欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

ResNetWhaleTailClassification(第五步ResNetTrain模块)

程序员文章站 2022-05-24 18:52:03
...
import tensorflow as tf
import tensorflow.contrib.slim.nets as nets
import os

def read_and_decode_tfrecord(filename):
    # 定义数据流队列和读取器
    filename_queue = tf.train.string_input_producer(filename)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img': tf.FixedLenFeature([], tf.string),
                                       })  # 取出包含image和label的features对象,拆解为图像数据和标签数据
    label = tf.cast(features['label'], tf.int32)
    img = tf.decode_raw(features['img'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])
    img = tf.cast(img, tf.float32)/255   # 将矩阵归一化
    return img, label

save_model_dir = 'model/'
if not os.path.exists(save_model_dir):
    os.makedirs(save_model_dir)

# 定义批次大小,x,y占位符,学习率
batch_size_ = 2
lr = tf.Variable(0.0001, dtype=tf.float32)
x = tf.placeholder(tf.float32, [None, 224, 224, 3])
y = tf.placeholder(tf.float32, [None])

train_list = []
class_path = 'data/tfrecord/'
for dir_name in  os.listdir(class_path):
    if 'train' in dir_name:
        train_list.append(os.path.join(class_path, dir_name))

img, label = read_and_decode_tfrecord(train_list)

# 将队列中的数据顺序打乱
img_batch, label_batch = tf.train.shuffle_batch([img, label], num_threads=2, batch_size=batch_size_, capacity=100,
                                                min_after_dequeue=95)
# 将label值进行onehot编码
one_hot_labels = tf.one_hot(indices=tf.cast(y, tf.int32), depth=50)

# 采用resnet进行训练返回50维数的向量,和网络对应的**函数的成分的字典
pre, end_points = nets.resnet_v2.resnet_v2_50(x, num_classes=50, is_training=True)
pre = tf.reshape(pre, shape=[-1, 50])

# 定义损失函数和优化器
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pre, labels=one_hot_labels))
optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)

# 准确度
a = tf.argmax(pre, 1)
b = tf.argmax(one_hot_labels, 1)
correct_pre = tf.equal(a, b)
accuracy = tf.reduce_mean(tf.cast(correct_pre, tf.float32))

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 创建一个协调器,管理线程
    coord = tf.train.Coordinator()
    # 启动QueueRunner,此时文件名队列已经进队
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    while True:
        i += 1
        b_image, b_label = sess.run([img_batch, label_batch])
        if len(b_label) == 50:
            acc_train, loss_, y_t, y_p, a_, b_ = sess.run([accuracy, optimizer, loss, one_hot_labels, pre, a, b], feed_dict={x: b_image,
                                                                                                           y: b_label})
            print('step: {}, train_loss: {}, acc_train: {}'.format(i, loss_, acc_train))
            if i == 500:
                saver.save(sess, save_model_dir, global_step=i)
            elif i == 800:
                saver.save(sess, save_model_dir, global_step=i)
            elif i == 1000:
                saver.save(sess, save_model_dir, global_step=i)
        coord.request_stop()
        # 其他所有线程关闭之后,这一函数才能返回
        coord.join(threads)

项目源代码获取请点击这里