Runner¶
- class mmengine.runner.Runner(model, work_dir, train_dataloader=None, val_dataloader=None, test_dataloader=None, train_cfg=None, val_cfg=None, test_cfg=None, auto_scale_lr=None, optim_wrapper=None, param_scheduler=None, val_evaluator=None, test_evaluator=None, default_hooks=None, custom_hooks=None, data_preprocessor=None, load_from=None, resume=False, launcher='none', env_cfg={'dist_cfg': {'backend': 'nccl'}}, log_processor=None, log_level='INFO', visualizer=None, default_scope='mmengine', randomness={'seed': None}, experiment_name=None, cfg=None)[source]¶
A training helper for PyTorch.
Runner object can be built from config by
runner = Runner.from_cfg(cfg)
where thecfg
usually contains training, validation, and test-related configurations to build corresponding components. We usually use the same config to launch training, testing, and validation tasks. However, only some of these components are necessary at the same time, e.g., testing a model does not need training or validation-related components.To avoid repeatedly modifying config, the construction of
Runner
adopts lazy initialization to only initialize components when they are going to be used. Therefore, the model is always initialized at the beginning, and training, validation, and, testing related components are only initialized when callingrunner.train()
,runner.val()
, andrunner.test()
, respectively.- Parameters:
model (
torch.nn.Module
or dict) – The model to be run. It can be a dict used for build a model.work_dir (str) – The working directory to save checkpoints. The logs will be saved in the subdirectory of work_dir named
timestamp
.train_dataloader (Dataloader or dict, optional) – A dataloader object or a dict to build a dataloader. If
None
is given, it means skipping training steps. Defaults to None. Seebuild_dataloader()
for more details.val_dataloader (Dataloader or dict, optional) – A dataloader object or a dict to build a dataloader. If
None
is given, it means skipping validation steps. Defaults to None. Seebuild_dataloader()
for more details.test_dataloader (Dataloader or dict, optional) – A dataloader object or a dict to build a dataloader. If
None
is given, it means skipping test steps. Defaults to None. Seebuild_dataloader()
for more details.train_cfg (dict, optional) – A dict to build a training loop. If it does not provide “type” key, it should contain “by_epoch” to decide which type of training loop
EpochBasedTrainLoop
orIterBasedTrainLoop
should be used. Iftrain_cfg
specified,train_dataloader
should also be specified. Defaults to None. Seebuild_train_loop()
for more details.val_cfg (dict, optional) – A dict to build a validation loop. If it does not provide “type” key,
ValLoop
will be used by default. Ifval_cfg
specified,val_dataloader
should also be specified. IfValLoop
is built with fp16=True`,runner.val()
will be performed under fp16 precision. Defaults to None. Seebuild_val_loop()
for more details.test_cfg (dict, optional) – A dict to build a test loop. If it does not provide “type” key,
TestLoop
will be used by default. Iftest_cfg
specified,test_dataloader
should also be specified. IfValLoop
is built with fp16=True`,runner.val()
will be performed under fp16 precision. Defaults to None. Seebuild_test_loop()
for more details.auto_scale_lr (dict, Optional) – Config to scale the learning rate automatically. It includes
base_batch_size
andenable
.base_batch_size
is the batch size that the optimizer lr is based on.enable
is the switch to turn on and off the feature.optim_wrapper (OptimWrapper or dict, optional) – Computing gradient of model parameters. If specified,
train_dataloader
should also be specified. If automatic mixed precision or gradient accmulation training is required. The type ofoptim_wrapper
should be AmpOptimizerWrapper. Seebuild_optim_wrapper()
for examples. Defaults to None.param_scheduler (_ParamScheduler or dict or list, optional) – Parameter scheduler for updating optimizer parameters. If specified,
optimizer
should also be specified. Defaults to None. Seebuild_param_scheduler()
for examples.val_evaluator (Evaluator or dict or list, optional) – A evaluator object used for computing metrics for validation. It can be a dict or a list of dict to build a evaluator. If specified,
val_dataloader
should also be specified. Defaults to None.test_evaluator (Evaluator or dict or list, optional) – A evaluator object used for computing metrics for test steps. It can be a dict or a list of dict to build a evaluator. If specified,
test_dataloader
should also be specified. Defaults to None.default_hooks (dict[str, dict] or dict[str, Hook], optional) – Hooks to execute default actions like updating model parameters and saving checkpoints. Default hooks are
OptimizerHook
,IterTimerHook
,LoggerHook
,ParamSchedulerHook
andCheckpointHook
. Defaults to None. Seeregister_default_hooks()
for more details.custom_hooks (list[dict] or list[Hook], optional) – Hooks to execute custom actions like visualizing images processed by pipeline. Defaults to None.
data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
. If themodel
argument is a dict and doesn’t contain the keydata_preprocessor
, set the argument as thedata_preprocessor
of themodel
dict. Defaults to None.load_from (str, optional) – The checkpoint file to load from. Defaults to None.
resume (bool) – Whether to resume training. Defaults to False. If
resume
is True andload_from
is None, automatically to find latest checkpoint fromwork_dir
. If not found, resuming does nothing.launcher (str) – Way to launcher multi-process. Supported launchers are ‘pytorch’, ‘mpi’, ‘slurm’ and ‘none’. If ‘none’ is provided, non-distributed environment will be launched.
env_cfg (dict) – A dict used for setting environment. Defaults to dict(dist_cfg=dict(backend=’nccl’)).
log_processor (dict, optional) – A processor to format logs. Defaults to None.
log_level (int or str) – The log level of MMLogger handlers. Defaults to ‘INFO’.
visualizer (Visualizer or dict, optional) – A Visualizer object or a dict build Visualizer object. Defaults to None. If not specified, default config will be used.
default_scope (str) – Used to reset registries location. Defaults to “mmengine”.
randomness (dict) – Some settings to make the experiment as reproducible as possible like seed and deterministic. Defaults to
dict(seed=None)
. If seed is None, a random number will be generated and it will be broadcasted to all other processes if in distributed environment. Ifcudnn_benchmark
isTrue
inenv_cfg
butdeterministic
isTrue
inrandomness
, the value oftorch.backends.cudnn.benchmark
will beFalse
finally.experiment_name (str, optional) – Name of current experiment. If not specified, timestamp will be used as
experiment_name
. Defaults to None.cfg (dict or Configdict or
Config
, optional) – Full config. Defaults to None.
Note
Since PyTorch 2.0.0, you can enable
torch.compile
by passing in cfg.compile = True. If you want to control compile options, you can pass a dict, e.g.cfg.compile = dict(backend='eager')
. Refer to PyTorch API Documentation for more valid options.Examples
>>> from mmengine.runner import Runner >>> cfg = dict( >>> model=dict(type='ToyModel'), >>> work_dir='path/of/work_dir', >>> train_dataloader=dict( >>> dataset=dict(type='ToyDataset'), >>> sampler=dict(type='DefaultSampler', shuffle=True), >>> batch_size=1, >>> num_workers=0), >>> val_dataloader=dict( >>> dataset=dict(type='ToyDataset'), >>> sampler=dict(type='DefaultSampler', shuffle=False), >>> batch_size=1, >>> num_workers=0), >>> test_dataloader=dict( >>> dataset=dict(type='ToyDataset'), >>> sampler=dict(type='DefaultSampler', shuffle=False), >>> batch_size=1, >>> num_workers=0), >>> auto_scale_lr=dict(base_batch_size=16, enable=False), >>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict( >>> type='SGD', lr=0.01)), >>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), >>> val_evaluator=dict(type='ToyEvaluator'), >>> test_evaluator=dict(type='ToyEvaluator'), >>> train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1), >>> val_cfg=dict(), >>> test_cfg=dict(), >>> custom_hooks=[], >>> default_hooks=dict( >>> timer=dict(type='IterTimerHook'), >>> checkpoint=dict(type='CheckpointHook', interval=1), >>> logger=dict(type='LoggerHook'), >>> optimizer=dict(type='OptimizerHook', grad_clip=False), >>> param_scheduler=dict(type='ParamSchedulerHook')), >>> launcher='none', >>> env_cfg=dict(dist_cfg=dict(backend='nccl')), >>> log_processor=dict(window_size=20), >>> visualizer=dict(type='Visualizer', >>> vis_backends=[dict(type='LocalVisBackend', >>> save_dir='temp_dir')]) >>> ) >>> runner = Runner.from_cfg(cfg) >>> runner.train() >>> runner.test()
- static build_dataloader(dataloader, seed=None, diff_rank_seed=False)[source]¶
Build dataloader.
The method builds three components:
Dataset
Sampler
Dataloader
An example of
dataloader
:dataloader = dict( dataset=dict(type='ToyDataset'), sampler=dict(type='DefaultSampler', shuffle=True), batch_size=1, num_workers=9 )
- Parameters:
dataloader (DataLoader or dict) – A Dataloader object or a dict to build Dataloader object. If
dataloader
is a Dataloader object, just returns itself.seed (int, optional) – Random seed. Defaults to None.
diff_rank_seed (bool) – Whether or not set different seeds to different ranks. If True, the seed passed to sampler is set to None, in order to synchronize the seeds used in samplers across different ranks.
- Returns:
DataLoader build from
dataloader_cfg
.- Return type:
Dataloader
- build_evaluator(evaluator)[source]¶
Build evaluator.
Examples of
evaluator
:# evaluator could be a built Evaluator instance evaluator = Evaluator(metrics=[ToyMetric()]) # evaluator can also be a list of dict evaluator = [ dict(type='ToyMetric1'), dict(type='ToyEvaluator2') ] # evaluator can also be a list of built metric evaluator = [ToyMetric1(), ToyMetric2()] # evaluator can also be a dict with key metrics evaluator = dict(metrics=ToyMetric()) # metric is a list evaluator = dict(metrics=[ToyMetric()])
- build_log_processor(log_processor)[source]¶
Build test log_processor.
Examples of
log_processor
:# LogProcessor will be used log_processor = dict()
# custom log_processor log_processor = dict(type=’CustomLogProcessor’)
- Parameters:
log_processor (LogProcessor or dict) – A log processor or a dict
processor (to build log processor. If log_processor is a log) –
object –
itself. (just returns) –
- Returns:
Log processor object build from
log_processor_cfg
.- Return type:
- build_logger(log_level='INFO', log_file=None, **kwargs)[source]¶
Build a global asscessable MMLogger.
- build_message_hub(message_hub=None)[source]¶
Build a global asscessable MessageHub.
- Parameters:
message_hub (dict, optional) – A dict to build MessageHub object. If not specified, default config will be used to build MessageHub object. Defaults to None.
- Returns:
A MessageHub object build from
message_hub
.- Return type:
- build_model(model)[source]¶
Build model.
If
model
is a dict, it will be used to build a nn.Module object. Else, ifmodel
is a nn.Module object it will be returned directly.An example of
model
:model = dict(type='ResNet')
- Parameters:
model (nn.Module or dict) – A
nn.Module
object or a dict to build nn.Module object. Ifmodel
is a nn.Module object, just returns itself.- Return type:
Note
The returned model must implement
train_step
,test_step
ifrunner.train
orrunner.test
will be called. Ifrunner.val
will be called orval_cfg
is configured, model must implement val_step.
- build_optim_wrapper(optim_wrapper)[source]¶
Build optimizer wrapper.
If
optim_wrapper
is a config dict for only one optimizer, the keys must containoptimizer
, andtype
is optional. It will build aOptimWrapper
by default.If
optim_wrapper
is a config dict for multiple optimizers, i.e., it has multiple keys and each key is for an optimizer wrapper. The constructor must be specified sinceDefaultOptimizerConstructor
cannot handle the building of training with multiple optimizers.If
optim_wrapper
is a dict of pre-built optimizer wrappers, i.e., each value ofoptim_wrapper
represents anOptimWrapper
instance.build_optim_wrapper
will directly build theOptimWrapperDict
instance fromoptim_wrapper
.- Parameters:
optim_wrapper (OptimWrapper or dict) – An OptimWrapper object or a dict to build OptimWrapper objects. If
optim_wrapper
is an OptimWrapper, just return anOptimizeWrapper
instance.- Return type:
Note
For single optimizer training, if optim_wrapper is a config dict, type is optional(defaults to
OptimWrapper
) and it must contain optimizer to build the corresponding optimizer.Examples
>>> # build an optimizer >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( ... type='SGD', lr=0.01)) >>> # optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) >>> # is also valid. >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) >>> optim_wrapper Type: OptimWrapper accumulative_counts: 1 optimizer: SGD ( Parameter Group 0 dampening: 0 lr: 0.01 momentum: 0 nesterov: False weight_decay: 0 ) >>> # build optimizer without `type` >>> optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) >>> optim_wrapper Type: OptimWrapper accumulative_counts: 1 optimizer: SGD ( Parameter Group 0 dampening: 0 lr: 0.01 maximize: False momentum: 0 nesterov: False weight_decay: 0 ) >>> # build multiple optimizers >>> optim_wrapper_cfg = dict( ... generator=dict(type='OptimWrapper', optimizer=dict( ... type='SGD', lr=0.01)), ... discriminator=dict(type='OptimWrapper', optimizer=dict( ... type='Adam', lr=0.001)) ... # need to customize a multiple optimizer constructor ... constructor='CustomMultiOptimizerConstructor', ...) >>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg) >>> optim_wrapper name: generator Type: OptimWrapper accumulative_counts: 1 optimizer: SGD ( Parameter Group 0 dampening: 0 lr: 0.1 momentum: 0 nesterov: False weight_decay: 0 ) name: discriminator Type: OptimWrapper accumulative_counts: 1 optimizer: 'discriminator': Adam ( Parameter Group 0 dampening: 0 lr: 0.02 momentum: 0 nesterov: False weight_decay: 0 )
Important
If you need to build multiple optimizers, you should implement a MultiOptimWrapperConstructor which gets parameters passed to corresponding optimizers and compose the
OptimWrapperDict
. More details about how to customize OptimizerConstructor can be found at optimizer-docs.- Returns:
Optimizer wrapper build from
optimizer_cfg
.- Return type:
- Parameters:
optim_wrapper (Optimizer | OptimWrapper | Dict) –
- build_param_scheduler(scheduler)[source]¶
Build parameter schedulers.
build_param_scheduler
should be called afterbuild_optim_wrapper
because the building logic will change according to the number of optimizers built by the runner. The cases are as below:Single optimizer: When only one optimizer is built and used in the runner,
build_param_scheduler
will return a list of parameter schedulers.Multiple optimizers: When two or more optimizers are built and used in runner,
build_param_scheduler
will return a dict containing the same keys with multiple optimizers and each value is a list of parameter schedulers. Note that, if you want different optimizers to use different parameter schedulers to update optimizer’s hyper-parameters, the input parameterscheduler
also needs to be a dict and its key are consistent with multiple optimizers. Otherwise, the same parameter schedulers will be used to update optimizer’s hyper-parameters.
- Parameters:
scheduler (_ParamScheduler or dict or list) – A Param Scheduler object or a dict or list of dict to build parameter schedulers.
- Return type:
Examples
>>> # build one scheduler >>> optim_cfg = dict(dict(type='SGD', lr=0.01)) >>> runner.optim_wrapper = runner.build_optim_wrapper( >>> optim_cfg) >>> scheduler_cfg = dict(type='MultiStepLR', milestones=[1, 2]) >>> schedulers = runner.build_param_scheduler(scheduler_cfg) >>> schedulers [<mmengine.optim.scheduler.lr_scheduler.MultiStepLR at 0x7f70f6966290>] # noqa: E501
>>> # build multiple schedulers >>> scheduler_cfg = [ ... dict(type='MultiStepLR', milestones=[1, 2]), ... dict(type='StepLR', step_size=1) ... ] >>> schedulers = runner.build_param_scheduler(scheduler_cfg) >>> schedulers [<mmengine.optim.scheduler.lr_scheduler.MultiStepLR at 0x7f70f60dd3d0>, # noqa: E501 <mmengine.optim.scheduler.lr_scheduler.StepLR at 0x7f70f6eb6150>]
Above examples only provide the case of one optimizer and one scheduler or multiple schedulers. If you want to know how to set parameter scheduler when using multiple optimizers, you can find more examples optimizer-docs.
- Returns:
List of parameter schedulers or a dictionary contains list of parameter schedulers build from
scheduler
.- Return type:
list[_ParamScheduler] or dict[str, list[_ParamScheduler]]
- Parameters:
scheduler (_ParamScheduler | Dict | List) –
- build_test_loop(loop)[source]¶
Build test loop.
Examples of
loop
:# `TestLoop` will be used loop = dict() # custom test loop loop = dict(type='CustomTestLoop')
- build_train_loop(loop)[source]¶
Build training loop.
Examples of
loop
:# `EpochBasedTrainLoop` will be used loop = dict(by_epoch=True, max_epochs=3) # `IterBasedTrainLoop` will be used loop = dict(by_epoch=False, max_epochs=3) # custom training loop loop = dict(type='CustomTrainLoop', max_epochs=3)
- build_val_loop(loop)[source]¶
Build validation loop.
Examples of
loop
:# ValLoop will be used loop = dict()
# custom validation loop loop = dict(type=’CustomValLoop’)
- build_visualizer(visualizer=None)[source]¶
Build a global asscessable Visualizer.
- Parameters:
visualizer (Visualizer or dict, optional) – A Visualizer object or a dict to build Visualizer object. If
visualizer
is a Visualizer object, just returns itself. If not specified, default config will be used to build Visualizer object. Defaults to None.- Returns:
A Visualizer object build from
visualizer
.- Return type:
- call_hook(fn_name, **kwargs)[source]¶
Call all hooks.
- Parameters:
fn_name (str) – The function name in each hook to be called, such as “before_train_epoch”.
**kwargs – Keyword arguments passed to hook.
- Return type:
None
- classmethod from_cfg(cfg)[source]¶
Build a runner from config.
- Parameters:
cfg (ConfigType) – A config used for building runner. Keys of
cfg
can see__init__()
.- Returns:
A runner build from
cfg
.- Return type:
- property hooks¶
A list of registered hooks.
- Type:
list[
Hook
]
- load_checkpoint(filename, map_location='cpu', strict=False, revise_keys=[('^module.', '')])[source]¶
Load checkpoint from given
filename
.- Parameters:
filename (str) – Accept local filepath, URL,
torchvision://xxx
,open-mmlab://xxx
.map_location (str or callable) – A string or a callable function to specifying how to remap storage locations. Defaults to ‘cpu’.
strict (bool) – strict (bool): Whether to allow different params for the model and checkpoint.
revise_keys (list) – A list of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Defaults to strip the prefix ‘module.’ by [(r’^module.’, ‘’)].
- register_default_hooks(hooks=None)[source]¶
Register default hooks into hook list.
hooks
will be registered into runner to execute some default actions like updating model parameters or saving checkpoints.Default hooks and their priorities:
Hooks
Priority
RuntimeInfoHook
VERY_HIGH (10)
IterTimerHook
NORMAL (50)
DistSamplerSeedHook
NORMAL (50)
LoggerHook
BELOW_NORMAL (60)
ParamSchedulerHook
LOW (70)
CheckpointHook
VERY_LOW (90)
If
hooks
is None, above hooks will be registered by default:default_hooks = dict( runtime_info=dict(type='RuntimeInfoHook'), timer=dict(type='IterTimerHook'), sampler_seed=dict(type='DistSamplerSeedHook'), logger=dict(type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=1), )
If not None,
hooks
will be merged intodefault_hooks
. If there are None value in default_hooks, the corresponding item will be popped fromdefault_hooks
:hooks = dict(timer=None)
The final registered default hooks will be
RuntimeInfoHook
,DistSamplerSeedHook
,LoggerHook
,ParamSchedulerHook
andCheckpointHook
.
- register_hook(hook, priority=None)[source]¶
Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified priority (See
Priority
for details of priorities). For hooks with the same priority, they will be triggered in the same order as they are registered.Priority of hook will be decided with the following priority:
priority
argument. Ifpriority
is given, it will be priority of hook.If
hook
argument is a dict andpriority
in it, the priority will be the value ofhook['priority']
.If
hook
argument is a dict butpriority
not in it orhook
is an instance ofhook
, the priority will behook.priority
.
- Parameters:
hook (
Hook
or dict) – The hook to be registered.priority (int or str or
Priority
, optional) – Hook priority. Lower value means higher priority.
- Return type:
None
- register_hooks(default_hooks=None, custom_hooks=None)[source]¶
Register default hooks and custom hooks into hook list.
- Parameters:
default_hooks (dict[str, dict] or dict[str, Hook], optional) – Hooks to execute default actions like updating model parameters and saving checkpoints. Defaults to None.
custom_hooks (list[dict] or list[Hook], optional) – Hooks to execute custom actions like visualizing images processed by pipeline. Defaults to None.
- Return type:
None
- resume(filename, resume_optimizer=True, resume_param_scheduler=True, map_location='default')[source]¶
Resume model from checkpoint.
- Parameters:
filename (str) – Accept local filepath, URL,
torchvision://xxx
,open-mmlab://xxx
.resume_optimizer (bool) – Whether to resume optimizer state. Defaults to True.
resume_param_scheduler (bool) – Whether to resume param scheduler state. Defaults to True.
map_location (str or callable) – A string or a callable function to specifying how to remap storage locations. Defaults to ‘default’.
- Return type:
None
- save_checkpoint(out_dir, filename, file_client_args=None, save_optimizer=True, save_param_scheduler=True, meta=None, by_epoch=True, backend_args=None)[source]¶
Save checkpoints.
CheckpointHook
invokes this method to save checkpoints periodically.- Parameters:
out_dir (str) – The directory that checkpoints are saved.
filename (str) – The checkpoint filename.
file_client_args (dict, optional) – Arguments to instantiate a FileClient. See
mmengine.fileio.FileClient
for details. Defaults to None. It will be deprecated in future. Please use backend_args instead.save_optimizer (bool) – Whether to save the optimizer to the checkpoint. Defaults to True.
save_param_scheduler (bool) – Whether to save the param_scheduler to the checkpoint. Defaults to True.
meta (dict, optional) – The meta information to be saved in the checkpoint. Defaults to None.
by_epoch (bool) – Decide the number of epoch or iteration saved in checkpoint. Defaults to True.
backend_args (dict, optional) – Arguments to instantiate the prefix of uri corresponding backend. Defaults to None. New in v0.2.0.
- scale_lr(optim_wrapper, auto_scale_lr=None)[source]¶
Automatically scaling learning rate in training according to the ratio of
base_batch_size
inautoscalelr_cfg
and real batch size.It scales the learning rate linearly according to the paper.
Note
scale_lr
must be called after building optimizer wrappers and before building parameter schedulers.- Parameters:
optim_wrapper (OptimWrapper) – An OptimWrapper object whose parameter groups’ learning rate need to be scaled.
auto_scale_lr (Dict, Optional) – Config to scale the learning rate automatically. It includes
base_batch_size
andenable
.base_batch_size
is the batch size that the optimizer lr is based on.enable
is the switch to turn on and off the feature.
- Return type:
None
- set_randomness(seed, diff_rank_seed=False, deterministic=False)[source]¶
Set random seed to guarantee reproducible results.
- Parameters:
seed (int) – A number to set random modules.
diff_rank_seed (bool) – Whether or not set different seeds according to global rank. Defaults to False.
deterministic (bool) – Whether to set the deterministic option for CUDNN backend, i.e., set torch.backends.cudnn.deterministic to True and torch.backends.cudnn.benchmark to False. Defaults to False. See https://pytorch.org/docs/stable/notes/randomness.html for more details.
- Return type:
None
- setup_env(env_cfg)[source]¶
Setup environment.
An example of
env_cfg
:env_cfg = dict( cudnn_benchmark=True, mp_cfg=dict( mp_start_method='fork', opencv_num_threads=0 ), dist_cfg=dict(backend='nccl', timeout=1800), resource_limit=4096 )
- Parameters:
env_cfg (dict) – Config for setting environment.
- Return type:
None
- property test_dataloader¶
The data loader for testing.
- property test_evaluator¶
An evaluator for testing.
- Type:
Evaluator
- property train_dataloader¶
The data loader for training.
- property val_dataloader¶
The data loader for validation.
- property val_evaluator¶
An evaluator for validation.
- Type:
Evaluator
- wrap_model(model_wrapper_cfg, model)[source]¶
Wrap the model to
MMDistributedDataParallel
or other custom distributed data-parallel module wrappers.An example of
model_wrapper_cfg
:model_wrapper_cfg = dict( broadcast_buffers=False, find_unused_parameters=False )
- Parameters:
model_wrapper_cfg (dict, optional) – Config to wrap model. If not specified,
DistributedDataParallel
will be used in distributed environment. Defaults to None.model (nn.Module) – Model to be wrapped.
- Returns:
nn.Module or subclass of
DistributedDataParallel
.- Return type:
nn.Module or DistributedDataParallel