Shortcuts

mmengine.hooks.iter_timer_hook 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import time
from typing import Optional, Sequence, Union

from mmengine.registry import HOOKS
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]


[文档]@HOOKS.register_module() class IterTimerHook(Hook): """A hook that logs the time spent during iteration. E.g. ``data_time`` for loading data and ``time`` for a model train step. """ priority = 'NORMAL' def __init__(self): self.time_sec_tot = 0 self.time_sec_test_val = 0 self.start_iter = 0
[文档] def before_train(self, runner) -> None: """Synchronize the number of iterations with the runner after resuming from checkpoints. Args: runner: The runner of the training, validation or testing process. """ self.start_iter = runner.iter
def _before_epoch(self, runner, mode: str = 'train') -> None: """Record timestamp before start an epoch. Args: runner (Runner): The runner of the training validation and testing process. mode (str): Current mode of runner. Defaults to 'train'. """ self.t = time.time() def _after_epoch(self, runner, mode: str = 'train') -> None: self.time_sec_test_val = 0 def _before_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, mode: str = 'train') -> None: """Calculating time for loading data and updating "data_time" ``HistoryBuffer`` of ``runner.message_hub``. Args: runner (Runner): The runner of the training, validation and testing process. batch_idx (int): The index of the current batch in the loop. data_batch (dict or tuple or list, optional): Data from dataloader. mode (str): Current mode of runner. Defaults to 'train'. """ # Update data loading time in `runner.message_hub`. runner.message_hub.update_scalar(f'{mode}/data_time', time.time() - self.t) def _after_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Union[dict, Sequence]] = None, mode: str = 'train') -> None: """Calculating time for an iteration and updating "time" ``HistoryBuffer`` of ``runner.message_hub``. Args: runner (Runner): The runner of the training validation and testing process. batch_idx (int): The index of the current batch in the loop. data_batch (dict or tuple or list, optional): Data from dataloader. outputs (dict or sequence, optional): Outputs from model. mode (str): Current mode of runner. Defaults to 'train'. """ # Update iteration time in `runner.message_hub`. message_hub = runner.message_hub message_hub.update_scalar(f'{mode}/time', time.time() - self.t) self.t = time.time() iter_time = message_hub.get_scalar(f'{mode}/time') if mode == 'train': self.time_sec_tot += iter_time.current() # Calculate average iterative time. time_sec_avg = self.time_sec_tot / ( runner.iter - self.start_iter + 1) # Calculate eta. eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) runner.message_hub.update_info('eta', eta_sec) else: if mode == 'val': cur_dataloader = runner.val_dataloader else: cur_dataloader = runner.test_dataloader self.time_sec_test_val += iter_time.current() time_sec_avg = self.time_sec_test_val / (batch_idx + 1) eta_sec = time_sec_avg * (len(cur_dataloader) - batch_idx - 1) runner.message_hub.update_info('eta', eta_sec)

© Copyright 2022, mmengine contributors. Revision 66fb81f7.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.