欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

OHEM,Focal loss,GHM loss二分类pytorch代码实现(减轻难易样本不均衡问题)

程序员文章站 2022-05-26 19:46:26
...

https://mp.weixin.qq.com/s/iOAICJege2b0pCVxPkvNiA
综述:解决目标检测中的样本不均衡问题
该综述主要介绍了OHEM,Focal loss,GHM loss;由于我这的二分类数据集不存在正负样本不均衡的问题,所以着重看了处理难易样本不均衡(正常情况下,容易的样本较多,困难的样本较少);由于我只是分类问题,所以写了各种分类的loss,且网络的最后一层为softmax,所以网络输出的pred是softmax层前的logits经过softmax后的结果,普通的交叉熵损失即为sum(-gt*log(pred)),但torch.nn.CrossEntropyLoss()中会对于输入的pred再进行一次softmax,所以这里使用torch.nn.NLLLoss代替,当然经测试,即使网络最后一层使用softmax损失函数还是使用torch.nn.CrossEntropyLoss(),效果和使用torch.nn.NLLLoss差不多。。。

OHEM:
代码参考:https://www.codeleading.com/article/7442852142/

def ohem_loss(pred, target, keep_num):
    loss = torch.nn.NLLLoss(reduce=False)(torch.log(pred), target)
    print(loss)
    loss_sorted, idx = torch.sort(loss, descending=True)
    loss_keep = loss_sorted[:keep_num]
    return loss_keep.sum() / keep_num

Focal loss:
详解:原论文Focal Loss for Dense Object Detection
代码参考:https://zhuanlan.zhihu.com/p/80594704

def focal_loss(pred,target,gamma=0.5):
    pred_temp=pred.detach().cpu()
    target_temp=target.detach().cpu()
    pt = torch.tensor([pred_temp[i,target_temp[i]] for i in range(target_temp.shape[0])])
    focal_weight = (1-pt).pow(gamma)
    return torch.mean((torch.nn.NLLLoss(reduce=False)(torch.log(pred), target)).mul(focal_weight.to(device).detach()))

GHM loss:
详解:https://zhuanlan.zhihu.com/p/80594704
代码参考:https://github.com/DHPO/GHM_Loss.pytorch/blob/master/GHM_loss.py

class GHM_Loss(nn.Module):
    def __init__(self, bins, alpha):
        super(GHM_Loss, self).__init__()
        self._bins = bins
        self._alpha = alpha
        self._last_bin_count = None

    def _g2bin(self, g):
        return torch.floor(g * (self._bins - 0.0001)).long()

    def _custom_loss(self, x, target, weight):
        raise NotImplementedError

    def _custom_loss_grad(self, x, target):
        raise NotImplementedError

    def forward(self, x, target):
        g = torch.abs(self._custom_loss_grad(x, target))
        bin_idx = self._g2bin(g)
        bin_count = torch.zeros((self._bins))
        for i in range(self._bins):
            bin_count[i] = (bin_idx == i).sum().item()

        N = x.size(0)

        nonempty_bins = (bin_count > 0).sum().item()
        gd = bin_count * nonempty_bins
        gd = torch.clamp(gd, min=0.0001)
        beta = N / gd
        return self._custom_loss(x, target, beta[bin_idx])
        
class GHMC_Loss(GHM_Loss):
    def __init__(self, bins, alpha):
        super(GHMC_Loss, self).__init__(bins, alpha)

    def _custom_loss(self, x, target, weight):
        return torch.sum((torch.nn.NLLLoss(reduce=False)(torch.log(x),target)).mul(weight.to(device).detach()))/torch.sum(weight.to(device).detach())

    def _custom_loss_grad(self, x, target):
        x=x.cpu().detach()
        target=target.cpu()
        return torch.tensor([x[i,target[i]] for i in range(target.shape[0])])-target
相关标签: pytorch