欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow解决测试时分类准确率受到batch_size的影响

程序员文章站 2022-06-13 10:16:10
...

之前在使用tensorflow的slim.nets模块进行finetune resnet50时,遇到了训练时的准确率和loss值正常,但在验证和测试时准确率很低的问题,具体表现是测试时batch size越小,测试的准确率越低。但按理说batchsize在测试时应该只影响了运行时间,不应该对准确率产生影响,在这里记录一下解决这个问题的过程。

1. Batch Normalization Layer的问题:

我finutune时用的网络结构是slim.nets.resnet_v2.resnet_v2_50()提供的,该方法中的BN层使用的是layers.batch_norm()实现的,在源码中有着这么一段注释:

  Note: when training, the moving_mean and moving_variance need to be updated.
  By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
  need to be added as a dependency to the `train_op`. For example:

  ```python
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.minimize(loss)
  ```

翻译一下:在训练的时候,需要更新moving_mean和moving_variance这两个参数。默认情况下更新操作被放在`tf.GraphKeys.UPDATE_OPS`里,所以它们需要作为依赖项添加到`train_op`中。

我们知道mean和variance是batch_norm的两个重要的参数,分别代表了训练时一个batch的均值和方差。这段话的意思就是让我们在train op之前设置节点去依赖update_ops,否则moving_mean和moving_variance将不会被更新,所以可能加载进来的值是初始默认值,导致测试的结果很不好。

我按照源码中给的示例进行了修改,但结果并没有得到改善,,,测试的结果还是受到了batch size的影响,于是继续debug。

2. 关于is_training参数:

在batch_norm()方法中,有一个重要的参数is_training,源码中的注释如下:

is_training: Whether or not the layer is in training mode. In training mode
      it would accumulate the statistics of the moments into `moving_mean` and
      `moving_variance` using an exponential moving average with the given
      `decay`. When it is not in training mode then it would use the values of
      the `moving_mean` and the `moving_variance`.

意思就是说当is_training为True时,它会用滑动平均法来积累moving_mean和moving_variance,当is_training为False时,它会使用moving_mean和moving_variance的值。

所以当我们训练时应该将is_training设为True以计算滑动均值和滑动方差,而当测试时应将is_training设为False,即使用训练完保存好的值,而不是当前batch里计算出的值。

问题终于解决了。。。最后贴上整个训练(包含验证)的代码:

# -*- coding: utf-8 -*-
import tensorflow as tf
from preprocess import preprocess_for_train, preprocess_for_test
from utils.parse_functions import parser_tfrecords, parse_list
from utils.generate_list import images_and_labels_list
from tensorflow.contrib.slim import nets
import os
import numpy as np
import time

slim = tf.contrib.slim

# 指定第一块gpu训练
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
"""
""net parameters
"""
num_classes = 10
image_size = 32
batch_size = 128
valid_batch_size = 128

shuffle_buffer = 50000
NUM_EPOCHS = 100
starter_learning_rate = 0.001
boundaries = [15000, 30000, 45000, 60000]
values = [0.001, 0.0005, 0.0001, 0.00005, 0.00001]
"""
""path parameters
"""
resnet_model_path = './resnet_v2_50.ckpt'
models_path = './models/temporary/train-model.ckpt'   # saver保存路径

logs_dir = "./logs/temporary"
logs_train_dir = os.path.join(logs_dir, "train")
logs_valid_dir = os.path.join(logs_dir, "valid")

train_tfrecords = ["./tfrecords/cifar10/train.tfrecords"]
test_tfrecords = ["./tfrecords/cifar10/test.tfrecords"]

for dir_name in [logs_dir, logs_train_dir, logs_valid_dir]:
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)
"""
""Datasets
"""
dataset = tf.data.TFRecordDataset(train_tfrecords)
dataset = dataset.map(parser_tfrecords)
dataset = dataset.shuffle(shuffle_buffer)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(NUM_EPOCHS)
iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()
label_batch = tf.one_hot(indices=tf.cast(label_batch, tf.int32), depth=num_classes)

val_dataset = tf.data.TFRecordDataset(test_tfrecords)
val_dataset = val_dataset.map(parser_tfrecords)
val_dataset = val_dataset.batch(valid_batch_size)
val_iterator = val_dataset.make_initializable_iterator()
val_image_batch, val_label_batch = val_iterator.get_next()
val_label_batch = tf.one_hot(indices=tf.cast(val_label_batch, tf.int32), depth=num_classes)
"""
""placeholder
"""
inputs = tf.placeholder(tf.float32, [None, image_size, image_size, 3])
labels = tf.placeholder(tf.int32, [None, num_classes])
keep_prob = tf.placeholder(tf.float32)
is_training = tf.placeholder(tf.bool)
"""
""inference
"""
with slim.arg_scope(nets.resnet_v2.resnet_arg_scope()):
    net, end_points = nets.resnet_v2.resnet_v2_50(inputs, num_classes=None, is_training=is_training)
