欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

mxnet finetune例子(只finetune某几层)

程序员文章站 2022-05-27 09:41:28
...

1.百度mxnet model zoo下载相应的pre-train model:

http://mxnet.incubator.apache.org/model_zoo/index.html


2.把数据转为.rec,可参照官方例子的第一块内容:

http://mxnet.incubator.apache.org/how_to/finetune.html


3.定义数据迭代器生成函数:

def get_fine_tune_model(model_name):
    # load model
    symbol, arg_params, aux_params = mx.model.load_checkpoint("data/pre_train/"+model_name, 0)
    # model tuning
    all_layers = symbol.get_internals()
    if model_name=="vgg16":
        net = all_layers['drop7_output']
    else:
        net = all_layers['flatten0_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='newfc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    # eliminate weights of new layer
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args,aux_params)

4.定义pre-train模型读取函数以及模型修改函数

def get_fine_tune_model(model_name):
    # load model
    symbol, arg_params, aux_params = mx.model.load_checkpoint("data/pre_train/"+model_name, 0)
    # model tuning
    all_layers = symbol.get_internals()
    if model_name=="vgg16":
        net = all_layers['drop7_output']
    else:
        net = all_layers['flatten0_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='newfc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    # eliminate weights of new layer
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args,aux_params)

5.模型训练,训练过程中,把不进行调整的层的学习率设置为0,从而达到只finetune后几层的效果

def fit(symbol, arg_params, aux_params, iter_train, iter_val, class_str, num_epoch, batch_size, gpu_avaliable):
    devs = [mx.gpu(i) for i in gpu_avaliable]
    model = mx.mod.Module(symbol=symbol, context=devs)
    # metric
    com_metric = mx.metric.CompositeEvalMetric()
    com_metric.add(mx.metric.Accuracy())
    com_metric.add(mAP(class_str)) # remove if unnecessary
    # optimizer: fix the weight of certain layers except the last fully connect layer
    sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
    finetune_lr = dict({k: 0 for k in arg_params})
    sgd.set_lr_mult(finetune_lr)
    # training
    model.fit(iter_train, iter_val,
        num_epoch=num_epoch,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),
        kvstore='device',
        optimizer=sgd,
        optimizer_params={'learning_rate':0.01},
        initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
        eval_metric='acc')
    return model.score(iter_val, com_metric)



完整代码:

import logging
import mxnet as mx
import numpy as np
import os.path, time,sys
from mAP_metric import mAP

print ("\n******File updated %ds ago%s******" % (time.time()-os.path.getmtime(sys.argv[0])))# file updatation check

# data iterators: generate data iterator from .rec file
def get_iterators(batch_size, rec_train, rec_val, lst_train, data_shape=(3, 224, 224)):
    train = mx.io.ImageRecordIter(
        path_imgrec=rec_train,
        path_imglist=lst_train,
        data_name='data',
        label_name='softmax_label',
        batch_size=batch_size,
        data_shape=data_shape,
        shuffle=True,
        # shuffle=False,
        rand_crop=True,
        mirror =True,
        rand_mirror=True,
        max_rotate_angle=0)
    val = mx.io.ImageRecordIter(
        path_imgrec=rec_val,
        data_name='data',
        label_name='softmax_label',
        batch_size=batch_size,
        data_shape=data_shape)
    return train,val

# load and tune model
def get_fine_tune_model(model_name):
    # load model
    symbol, arg_params, aux_params = mx.model.load_checkpoint("data/pre_train/"+model_name, 0)
    # model tuning
    all_layers = symbol.get_internals()
    if model_name=="vgg16":
        net = all_layers['drop7_output']
    else:
        net = all_layers['flatten0_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='newfc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    # eliminate weights of new layer
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args,aux_params)

#model training
def fit(symbol, arg_params, aux_params, iter_train, iter_val, class_str, num_epoch, batch_size, gpu_avaliable):
    devs = [mx.gpu(i) for i in gpu_avaliable]
    model = mx.mod.Module(symbol=symbol, context=devs)
    # metric
    com_metric = mx.metric.CompositeEvalMetric()
    com_metric.add(mx.metric.Accuracy())
    com_metric.add(mAP(class_str)) # remove if unnecessary
    # optimizer: fix the weight of certain layers except the last fully connect layer
    sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
    finetune_lr = dict({k: 0 for k in arg_params})
    sgd.set_lr_mult(finetune_lr)
    # training
    model.fit(iter_train, iter_val,
        num_epoch=num_epoch,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),
        kvstore='device',
        optimizer=sgd,
        optimizer_params={'learning_rate':0.01},
        initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
        eval_metric='acc')
    return model.score(iter_val, com_metric)

#=======================================================================================================================
# set logger, print message on screen and file
logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s',filename='acc_record.log',filemode='w')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter('%(asctime)-15s %(message)s'))
logging.getLogger('').addHandler(console)

# data and pre-train model
rec_train='./data/rec/hico_train_full.rec'
# rec_train='./data/rec/hico_train_200500.rec'
model_name='vgg16'
# model_name='resnet-152'
rec_val='./data/rec/hico_val.rec'
lst_train=rec_train[:-3]+'lst'

# parameter
num_classes = 600
class_str=[]
for i in range(num_classes):
    class_str.append("c"+str(i))
batch_per_gpu = 40
num_epoch =10
gpu_avaliable=[0,1,2,3]
num_gpus = len(gpu_avaliable)
batch_size = batch_per_gpu * num_gpus
if rec_train=='./data/rec/hico_train_full.rec':
    print ('-----------Batchs per epoch: %d-----------' % (7000.0/batch_size))
if rec_train=='./data/rec/hico_train_200500.rec':
    print ('-----------Batchs per epoch: %d-----------' % (137120.0/batch_size))
#-----------------------------------------------------------------------------------------------------------------------

(new_sym,new_args,aux_params)=get_fine_tune_model(model_name)
(iter_train, iter_val) = get_iterators(batch_size,rec_train,rec_val,lst_train)
mod_score = fit(new_sym, new_args, aux_params, iter_train, iter_val, class_str, num_epoch, batch_size, gpu_avaliable)
print(mod_score)