Shortcuts

从 EpochBased 切换至 IterBased

MMEngine 支持两种训练模式,基于轮次的 EpochBased 方式和基于迭代次数的 IterBased 方式,这两种方式在下游算法库均有使用,例如 MMDetection 默认使用 EpochBased 方式,MMSegmentation 默认使用 IterBased 方式。

MMEngine 很多模块默认以 EpochBased 的模式执行,例如 ParamScheduler, LoggerHook, CheckpointHook 等,常见的 EpochBased 配置写法如下:

param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6, 8]
    by_epoch=True  # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
)

default_hooks = dict(
    logger=dict(type='LoggerHook'),
    checkpoint=dict(type='CheckpointHook', interval=2),
)

train_cfg = dict(
    by_epoch=True,  # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
    max_epochs=10,
    val_interval=2
)

log_processor = dict(
    by_epoch=True
)  # log_processor 的 by_epoch 默认为 True,这边显式的写出来只是为了方便对比, 实际上不需要设置

runner = Runner(
    model=ResNet18(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader_cfg,
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
    param_scheduler=param_scheduler
    default_hooks=default_hooks,
    log_processor=log_processor,
    train_cfg=train_cfg,
    resume=True,
)

如果想按照 iter 训练模型,需要做以下改动:

  1. train_cfg 中的 by_epoch 设置为 False,同时将 max_iters 设置为训练的总 iter 数,val_iterval 设置为验证间隔的 iter 数。

    train_cfg = dict(
        by_epoch=False,
        max_iters=10000,
        val_interval=2000
    )
    
  2. default_hooks 中的 loggerlog_metric_by_epoch 设置为 False, checkpointby_epoch 设置为 False

    default_hooks = dict(
        logger=dict(type='LoggerHook', log_metric_by_epoch=False),
        checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
    )
    
  3. param_scheduler 中的 by_epoch 设置为 False,并將 epoch 相关的参数换算成 iter

    param_scheduler = dict(
        type='MultiStepLR',
        milestones=[6000, 8000],
        by_epoch=False,
    )
    

    除了这种方式,如果你能保证 IterBasedTraining 和 EpochBasedTraining 总 iter 数一致,直接设置 convert_to_iter_basedTrue 即可。

    param_scheduler = dict(
        type='MultiStepLR',
        milestones=[6, 8]
        convert_to_iter_based=True
    )
    
  4. log_processorby_epoch 设置为 False

    log_processor = dict(
        by_epoch=False
    )
    

15 分钟教程训练 CIFAR10 数据集为例:

Step Training by epoch Training by iteration
Build model
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel


class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels
Build dataloader
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(
    batch_size=32,
    shuffle=True,
    dataset=torchvision.datasets.CIFAR10(
        'data/cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(**norm_cfg)])))

val_dataloader = DataLoader(
    batch_size=32,
    shuffle=False,
    dataset=torchvision.datasets.CIFAR10(
        'data/cifar10',
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(**norm_cfg)])))
Prepare metric
from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # save the middle result of a batch to `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # return the dict containing the eval results
        # the key is the name of the metric name
        return dict(accuracy=100 * total_correct / total_size)
Configure default hooks
default_hooks = dict(
    logger=dict(type='LoggerHook', log_metric_by_epoch=True),
    checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
)
default_hooks = dict(
    logger=dict(type='LoggerHook', log_metric_by_epoch=False),
    checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)
Configure parameter scheduler
param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6, 8],
    by_epoch=True,
)
param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6000, 8000],
    by_epoch=False,
)
Configure log_processor
# The default configuration of log_processor is used for epoch based training.
# Defining it here additionally is for building runner with the same way.
log_processor = dict(by_epoch=True)
log_processor = dict(by_epoch=False)
Configure train_cfg
train_cfg = dict(
    by_epoch=True,
    max_epochs=10,
    val_interval=2
)
train_cfg = dict(
    by_epoch=False,
    max_iters=10000,
    val_interval=2000
)
Build Runner
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=train_cfg,
    log_processor=log_processor,
    default_hooks=default_hooks,
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
)
runner.train()

备注

如果基础配置文件为 train_dataloader 配置了基于 iteration/epoch 采样的 sampler,则需要在当前配置文件中将其更改为指定类型的 sampler,或将其设置为 None。当 dataloader 中的 sampler 为 None,MMEngine 或根据 train_cfg 中的 by_epoch 参数选择 InfiniteSampler(False) 或 DefaultSampler(True)。

备注

如果基础配置文件在 train_cfg 中指定了 type,那么必须在当前配置文件中将 type 覆盖为(IterBasedTrainLoop 或 EpochBasedTrainLoop),而不能简单的指定 by_epoch 参数。

Read the Docs v: v0.7.0
Versions
latest
stable
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.