AmpOptimWrapper¶
- class mmengine.optim.AmpOptimWrapper(loss_scale='dynamic', **kwargs)[source]¶
A subclass of
OptimWrapper
that supports automatic mixed precision training based on torch.cuda.amp.AmpOptimWrapper
provides a unified interface withOptimWrapper
, soAmpOptimWrapper
can be used in the same way asOptimWrapper
.Warning
AmpOptimWrapper
requires PyTorch >= 1.6.- Parameters
loss_scale (float or str or dict) –
The initial configuration of torch.cuda.amp.GradScaler. See more specific arguments introduction at PyTorch AMP # noqa: E501 Defaults to
dynamic
.”dynamic”: Initialize GradScale without any arguments.
float: Initialize GradScaler with
init_scale
.dict: Initialize GradScaler with more detail configuration.
**kwargs – Keyword arguments passed to OptimWrapper.
Note
If you use
IterBasedRunner
and enable gradient accumulation, the original max_iters should be multiplied byaccumulative_counts
.- backward(loss, **kwargs)[source]¶
Perform gradient back propagation with
loss_scaler
.- Parameters
loss (torch.Tensor) – The loss of current iteration.
kwargs – Keyword arguments passed to
torch.Tensor.backward()
- load_state_dict(state_dict)[source]¶
Load and parse the state dictionary of
optimizer
andloss_scaler
.If state_dict contains “loss_scaler.”, the
loss_scaler
will load the corresponding keys. Otherwise, only theoptimizer
will load the state dictionary.- Parameters
state_dict (dict) – The state dict of
optimizer
andloss_scaler
- optim_context(model)[source]¶
Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context.
- Parameters
model (nn.Module) – The training model.
- state_dict()[source]¶
Get the state dictionary of
optimizer
andloss_scaler
.Based on the state dictionary of the optimizer, the returned state dictionary will add a key named “loss_scaler”.
- Returns
The merged state dict of
loss_scaler
andoptimizer
.- Return type
- step(**kwargs)[source]¶
Update parameters with
loss_scaler
.- Parameters
kwargs – Keyword arguments passed to
torch.optim.Optimizer.step()
.