MIDL 2019——Boundary loss代码
程序员文章站
2022-03-04 20:10:22
...
会议MIDL简介
8 - 10 July 2019
全名International Conference on Medical Imaging with Deep Learning,会议主题是医学影像+深度学习。
Boundary loss由Boundary loss for highly unbalanced segmentation这篇文章提出,用于图像分割loss,作者的实验结果表明dice loss+Boundary loss效果非常好,一个是利用区域,一个利用边界。作者对这两个loss的用法是给他们一个权重,训练初期dice loss很高,随着训练进行,Boundary loss比例增加,也就是说越到训练后期越关注边界的准确,边界处理得更细一些。
对这篇文章更具体的介绍看以下文章:一票难求的MIDL 2019 Day 1-Boundary loss
这里我主要把作者开源的代码中的Boundary loss部分拿出来,并介绍如何使用,以二分类为例。
import torch
import numpy as np
from torch import einsum
from torch import Tensor
from scipy.ndimage import distance_transform_edt as distance
from scipy.spatial.distance import directed_hausdorff
from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union
# switch between representations
def probs2class(probs: Tensor) -> Tensor:
b, _, w, h = probs.shape # type: Tuple[int, int, int, int]
assert simplex(probs)
res = probs.argmax(dim=1)
assert res.shape == (b, w, h)
return res
def probs2one_hot(probs: Tensor) -> Tensor:
_, C, _, _ = probs.shape
assert simplex(probs)
res = class2one_hot(probs2class(probs), C)
assert res.shape == probs.shape
assert one_hot(res)
return res
def class2one_hot(seg: Tensor, C: int) -> Tensor:
if len(seg.shape) == 2: # Only w, h, used by the dataloader
seg = seg.unsqueeze(dim=0)
assert sset(seg, list(range(C)))
b, w, h = seg.shape # type: Tuple[int, int, int]
res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32)
assert res.shape == (b, C, w, h)
assert one_hot(res)
return res
def one_hot2dist(seg: np.ndarray) -> np.ndarray:
assert one_hot(torch.Tensor(seg), axis=0)
C: int = len(seg)
res = np.zeros_like(seg)
for c in range(C):
posmask = seg[c].astype(np.bool)
if posmask.any():
negmask = ~posmask
# print('negmask:', negmask)
# print('distance(negmask):', distance(negmask))
res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
# print('res[c]', res[c])
return res
def simplex(t: Tensor, axis=1) -> bool:
_sum = t.sum(axis).type(torch.float32)
_ones = torch.ones_like(_sum, dtype=torch.float32)
return torch.allclose(_sum, _ones)
def one_hot(t: Tensor, axis=1) -> bool:
return simplex(t, axis) and sset(t, [0, 1])
# Assert utils
def uniq(a: Tensor) -> Set:
return set(torch.unique(a.cpu()).numpy())
def sset(a: Tensor, sub: Iterable) -> bool:
return uniq(a).issubset(sub)
class SurfaceLoss():
def __init__(self):
# Self.idc is used to filter out some classes of the target mask. Use fancy indexing
self.idc: List[int] = [1] #这里忽略背景类 https://github.com/LIVIAETS/surface-loss/issues/3
# probs: bcwh, dist_maps: bcwh
def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor:
assert simplex(probs)
assert not one_hot(dist_maps)
pc = probs[:, self.idc, ...].type(torch.float32)
dc = dist_maps[:, self.idc, ...].type(torch.float32)
print('pc', pc)
print('dc', dc)
multipled = einsum("bcwh,bcwh->bcwh", pc, dc)
loss = multipled.mean()
return loss
if __name__ == "__main__":
data = torch.tensor([[[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 0, 0, 0, 0],
[0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]]])
data2 = class2one_hot(data, 2)
# print(data2)
data2 = data2[0].numpy()
data3 = one_hot2dist(data2) #bcwh
# print(data3)
print("data3.shape:", data3.shape)
logits = torch.tensor([[[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0],
[0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]]])
logits = class2one_hot(logits, 2)
Loss = SurfaceLoss()
data3 = torch.tensor(data3).unsqueeze(0)
res = Loss(logits, data3, None)
print('loss:', res)
输出结果:
loss: tensor(0.2143)
如果prediction和label一致,loss为0。如果prediction比label小并被label包围,loss为负。
其中label计算距离图,即
data2 = class2one_hot(data, 2)
data2 = data2[0].numpy()
data3 = one_hot2dist(data2) #bcwh
这几步,可以放到读取数据集,做出label之后。
上一篇: Oracle 配置远程访问教程
下一篇: 难道你上班只是为了钱么