Shortcuts

BaseAveragedModel

class mmengine.model.BaseAveragedModel(model, interval=1, device=None, update_buffers=False)[source]

A base class for averaging model weights.

Weight averaging, such as SWA and EMA, is a widely used technique for training neural networks. This class implements the averaging process for a model. All subclasses must implement the avg_func method. This class creates a copy of the provided module model on the device and allows computing running averages of the parameters of the model.

The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py.

Different from the AveragedModel in PyTorch, we use in-place operation to improve the parameter updating speed, which is about 5 times faster than the non-in-place version.

In mmengine, we provide two ways to use the model averaging:

  1. Use the model averaging module in hook: We provide an mmengine.hooks.EMAHook to apply the model averaging during training. Add custom_hooks=[dict(type='EMAHook')] to the config or the runner.

  2. Use the model averaging module directly in the algorithm. Take the ema teacher in semi-supervise as an example:

    >>> from mmengine.model import ExponentialMovingAverage
    >>> student = ResNet(depth=50)
    >>> # use ema model as teacher
    >>> ema_teacher = ExponentialMovingAverage(student)
    
Parameters
  • model (nn.Module) – The model to be averaged.

  • interval (int) – Interval between two updates. Defaults to 1.

  • device (torch.device, optional) – If provided, the averaged model will be stored on the device. Defaults to None.

  • update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.

Return type

None

abstract avg_func(averaged_param, source_param, steps)[source]

Use in-place operation to compute the average of the parameters. All subclasses must implement this method.

Parameters
  • averaged_param (Tensor) – The averaged parameters.

  • source_param (Tensor) – The source parameters.

  • steps (int) – The number of times the parameters have been updated.

Return type

None

forward(*args, **kwargs)[source]

Forward method of the averaged model.

update_parameters(model)[source]

Update the parameters of the model. This method will execute the avg_func to compute the new parameters and update the model’s parameters.

Parameters

model (nn.Module) – The model whose parameters will be averaged.

Return type

None

Read the Docs v: v0.8.3
Versions
latest
stable
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.