Unet图像分割网络Pytorch实现
介绍
最近一个月接触了一下Pytorch,个人认为Pytorch相较于Tensorflow来说好用很多。本文的内容是我对Unet论文的总结与提炼,需要提醒的是,Unet原文发布的时候还没有提出BN(Batch Normalization). 所以在本文中我会增加这一个步骤。
如果想要安装Python和Pytorch或者获得进一步的信息可以点击Python ,Pytorch
在图像分割这个大问题上,主要有两个流派:U-shape和dialated Conv。本文介绍的是U-shape网络中最为经典的U-Net。随着骨干网路的进化,很多相应衍生出来的网络大多都是对于Unet进行了改进但是本质上的思路还是没有太多的变化。比如结合DenseNet 和Unet的FCDenseNet, Unet++
Unet
Unet是一个为医学图像分割设计的auto-encoder-decoder结构的网络。行业里也把它视作一种FCN(fully connected network)。 它可以分成两个部分,down(encoder) 和 up(decoder)。down的主要结构可以看成conv后面跟maxpool。 up的主要结构是一个upsample后面跟conv。
Unet的核心思想
想要弄清这个问题首先要感性的理解一下卷积的作用。就拿MINIST数据集训练数字识别这个简单的CNN网络为例, 它把一个28*28的图片抽象成一个0-9的向量。卷积可以看成是特征的提取,它可以提取出输入的信息的抽象概念。但是Pool和Conv会损失空间信息。其中,空间信息在pool的过程中损失的更为严重。对于图像分割来说, 空间信息和抽象信息同样重要。既然每一个次pool的时候会严重损失空间信息,也就是说maxpool之间的空间信息多于之后的。于是Unet提出,把down的特征连接到对应的up上。
Unet的结构
其中灰色箭头copy and crop
中的copy
就是concatenate
而crop
是为了让两者的长宽一致
左半边就是down path右半边 就是up path。我们来分别介绍这两个部分。
Down Path
图中input image tile
就是我们输入的训练数据。除了第一层是两个conv,其他层都可以看成是maxpool后面跟两个conv。在Unet中绝大部分的conv都是两个conv连用的形式存在的,为了方便,我们可以先自定义一个double_conv
类。
# 实现double conv
class double_conv(nn.Module):
''' Conv => Batch_Norm => ReLU => Conv2d => Batch_Norm => ReLU
'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
self.conv.apply(self.init_weights)
def forward(self, x):
x = self.conv(x)
return x
@staticmethod
def init_weights(m):
if type(m) == nn.Conv2d:
init.xavier_normal(m.weight)
init.constant(m.bias,0)
下面我们来实现input conv, 它实际上用一个double_conv
也就完成了。
# 实现input conv
class inconv(nn.Module):
''' input conv layer
let input 3 channels image to 64 channels
The oly difference between `inconv` and `down` is maxpool layer
'''
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch)
def forward(self, x):
x = self.conv(x)
return x
接下来我们来实现down
类,它的结构是一个maxpool
接一个double_conv
。
class down(nn.Module):
''' normal down path
MaxPool2d => double_conv
'''
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
)
def forward(self, x):
x = self.mpconv(x)
return x
Up path
Unet的up path主要的结构是upsampl
加上double_conv
但是也可以使用ConvTranspose2d
代替upsample。下面的代码给出了两种选择。
在up path 中,我们需要将down path 中的特征合并进来。在up.forward
中crop从而让两个特征一致。
class up(nn.Module):
''' up path
conv_transpose => double_conv
'''
def __init__(self, in_ch, out_ch, Transpose=False):
super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if Transpose:
self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
else:
# self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_ch, in_ch//2, kernel_size=1, padding=0),
nn.ReLU(inplace=True))
self.conv = double_conv(in_ch, out_ch)
self.up.apply(self.init_weights)
def forward(self, x1, x2):
'''
conv output shape = (input_shape - Filter_shape + 2 * padding)/stride + 1
'''
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = nn.functional.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
x = torch.cat([x2,x1], dim=1)
x = self.conv(x)
return x
@staticmethod
def init_weights(m):
if type(m) == nn.Conv2d:
init.xavier_normal(m.weight)
init.constant(m.bias,0)
*已经造好了,那么我们来实现Unet让它跑起来
class Unet(nn.Module):
def __init__(self, in_ch, out_ch, gpu_ids=[]):
super(Unet, self).__init__()
self.loss_stack = 0
self.matrix_iou_stack = 0
self.stack_count = 0
self.display_names = ['loss_stack', 'matrix_iou_stack']
self.gpu_ids = gpu_ids
self.bce_loss = nn.BCELoss()
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if torch.cuda.is_available() else torch.device('cpu')
self.inc = inconv(in_ch, 64)
self.down1 = down(64, 128)
# print(list(self.down1.parameters()))
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.drop3 = nn.Dropout2d(0.5)
self.down4 = down(512, 1024)
self.drop4 = nn.Dropout2d(0.5)
self.up1 = up(1024, 512, False)
self.up2 = up(512, 256, False)
self.up3 = up(256, 128, False)
self.up4 = up(128, 64, False)
self.outc = outconv(64, 1)
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
# self.optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
def forward(self):
x1 = self.inc(self.x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x4 = self.drop3(x4)
x5 = self.down4(x4)
x5 = self.drop4(x5)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
self.pred_y = nn.functional.sigmoid(x)
def set_input(self, x, y):
self.x = x.to(self.device)
self.y = y.to(self.device)
def optimize_params(self):
self.forward()
self._bce_iou_loss()
_ = self.accu_iou()
self.stack_count += 1
self.zero_grad()
self.loss.backward()
self.optimizer.step()
def accu_iou(self):
# B is the mask pred, A is the malanoma
y_pred = (self.pred_y > 0.5) * 1.0
y_true = (self.y > 0.5) * 1.0
pred_flat = y_pred.view(y_pred.numel())
true_flat = y_true.view(y_true.numel())
intersection = float(torch.sum(pred_flat * true_flat)) + 1e-7
denominator = float(torch.sum(pred_flat + true_flat)) - intersection + 2e-7
self.matrix_iou = intersection/denominator
self.matrix_iou_stack += self.matrix_iou
return self.matrix_iou
def _bce_iou_loss(self):
y_pred = self.pred_y
y_true = self.y
pred_flat = y_pred.view(y_pred.numel())
true_flat = y_true.view(y_true.numel())
intersection = torch.sum(pred_flat * true_flat) + 1e-7
denominator = torch.sum(pred_flat + true_flat) - intersection + 1e-7
iou = torch.div(intersection, denominator)
bce_loss = self.bce_loss(pred_flat, true_flat)
self.loss = bce_loss - iou + 1
self.loss_stack += self.loss
def get_current_losses(self):
errors_ret = {}
for name in self.display_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, name)) / self.stack_count
self.loss_stack = 0
self.matrix_iou_stack = 0
self.stack_count = 0
return errors_ret
def eval_iou(self):
with torch.no_grad():
self.forward()
self._bce_iou_loss()
_ = self.accu_iou()
self.stack_count += 1
其他的代码就是很固定的pytorch模板代码了。
代码参考自GitHub
转载请标明出处