NLP学习笔记(二) LSTM基本介绍

news/2024/5/10 19:51:12/文章来源:https://blog.csdn.net/wsmrzx/article/details/128335055

大家好,我是半虹,这篇文章来讲长短期记忆网络 (Long Short-Term Memory, LSTM)

文章行文思路如下:

  1. 首先通过循环神经网络引出为啥需要长短期记忆网络
  2. 然后介绍长短期记忆网络的核心思想与运作方式
  3. 最后通过简短的代码深入理解长短期记忆网络的运作方式

长短期记忆网络可以看作是循环神经网络的改进版本,想要理解长短期记忆网络,首先要了解循环神经网络

由于我们之前已详细介绍过循环神经网络,所以这里我们只会做一个简单的回顾,想看详细的说明请戳这里


对比前馈神经网络,循环神经网络通过增加隐状态实现对隐藏层信息的传递,以此达到记住历史输入的目的

网络在每个时间步里读取上一隐藏层输出作为当前隐藏层输入,并保存当前隐藏层输出作为下一隐藏层输入

其结构简图如下:

循环神经网络结构

其中 XXX 是输入 ,HHH 是隐藏层的输出,图中的每个矩形都表示同一个循环神经网络隐藏层

下面我们把隐藏层中的细节也画出来,方便后面与长短期记忆网络来对比

循环神经网络结构

其中 XXX 是输入 ,HHH 是隐藏层的输出,图中的灰色矩形同样代表隐藏层,σ\sigmaσ 表示一个带激活函数的线性层

对应的公式表达如下:
Ht=α(XtWxh+Ht−1Whh+bh)H_{t} = \alpha(X_{t} W_{xh} + H_{t-1} W_{hh} + b_{h}) Ht=α(XtWxh+Ht1Whh+bh)
其中 XtX_{t}Xt 是当前输入,HtH_{t}Ht 是当前隐藏层输出,Ht−1H_{t-1}Ht1 是先前隐藏层输出,WxhW_{xh}WxhWhhW_{hh}Whhbhb_{h}bh 都是网络参数


理论上,上述介绍的循环神经网络能处理任意长的序列,但实际上却并非如此

在实际应用循环神经网络处理长序列时通常会出现梯度爆炸或梯度消失的情况,导致网络难以捕捉长期依赖

这是为什么呢?通过简单分析一下梯度计算公式就能发现端倪

为了阐述方便,我们暂且假定所有的参数都是一维的,用字母 θ\thetaθ 表示,对参数求导并按时间展开后如下所示
dHtdθ=∂Ht∂θ+∂Ht∂Ht−1dHt−1dθ=∂Ht∂θ+∂Ht∂Ht−1∂Ht−1∂θ+∂Ht∂Ht−1∂Ht−1∂Ht−2dHt−2dθ+⋯\begin{align*} \frac{d H_{t}}{d \theta} &= \frac{\partial H_{t}}{\partial \theta} + \frac{\partial H_{t}}{\partial H_{t-1}} \frac{d H_{t-1}}{d \theta} \\ &= \frac{\partial H_{t}}{\partial \theta} + \frac{\partial H_{t}}{\partial H_{t-1}} \frac{\partial H_{t-1}}{\partial \theta} + \frac{\partial H_{t}}{\partial H_{t-1}} \frac{\partial H_{t-1}}{\partial H_{t-2}} \frac{d H_{t-2}}{d \theta} + \cdots \end{align*} dθdHt=θHt+Ht1HtdθdHt1=θHt+Ht1HtθHt1+Ht1HtHt2Ht1dθdHt2+
不难发现,当前梯度 dHtdθ\frac{d H_{t}}{d \theta}dθdHt 由当前梯度值 ∂Ht∂θ\frac{\partial H_{t}}{\partial \theta}θHt 以及先前梯度 dHt−1dθ\frac{d H_{t-1}}{d \theta}dθdHt1 决定,对于先前梯度权重 ∂Ht∂Ht−1\frac{\partial H_{t}}{\partial H_{t-1}}Ht1Ht

  • ∣∂Ht∂Ht−1∣<1|\frac{\partial H_{t}}{\partial H_{t-1}}| < 1Ht1Ht<1 时,表示历史的梯度信息是逐渐减弱的,随着时间步不断增加,很可能会出现梯度消失
  • ∣∂Ht∂Ht−1∣>1|\frac{\partial H_{t}}{\partial H_{t-1}}| > 1Ht1Ht>1 时,表示历史的梯度信息是逐渐增强的,随着时间步不断增加,很可能会出现梯度爆炸

