Lecture3 梯度下降(Gradient Descent)

news/2024/4/20 11:23:03/文章来源:https://blog.csdn.net/m0_56494923/article/details/128910107

目录

1 问题背景

2 批量梯度下降 (Batch Gradient Descent)

3 鞍点(Saddle Point)

3 随机梯度下降 (Stochastic Gradient Descent)

4 小批量梯度下降 (Mini-batch Gradient Descent)


1 问题背景

图1 上节课讲述的穷举法求最优权重值

  在Lecture2中,介绍了使用穷举法来确定最优\omega值,然而当遇到\omega范围较大,或者数量过多等情况时,穷举法的时间复杂度过大。因此,我们需要优化该算法。

2 批量梯度下降 (Batch Gradient Descent)

  在这次课中,介绍了一种寻找\omega最优值的算法——批量梯度下降 (Batch Gradient Descent, BGD)

简单介绍下该算法。首先对于下图:

图2 训练过程中权重初始值与最优值的位置

  假设我们目前的起始\omega位于上图红色点,为了找到最优\omega点(位于绿点),那么我们需要向左边移动,这样才能到达最优\omega点。

图3 我们需要计算梯度以向左移动权值点

 

   如何让权值点向左还是向右移动呢?此时我们需要计算当前点的梯度(Gradient),也就是用成本函数对权重进行求导,如果梯度<0,则向函数值递减方向移动;梯度>0,则向函数值递增方向移动。

  因为要移动起来,所以我们每移动一步,就要更新一下\omega值。更新函数如下图Update处:

图4 更新权重值的函数公式

  在这个更新函数中, α代表学习率(Learning Rate),学习率是机器学习中常用的一个超参数,它定义了每次更新参数时步长的大小,即每次更新参数时参数值变化的幅度。如果学习率设置得过大,所求结果可能会在最优解的附近来回震荡,而无法找到全局最优解。如果学习率设置得过小,那么模型的训练将会非常缓慢,甚至找不到最优解。

  这个式子中,梯度前面用了减号,是为了朝函数值递减方向,也就是往最优\omega所在的点移动,所以在梯度前面加负号。 就这样持续一步步地更新\omega,直到找到最优\omega

下面我们来具体讲讲如何去计算更新函数中的\frac{\partial cost}{\partial \omega} :

图5 更新函数

计算过程中,需要用到上节课总结的两个公式:

图6 均方误差MSE

图7 预测值y_hat

  接着把上述两个公式代入原式:

图8 求导过程

   蓝色处,因为cost=MSE,所以直接代入上一节课的MSE公式,然后对\omega求导。

  绿色处,由有理运算法则,和的导数等于导数的和,所以这里可以把\frac{\partial}{\partial \omega}移入求和式子中,对里面先进行求导后,再求和相加。

  黄色处,根据复合导数的链式求导法进行求导。

代码实现

from matplotlib import pyplot as pltx_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0  # 初始权重,由这个权重开始迭代'''线性模型,算出预测值y_hat'''
def forward(x):return x * w'''均方误差MSE'''
def cost(xs, ys):cost = 0for x, y in zip(xs, ys):y_pred = forward(x)  # 算出y_hatcost += (y_pred - y) ** 2  # (y_hat - y)²return cost / len(xs)  # 除以样本总数求均值'''梯度下降公式'''
def gradient(xs, ys):grad = 0for x, y in zip(xs, ys):grad += 2 * x * (x * w - y)return grad / len(xs)print('Predict (before training)', 4, forward(4))  # 训练前,模型对输入的4的最终预测结果cost_list = [] # 保存每轮迭代后的cost值
epoch_list = [] # 保存每轮的迭代后的epoch值
for epoch in range(100):  # 进行100轮训练cost_val = cost(x_data, y_data)grad_val = gradient(x_data, y_data)w -= 0.01 * grad_val # 使用梯度下降法更新权重,0.01表示学习率print('Epoch:', epoch, 'w=%.2f' % w, 'loss=%.2f' % cost_val)cost_list.append(cost_val)epoch_list.append(epoch)
print('Predict (after training)', 4, forward(4))  # 训练后,模型对输入的4的最终预测结果'''绘图'''
plt.plot(epoch_list, cost_list)
plt.ylabel('Cost')
plt.xlabel('Epoch')
plt.grid()
plt.show()
图9 输出结果图像

  将MSE公式和Linear Model公式代入整合,的最终更新函数:

