keras框架下 Unet 医学图像分割
程序员文章站
2022-04-08 08:40:59
...
简介
最近在做医学眼底血管抽取的项目,需要用到Unet,关于Unet的介绍,全网各类博客已经讲过很多,我就不再赘述了,主要是讲讲我自己训练实际的情况,和经常遇到的错误
数据集
拿到的图片一般都是以下这样的
我拿到时,总共只有60张,可以说是少的可怜了,我试过直接用60张图片做训练,完全不行,train_loss都达到0.3,预测的图片黑的一片,可以肯定需要做数据增强了,由于对原图进行数据增强的时候,需要同时修改label的图片,我们可以把label的第1通道作为原图的第3通道,得到新的merge图
这样在做数据增强的时候,就可以同时修改原图和label图了,以下是我做数据增强的关键部分代码,需要注意的是,这里的这里我对原图进行了归一化处理,mean和std不是每一次的batch产生的,而是用所有的train图像求出来的,可是原图只有60张不够,所以我另外写了一个脚本,通过数据增强获得了大概1700张图片,再求mean和std,保存成npy,以后训练和预测,都需要读取进来使用
img = load_img(os.path.join(self.image_path, filename), \
target_size=(self.target_size, self.target_size))
img = img_to_array(img)
x_t = img.reshape((1,) + img.shape)
img = self.datagen.flow(x_t, batch_size=1)
batch = img.next()[0]
train = batch[:, :, 0]
mask = batch[:, :, 2]
X[i] = img
y[i] = mask
X = X.astype('float32')
y = y.astype('float32')
X /= 255
#X -= np.mean(x, axis=0)
#X /= np.std(x, axis=0)
X -= mean
X /= std
y /= 255
y[y >= 0.5] = 1
y[y < 0.5] = 0
模型训练
Unet模型,我是通过原模型修改过来的,原模型训练loss是0.5,acc是97,修改的模型loss是0.38,acc是98,在实际预测中抽取血管效果也好一点,实际上就是去掉dropout层,增加BN层,加速收敛控制过拟合等等的好处
要注意由于用了数据增强的方法,输入就不再是RGB 三层的图像了,而是灰度图,所以预测的时候,读取图片需要这样做:
img = load_img(img,grayscale=True)
那么提取的血管总不能是灰色的吧,所以最后预测得到的(512,512)区间(0,1)的结果,乘回去原图就可以获得到血管了
inputs = Input((self.img_rows, self.img_cols, 1))
conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
pool1 = layers.normalization.BatchNormalization()(pool1)
conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
pool2 = layers.normalization.BatchNormalization()(pool2)
conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
pool3 = layers.normalization.BatchNormalization()(pool3)
conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
pool4 = layers.normalization.BatchNormalization()(pool4)
conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = layers.normalization.BatchNormalization()(conv5)
up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(conv5))
merge6 = concatenate([conv4, up6], axis=3)
conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
conv6 = layers.normalization.BatchNormalization()(conv6)
up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(conv6))
merge7 = concatenate([conv3, up7], axis=3)
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
conv7 = layers.normalization.BatchNormalization()(conv7)
up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(conv7))
merge8 = concatenate([conv2, up8], axis=3)
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
conv8 = layers.normalization.BatchNormalization()(conv8)
up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(conv8))
merge9 = concatenate([conv1, up9], axis=3)
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
conv9 = layers.normalization.BatchNormalization()(conv9)
conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)
model = Model(inputs=inputs, outputs=conv10)
model.compile(optimizer=Adam(lr=1e-4, loss="binary_crossentropy", metrics=['accuracy'])
return model
最后放一张实际预测的效果图,这个预测的准确度还是挺高的