由推导式可以看出,梯度爆炸和梯度消失更容易出现在与当前时间步距离更远的梯度

这是因为这些梯度的权重连乘项更多,举例来说,对于时间步 ttt,其梯度 dHtdθ\frac{d H_{t}}{d \theta}dθdHt 由以下梯度相加组成

  • 时间步 t−1t - 1t1 的梯度 dHt−1dθ\frac{d H_{t-1}}{d \theta}dθdHt1,与时间步 ttt 的距离为 111,其权重为 ∂Ht∂Ht−1\frac{\partial H_{t}}{\partial H_{t-1}}Ht1Ht
  • 时间步 t−2t - 2t2 的梯度 dHt−2dθ\frac{d H_{t-2}}{d \theta}dθdHt2,与时间步 ttt 的距离为 222,其权重为 ∂Ht∂Ht−1∂Ht−1∂Ht−2\frac{\partial H_{t}}{\partial H_{t-1}} \frac{\partial H_{t-1}}{\partial H_{t-2}}Ht1HtHt2Ht1
  • 时间步 t−3t - 3t3 的梯度 dHt−2dθ\frac{d H_{t-2}}{d \theta}dθdHt2,与时间步 ttt 的距离为 333,其权重为 ∂Ht∂Ht−1∂Ht−1∂Ht−2∂Ht−3∂Ht−3\frac{\partial H_{t}}{\partial H_{t-1}} \frac{\partial H_{t-1}}{\partial H_{t-2}} \frac{\partial H_{t-3}}{\partial H_{t-3}}Ht1HtHt2Ht1Ht3Ht3
  • ……

这说明了什么?这说明了对于当前输入,距其更远的输入的梯度更容易出现梯度爆炸或梯度消失

从而导致长距离的梯度反馈失效,这就是循环神经网络难以捕捉长期依赖的实际含义


最后提醒大家注意一个细节,对于时间步 ttt 的梯度 dHtdθ\frac{d H_{t}}{d \theta}dθdHt

  • 假设有且仅有最后一项梯度爆炸,那么就会导致整个梯度爆炸,因为 dHt−1dθ+⋯+NaN=NaN\frac{d H_{t-1}}{d \theta} + \cdots + NaN = NaNdθdHt1++NaN=NaN
  • 假设有且仅有最后一项梯度消失,这并不会导致整个梯度消失,因为 dHt−1dθ+⋯+0≠0\frac{d H_{t-1}}{d \theta} + \cdots + 0 \neq 0dθdHt1++0=0

总结一下,梯度反向传播时发生的异常,主要可以分为两种,一是梯度爆炸,二是梯度消失

梯度爆炸比较容易处理,一个简单但有效的做法是设置一个梯度阈值,当梯度超过这个阈值时直接截断

梯度消失更难处理一些,而现在流行的做法正是将循环神经网络替换成长短期记忆网络

注意,长短期记忆网络能缓解梯度消失的问题,但并不能缓解梯度爆炸的问题


上面我们从反向传播的角度解释了什么是梯度消失

如果我们从前向计算的角度来看,则梯度消失可以理解成隐状态对短期记忆敏感,对长期记忆作用有限

