Shortcuts

设置日志、权重保存、验证的频率

MMEngine 支持两种训练模式,基于轮次的 EpochBased 方式和基于迭代次数的 IterBased 方式,这两种方式在下游算法库均有使用,例如 MMDetection 默认使用 EpochBased 方式,MMSegmentation 默认使用 IterBased 方式。

在不同的训练模式下,MMEngine 间隔(interval)的语义会有区别,EpochBased 的间隔以 Epoch 为单位,IterBasedIteration 为单位。

设置训练和验证的间隔

设置 Runner 初始化参数 train_cfg 中的 val_interval 值即可定制训练和验证的间隔。

  • EpochBased

EpochBased 模式下,val_interval 的默认值为 1,表示训练一个 Epoch,验证一次。

runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
)
runner.train()
  • IterBased

IterBased 模式下,val_interval 的默认值为 1000,表示训练迭代 1000 次,验证一次。

runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=False, max_iters=10000, val_interval=2000),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
)
runner.train()

设置保存权重的间隔

设置 CheckpointHookinterval 值即可定制保存权重的间隔。

  • EpochBased

EpochBased 模式下,interval 的默认值为 1,表示训练一个 Epoch,保存一次权重。

# 将 interval 设置为 2,表示每 2 个 epoch 保存一次权重
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2))
runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
    default_hooks=default_hooks,
)
runner.train()
  • IterBased

默认以 Epoch 为单位保存权重,如果希望以 Iteration 为单位,需设置 by_epoch=False

# 设置 by_epoch=False 以及 interval = 500,表示每 500 个 iteration 保存一次权重
default_hooks = dict(checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=500))
runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=False, max_iters=10000, val_interval=1000),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
    default_hooks=default_hooks,
)
runner.train()

CheckpointHook 的更多用法可查看 CheckpointHook 教程

设置打印日志的间隔

默认情况下,每迭代 10 次往终端打印 1 次日志,可以通过设置 LoggerHookinterval 参数进行设置。

# 设置每 20 次打印一次
default_hooks = dict(logger=dict(type='LoggerHook', interval=20))
runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
    default_hooks=default_hooks,
)
runner.train()

LoggerHook 的更多用法可查看 LoggerHook 教程

Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.