Shortcuts

欢迎来到 MMEngine 的中文文档!

您可以在页面左下角切换中英文文档。

介绍

MMEngine 是一个用于深度学习模型训练的基础库,基于 PyTorch,支持在 Linux、Windows、macOS 上运行。它具有如下三个亮点:

  1. 通用:MMEngine 实现了一个高级的通用训练器,它能够:

    • 支持用少量代码训练不同的任务,例如仅使用 80 行代码就可以训练 imagenet(pytorch example 400 行)

    • 轻松兼容流行的算法库如 TIMM、TorchVision 和 Detectron2 中的模型

  2. 统一:MMEngine 设计了一个接口统一的开放架构,使得

    • 用户可以仅依赖一份代码实现所有任务的轻量化,例如 MMRazor 1.x 相比 MMRazor 0.x 优化了 40% 的代码量

    • 上下游的对接更加统一便捷,在为上层算法库提供统一抽象的同时,支持多种后端设备。目前 MMEngine 支持 Nvidia CUDA、Mac MPS、AMD、MLU 等设备进行模型训练。

  3. 灵活:MMEngine 实现了“乐高”式的训练流程,支持了

    • 根据迭代数、 loss 和评测结果等动态调整的训练流程、优化策略和数据增强策略,例如早停(early stopping)机制等

    • 任意形式的模型权重平均,如 Exponential Momentum Average (EMA) 和 Stochastic Weight Averaging (SWA)

    • 训练过程中针对任意数据和任意节点的灵活可视化和日志控制

    • 对神经网络模型中各个层的优化配置进行细粒度调整

    • 混合精度训练的灵活控制

架构

openmmlab-2 0-arch

上图展示了 MMEngine 在 OpenMMLab 2.0 中的层次。MMEngine 实现了 OpenMMLab 算法库的新一代训练架构,为 OpenMMLab 中的 30 多个算法库提供了统一的执行基座。其核心组件包含训练引擎、评测引擎和模块管理等。

模块介绍

模块关系

MMEngine 将训练过程中涉及的组件和它们的关系进行了抽象,如上图所示。不同算法库中的同类型组件具有相同的接口定义。

核心模块与相关组件

训练引擎的核心模块是执行器(Runner)。 执行器负责执行训练、测试和推理任务并管理这些过程中所需要的各个组件。在训练、测试、推理任务执行过程中的特定位置,执行器设置了钩子(Hook) 来允许用户拓展、插入和执行自定义逻辑。执行器主要调用如下组件来完成训练和推理过程中的循环:

  • 数据集(Dataset):负责在训练、测试、推理任务中构建数据集,并将数据送给模型。实际使用过程中会被数据加载器(DataLoader)封装一层,数据加载器会启动多个子进程来加载数据。

  • 模型(Model):在训练过程中接受数据并输出 loss;在测试、推理任务中接受数据,并进行预测。分布式训练等情况下会被模型的封装器(Model Wrapper,如MMDistributedDataParallel)封装一层。

  • 优化器封装(Optimizer):优化器封装负责在训练过程中执行反向传播优化模型,并且以统一的接口支持了混合精度训练和梯度累加。

  • 参数调度器(Parameter Scheduler):训练过程中,对学习率、动量等优化器超参数进行动态调整。

在训练间隙或者测试阶段,评测指标与评测器(Metrics & Evaluator)会负责对模型性能进行评测。其中评测器负责基于数据集对模型的预测进行评估。评测器内还有一层抽象是评测指标,负责计算具体的一个或多个评测指标(如召回率、正确率等)。

为了统一接口,OpenMMLab 2.0 中各个算法库的评测器,模型和数据之间交流的接口都使用了数据元素(Data Element)来进行封装。

在训练、推理执行过程中,上述各个组件都可以调用日志管理模块和可视化器进行结构化和非结构化日志的存储与展示。日志管理(Logging Modules):负责管理执行器运行过程中产生的各种日志信息。其中消息枢纽 (MessageHub)负责实现组件与组件、执行器与执行器之间的数据共享,日志处理器(Log Processor)负责对日志信息进行处理,处理后的日志会分别发送给执行器的日志器(Logger)和可视化器(Visualizer)进行日志的管理与展示。可视化器(Visualizer):可视化器负责对模型的特征图、预测结果和训练过程中产生的结构化日志进行可视化,支持 Tensorboard 和 WanDB 等多种可视化后端。

公共基础模块

MMEngine 中还实现了各种算法模型执行过程中需要用到的公共基础模块,包括

  • 配置类(Config):在 OpenMMLab 算法库中,用户可以通过编写 config 来配置训练、测试过程以及相关的组件。

  • 注册器(Registry):负责管理算法库中具有相同功能的模块。MMEngine 根据对算法库模块的抽象,定义了一套根注册器,算法库中的注册器可以继承自这套根注册器,实现模块的跨算法库调用。

  • 文件读写(File I/O):为各个模块的文件读写提供了统一的接口,以统一的形式支持了多种文件读写后端和多种文件格式,并具备扩展性。

  • 分布式通信原语(Distributed Communication Primitives):负责在程序分布式运行过程中不同进程间的通信。这套接口屏蔽了分布式和非分布式环境的区别,同时也自动处理了数据的设备和通信后端。

  • 其他工具(Utils):还有一些工具性的模块,如 ManagerMixin,它实现了一种全局变量的创建和获取方式,执行器内很多全局可见对象的基类就是 ManagerMixin。

用户可以进一步阅读教程来了解这些模块的高级用法,也可以参考设计文档 了解它们的设计思路与细节。

安装

环境依赖

  • Python 3.6+

  • PyTorch 1.6+

  • CUDA 9.2+

  • GCC 5.4+

准备环境

  1. 使用 conda 新建虚拟环境,并进入该虚拟环境;

    conda create -n open-mmlab python=3.7 -y
    conda activate open-mmlab
    
  2. 安装 PyTorch

    在安装 MMEngine 之前,请确保 PyTorch 已经成功安装在环境中,可以参考 PyTorch 官方安装文档。使用以下命令验证 PyTorch 是否安装

    python -c 'import torch;print(torch.__version__)'
    

安装 MMEngine

使用 mim 安装

mim 是 OpenMMLab 项目的包管理工具,使用它可以很方便地安装 OpenMMLab 项目。

pip install -U openmim
mim install mmengine

使用 pip 安装

pip install mmengine

使用 docker 镜像

  1. 构建镜像

    docker build -t mmengine https://github.com/open-mmlab/mmengine.git#main:docker/release
    

    更多构建方式请参考 mmengine/docker

  2. 运行镜像

    docker run --gpus all --shm-size=8g -it mmengine
    

源码安装

# 如果克隆代码仓库的速度过慢,可以从 https://gitee.com/open-mmlab/mmengine.git 克隆
git clone https://github.com/open-mmlab/mmengine.git
cd mmengine
pip install -e . -v

验证安装

为了验证是否正确安装了 MMEngine 和所需的环境,我们可以运行以下命令

python -c 'import mmengine;print(mmengine.__version__)'

15 分钟上手 MMEngine

以在 CIFAR-10 数据集上训练一个 ResNet-50 模型为例,我们将使用 80 行以内的代码,利用 MMEngine 构建一个完整的、 可配置的训练和验证流程,整个流程包含如下步骤:

  1. 构建模型

  2. 构建数据集和数据加载器

  3. 构建评测指标

  4. 构建执行器并执行任务

构建模型

首先,我们需要构建一个模型,在 MMEngine 中,我们约定这个模型应当继承 BaseModel,并且其 forward 方法除了接受来自数据集的若干参数外,还需要接受额外的参数 mode:对于训练,我们需要 mode 接受字符串 “loss”,并返回一个包含 “loss” 字段的字典;对于验证,我们需要 mode 接受字符串 “predict”,并返回同时包含预测信息和真实信息的结果。

import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel


class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels

构建数据集和数据加载器

其次,我们需要构建训练和验证所需要的数据集 (Dataset)数据加载器 (DataLoader)。 对于基础的训练和验证功能,我们可以直接使用符合 PyTorch 标准的数据加载器和数据集。

import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))

val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))

构建评测指标

为了进行验证和测试,我们需要定义模型推理结果的评测指标。我们约定这一评测指标需要继承 BaseMetric,并实现 processcompute_metrics 方法。其中 process 方法接受数据集的输出和模型 mode="predict" 时的输出,此时的数据为一个批次的数据,对这一批次的数据进行处理后,保存信息至 self.results 属性。 而 compute_metrics 接受 results 参数,这一参数的输入为 process 中保存的所有信息 (如果是分布式环境,results 中为已收集的,包括各个进程 process 保存信息的结果),利用这些信息计算并返回保存有评测指标结果的字典。

from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # 将一个批次的中间结果保存至 `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # 返回保存有评测指标结果的字典,其中键为指标名称
        return dict(accuracy=100 * total_correct / total_size)

构建执行器并执行任务

最后,我们利用构建好的模型数据加载器评测指标构建一个执行器 (Runner),同时在其中配置 优化器工作路径训练与验证配置等选项,即可通过调用 train() 接口启动训练:

from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    # 用以训练和验证的模型,需要满足特定的接口需求
    model=MMResNet50(),
    # 工作路径,用以保存训练日志、权重文件信息
    work_dir='./work_dir',
    # 训练数据加载器,需要满足 PyTorch 数据加载器协议
    train_dataloader=train_dataloader,
    # 优化器包装,用于模型优化,并提供 AMP、梯度累积等附加功能
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # 训练配置,用于指定训练周期、验证间隔等信息
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    # 验证数据加载器,需要满足 PyTorch 数据加载器协议
    val_dataloader=val_dataloader,
    # 验证配置,用于指定验证所需要的额外参数
    val_cfg=dict(),
    # 用于验证的评测器,这里使用默认评测器,并评测指标
    val_evaluator=dict(type=Accuracy),
)

runner.train()

最后,让我们把以上部分汇总成为一个完整的,利用 MMEngine 执行器进行训练和验证的脚本:

在 Colab 中打开

import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from torch.utils.data import DataLoader

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner


class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels


class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        return dict(accuracy=100 * total_correct / total_size)


norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))

val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))

runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
)
runner.train()

输出的训练日志如下:

2022/08/22 15:51:53 - mmengine - INFO -
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0]
    CUDA available: True
    numpy_random_seed: 1513128759
    GPU 0: NVIDIA GeForce GTX 1660 SUPER
    CUDA_HOME: /usr/local/cuda
...

2022/08/22 15:51:54 - mmengine - INFO - Checkpoints will be saved to /home/mazerun/work_dir by HardDiskBackend.
2022/08/22 15:51:56 - mmengine - INFO - Epoch(train) [1][10/1563]  lr: 1.0000e-03  eta: 0:18:23  time: 0.1414  data_time: 0.0077  memory: 392  loss: 5.3465
2022/08/22 15:51:56 - mmengine - INFO - Epoch(train) [1][20/1563]  lr: 1.0000e-03  eta: 0:11:29  time: 0.0354  data_time: 0.0077  memory: 392  loss: 2.7734
2022/08/22 15:51:56 - mmengine - INFO - Epoch(train) [1][30/1563]  lr: 1.0000e-03  eta: 0:09:10  time: 0.0352  data_time: 0.0076  memory: 392  loss: 2.7789
2022/08/22 15:51:57 - mmengine - INFO - Epoch(train) [1][40/1563]  lr: 1.0000e-03  eta: 0:08:00  time: 0.0353  data_time: 0.0073  memory: 392  loss: 2.5725
2022/08/22 15:51:57 - mmengine - INFO - Epoch(train) [1][50/1563]  lr: 1.0000e-03  eta: 0:07:17  time: 0.0347  data_time: 0.0073  memory: 392  loss: 2.7382
2022/08/22 15:51:57 - mmengine - INFO - Epoch(train) [1][60/1563]  lr: 1.0000e-03  eta: 0:06:49  time: 0.0347  data_time: 0.0072  memory: 392  loss: 2.5956
2022/08/22 15:51:58 - mmengine - INFO - Epoch(train) [1][70/1563]  lr: 1.0000e-03  eta: 0:06:28  time: 0.0348  data_time: 0.0072  memory: 392  loss: 2.7351
...
2022/08/22 15:52:50 - mmengine - INFO - Saving checkpoint at 1 epochs
2022/08/22 15:52:51 - mmengine - INFO - Epoch(val) [1][10/313]    eta: 0:00:03  time: 0.0122  data_time: 0.0047  memory: 392
2022/08/22 15:52:51 - mmengine - INFO - Epoch(val) [1][20/313]    eta: 0:00:03  time: 0.0122  data_time: 0.0047  memory: 308
2022/08/22 15:52:51 - mmengine - INFO - Epoch(val) [1][30/313]    eta: 0:00:03  time: 0.0123  data_time: 0.0047  memory: 308
...
2022/08/22 15:52:54 - mmengine - INFO - Epoch(val) [1][313/313]  accuracy: 35.7000

基于 PyTorch 和基于 MMEngine 的训练流程对比如下:

output

除了以上基础组件,你还可以利用执行器轻松地组合配置各种训练技巧,如开启混合精度训练和梯度累积(见 优化器封装(OptimWrapper))、配置学习率衰减曲线(见 评测指标与评测器(Metrics & Evaluator))等。

恢复训练

恢复训练是指从之前某次训练保存下来的状态开始继续训练,这里的状态包括模型的权重、优化器和优化器参数调整策略的状态。

自动恢复训练

用户可以设置 Runnerresume 参数开启自动恢复训练的功能。在启动训练时,设置 Runnerresume 等于 TrueRunner 会从 work_dir 中加载最新的 checkpoint。如果 work_dir 中有最新的 checkpoint(例如该训练在上一次训练时被中断),则会从该 checkpoint 恢复训练,否则(例如上一次训练还没来得及保存 checkpoint 或者启动了新的训练任务)会重新开始训练。下面是一个开启自动恢复训练的示例

runner = Runner(
    model=ResNet18(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader_cfg,
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=3),
    resume=True,
)
runner.train()

指定 checkpoint 路径

如果希望指定恢复训练的路径,除了设置 resume=True,还需要设置 load_from 参数。需要注意的是,如果只设置了 load_from 而没有设置 resume=True,则只会加载 checkpoint 中的权重并重新开始训练,而不是接着之前的状态继续训练。

runner = Runner(
    model=ResNet18(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader_cfg,
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=3),
    load_from='./work_dir/epoch_2.pth',
    resume=True,
)
runner.train()

加速训练

分布式训练

MMEngine 支持 CPU、单卡、单机多卡以及多机多卡的训练。当环境中有多张显卡时,我们可以使用以下命令开启单机多卡或者多机多卡的方式从而缩短模型的训练时间。

  • 单机多卡

    假设当前机器有 8 张显卡,可以使用以下命令开启多卡训练

    python -m torch.distributed.launch --nproc_per_node=8 examples/train.py --launcher pytorch
    

    如果需要指定显卡的编号,可以设置 CUDA_VISIBLE_DEVICES 环境变量,例如使用第 0 和第 3 张卡

    CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 examples/train.py --launcher pytorch
    
  • 多机多卡

    假设有 2 台机器,每台机器有 8 张卡。

    第一台机器运行以下命令

    python -m torch.distributed.launch \
        --nnodes 8 \
        --node_rank 0 \
        --master_addr 127.0.0.1 \
        --master_port 29500 \
        --nproc_per_node=8 \
        examples/train.py --launcher pytorch
    

    第 2 台机器运行以下命令

    python -m torch.distributed.launch \
        --nnodes 8 \
        --node_rank 1 \
        --master_addr 127.0.0.1 \
        --master_port 29500 \
        --nproc_per_node=8 \
        examples/train.py --launcher pytorch
    

    如果在 slurm 集群运行 MMEngine,只需运行以下命令即可开启 2 机 16 卡的训练

    srun -p mm_dev \
        --job-name=test \
        --gres=gpu:8 \
        --ntasks=16 \
        --ntasks-per-node=8 \
        --cpus-per-task=5 \
        --kill-on-bad-exit=1 \
        python examples/train.py --launcher="slurm"
    

混合精度训练

Nvidia 在 Volta 和 Turing 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。开启自动混合精度训练后,部分算子的操作精度是 FP16,其余算子的操作精度是 FP32。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更大的 batch size、更大模型和尺寸更大的输入的训练。

PyTorch 从 1.6 开始官方支持 amp。如果你对自动混合精度的实现感兴趣,可以阅读 torch.cuda.amp: 自动混合精度详解

MMEngine 提供自动混合精度的封装 AmpOptimWrapper ,只需在 optim_wrapper 设置 type='AmpOptimWrapper' 即可开启自动混合精度训练,无需对代码做其他修改。

runner = Runner(
    model=ResNet18(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader_cfg,
    optim_wrapper=dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=3),
)
runner.train()

节省显存

在深度学习训练推理过程中显存容量至关重要,其决定了模型是否能成功运行。常见的节省显存办法包括:

  • 梯度累加

    梯度累加是指在每计算一个批次的梯度后,不进行清零而是进行梯度累加,当累加到一定的次数之后,再更新网络参数和梯度清零。 通过这种参数延迟更新的手段,实现与采用大 batch 尺寸相近的效果,达到节省显存的目的。但是需要注意如果模型中包含 batch normalization 层,使用梯度累加会对性能有一定影响。

  • 梯度检查点

    梯度检查点是一种以时间换空间的方法,通过减少保存的激活值来压缩模型占用空间,但是在计算梯度时必须重新计算没有存储的激活值。在 torch.utils.checkpoint 包中已经实现了对应功能。简要实现过程是:在前向阶段传递到 checkpoint 中的 forward 函数会以 torch.no_grad 模式运行,并且仅仅保存 forward 函数的输入和输出,然后在反向阶段重新计算中间层的激活值 (intermediate activations)。

  • 大模型训练技术

    最近的研究表明大型模型训练将有利于提高模型质量,但是训练如此大的模型需要巨大的资源,单卡显存已经越来越难以满足存放整个模型,因此诞生了大模型训练技术,典型的如 DeepSpeed ZeRO 和 FairScale 的完全分片数据并行(Fully Sharded Data Parallel, FSDP)技术,其允许在数据并行进程之间分片模型的参数、梯度和优化器状态,并同时仍然保持数据并行的简单性。

MMEngine 目前支持梯度累加大模型训练 FSDP 技术 。下面说明其用法。

梯度累加

配置写法如下所示:

optim_wrapper_cfg = dict(
    type='OptimWrapper',
    optimizer=dict(type='SGD', lr=0.001, momentum=0.9),
    # 累加 4 次参数更新一次
    accumulative_counts=4)

配合 Runner 使用的完整例子如下:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mmengine.runner import Runner
from mmengine.model import BaseModel

train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)


class ToyModel(BaseModel):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, mode):
        feat = self.linear(img)
        loss1 = (feat - label).pow(2)
        loss2 = (feat - label).abs()
        return dict(loss1=loss1, loss2=loss2)


runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01),
                       accumulative_counts=4)
)
runner.train()

大模型训练

PyTorch 1.11 中已经原生支持了 FSDP 技术。配置写法如下所示:

# 位于 cfg 配置文件中
model_wrapper_cfg=dict(type='MMFullyShardedDataParallel', cpu_offload=True)

配合 Runner 使用的完整例子如下:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mmengine.runner import Runner
from mmengine.model import BaseModel

train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)


class ToyModel(BaseModel):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, mode):
        feat = self.linear(img)
        loss1 = (feat - label).pow(2)
        loss2 = (feat - label).abs()
        return dict(loss1=loss1, loss2=loss2)


runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
    cfg=dict(model_wrapper_cfg=dict(type='MMFullyShardedDataParallel', cpu_offload=True))
)
runner.train()

注意必须在分布式训练环境中 FSDP 才能生效。

训练生成对抗网络

生成对抗网络(Generative Adversarial Network, GAN)可以用来生成图像视频等数据。这篇教程将带你一步步用 MMEngine 训练 GAN !

我们可以通过以下步骤来训练一个生成对抗网络。

构建数据加载器

构建数据集

接下来, 我们为 MNIST 数据集构建一个数据集类 MNISTDataset, 继承自数据集基类 BaseDataset, 并且重载数据集基类的 load_data_list 函数, 保证返回值为 list[dict],其中每个 dict 代表一个数据样本。更多关于 MMEngine 中数据集的用法,可以参考数据集教程

import numpy as np
from mmcv.transforms import to_tensor
from torch.utils.data import random_split
from torchvision.datasets import MNIST

from mmengine.dataset import BaseDataset


class MNISTDataset(BaseDataset):

    def __init__(self, data_root, pipeline, test_mode=False):
        # 下载 MNIST 数据集
        if test_mode:
            mnist_full = MNIST(data_root, train=True, download=True)
            self.mnist_dataset, _ = random_split(mnist_full, [55000, 5000])
        else:
            self.mnist_dataset = MNIST(data_root, train=False, download=True)

        super().__init__(
            data_root=data_root, pipeline=pipeline, test_mode=test_mode)

    @staticmethod
    def totensor(img):
        if len(img.shape) < 3:
            img = np.expand_dims(img, -1)
        img = np.ascontiguousarray(img.transpose(2, 0, 1))
        return to_tensor(img)

    def load_data_list(self):
        return [
            dict(inputs=self.totensor(np.array(x[0]))) for x in self.mnist_dataset
        ]


dataset = MNISTDataset("./data", [])

使用 Runner 中的函数 build_dataloader 来构建数据加载器。

import os
import torch
from mmengine.runner import Runner

NUM_WORKERS = int(os.cpu_count() / 2)
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

train_dataloader = dict(
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dataset)
train_dataloader = Runner.build_dataloader(train_dataloader)

构建生成器网络和判别器网络

下面的代码构建并实例化了一个生成器(Generator)和一个判别器(Discriminator)。

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, noise_size, img_shape):
        super().__init__()
        self.img_shape = img_shape
        self.noise_size = noise_size

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(noise_size, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity
generator = Generator(100, (1, 28, 28))
discriminator = Discriminator((1, 28, 28))

构建一个生成对抗网络模型

在使用 MMEngine 时,我们用 ImgDataPreprocessor 来对数据进行归一化和颜色通道的转换。

from mmengine.model import ImgDataPreprocessor

data_preprocessor = ImgDataPreprocessor(mean=([127.5]), std=([127.5]))

下面的代码实现了基础 GAN 的算法。使用 MMEngine 实现算法类,需要继承 BaseModel 基类,在 train_step 中实现训练过程。GAN 需要交替训练生成器和判别器,分别由 train_discriminator 和 train_generator 实现,并实现 disc_loss 和 gen_loss 计算判别器损失函数和生成器损失函数。 关于 BaseModel 的更多信息,请参考模型教程.

import torch.nn.functional as F
from mmengine.model import BaseModel

class GAN(BaseModel):

    def __init__(self, generator, discriminator, noise_size,
                 data_preprocessor):
        super().__init__(data_preprocessor=data_preprocessor)
        assert generator.noise_size == noise_size
        self.generator = generator
        self.discriminator = discriminator
        self.noise_size = noise_size

    def train_step(self, data, optim_wrapper):
        # 获取数据和数据预处理
        inputs_dict = self.data_preprocessor(data, True)
        # 训练判别器
        disc_optimizer_wrapper = optim_wrapper['discriminator']
        with disc_optimizer_wrapper.optim_context(self.discriminator):
            log_vars = self.train_discriminator(inputs_dict,
                                                disc_optimizer_wrapper)

        # 训练生成器
        set_requires_grad(self.discriminator, False)
        gen_optimizer_wrapper = optim_wrapper['generator']
        with gen_optimizer_wrapper.optim_context(self.generator):
            log_vars_gen = self.train_generator(inputs_dict,
                                                gen_optimizer_wrapper)

        set_requires_grad(self.discriminator, True)
        log_vars.update(log_vars_gen)

        return log_vars

    def forward(self, batch_inputs, data_samples=None, mode=None):
        return self.generator(batch_inputs)

    def disc_loss(self, disc_pred_fake, disc_pred_real):
        losses_dict = dict()
        losses_dict['loss_disc_fake'] = F.binary_cross_entropy(
            disc_pred_fake, 0. * torch.ones_like(disc_pred_fake))
        losses_dict['loss_disc_real'] = F.binary_cross_entropy(
            disc_pred_real, 1. * torch.ones_like(disc_pred_real))

        loss, log_var = self.parse_losses(losses_dict)
        return loss, log_var

    def gen_loss(self, disc_pred_fake):
        losses_dict = dict()
        losses_dict['loss_gen'] = F.binary_cross_entropy(
            disc_pred_fake, 1. * torch.ones_like(disc_pred_fake))
        loss, log_var = self.parse_losses(losses_dict)
        return loss, log_var

    def train_discriminator(self, inputs, optimizer_wrapper):
        real_imgs = inputs['inputs']
        z = torch.randn(
            (real_imgs.shape[0], self.noise_size)).type_as(real_imgs)
        with torch.no_grad():
            fake_imgs = self.generator(z)

        disc_pred_fake = self.discriminator(fake_imgs)
        disc_pred_real = self.discriminator(real_imgs)

        parsed_losses, log_vars = self.disc_loss(disc_pred_fake,
                                                 disc_pred_real)
        optimizer_wrapper.update_params(parsed_losses)
        return log_vars

    def train_generator(self, inputs, optimizer_wrapper):
        real_imgs = inputs['inputs']
        z = torch.randn(real_imgs.shape[0], self.noise_size).type_as(real_imgs)

        fake_imgs = self.generator(z)

        disc_pred_fake = self.discriminator(fake_imgs)
        parsed_loss, log_vars = self.gen_loss(disc_pred_fake)

        optimizer_wrapper.update_params(parsed_loss)
        return log_vars

其中一个函数 set_requires_grad 用来锁定训练生成器时判别器的权重。

def set_requires_grad(nets, requires_grad=False):
    """Set requires_grad for all the networks.

    Args:
        nets (nn.Module | list[nn.Module]): A list of networks or a single
            network.
        requires_grad (bool): Whether the networks require gradients or not.
    """
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

model = GAN(generator, discriminator, 100, data_preprocessor)

构建优化器

MMEngine 使用 OptimWrapper 来封装优化器,对于多个优化器的情况,使用 OptimWrapperDict 对 OptimWrapper 再进行一次封装。 关于优化器的更多信息,请参考优化器教程.

from mmengine.optim import OptimWrapper, OptimWrapperDict

opt_g = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_g_wrapper = OptimWrapper(opt_g)

opt_d = torch.optim.Adam(
    discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_d_wrapper = OptimWrapper(opt_d)

opt_wrapper_dict = OptimWrapperDict(
    generator=opt_g_wrapper, discriminator=opt_d_wrapper)

使用执行器进行训练

下面的代码演示了如何使用 Runner 进行模型训练。关于 Runner 的更多信息,请参考执行器教程

train_cfg = dict(by_epoch=True, max_epochs=220)
runner = Runner(
    model,
    work_dir='runs/gan/',
    train_dataloader=train_dataloader,
    train_cfg=train_cfg,
    optim_wrapper=opt_wrapper_dict)
runner.train()

到这里,我们就完成了一个 GAN 的训练,通过下面的代码可以查看刚才训练的 GAN 生成的结果。

z = torch.randn(64, 100).cuda()
img = model(z)

from torchvision.utils import save_image
save_image(img, "result.png", normalize=True)

GAN生成图像

如果你想了解更多如何使用 MMEngine 实现 GAN 和生成模型,我们强烈建议你使用同样基于 MMEngine 开发的生成框架 MMGen

执行器(Runner)

欢迎来到 MMEngine 用户界面的核心——执行器!

作为 MMEngine 中的“集大成者”,执行器涵盖了整个框架的方方面面,肩负着串联所有组件的重要责任;因此,其中的代码和实现逻辑需要兼顾各种情景,相对庞大复杂。但是不用担心!在这篇教程中,我们将隐去繁杂的细节,速览执行器常用的接口、功能、示例,为你呈现一个清晰易懂的用户界面。阅读完本篇教程,你将会:

  • 掌握执行器的常见参数与使用方式

  • 了解执行器的最佳实践——配置文件的写法

  • 了解执行器基本数据流与简要执行逻辑

  • 亲身感受使用执行器的优越性(也许)

执行器示例

使用执行器构建属于你自己的训练流程,通常有两种开始方式:

  • 参考 API 文档,逐项确认和配置参数

  • 在已有配置(如 15 分钟上手MMDet 等下游算法库)的基础上,进行定制化修改

两种方式各有利弊。使用前者,初学者很容易迷失在茫茫多的参数项中不知所措;而使用后者,一份过度精简或过度详细的参考配置都不利于初学者快速找到所需内容。

解决上述问题的关键在于,把执行器作为备忘录:掌握其中最常用的部分,并在有特殊需求时聚焦感兴趣的部分,其余部分使用缺省值。下面我们将通过一个适合初学者参考的例子,说明其中最常用的参数,并为一些不常用参数给出进阶指引。

面向初学者的示例代码

提示

我们希望你在本教程中更多地关注整体结构,而非具体模块的实现。这种“自顶向下”的思考方式是我们所倡导的。别担心,之后你将有充足的机会和指引,聚焦于自己想要改进的模块

运行下面的示例前,请先执行本段代码准备模型、数据集与评测指标;但是在本教程中,暂时无需关注它们的具体实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from mmengine.model import BaseModel
from mmengine.evaluator import BaseMetric
from mmengine.registry import MODELS, DATASETS, METRICS


@MODELS.register_module()
class MyAwesomeModel(BaseModel):
    def __init__(self, layers=4, activation='relu') -> None:
        super().__init__()
        if activation == 'relu':
            act_type = nn.ReLU
        elif activation == 'silu':
            act_type = nn.SiLU
        elif activation == 'none':
            act_type = nn.Identity
        else:
            raise NotImplementedError
        sequence = [nn.Linear(2, 64), act_type()]
        for _ in range(layers-1):
            sequence.extend([nn.Linear(64, 64), act_type()])
        self.mlp = nn.Sequential(*sequence)
        self.classifier = nn.Linear(64, 2)

    def forward(self, data, labels, mode):
        x = self.mlp(data)
        x = self.classifier(x)
        if mode == 'tensor':
            return x
        elif mode == 'predict':
            return F.softmax(x, dim=1), labels
        elif mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}


@DATASETS.register_module()
class MyDataset(Dataset):
    def __init__(self, is_train, size):
        self.is_train = is_train
        if self.is_train:
            torch.manual_seed(0)
            self.labels = torch.randint(0, 2, (size,))
        else:
            torch.manual_seed(3407)
            self.labels = torch.randint(0, 2, (size,))
        r = 3 * (self.labels+1) + torch.randn(self.labels.shape)
        theta = torch.rand(self.labels.shape) * 2 * torch.pi
        self.data = torch.vstack([r*torch.cos(theta), r*torch.sin(theta)]).T

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

    def __len__(self):
        return len(self.data)


@METRICS.register_module()
class Accuracy(BaseMetric):
    def __init__(self):
        super().__init__()

    def process(self, data_batch, data_samples):
        score, gt = data_samples
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(r['correct'] for r in results)
        total_size = sum(r['batch_size'] for r in results)
        return dict(accuracy=100*total_correct/total_size)
点击展开一段长长的示例代码。做好准备
from torch.utils.data import DataLoader, default_collate
from torch.optim import Adam
from mmengine.runner import Runner


runner = Runner(
    # 你的模型
    model=MyAwesomeModel(
        layers=2,
        activation='relu'),
    # 模型检查点、日志等都将存储在工作路径中
    work_dir='exp/my_awesome_model',

    # 训练所用数据
    train_dataloader=DataLoader(
        dataset=MyDataset(
            is_train=True,
            size=10000),
        shuffle=True,
        collate_fn=default_collate,
        batch_size=64,
        pin_memory=True,
        num_workers=2),
    # 训练相关配置
    train_cfg=dict(
        by_epoch=True,   # 根据 epoch 计数而非 iteration
        max_epochs=10,
        val_begin=2,     # 从第 2 个 epoch 开始验证
        val_interval=1), # 每隔 1 个 epoch 进行一次验证

    # 优化器封装,MMEngine 中的新概念,提供更丰富的优化选择。
    # 通常使用默认即可,可缺省。有特殊需求可查阅文档更换,如
    # 'AmpOptimWrapper' 开启混合精度训练
    optim_wrapper=dict(
        optimizer=dict(
            type=Adam,
            lr=0.001)),
    # 参数调度器,用于在训练中调整学习率/动量等参数
    param_scheduler=dict(
        type='MultiStepLR',
        by_epoch=True,
        milestones=[4, 8],
        gamma=0.1),

    # 验证所用数据
    val_dataloader=DataLoader(
        dataset=MyDataset(
            is_train=False,
            size=1000),
        shuffle=False,
        collate_fn=default_collate,
        batch_size=1000,
        pin_memory=True,
        num_workers=2),
    # 验证相关配置,通常为空即可
    val_cfg=dict(),
    # 验证指标与验证器封装,可自由实现与配置
    val_evaluator=dict(type=Accuracy),

    # 以下为其他进阶配置,无特殊需要时尽量缺省
    # 钩子属于进阶用法,如无特殊需要,尽量缺省
    default_hooks=dict(
        # 最常用的默认钩子,可修改保存 checkpoint 的间隔
        checkpoint=dict(type='CheckpointHook', interval=1)),

    # `luancher` 与 `env_cfg` 共同构成分布式训练环境配置
    launcher='none',
    env_cfg=dict(
        cudnn_benchmark=False,   # 是否使用 cudnn_benchmark
        backend='nccl',   # 分布式通信后端
        mp_cfg=dict(mp_start_method='fork')),  # 多进程设置
    log_level='INFO',

    # 加载权重的路径 (None 表示不加载)
    load_from=None
    # 从加载的权重文件中恢复训练
    resume=False
)

# 开始训练你的模型吧
runner.train()

示例代码讲解

真是一段长长的代码!但是如果你通读了上述样例,即使不了解实现细节,你也一定大体理解了这个训练流程,并感叹于执行器代码的紧凑与可读性(也许)。这也是 MMEngine 所期望的:结构化、模块化、标准化的训练流程,使得复现更加可靠、对比更加清晰。

上述例子可能会让你产生如下问题:

参数项实在是太多了!

不用担心,正如我们前面所说,把执行器作为备忘录。执行器涵盖了方方面面,防止你漏掉重要内容,但是这并不意味着你需要配置所有参数。如15分钟上手中的极简例子(甚至,舍去 val_evaluator val_dataloaderval_cfg)也可以正常运行。所有的参数由你的需求驱动,不关注的内容往往缺省值也可以工作得很好。

为什么有些传入参数是 dict?

是的,这与 MMEngine 的风格相关。在 MMEngine 中我们提供了两种不同风格的执行器构建方式:a)基于手动构建的,以及 b)基于注册机制的。如果你感到迷惑,下面的例子将给出一个对比:

from mmengine.model import BaseModel
from mmengine.runner import Runner
from mmengine.registry import MODELS # 模型根注册器,你的自定义模型需要注册到这个根注册器中

@MODELS.register_module() # 用于注册的装饰器
class MyAwesomeModel(BaseModel): # 你的自定义模型
    def __init__(self, layers=18, activation='silu'):
        ...

# 基于注册机制的例子
runner = Runner(
    model=dict(
        type='MyAwesomeModel',
        layers=50,
        activation='relu'),
    ...
)

# 基于手动构建的例子
model = MyAwesomeModel(layers=18, activation='relu')
runner = Runner(
    model=model,
    ...
)

类似上述例子,执行器中的参数大多同时支持两种输入类型。以上两种写法基本是等价的,区别在于:前者以 dict 作为输入时,该模块会在需要时在执行器内部被构建;而后者是构建完成后传递给执行器。如果你对于注册机制并不了解,下面的示意图展示了它的核心思想:注册器维护着模块的构建方式和它的名字之间的映射。如果你在使用中发现问题,或者想要进一步了解完整用法,我们推荐阅读注册器(Registry)文档。

Runner Registry 示意图

看到这你可能仍然很疑惑,为什么我要传入字典让 Runner 来构建实例,这样又有什么好处?如果你有产生这样的疑问,那我们就会很自豪的回答:“当然!(没有好处)”。事实上,基于注册机制的构建方式只有在结合配置文件时才会发挥它的最大优势。这里直接传入字典的写法也并非使用执行器的最佳实践。在这里,我们希望你能够通过这个例子读懂并习惯这种写法,方便理解我们马上将要讲到的执行器最佳实践——配置文件。敬请期待!

如果你作为初学者无法立刻理解,使用手动构建的方式依然不失为一种好选择,甚至在小规模使用、试错和调试时是一种更加推荐的方式,因为对于 IDE 更加友好。但我们也希望你能够读懂并习惯基于注册机制的写法,并且在后续教程中不会因此而产生不必要的混淆和疑惑。

我应该去哪里找到 xxx 参数的可能配置选项?

你可以在对应模块的教程中找到丰富的说明和示例,你也可以在 API 文档 中找到 Runner 的所有参数。如果上述两种方式都无法解决你的疑问,你随时可以在我们的讨论区中发起话题,帮助我们更好地改进文档。

我来自 MMDet/MMCls...下游库,为什么例子写法与我接触的不同?

OpenMMLab 下游库广泛采用了配置文件的方式。我们将在下个章节,基于上述示例稍微变换,从而展示配置文件——MMEngine 中执行器的最佳实践——的用法。

执行器最佳实践——配置文件

MMEngine 提供了一套支持 Python 语法的、功能强大的配置文件系统。你可以从之前的示例代码中近乎(我们将在下面说明)无缝地转换到配置文件。下面给出一段示例代码:

# 以下代码存放在 example_config.py 文件中
# 基本拷贝自上面的示例,并将每项结尾的逗号删去
model = dict(type='MyAwesomeModel',
    layers=2,
    activation='relu')
work_dir = 'exp/my_awesome_model'

train_dataloader = dict(
    dataset=dict(type='MyDataset',
        is_train=True,
        size=10000),
    sampler=dict(
        type='DefaultSampler',
        shuffle=True),
    collate_fn=dict(type='default_collate'),
    batch_size=64,
    pin_memory=True,
    num_workers=2)
train_cfg = dict(
    by_epoch=True,
    max_epochs=10,
    val_begin=2,
    val_interval=1)
optim_wrapper = dict(
    optimizer=dict(
        type='Adam',
        lr=0.001))
param_scheduler = dict(
    type='MultiStepLR',
    by_epoch=True,
    milestones=[4, 8],
    gamma=0.1)

val_dataloader = dict(
    dataset=dict(type='MyDataset',
        is_train=False,
        size=1000),
    sampler=dict(
        type='DefaultSampler',
        shuffle=False),
    collate_fn=dict(type='default_collate'),
    batch_size=1000,
    pin_memory=True,
    num_workers=2)
val_cfg = dict()
val_evaluator = dict(type='Accuracy')

default_hooks = dict(
    checkpoint=dict(type='CheckpointHook', interval=1))
launcher = 'none'
env_cfg = dict(
    cudnn_benchmark=False,
    backend='nccl',
    mp_cfg=dict(mp_start_method='fork'))
log_level = 'INFO'
load_from = None
resume = False

此时,我们只需要在训练代码中加载配置,然后运行即可

from mmengine.config import Config
from mmengine.runner import Runner
config = Config.fromfile('example_config.py')
runner = Runner.from_cfg(config)
runner.train()

注解

虽然是 Python 语法,但合法的配置文件需要满足以下条件:所有的变量必须是基本类型(例如 str dict int等)。因此,配置文件系统高度依赖于注册机制,以实现从基本类型到其他类型(如 nn.Module)的构建。

注解

使用配置文件时,你通常不需要手动注册所有模块。例如,torch.optim 中的所有优化器(如 Adam SGD等)都已经在 mmengine.optim 中注册完成。使用时的经验法则是:尝试直接使用 PyTorch 中的组件,只有当出现报错时再手动注册。

注解

当使用配置文件写法时,你的自定义模块的实现代码通常存放在独立文件中,可能并未被正确注册,进而导致构建失败。我们推荐你阅读注册器文档中 custom_imports 相关的内容以更好地使用配置文件系统。

警告

虽然与示例中的写法一致,但 from_cfg__init__ 的缺省值处理可能存在些微不同,例如 env_cfg 参数。

执行器配置文件已经在 OpenMMLab 的众多下游库(MMCls,MMDet…)中被广泛使用,并成为事实标准与最佳实践。配置文件的功能远不止如此,如果你对于继承、覆写等进阶功能感兴趣,请参考配置(Config)文档。

基本数据流

提示

在本章节中,我们将会介绍执行器内部各模块之间的数据传递流向与格式约定。如果你还没有基于 MMEngine 构建一个训练流程,本章节的部分内容可能会比较抽象、枯燥;你也可以暂时跳过,并在将来有需要时结合实践进行阅读。

接下来,我们将稍微深入执行器的内部,结合图示来理清其中数据的流向与格式约定。

基本数据流

上图是执行器的基本数据流,其中虚线边框、灰色填充的不同形状代表不同的数据格式,实线方框代表模块或方法。由于 MMEngine 强大的灵活性与可扩展性,你总可以继承某些关键基类并重载其中的方法,因此上图并不总是成立。只有当你没有自定义 RunnerTrainLoop ,并且你的自定义模型没有重载 train_stepval_steptest_step 方法时上图才会成立(而这在检测、分割等任务上是常见的,参考模型教程)。

可以确切地说明图中传递的每项数据的具体类型吗?

很遗憾,这一点无法做到。虽然 MMEngine 做了大量类型注释,但 Python 是一门高度动态化的编程语言,同时以数据为核心的深度学习系统也需要足够的灵活性来处理纷繁复杂的数据源,你有充分的自由决定何时需要(有时是必须)打破类型约定。因此,在你自定义某一或某几个模块(如 val_evaluator )时,你需要确保它的输入与上游(如 model 的输出)兼容,同时输出可以被下游解析。MMEngine 将处理数据的灵活性交给了用户,因而也需要用户保证数据流的兼容性——当然,实际上手后会发现,这一点并不十分困难。

数据一致性的考验一直存在于深度学习领域,MMEngine 也在尝试用自己的方式改进。如果你有兴趣,可以参考数据集基类抽象数据接口文档——但是请注意,它们主要面向进阶用户

dataloader、model 和 evaluator 之间的数据格式是如何约定的?

针对图中所展示的基本数据流,上述三个模块之间的数据传递可以用如下伪代码表示

# 训练过程
for data_batch in train_dataloader:
    data_batch = data_preprocessor(data_batch)
    if isinstance(data_batch, dict):
        losses = model.forward(**data_batch, mode='loss')
    elif isinstance(data_batch, (list, tuple)):
        losses = model.forward(*data_batch, mode='loss')
    else:
        raise TypeError()

# 验证过程
for data_batch in val_dataloader:
    data_batch = data_preprocessor(data_batch)
    if isinstance(data_batch, dict):
        outputs = model.forward(**data_batch, mode='predict')
    elif isinstance(data_batch, (list, tuple)):
        outputs = model.forward(**data_batch, mode='predict')
    else:
        raise TypeError()
    evaluator.process(data_samples=outputs, data_batch=data_batch)
metrics = evaluator.evaluate(len(val_dataloader.dataset))

上述伪代码的关键点在于:

  • data_preprocessor 的输出需要经过解包后传递给 model

  • evaluator 的 data_samples 参数接收模型的预测结果,而 data_batch 参数接收 dataloader 的原始数据

什么是 data_preprocessor?我可以用它做裁减缩放等图像预处理吗?

虽然图中的 data preprocessor 与 model 是分离的,但在实际中前者是后者的一部分,因此可以在模型文档中的数据处理器章节找到。

通常来说,数据处理器不需要额外关注和指定,默认的数据处理器只会自动将数据搬运到 GPU 中。但是,如果你的模型与数据加载器的数据格式不匹配,你也可以自定义一个数据处理器来进行格式转换。

裁减缩放等图像预处理更推荐在数据变换中进行,但如果是 batch 相关的数据处理(如 batch-resize 等),可以在这里实现。

为什么 model 产生了 3 个不同的输出? loss、predict、tensor 是什么含义?

15 分钟上手对此有一定的描述,你需要在自定义模型的 forward 函数中实现 3 条数据通路,适配训练、验证等不同需求。模型文档中对此有详细解释。

我可以看出红线是训练流程,蓝线是验证/测试流程,但绿线是什么?

在目前的执行器流程中,'tensor' 模式的输出并未被使用,大多数情况下用户无需实现。但一些情况下输出中间结果可以方便地进行 debug

如果我重载了 train_step 等方法,上图会完全失效吗?

默认的 train_stepval_steptest_step 的行为,覆盖了从数据进入 data preprocessormodel 输出 losspredict 结果的这一段流程,不影响其余部分。

为什么使用执行器(可选)

提示

这一部分内容并不能教会你如何使用执行器乃至整个 MMEngine,如果你正在被雇主/教授/DDL催促着几个小时内拿出成果,那这部分可能无法帮助到你,请随意跳过。但我们仍强烈推荐抽出时间阅读本章节,这可以帮助你更好地理解并使用 MMEngine

放轻松,接下来是闲聊时间

恭喜你通关了执行器!这真是一篇长长的、但还算有趣(希望如此)的教程。无论如何,请相信这些都是为了让你更加轻松——不论是本篇教程、执行器,还是 MMEngine。

执行器是 MMEngine 中所有模块的“管理者”。所有的独立模块——不论是模型、数据集这些看得见摸的着的,还是日志记录、分布式训练、随机种子等相对隐晦的——都在执行器中被统一调度、产生关联。事物之间的关系是复杂的,但执行器为你处理了一切,并提供了一个清晰易懂的配置式接口。这样做的好处主要有:

  1. 你可以轻易地在已搭建流程上修改/添加所需配置,而不会搅乱整个代码。也许你起初只有单卡训练,但你随时可以添加1、2行的分布式配置,切换到多卡甚至多机训练

  2. 你可以享受 MMEngine 不断引入的新特性,而不必担心后向兼容性。混合精度训练、可视化、崭新的分布式训练方式、多种设备后端……我们会在保证后向兼容性的前提下不断吸收社区的优秀建议与前沿技术,并以简洁明了的方式提供给你

  3. 你可以集中关注并实现自己的惊人想法,而不必受限于其他恼人的、不相关的细节。执行器的缺省值会为你处理绝大多数的情况

所以,MMEngine 与执行器会确实地让你更加轻松。只要花费一点点努力完成迁移,你的代码与实验会随着 MMEngine 的发展而与时俱进;如果再花费一点努力,MMEngine 的配置系统可以让你更加高效地管理数据、模型、实验。便利性与可靠性,这些正是我们努力的目标。

蓝色药丸,还是红色药丸——你准备好加入吗?

下一步的建议

如果你想要进一步地:

实现自己的模型结构

参考模型(Model)

使用自己的数据集

参考数据集(Dataset)与数据加载器(DataLoader)

更换模型评测/验证指标

参考模型精度评测(Evaluation)

调整优化器封装(如开启混合精度训练、梯度累积等)与更换优化器

参考优化器封装(OptimWrapper)

动态调整学习率等参数(如 warmup )

参考优化器参数调整策略(Parameter Scheduler)

其他
  • 左侧的“常用功能”中包含更多常用的与新特性的示例代码可供参考

  • “进阶教程”中有更多面向资深开发者的内容,可以更加灵活地配置训练流程、日志、可视化等

  • 如果以上所有内容都无法实现你的新想法,那么钩子(Hook)值得一试

  • 欢迎在我们的 讨论版 中发起话题求助!

数据集(Dataset)与数据加载器(DataLoader)

提示

如果你没有接触过 PyTorch 的数据集与数据加载器,我们推荐先浏览 PyTorch 官方教程以了解一些基本概念

数据集与数据加载器是 MMEngine 中训练流程的必要组件,它们的概念来源于 PyTorch,并且在含义上与 PyTorch 保持一致。通常来说,数据集定义了数据的总体数量、读取方式以及预处理,而数据加载器则在不同的设置下迭代地加载数据,如批次大小(batch_size)、随机乱序(shuffle)、并行(num_workers)等。数据集经过数据加载器封装后构成了数据源。在本篇教程中,我们将按照从外(数据加载器)到内(数据集)的顺序,逐步介绍它们在 MMEngine 执行器中的用法,并给出一些常用示例。读完本篇教程,你将会:

  • 掌握如何在 MMEngine 的执行器中配置数据加载器

  • 学会在配置文件中使用已有(如 torchvision)数据集

  • 了解如何使用自己的数据集

数据加载器详解

在执行器(Runner)中,你可以分别配置以下 3 个参数来指定对应的数据加载器

  • train_dataloader:在 Runner.train() 中被使用,为模型提供训练数据

  • val_dataloader:在 Runner.val() 中被使用,也会在 Runner.train() 中每间隔一段时间被使用,用于模型的验证评测

  • test_dataloader:在 Runner.test() 中被使用,用于模型的测试

MMEngine 完全支持 PyTorch 的原生 DataLoader,因此上述 3 个参数均可以直接传入构建好的 DataLoader,如15分钟上手中的例子所示。同时,借助 MMEngine 的注册机制,以上参数也可以传入 dict,如下面代码(以下简称例 1)所示。字典中的键值与 DataLoader 的构造参数一一对应。

runner = Runner(
    train_dataloader=dict(
        batch_size=32,
        sampler=dict(
            type='DefaultSampler',
            shuffle=True),
        dataset=torchvision.datasets.CIFAR10(...),
        collate_fn=dict(type='default_collate')
    )
)

在这种情况下,数据加载器会在实际被用到时,在执行器内部被构建。

注解

关于 DataLoader 的更多可配置参数,你可以参考 PyTorch API 文档

注解

如果你对于构建的具体细节感兴趣,你可以参考 build_dataloader

细心的你可能会发现,例 1 并非直接由15分钟上手中的示例代码简单修改而来。你可能本以为将 DataLoader 简单替换为 dict 就可以无缝切换,但遗憾的是,基于注册机制构建时 MMEngine 会有一些隐式的转换和约定。我们将介绍其中的不同点,以避免你使用配置文件时产生不必要的疑惑。

sampler 与 shuffle

与 15 分钟上手明显不同,例 1 中我们添加了 sampler 参数,这是由于在 MMEngine 中我们要求通过 dict 传入的数据加载器的配置必须包含 sampler 参数。同时,shuffle 参数也从 DataLoader 中移除,这是由于在 PyTorch 中 samplershuffle 参数是互斥的,见 PyTorch API 文档

注解

事实上,在 PyTorch 的实现中,shuffle 只是一个便利记号。当设置为 TrueDataLoader 会自动在内部使用 RandomSampler

当考虑 sampler 时,例 1 代码基本可以认为等价于下面的代码块

from mmengine.dataset import DefaultSampler

dataset = torchvision.datasets.CIFAR10(...)
sampler = DefaultSampler(dataset, shuffle=True)

runner = Runner(
    train_dataloader=DataLoader(
        batch_size=32,
        sampler=sampler,
        dataset=dataset,
        collate_fn=default_collate
    )
)

警告

上述代码的等价性只有在:1)使用单进程训练,以及 2)没有配置执行器的 randomness 参数时成立。这是由于使用 dict 传入 sampler 时,执行器会保证它在分布式训练环境设置完成后才被惰性构造,并接收到正确的随机种子。这两点在手动构造时需要额外工作且极易出错。因此,上述的写法只是一个示意而非推荐写法。我们强烈建议 samplerdict 的形式传入,让执行器处理构造顺序,以避免出现问题。

DefaultSampler

上面例子可能会让你好奇:DefaultSampler 是什么,为什么要使用它,是否有其他选项?事实上,DefaultSampler 是 MMEngine 内置的一种采样器,它屏蔽了单进程训练与多进程训练的细节差异,使得单卡与多卡训练可以无缝切换。如果你有过使用 PyTorch DistributedDataParallel 的经验,你一定会对其中更换数据加载器的 sampler 参数有所印象。但在 MMEngine 中,这一细节通过 DefaultSampler 而被屏蔽。

除了 Dataset 本身之外,DefaultSampler 还支持以下参数配置:

  • shuffle 设置为 True 时会打乱数据集的读取顺序

  • seed 打乱数据集所用的随机种子,通常不需要在此手动设置,会从 Runnerrandomness 入参中读取

  • round_up 设置为 True 时,与 PyTorch DataLoader 中设置 drop_last=False 行为一致。如果你在迁移 PyTorch 的项目,你可能需要注意这一点。

注解

更多关于 DefaultSampler 的内容可以参考 API 文档

DefaultSampler 适用于绝大部分情况,并且我们保证在执行器中使用它时,随机数等容易出错的细节都被正确地处理,防止你陷入多进程训练的常见陷阱。如果你想要使用基于迭代次数 (iteration-based) 的训练流程,你也许会对 InfiniteSampler 感兴趣。如果你有更多的进阶需求,你可能会想要参考上述两个内置 sampler 的代码,实现一个自定义的 sampler 并注册到 DATA_SAMPLERS 根注册器中。

@DATA_SAMPLERS.register_module()
class MySampler(Sampler):
    pass

runner = Runner(
    train_dataloader=dict(
        sampler=dict(type='MySampler'),
        ...
    )
)

不起眼的 collate_fn

PyTorch 的 DataLoader 中,collate_fn 这一参数常常被使用者忽略,但在 MMEngine 中你需要额外注意:当你传入 dict 来构造数据加载器时,MMEngine 会默认使用内置的 pseudo_collate,这一点明显区别于 PyTorch 默认的 default_collate。因此,当你迁移 PyTorch 项目时,需要在配置文件中手动指明 collate_fn 以保持行为一致。

注解

MMEngine 中使用 pseudo_collate 作为默认值,主要是由于历史兼容性原因,你可以不必过于深究,只需了解并避免错误使用即可。

MMengine 中提供了 2 种内置的 collate_fn

  • pseudo_collate,缺省时的默认参数。它不会将数据沿着 batch 的维度合并。详细说明可以参考 pseudo_collate

  • default_collate,与 PyTorch 中的 default_collate 行为几乎完全一致,会将数据转化为 Tensor 并沿着 batch 维度合并。一些细微不同和详细说明可以参考 default_collate

如果你想要使用自定义的 collate_fn,你也可以将它注册到 COLLATE_FUNCTIONS 根注册器中来使用

@COLLATE_FUNCTIONS.register_module()
def my_collate_func(data_batch: Sequence) -> Any:
    pass

runner = Runner(
    train_dataloader=dict(
        ...
        collate_fn=dict(type='my_collate_func')
    )
)

数据集详解

数据集通常定义了数据的数量、读取方式与预处理,并作为参数传递给数据加载器供后者分批次加载。由于我们使用了 PyTorch 的 DataLoader,因此数据集也自然与 PyTorch Dataset 完全兼容。同时得益于注册机制,当数据加载器使用 dict 在执行器内部构建时,dataset 参数也可以使用 dict 传入并在内部被构建。这一点使得编写配置文件成为可能。

使用 torchvision 数据集

torchvision 中提供了丰富的公开数据集,它们都可以在 MMEngine 中直接使用,例如 15 分钟上手中的示例代码就使用了其中的 Cifar10 数据集,并且使用了 torchvision 中内置的数据预处理模块。

但是,当需要将上述示例转换为配置文件时,你需要对 torchvision 中的数据集进行额外的注册。如果你同时用到了 torchvision 中的数据预处理模块,那么你也需要编写额外代码来对它们进行注册和构建。下面我们将给出一个等效的例子来展示如何做到这一点。

import torchvision.transforms as tvt
from mmengine.registry import DATASETS, TRANSFORMS
from mmengine.dataset.base_dataset import Compose

# 注册 torchvision 的 CIFAR10 数据集
# 数据预处理也需要在此一起构建
@DATASETS.register_module(name='Cifar10', force=False)
def build_torchvision_cifar10(transform=None, **kwargs):
    if isinstance(transform, dict):
        transform = [transform]
    if isinstance(transform, (list, tuple)):
        transform = Compose(transform)
    return torchvision.datasets.CIFAR10(**kwargs, transform=transform)

# 注册 torchvision 中用到的数据预处理模块
DATA_TRANSFORMS.register_module('RandomCrop', module=tvt.RandomCrop)
DATA_TRANSFORMS.register_module('RandomHorizontalFlip', module=tvt.RandomHorizontalFlip)
DATA_TRANSFORMS.register_module('ToTensor', module=tvt.ToTensor)
DATA_TRANSFORMS.register_module('Normalize', module=tvt.Normalize)

# 在 Runner 中使用
runner = Runner(
    train_dataloader=dict(
        batch_size=32,
        sampler=dict(
            type='DefaultSampler',
            shuffle=True),
        dataset=dict(type='Cifar10',
            root='data/cifar10',
            train=True,
            download=True,
            transform=[
                dict(type='RandomCrop', size=32, padding=4),
                dict(type='RandomHorizontalFlip'),
                dict(type='ToTensor'),
                dict(type='Normalize', **norm_cfg)])
    )
)

注解

上述例子中大量使用了注册机制,并且用到了 MMEngine 中的 Compose。如果你急需在配置文件中使用 torchvision 数据集,你可以参考上述代码并略作修改。但我们更加推荐你有需要时在下游库(如 MMDetMMCls 等)中寻找对应的数据集实现,从而获得更好的使用体验。

自定义数据集

你可以像使用 PyTorch 一样,自由地定义自己的数据集,或将之前 PyTorch 项目中的数据集拷贝过来。如果你想要了解如何自定义数据集,可以参考 PyTorch 官方教程

使用 MMEngine 的数据集基类

除了直接使用 PyTorch 的 Dataset 来自定义数据集之外,你也可以使用 MMEngine 内置的 BaseDataset,参考数据集基类文档。它对标注文件的格式做了一些约定,使得数据接口更加统一、多任务训练更加便捷。同时,数据集基类也可以轻松地搭配内置的数据变换使用,减轻你从头搭建训练流程的工作量。

目前,BaseDataset 已经在 OpenMMLab 2.0 系列的下游仓库中被广泛使用。

模型(Model)

Runner 与 model

Runner 教程的基本数据流中我们提到,DataLoader、model 和 evaluator 之间的数据流通遵循了一些规则,我们先来回顾一下基本数据流的伪代码:

# 训练过程
for data_batch in train_dataloader:
    data_batch = model.data_preprocessor(data_batch, training=True)
    if isinstance(data_batch, dict):
        losses = model(**data_batch, mode='loss')
    elif isinstance(data_batch, (list, tuple)):
        losses = model(*data_batch, mode='loss')
    else:
        raise TypeError()
# 验证过程
for data_batch in val_dataloader:
    data_batch = model.data_preprocessor(data_batch, training=False)
    if isinstance(data_batch, dict):
        outputs = model(**data_batch, mode='predict')
    elif isinstance(data_batch, (list, tuple)):
        outputs = model(**data_batch, mode='predict')
    else:
        raise TypeError()
    evaluator.process(data_samples=outputs, data_batch=data_batch)
metrics = evaluator.evaluate(len(val_dataloader.dataset))

在 Runner 的教程中,我们简单介绍了模型和前后组件之间的数据流通关系,提到了 data_preprocessor 的概念,对 model 有了一定的了解。然而在 Runner 实际运行的过程中,模型的功能和调用关系,其复杂程度远超上述伪代码。为了让你能够不感知模型和外部组件的复杂关系,进而聚焦精力到算法本身,我们设计了 BaseModel。大多数情况下你只需要让 model 继承 BaseModel,并按照要求实现 forward 接口,就能完成训练、测试、验证的逻辑。

在继续阅读模型教程之前,我们先抛出两个问题,希望你在阅读完 model 教程后能够找到相应的答案:

  1. 我们在什么位置更新模型参数?如果我有一些非常复杂的参数更新逻辑,又该如何实现?

  2. 为什么要有 data_preprocessor 的概念?它又可以实现哪些功能?

接口约定

在训练深度学习任务时,我们通常需要定义一个模型来实现算法的主体。在基于 MMEngine 开发时,定义的模型由执行器管理,且需要实现 train_stepval_steptest_step 方法。 对于检测、识别、分割一类的深度学习任务,上述方法通常为标准的流程,例如在 train_step 里更新参数,返回损失;val_steptest_step 返回预测结果。因此 MMEngine 抽象出模型基类 BaseModel,实现了上述接口的标准流程。

得益于 BaseModel 我们只需要让模型继承自模型基类,并按照一定的规范实现 forward,就能让模型在执行器中运行起来。

注解

模型基类继承自模块基类,能够通过配置 init_cfg 灵活地选择初始化方式。

forward: forward 的入参需通常需要和 DataLoader 的输出保持一致 (自定义数据预处理器除外),如果 DataLoader 返回元组类型的数据 dataforward 需要能够接受 *data 的解包后的参数;如果返回字典类型的数据 dataforward 需要能够接受 **data 解包后的参数。 mode 参数用于控制 forward 的返回结果:

  • mode='loss'loss 模式通常在训练阶段启用,并返回一个损失字典。损失字典的 key-value 分别为损失名和可微的 torch.Tensor。字典中记录的损失会被用于更新参数和记录日志。模型基类会在 train_step 方法中调用该模式的 forward

  • mode='predict'predict 模式通常在验证、测试阶段启用,并返回列表/元组形式的预测结果,预测结果需要和 process 接口的参数相匹配。OpenMMLab 系列算法对 predict 模式的输出有着更加严格的约定,需要输出列表形式的数据元素。模型基类会在 val_steptest_step 方法中调用该模式的 forward

  • mode='tensor'tensorpredict 模式均返回模型的前向推理结果,区别在于 tensor 模式下,forward 会返回未经后处理的张量,例如返回未经非极大值抑制(nms)处理的检测结果,返回未经 argmax 处理的分类结果。我们可以基于 tensor 模式的结果进行自定义的后处理。

train_step: 执行 forward 方法的 loss 分支,得到损失字典。模型基类基于优化器封装 实现了标准的梯度计算、参数更新、梯度清零流程。其等效伪代码如下:

def train_step(self, data, optim_wrapper):
    data = self.data_preprocessor(data, training=True)  # 按下不表,详见数据与处理器一节
    loss = self(**data, mode='loss')  # loss 模式,返回损失字典,假设 data 是字典,使用 ** 进行解析。事实上 train_step 兼容 tuple 和 dict 类型的输入。
    parsed_losses, log_vars = self.parse_losses() # 解析损失字典,返回可以 backward 的损失以及可以被日志记录的损失
    optim_wrapper.update_params(parsed_losses)  # 更新参数
    return log_vars

val_step: 执行 forward 方法的 predict 分支,返回预测结果:

def val_step(self, data, optim_wrapper):
    data = self.data_preprocessor(data, training=False)
    outputs = self(**data, mode='predict') # 预测模式,返回预测结果
    return outputs

test_step: 同 val_step

看到这我们就可以给出一份 基本数据流伪代码 plus

# 训练过程
for data_batch in train_dataloader:
    loss_dict = model.train_step(data_batch)
# 验证过程
for data_batch in val_dataloader:
    preds = model.test_step(data_batch)
    evaluator.process(data_samples=outputs, data_batch=data_batch)
metrics = evaluator.evaluate(len(val_dataloader.dataset))

没错,抛开 Hook 不谈,loop 调用 model 的过程和上述代码一模一样!看到这,我们再回过头去看 15 分钟上手 MMEngine 里的模型定义部分,就有一种看山不是山的感觉:

import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel

class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels

    # 下面的 3 个方法已在 BaseModel 实现,这里列出是为了
    # 解释调用过程
    def train_step(self, data, optim_wrapper):
        data = self.data_preprocessor(data)
        loss = self(*data, mode='loss')  # CIFAR10 返回 tuple,因此用 * 解包
        parsed_losses, log_vars = self.parse_losses()
        optim_wrapper.update_params(parsed_losses)
        return log_vars

    def val_step(self, data, optim_wrapper):
        data = self.data_preprocessor(data)
        outputs = self(*data, mode='predict')
        return outputs

    def test_step(self, data, optim_wrapper):
        data = self.data_preprocessor(data)
        outputs = self(*data, mode='predict')
        return outputs

看到这里,相信你对数据流有了更加深刻的理解,也能够回答 Runner 与 model 里提到的第一个问题:

BaseModel.train_step 里实现了默认的参数更新逻辑,如果我们想实现自定义的参数更新流程,可以重写 train_step 方法。但是需要注意的是,我们需要保证 train_step 最后能够返回损失字典。

数据预处理器(DataPreprocessor)

如果你的电脑配有 GPU(或其他能够加速训练的硬件,如 MPS、IPU 等),并且运行了 15 分钟上手 MMEngine 的代码示例,你会发现程序是在 GPU 上运行的,那么 MMEngine 是在何时把数据和模型从 CPU 搬运到 GPU 的呢?

事实上,执行器会在构造阶段将模型搬运到指定设备,而数据则会在上一节提到的 self.data_preprocessor 这一行搬运到指定设备,进一步将处理好的数据传给模型。看到这里相信你会疑惑:

  1. MMResNet50 并没有配置 data_preprocessor,为什么却可以访问到 data_preprocessor,并且把数据搬运到 GPU?

  2. 为什么不直接在模型里调用 data.to(device) 搬运数据,而需要有 data_preprocessor 这一层抽象?它又能实现哪些功能?

首先回答第一个问题:MMResNet50 继承了 BaseModel。在执行 super().__init__ 时,如果不传入任何参数,会构造一个默认的 BaseDataPreprocessor,其等效简易实现如下:

class BaseDataPreprocessor(nn.Module):
    def forward(self, data, training=True):  # 先忽略 training 参数
        # 假设 data 是 CIFAR10 返回的 tuple 类型数据,事实上
        # BaseDataPreprocessor 可以处理任意类型的数
        # BaseDataPreprocessor 同样可以把数据搬运到多种设备,这边方便
        # 起见写成 .cuda()
        return tuple(_data.cuda() for _data in data)

BaseDataPreprocessor 会在训练过程中,将各种类型的数据搬运到指定设备。

在回答第二个问题之前,我们不妨先再思考几个问题

  1. 数据归一化操作应该在哪里进行,transform 还是 model?

    听上去好像都挺合理,放在 transform 里可以利用 Dataloader 的多进程加速,放在 model 里可以搬运到 GPU 上,利用GPU 资源加速归一化。然而在我们纠结 CPU 归一化快还是 GPU 归一化快的时候,CPU 到 GPU 的数据搬运耗时相较于前者,可算的上是“降维打击”。 事实上对于归一化这类计算量较低的操作,其耗时会远低于数据搬运,因此优化数据搬运的效率就显得更加重要。设想一下,如果我能够在数据仍处于 uint8 时、归一化之前将其搬运到指定设备上(归一化后的 float 型数据大小是 unit8 的 4 倍),就能降低带宽,大大提升数据搬运的效率。这种“滞后”归一化的行为,也是我们设计数据预处理器(data preprocessor) 的主要原因之一。数据预处理器会先搬运数据,再做归一化,提升数据搬运的效率。

  2. 我们应该如何实现 MixUp、Mosaic 一类的数据增强?

    尽管看上去 MixUp 和 Mosaic 只是一种特殊的数据变换,按理说应该在 transform 里实现。考虑到这两种增强会涉及到“将多张图片融合成一张图片”的操作,在 transform 里实现他们的难度就会很大,因为目前 transform 的范式是对一张图片做各种增强,我们很难在一个 transform 里去额外读取其他图片(transform 里无法访问到 dataset)。然而如果基于 Dataloader 采样得到的 batch_data 去实现 Mosaic 或者 Mixup,事情就会变得非常简单,因为这个时候我们能够同时访问多张图片,可以轻而易举的完成图片融合的操作:

    class MixUpDataPreprocessor(nn.Module):
        def __init__(self, num_class, alpha):
            self.alpha = alpha
    
        def forward(self, data, training=True):
            data = tuple(_data.cuda() for _data in data)
            # 验证阶段无需进行 MixUp 数据增强
            if not training:
                return data
    
            label = F.one_hot(label)  # label 转 onehot 编码
            batch_size = len(label)
            index = torch.randperm(batch_size)  # 计算用于叠加的图片数
            img, label = data
            lam = np.random.beta(self.alpha, self.alpha)  # 融合因子
    
            # 原图和标签的 MixUp.
            img = lam * img + (1 - lam) * img[index, :]
            label = lam * batch_scores + (1 - lam) * batch_scores[index, :]
            # 由于此时返回的是 onehot 编码的 label,model 的 forward 也需要做相应调整
            return tuple(img, label)
    

    因此,除了数据搬运和归一化,data_preprocessor 另一大功能就是数据批增强(BatchAugmentation)。数据预处理器的模块化也能帮助我们实现算法和数据增强之间的自由组合。

  3. 如果 DataLoader 的输出和模型的输入类型不匹配怎么办,是修改 DataLoader 还是修改模型接口?

    答案是都不合适。理想的解决方案是我们能够在不破坏模型和数据已有接口的情况下完成适配。这个时候数据预处理器也能承担类型转换的工作,例如将传入的 data 从 tuple 转换成指定字段的 dict

看到这里,相信你已经能够理解数据预处理器存在的合理性,并且也能够自信地回答教程最初提出的两个问题!但是你可能还会疑惑 train_step 接口中传入的 optim_wrapper 又是什么,test_stepval_step 返回的结果和 evaluator 又有怎样的关系,这些问题会在模型精度评测教程优化器封装得到解答。

模型精度评测(Evaluation)

在模型验证和模型测试中,通常需要对模型精度做定量评测。我们可以通过在配置文件中指定评测指标(Metric)来实现这一功能。

在模型训练或测试中进行评测

使用单个评测指标

在基于 MMEngine 进行模型训练或测试时,用户只需要在配置文件中通过 val_evaluatortest_evaluator 2 个字段分别指定模型验证和测试阶段的评测指标即可。例如,用户在使用 MMClassification 训练分类模型时,希望在模型验证阶段评测 top-1 和 top-5 分类正确率,可以按以下方式配置:

val_evaluator = dict(type='Accuracy', top_k=(1, 5))  # 使用分类正确率评测指标

关于具体评测指标的参数设置,用户可以查阅相关算法库的文档。如上例中的 Accuracy 文档

使用多个评测指标

如果需要同时评测多个指标,也可以将 val_evaluatortest_evaluator 设置为一个列表,其中每一项为一个评测指标的配置信息。例如,在使用 MMDetection 训练全景分割模型时,希望在模型测试阶段同时评测模型的目标检测(COCO AP/AR)和全景分割精度,可以按以下方式配置:

test_evaluator = [
    # 目标检测指标
    dict(
        type='CocoMetric',
        metric=['bbox', 'segm'],
        ann_file='annotations/instances_val2017.json',
    ),
    # 全景分割指标
    dict(
        type='CocoPanopticMetric',
        ann_file='annotations/panoptic_val2017.json',
        seg_prefix='annotations/panoptic_val2017',
    )
]

自定义评测指标

如果算法库中提供的常用评测指标无法满足需求,用户也可以增加自定义的评测指标。我们以简化的分类正确率为例,介绍实现自定义评测指标的方法:

  1. 在定义新的评测指标类时,需要继承基类 BaseMetric(关于该基类的介绍,可以参考设计文档)。此外,评测指标类需要用注册器 METRICS 进行注册(关于注册器的说明请参考 Registry 文档)。

  2. 实现 process() 方法。该方法有 2 个输入参数,分别是一个批次的测试数据样本 data_batch 和模型预测结果 data_samples。我们从中分别取出样本类别标签和分类预测结果,并存放在 self.results 中。

  3. 实现 compute_metrics() 方法。该方法有 1 个输入参数 results,里面存放了所有批次测试数据经过 process() 方法处理后得到的结果。从中取出样本类别标签和分类预测结果,即可计算得到分类正确率 acc。最终,将计算得到的评测指标以字典的形式返回。

  4. (可选)可以为类属性 default_prefix 赋值。该属性会自动作为输出的评测指标名前缀(如 defaut_prefix='my_metric',则实际输出的评测指标名为 'my_metric/acc'),用以进一步区分不同的评测指标。该前缀也可以在配置文件中通过 prefix 参数改写。我们建议在 docstring 中说明该评测指标类的 default_prefix 值以及所有的返回指标名称。

具体实现如下:

from mmengine.evaluator import BaseMetric
from mmengine.registry import METRICS

import numpy as np


@METRICS.register_module()  # 将 Accuracy 类注册到 METRICS 注册器
class SimpleAccuracy(BaseMetric):
    """ Accuracy Evaluator

    Default prefix: ACC

    Metrics:
        - accuracy (float): classification accuracy
    """

    default_prefix = 'ACC'  # 设置 default_prefix

    def process(self, data_batch: Sequence[dict], data_samples: Sequence[dict]):
        """Process one batch of data and predictions. The processed
        Results should be stored in `self.results`, which will be used
        to compute the metrics when all batches have been processed.

        Args:
            data_batch (Sequence[Tuple[Any, dict]]): A batch of data
                from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from
                the model.
        """

        # 取出分类预测结果和类别标签
        result = {
            'pred': data_samples['pred_label'],
            'gt': data_samples['data_sample']['gt_label']
        }

        # 将当前 batch 的结果存进 self.results
        self.results.append(result)

    def compute_metrics(self, results: List):
        """Compute the metrics from processed results.

        Args:
            results (dict): The processed results of each batch.

        Returns:
            Dict: The computed metrics. The keys are the names of the metrics,
            and the values are corresponding results.
        """

        # 汇总所有样本的分类预测结果和类别标签
        preds = np.concatenate([res['pred'] for res in results])
        gts = np.concatenate([res['gt'] for res in results])

        # 计算分类正确率
        acc = (preds == gts).sum() / preds.size

        # 返回评测指标结果
        return {'accuracy': acc}

使用离线结果进行评测

另一种常见的模型评测方式,是利用提前保存在文件中的模型预测结果进行离线评测。此时,用户需要手动构建评测器,并调用评测器的相应接口完成评测。关于离线评测的详细说明,以及评测器和评测指标的关系,可以参考设计文档。我们仅在此给出一个离线评测示例:

from mmengine.evaluator import Evaluator
from mmengine.fileio import load

# 构建评测器。参数 `metrics` 为评测指标配置
evaluator = Evaluator(metrics=dict(type='Accuracy', top_k=(1, 5)))

# 从文件中读取测试数据。数据格式需要参考具使用的 metric。
data = load('test_data.pkl')

# 从文件中读取模型预测结果。该结果由待评测算法在测试数据集上推理得到。
# 数据格式需要参考具使用的 metric。
data_samples = load('prediction.pkl')

# 调用评测器离线评测接口,得到评测结果
# chunk_size 表示每次处理的样本数量,可根据内存大小调整
results = evaluator.offline_evaluate(data, data_samples, chunk_size=128)

优化器封装(OptimWrapper)

执行器教程模型教程中,我们或多或少地提到了优化器封装(OptimWrapper)的概念,但是却没有介绍为什么我们需要优化器封装,相比于 Pytorch 原生的优化器,优化器封装又有怎样的优势,这些问题会在本教程中得到一一解答。我们将通过对比的方式帮助大家理解,优化器封装的优势,以及如何使用它。

优化器封装顾名思义,是 Pytorch 原生优化器(Optimizer)高级抽象,它在增加了更多功能的同时,提供了一套统一的接口。优化器封装支持不同的训练策略,包括混合精度训练、梯度累加和梯度截断。我们可以根据需求选择合适的训练策略。优化器封装还定义了一套标准的参数更新流程,用户可以基于这一套流程,实现同一套代码,不同训练策略的切换。

优化器封装 vs 优化器

这里我们分别基于 Pytorch 内置的优化器和 MMEngine 的优化器封装进行单精度训练、混合精度训练和梯度累加,对比二者实现上的区别。

训练模型

1.1 基于 Pytorch 的 SGD 优化器实现单精度训练

import torch
from torch.optim import SGD
import torch.nn as nn
import torch.nn.functional as F

inputs = [torch.zeros(10, 1, 1)] * 10
targets = [torch.ones(10, 1, 1)] * 10
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

1.2 使用 MMEngine 的优化器封装实现单精度训练

from mmengine.optim import OptimWrapper

optim_wrapper = OptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    optim_wrapper.update_params(loss)

image

优化器封装的 update_params 实现了标准的梯度计算、参数更新和梯度清零流程,可以直接用来更新模型参数。

2.1 基于 Pytorch 的 SGD 优化器实现混合精度训练

from torch.cuda.amp import autocast

model = model.cuda()
inputs = [torch.zeros(10, 1, 1, 1)] * 10
targets = [torch.ones(10, 1, 1, 1)] * 10

for input, target in zip(inputs, targets):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

2.2 基于 MMEngine 的 优化器封装实现混合精度训练

from mmengine.optim import AmpOptimWrapper

optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

image

开启混合精度训练需要使用 AmpOptimWrapper,他的 optim_context 接口类似 autocast,会开启混合精度训练的上下文。除此之外他还能加速分布式训练时的梯度累加,这个我们会在下一个示例中介绍。

3.1 基于 Pytorch 的 SGD 优化器实现混合精度训练和梯度累加

for idx, (input, target) in enumerate(zip(inputs, targets)):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    if idx % 2 == 0:
        optimizer.step()
        optimizer.zero_grad()

3.2 基于 MMEngine 的优化器封装实现混合精度训练和梯度累加

optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=2)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

image

我们只需要配置 accumulative_counts 参数,并调用 update_params 接口就能实现梯度累加的功能。除此之外,分布式训练情况下,如果我们配置梯度累加的同时开启了 optim_wrapper 上下文,可以避免梯度累加阶段不必要的梯度同步。

优化器封装同样提供了更细粒度的接口,方便用户实现一些自定义的参数更新逻辑:

  • backward:传入损失,用于计算参数梯度。

  • step:同 optimizer.step,用于更新参数。

  • zero_grad:同 optimizer.zero_grad,用于参数的梯度。

我们可以使用上述接口实现和 Pytorch 优化器相同的参数更新逻辑:

for idx, (input, target) in enumerate(zip(inputs, targets)):
    optimizer.zero_grad()
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.backward(loss)
    if idx % 2 == 0:
        optim_wrapper.step()
        optim_wrapper.zero_grad()

我们同样可以为优化器封装配置梯度裁减策略:

# 基于 torch.nn.utils.clip_grad_norm_ 对梯度进行裁减
optim_wrapper = AmpOptimWrapper(
    optimizer=optimizer, clip_grad=dict(max_norm=1))

# 基于 torch.nn.utils.clip_grad_value_ 对梯度进行裁减
optim_wrapper = AmpOptimWrapper(
    optimizer=optimizer, clip_grad=dict(clip_value=0.2))

获取学习率/动量

优化器封装提供了 get_lrget_momentum 接口用于获取优化器的一个参数组的学习率:

import torch.nn as nn
from torch.optim import SGD

from mmengine.optim import OptimWrapper

model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optim_wrapper = OptimWrapper(optimizer)

print(optimizer.param_groups[0]['lr'])  # 0.01
print(optimizer.param_groups[0]['momentum'])  # 0
print(optim_wrapper.get_lr())  # {'lr': [0.01]}
print(optim_wrapper.get_momentum())  # {'momentum': [0]}
0.01
0
{'lr': [0.01]}
{'momentum': [0]}

导出/加载状态字典

优化器封装和优化器一样,提供了 state_dictload_state_dict 接口,用于导出/加载优化器状态,对于 AmpOptimWrapper,优化器封装还会额外导出混合精度训练相关的参数:

import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import OptimWrapper, AmpOptimWrapper

model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)

optim_wrapper = OptimWrapper(optimizer=optimizer)
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

# 导出状态字典
optim_state_dict = optim_wrapper.state_dict()
amp_optim_state_dict = amp_optim_wrapper.state_dict()

print(optim_state_dict)
print(amp_optim_state_dict)
optim_wrapper_new = OptimWrapper(optimizer=optimizer)
amp_optim_wrapper_new = AmpOptimWrapper(optimizer=optimizer)

# 加载状态字典
amp_optim_wrapper_new.load_state_dict(amp_optim_state_dict)
optim_wrapper_new.load_state_dict(optim_state_dict)
{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'params': [0, 1]}]}
{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'params': [0, 1]}], 'loss_scaler': {'scale': 65536.0, 'growth_factor': 2.0, 'backoff_factor': 0.5, 'growth_interval': 2000, '_growth_tracker': 0}}

使用多个优化器

考虑到生成对抗网络之类的算法通常需要使用多个优化器来训练生成器和判别器,因此优化器封装提供了优化器封装的容器类:OptimWrapperDict 来管理多个优化器封装。OptimWrapperDict 以字典的形式存储优化器封装,并允许用户像字典一样访问、遍历其中的元素,即优化器封装实例。

与普通的优化器封装不同,OptimWrapperDict 没有实现 update_paramsoptim_context, backwardstep 等方法,无法被直接用于训练模型。我们建议直接访问 OptimWrapperDict 管理的优化器实例,来实现参数更新逻辑。

你或许会好奇,既然 OptimWrapperDict 没有训练的功能,那为什么不直接使用 dict 来管理多个优化器?事实上,OptimWrapperDict 的核心功能是支持批量导出/加载所有优化器封装的状态字典;支持获取多个优化器封装的学习率、动量。如果没有 OptimWrapperDictMMEngine 就需要在很多位置对优化器封装的类型做 if else 判断,以获取所有优化器封装的状态。

from torch.optim import SGD
import torch.nn as nn

from mmengine.optim import OptimWrapper, OptimWrapperDict

gen = nn.Linear(1, 1)
disc = nn.Linear(1, 1)
optimizer_gen = SGD(gen.parameters(), lr=0.01)
optimizer_disc = SGD(disc.parameters(), lr=0.01)

optim_wapper_gen = OptimWrapper(optimizer=optimizer_gen)
optim_wapper_disc = OptimWrapper(optimizer=optimizer_disc)
optim_dict = OptimWrapperDict(gen=optim_wapper_gen, disc=optim_wapper_disc)

print(optim_dict.get_lr())  # {'gen.lr': [0.01], 'disc.lr': [0.01]}
print(optim_dict.get_momentum())  # {'gen.momentum': [0], 'disc.momentum': [0]}
{'gen.lr': [0.01], 'disc.lr': [0.01]}
{'gen.momentum': [0], 'disc.momentum': [0]}

如上例所示,OptimWrapperDict 可以非常方便的导出所有优化器封装的学习率和动量,同样的,优化器封装也能够导出/加载所有优化器封装的状态字典。

执行器中配置优化器封装

优化器封装需要接受 optimizer 参数,因此我们首先需要为优化器封装配置 optimizer。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 OPTIMIZERS 注册表中,用户可以用字典的形式来指定优化器,所有支持的优化器见 PyTorch 优化器列表

以配置一个 SGD 优化器封装为例:

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)

这样我们就配置好了一个优化器类型为 SGD 的优化器封装,学习率、动量等参数如配置所示。考虑到 OptimWrapper 为标准的单精度训练,因此我们也可以不配置 type 字段:

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(optimizer=optimizer)

要想开启混合精度训练和梯度累加,需要将 type 切换成 AmpOptimWrapper,并指定 accumulative_counts 参数:

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(type='AmpOptimWrapper', optimizer=optimizer, accumulative_counts=2)

注解

如果你是第一次阅读 MMEngine 的教程文档,并且尚未了解配置类注册器 等概念,建议可以先跳过以下进阶教程,先去阅读其他文档。当然了,如果你已经具备了这些储备知识,我们强烈建议阅读进阶教程,在进阶教程中,我们将学会:

  1. 如何在配置文件中定制化地在优化器中配置模型参数的学习率、衰减系数等。

  2. 如何自定义一个优化器构造策略,实现真正意义上的“优化器配置自由”。

除了配置类和注册器等前置知识,我们建议在开始进阶教程之前,先深入了解 Pytorch 原生优化器构造时的 params 参数。

进阶配置

PyTorch 的优化器支持对模型中的不同参数设置不同的超参数,例如对一个分类模型的骨干(backbone)和分类头(head)设置不同的学习率:

from torch.optim import SGD
import torch.nn as nn

model = nn.ModuleDict(dict(backbone=nn.Linear(1, 1), head=nn.Linear(1, 1)))
optimizer = SGD([{'params': model.backbone.parameters()},
     {'params': model.head.parameters(), 'lr': 1e-3}],
    lr=0.01,
    momentum=0.9)

上面的例子中,模型的骨干部分使用了 0.01 学习率,而模型的头部则使用了 1e-3 学习率。用户可以将模型的不同部分参数和对应的超参组成一个字典的列表传给优化器,来实现对模型优化的细粒度调整。

在 MMEngine 中,我们通过优化器封装构造器(optimizer wrapper constructor),让用户能够直接通过设置优化器封装配置文件中的 paramwise_cfg 字段而非修改代码来实现对模型的不同部分设置不同的超参。

为不同类型的参数设置不同的超参系数

MMEngine 提供的默认优化器封装构造器支持对模型中不同类型的参数设置不同的超参系数。例如,我们可以在 paramwise_cfg 中设置 norm_decay_mult=0,从而将正则化层(normalization layer)的权重(weight)和偏置(bias)的权值衰减系数(weight decay)设置为 0,来实现 Bag of Tricks 论文中提到的不对正则化层进行权值衰减的技巧。

具体示例如下,我们将 ToyModel 中所有正则化层(head.bn)的权重衰减系数设置为 0:

from mmengine.optim import build_optim_wrapper
from collections import OrderedDict

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.ModuleDict(
            dict(layer0=nn.Linear(1, 1), layer1=nn.Linear(1, 1)))
        self.head = nn.Sequential(
            OrderedDict(
                linear=nn.Linear(1, 1),
                bn=nn.BatchNorm1d(1)))


optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.01, weight_decay=0.0001),
    paramwise_cfg=dict(norm_decay_mult=0))
optimizer = build_optim_wrapper(ToyModel(), optim_wrapper)
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:lr=0.01
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.bias:lr=0.01
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.bias:lr=0.01
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.weight:weight_decay=0.0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.bias:weight_decay=0.0

除了可以对正则化层的权重衰减进行配置外,MMEngine 的默认优化器封装构造器的 paramwise_cfg 还支持对更多不同类型的参数设置超参系数,支持的配置如下:

lr_mult:所有参数的学习率系数

decay_mult:所有参数的衰减系数

bias_lr_mult:偏置的学习率系数(不包括正则化层的偏置以及可变形卷积的 offset)

bias_decay_mult:偏置的权值衰减系数(不包括正则化层的偏置以及可变形卷积的 offset)

norm_decay_mult:正则化层权重和偏置的权值衰减系数

flat_decay_mult:一维参数的权值衰减系数

dwconv_decay_mult:Depth-wise 卷积的权值衰减系数

bypass_duplicate:是否跳过重复的参数,默认为 False

dcn_offset_lr_mult:可变形卷积(Deformable Convolution)的学习率系数

为模型不同部分的参数设置不同的超参系数

此外,与上文 PyTorch 的示例一样,在 MMEngine 中我们也同样可以对模型中的任意模块设置不同的超参,只需要在 paramwise_cfg 中设置 custom_keys 即可。

例如我们想将 backbone.layer0 所有参数的学习率设置为 0,衰减系数设置为 0,backbone 其余子模块的学习率设置为 0.01;head 所有参数的学习率设置为 0.001,可以这样配置:

optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.01, weight_decay=0.0001),
    paramwise_cfg=dict(
        custom_keys={
            'backbone.layer0': dict(lr_mult=0, decay_mult=0),
            'backbone': dict(lr_mult=1),
            'head': dict(lr_mult=0.1)
        }))
optimizer = build_optim_wrapper(ToyModel(), optim_wrapper)
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.weight:lr=0.0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.weight:weight_decay=0.0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.weight:lr_mult=0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.weight:decay_mult=0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:lr=0.0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:weight_decay=0.0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:lr_mult=0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:decay_mult=0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.weight:lr=0.01
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.weight:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.weight:lr_mult=1
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.bias:lr=0.01
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.bias:lr_mult=1
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.weight:lr=0.001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.weight:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.weight:lr_mult=0.1
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.bias:lr=0.001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.bias:lr_mult=0.1
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.weight:lr=0.001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.weight:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.weight:lr_mult=0.1
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.bias:lr=0.001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.bias:lr_mult=0.1

上例中,模型的状态字典的 key 如下:

for name, val in ToyModel().named_parameters():
    print(name)
backbone.layer0.weight
backbone.layer0.bias
backbone.layer1.weight
backbone.layer1.bias
head.linear.weight
head.linear.bias
head.bn.weight
head.bn.bias

custom_keys 中每一个字段的含义如下:

  1. 'backbone': dict(lr_mult=1):将名字前缀为 backbone 的参数的学习率系数设置为 1

  2. 'backbone.layer0': dict(lr_mult=0, decay_mult=0):将名字前缀为 backbone.layer0 的参数学习率系数设置为 0,衰减系数设置为 0,该配置优先级比第一条高

  3. 'head': dict(lr_mult=0.1):将名字前缀为 head 的参数的学习率系数设置为 0.1

自定义优化器构造策略

与 MMEngine 中的其他模块一样,优化器封装构造器也同样由注册表管理。我们可以通过实现自定义的优化器封装构造器来实现自定义的超参设置策略。

例如,我们想实现一个叫做 LayerDecayOptimWrapperConstructor 的优化器封装构造器,能够对模型不同深度的层自动设置递减的学习率:

from mmengine.optim import DefaultOptimWrapperConstructor
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS
from mmengine.logging import print_log


@OPTIM_WRAPPER_CONSTRUCTORS.register_module(force=True)
class LayerDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):

    def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
        super().__init__(optim_wrapper_cfg, paramwise_cfg=None)
        self.decay_factor = paramwise_cfg.get('decay_factor', 0.5)

        super().__init__(optim_wrapper_cfg, paramwise_cfg)

    def add_params(self, params, module, prefix='' ,lr=None):
        if lr is None:
            lr = self.base_lr

        for name, param in module.named_parameters(recurse=False):
            param_group = dict()
            param_group['params'] = [param]
            param_group['lr'] = lr
            params.append(param_group)
            full_name = f'{prefix}.{name}' if prefix else name
            print_log(f'{full_name} : lr={lr}', logger='current')

        for name, module in module.named_children():
            chiled_prefix = f'{prefix}.{name}' if prefix else name
            self.add_params(
                params, module, chiled_prefix, lr=lr * self.decay_factor)


class ToyModel(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.layer = nn.ModuleDict(dict(linear=nn.Linear(1, 1)))
        self.linear = nn.Linear(1, 1)


model = ToyModel()

optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.01, weight_decay=0.0001),
    paramwise_cfg=dict(decay_factor=0.5),
    constructor='LayerDecayOptimWrapperConstructor')

optimizer = build_optim_wrapper(model, optim_wrapper)
08/23 22:20:26 - mmengine - INFO - layer.linear.weight : lr=0.0025
08/23 22:20:26 - mmengine - INFO - layer.linear.bias : lr=0.0025
08/23 22:20:26 - mmengine - INFO - linear.weight : lr=0.005
08/23 22:20:26 - mmengine - INFO - linear.bias : lr=0.005

add_params 被第一次调用时,params 参数为空列表(list),module 为模型(model)。详细的重载规则参考优化器封装构造器文档

类似地,如果想构造多个优化器,也需要实现自定义的构造器:

@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class MultipleOptimiWrapperConstructor:
    ...

在训练过程中调整超参

优化器中的超参数在构造时只能设置为一个定值,仅仅使用优化器封装,并不能在训练过程中调整学习率等参数。在 MMEngine 中,我们实现了参数调度器(Parameter Scheduler),以便能够在训练过程中调整参数。关于参数调度器的用法请见优化器参数调整策略

优化器参数调整策略(Parameter Scheduler)

在模型训练过程中,我们往往不是采用固定的优化参数,例如学习率等,会随着训练轮数的增加进行调整。最简单常见的学习率调整策略就是阶梯式下降,例如每隔一段时间将学习率降低为原来的几分之一。PyTorch 中有学习率调度器 LRScheduler 来对各种不同的学习率调整方式进行抽象,但支持仍然比较有限,在 MMEngine 中,我们对其进行了拓展,实现了更通用的参数调度器,可以对学习率、动量等优化器相关的参数进行调整,并且支持多个调度器进行组合,应用更复杂的调度策略。

参数调度器的使用

我们先简单介绍一下如何使用 PyTorch 内置的学习率调度器来进行学习率的调整:

如何使用 PyTorch 内置的学习率调度器调整学习率

下面是参考 PyTorch 官方文档 实现的一个例子,我们构造一个 ExponentialLR,并且在每个 epoch 结束后调用 scheduler.step(),实现了随 epoch 指数下降的学习率调整策略。

import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR

model = torch.nn.Linear(1, 1)
dataset = [torch.randn((1, 1, 1)) for _ in range(20)]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(10):
    for data in dataset:
        optimizer.zero_grad()
        output = model(data)
        loss = 1 - output
        loss.backward()
        optimizer.step()
    scheduler.step()

mmengine.optim.scheduler 中,我们支持大部分 PyTorch 中的学习率调度器,例如 ExponentialLRLinearLRStepLRMultiStepLR 等,使用方式也基本一致,所有支持的调度器见调度器接口文档。同时增加了对动量的调整,在类名中将 LR 替换成 Momentum 即可,例如 ExponentialMomentumLinearMomentum。更进一步地,我们实现了通用的参数调度器 ParamScheduler,用于调整优化器的中的其他参数,包括 weight_decay 等。这个特性可以很方便地配置一些新算法中复杂的调整策略。

和 PyTorch 文档中所给示例不同,MMEngine 中通常不需要手动来实现训练循环以及调用 optimizer.step(),而是在执行器(Runner)中对训练流程进行自动管理,同时通过 ParamSchedulerHook 来控制参数调度器的执行。

使用单一的学习率调度器

如果整个训练过程只需要使用一个学习率调度器, 那么和 PyTorch 自带的学习率调度器没有差异。

# 基于手动构建学习率调度器的例子
from torch.optim import SGD
from mmengine.runner import Runner
from mmengine.optim.scheduler import MultiStepLR

optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
param_scheduler = MultiStepLR(optimizer, milestones=[8, 11], gamma=0.1)

runner = Runner(
    model=model,
    optim_wrapper=dict(
        optimizer=optimizer),
    param_scheduler=param_scheduler,
    ...
    )

image

如果配合注册器和配置文件使用的话,我们可以设置配置文件中的 param_scheduler 字段来指定调度器, 执行器(Runner)会根据此字段以及执行器中的优化器自动构建学习率调度器:

# 在配置文件中设置学习率调度器字段
param_scheduler = dict(type='MultiStepLR', by_epoch=True, milestones=[8, 11], gamma=0.1)

注意这里增加了初始化参数 by_epoch,控制的是学习率调整频率,当其为 True 时表示按轮次(epoch)调整,为 False 时表示按迭代次数(iteration)调整,默认值为 True。在上面的例子中,表示按照轮次进行调整,此时其他参数的单位均为 epoch,例如 milestones 中的 [8, 11] 表示第 8 和 11 轮次结束时,学习率将会被调整为上一轮次的 0.1 倍。

当修改了学习率调整频率后,调度器中与计数相关设置的含义也会相应被改变。当 by_epoch=True 时,milestones 中的数字表示在哪些轮次进行学习率衰减,而当 by_epoch=False 时则表示在进行到第几次迭代时进行学习率衰减。下面是一个按照迭代次数进行调整的例子,在第 600 和 800 次迭代结束时,学习率将会被调整为原来的 0.1 倍。

param_scheduler = dict(type='MultiStepLR', by_epoch=False, milestones=[600, 800], gamma=0.1)

image

若用户希望在配置调度器时按轮次填写参数的同时使用基于迭代的更新频率,MMEngine 的调度器也提供了自动换算的方式。用户可以调用 build_iter_from_epoch 方法,并提供每个训练轮次的迭代次数,即可构造按迭代次数更新的调度器对象:

epoch_length = len(train_dataloader)
param_scheduler = MultiStepLR.build_iter_from_epoch(optimizer, milestones=[8, 11], gamma=0.1, epoch_length=epoch_length)

如果使用配置文件构建调度器,只需要在配置中加入 convert_to_iter_based=True,执行器会自动调用 build_iter_from_epoch 将基于轮次的配置文件转换为基于迭代次数的调度器对象:

param_scheduler = dict(type='MultiStepLR', by_epoch=True, milestones=[8, 11], gamma=0.1, convert_to_iter_based=True)

为了能直观感受这两种模式的区别,我们这里再举一个例子。下面是一个按轮次更新的余弦退火(CosineAnnealing)学习率调度器,学习率仅在每个轮次结束后被修改:

param_scheduler = dict(type='CosineAnnealingLR', by_epoch=True, T_max=12)

image

而在使用自动换算后,学习率会在每次迭代后被修改。从下图可以看出,学习率的变化更为平滑。

param_scheduler = dict(type='CosineAnnealingLR', by_epoch=True, T_max=12, convert_to_iter_based=True)

image

组合多个学习率调度器(以学习率预热为例)

有些算法在训练过程中,并不是自始至终按照某个调度策略进行学习率调整的。最常见的例子是学习率预热,比如在训练刚开始的若干迭代次数使用线性的调整策略将学习率从一个较小的值增长到正常,然后按照另外的调整策略进行正常训练。

MMEngine 支持组合多个调度器一起使用,只需将配置文件中的 scheduler 字段修改为一组调度器配置的列表,SchedulerStepHook 可以自动对调度器列表进行处理。下面的例子便实现了学习率预热。

param_scheduler = [
    # 线性学习率预热调度器
    dict(type='LinearLR',
         start_factor=0.001,
         by_epoch=False,  # 按迭代更新学习率
         begin=0,
         end=50),  # 预热前 50 次迭代
    # 主学习率调度器
    dict(type='MultiStepLR',
         by_epoch=True,  # 按轮次更新学习率
         milestones=[8, 11],
         gamma=0.1)
]

image

注意这里增加了 beginend 参数,这两个参数指定了调度器的生效区间。生效区间通常只在多个调度器组合时才需要去设置,使用单个调度器时可以忽略。当指定了 beginend 参数时,表示该调度器只在 [begin, end) 区间内生效,其单位是由 by_epoch 参数决定。上述例子中预热阶段 LinearLRby_epoch 为 False,表示该调度器只在前 50 次迭代生效,超过 50 次迭代后此调度器不再生效,由第二个调度器来控制学习率,即 MultiStepLR。在组合不同调度器时,各调度器的 by_epoch 参数不必相同。

这里再举一个例子:

param_scheduler = [
    # 在 [0, 100) 迭代时使用线性学习率
    dict(type='LinearLR',
         start_factor=0.001,
         by_epoch=False,
         begin=0,
         end=100),
    # 在 [100, 900) 迭代时使用余弦学习率
    dict(type='CosineAnnealingLR',
         T_max=800,
         by_epoch=False,
         begin=100,
         end=900)
]

image

上述例子表示在训练的前 100 次迭代时使用线性的学习率预热,然后在第 100 到第 900 次迭代时使用周期为 800 的余弦退火学习率调度器使学习率按照余弦函数逐渐下降为 0 。

我们可以组合任意多个调度器,既可以使用 MMEngine 中已经支持的调度器,也可以实现自定义的调度器。 如果相邻两个调度器的生效区间没有紧邻,而是有一段区间没有被覆盖,那么这段区间的学习率维持不变。而如果两个调度器的生效区间发生了重叠,则对多组调度器叠加使用,学习率的调整会按照调度器配置文件中的顺序触发(行为与 PyTorch 中 ChainedScheduler 一致)。 在一般情况下,我们推荐用户在训练的不同阶段使用不同的学习率调度策略来避免调度器的生效区间发生重叠。如果确实需要将两个调度器叠加使用,则需要十分小心,避免学习率的调整与预期不符。

如何调整其他参数

动量

和学习率一样, 动量也是优化器参数组中一组可以调度的参数。 动量调度器(momentum scheduler)的使用方法和学习率调度器完全一样。同样也只需要将动量调度器的配置添加进配置文件中的 param_scheduler 字段的列表中即可。

示例:

param_scheduler = [
    # the lr scheduler
    dict(type='LinearLR', ...),
    # 动量调度器
    dict(type='LinearMomentum',
         start_factor=0.001,
         by_epoch=False,
         begin=0,
         end=1000)
]

通用的参数调度器

MMEngine 还提供了一组通用的参数调度器用于调度优化器的 param_groups 中的其他参数,将学习率调度器类名中的 LR 改为 Param 即可,例如 LinearParamScheduler。用户可以通过设置参数调度器的 param_name 变量来选择想要调度的参数。

下面是一个通过自定义参数名来调度的例子:

param_scheduler = [
    dict(type='LinearParamScheduler',
         param_name='lr',  # 调度 `optimizer.param_groups` 中名为 'lr' 的变量
         start_factor=0.001,
         by_epoch=False,
         begin=0,
         end=1000)
]

这里设置的参数名是 lr,因此这个调度器的作用等同于直接使用学习率调度器 LinearLRScheduler

除了动量之外,用户也可以对 optimizer.param_groups 中的其他参数名进行调度,可调度的参数取决于所使用的优化器。例如,当使用带 weight_decay 的 SGD 优化器时,可以按照以下示例对调整 weight_decay

param_scheduler = [
    dict(type='LinearParamScheduler',
         param_name='weight_decay',  # 调度 `optimizer.param_groups` 中名为 'weight_decay' 的变量
         start_factor=0.001,
         by_epoch=False,
         begin=0,
         end=1000)
]

钩子(Hook)

钩子编程是一种编程模式,是指在程序的一个或者多个位置设置位点(挂载点),当程序运行至某个位点时,会自动调用运行时注册到位点的所有方法。钩子编程可以提高程序的灵活性和拓展性,用户将自定义的方法注册到位点便可被调用而无需修改程序中的代码。

内置钩子

MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默认钩子以及自定义钩子,前者表示会默认往执行器注册,后者表示需要用户自己注册。

每个钩子都有对应的优先级,在同一位点,钩子的优先级越高,越早被执行器调用,如果优先级一样,被调用的顺序和钩子注册的顺序一致。优先级列表如下:

  • HIGHEST (0)

  • VERY_HIGH (10)

  • HIGH (30)

  • ABOVE_NORMAL (40)

  • NORMAL (50)

  • BELOW_NORMAL (60)

  • LOW (70)

  • VERY_LOW (90)

  • LOWEST (100)

默认钩子

名称

用途

优先级

RuntimeInfoHook

往 message hub 更新运行时信息

VERY_HIGH (10)

IterTimerHook

统计迭代耗时

NORMAL (50)

DistSamplerSeedHook

确保分布式 Sampler 的 shuffle 生效

NORMAL (50)

LoggerHook

打印日志

BELOW_NORMAL (60)

ParamSchedulerHook

调用 ParamScheduler 的 step 方法

LOW (70)

CheckpointHook

按指定间隔保存权重

VERY_LOW (90)

自定义钩子

名称

用途

优先级

EMAHook

模型参数指数滑动平均

NORMAL (50)

EmptyCacheHook

PyTorch CUDA 缓存清理

NORMAL (50)

SyncBuffersHook

同步模型的 buffer

NORMAL (50)

注解

不建议修改默认钩子的优先级,因为优先级低的钩子可能会依赖优先级高的钩子。例如 CheckpointHook 的优先级需要比 ParamSchedulerHook 低,这样保存的优化器状态才是正确的状态。另外,自定义钩子的优先级默认为 NORMAL (50)

两种钩子在执行器中的设置不同,默认钩子的配置传给执行器的 default_hooks 参数,自定义钩子的配置传给 custom_hooks 参数,如下所示:

from mmengine.runner import Runner

default_hooks = dict(
    runtime_info=dict(type='RuntimeInfoHook'),
    timer=dict(type='IterTimerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    logger=dict(type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=1),
)

custom_hooks = [dict(type='EmptyCacheHook')]

runner = Runner(default_hooks=default_hooks, custom_hooks=custom_hooks, ...)
runner.train()

下面逐一介绍 MMEngine 中内置钩子的用法。

CheckpointHook

CheckpointHook 按照给定间隔保存模型的权重,如果是分布式多卡训练,则只有主(master)进程会保存权重。CheckpointHook 的主要功能如下:

  • 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重

  • 保存最新的多个权重

  • 保存最优权重

  • 指定保存权重的路径

如需了解其他功能,请阅读 CheckpointHook API 文档

下面介绍上面提到的 4 个功能。

  • 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重

    假设我们一共训练 20 个 epoch 并希望每隔 5 个 epoch 保存一次权重,下面的配置即可帮我们实现该需求。

    # by_epoch 的默认值为 True
    default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True))
    

    如果想以迭代次数作为保存间隔,则可以将 by_epoch 设为 False,interval=5 则表示每迭代 5 次保存一次权重。

    default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=False))
    
  • 保存最新的多个权重

    如果只想保存一定数量的权重,可以通过设置 max_keep_ckpts 参数实现最多保存 max_keep_ckpts 个权重,当保存的权重数超过 max_keep_ckpts 时,前面的权重会被删除。

    default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=2))
    

    上述例子表示,假如一共训练 20 个 epoch,那么会在第 5, 10, 15, 20 个 epoch 保存模型,但是在第 15 个 epoch 的时候会删除第 5 个 epoch 保存的权重,在第 20 个 epoch 的时候会删除第 10 个 epoch 的权重,最终只有第 15 和第 20 个 epoch 的权重才会被保存。

  • 保存最优权重

    如果想要保存训练过程验证集的最优权重,可以设置 save_best 参数,如果设置为 'auto',则会根据验证集的第一个评价指标(验证集返回的评价指标是一个有序字典)判断当前权重是否最优。

    default_hooks = dict(checkpoint=dict(type='CheckpointHook', save_best='auto'))
    

    也可以直接指定 save_best 的值为评价指标,例如在分类任务中,可以指定为 save_best='top-1',则会根据 'top-1' 的值判断当前权重是否最优。

    除了 save_best 参数,和保存最优权重相关的参数还有 rulegreater_keysless_keys,这三者用来判断 save_best 的值是越大越好还是越小越好。例如指定了 save_best='top-1',可以指定 rule='greater',则表示该值越大表示权重越好。

  • 指定保存权重的路径

    权重默认保存在工作目录(work_dir),但可以通过设置 out_dir 改变保存路径。

    default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))
    

LoggerHook

LoggerHook 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。

如果我们希望每迭代 20 次就输出(或保存)一次日志,我们可以设置 interval 参数,配置如下:

default_hooks = dict(logger=dict(type='LoggerHook', interval=20))

如果你对日志的管理感兴趣,可以阅读记录日志(logging)

ParamSchedulerHook

ParamSchedulerHook 遍历执行器的所有优化器参数调整策略(Parameter Scheduler)并逐个调用 step 方法更新优化器的参数。如需了解优化器参数调整策略的用法请阅读文档ParamSchedulerHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

IterTimerHook

IterTimerHook 用于记录加载数据的时间以及迭代一次耗费的时间。IterTimerHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

DistSamplerSeedHook

DistSamplerSeedHook 在分布式训练时调用 Sampler 的 step 方法以确保 shuffle 参数生效。DistSamplerSeedHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

RuntimeInfoHook

RuntimeInfoHook 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等)更新至 message hub 中,以便其他无法访问执行器的模块能够获取到这些信息。RuntimeInfoHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

EMAHook

EMAHook 在训练过程中对模型执行指数滑动平均操作,目的是提高模型的鲁棒性。注意:指数滑动平均生成的模型只用于验证和测试,不影响训练。

custom_hooks = [dict(type='EMAHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

EMAHook 默认使用 ExponentialMovingAverage,可选值还有 StochasticWeightAverageMomentumAnnealingEMA。可以通过设置 ema_type 使用其他的平均策略。

custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]

更多用法请阅读 EMAHook API 文档

EmptyCacheHook

EmptyCacheHook 调用 torch.cuda.empty_cache() 释放未被使用的显存。可以通过设置 before_epoch, after_iter 以及 after_epoch 参数控制释显存的时机,第一个参数表示在每个 epoch 开始之前,第二参数表示在每次迭代之后,第三个参数表示在每个 epoch 之后。

# 每一个 epoch 结束都会执行释放操作
custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

SyncBuffersHook

SyncBuffersHook 在分布式训练每一轮(epoch)结束时同步模型的 buffer,例如 BN 层的 running_mean 以及 running_var

custom_hooks = [dict(type='SyncBuffersHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

自定义钩子

如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。

例如,如果希望在训练的过程中判断损失值是否有效,如果值为无穷大则无效,我们可以在每次迭代后判断损失值是否无穷大,因此只需重写 after_train_iter 位点。

import torch

from mmengine.registry import HOOKS
from mmengine.hooks import Hook


@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
    """Check invalid loss hook.

    This hook will regularly check whether the loss is valid
    during training.

    Args:
        interval (int): Checking interval (every k iterations).
            Defaults to 50.
    """

    def __init__(self, interval=50):
        self.interval = interval

    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        """All subclasses should override this method, if they need any
        operations after each training iteration.

        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict or tuple or list, optional): Data from dataloader.
            outputs (dict, optional): Outputs from model.
        """
        if self.every_n_train_iters(runner, self.interval):
            assert torch.isfinite(outputs['loss']),\
                runner.logger.info('loss become infinite or NaN!')

我们只需将钩子的配置传给执行器的 custom_hooks 的参数,执行器初始化的时候会注册钩子,

from mmengine.runner import Runner

custom_hooks = dict(
    dict(type='CheckInvalidLossHook', interval=50)
)
runner = Runner(custom_hooks=custom_hooks, ...)  # 实例化执行器,主要完成环境的初始化以及各种模块的构建
runner.train()  # 执行器开始训练

便会在每次模型前向计算后检查损失值。

注意,自定义钩子的优先级默认为 NORMAL (50),如果想改变钩子的优先级,则可以在配置中设置 priority 字段。

custom_hooks = dict(
    dict(type='CheckInvalidLossHook', interval=50, priority='ABOVE_NORMAL')
)

也可以在定义类时给定优先级

@HOOKS.register_module()
class CheckInvalidLossHook(Hook):

    priority = 'ABOVE_NORMAL'

你可能还想阅读钩子的设计或者钩子的 API 文档

注册器(Registry)

OpenMMLab 的算法库支持了丰富的算法和数据集,因此实现了很多功能相近的模块。例如 ResNet 和 SE-ResNet 的算法实现分别基于 ResNetSEResNet 类,这些类有相似的功能和接口,都属于算法库中的模型组件。为了管理这些功能相似的模块,MMEngine 实现了 注册器。OpenMMLab 大多数算法库均使用注册器来管理它们的代码模块,包括 MMDetectionMMDetection3DMMClassificationMMEditing 等。

什么是注册器

MMEngine 实现的注册器可以看作一个映射表和模块构建方法(build function)的组合。映射表维护了一个字符串到类或者函数的映射,使得用户可以借助字符串查找到相应的类或函数,例如维护字符串 "ResNet"ResNet 类或函数的映射,使得用户可以通过 "ResNet" 找到 ResNet 类;而模块构建方法则定义了如何根据字符串查找到对应的类或函数以及如何实例化这个类或者调用这个函数,例如,通过字符串 "bn" 找到 nn.BatchNorm2d 并实例化 BatchNorm2d 模块;又或者通过字符串 "build_batchnorm2d" 找到 build_batchnorm2d 函数并返回该函数的调用结果。MMEngine 中的注册器默认使用 build_from_cfg 函数来查找并实例化字符串对应的类或者函数。

一个注册器管理的类或函数通常有相似的接口和功能,因此该注册器可以被视作这些类或函数的抽象。例如注册器 MODELS 可以被视作所有模型的抽象,管理了 ResNetSEResNetRegNetX 等分类网络的类以及 build_ResNet, build_SEResNetbuild_RegNetX 等分类网络的构建函数。

入门用法

使用注册器管理代码库中的模块,需要以下三个步骤。

  1. 创建注册器

  2. 创建一个用于实例化类的构建方法(可选,在大多数情况下可以只使用默认方法)

  3. 将模块加入注册器中

假设我们要实现一系列激活模块并且希望仅修改配置就能够使用不同的激活模块而无需修改代码。

首先创建注册器,

from mmengine import Registry
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
# locations 表示注册在此注册器的模块所存放的位置,注册器会根据预先定义的位置在构建模块时自动 import
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])

locations 指定的模块 mmengine.models.activations 对应了 mmengine/models/activations.py 文件。在使用注册器构建模块的时候,ACTIVATION 注册器会自动从该文件中导入实现的模块。因此,我们可以在 mmengine/models/activations.py 文件中实现不同的激活函数,例如 SigmoidReLUSoftmax

import torch.nn as nn

# 使用注册器管理模块
@ACTIVATION.register_module()
class Sigmoid(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        print('call Sigmoid.forward')
        return x

@ACTIVATION.register_module()
class ReLU(nn.Module):
    def __init__(self, inplace=False):
        super().__init__()

    def forward(self, x):
        print('call ReLU.forward')
        return x

@ACTIVATION.register_module()
class Softmax(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        print('call Softmax.forward')
        return x

使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 ACTIVATION 中。通过 @ACTIVATION.register_module() 装饰所实现的模块,字符串和类或函数之间的映射就可以由 ACTIVATION 构建和维护,我们也可以通过 ACTIVATION.register_module(module=ReLU) 实现同样的功能。

通过注册,我们就可以通过 ACTIVATION 建立字符串与类或函数之间的映射,

print(ACTIVATION.module_dict)
# {
#     'Sigmoid': __main__.Sigmoid,
#     'ReLU': __main__.ReLU,
#     'Softmax': __main__.Softmax
# }

注解

只有模块所在的文件被导入时,注册机制才会被触发,用户可以通过三种方式将模块添加到注册器中:

  1. locations 指向的文件中实现模块。注册器将自动在预先定义的位置导入模块。这种方式是为了简化算法库的使用,以便用户可以直接使用 REGISTRY.build(cfg)

  2. 手动导入文件。常用于用户在算法库之内或之外实现新的模块。

  3. 在配置中使用 custom_imports 字段。 详情请参考导入自定义Python模块

模块成功注册后,我们可以通过配置文件使用这个激活模块。

import torch

input = torch.randn(2)

act_cfg = dict(type='Sigmoid')
activation = ACTIVATION.build(act_cfg)
output = activation(input)
# call Sigmoid.forward
print(output)

如果我们想使用 ReLU,仅需修改配置。

act_cfg = dict(type='ReLU', inplace=True)
activation = ACTIVATION.build(act_cfg)
output = activation(input)
# call ReLU.forward
print(output)

如果我们希望在创建实例前检查输入参数的类型(或者任何其他操作),我们可以实现一个构建方法并将其传递给注册器从而实现自定义构建流程。

创建一个构建方法,


def build_activation(cfg, registry, *args, **kwargs):
    cfg_ = cfg.copy()
    act_type = cfg_.pop('type')
    print(f'build activation: {act_type}')
    act_cls = registry.get(act_type)
    act = act_cls(*args, **kwargs, **cfg_)
    return act

并将 build_activation 传递给 build_func 参数

ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])

@ACTIVATION.register_module()
class Tanh(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        print('call Tanh.forward')
        return x

act_cfg = dict(type='Tanh')
activation = ACTIVATION.build(act_cfg)
output = activation(input)
# build activation: Tanh
# call Tanh.forward
print(output)

注解

在这个例子中,我们演示了如何使用参数 build_func 自定义构建类的实例的方法。 该功能类似于默认的 build_from_cfg 方法。在大多数情况下,使用默认的方法就可以了。

MMEngine 的注册器除了可以注册类,也可以注册函数。

FUNCTION = Registry('function', scope='mmengine')

@FUNCTION.register_module()
def print_args(**kwargs):
    print(kwargs)

func_cfg = dict(type='print_args', a=1, b=2)
func_res = FUNCTION.build(func_cfg)

进阶用法

MMEngine 的注册器支持层级注册,利用该功能可实现跨项目调用,即可以在一个项目中使用另一个项目的模块。虽然跨项目调用也有其他方法的可以实现,但 MMEngine 注册器提供了更为简便的方法。

为了方便跨库调用,MMEngine 提供了 20 个根注册器:

  • RUNNERS: Runner 的注册器

  • RUNNER_CONSTRUCTORS: Runner 的构造器

  • LOOPS: 管理训练、验证以及测试流程,如 EpochBasedTrainLoop

  • HOOKS: 钩子,如 CheckpointHook, ParamSchedulerHook

  • DATASETS: 数据集

  • DATA_SAMPLERS: DataLoaderSampler,用于采样数据

  • TRANSFORMS: 各种数据预处理,如 Resize, Reshape

  • MODELS: 模型的各种模块

  • MODEL_WRAPPERS: 模型的包装器,如 MMDistributedDataParallel,用于对分布式数据并行

  • WEIGHT_INITIALIZERS: 权重初始化的工具

  • OPTIMIZERS: 注册了 PyTorch 中所有的 Optimizer 以及自定义的 Optimizer

  • OPTIM_WRAPPER: 对 Optimizer 相关操作的封装,如 OptimWrapperAmpOptimWrapper

  • OPTIM_WRAPPER_CONSTRUCTORS: optimizer wrapper 的构造器

  • PARAM_SCHEDULERS: 各种参数调度器,如 MultiStepLR

  • METRICS: 用于计算模型精度的评估指标,如 Accuracy

  • EVALUATOR: 用于计算模型精度的一个或多个评估指标

  • TASK_UTILS: 任务强相关的一些组件,如 AnchorGenerator, BboxCoder

  • VISUALIZERS: 管理绘制模块,如 DetVisualizer 可在图片上绘制预测框

  • VISBACKENDS: 存储训练日志的后端,如 LocalVisBackend, TensorboardVisBackend

  • LOG_PROCESSORS: 控制日志的统计窗口和统计方法,默认使用 LogProcessor,如有特殊需求可自定义 LogProcessor

调用父节点的模块

MMEngine 中定义模块 RReLU,并往 MODELS 根注册器注册。

import torch.nn as nn
from mmengine import Registry, MODELS

@MODELS.register_module()
class RReLU(nn.Module):
    def __init__(self, lower=0.125, upper=0.333, inplace=False):
        super().__init__()

    def forward(self, x):
        print('call RReLU.forward')
        return x

假设有个项目叫 MMAlpha,它也定义了 MODELS,并设置其父节点为 MMEngineMODELS,这样就建立了层级结构。

from mmengine import Registry, MODELS as MMENGINE_MODELS

MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])

下图是 MMEngineMMAlpha 的注册器层级结构。

可以调用 count_registered_modules 函数打印已注册到 MMEngine 的模块以及层级结构。

from mmengine.registry import count_registered_modules

count_registered_modules()

MMAlpha 中定义模块 LogSoftmax,并往 MMAlphaMODELS 注册。

@MODELS.register_module()
class LogSoftmax(nn.Module):
    def __init__(self, dim=None):
        super().__init__()

    def forward(self, x):
        print('call LogSoftmax.forward')
        return x

MMAlpha 中使用配置调用 LogSoftmax

model = MODELS.build(cfg=dict(type='LogSoftmax'))

也可以在 MMAlpha 中调用父节点 MMEngine 的模块。

model = MODELS.build(cfg=dict(type='RReLU', lower=0.2))
# 也可以加 scope
model = MODELS.build(cfg=dict(type='mmengine.RReLU'))

如果不加前缀,build 方法首先查找当前节点是否存在该模块,如果存在则返回该模块,否则会继续向上查找父节点甚至祖先节点直到找到该模块,因此,如果当前节点和父节点存在同一模块并且希望调用父节点的模块,我们需要指定 scope 前缀。

import torch

input = torch.randn(2)
output = model(input)
# call RReLU.forward
print(output)

调用兄弟节点的模块

除了可以调用父节点的模块,也可以调用兄弟节点的模块。

假设有另一个项目叫 MMBeta,它和 MMAlpha 一样,定义了 MODELS 以及设置其父节点为 MMEngineMODELS

from mmengine import Registry, MODELS as MMENGINE_MODELS

MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmbeta')

下图是 MMEngine,MMAlpha 和 MMBeta 的注册器层级结构。

MMBeta 中调用兄弟节点 MMAlpha 的模块,

model = MODELS.build(cfg=dict(type='mmalpha.LogSoftmax'))
output = model(input)
# call LogSoftmax.forward
print(output)

调用兄弟节点的模块需要在 type 中指定 scope 前缀,所以上面的配置需要加前缀 mmalpha

如果需要调用兄弟节点的数个模块,每个模块都加前缀,这需要做大量的修改。于是 MMEngine 引入了 DefaultScopeRegistry 借助它可以很方便地支持临时切换当前节点为指定的节点。

如果需要临时切换当前节点为指定的节点,只需在 cfg 设置 _scope_ 为指定节点的作用域。

model = MODELS.build(cfg=dict(type='LogSoftmax', _scope_='mmalpha'))
output = model(input)
# call LogSoftmax.forward
print(output)

配置(Config)

MMEngine 实现了抽象的配置类(Config),为用户提供统一的配置访问接口。配置类能够支持不同格式的配置文件,包括 pythonjsonyaml,用户可以根据需求选择自己偏好的格式。配置类提供了类似字典或者 Python 对象属性的访问接口,用户可以十分自然地进行配置字段的读取和修改。为了方便算法框架管理配置文件,配置类也实现了一些特性,例如配置文件的字段继承等。

在开始教程之前,我们先将教程中需要用到的配置文件下载到本地(建议在临时目录下执行,方便后续删除示例配置文件):

wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/config_sgd.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/cross_repo.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/custom_imports.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/demo_train.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/example.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/learn_read_config.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/my_module.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/optimizer_cfg.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/predefined_var.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/refer_base_var.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/resnet50_delete_key.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/resnet50_lr0.01.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/resnet50_runtime.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/resnet50.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/runtime_cfg.py
wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/modify_base_var.py

配置文件读取

配置类提供了统一的接口 Config.fromfile(),来读取和解析配置文件。

合法的配置文件应该定义一系列键值对,这里举几个不同格式配置文件的例子。

Python 格式:

test_int = 1
test_list = [1, 2, 3]
test_dict = dict(key1='value1', key2=0.1)

Json 格式:

{
  "test_int": 1,
  "test_list": [1, 2, 3],
  "test_dict": {"key1": "value1", "key2": 0.1}
}

YAML 格式:

test_int: 1
test_list: [1, 2, 3]
test_dict:
  key1: "value1"
  key2: 0.1

对于以上三种格式的文件,假设文件名分别为 config.pyconfig.jsonconfig.yml,调用 Config.fromfile('config.xxx') 接口加载这三个文件都会得到相同的结果,构造了包含 3 个字段的配置对象。我们以 config.py 为例,我们先将示例配置文件下载到本地:

然后通过配置类的 fromfile 接口读取配置文件:

from mmengine.config import Config

cfg = Config.fromfile('learn_read_config.py')
print(cfg)
Config (path: learn_read_config.py): {'test_int': 1, 'test_list': [1, 2, 3], 'test_dict': {'key1': 'value1', 'key2': 0.1}}

配置文件的使用

通过读取配置文件来初始化配置对象后,就可以像使用普通字典或者 Python 类一样来使用这个变量了。我们提供了两种访问接口,即类似字典的接口 cfg['key'] 或者类似 Python 对象属性的接口 cfg.key。这两种接口都支持读写。

print(cfg.test_int)
print(cfg.test_list)
print(cfg.test_dict)
cfg.test_int = 2

print(cfg['test_int'])
print(cfg['test_list'])
print(cfg['test_dict'])
cfg['test_list'][1] = 3
print(cfg['test_list'])
1
[1, 2, 3]
{'key1': 'value1', 'key2': 0.1}
2
[1, 2, 3]
{'key1': 'value1', 'key2': 0.1}
[1, 3, 3]

注意,配置文件中定义的嵌套字段(即类似字典的字段),在 Config 中会将其转化为 ConfigDict 类,该类继承了 Python 内置字典类型的全部接口,同时也支持以对象属性的方式访问数据。

在算法库中,可以将配置与注册器结合起来使用,达到通过配置文件来控制模块构造的目的。这里举一个在配置文件中定义优化器的例子。

假设我们已经定义了一个优化器的注册器 OPTIMIZERS,包括了各种优化器。那么首先写一个 config_sgd.py

optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)

然后在算法库中可以通过如下代码构造优化器对象。

from mmengine import Config, optim
from mmengine.registry import OPTIMIZERS

import torch.nn as nn

cfg = Config.fromfile('config_sgd.py')

model = nn.Conv2d(1, 1, 1)
cfg.optimizer.params = model.parameters()
optimizer = OPTIMIZERS.build(cfg.optimizer)
print(optimizer)
SGD (
Parameter Group 0
    dampening: 0
    foreach: None
    lr: 0.1
    maximize: False
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0001
)

配置文件的继承

有时候,两个不同的配置文件之间的差异很小,可能仅仅只改了一个字段,我们就需要将所有内容复制粘贴一次,而且在后续观察的时候,不容易定位到具体差异的字段。又有些情况下,多个配置文件可能都有相同的一批字段,我们不得不在这些配置文件中进行复制粘贴,给后续的修改和维护带来了不便。

为了解决这些问题,我们给配置文件增加了继承的机制,即一个配置文件 A 可以将另一个配置文件 B 作为自己的基础,直接继承了 B 中所有字段,而不必显式复制粘贴。

继承机制概述

这里我们举一个例子来说明继承机制。定义如下两个配置文件,

optimizer_cfg.py

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)

resnet50.py

_base_ = ['optimizer_cfg.py']
model = dict(type='ResNet', depth=50)

虽然我们在 resnet50.py 中没有定义 optimizer 字段,但由于我们写了 _base_ = ['optimizer_cfg.py'],会使这个配置文件获得 optimizer_cfg.py 中的所有字段。

cfg = Config.fromfile('resnet50.py')
print(cfg.optimizer)
{'type': 'SGD', 'lr': 0.02, 'momentum': 0.9, 'weight_decay': 0.0001}

这里 _base_ 是配置文件的保留字段,指定了该配置文件的继承来源。支持继承多个文件,将同时获得这多个文件中的所有字段,但是要求继承的多个文件中没有相同名称的字段,否则会报错。

runtime_cfg.py

gpu_ids = [0, 1]

resnet50_runtime.py

_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
model = dict(type='ResNet', depth=50)

这时,读取配置文件 resnet50_runtime.py 会获得 3 个字段 modeloptimizergpu_ids

cfg = Config.fromfile('resnet50_runtime.py')
print(cfg.optimizer)
{'type': 'SGD', 'lr': 0.02, 'momentum': 0.9, 'weight_decay': 0.0001}

通过这种方式,我们可以将配置文件进行拆分,定义一些通用配置文件,在实际配置文件中继承各种通用配置文件,可以减少具体任务的配置流程。

修改继承字段

有时候,我们继承一个配置文件之后,可能需要对其中个别字段进行修改,例如继承了 optimizer_cfg.py 之后,想将学习率从 0.02 修改为 0.01。

这时候,只需要在新的配置文件中,重新定义一下需要修改的字段即可。注意由于 optimizer 这个字段是一个字典,我们只需要重新定义这个字典里面需修改的下级字段即可。这个规则也适用于增加一些下级字段。

resnet50_lr0.01.py

_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
model = dict(type='ResNet', depth=50)
optimizer = dict(lr=0.01)

读取这个配置文件之后,就可以得到期望的结果。

cfg = Config.fromfile('resnet50_lr0.01.py')
print(cfg.optimizer)
{'type': 'SGD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0.0001}

对于非字典类型的字段,例如整数,字符串,列表等,重新定义即可完全覆盖,例如下面的写法就将 gpu_ids 这个字段的值修改成了 [0]

_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
model = dict(type='ResNet', depth=50)
gpu_ids = [0]

删除字典中的 key

有时候我们对于继承过来的字典类型字段,不仅仅是想修改其中某些 key,可能还需要删除其中的一些 key。这时候在重新定义这个字典时,需要指定 _delete_=True,表示将没有在新定义的字典中出现的 key 全部删除。

resnet50_delete_key.py

_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
model = dict(type='ResNet', depth=50)
optimizer = dict(_delete_=True, type='SGD', lr=0.01)

这时候,optimizer 这个字典中就只有 typelr 这两个 key,momentumweight_decay 将不再被继承。

cfg = Config.fromfile('resnet50_delete_key.py')
print(cfg.optimizer)
{'type': 'SGD', 'lr': 0.01}

引用被继承文件中的变量

有时我们想重复利用 _base_ 中定义的字段内容,就可以通过 {{_base_.xxxx}} 获取来获取对应变量的拷贝。例如:

refer_base_var.py

_base_ = ['resnet50.py']
a = {{_base_.model}}

解析后发现,a 的值变成了 resnet50.py 中定义的 model

cfg = Config.fromfile('refer_base_var.py')
print(cfg.a)
{'type': 'ResNet', 'depth': 50}

我们可以在 jsonyamlpython 三种类型的配置文件中,使用这种方式来获取 _base_ 中定义的变量。

尽管这种获取 _base_ 中定义变量的方式非常通用,但是在语法上存在一些限制,无法充分利用 python 类配置文件的动态特性。比如我们想在 python 类配置文件中,修改 _base_ 中定义的变量:

_base_ = ['resnet50.py']
a = {{_base_.model}}
a['type'] = 'MobileNet'

配置类是无法解析这样的配置文件的(解析时报错)。配置类提供了一种更 pythonic 的方式,让我们能够在 python 类配置文件中修改 _base_ 中定义的变量(python 类配置文件专属特性,目前不支持在 jsonyaml 配置文件中修改 _base_ 中定义的变量)。

modify_base_var.py

_base_ = ['resnet50.py']
a = _base_.model
a.type = 'MobileNet'
cfg = Config.fromfile('modify_base_var.py')
print(cfg.a)
{'type': 'MobileNet', 'depth': 50}

解析后发现,a 的 type 变成了 MobileNet

配置文件的导出

在启动训练脚本时,用户可能通过传参的方式来修改配置文件的部分字段,为此我们提供了 dump 接口来导出更改后的配置文件。与读取配置文件类似,用户可以通过 cfg.dump('config.xxx') 来选择导出文件的格式。dump 同样可以导出有继承关系的配置文件,导出的文件可以被独立使用,不再依赖于 _base_ 中定义的文件。

基于继承一节定义的 resnet50.py,我们将其加载后导出:

cfg = Config.fromfile('resnet50.py')
cfg.dump('resnet50_dump.py')

resnet50_dump.py

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
model = dict(type='ResNet', depth=50)

类似的,我们可以导出 json、yaml 格式的配置文件

resnet50_dump.yaml

model:
  depth: 50
  type: ResNet
optimizer:
  lr: 0.02
  momentum: 0.9
  type: SGD
  weight_decay: 0.0001

resnet50_dump.json

{"optimizer": {"type": "SGD", "lr": 0.02, "momentum": 0.9, "weight_decay": 0.0001}, "model": {"type": "ResNet", "depth": 50}}

此外,dump 不仅能导出加载自文件的 cfg,还能导出加载自字典的 cfg

cfg = Config(dict(a=1, b=2))
cfg.dump('dump_dict.py')

dump_dict.py

a=1
b=2

其他进阶用法

这里介绍一下配置类的进阶用法,这些小技巧可能使用户开发和使用算法库更简单方便。

预定义字段

有时候我们希望配置文件中的一些字段和当前路径或者文件名等相关,这里举一个典型使用场景的例子。在训练模型时,我们会在配置文件中定义一个工作目录,存放这组实验配置的模型和日志,那么对于不同的配置文件,我们期望定义不同的工作目录。用户的一种常见选择是,直接使用配置文件名作为工作目录名的一部分,例如对于配置文件 predefined_var.py,工作目录就是 ./work_dir/predefined_var

使用预定义字段可以方便地实现这种需求,在配置文件 predefined_var.py 中可以这样写:

work_dir = './work_dir/{{fileBasenameNoExtension}}'

这里 {{fileBasenameNoExtension}} 表示该配置文件的文件名(不含拓展名),在配置类读取配置文件的时候,会将这种用双花括号包起来的字符串自动解析为对应的实际值。

cfg = Config.fromfile('./predefined_var.py')
print(cfg.work_dir)
./work_dir/predefined_var

目前支持的预定义字段有以下四种,变量名参考自 VS Code 中的相关字段:

  • {{fileDirname}} - 当前文件的目录名,例如 /home/your-username/your-project/folder

  • {{fileBasename}} - 当前文件的文件名,例如 file.py

  • {{fileBasenameNoExtension}} - 当前文件不包含扩展名的文件名,例如 file

  • {{fileExtname}} - 当前文件的扩展名,例如 .py

命令行修改配置

有时候我们只希望修改部分配置,而不想修改配置文件本身,例如实验过程中想更换学习率,但是又不想重新写一个配置文件,常用的做法是在命令行传入参数来覆盖相关配置。考虑到我们想修改的配置通常是一些内层参数,如优化器的学习率、模型卷积层的通道数等,因此 MMEngine 提供了一套标准的流程,让我们能够在命令行里轻松修改配置文件中任意层级的参数。

  1. 使用 argparse 解析脚本运行的参数

  2. 使用 argparse.ArgumentParser.add_argument 方法时,让 action 参数的值为 DictAction,用它来进一步解析命令行参数中用于修改配置文件的参数

  3. 使用配置类的 merge_from_dict 方法来更新配置

启动脚本示例如下:

demo_train.py

import argparse

from mmengine.config import Config, DictAction


def parse_args():
    parser = argparse.ArgumentParser(description='Train a model')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    print(cfg)


if __name__ == '__main__':
    main()

示例配置文件如下:

example.py

model = dict(type='CustomModel', in_channels=[1, 2, 3])
optimizer = dict(type='SGD', lr=0.01)

我们在命令行里通过 . 的方式来访问配置文件中的深层配置,例如我们想修改学习率,只需要在命令行执行:

python demo_train.py ./example.py --cfg-options optimizer.lr=0.1
Config (path: ./example.py): {'model': {'type': 'CustomModel', 'in_channels': [1, 2, 3]}, 'optimizer': {'type': 'SGD', 'lr': 0.1}}

我们成功地把学习率从 0.01 修改成 0.1。如果想改变列表、元组类型的配置,如上例中的 in_channels,则需要在命令行赋值时给 ()[] 外加上双引号:

python demo_train.py ./example.py --cfg-options model.in_channels="[1, 1, 1]"
Config (path: ./example.py): {'model': {'type': 'CustomModel', 'in_channels': [1, 1, 1]}, 'optimizer': {'type': 'SGD', 'lr': 0.01}}

model.in_channels 已经从 [1, 2, 3] 修改成 [1, 1, 1]。

注解

上述流程只支持在命令行里修改字符串、整型、浮点型、布尔型、None、列表、元组类型的配置项。对于列表、元组类型的配置,里面每个元素的类型也必须为上述七种类型之一。

注解

DictAction 的行为与 "extend" 相似,支持多次传递,并保存在同一个列表中。如

python demo_train.py ./example.py --cfg-options optimizer.type="Adam" --cfg-options model.in_channels="[1, 1, 1]"
Config (path: ./example.py): {'model': {'type': 'CustomModel', 'in_channels': [1, 1, 1]}, 'optimizer': {'type': 'Adam', 'lr': 0.01}}

导入自定义 Python 模块

将配置与注册器结合起来使用时,如果我们往注册器中注册了一些自定义的类,就可能会遇到一些问题。因为读取配置文件的时候,这部分代码可能还没有被执行到,所以并未完成注册过程,从而导致构建自定义类的时候报错。

例如我们新实现了一种优化器 CustomOptim,相应代码在 my_module.py 中。

from mmengine.registry import OPTIMIZERS

@OPTIMIZERS.register_module()
class CustomOptim:
    pass

我们为这个优化器的使用写了一个新的配置文件 custom_imports.py

optimizer = dict(type='CustomOptim')

那么就需要在读取配置文件和构造优化器之前,增加一行 import my_module 来保证将自定义的类 CustomOptim 注册到 OPTIMIZERS 注册器中:为了解决这个问题,我们给配置文件定义了一个保留字段 custom_imports,用于将需要提前导入的 Python 模块,直接写在配置文件中。对于上述例子,就可以将配置文件写成如下:

custom_imports.py

custom_imports = dict(imports=['my_module'], allow_failed_imports=False)
optimizer = dict(type='CustomOptim')

这样我们就不用在训练代码中增加对应的 import 语句,只需要修改配置文件就可以实现非侵入式导入自定义注册模块。

cfg = Config.fromfile('custom_imports.py')

from mmengine.registry import OPTIMIZERS

custom_optim = OPTIMIZERS.build(cfg.optimizer)
print(custom_optim)
<my_module.CustomOptim object at 0x7f6983a87970>

跨项目继承配置文件

为了避免基于已有算法库开发新项目时需要复制大量的配置文件,MMEngine 的配置类支持配置文件的跨项目继承。例如我们基于 MMDetection 开发新的算法库,需要使用以下 MMDetection 的配置文件:

configs/_base_/schedules/schedule_1x.py
configs/_base_/datasets.coco_instance.py
configs/_base_/default_runtime.py
configs/_base_/models/faster-rcnn_r50_fpn.py

如果没有配置文件跨项目继承的功能,我们就需要把 MMDetection 的配置文件拷贝到当前项目,而我们现在只需要安装 MMDetection(如使用 mim install mmdet),在新项目的配置文件中按照以下方式继承 MMDetection 的配置文件:

cross_repo.py

_base_ = [
    'mmdet::_base_/schedules/schedule_1x.py',
    'mmdet::_base_/datasets/coco_instance.py',
    'mmdet::_base_/default_runtime.py',
    'mmdet::_base_/models/faster-rcnn_r50_fpn.py',
]

我们可以像加载普通配置文件一样加载 cross_repo.py

cfg = Config.fromfile('cross_repo.py')
print(cfg.train_cfg)
{'type': 'EpochBasedTrainLoop', 'max_epochs': 12, 'val_interval': 1, '_scope_': 'mmdet'}

通过指定 mmdet::,Config 类会去检索 mmdet 包中的配置文件目录,并继承指定的配置文件。实际上,只要算法库的 setup.py 文件符合 MMEngine 安装规范,在正确安装算法库以后,新的项目就可以使用上述用法去继承已有算法库的配置文件而无需拷贝。

跨项目获取配置文件

MMEngine 还提供了 get_configget_model 两个接口,支持对符合 MMEngine 安装规范 的算法库中的模型和配置文件做索引并进行 API 调用。通过 get_model 接口可以获得构建好的模型。通过 get_config 接口可以获得配置文件。

get_model 的使用样例如下所示,使用和跨项目继承配置文件相同的语法,指定 mmdet::,即可在 mmdet 包中检索对应的配置文件并构建和初始化相应模型。用户可以通过指定 pretrained=True 获得已经加载预训练权重的模型以进行训练或者推理。

from mmengine.hub import get_model

model = get_model(
    'mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True)
print(type(model))
http loads checkpoint from path: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
<class 'mmdet.models.detectors.faster_rcnn.FasterRCNN'>

get_config 的使用样例如下所示,使用和跨项目继承配置文件相同的语法,指定 mmdet::,即可实现去 mmdet 包中检索并加载对应的配置文件。用户可以基于这样得到的配置文件进行推理修改并自定义自己的算法模型。同时,如果用户指定 pretrained=True,得到的配置文件中会新增 model_path 字段,指定了对应模型预训练权重的路径。

from mmengine.hub import get_config

cfg = get_config(
    'mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True)
print(cfg.model_path)

https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth

数据集基类(BaseDataset)

基本介绍

算法库中的数据集类负责在训练/测试过程中为模型提供输入数据,OpenMMLab 下各个算法库中的数据集有一些共同的特点和需求,比如需要高效的内部数据存储格式,需要支持数据集拼接、数据集重复采样等功能。

因此 MMEngine 实现了一个数据集基类(BaseDataset)并定义了一些基本接口,且基于这套接口实现了一些数据集包装(DatasetWrapper)。OpenMMLab 算法库中的大部分数据集都会满足这套数据集基类定义的接口,并使用统一的数据集包装。

数据集基类的基本功能是加载数据集信息,这里我们将数据集信息分成两类,一种是元信息 (meta information),代表数据集自身相关的信息,有时需要被模型或其他外部组件获取,比如在图像分类任务中,数据集的元信息一般包含类别信息 classes,因为分类模型 model 一般需要记录数据集的类别信息;另一种为数据信息 (data information),在数据信息中,定义了具体样本的文件路径、对应标签等的信息。除此之外,数据集基类的另一个功能为不断地将数据送入数据流水线(data pipeline)中,进行数据预处理。

数据标注文件规范

为了统一不同任务的数据集接口,便于多任务的算法模型训练,OpenMMLab 制定了 OpenMMLab 2.0 数据集格式规范, 数据集标注文件需符合该规范,数据集基类基于该规范去读取与解析数据标注文件。如果用户提供的数据标注文件不符合规定格式,用户可以选择将其转化为规定格式,并使用 OpenMMLab 的算法库基于该数据标注文件进行算法训练和测试。

OpenMMLab 2.0 数据集格式规范规定,标注文件必须为 jsonyamlymlpicklepkl 格式;标注文件中存储的字典必须包含 metainfodata_list 两个字段。其中 metainfo 是一个字典,里面包含数据集的元信息;data_list 是一个列表,列表中每个元素是一个字典,该字典定义了一个原始数据(raw data),每个原始数据包含一个或若干个训练/测试样本。

以下是一个 JSON 标注文件的例子(该例子中每个原始数据只包含一个训练/测试样本):


{
    'metainfo':
        {
            'classes': ('cat', 'dog'),
            ...
        },
    'data_list':
        [
            {
                'img_path': "xxx/xxx_0.jpg",
                'img_label': 0,
                ...
            },
            {
                'img_path': "xxx/xxx_1.jpg",
                'img_label': 1,
                ...
            },
            ...
        ]
}

同时假设数据存放路径如下:

data
├── annotations
│   ├── train.json
├── train
│   ├── xxx/xxx_0.jpg
│   ├── xxx/xxx_1.jpg
│   ├── ...

数据集基类的初始化流程

数据集基类的初始化流程如下图所示:

  1. load metainfo:获取数据集的元信息,元信息有三种来源,优先级从高到低为:

  • __init__() 方法中用户传入的 metainfo 字典;改动频率最高,因为用户可以在实例化数据集时,传入该参数;

  • 类属性 BaseDataset.METAINFO 字典;改动频率中等,因为用户可以改动自定义数据集类中的类属性 BaseDataset.METAINFO

  • 标注文件中包含的 metainfo 字典;改动频率最低,因为标注文件一般不做改动。

    如果三种来源中有相同的字段,优先级最高的来源决定该字段的值,这些字段的优先级比较是:用户传入的 metainfo 字典里的字段 > BaseDataset.METAINFO 字典里的字段 > 标注文件中 metainfo 字典里的字段。

  1. join path:处理数据与标注文件的路径;

  2. build pipeline:构建数据流水线(data pipeline),用于数据预处理与数据准备;

  3. full init:完全初始化数据集类,该步骤主要包含以下操作:

  • load data list:读取与解析满足 OpenMMLab 2.0 数据集格式规范的标注文件,该步骤中会调用 parse_data_info() 方法,该方法负责解析标注文件里的每个原始数据;

  • filter data (可选):根据 filter_cfg 过滤无用数据,比如不包含标注的样本等;默认不做过滤操作,下游子类可以按自身所需对其进行重写;

  • get subset (可选):根据给定的索引或整数值采样数据,比如只取前 10 个样本参与训练/测试;默认不采样数据,即使用全部数据样本;

  • serialize data (可选):序列化全部样本,以达到节省内存的效果,详情请参考节省内存;默认操作为序列化全部样本。

数据集基类中包含的 parse_data_info() 方法用于将标注文件里的一个原始数据处理成一个或若干个训练/测试样本的方法。因此对于自定义数据集类,用户需要实现 parse_data_info() 方法。

数据集基类提供的接口

torch.utils.data.Dataset 类似,数据集初始化后,支持 __getitem__ 方法,用来索引数据,以及 __len__ 操作获取数据集大小,除此之外,OpenMMLab 的数据集基类主要提供了以下接口来访问具体信息:

  • metainfo:返回元信息,返回值为字典

  • get_data_info(idx):返回指定 idx 的样本全量信息,返回值为字典

  • __getitem__(idx):返回指定 idx 的样本经过 pipeline 之后的结果(也就是送入模型的数据),返回值为字典

  • __len__():返回数据集长度,返回值为整数型

  • get_subset_(indices):根据 indices 以 inplace 的方式修改原数据集类。如果 indicesint,则原数据集类只包含前若干个数据样本;如果 indicesSequence[int],则原数据集类包含根据 Sequence[int] 指定的数据样本。

  • get_subset(indices):根据 indices inplace 的方式返回子数据集类,即重新复制一份子数据集。如果 indicesint,则返回的子数据集类只包含前若干个数据样本;如果 indicesSequence[int],则返回的子数据集类包含根据 Sequence[int] 指定的数据样本。

使用数据集基类自定义数据集类

在了解了数据集基类的初始化流程与提供的接口之后,就可以基于数据集基类自定义数据集类。

对于满足 OpenMMLab 2.0 数据集格式规范的标注文件

如上所述,对于满足 OpenMMLab 2.0 数据集格式规范的标注文件,用户可以重载 parse_data_info() 来加载标签。以下是一个使用数据集基类来实现某一具体数据集的例子。

import os.path as osp

from mmengine.dataset import BaseDataset


class ToyDataset(BaseDataset):

    # 以上面标注文件为例,在这里 raw_data_info 代表 `data_list` 对应列表里的某个字典:
    # {
    #    'img_path': "xxx/xxx_0.jpg",
    #    'img_label': 0,
    #    ...
    # }
    def parse_data_info(self, raw_data_info):
        data_info = raw_data_info
        img_prefix = self.data_prefix.get('img_path', None)
        if img_prefix is not None:
            data_info['img_path'] = osp.join(
                img_prefix, data_info['img_path'])
        return data_info

使用自定义数据集类

在定义了数据集类后,就可以通过如下配置实例化 ToyDataset


class LoadImage:

    def __call__(self, results):
        results['img'] = cv2.imread(results['img_path'])
        return results

class ParseImage:

    def __call__(self, results):
        results['img_shape'] = results['img'].shape
        return results

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline)

同时可以使用数据集类提供的对外接口访问具体的样本信息:

toy_dataset.metainfo
# dict(classes=('cat', 'dog'))

toy_dataset.get_data_info(0)
# {
#     'img_path': "data/train/xxx/xxx_0.jpg",
#     'img_label': 0,
#     ...
# }

len(toy_dataset)
# 2

toy_dataset[0]
# {
#     'img_path': "data/train/xxx/xxx_0.jpg",
#     'img_label': 0,
#     'img': a ndarray with shape (H, W, 3), which denotes the value of the image,
#     'img_shape': (H, W, 3) ,
#     ...
# }

# `get_subset` 接口不对原数据集类做修改,即完全复制一份新的
sub_toy_dataset = toy_dataset.get_subset(1)
len(toy_dataset), len(sub_toy_dataset)
# 2, 1

# `get_subset_` 接口会对原数据集类做修改,即 inplace 的方式
toy_dataset.get_subset_(1)
len(toy_dataset)
# 1

经过以上步骤,可以了解基于数据集基类如何自定义新的数据集类,以及如何使用自定义数据集类。

自定义视频的数据集类

在上面的例子中,标注文件的每个原始数据只包含一个训练/测试样本(通常是图像领域)。如果每个原始数据包含若干个训练/测试样本(通常是视频领域),则只需保证 parse_data_info() 的返回值为 list[dict] 即可:

from mmengine.dataset import BaseDataset


class ToyVideoDataset(BaseDataset):

    # raw_data_info 仍为一个字典,但它包含了多个样本
    def parse_data_info(self, raw_data_info):
        data_list = []

        ...

        for ... :

            data_info = dict()

            ...

            data_list.append(data_info)

        return data_list

ToyVideoDataset 使用方法与 ToyDataset 类似,在此不做赘述。

对于不满足 OpenMMLab 2.0 数据集格式规范的标注文件

对于不满足 OpenMMLab 2.0 数据集格式规范的标注文件,有两种方式来使用数据集基类:

  1. 将不满足规范的标注文件转换成满足规范的标注文件,再通过上述方式使用数据集基类。

  2. 实现一个新的数据集类,继承自数据集基类,并且重载数据集基类的 load_data_list(self): 函数,处理不满足规范的标注文件,并保证返回值为 list[dict],其中每个 dict 代表一个数据样本。

数据集基类的其它特性

数据集基类还包含以下特性:

懒加载(lazy init)

在数据集类实例化时,需要读取并解析标注文件,因此会消耗一定时间。然而在某些情况比如预测可视化时,往往只需要数据集类的元信息,可能并不需要读取与解析标注文件。为了节省这种情况下数据集类实例化的时间,数据集基类支持懒加载:

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline,
    # 在这里传入 lazy_init 变量
    lazy_init=True)

lazy_init=True 时,ToyDataset 的初始化方法只执行了数据集基类的初始化流程中的 1、2、3 步骤,此时 toy_dataset 并未被完全初始化,因为 toy_dataset 并不会读取与解析标注文件,只会设置数据集类的元信息(metainfo)。

自然的,如果之后需要访问具体的数据信息,可以手动调用 toy_dataset.full_init() 接口来执行完整的初始化过程,在这个过程中数据标注文件将被读取与解析。调用 get_data_info(idx), __len__(), __getitem__(idx)get_subset_(indices)get_subset(indices) 接口也会自动地调用 full_init() 接口来执行完整的初始化过程(仅在第一次调用时,之后调用不会重复地调用 full_init() 接口):

# 完整初始化
toy_dataset.full_init()

# 初始化完毕,现在可以访问具体数据
len(toy_dataset)
# 2
toy_dataset[0]
# {
#     'img_path': "data/train/xxx/xxx_0.jpg",
#     'img_label': 0,
#     'img': a ndarray with shape (H, W, 3), which denotes the value the image,
#     'img_shape': (H, W, 3) ,
#     ...
# }

注意:

通过直接调用 __getitem__() 接口来执行完整初始化会带来一定风险:如果一个数据集类首先通过设置 lazy_init=True 未进行完全初始化,然后直接送入数据加载器(dataloader)中,在后续读取数据的过程中,不同的 worker 会同时读取与解析标注文件,虽然这样可能可以正常运行,但是会消耗大量的时间与内存。因此,建议在需要访问具体数据之前,提前手动调用 full_init() 接口来执行完整的初始化过程。

以上通过设置 lazy_init=True 未进行完全初始化,之后根据需求再进行完整初始化的方式,称为懒加载。

节省内存

在具体的读取数据过程中,数据加载器(dataloader)通常会起多个 worker 来预取数据,多个 worker 都拥有完整的数据集类备份,因此内存中会存在多份相同的 data_list,为了节省这部分内存消耗,数据集基类可以提前将 data_list 序列化存入内存中,使得多个 worker 可以共享同一份 data_list,以达到节省内存的目的。

数据集基类默认是将 data_list 序列化存入内存,也可以通过 serialize_data 变量(默认为 True)来控制是否提前将 data_list 序列化存入内存中:

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline,
    # 在这里传入 serialize_data 变量
    serialize_data=False)

上面例子不会提前将 data_list 序列化存入内存中,因此不建议在使用数据加载器开多个 worker 加载数据的情况下,使用这种方式实例化数据集类。

数据集基类包装

除了数据集基类,MMEngine 也提供了若干个数据集基类包装:ConcatDataset, RepeatDataset, ClassBalancedDataset。这些数据集基类包装同样也支持懒加载与拥有节省内存的特性。

ConcatDataset

MMEngine 提供了 ConcatDataset 包装来拼接多个数据集,使用方法如下:

from mmengine.dataset import ConcatDataset

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset_1 = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline)

toy_dataset_2 = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='val/'),
    ann_file='annotations/val.json',
    pipeline=pipeline)

toy_dataset_12 = ConcatDataset(datasets=[toy_dataset_1, toy_dataset_2])

上述例子将数据集的 train 部分与 val 部分合成一个大的数据集。

RepeatDataset

MMEngine 提供了 RepeatDataset 包装来重复采样某个数据集若干次,使用方法如下:

from mmengine.dataset import RepeatDataset

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline)

toy_dataset_repeat = RepeatDataset(dataset=toy_dataset, times=5)

上述例子将数据集的 train 部分重复采样了 5 次。

ClassBalancedDataset

MMEngine 提供了 ClassBalancedDataset 包装,来基于数据集中类别出现频率,重复采样相应样本。

注意:

ClassBalancedDataset 包装假设了被包装的数据集类支持 get_cat_ids(idx) 方法,get_cat_ids(idx) 方法返回一个列表,该列表包含了 idx 指定的 data_info 包含的样本类别,使用方法如下:

from mmengine.dataset import BaseDataset, ClassBalancedDataset

class ToyDataset(BaseDataset):

    def parse_data_info(self, raw_data_info):
        data_info = raw_data_info
        img_prefix = self.data_prefix.get('img_path', None)
        if img_prefix is not None:
            data_info['img_path'] = osp.join(
                img_prefix, data_info['img_path'])
        return data_info

    # 必须支持的方法,需要返回样本的类别
    def get_cat_ids(self, idx):
        data_info = self.get_data_info(idx)
        return [int(data_info['img_label'])]

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline)

toy_dataset_repeat = ClassBalancedDataset(dataset=toy_dataset, oversample_thr=1e-3)

上述例子将数据集的 train 部分以 oversample_thr=1e-3 重新采样,具体地,对于数据集中出现频率低于 1e-3 的类别,会重复采样该类别对应的样本,否则不重复采样,具体采样策略请参考 ClassBalancedDataset API 文档。

自定义数据集类包装

由于数据集基类实现了懒加载的功能,因此在自定义数据集类包装时,需要遵循一些规则,下面以一个例子的方式来展示如何自定义数据集类包装:

from mmengine.dataset import BaseDataset
from mmengine.registry import DATASETS


@DATASETS.register_module()
class ExampleDatasetWrapper:

    def __init__(self, dataset, lazy_init=False, ...):
        # 构建原数据集(self.dataset)
        if isinstance(dataset, dict):
            self.dataset = DATASETS.build(dataset)
        elif isinstance(dataset, BaseDataset):
            self.dataset = dataset
        else:
            raise TypeError(
                'elements in datasets sequence should be config or '
                f'`BaseDataset` instance, but got {type(dataset)}')
        # 记录原数据集的元信息
        self._metainfo = self.dataset.metainfo

        '''
        1. 在这里实现一些代码,来记录用于包装数据集的一些超参。
        '''

        self._fully_initialized = False
        if not lazy_init:
            self.full_init()

    def full_init(self):
        if self._fully_initialized:
            return

        # 将原数据集完全初始化
        self.dataset.full_init()

        '''
        2. 在这里实现一些代码,来包装原数据集。
        '''

        self._fully_initialized = True

    @force_full_init
    def _get_ori_dataset_idx(self, idx: int):

        '''
        3. 在这里实现一些代码,来将包装的索引 `idx` 映射到原数据集的索引 `ori_idx`。
        '''
        ori_idx = ...

        return ori_idx

    # 提供与 `self.dataset` 一样的对外接口。
    @force_full_init
    def get_data_info(self, idx):
        sample_idx = self._get_ori_dataset_idx(idx)
        return self.dataset.get_data_info(sample_idx)

    # 提供与 `self.dataset` 一样的对外接口。
    def __getitem__(self, idx):
        if not self._fully_initialized:
            warnings.warn('Please call `full_init` method manually to '
                          'accelerate the speed.')
            self.full_init()

        sample_idx = self._get_ori_dataset_idx(idx)
        return self.dataset[sample_idx]

    # 提供与 `self.dataset` 一样的对外接口。
    @force_full_init
    def __len__(self):

        '''
        4. 在这里实现一些代码,来计算包装数据集之后的长度。
        '''
        len_wrapper = ...

        return len_wrapper

    # 提供与 `self.dataset` 一样的对外接口。
    @property
    def metainfo(self)
        return copy.deepcopy(self._metainfo)

数据变换 (Data Transform)

在 OpenMMLab 算法库中,数据集的构建和数据的准备是相互解耦的。通常,数据集的构建只对数据集进行解析,记录每个样本的基本信息;而数据的准备则是通过一系列的数据变换,根据样本的基本信息进行数据加载、预处理、格式化等操作。

使用数据变换类

在 MMEngine 中,我们使用各种可调用的数据变换类来进行数据的操作。这些数据变换类可以接受若干配置参数进行实例化,之后通过调用的方式对输入的数据字典进行处理。同时,我们约定所有数据变换都接受一个字典作为输入,并将处理后的数据输出为一个字典。一个简单的例子如下:

注解

MMEngine 中仅约定了数据变换类的规范,常用的数据变换类实现及基类都在 MMCV 中,因此在本篇教程需要提前安装好 MMCV,参见 MMCV 安装教程

>>> import numpy as np
>>> from mmcv.transforms import Resize
>>>
>>> transform = Resize(scale=(224, 224))
>>> data_dict = {'img': np.random.rand(256, 256, 3)}
>>> data_dict = transform(data_dict)
>>> print(data_dict['img'].shape)
(224, 224, 3)

在配置文件中使用

在配置文件中,我们将一系列数据变换组合成为一个列表,称为数据流水线(Data Pipeline),传给数据集的 pipeline 参数。通常数据流水线由以下几个部分组成:

  1. 数据加载,通常使用 LoadImageFromFile

  2. 标签加载,通常使用 LoadAnnotations

  3. 数据处理及增强,例如 RandomResize

  4. 数据格式化,根据任务不同,在各个仓库使用自己的变换操作,通常名为 PackXXXInputs,其中 XXX 是任务的名称,如分类任务中的 PackClsInputs

以分类任务为例,我们在下图展示了一个典型的数据流水线。对每个样本,数据集中保存的基本信息是一个如图中最左侧所示的字典,之后每经过一个由蓝色块代表的数据变换操作,数据字典中都会加入新的字段(标记为绿色)或更新现有的字段(标记为橙色)。

如果我们希望在测试中使用上述数据流水线,则配置文件如下所示:

test_dataloader = dict(
    batch_size=32,
    dataset=dict(
        type='ImageNet',
        data_root='data/imagenet',
        pipeline = [
            dict(type='LoadImageFromFile'),
            dict(type='Resize', size=256, keep_ratio=True),
            dict(type='CenterCrop', crop_size=224),
            dict(type='PackClsInputs'),
        ]
    )
)

常用的数据变换类

按照功能,常用的数据变换类可以大致分为数据加载、数据预处理与增强、数据格式化。我们在 MMCV 中提供了一系列常用的数据变换类:

数据加载

为了支持大规模数据集的加载,通常在数据集初始化时不加载数据,只加载相应的路径。因此需要在数据流水线中进行具体数据的加载。

数据变换类

功能

LoadImageFromFile

根据路径加载图像

LoadAnnotations

加载和组织标注信息,如 bbox、语义分割图等

数据预处理及增强

数据预处理和增强通常是对图像本身进行变换,如裁剪、填充、缩放等。

数据变换类

功能

Pad

填充图像边缘

CenterCrop

居中裁剪

Normalize

对图像进行归一化

Resize

按照指定尺寸或比例缩放图像

RandomResize

缩放图像至指定范围的随机尺寸

RandomChoiceResize

缩放图像至多个尺寸中的随机一个尺寸

RandomGrayscale

随机灰度化

RandomFlip

图像随机翻转

数据格式化

数据格式化操作通常是对数据进行的类型转换。

数据变换类

功能

ToTensor

将指定的数据转换为 torch.Tensor

ImageToTensor

将图像转换为 torch.Tensor

自定义数据变换类

要实现一个新的数据变换类,需要继承 BaseTransform,并实现 transform 方法。这里,我们使用一个简单的翻转变换(MyFlip)作为示例:

import random
import mmcv
from mmcv.transforms import BaseTransform, TRANSFORMS

@TRANSFORMS.register_module()
class MyFlip(BaseTransform):
    def __init__(self, direction: str):
        super().__init__()
        self.direction = direction

    def transform(self, results: dict) -> dict:
        img = results['img']
        results['img'] = mmcv.imflip(img, direction=self.direction)
        return results

从而,我们可以实例化一个 MyFlip 对象,并将之作为一个可调用对象,来处理我们的数据字典。

import numpy as np

transform = MyFlip(direction='horizontal')
data_dict = {'img': np.random.rand(224, 224, 3)}
data_dict = transform(data_dict)
processed_img = data_dict['img']

又或者,在配置文件的 pipeline 中使用 MyFlip 变换

pipeline = [
    ...
    dict(type='MyFlip', direction='horizontal'),
    ...
]

需要注意的是,如需在配置文件中使用,需要保证 MyFlip 类所在的文件在运行时能够被导入。

初始化

基于 Pytorch 构建模型时,我们通常会选择 nn.Module 作为模型的基类,搭配使用 Pytorch 的初始化模块 torch.nn.init,完成模型的初始化。MMEngine 在此基础上抽象出基础模块(BaseModule),让我们能够通过传参或配置文件来选择模型的初始化方式。此外,MMEngine 还提供了一系列模块初始化函数,让我们能够更加方便灵活地初始化模型参数。

配置式初始化

为了能够更加灵活地初始化模型权重,MMEngine 抽象出了模块基类 BaseModule。模块基类继承自 nn.Module,在具备 nn.Module 基础功能的同时,还支持在构造时接受参数,以此来选择权重初始化方式。继承自 BaseModule 的模型可以在实例化阶段接受 init_cfg 参数,我们可以通过配置 init_cfg 为模型中任意组件灵活地选择初始化方式。目前我们可以在 init_cfg 中配置以下初始化器:

Initializer Registered name Function
ConstantInit Constant 将 weight 和 bias 初始化为指定常量,通常用于初始化卷积
XavierInit Xavier 将 weight Xavier 方式初始化,将 bias 初始化成指定常量,通常用于初始化卷积
NormalInit Normal 将 weight 以正态分布的方式初始化,将 bias 初始化成指定常量,通常用于初始化卷积
TruncNormalInit TruncNormal 将 weight 以被截断的正态分布的方式初始化,参数 a 和 b 为正态分布的有效区域;将 bias 初始化成指定常量,通常用于初始化 Transformer
UniformInit Uniform 将 weight 以均匀分布的方式初始化,参数 a 和 b 为均匀分布的范围;将 bias 初始化为指定常量,通常用于初始化卷积
KaimingInit Kaiming 将 weight 以 Kaiming 的方式初始化,将 bias 初始化成指定常量,通常用于初始化卷积
Caffe2XavierInit Caffe2Xavier Caffe2 中 Xavier 初始化方式,在 Pytorch 中对应 "fan_in", "normal" 模式的 Kaiming 初始化,,通常用于初始化卷
Pretrained PretrainedInit 加载预训练权重

我们通过几个例子来理解如何在 init_cfg 里配置初始化器,来选择模型的初始化方式。

使用预训练权重初始化

假设我们定义了模型类 ToyNet,它继承自模块基类(BaseModule)。此时我们可以在 ToyNet 初始化时传入 init_cfg 参数来选择模型的初始化方式,实例化后再调用 init_weights 方法,完成权重的初始化。以加载预训练权重为例:

import torch
import torch.nn as nn

from mmengine.model import BaseModule


class ToyNet(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.conv1 = nn.Linear(1, 1)


# 保存预训练权重
toy_net = ToyNet()
torch.save(toy_net.state_dict(), './pretrained.pth')
pretrained = './pretrained.pth'

# 配置加载预训练权重的初始化方式
toy_net = ToyNet(init_cfg=dict(type='Pretrained', checkpoint=pretrained))
# 加载权重
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO - load model from: ./pretrained.pth
08/19 16:50:24 - mmengine - INFO - local loads checkpoint from path: ./pretrained.pth

init_cfg 是一个字典时,type 字段就表示一种初始化器,它需要被注册到 WEIGHT_INITIALIZERS 注册器。我们可以通过指定 init_cfg=dict(type='Pretrained', checkpoint='path/to/ckpt') 来加载预训练权重,其中 PretrainedPretrainedInit 初始化器的缩写,这个映射名由 WEIGHT_INITIALIZERS 维护;checkpointPretrainedInit 的初始化参数,用于指定权重的加载路径,它可以是本地磁盘路径,也可以是 URL。

常用的初始化方式

和使用 PretrainedInit 初始化器类似,如果我们想对卷积做 Kaiming 初始化,需要令 init_cfg=dict(type='Kaiming', layer='Conv2d')。这样模型初始化时,就会以 Kaiming 初始化的方式来初始化类型为 Conv2d 的模块。

有时候我们可能需要用不同的初始化方式去初始化不同类型的模块,例如对卷积使用 Kaiming 初始化,对线性层使用 Xavier 初始化。此时我们可以使 init_cfg 成为一个列表,其中的每一个元素都表示对某些层使用特定的初始化方式。

import torch.nn as nn

from mmengine.model import BaseModule


class ToyNet(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.linear = nn.Linear(1, 1)
        self.conv = nn.Conv2d(1, 1, 1)


# 对卷积做 Kaiming 初始化,线性层做 Xavier 初始化
toy_net = ToyNet(
    init_cfg=[
        dict(type='Kaiming', layer='Conv2d'),
        dict(type='Xavier', layer='Linear')
    ], )
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
linear.weight - torch.Size([1, 1]):
XavierInit: gain=1, distribution=normal, bias=0

08/19 16:50:24 - mmengine - INFO -
linear.bias - torch.Size([1]):
XavierInit: gain=1, distribution=normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

类似地,layer 参数也可以是一个列表,表示列表中的多种不同的 layer 均使用 type 指定的初始化方式

# 对卷积和线性层做 Kaiming 初始化
toy_net = ToyNet(init_cfg=[dict(type='Kaiming', layer=['Conv2d', 'Linear'])], )
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
linear.weight - torch.Size([1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
linear.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

更细粒度的初始化

有时同一类型的不同模块有不同初始化方式,例如现在有 conv1conv2 两个模块,他们的类型均为 Conv2d 。我们需要对 conv1 进行 Kaiming 初始化,conv2 进行 Xavier 初始化,则可以通过配置 override 参数来满足这样的需求:

import torch.nn as nn

from mmengine.model import BaseModule


class ToyNet(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.conv2 = nn.Conv2d(1, 1, 1)


# 对 conv1 做 Kaiming 初始化,对 从 conv2 做 Xavier 初始化
toy_net = ToyNet(
    init_cfg=[
        dict(
            type='Kaiming',
            layer=['Conv2d'],
            override=dict(name='conv2', type='Xavier')),
    ], )
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
conv1.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv1.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv2.weight - torch.Size([1, 1, 1, 1]):
XavierInit: gain=1, distribution=normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv2.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

override 可以理解成一个嵌套的 init_cfg, 他同样可以是 list 或者 dict,也需要通过 type 字段指定初始化方式。不同的是 override 必须指定 namename 相当于 override 的作用域,如上例中,override 的作用域为 toy_net.conv2,我们会以 Xavier 初始化方式初始化 toy_net.conv2 下的所有参数,而不会影响作用域以外的模块。

自定义的初始化方式

尽管 init_cfg 能够控制各个模块的初始化方式,但是在不扩展 WEIGHT_INITIALIZERS 的情况下,我们是无法初始化一些自定义模块的,例如表格中提到的大多数初始化器,都需要对应的模块有 weightbias 属性 。对于这种情况,我们建议让自定义模块实现 init_weights 方法。模型调用 init_weights 时,会链式地调用所有子模块的 init_weights

假设我们定义了以下模块:

  • 继承自 nn.ModuleToyConv,实现了 init_weights 方法,让 custom_weight 初始化为 1,custom_bias 初始化为 0

  • 继承自模块基类的模型 ToyNet,且含有 ToyConv 子模块

我们在调用 ToyNetinit_weights 方法时,会链式的调用的子模块 ToyConvinit_weights 方法,实现自定义模块的初始化。

import torch
import torch.nn as nn

from mmengine.model import BaseModule


class ToyConv(nn.Module):

    def __init__(self):
        super().__init__()
        self.custom_weight = nn.Parameter(torch.empty(1, 1, 1, 1))
        self.custom_bias = nn.Parameter(torch.empty(1))

    def init_weights(self):
        with torch.no_grad():
            self.custom_weight = self.custom_weight.fill_(1)
            self.custom_bias = self.custom_bias.fill_(0)


class ToyNet(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.conv2 = nn.Conv2d(1, 1, 1)
        self.custom_conv = ToyConv()


toy_net = ToyNet(
    init_cfg=[
        dict(
            type='Kaiming',
            layer=['Conv2d'],
            override=dict(name='conv2', type='Xavier'))
    ])
# 链式调用 `ToyConv.init_weights()`,以自定义的方式初始化
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
conv1.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv1.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv2.weight - torch.Size([1, 1, 1, 1]):
XavierInit: gain=1, distribution=normal, bias=0

08/19 16:50:24 - mmengine - INFO -
conv2.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0

08/19 16:50:24 - mmengine - INFO -
custom_conv.custom_weight - torch.Size([1, 1, 1, 1]):
Initialized by user-defined `init_weights` in ToyConv

08/19 16:50:24 - mmengine - INFO -
custom_conv.custom_bias - torch.Size([1]):
Initialized by user-defined `init_weights` in ToyConv

小结

最后我们对 init_cfginit_weights 两种初始化方式做一些总结:

1. 配置 init_cfg 控制初始化

  • 通常用于初始化一些比较底层的模块,例如卷积、线性层等。如果想通过 init_cfg 配置自定义模块的初始化方式,需要将相应的初始化器注册到 WEIGHT_INITIALIZERS 里。

  • 动态初始化特性,初始化方式随 init_cfg 的值改变。

2. 实现 init_weights 方法

  • 通常用于初始化自定义模块。相比于 init_cfg 的自定义初始化,实现 init_weights 方法更加简单,无需注册。然而,它的灵活性不及 init_cfg,无法动态地指定任意模块的初始化方式。

注解

  • init_weights 的优先级比 init_cfg

  • 执行器会在 train() 函数中调用 init_weights。

函数式初始化

自定义的初始化方式一节提到,我们可以在 init_weights 里实现自定义的参数初始化逻辑。为了能够更加方便地实现参数初始化,MMEngine 在 torch.nn.init的基础上,提供了一系列模块初始化函数来初始化整个模块。例如我们对卷积层的权重(weight)进行正态分布初始化,卷积层的偏置(bias)进行常数初始化,基于 torch.nn.init 的实现如下:

from torch.nn.init import normal_, constant_
import torch.nn as nn

model = nn.Conv2d(1, 1, 1)
normal_(model.weight, mean=0, std=0.01)
constant_(model.bias, val=0)
Parameter containing:
tensor([0.], requires_grad=True)

上述流程实际上是卷积正态分布初始化的标准流程,因此 MMEngine 在此基础上做了进一步地简化,实现了一系列常用的模块初始化函数。相比 torch.nn.init,MMEngine 提供的初始化函数直接接受卷积模块,一行代码能实现同样的初始化逻辑:

from mmengine.model import normal_init

normal_init(model, mean=0, std=0.01, bias=0)

类似地,我们也可以用 Kaiming 初始化和 Xavier 初始化:

from mmengine.model import kaiming_init, xavier_init

kaiming_init(model)
xavier_init(model)

目前 MMEngine 提供了以下初始化函数:

初始化函数 功能
constant_init 将 weight 和 bias 初始化为指定常量,通常用于初始化卷积
xavier_init 将 weight 以 Xavier 方式初始化,将 bias 初始化成指定常量,通常用于初始化卷积
normal_init 将 weight 以正态分布的方式初始化,将 bias 初始化成指定常量,通常用于初始化卷积
trunc_normal_init 将 weight 以被截断的正态分布的方式初始化,参数 a 和 b 为正态分布的有效区域;将 bias 初始化成指定常量,通常用于初始化 Transformer
uniform_init 将 weight 以均匀分布的方式初始化,参数 a 和 b 为均匀分布的范围;将 bias 初始化为指定常量,通常用于初始化卷积
kaiming_init 将 weight 以 Kaiming 方式初始化,将 bias 初始化成指定常量,通常用于初始化卷积
caffe2_xavier_init Caffe2 中 Xavier 初始化方式,在 Pytorch 中对应 "fan_in", "normal" 模式的 Kaiming 初始化,通常用于初始化卷积
bias_init_with_prob 以概率值的形式初始化 bias

可视化

可视化可以给深度学习的模型训练和测试过程提供直观解释。

MMEngine 提供了 Visualizer 可视化器用以可视化和存储模型训练和测试过程中的状态以及中间结果,具备如下功能:

  • 支持基础绘图接口以及特征图可视化

  • 支持本地、TensorBoard 以及 WandB 等多种后端,可以将训练状态例如 loss 、lr 或者性能评估指标以及可视化的结果写入指定的单一或多个后端

  • 允许在代码库任意位置调用,对任意位置的特征、图像和状态等进行可视化和存储。

基础绘制接口

可视化器提供了常用对象的绘制接口,例如绘制检测框、点、文本、线、圆、多边形和二值掩码。这些基础 API 支持以下特性:

  • 可以多次调用,实现叠加绘制需求

  • 均支持多输入,除了要求文本输入的绘制接口外,其余接口同时支持 Tensor 以及 Numpy array 的输入

常见用法如下:

(1) 绘制检测框、掩码和文本等

import torch
import mmcv
from mmengine.visualization import Visualizer

image = mmcv.imread('docs/en/_static/image/cat_dog.png', channel_order='rgb')
visualizer = Visualizer(image=image)
# 绘制单个检测框, xyxy 格式
visualizer.draw_bboxes(torch.tensor([72, 13, 179, 147]))
# 绘制多个检测框
visualizer.draw_bboxes(torch.tensor([[33, 120, 209, 220], [72, 13, 179, 147]]))
visualizer.show()
visualizer.set_image(image=image)
visualizer.draw_texts("cat and dog", torch.tensor([10, 20]))
visualizer.show()

你也可以通过各个绘制接口中提供的参数来定制绘制对象的颜色和宽度等等

visualizer.set_image(image=image)
visualizer.draw_bboxes(torch.tensor([72, 13, 179, 147]), edge_colors='r', line_widths=3)
visualizer.draw_bboxes(torch.tensor([[33, 120, 209, 220]]),line_styles='--')
visualizer.show()

(2) 叠加显示

上述绘制接口可以多次调用,从而实现叠加显示需求

visualizer.set_image(image=image)
visualizer.draw_bboxes(torch.tensor([[33, 120, 209, 220], [72, 13, 179, 147]]))
visualizer.draw_texts("cat and dog",
                      torch.tensor([10, 20])).draw_circles(torch.tensor([40, 50]), torch.tensor([20]))
visualizer.show()

特征图绘制

特征图可视化功能较多,目前只支持单张特征图的可视化,为了方便理解,将其对外接口梳理如下:

@staticmethod
def draw_featmap(featmap: torch.Tensor, # 输入格式要求为 CHW
                 overlaid_image: Optional[np.ndarray] = None, # 如果同时输入了 image 数据,则特征图会叠加到 image 上绘制
                 channel_reduction: Optional[str] = 'squeeze_mean', # 多个通道压缩为单通道的策略
                 topk: int = 10, # 可选择激活度最高的 topk 个特征图显示
                 arrangement: Tuple[int, int] = (5, 2), # 多通道展开为多张图时候布局
                 resize_shape:Optional[tuple] = None, # 可以指定 resize_shape 参数来缩放特征图
                 alpha: float = 0.5) -> np.ndarray: # 图片和特征图绘制的叠加比例

其功能可以归纳如下

  • 输入的 Tensor 一般是包括多个通道的,channel_reduction 参数可以将多个通道压缩为单通道,然后和图片进行叠加显示

    • squeeze_mean 将输入的 C 维度采用 mean 函数压缩为一个通道,输出维度变成 (1, H, W)

    • select_max 从输入的 C 维度中先在空间维度 sum,维度变成 (C, ),然后选择值最大的通道

    • None 表示不需要压缩,此时可以通过 topk 参数可选择激活度最高的 topk 个特征图显示

  • 在 channel_reduction 参数为 None 的情况下,topk 参数生效,其会按照激活度排序选择 topk 个通道,然后和图片进行叠加显示,并且此时会通过 arrangement 参数指定显示的布局

    • 如果 topk 不是 -1,则会按照激活度排序选择 topk 个通道显示

    • 如果 topk = -1,此时通道 C 必须是 1 或者 3 表示输入数据是图片,否则报错提示用户应该设置 channel_reduction来压缩通道。

  • 考虑到输入的特征图通常非常小,函数支持输入 resize_shape 参数,方便将特征图进行上采样后进行可视化。

常见用法如下:以预训练好的 ResNet18 模型为例,通过提取 layer4 层输出进行特征图可视化

(1) 将多通道特征图采用 select_max 参数压缩为单通道并显示

import numpy as np
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor

def preprocess_image(img, mean, std):
    preprocessing = Compose([
        ToTensor(),
        Normalize(mean=mean, std=std)
    ])
    return preprocessing(img.copy()).unsqueeze(0)

model = resnet18(pretrained=True)

def _forward(x):
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)

    x1 = model.layer1(x)
    x2 = model.layer2(x1)
    x3 = model.layer3(x2)
    x4 = model.layer4(x3)
    return x4

model.forward = _forward

image_norm = np.float32(image) / 255
input_tensor = preprocess_image(image_norm,
                                mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
feat = model(input_tensor)[0]

visualizer = Visualizer()
drawn_img = visualizer.draw_featmap(feat, channel_reduction='select_max')
visualizer.show(drawn_img)

由于输出的 feat 特征图尺寸为 7x7,直接可视化效果不佳,用户可以通过叠加输入图片或者 resize_shape 参数来缩放特征图。如果传入图片尺寸和特征图大小不一致,会强制将特征图采样到和输入图片相同空间尺寸

drawn_img = visualizer.draw_featmap(feat, image, channel_reduction='select_max')
visualizer.show(drawn_img)

(2) 利用 topk=5 参数选择多通道特征图中激活度最高的 5 个通道并采用 2x3 布局显示

drawn_img = visualizer.draw_featmap(feat, image, channel_reduction=None, topk=5, arrangement=(2, 3))
visualizer.show(drawn_img)

用户可以通过 arrangement 参数选择自己想要的布局

drawn_img = visualizer.draw_featmap(feat, image, channel_reduction=None, topk=5, arrangement=(4, 2))
visualizer.show(drawn_img)

基础存储接口

在绘制完成后,可以选择本地窗口显示,也可以存储到不同后端中,目前 MMEngine 内置了本地存储、Tensorboard 存储和 WandB 存储 3 个后端,且支持存储绘制后的图片、loss 等标量数据和配置文件。

(1) 存储绘制后的图片

假设存储后端为本地存储

visualizer = Visualizer(image=image, vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')

visualizer.draw_bboxes(torch.tensor([[33, 120, 209, 220], [72, 13, 179, 147]]))
visualizer.draw_texts("cat and dog", torch.tensor([10, 20]))
visualizer.draw_circles(torch.tensor([40, 50]), torch.tensor([20]))

# 会生成 temp_dir/vis_data/vis_image/demo_0.png
visualizer.add_image('demo', visualizer.get_image())

其中生成的后缀 0 是用来区分不同 step 场景

# 会生成 temp_dir/vis_data/vis_image/demo_1.png
visualizer.add_image('demo', visualizer.get_image(), step=1)
# 会生成 temp_dir/vis_data/vis_image/demo_3.png
visualizer.add_image('demo', visualizer.get_image(), step=3)

如果想使用其他后端,则只需要修改配置文件即可

# TensorboardVisBackend
visualizer = Visualizer(image=image, vis_backends=[dict(type='TensorboardVisBackend')], save_dir='temp_dir')
# 或者 WandbVisBackend
visualizer = Visualizer(image=image, vis_backends=[dict(type='WandbVisBackend')], save_dir='temp_dir')

(2) 存储特征图

visualizer = Visualizer(vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
drawn_img = visualizer.draw_featmap(feat, image, channel_reduction=None, topk=5, arrangement=(2, 3))
# 会生成 temp_dir/vis_data/vis_image/feat_0.png
visualizer.add_image('feat', drawn_img)

(3) 存储 loss 等标量数据

# 会生成 temp_dir/vis_data/scalars.json
# 保存 loss
visualizer.add_scalar('loss', 0.2, step=0)
visualizer.add_scalar('loss', 0.1, step=1)
# 保存 acc
visualizer.add_scalar('acc', 0.7, step=0)
visualizer.add_scalar('acc', 0.8, step=1)

也可以一次性保存多个标量数据

# 会将内容追加到 temp_dir/vis_data/scalars.json
visualizer.add_scalars({'loss': 0.3, 'acc': 0.8}, step=3)

(4) 保存配置文件

from mmengine import Config
cfg=Config.fromfile('tests/data/config/py_config/config.py')
# 会生成 temp_dir/vis_data/config.py
visualizer.add_config(cfg)

多后端存储

实际上,任何一个可视化器都可以配置任意多个存储后端,可视化器会循环调用配置好的多个存储后端,从而将结果保存到多后端中。

visualizer = Visualizer(image=image, vis_backends=[dict(type='TensorboardVisBackend'),
                                                   dict(type='LocalVisBackend')],
                        save_dir='temp_dir')
# 会生成 temp_dir/vis_data/events.out.tfevents.xxx 文件
visualizer.draw_bboxes(torch.tensor([[33, 120, 209, 220], [72, 13, 179, 147]]))
visualizer.draw_texts("cat and dog", torch.tensor([10, 20]))
visualizer.draw_circles(torch.tensor([40, 50]), torch.tensor([20]))

visualizer.add_image('demo', visualizer.get_image())

注意:如果多个存储后端中存在同一个类的多个后端,那么必须指定 name 字段,否则无法区分是哪个存储后端

visualizer = Visualizer(image=image, vis_backends=[dict(type='TensorboardVisBackend', name='tb_1', save_dir='temp_dir_1'),
                                                   dict(type='TensorboardVisBackend', name='tb_2', save_dir='temp_dir_2'),
                                                   dict(type='LocalVisBackend', name='local')],
                        save_dir='temp_dir')

任意点位进行可视化

在深度学习过程中,会存在在某些代码位置插入可视化函数,并将其保存到不同后端的需求,这类需求主要用于可视化分析和调试阶段。MMEngine 设计的可视化器支持在任意点位获取同一个可视化器然后进行可视化的功能。 用户只需要在初始化时候通过 get_instance 接口实例化可视化对象,此时该可视化对象即为全局可获取唯一对象,后续通过 Visualizer.get_current_instance() 即可在代码任意位置获取。

# 在程序初始化时候调用
visualizer1 = Visualizer.get_instance(name='vis', vis_backends=[dict(type='LocalVisBackend')])

# 在任何代码位置都可调用
visualizer2 = Visualizer.get_current_instance()
visualizer2.add_scalar('map', 0.7, step=0)

assert id(visualizer1) == id(visualizer2)

也可以通过字段配置方式全局初始化

from mmengine.registry import VISUALIZERS

visualizer_cfg=dict(
                type='Visualizer',
                name='vis_new',
                vis_backends=[dict(type='LocalVisBackend')])
VISUALIZERS.build(visualizer_cfg)

扩展存储后端和可视化器

(1) 调用特定存储后端

目前存储后端仅仅提供了保存配置、保存标量等基本功能,但是由于 WandB 和 Tensorboard 这类存储后端功能非常强大, 用户可能会希望利用到这类存储后端的其他功能。因此,存储后端提供了 experiment 属性来方便用户获取后端对象,满足各类定制化功能。 例如 WandB 提供了表格显示的 API 接口,用户可以通过 experiment属性获取 WandB 对象,然后调用特定的 API 来将自定义数据保存为表格显示

visualizer = Visualizer(image=image, vis_backends=[dict(type='WandbVisBackend')],
                        save_dir='temp_dir')

# 获取 WandB 对象
wandb = visualizer.get_backend('WandbVisBackend').experiment
# 追加表格数据
table = wandb.Table(columns=["step", "mAP"])
table.add_data(1, 0.2)
table.add_data(2, 0.5)
table.add_data(3, 0.9)
# 保存
wandb.log({"table": table})

(2) 扩展存储后端

用户可以方便快捷的扩展存储后端。只需要继承自 BaseVisBackend 并实现各类 add_xx 方法即可

from mmengine.registry import VISBACKENDS
from mmengine.visualization import BaseVisBackend

@VISBACKENDS.register_module()
class DemoVisBackend(BaseVisBackend):
    def add_image(self, **kwargs):
        pass

visualizer = Visualizer(vis_backends=[dict(type='DemoVisBackend')], save_dir='temp_dir')
visualizer.add_image('demo',image)

(3) 扩展可视化器

同样的,用户可以通过继承 Visualizer 并实现想覆写的函数来方便快捷的扩展可视化器。大部分情况下,用户需要覆写 add_datasample来进行拓展。数据中通常包括标注或模型预测的检测框和实例掩码,该接口为各个下游库绘制 datasample 数据的抽象接口。以 MMDetection 为例,datasample 数据中通常包括标注 bbox、标注 mask 、预测 bbox 或者预测 mask 等数据,MMDetection 会继承 Visualizer 并实现 add_datasample 接口,在该接口内部会针对检测任务相关数据进行可视化绘制,从而简化检测任务可视化需求。

from mmengine.registry import VISUALIZERS

@VISUALIZERS.register_module()
class DetLocalVisualizer(Visualizer):
    def add_datasample(self,
                       name,
                       image: np.ndarray,
                       data_sample: Optional['BaseDataElement'] = None,
                       draw_gt: bool = True,
                       draw_pred: bool = True,
                       show: bool = False,
                       wait_time: int = 0,
                       step: int = 0) -> None:
        pass

visualizer_cfg = dict(
    type='DetLocalVisualizer', vis_backends=[dict(type='WandbVisBackend')], name='visualizer')

# 全局初始化
VISUALIZERS.build(visualizer_cfg)

# 任意代码位置
det_local_visualizer = Visualizer.get_current_instance()
det_local_visualizer.add_datasample('det', image, data_sample)

抽象数据接口

在模型的训练/测试过程中,组件之间往往有大量的数据需要传递,不同的算法需要传递的数据经常是不一样的,例如,训练单阶段检测器需要获得数据集的标注框(ground truth bounding boxes)和标签(ground truth box labels),训练 Mask R-CNN 时还需要实例掩码(instance masks)。 训练这些模型时的代码如下所示

for img, img_metas, gt_bboxes, gt_labels in data_loader:
    loss = retinanet(img, img_metas, gt_bboxes, gt_labels)
for img, img_metas, gt_bboxes, gt_masks, gt_labels in data_loader:
    loss = mask_rcnn(img, img_metas, gt_bboxes, gt_masks, gt_labels)

可以发现,在不加封装的情况下,不同算法所需数据的不一致导致了不同算法模块之间接口的不一致,影响了算法库的拓展性,同时一个算法库内的模块为了保持兼容性往往在接口上存在冗余。 上述弊端在算法库之间会体现地更加明显,导致在实现多任务(同时进行如语义分割、检测、关键点检测等多个任务)感知模型时模块难以复用,接口难以拓展。

为了解决上述问题,MMEngine 定义了一套抽象的数据接口来封装模型运行过程中的各种数据。假设将上述不同的数据封装进 data_sample ,不同算法的训练都可以被抽象和统一成如下代码

for img, data_sample in dataloader:
    loss = model(img, data_sample)

通过对各种数据提供统一的封装,抽象数据接口统一并简化了算法库中各个模块的接口,可以被用于算法库中 dataset,model,visualizer,和 evaluator 组件之间,或者 model 内各个模块之间的数据传递。 抽象数据接口实现了基本的增/删/改/查功能,同时支持不同设备之间的迁移,支持类字典和张量的操作,可以充分满足算法库对于这些数据的使用要求。 基于 MMEngine 的算法库可以继承这套抽象数据接口并实现自己的抽象数据接口来适应不同算法中数据的特点与实际需要,在保持统一接口的同时提高了算法模块的拓展性。

在实际实现过程中,算法库中的各个组件所具备的数据接口,一般为如下两个种:

  • 一个训练或测试样本(例如一张图像)的所有的标注信息和预测信息的集合,例如数据集的输出、模型以及可视化器的输入一般为单个训练或测试样本的所有信息。MMEngine将其定义为数据样本(DataSample)

  • 单一类型的预测或标注,一般是算法模型中某个子模块的输出, 例如二阶段检测中RPN的输出、语义分割模型的输出、关键点分支的输出, GAN中生成器的输出等。MMengine将其定义为数据元素(XXXData)

下边首先介绍一下数据样本与数据元素的基类 BaseDataElement

数据基类(BaseDataElement)

BaseDataElement 中存在两种类型的数据,一种是 data 类型,如标注框、框的标签、和实例掩码等;另一种是 metainfo 类型,包含数据的元信息以确保数据的完整性,如 img_shape, img_id 等数据所在图片的一些基本信息,方便可视化等情况下对数据进行恢复和使用。用户在创建 BaseDataElement 的过程中需要对这两类属性的数据进行显式地区分和声明。

为了能够更加方便地使用 BaseDataElementdatametainfo 中的数据均为 BaseDataElement 的属性。我们可以通过访问类属性的方式直接访问 datametainfo 中的数据。此外,BaseDataElement 还提供了很多方法,方便我们操作 data 内的数据:

  • 增/删/改/查 data 中不同字段的数据

  • data 迁移至目标设备

  • 支持像访问字典/张量一样访问 data 内的数据 以充分满足算法库对于这些数据的使用要求。

1. 数据元素的创建

BaseDataElement 的 data 参数可以直接通过 key=value 的方式自由添加,metainfo 的字段需要显式通过关键字 metainfo 指定。

import torch
from mmengine.structures import BaseDataElement
# 可以声明一个空的 object
data_element = BaseDataElement()

bboxes = torch.rand((5, 4))  # 假定 bboxes 是一个 Nx4 维的 tensor,N 代表框的个数
scores = torch.rand((5,))  # 假定框的分数是一个 N 维的 tensor,N 代表框的个数
img_id = 0  # 图像的 ID
H = 800  # 图像的高度
W = 1333  # 图像的宽度

# 直接设置 BaseDataElement 的 data 参数
data_element = BaseDataElement(bboxes=bboxes, scores=scores)

# 显式声明来设置 BaseDataElement 的参数 metainfo
data_element = BaseDataElement(
    bboxes=bboxes,
    scores=scores,
    metainfo=dict(img_id=img_id, img_shape=(H, W)))

2. newclone 函数

用户可以使用 new() 函数通过已有的数据接口创建一个具有相同状态和数据的抽象数据接口。用户可以在创建新 BaseDataElement 时设置 metainfodata,用于创建仅 datametainfo 具有相同状态和数据的抽象接口。比如 new(metainfo=xx) 使得新的 BaseDataElement 与被 clone 的 BaseDataElement 包含相同的 data 内容,但 metainfo 为新设置的内容。 也可以直接使用 clone() 来获得一份深拷贝,clone() 函数的行为与 PyTorch 中 Tensor 的 clone() 参数保持一致。

data_element = BaseDataElement(
    bboxes=torch.rand((5, 4)),
    scores=torch.rand((5,)),
    metainfo=dict(img_id=1, img_shape=(640, 640)))

# 可以在创建新 `BaseDataElement` 时设置 metainfo 和 data,使得新的 BaseDataElement 有相同未被设置的数据
data_element1 = data_element.new(metainfo=dict(img_id=2, img_shape=(320, 320)))
print('bboxes is in data_element1:', 'bboxes' in data_element1) # True
print('bboxes in data_element1 is same as bbox in data_element', (data_element1.bboxes == data_element.bboxes).all())
print('img_id in data_element1 is', data_element1.img_id == 2) # True

data_element2 = data_element.new(label=torch.rand(5,))
print('bboxes is not in data_element2', 'bboxes' not in data_element2) # True
print('img_id in data_element2 is same as img_id in data_element', data_element2.img_id == data_element.img_id)
print('label in data_element2 is', 'label' in data_element2)

# 也可以通过 `clone` 构建一个新的 object,新的 object 会拥有和 data_element 相同的 data 和 metainfo 内容以及状态。
data_element2 = data_element1.clone()
bboxes is in data_element1: True
bboxes in data_element1 is same as bbox in data_element tensor(True)
img_id in data_element1 is True
bboxes is not in data_element2 True
img_id in data_element2 is same as img_id in data_element True
label in data_element2 is True

3. 属性的增加与查询

对增加属性而言,用户可以像增加类属性那样增加 data 内的属性;对metainfo 而言,一般储存的为一些图像的元信息,一般情况下不会修改,如果需要增加,用户应当使用 set_metainfo 接口显示地修改。

对查询而言,用户可以可以通过 keysvalues,和 items 来访问只存在于 data 中的键值,也可以通过 metainfo_keysmetainfo_values,和metainfo_items 来访问只存在于 metainfo 中的键值。 用户还能通过 all_keysall_valuesall_items 来访问 BaseDataElement 的所有的属性并且不区分他们的类型。

同时为了方便使用,用户可以像访问类属性一样访问 data 与 metainfo 内的数据,或着类字典方式通过 get() 接口访问数据。

注意:

  1. BaseDataElement 不支持 metainfo 和 data 属性中有同名的字段,所以用户应当避免 metainfo 和 data 属性中设置相同的字段,否则 BaseDataElement 会报错。

  2. 考虑到 InstanceDataPixelData 支持对数据进行切片操作,为了避免 [] 用法的不一致,同时减少同种需求的不同方法,BaseDataElement 不支持像字典那样访问和设置它的属性,所以类似 BaseDataElement[name] 的取值赋值操作是不被支持的。

data_element = BaseDataElement()
# 通过 `set_metainfo`设置 data_element 的 metainfo 字段,
# 同时 img_id 和 img_shape 成为 data_element 的属性
data_element.set_metainfo(dict(img_id=9, img_shape=(100, 100)))
# 查看 metainfo 的 key, value 和 item
print("metainfo'keys are", data_element.metainfo_keys())
print("metainfo'values are", data_element.metainfo_values())
for k, v in data_element.metainfo_items():
    print(f'{k}: {v}')

print("通过类属性查看 img_id 和 img_shape")
print('img_id:', data_element.img_id)
print('img_shape:', data_element.img_shape)
metainfo'keys are ['img_id', 'img_shape']
metainfo'values are [9, (100, 100)]
img_id: 9
img_shape: (100, 100)
通过类属性查看 img_id  img_shape
img_id: 9
img_shape: (100, 100)

# 通过类属性直接设置 BaseDataElement 中的 data 字段
data_element.scores = torch.rand((5,))
data_element.bboxes = torch.rand((5, 4))

print("data's key is:", data_element.keys())
print("data's value is:", data_element.values())
for k, v in data_element.items():
    print(f'{k}: {v}')

print("通过类属性查看 scores 和 bboxes")
print('scores:', data_element.scores)
print('bboxes:', data_element.bboxes)

print("通过 get() 查看 scores 和 bboxes")
print('scores:', data_element.get('scores', None))
print('bboxes:', data_element.get('bboxes', None))
print('fake:', data_element.get('fake', 'not exist'))
data's key is: ['scores', 'bboxes']
data's value is: [tensor([0.7937, 0.6307, 0.3682, 0.4425, 0.8515]), tensor([[0.9204, 0.2110, 0.2886, 0.7925],
        [0.7993, 0.8982, 0.5698, 0.4120],
        [0.7085, 0.7016, 0.3069, 0.3216],
        [0.0206, 0.5253, 0.1376, 0.9322],
        [0.2512, 0.7683, 0.3010, 0.2672]])]
scores: tensor([0.7937, 0.6307, 0.3682, 0.4425, 0.8515])
bboxes: tensor([[0.9204, 0.2110, 0.2886, 0.7925],
        [0.7993, 0.8982, 0.5698, 0.4120],
        [0.7085, 0.7016, 0.3069, 0.3216],
        [0.0206, 0.5253, 0.1376, 0.9322],
        [0.2512, 0.7683, 0.3010, 0.2672]])
通过类属性查看 scores  bboxes
scores: tensor([0.7937, 0.6307, 0.3682, 0.4425, 0.8515])
bboxes: tensor([[0.9204, 0.2110, 0.2886, 0.7925],
        [0.7993, 0.8982, 0.5698, 0.4120],
        [0.7085, 0.7016, 0.3069, 0.3216],
        [0.0206, 0.5253, 0.1376, 0.9322],
        [0.2512, 0.7683, 0.3010, 0.2672]])
通过 get() 查看 scores  bboxes
scores: tensor([0.7937, 0.6307, 0.3682, 0.4425, 0.8515])
bboxes: tensor([[0.9204, 0.2110, 0.2886, 0.7925],
        [0.7993, 0.8982, 0.5698, 0.4120],
        [0.7085, 0.7016, 0.3069, 0.3216],
        [0.0206, 0.5253, 0.1376, 0.9322],
        [0.2512, 0.7683, 0.3010, 0.2672]])
fake: not exist

print("All key in data_element is:", data_element.all_keys())
print("The length of values in data_element is", len(data_element.all_values()))
for k, v in data_element.all_items():
    print(f'{k}: {v}')
All key in data_element is: ['img_id', 'img_shape', 'scores', 'bboxes']
The length of values in data_element is 4
img_id: 9
img_shape: (100, 100)
scores: tensor([0.7937, 0.6307, 0.3682, 0.4425, 0.8515])
bboxes: tensor([[0.9204, 0.2110, 0.2886, 0.7925],
        [0.7993, 0.8982, 0.5698, 0.4120],
        [0.7085, 0.7016, 0.3069, 0.3216],
        [0.0206, 0.5253, 0.1376, 0.9322],
        [0.2512, 0.7683, 0.3010, 0.2672]])

4. 属性的删改

用户可以像修改实例属性一样修改 BaseDataElementdata, 对metainfo 而言 一般储存的为一些图像的元信息,一般情况下不会修改,如果需要修改,用户应当使用 set_metainfo 接口显示的修改。

同时为了操作的便捷性,对 datametainfo 中的数据可以通过 del 直接删除,也支持 pop 在访问属性后删除属性。

data_element = BaseDataElement(
    bboxes=torch.rand((6, 4)), scores=torch.rand((6,)),
    metainfo=dict(img_id=0, img_shape=(640, 640))
)
for k, v in data_element.all_items():
    print(f'{k}: {v}')
img_id: 0
img_shape: (640, 640)
scores: tensor([0.8445, 0.6678, 0.8172, 0.9125, 0.7186, 0.5462])
bboxes: tensor([[0.5773, 0.0289, 0.4793, 0.7573],
        [0.8187, 0.8176, 0.3455, 0.3368],
        [0.6947, 0.5592, 0.7285, 0.0281],
        [0.7710, 0.9867, 0.7172, 0.5815],
        [0.3999, 0.9192, 0.7817, 0.2535],
        [0.2433, 0.0132, 0.1757, 0.6196]])
# 对 data 进行修改
data_element.bboxes = data_element.bboxes * 2
data_element.scores = data_element.scores * -1
for k, v in data_element.items():
    print(f'{k}: {v}')

# 删除 data 中的属性
del data_element.bboxes
for k, v in data_element.items():
    print(f'{k}: {v}')

data_element.pop('scores', None)
print('The keys in data is', data_element.keys())
scores: tensor([-0.8445, -0.6678, -0.8172, -0.9125, -0.7186, -0.5462])
bboxes: tensor([[1.1546, 0.0578, 0.9586, 1.5146],
        [1.6374, 1.6352, 0.6911, 0.6735],
        [1.3893, 1.1185, 1.4569, 0.0562],
        [1.5420, 1.9734, 1.4344, 1.1630],
        [0.7999, 1.8384, 1.5635, 0.5070],
        [0.4867, 0.0264, 0.3514, 1.2392]])
scores: tensor([-0.8445, -0.6678, -0.8172, -0.9125, -0.7186, -0.5462])
The keys in data is []
# 对 metainfo 进行修改
data_element.set_metainfo(dict(img_shape = (1280, 1280), img_id=10))
print(data_element.img_shape)  # (1280, 1280)
for k, v in data_element.metainfo_items():
    print(f'{k}: {v}')

# 提供了便捷的属性删除和访问操作 pop
del data_element.img_shape
for k, v in data_element.metainfo_items():
    print(f'{k}: {v}')

data_element.pop('img_id')
print('The keys in metainfo is', data_element.metainfo_keys())
(1280, 1280)
img_id: 10
img_shape: (1280, 1280)
img_id: 10
The keys in metainfo is []

5. 类张量操作

用户可以像 torch.Tensor 那样对 BaseDataElement 的 data 进行状态转换,目前支持 cudacputonumpy 等操作。 其中,to 函数拥有和 torch.Tensor.to() 相同的接口,使得用户可以灵活地将被封装的 tensor 进行状态转换。 注意: 这些接口只会处理类型为 np.array,torch.Tensor,或者数字的序列,其他属性的数据(如字符串)会被跳过处理。

data_element = BaseDataElement(
    bboxes=torch.rand((6, 4)), scores=torch.rand((6,)),
    metainfo=dict(img_id=0, img_shape=(640, 640))
)
# 将所有 data 转移到 GPU 上
cuda_element_1 = data_element.cuda()
print('cuda_element_1 is on the device of', cuda_element_1.bboxes.device)  # cuda:0
cuda_element_2 = data_element.to('cuda:0')
print('cuda_element_1 is on the device of', cuda_element_2.bboxes.device)  # cuda:0

# 将所有 data 转移到 cpu 上
cpu_element_1 = cuda_element_1.cpu()
print('cpu_element_1 is on the device of', cpu_element_1.bboxes.device)  # cpu
cpu_element_2 = cuda_element_2.to('cpu')
print('cpu_element_2 is on the device of', cpu_element_2.bboxes.device)  # cpu

# 将所有 data 变成 FP16
fp16_instances = cuda_element_1.to(
    device=None, dtype=torch.float16, non_blocking=False, copy=False,
    memory_format=torch.preserve_format)
print('The type of bboxes in fp16_instances is', fp16_instances.bboxes.dtype)  # torch.float16

# 阻断所有 data 的梯度
cuda_element_3 = cuda_element_2.detach()
print('The data in cuda_element_3 requires grad: ', cuda_element_3.bboxes.requires_grad)
# 转移 data 到 numpy array
np_instances = cpu_element_1.numpy()
print('The type of cpu_element_1 is convert to', type(np_instances.bboxes))
cuda_element_1 is on the device of cuda:0
cuda_element_1 is on the device of cuda:0
cpu_element_1 is on the device of cpu
cpu_element_2 is on the device of cpu
The type of bboxes in fp16_instances is torch.float16
The data in cuda_element_3 requires grad:  False
The type of cpu_element_1 is convert to <class 'numpy.ndarray'>

6. 属性的展示

BaseDataElement 还实现了 __repr__,因此,用户可以直接通过 print 函数看到其中的所有数据信息。 同时,为了便捷开发者 debug,BaseDataElement 中的属性都会添加进 __dict__ 中,方便用户在 IDE 界面可以直观看到 BaseDataElement 中的内容。 一个完整的属性展示如下

img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instance_data = BaseDataElement(metainfo=img_meta)
instance_data.det_labels = torch.LongTensor([0, 1, 2, 3])
instance_data.det_scores = torch.Tensor([0.01, 0.1, 0.2, 0.3])
print(instance_data)
<BaseDataElement(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    det_labels: tensor([0, 1, 2, 3])
    det_scores: tensor([0.0100, 0.1000, 0.2000, 0.3000])
) at 0x7f9f339f85b0>

数据元素(xxxData)

MMEngine 将数据元素情况划分为三个类别:

  • 实例数据(InstanceData): 主要针对的是上层任务(high-level)中,对图像中所有实例相关的数据进行封装,比如检测框(bounding boxes), 物体类别(box labels),实例掩码(instance masks), 关键点(key points), 文字边界(polygons), 跟踪id(tracking ids) 等. 所有实例相关的数据的长度一致,均为图像中实例的个数。

  • 像素数据(PixelData): 主要针对底层任务(low-level) 以及需要感知像素级别标签的部分上层任务。像素数据对像素级相关的数据进行封装,比如语义分割中的分割图(segmentation map), 光流任务中的光流图(flow map), 全景分割中的全景分割图(panoptic seg map);底层任务中生成的各种图像,比如超分辨图,去噪图,以及生成的各种风格图。这些数据的特点是都是三维或四维数组,最后两维度为数据的高度(height)和宽度(width),且具有相同的height和width

  • 标签数据(LabelData): 主要标签级别的数据进行封装,比如图像分类,多分类中的类别,图像生成中生成图像的类别内容,或者文字识别中的文本等。

InstanceData

InstanceDataBaseDataElement 的基础上,对 data 存储的数据做了限制,即要求存储在 data 中的数据的长度一致。比如在目标检测中, 假设一张图像中有 N 个目标(instance),可以将图像的所有边界框(bbox),类别(label)等存储在 InstanceData 中, InstanceData 的 bbox 和 label 的长度相同。 基于上述假定对 InstanceData进行了扩展,包括:

  • InstanceData 中 data 所存储的数据进行了长度校验

  • data 部分支持类字典访问和设置它的属性

  • 支持基础索引,切片以及高级索引功能

  • 支持具有相同的 key 但是不同 InstanceData 的拼接功能。 这些扩展功能除了支持基础的数据结构, 比如torch.tensor, numpy.dnarray, list, str, tuple, 也可以是自定义的数据结构,只要自定义数据结构实现了 __len__, __getitem__ and cat.

数据校验

InstanceData 中 data 的数据长度要保持一致,如果传入不同长度的新数据,将会报错。

from mmengine.structures import InstanceData
import torch
import numpy as np

img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instance_data = InstanceData(metainfo=img_meta)
instance_data.det_labels = torch.LongTensor([2, 3])
instance_data.det_scores = torch.Tensor([0.8, 0.7])
instance_data.bboxes = torch.rand((2, 4))
print('The length of instance_data is', len(instance_data))  # 2

instance_data.bboxes = torch.rand((3, 4))
The length of instance_data is 2
AssertionError: the length of values 3 is not consistent with the length of this :obj:`InstanceData` 2

类字典访问和设置属性

InstanceData 支持类似字典的操作访问和设置其 data 属性。

img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instance_data = InstanceData(metainfo=img_meta)
instance_data["det_labels"] = torch.LongTensor([2, 3])
instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
instance_data.bboxes = torch.rand((2, 4))
print(instance_data)
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    det_labels: tensor([2, 3])
    det_scores: tensor([0.8000, 0.7000])
    bboxes: tensor([[0.6576, 0.5435, 0.5253, 0.8273],
                [0.4533, 0.6848, 0.7230, 0.9279]])
) at 0x7f9f339f8ca0>
索引与切片

InstanceData 支持 Python 中类似列表的索引与切片,同时也支持类似 numpy 的高级索引操作。

img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instance_data = InstanceData(metainfo=img_meta)
instance_data.det_labels = torch.LongTensor([2, 3])
instance_data.det_scores = torch.Tensor([0.8, 0.7])
instance_data.bboxes = torch.rand((2, 4))
print(instance_data)
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    det_labels: tensor([2, 3])
    det_scores: tensor([0.8000, 0.7000])
    bboxes: tensor([[0.1872, 0.1669, 0.7563, 0.8777],
                [0.3421, 0.7104, 0.6000, 0.1518]])
) at 0x7f9f312b4dc0>
  1. 索引

print(instance_data[1])
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    det_labels: tensor([3])
    det_scores: tensor([0.7000])
    bboxes: tensor([[0.3421, 0.7104, 0.6000, 0.1518]])
) at 0x7f9f312b4610>
  1. 切片

print(instance_data[0:1])
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    det_labels: tensor([2])
    det_scores: tensor([0.8000])
    bboxes: tensor([[0.1872, 0.1669, 0.7563, 0.8777]])
) at 0x7f9f312b4e20>
  1. 高级索引

  • 列表索引

sorted_results = instance_data[instance_data.det_scores.sort().indices]
print(sorted_results)
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    det_labels: tensor([3, 2])
    det_scores: tensor([0.7000, 0.8000])
    bboxes: tensor([[0.3421, 0.7104, 0.6000, 0.1518],
                [0.1872, 0.1669, 0.7563, 0.8777]])
) at 0x7f9f312b4a90>
  • 布尔索引

filter_results = instance_data[instance_data.det_scores > 0.75]
print(filter_results)
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    det_labels: tensor([2])
    det_scores: tensor([0.8000])
    bboxes: tensor([[0.1872, 0.1669, 0.7563, 0.8777]])
) at 0x7fa061299dc0>
  1. 结果为空

empty_results = instance_data[instance_data.det_scores > 1]
print(empty_results)
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    det_labels: tensor([], dtype=torch.int64)
    det_scores: tensor([])
    bboxes: tensor([], size=(0, 4))
) at 0x7f9f439cccd0>
拼接(cat)

用户可以将两个具有相同 key 的 InstanceData 拼接成一个 InstanceData。对于长度分别为 N 和 M 的两个 InstanceData, 拼接后为长度 N + M 的新的 InstanceData

img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instance_data = InstanceData(metainfo=img_meta)
instance_data.det_labels = torch.LongTensor([2, 3])
instance_data.det_scores = torch.Tensor([0.8, 0.7])
instance_data.bboxes = torch.rand((2, 4))
print('The length of instance_data is', len(instance_data))
cat_results = InstanceData.cat([instance_data, instance_data])
print('The length of instance_data is', len(cat_results))
print(cat_results)
The length of instance_data is 2
The length of instance_data is 4
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    det_labels: tensor([2, 3, 2, 3])
    det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
    bboxes: tensor([[0.5341, 0.8962, 0.9043, 0.2824],
                [0.3864, 0.2215, 0.7610, 0.7060],
                [0.5341, 0.8962, 0.9043, 0.2824],
                [0.3864, 0.2215, 0.7610, 0.7060]])
) at 0x7fa061d4a9d0>
自定义数据结构

对于自定义结构如果想使用上述扩展要求需要实现__len__, __getitem__cat三个接口.

import itertools

class TmpObject:
    def __init__(self, tmp) -> None:
        assert isinstance(tmp, list)
        self.tmp = tmp

    def __len__(self):
        return len(self.tmp)

    def __getitem__(self, item):
        if type(item) == int:
            if item >= len(self) or item < -len(self):  # type:ignore
                raise IndexError(f'Index {item} out of range!')
            else:
                # keep the dimension
                item = slice(item, None, len(self))
        return TmpObject(self.tmp[item])

    @staticmethod
    def cat(tmp_objs):
        assert all(isinstance(results, TmpObject) for results in tmp_objs)
        if len(tmp_objs) == 1:
            return tmp_objs[0]
        tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
        tmp_list = list(itertools.chain(*tmp_list))
        new_data = TmpObject(tmp_list)
        return new_data

    def __repr__(self):
        return str(self.tmp)
img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instance_data = InstanceData(metainfo=img_meta)
instance_data.det_labels = torch.LongTensor([2, 3])
instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
instance_data.bboxes = torch.rand((2, 4))
instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
print(instance_data)
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    det_labels: tensor([2, 3])
    polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
    det_scores: tensor([0.8000, 0.7000])
    bboxes: tensor([[0.4207, 0.0778, 0.9959, 0.1967],
                [0.4679, 0.7934, 0.5372, 0.4655]])
) at 0x7fa061b5d2b0>
# 高级索引
print(instance_data[instance_data.det_scores > 0.75])
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    bboxes: tensor([[0.4207, 0.0778, 0.9959, 0.1967]])
    det_labels: tensor([2])
    det_scores: tensor([0.8000])
    polygons: [[1, 2, 3, 4]]
) at 0x7f9f312716d0>
# 拼接
print(InstanceData.cat([instance_data, instance_data]))
<InstanceData(

    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)

    DATA FIELDS
    bboxes: tensor([[0.4207, 0.0778, 0.9959, 0.1967],
                [0.4679, 0.7934, 0.5372, 0.4655],
                [0.4207, 0.0778, 0.9959, 0.1967],
                [0.4679, 0.7934, 0.5372, 0.4655]])
    det_labels: tensor([2, 3, 2, 3])
    det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
    polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7f9f31271490>

PixelData

PixelDataBaseDataElement 的基础上,同样对对 data 中存储的数据做了限制:

  • 所有 data 内的数据均为 3 维,并且顺序为 (通道,高, 宽)

  • 所有在 data 内的数据要有相同的长和宽 基于上述假定对 PixelData进行了扩展,包括:

  • PixelData 中 data 所存储的数据进行了尺寸的校验

  • 支持对 data 部分的数据对实例进行空间维度的索引和切片。

数据校验

PixelData 会对传入到 data 的数据进行维度与长宽的校验。

from mmengine.structures import PixelData
import random
import torch
import numpy as np
metainfo = dict(
    img_id=random.randint(0, 100),
    img_shape=(random.randint(400, 600), random.randint(400, 600)))
image = np.random.randint(0, 255, (4, 20, 40))
featmap = torch.randint(0, 255, (10, 20, 40))
pixel_data = PixelData(metainfo=metainfo,
                       image=image,
                       featmap=featmap)
print('The shape of pixel_data is', pixel_data.shape)
# set
pixel_data.map3 = torch.randint(0, 255, (20, 40))
print('The shape of pixel_data is', pixel_data.map3.shape)
The shape of pixel_data is (20, 40)
The shape of pixel_data is torch.Size([1, 20, 40])
pixel_data.map2 = torch.randint(0, 255, (3, 20, 30))
# AssertionError: the height and width of values (20, 30) is not consistent with the length of this :obj:`PixelData` (20, 40)
AssertionError: the height and width of values (20, 30) is not consistent with the length of this :obj:`PixelData` (20, 40)
pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40))
# AssertionError: The dim of value must be 2 or 3, but got 4
AssertionError: The dim of value must be 2 or 3, but got 4

空间维度索引

PixelData 支持对 data 部分的数据对实例进行空间维度的索引和切片,只需传入长宽的索引即可。

metainfo = dict(
    img_id=random.randint(0, 100),
    img_shape=(random.randint(400, 600), random.randint(400, 600)))
image = np.random.randint(0, 255, (4, 20, 40))
featmap = torch.randint(0, 255, (10, 20, 40))
pixel_data = PixelData(metainfo=metainfo,
                       image=image,
                       featmap=featmap)
print('The shape of pixel_data is', pixel_data.shape)
The shape of pixel_data is (20, 40)
  • 索引

index_data = pixel_data[10, 20]
print('The shape of index_data is', index_data.shape)
The shape of index_data is (1, 1)
  • 切片

slice_data = pixel_data[10:20, 20:40]
print('The shape of slice_data is', slice_data.shape)
The shape of slice_data is (10, 20)

LabelData

LabelData 主要用来封装标签数据,如场景分类标签,文字识别标签等。LabelData 没有对 data 做任何限制,只提供了两个额外功能:onehot 与 index 的转换。

from mmengine.structures import LabelData
import torch

item = torch.tensor([1], dtype=torch.int64)
num_classes = 10

onehot = LabelData.label_to_onehot(label=item, num_classes=num_classes)
print(f'{num_classes} is convert to ', onehot)

index = LabelData.onehot_to_label(onehot=onehot)
print(f'{onehot} is convert to ', index)
10 is convert to  tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0]) is convert to tensor([1])

数据样本(DataSample)

数据样本作为不同模块最外层的接口,提供了 xxxDataSample 用于单任务中各模块之间统一格式的传递,同时为了各个模块从统一字段获取或写入信息,数据样本中的命名以及类型要进行约束和统一,保证各模块接口的统一性。 OpenMMLab 中各个算法库的命名规范可以参考 OpenMMLab 中的命名规范

下游库使用

以 MMDet 为例,说明下游库中数据样本的使用,以及数据样本字段的约束和命名。MMDet 中定义了 DetDataSample, 同时定义了 7 个字段,分别为:

  • 标注信息

    • gt_instance(InstanceData): 实例标注信息,包括实例的类别、边界框等, 类型约束为 InstanceData

    • gt_panoptic_seg(PixelData): 全景分割的标注信息,类型约束为 PixelData

    • gt_semantic_seg(PixelData): 语义分割的标注信息, 类型约束为 PixelData

  • 预测结果

    • pred_instance(InstanceData): 实例预测结果,包括实例的类别、边界框等, 类型约束为 InstanceData

    • pred_panoptic_seg(PixelData): 全景分割的预测结果,类型约束为 PixelData

    • pred_semantic_seg(PixelData): 语义分割的预测结果, 类型约束为 PixelData

  • 中间结果

    • proposal(InstanceData): 主要为二阶段中 RPN 的预测结果, 类型约束为 InstanceData

from mmengine.structures import BaseDataElement
import torch

class DetDataSample(BaseDataElement):

    # 标注
    @property
    def gt_instances(self) -> InstanceData:
        return self._gt_instances

    @gt_instances.setter
    def gt_instances(self, value: InstanceData):
        self.set_field(value, '_gt_instances', dtype=InstanceData)

    @gt_instances.deleter
    def gt_instances(self):
        del self._gt_instances

    @property
    def gt_panoptic_seg(self) -> PixelData:
        return self._gt_panoptic_seg

    @gt_panoptic_seg.setter
    def gt_panoptic_seg(self, value: PixelData):
        self.set_field(value, '_gt_panoptic_seg', dtype=PixelData)

    @gt_panoptic_seg.deleter
    def gt_panoptic_seg(self):
        del self._gt_panoptic_seg

    @property
    def gt_sem_seg(self) -> PixelData:
        return self._gt_sem_seg

    @gt_sem_seg.setter
    def gt_sem_seg(self, value: PixelData):
        self.set_field(value, '_gt_sem_seg', dtype=PixelData)

    @gt_sem_seg.deleter
    def gt_sem_seg(self):
        del self._gt_sem_seg

    # 预测
    @property
    def pred_instances(self) -> InstanceData:
        return self._pred_instances

    @pred_instances.setter
    def pred_instances(self, value: InstanceData):
        self.set_field(value, '_pred_instances', dtype=InstanceData)

    @pred_instances.deleter
    def pred_instances(self):
        del self._pred_instances

    @property
    def pred_panoptic_seg(self) -> PixelData:
        return self._pred_panoptic_seg

    @pred_panoptic_seg.setter
    def pred_panoptic_seg(self, value: PixelData):
        self.set_field(value, '_pred_panoptic_seg', dtype=PixelData)

    @pred_panoptic_seg.deleter
    def pred_panoptic_seg(self):
        del self._pred_panoptic_seg

    # 中间结果
    @property
    def pred_sem_seg(self) -> PixelData:
        return self._pred_sem_seg

    @pred_sem_seg.setter
    def pred_sem_seg(self, value: PixelData):
        self.set_field(value, '_pred_sem_seg', dtype=PixelData)

    @pred_sem_seg.deleter
    def pred_sem_seg(self):
        del self._pred_sem_seg

    @property
    def proposals(self) -> InstanceData:
        return self._proposals

    @proposals.setter
    def proposals(self, value: InstanceData):
        self.set_field(value, '_proposals', dtype=InstanceData)

    @proposals.deleter
    def proposals(self):
        del self._proposals

类型约束

DetDataSample 的用法如下所示,在数据类型不符合要求的时候(例如用 torch.Tensor 而非 InstanceData 定义 proposals 时),DetDataSample 就会报错。

data_sample = DetDataSample()

data_sample.proposals = InstanceData(data=dict(bboxes=torch.rand((5,4))))
print(data_sample)
<DetDataSample(

    META INFORMATION

    DATA FIELDS
    proposals: <InstanceData(

            META INFORMATION

            DATA FIELDS
            data:
                bboxes: tensor([[0.7513, 0.9275, 0.6169, 0.5581],
                            [0.6019, 0.6861, 0.7915, 0.0221],
                            [0.5977, 0.8987, 0.9541, 0.7877],
                            [0.0309, 0.1680, 0.1374, 0.0556],
                            [0.3842, 0.9965, 0.0747, 0.6546]])
        ) at 0x7f9f1c090310>
) at 0x7f9f1c090430>
data_sample.proposals = torch.rand((5, 4))
AssertionError: tensor([[0.4370, 0.1661, 0.0902, 0.8421],
        [0.4947, 0.1668, 0.0083, 0.1111],
        [0.2041, 0.8663, 0.0563, 0.3279],
        [0.7817, 0.1938, 0.2499, 0.6748],
        [0.4524, 0.8265, 0.4262, 0.2215]]) should be a <class 'mmengine.data.instance_data.InstanceData'> but got <class 'torch.Tensor'>

接口的简化

下面以 MMDetection 为例更具体地说明 OpenMMLab 的算法库将如何迁移使用抽象数据接口,以简化模块和组件接口的。我们假定 MMDetection 和 MMEngine 中实现了 DetDataSample 和 InstanceData。

1. 组件接口的简化

检测器的外部接口可以得到显著的简化和统一。MMDet 2.X 中单阶段检测器和单阶段分割算法的接口如下。在训练过程中,SingleStageDetector 需要获取 imgimg_metasgt_bboxesgt_labelsgt_bboxes_ignore 作为输入,但是 SingleStageInstanceSegmentor 还需要 gt_masks,导致 detector 的训练接口不一致,影响了代码的灵活性。


class SingleStageDetector(BaseDetector):
    ...

    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None):


class SingleStageInstanceSegmentor(BaseDetector):
    ...

    def forward_train(self,
                      img,
                      img_metas,
                      gt_masks,
                      gt_labels,
                      gt_bboxes=None,
                      gt_bboxes_ignore=None,
                      **kwargs):

在 MMDet 3.0 中,所有检测器的训练接口都可以使用 DetDataSample 统一简化为 imgdata_samples,不同模块可以根据需要去访问 data_samples 封装的各种所需要的属性。

class SingleStageDetector(BaseDetector):
    ...

    def forward_train(self,
                      img,
                      data_samples):

class SingleStageInstanceSegmentor(BaseDetector):
    ...

    def forward_train(self,
                      img,
                      data_samples):

2. 模块接口的简化

MMDet 2.X 中 HungarianAssignerMaskHungarianAssigner 分别用于在训练过程中将检测框和实例掩码和标注的实例进行匹配。他们内部的匹配逻辑实现是一样的,只是接口和损失函数的计算不同。 但是,接口的不同使得 HungarianAssigner 中的代码无法被复用,MaskHungarianAssigner 中重写了很多冗余的逻辑。

class HungarianAssigner(BaseAssigner):

    def assign(self,
               bbox_pred,
               cls_pred,
               gt_bboxes,
               gt_labels,
               img_meta,
               gt_bboxes_ignore=None,
               eps=1e-7):

class MaskHungarianAssigner(BaseAssigner):

    def assign(self,
               cls_pred,
               mask_pred,
               gt_labels,
               gt_mask,
               img_meta,
               gt_bboxes_ignore=None,
               eps=1e-7):

InstanceData 可以封装实例的框、分数、和掩码,将 HungarianAssigner 的核心参数简化成 pred_instancesgt_instancess,和 gt_instances_ignore 使得 HungarianAssignerMaskHungarianAssigner 可以合并成一个通用的 HungarianAssigner

class HungarianAssigner(BaseAssigner):

    def assign(self,
               pred_instances,
               gt_instancess,
               gt_instances_ignore=None,
               eps=1e-7):

分布式通信原语

在分布式训练或测试的过程中,不同进程有时需要根据分布式的环境信息执行不同的代码逻辑,同时不同进程之间也经常会有相互通信的需求,对一些数据进行同步等操作。 PyTorch 提供了一套基础的通信原语用于多进程之间张量的通信,基于这套原语,MMEngine 实现了更高层次的通信原语封装以满足更加丰富的需求。基于 MMEngine 的通信原语,算法库中的模块可以

  1. 在使用通信原语封装时不显式区分分布式/非分布式环境

  2. 进行除 Tensor 以外类型数据的多进程通信

  3. 无需了解底层通信后端或框架

这些通信原语封装的接口和功能可以大致归类为如下三种,我们在后续章节中逐个介绍

  1. 分布式初始化:init_dist 负责初始化执行器的分布式环境

  2. 分布式信息获取与控制:包括 get_world_size 等函数获取当前的 rankworld_size 等信息

  3. 分布式通信接口:包括如 all_reduce 等通信函数(collective functions)

分布式初始化

  • init_dist: 是分布式训练的启动函数,目前支持 pytorch,slurm,MPI 3 种分布式启动方式,同时允许设置通信的后端,默认使用 NCCL。

分布式信息获取与控制

分布式信息的获取与控制函数没有参数,这些函数兼容非分布式训练的情况,功能如下

  • get_world_size:获取当前进程组的进程总数,非分布式情况下返回 1

  • get_rank:获取当前进程对应的全局 rank 数,非分布式情况下返回 0

  • get_backend:获取当前通信使用的后端,非分布式情况下返回 None

  • get_local_rank:获取当前进程对应到当前机器的 rank 数,非分布式情况下返回 0

  • get_local_size:获取当前进程所在机器的总进程数,非分布式情况下返回 0

  • get_dist_info:获取当前任务的进程总数和当前进程对应到全局的 rank 数,非分布式情况下 word_size = 1,rank = 0

  • is_main_process:判断是否为 0 号主进程,非分布式情况下返回 True

  • master_only:函数装饰器,用于修饰只需要全局 0 号进程(rank 0 而不是 local rank 0)执行的函数

  • barrier:同步所有进程到达相同位置

分布式通信函数

通信函数 (Collective functions),主要用于进程间数据的通信,基于 PyTorch 原生的 all_reduce,all_gather,gather,broadcast 接口,MMEngine 提供了如下接口,兼容非分布式训练的情况,并支持更丰富数据类型的通信。

  • all_reduce: 对进程间 tensor 进行 AllReduce 操作

  • all_gather:对进程间 tensor 进行 AllGather 操作

  • gather:将进程的 tensor 收集到一个目标 rank

  • broadcast:对某个进程的 tensor 进行广播

  • sync_random_seed:同步进程之间的随机种子

  • broadcast_object_list:支持对任意可被 Pickle 序列化的 Python 对象列表进行广播,基于 broadcast 接口实现

  • all_reduce_dict:对 dict 中的内容进行 all_reduce 操作,基于 broadcast 和 all_reduce 接口实现

  • all_gather_object:基于 all_gather 实现对任意可以被 Pickle 序列化的 Python 对象进行 all_gather 操作

  • gather_object:将 group 里每个 rank 中任意可被 Pickle 序列化的 Python 对象 gather 到指定的目标 rank

  • collect_results:支持基于 CPU 通信或者 GPU 通信对不同进程间的列表数据进行收集

记录日志

执行器(Runner)在运行过程中会产生很多日志,例如损失、迭代时间、学习率等。MMEngine 实现了一套灵活的日志系统让我们能够在配置执行器时,选择不同类型日志的统计方式;在代码的任意位置,新增需要被统计的日志。

灵活的日志统计方式

我们可以通过在构建执行器时候配置日志处理器,来灵活地选择日志统计方式。如果不为执行器配置日志处理器,则会按照日志处理器的默认参数构建实例,效果等价于:

log_processor = dict(window_size=10, by_epoch=True, custom_cfg=None, num_digits=4)

其输出的日志格式如下:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from mmengine.runner import Runner
from mmengine.model import BaseModel

train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)


class ToyModel(BaseModel):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, mode):
        feat = self.linear(img)
        loss1 = (feat - label).pow(2)
        loss2 = (feat - label).abs()
        return dict(loss1=loss1, loss2=loss2)

runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01))
)
runner.train()
08/21 02:58:41 - mmengine - INFO - Epoch(train) [1][10/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0019  data_time: 0.0004  loss1: 0.8381  loss2: 0.9007  loss: 1.7388
08/21 02:58:41 - mmengine - INFO - Epoch(train) [1][20/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0029  data_time: 0.0010  loss1: 0.1978  loss2: 0.4312  loss: 0.6290

以训练阶段为例,日志处理器默认会按照以下方式统计执行器输出的日志:

  • 日志前缀:

    • Epoch 模式(by_epoch=True): Epoch(train) [{当前epoch次数}][{当前迭代次数}/{Dataloader 总长度}]

    • Iter 模式(by_epoch=False): Iter(train) [{当前迭代次数}/{总迭代次数}]

  • 学习率(lr):统计最近一次迭代,参数更新的学习率

  • 时间

    • 迭代时间(time):最近 window_size(日志处理器参数) 次迭代,处理一个 batch 数据(包括数据加载和模型前向推理)的平局时间

    • 数据时间(data_time):最近 window_size 次迭代,加载一个 batch 数据的平局时间

    • 剩余时间(eta):根据总迭代次数和历次迭代时间计算出来的总剩余时间,剩余时间随着迭代次数增加逐渐趋于稳定

  • 损失:模型前向推理得到的各种字段的损失,默认统计最近 window_size 次迭代的平均损失。

默认情况下,window_size=10,日志处理器会统计最近 10 次迭代,损失、迭代时间、数据时间的均值。

默认情况下,所有日志的有效位数(num_digits 参数)为 4。

默认情况下,输出所有自定义日志最近一次迭代的值。

基于上述规则,代码示例中的日志处理器会输出 loss1loss2 每 10 次迭代的均值。如果我们想统计 loss1 从第一次迭代开始至今的全局均值,可以这样配置:

runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
    log_processor=dict(  # 配置日志处理器
        custom_cfg=[
            dict(data_src='loss1',  # 原日志名:loss1
                 method_name='mean',  # 统计方法:均值统计
                 window_size='global')])  # 统计窗口:全局
)
runner.train()
08/21 02:58:49 - mmengine - INFO - Epoch(train) [1][10/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0026  data_time: 0.0007  loss1: 0.7381  loss2: 0.8446  loss: 1.5827
08/21 02:58:49 - mmengine - INFO - Epoch(train) [1][20/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0030  data_time: 0.0012  loss1: 0.4521  loss2: 0.3939  loss: 0.5600

注解

log_processor 默认输出 by_epoch=True 格式的日志。日志格式需要和 train_cfg 中的 by_epoch 参数保持一致,例如我们想按迭代次数输出日志,就需要另 log_processortrain_cfgby_epoch=False

其中 data_src 为原日志名,mean 为统计方法,global 为统计方法的参数。这样的话,日志中统计的 loss1 就是全局均值。我们可以在日志处理器中配置以下统计方法:

统计方法 参数 功能
mean window_size 统计窗口内日志的均值
min window_size 统计窗口内日志的最小值
max window_size 统计窗口内日志的最大值
current / 返回最近一次更新的日志

其中 window_size 的值可以是:

  • 数字:表示统计窗口的大小

  • global:统计全局的最大、最小和均值

  • epoch:统计一个 epoch 内的最大、最小和均值

当然我们也可以选择自定义的统计方法,详细步骤见日志设计

如果我们既想统计窗口为 10 的 loss1 的局部均值,又想统计 loss1 的全局均值,则需要额外指定 log_name

runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
    log_processor=dict(
        custom_cfg=[
            # log_name 表示 loss1 重新统计后的日志名
            dict(data_src='loss1', log_name='loss1_global', method_name='mean', window_size='global')])
)
runner.train()
08/21 18:39:32 - mmengine - INFO - Epoch(train) [1][10/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0016  data_time: 0.0004  loss1: 0.1512  loss2: 0.3751  loss: 0.5264  loss1_global: 0.1512
08/21 18:39:32 - mmengine - INFO - Epoch(train) [1][20/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0051  data_time: 0.0036  loss1: 0.0113  loss2: 0.0856  loss: 0.0970  loss1_global: 0.0813

类似地,我们也可以统计 loss1 的局部最大值和全局最大值:

runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
    log_processor=dict(custom_cfg=[
        # 统计 loss1 的局部最大值,统计窗口为 10,并在日志中重命名为 loss1_local_max
        dict(data_src='loss1',
             log_name='loss1_local_max',
             window_size=10,
             method_name='max'),
        # 统计 loss1 的全局最大值,并在日志中重命名为 loss1_local_max
        dict(
            data_src='loss1',
            log_name='loss1_global_max',
            method_name='max',
            window_size='global')
    ]))
runner.train()
08/21 03:17:26 - mmengine - INFO - Epoch(train) [1][10/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0021  data_time: 0.0006  loss1: 1.8495  loss2: 1.3427  loss: 3.1922  loss1_local_max: 2.8872  loss1_global_max: 2.8872
08/21 03:17:26 - mmengine - INFO - Epoch(train) [1][20/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0024  data_time: 0.0010  loss1: 0.5464  loss2: 0.7251  loss: 1.2715  loss1_local_max: 2.8872  loss1_global_max: 2.8872

更多配置规则见日志处理器文档

自定义统计内容

除了 MMEngine 默认的日志统计类型,如损失、迭代时间、学习率,用户也可以自行添加日志的统计内容。例如我们想统计损失的中间结果,可以这样做:

from mmengine.logging import MessageHub


class ToyModel(BaseModel):

    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, mode):
        feat = self.linear(img)
        loss_tmp = (feat - label).abs()
        loss = loss_tmp.pow(2)

        message_hub = MessageHub.get_current_instance()
        # 在日志中额外统计 `loss_tmp`
        message_hub.update_scalar('train/loss_tmp', loss_tmp.sum())
        return dict(loss=loss)


runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
    log_processor=dict(
        custom_cfg=[
        # 统计 loss_tmp 的局部均值
            dict(
                data_src='loss_tmp',
                window_size=10,
                method_name='mean')
        ]
    )
)
runner.train()
08/21 03:40:31 - mmengine - INFO - Epoch(train) [1][10/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0026  data_time: 0.0008  loss_tmp: 0.0097  loss: 0.0000
08/21 03:40:31 - mmengine - INFO - Epoch(train) [1][20/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0028  data_time: 0.0013  loss_tmp: 0.0065  loss: 0.0000

通过调用消息枢纽的接口实现自定义日志的统计,具体步骤如下:

  1. 调用 get_current_instance 接口获取执行器的消息枢纽。

  2. 调用 update_scalar 接口更新日志内容,其中第一个参数为日志的名称,日志名称以 train/val/test/ 前缀打头,用于区分训练状态,然后才是实际的日志名,如上例中的 train/loss_tmp,这样统计的日志中就会出现 loss_tmp

  3. 配置日志处理器,以均值的方式统计 loss_tmp。如果不配置,日志里显示 loss_tmp 最近一次更新的值。

输出调试日志

初始化执行器(Runner)时,将 log_level 设置成 debug。这样终端上就会额外输出日志等级为 debug 的日志

runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    log_level='DEBUG',
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)))
runner.train()
08/21 18:16:22 - mmengine - DEBUG - Get class `LocalVisBackend` from "vis_backend" registry in "mmengine"
08/21 18:16:22 - mmengine - DEBUG - An `LocalVisBackend` instance is built from registry, its implementation can be found in mmengine.visualization.vis_backend
08/21 18:16:22 - mmengine - DEBUG - Get class `RuntimeInfoHook` from "hook" registry in "mmengine"
08/21 18:16:22 - mmengine - DEBUG - An `RuntimeInfoHook` instance is built from registry, its implementation can be found in mmengine.hooks.runtime_info_hook
08/21 18:16:22 - mmengine - DEBUG - Get class `IterTimerHook` from "hook" registry in "mmengine"
...

此外,分布式训练时,DEBUG 模式还会分进程存储日志。单机多卡,或者多机多卡但是共享存储的情况下,导出的分布式日志路径如下

#  共享存储
./tmp
├── tmp.log
├── tmp_rank1.log
├── tmp_rank2.log
├── tmp_rank3.log
├── tmp_rank4.log
├── tmp_rank5.log
├── tmp_rank6.log
└── tmp_rank7.log
...
└── tmp_rank63.log

多机多卡,独立存储的情况:

# 独立存储
# 设备0:
work_dir/
└── exp_name_logs
    ├── exp_name.log
    ├── exp_name_rank1.log
    ├── exp_name_rank2.log
    ├── exp_name_rank3.log
    ...
    └── exp_name_rank7.log

# 设备7:
work_dir/
└── exp_name_logs
    ├── exp_name_rank56.log
    ├── exp_name_rank57.log
    ├── exp_name_rank58.log
    ...
    └── exp_name_rank63.log

如果想要更加深入的了解 MMEngine 的日志系统,可以参考日志系统设计

文件读写

MMEngine 实现了一套统一的文件读写接口,可以用同一个函数来处理不同的文件格式,如 jsonyamlpickle,并且可以方便地拓展其它的文件格式。除此之外,文件读写模块还支持从多种文件存储后端读写文件,包括本地磁盘、Petrel(内部使用)、Memcached、LMDB 和 HTTP。

读取和保存数据

MMEngine 提供了两个通用的接口用于读取和保存数据,目前支持的格式有 jsonyamlpickle

从硬盘读取数据或者将数据保存至硬盘

from mmengine import load, dump

# 从文件中读取数据
data = load('test.json')
data = load('test.yaml')
data = load('test.pkl')
# 从文件对象中读取数据
with open('test.json', 'r') as f:
    data = load(f, file_format='json')

# 将数据序列化为字符串
json_str = dump(data, file_format='json')

# 将数据保存至文件 (根据文件名后缀反推文件类型)
dump(data, 'out.pkl')

# 将数据保存至文件对象
with open('test.yaml', 'w') as f:
    data = dump(data, f, file_format='yaml')

从其它文件存储后端读写文件

from mmengine import load, dump

# 从 s3 文件读取数据
data = load('s3://bucket-name/test.json')
data = load('s3://bucket-name/test.yaml')
data = load('s3://bucket-name/test.pkl')

# 将数据保存至 s3 文件 (根据文件名后缀反推文件类型)
dump(data, 's3://bucket-name/out.pkl')

我们提供了易于拓展的方式以支持更多的文件格式,我们只需要创建一个继承自 BaseFileHandler 的文件句柄类,句柄类至少需要重写三个方法。然后使用使用 register_handler 装饰器将句柄类注册为对应文件格式的读写句柄。

from mmengine import register_handler, BaseFileHandler

# 支持为文件句柄类注册多个文件格式
# @register_handler(['txt', 'log'])
@register_handler('txt')
class TxtHandler1(BaseFileHandler):

    def load_from_fileobj(self, file):
        return file.read()

    def dump_to_fileobj(self, obj, file):
        file.write(str(obj))

    def dump_to_str(self, obj, **kwargs):
        return str(obj)

PickleHandler 为例

from mmengine import BaseFileHandler
import pickle

class PickleHandler(BaseFileHandler):

    def load_from_fileobj(self, file, **kwargs):
        return pickle.load(file, **kwargs)

    def load_from_path(self, filepath, **kwargs):
        return super(PickleHandler, self).load_from_path(
            filepath, mode='rb', **kwargs)

    def dump_to_str(self, obj, **kwargs):
        kwargs.setdefault('protocol', 2)
        return pickle.dumps(obj, **kwargs)

    def dump_to_fileobj(self, obj, file, **kwargs):
        kwargs.setdefault('protocol', 2)
        pickle.dump(obj, file, **kwargs)

    def dump_to_path(self, obj, filepath, **kwargs):
        super(PickleHandler, self).dump_to_path(
            obj, filepath, mode='wb', **kwargs)

读取文件并返回列表或字典

例如, a.txt 是文本文件,一共有5行内容。

a
b
c
d
e

从硬盘读取

使用 list_from_file 读取 a.txt

from mmengine import list_from_file

print(list_from_file('a.txt'))
# ['a', 'b', 'c', 'd', 'e']
print(list_from_file('a.txt', offset=2))
# ['c', 'd', 'e']
print(list_from_file('a.txt', max_num=2))
# ['a', 'b']
print(list_from_file('a.txt', prefix='/mnt/'))
# ['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e']

同样, b.txt 也是文本文件,一共有3行内容

1 cat
2 dog cow
3 panda

使用 dict_from_file 读取 b.txt

from mmengine import dict_from_file

print(dict_from_file('b.txt'))
# {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
print(dict_from_file('b.txt', key_type=int))
# {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}

从其他存储后端读取

使用 list_from_file 读取 s3://bucket-name/a.txt

from mmengine import list_from_file

print(list_from_file('s3://bucket-name/a.txt'))
# ['a', 'b', 'c', 'd', 'e']
print(list_from_file('s3://bucket-name/a.txt', offset=2))
# ['c', 'd', 'e']
print(list_from_file('s3://bucket-name/a.txt', max_num=2))
# ['a', 'b']
print(list_from_file('s3://bucket-name/a.txt', prefix='/mnt/'))
# ['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e']

使用 dict_from_file 读取 b.txt

from mmengine import dict_from_file

print(dict_from_file('s3://bucket-name/b.txt'))
# {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
print(dict_from_file('s3://bucket-name/b.txt', key_type=int))
# {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}

读取和保存权重文件

通常情况下,我们可以通过下面的方式从磁盘或者网络远端读取权重文件。

import torch

filepath1 = '/path/of/your/checkpoint1.pth'
filepath2 = 'http://path/of/your/checkpoint3.pth'

# 从本地磁盘读取权重文件
checkpoint = torch.load(filepath1)
# 保存权重文件到本地磁盘
torch.save(checkpoint, filepath1)

# 从网络远端读取权重文件
checkpoint = torch.utils.model_zoo.load_url(filepath2)

MMEngine 中,得益于多文件存储后端的支持,不同存储形式的权重文件读写可以通过 load_checkpointsave_checkpoint 来统一实现。

from mmengine import load_checkpoint, save_checkpoint

filepath1 = '/path/of/your/checkpoint1.pth'
filepath2 = 's3://bucket-name/path/of/your/checkpoint1.pth'
filepath3 = 'http://path/of/your/checkpoint3.pth'

# 从本地磁盘读取权重文件
checkpoint = load_checkpoint(filepath1)
# 保存权重文件到本地磁盘
save_checkpoint(checkpoint, filepath1)

# 从 s3 读取权重文件
checkpoint = load_checkpoint(filepath2)
# 保存权重文件到 s3
save_checkpoint(checkpoint, filepath2)

# 从网络远端读取权重文件
checkpoint = load_checkpoint(filepath3)

全局管理器(ManagerMixin)

Runner 在训练过程中,难免会使用全局变量来共享信息,例如我们会在 model 中获取全局的 logger 来打印初始化信息;在 model 中获取全局的 Visualizer 来可视化预测结果、特征图;在 Registry 中获取全局的 DefaultScope 来确定注册域。为了管理这些功能相似的模块,MMEngine 设计了管理器 ManagerMix 来统一全局变量的创建和获取方式。

ManagerMixin

接口介绍

  • get_instance(name=’’, **kwargs):创建或者返回对应名字的的实例。

  • get_current_instance():返回最近被创建的实例。

  • instance_name:获取对应实例的 name。

使用方法

  1. 定义有全局访问需求的类

from mmengine.utils import ManagerMixin


class GlobalClass(ManagerMixin):
    def __init__(self, name, value):
        super().__init__(name)
        self.value = value

注意全局类的构造函数必须带有 name 参数,并在构造函数中调用 super().__init__(name),以确保后续能够根据 name 来获取对应的实例。

  1. 在任意位置实例化该对象,以 Hook 为例(要确保访问该实例时,对象已经被创建):

from mmengine import Hook

class CustomHook(Hook):
    def before_run(self, runner):
        GlobalClass.get_instance('mmengine', value=50)
        GlobalClass.get_instance(runner.experiment_name, value=100)

当我们调用子类的 get_instance 接口时,ManagerMixin 会根据名字来判断对应实例是否已经存在,进而创建/获取实例。如上例所示,当我们第一次调用 GlobalClass.get_instance('mmengine', value=50) 时,会创建一个名为 “mmengine” 的 GlobalClass 实例,其初始 value 为 50。为了方便后续介绍 get_current_instance 接口,这里我们创建了两个 GlobalClass 实例。

  1. 在任意组件中访问该实例

import torch.nn as nn


class CustomModule(nn.Module):
    def forward(self, x):
        value = GlobalClass.get_current_instance().value  # 最近一次被创建的实例 value 为 100(步骤二中按顺序创建)
        value = GlobalClass.get_instance('mmengine').value  # 名为 mmengine 的实例 value 为 50
        # value = GlobalClass.get_instance('mmengine', 1000).value  # mmengine 已经被创建,不能再接受额外参数

在同一进程里,我们可以在不同组件中访问 GlobalClass 实例。例如我们在 CustomModule 中,调用 get_instanceget_current_instance 接口来获取对应名字的实例和最近被创建的实例。需要注意的是,由于 “mmengine” 实例已经被创建,再次调用时不能再传入额外参数,否则会报错。

跨库调用模块

通过使用 MMEngine 的注册器(Registry)配置文件(Config),用户可以实现跨软件包的模块构建。 例如,在 MMDetection 中使用 MMClassification 的 Backbone,或者在 MMRotate 中使用 MMDetection 的 Transform,或者在 MMTracking 中使用 MMDetection 的 Detector。 一般来说,同类模块都可以进行跨库调用,只需要在配置文件的模块类型前加上软件包名的前缀即可。下面举几个常见的例子:

跨库调用 Backbone:

以在 MMDetection 中调用 MMClassification 的 ConvNeXt 为例,首先需要在配置中加入 custom_imports 字段将 MMClassification 的 Backbone 添加进注册器,然后只需要在 Backbone 的配置中的 type 加上 MMClassification 的软件包名 mmcls 作为前缀,即 mmcls.ConvNeXt 即可:

# 使用 custom_imports 将 mmcls 的 models 添加进注册器
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)

model = dict(
  type='MaskRCNN',
  data_preprocessor=dict(...),
  backbone=dict(
      type='mmcls.ConvNeXt',  # 添加 mmcls 前缀完成跨库调用
      arch='tiny',
      out_indices=[0, 1, 2, 3],
      drop_path_rate=0.4,
      layer_scale_init_value=1.0,
      gap_before_final_norm=False,
      init_cfg=dict(
          type='Pretrained',
          checkpoint=
          'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth',
          prefix='backbone.')),
  neck=dict(...),
  rpn_head=dict(...))

跨库调用 Transform:

与上文的跨库调用 Backbone 一样,使用 custom_imports 和添加前缀即可实现跨库调用:

# 使用 custom_imports 将 mmdet 的 transforms 添加进注册器
custom_imports = dict(imports=['mmdet.datasets.transforms'], allow_failed_imports=False)

# 添加 mmdet 前缀完成跨库调用
train_pipeline=[
    dict(type='mmdet.LoadImageFromFile'),
    dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
    dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
    dict(type='mmdet.Resize', scale=(1024, 2014), keep_ratio=True),
    dict(type='mmdet.RandomFlip', prob=0.5),
    dict(type='mmdet.PackDetInputs')
]

跨库调用 Detector:

跨库调用算法是一个比较复杂的例子,一个算法会包含多个子模块,因此每个子模块也需要在type中增加前缀,以在 MMTracking 中调用 MMDetection 的 YOLOX 为例:

# 使用 custom_imports 将 mmdet 的 models 添加进注册器
custom_imports = dict(imports=['mmdet.models'], allow_failed_imports=False)
model = dict(
    type='mmdet.YOLOX',
    backbone=dict(type='mmdet.CSPDarknet', deepen_factor=1.33, widen_factor=1.25),
    neck=dict(
        type='mmdet.YOLOXPAFPN',
        in_channels=[320, 640, 1280],
        out_channels=320,
        num_csp_blocks=4),
    bbox_head=dict(
        type='mmdet.YOLOXHead', num_classes=1, in_channels=320, feat_channels=320),
    train_cfg=dict(assigner=dict(type='mmdet.SimOTAAssigner', center_radius=2.5)))

为了避免给每个子模块手动增加前缀,配置文件中引入了 _scope_ 关键字,当某一模块的配置中添加了 _scope_ 关键字后,该模块配置文件下面的所有子模块配置都会从该关键字所对应的软件包内去构建:

# 使用 custom_imports 将 mmdet 的 models 添加进注册器
custom_imports = dict(imports=['mmdet.models'], allow_failed_imports=False)
model = dict(
    _scope_='mmdet',  # 使用 _scope_ 关键字,避免给所有子模块添加前缀
    type='YOLOX',
    backbone=dict(type='CSPDarknet', deepen_factor=1.33, widen_factor=1.25),
    neck=dict(
        type='YOLOXPAFPN',
        in_channels=[320, 640, 1280],
        out_channels=320,
        num_csp_blocks=4),
    bbox_head=dict(
        type='YOLOXHead', num_classes=1, in_channels=320, feat_channels=320),
    train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)))

以上这两种写法互相等价。

若希望了解更多关于注册器和配置文件的内容,请参考配置文件教程注册器教程

测试时增强(Test time augmentation)

测试时增强(Test time augmentation,后文简称 TTA)是一种测试阶段的数据增强策略,旨在测试过程中,对同一张图片做翻转、缩放等各种数据增强,将增强后每张图片预测的结果还原到原始尺寸并做融合,以获得更加准确的预测结果。为了让用户更加方便地使用 TTA,MMEngine 提供了 BaseTTAModel 类,用户只需按照任务需求,继承 BaseTTAModel 类,实现不同的 TTA 策略即可。

TTA 的核心实现通常分为两个部分:

  1. 测试时的数据增强:测试时数据增强主要在 MMCV 中实现,可以参考 TestTimeAug 的 API 文档,本文档不再赘述。

  2. 模型推理以及结果融合:BaseTTAModel 的主要功能就是实现这一部分,BaseTTAModel.test_step 会解析测试时增强后的数据并进行推理。用户继承 BaseTTAModel 后只需实现相应的融合策略即可。

快速上手

一个简单的支持 TTA 的示例可以参考 examples/test_time_augmentation.py

准备 TTA 数据增强

BaseTTAModel 需要配合 MMCV 中实现的 TestTimeAug 使用,这边简单给出一个样例配置:

tta_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='TestTimeAug',
        transforms=[
            [dict(type='Resize', img_scale=(1333, 800), keep_ratio=True)],
            [dict(type='RandomFlip', flip_ratio=0.),
                dict(type='RandomFlip', flip_ratio=1.)],
            [dict(type='PackXXXInputs', keys=['img'])],
        ])
]

该配置表示在测试时,每张图片缩放(Resize)后都会进行翻转增强,变成两张图片。

定义 TTA 模型融合策略

BaseTTAModel 需要对翻转前后的图片进行推理,并将结果融合。merge_preds 方法接受一列表,列表中每一个元素表示 batch 中的某个数据反复增强后的结果。例如 batch_size=3,我们对 batch 中的每张图片做翻转增强,merge_preds 接受的参数为:

# data_{i}_{j} 表示对第 i 张图片做第 j 种增强后的结果,
# 例如 batch_size=3,那么 i 的 取值范围为 0,1,2,
# 增强方式有 2 种(翻转),那么 j 的取值范围为 0,1

demo_results = [
    [data_0_0, data_0_1],
    [data_1_0, data_1_1],
    [data_2_0, data_2_1],
]

merge_preds 需要将 demo_results 融合成整个 batch 的推理结果。以融合分类结果为例:

class AverageClsScoreTTA(BaseTTAModel):
    def merge_preds(
        self,
        data_samples_list: List[List[ClsDataSample]],
    ) -> List[ClsDataSample]:

        merged_data_samples = []
        for data_samples in data_samples_list:
            merged_data_sample: ClsDataSample = data_samples[0].new()
            merged_score = sum(data_sample.pred_label.score
                               for data_sample in data_samples) / len(data_samples)
            merged_data_sample.set_pred_score(merged_score)
            merged_data_samples.append(merged_data_sample)
        return merged_data_samples

相应的配置文件为:

tta_model = dict(type='AverageClsScoreTTA')

改写测试脚本

cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline

进阶使用

一般情况下,用户继承 BaseTTAModel 后,只需要实现 merge_preds 方法,即可完成结果融合。但是对于复杂情况,例如融合多阶段检测器的推理结果,则可能会需要重写 test_step 方法。这就要求我们去进一步了解 BaseTTAModel 的数据流以及它和各组件之间的关系。

BaseTTAModel 和各组件的关系

BaseTTAModelDDPWrapperModel 的中间层。在执行 Runner.test() 的过程中,会先执行 DDPWrapper.test_step(),然后执行 TTAModel.test_step(),最后再执行 model.test_step()

运行过程中具体的调用栈如下所示:

数据流

数据经 TestTimeAug 增强后,其数据格式为:

image1  = dict(
    inputs=[data_1_1, data_1_2],
    data_sample=[data_sample1_1, data_sample1_2])
)

image2  = dict(
    inputs=[data_2_1, data_2_2],
    data_sample=[data_sample2_1, data_sample2_2])
)

image3  = dict(
    inputs=[data_3_1, data_3_2],
    data_sample=[data_sample3_1, data_sample3_2])
)

其中 data_{i}_{j} 为增强后的数据,data_sample_{i}_{j} 为增强后数据的标签信息。 数据经过 DataLoader 处理后,格式转变为:

data_batch = dict(
    inputs = [
              (data_1_1, data_2_1, data_3_1),
              (data_1_2, data_2_2, data_3_2),
             ]
    data_samples=[
         (data_samples1_1, data_samples2_1, data_samples3_1),
         (data_samples1_2, data_samples2_2, data_samples3_2)
     ]
)

为了方便模型推理,BaseTTAModel 会在模型推理前将将数据转换为:

data_batch_aug1 = dict(
    inputs = (data_1_1, data_2_1, data_3_1),
    data_samples=(data_samples1_1, data_samples2_1, data_samples3_1)
)

data_batch_aug2 = dict(
    inputs = (data_1_2, data_2_2, data_3_2),
    data_samples=(data_samples1_2, data_samples2_2, data_samples3_2)
)

此时每个 data_batch_aug 均可以直接传入模型进行推理。模型推理后,BaseTTAModel 会将推理结果整理成:

preds = [
    [data_samples1_1, data_samples_1_2],
    [data_samples2_1, data_samples_2_2],
    [data_samples3_1, data_samples_3_2],
]

方便用户进行结果融合。了解 TTA 的数据流后,我们就可以根据具体的需求,重载 BaseTTAModel.test_step(),以实现更加复杂的融合策略。

钩子

钩子编程是一种编程模式,是指在程序的一个或者多个位置设置位点(挂载点),当程序运行至某个位点时,会自动调用运行时注册到位点的所有方法。钩子编程可以提高程序的灵活性和拓展性,用户将自定义的方法注册到位点便可被调用而无需修改程序中的代码。

钩子示例

下面是钩子的简单示例。

pre_hooks = [(print, 'hello')]
post_hooks = [(print, 'goodbye')]

def main():
    for func, arg in pre_hooks:
        func(arg)
    print('do something here')
    for func, arg in post_hooks:
        func(arg)

main()

下面是程序的输出:

hello
do something here
goodbye

可以看到,main 函数在两个位置调用钩子中的函数而无需做任何改动。

在 PyTorch 中,钩子的应用也随处可见,例如神经网络模块(nn.Module)中的钩子可以获得模块的前向输入输出以及反向的输入输出。以 register_forward_hook 方法为例,该方法往模块注册一个前向钩子,钩子可以获得模块的前向输入和输出。

下面是 register_forward_hook 用法的简单示例:

import torch
import torch.nn as nn

def forward_hook_fn(
    module,  # 被注册钩子的对象
    input,  # module 前向计算的输入
    output,  # module 前向计算的输出
):
    print(f'"forward_hook_fn" is invoked by {module.name}')
    print('weight:', module.weight.data)
    print('bias:', module.bias.data)
    print('input:', input)
    print('output:', output)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 1)

    def forward(self, x):
        y = self.fc(x)
        return y

model = Model()
# 将 forward_hook_fn 注册到 model 每个子模块
for module in model.children():
    module.register_forward_hook(forward_hook_fn)

x = torch.Tensor([[0.0, 1.0, 2.0]])
y = model(x)

下面是程序的输出:

"forward_hook_fn" is invoked by Linear(in_features=3, out_features=1, bias=True)
weight: tensor([[-0.4077,  0.0119, -0.3606]])
bias: tensor([-0.2943])
input: (tensor([[0., 1., 2.]]),)
output: tensor([[-1.0036]], grad_fn=<AddmmBackward>)

可以看到注册到 Linear 模块的 forward_hook_fn 钩子被调用,在该钩子中打印了 Linear 模块的权重、偏置、模块的输入以及输出。更多关于 PyTorch 钩子的用法可以阅读 nn.Module

MMEngine 中钩子的设计

在介绍 MMEngine 中钩子的设计之前,先简单介绍使用 PyTorch 实现模型训练的基本步骤(示例代码来自 PyTorch Tutorials):

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    pass

class Net(nn.Module):
    pass

def main():
    transform = transforms.ToTensor()
    train_dataset = CustomDataset(transform=transform, ...)
    val_dataset = CustomDataset(transform=transform, ...)
    test_dataset = CustomDataset(transform=transform, ...)
    train_dataloader = DataLoader(train_dataset, ...)
    val_dataloader = DataLoader(val_dataset, ...)
    test_dataloader = DataLoader(test_dataset, ...)

    net = Net()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    for i in range(max_epochs):
        for inputs, labels in train_dataloader:
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            for inputs, labels in val_dataloader:
                outputs = net(inputs)
                loss = criterion(outputs, labels)

    with torch.no_grad():
        for inputs, labels in test_dataloader:
            outputs = net(inputs)
            accuracy = ...

上面的伪代码是训练模型的基本步骤。如果要在上面的代码中加入定制化的逻辑,我们需要不断修改和拓展 main 函数。为了提高 main 函数的灵活性和拓展性,我们可以在 main 方法中插入位点,并在对应位点实现调用 hook 的抽象逻辑。此时只需在这些位点插入 hook 来实现定制化逻辑,即可添加定制化功能,例如加载模型权重、更新模型参数等。

def main():
    ...
    call_hooks('before_run', hooks)  # 任务开始前执行的逻辑
    call_hooks('after_load_checkpoint', hooks)  # 加载权重后执行的逻辑
    call_hooks('before_train', hooks)  # 训练开始前执行的逻辑
    for i in range(max_epochs):
        call_hooks('before_train_epoch', hooks)  # 遍历训练数据集前执行的逻辑
        for inputs, labels in train_dataloader:
            call_hooks('before_train_iter', hooks)  # 模型前向计算前执行的逻辑
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            call_hooks('after_train_iter', hooks)  # 模型前向计算后执行的逻辑
            loss.backward()
            optimizer.step()
        call_hooks('after_train_epoch', hooks)  # 遍历完训练数据集后执行的逻辑

        call_hooks('before_val_epoch', hooks)  # 遍历验证数据集前执行的逻辑
        with torch.no_grad():
            for inputs, labels in val_dataloader:
                call_hooks('before_val_iter', hooks)  # 模型前向计算前执行
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                call_hooks('after_val_iter', hooks)  # 模型前向计算后执行
        call_hooks('after_val_epoch', hooks)  # 遍历完验证数据集前执行

        call_hooks('before_save_checkpoint', hooks)  # 保存权重前执行的逻辑
    call_hooks('after_train', hooks)  # 训练结束后执行的逻辑

    call_hooks('before_test_epoch', hooks)  # 遍历测试数据集前执行的逻辑
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            call_hooks('before_test_iter', hooks)  # 模型前向计算后执行的逻辑
            outputs = net(inputs)
            accuracy = ...
            call_hooks('after_test_iter', hooks)  # 遍历完成测试数据集后执行的逻辑
    call_hooks('after_test_epoch', hooks)  # 遍历完测试数据集后执行

    call_hooks('after_run', hooks)  # 任务结束后执行的逻辑

在 MMEngine 中,我们将训练过程抽象成执行器(Runner),执行器除了完成环境的初始化,另一个功能是在特定的位点调用钩子完成定制化逻辑。更多关于执行器的介绍请阅读执行器文档

为了方便管理,MMEngine 将位点定义为方法并集成到钩子基类(Hook)中,我们只需继承钩子基类并根据需求在特定位点实现定制化逻辑,再将钩子注册到执行器中,便可自动调用钩子中相应位点的方法。

钩子中一共有 22 个位点:

  • before_run

  • after_run

  • before_train

  • after_train

  • before_train_epoch

  • after_train_epoch

  • before_train_iter

  • after_train_iter

  • before_val

  • after_val

  • before_val_epoch

  • after_val_epoch

  • before_val_iter

  • after_val_iter

  • before_test

  • after_test

  • before_test_epoch

  • after_test_epoch

  • before_test_iter

  • after_test_iter

  • before_save_checkpoint

  • after_load_checkpoint

你可能还想阅读钩子的用法或者钩子的 API 文档

执行器

深度学习算法的训练、验证和测试通常都拥有相似的流程,因此, MMEngine 抽象出了执行器来负责通用的算法模型的训练、测试、推理任务。用户一般可以直接使用 MMEngine 中的默认执行器,也可以对执行器进行修改以满足定制化需求。

在介绍执行器的设计之前,我们先举几个例子来帮助用户理解为什么需要执行器。下面是一段使用 PyTorch 进行模型训练的伪代码:

model = ResNet()
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
train_dataset = ImageNetDataset(...)
train_dataloader = DataLoader(train_dataset, ...)

for i in range(max_epochs):
    for data_batch in train_dataloader:
        optimizer.zero_grad()
        outputs = model(data_batch)
        loss = loss_func(outputs, data_batch)
        loss.backward()
        optimizer.step()

下面是一段使用 PyTorch 进行模型测试的伪代码:

model = ResNet()
model.load_state_dict(torch.load(CKPT_PATH))
model.eval()

test_dataset = ImageNetDataset(...)
test_dataloader = DataLoader(test_dataset, ...)

for data_batch in test_dataloader:
    outputs = model(data_batch)
    acc = calculate_acc(outputs, data_batch)

下面是一段使用 PyTorch 进行模型推理的伪代码:

model = ResNet()
model.load_state_dict(torch.load(CKPT_PATH))
model.eval()

for img in imgs:
    prediction = model(img)

可以从上面的三段代码看出,这三个任务的执行流程都可以归纳为构建模型、读取数据、循环迭代等步骤。上述代码都是以图像分类为例,但不论是图像分类还是目标检测或是图像分割,都脱离不了这套范式。 因此,我们将模型的训练、验证、测试的流程整合起来,形成了执行器。在执行器中,我们只需要准备好模型、数据等任务必须的模块或是这些模块的配置文件,执行器会自动完成任务流程的准备和执行。 通过使用执行器以及 MMEngine 中丰富的功能模块,用户不再需要手动搭建训练测试的流程,也不再需要去处理分布式与非分布式训练的区别,可以专注于算法和模型本身。

Runner

MMEngine 的执行器内包含训练、测试、验证所需的各个模块,以及循环控制器(Loop)和钩子(Hook)。用户通过提供配置文件或已构建完成的模块,执行器将自动完成运行环境的配置,模块的构建和组合,最终通过循环控制器执行任务循环。执行器对外提供三个接口:trainvaltest,当调用这三个接口时,便会运行对应的循环控制器,并在循环的运行过程中调用钩子模块各个位点的钩子函数。

当用户构建一个执行器并调用训练、验证、测试的接口时,执行器的执行流程如下:创建工作目录 -> 配置运行环境 -> 准备任务所需模块 -> 注册钩子 -> 运行循环

runner_flow

执行器具有延迟初始化(Lazy Initialization)的特性,在初始化执行器时,并不需要依赖训练、验证和测试的全量模块,只有当运行某个循环控制器时,才会检查所需模块是否构建。因此,若用户只需要执行训练、验证或测试中的某一项功能,只需提供对应的模块或模块的配置即可。

循环控制器

在 MMEngine 中,我们将任务的执行流程抽象成循环控制器(Loop),因为大部分的深度学习任务执行流程都可以归纳为模型在一组或多组数据上进行循环迭代。 MMEngine 内提供了四种默认的循环控制器:

  • EpochBasedTrainLoop 基于轮次的训练循环

  • IterBasedTrainLoop 基于迭代次数的训练循环

  • ValLoop 标准的验证循环

  • TestLoop 标准的测试循环

Loop

MMEngine 中的默认执行器和循环控制器能够完成大部分的深度学习任务,但不可避免会存在无法满足的情况。有的用户希望能够对执行器进行更多自定义修改,因此,MMEngine 支持自定义模型的训练、验证以及测试的流程。

用户可以通过继承循环基类来实现自己的训练流程。循环基类需要提供两个输入:runner 执行器的实例和 dataloader 循环所需要迭代的迭代器。 用户如果有自定义的需求,也可以增加更多的输入参数。MMEngine 中同样提供了 LOOPS 注册器对循环类进行管理,用户可以向注册器内注册自定义的循环模块,然后在配置文件的 train_cfgval_cfgtest_cfg 中增加 type 字段来指定使用何种循环。 用户可以在自定义的循环中实现任意的执行逻辑,也可以增加或删减钩子(hook)点位,但需要注意的是一旦钩子点位被修改,默认的钩子函数可能不会被执行,导致一些训练过程中默认发生的行为发生变化。 因此,我们强烈建议用户按照本文档中定义的循环执行流程图以及钩子设计 去重载循环基类。

from mmengine.registry import LOOPS, HOOKS
from mmengine.runner import BaseLoop
from mmengine.hooks import Hook


# 自定义验证循环
@LOOPS.register_module()
class CustomValLoop(BaseLoop):
    def __init__(self, runner, dataloader, evaluator, dataloader2):
        super().__init__(runner, dataloader, evaluator)
        self.dataloader2 = runner.build_dataloader(dataloader2)

    def run(self):
        self.runner.call_hooks('before_val_epoch')
        for idx, data_batch in enumerate(self.dataloader):
            self.runner.call_hooks(
                'before_val_iter', batch_idx=idx, data_batch=data_batch)
            outputs = self.run_iter(idx, data_batch)
            self.runner.call_hooks(
                'after_val_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs)
        metric = self.evaluator.evaluate()

        # 增加额外的验证循环
        for idx, data_batch in enumerate(self.dataloader2):
            # 增加额外的钩子点位
            self.runner.call_hooks(
                'before_valloader2_iter', batch_idx=idx, data_batch=data_batch)
            self.run_iter(idx, data_batch)
            # 增加额外的钩子点位
            self.runner.call_hooks(
                'after_valloader2_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs)
        metric2 = self.evaluator.evaluate()

        ...

        self.runner.call_hooks('after_val_epoch')


# 定义额外点位的钩子类
@HOOKS.register_module()
class CustomValHook(Hook):
    def before_valloader2_iter(self, batch_idx, data_batch):
        ...

    def after_valloader2_iter(self, batch_idx, data_batch, outputs):
        ...

上面的例子中实现了一个与默认验证循环不一样的自定义验证循环,它在两个不同的验证集上进行验证,同时对第二次验证增加了额外的钩子点位,并在最后对两个验证结果进行进一步的处理。在实现了自定义的循环类之后,只需要在配置文件的 val_cfg 内设置 type='CustomValLoop',并添加额外的配置即可。

# 自定义验证循环
val_cfg = dict(type='CustomValLoop', dataloader2=dict(dataset=dict(type='ValDataset2'), ...))
# 额外点位的钩子
custom_hooks = [dict(type='CustomValHook')]

自定义执行器

更进一步,如果默认执行器中依然有其他无法满足需求的部分,用户可以像自定义其他模块一样,通过继承重写的方式,实现自定义的执行器。执行器同样也可以通过注册器进行管理。具体实现流程与其他模块无异:继承 MMEngine 中的 Runner,重写需要修改的函数,添加进 RUNNERS 注册器中,最后在配置文件中指定 runner_type 即可。

from mmengine.registry import RUNNERS
from mmengine.runner import Runner

@RUNNERS.register_module()
class CustomRunner(Runner):

    def setup_env(self):
        ...

上述例子实现了一个自定义的执行器,并重写了 setup_env 函数,然后添加进了 RUNNERS 注册器中,完成了这些步骤之后,便可以在配置文件中设置 runner_type='CustomRunner' 来构建自定义的执行器。

你可能还想阅读执行器的教程或者执行器的 API 文档

模型精度评测

评测指标与评测器

在模型验证和模型测试中,通常需要对模型精度做定量评测。在 MMEngine 中实现了评测指标(Metric)和评测器(Evaluator)来完成这一功能。

  • 评测指标 用于根据测试数据和模型预测结果,完成特定模型精度指标的计算。在 OpenMMLab 各算法库中提供了对应任务的常用评测指标,如 MMClassification 中提供了Accuracy 用于计算分类模型的 Top-k 分类正确率;MMDetection 中提供了 COCOMetric 用于计算目标检测模型的 AP,AR 等评测指标。评测指标与数据集解耦,如 COCOMetric 也可用于 COCO 以外的目标检测数据集上。

  • 评测器 是评测指标的上层模块,通常包含一个或多个评测指标。评测器的作用是在模型评测时完成必要的数据格式转换,并调用评测指标计算模型精度。评测器通常由执行器或测试脚本构建,分别用于在线评测和离线评测。

评测指标基类 BaseMetric

评测指标基类 BaseMetric 是一个抽象类,初始化参数如下:

  • collect_device:在分布式评测中用于同步结果的设备名,如 'cpu''gpu'

  • prefix:评测指标名前缀,用以区别多个同名的评测指标。如果该参数未给定,则会尝试使用类属性 default_prefix 作为前缀。

class BaseMetric(metaclass=ABCMeta):

    default_prefix: Optional[str] = None

    def __init__(self,
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None) -> None:
        ...

BaseMetric 有以下 2 个重要的方法需要在子类中重写:

  • process() 用于处理每个批次的测试数据和模型预测结果。处理结果应存放在 self.results 列表中,用于在处理完所有测试数据后计算评测指标。该方法具有以下 2 个参数:

    • data_batch:一个批次的测试数据样本,通常直接来自与数据加载器

    • data_samples:对应的模型预测结果 该方法没有返回值。函数接口定义如下:

    @abstractmethod
    def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
        """Process one batch of data samples and predictions. The processed
        results should be stored in ``self.results``, which will be used to
        compute the metrics when all batches have been processed.
        Args:
            data_batch (Any): A batch of data from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from the model.
        """
    
  • compute_metrics() 用于计算评测指标,并将所评测指标存放在一个字典中返回。该方法有以下 1 个参数:

    • results:列表类型,存放了所有批次测试数据经过 process() 方法处理后得到的结果 该方法返回一个字典,里面保存了评测指标的名称和对应的评测值。函数接口定义如下:

    @abstractmethod
    def compute_metrics(self, results: list) -> dict:
        """Compute the metrics from processed results.
    
        Args:
            results (list): The processed results of each batch.
    
        Returns:
            dict: The computed metrics. The keys are the names of the metrics,
            and the values are corresponding results.
        """
    

其中,compute_metrics() 会在 evaluate() 方法中被调用;后者在计算评测指标前,会在分布式测试时收集和汇总不同 rank 的中间处理结果。

需要注意的是,self.results 中存放的具体类型取决于评测指标子类的实现。例如,当测试样本或模型输出数据量较大(如语义分割、图像生成等任务),不宜全部存放在内存中时,可以在 self.results 中存放每个批次计算得到的指标,并在 compute_metrics() 中汇总;或将每个批次的中间结果存储到临时文件中,并在 self.results 中存放临时文件路径,最后由 compute_metrics() 从文件中读取数据并计算指标。

模型精度评测流程

通常,模型精度评测的过程如下图所示。

在线评测:测试数据通常会被划分为若干批次(batch)。通过一个循环,依次将每个批次的数据送入模型,得到对应的预测结果,并将测试数据和模型预测结果送入评测器。评测器会调用评测指标的 process() 方法对数据和预测结果进行处理。当循环结束后,评测器会调用评测指标的 evaluate() 方法,可计算得到对应指标的模型精度。

离线评测:与在线评测过程类似,区别是直接读取预先保存的模型预测结果来进行评测。评测器提供了 offline_evaluate 接口,用于在离线方式下调用评测指标来计算模型精度。为了避免同时处理大量数据导致内存溢出,离线评测时会将测试数据和预测结果分成若干个块(chunk)进行处理,类似在线评测中的批次。

增加自定义评测指标

在 OpenMMLab 的各个算法库中,已经实现了对应方向的常用评测指标。如 MMDetection 中提供了 COCO 评测指标,MMClassification 中提供了 Accuracy、F1Score 等评测指标等。

用户也可以增加自定义的评测指标。具体方法可以参考教程文档中给出的示例。

可视化

1 总体设计

可视化可以给深度学习的模型训练和测试过程提供直观解释。在 OpenMMLab 算法库中,我们期望可视化功能的设计能满足以下需求:

  • 提供丰富的开箱即用可视化功能,能够满足大部分计算机视觉可视化任务

  • 高扩展性,可视化功能通常多样化,应该能够通过简单扩展实现定制需求

  • 能够在训练和测试流程的任意点位进行可视化

  • OpenMMLab 各个算法库具有统一可视化接口,利于用户理解和维护

基于上述需求,OpenMMLab 2.0 引入了可视化对象 Visualizer 和各个可视化存储后端 VisBackend 如 LocalVisBackendWandbVisBackendTensorboardVisBackend 等。此处的可视化不仅仅包括图片数据格式,还包括配置内容、标量和模型图等数据的可视化。

  • 为了方便调用,Visualizer 提供的接口实现了绘制和存储的功能。可视化存储后端 VisBackend 作为 Visualizer 的内部属性,会在需要的时候被 Visualizer 调用,将数据存到不同的后端

  • 考虑到绘制后会希望存储到多个后端,Visualizer 可以配置多个 VisBackend,当用户调用 Visualizer 的存储接口时候,Visualizer 内部会遍历的调用 VisBackend 存储接口

两者的 UML 关系图如下

2 可视化器 Visualizer

可视化对象 Visualizer 对外提供了所有接口。可以将其接口分成 3 大类,如下所示

(1) 绘制相关接口

上述接口除了 draw_featmap 外都可以链式调用,因为该方法调用后可能会导致图片尺寸发生改变。为了避免给用户带来困扰, draw_featmap 被设置为静态方法。

(2) 存储相关接口

  • add_config 写配置到特定存储后端

  • add_graph 写模型图到特定存储后端

  • add_image 写图片到特定存储后端

  • add_scalar 写标量到特定存储后端

  • add_scalars 一次性写多个标量到特定存储后端

  • add_datasample 各个下游库绘制 datasample 数据的抽象接口

以 add 前缀开头的接口表示存储接口。datasample 是 OpenMMLab 2.0 架构中设计的各个下游库统一的抽象数据接口,而 add_datasample 接口可以直接处理该数据格式,例如可视化预测结果、可视化 Dataset 或者 DataLoader 输出、可视化中间预测结果等等都可以直接调用下游库重写的 add_datasample 接口。 所有下游库都必须要继承 Visualizer 并实现 add_datasample 接口。以 MMDetection 为例,应该继承并通过该接口实现目标检测中所有预置任务的可视化功能,例如目标检测、实例分割、全景分割任务结果的绘制和存储。

(3) 其余功能性接口

  • set_image 设置原始图片数据,默认输入图片格式为 RGB

  • get_image 获取绘制后的 Numpy 格式图片数据,默认输出格式为 RGB

  • show 可视化

  • get_backend 通过 name 获取特定存储后端

  • close 关闭所有已经打开的资源,包括 VisBackend

关于其用法,可以参考 可视化器用户教程

3 可视化存储后端 VisBackend

在绘制后可以将绘制后的数据存储到多个可视化存储后端中。为了统一接口调用,MMEngine 提供了统一的抽象类 BaseVisBackend,和一些常用的 VisBackend 如 LocalVisBackendWandbVisBackendTensorboardVisBackend。 BaseVisBackend 定义了对外调用的接口规范,主要接口和属性如下:

  • add_config 写配置到特定存储后端

  • add_graph 写模型图到特定后端

  • add_image 写图片到特定后端

  • add_scalar 写标量到特定后端

  • add_scalars 一次性写多个标量到特定后端

  • close 关闭已经打开的资源

  • experiment 写后端对象,例如 WandB 对象和 Tensorboard 对象

BaseVisBackend 定义了 5 个常见的写数据接口,考虑到某些写后端功能非常强大,例如 WandB,其具备写表格,写视频等等功能,针对这类需求用户可以直接获取 experiment 对象,然后调用写后端对象本身的 API 即可。而 LocalVisBackendWandbVisBackendTensorboardVisBackend 等都是继承自 BaseVisBackend,并根据自身特性实现了对应的存储功能。用户也可以继承 BaseVisBackend 从而扩展存储后端,实现自定义存储需求。 关于其用法,可以参考 存储后端用户教程

日志系统

概述

执行器(Runner)在运行过程中会产生很多日志,例如加载的数据集信息、模型的初始化信息、训练过程中的学习率、损失等。为了让用户能够更加自由的获取这些日志信息,MMEngine 实现了消息枢纽(MessageHub)历史缓冲区(HistoryBuffer)日志处理器(LogProcessor)MMLogger 来支持以下功能:

  • 用户可以通过配置文件,根据个人偏好来选择日志统计方式,例如在终端输出整个训练过程中的平均损失而不是基于固定迭代次数平滑的损失

  • 用户可以在任意组件中获取当前的训练状态,例如当前的迭代次数、训练轮次等

  • 用户可以通过配置文件来控制是否保存分布式训练下的多进程日志

image

训练过程中的产生的损失、学习率等数据由历史缓冲区管理和封装,汇总后交给消息枢纽维护。日志处理器将消息枢纽中的数据进行格式化,最后通过记录器钩子(LoggerHook) 展示到各种可视化后端。一般情况下用户无需感知数据处理流程,可以直接通过配置日志处理器来选择日志的统计方式。在介绍 MMEngine 的日志系统的设计之前,可以先阅读记录日志教程 了解日志系统的基本用法。

历史缓冲区(HistoryBuffer)

MMEngine 实现了历史数据存储的抽象类历史缓冲区(HistoryBuffer),用于存储训练日志的历史轨迹,如模型损失、优化器学习率、迭代时间等。通常情况下,历史缓冲区作为内部类,配合消息枢纽(MessageHub)、记录器钩子(LoggerHook )和日志处理器(LogProcessor) 实现了训练日志的可配置化。

用户也可以单独使用历史缓冲区来管理训练日志,能够非常简单的使用不同方法来统计训练日志。我们先来介绍如何单独使用历史缓冲区,在消息枢纽一节再进一步介绍二者的联动。

历史缓冲区初始化

历史缓冲区的初始化可以接受 log_historycount_history 两个参数。log_history 表示日志的历史轨迹,例如前三次迭代的 loss 为 0.3,0.2,0.1。我们就可以记 log_history=[0.3, 0.2, 0.1]count_history 是一个比较抽象的概念,如果按照迭代次数来算,0.3,0.2,0.1 分别是三次迭代的结果,那么我们可以记 count_history=[1, 1, 1],其中 “1” 表示一次迭代。如果按照 batch 来算,例如每次迭代的 batch_size 为 8,那么 count_history=[8, 8, 8]count_history 只会在统计均值时用到,用于控制返回均值的粒度。就拿上面那个例子来说,count_history=[1, 1, 1] 时会统计每次迭代的平均 loss,而 count_history=[8, 8, 8] 则会统计每张图片的平均 loss。

from mmengine.logging import HistoryBuffer

history_buffer = HistoryBuffer()  # 空初始化
log_history, count_history = history_buffer.data
# [] []
history_buffer = HistoryBuffer([1, 2, 3], [1, 2, 3])  # list 初始化
log_history, count_history = history_buffer.data
# [1 2 3] [1 2 3]
history_buffer = HistoryBuffer([1, 2, 3], [1, 2, 3], max_length=2)
# The length of history buffer(3) exceeds the max_length(2), the first few elements will be ignored.
log_history, count_history = history_buffer.data  # 最大长度为2,只能存储 [2, 3]
# [2 3] [2 3]

我们可以通过 history_buffer.data 来返回日志的历史轨迹。此外,我们可以为历史缓冲区设置最大队列长度,当历史缓冲区的长度大于最大队列长度时,会自动丢弃最早更新的数据。

更新历史缓冲区

我们可以通过 update 接口来更新历史缓冲区。update 接受两个参数,第一个参数用于更新 log_history ,第二个参数用于更新 count_history

history_buffer = HistoryBuffer([1, 2, 3], [1, 1, 1])
history_buffer.update(4)  # 更新日志
log_history, count_history = history_buffer.data
# [1, 2, 3, 4] [1, 1, 1, 1]
history_buffer.update(5, 2)  # 更新日志
log_history, count_history = history_buffer.data
# [1, 2, 3, 4, 5] [1, 1, 1, 1, 2]

基本统计方法

历史缓冲区提供了基本的数据统计方法:

  • current():获取最新更新的数据。

  • mean(window_size=None):获取窗口内数据的均值,默认返回数据的全局均值

  • max(window_size=None):获取窗口内数据的最大值,默认返回全局最大值

  • min(window_size=None):获取窗口内数据的最小值,默认返回全局最小值

history_buffer = HistoryBuffer([1, 2, 3], [1, 1, 1])
history_buffer.min(2)
# 2,从 [2, 3] 中统计最小值
history_buffer.min()
# 返回全局最小值 1

history_buffer.max(2)
# 3,从 [2, 3] 中统计最大值
history_buffer.min()
# 返回全局最大值 3
history_buffer.mean(2)
# 2.5,从 [2, 3] 中统计均值, (2 + 3) / (1 + 1)
history_buffer.mean()  # (1 + 2 + 3) / (1 + 1 + 1)
# 返回全局均值 2
history_buffer = HistoryBuffer([1, 2, 3], [2, 2, 2])  # 当 count 不为 1时
history_buffer.mean()  # (1 + 2 + 3) / (2 + 2 + 2)
# 返回均值 1
history_buffer = HistoryBuffer([1, 2, 3], [1, 1, 1])
history_buffer.update(4, 1)
history_buffer.current()
# 4

统计方法的统一入口

要想支持在配置文件中通过配置 ‘max’,’min’ 等字段来选择日志的统计方式,那么 HistoryBuffer 就需要一个接口来接受 ‘min’,’max’ 等统计方法字符串和相应参数,进而找到对应的统计方法,最后输出统计结果。statistics(name, *args, **kwargs) 接口就起到了这个作用。其中 name 是已注册的方法名(已经注册 minmax 等基本统计方法),*arg**kwarg 用于接受对应方法的参数。

history_buffer = HistoryBuffer([1, 2, 3], [1, 1, 1])
history_buffer.statistics('mean')
# 2 返回全局均值
history_buffer.statistics('mean', 2)
# 2.5 返回 [2, 3] 的均值
history_buffer.statistics('mean', 2, 3)  # 错误!传入了不匹配的参数
history_buffer.statistics('data')  # 错误! data 方法未被注册,无法被调用

注册统计方法

为了保证历史缓冲区的可扩展性,用户可以通过 register_statistics 接口注册自定义的统计函数

from mmengine.logging import HistoryBuffer
import numpy as np


@HistoryBuffer.register_statistics
def weighted_mean(self, window_size, weight):
    assert len(weight) == window_size
    return (self._log_history[-window_size:] * np.array(weight)).sum() / \
            self._count_history[-window_size:]


history_buffer = HistoryBuffer([1, 2], [1, 1])
history_buffer.statistics('weighted_mean', 2, [2, 1])  # get (2 * 1 + 1 * 2) / (1 + 1)

用户可以通过 statistics 接口,传入方法名和对应参数来调用被注册的函数。

使用样例

用户可以独立使用历史缓冲区来记录日志,通过简单的接口调用就能得到期望的统计接口。

logs = dict(lr=HistoryBuffer(), loss=HistoryBuffer())  # 字典配合 HistoryBuffer 记录不同字段的日志
max_iter = 10
log_interval = 5
for iter in range(1, max_iter+1):
    lr = iter / max_iter * 0.1  # 线性学习率变化
    loss = 1 / iter  # loss
    logs['lr'].update(lr, 1)
    logs['loss'].update(loss, 1)
    if iter % log_interval == 0:
        latest_lr = logs['lr'].statistics('current')  # 通过字符串来选择统计方法
        mean_loss = logs['loss'].statistics('mean', log_interval)
        print(f'lr:   {latest_lr}\n'  # 返回最近一次更新的学习率。
              f'loss: {mean_loss}')   # 平滑最新更新的 log_interval 个数据。
# lr:   0.05
# loss: 0.45666666666666667
# lr:   0.1
# loss: 0.12912698412698415

MMEngine 利用历史缓冲区的特性,结合消息枢纽,实现了训练日志的高度可定制化。

消息枢纽(MessageHub)

历史缓冲区(HistoryBuffer)可以十分简单地完成单个日志的更新和统计,而在模型训练过程中,日志的种类繁多,并且来自于不同的组件,因此如何完成日志的分发和收集是需要考虑的问题。 MMEngine 使用消息枢纽(MessageHub)来实现组件与组件、执行器与执行器之间的数据共享。消息枢纽继承自全局管理器(ManagerMixin),支持跨模块访问。

消息枢纽存储了两种含义的数据:

  • 历史缓冲区字典:消息枢纽会收集各个模块更新的训练日志,如损失、学习率、迭代时间,并将其更新至内部的历史缓冲区字典中。历史缓冲区字典经消息处理器(LogProcessor)处理后,会被输出到终端/保存到本地。如果用户需要记录自定义日志,可以往历史缓冲区字典中更新相应内容。

  • 运行时信息字典:运行时信息字典用于存储迭代次数、训练轮次等运行时信息,方便 MMEngine 中所有组件共享这些信息。

注解

当用户想在终端输出自定义日志,或者想跨模块共享一些自定义数据时,才会用到消息枢纽。

为了方便用户理解消息枢纽在训练过程中更新信息以及分发信息的流程,我们通过几个例子来介绍消息枢纽的使用方法,以及如何使用消息枢纽向终端输出自定义日志。

更新/获取训练日志

历史缓冲区以字典的形式存储在消息枢纽中。当我们第一次调用 update_scalar 时,会初始化对应字段的历史缓冲区,后续的每次更新等价于调用对应字段历史缓冲区的 update 方法。同样的我们可以通过 get_scalar 来获取对应字段的历史缓冲区,并按需计算统计值。如果想获取消息枢纽的全部日志,可以访问其 log_scalars 属性。

from mmengine import MessageHub

message_hub = MessageHub.get_instance('task')
message_hub.update_scalar('train/loss', 1, 1)
message_hub.get_scalar('train/loss').current()  # 1,最近一次更新值为 1
message_hub.update_scalar('train/loss', 3, 1)
message_hub.get_scalar('train/loss').mean()  # 2,均值为 (3 + 1) / (1 + 1)
message_hub.update_scalar('train/lr', 0.1, 1)

message_hub.update_scalars({'train/time': {'value': 0.1, 'count': 1},
                            'train/data_time': {'value': 0.1, 'count': 1}})

train_time = message_hub.get_scalar('train/time')  # 获取单个日志

log_dict = message_hub.log_scalars  # 返回存储全部 HistoryBuffer 的字典
lr_buffer, loss_buffer, time_buffer, data_time_buffer = (
    log_dict['train/lr'], log_dict['train/loss'], log_dict['train/time'],
    log_dict['train/data_time'])

注解

损失、学习率、迭代时间等训练日志在执行器和钩子中自动更新,无需用户维护。

注解

消息枢纽的历史缓冲区字典对 key 没有特殊要求,但是 MMEngine 约定历史缓冲区字典的 key 要有 train/val/test 的前缀,只有带前缀的日志会被输出当终端。

更新/获取运行时信息

运行时信息以字典的形式存储在消息枢纽中,能够存储任意数据类型,每次更新都会被覆盖。

message_hub = MessageHub.get_instance('task')
message_hub.update_info('iter', 1)
message_hub.get_info('iter')  # 1
message_hub.update_info('iter', 2)
message_hub.get_info('iter')  # 2 覆盖上一次结果

消息枢纽的跨组件通讯

执行器运行过程中,各个组件会通过消息枢纽来分发、接受消息。RuntimeInfoHook 会汇总其他组件更新的学习率、损失等信息,将其导出到用户指定的输出端(Tensorboard,WandB 等)。由于上述流程较为复杂,这里用一个简单示例来模拟日志钩子和其他组件通讯的过程。

from mmengine import MessageHub

class LogProcessor:
    # 汇总不同模块更新的消息,类似 LoggerHook
    def __init__(self, name):
        self.message_hub = MessageHub.get_instance(name)  # 获取 MessageHub

    def run(self):
        print(f"Learning rate is {self.message_hub.get_scalar('train/lr').current()}")
        print(f"loss is {self.message_hub.get_scalar('train/loss').current()}")
        print(f"meta is {self.message_hub.get_info('meta')}")


class LrUpdater:
    # 更新学习率
    def __init__(self, name):
        self.message_hub = MessageHub.get_instance(name)  # 获取 MessageHub

    def run(self):
        self.message_hub.update_scalar('train/lr', 0.001)
        # 更新学习率,以 HistoryBuffer 形式存储


class MetaUpdater:
    # 更新元信息
    def __init__(self, name):
        self.message_hub = MessageHub.get_instance(name)

    def run(self):
        self.message_hub.update_info(
            'meta',
            dict(experiment='retinanet_r50_caffe_fpn_1x_coco.py',
                 repo='mmdetection'))    # 更新元信息,每次更新会覆盖上一次的信息


class LossUpdater:
    # 更新损失函数
    def __init__(self, name):
        self.message_hub = MessageHub.get_instance(name)

    def run(self):
        self.message_hub.update_scalar('train/loss', 0.1)

class ToyRunner:
    # 组合个各个模块
    def __init__(self, name):
        self.message_hub = MessageHub.get_instance(name)  # 创建 MessageHub
        self.log_processor = LogProcessor(name)
        self.updaters = [LossUpdater(name),
                         MetaUpdater(name),
                         LrUpdater(name)]

    def run(self):
        for updater in self.updaters:
            updater.run()
        self.log_processor.run()

if __name__ == '__main__':
    task = ToyRunner('name')
    task.run()
    # Learning rate is 0.001
    # loss is 0.1
    # meta {'experiment': 'retinanet_r50_caffe_fpn_1x_coco.py', 'repo': 'mmdetection'}

添加自定义日志

我们可以在任意模块里更新消息枢纽的历史缓冲区字典,历史缓冲区字典中所有的合法字段经统计后最后显示到终端。

注解

更新历史缓冲区字典时,需要保证更新的日志名带有 train,val,test 前缀,否则日志不会在终端显示。

class CustomModule:
    def __init__(self):
        self.message_hub = MessageHub.get_current_instance()

    def custom_method(self):
        self.message_hub.update_scalar('train/a', 100)
        self.message_hub.update_scalars({'train/b': 1, 'train/c': 2})

默认情况下,终端上额外显示 a、b、c 最后一次更新的结果。我们也可以通过配置日志处理器来切换自定义日志的统计方式。

日志处理器(LogProcessor)

用户可以通过配置日志处理器(LogProcessor)来控制日志的统计方法及其参数。默认配置下,日志处理器会统计最近一次更新的学习率、基于迭代次数平滑的损失和迭代时间。用户可以在日志处理器中配置已知字段的统计方式。

最简配置

log_processor = dict(
    window_size=10,
)

此时终端会输出每 10 次迭代的平均损失和平均迭代时间。假设此时终端的输出为

04/15 12:34:24 - mmengine - INFO - Iter [10/12]  , eta: 0:00:00, time: 0.003, data_time: 0.002, loss: 0.13

自定义的统计方式

我们可以通过配置 custom_cfg 列表来选择日志的统计方式。custom_cfg 中的每一个元素需要包括以下信息:

  • data_src:日志的数据源,用户通过指定 data_src 来选择需要被重新统计的日志,一份数据源可以有多种统计方式。默认的日志源包括模型输出的损失字典的 key、学习率(lr)和迭代时间(time/data_time),一切经消息枢纽的 update_scalar/update_scalars 更新的日志均为可以配置的数据源(需要去掉 train/val/ 前缀)。(必填项)

  • method_name:日志的统计方法,即历史缓冲区中的基本统计方法以及用户注册的自定义统计方法(必填项)

  • log_name:日志被重新统计后的名字,如果不定义 log_name,新日志会覆盖旧日志(选填项)

  • 其他参数:统计方法会用到的参数,其中 window_size 为特殊字段,可以为普通的整型、字符串 epoch 和字符串 global。LogProcessor 会实时解析这些参数,以返回基于 iteration、epoch 和全局平滑的统计结果(选填项)

  1. 覆盖旧的统计方式

log_processor = dict(
    window_size=10,
    by_epoch=True,
    custom_cfg=[
        dict(data_src='loss',
             method_name='mean',
             window_size=100)])

此时会无视日志处理器的默认窗口 10,用更大的窗口 100 去统计 loss 的均值,并将原有结果覆盖。

04/15 12:34:24 - mmengine - INFO - Iter [10/12]  , eta: 0:00:00, time: 0.003, data_time: 0.002, loss: 0.11
  1. 新增统计方式,不覆盖

log_processor = dict(
    window_size=10,
    by_epoch=True,
    custom_cfg=[
        dict(data_src='loss',
             log_name='loss_min',
             method_name='min',
             window_size=100)])
04/15 12:34:24 - mmengine - INFO - Iter [10/12]  , eta: 0:00:00, time: 0.003, data_time: 0.002, loss: 0.11, loss_min: 0.08

MMLogger

为了能够导出层次分明、格式统一、且不受三方库日志系统干扰的日志,MMEngine 在 logging 模块的基础上实现了 MMLoggerMMLogger 继承自全局管理器(ManagerMixin),相比于 logging.LoggerMMLogger 能够在无法获取 logger 的名字(logger name)的情况下,拿到当前执行器的 logger

创建 MMLogger

我们可以通过 get_instance 接口创建全局可获取的 logger,默认的日志格式如下

logger = MMLogger.get_instance('mmengine', log_level='INFO')
logger.info("this is a test")
# 04/15 14:01:11 - mmengine - INFO - this is a test

logger 除了输出消息外,还会额外输出时间戳、logger 的名字和日志等级。对于 ERROR 等级的日志,我们会用红色高亮日志等级,并额外输出错误日志的代码位置

logger = MMLogger.get_instance('mmengine', log_level='INFO')
logger.error('division by zero')
# 04/15 14:01:56 - mmengine - ERROR - /mnt/d/PythonCode/DeepLearning/OpenMMLab/mmengine/a.py - <module> - 4 - division by zero

导出日志

调用 get_instance 时,如果指定了 log_file,会将日志记录的信息以文本格式导出到本地。

logger = MMLogger.get_instance('mmengine', log_file='tmp.log', log_level='INFO')
logger.info("this is a test")
# 04/15 14:01:11 - mmengine - INFO - this is a test

tmp/tmp.log:

04/15 14:01:11 - mmengine - INFO - this is a test

由于分布式情况下会创建多个日志文件,因此我们在预定的导出路径下,增加一级和导出文件同名的目录,用于存储所有进程的日志。上例中导出路径为 tmp.log,实际存储路径为 tmp/tmp.log

分布式训练时导出日志

使用 pytorch 分布式训练时,我们可以通过配置 distributed=True 来导出分布式训练时各个进程的日志(默认关闭)。

logger = MMLogger.get_instance('mmengine', log_file='tmp.log', distributed=True, log_level='INFO')

单机多卡,或者多机多卡但是共享存储的情况下,导出的分布式日志路径如下

#  共享存储
./tmp
├── tmp.log
├── tmp_rank1.log
├── tmp_rank2.log
├── tmp_rank3.log
├── tmp_rank4.log
├── tmp_rank5.log
├── tmp_rank6.log
└── tmp_rank7.log
...
└── tmp_rank63.log

多机多卡,独立存储的情况:

# 独立存储
# 设备0:
work_dir/
└── exp_name_logs
    ├── exp_name.log
    ├── exp_name_rank1.log
    ├── exp_name_rank2.log
    ├── exp_name_rank3.log
    ...
    └── exp_name_rank7.log

# 设备7:
work_dir/
└── exp_name_logs
    ├── exp_name_rank56.log
    ├── exp_name_rank57.log
    ├── exp_name_rank58.log
    ...
    └── exp_name_rank63.log

迁移 MMCV 执行器到 MMEngine

简介

随着支持的深度学习任务越来越多,用户的需求不断增加,我们对 MMCV 已有的执行器(Runner)的灵活性和通用性有了更高的要求。 因此,MMEngine 在 MMCV 的基础上,实现了一个更加通用灵活的执行器以支持更多复杂的模型训练流程。 MMEngine 中的执行器扩大了作用域,也承担了更多的功能;我们抽象出了训练循环控制器(EpochBasedTrainLoop/IterBasedTrainLoop)验证循环控制器(ValLoop)测试循环控制器(TestLoop)来方便用户灵活拓展模型的执行流程。

我们将首先介绍算法库的执行入口该如何从 MMCV 迁移到 MMEngine, 以最大程度地简化和统一执行入口的代码。 然后我们将详细介绍在 MMCV 和 MMEngine 中构造执行器及其内部组件进行训练的差异。 在开始迁移前,我们建议用户先阅读执行器教程

执行入口

以 MMDet 为例,我们首先展示基于 MMEngine 重构前后,配置文件和训练启动脚本的区别:

配置文件的迁移

基于 MMCV 执行器的配置文件概览 基于 MMEngine 执行器的配置文件概览
# default_runtime.py
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]


opencv_num_threads = 0
mp_start_method = 'fork'
auto_scale_lr = dict(enable=False, base_batch_size=16)
# default_runtime.py
default_scope = 'mmdet'

default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=50),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=1),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='DetVisualizationHook'))

env_cfg = dict(
    cudnn_benchmark=False,
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    dist_cfg=dict(backend='nccl'),
)

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
    type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)

log_level = 'INFO'
load_from = None
resume = False
# schedule.py

# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
# scheduler.py

# training schedule for 1x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# learning rate
param_scheduler = [
    dict(
        type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
    dict(
        type='MultiStepLR',
        begin=0,
        end=12,
        by_epoch=True,
        milestones=[8, 11],
        gamma=0.1)
]

# optimizer
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001))

# Default setting for scaling LR automatically
#   - `enable` means enable scaling LR automatically
#       or not by default.
#   - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)
# coco_detection.py

# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json',
        img_prefix=data_root + 'train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
# coco_detection.py

# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'

file_client_args = dict(backend='disk')

train_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=file_client_args),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PackDetInputs')
]
test_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=file_client_args),
    dict(type='Resize', scale=(1333, 800), keep_ratio=True),
    # If you don't have a gt annotation, delete the pipeline
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]
train_dataloader = dict(
    batch_size=2,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    batch_sampler=dict(type='AspectRatioBatchSampler'),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='annotations/instances_train2017.json',
        data_prefix=dict(img='train2017/'),
        filter_cfg=dict(filter_empty_gt=True, min_size=32),
        pipeline=train_pipeline))
val_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='annotations/instances_val2017.json',
        data_prefix=dict(img='val2017/'),
        test_mode=True,
        pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(
    type='CocoMetric',
    ann_file=data_root + 'annotations/instances_val2017.json',
    metric='bbox',
    format_only=False)
test_evaluator = val_evaluator

MMEngine 中的执行器提供了更多可自定义的部分,包括训练、验证、测试过程和数据加载器的配置,因此配置文件和之前相比会长一些。 为了方便用户的理解和阅读,我们遵循所见即所得的原则,重新调整了各个组件配置的层次,使得大部分一级字段都对应着执行器中关键属性的配置,例如数据加载器、评测器、钩子配置等。 这些配置在 OpenMMLab 2.0 算法库中都有默认配置,因此用户很多时候无需关心其中的大部分参数。

启动脚本的迁移

相比于 MMCV 的执行器,MMEngine 的执行器可以承担更多的功能,例如构建 DataLoader,构建分布式模型等。因此我们需要在配置文件中指定更多的参数,例如 DataLoadersamplerbatch_sampler,而无需在训练的启动脚本里实现构建 DataLoader 相关的代码。以 MMDet 的训练启动脚本为例:

基于 MMCV 执行器的训练启动脚本 基于 MMEngine 执行器的训练启动脚本
# tools/train.py

args = parse_args()

cfg = Config.fromfile(args.config)

# replace the ${key} with the value of cfg.key
cfg = replace_cfg_vals(cfg)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

if args.cfg_options is not None:
    cfg.merge_from_dict(args.cfg_options)

if args.auto_scale_lr:
    if 'auto_scale_lr' in cfg and \
            'enable' in cfg.auto_scale_lr and \
            'base_batch_size' in cfg.auto_scale_lr:
        cfg.auto_scale_lr.enable = True
    else:
        warnings.warn('Can not find "auto_scale_lr" or '
                        '"auto_scale_lr.enable" or '
                        '"auto_scale_lr.base_batch_size" in your'
                        ' configuration file. Please update all the '
                        'configuration files to mmdet >= 2.24.1.')

# set multi-process settings
setup_multi_processes(cfg)

# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
    torch.backends.cudnn.benchmark = True

# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
    # update configs according to CLI args if args.work_dir is not None
    cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
    # use config filename as default work_dir if cfg.work_dir is None
    cfg.work_dir = osp.join('./work_dirs',
                            osp.splitext(osp.basename(args.config))[0])

if args.resume_from is not None:
    cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
    cfg.gpu_ids = range(1)
    warnings.warn('`--gpus` is deprecated because we only support '
                    'single GPU mode in non-distributed training. '
                    'Use `gpus=1` now.')
if args.gpu_ids is not None:
    cfg.gpu_ids = args.gpu_ids[0:1]
    warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                    'Because we only support single GPU mode in '
                    'non-distributed training. Use the first GPU '
                    'in `gpu_ids` now.')
if args.gpus is None and args.gpu_ids is None:
    cfg.gpu_ids = [args.gpu_id]

# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
    distributed = False
else:
    distributed = True
    init_dist(args.launcher, **cfg.dist_params)
    # re-set gpu_ids with distributed training mode
    _, world_size = get_dist_info()
    cfg.gpu_ids = range(world_size)

# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
            dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')

cfg.device = get_device()
# set random seeds
seed = init_random_seed(args.seed, device=cfg.device)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
            f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)

model = build_detector(
    cfg.model,
    train_cfg=cfg.get('train_cfg'),
    test_cfg=cfg.get('test_cfg'))
model.init_weights()

datasets = []
train_detector(
    model,
    datasets,
    cfg,
    distributed=distributed,
    validate=(not args.no_validate),
    timestamp=timestamp,
    meta=meta)
# tools/train.py

args = parse_args()

# register all modules in mmdet into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)

# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
    cfg.merge_from_dict(args.cfg_options)

# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
    # update configs according to CLI args if args.work_dir is not None
    cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
    # use config filename as default work_dir if cfg.work_dir is None
    cfg.work_dir = osp.join('./work_dirs',
                            osp.splitext(osp.basename(args.config))[0])

# enable automatic-mixed-precision training
if args.amp is True:
    optim_wrapper = cfg.optim_wrapper.type
    if optim_wrapper == 'AmpOptimWrapper':
        print_log(
            'AMP training is already enabled in your config.',
            logger='current',
            level=logging.WARNING)
    else:
        assert optim_wrapper == 'OptimWrapper', (
            '`--amp` is only supported when the optimizer wrapper type is '
            f'`OptimWrapper` but got {optim_wrapper}.')
        cfg.optim_wrapper.type = 'AmpOptimWrapper'
        cfg.optim_wrapper.loss_scale = 'dynamic'

# enable automatically scaling LR
if args.auto_scale_lr:
    if 'auto_scale_lr' in cfg and \
            'enable' in cfg.auto_scale_lr and \
            'base_batch_size' in cfg.auto_scale_lr:
        cfg.auto_scale_lr.enable = True
    else:
        raise RuntimeError('Can not find "auto_scale_lr" or '
                            '"auto_scale_lr.enable" or '
                            '"auto_scale_lr.base_batch_size" in your'
                            ' configuration file.')

cfg.resume = args.resume

# build the runner from config
if 'runner_type' not in cfg:
    # build the default runner
    runner = Runner.from_cfg(cfg)
else:
    # build customized runner from the registry
    # if 'runner_type' is set in the cfg
    runner = RUNNERS.build(cfg)

# start training
runner.train()
def init_random_seed(...):
    ...

def set_random_seed(...):
    ...

# define function tools.
...


def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):

    cfg = compat_cfg(cfg)
    logger = get_root_logger(log_level=cfg.log_level)

    # put model on gpus
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        model = build_ddp(
            model,
            cfg.device,
            device_ids=[int(os.environ['LOCAL_RANK'])],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)
    else:
        model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)

    # build optimizer
    auto_scale_lr(cfg, distributed, logger)
    optimizer = build_optimizer(model, cfg.optimizer)

    runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))

    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))

    if distributed:
        if isinstance(runner, EpochBasedRunner):
            runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        val_dataloader_default_args = dict(
            samples_per_gpu=1,
            workers_per_gpu=2,
            dist=distributed,
            shuffle=False,
            persistent_workers=False)

        val_dataloader_args = {
            **val_dataloader_default_args,
            **cfg.data.get('val_dataloader', {})
        }
        # Support batch_size > 1 in validation

        if val_dataloader_args['samples_per_gpu'] > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.val.pipeline = replace_ImageToTensor(
                cfg.data.val.pipeline)
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))

        val_dataloader = build_dataloader(val_dataset, **val_dataloader_args)
        eval_cfg = cfg.get('evaluation', {})
        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
        eval_hook = DistEvalHook if distributed else EvalHook
        # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
        # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
        runner.register_hook(
            eval_hook(val_dataloader, **eval_cfg), priority='LOW')

    resume_from = None
    if cfg.resume_from is None and cfg.get('auto_resume'):
        resume_from = find_latest_checkpoint(cfg.work_dir)
    if resume_from is not None:
        cfg.resume_from = resume_from

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow)
# `apis/train.py` is removed in `mmengine`

上表对比了基于 MMCV 执行器和 MMEngine 执行器 MMDet 启动脚本的区别。 OpenMMLab 1.x 中的算法库都实现了一套 runner 的构建和训练流程,其中存在着大量的冗余代码。因此,MMEngine 的执行器在内部实现了很多流程化的代码以统一各个算法库的执行流程,例如初始化随机种子、初始化分布式环境、构建 DataLoader 等,使得下游算法库从此无需在训练启动脚本里实现相关代码,只需配置执行器的构造参数,就能够执行相应的流程。 基于 MMEngine 执行器的启动脚本不仅简化了 tools/train.py 的代码,甚至可以直接删除 apis/train.py,极大程度的简化了训练启动脚本。同样的,我们在基于 MMEngine 开发自己的代码仓库时,可以通过配置执行器参数来设置随机种子、初始化分布式环境,无需自行实现相关代码。

迁移执行器(Runner)

本节主要介绍 MMCV 执行器和 MMEngine 执行器在训练、验证、测试流程上的区别。 在使用 MMCV 执行器和 MMEngine 执行器训练、测试模型时,以下流程有着明显的不同:

  1. 准备logger

  2. 设置随机种子

  3. 初始化环境变量

  4. 准备数据

  5. 准备模型

  6. 准备优化器

  7. 准备钩子

  8. 准备验证/测试模块

  9. 构建执行器

  10. 执行器加载检查点

  11. 开始训练开始测试

  12. 迁移自定义训练流程

后续的教程中,我们会对每个流程的差异进行详细介绍。

准备logger

MMCV 准备 logger

MMCV 需要在训练脚本里调用 get_logger 接口获得 logger,并用它输出、记录训练环境。

logger = get_logger(name='custom', log_file=log_file, log_level=cfg.log_level)
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
            dash_line)

执行器构造时,也需要传入 logger。

runner = Runner(
    ...
    logger=logger
    ...)

MMEngine 准备 logger

在执行器构建时传入 logger 的日志等级,执行器构建时会自动创建 logger,并输出、记录训练环境。

log_level = 'INFO'

设置随机种子

MMCV 设置随机种子

在训练脚本中手动设置随机种子:

...
seed = init_random_seed(args.seed, device=cfg.device)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
            f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
...

MMEngine 设计随机种子

配置执行器的 randomness 参数,配置规则详见执行器 api 文档

OpenMMLab 系列算法库配置变更

MMCV 配置 MMEngine 配置
seed = 1
deterministic=False
diff_seed=False
randomness=dict(seed=1,
                deterministic=True,
                diff_rank_seed=False)

在本教程中,我们将 randomness 配置为:

randomness = dict(seed=5)

初始化训练环境

MMCV 初始化训练环境

MMCV 需要在训练脚本中配置多进程启动方式、多进程通信后端等环境变量,并在执行器构建之前初始化分布式环境,对模型进行分布式封装:

...
setup_multi_processes(cfg)
init_dist(cfg.launcher, **cfg.dist_params)
model = MMDistributedDataParallel(
    model,
    device_ids=[int(os.environ['LOCAL_RANK'])],
    broadcast_buffers=False,
    find_unused_parameters=find_unused_parameters)

MMEngine 初始化训练环境

MMEngine 通过配置 env_cfg 来选择多进程启动方式和多进程通信后端, 其默认值为 dict(dist_cfg=dict(backend='nccl')),配置方式详见执行器 api 文档

执行器构建时接受 launcher 参数,如果其值不为 'none',执行器构建时会自动执行分布式初始化,模型分布式封装。换句话说,使用 MMEngine 的执行器时,我们无需在执行器外做分布式相关的操作,只需配置 launcher 参数,选择训练的启动方式即可。

OpenMMLab 系列算法库配置变更

MMCV 配置 MMEngine 配置
launcher = 'pytorch'  # 开启分布式训练
dist_params = dict(backend='nccl')  # 选择多进程通信后端
launcher = 'pytorch'
env_cfg = dict(dist_cfg=dict(backend='nccl'))

在本教程中,我们将 env_cfg 配置为:

env_cfg = dict(dist_cfg=dict(backend='nccl'))

准备数据

MMCV 和 MMEngine 的执行器均可接受构建好的 DataLoader 实例。因此准备数据的流程没有差异:

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(
    root='data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=2)

val_dataset = CIFAR10(
    root='data', train=False, download=True, transform=transform)
val_dataloader = DataLoader(
    val_dataset, batch_size=128, shuffle=False, num_workers=2)

OpenMMLab 系列算法库配置变更

MMCV 配置 MMEngine 配置
data = dict(
    samples_per_gpu=2,  # 单卡 batch_size
    workers_per_gpu=2,  # Dataloader 采样进程数
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json',
        img_prefix=data_root + 'train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline))
train_dataloader = dict(
    batch_size=2, # samples_per_gpu -> batch_size,
    num_workers=2,
    # 遍历完 DataLoader 后,是否重启多进程采样
    persistent_workers=True,
    # 可配置的 sampler
    sampler=dict(type='DefaultSampler', shuffle=True),
    # 可配置的 batch_sampler
    batch_sampler=dict(type='AspectRatioBatchSampler'),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='annotations/instances_train2017.json',
        data_prefix=dict(img='train2017/'),
        filter_cfg=dict(filter_empty_gt=True, min_size=32),
        pipeline=train_pipeline))

val_dataloader = dict(
    batch_size=1, # 验证阶段的 batch_size
    num_workers=2,
    persistent_workers=True,
    drop_last=False, # 是否丢弃最后一个 batch
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='annotations/instances_val2017.json',
        data_prefix=dict(img='val2017/'),
        test_mode=True,
        pipeline=test_pipeline))

test_dataloader = val_dataloader

相比于 MMCV 的算法库配置,MMEngine 的配置更加复杂,但是也更加灵活。DataLoader 的配置过程由 Runner 负责,无需各个算法库实现。

准备模型

详见迁移 MMCV 模型至 MMEngine

import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModel


class Model(BaseModel):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, img, label, mode):
        feat = self.pool(F.relu(self.conv1(img)))
        feat = self.pool(F.relu(self.conv2(feat)))
        feat = feat.view(-1, 16 * 5 * 5)
        feat = F.relu(self.fc1(feat))
        feat = F.relu(self.fc2(feat))
        feat = self.fc3(feat)
        if mode == 'loss':
            loss = self.loss_fn(feat, label)
            return dict(loss=loss)
        else:
            return [feat.argmax(1)]

model = Model()

需要注意的是,分布式训练时,MMCV 的执行器需要接受分布式封装后的模型,而 MMEngine 接受分布式封装前的模型,在执行器实例化阶段对其段进行分布式封装。

准备优化器

MMCV 准备优化器

MMCV 执行器构造时,可以直接接受 Pytorch 优化器,如

optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)

对于复杂配置的优化器,MMCV 需要基于优化器构造器来构建优化器:


optimizer_cfg = dict(
    optimizer=dict(type='SGD', lr=0.01, weight_decay=0.0001),
    paramwise_cfg=dict(norm_decay_mult=0))

def build_optimizer_constructor(cfg):
    constructor_type = cfg.get('type')
    if constructor_type in OPTIMIZER_BUILDERS:
        return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
    elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
        return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
    else:
        raise KeyError(f'{constructor_type} is not registered '
                       'in the optimizer builder registry.')


def build_optimizer(model, cfg):
    optimizer_cfg = copy.deepcopy(cfg)
    constructor_type = optimizer_cfg.pop('constructor',
                                         'DefaultOptimizerConstructor')
    paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
    optim_constructor = build_optimizer_constructor(
        dict(
            type=constructor_type,
            optimizer_cfg=optimizer_cfg,
            paramwise_cfg=paramwise_cfg))
    optimizer = optim_constructor(model)
    return optimizer

optimizer = build_optimizer(model, optimizer_cfg)

MMEngine 准备优化器

构建 MMEngine 执行器时,需要接受 optim_wrapper 参数,即优化器封装实例或者优化器封装配置,对于复杂配置的优化器封装,MMEngine 同样只需要配置 optim_wrapperoptim_wrapper 的详细介绍见执行器 api 文档

OpenMMLab 系列算法库配置变更

MMCV 配置 MMEngine 配置
optimizer = dict(
    constructor='CustomConstructor',
    type='AdamW',  # 优化器配置为一级字段
    lr=0.0001,  # 优化器配置为一级字段
    betas=(0.9, 0.999),  # 优化器配置为一级字段
    weight_decay=0.05,  # 优化器配置为一级字段
    paramwise_cfg={  # constructor 的参数
        'decay_rate': 0.95,
        'decay_type': 'layer_wise',
        'num_layers': 6
    })
# MMCV 还需要配置 `optim_config`
# 来构建优化器钩子,而 MMEngine 不需要
optimizer_config = dict(grad_clip=None)
optim_wrapper = dict(
    constructor='CustomConstructor',
    type='OptimWrapper',  # 指定优化器封装类型
    optimizer=dict(  # 将优化器配置集中在 optimizer 内
        type='AdamW',
        lr=0.0001,
        betas=(0.9, 0.999),
        weight_decay=0.05)
    paramwise_cfg={
        'decay_rate': 0.95,
        'decay_type': 'layer_wise',
        'num_layers': 6
    })
对于检测、分类一类的上层任务(High level)MMCV 需要配置 `optim_config` 来构建优化器钩子,而 MMEngine 不需要。

本教程使用的 optim_wrapper 如下:

from torch.optim import SGD

optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
optim_wrapper = dict(optimizer=optimizer)

准备训练钩子

MMCV 准备训练钩子:

MMCV 常用训练钩子的配置如下:

# learning rate scheduler config
lr_config = dict(policy='step', step=[2, 3])
# configuration of optimizer
optimizer_config = dict(grad_clip=None)
# configuration of saving checkpoints periodically
checkpoint_config = dict(interval=1)
# save log periodically and multiple hooks can be used simultaneously
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
# register hooks to runner and those hooks will be invoked automatically
runner.register_training_hooks(
    lr_config=lr_config,
    optimizer_config=optimizer_config,
    checkpoint_config=checkpoint_config,
    log_config=log_config)

其中:

  • lr_config 用于配置 LrUpdaterHook

  • optimizer_config 用于配置 OptimizerHook

  • checkpoint_config 用于配置 CheckPointHook

  • log_config 用于配置 LoggerHook

除了上面提到的 4 个 Hook,MMCV 执行器自带 IterTimerHook。MMCV 需要先实例化执行器,再注册训练钩子,而 MMEngine 则在实例化阶段配置钩子。

MMEngine 准备训练钩子

MMEngine 执行器将 MMCV 常用的训练钩子配置成默认钩子:

对比上例中 MMCV 配置的训练钩子:

  • LrUpdaterHook 对应 MMEngine 中的 ParamSchedulerHook,二者对应关系详见迁移 scheduler 文档

  • MMEngine 在模型的 train_step 时更新参数,因此不需要配置优化器钩子(OptimizerHook

  • MMEngine 自带 CheckPointHook,可以使用默认配置

  • MMEngine 自带 LoggerHook,可以使用默认配置

因此我们只需要配置执行器优化器参数调整策略(param_scheduler),就能达到和 MMCV 示例一样的效果。 MMEngine 也支持注册自定义钩子,具体教程详见执行器教程迁移 hook 文档

MMCV 常用训练钩子 MMEngine 默认钩子
# MMCV 零散的配置训练钩子
# 配置 LrUpdaterHook,相当于 MMEngine 的参数调度器
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[8, 11])

# 配置 OptimizerHook,MMEngine 不需要
optimizer_config = dict(grad_clip=None)

# 配置 LoggerHook
log_config = dict(  # LoggerHook
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])

# 配置 CheckPointHook
checkpoint_config = dict(interval=1)  # CheckPointHook
# 配置参数调度器
param_scheduler = [
    dict(
        type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
    dict(
        type='MultiStepLR',
        begin=0,
        end=12,
        by_epoch=True,
        milestones=[8, 11],
        gamma=0.1)
]

# MMEngine 集中配置默认钩子
default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=50),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=1),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='DetVisualizationHook'))

注解

MMEngine 移除了 OptimizerHook,优化步骤在 model 中执行。

本教程使用的 param_scheduler 如下:

from math import gamma

param_scheduler = dict(type='MultiStepLR', milestones=[2, 3], gamma=0.1)

准备验证模块

MMCV 借助 EvalHook 实现验证流程,受限于篇幅,这里不做进一步展开。MMEngine 通过验证循环控制器(ValLoop)评测器(Evaluator)实现执行流程,如果我们想基于自定义的评价指标完成验证流程,则需要定义一个 Metric,并将其注册至 METRICS 注册器:

import torch
from mmengine.evaluator import BaseMetric
from mmengine.registry import METRICS

@METRICS.register_module(force=True)
class ToyAccuracyMetric(BaseMetric):

    def process(self, label, pred) -> None:
        self.results.append((label[1], pred, len(label[1])))

    def compute_metrics(self, results: list) -> dict:
        num_sample = 0
        acc = 0
        for label, pred, batch_size in results:
            acc += (label == torch.stack(pred)).sum()
            num_sample += batch_size
        return dict(Accuracy=acc / num_sample)

实现自定义 Metric 后,我们还需在执行器的构造参数中配置评测器和验证循环控制器,本教程中示例配置如下:

val_evaluator = dict(type='ToyAccuracyMetric')
val_cfg = dict(type='ValLoop')
MMCV 配置验证流程 MMEngine 配置验证流程
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook  # 配置 EvalHook
runner.register_hook(
    eval_hook(val_dataloader, **eval_cfg), priority='LOW')  # 注册 EvalHook
val_dataloader = val_dataloader  # 配置验证数据
val_evaluator = dict(type='ToyAccuracyMetric')  # 配置评测器
val_cfg = dict(type='ValLoop')  # 配置验证循环控制器

构建执行器

MMCV 构建执行器

runner = EpochBasedRunner(
    model=model,
    optimizer=optimizer,
    work_dir=work_dir,
    logger=logger,
    max_epochs=4
)

MMEngine 构建执行器

MMEngine 执行器的作用域比 MMCV 更广,将设置随机种子、启动分布式训练等流程参数化。除了前几节提到的参数,上例中出现的EpochBasedRunnermax_epochsval_iterval 现在由 train_cfg 配置:

  • by_epoch: True 时相当于 MMCV 的 EpochBasedRunner``,False 时相当于 IterBasedRunner

  • max_epoch/max_iters: 同 MMCV 执行器的配置

  • val_iterval: 同 EvalHookinterval 参数

train_cfg 实际上是训练循环控制器的构造参数,当 by_epoch=True 时,使用 EpochBasedTrainLoop

from mmengine.runner import Runner

runner = Runner(
    model=model,  # 待优化的模型
    work_dir='./work_dir',  # 配置工作目录
    randomness=randomness,  # 配置随机种子
    env_cfg=env_cfg,  # 配置环境变量
    launcher='none',  # 分布式训练启动方式
    optim_wrapper=optim_wrapper,  # 配置优化器
    param_scheduler=param_scheduler,  # 配置学习率调度器
    train_dataloader=train_dataloader,  # 配置训练数据
    train_cfg=dict(by_epoch=True, max_epochs=4, val_interval=1),  # 配置训练循环控制器
    val_dataloader=val_dataloader,  # 配置验证数据
    val_evaluator=val_evaluator,  # 配置评测器
    val_cfg=val_cfg)  # 配置验证循环控制器

执行器加载检查点

MMCV 加载检查点

在训练之前执行加载权重、恢复训练的流程。

if cfg.resume_from:
    runner.resume(cfg.resume_from)
elif cfg.load_from:
    runner.load_checkpoint(cfg.load_from)

MMEngine 加载检查点

runner = Runner(
    ...
    load_from='/path/to/checkpoint',
    resume=True
)
MMCV 加载检查点配置 MMEngine 加载检查点配置
load_from = 'path/to/ckpt'
load_from = 'path/to/ckpt'
resume = False
resume_from = 'path/to/ckpt'
load_from = 'path/to/ckpt'
resume = True

执行器训练流程

MMCV 执行器训练流程

在训练之前执行加载权重、恢复训练的流程。然后再执行 runner.run,并传入训练数据。

if cfg.resume_from:
    runner.resume(cfg.resume_from)
elif cfg.load_from:
    runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)

MMEngine 执行器训练流程

在执行器构建时配置加载权重、恢复训练参数

由于 MMEngine 的执行器在构造阶段就传入了训练数据,因此在调用 runner.train() 无需传入参数。

runner.train()

执行器测试流程

MMCV 的执行器没有测试功能,因此需要自行实现测试脚本。MMEngine 的执行器只需要在构建时配置 test_dataloadertest_cfgtest_evaluator,然后再调用 runner.test() 就能完成测试流程。

work_dir 和训练时一致,无需手动加载 checkpoint:

runner = Runner(
    model=model,
    work_dir='./work_dir',
    randomness=randomness,
    env_cfg=env_cfg,
    launcher='none',  # 不开启分布式训练
    optim_wrapper=optim_wrapper,
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_evaluator=val_evaluator,
    val_cfg=val_cfg,
    test_dataloader=val_dataloader,  # 假设测试和验证使用相同的数据和评测器
    test_evaluator=val_evaluator,
    test_cfg=dict(type='TestLoop'),
)
runner.test()

work_dir 和训练时不一致,需要额外指定 load_from:

runner = Runner(
    model=model,
    work_dir='./test_work_dir',
    load_from='./work_dir/epoch_5.pth',  # work_dir 不一致,指定 load_from,以加载指定的模型
    randomness=randomness,
    env_cfg=env_cfg,
    launcher='none',
    optim_wrapper=optim_wrapper,
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_evaluator=val_evaluator,
    val_cfg=val_cfg,
    test_dataloader=val_dataloader,
    test_evaluator=val_evaluator,
    test_cfg=dict(type='TestLoop'),
)
runner.test()

迁移自定义执行流程

使用 MMCV 执行器时,我们可能会重载 runner.train/runner.val 或者 runner.run_iter 实现自定义的训练、测试流程。以重载 runner.train 为例,假设我们想对每个批次的图片训练两遍,我们可以这样重载 MMCV 的执行器:

class CustomRunner(EpochBasedRunner):
    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch
            self._inner_iter = i
            for _ in range(2)
                self.call_hook('before_train_iter')
                self.run_iter(data_batch, train_mode=True, **kwargs)
                self.call_hook('after_train_iter')
            del self.data_batch
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

在 MMEngine 中,要实现上述功能,我们需要重载一个新的循环控制器

from mmengine.registry import LOOPS
from mmengine.runner import EpochBasedTrainLoop


@LOOPS.register_module()
class CustomEpochBasedTrainLoop(EpochBasedTrainLoop):
    def run_iter(self, idx, data_batch) -> None:
        for _ in range(2):
            super().run_iter(idx, data_batch)

在构建执行器时,指定 train_cfgtypeCustomEpochBasedTrainLoop。需要注意的是,by_epochtype 不能同时配置,当配置 by_epoch 时,会推断训练循环控制器的类型为 EpochBasedTrainLoop

runner = Runner(
    model=model,
    work_dir='./test_work_dir',
    randomness=randomness,
    env_cfg=env_cfg,
    launcher='none',
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
    train_dataloader=train_dataloader,
    train_cfg=dict(
        type='CustomEpochBasedTrainLoop',
        max_epochs=5,
        val_interval=1),
    val_dataloader=val_dataloader,
    val_evaluator=val_evaluator,
    val_cfg=val_cfg,
    test_dataloader=val_dataloader,
    test_evaluator=val_evaluator,
    test_cfg=dict(type='TestLoop'),
)
runner.train()

如果有更加复杂的执行器迁移需求,可以参考执行器教程执行器设计文档

迁移 MMCV 钩子到 MMEngine

简介

由于架构设计的更新和用户需求的不断增加,MMCV 的钩子(Hook)点位已经满足不了需求,因此在 MMEngine 中对钩子点位进行了重新设计以及对钩子的功能做了调整。在开始迁移前,阅读钩子的设计会很有帮助。

本文对比 MMCV v1.6.0MMEngine v0.5.0 的钩子在功能、点位、用法和实现上的差异。

功能差异

MMCV MMEngine
反向传播以及梯度更新 OptimizerHook 将反向传播以及梯度更新的操作抽象成 OptimWrapper 而不是钩子
GradientCumulativeOptimizerHook
学习率调整 LrUpdaterHook ParamSchdulerHook 以及 _ParamScheduler 的子类完成优化器超参的调整
动量调整 MomentumUpdaterHook
按指定间隔保存权重 CheckpointHook CheckpointHook 除了保存权重,还有保存最优权重的功能,而 EvalHook 的模型评估功能则交由 ValLoop 或 TestLoop 完成
模型评估并保存最优模型 EvalHook
打印日志 LoggerHook 及其子类实现打印日志、保存日志以及可视化功能 LoggerHook
可视化 NaiveVisualizationHook
添加运行时信息 RuntimeInfoHook
模型参数指数滑动平均 EMAHook EMAHook
确保分布式 Sampler 的 shuffle 生效 DistSamplerSeedHook DistSamplerSeedHook
同步模型的 buffer SyncBufferHook SyncBufferHook
PyTorch CUDA 缓存清理 EmptyCacheHook EmptyCacheHook
统计迭代耗时 IterTimerHook IterTimerHook
分析训练时间的瓶颈 ProfilerHook 暂未提供
提供注册方法给钩子点位的功能 ClosureHook 暂未提供

点位差异

MMCV MMEngine
全局位点 执行前 before_run before_run
执行后 after_run after_run
Checkpoint 相关 加载 checkpoint 后 after_load_checkpoint
保存 checkpoint 前 before_save_checkpoint
训练相关 训练前触发 before_train
训练后触发 after_train
每个 epoch 前 before_train_epoch before_train_epoch
每个 epoch 后 after_train_epoch after_train_epoch
每次迭代前 before_train_iter before_train_iter,新增 batch_idx 和 data_batch 参数
每次迭代后 after_train_iter after_train_iter,新增 batch_idx、data_batch 和 outputs 参数
验证相关 验证前触发 before_val
验证后触发 after_val
每个 epoch 前 before_val_epoch before_val_epoch
每个 epoch 后 after_val_epoch after_val_epoch
每次迭代前 before_val_iter before_val_iter,新增 batch_idx 和 data_batch 参数
每次迭代后 after_val_iter after_val_iter,新增 batch_idx、data_batch 和 outputs 参数
测试相关 测试前触发 before_test
测试后触发 after_test
每个 epoch 前 before_test_epoch
每个 epoch 后 after_test_epoch
每次迭代前 before_test_iter,新增 batch_idx 和 data_batch 参数
每次迭代后 after_test_iter,新增 batch_idx、data_batch 和 outputs 参数

用法差异

在 MMCV 中,将钩子注册到执行器(Runner),需调用执行器的 register_training_hooks 方法往执行器注册钩子,而在 MMEngine 中,可以通过参数传递给执行器的初始化方法进行注册。

  • MMCV

model = ResNet18()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
lr_config = dict(policy='step', step=[2, 3])
optimizer_config = dict(grad_clip=None)
checkpoint_config = dict(interval=5)
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
runner = EpochBasedRunner(
    model=model,
    optimizer=optimizer,
    work_dir='./work_dir',
    max_epochs=3,
    xxx,
)
runner.register_training_hooks(
    lr_config=lr_config,
    optimizer_config=optimizer_config,
    checkpoint_config=checkpoint_config,
    log_config=log_config,
    custom_hooks_config=custom_hooks,
)
runner.run([trainloader], [('train', 1)])
  • MMEngine

model=ResNet18()
optim_wrapper=dict(
    type='OptimizerWrapper',
    optimizer=dict(type='SGD', lr=0.001, momentum=0.9))
param_scheduler = dict(type='MultiStepLR', milestones=[2, 3]),
default_hooks = dict(
    logger=dict(type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=5),
)
custom_hooks = [dict(type='NumClassCheckHook')]
runner = Runner(
    model=model,
    work_dir='./work_dir',
    optim_wrapper=optim_wrapper,
    param_scheduler=param_scheduler,
    train_cfg=dict(by_epoch=True, max_epochs=3),
    default_hooks=default_hooks,
    custom_hooks=custom_hooks,
    xxx,
)
runner.train()

MMEngine 钩子的更多用法请参考钩子的用法

实现差异

CheckpointHook 为例,MMEngine 的 CheckpointHook 相比 MMCV 的 CheckpointHook(新增保存最优权重的功能,在 MMCV 中,保存最优权重的功能由 EvalHook 提供),因此,它需要实现 after_val_epoch 点位。

  • MMCV

class CheckpointHook(Hook):
    def before_run(self, runner):
        """初始化 out_dir 和 file_client 属性"""

    def after_train_epoch(self, runner):
        """同步 buffer 和保存权重,用于以 epoch 为单位训练的任务"""

    def after_train_iter(self, runner):
        """同步 buffer 和保存权重,用于以 iteration 为单位训练的任务"""
  • MMEngine

class CheckpointHook(Hook):
    def before_run(self, runner):
        """初始化 out_dir 和 file_client 属性"""

    def after_train_epoch(self, runner):
        """同步 buffer 和保存权重,用于以 epoch 为单位训练的任务"""

    def after_train_iter(self, runner, batch_idx, data_batch, outputs):
        """同步 buffer 和保存权重,用于以 iteration 为单位训练的任务"""

    def after_val_epoch(self, runner, metrics):
        """根据 metrics 保存最优权重"""

迁移 MMCV 模型到 MMEngine

简介

MMCV 早期支持的计算机视觉任务,例如目标检测、物体识别等,都采用了一种典型的模型参数优化流程,可以被归纳为以下四个步骤:

  1. 计算损失

  2. 计算梯度

  3. 更新参数

  4. 梯度清零

上述流程的一大特点就是调用位置统一(在训练迭代后调用)、执行步骤统一(依次执行步骤 1->2->3->4),非常契合钩子(Hook)的设计原则,因此这类任务通常会使用 Hook 来优化模型。MMCV 为此实现了一系列的 Hook,例如 OptimizerHook(单精度训练)、Fp16OptimizerHook(混合精度训练) 和 GradientCumulativeFp16OptimizerHook(混合精度训练 + 梯度累加),为这类任务提供各种优化策略。

一些例如生成对抗网络(GAN),自监督(Self-supervision)等领域的算法一般有更加灵活的训练流程,这类流程并不满足调用位置统一、执行步骤统一的原则,难以使用 Hook 对参数进行优化。为了支持训练这类任务,MMCV 的执行器会在调用 model.train_step 时,额外传入 optimizer 参数,让模型在 train_step 里实现自定义的优化流程。这样虽然可以支持训练这类任务,但也会导致无法使用各种 OptimizerHook,需要算法在 train_step 中实现混合精度训练、梯度累加等训练策略。

为了统一深度学习任务的参数优化流程,MMEngine 设计了优化器封装,集成了混合精度训练、梯度累加等训练策略,各类深度学习任务一律在 model.train_step 里执行参数优化流程。

优化流程的迁移

常用的参数更新流程

考虑到目标检测、物体识别一类的深度学习任务参数优化的流程基本一致,我们可以通过继承模型基类来完成迁移。

基于 MMCV 执行器的模型

在介绍如何迁移模型之前,我们先来看一个基于 MMCV 执行器训练模型的最简示例:

import torch
import torch.nn as nn
from torch.optim import SGD
from torch.utils.data import DataLoader

from mmcv.runner import Runner
from mmcv.utils.logging import get_logger


train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)


class MMCVToyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, return_loss=False):
        feat = self.linear(img)
        loss1 = (feat - label).pow(2)
        loss2 = (feat - label).abs()
        loss = (loss1 + loss2).sum()
        return dict(loss=loss,
                    num_samples=len(img),
                    log_vars=dict(
                        loss1=loss1.sum().item(),
                        loss2=loss2.sum().item()))

    def train_step(self, data, optimizer=None):
        return self(*data, return_loss=True)

    def val_step(self, data, optimizer=None):
        return self(*data, return_loss=False)


model = MMCVToyModel()
optimizer = SGD(model.parameters(), lr=0.01)
logger = get_logger('demo')

lr_config = dict(policy='step', step=[2, 3])
optimizer_config = dict(grad_clip=None)
log_config = dict(interval=10, hooks=[dict(type='TextLoggerHook')])


runner = Runner(
    model=model,
    work_dir='tmp_dir',
    optimizer=optimizer,
    logger=logger,
    max_epochs=5)

runner.register_training_hooks(
    lr_config=lr_config,
    optimizer_config=optimizer_config,
    log_config=log_config)
runner.run([train_dataloader], [('train', 1)])

基于 MMCV 执行器训练模型时,我们必须实现 train_step 接口,并返回一个字典,字典需要包含以下三个字段:

  • loss:传给 OptimizerHook 计算梯度

  • num_samples:传给 LogBuffer,用于计算平滑后的损失

  • log_vars:传给 LogBuffer 用于计算平滑后的损失

基于 MMEngine 执行器的模型

基于 MMEngine 的执行器,实现同样逻辑的代码:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from mmengine.runner import Runner
from mmengine.model import BaseModel

train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)


class MMEngineToyModel(BaseModel):

    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, mode):
        feat = self.linear(img)
        # 被 `train_step` 调用,返回用于更新参数的损失字典
        if mode == 'loss':
            loss1 = (feat - label).pow(2)
            loss2 = (feat - label).abs()
            return dict(loss1=loss1, loss2=loss2)
        # 被 `val_step` 调用,返回传给 `evaluator` 的预测结果
        elif mode == 'predict':
            return [_feat for _feat in feat]
        # tensor 模式,功能详见模型教程文档: tutorials/model.md
        else:
            pass


runner = Runner(
    model=MMEngineToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=5),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)))
runner.train()

MMEngine 实现了模型基类,模型基类在 train_step 里实现了 OptimizerHook 的优化流程。因此上例中,我们无需实现 train_step,运行时直接调用基类的 train_step

MMCV 模型 MMEngine 模型
class MMCVToyModel(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, return_loss=False):
        feat = self.linear(img)
        loss1 = (feat - label).pow(2)
        loss2 = (feat - label).abs()
        loss = (loss1 + loss2).sum()
        return dict(loss=loss,
                    num_samples=len(img),
                    log_vars=dict(
                        loss1=loss1.sum().item(),
                        loss2=loss2.sum().item()))

    def train_step(self, data, optimizer=None):
        return self(*data, return_loss=True)

    def val_step(self, data, optimizer=None):
        return self(*data, return_loss=False)
class MMEngineToyModel(BaseModel):

    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, mode):
        if mode == 'loss':
            feat = self.linear(img)
            loss1 = (feat - label).pow(2)
            loss2 = (feat - label).abs()
            return dict(loss1=loss1, loss2=loss2)
        elif mode == 'predict':
            return [_feat for _feat in feat]
        else:
            pass

    # 模型基类 `train_step` 等效代码
    # def train_step(self, data, optim_wrapper):
    #     data = self.data_preprocessor(data)
    #     loss_dict = self(*data, mode='loss')
    #     loss_dict['loss1'] = loss_dict['loss1'].sum()
    #     loss_dict['loss2'] = loss_dict['loss2'].sum()
    #     loss = (loss_dict['loss1'] + loss_dict['loss2']).sum()
    #     调用优化器封装更新模型参数
    #     optim_wrapper.update_params(loss)
    #     return loss_dict

关于等效代码中的数据处理器(data_preprocessor)优化器封装(optim_wrapper) 的说明,详见模型教程优化器封装教程

模型具体差异如下:

  • MMCVToyModel 继承自 nn.Module,而 MMEngineToyModel 继承自 BaseModel

  • MMCVToyModel 必须实现 train_step,且必须返回损失字典,损失字典包含 losslog_varsnum_samples 字段。MMEngineToyModel 继承自 BaseModel,只需要实现 forward 接口,并返回损失字典,损失字典的每一个值必须是可微的张量

  • MMCVToyModelMMEngineModelforward 的接口需要匹配 train_step 中的调用方式,由于 MMEngineToyModel 直接调用基类的 train_step 方法,因此 forward 需要接受参数 mode,具体规则详见模型教程文档

自定义的参数更新流程

以训练生成对抗网络为例,生成器和判别器的优化需要交替进行,且优化流程可能会随着迭代次数的增多发生变化,因此很难使用 OptimizerHook 来满足这种需求。在基于 MMCV 训练生成对抗网络时,通常会在模型的 train_step 接口中传入 optimizer,然后在 train_step 里实现自定义的参数更新逻辑。这种训练流程和 MMEngine 非常相似,只不过 MMEngine 在 train_step 接口中传入优化器封装,能够更加简单地优化模型。

参考训练生成对抗网络,MMCV 和 MMEngine 的对比实现如下:

Training gan in MMCV Training gan in MMEngine
    def train_discriminator(self, inputs, optimizer):
        real_imgs = inputs['inputs']
        z = torch.randn(
            (real_imgs.shape[0], self.noise_size)).type_as(real_imgs)
        with torch.no_grad():
            fake_imgs = self.generator(z)

        disc_pred_fake = self.discriminator(fake_imgs)
        disc_pred_real = self.discriminator(real_imgs)

        parsed_losses, log_vars = self.disc_loss(disc_pred_fake,
                                                 disc_pred_real)
        parsed_losses.backward()
        optimizer.step()
        optimizer.zero_grad()
        return log_vars

    def train_generator(self, inputs, optimizer_wrapper):
        real_imgs = inputs['inputs']
        z = torch.randn(inputs['inputs'].shape[0], self.noise_size).type_as(
            real_imgs)

        fake_imgs = self.generator(z)

        disc_pred_fake = self.discriminator(fake_imgs)
        parsed_loss, log_vars = self.gen_loss(disc_pred_fake)

        parsed_losses.backward()
        optimizer.step()
        optimizer.zero_grad()
        return log_vars
    def train_discriminator(self, inputs, optimizer_wrapper):
        real_imgs = inputs['inputs']
        z = torch.randn(
            (real_imgs.shape[0], self.noise_size)).type_as(real_imgs)
        with torch.no_grad():
            fake_imgs = self.generator(z)

        disc_pred_fake = self.discriminator(fake_imgs)
        disc_pred_real = self.discriminator(real_imgs)

        parsed_losses, log_vars = self.disc_loss(disc_pred_fake,
                                                 disc_pred_real)
        optimizer_wrapper.update_params(parsed_losses)
        return log_vars



    def train_generator(self, inputs, optimizer_wrapper):
        real_imgs = inputs['inputs']
        z = torch.randn(real_imgs.shape[0], self.noise_size).type_as(real_imgs)

        fake_imgs = self.generator(z)

        disc_pred_fake = self.discriminator(fake_imgs)
        parsed_loss, log_vars = self.gen_loss(disc_pred_fake)

        optimizer_wrapper.update_params(parsed_loss)
        return log_vars

二者的区别主要在于优化器的使用方式。此外,train_step 接口返回值的差异和上一节提到的一致。

验证/测试流程的迁移

基于 MMCV 执行器实现的模型通常不需要为验证、测试流程提供独立的 val_steptest_step(测试流程由 EvalHook 实现,这里不做展开)。基于 MMEngine 执行器实现的模型则有所不同,ValLoopTestLoop 会分别调用模型的 val_steptest_step 接口,输出会进一步传给 Evaluator.process。因此模型的 val_steptest_step 接口输出需要和 Evaluator.process 的入参(第一个参数)对齐,即返回列表(推荐,也可以是其他可迭代类型)类型的结果。列表中的每一个元素代表一个批次(batch)数据中每个样本的预测结果。模型的 test_stepval_step 会调 forward 接口(详见模型教程文档),因此在上一节的模型示例中,模型 forwardpredict 模式会将 feat 切片后,以列表的形式返回预测结果。


class MMEngineToyModel(BaseModel):

    ...
    def forward(self, img, label, mode):
        if mode == 'loss':
            ...
        elif mode == 'predict':
            # 把一个 batch 的预测结果切片成列表,每个元素代表一个样本的预测结果
            return [_feat for _feat in feat]
        else:
            ...
            # tensor 模式,功能详见模型教程文档: tutorials/model.md

迁移分布式训练

MMCV 需要在执行器构建之前,使用 MMDistributedDataParallel 对模型进行分布式封装。MMEngine 实现了 MMDistributedDataParallelMMSeparateDistributedDataParallel 两种分布式模型封装,供不同类型的任务选择。执行器会在构建时对模型进行分布式封装。

  1. 常用训练流程

    对于简介中提到的常用优化流程的训练任务,即一次参数更新可以被拆解成梯度计算、参数优化、梯度清零的任务,使用 Runner 默认的 MMDistributedDataParallel 即可满足需求,无需为 runner 其他额外参数。

    MMCV 分布式训练构建模型 MMEngine 分布式训练
    model = MMDistributedDataParallel(
        model,
        device_ids=[int(os.environ['LOCAL_RANK'])],
        broadcast_buffers=False,
        find_unused_parameters=find_unused_parameters)
    ...
    runner = Runner(model=model, ...)
    
    runner = Runner(
        model=model,
        launcher='pytorch', #开启分布式训练
        ..., # 其他参数
    )
    

     

    1. 以自定义流程分模块优化模型的学习任务

      同样以训练生成对抗网络为例,生成对抗网络有两个需要分别优化的子模块,即生成器和判别器。因此需要使用 MMSeparateDistributedDataParallel 对模型进行封装。我们需要在构建执行器时指定:

      cfg = dict(model_wrapper_cfg='MMSeparateDistributedDataParallel')
      runner = Runner(
          model=model,
          ...,
          launcher='pytorch',
          cfg=cfg)
      

      即可进行分布式训练。

     

    1. 以自定义流程优化整个模型的深度学习任务

      有时候我们需要用自定义的优化流程来优化单个模块,这时候我们就不能复用模型基类的 train_step,而需要重新实现,例如我们想用同一批图片对模型优化两次,第一次开启批数据增强,第二次关闭:

      class CustomModel(BaseModel):
      
          def train_step(self, data, optim_wrapper):
              data = self.data_preprocessor(data, training=True)  # 开启批数据增强
              loss = self(data, mode='loss')
              optim_wrapper.update_params(loss)
              data = self.data_preprocessor(data, training=False)  # 关闭批数据增强
              loss = self(data, mode='loss')
              optim_wrapper.update_params(loss)
      

      要想启用分布式训练,我们就需要重载 MMSeparateDistributedDataParallel,并在 train_step 中实现和 CustomModel.train_step 相同的流程(test_stepval_step 同理)。

      class CustomDistributedDataParallel(MMSeparateDistributedDataParallel):
      
          def train_step(self, data, optim_wrapper):
              data = self.data_preprocessor(data, training=True)  # 开启批数据增强
              loss = self(data, mode='loss')
              optim_wrapper.update_params(loss)
              data = self.data_preprocessor(data, training=False)  # 关闭批数据增强
              loss = self(data, mode='loss')
              optim_wrapper.update_params(loss)
      

      最后在构建 runner 时指定:

      # 指定封装类型为 `CustomDistributedDataParallel`,并基于默认参数封装模型。
      cfg = dict(model_wrapper_cfg=dict(type='CustomDistributedDataParallel'))
      runner = Runner(
          model=model,
          ...,
          launcher='pytorch',
          cfg=cfg
      )
      

迁移 MMCV 参数调度器到 MMEngine

MMCV 1.x 版本使用 LrUpdaterHookMomentumUpdaterHook 来调整学习率和动量。 但随着深度学习算法训练方式的不断发展,使用 Hook 修改学习率已经难以满足更加丰富的自定义需求,因此 MMEngine 提供了参数调度器(ParamScheduler)。 一方面,参数调度器的接口与 PyTroch 的学习率调度器(LRScheduler)对齐,另一方面,参数调度器提供了更丰富的功能,详细请参考参数调度器使用指南

学习率调度器(LrUpdater)迁移

MMEngine 中使用 LRScheduler 替代 LrUpdaterHook,配置文件中的字段从原本的 lr_config 修改为 param_scheduler。 MMCV 中的学习率配置与 MMEngine 中的参数调度器配置对应关系如下:

学习率预热(Warmup)迁移

由于 MMEngine 中的学习率调度器在实现时增加了 begin 和 end 参数,指定了调度器的生效区间,所以可以通过调度器组合的方式实现学习率预热。MMCV 中有 3 种学习率预热方式,分别是 'constant', 'linear', 'exp',在 MMEngine 中对应的配置应修改为:

常数预热(constant)
MMCV-1.x MMEngine
lr_config = dict(
    warmup='constant',
    warmup_ratio=0.1,
    warmup_iters=500,
    warmup_by_epoch=False
)
param_scheduler = [
    dict(type='ConstantLR',
         factor=0.1,
         begin=0,
         end=500,
         by_epoch=False),
    dict(...) # 主学习率调度器配置
]
线性预热(linear)
MMCV-1.x MMEngine
lr_config = dict(
    warmup='linear',
    warmup_ratio=0.1,
    warmup_iters=500,
    warmup_by_epoch=False
)
param_scheduler = [
    dict(type='LinearLR',
         start_factor=0.1,
         begin=0,
         end=500,
         by_epoch=False),
    dict(...) # 主学习率调度器配置
]
指数预热(exp)
MMCV-1.x MMEngine
lr_config = dict(
    warmup='exp',
    warmup_ratio=0.1,
    warmup_iters=500,
    warmup_by_epoch=False
)
param_scheduler = [
    dict(type='ExponentialLR',
         gamma=0.1,
         begin=0,
         end=500,
         by_epoch=False),
    dict(...) # 主学习率调度器配置
]

fixed 学习率(FixedLrUpdaterHook)迁移

MMCV-1.x MMEngine
lr_config = dict(policy='fixed')
param_scheduler = [
    dict(type='ConstantLR', factor=1)
]

step 学习率(StepLrUpdaterHook)迁移

MMCV-1.x MMEngine
lr_config = dict(
    policy='step',
    step=[8, 11],
    gamma=0.1,
    by_epoch=True
)
param_scheduler = [
    dict(type='MultiStepLR',
         milestone=[8, 11],
         gamma=0.1,
         by_epoch=True)
]

poly 学习率(PolyLrUpdaterHook)迁移

MMCV-1.x MMEngine
lr_config = dict(
    policy='poly',
    power=0.7,
    min_lr=0.001,
    by_epoch=True
)
param_scheduler = [
    dict(type='PolyLR',
         power=0.7,
         eta_min=0.001,
         begin=0,
         end=num_epochs,
         by_epoch=True)
]

exp 学习率(ExpLrUpdaterHook)迁移

MMCV-1.x MMEngine
lr_config = dict(
    policy='exp',
    power=0.5,
    by_epoch=True
)
param_scheduler = [
    dict(type='ExponentialLR',
         gamma=0.5,
         begin=0,
         end=num_epochs,
         by_epoch=True)
]

CosineAnnealing 学习率(CosineAnnealingLrUpdaterHook)迁移

MMCV-1.x MMEngine
lr_config = dict(
    policy='CosineAnnealing',
    min_lr=0.5,
    by_epoch=True
)
param_scheduler = [
    dict(type='CosineAnnealingLR',
         eta_min=0.5,
         T_max=num_epochs,
         begin=0,
         end=num_epochs,
         by_epoch=True)
]

FlatCosineAnnealing 学习率(FlatCosineAnnealingLrUpdaterHook)迁移

像 FlatCosineAnnealing 这种由多个学习率策略拼接而成的学习率,原本需要重写 Hook 来实现,而在 MMEngine 中只需将两个参数调度器组合即可

MMCV-1.x MMEngine
lr_config = dict(
    policy='FlatCosineAnnealing',
    start_percent=0.5,
    min_lr=0.005,
    by_epoch=True
)
param_scheduler = [
    dict(type='ConstantLR', factor=1, begin=0, end=num_epochs * 0.75)
    dict(type='CosineAnnealingLR',
         eta_min=0.005,
         begin=num_epochs * 0.75,
         end=num_epochs,
         T_max=num_epochs * 0.25,
         by_epoch=True)
]

CosineRestart 学习率(CosineRestartLrUpdaterHook)迁移

MMCV-1.x MMEngine
lr_config = dict(policy='CosineRestart',
                 periods=[5, 10, 15],
                 restart_weights=[1, 0.7, 0.3],
                 min_lr=0.001,
                 by_epoch=True)
param_scheduler = [
    dict(type='CosineRestartLR',
         periods=[5, 10, 15],
         restart_weights=[1, 0.7, 0.3],
         eta_min=0.001,
         by_epoch=True)
]

OneCycle 学习率(OneCycleLrUpdaterHook)迁移

MMCV-1.x MMEngine
lr_config = dict(policy='OneCycle',
                 max_lr=0.02,
                 total_steps=90000,
                 pct_start=0.3,
                 anneal_strategy='cos',
                 div_factor=25,
                 final_div_factor=1e4,
                 three_phase=True,
                 by_epoch=False)
param_scheduler = [
    dict(type='OneCycleLR',
         eta_max=0.02,
         total_steps=90000,
         pct_start=0.3,
         anneal_strategy='cos',
         div_factor=25,
         final_div_factor=1e4,
         three_phase=True,
         by_epoch=False)
]

需要注意的是 by_epoch 参数 MMCV 默认是 False, MMEngine 默认是 True

LinearAnnealing 学习率(LinearAnnealingLrUpdaterHook)迁移

MMCV-1.x MMEngine
lr_config = dict(
    policy='LinearAnnealing',
    min_lr_ratio=0.01,
    by_epoch=True
)
param_scheduler = [
    dict(type='LinearLR',
         start_factor=1,
         end_factor=0.01,
         begin=0,
         end=num_epochs,
         by_epoch=True)
]

动量调度器(MomentumUpdater)迁移

MMCV 使用 momentum_config 字段和 MomentumUpdateHook 调整动量。 MMEngine 中动量同样由参数调度器控制。用户可以简单将学习率调度器后的 LR 修改为 Momentum,即可使用同样的策略来调整动量。动量调度器只需要和学习率调度器一样添加进 param_scheduler 列表中即可。举一个简单的例子:

MMCV-1.x MMEngine
lr_config = dict(...)
momentum_config = dict(
    policy='CosineAnnealing',
    min_momentum=0.1,
    by_epoch=True
)
param_scheduler = [
    # 学习率调度器配置
    dict(...),
    # 动量调度器配置
    dict(type='CosineAnnealingMomentum',
         eta_min=0.1,
         T_max=num_epochs,
         begin=0,
         end=num_epochs,
         by_epoch=True)
]

参数更新频率相关配置迁移

如果在使用 epoch-based 训练循环且配置文件中按 epoch 设置生效区间(beginend)或周期(T_max)等变量的同时希望参数率按 iteration 更新,在 MMCV 中需要将 by_epoch 设置为 False。而在 MMEngine 中需要注意,配置中的 by_epoch 仍需设置为 True,通过在配置中添加 convert_to_iter_based=True 来构建按 iteration 更新的参数调度器,关于此配置详见参数调度器教程。 以迁移CosineAnnealing为例:

MMCV-1.x MMEngine
lr_config = dict(
    policy='CosineAnnealing',
    min_lr=0.5,
    by_epoch=False
)
param_scheduler = [
    dict(
        type='CosineAnnealingLR',
        eta_min=0.5,
        T_max=num_epochs,
        by_epoch=True,  # 注意,by_epoch 需要设置为 True
        convert_to_iter_based=True  # 转换为按 iter 更新参数
    )
]

你可能还想阅读参数调度器的教程或者参数调度器的 API 文档

数据变换类的迁移

简介

在 TorchVision 的数据变换类接口约定中,数据变换类需要实现 __call__ 方法,而在 OpenMMLab 1.0 的接口约定中,进一步要求 __call__ 方法的输出应当是一个字典,在各种数据变换中对这个字典进行增删查改。在 OpenMMLab 2.0 中,为了提升后续的可扩展性,我们将原先的 __call__ 方法迁移为 transform 方法,并要求数据变换类应当继承 mmcv.transforms.BaseTransfrom。具体如何实现一个数据变换类,可以参见文档

由于在此次更新中,我们将部分共用的数据变换类统一迁移至 MMCV 中,因此本文将会对比这些数据变换在旧版本(MMClassification v0.23.2MMDetection v2.25.1)和新版本(MMCV v2.0.0rc0)中的功能、用法和实现上的差异。

功能差异

MMClassification (旧) MMDetection (旧) MMCV (新)
LoadImageFromFile 从 'img_prefix' 和 'img_info.filename' 字段组合获得文件路径并读取 从 'img_prefix' 和 'img_info.filename' 字段组合获得文件路径并读取,支持指定通道顺序 从 'img_path' 获得文件路径并读取,支持指定加载失败不报错,支持指定解码后端
LoadAnnotations 支持读取 bbox,label,mask(包括多边形样式),seg map,转换 bbox 坐标系 支持读取 bbox,label,mask(不包括多边形样式),seg map
Pad 填充 "img_fields" 中所有字段,不支持指定填充至整数倍 填充 "img_fields" 中所有字段,支持指定填充至整数倍 填充 "img" 字段,支持指定填充至整数倍
CenterCrop 裁切 "img_fields" 中所有字段,支持以 EfficientNet 方式进行裁切 裁切 "img" 字段的图像,"gt_bboxes" 字段的 bbox,"gt_seg_map" 字段的分割图,"gt_keypoints" 字段的关键点,支持自动填充裁切边缘
Normalize 图像归一化 无差异 无差异,但 MMEngine 推荐在数据预处理器中进行归一化
Resize 缩放 "img_fields" 中所有字段,允许指定根据某边长等比例缩放 功能由 Resize 实现。需要 ratio_range 为 None,img_scale 仅指定一个尺寸,且 multiscale_mode 为 "value" 。 缩放 "img" 字段的图像,"gt_bboxes" 字段的 bbox,"gt_seg_map" 字段的分割图,"gt_keypoints" 字段的关键点,支持指定缩放比例,支持等比例缩放图像至指定尺寸内
RandomResize 功能由 Resize 实现。需要 ratio_range 为 None,img_scale指定两个尺寸,且 multiscale_mode 为 "range",或 ratio_range 不为 None。
Resize(
    img_sacle=[(640, 480), (960, 720)],
    mode="range",
)
缩放功能同 Resize,支持从指定尺寸范围或指定比例范围随机采样缩放尺寸。
RandomResize(scale=[(640, 480), (960, 720)])
RandomChoiceResize 功能由 Resize 实现。需要 ratio_range 为 None,img_scale 指定多个尺寸,且 multiscale_mode 为 "value"。
Resize(
    img_sacle=[(640, 480), (960, 720)],
    mode="value",
)
缩放功能同 Resize,支持从若干指定尺寸中随机选择缩放尺寸。
RandomChoiceResize(scales=[(640, 480), (960, 720)])
RandomGrayscale 灰度化 "img_fields" 中所有字段,灰度化后保持通道数。 灰度化 "img" 字段,支持指定灰度化权重,支持指定是否在灰度化后保持通道数(默认不保持)。
RandomFlip 翻转 "img_fields" 中所有字段,支持指定水平或垂直翻转。 翻转 "img_fields", "bbox_fields", "mask_fields", "seg_fields" 中所有字段,支持指定水平、垂直或对角翻转,支持指定各类翻转概率。 翻转 "img", "gt_bboxes", "gt_seg_map", "gt_keypoints" 字段,支持指定水平、垂直或对角翻转,支持指定各类翻转概率。
MultiScaleFlipAug 用于测试时增强 使用 TestTimeAug
ToTensor 将指定字段转换为 torch.Tensor 无差异 无差异
ImageToTensor 将指定字段转换为 torch.Tensor,并调整通道顺序至 CHW。 无差异 无差异

实现差异

RandomFlip 为例,MMCV 的 RandomFlip 相比旧版 MMDetection 的 RandomFlip,需要继承 BaseTransfrom,将功能实现放在 transforms 方法,并将生成随机结果的部分放在单独的方法中,用 cache_randomness 包装。有关随机方法的包装相关功能,参见相关文档

  • MMDetection (旧)

class RandomFlip:
    def __call__(self, results):
        """调用时进行随机翻转"""
        ...
        # 随机选择翻转方向
        cur_dir = np.random.choice(direction_list, p=flip_ratio_list)
        ...
        return results
  • MMCV

class RandomFlip(BaseTransfrom):
    def transform(self, results):
        """调用时进行随机翻转"""
        ...
        cur_dir = self._random_direction()
        ...
        return results

    @cache_randomness
    def _random_direction(self):
        """随机选择翻转方向"""
        ...
        return np.random.choice(direction_list, p=flip_ratio_list)

mmengine.registry

Registry

A registry to map strings to classes or functions.

DefaultScope

Scope of current task used to reset the current registry, which can be accessed globally.

build_from_cfg

Build a module from config dict when it is a class configuration, or call a function from config dict when it is a function configuration.

build_model_from_cfg

Build a PyTorch model from config dict(s).

build_runner_from_cfg

Build a Runner object.

build_scheduler_from_cfg

Builds a ParamScheduler instance from config.

count_registered_modules

Scan all modules in MMEngine’s root and child registries and dump to json.

traverse_registry_tree

Traverse the whole registry tree from any given node, and collect information of all registered modules in this registry tree.

init_default_scope

Initialize the given default scope.

mmengine.config

Config

A facility for config and config files.

ConfigDict

A dictionary for config which has the same interface as python’s built- in dictionary and can be used as a normal dictionary.

DictAction

argparse action to split an argument into KEY=VALUE form on the first = and append to a dictionary.

mmengine.runner

Runner

Runner

A training helper for PyTorch.

Loop

BaseLoop

Base loop class.

EpochBasedTrainLoop

Loop for epoch-based training.

IterBasedTrainLoop

Loop for iter-based training.

ValLoop

Loop for validation.

TestLoop

Loop for test.

Checkpoints

CheckpointLoader

A general checkpoint loader to manage all schemes.

find_latest_checkpoint

Find the latest checkpoint from the given path.

get_deprecated_model_names

get_external_models

get_mmcls_models

get_state_dict

Returns a dictionary containing a whole state of the module.

get_torchvision_models

load_checkpoint

Load checkpoint from a file or URI.

load_state_dict

Load state_dict to a module.

save_checkpoint

Save checkpoint to file.

weights_to_cpu

Copy a model state_dict to cpu.

AMP

autocast

A wrapper of torch.autocast and toch.cuda.amp.autocast.

Miscellaneous

LogProcessor

A log processor used to format log information collected from runner.message_hub.log_scalars.

Priority

Hook priority levels.

get_priority

Get priority value.

mmengine.hooks

Hook

Base hook class.

CheckpointHook

Save checkpoints periodically.

EMAHook

A Hook to apply Exponential Moving Average (EMA) on the model during training.

LoggerHook

Collect logs from different components of Runner and write them to terminal, JSON file, tensorboard and wandb .etc.

NaiveVisualizationHook

Show or Write the predicted results during the process of testing.

ParamSchedulerHook

A hook to update some hyper-parameters in optimizer, e.g., learning rate and momentum.

RuntimeInfoHook

A hook that updates runtime information into message hub.

DistSamplerSeedHook

Data-loading sampler for distributed training.

IterTimerHook

A hook that logs the time spent during iteration.

SyncBuffersHook

Synchronize model buffers such as running_mean and running_var in BN at the end of each epoch.

EmptyCacheHook

Releases all unoccupied cached GPU memory during the process of training.

ProfilerHook

A hook to analyze performance during training and inference.

PrepareTTAHook

Wraps runner.model with subclass of BaseTTAModel in before_test.

mmengine.model

Module

BaseModule

Base module for all modules in openmmlab.

ModuleDict

ModuleDict in openmmlab.

ModuleList

ModuleList in openmmlab.

Sequential

Sequential module in openmmlab.

Model

BaseModel

Base class for all algorithmic models.

BaseDataPreprocessor

Base data pre-processor used for copying data to the target device.

ImgDataPreprocessor

Image pre-processor for normalization and bgr to rgb conversion.

BaseTTAModel

Base model for inference with test-time augmentation.

EMA

BaseAveragedModel

A base class for averaging model weights.

ExponentialMovingAverage

Implements the exponential moving average (EMA) of the model.

MomentumAnnealingEMA

Exponential moving average (EMA) with momentum annealing strategy.

StochasticWeightAverage

Implements the stochastic weight averaging (SWA) of the model.

Model Wrapper

MMDistributedDataParallel

A distributed model wrapper used for training,testing and validation in loop.

MMSeparateDistributedDataParallel

A DistributedDataParallel wrapper for models in MMGeneration.

MMFullyShardedDataParallel

A wrapper for sharding Module parameters across data parallel workers.

is_model_wrapper

Check if a module is a model wrapper.

Weight Initialization

BaseInit

Caffe2XavierInit

ConstantInit

Initialize module parameters with constant values.

KaimingInit

Initialize module parameters with the values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K.

NormalInit

Initialize module parameters with the values drawn from the normal distribution \(\mathcal{N}(\text{mean}, \text{std}^2)\).

PretrainedInit

Initialize module by loading a pretrained model.

TruncNormalInit

Initialize module parameters with the values drawn from the normal distribution \(\mathcal{N}(\text{mean}, \text{std}^2)\) with values outside \([a, b]\).

UniformInit

Initialize module parameters with values drawn from the uniform distribution \(\mathcal{U}(a, b)\).

XavierInit

Initialize module parameters with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks - Glorot, X.

bias_init_with_prob

initialize conv/fc bias value according to a given probability value.

caffe2_xavier_init

constant_init

initialize

Initialize a module.

kaiming_init

normal_init

trunc_normal_init

uniform_init

update_init_info

Update the _params_init_info in the module if the value of parameters are changed.

xavier_init

Utils

detect_anomalous_params

merge_dict

Merge all dictionaries into one dictionary.

stack_batch

Stack multiple tensors to form a batch and pad the tensor to the max shape use the right bottom padding mode in these images.

revert_sync_batchnorm

Helper function to convert all SyncBatchNorm (SyncBN) and mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to `BatchNormXd layers.

convert_sync_batchnorm

Helper function to convert all BatchNorm layers in the model to SyncBatchNorm (SyncBN) or `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers. Adapted from <https://pytorch.org/docs/stable/generated/torch.nn.Sy ncBatchNorm.html#torch.nn.SyncBatchNorm.convert_sync_batchnorm>_.

mmengine.optim

mmengine.optim

Optimizer

AmpOptimWrapper

A subclass of OptimWrapper that supports automatic mixed precision training based on torch.cuda.amp.

OptimWrapper

Optimizer wrapper provides a common interface for updating parameters.

OptimWrapperDict

A dictionary container of OptimWrapper.

DefaultOptimWrapperConstructor

Default constructor for optimizers.

build_optim_wrapper

Build function of OptimWrapper.

Scheduler

_ParamScheduler

Base class for parameter schedulers.

ConstantLR

Decays the learning rate value of each parameter group by a small constant factor until the number of epoch reaches a pre-defined milestone: end.

ConstantMomentum

Decays the momentum value of each parameter group by a small constant factor until the number of epoch reaches a pre-defined milestone: end.

ConstantParamScheduler

Decays the parameter value of each parameter group by a small constant factor until the number of epoch reaches a pre-defined milestone: end.

CosineAnnealingLR

Set the learning rate of each parameter group using a cosine annealing schedule, where \(\eta_{max}\) is set to the initial value and \(T_{cur}\) is the number of epochs since the last restart in SGDR:

CosineAnnealingMomentum

Set the momentum of each parameter group using a cosine annealing schedule, where \(\eta_{max}\) is set to the initial value and \(T_{cur}\) is the number of epochs since the last restart in SGDR:

CosineAnnealingParamScheduler

Set the parameter value of each parameter group using a cosine annealing schedule, where \(\eta_{max}\) is set to the initial value and \(T_{cur}\) is the number of epochs since the last restart in SGDR:

ExponentialLR

Decays the learning rate of each parameter group by gamma every epoch.

ExponentialMomentum

Decays the momentum of each parameter group by gamma every epoch.

ExponentialParamScheduler

Decays the parameter value of each parameter group by gamma every epoch.

LinearLR

Decays the learning rate of each parameter group by linearly changing small multiplicative factor until the number of epoch reaches a pre-defined milestone: end.

LinearMomentum

Decays the momentum of each parameter group by linearly changing small multiplicative factor until the number of epoch reaches a pre-defined milestone: end.

LinearParamScheduler

Decays the parameter value of each parameter group by linearly changing small multiplicative factor until the number of epoch reaches a pre-defined milestone: end.

MultiStepLR

Decays the specified learning rate in each parameter group by gamma once the number of epoch reaches one of the milestones.

MultiStepMomentum

Decays the specified momentum in each parameter group by gamma once the number of epoch reaches one of the milestones.

MultiStepParamScheduler

Decays the specified parameter in each parameter group by gamma once the number of epoch reaches one of the milestones.

OneCycleLR

Sets the learning rate of each parameter group according to the 1cycle learning rate policy.

OneCycleParamScheduler

Sets the parameters of each parameter group according to the 1cycle learning rate policy.

PolyLR

Decays the learning rate of each parameter group in a polynomial decay scheme.

PolyMomentum

Decays the momentum of each parameter group in a polynomial decay scheme.

PolyParamScheduler

Decays the parameter value of each parameter group in a polynomial decay scheme.

StepLR

Decays the learning rate of each parameter group by gamma every step_size epochs.

StepMomentum

Decays the momentum of each parameter group by gamma every step_size epochs.

StepParamScheduler

Decays the parameter value of each parameter group by gamma every step_size epochs.

mmengine.evaluator

mmengine.evaluator

Evaluator

Evaluator

Wrapper class to compose multiple BaseMetric instances.

Metric

BaseMetric

Base class for a metric.

DumpResults

Dump model predictions to a pickle file for offline evaluation.

Utils

get_metric_value

Get the metric value specified by an indicator, which can be either a metric name or a full name with evaluator prefix.

mmengine.structures

BaseDataElement

A base data interface that supports Tensor-like and dict-like operations.

InstanceData

Data structure for instance-level annotations or predictions.

LabelData

Data structure for label-level annotations or predictions.

PixelData

Data structure for pixel-level annotations or predictions.

mmengine.dataset

Dataset

BaseDataset

BaseDataset for open source projects in OpenMMLab.

Compose

Compose multiple transforms sequentially.

Dataset Wrapper

ClassBalancedDataset

A wrapper of class balanced dataset.

ConcatDataset

A wrapper of concatenated dataset.

RepeatDataset

A wrapper of repeated dataset.

Sampler

DefaultSampler

The default data sampler for both distributed and non-distributed environment.

InfiniteSampler

It’s designed for iteration-based runner and yields a mini-batch indices each time.

Utils

default_collate

Convert list of data sampled from dataset into a batch of data, of which type consistent with the type of each data_itement in data_batch.

pseudo_collate

Convert list of data sampled from dataset into a batch of data, of which type consistent with the type of each data_itement in data_batch.

worker_init_fn

This function will be called on each worker subprocess after seeding and before data loading.

mmengine.device

get_device

Returns the currently existing device type.

get_max_cuda_memory

Returns the maximum GPU memory occupied by tensors in megabytes (MB) for a given device.

is_cuda_available

Returns True if cuda devices exist.

is_npu_available

Returns True if Ascend PyTorch and npu devices exist.

is_mlu_available

Returns True if Cambricon PyTorch and mlu devices exist.

is_mps_available

Return True if mps devices exist.

mmengine.hub

get_config

Get config from external package.

get_model

Get built model from external package.

mmengine.logging

MMLogger

Formatted logger used to record messages.

MessageHub

Message hub for component interaction.

HistoryBuffer

Unified storage format for different log types.

print_log

Print a log message.

mmengine.visualization

mmengine.visualization

Visualizer

Visualizer

MMEngine provides a Visualizer class that uses the Matplotlib library as the backend.

visualization Backend

BaseVisBackend

Base class for visualization backend.

LocalVisBackend

Local visualization backend class.

TensorboardVisBackend

Tensorboard visualization backend class.

WandbVisBackend

Wandb visualization backend class.

mmengine.fileio

File Backend

BaseStorageBackend

Abstract class of storage backends.

FileClient

A general file client to access files in different backends.

HardDiskBackend

Raw hard disks storage backend.

LocalBackend

Raw local storage backend.

HTTPBackend

HTTP and HTTPS storage bachend.

LmdbBackend

Lmdb storage backend.

MemcachedBackend

Memcached storage backend.

PetrelBackend

Petrel storage backend (for internal usage).

register_backend

Register a backend.

File IO

dump

Dump data to json/yaml/pickle strings or files.

load

Load data from json/yaml/pickle files.

copy_if_symlink_fails

Create a symbolic link pointing to src named dst.

copyfile

Copy a file src to dst and return the destination file.

copyfile_from_local

Copy a local file src to dst and return the destination file.

copyfile_to_local

Copy the file src to local dst and return the destination file.

copytree

Recursively copy an entire directory tree rooted at src to a directory named dst and return the destination directory.

copytree_from_local

Recursively copy an entire directory tree rooted at src to a directory named dst and return the destination directory.

copytree_to_local

Recursively copy an entire directory tree rooted at src to a local directory named dst and return the destination directory.

exists

Check whether a file path exists.

generate_presigned_url

Generate the presigned url of video stream which can be passed to mmcv.VideoReader.

get

Read bytes from a given filepath with ‘rb’ mode.

get_file_backend

Return a file backend based on the prefix of uri or backend_args.

get_local_path

Download data from filepath and write the data to local path.

get_text

Read text from a given filepath with ‘r’ mode.

isdir

Check whether a file path is a directory.

isfile

Check whether a file path is a file.

join_path

Concatenate all file paths.

list_dir_or_file

Scan a directory to find the interested directories or files in arbitrary order.

put

Write bytes to a given filepath with ‘wb’ mode.

put_text

Write text to a given filepath with ‘w’ mode.

remove

Remove a file.

rmtree

Recursively delete a directory tree.

Parse File

dict_from_file

Load a text file and parse the content as a dict.

list_from_file

Load a text file and parse the content as a list of strings.

mmengine.dist

mmengine.dist

dist

gather

Gather data from the whole group to dst process.

gather_object

Gathers picklable objects from the whole group in a single process.

all_gather

Gather data from the whole group in a list.

all_gather_object

Gather picklable objects from the whole group into a list.

all_reduce

Reduces the tensor data across all machines in such a way that all get the final result.

all_reduce_dict

Reduces the dict across all machines in such a way that all get the final result.

all_reduce_params

All-reduce parameters.

broadcast

Broadcast the data from src process to the whole group.

sync_random_seed

Synchronize a random seed to all processes.

broadcast_object_list

Broadcasts picklable objects in object_list to the whole group.

collect_results

Collected results in distributed environments.

collect_results_cpu

Collect results under cpu mode.

collect_results_gpu

Collect results under gpu mode.

utils

get_dist_info

Get distributed information of the given process group.

init_dist

Initialize distributed environment.

init_local_group

Setup the local process group.

get_backend

Return the backend of the given process group.

get_world_size

Return the number of the given process group.

get_rank

Return the rank of the given process group.

get_local_size

Return the number of the current node.

get_local_rank

Return the rank of current process in the current node.

is_main_process

Whether the current rank of the given process group is equal to 0.

master_only

Decorate those methods which should be executed in master process.

barrier

Synchronize all processes from the given process group.

is_distributed

Return True if distributed environment has been initialized.

get_local_group

Return local process group.

get_default_group

Return default process group.

get_data_device

Return the device of data.

get_comm_device

Return the device for communication among groups.

cast_data_device

Recursively convert Tensor in data to device.

mmengine.utils

Manager

ManagerMeta

The metaclass for global accessible class.

ManagerMixin

ManagerMixin is the base class for classes that have global access requirements.

Path

check_file_exist

fopen

is_abs

Check if path is an absolute path in different backends.

is_filepath

mkdir_or_exist

scandir

Scan a directory to find the interested files.

symlink

Package

call_command

install_package

get_installed_path

Get installed path of package.

is_installed

Check package whether installed.

Version

digit_version

Convert a version string into a tuple of integers.

get_git_hash

Get the git hash of the current repo.

Progress Bar

ProgressBar

A progress bar which can print the progress.

track_iter_progress

Track the progress of tasks iteration or enumeration with a progress bar.

track_parallel_progress

Track the progress of parallel task execution with a progress bar.

track_progress

Track the progress of tasks execution with a progress bar.

Miscellaneous

Timer

A flexible Timer class.

TimerError

is_list_of

Check whether it is a list of some type.

is_tuple_of

Check whether it is a tuple of some type.

is_seq_of

Check whether it is a sequence of some type.

is_str

Whether the input is an string instance.

iter_cast

Cast elements of an iterable object into some type.

list_cast

Cast elements of an iterable object into a list of some type.

tuple_cast

Cast elements of an iterable object into a tuple of some type.

concat_list

Concatenate a list of list into a single list.

slice_list

Slice a list into several sub lists by a list of given length.

to_1tuple

to_2tuple

to_3tuple

to_4tuple

to_ntuple

check_prerequisites

A decorator factory to check if prerequisites are satisfied.

deprecated_api_warning

A decorator to check if some arguments are deprecate and try to replace deprecate src_arg_name to dst_arg_name.

deprecated_function

Marks functions as deprecated.

has_method

Check whether the object has a method.

is_method_overridden

Check if a method of base class is overridden in derived class.

import_modules_from_strings

Import modules from the given list of strings.

requires_executable

A decorator to check if some executable files are installed.

requires_package

A decorator to check if some python packages are installed.

check_time

Add check points in a single line.

mmengine.utils.dl_utils

TimeCounter

A tool that counts the average running time of a function or a method.

collect_env

Collect the information of the running environments.

load_url

Loads the Torch serialized object at the given URL.

has_batch_norm

Detect whether model has a BatchNormalization layer.

is_norm

Check if a layer is a normalization layer.

mmcv_full_available

Check whether mmcv-full is installed.

tensor2imgs

Convert tensor to 3-channel images or 1-channel gray images.

TORCH_VERSION

A string with magic powers to compare to both Version and iterables! Prior to 1.10.0 torch.__version__ was stored as a str and so many did comparisons against torch.__version__ as if it were a str.

set_multi_processing

Set multi-processing related environment.

torch_meshgrid

A wrapper of torch.meshgrid to compat different PyTorch versions.

is_jit_tracing

Changelog of v0.x

v0.2.0 (11/10/2022)

New Features & Enhancements

  • Add SMDDP backend and support running on AWS by @austinmw in https://github.com/open-mmlab/mmengine/pull/579

  • Refactor FileIO but without breaking bc by @zhouzaida in https://github.com/open-mmlab/mmengine/pull/533

  • Add test time augmentation base model by @HAOCHENYE in https://github.com/open-mmlab/mmengine/pull/538

  • Use torch.lerp\_() to speed up EMA by @RangiLyu in https://github.com/open-mmlab/mmengine/pull/519

  • Support converting BN to SyncBN by config by @HAOCHENYE in https://github.com/open-mmlab/mmengine/pull/506

  • Support defining metric name in wandb backend by @okotaku in https://github.com/open-mmlab/mmengine/pull/509

  • Add dockerfile by @zhouzaida in https://github.com/open-mmlab/mmengine/pull/347

Docs

  • Fix API files of English documentation by @zhouzaida in https://github.com/open-mmlab/mmengine/pull/525

  • Fix typo in instance_data.py by @Dai-Wenxun in https://github.com/open-mmlab/mmengine/pull/530

  • Fix the docstring of the model sub-package by @zhouzaida in https://github.com/open-mmlab/mmengine/pull/573

  • Fix a spelling error in docs/zh_cn by @cxiang26 in https://github.com/open-mmlab/mmengine/pull/548

  • Fix typo in docstring by @MengzhangLI in https://github.com/open-mmlab/mmengine/pull/527

  • Update config.md by @Zhengfei-0311 in https://github.com/open-mmlab/mmengine/pull/562

Bug Fixes

  • Fix LogProcessor does not smooth loss if the name of loss doesn’t start with loss by @liuyanyi in https://github.com/open-mmlab/mmengine/pull/539

  • Fix failed to enable detect_anomalous_params in MMSeparateDistributedDataParallel by @HAOCHENYE in https://github.com/open-mmlab/mmengine/pull/588

  • Fix CheckpointHook behavior unexpected if given filename_tmpl argument by @C1rN09 in https://github.com/open-mmlab/mmengine/pull/518

  • Fix error argument sequence in FSDP by @HAOCHENYE in https://github.com/open-mmlab/mmengine/pull/520

  • Fix uploading image in wandb backend @okotaku in https://github.com/open-mmlab/mmengine/pull/510

  • Fix loading state dictionary in EMAHook by @okotaku in https://github.com/open-mmlab/mmengine/pull/507

  • Fix circle import in EMAHook by @HAOCHENYE in https://github.com/open-mmlab/mmengine/pull/523

  • Fix unit test could fail caused by MultiProcessTestCase by @HAOCHENYE in https://github.com/open-mmlab/mmengine/pull/535

  • Remove unnecessary “if statement” in Registry by @MambaWong in https://github.com/open-mmlab/mmengine/pull/536

  • Fix _save_to_state_dict by @HAOCHENYE in https://github.com/open-mmlab/mmengine/pull/542

  • Support comparing NumPy array dataset meta in Runner.resume by @HAOCHENYE in https://github.com/open-mmlab/mmengine/pull/511

  • Use get instead of pop to dump runner_type in build_runner_from_cfg by @nijkah in https://github.com/open-mmlab/mmengine/pull/549

  • Upgrade pre-commit hooks by @zhouzaida in https://github.com/open-mmlab/mmengine/pull/576

  • Delete the error comment in registry.md by @vansin in https://github.com/open-mmlab/mmengine/pull/514

  • Fix Some out-of-date unit tests by @C1rN09 in https://github.com/open-mmlab/mmengine/pull/586

  • Fix typo in MMFullyShardedDataParallel by @yhna940 in https://github.com/open-mmlab/mmengine/pull/569

  • Update Github Action CI and CircleCI by @zhouzaida in https://github.com/open-mmlab/mmengine/pull/512

  • Fix unit test in windows by @HAOCHENYE in https://github.com/open-mmlab/mmengine/pull/515

  • Fix merge ci & multiprocessing unit test by @HAOCHENYE in https://github.com/open-mmlab/mmengine/pull/529

New Contributors

  • @okotaku made their first contribution in https://github.com/open-mmlab/mmengine/pull/510

  • @MengzhangLI made their first contribution in https://github.com/open-mmlab/mmengine/pull/527

  • @MambaWong made their first contribution in https://github.com/open-mmlab/mmengine/pull/536

  • @cxiang26 made their first contribution in https://github.com/open-mmlab/mmengine/pull/548

  • @nijkah made their first contribution in https://github.com/open-mmlab/mmengine/pull/549

  • @Zhengfei-0311 made their first contribution in https://github.com/open-mmlab/mmengine/pull/562

  • @austinmw made their first contribution in https://github.com/open-mmlab/mmengine/pull/579

  • @yhna940 made their first contribution in https://github.com/open-mmlab/mmengine/pull/569

  • @liuyanyi made their first contribution in https://github.com/open-mmlab/mmengine/pull/539

贡献代码

欢迎加入 MMEngine 社区,我们致力于打造最前沿的深度学习模型训练的基础库,我们欢迎任何类型的贡献,包括但不限于

修复错误

修复代码实现错误的步骤如下:

  1. 如果提交的代码改动较大,建议先提交 issue,并正确描述 issue 的现象、原因和复现方式,讨论后确认修复方案。

  2. 修复错误并补充相应的单元测试,提交拉取请求。

新增功能或组件

  1. 如果新功能或模块涉及较大的代码改动,建议先提交 issue,确认功能的必要性。

  2. 实现新增功能并添单元测试,提交拉取请求。

文档补充

修复文档可以直接提交拉取请求

添加文档或将文档翻译成其他语言步骤如下

  1. 提交 issue,确认添加文档的必要性。

  2. 添加文档,提交拉取请求。

拉取请求工作流

如果你对拉取请求不了解,没关系,接下来的内容将会从零开始,一步一步地指引你如何创建一个拉取请求。如果你想深入了解拉取请求的开发模式,可以参考 github 官方文档

1. 复刻仓库

当你第一次提交拉取请求时,先复刻 OpenMMLab 原代码库,点击 GitHub 页面右上角的 Fork 按钮,复刻后的代码库将会出现在你的 GitHub 个人主页下。

将代码克隆到本地

git clone git@github.com:{username}/mmengine.git

添加原代码库为上游代码库

git remote add upstream git@github.com:open-mmlab/mmengine

检查 remote 是否添加成功,在终端输入 git remote -v

origin	git@github.com:{username}/mmengine.git (fetch)
origin	git@github.com:{username}/mmengine.git (push)
upstream	git@github.com:open-mmlab/mmengine (fetch)
upstream	git@github.com:open-mmlab/mmengine (push)

注解

这里对 origin 和 upstream 进行一个简单的介绍,当我们使用 git clone 来克隆代码时,会默认创建一个 origin 的 remote,它指向我们克隆的代码库地址,而 upstream 则是我们自己添加的,用来指向原始代码库地址。当然如果你不喜欢他叫 upstream,也可以自己修改,比如叫 open-mmlab。我们通常向 origin 提交代码(即 fork 下来的远程仓库),然后向 upstream 提交一个 pull request。如果提交的代码和最新的代码发生冲突,再从 upstream 拉取最新的代码,和本地分支解决冲突,再提交到 origin。

2. 配置 pre-commit

在本地开发环境中,我们使用 pre-commit 来检查代码风格,以确保代码风格的统一。在提交代码,需要先安装 pre-commit(需要在 mmengine 目录下执行):

pip install -U pre-commit
pre-commit install

检查 pre-commit 是否配置成功,并安装 .pre-commit-config.yaml 中的钩子:

pre-commit run --all-files

注解

如果你是中国用户,由于网络原因,可能会出现安装失败的情况,这时可以使用国内源 pre-commit install -c .pre-commit-config-zh-cn.yaml pre-commit run –all-files -c .pre-commit-config-zh-cn.yaml

如果安装过程被中断,可以重复执行 pre-commit run ... 继续安装。

如果提交的代码不符合代码风格规范,pre-commit 会发出警告,并自动修复部分错误。

如果我们想临时绕开 pre-commit 的检查提交一次代码,可以在 git commit 时加上 --no-verify(需要保证最后推送至远程仓库的代码能够通过 pre-commit 检查)。

git commit -m "xxx" --no-verify

3. 创建开发分支

安装完 pre-commit 之后,我们需要基于 master 创建开发分支,建议的分支命名规则为 username/pr_name

git checkout -b yhc/refactor_contributing_doc

在后续的开发中,如果本地仓库的 master 分支落后于 upstream 的 master 分支,我们需要先拉取 upstream 的代码进行同步,再执行上面的命令

git pull upstream master

4. 提交代码并在本地通过单元测试

  • MMEngine 引入了 mypy 来做静态类型检查,以增加代码的鲁棒性。因此我们在提交代码时,需要补充 Type Hints。具体规则可以参考教程

  • 提交的代码同样需要通过单元测试

    # 通过全量单元测试
    pytest tests
    
    # 我们需要保证提交的代码能够通过修改模块的单元测试,以 runner 为例
    pytest tests/test_runner/test_runner.py
    

    如果你由于缺少依赖无法运行修改模块的单元测试,可以参考指引-单元测试

  • 如果修改/添加了文档,参考指引确认文档渲染正常。

5. 推送代码到远程

代码通过单元测试和 pre-commit 检查后,将代码推送到远程仓库,如果是第一次推送,可以在 git push 后加上 -u 参数以关联远程分支

git push -u origin {branch_name}

这样下次就可以直接使用 git push 命令推送代码了,而无需指定分支和远程仓库。

6. 提交拉取请求(PR)

(1) 在 GitHub 的 Pull request 界面创建拉取请求

(2) 根据指引修改 PR 描述,以便于其他开发者更好地理解你的修改

描述规范详见拉取请求规范

 

注意事项

(a) PR 描述应该包含修改理由、修改内容以及修改后带来的影响,并关联相关 Issue(具体方式见文档

(b) 如果是第一次为 OpenMMLab 做贡献,需要签署 CLA

(c) 检查提交的 PR 是否通过 CI(集成测试)

MMEngine 会在不同的平台(Linux、Window、Mac),基于不同版本的 Python、PyTorch、CUDA 对提交的代码进行单元测试,以保证代码的正确性,如果有任何一个没有通过,我们可点击上图中的 Details 来查看具体的测试信息,以便于我们修改代码。

(3) 如果 PR 通过了 CI,那么就可以等待其他开发者的 review,并根据 reviewer 的意见,修改代码,并重复 4-5 步骤,直到 reviewer 同意合入 PR。

所有 reviewer 同意合入 PR 后,我们会尽快将 PR 合并到主分支。

7. 解决冲突

随着时间的推移,我们的代码库会不断更新,这时候,如果你的 PR 与主分支存在冲突,你需要解决冲突,解决冲突的方式有两种:

git fetch --all --prune
git rebase upstream/master

或者

git fetch --all --prune
git merge upstream/master

如果你非常善于处理冲突,那么可以使用 rebase 的方式来解决冲突,因为这能够保证你的 commit log 的整洁。如果你不太熟悉 rebase 的使用,那么可以使用 merge 的方式来解决冲突。

指引

单元测试

在提交修复代码错误或新增特性的拉取请求时,我们应该尽可能的让单元测试覆盖所有提交的代码,计算单元测试覆盖率的方法如下

python -m coverage run -m pytest /path/to/test_file
python -m coverage html
# check file in htmlcov/index.html

文档渲染

在提交修复代码错误或新增特性的拉取请求时,可能会需要修改/新增模块的 docstring。我们需要确认渲染后的文档样式是正确的。 本地生成渲染后的文档的方法如下

pip install -r requirements/docs.txt
cd docs/zh_cn/
# or docs/en
make html
# check file in ./docs/zh_cn/_build/html/index.html

Python 代码风格

PEP8 作为 OpenMMLab 算法库首选的代码规范,我们使用以下工具检查和格式化代码

  • flake8: Python 官方发布的代码规范检查工具,是多个检查工具的封装

  • isort: 自动调整模块导入顺序的工具

  • yapf: Google 发布的代码规范检查工具

  • codespell: 检查单词拼写是否有误

  • mdformat: 检查 markdown 文件的工具

  • docformatter: 格式化 docstring 的工具

yapf 和 isort 的配置可以在 setup.cfg 找到

通过配置 pre-commit hook ,我们可以在提交代码时自动检查和格式化 flake8yapfisorttrailing whitespacesmarkdown files,修复 end-of-filesdouble-quoted-stringspython-encoding-pragmamixed-line-ending,调整 requirments.txt 的包顺序。 pre-commit 钩子的配置可以在 .pre-commit-config 找到。

pre-commit 具体的安装使用方式见拉取请求

更具体的规范请参考 OpenMMLab 代码规范

拉取请求规范

  1. 使用 pre-commit hook,尽量减少代码风格相关问题

  2. 一个拉取请求对应一个短期分支

  3. 粒度要细,一个拉取请求只做一件事情,避免超大的拉取请求

    • Bad:实现 Faster R-CNN

    • Acceptable:给 Faster R-CNN 添加一个 box head

    • Good:给 box head 增加一个参数来支持自定义的 conv 层数

  4. 每次 Commit 时需要提供清晰且有意义 commit 信息

  5. 提供清晰且有意义的拉取请求描述

    • 标题写明白任务名称,一般格式:[Prefix] Short description of the pull request (Suffix)

    • prefix: 新增功能 [Feature], 修 bug [Fix], 文档相关 [Docs], 开发中 [WIP] (暂时不会被review)

    • 描述里介绍拉取请求的主要修改内容,结果,以及对其他部分的影响, 参考拉取请求模板

    • 关联相关的议题 (issue) 和其他拉取请求

  6. 如果引入了其他三方库,或借鉴了三方库的代码,请确认他们的许可证和 MMEngine 兼容,并在借鉴的代码上补充 This code is inspired from http://

代码规范

代码规范标准

PEP 8 —— Python 官方代码规范

Python 官方的代码风格指南,包含了以下几个方面的内容:

  • 代码布局,介绍了 Python 中空行、断行以及导入相关的代码风格规范。比如一个常见的问题:当我的代码较长,无法在一行写下时,何处可以断行?

  • 表达式,介绍了 Python 中表达式空格相关的一些风格规范。

  • 尾随逗号相关的规范。当列表较长,无法一行写下而写成如下逐行列表时,推荐在末项后加逗号,从而便于追加选项、版本控制等。

    # Correct:
    FILES = ['setup.cfg', 'tox.ini']
    # Correct:
    FILES = [
        'setup.cfg',
        'tox.ini',
    ]
    # Wrong:
    FILES = ['setup.cfg', 'tox.ini',]
    # Wrong:
    FILES = [
        'setup.cfg',
        'tox.ini'
    ]
    
  • 命名相关规范、注释相关规范、类型注解相关规范,我们将在后续章节中做详细介绍。

    “A style guide is about consistency. Consistency with this style guide is important. Consistency within a project is more important. Consistency within one module or function is the most important.” PEP 8 – Style Guide for Python Code

注解

PEP 8 的代码规范并不是绝对的,项目内的一致性要优先于 PEP 8 的规范。OpenMMLab 各个项目都在 setup.cfg 设定了一些代码规范的设置,请遵照这些设置。一个例子是在 PEP 8 中有如下一个例子:

# Correct:
hypot2 = x*x + y*y
# Wrong:
hypot2 = x * x + y * y

这一规范是为了指示不同优先级,但 OpenMMLab 的设置中通常没有启用 yapf 的 ARITHMETIC_PRECEDENCE_INDICATION 选项,因而格式规范工具不会按照推荐样式格式化,以设置为准。

Google 开源项目风格指南

Google 使用的编程风格指南,包括了 Python 相关的章节。相较于 PEP 8,该指南提供了更为详尽的代码指南。该指南包括了语言规范和风格规范两个部分。

其中,语言规范对 Python 中很多语言特性进行了优缺点的分析,并给出了使用指导意见,如异常、Lambda 表达式、列表推导式、metaclass 等。

风格规范的内容与 PEP 8 较为接近,大部分约定建立在 PEP 8 的基础上,也有一些更为详细的约定,如函数长度、TODO 注释、文件与 socket 对象的访问等。

推荐将该指南作为参考进行开发,但不必严格遵照,一来该指南存在一些 Python 2 兼容需求,例如指南中要求所有无基类的类应当显式地继承 Object, 而在仅使用 Python 3 的环境中,这一要求是不必要的,依本项目中的惯例即可。二来 OpenMMLab 的项目作为框架级的开源软件,不必对一些高级技巧过于避讳,尤其是 MMCV。但尝试使用这些技巧前应当认真考虑是否真的有必要,并寻求其他开发人员的广泛评估。

另外需要注意的一处规范是关于包的导入,在该指南中,要求导入本地包时必须使用路径全称,且导入的每一个模块都应当单独成行,通常这是不必要的,而且也不符合目前项目的开发惯例,此处进行如下约定:

# Correct
from mmcv.cnn.bricks import (Conv2d, build_norm_layer, DropPath, MaxPool2d,
                             Linear)
from ..utils import ext_loader

# Wrong
from mmcv.cnn.bricks import Conv2d, build_norm_layer, DropPath, MaxPool2d, \
                            Linear  # 使用括号进行连接,而不是反斜杠
from ...utils import is_str  # 最多向上回溯一层,过多的回溯容易导致结构混乱

OpenMMLab 项目使用 pre-commit 工具自动格式化代码,详情见贡献代码

命名规范

命名规范的重要性

优秀的命名是良好代码可读的基础。基础的命名规范对各类变量的命名做了要求,使读者可以方便地根据代码名了解变量是一个类 / 局部变量 / 全局变量等。而优秀的命名则需要代码作者对于变量的功能有清晰的认识,以及良好的表达能力,从而使读者根据名称就能了解其含义,甚至帮助了解该段代码的功能。

基础命名规范

类型

公有

私有

模块

lower_with_under

_lower_with_under

lower_with_under

CapWords

_CapWords

异常

CapWordsError

函数(方法)

lower_with_under

_lower_with_under

函数 / 方法参数

lower_with_under

全局 / 类内常量

CAPS_WITH_UNDER

_CAPS_WITH_UNDER

全局 / 类内变量

lower_with_under

_lower_with_under

变量

lower_with_under

_lower_with_under

局部变量

lower_with_under

注意:

  • 尽量避免变量名与保留字冲突,特殊情况下如不可避免,可使用一个后置下划线,如 class_

  • 尽量不要使用过于简单的命名,除了约定俗成的循环变量 i,文件变量 f,错误变量 e 等。

  • 不会被用到的变量可以命名为 _,逻辑检查器会将其忽略。

命名技巧

良好的变量命名需要保证三点:

  1. 含义准确,没有歧义

  2. 长短适中

  3. 前后统一

# Wrong
class Masks(metaclass=ABCMeta):  # 命名无法表现基类;Instance or Semantic?
    pass

# Correct
class BaseInstanceMasks(metaclass=ABCMeta):
    pass

# Wrong,不同地方含义相同的变量尽量用统一的命名
def __init__(self, inplanes, planes):
    pass

def __init__(self, in_channels, out_channels):
    pass

常见的函数命名方法:

  • 动宾命名法:crop_img, init_weights

  • 动宾倒置命名法:imread, bbox_flip

注意函数命名与参数的顺序,保证主语在前,符合语言习惯:

  • check_keys_exist(key, container)

  • check_keys_contain(container, key)

注意避免非常规或统一约定的缩写,如 nb -> num_blocks,in_nc -> in_channels

docstring 规范

为什么要写 docstring

docstring 是对一个类、一个函数功能与 API 接口的详细描述,有两个功能,一是帮助其他开发者了解代码功能,方便 debug 和复用代码;二是在 Readthedocs 文档中自动生成相关的 API reference 文档,帮助不了解源代码的社区用户使用相关功能。

如何写 docstring

与注释不同,一份规范的 docstring 有着严格的格式要求,以便于 Python 解释器以及 sphinx 进行文档解析,详细的 docstring 约定参见 PEP 257。此处以例子的形式介绍各种文档的标准格式,参考格式为 Google 风格

  1. 模块文档

    代码风格规范推荐为每一个模块(即 Python 文件)编写一个 docstring,但目前 OpenMMLab 项目大部分没有此类 docstring,因此不做硬性要求。

    """A one line summary of the module or program, terminated by a period.
    
    Leave one blank line. The rest of this docstring should contain an
    overall description of the module or program. Optionally, it may also
    contain a brief description of exported classes and functions and/or usage
    examples.
    
    Typical usage example:
    
    foo = ClassFoo()
    bar = foo.FunctionBar()
    """
    
  2. 类文档

    类文档是我们最常需要编写的,此处,按照 OpenMMLab 的惯例,我们使用了与 Google 风格不同的写法。如下例所示,文档中没有使用 Attributes 描述类属性,而是使用 Args 描述 init 函数的参数。

    在 Args 中,遵照 parameter (type): Description. 的格式,描述每一个参数类型和功能。其中,多种类型可使用 (float or str) 的写法,可以为 None 的参数可以写为 (int, optional)

    class BaseRunner(metaclass=ABCMeta):
        """The base class of Runner, a training helper for PyTorch.
    
        All subclasses should implement the following APIs:
    
        - ``run()``
        - ``train()``
        - ``val()``
        - ``save_checkpoint()``
    
        Args:
            model (:obj:`torch.nn.Module`): The model to be run.
            batch_processor (callable, optional): A callable method that process
                a data batch. The interface of this method should be
                ``batch_processor(model, data, train_mode) -> dict``.
                Defaults to None.
            optimizer (dict or :obj:`torch.optim.Optimizer`, optional): It can be
                either an optimizer (in most cases) or a dict of optimizers
                (in models that requires more than one optimizer, e.g., GAN).
                Defaults to None.
            work_dir (str, optional): The working directory to save checkpoints
                and logs. Defaults to None.
            logger (:obj:`logging.Logger`): Logger used during training.
                 Defaults to None. (The default value is just for backward
                 compatibility)
            meta (dict, optional): A dict records some import information such as
                environment info and seed, which will be logged in logger hook.
                Defaults to None.
            max_epochs (int, optional): Total training epochs. Defaults to None.
            max_iters (int, optional): Total training iterations. Defaults to None.
        """
    
        def __init__(self,
                     model,
                     batch_processor=None,
                     optimizer=None,
                     work_dir=None,
                     logger=None,
                     meta=None,
                     max_iters=None,
                     max_epochs=None):
            ...
    

    另外,在一些算法实现的主体类中,建议加入原论文的链接;如果参考了其他开源代码的实现,则应加入 modified from,而如果是直接复制了其他代码库的实现,则应加入 copied from ,并注意源码的 License。如有必要,也可以通过 .. math:: 来加入数学公式

    # 参考实现
    # This func is modified from `detectron2
    # <https://github.com/facebookresearch/detectron2/blob/ffff8acc35ea88ad1cb1806ab0f00b4c1c5dbfd9/detectron2/structures/masks.py#L387>`_.
    
    # 复制代码
    # This code was copied from the `ubelt
    # library<https://github.com/Erotemic/ubelt>`_.
    
    # 引用论文 & 添加公式
    class LabelSmoothLoss(nn.Module):
        r"""Initializer for the label smoothed cross entropy loss.
    
        Refers to `Rethinking the Inception Architecture for Computer Vision
        <https://arxiv.org/abs/1512.00567>`_.
    
        This decreases gap between output scores and encourages generalization.
        Labels provided to forward can be one-hot like vectors (NxC) or class
        indices (Nx1).
        And this accepts linear combination of one-hot like labels from mixup or
        cutmix except multi-label task.
    
        Args:
            label_smooth_val (float): The degree of label smoothing.
            num_classes (int, optional): Number of classes. Defaults to None.
            mode (str): Refers to notes, Options are "original", "classy_vision",
                "multi_label". Defaults to "classy_vision".
            reduction (str): The method used to reduce the loss.
                Options are "none", "mean" and "sum". Defaults to 'mean'.
            loss_weight (float):  Weight of the loss. Defaults to 1.0.
    
        Note:
            if the ``mode`` is "original", this will use the same label smooth
            method as the original paper as:
    
            .. math::
                (1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K}
    
            where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is
            the ``num_classes`` and :math:`\delta_{k,y}` is Dirac delta,
            which equals 1 for k=y and 0 otherwise.
    
            if the ``mode`` is "classy_vision", this will use the same label
            smooth method as the `facebookresearch/ClassyVision
            <https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/losses/label_smoothing_loss.py>`_ repo as:
    
            .. math::
                \frac{\delta_{k, y} + \epsilon/K}{1+\epsilon}
    
            if the ``mode`` is "multi_label", this will accept labels from
            multi-label task and smoothing them as:
    
            .. math::
                (1-2\epsilon)\delta_{k, y} + \epsilon
    

注解

注意 ``here``、`here`、”here” 三种引号功能是不同。

在 reStructured 语法中,``here`` 表示一段代码;`here` 表示斜体;”here” 无特殊含义,一般可用来表示字符串。其中 `here` 的用法与 Markdown 中不同,需要多加留意。 另外还有 :obj:`type` 这种更规范的表示类的写法,但鉴于长度,不做特别要求,一般仅用于表示非常用类型。

  1. 方法(函数)文档

    函数文档与类文档的结构基本一致,但需要加入返回值文档。对于较为复杂的函数和类,可以使用 Examples 字段加入示例;如果需要对参数加入一些较长的备注,可以加入 Note 字段进行说明。

    对于使用较为复杂的类或函数,比起看大段大段的说明文字和参数文档,添加合适的示例更能帮助用户迅速了解其用法。需要注意的是,这些示例最好是能够直接在 Python 交互式环境中运行的,并给出一些相对应的结果。如果存在多个示例,可以使用注释简单说明每段示例,也能起到分隔作用。

    def import_modules_from_strings(imports, allow_failed_imports=False):
        """Import modules from the given list of strings.
    
        Args:
            imports (list | str | None): The given module names to be imported.
            allow_failed_imports (bool): If True, the failed imports will return
                None. Otherwise, an ImportError is raise. Defaults to False.
    
        Returns:
            List[module] | module | None: The imported modules.
            All these three lines in docstring will be compiled into the same
            line in readthedocs.
    
        Examples:
            >>> osp, sys = import_modules_from_strings(
            ...     ['os.path', 'sys'])
            >>> import os.path as osp_
            >>> import sys as sys_
            >>> assert osp == osp_
            >>> assert sys == sys_
        """
        ...
    

    如果函数接口在某个版本发生了变化,需要在 docstring 中加入相关的说明,必要时添加 Note 或者 Warning 进行说明,例如:

    class CheckpointHook(Hook):
        """Save checkpoints periodically.
    
        Args:
            out_dir (str, optional): The root directory to save checkpoints. If
                not specified, ``runner.work_dir`` will be used by default. If
                specified, the ``out_dir`` will be the concatenation of
                ``out_dir`` and the last level directory of ``runner.work_dir``.
                Defaults to None. `Changed in version 1.3.15.`
            file_client_args (dict, optional): Arguments to instantiate a
                FileClient. See :class:`mmcv.fileio.FileClient` for details.
                Defaults to None. `New in version 1.3.15.`
    
        Warning:
            Before v1.3.15, the ``out_dir`` argument indicates the path where the
            checkpoint is stored. However, in v1.3.15 and later, ``out_dir``
            indicates the root directory and the final path to save checkpoint is
            the concatenation of out_dir and the last level directory of
            ``runner.work_dir``. Suppose the value of ``out_dir`` is
            "/path/of/A" and the value of ``runner.work_dir`` is "/path/of/B",
            then the final path will be "/path/of/A/B".
    

    如果参数或返回值里带有需要展开描述字段的 dict,则应该采用如下格式:

    def func(x):
        r"""
        Args:
            x (None): A dict with 2 keys, ``padded_targets``, and ``targets``.
    
                - ``targets`` (list[Tensor]): A list of tensors.
                  Each tensor has the shape of :math:`(T_i)`. Each
                  element is the index of a character.
                - ``padded_targets`` (Tensor): A tensor of shape :math:`(N)`.
                  Each item is the length of a word.
    
        Returns:
            dict: A dict with 2 keys, ``padded_targets``, and ``targets``.
    
            - ``targets`` (list[Tensor]): A list of tensors.
              Each tensor has the shape of :math:`(T_i)`. Each
              element is the index of a character.
            - ``padded_targets`` (Tensor): A tensor of shape :math:`(N)`.
              Each item is the length of a word.
        """
        return x
    

重要

为了生成 readthedocs 文档,文档的编写需要按照 ReStructrued 文档格式,否则会产生文档渲染错误,在提交 PR 前,最好生成并预览一下文档效果。 语法规范参考:

注释规范

为什么要写注释

对于一个开源项目,团队合作以及社区之间的合作是必不可少的,因而尤其要重视合理的注释。不写注释的代码,很有可能过几个月自己也难以理解,造成额外的阅读和修改成本。

如何写注释

最需要写注释的是代码中那些技巧性的部分。如果你在下次代码审查的时候必须解释一下,那么你应该现在就给它写注释。对于复杂的操作,应该在其操作开始前写上若干行注释。对于不是一目了然的代码,应在其行尾添加注释。 —— Google 开源项目风格指南

# We use a weighted dictionary search to find out where i is in
# the array. We extrapolate position based on the largest num
# in the array and the array size and then do binary search to
# get the exact number.
if i & (i-1) == 0:  # True if i is 0 or a power of 2.

为了提高可读性, 注释应该至少离开代码2个空格. 另一方面, 绝不要描述代码. 假设阅读代码的人比你更懂Python, 他只是不知道你的代码要做什么. —— Google 开源项目风格指南

# Wrong:
# Now go through the b array and make sure whenever i occurs
# the next element is i+1

# Wrong:
if i & (i-1) == 0:  # True if i bitwise and i-1 is 0.

在注释中,可以使用 Markdown 语法,因为开发人员通常熟悉 Markdown 语法,这样可以便于交流理解,如可使用单反引号表示代码和变量(注意不要和 docstring 中的 ReStructured 语法混淆)

# `_reversed_padding_repeated_twice` is the padding to be passed to
# `F.pad` if needed (e.g., for non-zero padding types that are
# implemented as two ops: padding + conv). `F.pad` accepts paddings in
# reverse order than the dimension.
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)

注释示例

  1. 出自 mmcv/utils/registry.py,对于较为复杂的逻辑结构,通过注释,明确了优先级关系。

    # self.build_func will be set with the following priority:
    # 1. build_func
    # 2. parent.build_func
    # 3. build_from_cfg
    if build_func is None:
        if parent is not None:
            self.build_func = parent.build_func
        else:
            self.build_func = build_from_cfg
    else:
        self.build_func = build_func
    
  2. 出自 mmcv/runner/checkpoint.py,对于 bug 修复中的一些特殊处理,可以附带相关的 issue 链接,帮助其他人了解 bug 背景。

    def _save_ckpt(checkpoint, file):
        # The 1.6 release of PyTorch switched torch.save to use a new
        # zipfile-based file format. It will cause RuntimeError when a
        # checkpoint was saved in high version (PyTorch version>=1.6.0) but
        # loaded in low version (PyTorch version<1.6.0). More details at
        # https://github.com/open-mmlab/mmpose/issues/904
        if digit_version(TORCH_VERSION) >= digit_version('1.6.0'):
            torch.save(checkpoint, file, _use_new_zipfile_serialization=False)
        else:
            torch.save(checkpoint, file)
    

类型注解

为什么要写类型注解

类型注解是对函数中变量的类型做限定或提示,为代码的安全性提供保障、增强代码的可读性、避免出现类型相关的错误。 Python 没有对类型做强制限制,类型注解只起到一个提示作用,通常你的 IDE 会解析这些类型注解,然后在你调用相关代码时对类型做提示。另外也有类型注解检查工具,这些工具会根据类型注解,对代码中可能出现的问题进行检查,减少 bug 的出现。 需要注意的是,通常我们不需要注释模块中的所有函数:

  1. 公共的 API 需要注释

  2. 在代码的安全性,清晰性和灵活性上进行权衡是否注释

  3. 对于容易出现类型相关的错误的代码进行注释

  4. 难以理解的代码请进行注释

  5. 若代码中的类型已经稳定,可以进行注释. 对于一份成熟的代码,多数情况下,即使注释了所有的函数,也不会丧失太多的灵活性.

如何写类型注解

  1. 函数 / 方法类型注解,通常不对 self 和 cls 注释。

    from typing import Optional, List, Tuple
    
    # 全部位于一行
    def my_method(self, first_var: int) -> int:
        pass
    
    # 另起一行
    def my_method(
            self, first_var: int,
            second_var: float) -> Tuple[MyLongType1, MyLongType1, MyLongType1]:
        pass
    
    # 单独成行(具体的应用场合与行宽有关,建议结合 yapf 自动化格式使用)
    def my_method(
        self, first_var: int, second_var: float
    ) -> Tuple[MyLongType1, MyLongType1, MyLongType1]:
        pass
    
    # 引用尚未被定义的类型
    class MyClass:
        def __init__(self,
                     stack: List["MyClass"]) -> None:
            pass
    

    注:类型注解中的类型可以是 Python 内置类型,也可以是自定义类,还可以使用 Python 提供的 wrapper 类对类型注解进行装饰,一些常见的注解如下:

    # 数值类型
    from numbers import Number
    
    # 可选类型,指参数可以为 None
    from typing import Optional
    def foo(var: Optional[int] = None):
        pass
    
    # 联合类型,指同时接受多种类型
    from typing import Union
    def foo(var: Union[float, str]):
        pass
    
    from typing import Sequence  # 序列类型
    from typing import Iterable  # 可迭代类型
    from typing import Any  # 任意类型
    from typing import Callable  # 可调用类型
    
    from typing import List, Dict  # 列表和字典的泛型类型
    from typing import Tuple  # 元组的特殊格式
    # 虽然在 Python 3.9 中,list, tuple 和 dict 本身已支持泛型,但为了支持之前的版本
    # 我们在进行类型注解时还是需要使用 List, Tuple, Dict 类型
    # 另外,在对参数类型进行注解时,尽量使用 Sequence & Iterable & Mapping
    # List, Tuple, Dict 主要用于返回值类型注解
    # 参见 https://docs.python.org/3/library/typing.html#typing.List
    
  2. 变量类型注解,一般用于难以直接推断其类型时

    # Recommend: 带类型注解的赋值
    a: Foo = SomeUndecoratedFunction()
    a: List[int]: [1, 2, 3]         # List 只支持单一类型泛型,可使用 Union
    b: Tuple[int, int] = (1, 2)     # 长度固定为 2
    c: Tuple[int, ...] = (1, 2, 3)  # 变长
    d: Dict[str, int] = {'a': 1, 'b': 2}
    
    # Not Recommend:行尾类型注释
    # 虽然这种方式被写在了 Google 开源指南中,但这是一种为了支持 Python 2.7 版本
    # 而补充的注释方式,鉴于我们只支持 Python 3, 为了风格统一,不推荐使用这种方式。
    a = SomeUndecoratedFunction()  # type: Foo
    a = [1, 2, 3]  # type: List[int]
    b = (1, 2, 3)  # type: Tuple[int, ...]
    c = (1, "2", 3.5)  # type: Tuple[int, Text, float]
    
  3. 泛型

    上文中我们知道,typing 中提供了 list 和 dict 的泛型类型,那么我们自己是否可以定义类似的泛型呢?

    from typing import TypeVar, Generic
    
    KT = TypeVar('KT')
    VT = TypeVar('VT')
    
    class Mapping(Generic[KT, VT]):
        def __init__(self, data: Dict[KT, VT]):
            self._data = data
    
        def __getitem__(self, key: KT) -> VT:
            return self._data[key]
    

    使用上述方法,我们定义了一个拥有泛型能力的映射类,实际用法如下:

    mapping = Mapping[str, float]({'a': 0.5})
    value: float = example['a']
    

    另外,我们也可以利用 TypeVar 在函数签名中指定联动的多个类型:

    from typing import TypeVar, List
    
    T = TypeVar('T')  # Can be anything
    A = TypeVar('A', str, bytes)  # Must be str or bytes
    
    
    def repeat(x: T, n: int) -> List[T]:
        """Return a list containing n references to x."""
        return [x]*n
    
    
    def longest(x: A, y: A) -> A:
        """Return the longest of two strings."""
        return x if len(x) >= len(y) else y
    

更多关于类型注解的写法请参考 typing

类型注解检查工具

mypy 是一个 Python 静态类型检查工具。根据你的类型注解,mypy 会检查传参、赋值等操作是否符合类型注解,从而避免可能出现的 bug。

例如如下的一个 Python 脚本文件 test.py:

def foo(var: int) -> float:
    return float(var)

a: str = foo('2.0')
b: int = foo('3.0')  # type: ignore

运行 mypy test.py 可以得到如下检查结果,分别指出了第 4 行在函数调用和返回值赋值两处类型错误。而第 5 行同样存在两个类型错误,由于使用了 type: ignore 而被忽略了,只有部分特殊情况可能需要此类忽略。

test.py:4: error: Incompatible types in assignment (expression has type "float", variable has type "int")
test.py:4: error: Argument 1 to "foo" has incompatible type "str"; expected "int"
Found 2 errors in 1 file (checked 1 source file)

Indices and tables


© Copyright 2022, mmengine contributors. Revision 6af88783.

Built with Sphinx using a theme provided by Read the Docs.

开始你的第一步

常用功能

入门教程

进阶教程

架构设计

迁移指南

API 文档

说明

语言切换

Read the Docs v: v0.4.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.