Shortcuts

Source code for mmengine.hooks.sampler_seed_hook

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import HOOKS
from .hook import Hook


[docs]@HOOKS.register_module() class DistSamplerSeedHook(Hook): """Data-loading sampler for distributed training. When distributed training, it is only useful in conjunction with :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same purpose with :obj:`IterLoader`. """ priority = 'NORMAL'
[docs] def before_train_epoch(self, runner) -> None: """Set the seed for sampler and batch_sampler. Args: runner (Runner): The runner of the training process. """ if hasattr(runner.train_loop.dataloader, 'sampler') and hasattr( runner.train_loop.dataloader.sampler, 'set_epoch'): # In case the` _SingleProcessDataLoaderIter` has no sampler, # or data loader uses `SequentialSampler` in Pytorch. runner.train_loop.dataloader.sampler.set_epoch(runner.epoch) elif hasattr(runner.train_loop.dataloader, 'batch_sampler') and hasattr( runner.train_loop.dataloader.batch_sampler.sampler, 'set_epoch'): # In case the` _SingleProcessDataLoaderIter` has no batch sampler. # batch sampler in pytorch warps the sampler as its attributes. runner.train_loop.dataloader.batch_sampler.sampler.set_epoch( runner.epoch)

© Copyright 2022, mmengine contributors. Revision 8d4885cb.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.5.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.