PyTorch的最大特点是动态计算图,计算图是用来描述运算的有向无环图。计算图有两种主要元素:结点(Node)和边(Edge)。结点表示数据,例如张量,而边表示运算,例如加、减、乘、除、卷积等。
对于节点而言,又分为叶子节点和非叶子节点。我们通常关注的叶子节点,那什么叶子节点呢?PyTorch中的张量tensor有一个属性是is_leaf,当is_leaf为True时,该tensor是叶子张量,也叫叶子节点。
在PyTorch中,默认情况下,非叶节点的梯度值在反向传播过程中使用完后就会被清除,不会被保留,只有叶子节点的梯度值能够被保留下来。对于非叶子节点而言,PyTorch出于节省内存的考虑,通常不会保存节点的到数值。总之,一句话:在调用backward()时,只有当节点的requires_grad和is_leaf同时为真时,才会计算节点的梯度值,也就是说节点的grad属性才会赋值,否则为None
简单来说,所有用户创建的向量都是叶子结点。其中分为两种情况,分别为显示叶子节点和隐式叶子节点:
(1)用户创建的训练数据,因为显而易见所以称之为“显示叶子节点”
import torch
input = torch.ones([2, 2])
(2)用户创建的网络模型中自带的权重参数,因为暗藏其中所以称之为“隐式叶子节点”。例如,nn.Linear()
, nn.Conv2d()
等网络模型, 它们是用户创建的,其内部的权重参数也是叶子节点。
需要提醒的是:默认情况下,我们显示创建的张量tensor的requires_grad都是False值的,因为我们训练网络训练的是网络模型的权重,而不需要训练输入。
如何区分叶子节点和非叶子节点,下面有三个例子。
import torch
a = torch.tensor([1.0, 1.0], requires_grad=False)
print(a.is_leaf)
print(a.requires_grad)
b = a + 1
print(b.is_leaf)
print(b.requires_grad)
结果为:
True
False
True
False
代码分析:张量b是由张量a计算而得,并非我们创建,按理应该是非叶子节点,但是从PyTorch的角度来看,由于a的requires_grad的为False,其不要求获得梯度,那么a在反向传播时其实是“无意义”的,可认为是游离在计算图之外的,故b仍然为叶子节点。
import torch
a = torch.tensor([1.0, 1.0], requires_grad=True)
print(a.is_leaf)
print(a.requires_grad)
b = a + 1
print(b.is_leaf)
print(b.requires_grad)
结果为:
True
True
False
True
代码分析:与例子1相比,所不同之处在于a的requires_grad为True。因为张量b是由张量a计算而得,从直觉上来说,b应该是非叶子节点,实际上的确如此,与我们的直觉相契合。通过例子1和例子2,我们可以得到一条收获:中间变量并非都是非叶子节点,跟它所处的计算图有密切关系。
import torch
input = torch.ones([2, 2], requires_grad=False)
w1 = torch.tensor(2.0, requires_grad=True)
w2 = torch.tensor(3.0, requires_grad=True)
w3 = torch.tensor(4.0, requires_grad=True)
l1 = input * w1
print("l1 is leaf?", l1.is_leaf)
print("l1 requires_grad:", l1.requires_grad)
l2 = l1 + w2
l3 = l1 * w3
l4 = l2 * l3
loss = l4.mean()
loss.backward()
print("w1.grad:", w1.grad, "\nw2.grad:", w2.grad, "\nw3.grad:", w3.grad)
print("l1.grad:", l1.grad, "\nl2.grad:", l2.grad, "\nl3.grad", l3.grad, "\nl4.grad", l4.grad)
结果为:
l1 is leaf? False
l1 requires_grad: True
w1.grad: tensor(28.)
w2.grad: tensor(8.)
w3.grad: tensor(10.)
l1.grad: None
l2.grad: None
l3.grad None
l4.grad None
代码分析:这个代码是深度学习常见的形式,l1,l2,l3,l4是非常典型的中间变量,是地地道道的非叶子节点。