欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

MIDL 2019——Boundary loss代码

程序员文章站 2022-03-04 20:10:22
...

会议MIDL简介
8 - 10 July 2019

全名International Conference on Medical Imaging with Deep Learning,会议主题是医学影像+深度学习。

Boundary loss由Boundary loss for highly unbalanced segmentation这篇文章提出,用于图像分割loss,作者的实验结果表明dice loss+Boundary loss效果非常好,一个是利用区域,一个利用边界。作者对这两个loss的用法是给他们一个权重,训练初期dice loss很高,随着训练进行,Boundary loss比例增加,也就是说越到训练后期越关注边界的准确,边界处理得更细一些。

对这篇文章更具体的介绍看以下文章:一票难求的MIDL 2019 Day 1-Boundary loss

这里我主要把作者开源的代码中的Boundary loss部分拿出来,并介绍如何使用,以二分类为例。

import torch
import numpy as np
from torch import einsum
from torch import Tensor
from scipy.ndimage import distance_transform_edt as distance
from scipy.spatial.distance import directed_hausdorff

from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union

# switch between representations
def probs2class(probs: Tensor) -> Tensor:
    b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
    assert simplex(probs)

    res = probs.argmax(dim=1)
    assert res.shape == (b, w, h)

    return res

def probs2one_hot(probs: Tensor) -> Tensor:
    _, C, _, _ = probs.shape
    assert simplex(probs)

    res = class2one_hot(probs2class(probs), C)
    assert res.shape == probs.shape
    assert one_hot(res)

    return res


def class2one_hot(seg: Tensor, C: int) -> Tensor:
    if len(seg.shape) == 2:  # Only w, h, used by the dataloader
        seg = seg.unsqueeze(dim=0)
    assert sset(seg, list(range(C)))

    b, w, h = seg.shape  # type: Tuple[int, int, int]

    res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32)
    assert res.shape == (b, C, w, h)
    assert one_hot(res)

    return res


def one_hot2dist(seg: np.ndarray) -> np.ndarray:
    assert one_hot(torch.Tensor(seg), axis=0)
    C: int = len(seg)

    res = np.zeros_like(seg)
    for c in range(C):
        posmask = seg[c].astype(np.bool)

        if posmask.any():
            negmask = ~posmask
            # print('negmask:', negmask)
            # print('distance(negmask):', distance(negmask))
            res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
            # print('res[c]', res[c])
    return res


def simplex(t: Tensor, axis=1) -> bool:
    _sum = t.sum(axis).type(torch.float32)
    _ones = torch.ones_like(_sum, dtype=torch.float32)
    return torch.allclose(_sum, _ones)


def one_hot(t: Tensor, axis=1) -> bool:
    return simplex(t, axis) and sset(t, [0, 1])

    # Assert utils
def uniq(a: Tensor) -> Set:
    return set(torch.unique(a.cpu()).numpy())


def sset(a: Tensor, sub: Iterable) -> bool:
    return uniq(a).issubset(sub)

class SurfaceLoss():
    def __init__(self):
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
        self.idc: List[int] = [1]   #这里忽略背景类  https://github.com/LIVIAETS/surface-loss/issues/3

    # probs: bcwh, dist_maps: bcwh
    def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor:
        assert simplex(probs)
        assert not one_hot(dist_maps)

        pc = probs[:, self.idc, ...].type(torch.float32)
        dc = dist_maps[:, self.idc, ...].type(torch.float32)

        print('pc', pc)
        print('dc', dc)

        multipled = einsum("bcwh,bcwh->bcwh", pc, dc)

        loss = multipled.mean()

        return loss


if __name__ == "__main__":
    data = torch.tensor([[[0, 0, 0, 0, 0, 0, 0], 
                          [0, 1, 1, 0, 0, 0, 0],
                          [0, 1, 1, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0, 0]]])

    data2 = class2one_hot(data, 2)
    # print(data2)
    data2 = data2[0].numpy()
    data3 = one_hot2dist(data2)   #bcwh

    # print(data3)
    print("data3.shape:", data3.shape)

    logits = torch.tensor([[[0, 0, 0, 0, 0, 0, 0], 
                             [0, 1, 1, 1, 1, 1, 0],
                             [0, 1, 1, 0, 0, 0, 0],
                             [0, 0, 0, 0, 0, 0, 0]]])

    logits = class2one_hot(logits, 2)
                
    Loss = SurfaceLoss()
    data3 = torch.tensor(data3).unsqueeze(0)

    res = Loss(logits, data3, None)
    print('loss:', res)

输出结果:

loss: tensor(0.2143)

如果prediction和label一致,loss为0。如果prediction比label小并被label包围,loss为负。

其中label计算距离图,即

data2 = class2one_hot(data, 2)
data2 = data2[0].numpy()
data3 = one_hot2dist(data2)   #bcwh

这几步,可以放到读取数据集,做出label之后。