关于requires_grad和优化器optim中parameters的记录

news/2024/4/24 3:23:29/文章来源:https://blog.csdn.net/immc1979/article/details/128109715

在模型中如果设置了requires_grad=True,则表示该层要进行梯度计算,标记为False则不计算梯度,在迁移学习中一般会设置成False,这样会大量减少算力。

而optim中的parameters是定义要对那些层进行参数优化

一般在迁移学习的代码过程中我们会先把加载的模型所有层定义成requires_grad=False,再

将模型编辑成我们需要的样子,例如将全连接层的输出定义成我们要的输出。

然后根据没有冻结的层创建优化器。

# 加载模型
model_ft = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)# 冻结模型
for param in model_ft.parameters():param.requires_grad = False# 编辑模型
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 10)# 取出未冻结的层
param_update= []
for param in model_ft.parameters():if param.requires_grad:param_update.append(param)# 定义优化器
optimizer_ft = optim.Adam(param_update, lr=0.001)

这样做的好处是可以节省计算梯度和优化的算力。

但是这样做有一个问题:

我在代码中会保存优化器以便于之后的继续测试

state = {'state_dict': mymodel.state_dict(),'optimizer': optimizer.state_dict()}
torch.save(state, "save_test.pth")

读取的时候

model_ft = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)# 此时不冻结任何层 对全部层进行训练optimizer = optim.Adam(model_ft.parameters(), lr=1e-2)checkpoint = torch.load("save_test.pth")optimizer.load_state_dict(checkpoint['optimizer'])

如果保存的时候对优化器设置的是只针对某几个层进行优化,而读取的时候希望对所有的层进行优化程序就会报错:

File "D:\anaconda3\envs\pytorch_gpu\lib\site-packages\torch\optim\optimizer.py", line 171, in load_state_dict
    raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

所以,在读取优化器记录的时候要注意,如果保存的数据是针对某些层保存的,而加载的时候又希望对所有层进行训练,这时候优化器数据就没必要加载了,要么写个try,要么就把这行注释掉。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_39122.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

opengl,opengl es,egl,glfw,glew

OpenGL ES之GLFW窗口搭建 - Plato - 博客园概述 本章节主要总结如何使用GLFW来创建Opengl窗口。主要包括如下内容: OpenGl窗口创建介绍 GLFW Window版编译介绍 GLFW简单工程源码介绍 OpenGL窗口创建介绍 能用于Ohttps://www.cnblogs.com/feng-sc/p/5093262.htmlOp…

2022-11-30 Github Forking 工作流模式

Forking 工作流 fork 操作是在个人远程仓库新建一份目标远程仓库的副本,流程如下: 比如在 GitHub 上操作时,在项目的主页点击 fork 按钮(页面右上角),即可拷贝该目标远程仓库。 假设开发者 A 拥有一个远程仓…

电脑怎么提取图片中的文字?

图片记录着我们生活的点点滴滴,比如各种办公截图、查快递单号、布置的课堂作业等等,都离不开这种便捷的方法。而我们有时难免需要从图片中提取想要的文字,总不能就靠打字打到手软吧,那么电脑怎么提取图片中的文字呢?有需要的朋友…

终于有阿里p8进行了大汇总(Redis+JVM+MySQL+Spring)还有面试题解全在这里了!

Redis特性 Redis是一直基于键值对的NoSQL数据库; Redis支持5种主要数据结构:string、hash、list、set、zset以及bitmaps、hyperLoglog、GEO等特化的数据结构; Redis是内存数据库,因此它有足够好的读写性能; Redis支持…

verilog实现分频(奇数分频和偶数分频,通用版)

大家好,最近写了一些分频器的设计,发现奇数分频和偶数分频是比较常用分频效果,所以写了一个比较简单的分频代码,适用于奇数分频和偶数分频(不考虑占空比),代码已经经过测试,需要可自…

如何应对Redis并发访问带来的问题

前言 我们在使用Redis的过程中,难免会遇到并发访问及数据更新的问题。但很多场景对数据的并发修改是很敏感的,比如库存数据如果没有做好并发读取和更新的版本控制,就会导致严重的业务问题。今天就来说说应该如何做好并发访问及数据更新问题。…

ROS2--概述

ROS2概述1 ROS2对比ROS12 ROS2 通信3 核心概念4 ros2 安装5 话题、服务、动作6 参数参考1 ROS2对比ROS1 多机器人系统:未来机器人一定不会是独立的个体,机器人和机器人之间也需要通信和协作,ROS2为多机器人系统的应用提供了标准方法和通信机…

Windows系统--AD域控--DHCP服务器

Windows系统--AD域控--DHCP服务器 虚拟机网络准备 1.将VMware网络编辑器的NAT模式--取消勾选 使用本地DHCP服务器; 从机(win10)将内置网卡的IPv4网络改为 自动获取IP地址、自动获取DNS AD服务器 部署 DHCP服务器

