Shortcuts

ZeroRedundancyOptimizer

class mmengine.optim.ZeroRedundancyOptimizer(params, optimizer_type, **kwargs)[source]

A wrapper class of ZeroRedundancyOptimizer that gets a optimizer type as string.

This class wraps an arbitrary torch.optim.Optimizer and shards its states across ranks in the group as described by ZeRO. The local optimizer instance in each rank is only responsible for updating approximately 1 / world_size parameters and hence only needs to keep 1 / world_size optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state. ZeroRedundancyOptimizer can be used in conjunction with torch.nn.parallel.DistributedDataParallel to reduce per-rank peak memory consumption.

ZeroRedundancyOptimizer uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order.

Warning

ZeroRedundancyOptimizer requires PyTorch >= 1.8.

Warning

ZeroRedundancyOptimizer requires PyTorch >= 1.12 to enable param groups.

Parameters
  • params (Iterable) – an Iterable of torch.Tensor s or dict s giving all parameters, which will be sharded across ranks.

  • optimizer_type (str) – the string of the local optimizer class.

state_dict()[source]

Consolidate state_dict`s from ranks to save the `state_dict.

Read the Docs v: v0.8.3
Versions
latest
stable
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.