欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow 2.0 医学图像分割(肝脏CT图像分割)

程序员文章站 2024-03-19 16:33:46
...

1.数据集

肝脏分割数据集,训练集一共有400张肝脏CT图像以及对应的分割模板,验证集一共有20张肝脏CT图像以及对应的分割模板,如下图所示
Tensorflow 2.0 医学图像分割(肝脏CT图像分割)
Tensorflow 2.0 医学图像分割(肝脏CT图像分割)

2.代码

model.py

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Activation, BatchNormalization, Conv2D, Conv2DTranspose, MaxPooling2D, concatenate

#定义U-Net网络模型
def unet(input_size=(416, 416, 3)):
    inputs = Input(input_size)

    conv1 = Conv2D(64, (3, 3), padding='same')(inputs)
    bn1 = Activation('relu')(conv1)
    conv1 = Conv2D(64, (3, 3), padding='same')(bn1)
    bn1 = BatchNormalization(axis=3)(conv1)
    bn1 = Activation('relu')(bn1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(bn1)

    conv2 = Conv2D(128, (3, 3), padding='same')(pool1)
    bn2 = Activation('relu')(conv2)
    conv2 = Conv2D(128, (3, 3), padding='same')(bn2)
    bn2 = BatchNormalization(axis=3)(conv2)
    bn2 = Activation('relu')(bn2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(bn2)

    conv3 = Conv2D(256, (3, 3), padding='same')(pool2)
    bn3 = Activation('relu')(conv3)
    conv3 = Conv2D(256, (3, 3), padding='same')(bn3)
    bn3 = BatchNormalization(axis=3)(conv3)
    bn3 = Activation('relu')(bn3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(bn3)

    conv4 = Conv2D(512, (3, 3), padding='same')(pool3)
    bn4 = Activation('relu')(conv4)
    conv4 = Conv2D(512, (3, 3), padding='same')(bn4)
    bn4 = BatchNormalization(axis=3)(conv4)
    bn4 = Activation('relu')(bn4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(bn4)

    conv5 = Conv2D(1024, (3, 3), padding='same')(pool4)
    bn5 = Activation('relu')(conv5)
    conv5 = Conv2D(1024, (3, 3), padding='same')(bn5)
    bn5 = BatchNormalization(axis=3)(conv5)
    bn5 = Activation('relu')(bn5)

    up6 = concatenate([Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(bn5), conv4], axis=3)
    conv6 = Conv2D(512, (3, 3), padding='same')(up6)
    bn6 = Activation('relu')(conv6)
    conv6 = Conv2D(512, (3, 3), padding='same')(bn6)
    bn6 = BatchNormalization(axis=3)(conv6)
    bn6 = Activation('relu')(bn6)

    up7 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(bn6), conv3], axis=3)
    conv7 = Conv2D(256, (3, 3), padding='same')(up7)
    bn7 = Activation('relu')(conv7)
    conv7 = Conv2D(256, (3, 3), padding='same')(bn7)
    bn7 = BatchNormalization(axis=3)(conv7)
    bn7 = Activation('relu')(bn7)

    up8 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(bn7), conv2], axis=3)
    conv8 = Conv2D(128, (3, 3), padding='same')(up8)
    bn8 = Activation('relu')(conv8)
    conv8 = Conv2D(128, (3, 3), padding='same')(bn8)
    bn8 = BatchNormalization(axis=3)(conv8)
    bn8 = Activation('relu')(bn8)

    up9 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(bn8), conv1], axis=3)
    conv9 = Conv2D(64, (3, 3), padding='same')(up9)
    bn9 = Activation('relu')(conv9)
    conv9 = Conv2D(64, (3, 3), padding='same')(bn9)
    bn9 = BatchNormalization(axis=3)(conv9)
    bn9 = Activation('relu')(bn9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(bn9)

    return Model(inputs=[inputs], outputs=[conv10])

train.py

import os
from model import unet
import tensorflow as tf
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras import backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint

#图片大小
im_width = 416
im_height = 416

BATCH_SIZE = 8

train_files = './data/train/'
val_files = './data/val/'

if not os.path.exists("save_weights"):
    os.makedirs("save_weights")


train_num = len(os.listdir(os.path.join(train_files,'image')))
val_num = len(os.listdir(os.path.join(val_files,'image')))


# 定义数据生成器函数
def data_generator(data_path, batch_size, aug_dict,
                   image_color_mode="rgb",
                   mask_color_mode="grayscale",
                   image_save_prefix="image",
                   mask_save_prefix="mask",
                   save_to_dir=None,
                   target_size=(im_width, im_height),
                   seed=1):
    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)

    image_generator = image_datagen.flow_from_directory(
        data_path,
        classes=['image'],
        class_mode=None,
        color_mode=image_color_mode,
        target_size=target_size,
        batch_size=batch_size,
        save_to_dir=save_to_dir,
        save_prefix=image_save_prefix,
        seed=seed)

    mask_generator = mask_datagen.flow_from_directory(
        data_path,
        classes=['label'],
        class_mode=None,
        color_mode=mask_color_mode,
        target_size=target_size,
        batch_size=batch_size,
        save_to_dir=save_to_dir,
        save_prefix=mask_save_prefix,
        seed=seed)

    train_gen = zip(image_generator, mask_generator)

    for (img, mask) in train_gen:
        img, mask = adjust_data(img, mask)
        yield (img, mask)


