欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch自动求导机制、自定义**函数和梯度

程序员文章站 2022-05-27 09:46:59
...

Pytorch自动求导机制、自定义**函数和梯度

前言:

由于pytorch框架只是提供了正向传播的机制,模块中的参数的梯度是通过自动求导推倒出来的,当我们需要自定义某一个针对张量的一些列操作时候就不够用了。

1 自动求导机制

Pytorch会根据计算过程来自动生成动态计算图,然后可以根据动态图的创建过程进行反向传播,计算得到每个节点的梯度直。

1.0 张量本身grad_fn

为了能记录张量的梯度,首先需要在张量创建的时候设置 requires_grad =True.

对于pytorch来说,每一个张量都有一个grad_fn方法,这个方法包含着创建该张量的运算的导数信息。本身携带计算图的信息,该方法还有一个next_functions属性,包含链接该张量的其他张量的grad_fn。

1.1 torch.autograd

Pytorch提供了一个专门用来做自动求导的包,torch.autograd.

包含2个重要函数:

1.1.1 torch.autograd.backward

这个函数通过传入根节点张量,以及初始梯度张量,可以计算产生该根节点所对应的叶子节点的梯度。

当张量为标量张量的时候(及只有一个元素的张量)可以部传入初始梯度张量,默认会设置初始梯度张量为1。

当计算梯度张量的时候,原先建立的计算图会自动释放,如果直接再次求导,肯定就会报错。

如果要在反向传播的时候保留计算图,可以设置retain_graph= True.

在自动求导的时候默认是不会建立反向传播图的,如果需要反向传播计算的同时建立和梯度张量相关的计算图,可以设置create_graph=Ture.

另外,对于一个可到的张量,也可以直接调用该张量内部的backward函数来自动求导。

t1=torch.randn(3,3,requires_grad=True)
t2 =t1.pow(2).sum()
#t2对t1张量求导
t2.backward()#反向传播
t1.grad
t2 =t1.pow(2).sum()
t2.backward()#再次反向传播
t1.grad #梯度累计
t1.grad.zero_() # 单个张量清零

1.1.2 torch.autograd.grad

在某些情况下,我们并不需要求出当前张量对所有产生该张量的叶子节点的梯度,这时候我们可以使用torch.autograd.grad方法。

该函数有2个参数,第一个参数是计算图的数据结果张量,第二个参数是需要对计算图求导的张量,最后输出的结果是第一个参数对第二个参数的求导结果,这个输出梯度也是会累计的。

要注意的地方:

1、这个函数部会改变叶子节点的grad属性。

2、反向传播求导时,自动释放计算图,如果要保留,可以设置retain_graph= True.

3、如果需要反向传播计算图,可以设置create_graph=Ture.

t1=torch.randn(3,3,requires_grad=True)
t2 =t1.pow(2).sum()
#t2对t1张量求导
torch.autograd.grad(t2,t1)

2 自定义**函数和梯度

前言里说了,仅仅使用模块有时候是不能满足我们需要效果的。我们需要自定义**函数,在**函数中定义前向传播和反向传播的代码来实现自己的需求。

2.1 类及方法

Pytorch自定义**函数继承于torch.autograd.Function,其内部有2个静态方法:forward和backward

class Func(torch.autograd.Function):
    @staticmethod
    def forward(ctx,input):
        return result
    
    @staticmethod
    def backward(ctx,grad_output):
        return grad_output

2.2 实例

Quoc V.Le等人的研究成果中,将Swish**函数定义为

Pytorch自动求导机制、自定义**函数和梯度

可以看到,这个公式还是比较复杂的,如果要生成图,中间有部少计算节点。

有了公式之后,我们可以求出导数函数,这样方便进行反向传播。

有了**函数和其导数函数,我们就可以来自定义相关**函数了。

swish =Swish.apply #获得**函数
torch.autograd.gradcheck(
swish,torch.randn(
10,requires_grad =True,
dtype =torch.double)
)
#测试反向传播,正常返回值为True

class Swish(torch.autograd.Function):
    @staticmethod
    def forward(ctx,input):
        ctx.input =input
        return input*torch.sigmoid(1*input) #假设b=1
     @staticmethod
    def backward(ctx,grad_output):
        ctx.input =input
        tmp = torch.sigmoid(1*input)
        
        return grad_output*(tmp +1 *input*tmp(1-tmp))
    

2.3 tips

在上面代码可以看到,我们记录了前像传播和反向传播的过程,并且在backward方法中实现了数值梯度的方法。

可以通过讲apply方法赋值给一个变量的方法来**自定义的**函数。

为了保持梯度精度,我们一般都使用双精度类型为张量数值类型