EMAHook¶
- class mmengine.hooks.EMAHook(ema_type='ExponentialMovingAverage', strict_load=False, begin_iter=0, begin_epoch=0, **kwargs)[source]¶
A Hook to apply Exponential Moving Average (EMA) on the model during training.
Note
EMAHook takes priority over CheckpointHook.
The original model parameters are actually saved in ema field after train.
begin_iter
andbegin_epoch
cannot be set at the same time.
- Parameters
ema_type (str) – The type of EMA strategy to use. You can find the supported strategies in
mmengine.model.averaged_model
. Defaults to ‘ExponentialMovingAverage’.strict_load (bool) – Whether to strictly enforce that the keys of
state_dict
in checkpoint match the keys returned byself.module.state_dict
. Defaults to False. Changed in v0.3.0.begin_iter (int) – The number of iteration to enable
EMAHook
. Defaults to 0.begin_epoch (int) – The number of epoch to enable
EMAHook
. Defaults to 0.**kwargs – Keyword arguments passed to subclasses of
BaseAveragedModel
- after_test_epoch(runner, metrics=None)[source]¶
We recover source model’s parameter from ema model after test.
- after_val_epoch(runner, metrics=None)[source]¶
We recover source model’s parameter from ema model after validation.
- before_run(runner)[source]¶
Create an ema copy of the model.
- Parameters
runner (Runner) – The runner of the training process.
- Return type
None
- before_test_epoch(runner)[source]¶
We load parameter values from ema model to source model before test.
- Parameters
runner (Runner) – The runner of the training process.
- Return type
None