欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

mxnet加载模型的params和json文件来预测

程序员文章站 2024-03-14 20:38:05
...

导读

有时候我们在使用别人的mxnet预训练模型时,会有两个文件paramsjson文件,其中params文件中包含的是模型的网络参数,json文件包含的是网络的结构。这里我们以ImageNet的预训练模型为例,来介绍如何加载模型并进行预测。

使用预训练模型进行预测

  • 下载预训练模型
import mxnet as mx
#定义一个data batch
from collections import namedtuple
import cv2
import numpy as np

def download_model():
    model_url = 'http://data.mxnet.io/models/imagenet/'
    #下载ResNet-18预训练模型
    mx.test_utils.download(model_url + 'resnet/18-layers/resnet-18-0000.params')
    mx.test_utils.download(model_url + 'resnet/18-layers/resnet-18-symbol.json')
    #下载label name文件
    mx.test_utils.download(model_url + 'synset.txt')
  • 读取标签文件
def get_label_names(label_path):
    with open(label_path,"r") as rf:
        labels = [line_info.rstrip() for line_info in rf]
        return labels
  • 获取模型
def get_mod(model_str,ctx,data_shape):
    _vec = model_str.split(",")
    prefix = _vec[0]
    epoch = int(_vec[1])
    sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)
    mod = mx.mod.Module(symbol=sym,context=ctx,label_names=None)
    #注意修改data_shapes,ImageNet使用的shape是(224,224,3)
    mod.bind(for_training=False,data_shapes=[("data",data_shape)],
             label_shapes=mod._label_shapes)
    #加载网络的参数
    mod.set_params(arg_params,aux_params,allow_missing=True)
    return mod
  • 图片预处理
def preprocess_img(img_path,data_shape,ctx):
    # 读取图片
    img = cv2.imread(img_path)
    # 将图片缩放到与bind中shape的width和height一致
    img = cv2.resize(img, (data_shape[2], data_shape[3]))
    # 将图片由BGR转为RGB
    img = img[:, :, ::-1]
    # 将numpy array转为ndarray
    nd_img = mx.nd.array(img,ctx=ctx).transpose((2, 0, 1))
    # 将图片的格式转为NCHW
    nd_img = mx.nd.expand_dims(nd_img, axis=0)
    return nd_img
  • 模型预测
def predict(model_str,ctx,data_shape,img_path,label_path):
    #通过标签文件获取标签名称
    label_names = get_label_names(label_path)
    Batch = namedtuple("Batch",["data"])
    mod = get_mod(model_str,ctx,data_shape)
    #获取预测的图片
    nd_img = preprocess_img(img_path,data_shape,ctx)
    #计算网络的预测值
    mod.forward(Batch([nd_img]))
    prob = mod.get_outputs()[0].asnumpy()
    #获取top5
    prob = np.squeeze(prob)
    #将图片预测的概率由大到小进行排序
    sort_prob = np.argsort(prob)[::-1]
    for i in sort_prob[:5]:
        print("label name=%s,probability=%f"%(label_names[i],prob[i]))

model_str = "resnet-18,0"
ctx = mx.cpu()
data_shape = (1,3,224,224)
img_path = "img/panda.jpg"
label_path = "synset.txt"
predict(model_str,ctx,data_shape,img_path,label_path)

mxnet加载模型的params和json文件来预测
mxnet加载模型的params和json文件来预测

获取指定层的输出

有些时候我们不需要网络的输出,而是只需要网络某个层的输出来通过网络提取图片的特征,这时候我们就需要指定提取层的名称,这里我们通过提取网络最后一层的全连接层为例

def get_specify_mod(model_str,ctx,data_shpae,layer_name):
    _vec = model_str.split(",")
    prefix = _vec[0]
    epoch = int(_vec[1])
    sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)
    #获取神经网络所有的层
    all_layers = sym.get_internals()
    #获取输出层
    sym = all_layers[layer_name+"_output"]
    mod = mx.mod.Module(symbol=sym,context=ctx)
    mod.bind(data_shapes=[("data",data_shpae)])
    mod.set_params(arg_params,aux_params)
    return mod
    
def predict_specify(model_str,ctx,data_shape,img_path,label_path):
    label_names = get_label_names(label_path)
    #通过输出网络层的名称,输出层全连接层的名称为fc1
    mod = get_specify_mod(model_str,ctx,data_shape,layer_name="fc1")
    nd_img = preprocess_img(img_path,data_shape,ctx)
    #将需要预测的图片封装为Batch
    data_batch = mx.io.DataBatch(data=(nd_img,))
    #计算网络的预测值
    mod.forward(data_batch,is_train=False)
    #获取网络的输出值
    output = mod.get_outputs()[0]
    #对输出值进行softmax处理
    proba = mx.nd.softmax(output)
    #获取前top5的值
    top_proba = proba.topk(k=5)[0].asnumpy()
    for index in top_proba:
        probability = proba[0][int(index)].asscalar()*100
        pred_label_name = label_names[int(index)]
        print("label name=%s,probability=%f"%(pred_label_name,probability))

mxnet加载模型的params和json文件来预测

相关标签: mxnet修炼之路