Shortcuts

mmengine.dist.all_gather

mmengine.dist.all_gather(data, group=None)[source]

Gather data from the whole group in a list.

Note

Calling all_gather in non-distributed environment does nothing and just returns a list containing data itself.

Note

Unlike PyTorch torch.distributed.all_gather, all_gather() in MMEngine does not pass in an empty list gather_list and returns the gather_list directly, which is more convenient. The difference between their interfaces is as below:

  • MMEngine: all_gather(data, group) -> gather_list

  • PyTorch: all_gather(gather_list, data, group) -> None

Parameters:
  • data (Tensor) – Tensor to be gathered.

  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. Defaults to None.

Returns:

Return a list containing data from the whole group if in distributed environment, otherwise a list only containing data itself.

Return type:

list[Tensor]

Examples

>>> import torch
>>> import mmengine.dist as dist
>>> # non-distributed environment
>>> data = torch.arange(2, dtype=torch.int64)
>>> data
tensor([0, 1])
>>> output = dist.all_gather(data)
>>> output
[tensor([0, 1])]
>>> # distributed environment
>>> # We have 2 process groups, 2 ranks.
>>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> data
tensor([1, 2])  # Rank 0
tensor([3, 4])  # Rank 1
>>> output = dist.all_gather(data)
>>> output
[tensor([1, 2]), tensor([3, 4])]  # Rank 0
[tensor([1, 2]), tensor([3, 4])]  # Rank 1