|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .my_modules import MyConv2d |
|
|
|
|
|
__all__ = ["profile"] |
|
|
|
|
|
|
|
|
def count_convNd(m, _, y): |
|
|
cin = m.in_channels |
|
|
|
|
|
kernel_ops = m.weight.size()[2] * m.weight.size()[3] |
|
|
ops_per_element = kernel_ops |
|
|
output_elements = y.nelement() |
|
|
|
|
|
|
|
|
total_ops = cin * output_elements * ops_per_element // m.groups |
|
|
m.total_ops = torch.zeros(1).fill_(total_ops) |
|
|
|
|
|
|
|
|
def count_linear(m, _, __): |
|
|
total_ops = m.in_features * m.out_features |
|
|
|
|
|
m.total_ops = torch.zeros(1).fill_(total_ops) |
|
|
|
|
|
|
|
|
register_hooks = { |
|
|
nn.Conv1d: count_convNd, |
|
|
nn.Conv2d: count_convNd, |
|
|
nn.Conv3d: count_convNd, |
|
|
MyConv2d: count_convNd, |
|
|
|
|
|
nn.Linear: count_linear, |
|
|
|
|
|
nn.Dropout: None, |
|
|
nn.Dropout2d: None, |
|
|
nn.Dropout3d: None, |
|
|
nn.BatchNorm2d: None, |
|
|
} |
|
|
|
|
|
|
|
|
def profile(model, input_size, custom_ops=None): |
|
|
handler_collection = [] |
|
|
custom_ops = {} if custom_ops is None else custom_ops |
|
|
|
|
|
def add_hooks(m_): |
|
|
if len(list(m_.children())) > 0: |
|
|
return |
|
|
|
|
|
m_.register_buffer("total_ops", torch.zeros(1)) |
|
|
m_.register_buffer("total_params", torch.zeros(1)) |
|
|
|
|
|
for p in m_.parameters(): |
|
|
m_.total_params += torch.zeros(1).fill_(p.numel()) |
|
|
|
|
|
m_type = type(m_) |
|
|
fn = None |
|
|
|
|
|
if m_type in custom_ops: |
|
|
fn = custom_ops[m_type] |
|
|
elif m_type in register_hooks: |
|
|
fn = register_hooks[m_type] |
|
|
|
|
|
if fn is not None: |
|
|
_handler = m_.register_forward_hook(fn) |
|
|
handler_collection.append(_handler) |
|
|
|
|
|
original_device = model.parameters().__next__().device |
|
|
training = model.training |
|
|
|
|
|
model.eval() |
|
|
model.apply(add_hooks) |
|
|
|
|
|
x = torch.zeros(input_size).to(original_device) |
|
|
with torch.no_grad(): |
|
|
model(x) |
|
|
|
|
|
total_ops = 0 |
|
|
total_params = 0 |
|
|
for m in model.modules(): |
|
|
if len(list(m.children())) > 0: |
|
|
continue |
|
|
total_ops += m.total_ops |
|
|
total_params += m.total_params |
|
|
|
|
|
total_ops = total_ops.item() |
|
|
total_params = total_params.item() |
|
|
|
|
|
model.train(training).to(original_device) |
|
|
for handler in handler_collection: |
|
|
handler.remove() |
|
|
|
|
|
return total_ops, total_params |
|
|
|