CAM类激活映射 |神经网络可视化 | 热力图

news/2024/5/19 14:58:00/文章来源:https://blog.csdn.net/holly_Z_P_F/article/details/130011296

文章目录

    • 前言:
    • 安装库:
    • 分类案例--ResNet50
    • 分割案例
      • AttributeError: ‘tuple‘ object has no attribute ‘cpu‘
      • RuntimeError: grad can be implicitly created only for scalar outputs
      • TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
      • 完整代码

前言:

本篇文章只是教程,不涉及原理,感兴趣可以自行搜索
如图,热力图可以很好的反映出网络究竟注意图片的哪一部分
在这里插入图片描述
github官方教程:
https://github.com/jacobgil/pytorch-grad-cam
参考博客:
https://blog.csdn.net/u014264373/article/details/85415921
https://blog.csdn.net/u014264373/article/details/116302678
但还是遇到了很多报错,解决过程记录如下:

安装库:

pip install grad-cam

分类案例–ResNet50

案例图片:
在这里插入图片描述

案例代码:
这个代码是可以跑通的,将图片保存到你本地,然后设置好路径即可。
(需要下载ResNet预训练模型)

from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
import torchvision
import torch
from matplotlib import pyplot as plt
import numpy as npdef myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens, size))if titles == False:titles = "0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor, heatmap=False, shape=(224, 224)):np_arr = tensor.detach().numpy()  # [0]# 对数据进行归一化if np_arr.max() > 1 or np_arr.min() < 0:np_arr = np_arr - np_arr.min()np_arr = np_arr / np_arr.max()# np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0] == 1:np_arr = np.concatenate([np_arr, np_arr, np_arr], axis=0)np_arr = np_arr.transpose((1, 2, 0))return np_arrpath = "../examples/both.png"
bin_data = torchvision.io.read_file(path)  # 加载二进制数据
img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片
img = img.unsqueeze(0)  # 变成BCHW的数据,B==1; squeeze
input_tensor = torchvision.transforms.functional.resize(img, [224, 224])model = resnet50(pretrained=True)
target_layers = [model.layer4[-1]]  # 如果传入多个layer,cam输出结果将会取均值# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
with GradCAM(model=model, target_layers=target_layers, use_cuda=False) as cam:# targets = [ClassifierOutputTarget(386), ClassifierOutputTarget(386)]  # 指定查看class_num为386的热力图targets = None  # 选定目标类别,如果不设置,则默认为分数最高的那一类# aug_smooth=True, eigen_smooth=True 使用图像增强是热力图变得更加平滑grayscale_cams = cam(input_tensor=input_tensor, targets=targets)  # targets=None 自动调用概率最大的类别显示for grayscale_cam, tensor in zip(grayscale_cams, input_tensor):# 将热力图结果与原图进行融合rgb_img = tensor2img(tensor)visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)myimshows([rgb_img, grayscale_cam, visualization], ["image", "cam", "image + cam"])

最后出来的结果应该就是这样一张图。
在这里插入图片描述

分割案例

如果上面的代码你跑通了
那么如何为自己的网络生成热力图呢?
有几个需要注意的点:(最后会附上完整代码)

首先,切换成你的网络了模型加载就不说了,这个自己搞好。
然后,你的网络是否是在gpu上跑的,如果是
输入数据要放gpu上

path = './test_img/yu.jpg'
bin_data = torchvision.io.read_file(path)  # 加载二进制数据
img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片
img = img.unsqueeze(0)  # 变成BCHW的数据,B==1 squeeze
img_tensor = torchvision.transforms.functional.resize(img, [352, 352])
img_tensor = img_tensor.cuda()   # 加一句这个

然后按照上面的代码,修改这一句,改成你要查看的层:

target_layers = [model.layer4[-1]]  # 如果传入多个layer,cam输出结果将会取均值

把这个改成你要的层,然后运行一下,可能会遇到报错:

AttributeError: ‘tuple‘ object has no attribute ‘cpu‘

如果出现这个报错,可以看下你的网络最终输出是几个特征。因为是自己写的网络,有的因为训练需要,最终返回的是多个结果。

print(model(x))

如果有多个结果,会被变成一个元组。后面需要转cpu,元组tuple没有.cpu的方法,所以报错。
解决方法:
先把你的网络包装一下,你返回了多个值,选择有用的那一个就行
我这里选择了多个输出的第一个,自己视情况而定

class SegmentationModelOutputWrapper(torch.nn.Module):def __init__(self, model):super(SegmentationModelOutputWrapper, self).__init__()self.model = modeldef forward(self, x):return self.model(x)[0]  # 我这里选择了多个输出的第一个,自己视情况而定model = NetWork()
model.load_state_dict(torch.load(opt.snap_path))
# 网络加载后先包装下  修改输出
model = SegmentationModelOutputWrapper(model)

然后再运行,可能会出现报错:

RuntimeError: grad can be implicitly created only for scalar outputs

这个问题的解决办法是:
你需要去到base_cam.py这个库文件里面去
第85行有一句loss.backward(retain_graph = True)
将其修改为loss.backward(torch.ones_like(loss),retain_graph=True)

在这里插入图片描述

参考链接:https://blog.csdn.net/weixin_44390884/article/details/127893163

还有一个报错:

TypeError: can’t convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

如果你的模型最后返回的特征是tensor的特征,那么需要对tensor2img做改动
在这里插入图片描述

np_arr = tensor.detach().numpy()  # [0]

修改为:

np_arr = tensor.cpu().detach().numpy()  # [0]

完整代码

import os
import torch
import argparse
import numpy as np
import imageio
import torchvision
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from matplotlib import pyplot as pltdef myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens, size))if titles == False:titles = "0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor, heatmap=False, shape=(224, 224)):np_arr = tensor.cpu().detach().numpy()  # [0]if np_arr.max() > 1 or np_arr.min() < 0:np_arr = np_arr - np_arr.min()np_arr = np_arr / np_arr.max()# np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0] == 1:np_arr = np.concatenate([np_arr, np_arr, np_arr], axis=0)np_arr = np_arr.transpose((1, 2, 0))return np_arrif __name__ == '__main__':model = NetWork()model.load_state_dict(torch.load(opt.snap_path))# torchinfo.summary(model=model,input_size=(8, 3, 352, 352))# 包装下  修改输出model = SegmentationModelOutputWrapper(model)model.eval()path = './test_img/yu.jpg'bin_data = torchvision.io.read_file(path)  # 加载二进制数据img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片img = img.unsqueeze(0)  # 变成BCHW的数据,B==1 squeezeimg_tensor = torchvision.transforms.functional.resize(img, [352, 352])img_tensor = img_tensor.cuda()target_layers = [model.model.ncd]targets = Nonewith GradCAM(model=model,target_layers=target_layers,use_cuda=True) as cam:grayscale_cams = cam(input_tensor=img_tensor,targets=targets,aug_smooth=True)# cam_image = show_cam_on_image(img_rgb, grayscale_cam, use_rgb=True)for grayscale_cam, tensor in zip(grayscale_cams, img_tensor):# 将热力图结果与原图进行融合rgb_img = tensor2img(tensor)visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)myimshows([rgb_img, grayscale_cam, visualization], ["image", "cam", "image + cam"])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_283012.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯嵌入式第十三届省赛题目解析

马上就要比赛了&#xff0c;我也是把自己写完调试好的题目分享出来给大家&#xff0c;同时也祝大家取得自己理想的成绩。 好了废话不多说&#xff0c;我们先看客观题再看程序设计题。 目录 客观题&#xff1a; 程序设计题&#xff1a; 题目解析&#xff1a; CubeMX配置 …

分期的秘密:名义利率和实际利率

