Shortcuts

FSDPStrategy

class mmengine._strategy.FSDPStrategy(*, model_wrapper=None, skip_init_weights=False, state_dict_cfg='local', activation_checkpointing=None, **kwargs)[源代码]

Support training model with FullyShardedDataParallel (FSDP).

关键字参数:
  • model_wrapper (dict, optional) –

    Config dict for model wrapper. The default configuration is:

    示例

    >>> model_wrapper = dict(
    >>>    type='MMFullyShardedDataParallel',
    >>>    use_orig_params=True,
    >>> )
    

    See more configurable arguments in MMFullyShardedDataParallel. Defaults to None

  • skip_init_weights (bool, optional) – Whether to skip initialization of weights. Defaults to False. This is useful when the parameters of the large model are loaded from a checkpoint, since skipping the initialization of weights can save a lot of time.

  • state_dict_cfg (str or dict) –

    Configuration for how to save and load the state dict of the model, optimizer, and scheduler.

    • ”local”: save and load the sharded state dict in all ranks.

    • ”full”: save and load the full state dict in rank 0.

    • dict object: save and load the state dict more flexibly. For example, you can first offload the state dict to the ‘cpu’ and then save it to the disk. This can help you to load the checkpoint in a non-gpu environment:

      Examples:
      >>> state_dict_cfg=dict(
      >>>     state_dict_type='FULL_STATE_DICT',
      >>>     state_dict_config=dict(type='FullStateDictConfig', offload_to_cpu=True),
      >>>     optim_state_dict_config=dict(type='FullOptimStateDictConfig', offload_to_cpu=True),
      

      See more configurable arguments for state_dict_cfg, state_dict_config, and ``optim_state_dict_config``in FSDP official api documents

  • kwargs (dict) –

    Additional arguments passed to DDPStrategy:

    • work_dir (str): The working directory to save checkpoints. The logs will be saved in the subdirectory of work_dir named timestamp. Defaults to ‘work_dirs’.

    • experiment_name (str, optional): Name of current experiment. If not specified, timestamp will be used as experiment_name. Defaults to None.

    • env_kwargs (dict, optional): Environment config passed in setup_env(). Defaults to None.

    • log_kwargs (dict, optional): Logger config passed in build_logger(). Defaults to None.

  • activation_checkpointing (dict, optional) –

    Config dict for gradient checkpoint.

    示例

    >>> activation_checkpointing = dict(check_fn='CustomCheckFn')
    >>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1))
    

    check_fn field should behave consistently with auto_wrap_policy defined in model_wrapper, and other fields will be passed to apply_activation_checkpointing

    New in version 0.9.0.

参数:
  • model_wrapper (dict | None) –

  • state_dict_cfg (str | dict) –

  • activation_checkpointing (dict | None) –

build_model(model)[源代码]

Build model.

If skip_init_weights is True, the model will be built with an empty weights. It means that load_checkpoint() must be called to fill the weights before training.

参数:

model (nn.Module or dict) – A nn.Module object or a dict to build nn.Module object. If model is a nn.Module object, just returns itself.

返回:

Model build from model.

返回类型:

nn.Module

build_optim_wrapper(optim_wrapper, model=None)[源代码]

Support sharding the optimizer state dict given a built optimizer or optim_wrapper.

See specific usage in BaseStrategy.build_optim_wrapper().

参数:
返回类型:

BaseOptimWrapper

load_checkpoint(filename, **kwargs)[源代码]

Load checkpoint from given filename.

备注

If state_dict_type is local, the filename should be a directory contains rank{i}.pth.

参数:

filename (str) – Accept local filepath, URL, torchvision://xxx, open-mmlab://xxx.

关键字参数:
  • map_location (str or callable) – A string or a callable function to specifying how to remap storage locations. Defaults to ‘cpu’.

  • strict (bool) – strict (bool): Whether to allow different params for the model and checkpoint.

  • revise_keys (list) – A list of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Defaults to strip the prefix ‘module.’ by [(r’^module.’, ‘’)].

  • callback (callable, callable) – Callback function to modify the checkpoint after loading the checkpoint. Defaults to None.

返回类型:

dict

load_model_state_dict(state_dict, *, strict=False, revise_keys=[('^module.', '')])[源代码]

Load model state from dict.

警告

revise_keys is not supported yet.

参数:
  • state_dict (dict) – Model state dict returned by FSDPStrategy.model_state_dict(). If state_dict_type is full. state_dict could be the result of model.state_dict()

  • strict (bool) – Whether to load model state dict strictly. Defaults to False.

  • revise_keys (list) –

返回类型:

None

load_optim_state_dict(state_dict)[源代码]

Load optimizer state from dict.

参数:

state_dict (dict) – The optimizer state dict. If state_dict_type is full. state_dict could be the result of optimizer.state_dict()

返回类型:

None

model_state_dict()[源代码]

Get model state dict based on the state_dict_type.

If state_dict_type is full, the model state dict will be the same as the one of original unsharded model.

If state_dict_type is local, and use_orig_params is True in model_wrapper. The key of the state dict will be the same as the one of original unsharded model, but its value will be the sharded one

If state_dict_type is local, and `use_orig_params` is False in model_wrapper, the flatten and sharded state dict will be returned.

See more details in the official api documents

返回类型:

dict

optim_state_dict()[源代码]

Get model state dict based on the state_dict_type.

If state_dict_type is full, the optimizer state dict can be loaded by the original unsharded optimizer.

Otherwise, the optimizer state dict could only be loaded by the optimizer with sharded parameters.

备注

The optimizer state dict is not the same as the one of original optimizer even if in full mode, although they can be loaded correctly.

See more details in the official api documents

返回类型:

dict

save_checkpoint(filename, *, save_optimizer=True, save_param_scheduler=True, extra_ckpt=None, callback=None)[源代码]

Save checkpoint to given filename.

If state_dict_type is full, the checkpoint will only be saved in rank0. The structure of the saved checkpoint is the same as the one saved by DDPStrategy

If state_dict_type is local, each rank will save the sharded state dict to a directory, which means the saved structure will look like this:

── epoch_0.pth
    ├── rank0.pth
    ├── rank1.pth
    ├── ...
    └── rank8.pth
参数:
  • filename (str) – Filename to save checkpoint.

  • save_optimizer (bool) –

  • save_param_scheduler (bool) –

  • extra_ckpt (dict | None) –

  • callback (Callable | None) –

关键字参数:
  • save_optimizer (bool) – Whether to save the optimizer to the checkpoint. Defaults to True.

  • save_param_scheduler (bool) – Whether to save the param_scheduler to the checkpoint. Defaults to True.

  • extra_ckpt (dict, optional) – Extra checkpoint to save. Defaults to None.

  • callback (callable, callable) – Callback function to modify the checkpoint before saving the checkpoint. Defaults to None.

返回类型:

None

Read the Docs v: stable
Versions
latest
stable
v0.10.4
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.