Shortcuts

MMDistributedDataParallel

class mmengine.model.MMDistributedDataParallel(module, detect_anomalous_params=False, **kwargs)[source]

A distributed model wrapper used for training,testing and validation in loop.

Different from DistributedDataParallel, MMDistributedDataParallel implements three methods train_step(), val_step() and test_step(), which will be called by train_loop, val_loop and test_loop.

  • train_step: Called by runner.train_loop, and implement default model forward, gradient back propagation, parameter updating logic. To take advantage of DistributedDataParallel’s automatic gradient synchronization, train_step calls DistributedDataParallel.forward to calculate the losses, and call other methods of BaseModel to pre-process data and parse losses. Finally, update model parameters by OptimWrapper and return the loss dictionary used for logging.

  • val_step: Called by runner.val_loop and get the inference results. Since there is no gradient synchronization requirement, this procedure is equivalent to BaseModel.val_step

  • test_step: Called by runner.test_loop, equivalent val_step.

Parameters
  • detect_anomalous_params (bool) –

    This option is only used for debugging which will slow down the training speed. Detect anomalous parameters that are not included in the computational graph with loss as the root. There are two cases

    • Parameters were not used during forward pass.

    • Parameters were not used to produce loss.

    Defaults to False.

  • **kwargs

    keyword arguments passed to DistributedDataParallel.

    • device_ids (List[int] or torch.device, optional): CUDA devices for module.

    • output_device (int or torch.device, optional): Device location of output for single-device CUDA modules.

    • dim (int): Defaults to 0.

    • broadcast_buffers (bool): Flag that enables syncing ( broadcasting) buffers of the module at beginning of the forward function. Defaults to True

    • find_unused_parameters (bool): Whether to find parameters of module, which are not in the forward graph. Defaults to False.

    • process_group (ProcessGroup, optional): The process group to be used for distributed data all-reduction.

    • bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults to 25.

    • check_reduction (bool): This argument is deprecated. Defaults to False.

    • gradient_as_bucket_view (bool): Defaults to False.

    • static_graph (bool): Defaults to False.

See more information about arguments in torch.nn.parallel.DistributedDataParallel.

Note

If model has multiple submodules and each module has separate optimization strategies, MMSeparateDistributedDataParallel should be used to wrap the model.

Note

If model itself has custom optimization strategy, rather than simply forward model and update model. A custom model wrapper inherit from MMDistributedDataParallel should be defined and override the train_step method.

test_step(data)[source]

Gets the predictions of module during testing process.

Parameters

data (dict or tuple or list) – Data sampled from dataset.

Returns

The predictions of given data.

Return type

list

train_step(data, optim_wrapper)[source]

Interface for model forward, backward and parameters updating during training process.

train_step() will perform the following steps in order:

  • If module defines the preprocess method, call module.preprocess to pre-processing data.

  • Call module.forward(**data) and get losses.

  • Parse losses.

  • Call optim_wrapper.optimizer_step to update parameters.

  • Return log messages of losses.

Parameters
  • data (dict or tuple or list) – Data sampled from dataset.

  • optim_wrapper (OptimWrapper) – A wrapper of optimizer to update parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

val_step(data)[source]

Gets the prediction of module during validation process.

Parameters

data (dict or tuple or list) – Data sampled from dataset.

Returns

The predictions of given data.

Return type

list

Read the Docs v: v0.8.3
Versions
latest
stable
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.