为了维持长期记忆,长短期记忆网络引入记忆元存放长期记忆,并通过门机制控制记忆元中的信息流动

从直觉上来说,先前重要的记忆会保留在记忆元,不重要的记忆会被过滤,以此来达到长期记忆的目的


这里有两个概念需要解释,一是记忆元,二是门机制,这两个就是长短期记忆网络的核心

先说记忆元,可以理解成另一种隐状态,都是用来记录附加信息的,简称为单元,英文为 Cell\text{Cell}Cell

再说门机制,这是用来控制记忆元中信息流动的机制,具体来说包括三个控制门:

  • 输入门:控制是否将信息写入记忆元,英文为 Input Gate\text{Input Gate}Input Gate
  • 遗忘门:控制是否从记忆元丢弃信息,英文为 Forget Gate\text{Forget Gate}Forget Gate
  • 输出门:控制是否从记忆元读出信息,英文为 Output Gate\text{Output Gate}Output Gate

本质上来说,上述三个控制门都是由一个线性层加一个激活函数组成的,这里激活函数用的是 sigmoid\text{sigmoid}sigmoid

因为这样能将输出限制在零到一之间,以表示门的打开程度,控制信息流动的程度


相比循环神经网络只有一个传输状态,即隐状态,长短期记忆网络有两个传输状态,即隐状态和记忆元

二者的输入输出对比图如下:

输入输出对比

其中 HHH 表示隐状态,CCC 表示记忆元,知道输入输出后,下面开始介绍长短期记忆网络的内部工作原理

首先,根据当前输入 XtX_{t}Xt 和先前隐状态 Ht−1H_{t-1}Ht1,计算得到输入门 ItI_tIt、遗忘门 FtF_tFt、输出门 OtO_tOt

其中,WxiW_{xi}WxiWhiW_{hi}Whibib_{i}biWxfW_{xf}WxfWhfW_{hf}Whfbfb_{f}bfWxoW_{xo}WxoWhoW_{ho}Whobob_{o}bo 都是网络参数,σ\sigmaσsigmoid\text{sigmoid}sigmoid 激活函数
It=σ(XtWxi+Ht−1Whi+bi)Ft=σ(XtWxf+Ht−1Whf+bf)Ot=σ(XtWxo+Ht−1Who+bo)\begin{align*} I_{t} &= \sigma (X_{t} W_{xi} + H_{t-1} W_{hi} + b_{i}) \\ F_{t} &= \sigma (X_{t} W_{xf} + H_{t-1} W_{hf} + b_{f}) \\ O_{t} &= \sigma (X_{t} W_{xo} + H_{t-1} W_{ho} + b_{o}) \end{align*} ItFtOt=σ(XtWxi+Ht1Whi+bi)=σ(XtWxf+Ht1Whf+bf)=σ(XtWxo+Ht1Who+bo)

然后,根据当前输入 XtX_{t}Xt 和先前隐状态 Ht−1H_{t-1}Ht1,计算得到候选记忆元 C~t\widetilde{C}_{t}Ct

其中,WxcW_{xc}WxcWhcW_{hc}Whcbcb_{c}bc 都是网络参数,tanh⁡\tanhtanhtanh⁡\tanhtanh 激活函数
C~t=tanh⁡(XtWxc+Ht−1Whc+bc)\widetilde{C}_{t} = \tanh (X_{t} W_{xc} + H_{t-1} W_{hc} + b_{c}) Ct=tanh(XtWxc+Ht1Whc+bc)
接着,输入门 ItI_tIt 控制采用多少来自 C~t\widetilde{C}_{t}Ct 的新信息,遗忘门 FtF_tFt 控制保留多少来自 Ct−1C_{t-1}Ct1 的旧信息,计算得 CtC_tCt

