flask搭建Keras服务出现的问题解决办法
当使用Keras训练好了一个识别模型后,如果采用线上部署为服务,一般情况下采用flask或者Django进行服务搭建。在我电脑上使用的环境是Keras 2.3.1、tensorflow 1.15.3这个版本。将手写数字的识别模型使用flask部署为服务。代码如下:
from flask import Flask from flask import request import numpy as np import keras from keras import models import tensorflow as tf import cv2 #将network定义为全局变量 network=models.load_model('./model/mnist_cnn.h5') app=Flask(__name__) @app.route('/') def index(): return 'welcome to visit huanhuncao server' @app.route('/post',methods=['GET','POST']) def lx_post(): if request.method=='POST': test_img=cv2.imread('3.jpg',0) test_img=test_img.reshape(28,28,1) test_img=test_img.reshape((1,)+test_img.shape) test_img=test_img.astype('float32')/255 output=network.predict(test_img) output=output.argmax(axis=1) output=str(output) return output if __name__=='__main__': app.run(host='172.24.103.157',port=6001)
我们分析如上的代码,首先在全局变量里就将模型文件载入进去,然后一旦接收到post请求,就将结果返回,从程序逻辑上看似乎没什么问题。运行这个服务,进行测试下:
发现服务会报错:
错误提示:
ValueError: Tensor Tensor("dense_2/Softmax:0", shape=(?, 10), dtype=float32) is not an element of this graph.
错误的原因在于Keras使用tensorflow作为后端时,tensorflow的操作都是默认加载在一个默认的Graph中,所以如果为了避免出错,自己就要创建Graph以及Session。
针对这个问题进行修改,修改后的代码如下:
from flask import Flask from flask import request import numpy as np import keras from keras import models import tensorflow as tf import cv2 #将network定义为全局变量 global sess,graph #tf2.x中为sess = tf.compat.v1.keras.backend.get_session() sess=keras.backend.get_session() graph=tf.get_default_graph() network=models.load_model('./model/mnist_cnn.h5') app=Flask(__name__) @app.route('/') def index(): return 'welcome to visit huanhuncao server' @app.route('/post',methods=['GET','POST']) def lx_post(): if request.method=='POST': test_img=cv2.imread('3.jpg',0) test_img=test_img.reshape(28,28,1) test_img=test_img.reshape((1,)+test_img.shape) test_img=test_img.astype('float32')/255 #在默认会话与计算图中进行模型的预测 with sess.as_default(): with graph.as_default(): output=network.predict(test_img) output=output.argmax(axis=1) output=str(output) return output if __name__=='__main__': app.run(host='172.24.103.157',port=6001)
运行这个服务,可以看到已经成功了。
在使用flask进行Keras模型预测的时候,还有一种错误会出现,比如使用Keras 2.2和tensorflow 1.15会出现一种线程错误,这种解决办法是将flask以单线程进行运行,网上很多说法是将Keras和tensorflow的版本进行降级,在我看来是不需要这种做法的。
app.run(host='172.24.103.157',port=6001,threaded=False)
其实只要tensorflow和Keras的版本一一对应上,是不会出现这个问题的。Keras与tensorflow的版本对应从这个网页里可以查看。
本文地址:https://blog.csdn.net/qq_37781464/article/details/108842854
上一篇: Android补间、逐帧动画
下一篇: mysql回表致索引失效案例讲解