针对tfrecords数据集,标签和图片读取不一致问题。
程序员文章站
2024-01-30 15:30:58
...
前一段时间,训练2分类深度网络时,loss一直维持在2.3左右。在网上看了很多博客,最后从这篇博客中找到了,解决的方法。具体的不详细说了,可以参考那篇博客。我按照上面的提到问题,仔细检测自己的网络,发现可能是我的数据集中的图片和标签不一致所造成。
我的数据集是tfrecords格式的二进制文件(前提:确保制作数据集的图片不存在问题,不然读取时会报错。),我制作代码如下:
import glob
import tensorflow as tf
from PIL import Image
import numpy as np
import random
num=0
bestnum=5000
recordfilenum=0
filenames=[]
for filename in glob.glob('./data/PetImages/Cat/*.jpg')[2500:3000]:
tmp=[]
tmp.append(filename)
tmp.append(0)
filenames.append(tmp)
for filename in glob.glob('./data/PetImages/Dog/*.jpg')[2500:3000]:
tmp=[]
tmp.append(filename)
tmp.append(1)
filenames.append(tmp)
random.shuffle(filenames)
for filename in filenames:
if not num % bestnum: #超过1000,写入下一个tfrecord
recordfilenum += 1
ftrecordfilename = ("valpetdata.tfrecords_%.3d" % recordfilenum)
writer = tf.python_io.TFRecordWriter(os.path.join('./data',ftrecordfilename))
num = num + 1
img = Image.open(filename[0], 'r')
img = img.resize([224,224],Image.ANTIALIAS)
img_raw = img.tobytes() # 将图片转化为二进制格式
example = tf.train.Example(
features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value[filename[1]])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
}))
writer.write(example.SerializeToString()) # 序列化为字符串
print('制作完成!!!!!')
writer.close()
通过上面的代码,成功的制作了tfrecords文件,通过以下代码,读取数据集(这段代码是问题代码):
# -*- coding: utf-8 -*-
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
def read_and_decode_tfrecord(filename):
filename_deque = tf.train.string_input_producer(filename)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_deque)
features = tf.parse_single_example(serialized_example, features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)})
label = tf.cast(features['label'], tf.int32)
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [224, 224, 3])
img = tf.cast(img, tf.float32) / 255.0
return img, label
val_list = ['./data/trpetdata.tfrecords_001']
valimg, vallabel = read_and_decode_tfrecord(val_list)
img,label = tf.train.batch([valimg, vallabel ], batch_size=64, capacity=500)
#img_batch, label_batch =tf.train.shuffle_batch([img, label], batch_size=128, #capacity=3500,min_after_dequeue=1000)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 创建一个协调器,管理线程
coord = tf.train.Coordinator()
# 启动QueueRunner,此时文件名队列已经进队
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
tf.train.start_queue_runners(sess)
for i in range(120):
print('第%d次:'%i,label.eval()[0])
plt.imshow(img.eval()[0])
plt.axis('off')
plt.show()
coord.request_stop()
coord.join(threads)
通过上面代码,会发现有时候标签类型和图片对应不上(前提是制作数据集时,标签和图片是一一对应)。咨询了别人,发现存在的问题:tfrecords格式的数据集,要同时读取图片和标签,才能保证图片和标签一一对应。根据问题,修改读取代码:
val_list = ['./data/trpetdata.tfrecords_001']
valimg, vallabel = read_and_decode_tfrecord(val_list)
img,label = tf.train.batch([valimg, vallabel ], batch_size=64, capacity=500)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 创建一个协调器,管理线程
coord = tf.train.Coordinator()
# 启动QueueRunner,此时文件名队列已经进队
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
tf.train.start_queue_runners(sess)
for i in range(120):
md_img,md_label=sess.run([img,label])
print('第%d次:'%i,md_label[0])
plt.imshow(md_img[0])
plt.axis('off')
plt.show()
coord.request_stop()
coord.join(threads)
这样读取的图片和标签也就一一对应了。
上一篇: linux查看硬件常用命令
下一篇: 大数据消息中间件Kafka概述学习