【Pytorch】语义分割、医学图像分割one-hot的两种实现方式
简要说明one-hot编码
这里只做简单解释,详细说明请查看相关资料。
假设我们有3个分类标签[1, 2, 3]
,one-hot则将其编码为:
-
1
->1 0 0
-
2
->0 1 0
-
3
->0 0 1
即有多少个分类标签,就要用多少位来进行one-hot编码,这里的位数通常也就对应语义分割中网络模型最后输出的特征图(即预测图)的通道数(num_classes)。
语义分割中one-hot一般用于mask(或gt),这取决于在计算损失的时候用什么损失函数。
例如,在pytorch中,假设训练的时候用的是CrossEntropyLoss
,则不用对mask进行手动one-hot编码,CrossEntropyLoss
底层会自动对传入的mask进行one-hot编码。
若用的是BCELoss
,就需要对mask进行手动one-hot编码。
下面进入正题。使用的数据样本描述:(灰色128
表示膀胱内外壁,白色255
表示肿瘤,黑色0
表示背景)。
one-hot第一种实现
def mask_to_onehot(mask, palette):
"""
Converts a segmentation mask (H, W, C) to (H, W, K) where the last dim is a one
hot encoding vector, C is usually 1 or 3, and K is the number of class.
"""
semantic_map = []
for colour in palette:
equality = np.equal(mask, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
return semantic_map
def onehot_to_mask(mask, palette):
"""
Converts a mask (H, W, K) to (H, W, C)
"""
x = np.argmax(mask, axis=-1)
colour_codes = np.array(palette)
x = np.uint8(colour_codes[x.astype(np.uint8)])
return x
one-hot第二种实现
def mask2onehot(mask, num_classes):
"""
Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one
hot encoding vector
"""
_mask = [mask == i for i in range(num_classes)]
return np.array(_mask).astype(np.uint8)
def onehot2mask(mask):
"""
Converts a mask (K, H, W) to (H,W)
"""
_mask = np.argmax(mask, axis=0).astype(np.uint8)
return _mask
下面演示两种方法各自的使用方法。
from PIL import Image
import cv2
np.set_printoptions(threshold=9999999)
'''color map:膀胱内外壁灰色128,肿瘤白色255, 背景黑色0
bladder wall: 128
tumor: 255
background: 0
'''
mask = Image.open('D:\\Learning\\datasets\\基于磁共振成像的膀胱内外壁分割与肿瘤检测\\Labels\\18-IM351.png')
mask = np.array(mask)
mask_copy = mask.copy()
'''第一种onehot方式'''
# 不同颜色映射的标签
label2trainid = {128: 0, 255: 1, 0: 2}
# color to trainid
for k, v in label2trainid.items():
mask_copy[mask == k] = v
mask_with_trainid = mask_copy.astype(np.uint8)
# one-hot 编码结果
mask_onehot1 = mask2onehot(mask_with_trainid, 3) # shape = (K, H, W)
# 恢复成原来的mask
mask_with_trainid = onehot2mask(mask_onehot1)
mask_recover1 = mask_with_trainid.copy() # shape = (H, W)
for k, v in label2trainid.items():
mask_recover1[mask_with_trainid == v] = k
'''第二种onehot方式'''
# 调色板
palette = [[128], [255], [0]]
# 扩展最后一维,用来做one-hot映射
mask_onehot2 = np.expand_dims(mask, axis=2)# shape = (H, W) -> (H, W, 1)
# one-hot 编码结果
mask_onehot2 = mask_to_onehot(mask_onehot2, palette) # shape = (H, W, K)
# 恢复成原来的mask, 逆向one-hot映射并压缩最后一维
mask_recover2 = onehot_to_mask(mask_onehot2, palette).squeeze() # shape = (H, W)
# 判断两种onehot的结果是否一样
print(np.equal(mask_onehot1.transpose([1, 2, 0]), mask_onehot2).all()) # True
# 判断恢复两种onehot的结果是否一样
print(np.equal(mask_recover1, mask_recover2).all()) # True
cv2.imshow('onehot1, onehot2', np.hstack([mask_recover1, mask_recover2]))
cv2.waitKey(0)
代码注释部分已经很详细了,就不做过多的说明。
one-hot的效果
下面用一个例子描述一下one-hot的效果。
为了简单起见,假设我们有一个 2 × 2
大小的mask图(上面用到的样本数据的简化版),里面的值如下:
mask = np.array([[128, 128],
[255, 0]])
# mask
[[128, 128],
[255, 0 ]]
可见,总共有4个值,里面有两个值属于bladder,一个值属于tumor, 一个值属于background。
现在我们对数据最后一维进行扩展,由 2 × 2
变为2 × 2 × 1
,方便做one-hot。
mask = np.expand_dims(mask, axis=2)
# mask
[[[128]
[128]]
[[255]
[ 0]]]
再对mask做one-hot映射之后,数据里面的值就会变成下面这样:
palette = [[128], [255], [0]]
mask = mask_to_onehot(mask, palette)
# mask
[[[1. 0. 0.]
[1. 0. 0.]]
[[0. 1. 0.]
[0. 0. 1.]]]
分析:one-hot其实就是将原始数据中一个像素值映射成了一个n维的向量(n=分类数),按照我们写的调色板palette = [[128], [255], [0]]
里面的值的顺序:
-
128
对应第1
类,则对应的one-hot映射就为1 0 0
-
255
对应第2
类,则对应的one-hot映射就为0 1 0
-
0
对应第3
类,则对应的one-hot映射就为0 0 1
one-hot编码之后,数据由 2 × 2 × 1
变成了 了 2 × 2 × 3
, 在pytorch中训练时,一般还需再对数据进行转置,由2 × 2 × 3
变为 3 × 2 × 2
, 最后在训练任务中同时加载多张图的时候, 维度就会变成N × 3 × 2 × 2
,即 N × C × H × W
的形式(这里的N
指 batch_size
的大小)。分割任务中网络模型最后的输出也是N × C × H × W
的形式。
此时,可以直接用 BCELoss
或者 DiceLoss
进行损失计算。应为这两种损失函数要求接收的Input
和Target
的维度要一致。但如果此时我们要用CrossEntropyLoss
进行训练的话怎么办呢,CrossEntropyLoss
要求接收的Target
的维度为N × H × W
,比Input
的维度少一位,并且要求Target
的数据类型为Long
,我觉得是因为CrossEntropyLoss
要求Target
为标签形式而不是one-hot形式。
我们只需要对上面得到的one-hot形式的数据求argmax
即可, 使用argmax
需要指定维度,这里维度指定C
所在的维度即可,上面经过one-hot编码后的数据维度为 2 × 2 × 3
,C
所在的维度为2,则代码这样写:(补充说明:对于维度为N × C × H × W
的形式,可见C
所在的维度为1, 则代码里面的 axis/dim = 1
)
mask = np.argmax(mask, axis=2)
# 若 mask 为 tensor 则用 torch.argmax(mask, dim=2)
# mask
[[0 0]
[1 2]]
这样就又得到数据的标签了, 数据维度由2 × 2 × 3
变回了2 × 2
,此时就可以用CrossEntropyLoss
进行训练了。
-
128
对应第1
类,则对应标签为0
-
255
对应第2
类,则对应标签为1
-
0
对应第3
类,则对应标签为2
至此,要说的都讲完啦~