SAM优化器
论文地址:https://arxiv.org/pdf/2010.01412v2.pdf
项目地址:GitHub - davda54/sam: SAM: Sharpness-Aware Minimization (PyTorch)
意义:增强泛化能力,避免过拟合
作者在实验部分的原话(翻译):
从理论上讲,您可以通过运行更长时间(1800个epoch而不是200个epoch)来获得更低的误差,因为SAM不应该容易过度拟合。SAM 使用 ,而 ASAM 设置为 ,如其作者所建议的那样。rho=0.05rho=2.0
直接运行
具体代码的讲解部分可以看这篇
base_optimizer = torch.optim.SGD # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
从项目给出的示例代码,SAM都是和其他优化器搭配使用的
代码解析
(1)StepLR类
class StepLR:def __init__(self, optimizer, learning_rate: float, total_epochs: int):self.optimizer = optimizerself.total_epochs = total_epochsself.base = learning_ratedef __call__(self, epoch):if epoch < self.total_epochs * 3/10:lr = self.baseelif epoch < self.total_epochs * 6/10:lr = self.base * 0.2elif epoch < self.total_epochs * 8/10:lr = self.base * 0.2 ** 2else:lr = self.base * 0.2 ** 3for param_group in self.optimizer.param_groups:param_group["lr"] = lrdef lr(self) -> float:return self.optimizer.param_groups[0]["lr"]
但是最后没用上
尝试在yolov5中使用
这套代码是我自己摸索出来的可能有点毛病,请见谅多指正,非常感谢。因为各个版本有所不同,我只能京可能覆盖当前比较常用的版本,但是可能还有一些问题,希望能及时指正交流
!!
个人对于模型的了解还是比较粗浅,这篇train代码写得还算比较清晰,多有学习。
(1)修改优化器首先要知道优化器在哪使用:train.py中的train函数
(2)对SAM有基本的了解
(3)上手
第三处修改被我删除了,因此序号不对应
- import
我是将SAM的py文件和utility一起copy到了yolov5的utils下
我已经整好了
链接:https://pan.baidu.com/s/1aMz8RplazAMyBpRfK1FG9w?pwd=jm9n
提取码:jm9n
解压之后放入utils就行
#(1)import文件
from utils.SAM_Optimizer.sam import SAM
from utils.SAM_Optimizer.utility.bypass_bn import enable_running_stats, disable_running_stats
from utils.SAM_Optimizer.utility.step_lr import StepLR
如果选择优化器被封装到smart_optimizer,同样也要在yolov5/utils/torch_utils.py
下copy上上面这段
- 增加"SAM"到优化器中
(1)常规版本
#(2)修改2:增加"SAM"elif opt.optimizer == 'SAM':base_optimizer = torch.optim.SGDoptimizer = SAM(g0, base_optimizer, lr=hyp['lr0'], momentum=hyp['momentum'])
(2)没有optimizer这个参数
在某些版本中已经将优化器简化到只剩SGD和adam,甚至没有optimizer这个参数,因此我只能增加optimizer这个参数,转化到和上述一样。
if opt.optimizer == 'Adam':optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum#---------------------(2)修改2:增加"SAM":默认使用SAM优化器----------------------------------------------------------elif opt.optimizer == 'SAM':base_optimizer = torch.optim.SGDoptimizer = SAM(g0, base_optimizer, lr=hyp['lr0'], momentum=hyp['momentum'])# --------------------------------------------------------else:optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
(3)封装到smart_optimizer
同样有些版本直接将选择优化器被封装到smart_optimizer,那么同样也要在yolov5/utils/torch_utils.py
下对smart_optimizer进行修改。
直接整个函数给你们
def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decayg = [], [], [] # optimizer parameter groupsbn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()for v in model.modules():if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)g[2].append(v.bias)if isinstance(v, bn): # weight (no decay)g[1].append(v.weight)elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)g[0].append(v.weight)if name == 'Adam':optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentumelif name == 'AdamW':optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)elif name == 'RMSProp':optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)elif name == 'SGD':optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)#-----------修改2--------------------------- #----------------SAM---------------elif name == 'SAM':base_optimizer = torch.optim.SGDoptimizer = SAM(g[2], base_optimizer, lr=lr, momentum=momentum)#----------------SAM---------------
#-------------------修改2 ----------------------------- else:raise NotImplementedError(f'Optimizer {name} not implemented.')optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decayoptimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")return optimizer
- enable
这部分是插入在Multi-scale后面
#--------------修改4:enable-------------------------------------修改4if(opt.optimizer=='SAM'):enable_running_stats(model)#--------------修改4:enable-------------------------------------修改4
- 为了不影响其他优化器的使用,增加if-else
直接搜索if ni - last_opt_step >= accumulate:
然后用下面这段替换一下
if ni - last_opt_step >= accumulate:#--------------修改5:第一阶段结束加上开启第二段-------------------------------------修改5if opt.optimizer=='SAM':optimizer.first_step(zero_grad=True)disable_running_stats(model)with amp.autocast(enabled=cuda):pred = model(imgs) # forwardloss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_sizeif RANK != -1:loss *= WORLD_SIZE # gradient averaged between devices in DDP modeif opt.quad:loss *= 4.# Backwardscaler.scale(loss).backward()optimizer.second_step(zero_grad=True)#--------------修改5:第一阶段结束加上开启第二段-------------------------------------修改5else:scaler.unscale_(optimizer) # unscale gradientstorch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradientsscaler.step(optimizer) # optimizer.stepscaler.update()optimizer.zero_grad()if ema:ema.update(model)last_opt_step = ni
- parser增加choice
- 【最核心修改】使用SAM时不使用半精度
#-------------修改7:不使用半精度-------------------------------if(opt.optimizer='SAM'):scaler = torch.cuda.amp.GradScaler(enabled=False)else :scaler = torch.cuda.amp.GradScaler(enabled=amp)
注意一下有些版本import的时候是from torch.cuda import amp
,此时替换成上面这么该应该没问题,也可以改成下面这种
scaler = amp.GradScaler(enabled=False)
如果使用了半精度会有nan输出,暂时还没想到更好的修改办法!