为什么模型训练这么吃显存?

现象

我有一个pytorch深度学习模型,模型本身不大,经过计算发现其自身只有约40 million的待学习参数量,大约仅需要150m内存就够了,但是我在一块有24g显存的显卡上训练时,很容易OOM,即使将batch size调整到刚刚能训练且不爆显存的大小,在backward阶段仍然会OOM,问题在于为什么这么小的模型却需要这么大的显存?

排查

搜索一番发现显存占用主要存在于4个方面:

  • 模型参数(parameters)
  • 模型参数的梯度(gradients)
  • 优化器状态(optimizer states)
  • 中间激活值(intermediate activations) 

问题是它们会在什么时候才开始占用显存?会占用多大显存?

模型参数

模型创建时就会占用,占用大小为模型的 weights 和 bias 的总数。例如,创建下面这个线性模型:

self.fc = torch.nn.Linear(in_features=3, out_features=2)

一共需要 (3+ 1) * 2 * 4 字节的显存,其中 (3 + 1) 是3个weight和1个bias,将这个模型展开其实就是下面这个方程组:

w11 * x1 + w12 * x2 + w13 * x3 = y1
w21 * x1 + w22 * x2 + w23 * x3 = y2

这个占用大小一般都是固定的,不会在训练过程中变化

模型参数的梯度

这个会在backward调用时占用,这也是为什么训练的时候没用OOM,一backward就出错,所以设置batch size的时候应适当留一些显存用于存储梯度。

它占用的大小和模型参数相同,一个学习参数对应一个梯度嘛。占用大小也是固定不变的

优化器状态

会在 optimizer.step() 的时候占用,但是占用大小和具体优化器实现有关,例如,如果使用的是SGD,则其不会占用额外空间,但如果使用的是Adam,则会再额外占用2倍模型参数那么多空间,因为它还需要为每个参数保存两个梯度更新状态。

训练过程中占用大小固定不变。

中间激活值

这就是为什么这么小的模型需要这么大显存训练的原因,它会占所有显存的大头,但会在backward过程中(求导之后)被gc。并且其占用的显存会随着batch size的增大而增大。

一个大模型里面肯定会嵌套各种各样的小模型:

pred = model1(inner_model1(inner_model2(inner_model3(x)))) + model2(...)

例如,当你使用卷积提取图像特征时,你可能会写:

Relu( Batch_norm( Conv2d(img) ) )

整体是一个模型,但其实是三个模型的嵌套,这三个模型各有各的待学习参数,所以pytorch在forward时需要将每个模型的输出都单独保存下来作为计算图的节点,用于在backward过程中对其对应模型参数求导。导求完了也就不需要了,就会被gc掉。

需要注意的是,只有有学习参数的模型的输出才会被保存下来,例如:

...
x = Conv2d(x)   # x会被单独保存下来
x = Batch_norm(x)   # x也会被单独保存下来,上面保存的那个x并不会被gc
x = x * 2   # 这个x不会单独保存,因为它不是模型的输出,它其实等价于和上面合并:x = 2 * Batch_norm(x)
x = Relu(x)   # 会被单独保存
...

所以,假如上述4个x输出大小都是一样的话,则一共会额外占据 3 * x 的显存空间(非模型输出不额外占据空间)

要知道,一个模型的激活值就是下个模型的输入,也就是说,在整个训练过程中,会将batch size的数据集“重复保存”n次(假设n个模型嵌套,每个模型的输出大小等于原始输入大小。如果某些模型会增大输入size,则会占用更多显存),则占用的显存大小和batch size呈现n倍增长。

所以通常来说,模型越深,其占用的显存也就越大。

工具

查看显存占用

pytorch中显存使用量有三个概念:reservation、max、located

located:程序目前真正使用了多少显存。

torch.cuda.memory_allocated(0) # 参数0表示第0块显卡,需根据实际使用修改

max:程序运行中,显存的占用是动态变化的(例如backward是会申请新的显存存储梯度),所以一定存在某一时刻显存占用达到峰值,而这个max就是获取这个峰值的,你可以在程序中很多地方获取这个值,然后从后往前找,找到第一个达到该值的地方就是最占用显存的地方了。

torch.cuda.max_memory_allocated(0)

reservation:pytorch占用显存的数据被gc后,其占用的显存并不会被立即释放,而是会作为缓存留下了,下次这个程序又需要申请显存时就能直接从这些缓存中获取。使用nvidia-smi看到的其实是这个的占用,而不是真正的占用。

torch.cuda.memory_reserved(0)

如何查看中间激活值占用大小

计算出模型的学习参数大小

# 模型占用大小
params = sum(p.numel() for p in myModel.parameters() if p.requires_grad) * 4

# 优化器占用大小
if optimizer == SGD:
    optimizer_cost = 0
elif optimizer == Adam:
    optimizer_cost = params * 2

在backward之前加上

before_backward = torch.cuda.memory_allocated()  # B + 中间激活值占用,其中B为其他占用

backward之后也加上

after_backward = torch.cuda.memory_allocated()  # B + 梯度占用 + 优化器占用

则激活值占用的空间为(字节):

# after_backward - before_backward = (B + 梯度占用 + 优化器占用) - (B + 中间激活值占用) = 梯度占用 + 优化器占用 - 中间激活值占用
# 中间激活值占用  =  梯度占用 + 优化器占用 - (after_backward - before_backward)
activations_memory_cost = params + optimizer_cost - (after_backward - before_backward)

注:这里面只包含了中间激活值,模型最初的输入和最后的输出是不计在内的,例如模型最初的输入为数据集的图片,最终的输出为一个激活函数的输出,这个图片和激活函数的输出是不计的。

参考

https://blog.csdn.net/hxxjxw/article/details/121176061

Leave a Comment