可视化训练日志¶
MMEngine 集成了 TensorBoard、Weights & Biases (WandB)、MLflow 和 ClearML 实验管理工具,你可以很方便地跟踪和可视化损失及准确率等指标。
下面基于15 分钟上手 MMENGINE中的例子介绍如何一行配置实验管理工具。
TensorBoard¶
设置 Runner 初始化参数中的 visualizer,并将 vis_backends 设置为 TensorboardVisBackend。
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
visualizer=dict(type='Visualizer', vis_backends=[dict(type='TensorboardVisBackend')]),
)
runner.train()
WandB¶
使用 WandB 前需安装依赖库 wandb 并登录至 wandb。
pip install wandb
wandb login
设置 Runner 初始化参数中的 visualizer,并将 vis_backends 设置为 WandbVisBackend。
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
visualizer=dict(type='Visualizer', vis_backends=[dict(type='WandbVisBackend')]),
)
runner.train()

可以点击 WandbVisBackend API 查看 WandbVisBackend 可配置的参数。例如 init_kwargs,该参数会传给 wandb.init 方法。
runner = Runner(
...
visualizer=dict(
type='Visualizer',
vis_backends=[
dict(
type='WandbVisBackend',
init_kwargs=dict(project='toy-example')
),
],
),
...
)
runner.train()
MLflow (WIP)¶
ClearML¶
使用 ClearML 前需安装依赖库 clearml 并参考 Connect ClearML SDK to the Server 进行配置。
pip install clearml
clearml-init
设置 Runner 初始化参数中的 visualizer,并将 vis_backends 设置为 ClearMLVisBackend。
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
visualizer=dict(type='Visualizer', vis_backends=[dict(type='ClearMLVisBackend')]),
)
runner.train()