欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch:Unet网络实现人脑图像分割

程序员文章站 2022-04-08 08:45:59
...

1 介绍

U-Net是一篇基本结构非常好的论文,主要是针对生物医学图片的分割,而且,在今后的许多对医学图像的分割网络中,很大一部分会采取U-Net作为网络的主干。相对于当年的,在EM segmentation challenge at ISBI 2012上做到比当时的best更好。而且速度也非常的快。其有一个很好的优点,就是在小数据集上也是能做得比较好的。就比如EM 2012这个数据集就只是30个果蝇第一龄幼虫腹侧神经所索的连续部分透射电子显微镜图。

本文主要利用pytorch实现了U-Net网络的人脑分割,并展示了部分分割效果图。希望对你有所帮助!

2 源代码

(1)网络结构代码

import torch.nn as nn
import torch
from torch import autograd

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)


class Unet(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(Unet, self).__init__()

        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64,out_ch, 1)

    def forward(self,x):
        c1=self.conv1(x)
        p1=self.pool1(c1)
        c2=self.conv2(p1)
        p2=self.pool2(c2)
        c3=self.conv3(p2)
        p3=self.pool3(c3)
        c4=self.conv4(p3)
        p4=self.pool4(c4)
        c5=self.conv5(p4)
        up_6= self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6=self.conv6(merge6)
        up_7=self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7=self.conv7(merge7)
        up_8=self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8=self.conv8(merge8)
        up_9=self.up9(c8)
        merge9=torch.cat([up_9,c1],dim=1)
        c9=self.conv9(merge9)
        c10=self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out

(2)数据集准备

import torch.utils.data as data
import PIL.Image as Image
import os


def make_dataset(rootdata,roottarget):#获取img和mask的地址
    imgs = []
    filename_data = [x for x in os.listdir(rootdata)]
    for name in filename_data:
        img = os.path.join(rootdata, name)
        mask = os.path.join(roottarget, name)
        imgs.append((img, mask))#作为元组返回
    return imgs


class MyDataset(data.Dataset):
    def __init__(self, rootdata, roottarget, transform=None, target_transform=None):
        imgs = make_dataset(rootdata,roottarget)
        #print(imgs)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        # print(x_path)
        img_x = Image.open(x_path).convert('L')#读取并转换为二值图像
        img_y = Image.open(y_path).convert('L')
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        # print(img_x.shape[0])
        # print(img_x.shape[1])
        # print(img_x.shape[2])
        # print(img_x)
        return img_x, img_y

    def __len__(self):
        return len(self.imgs)

 (3)训练与测试

import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torch.autograd import Variable
from torchvision import transforms
from unet import Unet
from dataset import MyDataset
from dataset import make_dataset
import os
import cv2
import time

# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# mask只需要转换为tensor
y_transforms = transforms.ToTensor()


def train_model(model, criterion, optimizer, dataload, num_epochs=150):
    for epoch in range(0,num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataload:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print("%d/%d,train_loss:%0.6f" %
                  (step,(dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.6f" % (epoch, epoch_loss))
    torch.save(model.cpu().state_dict(), 'weights_%d.pth' % epoch)
    torch.save(model.cpu(),'weights_%d_dc.pth' % epoch)
    return model


#训练模型
def train(train_data_path,train_gt_path):
    batch_size = 1
    # liver_dataset = MyDataset(
    #     "image/train/data", "image/train/gt",transform=x_transforms, target_transform=y_transforms)
    liver_dataset = MyDataset(
        train_data_path, train_gt_path, transform=x_transforms, target_transform=y_transforms)
    dataloaders = DataLoader(
        liver_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    train_model(model, criterion, optimizer, dataloaders)


#显示模型的输出结果
def test(test_data_path,test_gt_path,save_pre_path):
    # liver_dataset = MyDataset(
    #     "image/val/data", "image/val/gt", transform=x_transforms, target_transform=y_transforms)
    liver_dataset = MyDataset(
             test_data_path, test_gt_path, transform=x_transforms, target_transform=y_transforms)
    imgs = make_dataset(test_data_path, test_gt_path)
    # print(imgs[0][1])
    print(liver_dataset)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        count = 0
        for x, _ in dataloaders:
            start = time.clock()
            x = x.to('cuda')
            y = model(x)
            img_y = torch.squeeze(y).cpu().numpy()
            elapsed = (time.clock() - start)
            print("Time used:",elapsed)
            plt.imsave(os.path.join(save_pre_path,os.path.basename(imgs[count][1])), img_y)
            count+=1
        plt.show()

def test_forDirsImages(source_data_path,source_gt_path,save_path):
    if not os.path.exists(source_data_path):
        return
    if not os.path.exists(source_gt_path):
        return
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    sour_data_path = source_data_path
    sour_gt_path = source_gt_path
    sav_path = save_path
    for i in range(1,100):
        source_data_path = os.path.join(source_data_path,str(i))
        source_gt_path = os.path.join(source_gt_path,str(i))
        save_path = os.path.join(save_path, str(i))
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        test(source_data_path, source_gt_path, save_path)

        source_data_path = sour_data_path
        source_gt_path = sour_gt_path
        save_path = sav_path

if __name__ == '__main__':
    pretrained = False
    model = Unet(1, 1).to(device)
    if pretrained:
        model.load_state_dict(torch.load('./weights_149.pth'))
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    train("I:\\0\\train20191022\\img","I:\\0\\train20191022\\gt") 

 

3 效果展示

说明:左图为原始图像,右图为分割结果

PyTorch:Unet网络实现人脑图像分割