分布式机器学习之张量并行:层内切分 Transformer 模型
Tensor Parallel 是目前大模型训练和推理中最常用的并行方式之一,主要针对 Transformer 类模型。本文首先介绍了 GEMM 切分的按列并行和按行并行两种方式,然后在 Transformer 的各个组件,包括 MLP、Attention、input embedding、LM head + cross entropy loss 中根据实际情况设计具体的切分方案。
分布式机器学习之张量并行:层内切分 Transformer 模型
Tensor Parallel (张量并行)是目前大模型训练推理最常用的并行方式之一,首先由 NVIDIA 的 Megatron LM 提出,最初是专门针对 Transformer 模型设计的。Transformer 一般是由输入 embedding 层,输出 LM Head 层以及多个 Attention Block。我们需要对他们分别设计张量切分方案。这些模块实际上都是由线性层组成的,也就是说基础算子都是 GEMM(通用矩阵乘法),本文将首先介绍 GEMM 的按列并行和按行并行两种方式,然后介绍在 Transformer 的这些组件中,如何设计具体的张量并行方案。
如果图示和理解有误,请指出。
GEMM 的切分
由于 Transformer 中所有的参数层(MLP、Attention、Embedding、Head)本质上的基础算子都是 GEMM(通用矩阵乘法)。我们先来看看如何将 GEMM 切分到不同的 GPU 上。矩阵有行和列两个维度,那么很自然的,GEMM 的张量并行切分也分为按列并行和按行并行两种方式。
记 GEMM 操作中,输入矩阵 XXX,参数矩阵 AAA,输出 Y=AXY=AXY=AX。
按列拆分并行
按列并行方式是将参数矩阵 AAA 按行列分开,即 A=[A1,A2]A=[A_1,A_2]A=[A1,A2],输入矩阵 XXX 不变,这样每个 GPU 计算得到的是某一部分的最终结果,然后进行一次 All Gather 操作,将各个 GPU 上的最终结果通信,得到完整的最终结果。
Y=[Y1,Y2]=[XA1,XA2] Y=[Y_1,Y_2]=[XA_1,XA_2] \notag \\ Y=[Y1,Y2]=[XA1,XA2]
按行拆分并行
将参数矩阵 AAA 按行切分开,将即 A=[A1A2]A=\left[\begin{array}{c} A_1 \\ A_2\end{array} \right]A=[A1A2] ,输入矩阵 XXX 按列切分开,X=[X1,X2]X=[X_1,X_2]X=[X1,X2],这样每个 GPU 计算得到的是完整的部分结果(这样表述可能有点歧义,实际上是指每个 GPU 都得到了完整形状的输出结果,但需要将所有 gpu 的结果相加才是完整的最终结果),此时需要进行一次 All Reduce 操作,将每个 GPU 上的完整的部分结果进行加和,得到完整的最终结果。
Y=X1A1+X2A2 Y=X_1A_1+X_2A_2 \notag \\ Y=X1A1+X2A2
在理解了 GEMM 的两种切分方式之后,我们来看 Transformer 中各个组件的具体切分方案。
MLP
首先来看 MLP 层的切分,MLP 层由两个线性层组成,在它们之间是一个按元素的、非线性的激活函数(比如 GeLU),即 Y=act(XA)BY=\text{act}(XA)BY=act(XA)B 。如何设计 MLP 层切分的最优方案呢?
对于第一个线性层,如果我们用按行并行的方式,每个 GPU 上计算得到的结果是完整的部分结果,由于激活函数是非线性的(即 act(A+B)≠act(A)+act(B)\text{act}(A+B)\ne\text{act(A)}+\text{act}(B)act(A+B)=act(A)+act(B)),因此接下来必须先进行一次 sum reduce 同步,获取完整的最终结果,然后才能计算激活函数。而如果我们用按列并行的方式,得到的是部分的最终结果,可以先在单个 GPU 上各自进行激活函数的计算,不需要进行额外的同步通信。
并且,从上图可以看到按列并行的单 GPU 结果正好是按行并行需要的输入矩阵,既然激活函数的计算已经可以在单 GPU 上各自完成,那么到第二个线性层,也不需要额外的通信,直接采用按行并行的方式进行计算,完成后再通过一次 sum all reduce 同步结果即可。
综上所述,在 MLP 层中,我们在第一个线性层使用按行并行的切分方式,第二个线性层则使用按列并行,这样整个 MLP 前向计算过程中就只需要在最后进行一次 all reduce 即可,中间没有任何额外通信。
Attention
attention 层的切分看起来要稍微复杂一点,但其实思路上跟 MLP 层差不多。multi head attention 这种形式天然就很适合进行切分,因为它本来就要对多个 head 进行列切分,我们正好让每个 GPU 处理 n_heads / n_gpu 个 head 就好了。即,参数矩阵 WQ,K,VW_{Q,K,V}WQ,K,V 也是按列进行切分。每个 GPU 上的注意力计算完之后的结果,正好也是行切分输入矩阵的形式,然后还是将输出参数矩阵 WOW_OWO 按行切分,最后进行一次 all reduce。
整个过程中也是没有中间过程的通信,只需要在最后进行一次 all reduce。
input embedding 与 LM head + cross entropy loss
在语言模型中, input embedding 与 LM head 的形状是相同的,都是 (h,V)(h,V)(h,V),hhh 和 VVV 分别表示隐层大小和词表大小,一个常用的设计是将二者进行参数绑定(tied),即二者共用同样一套参数。对它们进行张量并行切分时,一般是按照词表维度进行切分。
在 input embedding 中,输入是 token ids,形状为 (b,l)(b,l)(b,l),讲这些 token id 在 embedding 中查表,获取到其对应的 hhh 维度的隐层向量。在张量并行时,将词表切分到不同的 GPU,那么在单个 GPU 上,输入 token id 查表时就有可能查不到,查不到的先用全 0 的 hhh 维向量填充,然后进行一次 all reduce (sum) 操作,即可将从其他 GPU 的结果中补上缺失的向量。
LM head 部分,首先与 last hidden state 计算出 logits,同样是按照词表维度切分,在这里是按列切分。接下来先不着急做 all reduce,因为这里如果要通信的话,数据量是 b×l×Vb\times l\times Vb×l×V,其中词表大小 VVV 一般是几万到十几万,b×lb\times lb×l 也能到几万,这样通信的数据量太大。为了减小通信量,我们可以让每个 GPU 自己先算 softmax 和损失,然后只通信最终损失这个标量值,这样通信数据量就小的多了。具体来说,此时每个 GPU 上已经计算出了自己部分词表的 logits,我们先计算自己部分的 softmax,但是由于 softmax 不是 elementwise 的操作,需要全局的指数和做分母,因此这里我们需要先 all reduce 一次,获取全局的完整指数和,这里的通信量是 b×lb\times lb×l。接下来,各个 GPU 就可以自己完成 softmax 和损失计算,最后在 all reduce 一下最终的损失值,通信量也是 b×lb\times lb×l。这样通信成本就大大减小了。
(注意这里图示中将 b×lb\times lb×l 维度展平来绘制了)
总结
Tensor Parallel 是目前大模型训练和推理中最常用的并行方式之一,主要针对 Transformer 类模型。本文首先介绍了 GEMM 切分的按列并行和按行并行两种方式,然后在 Transformer 的各个组件,包括 MLP、Attention、input embedding、LM head + cross entropy loss 中根据实际情况设计具体的切分方案。
更多推荐
所有评论(0)