欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch 使用amp.autocast半精度加速训练

程序员文章站 2024-02-28 19:51:10
...

pytorch 使用autocast半精度加速训练

如何使用autocast?

根据官方提供的方法,
答案就是autocast + GradScaler。

1,autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

如何在PyTorch中使用自动混合精度?

答案:autocast + GradScaler。

1.autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()

2.GradScaler

GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。
因此PyTorch中经典的AMP使用方式如下:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

3.nn.DataParallel

单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。
要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_data_c1):
    	with autocast():
    		# code
    	return

只要把forward里面的代码用autocast代码块方式运行就好啦!

自动进行autocast的操作

如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

  1. matmul
  2. addbmm
  3. addmm
  4. addmv
  5. addr
  6. baddbmm
  7. bmm
  8. chain_matmul
  9. conv1d
  10. conv2d
  11. conv3d
  12. conv_transpose1d
  13. conv_transpose2d
  14. conv_transpose3d
  15. linear
  16. matmul
  17. mm
  18. mv
  19. prelu

那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。