其中,⊙\odot 表示按元素乘法,当 It=0I_{t} = 0It=0Ft=1F_{t} = 1Ft=1 时,则过去记忆元被保留并传递到当前时间步
Ct=Ft⊙Ct−1+It⊙C~tC_{t} = F_{t} \odot C_{t-1} + I_{t} \odot \widetilde{C}_{t} Ct=FtCt1+ItCt
最后,输出门 OtO_tOt 控制采用多少来自 CtC_{t}Ct 的长记忆,计算得 HtH_{t}Ht

其中,⊙\odot 表示按元素乘法,tanh⁡\tanhtanh 表示 tanh⁡\tanhtanh 激活函数,当 OtO_{t}Ot 接近 111 时,就可以将长期记忆传递给隐状态
Ht=Ot⊙tanh⁡(Ct)H_{t} = O_{t} \odot \tanh (C_{t}) Ht=Ottanh(Ct)
上述计算过程对应的计算图如下所示:

长短期记忆网络结构

为了帮助大家进一步理解长短期记忆网络的工作方式,下面我们举一个例子来说,并给出关键代码

假设我们用长短期记忆网络对下面这个句子进行编码:我在画画

import torch
import torch.nn as nn# 定义输入数据
# 对于输入句子我在画画,首先用独热编码得到其向量表示x1 = torch.tensor([1, 0, 0]).float() # 我
x2 = torch.tensor([0, 1, 0]).float() # 在
x3 = torch.tensor([0, 0, 1]).float() # 画
x4 = torch.tensor([0, 0, 1]).float() # 画h0 = torch.zeros(5) # 初始化隐状态
c0 = torch.zeros(5) # 初始化记忆元# 定义模型参数
# 模型的输入是三维向量,这里定义模型的输出是五维向量W_xi = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hi = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_i  = nn.Parameter(torch.randn(5)   , requires_grad = True)W_xf = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hf = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_f  = nn.Parameter(torch.randn(5)   , requires_grad = True)W_xo = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_ho = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_o  = nn.Parameter(torch.randn(5)   , requires_grad = True)W_xc = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hc = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_c  = nn.Parameter(torch.randn(5)   , requires_grad = True)# 前向传播def forward(X, H, C):# 计算输入门、遗忘门、输出门I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i)F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f)O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o)# 计算候选记忆元C_tilde = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)# 计算当前记忆元C = F * C + I * C_tilde# 计算当前隐状态H = O * C.tanh()# 返回结果return H, Ch1, c1 = forward(x1, h0, c0)
h2, c2 = forward(x2, h1, c1)
h3, c3 = forward(x3, h2, c2)
h4, c4 = forward(x4, h3, c3)# 结果输出print(h3) # tensor([-0.0408,  0.1785,  0.0455,  0.3802,  0.0235])
print(h4) # tensor([-0.0560,  0.1269,  0.0346,  0.3426,  0.0118])

最后提醒大家一点,如果长短期记忆网络后有接其他网络,例如后面接一个线性层做单词预测

那么通常不会用记忆元的输出,而是用隐藏层的输出


至此本文结束,要点总结如下:

  1. 循环神经网络在处理长序列时很容易会出现梯度爆炸和梯度消失的情况,导致网络难以捕捉长期依赖

    对于梯度爆炸,通常可以采用梯度裁剪解决,对于梯度消失,可以采用长短期记忆网络缓解

  2. 除了有隐状态,长短期记忆网络还增加记忆元存放长期记忆,并通过门机制控制记忆元中的信息流动

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_236326.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电子厂测试题——难倒众多主播——大司马也才90分

一、选择题 1、1-2 ( ) A.1 B.3 C.-1 D.-3 2、|1-2|( ) A.1 B.3 C. -1 D.-3 3、1x2x3( ) A.5 B.6 C.7 D.8 4、3643( ) A.29 B.16 C.8 D.3 5、55x5( ) A.15 B.30 C.50 D.125 二、填空题(请填写阿拉伯数字) 6、110100 1000_______ 7、一个三角形砍去1个角&#…

