keras 多输入多输出网络
程序员文章站
2022-05-26 19:30:09
...
keras中的多输入多输出网络
多输入多输出网络搭建的官网介绍:
http://keras-cn.readthedocs.io/en/latest/getting_started/functional_API/
Demo:
from keras.applications.mobilenet import MobileNet
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.vgg19 import VGG19
from keras.applications.xception import Xception
def generate_model(application, num_class, img_size, pre_weights=None):
if application == 'InceptionV3':
base_model = InceptionV3(input_shape=(img_size, img_size, 3),
include_top=False,
weights=pre_weights)
elif application == 'MobileNet':
base_model = MobileNet(input_shape=(img_size, img_size, 3),
include_top=False,
weights=pre_weights)
elif application == 'VGG19':
base_model = VGG19(input_shape=(img_size, img_size, 3),
weights=pre_weights,
include_top=None)
elif application == 'InceptionResNetV2':
base_model = InceptionResNetV2(input_shape=(img_size, img_size, 3),
weights=pre_weights,
include_top=None)
elif application == 'Xception':
base_model = Xception(input_shape=(img_size, img_size, 3),
weights=pre_weights,
include_top=None)
else:
raise ('No specific aplication type!')
x = base_model.output
feature = Flatten(name='feature')(x)
predictions = Dropout(0.5)(feature)
#x = GlobalAveragePooling2D()(x)
#predictions = Dense(1024, activation='relu')(x)
predictions = Dense(num_class, activation='softmax',
name='pred',
kernel_initializer=RandomNormal(mean=0.0, stddev=0.001))(predictions)
model = Model(inputs=base_model.input, outputs=[predictions, feature])
#Model.summary(model)
return model
该函数基于keras自带的分类网络,定义了一个单输入双输出的网络
- 输入:(img_size, img_size, 3)的三通道图像
- 输出1:softmax后输出的分类类别,损失函数为多分类交叉熵,输出accuracy
- 输出2:softmax前模型输出的特征向量,损失函数为自定义的Triplet loss