使用Pytorch导出自定义ONNX算子

news/2024/4/21 12:41:30/文章来源:https://blog.csdn.net/qq_37541097/article/details/136464555

在实际部署模型时有时可能会遇到想用的算子无法导出onnx,但实际部署的框架是支持该算子的。此时可以通过自定义onnx算子的方式导出onnx模型(注:自定义onnx算子导出onnx模型后是无法使用onnxruntime推理的)。下面给出个具体应用中的示例:需要导出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又无法正常导出该算子,故可通过如下自定义算子代码导出。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypesclass CustomAffineGrid(Function):@staticmethoddef forward(ctx, theta: torch.Tensor, size: torch.Tensor):grid = F.affine_grid(theta=theta, size=size.cpu().tolist())return grid@staticmethoddef symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor):return g.op("AffineGrid", theta, size)class MyModel(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor):grid = CustomAffineGrid.apply(theta, size)x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros")return xdef main():with torch.inference_mode():custum_model = MyModel()x = torch.randn(1, 3, 224, 224)theta = torch.randn(1, 2, 3)size = torch.as_tensor([1, 3, 512, 512])torch.onnx.export(model=custum_model,args=(x, theta, size),f="custom.onnx",input_names=["input0_x", "input1_theta", "input2_size"],output_names=["output"],dynamic_axes={"input0_x": {2: "h0", 3: "w0"},"output": {2: "h1", 3: "w1"}},opset_version=16,operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)if __name__ == '__main__':main()

在上面代码中,通过继承torch.autograd.Function父类的方式实现导出自定义算子,继承该父类后需要用户自己实现forward以及symbolic两个静态方法,其中forward方法是在pytorch正常推理时调用的函数,而symbolic方法是在导出onnx时调用的函数。对于forward方法需要按照正常的pytorch语法来实现,其中第一个参数必须是ctx但对于当前导出onnx场景可以不用管它,后面的参数是实际自己传入的参数。对于symbolic方法的第一个必须是g,后面的参数任为实际自己传入的参数,然后通过g.op方法指定具体导出自定义算子的名称,以及输入的参数(注:上面示例中传入的都是Tensor所以可以直接传入,对与非Tensor的参数可见下面一个示例)。最后在使用时直接调用自己实现类的apply方法即可。使用netron打开自己导出的onnx文件,可以看到如下所示网络结构。
在这里插入图片描述

有时按照使用的推理框架导出自定义算子时还需要设置一些参数(非Tensor)那么可以参考如下示例,例如要导出int型的参数k那么可以通过传入k_i来指定,要导出float型的参数scale那么可以通过传入scale_f来指定,要导出string型的参数clockwise那么可以通过传入clockwise_s来指定:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypesclass CustomRot90AndScale(Function):@staticmethoddef forward(ctx, x: torch.Tensor):x = torch.rot90(x, k=1, dims=(3, 2))  # clockwise 90x *= 1.2return x@staticmethoddef symbolic(g: torch.Graph, x: torch.Tensor):return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes")class MyModel(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: torch.Tensor):return CustomRot90AndScale.apply(x)def main():with torch.inference_mode():custum_model = MyModel()x = torch.randn(1, 3, 224, 224)torch.onnx.export(model=custum_model,args=(x,),f="custom_rot90.onnx",input_names=["input"],output_names=["output"],dynamic_axes={"input": {2: "h0", 3: "w0"},"output": {2: "w0", 3: "h0"}},opset_version=16,operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)if __name__ == '__main__':main()

使用netron打开自己导出的onnx文件,可以看到如下所示信息。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_998041.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode:LCR 006. 两数之和 II - 输入有序数组(python3解法)

难度&#xff1a;简单 给定一个已按照 升序排列 的整数数组 numbers &#xff0c;请你从数组中找出两个数满足相加之和等于目标数 target 。 函数应该以长度为 2 的整数数组的形式返回这两个数的下标值。numbers 的下标 从 0 开始计数 &#xff0c;所以答案数组应当满足 0 <…

【教程】 iOS构建版本无效问题解决方案

引言 在进行iOS应用上架时&#xff0c;有时会遇到构建版本无效的问题&#xff0c;即通过XCode上传成功后&#xff0c;但在App Store Connect的TestFlight中无法显示构建版本&#xff0c;或者显示一会儿后就消失了。本文将介绍可能的原因分析&#xff0c;并提供解决问题的方法。…

ubuntu系统(11):ubuntu20.04命令行安装vscode

目录 1、更新软件包索引&#xff0c;并且安装依赖软件 2、使用 wget 命令插入 Microsoft GPG key 3、启用vscode存储库 4、更新软件包并安装vscode 5、当前目录进入vscode 6、设置样式&#xff0c;添加所需扩展 最近换了个新的服务器&#xff0c;所以要重新配置服务器的…

[DevOps云实践] 跨AWS账户及Region调用Lambda

[DevOps云实践] 跨AWS账户及Region调用Lambda 本文將幫大家理清一下幾個問題: 如何跨不同AWS賬戶,不同Region來調用Lambda? 不同Lambda之間如何互相調用?有時我們希望我們的Lambda脚本能夠運行在多個AWS賬戶中的不同Region下,但是,我們還不希望每個下面都去建立一個運行…

算法刷题day20:二分

目录 引言概念一、借教室二、分巧克力三、管道四、技能升级五、冶炼金属六、数的范围七、最佳牛围栏八、套餐设计九、牛的学术圈I十、我在哪&#xff1f; 引言 这几天一直在做二分的题&#xff0c;都是上了难度的题目&#xff0c;本来以为自己的二分水平已经非常熟悉了&#x…

Oracle定时任务和存储过程

