Dataset和DataLoader简介#

在使用Pytorch构建和训练模型的过程中,经常需要把原始数据(文本格式数据等)转换为张量的格式,这时pytorch的数据导入功能便能发挥作用。

Pytorch导入数据主要依靠 torch.utils.data.DataLoader和 torch.utils.data.Dataset这两个类来完成,并采用以下步骤完成最终的数据读取。

1 pytorch输入数据pipline“三步走”策略#

这里主要引用了这篇blog的内容:https://blog.csdn.net/loveliuzz/article/details/108756253

一般pytorch的数据加载到模型的操作顺序如下:

  1. 创建一个 Dataset 对象,必须实现 __len__() ,__getitem__() 两个方法;

  2. 创建一个 DataLoader 对象,对 Dataset 对象进行迭代的,一般不需要事先实现里面其他方法了;

  3. 循环遍历DataLoader对象,将样本数据加载到模型中训练。

注:Dataset、DataLoader 均在 torch.utils.data 中定义,使用它们的时候代码会有如下形式:

from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
 
dataset = MyDataset() # 第一步:构建 Dataset 对象
dataloader = DataLoader(dataset) # 第二步:通过Dataloader来构建迭代对象
 
num_epoches = 100
for epoch in range(num_epoches):
    for i, data in enumerate(dataloader):
        # 训练代码

数据集遍历一般化流程:

for i, data in enumerate(dataloader): 会调用dataloader 的 iter() 方法,产生了一个DataLoaderIter(迭代器),这里判断使用单线程还是多线程,调用 DataLoaderIter 的 next() 方法来得到 batch data 。在__next__() 方法中使用 __next_index()方法调用sampler(采样器)获得index索引,接着通过 Dataset_fetcher 的 fetch() 方法根据index(索引)调用dataset的 getitem() 方法,然后用 collate_fn 把它们打包成batch。当数据读完后, next() 抛出一个 StopIteration 异常,for循环结束,dataloader 失效。

2 Dataset#

torch.utils.data.Dataset是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:

  • getitem()

  • len()

第一个最为重要, 即每次怎么读数据,比如对于一个图片:

def __getitem__(self, index):
    img_path, label = self.data[index].img_path, self.data[index].label
    img = Image.open(img_path)

    return img, label

可以看到,img是模型训练的输入数据,label是模型训练的输出数据。

第二个比较简单, 就是返回整个数据集所有样本的个数,比如:

def __len__(self):
    return len(self.data)

3 Dataloader#

torch.utils.data.DataLoader类定义为:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False)

可以看到, 主要参数有这么几个:

  • dataset : 即上面自定义的dataset.

  • collate_fn: 这个函数用来打包batch, 后面详细讲.

  • num_worker: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据啦.

这个类一共干了两件事:

定义了一堆成员变量, 到时候赋给torch.utils.data.dataloader.DataLoaderIter 然后有一个iter() 函数, 把自己 “装进” DataLoaderIter 里面.

def iter(self): return DataLoaderIter(self)

首先我们看一下DataLoader.__next__的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据)。

class DataLoader(object):
    ...
    
    def __next__(self):
        if self.num_workers == 0:  
            indices = next(self.sample_iter)  # Sampler
            batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch

在阅读上面代码前,我们可以假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler完成,具体的细节如果我们不重写新函数就不用管。

然后Dataset和DataLoader在什么时候产生关系呢?就是我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。

再下面的if语句的作用简单理解就是,如果pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。

所以有这样的关系图:

最后,Dataset和DatasetLoader及数据集切分的基本使用流程如下:

  • Dataset是一个抽象类,需要派生一个子类构造数据集,需要改写的方法有__init__,__getitem__等。

  • DataLoader是一个迭代器,方便我们访问Dataset里的对象

如果大家对以上内容感兴趣,接下来这里有些根据网上查的资料总结下的dataset,dataloader等的用法可供大家参考。

参考资料有: