0

0

Paddle可视化神经网络热力图(CAM)

P粉084495128

P粉084495128

发布时间:2025-07-18 10:03:50

|

502人浏览过

|

来源于php中文网

原创

本文介绍了使用Paddle实现神经网络热力图(CAM)可视化的方法。CAM可展示CNN分类时关注的区域,文中详细阐述其原理,提供了完整代码,包括图像预处理、模型输出提取、梯度计算等步骤,还说明如何通过热力图指导数据集扩充和数据增强,帮助优化模型性能。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

paddle可视化神经网络热力图(cam) - php中文网

Paddle可视化神经网络热力图(CAM)

Class Activation Mapping(CAM)是一个帮助可视化CNN的工具,通过它我们可以观察为了达到正确分类的目的,网络更侧重于哪块区域。比如,下面两幅图,一个是刷牙,一个是砍树,我们根据热力图可以看到高响应区域的确集中在我们认为最有助于作出判断的部位。

Paddle可视化神经网络热力图(CAM) - php中文网

本项目最初还是来自于项目:讯飞农作物生长情况识别挑战赛baseline(非官方),因为数据集不大,尽管模型收敛很好,但线上的分数确始终不能更进一步。于是想到可以可视化一下网络的CAM,观察一下指导分类的高响应区域是否落在目标核心部位上。

CAM论文链接地址

  • CAM原理

其计算方法如下图所示。对于一个CNN模型,对其最后一个featuremap做全局平均池化(GAP)计算各通道均值,然后通过FC层等映射到class score,找出argmax,计算最大的那一类的输出相对于最后一个featuremap的梯度,再把这个梯度可视化到原图上即可。直观来说,就是看一下网络抽取到的高层特征的哪部分对最终的classifier影响更大。

ChatGPT Website Builder
ChatGPT Website Builder

ChatGPT网站生成器,AI对话快速生成网站

下载

Paddle可视化神经网络热力图(CAM) - php中文网

In [ ]
!cd 'data/data106772' && unzip -q img.zip
In [1]
%matplotlib inlineimport osfrom PIL import Imageimport paddleimport numpy as npimport cv2import matplotlib.pyplot as pltfrom draw_features import Res2Net_vdimport paddle.nn.functional as Fimport paddleimport warnings
warnings.filterwarnings('ignore')def draw_CAM(model, img_path, save_path, transform=None, visual_heatmap=False):
    '''
    绘制 Class Activation Map
    :param model: 加载好权重的Pytorch model
    :param img_path: 测试图片路径
    :param save_path: CAM结果保存路径
    :param transform: 输入图像预处理方法
    :param visual_heatmap: 是否可视化原始heatmap(调用matplotlib)
    :return:
    '''
    # 图像加载&预处理
    img = Image.open(img_path).convert('RGB')
    img = img.resize((224, 224), Image.BILINEAR) #Image.BILINEAR双线性插值
    if transform:
        img = transform(img)    # img = img.unsqueeze(0)
    img = np.array(img).astype('float32')
    img = img.transpose((2, 0, 1))
    img = paddle.to_tensor(img)
    img = paddle.unsqueeze(img, axis=0)    # print(img.shape)
    # 获取模型输出的feature/score

    output,features = model(img)    print('outputshape:',output.shape)    print('featureshape:',features.shape)    # lab = np.argmax(out.numpy())
    # 为了能读取到中间梯度定义的辅助函数
    def extract(g):
        global features_grad
        features_grad = g 
    # 预测得分最高的那一类对应的输出score
    pred = np.argmax(output.numpy())    # print('***********pred:',pred)
    pred_class = output[:, pred]    # print(pred_class)

    features.register_hook(extract)
    pred_class.backward() # 计算梯度
 
    grads = features_grad   # 获取梯度
    # print(grads.shape)
    # pooled_grads = paddle.nn.functional.adaptive_avg_pool2d( x = grads, output_size=[1, 1])
    pooled_grads = grads    
    # 此处batch size默认为1,所以去掉了第0维(batch size维)
    pooled_grads = pooled_grads[0]    # print('pooled_grads:', pooled_grads.shape)
    # print(pooled_grads.shape)
    features = features[0]    # print(features.shape)
    # 最后一层feature的通道数
    for i in range(2048):
        features[i, ...] *= pooled_grads[i, ...]
 
    heatmap = features.detach().numpy()
    
    heatmap = np.mean(heatmap, axis=0)    # print(heatmap)
    heatmap = np.maximum(heatmap, 0)    # print('+++++++++',heatmap)
    heatmap /= np.max(heatmap)    # print('+++++++++',heatmap)
    # 可视化原始热力图
    if visual_heatmap:
        plt.matshow(heatmap)
        plt.show()
 
    img = cv2.imread(img_path)  # 用cv2加载原始图像
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))  # 将热力图的大小调整为与原始图像相同
    heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
    # print(heatmap.shape)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将热力图应用于原始图像
    superimposed_img = heatmap * 0.4 + img  # 这里的0.4是热力图强度因子
    cv2.imwrite(save_path, superimposed_img)  # 将图像保存到硬盘model_re2 = Res2Net_vd(layers=50, scales=4, width=26, class_dim=4)# model_re2 = Res2Net50_vd_26w_4s(class_dim=4)modelre2_state_dict = paddle.load("Hapi_MyCNN.pdparams")
model_re2.set_state_dict(modelre2_state_dict, use_structured_name=True)
use_gpu = Truepaddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')

model_re2.eval()

draw_CAM(model_re2, 'data/data106772/img/test/629.jpg', 'test3.jpg', transform=None, visual_heatmap=True)
W1123 17:02:34.611114  6534 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1123 17:02:34.615968  6534 device_context.cc:465] device: 0, cuDNN Version: 7.6.
outputshape: [1, 4]
featureshape: [1, 2048, 7, 7]
  • 实验结果

Paddle可视化神经网络热力图(CAM) - php中文网

  • 代码详解

关于代码,相信注释已经写得很明白了,需要注意的是,我把网络结构多返回了softmax层之前的特征向量,代码如下所示:

    def forward(self, inputs):        y = self.conv1_1(inputs)        y = self.conv1_2(y)        y = self.conv1_3(y)        y = self.pool2d_max(y)        blocks = []
        for block in self.block_list:            y = block(y)            blocks.append(y)        # draw_features(32, 32, y.cpu().numpy(), "{}/f7_layer3.png".format(savepath))
        # y = self.convf_1(y)
        y = self.pool2d_avg(y)        y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])        out = self.out(y)        blocks.append(out)        return blocks[-1:-3:-1]

总结

通过该可视化方法,可以有针对性的对数据集进行扩充,以此来指导数据增强的方向。需要注意的是,大家需要对网络结构足够了解,CAM主要使用最后一层的特征向量,大家注意区分。

In [ ]

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
C++ 设计模式与软件架构
C++ 设计模式与软件架构

本专题深入讲解 C++ 中的常见设计模式与架构优化,包括单例模式、工厂模式、观察者模式、策略模式、命令模式等,结合实际案例展示如何在 C++ 项目中应用这些模式提升代码可维护性与扩展性。通过案例分析,帮助开发者掌握 如何运用设计模式构建高质量的软件架构,提升系统的灵活性与可扩展性。

14

2026.01.30

c++ 字符串格式化
c++ 字符串格式化

本专题整合了c++字符串格式化用法、输出技巧、实践等等内容,阅读专题下面的文章了解更多详细内容。

9

2026.01.30

java 字符串格式化
java 字符串格式化

本专题整合了java如何进行字符串格式化相关教程、使用解析、方法详解等等内容。阅读专题下面的文章了解更多详细教程。

12

2026.01.30

python 字符串格式化
python 字符串格式化

本专题整合了python字符串格式化教程、实践、方法、进阶等等相关内容,阅读专题下面的文章了解更多详细操作。

4

2026.01.30

java入门学习合集
java入门学习合集

本专题整合了java入门学习指南、初学者项目实战、入门到精通等等内容,阅读专题下面的文章了解更多详细学习方法。

20

2026.01.29

java配置环境变量教程合集
java配置环境变量教程合集

本专题整合了java配置环境变量设置、步骤、安装jdk、避免冲突等等相关内容,阅读专题下面的文章了解更多详细操作。

18

2026.01.29

java成品学习网站推荐大全
java成品学习网站推荐大全

本专题整合了java成品网站、在线成品网站源码、源码入口等等相关内容,阅读专题下面的文章了解更多详细推荐内容。

19

2026.01.29

Java字符串处理使用教程合集
Java字符串处理使用教程合集

本专题整合了Java字符串截取、处理、使用、实战等等教程内容,阅读专题下面的文章了解详细操作教程。

3

2026.01.29

Java空对象相关教程合集
Java空对象相关教程合集

本专题整合了Java空对象相关教程,阅读专题下面的文章了解更多详细内容。

6

2026.01.29

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
React 教程
React 教程

共58课时 | 4.4万人学习

Pandas 教程
Pandas 教程

共15课时 | 1.0万人学习

ASP 教程
ASP 教程

共34课时 | 4.3万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号