mxnet加载模型的params和json文件来预测
程序员文章站
2024-03-14 20:38:05
...
导读
有时候我们在使用别人的mxnet预训练模型时,会有两个文件params
和json
文件,其中params文件中包含的是模型的网络参数,json文件包含的是网络的结构。这里我们以ImageNet
的预训练模型为例,来介绍如何加载模型并进行预测。
使用预训练模型进行预测
- 下载预训练模型
import mxnet as mx
#定义一个data batch
from collections import namedtuple
import cv2
import numpy as np
def download_model():
model_url = 'http://data.mxnet.io/models/imagenet/'
#下载ResNet-18预训练模型
mx.test_utils.download(model_url + 'resnet/18-layers/resnet-18-0000.params')
mx.test_utils.download(model_url + 'resnet/18-layers/resnet-18-symbol.json')
#下载label name文件
mx.test_utils.download(model_url + 'synset.txt')
- 读取标签文件
def get_label_names(label_path):
with open(label_path,"r") as rf:
labels = [line_info.rstrip() for line_info in rf]
return labels
- 获取模型
def get_mod(model_str,ctx,data_shape):
_vec = model_str.split(",")
prefix = _vec[0]
epoch = int(_vec[1])
sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)
mod = mx.mod.Module(symbol=sym,context=ctx,label_names=None)
#注意修改data_shapes,ImageNet使用的shape是(224,224,3)
mod.bind(for_training=False,data_shapes=[("data",data_shape)],
label_shapes=mod._label_shapes)
#加载网络的参数
mod.set_params(arg_params,aux_params,allow_missing=True)
return mod
- 图片预处理
def preprocess_img(img_path,data_shape,ctx):
# 读取图片
img = cv2.imread(img_path)
# 将图片缩放到与bind中shape的width和height一致
img = cv2.resize(img, (data_shape[2], data_shape[3]))
# 将图片由BGR转为RGB
img = img[:, :, ::-1]
# 将numpy array转为ndarray
nd_img = mx.nd.array(img,ctx=ctx).transpose((2, 0, 1))
# 将图片的格式转为NCHW
nd_img = mx.nd.expand_dims(nd_img, axis=0)
return nd_img
- 模型预测
def predict(model_str,ctx,data_shape,img_path,label_path):
#通过标签文件获取标签名称
label_names = get_label_names(label_path)
Batch = namedtuple("Batch",["data"])
mod = get_mod(model_str,ctx,data_shape)
#获取预测的图片
nd_img = preprocess_img(img_path,data_shape,ctx)
#计算网络的预测值
mod.forward(Batch([nd_img]))
prob = mod.get_outputs()[0].asnumpy()
#获取top5
prob = np.squeeze(prob)
#将图片预测的概率由大到小进行排序
sort_prob = np.argsort(prob)[::-1]
for i in sort_prob[:5]:
print("label name=%s,probability=%f"%(label_names[i],prob[i]))
model_str = "resnet-18,0"
ctx = mx.cpu()
data_shape = (1,3,224,224)
img_path = "img/panda.jpg"
label_path = "synset.txt"
predict(model_str,ctx,data_shape,img_path,label_path)
获取指定层的输出
有些时候我们不需要网络的输出,而是只需要网络某个层的输出来通过网络提取图片的特征
,这时候我们就需要指定提取层的名称
,这里我们通过提取网络最后一层的全连接层为例
def get_specify_mod(model_str,ctx,data_shpae,layer_name):
_vec = model_str.split(",")
prefix = _vec[0]
epoch = int(_vec[1])
sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)
#获取神经网络所有的层
all_layers = sym.get_internals()
#获取输出层
sym = all_layers[layer_name+"_output"]
mod = mx.mod.Module(symbol=sym,context=ctx)
mod.bind(data_shapes=[("data",data_shpae)])
mod.set_params(arg_params,aux_params)
return mod
def predict_specify(model_str,ctx,data_shape,img_path,label_path):
label_names = get_label_names(label_path)
#通过输出网络层的名称,输出层全连接层的名称为fc1
mod = get_specify_mod(model_str,ctx,data_shape,layer_name="fc1")
nd_img = preprocess_img(img_path,data_shape,ctx)
#将需要预测的图片封装为Batch
data_batch = mx.io.DataBatch(data=(nd_img,))
#计算网络的预测值
mod.forward(data_batch,is_train=False)
#获取网络的输出值
output = mod.get_outputs()[0]
#对输出值进行softmax处理
proba = mx.nd.softmax(output)
#获取前top5的值
top_proba = proba.topk(k=5)[0].asnumpy()
for index in top_proba:
probability = proba[0][int(index)].asscalar()*100
pred_label_name = label_names[int(index)]
print("label name=%s,probability=%f"%(pred_label_name,probability))
上一篇: 常用类库-----15.14 国际化程序
下一篇: Java的Locale类