pytorch中的dataloader批次数值取出使用
程序员文章站
2024-03-30 20:34:09
pytorch当中的dataloader可以实现相应的取出对应的dataloader的数值并进行使用,对应的定义如下# 实现Dataloaderclass Dataset(tud.Dataset): # 继承tud.Dataset父类 def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts): super(Dataset, self).__init__()...
pytorch当中的dataloader可以实现相应的取出对应的dataloader的数值并进行使用,对应的定义如下
# 实现Dataloader
class Dataset(tud.Dataset): # 继承tud.Dataset父类
def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):
super(Dataset, self).__init__()
......
def __len__(self):
......
return len(self.text_encoded) #所有单词的总数
def __getitem__(self, idx):
......
return center_word, pos_words, neg_words
dataset = Dataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
注意
def __len__(self):
函数相当于定义对应的dataloader当中可以取出数值的总长度
然后使用后续相应的enumerate进行调用
for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
遍历相应的dataloader中的对应内容,后面的三个相应的参数(input_labels,pos_labels,neg_labels)为dataloader之中取出的相应的内容
本文地址:https://blog.csdn.net/znevegiveup1/article/details/110671853