【CV】ckpt文件转为pb文件(fasterrcnn)
程序员文章站
2022-07-01 15:29:46
...
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow
def freeze_graph(input_checkpoint,output_graph):
#指定输出的节点名称,该节点名称必须是原模型中存在的节点。直接用最后输出的节点,可以在tensorboard中查找到,tensorboard只能在linux中使用
output_node_names = "SCORE/resnet_v1_101_5/cls_prob/cls_prob/scores,SCORE/resnet_v1_101_5/bbox_pred/BiasAdd/bbox_pred/scores,SCORE/resnet_v1_101_5/cls_pred/cls_pred/scores"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #通过 import_meta_graph 导入模型中的图----1
graph = tf.get_default_graph() #获得默认的图
input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前的图
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #通过 saver.restore 从模型中恢复图中各个变量的数据----2
output_graph_def = graph_util.convert_variables_to_constants( #通过 graph_util.convert_variables_to_constants 将模型持久化----3
sess=sess,
input_graph_def=input_graph_def, #等于:sess.graph_def
output_node_names=output_node_names.split(",")) #如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
input_checkpoint='./checkpoints/res101_faster_rcnn_iter_70000.ckpt'
out_pb_path='./checkpoints/frozen_model.pb'
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map: # Print tensor name and values
print("tensor_name: ", key)
#print(reader.get_tensor(key))
freeze_graph(input_checkpoint, out_pb_path)