欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

keras多任务学习multi-task learning

程序员文章站 2022-05-27 09:47:17
...

base_model选择的是resnet50

resnet_weights_path='resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights=None))
my_new_model.add(Activation('relu'))
my_new_model.add(Dropout(0.5))
my_new_model.summary()
img1 = Input(shape=(224, 224, 3), name='img')
feature = my_new_model(img1)
category = Dense(11, activation='softmax',name='category_out1')(feature)
age = Dense(3,activation='softmax',name='age_out2')(feature)
model = Model(inputs=[img1], outputs=[category, age])
model.compile(optimizer='sgd',
              loss={
                  'category_out1': 'categorical_crossentropy',
                  'age_out2': 'categorical_crossentropy'
              },
              loss_weights={
                  'category_out1': 1.,
                  'age_out2': 1.
              },
              metrics=['accuracy'])

数据的generator

import keras
import numpy as np
import cv2


class MultiTaskGenerator(keras.utils.Sequence):
    def __init__(self,img_files=None, labels=None,age_labels=None,batch_size=32,n_classes=11,shuffle=True,dim=(224,224,3)):
        cls_dict = {'bixiong': 0, 'chaiquan': 1, 'demu': 2, 'fadou': 3, 'guibing': 4, 'jinmao': 5, 'jiwawa': 6, 'keka': 7, 'labu': 8, 'xuenarui': 9, 'yueke': 10}
        self.dim = dim
        self.batch_size = batch_size
        self.img_files = img_files
        self.labels = labels
        self.age_labels = age_labels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.cls2id=cls_dict
        self.age2id={'old':0,'small':1,'teen':2}
        self.on_epoch_end()
        
    def __len__(self):
        return int(len(self.img_files)/self.batch_size)
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.labels))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def __getitem__(self,index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        X, y,age=self.__data_generation(indexes)
        return X,[y,age]
        
    def __data_generation(self,list_IDs_temp):
        X = np.empty((self.batch_size,*self.dim))
        y = np.empty((self.batch_size),dtype=int)
        age = np.empty((self.batch_size),dtype=int)
        for i, ID in enumerate(list_IDs_temp):
            img = cv2.imread(self.img_files[ID]).astype('float')
            label = self.labels[ID]
            age_ = self.age_labels[ID]
            img = cv2.resize(img,(224,224))
            img = img/255
            X[i,] = img
            y[i] = self.cls2id[label]
            age[i] = self.age2id[age_]

            
        return X,keras.utils.to_categorical(y,num_classes=11),keras.utils.to_categorical(age,3)
    

训练过程

filepath="multi_task-{epoch:02d}-{val_age_out2_acc:.2f}.hdf5"
tensorboard = TensorBoard(log_dir='./logs', histogram_freq=0,
                          write_graph=True, write_images=False)
checkpoint= ModelCheckpoint(filepath, monitor='val_age_out2_acc', verbose=1, save_best_only=True, mode='max')
model.fit_generator(generator=train_gen,
                   validation_data= valid_gen,
                           steps_per_epoch=40,
                    validation_steps = 4,
                           callbacks = [checkpoint,tensorboard],
                           epochs=40                      
                            )

 

相关标签: multi-task learning