PyTorch自定义梯度
程序员文章站
2022-05-27 09:54:04
...
[3] 的一个简例。正向 sign: y = sign ( x ) = { − 1 x < 0 0 x = 0 1 x > 0 y=\text{sign}(x)=\begin{cases} & -1 & x < 0 \\ & 0 & x = 0 \\ & 1 & x > 0 \end{cases} y=sign(x)=⎩⎪⎨⎪⎧−101x<0x=0x>0 反向 Htanh: ∂ y ∂ x = { 1 − ϵ ≤ x ≤ ϵ 0 e l s e \frac{\partial y}{\partial x}=\begin{cases} & 1 & -\epsilon\le x\le \epsilon \\ & 0 & else \end{cases} ∂x∂y={10−ϵ≤x≤ϵelse
Code
import torch
class Htanh(torch.autograd.Function):
@staticmethod
def forward(ctx, x, epsilon=1):
ctx.save_for_backward(x.data, torch.tensor(epsilon))
return x.sign()
@staticmethod
def backward(ctx, dy):
x, epsilon = ctx.saved_tensors
dx = torch.where((x < - epsilon) | (x > epsilon), torch.zeros_like(dy), dy)
return dx, None
- 用例
htanh = Htanh()
x = torch.tensor([-2.5, -1.6, 0, 3.6, 4.7])
y = htanh.apply(x)
References
上一篇: php 城市站怎么配置