Transformer

news/2024/5/17 16:59:39/文章来源:https://blog.csdn.net/qq_36372569/article/details/127084422

参考
https://www.ylkz.life/deeplearning/p12158901/
https://zhuanlan.zhihu.com/p/396221959

模型结构

在这里插入图片描述

Input Embedding

将文本中词汇的数字表示转变为向量表示, 希望得到其在高维空间中的特征表示向量。

# 导入必备的工具包
import torch
import torch.nn as nn
import math
from torch.autograd import Variable# 定义Embeddings类来实现文本嵌入层,这里s说明代表两个一模一样的嵌入层, 他们共享参数.
class Embeddings(nn.Module):def __init__(self, d_model, vocab):"""类的初始化函数,有两个参数, d_model: 指词嵌入的维度, vocab: 指词表的大小"""super(Embeddings, self).__init__()# 调用nn中的预定义层Embedding, 获得一个词嵌入对象self.lutself.lut = nn.Embedding(vocab, d_model)# 将d_model传入类中self.d_model = d_modeldef forward(self, x):# 将x传给self.lut并与根号下self.d_model相乘作为结果返回return self.lut(x) * math.sqrt(self.d_model)

Positional Encoding

在Transformer的编码器结构中, 并没有针对词汇位置信息的处理,因此需要在Embedding层后加入位置编码器,将词汇位置不同可能会产生不同语义的信息加入到词嵌入张量中, 以弥补位置信息的缺失.

# 定义位置编码器类    
class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout, max_len=5000):"""位置编码器类的初始化函数, 共有三个参数分别是d_model: 词嵌入维度 dropout: 置0比率max_len: 每个句子的最大长度"""super(PositionalEncoding, self).__init__()# 实例化nn中预定义的Dropout层, 并将dropout传入其中, 获得对象self.dropoutself.dropout = nn.Dropout(p=dropout)# 初始化一个位置编码矩阵, 它是一个0阵,矩阵的大小是max_len x d_model.pe = torch.zeros(max_len, d_model)# 初始化一个绝对位置矩阵, 在这里,词汇的绝对位置用它的索引表示. # 首先使用arange方法获得一个连续自然数向量,然后再使用unsqueeze方法拓展向量维度使其成为矩阵, position = torch.arange(0, max_len).unsqueeze(1)# 绝对位置矩阵初始化之后,接下来就是考虑如何将这些位置信息加入到位置编码矩阵中,# 最简单思路就是先将max_len x 1的绝对位置矩阵, 变换成max_len x d_model形状,然后覆盖原来的初始位置编码矩阵即可, # 要做这种矩阵变换,就需要一个1xd_model形状的变换矩阵div_term,我们对这个变换矩阵的要求除了形状外,# 还希望它能够将自然数的绝对位置编码缩放成足够小的数字,有助于在之后的梯度下降过程中更快的收敛.  这样我们就可以开始初始化这个变换矩阵了   div_term = torch.exp(torch.arange(0, d_model, 2) *-(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)# 使用unsqueeze拓展维度.pe = pe.unsqueeze(0)# 最后把pe位置编码矩阵注册成模型的bufferself.register_buffer('pe', pe)def forward(self, x):"""forward函数的参数是x, 表示文本序列的词嵌入表示"""# 在相加之前我们对pe做一些适配工作, 将这个三维张量的第二维也就是句子最大长度的那一维将切片到与输入的x的第二维相同即x.size(1),# 因为我们默认max_len为5000一般来讲实在太大了,很难有一条句子包含5000个词汇,所以要进行与输入张量的适配. # 最后使用Variable进行封装,使其与x的样式相同,但是它是不需要进行梯度求解的,因此把requires_grad设置成false.x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)# 最后使用self.dropout对象进行'丢弃'操作, 并返回结果.return self.dropout(x)

掩码张量

只有1和0的元素,代表位置被遮掩或者不被遮掩,至于是0位置被遮掩还是1位置被遮掩可以自定义,因此它的作用就是让另外一个张量中的一些数值被遮掩,也可以说被替换, 它的表现形式是一个张量

在transformer中, 掩码张量的主要作用在应用attention,有一些生成的attention张量中的值计算有可能已知了未来信息而得到的,未来信息被看到是因为训练时会把整个输出结果都一次性进行Embedding,但是理论上解码器的的输出却不是一次就能产生最终结果的,而是一次次通过上一次结果综合得出的,因此,未来的信息可能被提前利用. 所以,需要进行遮掩

