Shortcuts

迁移 MMCV 钩子到 MMEngine

简介

由于架构设计的更新和用户需求的不断增加,MMCV 的钩子(Hook)点位已经满足不了需求,因此在 MMEngine 中对钩子点位进行了重新设计以及对钩子的功能做了调整。在开始迁移前,阅读钩子的设计会很有帮助。

本文对比 MMCV v1.6.0MMEngine v0.5.0 的钩子在功能、点位、用法和实现上的差异。

功能差异

MMCV MMEngine
反向传播以及梯度更新 OptimizerHook 将反向传播以及梯度更新的操作抽象成 OptimWrapper 而不是钩子
GradientCumulativeOptimizerHook
学习率调整 LrUpdaterHook ParamSchdulerHook 以及 _ParamScheduler 的子类完成优化器超参的调整
动量调整 MomentumUpdaterHook
按指定间隔保存权重 CheckpointHook CheckpointHook 除了保存权重,还有保存最优权重的功能,而 EvalHook 的模型评估功能则交由 ValLoop 或 TestLoop 完成
模型评估并保存最优模型 EvalHook
打印日志 LoggerHook 及其子类实现打印日志、保存日志以及可视化功能 LoggerHook
可视化 NaiveVisualizationHook
添加运行时信息 RuntimeInfoHook
模型参数指数滑动平均 EMAHook EMAHook
确保分布式 Sampler 的 shuffle 生效 DistSamplerSeedHook DistSamplerSeedHook
同步模型的 buffer SyncBufferHook SyncBufferHook
PyTorch CUDA 缓存清理 EmptyCacheHook EmptyCacheHook
统计迭代耗时 IterTimerHook IterTimerHook
分析训练时间的瓶颈 ProfilerHook 暂未提供
提供注册方法给钩子点位的功能 ClosureHook 暂未提供

点位差异

MMCV MMEngine
全局位点 执行前 before_run before_run
执行后 after_run after_run
Checkpoint 相关 加载 checkpoint 后 after_load_checkpoint
保存 checkpoint 前 before_save_checkpoint
训练相关 训练前触发 before_train
训练后触发 after_train
每个 epoch 前 before_train_epoch before_train_epoch
每个 epoch 后 after_train_epoch after_train_epoch
每次迭代前 before_train_iter before_train_iter,新增 batch_idx 和 data_batch 参数
每次迭代后 after_train_iter after_train_iter,新增 batch_idx、data_batch 和 outputs 参数
验证相关 验证前触发 before_val
验证后触发 after_val
每个 epoch 前 before_val_epoch before_val_epoch
每个 epoch 后 after_val_epoch after_val_epoch
每次迭代前 before_val_iter before_val_iter,新增 batch_idx 和 data_batch 参数
每次迭代后 after_val_iter after_val_iter,新增 batch_idx、data_batch 和 outputs 参数
测试相关 测试前触发 before_test
测试后触发 after_test
每个 epoch 前 before_test_epoch
每个 epoch 后 after_test_epoch
每次迭代前 before_test_iter,新增 batch_idx 和 data_batch 参数
每次迭代后 after_test_iter,新增 batch_idx、data_batch 和 outputs 参数

用法差异

在 MMCV 中,将钩子注册到执行器(Runner),需调用执行器的 register_training_hooks 方法往执行器注册钩子,而在 MMEngine 中,可以通过参数传递给执行器的初始化方法进行注册。

  • MMCV

model = ResNet18()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
lr_config = dict(policy='step', step=[2, 3])
optimizer_config = dict(grad_clip=None)
checkpoint_config = dict(interval=5)
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
runner = EpochBasedRunner(
    model=model,
    optimizer=optimizer,
    work_dir='./work_dir',
    max_epochs=3,
    xxx,
)
runner.register_training_hooks(
    lr_config=lr_config,
    optimizer_config=optimizer_config,
    checkpoint_config=checkpoint_config,
    log_config=log_config,
    custom_hooks_config=custom_hooks,
)
runner.run([trainloader], [('train', 1)])
  • MMEngine

model=ResNet18()
optim_wrapper=dict(
    type='OptimizerWrapper',
    optimizer=dict(type='SGD', lr=0.001, momentum=0.9))
param_scheduler = dict(type='MultiStepLR', milestones=[2, 3]),
default_hooks = dict(
    logger=dict(type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=5),
)
custom_hooks = [dict(type='NumClassCheckHook')]
runner = Runner(
    model=model,
    work_dir='./work_dir',
    optim_wrapper=optim_wrapper,
    param_scheduler=param_scheduler,
    train_cfg=dict(by_epoch=True, max_epochs=3),
    default_hooks=default_hooks,
    custom_hooks=custom_hooks,
    xxx,
)
runner.train()

MMEngine 钩子的更多用法请参考钩子的用法

实现差异

CheckpointHook 为例,MMEngine 的 CheckpointHook 相比 MMCV 的 CheckpointHook(新增保存最优权重的功能,在 MMCV 中,保存最优权重的功能由 EvalHook 提供),因此,它需要实现 after_val_epoch 点位。

  • MMCV

class CheckpointHook(Hook):
    def before_run(self, runner):
        """初始化 out_dir 和 file_client 属性"""

    def after_train_epoch(self, runner):
        """同步 buffer 和保存权重,用于以 epoch 为单位训练的任务"""

    def after_train_iter(self, runner):
        """同步 buffer 和保存权重,用于以 iteration 为单位训练的任务"""
  • MMEngine

class CheckpointHook(Hook):
    def before_run(self, runner):
        """初始化 out_dir 和 file_client 属性"""

    def after_train_epoch(self, runner):
        """同步 buffer 和保存权重,用于以 epoch 为单位训练的任务"""

    def after_train_iter(self, runner, batch_idx, data_batch, outputs):
        """同步 buffer 和保存权重,用于以 iteration 为单位训练的任务"""

    def after_val_epoch(self, runner, metrics):
        """根据 metrics 保存最优权重"""
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.