用bert训练模型并转换为pb格式
程序员文章站
2023-11-06 23:45:46
具体代码在github:https://github.com/danan0755/Bert_Classifier/blob/master/Bert_Train.pydef serving_input_fn(): # 保存模型为SaveModel格式 # 采用最原始的feature方式,输入是feature Tensors。 # 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples df = pd.read_...
具体代码在github:
https://github.com/danan0755/Bert_Classifier/blob/master/Bert_Train.py
def serving_input_fn():
# 保存模型为SaveModel格式
# 采用最原始的feature方式,输入是feature Tensors。
# 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples
df = pd.read_csv(FLAGS.data_dir, delimiter="\t", names=['labels', 'text'], header=None)
dense_units = len(df.labels.unique())
label_ids = tf.placeholder(tf.int32, [None, dense_units], name='label_ids')
input_ids = tf.placeholder(tf.int32, [None, 128], name='input_ids')
input_mask = tf.placeholder(tf.int32, [None, 128], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, 128], name='segment_ids')
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'label_ids': label_ids,
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids,
})()
return input_fn
本文地址:https://blog.csdn.net/qq236237606/article/details/107078973