数据集(Dataset)与数据加载器(DataLoader)¶
提示
如果你没有接触过 PyTorch 的数据集与数据加载器,我们推荐先浏览 PyTorch 官方教程以了解一些基本概念
数据集与数据加载器是 MMEngine 中训练流程的必要组件,它们的概念来源于 PyTorch,并且在含义上与 PyTorch 保持一致。通常来说,数据集定义了数据的总体数量、读取方式以及预处理,而数据加载器则在不同的设置下迭代地加载数据,如批次大小(batch_size
)、随机乱序(shuffle
)、并行(num_workers
)等。数据集经过数据加载器封装后构成了数据源。在本篇教程中,我们将按照从外(数据加载器)到内(数据集)的顺序,逐步介绍它们在 MMEngine 执行器中的用法,并给出一些常用示例。读完本篇教程,你将会:
掌握如何在 MMEngine 的执行器中配置数据加载器
学会在配置文件中使用已有(如
torchvision
)数据集了解如何使用自己的数据集
数据加载器详解¶
在执行器(Runner
)中,你可以分别配置以下 3 个参数来指定对应的数据加载器
train_dataloader
:在Runner.train()
中被使用,为模型提供训练数据val_dataloader
:在Runner.val()
中被使用,也会在Runner.train()
中每间隔一段时间被使用,用于模型的验证评测test_dataloader
:在Runner.test()
中被使用,用于模型的测试
MMEngine 完全支持 PyTorch 的原生 DataLoader
,因此上述 3 个参数均可以直接传入构建好的 DataLoader
,如15分钟上手中的例子所示。同时,借助 MMEngine 的注册机制,以上参数也可以传入 dict
,如下面代码(以下简称例 1)所示。字典中的键值与 DataLoader
的构造参数一一对应。
runner = Runner(
train_dataloader=dict(
batch_size=32,
sampler=dict(
type='DefaultSampler',
shuffle=True),
dataset=torchvision.datasets.CIFAR10(...),
collate_fn=dict(type='default_collate')
)
)
在这种情况下,数据加载器会在实际被用到时,在执行器内部被构建。
备注
关于 DataLoader
的更多可配置参数,你可以参考 PyTorch API 文档
备注
如果你对于构建的具体细节感兴趣,你可以参考 build_dataloader
细心的你可能会发现,例 1 并非直接由15分钟上手中的示例代码简单修改而来。你可能本以为将 DataLoader
简单替换为 dict
就可以无缝切换,但遗憾的是,基于注册机制构建时 MMEngine 会有一些隐式的转换和约定。我们将介绍其中的不同点,以避免你使用配置文件时产生不必要的疑惑。
sampler 与 shuffle¶
与 15 分钟上手明显不同,例 1 中我们添加了 sampler
参数,这是由于在 MMEngine 中我们要求通过 dict
传入的数据加载器的配置必须包含 sampler
参数。同时,shuffle
参数也从 DataLoader
中移除,这是由于在 PyTorch 中 sampler
与 shuffle
参数是互斥的,见 PyTorch API 文档。
备注
事实上,在 PyTorch 的实现中,shuffle
只是一个便利记号。当设置为 True
时 DataLoader
会自动在内部使用 RandomSampler
当考虑 sampler
时,例 1 代码基本可以认为等价于下面的代码块
from mmengine.dataset import DefaultSampler
dataset = torchvision.datasets.CIFAR10(...)
sampler = DefaultSampler(dataset, shuffle=True)
runner = Runner(
train_dataloader=DataLoader(
batch_size=32,
sampler=sampler,
dataset=dataset,
collate_fn=default_collate
)
)
警告
上述代码的等价性只有在:1)使用单进程训练,以及 2)没有配置执行器的 randomness
参数时成立。这是由于使用 dict
传入 sampler
时,执行器会保证它在分布式训练环境设置完成后才被惰性构造,并接收到正确的随机种子。这两点在手动构造时需要额外工作且极易出错。因此,上述的写法只是一个示意而非推荐写法。我们强烈建议 sampler
以 dict
的形式传入,让执行器处理构造顺序,以避免出现问题。
DefaultSampler¶
上面例子可能会让你好奇:DefaultSampler
是什么,为什么要使用它,是否有其他选项?事实上,DefaultSampler
是 MMEngine 内置的一种采样器,它屏蔽了单进程训练与多进程训练的细节差异,使得单卡与多卡训练可以无缝切换。如果你有过使用 PyTorch DistributedDataParallel
的经验,你一定会对其中更换数据加载器的 sampler
参数有所印象。但在 MMEngine 中,这一细节通过 DefaultSampler
而被屏蔽。
除了 Dataset
本身之外,DefaultSampler
还支持以下参数配置:
shuffle
设置为True
时会打乱数据集的读取顺序seed
打乱数据集所用的随机种子,通常不需要在此手动设置,会从Runner
的randomness
入参中读取round_up
设置为True
时,与 PyTorchDataLoader
中设置drop_last=False
行为一致。如果你在迁移 PyTorch 的项目,你可能需要注意这一点。
备注
更多关于 DefaultSampler
的内容可以参考 API 文档
DefaultSampler
适用于绝大部分情况,并且我们保证在执行器中使用它时,随机数等容易出错的细节都被正确地处理,防止你陷入多进程训练的常见陷阱。如果你想要使用基于迭代次数 (iteration-based) 的训练流程,你也许会对 InfiniteSampler 感兴趣。如果你有更多的进阶需求,你可能会想要参考上述两个内置 sampler
的代码,实现一个自定义的 sampler
并注册到 DATA_SAMPLERS
根注册器中。
@DATA_SAMPLERS.register_module()
class MySampler(Sampler):
pass
runner = Runner(
train_dataloader=dict(
sampler=dict(type='MySampler'),
...
)
)
不起眼的 collate_fn¶
PyTorch 的 DataLoader
中,collate_fn
这一参数常常被使用者忽略,但在 MMEngine 中你需要额外注意:当你传入 dict
来构造数据加载器时,MMEngine 会默认使用内置的 pseudo_collate,这一点明显区别于 PyTorch 默认的 default_collate。因此,当你迁移 PyTorch 项目时,需要在配置文件中手动指明 collate_fn
以保持行为一致。
备注
MMEngine 中使用 pseudo_collate
作为默认值,主要是由于历史兼容性原因,你可以不必过于深究,只需了解并避免错误使用即可。
MMengine 中提供了 2 种内置的 collate_fn
:
pseudo_collate
,缺省时的默认参数。它不会将数据沿着batch
的维度合并。详细说明可以参考 pseudo_collatedefault_collate
,与 PyTorch 中的default_collate
行为几乎完全一致,会将数据转化为Tensor
并沿着batch
维度合并。一些细微不同和详细说明可以参考 default_collate
如果你想要使用自定义的 collate_fn
,你也可以将它注册到 FUNCTIONS
根注册器中来使用
@FUNCTIONS.register_module()
def my_collate_func(data_batch: Sequence) -> Any:
pass
runner = Runner(
train_dataloader=dict(
...
collate_fn=dict(type='my_collate_func')
)
)
数据集详解¶
数据集通常定义了数据的数量、读取方式与预处理,并作为参数传递给数据加载器供后者分批次加载。由于我们使用了 PyTorch 的 DataLoader
,因此数据集也自然与 PyTorch Dataset
完全兼容。同时得益于注册机制,当数据加载器使用 dict
在执行器内部构建时,dataset
参数也可以使用 dict
传入并在内部被构建。这一点使得编写配置文件成为可能。
使用 torchvision 数据集¶
torchvision
中提供了丰富的公开数据集,它们都可以在 MMEngine 中直接使用,例如 15 分钟上手中的示例代码就使用了其中的 Cifar10
数据集,并且使用了 torchvision
中内置的数据预处理模块。
但是,当需要将上述示例转换为配置文件时,你需要对 torchvision
中的数据集进行额外的注册。如果你同时用到了 torchvision
中的数据预处理模块,那么你也需要编写额外代码来对它们进行注册和构建。下面我们将给出一个等效的例子来展示如何做到这一点。
import torchvision.transforms as tvt
from mmengine.registry import DATASETS, TRANSFORMS
from mmengine.dataset.base_dataset import Compose
# 注册 torchvision 的 CIFAR10 数据集
# 数据预处理也需要在此一起构建
@DATASETS.register_module(name='Cifar10', force=False)
def build_torchvision_cifar10(transform=None, **kwargs):
if isinstance(transform, dict):
transform = [transform]
if isinstance(transform, (list, tuple)):
transform = Compose(transform)
return torchvision.datasets.CIFAR10(**kwargs, transform=transform)
# 注册 torchvision 中用到的数据预处理模块
DATA_TRANSFORMS.register_module('RandomCrop', module=tvt.RandomCrop)
DATA_TRANSFORMS.register_module('RandomHorizontalFlip', module=tvt.RandomHorizontalFlip)
DATA_TRANSFORMS.register_module('ToTensor', module=tvt.ToTensor)
DATA_TRANSFORMS.register_module('Normalize', module=tvt.Normalize)
# 在 Runner 中使用
runner = Runner(
train_dataloader=dict(
batch_size=32,
sampler=dict(
type='DefaultSampler',
shuffle=True),
dataset=dict(type='Cifar10',
root='data/cifar10',
train=True,
download=True,
transform=[
dict(type='RandomCrop', size=32, padding=4),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize', **norm_cfg)])
)
)
备注
上述例子中大量使用了注册机制,并且用到了 MMEngine 中的 Compose。如果你急需在配置文件中使用 torchvision
数据集,你可以参考上述代码并略作修改。但我们更加推荐你有需要时在下游库(如 MMDet 和 MMPretrain 等)中寻找对应的数据集实现,从而获得更好的使用体验。
自定义数据集¶
你可以像使用 PyTorch 一样,自由地定义自己的数据集,或将之前 PyTorch 项目中的数据集拷贝过来。如果你想要了解如何自定义数据集,可以参考 PyTorch 官方教程
使用 MMEngine 的数据集基类¶
除了直接使用 PyTorch 的 Dataset
来自定义数据集之外,你也可以使用 MMEngine 内置的 BaseDataset
,参考数据集基类文档。它对标注文件的格式做了一些约定,使得数据接口更加统一、多任务训练更加便捷。同时,数据集基类也可以轻松地搭配内置的数据变换使用,减轻你从头搭建训练流程的工作量。
目前,BaseDataset
已经在 OpenMMLab 2.0 系列的下游仓库中被广泛使用。