欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

flask工程部署

程序员文章站 2022-06-13 15:19:16
...

1.本次博客记录了flask的部署方案,作为参考,仅此记录

#-*- coding=utf-8 -*-
"""Http-Server for the ASR demo."""
import os
import io
import time
import wave
import shutil
import argparse
import functools
import pydub
# import numpy as np
import urllib
import urllib2
import thread
import hashlib
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import _init_paths
import threading
import traceback
import paddle.v2 as paddle
import subprocess
import json
import math
from flask_cors import CORS
from Queue import Queue
# from werkzeug.utils import secure_filename
from flask import Flask, request, jsonify,render_template
from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model
from utils.utility import add_arguments, print_arguments
import logging
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('host_port',        int,    5000,    "Server's IP port.")
add_arg('beam_size',        int,    500,    "Beam search width.")
add_arg('num_conv_layers',  int,    2,      "# of convolution layers.")
add_arg('num_rnn_layers',   int,    3,      "# of recurrent layers.")
add_arg('rnn_layer_size',   int,    1024,   "# of recurrent cells per layer.")
add_arg('alpha',            float,  2.5,   "Coef of LM for beam search.")
add_arg('beta',             float,  0.3,   "Coef of WC for beam search.")
add_arg('cutoff_prob',      float,  1.0,    "Cutoff probability for pruning.")
add_arg('cutoff_top_n',     int,    40,     "Cutoff number for pruning.")
add_arg('use_gru',          bool,   True,  "Use GRUs instead of simple RNNs.")
add_arg('use_gpu',          bool,   True,   "Use GPU or not.")
add_arg('share_rnn_weights',bool,   True,   "Share input-hidden weights across "
                                            "bi-directional RNNs. Not for GRU.")
add_arg('host_ip',          str,
        'localhost',
        "Server's IP address.")
add_arg('speech_save_dir',  str,
        'data',
        "Directory to save demo audios.")
add_arg('mean_std_path',    str,
        '../models/Hoge_ASR/mean_std.npz',
        "Filepath of normalizer's mean & std.")
add_arg('vocab_path',       str,
        '../models/Hoge_ASR/vocab.txt',
        "Filepath of vocabulary.")
add_arg('model_path',       str,
        '../checkpoints/Hoge_ASR/params.latest.tar.gz',
        "If None, the training starts from scratch, "
        "otherwise, it resumes from the pre-trained model.")
add_arg('lang_model_path',  str,
        '../models/lm/zh_giga.no_cna_cmn.prune01244.klm',
        "Filepath for language model.")
add_arg('decoding_method',  str,
        'ctc_beam_search',
        "Decoding method. Options: ctc_beam_search, ctc_greedy",
        choices = ['ctc_beam_search', 'ctc_greedy'])
add_arg('specgram_type',    str,
        'linear',
        "Audio feature type. Options: linear, mfcc.",
        choices=['linear', 'mfcc'])
# yapf: disable
args = parser.parse_args()

# def warm_up_test(audio_process_handler,au_path):
#     """recognition."""
#     start_time = time.time()
#     transcript = audio_process_handler(au_path)
#     finish_time = time.time()
#     print("Response Time: %f, Transcript: %s" %(finish_time - start_time, transcript))


def start_transcript():
    """Start the ASR server"""
    # prepare data generator
    data_generator = DataGenerator(
        vocab_filepath=args.vocab_path,
        mean_std_filepath=args.mean_std_path,
        augmentation_config='{}',
        specgram_type=args.specgram_type,
        num_threads=1,
        keep_transcription_text=True)
    # prepare ASR model
    ds2_model = DeepSpeech2Model(
        vocab_size=data_generator.vocab_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_layer_size=args.rnn_layer_size,
        use_gru=args.use_gru,
        pretrained_model_path=args.model_path,
        share_rnn_weights=args.share_rnn_weights)

    vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]

    if args.decoding_method == "ctc_beam_search":
        ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
                                  vocab_list)

    # prepare ASR inference handler
    def file_to_transcript(filename):
        feature = data_generator.process_utterance(filename, "")
        probs_split = ds2_model.infer_batch_probs(
            infer_data=[feature],
            feeding_dict=data_generator.feeding)
        if args.decoding_method == "ctc_greedy":
            result_transcript = ds2_model.decode_batch_greedy(
                probs_split=probs_split,
                vocab_list=vocab_list)
        else:
            result_transcript = ds2_model.decode_batch_beam_search(
                probs_split=probs_split,
                beam_alpha=args.alpha,
                beam_beta=args.beta,
                beam_size=args.beam_size,
                cutoff_prob=args.cutoff_prob,
                cutoff_top_n=args.cutoff_top_n,
                vocab_list=vocab_list,
                num_processes=1)
        return result_transcript[0]
    t_file_to_transcript = file_to_transcript
    return t_file_to_transcript
    # start the transcript
    #------------------------------------------------------------------
#start flask webframe
app = Flask(__name__)

handler = logging.FileHandler('ai_asr.log', encoding='UTF-8')
handler.setLevel(logging.DEBUG)
logging_format = logging.Formatter(LOG_FORMAT)
handler.setFormatter(logging_format)
logging.getLogger().addHandler(handler)

CORS(app)
# 创建主队列
sendQ = Queue()
# 根路径,返回首页
@app.route('/')
def index():
    return render_template('index.html')

# 错误请求
def errorResp(msg):
    return jsonify(code=-1, message=msg)

# 成功请求
def successResp(data):
    return jsonify(code=0, message="success", data=data)