VF01销售开票发票金额控制增强

实施隐式增强 全部代码如下: method IF_EX_BADI_SD_BILLING~INVOICE_DOCUMENT_CHECK. CALL FUNCTION ‘SIPT_DOC_CHECK_SD’ EXPORTING it_xvbrk fxvbrk it_xvbrp fxvbrp it_xkomv fxkomv it_xvbpa fxvbpa IMPORTING ev_bad_data fbad_data. “”“”“”“…

Word控件Spire.Doc 【图像形状】教程(8): 如何借助C#/VB.NET在 Word 中插入艺术字

Spire.Doc for .NET是一款专门对 Word 文档进行操作的 .NET 类库。在于帮助开发人员无需安装 Microsoft Word情况下,轻松快捷高效地创建、编辑、转换和打印 Microsoft Word 文档。拥有近10年专业开发经验Spire系列办公文档开发工具,专注于创建、编辑、转…

Kotlin高仿微信-第12篇-单聊-图片

Kotlin高仿微信-项目实践58篇详细讲解了各个功能点,包括:注册、登录、主页、单聊(文本、表情、语音、图片、小视频、视频通话、语音通话、红包、转账)、群聊、个人信息、朋友圈、支付服务、扫一扫、搜索好友、添加好友、开通VIP等众多功能。 Kotlin高仿…

【JavaEE】MyBatis

文章目录1.MyBatis介绍2.MyBatis快速入门3.Mapper代理开发4.MyBatis核心配置文件5.配置文件完成增删改查5.1 查询5.2 添加/修改5.3 删除6.MyBatis参数传递7.注解完成增删改查1.MyBatis介绍 1.什么是MyBatis? MyBatis是一款优秀的 持久层框架,用于简化JDBC开发MyBat…

入门力扣自学笔记208 C++ (题目编号:895)

895. 最大频率栈​​​​​​ 题目: 设计一个类似堆栈的数据结构,将元素推入堆栈,并从堆栈中弹出出现频率最高的元素。 实现 FreqStack 类: FreqStack() 构造一个空的堆栈。 void push(int val) 将一个整数 val 压入栈顶。 int pop() 删除…

Kotlin高仿微信-第11篇-单聊-语音

Kotlin高仿微信-项目实践58篇详细讲解了各个功能点,包括:注册、登录、主页、单聊(文本、表情、语音、图片、小视频、视频通话、语音通话、红包、转账)、群聊、个人信息、朋友圈、支付服务、扫一扫、搜索好友、添加好友、开通VIP等众多功能。 Kotlin高仿…

基于Tree-LSTM网络语义表示模型

TC;DR 目前的LSTM仅能对序列信息进行建模, 但是自然语言中通常由词组成的短语形成了句法依存的语义树。为了学习到树结构的语义信息。论文中提出了两种Tree-LSTM模型。Child-Sum、Tree-LSTM、和N-ary Tree LSTMs。实验部分的Tree-LSTM、对比多种LSTMs的…

nuxtjs中asyncData异步数据请求、代理配置、fetch网络请求、vuex的使用、中间件处理

文章目录1. asyncData异步数据请求2. 代理配置3. fetch网络请求4. vuex4.1 state中的数据展示4.2 同步方法与异步方法4.3 数据持久化处理5. 中间件处理1. asyncData异步数据请求 Nuxt.js 扩展了 Vue.js,增加了一个叫 asyncData 和 fetch 的方法,使得我们…

这或许是全网最详细的介绍预言机赛道的视频课程,通俗易通,有趣有料!

图片来源:由无界版图 AI 绘画工具生成有一句话在创业者中很流行:Web3创业三大坑,隐私、跨链、预言机……搞塌加密市场的DK和SBF还在豪华度假酒店里思考人生搞隐私,一毛钱没赚到的Tornado cash开发者却在吃牢饭……加密圈前十大资产…

力扣(LeetCode)895. 最大频率栈(C++)

设计 ①维护最大频率,②维护每个数的出现次数,③维护出现次数对应的栈。 压栈时,新数压入出现次数对应的栈,每次压入新数,维护最大频率(所有出现次数中的最大出现次数)。 弹栈时,找最大频率对应的栈&…

拖死项目的不是团队,可能是失败的管理

项目中的活动,归根结底是由人来完成的,如何发挥项目成员的能力,对于项目的成败起着至关重要的作用。如何充分地发挥团队成员的能力,对项目经理也是一个挑战。 在团队管理者我们会遇见这些难题: 1、团队凝聚力不足&a…

【MySQL 18】Docker 安装 MySQL8 .0.30

1、查看可用的 MySQL 版本 访问 MySQL 镜像库地址: https://hub.docker.com/_/mysql?tabtags 。2、拉取 MySQL 8.0.30 镜像 拉取官方的指定版本的镜像: docker pull mysql:8.0.30[rootlocalhost deploy]# docker pull mysql:8.0.30 8.0.30: Pulling…