pytorch有关 Dataset和 DataLoader的心得
程序员文章站
2022-03-15 15:08:55
先来看看官方文档的说法:https://pytorch.org/docs/stable/data.htmlDataLoader支持两种数据集:map-style datasets 和 iterable-style datasets.一般我们用的最多的是map-style datasets,因此这里只讲map类型的,还有我到目前为止也没用过iterable类型的。(无知导致无能,很抱歉,这部分我不知道~)我们要使用map-style datasets,要实现两种方法__getitem__()和__...
先来看看官方文档的说法:https://pytorch.org/docs/stable/data.html
DataLoader支持两种数据集:map-style datasets 和 iterable-style datasets.
一般我们用的最多的是map-style datasets,因此这里只讲map类型的,还有我到目前为止也没用过iterable类型的。(无知导致无能,很抱歉,这部分我不知道~)
我们要使用map-style datasets,要实现两种方法__getitem__()和__len__(),这里我拿出我最近写的一个demo
class myDataset(Dataset): def __init__(self, data, label): self.data_list = data self.label_list = label def __getitem__(self, index): data_idx = [] data_idx.append(word2idx.transform(self.data_list[index].split(), max_len=max_len)) text = torch.LongTensor(data_idx) label = torch.LongTensor([self.label_list[index]]) return text, label def __len__(self): return len(self.data_list)
dataset = myDataset(data_list, label_list) data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
注意点:
1、__getitem__(self,index)里面每次的返回值,是一对数据,即文本和标签,我们通过参数index来确定返回哪个数据。
(我之前是返回的一批数据直接导致内存爆了,真的是无知者无畏啊,给大家看看爆的多少,一百多G的GPU)
本文地址:https://blog.csdn.net/qq_40819945/article/details/109622591