

class mmengine._strategy.FSDPStrategy(*, model_wrapper=None, skip_init_weights=False, state_dict_cfg='local', activation_checkpointing=None, **kwargs)[source]

Support training model with FullyShardedDataParallel (FSDP).

Keyword Arguments:
  • model_wrapper (dict, optional) –

    Config dict for model wrapper. The default configuration is:


    >>> model_wrapper = dict(
    >>>    type='MMFullyShardedDataParallel',
    >>>    use_orig_params=True,
    >>> )

    See more configurable arguments in MMFullyShardedDataParallel. Defaults to None

  • skip_init_weights (bool, optional) – Whether to skip initialization of weights. Defaults to False. This is useful when the parameters of the large model are loaded from a checkpoint, since skipping the initialization of weights can save a lot of time.

  • state_dict_cfg (str or dict) –

    Configuration for how to save and load the state dict of the model, optimizer, and scheduler.

    • ”local”: save and load the sharded state dict in all ranks.

    • ”full”: save and load the full state dict in rank 0.

    • dict object: save and load the state dict more flexibly. For example, you can first offload the state dict to the ‘cpu’ and then save it to the disk. This can help you to load the checkpoint in a non-gpu environment:

      >>> state_dict_cfg=dict(
      >>>     state_dict_type='FULL_STATE_DICT',
      >>>     state_dict_config=dict(type='FullStateDictConfig', offload_to_cpu=True),
      >>>     optim_state_dict_config=dict(type='FullOptimStateDictConfig', offload_to_cpu=True),

      See more configurable arguments for state_dict_cfg, state_dict_config, and ``optim_state_dict_config``in FSDP official api documents

  • kwargs (dict) –

    Additional arguments passed to DDPStrategy:

    • work_dir (str): The working directory to save checkpoints. The logs will be saved in the subdirectory of work_dir named timestamp. Defaults to ‘work_dirs’.

    • experiment_name (str, optional): Name of current experiment. If not specified, timestamp will be used as experiment_name. Defaults to None.

    • env_kwargs (dict, optional): Environment config passed in setup_env(). Defaults to None.

    • log_kwargs (dict, optional): Logger config passed in build_logger(). Defaults to None.

  • activation_checkpointing (dict, optional) –

    Config dict for gradient checkpoint.


    >>> activation_checkpointing = dict(check_fn='CustomCheckFn')
    >>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1))

    check_fn field should behave consistently with auto_wrap_policy defined in model_wrapper, and other fields will be passed to apply_activation_checkpointing

    New in version 0.9.0.

  • model_wrapper (dict | None) –

  • state_dict_cfg (str | dict) –

  • activation_checkpointing (dict | None) –


Build model.

If skip_init_weights is True, the model will be built with an empty weights. It means that load_checkpoint() must be called to fill the weights before training.


model (nn.Module or dict) – A nn.Module object or a dict to build nn.Module object. If model is a nn.Module object, just returns itself.


Model build from model.

Return type:


build_optim_wrapper(optim_wrapper, model=None)[source]

Support sharding the optimizer state dict given a built optimizer or optim_wrapper.

See specific usage in BaseStrategy.build_optim_wrapper().

Return type:


load_checkpoint(filename, **kwargs)[source]

Load checkpoint from given filename.


If state_dict_type is local, the filename should be a directory contains rank{i}.pth.


filename (str) – Accept local filepath, URL, torchvision://xxx, open-mmlab://xxx.

Keyword Arguments:
  • map_location (str or callable) – A string or a callable function to specifying how to remap storage locations. Defaults to ‘cpu’.

  • strict (bool) – strict (bool): Whether to allow different params for the model and checkpoint.

  • revise_keys (list) – A list of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Defaults to strip the prefix ‘module.’ by [(r’^module.’, ‘’)].

  • callback (callable, callable) – Callback function to modify the checkpoint after loading the checkpoint. Defaults to None.

Return type:


load_model_state_dict(state_dict, *, strict=False, revise_keys=[('^module.', '')])[source]

Load model state from dict.


revise_keys is not supported yet.

  • state_dict (dict) – Model state dict returned by FSDPStrategy.model_state_dict(). If state_dict_type is full. state_dict could be the result of model.state_dict()

  • strict (bool) – Whether to load model state dict strictly. Defaults to False.

  • revise_keys (list) –

Return type:



Load optimizer state from dict.


state_dict (dict) – The optimizer state dict. If state_dict_type is full. state_dict could be the result of optimizer.state_dict()

Return type:



Get model state dict based on the state_dict_type.

If state_dict_type is full, the model state dict will be the same as the one of original unsharded model.

If state_dict_type is local, and use_orig_params is True in model_wrapper. The key of the state dict will be the same as the one of original unsharded model, but its value will be the sharded one

If state_dict_type is local, and `use_orig_params` is False in model_wrapper, the flatten and sharded state dict will be returned.

See more details in the official api documents

Return type:



Get model state dict based on the state_dict_type.

If state_dict_type is full, the optimizer state dict can be loaded by the original unsharded optimizer.

Otherwise, the optimizer state dict could only be loaded by the optimizer with sharded parameters.


The optimizer state dict is not the same as the one of original optimizer even if in full mode, although they can be loaded correctly.

See more details in the official api documents

Return type:


save_checkpoint(filename, *, save_optimizer=True, save_param_scheduler=True, extra_ckpt=None, callback=None)[source]

Save checkpoint to given filename.

If state_dict_type is full, the checkpoint will only be saved in rank0. The structure of the saved checkpoint is the same as the one saved by DDPStrategy

If state_dict_type is local, each rank will save the sharded state dict to a directory, which means the saved structure will look like this:

── epoch_0.pth
    ├── rank0.pth
    ├── rank1.pth
    ├── ...
    └── rank8.pth
  • filename (str) – Filename to save checkpoint.

  • save_optimizer (bool) –

  • save_param_scheduler (bool) –

  • extra_ckpt (dict | None) –

  • callback (Callable | None) –

Keyword Arguments:
  • save_optimizer (bool) – Whether to save the optimizer to the checkpoint. Defaults to True.

  • save_param_scheduler (bool) – Whether to save the param_scheduler to the checkpoint. Defaults to True.

  • extra_ckpt (dict, optional) – Extra checkpoint to save. Defaults to None.

  • callback (callable, callable) – Callback function to modify the checkpoint before saving the checkpoint. Defaults to None.

Return type:
