Shortcuts

Source code for mmengine.model.wrappers.fully_sharded_distributed

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional, Union

import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    BackwardPrefetch, CPUOffload, FullyShardedDataParallel)

from mmengine.optim import OptimWrapper
from mmengine.registry import MODEL_WRAPPERS, Registry
from mmengine.structures import BaseDataElement

# support customize fsdp policy
FSDP_WRAP_POLICIES = Registry('fsdp wrap policy')


[docs]@MODEL_WRAPPERS.register_module() class MMFullyShardedDataParallel(FullyShardedDataParallel): """A wrapper for sharding Module parameters across data parallel workers. Different from FullyShardedDataParallel, MMFullyShardedDataParallel implements three methods :meth:`train_step`, :meth:`val_step` and :meth:`test_step`, which will be called by ``train_loop``, ``val_loop`` and ``test_loop``. - ``train_step``: Called by ``runner.train_loop``, and implement default model forward, gradient back propagation, parameter updating logic. - ``val_step``: Called by ``runner.val_loop`` and get the inference results. Specially, since MMFullyShardedDataParallel will wrap model recursively, it may cause some problem if one just use ``BaseModel.val_step`` to implement ``val_step`` here. To avoid that, ``val_step`` will call methods of :obj:`BaseModel` to pre-process data first, and use ``FullyShardedDataParallel.forward`` to get result. - ``test_step``: Called by ``runner.test_loop`` and get the inference results. Its logic is equivalent to ``val_loop``. Args: module (nn.Module): module to be wrapped with FSDP. process_group (Optional[ProcessGroup]): process group for sharding. cpu_offload (Optional[Union[bool,CPUOffload]]): CPU offloading config. Different from FullyShardedDataParallel,Since it can be set by users' pre-defined config in MMEngine,its type is expected to be `None`, `bool` or `CPUOffload`. Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on same device to work with optimizer. This API is subject to change. Default is ``None`` in which case there will be no offloading. fsdp_auto_wrap_policy: (Optional[Union[str,Callable]]): Specifying a policy to recursively wrap layers with FSDP. Different from FullyShardedDataParallel, Since it can be set by users' pre-defined config in MMEngine, its type is expected to be `None`, `str` or `Callable`. If it's `str`, then MMFullyShardedDataParallel will try to get specified method in ``FSDP_WRAP_POLICIES`` registry,and this method will be passed to FullyShardedDataParallel to finally initialize model. Note that this policy currently will only apply to child modules of the passed in module. The remainder modules are always wrapped in the returned FSDP root instance. ``default_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is an example of ``fsdp_auto_wrap_policy`` callable, this policy wraps layers with parameter sizes larger than 100M. Users can supply the customized ``fsdp_auto_wrap_policy`` callable that should accept following arguments: ``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``, extra customized arguments could be added to the customized ``fsdp_auto_wrap_policy`` callable as well. Example:: >>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> unwrapped_params: int, >>> # These are customizable for this policy function. >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return unwrapped_params >= min_num_params backward_prefetch: (Optional[Union[str,BackwardPrefetch]]): Different from FullyShardedDataParallel, Since it will be set by users' pre-defined config in MMEngine,its type is expected to be `None`, `str` or `BackwardPrefetch`. This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch algorithms to help backward communication and computation overlapping. Pros and cons of each algorithm is explained in class ``BackwardPrefetch``. **kwargs: Keyword arguments passed to :class:`FullyShardedDataParallel`. """ def __init__( self, module: nn.Module, process_group: Optional[ProcessGroup] = None, cpu_offload: Optional[Union[bool, CPUOffload]] = None, fsdp_auto_wrap_policy: Optional[Union[str, Callable]] = None, backward_prefetch: Optional[Union[str, BackwardPrefetch]] = None, **kwargs, ): if cpu_offload is not None: if isinstance(cpu_offload, bool): cpu_offload = CPUOffload(offload_params=cpu_offload) elif not isinstance(cpu_offload, CPUOffload): raise TypeError( '`cpu_offload` should be `None`, `bool`' f'or `CPUOffload`, but has type {type(cpu_offload)}') if fsdp_auto_wrap_policy is not None: if isinstance(fsdp_auto_wrap_policy, str): assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICIES, \ '`FSDP_WRAP_POLICIES` has no ' \ f'function {fsdp_auto_wrap_policy}' fsdp_auto_wrap_policy = FSDP_WRAP_POLICIES.get( # type: ignore fsdp_auto_wrap_policy) if not isinstance(fsdp_auto_wrap_policy, Callable): # type: ignore raise TypeError( 'Registered `fsdp_auto_wrap_policy` needs to be ' '`Callable`, but has type ' f'{type(fsdp_auto_wrap_policy)}') elif not isinstance(fsdp_auto_wrap_policy, Callable): # type: ignore raise TypeError( '`fsdp_auto_wrap_policy` should be `None`, `str` ' 'or `Callable`, but has type ' f'{type(fsdp_auto_wrap_policy)}') if backward_prefetch is not None: if isinstance(backward_prefetch, str): assert backward_prefetch in ['pre', 'post'], \ '`backward_prefetch` should be either `pre` or `post`,' \ f' but get {backward_prefetch}' if backward_prefetch == 'pre': backward_prefetch = BackwardPrefetch.BACKWARD_PRE else: backward_prefetch = BackwardPrefetch.BACKWARD_POST elif not isinstance(backward_prefetch, BackwardPrefetch): raise TypeError('`backward_prefetch` should be `None`, `str` ' 'or `BackwardPrefetch`, but has type ' f'{type(backward_prefetch)}') super().__init__( module=module, process_group=process_group, auto_wrap_policy=fsdp_auto_wrap_policy, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, **kwargs)
[docs] def train_step(self, data: dict, optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: """Interface for model forward, backward and parameters updating during training process. :meth:`train_step` will perform the following steps in order: - If :attr:`module` defines the preprocess method, call ``module.preprocess`` to pre-processing data. - Call ``module.forward(**data)`` and get losses. - Parse losses. - Call ``optim_wrapper.optimizer_step`` to update parameters. - Return log messages of losses. Args: data (dict): Data sampled by dataloader. optim_wrapper (OptimWrapper): A wrapper of optimizer to update parameters. Returns: Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. """ # enable automatic mixed precision training context. with optim_wrapper.optim_context(self): data = self.module.data_preprocessor(data, training=True) if isinstance(data, dict): losses = self(**data, mode='loss') elif isinstance(data, (list, tuple)): losses = self(*data, mode='loss') else: raise TypeError('Output of `data_preprocessor` should be ' f'list tuple or dict, but got {type(data)}') parsed_loss, log_vars = self.module.parse_losses(losses) optim_wrapper.update_params(parsed_loss) return log_vars
[docs] def val_step(self, data: dict) -> List[BaseDataElement]: """Gets the prediction of module during validation process. Args: data (dict): Data sampled by dataloader. Returns: List[BaseDataElement] or dict: The predictions of given data. """ inputs, data_sample = self.module.data_preprocessor(data, False) return self(inputs, data_sample, mode='predict')
[docs] def test_step(self, data: dict) -> List[BaseDataElement]: """Gets the predictions of module during testing process. Args: data (dict): Data sampled by dataloader. Returns: List[BaseDataElement]: The predictions of given data. """ inputs, data_sample = self.module.data_preprocessor(data, False) return self(inputs, data_sample, mode='predict')

© Copyright 2022, mmengine contributors. Revision 6af88783.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.4.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.