Shortcuts

mmengine.hooks.param_scheduler_hook 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Union

from mmengine.optim import _ParamScheduler
from mmengine.registry import HOOKS
from mmengine.utils import is_list_of
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]


[文档]@HOOKS.register_module() class ParamSchedulerHook(Hook): """A hook to update some hyper-parameters in optimizer, e.g., learning rate and momentum.""" priority = 'LOW'
[文档] def after_train_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None: """Call step function for each scheduler after each training iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. data_batch (dict or tuple or list, optional): Data from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. """ if runner.param_schedulers is None: return def step(param_schedulers): assert isinstance(param_schedulers, list) for scheduler in param_schedulers: if not scheduler.by_epoch: scheduler.step() if isinstance(runner.param_schedulers, list): step(runner.param_schedulers) elif isinstance(runner.param_schedulers, dict): for param_schedulers in runner.param_schedulers.values(): step(param_schedulers) else: raise TypeError( 'runner.param_schedulers should be list of ParamScheduler or ' 'a dict containing list of ParamScheduler, ' f'but got {runner.param_schedulers}')
[文档] def after_train_epoch(self, runner) -> None: """Call step function for each scheduler after each training epoch. Args: runner (Runner): The runner of the training process. """ if runner.param_schedulers is None: return def step(param_schedulers): assert isinstance(param_schedulers, list) for scheduler in param_schedulers: if scheduler.by_epoch: scheduler.step() if isinstance(runner.param_schedulers, list): step(runner.param_schedulers) elif isinstance(runner.param_schedulers, dict): for param_schedulers in runner.param_schedulers.values(): step(param_schedulers) else: raise TypeError( 'runner.param_schedulers should be list of ParamScheduler or ' 'a dict containing list of ParamScheduler, ' f'but got {runner.param_schedulers}')
[文档] def after_val_epoch(self, runner, metrics: Optional[Dict[str, float]] = None) -> None: """Call step function for each scheduler which has attribute ``need_val_args`` after each validation epoch. Args: runner (Runner): The runner of the validation process. metrics (Dict[str, float], optional): Evaluation results of all metrics on validation dataset. The keys are the names of the metrics, and the values are corresponding results. Note: if ``runner.param_schedulers`` is not built before, the hook ``after_val_epoch`` will be skipped. """ if runner.param_schedulers is None: return # avoid counting scheduler._global_step # it has counted in after_train_* hook if metrics is None: return def step(param_schedulers): # check param_schedulers is list and built if not is_list_of(param_schedulers, _ParamScheduler): return for scheduler in param_schedulers: if (scheduler.by_epoch and getattr(scheduler, 'need_val_args', False)): scheduler.step(metrics) if isinstance(runner.param_schedulers, list): step(runner.param_schedulers) elif isinstance(runner.param_schedulers, dict): for param_schedulers in runner.param_schedulers.values(): step(param_schedulers) else: raise TypeError( 'runner.param_schedulers should be list of ParamScheduler or ' 'a dict containing list of ParamScheduler, ' f'but got {runner.param_schedulers}')

© Copyright 2022, mmengine contributors. Revision 39ed23fa.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.