关于人工智能:FlashAttention算法详解

12次阅读

共计 5688 个字符,预计需要花费 15 分钟才能阅读完成。

这篇文章的目标是具体的解释 Flash Attention,为什么要解释 FlashAttention 呢?因为 FlashAttention 是一种从新排序注意力计算的算法,它无需任何近似即可减速注意力计算并缩小内存占用。所以作为目前 LLM 的模型减速它是一个十分好的解决方案,本文介绍经典的 V1 版本,最新的 V2 做了其余优化咱们这里临时不介绍。因为 V1 版的 FlashAttention 号称能够提速 5 -10 倍,所以咱们来钻研一下它到底是怎么实现的。

介绍

论文的题目是:

“FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”

内存的效率与一般注意力相比(序列长度是二次的,O(N²)),FlashAttention 是次二次的 / 线性的 N (O(N))。并且它不是注意力机制的近似值(例如,稠密或低秩矩阵近似值办法)- 它的输入与“传统”注意力机制雷同。与一般的注意力相比,FlashAttention 的注意力是”有感知“的。

它利用底层硬件的内存档次常识 (例如 gpu,但其余 AI 加速器也应该工作,我这里应用 gpu 作为示例)。一些[近似] 办法在序列长度上将计算要求升高到线性或近线性,但其中许多办法专一于缩小 FLOP,而疏忽内存拜访 (IO) 的开销。

通过多年的倒退 gpu 的 FLOPS 的增长速度始终在以比内存吞吐量 (TB/s) 更快。内存的瓶颈应该引起器重。FLOPS 和内存吞吐量须要紧密结合,因为硬件上的差距,咱们就须要软件层面上的工作进行均衡。

依据计算和内存拜访之间的比率,操作能够分为以下两种:

  • 计算束缚:矩阵乘法
  • 内存束缚: 元素操作(激活,dropout,masking),归并操作(softmax,layer norm,sum 等)

在以后的 AI 加速器(GPU)上是受内存大小限度的。因为它“次要由元素操作组成”,或者更精确地说,注意力的算术密度不是很高。

咱们看看这个图:

能够看到,masking,softmax 和 dropout 是占用大量工夫的操作,而不是矩阵乘法(即便大部分 FLOPS 是在 matmul 中)。内存不是一个繁多的工件,它在实质上是分层的,个别的规定是: 内存越快,越低廉,容量越小。

咱们在下面说的,FlashAttention 的注意力是”有感知“的能够归结为利用 SRAM 比 HBM(高带宽内存)快得多来确保缩小两者之间的通信。

以 A100 为例:

A100 GPU 有 40-80GB 的高带宽内存(HBM),带宽为 1.5-2.0 TB/s,而每 108 个流处理器有 192KB 的 SRAM,带宽预计在 19TB/ s 左右。

能够看到大小小了很多,然而速度却晋升了 10 倍,所以如何高效的利用 SRAM 是提速的要害,让咱们看看规范注意力实现背地的计算:

规范实现如何显示对 HW 操作形式不大尊重。它基本上将 HBM 加载 / 存储操作视为 0 老本(它不是“io 感知”)。

咱们首先思考如何使这个实现更无效(工夫和内存方面)。最简略的办法是删除冗余的 HBM 读 / 写。

如何把 S 写回 HBM 只是为了 (从新) 加载它来计算 softmax,那么咱们能够将其保留在 SRAM 中,执行所有两头步骤,而后将最终后果写回 HBM。

内核基本上是“GPU 操作”的一种奇异的说法(参考咱们以前公布的 CUDA 入门,往简略了说就是一个函数)。交融则能够将多个操作交融在一起。所以只从 HBM 加载一次,执行交融的 op,而后将后果写回来。这样做能够缩小通信开销。

这里还有一个业余名词术语是“materialization”(物化 / 实体化)。它指的是,在下面的规范注意力实现中,曾经调配了残缺的 NxN 矩阵 (S, P)。上面咱们将看到如何间接将内存复杂度从 O(N²) 升高到 O(N)。

