Shortcuts

Source code for mmengine.model.wrappers.test_time_aug

# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Dict, List, Union

import torch
import torch.nn as nn

from mmengine import MODELS
from mmengine.structures import BaseDataElement

# multi-batch inputs processed by different augmentations from the same batch.
EnhancedBatchInputs = List[Union[torch.Tensor, List[torch.Tensor]]]
# multi-batch data samples processed by different augmentations from the same
# batch. The inner list stands for different augmentations and the outer list
# stands for batch.
EnhancedBatchDataSamples = List[List[BaseDataElement]]
DATA_BATCH = Union[Dict[str, Union[EnhancedBatchInputs,
                                   EnhancedBatchDataSamples]], tuple, dict]
MergedDataSamples = List[BaseDataElement]


[docs]@MODELS.register_module() class BaseTTAModel(nn.Module): """Base model for inference with test-time augmentation. ``BaseTTAModel`` is a wrapper for inference given multi-batch data. It implements the :meth:`test_step` for multi-batch data inference. ``multi-batch`` data means data processed by different augmentation from the same batch. During test time augmentation, the data processed by :obj:`mmcv.transforms.TestTimeAug`, and then collated by ``pseudo_collate`` will have the following format: .. code-block:: result = dict( inputs=[ [image1_aug1, image2_aug1], [image1_aug2, image2_aug2] ], data_samples=[ [data_sample1_aug1, data_sample2_aug1], [data_sample1_aug2, data_sample2_aug2], ] ) ``image{i}_aug{j}`` means the i-th image of the batch, which is augmented by the j-th augmentation. ``BaseTTAModel`` will collate the data to: .. code-block:: data1 = dict( inputs=[image1_aug1, image2_aug1], data_samples=[data_sample1_aug1, data_sample2_aug1] ) data2 = dict( inputs=[image1_aug2, image2_aug2], data_samples=[data_sample1_aug2, data_sample2_aug2] ) ``data1`` and ``data2`` will be passed to model, and the results will be merged by :meth:`merge_preds`. Note: :meth:`merge_preds` is an abstract method, all subclasses should implement it. Args: module (dict or nn.Module): Tested model. """ def __init__(self, module: Union[dict, nn.Module]): super().__init__() if isinstance(module, nn.Module): self.module = module elif isinstance(module, dict): self.module = MODELS.build(module) else: raise TypeError('The type of module should be a `nn.Module` ' f'instance or a dict, but got {module}') assert hasattr(self.module, 'test_step'), ( 'Model wrapped by BaseTTAModel must implement `test_step`!')
[docs] @abstractmethod def merge_preds(self, data_samples_list: EnhancedBatchDataSamples) \ -> MergedDataSamples: """Merge predictions of enhanced data to one prediction. Args: data_samples_list (EnhancedBatchDataSamples): List of predictions of all enhanced data. Returns: List[BaseDataElement]: Merged prediction. """
[docs] def test_step(self, data: DATA_BATCH) -> MergedDataSamples: """Get predictions of each enhanced data, a multiple predictions. Args: data (DataBatch): Enhanced data batch sampled from dataloader. Returns: MergedDataSamples: Merged prediction. """ data_list: Union[List[dict], List[list]] if isinstance(data, dict): num_augs = len(data[next(iter(data))]) data_list = [{key: value[idx] for key, value in data.items()} for idx in range(num_augs)] elif isinstance(data, (tuple, list)): num_augs = len(data[0]) data_list = [[_data[idx] for _data in data] for idx in range(num_augs)] else: raise TypeError('data given by dataLoader should be a dict, ' f'tuple or a list, but got {type(data)}') predictions = [] for data in data_list: # type: ignore predictions.append(self.module.test_step(data)) return self.merge_preds(list(zip(*predictions))) # type: ignore

© Copyright 2022, mmengine contributors. Revision 4e685931.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.3.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.