def errorCallBack(url):#	"callback_smart_url":"https://10.0.1.111:6100"
    url = url
    msg = {}
    msg[u'code'] = -1
    msg[u'status'] = u'fail'
    data = msg
    #r = requests.post(url=url, data=json.dumps(data) if data is not None else None)
    #logging.info(r)

def successCallBack(url,task_id,param=None):
    import requests
    param['task_id'] = task_id
    resp = requests.post(url, data=json.dumps(param) if param else None)
    #print('{}\t\t{}\r\n{}\r\n'.format(url, str(resp), resp.text))
    #print json.loads(resp.text)
    return str(resp.text),str(task_id)

@app.route('/callback_test',methods=['POST'])
def callback_test():
    f = request.get_data()
    data = json.loads(f)
    return str(data)

def split_audio(filepath):
    audio_fragments = []
    audio_time = 0.0
    try:
        tmp_dir = os.path.dirname(filepath)
        sound = pydub.AudioSegment.from_file(filepath)
        audio_len = len(sound)
        audio_time = round(audio_len / 1000.0, 3)
        if audio_len > 60000:
            numbers = audio_len / 60000
        else:
            numbers = 0
        fragment_audio_len = int(math.ceil((1.0 * audio_len / (numbers + 1))))
        mono = sound.set_frame_rate(16000).set_channels(1).set_sample_width(2)

        for x in range(numbers + 1):
            s_time_minute = (x * fragment_audio_len) / 60000
            s_time_second = round(((x * fragment_audio_len) % 60000) / 1000.0, 3)
            e_time_minute = ((x + 1) * fragment_audio_len) / 60000
            e_time_second = round((((x + 1) * fragment_audio_len) % 60000) / 1000.0, 3)

            audio_fragment_filename = "%d:%.3f_%d:%.3f_.wav" % (
            s_time_minute, s_time_second, e_time_minute, e_time_second)
            filename = os.path.join(tmp_dir, audio_fragment_filename)
            audio_fragments.append(filename)
            mono[x * fragment_audio_len:(x + 1) * fragment_audio_len + 20].export(filename, format='wav')
        success = 1
    except:
        success = 0
    return success, audio_time, audio_fragments

def s_translate(src_path,task_id,callback_url):
    success = 0
    start_time = time.time()
    wav_temp = u'/storage/AI_server/tmp/files_wav/'
    if os.path.exists(wav_temp):
        shutil.rmtree(wav_temp)
    os.makedirs(wav_temp)
    wav_filename = wav_temp + os.path.basename(src_path)[:-4] + u'.wav'
    cmd_shell = u'ffmpeg -i ' + src_path + u' -f wav -ar 16k -ac 2 -y ' + wav_filename
    try:
        return_code = subprocess.call(cmd_shell, shell=True)
        if return_code:
            logging.info("Your file format is failed to transform!")
            success = 0
            #return errorResp("fail")
    except:
        logging.info("Your file format does not meet the requirements!")
        success = 0
        #return errorResp("fail")
    recv_queue = Queue()
    resps = {'text': [], 'audioTime': 0}
    _result, audioTime, fragment_files = split_audio(wav_temp + os.path.basename(wav_filename))
    resps['audioTime'] = audioTime

    if _result:
        for file_t in fragment_files:
            try:
                stat0 = os.path.basename(file_t).split(u"_")[0].split(u":")
                stat1 = os.path.basename(file_t).split(u"_")[1].split(u":")
                data = []
                data.append((file_t,))
                sendQ.put((data, recv_queue))
                success, text = recv_queue.get()
                r_time = stat0[0] + ":" + stat0[1] + "," + stat1[0] + ":" + stat1[1]
                resps['text'].append(([r_time], text.decode("utf8", "ignore")))
                logging.info("Received utterance result: %s, Time: %s." % (text, r_time))
                success = 1
            except:
                success = 0
                pass
    finish_time = time.time()
    logging.info("Response Time: %f" % (finish_time - start_time))
    # if success:
    #     #     return successResp(resps)
    #     # else:
    #     #     return errorResp("fail")
    if success:
        return successCallBack(callback_url,task_id,resps)
    else:
        return errorCallBack(callback_url)

@app.route('/translate',methods=['POST'])
def translate():
    f = request.get_data()
    audio_path = json.loads(f)
    src_path = audio_path["filepath"]
    callback_url = audio_path["callback_url"]
    resps = hashlib.md5(src_path).hexdigest()
    try:
        thread.start_new_thread(s_translate,(src_path,resps,callback_url))
        success = 1
    except:
        print "Error: unable to start thread"
        success = 0
    if success:
        return successResp(resps)
    else:
        return errorResp("fail")

# 创建一个PaddlePaddle的预测线程
def worker(s):
    while True:
        # 获取数据和子队列
        filename, recv_queue = sendQ.get()
        try:
            # 获取预测结果
            result = s(filename[0][0])
            # 处理预测结果
            recv_queue.put((True, result))
        except:
            # 通过子队列发送异常信息
            trace = traceback.format_exc()
            logging.info(trace)
            recv_queue.put((False, trace))
            continue

def main():
    # 打印模型参数
    print_arguments(args)
    # 初始化PaddlePaddle
    paddle.init(use_gpu=False, trainer_count=2)
    # 加载模型参数和预测的拓扑生成一个预测器
    file_to_transcript = start_transcript()
    logging.info("ASR Server Started.")
    return file_to_transcript

if __name__ == "__main__":
    s = main()
    t = threading.Thread(target=worker,args=(s,))
    t.daemon = True
    t.start()
    app.run(host='0.0.0.0', port=6100, threaded=True)