欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

flask搭建Keras服务出现的问题解决办法

程序员文章站 2024-01-29 23:37:10
...

当使用Keras训练好了一个识别模型后,如果采用线上部署为服务,一般情况下采用flask或者Django进行服务搭建。在我电脑上使用的环境是Keras 2.3.1、tensorflow 1.15.3这个版本。将手写数字的识别模型使用flask部署为服务。代码如下:

from flask import Flask
from flask import request
import numpy as np
import keras
from keras import models
import tensorflow as tf
import cv2

#将network定义为全局变量
network=models.load_model('./model/mnist_cnn.h5')

app=Flask(__name__)

@app.route('/')
def index():
    return 'welcome to visit huanhuncao server'

@app.route('/post',methods=['GET','POST'])
def lx_post():
    if request.method=='POST':
        test_img=cv2.imread('3.jpg',0)
        test_img=test_img.reshape(28,28,1)
        test_img=test_img.reshape((1,)+test_img.shape)
        test_img=test_img.astype('float32')/255
        output=network.predict(test_img)
        output=output.argmax(axis=1)         
        output=str(output)
        return output

if __name__=='__main__':
    app.run(host='172.24.103.157',port=6001)

我们分析如上的代码,首先在全局变量里就将模型文件载入进去,然后一旦接收到post请求,就将结果返回,从程序逻辑上看似乎没什么问题。运行这个服务,进行测试下:
发现服务会报错:
flask搭建Keras服务出现的问题解决办法
错误提示:

ValueError: Tensor Tensor("dense_2/Softmax:0", shape=(?, 10), dtype=float32) is not an element of this graph.

错误的原因在于Keras使用tensorflow作为后端时,tensorflow的操作都是默认加载在一个默认的Graph中,所以如果为了避免出错,自己就要创建Graph以及Session。
针对这个问题进行修改,修改后的代码如下:

from flask import Flask
from flask import request
import numpy as np
import keras
from keras import models
import tensorflow as tf
import cv2

#将network定义为全局变量
global sess,graph
#tf2.x中为sess = tf.compat.v1.keras.backend.get_session()
sess=keras.backend.get_session()
graph=tf.get_default_graph()

network=models.load_model('./model/mnist_cnn.h5')

app=Flask(__name__)

@app.route('/')
def index():
    return 'welcome to visit huanhuncao server'

@app.route('/post',methods=['GET','POST'])
def lx_post():
    if request.method=='POST':
        test_img=cv2.imread('3.jpg',0)
        test_img=test_img.reshape(28,28,1)
        test_img=test_img.reshape((1,)+test_img.shape)
        test_img=test_img.astype('float32')/255
        #在默认会话与计算图中进行模型的预测
        with sess.as_default():
            with graph.as_default():
                output=network.predict(test_img)
        output=output.argmax(axis=1)         
        output=str(output)
        return output

if __name__=='__main__':
    app.run(host='172.24.103.157',port=6001)

运行这个服务,可以看到已经成功了。
flask搭建Keras服务出现的问题解决办法
在使用flask进行Keras模型预测的时候,还有一种错误会出现,比如使用Keras 2.2和tensorflow 1.15会出现一种线程错误,这种解决办法是将flask以单线程进行运行,网上很多说法是将Keras和tensorflow的版本进行降级,在我看来是不需要这种做法的。

app.run(host='172.24.103.157',port=6001,threaded=False)

其实只要tensorflow和Keras的版本一一对应上,是不会出现这个问题的。Keras与tensorflow的版本对应从这个网页里可以查看。