飞燕提醒:tensor类型转为numpy类型,要注意使用detach()函数,将张量从GPU设备迁移到CPU设备上。此处照应了本教程中《detach原理》相关小节的内容。
在做AI业务开发时候,常常需要PyTorch的tensor类型和numpy类型进行转换,下面给大家介绍一下两者的转换过程。
首先,导入需要使用的包:
import numpy as np
import torch
然后,创建一个numpy类型的数组:
x = np.ones(3)
print(type(x))
这里创建了一个一维的数组,3个都为1,我们打印一下这个x的类型显示如下:
<class "numpy.ndarray">
用下面的代码将上述的x转换成tensor类型:
y = torch.tensor(x)
print(type(x))
这个打印的结果是:
<class "torch.Tensor">
当然,也可以使用:
y = torch.from_numpy(x)
import torch
x = torch.ones(3)
y = x.detach().numpy()
也可以使用:
y = x.numpy()