Shortcuts

mmengine._strategy.distributed 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Callable, Optional

import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel

from mmengine.device import get_device
from mmengine.dist import init_dist, is_distributed, master_only
from mmengine.model import convert_sync_batchnorm, is_model_wrapper
from mmengine.registry import MODEL_WRAPPERS, STRATEGIES
from .single_device import SingleDeviceStrategy


[文档]@STRATEGIES.register_module() class DDPStrategy(SingleDeviceStrategy): """Distribution strategy for distributed data parallel training. Args: model_wrapper (dict): Dict for model wrapper. Defaults to None. sync_bn (str): Type of sync batch norm. Defaults to None. Options are 'torch' and 'mmcv'. **kwargs: Other arguments for :class:`BaseStrategy`. """ def __init__( self, *, model_wrapper: Optional[dict] = None, sync_bn: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) self.model_wrapper = model_wrapper self.sync_bn = sync_bn def _setup_distributed( # type: ignore self, launcher: str = 'pytorch', backend: str = 'nccl', **kwargs, ): """Setup distributed environment. Args: launcher (str): Way to launcher multi processes. Supported launchers are 'pytorch', 'mpi' and 'slurm'. backend (str): Communication Backends. Supported backends are 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. **kwargs: Other arguments for :func:`init_dist`. """ if not is_distributed(): init_dist(launcher, backend, **kwargs)
[文档] def convert_model(self, model: nn.Module) -> nn.Module: """convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` (SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers. Args: model (nn.Module): Model to be converted. Returns: nn.Module: Converted model. """ if self.sync_bn is not None: try: model = convert_sync_batchnorm(model, self.sync_bn) except ValueError as e: self.logger.error('cfg.sync_bn should be "torch" or ' f'"mmcv", but got {self.sync_bn}') raise e return model
def _wrap_model(self, model: nn.Module) -> DistributedDataParallel: """Wrap the model to :obj:``MMDistributedDataParallel`` or other custom distributed data-parallel module wrappers. Args: model (nn.Module): Model to be wrapped. Returns: nn.Module or DistributedDataParallel: nn.Module or subclass of ``DistributedDataParallel``. """ if is_model_wrapper(model): return model model = model.to(get_device()) model = self.convert_model(model) if self.model_wrapper is None: # set broadcast_buffers as False to keep compatibility with # OpenMMLab repos self.model_wrapper = dict( type='MMDistributedDataParallel', broadcast_buffers=False) default_args = dict( type='MMDistributedDataParallel', module=model, device_ids=[int(os.environ['LOCAL_RANK'])]) model = MODEL_WRAPPERS.build( self.model_wrapper, default_args=default_args) return model
[文档] @master_only def save_checkpoint( self, filename: str, *, save_optimizer: bool = True, save_param_scheduler: bool = True, extra_ckpt: Optional[dict] = None, callback: Optional[Callable] = None, ) -> None: super().save_checkpoint( filename=filename, save_optimizer=save_optimizer, save_param_scheduler=save_param_scheduler, extra_ckpt=extra_ckpt, callback=callback)

© Copyright 2022, mmengine contributors. Revision 317d8f31.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.8.0
Versions
latest
stable
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.