欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

[pytorch] 二分类交叉熵逆样本频率权重

程序员文章站 2022-05-26 19:02:44
...

通常,由于类别不均衡,需要使用weighted cross entropy loss平衡。

def inverse_freq(label):
	"""
	输入label [N,1,H,W],1是channel数目
	"""
    den = label.sum() # 0
    _,_,h,w= label.shape
    num = h*w
    alpha = den/num # 0
    return torch.tensor([alpha, 1-alpha]).cuda()

# train
...
loss1 = F.cross_entropy(out1, label.squeeze(1).long(), weight=inverse_freq(label))

代码比较简单,写在博客上保存。

相关标签: torch