Shortcuts

FlopAnalyzer

class mmengine.analysis.FlopAnalyzer(model, inputs)[source]

Provides access to per-submodule model flop count obtained by tracing a model with pytorch’s jit tracing functionality.

By default, comes with standard flop counters for a few common operators.

Note

  • Flop is not a well-defined concept. We just produce our best estimate.

  • We count one fused multiply-add as one flop.

Handles for additional operators may be added, or the default ones overwritten, using the .set_op_handle(name, func) method. See the method documentation for details. Flop counts can be obtained as:

  • .total(module_name=""): total flop count for the module

  • .by_operator(module_name=""): flop counts for the module, as a Counter over different operator types

  • .by_module(): Counter of flop counts for all submodules

  • .by_module_and_operator(): dictionary indexed by descendant of Counters over different operator types

An operator is treated as within a module if it is executed inside the module’s __call__ method. Note that this does not include calls to other methods of the module or explicit calls to module.forward(...).

Modified from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py

Parameters:
  • model (nn.Module) – The model to analyze.

  • inputs (Union[Tensor, Tuple[Tensor, ...]]) – The input to the model.

Examples

>>> import torch.nn as nn
>>> import torch
>>> class TestModel(nn.Module):
...    def __init__(self):
...        super().__init__()
...        self.fc = nn.Linear(in_features=1000, out_features=10)
...        self.conv = nn.Conv2d(
...            in_channels=3, out_channels=10, kernel_size=1
...        )
...        self.act = nn.ReLU()
...    def forward(self, x):
...        return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel()
>>> inputs = (torch.randn((1,3,10,10)),)
>>> flops = FlopAnalyzer(model, inputs)
>>> flops.total()
13000
>>> flops.total("fc")
10000
>>> flops.by_operator()
Counter({"addmm" : 10000, "conv" : 3000})
>>> flops.by_module()
Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0})
>>> flops.by_module_and_operator()
{"" : Counter({"addmm" : 10000, "conv" : 3000}),
"fc" : Counter({"addmm" : 10000}),
"conv" : Counter({"conv" : 3000}),
"act" : Counter()
}
Read the Docs v: latest
Versions
latest
stable
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.1
v0.9.0
v0.8.5
v0.8.4
v0.8.3
v0.8.2
v0.8.1
v0.8.0
v0.7.4
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.