keras 多输入模型
程序员文章站
2022-05-26 19:17:57
...
def build_model(product_shape, level_shape, attr_shape, period_shape):
product_inputs = keras.Input(shape=(product_shape, ))
level_inputs = keras.Input(shape=(level_shape, ))
attr_inputs = keras.Input(shape=(attr_shape, ))
period_inputs = keras.Input(shape=(period_shape, ))
product_dense = keras.layers.Dense(256, activation='relu')(product_inputs)
product_dense= keras.layers.BatchNormalization()( product_dense)
laptop_inputs = keras.layers.concatenate([product_dense, level_inputs, attr_inputs, period_inputs])
laptop_dense = keras.layers.Dense(256, activation='relu')(laptop_inputs)
laptop_dense = keras.layers.BatchNormalization()( laptop_dense)
laptop_dense = keras.layers.Dense(128, activation='relu')(laptop_dense)
laptop_dense = keras.layers.BatchNormalization()( laptop_dense)
laptop_dense = keras.layers.Dense(64, activation='relu')(laptop_dense)
laptop_dense = keras.layers.BatchNormalization()( laptop_dense)
outputs = keras.layers.Dense(1, activation='linear')(laptop_dense)
model = keras.Model(inputs=[product_inputs, level_inputs, attr_inputs, period_inputs], outputs=outputs)
opt = keras.optimizers.Adam()
# opt = keras.optimizers.RMSprop(lr=3e-3)
model.compile(optimizer=opt, loss='mse')
return model