Tensorflow训练过程中validation
程序员文章站
2022-07-13 12:41:51
...
Tensorflow训练过程中validation
Tensorflow因为静态图的原因,边train边validation的过程相较于pytorch来说复杂一些。
载入数据
分别获取训练集和验证集的数据。我这里使用的是从tfrecoed读入数据。
# training data
img_name_batch_train, img_batch_train, gtboxes_and_label_batch_train, num_objs_batch_train, img_h_batch_train, img_w_batch_train = \
next_batch(dataset_name = xxx, ..., is_training = True)
# validation data
img_name_batch_val, img_batch_val, gtboxes_and_label_batch_val, num_objs_batch_val, img_h_batch_val, img_w_batch_val = \
next_batch(dataset_name = xxx, ..., is_training = False)
注意is training
定义is_training占位符
is_trainging = tf.placeholder(tf.bool, shape=())
用一个tf.placeholder来控制是否训练、验证。
使用这种方式就可以在一个graph里创建一个分支条件,从而通过控制placeholder来控制是否进行验证。
用is_training控制图结点唯一
img_name_batch, img_batch, gtboxes_and_label_batch, num_objs_batch, img_h_batch, img_w_batch = \
tf.cond(is_training, lambda:(img_name_batch_train, img_batch_train, gtboxes_and_label_batch_train, num_objs_batch_train, img_h_batch_train, img_w_batch_train), lambda:(img_name_batch_val, img_batch_val, gtboxes_and_label_batch_val, num_objs_batch_val, img_h_batch_val, img_w_batch_val))
如果不适用tf.cond(),会在原图上添加上许多新的结点,这些结点的参数都是需要重新初始化的,也是就是说,验证的时候并不是使用训练的权重。
sess运行
_, global_stepnp, total_loss_dict_ = sess.run([train_op, global_step, total_loss_dict], feed_dict = {is_training:True})
val_loss_list = []
total_loss_dict_ = sess.run(total_loss_dict_, feed_dict={is_training: False})
上一篇: 《深度学习笔记》——在训练过程中出现nan的调试笔记
下一篇: 如何启用Oracle10g闪回数据库
推荐阅读