Shortcuts

Source code for mmengine.hooks.iter_timer_hook

# Copyright (c) OpenMMLab. All rights reserved.
import time
from typing import Optional, Sequence, Union

from mmengine.registry import HOOKS
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]


[docs]@HOOKS.register_module() class IterTimerHook(Hook): """A hook that logs the time spent during iteration. E.g. ``data_time`` for loading data and ``time`` for a model train step. """ priority = 'NORMAL' def __init__(self): self.time_sec_tot = 0 self.time_sec_test_val = 0 self.start_iter = 0
[docs] def before_train(self, runner) -> None: """Synchronize the number of iterations with the runner after resuming from checkpoints. Args: runner: The runner of the training, validation or testing process. """ self.start_iter = runner.iter
def _before_epoch(self, runner, mode: str = 'train') -> None: """Record timestamp before start an epoch. Args: runner (Runner): The runner of the training validation and testing process. mode (str): Current mode of runner. Defaults to 'train'. """ self.t = time.time() def _after_epoch(self, runner, mode: str = 'train') -> None: self.time_sec_test_val = 0 def _before_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, mode: str = 'train') -> None: """Calculating time for loading data and updating "data_time" ``HistoryBuffer`` of ``runner.message_hub``. Args: runner (Runner): The runner of the training, validation and testing process. batch_idx (int): The index of the current batch in the loop. data_batch (dict or tuple or list, optional): Data from dataloader. mode (str): Current mode of runner. Defaults to 'train'. """ # Update data loading time in `runner.message_hub`. runner.message_hub.update_scalar(f'{mode}/data_time', time.time() - self.t) def _after_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Union[dict, Sequence]] = None, mode: str = 'train') -> None: """Calculating time for an iteration and updating "time" ``HistoryBuffer`` of ``runner.message_hub``. Args: runner (Runner): The runner of the training validation and testing process. batch_idx (int): The index of the current batch in the loop. data_batch (dict or tuple or list, optional): Data from dataloader. outputs (dict or sequence, optional): Outputs from model. mode (str): Current mode of runner. Defaults to 'train'. """ # Update iteration time in `runner.message_hub`. message_hub = runner.message_hub message_hub.update_scalar(f'{mode}/time', time.time() - self.t) self.t = time.time() iter_time = message_hub.get_scalar(f'{mode}/time') if mode == 'train': self.time_sec_tot += iter_time.current() # Calculate average iterative time. time_sec_avg = self.time_sec_tot / ( runner.iter - self.start_iter + 1) # Calculate eta. eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) runner.message_hub.update_info('eta', eta_sec) else: if mode == 'val': cur_dataloader = runner.val_dataloader else: cur_dataloader = runner.test_dataloader self.time_sec_test_val += iter_time.current() time_sec_avg = self.time_sec_test_val / (batch_idx + 1) eta_sec = time_sec_avg * (len(cur_dataloader) - batch_idx - 1) runner.message_hub.update_info('eta', eta_sec)

© Copyright 2022, mmengine contributors. Revision ef4c68de.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.7.0
Versions
latest
stable
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.