EMAHook¶
- class mmengine.hooks.EMAHook(ema_type='ExponentialMovingAverage', strict_load=False, begin_iter=0, begin_epoch=0, **kwargs)[源代码]¶
A Hook to apply Exponential Moving Average (EMA) on the model during training.
备注
EMAHook takes priority over CheckpointHook.
The original model parameters are actually saved in ema field after train.
begin_iter
andbegin_epoch
cannot be set at the same time.
- 参数:
ema_type (str) – The type of EMA strategy to use. You can find the supported strategies in
mmengine.model.averaged_model
. Defaults to ‘ExponentialMovingAverage’.strict_load (bool) – Whether to strictly enforce that the keys of
state_dict
in checkpoint match the keys returned byself.module.state_dict
. Defaults to False. Changed in v0.3.0.begin_iter (int) – The number of iteration to enable
EMAHook
. Defaults to 0.begin_epoch (int) – The number of epoch to enable
EMAHook
. Defaults to 0.**kwargs – Keyword arguments passed to subclasses of
BaseAveragedModel
- after_test_epoch(runner, metrics=None)[源代码]¶
We recover source model’s parameter from ema model after test.
- after_val_epoch(runner, metrics=None)[源代码]¶
We recover source model’s parameter from ema model after validation.
- before_run(runner)[源代码]¶
Create an ema copy of the model.
- 参数:
runner (Runner) – The runner of the training process.
- 返回类型:
None
- before_test_epoch(runner)[源代码]¶
We load parameter values from ema model to source model before test.
- 参数:
runner (Runner) – The runner of the training process.
- 返回类型:
None