深度学习——注意力机制(笔记+代码)

news/2024/4/28 13:41:55/文章来源:https://blog.csdn.net/jbkjhji/article/details/128939432

 1.从心理学的角度出发

人类根据随意线索(随着意志,主动的,有意识)和不随意线索(无主动,潜意识)选择注意点

第一眼看到红色咖啡杯比较突出和易见就是潜意识的不随意线索

 随着意识想主动读书,看到的书就是随意线索

2.注意力机制

①卷积,全连接,池化层都只考虑不随意线索,因为它们让数据原有的特点更加突出,能让特点注意到,就是不随意。

Ⅰ池化层操作是将感受野范围最大值提取出来(最大池化)

Ⅱ卷积操作是将输入全部通过卷积核进行操作,提取出明显的特征。

②注意力机制则显示的考虑随意线索(想要的

Ⅰ随意线索称为查询(query)—想要做的

Ⅱ每个输入是一个值(value)和不随意线索(key)的键值对—理解为环境,就是键值对,key和value可以相同和不同

Ⅲ通过注意力池化层偏向的选择某些输入—根据query偏向的选择输入,显示的加入query,根据query查询所需要的东西。

 3.非参注意力池化层:不需要学习参数

非参:不需要学习参数

x,y:key-value键值对

f(x):就是query查询的东西

平均池化:最简单的方案,不需要管查询的东西(f(x)的x),只对y求和取平均就可以了。

4. Nadaraya-Watson 核回归:

①核:K函数,衡量x和xi之间距离的函数

②在给定的数据进行查询xi,选择和新给定的值比较近的数据,然后将这些数据对应的value值进行加权求和,得到最终的query,不需要学习参数。

5.K的选择:高斯核

 

代入公式得到

①U:x-xi代表之间的距离

②exp:结果是大于0的数

③softmax得到0-1之间的数作为权重

④上述公式加一个可学习的参数w

 

【总结】

①心理学认为人通过随意线索和不随意线索选择注意点

②注意力机制中,通过query(随意线索)和key(不随意线索)偏向选择输入,写作

 

 f(x)的 key 和所有的不随意线索的 key 做距离上的计算(α(x,xi),通常称为注意力权重),分别作为所有的 value 的权重

【代码】

import torch
from torch import nn
from d2l import torch as d2l

1.生成数据集

n_train = 50  # 训练数据样本
x_train, _ = torch.sort(torch.rand(n_train) * 5)  # 排序后的训练样本def f(x):return 2 * torch.sin(x) + x ** 0.8y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)def plot_kernel_reg(y_hat):d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],xlim=[0, 5], ylim=[-1, 5])d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

2.平均汇聚

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

3.非参数注意力汇聚

#  x_repeat的形状是(n_test,n_train),每一行包含相同的测试输入
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train) ** 2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

4.训练可以学习的参数

# 使用小批量乘法计算加权平均值
weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))

5.带参数的注意力汇聚

class NWKernelRegression(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)self.w = nn.Parameter(torch.rand((1,), requires_grad=True))def forward(self, queries, keys, values):# queries和attention_weights的形状为(查询个数,“键-值”对个数)queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w) ** 2 / 2, dim=1)# values的形状为(查询个数,“键-值”对个数)return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)

6.将训练数据集转换为键和值

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

7.训练

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])for epoch in range(5):trainer.zero_grad()l = loss(net(x_train, keys, values), y_train)l.sum().backward()trainer.step()print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')animator.add(epoch + 1, float(l.sum()))

8.最后结果

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_255494.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

谁说菜鸟不会数据分析,不用Python,不用代码也轻松搞定

作为一个菜鸟,你可能觉得数据分析就是做表格的,或者觉得搞个报表很简单。实际上,当前有规模的公司任何一个岗位如果没有数据分析的思维和能力,都会被淘汰,数据驱动分析是解决日常问题的重点方式。很多时候,…

TypeScript快速入门

TypeScript快速入门1.TypeScript介绍1.1.TypeScript为什么要为JS添加类型支持1.2.TypeScript相比JS优势2.TypeScript初体验2.1.安装编译TS的工具包2.2.编译并运行TS代码2.3.简化运行TS代码3.TypeScript常用类型3.1.类型注解3.2.常用基础类型3.3.原始类型 number/string/boolean…

MG996R舵机介绍

舵机简介舵机是一种位置(角度)伺服的驱动器,适用于那些需要角度不断变化并可以保持的控制系统。在高档遥控玩具,如飞机、潜艇模型,遥控机器人中已经得到了普遍应用。舵机主要是由外壳、电路板、驱动马达、减速器与位置…

【c语言技能树】文件

Halo,这里是Ppeua。平时主要更新C语言,C,数据结构算法......感兴趣就关注我吧!你定不会失望。 🌈个人主页:主页链接 🌈算法专栏:专栏链接 我会一直往里填充内容哒! &…

NAS系列 硬件选择

转自我的博客文章https://blognas.hwb0307.com/nas/3224,内容更新仅在个人博客可见。欢迎关注! 前言 经过《NAS系列 为什么你需要一台NAS》的简单介绍,如果你也决定像我一样组装一台自己的NAS,那么就千万不要错过本文喔&#xff…

负载均衡反向代理下的webshell上传+apache漏洞

