Shortcuts

从 EpochBased 切换至 IterBased

MMEngine 支持两种训练模式,基于轮次的 EpochBased 方式和基于迭代次数的 IterBased 方式,这两种方式在下游算法库均有使用,例如 MMDetection 默认使用 EpochBased 方式,MMSegmentation 默认使用 IterBased 方式。

MMEngine 很多模块默认以 EpochBased 的模式执行,例如 ParamScheduler, LoggerHook, CheckpointHook 等,常见的 EpochBased 配置写法如下:

param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6, 8]
    by_epoch=True  # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
)

default_hooks = dict(
    logger=dict(type='LoggerHook'),
    checkpoint=dict(type='CheckpointHook', interval=2),
)

train_cfg = dict(
    by_epoch=True,  # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
    max_epochs=10,
    val_interval=2
)

log_processor = dict(
    by_epoch=True
)  # log_processor 的 by_epoch 默认为 True,这边显式的写出来只是为了方便对比, 实际上不需要设置

runner = Runner(
    model=ResNet18(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader_cfg,
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
    param_scheduler=param_scheduler
    default_hooks=default_hooks,
    log_processor=log_processor,
    train_cfg=train_cfg,
    resume=True,
)

如果想按照 iter 训练模型,需要做以下改动:

  1. train_cfg 中的 by_epoch 设置为 False,同时将 max_iters 设置为训练的总 iter 数,val_iterval 设置为验证间隔的 iter 数。

    train_cfg = dict(
        by_epoch=False,
        max_iters=10000,
        val_interval=2000
    )
    
  2. default_hooks 中的 loggerlog_metric_by_epoch 设置为 False, checkpointby_epoch 设置为 False

    default_hooks = dict(
        logger=dict(type='LoggerHook', log_metric_by_epoch=False),
        checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
    )
    
  3. param_scheduler 中的 by_epoch 设置为 False,并將 epoch 相关的参数换算成 iter

    param_scheduler = dict(
        type='MultiStepLR',
        milestones=[6000, 8000],
        by_epoch=False,
    )
    

    除了这种方式,如果你能保证 IterBasedTraining 和 EpochBasedTraining 总 iter 数一致,直接设置 convert_to_iter_basedTrue 即可。

    param_scheduler = dict(
        type='MultiStepLR',
        milestones=[6, 8]
        convert_to_iter_based=True
    )
    
  4. log_processorby_epoch 设置为 False

    log_processor = dict(
        by_epoch=False
    )
    

15 分钟教程训练 CIFAR10 数据集为例:

Step Training by epoch Training by iteration
Build model
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel


class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels
Build dataloader
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(
    batch_size=32,
    shuffle=True,
    dataset=torchvision.datasets.CIFAR10(
        'data/cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(**norm_cfg)])))

val_dataloader = DataLoader(
    batch_size=32,
    shuffle=False,
    dataset=torchvision.datasets.CIFAR10(
        'data/cifar10',
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(**norm_cfg)])))
Prepare metric
from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # save the middle result of a batch to `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # return the dict containing the eval results
        # the key is the name of the metric name
        return dict(accuracy=100 * total_correct / total_size)
Configure default hooks
default_hooks = dict(
    logger=dict(type='LoggerHook', log_metric_by_epoch=True),
    checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
)
default_hooks = dict(
    logger=dict(type='LoggerHook', log_metric_by_epoch=False),
    checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)
Configure parameter scheduler
param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6, 8],
    by_epoch=True,
)
param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6000, 8000],
    by_epoch=False,
)
Configure log_processor
# The default configuration of log_processor is used for epoch based training.
# Defining it here additionally is for building runner with the same way.
log_processor = dict(by_epoch=True)
log_processor = dict(by_epoch=False)
Configure train_cfg
train_cfg = dict(
    by_epoch=True,
    max_epochs=10,
    val_interval=2
)
train_cfg = dict(
    by_epoch=False,
    max_iters=10000,
    val_interval=2000
)
Build Runner
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=train_cfg,
    log_processor=log_processor,
    default_hooks=default_hooks,
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
)
runner.train()

备注

如果基础配置文件为 train_dataloader 配置了基于 iteration/epoch 采样的 sampler,则需要在当前配置文件中将其更改为指定类型的 sampler,或将其设置为 None。当 dataloader 中的 sampler 为 None,MMEngine 或根据 train_cfg 中的 by_epoch 参数选择 InfiniteSampler(False) 或 DefaultSampler(True)。

备注

如果基础配置文件在 train_cfg 中指定了 type,那么必须在当前配置文件中将 type 覆盖为(IterBasedTrainLoop 或 EpochBasedTrainLoop),而不能简单的指定 by_epoch 参数。

Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.