Shortcuts

mmengine.runner.amp 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import logging
from contextlib import contextmanager
from typing import Optional

import torch

from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
                             is_npu_available)
from mmengine.logging import print_log
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION


[文档]@contextmanager def autocast(device_type: Optional[str] = None, dtype: Optional[torch.dtype] = None, enabled: bool = True, cache_enabled: Optional[bool] = None): """A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``. Pytorch 1.5.0 provide ``torch.cuda.amp.autocast`` for running in mixed precision , and update it to ``torch.autocast`` in 1.10.0. Both interfaces have different arguments, and ``torch.autocast`` support running with cpu additionally. This function provides a unified interface by wrapping ``torch.autocast`` and ``torch.cuda.amp.autocast``, which resolves the compatibility issues that ``torch.cuda.amp.autocast`` does not support running mixed precision with cpu, and both contexts have different arguments. We suggest users using this function in the code to achieve maximized compatibility of different PyTorch versions. Note: ``autocast`` requires pytorch version >= 1.5.0. If pytorch version <= 1.10.0 and cuda is not available, it will raise an error with ``enabled=True``, since ``torch.cuda.amp.autocast`` only support cuda mode. Examples: >>> # case1: 1.10 > Pytorch version >= 1.5.0 >>> with autocast(): >>> # run in mixed precision context >>> pass >>> with autocast(device_type='cpu'):: >>> # raise error, torch.cuda.amp.autocast only support cuda mode. >>> pass >>> # case2: Pytorch version >= 1.10.0 >>> with autocast(): >>> # default cuda mixed precision context >>> pass >>> with autocast(device_type='cpu'): >>> # cpu mixed precision context >>> pass >>> with autocast( >>> device_type='cuda', enabled=True, cache_enabled=True): >>> # enable precision context with more specific arguments. >>> pass Args: device_type (str, required): Whether to use 'cuda' or 'cpu' device. enabled(bool): Whether autocasting should be enabled in the region. Defaults to True dtype (torch_dtype, optional): Whether to use ``torch.float16`` or ``torch.bfloat16``. cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled. """ # If `enabled` is True, enable an empty context and all calculations # are performed under fp32. assert digit_version(TORCH_VERSION) >= digit_version('1.5.0'), ( 'The minimum pytorch version requirements of mmengine is 1.5.0, but ' f'got {TORCH_VERSION}') if (digit_version('1.5.0') <= digit_version(TORCH_VERSION) < digit_version('1.10.0')): # If pytorch version is between 1.5.0 and 1.10.0, the default value of # dtype for `torch.cuda.amp.autocast` is torch.float16. assert ( device_type == 'cuda' or device_type == 'mlu' or device_type is None), ( 'Pytorch version under 1.10.0 only supports running automatic ' 'mixed training with cuda or mlu') if dtype is not None or cache_enabled is not None: print_log( f'{dtype} and {device_type} will not work for ' '`autocast` since your Pytorch version: ' f'{TORCH_VERSION} <= 1.10.0', logger='current', level=logging.WARNING) if is_npu_available(): with torch.npu.amp.autocast(enabled=enabled): yield elif is_mlu_available(): with torch.mlu.amp.autocast(enabled=enabled): yield elif is_cuda_available(): with torch.cuda.amp.autocast(enabled=enabled): yield else: if not enabled: yield else: raise RuntimeError( 'If pytorch versions is between 1.5.0 and 1.10, ' '`autocast` is only available in gpu mode') else: # Modified from https://github.com/pytorch/pytorch/blob/master/torch/amp/autocast_mode.py # noqa: E501 # This code should update with the `torch.autocast`. if cache_enabled is None: cache_enabled = torch.is_autocast_cache_enabled() device = get_device() device_type = device if device_type is None else device_type if device_type == 'cuda': if dtype is None: dtype = torch.get_autocast_gpu_dtype() if dtype == torch.bfloat16 and not \ torch.cuda.is_bf16_supported(): raise RuntimeError( 'Current CUDA Device does not support bfloat16. Please ' 'switch dtype to float16.') elif device_type == 'cpu': if dtype is None: dtype = torch.bfloat16 assert dtype == torch.bfloat16, ( 'In CPU autocast, only support `torch.bfloat16` dtype') elif device_type == 'mlu': pass elif device_type == 'npu': pass elif device_type == 'musa': if dtype is None: dtype = torch.get_autocast_gpu_dtype() with torch.musa.amp.autocast( enabled=enabled, dtype=dtype, cache_enabled=cache_enabled): yield return else: # Device like MPS does not support fp16 training or testing. # If an inappropriate device is set and fp16 is enabled, an error # will be thrown. if enabled is False: yield return else: raise ValueError('User specified autocast device_type must be ' f'cuda or cpu, but got {device_type}') with torch.autocast( device_type=device_type, enabled=enabled, dtype=dtype, cache_enabled=cache_enabled): yield

© Copyright 2022, mmengine contributors. Revision 66fb81f7.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.