[pytorch]医学图像之肝脏语义分割(训练+预测代码)
程序员文章站
2022-07-05 11:06:21
...
一,Unet结构:
结合上图的Unet结构,pytorch的unet代码如下:
unet.py:
import torch.nn as nn
import torch
from torch import autograd
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64, out_ch, 1)
def forward(self, x):
c1 = self.conv1(x)
p1 = self.pool1(c1)
c2 = self.conv2(p1)
p2 = self.pool2(c2)
c3 = self.conv3(p2)
p3 = self.pool3(c3)
c4 = self.conv4(p3)
p4 = self.pool4(c4)
c5 = self.conv5(p4)
up_6 = self.up6(c5)
merge6 = torch.cat([up_6, c4], dim=1)
c6 = self.conv6(merge6)
up_7 = self.up7(c6)
merge7 = torch.cat([up_7, c3], dim=1)
c7 = self.conv7(merge7)
up_8 = self.up8(c7)
merge8 = torch.cat([up_8, c2], dim=1)
c8 = self.conv8(merge8)
up_9 = self.up9(c8)
merge9 = torch.cat([up_9, c1], dim=1)
c9 = self.conv9(merge9)
c10 = self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
二,数据集:肝脏图
三,代码
主代码:main.py
import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
from mIou import *
import os
import cv2
# 是否使用cuda
import PIL.Image as Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_data(i):
import dataset
imgs = dataset.make_dataset(r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\data\val")
imgx = []
imgy = []
for img in imgs:
imgx.append(img[0])
imgy.append(img[1])
return imgx[i],imgy[i]
def train_model(model, criterion, optimizer, dataload, num_epochs=21):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
dt_size = len(dataload.dataset)
epoch_loss = 0
step = 0
for x, y in dataload:
step += 1
inputs = x.to(device)
labels = y.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
#torch.save(model.state_dict(), './weights_%d.pth' % epoch)
torch.save(model.state_dict(), r'H:\BaiduNetdisk\BaiduDownload\u_net_liver-master/weights.pth')
return model
# 训练模型
def train():
model = Unet(3, 1).to(device)
batch_size = args.batch_size
criterion = torch.nn.BCELoss()
optimizer = optim.Adam(model.parameters())
liver_dataset = LiverDataset(r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\data\train", transform=x_transforms, target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
train_model(model, criterion, optimizer, dataloaders)
# 显示模型的输出结果
def test():
model = Unet(3, 1).to(device) #unet输入是三通道,输出是一通道,因为不算上背景只有肝脏一个类别
model.load_state_dict(torch.load(args.ckp, map_location='cpu')) #载入训练好的模型
liver_dataset = LiverDataset(r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\data\val", transform=x_transforms, target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=1)
model.eval()
import matplotlib.pyplot as plt
plt.ion() #开启动态模式
with torch.no_grad():
i=0 #验证集中第i张图
miou_total = 0
num = len(dataloaders) #验证集图片的总数
for x, _ in dataloaders:
x = x.to(device)
y = model(x)
img_y = torch.squeeze(y).cpu().numpy() #输入损失函数之前要把预测图变成numpy格式,且为了跟训练图对应,要额外加多一维表示batchsize
mask = get_data(i)[1] #得到当前mask的路径
miou_total += get_iou(mask,img_y) #获取当前预测图的miou,并加到总miou中
plt.subplot(121)
plt.imshow(Image.open(get_data(i)[0]))
plt.subplot(122)
plt.imshow(img_y)
plt.pause(0.01)
if i < num:i+=1 #处理验证集下一张图
plt.show()
print('Miou=%f' % (miou_total / 20))
if __name__ =="__main__":
x_transforms = transforms.Compose([
transforms.ToTensor(), # -> [0,1]
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # ->[-1,1]
])
# mask只需要转换为tensor
y_transforms = transforms.ToTensor()
# 参数解析器,用来解析从终端读取的命令
parse = argparse.ArgumentParser()
#parse = argparse.ArgumentParser()
parse.add_argument("--action", type=str, help="train or test",default="train")
parse.add_argument("--batch_size", type=int, default=1)
parse.add_argument("--ckp", type=str, help="the path of model weight file")
args = parse.parse_args()
# train
# train() #测试时,就把此train()语句注释掉
# test()
args.ckp = r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\weights.pth"
test()
获取数据代码:
dataset.py:
import torch.utils.data as data
import PIL.Image as Image
import os
def make_dataset(root):
imgs = []
n = len(os.listdir(root)) // 2 #因为数据集中一套训练数据包含有训练图和mask图,所以要除2
for i in range(n):
img = os.path.join(root, "%03d.png" % i)
mask = os.path.join(root, "%03d_mask.png" % i)
imgs.append((img, mask))
return imgs
class LiverDataset(data.Dataset):
def __init__(self, root, transform=None, target_transform=None):
imgs = make_dataset(root)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
x_path, y_path = self.imgs[index]
origin_x = Image.open(x_path)
origin_y = Image.open(y_path)
if self.transform is not None:
img_x = self.transform(origin_x)
if self.target_transform is not None:
img_y = self.target_transform(origin_y)
return img_x, img_y
def __len__(self):
return len(self.imgs)
四,网络结果:
网络结果一般用
1.直观效果 或者
2.指标来指定
语义分割中有一个重要的指标就是miou,平均交并比,其中miou的代码如下:
mIou.py:
import cv2
import numpy as np
class IOUMetric:
"""
Class to calculate mean-iou using fast_hist method
"""
def __init__(self, num_classes):
self.num_classes = num_classes
self.hist = np.zeros((num_classes, num_classes))
def _fast_hist(self, label_pred, label_true):
mask = (label_true >= 0) & (label_true < self.num_classes)
hist = np.bincount(
self.num_classes * label_true[mask].astype(int) +
label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
return hist
def add_batch(self, predictions, gts):
for lp, lt in zip(predictions, gts):
self.hist += self._fast_hist(lp.flatten(), lt.flatten())
def evaluate(self):
acc = np.diag(self.hist).sum() / self.hist.sum()
acc_cls = np.diag(self.hist) / self.hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
mean_iu = np.nanmean(iu)
freq = self.hist.sum(axis=1) / self.hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
return acc, acc_cls, iu, mean_iu, fwavacc
def get_iou(mask_name,predict):
image_mask = cv2.imread(mask_name,0)
# print(image.shape)
height = predict.shape[0]
weight = predict.shape[1]
# print(height*weight)
o = 0
for row in range(height):
for col in range(weight):
if predict[row, col] < 0.5: #由于输出的predit是0~1范围的,其中值越靠近1越被网络认为是肝脏目标,所以取0.5为阈值
predict[row, col] = 0
else:
predict[row, col] = 1
if predict[row, col] == 0 or predict[row, col] == 1:
o += 1
height_mask = image_mask.shape[0]
weight_mask = image_mask.shape[1]
for row in range(height_mask):
for col in range(weight_mask):
if image_mask[row, col] < 125: #由于mask图是黑白的灰度图,所以少于125的可以看作是黑色
image_mask[row, col] = 0
else:
image_mask[row, col] = 1
if image_mask[row, col] == 0 or image_mask[row, col] == 1:
o += 1
predict = predict.astype(np.int16)
Iou = IOUMetric(2) #2表示类别,肝脏类+背景类
Iou.add_batch(predict, image_mask)
a, b, c, d, e= Iou.evaluate()
print('%s:iou=%f' % (mask_name,d))
return d
五,运行结果:
验证集的miou:
六,代码和数据集获取:
代码和数据集都在以下链接了:
https://pan.baidu.com/s/1ej9Maetb5pJ_mjMAfg9Lsg
提取码:e203
上一篇: 判断文件是否是目录