Distribution Communication¶
In distributed training, different processes sometimes need to apply different logics depending on their ranks, local_ranks, etc. They also need to communicate with each other and do synchronizations on data. These demands rely on distributed communication. PyTorch provides a set of basic distributed communication primitives. Based on these primitives, MMEngine provides some higher level APIs to meet more diverse demands. Using these APIs provided by MMEngine, modules can:
ignore the differences between distributed/non-distributed environment
deliver data in various types apart from Tensor
ignore the frameworks or backends used for communication
These APIs are roughly categorized into 3 types:
Initialization:
init_dist
for setting up distributed environment for the runnerQuery & control: functions including
get_world_size
for queryingworld_size
,rank
and other distributed informationCollective communication: collective communication functions such as
all_reduce
We will detail on these APIs in the following chapters.
Initialization¶
init_dist: Launch function of distributed training. Currently it supports 3 launchers including pytorch, slurm and MPI. It also setup the given communication backends, defaults to NCCL.
If you need to change the runtime timeout (default=30 minutes) for distributed operations that take very long, you can specify a different timeout in your
env_cfg
configuration passing in Runner like this:env_cfg = dict( cudnn_benchmark=True, mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl', timeout=10800), # Sets the timeout to 3h (10800 seconds) ) runner = Runner(xxx, env_cfg=env_cfg)
Query and control¶
The query and control functions are all argument free. They can be used in both distributed and non-distributed environment. Their functionalities are listed below:
get_world_size: Returns the number of processes in current process group. Returns 1 when non-distributed
get_rank: Returns the global rank of current process in current process group. Returns 0 when non-distributed
get_backend: Returns the communication backends used by current process group. Returns
None
when non-distributedget_local_rank: Returns the local rank of current process in current process group. Returns 0 when non-distributed
get_local_size: Returns the number of processes which are both in current process group and on the same machine as the current process. Returns 1 when non-distributed
get_dist_info: Returns the world_size and rank of the current process group. Returns world_size = 1, rank = 0 when non-distributed
is_main_process: Returns
True
if current process is rank 0 in current process group, otherwiseFalse
. Always returnsTrue
when non-distributedmaster_only: A function decorator. Functions decorated by
master_only
will only execute on rank 0 process.barrier: A synchronization primitive. Every process will hold until all processes in the current process group reach the same barrier location
Collective communication¶
Collective communication functions are used for data transfer between processes in the same process group. We provide the following APIs based on PyTorch native functions including all_reduce, all_gather, gather, broadcast. These APIs are compatible with non-distributed environment and support more data types apart from Tensor.
all_reduce: AllReduce operation on Tensors in the current process group
all_gather: AllGather operation on Tensors in the current process group
gather: Gather Tensors in the current process group to a destinated rank
broadcast: Broadcast a Tensor to all processes in the current process group
sync_random_seed: Synchronize random seed between processes in the current process group
broadcast_object_list: Broadcast a list of Python objects. It requires the object can be serialized by Pickle.
all_reduce_dict: AllReduce operation on dict. It is based on broadcast and all_reduce.
all_gather_object: AllGather operations on any Python object than can be serialized by Pickle. It is based on all_gather
gather_object: Gather Python objects that can be serialized by Pickle
collect_results: Unified API for collecting a list of data in current process group. It support both CPU and GPU communication