pytorch的buffer学习整理

news/2024/5/18 21:08:44/文章来源:https://blog.csdn.net/qq_38765642/article/details/128013565

pytorch模型中的buffer

这段时间忙于做项目,但是在项目中一直在模型构建中遇到buffer数据,所以花点时间整理下模型中的parameter和buffer数据的区别💕

1.torch.nn.Module.named_buffers(prefix=‘‘, recurse=True)

贴上pytorch官网对其的说明:
在这里插入图片描述
官网翻译:

named_buffers(prefix='', recurse=True)
方法: named_buffers(prefix='', recurse=True)Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.返回一个迭代器,该迭代器能够遍历模块的缓冲buffer,并且迭代返回的结果是缓冲的名字和缓冲本身.Parameters  参数prefix (str) – prefix to prepend to all buffer names.prefix (字符串) – 添加到所有缓冲名字之前的前缀.recurse (bool)if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.recurse (布尔类型) – 如果该参数是True,那么表示递归地迭代返回,即迭代返回该模块的缓冲以及该模块的所有子模块的缓冲. 默认为TrueYields  迭代返回(string, torch.Tensor) – Tuple containing the name and buffer(字符串,torch.Tensor类型) - 包含缓冲名字和缓冲自身的元组Example:  例子:>>> for name, buf in self.named_buffers():>>>    if name in ['running_var']:>>>        print(buf.size())

总结,缓冲buffer必须要登记注册才会有效,如果仅仅将张量赋值给Module模块的属性,不会被自动转为缓冲buffer.因而也无法被state_dict()、buffers()、named_buffers()访问到.此外state_dict()可以遍历缓冲buffer和参数Parameter.
可以概括为,缓冲buffer和参数Parameter的区别是前者不需要训练优化,而后者需要训练优化.在创建方法上也有区别,前者必须要将一个张量使用方法register_buffer()来登记注册,后者比较灵活,可以直接赋值给模块的属性,也可以使用方法register_parameter()来登记注册.
下面使用代码测试一下buffer数据:

