BaseTTAModel¶
- class mmengine.model.BaseTTAModel(module, data_preprocessor=None)[source]¶
Base model for inference with test-time augmentation.
BaseTTAModel
is a wrapper for inference given multi-batch data. It implements thetest_step()
for multi-batch data inference.multi-batch
data means data processed by different augmentation from the same batch.During test time augmentation, the data processed by
mmcv.transforms.TestTimeAug
, and then collated bypseudo_collate
will have the following format:result = dict( inputs=[ [image1_aug1, image2_aug1], [image1_aug2, image2_aug2] ], data_samples=[ [data_sample1_aug1, data_sample2_aug1], [data_sample1_aug2, data_sample2_aug2], ] )
image{i}_aug{j}
means the i-th image of the batch, which is augmented by the j-th augmentation.BaseTTAModel
will collate the data to:data1 = dict( inputs=[image1_aug1, image2_aug1], data_samples=[data_sample1_aug1, data_sample2_aug1] ) data2 = dict( inputs=[image1_aug2, image2_aug2], data_samples=[data_sample1_aug2, data_sample2_aug2] )
data1
anddata2
will be passed to model, and the results will be merged bymerge_preds()
.Note
merge_preds()
is an abstract method, all subclasses should implement it.Warning
If
data_preprocessor
is not None, it will overwrite the model’sdata_preprocessor
.- Parameters
module (dict or nn.Module) – Tested model.
data_preprocessor (dict or
BaseDataPreprocessor
, optional) – If model does not definedata_preprocessor
, it will be the default value for model.
- forward(inputs, data_samples=None, mode='tensor')[source]¶
BaseTTAModel.forward
should not be called.- Parameters
inputs (torch.Tensor) –
data_samples (Optional[list]) –
mode (str) –
- Return type
Union[Dict[str, torch.Tensor], list]
- abstract merge_preds(data_samples_list)[source]¶
Merge predictions of enhanced data to one prediction.
- Parameters
data_samples_list (EnhancedBatchDataSamples) – List of predictions of all enhanced data.
- Returns
Merged prediction.
- Return type
List[BaseDataElement]