ProArd / proard /utils /flops_counter.py
smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch
import torch.nn as nn
from .my_modules import MyConv2d
__all__ = ["profile"]
def count_convNd(m, _, y):
cin = m.in_channels
kernel_ops = m.weight.size()[2] * m.weight.size()[3]
ops_per_element = kernel_ops
output_elements = y.nelement()
# cout x oW x oH
total_ops = cin * output_elements * ops_per_element // m.groups
m.total_ops = torch.zeros(1).fill_(total_ops)
def count_linear(m, _, __):
total_ops = m.in_features * m.out_features
m.total_ops = torch.zeros(1).fill_(total_ops)
register_hooks = {
nn.Conv1d: count_convNd,
nn.Conv2d: count_convNd,
nn.Conv3d: count_convNd,
MyConv2d: count_convNd,
######################################
nn.Linear: count_linear,
######################################
nn.Dropout: None,
nn.Dropout2d: None,
nn.Dropout3d: None,
nn.BatchNorm2d: None,
}
def profile(model, input_size, custom_ops=None):
handler_collection = []
custom_ops = {} if custom_ops is None else custom_ops
def add_hooks(m_):
if len(list(m_.children())) > 0:
return
m_.register_buffer("total_ops", torch.zeros(1))
m_.register_buffer("total_params", torch.zeros(1))
for p in m_.parameters():
m_.total_params += torch.zeros(1).fill_(p.numel())
m_type = type(m_)
fn = None
if m_type in custom_ops:
fn = custom_ops[m_type]
elif m_type in register_hooks:
fn = register_hooks[m_type]
if fn is not None:
_handler = m_.register_forward_hook(fn)
handler_collection.append(_handler)
original_device = model.parameters().__next__().device
training = model.training
model.eval()
model.apply(add_hooks)
x = torch.zeros(input_size).to(original_device)
with torch.no_grad():
model(x)
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
continue
total_ops += m.total_ops
total_params += m.total_params
total_ops = total_ops.item()
total_params = total_params.item()
model.train(training).to(original_device)
for handler in handler_collection:
handler.remove()
return total_ops, total_params