class mmengine.model.MMFullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=None, mixed_precision=None, param_init_fn=None, use_orig_params=True, **kwargs)[source]

A wrapper for sharding Module parameters across data parallel workers.

Different from FullyShardedDataParallel, MMFullyShardedDataParallel implements three methods train_step(), val_step() and test_step(), which will be called by train_loop, val_loop and test_loop.

  • train_step: Called by runner.train_loop, and implement default model forward, gradient back propagation, parameter updating logic.

  • val_step: Called by runner.val_loop and get the inference results. Specially, since MMFullyShardedDataParallel will wrap model recursively, it may cause some problem if one just use BaseModel.val_step to implement val_step here. To avoid that, val_step will call methods of BaseModel to pre-process data first, and use FullyShardedDataParallel.forward to get result.

  • test_step: Called by runner.test_loop and get the inference results. Its logic is equivalent to val_loop.

  • module (nn.Module) – module to be wrapped with FSDP.

  • process_group (ProcessGroup, optional) – process group for sharding.

  • cpu_offload (bool, CPUOffload, optional) –

    CPU offloading config. Different from FullyShardedDataParallel,Since it can be set by users’ pre-defined config in MMEngine,its type is expected to be None, bool or CPUOffload.

    Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in cpu_offload=CPUOffload(offload_params=True). Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on same device to work with optimizer. This API is subject to change. Default is None in which case there will be no offloading.

  • auto_wrap_policy (str or Callable, optional) –

    Specifying a policy to recursively wrap layers with FSDP. Different from FullyShardedDataParallel, Since it can be set by users’ pre-defined config in MMEngine, its type is expected to be None, str or Callable. If it’s str, then MMFullyShardedDataParallel will try to get specified method in FSDP_WRAP_POLICIES registry,and this method will be passed to FullyShardedDataParallel to finally initialize model.

    Note that this policy currently will only apply to child modules of the passed in module. The remainder modules are always wrapped in the returned FSDP root instance. default_auto_wrap_policy written in torch.distributed.fsdp.wrap is an example of auto_wrap_policy callable, this policy wraps layers with parameter sizes larger than 100M. Users can supply the customized auto_wrap_policy callable that should accept following arguments: module: nn.Module, recurse: bool, unwrapped_params: int, extra customized arguments could be added to the customized auto_wrap_policy callable as well.


    >>> def custom_auto_wrap_policy(
    >>>     module: nn.Module,
    >>>     recurse: bool,
    >>>     unwrapped_params: int,
    >>>     # These are customizable for this policy function.
    >>>     min_num_params: int = int(1e8),
    >>> ) -> bool:
    >>>     return unwrapped_params >= min_num_params

  • backward_prefetch (str or BackwardPrefetch, optional) – Different from FullyShardedDataParallel, this argument could be a string or a BackwardPrefetch instance. If it’s a string, then it should be BACKWARD_PRE or BACKWARD_POST

  • mixed_precision (dict or MixedPrecision, optional) –

    This configures native mixed precision for FSDP. If this is set to None. Different from the native FSDP, this argument can a dict like this:


    >>> mixed_precision=dict(param_dtype='float16',
    >>>                      buffer_dtype='float32',
    >>>                      reduce_dtype='float32')

    Defaults to None.

  • use_orig_params (bool) – Different from native FullyShardedDataParallel, it defaults to True.

  • **kwargs – Keyword arguments passed to FullyShardedDataParallel.

  • sharding_strategy (Union[str, torch.distributed.fsdp.api.ShardingStrategy]) –

  • param_init_fn (Union[str, Callable[[torch.nn.modules.module.Module], None]]) –


Gets the predictions of module during testing process.


data (dict) – Data sampled by dataloader.


The predictions of given data.

Return type


train_step(data, optim_wrapper)[source]

Interface for model forward, backward and parameters updating during training process.

train_step() will perform the following steps in order:

  • If module defines the preprocess method,

    call module.preprocess to pre-processing data.

  • Call module.forward(**data) and get losses.

  • Parse losses.

  • Call optim_wrapper.optimizer_step to update parameters.

  • Return log messages of losses.

  • data (dict) – Data sampled by dataloader.

  • optim_wrapper (OptimWrapper) – A wrapper of optimizer to update parameters.


A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]


Gets the prediction of module during validation process.


data (dict) – Data sampled by dataloader.


The predictions of given data.

Return type

List[BaseDataElement] or dict

Read the Docs v: v0.8.1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.