关于深度学习:想要更好地理解大模型架构从计算参数量快速入手

153次阅读

共计 9200 个字符,预计需要花费 23 分钟才能阅读完成。

编者按:要了解一种新的机器学习架构(以及其余任何新技术),最无效的办法就是从头开始实现它。然而,还有一种更简略的办法——计算参数数量。

通过计算参数数量,读者能够更好地了解模型架构,并查看其解决方案中是否存在未被发现的谬误。

该文章提供了准确的 Transformers 模型的参数量计算公式和不太精确的简略公式版本,使读者可能疾速估算基于 Transformer 的任何模型中参数的数量。

以下是译文,Enjoy!

作者 | Dmytro Nikolaiev (Dimid)

编译 | 岳扬

要了解一种新的机器学习架构(以及其余任何新技术),最无效的办法就是从头开始实现它。 尽管这可能会非常复杂、耗时,并且有时简直不可能实现,但这是帮忙咱们了解每个技术细节的最佳办法。例如,如果没有相似的计算资源或数据,咱们将无奈确保咱们的解决方案中没有未被发现的谬误。

然而,还有一种更简略的办法——计算参数数量。 这比仅仅浏览论文要艰难得多,但能够让咱们深刻开掘并查看是否齐全了解了新架构的构件(在本文的例子是 Transformer 的编码器(Encoder)和解码器(Decoder)构件)。

咱们能够通过上面这幅图表来思考这个问题,这张图表展现了三种了解新 ML 架构的办法——圆圈的大小示意对该架构的了解水平。

本文次要钻研驰名的 Transformer 架构,并思考如何计算 PyTorch TransformerEncoderLayer[1] 和TransformerDecoderLayer[2]类中的参数数量。因而,咱们须要确保对于该架构由哪些局部组成不再充斥神秘感。

TLDR(总结)

(该文篇幅比拟长,如果不想深入探讨或工夫无限,能够间接看总结局部)

您能够浏览“论断 Conclusions”局部,所有参数量计算公式都总结在“论断 Conclusions”局部。

本文不仅提供准确的参数量计算公式,还可能提供不太精确的公式近似版本,将使您可能疾速估算基于 Transformer 的任何模型中参数的数量。

01 Transformer 架构

驰名的 Transformer 架构于 2017 年在《Attention Is All You Need[3]》这篇论文中提出,并因其具备可能无效捕获长距离的依赖关系(long-range dependencies)的能力而成为自然语言解决和计算机视觉工作中的规范架构。

早在 2023 年初,扩散模型(Diffusion)[4]因为文转图生成模型 [5] 的大火而变得极其风行。兴许,很快扩散模型将成为各种工作的最先进技术,就像 Transformer 与 LSTM 和 CNN 一样。但咱们先来看看 Transformer……

本文并不试图去解释 Transformer 架构,因为曾经有很多足够好的文章做到了这一点。这篇文章只是让咱们可能从不同的角度去对待它,或者解说一些细节问题。所以如果你正在寻找更多无关此架构的学习资源,我能够向你举荐一些;否则,您能够持续浏览上来。

1.1 理解更多 Transformer 的资源

如果你正在寻找更加具体的 Transformer 架构概述,能够浏览以下资料(请留神,互联网上有很多技术内容,我只是集体喜爱这些):

  • 首先,能够浏览 官网论文 [3]。第一次接触 Transformer 就浏览论文可能不是最佳形式,但这并不像看起来那么简单。能够尝试应用Explainpaper 来帮忙您浏览此论文 [6] 或其余论文(这是一种基于 AI 的工具,能够解释用鼠标标记的文本)。
  • Jay Alammar 的“Great Illustrated Transformer[7]”。如果您不喜爱阅读文章,能够观看同一作者的 YouTube 视频[8]。
  • Lukasz Kaiser 在 Google Brain 的 “Awesome Tensor2Tensor” 讲座[9]。
  • 如果想间接进行实操并应用各种 Transformer 模型构建应用程序,请查看 Hugging Face 课程[10]。

1.2 Original Transformer

首先,让咱们回顾一下 Transformer 的基础知识。

Transformer 的架构由两个组件组成:编码器(在右边)和解码器(在左边)。编码器承受输出 token 序列并生成暗藏状态序列(sequence of hidden states),而解码器则承受这个暗藏状态序列并生成输入 token 序列。

Transformer 架构图,来自 https://arxiv.org/pdf/1706.03762.pdf

