【DETR 论文解读】End-to-End Object Detection with Transformer

news/2024/4/25 20:56:36/文章来源:https://blog.csdn.net/qq_38253797/article/details/127429466

目录

  • 前言
  • 一、整体架构
  • 二、基于集合预测的损失函数
    • 2.1、二分图匹配确定有效预测框
    • 2.2、损失函数
  • 三、前向推理
  • 四、掉包版代码
  • 五、一些问题
  • Reference

前言

贡献/特点:

  1. 端到端:去除NMS和anchor,没有那么多的超参,计算量也大大减少,整个网络变得很简单;
  2. 基于Transformer:首次将Transformer引入到目标检测任务当中;
  3. 提出一种全新的基于集合的损失函数:通过二分图匹配的方法强制模型输出一组独一无二的预测框,每个物体只会产生一个预测框,这样就将目标检测问题直接转换为集合预测的问题,所以才不用nms,达到端到端的效果;
  4. 而且在decoder输入一组可学习的object query和encoder输出的全局上下文特征,直接以并行方式强制输出最终的100个预测框,替代了anchor;
  5. 缺点:对大物体的检测效果很好,但是对小物体的检测效果不好;训练起来比较慢;
  6. 优点:在COCO数据集上速度和精度和Faster RCNN差不多;可以扩展到很多任务中,比如分割、追踪、多模态等;

一、整体架构

在这里插入图片描述

  1. 图片输入,首先经过一个CNN网络提取图片的局部特征;
  2. 再把特征拉直,输入Transformer Encoder中,进一步学习这个特征的全局信息。经过Encoder后就可以计算出没一个点或者没一个特征和这个图片的其他特征的相关性;
  3. 再把Encoder的输出送入Decoder中,并且这里还要输入Object Query,限制解码出100个框,这一步作用就是生成100个预测框;
  4. 预测出的100个框和gt框,通过二分图匹配的方式,确定其中哪些预测框是有物体的,哪些是没有物体的(背景),再把有物体的框和gt框一起计算分类损失和回归损失;推理的时候更简单,直接对decoder中生成的100个预测框设置一个置信度阈值(0.7),大于的保留,小于的抑制;

二、基于集合预测的损失函数

2.1、二分图匹配确定有效预测框

预测得到N(100)个预测框,gt为M个框,通常N>M,那么怎么计算损失呢?

这里呢,就先对这100个预测框和gt框进行一个二分图的匹配,先确定每个gt对应的是哪个预测框,最终再计算M个预测框和M个gt框的总损失。

其实很简单,假设现在有一个矩阵,横坐标就是我们预测的100个预测框,纵坐标就是gt框,再分别计算每个预测框和其他所有gt框的cost,这样就构成了一个cost matrix,再确定把如何把所有gt框分配给对应的预测框,才能使得最终的总cost最小。

在这里插入图片描述
这里计算的方法就是很经典的匈牙利算法,通常是调用scipy包中的linear_sum_assignment函数来完成。这个函数的输入就是cost matrix,输出一组行索引和一个对应的列索引,给出最佳分配。

匈牙利算法通常用来解决二分图匹配问题,具体原理可以看这里: 二分图匈牙利算法的理解和代码 和 算法学习笔记(5):匈牙利算法

所以通过以上的步骤,就确定了最终100个预测框中哪些预测框会作为有效预测框,哪些预测框会称为背景。再将有效预测框和gt框计算最终损失(有效预测框个数等于gt框个数)。

2.2、损失函数

损失函数:分类损失+回归损失

在这里插入图片描述
分类损失:交叉熵损失,去掉log

回归损失:GIOU Loss + L1 Loss

三、前向推理

在这里插入图片描述

DETR前向传播流程:

  1. 假设输入图片:3x800x1066;
  2. 输入CNN网络(ResNet50)中,走到Conv5,此时对原图片下采样32倍,输出2048x25x34;
  3. 经过一个1x1卷积降为,输出256x25x34;
  4. 生成位置编码256x25x34,再和前面CNN输出的特征相加,输出256x25x34的特征;
  5. 再把特征拉直,变成850x256,输入transformer encoder中;
  6. 经过6个encoder模块,进行全局建模,输入同样850x256的特征;
  7. 生成一个可学习的object queries(positional embedding)100x256;
  8. 将encode输出的全局特征850x256和object queries 100x256一起输入6层decoder中,反复的做自注意力操作,最后得到一个100x256的特征;(细节:这里每个decoder都会做一次object query的自注意力操作,第一个decoder可以不做,这主要是为了移除冗余框;为了让模型训练的更快更稳定,所以在Decoder后面加了很多的auxiliary loss,不光在最后一层decoder中计算loss,在之前的decoder中也计算loss)
  9. 最后再接上两个feed forward network预测头(全连接层),一个FFN做物体类别的预测(类别个数),另一个FFN做box预测(4 xywh);
  10. 再用这100个预测框和gt框(N个)通过匈牙利算法做最优匹配,找到最终N个有效的预测框,其他的(100-N)框当作背景,舍去;
  11. 再用这N个预测框和N个GT框计算损失函数(交叉熵损失,去掉log + GIOU Loss + L1 Loss),梯度回传;

