FlashAttention解析——大预言模型核心组建

论文名称:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

论文地址:https://arxiv.org/abs/2205.14135        

一、研究FlashAttention的Motivate       

        FlashAttention技术在现在的主流大语言模型中均有应用,其主要作用是减少Transformer结构中运算(主要是self-attention,包括softmax、dropout等)的显存消耗,进一步解除文本处理长度限制,使得模型能够处理更长、更复杂的文本数据 以及 多轮对话功能。

        让我们继续看看原论文中的说法:

        Transformer 作为语言模型的基础架构提供了强大的特征表达能力,已经作为LLM的基础模型构件被大量使用。

        在Transformer中核心组件是多头自注意力(multi-head selft-attention),这里的计算复杂度和空间复杂度是序列长度的二次方O(n2)。因此长文本处理仍然面临挑战。

        当然有许多尝试用于减少注意力的计算和内存开销。例如,稀疏近似和低秩近似得方法,将计算复杂度降低到序列长度的线性或亚线性,但这些方法主要关注FLOPs(浮点数计算次数)的减少(这部分消耗主要由矩阵运算提供),而忽略了IO读写的内存访问开销。

        由下图可以看到,GPT-2中的标准attention,耗时对比:矩阵运算 < softmax < Dropout。在现代GPU中,计算速度超过显存访问速度。基于这样的发现,论文作者将突破‘超长文本处理’的契机放在了注意力的IO瓶颈。论文团队在对GPU硬件和注意力实现进行性能剖析后,将性能瓶颈锁定在‘HBM内存的读写压力过大’,指标论文的主要优化方向为‘降低HBM的IO次数’。

二、标准注意力机制与HBM的访问关系

2.1 标准Attention机制推理过程

        Q,K,V\epsilon R^{N\times D}, Attention(Q, K, V)=softmax(\frac{Q*K^{T}}{\sqrt{d}})V

        将上面的步骤进行拆解可以得到

        S=QK^{T} \epsilon R^{N\times N}P=softmax(S) \epsilon R^{N\times N}O=PV \epsilon R^{N\times d}

        如下图所示,一次标准Attention的实现需要多次读写HBM

        1. 按块从HBM中读取矩阵Q和K,计算S,并将S写入HBM;

        2. 从HBM中读取S,计算完P=softmax(S)之后,将P写入HBM;

        3. 按块从HBM读取中间结果P和V,计算O=PV,将O写入HBM;

        4. 返回O

        注意:笔者不清楚Q和K是同时读取还是分为2次;有相关科普说是分别读取(读两次HBM)

2.2 GPU结构的一些知识

        这里是论文中给出的GPU A-100的内存结构:

        1. HBM(High Bandwidth Memory,高带宽存取存储器)

                由多个DRAM堆叠。

        2. SRAM(Static Random-Access Memory, 静态随机访问存储器)

                用于高速缓存等内部存储器,具有更快的访问速度和更低的延迟,但成本更高。由图中可见SRAM的执行/读写速度是HBM的12.67倍,但存储空间远远小于HBM。

三、FlashAttention

        FlashAttention,总的来说是一种优化访问开销精准注意力算法。

        motivation:从GPU内存结构来看,要想提升Attention的性能,应该让计算过程尽可能在SRAM中进行。由于序列长度 N 可能会很长,无法将Q、K、V 以及中间结果完整存储在SRAM中,因此FlashAttention就采用了‘分块’操作,每块的计算所需内存不超过SRAM大小。

        这里有两种核心操作:tiling(平铺) 、recomputation(重计算) ,最后使用 kernel fusion 进行融合。

        1. tiling:利用更高速的SRAM代替HBM;

        2. recomputation:放弃中间结果写回,需要使用时再次计算,用计算Trade-off访存;

        3. kernel fusion:基于Tiling使用一个kernel完成整个计算。

3.1 tiling 平铺

        tiling 基本思路:不直接对整个输入序列计算注意力,而是根据SRAM大小将其分为多个较小的块,逐个对‘块’进行计算,在计算过程中增量式(详见3.1.2)地进行softmax的逼近。在整个计算过程中只需要更新某些中间变量(如全局最大值,详见3.1.2),不需要计算整个注意力权重矩阵。

        而‘分块’操作的难点在于softmax的计算,softmax计算中分母位置包含所有元素的求和项(该项用于归一化),论文重点描述了softmax的‘分块’计算。

