3.8.cuda运行时API-使用cuda核函数加速yolov5后处理

news/2024/5/6 5:13:18/文章来源:https://blog.csdn.net/qq_40672115/article/details/131627933

目录

    • 前言
    • 1. Yolov5后处理
    • 2. 后处理案例
      • 2.1 cpu_decode
      • 2.2 gpu_decode
    • 总结

前言

杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。

本次课程学习精简 CUDA 教程-使用 cuda 核函数加速 yolov5 的后处理

课程大纲可看下面的思维导图

在这里插入图片描述

1. Yolov5后处理

Yolov5 是目标检测中比较经典的模型,学习对其后处理进行解码是非常有必要的。在这里我们仅使用核函数对 Yolov5 推理的结果进行解码并恢复成框,掌握后处理所解决的问题,以及对于性能的考虑。

经验之谈

  1. 对于后处理的代码研究,可以把 PyTorch 的数据通过转换成 numpy 后,tobytes 再写到文件,然后再到 c++ 中读取的方式,能够快速进行问题研究和排查,此时不需要 tensorRT 推理也可以做后处理研究。这也叫变量控制法
  2. fast_nms_kernel 会在极端情况下少框,但是这个极端情况一般不会出现,实测几乎没有影响
  3. fast nms 在 cuda 实现上比较简单,高效,不用排序

2. 后处理案例

我们来看下 Yolov5 整个后处理过程:decode解码 + nms

由于整个后处理过程可能有点复杂,因此我们可以先在 CPU 上完成,然后再考虑 GPU 上的工作。

为了方便演示整个后处理过程,我们通过 PyTorch 去进行推理,把推理的结果利用 numpy 保存下来,然后利用 c++ 读取进行后处理,同时也可以看下 PyTorch 最终的结果和我们后处理的结果是否一致。

numpy 保存推理结果的代码如下:

with open("../workspace/predict.data", "wb") as f:f.write(pred.cpu().data.numpy().tobytes())

Yolov5 在 COCO 数据集上的输入是一个 [n,85] 为维度的 tensor,其中 85 是 [cx,cy,width,objectness,classfication * 80]

关于后处理原理和更多细节请查看 YOLOv5推理详解及预处理高性能实现

2.1 cpu_decode

我们先来看 cpu_decode,CPU 解码的重点有:

  1. 避免多余的计算,需要知道有些数学运算需要的事件远超过很多 if,减少他们的次数就是提高性能的关键
  2. nms 的实现是可以优化的,例如 remove_flags 并且预先分配内存,reserve 对输出分配内存

核心代码如下:

