【Pytorch深度学习实战】(9)神经语言模型(RNN-LM)

news/2024/5/14 8:06:41/文章来源:https://blog.csdn.net/sikh_0529/article/details/126923368

 🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎

📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃

🎁欢迎各位→点赞👍 + 收藏⭐️ + 留言📝​

📣系列专栏 - 机器学习【ML】 自然语言处理【NLP】  深度学习【DL】

 🖍foreword

✔说明⇢本人讲解主要包括Python、机器学习(ML)、深度学习(DL)、自然语言处理(NLP)等内容。

如果你对这个系列感兴趣的话,可以关注订阅哟👋

神经网络语言模型(RNN-LM)


传统语言模型的上述几个内在缺陷使得人们开始把目光转向神经网络模型,期望深度学习技术能够自动化地学习代表语法和语义的特征,解决稀疏性问题,并提高泛化能力。我们这里主要介绍两类神经网络模型:前馈神经网络模型(FFLM)和循环神经网络模型(RNNLM)。前者主要设计来解决稀疏性问题,而后者主要设计来解决泛化能力,尤其是对长上下文信息的处理。在实际工作中,基于循环神经网络及其变种的模型已经实现了非常好的效果。

我们前面提到,语言模型的一个主要任务就是要解决给定到当前的上下文的文字信息,如何估计现在每一个单词出现的概率。Bengio等人提出的第一个前馈神经网络模型利用一个三层,包含一个嵌入层、一个全连接层、一个输出层,的全连接神经网络模型来估计给定n-1个上文的情况下,第n个单词出现的概率。其架构如下图所示:

在这里插入图片描述
RNN语言模型训练过程
另一类循环神经网络模型不要求固定窗口的数据训练。FFLM假设每个输入都是独立的,但是这个假设并不合理。经常一起出现的单词以后也经常出现的概率会更高,并且当前应该出现的词通常是由前面一段文字决定的,利用这个相关性能提高模型的预测能力。循环神经网络的结构能利用文字的这种上下文序列关系,从而有利于对文字建模。这一点相比FFLM模型更接近人脑对文字的处理模型。比如一个人说:"我是中国人,我的母语是___ "。 对于在“__”中需要填写的内容,通过前文的“母语”知道需要是一种语言,通过“中国”知道这个语言需要是“中文”。通过RNNLM能回溯到前两个分句的内容,形成对“母语”,“中国”等上下文的记忆。一个典型的RNNLM模型结构如下图所示。

在这里插入图片描述
RNN语言模型训练过程

在这里插入图片描述


RNN语言模型反向传播

在这里插入图片描述
语言模型评估
迷惑度/困惑度/混乱度(perplexity),其基本思想是给测试集的句子赋予较高概率值的语言模型较好,当语言模型训练完之后,测试集中的句子都是正常的句子,那么训练好的模型就是在测试集上的概率越高越好。迷惑度越小,句子概率越大,语言模型越好。

 在这里插入图片描述

 神经网络语言模型Pytorch的实现

import torch
import torch.nn as nn
import numpy as np
from torch.nn.utils import clip_grad_norm_
from data_utils import Dictionary, Corpus# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 超参数
embed_size = 128
hidden_size = 1024
num_layers = 1
num_epochs = 5
num_samples = 1000     # number of words to be sampled
batch_size = 20
seq_length = 30
learning_rate = 0.002# 加载“Penn Treebank”数据集
corpus = Corpus()
ids = corpus.get_data('data/train.txt', batch_size)
vocab_size = len(corpus.dictionary)
num_batches = ids.size(1) // seq_length# 基于RNN的语言模型
class RNNLM(nn.Module):def __init__(self, vocab_size, embed_size, hidden_size, num_layers):super(RNNLM, self).__init__()self.embed = nn.Embedding(vocab_size, embed_size)self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)self.linear = nn.Linear(hidden_size, vocab_size)def forward(self, x, h):# 将单词 id 嵌入到向量中x = self.embed(x)# 前向传播 LSTMout, (h, c) = self.lstm(x, h)# 将输出重塑为 (batch_size*sequence_length, hidden_​​size)out = out.reshape(out.size(0)*out.size(1), out.size(2))# 解码所有时间步的隐藏状态out = self.linear(out)return out, (h, c)model = RNNLM(vocab_size, embed_size, hidden_size, num_layers).to(device)# 损失和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 截断反向传播
def detach(states):return [state.detach() for state in states] # 训练模型
for epoch in range(num_epochs):# 设置初始隐藏和单元格状态states = (torch.zeros(num_layers, batch_size, hidden_size).to(device),torch.zeros(num_layers, batch_size, hidden_size).to(device))for i in range(0, ids.size(1) - seq_length, seq_length):# 获取小批量输入和目标inputs = ids[:, i:i+seq_length].to(device)targets = ids[:, (i+1):(i+1)+seq_length].to(device)# 前传states = detach(states)outputs, states = model(inputs, states)loss = criterion(outputs, targets.reshape(-1))# 向后优化optimizer.zero_grad()loss.backward()clip_grad_norm_(model.parameters(), 0.5)optimizer.step()step = (i+1) // seq_lengthif step % 100 == 0:print ('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'.format(epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))# 测试模型
with torch.no_grad():with open('sample.txt', 'w') as f:# 设置初始隐藏单元状态state = (torch.zeros(num_layers, 1, hidden_size).to(device),torch.zeros(num_layers, 1, hidden_size).to(device))# 随机选择一个单词idprob = torch.ones(vocab_size)input = torch.multinomial(prob, num_samples=1).unsqueeze(1).to(device)for i in range(num_samples):# 前向传播 RNNoutput, state = model(input, state)# 采样一个单词idprob = output.exp()word_id = torch.multinomial(prob, num_samples=1).item()# 用采样的单词 id 填充输入以用于下一个时间步input.fill_(word_id)# 文件写入word = corpus.dictionary.idx2word[word_id]word = '\n' if word == '<eos>' else word + ' 'f.write(word)if (i+1) % 100 == 0:print('Sampled [{}/{}] words and save to {}'.format(i+1, num_samples, 'sample.txt'))# 保存模型checkpoints
torch.save(model.state_dict(), 'model.ckpt')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_10083.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第一视角体验搭载全志T507-H的开发板MYD-YT507H开发板