Linux(三) makefile与gdb调试

makefile mkefile文件中定义了一系列的规则来指定&#xff0c;哪些文件需要线编译&#xff0c;哪些后编译&#xff0c;哪些需要重新编译&#xff0c;甚至进行更复杂的功能操作&#xff0c;因为makefile就像一个Shell脚本一样&#xff0c;其中也可以执行操作系统的命令。 mkef…

硬件需知知识 -- 基本元件(电阻)

一、电阻 1.1 贴片电阻 1.1.1 贴片电阻的封装大小是和功率时相关的。 封装大小功率(W)0201120\frac{1}{20}201​0402116\frac{1}{16}161​0603110\frac{1}{10}101​080518\frac{1}{8}81​12060.2518120.5或1201012\frac{1}{2}21​25121或者21.1.2 贴片电阻读数 贴片电阻的读数…

Ac-EEVVAC-pNA,389868-12-6

Ac-EEVVAC-pNA, chromogenic substrate for a continuous spectrophotometric assay of HCV NS3 protease. The sequence EEVVAC is derived from the 5A-5B cleavage junction of the HCV polyprotein. Ac-EEVVAC-pNA, HCV NS3蛋白酶连续分光光度法测定的显色底物。EEVVAC序列…

新冠病毒:KN95(GB2626类型口罩)是否有效阻挡?

点击上方“青年码农”关注回复“源码”可获取各种资料​今天刷新闻&#xff0c;看到很多官方账号发布&#xff0c;只有五种编码口罩能防疫&#xff0c;分别是医用防护口罩&#xff08;GB19083-2010&#xff09;医用外科口罩&#xff08;YY0469-2011&#xff09;一次性使用医用口…

带有匹配滤波器的雷达信号调制和脉冲压缩Matlab仿真

up目录 一、理论基础 二、核心程序 三、测试结果 一、理论基础 匹配滤波器&#xff1a; 匹配滤波器是输出端的信号瞬时功率与噪声平均功率的比值最大的线性滤波器也就是说有最大的信噪比。其滤波器的传递函数形式是信号频谱的共轭。在通信系统中&#xff0c;滤波器是其中重…

微服务框架 SpringCloud微服务架构 多级缓存 47 Lua 语法入门 47.3 条件控制、函数

微服务框架 【SpringCloudRabbitMQDockerRedis搜索分布式&#xff0c;系统详解springcloud微服务技术栈课程|黑马程序员Java微服务】 多级缓存 文章目录微服务框架多级缓存47 Lua 语法入门47.3 条件控制、函数47.3.1 函数47.3.2 条件控制47 Lua 语法入门 47.3 条件控制、函数…

【图像融合】DCT域多焦点图像融合【含Matlab源码 1973期】

⛄一、基于DCT变换的图像融合算法简介 在图像融合过程中,最主要的就是如何提取低高频系数以及低高频系数的融合准则。基于DCT变换的图像融合算法原理如图2所示。 图2 DCT融合算法原理 算法步骤如下。 步骤1精确配准待融合的源图像。 步骤2采用分块的方法将参与融合的每幅大小…

【Java版oj】逆波兰表达式求值

目录 一、原题再现 二、问题分析 三、完整代码 一、原题再现 150. 逆波兰表达式求值 有效的算符包括 、-、*、/ 。每个运算对象可以是整数&#xff0c;也可以是另一个逆波兰表达式。 注意 两个整数之间的除法只保留整数部分。 可以保证给定的逆波兰表达式总是有效的。换句话…

1、浮动(float)

提示&#xff1a;我们一般网页上下用标准流&#xff0c;左右用浮动来写 1.1传统网页布局三种方式 网页布局本质——用css来摆放盒子&#xff0c;把盒子摆放到相应位置。css提供了三种传统布局简单方式&#xff0c;说就是盒子如何进行排列顺序&#xff1a; 普通流&#xff08;或…

