Shortcuts

模型(Model)

在训练深度学习任务时,我们通常需要定义一个模型来实现算法的主体。在基于 MMEngine 开发时,模型由执行器管理,需要实现 train_stepval_steptest_step 方法。

对于检测、识别、分割一类的深度学习任务,上述方法通常为标准的流程,例如在 train_step 里更新参数,返回损失;val_steptest_step 返回预测结果。因此 MMEngine 抽象出模型基类 BaseModel,实现了上述接口的标准流程。我们只需要让模型继承自模型基类,并按照一定的规范实现 forward,就能让模型在执行器中运行起来。

模型基类继承自模块基类,能够通过配置 init_cfg 灵活的选择初始化方式。

接口约定

forward: forward 的入参需要和 DataLoader 的输出保持一致 (自定义数据预处理器除外),如果 DataLoader 返回元组类型的数据 dataforward 需要能够接受 *data 的解包后的参数;如果返回字典类型的数据 dataforward 需要能够接受 **data 解包后的参数。 mode 参数用于控制 forward 的返回结果:

  • mode='loss'loss 模式通常在训练阶段启用,并返回一个损失字典。损失字典的 key-value 分别为损失名和可微的 torch.Tensor。字典中记录的损失会被用于更新参数和记录日志。模型基类会在 train_step 方法中调用该模式的 forward

  • mode='predict'predict 模式通常在验证、测试阶段启用,并返回列表/元组形式的预测结果,预测结果需要和 process 接口的参数相匹配。OpenMMLab 系列算法对 predict 模式的输出有着更加严格的约定,需要输出列表形式的数据元素。模型基类会在 val_steptest_step 方法中调用该模式的 forward

  • mode='tensor'tensorpredict 模式均返回模型的前向推理结果,区别在于 tensor 模式下,forward 会返回未经后处理的张量,例如返回未经非极大值抑制(nms)处理的检测结果,返回未经 argmax 处理的分类结果。我们可以基于 tensor 模式的结果进行自定义的后处理。

train_step: 调用 loss 模式的 forward 接口,得到损失字典。模型基类基于优化器封装 实现了标准的梯度计算、参数更新、梯度清零流程。

val_step: 调用 predict 模式的 forward,返回预测结果,预测结果会被进一步传给评测器process 接口和钩子(Hook)after_val_iter 接口。

test_step: 同 val_step,预测结果会被进一步传给 after_test_iter 接口。

基于上述接口约定,我们定义了继承自模型基类的 NeuralNetwork,配合执行器来训练 FashionMNIST

from torch.utils.data import DataLoader
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from mmengine.model import BaseModel
from mmengine.evaluator import BaseMetric
from mmengine.runner import Runner


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(dataset=training_data, batch_size=64)
test_dataloader = DataLoader(dataset=test_data, batch_size=64)


class NeuralNetwork(BaseModel):
    def __init__(self, data_preprocessor=None):
        super(NeuralNetwork, self).__init__(data_preprocessor)
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )
        self.loss = nn.CrossEntropyLoss()

    def forward(self, img, label, mode='tensor'):
        x = self.flatten(img)
        pred = self.linear_relu_stack(x)
        loss = self.loss(pred, label)
        if mode == 'loss':
            return dict(loss=loss)
        elif mode=='predict':
            return pred.argmax(1), loss.item()
        else:
            return pred


class FashionMnistMetric(BaseMetric):
    def process(self, data, preds) -> None:
        # data 参数为 Dataloader 返回的元组,即 (img, label)
        # predict 为模型 `predict` 模式下,返回的元组,分别为 `pred.argmax(1) 和 `loss``
        self.results.append(((data[1] == preds[0].cpu()).sum(), preds[1], len(preds[0])))

    def compute_metrics(self, results):
        correct, loss, batch_size = zip(*results)
        test_loss, correct = sum(loss) / len(self.results), sum(correct) / sum(batch_size)
        return dict(Accuracy=correct, Avg_loss=test_loss)


