MMDistributedDataParallel¶
- class mmengine.model.MMDistributedDataParallel(module, detect_anomalous_params=False, **kwargs)[source]¶
A distributed model wrapper used for training,testing and validation in loop.
Different from DistributedDataParallel, MMDistributedDataParallel implements three methods
train_step()
,val_step()
andtest_step()
, which will be called bytrain_loop
,val_loop
andtest_loop
.train_step
: Called byrunner.train_loop
, and implement default model forward, gradient back propagation, parameter updating logic. To take advantage of DistributedDataParallel’s automatic gradient synchronization,train_step
callsDistributedDataParallel.forward
to calculate the losses, and call other methods ofBaseModel
to pre-process data and parse losses. Finally, update model parameters byOptimWrapper
and return the loss dictionary used for logging.val_step
: Called byrunner.val_loop
and get the inference results. Since there is no gradient synchronization requirement, this procedure is equivalent toBaseModel.val_step
test_step
: Called byrunner.test_loop
, equivalentval_step
.
- Parameters:
detect_anomalous_params (bool) –
This option is only used for debugging which will slow down the training speed. Detect anomalous parameters that are not included in the computational graph with loss as the root. There are two cases
Parameters were not used during forward pass.
Parameters were not used to produce loss.
Defaults to False.
**kwargs –
keyword arguments passed to
DistributedDataParallel
.device_ids (List[int] or torch.device, optional): CUDA devices for module.
output_device (int or torch.device, optional): Device location of output for single-device CUDA modules.
dim (int): Defaults to 0.
broadcast_buffers (bool): Flag that enables syncing ( broadcasting) buffers of the module at beginning of the
forward
function. Defaults to Truefind_unused_parameters (bool): Whether to find parameters of module, which are not in the forward graph. Defaults to False.
process_group (ProcessGroup, optional): The process group to be used for distributed data all-reduction.
bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults to 25.
check_reduction (bool): This argument is deprecated. Defaults to False.
gradient_as_bucket_view (bool): Defaults to False.
static_graph (bool): Defaults to False.
See more information about arguments in
torch.nn.parallel.DistributedDataParallel
.Note
If model has multiple submodules and each module has separate optimization strategies,
MMSeparateDistributedDataParallel
should be used to wrap the model.Note
If model itself has custom optimization strategy, rather than simply forward model and update model. A custom model wrapper inherit from
MMDistributedDataParallel
should be defined and override thetrain_step
method.- train_step(data, optim_wrapper)[source]¶
Interface for model forward, backward and parameters updating during training process.
train_step()
will perform the following steps in order:If
module
defines the preprocess method, callmodule.preprocess
to pre-processing data.Call
module.forward(**data)
and get losses.Parse losses.
Call
optim_wrapper.optimizer_step
to update parameters.Return log messages of losses.
- Parameters:
optim_wrapper (OptimWrapper) – A wrapper of optimizer to update parameters.
- Returns:
A
dict
of tensor for logging.- Return type:
Dict[str, torch.Tensor]