【全网首个手把手教你在yolov5中使用SAM优化器 】yolov5优化方案(2):更换SAM优化器

news/2024/5/3 5:32:15/文章来源:https://blog.csdn.net/weixin_50862344/article/details/126412909

SAM优化器

论文地址:https://arxiv.org/pdf/2010.01412v2.pdf

项目地址:GitHub - davda54/sam: SAM: Sharpness-Aware Minimization (PyTorch)

意义:增强泛化能力,避免过拟合
作者在实验部分的原话(翻译):

从理论上讲,您可以通过运行更长时间(1800个epoch而不是200个epoch)来获得更低的误差,因为SAM不应该容易过度拟合。SAM 使用 ,而 ASAM 设置为 ,如其作者所建议的那样。rho=0.05rho=2.0

在这里插入图片描述

直接运行

具体代码的讲解部分可以看这篇

base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)

从项目给出的示例代码,SAM都是和其他优化器搭配使用的

在这里插入图片描述

代码解析

(1)StepLR类

class StepLR:def __init__(self, optimizer, learning_rate: float, total_epochs: int):self.optimizer = optimizerself.total_epochs = total_epochsself.base = learning_ratedef __call__(self, epoch):if epoch < self.total_epochs * 3/10:lr = self.baseelif epoch < self.total_epochs * 6/10:lr = self.base * 0.2elif epoch < self.total_epochs * 8/10:lr = self.base * 0.2 ** 2else:lr = self.base * 0.2 ** 3for param_group in self.optimizer.param_groups:param_group["lr"] = lrdef lr(self) -> float:return self.optimizer.param_groups[0]["lr"]

但是最后没用上

尝试在yolov5中使用

这套代码是我自己摸索出来的可能有点毛病,请见谅多指正,非常感谢。因为各个版本有所不同,我只能京可能覆盖当前比较常用的版本,但是可能还有一些问题,希望能及时指正交流
!!
个人对于模型的了解还是比较粗浅,这篇train代码写得还算比较清晰,多有学习。

(1)修改优化器首先要知道优化器在哪使用:train.py中的train函数

请添加图片描述

(2)对SAM有基本的了解

(3)上手

第三处修改被我删除了,因此序号不对应

  1. import

我是将SAM的py文件和utility一起copy到了yolov5的utils下

我已经整好了

链接:https://pan.baidu.com/s/1aMz8RplazAMyBpRfK1FG9w?pwd=jm9n
提取码:jm9n

解压之后放入utils就行
在这里插入图片描述

#(1)import文件
from utils.SAM_Optimizer.sam import SAM
from utils.SAM_Optimizer.utility.bypass_bn import enable_running_stats, disable_running_stats
from utils.SAM_Optimizer.utility.step_lr import StepLR

如果选择优化器被封装到smart_optimizer,同样也要在yolov5/utils/torch_utils.py下copy上上面这段

  1. 增加"SAM"到优化器中
    (1)常规版本

在这里插入图片描述

 #(2)修改2:增加"SAM"elif opt.optimizer == 'SAM':base_optimizer = torch.optim.SGDoptimizer = SAM(g0, base_optimizer, lr=hyp['lr0'], momentum=hyp['momentum'])

(2)没有optimizer这个参数
在某些版本中已经将优化器简化到只剩SGD和adam,甚至没有optimizer这个参数,因此我只能增加optimizer这个参数,转化到和上述一样。

if opt.optimizer == 'Adam':optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999))  # adjust beta1 to momentum#---------------------(2)修改2:增加"SAM":默认使用SAM优化器----------------------------------------------------------elif opt.optimizer == 'SAM':base_optimizer = torch.optim.SGDoptimizer = SAM(g0, base_optimizer, lr=hyp['lr0'], momentum=hyp['momentum'])# --------------------------------------------------------else:optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)

(3)封装到smart_optimizer
同样有些版本直接将选择优化器被封装到smart_optimizer,那么同样也要在yolov5/utils/torch_utils.py下对smart_optimizer进行修改。
直接整个函数给你们

def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decayg = [], [], []  # optimizer parameter groupsbn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()for v in model.modules():if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):  # bias (no decay)g[2].append(v.bias)if isinstance(v, bn):  # weight (no decay)g[1].append(v.weight)elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):  # weight (with decay)g[0].append(v.weight)if name == 'Adam':optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999))  # adjust beta1 to momentumelif name == 'AdamW':optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)elif name == 'RMSProp':optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)elif name == 'SGD':optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)#-----------修改2---------------------------                #----------------SAM---------------elif name == 'SAM':base_optimizer = torch.optim.SGDoptimizer = SAM(g[2], base_optimizer, lr=lr, momentum=momentum)#----------------SAM---------------
#-------------------修改2 -----------------------------  else:raise NotImplementedError(f'Optimizer {name} not implemented.')optimizer.add_param_group({'params': g[0], 'weight_decay': decay})  # add g0 with weight_decayoptimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})  # add g1 (BatchNorm2d weights)LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")return optimizer
  1. enable
    在这里插入图片描述
    这部分是插入在Multi-scale后面
