欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

tensorflow三种模型的加载和保存的方法(.ckpt,.pb,SavedModel)

程序员文章站 2022-12-20 16:51:48
工作中尝试用的是.ckpt,最近在研究tensorflow serving所以需要将模型转化为SavedModel格式。而有时模型平台调用又需要.pb模型,所以对这三种文件进行了解。一、ckpt文件的保存和加载1、模型保存文件格式checkpoint文件:用于告知某些TF函数,这是最新的检查点文件.data文件:包含训练变量的文件.index文件:描述variable中key和value的对应关系.meta文件:保存完整的网络图结构使用这种方法保存模型时会保存成上面这四个文件,重新加载模型时通...

工作中经常使用的是.ckpt,最近在研究tensorflow serving所以需要将模型转化为SavedModel格式。而有时模型平台调用又需要.pb模型,所以对这三种文件进行了解。

一、.ckpt文件的保存和加载

1、模型保存文件格式

checkpoint文件:b包含最新的和所有的文件地址
.data文件:包含训练变量的文件
.index文件:描述variable中key和value的对应关系
.meta文件:保存完整的网络图结构
使用这种方法保存模型时会保存成上面这四个文件,重新加载模型时通常只会用到.meta文件恢复图结构然后用.data文件把各个变量的值再加进去。

2、模型保存方法

saver=tf.train.Saver(max_to_keep=5)  #表示保存最近的几个模型,设置为None或者0 就是保存全部的模型。此处max_to_keep=5意思就是保存最近的5个模型
saver.save(sess,'D:/model',global_step=epoch)

创建一个saver,调用save方法将当前sess会话中的图和变量等信息保存到指定路径,global_step代表当前的轮数,设置之后会在文件名后面缀一个"-epco"

3、模型加载方法

saver=tf.train.import_meta_graph('model/model-0720-4.meta')  #恢复计算图结构
saver.restore(sess, tf.train.latest_checkpoint("model/"))  #恢复所有变量信息
#现在sess中已经恢复了网络结构和变量信息了,接下来可以直接用节点的名称来调用:
print(sess.run('op:0',feed_dict={'x:0':2,'y:0':3})
#或者采用:
graph = tf.get_default_graph()
input_x = graph.get_tensor_by_name('x:0')
input_y=graph.get_tensor_by_name('y:0')
op=graph.get_tensor_name('op:0')
print(sess.run(op,feed_dict={input_x:2,input_y:3)

4、特点

.ckpt方式保存模型,这种模型文件是依赖 TensorFlow 的,只能在其框架下使用

二、.pb文件的保存和加载

1、模型保存文件格式

.pb文件里面保存了图结构+数据,加载模型时只需要这一个文件就好。

2、模型保存方法

constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op'])
with tf.gfile.FastGFile('D:/pycharm files/model.pb', mode='wb') as f:
  f.write(constant_graph.SerializeToString())

3、模型加载方法

with tf.gfile.FastGFile(pb_file_path, 'rb') as f:
  graph_def = tf.GraphDef() # 生成图
  graph_def.ParseFromString(f.read()) # 图加载模型
  tf.import_graph_def(graph_def, name='')
#接下来与前面的相同可以直接用节点的名称来调用:
print(sess.run('op:0',feed_dict={'x:0':2,'y:0':3})
#或者采用:
graph = tf.get_default_graph()
input_x = graph.get_tensor_by_name('x:0')
input_y=graph.get_tensor_by_name('y:0')
op=graph.get_tensor_name('op:0')
print(sess.run(op,feed_dict={input_x:2,input_y:3)

4、特点

谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。另外的好处是保存为 PB 文件时候,模型的变量都会变成固定的,导致模型的大小会大大减小。
加载一个pb文件之后再对其进行微调(也就是将这个pb文件的网络作为自己网络的一部分),然后再保存成pb文件,后一个pb网络会包含前一个pb网络。

三、saved model

1、模型保存文件格式

在传入的目录下会有一个pb文件和一个variables文件夹:

2、模型保存方法

builder = tf.saved_model.builder.SavedModelBuilder(path)
builder.add_meta_graph_and_variables(sess,['cpu_server_1'])

3、模型加载方法

with tf.Session(graph=tf.Graph()) as sess:
  tf.saved_model.loader.load(sess, ['cpu_server_1'], pb_file_path+'savemodel')
#接下来可以直接使用名字或者get_tensor_by_name后再进行使用
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  op = sess.graph.get_tensor_by_name('op:0')
  ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})

4、特点

saved_model模块主要用于TensorFlow Serving,目的是要实现inference的代码统一。详细可点击参考https://blog.csdn.net/thriving_fcl/article/details/75213361

四、将 .ckpt转化为.pb

参考https://www.jianshu.com/p/06548e3e8f4b

五、将 .ckpt转化为SavedModel

参考https://www.jianshu.com/p/451c46bd9287

六、案例

参考https://www.cnblogs.com/biandekeren-blog/p/11876032.html

参考文献
https://www.cnblogs.com/biandekeren-blog/p/11876032.html

本文地址:https://blog.csdn.net/weixin_44388679/article/details/107458536

相关标签: 深度学习