Shortcuts

BaseTTAModel

class mmengine.model.BaseTTAModel(module, data_preprocessor=None)[source]

Base model for inference with test-time augmentation.

BaseTTAModel is a wrapper for inference given multi-batch data. It implements the test_step() for multi-batch data inference. multi-batch data means data processed by different augmentation from the same batch.

During test time augmentation, the data processed by mmcv.transforms.TestTimeAug, and then collated by pseudo_collate will have the following format:

result = dict(
    inputs=[
        [image1_aug1, image2_aug1],
        [image1_aug2, image2_aug2]
    ],
    data_samples=[
        [data_sample1_aug1, data_sample2_aug1],
        [data_sample1_aug2, data_sample2_aug2],
    ]
)

image{i}_aug{j} means the i-th image of the batch, which is augmented by the j-th augmentation.

BaseTTAModel will collate the data to:

data1 = dict(
    inputs=[image1_aug1, image2_aug1],
    data_samples=[data_sample1_aug1, data_sample2_aug1]
)

data2 = dict(
    inputs=[image1_aug2, image2_aug2],
    data_samples=[data_sample1_aug2, data_sample2_aug2]
)

data1 and data2 will be passed to model, and the results will be merged by merge_preds().

Note

merge_preds() is an abstract method, all subclasses should implement it.

Warning

If data_preprocessor is not None, it will overwrite the model’s data_preprocessor.

Parameters
  • module (dict or nn.Module) – Tested model.

  • data_preprocessor (dict or BaseDataPreprocessor, optional) – If model does not define data_preprocessor, it will be the default value for model.

forward(inputs, data_samples=None, mode='tensor')[source]

BaseTTAModel.forward should not be called.

Parameters
Return type

Union[Dict[str, torch.Tensor], list]

abstract merge_preds(data_samples_list)[source]

Merge predictions of enhanced data to one prediction.

Parameters

data_samples_list (EnhancedBatchDataSamples) – List of predictions of all enhanced data.

Returns

Merged prediction.

Return type

List[BaseDataElement]

test_step(data)[source]

Get predictions of each enhanced data, a multiple predictions.

Parameters

data (DataBatch) – Enhanced data batch sampled from dataloader.

Returns

Merged prediction.

Return type

MergedDataSamples

Read the Docs v: v0.4.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.