Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad

news/2024/5/18 21:45:40/文章来源:https://blog.csdn.net/xian0710830114/article/details/126551268

目录

写在前面

一、 torch.optim.SGD 随机梯度下降

SGD代码

SGD算法解析

1.MBGD(Mini-batch Gradient Descent)小批量梯度下降法

 2.Momentum动量

3.NAG(Nesterov accelerated gradient)

SGD总结

二、torch.optim.ASGD随机平均梯度下降

三、torch.optim.Rprop

四、torch.optim.Adagrad 自适应梯度

Adagrad 代码

Adagrad 算法解析

AdaGrad总结


写在前面

        优化器时深度学习中的重要组件,在深度学习中有举足轻重的地位。在实际开发中我们并不用亲手实现一个优化器,很多框架都帮我们实现好了,但如果不明白各个优化器的特点,就很难选择适合自己任务的优化器。接下来我会开一个系列,以Pytorch为例,介绍所有主流的优化器,如果都搞明白了,对优化器算法的掌握也就差不多了。

        作为系列的第一篇文章,本文介绍Pytorch中的SGD、ASGD、Rprop、Adagrad,其中主要介绍SGD和Adagrad。因为这四个优化器出现的比较早,都存在一些硬伤,而作为现在主流优化器的基础又跳不过,所以作为开端吧。

        我们定义一个通用的思路框架,方便在后面理解各算法之间的关系和改进。首先定义待优化参数 \theta,目标函数J(\theta ),学习率为 \eta ,然后我们进行迭代优化,假设当前的epoch为t,参数更新步骤如下:

1. 计算目标函数关于当前参数的梯度: 

g_{t}=\bigtriangledown J(\theta _{t})                               (1)

 2. 根据历史梯度计算一阶动量和二阶动量:

m_{t}=\phi (g_{1},g_{2}...,g_{t})                (2)

v_{t}=\varphi (g_{1},g_{2}...,g_{t})                 (3)

 3. 计算当前时刻的下降梯度: 

\bigtriangleup _{\theta _{t}}=\eta *\frac{m_{t}}{\sqrt{v_{t}}}                           (4)

4. 根据下降梯度进行更新:  

 \theta _{t+1}=\theta _{t}-\bigtriangleup _{\theta _{t}}                       (5)

        下面介绍的所有优化算法基本都能套用这个流程,只是式子(4)的形式会有变化。

一、 torch.optim.SGD 随机梯度下降

        该类可实现 SGD 优化算法,带动量 的SGD 优化算法和带 NAG(Nesterov accelerated gradient)的 SGD 优化算法,并且均可拥有 weight_decay(权重衰减) 项。

SGD代码

'''
params(iterable)- 参数组,优化器要优化的那部分参数。
lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。
momentum(float)- 动量,通常设置为 0.9,0.8
dampening(float)- dampening for momentum ,暂时不了其功能,在源码中是这样用的:buf.mul_(momentum).add_(1 - dampening, d_p),值得注意的是,若采用nesterov,dampening 必须为 0.
weight_decay(float)- 权值衰减系数,也就是 L2 正则项的系数
nesterov(bool)- bool 选项,是否使用 NAG(Nesterov accelerated gradient)
'''
class torch.optim.SGD(params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

SGD算法解析

1.MBGD(Mini-batch Gradient Descent)小批量梯度下降法

        明明类名是SGD,为什么介绍MBGD呢,因为在Pytorch中,torch.optim.SGD其实是实现的MBGD,要想使用SGD,只要将batch_size设成1就行了。

        MBGD就是结合BGD和SGD的折中,对于含有 n个训练样本的数据集,每次参数更新,选择一个大小为 m(m<n) 的mini-batch数据样本计算其梯度,其参数更新公式如下,其中j是一个batch的开始:

\theta _{t+1}=\theta _{t}-\eta *\frac{1}{m}*\sum_{i=j}^{i=j+m-1}\bigtriangledown _{\theta _{i}}J_{i}(\theta _{t})                (6)

优点:使用mini-batch的时候,可以收敛得很快,有一定摆脱局部最优的能力。

缺点:a.在随机选择梯度的同时会引入噪声,使得权值更新的方向不一定正确

           b.不能解决局部最优解的问题

 2.Momentum动量

         动量是一种有助于在相关方向上加速SGD并抑制振荡的方法,通过将当前梯度与过去梯度加权平均,来获取即将更新的梯度。如下图b图所示。它通过将过去时间步长的更新向量的一小部分添加到当前更新向量来实现这一点:

image-20211126212003953

 动量项通常设置为0.9或类似值。

参数更新公式如下,其中ρ 是动量衰减率,m是速率(即一阶动量)

g_{t}=\bigtriangledown_\theta J(\theta _{t})                             (7)

m_{t} = \rho *m_{t-1} +g_{t}                (8)

\theta _{t+1}=\theta _{t}-\eta *m_{t}                  (9)

3.NAG(Nesterov accelerated gradient)

        NAG的思想是在动量法的基础上展开的。动量法是思想是,将当前梯度与过去梯度加权平均,来获取即将更新的梯度。在知道梯度之后,更新自变量到新的位置。也就是说我们其实在每一步,是知道下一时刻位置的。这时Nesterov就说了:那既然这样的话,我们何不直接采用下一时刻的梯度来和上一时刻梯度进行加权平均呢?下面两张图看明白,就理解NAG了:

        这里写图片描述

 无

NAG和经典动量法的差别就在B点和C点梯度的不同。 

 参数更新公式:

g_{t}=\bigtriangledown_\theta J(\theta _{t}-\rho m_{t-1})                (10)

m_{t} = \rho *m_{t-1} +g_{t}                        (11)

\theta _{t+1}=\theta _{t}-\eta *m_{t}                           (12)

        上式中的-\rho m_{t-1}就是图中的B到C那一段向量,\theta _{t}-\rho m_{t-1}就是C点坐标(参数)。可以看到NAG除了式子(10)与式子(7)有所不同,其余公式和Momentum是一样的。

        一般情况下NAG方法相比Momentum收敛速度快、波动也小。实际上NAG方法用到了二阶信息,所以才会有这么好的结果。

         Nesterov动量梯度的计算在模型参数施加当前速度之后,因此可以理解为往标准动量中添加了一个校正因子。在凸批量梯度的情况下,Nesterov动量将额外误差收敛率从O(\frac{1}{k})(k步后)改进到  O(\frac{1}{k^2}),然而,在随机梯度情况下,Nesterov动量对收敛率的作用却不是很大。

SGD总结

使用了Momentum或NAG的MBGD有如下特点:

优点:加快收敛速度,有一定摆脱局部最优的能力,一定程度上缓解了没有动量的时候的问题

缺点:a.仍然继承了一部分SGD的缺点

          b.在随机梯度情况下,NAG对收敛率的作用不是很大

          c.Momentum和NAG都是为了使梯度更新更灵活。但是人工设计的学习率总是有些生硬,下面介绍几种自适应学习率的方法。

推荐程度:带Momentum的torch.optim.SGD 可以一试。

二、torch.optim.ASGD随机平均梯度下降

        ASGD 也称为 SAG,表示随机平均梯度下降(Averaged Stochastic Gradient Descent),简单地说 ASGD 就是用空间换时间的一种 SGD,因为很少使用,所以不详细介绍,详情可参看论文: http://riejohnson.com/rie/stograd_nips.pdf

'''
params(iterable)- 参数组,优化器要优化的那些参数。
lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。
lambd(float)- 衰减项,默认值 1e-4。
alpha(float)- power for eta update ,默认值 0.75。
t0(float)- point at which to start averaging,默认值 1e6。
weight_decay(float)- 权值衰减系数,也就是 L2 正则项的系数。
'''
class torch.optim.ASGD(params, lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)

 推荐程度:不常见

三、torch.optim.Rprop

        该类实现 Rprop 优化方法(弹性反向传播),适用于 full-batch,不适用于 mini-batch,因而在 mini-batch 大行其道的时代里,很少见到。

'''
params - 参数组,优化器要优化的那些参数。
lr - 学习率
etas (Tuple[float, float])- 乘法增减因子
step_sizes (Tuple[float, float]) - 允许的最小和最大步长
'''
class torch.optim.Rprop(params, lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50))

