Bert文本分类run_classifier的预测模块修改
程序员文章站
2022-05-14 18:16:01
...
修改位置1:run_classifier.py model_fn() 函数中:
源码1:
else:
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, predictions=probabilities, scaffold_fn=scaffold_fn)
替换源码1:
elif mode == tf.estimator.ModeKeys.PREDICT:
def metric_fn(logits,probabilities):
predicted_classes = tf.argmax(logits, axis=1,output_type=tf.int32)
return {
'pred_class_ids': predicted_classes[:, tf.newaxis],
'probabilities':probabilities,
'logits': logits}
pred_metrics = metric_fn(logits,probabilities)
output_spec = tf.estimator.EstimatorSpec(
mode=mode,predictions=pred_metrics)
修改位置2:run_classifier.py main()函数中:
源码2:
with tf.gfile.GFile(output_predict_file, "w") as writer:
tf.logging.info("***** Predict results *****")
for prediction in result:
output_line = "\t".join(
str(class_probability) for class_probability in prediction) + "\n"
writer.write(output_line)
替换代码2:
with tf.gfile.GFile(output_predict_file, "w") as writer:
tf.logging.info("***** Predict results *****")
pred_true_nums = 0 #预测正确个数
for test_sample_nums,prediction in enumerate(result,1):
output_line = "\t".join(
str(class_probability) for class_probability in prediction.items()) + "\n"
pred_true_nums += int(prediction["pred_class_ids"])
writer.write(output_line)
writer.write("\n"+"".join("pred_accuracy:" +str(pred_true_nums/test_sample_nums)))
上一篇: 【发邮件】C++