mmengine.hooks.early_stopping_hook 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from math import inf, isfinite
from typing import Optional, Tuple, Union
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Union[dict, tuple, list]]
[文档]@HOOKS.register_module()
class EarlyStoppingHook(Hook):
"""Early stop the training when the monitored metric reached a plateau.
Args:
monitor (str): The monitored metric key to decide early stopping.
rule (str, optional): Comparison rule. Options are 'greater',
'less'. Defaults to None.
min_delta (float, optional): Minimum difference to continue the
training. Defaults to 0.01.
strict (bool, optional): Whether to crash the training when `monitor`
is not found in the `metrics`. Defaults to False.
check_finite: Whether to stop training when the monitor becomes NaN or
infinite. Defaults to True.
patience (int, optional): The times of validation with no improvement
after which training will be stopped. Defaults to 5.
stopping_threshold (float, optional): Stop training immediately once
the monitored quantity reaches this threshold. Defaults to None.
Note:
`New in version 0.7.0.`
"""
priority = 'LOWEST'
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
_default_greater_keys = [
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
'mAcc', 'aAcc'
]
_default_less_keys = ['loss']
def __init__(
self,
monitor: str,
rule: Optional[str] = None,
min_delta: float = 0.1,
strict: bool = False,
check_finite: bool = True,
patience: int = 5,
stopping_threshold: Optional[float] = None,
):
self.monitor = monitor
if rule is not None:
if rule not in ['greater', 'less']:
raise ValueError(
'`rule` should be either "greater" or "less", '
f'but got {rule}')
else:
rule = self._init_rule(monitor)
self.rule = rule
self.min_delta = min_delta if rule == 'greater' else -1 * min_delta
self.strict = strict
self.check_finite = check_finite
self.patience = patience
self.stopping_threshold = stopping_threshold
self.wait_count = 0
self.best_score = -inf if rule == 'greater' else inf
def _init_rule(self, monitor: str) -> str:
greater_keys = {key.lower() for key in self._default_greater_keys}
less_keys = {key.lower() for key in self._default_less_keys}
monitor_lc = monitor.lower()
if monitor_lc in greater_keys:
rule = 'greater'
elif monitor_lc in less_keys:
rule = 'less'
elif any(key in monitor_lc for key in greater_keys):
rule = 'greater'
elif any(key in monitor_lc for key in less_keys):
rule = 'less'
else:
raise ValueError(f'Cannot infer the rule for {monitor}, thus rule '
'must be specified.')
return rule
def _check_stop_condition(self, current_score: float) -> Tuple[bool, str]:
compare = self.rule_map[self.rule]
stop_training = False
reason_message = ''
if self.check_finite and not isfinite(current_score):
stop_training = True
reason_message = (f'Monitored metric {self.monitor} = '
f'{current_score} is infinite. '
f'Previous best value was '
f'{self.best_score:.3f}.')
elif self.stopping_threshold is not None and compare(
current_score, self.stopping_threshold):
stop_training = True
self.best_score = current_score
reason_message = (f'Stopping threshold reached: '
f'`{self.monitor}` = {current_score} is '
f'{self.rule} than {self.stopping_threshold}.')
elif compare(self.best_score + self.min_delta, current_score):
self.wait_count += 1
if self.wait_count >= self.patience:
reason_message = (f'the monitored metric did not improve '
f'in the last {self.wait_count} records. '
f'best score: {self.best_score:.3f}. ')
stop_training = True
else:
self.best_score = current_score
self.wait_count = 0
return stop_training, reason_message
[文档] def before_run(self, runner) -> None:
"""Check `stop_training` variable in `runner.train_loop`.
Args:
runner (Runner): The runner of the training process.
"""
assert hasattr(runner.train_loop, 'stop_training'), \
'`train_loop` should contain `stop_training` variable.'
[文档] def after_val_epoch(self, runner, metrics):
"""Decide whether to stop the training process.
Args:
runner (Runner): The runner of the training process.
metrics (dict): Evaluation results of all metrics
"""
if self.monitor not in metrics:
if self.strict:
raise RuntimeError(
'Early stopping conditioned on metric '
f'`{self.monitor} is not available. Please check available'
f' metrics {metrics}, or set `strict=False` in '
'`EarlyStoppingHook`.')
warnings.warn(
'Skip early stopping process since the evaluation '
f'results ({metrics.keys()}) do not include `monitor` '
f'({self.monitor}).')
return
current_score = metrics[self.monitor]
stop_training, message = self._check_stop_condition(current_score)
if stop_training:
runner.train_loop.stop_training = True
runner.logger.info(message)