优点:它可以自动调节学习率,不需要人为调节

缺点:仍依赖于人工设置一个全局学习率,随着迭代次数增多,学习率会越来越小,最终会趋近于0

推荐程度:不推荐

四、torch.optim.Adagrad 自适应梯度

        该类可实现 Adagrad 优化方法(Adaptive Gradient),Adagrad 是一种自适应优化方法,是自适应的为各个参数分配不同的学习率。这个学习率的变化,会受到梯度的大小和迭代次数的影响。梯度越大,学习率越小;梯度越小,学习率越大。

Adagrad 代码

'''
params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
lr (float, 可选) – 学习率(默认: 1e-2)
lr_decay (float, 可选) – 学习率衰减(默认: 0)
weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0)
initial_accumulator_value - 累加器的起始值,必须为正。'''
class torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)

Adagrad 算法解析

        AdaGrad对学习率进行了一个约束,对于经常更新的参数,我们已经积累了大量关于它的知识,不希望被单个样本影响太大,希望学习速率慢一些;对于偶尔更新的参数,我们了解的信息太少,希望能从每个偶然出现的样本身上多学一些,即学习速率大一些。这样大大提高梯度下降的鲁棒性而该方法中开始使用二阶动量,才意味着“自适应学习率”优化算法时代的到来。
        在SGD中,我们每次迭代对所有参数进行更新,因为每个参数使用相同的学习率。而AdaGrad在每个时间步长对每个参数使用不同的学习率。AdaGrad消除了手动调整学习率的需要。AdaGrad在迭代过程中不断调整学习率,并让目标函数中的每个参数都分别拥有自己的学习率。大多数实现使用学习率默认值为0.01,开始设置一个较大的学习率。

        AdaGrad引入了二阶动量。二阶动量是迄今为止所有梯度值的平方和,即v_{t}=\sum_{i=1}^{t}g_{t}^{2}它是用来度量历史更新频率的。也就是说,我们的学习率现在是\frac{\eta }{\sqrt{v_{t}+\epsilon }},从这里我们就会发现 \sqrt{v_{t}+\epsilon }是恒大于0的,而且参数更新越频繁,二阶动量越大,学习率就越小,这一方法在稀疏数据场景下表现非常好,参数更新公式如下: 

        v_{t}=\sum_{i=1}^{t}g_{t}^{2}                                                    (13)

        \theta _{t-1}=\theta _{t}-\eta *\frac{g_{t}}{\sqrt{v_{t}+\epsilon }}                        (14)

AdaGrad总结

        AdaGrad在每个时间步长对每个参数使用不同的学习率。并且引入了二阶动量,二阶动量是迄今为止所有梯度值的平方和。

优点:AdaGrad消除了手动调整学习率的需要。AdaGrad在迭代过程中不断调整学习率,并让目标函数中的每个参数都分别拥有自己的学习率。

缺点:a.仍需要手工设置一个全局学习率  , 如果  设置过大的话,会使regularizer过于敏感,对梯度的调节太大

        b.在分母中累积平方梯度,由于每个添加项都是正数,因此在训练过程中累积和不断增长。这导致学习率不断变小并最终变得无限小,此时算法不再能够获得额外的知识即导致模型不会再次学习。

 推荐程度:不推荐

接下来adam相关的将是重点,敬请期待。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_5744.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【手把手】ios苹果打包——遇见项目实战|超详细的教程分享

