AmpOptimWrapper¶
- class mmengine.optim.AmpOptimWrapper(loss_scale='dynamic', **kwargs)[源代码]¶
A subclass of
OptimWrapperthat supports automatic mixed precision training based on torch.cuda.amp.AmpOptimWrapperprovides a unified interface withOptimWrapper, soAmpOptimWrappercan be used in the same way asOptimWrapper.警告
AmpOptimWrapperrequires PyTorch >= 1.6.- 参数
loss_scale (float or str or dict) –
The initial configuration of torch.cuda.amp.GradScaler. See more specific arguments introduction at PyTorch AMP # noqa: E501 Defaults to
dynamic.”dynamic”: Initialize GradScale without any arguments.
float: Initialize GradScaler with
init_scale.dict: Initialize GradScaler with more detail configuration.
**kwargs – Keyword arguments passed to OptimWrapper.
注解
If you use
IterBasedRunnerand enable gradient accumulation, the original max_iters should be multiplied byaccumulative_counts.- backward(loss, **kwargs)[源代码]¶
Perform gradient back propagation with
loss_scaler.- 参数
loss (torch.Tensor) – The loss of current iteration.
kwargs – Keyword arguments passed to
torch.Tensor.backward()
- load_state_dict(state_dict)[源代码]¶
Load and parse the state dictionary of
optimizerandloss_scaler.If state_dict contains “loss_scaler.”, the
loss_scalerwill load the corresponding keys. Otherwise, only theoptimizerwill load the state dictionary.- 参数
state_dict (dict) – The state dict of
optimizerandloss_scaler
- optim_context(model)[源代码]¶
Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context.
- 参数
model (nn.Module) – The training model.
- state_dict()[源代码]¶
Get the state dictionary of
optimizerandloss_scaler.Based on the state dictionary of the optimizer, the returned state dictionary will add a key named “loss_scaler”.
- 返回
The merged state dict of
loss_scalerandoptimizer.- 返回类型
- step(**kwargs)[源代码]¶
Update parameters with
loss_scaler.- 参数
kwargs – Keyword arguments passed to
torch.optim.Optimizer.step().