signed

QiShunwang

“诚信为本、客户至上”

【Pytorch Flask】通过使用 Flask 的 REST API 在 Python 中部署 PyTorch

2021/6/3 13:47:11   来源:

我们将使用 Flask 部署 PyTorch 模型。这是在生产中部署 PyTorch 模型的系列教程第一篇。 到目前为止,以这种方式使用 Flask 是开始为 PyTorch 模型提供服务的最简单方法。代码和文件下载见>>github 链接

1、安装环境

pip install Flask==1.0.3
pip install torchvision==0.3.0

2、Rest API 实例

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
@app.route('/')
def hello():
    return 'Hello World!'


# 加载预训练densenet121模型
model = models.densenet121(pretrained=True)
model.eval()

# 准备图像
def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


# 加载图片分类的索引
imagenet_class_index = json.load(open('./imagenet_class_index.json'))


# 预测图片类别
def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    try:
        app.run(host='0.0.0.0', port='7200')
    except Exception as e:
       print(e)

请求测试接口

import requests
url="http://127.0.0.1:7200/predict"
image_path=r'D:\28.jpeg'
resp = requests.post(url,files={"file": open(image_path,'rb')})
print(resp.text)

运行结果:

{"class_id":"n02124075","class_name":"Egyptian_cat"}