#--------------修改4:enable-------------------------------------修改4if(opt.optimizer=='SAM'):enable_running_stats(model)#--------------修改4:enable-------------------------------------修改4
  1. 为了不影响其他优化器的使用,增加if-else
    直接搜索 if ni - last_opt_step >= accumulate:然后用下面这段替换一下
            if ni - last_opt_step >= accumulate:#--------------修改5:第一阶段结束加上开启第二段-------------------------------------修改5if opt.optimizer=='SAM':optimizer.first_step(zero_grad=True)disable_running_stats(model)with amp.autocast(enabled=cuda):pred = model(imgs)  # forwardloss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_sizeif RANK != -1:loss *= WORLD_SIZE  # gradient averaged between devices in DDP modeif opt.quad:loss *= 4.# Backwardscaler.scale(loss).backward()optimizer.second_step(zero_grad=True)#--------------修改5:第一阶段结束加上开启第二段-------------------------------------修改5else:scaler.unscale_(optimizer)  # unscale gradientstorch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)  # clip gradientsscaler.step(optimizer)  # optimizer.stepscaler.update()optimizer.zero_grad()if ema:ema.update(model)last_opt_step = ni
  1. parser增加choice

在这里插入图片描述

  1. 【最核心修改】使用SAM时不使用半精度
    #-------------修改7:不使用半精度-------------------------------if(opt.optimizer='SAM'):scaler = torch.cuda.amp.GradScaler(enabled=False)else :scaler = torch.cuda.amp.GradScaler(enabled=amp)

注意一下有些版本import的时候是from torch.cuda import amp,此时替换成上面这么该应该没问题,也可以改成下面这种

scaler = amp.GradScaler(enabled=False)

如果使用了半精度会有nan输出,暂时还没想到更好的修改办法!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_7034.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第16讲:MySQL中常用的字符串函数以及基本使用

文章目录1.函数的概念以及常用的几种函数2.常用的字符串函数以及基本使用2.1.常用的几种字符串函数2.2.CONCAT将多个字符串拼接2.3.LOWER将字符串转换为小写2.4.UPPER将字符串转换为小写2.5.LPAD/RPAD字符串左右填充2.6.TRIM去除字符串两侧的空格2.7.SUBSTRING截取字符串2.8.字…

环境搭建+hello world!

php动态网页设计 1 环境下载搭建1.phpstudy服务器环境软件 下载链接:https://www.xp.cn/ 安装完毕之后可能需要重新配置端口号(默认为80,但是有可能被占用了,修改一个端口号即可) 启动apache、mysql新建一个网站(可能已经有个默认的,就用默认的) 点击管理,启动这个网站 然后…

nodejs基于微信小程序的图书销售商城系统 uniapp 小程序

伴随着社会以及科学技术的发展,互联网已经渗透在人们的身边,网络慢慢的变成了人们的生活必不可少的一部分,随着互联网的飞速发展,系统这一名词已不陌生,越来越多的书店都会使用系统来定制一款属于自己个性化的系统。书籍销售系统采用nodejs技术, mysql数据库进行开发,实现了首页…

用python找出400多万次KDJ金叉死叉,胜率有多高?附代码

引言: 邢不行的系列帖子“量化小讲堂”&#xff0c;通过实际案例教初学者使用python进行量化投Z&#xff0c;了解行业研究方向 这是邢不行第90期量化小课堂分享 作者 l 邢不行 不知道大家有没有发现&#xff0c;打开任意一个交易软件&#xff0c;无论是针对A股、美股、期货、…

spring+SpringMVC+MyBatis之配置多数据源

数据库准备   1、准备2个数据库&#xff0c;本例以mysql为例 在第一个数据库新建表user -- ---------------------------- -- Table structure for user -- ---------------------------- DROP TABLE IF EXISTS user; CREATE TABLE user (id int(11) NOT NULL AUTO_INCREME…

gateway过滤器

简介 1 作用: 过滤器就是在请求的传递过程中,对请求和响应做一些手脚 2 生命周期: Pre Post 3 分类: 局部过滤器(作用在某一个路由上) 全局过滤器(作用全部路由上) 在Gateway中, Filter的生命周期只有两个&#xff1a;“pre” 和 “post”。 PRE&#xff1a; 这种过滤器在请…

【蓝桥杯国赛真题24】Scratch货物运输 第十三届蓝桥杯 图形化编程scratch国赛真题和答案讲解

目录 scratch货物运输 一、题目要求 编程实现 二、案例分析 1、角色分析

PostGIS是什么

1. 什么是GIS(知识地图定位) 1.1. GIS概念 地理信息系统&#xff08;Geographic Information System或 Geo&#xff0d;Information system&#xff0c;GIS&#xff09;有时又称为“地学信息系统”。它是一种特定的十分重要的空间信息系统。它是在计算机硬、软件系统支持下&a…

注册中心对比和选型:Zookeeper、Eureka、Nacos、Consul和ETCD