如今车规级芯片市场潜力巨大&#xff0c;需求旺盛&#xff0c;芯片都在逐渐走向国产化。本文要介绍的主角是MYD-YT507H开发板&#xff0c;该开发板是米尔科技结合全志国产工业级平台CPU——全志T507-H芯片研制的CPU模组&#xff0c;全志T507-H可广泛用于电力物联网、汽车电子、…

目标检测开源框架YOLOv6全面升级,更快更准的2.0版本来啦

9月5日&#xff0c;美团视觉智能部发布了YOLOv6 2.0版本&#xff0c;本次更新对轻量级网络进行了全面升级&#xff0c;量化版模型 YOLOv6-S 达到了 869 FPS&#xff0c;同时&#xff0c;还推出了综合性能优异的中大型网络&#xff08;YOLOv6-M/L&#xff09;&#xff0c;丰富了…

一个div靠左另一个靠右

1.使用flex布局<style>#back{border: red solid 1px;width: 800px;height: 500px;display: flex;align-items: center;}#left{border: blue 1px solid;width: 100px;height: 100px;justify-content: flex-start;}#right{border: blue 1px solid;width: 100px;height: 100…

【前端进阶】-TypeScript类型声明文件详解及使用说明

前言 博主主页&#x1f449;&#x1f3fb;蜡笔雏田学代码 专栏链接&#x1f449;&#x1f3fb;【TypeScript专栏】 前三篇文章讲解了TypeScript的一些高级类型 详细内容请阅读如下&#xff1a;&#x1f53d; 【前端进阶】-TypeScript高级类型 | 泛型约束、泛型接口、泛型工具类…

Google Pub/Sub入门

什么是Google Pub/Sub&#xff1f; 首先他是一个messaging buffer/coupler消息缓冲区/耦合器&#xff0c;Decouples senders and receivers解耦发送者和接收者。 一些特性&#xff1a; 使用 Dataflow 注入分析事件并将其流式插入到 BigQuery免运维、安全、可伸缩的消息传递系…

MySQL基础总结合集

MySQL是啥&#xff1f;数据库又是啥&#xff1f; MySQL&#xff1a; MySQL 是最流行的关系型数据库管理系统&#xff0c;在 WEB 应用方面 MySQL 是最好的 RDBMS(Relational Database Management System&#xff1a;关系数据库管理系统)应用软件之一。 数据库&#xff1a; 数…

基于nodejs+vue的读书会网站

实行网上读书会网站&#xff0c;对其改善目前人们读书现状提供一些帮助和优化措施&#xff0c;为人们在未来看书节约了很多时间&#xff0c;使得人们在未来利用自己有限的时间可以看到更多对自己有益的书籍。 基于Vue的读书会网站的实现&#xff0c;通过网上系统的研发构造&…

你是否想过,GitHub Pages也可以自动构建?|原创

本文讲述了如何利用 GitHub Actions 来自动构建 GitHub Pages 项目&#xff0c;免去繁琐的手动构建再提交过程&#xff0c;让你专注于写作。点击上方“后端开发技术”&#xff0c;选择“设为星标” &#xff0c;优质资源及时送达GitHub Actions 自动构建之前的文章我们已经讲过…

