Shortcuts

mmengine.hooks.empty_cache_hook 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union

import torch

from mmengine.registry import HOOKS
from ..device import is_cuda_available, is_musa_available
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]


[文档]@HOOKS.register_module() class EmptyCacheHook(Hook): """Releases all unoccupied cached GPU memory during the process of training. Args: before_epoch (bool): Whether to release cache before an epoch. Defaults to False. after_epoch (bool): Whether to release cache after an epoch. Defaults to True. after_iter (bool): Whether to release cache after an iteration. Defaults to False. """ priority = 'NORMAL' def __init__(self, before_epoch: bool = False, after_epoch: bool = True, after_iter: bool = False) -> None: self._do_before_epoch = before_epoch self._do_after_epoch = after_epoch self._do_after_iter = after_iter def _after_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Union[dict, Sequence]] = None, mode: str = 'train') -> None: """Empty cache after an iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. data_batch (dict or tuple or list, optional): Data from dataloader. outputs (dict or sequence, optional): Outputs from model. mode (str): Current mode of runner. Defaults to 'train'. """ if self._do_after_iter: if is_cuda_available(): torch.cuda.empty_cache() elif is_musa_available(): torch.musa.empty_cache() def _before_epoch(self, runner, mode: str = 'train') -> None: """Empty cache before an epoch. Args: runner (Runner): The runner of the training process. mode (str): Current mode of runner. Defaults to 'train'. """ if self._do_before_epoch: if is_cuda_available(): torch.cuda.empty_cache() elif is_musa_available(): torch.musa.empty_cache() def _after_epoch(self, runner, mode: str = 'train') -> None: """Empty cache after an epoch. Args: runner (Runner): The runner of the training process. mode (str): Current mode of runner. Defaults to 'train'. """ if self._do_after_epoch: if is_cuda_available(): torch.cuda.empty_cache() elif is_musa_available(): torch.musa.empty_cache()

© Copyright 2022, mmengine contributors. Revision c423d0c1.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.