四、掉包版代码

论文原文给出的掉包版代码,mAP好像有40,虽然比源码低了2个点,但是代码很简单,只有40多行,方便我们了解整个detr的网络结构:

import torch
from torch import nn
from torchvision.models import resnet50class DETR(nn.Module):def __init__(self, num_classes, hidden_dim, nheads, num_encoder_layers, num_decoder_layers):super().__init__()# backbone = resnet50 除掉average pool和fc层  只保留conv1 - conv5_xself.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])# 1x1卷积降维 2048->256self.conv = nn.Conv2d(2048, hidden_dim, 1)# 6层encoder + 6层decoder    hidden_dim=256  nheads多头注意力机制 8头   num_encoder_layers=num_decoder_layers=6self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)# 分类头self.linear_class = nn.Linear(hidden_dim, num_classes + 1)# 回归头self.linear_bbox = nn.Linear(hidden_dim, 4)# 位置编码  encoder输入self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))# query pos编码  decoder输入self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))def forward(self, inputs):x = self.backbone(inputs)    # [1,3,800,1066] -> [1,2048,25,34]h = self.conv(x)             # [1,2048,25,34] -> [1,256,25,34]H, W = h.shape[-2:]          # H=25  W=34# pos = [850,1,256]  self.col_embed = [50,128]  self.row_embed[:H]=[50,128]pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)# encoder输入  decoder输入h = self.transformer(pos + h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1))return self.linear_class(h), self.linear_bbox(h).sigmoid()detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1066)
logits, bboxes = detr(inputs)
print(logits.shape)   # torch.Size([100, 1, 92])
print(bboxes.shape)   # torch.Size([100, 1, 4])

五、一些问题

1、为什么ViT只有Encoder,而DETR要用Encoder+Decoder?(从论文实验部分得出结论)
Encoder:Encoder自注意力主要进行全局建模,学习全局的特征,通过这一步其实已经基本可以把图片中的各个物体尽可能的分开;

Decoder:这个时候再使用Decoder自注意力,再做目标检测和分割任务,模型就可以进一步把物体的边界的极值点区域进行一个更进一步精确的划分,让边缘的识别更加精确;

2、object query有什么用?
object query是用来替换anchor的,通过引入可学习的object query,可以让模型自动的去学习图片当中哪些区域是可能有物体的,最终通过object query可以找到100个这种可能有物体的区域。再后面通过二分图匹配的方式找到100个预测框中有效的预测框,进而计算损失即可。

所以说object query就起到了替换anchor的作用,以可学习的方式找到可能有物体的区域,而不会因为使用anchor而造成大量的冗余框。

Reference

b站: DETR 论文精读【论文精读】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_404480.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二叉树的OJ练习题

1.单值二叉树 描述:如果二叉树每个节点都具有相同的值,那么该二叉树就是单值二叉树。只有给定的树是单值二叉树时,才返回 true;否则返回 false。 链接:965. 单值二叉树 - 力扣(LeetCode) 思路…

世界陶瓷卫浴100强榜单发布!

​  经过一年的严格数据审查,科学统计分析,备受全行业期待的 【世界陶瓷卫浴100强统计排行榜 】于2022年10月19日在中国佛山正式发布,除了陶瓷卫浴企业100强总榜以外,还发布了全球瓷砖企业30强、全球卫浴企业20强,全…

Python中的对象池是什么

在程序设计中,创建物体模块主要是通过生成对象来实现。当对象使用结束后,则会成为不再需要的模块进行销毁。 而在系统进行对象的生成与销毁过程中会大量的增加内存的消耗,同时对象的销毁往往会留下残留的信息,这样将会伴随内存泄露…

javaWeb SSM车辆调度系统myeclipse定制开发mysql数据库网页模式java编程SpringMVC

一、源码特点 JSP SSM车辆调度系统是一套完善的web设计系统,对理解JSP java编程开发语言有帮助,系统具有完整的源代码 系统采用SSM框架,系统主要采用B/S模式开发。开发环境为 TOMCAT7.0,Myeclipse8.5开发,数据库为Mysql5.0&a…

swagger动态开关实践

swagger动态开关实践1. 背景2. 配置文件监听2.1 基于注解2.2 基于jdk3. swagger改造3.1 bean刷新3.2 方法重写4. 总结5. 参考资料1. 背景 系统漏洞扫描,扫出了swagger的问题。这个问题其实比较基础,那就是生产环境不应该开启swagger! 但是&…

FreeRTOS 软件定时器的使用

FreeRTOS中加入了软件定时器这个功能组件,是一个可选的、不属于freeRTOS内核的功能,由定时器服务任务(其实就是一个定时器任务)来提供。 软件定时器是当设定一个定时时间,当达到设定的时间之后就会执行指定的功能函数&…

el-switch接口实现

