关于教程:手把手推导分布式矩阵乘的最优并行策略

2次阅读

共计 7354 个字符,预计需要花费 19 分钟才能阅读完成。

作者|郭冉、李一鹏、柳俊丞、袁进辉

罕用深度学习框架的主动并行机制还不够欠缺,还须要用户依据教训来配置并行形式,这给开发者带来了不小的智力累赘。因而,实现主动最优并行就成为一个乏味的课题。

矩阵乘是深度学习最罕用的底层计算原语,譬如卷积算子,注意力机制都是通过矩阵乘来实现的,所以大规模神经网络的并行实现大多数时候也是在解决分布式矩阵乘。本文就以如何最优地实现分布式矩阵乘为例来展现主动并行的解决思路。

1

如何实现最优的分布式矩阵乘?

通过上一篇文章《手把手推导 Ring all-reduce 的数学性质》咱们晓得了常见集群通信操作的通信量和所需通信工夫的数学性质,在这篇文章里咱们看看怎么应用这些性质来抉择最优的并行矩阵乘策略。

在文章《如何超过数据并行和模型并行:从 GShard 谈起》,咱们介绍了如何从个别的数据并行、模型并行提炼出最一般性的算子并行的形象示意 SBP。

假如咱们心愿在 4 张显卡 (2 台服务器,每台服务器上有 2 张显卡) 上实现一个矩阵乘 \(X\times W=Y \),也就是 \(y_{ij}=\sum_{k}{x_{ik}\times w_{kj}}\),其中 \(X \)和 \(W \)依照特定的 SBP 签名被摆放(place)到 4 张显卡上,那么将有多个形式实现分布式矩阵乘,它们在数学上等价,不过须要调用的集群通信操作不同,从而触发的通信代价也不同。

沿用《手把手推导 Ring all-reduce 的数学性质》里的符号,\(p \)示意设施数,\(V \)示意矩阵大小 \(V_{x} \)示意矩阵 \(X \)的大小,\(V_{w} \)示意矩阵 \(W \)的大小),\(\beta \)示意传输带宽。

2

数据并行还是模型并行?


图 1:基于 1D 矩阵乘的数据并行

如果 \(X \)和 \(W \)的 SBP 签名别离是 \(S(0) \)和 \(B \),那么能够推导进去 \(Y \)的 SBP 是 \(S(0) \),也就是左矩阵 \(X \)是行划分,右矩阵 \(W \)是在各个卡上是截然不同的拷贝(broadcast)。如果 \(X \)示意特色数据(feature map),\(W \)示意模型参数,那么这是一个典型的数据并行,上面咱们剖析一下数据并行的通信代价。

数据并行的反向须要执行集群通信操作 all-reduce,如果采纳环状算法,那么所有设施间的数据传输量是 \(2(p-1)V_{w} \),执行工夫是 \(\frac{2(p-1)V_{w}}{p\beta} \)。


图 2:基于输入层神经元划分的模型并行

如果 \(X \)和 \(W \)的 SBP 签名别离是 \(B \)和 \(S(1) \),那么能够推导进去 \(Y \)的 SBP 是 \(S(1) \),也就是左矩阵 \(X \)在各个卡上是截然不同的拷贝(broadcast),右矩阵 \(W \)在各个卡上列划分。如果 \(X \)示意特色数据(feature map),\(W \)示意模型参数,那么这是一个典型的模型并行,上面咱们剖析一下这种模型并行的通信代价。

如果 \(Y \)以 \(S(1) \)的状态参加上游的计算,那么 \(Y=X \times W \)自身并不需要引入额定的通信。但假如 \(Y \)须要被复原成和 \(X \)一样的状态(broadcast)参加上游计算,则前向计算时须要在 \(S(1) \)签名的 \(Y \) 上调用 all-gather 操作,后向计算时须要在 \(Y \)的反向 error signal 上调用 reduce-scatter 操作。那么前向和反向总的通信量是 \(2(p-1)V_{y} \),执行工夫是 \(\frac{2(p-1)V_{y}}{p\beta} \)。

