mmengine.optim.optimizer._deepspeed 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.registry import OPTIM_WRAPPERS
from .base import BaseOptimWrapper
[文档]@OPTIM_WRAPPERS.register_module()
class DeepSpeedOptimWrapper(BaseOptimWrapper):
def __init__(self, optimizer):
super().__init__(optimizer)
self._model = None
@property
def model(self):
if self._model is None:
raise ValueError('model attribute should be set before accessing.')
return self._model
@model.setter
def model(self, value):
self._model = value
[文档] def update_params(self, loss) -> None: # type: ignore
"""Update parameters in :attr:`optimizer`."""
self.backward(loss)
self.step()
[文档] def backward(self, loss: torch.Tensor, **kwargs) -> None:
""""Perform gradient back propagation."""
self.model.backward(loss)
[文档] def zero_grad(self, **kwargs) -> None:
raise NotImplementedError(
'DeepSpeedOptimWrapper does not support zero_grad method '
'currently.')
[文档] def state_dict(self) -> dict:
state_dict = {}
if self.base_param_settings is not None:
state_dict['base_param_settings'] = self.base_param_settings
return state_dict
[文档] def load_state_dict(self, state_dict: dict) -> None:
base_param_settings = state_dict.pop('base_param_settings', None)
if base_param_settings is not None:
self.base_param_settings = base_param_settings