FlopAnalyzer¶
- class mmengine.analysis.FlopAnalyzer(model, inputs)[source]¶
Provides access to per-submodule model flop count obtained by tracing a model with pytorch’s jit tracing functionality.
By default, comes with standard flop counters for a few common operators.
Note
Flop is not a well-defined concept. We just produce our best estimate.
We count one fused multiply-add as one flop.
Handles for additional operators may be added, or the default ones overwritten, using the
.set_op_handle(name, func)
method. See the method documentation for details. Flop counts can be obtained as:.total(module_name="")
: total flop count for the module.by_operator(module_name="")
: flop counts for the module, as a Counter over different operator types.by_module()
: Counter of flop counts for all submodules.by_module_and_operator()
: dictionary indexed by descendant of Counters over different operator types
An operator is treated as within a module if it is executed inside the module’s
__call__
method. Note that this does not include calls to other methods of the module or explicit calls tomodule.forward(...)
.Modified from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py
- Parameters
model (nn.Module) – The model to analyze.
inputs (Union[Tensor, Tuple[Tensor, ...]]) – The input to the model.
- Return type
None
Examples
>>> import torch.nn as nn >>> import torch >>> class TestModel(nn.Module): ... def __init__(self): ... super().__init__() ... self.fc = nn.Linear(in_features=1000, out_features=10) ... self.conv = nn.Conv2d( ... in_channels=3, out_channels=10, kernel_size=1 ... ) ... self.act = nn.ReLU() ... def forward(self, x): ... return self.fc(self.act(self.conv(x)).flatten(1)) >>> model = TestModel() >>> inputs = (torch.randn((1,3,10,10)),) >>> flops = FlopAnalyzer(model, inputs) >>> flops.total() 13000 >>> flops.total("fc") 10000 >>> flops.by_operator() Counter({"addmm" : 10000, "conv" : 3000}) >>> flops.by_module() Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0}) >>> flops.by_module_and_operator() {"" : Counter({"addmm" : 10000, "conv" : 3000}), "fc" : Counter({"addmm" : 10000}), "conv" : Counter({"conv" : 3000}), "act" : Counter() }