Shortcuts

Source code for mmengine.hooks.checkpoint_hook

# Copyright (c) OpenMMLab. All rights reserved.
import hashlib
import logging
import os.path as osp
import pickle
from collections import deque
from math import inf
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union

from mmengine.dist import is_main_process, master_only
from mmengine.fileio import FileClient, get_file_backend
from mmengine.logging import print_log
from mmengine.registry import HOOKS
from mmengine.utils import is_list_of, is_seq_of
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]


[docs]@HOOKS.register_module() class CheckpointHook(Hook): """Save checkpoints periodically. Args: interval (int): The saving period. If ``by_epoch=True``, interval indicates epochs, otherwise it indicates iterations. Defaults to -1, which means "never". by_epoch (bool): Saving checkpoints by epoch or by iteration. Defaults to True. save_optimizer (bool): Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. Defaults to True. save_param_scheduler (bool): Whether to save param_scheduler state_dict in the checkpoint. It is usually used for resuming experiments. Defaults to True. out_dir (str, Path, Optional): The root directory to save checkpoints. If not specified, ``runner.work_dir`` will be used by default. If specified, the ``out_dir`` will be the concatenation of ``out_dir`` and the last level directory of ``runner.work_dir``. For example, if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is ``./work_dir/cur_exp``, then the ckpt will be saved in ``./tmp/cur_exp``. Defaults to None. max_keep_ckpts (int): The maximum checkpoints to keep. In some cases we want only the latest few checkpoints and would like to delete old ones to save the disk space. Defaults to -1, which means unlimited. save_last (bool): Whether to force the last checkpoint to be saved regardless of interval. Defaults to True. save_best (str, List[str], optional): If a metric is specified, it would measure the best checkpoint during evaluation. If a list of metrics is passed, it would measure a group of best checkpoints corresponding to the passed metrics. The information about best checkpoint(s) would be saved in ``runner.message_hub`` to keep best score value and best checkpoint path, which will be also loaded when resuming checkpoint. Options are the evaluation metrics on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox detection and instance segmentation. ``AR@100`` for proposal recall. If ``save_best`` is ``auto``, the first key of the returned ``OrderedDict`` result will be used. Defaults to None. rule (str, List[str], optional): Comparison rule for best score. If set to None, it will infer a reasonable rule. Keys such as 'acc', 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' will be inferred by 'less' rule. If ``save_best`` is a list of metrics and ``rule`` is a str, all metrics in ``save_best`` will share the comparison rule. If ``save_best`` and ``rule`` are both lists, their length must be the same, and metrics in ``save_best`` will use the corresponding comparison rule in ``rule``. Options are 'greater', 'less', None and list which contains 'greater' and 'less'. Defaults to None. greater_keys (List[str], optional): Metric keys that will be inferred by 'greater' comparison rule. If ``None``, _default_greater_keys will be used. Defaults to None. less_keys (List[str], optional): Metric keys that will be inferred by 'less' comparison rule. If ``None``, _default_less_keys will be used. Defaults to None. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. Defaults to None. It will be deprecated in future. Please use ``backend_args`` instead. filename_tmpl (str, optional): String template to indicate checkpoint name. If specified, must contain one and only one "{}", which will be replaced with ``epoch + 1`` if ``by_epoch=True`` else ``iteration + 1``. Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth" accordingly. backend_args (dict, optional): Arguments to instantiate the prefix of uri corresponding backend. Defaults to None. `New in version 0.2.0.` published_keys (str, List[str], optional): If ``save_last`` is ``True`` or ``save_best`` is not ``None``, it will automatically publish model with keys in the list after training. Defaults to None. `New in version 0.7.1.` save_begin (int): Control the epoch number or iteration number at which checkpoint saving begins. Defaults to 0, which means saving at the beginning. `New in version 0.8.3.` Examples: >>> # Save best based on single metric >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', >>> rule='less') >>> # Save best based on multi metrics with the same comparison rule >>> CheckpointHook(interval=2, by_epoch=True, >>> save_best=['acc', 'mIoU'], rule='greater') >>> # Save best based on multi metrics with different comparison rule >>> CheckpointHook(interval=2, by_epoch=True, >>> save_best=['FID', 'IS'], rule=['less', 'greater']) >>> # Save best based on single metric and publish model after training >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', >>> rule='less', published_keys=['meta', 'state_dict']) """ out_dir: str priority = 'VERY_LOW' # logic to save best checkpoints # Since the key for determining greater or less is related to the # downstream tasks, downstream repositories may need to overwrite # the following inner variables accordingly. rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} init_value_map = {'greater': -inf, 'less': inf} _default_greater_keys = [ 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', 'mAcc', 'aAcc' ] _default_less_keys = ['loss'] def __init__(self, interval: int = -1, by_epoch: bool = True, save_optimizer: bool = True, save_param_scheduler: bool = True, out_dir: Optional[Union[str, Path]] = None, max_keep_ckpts: int = -1, save_last: bool = True, save_best: Union[str, List[str], None] = None, rule: Union[str, List[str], None] = None, greater_keys: Optional[Sequence[str]] = None, less_keys: Optional[Sequence[str]] = None, file_client_args: Optional[dict] = None, filename_tmpl: Optional[str] = None, backend_args: Optional[dict] = None, published_keys: Union[str, List[str], None] = None, save_begin: int = 0, **kwargs) -> None: self.interval = interval self.by_epoch = by_epoch self.save_optimizer = save_optimizer self.save_param_scheduler = save_param_scheduler self.out_dir = out_dir # type: ignore self.max_keep_ckpts = max_keep_ckpts self.save_last = save_last self.args = kwargs if file_client_args is not None: print_log( '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', logger='current', level=logging.WARNING) if backend_args is not None: raise ValueError( '"file_client_args" and "backend_args" cannot be set ' 'at the same time.') self.file_client_args = file_client_args self.backend_args = backend_args if filename_tmpl is None: if self.by_epoch: self.filename_tmpl = 'epoch_{}.pth' else: self.filename_tmpl = 'iter_{}.pth' else: self.filename_tmpl = filename_tmpl # save best logic assert (isinstance(save_best, str) or is_list_of(save_best, str) or (save_best is None)), ( '"save_best" should be a str or list of str or None, ' f'but got {type(save_best)}') if isinstance(save_best, list): if 'auto' in save_best: assert len(save_best) == 1, ( 'Only support one "auto" in "save_best" list.') assert len(save_best) == len( set(save_best)), ('Find duplicate element in "save_best".') else: # convert str to list[str] if save_best is not None: save_best = [save_best] # type: ignore # noqa: F401 self.save_best = save_best # rule logic assert (isinstance(rule, str) or is_list_of(rule, str) or (rule is None)), ( '"rule" should be a str or list of str or None, ' f'but got {type(rule)}') if isinstance(rule, list): # check the length of rule list assert len(rule) in [ 1, len(self.save_best) # type: ignore ], ('Number of "rule" must be 1 or the same as number of ' f'"save_best", but got {len(rule)}.') else: # convert str/None to list rule = [rule] # type: ignore # noqa: F401 if greater_keys is None: self.greater_keys = self._default_greater_keys else: if not isinstance(greater_keys, (list, tuple)): greater_keys = (greater_keys, ) # type: ignore assert is_seq_of(greater_keys, str) self.greater_keys = greater_keys # type: ignore if less_keys is None: self.less_keys = self._default_less_keys else: if not isinstance(less_keys, (list, tuple)): less_keys = (less_keys, ) # type: ignore assert is_seq_of(less_keys, str) self.less_keys = less_keys # type: ignore if self.save_best is not None: self.is_better_than: Dict[str, Callable] = dict() self._init_rule(rule, self.save_best) if len(self.key_indicators) == 1: self.best_ckpt_path: Optional[str] = None else: self.best_ckpt_path_dict: Dict = dict() # published keys if not (isinstance(published_keys, str) or is_seq_of(published_keys, str) or published_keys is None): raise TypeError( '"published_keys" should be a str or a sequence of str or ' f'None, but got {type(published_keys)}') if isinstance(published_keys, str): published_keys = [published_keys] elif isinstance(published_keys, (list, tuple)): assert len(published_keys) == len(set(published_keys)), ( 'Find duplicate elements in "published_keys".') self.published_keys = published_keys self.last_ckpt = None if save_begin < 0: raise ValueError( 'save_begin should not be less than 0, but got {save_begin}') self.save_begin = save_begin
[docs] def before_train(self, runner) -> None: """Finish all operations, related to checkpoint. This function will get the appropriate file client, and the directory to save these checkpoints of the model. Args: runner (Runner): The runner of the training process. """ if self.out_dir is None: self.out_dir = runner.work_dir # If self.file_client_args is None, self.file_client will not # used in CheckpointHook. To avoid breaking backward compatibility, # it will not be removed util the release of MMEngine1.0 self.file_client = FileClient.infer_client(self.file_client_args, self.out_dir) if self.file_client_args is None: self.file_backend = get_file_backend( self.out_dir, backend_args=self.backend_args) else: self.file_backend = self.file_client # if `self.out_dir` is not equal to `runner.work_dir`, it means that # `self.out_dir` is set so the final `self.out_dir` is the # concatenation of `self.out_dir` and the last level directory of # `runner.work_dir` if self.out_dir != runner.work_dir: basename = osp.basename(runner.work_dir.rstrip(osp.sep)) self.out_dir = self.file_backend.join_path( self.out_dir, basename) # type: ignore # noqa: E501 runner.logger.info(f'Checkpoints will be saved to {self.out_dir}.') if self.save_best is not None: if len(self.key_indicators) == 1: if 'best_ckpt' not in runner.message_hub.runtime_info: self.best_ckpt_path = None else: self.best_ckpt_path = runner.message_hub.get_info( 'best_ckpt') else: for key_indicator in self.key_indicators: best_ckpt_name = f'best_ckpt_{key_indicator}' if best_ckpt_name not in runner.message_hub.runtime_info: self.best_ckpt_path_dict[key_indicator] = None else: self.best_ckpt_path_dict[ key_indicator] = runner.message_hub.get_info( best_ckpt_name) if self.max_keep_ckpts > 0: keep_ckpt_ids = [] if 'keep_ckpt_ids' in runner.message_hub.runtime_info: keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids') while len(keep_ckpt_ids) > self.max_keep_ckpts: step = keep_ckpt_ids.pop(0) if is_main_process(): path = self.file_backend.join_path( self.out_dir, self.filename_tmpl.format(step)) if self.file_backend.isfile(path): self.file_backend.remove(path) elif self.file_backend.isdir(path): # checkpoints saved by deepspeed are directories self.file_backend.rmtree(path) self.keep_ckpt_ids: deque = deque(keep_ckpt_ids, self.max_keep_ckpts)
[docs] def after_train_epoch(self, runner) -> None: """Save the checkpoint and synchronize buffers after each epoch. Args: runner (Runner): The runner of the training process. """ if not self.by_epoch: return # save checkpoint for following cases: # 1. every ``self.interval`` epochs which start at ``self.save_begin`` # 2. reach the last epoch of training if self.every_n_epochs(runner, self.interval, self.save_begin) or ( self.save_last and self.is_last_train_epoch(runner)): runner.logger.info( f'Saving checkpoint at {runner.epoch + 1} epochs') self._save_checkpoint(runner)
[docs] def after_val_epoch(self, runner, metrics): """Save the checkpoint and synchronize buffers after each evaluation epoch. Args: runner (Runner): The runner of the training process. metrics (dict): Evaluation results of all metrics """ if len(metrics) == 0: runner.logger.warning( 'Since `metrics` is an empty dict, the behavior to save ' 'the best checkpoint will be skipped in this evaluation.') return self._save_best_checkpoint(runner, metrics)
[docs] def after_train(self, runner) -> None: """Publish the checkpoint after training. Args: runner (Runner): The runner of the training process. """ if self.published_keys is None: return if self.save_last and self.last_ckpt is not None: self._publish_model(runner, self.last_ckpt) if getattr(self, 'best_ckpt_path', None) is not None: self._publish_model(runner, str(self.best_ckpt_path)) if getattr(self, 'best_ckpt_path_dict', None) is not None: for best_ckpt in self.best_ckpt_path_dict.values(): self._publish_model(runner, best_ckpt)
@master_only def _publish_model(self, runner, ckpt_path: str) -> None: """Remove unnecessary keys from ckpt_path and save the new checkpoint. Args: runner (Runner): The runner of the training process. ckpt_path (str): The checkpoint path that ought to be published. """ from mmengine.runner import save_checkpoint from mmengine.runner.checkpoint import _load_checkpoint checkpoint = _load_checkpoint(ckpt_path) assert self.published_keys is not None removed_keys = [] for key in list(checkpoint.keys()): if key not in self.published_keys: removed_keys.append(key) checkpoint.pop(key) if removed_keys: print_log( f'Key {removed_keys} will be removed because they are not ' 'found in published_keys. If you want to keep them, ' f'please set `{removed_keys}` in published_keys', logger='current') checkpoint_data = pickle.dumps(checkpoint) sha = hashlib.sha256(checkpoint_data).hexdigest() final_path = osp.splitext(ckpt_path)[0] + f'-{sha[:8]}.pth' save_checkpoint(checkpoint, final_path) print_log( f'The checkpoint ({ckpt_path}) is published to ' f'{final_path}.', logger='current') def _save_checkpoint_with_step(self, runner, step, meta): # remove other checkpoints before save checkpoint to make the # self.keep_ckpt_ids are saved as expected if self.max_keep_ckpts > 0: # _save_checkpoint and _save_best_checkpoint may call this # _save_checkpoint_with_step in one epoch if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step: pass else: if len(self.keep_ckpt_ids) == self.max_keep_ckpts: _step = self.keep_ckpt_ids.popleft() if is_main_process(): ckpt_path = self.file_backend.join_path( self.out_dir, self.filename_tmpl.format(_step)) if self.file_backend.isfile(ckpt_path): self.file_backend.remove(ckpt_path) elif self.file_backend.isdir(ckpt_path): # checkpoints saved by deepspeed are directories self.file_backend.rmtree(ckpt_path) self.keep_ckpt_ids.append(step) runner.message_hub.update_info('keep_ckpt_ids', list(self.keep_ckpt_ids)) ckpt_filename = self.filename_tmpl.format(step) self.last_ckpt = self.file_backend.join_path(self.out_dir, ckpt_filename) runner.message_hub.update_info('last_ckpt', self.last_ckpt) runner.save_checkpoint( self.out_dir, ckpt_filename, self.file_client_args, save_optimizer=self.save_optimizer, save_param_scheduler=self.save_param_scheduler, meta=meta, by_epoch=self.by_epoch, backend_args=self.backend_args, **self.args) # Model parallel-like training should involve pulling sharded states # from all ranks, but skip the following procedure. if not is_main_process(): return save_file = osp.join(runner.work_dir, 'last_checkpoint') with open(save_file, 'w') as f: f.write(self.last_ckpt) # type: ignore def _save_checkpoint(self, runner) -> None: """Save the current checkpoint and delete outdated checkpoint. Args: runner (Runner): The runner of the training process. """ if self.by_epoch: step = runner.epoch + 1 meta = dict(epoch=step, iter=runner.iter) else: step = runner.iter + 1 meta = dict(epoch=runner.epoch, iter=step) self._save_checkpoint_with_step(runner, step, meta=meta) def _save_best_checkpoint(self, runner, metrics) -> None: """Save the current checkpoint and delete outdated checkpoint. Args: runner (Runner): The runner of the training process. metrics (dict): Evaluation results of all metrics. """ if not self.save_best: return if self.by_epoch: ckpt_filename = self.filename_tmpl.format(runner.epoch) cur_type, cur_time = 'epoch', runner.epoch else: ckpt_filename = self.filename_tmpl.format(runner.iter) cur_type, cur_time = 'iter', runner.iter meta = dict(epoch=runner.epoch, iter=runner.iter) # handle auto in self.key_indicators and self.rules before the loop if 'auto' in self.key_indicators: self._init_rule(self.rules, [list(metrics.keys())[0]]) best_ckpt_updated = False # save best logic # get score from messagehub for key_indicator, rule in zip(self.key_indicators, self.rules): key_score = metrics[key_indicator] if len(self.key_indicators) == 1: best_score_key = 'best_score' runtime_best_ckpt_key = 'best_ckpt' best_ckpt_path = self.best_ckpt_path else: best_score_key = f'best_score_{key_indicator}' runtime_best_ckpt_key = f'best_ckpt_{key_indicator}' best_ckpt_path = self.best_ckpt_path_dict[key_indicator] if best_score_key not in runner.message_hub.runtime_info: best_score = self.init_value_map[rule] else: best_score = runner.message_hub.get_info(best_score_key) if key_score is None or not self.is_better_than[key_indicator]( key_score, best_score): continue best_ckpt_updated = True best_score = key_score runner.message_hub.update_info(best_score_key, best_score) if best_ckpt_path and is_main_process(): is_removed = False if self.file_backend.isfile(best_ckpt_path): self.file_backend.remove(best_ckpt_path) is_removed = True elif self.file_backend.isdir(best_ckpt_path): # checkpoints saved by deepspeed are directories self.file_backend.rmtree(best_ckpt_path) is_removed = True if is_removed: runner.logger.info( f'The previous best checkpoint {best_ckpt_path} ' 'is removed') best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}' # Replace illegal characters for filename with `_` best_ckpt_name = best_ckpt_name.replace('/', '_') if len(self.key_indicators) == 1: self.best_ckpt_path = self.file_backend.join_path( # type: ignore # noqa: E501 self.out_dir, best_ckpt_name) runner.message_hub.update_info(runtime_best_ckpt_key, self.best_ckpt_path) else: self.best_ckpt_path_dict[ key_indicator] = self.file_backend.join_path( # type: ignore # noqa: E501 self.out_dir, best_ckpt_name) runner.message_hub.update_info( runtime_best_ckpt_key, self.best_ckpt_path_dict[key_indicator]) runner.save_checkpoint( self.out_dir, filename=best_ckpt_name, file_client_args=self.file_client_args, save_optimizer=False, save_param_scheduler=False, meta=meta, by_epoch=False, backend_args=self.backend_args) runner.logger.info( f'The best checkpoint with {best_score:0.4f} {key_indicator} ' f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') # save checkpoint again to update the best_score and best_ckpt stored # in message_hub because the checkpoint saved in `after_train_epoch` # or `after_train_iter` stage only keep the previous best checkpoint # not the current best checkpoint which causes the current best # checkpoint can not be removed when resuming training. if best_ckpt_updated and self.last_ckpt is not None: self._save_checkpoint_with_step(runner, cur_time, meta) def _init_rule(self, rules, key_indicators) -> None: """Initialize rule, key_indicator, comparison_func, and best score. If key_indicator is a list of string and rule is a string, all metric in the key_indicator will share the same rule. Here is the rule to determine which rule is used for key indicator when the rule is not specific (note that the key indicator matching is case- insensitive): 1. If the key indicator is in ``self.greater_keys``, the rule will be specified as 'greater'. 2. Or if the key indicator is in ``self.less_keys``, the rule will be specified as 'less'. 3. Or if any one item in ``self.greater_keys`` is a substring of key_indicator, the rule will be specified as 'greater'. 4. Or if any one item in ``self.less_keys`` is a substring of key_indicator, the rule will be specified as 'less'. Args: rule (List[Optional[str]]): Comparison rule for best score. key_indicator (List[str]): Key indicator to determine the comparison rule. """ if len(rules) == 1: rules = rules * len(key_indicators) self.rules = [] for rule, key_indicator in zip(rules, key_indicators): if rule not in self.rule_map and rule is not None: raise KeyError('rule must be greater, less or None, ' f'but got {rule}.') if rule is None and key_indicator != 'auto': # `_lc` here means we use the lower case of keys for # case-insensitive matching key_indicator_lc = key_indicator.lower() greater_keys = {key.lower() for key in self.greater_keys} less_keys = {key.lower() for key in self.less_keys} if key_indicator_lc in greater_keys: rule = 'greater' elif key_indicator_lc in less_keys: rule = 'less' elif any(key in key_indicator_lc for key in greater_keys): rule = 'greater' elif any(key in key_indicator_lc for key in less_keys): rule = 'less' else: raise ValueError('Cannot infer the rule for key ' f'{key_indicator}, thus a specific rule ' 'must be specified.') if rule is not None: self.is_better_than[key_indicator] = self.rule_map[rule] self.rules.append(rule) self.key_indicators = key_indicators
[docs] def after_train_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs=Optional[dict]) -> None: """Save the checkpoint and synchronize buffers after each iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. data_batch (dict or tuple or list, optional): Data from dataloader. outputs (dict, optional): Outputs from model. """ if self.by_epoch: return # save checkpoint for following cases: # 1. every ``self.interval`` iterations # which start at ``self.save_begin`` # 2. reach the last iteration of training if self.every_n_train_iters(runner, self.interval, self.save_begin) or \ (self.save_last and self.is_last_train_iter(runner)): runner.logger.info( f'Saving checkpoint at {runner.iter + 1} iterations') self._save_checkpoint(runner)

© Copyright 2022, mmengine contributors. Revision 2c4516c6.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.