AmpOptimWrapper¶
- class mmengine.optim.AmpOptimWrapper(loss_scale='dynamic', dtype=None, use_fsdp=False, **kwargs)[源代码]¶
A subclass of
OptimWrapper
that supports automatic mixed precision training based on torch.cuda.amp.AmpOptimWrapper
provides a unified interface withOptimWrapper
, soAmpOptimWrapper
can be used in the same way asOptimWrapper
.警告
AmpOptimWrapper
requires PyTorch >= 1.6.- 参数:
loss_scale (float or str or dict) –
The initial configuration of torch.cuda.amp.GradScaler. See more specific arguments introduction at PyTorch AMP # noqa: E501 Defaults to
dynamic
.”dynamic”: Initialize GradScale without any arguments.
float: Initialize GradScaler with
init_scale
.dict: Initialize GradScaler with more detail configuration.
dtype (str or torch.dtype, optional) – The data type to autocast in amp. If a
str
is given, it will be converted totorch.dtype
. Validstr
format are ‘float16’, ‘bfloat16’, ‘float32’ and ‘float64’. If set toNone
, the default data type will be used. Defaults to None. New in version 0.6.1.use_fsdp (bool) – Using
ShardedGradScaler
when it is True. It should be enabled when usingFullyShardedDataParallel
. Defaults to False. New in version 0.8.0.**kwargs – Keyword arguments passed to OptimWrapper.
警告
dtype
argument is only available with PyTorch version >= 1.10.0. If you use PyTorch of an older version, it will be ignored.备注
If you use
IterBasedRunner
and enable gradient accumulation, the original max_iters should be multiplied byaccumulative_counts
.- backward(loss, **kwargs)[源代码]¶
Perform gradient back propagation with
loss_scaler
.- 参数:
loss (torch.Tensor) – The loss of current iteration.
kwargs – Keyword arguments passed to
torch.Tensor.backward()
- load_state_dict(state_dict)[源代码]¶
Load and parse the state dictionary of
optimizer
andloss_scaler
.If state_dict contains “loss_scaler.”, the
loss_scaler
will load the corresponding keys. Otherwise, only theoptimizer
will load the state dictionary.- 参数:
state_dict (dict) – The state dict of
optimizer
andloss_scaler
- optim_context(model)[源代码]¶
Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context.
- 参数:
model (nn.Module) – The training model.
- state_dict()[源代码]¶
Get the state dictionary of
optimizer
andloss_scaler
.Based on the state dictionary of the optimizer, the returned state dictionary will add a key named “loss_scaler”.
- 返回:
The merged state dict of
loss_scaler
andoptimizer
.- 返回类型:
- step(**kwargs)[源代码]¶
Update parameters with
loss_scaler
.- 参数:
kwargs – Keyword arguments passed to
torch.optim.Optimizer.step()
.