Shortcuts

Source code for mmengine.optim.optimizer.zero_optimizer

# Copyright (c) OpenMMLab. All rights reserved.

import torch
from torch.distributed.rpc import is_available

from mmengine.dist import is_main_process
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

try:
    from torch.distributed.optim import \
        ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer
except ImportError:
    _ZeroRedundancyOptimizer = object

from .builder import OPTIMIZERS


[docs]@OPTIMIZERS.register_module() class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer): """A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a optimizer type as string. This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its states across ranks in the group as described by ZeRO_. The local optimizer instance in each rank is only responsible for updating approximately ``1 / world_size`` parameters and hence only needs to keep ``1 / world_size`` optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak memory consumption. ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order. Warnings: ``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8. Warnings: ``ZeroRedundancyOptimizer`` requires PyTorch >= 1.12 to enable param groups. Args: params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s or :class:`dict` s giving all parameters, which will be sharded across ranks. optimizer_type (str): the string of the local optimizer class. .. _ZeRO: https://arxiv.org/abs/1910.02054 """ def __init__(self, params, optimizer_type: str, **kwargs): assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), ( '`torch.distributed.optim.ZeroReundancyOptimizer` is only ' 'available when pytorch version >= 1.8.') assert is_available(), 'torch.distributed.rpc is not available.' # Avoid the generator becoming empty after the following check params = list(params) assert ( all(isinstance(p, torch.Tensor) for p in params) or digit_version(TORCH_VERSION) >= digit_version('1.12.0')), ( 'PyTorch ZeroRedundancyOptimizer started to support param ' 'groups since 1.12.0. Please update your pytorch version to ' 'enable this feature, or disable param groups by deleting ' '`paramwise_cfg` filed in config file.') optimizer_class = getattr(torch.optim, optimizer_type) # TODO: Register a DDP communication hook for `overlap_with_ddp=True`. # Currently only `overlap_with_ddp=False` is supported. For more # details, please refer to the pytorch's official documentation. super().__init__(params, optimizer_class, **kwargs)
[docs] def state_dict(self): """Consolidate `state_dict`s from ranks to save the `state_dict`.""" self.consolidate_state_dict() state_dict = super().state_dict() if is_main_process() else dict() return state_dict

© Copyright 2022, mmengine contributors. Revision 6a56ca78.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.6.0
Versions
latest
stable
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.