关于神经网络:如何估算transformer模型的显存大小

7次阅读

共计 2665 个字符,预计需要花费 7 分钟才能阅读完成。

在微调 GPT/BERT 模型时,会常常遇到“cuda out of memory”的状况。这是因为 transformer 是内存密集型的模型,并且内存要求也随序列长度而减少。所以如果能对模型的内存要求进行粗略的预计将有助于预计工作所需的资源。

如果你想间接看后果,能够跳到本文最初。不过在浏览本文前请记住所有神经网络都是通过反向流传的办法进行训练的,这一点对于咱们计算内存的占用非常重要。

 total_memory = memory_modal + memory_activations + memory_gradients

这里的 memory_modal 是指存储模型所有参数所需的内存。memory_activations 是计算并存储在正向流传中的两头变量,在计算梯度时须要应用这些变量。因为模型中梯度的数量通常等于两头变量的数量,所以 memory_activations= memory_gradients。因而能够写成:

 total_memory = memory_modal + 2 * memory_activations

所以咱们计算总体内存的要求时只须要找到 memory_modal 和 memory_activations 就能够了。

估算模型的内存

上面咱们以 GPT 为例。GPT 由许多 transformer 块组成(前面我用 n_tr_blocks 示意其数量)。每个 transformer 块都蕴含以下构造:

 multi_headed_attention --> layer_normalization --> MLP -->layer_normalization

每个 multi_headed_attention 元素都由键,值和查问组成。其中包含 n_head 个注意力头和 dim 个维度。MLP 是蕴含有 n_head * dim 的尺寸。这些权重都是要占用内存的,那么

 memory_modal = memory of multi_headed_attention + memory of MLP
  = memory of value  + memory of key + memory of query + memory of MLP
  = square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim)
  = 4*square_of(n_head * dim)

因为咱们的模型蕴含了 n 个单元。所以最初内存就变为:

 memory_modal = 4*n_tr_blocks*square_of(n_head * dim)

下面的估算没有思考到偏差所需的内存,因为这大部分是动态的,不依赖于批大小、输出序列等。

估算两头变量的内存

多头注意力通常应用 softmax,能够写成:

 multi_headed_attention = softmax(query * key * sequence_length) * value

k,q,v 的维度是:

 [batch_size, n_head, sequence_length, dim]

multi_headed_attention 操作会得出如下形态:

 [batch_size, n_head, sequence_length, sequence_length]

所以最终得内存为:

 memory_softmax  = batch_size * n_head * square_of(sequence_length)

q k sequence_length 操作乘以 value 的形态为 [batch_size, n_head, sequence_length, dim]。MLP 也有雷同的维度:

 memory of MLP  = batch_size * n_head * sequence_length * dim
 memory of value = batch_size * n_head * sequence_length * dim

咱们把下面的整合在一起,单个 transformer 的两头变量为:

 memory_activations = memory_softmax + memory_value + memory_MLP
 = batch_size * n_head * square_of(sequence_length)
   + batch_size * n_head * sequence_length * dim
   + batch_size * n_head * sequence_length * dim
 = batch_size * n_head * sequence_length * (sequence_length + 2*dim)

再乘以块的数量,模型所有的 memory_activations 就是:

 n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))

整合在一起

咱们把下面两个公式进行演绎总结,想看后果的话间接看这里就行了。transformer 模型所需的总内存为:

 total_memory = memory_modal + 2 * memory_activations

模型参数的内存:

 4*n_tr_blocks*square_of(n_head * dim)

两头变量内存:

 n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))

咱们应用上面的符号能够更简洁地写出这些公式。

 R = n_tr_blocks = transformer 层重叠的数量
 N = n_head = 注意力头数量
 D = dim = 注意力头的维度
 B = batch_size = 批大小
 S = sequence_length = 输出序列的长度
 
 memory modal = 4 * R * N^2 * D^2
 
 memory activations = RBNS(S + 2D)

所以在训练模型时总的内存占用为:

 M = (4 * R * N^2 * D^2) + RBNS(S + 2D)

因为内存的占用和序列长度又很大的关系,如果有一个很长的序列长度 S >> D S + 2D <——> S,这时能够将计算变为:

 M = (4 * R * N^2 * D^2) + RBNS(S) = 4*R*N^2*D^2 + RBNS^2

能够看到对于较大的序列,M 与输出序列长度的平方成正比,与批大小成线性比例,这也就证实了序列长度和内存占用有很大的关系。

所以最终的内存占用的评估为:

 总内存 = ((4 * R * N^2 * D^2) + RBNS(S + 2D)) * float64(以字节为单位)

https://avoid.overfit.cn/post/6724eec842b740d482f73386b1b8b012

作者:Schartz Rehan

正文完
 0