InstanceData¶
- class mmengine.structures.InstanceData(*, metainfo=None, **kwargs)[源代码]¶
Data structure for instance-level annotations or predictions.
Subclass of
BaseDataElement
. All value in data_fields should have the same length. This design refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 InstanceData also support extra functions:index
,slice
andcat
for data field. The type of value in data field can be base data structure such as torch.Tensor, numpy.ndarray, list, str, tuple, and can be customized data structure that has__len__
,__getitem__
andcat
attributes.示例
>>> # custom data structure >>> class TmpObject: ... def __init__(self, tmp) -> None: ... assert isinstance(tmp, list) ... self.tmp = tmp ... def __len__(self): ... return len(self.tmp) ... def __getitem__(self, item): ... if isinstance(item, int): ... if item >= len(self) or item < -len(self): # type:ignore ... raise IndexError(f'Index {item} out of range!') ... else: ... # keep the dimension ... item = slice(item, None, len(self)) ... return TmpObject(self.tmp[item]) ... @staticmethod ... def cat(tmp_objs): ... assert all(isinstance(results, TmpObject) for results in tmp_objs) ... if len(tmp_objs) == 1: ... return tmp_objs[0] ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] ... tmp_list = list(itertools.chain(*tmp_list)) ... new_data = TmpObject(tmp_list) ... return new_data ... def __repr__(self): ... return str(self.tmp) >>> from mmengine.structures import InstanceData >>> import numpy as np >>> import torch >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) >>> instance_data = InstanceData(metainfo=img_meta) >>> 'img_shape' in instance_data True >>> instance_data.det_labels = torch.LongTensor([2, 3]) >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) >>> instance_data.bboxes = torch.rand((2, 4)) >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) >>> len(instance_data) 2 >>> print(instance_data) <InstanceData( META INFORMATION img_shape: (800, 1196, 3) pad_shape: (800, 1216, 3) DATA FIELDS det_labels: tensor([2, 3]) det_scores: tensor([0.8000, 0.7000]) bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], [0.8101, 0.3105, 0.5123, 0.6263]]) polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] ) at 0x7fb492de6280> >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] >>> sorted_results.det_scores tensor([0.7000, 0.8000]) >>> print(instance_data[instance_data.det_scores > 0.75]) <InstanceData( META INFORMATION img_shape: (800, 1196, 3) pad_shape: (800, 1216, 3) DATA FIELDS det_labels: tensor([2]) det_scores: tensor([0.8000]) bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) polygons: [[1, 2, 3, 4]] ) at 0x7f64ecf0ec40> >>> print(instance_data[instance_data.det_scores > 1]) <InstanceData( META INFORMATION img_shape: (800, 1196, 3) pad_shape: (800, 1216, 3) DATA FIELDS det_labels: tensor([], dtype=torch.int64) det_scores: tensor([]) bboxes: tensor([], size=(0, 4)) polygons: [] ) at 0x7f660a6a7f70> >>> print(instance_data.cat([instance_data, instance_data])) <InstanceData( META INFORMATION img_shape: (800, 1196, 3) pad_shape: (800, 1216, 3) DATA FIELDS det_labels: tensor([2, 3, 2, 3]) det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], [0.8101, 0.3105, 0.5123, 0.6263], [0.4997, 0.7707, 0.0595, 0.4188], [0.8101, 0.3105, 0.5123, 0.6263]]) polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] ) at 0x7f203542feb0>
- 参数:
metainfo (dict | None) –
- static cat(instances_list)[源代码]¶
Concat the instances of all
InstanceData
in the list.Note: To ensure that cat returns as expected, make sure that all elements in the list must have exactly the same keys.
- 参数:
instances_list (list[
InstanceData
]) – A list ofInstanceData
.- 返回:
- 返回类型: