这篇文章的目标是具体的解释Flash Attention,为什么要解释FlashAttention呢?因为FlashAttention 是一种从新排序注意力计算的算法,它无需任何近似即可减速注意力计算并缩小内存占用。所以作为目前LLM的模型减速它是一个十分好的解决方案,本文介绍经典的V1版本,最新的V2做了其余优化咱们这里临时不介绍。因为V1版的FlashAttention号称能够提速5-10倍,所以咱们来钻研一下它到底是怎么实现的。

介绍

论文的题目是:

“FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”

内存的效率与一般注意力相比(序列长度是二次的,O(N²)),FlashAttention是次二次的/线性的N (O(N))。并且它不是注意力机制的近似值(例如,稠密或低秩矩阵近似值办法)-它的输入与“传统”注意力机制雷同。与一般的注意力相比,FlashAttention的注意力是”有感知“的。

它利用底层硬件的内存档次常识(例如gpu,但其余AI加速器也应该工作,我这里应用gpu作为示例)。一些[近似]办法在序列长度上将计算要求升高到线性或近线性,但其中许多办法专一于缩小FLOP,而疏忽内存拜访(IO)的开销。

通过多年的倒退gpu的FLOPS的增长速度始终在以比内存吞吐量(TB/s)更快。内存的瓶颈应该引起器重。FLOPS和内存吞吐量须要紧密结合,因为硬件上的差距,咱们就须要软件层面上的工作进行均衡。

依据计算和内存拜访之间的比率,操作能够分为以下两种:

  • 计算束缚 :矩阵乘法
  • 内存束缚:元素操作(激活,dropout,masking),归并操作(softmax, layer norm,sum等)

在以后的AI加速器(GPU)上是受内存大小限度的。因为它“次要由元素操作组成”,或者更精确地说,注意力的算术密度不是很高。

咱们看看这个图:

能够看到,masking,softmax和dropout是占用大量工夫的操作,而不是矩阵乘法(即便大部分FLOPS是在matmul中)。内存不是一个繁多的工件,它在实质上是分层的,个别的规定是:内存越快,越低廉,容量越小。

咱们在下面说的,FlashAttention的注意力是”有感知“的能够归结为利用SRAM比HBM(高带宽内存)快得多来确保缩小两者之间的通信。

以A100为例:

A100 GPU有40-80GB的高带宽内存(HBM),带宽为1.5-2.0 TB/s,而每108个流处理器有192KB的SRAM,带宽预计在19TB/s左右。

能够看到大小小了很多,然而速度却晋升了10倍,所以如何高效的利用SRAM是提速的要害,让咱们看看规范注意力实现背地的计算:

规范实现如何显示对HW操作形式不大尊重。它基本上将HBM加载/存储操作视为0老本(它不是“io感知”)。

咱们首先思考如何使这个实现更无效(工夫和内存方面)。最简略的办法是删除冗余的HBM读/写。

如何把S写回HBM只是为了(从新)加载它来计算softmax,那么咱们能够将其保留在SRAM中,执行所有两头步骤,而后将最终后果写回HBM。

内核基本上是“GPU操作”的一种奇异的说法(参考咱们以前公布的CUDA入门,往简略了说就是一个函数)。交融则能够将多个操作交融在一起。所以只从HBM加载一次,执行交融的op,而后将后果写回来。这样做能够缩小通信开销。

这里还有一个业余名词术语是“materialization”(物化/实体化)。它指的是,在下面的规范注意力实现中,曾经调配了残缺的NxN矩阵(S, P)。上面咱们将看到如何间接将内存复杂度从O(N²)升高到O(N)。

Flash attention基本上能够归结为两个次要观点:

Tiling (在向前和向后传递时应用)-基本上将NxN softmax/scores矩阵分块成块。

Recomputation (仅在向后传递中应用)

算法如下:

下面咱们提到了很多名词,你可能还不理解。没关系上面咱们开始逐行解释算法。

FlashAttention算法

让Tiling办法的次要阻碍是softmax。因为softmax须要将所有的分数列耦合在一起。

看到分母了吗?这就是问题所在。

要计算输出序列中的特定第i个标记对序列中其余标记的关注水平,须要在SRAM中随时可用所有这些分数(这里用z_j示意)。

然而SRAM的容量是无限的。N(序列长度)能够是1000甚至100000个令牌。所以N²爆炸得很快。所以论文应用了一个技巧:把softmax的计算分成更小的块,最终依然失去完全相同的后果。

咱们能够只获取前一个B分数(x_1到x_B)并为它们计算softmax。而后通过迭代,“收敛”到正确的后果。以一种聪慧的形式组合这些每块局部softmax的数字,这样最终的后果实际上是正确的。办法如下:

基本上,为了计算属于前2个块(大小为B)的分数的softmax,必须要跟踪每个块的2个统计数据:m(x)(最大分数)和l(x) (exp分数总和)。而后就能够用归一化系数将它们无缝地交融在一起。

这里次要是一些根本的代数运算,通过开展f(x)和l(x)项并与e^x相乘一些项会互相对消,这里就不写了。这个逻辑递归地始终继续到最初一个(N/B)块,这样就失去了N维正确的softmax输入!

为了具体的介绍这个算法,假如有一个大小为1的批处理(即单个序列)和单个注意力头,稍后会扩大它(通过简略地跨GPU的并行化-稍后会具体介绍)。咱们临时疏忽了dropout和masking,因为稍后再增加。

咱们开始计算:

初始化:HBM的容量以GB为单位测量(例如RTX 3090有24 GB的VRAM/HBM, A100有40-80 GB等),因而调配Q, K和V不是问题。

第1步

计算行/列块大小。为什么ceil(M / 4 d) ?因为查问、键和值向量是d维的,所以咱们还须要将它们组合成输入的d维向量。所以这个大小基本上容许咱们用q k v和0个向量最大化SRAM的容量。

比如说,假如M = 1000, d = 5。那么块大小为(1000/4*5)= 50。所以一次加载50个q, k, v, o个向量的块,这样能够缩小HBM/SRAM之间的读/写次数。

对于B_r,我也不太确定他们为什么要用d执行最小运算?如果有人晓得,请评论指教!

第2步:

用全0初始化输入矩阵O。它将作为一个累加器,l也相似它的目标是保留softmax的累积分母——exp分数的总和)。M(保留逐行最大分数)初始化为-inf,因为咱们将对其进行Max运算符,因而无论第一个块的Max是什么-它必定大于-inf 。

第3步:

步骤1中的块大小将Q, K和V分成块。

第4步:

将O, l, m宰割成块(与Q的块大小雷同)。

第5步:

开始跨列循环,即跨键/值向量(上图中的内部循环)。

第6步:

将K_j和V_j块从HBM加载到SRAM。在这个工夫点上咱们依然有50%的SRAM未被占用(专用于Q和O)。所以SRAM是这样的:

第7步:

开始跨行外部循环,即跨查问向量。

第8步:

将Q_i (B_r x d)和O_i (B_r x d)块以及l_i (B_r)和m_i (B_r)加载到SRAM中。

这里须要保障l_i和m_i可能载入SRAM(包含所有两头变量),这块可能是CUDA的常识,我不太确定如何计算,所以如果你有相干的信息,请留言

第9步:

计算Q_i (B_r x d)和K_j转置(d x B_c)之间的点积,失去分数(B_r x B_c)。并没有将整个nxns(分数)矩阵“物化”。

假如内部循环索引为j (j=3),外部循环索引为i (i=2), N为25,块大小为5,上面就是刚刚计算的后果(假如以1为根底的索引):

也就是输出序列中标记11-15的标记6-10的注意力得分。这里的一个要点是,这些都是准确的分数,它们永远不会扭转。

第10步:

应用上一步计算的分数计算m_i_j、li_j和P~i_j。M ~_i_j是按行计算的,找到下面每一行的最大元素。

而后通过利用元素运算失去P~_i_j:

归一化-取行最大值并从行分数中减去它,而后EXP

l~_i_j是矩阵P的逐行和。

第11步:

计算m_new_i和l_new_i。同样非常简单,能够重复使用下面的图表:

M_i蕴含之前所有块的逐行最大值(j=1 & j=2,用绿色示意)。M _i_j蕴含以后块的逐行最大值(用黄色示意)。为了失去m_new_i咱们只须要在m_i_j和m_i之间取一个最大值,l_new_i也相似。

第12步(最重要):

这是算法中最难的局部。

它容许咱们用矩阵的模式做逐行标量乘法。如果你有一列标量s (N)和一个矩阵a (NxN)如果你做diag(s)* a你基本上是在用这些标量做a行的元素乘法。

公式1(为了不便再次粘贴在这里):

