【深度学习笔记】6_6 通过时间反向传播(back-propagation through time)

news/2024/4/24 14:54:29/文章来源:https://blog.csdn.net/qq_44894943/article/details/136586452

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

6.6 通过时间反向传播

在前面两节中,如果不裁剪梯度,模型将无法正常训练。为了深刻理解这一现象,本节将介绍循环神经网络中梯度的计算和存储方法,即通过时间反向传播(back-propagation through time)。

我们在3.14节(正向传播、反向传播和计算图)中介绍了神经网络中梯度计算与存储的一般思路,并强调正向传播和反向传播相互依赖。正向传播在循环神经网络中比较直观,而通过时间反向传播其实是反向传播在循环神经网络中的具体应用。我们需要将循环神经网络按时间步展开,从而得到模型变量和参数之间的依赖关系,并依据链式法则应用反向传播计算并存储梯度。

6.6.1 定义模型

简单起见,我们考虑一个无偏差项的循环神经网络,且激活函数为恒等映射( ϕ ( x ) = x \phi(x)=x ϕ(x)=x)。设时间步 t t t 的输入为单样本 x t ∈ R d \boldsymbol{x}_t \in \mathbb{R}^d xtRd,标签为 y t y_t yt,那么隐藏状态 h t ∈ R h \boldsymbol{h}_t \in \mathbb{R}^h htRh的计算表达式为

h t = W h x x t + W h h h t − 1 , \boldsymbol{h}_t = \boldsymbol{W}_{hx} \boldsymbol{x}_t + \boldsymbol{W}_{hh} \boldsymbol{h}_{t-1}, ht=Whxxt+Whhht1,

其中 W h x ∈ R h × d \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} WhxRh×d W h h ∈ R h × h \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h是隐藏层权重参数。设输出层权重参数 W q h ∈ R q × h \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} WqhRq×h,时间步 t t t的输出层变量 o t ∈ R q \boldsymbol{o}_t \in \mathbb{R}^q otRq计算为

o t = W q h h t . \boldsymbol{o}_t = \boldsymbol{W}_{qh} \boldsymbol{h}_{t}. ot=Wqhht.

设时间步 t t t的损失为 ℓ ( o t , y t ) \ell(\boldsymbol{o}_t, y_t) (ot,yt)。时间步数为 T T T的损失函数 L L L定义为

L = 1 T ∑ t = 1 T ℓ ( o t , y t ) . L = \frac{1}{T} \sum_{t=1}^T \ell (\boldsymbol{o}_t, y_t). L=T1t=1T(ot,yt).

我们将 L L L称为有关给定时间步的数据样本的目标函数,并在本节后续讨论中简称为目标函数。

6.6.2 模型计算图

为了可视化循环神经网络中模型变量和参数在计算中的依赖关系,我们可以绘制模型计算图,如图6.3所示。例如,时间步3的隐藏状态 h 3 \boldsymbol{h}_3 h3的计算依赖模型参数 W h x \boldsymbol{W}_{hx} Whx W h h \boldsymbol{W}_{hh} Whh、上一时间步隐藏状态 h 2 \boldsymbol{h}_2 h2以及当前时间步输入 x 3 \boldsymbol{x}_3 x3

在这里插入图片描述

图6.3 时间步数为3的循环神经网络模型计算中的依赖关系。方框代表变量(无阴影)或参数(有阴影),圆圈代表运算符

6.6.3 方法

刚刚提到,图6.3中的模型的参数是 W h x \boldsymbol{W}_{hx} Whx, W h h \boldsymbol{W}_{hh} Whh W q h \boldsymbol{W}_{qh} Wqh。与3.14节(正向传播、反向传播和计算图)中的类似,训练模型通常需要模型参数的梯度 ∂ L / ∂ W h x \partial L/\partial \boldsymbol{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh ∂ L / ∂ W q h \partial L/\partial \boldsymbol{W}_{qh} L/Wqh
根据图6.3中的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。为了表述方便,我们依然采用3.14节中表达链式法则的运算符prod。

首先,目标函数有关各时间步输出层变量的梯度 ∂ L / ∂ o t ∈ R q \partial L/\partial \boldsymbol{o}_t \in \mathbb{R}^q L/otRq很容易计算:

∂ L ∂ o t = ∂ ℓ ( o t , y t ) T ⋅ ∂ o t . \frac{\partial L}{\partial \boldsymbol{o}_t} = \frac{\partial \ell (\boldsymbol{o}_t, y_t)}{T \cdot \partial \boldsymbol{o}_t}. otL=Tot(ot,yt).

下面,我们可以计算目标函数有关模型参数 W q h \boldsymbol{W}_{qh} Wqh的梯度 ∂ L / ∂ W q h ∈ R q × h \partial L/\partial \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} L/WqhRq×h。根据图6.3, L L L通过 o 1 , … , o T \boldsymbol{o}_1, \ldots, \boldsymbol{o}_T o1,,oT依赖 W q h \boldsymbol{W}_{qh} Wqh。依据链式法则,

