SmoothL1 loss
程序员文章站
2023-12-31 22:19:58
...
实现代码如下:
def smooth_l1_loss(input, target, sigma, reduce=True, normalizer=1.0):
beta = 1. / (sigma ** 2)
diff = torch.abs(input - target)
cond = diff < beta
loss = torch.where(cond, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
if reduce:
return torch.sum(loss) / normalizer
return torch.sum(loss, dim=1) / normalizer