Shortcuts

OptimWrapper

class mmengine.optim.OptimWrapper(optimizer, accumulative_counts=1, clip_grad=None)[source]

Optimizer wrapper provides a common interface for updating parameters.

Optimizer wrapper provides a unified interface for single precision training and automatic mixed precision training with different hardware. OptimWrapper encapsulates optimizer to provide simplified interfaces for commonly used training techniques such as gradient accumulative and grad clips. OptimWrapper implements the basic logic of gradient accumulation and gradient clipping based on torch.optim.Optimizer. The subclasses only need to override some methods to implement the mixed precision training. See more information in AmpOptimWrapper.

Parameters
  • optimizer (Optimizer) – Optimizer used to update model parameters.

  • accumulative_counts (int) – The number of iterations to accumulate gradients. The parameters will be updated per accumulative_counts.

  • clip_grad (dict, optional) –

    If clip_grad is not None, it will be the arguments of torch.nn.utils.clip_grad_norm_() or torch.nn.utils.clip_grad_value_(). clip_grad should be a dict, and the keys could be set as follows:

    If the key type is not set, or type is “norm”, the accepted keys are as follows:

    • max_norm (float or int): Max norm of the gradients.

    • norm_type (float or int): Type of the used p-norm. Can be 'inf' for infinity norm.

    • error_if_nonfinite (bool): If True, an error is thrown if the total norm of the gradients from parameters is nan, inf, or -inf. Defaults to False (will switch to True in the future)

    If the key type is set to “value”, the accepted keys are as follows:

    • clip_value (float or int): maximum allowed value of the gradients. The gradients are clipped in the range (-clip_value, +clip_value).

Note

If accumulative_counts is larger than 1, perform update_params() under the context of optim_context could avoid unnecessary gradient synchronization.

Note

If you use IterBasedRunner and enable gradient accumulation, the original max_iters should be multiplied by accumulative_counts.

Note

The subclass should ensure that once update_params() is called, _inner_count += 1 is automatically performed.

Examples

>>> # Config sample of OptimWrapper and enable clipping gradient by
>>> # norm.
>>> optim_wrapper_cfg = dict(
>>>     type='OptimWrapper',
>>>     _accumulative_counts=1,
>>>     clip_grad=dict(max_norm=0.2))
>>> # Config sample of OptimWrapper and enable clipping gradient by
>>> # value.
>>> optim_wrapper_cfg = dict(
>>>     type='OptimWrapper',
>>>     _accumulative_counts=1,
>>>     clip_grad=dict(type='value', clip_value=0.2))
>>> # Use OptimWrapper to update model.
>>> import torch.nn as nn
>>> import torch
>>> from torch.optim import SGD
>>> from torch.utils.data import DataLoader
>>> from mmengine.optim import OptimWrapper
>>>
>>> model = nn.Linear(1, 1)
>>> dataset = torch.randn(10, 1, 1)
>>> dataloader = DataLoader(dataset)
>>> optimizer = SGD(model.parameters(), lr=0.1)
>>> optim_wrapper = OptimWrapper(optimizer)
>>>
>>> for data in dataloader:
>>>     loss = model(data)
>>>     optim_wrapper.update_params(loss)
>>> # Enable gradient accumulation
>>> optim_wrapper_cfg = dict(
>>>     type='OptimWrapper',
>>>     _accumulative_counts=3,
>>>     clip_grad=dict(max_norm=0.2))
>>> ddp_model = DistributedDataParallel(model)
>>> optimizer = SGD(ddp_model.parameters(), lr=0.1)
>>> optim_wrapper = OptimWrapper(optimizer)
>>> optim_wrapper.initialize_count_status(0, len(dataloader))
>>> # If model is a subclass instance of DistributedDataParallel,
>>> # `optim_context` context manager can avoid unnecessary gradient
>>> #  synchronize.
>>> for iter, data in enumerate(dataloader):
>>>     with optim_wrapper.optim_context(ddp_model):
>>>         loss = model(data)
>>>     optim_wrapper.update_params(loss)
backward(loss, **kwargs)[source]

Perform gradient back propagation.

Provide unified backward interface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic. For example, torch.cuda.amp require some extra operation on GradScaler during backward process.

Note

If subclasses inherit from OptimWrapper override backward, _inner_count +=1 must be implemented.

Parameters
Return type

None

initialize_count_status(model, init_counts, max_counts)[source]

Initialize gradient accumulation related attributes.

OptimWrapper can be used without calling initialize_iter_status. However, Consider the case of len( dataloader) == 10, and the accumulative_iter == 3. Since 10 is not divisible by 3, the last iteration will not trigger optimizer.step(), resulting in one less parameter updating.

Parameters
  • model (nn.Module) – Training model

  • init_counts (int) – The initial value of the inner count.

  • max_counts (int) – The maximum value of the inner count.

Return type

None

property inner_count

Get the number of updating parameters of optimizer wrapper.

optim_context(model)[source]

A Context for gradient accumulation and automatic mix precision training.

If subclasses need to enable the context for mix precision training, e.g., :class:`AmpOptimWrapper, the corresponding context should be enabled in optim_context. Since OptimWrapper uses default fp32 training, optim_context will only enable the context for blocking the unnecessary gradient synchronization during gradient accumulation

If model is an instance with no_sync method (which means blocking the gradient synchronization) and self._accumulative_counts != 1. The model will not automatically synchronize gradients if cur_iter is divisible by self._accumulative_counts. Otherwise, this method will enable an empty context.

Parameters

model (nn.Module) – The training model.

scale_loss(loss)[source]

Get scaled loss according to _accumulative_counts, _inner_count and max_counts.

Parameters

loss (torch.Tensor) – Original loss calculated by model.

Returns

Scaled loss.

Return type

loss (torch.Tensor)

should_sync()[source]

Decide whether the automatic gradient synchronization should be allowed at the current iteration.

It takes effect when gradient accumulation is used to skip synchronization at the iterations where the parameter is not updated.

Since should_sync is called by optim_context(), and it is called before backward() which means self._inner_count += 1 has not happened yet. Therefore, self._inner_count += 1 should be performed manually here.

Returns

Whether to block the automatic gradient synchronization.

Return type

bool

should_update()[source]

Decide whether the parameters should be updated at the current iteration.

Called by update_params() and check whether the optimizer wrapper should update parameters at current iteration.

Returns

Whether to update parameters.

Return type

bool

step(**kwargs)[source]

A wrapper of Optimizer.step.

Provide unified step interface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic. For example, torch.cuda.amp require some extra operation on GradScaler during step process.

Clip grad if clip_grad_kwargs is not None, and then update parameters.

Parameters

kwargs – Keyword arguments passed to torch.optim.Optimizer.step().

Return type

None

update_params(loss, step_kwargs=None, zero_kwargs=None)[source]

Update parameters in optimizer.

Parameters
  • loss (torch.Tensor) – A tensor for back propagation.

  • step_kwargs (dict) – Arguments for optimizer.step. Defaults to None. New in version v0.4.0.

  • zero_kwargs (dict) – Arguments for optimizer.zero_grad. Defaults to None. New in version v0.4.0.

Return type

None

zero_grad(**kwargs)[source]

A wrapper of Optimizer.zero_grad.

Provide unified zero_grad interface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic.

Parameters

kwargs – Keyword arguments passed to torch.optim.Optimizer.zero_grad().

Return type

None

Read the Docs v: v0.8.3
Versions
latest
stable
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.