ZeroRedundancyOptimizer¶
- class mmengine.optim.ZeroRedundancyOptimizer(params, optimizer_type, **kwargs)[源代码]¶
A wrapper class of
ZeroRedundancyOptimizer
that gets a optimizer type as string.This class wraps an arbitrary
torch.optim.Optimizer
and shards its states across ranks in the group as described by ZeRO. The local optimizer instance in each rank is only responsible for updating approximately1 / world_size
parameters and hence only needs to keep1 / world_size
optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state.ZeroRedundancyOptimizer
can be used in conjunction withtorch.nn.parallel.DistributedDataParallel
to reduce per-rank peak memory consumption.ZeroRedundancyOptimizer
uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order.警告
ZeroRedundancyOptimizer
requires PyTorch >= 1.8.警告
ZeroRedundancyOptimizer
requires PyTorch >= 1.12 to enable param groups.- 参数:
params (
Iterable
) – anIterable
oftorch.Tensor
s ordict
s giving all parameters, which will be sharded across ranks.optimizer_type (str) – the string of the local optimizer class.