欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch-划分数据集

程序员文章站 2022-07-14 13:32:19
...

问题来源:

torchvision.datasets里的数据集都是整一个的,尤其在联邦学习中,我们需要划分不同的数据集给不同的参与者,如何将torchvision中的数据集划分成为了联邦学习中的重大挑战。

torch.utils.data.Subset

torch包有划分torchvision数据集的函数Subset,用法如下:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data.dataloader as dataloader
from torch.utils.data import Subset

train_set = torchvision.datasets.MNIST(root="./data",train=True,transform=transforms.ToTensor(),download=True)
train_set_A=Subset(train_set,range(0,1000))
train_set_B=Subset(train_set,range(1000,2000))
train_set_C=Subset(train_set,range(2000,3000))
train_loader_A = dataloader.DataLoader(dataset=train_set_A,batch_size=1000,shuffle=False)
train_loader_B = dataloader.DataLoader(dataset=train_set_B,batch_size=1000,shuffle=False)
train_loader_C = dataloader.DataLoader(dataset=train_set_C,batch_size=1000,shuffle=False)
test_set = torchvision.datasets.MNIST(root="./data",train=False,transform=transforms.ToTensor(),download=True)
test_set=Subset(test_set,range(0,2000))
test_loader = dataloader.DataLoader(dataset=test_set,shuffle=True)

手动划分


def iterate_minibatches(inputs, targets, batchsize=100, shuffle=False ):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.arange(len(inputs))
        np.random.shuffle(indices)

    for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)

        yield inputs[excerpt], targets[excerpt]

def inf_data(x, y, batchsize, shuffle=False):
    while True:
        for b in iterate_minibatches(x, y, batchsize=100, shuffle=shuffle):
            yield b
            
            
def split_data(dataset, n_workers=3, size=1000):
    X = dataset.data
    # 从 60000*28*28 变成 60000*1*28*28 
    X = X.resize(X.shape[0],1,X.shape[1],X.shape[2])
    X = X/225
    y = dataset.targets
    
    index = np.arange(X.shape[0])
    np.random.shuffle(index)
    
    splitted_X = []
    splitted_y = []
    for i in range(n_workers):
        xx = X[index[size * i: size * (i + 1)]]
        yy = y[index[size * i: size * (i + 1)]]
        splitted_X.append(xx)
        splitted_y.append(yy)
        

    return splitted_X, splitted_y
    
splitted_X, splitted_y = split_data(trainset, n_workers, 10000)
dataloaders = []
for i in range(n_workers):
    dataloaders.append(inf_data(splitted_X[i], splitted_y[i], batchsize))

上面的方法也可以用torch.utils.data.TensorDataset将分好的tensor类型的数据直接变成dataloder

参考:
https://blog.csdn.net/qq_39220308/article/details/107899927