[pytorch] 二分类交叉熵逆样本频率权重
程序员文章站
2022-05-26 19:02:44
...
通常,由于类别不均衡,需要使用weighted cross entropy loss平衡。
def inverse_freq(label):
"""
输入label [N,1,H,W],1是channel数目
"""
den = label.sum() # 0
_,_,h,w= label.shape
num = h*w
alpha = den/num # 0
return torch.tensor([alpha, 1-alpha]).cuda()
# train
...
loss1 = F.cross_entropy(out1, label.squeeze(1).long(), weight=inverse_freq(label))
代码比较简单,写在博客上保存。
上一篇: 普陀佛茶是什么茶,关于普陀佛茶你了解多少
下一篇: 2019年用户最爱朋友圈广告TOP10