欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

如何使用tf.data读取tfrecords数据集2

程序员文章站 2024-01-19 12:57:34
...

在检查完了数据是否一样后就要开始转图片格式,其实不一定要这一步,但是我怕数据的类型不同影响数据集的效果。

import os
import tensorflow as tf
from PIL import Image
import PIL
import matplotlib.pyplot as plt
import numpy as np
from scipy.misc import imread, imsave, imresize

cwd = r'/home/hehe/python/dataset1/washing/'

num=0
#使用os.listdir()获取cwd里面所有文件,然后转化格式保存
for img_name in os.listdir(cwd):
    img_path = cwd + img_name


    img = Image.open(img_path)
    img = img.resize((100, 100))
     # # img.show()
    img.save('/home/hehe/python/load_cifar10/datadir/washing/washing{}.JPEG'.format(num))
    num += 1
print("change format finish")
   

这个是我转化图片的完整代码,在cwd里面使用os.listdir得到所有数据的名字,然后使用PIL模块中的resize来定制大小,最后使用img.save来保存数据到指定的数据集中。

好了,这个就是转换图片的代码。接下来就要使用代码为所有的数据打标签,然后shuffle

#下面的代码是为了生成list.txt , 把不同文件夹下的图片和 数字label对应起来
import os

classes = {'bookrack':1 ,'cleaner':2, 'fan':3, 'lamp':4, 'microwave':5,
           'soft':6, 'bed':7, 'chair':0, 'washing':8, 'desk':9}
data_dir = r'/home/hehe/python/load_cifar10/datadir/'
output_path = 'list.txt'
fd = open(output_path, 'w')
for class_name in classes.keys():
    images_list = os.listdir(data_dir + class_name)
    for image_name in images_list:
        fd.write('{}/{} {}\n'.format(class_name, image_name, classes[class_name]))
fd.close()
print('finish task')

在classes里面定义好类型,字典里面key是文件夹的名字,value是lable。下面的代码是把所有文件名打乱,相当于shuffle效果

#随机生成训练集和验证集(在总量中随机选取_NUM_VALIDATION=100个样本作为验证集)

import random
_NUM_VALIDATION = 2000
_RANDOM_SEED = 0
list_path = 'list.txt'
train_list_path = 'list_train.txt'
val_list_path = 'list_val.txt'
fd = open(list_path)
lines = fd.readlines()
fd.close()
random.seed(_RANDOM_SEED)
random.shuffle(lines)
fd = open(train_list_path, 'w')
for line in lines[_NUM_VALIDATION:]:
    fd.write(line)
fd.close()
fd = open(val_list_path, 'w')
for line in lines[:_NUM_VALIDATION]:
    fd.write(line)
fd.close()

也有另外一种比较简单的方案

#随机生成训练集和验证集(在总量中随机选取_NUM_VALIDATION=100个样本作为验证集)


import random


ratio=0.2                #选择0.2,那么意味着你的测试集只有20%,训练集80%
list_path = 'list.txt'
train_list_path = 'list_train.txt'
val_list_path = 'list_val.txt'
fd = open(list_path)
lines = fd.readlines()
_NUM_VALIDATION = int(len(lines)*ratio)
_RANDOM_SEED = 0
fd.close()
random.seed(_RANDOM_SEED)
random.shuffle(lines)
fd = open(train_list_path, 'w')
for line in lines[_NUM_VALIDATION:]:
    fd.write(line)
fd.close()
fd = open(val_list_path, 'w')
for line in lines[:_NUM_VALIDATION]:
    fd.write(line)
fd.close()

这种方案只需要改变ratio的大小就够了,很方便