Shortcuts

mmengine.optim.optimizer._deepspeed 源代码

# Copyright (c) OpenMMLab. All rights reserved.

import torch

from mmengine.registry import OPTIM_WRAPPERS
from .base import BaseOptimWrapper


[文档]@OPTIM_WRAPPERS.register_module() class DeepSpeedOptimWrapper(BaseOptimWrapper): def __init__(self, optimizer): super().__init__(optimizer) self._model = None @property def model(self): if self._model is None: raise ValueError('model attribute should be set before accessing.') return self._model @model.setter def model(self, value): self._model = value
[文档] def update_params(self, loss) -> None: # type: ignore """Update parameters in :attr:`optimizer`.""" self.backward(loss) self.step()
[文档] def backward(self, loss: torch.Tensor, **kwargs) -> None: """"Perform gradient back propagation.""" self.model.backward(loss)
[文档] def zero_grad(self, **kwargs) -> None: raise NotImplementedError( 'DeepSpeedOptimWrapper does not support zero_grad method ' 'currently.')
[文档] def step(self, **kwargs): self.model.step()
[文档] def state_dict(self) -> dict: state_dict = {} if self.base_param_settings is not None: state_dict['base_param_settings'] = self.base_param_settings return state_dict
[文档] def load_state_dict(self, state_dict: dict) -> None: base_param_settings = state_dict.pop('base_param_settings', None) if base_param_settings is not None: self.base_param_settings = base_param_settings

© Copyright 2022, mmengine contributors. Revision b2295a25.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.8.1
Versions
latest
stable
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.