mmengine.dataset.utils 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
from typing import Any, Mapping, Sequence
import numpy as np
import torch
from torch.utils.data._utils.collate import \
default_collate as torch_default_collate
from mmengine.registry import FUNCTIONS
from mmengine.structures import BaseDataElement
# FUNCTIONS is new in MMEngine v0.7.0. Reserve the `COLLATE_FUNCTIONS` to keep
# the compatibility.
COLLATE_FUNCTIONS = FUNCTIONS
[文档]def worker_init_fn(worker_id: int,
num_workers: int,
rank: int,
seed: int,
disable_subprocess_warning: bool = False) -> None:
"""This function will be called on each worker subprocess after seeding and
before data loading.
Args:
worker_id (int): Worker id in [0, num_workers - 1].
num_workers (int): How many subprocesses to use for data loading.
rank (int): Rank of process in distributed environment. If in
non-distributed environment, it is a constant number `0`.
seed (int): Random seed.
"""
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
if disable_subprocess_warning and worker_id != 0:
warnings.simplefilter('ignore')
[文档]@FUNCTIONS.register_module()
def pseudo_collate(data_batch: Sequence) -> Any:
"""Convert list of data sampled from dataset into a batch of data, of which
type consistent with the type of each data_itement in ``data_batch``.
The default behavior of dataloader is to merge a list of samples to form
a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate``
will not stack tensors to batch tensors, and convert int, float, ndarray to
tensors.
This code is referenced from:
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_.
Args:
data_batch (Sequence): Batch of data from dataloader.
Returns:
Any: Transversed Data in the same format as the data_itement of
``data_batch``.
""" # noqa: E501
data_item = data_batch[0]
data_item_type = type(data_item)
if isinstance(data_item, (str, bytes)):
return data_batch
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
# named tuple
return data_item_type(*(pseudo_collate(samples)
for samples in zip(*data_batch)))
elif isinstance(data_item, Sequence):
# check to make sure that the data_itements in batch have
# consistent size
it = iter(data_batch)
data_item_size = len(next(it))
if not all(len(data_item) == data_item_size for data_item in it):
raise RuntimeError(
'each data_itement in list of batch should be of equal size')
transposed = list(zip(*data_batch))
if isinstance(data_item, tuple):
return [pseudo_collate(samples)
for samples in transposed] # Compat with Pytorch.
else:
try:
return data_item_type(
[pseudo_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)`
# (e.g., `range`).
return [pseudo_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: pseudo_collate([d[key] for d in data_batch])
for key in data_item
})
else:
return data_batch
[文档]@FUNCTIONS.register_module()
def default_collate(data_batch: Sequence) -> Any:
"""Convert list of data sampled from dataset into a batch of data, of which
type consistent with the type of each data_itement in ``data_batch``.
Different from :func:`pseudo_collate`, ``default_collate`` will stack
tensor contained in ``data_batch`` into a batched tensor with the
first dimension batch size, and then move input tensor to the target
device.
Different from ``default_collate`` in pytorch, ``default_collate`` will
not process ``BaseDataElement``.
This code is referenced from:
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_.
Note:
``default_collate`` only accept input tensor with the same shape.
Args:
data_batch (Sequence): Data sampled from dataset.
Returns:
Any: Data in the same format as the data_itement of ``data_batch``, of which
tensors have been stacked, and ndarray, int, float have been
converted to tensors.
""" # noqa: E501
data_item = data_batch[0]
data_item_type = type(data_item)
if isinstance(data_item, (BaseDataElement, str, bytes)):
return data_batch
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
# named_tuple
return data_item_type(*(default_collate(samples)
for samples in zip(*data_batch)))
elif isinstance(data_item, Sequence):
# check to make sure that the data_itements in batch have
# consistent size
it = iter(data_batch)
data_item_size = len(next(it))
if not all(len(data_item) == data_item_size for data_item in it):
raise RuntimeError(
'each data_itement in list of batch should be of equal size')
transposed = list(zip(*data_batch))
if isinstance(data_item, tuple):
return [default_collate(samples)
for samples in transposed] # Compat with Pytorch.
else:
try:
return data_item_type(
[default_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)`
# (e.g., `range`).
return [default_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: default_collate([d[key] for d in data_batch])
for key in data_item
})
else:
return torch_default_collate(data_batch)