欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

【AI模型部署】基于flask的pytorch简单分类模型部署

程序员文章站 2022-03-27 20:13:39
【本博客代码】https://gitee.com/zengxy2020/csdn/tree/master/flask【官方教程】https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html【官方代码】https://github.com/avinassh/pytorch-flask-api(本博客基于此代码有改动)【flask文档】https://flask.palletsprojects.com/en/1.1.x...

一.简介

通过flask框架,部署pytorch模型后,可以通过不同主机向服务端主机发送图像路径或图片请求服务,处理后返回结果。本文的返回结果是,对图片的分类结果】

1.1服务端

部署成功后,服务端接收不同主机请求的过程图如下:

【AI模型部署】基于flask的pytorch简单分类模型部署

服务端本地图片信息 :

 【AI模型部署】基于flask的pytorch简单分类模型部署

1.2返回结果

其他主机的浏览器向服务端传图片路径(图片在服务端本机)

【AI模型部署】基于flask的pytorch简单分类模型部署

其他主机直接传送图片到服务端(图片在客户端)

【AI模型部署】基于flask的pytorch简单分类模型部署

二、实现过程

2.1测试flask服务

【参考】https://flask.palletsprojects.com/en/1.1.x/quickstart/#a-minimal-application

安装

pip install flask

测试程序:hello.py 

from flask import Flask 
app = Flask(__name__)

@app.route('/')
def predict():
	return "hello world!  It is flask!"
	
if __name__ == '__main__':
    app.run()

运行:

python hello.py

 结果

(本地浏览器访问:http://127.0.0.1:5000/

【AI模型部署】基于flask的pytorch简单分类模型部署

2.2 pytorch模型

代码

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)  # 固定写法
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

#'/predict'是会影响请求的格式,可*改名。
# 需要添加“get”方法,才能直接通过浏览器发送请求
# 请求的路径path是图片的路径,一般是在服务端本机
# 浏览器输入实例,请换自己的ip和路径:http://192.168.1.139:5005/predict?path=/home/ai004/sdg4.jpg


@app.route('/predict', methods=['GET', 'POST'])
def predict():
    if request.method == 'POST':  # 接收传输的图片
        file = request.files['file']
    # zxy add for GET
    else:
        image_file = request.args.get("path") #接收其他客户端浏览器发送的请求
        file = open(image_file, 'rb')
    img_bytes = file.read()
    class_id, class_name = get_prediction(image_bytes=img_bytes)
    return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    # app.run() # 原工程的写法,默认只能本机访问
    app.run(host='0.0.0.0', port=5005)  # 使其他主机可以访问服务

 外部主机请求服务(需修改代码,指定ip等)

#从外部主机发送图片到服务器,并接收返回结果

curl -X POST -F file=@2.jpg http://192.168.1.139:5005/predict

# 从浏览器发出请求,图片在服务端本地

http://192.168.1.139:5005/predict?path=/home/ai004/sdg4.jpg

结果

如博文简介部分所示

 

三、完整工程

【本博客代码】https://gitee.com/zengxy2020/csdn/tree/master/flask

 

 

本文地址:https://blog.csdn.net/imwaters/article/details/109264716