Flash attention 基本上能够归结为两个次要观点:

Tiling (在向前和向后传递时应用)- 基本上将 NxN softmax/scores 矩阵分块成块。

Recomputation (仅在向后传递中应用)

算法如下:

下面咱们提到了很多名词,你可能还不理解。没关系上面咱们开始逐行解释算法。

FlashAttention 算法

让 Tiling 办法的次要阻碍是 softmax。因为 softmax 须要将所有的分数列耦合在一起。

看到分母了吗? 这就是问题所在。

要计算输出序列中的特定第 i 个标记对序列中其余标记的关注水平,须要在 SRAM 中随时可用所有这些分数(这里用 z_j 示意)。

然而 SRAM 的容量是无限的。N(序列长度)能够是 1000 甚至 100000 个令牌。所以 N²爆炸得很快。所以论文应用了一个技巧:把 softmax 的计算分成更小的块,最终依然失去完全相同的后果。

咱们能够只获取前一个 B 分数 (x_1 到 x_B) 并为它们计算 softmax。而后通过迭代,“收敛”到正确的后果。以一种聪慧的形式组合这些每块局部 softmax 的数字,这样最终的后果实际上是正确的。办法如下:

基本上,为了计算属于前 2 个块 (大小为 B) 的分数的 softmax,必须要跟踪每个块的 2 个统计数据:m(x)(最大分数)和 l(x) (exp 分数总和)。而后就能够用归一化系数将它们无缝地交融在一起。

这里次要是一些根本的代数运算,通过开展 f(x)和 l(x)项并与 e^x 相乘一些项会互相对消,这里就不写了。这个逻辑递归地始终继续到最初一个 (N/B) 块,这样就失去了 N 维正确的 softmax 输入!

为了具体的介绍这个算法,假如有一个大小为 1 的批处理 (即单个序列) 和单个注意力头,稍后会扩大它(通过简略地跨 GPU 的并行化 - 稍后会具体介绍)。咱们临时疏忽了 dropout 和 masking,因为稍后再增加。

咱们开始计算:

初始化:HBM 的容量以 GB 为单位测量(例如 RTX 3090 有 24 GB 的 VRAM/HBM, A100 有 40-80 GB 等),因而调配 Q, K 和 V 不是问题。

第 1 步

计算行 / 列块大小。为什么 ceil(M / 4 d) ? 因为查问、键和值向量是 d 维的,所以咱们还须要将它们组合成输入的 d 维向量。所以这个大小基本上容许咱们用 q k v 和 0 个向量最大化 SRAM 的容量。

比如说,假如 M = 1000, d = 5。那么块大小为(1000/4*5)= 50。所以一次加载 50 个 q, k, v, o 个向量的块,这样能够缩小 HBM/SRAM 之间的读 / 写次数。

对于 B_r,我也不太确定他们为什么要用 d 执行最小运算? 如果有人晓得,请评论指教!

第 2 步:

用全 0 初始化输入矩阵 O。它将作为一个累加器,l 也相似它的目标是保留 softmax 的累积分母——exp 分数的总和)。M(保留逐行最大分数)初始化为 -inf,因为咱们将对其进行 Max 运算符,因而无论第一个块的 Max 是什么 - 它必定大于 -inf。

第 3 步:

步骤 1 中的块大小将 Q, K 和 V 分成块。

第 4 步:

将 O, l, m 宰割成块(与 Q 的块大小雷同)。

第 5 步:

开始跨列循环,即跨键 / 值向量(上图中的内部循环)。

第 6 步:

将 K_j 和 V_j 块从 HBM 加载到 SRAM。在这个工夫点上咱们依然有 50% 的 SRAM 未被占用(专用于 Q 和 O)。所以 SRAM 是这样的:

第 7 步:

