欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch模型转ONNX

程序员文章站 2022-07-13 11:43:06
...

本程序用于将PyTorch框架下建立的模型(.pt,.pth等)转换为ONNX模型。利用Netron工具箱,可方便快捷地实现ONNX模型结构的可视化。

"""
Exporting a .pth model to ONNX format

Reference:
https://github.com/ultralytics/yolov5/blob/master/models/export.py
"""

import onnx
import torch

from model_irse import IR_SE_152
# import the model you need

if __name__ == '__main__':
    model = IR_SE_152((112, 112), 512)
    model_path = './IR_SE_152.pth'
    model.load_state_dict(torch.load(model_path))
    
    # Input
    x = torch.randn(1, 3, 112, 112)

    # ONNX export
    try:
        print('Starting ONNX export with onnx %s...' % onnx.__version__)
        f = './models/IR_SE_152.onnx'  # export filename
        # torch.onnx.export(model, x, f, verbose=False, opset_version=12, input_names=['images'],
        #                   output_names=['classes', 'boxes'] if y is None else ['output'])
        torch.onnx.export(model, x, f, verbose=False, opset_version=12, input_names=['images'],
                          output_names=['output'])

        # Check the ONNX model
        onnx_model = onnx.load(f)  # load onnx model
        onnx.checker.check_model(onnx_model)  # check onnx model
        # print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model
        print('ONNX model exported successfully, saved as %s' % f)
    except Exception as e:
        print('Failed to export ONNX model: %s' % e)

    
相关标签: 深度学习 python