PyTorch中学习率调度器可视化介绍

news/2024/5/4 21:53:25/文章来源:https://blog.csdn.net/m0_46510245/article/details/128262227

神经网络有许多影响模型性能的超参数。一个最基本的超参数是学习率(LR),它决定了在训练步骤之间模型权重的变化程度。在最简单的情况下,LR值是0到1之间的固定值。

选择正确的LR值是具有挑战性。一方面较大的学习率有助于算法快速收敛,但它也会导致算法在最小值附近跳跃而没有达到它,甚至在它太大时跳过它。另一方面,较小的学习率可以更好地收敛到最小值,但是如果优化器太小,可能需要太长时间才能收敛,或者陷入停滞。

什么是学习率调度器?

一种帮助算法快速收敛到最优的解决方案是使用学习率调度器。学习率调度器在训练过程中根据预先定义的时间表调整学习率。

通常,学习率在训练开始时设置为比较高的值,允许更快的收敛。随着训练的进行,学习率会降低,使收敛到最优,获得更好的性能。在训练过程中降低学习率也称为退火或衰减。

学习率调度器有很多个,并且我们还可以自定义调度器。本文将介绍PyTorch中不同的预定义学习率调度器如何在训练期间调整学习率

学习率调度器

对于本文,我们使用PyTorch 1.13.0版本。你可以在PyTorch文档中关于学习率调度器的细节。

 import torch

在本文末尾的附录中会包含用于可视化PyTorch学习率调度器的Python代码。

1、StepLR

