Shortcuts

OptimWrapperDict

class mmengine.optim.OptimWrapperDict(**optim_wrapper_dict)[source]

A dictionary container of OptimWrapper.

If runner is training with multiple optimizers, all optimizer wrappers should be managed by OptimWrapperDict which is built by CustomOptimWrapperConstructor. OptimWrapperDict will load and save the state dictionary of all optimizer wrappers.

Consider the semantic ambiguity of calling :meth:update_params, backward() of all optimizer wrappers, OptimWrapperDict will not implement these methods.

Examples

>>> import torch.nn as nn
>>> from torch.optim import SGD
>>> from mmengine.optim import OptimWrapperDict, OptimWrapper
>>> model1 = nn.Linear(1, 1)
>>> model2 = nn.Linear(1, 1)
>>> optim_wrapper1 = OptimWrapper(SGD(model1.parameters(), lr=0.1))
>>> optim_wrapper2 = OptimWrapper(SGD(model2.parameters(), lr=0.1))
>>> optim_wrapper_dict = OptimWrapperDict(model1=optim_wrapper1,
>>>                                       model2=optim_wrapper2)

Note

The optimizer wrapper contained in OptimWrapperDict can be accessed in the same way as dict.

Parameters
backward(loss, **kwargs)[source]

Since OptimWrapperDict doesn’t know which optimizer wrapper’s backward method should be called (loss_scaler maybe different in different :obj:AmpOptimWrapper), this method is not implemented.

The optimizer wrapper of OptimWrapperDict should be accessed and call its backward.

Parameters

loss (torch.Tensor) –

Return type

None

get_lr()[source]

Get the learning rate of all optimizers.

Returns

Learning rate of all optimizers.

Return type

Dict[str, List[float]]

get_momentum()[source]

Get the momentum of all optimizers.

Returns

momentum of all optimizers.

Return type

Dict[str, List[float]]

initialize_count_status(model, cur_iter, max_iters)[source]

Do nothing but provide unified interface for OptimWrapper

Since OptimWrapperDict does not know the correspondence between model and optimizer wrapper. initialize_iter_status will do nothing and each optimizer wrapper should call initialize_iter_status separately.

Parameters

model (torch.nn.modules.module.Module) –

Return type

None

items()[source]

A generator to get the name and corresponding OptimWrapper

Return type

Iterator[Tuple[str, mmengine.optim.optimizer.optimizer_wrapper.OptimWrapper]]

keys()[source]

A generator to get the name of OptimWrapper

Return type

Iterator[str]

load_state_dict(state_dict)[source]

Load the state dictionary from the state_dict.

Parameters

state_dict (dict) – Each key-value pair in state_dict represents the name and the state dictionary of corresponding OptimWrapper.

Return type

None

optim_context(model)[source]

optim_context should be called by each optimizer separately.

Parameters

model (torch.nn.modules.module.Module) –

property param_groups

Returns the parameter groups of each OptimWrapper.

state_dict()[source]

Get the state dictionary of all optimizer wrappers.

Returns

Each key-value pair in the dictionary represents the name and state dictionary of corresponding OptimWrapper.

Return type

dict

step(**kwargs)[source]

Since the backward method is not implemented, the step should not be implemented either.

Return type

None

update_params(loss, step_kwargs=None, zero_kwargs=None)[source]

Update all optimizer wrappers would lead to a duplicate backward errors, and OptimWrapperDict does not know which optimizer wrapper should be updated.

Therefore, this method is not implemented. The optimizer wrapper of OptimWrapperDict should be accessed and call its update_params.

Parameters
Return type

None

values()[source]

A generator to get OptimWrapper

Return type

Iterator[mmengine.optim.optimizer.optimizer_wrapper.OptimWrapper]

zero_grad(**kwargs)[source]

Set the gradients of all optimizer wrappers to zero.

Return type

None

Read the Docs v: v0.8.3
Versions
latest
stable
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.