编码器和解码器都由一堆雷同的层组成。对于编码器,该层包含 多头注意力 (multi-head attention)(1——此处及下文中的数字指的是上面的图片中标序号的局部)和一个带有一些 层归一化 (3)和 跳跃连贯 (skip connections)的 前馈神经网络(feed-forward neural network)(2)。

解码器也相似于编码器,但除了 第一个多头注意力 (4)(在机器翻译工作中被屏蔽,所以解码器不会通过查看将来的 tokens 进行舞弊)和一个 前缀网络 (5)之外,它还具备 第二个多头注意力机制 (6)。它容许解码器在生成输入时应用编码器提供的上下文(context)。与编码器一样,解码器也有一些 层归一化 (layer normalization)(7)和 跳跃连贯组件

带有序号标记组件的 Transformer 架构图

来自 https://arxiv.org/pdf/1706.03762.pdf

我不会将输出嵌入层(带有 地位编码)和最终输入层(linear+softmax)视为 Transformer 组件,而只关注编码器和解码器块。这样做是因为这些组件是实用于某些特定工作和嵌入办法的,而编码器和解码器栈是其余体系结构的根底。

这种架构的例子包含用于编码器的基于 BERT 的模型(BERT、RoBERTa、ALBERT、DeBERTa 等),用于解码器的基于 GPT 的模型(GPT、GPT-2、GPT-3、ChatGPT),以及构建在残缺的编码器 - 解码器框架上的模型(T5、BART 等)。

只管咱们在该架构中标记了七个组件,但咱们能够看到,其中仅有三个独特的组件:

  • 多头注意力(Multi-head attention);
  • 前馈网络(Feed-forward network);
  • 层的归一化(Layer normalization)。

Transformer 构件 来自论文 https://arxiv.org/pdf/1706.03762.pdf

02 Transformer 构件块

让咱们考虑一下每个模块的内部结构以及它须要多少参数。在本节中,咱们还将开始应用 PyTorch[11]来验证咱们的计算结果。

为了查看某个模型块的参数数量,我将应用以下这行函数[12]:

import torch
# https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9
def count_parameters(model: torch.nn.Module) -> int:
"""Returns the number of learnable parameters for a PyTorch model"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)

在咱们开始之前,请留神一个事实,即 所有构件块都是标准化的,并且应用跳跃连贯 这意味着所有输出和输入的 shape(更确切地说,是其最初一个数字 因为 batch size 和 tokens 数量可能会有所不同)必须雷同 对于原论文,这个数字(d_model)为 512。

2.1 多头注意力

驰名的注意力机制是 Transformer 架构的要害。然而,无论设计动机和技术细节如何,它只波及几个矩阵乘法。

Transformer 多头注意力架构图 

来自论文 https://arxiv.org/pdf/1706.03762.pdf

计算了每个 head 的注意力后,咱们将所有 head 连接起来,并通过一个 线性层(W_O 矩阵)进行传递。反过来,每个 head 都是用三个独立的矩阵乘以 query、key 和 value(别离为 W_Q、W_K 和 W_V 矩阵)的 Scaled dot-product attention(缩放点积注意力)。这三个矩阵对每个 head 都是不同的,这就是下标 i 呈现的起因。

最终线性层(final linear layer)(W_O)的 shape 为 d_model 到 d_model。其余三个矩阵(W_Q、W_K 和 W_V)的 shape 雷同:d_model 到 d_qkv。

请留神,在下面的图像中,d_qkv 被示意为原论文中的 d_k 或 d_v。我认为这个名称更直观,因为只管这些矩阵可能具备不同的 shape,但简直总是雷同的。

此外,请留神,d_qkv = d_model / num_heads (文中的 h)。这就是为什么 d_model 必须可能被 num_heads 整除的起因:以确保前面的连贯正确。

能够通过查看上图中的所有两头阶段的 shape(正确的 shape 在右下角标出)来自行测试。

因而,咱们须要每个 head 有三个较小的矩阵和一个大的最终矩阵。那么咱们须要多少参数(不要疏忽偏差)?

用于计算 Transformer 注意力模块中参数数量的公式。图片由作者提供

我心愿这个公式不会太繁琐——我试图让推导的后果尽可能的清晰。不要放心! 将来的公式会更加简短。

参数的大抵数量是这样的,因为与 4 d_model 相比,咱们能够疏忽 4 d_model^2。让咱们当初用 PyTorch 进行测试。

from torch import nn
d_model = 512
n_heads = 8 # must be a divisor of `d_model`
multi_head_attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads)
print(count_parameters(multi_head_attention)) # 1050624
print(4 * (d_model * d_model + d_model)) # 1050624

数字匹配,这意味着咱们做得很好!

2.2 前馈网络

Transformer 中的前馈网络由 两个全连贯层 (fully connected layers)组成,其 两头有一个 ReLU 激活函数。该网络的外部局部比输出和输入(input and output)更具表现力(输出和输入必须雷同)。

在个别状况下,它是 MLP(d_model, d_ff) -> ReLU -> MLP(d_ff, d_model),对于原始论文,d_ff = 2048

前馈神经网络形容 图来自论文 https://arxiv.org/pdf/1706.03762.pdf

略微进行一下可视化不会有害处。

Transformer 中的前馈网络。作者提供的图像。

参数的计算相当容易,次要的还是不要被弄混。

用于计算 Transformer 前馈网络中参数数量的公式。图像由作者提供。

咱们能够应用以下代码形容这样一个简略的网络并查看其参数的数量(请留神,官网的 PyTorch 实现也应用了 dropout,咱们将在前面的编码器 / 解码器代码中看到。然而正如咱们所知,dropout 层没有可训练的参数,因而为了简略起见,我在这里省略它):

from torch import nn
class TransformerFeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super(TransformerFeedForward, self).__init__()
       self.d_model = d_model
       self.d_ff = d_ff
       self.linear1 = nn.Linear(self.d_model, self.d_ff)
       self.relu = nn.ReLU()
       self.linear2 = nn.Linear(self.d_ff, self.d_model)
def forward(self, x):
       x = self.linear1(x)
       x = self.relu(x)
       x = self.linear2(x)
return x
d_model = 512
d_ff = 2048
feed_forward = TransformerFeedForward(d_model, d_ff)
print(count_parameters(feed_forward)) # 2099712
print(2 * d_model * d_ff + d_model + d_ff) # 2099712

再次看看图中的数字,仅剩下一个组件没有介绍啦。

2.3 层归一化

Transformer 架构的最初一个构件块是 层归一化。简略地说,只是一种智能的(即可学习的)归一化形式,具备缩放性能,能够进步训练过程的稳定性。

Transformer 的层归一化,图片由作者提供

这里的可训练参数是两个向量 gamma 和 beta,每个向量的维度都是 d_model。

用于计算 Transformer 层归一化模块中参数数量的公式。作者提供的图像。

让咱们应用代码来测验咱们的假如。

from torch import nn
d_model = 512
layer_normalization = nn.LayerNorm(d_model)
print(count_parameters(layer_normalization)) # 1024
print(d_model * 2) # 1024

很好! 在近似计算中,这个数字能够 忽略不计 ,因为 层归一化的参数大大少于前馈网络或多头注意力块(只管这个模块呈现了几次)。

03 推导出残缺的公式

当初咱们有了所有,能够计算整个编码器 / 解码器模块的参数了!

3.1 用 PyTorch 实现的编码器和解码器

请让咱们记住,编码器是由一个注意力块、前馈网络和两个层归一化组成。

Transformer 编码器。来源于论文 https://arxiv.org/pdf/1706.03762.pdf

咱们能够查看 PyTorch 代码中的细节来验证所有组件是否都已就位。其中 多头注意力机制用红色标注 (左侧), 前馈网络用蓝色标注 层归一化用绿色标注(在 PyCharm 中的 Python 控制台截图)。

PyTorch TransformerEncoderLayer。图片由作者提供

3.2 最终公式

确认好之后,咱们能够编写以下函数来计算参数数量。实际上,这只是三行代码,甚至能够合并为一行。函数的其余部分是文档字符串以作阐明。

def transformer_count_params(d_model=512, d_ff=2048, encoder=True, approx=False):
"""
   Calculate the number of parameters in Transformer Encoder/Decoder.
   Formulas are the following:
       multi-head attention: 4*(d_model^2 + d_model)
           if approx=False, 4*d_model^2 otherwise
       feed-forward: 2*d_model*d_ff + d_model + d_ff
           if approx=False, 2*d_model*d_ff otherwise
       layer normalization: 2*d_model if approx=False, 0 otherwise
   Encoder block consists of:
       1 multi-head attention block,
       1 feed-forward net, and
       2 layer normalizations.
   Decoder block consists of:
       2 multi-head attention blocks,
       1 feed-forward net, and
       3 layer normalizations.
   :param d_model: (int) model dimensionality
   :param d_ff: (int) internal dimensionality of a feed-forward neural network
   :param encoder: (bool) if True, return the number of parameters of the Encoder,
       otherwise the Decoder
   :param approx: (bool) if True, result is approximate (see formulas)
   :return: (int) number of learnable parameters in Transformer Encoder/Decoder
   """
   attention = 4 * (d_model ** 2 + d_model) if not approx else 4 * d_model ** 2
   feed_forward = 2 * d_model * d_ff + d_model + d_ff if not approx else 2 * d_model * d_ff
   layer_norm = 2 * d_model if not approx else 0
return attention + feed_forward + 2 * layer_norm \
if encoder else 2 * attention + feed_forward + 3 * layer_norm

当初是测试它的时候了。

from torch import nn
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
print(count_parameters(encoder_layer))  # 3152384
print(transformer_count_params(d_model=512, d_ff=2048, encoder=True, approx=False))  # 3152384
print(transformer_count_params(d_model=512, d_ff=2048, encoder=True, approx=True))   # 3145728
# ~0.21% difference
decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
print(count_parameters(decoder_layer))  # 4204032
print(transformer_count_params(d_model=512, d_ff=2048, encoder=False, approx=False))  # 4204032
print(transformer_count_params(d_model=512, d_ff=2048, encoder=False, approx=True))   # 4194304
# ~0.23% difference

精确的公式是正确的,这意味着咱们曾经正确地确定了所有构件块并将其分解成其各组成部分。乏味的是,因为咱们在近似公式中疏忽了绝对较小的值(与百万相比只有数千个),因而绝对于准确后果,误差仅约为 0.2%!然而还有一种办法能够使这些公式更简略。

注意力块的近似参数数量为 4 d_model^2。思考到 d_model 是一个重要的超参数,这听起来计算会非常简略。然而对于前馈网络,咱们须要晓得 d_ff,因为公式是 2 d_model * d_ff。

d_ff 是一个独自的超参数,当初必须在公式中记住它,因而让咱们思考如何解脱它。正如咱们下面看到的,当 d_model = 512 时,d_ff = 2048,因而 d_ff = 4 * d_model。

对于许多 Transformer 模型来说,这样的假如将是有意义的,大大简化了公式,并依然给出一个大略的参数数量。毕竟,没有人想晓得确切的数量,只是理解这个数量是几十万还是几千万。

近似的编码器 - 解码器公式。由作者提供的图像。

为了理解你正在解决的数量级,你也能够将乘数四舍五入。这样每个编码器 / 解码器层就会失去 10 * d_model ^ 2 个参数。

04 Conclusion 论断

上面给咱们明天推导出的所有公式做一个总结。

公式总结,由作者提供的图像。

在本文计算了 Transformer 编码器 / 解码器块中的参数数量,然而当然,咱们并不建议您去计算所有新模型的参数。之所以抉择这种办法,是因为当我开始钻研 Transformers 时,我很诧异没有找到这样的文章。

尽管参数数量能够让咱们晓得模型的复杂性和训练所需数据量,但这只是更深刻地理解模型架构的一种形式。我想激励您摸索和试验:去查看、实现、运行具备不同超参数的代码等等。因而,请持续学习并 enjoy 人工智能的乐趣!

END

参考资料

1.https://pytorch.org/docs/stable/generated/torch.nn.Transforme…

2.https://pytorch.org/docs/stable/generated/torch.nn.Transforme…

3.https://arxiv.org/abs/1706.03762

4.https://techcrunch.com/2022/12/22/a-brief-history-of-diffusio…

5.https://www.washingtonpost.com/technology/interactive/2022/ai…

6.https://www.explainpaper.com/papers/attention

7.https://jalammar.github.io/illustrated-transformer/

8.https://youtu.be/-QH8fRhqFHM

9.https://www.youtube.com/watch?v=rBCqOTEfxv

10.https://huggingface.co/course/chapter1/1

11.https://pytorch.org/

12.https://discuss.pytorch.org/t/how-do-i-check-the-number-of-pa…

本文经原作者受权,由 Baihai IDP 编译。如需转载译文,请分割获取受权。

原文链接

https://towardsdatascience.com/how-to-estimate-the-number-of-…

正文完
 0