tensorflow读取数据时报错:cannot create a tensor proto whose content is larger than 2GB
程序员文章站
2022-03-17 14:51:39
...
一、背景
我在做实验时,利用vangogh2photo数据集[1]实现图像风格转换时,由于数据集过大,其中trainB文件夹中包含由6287张图片,trainA包含整整400张图片,每张图片的大小均为256*256,如果直接将数据集全部读取到内存中,就会报错,报错的结果如下图所示:
[1]下载地址:https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/
出现的报错为:cannot create a tensor proto whose content is larger than 2GB,且电脑会卡死。。。。。
二、问题原因
一次读入太多图像导致系统的内存溢出,难以执行后续的运算,我所采用的读取数据的示例代码为:
import tensorflow as tf
import numpy as np
import os
from PIL import Image
image_height = 256
image_width = 256
image_channel = 3
image_size = image_height * image_width
data_dir = "./vangogh2photo"
def main(input_dir, floderA, floderB):
imagesA = os.listdir(input_dir + floderA)
imagesB = os.listdir(input_dir + floderB)
imageA_len = len(imagesA)
imageB_len = len(imagesB)
dataA = np.empty((imageA_len, image_width, image_height, image_channel), dtype="float32")
dataB = np.empty((imageB_len, image_width, image_height, image_channel), dtype="float32")
for i in range(imageA_len):
img = Image.open(input_dir + floderA + "/" + imagesA[i])
img = img.resize((image_width, image_height))
arr = np.asarray(img, dtype="float32")
dataA[i, :, :, :] = arr * 1.0 / 127.5 - 1.0
for i in range(imageB_len):
img = Image.open(input_dir + floderB + "/" + imagesB[i])
img = img.resize((image_width, image_height))
arr = np.asarray(img, dtype="float32")
dataB[i, :, :, :] = arr * 1.0 / 127.5 - 1.0
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
dataA = tf.reshape(dataA, [-1, image_width, image_height, image_channel])
dataB = tf.reshape(dataB, [-1, image_width, image_height, image_channel])
train_set_A = sess.run(dataA)
train_set_B = sess.run(dataB)
return train_set_A, train_set_B
if __name__ =="__main__":
a, b = main(data_dir, "/trainA", "/trainB")
print(len(b))
我的电脑的内存为8G,显卡为GTX 1060 3G。只要运行上述代码必会报错。
三、解决办法
网上关于这个问题的讨论比较少,目前能想到的解决方法有三个:
1.在对实验没有大的影响的情况下,适当减少数据集的大小。
2.对代码进行修改,只读取每张图像的路径,然后再feed_dict中再单独读取batch_size个图像,读取每张图像路径的函数可以用os.listdir,在实验的过程中可以考虑这么使用:
from PIL import Image
import tensorflow as tf
import os
# 读取图像的路径位置
def read_data_path(data_dir):
images = os.listdir(data_dir)
image_len = len(images)
return images, image_len
# ############################################################# #
# 使用时,只需提前创建存储变量data,例如 #
# data = np.empty((image_len, 256, 256, 3), dtype="float32") #
# ############################################################# #
# 读取图像, images为输入数据,将其读到data当中的第i层
def load_data(data_dir, images, data, i):
img = Image.open(data_dir + "/" + images)
arr = np.asarray(img, dtype="float32")
data[i, :, :, :] = arr * 1.0 / 127.5 - 1.0
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
data = sess.run(tf.reshape(data, [-1, 256, 256, 3]))
return data
3.将输入数据分批制作成.tfrecords格式进行后续实验,可参考:https://blog.csdn.net/Liangjun_Feng/article/details/79698809
上一篇: js中!与!!的用法介绍
下一篇: js如何实现下拉控制列表