self.register_buffer方法使用解析(pytorch)

news/2024/5/18 18:56:53/文章来源:https://blog.csdn.net/weixin_38252409/article/details/134246688

self.register_buffer就是pytorch框架用来保存不更新参数的方法。

列子如下:

self.register_buffer("position_emb", torch.randn((5, 3)))

第一个参数position_emb传入一个字符串,表示这组参数的名字,第二个就是tensor形式的参数torch.randn((5, 3),并一次初始化后保存于模型,不会有梯度传播给它,能被模型的model.state_dict()记录下来,可以理解为模型的常数。当然,你想保留固定值,使用如下代码:

self.register_buffer("position_emb", torch.tensorrt([[2,5],[8,9]]))

进一步探讨训练对该参数是否有影响,答案是:没影响。具体可看下面实现的列子代码:

import torch
from torch.nn import Embeddingclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.emb = Embedding(5, 3)self.register_buffer("position_emb", torch.randn((5, 3)))def forward(self,vec):input = torch.tensor([0, 1, 2, 3, 4])emb_vec1 = self.emb(input)emb_vec1=emb_vec1+self.position_emboutput = torch.einsum('ik, kj -> ij', emb_vec1, vec)return output
def simple_train():model = Model()vec = torch.randn((3, 1))label = torch.Tensor(5, 1).fill_(3)loss_fun = torch.nn.MSELoss()opt = torch.optim.SGD(model.parameters(), lr=0.015)print('初始化后position_emb参数:\n',model.position_emb)for iter_num in range(100):output = model(vec)loss = loss_fun(output, label)opt.zero_grad()loss.backward(retain_graph=True)opt.step()print('训练后position_emb参数:\n', model.position_emb)if __name__ == '__main__':simple_train()  # 训练与保存权重

实现结果如下:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_190614.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaEE平台技术——MyBatis

JavaEE平台技术——MyBatis 1. 对象关系映射框架——Hibernate、MyBatis2. 对象关系模型映射3. MyBatis的实现机制4. MyBatis的XML定义5. Spring事务 在观看这个之前,大家请查阅前序内容。 😀JavaEE的渊源 😀😀JavaEE平台技术——…

大数据毕业设计选题推荐-设备环境监测平台-Hadoop-Spark-Hive

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

AI:57-基于机器学习的番茄叶部病害图像识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

Web前端—网页制作(以“学成在线”为例)

版本说明 当前版本号[20231105]。 版本修改说明20231105初版 目录 文章目录 版本说明目录day07-学成在线01-项目目录02-版心居中03-布局思路04-header区域-整体布局HTML结构CSS样式 05-header区域-logo06-header区域-导航HTML结构CSS样式 07-header区域-搜索布局HTML结构CSS…

挑战100天 AI In LeetCode Day02(1)

挑战100天 AI In LeetCode Day02(1) 一、LeetCode介绍二、LeetCode 热题 HOT 100-32.1 题目2.2 题解 三、面试经典 150 题-33.1 题目3.2 题解 一、LeetCode介绍 LeetCode是一个在线编程网站,提供各种算法和数据结构的题目,面向程序…

使用Objective-C和ASIHTTPRequest库进行Douban电影分析

概述 Douban是一个提供图书、音乐、电影等文化内容的社交网站,它的电影频道包含了大量的电影信息和用户评价。本文将介绍如何使用Objective-C语言和ASIHTTPRequest库进行Douban电影分析,包括如何获取电影数据、如何解析JSON格式的数据、如何使用代理IP技…

【JavaEE】JVM 剖析

JVM 1. JVM 的内存划分2. JVM 类加载机制2.1 类加载的大致流程2.2 双亲委派模型2.3 类加载的时机 3. 垃圾回收机制3.1 为什么会存在垃圾回收机制?3.2 垃圾回收, 到底实在做什么?3.3 垃圾回收的两步骤第一步: 判断对象是否是"垃圾"第二步: 如何回收垃圾 1. JVM 的内…

计算机网络第4章-网络层(1)

引子 网络层能够被分解为两个相互作用的部分: 数据平面和控制平面。 网络层概述 路由器具有截断的协议栈,即没有网络层以上的部分。 如下图所示,是一个简单网络: 转发和路由选择:数据平面和控制平面 网络层的作用…

webgoat-(A1)injection

SQL Injection (intro) SQL 命令主要分为三类: 数据操作语言 (DML)DML 语句可用于请求记录 (SELECT)、添加记录 (INSERT)、删除记录 (DELETE) 和修改现有记录 &#xff…

【C++】详解IO流(输入输出流+文件流+字符串流)

文章目录 一、标准输入输出流1.1提取符>>&#xff08;赋值给&#xff09;与插入符<<&#xff08;输出到&#xff09;理解cin >> a理解ifstream&#xff08;读&#xff09; >> a例子 1.2get系列函数get与getline函数细小但又重要的区别 1.3获取状态信息…

升级Python版本后,anaconda navigator启动失败

anaconda navigator启动失败&#xff0c;尤其是重装不解决问题的&#xff0c;大概率是库冲突 1.通过anaconda-navigator的图标启动&#xff0c;没有反应 2.在命令窗口&#xff0c;输入anaconda-navigator&#xff0c;报错如下 anaconda-navigator 3.错误来自这里 File &quo…

小程序day02

目标 WXML模板语法 数据绑定 事件绑定 那麽問題來了&#xff0c;一次點擊會觸發兩個組件事件的話&#xff0c;該怎么阻止事件冒泡呢&#xff1f; 文本框和data的双向绑定 注意点: 只在标签里面用value“{{info}}”&#xff0c;只会是info到文本框的单向绑定&#xff0c;必须在…

Python---排序算法

文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 前言 Python中的排序算法用于对数据进行排序。排序算法可以使数据按照一定的规则进行排列&#xff0c;以便于数据的查找、统计、比较等操作。在数据分析、机器学习、图形计算等领域&#xff0c…

gcc -static 在centos stream8 和centos stream9中运行报错的解决办法

gcc -static 在centos stream8 和centos stream9中运行报错的解决办法&#xff1a; 报/usr/bin/ld: cannot find -lc 我们下载glibc-static&#xff1a; 选择x86_64的。 还有一个是libxcrypt-static&#xff0c;依旧在这个网站里搜。 rpm -ivh glibc-static-2.28-239.el8.x…

【数字三角形】

题目描述 上图给出了一个数字三角形。从三角形的顶部到底部有很多条不同的路径。对于每条路径&#xff0c;把路径上面的数加起来可以得到一个和&#xff0c;你的任务就是找到最大的和。 路径上的每一步只能从一个数走到下一层和它最近的左边的那个数或者右 边的那个数。此外…

线扫相机DALSA软件开发套件有哪些

Win10和Win7系统完整SDK目录截图&#xff1a; Sapera Configuration 缓存与内存管理&#xff0c;以及通信端口配置工具&#xff0c;部分功能等效于Detection(查找相机)内的Settings。 Sapera Log Viewer 打开Log Viewer后会显示之前发生过的所有与Sapera LT软件有关的运行信息…

无需专线、无需固定公网IP,各地安防数据如何高效上云?

某专注于安防领域的企业&#xff0c;供机场、金融、智慧大厦等行业&#xff0c;包括门禁系统、巡更系统、视频监控在内的整体解决方案。 在实际方案交付过程中&#xff0c;往往需要在多地分支机构分别部署相应的安防设备&#xff0c;并将产生的数据实时统一汇总至云平台进行管理…

【云服务器】对比传统服务器,为什么说云服务器更具优势?

个人主页&#xff1a;【&#x1f60a;个人主页】 系列专栏&#xff1a;【❤️其他领域】 文章目录 前言云服务器云服务器的优势成本可扩展性可靠性和安全性 总结 前言 2006年搜索引擎大会上&#xff0c;“云服务器”的概念孕育而生&#xff0c;时至今日云服务器与传统服务器的…

ChinaSoft 论坛巡礼 | 安全攸关软件的智能化开发方法论坛

2023年CCF中国软件大会&#xff08;CCF ChinaSoft 2023&#xff09;由CCF主办&#xff0c;CCF系统软件专委会、形式化方法专委会、软件工程专委会以及复旦大学联合承办&#xff0c;将于2023年12月1-3日在上海国际会议中心举行。 本次大会主题是“智能化软件创新推动数字经济与社…

CSS3媒体查询与页面自适应

2017年9月&#xff0c;W3C发布媒体查询(Media Query Level 4)候选推荐标准规范&#xff0c;它扩展了已经发布的媒体查询的功能。该规范用于CSS的media规则&#xff0c;可以为文档设定特定条件的样式&#xff0c;也可以用于HTML、JavaScript等语言。 1、媒体查询基础 媒体查询…