欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

针对tfrecords数据集,标签和图片读取不一致问题。

程序员文章站 2024-01-30 15:30:58
...

前一段时间,训练2分类深度网络时,loss一直维持在2.3左右。在网上看了很多博客,最后从这篇博客中找到了,解决的方法。具体的不详细说了,可以参考那篇博客。我按照上面的提到问题,仔细检测自己的网络,发现可能是我的数据集中的图片和标签不一致所造成。

我的数据集是tfrecords格式的二进制文件(前提:确保制作数据集的图片不存在问题,不然读取时会报错。),我制作代码如下:

import glob
import tensorflow as tf
from PIL import Image
import numpy as np
import random

num=0
bestnum=5000   
recordfilenum=0
filenames=[]
for filename in glob.glob('./data/PetImages/Cat/*.jpg')[2500:3000]:
    tmp=[]
    tmp.append(filename)
    tmp.append(0)
    filenames.append(tmp)
for filename in glob.glob('./data/PetImages/Dog/*.jpg')[2500:3000]:
    tmp=[]
    tmp.append(filename)
    tmp.append(1)
    filenames.append(tmp)
   
random.shuffle(filenames)

for filename in filenames:           
    if not num % bestnum:    #超过1000,写入下一个tfrecord        
        recordfilenum += 1
        ftrecordfilename = ("valpetdata.tfrecords_%.3d" % recordfilenum)
        writer = tf.python_io.TFRecordWriter(os.path.join('./data',ftrecordfilename))
    num = num + 1
    img = Image.open(filename[0], 'r')    
    img = img.resize([224,224],Image.ANTIALIAS)
    img_raw = img.tobytes()  # 将图片转化为二进制格式
    example = tf.train.Example(
    features=tf.train.Features(feature={
        'label': tf.train.Feature(int64_list=tf.train.Int64List(value[filename[1]])),
        'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
    }))
    writer.write(example.SerializeToString())  # 序列化为字符串    
print('制作完成!!!!!')
writer.close()

通过上面的代码,成功的制作了tfrecords文件,通过以下代码,读取数据集(这段代码是问题代码)

# -*- coding: utf-8 -*-

import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt

def read_and_decode_tfrecord(filename):
    filename_deque = tf.train.string_input_producer(filename)    
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_deque)
    features = tf.parse_single_example(serialized_example, features={
        'label': tf.FixedLenFeature([], tf.int64),
        'img_raw': tf.FixedLenFeature([], tf.string)})
    
    label = tf.cast(features['label'], tf.int32)
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])   
    img = tf.cast(img, tf.float32) / 255.0
    return img, label

val_list = ['./data/trpetdata.tfrecords_001']
valimg, vallabel = read_and_decode_tfrecord(val_list)
img,label = tf.train.batch([valimg, vallabel ], batch_size=64, capacity=500)

#img_batch, label_batch =tf.train.shuffle_batch([img, label], batch_size=128, #capacity=3500,min_after_dequeue=1000)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 创建一个协调器,管理线程
    coord = tf.train.Coordinator()
    # 启动QueueRunner,此时文件名队列已经进队
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    tf.train.start_queue_runners(sess)
    for i in range(120):
        print('第%d次:'%i,label.eval()[0])
        plt.imshow(img.eval()[0])
        plt.axis('off')
        plt.show()
    coord.request_stop()
    coord.join(threads)

通过上面代码,会发现有时候标签类型和图片对应不上(前提是制作数据集时,标签和图片是一一对应)咨询了别人,发现存在的问题:tfrecords格式的数据集,要同时读取图片和标签,才能保证图片和标签一一对应。根据问题,修改读取代码:

val_list = ['./data/trpetdata.tfrecords_001']
valimg, vallabel = read_and_decode_tfrecord(val_list)
img,label = tf.train.batch([valimg, vallabel ], batch_size=64, capacity=500)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 创建一个协调器,管理线程
    coord = tf.train.Coordinator()
    # 启动QueueRunner,此时文件名队列已经进队
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    tf.train.start_queue_runners(sess)
    for i in range(120):
        md_img,md_label=sess.run([img,label])
        print('第%d次:'%i,md_label[0])
        plt.imshow(md_img[0])
        plt.axis('off')
        plt.show()
    coord.request_stop()
    coord.join(threads)

这样读取的图片和标签也就一一对应了。