Hook programming is a programming pattern in which a mount point is set in one or more locations of a program. When the program runs to a mount point, all methods registered to it at runtime are automatically called. Hook programming can increase the flexibility and extensibility of the program, since users can register custom methods to the mount point to be called without modifying the code in the program.

Built-in Hooks

MMEngine encapsules many ultilities as built-in hooks. These hooks are divided into two categories, namely default hooks and custom hooks. The former refers to those registered with the Runner by default, while the latter refers to those registered by the user on demand.

Each hook has a corresponding priority. At each mount point, hooks with higher priority are called earlier by the Runner. When sharing the same priority, the hooks are called in their registration order. The priority list is as follows.

  • HIGHEST (0)

  • VERY_HIGH (10)

  • HIGH (30)


  • NORMAL (50)


  • LOW (70)

  • VERY_LOW (90)

  • LOWEST (100)

default hooks





update runtime information into message hub



Update the time spent during iteration into message hub



Ensure distributed Sampler shuffle is active



Collect logs from different components of Runner and write them to terminal, JSON file, tensorboard and wandb .etc



update some hyper-parameters of optimizer

LOW (70)


Save checkpoints periodically


custom hooks





apply Exponential Moving Average (EMA) on the model during training



Releases all unoccupied cached GPU memory during the process of training



Synchronize model buffers at the end of each epoch



It is not recommended to modify the priority of the default hooks, as hooks with lower priority may depend on hooks with higher priority. For example, CheckpointHook needs to have a lower priority than ParamSchedulerHook so that the saved optimizer state is correct. Also, the priority of custom hooks defaults to NORMAL (50).

The two types of hooks are set differently in the Runner, with the configuration of default hooks being passed to the default_hooks parameter of the Runner and the configuration of custom hooks being passed to the custom_hooks parameter, as follows.

from mmengine.runner import Runner
default_hooks = dict(
    checkpoint=dict(type='CheckpointHook', interval=1),
custom_hooks = [dict(type='EmptyCacheHook')]
runner = Runner(default_hooks=default_hooks, custom_hooks=custom_hooks, ...)


LoggerHook collects logs from different components of Runner and write them to terminal, JSON file, tensorboard and wandb, etc.


CheckpointHook saves the checkpoints at a given interval. In the case of distributed training, only the master process will save the checkpoints. The main features of CheckpointHook is as follows.

  • Save checkpoints by interval, and support saving them by epoch or iteration

  • Save the most recent checkpoints

  • Save the best checkpoints

  • Specify the path to save the checkpoints

  • Make checkpoints for publish

  • Control the epoch number or iteration number at which checkpoint saving begins

For more features, please read the CheckpointHook API documentation.

The six features mentioned above are described below.

  • Save checkpoints by interval, and support saving them by epoch or iteration

    Suppose we train a total of 20 epochs and want to save the checkpoints every 5 epochs, the following configuration will help us achieve this requirement.

    # the default value of by_epoch is True
    default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True))

    If you want to save checkpoints by iteration, you can set by_epoch to False and interval=5 to save them every 5 iterations.

    default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=False))
  • Save the most recent checkpoints

    If you only want to keep a certain number of checkpoints, you can set the max_keep_ckpts parameter. When the number of checkpoints saved exceeds max_keep_ckpts, the previous checkpoints will be deleted.

    default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=2))

    The above config shows that if a total of 20 epochs are trained, the model will be saved at epochs 5, 10, 15, and 20, but the checkpoint epoch_5.pth will be deleted at epoch 15, and at epoch 20 the checkpoint epoch_10.pth will be deleted, so that only the epoch_15.pth and epoch_20.pth will be saved.

  • Save the best checkpoints

    If you want to save the best checkpoints of the validation set for the training process, you can set the save_best parameter. If set to 'auto', the current checkpoint are judged to be best based on the first evaluation metric of the validation set (the evaluation metrics returned by evaluator are an ordered dictionary).

    default_hooks = dict(checkpoint=dict(type='CheckpointHook', save_best='auto'))

    You can also directly specify the value of save_best as the evaluation metric, for example, in a classification task, you can specify save_best='top-1', then the current checkpoint will be judged as best based on the value of 'top-1'.

    In addition to the save_best parameter, other parameters related to saving the best checkpoint are rule, greater_keys and less_keys, which are used to imply whether its good to have large value or not. For example, if you specify save_best='top-1', you can specify rule='greater' to imply that the larger the value, the better the checkpoint.

  • Specify the path to save the checkpoints

    The checkpoints are saved in work_dir by default, but the path can be changed by setting out_dir.

    default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))
  • Make checkpoints for publish

    If you want to automatically generate publishable checkpoints after training (remove unnecessary keys, such as optimizer state), you can set the published_keys parameter to choose which information to keep. Note: You need to set the save_best or save_last parameters accordingly so that the releasable checkpoints will be generated. Setting save_best will generate the releasable weights of the optimal checkpoint, and setting save_last will generate the releasable final checkpoint. These two parameters can also be set at the same time.

    default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
  • Control the epoch number or iteration number at which checkpoint saving begins

    If you want to set the number of epochs or iterations to control the start of saving weights, you can set the save_begin parameter, defaults to 0, which means saving checkpoints from the beginning of training. For example, if you train for a total of 10 epochs, and save_begin is set to 5, then the checkpoints for epochs 5, 6, 7, 8, 9, and 10 will be saved. If interval=2, only save checkpoints for epochs 5, 7 and 9.

    default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5))

LoggerHook collects logs from different components of Runner and write them to terminal, JSON file, tensorboard and wandb .etc.

If we want to output (or save) the logs every 20 iterations, we can set the interval parameter and configure it as follows.

default_hooks = dict(logger=dict(type='LoggerHook', interval=20))

If you are interested in how MMEngine manages logging, you can refer to logging.


ParamSchedulerHook iterates through all optimizer parameter schedulers of the Runner and calls their step method to update the optimizer parameters in order. See Parameter Schedulers for more details about what are parameter schedulers.

ParamSchedulerHook is registered to the Runner by default and has no configurable parameters, so there is no need to configure it.


IterTimerHook is used to record the time taken to load data and iterate once.

IterTimerHook is registered to the Runner by default and has no configurable parameters, so there is no need to configure it.


DistSamplerSeedHook calls the step method of the Sampler during distributed training to ensure that the shuffle operation takes effect.

DistSamplerSeedHook is registered to the Runner by default and has no configurable parameters, so there is no need to configure it.


RuntimeInfoHook will update the current runtime information (e.g. epoch, iter, max_epochs, max_iters, lr, metrics, etc.) to the message hub at different mount points in the Runner so that other modules without access to the Runner can obtain this information.

RuntimeInfoHook is registered to the Runner by default and has no configurable parameters, so there is no need to configure it.


EMAHook performs an exponential moving average operation on the model during training, with the aim of improving the robustness of the model. Note that the model generated by exponential moving average is only used for validation and testing, and does not affect training.

custom_hooks = [dict(type='EMAHook')]
runner = Runner(custom_hooks=custom_hooks, ...)

EMAHook uses ExponentialMovingAverage by default, with optional values of StochasticWeightAverage and MomentumAnnealingEMA. Other averaging strategies can be used by setting ema_type.

custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]

See EMAHook API Reference for more usage.


EmptyCacheHook calls torch.cuda.empty_cache() to release all unoccupied cached GPU memory. The timing of releasing memory can be controlled by setting parameters like before_epoch, after_iter, and after_epoch, meaning before the start of each epoch, after each iteration, and after each epoch respectively.

# The release operation is performed at the end of each epoch
custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)]
runner = Runner(custom_hooks=custom_hooks, ...)


SyncBuffersHook synchronizes the buffer of the model at the end of each epoch during distributed training, e.g. running_mean and running_var of the BN layer.

custom_hooks = [dict(type='SyncBuffersHook')]
runner = Runner(custom_hooks=custom_hooks, ...)

Customize Your Hooks

If the built-in hooks provided by MMEngine do not cover your demands, you are encouraged to customize your own hooks by simply inheriting the base hook class and overriding the corresponding mount point methods.

For example, if you want to check whether the loss value is valid, i.e. not infinite, during training, you can simply override the after_train_iter method as below. The check will be performed after each training iteration.

import torch
from mmengine.registry import HOOKS
from mmengine.hooks import Hook
class CheckInvalidLossHook(Hook):
    """Check invalid loss hook.
    This hook will regularly check whether the loss is valid
    during training.
        interval (int): Checking interval (every k iterations).
            Defaults to 50.
    def __init__(self, interval=50):
        self.interval = interval
    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        """All subclasses should override this method, if they need any
        operations after each training iteration.
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict or tuple or list, optional): Data from dataloader.
            outputs (dict, optional): Outputs from model.
        if self.every_n_train_iters(runner, self.interval):
            assert torch.isfinite(outputs['loss']),\
      'loss become infinite or NaN!')

We simply pass the hook config to the custom_hooks parameter of the Runner, which will register the hooks when the Runner is initialized.

from mmengine.runner import Runner
custom_hooks = [
    dict(type='CheckInvalidLossHook', interval=50)
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()  # start training

Then the loss value are checked after iteration.

Note that the priority of the custom hook is NORMAL (50) by default, if you want to change the priority of the hook, then you can set the priority key in the config.

custom_hooks = [
    dict(type='CheckInvalidLossHook', interval=50, priority='ABOVE_NORMAL')

You can also set priority when defining classes.

class CheckInvalidLossHook(Hook):
    priority = 'ABOVE_NORMAL'
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.