ICCV 2021 | Y-Net:轨迹-场景信息的真正融合

news/2024/5/9 17:17:32/文章来源:https://blog.csdn.net/m0_57541899/article/details/127520793

今天没有多余的解释,直接开始吧~

1. Y-Net网络结构

       Y-Net的网络结构长什么样子呢?Y-Net的网络结构就长下图这样子。看上去我好像在自言自语,其实你仔细揣摩就会发现,我真的是在自言自语。可以看到说,Y-Net网络输入的是一张张的图片,而不是序列。这一点很重要的,因为只有先搞清输入输出是什么,才能进行接下来的工作。那在(一)中的时候说,对于给定的RGB三通道图片I,先通过语义分割网络得到图片I的语义分割图S,于此同时将行人的过去轨迹转化为轨迹热力图H,然后将语义分割图S与轨迹热力图进行concatenate拼接,之后送到我们的编码器U_{e}当中。编码器U_{e}的最终输出H_{M}作为解码器U_{g}U_{t}的输入,U_{e}的中间特征H_{m}(1\leq m\leq M)将与俩解码器的中间层输出进行skip connection,也就是进行特征融合。在Goal & Waypoint decoder中,最后的输出层由一个卷积层后跟一个像素级的sigmoid函数组成,对goal 和 waypont生成一个概率分布,从这个概率分布中采样我们所需的goal 和 waypoint,紧接着将采样得到的goal and waypoint转换为goal & waypoint Heatmap,记为H_{g}。最后,对向量H_{g}进行下采样以匹配U_{t}中每个block的空间尺寸(从图中来看是要下采样6次,但是最后一个下采样箭头是不是画错位置了呢?另外,匹配的意思指的是”拼接“,而不是”输入“)

 2. encoder U_{e}

       我们知道,编码器要干的事情就是提取图片的深层特征。Y-Net的encoder架构很显然是仿照U-Net encoder所设计的,encoder的输入是Segmentation Map与Trajectory Heatmap所拼接的张量tensor,那这个tensor的维度是多少呢?那我们先给定这个tensor的维度是4 \times 14 \times 416 \times 512 (N \times C \times H \times W )。每一层的输出如下图所示:

3. decoder U_{g}

        解码器实际上就是要恢复原图的大小,并融合深层的特征。那解码器U_{g}是怎么实现该功能的呢?从图中我们可以看到,H_{M}首先会被送到U_{g}的center block当中去,然后再将center block的输出送到接下来的模块。接下来的模块是反卷积——skip connection——卷积的重复操作。

       我们来看前向传播过程:在encoder的前向传播代码中,会将六层的future特征预存在feature[ ]中,这是为了在decoder中方便进行特征融合。读过U-Net的朋友应该知道,特征融合应该有相同的维度(当然channel数可以不同)。所以在U_{g}的前向传播过程中,首先将encoder保存的特征进行逆排序,然后将feature进行切片——也就是代码中的feature[0]——赋给center_feature。经过两层卷积后以center_feature的输出维度作为后续block的输入x。在for循环当中,enumerate()是枚举函数,zip()函数将对应的元素打包成一个个元组,F.interpolate()对输入的x进行插值,为接下来上采样做准备;torch.cat()将encoder中间层的特征与decoder进行concatenate拼接,拼接好后输入到module()模块,该模块在decoder该类中有所定义,实际上就是decoder架构中反卷积后的两层卷积层。在for循环完之后,有一个self.predictor(),同样得,它在decoder该类中有定义,实际上就是U_{g}的输出层,经过它我们将得到goal $ waypoint heatmap logits.

每一层的输出如下:

4.  decoder U_{t} 

       上面说过, 将采样得到的goal and waypoint转换为goal & waypoint Heatmap,记为H_{g}。最后,对向量H_{g}进行下采样以匹配U_{t}中每个block的空间尺寸。采样的过程该怎么用代码实现呢?

gt_waypoints_maps_downsampled = [nn.AvgPool2d(kernel_size=2**i, stride=2**i)(gt_waypoint_map) for i in range(1, len(features))]gt_waypoints_maps_downsampled = [gt_waypoint_map] + gt_waypoints_maps_downsampled

       采样的结果将保存在列表里,将随H_{M}H_{m}一同输入到U_{t}当中:

traj_input = [torch.cat([feature, goal], dim=1) for feature, goal in zip(features, gt_waypoints_maps_downsampled)]
pred_traj_map = model.pred_traj(traj_input)

       这一行代码一定要注意,得到的traj_input的输出为:(4,33,416,512),(4,33,208,256),(4,65,104,128),(4,65,52,64),(4,65,26,32),(4,65,13,16),从代码中我们可以看到,zip()函数将来自U_{e}的特征即该行代码里的feature,与对H_{g}下采样的结果打包成了元组,然后再进行concat拼接得到traj_input的输出。所以,待会U_{t}在前向传播的时候,已经有了来自下采样的拼接,所以我希望读者看到这的时候不会疑惑。

 U_{t}的网络结构与U_{g}并没有太大区别,不同的是U_{t}的输入除了H_{M}之外,还有下采样H_{g}的结果。前向传播的过程与U_{g}是一样的,这里就不在赘述。每一层的输入输出如下:

5. 训练与测试 

       那回顾之前对Y-Net的描述,读者会发现说,其实我并没有涉及对算法很具体的描述,只弄一弄输入输出是提高不了博客质量的。这篇文章是很早之前写的,那现在来将算法描述清楚。现在是2022年10月25日,是的,没有看错,过去了很久的时间。

       Y-Net网络架构是很特别的,为什么这么说呢?在Y-Net之前,行人轨迹预测的论文对行人与场景、行人与行人之间的交互都是简单的将其放在一个大矩阵里,然后喂给相应的模型进行特征融合,再经过某种处理得到最后的预测轨迹,这种模式相当于:轨迹——Pooling——轨迹。但是Y-Net通过将轨迹映射到特征图的方式做到了真正的融合,这种模式相当于:轨迹——Mapping——轨迹,这是一种很大胆的创新。

       在开始之前,论文中几个符号的含义请务必弄清:

K_{e}:goal的数目

K_{a}:到达某个goal的路径数目

N^{w}:waypoint的映射图

④waypoint:到达某个goal中途的点

       Y-Net训练与测试存在一定的不同。举例来说,在训练阶段,轨迹heatmap与场景图叠加后作为模型的输入(4,8,416,512),经过U_{e}U_{g}处理得到(4,12,46,512)大小维度的输出。这一步实际上可以将U-Net视为预测部件,输入8帧(张)图片,得到后续12帧(张)的图片。U_{t}模块的输入来自三个方面,一是U_{e}最后一层的输出、二是来自U_{g}下采样的拼接、三是与U_{e}中间层的跳跃连接。一、三方面的输入是固定的,变化的是第二方面的输入。举例来说,假定超参数K_{e}=20K_{a}=1N^{w}=[10,11],需要解释的是上述超参数分别代表选取goal的数目、到达某个goal存在的路径数目以及选择GT的第10、11张heatmap图作为waypoint的map(注意:①训练阶段不引入goal的信息,或者换句话说直接将GT的第12张heatmap图作为goal的信息;②这里所说的GT的第10、11张图是指,人为的将GT heatmap分为前8张(0~8)与后12张(0~12))。训练阶段提供真实的第19帧轨迹坐标映射的heatmap与下分支U_{t}的各个block进行拼接,经U_{t}输出(4,12,416,512)大小的特征图,再对该输出特征图进行采样(softargmax)得到概率值经变换后得到预所有行人12帧下的预测坐标(4,12,2)。请注意,这里的上分支的监督12个点,与GT算loss,得到goal_loss,下分支也监督12个点,也与GT算loss得到traj_loss。这里有意思的是:上下分支监督12个相同的点,那这两个loss有什么区别?并且也没有实现论文中所提及的L_{waypoint} loss.

Training

       测试阶段,超参数K_{e}=20K_{a}=1N^{w}=[10,11]不变,请注意这里N^{w}=[10,11]含义不同了:此时N^{w}=[10,11]的含义是选择U_{g}输出(4,12,46,512)的第10、11张特征图作为采样heatmap和goal的map。读者可能很自然的想到,在第10张图片上采样waypoint,在第11张图片上采样goal。事实上,我们只需要在第11图片上采样就可以了。那为什么还需要第10张图片呢?这里就涉及具体编程时的逻辑语句。举例来说,如果只设置N^{w}=[10],那len(N^{w}=[10])=1,相当于goal就是waypoint。没有waypoint什么事的话,那多尴尬啊。所以设置len(N^{w}=[10,11])=2,就意味着waypoint与goal不可能重合。由于我们只在第11图片上采样,那为方便叙述,我们给它取名h_{11},它的大小为(4,1,288,512)。好,那怎么得到所需要的goal呢?一个不太直觉的想法是:在你得到的h_{11}上进行采样。那采样(sampling)这件事怎么又该怎么操作呢?举例来说,我们首先会将h_{11}进行平铺的操作,得到维度大小为(4,157696)的概率矩阵,行数4代表着4位行人,列数实际上是每一张概率图的展平。现在假定我们要取样20个点,也就是n_sample = goal = 20,我们利用torch.multinomial函数对该矩阵的每一行随机采样20次,返回每次采样概率所在行的位置,那这样我们就得到了一个大小为(4,20)的样本矩阵,4代表4个行人,20代表每一个行人采样20个样本点。到这里其实还不是我们想要的目的,你想想,我们要采样样本点的目的实际上是想得到每个goal的坐标,所以我们还需要将(4,20)的样本矩阵处理成最终的坐标形式,得到(20,4,1,2)(batch, 1, n_samples, dim)矩阵才是每一个goal在图中的像素坐标。

       现在goal有了,waypoint又该怎么得到呢?这里分情况讨论:第一种情况是你的K_{a}=1(注意K_{a}的含义),此时goal就是waypoint,waypoint就是goal。第二种情况是你的K_{a}\geqslant 2,此时你还需要在h_{11}对waypoint进行采样,那采样的次数由什么决定呢?采样的次数由你预先设想的到每个goal有几条路径所决定,用直观的符号表示就是n_{samples}=K_{e}\times K_{a},这里的K_{a}就是到某个goal的路径数。有了采样的次数,你需要的就是在h_{11}重复上述采样过程就可以得到waypoint的像素坐标。

       Anyway,到目前为止,我们有了goal的像素坐标,有了waypoint的像素坐标,紧接着就该将每个采样点的坐标映射到图中了吧。将每个采样点的坐标映射到图中这一操作具有普遍性,也就是说不管你是情况一还是情况二都是这么进行映射的,那具体怎么操作呢?举例来说在事先你会有一个距离模板矩阵,这个矩阵的大小是1050\times 1050,该矩阵中的每一个数是网格索引到中心点的距离的归一化。这个模板矩阵的意义是:减少每个点映射到图中的计算量,举例来说要产生一个点到图的映射,待会你只需在做好的模板矩阵上截取patch就能够实现,而不需要繁杂的计算。这一过程用公式来表述为:

      以第一种情况K_{a}=1为例,超参数设置 K_{e}=20K_{a}=1N^{w}=[10,11]不变,采样得到waypoint(20,4.1,2)和goal(20,4.1,2)之后,我们将其Concatenate拼接起来(20,4.2,2),每一个样本映射到一张heatmap图,所以你的样本总数(20)有多少,就会有多少张heatmap图,并且每次循环只用一张heatmap,总共循环20次,每次循环最终的输出为(4,12,2)大小的矩阵,那堆叠20次循环的输出,最终得到的输出维度为(20,4.12,2)(n_samples, batch, seq_len, dim)

Testing

 那到这里,Y-Net也就结束了,这是一篇很有启发性的文章,代码非常优美,读者有时间一定要细看。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_218387.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TPH-YOLOv5: 基于Transformer预测头的改进YOLOv5用于无人机捕获场景目标检测

代码链接:GitHub - cv516Buaa/tph-yolov5 这是一篇针对无人机小目标算法比赛后写的论文,无人机捕获场景下的目标检测是近年来的热门课题。由于无人机总是在不同的高度上飞行,目标尺度变化剧烈,给网络优化带来了负担。此外&#xf…

buu [NPUCTF2020]认清形势,建立信心

题目: from Crypto.Util.number import * from gmpy2 import * from secret import flagp getPrime(25) e # Hidden q getPrime(25) n p * q m bytes_to_long(flag.strip(b"npuctf{").strip(b"}"))c pow(m, e, n) print(c) print(pow(2,…

hadoop至MapReduce-004

MapReduce定义 MapReduce是一个分布式运算程序的编程框架,核心功能是将用户编写的业务逻辑代码和自带默认组件组合成一个完整的分布式运算程序,并发运行在hadoop集群上 MapReduce的优缺点 优点 易于编程:用户只关心业务逻辑代码扩展性&am…

webpack 异步import生成代码解析

文章目录原文件内容文件目录打包前打包后入口文件生成代码生成的一些辅助方法__webpack_require__.m__webpack_require__.d__webpack_require__.o__webpack_require__.u__webpack_require__.g__webpack_require__.r导入文件通用方法__webpack_require__异步文件引入获取下载文件…

AntDB-M设计之CheckPoint

1.引 言 数据库服务能力提升是一项系统性的工程,在不同的应用场景下,用户对于数据库各项能力的关注点也不同,如:读写延迟、吞吐量、扩展性、可靠性、可用性等等。国内不少数据库系统通过系统架构优化、硬件设备升级等方式&…

教程:使用Jmeter对带token的接口进行压测

