pytorch代码实现之动态卷积模块ODConv

news/2024/5/20 0:01:29/文章来源:https://blog.csdn.net/DM_zx/article/details/132857530

ODConv动态卷积模块

ODConv可以视作CondConv的延续,将CondConv中一个维度上的动态特性进行了扩展,同时了考虑了空域、输入通道、输出通道等维度上的动态性,故称之为全维度动态卷积。ODConv通过并行策略采用多维注意力机制沿核空间的四个维度学习互补性注意力。作为一种“即插即用”的操作,它可以轻易的嵌入到现有CNN网络中。ImageNet分类与COCO检测任务上的实验验证了所提ODConv的优异性:即可提升大模型的性能,又可提升轻量型模型的性能,实乃万金油是也!值得一提的是,受益于其改进的特征提取能力,ODConv搭配一个卷积核时仍可取得与现有多核动态卷积相当甚至更优的性能。

原文地址:Omni-Dimensional Dynamic Convolution

ODConv结构图
代码实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd
from models.common import Conv, autopadclass Attention(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):super(Attention, self).__init__()attention_channel = max(int(in_planes * reduction), min_channel)self.kernel_size = kernel_sizeself.kernel_num = kernel_numself.temperature = 1.0self.avgpool = nn.AdaptiveAvgPool2d(1)self.fc = Conv(in_planes, attention_channel, act=nn.ReLU(inplace=True))self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)self.func_channel = self.get_channel_attentionif in_planes == groups and in_planes == out_planes:  # depth-wise convolutionself.func_filter = self.skipelse:self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)self.func_filter = self.get_filter_attentionif kernel_size == 1:  # point-wise convolutionself.func_spatial = self.skipelse:self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)self.func_spatial = self.get_spatial_attentionif kernel_num == 1:self.func_kernel = self.skipelse:self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)self.func_kernel = self.get_kernel_attentionself._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)if isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def update_temperature(self, temperature):self.temperature = temperature@staticmethoddef skip(_):return 1.0def get_channel_attention(self, x):channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)return channel_attentiondef get_filter_attention(self, x):filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)return filter_attentiondef get_spatial_attention(self, x):spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)spatial_attention = torch.sigmoid(spatial_attention / self.temperature)return spatial_attentiondef get_kernel_attention(self, x):kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)return kernel_attentiondef forward(self, x):x = self.avgpool(x)x = self.fc(x)return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)class ODConv2d(nn.Module):def __init__(self, in_planes, out_planes, k, s=1, p=None, g=1, act=True, d=1,reduction=0.0625, kernel_num=1):super(ODConv2d, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.kernel_size = kself.stride = sself.padding = autopad(k, p)self.dilation = dself.groups = gself.kernel_num = kernel_numself.attention = Attention(in_planes, out_planes, k, groups=g,reduction=reduction, kernel_num=kernel_num)self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//g, k, k),requires_grad=True)self._initialize_weights()self.bn = nn.BatchNorm2d(out_planes)self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())if self.kernel_size == 1 and self.kernel_num == 1:self._forward_impl = self._forward_impl_pw1xelse:self._forward_impl = self._forward_impl_commondef _initialize_weights(self):for i in range(self.kernel_num):nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')def update_temperature(self, temperature):self.attention.update_temperature(temperature)def _forward_impl_common(self, x):# Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,# while we observe that when using the latter method the models will run faster with less gpu memory cost.channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)batch_size, in_planes, height, width = x.size()x = x * channel_attentionx = x.reshape(1, -1, height, width)aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)aggregate_weight = torch.sum(aggregate_weight, dim=1).view([-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups * batch_size)output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))output = output * filter_attentionreturn outputdef _forward_impl_pw1x(self, x):channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)x = x * channel_attentionoutput = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups)output = output * filter_attentionreturn outputdef forward(self, x):return self.act(self.bn(self._forward_impl(x)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_550711.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】实验13 使用Dropout抑制过拟合

文章目录 使用Dropout抑制过拟合1. 环境准备2. 导入数据集3. 对所有数据的预测3.1 数据集3.2 构建神经网络 3.3 训练模型3.4 分析模型 4. 对未见过数据的预测4.1 划分数据集4.2 构建神经网络4.3 训练模型4.4 分析模型 5. 使用Dropout抑制过拟合5.1 构建神经网络5.2 训练模型5.3…

加密货币交易所偿付能力的零知识证明

如何检测下一个 FTX 和 Mt. Gox 加密货币交易所 FTX 的内爆导致数十亿客户资金流失,这是加密货币历史上交易所破产的最新例子。历史可以追溯到 2014 年,当时处理 70% 比特币交易的历史最悠久、规模最大的交易所 Mt. Gox 丢失了用户的 850,000 个比特币。…

python使用websocket实现多端数据同步,多个websocket同步消息,断开链接自动清理

我使用的是flask_sock这个模块,我的使用场景是:可以让数据多端实时同步。在游戏控制后台和游戏选手的ipad上都可以实时调整角色的技能和点数什么的,所以需要这样的一个功能来实现数据实时同步。 下面是最小的demo案例: from fla…

指数渐变线

指数渐变线是非均匀传输线的一种。为何叫指数渐变线呢?其分布参数变化规律为指数规律,比如:单位长度的电感、电容、特性阻抗。 1、分析过程 从非均匀线的微分方程出发: 对方程两侧同时取微分: 化简得: …

elementui 中 DateTimePicker 组件时间自定义格式化

elementui 中 DateTimePicker 组件时间自定义格式化 需求分析 需求 elementui 中 DateTimePicker 组件时间自定义格式化 自定义需求&#xff1a;需要获取到 DateTimePicker 组件时间的值为[“2023/9/5 20:2”,“2023/9/4 2:10”] 分析 源码如下&#xff1a; <el-date-pick…

