Shortcuts

mmengine.runner.utils 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import logging
import random
from typing import List, Optional, Tuple

import numpy as np
import torch
from torch.utils.data import DataLoader

from mmengine.device import is_cuda_available, is_musa_available
from mmengine.dist import get_rank, sync_random_seed
from mmengine.logging import print_log
from mmengine.utils import digit_version, is_list_of
from mmengine.utils.dl_utils import TORCH_VERSION


def calc_dynamic_intervals(
    start_interval: int,
    dynamic_interval_list: Optional[List[Tuple[int, int]]] = None
) -> Tuple[List[int], List[int]]:
    """Calculate dynamic intervals.

    Args:
        start_interval (int): The interval used in the beginning.
        dynamic_interval_list (List[Tuple[int, int]], optional): The
            first element in the tuple is a milestone and the second
            element is a interval. The interval is used after the
            corresponding milestone. Defaults to None.

    Returns:
        Tuple[List[int], List[int]]: a list of milestone and its corresponding
        intervals.
    """
    if dynamic_interval_list is None:
        return [0], [start_interval]

    assert is_list_of(dynamic_interval_list, tuple)

    dynamic_milestones = [0]
    dynamic_milestones.extend(
        [dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
    dynamic_intervals = [start_interval]
    dynamic_intervals.extend(
        [dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
    return dynamic_milestones, dynamic_intervals


[文档]def set_random_seed(seed: Optional[int] = None, deterministic: bool = False, diff_rank_seed: bool = False) -> int: """Set random seed. Args: seed (int, optional): Seed to be used. deterministic (bool): Whether to set the deterministic option for CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark` to False. Defaults to False. diff_rank_seed (bool): Whether to add rank number to the random seed to have different random seed in different threads. Defaults to False. """ if seed is None: seed = sync_random_seed() if diff_rank_seed: rank = get_rank() seed += rank random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # torch.cuda.manual_seed(seed) if is_cuda_available(): torch.cuda.manual_seed_all(seed) elif is_musa_available(): torch.musa.manual_seed_all(seed) # os.environ['PYTHONHASHSEED'] = str(seed) if deterministic: if torch.backends.cudnn.benchmark: print_log( 'torch.backends.cudnn.benchmark is going to be set as ' '`False` to cause cuDNN to deterministically select an ' 'algorithm', logger='current', level=logging.WARNING) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): torch.use_deterministic_algorithms(True) return seed
def _get_batch_size(dataloader: dict): if isinstance(dataloader, dict): if 'batch_size' in dataloader: return dataloader['batch_size'] elif ('batch_sampler' in dataloader and 'batch_size' in dataloader['batch_sampler']): return dataloader['batch_sampler']['batch_size'] else: raise ValueError('Please set batch_size in `Dataloader` or ' '`batch_sampler`') elif isinstance(dataloader, DataLoader): return dataloader.batch_sampler.batch_size else: raise ValueError('dataloader should be a dict or a Dataloader ' f'instance, but got {type(dataloader)}')