Shortcuts

Source code for mmengine.structures.instance_data

# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from collections.abc import Sized
from typing import List, Union

import numpy as np
import torch

from .base_data_element import BaseDataElement

IndexType = Union[str, slice, int, list, torch.LongTensor,
                  torch.cuda.LongTensor, torch.BoolTensor,
                  torch.cuda.BoolTensor, np.ndarray]


# Modified from
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
[docs]class InstanceData(BaseDataElement): """Data structure for instance-level annotations or predictions. Subclass of :class:`BaseDataElement`. All value in `data_fields` should have the same length. This design refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value in data field can be base data structure such as `torch.tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. Examples: >>> # custom data structure >>> class TmpObject: ... def __init__(self, tmp) -> None: ... assert isinstance(tmp, list) ... self.tmp = tmp ... def __len__(self): ... return len(self.tmp) ... def __getitem__(self, item): ... if type(item) == int: ... if item >= len(self) or item < -len(self): # type:ignore ... raise IndexError(f'Index {item} out of range!') ... else: ... # keep the dimension ... item = slice(item, None, len(self)) ... return TmpObject(self.tmp[item]) ... @staticmethod ... def cat(tmp_objs): ... assert all(isinstance(results, TmpObject) for results in tmp_objs) ... if len(tmp_objs) == 1: ... return tmp_objs[0] ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] ... tmp_list = list(itertools.chain(*tmp_list)) ... new_data = TmpObject(tmp_list) ... return new_data ... def __repr__(self): ... return str(self.tmp) >>> from mmengine.structures import InstanceData >>> import numpy as np >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) >>> instance_data = InstanceData(metainfo=img_meta) >>> 'img_shape' in instance_data True >>> instance_data.det_labels = torch.LongTensor([2, 3]) >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) >>> instance_data.bboxes = torch.rand((2, 4)) >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) >>> len(instance_data) 2 >>> print(instance_data) <InstanceData( META INFORMATION pad_shape: (800, 1196, 3) img_shape: (800, 1216, 3) DATA FIELDS det_labels: tensor([2, 3]) det_scores: tensor([0.8, 0.7000]) bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], [0.8101, 0.3105, 0.5123, 0.6263]]) polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] ) at 0x7fb492de6280> >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] >>> sorted_results.det_scores tensor([0.7000, 0.8000]) >>> print(instance_data[instance_data.det_scores > 0.75]) <InstanceData( META INFORMATION pad_shape: (800, 1216, 3) img_shape: (800, 1196, 3) DATA FIELDS det_labels: tensor([2]) masks: [[11, 21, 31, 41]] det_scores: tensor([0.8000]) bboxes: tensor([[0.9308, 0.4000, 0.6077, 0.5554]]) polygons: [[1, 2, 3, 4]] ) at 0x7f64ecf0ec40> >>> print(instance_data[instance_data.det_scores > 1]) <InstanceData( META INFORMATION pad_shape: (800, 1216, 3) img_shape: (800, 1196, 3) DATA FIELDS det_labels: tensor([], dtype=torch.int64) masks: [] det_scores: tensor([]) bboxes: tensor([], size=(0, 4)) polygons: [[]] ) at 0x7f660a6a7f70> >>> print(instance_data.cat([instance_data, instance_data])) <InstanceData( META INFORMATION img_shape: (800, 1196, 3) pad_shape: (800, 1216, 3) DATA FIELDS det_labels: tensor([2, 3, 2, 3]) bboxes: tensor([[0.7404, 0.6332, 0.1684, 0.9961], [0.2837, 0.8112, 0.5416, 0.2810], [0.7404, 0.6332, 0.1684, 0.9961], [0.2837, 0.8112, 0.5416, 0.2810]]) data: polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) masks: [[11, 21, 31, 41], [51, 61, 71, 81], [11, 21, 31, 41], [51, 61, 71, 81]] ) at 0x7f203542feb0> """ def __setattr__(self, name: str, value: Sized): """setattr is only used to set data. the value must have the attribute of `__len__` and have the same length of instancedata """ if name in ('_metainfo_fields', '_data_fields'): if not hasattr(self, name): super().__setattr__(name, value) else: raise AttributeError( f'{name} has been used as a ' f'private attribute, which is immutable. ') else: assert isinstance(value, Sized), 'value must contain `_len__` attribute' if len(self) > 0: assert len(value) == len(self), f'the length of ' \ f'values {len(value)} is ' \ f'not consistent with' \ f' the length of this ' \ f':obj:`InstanceData` ' \ f'{len(self)} ' super().__setattr__(name, value) __setitem__ = __setattr__ def __getitem__(self, item: IndexType) -> 'InstanceData': """ Args: item (str, obj:`slice`, obj`torch.LongTensor`, obj:`torch.BoolTensor`): get the corresponding values according to item. Returns: obj:`InstanceData`: Corresponding values. """ if isinstance(item, list): item = np.array(item) if isinstance(item, np.ndarray): # The default int type of numpy is platform dependent, int32 for # windows and int64 for linux. `torch.Tensor` requires the index # should be int64, therefore we simply convert it to int64 here. # More details in https://github.com/numpy/numpy/issues/9464 item = item.astype(np.int64) if item.dtype == np.int32 else item item = torch.from_numpy(item) assert isinstance( item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor, torch.BoolTensor, torch.cuda.BoolTensor)) if isinstance(item, str): return getattr(self, item) if type(item) == int: if item >= len(self) or item < -len(self): # type:ignore raise IndexError(f'Index {item} out of range!') else: # keep the dimension item = slice(item, None, len(self)) new_data = self.__class__(metainfo=self.metainfo) if isinstance(item, torch.Tensor): assert item.dim() == 1, 'Only support to get the' \ ' values along the first dimension.' if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)): assert len(item) == len(self), f'The shape of the' \ f' input(BoolTensor)) ' \ f'{len(item)} ' \ f' does not match the shape ' \ f'of the indexed tensor ' \ f'in results_filed ' \ f'{len(self)} at ' \ f'first dimension. ' for k, v in self.items(): if isinstance(v, torch.Tensor): new_data[k] = v[item] elif isinstance(v, np.ndarray): new_data[k] = v[item.cpu().numpy()] elif isinstance( v, (str, list, tuple)) or (hasattr(v, '__getitem__') and hasattr(v, 'cat')): # convert to indexes from boolTensor if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)): indexes = torch.nonzero(item).view( -1).cpu().numpy().tolist() else: indexes = item.cpu().numpy().tolist() slice_list = [] if indexes: for index in indexes: slice_list.append(slice(index, None, len(v))) else: slice_list.append(slice(None, 0, None)) r_list = [v[s] for s in slice_list] if isinstance(v, (str, list, tuple)): new_value = r_list[0] for r in r_list[1:]: new_value = new_value + r else: new_value = v.cat(r_list) new_data[k] = new_value else: raise ValueError( f'The type of `{k}` is `{type(v)}`, which has no ' 'attribute of `cat`, so it does not ' f'support slice with `bool`') else: # item is a slice for k, v in self.items(): new_data[k] = v[item] return new_data # type:ignore
[docs] @staticmethod def cat(instances_list: List['InstanceData']) -> 'InstanceData': """Concat the instances of all :obj:`InstanceData` in the list. Note: To ensure that cat returns as expected, make sure that all elements in the list must have exactly the same keys. Args: instances_list (list[:obj:`InstanceData`]): A list of :obj:`InstanceData`. Returns: obj:`InstanceData` """ assert all( isinstance(results, InstanceData) for results in instances_list) assert len(instances_list) > 0 if len(instances_list) == 1: return instances_list[0] # metainfo and data_fields must be exactly the # same for each element to avoid exceptions. field_keys_list = [ instances.all_keys() for instances in instances_list ] assert len({len(field_keys) for field_keys in field_keys_list}) \ == 1 and len(set(itertools.chain(*field_keys_list))) \ == len(field_keys_list[0]), 'There are different keys in ' \ '`instances_list`, which may ' \ 'cause the cat operation ' \ 'to fail. Please make sure all ' \ 'elements in `instances_list` ' \ 'have the exact same key ' new_data = instances_list[0].__class__( metainfo=instances_list[0].metainfo) for k in instances_list[0].keys(): values = [results[k] for results in instances_list] v0 = values[0] if isinstance(v0, torch.Tensor): new_values = torch.cat(values, dim=0) elif isinstance(v0, np.ndarray): new_values = np.concatenate(values, axis=0) elif isinstance(v0, (str, list, tuple)): new_values = v0[:] for v in values[1:]: new_values += v elif hasattr(v0, 'cat'): new_values = v0.cat(values) else: raise ValueError( f'The type of `{k}` is `{type(v0)}` which has no ' 'attribute of `cat`') new_data[k] = new_values return new_data # type:ignore
def __len__(self) -> int: """int: the length of InstanceData""" if len(self._data_fields) > 0: return len(self.values()[0]) else: return 0

© Copyright 2022, mmengine contributors. Revision 4e685931.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.3.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.