Shortcuts

BaseTTAModel

class mmengine.model.BaseTTAModel(module, data_preprocessor=None)[源代码]

Base model for inference with test-time augmentation.

BaseTTAModel is a wrapper for inference given multi-batch data. It implements the test_step() for multi-batch data inference. multi-batch data means data processed by different augmentation from the same batch.

During test time augmentation, the data processed by mmcv.transforms.TestTimeAug, and then collated by pseudo_collate will have the following format:

result = dict(
    inputs=[
        [image1_aug1, image2_aug1],
        [image1_aug2, image2_aug2]
    ],
    data_samples=[
        [data_sample1_aug1, data_sample2_aug1],
        [data_sample1_aug2, data_sample2_aug2],
    ]
)

image{i}_aug{j} means the i-th image of the batch, which is augmented by the j-th augmentation.

BaseTTAModel will collate the data to:

data1 = dict(
    inputs=[image1_aug1, image2_aug1],
    data_samples=[data_sample1_aug1, data_sample2_aug1]
)

data2 = dict(
    inputs=[image1_aug2, image2_aug2],
    data_samples=[data_sample1_aug2, data_sample2_aug2]
)

data1 and data2 will be passed to model, and the results will be merged by merge_preds().

备注

merge_preds() is an abstract method, all subclasses should implement it.

警告

If data_preprocessor is not None, it will overwrite the model’s data_preprocessor.

参数:
  • module (dict or nn.Module) – Tested model.

  • data_preprocessor (dict or BaseDataPreprocessor, optional) – If model does not define data_preprocessor, it will be the default value for model.

forward(inputs, data_samples=None, mode='tensor')[源代码]

BaseTTAModel.forward should not be called.

参数:
  • inputs (Tensor) –

  • data_samples (list | None) –

  • mode (str) –

返回类型:

Dict[str, Tensor] | list

abstract merge_preds(data_samples_list)[源代码]

Merge predictions of enhanced data to one prediction.

参数:

data_samples_list (EnhancedBatchDataSamples) – List of predictions of all enhanced data.

返回:

Merged prediction.

返回类型:

List[BaseDataElement]

test_step(data)[源代码]

Get predictions of each enhanced data, a multiple predictions.

参数:

data (DataBatch) – Enhanced data batch sampled from dataloader.

返回:

Merged prediction.

返回类型:

MergedDataSamples

Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.