飞燕提醒:detach()函数的一大作用是将将张量从GPU移动到CPU时,在本教程的《tensor与numpy互换小节》中大家可以深刻体会。
在PyTorch中,detach()函数的作用有两个:
(1)将张量从当前计算图中分离出来,从而不需要跟踪张量的梯度变化
(2)将张量从GPU移动到CPU时。
import torch
x = torch.tensor(2.0, requires_grad = True)
print("Tensor:", x)
x_detach = x.detach()
print("Tensor with detach:", x_detach)
结果为:
Tensor: tensor(2., requires_grad=True)
Tensor with detach: tensor(2.)
代码分析:在上述输出中,分离后的张量没有requires_grad=True
import torch
x = torch.rand(3, requires_grad = True)
print("x:", x)
y = 3 + x
z = 3 * x.detach()
print("y:", y)
print("z:", z)
输出结果为:
x: tensor([0.5656, 0.8402, 0.6661], requires_grad=True)
y: tensor([3.5656, 3.8402, 3.6661], grad_fn=<AddBackward0>)
z: tensor([1.6968, 2.5207, 1.9984])
在PyTorch中,detach()函数是一个非常有用的方法,它主要用于创建一个与原张量tensor逻辑上分离的新张量。尽管它们共享相同的存储空间,但任何对detach()返回的新张量的修改都不会影响到原始张量,反之亦然。如下代码所示:
import torch
x = torch.tensor([1.0])
y = x + 1.0
print("y值:", y)
y_ = y.detach()
y_ = y_ + 1.0
print("分离出的y_值:", y_)
print("y值:", y)
结果为:
y值: tensor([2.])
分离出的y_值: tensor([3.])
y值: tensor([2.])
控制梯度更新:当你想控制计算图特定部分的梯度更新时,detach很有用。例如,在强化学习中,你可能希望将动作值从策略网络中分离出来,以防止梯度流过它们。
创建独立张量:如果你需要创建一个独立于当前计算图的张量,detach会有所帮助。当你想在不影响梯度计算的情况下存储中间结果时,detach可能很有用。
调试和可视化:在调试和可视化过程中,分离张量可能会有所帮助。通过分离张量,你可以确保调试或可视化代码不会干扰梯度计算。