Shortcuts

Source code for mmengine.model.wrappers.distributed

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Union

import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel

from mmengine.optim import OptimWrapper
from mmengine.registry import MODEL_WRAPPERS
from ..utils import detect_anomalous_params

MODEL_WRAPPERS.register_module(module=DistributedDataParallel)
MODEL_WRAPPERS.register_module(module=DataParallel)


[docs]@MODEL_WRAPPERS.register_module() class MMDistributedDataParallel(DistributedDataParallel): """A distributed model wrapper used for training,testing and validation in loop. Different from DistributedDataParallel, MMDistributedDataParallel implements three methods :meth:`train_step`, :meth:`val_step` and :meth:`test_step`, which will be called by ``train_loop``, ``val_loop`` and ``test_loop``. - ``train_step``: Called by ``runner.train_loop``, and implement default model forward, gradient back propagation, parameter updating logic. To take advantage of DistributedDataParallel's automatic gradient synchronization, ``train_step`` calls ``DistributedDataParallel.forward`` to calculate the losses, and call other methods of :class:`BaseModel` to pre-process data and parse losses. Finally, update model parameters by :class:`OptimWrapper` and return the loss dictionary used for logging. - ``val_step``: Called by ``runner.val_loop`` and get the inference results. Since there is no gradient synchronization requirement, this procedure is equivalent to ``BaseModel.val_step`` - ``test_step``: Called by ``runner.test_loop``, equivalent ``val_step``. Args: detect_anomalous_params (bool): This option is only used for debugging which will slow down the training speed. Detect anomalous parameters that are not included in the computational graph with `loss` as the root. There are two cases - Parameters were not used during forward pass. - Parameters were not used to produce loss. Defaults to False. **kwargs: keyword arguments passed to ``DistributedDataParallel``. - device_ids (List[int] or torch.device, optional): CUDA devices for module. - output_device (int or torch.device, optional): Device location of output for single-device CUDA modules. - dim (int): Defaults to 0. - broadcast_buffers (bool): Flag that enables syncing ( broadcasting) buffers of the module at beginning of the ``forward`` function. Defaults to True - find_unused_parameters (bool): Whether to find parameters of module, which are not in the forward graph. Defaults to False. - process_group (ProcessGroup, optional): The process group to be used for distributed data all-reduction. - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults to 25. - check_reduction (bool): This argument is deprecated. Defaults to False. - gradient_as_bucket_view (bool): Defaults to False. - static_graph (bool): Defaults to False. See more information about arguments in :class:`torch.nn.parallel.DistributedDataParallel`. Note: If model has multiple submodules and each module has separate optimization strategies, :class:`MMSeparateDistributedDataParallel` should be used to wrap the model. Note: If model itself has custom optimization strategy, rather than simply forward model and update model. A custom model wrapper inherit from ``MMDistributedDataParallel`` should be defined and override the ``train_step`` method. """ def __init__(self, module, detect_anomalous_params: bool = False, **kwargs): super().__init__(module=module, **kwargs) self.detect_anomalous_params = detect_anomalous_params
[docs] def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: """Interface for model forward, backward and parameters updating during training process. :meth:`train_step` will perform the following steps in order: - If :attr:`module` defines the preprocess method, call ``module.preprocess`` to pre-processing data. - Call ``module.forward(**data)`` and get losses. - Parse losses. - Call ``optim_wrapper.optimizer_step`` to update parameters. - Return log messages of losses. Args: data (dict or tuple or list): Data sampled from dataset. optim_wrapper (OptimWrapper): A wrapper of optimizer to update parameters. Returns: Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. """ # Enable automatic mixed precision training context. with optim_wrapper.optim_context(self): data = self.module.data_preprocessor(data, training=True) losses = self._run_forward(data, mode='loss') parsed_loss, log_vars = self.module.parse_losses(losses) optim_wrapper.update_params(parsed_loss) if self.detect_anomalous_params: detect_anomalous_params(parsed_loss, model=self) return log_vars
[docs] def val_step(self, data: Union[dict, tuple, list]) -> list: """Gets the prediction of module during validation process. Args: data (dict or tuple or list): Data sampled from dataset. Returns: list: The predictions of given data. """ return self.module.val_step(data)
[docs] def test_step(self, data: Union[dict, tuple, list]) -> list: """Gets the predictions of module during testing process. Args: data (dict or tuple or list): Data sampled from dataset. Returns: list: The predictions of given data. """ return self.module.test_step(data)
def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any: """Unpacks data for :meth:`forward` Args: data (dict or tuple or list): Data sampled from dataset. mode (str): Mode of forward. Returns: dict or list: Results of training or testing mode. """ if isinstance(data, dict): results = self(**data, mode=mode) elif isinstance(data, (list, tuple)): results = self(*data, mode=mode) else: raise TypeError('Output of `data_preprocessor` should be ' f'list, tuple or dict, but got {type(data)}') return results