3.1.1 先来看看标准softmax计算流程(无分块)

        这里有两个版本的softmax,softmax(x)应该是我们常见的理论上的实现方式;但是在实际操作中,我们通常使用safe_softmax(x),笔者已经替大家试过了,两者结果一致

# safe softmax / 安全 softmax
def safe_softmax(x):# 防止数值计算时的下溢,先将x中的每个元素减去x中的最大值e_x = np.exp(x - np.max(x))return e_x / e_x.sum(axis=0)# 一般形式softmax
def softmax(x):e_x_2 = np.exp(x)return e_x_2 / e_x.sum(axis=0)# input = np.array([1, 2, 3, 4])
# output = [0.0320586  0.08714432 0.23688282 0.64391426]

        让我们继续用一个case,推理下safe_softmax的计算:x = [1, 2, 3, 4]

        a. 计算组间最大值,防止计算下溢,m(x) = max(x) = 4

        b. 指数计算,f(x) = [e^{1-m(x)}, e^{2-m(x)}, e^{3-m(x)}, e^{4-m(x)}] = [e^{-3}, e^{-2}, e^{-1}, e^{0}]

        c. 计算softmax分母 / 归一化因子, l(x) = e^{-3} + e^{-2} + e^{-1} + e^{-0}

        d. softmax计算, softmax(x) = \frac{f(x)}{l(x)}

3.1.2 继续看看分块softmax计算流程:举例推理

举例推理:简单起见,这里分为2块计算,x1 = [1, 2],x2 = [3, 4]

        a. 计算第一块内的最大值 m(x1) = max(x1) = 2 = m(x) {记录全局最大值m(x)}

        b. 第一个块内,进行指数计算 f(x1) = [e^{1-m(x1)}, e^{2-m(x1)}] = [e^{-1}, e^{0}]。初始化赋值f(x)=f(x1)

        c. 第一个块内,计算归一化因子 l(x1) = e^{-1} + e^{-0}。注意,这里是中间变量

        d. 开始操作第二个模块,更新到此刻为止的最大值 m(x)=m(x2) = max(m(x1), x2) = 4

                补充:论文提供的伪代码中,使用for循环处理每个块,每一步都会更新最大值。

        e. 第二个块内,进行指数计算  f(x2) = [e^{1-m(x2)}, e^{2-m(x2)}] = [e^{3-4}, e^{4-4}] = [e^{-1}, e^{0}]

        f. 第二个块内,计算归一化因子 l(x2) = e^{-1} + e^{-0}。注意,这里仍然是中间变量。

        g. 柔和两个块的中间结果计算全局f(x) 和 l(x)

f(x) = [e^{m(x1) - m(x)}f(x1), e^{m(x2) - m(x)}f(x2)]

f(x) = [e^{2-4}(e^{-1}, e^{0}), e^{4-4}(e^{-1}, e^{0})] = [e^{-2}(e^{-1}, e^{0}), e^{0}(e^{-1}, e^{0})] 

       l(x) = e^{m(x1) - m(x)} * l(x1) + e^{m(x2) - m(x)} * l(x2)

 l(x) = e^{2-4} * (e^{-1} + e^{-0}) + e^{4-4} * (e^{-1} + e^{-0}) = e^{-3} + e^{-2} + e^{-1} + e^{-0}

        :至此,各位会发现分块计算的f(x)和l(x)到了这一步的结果和不分块计算的结果一致。

3.1.3 补充 + 尚存问题

        tiling 操作在FlashAttention中是一个贯穿正向传播和反向传播的重要策略。它不仅在正向传播中用于分块处理输入矩阵以提高计算效率和减少内存使用,还在反向传播中用于优化内存访问和重新计算必要的中间变量。

3.2 recomputation 重计算

        Recomputation是一种算力换内存的操作,即基于trade-off的思想。在上述分析中重点在于优化访问开销,既然GPU计算时间 小于 HBM读写时间,那么就不存储注意力计算过程中的中间结果,而是在某层反向传播中临时计算梯度更新所需的正向传播的中间状态。

        相对于标准注意力机制从HBM中读取很大的中间注意力矩阵,重新计算尽管增加了额外的计算量FLOPs,但仍能够减少运行时间。由下图可见,虽然增加了FLOPs,但是减少了HBM的读写量,最终耗时性能收益明显。

        注1:在这里(反向传播),仅保存了前向 tiling 过程中的两个统计量 m(x) 和 l(x);

        注2:在正向传播中,变量S、P(见2.1)不会被保存;但是在反向传播中需要计算S、P关于Q、K、V的偏导,然后用于更新权重,在这里是重新计算中间结果S和P。

        注3:在recomputation中同样基于 tiling 平铺的思想重新计算所需的注意力权重矩阵。看到这么一种说法:“recomputation 可以看作是一种基于 tiling 的特殊的 gradient checkpointing”。