开始跨行外部循环,即跨查问向量。

第 8 步:

将 Q_i (B_r x d)和 O_i (B_r x d)块以及 l_i (B_r)和 m_i (B_r)加载到 SRAM 中。

这里须要保障 l_i 和 m_i 可能载入 SRAM(包含所有两头变量),这块可能是 CUDA 的常识,我不太确定如何计算,所以如果你有相干的信息,请留言

第 9 步:

计算 Q_i (B_r x d)和 K_j 转置 (d x B_c) 之间的点积,失去分数 (B_r x B_c)。并没有将整个 nxns(分数) 矩阵“物化”。

假如内部循环索引为 j (j=3),外部循环索引为 i (i=2),N 为 25,块大小为 5,上面就是刚刚计算的后果(假如以 1 为根底的索引):

也就是输出序列中标记 11-15 的标记 6 -10 的注意力得分。这里的一个要点是,这些都是准确的分数,它们永远不会扭转。

第 10 步:

应用上一步计算的分数计算 m_i_j、li_j 和 P~i_j。M ~_i_j 是按行计算的,找到下面每一行的最大元素。

而后通过利用元素运算失去 P~_i_j:

归一化 - 取行最大值并从行分数中减去它,而后 EXP

l~_i_j 是矩阵 P 的逐行和。

第 11 步:

计算 m_new_i 和 l_new_i。同样非常简单,能够重复使用下面的图表:

M_i 蕴含之前所有块的逐行最大值(j=1 & j=2,用绿色示意)。M _i_j 蕴含以后块的逐行最大值(用黄色示意)。为了失去 m_new_i 咱们只须要在 m_i_j 和 m_i 之间取一个最大值,l_new_i 也相似。

第 12 步(最重要):

这是算法中最难的局部。

它容许咱们用矩阵的模式做逐行标量乘法。如果你有一列标量 s (N)和一个矩阵 a (NxN)如果你做 diag(s)* a 你基本上是在用这些标量做 a 行的元素乘法。

公式 1(为了不便再次粘贴在这里):

