欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow slim代码示例

程序员文章站 2022-07-13 13:10:18
...

slim代码示例,包括读取tfrecords,带BN的训练,L2正则化

读取tfrecords
import glob
import cv2
import numpy as np
import tensorflow as tf
# 从自定义的config.py里导入一些参数
from config import HEIGHT, WIDTH, CHANNEL

def read(file_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(file_queue)
    features = tf.parse_single_example(
        serialized=serialized_example,
        features={'image': tf.FixedLenFeature([], tf.string),
                  'field': tf.VarLenFeature(tf.float32)}
    ) 
    # image是整型,field是浮点型float32
    image = tf.decode_raw(features['image'], tf.uint8)
    field = tf.sparse_tensor_to_dense(features['field'], default_value=0)
    image = tf.reshape(image, [HEIGHT, WIDTH, CHANNEL])
    field = tf.reshape(field, [HEIGHT, WIDTH, 1])
    return image, field

def read_batch(file_queue, batch_size):
    image, field = read(file_queue)
    images, fields = tf.train.shuffle_batch(
        tensors=[image, field], batch_size=batch_size,
        capacity=400, min_after_dequeue=200, num_threads=4
    )
    return images, fields

# 示范
if __name__ == '__main__':
	files = glob.glob('../dataset/car/tfrecords/training/*.tfrecords')
    file_queue = tf.train.string_input_producer(
    	string_tensor=tf.train.match_filenames_once(files),
		num_epochs=1,
		shuffle=True)
	image_batch, field_batch = read_batch(file_queue=file_queue, batch_size=16)

	with tf.Session() as sess:
		sess.run(tf.global_variables_initializer())
		sess.run(tf.local_variables_initializer())
		coord = tf.train.Coordinator()
		threads = tf.train.start_queue_runners(coord=coord)
		images, fields = sess.run([image_batch, field_batch])
	    for i in range(16):
			cv2.imshow('image', images[i, :, :, :])
			cv2.imshow('field', np.uint8(fields[i, :, :, :] * 255))
			cv2.waitKey(0)
			coord.request_stop()
			coord.join(threads)
完整的网络、训练和测试
import os
import numpy as np
import cv2
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python.ops import control_flow_ops

from data_reader import read_batch
from config import HEIGHT, WIDTH, CHANNEL
from config import EXAMPLE_NUM, BATCH_SIZE, MAX_STEP
from config import L2_REGULARIZER, MOMENTUM, LEARNING_RATE
from config import DECAY_STEP, DECAY_RATE, SAVE_STEP
from config import TRAIN_SET_DIR, VAL_SET_DIR, LOG_DIR, MODEL_PATH


class EFNet:

    def __init__(self):
        print('New network')
        self.input_image = tf.placeholder(dtype=tf.float32, shape=[None, HEIGHT, WIDTH, CHANNEL], name='input_image')
        self.input_label = tf.placeholder(dtype=tf.float32, shape=[None, HEIGHT, WIDTH, 1], name='input_label')
        self.is_training = tf.placeholder(dtype=tf.bool, name='is_training')

        self.prediction = None
        self.field_loss, self.regularization_loss, self.total_loss = [None] * 3
        self.learning_rate, self.global_step, self.train_step = [None] * 3
        self.optimizer, self.update_ops = [None] * 2

    def id_block(self, x, d, scope):
        with slim.arg_scope([slim.conv2d],
                            weights_regularizer=slim.l2_regularizer(L2_REGULARIZER),
                            activation_fn=tf.nn.relu,
                            normalizer_fn=tf.layers.batch_normalization,
                            normalizer_params={'training': self.is_training, 'momentum': MOMENTUM}):
            y = slim.conv2d(inputs=x, num_outputs=d, kernel_size=[1, 1], scope=scope + '_conv1')
            y = slim.conv2d(inputs=y, num_outputs=d, kernel_size=[3, 3], scope=scope + '_conv2')
            y = slim.conv2d(inputs=y, num_outputs=2*d, kernel_size=[1, 1], activation_fn=None, scope=scope + '_conv3')
            return tf.nn.relu(x + y)

    def conv_block(self, x, d, scope):
        with slim.arg_scope([slim.conv2d],
                            weights_regularizer=slim.l2_regularizer(L2_REGULARIZER),
                            activation_fn=tf.nn.relu,
                            normalizer_fn=tf.layers.batch_normalization,
                            normalizer_params={'training': self.is_training, 'momentum': MOMENTUM}):
            y1 = slim.conv2d(inputs=x, num_outputs=d, kernel_size=[1, 1], scope=scope + '_conv_1')
            y1 = slim.conv2d(inputs=y1, num_outputs=d, kernel_size=[3, 3], scope=scope + '_conv_2')
            y1 = slim.conv2d(inputs=y1, num_outputs=2*d, kernel_size=[1, 1], activation_fn=None, scope=scope + '_conv3')
            y2 = slim.conv2d(inputs=x, num_outputs=2*d, kernel_size=[1, 1], activation_fn=None, scope=scope + '_conv4')
            return tf.nn.relu(y1 + y2)

    @staticmethod
    def downsample(x, scope):
        return slim.avg_pool2d(inputs=x, kernel_size=[2, 2], stride=2, padding='VALID', scope=scope)

    @staticmethod
    def upsample(x):
        return tf.image.resize_images(images=x, size=(x.get_shape()[1] * 2, x.get_shape()[2] * 2))

    def build_net(self):
        # inference
        with tf.name_scope('inference'):
            net = self.input_image / 255.0 - 0.5
            # encoder
            net = slim.conv2d(inputs=net, num_outputs=64, kernel_size=[5, 5], stride=1, padding='SAME',
                              normalizer_fn=tf.layers.batch_normalization,
                              normalizer_params={'training': self.is_training, 'momentum': MOMENTUM},
                              weights_regularizer=slim.l2_regularizer(L2_REGULARIZER),
                              scope='conv1_of_encoder')
            net = slim.conv2d(inputs=net, num_outputs=64, kernel_size=[3, 3], stride=1, padding='SAME',
                              normalizer_fn=tf.layers.batch_normalization,
                              normalizer_params={'training': self.is_training, 'momentum': MOMENTUM},
                              weights_regularizer=slim.l2_regularizer(L2_REGULARIZER),
                              scope='conv2_of_encoder')
            # id_block*1, conv_block, downsample
            net = self.id_block(x=net, d=32, scope='id1_of_encoder')
            net = self.conv_block(x=net, d=64, scope='conv1_of_encoder')
            net = self.downsample(x=net, scope='pool1_of_encoder')
            # id_block*2, conv_block, downsample
            net = self.id_block(x=net, d=64, scope='id2_of_encoder')
            net = self.id_block(x=net, d=64, scope='id3_of_encoder')
            net = self.conv_block(x=net, d=128, scope='conv2_of_encoder')
            net = self.downsample(x=net, scope='pool2_of_encoder')
            # id_block*3, conv_block, downsample
            net = self.id_block(x=net, d=128, scope='id4_of_encoder')
            net = self.id_block(x=net, d=128, scope='id5_of_encoder')
            net = self.id_block(x=net, d=128, scope='id6_of_encoder')
            net = self.conv_block(x=net, d=256, scope='conv3_of_encoder')
            net = self.downsample(x=net, scope='pool3_of_encoder')

            # decoder
            # convolution, upsample, id*2
            net = slim.conv2d(inputs=net, num_outputs=256, kernel_size=[3, 3], stride=1, padding='SAME',
                              normalizer_fn=tf.layers.batch_normalization,
                              normalizer_params={'training': self.is_training, 'momentum': MOMENTUM},
                              weights_regularizer=slim.l2_regularizer(L2_REGULARIZER),
                              scope='conv1_of_decoder')
            net = self.upsample(x=net)
            net = self.id_block(x=net, d=128, scope='id1_of_decoder')
            net = self.id_block(x=net, d=128, scope='id2_of_decoder')
            # convolution, upsample, id*1
            net = slim.conv2d(inputs=net, num_outputs=128, kernel_size=[3, 3], stride=1, padding='SAME',
                              normalizer_fn=tf.layers.batch_normalization,
                              normalizer_params={'training': self.is_training, 'momentum': MOMENTUM},
                              weights_regularizer=slim.l2_regularizer(L2_REGULARIZER),
                              scope='conv2_of_decoder')
            net = self.upsample(x=net)
            net = self.id_block(x=net, d=64, scope='id3_of_decoder')
            # convolution, upsample, id*1
            net = slim.conv2d(inputs=net, num_outputs=64, kernel_size=[3, 3], stride=1, padding='SAME',
                              normalizer_fn=tf.layers.batch_normalization,
                              normalizer_params={'training': self.is_training, 'momentum': MOMENTUM},
                              weights_regularizer=slim.l2_regularizer(L2_REGULARIZER),
                              scope='conv3_of_decoder')
            net = self.upsample(x=net)
            net = self.id_block(x=net, d=32, scope='id4_of_decoder')
            self.prediction = slim.conv2d(inputs=net, num_outputs=1, kernel_size=[3, 3], stride=1, padding='SAME',
                                          normalizer_fn=None,
                                          normalizer_params=None,
                                          weights_regularizer=slim.l2_regularizer(L2_REGULARIZER),
                                          scope='conv_of_output')

        # loss
        with tf.name_scope('loss'):
            self.field_loss = tf.losses.mean_squared_error(labels=self.input_label, predictions=self.prediction)
            self.regularization_loss = tf.add_n(slim.losses.get_regularization_losses())
            self.total_loss = self.field_loss + self.regularization_loss

        # train_step
        with tf.name_scope('train_op'):
            self.global_step = tf.Variable(0, trainable=False)
            self.learning_rate = tf.train.exponential_decay(
                learning_rate=LEARNING_RATE, global_step=self.global_step,
                decay_steps=DECAY_STEP, decay_rate=DECAY_RATE, staircase=True)
            self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
            self.train_step = slim.learning.create_train_op(
                total_loss=self.total_loss, optimizer=self.optimizer, global_step=self.global_step)
            self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if self.update_ops:
                self.train_step = control_flow_ops.with_dependencies([tf.group(*self.update_ops)], self.train_step)

    def train(self):
        train_file = [os.path.join(TRAIN_SET_DIR, file) for file in os.listdir(TRAIN_SET_DIR)]
        train_file_queue = tf.train.string_input_producer(
            string_tensor=tf.train.match_filenames_once(train_file),
            shuffle=True)
        train_image_batch, train_label_batch = read_batch(train_file_queue, BATCH_SIZE)
        val_file = [os.path.join(VAL_SET_DIR, file) for file in os.listdir(VAL_SET_DIR)]
        val_file_queue = tf.train.string_input_producer(
            string_tensor=tf.train.match_filenames_once(val_file),
            shuffle=False)
        val_image_batch, val_label_batch = read_batch(val_file_queue, BATCH_SIZE)

        train_sum_dir = LOG_DIR + '/train'
        val_sum_dir = LOG_DIR + '/validation'
        loss_sum = tf.summary.scalar('mse', self.field_loss)

        saver = tf.train.Saver()

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            train_writer = tf.summary.FileWriter(train_sum_dir, sess.graph)
            val_writer = tf.summary.FileWriter(val_sum_dir, sess.graph)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            try:
                while not coord.should_stop():
                    train_images, train_labels = sess.run([train_image_batch, train_label_batch])
                    step, _ = sess.run([self.global_step, self.train_step],
                                       feed_dict={self.input_image: train_images,
                                                  self.input_label: train_labels,
                                                  self.is_training: True})

                    if step % 50 == 0 or step == 1:
                        train_loss, lr, train_sum = \
                            sess.run([self.field_loss, self.learning_rate, loss_sum],
                                     feed_dict={self.input_image: train_images,
                                                self.input_label: train_labels,
                                                self.is_training: True})
                        train_writer.add_summary(train_sum, step)
                        print('Epoch: {}, step: {}, train_loss: {:.6f}, lr: {:.6f}'
                              .format(step * BATCH_SIZE // EXAMPLE_NUM + 1, step, train_loss, lr))

                        val_images, val_labels = sess.run([val_image_batch, val_label_batch])
                        val_loss, val_sum = sess.run([self.field_loss, loss_sum],
                                                     feed_dict={self.input_image: val_images,
                                                                self.input_label: val_labels,
                                                                self.is_training: False})
                        val_writer.add_summary(val_sum, step)
                        print('val_loss: {:.6f}'.format(val_loss))

                    if step % SAVE_STEP == 0:
                        saver.save(sess, MODEL_PATH, step)

                    if step == MAX_STEP:
                        break

            except tf.errors.OutOfRangeError:
                print('Training done!')
            finally:
                coord.request_stop()
            coord.join(threads)
            train_writer.close()
            val_writer.close()
        print('Training finished!')

    def predict(self, model, image):
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            saver.restore(sess=sess, save_path=model)
            img = np.reshape(a=cv2.imread(image), newshape=[1, HEIGHT, WIDTH, CHANNEL])
            field = sess.run(self.prediction, feed_dict={self.input_image: img, self.is_training: False})
            field = field[0, :, :, :]
            cv2.imshow('field', np.uint8(field * 255))
            cv2.waitKey(0)

if __name__ == '__main__':
    # train
    # my_net = EFNet()
    # my_net.build_net()
    # my_net.train()

    # predict
    # model = '../models/car/Car.ckpt'
    # image = '../dataset/car/test/'
    # my_net = EFNet()
    # my_net.build_net()
    # my_net.predict(model, image)