欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow/serving部署keras模型

程序员文章站 2022-05-26 17:57:22
...

之前写了一篇tensorflow/serving部署tensorflow模型的文章,记录了详细的操作步骤与常见的错误及解决方案,具体见:TensorFlow Serving模型转换与部署

本文主要记录tensorflow/serving部署keras模型过程中的一些重要步骤,以便后续查阅。

我们在keras中保存模型通常用model.save或者model.save_weights函数。
其中,model.save函数保存的模型往往比的是模型的结构与权重,而model.save_weights函数保存的仅仅是模型的结构,因此model.save函数保存的模型往往比model.save_weights函数保存的模型要大些。

在前一篇tensorflow/serving介绍中TensorFlow Serving模型转换与部署,我们知道tensorflow/serving需要pb格式的模型,而本篇文章我们讨论的keras模型是.h5.weights格式的,因此,首先我们需要将.h5.weights格式的keras模型转换为tensorflow/serving框架可识别的pb格式模型,转换代码如下:

def keras_model_to_tfs(model, export_path):
    signature = tf.saved_model.signature_def_utils.predict_signature_def(
        inputs={'input_x': model.input}, 
        outputs={'output_y': model.output}
    )
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
    builder.add_meta_graph_and_variables(
        sess=K.get_session(),
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature,
        },
        legacy_init_op=legacy_init_op)
    builder.save()
    print('Build done.')

简要说明一下keras_model_to_tfs函数的参数
model:导入的keras模型,用keras的load_modelload_weights导入的模型
export_path:转换成pb格式模型后的保存路径

模型转换完成后,剩下的工作就是部署tensorflow/serving框架,并利用grpc接口调用模型预测。
关于具体的tensorflow/serving的部署,可参考之前文章:TensorFlow Serving模型转换与部署,预测代码在之前那篇文章中也有,本文再次贴出一个。

def tfserving_grpc(title, content):
    content = content or title
    content = filter_waste(content)
    model_dir = os.path.join(project_path, 'models_weights')
    with open(os.path.join(model_dir, 'tokenizer.plk'), 'rb') as f:
        tokenizer = pickle.load(f)
    x = tokenizer.texts_to_sequences([jieba.lcut(content)])
    x = x[0]
    if len(x) > MAX_LEN:
        x = x[:MAX_LEN]
    else:
        x = x + [0] * (MAX_LEN - len(x))

    # ip地址为部署tensorflow/serving的IP
    channel = grpc.insecure_channel('xx.xx.xx.xx:8500')  
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'new_yq_model'
    # request.model_spec.version.value = 1000001
    request.model_spec.signature_name = 'serving_default'

    request.inputs["input_x"].CopyFrom(tf.contrib.util.make_tensor_proto(np.array([x], dtype=np.float)))
    response = stub.Predict(request, 10.0)

    results = {}
    for key in response.outputs:
        tensor_proto = response.outputs[key]
        results[key] = tf.contrib.util.make_ndarray(tensor_proto)

    return results

最后给一个main函数的整体过程代码。

model = build_model(len(tokenizer.index_word))
model.load_weights(os.path.join(model_dir, 'best_model.weights'))
model.summary()
export_path = './tfs_models'
keras_model_to_tfs(model, export_path)

参考

使用tensorflow serving部署keras模型(tensorflow 2.0.0)
keras、tensorflow serving踩坑记