pytorch 使用 xformers 库 加速多头注意力计算 和 大幅节省显存

news/2024/4/16 15:41:29/文章来源:https://blog.csdn.net/ONE_SIX_MIX/article/details/129062323

效果概览:
好处:使用 google PALM 架构的小模型做 生成任务,改为 xformers 实现后,加速比为 2倍,显存消耗为原来的 1/3 ,非常给力。
缺点:相比pytorch的原生实现,误差略大。。。

xformers 官方github仓库:https://github.com/facebookresearch/xformers
xformers 官方文档:https://facebookresearch.github.io/xformers/
https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops

前两周 xformers 官方提供了 pypi 和 whl 轮包
windows 和 linux 均可用,最低版本要求 pytorch 1.13.1 版本

pip 安装 xformers

pip install -U xformers

如果需要用于编码器或需要位置偏置,则需要安装 0.17 以上版本
当前(2023/2/26) v0.17 为预发行版,需要使用 --pre 来安装

pip install --pre -U xformers

使用方法

import torch
from xformers.ops import memory_efficient_attention, LowerTriangularMaskdevice='cuda'
batch = 4
n_head = 8
head_dim = 16
seq_len = 128q = torch.rand(batch, seq_len, n_head, head_dim).to(device)
k = torch.rand(batch, seq_len, n_head, head_dim).to(device)
v = torch.rand(batch, seq_len, n_head, head_dim).to(device)# 使用 causal 掩码
o = memory_efficient_attention(q, k, v, LowerTriangularMask())# 不使用编码
o = memory_efficient_attention(q, k, v)# 使用自定义的 attn_bias,要求 xformers 版本 大于等于 0.17
## 这里的 from_len,to_len 分别代表Decoder的序列长度,Encoder的序列长度
from_len = seq_len
to_len = seq_len
attn_bias = torch.rand(batch, n_head, from_len, to_len).to(device)
o = memory_efficient_attention(q, k, v, attn_bias)

memory_efficient_attention 的大概的 等效pytorch实现
来自 https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops


def memory_efficient_attention_pytorch(query, key, value, attn_bias=None, p=0., scale=None):# q [batch, seq_len, n_head, head_dim]# k [batch, seq_len, n_head, head_dim]# v [batch, seq_len, n_head, head_dim]# attn_bias [batch, n_head, seq_len, seq_len]if scale is None:scale = 1 / query.shape[-1] ** 0.5query = query * scaleattn = query @ key.transpose(-2, -1)if attn_bias is not None:attn = attn + attn_biasattn = attn.softmax(-1)attn = F.dropout(attn, p)return attn @ value

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_74505.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构入门DAY1

力扣刷题合集:力扣刷题_Sunlightʊə的博客-CSDN博客217.存在重复元素相关题目链接:力扣 - 存在重复元素题目重现给你一个整数数组 nums 。如果任一值在数组中出现 至少两次 ,返回 true ;如果数组中每个元素互不相同,返…

大数据框架之Hadoop:MapReduce(三)MapReduce框架原理——ReduceTask工作机制

1、ReduceTask工作机制 ReduceTask工作机制,如下图所示。 (1)Copy阶段:ReduceTask从各个MapTask上远程拷贝一片数据,并针对某一片数据,如果其大小超过一定阈值,则写到磁盘上,否则直…

Active Directory 05 - 初识 AD CS 证书服务

写在最前 如果你是信息安全爱好者,如果你想考一些证书来提升自己的能力,那么欢迎大家来我的 Discord 频道 Northern Bay。邀请链接在这里: https://discord.gg/9XvvuFq9Wb我会提供备考过程中尽可能多的帮助,并分享学习和实践过程…

1029 旧键盘 C++中find函数的使用

题目链接: 一、自己的想法:(弱化版双指针) 思路为用两个“指针”i, j分别指向原来字符串和实际输入字符串的第一个字符,然后判断i,j所指字符是否一致,若是则i, j同时,若否则将i所指…

【5G RRC】5G系统消息SIB3介绍

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G算力网络技术标准研究。 博客…

Windows下命令执行绕过技巧总结(渗透测试专用)

一、连接符1、双引号不要求双引号闭合举例:"who"a"mi" //闭合的 "who"a"mi //不闭合的2、圆括号必须在两边,不能包括中间的字符。举例:((whoami))3、^符号(转译符号)不可以在结尾&…

Go项目(商品微服务-1)