留神,矩阵乘引入的通信量不只是由以后算子决定的,还取决于它所处的上下文;咱们这里的剖析假如上游的算子须要 \(Y \)放弃和输出 \(X \)一样的 SBP 签名,在这种状况下探讨不同并行形式的通信量。


图 3:基于输出层神经元划分的模型并行

如果 \(X \)和 \(W \)的 SBP 签名别离是 \(S(1) \)和 \(S(0) \),那么能够推导进去 \(Y \)的 SBP 是 \(P \),也就是左矩阵 \(X \)在各个卡上是列划分,右矩阵 \(W \)在各个卡上行划分。如果 \(X \)示意特色数据(feature map),\(W \)示意模型参数,那么这也是一个模型并行的形式(只不过是对全连贯层的输出神经元划分而来),上面咱们剖析一下这种模型并行的通信代价。

如果 \(Y \)以与 \(X \)雷同的 \(S(1) \)的状态参加上游的计算,则前向计算时须要在 \(P \)签名的 \(Y \)上调用 reduce-scatter 操作,后向计算时须要在 \(Y \)的误差上调用 all-gather 操作。那么前向和反向总的通信量是 \(2(p-1)V_{y} \),执行工夫是 \(\frac{2(p-1)V_{y}}{p\beta} \)。

依据以上的剖析,数据并行的通信量是 \(2(p-1)V_{w} \),模型并行的通信量是 \(2(p-1)V_{y} \),因而单就这一个矩阵乘而言,到底应用数据并行还是模型并行是比拟容易确定的,也就是取决于 \(V_{w} \)和 \(V_{y} \)哪个大,如果 \(V_{w} > V_{y} \),示意权重矩阵的容量大于输入特色数据的容量(譬如超大的全连贯层),那么适宜模型并行;如果 \( V_{w} < V_{y} \),示意示意权重矩阵的容量小于输入特色数据的容量(譬如卷积层),那么适宜数据并行。

值得一提的是,在实践中,数据并行和模型并行还不单单由 \(V_{w} \)和 \(V_{y} \)哪个大来决定,数据并行中 all-reduce 通信比拟容易被反向计算所覆盖,而模型并行的通信不容易被计算覆盖,因而即便 \(V_{w} > V_{y} \)了,实践上应该用模型并行了,但当数据并行反向覆盖 all-reduce 的劣势超过模型并行中通信量更小的劣势时,应用数据并行还是更优的。这就是问题的简单之处,最优的并行形式不仅仅是一个代价函数决定的,还和零碎具体实现密切相关。

3

高维并行(矩阵乘)是怎么回事?

在英伟达为大规模预训练模型开发的 Megatron-LM 里,矩阵乘应用了 2D 并行,譬如同一个算子在机器间应用了数据并行,机器外部应用了模型并行。有一篇论文也提出 2D 并行来实现矩阵乘 An Efficient 2D Method for Training Super-Large Deep Learning Models(https://arxiv.org/pdf/2104.05…)。

2D 并行是怎么回事?真的会带来益处吗?为什么呢?咱们还没有发现已有文献对这个问题从实践上探讨分明,心愿这篇博客能彻底搞清楚这些问题。


图 4:2D 并行

假如咱们有 2 台机器,每台机器 2 个设施,\(X \)在机器间是 \(S(0) \),在机器外部是 \(B \),而 \(W \)在机器间是 \(B \),在机器外部是 \(S(1) \),计算结果在机器间是 \(S(0) \), 机器外部是 \(S(1) \)。

这个例子里,机器间是数据并行,机器外部是模型并行。

把 \(Y \)从 \({S(0),S(1)} \)转换成和 X 一样的 \({S(0),B} \),那么前向计算须要每台机器外部执行 all-gather,反向须要在每台机器外部执行 reduce-scatter,其传输量是 \(2(\sqrt{p}-1)V_{y} \)。同时机器之间是数据并行,反向计算须要在第 1 台机器的第 1 张卡和第 2 台机器的第 1 张卡之间,以及第 1 台机器的第 2 张卡和第 2 台机的第 2 张卡之间别离调用 all-reduce,传输量是 \(2(\sqrt{p}-1)V_{w} \),总的传输量是 \(2(\sqrt{p}-1)(V_{y}+V_{w} \))。

2D 的 all-gather 为例,咱们再粗疏的解释一下下面的传输量是怎么推导进去的。

假如一共 \(\sqrt{p} \)台机器,每台机器上有 \(\sqrt{p} \)个设施,每台机器外部须要在 \(\sqrt{p} \)个设施之间实现 \(\frac{V_{y}}{\sqrt{p}} \)大小的矩阵,所以每台机器外部的传输量是 \(\frac{2(\sqrt{p}-1)V_{y}}{\sqrt{p}} \),一共 \(\sqrt{p} \)台机器,因而前向 all-gather 传输量是 \(2(\sqrt{p}-1)V_{y} \)。


图 5:2D 矩阵乘

2 台机器,每台机器 2 个设施,\(X \)在机器间是 \(S(0) \),在机器外部是 \(S(1) \),而 \(W \)在机器间是 \(B \),在机器外部是 \(S(0) \),计算结果在机器间是 \(S(0) \), 机器外部是 \(P \)。

机器间是数据并行,机器外部是模型并行。

把 \(Y \)从 \({S(0),P} \)转换成和 \(X \) 一样的 \({S(0),S(1)} \),那么前向计算须要每台机器外部执行 reduce-scatter,反向须要在每台机器外部执行 all-gather,其传输量是 \(2(\sqrt{p}-1)V_{y} \)。同时机器之间是数据并行,反向计算须要在第 1 台机器的第 1 张卡和第 2 台机器的第 1 张卡之间,以及第 1 台机器的第 2 张卡和第 2 台机器的第 2 张卡之间别离调用 all-reduce,传输量是 \(2(\sqrt{p}-1)V_{w} \),
总的传输量是 \(2(\sqrt{p}-1)(V_{y}+V_{w}) \)。


图 6:2D 矩阵乘

图 6 展现了经典的 2D SUMMA 算法的实现。间接依照图 6 所示的数据分布是无奈间接执行矩阵乘的,\(X \)和 \(W \) 在机器外部都须要执行 all-gather 计算,变成图 4 所示的数据分布才能够,相应的反向计算须要在机器外部执行 reduce-scatter,总的通信量是 \(2(\sqrt{p}-1)(V_{x}+V_{w}) \)。

4

高维矩阵乘有什么益处?

咱们以图 4 所示的 2D 矩阵乘为例来探讨高维矩阵乘绝对于 1D 矩阵乘带来了什么益处。

首先假如 \(V_{x}=V_{w}=V_{y}=V \),那么 1D 矩阵乘的通信量是 \(2(p-1)V \),而 2D 矩阵乘的通信量是 \(4(\sqrt{p}-1)V \),基本上能够认为,当 \(p>4 \),2D 矩阵乘通信量就小于 1D 矩阵乘的通信量了。

能够揣测,如果是 3D 矩阵乘,那么通信量是和 \(\sqrt[3]{p} \)成正比的。高维矩阵乘的实质是减小了每一个集群通信操作的”宽度“,咱们在上一篇博客《手把手推导 Ring all-reduce 的数学性质》推导过集群通信的通信量是和通信宽度成正比的。

5

高维矩阵乘会升高通信工夫吗?

仔细的敌人可能留神到了,咱们在探讨 1D 矩阵乘的通信代价时,总是同时探讨通信量和通信工夫,然而在探讨 2D 矩阵乘的通信代价时,却只探讨了通信量,没有探讨通信工夫。方才咱们也探讨了,高维矩阵乘会升高通信量,那么高维矩阵乘的通信工夫也会升高吗?

实际上不会。论断有点违反直觉,为什么呢?起因是:通信量变成原来的 \(\frac{1}{\sqrt{p}} \),但每个设施同时参加多组集群通信,每组集群通信可应用的带宽也变成原来的 \(\frac{1}{\sqrt{p}} \)。上面咱们看一个具体的例子。


图 7:DGX-A100 通信拓扑

图 7 展现了 DGX-A100 机器的通信拓扑,假如一共有 4 台机器,每台机器有 4 个 GPU,每台机器有 4 张网卡,因而机器之间的带宽是每张网卡带宽的 4 倍。


图 8:1D 并行的环状通信拓扑

在 1D 并行 ,假如所有 GPU 形成图 8 所示的一个大环。机器间通信带宽为 \(\beta=\sqrt{p}\times \beta_{IB} \)(留神: 下文的公式和上文公式带宽差一个 \(\sqrt{p} \)系数,来源于此),其中 \(\beta_{IB} \) 示意 IB 网卡带宽,在 DGX A100 拓扑中机器间 IB 带宽通常小于机器内 GPU 设施间通信带宽,因而此处整体通信受限于机器间带宽,通信工夫为 \(\frac{2(p-1)V}{p\times \sqrt{p}\beta_{IB}} \) (留神:分母须要乘以设施总数 \( p \))。


图 9:2D 并行的环状通信拓扑

在 2D 并行 ,以 SUMMA 矩阵乘法为例,每行的 4 个 GPU 设施形成一个环,即[machine 0 : gpu 0,machine 1 : gpu0,machine 2 : gpu 0,machine 3 : gpu0]、[machine 0 : gpu 1,machine 1 : gpu1,machine 2 : gpu 1,machine 3 : gpu1] 组成一个环等。每列的 4 个 GPU 设施也形成一个环。前向计算时,每个环上都要同时执行 all-gather 操作,跨机器的每个集群通信操作都会占用 \(\frac{1}{\sqrt{p}} \)的网络带宽,也就是 \(\beta_{IB} \),机器外部的每个集群通信带宽不是瓶颈所在,因而不影响最终后果。通信工夫不难推导,是 \(\frac{2(\sqrt{p}-1)V}{p\times \beta_{IB}} \) (这里除以 p 失去的是每个设施的通信量),和 1D 并行的通信工夫 \( \frac{2(p-1)V}{p\times \sqrt{p}\beta_{IB}} \) 是同一个数量级。

至此,咱们晓得:2D 矩阵乘减小了集群通信的宽度,因而升高了所须要的通信量,但不会升高通信工夫。

甚至,在特定的状况下,1D 矩阵乘的通信工夫要小于 2D 矩阵乘,这又是为什么呢?

2D 矩阵乘的通信工夫是 \(\max{\frac{2(\sqrt{p}-1)V_{1}}{p\beta_{1}},\frac{2(\sqrt{p}-1)V_{2}}{p\beta_{2}}} \)
其中区别了不同的矩阵和不同环的传输带宽。假如 \(\beta_{1} < \beta_{2} \)(机器间带宽小于机器外部带宽),那么 2D 矩阵乘的通信工夫至多是

\(\max{\frac{2(\sqrt{p}-1)V_{1}}{p\beta_{2}},\frac{2(\sqrt{p}-1)V_{2}}{p\beta_{2}}} \)

1D 矩阵乘的通信工夫是抉择数据并行和模型并行中更优的那一个:

\(\min{\frac{2(p-1)V_{1}}{p\sqrt{p}\beta_{1}},\frac{2(p-1)V_{2}}{p\sqrt{p}\beta_{1}}} \)

当 \(V_{1} \) 和 \(V_{2} \) 相差比拟迥异时,无妨假如 \(V_{1}<V_{2} \),那么 2D 并行通信工夫的下界是 \(\frac{2(\sqrt{p}-1)V_{2}}{p\beta_{2}} \),而 1D 并行的通信工夫是 \(\frac{2(p-1)V_{1}}{p\sqrt{p}\beta_{1}} \),不难失去,当 \(V_{1}<\frac{\beta_{1}}{\beta_{2}}V_{2} \) 时,1D 并行的通信工夫肯定小于 2D 并行的通信工夫。

因而,2D 并行在升高通信量(或者带宽需要)上有劣势,1D 并行在升高通信工夫上有劣势。

一般来说,一个神经网络中同时存在很多相似矩阵乘的算子,算子档次的并行都须要引入通信需要,通信带宽十分富余,那么就能够释怀的应用 1D 并行,这样确保通信工夫是最小的;如果通信带宽是瓶颈,那么每一个算子都应该尽可能升高通信量的需要,节俭带宽,这样能力让总体的通信工夫最小。

2D 并行的带宽需要升高了,但通信工夫没有变动,起因是什么呢?直观的了解是,在 2D 并行中肯定有一部分带宽是被闲置了。设想一下,一个大环被切成几段,造成几个小环,小环和小环之间的带宽是不须要用的。

6

结语

如果你在 GPU 上实现过单卡矩阵乘法,那可能对下面 2D 矩阵乘的示意图很相熟,没错,在单卡实现矩阵乘时,要害也在于尽可能减小 global memory 和 shared memory 之间的数据搬运。

因而,那里也须要做相似于分布式矩阵乘的通信代价剖析,分布式是宏观档次的数据搬运,单卡是宏观档次的数据搬运,二者在原理上十分类似。实际上,已有文献对分布式矩阵乘的通信代价的实践剖析曾经十分成熟,本文探讨的 2D 阵乘或 3D 矩阵乘的实现形式都已实现了各自拓扑下通信代价的实践下界。

本文只探讨了一个算子并行时的最优策略,其实每个算子的最优策略也和它所处的上下文相干,一个算子不仅仅要思考那个并行策略对本身是不是无利,还要思考它的计算结果对四周的算子是不是无利。

因而,给定一个神经网络,它的最优并行策略是一个组合优化问题,如果这个神经网络是链状(chain-structure)的,那么能够证实,应用动静布局算法就能够在多项式工夫内求出全局最优解,当神经网络的构造不是链状时,就无奈应用动静布局,就须要一系列伎俩尽可能升高搜寻空间的规模。

auto-placement 和 auto-parallelism 是业界宽泛关注的一个热点问题。很多钻研工作间接就把问题模式化成一个组合优化的问题,但比拟少探讨分布式深度学习本身的数学法则。

OneFlow 团队在钻研过程中发现,如果能对问题自身的数学性质做深刻的实践剖析,充分利用这些实践性质,auto-placement 和 auto-parallelism 的求解能够出其不意的简略。

迄今为止,咱们应该对数据并行和模型并行探讨得很深刻了,将来,咱们会对流水并行的实践性质展开讨论。

正如本文在探讨 1D 并行和 2D 并行实现时所画的各种示意图所示,不同的数据切分形式带来不同的并行形式,也带有不同的通信代价。有些切分形式并不直观,怎么能力从实践上保障一种切分形式是正确的?怎么能力穷尽所有实践上正确的切分形式?

OneFlow SBP 提供了一种很弱小的数学形象,不仅能够用来剖析 1D 矩阵乘,还能够很不便地剖析 2D 矩阵乘,大大简化了剖析这些简单问题的难度。强烈推荐做这方面工作的小伙伴儿都来用这套工具。

如果想更具体理解 SBP 如何在分布式模型训练里施展威力,能够参照 OneFlow 公布的 LiBai(
https://github.com/Oneflow-In…),仅仅 1 万行外围代码就实现了 NVIDIA Megatron-LM 和 Microsoft DeepSpeed 须要五六倍代码量能力实现的性能。

欢送下载体验 OneFlow v0.7.0 最新版本:
https://github.com/Oneflow-In…

正文完
 0