3.3 Kernal Fusion

        核心思想是将多个操作融合成一个操作,以此减少HBM的访问次数。tiling 分块计算使得可以用一个Kernal完成注意力的所有操作。        

        例如:在 SRAM 中计算完 𝑆 之后紧接着就通过 𝑆 计算 𝑃 ,这样可以避免在 HBM 和 SRAM 交换 𝑆 。

3.4 不确定的部分

        笔者猜测全流程:从HBM加载输入数据(如完整的Q、K),然后‘分块’加载到SRAM执行计算,在SRAM基于一个Kernal Fusion的概念,将mask、softmax、dropout等计算完整,最后将结果写回HBM。整个流程只有‘两次’读写HBM操作?

        是否是这个样子,各位可以评论区留言。

        但是,看伪代码,for循环不断的从HBM加载数据到SRAM,这一步也需要消耗吧。

        

四、论文伪代码解析

4.1 FlashAttention前向传播

按行数进行代码描述

首先确定SRAM的大小,记M,保证Q、K、V和结果O的分块能够保留在SRAM内;

1. 计算 ‘块数’ or 列大小 Bc

2. 在HBM中初始化输出矩阵O,中间变量l和m,其中m用于记录每一行中行最大值,初始化-inf;

3. 将Q、K、V切块,块数分别为Tr 、Tc、Tc;

4. 将2中初始化的O、l、m切块,块数和Q一样,均为Tr;

5+6. 外层循环,将 Kj、Vj 从HBM加载到SRAM;

7+8. 内层循环,将Qi、Oi、li、mi 从HBM加载到SRAM;

9. 开始注意力机制的计算,计算中间变量 Sij;

10. 计算Sij每一行的最大值,记mij(Sij是一个Br * Bc的矩阵,有Br行);按行开展safe_softmax指数运算得到Pij(约等于第三章中的f(x));计算Pij每一行的和,记Lij(softmax分母);

11. 计算 mi(new)、li(new),这一步类比3.1.2中(d,e,f,g),再更新最大值之后,计算分母累计值;

12. 累加计算注意力(KV部分)更新Oi并写入HBM,供下一轮循环读取;

13. 重新赋值并将当前累积 li、mi 写入HBM;在下一轮中,将作为上一轮的累积结果

补充:GPU内多线程分块读取 + 计算。

作者还将Flash Attention扩展到了块稀疏注意力,产生了一种更优的近似的注意力算法。

4.2 反向传播过程(我要开始偷懒了)

        已知前向过程只将Oi、li、mi 写入了HBM,并没有保存S和P,再根据标准self_attention反向传播计算dQ、dK、dV的公式(如下图,图来自于原论文最后的补充材料),分块计算结果。

        ‘分块’attention 反向传播伪代码如下:

        1~4. 前向过程会保留Q,K,V,O,l,m在HBM中,dO由反向传播计算得到后,按照和前向传播相同的分块模式重新分块;

        5. 初始化dQ,dK,dV为全0矩阵,并按照对等Q,K,V的分割方式分割dQ,dK,dV;

        6~10. 外循环:从HBM中读取K、V 块到SRAM;内循环:读取Q块到SRAM;

        11~20. 根据前向过程重新计算对应的Sij和Pij;按分块矩阵的方式分别计算对应梯度d(Sij)和d(Pij)

        21~end. 累积形式更新dQ、dK、dV

五、总结  

        FlashAttention是通过减少HBM访问开销、以内存换时间等操作优化后的精准Attention,虽然多了很多计算步骤,可能会导致一定的精度损失,但仍然能够保证模型在处理复杂任务时的精确性和可靠性。

        核心收益如下:

        长文本处理能力:更小的内存(显存复杂度从O(N^2)降低到了O(N)) + 更快的推理速度(减少HBM访问),这些特性扩展了文本处理长度限制,C哈她GLM2应用该技术后,将文本可处理长度从2K提升到了32K。大预言模型能够处理更长、更复杂的文本数据。这一改进推进了‘长文本’的处理和模型效果优化。

        增强上下文理解能力:更长的输入,可能会增强对长历史对话的理解能力,确保模型在多轮对话中能够准确捕捉和整合上下文信息。

        灵活的组件:FlashAttention可以应用于各种类型的神经网络,包括卷积神经网络(CNN)、循环神经网络(RNN)和Transformer等。这种灵活性使得FlashAttention能够在多种场景和任务中发挥作用。

        主要缺陷

        硬件依赖:FlashAttention起作用的一部分起因是计算开销 < 访问开销,因此能够起到更好的作用,就比较依赖于内存带宽和计算带宽。

        额外的调度配置:分块、动态规划(累积计算中间结果和最终结果)和缓存机制等方法来优化计算过程,那么在GPU内不同线程之间如何调度、如何分区的配置需要根据任务和数据反复调试,以找到最佳配置。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3281503.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Java--多态和抽象类

