Source code for mmengine._strategy.deepspeed
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
import time
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from mmengine.logging import print_log
try:
import deepspeed
except ImportError:
deepspeed = None
import logging
import torch.nn as nn
import mmengine
from mmengine.dist import init_dist, is_main_process
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
STRATEGIES)
from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu
from mmengine.utils import apply_to, digit_version, get_git_hash
from .base import BaseStrategy
def register_deepspeed_optimizers() -> List[str]:
"""Register optimizers in ``deepspeed`` to the ``OPTIMIZERS`` registry.
Returns:
List[str]: A list of registered optimizers' name.
"""
deepspeed_optimizers = []
try:
import deepspeed # noqa: F401
except ImportError:
pass
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.ops.lamb import FusedLamb
from deepspeed.runtime.fp16.onebit import (OnebitAdam, OnebitLamb,
ZeroOneAdam)
OPTIMIZERS.register_module(module=DeepSpeedCPUAdam)
deepspeed_optimizers.append('DeepSpeedCPUAdam')
OPTIMIZERS.register_module(module=FusedAdam)
deepspeed_optimizers.append('FusedAdam')
OPTIMIZERS.register_module(module=FusedLamb)
deepspeed_optimizers.append('FusedLamb')
OPTIMIZERS.register_module(module=OnebitAdam)
deepspeed_optimizers.append('OnebitAdam')
OPTIMIZERS.register_module(module=OnebitLamb)
deepspeed_optimizers.append('OnebitLamb')
OPTIMIZERS.register_module(module=ZeroOneAdam)
deepspeed_optimizers.append('ZeroOneAdam')
return deepspeed_optimizers
[docs]@OPTIM_WRAPPERS.register_module()
class DeepSpeedOptimWrapper(BaseOptimWrapper):
def __init__(self, optimizer):
super().__init__(optimizer)
self._model = None
@property
def model(self):
if self._model is None:
raise ValueError('model attribute should be set before accessing.')
return self._model
@model.setter
def model(self, value):
self._model = value
[docs] def update_params(self, loss) -> None: # type: ignore
"""Update parameters in :attr:`optimizer`."""
self.backward(loss)
self.step()
[docs] def backward(self, loss: torch.Tensor, **kwargs) -> None:
""""Perform gradient back propagation."""
self.model.backward(loss)
[docs] def zero_grad(self, **kwargs) -> None:
raise NotImplementedError(
'DeepSpeedOptimWrapper does not support zero_grad method '
'currently.')
[docs] def state_dict(self) -> dict:
state_dict = {}
if self.base_param_settings is not None:
state_dict['base_param_settings'] = self.base_param_settings
return state_dict
[docs] def load_state_dict(self, state_dict: dict) -> None:
base_param_settings = state_dict.pop('base_param_settings', None)
if base_param_settings is not None:
self.base_param_settings = base_param_settings
[docs]@MODEL_WRAPPERS.register_module()
class MMDeepSpeedEngineWrapper:
def __init__(
self,
*,
model: 'deepspeed.DeepSpeedEngine',
inputs_to_half: Optional[List[Union[int, str]]] = None,
):
self.model = model
self._inputs_to_half = inputs_to_half
def __getattr__(self, name):
return getattr(self.model, name)
def train_step(
self,
data: Union[dict, tuple, list],
optim_wrapper: DeepSpeedOptimWrapper,
) -> Dict[str, torch.Tensor]:
data = self.model.module.data_preprocessor(data, training=True)
data = self._cast_inputs_half(data)
losses = self._run_forward(data, mode='loss')
parsed_loss, log_vars = self.model.module.parse_losses(losses)
optim_wrapper.update_params(parsed_loss)
return log_vars
[docs] def val_step(self, data: Union[dict, tuple, list]) -> list:
"""Gets the prediction of module during validation process.
Args:
data (dict or tuple or list): Data sampled from dataset.
Returns:
list: The predictions of given data.
"""
data = self.model.module.data_preprocessor(data, False)
data = self._cast_inputs_half(data)
return self._run_forward(data, mode='predict')
[docs] def test_step(self, data: Union[dict, tuple, list]) -> list:
"""Gets the predictions of module during testing process.
Args:
data (dict or tuple or list): Data sampled from dataset.
Returns:
list: The predictions of given data.
"""
data = self.model.module.data_preprocessor(data, False)
data = self._cast_inputs_half(data)
return self._run_forward(data, mode='predict')
def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any:
"""Unpacks data for :meth:`forward`
Args:
data (dict or tuple or list): Data sampled from dataset.
mode (str): Mode of forward.
Returns:
dict or list: Results of training or testing mode.
"""
if isinstance(data, dict):
results = self.model(**data, mode=mode)
elif isinstance(data, (list, tuple)):
results = self.model(*data, mode=mode)
else:
raise TypeError('Output of `data_preprocessor` should be '
f'list, tuple or dict, but got {type(data)}')
return results
def _cast_inputs_half(self, inputs: Union[list, tuple, dict, None]):
"""Cast inputs to half precision if needed.
Args:
inputs (list or tuple or dict or None): Inputs to be casted.
Returns:
list or tuple or dict or None: Casted inputs.
"""
if self._inputs_to_half is None:
return inputs
dtype = next(self.model.parameters()).dtype
if isinstance(inputs, (list, tuple)):
new_inputs = []
for i, v in enumerate(inputs):
if i in self._inputs_to_half:
new_inputs.append(
apply_to(v, lambda x: hasattr(x, 'to'),
lambda x: x.to(dtype)))
else:
new_inputs.append(v)
return inputs.__class__(new_inputs)
elif isinstance(inputs, dict):
for k, v in inputs.items():
if k in self._inputs_to_half:
inputs[k] = apply_to(v, lambda x: hasattr(x, 'to'),
lambda x: x.to(dtype))
return inputs
else:
raise TypeError('inputs should be list, tuple or dict, '
f'but got {type(inputs)}')
[docs]@STRATEGIES.register_module()
class DeepSpeedStrategy(BaseStrategy):
"""Support training models with DeepSpeed.
Note:
The detailed usage of parameters can be found at
https://www.deepspeed.ai/docs/config-json/.
Args:
config (str or dict, optional): If it is a string, it is a path to load
config for deepspeed. Defaults to None.
zero_optimization (dict, optional): Enabling and configuring ZeRO
memory optimizations. Defaults to None.
gradient_clipping (float, optional): Enable gradient clipping with
value. Defaults to None.
fp16 (dict, optional): Configuration for using mixed precision/FP16
training that leverages NVIDIA's Apex package. Defaults to None.
inputs_to_half (list[int or str], optional): Which inputs are to
converted to half precision. Defaults to None.
If ``fp16`` is enabled, it also should be set.
bf16 (dict, optional): Configuration for using bfloat16 floating-point
format as an alternative to FP16. Defaults to None.
amp (dict, optional): Configuration for using automatic mixed
precision (AMP) training that leverages NVIDIA's Apex AMP package.
Defaults to None.
activation_checkpointing (dict, optional): Reduce memory usage by
clearing activations of certain layers and recomputing them
during a backward pass.
Defaults to None.
aio (dict, optional): Configuring the asynchronous I/O module for
offloading parameter and optimizer states to persistent (NVMe)
storage. This module uses Linux native asynchronous I/O (libaio).
Defaults to None.
train_micro_batch_size_per_gpu (int, optional): Batch size to be
processed by one GPU in one step (without gradient accumulation).
Defaults to None.
gradient_accumulation_steps (int, optional): Number of training steps
to accumulate gradients before averaging and applying them.
Defaults to None.
exclude_frozen_parameters (bool, optional): Exclude frozen parameters
from saved checkpoint.
"""
def __init__(
self,
*,
# the following args are for deepspeed
config: Union[str, dict, None] = None,
zero_optimization: Optional[dict] = None,
gradient_clipping: Optional[float] = None,
fp16: Optional[dict] = None,
inputs_to_half: Optional[List[Union[int, str]]] = None,
bf16: Optional[dict] = None,
amp: Optional[dict] = None,
activation_checkpointing: Optional[dict] = None,
aio: Optional[dict] = None,
train_micro_batch_size_per_gpu: Optional[int] = None,
gradient_accumulation_steps: Optional[int] = None,
# disable the log printed by deepseed
steps_per_print: int = 10000000000000,
# the following args are for BaseStrategy
exclude_frozen_parameters: Optional[bool] = None,
**kwargs,
):
assert deepspeed is not None, \
'DeepSpeed is not installed. Please check ' \
'https://github.com/microsoft/DeepSpeed#installation.'
super().__init__(**kwargs)
self.config = self._parse_config(config)
if zero_optimization is not None:
self.config['zero_optimization'] = zero_optimization
if gradient_clipping is not None:
self.config['gradient_clipping'] = gradient_clipping
if fp16 is not None:
self.config['fp16'] = fp16
if bf16 is not None:
self.config['bf16'] = bf16
if amp is not None:
self.config['amp'] = amp
if activation_checkpointing is not None:
self.config['activation_checkpointing'] = activation_checkpointing
if aio is not None:
self.config['aio'] = aio
if train_micro_batch_size_per_gpu is not None:
self.config['train_micro_batch_size_per_gpu'] = \
train_micro_batch_size_per_gpu
if gradient_accumulation_steps is not None:
self.config['gradient_accumulation_steps'] = \
gradient_accumulation_steps
else:
self.config.setdefault('gradient_accumulation_steps', 1)
self.config['steps_per_print'] = steps_per_print
self._inputs_to_half = inputs_to_half
assert (exclude_frozen_parameters is None or
digit_version(deepspeed.__version__) >= digit_version('0.13.2')
), ('DeepSpeed >= 0.13.2 is required to enable '
'exclude_frozen_parameters')
self.exclude_frozen_parameters = exclude_frozen_parameters
register_deepspeed_optimizers()
def _parse_config(self, config):
if config is None:
config = dict()
elif isinstance(config, str):
with open(config) as f:
config = json.load(f)
return config
def _setup_distributed( # type: ignore
self,
launcher: Optional[str] = None,
backend: str = 'nccl',
**kwargs,
):
"""Setup distributed environment.
Args:
launcher (str, optional): Way to launch multi processes.
DeepSpeedStrategy does not support the launcher argument.
backend (str): Communication Backends. Supported backends are
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: Other arguments for :func:`deepspeed.init_distributed`.
"""
init_dist(launcher, backend, init_backend='deepspeed', **kwargs)
[docs] def prepare(
self,
model: Union[nn.Module, dict],
*,
optim_wrapper: Union[BaseOptimWrapper, dict, None] = None,
param_scheduler: Union[_ParamScheduler, Dict, List, None] = None,
compile: Union[dict, bool] = False,
dispatch_kwargs: Optional[dict] = None,
):
"""Prepare model and some components.
Args:
model (:obj:`torch.nn.Module` or dict): The model to be run. It
can be a dict used for build a model.
Keyword Args:
optim_wrapper (BaseOptimWrapper or dict, optional): Computing the
gradient of model parameters and updating them.
Defaults to None.
See :meth:`build_optim_wrapper` for examples.
param_scheduler (_ParamScheduler or dict or list, optional):
Parameter scheduler for updating optimizer parameters. If
specified, :attr:`optim_wrapper` should also be specified.
Defaults to None.
See :meth:`build_param_scheduler` for examples.
compile (dict, optional): Config to compile model.
Defaults to False. Requires PyTorch>=2.0.
dispatch_kwargs (dict, optional): Kwargs to be passed to other
methods of Strategy. Defaults to None.
"""
if self._prepared:
return self._prepared_components()
assert dispatch_kwargs is not None
self.dispatch_kwargs.update(dispatch_kwargs)
model = self.build_model(model)
model = self._init_model_weights(model)
if optim_wrapper is not None:
self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model)
self.model = self._wrap_model(model)
self.optim_wrapper.model = self.model # type: ignore
else:
self.model = self._wrap_model(model)
if param_scheduler is not None:
self.param_schedulers = self.build_param_scheduler(
param_scheduler, self.optim_wrapper)
self._prepared = True
return self._prepared_components()
def _wrap_model(self, model: nn.Module) -> nn.Module:
if hasattr(self, 'optim_wrapper'):
engine, self.optim_wrapper.optimizer, *_ = deepspeed.initialize(
model=model,
optimizer=self.optim_wrapper.optimizer,
config=self.config)
else:
engine, *_ = deepspeed.initialize(model=model, config=self.config)
wrapper = MMDeepSpeedEngineWrapper(
model=engine, inputs_to_half=self._inputs_to_half)
return wrapper
[docs] def load_checkpoint(
self,
filename: str,
*,
map_location: Union[str, Callable] = 'cpu',
strict: bool = False,
revise_keys: list = [(r'^module.', '')],
callback: Optional[Callable] = None,
) -> dict:
"""Load checkpoint from given ``filename``.
Warning:
`map_localtion` and `callback` parameters are not supported yet.
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
"""
self.logger.info(f'Load checkpoint from {filename}')
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
_, extra_ckpt = self.model.load_checkpoint(
dirname,
tag=basename,
load_optimizer_states=False,
load_module_strict=not self.exclude_frozen_parameters)
else:
_, extra_ckpt = self.model.load_checkpoint(
dirname, tag=basename, load_optimizer_states=False)
return extra_ckpt
[docs] def resume(
self,
filename: str,
*,
resume_optimizer: bool = True,
resume_param_scheduler: bool = True,
map_location: Union[str, Callable] = 'default',
callback: Optional[Callable] = None,
) -> dict:
"""Resume training from given ``filename``.
Warning:
`map_location` and `callback` parameters are not supported yet.
Args:
filename (str): Accept local filepath.
Keyword Args:
resume_optimizer (bool): Whether to resume optimizer state.
Defaults to True.
resume_param_scheduler (bool): Whether to resume param scheduler
state. Defaults to True.
"""
self.logger.info(f'Resume checkpoint from {filename}')
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
_, extra_ckpt = self.model.load_checkpoint(
dirname,
tag=basename,
load_optimizer_states=resume_optimizer,
load_module_strict=not self.exclude_frozen_parameters)
else:
_, extra_ckpt = self.model.load_checkpoint(
dirname, tag=basename, load_optimizer_states=resume_optimizer)
if resume_optimizer:
self.load_optim_state_dict(extra_ckpt.pop('optim_wrapper'))
if resume_param_scheduler and hasattr(self, 'param_schedulers'):
param_schedulers = extra_ckpt.pop('param_schedulers')
self.load_scheduler_state_dict(param_schedulers)
# resume random seed
resumed_seed = extra_ckpt['meta'].get('seed', None)
current_seed = self._randomness.get('seed')
if resumed_seed is not None and resumed_seed != current_seed:
if current_seed is not None:
self.logger.warning(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{current_seed}"')
self._randomness.update(seed=resumed_seed)
self._set_randomness(**self._randomness)
return extra_ckpt
[docs] def save_checkpoint(
self,
filename: str,
*,
save_optimizer: bool = True,
save_param_scheduler: bool = True,
extra_ckpt: Optional[dict] = None,
callback: Optional[Callable] = None,
) -> None:
"""Save checkpoint to given ``filename``.
Warning:
`callback` parameter is not supported yet.
Args:
filename (str): Filename to save checkpoint.
Keyword Args:
save_param_scheduler (bool): Whether to save the param_scheduler
to the checkpoint. Defaults to True.
extra_ckpt (dict, optional): Extra checkpoint to save.
Defaults to None.
"""
if extra_ckpt is None:
extra_ckpt = dict()
if 'meta' not in extra_ckpt:
extra_ckpt['meta'] = dict()
extra_ckpt['meta'].update(
seed=self.seed,
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
mmengine=mmengine.__version__ + get_git_hash(),
)
if save_param_scheduler and hasattr(self, 'param_schedulers'):
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()
if (not save_optimizer
and self.model.zero_optimization_partition_weights()
and not self.model.zero_gather_16bit_weights_on_model_save()):
print_log(
'Configured to `save_optimizer=False`, but currently using '
"DeepSpeed's ZeRO stage 3 with "
'`gather_16bit_weights_on_model_save=False`. In '
'this configuration, the model cannot be saved properly '
'and will be saved with the optimizer state. '
'To support `save_optimizer=False`, please set '
'`gather_16bit_weights_on_model_save=True` in your '
'DeepSpeed config.',
logger='current',
level=logging.WARNING)
save_optimizer = True
state_dict_kwargs = {}
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
state_dict_kwargs[
'exclude_frozen_parameters'] = self.exclude_frozen_parameters
if save_optimizer:
if hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be
# thrown when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
dirname, basename = osp.split(filename)
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
**state_dict_kwargs)
else:
if self.model.zero_optimization_partition_weights():
state_dict = self.model._zero3_consolidated_16bit_state_dict(
**state_dict_kwargs)
else:
state_dict = self.model.module_state_dict(**state_dict_kwargs)
if is_main_process():
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
save_checkpoint(ckpt, filename)