BaseInferencer¶
- class mmengine.infer.BaseInferencer(model=None, weights=None, device=None, scope=None, show_progress=True)[source]¶
Base inferencer for downstream tasks.
The BaseInferencer provides the standard workflow for inference as follows:
Preprocess the input data by
preprocess()
.Forward the data to the model by
forward()
.BaseInferencer
assumes the model inherits frommmengine.models.BaseModel
and will call model.test_step inforward()
by default.Visualize the results by
visualize()
.Postprocess and return the results by
postprocess()
.
When we call the subclasses inherited from BaseInferencer (not overriding
__call__
), the workflow will be executed in order.All subclasses of BaseInferencer could define the following class attributes for customization:
preprocess_kwargs
: The keys of the kwargs that will be passed topreprocess()
.forward_kwargs
: The keys of the kwargs that will be passed toforward()
visualize_kwargs
: The keys of the kwargs that will be passed tovisualize()
postprocess_kwargs
: The keys of the kwargs that will be passed topostprocess()
All attributes mentioned above should be a
set
of keys (strings), and each key should not be duplicated. Actually,__call__()
will dispatch all the arguments to the corresponding methods according to thexxx_kwargs
mentioned above, therefore, the key in sets should be unique to avoid ambiguous dispatching.Warning
If subclasses defined the class attributes mentioned above with duplicated keys, an
AssertionError
will be raised during import process.Subclasses inherited from
BaseInferencer
should implement_init_pipeline()
,visualize()
andpostprocess()
:_init_pipeline: Return a callable object to preprocess the input data.
visualize: Visualize the results returned by
forward()
.postprocess: Postprocess the results returned by
forward()
andvisualize()
.
- Parameters:
model (str, optional) – Path to the config file or the model name defined in metafile. Take the mmdet metafile as an example, the model could be retinanet_r18_fpn_1x_coco or its alias. If model is not specified, user must provide the weights saved by MMEngine which contains the config string. Defaults to None.
weights (str, optional) – Path to the checkpoint. If it is not specified and model is a model name of metafile, the weights will be loaded from metafile. Defaults to None.
device (str, optional) – Device to run inference. If None, the available device will be automatically used. Defaults to None.
scope (str, optional) – The scope of the model. Defaults to None.
show_progress (bool) – Control whether to display the progress bar during the inference process. Defaults to True. New in version 0.7.4.
Note
Since
Inferencer
could be used to infer batch data, collate_fn should be defined. If collate_fn is not defined in config file, the collate_fn will be pseudo_collate by default.- static list_models(scope=None, patterns='.*')[source]¶
List models defined in metafile of corresponding packages.
- Parameters:
- Returns:
Model dict with model name and its alias.
- Return type:
- abstract postprocess(preds, visualization, return_datasample=False, **kwargs)[source]¶
Process the predictions and visualization results from
forward
andvisualize
.This method should be responsible for the following tasks:
Convert datasamples into a json-serializable dict if needed.
Pack the predictions and visualization results and return them.
Dump or log the predictions.
Customize your postprocess by overriding this method. Make sure
postprocess
will return a dict with visualization results and inference results.- Parameters:
preds (List[Dict]) – Predictions of the model.
visualization (np.ndarray) – Visualized predictions.
return_datasample (bool) – Whether to return results as datasamples. Defaults to False.
- Returns:
Inference and visualization results with key
predictions
andvisualization
visualization (Any)
: Returned byvisualize()
predictions
(dict or DataSample): Returned byforward()
and processed inpostprocess()
. Ifreturn_datasample=False
, it usually should be a json-serializable dict containing only basic data elements such as strings and numbers.
- Return type:
- preprocess(inputs, batch_size=1, **kwargs)[source]¶
Process the inputs into a model-feedable format.
Customize your preprocess by overriding this method. Preprocess should return an iterable object, of which each item will be used as the input of
model.test_step
.BaseInferencer.preprocess
will return an iterable chunked data, which will be used in __call__ like this:def __call__(self, inputs, batch_size=1, **kwargs): chunked_data = self.preprocess(inputs, batch_size, **kwargs) for batch in chunked_data: preds = self.forward(batch, **kwargs)
- Parameters:
inputs (InputsType) – Inputs given by user.
batch_size (int) – batch size. Defaults to 1.
- Yields:
Any – Data processed by the
pipeline
andcollate_fn
.