keras 实现多任务学习
程序员文章站
2022-05-27 09:45:23
...
def deep_multi_model(feature_dim, cvr_label_dim, profit_label_dim):
inputs = Input(shape=(feature_dim,))
dense_1 = Dense(512, activation='relu')(inputs)
dense_2 = Dense(384, activation='relu')(dense_1)
dense_3 = Dense(256, activation='relu')(dense_2)
drop_1 = Dropout(0.2)(dense_3)
dense_4 = Dense(128, activation='relu')(drop_1)
dense_5 = Dense(64, activation='relu')(dense_4)
output_1 = Dense(32, activation='relu')(dense_5)
output_cvr = Dense(cvr_label_dim, activation='softmax', name='output_cvr')(output_1)
output_2 = Dense(16, activation='relu')(dense_5)
output_profit = Dense(profit_label_dim, activation='softmax', name='output_profit')(output_2)
# 模型有两个输出 output_cvr, output_profit
model = Model(inputs=inputs, outputs=[output_cvr, output_profit])
model.summary()
# 模型有两个 loss, 都是 categorical_crossentropy
# loss 的 key 需要和模型的 output 层的 name 保持一致
model.compile(optimizer='adam',
loss={'output_cvr': 'categorical_crossentropy', 'output_profit': 'categorical_crossentropy'},
loss_weights={'output_cvr':1, 'output_profit': 0.3},
metrics=[categorical_accuracy])
return model
# 产生训练数据的生成器
# 模型只有一个 input 有两个 output,所以 yield 格式为如下
def generate_arrays(X_train, y_train_cvr_label, y_train_profit_label):
while True:
for x, y_cvr, y_profit in zip(X_train, y_train_cvr_label, y_train_profit_label):
yield (x[np.newaxis, :], {'output_cvr': y_cvr[np.newaxis, :], 'output_profit': y_profit[np.newaxis, :]})
# fit_generator 进行 fit 训练
def train_multi(X_train, y_train_cvr_label, y_train_profit_label, X_test, y_test_cvr_label, y_test_profit_label):
feature_dim = X_train.shape[1]
cvr_label_dim = y_train_cvr_label.shape[1]
profit_label_dim = y_train_profit_label.shape[1]
model = deep_multi_model(feature_dim, cvr_label_dim, profit_label_dim)
model.summary()
early_stopping = EarlyStopping(monitor='val_loss', patience=15, verbose=0)
model.fit_generator(generate_arrays(X_train, y_train_cvr_label, y_train_profit_label),
steps_per_epoch=1024,
epochs=100,
validation_data=generate_arrays(X_test, y_test_cvr_label, y_test_profit_label),
validation_steps=1024,
callbacks=[early_stopping])
return model
上一篇: php给$
下一篇: pytorch: 自定义损失函数Loss
推荐阅读
-
JavaScript学习笔记整理_简单实现枚举类型,扑克牌应用
-
【莫烦强化学习】视频笔记(二)3.Q_Learning算法实现走迷宫
-
4.18学习笔记 三级联动(对象实现)
-
OpenCV学习笔记(18)双目测距与三维重建的OpenCV实现问题集锦(三)立体匹配与视差计算
-
[机器学习与深度学习] - No.2 遗传算法原理及简单实现
-
OpenGL学习笔记三(引入GLM库,实现transform)
-
Cocos2d游戏开发学习记录——1.Surface、SurfaceView、SurfaceHolder实现简单的游戏demo
-
openCV学习笔记(五):滤波的实现
-
Javascript中从学习bind到实现bind的过程详解
-
MySQL学习笔记之数据的增、删、改实现方法_MySQL