六年代码两茫茫&#xff0c;不思量&#xff0c;自难忘 6年资深前端主管一枚&#xff0c;只分享技术干货&#xff0c;项目实战经验 关注博主不迷路~ 文章目录前言weex介绍eeui介绍一、安装CocoaPods1.CocoaPods介绍2.CocoaPods的安装二、登录开发者中心四、添加测试手机设备五、…

2022最新iOS证书(.p12)、描述文件(.mobileprovision)申请和HBuider打包及注意注意事项

制作p12证书1、在钥匙串界面中,选中安装好的开发者证书,【右键】选择导出在弹出的界面中3、在接下来的弹窗中填写p12文件的安装密码(后面他人安装该p12文件时需要输入这个密码,重要)4、继续上面的步骤,这里需要输入电脑的开机密码,p12开发者证书到这里即制作完成。以上就…

【芯片前端】根据数据有效选择输出的握手型FIFO结构探究

前言 之前要做一个一读多写的fifo&#xff0c;也就是master写入数据到fifo中&#xff0c;多个slave读取数据&#xff0c;结构如下图所示&#xff1a; 由于slave需要的数据一致&#xff0c;fifo内只需要例化一个ram以节约空间。这个fifo的具体结构下次博客中再来讨论。在这个fi…

Git 之 revert

转自: Git 之 revertrevert 可以撤销指定的提交内容,撤销后会生成一个新的commit。 1、两种commit: 当讨论 revert 时,需要分两种情况,因为 commit 分为两种:一种是常规的 commit,也就是使用 git commit 提交的 commit; 另一种是 merge commit,在使用 git merge 合并两…

mysql 主从备份原理

mysql 主从备份原理 1.1 用途及条件 mysql主从复制用途实时灾备,用于故障切换 读写分离,提供查询服务 备份,避免影响业务主从部署必要条件:主库开启binlog日志(设置log-bin参数) 主从server-id不同 从库服务器能连通主库2.1 主从原理在备库 B 上通过 change master 命令,…

服务端挂了,客户端的 TCP 连接还在吗?

作者:小林coding 计算机八股文网站:https://xiaolincoding.com大家好,我是小林。 如果「服务端挂掉」指的是「服务端进程崩溃」,服务端的进程在发生崩溃的时候,内核会发送 FIN 报文,与客户端进行四次挥手。 但是,如果「服务端挂掉」指的是「服务端主机宕机」,那么是不会…

[第二章 web进阶]XSS闯关-1

定义:跨站脚本(Cross_Site Scripting,简称为XSS或跨站脚本或跨站脚本攻击)是一种针对网站应用程序的安全漏洞攻击技术,是代码注入的一种。它允许恶意用户将代码注入网页,其他用户浏览网页时就会受到影响。恶意用户利用XSS代码攻击成功后,可能得到包括但不限于更高的权限、会…

K8s简介之什么是K8s

1.概述 欢迎来到K8s入门课程。Kubernetes,也被称为K8s或Kube,是谷歌推出的业界最受欢迎的容器编排器。本K8s教程由一系列关于K8s的文章组成。在第一部分,我们将讨论什么是K8s和K8s的基本概念。 本课程是专为初学者开设的,你可以零基础学习这项技术。我们将带你了解全部K8s的…

第2章 第一个Spring Boot项目

开发工具选择 工欲善其事必先利其器&#xff0c;我们进行Java项目开发&#xff0c;选择一个好的集成开发工具&#xff08;IDE&#xff09;对提高我们的开发调试效率有非常大的帮助。这里我们选择大名鼎鼎的IDEA &#xff0c;它全称 IntelliJ IDEA。 ​IntelliJ IDEA公认最好的J…

【云原生 | Kubernetes 系列】K8s 实战 如何给应用注入数据 II 将pod数据传递给容器

将pod数据传递给容器前言一、通过环境变量将 Pod 信息传递给容器1.1、用 Container 字段作为环境变量的值二、通过文件将 Pod 信息呈现给容器2.1、存储容器字段总结前言 在上一篇文章中&#xff0c;我们学习了针对容器设置启动时要执行的命令和参数、定义相互依赖的环境变量、为…

