Dataset and DataLoader¶
If you have never been exposed to PyTorch’s Dataset and DataLoader classes, you are recommended to read through PyTorch official tutorial to get familiar with some basic concepts.
Datasets and DataLoaders are necessary components in MMEngine’s training pipeline. They are conceptually derived from and consistent with PyTorch. Typically, a dataset defines the quantity, parsing, and pre-processing of the data, while a dataloader iteratively loads data according to settings such as
num_workers, etc. Datasets are encapsulated with dataloaders and they together constitute the data source.
In this tutorial, we will step through their usage in MMEngine runner from the outside (dataloader) to the inside (dataset) and give some practical examples. After reading through this tutorial, you will be able to:
Master the configuration of dataloaders in MMEngine
Learn to use existing datasets (e.g. those from
torchvision) from config files
Know about building and using your own dataset
Details on dataloader¶
Dataloaders can be configured in MMEngine’s
Runner with 3 arguments:
train_dataloader: Used in
Runner.train()to provide training data for models
val_dataloader: Used in
Runner.train()at regular intervals for model evaluation
test_dataloader: Used in
Runner.test()for the final test
MMEngine has full support for PyTorch native
DataLoader objects. Therefore, you can simply pass your valid, already built dataloaders to the runner, as shown in getting started in 15 minutes. Meanwhile, thanks to the Registry Mechanism of MMEngine, those arguments also accept
dicts as inputs, as illustrated in the following example (referred to as example 1). The keys in the dictionary correspond to arguments in DataLoader’s init function.
runner = Runner( train_dataloader=dict( batch_size=32, sampler=dict( type='DefaultSampler', shuffle=True), dataset=torchvision.datasets.CIFAR10(...), collate_fn=dict(type='default_collate') ) )
When passed to the runner in the form of a dict, the dataloader will be lazily built in the runner when actually needed.
For more configurable arguments of the
DataLoader, please refer to PyTorch API documentation
If you are interested in the details of the building procedure, you may refer to build_dataloader
You may find example 1 differs from that in getting started in 15 minutes in some arguments. Indeed, due to some obscure conventions in MMEngine, you can’t seamlessly switch it to a dict by simply replacing
dict. We will discuss the differences between our convention and PyTorch’s in the following sections, in case you run into trouble when using config files.
sampler and shuffle¶
One obvious difference is that we add a
sampler argument to the dict. This is because we require
sampler to be explicitly specified when using a dict as a dataloader. Meanwhile,
shuffle is also removed from
DataLoader arguments, because it conflicts with
sampler in PyTorch, as referred to in PyTorch DataLoader API documentation.
shuffle is just a notation for convenience in PyTorch implementation. If
shuffle is set to
True, the dataloader will automatically switch to
sampler argument, codes in example 1 is nearly equivalent to code block below
from mmengine.dataset import DefaultSampler dataset = torchvision.datasets.CIFAR10(...) sampler = DefaultSampler(dataset, shuffle=True) runner = Runner( train_dataloader=DataLoader( batch_size=32, sampler=sampler, dataset=dataset, collate_fn=default_collate ) )
The equivalence of the above codes holds only if: 1) you are training with a single process, and 2) no
randomness argument is passed to the runner. This is due to the fact that
sampler should be built after distributed environment setup to be correct. The runner will guarantee the correct order and proper random seed by applying lazy initialization techniques, which is only possible for dict inputs. Instead, when building a sampler manually, it requires extra work and is highly error-prone. Therefore, the code block above is just for illustration and definitely not recommended. We strongly suggest passing
sampler as a
dict to avoid potential problems.
The above example may make you wonder what a
DefaultSampler is, why use it and whether there are other options. In fact,
DefaultSampler is a built-in sampler in MMEngine which eliminates the gap between distributed and non-distributed training and thus enabling a seamless conversion between them. If you have the experience of using
DistributedDataParallel in PyTorch, you may be impressed by having to change the
sampler argument to make it correct. However, in MMEngine, you don’t need to bother with this
DefaultSampler accepts the following arguments:
shuffle: Set to
Trueto load data in the dataset in random order
seed: Random seed used to shuffle the dataset. Typically it doesn’t require manual configuration here because the runner will handle it with
round_up: When set this to
True, this is the same behavior as setting
DataLoader. You should take care of it when doing migration from PyTorch.
For more details about
DefaultSampler, please refer to its API docs
DefaultSampler handles most of the cases. We ensure that error-prone details such as random seeds are handled properly when you are using it in a runner. This prevents you from getting into troubles with distributed training. Apart from
DefaultSampler, you may also be interested in InfiniteSampler for iteration-based training pipelines. If you have more advanced demands, you may want to refer to the codes of these two built-in samplers to implement your own one and register it to
@DATA_SAMPLERS.register_module() class MySampler(Sampler): pass runner = Runner( train_dataloader=dict( sampler=dict(type='MySampler'), ... ) )
The obscure collate_fn¶
Among the arguments of PyTorch
collate_fn is often ignored by users, but in MMEngine you must pay special attention to it. When you pass the dataloader argument as a dict, MMEngine will use the built-in pseudo_collate by default, which is significantly different from that, default_collate, in PyTorch. Therefore, when doing a migration from PyTorch, you have to explicitly specify the
collate_fn in config files to be consistent in behavior.
pseudo_collate as default value is mainly due to historical compatibility reasons. You don’t have to look deeply into it. You can just know about it and avoid potential errors.
MMEngine provides 2 built-in
pseudo_collate: Default value in MMEngine. It won’t concatenate data through
batchindex. Detailed explanations can be found in pseudo_collate API doc
default_collate: It behaves almost identically to PyTorch’s
default_collate. It will transfer data into
Tensorand concatenate them through
batchindex. More details and slight differences from PyTorch can be found in default_collate API doc
If you want to use a custom
collate_fn, you can register it to
@FUNCTIONS.register_module() def my_collate_func(data_batch: Sequence) -> Any: pass runner = Runner( train_dataloader=dict( ... collate_fn=dict(type='my_collate_func') ) )
Details on dataset¶
Typically, datasets define the quantity, parsing, and pre-processing of the data. It is encapsulated in dataloader, allowing the latter to load data in batches. Since we fully support PyTorch
DataLoader, the dataset is also compatible. Meanwhile, thanks to the registry mechanism, when a dataloader is given as a dict, its
dataset argument can also be given as a dict, which enables lazy initialization in the runner. This mechanism allows for writing config files.
Use torchvision datasets¶
torchvision provides various open datasets. They can be directly used in MMEngine as shown in getting started in 15 minutes, where a
CIFAR10 dataset is used together with torchvision’s built-in data transforms.
However, if you want to use the dataset in config files, registration is needed. What’s more, if you also require data transforms in torchvision, some more registrations are required. The following example illustrates how to do it.
import torchvision.transforms as tvt from mmengine.registry import DATASETS, TRANSFORMS from mmengine.dataset.base_dataset import Compose # register CIFAR10 dataset in torchvision # data transforms should also be built here @DATASETS.register_module(name='Cifar10', force=False) def build_torchvision_cifar10(transform=None, **kwargs): if isinstance(transform, dict): transform = [transform] if isinstance(transform, (list, tuple)): transform = Compose(transform) return torchvision.datasets.CIFAR10(**kwargs, transform=transform) # register data transforms in torchvision DATA_TRANSFORMS.register_module('RandomCrop', module=tvt.RandomCrop) DATA_TRANSFORMS.register_module('RandomHorizontalFlip', module=tvt.RandomHorizontalFlip) DATA_TRANSFORMS.register_module('ToTensor', module=tvt.ToTensor) DATA_TRANSFORMS.register_module('Normalize', module=tvt.Normalize) # specify in runner runner = Runner( train_dataloader=dict( batch_size=32, sampler=dict( type='DefaultSampler', shuffle=True), dataset=dict(type='Cifar10', root='data/cifar10', train=True, download=True, transform=[ dict(type='RandomCrop', size=32, padding=4), dict(type='RandomHorizontalFlip'), dict(type='ToTensor'), dict(type='Normalize', **norm_cfg)]) ) )
The above example makes extensive use of the registry mechanism and borrows the Compose module from MMEngine. If you urge to use torchvision dataset in your config files, you can refer to it and make some slight modifications. However, we recommend you borrow datasets from downstream repos such as MMDet, MMCls, etc. This may give you a better experience.
Customize your dataset¶
You are free to customize your own datasets, as you would with PyTorch. You can also copy existing datasets from your previous PyTorch projects. If you want to learn how to customize your dataset, please refer to PyTorch official tutorials
Use MMEngine BaseDataset¶
Apart from directly using PyTorch native
Dataset class, you can also use MMEngine’s built-in class
BaseDataset to customize your own one, as referred to BaseDataset tutorial. It makes some conventions on the format of annotation files, which makes the data interface more unified and multi-task training more convenient. Meanwhile,
BaseDataset can easily cooperate with built-in data transforms in MMEngine, which releases you from writing one from scratch.
BaseDataset has been widely used in downstream repos of OpenMMLab 2.0 projects.