mmengine.runner.loops 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import logging
import time
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch.utils.data import DataLoader
from mmengine.evaluator import Evaluator
from mmengine.logging import HistoryBuffer, print_log
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
[文档]@LOOPS.register_module()
class EpochBasedTrainLoop(BaseLoop):
"""Loop for epoch-based training.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
max_epochs (int): Total training epochs.
val_begin (int): The epoch that begins validating.
Defaults to 1.
val_interval (int): Validation interval. Defaults to 1.
dynamic_intervals (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
"""
def __init__(
self,
runner,
dataloader: Union[DataLoader, Dict],
max_epochs: int,
val_begin: int = 1,
val_interval: int = 1,
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader)
self._max_epochs = int(max_epochs)
assert self._max_epochs == max_epochs, \
f'`max_epochs` should be a integer number, but get {max_epochs}.'
self._max_iters = self._max_epochs * len(self.dataloader)
self._epoch = 0
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
# This attribute will be updated by `EarlyStoppingHook`
# when it is enabled.
self.stop_training = False
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
print_log(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in visualizer will be '
'None.',
logger='current',
level=logging.WARNING)
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
self.val_interval, dynamic_intervals)
@property
def max_epochs(self):
"""int: Total epochs to train model."""
return self._max_epochs
@property
def max_iters(self):
"""int: Total iterations to train model."""
return self._max_iters
@property
def epoch(self):
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
"""int: Current iteration."""
return self._iter
[文档] def run(self) -> torch.nn.Module:
"""Launch training."""
self.runner.call_hook('before_train')
while self._epoch < self._max_epochs and not self.stop_training:
self.run_epoch()
self._decide_current_val_interval()
if (self.runner.val_loop is not None
and self._epoch >= self.val_begin
and (self._epoch % self.val_interval == 0
or self._epoch == self._max_epochs)):
self.runner.val_loop.run()
self.runner.call_hook('after_train')
return self.runner.model
[文档] def run_epoch(self) -> None:
"""Iterate one epoch."""
self.runner.call_hook('before_train_epoch')
self.runner.model.train()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
self.runner.call_hook('after_train_epoch')
self._epoch += 1
[文档] def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
"""Iterate one min-batch.
Args:
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_train_iter', batch_idx=idx, data_batch=data_batch)
# Enable gradient accumulation mode and avoid unnecessary gradient
# synchronization during gradient accumulation process.
# outputs should be a dict of loss.
outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.call_hook(
'after_train_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
self._iter += 1
def _decide_current_val_interval(self) -> None:
"""Dynamically modify the ``val_interval``."""
step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1))
self.val_interval = self.dynamic_intervals[step - 1]
class _InfiniteDataloaderIterator:
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop.
It resets the dataloader to continue iterating when the iterator has
iterated over all the data. However, this approach is not efficient, as the
workers need to be restarted every time the dataloader is reset. It is
recommended to use `mmengine.dataset.InfiniteSampler` to enable the
dataloader to iterate infinitely.
"""
def __init__(self, dataloader: DataLoader) -> None:
self._dataloader = dataloader
self._iterator = iter(self._dataloader)
self._epoch = 0
def __iter__(self):
return self
def __next__(self) -> Sequence[dict]:
try:
data = next(self._iterator)
except StopIteration:
print_log(
'Reach the end of the dataloader, it will be '
'restarted and continue to iterate. It is '
'recommended to use '
'`mmengine.dataset.InfiniteSampler` to enable the '
'dataloader to iterate infinitely.',
logger='current',
level=logging.WARNING)
self._epoch += 1
if hasattr(self._dataloader, 'sampler') and hasattr(
self._dataloader.sampler, 'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no sampler,
# or data loader uses `SequentialSampler` in Pytorch.
self._dataloader.sampler.set_epoch(self._epoch)
elif hasattr(self._dataloader, 'batch_sampler') and hasattr(
self._dataloader.batch_sampler.sampler, 'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no batch
# sampler. batch sampler in pytorch warps the sampler as its
# attributes.
self._dataloader.batch_sampler.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self._iterator = iter(self._dataloader)
data = next(self._iterator)
return data
[文档]@LOOPS.register_module()
class IterBasedTrainLoop(BaseLoop):
"""Loop for iter-based training.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
max_iters (int): Total training iterations.
val_begin (int): The iteration that begins validating.
Defaults to 1.
val_interval (int): Validation interval. Defaults to 1000.
dynamic_intervals (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
"""
def __init__(
self,
runner,
dataloader: Union[DataLoader, Dict],
max_iters: int,
val_begin: int = 1,
val_interval: int = 1000,
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader)
self._max_iters = int(max_iters)
assert self._max_iters == max_iters, \
f'`max_iters` should be a integer number, but get {max_iters}'
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
self._epoch = 0
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
# This attribute will be updated by `EarlyStoppingHook`
# when it is enabled.
self.stop_training = False
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
print_log(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in visualizer will be '
'None.',
logger='current',
level=logging.WARNING)
# get the iterator of the dataloader
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
self.val_interval, dynamic_intervals)
@property
def max_epochs(self):
"""int: Total epochs to train model."""
return self._max_epochs
@property
def max_iters(self):
"""int: Total iterations to train model."""
return self._max_iters
@property
def epoch(self):
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
"""int: Current iteration."""
return self._iter
[文档] def run(self) -> None:
"""Launch training."""
self.runner.call_hook('before_train')
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
if self._iter > 0:
print_log(
f'Advance dataloader {self._iter} steps to skip data '
'that has already been trained',
logger='current',
level=logging.WARNING)
for _ in range(self._iter):
next(self.dataloader_iterator)
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()
data_batch = next(self.dataloader_iterator)
self.run_iter(data_batch)
self._decide_current_val_interval()
if (self.runner.val_loop is not None
and self._iter >= self.val_begin
and (self._iter % self.val_interval == 0
or self._iter == self._max_iters)):
self.runner.val_loop.run()
self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train')
return self.runner.model
[文档] def run_iter(self, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch.
Args:
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
# Enable gradient accumulation mode and avoid unnecessary gradient
# synchronization during gradient accumulation process.
# outputs should be a dict of loss.
outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.call_hook(
'after_train_iter',
batch_idx=self._iter,
data_batch=data_batch,
outputs=outputs)
self._iter += 1
def _decide_current_val_interval(self) -> None:
"""Dynamically modify the ``val_interval``."""
step = bisect.bisect(self.dynamic_milestones, (self._iter + 1))
self.val_interval = self.dynamic_intervals[step - 1]
[文档]@LOOPS.register_module()
class ValLoop(BaseLoop):
"""Loop for validation.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to
False.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False) -> None:
super().__init__(runner, dataloader)
if isinstance(evaluator, (dict, list)):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
assert isinstance(evaluator, Evaluator), (
'evaluator must be one of dict, list or Evaluator instance, '
f'but got {type(evaluator)}.')
self.evaluator = evaluator # type: ignore
if hasattr(self.dataloader.dataset, 'metainfo'):
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
print_log(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in evaluator, metric and '
'visualizer will be None.',
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.val_loss: Dict[str, HistoryBuffer] = dict()
[文档] def run(self) -> dict:
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
# clear val loss
self.val_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
if self.val_loss:
loss_dict = _parse_losses(self.val_loss, 'val')
metrics.update(loss_dict)
self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val')
return metrics
[文档] @torch.no_grad()
def run_iter(self, idx, data_batch: Sequence[dict]):
"""Iterate one mini-batch.
Args:
data_batch (Sequence[dict]): Batch of data
from dataloader.
"""
self.runner.call_hook(
'before_val_iter', batch_idx=idx, data_batch=data_batch)
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)
outputs, self.val_loss = _update_losses(outputs, self.val_loss)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_val_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
[文档]@LOOPS.register_module()
class TestLoop(BaseLoop):
"""Loop for test.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 testing. Defaults to
False.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False):
super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or isinstance(evaluator, list):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
self.evaluator = evaluator # type: ignore
if hasattr(self.dataloader.dataset, 'metainfo'):
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
print_log(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in evaluator, metric and '
'visualizer will be None.',
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.test_loss: Dict[str, HistoryBuffer] = dict()
[文档] def run(self) -> dict:
"""Launch test."""
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()
# clear test loss
self.test_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
if self.test_loss:
loss_dict = _parse_losses(self.test_loss, 'test')
metrics.update(loss_dict)
self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
[文档] @torch.no_grad()
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch.
Args:
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch)
outputs, self.test_loss = _update_losses(outputs, self.test_loss)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
def _parse_losses(losses: Dict[str, HistoryBuffer],
stage: str) -> Dict[str, float]:
"""Parses the raw losses of the network.
Args:
losses (dict): raw losses of the network.
stage (str): The stage of loss, e.g., 'val' or 'test'.
Returns:
dict[str, float]: The key is the loss name, and the value is the
average loss.
"""
all_loss = 0
loss_dict: Dict[str, float] = dict()
for loss_name, loss_value in losses.items():
avg_loss = loss_value.mean()
loss_dict[loss_name] = avg_loss
if 'loss' in loss_name:
all_loss += avg_loss
loss_dict[f'{stage}_loss'] = all_loss
return loss_dict
def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
"""Update and record the losses of the network.
Args:
outputs (list): The outputs of the network.
losses (dict): The losses of the network.
Returns:
list: The updated outputs of the network.
dict: The updated losses of the network.
"""
if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()
for loss_name, loss_value in loss.items():
if loss_name not in losses:
losses[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
losses[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
for loss_value_i in loss_value:
losses[loss_name].update(loss_value_i.item())
return outputs, losses