分期付款&#xff0c;是一种常见的消费方式&#xff0c;但是这其中却有不少猫腻。名义上的年化利率和实际上的利率竟有可能相差两倍之多。今天&#xff0c;我就以招行的现金分期举例&#xff0c;简单剖析一下其中的玄机。 以上就是招行现金分期的月利率&#xff0c;我们做一点小…

【李宏毅】深度学习——HW4-Speaker Identification

Speaker Identification 1.Goal 根据给定的语音内容&#xff0c;识别出说话者是谁 2.Data formats 2.1data directory 目录下有三个json文件和很多pt文件&#xff0c;三个json文件作用标注在下图中&#xff0c;pt文件就是语音内容。 mapping文件 metadata文件 n_mels:Th…

视频批量剪辑,如何在合并视频的时候添加上片头片尾。

我们经常在批量剪辑视频的时候&#xff0c;会遇到要给视频添加片头片尾的情况&#xff1f;那么应该要如何操作呢&#xff1f;今天就由小编来教教大家一个操作办法。 首先我们要进入媒体梦工厂主页面&#xff0c;并切换到“嵌套合并”这个操作页面来。 第二步&#xff0c;封面如…

第十七章 镜像架构和规划 - 双数据中心镜像配置和异地容灾

文章目录第十七章 镜像架构和规划 - 双数据中心镜像配置和异地容灾双数据中心镜像配置和异地容灾具有本地 DR 和地理上分离的 DR 的故障转移对具有地理位置分离的完全冗余灾难恢复环境的故障转移对地理上分离的故障转移对第十七章 镜像架构和规划 - 双数据中心镜像配置和异地容…

软件工程复习4.7

软件危机 软件危机的定义 软件在开发和维护过程中遇到的一系列问题 软件危机的表现 成本高软件质量得不到保证进度难以控制维护非常困难 软件危机包含两方面问题 如何开发软件&#xff0c;以满足不断增长&#xff0c;日趋复杂的需求如何维护数量不断膨胀的软件产品 软件…

一款多功能多合一的档案库房常用的一款空气质量检测仪