原理实现

>>> atten_data=torch.tensor([[4,2,3,4,5],[6,7,8,9,10],[11,12,13,14,15],[16,17,18,19,20]])
>>> data
tensor([[ 4,  2,  3,  4,  5],[ 6,  7,  8,  9, 10],[11, 12, 13, 14, 15],[16, 17, 18, 19, 20]])>>> mask=np.triu([[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1]],k=1)
>>> mask=1-mask
>>> mask
array([[1, 0, 0, 0, 0],[1, 1, 0, 0, 0],[1, 1, 1, 0, 0],[1, 1, 1, 1, 0]])>>> data=data.masked_fill(mask==0,-1e9)
>>> data
tensor([[          4, -1000000000, -1000000000, -1000000000, -1000000000],[          6,           7, -1000000000, -1000000000, -1000000000],[         11,          12,          13, -1000000000, -1000000000],[         16,          17,          18,          19, -1000000000]])

transformer中

多层Transformer

在实际使用种常常使用多层transformer结构(原论文6层)
在这里插入图片描述
在多层Transformer中,多层编码器先对输入序列进行编码,然后得到最后一个Encoder的输出Memory;解码器先通过Masked Multi-Head Attention对输入序列进行编码,然后将输出结果同Memory通过Encoder-Decoder Attention后得到第1层解码器的输出;接着再将第1层Decoder的输出通过Masked Multi-Head Attention进行编码,接着将编码后的结果同Memory通过Encoder-Decoder Attention后得到第2层解码器的输出,以此类推得到最后一个Decoder的输出。

值得注意的是,在多层Transformer的解码过程中,每一个Decoder在Encoder-Decoder Attention中所使用的Memory均是同一个。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_15680.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt5开发从入门到精通——第九篇一节( Qt5 文件及磁盘处理—— 读写文本文件)

CSDN话题挑战赛第2期 参赛话题:学习笔记 欢迎小伙伴的点评✨✨,相互学习c/c应用开发。🍳🍳🍳 博主🧑🧑 本着开源的精神交流Qt开发的经验、将持续更新续章,为社区贡献博主自身的开源精…

esp32-C3 CAN接口使用

esp32-C3 CAN接口使用功能概述CAN协议关注点接收过滤器单过滤器模式双过滤器模式关键函数说明配置和安装驱动获取TWAI状态信息发送/接收消息使用示例CAN控制器自回环测试CAN收发带过滤测试功能概述 ESP32-C3具有1个CAN控制器支持以下特性: 兼容ISO 11898-1协议(CA…

伟大的micropython smartconfig 配网它来了!!!

我这其实只是实验和搬运,还是感谢伟大的walkline群主,他弄好的,我只是负责搬运发布给新手看。 之前一大堆人问我配网的事儿,输入下wifi名称密码这么麻烦吗,好吧,有求必应,之前的配网是通过ap模式…

PICO高管专访:关于PICO 4硬件、内容、定价、海外布局的一切解答

PICO 4昨天正式在国内发布,简单来说这是一款相对均衡的VR一体机,在硬件素质、内容生态建设上都可圈可点,对于国内还未入手VR的朋友们来说是非常好的选择。相关阅读:《PICO 4评测:Pancake光学新标杆,VR娱乐V…

20【访问者设计模式】

文章目录二十、访问者设计模式20.1 访问者设计模式简介20.1.1 访问者设计模式概述20.1.2 访问者设计模式的UML类图20.2 访问者设计模式的实现20.3 访问者设计模式的优缺点二十、访问者设计模式 20.1 访问者设计模式简介 20.1.1 访问者设计模式概述 访问者设计模式&#xff0…

计算机网络基础 VLSM----可变长子网掩码;CIDR技术----无类域间路由;

VLSM----可变长子网掩码: 概述: 通过网络位向主机位借位的方式,延长子网掩码,从而达到将一个大网络划分为多个小网络;借出的位数称之为子网位,决定了能划分网络的个数。 优点: 更高效的利用…

记一次导入下载好的源码工程到本地工程异常解决方案

今天在学习okhttp相关视频时,安装视频的操作在自己的工程中引入三方的模块,但是发现引入后和预期的不一致。不一致指的是,视频中以module方式引入sample-okhttp并解决冲突后,sample-okhttp能够被android stuidio识别为applicayion…

Style样式设置器

构成Style最重要的两种元素: Setter类帮助我们设置控件的静态外观风格 Trigger类则帮助我们设置控件的行为风格。 Setter,设置器,我们给属性赋值的时候一般都采用“属 性名属性值”的形式 上面的例子中针对TextBlock的Style,Style中使用 若…

解决csdn强制关注博主才能阅读文章

问题 有的时候查阅资料的时候,关注博主并不是很方便,查csdn会出现下面的提示解决办法 打开控制台输入以下代码: var article_content=document.getElementById("article_content"); article_content.removeAttribute("style");var follow_text=document…

深入理解计算机系统——第七章 Linking

深入理解计算机系统——第七章 Linking7.1 Compiler Drivers7.2 Static Linking7.3 Object Files7.4 Relocatable Object Files7.5 Symbols and Symbol Tables7.6 Symbol Resolution7.6.1 How Linkers Resolve Duplicate Symbol Names7.6.2 Linking with Static Libraries7.6.3…

人体神经元结构示意图,神经细胞内部结构图

人体神经结构图???? 谷歌人工智能写作项目:神经网络伪原创 下图为神经系统的结构示意图,请根据图回答: (1)构成神经系统的结构、功能单位是神经元,图中E部分…

19【迭代器设计模式】

文章目录十九、迭代器设计模式19.1 迭代器设计模式简介19.1.1 迭代器设计模式概述19.1.2 迭代器设计模式的UML类图19.2 迭代器设计模式的实现19.3 迭代器设计模式的优缺点十九、迭代器设计模式 19.1 迭代器设计模式简介 19.1.1 迭代器设计模式概述 迭代器设计模式&#xff0…

DeFi借贷重新洗牌 透过协议变化能找到哪些新趋势?

在过去的几个月里,DeFi 借贷赛道产生了重大变化,1kx 研究员 Mikey 0x 对此场域重新进行梳理,BlockBeats 对其整理翻译如下: 本文内容将包括对新借贷协议的介绍、核心数据统计以及发展趋势,也许可以让我们大致把握下一…

Python3操作MongoDB数据库

Python3操作MongoDB数据库 文章目录Python3操作MongoDB数据库0. 写在前面1. 安装开源驱动库pymongo2. 参考0. 写在前面 Linux:Ubuntu Kylin 16.04MongoDB:MongoDB3.2.7Python:Anaconda With Python3.7 1. 安装开源驱动库pymongo pymongo驱动…

公众号题库搜题对接(免费接口)

公众号题库搜题对接(免费接口) 本平台优点: 多题库查题、独立后台、响应速度快、全网平台可查、功能最全! 1.想要给自己的公众号获得查题接口,只需要两步! 2.题库: 题库:题库后台(点击跳转&a…

用神经网络表示与逻辑,神经网络实现逻辑运算

数据挖掘中的神经网络和模糊逻辑的概念是啥? 【神经网络】人工神经网络(Artificial Neural Networks,简写为ANNs)也简称为神经网络(NNs)或称作连接模型(Connection Model)&#xff…

Frp内网穿透win系统实录

文章目录前言公网服务器端配置基于Docker配置简单文件配置内网服务器端配置frpc配置安装OpenSSH服务配置连接XShell和Xftp连接前言 由于实验室的某些原因,分配了一台win10的服务器(QAQ),但是由于服务器在内网,无法访问…

【常用排序算法】

文章目录写在最前面只想用其中的某个算法?类关系图工具类NumberArrayUtil用于测试排序的父类 SortTest冒泡排序堆排序插入排序归并排序快速排序选择排序希尔排序写在最前面 只想用其中的某个算法? 如果你只是想要对应的排序算法,可删除每个…

A-Level数学P4:反证法题型变革趋势

历年来,真题中Prove by contradiction的常见题型有三类: 1►Even/Odd相关证明2►Multiple of 3相关证明3►Irrational number相关证明 但是从2022年开始,该考点有越变越活的趋势。不再局限于书本上出现过的习题类型,而是进一步考察…

SpringBoot生产监控

文章目录一、健康监控简介1、介绍2、SpringBoot准备工作3、其他二、健康检测触达关键组件1、内置组件健康详情2、自定义组件健康详情3、自定义多 HealthIndicator 聚合三、对外暴露应用内部重要组件的状态1、内部状态数据暴露2、JMX MBean四、指标 Metrics 快速定位五、总结一、…