第12步的第一项所做的(用绿色下划线)是:更新了在同一行块中以后块之前的块的以后softmax预计。如果j=1(这是这一行的第一个块。

第一项乘以diag(l_i)是为了对消之前迭代中除以的雷同常数(这个常数暗藏在O_i中)。

表达式的第二项(黄色下划线)是不须要消去的,因为能够看到咱们间接将P~_i_j矩阵与V向量块(V_j)相乘。

e^x项是用来批改矩阵P~_i_j & O_i的,办法是消去前一次迭代中的m,用最新的预计(m_new_i)来更新它,该预计蕴含到目前为止逐行最大值。

以下是我的逐渐剖析(实际上只须要5分钟,心愿能有所帮忙!)

重点是这些里面的e项和P/O矩阵外面的e项消掉了,所以总是失去最新的m_new_1预计!

第三次迭代也是相似的,失去了正确的最终后果!

回忆一下:这只是对最终O_i的以后预计。只有在咱们遍历上图中的所有红色块之后,咱们能力最终失去确切的后果。

第13步

将最新的累加到统计数据(l_i & m_i)写回HBM。留神它们的维数是B_r。

第13、14、15、1步

嵌套的for循环完结,O (Nxd)将蕴含最终后果:每个输出令牌的注意力加权值向量!

简略汇总

算法能够很容易地扩大到“block-sparse FlashAttention”,这是一种比FlashAttention快2-4的稠密注意力算法,扩大到64k的序列长度!通过应用一个块模式的掩码矩阵,能够跳过下面嵌套的for循环中的某些加载/存储,这样咱们能够按比例节俭稠密系数,比方下图

当初让咱们简略地讨论一下复杂性。

复杂度剖析

空间:在HBM中调配了Q, K, V, O (Nxd), l和m (N)。等于4Nd + 2*N。去掉常量,并且晓得d也是一个常量并且通常比N小得多(例如d={32,64,128}, N={1024,…,100k}),能够失去O(N)的空间,这有助于扩大到64k序列长度(再加上一些其余“技巧”,比方ALiBi)。

工夫:这里不会严格地进行工夫复杂度剖析,然而咱们将应用一个好的指标:HBM拜访的数量。

论文的解释如下:

他们是怎么失去这个数字的?让咱们来剖析嵌套的for循环:

咱们的块大小是M/4d。这意味着向量被宰割成N/(M/4d)块。取它的2次方(因为要遍历行/列块)失去O(N²d²/ M²)

咱们不能一次获取整个块,如果做一个大O剖析,可能会让咱们认为这并不比规范注意力好多少,但对于典型的数字,这导致拜访次数缩小了9倍(依据下面的论文截图)。

咱们的伪算法集中在一个单头注意力,假如批处理大小为1。上面咱们就开始进行扩大了

多头注意力

要扩大到batch_size > 1和num_heads > 1实际上并不难。

算法基本上是由单个线程块(CUDA编程术语)解决的。这个线程块在单个流多处理器(SM)上执行(例如,A100上有108个这样的处理器)。为了并行化计算,只须要在不同的SMs上并行运行batch_size * num_heads线程块。该数字与零碎上可用的SMs数量越靠近,利用率就越高(现实状况下是多个,因为每个SM能够运行多个线程块)。

反向流传

对于GPU内存的占用,另外一个大头就是反向流传,通过存储输入O (Nxd)和softmax归一化统计数据(N),咱们能够间接从SRAM中的Q, K和V (Nxd)块中反向计算注意力矩阵S (NxN)和P (NxN) !从而使内存放弃在O (N)。这个比拟业余了,咱们理解以下就能够了,所以须要具体的内容请看原论文。

代码实现

最初,让咱们看看在应用flash attention时可能呈现的一些问题。因为波及到显存的操作,所以咱们只能深刻CUDA,然而CUDA又比较复杂。

这就是OpenAI的Triton等我的项目的劣势(参见他们的FlashAttention实现)。Triton基本上是一种DSL(畛域特定语言),介于CUDA和其余畛域特定语言(例如TVM)之间的形象级别。能够编写超级优化的Python代码(一旦编译),而不用间接解决CUDA。这样Python代码能够部署在任意的加速器上(这是Triton工作)。

另外一个好消息是Triton最近曾经与PyTorch 2.0集成了。

另外对于某些用例,比方对于超过1K的序列长度,一些近似留神办法(如Linformer)开始变得更快。然而flash attention的块稠密实现优于所有其余办法。

总结

你有没有想过,对于这种底层优化的算法为什么是一个斯坦福大学的学生公布,而不是NVIDIA的工程师?

我认为有2种可能的解释:

1、FlashAttention更容易/只能在最新的gpu上实现(原始代码库不反对V100)。

2、通常“局外人”是那些以初学者的眼光对待问题,可能看到问题的本源并从根本准则登程解决问题

最初咱们还是要进行个总结

FlashAttention可能让BERT-large训练中节俭15%,将GPT训练速度进步2/3,并且是在不须要批改代码的状况下,这是一个十分重要的提高,它为LLM的钻研又提出了一个新的方向。

论文地址:

https://avoid.overfit.cn/post/9d812b7a909e49e6ad4fb115cc25cdc1

作者:Aleksa Gordić