Pytorch1.4自动梯度计算自定义
程序员文章站
2022-05-27 09:45:53
...
Pytorch提供自动求导机制,但有些时候进行自定义操作不可避免。这里采用官网EXTENDING PYTORCH的例子,对其加以注释,并对一些代码进行必要解释。
提示
使用者需要掌握ctx中的特殊功能,以确保用户自定义可以与autograd一起正常使用:
- save_for_backward():保存前向传播的输入和输出,为了后面反向传播使用,官方文档对此做出解释:
save_for_backward(*tensors)
1、保存给定的张量,以备将来调用backward();
2、应该最多调用一次,并且只能从forward()方法内部调用;
3、以后,可以通过saved_tensors属性访问已保存的张量,在将它们返回给用户之前,应进行检查,确保其未被就地(in-place)操作修改。
4、参数也可以为None。 - mark_dirty():标记由前向传播就地(inplace)修改的输入;
- mark_non_differentiable():如果输出不可微分的话,通过此功能被告知。
Linear函数
#继承Function
class LinearFunction(Function):
# 注意:前向和反向都使用了 @staticmethods
# @staticmethods 和 @classmethods 我会在写一篇文章专门介绍
@staticmethod
# 偏置bias是可选参数,这里默认None
def forward(ctx, input, weight, bias=None):
# ctx在这里类似self,ctx的属性可以在backward中调用
# 将Tensor保存到ctx中
ctx.save_for_backward(input, weight, bias)
# torch.t()方法,对2D tensor进行转置
output = input.mm(weight.t())
if bias is not None:
# expand_as(tensor)等价于expand(tensor.size()),
# 将原tensor按照新的size进行扩展
output += bias.unsqueeze(0).expand_as(output)
return output
#此函数只有一个输出,因此只能得到一个梯度
@staticmethod
def backward(ctx, grad_output):
# 这是一个非常方便的模式-位于backward的顶部
@staticmethod
# 加载save_for_backward保存的tensor数据
input, weight, bias = ctx.saved_tensors
# 分别代表输入,权值,偏置三者的梯度
grad_input = grad_weight = grad_bias = None
# 这些needs_input_grad是可选的,仅用于提高效率。
# 如果您想简化代码,可以跳过它们
# 为不需要的输入返回梯度并不会产生一个错误
# 判断三者是否需要进行反向求导计算梯度
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
# 现在,为了使使用这些自定义操作更容易,我们建议为它们的apply方法加上别名
linear = LinearFunction.apply
在这里,我们给出了一个由非Tensor参数参数化的函数的附加示例:
class MulConstant(Function):
@staticmethod
def forward(ctx, tensor, constant):
# ctx是一个上下文对象,可用于隐藏信息以进行backward计算
ctx.constant = constant
return tensor * constant
@staticmethod
def backward(ctx, grad_output):
# 我们返回与参数一样多的输入梯度
# 前向传播的非Tensor参数的梯度必须为None
return grad_output * ctx.constant, None
检测backward()是否正确
我们通过调用torch.autograd.gradcheck方法检查实现的backward()是否正确。
这里简要介绍下torch.autograd.gradcheck:
torch.autograd.gradcheck(func, inputs, eps=1e-06, atol=1e-05, rtol=0.001, raise_exception=True, check_sparse_nnz=False, nondet_tol=0.0)
参数:
- func(function)-接受Tensor输入并返回Tensor或Tensors元组的Python函数
- inputs (tuple of Tensor or Tensor) –输入
- eps (python:float, optional)-有限差分摄动
- atol (python:float, optional) –绝对公差
from torch.autograd import gradcheck
# gradcheck将tensor元组作为输入,
#检查使用这些张量评估的梯度是否足够接近数字近似值,如果它们都验证了此条件,则返回True
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True),
torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)
如果代码没问题,输出结果为True