欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

[loss] pytorch实现交叉熵损失函数及其变种

程序员文章站 2022-06-15 13:55:55
...

1. 交叉熵反映两种分布的相似程度

如下公式 pi qip_{i}\ q_{i} 为两随机变量,N为有限样本量,交叉熵通过有限样本量近似衡量两随机变量分布的相似性

cross_entropy=i=1Nqilog(pi) cross_\_entropy=\sum_{i=1}^{N}q_{i}log(p_{i})
稍微分类一下变种

  • class weighted 不同类别的样本权重不同。如平衡正负样本(balance class frequency)
  • sample weighted 不同样本权重不同。如平衡难易样本(focalloss、GHMloss)
  • hardlabel/softlabel 上述两者计算单个样本loss时只用到了概率向量中真实类别对应的概率,softlabel则用到了概率向量中所有的概率(计算单个样本loss时为不同类别赋予不同权重,区分先前组合minibatch loss时为不同类别赋予不同权重)如 label smoothing

一个minibatch的loss是由该minibatch中每个样本的loss通过sum/avg等规则计算的,对于不同的任务,样本的定义从单张图像到单个像素,但梯度下降永远用的是一个minibatch的loss

我们既要关注 单个样本loss 的计算方式(如上述softlabel)又要清楚它们如何 组合成最后minibatch的loss(如上述class/sample weight)

2. pytorch内置的weighted CELoss

分类任务

  • 二分类

    模型输出尺寸 (batchsize) 一维向量,即每个sample输出概率数
    pytorch对应 torch.nn.BCEWithLogitsLoss(weight, reduction='mean', pos_weight=None) 内置sigmoid操作

    • weight (Tensor) a Tensor of size nbatch 注意这里是每个样本给了一个权重(但依旧是按样本所属类别赋予的权重值)
    • pos_weight (Tensor, optional) a weight of positive examples. a vector of length num_classes
    • reduction(string) 对minibatch每个sample算得的loss如何处理 'none' | 'mean' | 'sum'
      显然根据minibatch的label可以由pos_weight生成
  • 多分类

  • 模型输出尺寸 (batchsize, num_classes) 即每个sample输出概率向量
    pytorch对应 torch.nn.CrossEntropyLoss(weight, ignore_index, reduction='mean')

    • weight (Tensor) a Tensor of size C
    • ignore_index (int) Specifies a target value
    • reduction (string) Specifies the reduction to apply to the output 'none' | 'mean' | 'sum'

分割任务

  • 模型输出尺寸 (batchsize, num_classes, H, W) 即每个sample输出概率向量,但sample的定义从一张图变成图上的向量
    pytorch对应 torch.nn.CrossEntropyLoss(weight, ignore_index, reduction='mean')
    • weight (Tensor) a Tensor of size C 即为length of num_classes的一维向量,为每个类别赋予不同权重
    • ignore_index (int, optional) Specifies a target value
    • reduction (string, optional): Specifies the reduction to apply to the output:
      'none' | 'mean' | 'sum'

3. 各种变种如何根据现有内置loss改动

3.1 CELoss与NLLLoss/log/softmax

考虑最后一层只做了WX+b,输出Z,尺寸为(N, num_classes)(N, C, d_1, d_2, ..., d_K)

  • softmax 既作为**函数提供nonlinearity,又使输出归一化可看做概率valid probability
  • log 代表最开始那个交叉熵公式中的第二个乘积因子log(pi)log(p_{i})\quad (第一个乘积因子qiq_{i}是外部信息)

可以看到log_softmax后,得到了所有样本的log probability(log概率向量)

  • nllloss 做了三件事:计算单个样本概率 + 组合minibatch中全部样本的loss + 取负数
    计算单个样本loss的内部实现显然就是矩阵乘法 \quad 考虑单个样本的log概率向量和label,利用最开始那个交叉熵公式计算单个样本loss(注意此时label可为softlabel,即一个样本不不仅仅考虑真实类别对应的那一个概率)

此外还要说个代码细节:nllloss支持的label尺寸(N)(N, d_1, d_2, ..., d_K) ,可以这么理解:二分类时N;多分类问题既可为N一系列标签,又可以自己手动实现one-hot后的(N, d_1, d_2, ..., d_K),其中元素0/1;还可以是softlabel形式的(N, d_1, d_2, ..., d_K),不同样本真实类别对应的label相同,同一样本其余类别的label相同

3.2 具体变种原理和代码

拆解CELoss原理为上述两步后,变种都是在做完log_softmax后介入

常规而言,分类任务要求输入标签形式为 length of batchsize vector,分割任务标签形式 tensor of (batchsize, H, W) 即pytorch内部自动帮我们把标签矩阵转为one-hot形式

  • softlabel 计算单个样本时介入
    改labels,需要自行coding构造 (N, d_1, d_2, ..., d_K)尺寸的labels矩阵

  • class weight 可在计算单个样本loss时介入,但更推荐组合minibatch loss时介入
    后者利用CELoss/NLLLoss 初始化自带weight参数,注意weight参数是长度为num_classes的权重向量(BCE的参数略有不同,见上
    前者需要手动实现"类似broadcasting" : 把对类别的权重向量变成对每个样本的权重矩阵

  • sample weight 计算单个loss或组合loss时都可
    对于分割,每个像素点的sample loss其实就是class weight

    对于focalloss的sample loss,每个样本的损失和自身输出的probability有关,所以即便是同类,不同样本的权重也一般不同,所以这个变种介入不像上述,都在nllloss计算中介入,而是直接更改了nllloss的input,把log_softmax改为带有modulating_factor(与probability相关的权重)的log_softmax

    # outs为上述最后一层输出的z
    p = F.softmax(outs, dim=1)
    modulating_factor = (1 - p) ** self.gamma
    
    loss_fn = torch.nn.NLLLoss(weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
    focal_loss = loss_fn(modulating_factor*p, labels)
    

    完整FocalLoss代码指路 pytorch实现Focalloss