focal loss in pytorch
程序员文章站
2022-05-27 09:39:46
...
def multi_label_loss(y_pred, y_true):
'''
Zhang, M. L., & Zhou, Z. H. (2006). Multilabel neural networks with applications to functional genomics and text categorization. IEEE transactions on Knowledge and Data Engineering, 18(10), 1338-1351.
'''
# 注 0.5 为**函数的阈值
y_true = torch.reshape(y_true,(-1,config.categories))
y_pred = torch.reshape(y_pred,(-1,config.categories))
ml_loss = []
for i in range(y_true.shape[0]):
y_true_i, y_pred_i = y_true[i], y_pred[i]
ones_mask = (y_true_i==torch.tensor(1.))
zero_mask = (y_true_i==torch.tensor(0.))
#ones_mask, zero_mask = (y_true_i==1), (y_true_i==0)
ones, zeros = y_pred_i[ones_mask], y_pred_i[zero_mask]
if ones.shape[0]==0:# 若真实标签没有1,全为 0,
'''ones.shape[0]==0, zeros.shape[0]==8,**函数阈值为0.5,则 neg_pred 小于0.5 即可'''
ones = config.activ_th * torch.ones(1,requires_grad=True).cuda()
elif zeros.shape[0]==0:
'''zeros.shape[0]==0, ones.shape[0]==8, **函数阈值为0.5,则 pos_pred 大于0.5 即可'''
zeros = config.activ_th * torch.ones(1,requires_grad=True).cuda()
#print ('ones, zeros',ones.requires_grad, zeros.requires_grad)
p_repeat = ones.unsqueeze(1).expand(ones.size()[0],zeros.size()[0]).reshape((-1,1))
n_repeat = zeros.unsqueeze(0).expand(ones.size()[0],zeros.size()[0]).reshape((-1,1))
p_n_pairs = torch.cat((p_repeat,n_repeat),1)
ml_loss_i = torch.exp(-p_n_pairs[:,0]+p_n_pairs[:,1])
ml_loss_i = torch.div(torch.sum(ml_loss_i), 1.0 * ones.shape[0]*zeros.shape[0])
ml_loss.append(ml_loss_i.unsqueeze(0))
ml_loss = torch.cat(ml_loss,0)
ml_loss = torch.mean(ml_loss)
return ml_loss
def binary_focal_loss(y_pred, y_true, gamma=2., alpha=.25):
"""
把二分类拓展到多标签
Lin, Tsung-Yi, et al. "Focal loss for dense object detection." Proceedings of the IEEE international conference on computer vision. 2017.
Binary form of focal loss.
FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
param y_true: A tensor of the same shape as `y_pred`, 1 维
param y_pred: A tensor resulting from a sigmoid, 1 维
return: Output tensor.
"""
#print ('y_true,y_pred:',y_true.requires_grad,y_pred.requires_grad) False,True
ones = torch.ones_like(y_pred,dtype = torch.float)
zeros = torch.zeros_like(y_pred,dtype = torch.float)
pt_1 = torch.where(y_true == 1, y_pred, ones)
pt_0 = torch.where(y_true == 0, y_pred, zeros)
epsilon = eps=1e-10
pt_1 = torch.clamp(pt_1, epsilon, 1. - epsilon)
pt_0 = torch.clamp(pt_0, epsilon, 1. - epsilon)
bi_fl = -torch.sum(alpha * torch.pow(1. - pt_1, gamma) * torch.log(pt_1)) \
-torch.sum((1 - alpha) * torch.pow(pt_0, gamma) * torch.log(1. - pt_0))
# 若返回值为sum,则相当于alpha=n_batchsize,若返回为mean,则相当于alpha=1/n_class
return bi_fl
def categorical_focal_loss(y_pred,y_true,gamma=2.,weight = None):
""" 用于单标签
:param y_true: idx_list [1,2,5,4,0...],shape = [n_batch]
:param y_pred: A tensor resulting from a softmax,shape = [n_batch,n_class]
:return: loss of a batch.
if weight : loss = sum([w1*CE_1,w2*CE_2...]/(w1+w2+..))
"""
#def categorical_focal_loss(y_pred,y_true,gamma=2.,weight = None):
# 把 y_true 转换成 [n_batch,n_class]
#y_pred = torch.tensor(y_pred,requires_grad=True)
zeros = torch.zeros_like(y_pred,dtype = torch.float)
index = (torch.LongTensor(list(range(y_true.size()[0]))), y_true)
y_true = zeros.index_put_(indices = index,
values = torch.ones_like(y_true,dtype = torch.float))
epsilon = eps=1e-10
y_pred = torch.clamp(y_pred, epsilon, 1. - epsilon)
# Calculate Cross Entropy
cross_entropy = -y_true * torch.log(y_pred)
# Calculate Focal Loss
loss = torch.pow(1 - y_pred, gamma) * cross_entropy
if weight is not None:
loss = weight * loss
# Sum the losses in mini_batch(or K.mean)
return torch.mean(loss)