Tomcat 在IDEA中运行Tomcat,控制台乱码问题的解决方案

IDEA中运行Tomcat,控制台乱码问题的解决方案试了好多种网上的方案(只有这一种能解决)环境:jdk 11 idea 2022.2.4 tomcat 9.0.54解决方案: 1.打开tomcat的配置文件(apache-tomcat-9.0.54\conf\logging.properties)将文件中的java.util.logging.ConsoleHandler.encoding =…

el-tree增加提示语

element ui tree树形控件加提示信息<el-tree :data="tieLinedata" :props="defaultProps" @node-click="handleNodeClick"><span class="span-ellipsis" slot-scope="{ node, data }"><span :title="no…

【图像增强】基于DEHAZENET和HWD的水下去散射图像增强附matlab代码

1 内容介绍 去散射和边缘增强是解决水下图像的对比度严重衰减、颜色偏差和边缘模糊等问题的关键步骤。这篇论文提出了一种较好的水下图像增强的方法。首先使用经过端到端训练的卷积神经网络去测量输入图片&#xff0c;同时以自适应双边滤波器对传输图片进行处理。接着提出一种…

allure介绍——生成完美的测试报告

一、allure简介 Allure是输出网页测试报告的一种框架 1、该框架是基于Java写的,所以安装该框架需要先安装JDK; 2、下载allure命令行工具,路径:https://github.com/allure-framework/allure2/releases 注:①下载包放到pytest文件夹中,然后将allure/bin的路径放到环境变量的…

css font-size设置小于12px失效(转)

原文:https://blog.csdn.net/weixin_38629529/article/details/119866495 1、描述 不知道你有没有遇到这样的情况,设置了font-size为10px,打开控制台审查元素也显示的是10px,但浏览器渲染的字体大小还是没有发生改变。 这是因为浏览器(以Chrome为例,其他没测试过)在中文…

第五篇、Callable接口实现多线程

文章目录前言一、实现Callable接口二、代码示例1.Callable接口实现多线程总结前言 上一篇我们共同认识了并发问题&#xff0c;那么本篇我们将一起来学习Callable接口实现多线程。 一、实现Callable接口 上篇内容我们通过实现Runnable实现多线程&#xff0c;本篇我们将学习如何…

非零基础自学Java (老师:韩顺平) 第13章 常用类 13.11 日期类

非零基础自学Java (老师&#xff1a;韩顺平) ✈【【零基础 快速学Java】韩顺平 零基础30天学会Java】 第13章 常用类 文章目录非零基础自学Java (老师&#xff1a;韩顺平)第13章 常用类13.11 日期类13.11.1 第一代日期类13.11.2 第二代日期类13.11.3 第三代日期类13.11.4 Dat…

线稿图视频制作--从此短视频平台不缺上传视频了

&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d;&#x1f51d; &#x1f970; 博客首页&#xff1a;knighthood2001 &#x1f6…

[unity][通过代码]控制模型旋转,动态改变模型角度,让模型转动到固定角度

阅读建议 ⏰阅读时长 : 10分钟 &#x1f3eb;阅读难度 : 初级 &#x1f33e;阅读收获 : 了解模型的旋转基本原理,了解瞬间旋转和过度旋转的理论,并学习到一种过渡方式的代码编写 &#x1f587;例子地址 : https://gitee.com/asiworld/unity3d-basic-function-code &#x1f921…

Vulnhub_Noob

本文内容涉及程序/技术原理可能带有攻击性&#xff0c;仅用于安全研究和教学使用&#xff0c;务必在模拟环境下进行实验&#xff0c;请勿将其用于其他用途。因此造成的后果自行承担&#xff0c;如有违反国家法律则自行承担全部法律责任&#xff0c;与作者及分享者无关主机信息K…

ESP8266-Arduino编程实例-PCT2075温度数字转换器驱动

PCT2075温度数字转换器驱动 1、PCT2075介绍 PCT2075 是一款温度数字转换器,在 ‑25 C 至 +100 C 范围内具有 1 C 的精度。它使用片上带隙温度传感器和 Sigma-Delta A-D 转换技术,具有过温检测输出,是其他 LM75 系列热传感器的直接替代品。 该设备包含多个数据寄存器: 配置…

攻防世界 - (题目名称-文件包含)

我会在 writeup 中写出我在解决这道题时遇到的问题&#xff0c;以及对问题的思考&#xff0c;而不是直接给出 payload。 进入场景&#xff1a; 显然是考文件包含&#xff0c;第一反应是用 php://filter 加 convert.base64-encode 文件读取 flag.php&#xff1a; 提示 "do…