最近在研究并发,用到了Jmeter对接口进行压力测试,记录下使用过程 一. 配置/bin下的Jmeter.properties,打开以下两项配置,一个是默认的编码,一个是默认的语言 二. 打开jmeter.bat运行,新建线程组&#xff0…

qt学习笔记6:ui实例 登录窗口布局

首先从ui布局界面去进行大致布局, 可以先把默认的一些移除掉,变成一个大的空窗口 用户窗口,一般都得有一个用户名和密码(用label)输入用Line edit, 再来俩按钮pushButton, 但仅仅这样是没有意义…

kafka学习(四):生产者发送消息的分区策略

Kafka为了增加系统的伸缩性(Scalability),引入了分区(Partitioning)的概念。 Kafka 中的分区机制指的是将每个主题划分成多个分区(Partition),每个分区是一组有序的消息日志。主题下的每条消息只会保存在某一个分区中,…

python 基于PHP在线音乐网站

随着时代的发展,人们的生活水平越来越高,相对应的对精神世界的追求也越来越多,而音乐一直以来一直是人们追求美好生活的象征,它不仅可以陶冶人们的情操还可以美化人们的灵魂,音乐也一直是千百年来人们不断追求的一个精神文明的产物,为了能够让更多的人找到自己喜欢的音乐,我开发…

1.3.1操作系统的运行机制和体系结构

文章目录运行机制两种指令两种状态两种程序操作系统内核内核在计算机的系统中的层次结构内核的功能时钟管理(基本功能)中断机制(基本功能)原语(基本功能)对资源的进行管理的功能运行机制 两种指令 指令和…

python基于PHP旅游网站的设计与开发

在经济高速发展的现在,人们的工作越来越繁重,生活节奏越来越快,生活工作压力也越来越大。反而留给自己休息,享受旅游生活的时间越来越少,缺少对周边旅游信息的了解,无法与兴趣一致的户外旅友进行交流。这则会导致人们会花更多的时间去寻找旅游地点,并进行路线规划,花费的时间在…

彻底理解闭包实现原理

前言 闭包对于一个长期写 Java 的开发者来说估计鲜有耳闻,我在写 Python 和 Go 之前也是没怎么了解,光这名字感觉就有点"神秘莫测",这篇文章的主要目的就是从编译器的角度来分析闭包,彻底搞懂闭包的实现原理。 函数一等公民 一门语言在实现闭包之前首先要具有的特…

工程项目部质量管理体系的控制要点分析

质量管理是施工企业风险控制的重要组成部分。本文从有序的生产过程控制,提高企业质量意识出发,结合贯彻ISO9001标准及50430规范的企业贯标工作,分阶段研究和分析施工企业工程项目部质量管理体系的控制要点。 质量是企业的生命线,…

Android实战——单元测试从吹水到实践

目录1.单元测试到底需要不需要了?开发时间紧张,不需要做单元测试了吧?开发经验丰富,不需要做单元测试了吧?或许存在一种”自动化“的测试,就不需要做单元测试了吧?2.单元测试的好处单元测试可以…

【附源码】计算机毕业设计SSM校园拍卖平台

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

React 状态管理器,我是这样选的

前言 我们的前端团队在一直深度使用 React ,从最早的 CRA ,到后来切换到 umijs ,从 1.x、2.x、3.x 再到现在的 4.x,其中有一点不变的,就是我们一直在使用基于 react-redux 思想的 dva 作为状态管理工具。 在状态共享这…

(附源码)计算机毕业设计SSM跨移动平台的新闻阅读应用

(附源码)计算机毕业设计SSM跨移动平台的新闻阅读应用 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目…

DM-DM DBLINK使用配置

简单介绍 DM-DM DBLINK支持3种连接方式创建,分别是:dmmal、dpi、odbc。 其中dpi、odbc属于第三方接口,dmmal属于原生接口。dpi类型dblink为新版本新添加支持,以前版本中不支持。 环境说明 (1)数据库版本…

2023届C/C++软件开发工程师校招面试常问知识点复盘Part 7

目录46、C类的成员变量初始化顺序及拓展47、强制转换类型操作符号48、const 成员函数–常成员函数与常量对象49、volatile关键字50、赫夫曼树51、前缀树46、C类的成员变量初始化顺序及拓展 注意: 1、const成员或者引用必须在成员变量初始化列表中初始化,…

git的基础指令操作

git的下载地址:https://git-scm.com/download 安装好git后 在桌面上右键即可以看到两个git的快捷方式。 需要先对git进行基本的配置,即需要配置用户名和用户邮箱 1. 打开Git Bash 2. 设置用户信息 git confifig --global user.name “zqy” git confi…