import torch
input = torch.ones([2, 2], requires_grad=False)
w1 = torch.tensor(2.0, requires_grad=True)
w2 = torch.tensor(3.0, requires_grad=True)
w3 = torch.tensor(4.0, requires_grad=True)
l1 = input * w1
l2 = l1 + w2
l3 = l1 * w3
l4 = l2 * l3
loss = l4.mean()
loss.backward()
print(w1.grad, w2.grad, w3.grad)
结果为:
tensor(28.) tensor(8.) tensor(10.)
注意:图中的圆形表示操作符,因为执行操作符之后必定会出现一个中间结果,所以圆形代表操作符和操作数(中间结果)。中间结果非用户生成,属于非叶子节点。而方框表示的操作数为叶子节点。
$ l_1 = w_1 $
$ l_2 = w_1 + w_2$
$ l_3 = w_1w_3$
$ l_4 = l_2 * l_3 = (w_1 + w_2) * w_1w_3 = {w_1}^2w_3 + w_1w_2w_3$
$ \frac{ \partial l_4 }{\partial w_1} = 2w_1w_3 + w_2w_3 = 28 $
$ \frac{ \partial l_4 }{\partial w_2} = w_1w_3 = 8 $
$ \frac{ \partial l_4 }{\partial w_3} = {w_1}^2 + w_1w_2 = 10 $
import torch
input = torch.ones([2, 2], requires_grad=False)
w1 = torch.tensor(2.0, requires_grad=True)
w2 = torch.tensor(3.0, requires_grad=True)
w3 = torch.tensor(4.0, requires_grad=True)
l1 = input * w1
l2 = l1 + w2
l3 = l1 * w3
l3 = l3.detach()
l4 = l2 * l3
loss = l4.mean()
loss.backward()
print("w1.grad:", w1.grad, "\nw2.grad:", w2.grad, "\nw3.grad:", w3.grad)
结果为:
w1.grad: tensor(8.)
w2.grad: tensor(8.)
w3.grad: None
根据链式求导公式以及上文3小节得到的变量关系,可以推导出:
$ \frac{ \partial l_4 }{\partial w_1} = \frac{ \partial l_4 }{\partial l_2} * \frac{ \partial l_2 }{\partial w_1} = \frac{ \partial (l_2 * l_3) }{\partial l_2} * \frac{ \partial (w_1+w_2) }{\partial w_1} = l_3 * 1 = w_1w_3 = 8$
$ \frac{ \partial l_4 }{\partial w_2} = \frac{ \partial l_4 }{\partial l_2} * \frac{ \partial l_2 }{\partial w_2} = \frac{ \partial (l_2 * l_3) }{\partial l_2} * \frac{ \partial (w_1+w_2) }{\partial w_2} = l_3 * 1 = w_1w_3 = 8$
至此,我们可以得到:w1.grad
的值为tensor(8.)
,w2.grad
的值也是tensor(8.)
。