pytorch-划分数据集
程序员文章站
2022-07-14 13:32:19
...
问题来源:
torchvision.datasets里的数据集都是整一个的,尤其在联邦学习中,我们需要划分不同的数据集给不同的参与者,如何将torchvision中的数据集划分成为了联邦学习中的重大挑战。
torch.utils.data.Subset
torch包有划分torchvision数据集的函数Subset,用法如下:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data.dataloader as dataloader
from torch.utils.data import Subset
train_set = torchvision.datasets.MNIST(root="./data",train=True,transform=transforms.ToTensor(),download=True)
train_set_A=Subset(train_set,range(0,1000))
train_set_B=Subset(train_set,range(1000,2000))
train_set_C=Subset(train_set,range(2000,3000))
train_loader_A = dataloader.DataLoader(dataset=train_set_A,batch_size=1000,shuffle=False)
train_loader_B = dataloader.DataLoader(dataset=train_set_B,batch_size=1000,shuffle=False)
train_loader_C = dataloader.DataLoader(dataset=train_set_C,batch_size=1000,shuffle=False)
test_set = torchvision.datasets.MNIST(root="./data",train=False,transform=transforms.ToTensor(),download=True)
test_set=Subset(test_set,range(0,2000))
test_loader = dataloader.DataLoader(dataset=test_set,shuffle=True)
手动划分
def iterate_minibatches(inputs, targets, batchsize=100, shuffle=False ):
assert len(inputs) == len(targets)
if shuffle:
indices = np.arange(len(inputs))
np.random.shuffle(indices)
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
if shuffle:
excerpt = indices[start_idx:start_idx + batchsize]
else:
excerpt = slice(start_idx, start_idx + batchsize)
yield inputs[excerpt], targets[excerpt]
def inf_data(x, y, batchsize, shuffle=False):
while True:
for b in iterate_minibatches(x, y, batchsize=100, shuffle=shuffle):
yield b
def split_data(dataset, n_workers=3, size=1000):
X = dataset.data
# 从 60000*28*28 变成 60000*1*28*28
X = X.resize(X.shape[0],1,X.shape[1],X.shape[2])
X = X/225
y = dataset.targets
index = np.arange(X.shape[0])
np.random.shuffle(index)
splitted_X = []
splitted_y = []
for i in range(n_workers):
xx = X[index[size * i: size * (i + 1)]]
yy = y[index[size * i: size * (i + 1)]]
splitted_X.append(xx)
splitted_y.append(yy)
return splitted_X, splitted_y
splitted_X, splitted_y = split_data(trainset, n_workers, 10000)
dataloaders = []
for i in range(n_workers):
dataloaders.append(inf_data(splitted_X[i], splitted_y[i], batchsize))
上面的方法也可以用torch.utils.data.TensorDataset
将分好的tensor类型的数据直接变成dataloder
参考:
https://blog.csdn.net/qq_39220308/article/details/107899927