Shortcuts

EMAHook

class mmengine.hooks.EMAHook(ema_type='ExponentialMovingAverage', strict_load=False, begin_iter=0, begin_epoch=0, **kwargs)[源代码]

A Hook to apply Exponential Moving Average (EMA) on the model during training.

备注

  • EMAHook takes priority over CheckpointHook.

  • The original model parameters are actually saved in ema field after train.

  • begin_iter and begin_epoch cannot be set at the same time.

参数:
  • ema_type (str) – The type of EMA strategy to use. You can find the supported strategies in mmengine.model.averaged_model. Defaults to ‘ExponentialMovingAverage’.

  • strict_load (bool) – Whether to strictly enforce that the keys of state_dict in checkpoint match the keys returned by self.module.state_dict. Defaults to False. Changed in v0.3.0.

  • begin_iter (int) – The number of iteration to enable EMAHook. Defaults to 0.

  • begin_epoch (int) – The number of epoch to enable EMAHook. Defaults to 0.

  • **kwargs – Keyword arguments passed to subclasses of BaseAveragedModel

after_load_checkpoint(runner, checkpoint)[源代码]

Resume ema parameters from checkpoint.

参数:
  • runner (Runner) – The runner of the testing process.

  • checkpoint (dict) –

返回类型:

None

after_test_epoch(runner, metrics=None)[源代码]

We recover source model’s parameter from ema model after test.

参数:
  • runner (Runner) – The runner of the testing process.

  • metrics (Dict[str, float], optional) – Evaluation results of all metrics on test dataset. The keys are the names of the metrics, and the values are corresponding results.

返回类型:

None

after_train_iter(runner, batch_idx, data_batch=None, outputs=None)[源代码]

Update ema parameter.

参数:
  • runner (Runner) – The runner of the training process.

  • batch_idx (int) – The index of the current batch in the train loop.

  • data_batch (Sequence[dict], optional) – Data from dataloader. Defaults to None.

  • outputs (dict, optional) – Outputs from model. Defaults to None.

返回类型:

None

after_val_epoch(runner, metrics=None)[源代码]

We recover source model’s parameter from ema model after validation.

参数:
  • runner (Runner) – The runner of the validation process.

  • metrics (Dict[str, float], optional) – Evaluation results of all metrics on validation dataset. The keys are the names of the metrics, and the values are corresponding results.

返回类型:

None

before_run(runner)[源代码]

Create an ema copy of the model.

参数:

runner (Runner) – The runner of the training process.

返回类型:

None

before_save_checkpoint(runner, checkpoint)[源代码]

Save ema parameters to checkpoint.

参数:
  • runner (Runner) – The runner of the testing process.

  • checkpoint (dict) –

返回类型:

None

before_test_epoch(runner)[源代码]

We load parameter values from ema model to source model before test.

参数:

runner (Runner) – The runner of the training process.

返回类型:

None

before_train(runner)[源代码]

Check the begin_epoch/iter is smaller than max_epochs/iters.

参数:

runner (Runner) – The runner of the training process.

返回类型:

None

before_val_epoch(runner)[源代码]

We load parameter values from ema model to source model before validation.

参数:

runner (Runner) – The runner of the training process.

返回类型:

None