BaseDataPreprocessor¶
- class mmengine.model.BaseDataPreprocessor(non_blocking=False)[源代码]¶
Base data pre-processor used for copying data to the target device.
Subclasses inherit from
BaseDataPreprocessor
could override the forward method to implement custom data pre-processing, such as batch-resize, MixUp, or CutMix.- 参数:
non_blocking (bool) – Whether block current process when transferring data to device. New in version 0.3.0.
备注
Data dictionary returned by dataloader must be a dict and at least contain the
inputs
key.- cast_data(data)[源代码]¶
Copying data to the target device.
- 参数:
data (dict) – Data returned by
DataLoader
.- 返回:
Inputs and data sample at target device.
- 返回类型:
CollatedResult
- cpu(*args, **kwargs)[源代码]¶
Overrides this method to set the
device
- 返回:
The model itself.
- 返回类型:
nn.Module
- cuda(*args, **kwargs)[源代码]¶
Overrides this method to set the
device
- 返回:
The model itself.
- 返回类型:
nn.Module
- forward(data, training=False)[源代码]¶
Preprocesses the data into the model input format.
After the data pre-processing of
cast_data()
,forward
will stack the input tensor list to a batch tensor at the first dimension.
- mlu(*args, **kwargs)[源代码]¶
Overrides this method to set the
device
- 返回:
The model itself.
- 返回类型:
nn.Module
- musa(*args, **kwargs)[源代码]¶
Overrides this method to set the
device
- 返回:
The model itself.
- 返回类型:
nn.Module