MMFullyShardedDataParallel¶
- class mmengine.model.MMFullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=None, mixed_precision=None, param_init_fn=None, use_orig_params=True, **kwargs)[source]¶
A wrapper for sharding Module parameters across data parallel workers.
Different from FullyShardedDataParallel, MMFullyShardedDataParallel implements three methods
train_step()
,val_step()
andtest_step()
, which will be called bytrain_loop
,val_loop
andtest_loop
.train_step
: Called byrunner.train_loop
, and implement default model forward, gradient back propagation, parameter updating logic.val_step
: Called byrunner.val_loop
and get the inference results. Specially, since MMFullyShardedDataParallel will wrap model recursively, it may cause some problem if one just useBaseModel.val_step
to implementval_step
here. To avoid that,val_step
will call methods ofBaseModel
to pre-process data first, and useFullyShardedDataParallel.forward
to get result.test_step
: Called byrunner.test_loop
and get the inference results. Its logic is equivalent toval_loop
.
- Parameters:
module (nn.Module) – module to be wrapped with FSDP.
process_group (ProcessGroup, optional) – process group for sharding.
cpu_offload (bool, CPUOffload, optional) –
CPU offloading config. Different from FullyShardedDataParallel,Since it can be set by users’ pre-defined config in MMEngine,its type is expected to be None, bool or CPUOffload.
Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in
cpu_offload=CPUOffload(offload_params=True)
. Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on same device to work with optimizer. This API is subject to change. Default isNone
in which case there will be no offloading.auto_wrap_policy (str or Callable, optional) –
Specifying a policy to recursively wrap layers with FSDP. Different from FullyShardedDataParallel, Since it can be set by users’ pre-defined config in MMEngine, its type is expected to be None, str or Callable. If it’s str, then MMFullyShardedDataParallel will try to get specified method in
FSDP_WRAP_POLICIES
registry,and this method will be passed to FullyShardedDataParallel to finally initialize model.Note that this policy currently will only apply to child modules of the passed in module. The remainder modules are always wrapped in the returned FSDP root instance.
default_auto_wrap_policy
written intorch.distributed.fsdp.wrap
is an example ofauto_wrap_policy
callable, this policy wraps layers with parameter sizes larger than 100M. Users can supply the customizedauto_wrap_policy
callable that should accept following arguments:module: nn.Module
,recurse: bool
,unwrapped_params: int
, extra customized arguments could be added to the customizedauto_wrap_policy
callable as well.Example:
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> unwrapped_params: int, >>> # These are customizable for this policy function. >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return unwrapped_params >= min_num_params
backward_prefetch (str or BackwardPrefetch, optional) – Different from FullyShardedDataParallel, this argument could be a string or a BackwardPrefetch instance. If it’s a string, then it should be
BACKWARD_PRE
orBACKWARD_POST
mixed_precision (dict or MixedPrecision, optional) –
This configures native mixed precision for FSDP. If this is set to
None
. Different from the native FSDP, this argument can a dict like this:Examples
>>> mixed_precision=dict(param_dtype='float16', >>> buffer_dtype='float32', >>> reduce_dtype='float32')
Defaults to None.
use_orig_params (bool) – Different from native
FullyShardedDataParallel
, it defaults to True.**kwargs – Keyword arguments passed to
FullyShardedDataParallel
.sharding_strategy (str | ShardingStrategy) –
- test_step(data)[source]¶
Gets the predictions of module during testing process.
- Parameters:
data (dict) – Data sampled by dataloader.
- Returns:
The predictions of given data.
- Return type:
List[BaseDataElement]
- train_step(data, optim_wrapper)[source]¶
Interface for model forward, backward and parameters updating during training process.
train_step()
will perform the following steps in order:- If
module
defines the preprocess method, call
module.preprocess
to pre-processing data.
- If
Call
module.forward(**data)
and get losses.Parse losses.
Call
optim_wrapper.optimizer_step
to update parameters.Return log messages of losses.
- Parameters:
data (dict) – Data sampled by dataloader.
optim_wrapper (OptimWrapper) – A wrapper of optimizer to update parameters.
- Returns:
A
dict
of tensor for logging.- Return type:
Dict[str, torch.Tensor]
- val_step(data)[source]¶
Gets the prediction of module during validation process.
- Parameters:
data (dict) – Data sampled by dataloader.
- Returns:
The predictions of given data.
- Return type:
List[BaseDataElement] or dict