mmengine.dist.all_reduce¶
- mmengine.dist.all_reduce(data, op='sum', group=None)[source]¶
Reduces the tensor data across all machines in such a way that all get the final result.
After the call
data
is going to be bitwise identical in all processes.Note
Calling
all_reduce
in non-distributed environment does nothing.- Parameters:
data (Tensor) – Input and output of the collective. The function operates in-place.
op (str) – Operation to reduce data. Defaults to ‘sum’. Optional values are ‘sum’, ‘mean’ and ‘produce’, ‘min’, ‘max’, ‘band’, ‘bor’ and ‘bxor’.
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. Defaults to None.
- Return type:
None
Examples
>>> import torch >>> import mmengine.dist as dist
>>> # non-distributed environment >>> data = torch.arange(2, dtype=torch.int64) >>> dist.all_reduce(data) >>> data tensor([0, 1])
>>> # distributed environment >>> # We have 2 process groups, 2 ranks. >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank >>> data tensor([1, 2]) # Rank 0 tensor([3, 4]) # Rank 1 >>> dist.all_reduce(data, op=dist.ReduceOp.SUM) >>> data tensor([4, 6]) # Rank 0 tensor([4, 6]) # Rank 1