欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Bert文本分类run_classifier的预测模块修改

程序员文章站 2022-05-14 18:16:01
...

修改位置1:run_classifier.py model_fn() 函数中

源码1:

else:
	output_spec = tf.contrib.tpu.TPUEstimatorSpec(
		mode=mode, predictions=probabilities, scaffold_fn=scaffold_fn)

替换源码1:

elif mode == tf.estimator.ModeKeys.PREDICT:
    def metric_fn(logits,probabilities):
        predicted_classes = tf.argmax(logits, axis=1,output_type=tf.int32)
        return {
             'pred_class_ids': predicted_classes[:, tf.newaxis],
             'probabilities':probabilities,
             'logits': logits}                

    pred_metrics = metric_fn(logits,probabilities)   
    output_spec = tf.estimator.EstimatorSpec(
        mode=mode,predictions=pred_metrics)         

修改位置2:run_classifier.py main()函数中

源码2:

with tf.gfile.GFile(output_predict_file, "w") as writer:
    tf.logging.info("***** Predict results *****")        
    for prediction in result:
        output_line = "\t".join(
        	str(class_probability) for class_probability in prediction) + "\n" 
        writer.write(output_line)   

替换代码2:

with tf.gfile.GFile(output_predict_file, "w") as writer:
    tf.logging.info("***** Predict results *****")
    pred_true_nums = 0     #预测正确个数
    for test_sample_nums,prediction in enumerate(result,1):
        output_line = "\t".join(
            str(class_probability) for class_probability in prediction.items()) + "\n"
        pred_true_nums += int(prediction["pred_class_ids"])
        writer.write(output_line)
    writer.write("\n"+"".join("pred_accuracy:" +str(pred_true_nums/test_sample_nums)))