FLASH:一种高效的Transformer设计

news/2024/5/18 18:00:57/文章来源:https://blog.csdn.net/shawroad88/article/details/126617904

背景

近年来,Transformer凭借其优秀的设计,在文本、图像、语音等方向大杀四方。但是由于其attention的二次复杂度限制了其在长序列上的应用。本文提出了一种快(速度快)、省(省显存)的模型FLASH(Fast Linear Attention with a Single Head),在长序列的表现远远高于标准的Transformer。

模型介绍

GAU(Gated Attention Unit)

在标准的Transformer结构中,多头注意力和FFN是交替连接的。GLU那篇论文中,将FFN替换成基于门控的线性单元,发现效果会变好。因此,我们先简单了解一下门控单元GLU的计算,如下左图:

具体计算:

也就是将输入X分别经过放射变换(线性映射+激活函数)得到U,VU,VU,V。然后再将U,VU,VU,V进行点积,最后再进行线性映射,得到门控线性单元的输出。

上述的GLU中没有对token两两进行注意力计算,如果在上面的U,VU,VU,V中引入注意力,那岂不是就省了前面的多头注意力计算了。如下式:

U,VU,VU,V进行点积计算的时候,如果给VVV乘一个注意力矩阵AAA(维度为nxn,其中n为序列的长度),那岂不是就引入了注意力信息。

基于此,本文就提出了一种新的结构GAU。主要是给出了一种注意力矩阵A的计算方法。具体计算如下:

对输入X进行放射变换(线性映射+激活函数)得到Z,然后对Z分别进行Q,K\mathcal{Q},\mathcal{K}Q,K变换,就是对Z中的每一个标量进行平移等运算。 这里的Q,K\mathcal{Q},\mathcal{K}Q,K变化类似于LayerNorm中的α,β\alpha, \betaα,β,是可训练的。 然后将两种变换的结果进行矩阵乘。最后再经过一个relu2relu^2relu2的激活函数(relu2relu^2relu2是将relu的计算结果平方),得到最终的注意力矩阵。然后和门控注意力单元中的V进行相乘,这样就在门控注意力单元中引入了注意力信息。具体结果如下图所示:

注意:可能是GLU对attention的依赖没有那么强,因此,作者在实验中只用了一个注意力头。

在GAU的实验中,作者固定e=2d,那"n层Attention+n层FFN"的标准Transformer模型,对应的就是"2n层GAU"的新模型,即该模型为FLASH-Quad。其中Quad表明复杂度依然是二次的。即:FLASH的二次复杂度版本。

Fast Linear Attention with GAU

可能有读者发现,在上述的GAU中,你只是将attention和FFN合并起来,替代了标准的attention+FNN,并没有解决attention的二次复杂度呀?对,因此,作者提出了一种快速计算注意力的方法。

过去,在解决注意力的二次复杂度问题上,有两种主流方法: (1)将注意力计算稀疏化、(2)将注意力计算线性化。稀疏化即人为根据先验知识规定哪些token可以进行注意力计算(典型代表: Longformer、BigBird等)。线性化则是提出另外的方法,去逼近标准注意力的效果(典型代表: Linformer、Performer等),如下公式所示:

正常的注意力是将Q,KQ,KQ,K进行矩阵乘,接着经过softmax,最后乘V。如果将K和V先进行乘,则可以大大减少计算量。假设Q,K,VQ,K,VQ,K,V的维度为:(m,d)(m,d)(m,d),则标准注意力的计算量为:
m∗d∗m+m∗d∗mm*d*m+m*d*mmdm+mdm,即: 2dm22dm^22dm2,是跟序列长度m成平方正比。如果先算K乘V,则计算量为:d∗m∗d+d∗m∗dd*m*d+d*m*ddmd+dmd,即:2md22md^22md2,是跟序列长度m成一次正比。所以,第二种方法随着序列的边长,效率会远高于第一种方法。

本文则是根据上述两种方式,结合"稀疏化"和"线性化"的优点,提出了一种"局部+全局"的分块混合的注意力计算方法。

首先是分块注意力的计算,假设序列长度为n,每个块的维度为c,则可分成n/c个块(默认可整除)。Ug,Vg∈Rc×e,Zg∈Rc×s\boldsymbol{U}_g,\boldsymbol{V}_g\in\mathbb{R}^{c\times e},\boldsymbol{Z}_g\in\mathbb{R}^{c\times s}Ug,VgRc×e,ZgRc×s,其中g指的是第g个块。将Zg\boldsymbol{Z}_gZg通过四个放射变换(线性映射+激活)分别得到Qgquad,Kgquad,Qglin,Kglin\boldsymbol{Q}_g^{\text{quad}},\boldsymbol{K}_g^{\text{quad}},\boldsymbol{Q}_g^{\text{lin}},\boldsymbol{K}_g^{\text{lin}}Qgquad,Kgquad,Qglin,Kglin。则块内注意力计算如下:

可以看出上述公式很好理解,不做过多描述。接下来算算其复杂度。每个块内注意力计算复杂度为c2c^2c2,有n/c个块,则块注意力计算整体的复杂度为:(n/c)∗c2(n/c) * c^2(n/c)c2,即nc,也就是正比于n。

接着用Qglin,Kglin\boldsymbol{Q}_g^{\text{lin}},\boldsymbol{K}_g^{\text{lin}}Qglin,Kglin进行全局的attention计算。这里采用的是上述介绍的注意力线性化方法。计算如下:


上述(7)式就是全局的两两计算。(8)式主要是在生成任务中,对标的是带有mask的多头注意力。

最后将两种attention结果整合到GAU中,得到线性版的GAU网络,计算如下:

作者在论文中还贴出来注意力计算的代码,如下所示:

至此,论文就介绍完了,下面简单看一下实验结果。

实验

首先,作者对比了GAU、多头注意力+FFN、以及多头注意力+GLU三种结构,在自回归任务和MLM任务上的表现,如下图:

横轴为速度,纵轴为效果,越靠右上,效果越好。上述实验是在长度为512上的效果对比。可以发现,GAU的在相同效果的前提下,速度更快;在相同速度的前提下,效果更好。

从上表中可以看出,虽然FLASH-Quad也是二次复杂度,但是也比标准的Transformer效果好,速度也更快。另外,随着序列的逐渐变长,FLASH的速度远远快于标准的Transformer。

最后看看消融实验,如下图所示:

上面MF-TFM++模型采用的是多头注意力+FFN的结构,只是多头注意力采用的线性注意力+块注意力。也就是本文提出的注意力计算。从消融实验中,可以看出一个很有用的信息,即localOnly attention比GlobalOnly attention更重要。

本文参考:
[1] 论文: https://arxiv.org/abs/2202.10447
[2] GLU: https://arxiv.org/abs/2002.05202
[3] 苏剑林. (Feb. 25, 2022). 《FLASH:可能是近来最有意思的高效Transformer设计 》[Blog post]. Retrieved from https://kexue.fm/archives/8934

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_3398.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot 和 Vue前后端分离在线工具项目实战,源码+超详细讲解

一、前言 主要通过SpringBoot和Vue来实现一个前后端分离的在线工具平台,包含PDF转换、图片处理、文本处理、图表展示、二维码工具等功能。 为了更直观展示项目效果,也给大家提供了在线体验地址:http://49.234.28.149, 源码资源见文末。 通过…

org.apache.ibatis.binding.BindingException: Invalid bound statement (not found):

org.apache.ibatis.binding.BindingException: Invalid bound statement (not found): 无效的绑定语句(未找到),就是写的sql 方法找不到sql。解决: 1 namespace 指向是否正确 路径与引用的方法的路径保持一致a.namespace 没有指向Dao b. id ,方法名没有对应上2 引用的方法…

记录Kettle连不上mysql8

如图所示,mysql升级到8了。 在很早之前,我一直用的是Mysql 5的驱动包去连接数据库,今天发现突然连接不上了,想了一下,应该是我以前升级mysql后的原因,换了mysql8的驱动后依旧没个卵用。 报错如下&#xff…

远程Debug远端服务器JVM配置

远程调试非本机的Java进程 远端Java进程启动的JVM参数 注意:以下配置尽量不要在线上生产环境开启,或者 JDK4: -Xdebug -Xrunjdwp:transportdt_socket,servery,suspendn,address{port} JDK5-JDK8: -agentlib:jdwptransportdt_socket,servery,suspen…

Python——LeetCode刷题——【383. 赎金信】

题目描述: 解题思路: 用字典记录字符串magazine中每个字符出现的次数。然后看看字典中magazine的各个字符的出现次数是否“够”字符串ransomNote中各个字符出现的次数。如果够,return True。如果存在有点字符不够,return False。…

学习:Python进阶 冒泡排序

#原理 列表每两个相邻的数,如果前面的数比后面的数大,则交换这两个数 一趟排序完成后,则无序曲减少一个数,有序区增加一个数 每循环一趟,从无序区冒出来一个最大的数,放入有序区,最终得到一个升序的列表

认真研究ConcurrentHashMap中的元素统计策略

这里我们想研究的是jdk1.8中ConcurrentHashMap的addCount(long x, int check)方法。如下所示在put方法的最后会触发addCount(long x, int check)方法进行元素个数的统计。 我们再回顾一下另一个参数binCount : 在操作链表的分支if (fh > 0)中 用于统计put前链表…

TinyRenderer学习笔记--Lesson 3、4

