Shortcuts

BaseAveragedModel

class mmengine.model.BaseAveragedModel(model, interval=1, device=None, update_buffers=False)[source]

A base class for averaging model weights.

Weight averaging, such as SWA and EMA, is a widely used technique for training neural networks. This class implements the averaging process for a model. All subclasses must implement the avg_func method. This class creates a copy of the provided module model on the device and allows computing running averages of the parameters of the model.

The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py.

Different from the AveragedModel in PyTorch, we use in-place operation to improve the parameter updating speed, which is about 5 times faster than the non-in-place version.

In mmengine, we provide two ways to use the model averaging:

  1. Use the model averaging module in hook: We provide an mmengine.hooks.EMAHook to apply the model averaging during training. Add custom_hooks=[dict(type='EMAHook')] to the config or the runner.

  2. Use the model averaging module directly in the algorithm. Take the ema teacher in semi-supervise as an example:

    >>> from mmengine.model import ExponentialMovingAverage
    >>> student = ResNet(depth=50)
    >>> # use ema model as teacher
    >>> ema_teacher = ExponentialMovingAverage(student)
    
Parameters
  • model (nn.Module) – The model to be averaged.

  • interval (int) – Interval between two updates. Defaults to 1.

  • device (torch.device, optional) – If provided, the averaged model will be stored on the device. Defaults to None.

  • update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.

Return type

None

abstract avg_func(averaged_param, source_param, steps)[source]

Use in-place operation to compute the average of the parameters. All subclasses must implement this method.

Parameters
  • averaged_param (Tensor) – The averaged parameters.

  • source_param (Tensor) – The source parameters.

  • steps (int) – The number of times the parameters have been updated.

Return type

None

forward(*args, **kwargs)[source]

Forward method of the averaged model.

update_parameters(model)[source]

Update the parameters of the model. This method will execute the avg_func to compute the new parameters and update the model’s parameters.

Parameters

model (nn.Module) – The model whose parameters will be averaged.

Return type

None

Read the Docs v: v0.4.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.