Shortcuts

mmengine.utils.misc 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import collections.abc
import functools
import itertools
import logging
import re
import subprocess
import textwrap
import warnings
from collections import abc
from importlib import import_module
from inspect import getfullargspec, ismodule
from itertools import repeat
from typing import Any, Callable, Optional, Type, Union


# From PyTorch internals
def _ntuple(n):

    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple


[文档]def is_str(x): """Whether the input is an string instance. Note: This method is deprecated since python 2 is no longer supported. """ return isinstance(x, str)
[文档]def import_modules_from_strings(imports, allow_failed_imports=False): """Import modules from the given list of strings. Args: imports (list | str | None): The given module names to be imported. allow_failed_imports (bool): If True, the failed imports will return None. Otherwise, an ImportError is raise. Defaults to False. Returns: list[module] | module | None: The imported modules. Examples: >>> osp, sys = import_modules_from_strings( ... ['os.path', 'sys']) >>> import os.path as osp_ >>> import sys as sys_ >>> assert osp == osp_ >>> assert sys == sys_ """ if not imports: return single_import = False if isinstance(imports, str): single_import = True imports = [imports] if not isinstance(imports, list): raise TypeError( f'custom_imports must be a list but got type {type(imports)}') imported = [] for imp in imports: if not isinstance(imp, str): raise TypeError( f'{imp} is of type {type(imp)} and cannot be imported.') try: imported_tmp = import_module(imp) except ImportError: if allow_failed_imports: warnings.warn(f'{imp} failed to import and is ignored.', UserWarning) imported_tmp = None else: raise ImportError(f'Failed to import {imp}') imported.append(imported_tmp) if single_import: imported = imported[0] return imported
[文档]def iter_cast(inputs, dst_type, return_type=None): """Cast elements of an iterable object into some type. Args: inputs (Iterable): The input object. dst_type (type): Destination type. return_type (type, optional): If specified, the output object will be converted to this type, otherwise an iterator. Returns: iterator or specified type: The converted object. """ if not isinstance(inputs, abc.Iterable): raise TypeError('inputs must be an iterable object') if not isinstance(dst_type, type): raise TypeError('"dst_type" must be a valid type') out_iterable = map(dst_type, inputs) if return_type is None: return out_iterable else: return return_type(out_iterable)
[文档]def list_cast(inputs, dst_type): """Cast elements of an iterable object into a list of some type. A partial method of :func:`iter_cast`. """ return iter_cast(inputs, dst_type, return_type=list)
[文档]def tuple_cast(inputs, dst_type): """Cast elements of an iterable object into a tuple of some type. A partial method of :func:`iter_cast`. """ return iter_cast(inputs, dst_type, return_type=tuple)
[文档]def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None) -> bool: """Check whether it is a sequence of some type. Args: seq (Sequence): The sequence to be checked. expected_type (type or tuple): Expected type of sequence items. seq_type (type, optional): Expected sequence type. Defaults to None. Returns: bool: Return True if ``seq`` is valid else False. Examples: >>> from mmengine.utils import is_seq_of >>> seq = ['a', 'b', 'c'] >>> is_seq_of(seq, str) True >>> is_seq_of(seq, int) False """ if seq_type is None: exp_seq_type = abc.Sequence else: assert isinstance(seq_type, type) exp_seq_type = seq_type if not isinstance(seq, exp_seq_type): return False for item in seq: if not isinstance(item, expected_type): return False return True
[文档]def is_list_of(seq, expected_type): """Check whether it is a list of some type. A partial method of :func:`is_seq_of`. """ return is_seq_of(seq, expected_type, seq_type=list)
[文档]def is_tuple_of(seq, expected_type): """Check whether it is a tuple of some type. A partial method of :func:`is_seq_of`. """ return is_seq_of(seq, expected_type, seq_type=tuple)
[文档]def slice_list(in_list, lens): """Slice a list into several sub lists by a list of given length. Args: in_list (list): The list to be sliced. lens(int or list): The expected length of each out list. Returns: list: A list of sliced list. """ if isinstance(lens, int): assert len(in_list) % lens == 0 lens = [lens] * int(len(in_list) / lens) if not isinstance(lens, list): raise TypeError('"indices" must be an integer or a list of integers') elif sum(lens) != len(in_list): raise ValueError('sum of lens and list length does not ' f'match: {sum(lens)} != {len(in_list)}') out_list = [] idx = 0 for i in range(len(lens)): out_list.append(in_list[idx:idx + lens[i]]) idx += lens[i] return out_list
[文档]def concat_list(in_list): """Concatenate a list of list into a single list. Args: in_list (list): The list of list to be merged. Returns: list: The concatenated flat list. """ return list(itertools.chain(*in_list))
[文档]def apply_to(data: Any, expr: Callable, apply_func: Callable): """Apply function to each element in dict, list or tuple that matches with the expression. For examples, if you want to convert each element in a list of dict from `np.ndarray` to `Tensor`. You can use the following code: Examples: >>> from mmengine.utils import apply_to >>> import numpy as np >>> import torch >>> data = dict(array=[np.array(1)]) # {'array': [array(1)]} >>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x)) >>> print(result) # {'array': [tensor(1)]} Args: data (Any): Data to be applied. expr (Callable): Expression to tell which data should be applied with the function. It should return a boolean. apply_func (Callable): Function applied to data. Returns: Any: The data after applying. """ # noqa: E501 if isinstance(data, dict): # Keep the original dict type res = type(data)() for key, value in data.items(): res[key] = apply_to(value, expr, apply_func) return res elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple return type(data)(*(apply_to(sample, expr, apply_func) for sample in data)) # type: ignore # noqa: E501 # yapf:disable elif isinstance(data, (tuple, list)): return type(data)(apply_to(sample, expr, apply_func) for sample in data) # type: ignore # noqa: E501 # yapf:disable elif expr(data): return apply_func(data) else: return data
[文档]def check_prerequisites( prerequisites, checker, msg_tmpl='Prerequisites "{}" are required in method "{}" but not ' 'found, please install them first.'): # yapf: disable """A decorator factory to check if prerequisites are satisfied. Args: prerequisites (str of list[str]): Prerequisites to be checked. checker (callable): The checker method that returns True if a prerequisite is meet, False otherwise. msg_tmpl (str): The message template with two variables. Returns: decorator: A specific decorator. """ def wrap(func): @functools.wraps(func) def wrapped_func(*args, **kwargs): requirements = [prerequisites] if isinstance( prerequisites, str) else prerequisites missing = [] for item in requirements: if not checker(item): missing.append(item) if missing: print(msg_tmpl.format(', '.join(missing), func.__name__)) raise RuntimeError('Prerequisites not meet.') else: return func(*args, **kwargs) return wrapped_func return wrap
def _check_py_package(package): try: import_module(package) except ImportError: return False else: return True def _check_executable(cmd): if subprocess.call(f'which {cmd}', shell=True) != 0: return False else: return True
[文档]def requires_package(prerequisites): """A decorator to check if some python packages are installed. Example: >>> @requires_package('numpy') >>> func(arg1, args): >>> return numpy.zeros(1) array([0.]) >>> @requires_package(['numpy', 'non_package']) >>> func(arg1, args): >>> return numpy.zeros(1) ImportError """ return check_prerequisites(prerequisites, checker=_check_py_package)
[文档]def requires_executable(prerequisites): """A decorator to check if some executable files are installed. Example: >>> @requires_executable('ffmpeg') >>> func(arg1, args): >>> print(1) 1 """ return check_prerequisites(prerequisites, checker=_check_executable)
[文档]def deprecated_api_warning(name_dict: dict, cls_name: Optional[str] = None) -> Callable: """A decorator to check if some arguments are deprecate and try to replace deprecate src_arg_name to dst_arg_name. Args: name_dict(dict): key (str): Deprecate argument names. val (str): Expected argument names. Returns: func: New function. """ def api_warning_wrapper(old_func): @functools.wraps(old_func) def new_func(*args, **kwargs): # get the arg spec of the decorated method args_info = getfullargspec(old_func) # get name of the function func_name = old_func.__name__ if cls_name is not None: func_name = f'{cls_name}.{func_name}' if args: arg_names = args_info.args[:len(args)] for src_arg_name, dst_arg_name in name_dict.items(): if src_arg_name in arg_names: warnings.warn( f'"{src_arg_name}" is deprecated in ' f'`{func_name}`, please use "{dst_arg_name}" ' 'instead', DeprecationWarning) arg_names[arg_names.index(src_arg_name)] = dst_arg_name if kwargs: for src_arg_name, dst_arg_name in name_dict.items(): if src_arg_name in kwargs: assert dst_arg_name not in kwargs, ( f'The expected behavior is to replace ' f'the deprecated key `{src_arg_name}` to ' f'new key `{dst_arg_name}`, but got them ' f'in the arguments at the same time, which ' f'is confusing. `{src_arg_name} will be ' f'deprecated in the future, please ' f'use `{dst_arg_name}` instead.') warnings.warn( f'"{src_arg_name}" is deprecated in ' f'`{func_name}`, please use "{dst_arg_name}" ' 'instead', DeprecationWarning) kwargs[dst_arg_name] = kwargs.pop(src_arg_name) # apply converted arguments to the decorated method output = old_func(*args, **kwargs) return output return new_func return api_warning_wrapper
[文档]def is_method_overridden(method: str, base_class: type, derived_class: Union[type, Any]) -> bool: """Check if a method of base class is overridden in derived class. Args: method (str): the method name to check. base_class (type): the class of the base class. derived_class (type | Any): the class or instance of the derived class. """ assert isinstance(base_class, type), \ "base_class doesn't accept instance, Please pass class instead." if not isinstance(derived_class, type): derived_class = derived_class.__class__ base_method = getattr(base_class, method) derived_method = getattr(derived_class, method) return derived_method != base_method
[文档]def has_method(obj: object, method: str) -> bool: """Check whether the object has a method. Args: method (str): The method name to check. obj (object): The object to check. Returns: bool: True if the object has the method else False. """ return hasattr(obj, method) and callable(getattr(obj, method))
[文档]def deprecated_function(since: str, removed_in: str, instructions: str) -> Callable: """Marks functions as deprecated. Throw a warning when a deprecated function is called, and add a note in the docstring. Modified from https://github.com/pytorch/pytorch/blob/master/torch/onnx/_deprecation.py Args: since (str): The version when the function was first deprecated. removed_in (str): The version when the function will be removed. instructions (str): The action users should take. Returns: Callable: A new function, which will be deprecated soon. """ # noqa: E501 from mmengine import print_log def decorator(function): @functools.wraps(function) def wrapper(*args, **kwargs): print_log( f"'{function.__module__}.{function.__name__}' " f'is deprecated in version {since} and will be ' f'removed in version {removed_in}. Please {instructions}.', logger='current', level=logging.WARNING, ) return function(*args, **kwargs) indent = ' ' # Add a deprecation note to the docstring. docstring = function.__doc__ or '' # Add a note to the docstring. deprecation_note = textwrap.dedent(f"""\ .. deprecated:: {since} Deprecated and will be removed in version {removed_in}. Please {instructions}. """) # Split docstring at first occurrence of newline pattern = '\n\n' summary_and_body = re.split(pattern, docstring, 1) if len(summary_and_body) > 1: summary, body = summary_and_body body = textwrap.indent(textwrap.dedent(body), indent) summary = '\n'.join( [textwrap.dedent(string) for string in summary.split('\n')]) summary = textwrap.indent(summary, prefix=indent) # Dedent the body. We cannot do this with the presence of the # summary because the body contains leading whitespaces when the # summary does not. new_docstring_parts = [ deprecation_note, '\n\n', summary, '\n\n', body ] else: summary = summary_and_body[0] summary = '\n'.join( [textwrap.dedent(string) for string in summary.split('\n')]) summary = textwrap.indent(summary, prefix=indent) new_docstring_parts = [deprecation_note, '\n\n', summary] wrapper.__doc__ = ''.join(new_docstring_parts) return wrapper return decorator
def get_object_from_string(obj_name: str): """Get object from name. Args: obj_name (str): The name of the object. Examples: >>> get_object_from_string('torch.optim.sgd.SGD') >>> torch.optim.sgd.SGD """ parts = iter(obj_name.split('.')) module_name = next(parts) # import module while True: try: module = import_module(module_name) part = next(parts) # mmcv.ops has nms.py has nms function at the same time. So the # function will have a higher priority obj = getattr(module, part, None) if obj is not None and not ismodule(obj): break module_name = f'{module_name}.{part}' except StopIteration: # if obj is a module return module except ImportError: return None # get class or attribute from module while True: try: obj_cls = getattr(module, part) part = next(parts) except StopIteration: return obj_cls except AttributeError: return None

© Copyright 2022, mmengine contributors. Revision 317d8f31.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.8.0
Versions
latest
stable
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.