wing-loss pytorch
程序员文章站
2022-05-27 09:42:57
...
原文链接:点此查看原链接
# -*- coding: utf-8 -*-
# @Time : 2019/9/9
# @Author : Elliott Zheng
# @Email : [email protected]
import math
import torch
from torch import nn
#torch.log and math.log is e based
class WingLoss(nn.Module):
def __init__(self, omega=10, epsilon=2):
super(WingLoss, self).__init__()
self.omega = omega
self.epsilon = epsilon
def forward(self, pred, target):
y = target
y_hat = pred
delta_y = (y - y_hat).abs()
delta_y1 = delta_y[delta_y < self.omega]
delta_y2 = delta_y[delta_y >= self.omega]
loss1 = self.omega * torch.log(1 + delta_y1 / self.epsilon)
C = self.omega - self.omega * math.log(1 + self.omega / self.epsilon)
loss2 = delta_y2 - C
return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))
if __name__ == "__main__":
loss_func = WingLoss()
y = torch.ones(2, 68, 64, 64)
y_hat = torch.zeros(2, 68, 64, 64)
y_hat.requires_grad_(True)
loss = loss_func(y_hat, y)
loss.backward()
print(loss)