欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow模型转ONNX格式-Part1

程序员文章站 2022-06-26 15:30:44
...

TensorFlow模型转ONNX格式-官方示例

https://github.com/onnx/tutorials/blob/master/tutorials/TensorflowToOnnx-1.ipynb
——Translated by Blssel

前言

对Tensorflow和ONNX来说,虽然它们使用的是不同的计算图格式,但你可以使用Tensorflow-ONNX将一个Tensorflow模型转化为ONNX。本文将分为两个部分:第1部分介绍基本的转换方法,第2部分讨论更高级的话题。目录可以概括如下:

  1. 转换TensorFlow模型的步骤
    -准备tensorflow模型
    -转换为ONNX
    -验证
  2. 额外说明

1. 转换TensorFlow模型的步骤

步骤1:准备tensorflow模型
Tensorflow有好几种保存模型的文件格式,如检查点(checkpoint)文件、graph with weight(called frozen graph next) 以及saved_model,,你可以在训练模型时使用tensorflow提供的api来生成这些文件,可以参考脚本tensorflow_to_onnx_example.py
无论是这三种格式中的哪一种,Tensorflow-onnx都能够将它们转换成onnx格式不过更推荐使用“saved_model”格式,因为它不需要用户指定图形的输入和输出名称。本节将以它为例进行介绍,然后在第2部分(part2)中介绍其他两个。此外,你还可以从tensorflow-onnx的README文件中获得更多细节。

import os
import shutil
import tensorflow as tf
from assets.tensorflow_to_onnx_example import create_and_train_mnist
def save_model_to_saved_model(sess, input_tensor, output_tensor):
    from tensorflow.saved_model import simple_save
    save_path = r"./output/saved_model"
    if os.path.exists(save_path):
        shutil.rmtree(save_path)
    simple_save(sess, save_path, {input_tensor.name: input_tensor}, {output_tensor.name: output_tensor})

print("please wait for a while, because the script will train MNIST from scratch")
tf.reset_default_graph()
sess_tf, saver, input_tensor, output_tensor = create_and_train_mnist()
print("save tensorflow in format \"saved_model\"")
save_model_to_saved_model(sess_tf, input_tensor, output_tensor)
please wait for a while, because the script will train MNIST from scratch
Extracting /tmp/tensorflow/mnist/input_data/train-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/train-labels-idx1-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/t10k-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/t10k-labels-idx1-ubyte.gz
step 0, training accuracy 0.18
step 1000, training accuracy 0.98
step 2000, training accuracy 0.94
step 3000, training accuracy 1
step 4000, training accuracy 1
test accuracy 0.976
save tensorflow in format "saved_model"

步骤2:转换为ONNX
tensorflow-onnx有几个条目用于转换不同的tensorflow格式的tensorflow模型,本节只讨论“saved_model”,“frozen graph”和“checkpoint”将在第2部分中介绍。
另外,tensorflow-onnx还导出了相关的python api,这样用户就可以直接在脚本中调用它们,而不是在命令行中调用它们,具体细节将在第2部分中介绍。

# generating mnist.onnx using saved_model
!python -m tf2onnx.convert \
        --saved-model ./output/saved_model \
        --output ./output/mnist1.onnx \
        --opset 7
2019-06-17 07:22:03,871 - INFO - Using tensorflow=1.12.0, onnx=1.5.0, tf2onnx=1.5.1/0c735a
2019-06-17 07:22:03,871 - INFO - Using opset <onnx, 7>
2019-06-17 07:22:03,989 - INFO - 
2019-06-17 07:22:04,012 - INFO - Optimizing ONNX model
2019-06-17 07:22:04,029 - INFO - After optimization: Add -2 (4->2), Identity -3 (3->0), Transpose -8 (9->1)
2019-06-17 07:22:04,031 - INFO - 
2019-06-17 07:22:04,032 - INFO - Successfully converted TensorFlow model ./output/saved_model to ONNX
2019-06-17 07:22:04,044 - INFO - ONNX model is saved at ./output/mnist1.onnx

步骤3:验证
有好几种可以运行ONNX模型的方式,这里使用ONNXRuntime框架,由微软开源,可以确保生成的ONNX计算图正常运行。输入”image.npz”是一幅手写的“7”图像,因此模型的预期分类结果应为“7”。

import numpy as np
import onnxruntime as ort

img = np.load("./assets/image.npz").reshape([1, 784])  
sess_ort = ort.InferenceSession("./output/mnist1.onnx")
res = sess_ort.run(output_names=[output_tensor.name], input_feed={input_tensor.name: img})
print("the expected result is \"7\"")
print("the digit is classified as \"%s\" in ONNXRruntime"%np.argmax(res))
the expected result is "7"
the digit is classified as "7" in ONNXRruntime

2. 额外说明

以上的命令行应该适用于大多数tensorflow模型。在某些情况下,您可能会遇到需要额外选项的问题。
选项中最重要的概念是opset(操作集)选项,ONNX是一个不断发展的标准,它将添加更多的新操作并增强现有的操作,因此不同的opset版本将包含不同的操作,它们可能会有些不同 。默认版本“tensorflow-onnx”使用的是7,ONNX现在最高支持版本10,所以如果转换失败,您可以尝试不同的版本,通过命令行选项“——opset”,看看它是否工作。

继续第2部分,解释高级主题。

相关标签: TensorFlow ONNX