Pytorch剪枝api测试和结果

news/2024/4/24 13:14:21/文章来源:https://blog.csdn.net/Bismarckczy/article/details/130373822

Pytorch 官方给出的prune接口

下面是基于prune的接口进行剪枝的方法步骤

1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有:

  • RandomUnstructured
  • L1Unstructured
  • RandomStructured
  • LnStructured
  • CustomFromMask
    ps:非结构性剪枝不会给剪枝后模型的速度带来提升。

2、选择一个方法,定义好一个model后,将要剪枝的模块,及模块剪枝的部分作为函数的参数传入剪枝参数

from torch.nn.utils import prune 
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
'''
module: 模型的模块名字,如 model.conv1、model.fc1 ,这些跟你在构建模型时有关,可以用 models.state_dict().keys() 查看
name:模块中要剪枝的部分,可以是、weight、bias
amount:指的是模型本次剪枝的概率
n:前面使用的是ln_structured 模型,n表示使用那种剪枝策略,L1、L2、L3
dim:表示对第几个维度进行剪枝,如卷积层可以是维度 0123
'''

3、剪枝完后会产生一个weight_mask的掩码,本身不会直接作用于模型,会产生一个weight的属性,这时候原module是不存在weight的parameter,仅仅是一个attribute
如果此时输出模型的model.state_dict().keys()
之前是 conv1.weight 变成了 conv1.weight_orig ,以及conv1.weight_mask
4、此时模型的参数仍然是没有发生变化的,需要对剪枝后的模型进行保存

prune.remove(module, 'weight')
print(list(module.named_parameters()))

5、此时模型保存的是剪枝之后的权重值,同时weight_orig已经被删除掉了
6、所以直接对每一层需要剪枝的地方选择一个剪枝方法后,直接进行剪枝就可以了,然后保存模型此时的状态参数。

对模型进行全局剪枝,prune只提供了一个全局剪枝的接口global_unstructured()

import torch.nn.utils.prune as pt_prune
pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)
'''
parameters_to_prune:list 待剪枝模块的 名字
pruning_method:全局剪枝的方法
amount:剪枝率
'''

然后对剪枝后的模块进行remove操作即可
但是全局剪枝,只支持非结构性剪枝

prune全局非结构性剪枝测试结果

# 推理模型tiny-yolov4def model_global_prune(amount: float):detect_model = Darknet('/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/cfg/yolov4-tiny.cfg')  # TODO:改成相对路径detect_model.load_weights("/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/weight/yolov4-tiny.weights")parameters_to_prune = list()nums = 0for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential):  for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):nums += 1parameters_to_prune.append((detect_model.models[i][j], 'weight'))elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):nums += 2parameters_to_prune.append((detect_model.models[i][j], 'weight'))parameters_to_prune.append((detect_model.models[i][j], 'bias'))parameters_to_prune = tuple(parameters_to_prune)assert (nums == len(parameters_to_prune))pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)for i, modules in enumerate(detect_model.models):if isinstance(modules, nn.Sequential):  for j, module in enumerate(modules):if isinstance(detect_model.models[i][j], nn.Conv2d):pt_prune.remove(detect_model.models[i][j], 'weight')elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):pt_prune.remove(detect_model.models[i][j], 'weight')pt_prune.remove(detect_model.models[i][j], 'bias')return detect_model

base_line:
base model average time : 0.2082s
bicycle:0.605963
truck:0.814734
dog:0.870323

case 1: 剪枝率0.5 只剪卷积层.
model_pruned average time:0.2122
bicycle:0.597527
truck:0.825150
dog:0.592364

case2: 全局非结构性剪枝 剪枝率0.2
model_pruned average time:0.2078
bicycle:0.637542
truck:0.839107
dog:0.851859

case4:全局非结构性剪枝 剪枝率0.5 只剪bn层
‘’‘精度降为0’‘’

case4:全局非结构性剪枝 剪枝率0.2 只剪bn层
model_pruned average time : 0.2666
truck: 0.714715
truck: 0.594537
cat: 0.435578

case5:全局非结构性剪枝 剪枝率
model_pruned average time:0.2138
bicycle:0.636104
truck:0.840595
dog:0.850322

prune结构性剪枝测试结果

通过L2方法对模型的卷积层进行结构化剪枝(剪枝率0.5、0.4、0.2、0.1),剪枝完后模型的速度并没有变快,相反,模型的精度大幅度的下降,(模型精度下降的问题不知道是不是需要进行重新训练来提升,但是模型的速度并未得到提升)

结论:对于训练好的模型,prune接口只是提供了一种方法去“剪掉”模型每一层中最不重要的结构。而并没有稀疏训练这一步,导致在结构性剪枝中,模型的精度大幅度下降map趋近于0。同时剪枝方法只是使用简单的L1或L2对权重参数进行计算。
此外,接口中的“剪枝”只是找到模型中那些位置不重要参数,生成相应大小的掩膜,把不重要的位置置0,但是并没有删除与这些位置相连的前后层(只针对结构性剪枝而言),最后模型的权重大小并未发生改变,只是不重要的位置的参数大小变为了0,使得模型的速度并未提升。即使模型剪枝率达到95%,模型的速度仍与baseline保持一致。

结论:pytorch的官方接口并不能直接使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_103521.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

隋唐洛阳“西宫”:上阳宫的GIS视角

隋唐洛阳城简介 营建 隋大业元年(605年),在隋炀帝的授意下,隋代著名城市设计师宇文恺,在汉魏故城以西重新选址,历时8个月,日役劳工200万,兴建新都洛阳城。 城和苑 隋唐洛阳城采用…

eBPF技术介绍

前言 eBPF起源于linux内核,它可以以砂箱程序运行在操作系统内核的特权上下文,高效,安全,易于扩展而不需要修改内核源码或者加载内核模块。 操作系统一直是实现观测,安全和网络功能的最理想的地方,因为内核的…

优思学院|精益管理的理念是什么?

作为一个企业,我们都希望拥有高效率和优异的竞争力。但是,如何才能在竞争激烈的市场中脱颖而出?这时,精益管理理念的出现可以帮助我们。 精益管理的基本概念是什么? 精益管理的核心理念是通过消除浪费来实现生产效率…

Java线程间通信方式(3)

前文了解了线程通信方式中的CountDownLatch, Condition,ReentrantLock以及CyclicBarrier,接下来我们继续了解其他的线程间通信方式。 Phaser Phaser是JDK1.7中引入的一种功能上和CycliBarrier和CountDownLatch相似的同步工具,相…

辛弃疾最经典的10首词

他,文能挥笔填词,武能上马杀敌; 他,被称为“词中之龙”, 他,一生赤子,追求收复山河; 他,是与苏轼齐名的豪放派词人; 他是辛弃疾。 辛弃疾一生怀着赤子之…

IO多路复用——select函数

1.select函数原型和fd_set结构体说明 1.1 select函数原型 ​ 使用 select 这种 IO 多路转接方式需要调用一个同名函数 select,这个函数是跨平台的,Linux、Mac、Windows 都是支持的。程序员通过调用这个函数可以委托内核帮助我们检测若干个文件描述符的…

【MCS-51】51单片机结构原理

至今为止,MCS-51系列单片机有许多种型号的产品:其中又分为普通型51(8031、8051、89S51)和增强型52(8032、8052、89S52等)。它们最大的区别在于存储器配置各有差异。下面我举例子的都是8051这一系列的单片机…

STM32-HAL-定时器(无源蜂鸣器的驱动)

文章目录 一、蜂鸣器的介绍二、常用的无源蜂鸣器的电路三、测试准备四、初始化片上外设4.1 初始化定时器4的通道2为PWM输出模式4.2 编写驱动代码4.3 Logic分析仪查看波形4.4 代码分析 一、蜂鸣器的介绍 有源蜂鸣器: 有源蜂鸣器内部有一个发声电路,也就是“源”&…

数据湖Iceberg-Hive集成Iceberg(3)

文章目录 Hive集成Iceberg环境准备Hive与Iceberg的版本对应关系如下上传jar包,拷贝到Hive的auxlib目录中修改hive-site.xml,添加配置项启动 HMS 服务启动 Hadoop 创建和管理 Catalog默认使用 HiveCatalog指定 Catalog 类型使用 HiveCatalog使用 HadoopCa…