2023数学建模国赛游记

第一参加数学建模国赛&#xff0c;大概也是最后一次参加了&#xff0c;记录一下这几天的历程吧。 我们队的情况是计算机电气数统&#xff0c;计算机负责编程&#xff0c;电气学院的负责论文部分&#xff0c;数统的同学负责建模&#xff0c;数据处理部分我们是共同承担。 第一天…

微服务学习(七):docker安装Mysql

微服务学习&#xff08;七&#xff09;&#xff1a;docker安装Mysql 1、拉取镜像 docker pull mysql2、查看安装的镜像 docker images3、安装mysql docker run -p 3306:3306 --name mysql \ -v /mydata/mysql/log:/var/log/mysql \ -v /mydata/mysql/data:/var/lib/mysql \…

基于SpringBoot+Vue的宠物领养饲养交流管理平台设计与实现

前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者&#xff0c;博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌&#x1f497; &#x1f447;&#x1f3fb;…

计算机专业毕业设计项目推荐08-英语在线点读平台(SpringBoot+Vue+MongoDB)

英语在线点读平台&#xff08;SpringBootVueMongoDB&#xff09; **介绍****系统总体开发情况-功能模块****各部分模块实现** 介绍 本系列(后期可能博主会统一为专栏)博文献给即将毕业的计算机专业同学们,因为博主自身本科和硕士也是科班出生,所以也比较了解计算机专业的毕业设…

SSRF漏洞(利用file协议读取本地文件)

简介 当利用SSRF漏洞时&#xff0c;攻击者可以通过构造恶意请求来读取本地文件。其中一种方法是使用file协议来读取本地文件。例如&#xff0c;file:///etc/passwd是一个常见的示例&#xff0c;它用于读取Linux系统上的passwd文件。 passwd文件是Linux系统中用于存储用户账户…

2.求循环小数

题目 对于任意的真分数 N/M &#xff08; 0 < N < M &#xff09;&#xff0c;均可以求出对应的小数。如果采用链表表示各个小数&#xff0c;对于循环节采用循环链表表示&#xff0c;则所有分数均可以表示为如下链表形式。 输入&#xff1a; N M 输出&#xff1a; 转换…

时序数据库 TimescaleDB 安装与使用

TimescaleDB 是一个时间序列数据库&#xff0c;建立在 PostgreSQL 之上。然而&#xff0c;不仅如此&#xff0c;它还是时间序列的关系数据库。使用 TimescaleDB 的开发人员将受益于专门构建的时间序列数据库以及经典的关系数据库 (PostgreSQL)&#xff0c;所有这些都具有完整的…

力扣:103. 二叉树的锯齿形层序遍历(Python3)

题目&#xff1a; 给你二叉树的根节点 root &#xff0c;返回其节点值的 锯齿形层序遍历 。&#xff08;即先从左往右&#xff0c;再从右往左进行下一层遍历&#xff0c;以此类推&#xff0c;层与层之间交替进行&#xff09;。 来源&#xff1a;力扣&#xff08;LeetCode&#…

Android StateFlow初探

Android StateFlow初探 前言&#xff1a; 最近在学习StateFlow&#xff0c;感觉很好用&#xff0c;也很神奇&#xff0c;于是记录了一下. 1.简介&#xff1a; StateFlow 是一个状态容器式可观察数据流&#xff0c;可以向其收集器发出当前状态更新和新状态更新。还可通过其 …

Go语言开发环境搭建指南:快速上手构建高效的Go开发环境

Go 官网&#xff1a;https://go.dev/dl/ Go 语言中文网&#xff1a;https://studygolang.com/dl 下载 Go 的语言包 进入官方网站 Go 官网 或 Go 语言中文网&#xff1a; 选择下载对应操作系统的安装包&#xff1a; 等待下载完成&#xff1a; 安装 Go 的语言包 双击运行上…

Linux 远程登录(Xshell7)

为什么需要远程登录Linux&#xff1f;因为通常在公司做开发的时候&#xff0c;Linux 一般作为服务器使用&#xff0c;而服务器一般放在机房&#xff0c;linux服务器是开发小组共享&#xff0c;且正式上线的项目是运行在公网&#xff0c;因此需要远程登录到Liux进行项日管理或者…

PgSQL-安全加固实践-如何设置非全零监听

PgSQL-安全加固实践-如何设置非全零监听 1、介绍 PgSQL在启动前需要配置listen_addresses配置项&#xff0c;该配置项表示允许PgSQL服务监听程序绑定的IP。我们知道一个host上可以有多个网卡&#xff0c;每个网卡可以绑定多个IP&#xff0c;该参数就是控制PgSQL服务绑定在哪个或…

容器的数据卷

容器的数据卷 操作数据卷 # 基本格式 docker volume [common] # 创建一个volume docker volume create # 显示一个或多个volume docker volume inspect # 列出所以的volume docker volume ls # 删除未使用的volume docker volume prune # 删除一个或多个volume docker volume…

C++笔记之引用折叠规则

C笔记之引用折叠规则 文章目录 C笔记之引用折叠规则1. 当两个左值引用结合在一起时&#xff0c;它们会折叠成一个左值引用。2. 当一个左值引用和一个右值引用结合在一起时&#xff0c;它们会折叠成一个左值引用。3. 当两个右值引用结合在一起时&#xff0c;它们也会折叠成一个右…

typeof的作用

typeof 是 JavaScript 中的一种运算符&#xff0c;用于获取给定值的数据类型。 它的作用是返回一个字符串&#xff0c;表示目标值的数据类型。通过使用 typeof 运算符&#xff0c;我们可以在运行时确定一个值的类型&#xff0c;从而进行相应的处理或逻辑判断。 常见的数据类型…