转自:https://juejin.cn/post/7068065361312088095 下面是文章目录:注册中心基本概念 什么是注册中心? 注册中心主要有三种角色:服务提供者(RPC Server):在启动时,向 Registry 注册自身服务,并向 Registry 定期发送心跳汇报存活状态。 服务消费者(RPC Client):在启…

【Linux虚拟机安装】在VMware Workstation上安装ubuntu虚拟机

目录0、工具清单1、下载操作系统镜像2、创建虚拟机3、设置ubuntu系统0、工具清单 虚拟机软件&#xff1a;VMware Workstationubuntu镜像版本&#xff1a;Ubuntu 20.04.4 LTS (Focal Fossa)宿主机操作系统&#xff1a;Windows 10 专业版 1、下载操作系统镜像 官方下载网址&am…

氨基聚苯乙烯包覆硅胶微球SG-PS-NH2/聚苯乙烯/硫化镉PS/CdS复合材料/聚苯乙烯支载井冈霉素微球制备

今天小编给大家分享了氨基聚苯乙烯包覆硅胶微球SG-PS-NH2/聚苯乙烯/硫化镉PS/CdS复合材料/聚苯乙烯支载井冈霉素微球的制备方法&#xff0c;一起来看看&#xff01; 小编分享-氨基聚苯乙烯包覆硅胶微球SG-PS-NH2的制备方法&#xff1a; 通过对硅胶微球进行聚苯乙烯包覆,然后功…

javascript为什么叫脚本语言

脚本script是使用一种特定的描述性语言&#xff0c;依据一定的格式编写的可执行文件&#xff0c;又称作宏或批处理文件。 脚本通常可以由应用程序临时调用并执行。 各类脚本目前被广泛地应用于网页设计中&#xff0c;因为脚本不仅可以减小网页的规模和提高网页浏览速度&#xf…

为什么ASO很重要?

由于用户对多功能App的需求量增大&#xff0c;导致榜单影响力下滑&#xff0c;越来越多的用户通过搜索相关词来查找目标App。同时搜索对排名的影响权重也被各家应用商店加大。数据库显示&#xff0c;用户越来越习惯直接搜索关键词来搜索想要的应用。各应用商店收录热词现阶段有…

压缩网络相关

同样搬运模式 勿怪呀 大佬们 自从深度学习&#xff08;Deep Learning&#xff09;开始流行&#xff0c;已经在很多领域有了很大的突破&#xff0c;尤其是AlexNet一举夺得ILSVRC 2012 ImageNet图像分类竞赛的冠军后&#xff0c;卷积神经网络&#xff08;CNN&#xff09;的热潮便…

【JY】YJK前处理参数详解及常见问题分析:控制信息(二)

点击蓝字 求求关注【写在前文】本文介绍计算控制信息之控制信息。【计算信息参数详解】一、控制信息A区参数详解1、水平力与整体坐标夹角该参数为地震作用、风荷载计算时的X正向与结构整体坐标系下X轴的夹角&#xff0c;逆时针方向为正&#xff0c;单位为度。常见问题&#xf…

10、MyBatis-Plus 多数据源

第一篇&#xff1a;1、Mybatis-Plus 创建SpringBoot项目 第二篇&#xff1a;2、Mybatis-Plus 测试增、删、改、查 第三篇&#xff1a;3、Mybatis-Plus 自定义sql语句 第四篇&#xff1a;4、Mybatis-Plus 通用service的操作 第五篇&#xff1a;5、Mybatis-Plus 常用注解 第六篇&…

004-GoingDeeperConvolutions2014(googLeNet)

Going Deeper with Convolutions #paper1. paper-info 1.1 MetadataAuthor:: [[Christian Szegedy]], [[Wei Liu]], [[Yangqing Jia]], [[Pierre Sermanet]], [[Scott Reed]], [[Dragomir Anguelov]], [[Dumitru Erhan]], [[Vincent Vanhoucke]], [[Andrew Rabinovich]] 作者机…

UNIAPP----video标签层级问题的三种解决方法

uniapp的app端&#xff0c;video标签层级过高&#xff0c;无法轻易被遮盖。 三种解决方法&#xff0c;真机测试没问题。代码复制即可。 1.cover-view或者cover-image&#xff0c;放在video标签内使用&#xff0c;子绝父相 缺点&#xff1a;只能改变cover-view样式&#xff0…

算法落地思考:如何让智能运维更智能

嘉宾 | 王鹏 整理人 | 西狩xs 出品 | CSDN云原生 AIOps是人工智能与运维的结合&#xff0c;能够基于已有的运维数据&#xff0c;利用人工智能算法&#xff0c;通过机器学习的方式帮助企业提升运维效率&#xff0c;解决自动化运维无法管理的问题。 2022年8月30日&#xff0…

spring底层原理初探

一&#xff0c;spring原理初探 1&#xff0c;bean的创建生命周期 userService.class --> 推断构造方法 --> 实例化对象 --> 依赖注入(属性填充) --> 初始化前(PostConstruct) --> 初始化 (Initializingbean) --> 初始化后(AOP&#xff0c;bean的后置处理器…