Shortcuts

Source code for mmengine.optim.optimizer.optimizer_wrapper

# Copyright (c) OpenMMLab. All rights reserved.
import logging
from contextlib import contextmanager
from typing import Dict, List, Optional

import torch
import torch.nn as nn
from torch.optim import Optimizer

from mmengine.logging import MessageHub, print_log
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils.dl_utils import has_batch_norm
from .base import BaseOptimWrapper


[docs]@OPTIM_WRAPPERS.register_module() class OptimWrapper(BaseOptimWrapper): """Optimizer wrapper provides a common interface for updating parameters. Optimizer wrapper provides a unified interface for single precision training and automatic mixed precision training with different hardware. OptimWrapper encapsulates optimizer to provide simplified interfaces for commonly used training techniques such as gradient accumulative and grad clips. ``OptimWrapper`` implements the basic logic of gradient accumulation and gradient clipping based on ``torch.optim.Optimizer``. The subclasses only need to override some methods to implement the mixed precision training. See more information in :class:`AmpOptimWrapper`. Args: optimizer (Optimizer): Optimizer used to update model parameters. accumulative_counts (int): The number of iterations to accumulate gradients. The parameters will be updated per ``accumulative_counts``. clip_grad (dict, optional): If ``clip_grad`` is not None, it will be the arguments of :func:`torch.nn.utils.clip_grad_norm_` or :func:`torch.nn.utils.clip_grad_value_`. ``clip_grad`` should be a dict, and the keys could be set as follows: If the key ``type`` is not set, or ``type`` is "norm", the accepted keys are as follows: - max_norm (float or int): Max norm of the gradients. - norm_type (float or int): Type of the used p-norm. Can be ``'inf'`` for infinity norm. - error_if_nonfinite (bool): If True, an error is thrown if the total norm of the gradients from :attr:`parameters` is ``nan``, ``inf``, or ``-inf``. Defaults to False (will switch to True in the future) If the key ``type`` is set to "value", the accepted keys are as follows: - clip_value (float or int): maximum allowed value of the gradients. The gradients are clipped in the range ``(-clip_value, +clip_value)``. Note: If ``accumulative_counts`` is larger than 1, perform :meth:`update_params` under the context of ``optim_context`` could avoid unnecessary gradient synchronization. Note: If you use ``IterBasedRunner`` and enable gradient accumulation, the original `max_iters` should be multiplied by ``accumulative_counts``. Note: The subclass should ensure that once :meth:`update_params` is called, ``_inner_count += 1`` is automatically performed. Examples: >>> # Config sample of OptimWrapper and enable clipping gradient by >>> # norm. >>> optim_wrapper_cfg = dict( >>> type='OptimWrapper', >>> _accumulative_counts=1, >>> clip_grad=dict(max_norm=0.2)) >>> # Config sample of OptimWrapper and enable clipping gradient by >>> # value. >>> optim_wrapper_cfg = dict( >>> type='OptimWrapper', >>> _accumulative_counts=1, >>> clip_grad=dict(type='value', clip_value=0.2)) >>> # Use OptimWrapper to update model. >>> import torch.nn as nn >>> import torch >>> from torch.optim import SGD >>> from torch.utils.data import DataLoader >>> from mmengine.optim import OptimWrapper >>> >>> model = nn.Linear(1, 1) >>> dataset = torch.randn(10, 1, 1) >>> dataloader = DataLoader(dataset) >>> optimizer = SGD(model.parameters(), lr=0.1) >>> optim_wrapper = OptimWrapper(optimizer) >>> >>> for data in dataloader: >>> loss = model(data) >>> optim_wrapper.update_params(loss) >>> # Enable gradient accumulation >>> optim_wrapper_cfg = dict( >>> type='OptimWrapper', >>> _accumulative_counts=3, >>> clip_grad=dict(max_norm=0.2)) >>> ddp_model = DistributedDataParallel(model) >>> optimizer = SGD(ddp_model.parameters(), lr=0.1) >>> optim_wrapper = OptimWrapper(optimizer) >>> optim_wrapper.initialize_count_status(0, len(dataloader)) >>> # If model is a subclass instance of DistributedDataParallel, >>> # `optim_context` context manager can avoid unnecessary gradient >>> # synchronize. >>> for iter, data in enumerate(dataloader): >>> with optim_wrapper.optim_context(ddp_model): >>> loss = model(data) >>> optim_wrapper.update_params(loss) """ def __init__(self, optimizer: Optimizer, accumulative_counts: int = 1, clip_grad: Optional[dict] = None): assert accumulative_counts > 0, ( '_accumulative_counts at least greater than or equal to 1') self._accumulative_counts = accumulative_counts self.optimizer = optimizer if clip_grad is not None: # clip_grad_kwargs should not be non-empty dict. assert isinstance(clip_grad, dict) and clip_grad, ( 'If `clip_grad` is not None, it should be a `dict` ' 'which is the arguments of `torch.nn.utils.clip_grad_norm_` ' 'or clip_grad_value_`.') clip_type = clip_grad.pop('type', 'norm') if clip_type == 'norm': self.clip_func = torch.nn.utils.clip_grad_norm_ self.grad_name = 'grad_norm' elif clip_type == 'value': self.clip_func = torch.nn.utils.clip_grad_value_ self.grad_name = 'grad_value' else: raise ValueError('type of clip_grad should be "norm" or ' f'"value" but got {clip_type}') assert clip_grad, ('`clip_grad` should contain other arguments ' 'besides `type`. The arguments should match ' 'with the `torch.nn.utils.clip_grad_norm_` or ' 'clip_grad_value_`') self.clip_grad_kwargs = clip_grad # Used to update `grad_norm` log message. self.message_hub = MessageHub.get_current_instance() self._inner_count = 0 # `_max_counts` means the total number of parameter updates. It # ensures that the gradient of the last few iterations will not be # lost when the `_max_counts` is not divisible by # `accumulative_counts`. self._max_counts = -1 # The `_remainder_iter` is used for calculating loss factor at the # last few iterations. If `_max_counts` has not been initialized, # the loss factor will always be the same as `_accumulative_counts`. self._remainder_counts = -1 # The Following code is used to initialize `base_param_settings`. # `base_param_settings` is used to store the parameters that are not # updated by the optimizer. # The `base_param_settings` used for tracking the base learning in the # optimizer. If the optimizer has multiple parameter groups, this # params will not be scaled by the loss factor. if len(optimizer.param_groups) > 1: self.base_param_settings = { 'params': torch.tensor([0.0], dtype=torch.float) } self.base_param_settings.update(**self.optimizer.defaults) else: self.base_param_settings = None # type: ignore
[docs] def update_params( # type: ignore self, loss: torch.Tensor, step_kwargs: Optional[Dict] = None, zero_kwargs: Optional[Dict] = None) -> None: """Update parameters in :attr:`optimizer`. Args: loss (torch.Tensor): A tensor for back propagation. step_kwargs (dict): Arguments for optimizer.step. Defaults to None. New in version v0.4.0. zero_kwargs (dict): Arguments for optimizer.zero_grad. Defaults to None. New in version v0.4.0. """ if step_kwargs is None: step_kwargs = {} if zero_kwargs is None: zero_kwargs = {} loss = self.scale_loss(loss) self.backward(loss) # Update parameters only if `self._inner_count` is divisible by # `self._accumulative_counts` or `self._inner_count` equals to # `self._max_counts` if self.should_update(): self.step(**step_kwargs) self.zero_grad(**zero_kwargs)
[docs] def backward(self, loss: torch.Tensor, **kwargs) -> None: """Perform gradient back propagation. Provide unified ``backward`` interface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic. For example, ``torch.cuda.amp`` require some extra operation on GradScaler during backward process. Note: If subclasses inherit from ``OptimWrapper`` override ``backward``, ``_inner_count +=1`` must be implemented. Args: loss (torch.Tensor): The loss of current iteration. kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`. """ loss.backward(**kwargs) self._inner_count += 1
[docs] def zero_grad(self, **kwargs) -> None: """A wrapper of ``Optimizer.zero_grad``. Provide unified ``zero_grad`` interface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic. Args: kwargs: Keyword arguments passed to :meth:`torch.optim.Optimizer.zero_grad`. """ self.optimizer.zero_grad(**kwargs)
[docs] def step(self, **kwargs) -> None: """A wrapper of ``Optimizer.step``. Provide unified ``step`` interface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic. For example, ``torch.cuda.amp`` require some extra operation on ``GradScaler`` during step process. Clip grad if :attr:`clip_grad_kwargs` is not None, and then update parameters. Args: kwargs: Keyword arguments passed to :meth:`torch.optim.Optimizer.step`. """ if self.clip_grad_kwargs: self._clip_grad() self.optimizer.step(**kwargs)
[docs] @contextmanager def optim_context(self, model: nn.Module): """A Context for gradient accumulation and automatic mix precision training. If subclasses need to enable the context for mix precision training, e.g., ``:class:`AmpOptimWrapper``, the corresponding context should be enabled in `optim_context`. Since ``OptimWrapper`` uses default fp32 training, ``optim_context`` will only enable the context for blocking the unnecessary gradient synchronization during gradient accumulation If model is an instance with ``no_sync`` method (which means blocking the gradient synchronization) and ``self._accumulative_counts != 1``. The model will not automatically synchronize gradients if ``cur_iter`` is divisible by ``self._accumulative_counts``. Otherwise, this method will enable an empty context. Args: model (nn.Module): The training model. """ # During gradient accumulation process, the gradient synchronize # should only happen before updating parameters. if not self.should_sync() and hasattr(model, 'no_sync'): with model.no_sync(): yield else: yield
def _clip_grad(self) -> None: """Clip the gradients of parameters.""" params: List[torch.Tensor] = [] for param_group in self.optimizer.param_groups: params.extend(param_group['params']) params = list( filter(lambda p: p.requires_grad and p.grad is not None, params)) if len(params) > 0: grad = self.clip_func(params, **self.clip_grad_kwargs) # `torch.nn.utils.clip_grad_value_` will return None. if grad is not None: self.message_hub.update_scalar(f'train/{self.grad_name}', float(grad))
[docs] def initialize_count_status(self, model: nn.Module, init_counts: int, max_counts: int) -> None: """Initialize gradient accumulation related attributes. ``OptimWrapper`` can be used without calling ``initialize_iter_status``. However, Consider the case of ``len( dataloader) == 10``, and the ``accumulative_iter == 3``. Since 10 is not divisible by 3, the last iteration will not trigger ``optimizer.step()``, resulting in one less parameter updating. Args: model (nn.Module): Training model init_counts (int): The initial value of the inner count. max_counts (int): The maximum value of the inner count. """ self._inner_count = init_counts self._max_counts = max_counts if self._inner_count % self._accumulative_counts != 0: print_log( 'Resumed iteration number is not divisible by ' '`_accumulative_counts` in `GradientCumulativeOptimizerHook`, ' 'which means the gradient of some iterations is lost and the ' 'result may be influenced slightly.', logger='current', level=logging.WARNING) if has_batch_norm(model) and self._accumulative_counts > 1: print_log( 'Gradient accumulative may slightly decrease ' 'performance because the model has BatchNorm layers.', logger='current', level=logging.WARNING) # Remainder of `_max_counts` divided by `_accumulative_counts` self._remainder_counts = self._max_counts % self._accumulative_counts
[docs] def should_update(self) -> bool: """Decide whether the parameters should be updated at the current iteration. Called by :meth:`update_params` and check whether the optimizer wrapper should update parameters at current iteration. Returns: bool: Whether to update parameters. """ return (self._inner_count % self._accumulative_counts == 0 or self._inner_count == self._max_counts)
[docs] def should_sync(self) -> bool: """Decide whether the automatic gradient synchronization should be allowed at the current iteration. It takes effect when gradient accumulation is used to skip synchronization at the iterations where the parameter is not updated. Since ``should_sync`` is called by :meth:`optim_context`, and it is called before :meth:`backward` which means ``self._inner_count += 1`` has not happened yet. Therefore, ``self._inner_count += 1`` should be performed manually here. Returns: bool: Whether to block the automatic gradient synchronization. """ return ((self._inner_count + 1) % self._accumulative_counts == 0 or (self._inner_count + 1) == self._max_counts)
[docs] def scale_loss(self, loss: torch.Tensor) -> torch.Tensor: """Get scaled loss according to ``_accumulative_counts``, ``_inner_count`` and max_counts. Args: loss (torch.Tensor): Original loss calculated by model. Returns: loss (torch.Tensor): Scaled loss. """ if self._accumulative_counts == 1: # update parameters without gradient accumulation. The gradient # should not be rescaled and `loss_factor=1`. loss_factor = 1 elif self._max_counts == -1: loss_factor = self._accumulative_counts else: # if `self._accumulative_counts > 1`, the gradient needs to be # rescaled and accumulated. In most cases, `loss_factor` equals to # `self._accumulative_counts`. However, `self._max_counts` may not # be divisible by `self._accumulative_counts`, so the # `loss_scale` for the last few iterations needs to be # recalculated. if self._inner_count < self._max_counts - self._remainder_counts: loss_factor = self._accumulative_counts else: loss_factor = self._remainder_counts assert loss_factor > 0, ( 'loss_factor should be larger than zero! This error could ' 'happened when initialize_iter_status called with an ' 'error `init_counts` or `max_counts`') loss = loss / loss_factor return loss
@property def inner_count(self): """Get the number of updating parameters of optimizer wrapper.""" return self._inner_count def __repr__(self): wrapper_info = (f'Type: {type(self).__name__}\n' f'_accumulative_counts: {self._accumulative_counts}\n' 'optimizer: \n') optimizer_str = repr(self.optimizer) + '\n' return wrapper_info + optimizer_str

© Copyright 2022, mmengine contributors. Revision 2c4516c6.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.