Shortcuts

Source code for mmengine.runner.utils

# Copyright (c) OpenMMLab. All rights reserved.
import logging
import random
from typing import List, Optional, Tuple

import numpy as np
import torch
from torch.utils.data import DataLoader

from mmengine.device import is_cuda_available, is_musa_available
from mmengine.dist import get_rank, sync_random_seed
from mmengine.logging import print_log
from mmengine.utils import digit_version, is_list_of
from mmengine.utils.dl_utils import TORCH_VERSION


def calc_dynamic_intervals(
    start_interval: int,
    dynamic_interval_list: Optional[List[Tuple[int, int]]] = None
) -> Tuple[List[int], List[int]]:
    """Calculate dynamic intervals.

    Args:
        start_interval (int): The interval used in the beginning.
        dynamic_interval_list (List[Tuple[int, int]], optional): The
            first element in the tuple is a milestone and the second
            element is a interval. The interval is used after the
            corresponding milestone. Defaults to None.

    Returns:
        Tuple[List[int], List[int]]: a list of milestone and its corresponding
        intervals.
    """
    if dynamic_interval_list is None:
        return [0], [start_interval]

    assert is_list_of(dynamic_interval_list, tuple)

    dynamic_milestones = [0]
    dynamic_milestones.extend(
        [dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
    dynamic_intervals = [start_interval]
    dynamic_intervals.extend(
        [dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
    return dynamic_milestones, dynamic_intervals


[docs]def set_random_seed(seed: Optional[int] = None, deterministic: bool = False, diff_rank_seed: bool = False) -> int: """Set random seed. Args: seed (int, optional): Seed to be used. deterministic (bool): Whether to set the deterministic option for CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark` to False. Defaults to False. diff_rank_seed (bool): Whether to add rank number to the random seed to have different random seed in different threads. Defaults to False. """ if seed is None: seed = sync_random_seed() if diff_rank_seed: rank = get_rank() seed += rank random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # torch.cuda.manual_seed(seed) if is_cuda_available(): torch.cuda.manual_seed_all(seed) elif is_musa_available(): torch.musa.manual_seed_all(seed) # os.environ['PYTHONHASHSEED'] = str(seed) if deterministic: if torch.backends.cudnn.benchmark: print_log( 'torch.backends.cudnn.benchmark is going to be set as ' '`False` to cause cuDNN to deterministically select an ' 'algorithm', logger='current', level=logging.WARNING) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): torch.use_deterministic_algorithms(True) return seed
def _get_batch_size(dataloader: dict): if isinstance(dataloader, dict): if 'batch_size' in dataloader: return dataloader['batch_size'] elif ('batch_sampler' in dataloader and 'batch_size' in dataloader['batch_sampler']): return dataloader['batch_sampler']['batch_size'] else: raise ValueError('Please set batch_size in `Dataloader` or ' '`batch_sampler`') elif isinstance(dataloader, DataLoader): return dataloader.batch_sampler.batch_size else: raise ValueError('dataloader should be a dict or a Dataloader ' f'instance, but got {type(dataloader)}')

© Copyright 2022, mmengine contributors. Revision 66fb81f7.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.