欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

keras 多输入模型

程序员文章站 2022-05-26 19:17:57
...
def build_model(product_shape, level_shape, attr_shape, period_shape):
    product_inputs = keras.Input(shape=(product_shape, ))
    level_inputs = keras.Input(shape=(level_shape, ))
    attr_inputs = keras.Input(shape=(attr_shape, ))
    period_inputs = keras.Input(shape=(period_shape, ))

    product_dense = keras.layers.Dense(256, activation='relu')(product_inputs)
    product_dense= keras.layers.BatchNormalization()( product_dense)
    

    laptop_inputs = keras.layers.concatenate([product_dense, level_inputs, attr_inputs, period_inputs])
    laptop_dense = keras.layers.Dense(256, activation='relu')(laptop_inputs)
    laptop_dense = keras.layers.BatchNormalization()( laptop_dense)
    laptop_dense = keras.layers.Dense(128, activation='relu')(laptop_dense)
    laptop_dense = keras.layers.BatchNormalization()( laptop_dense)
    laptop_dense = keras.layers.Dense(64, activation='relu')(laptop_dense)
    laptop_dense = keras.layers.BatchNormalization()( laptop_dense)
    outputs = keras.layers.Dense(1, activation='linear')(laptop_dense)



    model = keras.Model(inputs=[product_inputs, level_inputs, attr_inputs, period_inputs], outputs=outputs)
    opt = keras.optimizers.Adam()
#    opt = keras.optimizers.RMSprop(lr=3e-3)
    model.compile(optimizer=opt, loss='mse')
    
    return model