keras和tensorflow模型同时读取要慎重
程序员文章站
2022-05-26 17:54:51
...
项目中,先读取了一个keras模型获取模型输入size,再加载keras转tensorflow后的pb模型进行预测。
报错:
Attempting to use uninitialized value batch_normalization_14/moving_mean
逛论坛,有建议加上初始化:
sess.run(tf.global_variables_initializer())
但是这样的话,会导致模型参数全部变成初始化数据。无法使用预测模型参数。
最后发现,将keras模型的加载去掉即可。
猜测原因:keras模型和tensorflow模型同时读取有坑。
import cv2
import numpy as np
from keras.models import load_model
from utils.datasets import get_labels
from utils.preprocessor import preprocess_input
import time
import os
import tensorflow as tf
from tensorflow.python.platform import gfile
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
emotion_labels = get_labels('fer2013')
emotion_target_size = (64,64)
#emotion_model_path = './models/emotion_model.hdf5'
#emotion_classifier = load_model(emotion_model_path)
#emotion_target_size = emotion_classifier.input_shape[1:3]
path = '/mnt/nas/cv_data/emotion/test'
filelist = os.listdir(path)
total_num = len(filelist)
timeall = 0
n = 0
sess = tf.Session()
#sess.run(tf.global_variables_initializer())
with gfile.FastGFile("./trans_model/emotion_mode.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
pred = sess.graph.get_tensor_by_name("predictions/Softmax:0")
######################img##########################
for item in filelist:
if (item == '.DS_Store') | (item == 'Thumbs.db'):
continue
src = os.path.join(os.path.abspath(path), item)
bgr_image = cv2.imread(src)
gray_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2GRAY)
gray_face = gray_image
try:
gray_face = cv2.resize(gray_face, (emotion_target_size))
except:
continue
gray_face = preprocess_input(gray_face, True)
gray_face = np.expand_dims(gray_face, 0)
gray_face = np.expand_dims(gray_face, -1)
input = sess.graph.get_tensor_by_name('input_1:0')
res = sess.run(pred, {input: gray_face})
print("src:", src)
emotion_probability = np.max(res[0])
emotion_label_arg = np.argmax(res[0])
emotion_text = emotion_labels[emotion_label_arg]
print("predict:", res[0], ",prob:", emotion_probability, ",label:", emotion_label_arg, ",text:",emotion_text)
上一篇: Python virtualenv
下一篇: java对于半角和全角的转换