∂ L ∂ W q h = ∑ t = 1 T prod ( ∂ L ∂ o t , ∂ o t ∂ W q h ) = ∑ t = 1 T ∂ L ∂ o t h t ⊤ . \frac{\partial L}{\partial \boldsymbol{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{o}_t} \boldsymbol{h}_t^\top. WqhL=t=1Tprod(otL,Wqhot)=t=1TotLht.

其次,我们注意到隐藏状态之间也存在依赖关系。
在图6.3中, L L L只通过 o T \boldsymbol{o}_T oT依赖最终时间步 T T T的隐藏状态 h T \boldsymbol{h}_T hT。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度 ∂ L / ∂ h T ∈ R h \partial L/\partial \boldsymbol{h}_T \in \mathbb{R}^h L/hTRh。依据链式法则,我们得到

∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T . \frac{\partial L}{\partial \boldsymbol{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_T}, \frac{\partial \boldsymbol{o}_T}{\partial \boldsymbol{h}_T} \right) = \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_T}. hTL=prod(oTL,hToT)=WqhoTL.

接下来对于时间步 t < T t < T t<T, 在图6.3中, L L L通过 h t + 1 \boldsymbol{h}_{t+1} ht+1 o t \boldsymbol{o}_t ot依赖 h t \boldsymbol{h}_t ht。依据链式法则,
目标函数有关时间步 t < T t < T t<T的隐藏状态的梯度 ∂ L / ∂ h t ∈ R h \partial L/\partial \boldsymbol{h}_t \in \mathbb{R}^h L/htRh需要按照时间步从大到小依次计算:
∂ L ∂ h t = prod ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) + prod ( ∂ L ∂ o t , ∂ o t ∂ h t ) = W h h ⊤ ∂ L ∂ h t + 1 + W q h ⊤ ∂ L ∂ o t \frac{\partial L}{\partial \boldsymbol{h}_t} = \text{prod} (\frac{\partial L}{\partial \boldsymbol{h}_{t+1}}, \frac{\partial \boldsymbol{h}_{t+1}}{\partial \boldsymbol{h}_t}) + \text{prod} (\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{h}_t} ) = \boldsymbol{W}_{hh}^\top \frac{\partial L}{\partial \boldsymbol{h}_{t+1}} + \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_t} htL=prod(ht+1L,htht+1)+prod(otL,htot)=Whhht+1L+WqhotL

将上面的递归公式展开,对任意时间步 1 ≤ t ≤ T 1 \leq t \leq T 1tT,我们可以得到目标函数有关隐藏状态梯度的通项公式

∂ L ∂ h t = ∑ i = t T ( W h h ⊤ ) T − i W q h ⊤ ∂ L ∂ o T + t − i . \frac{\partial L}{\partial \boldsymbol{h}_t} = \sum_{i=t}^T {\left(\boldsymbol{W}_{hh}^\top\right)}^{T-i} \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_{T+t-i}}. htL=i=tT(Whh)TiWqhoT+tiL.

由上式中的指数项可见,当时间步数 T T T 较大或者时间步 t t t 较小时,目标函数有关隐藏状态的梯度较容易出现衰减和爆炸。这也会影响其他包含 ∂ L / ∂ h t \partial L / \partial \boldsymbol{h}_t L/ht项的梯度,例如隐藏层中模型参数的梯度 ∂ L / ∂ W h x ∈ R h × d \partial L / \partial \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} L/WhxRh×d ∂ L / ∂ W h h ∈ R h × h \partial L / \partial \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} L/WhhRh×h
在图6.3中, L L L通过 h 1 , … , h T \boldsymbol{h}_1, \ldots, \boldsymbol{h}_T h1,,hT依赖这些模型参数。
依据链式法则,我们有

