Pytorch如何自定义Loss
程序员文章站
2022-05-27 09:46:53
...
Pytorch如何自定义Loss
原文:https://blog.csdn.net/yutingzhaomeng/article/details/80454545
将Loss视作单独的层,在forward函数里写明loss的计算方式,无需定义backward
-
class MyLoss(nn.Module):
-
def __init__(self):
-
super(MyLoss, self).__init__()
-
print '1'
-
def forward(self, pred, truth):
-
return torch.mean(torch.mean((pred-truth)**2,1),0)
初始化函数一定要添加
super(MyLoss, self).__init__()
否则出现 AttributeError: 'MyLoss' object has no attribute '_forward_pre_hooks' 错误。
要打印loss可以用
loss.data.cpu().numpy()[0]
下一篇: XGBOOST学习实战