从零构建深度学习推理框架-1 简介和Tensor

news/2024/5/20 20:10:36/文章来源:https://blog.csdn.net/zhuangtu1999/article/details/131921483

源代码作者:https://github.com/zjhellofss

本文仅作为个人学习心得领悟 ,将原作品提炼,更加适合新手

什么是推理框架?

深度学习推理框架用于对已训练完成的神经网络进行预测,也就是说,能够将深度训练框架例如Pytorch、Tensorflow中定义的算法移植到中心侧和端侧,并高效执行。与训练框架不同的是,深度学习推理框架没有梯度反向传播功能,因为算法模型文件中的权重系数已经被固化,推理框架只需要读取、加载并完成对新数据的预测即可。

模型加载阶段

训练完成的模型被放置在两个文件中,一个是模型定义文件,一个是权重文件

ONNX文件是将模型定义文件和权重文件合二为一的文件格式。

关于维度的预备知识

在Tensor张量中,共有三维数据进行顺序存放,分别是Channels(维度),Rows(行高), Cols(行宽),

三维矩阵我们可以看作多个连续的二维矩阵组成,最简单的方法就是std::vector<std::vector<std::vector<float>>>,但是这种方法非常不利于数据的访问(尤其是内存不连续的问题 、修改以及查询,特别是在扩容的时候非常不方便。不能满足使用需求

不连续会造成数组访问慢的问题,在这里我用chrono做了测试:

#include<iostream>
#include <gtest/gtest.h>
#include <armadillo>
#include <glog/logging.h>
#include <vector>
#include <chrono>
#define TICK(x) auto bench_##x = std::chrono::steady_clock::now();
#define TOCK(x) std::cout << #x ": " << std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::steady_clock::now() - bench_##x).count() *1000000<< "ns" << std::endl;
using namespace std;
int m = 10000 , n = 10000 , channel = 2;
TEST(test_compare_vector , speed2D){LOG(INFO)<<"Test of vector & cube"<<endl;vector<vector<float>> matA (m , vector<float>(n , 1));TICK(2D);for (int  i = 0; i <matA.size(); i++){for (int j = 0; j < matA[0].size(); j++){matA[i][j] = matA[i][j]*matA[j][i];}}TOCK(2D);arma::fcube matB( m , n , 1 , arma::fill::ones);TICK(cube);for (int  i = 0; i <m; i++){for (int j = 0; j < n; j++){matB(i , j , 0) = matB(i , j , 0)*matB( j , i , 0);}}TOCK(cube);}

因此,综合考虑灵活性和开发的难易度,作者在这里以Armadillo类中的arma::mat(矩阵 matrix)类和arma::cube 作为数据管理(三维矩阵)类来实现Tensor 库中类的主体,一个cube由多个matrix组成,cube又是Tensor类中的数据实际管理者。一块连续的大内存分配开始写一个tensor,工作量会特别大,折中!

 作者设计的类是以arma::cube为基础实现了Tensor我们主要是提供了更方便的访问方式和对外接口 

 上图即为Tensor与cube的对应关系。

 cube一般有多个维度,在channel维度上有多个matrix。

arma::cube(2,5,3),表示当前的三维矩阵共有2个矩阵构成,每个矩阵都是5行3列的。如果放在我们项目中会以这形式提供 Tensor tensor(2,5,3)

下图是这种情况下的三维结构图,可以看出一个Cube一共有两个Matrix,也就是共有两个Channel。一个Channel放一个Matrix. Matrix的行宽均为Rows和Cols.

 Tensor类方法总览

这里的很多都不需要我们重新去造轮子

比如Fill(float value)就可以直接调用cube里的fill:

void Tensor<float>::Fill(float value) {CHECK(!this->data_.empty());this->data_.fill(value);
}

再比如这个at:

float &Tensor<float>::at(uint32_t channel, uint32_t row, uint32_t col) {CHECK_LT(row, this->rows());CHECK_LT(col, this->cols());CHECK_LT(channel, this->channels());return this->data_.at(row, col, channel);
}

再难一些的就需要我们自己去实现了:

Fill(vector)方法实现:

TEST(test_tensor, fill) {using namespace kuiper_infer;Tensor<float> tensor(3, 3, 3);ASSERT_EQ(tensor.channels(), 3);ASSERT_EQ(tensor.rows(), 3);ASSERT_EQ(tensor.cols(), 3);std::vector<float> values;for (int i = 0; i < 27; ++i) {values.push_back((float) i);}tensor.Fill(values);LOG(INFO) << tensor.data();int index = 0;for (int c = 0; c < tensor.channels(); ++c) {for (int r = 0; r < tensor.rows(); ++r) {for (int c_ = 0; c_ < tensor.cols(); ++c_) {ASSERT_EQ(values.at(index), tensor.at(c, r, c_));index += 1;}}}LOG(INFO) << "Test1 passed!";
}

padding功能实现:

 


TEST(test_tensor, padding1) {using namespace kuiper_infer;Tensor<float> tensor(3, 3, 3);ASSERT_EQ(tensor.channels(), 3);ASSERT_EQ(tensor.rows(), 3);ASSERT_EQ(tensor.cols(), 3);tensor.Fill(1.f); // 填充为1tensor.Padding({1, 1, 1, 1}, 0); // 边缘填充为0ASSERT_EQ(tensor.rows(), 5);ASSERT_EQ(tensor.cols(), 5);int index = 0;// 检查一下边缘被填充的行、列是否都是0for (int c = 0; c < tensor.channels(); ++c) {for (int r = 0; r < tensor.rows(); ++r) {for (int c_ = 0; c_ < tensor.cols(); ++c_) {if (c_ == 0 || r == 0) {ASSERT_EQ(tensor.at(c, r, c_), 0);}index += 1;}}}LOG(INFO) << "Test2 passed!";
}

再谈谈Tensor类中数据的排布

我们以具体的图片作为例子,来讲讲Tensor中数据管理类arma::cube的数据排布方式Tensor类是arma::cube对外更方便的接口,所以说armadillo::cube怎么管理内存的,Tensor类就是怎么管理内存的。希望大家的能理解到位。如下图中的一个Cube,Cube的维度是2,每个维度上存放的是一个Matrix,一个Matrix中的存储空间被用来存放一张图像(lena) 。一个框内(channel)是一个Matrix,Matrix1存放在Cube第1维度(channel 1)上,Matrix2存放在Cube的第2维度上(channel 2). Matrix1和Matrix2的Rows和Cols均代表着图像的高和宽,在本例中就是512和384。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_146831.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UE虚幻引擎教程_生成云平台指定路径下的exe文件

市面上大量优秀的游戏都是基于UE制作的&#xff0c;UE虚幻引擎制作的作品可以在windows、mac、linux以及ps4、x-boxone、ios、android甚至是html5等平台上运行。本文介绍了UE虚幻引擎如何生成云平台指定路径下的EXE。 一、云平台会运行打包文件夹下指定路径的EXE文件 但有时候…

【多选框、表格全选】element el-checkbox、el-table

话不多说 先看效果&#xff1a; 多选框&#xff1a; 表格全选&#xff1a; <template><div><div class"titleLabel"><div class"lineStyle"></div>统计部门</div><div style"display: flex"><e…

项目开启启动命令整合

启动RabbitMQ管理插件 1.启动 RabbitMQ 管理插件。 rabbitmq-plugins enable rabbitmq_management rabbitmq-server # 直接启动&#xff0c;如果关闭窗⼝或需要在该窗⼝使⽤其他命令时应⽤就会停⽌ rabbitmq-server -detached # 后台启动 rabbitmq-server start # 启⽤服务 rab…

(二)安装部署InfluxDB

以下内容来自 尚硅谷&#xff0c;写这一系列的文章&#xff0c;主要是为了方便后续自己的查看&#xff0c;不用带着个PDF找来找去的&#xff0c;太麻烦&#xff01; 第 2 章 安装部署InfluxDB 1、linux 安装方式如下 通过包管理工具安装&#xff0c;比如apt 和yum直接下载可执…

springboot()—— 集成redis

1、新建一个springboot项目 2、添加redis依赖包 可以在新建项目的时候就选上 也可以建完项目以后手动导入pom.xml <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </d…

2023年基准Kubernetes报告:6个K8s可靠性失误

云计算日益成为组织构建应用程序和服务的首选目的地。尽管一年来经济不确定性的头条新闻主要集中在通货膨胀增长和银行动荡方面&#xff0c;但大多数组织预计今年的云使用和支出将与计划的相同&#xff08;45%&#xff09;&#xff0c;或高于计划的&#xff08;45%&#xff09;…

装饰模式-扩展系统功能

买了新车后&#xff0c;不少人会对车进行装饰&#xff0c;比如给车贴膜&#xff0c;喷上骚粉的漆等。某天&#xff0c;小李和小张都买了辆车&#xff0c;小李想给车贴膜&#xff0c;小张想给车先喷漆然后再贴膜。现在中的做法是&#xff0c;把车开到改装店&#xff0c;如果要喷…

浏览器调试Android App

浏览器调试Android App 1. 背景2. 调试工具3. 手机设置4. 打开浏览器(edge)5. 连接手机6. 点击inspect 开始调试 1. 背景 在工作中经常会遇到在原生app中嵌套h5&#xff0c; 但是在某些需要在app里面调试的内容&#xff0c; 却没有像chrome开发者工具这样的工具来帮助我们快速…

react 在build读取env 数据

默认会读取.env 文件 npm install dotenv --save npm install dotenv-cli --save-dev例如读取.env.test "build:test": "dotenv -e .env.test react-app-rewired build",.env.test REACT_APP_CURRENTMODE devREACT_APP_Public_Path "https://baid…

[NLP]使用Alpaca-Lora基于llama模型进行微调教程

Stanford Alpaca 是在 LLaMA 整个模型上微调&#xff0c;即对预训练模型中的所有参数都进行微调&#xff08;full fine-tuning&#xff09;。但该方法对于硬件成本要求仍然偏高且训练低效。 [NLP]理解大型语言模型高效微调(PEFT) 因此&#xff0c; Alpaca-Lora 则是利用 Lora…

算法竞赛入门【码蹄集新手村600题】(MT1040-1060)

算法竞赛入门【码蹄集新手村600题】(MT1040-1060&#xff09; 目录MT1041 求圆面积和周长MT1042 求矩形的面积和周长MT1043 椭圆计算MT1044 三角形面积MT1045 平行四边形MT1046 菱形MT1047 梯形MT1048 扇形面积MT1049 三角形坐标MT1050 空间三角形MT1051 四边形坐标MT1052 直角…

Java通过URL对象实现简单爬虫功能

目录 一、URL类 1. URL类基本概念 2. 构造器 3. 常用方法 二、爬虫实例 1. 爬取网络图片&#xff08;简易&#xff09; 2. 爬取网页源代码 3. 爬取网站所有图片 一、URL类 1. URL类基本概念 URL&#xff1a;Uniform Resource Locator 统一资源定位符 表示统一资源定位…

day39-Password Strength Background(密码强度背景)

50 天学习 50 个项目 - HTMLCSS and JavaScript day39-Password Strength Background&#xff08;密码强度背景&#xff09; 效果 index.html <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name&quo…

Linux 学习记录57(ARM篇)

Linux 学习记录57(ARM篇) 本文目录 Linux 学习记录57(ARM篇)一、外部中断1. 概念2. 流程图框 二、相关寄存器1. GIC CPU Interface (GICC)2. GIC distributor (GICD)3. EXTI registers 三、EXTI 寄存器1. 概述2. 内部框图3. 寄存器功能描述4. EXTI选择框图5. EXTI_EXTICR1 &…

go-zero学习 第六章 分布式事务dtm

go-zero学习 第六章 分布式事务dtm 1 参考文档2 官方示例3 go-zero使用dtm参考代码3.1 go-zero支持dtm 代码操作步骤※3.2 gozerodtm 代码操作步骤 4 注意事项4.1 grpc接口地址※4.2 动态调用过程4.3 dtm的回滚补偿4.4 barrier的空补偿、悬挂等4.5 barrier在rpc中本地事务 1 参…

Volatile关键字详解

Volatile关键字详解 volatile的定义 这个引用JSR中的定义&#xff1a; The Java programming language allows threads to access shared variables (17.1). As a rule, to ensure that shared variables are consistently and reliably updated, a thread should ensure tha…

数据结构【栈和队列】

第三章 栈与队列 一、栈 1.定义&#xff1a;只允许一端进行插入和删除的线性表&#xff0c;结构与手枪的弹夹差不多&#xff0c;可以作为实现递归函数&#xff08;调用和返回都是后进先出&#xff09;调用的一种数据结构&#xff1b; 栈顶&#xff1a;允许插入删除的那端&…

25.3 matlab里面的10中优化方法介绍——Nelder-Mead法(matlab程序)

1.简述 fminsearch函数用来求解多维无约束的线性优化问题 用derivative-free的方法找到多变量无约束函数的最小值 语法 x fminsearch(fun,x0) x fminsearch(fun,x0,options) [x,fval] fminsearch(...) [x,fval,exitflag] fminsearch(...) [x,fval,exitflag,output] fmins…

SpringCloudAlibaba微服务实战系列(二)Nacos配置中心

SpringCloudAlibaba Nacos配置中心 在java代码中或者在配置文件中写配置&#xff0c;是最不雅的&#xff0c;意味着每次修改配置都需要重新打包或者替换class文件。若放在远程的配置文件中&#xff0c;每次修改了配置后只需要重启一次服务即可。话不多说&#xff0c;直接干货拉…

网络通信原理(第十八课)

网络通信原理(第十八课) 4.1 回顾 1.什么是TCP/IP 目前应用广泛的网络通信协议集 国际互联网上电脑相互通信的规则、约定。 2.主机通信的三要素 IP地址:用来标识一个节点的网络地址(区分网络中电脑身份的地址,如人有名字) 子网掩码:配合IP地址确定网络号 IP路由:网…