欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

focal loss in pytorch

程序员文章站 2022-05-27 09:39:46
...
def multi_label_loss(y_pred, y_true):
	'''
	Zhang, M. L., & Zhou, Z. H. (2006). Multilabel neural networks with applications to functional genomics and text categorization. IEEE transactions on Knowledge and Data Engineering, 18(10), 1338-1351.
	'''
    # 注 0.5 为**函数的阈值
    y_true = torch.reshape(y_true,(-1,config.categories))
    y_pred = torch.reshape(y_pred,(-1,config.categories))
    
    ml_loss = []
    for i in range(y_true.shape[0]):
        y_true_i, y_pred_i = y_true[i], y_pred[i]
        
        ones_mask = (y_true_i==torch.tensor(1.))
        zero_mask = (y_true_i==torch.tensor(0.))
        #ones_mask, zero_mask = (y_true_i==1), (y_true_i==0)
        ones, zeros = y_pred_i[ones_mask], y_pred_i[zero_mask]

        if ones.shape[0]==0:# 若真实标签没有1,全为 0, 
            '''ones.shape[0]==0, zeros.shape[0]==8,**函数阈值为0.5,则 neg_pred 小于0.5 即可'''
            ones = config.activ_th * torch.ones(1,requires_grad=True).cuda()
            
        elif zeros.shape[0]==0:
            '''zeros.shape[0]==0, ones.shape[0]==8, **函数阈值为0.5,则 pos_pred 大于0.5 即可'''
            zeros = config.activ_th * torch.ones(1,requires_grad=True).cuda()
        #print ('ones, zeros',ones.requires_grad, zeros.requires_grad)
        p_repeat = ones.unsqueeze(1).expand(ones.size()[0],zeros.size()[0]).reshape((-1,1))
        n_repeat = zeros.unsqueeze(0).expand(ones.size()[0],zeros.size()[0]).reshape((-1,1))
        p_n_pairs = torch.cat((p_repeat,n_repeat),1)

        ml_loss_i = torch.exp(-p_n_pairs[:,0]+p_n_pairs[:,1])
        ml_loss_i = torch.div(torch.sum(ml_loss_i), 1.0 * ones.shape[0]*zeros.shape[0])

        ml_loss.append(ml_loss_i.unsqueeze(0))
        
    ml_loss = torch.cat(ml_loss,0)
    ml_loss = torch.mean(ml_loss)
    return ml_loss

def binary_focal_loss(y_pred, y_true, gamma=2., alpha=.25):
    """
	把二分类拓展到多标签
	Lin, Tsung-Yi, et al. "Focal loss for dense object detection." Proceedings of the IEEE international conference on computer vision. 2017.

        Binary form of focal loss.
            FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
            where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
        param y_true: A tensor of the same shape as `y_pred`, 1 维
        param y_pred:  A tensor resulting from a sigmoid, 1 维
        return: Output tensor.
    """
    #print ('y_true,y_pred:',y_true.requires_grad,y_pred.requires_grad) False,True
    ones = torch.ones_like(y_pred,dtype = torch.float)
    zeros = torch.zeros_like(y_pred,dtype = torch.float)
    pt_1 = torch.where(y_true == 1, y_pred, ones)
    pt_0 = torch.where(y_true == 0, y_pred, zeros)

    epsilon = eps=1e-10
    pt_1 = torch.clamp(pt_1, epsilon, 1. - epsilon)
    pt_0 = torch.clamp(pt_0, epsilon, 1. - epsilon)

    bi_fl =  -torch.sum(alpha * torch.pow(1. - pt_1, gamma) * torch.log(pt_1)) \
             -torch.sum((1 - alpha) * torch.pow(pt_0, gamma) * torch.log(1. - pt_0))
    # 若返回值为sum,则相当于alpha=n_batchsize,若返回为mean,则相当于alpha=1/n_class
    return bi_fl	
	
def categorical_focal_loss(y_pred,y_true,gamma=2.,weight = None):
    """ 用于单标签
        :param y_true: idx_list [1,2,5,4,0...],shape = [n_batch]
        :param y_pred: A tensor resulting from a softmax,shape = [n_batch,n_class]
        :return: loss of a batch.
        if weight : loss = sum([w1*CE_1,w2*CE_2...]/(w1+w2+..))
    """
    #def categorical_focal_loss(y_pred,y_true,gamma=2.,weight = None):
    # 把 y_true 转换成  [n_batch,n_class]
    #y_pred = torch.tensor(y_pred,requires_grad=True)
    zeros = torch.zeros_like(y_pred,dtype = torch.float)
    index = (torch.LongTensor(list(range(y_true.size()[0]))), y_true)
    y_true = zeros.index_put_(indices = index,
                              values = torch.ones_like(y_true,dtype = torch.float))

    epsilon = eps=1e-10
    y_pred = torch.clamp(y_pred, epsilon, 1. - epsilon)
    # Calculate Cross Entropy
    cross_entropy = -y_true * torch.log(y_pred)
    # Calculate Focal Loss
    loss = torch.pow(1 - y_pred, gamma) * cross_entropy
    if weight is not None:
        loss = weight * loss
    # Sum the losses in mini_batch(or K.mean)
    return torch.mean(loss)