pytorch对自定义loss函数自动求梯度
程序员文章站
2022-03-03 14:37:00
...
通过 torch.autograd.grad
class MMD(nn.Module):
def __init__(self):
super(MMD, self).__init__()
self.mmd = torch.nn.MSELoss()
def forward(self,fc1Features1,fc1Features2):
n = len(fc1Features1)
fc1_1 = 1/n * torch.sum(fc1Features1,axis=0)
fc1_2 = 1/n * torch.sum(fc1Features2,axis=0)
fc1 = fc1_1 - fc1_2
mmdLoss = torch.norm(fc1,p=2)
mmdLoss = mmdLoss * mmdLoss
return mmdLoss
def fc1_constrain(self, fc1Features1, fc1Features2, ):
torch.cuda.current_stream().wait_stream(self.stream)
mmdLoss = self.mmd.forward(fc1Features1,fc1Features2)
# torch.autograd.grad(y,[x1,x2]) 返回y分别对x1和x2求得的偏导数
grad1,grad2 = torch.autograd.grad(mmdLoss, [fc1Features1,fc1Features2], only_inputs=True)
mmdGradList = [grad1,grad2]
return mmdLoss, mmdGradList