档案馆库房专用的一款智能型空气质量云测仪 空气质量检测仪 空气质量传感器 环境集成传感器 集成/温湿度、粉尘PM2.5 PM10/甲醛/TVOC/CO2等高度集成的一款传感器/RS485信号输出 ◆温度测量参数: (1)温度测量范围: -40~80℃(2&#xff09;输出分辨率:0.1oC (3&#xff09;测…

GPT-4博客介绍

文章目录gpt-4Visual inputs(视觉输入)Training process(训练过程)Loss Prediction&#xff08;损失预测&#xff09;Steerability&#xff08;可控性&#xff09;Limitations(限制)Risks & mitigations(风险和应对措施)gpt-4 在越复杂任务上&#xff0c;GPT4越是强于chat…

溯源取证-内存取证基础篇

使用工具&#xff1a; volatility_2.6_lin64_standalone 镜像文件&#xff1a; CYBERDEF-567078-20230213-171333.raw 使用环境&#xff1a; kali linux 2022.02 我们只有一个RAW映像文件&#xff0c;如何从该映像文件中提取出我们想要的东西呢&#xff1f; 1.Which volatili…

IPV6 资料收集

IPV6与IPV4区别 1、地址长度的区别&#xff1a;IPv4协议具有32位&#xff08;4字节&#xff09;地址长度&#xff1b;IPv6协议具有128位&#xff08;16字节&#xff09;地址长度。 2、地址的表示方法区别&#xff1a;IPv4地址是以小数表示的二进制数。 IPv6地址是以十六进制表…

Anaconda使用(一)使用Navigator或者prompt创建虚拟环境

入门 conda是一个功能强大的环境管理器&#xff0c;可以有效避免python各个版本和库之间产生的冲突问题。 安装问题 Navigator Navigator是conda中的一个图形化用户界面&#xff0c;可以在类似Web的界面中使用conda。 以下以windows为例子&#xff0c;打开的过程会比较长…

蓝桥杯赛前自救攻略,备赛抱佛脚指南

目录前言一、复习语言知识1、代码起手框架2、vector初始化2、unordered_map3、输入输出问题二、复习考试范围知识1、深度优先搜索&#xff08;Depth-First-Search&#xff09;模板2、随机字符、数字三、复习比赛真题1、模拟题2、动态规划题四、其他前言 明天就要开始蓝桥杯了&a…

安全防御 --- 防火墙-- ASPF、NAT

ASPF、NAT 1、FTP技术 &#xff08;1&#xff09;简介&#xff1a; 主机之间传输文件是IP网络的一个重要功能&#xff0c;如今人们可以方便地使用网页、邮箱进行文件传输。 然而在互联网早期&#xff0c;Web&#xff08;World Wide Web&#xff0c;万维网&#xff09;还未出现…

CDA证书值得考吗?数据分析前景怎么样?

在数据时代快速发展的现在&#xff0c;涌现出一批高薪的岗位&#xff0c;像大数据开发工程师、数据挖掘工程师、数据分析师等等。相对开发工程师等类型更偏向于技术的岗位而言&#xff0c;数据分析师对于学习者的要求同样严格。 下面就来给大家科普一下不同数据分析相关岗位提供…

把ChatGPT接入我的个人网站

效果图 详细内容和使用说明可以查看我的个人网站文章 把ChatGPT接入我的个人网站 献给有外网服务器的小伙伴 如果你本人已经有一台外网的服务器&#xff0c;并且页拥有一个OpenAI API Key&#xff0c;那么下面就可以参照我的教程来搭建一个自己的ChatGPT。 需要的环境 Cento…

降噪蓝牙耳机哪个品牌好?降噪蓝牙耳机排行推荐

随着蓝牙耳机品牌越来越多&#xff0c;型号更是让人眼花缭乱&#xff0c;各种功能也是层出不穷。但是很多人在眼花缭乱的耳机中并不知道如何选择合适的&#xff0c;下面是我根据多年的耳机使用经验总结的几款值得推荐的降噪蓝牙耳机&#xff0c;快速来看。 1.南卡A2真无线降噪…

回收站文件恢复怎么做?4种方法推荐!

案例&#xff1a;回收站文件恢复 【今天弟弟借用我的电脑&#xff0c;不小心把我的回收站清空了&#xff01;里面还有些被我误删的文件&#xff0c;有朋友知道回收站文件删除后应该怎么恢复吗&#xff1f;急求一个解决方法&#xff01;感谢感谢&#xff01;】 当我们意外地删…

跨时钟域传输数据——单bit和多bit信号(总结)

文章目录前言一、慢时钟域到快时钟域1、单bit信号2、多bit信号二、快时钟域到慢时钟域1、单bit信号2、多bit信号三、多bit信号跨时钟域传输1、多个信号合并2、多周期路径 Multi-cycle Path/MCP3、使用格雷码4、使用异步FIFO5、使用DMUX电路结构6、握手信号传输四、简答题1、跨时…

RWKV 语言模型:具有 Transformer 优点的 CNN

RWKV 语言模型&#xff0c;这是一个具有巨大潜力的开源大型语言模型。由于 ChatGPT 和一般的大型语言模型最近受到了很多关注。在这篇文章中&#xff0c;我将尝试解释与大多数语言模型&#xff08;transformer&#xff09;相比&#xff0c;RWKV 有何特别之处。 RWKV 可视化 浅谈…

Filter过滤器和Listener监听器在Servlet的应用

文章目录Filter过滤器Filter过滤器的实现原理多个Filter的执行顺序Filter过滤器在Servlet的应用Listener监听器Listener监听器的介绍Listener监听器在Servlet中的应用Filter过滤器 Filter过滤器的实现原理 过滤器&#xff0c;顾名思义。我们可以根据需要制作特定的过滤器在某些…