欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【CV】ckpt文件转为pb文件(fasterrcnn)

程序员文章站 2022-07-01 15:29:46
...
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow

def freeze_graph(input_checkpoint,output_graph):
    #指定输出的节点名称,该节点名称必须是原模型中存在的节点。直接用最后输出的节点,可以在tensorboard中查找到,tensorboard只能在linux中使用
    output_node_names = "SCORE/resnet_v1_101_5/cls_prob/cls_prob/scores,SCORE/resnet_v1_101_5/bbox_pred/BiasAdd/bbox_pred/scores,SCORE/resnet_v1_101_5/cls_pred/cls_pred/scores"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #通过 import_meta_graph 导入模型中的图----1
    graph = tf.get_default_graph() #获得默认的图
    input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前的图
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #通过 saver.restore 从模型中恢复图中各个变量的数据----2
        output_graph_def = graph_util.convert_variables_to_constants(  #通过 graph_util.convert_variables_to_constants 将模型持久化----3
            sess=sess,
            input_graph_def=input_graph_def, #等于:sess.graph_def
            output_node_names=output_node_names.split(","))  #如果有多个输出节点,以逗号隔开
        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

input_checkpoint='./checkpoints/res101_faster_rcnn_iter_70000.ckpt'
out_pb_path='./checkpoints/frozen_model.pb'

reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:  # Print tensor name and values
    print("tensor_name: ", key)
    #print(reader.get_tensor(key))

freeze_graph(input_checkpoint, out_pb_path)
相关标签: image