BaseTTAModel¶
- class mmengine.model.BaseTTAModel(module, data_preprocessor=None)[源代码]¶
Base model for inference with test-time augmentation.
BaseTTAModel
is a wrapper for inference given multi-batch data. It implements thetest_step()
for multi-batch data inference.multi-batch
data means data processed by different augmentation from the same batch.During test time augmentation, the data processed by
mmcv.transforms.TestTimeAug
, and then collated bypseudo_collate
will have the following format:result = dict( inputs=[ [image1_aug1, image2_aug1], [image1_aug2, image2_aug2] ], data_samples=[ [data_sample1_aug1, data_sample2_aug1], [data_sample1_aug2, data_sample2_aug2], ] )
image{i}_aug{j}
means the i-th image of the batch, which is augmented by the j-th augmentation.BaseTTAModel
will collate the data to:data1 = dict( inputs=[image1_aug1, image2_aug1], data_samples=[data_sample1_aug1, data_sample2_aug1] ) data2 = dict( inputs=[image1_aug2, image2_aug2], data_samples=[data_sample1_aug2, data_sample2_aug2] )
data1
anddata2
will be passed to model, and the results will be merged bymerge_preds()
.备注
merge_preds()
is an abstract method, all subclasses should implement it.警告
If
data_preprocessor
is not None, it will overwrite the model’sdata_preprocessor
.- 参数:
module (dict or nn.Module) – Tested model.
data_preprocessor (dict or
BaseDataPreprocessor
, optional) – If model does not definedata_preprocessor
, it will be the default value for model.
- abstract merge_preds(data_samples_list)[源代码]¶
Merge predictions of enhanced data to one prediction.
- 参数:
data_samples_list (EnhancedBatchDataSamples) – List of predictions of all enhanced data.
- 返回:
Merged prediction.
- 返回类型:
List[BaseDataElement]