PyTorch:Unet网络实现人脑图像分割
程序员文章站
2022-04-08 08:45:59
...
1 介绍
U-Net是一篇基本结构非常好的论文,主要是针对生物医学图片的分割,而且,在今后的许多对医学图像的分割网络中,很大一部分会采取U-Net作为网络的主干。相对于当年的,在EM segmentation challenge at ISBI 2012上做到比当时的best更好。而且速度也非常的快。其有一个很好的优点,就是在小数据集上也是能做得比较好的。就比如EM 2012这个数据集就只是30个果蝇第一龄幼虫腹侧神经所索的连续部分透射电子显微镜图。
本文主要利用pytorch实现了U-Net网络的人脑分割,并展示了部分分割效果图。希望对你有所帮助!
2 源代码
(1)网络结构代码
import torch.nn as nn
import torch
from torch import autograd
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self,in_ch,out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64,out_ch, 1)
def forward(self,x):
c1=self.conv1(x)
p1=self.pool1(c1)
c2=self.conv2(p1)
p2=self.pool2(c2)
c3=self.conv3(p2)
p3=self.pool3(c3)
c4=self.conv4(p3)
p4=self.pool4(c4)
c5=self.conv5(p4)
up_6= self.up6(c5)
merge6 = torch.cat([up_6, c4], dim=1)
c6=self.conv6(merge6)
up_7=self.up7(c6)
merge7 = torch.cat([up_7, c3], dim=1)
c7=self.conv7(merge7)
up_8=self.up8(c7)
merge8 = torch.cat([up_8, c2], dim=1)
c8=self.conv8(merge8)
up_9=self.up9(c8)
merge9=torch.cat([up_9,c1],dim=1)
c9=self.conv9(merge9)
c10=self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
(2)数据集准备
import torch.utils.data as data
import PIL.Image as Image
import os
def make_dataset(rootdata,roottarget):#获取img和mask的地址
imgs = []
filename_data = [x for x in os.listdir(rootdata)]
for name in filename_data:
img = os.path.join(rootdata, name)
mask = os.path.join(roottarget, name)
imgs.append((img, mask))#作为元组返回
return imgs
class MyDataset(data.Dataset):
def __init__(self, rootdata, roottarget, transform=None, target_transform=None):
imgs = make_dataset(rootdata,roottarget)
#print(imgs)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
x_path, y_path = self.imgs[index]
# print(x_path)
img_x = Image.open(x_path).convert('L')#读取并转换为二值图像
img_y = Image.open(y_path).convert('L')
if self.transform is not None:
img_x = self.transform(img_x)
if self.target_transform is not None:
img_y = self.target_transform(img_y)
# print(img_x.shape[0])
# print(img_x.shape[1])
# print(img_x.shape[2])
# print(img_x)
return img_x, img_y
def __len__(self):
return len(self.imgs)
(3)训练与测试
import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torch.autograd import Variable
from torchvision import transforms
from unet import Unet
from dataset import MyDataset
from dataset import make_dataset
import os
import cv2
import time
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# mask只需要转换为tensor
y_transforms = transforms.ToTensor()
def train_model(model, criterion, optimizer, dataload, num_epochs=150):
for epoch in range(0,num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
dt_size = len(dataload.dataset)
epoch_loss = 0
step = 0
for x, y in dataload:
step += 1
inputs = x.to(device)
labels = y.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("%d/%d,train_loss:%0.6f" %
(step,(dt_size - 1) // dataload.batch_size + 1, loss.item()))
print("epoch %d loss:%0.6f" % (epoch, epoch_loss))
torch.save(model.cpu().state_dict(), 'weights_%d.pth' % epoch)
torch.save(model.cpu(),'weights_%d_dc.pth' % epoch)
return model
#训练模型
def train(train_data_path,train_gt_path):
batch_size = 1
# liver_dataset = MyDataset(
# "image/train/data", "image/train/gt",transform=x_transforms, target_transform=y_transforms)
liver_dataset = MyDataset(
train_data_path, train_gt_path, transform=x_transforms, target_transform=y_transforms)
dataloaders = DataLoader(
liver_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
train_model(model, criterion, optimizer, dataloaders)
#显示模型的输出结果
def test(test_data_path,test_gt_path,save_pre_path):
# liver_dataset = MyDataset(
# "image/val/data", "image/val/gt", transform=x_transforms, target_transform=y_transforms)
liver_dataset = MyDataset(
test_data_path, test_gt_path, transform=x_transforms, target_transform=y_transforms)
imgs = make_dataset(test_data_path, test_gt_path)
# print(imgs[0][1])
print(liver_dataset)
dataloaders = DataLoader(liver_dataset, batch_size=1)
import matplotlib.pyplot as plt
plt.ion()
with torch.no_grad():
count = 0
for x, _ in dataloaders:
start = time.clock()
x = x.to('cuda')
y = model(x)
img_y = torch.squeeze(y).cpu().numpy()
elapsed = (time.clock() - start)
print("Time used:",elapsed)
plt.imsave(os.path.join(save_pre_path,os.path.basename(imgs[count][1])), img_y)
count+=1
plt.show()
def test_forDirsImages(source_data_path,source_gt_path,save_path):
if not os.path.exists(source_data_path):
return
if not os.path.exists(source_gt_path):
return
if not os.path.exists(save_path):
os.makedirs(save_path)
sour_data_path = source_data_path
sour_gt_path = source_gt_path
sav_path = save_path
for i in range(1,100):
source_data_path = os.path.join(source_data_path,str(i))
source_gt_path = os.path.join(source_gt_path,str(i))
save_path = os.path.join(save_path, str(i))
if not os.path.exists(save_path):
os.makedirs(save_path)
test(source_data_path, source_gt_path, save_path)
source_data_path = sour_data_path
source_gt_path = sour_gt_path
save_path = sav_path
if __name__ == '__main__':
pretrained = False
model = Unet(1, 1).to(device)
if pretrained:
model.load_state_dict(torch.load('./weights_149.pth'))
criterion = torch.nn.BCELoss()
optimizer = optim.Adam(model.parameters())
train("I:\\0\\train20191022\\img","I:\\0\\train20191022\\gt")
3 效果展示
说明:左图为原始图像,右图为分割结果
推荐阅读
-
【Pytorch】语义分割、医学图像分割one-hot的两种实现方式
-
【语义分割系列:四】Unet 论文阅读翻译笔记 医学图像 pytorch实现
-
2D UNet3+ Pytorch实现 脑肿瘤分割
-
图像分类的卷积神经网络LeNet,VGG,GoogLeNet,ResNet通俗解读及其pytorch实现
-
TensorFlow keras实现unet网络并进行图像分割入门实例(非常适合新手!)
-
PyTorch:Unet网络实现人脑图像分割
-
Unet图像分割网络Pytorch实现
-
Keras:Unet网络实现多类语义分割方式
-
图像分类的卷积神经网络LeNet,VGG,GoogLeNet,ResNet通俗解读及其pytorch实现