Shortcuts

Source code for mmengine.hooks.sync_buffer_hook

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dist import all_reduce_params, is_distributed
from mmengine.registry import HOOKS
from .hook import Hook


[docs]@HOOKS.register_module() class SyncBuffersHook(Hook): """Synchronize model buffers such as running_mean and running_var in BN at the end of each epoch.""" priority = 'NORMAL' def __init__(self) -> None: self.distributed = is_distributed()
[docs] def after_train_epoch(self, runner) -> None: """All-reduce model buffers at the end of each epoch. Args: runner (Runner): The runner of the training process. """ if self.distributed: all_reduce_params(runner.model.buffers(), op='mean')

© Copyright 2022, mmengine contributors. Revision 8d4885cb.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.5.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.