图10 最终的更新函数

补充

训练后的结果一般来说,cost会趋于收敛情况

图11 通常训练后cost图像会趋于收敛

 如果发生如下情况,说明训练失败,原因有很多,其中之一可能是学习率取得太大:

图12 训练失败

 

  这就是批量梯度下降算法,本质上是一个贪心算法(Greedy Algorithm)。不过该算法有局限性,比如当前的预测值\omega正好位于下图绿线处,因为再往右移动会梯度会发生变化,使得程序直接终止,于是误将红的点作为最优\omega值,而忽略了处于蓝色点的最优\omega值:

图13 局部最优和全局最优示意图

   我们把上图中的红点称为局部最优点(Local Optimum),蓝色点称为全局最优点(Global Optimum)。因此对于该梯度下降算法,很可能会找到局部最优点,而忽略了全局最优点。不过这种现象不必担心,因为在实际训练中,往往很难陷入局部最优点。

3 鞍点(Saddle Point)

  在实际训练中,往往很难陷入局部最优点,而最需要解决的问题是鞍点(Saddle Point),鞍点是机器学习和数学中的一个概念,它指的是一个特殊的局部极小值,在某些方向上是极小值,但在其他方向上是极大值。在一元函数中,梯度=0的点就是鞍点。比如下图中,红色小球所处的位置就在鞍点,此时梯度为零,会导致更新函数无法更新(因为梯度=0,\omega=\omega-α*0相当于没有发生更新):

图14 鞍点示意图

 

 从多维角度来分析,比如下图红球处于马鞍面(Saddle Surface),从一个切面看可以处于最小值,从另一个切面看又处于最大值:

图15 位于马鞍面的鞍点

 

  在优化问题中,鞍点是一种特殊的局部最优解,是一个难以优化的点,因为优化算法可能很难从鞍点附近找到全局最优解。这是因为,如果优化算法在鞍点附近搜索,它可能会被误导到其他附近的局部最优解,而不是真正的全局最优解。所以在深度学习中,需要克服的最大问题就是鞍点而非局部最优问题。

3 随机梯度下降 (Stochastic Gradient Descent)

  随机梯度下降 (Stochastic Gradient Descent, SGD)在深度学习中很常用,和BGD算法的区别是,BGD使用所有的样本的均值的平均损失来作为\omega的更新依据,而SGD是从所有样本中随机选择单个样本的损失值来对\omega进行更新。

  随机梯度下降的优点是,每次仅使用一个数据点的梯度,因此在每次迭代时都有可能沿着非0梯度的方向更新参数,这样就避免陷入到鞍点导致无法更新参数。

图16 BGD到SGD公式上的改变

代码实现

import randomx_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0def forward(x):return x * wdef loss(x, y):y_pred = forward(x)return (y_pred - y) ** 2def gradient(x, y):return 2 * x * (x * w - y)print('Predict (before training)', 4, forward(4))
for epoch in range(100):t = random.randrange(0, 3) # 随机得到一个样本x = x_data[t]y = y_data[t]grad = gradient(x, y)w = w - 0.01 * gradprint("\tgrad: ", x, y, '%.2f' % grad)l = loss(x, y)print("progress:", epoch, "w=%.2f" % w, "loss=%.2f" % l)
print('Predict (after training)', 4, forward(4))

部分输出结果

Predict (before training) 4 4.0
    grad:  3.0 6.0 -18.00
progress: 0 w=1.18 loss=6.05
    grad:  2.0 4.0 -6.56
progress: 1 w=1.25 loss=2.28
    grad:  3.0 6.0 -13.58
progress: 2 w=1.38 loss=3.44
    grad:  1.0 2.0 -1.24
progress: 3 w=1.39 loss=0.37
    grad:  2.0 4.0 -4.85
progress: 4 w=1.44 loss=1.24

···

    grad:  1.0 2.0 -0.00
progress: 97 w=2.00 loss=0.00
    grad:  1.0 2.0 -0.00
progress: 98 w=2.00 loss=0.00
    grad:  2.0 4.0 -0.00
progress: 99 w=2.00 loss=0.00
Predict (after training) 4 7.999910864525451

