keras模型转TensorFlow模型->tensorrt UFF格式
程序员文章站
2022-03-04 20:17:52
...
最近在学习tensorrt,需要将keras训练好保存的.hdf5格式模型转为tensorflow的.pb模型,然后转为tensorrt支持的uff格式。
做个记录。
代码如下:
转为tensorflow的.pb格式
# h5_to_pb.py
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
# 路径参数
input_path = '/home/cendelian/Study_CV/face_age_competition/Resnet_Tensorrt-master/model/'
weight_file = 'resnet50_0.1273_3.7708.hdf5'
weight_file_path = osp.join(input_path, weight_file)
output_graph_name = weight_file[:-3] + '.pb'
# 转换函数
def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
if osp.exists(output_dir) == False:
os.mkdir(output_dir)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i], out_prefix + str(i + 1))
sess = K.get_session()
from tensorflow.python.framework import graph_util, graph_io
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
if log_tensorboard:
from tensorflow.python.tools import import_pb_to_tensorboard
import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir, model_name), output_dir)
# 输出路径
output_dir = osp.join(os.getcwd(), "trans_model")
# 加载模型
print(weight_file_path)
h5_model = load_model(weight_file_path)
# model.load_weights
h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
print('model saved')
Note
-
期间遇到一个错误:
tensorflow.python.framework.errors_impl.NotFoundError: libnvinfer.so.5: cannot open shared object file: No such file or directory
后面解决方法是将tensorflow环境切换到tensorflow-gpu 1.15。
原因是我安装的tensorrt版本是7.0,
tensorrt7.0版本需要tensorflow1.14+以上的环境才可以。 -
keras模型需要用save()保存模型结果,save()既保持了模型的图结构,又保存了模型的参数。
save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。
本次运行的环境:tensorflow1.15,cuda10.0,cudnn7.6
转为tensorrt 的uff格式
只需要要一行指令如下:
convert-to-uff resnet50_0.1273_3.7708.h.pb
其中‘resnet50_0.1273_3.7708.h.pb’为上一个步骤转换得到的结果。
该指令运行后可以得到一个
’ resnet50_0.1273_3.7708.h.uff '文件。
结束!
上一篇: java acm A1003--顺序合并两个非有序的链表
下一篇: RestTemplate小记