Shortcuts

mmengine._strategy.distributed 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Callable, Optional

import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel

from mmengine.device import get_device
from mmengine.dist import init_dist, is_distributed, master_only
from mmengine.model import convert_sync_batchnorm, is_model_wrapper
from mmengine.registry import MODEL_WRAPPERS, STRATEGIES
from .single_device import SingleDeviceStrategy


[文档]@STRATEGIES.register_module() class DDPStrategy(SingleDeviceStrategy): """Distribution strategy for distributed data parallel training. Args: model_wrapper (dict): Dict for model wrapper. Defaults to None. sync_bn (str): Type of sync batch norm. Defaults to None. Options are 'torch' and 'mmcv'. **kwargs: Other arguments for :class:`BaseStrategy`. """ def __init__( self, *, model_wrapper: Optional[dict] = None, sync_bn: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) self.model_wrapper = model_wrapper self.sync_bn = sync_bn def _setup_distributed( # type: ignore self, launcher: str = 'pytorch', backend: str = 'nccl', **kwargs, ): """Setup distributed environment. Args: launcher (str): Way to launcher multi processes. Supported launchers are 'pytorch', 'mpi' and 'slurm'. backend (str): Communication Backends. Supported backends are 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. **kwargs: Other arguments for :func:`init_dist`. """ if not is_distributed(): init_dist(launcher, backend, **kwargs)
[文档] def convert_model(self, model: nn.Module) -> nn.Module: """convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` (SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers. Args: model (nn.Module): Model to be converted. Returns: nn.Module: Converted model. """ if self.sync_bn is not None: try: model = convert_sync_batchnorm(model, self.sync_bn) except ValueError as e: self.logger.error('cfg.sync_bn should be "torch" or ' f'"mmcv", but got {self.sync_bn}') raise e return model
def _wrap_model(self, model: nn.Module) -> DistributedDataParallel: """Wrap the model to :obj:``MMDistributedDataParallel`` or other custom distributed data-parallel module wrappers. Args: model (nn.Module): Model to be wrapped. Returns: nn.Module or DistributedDataParallel: nn.Module or subclass of ``DistributedDataParallel``. """ if is_model_wrapper(model): return model model = model.to(get_device()) model = self.convert_model(model) if self.model_wrapper is None: # set broadcast_buffers as False to keep compatibility with # OpenMMLab repos self.model_wrapper = dict( type='MMDistributedDataParallel', broadcast_buffers=False) default_args = dict( type='MMDistributedDataParallel', module=model, device_ids=[int(os.environ['LOCAL_RANK'])]) model = MODEL_WRAPPERS.build( self.model_wrapper, default_args=default_args) return model
[文档] @master_only def save_checkpoint( self, filename: str, *, save_optimizer: bool = True, save_param_scheduler: bool = True, extra_ckpt: Optional[dict] = None, callback: Optional[Callable] = None, ) -> None: super().save_checkpoint( filename=filename, save_optimizer=save_optimizer, save_param_scheduler=save_param_scheduler, extra_ckpt=extra_ckpt, callback=callback)

© Copyright 2022, mmengine contributors. Revision 66fb81f7.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.