Shortcuts

AmpOptimWrapper

class mmengine.optim.AmpOptimWrapper(loss_scale='dynamic', dtype=None, **kwargs)[source]

A subclass of OptimWrapper that supports automatic mixed precision training based on torch.cuda.amp.

AmpOptimWrapper provides a unified interface with OptimWrapper, so AmpOptimWrapper can be used in the same way as OptimWrapper.

Warning

AmpOptimWrapper requires PyTorch >= 1.6.

Parameters
  • loss_scale (float or str or dict) –

    The initial configuration of torch.cuda.amp.GradScaler. See more specific arguments introduction at PyTorch AMP # noqa: E501 Defaults to dynamic.

    • ”dynamic”: Initialize GradScale without any arguments.

    • float: Initialize GradScaler with init_scale.

    • dict: Initialize GradScaler with more detail configuration.

  • dtype (str or torch.dtype, optional) – The data type to autocast in amp. If a str is given, it will be converted to torch.dtype. Valid str format are ‘float16’, ‘bfloat16’, ‘float32’ and ‘float64’. If set to None, the default data type will be used. Defaults to None. New in version 0.6.1.

  • **kwargs – Keyword arguments passed to OptimWrapper.

Warning

dtype argument is only available with PyTorch version >= 1.10.0. If you use PyTorch of an older version, it will be ignored.

Note

If you use IterBasedRunner and enable gradient accumulation, the original max_iters should be multiplied by accumulative_counts.

backward(loss, **kwargs)[source]

Perform gradient back propagation with loss_scaler.

Parameters
load_state_dict(state_dict)[source]

Load and parse the state dictionary of optimizer and loss_scaler.

If state_dict contains “loss_scaler.”, the loss_scaler will load the corresponding keys. Otherwise, only the optimizer will load the state dictionary.

Parameters

state_dict (dict) – The state dict of optimizer and loss_scaler

optim_context(model)[source]

Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context.

Parameters

model (nn.Module) – The training model.

state_dict()[source]

Get the state dictionary of optimizer and loss_scaler.

Based on the state dictionary of the optimizer, the returned state dictionary will add a key named “loss_scaler”.

Returns

The merged state dict of loss_scaler and optimizer.

Return type

dict

step(**kwargs)[source]

Update parameters with loss_scaler.

Parameters

kwargs – Keyword arguments passed to torch.optim.Optimizer.step().

Read the Docs v: v0.7.0
Versions
latest
stable
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.