后台返回的数据: active-textswitch 打开时的文字描述string——inactive-textswitch 关闭时的文字描述string——active-valueswitch 打开时的值boolean / string / number—trueinactive-valueswitch 关闭时的值boolean / string / number—falseactive-colorswi…

Enzo丨艾美捷Enzo Ciglitazone解决方案

艾美捷Enzo Ciglitazone是一种噻唑烷二酮类降血糖药。它在遗传性肥胖的C57 Bl/6 ob/ob小鼠中显示抗高血糖活性,并且是选择性PPARγ激动剂(EC50=3M)。抑制人间充质干细胞中HUVEC分化和血管生成,并刺激脂肪生成和减少成骨…

区块链 — Overview

文章目录区块链的概念区块链数据结构区块链的基础技术哈希运算数字签名共识算法智能合约P2P网络区块链分类公有链联盟链私有链区块链的概念 狭义上,区块链是一种按照时间顺序将数据区块以顺序相连的方式组合成的一种链式数据结构,并以密码学方式保证的不…

深度神经网络图像识别,深度神经网络图像配准

如何用Python和深度神经网络寻找相似图像 代码首先,读入TuriCreate软件包import turicreate as tc我们指定图像所在的文件夹image,让TuriCreate读取所有的图像文件,并且存储到data数据框data tc.image_analysis.load_images(./image/)我们来…

《python 可视化之 matplotlib》第一章 折线图 plot

《python 可视化之 matplotlib》第一章 折线图 本章节内容包括以下几方面内容: 绘制曲线 yx2yx^2yx2;让曲线更加光滑;常见的相关属性设置;多条折线图的绘制;折线图之间的颜色填充;时间序列可视化;常见问题…

iNFTnews|在元宇宙中探索NFT的无限可能

元宇宙正在使我们当下的生活发生显著变化。 我们都玩过很多电子游戏,看过很多相关的科幻电影,也有过很多关于元宇宙进入我们日常生活后,我们周围的事物将会受到怎样的巨大影响的讨论。 我们很快就会看到,如此先进的技术突破将逐…

人工神经网络概念及组成,人工神经网络基本概念

1、什么是BP神经网络? BP算法的基本思想是:学习过程由信号正向传播与误差的反向回传两个部分组成;正向传播时,输入样本从输入层传入,经各隐层依次逐层处理,传向输出层,若输出层输出与期望不符&…

含汞废水的深度处理方法

CH-95 是一款为了从工业废水中去除回收汞和贵金属而专门开发的螯合树脂。拥有聚乙烯 异硫脲官能基的大孔树脂,这种树脂对汞有极高的选择性。钠,碱土,铁铜等重金属等不能干扰 其对汞的选择性去除。 CH-97 是一款含有附着甲基硫醇聚苯乙烯共…

基于PB的企业人力资源信息系统设计与实现

目 录 摘 要 I Abstract II 第1章 引言 1 1.1选题背景及意义 1 1.2发展现状 1 1.3论文结构 2 第2章 系统分析 3 2.1 系统目标 3 2.2 系统需求分析 3 第3章 系统设计 5 3.1 系统功能结构设计 5 3.2 数据库设计与实现 7 3.2.1数据库需求分析 7 3.2.2数据库概念结构设计 8 3.2.3数…

[oeasy]python0010 - python虚拟机解释执行py文件的原理

解释运行程序 🥊 回忆上次内容 我们这次设置了断点 设置断点的目的是更快地调试调试的目的是去除​​bug​​别害怕​​bug​​一步步地总能找到​​bug​​这就是程序员基本功 调试​​debug​​ 我心中还是有疑问 ​​python3​​ 是怎么解释​​hello.py​​ 的…

Python实现SSA智能麻雀搜索算法优化支持向量机分类模型(SVC算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 麻雀搜索算法(Sparrow Search Algorithm, SSA)是一种新型的群智能优化算法,在2020年提出&am…

pytorch:常见的pytorch参数初始化方法总结

pytorch参数初始化1. 关于常见的初始化方法1) 均匀分布初始化torch.nn.init.uniform_()2) 正态分布初始化torch.nn.init.normal_()3) 常量初始化torch.nn.init.constant_()4) Xavier均匀分布5)Xavier正态分布初始化6) kaiming均匀分布初始化7) kaiming正…

除了pid还有什么控制算法,类似pid算法还有哪些

什么是专家PID?他和传统的PID有什么区别? PID是智能控制啊,比如要控制一个水管的水流量,通过流量计,开关阀,让PID来控制开关阀的开关大小使水流量正确.专家PID记得是PID的高级设置,某些个场合一般的PID无法使用,出现了了专用的,有特殊功能的.记忆中是这…

防火墙的ISP选路

拓补图: 实验目的: 让R1走ISP1的路径访问192.168.1.1,R2走ISP2的路径访问172.16.1.1 1. IP地址的配置略 2. 防火墙区域的划分(防火墙的g1/0/2接口是属于ISP1接口,所以需要自己新建一个区域然后添加接口,…