Shortcuts

AmpOptimWrapper

class mmengine.optim.AmpOptimWrapper(loss_scale='dynamic', dtype=None, use_fsdp=False, **kwargs)[源代码]

A subclass of OptimWrapper that supports automatic mixed precision training based on torch.cuda.amp.

AmpOptimWrapper provides a unified interface with OptimWrapper, so AmpOptimWrapper can be used in the same way as OptimWrapper.

警告

AmpOptimWrapper requires PyTorch >= 1.6.

参数:
  • loss_scale (float or str or dict) –

    The initial configuration of torch.cuda.amp.GradScaler. See more specific arguments introduction at PyTorch AMP # noqa: E501 Defaults to dynamic.

    • ”dynamic”: Initialize GradScale without any arguments.

    • float: Initialize GradScaler with init_scale.

    • dict: Initialize GradScaler with more detail configuration.

  • dtype (str or torch.dtype, optional) – The data type to autocast in amp. If a str is given, it will be converted to torch.dtype. Valid str format are ‘float16’, ‘bfloat16’, ‘float32’ and ‘float64’. If set to None, the default data type will be used. Defaults to None. New in version 0.6.1.

  • use_fsdp (bool) – Using ShardedGradScaler when it is True. It should be enabled when using FullyShardedDataParallel. Defaults to False. New in version 0.8.0.

  • **kwargs – Keyword arguments passed to OptimWrapper.

警告

dtype argument is only available with PyTorch version >= 1.10.0. If you use PyTorch of an older version, it will be ignored.

备注

If you use IterBasedRunner and enable gradient accumulation, the original max_iters should be multiplied by accumulative_counts.

backward(loss, **kwargs)[源代码]

Perform gradient back propagation with loss_scaler.

参数:
load_state_dict(state_dict)[源代码]

Load and parse the state dictionary of optimizer and loss_scaler.

If state_dict contains “loss_scaler.”, the loss_scaler will load the corresponding keys. Otherwise, only the optimizer will load the state dictionary.

参数:

state_dict (dict) – The state dict of optimizer and loss_scaler

optim_context(model)[源代码]

Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context.

参数:

model (nn.Module) – The training model.

state_dict()[源代码]

Get the state dictionary of optimizer and loss_scaler.

Based on the state dictionary of the optimizer, the returned state dictionary will add a key named “loss_scaler”.

返回:

The merged state dict of loss_scaler and optimizer.

返回类型:

dict

step(**kwargs)[源代码]

Update parameters with loss_scaler.

参数:

kwargs – Keyword arguments passed to torch.optim.Optimizer.step().