欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow读取数据时报错:cannot create a tensor proto whose content is larger than 2GB

程序员文章站 2022-03-17 14:51:39
...

一、背景

我在做实验时,利用vangogh2photo数据集[1]实现图像风格转换时,由于数据集过大,其中trainB文件夹中包含由6287张图片,trainA包含整整400张图片,每张图片的大小均为256*256,如果直接将数据集全部读取到内存中,就会报错,报错的结果如下图所示:

[1]下载地址:https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/

tensorflow读取数据时报错:cannot create a tensor proto whose content is larger than 2GB

出现的报错为:cannot create a tensor proto whose content is larger than 2GB,且电脑会卡死。。。。。

二、问题原因

一次读入太多图像导致系统的内存溢出,难以执行后续的运算,我所采用的读取数据的示例代码为:

import tensorflow as tf
import numpy as np
import os
from PIL import Image


image_height = 256
image_width = 256
image_channel = 3
image_size = image_height * image_width

data_dir = "./vangogh2photo"

def main(input_dir, floderA, floderB):

    imagesA = os.listdir(input_dir + floderA)
    imagesB = os.listdir(input_dir + floderB)
    imageA_len = len(imagesA)
    imageB_len = len(imagesB)

    dataA = np.empty((imageA_len, image_width, image_height, image_channel), dtype="float32")
    dataB = np.empty((imageB_len, image_width, image_height, image_channel), dtype="float32")

    for i in range(imageA_len):
        img = Image.open(input_dir + floderA + "/" + imagesA[i])
        img = img.resize((image_width, image_height))
        arr = np.asarray(img, dtype="float32")
        dataA[i, :, :, :] = arr * 1.0 / 127.5 - 1.0

    for i in range(imageB_len):
        img = Image.open(input_dir + floderB + "/" + imagesB[i])
        img = img.resize((image_width, image_height))
        arr = np.asarray(img, dtype="float32")
        dataB[i, :, :, :] = arr * 1.0 / 127.5 - 1.0


    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())

        dataA = tf.reshape(dataA, [-1, image_width, image_height, image_channel])
        dataB = tf.reshape(dataB, [-1, image_width, image_height, image_channel])

        train_set_A = sess.run(dataA)
        train_set_B = sess.run(dataB)

    return train_set_A, train_set_B


if __name__ =="__main__":
    a, b = main(data_dir, "/trainA", "/trainB")
    print(len(b))

我的电脑的内存为8G,显卡为GTX 1060  3G。只要运行上述代码必会报错。

三、解决办法

网上关于这个问题的讨论比较少,目前能想到的解决方法有三个:

1.在对实验没有大的影响的情况下,适当减少数据集的大小。

2.对代码进行修改,只读取每张图像的路径,然后再feed_dict中再单独读取batch_size个图像,读取每张图像路径的函数可以用os.listdir,在实验的过程中可以考虑这么使用:

from PIL import Image
import tensorflow as tf
import os


# 读取图像的路径位置
def read_data_path(data_dir):

    images = os.listdir(data_dir)
    image_len = len(images)

    return images, image_len


# ############################################################# #
# 使用时,只需提前创建存储变量data,例如                           #
# data = np.empty((image_len, 256, 256, 3), dtype="float32")    #
# ############################################################# #


# 读取图像, images为输入数据,将其读到data当中的第i层
def load_data(data_dir, images, data, i):

    img = Image.open(data_dir + "/" + images)
    arr = np.asarray(img, dtype="float32")
    data[i, :, :, :] = arr * 1.0 / 127.5 - 1.0

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        data = sess.run(tf.reshape(data, [-1, 256, 256, 3]))

    return data
   

3.将输入数据分批制作成.tfrecords格式进行后续实验,可参考:https://blog.csdn.net/Liangjun_Feng/article/details/79698809

 

相关标签: tensorflow