Lesson 3 zbuffer 无论怎样,生活中的显示器基本上都是平面,是一个2D的场景,而我们的模型却是3D的,是有深度的,实际上我们看见的都只是离我们的眼睛最近的那一个平面,一个不透明的3D物体的内部和背面是我们…

河北稳控科技使用标准信号检测 VM振弦采集模块测量精度

河北稳控科技使用标准信号检测 VM振弦采集模块测量精度(一) (1)电源1.1VDD 引脚电源必须使用 LDO 稳压或者低纹波线性电源, LDO 推荐使用 AM1117_3.3V 芯片,测试时发现 SPX 生产的 LDO会造成非常严重的干扰(其它品牌应该也会有类似的问题)。1.2VSEN 引脚电源单通道模块…

阿里、滴滴、华为等一线互联网分布式消息中间件:RocketMQ核心笔记

本篇介绍了RocketMQ的基本使用方法及其各个组件的基本原理,讲解原理时,都是采用先整体架构后详细分解的方式。详细分解时不会深入源码逐段讲,而是从代码结构出发梳理整个运行过程。 这份RocketMQ分布式消息中间件—核心原理与最佳实践的完整…

Android Studio应用基础,手把手教你从入门到精通(小白学习)总结2 之 常用界面布局和ListView

总结1链接: (156条消息) Android Studio应用基础,手把手教你从入门到精通(小白学习)总结1_好喜欢吃红柚子的博客-CSDN博客 学习视频链接: (学完必会)Android studio基础,从入门到…

尚好房 07_前端房源展示

尚好房&#xff1a;前端房源展示 一、分页显示房源列表 1、效果 2、项目搭建 2.1 创建项目 在web项目中创建子工程web-front 2.2 pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0&…

stm32学习(二)|ADC电压采集DMA

利用ADC通道采集外部传感器数值,ADC通道选择依据实际查询芯片手册可得,相关配置利用Cubemx完成。 ADC参数配置首先选择需要使用的ADC通道,并设置对应的引脚ADC_IN0X.ADC参数设置(Paremeter setting)Mode : Independent mode,只使用一个ADC通道 Clock Prescaler,Resolut…

OpenGL 反色

目录 一.OpenGL 反色 1.IOS Object-C 版本2.Windows OpenGL ES 版本3.Windows OpenGL 版本 二.OpenGL 反色 GLSL Shader三.猜你喜欢 零基础 OpenGL ES 学习路线推荐 : OpenGL ES 学习目录 >> OpenGL ES 基础 零基础 OpenGL ES 学习路线推荐 : OpenGL ES 学习目录 >&…

Windows OpenGL ES 图像反色

目录 一.OpenGL ES 图像反色 1.原始图片2.效果演示 二.OpenGL ES 图像反色源码下载三.猜你喜欢 零基础 OpenGL ES 学习路线推荐 : OpenGL ES 学习目录 >> OpenGL ES 基础 零基础 OpenGL ES 学习路线推荐 : OpenGL ES 学习目录 >> OpenGL ES 特效 零基础 OpenGL E…

责任链模式

1、责任链模式是什么 行为模式&#xff0c;一个对象产生的消息会被另外的对象处理。对象发出消息后&#xff0c;不管被哪种、多少个其他对象收到和处理消息。【客户端和handler解耦】 2、为什么使用 如果不使用责任链&#xff0c;则client要知道有多少个handler、什么情况调…

2.IP子网划分

IP子网划分地址分类网络位与主机位一个网段可以容纳多少IPIP地址&#xff1a;互联网中计算机的‘身份证号’&#xff0c;唯一标识一台网络设备的身份ID NAT技术&#xff1a;网络地址转换&#xff0c;节约公网IP 例: IP地址 192.168.1.1 192.168.1 …

电商数仓项目中各层的表

ODS operation Data store 操作数据存储 DWD Data Warehouse detail 细节数据层, DIM Dimension---------------范围&#xff0c;维度 DWS Data Warehouse Summary 数据库汇总 ADS Application Data Service 应用数据服务层 【电商数仓每一层的表】 【ODS层】 operation Data s…

Spring之AOP思想

目录 什么是AOP ​​​为什么用AOP Spring AOP 应该怎么学习呢 AOP下的一些核心概念&#xff08;SpringAOP并没有实现所有的概念&#xff09; 基于概念的使用Spring的AOP 一个使用的实例 关于切点的匹配 通知的种类 使用注解的方式来实现功能​编辑 AOP框架背后的核心 …

TypeScript 小结

TypeScript 是什么&#xff1f; TypeScript 是由微软开发的一种自由和开源的编程语言。它是 JavaScript 的一个超集&#xff0c;本质上是在 JavaScript 的基础上添加了可选的静态类型和基于类的面向对象编程。 TypeScript 和 JavaScript 的区别&#xff1f; TypeScript 的安装…