flask工程部署
程序员文章站
2022-06-13 15:19:16
...
1.本次博客记录了flask的部署方案,作为参考,仅此记录
#-*- coding=utf-8 -*-
"""Http-Server for the ASR demo."""
import os
import io
import time
import wave
import shutil
import argparse
import functools
import pydub
# import numpy as np
import urllib
import urllib2
import thread
import hashlib
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import _init_paths
import threading
import traceback
import paddle.v2 as paddle
import subprocess
import json
import math
from flask_cors import CORS
from Queue import Queue
# from werkzeug.utils import secure_filename
from flask import Flask, request, jsonify,render_template
from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model
from utils.utility import add_arguments, print_arguments
import logging
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('host_port', int, 5000, "Server's IP port.")
add_arg('beam_size', int, 500, "Beam search width.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 1024, "# of recurrent cells per layer.")
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, True, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
"bi-directional RNNs. Not for GRU.")
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('speech_save_dir', str,
'data',
"Directory to save demo audios.")
add_arg('mean_std_path', str,
'../models/Hoge_ASR/mean_std.npz',
"Filepath of normalizer's mean & std.")
add_arg('vocab_path', str,
'../models/Hoge_ASR/vocab.txt',
"Filepath of vocabulary.")
add_arg('model_path', str,
'../checkpoints/Hoge_ASR/params.latest.tar.gz',
"If None, the training starts from scratch, "
"otherwise, it resumes from the pre-trained model.")
add_arg('lang_model_path', str,
'../models/lm/zh_giga.no_cna_cmn.prune01244.klm',
"Filepath for language model.")
add_arg('decoding_method', str,
'ctc_beam_search',
"Decoding method. Options: ctc_beam_search, ctc_greedy",
choices = ['ctc_beam_search', 'ctc_greedy'])
add_arg('specgram_type', str,
'linear',
"Audio feature type. Options: linear, mfcc.",
choices=['linear', 'mfcc'])
# yapf: disable
args = parser.parse_args()
# def warm_up_test(audio_process_handler,au_path):
# """recognition."""
# start_time = time.time()
# transcript = audio_process_handler(au_path)
# finish_time = time.time()
# print("Response Time: %f, Transcript: %s" %(finish_time - start_time, transcript))
def start_transcript():
"""Start the ASR server"""
# prepare data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1,
keep_transcription_text=True)
# prepare ASR model
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
use_gru=args.use_gru,
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
if args.decoding_method == "ctc_beam_search":
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list)
# prepare ASR inference handler
def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "")
probs_split = ds2_model.infer_batch_probs(
infer_data=[feature],
feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
result_transcript = ds2_model.decode_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
result_transcript = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
num_processes=1)
return result_transcript[0]
t_file_to_transcript = file_to_transcript
return t_file_to_transcript
# start the transcript
#------------------------------------------------------------------
#start flask webframe
app = Flask(__name__)
handler = logging.FileHandler('ai_asr.log', encoding='UTF-8')
handler.setLevel(logging.DEBUG)
logging_format = logging.Formatter(LOG_FORMAT)
handler.setFormatter(logging_format)
logging.getLogger().addHandler(handler)
CORS(app)
# 创建主队列
sendQ = Queue()
# 根路径,返回首页
@app.route('/')
def index():
return render_template('index.html')
# 错误请求
def errorResp(msg):
return jsonify(code=-1, message=msg)
# 成功请求
def successResp(data):
return jsonify(code=0, message="success", data=data)
def errorCallBack(url):# "callback_smart_url":"https://10.0.1.111:6100"
url = url
msg = {}
msg[u'code'] = -1
msg[u'status'] = u'fail'
data = msg
#r = requests.post(url=url, data=json.dumps(data) if data is not None else None)
#logging.info(r)
def successCallBack(url,task_id,param=None):
import requests
param['task_id'] = task_id
resp = requests.post(url, data=json.dumps(param) if param else None)
#print('{}\t\t{}\r\n{}\r\n'.format(url, str(resp), resp.text))
#print json.loads(resp.text)
return str(resp.text),str(task_id)
@app.route('/callback_test',methods=['POST'])
def callback_test():
f = request.get_data()
data = json.loads(f)
return str(data)
def split_audio(filepath):
audio_fragments = []
audio_time = 0.0
try:
tmp_dir = os.path.dirname(filepath)
sound = pydub.AudioSegment.from_file(filepath)
audio_len = len(sound)
audio_time = round(audio_len / 1000.0, 3)
if audio_len > 60000:
numbers = audio_len / 60000
else:
numbers = 0
fragment_audio_len = int(math.ceil((1.0 * audio_len / (numbers + 1))))
mono = sound.set_frame_rate(16000).set_channels(1).set_sample_width(2)
for x in range(numbers + 1):
s_time_minute = (x * fragment_audio_len) / 60000
s_time_second = round(((x * fragment_audio_len) % 60000) / 1000.0, 3)
e_time_minute = ((x + 1) * fragment_audio_len) / 60000
e_time_second = round((((x + 1) * fragment_audio_len) % 60000) / 1000.0, 3)
audio_fragment_filename = "%d:%.3f_%d:%.3f_.wav" % (
s_time_minute, s_time_second, e_time_minute, e_time_second)
filename = os.path.join(tmp_dir, audio_fragment_filename)
audio_fragments.append(filename)
mono[x * fragment_audio_len:(x + 1) * fragment_audio_len + 20].export(filename, format='wav')
success = 1
except:
success = 0
return success, audio_time, audio_fragments
def s_translate(src_path,task_id,callback_url):
success = 0
start_time = time.time()
wav_temp = u'/storage/AI_server/tmp/files_wav/'
if os.path.exists(wav_temp):
shutil.rmtree(wav_temp)
os.makedirs(wav_temp)
wav_filename = wav_temp + os.path.basename(src_path)[:-4] + u'.wav'
cmd_shell = u'ffmpeg -i ' + src_path + u' -f wav -ar 16k -ac 2 -y ' + wav_filename
try:
return_code = subprocess.call(cmd_shell, shell=True)
if return_code:
logging.info("Your file format is failed to transform!")
success = 0
#return errorResp("fail")
except:
logging.info("Your file format does not meet the requirements!")
success = 0
#return errorResp("fail")
recv_queue = Queue()
resps = {'text': [], 'audioTime': 0}
_result, audioTime, fragment_files = split_audio(wav_temp + os.path.basename(wav_filename))
resps['audioTime'] = audioTime
if _result:
for file_t in fragment_files:
try:
stat0 = os.path.basename(file_t).split(u"_")[0].split(u":")
stat1 = os.path.basename(file_t).split(u"_")[1].split(u":")
data = []
data.append((file_t,))
sendQ.put((data, recv_queue))
success, text = recv_queue.get()
r_time = stat0[0] + ":" + stat0[1] + "," + stat1[0] + ":" + stat1[1]
resps['text'].append(([r_time], text.decode("utf8", "ignore")))
logging.info("Received utterance result: %s, Time: %s." % (text, r_time))
success = 1
except:
success = 0
pass
finish_time = time.time()
logging.info("Response Time: %f" % (finish_time - start_time))
# if success:
# # return successResp(resps)
# # else:
# # return errorResp("fail")
if success:
return successCallBack(callback_url,task_id,resps)
else:
return errorCallBack(callback_url)
@app.route('/translate',methods=['POST'])
def translate():
f = request.get_data()
audio_path = json.loads(f)
src_path = audio_path["filepath"]
callback_url = audio_path["callback_url"]
resps = hashlib.md5(src_path).hexdigest()
try:
thread.start_new_thread(s_translate,(src_path,resps,callback_url))
success = 1
except:
print "Error: unable to start thread"
success = 0
if success:
return successResp(resps)
else:
return errorResp("fail")
# 创建一个PaddlePaddle的预测线程
def worker(s):
while True:
# 获取数据和子队列
filename, recv_queue = sendQ.get()
try:
# 获取预测结果
result = s(filename[0][0])
# 处理预测结果
recv_queue.put((True, result))
except:
# 通过子队列发送异常信息
trace = traceback.format_exc()
logging.info(trace)
recv_queue.put((False, trace))
continue
def main():
# 打印模型参数
print_arguments(args)
# 初始化PaddlePaddle
paddle.init(use_gpu=False, trainer_count=2)
# 加载模型参数和预测的拓扑生成一个预测器
file_to_transcript = start_transcript()
logging.info("ASR Server Started.")
return file_to_transcript
if __name__ == "__main__":
s = main()
t = threading.Thread(target=worker,args=(s,))
t.daemon = True
t.start()
app.run(host='0.0.0.0', port=6100, threaded=True)