Shortcuts

ColossalAIStrategy

class mmengine._strategy.ColossalAIStrategy(*, config=None, mixed_precision=None, plugin='gemini', model_wrapper=None, **kwargs)[source]
Parameters:
  • config (str | dict | None) – (str or dict): The colossalai config file to setup distributed environment. See more details in the colossalai config tutorial.

  • mixed_precision (str or MixedPrecision) – The mixed precision to run the training. Defaults to None. If the argument is a string, it can be ‘fp16’, ‘fp16_apex’, ‘bf16’, or ‘fp8’ fp16’ would use PyTorch AMP while fp16_apex would use Nvidia Apex.

  • plugin (Plugin) –

    The plugin to run the training. The type of plugin could be:

    • str: The available plugins are gemini and lowlevel-zero.

      gemini means a ZeRO implementation with chunk-based memory management. You could find more details in the colossalai gemini tutorial. lowlevel-zero means a Zero-1 and Zero-2 implementation. Although gemini is more memory saving, some unexpceted error could happen for some spectial model structure. lowlevel-zero is more stable.

    • dict: dict-type style config to build a colossalai plugin.

      See the booster plugin tutorial for more details.

  • model_wrapper (dict, optional) – Dict for model wrapper. Defaults to None.

  • work_dir (str) – The working directory to save checkpoints. The logs will be saved in the subdirectory of work_dir named timestamp. Defaults to ‘work_dirs’.

  • experiment_name (str, optional) – Name of current experiment. If not specified, timestamp will be used as experiment_name. Defaults to None.

  • env_kwargs (dict, optional) – Environment config passed in setup_env(). Defaults to None.

  • log_kwargs (dict, optional) – Logger config passed in build_logger(). Defaults to None.

  • auto_scale_lr (dict, Optional) – Config to scale the learning rate automatically. It includes base_batch_size and enable. base_batch_size is the batch size that the optimizer lr is based on. enable is the switch to turn on and off the feature.

load_checkpoint(filename, *, map_location='cpu', strict=False, revise_keys=[('^module.', '')], callback=None)[source]

Load checkpoint from given filename.

Warning

map_localtion and callback parameters are not supported yet.

Parameters:
  • filename (str) – Accept local filepath, URL, torchvision://xxx, open-mmlab://xxx.

  • map_location (str | Callable) –

  • strict (bool) –

  • revise_keys (list) –

  • callback (Callable | None) –

Return type:

dict

prepare(model, *, optim_wrapper=None, param_scheduler=None, compile=False, dispatch_kwargs=None)[source]

Prepare model and some components.

Parameters:
Keyword Arguments:
  • optim_wrapper (BaseOptimWrapper or dict, optional) – Computing the gradient of model parameters and updating them. Defaults to None. See build_optim_wrapper() for examples.

  • param_scheduler (_ParamScheduler or dict or list, optional) – Parameter scheduler for updating optimizer parameters. If specified, optim_wrapper should also be specified. Defaults to None. See build_param_scheduler() for examples.

  • compile (dict, optional) – Config to compile model. Defaults to False. Requires PyTorch>=2.0.

  • dispatch_kwargs (dict, optional) – Kwargs to be passed to other methods of Strategy. Defaults to None. If accumulative_counts is set in optim_wrapper, you need to provide max_iters in dispatch_kwargs.

resume(filename, *, resume_optimizer=True, resume_param_scheduler=True, map_location='default', callback=None)[source]

override this method since colossalai resume optimizer from filename directly.

Parameters:
Return type:

dict

save_checkpoint(filename, *, save_optimizer=True, save_param_scheduler=True, extra_ckpt=None, callback=None)[source]

Save checkpoint to given filename.

Parameters:
  • filename (str) – Filename to save checkpoint.

  • save_optimizer (bool) –

  • save_param_scheduler (bool) –

  • extra_ckpt (dict | None) –

  • callback (Callable | None) –

Keyword Arguments:
  • save_optimizer (bool) – Whether to save the optimizer to the checkpoint. Defaults to True.

  • save_param_scheduler (bool) – Whether to save the param_scheduler to the checkpoint. Defaults to True.

  • extra_ckpt (dict, optional) – Extra checkpoint to save. Defaults to None.

  • callback (callable, callable) – Callback function to modify the checkpoint before saving the checkpoint. Defaults to None.

Return type:

None