欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Keras框架使用Vnet2d模型对遥感图像语义分割

程序员文章站 2024-03-15 10:39:35
...

数据集:Keras框架使用Vnet2d模型对遥感图像语义分割

 Keras框架使用Vnet2d模型对遥感图像语义分割

 

 

vnet2d网络如下,只分隔一个目标(大棚),所以是个二分类问题,因为目标区域较大这里我们采用传统的二分类交叉熵binary_crossentropy损失。



import numpy as np 
import os
import skimage.io as io
import skimage.transform as trans
import keras.backend as K
from keras.utils import multi_gpu_model
#from keras import optimizers
#from keras.utils import plot_model#使用plot_mode时打开
from keras.models import Model
from keras.layers import Conv2D,PReLU,Conv2DTranspose,add,concatenate,Input,Dropout,BatchNormalization,Activation
from keras.optimizers import Nadam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler

# from loss import generalized_dice_loss_fun,generalized_dice_coef_fun,dice_coef,dice_loss
#from keras import backend as keras
from keras.models import Model,load_model
from keras.layers import MaxPooling2D,UpSampling2D

from loss import *


def resBlock(conv,stage,keep_prob,stage_num=5):#收缩路径
    
    inputs=conv
    for _ in range(3 if stage>3 else stage):
        conv=PReLU()(BatchNormalization()(Conv2D(16*(2**(stage-1)), 5, activation = None, padding = 'same', kernel_initializer = 'he_normal')(conv)))
        #print('conv_down_stage_%d:' %stage,conv.get_shape().as_list())#输出收缩路径中每个stage内的卷积
    conv_add=PReLU()(add([inputs,conv]))
    #print('conv_add:',conv_add.get_shape().as_list())
    conv_drop=Dropout(keep_prob)(conv_add)
    
    if stage<stage_num:
        conv_downsample=PReLU()(BatchNormalization()(Conv2D(16*(2**stage), 2, strides=(2, 2),activation = None, padding = 'same', kernel_initializer = 'he_normal')(conv_drop)))
        return conv_downsample,conv_add#返回每个stage下采样后的结果,以及在相加之前的结果
    else:
        return conv_add,conv_add#返回相加之后的结果,为了和上面输出保持一致,所以重复输出
        
def up_resBlock(forward_conv,input_conv,stage):#扩展路径
    
    conv=concatenate([forward_conv,input_conv],axis = -1)
    print('conv_concatenate:',conv.get_shape().as_list())
    for _ in range(3 if stage>3 else stage):
        conv=PReLU()(BatchNormalization()(Conv2D(16*(2**(stage-1)), 5, activation = None, padding = 'same', kernel_initializer = 'he_normal')(conv)))
        print('conv_up_stage_%d:' %stage,conv.get_shape().as_list())#输出扩展路径中每个stage内的卷积
    conv_add=PReLU()(add([input_conv,conv]))
    if stage>1:
        conv_upsample=PReLU()(BatchNormalization()(Conv2DTranspose(16*(2**(stage-2)),2,strides=(2, 2),padding='valid',activation = None,kernel_initializer = 'he_normal')(conv_add)))
        return conv_upsample
    else:
        return conv_add

def vnet(pretrained_weights = None,input_size = (320,320,3),num_class=1,is_training=True,stage_num=6,thresh=0.5):#二分类时num_classes设置成1,不是2,stage_num可自行改变,也即可自行改变网络深度
    keep_prob = 1.0 if is_training else 1.0#不使用dropout
    features=[]
    input_model = Input(input_size)
    x=PReLU()(BatchNormalization()(Conv2D(16, 5, activation = None, padding = 'same', kernel_initializer = 'he_normal')(input_model)))
    
    for s in range(1,stage_num+1):
        x,feature=resBlock(x,s,keep_prob,stage_num)#调用收缩路径
        features.append(feature)
        
    conv_up=PReLU()(BatchNormalization()(Conv2DTranspose(16*(2**(s-2)),2,strides=(2, 2),padding='valid',activation = None,kernel_initializer = 'he_normal')(x)))
    
    for d in range(stage_num-1,0,-1):
        conv_up=up_resBlock(features[d-1],conv_up,d)#调用扩展路径
    if num_class>1:
        conv_out=Conv2D(num_class, 1, activation = 'softmax', padding = 'same', kernel_initializer = 'he_normal')(conv_up)
    else:
        conv_out=Conv2D(num_class, 1, activation = 'sigmoid', padding = 'same', kernel_initializer = 'he_normal')(conv_up)
     
    
    model=Model(inputs=input_model,outputs=conv_out)
    print(model.output_shape)
    
    model_dice=dice_loss(smooth=1e-5)
    # model_dice=generalized_dice_loss_fun(smooth=1e-5)

    model.compile(optimizer = Nadam(lr = 2e-4), loss = "binary_crossentropy", metrics = ['accuracy'])
    # model.compile(optimizer = Nadam(lr = 2e-4), loss = model_dice, metrics = ['accuracy'])

    #不使用metric
    # model.compile(optimizer = Nadam(lr = 2e-4), loss = model_dice)
    #plot_model(model, to_file='model.png')
    if(pretrained_weights):
    	model.load_weights(pretrained_weights)
    return model

loss与精准度曲线:

Keras框架使用Vnet2d模型对遥感图像语义分割

 

Keras框架使用Vnet2d模型对遥感图像语义分割

 

效果测试:

Keras框架使用Vnet2d模型对遥感图像语义分割

Keras框架使用Vnet2d模型对遥感图像语义分割

相关标签: 语义分割