Fastformer: Additive Attention Can Be All You Need

news/2024/5/2 23:24:13/文章来源:https://www.cnblogs.com/Uriel-w/p/16704705.html

创新点:

  • 本文根据transformer模型进行改进,提出了一个高效的模型,模型复杂度呈线性。
  • 主要改进了注意力机制,出发点在于降低了注意力矩阵的重要程度,该方法采用一个(1*T)一维向量替换了原始T*T大小的注意力矩阵。

注意力结构图:

 

       在这里,输入同样通过不同的线性映射得到Q,K,V,然后通过Q得到Q的权重:

 

 

 其中,从Q到Q的权重变化过程为:

weight:(B,T,D)->(B,T,h)->(B,h,T)->(B,h,1,T)

然后Q和Q的权重做乘法运算weight*query=(B,h,1,ad)->(B,1,h,ad)->(B,1,D)->(B,T,D):

 

 得到的结果和K做逐点乘法运算:

 

 

 K的权重向量和Q的求法相同:

 

 同样的K和K的权重做乘法运算:

 

 最后的结果和V做逐点运算:

 

 在这里Q和V是相同的,采用了权重共享的方法。

在espnet中的代码实现:

import numpy
import torchclass FastSelfAttention(torch.nn.Module):"""Fast self-attention used in Fastformer."""def __init__(self,size,attention_heads,dropout_rate,):super().__init__()if size % attention_heads != 0:raise ValueError(f"Hidden size ({size}) is not an integer multiple "f"of attention heads ({attention_heads})")self.attention_head_size = size // attention_headsself.num_attention_heads = attention_headsself.query = torch.nn.Linear(size, size)self.query_att = torch.nn.Linear(size, attention_heads)self.key = torch.nn.Linear(size, size)self.key_att = torch.nn.Linear(size, attention_heads)self.transform = torch.nn.Linear(size, size)self.dropout = torch.nn.Dropout(dropout_rate)def espnet_initialization_fn(self):self.apply(self.init_weights)def init_weights(self, module):if isinstance(module, torch.nn.Linear):module.weight.data.normal_(mean=0.0, std=0.02)if isinstance(module, torch.nn.Linear) and module.bias is not None:module.bias.data.zero_()def transpose_for_scores(self, x):"""Reshape and transpose to compute scores.Args:x: (batch, time, size = n_heads * attn_dim)Returns:(batch, n_heads, time, attn_dim)"""new_x_shape = x.shape[:-1] + (self.num_attention_heads,self.attention_head_size,)return x.reshape(*new_x_shape).transpose(1, 2)def forward(self, xs_pad, mask):"""Forward method.Args:xs_pad: (batch, time, size = n_heads * attn_dim)mask: (batch, 1, time), nonpadding is 1, padding is 0Returns:torch.Tensor: (batch, time, size)"""batch_size, seq_len, _ = xs_pad.shapemixed_query_layer = self.query(xs_pad)  # (batch, time, size)mixed_key_layer = self.key(xs_pad)  # (batch, time, size)if mask is not None:mask = mask.eq(0)  # padding is 1, nonpadding is 0# (batch, n_heads, time)query_for_score = (self.query_att(mixed_query_layer).transpose(1, 2)/ self.attention_head_size**0.5)if mask is not None:min_value = float(numpy.finfo(torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype).min)query_for_score = query_for_score.masked_fill(mask, min_value)query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)else:query_weight = torch.softmax(query_for_score, dim=-1)query_weight = query_weight.unsqueeze(2)  # (batch, n_heads, 1, time)query_layer = self.transpose_for_scores(mixed_query_layer)  # (batch, n_heads, time, attn_dim)
pooled_query = (torch.matmul(query_weight, query_layer).transpose(1, 2).reshape(-1, 1, self.num_attention_heads * self.attention_head_size))  # (batch, 1, size = n_heads * attn_dim)pooled_query = self.dropout(pooled_query)pooled_query_repeat = pooled_query.repeat(1, seq_len, 1)  # (batch, time, size)
mixed_query_key_layer = (mixed_key_layer * pooled_query_repeat)  # (batch, time, size)# (batch, n_heads, time)query_key_score = (self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5).transpose(1, 2)if mask is not None:min_value = float(numpy.finfo(torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype).min)query_key_score = query_key_score.masked_fill(mask, min_value)query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(mask, 0.0)else:query_key_weight = torch.softmax(query_key_score, dim=-1)query_key_weight = query_key_weight.unsqueeze(2)  # (batch, n_heads, 1, time)key_layer = self.transpose_for_scores(mixed_query_key_layer)  # (batch, n_heads, time, attn_dim)pooled_key = torch.matmul(query_key_weight, key_layer)  # (batch, n_heads, 1, attn_dim)pooled_key = self.dropout(pooled_key)# NOTE: value = query, due to param sharingweighted_value = (pooled_key * query_layer).transpose(1, 2)  # (batch, time, n_heads, attn_dim)weighted_value = weighted_value.reshape(weighted_value.shape[:-2]+ (self.num_attention_heads * self.attention_head_size,))  # (batch, time, size)weighted_value = (self.dropout(self.transform(weighted_value)) + mixed_query_layer)return weighted_value

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_9623.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue23全面知识总结七(2)

