Shortcuts

Source code for mmengine.hooks.sampler_seed_hook

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import HOOKS
from .hook import Hook


[docs]@HOOKS.register_module() class DistSamplerSeedHook(Hook): """Data-loading sampler for distributed training. When distributed training, it is only useful in conjunction with :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same purpose with :obj:`IterLoader`. """ priority = 'NORMAL'
[docs] def before_train_epoch(self, runner) -> None: """Set the seed for sampler and batch_sampler. Args: runner (Runner): The runner of the training process. """ if hasattr(runner.train_loop.dataloader, 'sampler') and hasattr( runner.train_loop.dataloader.sampler, 'set_epoch'): # In case the` _SingleProcessDataLoaderIter` has no sampler, # or data loader uses `SequentialSampler` in Pytorch. runner.train_loop.dataloader.sampler.set_epoch(runner.epoch) elif hasattr(runner.train_loop.dataloader, 'batch_sampler') and hasattr( runner.train_loop.dataloader.batch_sampler.sampler, 'set_epoch'): # In case the` _SingleProcessDataLoaderIter` has no batch sampler. # batch sampler in pytorch warps the sampler as its attributes. runner.train_loop.dataloader.batch_sampler.sampler.set_epoch( runner.epoch)

© Copyright 2022, mmengine contributors. Revision d9fee4fb.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.8.4
Versions
latest
stable
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.