Shortcuts

mmengine.runner.utils 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import logging
import random
from typing import List, Optional, Tuple

import numpy as np
import torch

from mmengine.dist import get_rank, sync_random_seed
from mmengine.logging import print_log
from mmengine.utils import digit_version, is_list_of
from mmengine.utils.dl_utils import TORCH_VERSION


def calc_dynamic_intervals(
    start_interval: int,
    dynamic_interval_list: Optional[List[Tuple[int, int]]] = None
) -> Tuple[List[int], List[int]]:
    """Calculate dynamic intervals.

    Args:
        start_interval (int): The interval used in the beginning.
        dynamic_interval_list (List[Tuple[int, int]], optional): The
            first element in the tuple is a milestone and the second
            element is a interval. The interval is used after the
            corresponding milestone. Defaults to None.

    Returns:
        Tuple[List[int], List[int]]: a list of milestone and its corresponding
        intervals.
    """
    if dynamic_interval_list is None:
        return [0], [start_interval]

    assert is_list_of(dynamic_interval_list, tuple)

    dynamic_milestones = [0]
    dynamic_milestones.extend(
        [dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
    dynamic_intervals = [start_interval]
    dynamic_intervals.extend(
        [dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
    return dynamic_milestones, dynamic_intervals


[文档]def set_random_seed(seed: Optional[int] = None, deterministic: bool = False, diff_rank_seed: bool = False) -> int: """Set random seed. Args: seed (int, optional): Seed to be used. deterministic (bool): Whether to set the deterministic option for CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark` to False. Defaults to False. diff_rank_seed (bool): Whether to add rank number to the random seed to have different random seed in different threads. Defaults to False. """ if seed is None: seed = sync_random_seed() if diff_rank_seed: rank = get_rank() seed += rank random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # os.environ['PYTHONHASHSEED'] = str(seed) if deterministic: if torch.backends.cudnn.benchmark: print_log( 'torch.backends.cudnn.benchmark is going to be set as ' '`False` to cause cuDNN to deterministically select an ' 'algorithm', logger='current', level=logging.WARNING) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): torch.use_deterministic_algorithms(True) return seed

© Copyright 2022, mmengine contributors. Revision a2e410bd.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.7.4
Versions
latest
stable
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.