Source code for mmengine.structures.label_data
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from .base_data_element import BaseDataElement
[docs]class LabelData(BaseDataElement):
"""Data structure for label-level annotations or predictions."""
[docs] @staticmethod
def onehot_to_label(onehot: torch.Tensor) -> torch.Tensor:
"""Convert the one-hot input to label.
Args:
onehot (torch.Tensor, optional): The one-hot input. The format
of input must be one-hot.
Returns:
torch.Tensor: The converted results.
"""
assert isinstance(onehot, torch.Tensor)
if (onehot.ndim == 1 and onehot.max().item() <= 1
and onehot.min().item() >= 0):
return onehot.nonzero().squeeze(-1)
else:
raise ValueError(
'input is not one-hot and can not convert to label')
[docs] @staticmethod
def label_to_onehot(label: torch.Tensor, num_classes: int) -> torch.Tensor:
"""Convert the label-format input to one-hot.
Args:
label (torch.Tensor): The label-format input. The format
of item must be label-format.
num_classes (int): The number of classes.
Returns:
torch.Tensor: The converted results.
"""
assert isinstance(label, torch.Tensor)
onehot = label.new_zeros((num_classes, ))
assert max(label, default=torch.tensor(0)).item() < num_classes
onehot[label] = 1
return onehot