4 小批量梯度下降 (Mini-batch Gradient Descent)

  SGD算法虽然可以在一定程度上避免陷入局部最优以及鞍点问题,但是运算所需时间复杂度过高,每次仅使用一个数据点的梯度,因此它的收敛速度通常比较慢。

  因此有一个折中的办法,就是使用小批量梯度下降 (Mini-batch Gradient Descent) 算法。简单来说,小批量梯度下降是一种介于批量梯度下降和随机梯度下降之间的优化算法。结合了这两种方法,通过使用小的随机选择的训练数据子集(称为mini-batch)计算损失函数关于参数的梯度的平均值来更新模型参数。

  总之,小批量梯度下降算法实现了BGD的高计算效率和SGD的良好收敛性之间的平衡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_255655.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

重磅!微软推出首款 ChatGPT 版搜索引擎!

微软近期推出了首款 ChatGPT 版搜索引擎&#xff0c;今天带大家一起来看一下。 一夜之间&#xff0c;全球最大的科技公司仿佛都回到了自己年轻时的样子。 在谷歌宣布「实验性对话式人工智能服务」Bard 之后仅 24 小时&#xff0c;北京时间 2 月 8 日凌晨两点&#xff0c;微软发…

2023 软件测试行业内卷动荡,红利期过去后,何去何从?

前段时间席卷全互联网行业的内卷现象&#xff0c;想必有不少人都深陷其中。其实刚开始测试行业人才往往供不应求&#xff0c;而在发展了十几年后&#xff0c;很多人涌入这个行业开始面对存量竞争。红利期过去了&#xff0c;只剩内部争夺。 即便如此&#xff0c;测试行业仍有许…

微服务 分片 运维管理

微服务 分片 运维管理分片分片的概念分片案例环境搭建案例改造成任务分片Dataflow类型调度代码示例运维管理事件追踪运维平台搭建步骤使用步骤分片 分片的概念 当只有一台机器的情况下&#xff0c;给定时任务分片四个&#xff0c;在机器A启动四个线程&#xff0c;分别处理四个…

Python编程自动化办公案例(1)

作者简介&#xff1a;一名在校计算机学生、每天分享Python的学习经验、和学习笔记。 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a;网络豆的主页​​​​​​ 目录 前言 一.使用库讲解 1.xlrd 2.xlwt 二.主要案例 1.批量合并 模板如下&#xf…

Monkey

文章目录一、简介二、原理2.1 特殊处理三、命令3.1 启动3.2 关闭四、事件4.1 触摸事件4.2 手势事件4.3 二指缩放事件4.4 轨迹事件4.5 屏幕旋转事件4.6 基本导航事件4.7 主要导航事件4.8 系统按键事件4.9 启动activity事件4.10 键盘事件4.11 其他类型事件五、参数5.1 常规类参数…

go语言实现的一个基于go-zero框架的微服务影院票务系统cinema-ticket

一个基于go-zero框架的微服务影院票务系统cinema-ticket 前言 项目基本介绍 项目开源地址&#xff1a;butane123/cinema-ticket: 一个基于go-zero框架的微服务影院票务系统cinema-ticket (github.com) 这是一个微服务影院票务系统&#xff0c;基于go-zero框架实现&#xff0c…

【Java进阶打卡】JDBC- jdbc连接池

【Java进阶打卡】JDBC- jdbc连接池概述自定义数据库连接池归还连接-装饰设计模式归还连接-适配器设计模式动态代理动态代理-归还数据库连接概述 自定义数据库连接池 DataSource接口概述 javax.sql.DataSource接口&#xff1a;数据源&#xff08;数据库连接池&#xff09; Java…

LaoCat带你认识容器与镜像(实践篇二上)

实践篇主要以各容器的挂载和附加命令为主。 本章内容 本文实操全部基于Ubuntu 20.04 宿主机 > linux服务器本身 Docker > 20.10.22 在开始本章内容之前&#xff0c;我解答一个问题&#xff0c;有小伙伴问我说&#xff0c;有的容器DockerHub官网并没有提供任何可参考的文…

软件测试标准流程

软件测试的基本流程大概要经历四个阶段&#xff0c;分别是制定测试计划、测试需求分析、测试用例设计与编写以及测试用例评审。因此软件测试的工作内容&#xff0c;远远没有许多人想象的只是找出bug那么简单。准确的说&#xff0c;从一个项目立项以后&#xff0c;软件测试从业者…

【项目精选】基于Java的敬老院管理系统的设计和实现

