欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

onnxruntime加载pytorch图像分类模型

程序员文章站 2022-07-13 10:20:25
...
  • 从pytorch模型导出onnx模型,可以参考笔者的前一篇博文https://blog.csdn.net/ouening/article/details/109245243
  • 使用netron查看onnx模型结构,如下图:
    onnxruntime加载pytorch图像分类模型
    注意输入输出的名称name以及数据类型和维度type
  • 程序
import numpy as np    # we're going to use numpy to process input and output data
import onnxruntime    # to inference ONNX models, we use the ONNX Runtime
import onnx
from onnx import numpy_helper
import urllib.request
import json
import time
from imageio import imread
import warnings
warnings.filterwarnings('ignore')
# display images in notebook
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

onnx_model = r"D:\Files\python\opencv\调用pytorch-onnx模型\exported.onnx"

# Run the model on the backend
session = onnxruntime.InferenceSession(onnx_model, None)

# get the name of the first input of the model
input_name = session.get_inputs()[0].name  
output_name = session.get_outputs()[0].name  
# print(len(session.get_outputs()))
print('Input Name:', input_name)
print('Output Name:', output_name)

img_file = r"C:\Users\LX\Pictures\elephant.jpg"
def load_labels():
    classes = None
    class_file = r"E:\ScientificComputing\opencv\sources\samples\data\dnn\classification_classes_ILSVRC2012.txt"
    with open(class_file, 'rt') as f:
        classes = f.read().rstrip('\n').split('\n')
    return classes

def preprocess(input_data):
    # convert the input data into the float32 input
    img_data = input_data.astype('float32')

    #normalize
    mean_vec = np.array([0.485, 0.456, 0.406])
    stddev_vec = np.array([0.229, 0.224, 0.225])
    norm_img_data = np.zeros(img_data.shape).astype('float32')
    for i in range(img_data.shape[0]):
        norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
        
    #add batch channel
    norm_img_data = norm_img_data.reshape(1, 3, 224, 224).astype('float32')
    return norm_img_data

def softmax(x):
    x = x.reshape(-1)
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def postprocess(result):
    return softmax(np.array(result)).tolist()

image = Image.open(img_file).resize((224,224))
# image = Image.open('images/plane.jpg')

print("Image size: ", image.size)
plt.axis('off')
display_image = plt.imshow(image)
image_data = np.array(image).transpose(2, 0, 1)
input_data = preprocess(image_data)

#%%
start = time.time()
raw_result = session.run([], {input_name: input_data})
end = time.time()
res = postprocess(raw_result)

inference_time = np.round((end - start) * 1000, 2)
idx = np.argmax(res)
labels = load_labels()
print('========================================')
print('Final top prediction is: ' + labels[idx])
print('========================================')

print('========================================')
print('Inference time: ' + str(inference_time) + " ms")
print('========================================')

sort_idx = np.flip(np.squeeze(np.argsort(res)))
print('============ Top 5 labels are: ============================')
# print(labels[sort_idx[:5]])
for k in sort_idx[:5]:
    print(labels[k])
print('===========================================================')

plt.axis('off')
display_image = plt.imshow(image)

参考链接:https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb