Migrate Hook from MMCV to MMEngine¶
Introduction¶
Due to the upgrade of our architecture design and the continuous increase of user demands, existing hook mount points in MMCV can no longer meet the requirements. Hence, we redesigned the mount points in MMEngine, and the functions of hooks were adjusted accordingly. It will help a lot to read the tutorial Hook Design before your migration.
This tutorial compares the difference in function, mount point, usage and implementation between MMCV v1.6.0 and MMEngine v0.5.0.
Function Comparison¶
MMCV | MMEngine | |
---|---|---|
Backpropagation and gradient update | OptimizerHook | Unify the backpropagation and gradient update operations into OptimWrapper rather than hooks |
GradientCumulativeOptimizerHook | ||
Learning rate adjustment | LrUpdaterHook | Use ParamSchdulerHook and subclasses of _ParamScheduler to complete the adjustment of optimizer hyperparameters |
Momentum adjustment | MomentumUpdaterHook | |
Saving model weights at specified interval | CheckpointHook | The CheckpointHook is responsible for not only saving weights but also saving the optimal weights. Meanwhile, the model evaluation function of EvalHook is delegated to ValLoop or TestLoop. |
Model evaluation and optimal weights saving | EvalHook | |
Log printing | LoggerHook and its subclasses can print logs, save logs and visualize data | LoggerHook |
Visualization | NaiveVisualizationHook | |
Adding runtime information | RuntimeInfoHook | |
Model weights exponential moving average (EMA) | EMAHook | EMAHook |
Ensuring that the shuffle functionality of the distributed Sampler takes effect | DistSamplerSeedHook | DistSamplerSeedHook |
Synchronizing model buffer | SyncBufferHook | SyncBufferHook |
Empty PyTorch CUDA cache | EmptyCacheHook | EmptyCacheHook |
Calculating iteration time-consuming | IterTimerHook | IterTimerHook |
Analyzing bottlenecks of training time | ProfilerHook | Not yet available |
Provide the most concise function registration | ClosureHook | Not yet available |
Mount Point Comparison¶
MMCV | MMEngine | ||
---|---|---|---|
Global mount points | before run | before_run | before_run |
after run | after_run | after_run | |
Checkpoint related | after loading checkpoints | None | after_load_checkpoint |
before saving checkpoints | None | before_save_checkpoint | |
Training related | triggered before training | None | before_train |
triggered after training | None | after_train | |
before each epoch | before_train_epoch | before_train_epoch | |
after each epoch | after_train_epoch | after_train_epoch | |
before each iteration | before_train_iter | before_train_iter, with additional args: batch_idx and data_batch | |
after each iteration | after_train_iter | after_train_iter, with additional args: batch_idx, data_batch, and outputs | |
Validation related | before validation | None | before_val |
after validation | None | after_val | |
before each epoch | before_val_epoch | before_val_epoch | |
after each epoch | after_val_epoch | after_val_epoch | |
before each iteration | before_val_iter | before_val_iter, with additional args: batch_idx and data_batch | |
after each iteration | after_val_iter | after_val_iter, with additional args: batch_idx, data_batch and outputs | |
Test related | before test | None | before_test |
after test | None | after_test | |
before each epoch | None | before_test_epoch | |
after each epoch | None | after_test_epoch | |
before each iteration | None | before_test_iter, with additional args: batch_idx and data_batch | |
after each iteration | None | after_test_iter, with additional args: batch_idx, data_batch and outputs |
Usage Comparison¶
In MMCV, to register hooks to the runner, you need to call the Runner’s register_training_hooks
method to register hooks to the Runner. In MMEngine, you can register hooks by passing them as parameters to the Runner’s initialization method.
MMCV
model = ResNet18()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
lr_config = dict(policy='step', step=[2, 3])
optimizer_config = dict(grad_clip=None)
checkpoint_config = dict(interval=5)
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
runner = EpochBasedRunner(
model=model,
optimizer=optimizer,
work_dir='./work_dir',
max_epochs=3,
xxx,
)
runner.register_training_hooks(
lr_config=lr_config,
optimizer_config=optimizer_config,
checkpoint_config=checkpoint_config,
log_config=log_config,
custom_hooks_config=custom_hooks,
)
runner.run([trainloader], [('train', 1)])
MMEngine
model=ResNet18()
optim_wrapper=dict(
type='OptimizerWrapper',
optimizer=dict(type='SGD', lr=0.001, momentum=0.9))
param_scheduler = dict(type='MultiStepLR', milestones=[2, 3]),
default_hooks = dict(
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=5),
)
custom_hooks = [dict(type='NumClassCheckHook')]
runner = Runner(
model=model,
work_dir='./work_dir',
optim_wrapper=optim_wrapper,
param_scheduler=param_scheduler,
train_cfg=dict(by_epoch=True, max_epochs=3),
default_hooks=default_hooks,
custom_hooks=custom_hooks,
xxx,
)
runner.train()
For more details of MMEngine hooks, please refer to Usage of Hooks.
Implementation Comparison¶
Taking CheckpointHook
as an example, compared with CheckpointHook in MMCV, CheckpointHook of MMEngine needs to implement the after_val_epoch
method, since new CheckpointHook
supports saving the optimal weights, while in MMCV, the function is achieved by EvalHook.
MMCV
class CheckpointHook(Hook):
def before_run(self, runner):
"""Initialize out_dir and file_client"""
def after_train_epoch(self, runner):
"""Synchronize buffer and save model weights, for tasks trained in epochs"""
def after_train_iter(self, runner):
"""Synchronize buffers and save model weights for tasks trained in iterations"""
MMEngine
class CheckpointHook(Hook):
def before_run(self, runner):
"""Initialize out_dir and file_client"""
def after_train_epoch(self, runner):
"""Synchronize buffer and save model weights, for tasks trained in epochs"""
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""Synchronize buffers and save model weights for tasks trained in iterations"""
def after_val_epoch(self, runner, metrics):
"""Save optimal weights according to metrics"""