本系统主要是针对敬老院工作人员即管理员和员工设计的。敬老院管理系统 将IT技术为养老院提供一个接口便于管理信息,存储老人个人信息和其他信息,查找 和更新信息的养老院档案,节省了员工的劳动时间,大大降低了成本。 其主要功能包括&#xff1a; 系统管理员用户功能介绍&#…

面临激烈竞争的汽车之家仍有新的增长机会

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 竞争激烈是汽车之家面临的主要问题 在汽车之家&#xff08;ATHM&#xff09;2021财年的20-F文件中&#xff0c;汽车之家将自己描述为中国最大的“汽车服务平台”运营商&#xff0c;但行业数据却显示&#xff0c;汽车之家的…

Python 如何快速搭建环境?

Python可应用于多平台包括 Linux 和 Mac OS X。 你可以通过终端窗口输入 “python” 命令来查看本地是否已经安装Python以及Python的安装版本。 Unix (Solaris, Linux, FreeBSD, AIX, HP/UX, SunOS, IRIX, 等等。) Win 9x/NT/2000 Macintosh (Intel, PPC, 68K) OS/2 DOS (多个…

10条终身受益的Salesforce职业发展建议!

Salesforce这个千亿美金巨兽&#xff0c;在全球范围内有42,000多名员工。作为一家发展迅速的科技公司&#xff0c;一直在招聘各种角色&#xff0c;包括销售、营销、工程师和管理人员等。 据IDC估计&#xff0c;从2016年到2020年&#xff0c;该生态系统创造了190万个工作岗位。…

训练营day16

104.二叉树的最大深度 559.n叉树的最大深度111.二叉树的最小深度222.完全二叉树的节点个数104.二叉树的最大深度 力扣题目链接 给定一个二叉树&#xff0c;找出其最大深度。 二叉树的深度为根节点到最远叶子节点的最长路径上的节点数。 说明: 叶子节点是指没有子节点的节点。 示…

Java基础-网络编程

1. 网络编程入门 1.1 网络编程概述 计算机网络 是指将地理位置不同的具有独立功能的多台计算机及其外部设备&#xff0c;通过通信线路连接起来&#xff0c;在网络操作系统&#xff0c;网络管理软件及网络通信协议的管理和协调下&#xff0c;实现资源共享和信息传递的计算机系统…

Vue中路由缓存及activated与deactivated的详解

目录前言一&#xff0c;路由缓存1.1 引子1.2 路由缓存的方法1.2.1 keep-alive1.2.2 keep-alive标签中的include属性1.2.3 include中多组件的配置二&#xff0c;activated与deactivated2.1 引子2.2 介绍activated与deactivated2.3 解决需求三&#xff0c;整体代码总结前言 在Vu…

【深度学习基础8】卷积神经网络 经典网络

一、卷积操作 1. 基本原理 相信大家对卷积操作并不陌生,先来回顾一下卷积的工作原理(2-D):👇 卷积的目的是进行特征提取,不同的卷积核可以提取到不同的特征,比如下面的三个卷积核的功能分别是:模糊化、锐化、边缘化👇 卷积的本质就是滤波器, 将滤波器沿着图像…

【JavaScript】面向对象和构造函数详解

&#x1f4bb; 【JavaScript】面向对象和构造函数详解 &#x1f3e0;专栏&#xff1a;JavaScript &#x1f440;个人主页&#xff1a;繁星学编程&#x1f341; &#x1f9d1;个人简介&#xff1a;一个不断提高自我的平凡人&#x1f680; &#x1f50a;分享方向&#xff1a;目前…

加拿大访问学者家属如何办理探亲签证?

由于大多数访问学者的访学期限都为一年&#xff0c;家人来访不仅可以缓解访学的寂寞生活&#xff0c;而且也是家人到加拿大体验国外风情的好机会。家属在国内申请赴加签证时&#xff0c;如果材料齐全&#xff0c;一般上午递交了申请&#xff0c;下午就可以拿到签证。以下是家人…

基于merlin使用chatGPT进行对话

最近chatGPT很热&#xff0c;大家都想试用它。但由于各种限制&#xff0c;一般情况下国内不能试用。 下面给大家介绍基于merlin使用chatGPT&#xff08;目前每天只有11次问答次数&#xff09;。 1 打开merlin页面 访问地址merlin.foyer.work&#xff0c;点击“add to chrome”…