∂ L ∂ W h x = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h x ) = ∑ t = 1 T ∂ L ∂ h t x t ⊤ , ∂ L ∂ W h h = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h h ) = ∑ t = 1 T ∂ L ∂ h t h t − 1 ⊤ . \begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{x}_t^\top,\\ \frac{\partial L}{\partial \boldsymbol{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{h}_{t-1}^\top. \end{aligned} WhxLWhhL=t=1Tprod(htL,Whxht)=t=1ThtLxt,=t=1Tprod(htL,Whhht)=t=1ThtLht1.

我们已在3.14节里解释过,每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算。例如,由于隐藏状态梯度 ∂ L / ∂ h t \partial L/\partial \boldsymbol{h}_t L/ht被计算和存储,之后的模型参数梯度 ∂ L / ∂ W h x \partial L/\partial \boldsymbol{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh的计算可以直接读取 ∂ L / ∂ h t \partial L/\partial \boldsymbol{h}_t L/ht的值,而无须重复计算它们。此外,反向传播中的梯度计算可能会依赖变量的当前值。它们正是通过正向传播计算出来的。
举例来说,参数梯度 ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh的计算需要依赖隐藏状态在时间步 t = 0 , … , T − 1 t = 0, \ldots, T-1 t=0,,T1的当前值 h t \boldsymbol{h}_t ht h 0 \boldsymbol{h}_0 h0是初始化得到的)。这些值是通过从输入层到输出层的正向传播计算并存储得到的。

小结

  • 通过时间反向传播是反向传播在循环神经网络中的具体应用。
  • 当总的时间步数较大或者当前时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。

注:本节与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_999580.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第19章-IPv6基础

1. IPv4的缺陷 2. IPv6的优势 3. 地址格式 3.1 格式 3.2 长度 4. 地址书写压缩 4.1 段内前导0压缩 4.2 全0段压缩 4.3 例子1 4.4 例子 5. 网段划分 5.1 前缀 5.2 接口标识符 5.3 前缀长度 5.4 地址规模分类 6. 地址分类 6.1 单播地址 6.2 组播地址 6.3 任播地址 6.4 例子 …

离线数仓(五) [ 从数据仓库概述到建模 ]

前言 今天开始正式数据仓库的内容了, 前面我们把生产数据 , 数据上传到 HDFS , Kafka 的通道都已经搭建完毕了, 数据也就正式进入数据仓库了, 解下来的数仓建模是重中之重 , 是将来吃饭的家伙 ! 以及 Hive SQL 必须熟练到像喝水一样 ! 第1章 数据仓库概述 1.1 数据仓库概念 数…

MySQL实战:SQL优化及问题排查

有更合适的索引不走&#xff0c;怎么办&#xff1f; MySQL在选取索引时&#xff0c;会参考索引的基数&#xff0c;基数是MySQL估算的&#xff0c;反映这个字段有多少种取值&#xff0c;估算的策略为选取几个页算出取值的平均值&#xff0c;再乘以页数&#xff0c;即为基数 查…

【Qt】四种绘图设备详细使用

绘图设备有4个: **绘图设备是指继承QPainterDevice的子类————**QPixmap QImage QPicture QBitmap(黑白图片) QBitmap——父类QPixmapQPixmap图片类&#xff0c;主要用来显示&#xff0c;它针对于显示器显示做了特殊优化&#xff0c;依赖于平台的&#xff0c;只能在主线程…

鸿蒙实战多媒体运用:【音频组件】

音频组件用于实现音频相关的功能&#xff0c;包括音频播放&#xff0c;录制&#xff0c;音量管理和设备管理。 图 1 音频组件架构图 基本概念 采样 采样是指将连续时域上的模拟信号按照一定的时间间隔采样&#xff0c;获取到离散时域上离散信号的过程。 采样率 采样率为每…

【视频转码】基于ZLMediakit的视频转码技术概述

一、概述 zlmediakit pro版本支持基于ffmpeg的转码能力&#xff0c;在开源版本强大功能的基础上&#xff0c;新增支持如下能力&#xff1a; 1、音视频间任意转码(包括h265/h264/opus/g711/aac等)。2、基于配置文件的转码&#xff0c;支持设置比特率&#xff0c;codec类型等参…

c#触发事件

Demo1 触发事件 <Window x:Class"WPFExample.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"Title"WPF Example" Height"600" Wi…

【性能测试】Jmeter性能压测-阶梯式/波浪式场景总结(详细)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、阶梯式场景&am…

Flink hello world

下载并且解压Flink Downloads | Apache Flink 启动Flink. $ ./bin/start-cluster.sh Starting cluster. Starting standalonesession daemon on host harrydeMacBook-Pro.local. Starting taskexecutor daemon on host harrydeMacBook-Pro.local. 访问localhost:8081 Flin…

分布式数字身份:通往Web3.0世界的个人钥匙

数字化时代&#xff0c;个人身份已不再仅仅局限于传统形式&#xff0c;分布式数字身份&#xff08;Decentralized Identity&#xff0c;简称DID&#xff09;正崭露头角&#xff0c;它允许个人通过数字签名等加密技术&#xff0c;完全掌握和控制自己的身份信息。研究报告显示&am…

VScode+Live Service+Five Service实现php实时调试

VScodeLive ServiceFive Service实现php实时调试 一、VScode插件安装及配置 1.Code Runner settings.json设置&#xff08;打开方式&#xff1a;ctrlp&#xff0c;搜索settings.json&#xff09; 设置php为绝对路径&#xff08;注意路径分隔符为\\或/&#xff09; 2. Live S…

计算机网络——计算机网络的性能

计算机网络——计算机网络的性能 速率带宽吞吐量时延时延宽带积往返时间RTT利用率信道利用率网络利用率 我们今天来看看计算机网络的性能。 速率 速率这个很简单&#xff0c;就是数据的传送速率&#xff0c;也称为数据率&#xff0c;或者比特率&#xff0c;单位为bit/s&#…

DataWhale公开课笔记2:Diffusion Model和Transformer Diffusion

Stable Diffusion和AIGC AIGC是什么 AIGC的全称叫做AI generated content&#xff0c;AlGC (Al-Generated Content&#xff0c;人工智能生产内容)&#xff0c;是利用AI自动生产内容的生产方式。 在传统的内容创作领域中&#xff0c;专业生成内容&#xff08;PGC&#xff09;…

XSS靶场-DOM型初级关卡

一、环境 XSS靶场 二、闯关 1、第一关 先看源码 使用DOM型&#xff0c;获取h2标签&#xff0c;使用innerHTML将内容插入到h2中 我们直接插入<script>标签试一下 明显插入到h2标签中了&#xff0c;为什么不显示呢&#xff1f;看一下官方文档 尽管插入进去了&#xff0…

gitlab仓库迁移至bitbucket

0. 场景描述 假设已有一个gitlab仓库&#xff1a;ssh://xxx_origin.git&#xff0c;想要把这个仓库迁移至bitbucket上。 默认gitlab和bitbucket的SSH key都已添加。 1. 新建bitbucket仓库 在bitbucket上创建新的仓库&#xff0c;并复制url地址。假设为&#xff1a; https:/…

07.axios封装实例

一.简易axios封装-获取省份列表 1. 需求&#xff1a;基于 Promise 和 XHR 封装 myAxios 函数&#xff0c;获取省份列表展示到页面 2. 核心语法&#xff1a; function myAxios(config) {return new Promise((resolve, reject) > {// XHR 请求// 调用成功/失败的处理程序}) …

前端知识点、技巧、webpack、性能优化(持续更新~)

1、 请求太多 页面加载慢 &#xff08;webpack性能优化&#xff09; 可以把 图片转换成 base64 放在src里面 减少服务器请求 但是图片会稍微大一点点 以上的方法不需要一个一个自己转化 可以在webpack 进行 性能优化 &#xff08;官网有详细描述&#xff09;

Thingsboard学习杂记

知识杂记 1.遵循磁盘绑定的内存数据库和遵循磁盘支持的内存数据库 遵循磁盘绑定的内存数据库和遵循磁盘支持的内存数据库有不同的工作方式&#xff0c;它们的优点和缺点也不同。 遵循磁盘绑定的内存数据库的优点&#xff1a; 数据库可以支持更大的数据集合&#xff0c;因为数…

seq2seq翻译实战-Pytorch复现

&#x1f368; 本文为[&#x1f517;365天深度学习训练营学习记录博客 &#x1f366; 参考文章&#xff1a;365天深度学习训练营 &#x1f356; 原作者&#xff1a;[K同学啊 | 接辅导、项目定制]\n&#x1f680; 文章来源&#xff1a;[K同学的学习圈子](https://www.yuque.com/…

ChatGPT 升级出现「我们未能验证您的支付方式/we are unable to authenticate」怎么办?

ChatGPT 升级出现「我们未能验证您的支付方式/we are unable to authenticate」怎么办&#xff1f; 在订阅 ChatGPT Plus 时&#xff0c;有时候会出现以下报错 &#xff1a; We are unable to authenticate your payment method. 我们未能验证您的支付方式。 出现 unable to a…