欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

keras 实现多任务学习

程序员文章站 2022-05-27 09:45:23
...
def deep_multi_model(feature_dim, cvr_label_dim, profit_label_dim):
    inputs = Input(shape=(feature_dim,))
    dense_1 = Dense(512, activation='relu')(inputs)
    dense_2 = Dense(384, activation='relu')(dense_1)
    dense_3 = Dense(256, activation='relu')(dense_2)
    drop_1 = Dropout(0.2)(dense_3)
    dense_4 = Dense(128, activation='relu')(drop_1)
    dense_5 = Dense(64, activation='relu')(dense_4)

    output_1 = Dense(32, activation='relu')(dense_5)
    output_cvr = Dense(cvr_label_dim, activation='softmax', name='output_cvr')(output_1)

    output_2 = Dense(16, activation='relu')(dense_5)
    output_profit = Dense(profit_label_dim, activation='softmax', name='output_profit')(output_2)

    # 模型有两个输出 output_cvr, output_profit
    model = Model(inputs=inputs, outputs=[output_cvr, output_profit])
    model.summary()

    # 模型有两个 loss, 都是 categorical_crossentropy
    # loss 的 key 需要和模型的 output 层的 name 保持一致
    model.compile(optimizer='adam',
              loss={'output_cvr': 'categorical_crossentropy', 'output_profit': 'categorical_crossentropy'},
              loss_weights={'output_cvr':1, 'output_profit': 0.3},
              metrics=[categorical_accuracy])
    
    return model


# 产生训练数据的生成器
# 模型只有一个 input 有两个 output,所以 yield 格式为如下
def generate_arrays(X_train, y_train_cvr_label, y_train_profit_label):
    while True:
        for x, y_cvr, y_profit in zip(X_train, y_train_cvr_label, y_train_profit_label):
            yield (x[np.newaxis, :], {'output_cvr': y_cvr[np.newaxis, :], 'output_profit': y_profit[np.newaxis, :]})


# fit_generator 进行 fit 训练
def train_multi(X_train, y_train_cvr_label, y_train_profit_label, X_test, y_test_cvr_label, y_test_profit_label):
    feature_dim = X_train.shape[1]
    cvr_label_dim = y_train_cvr_label.shape[1]
    profit_label_dim = y_train_profit_label.shape[1]
    
    model = deep_multi_model(feature_dim, cvr_label_dim, profit_label_dim)
    
    model.summary()
    early_stopping = EarlyStopping(monitor='val_loss', patience=15, verbose=0)
    
    
    model.fit_generator(generate_arrays(X_train, y_train_cvr_label, y_train_profit_label),
                        steps_per_epoch=1024, 
                        epochs=100, 
                        validation_data=generate_arrays(X_test, y_test_cvr_label, y_test_profit_label), 
                        validation_steps=1024, 
                        callbacks=[early_stopping])

    return model