import torch 
import torch.nn as nn
torch.manual_seed(seed=20200910)
class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.conv1=torch.nn.Sequential(  # 输入torch.Size([64, 1, 28, 28])torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),  # 输出torch.Size([64, 64, 28, 28]))self.attribute_buffer_in = torch.randn(3,5)                       # 仅仅赋值给模型属性,是无法访问到该buffer数据register_buffer_in_temp = torch.randn(4,6)               self.register_buffer('register_buffer_in', register_buffer_in_temp)   # 注册buffer数据,才能生效,能获取到数据def forward(self,x): passprint('cuda(GPU)是否可用:',torch.cuda.is_available())
print('torch的版本:',torch.__version__)
model = Model() #.cuda()print('初始化之后模型修改之前'.center(100,"-"))
print('调用named_buffers()'.center(100,"-"))   
for name, buf in model.named_buffers():print(name,'-->',buf.shape)print('调用named_parameters()'.center(100,"-"))
for name, param in model.named_parameters():     # 访问模型的parameter参数数据的名字和其本身print(name,'-->',param.shape)print('调用buffers()'.center(100,"-"))           # 访问模型中的buffer数据本身
for buf in model.buffers():print(buf.shape)print('调用parameters()'.center(100,"-"))        # 访问模型中的parameter数据本身
for param in model.parameters():print(param.shape)print('调用state_dict()'.center(100,"-"))        # 同时获取模型的parameter参数数据、buffer参数数据
for k, v in model.state_dict().items():print(k, '-->', v.shape)model.attribute_buffer_out = torch.randn(10,10)      # 赋值给模型属性
register_buffer_out_temp = torch.randn(15,15)
model.register_buffer('register_buffer_out', register_buffer_out_temp)  # 通过注册的方式,使得模型的buffer成员属性生效
print('模型初始化以及修改之后'.center(100,"-"))
print('调用named_buffers()'.center(100,"-"))         # 修改模型buffer属性之后,访问buffer数据名字和其本身
for name, buf in model.named_buffers():print(name,'-->',buf.shape)print('调用named_parameters()'.center(100,"-"))      # 修改模型buffer属性之后,访问模型parameter数据名字和其本身
for name, param in model.named_parameters():print(name,'-->',param.shape)print('调用buffers()'.center(100,"-"))
for buf in model.buffers():print(buf.shape)print('调用parameters()'.center(100,"-"))
for param in model.parameters():print(param.shape)print('调用state_dict()'.center(100,"-"))
for k, v in model.state_dict().items():print(k, '-->', v.shape)  

输出结果为:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。尝试新的跨平台 PowerShell https://aka.ms/pscore6加载个人及系统配置文件用了 840 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '63490' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test2.py'
cuda(GPU)是否可用: True
torch的版本: 1.2.0+cu92
--------------------------------------------初始化之后模型修改之前---------------------------------------------
-----------------------------------------调用named_buffers()------------------------------------------
register_buffer_in --> torch.Size([4, 6])                     # 
----------------------------------------调用named_parameters()----------------------------------------
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------调用buffers()---------------------------------------------
torch.Size([4, 6])
-------------------------------------------调用parameters()-------------------------------------------
torch.Size([64, 1, 3, 3])
torch.Size([64])
-------------------------------------------调用state_dict()-------------------------------------------
register_buffer_in --> torch.Size([4, 6])
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------模型初始化以及修改之后---------------------------------------------
-----------------------------------------调用named_buffers()------------------------------------------
register_buffer_in --> torch.Size([4, 6])
register_buffer_out --> torch.Size([15, 15])
----------------------------------------调用named_parameters()----------------------------------------
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------调用buffers()---------------------------------------------
torch.Size([4, 6])
torch.Size([15, 15])
-------------------------------------------调用parameters()-------------------------------------------
torch.Size([64, 1, 3, 3])
torch.Size([64])
-------------------------------------------调用state_dict()-------------------------------------------
register_buffer_in --> torch.Size([4, 6])
register_buffer_out --> torch.Size([15, 15])
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> 

模型中的buffer和parameter区别

在这里插入图片描述
在这里插入图片描述
下面使用代码进行说明:
pytorch保存模型参数的一种方式为:

# save
torch.save(model.state_dict(), PATH)# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

可以看到模型保存的是 model.state_dict() 的返回对象。 model.state_dict() 的返回对象是一个 OrderDict ,它以键值对的形式包含模型中需要保存下来的参数,例如:

class MyModule(nn.Module):def __init__(self, input_size, output_size):super(MyModule, self).__init__()self.lin = nn.Linear(input_size, output_size)def forward(self, x):return self.lin(x)module = MyModule(4, 2)
print(module.state_dict())

输出结果:
在这里插入图片描述
分析一个parameter和buffer的例子:

class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()buffer = torch.randn(2, 3)  # tensorself.register_buffer('my_buffer', buffer)self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量def forward(self, x):# 可以通过 self.param 和 self.my_buffer 访问pass
model = MyModel()
for param in model.parameters():print(param)
print("----------------")
for buffer in model.buffers():print(buffer)
print("----------------")
print(model.state_dict())

输出结果:
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_224881.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式文件系统HDFS实践及原理详解part3

HDFS原理 说明:3.5开头目录是因为和上篇文章内容同属一章,所以开头使用了3.5 3.5 HDFS核心设计 3.5.1 心跳机制 1、 Hadoop 是 Master/Slave 结构,Master 中有 NameNode 和 ResourceManager,Slave 中有 Datanode 和 NodeManag…

异构网络小入

A Survey of Heterogeneous Information Network Analysis Heterogeneous Graph Attention Network 异构网络很火吗? 在一个网络中,不用节点的类型不同,这是肯定的。 所以,异构网络在表征比较复杂的情形时,是比较合适…

基于图像识别的小车智能寻迹控制系统

目录 摘要…… I Abstract II 基于图像识别的智能寻迹控制系统设计 I Design of Intelligent tracking Control system based on Image recognition II 目录 III 第1章 绪论 1 1.1 课题背景 1 1.1 国内外文献综述 1 1.2 论文研究内容 2 第2章 基于图像识别的智能寻迹控制系统方…

【安装Ubuntu18.04遇到的问题】未找到WIFI适配器

大家好,我是小政。好久没有更新文章,近期开始陆续分享一些研究生阶段正在学习的知识和遇到的一些问题。 联想拯救者Y9000P关于安装Ubuntu未找到WIFI适配器的解决方法1.Ubuntu18.042.网卡信息3.解决方法(1)用手机USB连接电脑提供网…

动态规划--树型dp

6个题1. 树的最长路径2.树的中点.由于第三题需要用到一些数学地知识,所以先去补一补数学知识。连接链接在这里4.二叉苹果树5.战略游戏6.皇宫守卫1. 树的最长路径 定义:树中两个点直接的最远距离称为树的直径 先说一个结论 先任意找到一个树中一个点u&am…

分布式协调系统ZooKeeper实践与原理剖析

基础的一些知识,高阶知识后续看看补充 第一章 ZooKeeper概述 1.1 介绍 What is ZooKeeper? Apache ZooKeeper is an effort to develop and maintain an open-source server which enables highly reliable distributed coordination ZooKeeper is…

大学生静态HTML网页设计--公司官网首页

⛵ 源码获取 文末联系 ✈ Web前端开发技术 描述 网页设计题材,DIVCSS 布局制作,HTMLCSS网页设计期末课程大作业 | 公司官网网站 | 企业官网 | 酒店官网 | 等网站的设计与制 HTML期末大学生网页设计作业,Web大学生网页 HTML:结构 CSS&#xf…

SpringIoc依赖查找-5

1. 依赖查找的今世前生: Spring IoC容器从Java标准中学到了什么? 单一类型依赖查找 JNDI - javax.naming.Context#lookup(javax.naming.Name) JavaBeans - java.beans.beancontext.BeanContext 集合类型依赖查找 java.beans.beancontext.BeanContext 集合查找方法 层…

sqli-labs/Less-51

这一关的欢迎界面依然是以sort作为注入点 我们首先来判断一下是否为数字型注入 输入如下 sortrand() 对尝试几次 发现页面并没有发生变化 说明这道题的注入类型属于字符型 然后尝试输入以下内容 sort1 报错了 报错信息如下 我们从报错信息可以知道这道题的注入类型属于单…

期末前端web大作业——HTML+CSS+JavaScript仿京东购物商城网页制作(7页)

常见网页设计作业题材有 个人、 美食、 公司、 学校、 旅游、 电商、 宠物、 电器、 茶叶、 家居、 酒店、 舞蹈、 动漫、 服装、 体育、 化妆品、 物流、 环保、 书籍、 婚纱、 游戏、 节日、 戒烟、 电影、 摄影、 文化、 家乡、 鲜花、 礼品、 汽车、 其他等网页设计题目, A…

#边学边考 必修5 高项:对人管理 第2章 项目沟通管理和干系人管理

答题报告 自我分析 有可能是间隔时间太长,本章节从开始学习到今天(11.24)学完,中间至少停止了1周以上,造成对基本知识记忆不牢固。对重点知识没有重点记忆,走马观花,以至于混淆。 答题解析 关…

MySQL 进阶 图文详解InnoDB储存引擎

前言 SQL 语句的最终执行者是存储引擎。存储引擎在经解析器、优化器处理后被执行器调用其接口执行优化后的执行计划。MySQL 存储引擎包括 InnoDB、Myisam、Memory、Archive、CSV 存储引擎等,其中最常用也是MySQL 默认的存储引擎是 InnoDB。 写入缓冲池(…

用DIV+CSS技术设计的水果介绍网站(web前端网页制作课作业)

🎀 精彩专栏推荐👇🏻👇🏻👇🏻 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业…

软件测试面试技巧有哪些?可以从这2个方面去进行准备

面试所有只职场人,通往工作岗位的第一道关卡,也是最重要的一道门槛。而面试中,如何回答HR提出的问题很大程度上决定了面试能不能成功。所以这些软件测试的面试技巧你可不能错过了。 首先是自我介绍 自我介绍的时间不能太短,几十秒…

(附源码)计算机毕业设计JavaJava毕设项目财务管理系统的设计与实现

项目运行 环境配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: Springboot mybatis Maven Vue 等等组成,B/…

【Flutter】shape 属性 ShapeBorder,形状

文章目录前言一、shape 是什么?二、不同的形状1.BeveledRectangleBorder2.Border3.CircleBorder圆形4.ContinuousRectangleBorder连续圆角5.StadiumBorder 体育场边界 ,药丸形状6.OutlineInputBorder外边框可以定制圆角7.UnderlineInputBorder下划线总结…

Springboot Security 前后端分离模式自由接口最小工作模型

但凡讲解Springboot Security的教程,都是根据其本身的定义,前后端整合在一起,登录采用form或者basic。我们现在的很多项目,前后端分离,form登录已经不适用了。很多程序的架构要求所有的接口都采用application/json方式…

复制集群架构设计技巧

Redis Sentinel设计技巧 Redis Sentinel基本架构 Monitoring Sentinel可以监控Redis节点的状态 Notification Sentinel可以通过API进行集群状态通知 Automatic failover Sentinel实现故障自动切换 Configuration provider Sentinel为client提供发现master节点的发现功能…

Java项目:JSP校园运动会管理系统

作者主页:源码空间站2022 简介:Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 本项目包含三种角色:运动员、裁判员、管理员; 运动员角色包含以下功能: 运动员登录,个人信息修改,运动成绩…

【优化求解】粒子群算法求解干扰受限无人机辅助网络优化问题【含Matlab源码 230期】

⛄一、粒子群简介 1 粒子群优化算法 粒子群优化算法( PSO)是指通过模拟鸟群觅食的协作行为,实现群体最优化。PSO是一种并行计算的智能算法,其基本模型如下: 假设群体规模为M,在D维空间中,群体中的第i个个体表示为XD ( xm1,xm2…xm D)T,速度表示为VD ( vm1,vm2…vm D)T,位置( …