Tensorflow导出pb模型,并在python和matlab下分别进行预测
程序员文章站
2022-03-12 12:37:30
tensorflow下训练完模型测试程序比较杂乱,特此整理一下。1、我都是在linux下训练,windows下测试分析,训练保存模型如下所示。2、然后调用frozen_model.py将模型进行固化...
tensorflow下训练完模型测试程序比较杂乱,特此整理一下。
1、我都是在linux下训练,windows下调用测试,训练保存模型如下所示。
2、然后调用frozen_model.py将模型进行固化,这里需要注意一点就是网络输出结点的名称,可以在tensorboard查看GRAPHS中网络输出结点名或训练时进行命名。
frozen_model.py
import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint,output_graph):
# 原模型中输出节点名称
output_node_names = "generator/decoder_1/output_node"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
tf.reset_default_graph()
input_checkpoint = "tensorflow_model/spot_train/model-500"
output_graph = "frozen_model/frozen_model.pb"
freeze_graph(input_checkpoint,output_graph)
3、得到pb模型后,调用test.py进行测试。需要注意输入输出tensor名字一定要写对,一般结点名字后面加":0"就是对应tensor名,可以在这个网站打开pb模型查看tensor名
https://lutzroeder.github.io/netron/
test.py
#-*- coding:utf-8 -*-
import os
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
import scipy.io
from tensorflow.python.platform import gfile
tf.reset_default_graph()
pb_file_path = 'model_package/frozen_model/'
result_file_path = 'test_results/'
def preprocess(x):
Max = np.max(x)
Min = np.min(x)
x = (x-Min)/(Max-Min)
return x*2-1
def deprocess(x):
return (x+1)/2
data = scipy.io.loadmat('1.mat')['data']
flatten_img = preprocess(np.reshape(data, [1, 256,256,1]))
sess = tf.Session()
with gfile.FastGFile(pb_file_path + 'frozen_model.pb', 'rb') as f: #加载模型
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='') # 导入计算图
# 初始化
#sess.run(tf.global_variables_initializer())
x = sess.graph.get_tensor_by_name('batch:1')
y = sess.graph.get_tensor_by_name('generator/decoder_1/output_node:0')
y_out=sess.run(y,feed_dict={x:flatten_img})
scipy.io.savemat(result_file_path+'test.mat', {'output':y_out})
后面为方便matlab调用,又整理成类了
python_test.py
#-*- coding:utf-8 -*-
import os
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
import scipy.io
from tensorflow.python.platform import gfile
from glob import glob
#pb_file_path = 'model_package/frozen_model/'
class Predict(object):
def __init__(self):
tf.reset_default_graph()
self.sess = tf.Session()
with gfile.FastGFile('frozen_model/frozen_model.pb', 'rb') as f: #加载模型
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='') #导入计算图
print('load model success')
#sess.run(tf.global_variables_initializer()) #初始化
self.x = sess.graph.get_tensor_by_name('batch:1')
self.y = sess.graph.get_tensor_by_name('generator/decoder_1/output_node:0')
def preprocess(x):
Max = np.max(x)
Min = np.min(x)
x = (x-Min)/(Max-Min)
return x*2-1
def deprocess(x):
return (x+1)/2
def predict(self,input_path):
#加载测试输出数据
img= scipy.io.loadmat(input_path)['data']
flatten_img = preprocess(np.reshape(img, [1, 256,256,1]))
y_out = sess.run(self.y,feed_dict={self.x:flatten_img})
y_out = np.squeeze(deprocess(y_out))
scipy.io.savemat('test_results/'+input_path[17:], {'output':y_out})
if __name__ == '__main__':
model = Predict()
model.predict("1.mat")
4、Matlab中测试采用的是调用python测试脚本还实现的。
clear;close all;clc
clear classes
tf = py.importlib.import_module('tensorflow');
np = py.importlib.import_module('numpy');
%plt = py.importlib.import_module('matplotlib.pyplot');
sio = py.importlib.import_module('scipy.io');
obj = py.importlib.import_module('python_test'); %python测试脚本路径
py.importlib.reload(obj);
a = py.pix2pix_test.Predict();
a.predict('test_data/2.bmp')
a.predict('test_data/3.bmp')
写完回头看了一眼,咋这么混乱都没说清也就自己能看懂了,先这样吧后面要提高一下写作能力了,再接再厉,加油!
本文地址:https://blog.csdn.net/megaoliyuanzhende/article/details/107365281