数据集的预处理
程序员文章站
2022-06-01 16:23:56
...
mini-ImageNet
首先配置好数据集,images,train.csv,test.csv,val.csv,目录如下
miniimagenet/ ├── images ├── n0210891500001298.jpg ├── n0287152500001298.jpg ... ├── test.csv ├── val.csv └── train.csv └── proc_images.py
images文件夹下是60000张图片,先对其归一化成84*84大小;csv文件中是图片文件名和其对应的标签
按照csv对应的标签分成训练集、测试集、验证集。训练集中每个文件夹都代表一类,其文件夹名称就是标签
例如:'images/n0153282900000005.jpg' -> 'train/n01532829/n0153282900000005.jpg'
proc_images.py
首先给定images,train.csv,test.csv,val.csv文件,将images中的图片按照csv文件对应的分成训练集、测试集、验证集,自动生成三个文件夹(train、test、val)
'''
windows版本
'''
from __future__ import print_function
import csv
import glob
import os
from PIL import Image
path_to_images = 'imagess/'
all_images = glob.glob(path_to_images + '*')#调用glob函数读取文件中图片
# 将图片归一化为84*84大小
for i, image_file in enumerate(all_images):
im = Image.open(image_file)
im = im.resize((84, 84), resample=Image.LANCZOS)
im.save(image_file)
if i % 500 == 0:
print(i)
# 根据csv文件从images中读取数据并分成三类,创建相应的目录文件夹(train、val、test)
for datatype in ['train', 'val', 'test']:
os.mkdir(datatype)
with open(datatype + '.csv', 'r') as f:
reader = csv.reader(f, delimiter=',')
last_label = ''
for i, row in enumerate(reader):
if i == 0: # skip the headers
continue
label = row[1]
image_name = row[0]
if label != last_label:
cur_dir = datatype + '/' + label + '/'
os.mkdir(cur_dir)
last_label = label
os.rename('imagess/' + image_name,cur_dir + image_name)
'''
首先配置好数据集,images,train.csv,test.csv,val.csv。
images文件夹下是60000张图片,先对其归一化成84*84大小;csv文件中是图片文件名和其对应的标签
按照csv对应的标签分成训练集、测试集、验证集。训练集中每个文件夹都代表一类,其文件夹名称就是标签
例如:'images/n0153282900000005.jpg' -> 'train/n01532829/n0153282900000005.jpg'
'''
'''
linux版本
'''
from __future__ import print_function
import csv
import glob
import os
from PIL import Image
path_to_images = 'images/'
all_images = glob.glob(path_to_images + '*')
# Resize images
for i, image_file in enumerate(all_images):
im = Image.open(image_file)
im = im.resize((84, 84), resample=Image.LANCZOS)
im.save(image_file)
if i % 500 == 0:
print(i)
# Put in correct directory
for datatype in ['train', 'val', 'test']:
os.system('mkdir ' + datatype)
with open(datatype + '.csv', 'r') as f:
reader = csv.reader(f, delimiter=',')
last_label = ''
for i, row in enumerate(reader):
if i == 0: # skip the headers
continue
label = row[1]
image_name = row[0]
if label != last_label:
cur_dir = datatype + '/' + label + '/'
os.system('mkdir ' + cur_dir)
last_label = label
os.system('mv images/' + image_name + ' ' + cur_dir)
'''
Pytorch版本
'''
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import numpy as np
import collections
from PIL import Image
import csv
import random
class MiniImagenet(Dataset):
"""
put mini-imagenet files as :
root :
|- images/*.jpg includes all imgeas
|- train.csv
|- test.csv
|- val.csv
注意:元学习不同于一般的监督学习,尤其是批处理和集合的概念。
批处理:包含多个集合
集合: n_way * k_shot为元训练集, n_way * n_query为元测试集.
"""
def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
"""
:param root: root path of mini-imagenet
:param mode: train, val or test
:param batchsz: batch size of sets, not batch of imgs
:param n_way:
:param k_shot:
:param k_query: num of qeruy imgs per class
:param resize: resize to
:param startidx: 从startidx开始索引标签
"""
self.batchsz = batchsz # batch of set, not batch of imgs
self.n_way = n_way # n-way
self.k_shot = k_shot # k-shot
self.k_query = k_query # for evaluation
self.setsz = self.n_way * self.k_shot # num of samples per set
self.querysz = self.n_way * self.k_query # number of samples per set for evaluation
self.resize = resize # resize to
self.startidx = startidx # index label not from 0, but from startidx
print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % (
mode, batchsz, n_way, k_shot, k_query, resize))
if mode == 'train':
self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
transforms.Resize((self.resize, self.resize)),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
else:
self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
transforms.Resize((self.resize, self.resize)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
self.path = os.path.join(root, 'images') # image path
csvdata = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path
self.data = []
self.img2label = {}
for i, (k, v) in enumerate(csvdata.items()):
self.data.append(v) # [[img1, img2, ...], [img111, ...]]
self.img2label[k] = i + self.startidx # {"img_name[:9]":label}
self.cls_num = len(self.data)
self.create_batch(self.batchsz)
def loadCSV(self, csvf):
"""
return a dict saving the information of csv
:param splitFile: csv file name
:return: {label:[file1, file2 ...]}
"""
dictLabels = {}
with open(csvf) as csvfile:
csvreader = csv.reader(csvfile, delimiter=',')
next(csvreader, None) # skip (filename, label)
for i, row in enumerate(csvreader):
filename = row[0]
label = row[1]
# append filename to current label
if label in dictLabels.keys():
dictLabels[label].append(filename)
else:
dictLabels[label] = [filename]
return dictLabels
def create_batch(self, batchsz):
"""
create batch for meta-learning.
×episode× here means batch, and it means how many sets we want to retain.
:param episodes: batch size
:return:
"""
self.support_x_batch = [] # support set batch
self.query_x_batch = [] # query set batch
for b in range(batchsz): # for each batch
# 1.select n_way classes randomly
selected_cls = np.random.choice(self.cls_num, self.n_way, False) # no duplicate
np.random.shuffle(selected_cls)
support_x = []
query_x = []
for cls in selected_cls:
# 2. select k_shot + k_query for each class
selected_imgs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False)
np.random.shuffle(selected_imgs_idx)
indexDtrain = np.array(selected_imgs_idx[:self.k_shot]) # idx for Dtrain
indexDtest = np.array(selected_imgs_idx[self.k_shot:]) # idx for Dtest
support_x.append(
np.array(self.data[cls])[indexDtrain].tolist()) # get all images filename for current Dtrain
query_x.append(np.array(self.data[cls])[indexDtest].tolist())
# shuffle the correponding relation between support set and query set
random.shuffle(support_x)
random.shuffle(query_x)
self.support_x_batch.append(support_x) # append set to current sets
self.query_x_batch.append(query_x) # append sets to current sets
def __getitem__(self, index):
"""
index means index of sets, 0<= index <= batchsz-1
:param index:
:return:
"""
# [setsz, 3, resize, resize]
support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
# [setsz]
support_y = np.zeros((self.setsz), dtype=np.int)
# [querysz, 3, resize, resize]
query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
# [querysz]
query_y = np.zeros((self.querysz), dtype=np.int)
flatten_support_x = [os.path.join(self.path, item)
for sublist in self.support_x_batch[index] for item in sublist]
support_y = np.array(
[self.img2label[item[:9]] # filename:n0153282900000005.jpg, the first 9 characters treated as label
for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32)
flatten_query_x = [os.path.join(self.path, item)
for sublist in self.query_x_batch[index] for item in sublist]
query_y = np.array([self.img2label[item[:9]]
for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32)
# print('global:', support_y, query_y)
# support_y: [setsz]
# query_y: [querysz]
# unique: [n-way], sorted
unique = np.unique(support_y)
random.shuffle(unique)
# relative means the label ranges from 0 to n-way
support_y_relative = np.zeros(self.setsz)
query_y_relative = np.zeros(self.querysz)
for idx, l in enumerate(unique):
support_y_relative[support_y == l] = idx
query_y_relative[query_y == l] = idx
# print('relative:', support_y_relative, query_y_relative)
for i, path in enumerate(flatten_support_x):
support_x[i] = self.transform(path)
for i, path in enumerate(flatten_query_x):
query_x[i] = self.transform(path)
# print(support_set_y)
# return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y)
return support_x, torch.LongTensor(support_y_relative), query_x, torch.LongTensor(query_y_relative)
def __len__(self):
# as we have built up to batchsz of sets, you can sample some small batch size of sets.
return self.batchsz
if __name__ == '__main__':
# 下面的章节是通过tensorboard查看一组图像。
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from tensorboardX import SummaryWriter
import time
plt.ion()# 打开交互模式
tb = SummaryWriter('runs', 'mini-imagenet')#记录
#对数据预处理,
mini = MiniImagenet('./data/miniimagenet/', mode='train', n_way=5, k_shot=1, k_query=1, batchsz=1000, resize=168)
for i, set_ in enumerate(mini):
# support_x: [k_shot*n_way, 3, 84, 84]
support_x, support_y, query_x, query_y = set_
support_x = make_grid(support_x, nrow=2)
query_x = make_grid(query_x, nrow=2)
plt.figure(1)
plt.imshow(support_x.transpose(2, 0).numpy())
plt.pause(0.5)
plt.figure(2)
plt.imshow(query_x.transpose(2, 0).numpy())
plt.pause(0.5)
tb.add_image('support_x', support_x)
tb.add_image('query_x', query_x)
time.sleep(5)
tb.close()
Omniglot
注释完即
上一篇: php+mysql大量用户登录解决方案
下一篇: wikipedia数据集预处理