第 12 步的第一项所做的 (用绿色下划线) 是: 更新了在同一行块中以后块之前的块的以后 softmax 预计。如果 j =1(这是这一行的第一个块。

第一项乘以 diag(l_i)是为了对消之前迭代中除以的雷同常数(这个常数暗藏在 O_i 中)。

表达式的第二项 (黄色下划线) 是不须要消去的,因为能够看到咱们间接将 P~_i_j 矩阵与 V 向量块 (V_j) 相乘。

e^x 项是用来批改矩阵 P~_i_j & O_i 的,办法是消去前一次迭代中的 m,用最新的预计 (m_new_i) 来更新它,该预计蕴含到目前为止逐行最大值。

以下是我的逐渐剖析(实际上只须要 5 分钟,心愿能有所帮忙!)

重点是这些里面的 e 项和 P / O 矩阵外面的 e 项消掉了,所以总是失去最新的 m_new_1 预计!

第三次迭代也是相似的,失去了正确的最终后果!

回忆一下: 这只是对最终 O_i 的以后预计。只有在咱们遍历上图中的所有红色块之后,咱们能力最终失去确切的后果。

第 13 步

将最新的累加到统计数据 (l_i & m_i) 写回 HBM。留神它们的维数是 B_r。

第 13、14、15、1 步

嵌套的 for 循环完结,O (Nxd)将蕴含最终后果: 每个输出令牌的注意力加权值向量!

简略汇总

算法能够很容易地扩大到“block-sparse FlashAttention”,这是一种比 FlashAttention 快 2 - 4 的稠密注意力算法,扩大到 64k 的序列长度! 通过应用一个块模式的掩码矩阵,能够跳过下面嵌套的 for 循环中的某些加载 / 存储,这样咱们能够按比例节俭稠密系数,比方下图

当初让咱们简略地讨论一下复杂性。

复杂度剖析

空间: 在 HBM 中调配了 Q, K, V, O (Nxd),l 和 m (N)。等于 4 Nd + 2*N。去掉常量,并且晓得 d 也是一个常量并且通常比 N 小得多 (例如 d ={32,64,128},N={1024,…,100k}),能够失去 O(N) 的空间,这有助于扩大到 64k 序列长度(再加上一些其余“技巧”,比方 ALiBi)。

工夫: 这里不会严格地进行工夫复杂度剖析,然而咱们将应用一个好的指标:HBM 拜访的数量。

论文的解释如下:

他们是怎么失去这个数字的? 让咱们来剖析嵌套的 for 循环:

咱们的块大小是 M /4d。这意味着向量被宰割成 N /(M/4d)块。取它的 2 次方 (因为要遍历行 / 列块) 失去 O(N²d²/ M²)

咱们不能一次获取整个块,如果做一个大 O 剖析,可能会让咱们认为这并不比规范注意力好多少,但对于典型的数字,这导致拜访次数缩小了 9 倍(依据下面的论文截图)。

咱们的伪算法集中在一个单头注意力,假如批处理大小为 1。上面咱们就开始进行扩大了

多头注意力

要扩大到 batch_size > 1 和 num_heads > 1 实际上并不难。

算法基本上是由单个线程块 (CUDA 编程术语) 解决的。这个线程块在单个流多处理器 (SM) 上执行(例如,A100 上有 108 个这样的处理器)。为了并行化计算,只须要在不同的 SMs 上并行运行 batch_size * num_heads 线程块。该数字与零碎上可用的 SMs 数量越靠近,利用率就越高(现实状况下是多个,因为每个 SM 能够运行多个线程块)。

反向流传

对于 GPU 内存的占用,另外一个大头就是反向流传,通过存储输入 O (Nxd)和 softmax 归一化统计数据 (N),咱们能够间接从 SRAM 中的 Q, K 和 V (Nxd) 块中反向计算注意力矩阵 S (NxN)和 P (NxN) ! 从而使内存放弃在 O (N)。这个比拟业余了,咱们理解以下就能够了,所以须要具体的内容请看原论文。

代码实现

最初,让咱们看看在应用 flash attention 时可能呈现的一些问题。因为波及到显存的操作,所以咱们只能深刻 CUDA,然而 CUDA 又比较复杂。

这就是 OpenAI 的 Triton 等我的项目的劣势 (参见他们的 FlashAttention 实现)。Triton 基本上是一种 DSL(畛域特定语言),介于 CUDA 和其余畛域特定语言(例如 TVM) 之间的形象级别。能够编写超级优化的 Python 代码(一旦编译),而不用间接解决 CUDA。这样 Python 代码能够部署在任意的加速器上(这是 Triton 工作)。

另外一个好消息是 Triton 最近曾经与 PyTorch 2.0 集成了。

另外对于某些用例,比方对于超过 1K 的序列长度,一些近似留神办法 (如 Linformer) 开始变得更快。然而 flash attention 的块稠密实现优于所有其余办法。

总结

你有没有想过,对于这种底层优化的算法为什么是一个斯坦福大学的学生公布,而不是 NVIDIA 的工程师?

我认为有 2 种可能的解释:

1、FlashAttention 更容易 / 只能在最新的 gpu 上实现(原始代码库不反对 V100)。

2、通常“局外人”是那些以初学者的眼光对待问题,可能看到问题的本源并从根本准则登程解决问题

最初咱们还是要进行个总结

FlashAttention 可能让 BERT-large 训练中节俭 15%,将 GPT 训练速度进步 2 /3,并且是在不须要批改代码的状况下,这是一个十分重要的提高,它为 LLM 的钻研又提出了一个新的方向。

论文地址:

https://avoid.overfit.cn/post/9d812b7a909e49e6ad4fb115cc25cdc1

作者:Aleksa Gordić

正文完
 0