ColossalAIStrategy¶
- class mmengine._strategy.ColossalAIStrategy(*, config=None, mixed_precision=None, plugin='gemini', model_wrapper=None, **kwargs)[source]¶
- Parameters:
config (str | dict | None) – (str or dict): The colossalai config file to setup distributed environment. See more details in the colossalai config tutorial.
mixed_precision (str or MixedPrecision) – The mixed precision to run the training. Defaults to None. If the argument is a string, it can be ‘fp16’, ‘fp16_apex’, ‘bf16’, or ‘fp8’ fp16’ would use PyTorch AMP while fp16_apex would use Nvidia Apex.
plugin (Plugin) –
The plugin to run the training. The type of plugin could be:
str: The available plugins are
gemini
andlowlevel-zero
.gemini
means a ZeRO implementation with chunk-based memory management. You could find more details in the colossalai gemini tutorial.lowlevel-zero
means a Zero-1 and Zero-2 implementation. Although gemini is more memory saving, some unexpceted error could happen for some spectial model structure. lowlevel-zero is more stable.dict: dict-type style config to build a colossalai plugin.
See the booster plugin tutorial for more details.
model_wrapper (dict, optional) – Dict for model wrapper. Defaults to None.
work_dir (str) – The working directory to save checkpoints. The logs will be saved in the subdirectory of work_dir named
timestamp
. Defaults to ‘work_dirs’.experiment_name (str, optional) – Name of current experiment. If not specified, timestamp will be used as
experiment_name
. Defaults to None.env_kwargs (dict, optional) – Environment config passed in
setup_env()
. Defaults to None.log_kwargs (dict, optional) – Logger config passed in
build_logger()
. Defaults to None.auto_scale_lr (dict, Optional) – Config to scale the learning rate automatically. It includes
base_batch_size
andenable
.base_batch_size
is the batch size that the optimizer lr is based on.enable
is the switch to turn on and off the feature.
- load_checkpoint(filename, *, map_location='cpu', strict=False, revise_keys=[('^module.', '')], callback=None)[source]¶
Load checkpoint from given
filename
.Warning
map_localtion and callback parameters are not supported yet.
- prepare(model, *, optim_wrapper=None, param_scheduler=None, compile=False, dispatch_kwargs=None)[source]¶
Prepare model and some components.
- Parameters:
model (
torch.nn.Module
or dict) – The model to be run. It can be a dict used for build a model.optim_wrapper (BaseOptimWrapper | dict | None) –
param_scheduler (_ParamScheduler | Dict | List | None) –
dispatch_kwargs (dict | None) –
- Keyword Arguments:
optim_wrapper (BaseOptimWrapper or dict, optional) – Computing the gradient of model parameters and updating them. Defaults to None. See
build_optim_wrapper()
for examples.param_scheduler (_ParamScheduler or dict or list, optional) – Parameter scheduler for updating optimizer parameters. If specified,
optim_wrapper
should also be specified. Defaults to None. Seebuild_param_scheduler()
for examples.compile (dict, optional) – Config to compile model. Defaults to False. Requires PyTorch>=2.0.
dispatch_kwargs (dict, optional) – Kwargs to be passed to other methods of Strategy. Defaults to None. If
accumulative_counts
is set inoptim_wrapper
, you need to providemax_iters
indispatch_kwargs
.
- resume(filename, *, resume_optimizer=True, resume_param_scheduler=True, map_location='default', callback=None)[source]¶
override this method since colossalai resume optimizer from filename directly.
- save_checkpoint(filename, *, save_optimizer=True, save_param_scheduler=True, extra_ckpt=None, callback=None)[source]¶
Save checkpoint to given
filename
.- Parameters:
- Keyword Arguments:
save_optimizer (bool) – Whether to save the optimizer to the checkpoint. Defaults to True.
save_param_scheduler (bool) – Whether to save the param_scheduler to the checkpoint. Defaults to True.
extra_ckpt (dict, optional) – Extra checkpoint to save. Defaults to None.
callback (callable, callable) – Callback function to modify the checkpoint before saving the checkpoint. Defaults to None.
- Return type:
None