欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch模型转onnx

程序员文章站 2022-07-13 11:43:30
...

安装onnx

pip install onnx
pip install onnxruntime

pth模型转换为onnx

以Resnet50为例

pytorch1.6.0
onnx1.9.0
onnxruntime1.7.0

简单示例如下:

import torch
import torchvision
import numpy
import onnx
import onnxruntime


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model = torchvision.models.resnet50(pretrained=True).cuda()
model.to(device)
model.eval()

# 输出固定尺寸onnx模型
export_onnx_file = "./resnet50.onnx"
x=torch.onnx.export(model,  # 待转换的网络模型和参数
                torch.randn(1, 3, 224, 224, device='cuda'), # 虚拟的输入,用于确定输入尺寸和推理计算图每个节点的尺寸
                export_onnx_file,  # 输出文件的名称
                verbose=False,      # 是否以字符串的形式显示计算图
                input_names=["input"],# + ["params_%d"%i for i in range(120)],  # 输入节点的名称,这里也可以给一个list,list中名称分别对应每一层可学习的参数,便于后续查询
                output_names=["output"], # 输出节点的名称
                opset_version=11,   # onnx 支持采用的operator set, pytorch版本相关
                do_constant_folding=True, # 是否压缩常量
                )

# 输出可变尺寸onnx模型
export_onnx_file = "./resnet50_dynamic.onnx"
x=torch.onnx.export(model,  # 待转换的网络模型和参数
                torch.randn(1, 3, 224, 224, device='cuda'), # 虚拟的输入,用于确定输入尺寸和推理计算图每个节点的尺寸
                export_onnx_file,  # 输出文件的名称
                verbose=False,      # 是否以字符串的形式显示计算图
                input_names=["input"],# + ["params_%d"%i for i in range(120)],  # 输入节点的名称,这里也可以给一个list,list中名称分别对应每一层可学习的参数,便于后续查询
                output_names=["output"], # 输出节点的名称
                opset_version=11,   # onnx 支持采用的operator set, 和pytorch版本相关
                do_constant_folding=True, # 是否压缩常量
                dynamic_axes={"input":{0: "batch_size"}, "output":{0: "batch_size"},} #设置动态维度,此处指明input节点的第0维度可变,命名为batch_size
                )

这里提供了导出两种类型的onnx模型,一种是固定尺寸的onnx的模型,于是在run时,输入尺寸只能为为1x3x224x224的量;一种是可变尺寸的onnx模型,即允许输入发生变化的维度,比如这里我们给的dummy input是1x3x224x224尺寸,然后限定input的第0维可以发生变化,于是在run时,可以输入尺寸为16x3x224x224的量。
可变尺寸的模型一般是用来多batch推理,需要注意的是,多batch推理时相比于单张推理,平均每张推理速度会加快,但是显存的占用会变大,可以根据自己的需求选择哪一种。
转换好的模型可以用下面这种方式检查导出是否正确

onnx_model = onnx.load("./resnet50.onnx")
onnx.checker.check_model(onnx_model)
print('The model is:\n{}'.format(onnx_model.graph.input))

通过netron可视化生成的onnx模型,可以看到给参数命的名称。

上一篇: Pytorch

下一篇: HTML实现旋转动画