Shortcuts

Source code for mmengine.model.test_time_aug

# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn

from mmengine.registry import MODELS
from mmengine.structures import BaseDataElement
from .base_model import BaseModel

# multi-batch inputs processed by different augmentations from the same batch.
EnhancedBatchInputs = List[Union[torch.Tensor, List[torch.Tensor]]]
# multi-batch data samples processed by different augmentations from the same
# batch. The inner list stands for different augmentations and the outer list
# stands for batch.
EnhancedBatchDataSamples = List[List[BaseDataElement]]
DATA_BATCH = Union[Dict[str, Union[EnhancedBatchInputs,
                                   EnhancedBatchDataSamples]], tuple, dict]
MergedDataSamples = List[BaseDataElement]


[docs]@MODELS.register_module() class BaseTTAModel(BaseModel): """Base model for inference with test-time augmentation. ``BaseTTAModel`` is a wrapper for inference given multi-batch data. It implements the :meth:`test_step` for multi-batch data inference. ``multi-batch`` data means data processed by different augmentation from the same batch. During test time augmentation, the data processed by :obj:`mmcv.transforms.TestTimeAug`, and then collated by ``pseudo_collate`` will have the following format: .. code-block:: result = dict( inputs=[ [image1_aug1, image2_aug1], [image1_aug2, image2_aug2] ], data_samples=[ [data_sample1_aug1, data_sample2_aug1], [data_sample1_aug2, data_sample2_aug2], ] ) ``image{i}_aug{j}`` means the i-th image of the batch, which is augmented by the j-th augmentation. ``BaseTTAModel`` will collate the data to: .. code-block:: data1 = dict( inputs=[image1_aug1, image2_aug1], data_samples=[data_sample1_aug1, data_sample2_aug1] ) data2 = dict( inputs=[image1_aug2, image2_aug2], data_samples=[data_sample1_aug2, data_sample2_aug2] ) ``data1`` and ``data2`` will be passed to model, and the results will be merged by :meth:`merge_preds`. Note: :meth:`merge_preds` is an abstract method, all subclasses should implement it. Warning: If ``data_preprocessor`` is not None, it will overwrite the model's ``data_preprocessor``. Args: module (dict or nn.Module): Tested model. data_preprocessor (dict or :obj:`BaseDataPreprocessor`, optional): If model does not define ``data_preprocessor``, it will be the default value for model. """ def __init__( self, module: Union[dict, nn.Module], data_preprocessor: Union[dict, nn.Module, None] = None, ): super().__init__() if isinstance(module, nn.Module): self.module = module elif isinstance(module, dict): if data_preprocessor is not None: module['data_preprocessor'] = data_preprocessor self.module = MODELS.build(module) else: raise TypeError('The type of module should be a `nn.Module` ' f'instance or a dict, but got {module}') assert hasattr(self.module, 'test_step'), ( 'Model wrapped by BaseTTAModel must implement `test_step`!')
[docs] @abstractmethod def merge_preds(self, data_samples_list: EnhancedBatchDataSamples) \ -> MergedDataSamples: """Merge predictions of enhanced data to one prediction. Args: data_samples_list (EnhancedBatchDataSamples): List of predictions of all enhanced data. Returns: List[BaseDataElement]: Merged prediction. """
[docs] def test_step(self, data): """Get predictions of each enhanced data, a multiple predictions. Args: data (DataBatch): Enhanced data batch sampled from dataloader. Returns: MergedDataSamples: Merged prediction. """ data_list: Union[List[dict], List[list]] if isinstance(data, dict): num_augs = len(data[next(iter(data))]) data_list = [{key: value[idx] for key, value in data.items()} for idx in range(num_augs)] elif isinstance(data, (tuple, list)): num_augs = len(data[0]) data_list = [[_data[idx] for _data in data] for idx in range(num_augs)] else: raise TypeError('data given by dataLoader should be a dict, ' f'tuple or a list, but got {type(data)}') predictions = [] for data in data_list: # type: ignore predictions.append(self.module.test_step(data)) return self.merge_preds(list(zip(*predictions))) # type: ignore
[docs] def forward(self, inputs: torch.Tensor, data_samples: Optional[list] = None, mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: """``BaseTTAModel.forward`` should not be called.""" raise NotImplementedError( '`BaseTTAModel.forward` will not be called during training or' 'testing. Please call `test_step` instead. If you want to use' '`BaseTTAModel.forward`, please implement this method')

© Copyright 2022, mmengine contributors. Revision 66fb81f7.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.