Shortcuts

mmengine._strategy.colossalai 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import os.path as osp
import time
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

try:
    import colossalai
    import colossalai.booster.mixed_precision as colo_precision
    import colossalai.booster.plugin as colo_plugin
    import colossalai.nn.optimizer as colo_optimizer
    from colossalai.booster import Booster
    from colossalai.interface import ModelWrapper
except Exception as e:  # noqa: F841
    colossalai = None
    colo_precision = None
    colo_plugin = None
    colo_optimizer = None
    Booster = None
    ModelWrapper = None

import torch
import torch.nn as nn

import mmengine
from mmengine import mkdir_or_exist
from mmengine._strategy import BaseStrategy
from mmengine.device import get_device
from mmengine.dist import init_dist, is_main_process
from mmengine.fileio import join_path
from mmengine.model import BaseDataPreprocessor
from mmengine.optim import BaseOptimWrapper, OptimWrapper, _ParamScheduler
from mmengine.registry import STRATEGIES, Registry
from mmengine.registry.root import MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS
from mmengine.runner.checkpoint import _load_checkpoint, save_checkpoint
from mmengine.utils import get_git_hash

# Component for colossalai `plugins` and `mixed_precisions`
PLUGINS = Registry('plugin')
MIXED_PRECISIONS = Registry('mixed_precision')


def register_plugins():
    _plugins = inspect.getmembers(
        colo_plugin,
        lambda x: inspect.isclass(x) and issubclass(x, colo_plugin.Plugin))

    for name, plugin in _plugins:
        PLUGINS.register_module(name=name, module=plugin)


def register_optimizers():
    _colo_optimizer = inspect.getmembers(
        colo_optimizer,
        lambda x: inspect.isclass(x) and issubclass(x, torch.optim.Optimizer))
    for name, optim_type in _colo_optimizer:
        OPTIMIZERS.register_module(name=name, module=optim_type, force=True)


def register_mixed_precisions():
    _mixed_precisions = inspect.getmembers(
        colo_precision, lambda x: inspect.isclass(x) and issubclass(
            x, colo_precision.MixedPrecision))

    for name, mixed_precision in _mixed_precisions:
        MIXED_PRECISIONS.register_module(name=name, module=mixed_precision)


