欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

keras 多输入多输出网络

程序员文章站 2022-05-26 19:30:09
...

keras中的多输入多输出网络

多输入多输出网络搭建的官网介绍:
http://keras-cn.readthedocs.io/en/latest/getting_started/functional_API/

Demo:

from keras.applications.mobilenet import MobileNet
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.vgg19 import VGG19
from keras.applications.xception import Xception

def generate_model(application, num_class, img_size, pre_weights=None):
    if application == 'InceptionV3':
        base_model = InceptionV3(input_shape=(img_size, img_size, 3),
                                 include_top=False,
                                 weights=pre_weights)
    elif application == 'MobileNet':
        base_model = MobileNet(input_shape=(img_size, img_size, 3),
                               include_top=False,
                               weights=pre_weights)
    elif application == 'VGG19':
        base_model = VGG19(input_shape=(img_size, img_size, 3),
                           weights=pre_weights,
                           include_top=None)
    elif application == 'InceptionResNetV2':
        base_model = InceptionResNetV2(input_shape=(img_size, img_size, 3),
                                       weights=pre_weights,
                                       include_top=None)
    elif application == 'Xception':
        base_model = Xception(input_shape=(img_size, img_size, 3),
                              weights=pre_weights,
                              include_top=None)
    else:
        raise ('No specific aplication type!')

    x = base_model.output
    feature = Flatten(name='feature')(x)
    predictions = Dropout(0.5)(feature)
    #x = GlobalAveragePooling2D()(x)
    #predictions = Dense(1024, activation='relu')(x)
    predictions = Dense(num_class, activation='softmax',
                        name='pred',
                        kernel_initializer=RandomNormal(mean=0.0, stddev=0.001))(predictions)
    model = Model(inputs=base_model.input, outputs=[predictions, feature])
    #Model.summary(model)
    return model

该函数基于keras自带的分类网络,定义了一个单输入双输出的网络
- 输入:(img_size, img_size, 3)的三通道图像
- 输出1:softmax后输出的分类类别,损失函数为多分类交叉熵,输出accuracy
- 输出2:softmax前模型输出的特征向量,损失函数为自定义的Triplet loss