OHEM,Focal loss,GHM loss二分类pytorch代码实现(减轻难易样本不均衡问题)
https://mp.weixin.qq.com/s/iOAICJege2b0pCVxPkvNiA
综述:解决目标检测中的样本不均衡问题
该综述主要介绍了OHEM,Focal loss,GHM loss;由于我这的二分类数据集不存在正负样本不均衡的问题,所以着重看了处理难易样本不均衡(正常情况下,容易的样本较多,困难的样本较少);由于我只是分类问题,所以写了各种分类的loss,且网络的最后一层为softmax,所以网络输出的pred是softmax层前的logits经过softmax后的结果,普通的交叉熵损失即为sum(-gt*log(pred)),但torch.nn.CrossEntropyLoss()中会对于输入的pred再进行一次softmax,所以这里使用torch.nn.NLLLoss代替,当然经测试,即使网络最后一层使用softmax损失函数还是使用torch.nn.CrossEntropyLoss(),效果和使用torch.nn.NLLLoss差不多。。。
OHEM:
代码参考:https://www.codeleading.com/article/7442852142/
def ohem_loss(pred, target, keep_num):
loss = torch.nn.NLLLoss(reduce=False)(torch.log(pred), target)
print(loss)
loss_sorted, idx = torch.sort(loss, descending=True)
loss_keep = loss_sorted[:keep_num]
return loss_keep.sum() / keep_num
Focal loss:
详解:原论文Focal Loss for Dense Object Detection
代码参考:https://zhuanlan.zhihu.com/p/80594704
def focal_loss(pred,target,gamma=0.5):
pred_temp=pred.detach().cpu()
target_temp=target.detach().cpu()
pt = torch.tensor([pred_temp[i,target_temp[i]] for i in range(target_temp.shape[0])])
focal_weight = (1-pt).pow(gamma)
return torch.mean((torch.nn.NLLLoss(reduce=False)(torch.log(pred), target)).mul(focal_weight.to(device).detach()))
GHM loss:
详解:https://zhuanlan.zhihu.com/p/80594704
代码参考:https://github.com/DHPO/GHM_Loss.pytorch/blob/master/GHM_loss.py
class GHM_Loss(nn.Module):
def __init__(self, bins, alpha):
super(GHM_Loss, self).__init__()
self._bins = bins
self._alpha = alpha
self._last_bin_count = None
def _g2bin(self, g):
return torch.floor(g * (self._bins - 0.0001)).long()
def _custom_loss(self, x, target, weight):
raise NotImplementedError
def _custom_loss_grad(self, x, target):
raise NotImplementedError
def forward(self, x, target):
g = torch.abs(self._custom_loss_grad(x, target))
bin_idx = self._g2bin(g)
bin_count = torch.zeros((self._bins))
for i in range(self._bins):
bin_count[i] = (bin_idx == i).sum().item()
N = x.size(0)
nonempty_bins = (bin_count > 0).sum().item()
gd = bin_count * nonempty_bins
gd = torch.clamp(gd, min=0.0001)
beta = N / gd
return self._custom_loss(x, target, beta[bin_idx])
class GHMC_Loss(GHM_Loss):
def __init__(self, bins, alpha):
super(GHMC_Loss, self).__init__(bins, alpha)
def _custom_loss(self, x, target, weight):
return torch.sum((torch.nn.NLLLoss(reduce=False)(torch.log(x),target)).mul(weight.to(device).detach()))/torch.sum(weight.to(device).detach())
def _custom_loss_grad(self, x, target):
x=x.cpu().detach()
target=target.cpu()
return torch.tensor([x[i,target[i]] for i in range(target.shape[0])])-target
上一篇: 怎么取得路径返回父目录呢
下一篇: win7 下mysql的安装和设置