Shortcuts

mmengine.optim.optimizer.zero_optimizer 源代码

# Copyright (c) OpenMMLab. All rights reserved.

import torch
from torch.distributed.rpc import is_available

from mmengine.dist import is_main_process
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

try:
    from torch.distributed.optim import \
        ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer
except ImportError:
    _ZeroRedundancyOptimizer = object

from .builder import OPTIMIZERS


[文档]@OPTIMIZERS.register_module() class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer): """A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a optimizer type as string. This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its states across ranks in the group as described by ZeRO_. The local optimizer instance in each rank is only responsible for updating approximately ``1 / world_size`` parameters and hence only needs to keep ``1 / world_size`` optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak memory consumption. ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order. Warnings: ``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8. Warnings: ``ZeroRedundancyOptimizer`` requires PyTorch >= 1.12 to enable param groups. Args: params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s or :class:`dict` s giving all parameters, which will be sharded across ranks. optimizer_type (str): the string of the local optimizer class. .. _ZeRO: https://arxiv.org/abs/1910.02054 """ def __init__(self, params, optimizer_type: str, **kwargs): assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), ( '`torch.distributed.optim.ZeroReundancyOptimizer` is only ' 'available when pytorch version >= 1.8.') assert is_available(), 'torch.distributed.rpc is not available.' # Avoid the generator becoming empty after the following check params = list(params) assert ( all(isinstance(p, torch.Tensor) for p in params) or digit_version(TORCH_VERSION) >= digit_version('1.12.0')), ( 'PyTorch ZeroRedundancyOptimizer started to support param ' 'groups since 1.12.0. Please update your pytorch version to ' 'enable this feature, or disable param groups by deleting ' '`paramwise_cfg` filed in config file.') optimizer_class = getattr(torch.optim, optimizer_type) # TODO: Register a DDP communication hook for `overlap_with_ddp=True`. # Currently only `overlap_with_ddp=False` is supported. For more # details, please refer to the pytorch's official documentation. super().__init__(params, optimizer_class, **kwargs)
[文档] def state_dict(self): """Consolidate `state_dict`s from ranks to save the `state_dict`.""" self.consolidate_state_dict() state_dict = super().state_dict() if is_main_process() else dict() return state_dict

© Copyright 2022, mmengine contributors. Revision b2295a25.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.8.1
Versions
latest
stable
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.