欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

pytorch中的dataloader批次数值取出使用

程序员文章站 2024-03-30 20:34:09
pytorch当中的dataloader可以实现相应的取出对应的dataloader的数值并进行使用,对应的定义如下# 实现Dataloaderclass Dataset(tud.Dataset): # 继承tud.Dataset父类 def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts): super(Dataset, self).__init__()...

pytorch当中的dataloader可以实现相应的取出对应的dataloader的数值并进行使用,对应的定义如下

# 实现Dataloader
class Dataset(tud.Dataset): # 继承tud.Dataset父类
    
    def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):    
        super(Dataset, self).__init__() 
        ......
        
    def __len__(self): 
    	......
        return len(self.text_encoded) #所有单词的总数
        
    def __getitem__(self, idx):
        ......
        return center_word, pos_words, neg_words 


dataset = Dataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)  

注意

def __len__(self):

函数相当于定义对应的dataloader当中可以取出数值的总长度
然后使用后续相应的enumerate进行调用

for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):

遍历相应的dataloader中的对应内容,后面的三个相应的参数(input_labels,pos_labels,neg_labels)为dataloader之中取出的相应的内容

本文地址:https://blog.csdn.net/znevegiveup1/article/details/110671853

相关标签: pytorch笔记