欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow Save

程序员文章站 2022-05-29 12:11:01
...

保存为四个文件:

my-model.ckpt.meta          保存整个计算图的结构

my-model.ckpt.data-*        保存模型中每个变量的取值

my-model.ckpt.index

checkpoint                          记录目录下所有模型文件列表

Tensorflow Save

.ckpt模型     图结构.meta与变量值.ckpt分离

from __future__ import print_function
import tensorflow as tf
import numpy as np


'''*********************自定义图运算******************'''
'''*********************自定义图运算******************'''
'''*********************自定义图运算******************'''


'''
#**********************************************在一张图、会话中存入 再载入 变量************************************
tf.reset_default_graph()
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')

init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))
    print(W.name,b.name)
    save_path = saver.save(sess, "my_net/save_net.ckpt")
    print("Save to path: ", save_path)
    saver.restore(sess, "my_net/save_net.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))
    print(W.name,b.name)

tf.reset_default_graph()  #不然以下的w名称为 weight_1     
'''
'''
#***********************************************在不同图、会话中存入 载入变量**************************************
#----------------------------------save------------------
tf.reset_default_graph() #!!!!!!!!!!!!!!!!!
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')             #导出张量名为weights:0  计算节点名weights    
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
print(W.name)

init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess, "my_net/save_net.ckpt")
    print("Save to path: ", save_path)



#---------------------------------reload-----------------

tf.reset_default_graph() #!!!!!!!!!!!!!!!!!
#1 WW = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights2")
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
#print(W.name)
# 自己定义图上运算 参数无需初始化 而将值直接按名称加载进来

saver = tf.train.Saver()
#2 saver = tf.train.Saver([W])              # **************************[对应名称的变量的张量] 列表形式只获取部分变量
#1 saver = tf.train.Saver({'weights':WW})   #***********************{‘原名’: }形式 重命名变量  将原名weight的值放入WW中     %名字无需加上 :0 部分  --计算节点
with tf.Session() as sess:
   
    # 提取变量
    saver.restore(sess, "my_net/save_net.ckpt")
    #print("weights:", sess.run(W))
    print("biases:", sess.run(W))
    
'''  

'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''

#**********************************************直接加载图而无需重复定义图上的运算***********************************

tf.reset_default_graph() 

init = tf.global_variables_initializer()
saver = tf.train.import_meta_graph("C:/Users/Administrator/Desktop/my_net/save_net.ckpt.meta")  #加载图 同时 下面导入变量值

with tf.Session() as sess:
    sess.run(init)
    # 提取变量
    saver.restore(sess, "my_net/save_net.ckpt")
    #通过张量的名称来获取张量 )#Tensor names must be of the form "<op_name>:<output_index>".
    print(sess.run(tf.get_default_graph().get_tensor_by_name('weights:0')))                              #%名字需加上 :0 部分   因为是获取张量


.pb模型   freeze的模型,该模型已经是包含图和相应的参数了

import tensorflow as tf
from tensorflow.python.framework import graph_util  


'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''

    
''' 

#***********************************************在不同图、会话中存入 载入变量**************************************
#----------------------------------save------------------
tf.reset_default_graph() #!!!!!!!!!!!!!!!!!
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')             #导出张量名为weights:0  计算节点名weights    
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
print(W.name)

init = tf.global_variables_initializer()


with tf.Session() as sess:
    sess.run(init)
    graph_def = tf.get_default_graph().as_graph_def()    #得到当前的图的 GraphDef 部分,==输入层到输出层的计算过程
    output_graph_def = graph_util.convert_variables_to_constants(sess,   #计算图中的变量及其取值通过常量的方式保存于一个文件中  
                                                        graph_def, ['weights'])   ##需要保存【计算节点】的名字    %舍去无用节点 保存该节点下子图及变量值
  
    with tf.gfile.GFile("model/w.pb", 'wb') as f:  #通过 tf.gfile.GFile 进行模型持久化
        f.write(output_graph_def.SerializeToString())   # 序列化输出

'''

#---------------------------------reload-----------------
from tensorflow.python.platform import gfile  

tf.reset_default_graph() #!!!!!!!!!!!!!!!!!

  
with tf.Session() as sess:  
    model_filename = "Model/combined_model.pb"  
    with gfile.FastGFile(model_filename, 'rb') as f:  
        graph_def = tf.GraphDef()  
        graph_def.ParseFromString(f.read())  
  
    result = tf.import_graph_def(graph_def, return_elements=['weights:0'])   #得输出节点的值--【张量】
    print(sess.run(result)) # [array([ 3.], dtype=float32)]  


参考:

http://blog.csdn.net/marsjhao/article/details/72829635  书译

http://blog.csdn.net/michael_yt/article/details/74737489

相关标签: tensorflow save