欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow serving部署Bert预训练模型

程序员文章站 2022-06-16 22:54:53
...

目前没有整理完善,先留个坑~


Bert模型介绍

BERT的关键技术创新是将Transformers双向训练作为一种流行的注意力模型应用到语言建模中。Masked LM (MLM)在向BERT输入单词序列之前,每个序列中有15%的单词被[MASK]token替换。然后,该模型试图根据序列中其他非MASK词提供的上下文来预测MASK词的原始值。

本文主要记录使用tensorflow serving部署训练好的bert模型,并根据模型获取句子向量表示。

ckpt转saved_model格式

google bert原始预训练模型保存的事ckpt格式,用tfserving部署需要saved_model的pb格式,这里需要一个转化过程。

import json
import os
import tensorflow as tf
import argparse

import modeling

def create_model(bert_config, is_training, input_ids):
    # 通过传入的训练数据,进行representation
    model = modeling.BertModel(config=bert_config, is_training=is_training, input_ids=input_ids)
    output = model.get_pooled_output()
    # output = model.get_sequence_output()

    return output

def transfer_saved_model(args):

    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
    sess = tf.Session(config=gpu_config)

    print("going to restore checkpoint")
    bert_config_file = os.path.join(args.model_path, 'bert_config.json')
    bert_config = modeling.BertConfig.from_json_file(bert_config_file)

    input_ids = tf.placeholder(tf.int32, [None, args.max_seq_len], name="input_ids")
    output = create_model(bert_config=bert_config, is_training=False, input_ids=input_ids)

    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(args.model_path))

    tf.saved_model.simple_save(sess, args.export_path,
                               inputs={'input_ids': input_ids},
                               outputs={"outputs": output})
    print('savedModel export finished.')
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Trans ckpt file to .pb file')
    
    parser.add_argument('-model_path', type=str, required=True, help='dir of a pretrained BERT model')
    parser.add_argument('-export_path', type=str, default=None, help='export model path')
    parser.add_argument('-max_seq_len', type=int, default=128, help='maximum length of a sequence')
    args = parser.parse_args()

    transfer_saved_model(args)

运行上述python脚本

python3 export_bert_test.py -model_path uncased_L-12_H-768_A-12 \ 
-export_path bert_output -max_seq_len 128

Docker部署模型

使用docker版tensorflow serving部署模型。

nvidia-docker run -p 8525:8525 -p 8526:8526 --name=bert_model \
--mount type=bind,source=$(pwd)/bert_output,target=/models/bert_output \
-t --entrypoint=tensorflow_model_server tensorflow/serving:1.12.0-gpu \
--port=8525 --rest_api_port=8526 \
--enable_batching=true --file_system_poll_wait_seconds=300 \
--grpc_channel_arguments=“grpc.max_connection_age_ms=5000” \
--per_process_gpu_memory_fraction=0.4 \
-- model_config_file=/models/bert_output/tfserving.conf

http和grpc请求

import tokenizer

def process_input(texts, max_len):
    def tokenize_input(text):
        tokens = tokenizer.tokenize(text)
        text_ids = tokenizer.convert_tokens_to_ids(tokens)
        while len(text_ids) < max_len:
            text_ids.append(0)
        return text_ids

    input_ids = []
    for text in texts:
        text_ids = tokenize_input(text)
        input_ids.append(text_ids)
    
    return input_ids

http请求

import numpy as np
import requests

def predict_http(input_texts):
    SERVER_URL = "http://localhost:8526/v1/models/bert_output:predict"
    
    input_ids = process_input(input_texts, max_len=128)
    input_ids = np.array(input_ids).tolist()
    input_data = {"input_ids": input_ids}
    
    request_data = {"signature_name": "serving_default", "instances": input_data}
    request_data = json.dumps(request_data)

    response = requests.post(SERVER_URL, data=request_data)
    result = response.json()
    pred_value = result['outputs']

    return pred_value

grpc请求

import grpc
import tensorflow as tf
from tensorflow.core.framework import types_pb2
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc

host = "localhost"
port = "8526"

# initialize grpc channel
channel = grpc.insecure_channel('{host}:{port}'.format(host=host, port=port))

def predict_grpc(input_texts, max_len=128):

    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

    result = []

    def predict_request(input_ids):
        """模型预测请求."""
        input_ids = np.array(input_ids)
        input_tensor = tf.make_tensor_proto(input_ids, shape=input_ids.shape, dtype=tf.int32)
        try:
            request = predict_pb2.PredictRequest()
            request.inputs["input_ids"].ParseFromString(input_tensor.SerializeToString())
            request.model_spec.name = "bert_output"
            request.model_spec.signature_name = "serving_default"
            response = stub.Predict(request, 50)
            _result = tf.make_ndarray(response.outputs["outputs"]).tolist()
            result.extend(_result)
        except Exception as e:
            print(e)

        input_ids = process_input(input_texts, max_len)
        predict_request(input_ids)

    return result