欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

程序员文章站 2022-07-14 15:42:32
...

前言:在深度学习中,数据的预处理是第一步,pytorch提供了非常规范的处理接口,本文将针对处理过程中的一些问题来进行说明,本文所针对的主要数据是图像数据集。

本文的案例来源于车道线语义分割,采用的数据集是tusimple数据集,当然先需要将tusimple数据集写一个简单的脚本程序转换成指定的数据格式,如下:

(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

一、基本概述

pytorch输入数据PipeLine一般遵循一个“三步走”的策略,一般pytorch 的数据加载到模型的操作顺序是这样的:
① 创建一个 Dataset 对象。必须实现__len__()、__getitem__()这两个方法,这里面会用到transform对数据集进行扩充。
② 创建一个 DataLoader 对象。它是对DataSet对象进行迭代的,一般不需要事先里面的其他方法了。
③ 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练

注意这三个类均在torch.utils.data 中,这个模块中定义了下面几个功能,

from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler
from .distributed import DistributedSampler
from .dataset import Dataset, TensorDataset, ConcatDataset, Subset, random_split
from .dataloader import DataLoader
 
# 可见,采样器sanpler,dataset,dataloader都是定义在里面的

pytorch数据处理pipeline 三步走的 一般格式如下:

dataset = MyDataset()           # 第一步:构造Dataset对象
dataloader = DataLoader(dataset)# 第二步:通过DataLoader来构造迭代对象
 
num_epoches = 100
for epoch in range(num_epoches):# 第三步:逐步迭代数据
    for img, label in dataloader:
        # 训练代码

二、Dataset类详解

Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中Dataset类中的两个私有成员函数必须被重载,否则将会触发错误提示:

  • def __getitem__(self, index):
  • def __len__(self):
  • def __init__(self): 构造函数一般情况下我们也是要自己定义的,但是不是强制性的。

其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数,例如通过dataset[i]可以得到数据集中的第i+1个数据。这个Dataset抽象父类的定义如下:

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
 
    def __len__(self):
        raise NotImplementedError
 
    def __add__(self, other):
        return ConcatDataset([self, other])

总结:Dataset的子类中除了上面的三个函数以外,当然还可以添加自己的一些处理函数,比如随机打乱,归一化等等,但是上面这三个一般情况下是必须要自己实现的。而且这三个函数的功能也有所侧重,一般情况下:

(1)__init__(self): 主要是数据的获取,比如从某个文件中获取

(2)__len__(self): 整个数据集的长度

(3)__getitem__(self,index): 这个是最重要的,一般情况下它会包含以下几个业务需要处理,

  • 第一,比如如果我们需要在读取数据的同时对图像进行增强的话,当然,图像增强的方法可以使用Pytorch内置的图像增强方式,也可以使用自定义或者其他的图像增强库这个很灵活。
  • 第二,在Pytorch中得到的图像必须是tensor,也就是说我们必须要将数据格式转化成pytorch的tensor格式才行。

2.1 构造函数__init__()

# coding: utf-8
 
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2
import numpy as np
 
from torchvision.transforms import ToTensor
from torchvision import datasets, transforms
 
import random
 
class LaneDataSet(Dataset):
    def __init__(self, dataset, transform):
        '''
        param:
            detaset: 实际上就是tusimple数据集的三个文本文件train.txt、val.txt、test.txt三者的文件路径
            transform: 决定是否进行变换,它其实是一个函数或者是几个函数的组合
        构造三个列表,存储每一张图片的文件路径          
        '''
        self._gt_img_list = []
        self._gt_label_binary_list = []
        self._gt_label_instance_list = []
        self.transform = transform
 
        with open(dataset, 'r') as file:  # 打开其实是那个 training下面的那个train.txt 文件
            for _info in file:
                info_tmp = _info.strip(' ').split()
 
                self._gt_img_list.append(info_tmp[0])
                self._gt_label_binary_list.append(info_tmp[1])
                self._gt_label_instance_list.append(info_tmp[2])
 
        assert len(self._gt_img_list) == len(self._gt_label_binary_list) == len(self._gt_label_instance_list)
 
        self._shuffle()

此构造函数主要功能是实现将tusimple的数据集的gt_image、binary_image、instance_image的路径分别存储在三个列表中,并且随机打乱。

这里有一个_shuffle()函数,如下:

def _shuffle(self):
    # 将gt_image、binary_image、instance_image三者所对应的图片路径组合起来,再进行随机打乱
    c = list(zip(self._gt_img_list, self._gt_label_binary_list, self._gt_label_instance_list))
    random.shuffle(c)
    self._gt_img_list, self._gt_label_binary_list, self._gt_label_instance_list = zip(*c)

2.2 必须要实现的__len__()函数

def __len__(self):
    return len(self._gt_img_list)

其实就是返回样本的数量,

2.3 必须要实现的__getitem__()函数

def __getitem__(self, idx):
    assert len(self._gt_label_binary_list) == len(self._gt_label_instance_list) \
               == len(self._gt_img_list)
 
    # 读取所有图片
    img = cv2.imread(self._gt_img_list[idx], cv2.IMREAD_COLOR) #真实图片 (720,1280,3)
 
    label_instance_img = cv2.imread(self._gt_label_instance_list[idx], cv2.IMREAD_UNCHANGED) # instance图片 (720,1280)
 
    label_binary_img = cv2.imread(self._gt_label_binary_list[idx], cv2.IMREAD_GRAYSCALE) #binary图片 (720,1280)
 
    # optional transformations,裁剪成(256,512)
    if self.transform:
        img = self.transform(img)
        label_binary_img = self.transform(label_binary_img)
        label_instance_img = self.transform(label_instance_img)
 
    img = img.reshape(img.shape[2], img.shape[0], img.shape[1]) #(3,720,1280) 这里都没有问题
    return (img, label_binary_img, label_instance_img)

本例没有在__getitem__实现了使用transform来对样本数据进行处理,但是还没有转化成tensor,返回的是numpy数组。后面在处理也是一样的。

三、DataLoader类详解(_DataLoaderIter类

DataLoader的几种访问方式:

(1)dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问,由于它本身就是一个可迭代对象,可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;

(2)先使用iter对dataloader进行第一步包装,使用iter(dataloader)返回的是一个迭代器,然后就可以可以使用next访问了。

先来看一下DataLoader 的定义,如下:

class DataLoader(object):
    __initialized = False
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 
    def __setattr__(self, attr, val):
    def __iter__(self):
    def __len__(self):

注意:

(1)我们一般不需要再自己去实现DataLoader的方法了,只需要在构造函数中指定相应的参数即可,比如常见的batch_size,shuffle等等参数。所以使用DataLoader十分简洁方便。既然都是通过指定构造函数的参数实现,这里重点介绍一下构造函数中参数的含义。

(2)DataLoader实际上一个较为高层的封装类,它的功能都是通过更底层的_DataLoader来完成的,但是_DataLoader类较为低层,这里就不再展开叙述了。DataLoaderIter就是_DataLoaderIter的一个框架, 用来传给_DataLoaderIter 一堆参数, 并把自己装进DataLoaderIter 里。

3.1 DataLoader的构造函数参数

 
class DataLoader(object):
 
    Arguments:
        dataset (Dataset): 是一个DataSet对象,表示需要加载的数据集.
        batch_size (int, optional): 每一个batch加载多少组样本,即指定batch_size,默认是 1 
        shuffle (bool, optional): 布尔值True或者是False ,表示每一个epoch之后是否对样本进行随机打乱,默认是False
------------------------------------------------------------------------------------
        sampler (Sampler, optional): 自定义从数据集中抽取样本的策略,如果指定这个参数,那么shuffle必须为False
        batch_sampler (Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥)
------------------------------------------------------------------------------------
        num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
        collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数(这个还不是很懂)
        pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
------------------------------------------------------------------------------------
        drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了,如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
------------------------------------------------------------------------------------
        timeout (numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
 
        worker_init_fn (callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``)
 
    注意事项note: By default, each worker will have its PyTorch seed set to
              ``base_seed + worker_id``, where ``base_seed`` is a long generated
              by main process using its RNG. However, seeds for other libraies
              may be duplicated upon initializing workers (w.g., NumPy), causing
              each worker to return identical random numbers. (See
              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
              use :func:`torch.initial_seed()` to access the PyTorch seed for
              each worker in :attr:`worker_init_fn`, and use it to set other
              seeds before data loading.
 
    警告warning: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function.

其实用的较多的,就是dataset,batch_size,shuffle这三个参数。

3.2 参数解析之——batch_size和shuffle参数

看下面的简单应用:

import time
import os
import sys
 
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
import torch
from torch import cuda
from torch.utils.data import DataLoader
from torchvision import transforms
 
# 这是自己项目里面的模块
from data_loader.data_loaders import LaneDataSet
from data_loader.transformers import Rescale
from lanenet.lanenet import LaneNet
from lanenet.Model import ESPNet,compute_loss  # 导入ESPNet
 
def train(train_loader):
    t=enumerate(iter(train_loader)) # 这里使用iter对dataloader进行了包装
    for batch_idx, batch in t:
        
        # 注意 ,这三个数据都是 FloatTensor
        image_data = batch[0].type(torch.FloatTensor).to(DEVICE)      # (8,3,256,512) 
        binary_label = batch[1].type(torch.FloatTensor).to(DEVICE)    # [8,256,512]  ,只有 0,255 这两个值
        instance_label = batch[2].type(torch.FloatTensor).to(DEVICE)  # (8,256,512)  ,只有 0,20,70,120,170 每根车道线的值
  
        # 查看每一个batch里面的第一张样本和所对应的标签
        binary_label=binary_label.detach().cpu().numpy()
        instance_label=instance_label.detach().cpu().numpy()
        image_data=image_data.detach().cpu().type(torch.IntTensor).numpy()
 
        image_data = image_data.reshape(image_data.shape[0],image_data.shape[2], image_data.shape[3], image_data.shape[1]) #(8,256,512,3) 
 
        plt.figure('image_data')
        plt.imshow(image_data[0][:,:,::-1])  #(256,512,3)
 
        plt.figure('binary_image')
        plt.imshow(binary_label[0], cmap='gray')  #(256,512)
 
        plt.figure('instance_image')
        plt.imshow(instance_label[0], cmap='gray')  #(256,512)
 
        plt.show()
        print("--------------------------------------------")
 
def main():
    train_dataset_file = 'H:/tusimple_dataset/training/train.txt'
 
    # 第一步: 构造dataset 对象
    train_dataset = LaneDataSet(train_dataset_file, transform=transforms.Compose([Rescale((512,256))]))
 
    # 第二步: 构造dataloader 对象
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
 
    # 第三步:迭代dataloader,进行训练
    train(train_loader)  
 
if __name__ == '__main__':
    main()

运行结果如下:

(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

3.3 参数解析之——sampler参数

这个参数其实就是一个“采样器”,表示从样本中究竟如何取样,pytorch中采样器有如下几个:

class Sampler(object):
  
class SequentialSampler(Sampler):
 
class RandomSampler(Sampler):
 
class SubsetRandomSampler(Sampler):
 
class WeightedRandomSampler(Sampler):
 
class BatchSampler(Sampler):

注意:Sampler类是所有的采样器的基类,每一个继承自Sampler的子类都必须实现它的__iter__方法和__len__方法。前者实现如何迭代样本,后者实现一共有多少个样本。

其实DataLoader里面在构造函数中就定义了采样器——如何采样,__init__中的部分代码如下所示:

if batch_sampler is None: # 没有手动传入batch_sampler参数时
    if sampler is None:   # 没有手动传入sampler参数时
        if shuffle:
            sampler = RandomSampler(dataset)
        else:
            sampler = SequentialSampler(dataset)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)
 
self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True

3.4 参数解析之——collate_fn(这个参数往往是出现错误的根源所在

DataLoader能够为我们自动生成一个多线程的迭代器,只要传入几个参数进行就可以了,第一个参数就是上面定义的数据集,后面几个参数就是batch size的大小,是否打乱数据,读取数据的线程数目等等,这样一来,我们就建立了一个多线程的I/O。

读到这里,你可能觉得PyTorch真的太方便了,真的是简单实用,但是在使用的过程中很有可能性就报错了,而且你也是一步一步按着实现来的,怎么就报错了呢?

不用着急,下面就来讲一下为什么会报错,以及这一块pyhon实现的解读,这样你就能够真正知道如何进行自定义的数据读入。

(1)问题来源

通过上面的实现,可能会遇到各种不同的问题,Dataset非常简单,一般都不会有错,只要Dataset实现正确,那么问题的来源只有一个,那就是torch.utils.data.DataLoader中的一个参数collate_fn,这里我们需要找到DataLoader的源码进行查看这个参数到底是什么。

可以看到collate_fn默认是等于default_collate,那么这个函数的定义如下。

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""
 
    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if re.search('[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))
 
            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], container_abcs.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]
 
    raise TypeError((error_msg.format(type(batch[0]))))

这是他的定义,但是到目前为止,我们似乎并不知道这个函数的作用,它的输入参数是batch,这是什么意思也不知道,我们找到这个函数的调用部分,看一看究竟给这个函数传递进去的是什么。

前面说到了,DataLoader实际上是通过_DataLoaderIter来实现的,进入_DataLoaderIter,找到函数的调用如下:

def __next__(self):
    if self.num_workers == 0:  # same-process loading
        indices = next(self.sample_iter)  # may raise StopIteration
        batch = self.collate_fn([self.dataset[i] for i in indices])  # 在这里调用了collate_fn函数,传递的参数是一个列表
        if self.pin_memory:
            batch = pin_memory_batch(batch)
        return batch

由上面可以发现,default_collate(batch)中的参数就是这里的  [self.dataset[i] for i in indices] 。从这里看这就是一个list,list中的每个元素就是self.data[i],如果你在往上看,可以看到这个self.data就是我们需要预先定义的Dataset,那么这里self.data[i]就等价于我们在Dataset里面定义的__getitem__这个函数。

所以我们知道了collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果。

这时我们再去看看collate_fn这个函数,其实可以看到非常简单,就是通过对一些情况的排除,然后最后输出结果,比如第一个if,如果我们的输入是一个tensor,那么最后会将一个batch size的tensor重新stack在一起,比如输入的tensor是一张图片,3x30x30,如果batch size是32,那么按第一维stack之后的结果就是32x3x30x30,这里stack和concat有一点区别就是会增加一维。

所以通过上面的源码解读我们知道了数据读入具体是如何操作的,那么我们就能够实现自定义的数据读入了,我们需要自己按需要重新定义collate_fn这个函数,下面举个例子。

(2)collate_fn的案例一

下面我们来举一个麻烦的例子,比如做文本识别,需要将一张图片上的字符识别出来,比如下面这些图片

(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

那么这个问题的输入就是一张一张的图片,他的label就是一串字符,但是由于长度是变化的,所以这个问题比较麻烦。

下面我们就来简单实现一下。

我们有一个train.txt的文件,上面有图片的名称和对应的label,首先我们需要定义一个Dataset。

class custom_dset(Dataset):
    def __init__(self,
                 img_path,
                 txt_path,
                 img_transform=None,
                 loader=default_loader):
        with open(txt_path, 'r') as f:
            lines = f.readlines()
            self.img_list = [
                os.path.join(img_path, i.split()[0]) for i in lines
            ]
            self.label_list = [i.split()[1] for i in lines]
        self.img_transform = img_transform
        self.loader = loader
 
    def __getitem__(self, index):
        img_path = self.img_list[index]
        label = self.label_list[index]
 
        img = img_path
        if self.img_transform is not None:
            img = self.img_transform(img)
        return img, label
 
    def __len__(self):
        return len(self.label_list)

这里非常简单,就是将txt文件打开,然后分别读取图片名和label,由于存放图片的文件夹我并没有放上去,因为数据太大,所以读取图片以及对图片做一些变换的操作就不进行了。

接着我们自定义一个collate_fn,这里可以使用任何名字,只要在DataLoader里面传入就可以了。

def collate_fn(batch):
    batch.sort(key=lambda x: len(x[1]), reverse=True)
    img, label = zip(*batch)
    pad_label = []
    lens = []
    max_len = len(label[0])
    for i in range(len(label)):
        temp_label = [0] * max_len
        temp_label[:len(label[i])] = label[i]
        pad_label.append(temp_label)
        lens.append(len(label[i]))
    return img, pad_label, lens

(3)collate_fn的案例二

在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在_ getitem _函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。

from torch.utils.data.dataloader import default_collate  # 导入这个函数
def collate_fn(batch):
    '''
    batch 实际上是一个列表,列表的长度就是一个batch_size,列表的每一个元素形如(data, label),
          这实际上是定义DataSet的时候,每一个__getitem__得到的元素
    '''
    # 过滤为None的数据
    batch = list(filter(lambda x:x[0] is not None, batch))
    if len(batch) == 0: 
        return torch.Tensor()
    return default_collate(batch) # 用默认方式拼接过滤后的batch数据,这里的defaut_collate就是pytorch默认给collate_fn传递的函数,需要导入才能使用
# 第一步:定义dataset
dataset = NewDogCat(root='data/dogcat_wrong/', transform=transform)
 
# 第二步:定义dataloader,需要注意的是,这里的collate_fn是我自己定义的啊
dataloader = DataLoader(dataset, 2, collate_fn=collate_fn, num_workers=1,shuffle=True)
 
# 第三步:迭代dataloader
for batch_datas, batch_labels in dataloader:
    print(batch_datas.size(),batch_labels.size())

总结:什么时候该使用DataLoader的collate_fn这个参数?

         当定义DataSet类中的__getitem__函数的时候,由于每次返回的是一组类似于(x,y)的样本,但是如果在返回的每一组样本x,y中出现什么错误,或者是还需要进一步对x,y进行一些处理的时候,我们就需要再定义一个collate_fn函数来实现这些功能。当然我也可以自己在实现__getitem__的时候就实现这些后处理也是可以的。

      collate_fn,中单词collate的含义是:核对,校勘,对照,整理。顾名思义,这就是一个对每一组样本数据进行一遍“核对和重新整理”,现在可能更好理解一些。

后面有一篇专门讲解collate_fn的文章,请参考:

(第二篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

相关标签: pytorch数据预处理