欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

mxnet finetune例子(只finetune某几层)

程序员文章站 2022-05-27 09:33:47
...

import logging

import mxnet as mx

import numpy as np

import os.path, time,sys

from mAP_metric import mAP

 

print ("\n******File updated %ds ago%s******" % (time.time()-os.path.getmtime(sys.argv[0])))# file updatation check

 

# data iterators: generate data iterator from .rec file

def get_iterators(batch_size, rec_train, rec_val, lst_train, data_shape=(3, 224, 224)):

    train = mx.io.ImageRecordIter(

        path_imgrec=rec_train,

        path_imglist=lst_train,

        data_name='data',

        label_name='softmax_label',

        batch_size=batch_size,

        data_shape=data_shape,

        shuffle=True,

        # shuffle=False,

        rand_crop=True,

        mirror =True,

        rand_mirror=True,

        max_rotate_angle=0)

    val = mx.io.ImageRecordIter(

        path_imgrec=rec_val,

        data_name='data',

        label_name='softmax_label',

        batch_size=batch_size,

        data_shape=data_shape)

    return train,val

 

# load and tune model

def get_fine_tune_model(model_name):

    # load model

    symbol, arg_params, aux_params = mx.model.load_checkpoint("data/pre_train/"+model_name, 0)

    # model tuning

    all_layers = symbol.get_internals()

    if model_name=="vgg16":

        net = all_layers['drop7_output']

    else:

        net = all_layers['flatten0_output']

    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='newfc1')

    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')

    # eliminate weights of new layer

    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})

    return (net, new_args,aux_params)

 

#model training

def fit(symbol, arg_params, aux_params, iter_train, iter_val, class_str, num_epoch, batch_size, gpu_avaliable):

    devs = [mx.gpu(i) for i in gpu_avaliable]

    model = mx.mod.Module(symbol=symbol, context=devs)

    # metric

    com_metric = mx.metric.CompositeEvalMetric()

    com_metric.add(mx.metric.Accuracy())

    com_metric.add(mAP(class_str)) # remove if unnecessary

    # optimizer: fix the weight of certain layers except the last fully connect layer

    sgd = mx.optimizer.Optimizer.create_optimizer('sgd')

    finetune_lr = dict({k: 0 for k in arg_params})

    sgd.set_lr_mult(finetune_lr)

    # training

    model.fit(iter_train, iter_val,

        num_epoch=num_epoch,

        arg_params=arg_params,

        aux_params=aux_params,

        allow_missing=True,

        batch_end_callback = mx.callback.Speedometer(batch_size, 10),

        kvstore='device',

        optimizer=sgd,

        optimizer_params={'learning_rate':0.01},

        initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),

        eval_metric='acc')

    return model.score(iter_val, com_metric)

 

#=======================================================================================================================

# set logger, print message on screen and file

logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s',filename='acc_record.log',filemode='w')

console = logging.StreamHandler()

console.setLevel(logging.INFO)

console.setFormatter(logging.Formatter('%(asctime)-15s %(message)s'))

logging.getLogger('').addHandler(console)

 

# data and pre-train model

rec_train='./data/rec/hico_train_full.rec'

# rec_train='./data/rec/hico_train_200500.rec'

model_name='vgg16'

# model_name='resnet-152'

rec_val='./data/rec/hico_val.rec'

lst_train=rec_train[:-3]+'lst'

 

# parameter

num_classes = 600

class_str=[]

for i in range(num_classes):

    class_str.append("c"+str(i))

batch_per_gpu = 40

num_epoch =10

gpu_avaliable=[0,1,2,3]

num_gpus = len(gpu_avaliable)

batch_size = batch_per_gpu * num_gpus

if rec_train=='./data/rec/hico_train_full.rec':

    print ('-----------Batchs per epoch: %d-----------' % (7000.0/batch_size))

if rec_train=='./data/rec/hico_train_200500.rec':

    print ('-----------Batchs per epoch: %d-----------' % (137120.0/batch_size))

#-----------------------------------------------------------------------------------------------------------------------

 

(new_sym,new_args,aux_params)=get_fine_tune_model(model_name)

(iter_train, iter_val) = get_iterators(batch_size,rec_train,rec_val,lst_train)

mod_score = fit(new_sym, new_args, aux_params, iter_train, iter_val, class_str, num_epoch, batch_size, gpu_avaliable)

print(mod_score)


 

from:https://blog.csdn.net/hwj_wayne/article/details/78602570