Shortcuts

mmengine.hooks.sync_buffer_hook 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dist import all_reduce_params, is_distributed
from mmengine.registry import HOOKS
from .hook import Hook


[文档]@HOOKS.register_module() class SyncBuffersHook(Hook): """Synchronize model buffers such as running_mean and running_var in BN at the end of each epoch.""" priority = 'NORMAL' def __init__(self) -> None: self.distributed = is_distributed() # A flag to mark whether synchronization has been done in # after_train_epoch self.called_in_train = False
[文档] def before_val_epoch(self, runner) -> None: """All-reduce model buffers before each validation epoch. Synchronize the buffers before each validation if they have not been synchronized at the end of the previous training epoch. This method will be called when using IterBasedTrainLoop. Args: runner (Runner): The runner of the training process. """ if self.distributed: if not self.called_in_train: all_reduce_params(runner.model.buffers(), op='mean') self.called_in_train = False
[文档] def after_train_epoch(self, runner) -> None: """All-reduce model buffers at the end of each epoch. Args: runner (Runner): The runner of the training process. """ if self.distributed: all_reduce_params(runner.model.buffers(), op='mean') self.called_in_train = True

© Copyright 2022, mmengine contributors. Revision 39ed23fa.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.