欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch如何自定义Loss

程序员文章站 2022-05-27 09:52:16
...
将Loss视作单独的层,在forward函数里写明loss的计算方式,无需定义backward

class MyLoss(nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()
        print '1'
    def forward(self, pred, truth):
        return  torch.mean(torch.mean((pred-truth)**2,1),0)

初始化函数一定要添加 
super(MyLoss, self).__init__()

否则出现 AttributeError: 'MyLoss' object has no attribute '_forward_pre_hooks' 错误。

要打印loss可以用

loss.data.cpu().numpy()[0]
访问