[附源码]Python计算机毕业设计高校助学金管理系统Django(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程 项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等…

Selenium3自动化测试【40】Html测试报告

&#x1f4cc; 博客主页&#xff1a; 程序员二黑 &#x1f4cc; 专注于软件测试领域相关技术实践和思考&#xff0c;持续分享自动化软件测试开发干货知识&#xff01; &#x1f4cc; 公号同名&#xff0c;欢迎加入我的测试交流群&#xff0c;我们一起交流学习&#xff01; 目录…

tkinter: 基本+Button+Layout

简介 简介 Tcl 动态解释型编程语言可独立执行&#xff0c;多嵌入C程序中作为脚本引擎&#xff0c;或者作为使用Tk工具包的接口Tcl库可以创建一个或多个Tcl解释器实例&#xff0c;然后在这些实例上运行C或Tcl命令和脚本每个解释器有一个事件队列&#xff0c;接受事件并处理他们…

分享10个比B站更刺激的网站,千万别轻易点开

作为一个码龄8年程序员&#xff0c;到现在还能保持着浓密的头发和健壮的身体&#xff0c;全靠这10个网站让我健&#xff08;偷&#xff09;康&#xff08;偷&#xff09;生&#xff08;摸&#xff09;活&#xff08;鱼&#xff09;&#xff0c;今天就把我收藏夹里的网站无私分享…

【实时数仓】在Hbase建立维度表、保存维度数据到Hbase、保存业务数据到kafka主题

文章目录一 分流Sink之建立维度表到HBase(Phoenix)1 拼接建表语句&#xff08;1&#xff09;定义配置常量类&#xff08;2&#xff09;引入依赖&#xff08;3&#xff09;hbase-site.xml&#xff08;4&#xff09;在phoenix中执行&#xff08;5&#xff09;增加代码a TableProc…

用Python写一个模拟qq聊天小程序的代码实例

前言 今天小编就为大家分享一篇关于用Python写一个模拟qq聊天小程序的代码实例&#xff0c;小编觉得内容挺不错的&#xff0c;现在分享给大家&#xff0c;具有很好的参考价值&#xff0c;需要的朋友一起跟随小编来看看吧 Python 超简单的聊天程序 客户端: 服务器: 模拟qq聊…

张驰咨询:快速提高流程效率的5个关键精益生产工具

精益&#xff0c;又称“精益制造”或“精益生产”&#xff0c;注重通过消除浪费、消除缺陷&#xff0c;实现客户价值最大化。精益工具是关于理解过程&#xff0c;发现浪费&#xff0c;防止错误和记录你所做的事情。 让我们来看看流程改进中使用的五种精益工具&#xff0c;它们…

对 CSS 工程化的理解

CSS 工程化是为了解决以下问题&#xff1a; 宏观设计&#xff1a;CSS 代码如何组织、如何拆分、模块结构怎样设计&#xff1f;编码优化&#xff1a;怎样写出更好的 CSS&#xff1f;构建&#xff1a;如何处理我的 CSS&#xff0c;才能让它的打包结果最优&#xff1f;可维护性&a…

ReplicaSet和Deployment

ReplicaSet和Deployment 写在前面 语雀原文阅读效果更佳&#xff1a;198 ReplicaSet和Deployment 语雀 《198 ReplicaSet和Deployment》 1、ReplicaSet 假如我们现在有一个 Pod 正在提供线上的服务&#xff0c;我们来想想一下我们可能会遇到的一些场景&#xff1a; 某次运营…

计算机毕业设计django基于python大学生多媒体学习系统

项目介绍 随着计算机多媒体技术的发展和网络的普及。采用当前流行的B/S模式以及3层架构的设计思想通过Python技术来开发此系统的目的是建立一个配合网络环境的大学生多媒体学习系统的平台,这样可以有效地解决数据学习系统混乱的局面。 本文首先介绍了大学生多媒体学习系统的发…