Shortcuts

ColossalAIStrategy

class mmengine._strategy.ColossalAIStrategy(*, config=None, mixed_precision=None, plugin='gemini', model_wrapper=None, **kwargs)[源代码]
参数:
  • config (str | dict | None) – (str or dict): The colossalai config file to setup distributed environment. See more details in the colossalai config tutorial.

  • mixed_precision (str or MixedPrecision) – The mixed precision to run the training. Defaults to None. If the argument is a string, it can be ‘fp16’, ‘fp16_apex’, ‘bf16’, or ‘fp8’ fp16’ would use PyTorch AMP while fp16_apex would use Nvidia Apex.

  • plugin (Plugin) –

    The plugin to run the training. The type of plugin could be:

    • str: The available plugins are gemini and lowlevel-zero.

      gemini means a ZeRO implementation with chunk-based memory management. You could find more details in the colossalai gemini tutorial. lowlevel-zero means a Zero-1 and Zero-2 implementation. Although gemini is more memory saving, some unexpceted error could happen for some spectial model structure. lowlevel-zero is more stable.

    • dict: dict-type style config to build a colossalai plugin.

      See the booster plugin tutorial for more details.

  • model_wrapper (dict, optional) – Dict for model wrapper. Defaults to None.

  • work_dir (str) – The working directory to save checkpoints. The logs will be saved in the subdirectory of work_dir named timestamp. Defaults to ‘work_dirs’.

  • experiment_name (str, optional) – Name of current experiment. If not specified, timestamp will be used as experiment_name. Defaults to None.

  • env_kwargs (dict, optional) – Environment config passed in setup_env(). Defaults to None.

  • log_kwargs (dict, optional) – Logger config passed in build_logger(). Defaults to None.

  • auto_scale_lr (dict, Optional) – Config to scale the learning rate automatically. It includes base_batch_size and enable. base_batch_size is the batch size that the optimizer lr is based on. enable is the switch to turn on and off the feature.

load_checkpoint(filename, *, map_location='cpu', strict=False, revise_keys=[('^module.', '')], callback=None)[源代码]

Load checkpoint from given filename.

警告

map_localtion and callback parameters are not supported yet.

参数:
  • filename (str) – Accept local filepath, URL, torchvision://xxx, open-mmlab://xxx.

  • map_location (str | Callable) –

  • strict (bool) –

  • revise_keys (list) –

  • callback (Callable | None) –

返回类型:

dict

prepare(model, *, optim_wrapper=None, param_scheduler=None, compile=False, dispatch_kwargs=None)[源代码]

Prepare model and some components.

参数:
关键字参数:
  • optim_wrapper (BaseOptimWrapper or dict, optional) – Computing the gradient of model parameters and updating them. Defaults to None. See build_optim_wrapper() for examples.

  • param_scheduler (_ParamScheduler or dict or list, optional) – Parameter scheduler for updating optimizer parameters. If specified, optim_wrapper should also be specified. Defaults to None. See build_param_scheduler() for examples.

  • compile (dict, optional) – Config to compile model. Defaults to False. Requires PyTorch>=2.0.

  • dispatch_kwargs (dict, optional) – Kwargs to be passed to other methods of Strategy. Defaults to None. If accumulative_counts is set in optim_wrapper, you need to provide max_iters in dispatch_kwargs.

resume(filename, *, resume_optimizer=True, resume_param_scheduler=True, map_location='default', callback=None)[源代码]

override this method since colossalai resume optimizer from filename directly.

参数:
返回类型:

dict

save_checkpoint(filename, *, save_optimizer=True, save_param_scheduler=True, extra_ckpt=None, callback=None)[源代码]

Save checkpoint to given filename.

参数:
  • filename (str) – Filename to save checkpoint.

  • save_optimizer (bool) –

  • save_param_scheduler (bool) –

  • extra_ckpt (dict | None) –

  • callback (Callable | None) –

关键字参数:
  • save_optimizer (bool) – Whether to save the optimizer to the checkpoint. Defaults to True.

  • save_param_scheduler (bool) – Whether to save the param_scheduler to the checkpoint. Defaults to True.

  • extra_ckpt (dict, optional) – Extra checkpoint to save. Defaults to None.

  • callback (callable, callable) – Callback function to modify the checkpoint before saving the checkpoint. Defaults to None.

返回类型:

None

Read the Docs v: stable
Versions
latest
stable
v0.10.4
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.