with tf.variable_scope('Logits'):
    net = tf.squeeze(net, axis=[1, 2])
    net = slim.dropout(net, keep_prob=keep_prob, scope='scope')
    logits = slim.fully_connected(net, num_outputs=num_classes,
                                  activation_fn=None, scope='fc')
"""
""Learning rate
"""
with tf.variable_scope("learning_rate") as scope:
    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
    tf.summary.scalar(scope.name, learning_rate)
"""
""Loss
"""
with tf.variable_scope("loss") as scope:
    losses = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)
    loss = tf.reduce_mean(losses)
    tf.summary.scalar(scope.name, loss)
"""
""Accuracy
"""
with tf.variable_scope("accuracy") as scope:
    correct_prediction = tf.equal(tf.cast(tf.argmax(logits, 1), tf.int32), tf.cast(tf.argmax(labels, 1), tf.int32))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar(scope.name, accuracy)
"""
""Optimizer
"""
with tf.name_scope("optimizer"):
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_step = optimizer.minimize(loss, global_step)
        # train_step = slim.learning.create_train_op(loss, optimizer, global_step=global_step)
"""
""Summary
"""
summary_op = tf.summary.merge_all()  # 混合所有summary类型log
"""
" Restore resnet50
"""
checkpoint_exclude_scopes = 'Logits'
exclusions = None
if checkpoint_exclude_scopes:
    exclusions = [
        scope.strip() for scope in checkpoint_exclude_scopes.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
    excluded = False
    for exclusion in exclusions:
        if var.op.name.startswith(exclusion):
            excluded = True
    if not excluded:
        variables_to_restore.append(var)
saver_restore = tf.train.Saver(var_list=variables_to_restore)
saver = tf.train.Saver()
"""
" Open Session
"""
init = tf.global_variables_initializer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    sess.run(init)
    # saver_restore.restore(sess, resnet_model_path)

    train_writer = tf.summary.FileWriter(os.path.join(logs_dir, 'train'), sess.graph)  # 训练日志
    valid_writer = tf.summary.FileWriter(os.path.join(logs_dir, 'valid'), sess.graph)  # 验证日志
    # 初始化训练数据的迭代器
    sess.run(iterator.initializer)
    step = 0
    best_acc = 0.0
    start_time = time.time()
    while True:
        try:
            step += 1
            batch_x, batch_y = sess.run([image_batch, label_batch])
            train_dict = {inputs: batch_x,
                          labels: batch_y,
                          is_training: True,
                          keep_prob: 1.0}
            summary_str, _, loss_t, acc_t = sess.run([summary_op, train_step, loss, accuracy],
                                                     feed_dict=train_dict)
            if step % 10 == 0:
                print('Step: {},Train_Acc: {:.4f},Loss: {:.8f}'.format(step, acc_t, loss_t))
                train_writer.add_summary(summary_str, step)
            if step % 1000 == 0:   # 验证准确率
                sess.run(val_iterator.initializer)
                acc_reg = []
                loss_reg = []
                while True:
                    try:
                        batch_x, batch_y = sess.run([val_image_batch, val_label_batch])
                        valid_dict = {inputs: batch_x,
                                      labels: batch_y,
                                      is_training: False,
                                      keep_prob: 1.0}
                        loss_v, acc_v, summary_str = sess.run([loss, accuracy, summary_op],
                                                              feed_dict=valid_dict)
                        valid_writer.add_summary(summary_str, step)
                        acc_reg.append(acc_v)
                        loss_reg.append(loss_v)
                    except tf.errors.OutOfRangeError:
                        break
                avg_acc = np.mean(np.array(acc_reg))
                avg_loss = np.mean(np.array(loss_reg))
                print('------------------------------------------------------')
                print('Valid-----> ,Valid_Acc: {:.4f}, Valid_Loss: {:.7f}'.format(avg_acc, avg_loss))
                print('------------------------------------------------------')
                """
                " Save the best model
                """
                if avg_acc > best_acc:
                    best_acc = avg_acc
                    saver.save(sess=sess, save_path=models_path, global_step=step)
                    print("模型保存成功")
                    print("Save the best model with val_acc %0.4f" % best_acc)
                else:
                    print("Val_acc stay with val_acc %0.4f" % best_acc)
        except tf.errors.OutOfRangeError:
            train_writer.close()
            valid_writer.close()
            saver.save(sess=sess, save_path=models_path, global_step=step)
            break
    end_time = time.time()
    print("总共用时:", end_time-start_time)
    print('Ended......')

参考了:

https://www.zhihu.com/question/284663735

https://zhuanlan.zhihu.com/p/41117510

相关标签: tensorflow