Shortcuts

Source code for mmengine.model.wrappers.utils

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn

from mmengine.registry import MODEL_WRAPPERS, Registry


[docs]def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS): """Check if a module is a model wrapper. The following 4 model in MMEngine (and their subclasses) are regarded as model wrappers: DataParallel, DistributedDataParallel, MMDataParallel, MMDistributedDataParallel. You may add you own model wrapper by registering it to ``mmengine.registry.MODEL_WRAPPERS``. Args: model (nn.Module): The model to be checked. registry (Registry): The parent registry to search for model wrappers. Returns: bool: True if the input model is a model wrapper. """ module_wrappers = tuple(registry.module_dict.values()) if isinstance(model, module_wrappers): return True if not registry.children: return False return any( is_model_wrapper(model, child) for child in registry.children.values())

© Copyright 2022, mmengine contributors. Revision 6af88783.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.4.0
Versions
latest
stable
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.