Shortcuts

InstanceData

class mmengine.structures.InstanceData(*, metainfo=None, **kwargs)[源代码]

Data structure for instance-level annotations or predictions.

Subclass of BaseDataElement. All value in data_fields should have the same length. This design refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 InstanceData also support extra functions: index, slice and cat for data field. The type of value in data field can be base data structure such as torch.tensor, numpy.ndarray, list, str, tuple, and can be customized data structure that has __len__, __getitem__ and cat attributes.

实际案例

>>> # custom data structure
>>> class TmpObject:
...     def __init__(self, tmp) -> None:
...         assert isinstance(tmp, list)
...         self.tmp = tmp
...     def __len__(self):
...         return len(self.tmp)
...     def __getitem__(self, item):
...         if type(item) == int:
...             if item >= len(self) or item < -len(self):  # type:ignore
...                 raise IndexError(f'Index {item} out of range!')
...             else:
...                 # keep the dimension
...                 item = slice(item, None, len(self))
...         return TmpObject(self.tmp[item])
...     @staticmethod
...     def cat(tmp_objs):
...         assert all(isinstance(results, TmpObject) for results in tmp_objs)
...         if len(tmp_objs) == 1:
...             return tmp_objs[0]
...         tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
...         tmp_list = list(itertools.chain(*tmp_list))
...         new_data = TmpObject(tmp_list)
...         return new_data
...     def __repr__(self):
...         return str(self.tmp)
>>> from mmengine.structures import InstanceData
>>> import numpy as np
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
>>> instance_data = InstanceData(metainfo=img_meta)
>>> 'img_shape' in instance_data
True
>>> instance_data.det_labels = torch.LongTensor([2, 3])
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
>>> instance_data.bboxes = torch.rand((2, 4))
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
>>> len(instance_data)
2
>>> print(instance_data)
<InstanceData(
    META INFORMATION
    pad_shape: (800, 1196, 3)
    img_shape: (800, 1216, 3)
    DATA FIELDS
    det_labels: tensor([2, 3])
    det_scores: tensor([0.8, 0.7000])
    bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
        [0.8101, 0.3105, 0.5123, 0.6263]])
    polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
    ) at 0x7fb492de6280>
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
>>> sorted_results.det_scores
tensor([0.7000, 0.8000])
>>> print(instance_data[instance_data.det_scores > 0.75])
<InstanceData(
    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)
    DATA FIELDS
    det_labels: tensor([2])
    masks: [[11, 21, 31, 41]]
    det_scores: tensor([0.8000])
    bboxes: tensor([[0.9308, 0.4000, 0.6077, 0.5554]])
    polygons: [[1, 2, 3, 4]]
) at 0x7f64ecf0ec40>
>>> print(instance_data[instance_data.det_scores > 1])
<InstanceData(
    META INFORMATION
    pad_shape: (800, 1216, 3)
    img_shape: (800, 1196, 3)
    DATA FIELDS
    det_labels: tensor([], dtype=torch.int64)
    masks: []
    det_scores: tensor([])
    bboxes: tensor([], size=(0, 4))
    polygons: [[]]
) at 0x7f660a6a7f70>
>>> print(instance_data.cat([instance_data, instance_data]))
<InstanceData(
    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)
    DATA FIELDS
    det_labels: tensor([2, 3, 2, 3])
    bboxes: tensor([[0.7404, 0.6332, 0.1684, 0.9961],
                [0.2837, 0.8112, 0.5416, 0.2810],
                [0.7404, 0.6332, 0.1684, 0.9961],
                [0.2837, 0.8112, 0.5416, 0.2810]])
    data:
    polygons: [[1, 2, 3, 4], [5, 6, 7, 8],
               [1, 2, 3, 4], [5, 6, 7, 8]]
    det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
    masks: [[11, 21, 31, 41], [51, 61, 71, 81],
            [11, 21, 31, 41], [51, 61, 71, 81]]
) at 0x7f203542feb0>
参数

metainfo (Optional[dict]) –

返回类型

None

static cat(instances_list)[源代码]

Concat the instances of all InstanceData in the list.

Note: To ensure that cat returns as expected, make sure that all elements in the list must have exactly the same keys.

参数

instances_list (list[InstanceData]) – A list of InstanceData.

返回

InstanceData

返回类型

obj

Read the Docs v: v0.3.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.