YOLOV5中添加CBAM模块详解——原理+代码

news/2024/5/4 5:21:48/文章来源:https://blog.csdn.net/python_plus/article/details/129375664

目录

  • 一、前言
  • 二、CAM
        • 1. CAM计算过程
        • 2. 代码实现
        • 3. 流程图
  • 三、SAM
        • 1. SAM计算过程
        • 2. 代码实现
        • 3. 流程图
  • 四、YOLOv5中添加CBAM模块
  • 参考文章

一、前言

  由于卷积操作通过融合通道和空间信息来提取特征(通过N×NN×NN×N的卷积核与原特征图相乘,融合空间信息;通过不同通道的特征图加权求和,融合通道信息),论文提出的Convolutional Block Attention Module(CBAM)沿两个独立的维度(通道和空间)依次学习特征,然后与学习后的特征图与输入特征图相乘,进行自适应特征细化。

在这里插入图片描述

图1-1 CBAM结构图

  上图可以看到,CBAM包含CAM(Channel Attention Module)和SAM(Spartial Attention Module)两个子模块,分别进行通道和空间上的Attention。这样不只能够节约参数和计算力,并且保证了其能够做为即插即用的模块集成到现有的网络架构中去。

二、CAM

1. CAM计算过程

在这里插入图片描述

图2-1 CAM结构图

  输入特征图FFF首先经过两个并行的MaxPool层和AvgPool层,将特征图的维度从C×H×WC×H×WC×H×W变为C×1×1C×1×1C×1×1,然后经过Shared MLP模块。在该模块中,它先将通道数压缩为原来的1/r1/r1/r倍,再经过ReLU激活函数,然后扩张到原通道数。将这两个输出结果进行逐元素相加,再通过一个sigmoid激活函数得到Channel Attention的输出结果,然后将这个输出结果与原图相乘,变回C×H×WC×H×WC×H×W的大小。

  上述过程的计算公式如下:

Mc(F)=σ(MLP(AvgPool(F))+MLP(MaxPool(F)))M_{c}(F)=\sigma (MLP(AvgPool(F))+MLP(MaxPool(F)))Mc(F)=σ(MLP(AvgPool(F))+MLP(MaxPool(F)))
=σ(W1(W0(Favgc))+W1(W0(Fmaxc)))=\sigma (W_{1}(W_{0}(F^{c}_{avg}))+W_{1}(W_{0}(F^{c}_{max})))=σ(W1(W0(Favgc))+W1(W0(Fmaxc)))

  其中,σ\sigmaσ代表sigmoid激活函数,W0∈RC/r×CW_{0}\in R^{C/r\times C}W0RC/r×CW1∈RC×C/rW_{1}\in R^{C\times C/r}W1RC×C/r,且MLP的权重W0W_{0}W0W1W_{1}W1对于输入来说是共享的,ReLU激活函数位于W0W_{0}W0之后,W1W_{1}W1之前。

2. 代码实现

class ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) # 上面公式中的W0self.relu = nn.ReLU()self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) # 上面公式中的W1self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))max_out = self.f2(self.relu(self.f1(self.max_pool(x))))out = self.sigmoid(avg_out + max_out)return torch.mul(x, out)

3. 流程图

  CAM过程的详细流程如下图所示:

在这里插入图片描述

图2-2 CAM流程图

三、SAM

1. SAM计算过程

在这里插入图片描述

图3-1 SAM结构图

  将Channel Attention的输出结果通过最大池化和平均池化得到两个1×H×W1×H×W1×H×W的特征图,然后经过Concat操作对两个特征图进行拼接,再通过7×77×77×7卷积将特征图的通道数变为111(实验证明7×77×77×7效果比3×33×33×3好),再经过一个sigmoid得到Spatial Attention的特征图,最后将输出结果与原输入特征图相乘,变回CHW大小。

  上述过程的计算公式如下:

Ms(F)=σ(f7×7([AvgPool(F);MaxPool(F)]))M_{s}(F)=\sigma (f^{7\times 7}([AvgPool(F);MaxPool(F)])) Ms(F)=σ(f7×7([AvgPool(F);MaxPool(F)]))

=σ(f7×7([Favgs;Fmaxs]))=\sigma (f^{7\times 7}([F^{s}_{avg};F^{s}_{max}]))=σ(f7×7([Favgs;Fmaxs]))

  其中,σ\sigmaσ代表sigmoid激活函数,f7×7f^{7\times 7}f7×7代表卷积核大小为7×77×77×7的卷积过程。