文章目录简介建表protohandler商品小结简介 商品微服务主要在于表的设计,建哪些表?表之间的关系是怎样的? 主要代码就是 CURD表和字段的设计是一个比较有挑战性的工作,比较难说清楚,也需要经验的积累,这里…

【机器学习笔记】Python基础笔记

目录基础语法加载数据:pd.read_csv查看数据大小:shape浏览数据行字段:columns浏览少量数据:head()浏览数据概要:describe()输出:to_csv基础功能语法缺省值去除缺失值:dropna按行删除&#xff1a…

Paddle配置

目录: 1.激活环境 2.版本选择 突发情况:ModuleNotFoundError: No module named paddle 检验是否安装成功 1.激活环境 Anaconda: conda remove -n paddle --all conda activate paddle 2.版本选择 打开链接:https://www.pa…

基于企业微信应用消息的每日早安推送

基于企业微信应用消息的每日早安推送 第一步:注册企业微信 企业微信注册地址:https://work.weixin.qq.com/wework_admin/register_wx 按照正常流程填写信息即可,个人也可以注册企业微信,不需要公司 注册完成后,登录…

Google Guice 4:Bindings(2)

4 Scopes (实例的作用域) 4.1 默认规则:unreuse instance 到目前为止,通过bind().to()和Provides定义的binding,每次需要注入实例对象时,Guice都会创建一个新的实例 // 修改DatabaseTransactionLog,使其打…

Ncvicat 打开sql文件方法

Nacicat打开sql文件时,有比较多的文章介绍可以直接打开,方法介绍的比较多,但是我遇到了一个坑,就是如何配置环境都无法打开。 本机环境: windows10 mysql 5.7.40 Navicat12.1 一、遇到问题情况 1.1、通过navicat…

【python量化】大幅提升预测性能,将NSTransformer用于股价预测

写在前面 NSTransformer模型来自NIPS 2022的一篇paper《Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting》。NSTransformer的目的主要是为了解决其他方法出现过平稳化处理的问题。其通过提出序列平稳化以及去平稳化注意力机制可以使得模型…

2023年三月份图形化二级打卡试题

活动时间 从2023年3月1日至3月21日,每天一道编程题。 本次打卡的规则如下: 小朋友每天利用10~15分钟做一道编程题,遇到问题就来群内讨论,我来给大家答疑。 小朋友做完题目后,截图到朋友圈打卡并把打卡的截图发到活动群…

【尚硅谷MySQL入门到高级-宋红康】数据库概述

1、为什么要使用数据库 数据的持久化 2、数据库与数据库管理系统 2.1 数据库的相关概念 2.2 数据库与数据库管理系统的关系 3、 MySQL介绍 MySQL从5.7版本直接跳跃发布了8.0版本 ,可见这是一个令人兴奋的里程碑版本。MySQL 8版本在功能上做了显著的改进与增强&a…

CXL技术分析

CXL,全称Compute Express Link,该技术由Intel牵头开发用于高性能计算、数据中心,主要解决处理器、加速器和内存之间的cache一致性问题,可消除CPU、专用加速器的计算密集型工作负载的传输瓶颈,显著提升系统性能。 一、…

python的装饰器与设计模式中的装饰器模式

相信很多人在初次接触python中的装饰器时,会跟我一样有个疑问,这跟设计模式中的装饰器模式有什么区别吗?本质上是一样的,都是对现有对象,包括函数或者类的一种扩展。这篇文档将进行对比分析。 python的装饰器 装饰器…

duboo+zookeeper分布式架构入门

分布式 dubbo Zookeeper 分布式系统就是若干独立计算机的集合(并且这些计算机之间相互有关联,就像是一台计算机中的C盘F盘等),这些计算对于用户来说就是一个独立的系统。 zookeeper安装 下载地址:Index of /dist/z…

【数据库系统概论】基础知识总结

🌹作者:云小逸 📝个人主页:云小逸的主页 📝Github:云小逸的Github 🤟motto:要敢于一个人默默的面对自己,强大自己才是核心。不要等到什么都没有了,才下定决心去做。种一颗树,最好的时间是十年前…

C++10:非类型模板参数以及模板的特化

目录 非类型模板参数 模板的特化 模板类的特化 1.全特化 2.偏特化 模板其实还有其他的玩法&#xff0c;比如非类型模板参数以及模板的特化。 非类型模板参数 在记述非类型模板参数前&#xff0c;我们认识一下C中一个比较鸡肋的类&#xff0c;array #include<iostream&g…