FSDPStrategy¶
- class mmengine._strategy.FSDPStrategy(*, model_wrapper=None, skip_init_weights=False, state_dict_cfg='local', activation_checkpointing=None, **kwargs)[source]¶
Support training model with FullyShardedDataParallel (FSDP).
- Keyword Arguments:
model_wrapper (dict, optional) –
Config dict for model wrapper. The default configuration is:
Examples
>>> model_wrapper = dict( >>> type='MMFullyShardedDataParallel', >>> use_orig_params=True, >>> )
See more configurable arguments in
MMFullyShardedDataParallel
. Defaults to Noneskip_init_weights (bool, optional) – Whether to skip initialization of weights. Defaults to False. This is useful when the parameters of the large model are loaded from a checkpoint, since skipping the initialization of weights can save a lot of time.
state_dict_cfg (str or dict) –
Configuration for how to save and load the state dict of the model, optimizer, and scheduler.
”local”: save and load the sharded state dict in all ranks.
”full”: save and load the full state dict in rank 0.
dict object: save and load the state dict more flexibly. For example, you can first offload the state dict to the ‘cpu’ and then save it to the disk. This can help you to load the checkpoint in a non-gpu environment:
- Examples:
>>> state_dict_cfg=dict( >>> state_dict_type='FULL_STATE_DICT', >>> state_dict_config=dict(type='FullStateDictConfig', offload_to_cpu=True), >>> optim_state_dict_config=dict(type='FullOptimStateDictConfig', offload_to_cpu=True),
See more configurable arguments for
state_dict_cfg
,state_dict_config
, and ``optim_state_dict_config``in FSDP official api documents
kwargs (dict) –
Additional arguments passed to
DDPStrategy
:work_dir (str): The working directory to save checkpoints. The logs will be saved in the subdirectory of work_dir named
timestamp
. Defaults to ‘work_dirs’.experiment_name (str, optional): Name of current experiment. If not specified, timestamp will be used as
experiment_name
. Defaults to None.env_kwargs (dict, optional): Environment config passed in
setup_env()
. Defaults to None.log_kwargs (dict, optional): Logger config passed in
build_logger()
. Defaults to None.
activation_checkpointing (dict, optional) –
Config dict for gradient checkpoint.
Examples
>>> activation_checkpointing = dict(check_fn='CustomCheckFn') >>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1))
check_fn
field should behave consistently withauto_wrap_policy
defined in model_wrapper, and other fields will be passed toapply_activation_checkpointing
New in version 0.9.0.
- Parameters:
- build_model(model)[source]¶
Build model.
If skip_init_weights is True, the model will be built with an empty weights. It means that
load_checkpoint()
must be called to fill the weights before training.- Parameters:
model (nn.Module or dict) – A
nn.Module
object or a dict to buildnn.Module
object. Ifmodel
is ann.Module
object, just returns itself.- Returns:
Model build from
model
.- Return type:
nn.Module
- build_optim_wrapper(optim_wrapper, model=None)[source]¶
Support sharding the optimizer state dict given a built optimizer or optim_wrapper.
See specific usage in
BaseStrategy.build_optim_wrapper()
.- Parameters:
optim_wrapper (Optimizer | OptimWrapper | dict) –
model (Module | None) –
- Return type:
BaseOptimWrapper
- load_checkpoint(filename, **kwargs)[source]¶
Load checkpoint from given
filename
.Note
If
state_dict_type
is local, the filename should be a directory containsrank{i}.pth
.- Parameters:
filename (str) – Accept local filepath, URL,
torchvision://xxx
,open-mmlab://xxx
.- Keyword Arguments:
map_location (str or callable) – A string or a callable function to specifying how to remap storage locations. Defaults to ‘cpu’.
strict (bool) – strict (bool): Whether to allow different params for the model and checkpoint.
revise_keys (list) – A list of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Defaults to strip the prefix ‘module.’ by [(r’^module.’, ‘’)].
callback (callable, callable) – Callback function to modify the checkpoint after loading the checkpoint. Defaults to None.
- Return type:
- load_model_state_dict(state_dict, *, strict=False, revise_keys=[('^module.', '')])[source]¶
Load model state from dict.
Warning
revise_keys is not supported yet.
- Parameters:
state_dict (dict) – Model state dict returned by
FSDPStrategy.model_state_dict()
. Ifstate_dict_type
isfull
.state_dict
could be the result ofmodel.state_dict()
strict (bool) – Whether to load model state dict strictly. Defaults to False.
revise_keys (list) –
- Return type:
None
- load_optim_state_dict(state_dict)[source]¶
Load optimizer state from dict.
- Parameters:
state_dict (dict) – The optimizer state dict. If
state_dict_type
isfull
.state_dict
could be the result ofoptimizer.state_dict()
- Return type:
None
- model_state_dict()[source]¶
Get model state dict based on the
state_dict_type
.If
state_dict_type
is full, the model state dict will be the same as the one of original unsharded model.If
state_dict_type
islocal
, anduse_orig_params
isTrue
inmodel_wrapper
. The key of the state dict will be the same as the one of original unsharded model, but its value will be the sharded oneIf
state_dict_type
is local, and`use_orig_params`
isFalse
inmodel_wrapper
, the flatten and sharded state dict will be returned.See more details in the official api documents
- Return type:
- optim_state_dict()[source]¶
Get model state dict based on the
state_dict_type
.If
state_dict_type
isfull
, the optimizer state dict can be loaded by the original unsharded optimizer.Otherwise, the optimizer state dict could only be loaded by the optimizer with sharded parameters.
Note
The optimizer state dict is not the same as the one of original optimizer even if in
full
mode, although they can be loaded correctly.See more details in the official api documents
- Return type:
- save_checkpoint(filename, *, save_optimizer=True, save_param_scheduler=True, extra_ckpt=None, callback=None)[source]¶
Save checkpoint to given
filename
.If
state_dict_type
is full, the checkpoint will only be saved in rank0. The structure of the saved checkpoint is the same as the one saved byDDPStrategy
If
state_dict_type
is local, each rank will save the sharded state dict to a directory, which means the saved structure will look like this:── epoch_0.pth ├── rank0.pth ├── rank1.pth ├── ... └── rank8.pth
- Parameters:
- Keyword Arguments:
save_optimizer (bool) – Whether to save the optimizer to the checkpoint. Defaults to True.
save_param_scheduler (bool) – Whether to save the param_scheduler to the checkpoint. Defaults to True.
extra_ckpt (dict, optional) – Extra checkpoint to save. Defaults to None.
callback (callable, callable) – Callback function to modify the checkpoint before saving the checkpoint. Defaults to None.
- Return type:
None