欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【Pytorch】语义分割、医学图像分割one-hot的两种实现方式

程序员文章站 2022-07-05 11:06:33
...

简要说明one-hot编码

这里只做简单解释,详细说明请查看相关资料。

假设我们有3个分类标签[1, 2, 3],one-hot则将其编码为:

  • 1 -> 1 0 0
  • 2 -> 0 1 0
  • 3 -> 0 0 1

即有多少个分类标签,就要用多少位来进行one-hot编码,这里的位数通常也就对应语义分割中网络模型最后输出的特征图(即预测图)的通道数(num_classes)。

语义分割中one-hot一般用于mask(或gt),这取决于在计算损失的时候用什么损失函数。

例如,在pytorch中,假设训练的时候用的是CrossEntropyLoss,则不用对mask进行手动one-hot编码,CrossEntropyLoss底层会自动对传入的mask进行one-hot编码。
若用的是BCELoss,就需要对mask进行手动one-hot编码。

下面进入正题。使用的数据样本描述:(灰色128表示膀胱内外壁,白色255表示肿瘤,黑色0表示背景)。
【Pytorch】语义分割、医学图像分割one-hot的两种实现方式

one-hot第一种实现

def mask_to_onehot(mask, palette):
    """
    Converts a segmentation mask (H, W, C) to (H, W, K) where the last dim is a one
    hot encoding vector, C is usually 1 or 3, and K is the number of class.
    """
    semantic_map = []
    for colour in palette:
        equality = np.equal(mask, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
    return semantic_map

def onehot_to_mask(mask, palette):
    """
    Converts a mask (H, W, K) to (H, W, C)
    """
    x = np.argmax(mask, axis=-1)
    colour_codes = np.array(palette)
    x = np.uint8(colour_codes[x.astype(np.uint8)])
    return x

one-hot第二种实现

def mask2onehot(mask, num_classes):
    """
    Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one
    hot encoding vector

    """
    _mask = [mask == i for i in range(num_classes)]
    return np.array(_mask).astype(np.uint8)

def onehot2mask(mask):
    """
    Converts a mask (K, H, W) to (H,W)
    """
    _mask = np.argmax(mask, axis=0).astype(np.uint8)
    return _mask

下面演示两种方法各自的使用方法。

from PIL import Image
import cv2
np.set_printoptions(threshold=9999999)


'''color map:膀胱内外壁灰色128,肿瘤白色255, 背景黑色0
bladder wall: 128
tumor: 255
background: 0
'''

mask = Image.open('D:\\Learning\\datasets\\基于磁共振成像的膀胱内外壁分割与肿瘤检测\\Labels\\18-IM351.png')
mask = np.array(mask)
mask_copy = mask.copy()

'''第一种onehot方式'''
# 不同颜色映射的标签
label2trainid = {128: 0, 255: 1, 0: 2}

# color to trainid
for k, v in label2trainid.items():
    mask_copy[mask == k] = v
mask_with_trainid = mask_copy.astype(np.uint8)
# one-hot 编码结果
mask_onehot1 = mask2onehot(mask_with_trainid, 3)  # shape = (K, H, W)
# 恢复成原来的mask
mask_with_trainid = onehot2mask(mask_onehot1)
mask_recover1 = mask_with_trainid.copy()  # shape = (H, W)
for k, v in label2trainid.items():
    mask_recover1[mask_with_trainid == v] = k

'''第二种onehot方式'''
# 调色板
palette = [[128], [255], [0]]
# 扩展最后一维,用来做one-hot映射
mask_onehot2 = np.expand_dims(mask, axis=2)# shape = (H, W) -> (H, W, 1) 
# one-hot 编码结果
mask_onehot2 = mask_to_onehot(mask_onehot2, palette)  # shape = (H, W, K)
# 恢复成原来的mask, 逆向one-hot映射并压缩最后一维
mask_recover2 = onehot_to_mask(mask_onehot2, palette).squeeze()  # shape = (H, W)

# 判断两种onehot的结果是否一样
print(np.equal(mask_onehot1.transpose([1, 2, 0]), mask_onehot2).all())  # True
# 判断恢复两种onehot的结果是否一样
print(np.equal(mask_recover1, mask_recover2).all())  # True
cv2.imshow('onehot1, onehot2', np.hstack([mask_recover1, mask_recover2]))
cv2.waitKey(0)

代码注释部分已经很详细了,就不做过多的说明。

one-hot的效果

下面用一个例子描述一下one-hot的效果。
为了简单起见,假设我们有一个 2 × 2 大小的mask图(上面用到的样本数据的简化版),里面的值如下:

mask = np.array([[128, 128],
                 [255, 0]])
# mask
[[128, 128],
 [255, 0  ]]    

可见,总共有4个值,里面有两个值属于bladder,一个值属于tumor, 一个值属于background。
现在我们对数据最后一维进行扩展,由 2 × 2 变为2 × 2 × 1,方便做one-hot。

mask = np.expand_dims(mask, axis=2)

# mask
[[[128]
  [128]]

 [[255]
  [  0]]]

再对mask做one-hot映射之后,数据里面的值就会变成下面这样:

palette = [[128], [255], [0]]
mask = mask_to_onehot(mask, palette)

# mask
[[[1. 0. 0.]
  [1. 0. 0.]]

 [[0. 1. 0.]
  [0. 0. 1.]]]

分析:one-hot其实就是将原始数据中一个像素值映射成了一个n维的向量(n=分类数),按照我们写的调色板palette = [[128], [255], [0]]里面的值的顺序:

  • 128对应第1类,则对应的one-hot映射就为 1 0 0
  • 255 对应第2类,则对应的one-hot映射就为 0 1 0
  • 0对应第3类,则对应的one-hot映射就为 0 0 1

one-hot编码之后,数据由 2 × 2 × 1 变成了 了 2 × 2 × 3, 在pytorch中训练时,一般还需再对数据进行转置,由2 × 2 × 3 变为 3 × 2 × 2, 最后在训练任务中同时加载多张图的时候, 维度就会变成N × 3 × 2 × 2,即 N × C × H × W的形式(这里的Nbatch_size的大小)。分割任务中网络模型最后的输出也是N × C × H × W的形式。
此时,可以直接用 BCELoss 或者 DiceLoss进行损失计算。应为这两种损失函数要求接收的InputTarget的维度要一致。但如果此时我们要用CrossEntropyLoss进行训练的话怎么办呢,CrossEntropyLoss要求接收的Target的维度为N × H × W,比Input 的维度少一位,并且要求Target的数据类型为Long,我觉得是因为CrossEntropyLoss要求Target为标签形式而不是one-hot形式。

我们只需要对上面得到的one-hot形式的数据求argmax即可, 使用argmax 需要指定维度,这里维度指定C所在的维度即可,上面经过one-hot编码后的数据维度为 2 × 2 × 3C所在的维度为2,则代码这样写:(补充说明:对于维度为N × C × H × W的形式,可见C所在的维度为1, 则代码里面的 axis/dim = 1 )

mask = np.argmax(mask, axis=2)
# 若 mask 为 tensor 则用 torch.argmax(mask, dim=2)
# mask
[[0 0]
 [1 2]]

这样就又得到数据的标签了, 数据维度由2 × 2 × 3 变回了2 × 2,此时就可以用CrossEntropyLoss进行训练了。

  • 128对应第1类,则对应标签为0
  • 255对应第2类,则对应标签为1
  • 0对应第3类,则对应标签为2

至此,要说的都讲完啦~

相关标签: 图像分割