C++学习记录——이십 map和set

文章目录 1、setmultiset 2、map3、map::operator[] 1、set vector/list/deque等是序列式容器,map,set是关联式容器。序列式容器的特点就是数据线性存放,而关联式容器的数据并不是线性,数据之间有很强的关系。 它们的底层是平衡…

在当前互联网行情下,Android想转音视频开发,会有前景吗?

前言 近年来,由于三年疫情的影响,很多公司都开始陆陆续续的在裁员,Android开发工作岗位也是,可能有些从事Android开发的朋友还没有意识到,Android开发岗位正在变少,求职者,僧多粥少&#xff0c…

视频大文件传输的演变:从“卷轴男孩”到自动化

200年前,从纽约市到英国伦敦的单程旅行需要乘坐一艘跨大西洋轮船将近三周——如果你能负担得起的话,那就是。那些不能在满是汗水、狭窄的帆船上安顿大约一个半月的人。 今天,视频专业人士能够在几小时甚至几分钟内跨越相同的物理距离传输大量…

《用于估计血压变化的光电体积描记图和心电图的特征》阅读笔记

目录 一、摘要 二、十大问题 Q1论文试图解决什么问题? Q2这是否是一个新的问题? Q3这篇文章要验证一个什么科学假设? Q4有哪些相关研究?如何归类?谁是这一课题在领域内值得关注的研究员? Q5论文中提…

微信小程序第五节——登录那些事儿(超详细的前后端完整流程)

📌 微信小程序第一节 ——自定义顶部、底部导航栏以及获取胶囊体位置信息。 📌 微信小程序第二节 —— 自定义组件 📌 微信小程序第三节 —— 页面跳转的那些事儿 📌 微信小程序第四节—— 网络请求那些事儿 😜作 …

MFC之CRect详解

2023年4月25日,周二晚上。 今天查了不少关于CRect类及其相关内容的资料,学到了不少东西,所以我决定写一篇详细的关于CRect类及其相关内容的文章,以记录今天所学。 CRect类 在 MFC 中,CRect 类表示一个矩形区域。它是…

linux 命令之 tar -czvf和 tar -xzvf

文章目录 一、概述:二、基础知识 一、概述: tar 用于linux 系统中压缩和解压 二、基础知识 tar常用命令参数说明 tar命令的czvf/xzvf参数分别代表的意义如下: -c 或–create 建立新的备份文件。 -x或–extract或–get 从备份文件中还原文件…

SparkStreaming学习之——无状态与有状态转化、遍历kafka的topic消息、WindowOperations

目录 一、状态转化 二、kafka topic A→SparkStreaming→kafka topic B (一)rdd.foreach与rdd.foreachPartition (二)案例实操1 1.需求: 2.代码实现: 3.运行结果 (三)案例实操2 1.需求: 2.代码实现: 3.运行结果 三、W…

Eclipse代码提示突然失灵的解决方案

不知道改动了啥,突然间Eclipse的代码提示就失效了,发现缺少后极不方便。 使用快捷键:Alt/ 提示 No Default Proposals 为什么使用快捷键:Alt/ 会提示“No Default Proposals。”呢? 网上提示可能是热键冲突 但是一套…

数据可视化大屏电商数据展示平台开发实录(Echarts柱图曲线图、mysql筛选统计语句、时间计算、大数据量统计)

数据可视化大屏电商数据展示平台 一、前言二、项目介绍三、项目展示四、项目经验分享4.1 翻牌器4.1.1 翻牌器-今日实时交易4.1.2.翻牌器后端统计SUM函数的使用 4.2 不同时间指标的数据MySql内部的时间计算 4.3 实时交易播报MySql联表查询和内部遍历循环 4.4 每日交易量4.4.1.近…

5.5 高斯型求积公式简历

学习目标: 我会按照以下步骤学习高斯求积公式简介: 理解积分的概念:学习什么是积分以及积分的几何和物理意义,如面积、质量、电荷等概念。 掌握基本的积分技巧:掌握基本的积分公式和技巧,如换元法、分部积…