File size: 4,694 Bytes
188f311 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from .my_modules import MyNetwork
__all__ = [
"make_divisible",
"build_activation",
"ShuffleLayer",
"MyGlobalAvgPool2d",
"Hswish",
"Hsigmoid",
"SEModule",
"MultiHeadCrossEntropyLoss",
]
def make_divisible(v, divisor, min_val=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_val:
:return:
"""
if min_val is None:
min_val = divisor
new_v = max(min_val, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def build_activation(act_func, inplace=True):
if act_func == "relu":
return nn.ReLU(inplace=inplace)
elif act_func == "relu6":
return nn.ReLU6(inplace=inplace)
elif act_func == "tanh":
return nn.Tanh()
elif act_func == "sigmoid":
return nn.Sigmoid()
elif act_func == "h_swish":
return Hswish(inplace=inplace)
elif act_func == "h_sigmoid":
return Hsigmoid(inplace=inplace)
elif act_func is None or act_func == "none":
return None
else:
raise ValueError("do not support: %s" % act_func)
class ShuffleLayer(nn.Module):
def __init__(self, groups):
super(ShuffleLayer, self).__init__()
self.groups = groups
def forward(self, x):
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // self.groups
# reshape
x = x.view(batch_size, self.groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x
def __repr__(self):
return "ShuffleLayer(groups=%d)" % self.groups
class MyGlobalAvgPool2d(nn.Module):
def __init__(self, keep_dim=True):
super(MyGlobalAvgPool2d, self).__init__()
self.keep_dim = keep_dim
def forward(self, x):
return x.mean(3, keepdim=self.keep_dim).mean(2, keepdim=self.keep_dim)
def __repr__(self):
return "MyGlobalAvgPool2d(keep_dim=%s)" % self.keep_dim
class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
def __repr__(self):
return "Hswish()"
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return F.relu6(x + 3.0, inplace=self.inplace) / 6.0
def __repr__(self):
return "Hsigmoid()"
class SEModule(nn.Module):
REDUCTION = 4
def __init__(self, channel, reduction=None):
super(SEModule, self).__init__()
self.channel = channel
self.reduction = SEModule.REDUCTION if reduction is None else reduction
num_mid = make_divisible(
self.channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE
)
self.fc = nn.Sequential(
OrderedDict(
[
("reduce", nn.Conv2d(self.channel, num_mid, 1, 1, 0, bias=True)),
("relu", nn.ReLU(inplace=True)),
("expand", nn.Conv2d(num_mid, self.channel, 1, 1, 0, bias=True)),
("h_sigmoid", Hsigmoid(inplace=True)),
]
)
)
def forward(self, x):
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
y = self.fc(y)
return x * y
def __repr__(self):
return "SE(channel=%d, reduction=%d)" % (self.channel, self.reduction)
class MultiHeadCrossEntropyLoss(nn.Module):
def forward(self, outputs, targets):
assert outputs.dim() == 3, outputs
assert targets.dim() == 2, targets
assert outputs.size(1) == targets.size(1), (outputs, targets)
num_heads = targets.size(1)
loss = 0
for k in range(num_heads):
loss += F.cross_entropy(outputs[:, k, :], targets[:, k]) / num_heads
return loss
|