9--RNN

news/2024/5/14 12:46:40/文章来源:https://blog.csdn.net/wangyumei0916/article/details/126788155

有隐藏状态的循环神经网络

        假设在时间步t有小批量输入\mathbf{X}_t \in \mathbb{R}^{n \times d},即对于n个序列样本的小批量,\mathbf{X}_t的每一行对应于来自该序列的时间步t处的一个样本,用\mathbf{H}_t \in \mathbb{R}^{n \times h}表示时间步t的隐藏变量。与MLP不同的是, 我们在这里保存了前一个时间步的隐藏变量\mathbf{H}_{t-1},并引入了一个新的权重参数\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}。当前时间步隐藏变量由当前时间步的输入与前一个时间步的隐藏变量一起计算得出:

\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1} \mathbf{W}_{hh} + \mathbf{b}_h).

        从相邻时间步的隐藏变量\mathbf{H}_t\mathbf{H}_{t-1}之间的关系可知, 这些变量捕获并保留了序列直到其当前时间步的历史信息, 就如当前时间步下神经网络的状态或记忆, 因此这样的隐藏变量被称为隐状态(hidden state)。对于时间步t,输出层的输出类似于多层感知机中的计算:

\mathbf{O}_t = \mathbf{H}_t \mathbf{W}_{hq} + \mathbf{b}_q.

        其实循环神经网络与MLP不同的地方就在于,中间隐藏层的更新会依赖于上一时间步的隐藏层。(下图中蓝色的点为隐藏层)

基于循环神经网络的字符级语言模型 

        根据过去的词与当前的词来对下一个词进行预测,可以将词的原始序列位移一个词源作为一个标签。考虑使用神经网络来进行语言建模,设小批量大小为1,批量中的那个文本序列为“machine”。这里考虑字符级语言模型,下图展示了如何通过之前以及当前字符预测下一个字符。

        在训练过程中,对每个时间步的输出都进行一个softmax操作,并利用交叉熵损失计算模型输出和标签之间的误差。

困惑度(Perplexity)

        对于语言模型预测的结果,通过计算序列的似然概率来度量模型的质量。 一个更好的语言模型应该能更准确地预测下一个词元。因此,它在压缩序列时花费更少的比特。所以可以通过一个序列中所有的n个词元的交叉熵损失的平均值来衡量:

\frac{1}{n} \sum_{t=1}^n -\log P(x_t \mid x_{t-1}, \ldots, x_1),

        其中P由语言模型给出, xt是在时间步t从该序列中观察到的实际词元,上式的指数则称为困惑度,即下一个词元的实际选择数的调和平均数

\exp\left(-\frac{1}{n} \sum_{t=1}^n \log P(x_t \mid x_{t-1}, \ldots, x_1)\right). 

        在最好的情况下,模型总是完美地估计标签词元的概率为1(即预测结果为一个词元), 在这种情况下,模型的困惑度为1。 在最坏的情况下,模型总是预测标签词元的概率为0,在这种情况下,困惑度是正无穷大。在基线上,该模型的预测是词表的所有可用词元上的均匀分布,困惑度等于词表中唯一词元的数量。

实例

        基于时光机器数据集来训练模型,具体代码如下:

!pip install git+https://github.com/d2l-ai/d2l-zh@release  # installing d2l
!pip install matplotlib_inline
!pip install matplotlib==3.0.0import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size , num_steps = 32,35
train_iter,vocab = d2l.load_data_time_machine(batch_size , num_steps)#构造一个具有256个隐藏单元的单隐藏层的循环神经网络层
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab),num_hiddens,1)class RNNModel(nn.Module):def __init__(self,rnn_layer,vocab_size,**kwargs):super(RNNModel,self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_sizeif not self.rnn.bidirectional:self.num_directions=1self.linear = nn.Linear(self.num_hiddens,self.vocab_size)else:self.num_directions=2self.linear = nn.Linear(self.num_hiddens*2,self.vocab_size)def forward(self,inputs,state):X = F.one_hot(inputs.T.long(),self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X,state)output = self.linear(Y.reshape(-1,Y.shape[-1]))return output,state#初始化隐状态为0 形状是(隐藏层数,批量大小,隐藏单元数)def begin_state(self,device,batch_size=1):if not isinstance(self.rnn,nn.LSTM):return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))
device = d2l.try_gpu()
net = RNNModel(rnn_layer,vocab_size=len(vocab))
num_epochs,lr = 500,1
d2l.train_ch8(net,train_iter,vocab,lr,num_epochs,device)

        运行结果如下,500个epoch后困惑度达到了1.3。

        另外,这里分别使用训练前和训练后的模型对“time traveller”后续词元进行续写,可以看出模型训练前完全是随机性的预测字符串,虽然训练后的模型预测结果语义上不太通顺,但预测出来的单词大部分是正确的(该模型的词元是字符)。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_6191.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《Mycat分布式数据库架构》之数据切分实战

文章目录1、引言2、前期准备2.1 系统环境2.2 数据库集群3 注意事项3.1 分片原则3.2 如何选择分片键4 数据切分实战4.1 配置访问用户及权限4.2 配置逻辑库及逻辑表4.3 配置分片规则4.3.1 简单取模分片4.3.2 哈希取模分片4.3.3 枚举分片4.3.4 字符串范围取模分片前文回顾&#xf…

Selenium操作已经打开的Chrome(只怪自己尝试的太迟)

🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝🔝 🥰 博客首页:…

抖音视频

刻度尺读取方法0n:/ 复制打开抖音,看看【天子骄龙的作品】初中物理-刻度尺读数 ηηQ2VtW0nGyv8▽▽ 秒表读取方法 8.76 aNW:/ 复制打开抖音,看看【天子骄龙的作品】初中物理-秒表读数# 专业的事交给专业的人 初中物理... https://v.douyin.com/6RTySK2/

微信支付v3

文章目录前言1. 微信支付产品介绍2 接入指引2.1 获取商户号2.2 获取appid2.3 获取密钥和证书3 支付安全3.1 对称加密和非对称加密3.2 身份认证3.3 数字证书3.4 https中的数字证书3.5 微信支付中的证书密钥和签名4 基础支付apiv34.1 基础支付APly3-引入支付参数4.2 基础支付APly…

frame标签使用

当页面采用框架集的时候,如果点击,某个部分想在当前页面跳转到一个全新的无框架集的页面,可以在超链接中指定 target属性,如果指定为_top,则是整个页面,也可以指定某个frame 。 默认的几种值有: _self:当前frame(或者当前部分) _blank:打开新的一个窗口 _parent:当…

upload-labs靶场通关指南(9-11关)

今天继续给大家介绍渗透测试相关知识,本文主要内容是upload-labs靶场通关指南(9-10关) 免责声明: 本文所介绍的内容仅做学习交流使用,严禁利用文中技术进行非法行为,否则造成一切严重后果自负! …

JavaScript每日一题_立即执行函数中函数名和变量同名,输出的是什么

立即执行函数中函数名和变量同名,输出的是什么 代码如下 var a 1;(function a() {a 2console.log(a)})();首先,不是输出2,也不是输出1 运行代码 输出的是函数a未定义 一句一句代码解读 实现 var a 1;会在window对象上挂载一个属性a,并赋值为1 然后是 (function a() {a …

.NET操作Excel高效低内存的开源框架 - MiniExcel

.Net平台上对Excel进行操作主要有两种方式。第一种,把Excel文件看成一个数据库,通过OleDb的方式进行读取与操作;第二种,调用Excel的COM组件。两种方式各有特点。 今天给大家介绍第三种方式:插件方式,目前主流框架大多需要将数据全载入到内存方便操作,但这会导致内存消耗…

【ZJSU - 大红大紫:ACM - Template】比赛用模板12:STL与库函数

模板整理12:STL与库函数(更新至v6.0,2022.09.10)\(\tt STL\) 与库函数 后继 \(\tt lower\_bound、upper\_bound\) lower 表示 \(\ge\) ,upper 表示 \(>\) 。使用前记得先进行排序。 //返回a数组[start,end)区间中第一个>=x的地址【地址!!!】 cout << lower…

剑指offer--重建二叉树

目录Start代码及分析EndingStart 代码及分析 在已知前序遍历和中序遍历之后&#xff0c;如何建树呢&#xff1f; 我们知道&#xff0c;在二叉树的前序遍历当中&#xff0c;第一个数字总是根结点的值。而在中序遍历中&#xff0c;根节点位于中间位置&#xff0c;根结点的左边是…

【数据结构】交换排序之冒泡排序与快速排序

承接上文&#xff1a; (32条消息) 【数据结构】常见排序之插入排序与选择排序_vpurple__的博客-CSDN博客https://blog.csdn.net/vpurple_/article/details/126568614?spm1001.2014.3001.5502https://blog.csdn.net/vpurple_/article/details/126568614?spm1001.2014.3001.55…

【算法刷题日记之本手篇】微信红包与计算字符串的编辑距离

⭐️前面的话⭐️ 本篇文章介绍来自牛客试题广场的两道题题解&#xff0c;分别为【微信红包】和【计算字符串的编辑距离】&#xff0c;展示语言java。 小贴士&#xff1a;本专栏所有题目来自牛客->面试刷题必用工具 &#x1f4d2;博客主页&#xff1a;未见花闻的博客主页 …

索引优化分析_预热_JOIN

索引优化分析_预热_JOIN1.性能下降SQL慢 执行时间长 等待时间长2.常见通用的Join查询2.1.SQL执行顺序2.2.Join图2.3.建表SQL2.4 7种JOIN2.5.扩展(掌门人)1.性能下降SQL慢 执行时间长 等待时间长 数据过多——分库分表 mycat索引失效&#xff0c;没有充分利用到索引——索引建立…

Java项目:ssm流浪狗领养系统

作者主页&#xff1a;夜未央5788 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 流浪狗领养网站是一个基于ssm(Spring SpringMVC MyBatis)的项目&#xff0c;项目分为前后台。 前台网站主要首页(包含轮播图、关键字搜索、点击排行…

常见的屏幕接口

常见的屏幕接口 常见的屏幕接口有: 6800、8080、RGB、I2C、SPI、MIPI-SDI、LVDS等今天聊一聊我最近想手动给我的esp8266开发板加一块LCD裸屏,网上找了有一遍绝大部分都是LCD屏幕模块,于是自己查了一通资料整理一下。我们平时用的大部分都是屏幕模块,而模块上面只有几个引脚而…

springboot学生成绩课堂表现过程性评价系统java

随着互联网技术的发发展,计算机技术广泛应用在人们的生活中,逐渐成为日常工作、生活不可或缺的工具,高校各种管理系统层出不穷。高校作为学习知识和技术的高等学府,信息技术更加的成熟,为校园教务管理开发必要的系统,能够有效的提升管理效率。一直以来,校园教务一直没有进行系统…

python-opencv之形态学操作(腐蚀和膨胀)原理详解

形态学操作作用 Removing noise.Isolation of individual elements and joining disparate elements in an image.Finding of intensity bumps or holes in an image. 最基本的形态操作是侵蚀和扩张。让我们更详细地了解这些操作。 Erosion 腐蚀 原理 它会侵蚀前景物体的边…

Spring Cloud Alibaba 中 Nacos 组件的使用

Spring Cloud Alibaba 微服务工具集 阿里巴巴版本: 2.2.1 Boot版本: 2.2.5 1.简介 Spring Cloud Alibaba provides a one-stop solution for distributed application development. It contains all the components required to develop distributed applications, making …

字符串匹配算法之——KMP算法

字符串匹配在日常开发中很常用,用于判断一个字符串中是否包含另外一个字符串,例如Java中的indexOf方法,查到则返回对应的位置,未查询到则返回-1。 如图-1,在“abcabd”中查找“abd”,最终在下标3的位置匹配。 图-1 至于是如何匹配的,直觉上…

spring的set注入方式流程图解

spring的set注入方式流程图解 自己学习spring的一些笔记,详细画出了spring的set方式实现依赖注入的流程。注意:<property name="UserDao" ref="userDao"></property>的name属性值要与UserServiceImpl中的setxxx();方法的名字相同,但是首字…