- class mmengine.infer.BaseInferencer(model=None, weights=None, device=None, scope=None, show_progress=True)[source]¶
Base inferencer for downstream tasks.
The BaseInferencer provides the standard workflow for inference as follows:
Preprocess the input data by
Forward the data to the model by
BaseInferencerassumes the model inherits from
mmengine.models.BaseModeland will call model.test_step in
Visualize the results by
Postprocess and return the results by
When we call the subclasses inherited from BaseInferencer (not overriding
__call__), the workflow will be executed in order.
All subclasses of BaseInferencer could define the following class attributes for customization:
preprocess_kwargs: The keys of the kwargs that will be passed to
forward_kwargs: The keys of the kwargs that will be passed to
visualize_kwargs: The keys of the kwargs that will be passed to
postprocess_kwargs: The keys of the kwargs that will be passed to
All attributes mentioned above should be a
setof keys (strings), and each key should not be duplicated. Actually,
__call__()will dispatch all the arguments to the corresponding methods according to the
xxx_kwargsmentioned above, therefore, the key in sets should be unique to avoid ambiguous dispatching.
If subclasses defined the class attributes mentioned above with duplicated keys, an
AssertionErrorwill be raised during import process.
Subclasses inherited from
_init_pipeline: Return a callable object to preprocess the input data.
visualize: Visualize the results returned by
postprocess: Postprocess the results returned by
model (str, optional) – Path to the config file or the model name defined in metafile. Take the mmdet metafile as an example, the model could be retinanet_r18_fpn_1x_coco or its alias. If model is not specified, user must provide the weights saved by MMEngine which contains the config string. Defaults to None.
weights (str, optional) – Path to the checkpoint. If it is not specified and model is a model name of metafile, the weights will be loaded from metafile. Defaults to None.
device (str, optional) – Device to run inference. If None, the available device will be automatically used. Defaults to None.
scope (str, optional) – The scope of the model. Defaults to None.
show_progress (bool) – Control whether to display the progress bar during the inference process. Defaults to True. New in version 0.7.4.
- Return type
Inferencercould be used to infer batch data, collate_fn should be defined. If collate_fn is not defined in config file, the collate_fn will be pseudo_collate by default.
- forward(inputs, **kwargs)[source]¶
Feed the inputs to the model.
- static list_models(scope=None, patterns='.*')[source]¶
List models defined in metafile of corresponding packages.
Model dict with model name and its alias.
- Return type
- abstract postprocess(preds, visualization, return_datasample=False, **kwargs)[source]¶
Process the predictions and visualization results from
This method should be responsible for the following tasks:
Convert datasamples into a json-serializable dict if needed.
Pack the predictions and visualization results and return them.
Dump or log the predictions.
Customize your postprocess by overriding this method. Make sure
postprocesswill return a dict with visualization results and inference results.
preds (List[Dict]) – Predictions of the model.
visualization (np.ndarray) – Visualized predictions.
return_datasample (bool) – Whether to return results as datasamples. Defaults to False.
Inference and visualization results with key
visualization (Any): Returned by
predictions(dict or DataSample): Returned by
forward()and processed in
return_datasample=False, it usually should be a json-serializable dict containing only basic data elements such as strings and numbers.
- Return type
- preprocess(inputs, batch_size=1, **kwargs)[source]¶
Process the inputs into a model-feedable format.
Customize your preprocess by overriding this method. Preprocess should return an iterable object, of which each item will be used as the input of
BaseInferencer.preprocesswill return an iterable chunked data, which will be used in __call__ like this:
def __call__(self, inputs, batch_size=1, **kwargs): chunked_data = self.preprocess(inputs, batch_size, **kwargs) for batch in chunked_data: preds = self.forward(batch, **kwargs)
inputs (InputsType) – Inputs given by user.
batch_size (int) – batch size. Defaults to 1.
Any – Data processed by the
- abstract visualize(inputs, preds, show=False, **kwargs)[source]¶
Customize your visualization by overriding this method. visualize should return visualization results, which could be np.ndarray or any other objects.