EMAHook¶
- class mmengine.hooks.EMAHook(ema_type='ExponentialMovingAverage', strict_load=False, begin_iter=0, begin_epoch=0, **kwargs)[source]¶
A Hook to apply Exponential Moving Average (EMA) on the model during training.
Note
EMAHook takes priority over CheckpointHook.
The original model parameters are actually saved in ema field after train.
begin_iter
andbegin_epoch
cannot be set at the same time.
- Parameters:
ema_type (str) – The type of EMA strategy to use. You can find the supported strategies in
mmengine.model.averaged_model
. Defaults to ‘ExponentialMovingAverage’.strict_load (bool) – Whether to strictly enforce that the keys of
state_dict
in checkpoint match the keys returned byself.module.state_dict
. Defaults to False. Changed in v0.3.0.begin_iter (int) – The number of iteration to enable
EMAHook
. Defaults to 0.begin_epoch (int) – The number of epoch to enable
EMAHook
. Defaults to 0.**kwargs – Keyword arguments passed to subclasses of
BaseAveragedModel
- after_test_epoch(runner, metrics=None)[source]¶
We recover source model’s parameter from ema model after test.
- after_val_epoch(runner, metrics=None)[source]¶
We recover source model’s parameter from ema model after validation.
- before_run(runner)[source]¶
Create an ema copy of the model.
- Parameters:
runner (Runner) – The runner of the training process.
- Return type:
None
- before_test_epoch(runner)[source]¶
We load parameter values from ema model to source model before test.
- Parameters:
runner (Runner) – The runner of the training process.
- Return type:
None