2. 代码实现

class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avg_out, max_out], dim=1)out = self.sigmoid(self.conv(out))return torch.mul(x, out)

3. 流程图

  SAM过程的详细流程如下图所示:

在这里插入图片描述

图3-2 SAM流程图

四、YOLOv5中添加CBAM模块

  • 修改common.py
    在common.py中添加下列代码:
class ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)self.relu = nn.ReLU()self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))max_out = self.f2(self.relu(self.f1(self.max_pool(x))))out = self.sigmoid(avg_out + max_out)return torch.mul(x, out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avg_out, max_out], dim=1)out = self.sigmoid(self.conv(out))return torch.mul(x, out)class CBAMC3(nn.Module):# CSP Bottleneck with 3 convolutionsdef __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansionsuper(CBAMC3, self).__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c1, c_, 1, 1)self.cv3 = Conv(2 * c_, c2, 1)self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])self.channel_attention = ChannelAttention(c2, 16)self.spatial_attention = SpatialAttention(7)def forward(self, x):# 将最后的标准卷积模块改为了注意力机制提取特征return self.spatial_attention(self.channel_attention(self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))))
  • 修改yolo.py
    在yolo.py的if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3TR,......]中添加CBAMC3,即修改后的代码为:
        if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,C3, C3TR, ASPP, CBAMC3]:c1, c2 = ch[f], args[0]  if c2 != no:  c2 = make_divisible(c2 * gw, 8)  args = [c1, c2, *args[1:]] 
  • 修改yolov5s.yaml
    修改后的yolov5s.yaml如下:
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, CBAMC3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, CBAMC3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, CBAMC3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, CBAMC3, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

参考文章

CBAM——即插即用的注意力模块(附代码)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_267626.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录-51-110.平衡二叉树

目录前言题目1.求高度和深度的区别节点的高度节点的深度2. 本题思路分析:3. 算法实现4. pop函数的算法复杂度5. 算法坑点前言 在本科毕设结束后,我开始刷卡哥的“代码随想录”,每天一节。自己的总结笔记均会放在“算法刷题-代码随想录”该专…

学习笔记:基于SpringBoot的牛客网社区项目实现(二)之Spring MVC入门

