
mmengine.hooks.param_scheduler_hook 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union

from mmengine.registry import HOOKS
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]

[文档]@HOOKS.register_module() class ParamSchedulerHook(Hook): """A hook to update some hyper-parameters in optimizer, e.g., learning rate and momentum.""" priority = 'LOW'
[文档] def after_train_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None: """Call step function for each scheduler after each iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. data_batch (dict or tuple or list, optional): Data from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. """ def step(param_schedulers): assert isinstance(param_schedulers, list) for scheduler in param_schedulers: if not scheduler.by_epoch: scheduler.step() if runner.param_schedulers is None: return if isinstance(runner.param_schedulers, list): step(runner.param_schedulers) elif isinstance(runner.param_schedulers, dict): for param_schedulers in runner.param_schedulers.values(): step(param_schedulers) else: raise TypeError( 'runner.param_schedulers should be list of ParamScheduler or ' 'a dict containing list of ParamScheduler, ' f'but got {runner.param_schedulers}')
[文档] def after_train_epoch(self, runner) -> None: """Call step function for each scheduler after each epoch. Args: runner (Runner): The runner of the training process. """ def step(param_schedulers): assert isinstance(param_schedulers, list) for scheduler in param_schedulers: if scheduler.by_epoch: scheduler.step() if runner.param_schedulers is None: return if isinstance(runner.param_schedulers, list): step(runner.param_schedulers) elif isinstance(runner.param_schedulers, dict): for param_schedulers in runner.param_schedulers.values(): step(param_schedulers) else: raise TypeError( 'runner.param_schedulers should be list of ParamScheduler or ' 'a dict containing list of ParamScheduler, ' f'but got {runner.param_schedulers}')

© Copyright 2022, mmengine contributors. Revision 6af88783.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.4.0
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.