在每个预定义的训练步骤数之后,StepLR通过乘法因子降低学习率。

 from torch.optim.lr_scheduler import StepLRscheduler = StepLR(optimizer, step_size = 4, # Period of learning rate decaygamma = 0.5) # Multiplicative factor of learning rate decay

2、MultiStepLR

MultiStepLR -类似于StepLR -也通过乘法因子降低了学习率,但在可以自定义修改学习率的时间节点。

 from torch.optim.lr_scheduler import MultiStepLRscheduler = MultiStepLR(optimizer, milestones=[8, 24, 28], # List of epoch indicesgamma =0.5) # Multiplicative factor of learning rate decay

3、ConstantLR

ConstantLR通过乘法因子降低学习率,直到训练达到预定义步数。

 from torch.optim.lr_scheduler import ConstantLRscheduler = ConstantLR(optimizer, factor = 0.5, # The number we multiply learning rate until the milestone.total_iters = 8) # The number of steps that the scheduler decays the learning rate

如果起始因子小于1,那么学习率调度器在训练过程中会提高学习率,而不是降低学习率。

4、LinearLR

LinearLR -类似于ConstantLR -在训练开始时通过乘法因子降低了学习率。但是它会在一定数量的训练步骤中线性地改变学习率,直到它达到最初设定的学习率。

 from torch.optim.lr_scheduler import LinearLRscheduler = LinearLR(optimizer, start_factor = 0.5, # The number we multiply learning rate in the first epochtotal_iters = 8) # The number of iterations that multiplicative factor reaches to 1

5、ExponentialLR

ExponentialLR在每个训练步骤中通过乘法因子降低学习率。

 rom torch.optim.lr_scheduler import ExponentialLRscheduler = ExponentialLR(optimizer, gamma = 0.5) # Multiplicative factor of learning rate decay.

6、PolynomialLR

PolynomialLR通过对定义的步骤数使用多项式函数来降低学习率。

 from torch.optim.lr_scheduler import PolynomialLRscheduler = PolynomialLR(optimizer, total_iters = 8, # The number of steps that the scheduler decays the learning rate.power = 1) # The power of the polynomial.

下图为power= 1时的学习率衰减结果。

power= 2时,学习率衰减如下所示。

7、CosineAnnealingLR

CosineAnnealingLR通过余弦函数降低学习率。

可以从技术上安排学习率调整以跟随多个周期,但他的思想是在半个周期内衰减学习率以获得最大的迭代次数。

 from torch.optim.lr_scheduler import CosineAnnealingLRscheduler = CosineAnnealingLR(optimizer,T_max = 32, # Maximum number of iterations.eta_min = 1e-4) # Minimum learning rate.

两位Kaggle大赛大师Philipp Singer和Yauhen Babakhin建议使用余弦衰减作为深度迁移学习[2]的学习率调度器。

8、CosineAnnealingWarmRestartsLR

CosineAnnealingWarmRestartsLR类似于CosineAnnealingLR。但是它允许在(例如,每个轮次中)使用初始LR重新启动LR计划。

 from torch.optim.lr_scheduler import CosineAnnealingWarmRestartsscheduler = CosineAnnealingWarmRestarts(optimizer, T_0 = 8,# Number of iterations for the first restartT_mult = 1, # A factor increases TiTi after a restarteta_min = 1e-4) # Minimum learning rate

这个计划调度于2017年[1]推出。虽然增加LR会导致模型发散但是这种有意的分歧使模型能够逃避局部最小值,并找到更好的全局最小值。

9、CyclicLR

CyclicLR根据循环学习率策略调整学习率,该策略基于我们在前一节中讨论过的重启的概念。在PyTorch中有三个内置策略。

 from torch.optim.lr_scheduler import CyclicLRscheduler = CyclicLR(optimizer, base_lr = 0.0001, # Initial learning rate which is the lower boundary in the cycle for each parameter groupmax_lr = 1e-3, # Upper learning rate boundaries in the cycle for each parameter groupstep_size_up = 4, # Number of training iterations in the increasing half of a cyclemode = "triangular")

当mode = " triangle "时,学习率衰减将遵循一个基本的三角形循环,没有振幅缩放,如下图所示。

对于mode = " triangar2 ",所得到的学习率衰减将遵循一个基本的三角形循环,每个循环将初始振幅缩放一半,如下图所示。

使用mode = “exp_range”,得到的学习率衰减将如下所示。

10、OneCycleLR

OneCycleLR根据1cycle学习率策略降低学习率,该策略在2017年[3]的一篇论文中提出。

与许多其他学习率调度器相比,学习率不仅在训练过程中下降。相反,学习率从初始学习率增加到某个最大学习率,然后再次下降。

 from torch.optim.lr_scheduler import OneCycleLRscheduler = OneCycleLR(optimizer, max_lr = 1e-3, # Upper learning rate boundaries in the cycle for each parameter groupsteps_per_epoch = 8, # The number of steps per epoch to train for.epochs = 4, # The number of epochs to train for.anneal_strategy = 'cos') # Specifies the annealing strategy

使用anneal_strategy = "cos"得到的学习率衰减将如下所示。

使用anneal_strategy = “linear”,得到的学习率衰减将如下所示。

11、ReduceLROnPlateauLR

当指标度量停止改进时,ReduceLROnPlateau会降低学习率。这很难可视化,因为学习率降低时间取决于您的模型、数据和超参数。

12、自定义学习率调度器

如果内置的学习率调度器不能满足需求,我们可以使用lambda函数定义一个调度器。lambda函数是一个返回基于epoch值的乘法因子的函数。

LambdaLR通过将lambda函数的乘法因子应用到初始LR来调整学习速率。

 lr_epoch[t] = lr_initial * lambda(epoch)

MultiplicativeLR通过将lambda函数的乘法因子应用到前一个epoch的LR来调整学习速率。

 lr_epoch[t] = lr_epoch[t-1] * lambda(epoch)

这些学习率调度器也有点难以可视化,因为它们高度依赖于已定义的lambda函数。

可视化汇总

以上就是PyTorch内置的学习率调度器,应该为深度学习项目选择哪种学习率调度器呢?

答案并不那么容易,ReduceLROnPlateau是一个流行的学习率调度器。而现在其他的方法如CosineAnnealingLR和OneCycleLR或像cosineannealingwarmrestart和CyclicLR这样的热重启方法已经越来越受欢迎。

所以我们需要运行一些实验来确定哪种学习率调度器最适合要解决问题。但是可以说的是使用任何学习调度器都会影响到模型性能。

下面是PyTorch中讨论过的学习率调度器的可视化总结。

引用和附录

[1] Loshchilov, I., & Hutter, F. (2016). Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983.

[2] Singer, P. & Babakhin, Y. (2022) Practical Tips for Deep Transfer Learning. In: Kaggle Days Paris 2022.

[3] Smith, L. N., & Topin, N. (2019). Super-convergence: Very fast training of neural networks using large learning rates. In Artificial intelligence and machine learning for multi-domain operations applications (Vol. 11006, pp. 369–386). SPIE.

下面是来可视化学习率调度器的代码:

 import torchfrom torch.optim.lr_scheduler import StepLR # Import your choice of scheduler hereimport matplotlib.pyplot as pltfrom matplotlib.ticker import MultipleLocatorLEARNING_RATE = 1e-3EPOCHS = 4STEPS_IN_EPOCH = 8# Set model and optimizermodel = torch.nn.Linear(2, 1)optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)# Define your scheduler here as described above# ...# Get learning rates as each training steplearning_rates = []for i in range(EPOCHS*STEPS_IN_EPOCH):optimizer.step()learning_rates.append(optimizer.param_groups[0]["lr"])scheduler.step()# Visualize learinig rate schedulerfig, ax = plt.subplots(1,1, figsize=(10,5))ax.plot(range(EPOCHS*STEPS_IN_EPOCH), learning_rates,marker='o', color='black')ax.set_xlim([0, EPOCHS*STEPS_IN_EPOCH])ax.set_ylim([0, LEARNING_RATE + 0.0001])ax.set_xlabel('Steps')ax.set_ylabel('Learning Rate')ax.spines['top'].set_visible(False)ax.spines['right'].set_visible(False)ax.xaxis.set_major_locator(MultipleLocator(STEPS_IN_EPOCH))ax.xaxis.set_minor_locator(MultipleLocator(1))plt.show()

https://avoid.overfit.cn/post/9987b70f3d1e47d5b6a0b5ff70d8133f

作者:Leonie Monigatti

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_233453.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[论文阅读] 颜色迁移-梯度保护颜色迁移

[论文阅读] 颜色迁移-梯度保护颜色迁移 文章: [Gradient-Preserving Color Transfer], [代码未公开] 本文目的: 如题所示为梯度保护的颜色迁移方法. 1-算法原理 人类的视觉系统对局部强度差异比强度本身更敏感, 因而, 保持颜色梯度是场景保真度的必要条件, 因而作者认为: 一…

如何批量查询谷歌PR权重是多少?谷歌PR权重怎么批量查询

权重是就是网站在搜索引擎心目中的位置,如果一个网站在搜索引擎心目中的位置高的话,当然容易获得较好的排名,今天不是来跟大家聊如何提升网站权重的,而是教大家如何去看一个网站的权重,做网站的朋友都要知道要做关键词…

数据库面试题1-数据库基本概念、常用SQL语言

题1:什么是数据库 数据库(Database) 是保存有组织的数据的容器(通常是一个文件或一组文件),是通过 数据库管理系统(DataBase- Management System,DBMS) 创建和操纵的容器…

Metal每日分享,波动滤镜/涂鸦滤镜效果

本案例的目的是理解如何用Metal实现图像波动效果滤镜,还可类似涂鸦效果,主要就是对纹理坐标进行正余弦偏移处理; Demo HarbethDemo地址 实操代码 // 波动效果 let filter C7Fluctuate.init(extent: 50, amplitude: 0.003, fluctuate: 2.5…

自动驾驶两大路线对决,渐进式玩家为何更容易得人心?

HiEV消息(文/长海)对自动驾驶赛道而言,2022年的冬天格外冷冽。寒潮袭来,从各家的应变方式看,不同路径的玩家呈现“冰火两重天”,进化的趋势也越来越清晰。 以Waymo为代表、持续研发L4级无人驾驶的跨越式路线…

Python实现房产数据分析与可视化 数据分析 实战

Python库的选择 话说,工欲善其事,必先利其器,虽然我们已经选择Python来完成剩余的工作,但是我们需要考虑具体选择使用Pytho的哪些利器来帮助我们更快更好地完成剩余的工作。 我们可以看一下,在这个任务中&#xff0c…

UIAutomator测试框架介绍

uiautomator简介 UiAutomator是Google提供的用来做安卓自动化测试的一个Java库,基于Accessibility服务。功能很强,可以对第三方App进行测试,获取屏幕上任意一个APP的任意一个控件属性,并对其进行任意操作,但有两个缺点…

【Docker学习教程系列】8-如何将本地的Docker镜像发布到私服?

通过前面的学习,我们已经知道,怎么将本地自己制作的镜像发布到阿里云远程镜像仓库中去。但是在实际工作开发中,一般,我们都是将公司的镜像发布到公司自己搭建的私服镜像仓库中,那么一个私服的镜像仓库怎么搭建&#xf…

【云原生】Kubernetes(k8s)Istio Gateway 介绍与实战操作

文章目录一、概述二、Istio 架构三、通过 istioctl 部署 Istio1)安装istioctl 工具2)通过istioctl安装istio3)检查四、Istio Gateway五、Istio VirtualService 虚拟服务六、示例演示(bookinfo)1)安装bookin…

笔试强训(四十一)

目录一、选择题二、编程题2.1 Emacs计算器2.1.1 题目2.1.1 题解一、选择题 (1)某主机的IP地址为180.80.77.55,子网掩码为255.255.252.0.若该主机向其所在子网发送广播分组,则目的地址可以是(D) A.180.80.7…

制作移动端整页滚动动画

制作移动端整页滚动动画 需要用到 rem7.5.js(rem适配) pageSlider.js(控制动画的js文件) 基于zepto&#xff0c;引入zepto.js文件 animate.css(动画样式) base.css(公共样式) 下面看一下页面结构 <div class"section sec1"style"background-image:url(./ima…

ASP.NET微信快速开发框架源码【源码分享】

ASP.NET微信快速开发框架源码 微信公众平台快速开发框架源码 需要源码学习&#xff0c;查看文末卡片获取&#xff0c;或私信我。 框架主要技术&#xff1a; ASP.NET MVC5、ASP.NET Identity、Bootstrap、KnockoutJs、Entity Framework等。 主要特色&#xff1a; 1、快速迭代开…

​创新不是公司的救命良药

阅读本文大概需要1.06 分钟。之前问说当整个大环境都差的时候&#xff0c;公司还有项目可做就不错了&#xff0c;不要觉得只能赚点小钱就看不上&#xff0c;现在已经从伸手抓钱&#xff0c;变成弯腰捡钱的时代了。 开始赚的钱是不多&#xff0c;但能验证方向&#xff0c;先把跑…

【Keras+计算机视觉+Tensorflow】生成对抗神经网络中DCGAN、CycleGAN网络的讲解(图文解释 超详细)

觉得有帮助麻烦点赞关注收藏~~~ 一、生成对抗网络简介 生成对抗网络(GANs&#xff0c;Generative Adversarial Nets),由Ian Goodfellow在2014年提出的,是当今计算机科学中最有趣的概念之一。GAN最早提出是为了弥补真实数据的不足&#xff0c;生成高质量的人工数据。GAN的主要思…

Java项目中集成Redis提升系统的性能

概述 安装Redis 安装 启动Rocky Linux 9.0&#xff0c;在浏览器中打开web console. 如果没有安装Web console&#xff0c;按以下步骤安装启用&#xff1a; 安装命令&#xff1a; # dnf install cockpit 启用并运行服务 # systemctl enable --now cockpit.socket 开通防火墙&…

【每日小技巧】如果Tomcat的端口被占用,怎么处理该报错

苦恼的问题&#xff1a;当我们在用Tomcat时&#xff0c;发现我们要用的端口被其他程序占用了&#xff0c;如图&#xff1a; 解决办法&#xff1a; ①winR&#xff0c;输入cmd&#xff0c;打开命令行 输入命令netstat -ano&#xff0c;列出所有的端口号使用情况 ②查看PID&#…

Linux命令_ps 进程管理

简介 ps通过读取 /proc 中的虚拟文件来工作&#xff0c;不需要 setuid kmem 或有任何特权来运行。 CPU使用率目前表示为进程整个生命周期中运行所花费时间的百分比。这是不理想的&#xff0c;它不符合ps在其他方面所符合的标准。CPU使用率加起来不太可能达到100%。 SIZE和RSS字…

E. DS哈希查找--Trie树

目录 题目描述 思路分析 AC代码 题目描述 Trie树又称单词查找树&#xff0c;是一种树形结构&#xff0c;如下图所示。 它是一种哈希树的变种。典型应用是用于统计&#xff0c;排序和保存大量的字符串&#xff08;但不仅限于字符串&#xff09;&#xff0c;所以经常被搜索引擎…

HTML列表与表格详解_高效学习攻略

HTML列表与表格HTML篇_第六章、HTML列表与表格一、列表1.1定义1.2列表的分类1.3列表的对比二、表格2.1表格的定义2.2表格的边框2.3表格的表头单元格2.4表格标题 <caption>2.5表格的高度和宽度2.6表格背景2.7表格空间2.8合并单元格2.9表格头部、主题和页脚2.10表格的嵌套H…

【C++常用容器】STL基础语法学习queue容器

目录 ●queue的基本概念 ●queue常用接口 ●构造函数 ●赋值操作 ●数据存取 ●大小操作 ●queue的基本概念 简要介绍&#xff1a;queue是一种先进先出的的数据结构&#xff0c;它有两个出口。队列容器允许从一端新增元素&#xff0c;从另一端移除元素。队列中只有队…