关于订单过期的监听和处理

订单过期监听和处理 业务需求 有些时候 用户发起订单 但是没有付款 这个时候一般来说 会设置一个订单过期时间 如果订单过期 则需要重新下单 问题来了 如果每过一段很小的时间就去盘一次数据库 那压力也太大了 demo 搭建 用到的 mysql mybatis plus redis rabbit mq 目录结…

【毕业设计】单片机远程wifi红外无接触体温测量系统 - 物联网 stm32

文章目录0 前言1 简介2 主要器件3 实现效果4 设计原理4.1 **MLX90614红外测温传感器**4.2 TOF10120激光测距传感器4.3 DS18B20传感器**DS18B20单总线协议**5 部分核心代码5 最后0 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设…

精妙绝伦

精妙绝伦啊,精妙绝伦啊,大妙! 今天讨论到一个二级联动省和市在一个表中的情况, 这么一组数据,需要达成一个sql语句便能把省和市同时显示出来,愚绞尽脑汁思虑良久,未得有用之策,经同事提点,顿醍醐灌顶! 先来解释一下这串代码:Select * from TBSpace inner join TBPla…

three.js绘制地图(平面、曲面)

加载中国地图json数据 let loader = new THREE.FileLoader(); loader.load(model/chinaJson.json, function (data) {let jsonData = JSON.parse(data);initMap(jsonData); // 解析并绘制地图 });绘制曲面地图function initMap( chinaJson ) {//创建一个空对象存放对象map = ne…

Vue指令

Vue指令分为内置指令和自定义指令 内置指令 v-bind 单向绑定解析表达式&#xff0c; 简写&#xff1a; &#xff1a;xxx <div id"root">单项数据绑定&#xff1a;<input type"text" v-bind:value"name"><br></div> v…

2023秋招——快手数据研发一、二面面经

&#x1f33c;今天来总结一下快手数据研发的一、二面&#xff0c;在面试中进步&#xff0c;在总结中成长&#xff01;对往期内容感兴趣的小伙伴可以参考下面&#x1f447;&#xff1a; 链接: 2022暑期实习字节跳动数据研发面试经历-基础数仓.链接: 2022百度大数据开发工程师实…

three.js实现鼠标拾取例子

基本思路 <script> var renderer,scene,camera; var light; var raycaster,//相机->鼠标的射线mouse,//鼠标所在位置actionObject;//选中的物体 init(); animation();function init(){//渲染器//场景//相机//方向光//创建2000个立方体//创建射线//创建鼠标二维向量(圆…

epoll实现异步请求数据---以UDP为例

文章目录同步UDP请求数据的问题异步请求的模型具体的代码同步UDP请求数据的问题 不管是请求DNS资源还是其他资源。如果以串行的方式请求数据&#xff0c;也就是send以后recv阻塞等待获取数据&#xff0c;这样做的效率非常低效&#xff0c;网络延迟、服务器处理请求、再加上rec…

【C# 学习笔记 ②】C#基本语法(数组、判断和循环、字符串、枚举、结构体)

由于在自己的工作和学习过程中&#xff0c;只查看某个大佬的教程或文章无法满足自己的学习需求和解决遇到的问题&#xff0c;所以自己在追赶大佬们步伐的基础上&#xff0c;又自己总结、整理、汇总了一些资料&#xff0c;方便自己理解和后续回顾&#xff0c;同时也希望给大家带…

【我不熟悉的css】07. css命名,bem规范,跟着组件库element-ui学习组件命名

在去年&#xff0c;我总结了一篇文章&#xff0c;跟着element-ui学习css命名 【系统学习css】跟着element-ui学习css的命名_我有一棵树的博客-CSDN博客每日鸡汤&#xff0c;每一个你想要学习的念头都是未来的你向自己求救写css 最烦人的就是给class起名字了&#xff0c;这里不…