PyTorch初始化时,其默认浮点数据类型为torch.float32,同时相应的复数的默认数据类型为torch.complex64。利用torch.set_default_dtype(d)可以修改pytorch全局的默认精度,代码如下所示:
>>> # PyTorch 默认浮点类型是 torch.float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # PyTorch 默认复数的浮点类型是 torch.complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> torch.tensor([1.2, 3]).dtype
torch.float64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex128
>>> torch.set_default_dtype(torch.float16)
>>> torch.tensor([1.2, 3]).dtype
torch.float16
>>> torch.tensor([1.2, 3j]).dtype
torch.complex32