Weight initialization¶
Usually, we’ll customize our module based on nn.Module, which is implemented by Native PyTorch. Also, torch.nn.init could help us initialize the parameters of the model easily. To simplify the process of model construction and initialization, MMEngine designed the BaseModule to help us define and initialize the model from config easily.
Initialize the model from config¶
The core function of BaseModule
is that it could help us to initialize the model from config. Subclasses inherited from BaseModule
could define the init_cfg
in the __init__
function, and we can choose the method of initialization by configuring init_cfg
.
Currently, we support the following initialization methods:
Initializer | Registered name | Function |
---|---|---|
ConstantInit | Constant | Initialize the weight and bias with a constant, commonly used for Convolution |
XavierInit | Xavier | Initialize the weight by Xavier initialization, and initialize the bias with a constant |
NormalInit | Normal | Initialize the weight by normal distribution, and initialize the bias with a constant |
TruncNormalInit | TruncNormal | Initialize the weight by truncated normal distribution, and initialize the bias with a constant, commonly used for Transformer |
UniformInit | Uniform | Initialize the weight by uniform distribution, and initialize the bias with a constant, commonly used for convolution |
KaimingInit | Kaiming | Initialize the weight by Kaiming initialization, and initialize the bias with a constant. Commonly used for convolution |
Caffe2XavierInit | Caffe2Xavier | Xavier initialization in Caffe2, and Kaiming initialization in PyTorh with "fan_in" and "normal" mode. Commonly used for convolution |
PretrainedInit | Pretrained | Initialize the model with the pretrained model |
Initialize the model with pretrained model¶
Defining the ToyNet
as below:
import torch
import torch.nn as nn
from mmengine.model import BaseModule
class ToyNet(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1 = nn.Linear(1, 1)
# Save the checkpoint.
toy_net = ToyNet()
torch.save(toy_net.state_dict(), './pretrained.pth')
pretrained = './pretrained.pth'
toy_net = ToyNet(init_cfg=dict(type='Pretrained', checkpoint=pretrained))
and then we can configure the init_cfg
to make it load the pretrained model by calling initi_weights()
after its construction.
# Initialize the model with the saved checkpoint.
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO - load model from: ./pretrained.pth
08/19 16:50:24 - mmengine - INFO - local loads checkpoint from path: ./pretrained.pth
If init_cfg
is a dict
, type
means a kind of initializer registered in WEIGHT_INITIALIZERS
. The Pretrained
means PretrainedInit
, which could help us to load the target checkpoint.
All initializers have the same mapping relationship like Pretrained
-> PretrainedInit
, which strips the suffix Init
of the class name. The checkpoint
argument of PretrainedInit
means the path of the checkpoint. It could be a local path or a URL.
Note
PretrainedInit
has a higher priority than any other initializer. The loaded pretrained weights will overwrite the previous initialized weights.
Commonly used initialization methods¶
Similarly, we could use the Kaiming
initialization just like Pretrained
initializer. For example, we could make init_cfg=dict(type='Kaiming', layer='Conv2d')
to initialize all Conv2d
module with Kaiming
initialization.
Sometimes we need to initialize the model with different initialization methods for different modules. For example, we could initialize the Conv2d
module with Kaiming
initialization and initialize the Linear
module with Xavier
initialization. We could make init_cfg=dict(type='Kaiming', layer='Conv2d')
:
import torch.nn as nn
from mmengine.model import BaseModule
class ToyNet(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(1, 1)
self.conv = nn.Conv2d(1, 1, 1)
# Apply `Kaiming` initialization to `Conv2d` module and `Xavier` initialization to `Linear` module.
toy_net = ToyNet(
init_cfg=[
dict(type='Kaiming', layer='Conv2d'),
dict(type='Xavier', layer='Linear')
], )
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
linear.weight - torch.Size([1, 1]):
XavierInit: gain=1, distribution=normal, bias=0
08/19 16:50:24 - mmengine - INFO -
linear.bias - torch.Size([1]):
XavierInit: gain=1, distribution=normal, bias=0
08/19 16:50:24 - mmengine - INFO -
conv.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
08/19 16:50:24 - mmengine - INFO -
conv.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
layer
could also be a list, each element of which means a type of applied module.
# Apply Kaiming initialization to `Conv2d` and `Linear` module.
toy_net = ToyNet(init_cfg=[dict(type='Kaiming', layer=['Conv2d', 'Linear'])], )
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
linear.weight - torch.Size([1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
08/19 16:50:24 - mmengine - INFO -
linear.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
08/19 16:50:24 - mmengine - INFO -
conv.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
08/19 16:50:24 - mmengine - INFO -
conv.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
More fine-grained initialization¶
Sometimes we need to initialize the same type of module with different types of initialization. For example, we’ve defined conv1
and conv2
submodules, and we want to initialize the conv1
with Kaiming
initialization and conv2
with Xavier
initialization. We could configure the init_cfg with override
:
import torch.nn as nn
from mmengine.model import BaseModule
class ToyNet(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
# Apllly `Kaiming` initialization to `conv1` and `Xavier` initialization to `conv2`.
toy_net = ToyNet(
init_cfg=[
dict(
type='Kaiming',
layer=['Conv2d'],
override=dict(name='conv2', type='Xavier')),
], )
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
conv1.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
08/19 16:50:24 - mmengine - INFO -
conv1.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
08/19 16:50:24 - mmengine - INFO -
conv2.weight - torch.Size([1, 1, 1, 1]):
XavierInit: gain=1, distribution=normal, bias=0
08/19 16:50:24 - mmengine - INFO -
conv2.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
override
could be understood as an nested init_cfg
, which could also be a list
or dict
, and we should also set “type
” for it. The difference is that we must set name
in override
to specify the applied scope for submodule. As the example above, we set name='conv2'
to specify that the Xavier
initialization is applied to all submodules of toy_net.conv2
.
Customize the initialization method¶
Although the init_cfg
could control the initialization method for different modules, we would have to register a new initialization method to WEIGHT_INITIALIZERS
if we want to customize initialization process. It is not convenient right? Actually, we could also override the init_weights
method to customize the initialization process.
Assuming we’ve defined the following modules:
ToyConv
inherit fromnn.Module
, implementsinit_weights
which initializecustom_weight
(parameter
ofToyConv
) with 1 and initializecustom_bias
with 0ToyNet
defines aToyConv
submodule.
ToyNet.init_weights
will call init_weights
of all submodules sequentially.
import torch
import torch.nn as nn
from mmengine.model import BaseModule
class ToyConv(nn.Module):
def __init__(self):
super().__init__()
self.custom_weight = nn.Parameter(torch.empty(1, 1, 1, 1))
self.custom_bias = nn.Parameter(torch.empty(1))
def init_weights(self):
with torch.no_grad():
self.custom_weight = self.custom_weight.fill_(1)
self.custom_bias = self.custom_bias.fill_(0)
class ToyNet(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
self.custom_conv = ToyConv()
toy_net = ToyNet(
init_cfg=[
dict(
type='Kaiming',
layer=['Conv2d'],
override=dict(name='conv2', type='Xavier'))
])
toy_net.init_weights()
08/19 16:50:24 - mmengine - INFO -
conv1.weight - torch.Size([1, 1, 1, 1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
08/19 16:50:24 - mmengine - INFO -
conv1.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
08/19 16:50:24 - mmengine - INFO -
conv2.weight - torch.Size([1, 1, 1, 1]):
XavierInit: gain=1, distribution=normal, bias=0
08/19 16:50:24 - mmengine - INFO -
conv2.bias - torch.Size([1]):
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
08/19 16:50:24 - mmengine - INFO -
custom_conv.custom_weight - torch.Size([1, 1, 1, 1]):
Initialized by user-defined `init_weights` in ToyConv
08/19 16:50:24 - mmengine - INFO -
custom_conv.custom_bias - torch.Size([1]):
Initialized by user-defined `init_weights` in ToyConv
Conclusion¶
1. Configure init_cfg
to initialize model
Commonly used for the initialization of
Conv2d
,Linear
and other underlying module. All initialization methods should be managed byWEIGHT_INITIALIZERS
Dynamic initialization controlled by
init_cfg
2. Customize init_weights
Compared to configuring the
init_cfg
, implementing theinit_weights
is simpler and does not require registration. However, it is not as flexible asinit_cfg
, and it is not possible to initialize the module dynamically.
Note
The priorify of init_weights is higher than
init_cfg
Runner will call
init_weights
in Runner.train()
Ininitailize module with function¶
As mentioned in prior section, we could customize our initialization in init_weights
. To make it more convenient to initialize modules, MMEngine provides a series of module initialization functions to initialize the whole module based on torch.nn.init
. For example, we want to initialize the weights of the convolutional layer with normal distribution and initialize the bias of the convolutional layer with a constant. The implementation of torch.nn.init
is as follows:
from torch.nn.init import normal_, constant_
import torch.nn as nn
model = nn.Conv2d(1, 1, 1)
normal_(model.weight, mean=0, std=0.01)
constant_(model.bias, val=0)
Parameter containing:
tensor([0.], requires_grad=True)
The above process is actually a standard process for initializing a convolutional module with normal distribution, so MMEngine simplifies this by implementing a series of common module initialization functions. Compared with torch.nn.init
, the module initialization functions could accept the convolution module directly:
from mmengine.model import normal_init
normal_init(model, mean=0, std=0.01, bias=0)
Similarly, we could also use Kaiming initialization and Xavier initialization:
from mmengine.model import kaiming_init, xavier_init
kaiming_init(model)
xavier_init(model)
Currently, MMEngine provide the following initialization function:
Initialization function | Function |
---|---|
constant_init | Initialize the weight and bias with a constant, commonly used for Convolution |
xavier_init | Initialize the weight by Xavier initialization, and initialize the bias with a constant |
normal_init | Initialize the weight by normal distribution, and initialize the bias with a constant |
trunc_normal_init | Initialize the weight by truncated normal distribution, and initialize the bias with a constant, commonly used for Transformer |
uniform_init | Initialize the weight by uniform distribution, and initialize the bias with a constant, commonly used for convolution |
kaiming_init | Initialize the weight by Kaiming initialization, and initialize the bias with a constant. Commonly used for convolution |
caffe2_xavier_init | Xavier initialization in Caffe2, and Kaiming initialization in PyTorh with "fan_in" and "normal" mode. Commonly used for convolution |
bias_init_with_prob | Initialize the bias with the probability |