Shortcuts

数据集(Dataset)与数据加载器(DataLoader)

提示

如果你没有接触过 PyTorch 的数据集与数据加载器,我们推荐先浏览 PyTorch 官方教程以了解一些基本概念

数据集与数据加载器是 MMEngine 中训练流程的必要组件,它们的概念来源于 PyTorch,并且在含义上与 PyTorch 保持一致。通常来说,数据集定义了数据的总体数量、读取方式以及预处理,而数据加载器则在不同的设置下迭代地加载数据,如批次大小(batch_size)、随机乱序(shuffle)、并行(num_workers)等。数据集经过数据加载器封装后构成了数据源。在本篇教程中,我们将按照从外(数据加载器)到内(数据集)的顺序,逐步介绍它们在 MMEngine 执行器中的用法,并给出一些常用示例。读完本篇教程,你将会:

  • 掌握如何在 MMEngine 的执行器中配置数据加载器

  • 学会在配置文件中使用已有(如 torchvision)数据集

  • 了解如何使用自己的数据集

数据加载器详解

在执行器(Runner)中,你可以分别配置以下 3 个参数来指定对应的数据加载器

  • train_dataloader:在 Runner.train() 中被使用,为模型提供训练数据

  • val_dataloader:在 Runner.val() 中被使用,也会在 Runner.train() 中每间隔一段时间被使用,用于模型的验证评测

  • test_dataloader:在 Runner.test() 中被使用,用于模型的测试

MMEngine 完全支持 PyTorch 的原生 DataLoader,因此上述 3 个参数均可以直接传入构建好的 DataLoader,如15分钟上手中的例子所示。同时,借助 MMEngine 的注册机制,以上参数也可以传入 dict,如下面代码(以下简称例 1)所示。字典中的键值与 DataLoader 的构造参数一一对应。

runner = Runner(
    train_dataloader=dict(
        batch_size=32,
        sampler=dict(
            type='DefaultSampler',
            shuffle=True),
        dataset=torchvision.datasets.CIFAR10(...),
        collate_fn=dict(type='default_collate')
    )
)

在这种情况下,数据加载器会在实际被用到时,在执行器内部被构建。

备注

关于 DataLoader 的更多可配置参数,你可以参考 PyTorch API 文档

备注

如果你对于构建的具体细节感兴趣,你可以参考 build_dataloader

细心的你可能会发现,例 1 并非直接由15分钟上手中的示例代码简单修改而来。你可能本以为将 DataLoader 简单替换为 dict 就可以无缝切换,但遗憾的是,基于注册机制构建时 MMEngine 会有一些隐式的转换和约定。我们将介绍其中的不同点,以避免你使用配置文件时产生不必要的疑惑。

sampler 与 shuffle

与 15 分钟上手明显不同,例 1 中我们添加了 sampler 参数,这是由于在 MMEngine 中我们要求通过 dict 传入的数据加载器的配置必须包含 sampler 参数。同时,shuffle 参数也从 DataLoader 中移除,这是由于在 PyTorch 中 samplershuffle 参数是互斥的,见 PyTorch API 文档

备注

事实上,在 PyTorch 的实现中,shuffle 只是一个便利记号。当设置为 TrueDataLoader 会自动在内部使用 RandomSampler

当考虑 sampler 时,例 1 代码基本可以认为等价于下面的代码块

from mmengine.dataset import DefaultSampler

dataset = torchvision.datasets.CIFAR10(...)
sampler = DefaultSampler(dataset, shuffle=True)

runner = Runner(
    train_dataloader=DataLoader(
        batch_size=32,
        sampler=sampler,
        dataset=dataset,
        collate_fn=default_collate
    )
)

警告

上述代码的等价性只有在:1)使用单进程训练,以及 2)没有配置执行器的 randomness 参数时成立。这是由于使用 dict 传入 sampler 时,执行器会保证它在分布式训练环境设置完成后才被惰性构造,并接收到正确的随机种子。这两点在手动构造时需要额外工作且极易出错。因此,上述的写法只是一个示意而非推荐写法。我们强烈建议 samplerdict 的形式传入,让执行器处理构造顺序,以避免出现问题。

DefaultSampler

上面例子可能会让你好奇:DefaultSampler 是什么,为什么要使用它,是否有其他选项?事实上,DefaultSampler 是 MMEngine 内置的一种采样器,它屏蔽了单进程训练与多进程训练的细节差异,使得单卡与多卡训练可以无缝切换。如果你有过使用 PyTorch DistributedDataParallel 的经验,你一定会对其中更换数据加载器的 sampler 参数有所印象。但在 MMEngine 中,这一细节通过 DefaultSampler 而被屏蔽。

除了 Dataset 本身之外,DefaultSampler 还支持以下参数配置:

  • shuffle 设置为 True 时会打乱数据集的读取顺序

  • seed 打乱数据集所用的随机种子,通常不需要在此手动设置,会从 Runnerrandomness 入参中读取

  • round_up 设置为 True 时,与 PyTorch DataLoader 中设置 drop_last=False 行为一致。如果你在迁移 PyTorch 的项目,你可能需要注意这一点。

备注

更多关于 DefaultSampler 的内容可以参考 API 文档

DefaultSampler 适用于绝大部分情况,并且我们保证在执行器中使用它时,随机数等容易出错的细节都被正确地处理,防止你陷入多进程训练的常见陷阱。如果你想要使用基于迭代次数 (iteration-based) 的训练流程,你也许会对 InfiniteSampler 感兴趣。如果你有更多的进阶需求,你可能会想要参考上述两个内置 sampler 的代码,实现一个自定义的 sampler 并注册到 DATA_SAMPLERS 根注册器中。

@DATA_SAMPLERS.register_module()
class MySampler(Sampler):
    pass

runner = Runner(
    train_dataloader=dict(
        sampler=dict(type='MySampler'),
        ...
    )
)

不起眼的 collate_fn

PyTorch 的 DataLoader 中,collate_fn 这一参数常常被使用者忽略,但在 MMEngine 中你需要额外注意:当你传入 dict 来构造数据加载器时,MMEngine 会默认使用内置的 pseudo_collate,这一点明显区别于 PyTorch 默认的 default_collate。因此,当你迁移 PyTorch 项目时,需要在配置文件中手动指明 collate_fn 以保持行为一致。

备注

MMEngine 中使用 pseudo_collate 作为默认值,主要是由于历史兼容性原因,你可以不必过于深究,只需了解并避免错误使用即可。

MMengine 中提供了 2 种内置的 collate_fn

  • pseudo_collate,缺省时的默认参数。它不会将数据沿着 batch 的维度合并。详细说明可以参考 pseudo_collate

  • default_collate,与 PyTorch 中的 default_collate 行为几乎完全一致,会将数据转化为 Tensor 并沿着 batch 维度合并。一些细微不同和详细说明可以参考 default_collate

如果你想要使用自定义的 collate_fn,你也可以将它注册到 FUNCTIONS 根注册器中来使用

@FUNCTIONS.register_module()
def my_collate_func(data_batch: Sequence) -> Any:
    pass

runner = Runner(
    train_dataloader=dict(
        ...
        collate_fn=dict(type='my_collate_func')
    )
)

数据集详解

数据集通常定义了数据的数量、读取方式与预处理,并作为参数传递给数据加载器供后者分批次加载。由于我们使用了 PyTorch 的 DataLoader,因此数据集也自然与 PyTorch Dataset 完全兼容。同时得益于注册机制,当数据加载器使用 dict 在执行器内部构建时,dataset 参数也可以使用 dict 传入并在内部被构建。这一点使得编写配置文件成为可能。

使用 torchvision 数据集

torchvision 中提供了丰富的公开数据集,它们都可以在 MMEngine 中直接使用,例如 15 分钟上手中的示例代码就使用了其中的 Cifar10 数据集,并且使用了 torchvision 中内置的数据预处理模块。

但是,当需要将上述示例转换为配置文件时,你需要对 torchvision 中的数据集进行额外的注册。如果你同时用到了 torchvision 中的数据预处理模块,那么你也需要编写额外代码来对它们进行注册和构建。下面我们将给出一个等效的例子来展示如何做到这一点。

import torchvision.transforms as tvt
from mmengine.registry import DATASETS, TRANSFORMS
from mmengine.dataset.base_dataset import Compose

# 注册 torchvision 的 CIFAR10 数据集
# 数据预处理也需要在此一起构建
@DATASETS.register_module(name='Cifar10', force=False)
def build_torchvision_cifar10(transform=None, **kwargs):
    if isinstance(transform, dict):
        transform = [transform]
    if isinstance(transform, (list, tuple)):
        transform = Compose(transform)
    return torchvision.datasets.CIFAR10(**kwargs, transform=transform)

# 注册 torchvision 中用到的数据预处理模块
DATA_TRANSFORMS.register_module('RandomCrop', module=tvt.RandomCrop)
DATA_TRANSFORMS.register_module('RandomHorizontalFlip', module=tvt.RandomHorizontalFlip)
DATA_TRANSFORMS.register_module('ToTensor', module=tvt.ToTensor)
DATA_TRANSFORMS.register_module('Normalize', module=tvt.Normalize)

# 在 Runner 中使用
runner = Runner(
    train_dataloader=dict(
        batch_size=32,
        sampler=dict(
            type='DefaultSampler',
            shuffle=True),
        dataset=dict(type='Cifar10',
            root='data/cifar10',
            train=True,
            download=True,
            transform=[
                dict(type='RandomCrop', size=32, padding=4),
                dict(type='RandomHorizontalFlip'),
                dict(type='ToTensor'),
                dict(type='Normalize', **norm_cfg)])
    )
)

备注

上述例子中大量使用了注册机制,并且用到了 MMEngine 中的 Compose。如果你急需在配置文件中使用 torchvision 数据集,你可以参考上述代码并略作修改。但我们更加推荐你有需要时在下游库(如 MMDetMMPretrain 等)中寻找对应的数据集实现,从而获得更好的使用体验。

自定义数据集

你可以像使用 PyTorch 一样,自由地定义自己的数据集,或将之前 PyTorch 项目中的数据集拷贝过来。如果你想要了解如何自定义数据集,可以参考 PyTorch 官方教程

使用 MMEngine 的数据集基类

除了直接使用 PyTorch 的 Dataset 来自定义数据集之外,你也可以使用 MMEngine 内置的 BaseDataset,参考数据集基类文档。它对标注文件的格式做了一些约定,使得数据接口更加统一、多任务训练更加便捷。同时,数据集基类也可以轻松地搭配内置的数据变换使用,减轻你从头搭建训练流程的工作量。

目前,BaseDataset 已经在 OpenMMLab 2.0 系列的下游仓库中被广泛使用。

Read the Docs v: stable
Versions
latest
stable
v0.10.4
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.