目录 多态实现多态的条件重写重写和重载静态绑定和动态绑定向上转型和向下转型向上转型向下转型instanceof 抽象类 多态 在Java中多态就是在完成一件事的时候&#xff0c;不同人去完成产生的结果不同 比方说打印&#xff0c;我们就是要打印一个东西&#xff0c;如果我们交给彩色…

性能提升20%,字节跳动HTTPDNS从中心下沉到边缘

摘要&#xff1a;本文介绍了HTTPDNS服务从中心迁移至边缘详细的落地过程。主要内容为&#xff1a; HTTPDNS下沉边缘实践遇到的挑战&#xff0c;包括服务放置、流量调度 HTTPDNS下沉边缘解决方案 从性能、成本出发&#xff0c;谈谈HTTPDNS下沉边缘后的收益 传统的DNS流程中…

微信小程序-获取手机号:HttpClientErrorException: 412 Precondition Failed: [no body]

问题&#xff1a; 412 异常就是你的请求参数获取请求头与服务器的不符&#xff0c;缺少请求体&#xff01; 我的问题&#xff1a; 我这里获取微信手机号的时候突然给我报错142&#xff0c;但是代码用的是原来的代码&#xff0c;换了一个框架就噶了&#xff01; 排查问题&am…

java算法day27

java算法day27 动态规划初步总结509 斐波那契数杨辉三角打家劫舍完全平方数 动态规划初步总结 如果你感觉某个问题有很多重叠子问题&#xff0c;使用动态规划是最有效的。 动态规划的过程就是每一个状态一定是由上一个状态推导出来的&#xff0c;这一点就区分于贪心了。贪心是…

鄂维南院士:人工智能的零数据、小数据、大数据和全数据方法

源自&#xff1a; 中国计算机学会 注&#xff1a;若出现无法显示完全的情况&#xff0c;可 V 搜索“人工智能技术与咨询”查看完整文章 人工智能、大数据、多模态大模型、计算机视觉、自然语言处理、数字孪生、深度强化学习 课程也可加V“人工智能技术与咨询”报名参加学习 致…

android java socket server端 可以不断的连接断开,不断的收发 TCP转发

adb.exe forward tcp:5902 tcp:5902 前面本地5901 转发到 后面设备为5902查看转发 adb forward --list删除所有转发 adb forward --remove-allpublic static final String TAG "Communicate";private static boolean isEnable;private final WebConfig webConfig;//…

四步教你快速解决UE5文件迁移失败❗️

本期作者&#xff1a;尼克 易知微3D引擎技术负责人 不知道大家在用UE5迁移文件时&#xff0c;有没有发现这个问题&#xff1a;如果文件输出的路径选择了非项目路径&#xff0c;那么UE会提示无法迁移。在UE4中&#xff0c;这样做是不存在问题的&#xff0c;只要选择「忽略」就可…

Studying-代码随想录训练营day48| 739. 每日温度、496.下一个更大元素 I、503.下一个更大元素II

第48天&#xff0c;单调栈part01&#xff0c;栈的特殊应用场所&#xff01;编程语言&#xff1a;C 目录 739. 每日温度 496.下一个更大元素 I 503.下一个更大元素II 总结&#xff1a; 739. 每日温度 文档讲解&#xff1a;代码随想录每日温度 视频讲解&#xff1a;手撕每日…

AI识别智能称重-收银系统源码

系统概况 专门为零售行业的连锁店量身打造的收银系统&#xff0c;适用于常规超市、生鲜超市、水果店、便利店、零食专卖店、服装店、母婴用品、农贸市场等类型的门店使用。同时线上线下数据打通&#xff0c;线下收银的数据与小程序私域商城中的数据完全同步&#xff0c;如商品…

什么是数据血缘?怎么做好数据血缘分析?

目录 一、什么是数据血缘&#xff1f; 二、数据血缘关系的四大特征 三、数据血缘分析怎么做&#xff1f; 1.定义元数据模型 2.收集元数据 3.建立血缘关系模型 4.追踪数据流动 5.可视化分析 6.集成到数据治理中 7.持续更新和维护 8.应用分析结果 四、数据血缘技术趋势 1.通用的血…

测试环境领域到测试环境产品

作者&#xff1a;攻心 去年之前&#xff0c;阿里巴巴的淘天集团测试环境是以领域方式运作&#xff1a;不局限测试环境治理本身&#xff0c;从测试模式方法论及用好测试环境思路引领集团测试环境治理。领域运作最难的是“统一思想”。业务进一步细分调整后&#xff0c;测试环境治…

修改所属用户/用户组——chown

目录 &#xff08;1&#xff09;修改所属用户 &#xff08;2&#xff09;修改所属用户组 &#xff08;3&#xff09;修改所属用户和用户组 &#xff08;4&#xff09; 选项 -R 使用 chown 可以修改文件/文件夹的所属用户&#xff0c;所属用户组&#xff1b; 当然与 chmod …

数字人直播系统搭建能力评测!3招教你快速摸清源码厂商的真实实力?

随着数字人直播的应用场景不断拓展和应用频率的持续升高&#xff0c;其所蕴含着的市场前景和收益潜力逐渐显现&#xff0c;连带着数字人直播系统搭建的热度也迎来了新的高潮。在此背景下&#xff0c;作为非科班和研发资源有限的创业者们主要的入局途径&#xff0c;各大数字人源…

C++原创系列创斯人工智能Trons10.0.135.7911最新概念版本预告及思路总结

这次更新删掉了以前的所有代码&#xff0c;重新编写&#xff0c;只因我有了新的思路&#xff0c;以前的思路太过于原始&#xff0c;我的思路中的聊天功能如下 这只是聊天函数的原理&#xff0c;聊天函数对一句话的回答有5个到10个&#xff0c;在主函数中多次运行这个函数&#…

ruoyi vue3版本web端隐藏侧边栏及其顶部导航栏

做项目时有个需求是在web端里面嵌入一个页面全屏的大屏&#xff0c;但若依web自带的侧边栏导航和顶部导航一时还不知道怎么隐藏起来&#xff0c;于是在网上到处查找资料&#xff0c;终于&#xff0c;还是在若依的gitee文档中发现了线索 怎么隐藏侧边栏和顶部导航栏实现完全的全…

<数据集>工程机械识别数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;6338张 标注数量(xml文件个数)&#xff1a;6338 标注数量(txt文件个数)&#xff1a;6338 标注类别数&#xff1a;7 标注类别名称&#xff1a;[Excavator, Loader, Dumb_truck, Mobile_crane, Roller, Bull_dozer, …

微信小程序之使用智能对话服务,客服回复的跳转小程序指定页面链接无效

在微信小程序中使用了微信智能对话服务&#xff0c;客服回复的是小程序指定页面的链接&#xff0c;无法正确跳转&#xff0c;而是返回到进入客服时的页面去了 解决方案&#xff1a; 需在小程序的客服组件 button 上添加 bindcontact 监听事件即可 <movable-area class"…

【ROS 最简单教程 007/300】ROS 架构 - 目录解析 增删改查 计算图

⭐ 工作空间目录解析如下 &#xff1a; WorkSpace --- 自定义的工作空间|--- build:编译空间&#xff0c;用于存放 CMake 和 catkin的 缓存信息、配置信息和其他中间文件|--- devel:开发空间&#xff0c;用于存放编译后生成的目标文件&#xff0c;包括头文件、动态&静态链接…

MySQL基础练习题14-产品销售分析1

题目&#xff1a;获取 Sales 表中所有 sale_id 对应的 product_name 以及该产品的所有 year 和 price 。 准备数据 分析数据 题目&#xff1a;获取 Sales 表中所有 sale_id 对应的 product_name 以及该产品的所有 year 和 price 。 准备数据 ## 创建库 create database db;…

DNS查询服务器的基本流程以及https的加密过程

DNS查询服务器的基本流程&#xff0c;能画出图更好&#xff0c;并说明为什么DNS查询为什么不直接从单一服务器查询ip&#xff0c;而是要经过多次查询&#xff0c;多次查询不会增加开销么&#xff08;即DNS多级查询的优点&#xff09;&#xff1f; 用户发起请求&#xff1a;用户…