BaseAveragedModel¶
- class mmengine.model.BaseAveragedModel(model, interval=1, device=None, update_buffers=False)[source]¶
A base class for averaging model weights.
Weight averaging, such as SWA and EMA, is a widely used technique for training neural networks. This class implements the averaging process for a model. All subclasses must implement the avg_func method. This class creates a copy of the provided module
model
on thedevice
and allows computing running averages of the parameters of themodel
.The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py.
Different from the AveragedModel in PyTorch, we use in-place operation to improve the parameter updating speed, which is about 5 times faster than the non-in-place version.
In mmengine, we provide two ways to use the model averaging:
Use the model averaging module in hook: We provide an
mmengine.hooks.EMAHook
to apply the model averaging during training. Addcustom_hooks=[dict(type='EMAHook')]
to the config or the runner.Use the model averaging module directly in the algorithm. Take the ema teacher in semi-supervise as an example:
>>> from mmengine.model import ExponentialMovingAverage >>> student = ResNet(depth=50) >>> # use ema model as teacher >>> ema_teacher = ExponentialMovingAverage(student)
- Parameters
model (nn.Module) – The model to be averaged.
interval (int) – Interval between two updates. Defaults to 1.
device (torch.device, optional) – If provided, the averaged model will be stored on the
device
. Defaults to None.update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.
- Return type