欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow 读取图片文件

程序员文章站 2022-03-20 17:01:17
...

测试文件
链接:https://pan.baidu.com/s/11bfWr8tIk3mDlNkoQvODuQ
提取码:wco7

TensorFlow 读取图片文件 tf 1.0

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time

from tensorflow import keras

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)
def pic_read(file_list):
    """
    读取狗图片并转换成张量
    :param file_list :文件路径+名字的列表
    """
    # 1.构造文件队列
    file_queue = tf.train.string_input_producer(file_list)
    
    # 2、构造阅读器去读取图片内容(默认读取一张图片)
    reader = tf.WholeFileReader()
    key, value = reader.read(file_queue)
    
    print(value)
    # 3.对读取的数据图片进行解码
    image = tf.image.decode_jpeg(value)
    print(image)
    
    # 5.处理图片的大小(统一大小)
    img_resize = tf.image.resize_images(image, [200, 200])
    
    print(img_resize)
    # 注意:一定要把样本的形状固定[200, 200, 3] 在批处理的时候所有数据形状必须定义
    img_resize.set_shape([200, 200, 3])
    print(img_resize)
    
    # 6.进行批处理
    image_batch = tf.train.batch([img_resize], batch_size=20, num_threads=1, capacity=20)
    
    print(image_batch)
    
    return image_batch
if __name__ == "__main__":
    # 找到文件,放入列表  路径+名字 ->列表当中
    file_names = os.listdir("./dog/")
    file_list = [os.path.join("./dog/", file) for file in file_names]
    print(file_list)
    image_batch = pic_read(file_list)
    
    # 开启会话运行结果
    with tf.Session() as sess:
        # 定义一个线程协调
        coord = tf.train.Coordinator()
        # 开启读文件的线程
        threads = tf.train.start_queue_runners(sess, coord)
        
        print(sess.run([image_batch]))
        # 回收子线程
        coord.request_stop()
        coord.join(threads)

TensorFlow 读取图片文件 tf 2.0

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras

print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)
# 设置文件列表
train_dir  = "./image/"               
print(os.path.exists(train_dir))
# 图片大小 高*宽*深度
height = 200    
width = 200
channels = 3
batch_size = 20

# 类型个数
num_classes = 2

# 图片读取方式
train_datagen = keras.preprocessing.image.ImageDataGenerator(
    rescale = 1./ 255,           # 数据归一化 
    rotation_range = 40,         # 随机旋转 0-40度
    width_shift_range = 0.2,     # 水平移动百分百  0-20%
    height_shift_range = 0.2,    # 垂直移动百分百  0-20%
    shear_range = 0.2,           # 剪切百分百
    zoom_range = 0.2,            # 缩放百分百
    horizontal_flip = True,      # 水平翻转
    fill_mode = "nearest",       # 空数据填坑方式:最近数据
)

# 读取图片文件夹
train_generator = train_datagen.flow_from_directory(
    train_dir,                     # 训练文件夹
    target_size = (height, width), # 图片目标尺寸
    batch_size = batch_size,       # batch 尺寸
    seed = 7,                      # 随机方式
    shuffle = True,                # 是否缓存数据
    class_mode = "categorical"     
)
# 显示读取数据情况
train_num = train_generator.samples
print(train_num)

for i in range(2):
    x, y = train_generator.next()
    print(x.shape, y.shape)
    print(y)
相关标签: TensorFlow