[文档]@OPTIM_WRAPPERS.register_module() class ColossalAIOptimWrapper(OptimWrapper): """OptimWrapper for ColossalAI. The available optimizers are: - CPUAdam - FusedAdam - FusedLAMB - FusedSGD - HybridAdam - Lamb - Lars You can find more details in the `colossalai tutorial`_ Args: optimizer (dict or torch.optim.Optimizer): The optimizer to be wrapped. accumulative_counts (int): The number of iterations to accumulate gradients. The parameters will be updated per ``accumulative_counts``. .. _colossalai tutorial: https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/nn/optimizer """ # noqa: E501 def __init__(self, optimizer: torch.optim.Optimizer, booster: Optional[Booster] = None, accumulative_counts: int = 1): super().__init__(optimizer, accumulative_counts=accumulative_counts) self.booster = booster
[文档] @contextmanager def optim_context(self, model: nn.Module): assert isinstance(self.booster, Booster), \ 'Please set the booster attribute before using ' \ '`ColossalAIOptimWrapper`.' if self.booster.plugin.support_no_sync(): no_sync_context = self.booster.no_sync(model, self.optimizer) else: yield return if self.should_sync(): yield else: with no_sync_context: yield
[文档] def backward(self, loss: torch.Tensor, **kwargs) -> None: self._inner_count += 1 self.optimizer.backward(loss, **kwargs)
@MODEL_WRAPPERS.register_module( name=['ColossalAIModelWrapper', 'CollosalAIModelWrapper']) class ColossalAIModelWrapper: def __init__(self, model_wrapper: ModelWrapper, model: nn.Module): self.model_wrapper = model_wrapper self.model = model def __call__(self, *args, **kwargs) -> Any: return self.model_wrapper(*args, **kwargs) def train_step( self, data: Union[dict, tuple, list], optim_wrapper: ColossalAIOptimWrapper, ) -> Dict[str, torch.Tensor]: data = self.model.data_preprocessor(data, training=True) with optim_wrapper.optim_context(self.model): losses = self._run_forward(data, mode='loss') parsed_loss, log_vars = self.model.parse_losses(losses) optim_wrapper.update_params(parsed_loss) return log_vars def val_step(self, data: Union[dict, tuple, list]) -> list: """Gets the prediction of module during validation process. Args: data (dict or tuple or list): Data sampled from dataset. Returns: list: The predictions of given data. """ data = self.model.data_preprocessor(data, False) return self._run_forward(data, mode='predict') test_step = val_step def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any: """Unpacks data for :meth:`forward` Args: data (dict or tuple or list): Data sampled from dataset. mode (str): Mode of forward. Returns: dict or list: Results of training or testing mode. """ if isinstance(data, dict): results = self.model_wrapper(**data, mode=mode) elif isinstance(data, (list, tuple)): results = self.model_wrapper(*data, mode=mode) else: raise TypeError('Output of `data_preprocessor` should be ' f'list, tuple or dict, but got {type(data)}') return results def __getattr__(self, name): if hasattr(self.model_wrapper, name): return getattr(self.model_wrapper, name) elif hasattr(self.model, name): return getattr(self.model, name) else: raise AttributeError( f'{self.model_wrapper} and {self.model} has no ' f'attribute {name}')
[文档]@STRATEGIES.register_module() class ColossalAIStrategy(BaseStrategy): """ Args: config: (str or dict): The colossalai config file to setup distributed environment. See more details in the `colossalai config tutorial`_. mixed_precision (str or MixedPrecision): The mixed precision to run the training. Defaults to None. If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8' fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex. plugin (Plugin): The plugin to run the training. The type of `plugin` could be: - str: The available plugins are ``gemini`` and ``lowlevel-zero``. ``gemini`` means a `ZeRO`_ implementation with chunk-based memory management. You could find more details in the `colossalai gemini tutorial`_. ``lowlevel-zero`` means a Zero-1 and Zero-2 implementation. Although gemini is more memory saving, some unexpceted error could happen for some spectial model structure. lowlevel-zero is more stable. - dict: **dict-type style config to build a colossalai plugin**. See the `booster plugin tutorial`_ for more details. model_wrapper (dict, optional): Dict for model wrapper. Defaults to None. work_dir (str): The working directory to save checkpoints. The logs will be saved in the subdirectory of `work_dir` named :attr:`timestamp`. Defaults to 'work_dirs'. experiment_name (str, optional): Name of current experiment. If not specified, timestamp will be used as :attr:`experiment_name`. Defaults to None. env_kwargs (dict, optional): Environment config passed in :meth:`setup_env`. Defaults to None. log_kwargs (dict, optional): Logger config passed in :meth:`build_logger`. Defaults to None. auto_scale_lr (dict, Optional): Config to scale the learning rate automatically. It includes ``base_batch_size`` and ``enable``. ``base_batch_size`` is the batch size that the optimizer lr is based on. ``enable`` is the switch to turn on and off the feature. .. _colossalai config tutorial: https://colossalai.org/docs/basics/configure_parallelization .. _ZeRO: https://arxiv.org/abs/1910.02054 .. _colossalai gemini tutorial: https://colossalai.org/docs/features/zero_with_chunk/#geminiddp .. _booster plugin tutorial: https://colossalai.org/docs/basics/booster_plugins """ # noqa: E501 OPTIMIZER_DIR = 'optimizer' # directory to save optimizer state. MODEL_DIR = 'model' # directory to save model SCHEDULER_DIR = 'scheduler' # directory to save scheduelrs model: ColossalAIModelWrapper # type: ignore optim_wrapper: ColossalAIOptimWrapper # type: ignore def __init__( self, *, config: Union[str, dict, None] = None, mixed_precision: Union[str, dict, None] = None, plugin: str = 'gemini', model_wrapper: Optional[dict] = None, **kwargs, ): if colossalai is None: raise ModuleNotFoundError( 'Please install colossalai by `pip install -U colossalai`') register_plugins() register_mixed_precisions() register_optimizers() self.config = config or {} super().__init__(**kwargs) if mixed_precision is not None: mixed_precision = self._build_mixed_precision(mixed_precision) if plugin is not None: plugin = self._build_plugin(plugin) self.booster = Booster(mixed_precision=mixed_precision, plugin=plugin) self.model_wrapper = model_wrapper
[文档] def prepare( self, model: Union[nn.Module, dict], *, optim_wrapper: Union[BaseOptimWrapper, dict, None] = None, param_scheduler: Union[_ParamScheduler, Dict, List, None] = None, compile: Union[dict, bool] = False, dispatch_kwargs: Optional[dict] = None, ): """Prepare model and some components. Args: model (:obj:`torch.nn.Module` or dict): The model to be run. It can be a dict used for build a model. Keyword Args: optim_wrapper (BaseOptimWrapper or dict, optional): Computing the gradient of model parameters and updating them. Defaults to None. See :meth:`build_optim_wrapper` for examples. param_scheduler (_ParamScheduler or dict or list, optional): Parameter scheduler for updating optimizer parameters. If specified, :attr:`optim_wrapper` should also be specified. Defaults to None. See :meth:`build_param_scheduler` for examples. compile (dict, optional): Config to compile model. Defaults to False. Requires PyTorch>=2.0. dispatch_kwargs (dict, optional): Kwargs to be passed to other methods of Strategy. Defaults to None. If ``accumulative_counts`` is set in ``optim_wrapper``, you need to provide ``max_iters`` in ``dispatch_kwargs``. """ if self._prepared: return self._prepared_components() if dispatch_kwargs is not None: self.dispatch_kwargs.update(dispatch_kwargs) model = self.build_model(model) model = self._init_model_weights(model) # optim_wrapper is required by booster if optim_wrapper is not None and isinstance(optim_wrapper, dict): optim_wrapper.setdefault('type', 'ColossalAIOptimWrapper') optim_wrapper_type = OPTIM_WRAPPERS.get(optim_wrapper['type']) if optim_wrapper_type is None: raise ValueError(f'Failed to find {optim_wrapper["type"]} in ' '`OPTIM_WRAPPERS`.') if 'clip_grad' in optim_wrapper: raise ValueError('`Please configure `clip_grad` in `plugin`') if not issubclass(optim_wrapper_type, ColossalAIOptimWrapper): raise ValueError( 'The type of `optim_wrapper` must be ' '`ColossalAIOptimWrapper` (or subclass), but got ' f'{optim_wrapper_type}') optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) optim_wrapper.booster = self.booster # type: ignore if optim_wrapper is not None: self.model, self.optim_wrapper = self._wrap( model, optim_wrapper) # type: ignore else: self.model = self._wrap(model) # type: ignore # TODO: Check whether `compile` is compatible with colossalai. if param_scheduler is not None: self.param_schedulers = self.build_param_scheduler( param_scheduler, optim_wrapper) # type: ignore if optim_wrapper is not None: self._scale_lr() accumulative_counts = getattr(self.optim_wrapper, '_accumulative_counts', 1) if accumulative_counts > 1: if 'max_iters' not in self.dispatch_kwargs: raise ValueError( '"max_iters" must be specified because ' '"accumulative_counts" was set as ' f'{accumulative_counts} which is greater than 1.') self.optim_wrapper.initialize_count_status( # type: ignore self.model, 0, self.dispatch_kwargs['max_iters']) self._prepared = True return self._prepared_components()
[文档] def resume( self, filename: str, *, resume_optimizer: bool = True, resume_param_scheduler: bool = True, map_location: Union[str, Callable] = 'default', callback: Optional[Callable] = None, ) -> dict: """Override this method since colossalai resume optimizer from filename directly.""" self.logger.info(f'Resume checkpoint from {filename}') extra_ckpt = self.load_checkpoint( filename, map_location=map_location, callback=callback) if resume_optimizer: self.booster.load_optimizer( self.optim_wrapper.optimizer, join_path(filename, self.OPTIMIZER_DIR)) if resume_param_scheduler: schedulers_dir = join_path(filename, self.SCHEDULER_DIR) for i, scheduler in enumerate(self.param_schedulers): self.booster.load_lr_scheduler( scheduler, f'{schedulers_dir}/scheduler_{i}.pth') # resume random seed resumed_seed = extra_ckpt['meta'].get('seed', None) current_seed = self._randomness.get('seed') if resumed_seed is not None and resumed_seed != current_seed: if current_seed is not None: self.logger.warning(f'The value of random seed in the ' f'checkpoint "{resumed_seed}" is ' f'different from the value in ' f'`randomness` config "{current_seed}"') self._randomness.update(seed=resumed_seed) self._set_randomness(**self._randomness) # resume iter self.dispatch_kwargs['cur_iter'] = extra_ckpt['meta']['iter'] return extra_ckpt
[文档] def load_checkpoint( self, filename: str, *, map_location: Union[str, Callable] = 'cpu', strict: bool = False, revise_keys: list = [(r'^module.', '')], callback: Optional[Callable] = None, ) -> dict: """Load checkpoint from given ``filename``. Warning: `map_localtion` and `callback` parameters are not supported yet. Args: filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. """ self.logger.info(f'Load checkpoint from {filename}') self.booster.load_model(self.model.model_wrapper, join_path(filename, self.MODEL_DIR)) meta = _load_checkpoint(osp.join(filename, 'meta.pth')) return meta
[文档] def save_checkpoint( self, filename: str, *, save_optimizer: bool = True, save_param_scheduler: bool = True, extra_ckpt: Optional[dict] = None, callback: Optional[Callable] = None, ) -> None: # The checkpoint directory will be: # |--epoch_0.pth # |---model/ # |---optimizer/ # |---scheduler/ if extra_ckpt is None: extra_ckpt = dict() if 'meta' not in extra_ckpt: extra_ckpt['meta'] = dict() extra_ckpt['meta'].update( seed=self.seed, time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), mmengine=mmengine.__version__ + get_git_hash()) model_dir = join_path(filename, self.MODEL_DIR) optimizer_dir = join_path(filename, self.OPTIMIZER_DIR) schedulers_dir = join_path(filename, self.SCHEDULER_DIR) mkdir_or_exist(model_dir) mkdir_or_exist(optimizer_dir) mkdir_or_exist(schedulers_dir) self.booster.save_model( self.model.model_wrapper, checkpoint=model_dir, shard=True) if save_optimizer: self.booster.save_optimizer( self.optim_wrapper.optimizer, checkpoint=optimizer_dir, shard=True) if is_main_process() and save_param_scheduler: for i, scheduler in enumerate(self.param_schedulers): self.booster.save_lr_scheduler( scheduler, f'{schedulers_dir}/scheduler_{i}.pth') save_checkpoint(extra_ckpt, join_path(filename, 'meta.pth'))
def _build_plugin(self, plugin: Union[str, dict]): if isinstance(plugin, str): if plugin == 'gemini': try: plugin = colo_plugin.GeminiPlugin( precision='bf16', placement_policy='auto') except AssertionError: from colossalai.zero.gemini.placement_policy import \ PlacementPolicyFactory as colo_placement raise ValueError('placement policy must be one of ' + f'{list(colo_placement.policies.keys())}') elif plugin == 'lowlevel-zero': plugin = colo_plugin.LowLevelZeroPlugin() else: raise ValueError('`plugin` must be "gemini" or ' '"lowlevel-zero"') elif isinstance(plugin, dict): plugin = PLUGINS.build(plugin) else: raise ValueError('`plugin` must be dict or str, but got a ' f'{type(plugin)} object)') return plugin def _build_mixed_precision(self, mixed_precision: Union[str, dict]): if isinstance(mixed_precision, str): if mixed_precision == 'fp16': mixed_precision = colo_precision.FP16TorchMixedPrecision() elif mixed_precision == 'fp16_apex': mixed_precision = colo_precision.FP16ApexMixedPrecision() elif mixed_precision == 'bf16': mixed_precision = colo_precision.BF16MixedPrecision() elif mixed_precision == 'fp8': mixed_precision = colo_precision.FP8MixedPrecision() else: raise ValueError( 'If `mixed_precision` is a string, it must be one of ' '"fp16", "fp16_apex", "bf16" and "fp8", but got ' f'{mixed_precision}') elif isinstance(mixed_precision, dict): mixed_precision = MIXED_PRECISIONS.build(mixed_precision) else: raise ValueError('mixed precision should be dict or str, but got ' f'a {type(mixed_precision)} object') return mixed_precision def _wrap( self, model: nn.Module, optim_wrapper: Optional[OptimWrapper] = None, ) -> Union[Tuple[ColossalAIModelWrapper, ColossalAIOptimWrapper], ColossalAIModelWrapper]: # type: ignore """Wrap model with :class:`ModelWrapper`.""" if self.model_wrapper is None: self.model_wrapper = {'type': 'ColossalAIModelWrapper'} # For zero series parallel, move `data_preprocessor` to current device # is reasonable. We need to `BaseDataPreprocessor.to` manually since # framework like colossalai and deepspeed could not handle it, leading # to `data_preprocessor` move data to cpu. for module in model.modules(): if isinstance(module, BaseDataPreprocessor): module.to(get_device()) if optim_wrapper is not None: optimizer = optim_wrapper.optimizer if not hasattr(optimizer, '_hook_for_profile'): # PyTorch 2.0 removes the `_hook_for_profile` in # `torch.optim.Optimizer`. We maintain this function here to # keep compatibility. # TODO: Remove this hardcode when ColossalAI supports # PyTorch 2.0 optimizer.__class__._hook_for_profile = object # We do not pass `scheduler` and `Dataloader` here for: # 1. `Booster.boost` cannot accept a list of schedulers. # 2. `Strategy` cannot not accept dataloader now. model_wrapper, optimizer, *_ = self.booster.boost(model, optimizer) optim_wrapper.optimizer = optimizer default_args = {'model_wrapper': model_wrapper, 'model': model} model_wrapper = MODEL_WRAPPERS.build( self.model_wrapper, default_args=default_args) return model_wrapper, optim_wrapper # type: ignore else: model_wrapper, *_ = self.booster.boost(model) default_args = {'model_wrapper': model_wrapper, 'model': model} model_wrapper = MODEL_WRAPPERS.build( self.model_wrapper, default_args=default_args) return model_wrapper def _setup_distributed( # type: ignore self, launcher: Optional[str] = None, backend: str = 'nccl', **kwargs, ): init_dist( launcher, backend, init_backend='colossalai', config=self.config)