欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Tensorflow导出pb模型,并在python和matlab下分别进行预测

程序员文章站 2022-06-21 16:46:05
tensorflow下训练完模型测试程序比较杂乱,特此整理一下。1、我都是在linux下训练,windows下测试分析,训练保存模型如下所示。2、然后调用frozen_model.py将模型进行固化...

tensorflow下训练完模型测试程序比较杂乱,特此整理一下。

1、我都是在linux下训练,windows下调用测试,训练保存模型如下所示。

Tensorflow导出pb模型,并在python和matlab下分别进行预测

2、然后调用frozen_model.py将模型进行固化,这里需要注意一点就是网络输出结点的名称,可以在tensorboard查看GRAPHS中网络输出结点名或训练时进行命名。

frozen_model.py

import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint,output_graph):
    # 原模型中输出节点名称
    output_node_names = "generator/decoder_1/output_node"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph() 
    input_graph_def = graph.as_graph_def() 
 
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) 
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))
 
        with tf.gfile.GFile(output_graph, "wb") as f: 
            f.write(output_graph_def.SerializeToString()) 
        print("%d ops in the final graph." % len(output_graph_def.node)) 
        

tf.reset_default_graph()
input_checkpoint = "tensorflow_model/spot_train/model-500"
output_graph = "frozen_model/frozen_model.pb"
freeze_graph(input_checkpoint,output_graph)

3、得到pb模型后,调用test.py进行测试。需要注意输入输出tensor名字一定要写对,一般结点名字后面加":0"就是对应tensor名,可以在这个网站打开pb模型查看tensor名

https://lutzroeder.github.io/netron/

Tensorflow导出pb模型,并在python和matlab下分别进行预测

test.py

#-*- coding:utf-8 -*-
import os
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
import scipy.io 
from tensorflow.python.platform import gfile
tf.reset_default_graph()

pb_file_path = 'model_package/frozen_model/'
result_file_path = 'test_results/'

def preprocess(x):   
    Max = np.max(x)
    Min = np.min(x)
    x = (x-Min)/(Max-Min)
    return x*2-1

def deprocess(x):
    return (x+1)/2

data = scipy.io.loadmat('1.mat')['data']
flatten_img = preprocess(np.reshape(data, [1, 256,256,1]))   

sess = tf.Session()
with gfile.FastGFile(pb_file_path + 'frozen_model.pb', 'rb') as f: #加载模型
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')  # 导入计算图

# 初始化
#sess.run(tf.global_variables_initializer())

x = sess.graph.get_tensor_by_name('batch:1')  
y = sess.graph.get_tensor_by_name('generator/decoder_1/output_node:0')  
                                                       
y_out=sess.run(y,feed_dict={x:flatten_img})
scipy.io.savemat(result_file_path+'test.mat', {'output':y_out})

后面为方便matlab调用,又整理成类了

python_test.py

#-*- coding:utf-8 -*-
import os
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
import scipy.io 
from tensorflow.python.platform import gfile
from glob import glob

#pb_file_path = 'model_package/frozen_model/'

class Predict(object):
    
    def __init__(self):
        tf.reset_default_graph()    
        self.sess = tf.Session()
        with gfile.FastGFile('frozen_model/frozen_model.pb', 'rb') as f: #加载模型
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')  #导入计算图
        print('load model success')
        #sess.run(tf.global_variables_initializer()) #初始化
        self.x = sess.graph.get_tensor_by_name('batch:1')  
        self.y = sess.graph.get_tensor_by_name('generator/decoder_1/output_node:0') 
        
    def preprocess(x):   
        Max = np.max(x)
        Min = np.min(x)
        x = (x-Min)/(Max-Min)
        return x*2-1

    def deprocess(x):
        return (x+1)/2   
    
    def predict(self,input_path):
        #加载测试输出数据
        img= scipy.io.loadmat(input_path)['data'] 
        flatten_img = preprocess(np.reshape(img, [1, 256,256,1]))  
        
        y_out = sess.run(self.y,feed_dict={self.x:flatten_img})
        y_out = np.squeeze(deprocess(y_out))
        scipy.io.savemat('test_results/'+input_path[17:], {'output':y_out})

if __name__ == '__main__':
    model = Predict()
    model.predict("1.mat")

4、Matlab中测试采用的是调用python测试脚本还实现的。

clear;close all;clc
clear classes

tf = py.importlib.import_module('tensorflow');
np = py.importlib.import_module('numpy');
%plt = py.importlib.import_module('matplotlib.pyplot');
sio = py.importlib.import_module('scipy.io');

obj = py.importlib.import_module('python_test'); %python测试脚本路径
py.importlib.reload(obj);
a = py.pix2pix_test.Predict();
a.predict('test_data/2.bmp')
a.predict('test_data/3.bmp')

写完回头看了一眼,咋这么混乱都没说清也就自己能看懂了,先这样吧后面要提高一下写作能力了,再接再厉,加油!

本文地址:https://blog.csdn.net/megaoliyuanzhende/article/details/107365281