PyTorch - Autograd: Automatic Differentiation(自动微分)
程序员文章站
2022-07-12 23:00:07
...
PyTorch - Autograd: Automatic Differentiation(自动微分)
flyfish
参考网址
import torch
import numpy as np
from torch.autograd import Variable
a = torch.randn(2, 2)
print(a)
a = ((a * 3) / (a - 1))
print(a)
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b = (a * a).sum()
print(b.grad_fn)
#Create a tensor and set requires_grad=True to track computation with it
x = torch.ones(2, 2, requires_grad=True)
print(x)
# =============================================================================
# tensor([[1., 1.],
# [1., 1.]], requires_grad=True)
# =============================================================================
#Do a tensor operation:
y = x + 2
#y.creator
print(y)
# =============================================================================
# tensor([[3., 3.],
# [3., 3.]], grad_fn=<AddBackward0>)
# =============================================================================
# y was created as a result of an operation, so it has a grad_fn.
print(y.grad_fn)
#<AddBackward0 object at 0x7f3709bbc780>
z = y * y * 3
#grad can be implicitly created only for scalar outputs
out = z.mean()
print(out)
# =============================================================================
# tensor([[27., 27.],
# [27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)
# =============================================================================
#Let’s backprop now. Because out contains a single scalar,
#out.backward() is equivalent to out.backward(torch.tensor(1.)).
out.backward()
print(x.grad)#Print gradients d(out)/dx
# =============================================================================
# tensor([[4.5000, 4.5000],
# [4.5000, 4.5000]])
# =============================================================================
官网给的计算步骤
设输出的变量为o
Therefore
hence
我手工计算的步骤是
求导之后是
将各个元素带入,然后每个数除以4
也就是
例如
# input
[[1., 2.],
[3., 4.]]
带入
6*(1+2)/4=4.5
6*(2+2)/4=6
6*(3+2)/4=7.5
6*(4+2)/4=9
# ouput
[[4.5000, 6.0000]
[7.5000, 9.0000]]