mmengine.runner.autocast¶
- mmengine.runner.autocast(device_type=None, dtype=None, enabled=True, cache_enabled=None)[源代码]¶
A wrapper of
torch.autocast
andtoch.cuda.amp.autocast
.Pytorch 1.5.0 provide
torch.cuda.amp.autocast
for running in mixed precision , and update it totorch.autocast
in 1.10.0. Both interfaces have different arguments, andtorch.autocast
support running with cpu additionally.This function provides a unified interface by wrapping
torch.autocast
andtorch.cuda.amp.autocast
, which resolves the compatibility issues thattorch.cuda.amp.autocast
does not support running mixed precision with cpu, and both contexts have different arguments. We suggest users using this function in the code to achieve maximized compatibility of different PyTorch versions.备注
autocast
requires pytorch version >= 1.5.0. If pytorch version <= 1.10.0 and cuda is not available, it will raise an error withenabled=True
, sincetorch.cuda.amp.autocast
only support cuda mode.示例
>>> # case1: 1.10 > Pytorch version >= 1.5.0 >>> with autocast(): >>> # run in mixed precision context >>> pass >>> with autocast(device_type='cpu'):: >>> # raise error, torch.cuda.amp.autocast only support cuda mode. >>> pass >>> # case2: Pytorch version >= 1.10.0 >>> with autocast(): >>> # default cuda mixed precision context >>> pass >>> with autocast(device_type='cpu'): >>> # cpu mixed precision context >>> pass >>> with autocast( >>> device_type='cuda', enabled=True, cache_enabled=True): >>> # enable precision context with more specific arguments. >>> pass
- 参数:
device_type (str, required) – Whether to use ‘cuda’ or ‘cpu’ device.
enabled (bool) – Whether autocasting should be enabled in the region. Defaults to True
dtype (torch_dtype, optional) – Whether to use
torch.float16
ortorch.bfloat16
.cache_enabled (bool, optional) – Whether the weight cache inside autocast should be enabled.