欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow 读取图片并写入tfrecord

程序员文章站 2022-03-20 17:33:44
...

图片下载地址:
链接:https://pan.baidu.com/s/1gvvr5ovcYT1pQTy0umpzrA
提取码:bh0t

import os
import tensorflow as tf 
from matplotlib import pyplot as plt
import numpy as np
from sklearn.utils import shuffle
from PIL import Image
%matplotlib inline
from tqdm import tqdm

print(tf.__version__)
print(np.__version__)
# 读取文件夹文件与标签
def load_sample(sample_dir):
    
    print("加载图片数据")
    file_name_list = []
    labels_names = []
    
    for(dir_path, dir_names, file_names) in os.walk(sample_dir):
        
        for file_name in file_names:
            file_path = os.path.join(dir_path, file_name)
            # 获取图片路径与文件夹名字(标签)
            file_name_list.append(file_path)    
            labels_names.append(dir_path.split("\\")[-1])
    
    lab = list(sorted(set(labels_names)))
    
    labdict = dict(zip(lab, list(range(len(lab)))))
    labels = [labdict[i] for i in labels_names]
    
    return (np.asarray(file_name_list), np.asarray(labels)), np.asarray(lab)
# 读取文件名与标签 
data_dir = 'man_woman\\'  # 定义文件路径

(images, labels), labelsnames = load_sample(data_dir)  # 载入文件名称与标签
print(len(images), images)  # 文件名 
print(len(labels), labels)  # 标签              
print(labelsnames)  # 标签字符串
def makeTFRec(filenames, labels):  # 定义函数生成TFRecord
    output_dir = "tfrecord_dir"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    filename = "mydata.tfrecords"
    filename_fullpath = os.path.join(output_dir, filename)
    
    with tf.io.TFRecordWriter(filename_fullpath) as writer:
        for i in tqdm(range(0, len(labels))):
            img = Image.open(filenames[i])
            img = img.resize((256, 256))
            img_raw = img.tobytes()  # 将图片转化为二进制格式
            
            features = tf.train.Features(feature = {
                "label":tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[labels[i]])),
                "img_raw":tf.train.Feature(
                    bytes_list = tf.train.BytesList(value=[img_raw]))
            })
            example = tf.train.Example(features=features)  # example对象对label和image数据进行封装

            writer.write(example.SerializeToString())  # 序列化为字符串
makeTFRec(images, labels)
相关标签: TensorFlow