Shortcuts

mmengine.dataset.sampler 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import math
from typing import Iterator, Optional, Sized

import torch
from torch.utils.data import Sampler

from mmengine.dist import get_dist_info, sync_random_seed
from mmengine.registry import DATA_SAMPLERS


[文档]@DATA_SAMPLERS.register_module() class DefaultSampler(Sampler): """The default data sampler for both distributed and non-distributed environment. It has several differences from the PyTorch ``DistributedSampler`` as below: 1. This sampler supports non-distributed environment. 2. The round up behaviors are a little different. - If ``round_up=True``, this sampler will add extra samples to make the number of samples is evenly divisible by the world size. And this behavior is the same as the ``DistributedSampler`` with ``drop_last=False``. - If ``round_up=False``, this sampler won't remove or add any samples while the ``DistributedSampler`` with ``drop_last=True`` will remove tail samples. Args: dataset (Sized): The dataset. shuffle (bool): Whether shuffle the dataset or not. Defaults to True. seed (int, optional): Random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Defaults to None. round_up (bool): Whether to add extra samples to make the number of samples evenly divisible by the world size. Defaults to True. """ def __init__(self, dataset: Sized, shuffle: bool = True, seed: Optional[int] = None, round_up: bool = True) -> None: rank, world_size = get_dist_info() self.rank = rank self.world_size = world_size self.dataset = dataset self.shuffle = shuffle if seed is None: seed = sync_random_seed() self.seed = seed self.epoch = 0 self.round_up = round_up if self.round_up: self.num_samples = math.ceil(len(self.dataset) / world_size) self.total_size = self.num_samples * self.world_size else: self.num_samples = math.ceil( (len(self.dataset) - rank) / world_size) self.total_size = len(self.dataset) def __iter__(self) -> Iterator[int]: """Iterate the indices.""" # deterministically shuffle based on epoch and seed if self.shuffle: g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = torch.arange(len(self.dataset)).tolist() # add extra samples to make it evenly divisible if self.round_up: indices = ( indices * int(self.total_size / len(indices) + 1))[:self.total_size] # subsample indices = indices[self.rank:self.total_size:self.world_size] return iter(indices) def __len__(self) -> int: """The number of samples in this rank.""" return self.num_samples
[文档] def set_epoch(self, epoch: int) -> None: """Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Args: epoch (int): Epoch number. """ self.epoch = epoch
[文档]@DATA_SAMPLERS.register_module() class InfiniteSampler(Sampler): """It's designed for iteration-based runner and yields a mini-batch indices each time. The implementation logic is referred to https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/distributed_sampler.py Args: dataset (Sized): The dataset. shuffle (bool): Whether shuffle the dataset or not. Defaults to True. seed (int, optional): Random seed. If None, set a random seed. Defaults to None. """ # noqa: W605 def __init__(self, dataset: Sized, shuffle: bool = True, seed: Optional[int] = None) -> None: rank, world_size = get_dist_info() self.rank = rank self.world_size = world_size self.dataset = dataset self.world_size = world_size self.rank = rank self.shuffle = shuffle if seed is None: seed = sync_random_seed() self.seed = seed self.size = len(dataset) self.indices = self._indices_of_rank() def _infinite_indices(self) -> Iterator[int]: """Infinitely yield a sequence of indices.""" g = torch.Generator() g.manual_seed(self.seed) while True: if self.shuffle: yield from torch.randperm(self.size, generator=g).tolist() else: yield from torch.arange(self.size).tolist() def _indices_of_rank(self) -> Iterator[int]: """Slice the infinite indices by rank.""" yield from itertools.islice(self._infinite_indices(), self.rank, None, self.world_size) def __iter__(self) -> Iterator[int]: """Iterate the indices.""" yield from self.indices def __len__(self) -> int: """Length of base dataset.""" return self.size
[文档] def set_epoch(self, epoch: int) -> None: """Not supported in iteration-based runner.""" pass

© Copyright 2022, mmengine contributors. Revision b2295a25.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.8.1
Versions
latest
stable
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.