mmengine.dist.all_gather¶
- mmengine.dist.all_gather(data, group=None)[source]¶
Gather data from the whole group in a list.
Note
Calling
all_gather
in non-distributed environment does nothing and just returns a list containingdata
itself.Note
Unlike PyTorch
torch.distributed.all_gather
,all_gather()
in MMEngine does not pass in an empty listgather_list
and returns thegather_list
directly, which is more convenient. The difference between their interfaces is as below:MMEngine: all_gather(data, group) -> gather_list
PyTorch: all_gather(gather_list, data, group) -> None
- Parameters:
data (Tensor) – Tensor to be gathered.
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. Defaults to None.
- Returns:
Return a list containing data from the whole group if in distributed environment, otherwise a list only containing
data
itself.- Return type:
list[Tensor]
Examples
>>> import torch >>> import mmengine.dist as dist
>>> # non-distributed environment >>> data = torch.arange(2, dtype=torch.int64) >>> data tensor([0, 1]) >>> output = dist.all_gather(data) >>> output [tensor([0, 1])]
>>> # distributed environment >>> # We have 2 process groups, 2 ranks. >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank >>> data tensor([1, 2]) # Rank 0 tensor([3, 4]) # Rank 1 >>> output = dist.all_gather(data) >>> output [tensor([1, 2]), tensor([3, 4])] # Rank 0 [tensor([1, 2]), tensor([3, 4])] # Rank 1