从 EpochBased 切换至 IterBased¶
MMEngine 支持两种训练模式,基于轮次的 EpochBased 方式和基于迭代次数的 IterBased 方式,这两种方式在下游算法库均有使用,例如 MMDetection 默认使用 EpochBased 方式,MMSegmentation 默认使用 IterBased 方式。
MMEngine 很多模块默认以 EpochBased 的模式执行,例如 ParamScheduler
, LoggerHook
, CheckpointHook
等,常见的 EpochBased 配置写法如下:
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8]
by_epoch=True # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
)
default_hooks = dict(
logger=dict(type='LoggerHook'),
checkpoint=dict(type='CheckpointHook', interval=2),
)
train_cfg = dict(
by_epoch=True, # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
max_epochs=10,
val_interval=2
)
log_processor = dict(
by_epoch=True
) # log_processor 的 by_epoch 默认为 True,这边显式的写出来只是为了方便对比, 实际上不需要设置
runner = Runner(
model=ResNet18(),
work_dir='./work_dir',
train_dataloader=train_dataloader_cfg,
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
param_scheduler=param_scheduler
default_hooks=default_hooks,
log_processor=log_processor,
train_cfg=train_cfg,
resume=True,
)
如果想按照 iter 训练模型,需要做以下改动:
将
train_cfg
中的by_epoch
设置为False
,同时将max_iters
设置为训练的总 iter 数,val_iterval
设置为验证间隔的 iter 数。train_cfg = dict( by_epoch=False, max_iters=10000, val_interval=2000 )
将
default_hooks
中的logger
的log_metric_by_epoch
设置为 False,checkpoint
的by_epoch
设置为False
。default_hooks = dict( logger=dict(type='LoggerHook', log_metric_by_epoch=False), checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), )
将
param_scheduler
中的by_epoch
设置为False
,并將epoch
相关的参数换算成iter
param_scheduler = dict( type='MultiStepLR', milestones=[6000, 8000], by_epoch=False, )
除了这种方式,如果你能保证 IterBasedTraining 和 EpochBasedTraining 总 iter 数一致,直接设置
convert_to_iter_based
为True
即可。param_scheduler = dict( type='MultiStepLR', milestones=[6, 8] convert_to_iter_based=True )
将
log_processor
的by_epoch
设置为False
。log_processor = dict( by_epoch=False )
Step | Training by epoch | Training by iteration |
---|---|---|
Build model | import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
|
|
Build dataloader |
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(
batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)])))
val_dataloader = DataLoader(
batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)])))
|
|
Prepare metric |
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# save the middle result of a batch to `self.results`
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# return the dict containing the eval results
# the key is the name of the metric name
return dict(accuracy=100 * total_correct / total_size)
|
|
Configure default hooks |
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=True),
checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
)
|
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)
|
Configure parameter scheduler |
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8],
by_epoch=True,
)
|
param_scheduler = dict(
type='MultiStepLR',
milestones=[6000, 8000],
by_epoch=False,
)
|
Configure log_processor |
# The default configuration of log_processor is used for epoch based training.
# Defining it here additionally is for building runner with the same way.
log_processor = dict(by_epoch=True)
|
log_processor = dict(by_epoch=False)
|
Configure train_cfg |
train_cfg = dict(
by_epoch=True,
max_epochs=10,
val_interval=2
)
|
train_cfg = dict(
by_epoch=False,
max_iters=10000,
val_interval=2000
)
|
Build Runner |
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=train_cfg,
log_processor=log_processor,
default_hooks=default_hooks,
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
runner.train()
|
备注
如果基础配置文件为 train_dataloader 配置了基于 iteration/epoch 采样的 sampler,则需要在当前配置文件中将其更改为指定类型的 sampler,或将其设置为 None。当 dataloader 中的 sampler 为 None,MMEngine 或根据 train_cfg 中的 by_epoch 参数选择 InfiniteSampler
(False) 或 DefaultSampler
(True)。
备注
如果基础配置文件在 train_cfg 中指定了 type,那么必须在当前配置文件中将 type 覆盖为(IterBasedTrainLoop 或 EpochBasedTrainLoop),而不能简单的指定 by_epoch 参数。