可视化训练日志¶
MMEngine 集成了 TensorBoard、Weights & Biases (WandB)、MLflow 和 ClearML 实验管理工具,你可以很方便地跟踪和可视化损失及准确率等指标。
下面基于15 分钟上手 MMENGINE中的例子介绍如何一行配置实验管理工具。
TensorBoard¶
设置 Runner
初始化参数中的 visualizer
,并将 vis_backends
设置为 TensorboardVisBackend
。
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
visualizer=dict(type='Visualizer', vis_backends=[dict(type='TensorboardVisBackend')]),
)
runner.train()
WandB¶
使用 WandB 前需安装依赖库 wandb
并登录至 wandb。
pip install wandb
wandb login
设置 Runner
初始化参数中的 visualizer
,并将 vis_backends
设置为 WandbVisBackend
。
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
visualizer=dict(type='Visualizer', vis_backends=[dict(type='WandbVisBackend')]),
)
runner.train()
可以点击 WandbVisBackend API 查看 WandbVisBackend
可配置的参数。例如 init_kwargs
,该参数会传给 wandb.init 方法。
runner = Runner(
...
visualizer=dict(
type='Visualizer',
vis_backends=[
dict(
type='WandbVisBackend',
init_kwargs=dict(project='toy-example')
),
],
),
...
)
runner.train()
MLflow (WIP)¶
ClearML¶
使用 ClearML 前需安装依赖库 clearml
并参考 Connect ClearML SDK to the Server 进行配置。
pip install clearml
clearml-init
设置 Runner
初始化参数中的 visualizer
,并将 vis_backends
设置为 ClearMLVisBackend
。
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
visualizer=dict(type='Visualizer', vis_backends=[dict(type='ClearMLVisBackend')]),
)
runner.train()