Shortcuts

BaseTTAModel

class mmengine.model.BaseTTAModel(module)[source]

Base model for inference with test-time augmentation.

BaseTTAModel is a wrapper for inference given multi-batch data. It implements the test_step() for multi-batch data inference. multi-batch data means data processed by different augmentation from the same batch.

During test time augmentation, the data processed by mmcv.transforms.TestTimeAug, and then collated by pseudo_collate will have the following format:

result = dict(
    inputs=[
        [image1_aug1, image2_aug1],
        [image1_aug2, image2_aug2]
    ],
    data_samples=[
        [data_sample1_aug1, data_sample2_aug1],
        [data_sample1_aug2, data_sample2_aug2],
    ]
)

image{i}_aug{j} means the i-th image of the batch, which is augmented by the j-th augmentation.

BaseTTAModel will collate the data to:

data1 = dict(
    inputs=[image1_aug1, image2_aug1],
    data_samples=[data_sample1_aug1, data_sample2_aug1]
)

data2 = dict(
    inputs=[image1_aug2, image2_aug2],
    data_samples=[data_sample1_aug2, data_sample2_aug2]
)

data1 and data2 will be passed to model, and the results will be merged by merge_preds().

Note

merge_preds() is an abstract method, all subclasses should implement it.

Parameters

module (dict or nn.Module) – Tested model.

abstract merge_preds(data_samples_list)[source]

Merge predictions of enhanced data to one prediction.

Parameters

data_samples_list (EnhancedBatchDataSamples) – List of predictions of all enhanced data.

Returns

Merged prediction.

Return type

List[BaseDataElement]

test_step(data)[source]

Get predictions of each enhanced data, a multiple predictions.

Parameters

data (DataBatch) – Enhanced data batch sampled from dataloader.

Returns

Merged prediction.

Return type

MergedDataSamples

Read the Docs v: v0.2.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.