Source code for mmengine.optim.optimizer.apex_optimizer_wrapper

# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from typing import Optional, Union

import torch
import torch.nn as nn

# a circular import will be caused by
# from mmengine.model.wrappers import is_model_wrapper
import mmengine
from mmengine.registry import OPTIM_WRAPPERS
from .optimizer_wrapper import OptimWrapper

    import apex.amp as apex_amp
except ImportError:
    apex_amp = None

[docs]@OPTIM_WRAPPERS.register_module() class ApexOptimWrapper(OptimWrapper): """A subclass of :class:`OptimWrapper` that supports automatic mixed precision training based on apex.amp. ``ApexOptimWrapper`` provides a unified interface with ``OptimWrapper``, so it can be used in the same way as ``OptimWrapper``. Warning: ``ApexOptimWrapper`` requires `nvidia apex <>`_ Args: opt_level (str): Pure or mixed precision optimization level. Accepted values are "O0", "O1", "O2", and "O3". Defaults to "O1". loss_scale (float or str, optional): If passed as a string, must be a string representing a number, e.g., "128.0", or the string "dynamic". Defaults to "dynamic". enabled (bool): If False, renders all Amp calls no-ops, so your script should run as if Amp were not present. Defaults to True. cast_model_type (torch.dtype, optional): Model's parameters and buffers to the desired type. Defaults to None. patch_torch_functions (bool, optional): Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32. Defaults to None. keep_batchnorm_fp32 (bool or str, optional): To enhance precision and enable cudnn batchnorm (which improves performance), it's often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16. If passed as a string, must be the string "True" or "False". Defaults to None. master_weights (bool, optional): Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients. Defaults to None. cast_model_outputs (torch.dtype, optional): Option to ensure that the outputs of your model(s) are always cast to a particular type regardless of ``opt_level``. Defaults to None. num_losses (int): Option to tell Amp in advance how many losses/backward passes you plan to use. Defaults to 1. verbosity (int): Set to 0 to suppress Amp-related output. Defaults to 1. min_loss_scale (float, optional): Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None means that no floor is imposed. If dynamic loss scaling is not used, `min_loss_scale` is ignored. Defaults to None. max_loss_scale (float, optional): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, `max_loss_scale` is ignored. Defaults to 2.**24. **kwargs: Keyword arguments passed to OptimWrapper. Note: If you use ``IterBasedRunner`` and enable gradient accumulation, the original `max_iters` should be multiplied by ``accumulative_counts``. Note: `New in version 0.6.0.` """ # noqa: E501 def __init__(self, opt_level: str = 'O1', loss_scale: Union[float, str, None] = 'dynamic', enabled: Optional[bool] = True, cast_model_type: Optional[torch.dtype] = None, patch_torch_functions: Optional[bool] = None, keep_batchnorm_fp32: Union[bool, str, None] = None, master_weights: Optional[bool] = None, cast_model_outputs: Optional[torch.dtype] = None, num_losses: int = 1, verbosity: int = 1, min_loss_scale: Optional[float] = None, max_loss_scale: Optional[float] = 2.**24, **kwargs): assert apex_amp is not None, \ 'Apex is not installed. Please check ' \ '' super().__init__(**kwargs) self.opt_level = opt_level self.loss_scale = loss_scale self.enabled = enabled self.cast_model_type = cast_model_type self.patch_torch_functions = patch_torch_functions self.keep_batchnorm_fp32 = keep_batchnorm_fp32 self.master_weights = master_weights self.cast_model_outputs = cast_model_outputs self.num_losses = num_losses self.verbosity = verbosity self.min_loss_scale = min_loss_scale self.max_loss_scale = max_loss_scale self._apex_amp_state_dict = None
[docs] def backward(self, loss: torch.Tensor, **kwargs) -> None: """Perform gradient back propagation with :attr:`loss_scaler`. Args: loss (torch.Tensor): The loss of current iteration. kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` """ with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward(**kwargs) self._inner_count += 1
[docs] def state_dict(self) -> dict: """Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`. Based on the state dictionary of the optimizer, the returned state dictionary will add a key named "apex_amp". Returns: dict: The merged state dict of :attr:`apex_amp` and :attr:`optimizer`. """ state_dict = self.optimizer.state_dict() state_dict['apex_amp'] = apex_amp.state_dict() return state_dict
[docs] def load_state_dict(self, state_dict: dict) -> None: """Load and parse the state dictionary of :attr:`optimizer` and :attr:`apex_amp`. If state_dict contains "apex_amp", the :attr:`apex_amp` will load the corresponding keys. Otherwise, only the :attr:`optimizer` will load the state dictionary. Note: :meth:`load_state_dict` shuold be called after `apex_amp.initialize` is called. Args: state_dict (dict): The state dict of :attr:`optimizer` and :attr:`apex_amp` """ if 'apex_amp' in state_dict: # when `apex_amp` is not initialized, calling `load_state_dict` # will raise an error, so we temporarily cache the apex_amp # part, and then load it into `apex_amp` after completing # the `apex_amp` initialization in `optim_context` method if hasattr(self.optimizer, '_amp_stash'): apex_amp.load_state_dict(state_dict.pop('apex_amp')) else: self._apex_amp_state_dict = state_dict.pop('apex_amp') self.optimizer.load_state_dict(state_dict)
[docs] @contextmanager def optim_context(self, model: nn.Module): """Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context. Args: model (nn.Module): The training model. """ with super().optim_context(model): # when a given optimizer be passed through apex_amp.initialize, # the "_amp_stash" property will be added if not hasattr(self.optimizer, '_amp_stash'): if mmengine.model.wrappers.is_model_wrapper(model): model = model.module model, self.optimizer = apex_amp.initialize( model, self.optimizer, opt_level=self.opt_level, loss_scale=self.loss_scale, enabled=self.enabled, cast_model_type=self.cast_model_type, patch_torch_functions=self.patch_torch_functions, keep_batchnorm_fp32=self.keep_batchnorm_fp32, master_weights=self.master_weights, cast_model_outputs=self.cast_model_outputs, num_losses=self.num_losses, verbosity=self.verbosity, min_loss_scale=self.min_loss_scale, max_loss_scale=self.max_loss_scale) # loading apex_amp state_dict after initialization of apex_amp if self._apex_amp_state_dict is not None: apex_amp.load_state_dict(self._apex_amp_state_dict) self._apex_amp_state_dict = None yield

© Copyright 2022, mmengine contributors. Revision b2295a25.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.8.1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.