--1.声明定时任务 DECLAREjob NUMBER; BIGIN dbms_job.sumit(job, --任务ID,系统定义的test_prcedure(19)&#xff0c;--调用存储过程&#xff1f;to_date(20240305 02:00&#xff0c;yyyymmdd hh24:mi) --任务开始时间sysdate1/(24*60) --任务执行周期 [每分钟执行…

springboot 加入 日志+ controller 加入全局异常捕获

提下比较好点 包含将捕获的异常堆栈完整的返回给前端。方便 后端人员用 swagger 或 knife 工具验证接口时&#xff0c;直接看到异常。 有啥用呢&#xff1f;在现场环境&#xff0c;或不方便远程服务器机器时&#xff0c;非常有用&#xff01;&#xff01;&#xff01; 同时&…

统信UOS及麒麟KYLINOS操作系统上如何切换键盘布局

原文链接&#xff1a;如何切换键盘布局 | 统信UOS | 麒麟KYLINOS Hello&#xff0c;大家好啊&#xff0c;最近有朋友在群里提到他的键盘输入“Y”会显示“Z”&#xff0c;输入“Z”会显示“Y”。这个问题听起来可能有些奇怪&#xff0c;但其实并不罕见。出现这种情况的原因&…

性别和年龄的视频实时监测项目

注意&#xff1a;本文引用自专业人工智能社区Venus AI 更多AI知识请参考原站 &#xff08;[www.aideeplearning.cn]&#xff09; 性别和年龄检测 Python 项目 首先介绍性别和年龄检测的高级Python项目中使用的专业术语 什么是计算机视觉&#xff1f; 计算机视觉是使计算机能…

设计模式(十):抽象工厂模式(创建型模式)

Abstract Factory&#xff0c;抽象工厂&#xff1a;提供一个创建一系列相关或相互依赖对 象的接口&#xff0c;而无须指定它们的具体类。 之前写过简单工厂和工厂方法模式(创建型模式)&#xff0c;这两种模式比较简单。 简单工厂模式其实不符合开闭原则&#xff0c;即对修改关闭…

C#,入门教程(26)——数据的基本概念与使用方法

上一篇&#xff1a; C#&#xff0c;入门教程(25)——注释&#xff08;Comments&#xff09;你会吗&#xff1f;看多图演示&#xff0c;学真正注释。https://blog.csdn.net/beijinghorn/article/details/124681888 本文所述的知识基本上适用于C/C&#xff0c;java等其他语言。 …

Rethinking Data Augmentation for Image Super-resolution

文章目录 Rethinking Data Augmentation for Image Super-resolution:1.概述2.一些现有方法的分析3.cutblur4.MOA 各种策略的混合5.降噪6.cutblur 代码 Rethinking Data Augmentation for Image Super-resolution: A Comprehensive Analysis and a New Strategy 1.概述 根据…

【JavaScript】字符串练习

练习 1&#xff1a;"smyhvaevaesmyh"查找字符串中所有 m 出现的位置。 代码实现&#xff1a; var str2 smyhvaevaesmyh; for (var i 0; i < str2.length; i) {//如果指定位置的符号 "o"//str2[i]if (str2.charAt(i) m) {console.log(i);} }练习 2&…

蚂蚁SEO什么是蜘蛛池2024最新强势蜘蛛池

蜘蛛池是一种搜索引擎优化&#xff08;SEO&#xff09;策略&#xff0c;通过在互联网上建立大量的网站和链接&#xff0c;吸引搜索引擎的爬虫&#xff08;也称为“蜘蛛”&#xff09;访问&#xff0c;以提高网站的搜索排名和曝光率。以下是关于蜘蛛池的详细解释&#xff1a; 获…

【网站项目】202物流管理系统

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

排序算法及Arrays

冒泡排序 1.相邻的数据两两比较&#xff0c;小的放前面&#xff0c;大的放后面。 2.第一轮比较完毕后&#xff0c;最大值已经确定了&#xff0c;第二轮可以少循环一次&#xff0c;后面依次类推。 3.如果数组中有n个数据&#xff0c;总共我们只执行n-1轮的代码就可以。 pack…

K8S之实现业务的金丝雀发布

如何实现金丝雀发布 金丝雀发布简介优缺点在k8s中实现金丝雀发布 金丝雀发布简介 金丝雀发布的由来&#xff1a;17 世纪&#xff0c;英国矿井工人发现&#xff0c;金丝雀对瓦斯这种气体十分敏感。空气中哪怕有极其微量的瓦斯&#xff0c;金丝雀也会停止歌唱&#xff1b;当瓦斯…

vscode 使用ssh进行远程开发 (remote-ssh),首次连接及后续使用,详细介绍

在vscode添加remote ssh插件 首次连接 选择左侧栏的扩展&#xff0c;并搜索remote ssh 它大概长这样&#xff0c;点击安装 安装成功后&#xff0c;在左侧栏会出现远程连接的图标&#xff0c;点击后选择ssh旁加号便可以进行连接。 安装成功后vscode左下角会有一个图标 点击图…

基于springboot的迷你天猫商城设计与实现

目 录 摘 要 I Abstract II 引 言 1 1 系统开发技术 3 1.1 Springboot 3 1.2 MyEclipse 3 1.3 MySQL 3 1.4 Apache JMeter 3 1.5 系统开发背景 4 1.6 系统需求分析 4 1.7 本章小结 4 2 系统分析 5 2.1 技术可行性分析 5 2.2 系统经济可行性分析 5 2.3 系统功能需求分析 5 2.4 …

网工学习 DHCP配置-接口模式

网工学习 DHCP配置-接口模式 学习DHCP总是看到&#xff0c;接口模式、全局模式、中继模式。理解起来也不困难&#xff0c;但是自己动手操作起来全是问号。跟着老师视频配置啥问题没有&#xff0c;自己组建网络环境配置就是不通&#xff0c;悲催。今天总结一下我学习接口模式的…