flask搭建Keras服务出现的问题解决办法
程序员文章站
2024-01-29 23:37:10
...
当使用Keras训练好了一个识别模型后,如果采用线上部署为服务,一般情况下采用flask或者Django进行服务搭建。在我电脑上使用的环境是Keras 2.3.1、tensorflow 1.15.3这个版本。将手写数字的识别模型使用flask部署为服务。代码如下:
from flask import Flask
from flask import request
import numpy as np
import keras
from keras import models
import tensorflow as tf
import cv2
#将network定义为全局变量
network=models.load_model('./model/mnist_cnn.h5')
app=Flask(__name__)
@app.route('/')
def index():
return 'welcome to visit huanhuncao server'
@app.route('/post',methods=['GET','POST'])
def lx_post():
if request.method=='POST':
test_img=cv2.imread('3.jpg',0)
test_img=test_img.reshape(28,28,1)
test_img=test_img.reshape((1,)+test_img.shape)
test_img=test_img.astype('float32')/255
output=network.predict(test_img)
output=output.argmax(axis=1)
output=str(output)
return output
if __name__=='__main__':
app.run(host='172.24.103.157',port=6001)
我们分析如上的代码,首先在全局变量里就将模型文件载入进去,然后一旦接收到post请求,就将结果返回,从程序逻辑上看似乎没什么问题。运行这个服务,进行测试下:
发现服务会报错:
错误提示:
ValueError: Tensor Tensor("dense_2/Softmax:0", shape=(?, 10), dtype=float32) is not an element of this graph.
错误的原因在于Keras使用tensorflow作为后端时,tensorflow的操作都是默认加载在一个默认的Graph中,所以如果为了避免出错,自己就要创建Graph以及Session。
针对这个问题进行修改,修改后的代码如下:
from flask import Flask
from flask import request
import numpy as np
import keras
from keras import models
import tensorflow as tf
import cv2
#将network定义为全局变量
global sess,graph
#tf2.x中为sess = tf.compat.v1.keras.backend.get_session()
sess=keras.backend.get_session()
graph=tf.get_default_graph()
network=models.load_model('./model/mnist_cnn.h5')
app=Flask(__name__)
@app.route('/')
def index():
return 'welcome to visit huanhuncao server'
@app.route('/post',methods=['GET','POST'])
def lx_post():
if request.method=='POST':
test_img=cv2.imread('3.jpg',0)
test_img=test_img.reshape(28,28,1)
test_img=test_img.reshape((1,)+test_img.shape)
test_img=test_img.astype('float32')/255
#在默认会话与计算图中进行模型的预测
with sess.as_default():
with graph.as_default():
output=network.predict(test_img)
output=output.argmax(axis=1)
output=str(output)
return output
if __name__=='__main__':
app.run(host='172.24.103.157',port=6001)
运行这个服务,可以看到已经成功了。
在使用flask进行Keras模型预测的时候,还有一种错误会出现,比如使用Keras 2.2和tensorflow 1.15会出现一种线程错误,这种解决办法是将flask以单线程进行运行,网上很多说法是将Keras和tensorflow的版本进行降级,在我看来是不需要这种做法的。
app.run(host='172.24.103.157',port=6001,threaded=False)
其实只要tensorflow和Keras的版本一一对应上,是不会出现这个问题的。Keras与tensorflow的版本对应从这个网页里可以查看。