runner = Runner(
    model=NeuralNetwork(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=1e-3)),
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_cfg=dict(fp16=True),
    val_dataloader=test_dataloader,
    val_evaluator=dict(metrics=FashionMnistMetric()))
runner.train()

在这个例子中,NeuralNetwork.forward 存在着以下跨模块的接口约定:

  • 由于 train_dataloader 会返回一个 (img, label) 形式的元组,因此 forward 接口的前两个参数分别需要为 imglabel

  • 由于 forwardpredict 模式下会返回 (pred, loss) 形式的元组,因此 processpreds 参数应当同样为相同形式的元组。

相比于 Pytorch 官方示例,MMEngine 的代码更加简洁,记录的日志也更加丰富。

数据预处理器(DataPreprocessor)

如果你的电脑配有 GPU(或其他能够加速训练的硬件,如 mps、ipu 等),并运行了上节的代码示例。你会发现 Pytorch 的示例是在 CPU 上运行的,而 MMEngine 的示例是在 GPU 上运行的。MMEngine 是在何时把数据和模型从 CPU 搬运到 GPU 的呢?

事实上,执行器会在构造阶段将模型搬运到指定设备,而数据则会在 train_stepval_steptest_step 中,被基础数据预处理器(BaseDataPreprocessor)搬运到指定设备,进一步将处理好的数据传给模型。数据预处理器作为模型基类的一个属性,会在模型基类的构造过程中被实例化。

为了体现数据预处理器起到的作用,我们仍然以上一节训练 FashionMNIST 为例, 实现了一个简易的数据预处理器,用于搬运数据和归一化:

from torch.optim import SGD
from mmengine.model import BaseDataPreprocessor, BaseModel


class NeuralNetwork1(NeuralNetwork):

    def __init__(self, data_preprocessor):
        super().__init__(data_preprocessor=data_preprocessor)
        self.data_preprocessor = data_preprocessor

    def train_step(self, data, optimizer):
        img, label = self.data_preprocessor(data)
        loss = self(img, label, mode='loss')['loss'].sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        return dict(loss=loss)

    def test_step(self, data):
        img, label = self.data_preprocessor(data)
        return self(img, label, mode='predict')

    def val_step(self, data):
        img, label = self.data_preprocessor(data)
        return self(img, label, mode='predict')


class NormalizeDataPreprocessor(BaseDataPreprocessor):

    def forward(self, data, training=False):
        img, label = [item for item in data]
        img = (img - 127.5) / 127.5
        return img, label


model = NeuralNetwork1(data_preprocessor=NormalizeDataPreprocessor())
optimizer = SGD(model.parameters(), lr=0.01)
data = (torch.full((3, 28, 28), fill_value=127.5), torch.ones(3, 10))

model.train_step(data, optimizer)
model.val_step(data)
model.test_step(data)

上例中,我们实现了 BaseModel.train_stepBaseModel.val_stepBaseModel.test_step 的简化版。数据经 NormalizeDataPreprocessor.forward 归一化处理,解包后传给 NeuralNetwork.forward,进一步返回损失或者预测结果。如果想实现自定义的参数优化或预测逻辑,可以自行实现 train_stepval_steptest_step,具体例子可以参考:使用 MMEngine 训练生成对抗网络

注解

上例中数据预处理器的 training 参数用于区分训练、测试阶段不同的批增强策略,train_step 会传入 training=Truetest_stepval_step 则会传入 trainig=Fasle

注解

通常情况下,我们要求 DataLoader 的 data 数据解包后(字典类型的被 **data 解包,元组列表类型被 *data 解包)能够直接传给模型的 forward。但是如果数据预处理器修改了 data 的数据类型,则要求数据预处理器的 forward 的返回值与模型 forward 的入参相匹配。

Read the Docs v: v0.3.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.