Shortcuts

MMSeparateDistributedDataParallel

class mmengine.model.MMSeparateDistributedDataParallel(module, broadcast_buffers=False, find_unused_parameters=False, **kwargs)[source]

A DistributedDataParallel wrapper for models in MMGeneration.

In MMedting and MMGeneration there is a need to wrap different modules in the models with separate DistributedDataParallel. Otherwise, it will cause errors for GAN training. For example, the GAN model, usually has two submodules: generator and discriminator. If we wrap both of them in one standard DistributedDataParallel, it will cause errors during training, because when we update the parameters of the generator (or discriminator), the parameters of the discriminator (or generator) is not updated, which is not allowed for DistributedDataParallel. So we design this wrapper to separately wrap DistributedDataParallel for generator and discriminator. In this wrapper, we perform two operations:

  1. Wraps each module in the models with separate MMDistributedDataParallel. Note that only modules with parameters will be wrapped.

  2. Calls train_step, val_step and test_step of submodules to get losses and predictions.

Parameters:
  • module (nn.Module) – model contain multiple submodules which have separately updating strategy.

  • broadcast_buffers (bool) – Same as that in torch.nn.parallel.distributed.DistributedDataParallel. Defaults to False.

  • find_unused_parameters (bool) – Same as that in torch.nn.parallel.distributed.DistributedDataParallel. Traverse the autograd graph of all tensors contained in returned value of the wrapped module’s forward function. Defaults to False.

  • **kwargs

    Keyword arguments passed to MMDistributedDataParallel.

    • device_ids (List[int] or torch.device, optional): CUDA devices for module.

    • output_device (int or torch.device, optional): Device location of output for single-device CUDA modules.

    • dim (int): Defaults to 0.

    • process_group (ProcessGroup, optional): The process group to be used for distributed data all-reduction.

    • bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults to 25.

    • check_reduction (bool): This argument is deprecated. Defaults to False.

    • gradient_as_bucket_view (bool): Defaults to False.

    • static_graph (bool): Defaults to False.

See more information about arguments in torch.nn.parallel.DistributedDataParallel.

no_sync()[source]

Enables no_sync context of all sub MMDistributedDataParallel modules.

test_step(data)[source]

Gets the predictions of module during testing process.

Parameters:

data (dict or tuple or list) – Data sampled from dataset.

Returns:

The predictions of given data.

Return type:

list

train(mode=True)[source]

Sets the module in training mode.

In order to make the ddp wrapper inheritance hierarchy more uniform, MMSeparateDistributedDataParallel inherits from DistributedDataParallel, but will not call its constructor. Since the attributes of DistributedDataParallel have not been initialized, call the train method of DistributedDataParallel will raise an error if pytorch version <= 1.9. Therefore, override this method to call the train method of submodules.

Parameters:

mode (bool) – whether to set training mode (True) or evaluation mode (False). Defaults to True.

Returns:

self.

Return type:

Module

train_step(data, optim_wrapper)[source]

Interface for model forward, backward and parameters updating during training process.

Parameters:
Returns:

A dict of tensor for logging.

Return type:

Dict[str, torch.Tensor]

val_step(data)[source]

Gets the prediction of module during validation process.

Parameters:

data (dict or tuple or list) – Data sampled from dataset.

Returns:

The predictions of given data.

Return type:

list

Read the Docs v: latest
Versions
latest
stable
v0.10.5
v0.10.4
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.