【AI模型部署】基于flask的pytorch简单分类模型部署
程序员文章站
2022-03-27 20:13:39
【本博客代码】https://gitee.com/zengxy2020/csdn/tree/master/flask【官方教程】https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html【官方代码】https://github.com/avinassh/pytorch-flask-api(本博客基于此代码有改动)【flask文档】https://flask.palletsprojects.com/en/1.1.x...
- 【本博客代码】https://gitee.com/zengxy2020/csdn/tree/master/flask
- 【官方教程】https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html
- 【官方代码】https://github.com/avinassh/pytorch-flask-api(本博客基于此代码有改动)
- 【flask文档】https://flask.palletsprojects.com/en/1.1.x/
一.简介
通过flask框架,部署pytorch模型后,可以通过不同主机向服务端主机发送图像路径或图片请求服务,处理后返回结果。本文的返回结果是,对图片的分类结果】
1.1服务端
部署成功后,服务端接收不同主机请求的过程图如下:
服务端本地图片信息 :
1.2返回结果
其他主机的浏览器向服务端传图片路径(图片在服务端本机)
其他主机直接传送图片到服务端(图片在客户端)
二、实现过程
2.1测试flask服务
【参考】https://flask.palletsprojects.com/en/1.1.x/quickstart/#a-minimal-application
安装
pip install flask
测试程序:hello.py
from flask import Flask
app = Flask(__name__)
@app.route('/')
def predict():
return "hello world! It is flask!"
if __name__ == '__main__':
app.run()
运行:
python hello.py
结果
(本地浏览器访问:http://127.0.0.1:5000/)
2.2 pytorch模型
代码
import io
import json
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__) # 固定写法
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
#'/predict'是会影响请求的格式,可*改名。
# 需要添加“get”方法,才能直接通过浏览器发送请求
# 请求的路径path是图片的路径,一般是在服务端本机
# 浏览器输入实例,请换自己的ip和路径:http://192.168.1.139:5005/predict?path=/home/ai004/sdg4.jpg
@app.route('/predict', methods=['GET', 'POST'])
def predict():
if request.method == 'POST': # 接收传输的图片
file = request.files['file']
# zxy add for GET
else:
image_file = request.args.get("path") #接收其他客户端浏览器发送的请求
file = open(image_file, 'rb')
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
# app.run() # 原工程的写法,默认只能本机访问
app.run(host='0.0.0.0', port=5005) # 使其他主机可以访问服务
外部主机请求服务(需修改代码,指定ip等)
#从外部主机发送图片到服务器,并接收返回结果
curl -X POST -F file=@2.jpg http://192.168.1.139:5005/predict
# 从浏览器发出请求,图片在服务端本地
http://192.168.1.139:5005/predict?path=/home/ai004/sdg4.jpg
结果
如博文简介部分所示
三、完整工程
【本博客代码】https://gitee.com/zengxy2020/csdn/tree/master/flask
本文地址:https://blog.csdn.net/imwaters/article/details/109264716
上一篇: 5年前端竟败在了CSS面试上
下一篇: 最长回文子串-java版