1.1 函数的返回值为空,因为可以使用response对象向浏览器返回数据。声明了request对象和response对象,dispatcherservlet自动将这两个对象传入 RequestMapping("/http")public void http(HttpServletRequest request, HttpServletResponse re…

不会吧,难道真的有程序员不知道怎么接单赚钱吗?

随着大环境逐渐转好,跳槽、新工作、兼职等等机会都浮出水面。抛开跳槽、新工作不谈,今天就专门来说说程序员接单赚钱有哪些靠谱的平台。 首先分享一波关于接私活有哪些注意事项,给大家提个醒,避免盲目入坑。 一、程序员接单须知…

深度学习知识点全面总结_深度学习总结

深度学习知识点全面总结_深度学习总结 神经网络与深度学习结构(图片选自《神经网络与深度学习》一邱锡鹏) 目录 常见的分类算法 一、深度学习概念 1.深度学习定义 2.深度学习应用 3.深度学习主要术语 二、神经网络基础 1. 神经网络组成 感知机 多层感知机 3.前向传播…

复位和时钟控制(RCC)

目录 复位 系统复位 电源复位 备份区复位 时钟控制 什么是时钟? 时钟来源 二级时钟源: 如何使用CubeMX配置时钟 复位 系统复位 当发生以下任一事件时,产生一个系统复位:1. NRST引脚上的低电平(外部复位) 2. 窗口看门狗计数终止(WWD…

项目实战典型案例27——单表的更新接口有9个之多

单表的更新接口有9个之多一:背景介绍环境准备引入pom依赖配置数据库连接mybatis配置文件Mybatis的配置类编写通用的更新语句可以覆盖的更新接口暂时无法覆盖的接口测试四:总结五:升华一:背景介绍 本篇博客是对项目开发中出现的单…

197.Spark(四):Spark 案例实操,MVC方式代码编程

一、Spark 案例实操 1.数据准备 电商网站的用户行为数据,主要包含用户的 4 种行为:搜索,点击,下单,支付 样例类: 2. Top10 热门品类 先按照点击数排名,靠前的就排名高;如果点击数相同,再比较下单数;下单数再相同,就比较支付数。 我们有多种写法,越往后性能越…

k8s学习之路 | k8s 工作负载 ReplicaSet

文章目录1. ReplicaSet 基础概念1.1 RS 是什么?1.2 RS 工作原理1.3 什么时候使用 RS1.4 RS 示例1.5 非模板 Pod 的获得1.6 编写 RS1.7 使用 RS1.8 RS 替代方案2. ReplicaSet 与 ReplicationController2.1 关于 RS、RC2.2 两者的选择器区别2.3 总结1. ReplicaSet 基础…

yii2项目使用frp https2http插件问题

yii2内网项目,使用frp进行内网穿透,使用 https2http插件把内网服务器http流量转成https,会存在一个问题:当使用 $this->redirect(...) 或 $this->goHome() (其实用的也是前者)等重定向时,…

物联网毕设 -- 智能厨房监测系统(改)

前言 在家庭生活中,厨房是必不可少的,所以厨房的安全问题关乎着我们大家的生命,所以提出智能厨房监测系统,目的就是为我们减少不必要的安全问题 ⚠️⚠️(本文章仅提供思路和实现方法,并不包含代码&#x…

javaWeb在线考试系统

一、项目简介 本项目是一套javaWeb在线考试系统,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的Java学习者。 包含:项目源码、数据库脚本等,该项目附带全部源码可作为毕设使用。 项目都经过严格调试,eclipse 确保…

DBeaver连接mysql、oracle数据库

1. DBeaver连接mysql 1) 下载DBeaver https://dbeaver.io/download/,并安装 2) 新建数据库连接 3)选择mysql驱动程序 4)填写连接设置内容 5)点击 “编辑驱动设置”,并填写相关信息 6)选择本地…

厦大纪老师chatgpt相关讲座3.7

在线更新数据,迭代学习训练,进而提高模型性能。 比较明显的是API部分,这一步学习的就是intruction,实现人机写作的复杂系统工程 数据充足,维基类似于百度百科 transformer结构更有优势,预测下一个字,模型越…

优思学院|盘点,精益生产25个工具!【必需收藏】

精益生产方法需要一种全面的方法才能有效实施。精益这个概念是每个接触产品供应链的人都要实践的,无论是在计划方面还是在分析方面。 精益生产工具有助于持续改进生产效率和产品或服务质量。精益工具是要减少 Muda (浪费),从生产过…

6.4 深度负反馈放大电路放大倍数的分析

实用的放大电路中多引入深度负反馈,因此分析负反馈放大电路的重点是从电路中分离出反馈网络,并求出反馈系数 F˙\pmb{\dot F}F˙。 一、深度负反馈的实质 在负反馈放大电路的一般表达式中,若 ∣1A˙F˙∣>>1|1\dot A\dot F|>>1…

FPGA使用GTX实现SFP光纤收发SDI视频 全网首创略显高端 提供工程源码和技术支持

目录1、前言2、设计思路和框架3、vivado工程详解4、上板调试验证并演示5、福利:工程代码的获取1、前言 FPGA实现SDI视频编解码目前有两种方案: 一是使用专用编解码芯片,比如典型的接收器GS2971,发送器GS2972,优点是简…

MCM 箱模型建模方法及大气 O3 来源解析实用干货

OBM 箱模型可用于模拟光化学污染的发生、演变过程,研究臭氧的生成机制和进行敏感性分析,探讨前体物的排放对光化学污染的影响。箱模型通常由化学机理、物理过程、初始条件、输入和输出模块构成,化学机理是其核心部分。MCM (Master Chemical M…

机器学习中的数学——精确率与召回率

在Yolov5训练完之后会有很多图片,它们的具体含义是什么呢? 通过这篇博客,你将清晰的明白什么是精确率、召回率。这个专栏名为白话机器学习中数学学习笔记,主要是用来分享一下我在 机器学习中的学习笔记及一些感悟,也希…

自动化框架如何搭建?让10年阿里自动化测试老司机帮你搞定!自动化测试脚本怎么写?

一、何为框架?何为自动化测试框架? 无论是日常技术交流,还是在自动化测试实践中,经常会听到一个词叫:框架。之前对“框架”这个词知其然不知其所以然。现在看过一些资料以及加上我自己的一些实践有了我自己的一些看法…

日常文档标题级别规范

这里写自定义目录标题欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants创建一个自定义列表如何创建一个注…