Shortcuts

Source code for mmengine.dataset.utils

# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
from typing import Any, Mapping, Sequence

import numpy as np
import torch
from torch.utils.data._utils.collate import \
    default_collate as torch_default_collate

from mmengine.registry import FUNCTIONS
from mmengine.structures import BaseDataElement

# FUNCTIONS is new in MMEngine v0.7.0. Reserve the `COLLATE_FUNCTIONS` to keep
# the compatibility.
COLLATE_FUNCTIONS = FUNCTIONS


[docs]def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int, disable_subprocess_warning: bool = False) -> None: """This function will be called on each worker subprocess after seeding and before data loading. Args: worker_id (int): Worker id in [0, num_workers - 1]. num_workers (int): How many subprocesses to use for data loading. rank (int): Rank of process in distributed environment. If in non-distributed environment, it is a constant number `0`. seed (int): Random seed. """ # The seed of each worker equals to # num_worker * rank + worker_id + user_seed worker_seed = num_workers * rank + worker_id + seed np.random.seed(worker_seed) random.seed(worker_seed) torch.manual_seed(worker_seed) if disable_subprocess_warning and worker_id != 0: warnings.simplefilter('ignore')
[docs]@FUNCTIONS.register_module() def pseudo_collate(data_batch: Sequence) -> Any: """Convert list of data sampled from dataset into a batch of data, of which type consistent with the type of each data_itement in ``data_batch``. The default behavior of dataloader is to merge a list of samples to form a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate`` will not stack tensors to batch tensors, and convert int, float, ndarray to tensors. This code is referenced from: `Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. Args: data_batch (Sequence): Batch of data from dataloader. Returns: Any: Transversed Data in the same format as the data_itement of ``data_batch``. """ # noqa: E501 data_item = data_batch[0] data_item_type = type(data_item) if isinstance(data_item, (str, bytes)): return data_batch elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): # named tuple return data_item_type(*(pseudo_collate(samples) for samples in zip(*data_batch))) elif isinstance(data_item, Sequence): # check to make sure that the data_itements in batch have # consistent size it = iter(data_batch) data_item_size = len(next(it)) if not all(len(data_item) == data_item_size for data_item in it): raise RuntimeError( 'each data_itement in list of batch should be of equal size') transposed = list(zip(*data_batch)) if isinstance(data_item, tuple): return [pseudo_collate(samples) for samples in transposed] # Compat with Pytorch. else: try: return data_item_type( [pseudo_collate(samples) for samples in transposed]) except TypeError: # The sequence type may not support `__init__(iterable)` # (e.g., `range`). return [pseudo_collate(samples) for samples in transposed] elif isinstance(data_item, Mapping): return data_item_type({ key: pseudo_collate([d[key] for d in data_batch]) for key in data_item }) else: return data_batch
[docs]@FUNCTIONS.register_module() def default_collate(data_batch: Sequence) -> Any: """Convert list of data sampled from dataset into a batch of data, of which type consistent with the type of each data_itement in ``data_batch``. Different from :func:`pseudo_collate`, ``default_collate`` will stack tensor contained in ``data_batch`` into a batched tensor with the first dimension batch size, and then move input tensor to the target device. Different from ``default_collate`` in pytorch, ``default_collate`` will not process ``BaseDataElement``. This code is referenced from: `Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. Note: ``default_collate`` only accept input tensor with the same shape. Args: data_batch (Sequence): Data sampled from dataset. Returns: Any: Data in the same format as the data_itement of ``data_batch``, of which tensors have been stacked, and ndarray, int, float have been converted to tensors. """ # noqa: E501 data_item = data_batch[0] data_item_type = type(data_item) if isinstance(data_item, (BaseDataElement, str, bytes)): return data_batch elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): # named_tuple return data_item_type(*(default_collate(samples) for samples in zip(*data_batch))) elif isinstance(data_item, Sequence): # check to make sure that the data_itements in batch have # consistent size it = iter(data_batch) data_item_size = len(next(it)) if not all(len(data_item) == data_item_size for data_item in it): raise RuntimeError( 'each data_itement in list of batch should be of equal size') transposed = list(zip(*data_batch)) if isinstance(data_item, tuple): return [default_collate(samples) for samples in transposed] # Compat with Pytorch. else: try: return data_item_type( [default_collate(samples) for samples in transposed]) except TypeError: # The sequence type may not support `__init__(iterable)` # (e.g., `range`). return [default_collate(samples) for samples in transposed] elif isinstance(data_item, Mapping): return data_item_type({ key: default_collate([d[key] for d in data_batch]) for key in data_item }) else: return torch_default_collate(data_batch)

© Copyright 2022, mmengine contributors. Revision 66fb81f7.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.