Shortcuts

BaseAveragedModel

class mmengine.model.BaseAveragedModel(model, interval=1, device=None, update_buffers=False)[source]

A base class for averaging model weights.

Weight averaging, such as SWA and EMA, is a widely used technique for training neural networks. This class implements the averaging process for a model. All subclasses must implement the avg_func method. This class creates a copy of the provided module model on the device and allows computing running averages of the parameters of the model.

The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py.

Different from the AveragedModel in PyTorch, we use in-place operation to improve the parameter updating speed, which is about 5 times faster than the non-in-place version.

In mmengine, we provide two ways to use the model averaging:

  1. Use the model averaging module in hook: We provide an mmengine.hooks.EMAHook to apply the model averaging during training. Add custom_hooks=[dict(type='EMAHook')] to the config or the runner.

  2. Use the model averaging module directly in the algorithm. Take the ema teacher in semi-supervise as an example:

    >>> from mmengine.model import ExponentialMovingAverage
    >>> student = ResNet(depth=50)
    >>> # use ema model as teacher
    >>> ema_teacher = ExponentialMovingAverage(student)
    
Parameters:
  • model (nn.Module) – The model to be averaged.

  • interval (int) – Interval between two updates. Defaults to 1.

  • device (torch.device, optional) – If provided, the averaged model will be stored on the device. Defaults to None.

  • update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.

abstract avg_func(averaged_param, source_param, steps)[source]

Use in-place operation to compute the average of the parameters. All subclasses must implement this method.

Parameters:
  • averaged_param (Tensor) – The averaged parameters.

  • source_param (Tensor) – The source parameters.

  • steps (int) – The number of times the parameters have been updated.

Return type:

None

forward(*args, **kwargs)[source]

Forward method of the averaged model.

update_parameters(model)[source]

Update the parameters of the model. This method will execute the avg_func to compute the new parameters and update the model’s parameters.

Parameters:

model (nn.Module) – The model whose parameters will be averaged.

Return type:

None

Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.