DeepSpeedOptimWrapper¶
- class mmengine.optim.DeepSpeedOptimWrapper(optimizer)[源代码]¶
- backward(loss, **kwargs)[源代码]¶
“Perform gradient back propagation.
- 参数
loss (torch.Tensor) –
- 返回类型
None
- load_state_dict(state_dict)[源代码]¶
A wrapper of
Optimizer.load_state_dict. load the state dict ofoptimizer.Provide unified
load_state_dictinterface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic. For example, the state dictionary of GradScaler should be loaded when training withtorch.cuda.amp.- 参数
state_dict (dict) – The state dictionary of
optimizer.- 返回类型
None