ActivationAnalyzer¶
- class mmengine.analysis.ActivationAnalyzer(model, inputs)[source]¶
Provides access to per-submodule model activation count obtained by tracing a model with pytorch’s jit tracing functionality.
By default, comes with standard activation counters for convolutional and dot-product operators. Handles for additional operators may be added, or the default ones overwritten, using the
.set_op_handle(name, func)
method. See the method documentation for details. Activation counts can be obtained as:.total(module_name="")
: total activation count for a module.by_operator(module_name="")
: activation counts for the module, as a Counter over different operator types.by_module()
: Counter of activation counts for all submodules.by_module_and_operator()
: dictionary indexed by descendant of Counters over different operator types
An operator is treated as within a module if it is executed inside the module’s
__call__
method. Note that this does not include calls to other methods of the module or explicit calls tomodule.forward(...)
.Modified from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py
- Parameters:
model (nn.Module) – The model to analyze.
inputs (Union[Tensor, Tuple[Tensor, ...]]) – The input to the model.
Examples
>>> import torch.nn as nn >>> import torch >>> class TestModel(nn.Module): ... def __init__(self): ... super().__init__() ... self.fc = nn.Linear(in_features=1000, out_features=10) ... self.conv = nn.Conv2d( ... in_channels=3, out_channels=10, kernel_size=1 ... ) ... self.act = nn.ReLU() ... def forward(self, x): ... return self.fc(self.act(self.conv(x)).flatten(1)) >>> model = TestModel() >>> inputs = (torch.randn((1,3,10,10)),) >>> acts = ActivationAnalyzer(model, inputs) >>> acts.total() 1010 >>> acts.total("fc") 10 >>> acts.by_operator() Counter({"conv" : 1000, "addmm" : 10}) >>> acts.by_module() Counter({"" : 1010, "fc" : 10, "conv" : 1000, "act" : 0}) >>> acts.by_module_and_operator() {"" : Counter({"conv" : 1000, "addmm" : 10}), "fc" : Counter({"addmm" : 10}), "conv" : Counter({"conv" : 1000}), "act" : Counter() }