欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow的模型保存以及加载

程序员文章站 2022-06-16 23:10:06
...

一、模型保存

为了更好地保存和加载我们已经训练好的模型,TensorFlow使用tf.train.Saver类和checkpoint的机制去实现这一过程,

什么是checkpoints?

        是用于存储变量的二进制文件,在其内部使用“ 字典结构 ”存储变量,键 即变量的名字,值 为变量的tensor值。

其中Saver类的定义如下所示:

class Saver(object):
    def __init__(self,
               var_list=None,                          #要保存的变量列表
               reshape=False,                          #加载时是否恢复变量形状
               sharded=False,
               max_to_keep=5,                          #最大保留几个checkpoint点
               keep_checkpoint_every_n_hours=10000.0,  #隔多长时间保留一个checkpoint
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False,
               filename=None)

1、保存的操作步骤:

saver=tf.train.Saver()    #创建saver对象

save_path=saver.save(sess,"titanic_model_saver/titanic_model.ckpt")

print(f"模型的保存路径为{save_path}")

注意事项:

(1)Saver对象在初始化的时候如果没有指定需要存储的变量列表,默认只会自动收集saver定义之前的所有变量,在saver初始化后面的相关变量则不会保存下来。也可以指定保存一些指定的变量:

    saver = tf.train.Saver([w1,w2])   #只保存 w1 w2这两个变量

(2)模型保存之后,会在相应的文件夹之下生成4个文件。

.ckpt文件:该文件是真实存储变量及其变量值的文件

.ckpt.meta文件:它是一个描述文件,在这个文件存储的是MetaGraphDef结构的对象经过二进制序列化之后的内容。             MetaGraphDef结构是由Protocol buffer定义的,其中包含了整个计算图的描述、各个变量的定义和声明、输入管道的形式、以及其他的一些信息

.ckpt.index文件:存储变量在checkpoints文件中的位置索引

checkpoint文件:最后还有一个名称为checkpoint的文件,这个是文件中存储了最新存档的文件路径。

2、相关的参数设置

前面的Saver没有添加任何参数,这样的模型存储,只会讲模型最终训练的数据存储起来,即存储最终的“ 稳定 ”的模型。除此之外,还可以引入“ 迭代计数器 ”的方式,即按照训练迭代轮次进行存储。即如下所示:

      saver.save(sess,'my_model.ckpt', global_step=step)

这里的global_step和记录日志里面的  writer.add_summary(summary,global_step=step)中的是一个意思,然后会在自动生成的带有测试的轮次和版本号的checkpoint文件。

但是因为每一次迭代记录都会生成一组checkpoint,那么迭代成千上万次之后的训练后会占用大量的磁盘空间,为了防范这种情况,Saver类中的构造函数会有两个参数进行设置。

参数一:max_to_keep :此参数指定存储操作以更迭的方式只会保留最后的5个版本的checkpoint

参数二:keep_checkpoint_every_n_hour: 这种方式以时间为单位,每n个小时存储一个checkpoint,该参数的默认值是10000,                 即10000个小时记录一个checkpoint。

当然我们可以在自己定义Saver对象的时候修改这两个参数的值,但是我们一般不推荐这样去做。

二、模型的恢复和加载

1、模型加载的方式

with tf.Session() as sess:   #第一步:构造会话对象

       new_saver=tf.train.import_meta_graph('titanic_model_saver/titanic_model.ckpt.meta') #第二步:导入模型的图结构

       new_saver.restore(sess,'titanic_model_saver/titanic_model.ckpt')  #第三步:将这个会话绑定到导入的图中

       #new_saver.restore(sess,tf.train.latest_checkpoint('mymodel'))    #第三步也可以是这样操作,因为会从mymodel文件夹中获取checkpoint,而checkpoint中存储了最新存档的文件路径

2、获取模型内保存的变量以及相关的参数

      (1)直接获取

               w1=sess.graph.get_xxxxx()        #直接获取,因为sess已经和加载的图进行了绑定

      (2)先创建所获取的图对象

              graph=tf.get_default_graph()   #获取该session所绑定的默认图

              w1=graph.get_xxxxxxx()

3、获取模型中tensor的值

      (1)方法一

               w2=sess.run('w2:0')            #"name:index"的形式

               print(w2)

      (2)方法二

              w2=sess.graph.get_operation_by_name('w2:0')  

              print(sess.run(w2)

由此可见,第一种方法更加简单,获取张量的时候,一定要写成“ name:index ”的形式。