PyTorch默认的浮点数存储方式用的是torch.float32,小数点后位数更多固然能保证数据的精确性,但绝大多数场景其实并不需要这么精确,只保留一半的信息也不会影响结果,也就是使用torch.float16格式。由于数位减了一半,因此被称为“半精度”。 显然半精度能够减少显存占用,使得显卡可以同时加载更多数据进行计算。在PyTorch中使用autocast配置半精度训练,同时需要在下面三处加以设置:
import autocast
from torch.cuda.amp import autocast
在模型定义中,使用python的装饰器方法,用autocast装饰模型中的forward函数。关于装饰器的使用,可以参考这里:
@autocast()
def forward(self, x):
...
return x
在训练过程中,只需在将数据输入模型及其之后的部分放入“with autocast():“即可:
for x in train_loader:
x = x.cuda()
with autocast():
output = model(x)
...
注意:半精度训练主要适用于数据本身的size比较大,比如说3D图像、视频等。当数据本身的size并不大时,比如手写数字MNIST数据集的图片尺寸只有28 * 28,使用半精度训练则可能不会带来显著的提升。