TensorFlow 读取图片并写入tfrecord
程序员文章站
2022-03-20 17:33:44
...
图片下载地址:
链接:https://pan.baidu.com/s/1gvvr5ovcYT1pQTy0umpzrA
提取码:bh0t
import os
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
from sklearn.utils import shuffle
from PIL import Image
%matplotlib inline
from tqdm import tqdm
print(tf.__version__)
print(np.__version__)
# 读取文件夹文件与标签
def load_sample(sample_dir):
print("加载图片数据")
file_name_list = []
labels_names = []
for(dir_path, dir_names, file_names) in os.walk(sample_dir):
for file_name in file_names:
file_path = os.path.join(dir_path, file_name)
# 获取图片路径与文件夹名字(标签)
file_name_list.append(file_path)
labels_names.append(dir_path.split("\\")[-1])
lab = list(sorted(set(labels_names)))
labdict = dict(zip(lab, list(range(len(lab)))))
labels = [labdict[i] for i in labels_names]
return (np.asarray(file_name_list), np.asarray(labels)), np.asarray(lab)
# 读取文件名与标签
data_dir = 'man_woman\\' # 定义文件路径
(images, labels), labelsnames = load_sample(data_dir) # 载入文件名称与标签
print(len(images), images) # 文件名
print(len(labels), labels) # 标签
print(labelsnames) # 标签字符串
def makeTFRec(filenames, labels): # 定义函数生成TFRecord
output_dir = "tfrecord_dir"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
filename = "mydata.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
with tf.io.TFRecordWriter(filename_fullpath) as writer:
for i in tqdm(range(0, len(labels))):
img = Image.open(filenames[i])
img = img.resize((256, 256))
img_raw = img.tobytes() # 将图片转化为二进制格式
features = tf.train.Features(feature = {
"label":tf.train.Feature(
int64_list=tf.train.Int64List(value=[labels[i]])),
"img_raw":tf.train.Feature(
bytes_list = tf.train.BytesList(value=[img_raw]))
})
example = tf.train.Example(features=features) # example对象对label和image数据进行封装
writer.write(example.SerializeToString()) # 序列化为字符串
makeTFRec(images, labels)
上一篇: 怎么用JS做出切换隐藏与显示同时切换图标
下一篇: Apache Geode 缓存管理介绍