# 处理数据函数
def adjust_data(img, mask):
    img = img / 255
    mask = mask / 255
    mask[mask > 0.5] = 1
    mask[mask <= 0.5] = 0

    return (img, mask)


smooth = 100


# 定义Dice系数
def dice_coef(y_true, y_pred):
    y_truef = K.flatten(y_true)
    y_predf = K.flatten(y_pred)
    And = K.sum(y_truef * y_predf)
    return ((2 * And + smooth) / (K.sum(y_truef) + K.sum(y_predf) + smooth))


# 定义损失函数
def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)


# 定义iou函数
def iou(y_true, y_pred):
    intersection = K.sum(y_true * y_pred)
    sum_ = K.sum(y_true + y_pred)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return jac


def jac_distance(y_true, y_pred):
    y_truef = K.flatten(y_true)
    y_predf = K.flatten(y_pred)

    return - iou(y_true, y_pred)

model = unet()
# 打印模型参数
model.summary()

#训练集图片做数据增强
train_generator_args = dict(rotation_range=0.2,
                            width_shift_range=0.05,
                            height_shift_range=0.05,
                            shear_range=0.05,
                            zoom_range=0.05,
                            horizontal_flip=True,
                            fill_mode='nearest')
#生成训练数据
train_gen = data_generator(train_files, BATCH_SIZE,
                                train_generator_args,
                                target_size=(im_height, im_width))
#生成验证数据
val_gen = data_generator(val_files, BATCH_SIZE,
                                dict(),
                                target_size=(im_height, im_width))

checkpoint = ModelCheckpoint(
                                filepath='./save_weights/myUnet.h5',
                                monitor='val_acc',
                                save_weights_only=False,
                                save_best_only=True,
                                mode='auto',
                                period=1
                            )

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, verbose=1)

#编译模型
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss=dice_coef_loss, metrics=["binary_accuracy", iou, dice_coef])

history = model.fit(train_gen,
                    steps_per_epoch= train_num // BATCH_SIZE,
                    epochs=30,
                    validation_data = val_gen,
                    validation_steps= val_num // BATCH_SIZE,
                    callbacks=[checkpoint,reduce_lr])

model.save_weights('./save_weights/myUnet.h5')

predict.py

from model import unet
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt


val_img_files = './data/val/image/'
val_label_files = './data/val/label/'

images = os.listdir(val_img_files)
mask = os.listdir(val_label_files)

model = unet()
model.load_weights('./save_weights/myUnet.h5')

prediction_overlap = []
for i in range(20):
    img = cv2.imread(os.path.join(val_img_files,images[i]))
    img = cv2.resize(img ,(416, 416))
    img1 = img / 255
    img1 = img1[np.newaxis, :, :, :]
    prediction=model.predict(img1)
    prediction = np.squeeze(prediction)
    ground_truth = cv2.resize(cv2.imread(os.path.join(val_label_files,mask[i])),(416,416))
    ground_truth = cv2.cvtColor(ground_truth, cv2.COLOR_BGR2GRAY)
    ret, thresh_gt= cv2.threshold(ground_truth, 127, 255, 0)
    contours, hierarchy = cv2.findContours(thresh_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    overlap_img = cv2.drawContours(img, contours, -1, (0, 255, 0),2)

    prediction[np.nonzero(prediction < 0.3)] = 0.0
    prediction[np.nonzero(prediction >= 0.3)] = 255.
    prediction = prediction.astype("uint8")
    ret_p, thresh_p = cv2.threshold(prediction, 127, 255, 0)
    contours_p, hierarchy_p = cv2.findContours(thresh_p, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    overlap_img = cv2.drawContours(img, contours_p, -1, (255,36,0),2)
    prediction_overlap.append(overlap_img)

plt.figure(figsize=(20,20))
for i in range(len(prediction_overlap)):
    plt.subplot(4,5,i+1)
    plt.imshow(prediction_overlap[i])
    plt.title('Predict Image')
    plt.xticks([])
    plt.yticks([])


plt.savefig("Predict.png")
plt.show()

3.结果

绿色轮廓为真实分割结果,红色轮廓为预测分割结果
Tensorflow 2.0 医学图像分割(肝脏CT图像分割)
肝脏分割数据集和整个项目文件下载地址:https://download.csdn.net/download/qq_38412266/12628230