keras多任务学习multi-task learning
程序员文章站
2022-05-27 09:47:17
...
base_model选择的是resnet50
resnet_weights_path='resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights=None))
my_new_model.add(Activation('relu'))
my_new_model.add(Dropout(0.5))
my_new_model.summary()
img1 = Input(shape=(224, 224, 3), name='img')
feature = my_new_model(img1)
category = Dense(11, activation='softmax',name='category_out1')(feature)
age = Dense(3,activation='softmax',name='age_out2')(feature)
model = Model(inputs=[img1], outputs=[category, age])
model.compile(optimizer='sgd',
loss={
'category_out1': 'categorical_crossentropy',
'age_out2': 'categorical_crossentropy'
},
loss_weights={
'category_out1': 1.,
'age_out2': 1.
},
metrics=['accuracy'])
数据的generator
import keras
import numpy as np
import cv2
class MultiTaskGenerator(keras.utils.Sequence):
def __init__(self,img_files=None, labels=None,age_labels=None,batch_size=32,n_classes=11,shuffle=True,dim=(224,224,3)):
cls_dict = {'bixiong': 0, 'chaiquan': 1, 'demu': 2, 'fadou': 3, 'guibing': 4, 'jinmao': 5, 'jiwawa': 6, 'keka': 7, 'labu': 8, 'xuenarui': 9, 'yueke': 10}
self.dim = dim
self.batch_size = batch_size
self.img_files = img_files
self.labels = labels
self.age_labels = age_labels
self.n_classes = n_classes
self.shuffle = shuffle
self.cls2id=cls_dict
self.age2id={'old':0,'small':1,'teen':2}
self.on_epoch_end()
def __len__(self):
return int(len(self.img_files)/self.batch_size)
def on_epoch_end(self):
self.indexes = np.arange(len(self.labels))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __getitem__(self,index):
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
X, y,age=self.__data_generation(indexes)
return X,[y,age]
def __data_generation(self,list_IDs_temp):
X = np.empty((self.batch_size,*self.dim))
y = np.empty((self.batch_size),dtype=int)
age = np.empty((self.batch_size),dtype=int)
for i, ID in enumerate(list_IDs_temp):
img = cv2.imread(self.img_files[ID]).astype('float')
label = self.labels[ID]
age_ = self.age_labels[ID]
img = cv2.resize(img,(224,224))
img = img/255
X[i,] = img
y[i] = self.cls2id[label]
age[i] = self.age2id[age_]
return X,keras.utils.to_categorical(y,num_classes=11),keras.utils.to_categorical(age,3)
训练过程
filepath="multi_task-{epoch:02d}-{val_age_out2_acc:.2f}.hdf5"
tensorboard = TensorBoard(log_dir='./logs', histogram_freq=0,
write_graph=True, write_images=False)
checkpoint= ModelCheckpoint(filepath, monitor='val_age_out2_acc', verbose=1, save_best_only=True, mode='max')
model.fit_generator(generator=train_gen,
validation_data= valid_gen,
steps_per_epoch=40,
validation_steps = 4,
callbacks = [checkpoint,tensorboard],
epochs=40
)
上一篇: jupyter notebook连接远程服务器配置
下一篇: 求子序列和的最大值