继承Function类,自定义backward函数求loss
程序员文章站
2022-05-27 09:46:11
...
torch.nn.Function类
自定义模型、自定义层、自定义**函数、自定义损失函数都属于pytorch的拓展,前面讲过通过继承torch.nn.Module类来实现拓展,它最大的特点是以下几点:
- 包装torch普通函数和torch.nn.functional专用于神经网络的函数;(torch.nn.functional是专门为神经网络所定义的函数集合)
- 只需要重新实现__init__和forward函数,求导的函数是不需要设置的,会自动按照求导规则求导。
- 可以保存参数和状态信息;
注意:当在构建模型时,有时候一些****操作是不可导****的,这时候你需要自定义求导方式,就不能再使用上面提到的方式了,需要通过继承torch.nn.Function类来实现拓展。
它最大的特点是:
- 在有些操作通过组合pytorch中已有的层或者是已有的方法实现不了的时候,比如你要实现一个新的方法,这个新的方法需要forward和backward一起写,然后自己写对中间变量的操作。
- 需要重新实现__init__和forward函数,以及backward函数,需要自己定义求导规则;
- 不可以保存参数和状态信息
Function类和Module类最明显的区别是它多了一个backward方法,这也是他俩****最本质的区别:****
如果某一个类my_function继承自Function类,实现了这个类的forward和backward方法,那么我依然可以用nn.Module对这个自定义的类my_function进行包装组合。
# 定义一个继承了Function类的子类,实现y=f(x)的正向运算以及反向求导
class sqrt_and_inverse(torch.autograd.Function):
'''
本例子所采用的数学公式是:
z=sqrt(x)+1/x+2*power(y,2)
z是关于x,y的一个二元函数它的导数是
z'(x)=1/(2*sqrt(x))-1/power(x,2)
z'(y)=4*y
forward和backward可以定义成静态方法,向定义中那样,也可以定义成实例方法
'''
# 前向运算
def forward(self, input_x, input_y):
'''
self.save_for_backward(input_x,input_y) ,这个函数是定义在Function的父类_ContextMethodMixin中
它是将函数的输入参数保存起来以便后面在求导时候再使用,起前向反向传播中协调作用
'''
self.save_for_backward(input_x, input_y)
# 对输入和参数进行的操作,其实就是前向运算的函数表达式]
output = torch.sqrt(input_x) + torch.reciprocal(input_x) + 2 * torch.pow(input_y, 2)
return output
def backward(self, grad_output):
# 计算梯度是链式法则,输入的参数grad_output为反向传播上一级计算得到的梯度值
input_x, input_y = self.saved_tensors # 获取前面保存的参数,也可以使用self.saved_variables
# 求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式
# 这里上一级梯度值grad_output乘以当前级的梯度
grad_x = grad_output * (torch.reciprocal(2 * torch.sqrt(input_x)) - torch.reciprocal(torch.pow(input_x, 2)))
grad_y = grad_output * (4 * input_y)
return grad_x, grad_y # 需要注意的是,反向传播得到的结果需要与输入的参数相匹配
# 方法一 用类将我们继承了Function类的自定义子类包装
class DemoModel(torch.nn.Module):
def __init__(self):
super(DemoModel, self).__init__()
def forward(self,input_x, input_y): # 这里是对象调用的含义,因为function中实现了__call__
return sqrt_and_inverse()(input_x, input_y)
# 方法二 为了让它看起来更像是一个pytorch函数,包装一下
def sqrt_and_inverse_func(input_x, input_y):
return sqrt_and_inverse()(input_x, input_y) # 这里是对象调用的含义,因为function中实现了__call__
# x = torch.tensor(3.0, requires_grad=True) # 标量
# y = torch.tensor(2.0, requires_grad=True)
# z = sqrt_and_inverse_func(x, y)
# print('开始前向传播')
#
# z = sqrt_and_inverse_func(x, y).sum()
# print('开始反向传播')
# z.backward() # 这里是标量对标量求导
# print(x.grad)
# print(y.grad)
x=torch.tensor([2.0,3.0,4.0],requires_grad=True) #tensor
y=torch.tensor([12.0,13.0,14.0],requires_grad=True) #tensor
print('开始前向传播')
z = sqrt_and_inverse_func(x, y).sum()
# z = DemoModel()(x,y).sum()
print('开始反向传播')
z.backward() # 这里是标量对向量求导
print(x.grad)
print(y.grad)
'''运行结果为:
开始前向传播
开始反向传播
tensor(0.1776)
tensor(8.)
================================
开始前向传播
开始反向传播
tensor([0.1036, 0.1776, 0.1875])
tensor([48., 52., 56.])
# 当类包装使用结果
开始前向传播
开始反向传播
tensor([0.1036, 0.1776, 0.1875])
tensor([48., 52., 56.])
'''
def backward(self, grad_output):函数的参数是grad_output为反向传播上一级计算得到的梯度值, 其实就是上篇文章讲解,y.backward(gradient)中gradient。