Tensorflow Save
程序员文章站
2022-05-29 12:11:01
...
保存为四个文件:
my-model.ckpt.meta 保存整个计算图的结构
my-model.ckpt.data-* 保存模型中每个变量的取值
my-model.ckpt.index
checkpoint 记录目录下所有模型文件列表
.ckpt模型 图结构.meta与变量值.ckpt分离
from __future__ import print_function
import tensorflow as tf
import numpy as np
'''*********************自定义图运算******************'''
'''*********************自定义图运算******************'''
'''*********************自定义图运算******************'''
'''
#**********************************************在一张图、会话中存入 再载入 变量************************************
tf.reset_default_graph()
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
print("weights:", sess.run(W))
print("biases:", sess.run(b))
print(W.name,b.name)
save_path = saver.save(sess, "my_net/save_net.ckpt")
print("Save to path: ", save_path)
saver.restore(sess, "my_net/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))
print(W.name,b.name)
tf.reset_default_graph() #不然以下的w名称为 weight_1
'''
'''
#***********************************************在不同图、会话中存入 载入变量**************************************
#----------------------------------save------------------
tf.reset_default_graph() #!!!!!!!!!!!!!!!!!
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights') #导出张量名为weights:0 计算节点名weights
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
print(W.name)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess, "my_net/save_net.ckpt")
print("Save to path: ", save_path)
#---------------------------------reload-----------------
tf.reset_default_graph() #!!!!!!!!!!!!!!!!!
#1 WW = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights2")
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
#print(W.name)
# 自己定义图上运算 参数无需初始化 而将值直接按名称加载进来
saver = tf.train.Saver()
#2 saver = tf.train.Saver([W]) # **************************[对应名称的变量的张量] 列表形式只获取部分变量
#1 saver = tf.train.Saver({'weights':WW}) #***********************{‘原名’: }形式 重命名变量 将原名weight的值放入WW中 %名字无需加上 :0 部分 --计算节点
with tf.Session() as sess:
# 提取变量
saver.restore(sess, "my_net/save_net.ckpt")
#print("weights:", sess.run(W))
print("biases:", sess.run(W))
'''
'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''
#**********************************************直接加载图而无需重复定义图上的运算***********************************
tf.reset_default_graph()
init = tf.global_variables_initializer()
saver = tf.train.import_meta_graph("C:/Users/Administrator/Desktop/my_net/save_net.ckpt.meta") #加载图 同时 下面导入变量值
with tf.Session() as sess:
sess.run(init)
# 提取变量
saver.restore(sess, "my_net/save_net.ckpt")
#通过张量的名称来获取张量 )#Tensor names must be of the form "<op_name>:<output_index>".
print(sess.run(tf.get_default_graph().get_tensor_by_name('weights:0'))) #%名字需加上 :0 部分 因为是获取张量
.pb模型 freeze的模型,该模型已经是包含图和相应的参数了
import tensorflow as tf
from tensorflow.python.framework import graph_util
'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''
'''
#***********************************************在不同图、会话中存入 载入变量**************************************
#----------------------------------save------------------
tf.reset_default_graph() #!!!!!!!!!!!!!!!!!
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights') #导出张量名为weights:0 计算节点名weights
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
print(W.name)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
graph_def = tf.get_default_graph().as_graph_def() #得到当前的图的 GraphDef 部分,==输入层到输出层的计算过程
output_graph_def = graph_util.convert_variables_to_constants(sess, #计算图中的变量及其取值通过常量的方式保存于一个文件中
graph_def, ['weights']) ##需要保存【计算节点】的名字 %舍去无用节点 保存该节点下子图及变量值
with tf.gfile.GFile("model/w.pb", 'wb') as f: #通过 tf.gfile.GFile 进行模型持久化
f.write(output_graph_def.SerializeToString()) # 序列化输出
'''
#---------------------------------reload-----------------
from tensorflow.python.platform import gfile
tf.reset_default_graph() #!!!!!!!!!!!!!!!!!
with tf.Session() as sess:
model_filename = "Model/combined_model.pb"
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
result = tf.import_graph_def(graph_def, return_elements=['weights:0']) #得输出节点的值--【张量】
print(sess.run(result)) # [array([ 3.], dtype=float32)]
参考:
推荐阅读
-
mobilenetv3-tensorflow实战项目准备和代码调试
-
MAC下安装tensorflow 1.15.0版本
-
windows环境下Tensorflow 2.1.0 打包成exe可执行程序
-
在python下使用tensorflow判断是否存在文件夹的实例
-
定制IE下载对话框的按钮(打开Run/保存Save)
-
windows10下安装TensorFlow Object Detection API的步骤
-
TensorFlow进阶项目实战
-
如何使用C#将Tensorflow训练的.pb文件用在生产环境详解
-
Hibernate save() saveorupdate()的用法第1/2页
-
tensorflow2.0之keras实现卷积神经网络