AmpOptimWrapper¶
- class mmengine.optim.AmpOptimWrapper(loss_scale='dynamic', **kwargs)[源代码]¶
A subclass of
OptimWrapper
that supports automatic mixed precision training based on torch.cuda.amp.AmpOptimWrapper
provides a unified interface withOptimWrapper
, soAmpOptimWrapper
can be used in the same way asOptimWrapper
.警告
AmpOptimWrapper
requires PyTorch >= 1.6.- 参数
loss_scale (float or str or dict) –
The initial configuration of torch.cuda.amp.GradScaler. See more specific arguments introduction at PyTorch AMP # noqa: E501 Defaults to
dynamic
.”dynamic”: Initialize GradScale without any arguments.
float: Initialize GradScaler with
init_scale
.dict: Initialize GradScaler with more detail configuration.
**kwargs – Keyword arguments passed to OptimWrapper.
注解
If you use
IterBasedRunner
and enable gradient accumulation, the original max_iters should be multiplied byaccumulative_counts
.- backward(loss, **kwargs)[源代码]¶
Perform gradient back propagation with
loss_scaler
.- 参数
loss (torch.Tensor) – The loss of current iteration.
kwargs – Keyword arguments passed to
torch.Tensor.backward()
- load_state_dict(state_dict)[源代码]¶
Load and parse the state dictionary of
optimizer
andloss_scaler
.If state_dict contains “loss_scaler.”, the
loss_scaler
will load the corresponding keys. Otherwise, only theoptimizer
will load the state dictionary.- 参数
state_dict (dict) – The state dict of
optimizer
andloss_scaler
- optim_context(model)[源代码]¶
Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context.
- 参数
model (nn.Module) – The training model.
- state_dict()[源代码]¶
Get the state dictionary of
optimizer
andloss_scaler
.Based on the state dictionary of the optimizer, the returned state dictionary will add a key named “loss_scaler”.
- 返回
The merged state dict of
loss_scaler
andoptimizer
.- 返回类型
- step(**kwargs)[源代码]¶
Update parameters with
loss_scaler
.- 参数
kwargs – Keyword arguments passed to
torch.optim.Optimizer.step()
.