Shortcuts

Source code for mmengine.structures.label_data

# Copyright (c) OpenMMLab. All rights reserved.

import torch

from .base_data_element import BaseDataElement


[docs]class LabelData(BaseDataElement): """Data structure for label-level annotations or predictions."""
[docs] @staticmethod def onehot_to_label(onehot: torch.Tensor) -> torch.Tensor: """Convert the one-hot input to label. Args: onehot (torch.Tensor, optional): The one-hot input. The format of input must be one-hot. Returns: torch.Tensor: The converted results. """ assert isinstance(onehot, torch.Tensor) if (onehot.ndim == 1 and onehot.max().item() <= 1 and onehot.min().item() >= 0): return onehot.nonzero().squeeze(-1) else: raise ValueError( 'input is not one-hot and can not convert to label')
[docs] @staticmethod def label_to_onehot(label: torch.Tensor, num_classes: int) -> torch.Tensor: """Convert the label-format input to one-hot. Args: label (torch.Tensor): The label-format input. The format of item must be label-format. num_classes (int): The number of classes. Returns: torch.Tensor: The converted results. """ assert isinstance(label, torch.Tensor) onehot = label.new_zeros((num_classes, )) assert max(label, default=torch.tensor(0)).item() < num_classes onehot[label] = 1 return onehot

© Copyright 2022, mmengine contributors. Revision 6a56ca78.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.6.0
Versions
latest
stable
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.