欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

损失函数 DiceLoss 的 Pytorch、TensorFlow 实现

程序员文章站 2022-05-04 12:36:32
Pytorch 实现def dice_loss(preds, targets): """ preds: tensor of shape (N, C) targets: tensor of shape (N, C) """ assert preds.shape == targets.shape preds = preds.float() targets = targets.float() numer...

二分类的 DiceLoss 损失函数

二分类 Dice 系数计算

假设模型输出的预测值 preds 经过 sigmoid 后,得到 logits 如下所示
损失函数 DiceLoss 的 Pytorch、TensorFlow 实现
该 logits 对应的标签 label 如下所示,0 表示不属于某一类,1 表示属于某一类:
损失函数 DiceLoss 的 Pytorch、TensorFlow 实现
根据 DiceLoss 系数的定义有:
XY=[0.53220.49320.17640.31070.52970.16040.38410.35370.35740.33230.83010.6436][000000111111]=[0.00000.00000.00000.00000.00000.00000.38410.35370.35740.33230.83010.6436]2.9012() \begin{aligned} |X \cap Y| &=\begin{bmatrix} 0.5322&0.4932&0.1764\\ 0.3107&0.5297&0.1604\\ 0.3841&0.3537&0.3574\\ 0.3323&0.8301&0.6436 \end{bmatrix} \star \begin{bmatrix} 0&0&0\\ 0&0&0\\ 1&1&1\\ 1&1&1 \end{bmatrix} \\&= \begin{bmatrix} 0.0000&0.0000&0.0000\\ 0.0000&0.0000&0.0000\\ 0.3841&0.3537&0.3574\\ 0.3323&0.8301&0.6436 \end{bmatrix} \rightarrow 2.9012 (求和) \end{aligned}

X=[0.53220.49320.17640.31070.52970.16040.38410.35370.35740.33230.83010.6436]5.1038 |X| = \begin{bmatrix} 0.5322&0.4932&0.1764\\ 0.3107&0.5297&0.1604\\ 0.3841&0.3537&0.3574\\ 0.3323&0.8301&0.6436 \end{bmatrix} \rightarrow 5.1038

Y=[000000111111]8 |Y| = \begin{bmatrix} 0&0&0\\ 0&0&0\\ 1&1&1\\ 1&1&1 \end{bmatrix} \rightarrow 8

所以 Dice 系数为
D=2XY+1X+Y+1=22.9012+15.1038+81=0.5901 D = \frac{2 * |X\cap Y| +1}{|X| + |Y | + 1} = \frac{2 * 2.9012 + 1}{ 5.1038 + 8+1}=0.5901

所以 Dice 损失 L=1D=0.4099L = 1-D=0.4099

这是二分类一个批次只有一张图的情况,当一个批次有 NN 张图片时,可以将图片压缩为一维向量,如下所示:
损失函数 DiceLoss 的 Pytorch、TensorFlow 实现
对应的 label 也做相应的变换,最后一起计算 NN 张图片的 Dice 系数 和 Loss。

上面这个过程的 pytorch 代码实现如下所示;

import torch
import torch.nn as nn

class BinaryDiceLoss(nn.Model):
	def __init__(self):
		super(BinaryDiceLoss, self).__init__()
	
	def forward(self, input, targets):
		# 获取每个批次的大小 N
		N = targets.size()[0]
		# 平滑变量
		smooth = 1
		# 将宽高 reshape 到同一纬度
		input_flat = input.view(N, -1)
		targets_flat = targets.view(N, -1)
	
		# 计算交集
		intersection = input_flat * targets_flat 
		N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
		# 计算一个批次中平均每张图的损失
		loss = 1 - dice_eff.sum() / N
		return loss

多分类 DiceLoss 损失函数

当有多个分类时,label 通过 one hot 转化为多个二分类,如下图所示:
损失函数 DiceLoss 的 Pytorch、TensorFlow 实现
每个channel 切面,可以看作是一个二分类问题,所以多分类 DiceLoss 损失函数,可以通过计算每个类别的二分类 DiceLoss 损失,最后再求均值得到。pytorch 代码如下所示:

import torch
import torch.nn as nn

class MultiClassDiceLoss(nn.Module):
	def __init__(self, weight=None, ignore_index=None, **kwargs):
		super(MultiClassDiceLoss, self).__init__()
		self.weight = weight
		self.ignore_index = ignore_index
		self.kwargs = kwargs
	
	def forward(self, input, target):
		"""
			input tesor of shape = (N, C, H, W)
			target tensor of shape = (N, C, H, W)
		"""
		assert input.shape == target.shape, "predict & target shape do not match"
		
		binaryDiceLoss = BinaryDiceLoss()
		total_loss = 0
		
		# 归一化输出
		logits = F.softmax(input, dim=1)
		C = target.shape[1]
		
		# 遍历 channel,得到每个类别的二分类 DiceLoss
		for i in range(C):
			dice_loss = binaryDiceLoss(logits[:, i], target[:, i])
			total_loss += dice_loss
		
		# 每个类别的平均 dice_loss
		return total_loss / C

本文地址:https://blog.csdn.net/liangjiu2009/article/details/107352164

上一篇: CSS3总结

下一篇: javascript闭包详解