vector<Box> cpu_decode(float* predict, int rows, int cols, float confidence_threshold = 0.25f, float nms_threshold = 0.45f){vector<Box> boxes;int num_classes = cols - 5;for(int i = 0; i < rows; ++i){float* pitem = predict + i * cols;float objness = pitem[4];if(objness < confidence_threshold)continue;float* pclass = pitem + 5;int label     = std::max_element(pclass, pclass + num_classes) - pclass;float prob    = pclass[label];float confidence = prob * objness;if(confidence < confidence_threshold)continue;float cx     = pitem[0];float cy     = pitem[1];float width  = pitem[2];float height = pitem[3];float left   = cx - width * 0.5;float top    = cy - height * 0.5;float right  = cx + width * 0.5;float bottom = cy + height * 0.5;boxes.emplace_back(left, top, right, bottom, confidence, (float)label);}std::sort(boxes.begin(), boxes.end(), [](Box& a, Box& b){return a.confidence > b.confidence;});std::vector<bool> remove_flags(boxes.size());std::vector<Box> box_result;box_result.reserve(boxes.size());auto iou = [](const Box& a, const Box& b){float cross_left   = std::max(a.left, b.left);float cross_top    = std::max(a.top, b.top);float cross_right  = std::min(a.right, b.right);float cross_bottom = std::min(a.bottom, b.bottom);float cross_area = std::max(0.0f, cross_right - cross_left) * std::max(0.0f, cross_bottom - cross_top);float union_area = std::max(0.0f, a.right - a.left) * std::max(0.0f, a.bottom - a.top) + std::max(0.0f, b.right - b.left) * std::max(0.0f, b.bottom - b.top) - cross_area;if(cross_area == 0 || union_area == 0) return 0.0f;return cross_area / union_area;};for(int i = 0; i < boxes.size(); ++i){if(remove_flags[i]) continue;auto& ibox = boxes[i];box_result.emplace_back(ibox);for(int j = i + 1; j < boxes.size(); ++j){if(remove_flags[j]) continue;auto& jbox = boxes[j];if(ibox.label == jbox.label){// class matchedif(iou(ibox, jbox) >= nms_threshold)remove_flags[j] = true;}}}return box_result;
}

该代码主要可分为预处结果解码和非极大值抑制两部分

预测结果解码

首先遍历每个预测框,通过置信度阈值(confidence_threshold)对预测结果进行过滤。然后,计算预测框的类别,选择 80 个类别中最高概率的类别作为预测框的标签。接下来,将预测框的中心点和宽高转变成左上角和右下角坐标,并将预测框的信息保存到 boxes

非极大值抑制(NMS)

首先我们需要对 boxes 中的所有预测框按照置信度进行降序排序,方便后续 NMS 操作。NMS 的实现主要是通过 remove_flags 这个标志来实现的,将未标记为需要移除的预测框保存到 box_result

关键的性能优化点

  • 预测框过滤,在 decode 过程中先利用置信度阈值过滤,避免了不必要的后续计算和处理
  • 预测框排序,在 lambda 函数中传引用,同时对 box_result 利用 reverse 进行预分配提升性能
  • 使用标志位:在 NMS 过程中,使用 remove_flags 标志位来标记需要移除的预测框,相比于两两预测框比较提高了效率

2.2 gpu_decode

我们再来看 gpu_decode,GPU 解码的重点有:

  1. 表示输出数量不确定的数组,用 [count, box1, box2, box3] 的方式,此时需要有最大数量限制
  2. 通过 atomicAdd 实现数组元素的加入,并返回索引
  3. 和 cpu_decode 一样,不必要的计算尽量省掉

decode 核心代码如下:

static __global__ void decode_kernel(float* predict, int num_bboxes, int num_classes, float confidence_threshold, float* invert_affine_matrix, float* parray, int max_objects, int NUM_BOX_ELEMENT
){  int position = blockDim.x * blockIdx.x + threadIdx.x;if (position >= num_bboxes) return;float* pitem     = predict + (5 + num_classes) * position;float objectness = pitem[4];if(objectness < confidence_threshold)return;float* class_confidence = pitem + 5;float confidence        = *class_confidence++;int label               = 0;for(int i = 1; i < num_classes; ++i, ++class_confidence){if(*class_confidence > confidence){confidence = *class_confidence;label      = i;}}confidence *= objectness;if(confidence < confidence_threshold)return;int index = atomicAdd(parray, 1);if(index >= max_objects)return;float cx         = *pitem++;float cy         = *pitem++;float width      = *pitem++;float height     = *pitem++;float left   = cx - width * 0.5f;float top    = cy - height * 0.5f;float right  = cx + width * 0.5f;float bottom = cy + height * 0.5f;// affine_project(invert_affine_matrix, left,  top,    &left,  &top);// affine_project(invert_affine_matrix, right, bottom, &right, &bottom);// left, top, right, bottom, confidence, class, keepflagfloat* pout_item = parray + 1 + index * NUM_BOX_ELEMENT;*pout_item++ = left;*pout_item++ = top;*pout_item++ = right;*pout_item++ = bottom;*pout_item++ = confidence;*pout_item++ = label;*pout_item++ = 1; // 1 = keep, 0 = ignore
}

上述 gpu_decode 代码和 cpu 处理非常像,其中核函数启动的线程数为预测框的数量,每个线程处理一个框的解码工作,position 代表当前线程的 Idx,*predict 为所有预测框的首地址,pitem 为当前线程要处理的预测框的起始地址,如下图所示:

在这里插入图片描述

图2-1 pitem

同时为了保存 decode 后的预测框,我们使用原子加(atomicAdd)操作来避免多个线程同时写入输出数组时的冲突问题,可以确保结果的准确性。具体来说,index = atomicAdd(parray, 1) 表示将 parray 指向的内存位置的值加上 1,并将加前的值赋给 index,而 index 表示当前所处理的边界框在所有边界框中的索引值。为了避免超过最大边界框数量,会在 index 超过 MAX_IMAGE_BOXES 时直接返回,不再处理该边界框。

将预测框完成解码后就需要将其解码后的框信息保存下来,保存的首地址是 *parrayparray 的第一个元素是保存下来的框的数量,后面才是一个个框的信息,如下图所示。

在这里插入图片描述

图2-2 pout_item

当然对于 nsm 你也可以采用 cuda 加入,代码如下:

static __global__ void fast_nms_kernel(float* bboxes, int max_objects, float threshold, int NUM_BOX_ELEMENT){int position = (blockDim.x * blockIdx.x + threadIdx.x);int count = min((int)*bboxes, max_objects);if (position >= count) return;// left, top, right, bottom, confidence, class, keepflagfloat* pcurrent = bboxes + 1 + position * NUM_BOX_ELEMENT;for(int i = 0; i < count; ++i){float* pitem = bboxes + 1 + i * NUM_BOX_ELEMENT;if(i == position || pcurrent[5] != pitem[5]) continue;if(pitem[4] >= pcurrent[4]){if(pitem[4] == pcurrent[4] && i < position)continue;float iou = box_iou(pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3],pitem[0],    pitem[1],    pitem[2],    pitem[3]);if(iou > threshold){pcurrent[6] = 0;  // 1=keep, 0=ignorereturn;}}}
} 

fast_nms_kernel 在极端情况下会少框,比如当存在多个重叠框,并且它们具有相同的置信度时,由于核函数中的条件判断和并行计算的特性,可能会导致后面的框覆盖前面的框,从而使得前面的框被忽略。

值得注意的是在对 mAP 进行测试性能的时候,只能采用 CPU 版本的 nms,这是因为 mAP 测试需要精确计算每个框的重叠情况,并且需要按照特定的算法进行排序和抑制。而在 GPU 上进行并行计算的 nms 方法往往会牺牲一定的精确性,无法满足 mAP 测试的要求。

下图对比了 PyTorch 的效果和我们自己实现的后处理的效果,可以看到结果是没问题的

在这里插入图片描述

图2-3 PyTorch效果

在这里插入图片描述

图2-4 自定义实现后处理的效果

总结

本次课程学习了经典目标检测算法 Yolov5 的后处理,我们先在 cpu 上实现了整个 decode,cpu 版本的实现性能已经非常高了,适合在一些边缘嵌入式设备上运行,随后我们根据 cpu 版本的 decode 编写了核函数来加速整个 decode 解码过程,很多东西还是需要大家自己多去动手,多去尝试。

关于代码的更多探讨可参考 infer源码阅读之yolo.cu

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_328896.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue 2.0 的使用

day01 1. Vue简介 一套用于构建用户界面的 <font colorred>渐进式框架</font> 2. 初识Vue 2.1 搭建Vue开发环境 第一步&#xff1a;去<a href"https://v2.cn.vuejs.org/">Vue2官网</a>&#xff0c;下载依赖包。 第二步&#xff1a;在 …

多线程与并发编程【线程休眠、线程让步、线程联合、判断线程是否存活】(二)-全面详解(学习总结---从入门到深化)

目录 线程休眠 线程让步 线程联合 Thread类中的其他常用方法 判断线程是否存活 线程的优先级 线程休眠 sleep()方法&#xff1a;可以让正在运行的线程进入阻塞状态&#xff0c;直到休眠时间 满了&#xff0c;进入就绪状态。sleep方法的参数为休眠的毫秒数。 public class…

两部搞定Pytorch 安装与配置(小白也能搞定!!!)

Pytorch 安装与配置 NVIDIA系统管理界面查看 nvidia-smi 进入NVIDIA系统管理界面 对应的详细解释看下图 参考博文 (53条消息) nvidia-smi命令详解和一些高阶技巧介绍_Chaos_Wang_的博客-CSDN博客 CUDA 查看 CUDA 有两类&#xff1a;其中一类是驱动API(Driver API)&#xff…

实现windows系统文件传输到Linux系统中的工具

1、实现windows系统文件传输到Linux系统中的工具 yum -y install lrzsz然后就可以将windows中的文件&#xff0c;直接拖到Xshell窗口即可。

【钱处理】商业计算怎样才能保证精度不丢失

以项目驱动学习&#xff0c;以实践检验真知 前言 很多系统都有「处理金额」的需求&#xff0c;比如电商系统、财务系统、收银系统&#xff0c;等等。只要和钱扯上关系&#xff0c;就不得不打起十二万分精神来对待&#xff0c;一分一毫都不能出错&#xff0c;否则对系统和用户来…

Kafka入门,mysql5.7 Kafka-Eagle部署(二十五)

官网 https://www.kafka-eagle.org/ 下载解压 这里使用的是2.0.8 创建mysql数据库 创建名为ke数据库,新版本会自动创建&#xff0c;不会创建的话&#xff0c;自己手动创建&#xff0c;不然会报查不到相关表信息错误 SET NAMES utf8; SET FOREIGN_KEY_CHECKS 0;-- ------…

拥有铁粉,怀抱CSDN大家庭

&#x1f451; 个人主页 &#x1f451; &#xff1a;&#x1f61c;&#x1f61c;&#x1f61c;Fish_Vast&#x1f61c;&#x1f61c;&#x1f61c; &#x1f41d; 个人格言 &#x1f41d; &#xff1a;&#x1f9d0;&#x1f9d0;&#x1f9d0;说到做到&#xff0c;言出必行&am…

python_day4

def test():return 1, a, Truex, y, z test() print(f"x{x},y{y},z{z}")位置参数&#xff1a;调用时根据参数位置传递参数 关键字参数&#xff1a;调用时通过“键值”形式传参 def user(name, age, gender):print(f"name:{name},age:{age},gender:{gender}&q…

图床项目之公网发布和测试

项目发布和测试 一、http服务测试1.1、ab http压力测试1.2、post测试&#xff08;注册请求和登录请求&#xff09; 二、性能测试2.1、生成测试脚本2.2、上传测试2.2.1、单客户端测试本地上传到本机服务器2.2.2、如果使用集群的方式进行测试 2.3、下载测试2.4、删除测试2.5、测试…

springboot请求重定向失败问题解决方案

今天晚上在写登录页面时&#xff0c;发现自己的首页无法正常访问&#xff0c;用户名和密码正常的情况下还是无法访问首页。于是开始进行debug&#xff0c; 程序执行至此处时无任何异常&#xff0c;但是就是在进行重定向页面时出现了404&#xff0c;在检查导航栏后发现地址栏也发…

深度学习——批数据训练

代码与详细注释&#xff1a; BATCH_SIZE 5&#xff0c;shuffleTrue import torch import torch.utils.data as Data# 添加随机种子以使结果可复现 torch.manual_seed(1) # reproducible# 批大小 BATCH_SIZE 5 # BATCH_SIZE 8x torch.linspace(1, 10, 10) # this…

dvwa靶场通关(九)

第九关&#xff1a;Weak Session IDs&#xff08;弱会话IDs&#xff09; 当用户登录后&#xff0c;在服务器就会创建一个会话(session)&#xff0c;叫做会话控制&#xff0c;接着访问页面的时候就不用登录&#xff0c;只需要携带 Sesion去访问。 sessionID作为特定用户访问站…

用技术指标伦敦金行情走势图

经常有投资者说&#xff0c;伦敦金行情走势图老是涨跌涨跌&#xff0c;抓不准它涨跌的规律&#xff0c;老是被它弄得头昏脑胀。其实看伦敦金行情走势图的方法有很多&#xff0c;最直接的就是使用技术指标。技术指标本来就是投资者为了避免伦敦金行情走势图上价格干扰性波动&…

什么是热修复?它的优缺点是什么?

我们开发时常常要考虑的一些问题。 开发上线的版本能保证不存在Bug么&#xff1f; 修复后的版本能保证用户都及时更新么&#xff1f; 如何最大化减少线上Bug对业务的影响&#xff1f; 热修复技术帮助我们解决了很多问题&#xff0c;带来的优势不言而喻。不知道各位对于热修复技…

【AcWing算法基础课】第四章 数学知识(未完待续)

文章目录 前言课前温习番外&#xff1a;秦九韶算法核心模板 一、质数1. 试除法判定质数核心模板1.1题目描述1.2思路分析1.3代码实现 2、试除法分解质因数核心模板1.4题目描述1.5思路分析1.6代码实现 二、筛素数1.朴素筛法求素数核心模板2.线性筛法求素数&#xff08;O(n)&#…

vue拼接html中onclick的触发方式,模版字符串拼接点击事件在vue项目中不生效问题

模版字符串拼接点击事件在vue项目中不生效问题 下面的点击事件没有任何效果&#xff0c;但是如果换成onclick绑定事件则会提示没有该方法。主要原因是&#xff1a; 模版字符串中拼接的html片段中的方法调不到vue中this.methods里的东西&#xff0c;因为methods里的代码是编译…

STM32 Proteus UCOSII系统多路数据采集系统8路开关量4路电压-0058

STM32 Proteus UCOSII系统多路数据采集系统8路开关量4路电压-0058 Proteus仿真小实验&#xff1a; STM32 Proteus UCOSII系统多路数据采集系统8路开关量4路电压-0058 功能&#xff1a; 硬件组成&#xff1a;STM32F103R6单片机 LCD1602显示器8路光耦隔离开关量采集4路微小信号…

你的流量虚了吗?分析手机流量卡不足量的套路

当今时代&#xff0c;手机流量的使用是每个人每天都在消耗的事情&#xff0c;在有WIFI的情况下还好&#xff0c;大家不需要担心流量用多了还是少了&#xff0c;但是在使用手机流量的时候&#xff0c;就需要注意了&#xff0c;看看是不是会用超什么的&#xff0c;但是现在有一个…

网络编程5——TCP协议的五大效率机制:滑动窗口+流量控制+拥塞控制+延时应答+捎带应答

文章目录 前言一、TCP协议段与机制TCP协议的特点TCP报头结构TCP协议的机制与特性 二、TCP协议的 滑动窗口机制 三、TCP协议的 流量控制机制 四、TCP协议的 拥塞控制机制 五、TCP协议的 延时应答机制 六、TCP协议的 捎带应答机制 总结 前言 本人是一个普通程序猿!分享一点自己的…

QT事件处理

设计一个闹钟&#xff0c;定时播报内容。 #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow> #include <QTimerEvent> #include <QDateTime> #include <QMessageBox> #include <QTextToSpeech> #include <QDebug> namespa…