mxnet finetune例子(只finetune某几层)
程序员文章站
2022-05-27 09:33:47
...
import logging
import mxnet as mx
import numpy as np
import os.path, time,sys
from mAP_metric import mAP
print ("\n******File updated %ds ago%s******" % (time.time()-os.path.getmtime(sys.argv[0])))# file updatation check
# data iterators: generate data iterator from .rec file
def get_iterators(batch_size, rec_train, rec_val, lst_train, data_shape=(3, 224, 224)):
train = mx.io.ImageRecordIter(
path_imgrec=rec_train,
path_imglist=lst_train,
data_name='data',
label_name='softmax_label',
batch_size=batch_size,
data_shape=data_shape,
shuffle=True,
# shuffle=False,
rand_crop=True,
mirror =True,
rand_mirror=True,
max_rotate_angle=0)
val = mx.io.ImageRecordIter(
path_imgrec=rec_val,
data_name='data',
label_name='softmax_label',
batch_size=batch_size,
data_shape=data_shape)
return train,val
# load and tune model
def get_fine_tune_model(model_name):
# load model
symbol, arg_params, aux_params = mx.model.load_checkpoint("data/pre_train/"+model_name, 0)
# model tuning
all_layers = symbol.get_internals()
if model_name=="vgg16":
net = all_layers['drop7_output']
else:
net = all_layers['flatten0_output']
net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='newfc1')
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
# eliminate weights of new layer
new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
return (net, new_args,aux_params)
#model training
def fit(symbol, arg_params, aux_params, iter_train, iter_val, class_str, num_epoch, batch_size, gpu_avaliable):
devs = [mx.gpu(i) for i in gpu_avaliable]
model = mx.mod.Module(symbol=symbol, context=devs)
# metric
com_metric = mx.metric.CompositeEvalMetric()
com_metric.add(mx.metric.Accuracy())
com_metric.add(mAP(class_str)) # remove if unnecessary
# optimizer: fix the weight of certain layers except the last fully connect layer
sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
finetune_lr = dict({k: 0 for k in arg_params})
sgd.set_lr_mult(finetune_lr)
# training
model.fit(iter_train, iter_val,
num_epoch=num_epoch,
arg_params=arg_params,
aux_params=aux_params,
allow_missing=True,
batch_end_callback = mx.callback.Speedometer(batch_size, 10),
kvstore='device',
optimizer=sgd,
optimizer_params={'learning_rate':0.01},
initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
eval_metric='acc')
return model.score(iter_val, com_metric)
#=======================================================================================================================
# set logger, print message on screen and file
logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s',filename='acc_record.log',filemode='w')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter('%(asctime)-15s %(message)s'))
logging.getLogger('').addHandler(console)
# data and pre-train model
rec_train='./data/rec/hico_train_full.rec'
# rec_train='./data/rec/hico_train_200500.rec'
model_name='vgg16'
# model_name='resnet-152'
rec_val='./data/rec/hico_val.rec'
lst_train=rec_train[:-3]+'lst'
# parameter
num_classes = 600
class_str=[]
for i in range(num_classes):
class_str.append("c"+str(i))
batch_per_gpu = 40
num_epoch =10
gpu_avaliable=[0,1,2,3]
num_gpus = len(gpu_avaliable)
batch_size = batch_per_gpu * num_gpus
if rec_train=='./data/rec/hico_train_full.rec':
print ('-----------Batchs per epoch: %d-----------' % (7000.0/batch_size))
if rec_train=='./data/rec/hico_train_200500.rec':
print ('-----------Batchs per epoch: %d-----------' % (137120.0/batch_size))
#-----------------------------------------------------------------------------------------------------------------------
(new_sym,new_args,aux_params)=get_fine_tune_model(model_name)
(iter_train, iter_val) = get_iterators(batch_size,rec_train,rec_val,lst_train)
mod_score = fit(new_sym, new_args, aux_params, iter_train, iter_val, class_str, num_epoch, batch_size, gpu_avaliable)
print(mod_score)
from:https://blog.csdn.net/hwj_wayne/article/details/78602570
上一篇: yii2 composer 安装失败
下一篇: 简单的php缓存类分享_PHP教程