# Once for All: Train One Network and Specialize it for Efficient Deployment # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han # International Conference on Learning Representations (ICLR), 2020. import torch import torch.nn as nn from .my_modules import MyConv2d __all__ = ["profile"] def count_convNd(m, _, y): cin = m.in_channels kernel_ops = m.weight.size()[2] * m.weight.size()[3] ops_per_element = kernel_ops output_elements = y.nelement() # cout x oW x oH total_ops = cin * output_elements * ops_per_element // m.groups m.total_ops = torch.zeros(1).fill_(total_ops) def count_linear(m, _, __): total_ops = m.in_features * m.out_features m.total_ops = torch.zeros(1).fill_(total_ops) register_hooks = { nn.Conv1d: count_convNd, nn.Conv2d: count_convNd, nn.Conv3d: count_convNd, MyConv2d: count_convNd, ###################################### nn.Linear: count_linear, ###################################### nn.Dropout: None, nn.Dropout2d: None, nn.Dropout3d: None, nn.BatchNorm2d: None, } def profile(model, input_size, custom_ops=None): handler_collection = [] custom_ops = {} if custom_ops is None else custom_ops def add_hooks(m_): if len(list(m_.children())) > 0: return m_.register_buffer("total_ops", torch.zeros(1)) m_.register_buffer("total_params", torch.zeros(1)) for p in m_.parameters(): m_.total_params += torch.zeros(1).fill_(p.numel()) m_type = type(m_) fn = None if m_type in custom_ops: fn = custom_ops[m_type] elif m_type in register_hooks: fn = register_hooks[m_type] if fn is not None: _handler = m_.register_forward_hook(fn) handler_collection.append(_handler) original_device = model.parameters().__next__().device training = model.training model.eval() model.apply(add_hooks) x = torch.zeros(input_size).to(original_device) with torch.no_grad(): model(x) total_ops = 0 total_params = 0 for m in model.modules(): if len(list(m.children())) > 0: # skip for non-leaf module continue total_ops += m.total_ops total_params += m.total_params total_ops = total_ops.item() total_params = total_params.item() model.train(training).to(original_device) for handler in handler_collection: handler.remove() return total_ops, total_params