欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow训练过程中validation

程序员文章站 2022-07-13 12:41:51
...


Tensorflow因为静态图的原因,边train边validation的过程相较于pytorch来说复杂一些。

载入数据

分别获取训练集和验证集的数据。我这里使用的是从tfrecoed读入数据。

# training data
img_name_batch_train, img_batch_train, gtboxes_and_label_batch_train, num_objs_batch_train, img_h_batch_train, img_w_batch_train = \
next_batch(dataset_name = xxx, ..., is_training = True)

# validation data
img_name_batch_val, img_batch_val, gtboxes_and_label_batch_val, num_objs_batch_val, img_h_batch_val, img_w_batch_val = \
next_batch(dataset_name = xxx, ..., is_training = False)

注意is training

定义is_training占位符

is_trainging = tf.placeholder(tf.bool, shape=())

用一个tf.placeholder来控制是否训练、验证。
使用这种方式就可以在一个graph里创建一个分支条件,从而通过控制placeholder来控制是否进行验证。

用is_training控制图结点唯一

img_name_batch, img_batch, gtboxes_and_label_batch, num_objs_batch, img_h_batch, img_w_batch = \
tf.cond(is_training, lambda:(img_name_batch_train, img_batch_train, gtboxes_and_label_batch_train, num_objs_batch_train, img_h_batch_train, img_w_batch_train), lambda:(img_name_batch_val, img_batch_val, gtboxes_and_label_batch_val, num_objs_batch_val, img_h_batch_val, img_w_batch_val))

如果不适用tf.cond(),会在原图上添加上许多新的结点,这些结点的参数都是需要重新初始化的,也是就是说,验证的时候并不是使用训练的权重。

sess运行

_, global_stepnp, total_loss_dict_ = sess.run([train_op, global_step, total_loss_dict], feed_dict = {is_training:True})

val_loss_list = []
total_loss_dict_ = sess.run(total_loss_dict_, feed_dict={is_training: False})