Shortcuts

BaseModel

class mmengine.model.BaseModel(data_preprocessor=None, init_cfg=None)[源代码]

Base class for all algorithmic models.

BaseModel implements the basic functions of the algorithmic model, such as weights initialize, batch inputs preprocess(see more information in BaseDataPreprocessor), parse losses, and update model parameters.

Subclasses inherit from BaseModel only need to implement the forward method, which implements the logic to calculate loss and predictions, then can be trained in the runner.

实际案例

>>> @MODELS.register_module()
>>> class ToyModel(BaseModel):
>>>
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.backbone = nn.Sequential()
>>>         self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5))
>>>         self.backbone.add_module('pool', nn.MaxPool2d(2, 2))
>>>         self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5))
>>>         self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120))
>>>         self.backbone.add_module('fc2', nn.Linear(120, 84))
>>>         self.backbone.add_module('fc3', nn.Linear(84, 10))
>>>
>>>         self.criterion = nn.CrossEntropyLoss()
>>>
>>>     def forward(self, batch_inputs, data_samples, mode='tensor'):
>>>         data_samples = torch.stack(data_samples)
>>>         if mode == 'tensor':
>>>             return self.backbone(batch_inputs)
>>>         elif mode == 'predict':
>>>             feats = self.backbone(batch_inputs)
>>>             predictions = torch.argmax(feats, 1)
>>>             return predictions
>>>         elif mode == 'loss':
>>>             feats = self.backbone(batch_inputs)
>>>             loss = self.criterion(feats, data_samples)
>>>             return dict(loss=loss)
参数
data_preprocessor

Used for pre-processing data sampled by dataloader to the format accepted by forward().

Type

BaseDataPreprocessor

init_cfg

Initialization config dict.

Type

dict, optional

cpu(*args, **kwargs)[源代码]

Overrides this method to call BaseDataPreprocessor.cpu() additionally.

返回

The model itself.

返回类型

nn.Module

cuda(device=None)[源代码]

Overrides this method to call BaseDataPreprocessor.cuda() additionally.

返回

The model itself.

返回类型

nn.Module

参数

device (Optional[Union[int, str, torch.device]]) –

abstract forward(inputs, data_samples=None, mode='tensor')[源代码]

Returns losses or predictions of training, validation, testing, and simple inference process.

forward method of BaseModel is an abstract method, its subclasses must implement this method.

Accepts batch_inputs and data_sample processed by data_preprocessor, and returns results according to mode arguments.

During non-distributed training, validation, and testing process, forward will be called by BaseModel.train_step, BaseModel.val_step and BaseModel.test_step directly.

During distributed data parallel training process, MMSeparateDistributedDataParallel.train_step will first call DistributedDataParallel.forward to enable automatic gradient synchronization, and then call forward to get training loss.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (list, optional) – data samples collated by data_preprocessor.

  • mode (str) –

    mode should be one of loss, predict and tensor

    • loss: Called by train_step and return loss dict used for logging

    • predict: Called by val_step and test_step and return list of results used for computing metric.

    • tensor: Called by custom use to get Tensor type results.

返回

  • If mode == loss, return a dict of loss tensor used for backward and logging.

  • If mode == predict, return a list of inference results.

  • If mode == tensor, return a tensor or tuple of tensor or dict of tensor for custom use.

返回类型

dict or list

mlu(device=None)[源代码]

Overrides this method to call BaseDataPreprocessor.mlu() additionally.

返回

The model itself.

返回类型

nn.Module

参数

device (Optional[Union[int, str, torch.device]]) –

npu(device=None)[源代码]

Overrides this method to call BaseDataPreprocessor.npu() additionally.

返回

The model itself.

返回类型

nn.Module

参数

device (Optional[Union[int, str, torch.device]]) –

备注

This generation of NPU(Ascend910) does not support the use of multiple cards in a single process, so the index here needs to be consistent with the default device

parse_losses(losses)[源代码]

Parses the raw outputs (losses) of the network.

参数

losses (dict) – Raw output of the network, which usually contain losses and other necessary information.

返回

There are two elements. The first is the loss tensor passed to optim_wrapper which may be a weighted sum of all losses, and the second is log_vars which will be sent to the logger.

返回类型

tuple[Tensor, dict]

test_step(data)[源代码]

BaseModel implements test_step the same as val_step.

参数

data (dict or tuple or list) – Data sampled from dataset.

返回

The predictions of given data.

返回类型

list

to(*args, **kwargs)[源代码]

Overrides this method to call BaseDataPreprocessor.to() additionally.

返回

The model itself.

返回类型

nn.Module

train_step(data, optim_wrapper)[源代码]

Implements the default model training process including preprocessing, model forward propagation, loss calculation, optimization, and back-propagation.

During non-distributed training. If subclasses do not override the train_step(), EpochBasedTrainLoop or IterBasedTrainLoop will call this method to update model parameters. The default parameter update process is as follows:

  1. Calls self.data_processor(data, training=False) to collect batch_inputs and corresponding data_samples(labels).

  2. Calls self(batch_inputs, data_samples, mode='loss') to get raw loss

  3. Calls self.parse_losses to get parsed_losses tensor used to backward and dict of loss tensor used to log messages.

  4. Calls optim_wrapper.update_params(loss) to update model.

参数
  • data (dict or tuple or list) – Data sampled from dataset.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, torch.Tensor]

val_step(data)[源代码]

Gets the predictions of given data.

Calls self.data_preprocessor(data, False) and self(inputs, data_sample, mode='predict') in order. Return the predictions which will be passed to evaluator.

参数

data (dict or tuple or list) – Data sampled from dataset.

返回

The predictions of given data.

返回类型

list

Read the Docs v: latest
Versions
latest
stable
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.