Shortcuts

FlopAnalyzer

class mmengine.analysis.FlopAnalyzer(model, inputs)[source]

Provides access to per-submodule model flop count obtained by tracing a model with pytorch’s jit tracing functionality.

By default, comes with standard flop counters for a few common operators.

Note

  • Flop is not a well-defined concept. We just produce our best estimate.

  • We count one fused multiply-add as one flop.

Handles for additional operators may be added, or the default ones overwritten, using the .set_op_handle(name, func) method. See the method documentation for details. Flop counts can be obtained as:

  • .total(module_name=""): total flop count for the module

  • .by_operator(module_name=""): flop counts for the module, as a Counter over different operator types

  • .by_module(): Counter of flop counts for all submodules

  • .by_module_and_operator(): dictionary indexed by descendant of Counters over different operator types

An operator is treated as within a module if it is executed inside the module’s __call__ method. Note that this does not include calls to other methods of the module or explicit calls to module.forward(...).

Modified from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py

Parameters
  • model (nn.Module) – The model to analyze.

  • inputs (Union[Tensor, Tuple[Tensor, ...]]) – The input to the model.

Return type

None

Examples

>>> import torch.nn as nn
>>> import torch
>>> class TestModel(nn.Module):
...    def __init__(self):
...        super().__init__()
...        self.fc = nn.Linear(in_features=1000, out_features=10)
...        self.conv = nn.Conv2d(
...            in_channels=3, out_channels=10, kernel_size=1
...        )
...        self.act = nn.ReLU()
...    def forward(self, x):
...        return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel()
>>> inputs = (torch.randn((1,3,10,10)),)
>>> flops = FlopAnalyzer(model, inputs)
>>> flops.total()
13000
>>> flops.total("fc")
10000
>>> flops.by_operator()
Counter({"addmm" : 10000, "conv" : 3000})
>>> flops.by_module()
Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0})
>>> flops.by_module_and_operator()
{"" : Counter({"addmm" : 10000, "conv" : 3000}),
"fc" : Counter({"addmm" : 10000}),
"conv" : Counter({"conv" : 3000}),
"act" : Counter()
}
Read the Docs v: v0.7.0
Versions
latest
stable
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.