感兴趣的朋友可以去我的语雀平台进行查看更多的知识。 https://www.yuque.com/ambition-bcpii/muziteng 7.8 路由的props配置 props作用:让路由组件更方便的收到参数 {name:detail,path:detail/:id,component:Detail,//第一种写法:props值为对象&…

Java内存模型:创建对象在堆区如何分配内存

一、Heap堆区 Heap堆是JVM所管理的内存中最大的一块区域,被所有线程共享的一块内存区域。堆区中存放对象实例和数组,“几乎”所有的对象实例以及数组都在这里分配内存。 新生代、老年代 二、创建对象的内存分配 初始创建对象会在新生代的Eden区生成&…

行为型设计模式之策略模式

行为型设计模式之策略模式策略模式应用场景优缺点主要角色策略模式的基本使用创建抽象策略角色创建具体策略角色创建上下文角色客户端执行策略模式实现支付方式的选择创建抽象策略角色创建具体策略角色创建上下文角色客户端执行策略模式 策略模式(Strategy Pattern)…

线程安全集合:CopyOnWriteArrayList源码分析

目录 一、基本思想 二、源码分析 add()方法 set()方法 remove()方法 get()方法 三、小结 一、基本思想 首先CopyOnWrite 简称 COW ,是一种用于对集合并发访问的优化策略。基本思想是:当我们往一个集合容器中写入元素时(比如添加…

C++左值右值、左值引用右值引用、移动语义move

目录 1.什么是左值、右值 2.什么是左值引用&、右值引用&& 2.1左值引用& 2.2右值引用&& 2.3对左右值引用本质的讨论 2.3.1右值引用有办法指向左值吗? 2.3.2左值引用、右值引用本身是左值还是右值? 2.4 右值引用使用场景…

51单片机学习:静态数码管实验

实验名称:静态数码管实验 接线说明: 实验现象:下载程序后“数码管模块”最左边数码管显示数字0 注意事项: ***************************…

神经体液调节网络,神经网络能干嘛

神经网络的发展趋势如何? 神经网络的云集成模式还不是很成熟,应该有发展潜力,但神经网络有自己的硬伤,不知道能够达到怎样的效果,所以决策支持系统中并不是很热门,但是神经网络无视过程的优点也是无可替代…

CSDN编程竞赛-第六期(上)

CSDN编程竞赛报名地址:https://edu.csdn.net/contest/detail/16 努力是为了让自己不平庸: 前言/背景 四道题都是相关数组的,思路很好想,但是需要熟练使用,不能有小错误。 参赛流程 活动时间:9月8日-21日&a…

Python机器视觉--OpenCV进阶(核心)--图像直方图与掩膜直方图与直方图均衡化

1.图像直方图 1.1 图像直方图的基本概念 在统计学中,直方图是一种对数据分布情况的图形表示,是一种二维统计图表. 图像直方图是用一表示数字图像中亮度分布的直方图,标绘了图像中每个亮度值的像素数。可以借助观察该直方图了解需要如何调整…

记录一次关于Rank()排序函数问题

先来看应用场景吧 就是页面上有个top按钮 根据不同的top 进行筛选 比如我选择top5 那么在下方当前大区的销售额降序筛选出来最高的前五个销售员or客户这种场景 💖 问题 问题1:为什么我的这个rank排序函数 这个华南大区 不是从1开始的呢 其他大区都是正…

java毕业设计选题系统ssm实现的商城系统(电商购物项目)

🍅文末获取联系🍅 一、项目介绍 《ssm实现的商城系统》该项目采用技术:springspringMVCmybaitsEasyUIjQueryAjax等相关技术,项目含有源码、文档、配套开发软件、软件安装教程、项目发布教程等 1.1 课题背景、目的及意义 当今社…

java 同学聚会AA制共享账单系统springboot 小程序022

本系统在一般同学会小程序的基础上增加了首页推送最新信息的功能方便用户快速浏览,是一个高效的、动态的、交互友好的同学会小程序。 用户在首页上会看到各类模块的推送内容,可以以最直接的方式获取信息,注册登陆后,可以对应经费信…

Unity基础笔记(5)—— Unity渲染基础与动画系统

Unity渲染基础与动画系统 Unity渲染基础 一、摄像机 1. 摄像机概念和现实中的摄像机很接近,Unity 中 Camera 组件负责将游戏画面拍摄然后投放到画面上 Camera 拍摄到的画面决定了 Game 面板的画面 创建场景的时候,Unity 会默认创建一个摄像机,所以我们点击 Game 面板才有画面…

【算法刷题】链表篇-链表的回文结构

文章目录题目要求方法1:思路代码方法2代码题目要求 链接:链表的回文结构_牛客题霸_牛客网 (nowcoder.com) 1 -> 2 -> 3 -> 2 -> 1 1 -> 2 -> 2 -> 1 上面两个是回文结构 方法1:思路 1.遍历链表,把结点对应的…

网络安全基础——对称加密算法和非对称加密算法(+CA数字证书)

目录 一、数据传输时的安全特性 二、对称加密算法: 三、非对称加密算法 四、对称加密和非对称加密 — 融合算法: 五、CA数字证书: 一、数据传输时的安全特性 ———————————————————————————————————…

分布式进化算法

1 多解优化问题 多解优化问题是指一类具有多个最优解的复杂优化问题。多峰优化问题和多目标优化问题都是两类典型的多解优化问题,它们之前的统一关系,即都具有多个最优解。多峰优化问题要求算法找到多个具有相同适应度值得最优解,多目标优化问…

SpringBoot的核心原理(扒笔记记录)

这一课的主要重点: 自动装配以及starterJDBC数据库连接池ORM、JPA、MyBatis、Hibernate这样相关的一些技术 从Spring到SpringBoot 我们在工作中都可能用过了SpringBoot,特别是最近几点,Java开发者大军里的一员,我们一般可能上手就…

卷积神经网络相比循环神经网络具有哪些特征

CNN卷积神经网络结构有哪些特点? 局部连接,权值共享,池化操作,多层次结构。 1、局部连接使网络可以提取数据的局部特征;2、权值共享大大降低了网络的训练难度,一个Filter只提取一个特征,在整个…

Docker容器互联

前言: 虽然每个docker容器之间都能通过ip来进行互联,但当容器重新启动,ip就会被重新分配给重新启动的容器,这时同个容器由于重启导致ip不一样了,这时就会导致开发和运维的困难程度大大增加,这时候就要考虑…

springboot+学生信息管理 毕业设计-附源码191219

学生信息管理的设计与实现 摘 要 科技进步的飞速发展引起人们日常生活的巨大变化,电子信息技术的飞速发展使得电子信息技术的各个领域的应用水平得到普及和应用。信息时代的到来已成为不可阻挡的时尚潮流,人类发展的历史正进入一个新时代。在现实运用中&…