目录一、负载均衡反向代理下的webshell上传1、nginx 负载均衡2、搭建环境3、负载均衡下的 WebShell连接的难点总结难点一、需要在每一台节点的相同位置都上传相同内容的 WebShell难点二、无法预测下次的请求交给哪台机器去执行。难点三、下载文件时,可能会出现飘逸&…

【3】深度学习之Pytorch——如何使用张量处理表格数据集(葡萄酒数据集)

张量是PyTorch中数据的基础。神经网络将张量输入并产生张量作为输出,实际上,神经网络内部和优化期间的所有操作都是张量之间的操作,而神经网络中的所有参数(例如权重和偏差)也都是张量。 怎样获取一条数据、一段视频或…

Springboot + RabbitMq 消息队列

前言 一、RabbitMq简介 1、RabbitMq场景应用,RabbitMq特点 场景应用 以订单系统为例,用户下单之后的业务逻辑可能包括:生成订单、扣减库存、使用优惠券、增加积分、通知商家用户下单、发短信通知等等。在业务发展初期这些逻辑可能放在一起…

openGL学习之GLFW和GLAD的下载和编译

背景:为什么使用GLFW和GLADOPenGL环境 目前主流的桌面平台是GLFW和GLAD之前使用的GLUT和Free GLUT已经基本淘汰了,所以记录一下如何下载GLFW和GLAD并且编译.GLFW下载:An OpenGL library | GLFW复制到你想存放的位置,我这里就存放到C盘Libaray文件夹下了,这里是我存放…

中国区注册使用ChatGPT指南(OpenAI‘s services are not available in your country)

ChatGPT又火了,各大平台热搜提到手软。暴增的访问量,即使强如ChatGPT,也表示顶不住了。Openai表示服务器已满负荷,ChatGPT暂无法提供服务由于目前ChatGPT未在中国开放,所以国内目前是无法注册使用ChatGPT。但我经过一番…

『 MySQL篇 』:MySQL表的聚合与联合查询

基础篇 MySQL系列专栏(持续更新中 …)1『 MySQL篇 』:库操作、数据类型2『 MySQL篇 』:MySQL表的CURD操作3『 MySQL篇 』:MySQL表的相关约束4『 MySQL篇 』:MySQL表的聚合与联合查询目录一. 聚合查询1.1 聚合函数1.2 GROUP BY子句…

Python将字典转换为csv

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。喜欢通过博客创作的方式对所学的知识进行总结与归纳,不仅形成深入且独到的理…

MySQL篇02-三大范式,多表查询

数据入库时,由于数据设计不合理,会存在数据重复、更新插入异常等情况, 故数据库中表的设计遵循的设计规范:三大范式1.第一范式(1NF)要求数据库的每一列都是不可分割的原子数据项,即原子性。强调的是列的原子性,即数据库中每一列的…

攀升MaxBook P2电脑U盘重装系统方法教学

攀升MaxBook P2电脑U盘重装系统方法教学。攀升MaxBook P2电脑是一款性价比非常高的笔记本。有用户购买了这款电脑后,想要将系统进行重装。今天和大家分享一个U盘重装系统的方法,学会这个方法后以后就可以自己轻松去重装电脑系统了。接下来一起看看具体的…

相机坐标系的正向投影和反向投影

1 、正向投影: 世界坐标系到像素坐标系 世界3D坐标系(x, y, z) 到图像像素坐标(u,v)的映射过程 (1)世界坐标系到相机坐标系的映射。 两个坐标系的转换比较简单,就是旋转矩阵 平移矩阵,旋转矩阵则是绕X, Y&#xff…

Thread 类及常见方法

Thread 类是 JVM 用来管理线程的一个类,换句话说,每个线程都有一个唯一的 Thread 对象与之关联。用我们上面的例子来看,每个执行流,也需要有一个对象来描述,类似下图所示,而 Thread 类的对象就是用来描述一…

分享111个JS焦点图代码,总有一款适合您

分享111个JS焦点图代码,总有一款适合您 111个JS焦点图代码下载链接:https://pan.baidu.com/s/1GxjW5m9DNOPEQd-Qf_gGSA?pwd4aci 提取码:4aci Python采集代码下载链接:https://wwgn.lanzoul.com/iKGwb0kye3wj jQuery宽屏左右…

没有人能比我快,用Python写一个自动填写答案的脚本

前言 不是标题党,真的就是没有人比我快,今天用Python写了个自动填写答案的脚本,快就算了,准确率还是百分之百 话不多说 咱先看代码 后看效果 不想看全文的 点击文末名片 领取源码 环境使用 Python 3.8Pycharm 模块使用 imp…

Request Method: OPTIONS

节选自https://blog.csdn.net/Amnesiac666/article/details/121105088版权归原作者所有,如有侵权请联系删除Request Method: OPTIONS一些接口在请求时,会自动发送一个的请求,我查了一遍代码,不是代码中写明的。 网上给出的解释涉及…

Pinecone:一款专为红队研究人员设计的WLAN网络安全审计框架

关于Pinecone Pinecone是一款专为红队研究人员设计的WLAN网络安全审计框架,该工具基于模块化开发,允许广大研究人员根据任务需求进行自定义功能扩展。Pinecone设计之初专用于树莓派,可以将树莓派打造为便携式无线网络安全审计工具&#xff0…