Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- proard/__init__.py +0 -0
- proard/classification/__init__.py +0 -0
- proard/classification/data_providers/__init__.py +3 -0
- proard/classification/data_providers/base_provider.py +58 -0
- proard/classification/data_providers/cifar10.py +264 -0
- proard/classification/data_providers/cifar100.py +264 -0
- proard/classification/data_providers/imagenet.py +310 -0
- proard/classification/elastic_nn/__init__.py +0 -0
- proard/classification/elastic_nn/modules/__init__.py +6 -0
- proard/classification/elastic_nn/modules/dynamic_layers.py +841 -0
- proard/classification/elastic_nn/modules/dynamic_op.py +401 -0
- proard/classification/elastic_nn/networks/__init__.py +7 -0
- proard/classification/elastic_nn/networks/dyn_mbv3.py +780 -0
- proard/classification/elastic_nn/networks/dyn_proxyless.py +774 -0
- proard/classification/elastic_nn/networks/dyn_resnets.py +678 -0
- proard/classification/elastic_nn/training/__init__.py +6 -0
- proard/classification/elastic_nn/training/progressive_shrinking.py +463 -0
- proard/classification/elastic_nn/utils.py +83 -0
- proard/classification/networks/__init__.py +25 -0
- proard/classification/networks/mobilenet_v3.py +559 -0
- proard/classification/networks/proxyless_nets.py +490 -0
- proard/classification/networks/resnet_trades.py +115 -0
- proard/classification/networks/resnets.py +490 -0
- proard/classification/networks/wide_resnet.py +93 -0
- proard/classification/run_manager/__init__.py +7 -0
- proard/classification/run_manager/distributed_run_manager.py +505 -0
- proard/classification/run_manager/run_config.py +414 -0
- proard/classification/run_manager/run_manager.py +484 -0
- proard/model_zoo.py +162 -0
- proard/nas/__init__.py +0 -0
- proard/nas/accuracy_predictor/__init__.py +11 -0
- proard/nas/accuracy_predictor/acc_dataset.py +213 -0
- proard/nas/accuracy_predictor/acc_predictor.py +68 -0
- proard/nas/accuracy_predictor/acc_rob_dataset.py +219 -0
- proard/nas/accuracy_predictor/acc_rob_predictor.py +77 -0
- proard/nas/accuracy_predictor/arch_encoder.py +372 -0
- proard/nas/accuracy_predictor/rob_dataset.py +211 -0
- proard/nas/accuracy_predictor/rob_predictor.py +66 -0
- proard/nas/efficiency_predictor/__init__.py +78 -0
- proard/nas/efficiency_predictor/latency_lookup_table.py +567 -0
- proard/nas/search_algorithm/__init__.py +6 -0
- proard/nas/search_algorithm/evolution.py +143 -0
- proard/nas/search_algorithm/multi_evolution.py +143 -0
- proard/utils/__init__.py +10 -0
- proard/utils/common_tools.py +307 -0
- proard/utils/flops_counter.py +97 -0
- proard/utils/layers.py +819 -0
- proard/utils/my_dataloader/__init__.py +2 -0
- proard/utils/my_dataloader/my_data_loader.py +1050 -0
- proard/utils/my_dataloader/my_data_worker.py +242 -0
proard/__init__.py
ADDED
|
File without changes
|
proard/classification/__init__.py
ADDED
|
File without changes
|
proard/classification/data_providers/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .cifar10 import *
|
| 2 |
+
from .cifar100 import *
|
| 3 |
+
from .imagenet import *
|
proard/classification/data_providers/base_provider.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
__all__ = ["DataProvider"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DataProvider:
|
| 12 |
+
SUB_SEED = 937162211 # random seed for sampling subset
|
| 13 |
+
VALID_SEED = 2147483647 # random seed for the validation set
|
| 14 |
+
|
| 15 |
+
@staticmethod
|
| 16 |
+
def name():
|
| 17 |
+
"""Return name of the dataset"""
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def data_shape(self):
|
| 22 |
+
"""Return shape as python list of one data entry"""
|
| 23 |
+
raise NotImplementedError
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def n_classes(self):
|
| 27 |
+
"""Return `int` of num classes"""
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def save_path(self):
|
| 32 |
+
"""local path to save the data"""
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def data_url(self):
|
| 37 |
+
"""link to download the data"""
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def random_sample_valid_set(train_size, valid_size):
|
| 42 |
+
assert train_size > valid_size
|
| 43 |
+
|
| 44 |
+
g = torch.Generator()
|
| 45 |
+
g.manual_seed(
|
| 46 |
+
DataProvider.VALID_SEED
|
| 47 |
+
) # set random seed before sampling validation set
|
| 48 |
+
rand_indexes = torch.randperm(train_size, generator=g).tolist()
|
| 49 |
+
|
| 50 |
+
valid_indexes = rand_indexes[:valid_size]
|
| 51 |
+
train_indexes = rand_indexes[valid_size:]
|
| 52 |
+
return train_indexes, valid_indexes
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def labels_to_one_hot(n_classes, labels):
|
| 56 |
+
new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
|
| 57 |
+
new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
|
| 58 |
+
return new_labels
|
proard/classification/data_providers/cifar10.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.utils.data
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
import torchvision.datasets as datasets
|
| 8 |
+
from .base_provider import DataProvider
|
| 9 |
+
from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
|
| 10 |
+
|
| 11 |
+
__all__ = ["Cifar10DataProvider"]
|
| 12 |
+
|
| 13 |
+
class Cifar10DataProvider(DataProvider):
|
| 14 |
+
DEFAULT_PATH = "./dataset/cifar10"
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
save_path=None,
|
| 18 |
+
train_batch_size=256,
|
| 19 |
+
test_batch_size=512,
|
| 20 |
+
valid_size=None,
|
| 21 |
+
resize_scale=0.08,
|
| 22 |
+
distort_color=None,
|
| 23 |
+
n_worker=32,
|
| 24 |
+
image_size=32,
|
| 25 |
+
num_replicas=None,
|
| 26 |
+
rank=None,
|
| 27 |
+
):
|
| 28 |
+
|
| 29 |
+
warnings.filterwarnings("ignore")
|
| 30 |
+
self._save_path = save_path
|
| 31 |
+
|
| 32 |
+
self.image_size = image_size # int or list of int
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
self._valid_transform_dict = {}
|
| 36 |
+
if not isinstance(self.image_size, int):
|
| 37 |
+
from proard.utils.my_dataloader.my_data_loader import MyDataLoader
|
| 38 |
+
|
| 39 |
+
assert isinstance(self.image_size, list)
|
| 40 |
+
self.image_size.sort() # e.g., 160 -> 224
|
| 41 |
+
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
| 42 |
+
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
| 43 |
+
|
| 44 |
+
for img_size in self.image_size:
|
| 45 |
+
self._valid_transform_dict[img_size] = self.build_valid_transform(
|
| 46 |
+
img_size
|
| 47 |
+
)
|
| 48 |
+
self.active_img_size = max(self.image_size) # active resolution for test
|
| 49 |
+
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
| 50 |
+
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
| 51 |
+
else:
|
| 52 |
+
self.active_img_size = self.image_size
|
| 53 |
+
valid_transforms = self.build_valid_transform()
|
| 54 |
+
train_loader_class = torch.utils.data.DataLoader
|
| 55 |
+
|
| 56 |
+
train_dataset = self.train_dataset(self.build_train_transform())
|
| 57 |
+
|
| 58 |
+
if valid_size is not None:
|
| 59 |
+
if not isinstance(valid_size, int):
|
| 60 |
+
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
| 61 |
+
valid_size = int(len(train_dataset) * valid_size)
|
| 62 |
+
|
| 63 |
+
valid_dataset = self.train_dataset(valid_transforms)
|
| 64 |
+
train_indexes, valid_indexes = self.random_sample_valid_set(
|
| 65 |
+
len(train_dataset), valid_size
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if num_replicas is not None:
|
| 69 |
+
train_sampler = MyDistributedSampler(
|
| 70 |
+
train_dataset, num_replicas, rank, True, np.array(train_indexes)
|
| 71 |
+
)
|
| 72 |
+
valid_sampler = MyDistributedSampler(
|
| 73 |
+
valid_dataset, num_replicas, rank, True, np.array(valid_indexes)
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
|
| 77 |
+
train_indexes
|
| 78 |
+
)
|
| 79 |
+
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
|
| 80 |
+
valid_indexes
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.train = train_loader_class(
|
| 84 |
+
train_dataset,
|
| 85 |
+
batch_size=train_batch_size,
|
| 86 |
+
sampler=train_sampler,
|
| 87 |
+
num_workers=n_worker,
|
| 88 |
+
pin_memory=False,
|
| 89 |
+
)
|
| 90 |
+
self.valid = torch.utils.data.DataLoader(
|
| 91 |
+
valid_dataset,
|
| 92 |
+
batch_size=test_batch_size,
|
| 93 |
+
sampler=valid_sampler,
|
| 94 |
+
num_workers=n_worker,
|
| 95 |
+
pin_memory=False,
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
if num_replicas is not None:
|
| 99 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 100 |
+
train_dataset, num_replicas, rank
|
| 101 |
+
)
|
| 102 |
+
self.train = train_loader_class(
|
| 103 |
+
train_dataset,
|
| 104 |
+
batch_size=train_batch_size,
|
| 105 |
+
sampler=train_sampler,
|
| 106 |
+
num_workers=n_worker,
|
| 107 |
+
pin_memory=True,
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
self.train = train_loader_class(
|
| 111 |
+
train_dataset,
|
| 112 |
+
batch_size=train_batch_size,
|
| 113 |
+
shuffle=True,
|
| 114 |
+
num_workers=n_worker,
|
| 115 |
+
pin_memory=False,
|
| 116 |
+
)
|
| 117 |
+
self.valid = None
|
| 118 |
+
|
| 119 |
+
test_dataset = self.test_dataset(valid_transforms)
|
| 120 |
+
if num_replicas is not None:
|
| 121 |
+
test_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 122 |
+
test_dataset, num_replicas, rank
|
| 123 |
+
)
|
| 124 |
+
self.test = torch.utils.data.DataLoader(
|
| 125 |
+
test_dataset,
|
| 126 |
+
batch_size=test_batch_size,
|
| 127 |
+
sampler=test_sampler,
|
| 128 |
+
num_workers=n_worker,
|
| 129 |
+
pin_memory=False,
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
self.test = torch.utils.data.DataLoader(
|
| 133 |
+
test_dataset,
|
| 134 |
+
batch_size=test_batch_size,
|
| 135 |
+
shuffle=True,
|
| 136 |
+
num_workers=n_worker,
|
| 137 |
+
pin_memory=False,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if self.valid is None:
|
| 141 |
+
self.valid = self.test
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def name():
|
| 145 |
+
return "cifar10"
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def data_shape(self):
|
| 149 |
+
return 3, self.active_img_size, self.active_img_size # C, H, W
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def n_classes(self):
|
| 153 |
+
return 10
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def save_path(self):
|
| 157 |
+
if self._save_path is None:
|
| 158 |
+
self._save_path = self.DEFAULT_PATH
|
| 159 |
+
if not os.path.exists(self._save_path):
|
| 160 |
+
self._save_path = os.path.expanduser("~/dataset/cifar10")
|
| 161 |
+
return self._save_path
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def data_url(self):
|
| 165 |
+
raise ValueError("unable to download %s" % self.name())
|
| 166 |
+
|
| 167 |
+
def train_dataset(self, _transforms):
|
| 168 |
+
return datasets.CIFAR10(self.train_path, train=True, transform=_transforms,download=True)
|
| 169 |
+
|
| 170 |
+
def test_dataset(self, _transforms):
|
| 171 |
+
return datasets.CIFAR10(self.valid_path, train=False, transform=_transforms,download=True)
|
| 172 |
+
@property
|
| 173 |
+
def train_path(self):
|
| 174 |
+
return os.path.join(self.save_path, "train")
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def valid_path(self):
|
| 178 |
+
return os.path.join(self.save_path, "val")
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def normalize(self):
|
| 182 |
+
return transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
|
| 183 |
+
|
| 184 |
+
def build_train_transform(self, image_size=None, print_log=True):
|
| 185 |
+
if image_size is None:
|
| 186 |
+
image_size = self.image_size
|
| 187 |
+
|
| 188 |
+
# random_resize_crop -> random_horizontal_flip
|
| 189 |
+
train_transforms = [
|
| 190 |
+
transforms.RandomCrop(32,padding=4),
|
| 191 |
+
transforms.RandomHorizontalFlip(),
|
| 192 |
+
# AutoAugment(),
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
train_transforms += [
|
| 196 |
+
transforms.ToTensor(),
|
| 197 |
+
# self.normalize,
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
train_transforms = transforms.Compose(train_transforms)
|
| 201 |
+
return train_transforms
|
| 202 |
+
|
| 203 |
+
def build_valid_transform(self, image_size=None):
|
| 204 |
+
if image_size is None:
|
| 205 |
+
image_size = self.active_img_size
|
| 206 |
+
return transforms.Compose([
|
| 207 |
+
transforms.ToTensor(),
|
| 208 |
+
# self.normalize,
|
| 209 |
+
])
|
| 210 |
+
|
| 211 |
+
def assign_active_img_size(self, new_img_size):
|
| 212 |
+
self.active_img_size = new_img_size
|
| 213 |
+
if self.active_img_size not in self._valid_transform_dict:
|
| 214 |
+
self._valid_transform_dict[
|
| 215 |
+
self.active_img_size
|
| 216 |
+
] = self.build_valid_transform()
|
| 217 |
+
# change the transform of the valid and test set
|
| 218 |
+
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
| 219 |
+
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
| 220 |
+
|
| 221 |
+
def build_sub_train_loader(
|
| 222 |
+
self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None
|
| 223 |
+
):
|
| 224 |
+
# used for resetting BN running statistics
|
| 225 |
+
if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None:
|
| 226 |
+
if num_worker is None:
|
| 227 |
+
num_worker = self.train.num_workers
|
| 228 |
+
|
| 229 |
+
n_samples = len(self.train.dataset)
|
| 230 |
+
g = torch.Generator()
|
| 231 |
+
g.manual_seed(DataProvider.SUB_SEED)
|
| 232 |
+
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
| 233 |
+
|
| 234 |
+
new_train_dataset = self.train_dataset(
|
| 235 |
+
self.build_train_transform(
|
| 236 |
+
image_size=self.active_img_size, print_log=False
|
| 237 |
+
)
|
| 238 |
+
)
|
| 239 |
+
chosen_indexes = rand_indexes[:n_images]
|
| 240 |
+
if num_replicas is not None:
|
| 241 |
+
sub_sampler = MyDistributedSampler(
|
| 242 |
+
new_train_dataset,
|
| 243 |
+
num_replicas,
|
| 244 |
+
rank,
|
| 245 |
+
True,
|
| 246 |
+
np.array(chosen_indexes),
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(
|
| 250 |
+
chosen_indexes
|
| 251 |
+
)
|
| 252 |
+
sub_data_loader = torch.utils.data.DataLoader(
|
| 253 |
+
new_train_dataset,
|
| 254 |
+
batch_size=batch_size,
|
| 255 |
+
sampler=sub_sampler,
|
| 256 |
+
num_workers=num_worker,
|
| 257 |
+
pin_memory=False,
|
| 258 |
+
)
|
| 259 |
+
self.__dict__["sub_train_%d" % self.active_img_size] = []
|
| 260 |
+
for images, labels in sub_data_loader:
|
| 261 |
+
self.__dict__["sub_train_%d" % self.active_img_size].append(
|
| 262 |
+
(images, labels)
|
| 263 |
+
)
|
| 264 |
+
return self.__dict__["sub_train_%d" % self.active_img_size]
|
proard/classification/data_providers/cifar100.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.utils.data
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
import torchvision.datasets as datasets
|
| 8 |
+
from .base_provider import DataProvider
|
| 9 |
+
from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
|
| 10 |
+
|
| 11 |
+
__all__ = ["Cifar100DataProvider"]
|
| 12 |
+
|
| 13 |
+
class Cifar100DataProvider(DataProvider):
|
| 14 |
+
DEFAULT_PATH = "./dataset/cifar100"
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
save_path=None,
|
| 18 |
+
train_batch_size=256,
|
| 19 |
+
test_batch_size=512,
|
| 20 |
+
resize_scale=0.08,
|
| 21 |
+
distort_color=None,
|
| 22 |
+
valid_size=None,
|
| 23 |
+
n_worker=32,
|
| 24 |
+
image_size=32,
|
| 25 |
+
num_replicas=None,
|
| 26 |
+
rank=None,
|
| 27 |
+
):
|
| 28 |
+
|
| 29 |
+
warnings.filterwarnings("ignore")
|
| 30 |
+
self._save_path = save_path
|
| 31 |
+
|
| 32 |
+
self.image_size = image_size # int or list of int
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
self._valid_transform_dict = {}
|
| 36 |
+
if not isinstance(self.image_size, int):
|
| 37 |
+
from proard.utils.my_dataloader.my_data_loader import MyDataLoader
|
| 38 |
+
|
| 39 |
+
assert isinstance(self.image_size, list)
|
| 40 |
+
self.image_size.sort() # e.g., 160 -> 224
|
| 41 |
+
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
| 42 |
+
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
| 43 |
+
|
| 44 |
+
for img_size in self.image_size:
|
| 45 |
+
self._valid_transform_dict[img_size] = self.build_valid_transform(
|
| 46 |
+
img_size
|
| 47 |
+
)
|
| 48 |
+
self.active_img_size = max(self.image_size) # active resolution for test
|
| 49 |
+
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
| 50 |
+
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
| 51 |
+
else:
|
| 52 |
+
self.active_img_size = self.image_size
|
| 53 |
+
valid_transforms = self.build_valid_transform()
|
| 54 |
+
train_loader_class = torch.utils.data.DataLoader
|
| 55 |
+
|
| 56 |
+
train_dataset = self.train_dataset(self.build_train_transform())
|
| 57 |
+
|
| 58 |
+
if valid_size is not None:
|
| 59 |
+
if not isinstance(valid_size, int):
|
| 60 |
+
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
| 61 |
+
valid_size = int(len(train_dataset) * valid_size)
|
| 62 |
+
|
| 63 |
+
valid_dataset = self.train_dataset(valid_transforms)
|
| 64 |
+
train_indexes, valid_indexes = self.random_sample_valid_set(
|
| 65 |
+
len(train_dataset), valid_size
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if num_replicas is not None:
|
| 69 |
+
train_sampler = MyDistributedSampler(
|
| 70 |
+
train_dataset, num_replicas, rank, True, np.array(train_indexes)
|
| 71 |
+
)
|
| 72 |
+
valid_sampler = MyDistributedSampler(
|
| 73 |
+
valid_dataset, num_replicas, rank, True, np.array(valid_indexes)
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
|
| 77 |
+
train_indexes
|
| 78 |
+
)
|
| 79 |
+
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
|
| 80 |
+
valid_indexes
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.train = train_loader_class(
|
| 84 |
+
train_dataset,
|
| 85 |
+
batch_size=train_batch_size,
|
| 86 |
+
sampler=train_sampler,
|
| 87 |
+
num_workers=n_worker,
|
| 88 |
+
pin_memory=False,
|
| 89 |
+
)
|
| 90 |
+
self.valid = torch.utils.data.DataLoader(
|
| 91 |
+
valid_dataset,
|
| 92 |
+
batch_size=test_batch_size,
|
| 93 |
+
sampler=valid_sampler,
|
| 94 |
+
num_workers=n_worker,
|
| 95 |
+
pin_memory=False,
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
if num_replicas is not None:
|
| 99 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 100 |
+
train_dataset, num_replicas, rank
|
| 101 |
+
)
|
| 102 |
+
self.train = train_loader_class(
|
| 103 |
+
train_dataset,
|
| 104 |
+
batch_size=train_batch_size,
|
| 105 |
+
sampler=train_sampler,
|
| 106 |
+
num_workers=n_worker,
|
| 107 |
+
pin_memory=True,
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
self.train = train_loader_class(
|
| 111 |
+
train_dataset,
|
| 112 |
+
batch_size=train_batch_size,
|
| 113 |
+
shuffle=True,
|
| 114 |
+
num_workers=n_worker,
|
| 115 |
+
pin_memory=False,
|
| 116 |
+
)
|
| 117 |
+
self.valid = None
|
| 118 |
+
|
| 119 |
+
test_dataset = self.test_dataset(valid_transforms)
|
| 120 |
+
if num_replicas is not None:
|
| 121 |
+
test_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 122 |
+
test_dataset, num_replicas, rank
|
| 123 |
+
)
|
| 124 |
+
self.test = torch.utils.data.DataLoader(
|
| 125 |
+
test_dataset,
|
| 126 |
+
batch_size=test_batch_size,
|
| 127 |
+
sampler=test_sampler,
|
| 128 |
+
num_workers=n_worker,
|
| 129 |
+
pin_memory=False,
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
self.test = torch.utils.data.DataLoader(
|
| 133 |
+
test_dataset,
|
| 134 |
+
batch_size=test_batch_size,
|
| 135 |
+
shuffle=True,
|
| 136 |
+
num_workers=n_worker,
|
| 137 |
+
pin_memory=False,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if self.valid is None:
|
| 141 |
+
self.valid = self.test
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def name():
|
| 145 |
+
return "cifar100"
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def data_shape(self):
|
| 149 |
+
return 3, self.active_img_size, self.active_img_size # C, H, W
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def n_classes(self):
|
| 153 |
+
return 100
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def save_path(self):
|
| 157 |
+
if self._save_path is None:
|
| 158 |
+
self._save_path = self.DEFAULT_PATH
|
| 159 |
+
if not os.path.exists(self._save_path):
|
| 160 |
+
self._save_path = os.path.expanduser("~/dataset/cifar100")
|
| 161 |
+
return self._save_path
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def data_url(self):
|
| 165 |
+
raise ValueError("unable to download %s" % self.name())
|
| 166 |
+
|
| 167 |
+
def train_dataset(self, _transforms):
|
| 168 |
+
return datasets.CIFAR100(self.train_path, train=True, transform=_transforms,download=True)
|
| 169 |
+
|
| 170 |
+
def test_dataset(self, _transforms):
|
| 171 |
+
return datasets.CIFAR100(self.valid_path, train=False, transform=_transforms,download=True)
|
| 172 |
+
@property
|
| 173 |
+
def train_path(self):
|
| 174 |
+
return os.path.join(self.save_path, "train")
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def valid_path(self):
|
| 178 |
+
return os.path.join(self.save_path, "val")
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def normalize(self):
|
| 182 |
+
return transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
|
| 183 |
+
|
| 184 |
+
def build_train_transform(self, image_size=None, print_log=True):
|
| 185 |
+
if image_size is None:
|
| 186 |
+
image_size = self.image_size
|
| 187 |
+
|
| 188 |
+
# random_resize_crop -> random_horizontal_flip
|
| 189 |
+
train_transforms = [
|
| 190 |
+
transforms.RandomCrop(32,padding=4),
|
| 191 |
+
transforms.RandomHorizontalFlip(),
|
| 192 |
+
# AutoAugment(),
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
train_transforms += [
|
| 196 |
+
transforms.ToTensor(),
|
| 197 |
+
# self.normalize,
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
train_transforms = transforms.Compose(train_transforms)
|
| 201 |
+
return train_transforms
|
| 202 |
+
|
| 203 |
+
def build_valid_transform(self, image_size=None):
|
| 204 |
+
if image_size is None:
|
| 205 |
+
image_size = self.active_img_size
|
| 206 |
+
return transforms.Compose([
|
| 207 |
+
transforms.ToTensor(),
|
| 208 |
+
# self.normalize,
|
| 209 |
+
])
|
| 210 |
+
|
| 211 |
+
def assign_active_img_size(self, new_img_size):
|
| 212 |
+
self.active_img_size = new_img_size
|
| 213 |
+
if self.active_img_size not in self._valid_transform_dict:
|
| 214 |
+
self._valid_transform_dict[
|
| 215 |
+
self.active_img_size
|
| 216 |
+
] = self.build_valid_transform()
|
| 217 |
+
# change the transform of the valid and test set
|
| 218 |
+
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
| 219 |
+
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
| 220 |
+
|
| 221 |
+
def build_sub_train_loader(
|
| 222 |
+
self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None
|
| 223 |
+
):
|
| 224 |
+
# used for resetting BN running statistics
|
| 225 |
+
if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None:
|
| 226 |
+
if num_worker is None:
|
| 227 |
+
num_worker = self.train.num_workers
|
| 228 |
+
|
| 229 |
+
n_samples = len(self.train.dataset)
|
| 230 |
+
g = torch.Generator()
|
| 231 |
+
g.manual_seed(DataProvider.SUB_SEED)
|
| 232 |
+
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
| 233 |
+
|
| 234 |
+
new_train_dataset = self.train_dataset(
|
| 235 |
+
self.build_train_transform(
|
| 236 |
+
image_size=self.active_img_size, print_log=False
|
| 237 |
+
)
|
| 238 |
+
)
|
| 239 |
+
chosen_indexes = rand_indexes[:n_images]
|
| 240 |
+
if num_replicas is not None:
|
| 241 |
+
sub_sampler = MyDistributedSampler(
|
| 242 |
+
new_train_dataset,
|
| 243 |
+
num_replicas,
|
| 244 |
+
rank,
|
| 245 |
+
True,
|
| 246 |
+
np.array(chosen_indexes),
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(
|
| 250 |
+
chosen_indexes
|
| 251 |
+
)
|
| 252 |
+
sub_data_loader = torch.utils.data.DataLoader(
|
| 253 |
+
new_train_dataset,
|
| 254 |
+
batch_size=batch_size,
|
| 255 |
+
sampler=sub_sampler,
|
| 256 |
+
num_workers=num_worker,
|
| 257 |
+
pin_memory=False,
|
| 258 |
+
)
|
| 259 |
+
self.__dict__["sub_train_%d" % self.active_img_size] = []
|
| 260 |
+
for images, labels in sub_data_loader:
|
| 261 |
+
self.__dict__["sub_train_%d" % self.active_img_size].append(
|
| 262 |
+
(images, labels)
|
| 263 |
+
)
|
| 264 |
+
return self.__dict__["sub_train_%d" % self.active_img_size]
|
proard/classification/data_providers/imagenet.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import warnings
|
| 6 |
+
import os
|
| 7 |
+
import math
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
import torchvision.datasets as datasets
|
| 12 |
+
|
| 13 |
+
from .base_provider import DataProvider
|
| 14 |
+
from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
|
| 15 |
+
|
| 16 |
+
__all__ = ["ImagenetDataProvider"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ImagenetDataProvider(DataProvider):
|
| 20 |
+
DEFAULT_PATH = "./dataset/imagenet"
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
save_path=None,
|
| 25 |
+
train_batch_size=256,
|
| 26 |
+
test_batch_size=512,
|
| 27 |
+
valid_size=None,
|
| 28 |
+
n_worker=32,
|
| 29 |
+
resize_scale=0.08,
|
| 30 |
+
distort_color=None,
|
| 31 |
+
image_size=224,
|
| 32 |
+
num_replicas=None,
|
| 33 |
+
rank=None,
|
| 34 |
+
):
|
| 35 |
+
|
| 36 |
+
warnings.filterwarnings("ignore")
|
| 37 |
+
self._save_path = save_path
|
| 38 |
+
|
| 39 |
+
self.image_size = image_size # int or list of int
|
| 40 |
+
self.distort_color = "None" if distort_color is None else distort_color
|
| 41 |
+
self.resize_scale = resize_scale
|
| 42 |
+
|
| 43 |
+
self._valid_transform_dict = {}
|
| 44 |
+
if not isinstance(self.image_size, int):
|
| 45 |
+
from proard.utils.my_dataloader.my_data_loader import MyDataLoader
|
| 46 |
+
|
| 47 |
+
assert isinstance(self.image_size, list)
|
| 48 |
+
self.image_size.sort() # e.g., 160 -> 224
|
| 49 |
+
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
| 50 |
+
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
| 51 |
+
|
| 52 |
+
for img_size in self.image_size:
|
| 53 |
+
self._valid_transform_dict[img_size] = self.build_valid_transform(
|
| 54 |
+
img_size
|
| 55 |
+
)
|
| 56 |
+
self.active_img_size = max(self.image_size) # active resolution for test
|
| 57 |
+
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
| 58 |
+
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
| 59 |
+
else:
|
| 60 |
+
self.active_img_size = self.image_size
|
| 61 |
+
valid_transforms = self.build_valid_transform()
|
| 62 |
+
train_loader_class = torch.utils.data.DataLoader
|
| 63 |
+
|
| 64 |
+
train_dataset = self.train_dataset(self.build_train_transform())
|
| 65 |
+
|
| 66 |
+
if valid_size is not None:
|
| 67 |
+
if not isinstance(valid_size, int):
|
| 68 |
+
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
| 69 |
+
valid_size = int(len(train_dataset) * valid_size)
|
| 70 |
+
|
| 71 |
+
valid_dataset = self.train_dataset(valid_transforms)
|
| 72 |
+
train_indexes, valid_indexes = self.random_sample_valid_set(
|
| 73 |
+
len(train_dataset), valid_size
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
if num_replicas is not None:
|
| 77 |
+
train_sampler = MyDistributedSampler(
|
| 78 |
+
train_dataset, num_replicas, rank, True, np.array(train_indexes)
|
| 79 |
+
)
|
| 80 |
+
valid_sampler = MyDistributedSampler(
|
| 81 |
+
valid_dataset, num_replicas, rank, True, np.array(valid_indexes)
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
|
| 85 |
+
train_indexes
|
| 86 |
+
)
|
| 87 |
+
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
|
| 88 |
+
valid_indexes
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.train = train_loader_class(
|
| 92 |
+
train_dataset,
|
| 93 |
+
batch_size=train_batch_size,
|
| 94 |
+
sampler=train_sampler,
|
| 95 |
+
num_workers=n_worker,
|
| 96 |
+
pin_memory=False,
|
| 97 |
+
)
|
| 98 |
+
self.valid = torch.utils.data.DataLoader(
|
| 99 |
+
valid_dataset,
|
| 100 |
+
batch_size=test_batch_size,
|
| 101 |
+
sampler=valid_sampler,
|
| 102 |
+
num_workers=n_worker,
|
| 103 |
+
pin_memory=False,
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
if num_replicas is not None:
|
| 107 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 108 |
+
train_dataset, num_replicas, rank
|
| 109 |
+
)
|
| 110 |
+
self.train = train_loader_class(
|
| 111 |
+
train_dataset,
|
| 112 |
+
batch_size=train_batch_size,
|
| 113 |
+
sampler=train_sampler,
|
| 114 |
+
num_workers=n_worker,
|
| 115 |
+
pin_memory=True,
|
| 116 |
+
)
|
| 117 |
+
else:
|
| 118 |
+
self.train = train_loader_class(
|
| 119 |
+
train_dataset,
|
| 120 |
+
batch_size=train_batch_size,
|
| 121 |
+
shuffle=True,
|
| 122 |
+
num_workers=n_worker,
|
| 123 |
+
pin_memory=False,
|
| 124 |
+
)
|
| 125 |
+
self.valid = None
|
| 126 |
+
|
| 127 |
+
test_dataset = self.test_dataset(valid_transforms)
|
| 128 |
+
if num_replicas is not None:
|
| 129 |
+
test_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 130 |
+
test_dataset, num_replicas, rank
|
| 131 |
+
)
|
| 132 |
+
self.test = torch.utils.data.DataLoader(
|
| 133 |
+
test_dataset,
|
| 134 |
+
batch_size=test_batch_size,
|
| 135 |
+
sampler=test_sampler,
|
| 136 |
+
num_workers=n_worker,
|
| 137 |
+
pin_memory=False,
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
self.test = torch.utils.data.DataLoader(
|
| 141 |
+
test_dataset,
|
| 142 |
+
batch_size=test_batch_size,
|
| 143 |
+
shuffle=True,
|
| 144 |
+
num_workers=n_worker,
|
| 145 |
+
pin_memory=False,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if self.valid is None:
|
| 149 |
+
self.valid = self.test
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def name():
|
| 153 |
+
return "imagenet"
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def data_shape(self):
|
| 157 |
+
return 3, self.active_img_size, self.active_img_size # C, H, W
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def n_classes(self):
|
| 161 |
+
return 1000
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def save_path(self):
|
| 165 |
+
if self._save_path is None:
|
| 166 |
+
self._save_path = self.DEFAULT_PATH
|
| 167 |
+
if not os.path.exists(self._save_path):
|
| 168 |
+
self._save_path = os.path.expanduser("~/dataset/imagenet")
|
| 169 |
+
return self._save_path
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def data_url(self):
|
| 173 |
+
raise ValueError("unable to download %s" % self.name())
|
| 174 |
+
|
| 175 |
+
def train_dataset(self, _transforms):
|
| 176 |
+
return datasets.ImageFolder(self.train_path, _transforms)
|
| 177 |
+
|
| 178 |
+
def test_dataset(self, _transforms):
|
| 179 |
+
return datasets.ImageFolder(self.valid_path, _transforms)
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def train_path(self):
|
| 183 |
+
return os.path.join(self.save_path, "train")
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def valid_path(self):
|
| 187 |
+
return os.path.join(self.save_path, "val")
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def normalize(self):
|
| 191 |
+
return transforms.Normalize(
|
| 192 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def build_train_transform(self, image_size=None, print_log=True):
|
| 196 |
+
if image_size is None:
|
| 197 |
+
image_size = self.image_size
|
| 198 |
+
if print_log:
|
| 199 |
+
print(
|
| 200 |
+
"Color jitter: %s, resize_scale: %s, img_size: %s"
|
| 201 |
+
% (self.distort_color, self.resize_scale, image_size)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if isinstance(image_size, list):
|
| 205 |
+
resize_transform_class = MyRandomResizedCrop
|
| 206 |
+
print(
|
| 207 |
+
"Use MyRandomResizedCrop: %s, \t %s"
|
| 208 |
+
% MyRandomResizedCrop.get_candidate_image_size(),
|
| 209 |
+
"sync=%s, continuous=%s"
|
| 210 |
+
% (
|
| 211 |
+
MyRandomResizedCrop.SYNC_DISTRIBUTED,
|
| 212 |
+
MyRandomResizedCrop.CONTINUOUS,
|
| 213 |
+
),
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
resize_transform_class = transforms.RandomResizedCrop
|
| 217 |
+
|
| 218 |
+
# random_resize_crop -> random_horizontal_flip
|
| 219 |
+
train_transforms = [
|
| 220 |
+
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
| 221 |
+
transforms.RandomHorizontalFlip(),
|
| 222 |
+
]
|
| 223 |
+
|
| 224 |
+
# color augmentation (optional)
|
| 225 |
+
color_transform = None
|
| 226 |
+
if self.distort_color == "torch":
|
| 227 |
+
color_transform = transforms.ColorJitter(
|
| 228 |
+
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
|
| 229 |
+
)
|
| 230 |
+
elif self.distort_color == "tf":
|
| 231 |
+
color_transform = transforms.ColorJitter(
|
| 232 |
+
brightness=32.0 / 255.0, saturation=0.5
|
| 233 |
+
)
|
| 234 |
+
if color_transform is not None:
|
| 235 |
+
train_transforms.append(color_transform)
|
| 236 |
+
|
| 237 |
+
train_transforms += [
|
| 238 |
+
transforms.ToTensor(),
|
| 239 |
+
self.normalize,
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
train_transforms = transforms.Compose(train_transforms)
|
| 243 |
+
return train_transforms
|
| 244 |
+
|
| 245 |
+
def build_valid_transform(self, image_size=None):
|
| 246 |
+
if image_size is None:
|
| 247 |
+
image_size = self.active_img_size
|
| 248 |
+
return transforms.Compose(
|
| 249 |
+
[
|
| 250 |
+
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
| 251 |
+
transforms.CenterCrop(image_size),
|
| 252 |
+
transforms.ToTensor(),
|
| 253 |
+
self.normalize,
|
| 254 |
+
]
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
def assign_active_img_size(self, new_img_size):
|
| 258 |
+
self.active_img_size = new_img_size
|
| 259 |
+
if self.active_img_size not in self._valid_transform_dict:
|
| 260 |
+
self._valid_transform_dict[
|
| 261 |
+
self.active_img_size
|
| 262 |
+
] = self.build_valid_transform()
|
| 263 |
+
# change the transform of the valid and test set
|
| 264 |
+
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
| 265 |
+
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
| 266 |
+
|
| 267 |
+
def build_sub_train_loader(
|
| 268 |
+
self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None
|
| 269 |
+
):
|
| 270 |
+
# used for resetting BN running statistics
|
| 271 |
+
if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None:
|
| 272 |
+
if num_worker is None:
|
| 273 |
+
num_worker = self.train.num_workers
|
| 274 |
+
|
| 275 |
+
n_samples = len(self.train.dataset)
|
| 276 |
+
g = torch.Generator()
|
| 277 |
+
g.manual_seed(DataProvider.SUB_SEED)
|
| 278 |
+
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
| 279 |
+
|
| 280 |
+
new_train_dataset = self.train_dataset(
|
| 281 |
+
self.build_train_transform(
|
| 282 |
+
image_size=self.active_img_size, print_log=False
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
chosen_indexes = rand_indexes[:n_images]
|
| 286 |
+
if num_replicas is not None:
|
| 287 |
+
sub_sampler = MyDistributedSampler(
|
| 288 |
+
new_train_dataset,
|
| 289 |
+
num_replicas,
|
| 290 |
+
rank,
|
| 291 |
+
True,
|
| 292 |
+
np.array(chosen_indexes),
|
| 293 |
+
)
|
| 294 |
+
else:
|
| 295 |
+
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(
|
| 296 |
+
chosen_indexes
|
| 297 |
+
)
|
| 298 |
+
sub_data_loader = torch.utils.data.DataLoader(
|
| 299 |
+
new_train_dataset,
|
| 300 |
+
batch_size=batch_size,
|
| 301 |
+
sampler=sub_sampler,
|
| 302 |
+
num_workers=num_worker,
|
| 303 |
+
pin_memory=False,
|
| 304 |
+
)
|
| 305 |
+
self.__dict__["sub_train_%d" % self.active_img_size] = []
|
| 306 |
+
for images, labels in sub_data_loader:
|
| 307 |
+
self.__dict__["sub_train_%d" % self.active_img_size].append(
|
| 308 |
+
(images, labels)
|
| 309 |
+
)
|
| 310 |
+
return self.__dict__["sub_train_%d" % self.active_img_size]
|
proard/classification/elastic_nn/__init__.py
ADDED
|
File without changes
|
proard/classification/elastic_nn/modules/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
from .dynamic_layers import *
|
| 6 |
+
from .dynamic_op import *
|
proard/classification/elastic_nn/modules/dynamic_layers.py
ADDED
|
@@ -0,0 +1,841 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
|
| 10 |
+
from proard.utils.layers import (
|
| 11 |
+
MBConvLayer,
|
| 12 |
+
ConvLayer,
|
| 13 |
+
IdentityLayer,
|
| 14 |
+
set_layer_from_config,
|
| 15 |
+
)
|
| 16 |
+
from proard.utils.layers import ResNetBottleneckBlock, LinearLayer
|
| 17 |
+
from proard.utils import (
|
| 18 |
+
MyModule,
|
| 19 |
+
val2list,
|
| 20 |
+
get_net_device,
|
| 21 |
+
build_activation,
|
| 22 |
+
make_divisible,
|
| 23 |
+
SEModule,
|
| 24 |
+
MyNetwork,
|
| 25 |
+
)
|
| 26 |
+
from .dynamic_op import (
|
| 27 |
+
DynamicSeparableConv2d,
|
| 28 |
+
DynamicConv2d,
|
| 29 |
+
DynamicBatchNorm2d,
|
| 30 |
+
DynamicSE,
|
| 31 |
+
DynamicGroupNorm,
|
| 32 |
+
)
|
| 33 |
+
from .dynamic_op import DynamicLinear
|
| 34 |
+
|
| 35 |
+
__all__ = [
|
| 36 |
+
"adjust_bn_according_to_idx",
|
| 37 |
+
"copy_bn",
|
| 38 |
+
"DynamicMBConvLayer",
|
| 39 |
+
"DynamicConvLayer",
|
| 40 |
+
"DynamicLinearLayer",
|
| 41 |
+
"DynamicResNetBottleneckBlock",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def adjust_bn_according_to_idx(bn, idx):
|
| 46 |
+
bn.weight.data = torch.index_select(bn.weight.data, 0, idx)
|
| 47 |
+
bn.bias.data = torch.index_select(bn.bias.data, 0, idx)
|
| 48 |
+
if type(bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
|
| 49 |
+
bn.running_mean.data = torch.index_select(bn.running_mean.data, 0, idx)
|
| 50 |
+
bn.running_var.data = torch.index_select(bn.running_var.data, 0, idx)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def copy_bn(target_bn, src_bn):
|
| 54 |
+
feature_dim = (
|
| 55 |
+
target_bn.num_channels
|
| 56 |
+
if isinstance(target_bn, nn.GroupNorm)
|
| 57 |
+
else target_bn.num_features
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
target_bn.weight.data.copy_(src_bn.weight.data[:feature_dim])
|
| 61 |
+
target_bn.bias.data.copy_(src_bn.bias.data[:feature_dim])
|
| 62 |
+
if type(src_bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
|
| 63 |
+
target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim])
|
| 64 |
+
target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim])
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class DynamicLinearLayer(MyModule):
|
| 68 |
+
def __init__(self, in_features_list, out_features, bias=True, dropout_rate=0):
|
| 69 |
+
super(DynamicLinearLayer, self).__init__()
|
| 70 |
+
|
| 71 |
+
self.in_features_list = in_features_list
|
| 72 |
+
self.out_features = out_features
|
| 73 |
+
self.bias = bias
|
| 74 |
+
self.dropout_rate = dropout_rate
|
| 75 |
+
|
| 76 |
+
if self.dropout_rate > 0:
|
| 77 |
+
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
|
| 78 |
+
else:
|
| 79 |
+
self.dropout = None
|
| 80 |
+
self.linear = DynamicLinear(
|
| 81 |
+
max_in_features=max(self.in_features_list),
|
| 82 |
+
max_out_features=self.out_features,
|
| 83 |
+
bias=self.bias,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
if self.dropout is not None:
|
| 88 |
+
x = self.dropout(x)
|
| 89 |
+
return self.linear(x)
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def module_str(self):
|
| 93 |
+
return "DyLinear(%d, %d)" % (max(self.in_features_list), self.out_features)
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def config(self):
|
| 97 |
+
return {
|
| 98 |
+
"name": DynamicLinear.__name__,
|
| 99 |
+
"in_features_list": self.in_features_list,
|
| 100 |
+
"out_features": self.out_features,
|
| 101 |
+
"bias": self.bias,
|
| 102 |
+
"dropout_rate": self.dropout_rate,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def build_from_config(config):
|
| 107 |
+
return DynamicLinearLayer(**config)
|
| 108 |
+
|
| 109 |
+
def get_active_subnet(self, in_features, preserve_weight=True):
|
| 110 |
+
sub_layer = LinearLayer(
|
| 111 |
+
in_features, self.out_features, self.bias, dropout_rate=self.dropout_rate
|
| 112 |
+
)
|
| 113 |
+
sub_layer = sub_layer.to(get_net_device(self))
|
| 114 |
+
if not preserve_weight:
|
| 115 |
+
return sub_layer
|
| 116 |
+
|
| 117 |
+
sub_layer.linear.weight.data.copy_(
|
| 118 |
+
self.linear.get_active_weight(self.out_features, in_features).data
|
| 119 |
+
)
|
| 120 |
+
if self.bias:
|
| 121 |
+
sub_layer.linear.bias.data.copy_(
|
| 122 |
+
self.linear.get_active_bias(self.out_features).data
|
| 123 |
+
)
|
| 124 |
+
return sub_layer
|
| 125 |
+
|
| 126 |
+
def get_active_subnet_config(self, in_features):
|
| 127 |
+
return {
|
| 128 |
+
"name": LinearLayer.__name__,
|
| 129 |
+
"in_features": in_features,
|
| 130 |
+
"out_features": self.out_features,
|
| 131 |
+
"bias": self.bias,
|
| 132 |
+
"dropout_rate": self.dropout_rate,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DynamicMBConvLayer(MyModule):
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
in_channel_list,
|
| 140 |
+
out_channel_list,
|
| 141 |
+
kernel_size_list=3,
|
| 142 |
+
expand_ratio_list=6,
|
| 143 |
+
stride=1,
|
| 144 |
+
act_func="relu6",
|
| 145 |
+
use_se=False,
|
| 146 |
+
):
|
| 147 |
+
super(DynamicMBConvLayer, self).__init__()
|
| 148 |
+
|
| 149 |
+
self.in_channel_list = in_channel_list
|
| 150 |
+
self.out_channel_list = out_channel_list
|
| 151 |
+
|
| 152 |
+
self.kernel_size_list = val2list(kernel_size_list)
|
| 153 |
+
self.expand_ratio_list = val2list(expand_ratio_list)
|
| 154 |
+
|
| 155 |
+
self.stride = stride
|
| 156 |
+
self.act_func = act_func
|
| 157 |
+
self.use_se = use_se
|
| 158 |
+
|
| 159 |
+
# build modules
|
| 160 |
+
max_middle_channel = make_divisible(
|
| 161 |
+
round(max(self.in_channel_list) * max(self.expand_ratio_list)),
|
| 162 |
+
MyNetwork.CHANNEL_DIVISIBLE,
|
| 163 |
+
)
|
| 164 |
+
if max(self.expand_ratio_list) == 1:
|
| 165 |
+
self.inverted_bottleneck = None
|
| 166 |
+
else:
|
| 167 |
+
self.inverted_bottleneck = nn.Sequential(
|
| 168 |
+
OrderedDict(
|
| 169 |
+
[
|
| 170 |
+
(
|
| 171 |
+
"conv",
|
| 172 |
+
DynamicConv2d(
|
| 173 |
+
max(self.in_channel_list), max_middle_channel
|
| 174 |
+
),
|
| 175 |
+
),
|
| 176 |
+
("bn", DynamicBatchNorm2d(max_middle_channel)),
|
| 177 |
+
("act", build_activation(self.act_func)),
|
| 178 |
+
]
|
| 179 |
+
)
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
self.depth_conv = nn.Sequential(
|
| 183 |
+
OrderedDict(
|
| 184 |
+
[
|
| 185 |
+
(
|
| 186 |
+
"conv",
|
| 187 |
+
DynamicSeparableConv2d(
|
| 188 |
+
max_middle_channel, self.kernel_size_list, self.stride
|
| 189 |
+
),
|
| 190 |
+
),
|
| 191 |
+
("bn", DynamicBatchNorm2d(max_middle_channel)),
|
| 192 |
+
("act", build_activation(self.act_func)),
|
| 193 |
+
]
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
if self.use_se:
|
| 197 |
+
self.depth_conv.add_module("se", DynamicSE(max_middle_channel))
|
| 198 |
+
|
| 199 |
+
self.point_linear = nn.Sequential(
|
| 200 |
+
OrderedDict(
|
| 201 |
+
[
|
| 202 |
+
(
|
| 203 |
+
"conv",
|
| 204 |
+
DynamicConv2d(max_middle_channel, max(self.out_channel_list)),
|
| 205 |
+
),
|
| 206 |
+
("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
|
| 207 |
+
]
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
self.active_kernel_size = max(self.kernel_size_list)
|
| 212 |
+
self.active_expand_ratio = max(self.expand_ratio_list)
|
| 213 |
+
self.active_out_channel = max(self.out_channel_list)
|
| 214 |
+
|
| 215 |
+
def forward(self, x):
|
| 216 |
+
in_channel = x.size(1)
|
| 217 |
+
|
| 218 |
+
if self.inverted_bottleneck is not None:
|
| 219 |
+
self.inverted_bottleneck.conv.active_out_channel = make_divisible(
|
| 220 |
+
round(in_channel * self.active_expand_ratio),
|
| 221 |
+
MyNetwork.CHANNEL_DIVISIBLE,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self.depth_conv.conv.active_kernel_size = self.active_kernel_size
|
| 225 |
+
self.point_linear.conv.active_out_channel = self.active_out_channel
|
| 226 |
+
|
| 227 |
+
if self.inverted_bottleneck is not None:
|
| 228 |
+
x = self.inverted_bottleneck(x)
|
| 229 |
+
x = self.depth_conv(x)
|
| 230 |
+
x = self.point_linear(x)
|
| 231 |
+
return x
|
| 232 |
+
|
| 233 |
+
@property
|
| 234 |
+
def module_str(self):
|
| 235 |
+
if self.use_se:
|
| 236 |
+
return "SE(O%d, E%.1f, K%d)" % (
|
| 237 |
+
self.active_out_channel,
|
| 238 |
+
self.active_expand_ratio,
|
| 239 |
+
self.active_kernel_size,
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
return "(O%d, E%.1f, K%d)" % (
|
| 243 |
+
self.active_out_channel,
|
| 244 |
+
self.active_expand_ratio,
|
| 245 |
+
self.active_kernel_size,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
@property
|
| 249 |
+
def config(self):
|
| 250 |
+
return {
|
| 251 |
+
"name": DynamicMBConvLayer.__name__,
|
| 252 |
+
"in_channel_list": self.in_channel_list,
|
| 253 |
+
"out_channel_list": self.out_channel_list,
|
| 254 |
+
"kernel_size_list": self.kernel_size_list,
|
| 255 |
+
"expand_ratio_list": self.expand_ratio_list,
|
| 256 |
+
"stride": self.stride,
|
| 257 |
+
"act_func": self.act_func,
|
| 258 |
+
"use_se": self.use_se,
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
@staticmethod
|
| 262 |
+
def build_from_config(config):
|
| 263 |
+
return DynamicMBConvLayer(**config)
|
| 264 |
+
|
| 265 |
+
############################################################################################
|
| 266 |
+
|
| 267 |
+
@property
|
| 268 |
+
def in_channels(self):
|
| 269 |
+
return max(self.in_channel_list)
|
| 270 |
+
|
| 271 |
+
@property
|
| 272 |
+
def out_channels(self):
|
| 273 |
+
return max(self.out_channel_list)
|
| 274 |
+
|
| 275 |
+
def active_middle_channel(self, in_channel):
|
| 276 |
+
return make_divisible(
|
| 277 |
+
round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
############################################################################################
|
| 281 |
+
|
| 282 |
+
def get_active_subnet(self, in_channel, preserve_weight=True):
|
| 283 |
+
# build the new layer
|
| 284 |
+
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
|
| 285 |
+
sub_layer = sub_layer.to(get_net_device(self))
|
| 286 |
+
if not preserve_weight:
|
| 287 |
+
return sub_layer
|
| 288 |
+
|
| 289 |
+
middle_channel = self.active_middle_channel(in_channel)
|
| 290 |
+
# copy weight from current layer
|
| 291 |
+
if sub_layer.inverted_bottleneck is not None:
|
| 292 |
+
sub_layer.inverted_bottleneck.conv.weight.data.copy_(
|
| 293 |
+
self.inverted_bottleneck.conv.get_active_filter(
|
| 294 |
+
middle_channel, in_channel
|
| 295 |
+
).data,
|
| 296 |
+
)
|
| 297 |
+
copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn)
|
| 298 |
+
|
| 299 |
+
sub_layer.depth_conv.conv.weight.data.copy_(
|
| 300 |
+
self.depth_conv.conv.get_active_filter(
|
| 301 |
+
middle_channel, self.active_kernel_size
|
| 302 |
+
).data
|
| 303 |
+
)
|
| 304 |
+
copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn)
|
| 305 |
+
|
| 306 |
+
if self.use_se:
|
| 307 |
+
se_mid = make_divisible(
|
| 308 |
+
middle_channel // SEModule.REDUCTION,
|
| 309 |
+
divisor=MyNetwork.CHANNEL_DIVISIBLE,
|
| 310 |
+
)
|
| 311 |
+
sub_layer.depth_conv.se.fc.reduce.weight.data.copy_(
|
| 312 |
+
self.depth_conv.se.get_active_reduce_weight(se_mid, middle_channel).data
|
| 313 |
+
)
|
| 314 |
+
sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(
|
| 315 |
+
self.depth_conv.se.get_active_reduce_bias(se_mid).data
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
sub_layer.depth_conv.se.fc.expand.weight.data.copy_(
|
| 319 |
+
self.depth_conv.se.get_active_expand_weight(se_mid, middle_channel).data
|
| 320 |
+
)
|
| 321 |
+
sub_layer.depth_conv.se.fc.expand.bias.data.copy_(
|
| 322 |
+
self.depth_conv.se.get_active_expand_bias(middle_channel).data
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
sub_layer.point_linear.conv.weight.data.copy_(
|
| 326 |
+
self.point_linear.conv.get_active_filter(
|
| 327 |
+
self.active_out_channel, middle_channel
|
| 328 |
+
).data
|
| 329 |
+
)
|
| 330 |
+
copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn)
|
| 331 |
+
|
| 332 |
+
return sub_layer
|
| 333 |
+
|
| 334 |
+
def get_active_subnet_config(self, in_channel):
|
| 335 |
+
return {
|
| 336 |
+
"name": MBConvLayer.__name__,
|
| 337 |
+
"in_channels": in_channel,
|
| 338 |
+
"out_channels": self.active_out_channel,
|
| 339 |
+
"kernel_size": self.active_kernel_size,
|
| 340 |
+
"stride": self.stride,
|
| 341 |
+
"expand_ratio": self.active_expand_ratio,
|
| 342 |
+
"mid_channels": self.active_middle_channel(in_channel),
|
| 343 |
+
"act_func": self.act_func,
|
| 344 |
+
"use_se": self.use_se,
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
| 348 |
+
importance = torch.sum(
|
| 349 |
+
torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3)
|
| 350 |
+
)
|
| 351 |
+
if isinstance(self.depth_conv.bn, DynamicGroupNorm):
|
| 352 |
+
channel_per_group = self.depth_conv.bn.channel_per_group
|
| 353 |
+
importance_chunks = torch.split(importance, channel_per_group)
|
| 354 |
+
for chunk in importance_chunks:
|
| 355 |
+
chunk.data.fill_(torch.mean(chunk))
|
| 356 |
+
importance = torch.cat(importance_chunks, dim=0)
|
| 357 |
+
if expand_ratio_stage > 0:
|
| 358 |
+
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
|
| 359 |
+
sorted_expand_list.sort(reverse=True)
|
| 360 |
+
target_width_list = [
|
| 361 |
+
make_divisible(
|
| 362 |
+
round(max(self.in_channel_list) * expand),
|
| 363 |
+
MyNetwork.CHANNEL_DIVISIBLE,
|
| 364 |
+
)
|
| 365 |
+
for expand in sorted_expand_list
|
| 366 |
+
]
|
| 367 |
+
|
| 368 |
+
right = len(importance)
|
| 369 |
+
base = -len(target_width_list) * 1e5
|
| 370 |
+
for i in range(expand_ratio_stage + 1):
|
| 371 |
+
left = target_width_list[i]
|
| 372 |
+
importance[left:right] += base
|
| 373 |
+
base += 1e5
|
| 374 |
+
right = left
|
| 375 |
+
|
| 376 |
+
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
|
| 377 |
+
self.point_linear.conv.conv.weight.data = torch.index_select(
|
| 378 |
+
self.point_linear.conv.conv.weight.data, 1, sorted_idx
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx)
|
| 382 |
+
self.depth_conv.conv.conv.weight.data = torch.index_select(
|
| 383 |
+
self.depth_conv.conv.conv.weight.data, 0, sorted_idx
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if self.use_se:
|
| 387 |
+
# se expand: output dim 0 reorganize
|
| 388 |
+
se_expand = self.depth_conv.se.fc.expand
|
| 389 |
+
se_expand.weight.data = torch.index_select(
|
| 390 |
+
se_expand.weight.data, 0, sorted_idx
|
| 391 |
+
)
|
| 392 |
+
se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx)
|
| 393 |
+
# se reduce: input dim 1 reorganize
|
| 394 |
+
se_reduce = self.depth_conv.se.fc.reduce
|
| 395 |
+
se_reduce.weight.data = torch.index_select(
|
| 396 |
+
se_reduce.weight.data, 1, sorted_idx
|
| 397 |
+
)
|
| 398 |
+
# middle weight reorganize
|
| 399 |
+
se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3))
|
| 400 |
+
se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True)
|
| 401 |
+
|
| 402 |
+
se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx)
|
| 403 |
+
se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx)
|
| 404 |
+
se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx)
|
| 405 |
+
|
| 406 |
+
if self.inverted_bottleneck is not None:
|
| 407 |
+
adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx)
|
| 408 |
+
self.inverted_bottleneck.conv.conv.weight.data = torch.index_select(
|
| 409 |
+
self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx
|
| 410 |
+
)
|
| 411 |
+
return None
|
| 412 |
+
else:
|
| 413 |
+
return sorted_idx
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class DynamicConvLayer(MyModule):
|
| 417 |
+
def __init__(
|
| 418 |
+
self,
|
| 419 |
+
in_channel_list,
|
| 420 |
+
out_channel_list,
|
| 421 |
+
kernel_size=3,
|
| 422 |
+
stride=1,
|
| 423 |
+
dilation=1,
|
| 424 |
+
use_bn=True,
|
| 425 |
+
act_func="relu6",
|
| 426 |
+
):
|
| 427 |
+
super(DynamicConvLayer, self).__init__()
|
| 428 |
+
|
| 429 |
+
self.in_channel_list = in_channel_list
|
| 430 |
+
self.out_channel_list = out_channel_list
|
| 431 |
+
self.kernel_size = kernel_size
|
| 432 |
+
self.stride = stride
|
| 433 |
+
self.dilation = dilation
|
| 434 |
+
self.use_bn = use_bn
|
| 435 |
+
self.act_func = act_func
|
| 436 |
+
|
| 437 |
+
self.conv = DynamicConv2d(
|
| 438 |
+
max_in_channels=max(self.in_channel_list),
|
| 439 |
+
max_out_channels=max(self.out_channel_list),
|
| 440 |
+
kernel_size=self.kernel_size,
|
| 441 |
+
stride=self.stride,
|
| 442 |
+
dilation=self.dilation,
|
| 443 |
+
)
|
| 444 |
+
if self.use_bn:
|
| 445 |
+
self.bn = DynamicBatchNorm2d(max(self.out_channel_list))
|
| 446 |
+
self.act = build_activation(self.act_func)
|
| 447 |
+
|
| 448 |
+
self.active_out_channel = max(self.out_channel_list)
|
| 449 |
+
|
| 450 |
+
def forward(self, x):
|
| 451 |
+
self.conv.active_out_channel = self.active_out_channel
|
| 452 |
+
|
| 453 |
+
x = self.conv(x)
|
| 454 |
+
if self.use_bn:
|
| 455 |
+
x = self.bn(x)
|
| 456 |
+
x = self.act(x)
|
| 457 |
+
return x
|
| 458 |
+
|
| 459 |
+
@property
|
| 460 |
+
def module_str(self):
|
| 461 |
+
return "DyConv(O%d, K%d, S%d)" % (
|
| 462 |
+
self.active_out_channel,
|
| 463 |
+
self.kernel_size,
|
| 464 |
+
self.stride,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
@property
|
| 468 |
+
def config(self):
|
| 469 |
+
return {
|
| 470 |
+
"name": DynamicConvLayer.__name__,
|
| 471 |
+
"in_channel_list": self.in_channel_list,
|
| 472 |
+
"out_channel_list": self.out_channel_list,
|
| 473 |
+
"kernel_size": self.kernel_size,
|
| 474 |
+
"stride": self.stride,
|
| 475 |
+
"dilation": self.dilation,
|
| 476 |
+
"use_bn": self.use_bn,
|
| 477 |
+
"act_func": self.act_func,
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
@staticmethod
|
| 481 |
+
def build_from_config(config):
|
| 482 |
+
return DynamicConvLayer(**config)
|
| 483 |
+
|
| 484 |
+
############################################################################################
|
| 485 |
+
|
| 486 |
+
@property
|
| 487 |
+
def in_channels(self):
|
| 488 |
+
return max(self.in_channel_list)
|
| 489 |
+
|
| 490 |
+
@property
|
| 491 |
+
def out_channels(self):
|
| 492 |
+
return max(self.out_channel_list)
|
| 493 |
+
|
| 494 |
+
############################################################################################
|
| 495 |
+
|
| 496 |
+
def get_active_subnet(self, in_channel, preserve_weight=True):
|
| 497 |
+
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
|
| 498 |
+
sub_layer = sub_layer.to(get_net_device(self))
|
| 499 |
+
|
| 500 |
+
if not preserve_weight:
|
| 501 |
+
return sub_layer
|
| 502 |
+
|
| 503 |
+
sub_layer.conv.weight.data.copy_(
|
| 504 |
+
self.conv.get_active_filter(self.active_out_channel, in_channel).data
|
| 505 |
+
)
|
| 506 |
+
if self.use_bn:
|
| 507 |
+
copy_bn(sub_layer.bn, self.bn.bn)
|
| 508 |
+
|
| 509 |
+
return sub_layer
|
| 510 |
+
|
| 511 |
+
def get_active_subnet_config(self, in_channel):
|
| 512 |
+
return {
|
| 513 |
+
"name": ConvLayer.__name__,
|
| 514 |
+
"in_channels": in_channel,
|
| 515 |
+
"out_channels": self.active_out_channel,
|
| 516 |
+
"kernel_size": self.kernel_size,
|
| 517 |
+
"stride": self.stride,
|
| 518 |
+
"dilation": self.dilation,
|
| 519 |
+
"use_bn": self.use_bn,
|
| 520 |
+
"act_func": self.act_func,
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
class DynamicResNetBottleneckBlock(MyModule):
|
| 525 |
+
def __init__(
|
| 526 |
+
self,
|
| 527 |
+
in_channel_list,
|
| 528 |
+
out_channel_list,
|
| 529 |
+
expand_ratio_list=0.25,
|
| 530 |
+
kernel_size=3,
|
| 531 |
+
stride=1,
|
| 532 |
+
act_func="relu",
|
| 533 |
+
downsample_mode="avgpool_conv",
|
| 534 |
+
):
|
| 535 |
+
super(DynamicResNetBottleneckBlock, self).__init__()
|
| 536 |
+
|
| 537 |
+
self.in_channel_list = in_channel_list
|
| 538 |
+
self.out_channel_list = out_channel_list
|
| 539 |
+
self.expand_ratio_list = val2list(expand_ratio_list)
|
| 540 |
+
|
| 541 |
+
self.kernel_size = kernel_size
|
| 542 |
+
self.stride = stride
|
| 543 |
+
self.act_func = act_func
|
| 544 |
+
self.downsample_mode = downsample_mode
|
| 545 |
+
|
| 546 |
+
# build modules
|
| 547 |
+
max_middle_channel = make_divisible(
|
| 548 |
+
round(max(self.out_channel_list) * max(self.expand_ratio_list)),
|
| 549 |
+
MyNetwork.CHANNEL_DIVISIBLE,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
self.conv1 = nn.Sequential(
|
| 553 |
+
OrderedDict(
|
| 554 |
+
[
|
| 555 |
+
(
|
| 556 |
+
"conv",
|
| 557 |
+
DynamicConv2d(max(self.in_channel_list), max_middle_channel),
|
| 558 |
+
),
|
| 559 |
+
("bn", DynamicBatchNorm2d(max_middle_channel)),
|
| 560 |
+
("act", build_activation(self.act_func, inplace=True)),
|
| 561 |
+
]
|
| 562 |
+
)
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
self.conv2 = nn.Sequential(
|
| 566 |
+
OrderedDict(
|
| 567 |
+
[
|
| 568 |
+
(
|
| 569 |
+
"conv",
|
| 570 |
+
DynamicConv2d(
|
| 571 |
+
max_middle_channel, max_middle_channel, kernel_size, stride
|
| 572 |
+
),
|
| 573 |
+
),
|
| 574 |
+
("bn", DynamicBatchNorm2d(max_middle_channel)),
|
| 575 |
+
("act", build_activation(self.act_func, inplace=True)),
|
| 576 |
+
]
|
| 577 |
+
)
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
self.conv3 = nn.Sequential(
|
| 581 |
+
OrderedDict(
|
| 582 |
+
[
|
| 583 |
+
(
|
| 584 |
+
"conv",
|
| 585 |
+
DynamicConv2d(max_middle_channel, max(self.out_channel_list)),
|
| 586 |
+
),
|
| 587 |
+
("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
|
| 588 |
+
]
|
| 589 |
+
)
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
if self.stride == 1 and self.in_channel_list == self.out_channel_list:
|
| 593 |
+
self.downsample = IdentityLayer(
|
| 594 |
+
max(self.in_channel_list), max(self.out_channel_list)
|
| 595 |
+
)
|
| 596 |
+
elif self.downsample_mode == "conv":
|
| 597 |
+
self.downsample = nn.Sequential(
|
| 598 |
+
OrderedDict(
|
| 599 |
+
[
|
| 600 |
+
(
|
| 601 |
+
"conv",
|
| 602 |
+
DynamicConv2d(
|
| 603 |
+
max(self.in_channel_list),
|
| 604 |
+
max(self.out_channel_list),
|
| 605 |
+
stride=stride,
|
| 606 |
+
),
|
| 607 |
+
),
|
| 608 |
+
("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
|
| 609 |
+
]
|
| 610 |
+
)
|
| 611 |
+
)
|
| 612 |
+
elif self.downsample_mode == "avgpool_conv":
|
| 613 |
+
self.downsample = nn.Sequential(
|
| 614 |
+
OrderedDict(
|
| 615 |
+
[
|
| 616 |
+
(
|
| 617 |
+
"avg_pool",
|
| 618 |
+
nn.AvgPool2d(
|
| 619 |
+
kernel_size=stride,
|
| 620 |
+
stride=stride,
|
| 621 |
+
padding=0,
|
| 622 |
+
ceil_mode=True,
|
| 623 |
+
),
|
| 624 |
+
),
|
| 625 |
+
(
|
| 626 |
+
"conv",
|
| 627 |
+
DynamicConv2d(
|
| 628 |
+
max(self.in_channel_list), max(self.out_channel_list)
|
| 629 |
+
),
|
| 630 |
+
),
|
| 631 |
+
("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
|
| 632 |
+
]
|
| 633 |
+
)
|
| 634 |
+
)
|
| 635 |
+
else:
|
| 636 |
+
raise NotImplementedError
|
| 637 |
+
|
| 638 |
+
self.final_act = build_activation(self.act_func, inplace=True)
|
| 639 |
+
|
| 640 |
+
self.active_expand_ratio = max(self.expand_ratio_list)
|
| 641 |
+
self.active_out_channel = max(self.out_channel_list)
|
| 642 |
+
|
| 643 |
+
def forward(self, x):
|
| 644 |
+
feature_dim = self.active_middle_channels
|
| 645 |
+
|
| 646 |
+
self.conv1.conv.active_out_channel = feature_dim
|
| 647 |
+
self.conv2.conv.active_out_channel = feature_dim
|
| 648 |
+
self.conv3.conv.active_out_channel = self.active_out_channel
|
| 649 |
+
if not isinstance(self.downsample, IdentityLayer):
|
| 650 |
+
self.downsample.conv.active_out_channel = self.active_out_channel
|
| 651 |
+
|
| 652 |
+
residual = self.downsample(x)
|
| 653 |
+
|
| 654 |
+
x = self.conv1(x)
|
| 655 |
+
x = self.conv2(x)
|
| 656 |
+
x = self.conv3(x)
|
| 657 |
+
|
| 658 |
+
x = x + residual
|
| 659 |
+
x = self.final_act(x)
|
| 660 |
+
return x
|
| 661 |
+
|
| 662 |
+
@property
|
| 663 |
+
def module_str(self):
|
| 664 |
+
return "(%s, %s)" % (
|
| 665 |
+
"%dx%d_BottleneckConv_in->%d->%d_S%d"
|
| 666 |
+
% (
|
| 667 |
+
self.kernel_size,
|
| 668 |
+
self.kernel_size,
|
| 669 |
+
self.active_middle_channels,
|
| 670 |
+
self.active_out_channel,
|
| 671 |
+
self.stride,
|
| 672 |
+
),
|
| 673 |
+
"Identity"
|
| 674 |
+
if isinstance(self.downsample, IdentityLayer)
|
| 675 |
+
else self.downsample_mode,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
@property
|
| 679 |
+
def config(self):
|
| 680 |
+
return {
|
| 681 |
+
"name": DynamicResNetBottleneckBlock.__name__,
|
| 682 |
+
"in_channel_list": self.in_channel_list,
|
| 683 |
+
"out_channel_list": self.out_channel_list,
|
| 684 |
+
"expand_ratio_list": self.expand_ratio_list,
|
| 685 |
+
"kernel_size": self.kernel_size,
|
| 686 |
+
"stride": self.stride,
|
| 687 |
+
"act_func": self.act_func,
|
| 688 |
+
"downsample_mode": self.downsample_mode,
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
@staticmethod
|
| 692 |
+
def build_from_config(config):
|
| 693 |
+
return DynamicResNetBottleneckBlock(**config)
|
| 694 |
+
|
| 695 |
+
############################################################################################
|
| 696 |
+
|
| 697 |
+
@property
|
| 698 |
+
def in_channels(self):
|
| 699 |
+
return max(self.in_channel_list)
|
| 700 |
+
|
| 701 |
+
@property
|
| 702 |
+
def out_channels(self):
|
| 703 |
+
return max(self.out_channel_list)
|
| 704 |
+
|
| 705 |
+
@property
|
| 706 |
+
def active_middle_channels(self):
|
| 707 |
+
feature_dim = round(self.active_out_channel * self.active_expand_ratio)
|
| 708 |
+
feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
|
| 709 |
+
return feature_dim
|
| 710 |
+
|
| 711 |
+
############################################################################################
|
| 712 |
+
|
| 713 |
+
def get_active_subnet(self, in_channel, preserve_weight=True):
|
| 714 |
+
# build the new layer
|
| 715 |
+
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
|
| 716 |
+
sub_layer = sub_layer.to(get_net_device(self))
|
| 717 |
+
if not preserve_weight:
|
| 718 |
+
return sub_layer
|
| 719 |
+
|
| 720 |
+
# copy weight from current layer
|
| 721 |
+
sub_layer.conv1.conv.weight.data.copy_(
|
| 722 |
+
self.conv1.conv.get_active_filter(
|
| 723 |
+
self.active_middle_channels, in_channel
|
| 724 |
+
).data
|
| 725 |
+
)
|
| 726 |
+
copy_bn(sub_layer.conv1.bn, self.conv1.bn.bn)
|
| 727 |
+
|
| 728 |
+
sub_layer.conv2.conv.weight.data.copy_(
|
| 729 |
+
self.conv2.conv.get_active_filter(
|
| 730 |
+
self.active_middle_channels, self.active_middle_channels
|
| 731 |
+
).data
|
| 732 |
+
)
|
| 733 |
+
copy_bn(sub_layer.conv2.bn, self.conv2.bn.bn)
|
| 734 |
+
|
| 735 |
+
sub_layer.conv3.conv.weight.data.copy_(
|
| 736 |
+
self.conv3.conv.get_active_filter(
|
| 737 |
+
self.active_out_channel, self.active_middle_channels
|
| 738 |
+
).data
|
| 739 |
+
)
|
| 740 |
+
copy_bn(sub_layer.conv3.bn, self.conv3.bn.bn)
|
| 741 |
+
|
| 742 |
+
if not isinstance(self.downsample, IdentityLayer):
|
| 743 |
+
sub_layer.downsample.conv.weight.data.copy_(
|
| 744 |
+
self.downsample.conv.get_active_filter(
|
| 745 |
+
self.active_out_channel, in_channel
|
| 746 |
+
).data
|
| 747 |
+
)
|
| 748 |
+
copy_bn(sub_layer.downsample.bn, self.downsample.bn.bn)
|
| 749 |
+
|
| 750 |
+
return sub_layer
|
| 751 |
+
|
| 752 |
+
def get_active_subnet_config(self, in_channel):
|
| 753 |
+
return {
|
| 754 |
+
"name": ResNetBottleneckBlock.__name__,
|
| 755 |
+
"in_channels": in_channel,
|
| 756 |
+
"out_channels": self.active_out_channel,
|
| 757 |
+
"kernel_size": self.kernel_size,
|
| 758 |
+
"stride": self.stride,
|
| 759 |
+
"expand_ratio": self.active_expand_ratio,
|
| 760 |
+
"mid_channels": self.active_middle_channels,
|
| 761 |
+
"act_func": self.act_func,
|
| 762 |
+
"groups": 1,
|
| 763 |
+
"downsample_mode": self.downsample_mode,
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
| 767 |
+
# conv3 -> conv2
|
| 768 |
+
importance = torch.sum(
|
| 769 |
+
torch.abs(self.conv3.conv.conv.weight.data), dim=(0, 2, 3)
|
| 770 |
+
)
|
| 771 |
+
if isinstance(self.conv2.bn, DynamicGroupNorm):
|
| 772 |
+
channel_per_group = self.conv2.bn.channel_per_group
|
| 773 |
+
importance_chunks = torch.split(importance, channel_per_group)
|
| 774 |
+
for chunk in importance_chunks:
|
| 775 |
+
chunk.data.fill_(torch.mean(chunk))
|
| 776 |
+
importance = torch.cat(importance_chunks, dim=0)
|
| 777 |
+
if expand_ratio_stage > 0:
|
| 778 |
+
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
|
| 779 |
+
sorted_expand_list.sort(reverse=True)
|
| 780 |
+
target_width_list = [
|
| 781 |
+
make_divisible(
|
| 782 |
+
round(max(self.out_channel_list) * expand),
|
| 783 |
+
MyNetwork.CHANNEL_DIVISIBLE,
|
| 784 |
+
)
|
| 785 |
+
for expand in sorted_expand_list
|
| 786 |
+
]
|
| 787 |
+
right = len(importance)
|
| 788 |
+
base = -len(target_width_list) * 1e5
|
| 789 |
+
for i in range(expand_ratio_stage + 1):
|
| 790 |
+
left = target_width_list[i]
|
| 791 |
+
importance[left:right] += base
|
| 792 |
+
base += 1e5
|
| 793 |
+
right = left
|
| 794 |
+
|
| 795 |
+
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
|
| 796 |
+
self.conv3.conv.conv.weight.data = torch.index_select(
|
| 797 |
+
self.conv3.conv.conv.weight.data, 1, sorted_idx
|
| 798 |
+
)
|
| 799 |
+
adjust_bn_according_to_idx(self.conv2.bn.bn, sorted_idx)
|
| 800 |
+
self.conv2.conv.conv.weight.data = torch.index_select(
|
| 801 |
+
self.conv2.conv.conv.weight.data, 0, sorted_idx
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
# conv2 -> conv1
|
| 805 |
+
importance = torch.sum(
|
| 806 |
+
torch.abs(self.conv2.conv.conv.weight.data), dim=(0, 2, 3)
|
| 807 |
+
)
|
| 808 |
+
if isinstance(self.conv1.bn, DynamicGroupNorm):
|
| 809 |
+
channel_per_group = self.conv1.bn.channel_per_group
|
| 810 |
+
importance_chunks = torch.split(importance, channel_per_group)
|
| 811 |
+
for chunk in importance_chunks:
|
| 812 |
+
chunk.data.fill_(torch.mean(chunk))
|
| 813 |
+
importance = torch.cat(importance_chunks, dim=0)
|
| 814 |
+
if expand_ratio_stage > 0:
|
| 815 |
+
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
|
| 816 |
+
sorted_expand_list.sort(reverse=True)
|
| 817 |
+
target_width_list = [
|
| 818 |
+
make_divisible(
|
| 819 |
+
round(max(self.out_channel_list) * expand),
|
| 820 |
+
MyNetwork.CHANNEL_DIVISIBLE,
|
| 821 |
+
)
|
| 822 |
+
for expand in sorted_expand_list
|
| 823 |
+
]
|
| 824 |
+
right = len(importance)
|
| 825 |
+
base = -len(target_width_list) * 1e5
|
| 826 |
+
for i in range(expand_ratio_stage + 1):
|
| 827 |
+
left = target_width_list[i]
|
| 828 |
+
importance[left:right] += base
|
| 829 |
+
base += 1e5
|
| 830 |
+
right = left
|
| 831 |
+
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
|
| 832 |
+
|
| 833 |
+
self.conv2.conv.conv.weight.data = torch.index_select(
|
| 834 |
+
self.conv2.conv.conv.weight.data, 1, sorted_idx
|
| 835 |
+
)
|
| 836 |
+
adjust_bn_according_to_idx(self.conv1.bn.bn, sorted_idx)
|
| 837 |
+
self.conv1.conv.conv.weight.data = torch.index_select(
|
| 838 |
+
self.conv1.conv.conv.weight.data, 0, sorted_idx
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
return None
|
proard/classification/elastic_nn/modules/dynamic_op.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn.parameter import Parameter
|
| 9 |
+
|
| 10 |
+
from proard.utils import (
|
| 11 |
+
get_same_padding,
|
| 12 |
+
sub_filter_start_end,
|
| 13 |
+
make_divisible,
|
| 14 |
+
SEModule,
|
| 15 |
+
MyNetwork,
|
| 16 |
+
MyConv2d,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"DynamicSeparableConv2d",
|
| 21 |
+
"DynamicConv2d",
|
| 22 |
+
"DynamicGroupConv2d",
|
| 23 |
+
"DynamicBatchNorm2d",
|
| 24 |
+
"DynamicGroupNorm",
|
| 25 |
+
"DynamicSE",
|
| 26 |
+
"DynamicLinear",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# Seprable conv consits of a depthwise and pointwise conv
|
| 30 |
+
|
| 31 |
+
class DynamicSeparableConv2d(nn.Module):
|
| 32 |
+
KERNEL_TRANSFORM_MODE = 1 # None or 1
|
| 33 |
+
|
| 34 |
+
def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
|
| 35 |
+
super(DynamicSeparableConv2d, self).__init__()
|
| 36 |
+
|
| 37 |
+
self.max_in_channels = max_in_channels
|
| 38 |
+
self.kernel_size_list = kernel_size_list # list of kernel size
|
| 39 |
+
self.stride = stride
|
| 40 |
+
self.dilation = dilation
|
| 41 |
+
|
| 42 |
+
self.conv = nn.Conv2d(
|
| 43 |
+
self.max_in_channels,
|
| 44 |
+
self.max_in_channels,
|
| 45 |
+
max(self.kernel_size_list),
|
| 46 |
+
self.stride,
|
| 47 |
+
groups=self.max_in_channels,
|
| 48 |
+
bias=False,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
self._ks_set = list(set(self.kernel_size_list))
|
| 52 |
+
self._ks_set.sort() # e.g., [3, 5, 7]
|
| 53 |
+
# define a matrix for converting from damll kernel size to larger one
|
| 54 |
+
if self.KERNEL_TRANSFORM_MODE is not None:
|
| 55 |
+
# register scaling parameters
|
| 56 |
+
# 7to5_matrix, 5to3_matrix
|
| 57 |
+
scale_params = {}
|
| 58 |
+
for i in range(len(self._ks_set) - 1):
|
| 59 |
+
ks_small = self._ks_set[i]
|
| 60 |
+
ks_larger = self._ks_set[i + 1]
|
| 61 |
+
param_name = "%dto%d" % (ks_larger, ks_small)
|
| 62 |
+
# noinspection PyArgumentList
|
| 63 |
+
scale_params["%s_matrix" % param_name] = Parameter(
|
| 64 |
+
torch.eye(ks_small ** 2)
|
| 65 |
+
)
|
| 66 |
+
for name, param in scale_params.items():
|
| 67 |
+
self.register_parameter(name, param)
|
| 68 |
+
|
| 69 |
+
self.active_kernel_size = max(self.kernel_size_list)
|
| 70 |
+
|
| 71 |
+
def get_active_filter(self, in_channel, kernel_size):
|
| 72 |
+
out_channel = in_channel
|
| 73 |
+
max_kernel_size = max(self.kernel_size_list)
|
| 74 |
+
|
| 75 |
+
start, end = sub_filter_start_end(max_kernel_size, kernel_size)
|
| 76 |
+
filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
|
| 77 |
+
if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
|
| 78 |
+
start_filter = self.conv.weight[
|
| 79 |
+
:out_channel, :in_channel, :, :
|
| 80 |
+
] # start with max kernel
|
| 81 |
+
for i in range(len(self._ks_set) - 1, 0, -1):
|
| 82 |
+
src_ks = self._ks_set[i]
|
| 83 |
+
if src_ks <= kernel_size:
|
| 84 |
+
break
|
| 85 |
+
target_ks = self._ks_set[i - 1]
|
| 86 |
+
start, end = sub_filter_start_end(src_ks, target_ks)
|
| 87 |
+
_input_filter = start_filter[:, :, start:end, start:end]
|
| 88 |
+
_input_filter = _input_filter.contiguous()
|
| 89 |
+
_input_filter = _input_filter.view(
|
| 90 |
+
_input_filter.size(0), _input_filter.size(1), -1
|
| 91 |
+
)
|
| 92 |
+
_input_filter = _input_filter.view(-1, _input_filter.size(2))
|
| 93 |
+
_input_filter = F.linear(
|
| 94 |
+
_input_filter,
|
| 95 |
+
self.__getattr__("%dto%d_matrix" % (src_ks, target_ks)),
|
| 96 |
+
)
|
| 97 |
+
_input_filter = _input_filter.view(
|
| 98 |
+
filters.size(0), filters.size(1), target_ks ** 2
|
| 99 |
+
)
|
| 100 |
+
_input_filter = _input_filter.view(
|
| 101 |
+
filters.size(0), filters.size(1), target_ks, target_ks
|
| 102 |
+
)
|
| 103 |
+
start_filter = _input_filter
|
| 104 |
+
filters = start_filter
|
| 105 |
+
return filters
|
| 106 |
+
|
| 107 |
+
def forward(self, x, kernel_size=None):
|
| 108 |
+
if kernel_size is None:
|
| 109 |
+
kernel_size = self.active_kernel_size
|
| 110 |
+
in_channel = x.size(1)
|
| 111 |
+
|
| 112 |
+
filters = self.get_active_filter(in_channel, kernel_size).contiguous()
|
| 113 |
+
|
| 114 |
+
padding = get_same_padding(kernel_size)
|
| 115 |
+
filters = (
|
| 116 |
+
self.conv.weight_standardization(filters)
|
| 117 |
+
if isinstance(self.conv, MyConv2d)
|
| 118 |
+
else filters
|
| 119 |
+
)
|
| 120 |
+
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, in_channel)
|
| 121 |
+
return y
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class DynamicConv2d(nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1
|
| 127 |
+
):
|
| 128 |
+
super(DynamicConv2d, self).__init__()
|
| 129 |
+
|
| 130 |
+
self.max_in_channels = max_in_channels
|
| 131 |
+
self.max_out_channels = max_out_channels
|
| 132 |
+
self.kernel_size = kernel_size
|
| 133 |
+
self.stride = stride
|
| 134 |
+
self.dilation = dilation
|
| 135 |
+
|
| 136 |
+
self.conv = nn.Conv2d(
|
| 137 |
+
self.max_in_channels,
|
| 138 |
+
self.max_out_channels,
|
| 139 |
+
self.kernel_size,
|
| 140 |
+
stride=self.stride,
|
| 141 |
+
bias=False,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.active_out_channel = self.max_out_channels
|
| 145 |
+
|
| 146 |
+
def get_active_filter(self, out_channel, in_channel):
|
| 147 |
+
return self.conv.weight[:out_channel, :in_channel, :, :]
|
| 148 |
+
|
| 149 |
+
def forward(self, x, out_channel=None):
|
| 150 |
+
if out_channel is None:
|
| 151 |
+
out_channel = self.active_out_channel
|
| 152 |
+
in_channel = x.size(1)
|
| 153 |
+
filters = self.get_active_filter(out_channel, in_channel).contiguous()
|
| 154 |
+
|
| 155 |
+
padding = get_same_padding(self.kernel_size)
|
| 156 |
+
filters = (
|
| 157 |
+
self.conv.weight_standardization(filters)
|
| 158 |
+
if isinstance(self.conv, MyConv2d)
|
| 159 |
+
else filters
|
| 160 |
+
)
|
| 161 |
+
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1)
|
| 162 |
+
return y
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class DynamicGroupConv2d(nn.Module):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
in_channels,
|
| 169 |
+
out_channels,
|
| 170 |
+
kernel_size_list,
|
| 171 |
+
groups_list,
|
| 172 |
+
stride=1,
|
| 173 |
+
dilation=1,
|
| 174 |
+
):
|
| 175 |
+
super(DynamicGroupConv2d, self).__init__()
|
| 176 |
+
|
| 177 |
+
self.in_channels = in_channels
|
| 178 |
+
self.out_channels = out_channels
|
| 179 |
+
self.kernel_size_list = kernel_size_list
|
| 180 |
+
self.groups_list = groups_list
|
| 181 |
+
self.stride = stride
|
| 182 |
+
self.dilation = dilation
|
| 183 |
+
|
| 184 |
+
self.conv = nn.Conv2d(
|
| 185 |
+
self.in_channels,
|
| 186 |
+
self.out_channels,
|
| 187 |
+
max(self.kernel_size_list),
|
| 188 |
+
self.stride,
|
| 189 |
+
groups=min(self.groups_list),
|
| 190 |
+
bias=False,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
self.active_kernel_size = max(self.kernel_size_list)
|
| 194 |
+
self.active_groups = min(self.groups_list)
|
| 195 |
+
|
| 196 |
+
def get_active_filter(self, kernel_size, groups):
|
| 197 |
+
start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size)
|
| 198 |
+
filters = self.conv.weight[:, :, start:end, start:end]
|
| 199 |
+
|
| 200 |
+
sub_filters = torch.chunk(filters, groups, dim=0)
|
| 201 |
+
sub_in_channels = self.in_channels // groups
|
| 202 |
+
sub_ratio = filters.size(1) // sub_in_channels
|
| 203 |
+
|
| 204 |
+
filter_crops = []
|
| 205 |
+
for i, sub_filter in enumerate(sub_filters):
|
| 206 |
+
part_id = i % sub_ratio
|
| 207 |
+
start = part_id * sub_in_channels
|
| 208 |
+
filter_crops.append(sub_filter[:, start : start + sub_in_channels, :, :])
|
| 209 |
+
filters = torch.cat(filter_crops, dim=0)
|
| 210 |
+
return filters
|
| 211 |
+
|
| 212 |
+
def forward(self, x, kernel_size=None, groups=None):
|
| 213 |
+
if kernel_size is None:
|
| 214 |
+
kernel_size = self.active_kernel_size
|
| 215 |
+
if groups is None:
|
| 216 |
+
groups = self.active_groups
|
| 217 |
+
|
| 218 |
+
filters = self.get_active_filter(kernel_size, groups).contiguous()
|
| 219 |
+
padding = get_same_padding(kernel_size)
|
| 220 |
+
filters = (
|
| 221 |
+
self.conv.weight_standardization(filters)
|
| 222 |
+
if isinstance(self.conv, MyConv2d)
|
| 223 |
+
else filters
|
| 224 |
+
)
|
| 225 |
+
y = F.conv2d(
|
| 226 |
+
x,
|
| 227 |
+
filters,
|
| 228 |
+
None,
|
| 229 |
+
self.stride,
|
| 230 |
+
padding,
|
| 231 |
+
self.dilation,
|
| 232 |
+
groups,
|
| 233 |
+
)
|
| 234 |
+
return y
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class DynamicBatchNorm2d(nn.Module):
|
| 238 |
+
SET_RUNNING_STATISTICS = False
|
| 239 |
+
|
| 240 |
+
def __init__(self, max_feature_dim):
|
| 241 |
+
super(DynamicBatchNorm2d, self).__init__()
|
| 242 |
+
|
| 243 |
+
self.max_feature_dim = max_feature_dim
|
| 244 |
+
self.bn = nn.BatchNorm2d(self.max_feature_dim)
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def bn_forward(x, bn: nn.BatchNorm2d, feature_dim):
|
| 248 |
+
if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS:
|
| 249 |
+
return bn(x)
|
| 250 |
+
else:
|
| 251 |
+
exponential_average_factor = 0.0
|
| 252 |
+
|
| 253 |
+
if bn.training and bn.track_running_stats:
|
| 254 |
+
if bn.num_batches_tracked is not None:
|
| 255 |
+
bn.num_batches_tracked += 1
|
| 256 |
+
if bn.momentum is None: # use cumulative moving average
|
| 257 |
+
exponential_average_factor = 1.0 / float(bn.num_batches_tracked)
|
| 258 |
+
else: # use exponential moving average
|
| 259 |
+
exponential_average_factor = bn.momentum
|
| 260 |
+
return F.batch_norm(
|
| 261 |
+
x,
|
| 262 |
+
bn.running_mean[:feature_dim],
|
| 263 |
+
bn.running_var[:feature_dim],
|
| 264 |
+
bn.weight[:feature_dim],
|
| 265 |
+
bn.bias[:feature_dim],
|
| 266 |
+
bn.training or not bn.track_running_stats,
|
| 267 |
+
exponential_average_factor,
|
| 268 |
+
bn.eps,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def forward(self, x):
|
| 272 |
+
feature_dim = x.size(1)
|
| 273 |
+
y = self.bn_forward(x, self.bn, feature_dim)
|
| 274 |
+
return y
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class DynamicGroupNorm(nn.GroupNorm):
|
| 278 |
+
def __init__(
|
| 279 |
+
self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None
|
| 280 |
+
):
|
| 281 |
+
super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
|
| 282 |
+
self.channel_per_group = channel_per_group
|
| 283 |
+
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
n_channels = x.size(1)
|
| 286 |
+
n_groups = n_channels // self.channel_per_group
|
| 287 |
+
return F.group_norm(
|
| 288 |
+
x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
@property
|
| 292 |
+
def bn(self):
|
| 293 |
+
return self
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class DynamicSE(SEModule):
|
| 297 |
+
def __init__(self, max_channel):
|
| 298 |
+
super(DynamicSE, self).__init__(max_channel)
|
| 299 |
+
|
| 300 |
+
def get_active_reduce_weight(self, num_mid, in_channel, groups=None):
|
| 301 |
+
if groups is None or groups == 1:
|
| 302 |
+
return self.fc.reduce.weight[:num_mid, :in_channel, :, :]
|
| 303 |
+
else:
|
| 304 |
+
assert in_channel % groups == 0
|
| 305 |
+
sub_in_channels = in_channel // groups
|
| 306 |
+
sub_filters = torch.chunk(
|
| 307 |
+
self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1
|
| 308 |
+
)
|
| 309 |
+
return torch.cat(
|
| 310 |
+
[sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters],
|
| 311 |
+
dim=1,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def get_active_reduce_bias(self, num_mid):
|
| 315 |
+
return (
|
| 316 |
+
self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def get_active_expand_weight(self, num_mid, in_channel, groups=None):
|
| 320 |
+
if groups is None or groups == 1:
|
| 321 |
+
return self.fc.expand.weight[:in_channel, :num_mid, :, :]
|
| 322 |
+
else:
|
| 323 |
+
assert in_channel % groups == 0
|
| 324 |
+
sub_in_channels = in_channel // groups
|
| 325 |
+
sub_filters = torch.chunk(
|
| 326 |
+
self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0
|
| 327 |
+
)
|
| 328 |
+
return torch.cat(
|
| 329 |
+
[sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters],
|
| 330 |
+
dim=0,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
def get_active_expand_bias(self, in_channel, groups=None):
|
| 334 |
+
if groups is None or groups == 1:
|
| 335 |
+
return (
|
| 336 |
+
self.fc.expand.bias[:in_channel]
|
| 337 |
+
if self.fc.expand.bias is not None
|
| 338 |
+
else None
|
| 339 |
+
)
|
| 340 |
+
else:
|
| 341 |
+
assert in_channel % groups == 0
|
| 342 |
+
sub_in_channels = in_channel // groups
|
| 343 |
+
sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0)
|
| 344 |
+
return torch.cat(
|
| 345 |
+
[sub_bias[:sub_in_channels] for sub_bias in sub_bias_list], dim=0
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
def forward(self, x, groups=None):
|
| 349 |
+
in_channel = x.size(1)
|
| 350 |
+
num_mid = make_divisible(
|
| 351 |
+
in_channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
|
| 355 |
+
# reduce
|
| 356 |
+
reduce_filter = self.get_active_reduce_weight(
|
| 357 |
+
num_mid, in_channel, groups=groups
|
| 358 |
+
).contiguous()
|
| 359 |
+
reduce_bias = self.get_active_reduce_bias(num_mid)
|
| 360 |
+
y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1)
|
| 361 |
+
# relu
|
| 362 |
+
y = self.fc.relu(y)
|
| 363 |
+
# expand
|
| 364 |
+
expand_filter = self.get_active_expand_weight(
|
| 365 |
+
num_mid, in_channel, groups=groups
|
| 366 |
+
).contiguous()
|
| 367 |
+
expand_bias = self.get_active_expand_bias(in_channel, groups=groups)
|
| 368 |
+
y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1)
|
| 369 |
+
# hard sigmoid
|
| 370 |
+
y = self.fc.h_sigmoid(y)
|
| 371 |
+
|
| 372 |
+
return x * y
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class DynamicLinear(nn.Module):
|
| 376 |
+
def __init__(self, max_in_features, max_out_features, bias=True):
|
| 377 |
+
super(DynamicLinear, self).__init__()
|
| 378 |
+
|
| 379 |
+
self.max_in_features = max_in_features
|
| 380 |
+
self.max_out_features = max_out_features
|
| 381 |
+
self.bias = bias
|
| 382 |
+
|
| 383 |
+
self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias)
|
| 384 |
+
|
| 385 |
+
self.active_out_features = self.max_out_features
|
| 386 |
+
|
| 387 |
+
def get_active_weight(self, out_features, in_features):
|
| 388 |
+
return self.linear.weight[:out_features, :in_features]
|
| 389 |
+
|
| 390 |
+
def get_active_bias(self, out_features):
|
| 391 |
+
return self.linear.bias[:out_features] if self.bias else None
|
| 392 |
+
|
| 393 |
+
def forward(self, x, out_features=None):
|
| 394 |
+
if out_features is None:
|
| 395 |
+
out_features = self.active_out_features
|
| 396 |
+
|
| 397 |
+
in_features = x.size(1)
|
| 398 |
+
weight = self.get_active_weight(out_features, in_features).contiguous()
|
| 399 |
+
bias = self.get_active_bias(out_features)
|
| 400 |
+
y = F.linear(x, weight, bias)
|
| 401 |
+
return y
|
proard/classification/elastic_nn/networks/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
from .dyn_proxyless import DYNProxylessNASNets,DYNProxylessNASNets_Cifar
|
| 6 |
+
from .dyn_mbv3 import DYNMobileNetV3,DYNMobileNetV3_Cifar
|
| 7 |
+
from .dyn_resnets import DYNResNets,DYNResNets_Cifar
|
proard/classification/elastic_nn/networks/dyn_mbv3.py
ADDED
|
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
from proard.classification.elastic_nn.modules.dynamic_layers import (
|
| 9 |
+
DynamicMBConvLayer,
|
| 10 |
+
)
|
| 11 |
+
from proard.utils.layers import (
|
| 12 |
+
ConvLayer,
|
| 13 |
+
IdentityLayer,
|
| 14 |
+
LinearLayer,
|
| 15 |
+
MBConvLayer,
|
| 16 |
+
ResidualBlock,
|
| 17 |
+
)
|
| 18 |
+
from proard.classification.networks import MobileNetV3,MobileNetV3_Cifar
|
| 19 |
+
from proard.utils import make_divisible, val2list, MyNetwork
|
| 20 |
+
|
| 21 |
+
__all__ = ["DYNMobileNetV3","DYNMobileNetV3_Cifar"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DYNMobileNetV3(MobileNetV3):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
n_classes=1000,
|
| 28 |
+
bn_param=(0.1, 1e-5),
|
| 29 |
+
dropout_rate=0.1,
|
| 30 |
+
base_stage_width=None,
|
| 31 |
+
width_mult=1.0,
|
| 32 |
+
ks_list=3,
|
| 33 |
+
expand_ratio_list=6,
|
| 34 |
+
depth_list=4,
|
| 35 |
+
):
|
| 36 |
+
|
| 37 |
+
self.width_mult = width_mult
|
| 38 |
+
self.ks_list = val2list(ks_list, 1)
|
| 39 |
+
self.expand_ratio_list = val2list(expand_ratio_list, 1)
|
| 40 |
+
self.depth_list = val2list(depth_list, 1)
|
| 41 |
+
|
| 42 |
+
self.ks_list.sort()
|
| 43 |
+
self.expand_ratio_list.sort()
|
| 44 |
+
self.depth_list.sort()
|
| 45 |
+
|
| 46 |
+
base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280]
|
| 47 |
+
|
| 48 |
+
final_expand_width = make_divisible(
|
| 49 |
+
base_stage_width[-2] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 50 |
+
)
|
| 51 |
+
last_channel = make_divisible(
|
| 52 |
+
base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
stride_stages = [1, 2, 2, 2, 1, 2]
|
| 56 |
+
act_stages = ["relu", "relu", "relu", "h_swish", "h_swish", "h_swish"]
|
| 57 |
+
se_stages = [False, False, True, False, True, True]
|
| 58 |
+
n_block_list = [1] + [max(self.depth_list)] * 5
|
| 59 |
+
width_list = []
|
| 60 |
+
for base_width in base_stage_width[:-2]:
|
| 61 |
+
width = make_divisible(
|
| 62 |
+
base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 63 |
+
)
|
| 64 |
+
width_list.append(width)
|
| 65 |
+
|
| 66 |
+
input_channel, first_block_dim = width_list[0], width_list[1]
|
| 67 |
+
# first conv layer
|
| 68 |
+
first_conv = ConvLayer(
|
| 69 |
+
3, input_channel, kernel_size=3, stride=2, act_func="h_swish"
|
| 70 |
+
)
|
| 71 |
+
first_block_conv = MBConvLayer(
|
| 72 |
+
in_channels=input_channel,
|
| 73 |
+
out_channels=first_block_dim,
|
| 74 |
+
kernel_size=3,
|
| 75 |
+
stride=stride_stages[0],
|
| 76 |
+
expand_ratio=1,
|
| 77 |
+
act_func=act_stages[0],
|
| 78 |
+
use_se=se_stages[0],
|
| 79 |
+
)
|
| 80 |
+
first_block = ResidualBlock(
|
| 81 |
+
first_block_conv,
|
| 82 |
+
IdentityLayer(first_block_dim, first_block_dim)
|
| 83 |
+
if input_channel == first_block_dim
|
| 84 |
+
else None,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# inverted residual blocks
|
| 88 |
+
self.block_group_info = []
|
| 89 |
+
blocks = [first_block]
|
| 90 |
+
_block_index = 1
|
| 91 |
+
feature_dim = first_block_dim
|
| 92 |
+
|
| 93 |
+
for width, n_block, s, act_func, use_se in zip(
|
| 94 |
+
width_list[2:],
|
| 95 |
+
n_block_list[1:],
|
| 96 |
+
stride_stages[1:],
|
| 97 |
+
act_stages[1:],
|
| 98 |
+
se_stages[1:],
|
| 99 |
+
):
|
| 100 |
+
self.block_group_info.append([_block_index + i for i in range(n_block)])
|
| 101 |
+
_block_index += n_block
|
| 102 |
+
|
| 103 |
+
output_channel = width
|
| 104 |
+
for i in range(n_block):
|
| 105 |
+
if i == 0:
|
| 106 |
+
stride = s
|
| 107 |
+
else:
|
| 108 |
+
stride = 1
|
| 109 |
+
mobile_inverted_conv = DynamicMBConvLayer(
|
| 110 |
+
in_channel_list=val2list(feature_dim),
|
| 111 |
+
out_channel_list=val2list(output_channel),
|
| 112 |
+
kernel_size_list=ks_list,
|
| 113 |
+
expand_ratio_list=expand_ratio_list,
|
| 114 |
+
stride=stride,
|
| 115 |
+
act_func=act_func,
|
| 116 |
+
use_se=use_se,
|
| 117 |
+
)
|
| 118 |
+
if stride == 1 and feature_dim == output_channel:
|
| 119 |
+
shortcut = IdentityLayer(feature_dim, feature_dim)
|
| 120 |
+
else:
|
| 121 |
+
shortcut = None
|
| 122 |
+
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
|
| 123 |
+
feature_dim = output_channel
|
| 124 |
+
# final expand layer, feature mix layer & classifier
|
| 125 |
+
final_expand_layer = ConvLayer(
|
| 126 |
+
feature_dim, final_expand_width, kernel_size=1, act_func="h_swish"
|
| 127 |
+
)
|
| 128 |
+
feature_mix_layer = ConvLayer(
|
| 129 |
+
final_expand_width,
|
| 130 |
+
last_channel,
|
| 131 |
+
kernel_size=1,
|
| 132 |
+
bias=False,
|
| 133 |
+
use_bn=False,
|
| 134 |
+
act_func="h_swish",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
| 138 |
+
|
| 139 |
+
super(DYNMobileNetV3, self).__init__(
|
| 140 |
+
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# set bn param
|
| 144 |
+
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
|
| 145 |
+
|
| 146 |
+
# runtime_depth
|
| 147 |
+
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
|
| 148 |
+
|
| 149 |
+
""" MyNetwork required methods """
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def name():
|
| 153 |
+
return "DYNMobileNetV3"
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
# first conv
|
| 157 |
+
x = self.first_conv(x)
|
| 158 |
+
# first block
|
| 159 |
+
x = self.blocks[0](x)
|
| 160 |
+
# blocks
|
| 161 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 162 |
+
depth = self.runtime_depth[stage_id]
|
| 163 |
+
active_idx = block_idx[:depth]
|
| 164 |
+
for idx in active_idx:
|
| 165 |
+
x = self.blocks[idx](x)
|
| 166 |
+
x = self.final_expand_layer(x)
|
| 167 |
+
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
|
| 168 |
+
x = self.feature_mix_layer(x)
|
| 169 |
+
x = x.view(x.size(0), -1)
|
| 170 |
+
x = self.classifier(x)
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def module_str(self):
|
| 175 |
+
_str = self.first_conv.module_str + "\n"
|
| 176 |
+
_str += self.blocks[0].module_str + "\n"
|
| 177 |
+
|
| 178 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 179 |
+
depth = self.runtime_depth[stage_id]
|
| 180 |
+
active_idx = block_idx[:depth]
|
| 181 |
+
for idx in active_idx:
|
| 182 |
+
_str += self.blocks[idx].module_str + "\n"
|
| 183 |
+
|
| 184 |
+
_str += self.final_expand_layer.module_str + "\n"
|
| 185 |
+
_str += self.feature_mix_layer.module_str + "\n"
|
| 186 |
+
_str += self.classifier.module_str + "\n"
|
| 187 |
+
return _str
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def config(self):
|
| 191 |
+
return {
|
| 192 |
+
"name": DYNMobileNetV3.__name__,
|
| 193 |
+
"bn": self.get_bn_param(),
|
| 194 |
+
"first_conv": self.first_conv.config,
|
| 195 |
+
"blocks": [block.config for block in self.blocks],
|
| 196 |
+
"final_expand_layer": self.final_expand_layer.config,
|
| 197 |
+
"feature_mix_layer": self.feature_mix_layer.config,
|
| 198 |
+
"classifier": self.classifier.config,
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def build_from_config(config):
|
| 203 |
+
raise ValueError("do not support this function")
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def grouped_block_index(self):
|
| 207 |
+
return self.block_group_info
|
| 208 |
+
|
| 209 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 210 |
+
model_dict = self.state_dict()
|
| 211 |
+
for key in state_dict:
|
| 212 |
+
if ".mobile_inverted_conv." in key:
|
| 213 |
+
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
|
| 214 |
+
else:
|
| 215 |
+
new_key = key
|
| 216 |
+
if new_key in model_dict:
|
| 217 |
+
pass
|
| 218 |
+
elif ".bn.bn." in new_key:
|
| 219 |
+
new_key = new_key.replace(".bn.bn.", ".bn.")
|
| 220 |
+
elif ".conv.conv.weight" in new_key:
|
| 221 |
+
new_key = new_key.replace(".conv.conv.weight", ".conv.weight")
|
| 222 |
+
elif ".linear.linear." in new_key:
|
| 223 |
+
new_key = new_key.replace(".linear.linear.", ".linear.")
|
| 224 |
+
##############################################################################
|
| 225 |
+
elif ".linear." in new_key:
|
| 226 |
+
new_key = new_key.replace(".linear.", ".linear.linear.")
|
| 227 |
+
elif "bn." in new_key:
|
| 228 |
+
new_key = new_key.replace("bn.", "bn.bn.")
|
| 229 |
+
elif "conv.weight" in new_key:
|
| 230 |
+
new_key = new_key.replace("conv.weight", "conv.conv.weight")
|
| 231 |
+
else:
|
| 232 |
+
raise ValueError(new_key)
|
| 233 |
+
assert new_key in model_dict, "%s" % new_key
|
| 234 |
+
model_dict[new_key] = state_dict[key]
|
| 235 |
+
super(DYNMobileNetV3, self).load_state_dict(model_dict)
|
| 236 |
+
|
| 237 |
+
""" set, sample and get active sub-networks """
|
| 238 |
+
|
| 239 |
+
def set_max_net(self):
|
| 240 |
+
self.set_active_subnet(
|
| 241 |
+
ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
|
| 245 |
+
ks = val2list(ks, len(self.blocks) - 1)
|
| 246 |
+
expand_ratio = val2list(e, len(self.blocks) - 1)
|
| 247 |
+
depth = val2list(d, len(self.block_group_info))
|
| 248 |
+
|
| 249 |
+
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
|
| 250 |
+
if k is not None:
|
| 251 |
+
block.conv.active_kernel_size = k
|
| 252 |
+
if e is not None:
|
| 253 |
+
block.conv.active_expand_ratio = e
|
| 254 |
+
|
| 255 |
+
for i, d in enumerate(depth):
|
| 256 |
+
if d is not None:
|
| 257 |
+
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
|
| 258 |
+
|
| 259 |
+
def set_constraint(self, include_list, constraint_type="depth"):
|
| 260 |
+
if constraint_type == "depth":
|
| 261 |
+
self.__dict__["_depth_include_list"] = include_list.copy()
|
| 262 |
+
elif constraint_type == "expand_ratio":
|
| 263 |
+
self.__dict__["_expand_include_list"] = include_list.copy()
|
| 264 |
+
elif constraint_type == "kernel_size":
|
| 265 |
+
self.__dict__["_ks_include_list"] = include_list.copy()
|
| 266 |
+
else:
|
| 267 |
+
raise NotImplementedError
|
| 268 |
+
|
| 269 |
+
def clear_constraint(self):
|
| 270 |
+
self.__dict__["_depth_include_list"] = None
|
| 271 |
+
self.__dict__["_expand_include_list"] = None
|
| 272 |
+
self.__dict__["_ks_include_list"] = None
|
| 273 |
+
|
| 274 |
+
def sample_active_subnet(self):
|
| 275 |
+
ks_candidates = (
|
| 276 |
+
self.ks_list
|
| 277 |
+
if self.__dict__.get("_ks_include_list", None) is None
|
| 278 |
+
else self.__dict__["_ks_include_list"]
|
| 279 |
+
)
|
| 280 |
+
expand_candidates = (
|
| 281 |
+
self.expand_ratio_list
|
| 282 |
+
if self.__dict__.get("_expand_include_list", None) is None
|
| 283 |
+
else self.__dict__["_expand_include_list"]
|
| 284 |
+
)
|
| 285 |
+
depth_candidates = (
|
| 286 |
+
self.depth_list
|
| 287 |
+
if self.__dict__.get("_depth_include_list", None) is None
|
| 288 |
+
else self.__dict__["_depth_include_list"]
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# sample kernel size
|
| 292 |
+
ks_setting = []
|
| 293 |
+
if not isinstance(ks_candidates[0], list):
|
| 294 |
+
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
|
| 295 |
+
for k_set in ks_candidates:
|
| 296 |
+
k = random.choice(k_set)
|
| 297 |
+
ks_setting.append(k)
|
| 298 |
+
|
| 299 |
+
# sample expand ratio
|
| 300 |
+
expand_setting = []
|
| 301 |
+
if not isinstance(expand_candidates[0], list):
|
| 302 |
+
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
|
| 303 |
+
for e_set in expand_candidates:
|
| 304 |
+
e = random.choice(e_set)
|
| 305 |
+
expand_setting.append(e)
|
| 306 |
+
|
| 307 |
+
# sample depth
|
| 308 |
+
depth_setting = []
|
| 309 |
+
if not isinstance(depth_candidates[0], list):
|
| 310 |
+
depth_candidates = [
|
| 311 |
+
depth_candidates for _ in range(len(self.block_group_info))
|
| 312 |
+
]
|
| 313 |
+
for d_set in depth_candidates:
|
| 314 |
+
d = random.choice(d_set)
|
| 315 |
+
depth_setting.append(d)
|
| 316 |
+
|
| 317 |
+
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
|
| 318 |
+
|
| 319 |
+
return {
|
| 320 |
+
"ks": ks_setting,
|
| 321 |
+
"e": expand_setting,
|
| 322 |
+
"d": depth_setting,
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
def get_active_subnet(self, preserve_weight=True):
|
| 326 |
+
first_conv = copy.deepcopy(self.first_conv)
|
| 327 |
+
blocks = [copy.deepcopy(self.blocks[0])]
|
| 328 |
+
|
| 329 |
+
final_expand_layer = copy.deepcopy(self.final_expand_layer)
|
| 330 |
+
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
|
| 331 |
+
classifier = copy.deepcopy(self.classifier)
|
| 332 |
+
|
| 333 |
+
input_channel = blocks[0].conv.out_channels
|
| 334 |
+
# blocks
|
| 335 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 336 |
+
depth = self.runtime_depth[stage_id]
|
| 337 |
+
active_idx = block_idx[:depth]
|
| 338 |
+
stage_blocks = []
|
| 339 |
+
for idx in active_idx:
|
| 340 |
+
stage_blocks.append(
|
| 341 |
+
ResidualBlock(
|
| 342 |
+
self.blocks[idx].conv.get_active_subnet(
|
| 343 |
+
input_channel, preserve_weight
|
| 344 |
+
),
|
| 345 |
+
copy.deepcopy(self.blocks[idx].shortcut),
|
| 346 |
+
)
|
| 347 |
+
)
|
| 348 |
+
input_channel = stage_blocks[-1].conv.out_channels
|
| 349 |
+
blocks += stage_blocks
|
| 350 |
+
|
| 351 |
+
_subnet = MobileNetV3(
|
| 352 |
+
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 353 |
+
)
|
| 354 |
+
_subnet.set_bn_param(**self.get_bn_param())
|
| 355 |
+
return _subnet
|
| 356 |
+
|
| 357 |
+
def get_active_net_config(self):
|
| 358 |
+
# first conv
|
| 359 |
+
first_conv_config = self.first_conv.config
|
| 360 |
+
first_block_config = self.blocks[0].config
|
| 361 |
+
final_expand_config = self.final_expand_layer.config
|
| 362 |
+
feature_mix_layer_config = self.feature_mix_layer.config
|
| 363 |
+
classifier_config = self.classifier.config
|
| 364 |
+
|
| 365 |
+
block_config_list = [first_block_config]
|
| 366 |
+
input_channel = first_block_config["conv"]["out_channels"]
|
| 367 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 368 |
+
depth = self.runtime_depth[stage_id]
|
| 369 |
+
active_idx = block_idx[:depth]
|
| 370 |
+
stage_blocks = []
|
| 371 |
+
for idx in active_idx:
|
| 372 |
+
stage_blocks.append(
|
| 373 |
+
{
|
| 374 |
+
"name": ResidualBlock.__name__,
|
| 375 |
+
"conv": self.blocks[idx].conv.get_active_subnet_config(
|
| 376 |
+
input_channel
|
| 377 |
+
),
|
| 378 |
+
"shortcut": self.blocks[idx].shortcut.config
|
| 379 |
+
if self.blocks[idx].shortcut is not None
|
| 380 |
+
else None,
|
| 381 |
+
}
|
| 382 |
+
)
|
| 383 |
+
input_channel = self.blocks[idx].conv.active_out_channel
|
| 384 |
+
block_config_list += stage_blocks
|
| 385 |
+
|
| 386 |
+
return {
|
| 387 |
+
"name": MobileNetV3.__name__,
|
| 388 |
+
"bn": self.get_bn_param(),
|
| 389 |
+
"first_conv": first_conv_config,
|
| 390 |
+
"blocks": block_config_list,
|
| 391 |
+
"final_expand_layer": final_expand_config,
|
| 392 |
+
"feature_mix_layer": feature_mix_layer_config,
|
| 393 |
+
"classifier": classifier_config,
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
""" Width Related Methods """
|
| 397 |
+
|
| 398 |
+
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
| 399 |
+
for block in self.blocks[1:]:
|
| 400 |
+
block.conv.re_organize_middle_weights(expand_ratio_stage)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class DYNMobileNetV3_Cifar(MobileNetV3_Cifar):
|
| 405 |
+
def __init__(
|
| 406 |
+
self,
|
| 407 |
+
n_classes=10,
|
| 408 |
+
bn_param=(0.1, 1e-5),
|
| 409 |
+
dropout_rate=0.1,
|
| 410 |
+
base_stage_width=None,
|
| 411 |
+
width_mult=1.0,
|
| 412 |
+
ks_list=3,
|
| 413 |
+
expand_ratio_list=6,
|
| 414 |
+
depth_list=4,
|
| 415 |
+
):
|
| 416 |
+
|
| 417 |
+
self.width_mult = width_mult
|
| 418 |
+
self.ks_list = val2list(ks_list, 1)
|
| 419 |
+
self.expand_ratio_list = val2list(expand_ratio_list, 1)
|
| 420 |
+
self.depth_list = val2list(depth_list, 1)
|
| 421 |
+
|
| 422 |
+
self.ks_list.sort()
|
| 423 |
+
self.expand_ratio_list.sort()
|
| 424 |
+
self.depth_list.sort()
|
| 425 |
+
|
| 426 |
+
base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280]
|
| 427 |
+
|
| 428 |
+
final_expand_width = make_divisible(
|
| 429 |
+
base_stage_width[-2] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 430 |
+
)
|
| 431 |
+
last_channel = make_divisible(
|
| 432 |
+
base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
stride_stages = [1, 1, 2, 2, 1, 2]
|
| 436 |
+
act_stages = ["relu", "relu", "relu", "h_swish", "h_swish", "h_swish"]
|
| 437 |
+
se_stages = [False, False, True, False, True, True]
|
| 438 |
+
n_block_list = [1] + [max(self.depth_list)] * 5
|
| 439 |
+
width_list = []
|
| 440 |
+
for base_width in base_stage_width[:-2]:
|
| 441 |
+
width = make_divisible(
|
| 442 |
+
base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 443 |
+
)
|
| 444 |
+
width_list.append(width)
|
| 445 |
+
|
| 446 |
+
input_channel, first_block_dim = width_list[0], width_list[1]
|
| 447 |
+
# first conv layer
|
| 448 |
+
first_conv = ConvLayer(
|
| 449 |
+
3, input_channel, kernel_size=3, stride=1, act_func="h_swish"
|
| 450 |
+
)
|
| 451 |
+
first_block_conv = MBConvLayer(
|
| 452 |
+
in_channels=input_channel,
|
| 453 |
+
out_channels=first_block_dim,
|
| 454 |
+
kernel_size=3,
|
| 455 |
+
stride=stride_stages[0],
|
| 456 |
+
expand_ratio=1,
|
| 457 |
+
act_func=act_stages[0],
|
| 458 |
+
use_se=se_stages[0],
|
| 459 |
+
)
|
| 460 |
+
first_block = ResidualBlock(
|
| 461 |
+
first_block_conv,
|
| 462 |
+
IdentityLayer(first_block_dim, first_block_dim)
|
| 463 |
+
if input_channel == first_block_dim
|
| 464 |
+
else None,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# inverted residual blocks
|
| 468 |
+
self.block_group_info = []
|
| 469 |
+
blocks = [first_block]
|
| 470 |
+
_block_index = 1
|
| 471 |
+
feature_dim = first_block_dim
|
| 472 |
+
|
| 473 |
+
for width, n_block, s, act_func, use_se in zip(
|
| 474 |
+
width_list[2:],
|
| 475 |
+
n_block_list[1:],
|
| 476 |
+
stride_stages[1:],
|
| 477 |
+
act_stages[1:],
|
| 478 |
+
se_stages[1:],
|
| 479 |
+
):
|
| 480 |
+
self.block_group_info.append([_block_index + i for i in range(n_block)])
|
| 481 |
+
_block_index += n_block
|
| 482 |
+
|
| 483 |
+
output_channel = width
|
| 484 |
+
for i in range(n_block):
|
| 485 |
+
if i == 0:
|
| 486 |
+
stride = s
|
| 487 |
+
else:
|
| 488 |
+
stride = 1
|
| 489 |
+
mobile_inverted_conv = DynamicMBConvLayer(
|
| 490 |
+
in_channel_list=val2list(feature_dim),
|
| 491 |
+
out_channel_list=val2list(output_channel),
|
| 492 |
+
kernel_size_list=ks_list,
|
| 493 |
+
expand_ratio_list=expand_ratio_list,
|
| 494 |
+
stride=stride,
|
| 495 |
+
act_func=act_func,
|
| 496 |
+
use_se=use_se,
|
| 497 |
+
)
|
| 498 |
+
if stride == 1 and feature_dim == output_channel:
|
| 499 |
+
shortcut = IdentityLayer(feature_dim, feature_dim)
|
| 500 |
+
else:
|
| 501 |
+
shortcut = None
|
| 502 |
+
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
|
| 503 |
+
feature_dim = output_channel
|
| 504 |
+
# final expand layer, feature mix layer & classifier
|
| 505 |
+
final_expand_layer = ConvLayer(
|
| 506 |
+
feature_dim, final_expand_width, kernel_size=1, act_func="h_swish"
|
| 507 |
+
)
|
| 508 |
+
feature_mix_layer = ConvLayer(
|
| 509 |
+
final_expand_width,
|
| 510 |
+
last_channel,
|
| 511 |
+
kernel_size=1,
|
| 512 |
+
bias=False,
|
| 513 |
+
use_bn=False,
|
| 514 |
+
act_func="h_swish",
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
| 518 |
+
|
| 519 |
+
super(DYNMobileNetV3_Cifar, self).__init__(
|
| 520 |
+
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# set bn param
|
| 524 |
+
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
|
| 525 |
+
|
| 526 |
+
# runtime_depth
|
| 527 |
+
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
|
| 528 |
+
|
| 529 |
+
""" MyNetwork required methods """
|
| 530 |
+
|
| 531 |
+
@staticmethod
|
| 532 |
+
def name():
|
| 533 |
+
return "DYNMobileNetV3_Cifar"
|
| 534 |
+
|
| 535 |
+
def forward(self, x):
|
| 536 |
+
# first conv
|
| 537 |
+
x = self.first_conv(x)
|
| 538 |
+
# first block
|
| 539 |
+
x = self.blocks[0](x)
|
| 540 |
+
# blocks
|
| 541 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 542 |
+
depth = self.runtime_depth[stage_id]
|
| 543 |
+
active_idx = block_idx[:depth]
|
| 544 |
+
for idx in active_idx:
|
| 545 |
+
x = self.blocks[idx](x)
|
| 546 |
+
x = self.final_expand_layer(x)
|
| 547 |
+
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
|
| 548 |
+
x = self.feature_mix_layer(x)
|
| 549 |
+
x = x.view(x.size(0), -1)
|
| 550 |
+
x = self.classifier(x)
|
| 551 |
+
return x
|
| 552 |
+
|
| 553 |
+
@property
|
| 554 |
+
def module_str(self):
|
| 555 |
+
_str = self.first_conv.module_str + "\n"
|
| 556 |
+
_str += self.blocks[0].module_str + "\n"
|
| 557 |
+
|
| 558 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 559 |
+
depth = self.runtime_depth[stage_id]
|
| 560 |
+
active_idx = block_idx[:depth]
|
| 561 |
+
for idx in active_idx:
|
| 562 |
+
_str += self.blocks[idx].module_str + "\n"
|
| 563 |
+
|
| 564 |
+
_str += self.final_expand_layer.module_str + "\n"
|
| 565 |
+
_str += self.feature_mix_layer.module_str + "\n"
|
| 566 |
+
_str += self.classifier.module_str + "\n"
|
| 567 |
+
return _str
|
| 568 |
+
|
| 569 |
+
@property
|
| 570 |
+
def config(self):
|
| 571 |
+
return {
|
| 572 |
+
"name": DYNMobileNetV3_Cifar.__name__,
|
| 573 |
+
"bn": self.get_bn_param(),
|
| 574 |
+
"first_conv": self.first_conv.config,
|
| 575 |
+
"blocks": [block.config for block in self.blocks],
|
| 576 |
+
"final_expand_layer": self.final_expand_layer.config,
|
| 577 |
+
"feature_mix_layer": self.feature_mix_layer.config,
|
| 578 |
+
"classifier": self.classifier.config,
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
@staticmethod
|
| 582 |
+
def build_from_config(config):
|
| 583 |
+
raise ValueError("do not support this function")
|
| 584 |
+
|
| 585 |
+
@property
|
| 586 |
+
def grouped_block_index(self):
|
| 587 |
+
return self.block_group_info
|
| 588 |
+
|
| 589 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 590 |
+
model_dict = self.state_dict()
|
| 591 |
+
for key in state_dict:
|
| 592 |
+
if ".mobile_inverted_conv." in key:
|
| 593 |
+
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
|
| 594 |
+
else:
|
| 595 |
+
new_key = key
|
| 596 |
+
if new_key in model_dict:
|
| 597 |
+
pass
|
| 598 |
+
elif ".bn.bn." in new_key:
|
| 599 |
+
new_key = new_key.replace(".bn.bn.", ".bn.")
|
| 600 |
+
elif ".conv.conv.weight" in new_key:
|
| 601 |
+
new_key = new_key.replace(".conv.conv.weight", ".conv.weight")
|
| 602 |
+
elif ".linear.linear." in new_key:
|
| 603 |
+
new_key = new_key.replace(".linear.linear.", ".linear.")
|
| 604 |
+
##############################################################################
|
| 605 |
+
elif ".linear." in new_key:
|
| 606 |
+
new_key = new_key.replace(".linear.", ".linear.linear.")
|
| 607 |
+
elif "bn." in new_key:
|
| 608 |
+
new_key = new_key.replace("bn.", "bn.bn.")
|
| 609 |
+
elif "conv.weight" in new_key:
|
| 610 |
+
new_key = new_key.replace("conv.weight", "conv.conv.weight")
|
| 611 |
+
else:
|
| 612 |
+
raise ValueError(new_key)
|
| 613 |
+
assert new_key in model_dict, "%s" % new_key
|
| 614 |
+
model_dict[new_key] = state_dict[key]
|
| 615 |
+
super(DYNMobileNetV3_Cifar, self).load_state_dict(model_dict)
|
| 616 |
+
|
| 617 |
+
""" set, sample and get active sub-networks """
|
| 618 |
+
|
| 619 |
+
def set_max_net(self):
|
| 620 |
+
self.set_active_subnet(
|
| 621 |
+
ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list)
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
|
| 625 |
+
ks = val2list(ks, len(self.blocks) - 1)
|
| 626 |
+
expand_ratio = val2list(e, len(self.blocks) - 1)
|
| 627 |
+
depth = val2list(d, len(self.block_group_info))
|
| 628 |
+
|
| 629 |
+
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
|
| 630 |
+
if k is not None:
|
| 631 |
+
block.conv.active_kernel_size = k
|
| 632 |
+
if e is not None:
|
| 633 |
+
block.conv.active_expand_ratio = e
|
| 634 |
+
|
| 635 |
+
for i, d in enumerate(depth):
|
| 636 |
+
if d is not None:
|
| 637 |
+
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
|
| 638 |
+
|
| 639 |
+
def set_constraint(self, include_list, constraint_type="depth"):
|
| 640 |
+
if constraint_type == "depth":
|
| 641 |
+
self.__dict__["_depth_include_list"] = include_list.copy()
|
| 642 |
+
elif constraint_type == "expand_ratio":
|
| 643 |
+
self.__dict__["_expand_include_list"] = include_list.copy()
|
| 644 |
+
elif constraint_type == "kernel_size":
|
| 645 |
+
self.__dict__["_ks_include_list"] = include_list.copy()
|
| 646 |
+
else:
|
| 647 |
+
raise NotImplementedError
|
| 648 |
+
|
| 649 |
+
def clear_constraint(self):
|
| 650 |
+
self.__dict__["_depth_include_list"] = None
|
| 651 |
+
self.__dict__["_expand_include_list"] = None
|
| 652 |
+
self.__dict__["_ks_include_list"] = None
|
| 653 |
+
|
| 654 |
+
def sample_active_subnet(self):
|
| 655 |
+
ks_candidates = (
|
| 656 |
+
self.ks_list
|
| 657 |
+
if self.__dict__.get("_ks_include_list", None) is None
|
| 658 |
+
else self.__dict__["_ks_include_list"]
|
| 659 |
+
)
|
| 660 |
+
expand_candidates = (
|
| 661 |
+
self.expand_ratio_list
|
| 662 |
+
if self.__dict__.get("_expand_include_list", None) is None
|
| 663 |
+
else self.__dict__["_expand_include_list"]
|
| 664 |
+
)
|
| 665 |
+
depth_candidates = (
|
| 666 |
+
self.depth_list
|
| 667 |
+
if self.__dict__.get("_depth_include_list", None) is None
|
| 668 |
+
else self.__dict__["_depth_include_list"]
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# sample kernel size
|
| 672 |
+
ks_setting = []
|
| 673 |
+
if not isinstance(ks_candidates[0], list):
|
| 674 |
+
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
|
| 675 |
+
for k_set in ks_candidates:
|
| 676 |
+
k = random.choice(k_set)
|
| 677 |
+
ks_setting.append(k)
|
| 678 |
+
|
| 679 |
+
# sample expand ratio
|
| 680 |
+
expand_setting = []
|
| 681 |
+
if not isinstance(expand_candidates[0], list):
|
| 682 |
+
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
|
| 683 |
+
for e_set in expand_candidates:
|
| 684 |
+
e = random.choice(e_set)
|
| 685 |
+
expand_setting.append(e)
|
| 686 |
+
|
| 687 |
+
# sample depth
|
| 688 |
+
depth_setting = []
|
| 689 |
+
if not isinstance(depth_candidates[0], list):
|
| 690 |
+
depth_candidates = [
|
| 691 |
+
depth_candidates for _ in range(len(self.block_group_info))
|
| 692 |
+
]
|
| 693 |
+
for d_set in depth_candidates:
|
| 694 |
+
d = random.choice(d_set)
|
| 695 |
+
depth_setting.append(d)
|
| 696 |
+
|
| 697 |
+
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
|
| 698 |
+
|
| 699 |
+
return {
|
| 700 |
+
"ks": ks_setting,
|
| 701 |
+
"e": expand_setting,
|
| 702 |
+
"d": depth_setting,
|
| 703 |
+
}
|
| 704 |
+
|
| 705 |
+
def get_active_subnet(self, preserve_weight=True):
|
| 706 |
+
first_conv = copy.deepcopy(self.first_conv)
|
| 707 |
+
blocks = [copy.deepcopy(self.blocks[0])]
|
| 708 |
+
|
| 709 |
+
final_expand_layer = copy.deepcopy(self.final_expand_layer)
|
| 710 |
+
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
|
| 711 |
+
classifier = copy.deepcopy(self.classifier)
|
| 712 |
+
|
| 713 |
+
input_channel = blocks[0].conv.out_channels
|
| 714 |
+
# blocks
|
| 715 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 716 |
+
depth = self.runtime_depth[stage_id]
|
| 717 |
+
active_idx = block_idx[:depth]
|
| 718 |
+
stage_blocks = []
|
| 719 |
+
for idx in active_idx:
|
| 720 |
+
stage_blocks.append(
|
| 721 |
+
ResidualBlock(
|
| 722 |
+
self.blocks[idx].conv.get_active_subnet(
|
| 723 |
+
input_channel, preserve_weight
|
| 724 |
+
),
|
| 725 |
+
copy.deepcopy(self.blocks[idx].shortcut),
|
| 726 |
+
)
|
| 727 |
+
)
|
| 728 |
+
input_channel = stage_blocks[-1].conv.out_channels
|
| 729 |
+
blocks += stage_blocks
|
| 730 |
+
|
| 731 |
+
_subnet = MobileNetV3_Cifar(
|
| 732 |
+
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 733 |
+
)
|
| 734 |
+
_subnet.set_bn_param(**self.get_bn_param())
|
| 735 |
+
return _subnet
|
| 736 |
+
|
| 737 |
+
def get_active_net_config(self):
|
| 738 |
+
# first conv
|
| 739 |
+
first_conv_config = self.first_conv.config
|
| 740 |
+
first_block_config = self.blocks[0].config
|
| 741 |
+
final_expand_config = self.final_expand_layer.config
|
| 742 |
+
feature_mix_layer_config = self.feature_mix_layer.config
|
| 743 |
+
classifier_config = self.classifier.config
|
| 744 |
+
|
| 745 |
+
block_config_list = [first_block_config]
|
| 746 |
+
input_channel = first_block_config["conv"]["out_channels"]
|
| 747 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 748 |
+
depth = self.runtime_depth[stage_id]
|
| 749 |
+
active_idx = block_idx[:depth]
|
| 750 |
+
stage_blocks = []
|
| 751 |
+
for idx in active_idx:
|
| 752 |
+
stage_blocks.append(
|
| 753 |
+
{
|
| 754 |
+
"name": ResidualBlock.__name__,
|
| 755 |
+
"conv": self.blocks[idx].conv.get_active_subnet_config(
|
| 756 |
+
input_channel
|
| 757 |
+
),
|
| 758 |
+
"shortcut": self.blocks[idx].shortcut.config
|
| 759 |
+
if self.blocks[idx].shortcut is not None
|
| 760 |
+
else None,
|
| 761 |
+
}
|
| 762 |
+
)
|
| 763 |
+
input_channel = self.blocks[idx].conv.active_out_channel
|
| 764 |
+
block_config_list += stage_blocks
|
| 765 |
+
|
| 766 |
+
return {
|
| 767 |
+
"name": MobileNetV3_Cifar.__name__,
|
| 768 |
+
"bn": self.get_bn_param(),
|
| 769 |
+
"first_conv": first_conv_config,
|
| 770 |
+
"blocks": block_config_list,
|
| 771 |
+
"final_expand_layer": final_expand_config,
|
| 772 |
+
"feature_mix_layer": feature_mix_layer_config,
|
| 773 |
+
"classifier": classifier_config,
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
""" Width Related Methods """
|
| 777 |
+
|
| 778 |
+
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
| 779 |
+
for block in self.blocks[1:]:
|
| 780 |
+
block.conv.re_organize_middle_weights(expand_ratio_stage)
|
proard/classification/elastic_nn/networks/dyn_proxyless.py
ADDED
|
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
from proard.utils import make_divisible, val2list, MyNetwork
|
| 9 |
+
from proard.classification.elastic_nn.modules import DynamicMBConvLayer
|
| 10 |
+
from proard.utils.layers import (
|
| 11 |
+
ConvLayer,
|
| 12 |
+
IdentityLayer,
|
| 13 |
+
LinearLayer,
|
| 14 |
+
MBConvLayer,
|
| 15 |
+
ResidualBlock,
|
| 16 |
+
)
|
| 17 |
+
from proard.classification.networks.proxyless_nets import ProxylessNASNets,ProxylessNASNets_Cifar
|
| 18 |
+
|
| 19 |
+
__all__ = ["DYNProxylessNASNets","DYNProxylessNASNets_Cifar"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DYNProxylessNASNets(ProxylessNASNets):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
n_classes=1000,
|
| 26 |
+
bn_param=(0.1, 1e-3),
|
| 27 |
+
dropout_rate=0.1,
|
| 28 |
+
base_stage_width=None,
|
| 29 |
+
width_mult=1.0,
|
| 30 |
+
ks_list=3,
|
| 31 |
+
expand_ratio_list=6,
|
| 32 |
+
depth_list=4,
|
| 33 |
+
):
|
| 34 |
+
|
| 35 |
+
self.width_mult = width_mult
|
| 36 |
+
self.ks_list = val2list(ks_list, 1)
|
| 37 |
+
self.expand_ratio_list = val2list(expand_ratio_list, 1)
|
| 38 |
+
self.depth_list = val2list(depth_list, 1)
|
| 39 |
+
|
| 40 |
+
self.ks_list.sort()
|
| 41 |
+
self.expand_ratio_list.sort()
|
| 42 |
+
self.depth_list.sort()
|
| 43 |
+
|
| 44 |
+
if base_stage_width == "google":
|
| 45 |
+
# MobileNetV2 Stage Width
|
| 46 |
+
base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280]
|
| 47 |
+
else:
|
| 48 |
+
# ProxylessNAS Stage Width
|
| 49 |
+
base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280]
|
| 50 |
+
|
| 51 |
+
input_channel = make_divisible(
|
| 52 |
+
base_stage_width[0] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 53 |
+
)
|
| 54 |
+
first_block_width = make_divisible(
|
| 55 |
+
base_stage_width[1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 56 |
+
)
|
| 57 |
+
last_channel = make_divisible(
|
| 58 |
+
base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# first conv layer
|
| 62 |
+
first_conv = ConvLayer(
|
| 63 |
+
3,
|
| 64 |
+
input_channel,
|
| 65 |
+
kernel_size=3,
|
| 66 |
+
stride=2,
|
| 67 |
+
use_bn=True,
|
| 68 |
+
act_func="relu6",
|
| 69 |
+
ops_order="weight_bn_act",
|
| 70 |
+
)
|
| 71 |
+
# first block
|
| 72 |
+
first_block_conv = MBConvLayer(
|
| 73 |
+
in_channels=input_channel,
|
| 74 |
+
out_channels=first_block_width,
|
| 75 |
+
kernel_size=3,
|
| 76 |
+
stride=1,
|
| 77 |
+
expand_ratio=1,
|
| 78 |
+
act_func="relu6",
|
| 79 |
+
)
|
| 80 |
+
first_block = ResidualBlock(first_block_conv, None)
|
| 81 |
+
|
| 82 |
+
input_channel = first_block_width
|
| 83 |
+
# inverted residual blocks
|
| 84 |
+
self.block_group_info = []
|
| 85 |
+
blocks = [first_block]
|
| 86 |
+
_block_index = 1
|
| 87 |
+
|
| 88 |
+
stride_stages = [2, 2, 2, 1, 2, 1]
|
| 89 |
+
n_block_list = [max(self.depth_list)] * 5 + [1]
|
| 90 |
+
|
| 91 |
+
width_list = []
|
| 92 |
+
for base_width in base_stage_width[2:-1]:
|
| 93 |
+
width = make_divisible(
|
| 94 |
+
base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 95 |
+
)
|
| 96 |
+
width_list.append(width)
|
| 97 |
+
|
| 98 |
+
for width, n_block, s in zip(width_list, n_block_list, stride_stages):
|
| 99 |
+
self.block_group_info.append([_block_index + i for i in range(n_block)])
|
| 100 |
+
_block_index += n_block
|
| 101 |
+
|
| 102 |
+
output_channel = width
|
| 103 |
+
for i in range(n_block):
|
| 104 |
+
if i == 0:
|
| 105 |
+
stride = s
|
| 106 |
+
else:
|
| 107 |
+
stride = 1
|
| 108 |
+
|
| 109 |
+
mobile_inverted_conv = DynamicMBConvLayer(
|
| 110 |
+
in_channel_list=val2list(input_channel, 1),
|
| 111 |
+
out_channel_list=val2list(output_channel, 1),
|
| 112 |
+
kernel_size_list=ks_list,
|
| 113 |
+
expand_ratio_list=expand_ratio_list,
|
| 114 |
+
stride=stride,
|
| 115 |
+
act_func="relu6",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if stride == 1 and input_channel == output_channel:
|
| 119 |
+
shortcut = IdentityLayer(input_channel, input_channel)
|
| 120 |
+
else:
|
| 121 |
+
shortcut = None
|
| 122 |
+
|
| 123 |
+
mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut)
|
| 124 |
+
|
| 125 |
+
blocks.append(mb_inverted_block)
|
| 126 |
+
input_channel = output_channel
|
| 127 |
+
# 1x1_conv before global average pooling
|
| 128 |
+
feature_mix_layer = ConvLayer(
|
| 129 |
+
input_channel,
|
| 130 |
+
last_channel,
|
| 131 |
+
kernel_size=1,
|
| 132 |
+
use_bn=True,
|
| 133 |
+
act_func="relu6",
|
| 134 |
+
)
|
| 135 |
+
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
| 136 |
+
|
| 137 |
+
super(DYNProxylessNASNets, self).__init__(
|
| 138 |
+
first_conv, blocks, feature_mix_layer, classifier
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# set bn param
|
| 142 |
+
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
|
| 143 |
+
|
| 144 |
+
# runtime_depth
|
| 145 |
+
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
|
| 146 |
+
|
| 147 |
+
""" MyNetwork required methods """
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def name():
|
| 151 |
+
return "DYNProxylessNASNets"
|
| 152 |
+
|
| 153 |
+
def forward(self, x):
|
| 154 |
+
# first conv
|
| 155 |
+
x = self.first_conv(x)
|
| 156 |
+
# first block
|
| 157 |
+
x = self.blocks[0](x)
|
| 158 |
+
|
| 159 |
+
# blocks
|
| 160 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 161 |
+
depth = self.runtime_depth[stage_id]
|
| 162 |
+
active_idx = block_idx[:depth]
|
| 163 |
+
for idx in active_idx:
|
| 164 |
+
x = self.blocks[idx](x)
|
| 165 |
+
|
| 166 |
+
# feature_mix_layer
|
| 167 |
+
x = self.feature_mix_layer(x)
|
| 168 |
+
x = x.mean(3).mean(2)
|
| 169 |
+
|
| 170 |
+
x = self.classifier(x)
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def module_str(self):
|
| 175 |
+
_str = self.first_conv.module_str + "\n"
|
| 176 |
+
_str += self.blocks[0].module_str + "\n"
|
| 177 |
+
|
| 178 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 179 |
+
depth = self.runtime_depth[stage_id]
|
| 180 |
+
active_idx = block_idx[:depth]
|
| 181 |
+
for idx in active_idx:
|
| 182 |
+
_str += self.blocks[idx].module_str + "\n"
|
| 183 |
+
_str += self.feature_mix_layer.module_str + "\n"
|
| 184 |
+
_str += self.classifier.module_str + "\n"
|
| 185 |
+
return _str
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def config(self):
|
| 189 |
+
return {
|
| 190 |
+
"name": DYNProxylessNASNets.__name__,
|
| 191 |
+
"bn": self.get_bn_param(),
|
| 192 |
+
"first_conv": self.first_conv.config,
|
| 193 |
+
"blocks": [block.config for block in self.blocks],
|
| 194 |
+
"feature_mix_layer": None
|
| 195 |
+
if self.feature_mix_layer is None
|
| 196 |
+
else self.feature_mix_layer.config,
|
| 197 |
+
"classifier": self.classifier.config,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
@staticmethod
|
| 201 |
+
def build_from_config(config):
|
| 202 |
+
raise ValueError("do not support this function")
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def grouped_block_index(self):
|
| 206 |
+
return self.block_group_info
|
| 207 |
+
|
| 208 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 209 |
+
model_dict = self.state_dict()
|
| 210 |
+
for key in state_dict:
|
| 211 |
+
if ".mobile_inverted_conv." in key:
|
| 212 |
+
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
|
| 213 |
+
else:
|
| 214 |
+
new_key = key
|
| 215 |
+
if new_key in model_dict:
|
| 216 |
+
pass
|
| 217 |
+
elif ".bn.bn." in new_key:
|
| 218 |
+
new_key = new_key.replace(".bn.bn.", ".bn.")
|
| 219 |
+
elif ".conv.conv.weight" in new_key:
|
| 220 |
+
new_key = new_key.replace(".conv.conv.weight", ".conv.weight")
|
| 221 |
+
elif ".linear.linear." in new_key:
|
| 222 |
+
new_key = new_key.replace(".linear.linear.", ".linear.")
|
| 223 |
+
##############################################################################
|
| 224 |
+
elif ".linear." in new_key:
|
| 225 |
+
new_key = new_key.replace(".linear.", ".linear.linear.")
|
| 226 |
+
elif "bn." in new_key:
|
| 227 |
+
new_key = new_key.replace("bn.", "bn.bn.")
|
| 228 |
+
elif "conv.weight" in new_key:
|
| 229 |
+
new_key = new_key.replace("conv.weight", "conv.conv.weight")
|
| 230 |
+
else:
|
| 231 |
+
raise ValueError(new_key)
|
| 232 |
+
assert new_key in model_dict, "%s" % new_key
|
| 233 |
+
model_dict[new_key] = state_dict[key]
|
| 234 |
+
super(DYNProxylessNASNets, self).load_state_dict(model_dict)
|
| 235 |
+
|
| 236 |
+
""" set, sample and get active sub-networks """
|
| 237 |
+
|
| 238 |
+
def set_max_net(self):
|
| 239 |
+
self.set_active_subnet(
|
| 240 |
+
ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
|
| 244 |
+
ks = val2list(ks, len(self.blocks) - 1)
|
| 245 |
+
expand_ratio = val2list(e, len(self.blocks) - 1)
|
| 246 |
+
depth = val2list(d, len(self.block_group_info))
|
| 247 |
+
|
| 248 |
+
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
|
| 249 |
+
if k is not None:
|
| 250 |
+
block.conv.active_kernel_size = k
|
| 251 |
+
if e is not None:
|
| 252 |
+
block.conv.active_expand_ratio = e
|
| 253 |
+
|
| 254 |
+
for i, d in enumerate(depth):
|
| 255 |
+
if d is not None:
|
| 256 |
+
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
|
| 257 |
+
|
| 258 |
+
def set_constraint(self, include_list, constraint_type="depth"):
|
| 259 |
+
if constraint_type == "depth":
|
| 260 |
+
self.__dict__["_depth_include_list"] = include_list.copy()
|
| 261 |
+
elif constraint_type == "expand_ratio":
|
| 262 |
+
self.__dict__["_expand_include_list"] = include_list.copy()
|
| 263 |
+
elif constraint_type == "kernel_size":
|
| 264 |
+
self.__dict__["_ks_include_list"] = include_list.copy()
|
| 265 |
+
else:
|
| 266 |
+
raise NotImplementedError
|
| 267 |
+
|
| 268 |
+
def clear_constraint(self):
|
| 269 |
+
self.__dict__["_depth_include_list"] = None
|
| 270 |
+
self.__dict__["_expand_include_list"] = None
|
| 271 |
+
self.__dict__["_ks_include_list"] = None
|
| 272 |
+
|
| 273 |
+
def sample_active_subnet(self):
|
| 274 |
+
ks_candidates = (
|
| 275 |
+
self.ks_list
|
| 276 |
+
if self.__dict__.get("_ks_include_list", None) is None
|
| 277 |
+
else self.__dict__["_ks_include_list"]
|
| 278 |
+
)
|
| 279 |
+
expand_candidates = (
|
| 280 |
+
self.expand_ratio_list
|
| 281 |
+
if self.__dict__.get("_expand_include_list", None) is None
|
| 282 |
+
else self.__dict__["_expand_include_list"]
|
| 283 |
+
)
|
| 284 |
+
depth_candidates = (
|
| 285 |
+
self.depth_list
|
| 286 |
+
if self.__dict__.get("_depth_include_list", None) is None
|
| 287 |
+
else self.__dict__["_depth_include_list"]
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# sample kernel size
|
| 291 |
+
ks_setting = []
|
| 292 |
+
if not isinstance(ks_candidates[0], list):
|
| 293 |
+
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
|
| 294 |
+
for k_set in ks_candidates:
|
| 295 |
+
k = random.choice(k_set)
|
| 296 |
+
ks_setting.append(k)
|
| 297 |
+
|
| 298 |
+
# sample expand ratio
|
| 299 |
+
expand_setting = []
|
| 300 |
+
if not isinstance(expand_candidates[0], list):
|
| 301 |
+
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
|
| 302 |
+
for e_set in expand_candidates:
|
| 303 |
+
e = random.choice(e_set)
|
| 304 |
+
expand_setting.append(e)
|
| 305 |
+
|
| 306 |
+
# sample depth
|
| 307 |
+
depth_setting = []
|
| 308 |
+
if not isinstance(depth_candidates[0], list):
|
| 309 |
+
depth_candidates = [
|
| 310 |
+
depth_candidates for _ in range(len(self.block_group_info))
|
| 311 |
+
]
|
| 312 |
+
for d_set in depth_candidates:
|
| 313 |
+
d = random.choice(d_set)
|
| 314 |
+
depth_setting.append(d)
|
| 315 |
+
|
| 316 |
+
depth_setting[-1] = 1
|
| 317 |
+
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
|
| 318 |
+
|
| 319 |
+
return {
|
| 320 |
+
"ks": ks_setting,
|
| 321 |
+
"e": expand_setting,
|
| 322 |
+
"d": depth_setting,
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
def get_active_subnet(self, preserve_weight=True):
|
| 326 |
+
first_conv = copy.deepcopy(self.first_conv)
|
| 327 |
+
blocks = [copy.deepcopy(self.blocks[0])]
|
| 328 |
+
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
|
| 329 |
+
classifier = copy.deepcopy(self.classifier)
|
| 330 |
+
|
| 331 |
+
input_channel = blocks[0].conv.out_channels
|
| 332 |
+
# blocks
|
| 333 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 334 |
+
depth = self.runtime_depth[stage_id]
|
| 335 |
+
active_idx = block_idx[:depth]
|
| 336 |
+
stage_blocks = []
|
| 337 |
+
for idx in active_idx:
|
| 338 |
+
stage_blocks.append(
|
| 339 |
+
ResidualBlock(
|
| 340 |
+
self.blocks[idx].conv.get_active_subnet(
|
| 341 |
+
input_channel, preserve_weight
|
| 342 |
+
),
|
| 343 |
+
copy.deepcopy(self.blocks[idx].shortcut),
|
| 344 |
+
)
|
| 345 |
+
)
|
| 346 |
+
input_channel = stage_blocks[-1].conv.out_channels
|
| 347 |
+
blocks += stage_blocks
|
| 348 |
+
|
| 349 |
+
_subnet = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
|
| 350 |
+
_subnet.set_bn_param(**self.get_bn_param())
|
| 351 |
+
return _subnet
|
| 352 |
+
|
| 353 |
+
def get_active_net_config(self):
|
| 354 |
+
first_conv_config = self.first_conv.config
|
| 355 |
+
first_block_config = self.blocks[0].config
|
| 356 |
+
feature_mix_layer_config = self.feature_mix_layer.config
|
| 357 |
+
classifier_config = self.classifier.config
|
| 358 |
+
|
| 359 |
+
block_config_list = [first_block_config]
|
| 360 |
+
input_channel = first_block_config["conv"]["out_channels"]
|
| 361 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 362 |
+
depth = self.runtime_depth[stage_id]
|
| 363 |
+
active_idx = block_idx[:depth]
|
| 364 |
+
stage_blocks = []
|
| 365 |
+
for idx in active_idx:
|
| 366 |
+
stage_blocks.append(
|
| 367 |
+
{
|
| 368 |
+
"name": ResidualBlock.__name__,
|
| 369 |
+
"conv": self.blocks[idx].conv.get_active_subnet_config(
|
| 370 |
+
input_channel
|
| 371 |
+
),
|
| 372 |
+
"shortcut": self.blocks[idx].shortcut.config
|
| 373 |
+
if self.blocks[idx].shortcut is not None
|
| 374 |
+
else None,
|
| 375 |
+
}
|
| 376 |
+
)
|
| 377 |
+
try:
|
| 378 |
+
input_channel = self.blocks[idx].conv.active_out_channel
|
| 379 |
+
except Exception:
|
| 380 |
+
input_channel = self.blocks[idx].conv.out_channels
|
| 381 |
+
block_config_list += stage_blocks
|
| 382 |
+
|
| 383 |
+
return {
|
| 384 |
+
"name": ProxylessNASNets.__name__,
|
| 385 |
+
"bn": self.get_bn_param(),
|
| 386 |
+
"first_conv": first_conv_config,
|
| 387 |
+
"blocks": block_config_list,
|
| 388 |
+
"feature_mix_layer": feature_mix_layer_config,
|
| 389 |
+
"classifier": classifier_config,
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
""" Width Related Methods """
|
| 393 |
+
|
| 394 |
+
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
| 395 |
+
for block in self.blocks[1:]:
|
| 396 |
+
block.conv.re_organize_middle_weights(expand_ratio_stage)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class DYNProxylessNASNets_Cifar(ProxylessNASNets_Cifar):
|
| 401 |
+
def __init__(
|
| 402 |
+
self,
|
| 403 |
+
n_classes=10,
|
| 404 |
+
bn_param=(0.1, 1e-3),
|
| 405 |
+
dropout_rate=0.1,
|
| 406 |
+
base_stage_width=None,
|
| 407 |
+
width_mult=1.0,
|
| 408 |
+
ks_list=3,
|
| 409 |
+
expand_ratio_list=6,
|
| 410 |
+
depth_list=4,
|
| 411 |
+
):
|
| 412 |
+
|
| 413 |
+
self.width_mult = width_mult
|
| 414 |
+
self.ks_list = val2list(ks_list, 1)
|
| 415 |
+
self.expand_ratio_list = val2list(expand_ratio_list, 1)
|
| 416 |
+
self.depth_list = val2list(depth_list, 1)
|
| 417 |
+
|
| 418 |
+
self.ks_list.sort()
|
| 419 |
+
self.expand_ratio_list.sort()
|
| 420 |
+
self.depth_list.sort()
|
| 421 |
+
|
| 422 |
+
if base_stage_width == "MBV2":
|
| 423 |
+
# MobileNetV2 Stage Width
|
| 424 |
+
base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280]
|
| 425 |
+
else:
|
| 426 |
+
# ProxylessNAS Stage Width
|
| 427 |
+
base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280]
|
| 428 |
+
|
| 429 |
+
input_channel = make_divisible(
|
| 430 |
+
base_stage_width[0] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 431 |
+
)
|
| 432 |
+
first_block_width = make_divisible(
|
| 433 |
+
base_stage_width[1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 434 |
+
)
|
| 435 |
+
last_channel = make_divisible(
|
| 436 |
+
base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# first conv layer
|
| 440 |
+
first_conv = ConvLayer(
|
| 441 |
+
3,
|
| 442 |
+
input_channel,
|
| 443 |
+
kernel_size=3,
|
| 444 |
+
stride=1,
|
| 445 |
+
use_bn=True,
|
| 446 |
+
act_func="relu6",
|
| 447 |
+
ops_order="weight_bn_act",
|
| 448 |
+
)
|
| 449 |
+
# first block
|
| 450 |
+
first_block_conv = MBConvLayer(
|
| 451 |
+
in_channels=input_channel,
|
| 452 |
+
out_channels=first_block_width,
|
| 453 |
+
kernel_size=3,
|
| 454 |
+
stride=1,
|
| 455 |
+
expand_ratio=1,
|
| 456 |
+
act_func="relu6",
|
| 457 |
+
)
|
| 458 |
+
first_block = ResidualBlock(first_block_conv, None)
|
| 459 |
+
|
| 460 |
+
input_channel = first_block_width
|
| 461 |
+
# inverted residual blocks
|
| 462 |
+
self.block_group_info = []
|
| 463 |
+
blocks = [first_block]
|
| 464 |
+
_block_index = 1
|
| 465 |
+
|
| 466 |
+
stride_stages = [1, 2, 2, 1, 2, 1]
|
| 467 |
+
n_block_list = [max(self.depth_list)] * 5 + [1]
|
| 468 |
+
|
| 469 |
+
width_list = []
|
| 470 |
+
for base_width in base_stage_width[2:-1]:
|
| 471 |
+
width = make_divisible(
|
| 472 |
+
base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 473 |
+
)
|
| 474 |
+
width_list.append(width)
|
| 475 |
+
|
| 476 |
+
for width, n_block, s in zip(width_list, n_block_list, stride_stages):
|
| 477 |
+
self.block_group_info.append([_block_index + i for i in range(n_block)])
|
| 478 |
+
_block_index += n_block
|
| 479 |
+
|
| 480 |
+
output_channel = width
|
| 481 |
+
for i in range(n_block):
|
| 482 |
+
if i == 0:
|
| 483 |
+
stride = s
|
| 484 |
+
else:
|
| 485 |
+
stride = 1
|
| 486 |
+
|
| 487 |
+
mobile_inverted_conv = DynamicMBConvLayer(
|
| 488 |
+
in_channel_list=val2list(input_channel, 1),
|
| 489 |
+
out_channel_list=val2list(output_channel, 1),
|
| 490 |
+
kernel_size_list=ks_list,
|
| 491 |
+
expand_ratio_list=expand_ratio_list,
|
| 492 |
+
stride=stride,
|
| 493 |
+
act_func="relu6",
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
if stride == 1 and input_channel == output_channel:
|
| 497 |
+
shortcut = IdentityLayer(input_channel, input_channel)
|
| 498 |
+
else:
|
| 499 |
+
shortcut = None
|
| 500 |
+
|
| 501 |
+
mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut)
|
| 502 |
+
|
| 503 |
+
blocks.append(mb_inverted_block)
|
| 504 |
+
input_channel = output_channel
|
| 505 |
+
# 1x1_conv before global average pooling
|
| 506 |
+
feature_mix_layer = ConvLayer(
|
| 507 |
+
input_channel,
|
| 508 |
+
last_channel,
|
| 509 |
+
kernel_size=1,
|
| 510 |
+
use_bn=True,
|
| 511 |
+
act_func="relu6",
|
| 512 |
+
)
|
| 513 |
+
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
| 514 |
+
|
| 515 |
+
super(DYNProxylessNASNets_Cifar, self).__init__(
|
| 516 |
+
first_conv, blocks, feature_mix_layer, classifier
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# set bn param
|
| 520 |
+
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
|
| 521 |
+
|
| 522 |
+
# runtime_depth
|
| 523 |
+
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
|
| 524 |
+
|
| 525 |
+
""" MyNetwork required methods """
|
| 526 |
+
|
| 527 |
+
@staticmethod
|
| 528 |
+
def name():
|
| 529 |
+
return "DYNProxylessNASNets_Cifar"
|
| 530 |
+
|
| 531 |
+
def forward(self, x):
|
| 532 |
+
# first conv
|
| 533 |
+
x = self.first_conv(x)
|
| 534 |
+
# first block
|
| 535 |
+
x = self.blocks[0](x)
|
| 536 |
+
|
| 537 |
+
# blocks
|
| 538 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 539 |
+
depth = self.runtime_depth[stage_id]
|
| 540 |
+
active_idx = block_idx[:depth]
|
| 541 |
+
for idx in active_idx:
|
| 542 |
+
x = self.blocks[idx](x)
|
| 543 |
+
|
| 544 |
+
# feature_mix_layer
|
| 545 |
+
x = self.feature_mix_layer(x)
|
| 546 |
+
x = x.mean(3).mean(2)
|
| 547 |
+
|
| 548 |
+
x = self.classifier(x)
|
| 549 |
+
return x
|
| 550 |
+
|
| 551 |
+
@property
|
| 552 |
+
def module_str(self):
|
| 553 |
+
_str = self.first_conv.module_str + "\n"
|
| 554 |
+
_str += self.blocks[0].module_str + "\n"
|
| 555 |
+
|
| 556 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 557 |
+
depth = self.runtime_depth[stage_id]
|
| 558 |
+
active_idx = block_idx[:depth]
|
| 559 |
+
for idx in active_idx:
|
| 560 |
+
_str += self.blocks[idx].module_str + "\n"
|
| 561 |
+
_str += self.feature_mix_layer.module_str + "\n"
|
| 562 |
+
_str += self.classifier.module_str + "\n"
|
| 563 |
+
return _str
|
| 564 |
+
|
| 565 |
+
@property
|
| 566 |
+
def config(self):
|
| 567 |
+
return {
|
| 568 |
+
"name": DYNProxylessNASNets_Cifar.__name__,
|
| 569 |
+
"bn": self.get_bn_param(),
|
| 570 |
+
"first_conv": self.first_conv.config,
|
| 571 |
+
"blocks": [block.config for block in self.blocks],
|
| 572 |
+
"feature_mix_layer": None
|
| 573 |
+
if self.feature_mix_layer is None
|
| 574 |
+
else self.feature_mix_layer.config,
|
| 575 |
+
"classifier": self.classifier.config,
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
@staticmethod
|
| 579 |
+
def build_from_config(config):
|
| 580 |
+
raise ValueError("do not support this function")
|
| 581 |
+
|
| 582 |
+
@property
|
| 583 |
+
def grouped_block_index(self):
|
| 584 |
+
return self.block_group_info
|
| 585 |
+
|
| 586 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 587 |
+
model_dict = self.state_dict()
|
| 588 |
+
for key in state_dict:
|
| 589 |
+
if ".mobile_inverted_conv." in key:
|
| 590 |
+
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
|
| 591 |
+
else:
|
| 592 |
+
new_key = key
|
| 593 |
+
if new_key in model_dict:
|
| 594 |
+
pass
|
| 595 |
+
elif ".bn.bn." in new_key:
|
| 596 |
+
new_key = new_key.replace(".bn.bn.", ".bn.")
|
| 597 |
+
elif ".conv.conv.weight" in new_key:
|
| 598 |
+
new_key = new_key.replace(".conv.conv.weight", ".conv.weight")
|
| 599 |
+
elif ".linear.linear." in new_key:
|
| 600 |
+
new_key = new_key.replace(".linear.linear.", ".linear.")
|
| 601 |
+
##############################################################################
|
| 602 |
+
elif ".linear." in new_key:
|
| 603 |
+
new_key = new_key.replace(".linear.", ".linear.linear.")
|
| 604 |
+
elif "bn." in new_key:
|
| 605 |
+
new_key = new_key.replace("bn.", "bn.bn.")
|
| 606 |
+
elif "conv.weight" in new_key:
|
| 607 |
+
new_key = new_key.replace("conv.weight", "conv.conv.weight")
|
| 608 |
+
else:
|
| 609 |
+
raise ValueError(new_key)
|
| 610 |
+
assert new_key in model_dict, "%s" % new_key
|
| 611 |
+
model_dict[new_key] = state_dict[key]
|
| 612 |
+
super(DYNProxylessNASNets_Cifar, self).load_state_dict(model_dict)
|
| 613 |
+
|
| 614 |
+
""" set, sample and get active sub-networks """
|
| 615 |
+
|
| 616 |
+
def set_max_net(self):
|
| 617 |
+
self.set_active_subnet(
|
| 618 |
+
ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list)
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
|
| 622 |
+
ks = val2list(ks, len(self.blocks) - 1)
|
| 623 |
+
expand_ratio = val2list(e, len(self.blocks) - 1)
|
| 624 |
+
depth = val2list(d, len(self.block_group_info))
|
| 625 |
+
|
| 626 |
+
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
|
| 627 |
+
if k is not None:
|
| 628 |
+
block.conv.active_kernel_size = k
|
| 629 |
+
if e is not None:
|
| 630 |
+
block.conv.active_expand_ratio = e
|
| 631 |
+
|
| 632 |
+
for i, d in enumerate(depth):
|
| 633 |
+
if d is not None:
|
| 634 |
+
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
|
| 635 |
+
|
| 636 |
+
def set_constraint(self, include_list, constraint_type="depth"):
|
| 637 |
+
if constraint_type == "depth":
|
| 638 |
+
self.__dict__["_depth_include_list"] = include_list.copy()
|
| 639 |
+
elif constraint_type == "expand_ratio":
|
| 640 |
+
self.__dict__["_expand_include_list"] = include_list.copy()
|
| 641 |
+
elif constraint_type == "kernel_size":
|
| 642 |
+
self.__dict__["_ks_include_list"] = include_list.copy()
|
| 643 |
+
else:
|
| 644 |
+
raise NotImplementedError
|
| 645 |
+
|
| 646 |
+
def clear_constraint(self):
|
| 647 |
+
self.__dict__["_depth_include_list"] = None
|
| 648 |
+
self.__dict__["_expand_include_list"] = None
|
| 649 |
+
self.__dict__["_ks_include_list"] = None
|
| 650 |
+
|
| 651 |
+
def sample_active_subnet(self):
|
| 652 |
+
ks_candidates = (
|
| 653 |
+
self.ks_list
|
| 654 |
+
if self.__dict__.get("_ks_include_list", None) is None
|
| 655 |
+
else self.__dict__["_ks_include_list"]
|
| 656 |
+
)
|
| 657 |
+
expand_candidates = (
|
| 658 |
+
self.expand_ratio_list
|
| 659 |
+
if self.__dict__.get("_expand_include_list", None) is None
|
| 660 |
+
else self.__dict__["_expand_include_list"]
|
| 661 |
+
)
|
| 662 |
+
depth_candidates = (
|
| 663 |
+
self.depth_list
|
| 664 |
+
if self.__dict__.get("_depth_include_list", None) is None
|
| 665 |
+
else self.__dict__["_depth_include_list"]
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# sample kernel size
|
| 669 |
+
ks_setting = []
|
| 670 |
+
if not isinstance(ks_candidates[0], list):
|
| 671 |
+
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
|
| 672 |
+
for k_set in ks_candidates:
|
| 673 |
+
k = random.choice(k_set)
|
| 674 |
+
ks_setting.append(k)
|
| 675 |
+
|
| 676 |
+
# sample expand ratio
|
| 677 |
+
expand_setting = []
|
| 678 |
+
if not isinstance(expand_candidates[0], list):
|
| 679 |
+
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
|
| 680 |
+
for e_set in expand_candidates:
|
| 681 |
+
e = random.choice(e_set)
|
| 682 |
+
expand_setting.append(e)
|
| 683 |
+
|
| 684 |
+
# sample depth
|
| 685 |
+
depth_setting = []
|
| 686 |
+
if not isinstance(depth_candidates[0], list):
|
| 687 |
+
depth_candidates = [
|
| 688 |
+
depth_candidates for _ in range(len(self.block_group_info))
|
| 689 |
+
]
|
| 690 |
+
for d_set in depth_candidates:
|
| 691 |
+
d = random.choice(d_set)
|
| 692 |
+
depth_setting.append(d)
|
| 693 |
+
|
| 694 |
+
depth_setting[-1] = 1
|
| 695 |
+
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
|
| 696 |
+
|
| 697 |
+
return {
|
| 698 |
+
"ks": ks_setting,
|
| 699 |
+
"e": expand_setting,
|
| 700 |
+
"d": depth_setting,
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
def get_active_subnet(self, preserve_weight=True):
|
| 704 |
+
first_conv = copy.deepcopy(self.first_conv)
|
| 705 |
+
blocks = [copy.deepcopy(self.blocks[0])]
|
| 706 |
+
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
|
| 707 |
+
classifier = copy.deepcopy(self.classifier)
|
| 708 |
+
|
| 709 |
+
input_channel = blocks[0].conv.out_channels
|
| 710 |
+
# blocks
|
| 711 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 712 |
+
depth = self.runtime_depth[stage_id]
|
| 713 |
+
active_idx = block_idx[:depth]
|
| 714 |
+
stage_blocks = []
|
| 715 |
+
for idx in active_idx:
|
| 716 |
+
stage_blocks.append(
|
| 717 |
+
ResidualBlock(
|
| 718 |
+
self.blocks[idx].conv.get_active_subnet(
|
| 719 |
+
input_channel, preserve_weight
|
| 720 |
+
),
|
| 721 |
+
copy.deepcopy(self.blocks[idx].shortcut),
|
| 722 |
+
)
|
| 723 |
+
)
|
| 724 |
+
input_channel = stage_blocks[-1].conv.out_channels
|
| 725 |
+
blocks += stage_blocks
|
| 726 |
+
|
| 727 |
+
_subnet = ProxylessNASNets_Cifar(first_conv, blocks, feature_mix_layer, classifier)
|
| 728 |
+
_subnet.set_bn_param(**self.get_bn_param())
|
| 729 |
+
return _subnet
|
| 730 |
+
|
| 731 |
+
def get_active_net_config(self):
|
| 732 |
+
first_conv_config = self.first_conv.config
|
| 733 |
+
first_block_config = self.blocks[0].config
|
| 734 |
+
feature_mix_layer_config = self.feature_mix_layer.config
|
| 735 |
+
classifier_config = self.classifier.config
|
| 736 |
+
|
| 737 |
+
block_config_list = [first_block_config]
|
| 738 |
+
input_channel = first_block_config["conv"]["out_channels"]
|
| 739 |
+
for stage_id, block_idx in enumerate(self.block_group_info):
|
| 740 |
+
depth = self.runtime_depth[stage_id]
|
| 741 |
+
active_idx = block_idx[:depth]
|
| 742 |
+
stage_blocks = []
|
| 743 |
+
for idx in active_idx:
|
| 744 |
+
stage_blocks.append(
|
| 745 |
+
{
|
| 746 |
+
"name": ResidualBlock.__name__,
|
| 747 |
+
"conv": self.blocks[idx].conv.get_active_subnet_config(
|
| 748 |
+
input_channel
|
| 749 |
+
),
|
| 750 |
+
"shortcut": self.blocks[idx].shortcut.config
|
| 751 |
+
if self.blocks[idx].shortcut is not None
|
| 752 |
+
else None,
|
| 753 |
+
}
|
| 754 |
+
)
|
| 755 |
+
try:
|
| 756 |
+
input_channel = self.blocks[idx].conv.active_out_channel
|
| 757 |
+
except Exception:
|
| 758 |
+
input_channel = self.blocks[idx].conv.out_channels
|
| 759 |
+
block_config_list += stage_blocks
|
| 760 |
+
|
| 761 |
+
return {
|
| 762 |
+
"name": ProxylessNASNets_Cifar.__name__,
|
| 763 |
+
"bn": self.get_bn_param(),
|
| 764 |
+
"first_conv": first_conv_config,
|
| 765 |
+
"blocks": block_config_list,
|
| 766 |
+
"feature_mix_layer": feature_mix_layer_config,
|
| 767 |
+
"classifier": classifier_config,
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
""" Width Related Methods """
|
| 771 |
+
|
| 772 |
+
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
| 773 |
+
for block in self.blocks[1:]:
|
| 774 |
+
block.conv.re_organize_middle_weights(expand_ratio_stage)
|
proard/classification/elastic_nn/networks/dyn_resnets.py
ADDED
|
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
from proard.classification.elastic_nn.modules.dynamic_layers import (
|
| 4 |
+
DynamicConvLayer,
|
| 5 |
+
DynamicLinearLayer,
|
| 6 |
+
)
|
| 7 |
+
from proard.classification.elastic_nn.modules.dynamic_layers import (
|
| 8 |
+
DynamicResNetBottleneckBlock,
|
| 9 |
+
)
|
| 10 |
+
from proard.utils.layers import IdentityLayer, ResidualBlock
|
| 11 |
+
from proard.classification.networks import ResNets,ResNets_Cifar
|
| 12 |
+
from proard.utils import make_divisible, val2list, MyNetwork
|
| 13 |
+
|
| 14 |
+
__all__ = ["DYNResNets","DYNResNets_Cifar"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DYNResNets(ResNets):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
n_classes=1000,
|
| 21 |
+
bn_param=(0.1, 1e-5),
|
| 22 |
+
dropout_rate=0,
|
| 23 |
+
depth_list=2,
|
| 24 |
+
expand_ratio_list=0.25,
|
| 25 |
+
width_mult_list=1.0,
|
| 26 |
+
):
|
| 27 |
+
|
| 28 |
+
self.depth_list = val2list(depth_list)
|
| 29 |
+
self.expand_ratio_list = val2list(expand_ratio_list)
|
| 30 |
+
self.width_mult_list = val2list(width_mult_list)
|
| 31 |
+
# sort
|
| 32 |
+
self.depth_list.sort()
|
| 33 |
+
self.expand_ratio_list.sort()
|
| 34 |
+
self.width_mult_list.sort()
|
| 35 |
+
|
| 36 |
+
input_channel = [
|
| 37 |
+
make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 38 |
+
for width_mult in self.width_mult_list
|
| 39 |
+
]
|
| 40 |
+
mid_input_channel = [
|
| 41 |
+
make_divisible(channel // 2, MyNetwork.CHANNEL_DIVISIBLE)
|
| 42 |
+
for channel in input_channel
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
|
| 46 |
+
for i, width in enumerate(stage_width_list):
|
| 47 |
+
stage_width_list[i] = [
|
| 48 |
+
make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 49 |
+
for width_mult in self.width_mult_list
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
n_block_list = [
|
| 53 |
+
base_depth + max(self.depth_list) for base_depth in ResNets.BASE_DEPTH_LIST
|
| 54 |
+
]
|
| 55 |
+
stride_list = [1, 2, 2, 2]
|
| 56 |
+
|
| 57 |
+
# build input stem
|
| 58 |
+
input_stem = [
|
| 59 |
+
DynamicConvLayer(
|
| 60 |
+
val2list(3),
|
| 61 |
+
mid_input_channel,
|
| 62 |
+
3,
|
| 63 |
+
stride=2,
|
| 64 |
+
use_bn=True,
|
| 65 |
+
act_func="relu",
|
| 66 |
+
),
|
| 67 |
+
ResidualBlock(
|
| 68 |
+
DynamicConvLayer(
|
| 69 |
+
mid_input_channel,
|
| 70 |
+
mid_input_channel,
|
| 71 |
+
3,
|
| 72 |
+
stride=1,
|
| 73 |
+
use_bn=True,
|
| 74 |
+
act_func="relu",
|
| 75 |
+
),
|
| 76 |
+
IdentityLayer(mid_input_channel, mid_input_channel),
|
| 77 |
+
),
|
| 78 |
+
DynamicConvLayer(
|
| 79 |
+
mid_input_channel,
|
| 80 |
+
input_channel,
|
| 81 |
+
3,
|
| 82 |
+
stride=1,
|
| 83 |
+
use_bn=True,
|
| 84 |
+
act_func="relu",
|
| 85 |
+
),
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
# blocks
|
| 89 |
+
blocks = []
|
| 90 |
+
for d, width, s in zip(n_block_list, stage_width_list, stride_list):
|
| 91 |
+
for i in range(d):
|
| 92 |
+
stride = s if i == 0 else 1
|
| 93 |
+
bottleneck_block = DynamicResNetBottleneckBlock(
|
| 94 |
+
input_channel,
|
| 95 |
+
width,
|
| 96 |
+
expand_ratio_list=self.expand_ratio_list,
|
| 97 |
+
kernel_size=3,
|
| 98 |
+
stride=stride,
|
| 99 |
+
act_func="relu",
|
| 100 |
+
downsample_mode="avgpool_conv",
|
| 101 |
+
)
|
| 102 |
+
blocks.append(bottleneck_block)
|
| 103 |
+
input_channel = width
|
| 104 |
+
# classifier
|
| 105 |
+
classifier = DynamicLinearLayer(
|
| 106 |
+
input_channel, n_classes, dropout_rate=dropout_rate
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
super(DYNResNets, self).__init__(input_stem, blocks, classifier)
|
| 110 |
+
|
| 111 |
+
# set bn param
|
| 112 |
+
self.set_bn_param(*bn_param)
|
| 113 |
+
|
| 114 |
+
# runtime_depth
|
| 115 |
+
self.input_stem_skipping = 0
|
| 116 |
+
self.runtime_depth = [0] * len(n_block_list)
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def ks_list(self):
|
| 120 |
+
return [3]
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def name():
|
| 124 |
+
return "DYNResNets"
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
for layer in self.input_stem:
|
| 128 |
+
if (
|
| 129 |
+
self.input_stem_skipping > 0
|
| 130 |
+
and isinstance(layer, ResidualBlock)
|
| 131 |
+
and isinstance(layer.shortcut, IdentityLayer)
|
| 132 |
+
):
|
| 133 |
+
pass
|
| 134 |
+
else:
|
| 135 |
+
x = layer(x)
|
| 136 |
+
x = self.max_pooling(x)
|
| 137 |
+
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
| 138 |
+
depth_param = self.runtime_depth[stage_id]
|
| 139 |
+
active_idx = block_idx[: len(block_idx) - depth_param]
|
| 140 |
+
for idx in active_idx:
|
| 141 |
+
x = self.blocks[idx](x)
|
| 142 |
+
x = self.global_avg_pool(x)
|
| 143 |
+
x = self.classifier(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def module_str(self):
|
| 148 |
+
_str = ""
|
| 149 |
+
for layer in self.input_stem:
|
| 150 |
+
if (
|
| 151 |
+
self.input_stem_skipping > 0
|
| 152 |
+
and isinstance(layer, ResidualBlock)
|
| 153 |
+
and isinstance(layer.shortcut, IdentityLayer)
|
| 154 |
+
):
|
| 155 |
+
pass
|
| 156 |
+
else:
|
| 157 |
+
_str += layer.module_str + "\n"
|
| 158 |
+
_str += "max_pooling(ks=3, stride=2)\n"
|
| 159 |
+
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
| 160 |
+
depth_param = self.runtime_depth[stage_id]
|
| 161 |
+
active_idx = block_idx[: len(block_idx) - depth_param]
|
| 162 |
+
for idx in active_idx:
|
| 163 |
+
_str += self.blocks[idx].module_str + "\n"
|
| 164 |
+
_str += self.global_avg_pool.__repr__() + "\n"
|
| 165 |
+
_str += self.classifier.module_str
|
| 166 |
+
return _str
|
| 167 |
+
|
| 168 |
+
@property
|
| 169 |
+
def config(self):
|
| 170 |
+
return {
|
| 171 |
+
"name": DYNResNets.__name__,
|
| 172 |
+
"bn": self.get_bn_param(),
|
| 173 |
+
"input_stem": [layer.config for layer in self.input_stem],
|
| 174 |
+
"blocks": [block.config for block in self.blocks],
|
| 175 |
+
"classifier": self.classifier.config,
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
@staticmethod
|
| 179 |
+
def build_from_config(config):
|
| 180 |
+
raise ValueError("do not support this function")
|
| 181 |
+
|
| 182 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 183 |
+
model_dict = self.state_dict()
|
| 184 |
+
for key in state_dict:
|
| 185 |
+
new_key = key
|
| 186 |
+
if new_key in model_dict:
|
| 187 |
+
pass
|
| 188 |
+
elif ".linear." in new_key:
|
| 189 |
+
new_key = new_key.replace(".linear.", ".linear.linear.")
|
| 190 |
+
elif "bn." in new_key:
|
| 191 |
+
new_key = new_key.replace("bn.", "bn.bn.")
|
| 192 |
+
elif "conv.weight" in new_key:
|
| 193 |
+
new_key = new_key.replace("conv.weight", "conv.conv.weight")
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError(new_key)
|
| 196 |
+
assert new_key in model_dict, "%s" % new_key
|
| 197 |
+
model_dict[new_key] = state_dict[key]
|
| 198 |
+
super(DYNResNets, self).load_state_dict(model_dict)
|
| 199 |
+
|
| 200 |
+
""" set, sample and get active sub-networks """
|
| 201 |
+
|
| 202 |
+
def set_max_net(self):
|
| 203 |
+
self.set_active_subnet(
|
| 204 |
+
d=max(self.depth_list),
|
| 205 |
+
e=max(self.expand_ratio_list),
|
| 206 |
+
w=len(self.width_mult_list) - 1,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def set_active_subnet(self, d=None, e=None, w=None, **kwargs):
|
| 210 |
+
depth = val2list(d, len(ResNets.BASE_DEPTH_LIST) + 1)
|
| 211 |
+
expand_ratio = val2list(e, len(self.blocks))
|
| 212 |
+
width_mult = val2list(w, len(ResNets.BASE_DEPTH_LIST) + 2)
|
| 213 |
+
|
| 214 |
+
for block, e in zip(self.blocks, expand_ratio):
|
| 215 |
+
if e is not None:
|
| 216 |
+
block.active_expand_ratio = e
|
| 217 |
+
|
| 218 |
+
if width_mult[0] is not None:
|
| 219 |
+
self.input_stem[1].conv.active_out_channel = self.input_stem[
|
| 220 |
+
0
|
| 221 |
+
].active_out_channel = self.input_stem[0].out_channel_list[width_mult[0]]
|
| 222 |
+
if width_mult[1] is not None:
|
| 223 |
+
self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[
|
| 224 |
+
width_mult[1]
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
if depth[0] is not None:
|
| 228 |
+
self.input_stem_skipping = depth[0] != max(self.depth_list)
|
| 229 |
+
for stage_id, (block_idx, d, w) in enumerate(
|
| 230 |
+
zip(self.grouped_block_index, depth[1:], width_mult[2:])
|
| 231 |
+
):
|
| 232 |
+
if d is not None:
|
| 233 |
+
self.runtime_depth[stage_id] = max(self.depth_list) - d
|
| 234 |
+
if w is not None:
|
| 235 |
+
for idx in block_idx:
|
| 236 |
+
self.blocks[idx].active_out_channel = self.blocks[
|
| 237 |
+
idx
|
| 238 |
+
].out_channel_list[w]
|
| 239 |
+
|
| 240 |
+
def sample_active_subnet(self):
|
| 241 |
+
# sample expand ratio
|
| 242 |
+
expand_setting = []
|
| 243 |
+
for block in self.blocks:
|
| 244 |
+
expand_setting.append(random.choice(block.expand_ratio_list))
|
| 245 |
+
|
| 246 |
+
# sample depth
|
| 247 |
+
depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])]
|
| 248 |
+
for stage_id in range(len(ResNets.BASE_DEPTH_LIST)):
|
| 249 |
+
depth_setting.append(random.choice(self.depth_list))
|
| 250 |
+
|
| 251 |
+
# sample width_mult
|
| 252 |
+
width_mult_setting = [
|
| 253 |
+
random.choice(list(range(len(self.input_stem[0].out_channel_list)))),
|
| 254 |
+
random.choice(list(range(len(self.input_stem[2].out_channel_list)))),
|
| 255 |
+
]
|
| 256 |
+
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
| 257 |
+
stage_first_block = self.blocks[block_idx[0]]
|
| 258 |
+
width_mult_setting.append(
|
| 259 |
+
random.choice(list(range(len(stage_first_block.out_channel_list))))
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
arch_config = {"d": depth_setting, "e": expand_setting, "w": width_mult_setting}
|
| 263 |
+
self.set_active_subnet(**arch_config)
|
| 264 |
+
return arch_config
|
| 265 |
+
|
| 266 |
+
def get_active_subnet(self, preserve_weight=True):
|
| 267 |
+
input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)]
|
| 268 |
+
if self.input_stem_skipping <= 0:
|
| 269 |
+
input_stem.append(
|
| 270 |
+
ResidualBlock(
|
| 271 |
+
self.input_stem[1].conv.get_active_subnet(
|
| 272 |
+
self.input_stem[0].active_out_channel, preserve_weight
|
| 273 |
+
),
|
| 274 |
+
IdentityLayer(
|
| 275 |
+
self.input_stem[0].active_out_channel,
|
| 276 |
+
self.input_stem[0].active_out_channel,
|
| 277 |
+
),
|
| 278 |
+
)
|
| 279 |
+
)
|
| 280 |
+
input_stem.append(
|
| 281 |
+
self.input_stem[2].get_active_subnet(
|
| 282 |
+
self.input_stem[0].active_out_channel, preserve_weight
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
input_channel = self.input_stem[2].active_out_channel
|
| 286 |
+
|
| 287 |
+
blocks = []
|
| 288 |
+
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
| 289 |
+
depth_param = self.runtime_depth[stage_id]
|
| 290 |
+
active_idx = block_idx[: len(block_idx) - depth_param]
|
| 291 |
+
for idx in active_idx:
|
| 292 |
+
blocks.append(
|
| 293 |
+
self.blocks[idx].get_active_subnet(input_channel, preserve_weight)
|
| 294 |
+
)
|
| 295 |
+
input_channel = self.blocks[idx].active_out_channel
|
| 296 |
+
classifier = self.classifier.get_active_subnet(input_channel, preserve_weight)
|
| 297 |
+
subnet = ResNets(input_stem, blocks, classifier)
|
| 298 |
+
|
| 299 |
+
subnet.set_bn_param(**self.get_bn_param())
|
| 300 |
+
return subnet
|
| 301 |
+
|
| 302 |
+
def get_active_net_config(self):
|
| 303 |
+
input_stem_config = [self.input_stem[0].get_active_subnet_config(3)]
|
| 304 |
+
if self.input_stem_skipping <= 0:
|
| 305 |
+
input_stem_config.append(
|
| 306 |
+
{
|
| 307 |
+
"name": ResidualBlock.__name__,
|
| 308 |
+
"conv": self.input_stem[1].conv.get_active_subnet_config(
|
| 309 |
+
self.input_stem[0].active_out_channel
|
| 310 |
+
),
|
| 311 |
+
"shortcut": IdentityLayer(
|
| 312 |
+
self.input_stem[0].active_out_channel,
|
| 313 |
+
self.input_stem[0].active_out_channel,
|
| 314 |
+
),
|
| 315 |
+
}
|
| 316 |
+
)
|
| 317 |
+
input_stem_config.append(
|
| 318 |
+
self.input_stem[2].get_active_subnet_config(
|
| 319 |
+
self.input_stem[0].active_out_channel
|
| 320 |
+
)
|
| 321 |
+
)
|
| 322 |
+
input_channel = self.input_stem[2].active_out_channel
|
| 323 |
+
|
| 324 |
+
blocks_config = []
|
| 325 |
+
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
| 326 |
+
depth_param = self.runtime_depth[stage_id]
|
| 327 |
+
active_idx = block_idx[: len(block_idx) - depth_param]
|
| 328 |
+
for idx in active_idx:
|
| 329 |
+
blocks_config.append(
|
| 330 |
+
self.blocks[idx].get_active_subnet_config(input_channel)
|
| 331 |
+
)
|
| 332 |
+
input_channel = self.blocks[idx].active_out_channel
|
| 333 |
+
classifier_config = self.classifier.get_active_subnet_config(input_channel)
|
| 334 |
+
return {
|
| 335 |
+
"name": ResNets.__name__,
|
| 336 |
+
"bn": self.get_bn_param(),
|
| 337 |
+
"input_stem": input_stem_config,
|
| 338 |
+
"blocks": blocks_config,
|
| 339 |
+
"classifier": classifier_config,
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
""" Width Related Methods """
|
| 343 |
+
|
| 344 |
+
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
| 345 |
+
for block in self.blocks:
|
| 346 |
+
block.re_organize_middle_weights(expand_ratio_stage)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class DYNResNets_Cifar(ResNets_Cifar):
|
| 351 |
+
def __init__(
|
| 352 |
+
self,
|
| 353 |
+
n_classes=10,
|
| 354 |
+
bn_param=(0.1, 1e-5),
|
| 355 |
+
dropout_rate=0,
|
| 356 |
+
depth_list=0,
|
| 357 |
+
expand_ratio_list=0.25,
|
| 358 |
+
width_mult_list=1.0,
|
| 359 |
+
):
|
| 360 |
+
|
| 361 |
+
self.depth_list = val2list(depth_list)
|
| 362 |
+
self.expand_ratio_list = val2list(expand_ratio_list)
|
| 363 |
+
self.width_mult_list = val2list(width_mult_list)
|
| 364 |
+
# sort
|
| 365 |
+
self.depth_list.sort()
|
| 366 |
+
self.expand_ratio_list.sort()
|
| 367 |
+
self.width_mult_list.sort()
|
| 368 |
+
|
| 369 |
+
input_channel = [
|
| 370 |
+
make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 371 |
+
for width_mult in self.width_mult_list
|
| 372 |
+
]
|
| 373 |
+
mid_input_channel = [
|
| 374 |
+
make_divisible(channel // 2, MyNetwork.CHANNEL_DIVISIBLE)
|
| 375 |
+
for channel in input_channel
|
| 376 |
+
]
|
| 377 |
+
|
| 378 |
+
stage_width_list = ResNets_Cifar.STAGE_WIDTH_LIST.copy()
|
| 379 |
+
for i, width in enumerate(stage_width_list):
|
| 380 |
+
stage_width_list[i] = [
|
| 381 |
+
make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 382 |
+
for width_mult in self.width_mult_list
|
| 383 |
+
]
|
| 384 |
+
|
| 385 |
+
n_block_list = [
|
| 386 |
+
base_depth + max(self.depth_list) for base_depth in ResNets_Cifar.BASE_DEPTH_LIST
|
| 387 |
+
]
|
| 388 |
+
stride_list = [1, 2, 2, 2]
|
| 389 |
+
|
| 390 |
+
# build input stem
|
| 391 |
+
input_stem = [
|
| 392 |
+
DynamicConvLayer(
|
| 393 |
+
val2list(3),
|
| 394 |
+
mid_input_channel,
|
| 395 |
+
3,
|
| 396 |
+
stride=1,
|
| 397 |
+
use_bn=True,
|
| 398 |
+
act_func="relu",
|
| 399 |
+
),
|
| 400 |
+
ResidualBlock(
|
| 401 |
+
DynamicConvLayer(
|
| 402 |
+
mid_input_channel,
|
| 403 |
+
mid_input_channel,
|
| 404 |
+
3,
|
| 405 |
+
stride=1,
|
| 406 |
+
use_bn=True,
|
| 407 |
+
act_func="relu",
|
| 408 |
+
),
|
| 409 |
+
IdentityLayer(mid_input_channel, mid_input_channel),
|
| 410 |
+
),
|
| 411 |
+
DynamicConvLayer(
|
| 412 |
+
mid_input_channel,
|
| 413 |
+
input_channel,
|
| 414 |
+
3,
|
| 415 |
+
stride=1,
|
| 416 |
+
use_bn=True,
|
| 417 |
+
act_func="relu",
|
| 418 |
+
),
|
| 419 |
+
]
|
| 420 |
+
|
| 421 |
+
# blocks
|
| 422 |
+
blocks = []
|
| 423 |
+
for d, width, s in zip(n_block_list, stage_width_list, stride_list):
|
| 424 |
+
for i in range(d):
|
| 425 |
+
stride = s if i == 0 else 1
|
| 426 |
+
bottleneck_block = DynamicResNetBottleneckBlock(
|
| 427 |
+
input_channel,
|
| 428 |
+
width,
|
| 429 |
+
expand_ratio_list=self.expand_ratio_list,
|
| 430 |
+
kernel_size=3,
|
| 431 |
+
stride=stride,
|
| 432 |
+
act_func="relu",
|
| 433 |
+
downsample_mode="conv",
|
| 434 |
+
)
|
| 435 |
+
blocks.append(bottleneck_block)
|
| 436 |
+
input_channel = width
|
| 437 |
+
# classifier
|
| 438 |
+
classifier = DynamicLinearLayer(
|
| 439 |
+
input_channel, n_classes, dropout_rate=dropout_rate
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
super(DYNResNets_Cifar, self).__init__(input_stem, blocks, classifier)
|
| 443 |
+
|
| 444 |
+
# set bn param
|
| 445 |
+
self.set_bn_param(*bn_param)
|
| 446 |
+
|
| 447 |
+
# runtime_depth
|
| 448 |
+
self.input_stem_skipping = 0
|
| 449 |
+
self.runtime_depth = [0] * len(n_block_list)
|
| 450 |
+
|
| 451 |
+
@property
|
| 452 |
+
def ks_list(self):
|
| 453 |
+
return [3]
|
| 454 |
+
|
| 455 |
+
@staticmethod
|
| 456 |
+
def name():
|
| 457 |
+
return "DYNResNets_Cifar"
|
| 458 |
+
|
| 459 |
+
def forward(self, x):
|
| 460 |
+
for layer in self.input_stem:
|
| 461 |
+
if (
|
| 462 |
+
self.input_stem_skipping > 0
|
| 463 |
+
and isinstance(layer, ResidualBlock)
|
| 464 |
+
and isinstance(layer.shortcut, IdentityLayer)
|
| 465 |
+
):
|
| 466 |
+
pass
|
| 467 |
+
else:
|
| 468 |
+
x = layer(x)
|
| 469 |
+
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
| 470 |
+
depth_param = self.runtime_depth[stage_id]
|
| 471 |
+
active_idx = block_idx[: len(block_idx) - depth_param]
|
| 472 |
+
for idx in active_idx:
|
| 473 |
+
x = self.blocks[idx](x)
|
| 474 |
+
x = self.global_avg_pool(x)
|
| 475 |
+
x = self.classifier(x)
|
| 476 |
+
return x
|
| 477 |
+
|
| 478 |
+
@property
|
| 479 |
+
def module_str(self):
|
| 480 |
+
_str = ""
|
| 481 |
+
for layer in self.input_stem:
|
| 482 |
+
if (
|
| 483 |
+
self.input_stem_skipping > 0
|
| 484 |
+
and isinstance(layer, ResidualBlock)
|
| 485 |
+
and isinstance(layer.shortcut, IdentityLayer)
|
| 486 |
+
):
|
| 487 |
+
pass
|
| 488 |
+
else:
|
| 489 |
+
_str += layer.module_str + "\n"
|
| 490 |
+
# _str += "max_pooling(ks=3, stride=2)\n"
|
| 491 |
+
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
| 492 |
+
depth_param = self.runtime_depth[stage_id]
|
| 493 |
+
active_idx = block_idx[: len(block_idx) - depth_param]
|
| 494 |
+
for idx in active_idx:
|
| 495 |
+
_str += self.blocks[idx].module_str + "\n"
|
| 496 |
+
_str += self.global_avg_pool.__repr__() + "\n"
|
| 497 |
+
_str += self.classifier.module_str
|
| 498 |
+
return _str
|
| 499 |
+
|
| 500 |
+
@property
|
| 501 |
+
def config(self):
|
| 502 |
+
return {
|
| 503 |
+
"name": DYNResNets_Cifar.__name__,
|
| 504 |
+
"bn": self.get_bn_param(),
|
| 505 |
+
"input_stem": [layer.config for layer in self.input_stem],
|
| 506 |
+
"blocks": [block.config for block in self.blocks],
|
| 507 |
+
"classifier": self.classifier.config,
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
@staticmethod
|
| 511 |
+
def build_from_config(config):
|
| 512 |
+
raise ValueError("do not support this function")
|
| 513 |
+
|
| 514 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 515 |
+
model_dict = self.state_dict()
|
| 516 |
+
for key in state_dict:
|
| 517 |
+
new_key = key
|
| 518 |
+
if new_key in model_dict:
|
| 519 |
+
pass
|
| 520 |
+
elif ".linear." in new_key:
|
| 521 |
+
new_key = new_key.replace(".linear.", ".linear.linear.")
|
| 522 |
+
elif "bn." in new_key:
|
| 523 |
+
new_key = new_key.replace("bn.", "bn.bn.")
|
| 524 |
+
elif "conv.weight" in new_key:
|
| 525 |
+
new_key = new_key.replace("conv.weight", "conv.conv.weight")
|
| 526 |
+
else:
|
| 527 |
+
raise ValueError(new_key)
|
| 528 |
+
assert new_key in model_dict, "%s" % new_key
|
| 529 |
+
model_dict[new_key] = state_dict[key]
|
| 530 |
+
super(DYNResNets_Cifar, self).load_state_dict(model_dict)
|
| 531 |
+
|
| 532 |
+
""" set, sample and get active sub-networks """
|
| 533 |
+
|
| 534 |
+
def set_max_net(self):
|
| 535 |
+
self.set_active_subnet(
|
| 536 |
+
d=max(self.depth_list),
|
| 537 |
+
e=max(self.expand_ratio_list),
|
| 538 |
+
w=len(self.width_mult_list) - 1,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
def set_active_subnet(self, d=None, e=None, w=None, **kwargs):
|
| 542 |
+
depth = val2list(d, len(ResNets_Cifar.BASE_DEPTH_LIST) + 1)
|
| 543 |
+
expand_ratio = val2list(e, len(self.blocks))
|
| 544 |
+
width_mult = val2list(w, len(ResNets_Cifar.BASE_DEPTH_LIST) + 2)
|
| 545 |
+
|
| 546 |
+
for block, e in zip(self.blocks, expand_ratio):
|
| 547 |
+
if e is not None:
|
| 548 |
+
block.active_expand_ratio = e
|
| 549 |
+
|
| 550 |
+
if width_mult[0] is not None:
|
| 551 |
+
self.input_stem[1].conv.active_out_channel = self.input_stem[
|
| 552 |
+
0
|
| 553 |
+
].active_out_channel = self.input_stem[0].out_channel_list[int(width_mult[0])]
|
| 554 |
+
if width_mult[1] is not None:
|
| 555 |
+
self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[
|
| 556 |
+
int(width_mult[1])
|
| 557 |
+
]
|
| 558 |
+
|
| 559 |
+
if depth[0] is not None:
|
| 560 |
+
self.input_stem_skipping = depth[0] != max(self.depth_list)
|
| 561 |
+
for stage_id, (block_idx, d, w) in enumerate(
|
| 562 |
+
zip(self.grouped_block_index, depth[1:], width_mult[2:])
|
| 563 |
+
):
|
| 564 |
+
if d is not None:
|
| 565 |
+
self.runtime_depth[stage_id] = max(self.depth_list) - d
|
| 566 |
+
if w is not None:
|
| 567 |
+
for idx in block_idx:
|
| 568 |
+
self.blocks[idx].active_out_channel = self.blocks[
|
| 569 |
+
idx
|
| 570 |
+
].out_channel_list[int(w)]
|
| 571 |
+
|
| 572 |
+
def sample_active_subnet(self):
|
| 573 |
+
# sample expand ratio
|
| 574 |
+
expand_setting = []
|
| 575 |
+
for block in self.blocks:
|
| 576 |
+
expand_setting.append(random.choice(block.expand_ratio_list))
|
| 577 |
+
|
| 578 |
+
# sample depth
|
| 579 |
+
depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])]
|
| 580 |
+
for stage_id in range(len(ResNets_Cifar.BASE_DEPTH_LIST)):
|
| 581 |
+
depth_setting.append(random.choice(self.depth_list))
|
| 582 |
+
|
| 583 |
+
# sample width_mult
|
| 584 |
+
width_mult_setting = [
|
| 585 |
+
random.choice(list(range(len(self.input_stem[0].out_channel_list)))),
|
| 586 |
+
random.choice(list(range(len(self.input_stem[2].out_channel_list)))),
|
| 587 |
+
]
|
| 588 |
+
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
| 589 |
+
stage_first_block = self.blocks[block_idx[0]]
|
| 590 |
+
width_mult_setting.append(
|
| 591 |
+
random.choice(list(range(len(stage_first_block.out_channel_list))))
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
arch_config = {"d": depth_setting, "e": expand_setting, "w": width_mult_setting}
|
| 595 |
+
self.set_active_subnet(**arch_config)
|
| 596 |
+
return arch_config
|
| 597 |
+
|
| 598 |
+
def get_active_subnet(self, preserve_weight=True):
|
| 599 |
+
input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)]
|
| 600 |
+
if self.input_stem_skipping <= 0:
|
| 601 |
+
input_stem.append(
|
| 602 |
+
ResidualBlock(
|
| 603 |
+
self.input_stem[1].conv.get_active_subnet(
|
| 604 |
+
self.input_stem[0].active_out_channel, preserve_weight
|
| 605 |
+
),
|
| 606 |
+
IdentityLayer(
|
| 607 |
+
self.input_stem[0].active_out_channel,
|
| 608 |
+
self.input_stem[0].active_out_channel,
|
| 609 |
+
),
|
| 610 |
+
)
|
| 611 |
+
)
|
| 612 |
+
input_stem.append(
|
| 613 |
+
self.input_stem[2].get_active_subnet(
|
| 614 |
+
self.input_stem[0].active_out_channel, preserve_weight
|
| 615 |
+
)
|
| 616 |
+
)
|
| 617 |
+
input_channel = self.input_stem[2].active_out_channel
|
| 618 |
+
|
| 619 |
+
blocks = []
|
| 620 |
+
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
| 621 |
+
depth_param = self.runtime_depth[stage_id]
|
| 622 |
+
active_idx = block_idx[: len(block_idx) - depth_param]
|
| 623 |
+
for idx in active_idx:
|
| 624 |
+
blocks.append(
|
| 625 |
+
self.blocks[idx].get_active_subnet(input_channel, preserve_weight)
|
| 626 |
+
)
|
| 627 |
+
input_channel = self.blocks[idx].active_out_channel
|
| 628 |
+
classifier = self.classifier.get_active_subnet(input_channel, preserve_weight)
|
| 629 |
+
subnet = ResNets_Cifar(input_stem, blocks, classifier)
|
| 630 |
+
|
| 631 |
+
subnet.set_bn_param(**self.get_bn_param())
|
| 632 |
+
return subnet
|
| 633 |
+
|
| 634 |
+
def get_active_net_config(self):
|
| 635 |
+
input_stem_config = [self.input_stem[0].get_active_subnet_config(3)]
|
| 636 |
+
if self.input_stem_skipping <= 0:
|
| 637 |
+
input_stem_config.append(
|
| 638 |
+
{
|
| 639 |
+
"name": ResidualBlock.__name__,
|
| 640 |
+
"conv": self.input_stem[1].conv.get_active_subnet_config(
|
| 641 |
+
self.input_stem[0].active_out_channel
|
| 642 |
+
),
|
| 643 |
+
"shortcut": IdentityLayer(
|
| 644 |
+
self.input_stem[0].active_out_channel,
|
| 645 |
+
self.input_stem[0].active_out_channel,
|
| 646 |
+
),
|
| 647 |
+
}
|
| 648 |
+
)
|
| 649 |
+
input_stem_config.append(
|
| 650 |
+
self.input_stem[2].get_active_subnet_config(
|
| 651 |
+
self.input_stem[0].active_out_channel
|
| 652 |
+
)
|
| 653 |
+
)
|
| 654 |
+
input_channel = self.input_stem[2].active_out_channel
|
| 655 |
+
|
| 656 |
+
blocks_config = []
|
| 657 |
+
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
| 658 |
+
depth_param = self.runtime_depth[stage_id]
|
| 659 |
+
active_idx = block_idx[: len(block_idx) - int(depth_param)]
|
| 660 |
+
for idx in active_idx:
|
| 661 |
+
blocks_config.append(
|
| 662 |
+
self.blocks[idx].get_active_subnet_config(input_channel)
|
| 663 |
+
)
|
| 664 |
+
input_channel = self.blocks[idx].active_out_channel
|
| 665 |
+
classifier_config = self.classifier.get_active_subnet_config(input_channel)
|
| 666 |
+
return {
|
| 667 |
+
"name": ResNets_Cifar.__name__,
|
| 668 |
+
"bn": self.get_bn_param(),
|
| 669 |
+
"input_stem": input_stem_config,
|
| 670 |
+
"blocks": blocks_config,
|
| 671 |
+
"classifier": classifier_config,
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
""" Width Related Methods """
|
| 675 |
+
|
| 676 |
+
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
| 677 |
+
for block in self.blocks:
|
| 678 |
+
block.re_organize_middle_weights(expand_ratio_stage)
|
proard/classification/elastic_nn/training/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
from .progressive_shrinking import *
|
| 6 |
+
from .progressive_shrinking import *
|
proard/classification/elastic_nn/training/progressive_shrinking.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from attacks.utils import ctx_noparamgrad_and_eval
|
| 12 |
+
from robust_loss.rslad import rslad_inner_loss,kl_loss
|
| 13 |
+
from robust_loss.trades import trades_loss
|
| 14 |
+
from attacks import create_attack
|
| 15 |
+
import copy
|
| 16 |
+
from proard.utils import AverageMeter, cross_entropy_loss_with_soft_target
|
| 17 |
+
from proard.utils import (
|
| 18 |
+
DistributedMetric,
|
| 19 |
+
list_mean,
|
| 20 |
+
subset_mean,
|
| 21 |
+
val2list,
|
| 22 |
+
MyRandomResizedCrop,
|
| 23 |
+
)
|
| 24 |
+
from proard.classification.run_manager import DistributedRunManager
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"validate",
|
| 28 |
+
"train_one_epoch",
|
| 29 |
+
"train",
|
| 30 |
+
"load_models",
|
| 31 |
+
"train_elastic_depth",
|
| 32 |
+
"train_elastic_expand",
|
| 33 |
+
"train_elastic_width_mult",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def validate(
|
| 38 |
+
run_manager,
|
| 39 |
+
epoch=0,
|
| 40 |
+
is_test=False,
|
| 41 |
+
image_size_list=None,
|
| 42 |
+
ks_list=None,
|
| 43 |
+
expand_ratio_list=None,
|
| 44 |
+
depth_list=None,
|
| 45 |
+
width_mult_list=None,
|
| 46 |
+
additional_setting=None,
|
| 47 |
+
):
|
| 48 |
+
dynamic_net = run_manager.net
|
| 49 |
+
if isinstance(dynamic_net, nn.DataParallel):
|
| 50 |
+
dynamic_net = dynamic_net.module
|
| 51 |
+
|
| 52 |
+
dynamic_net.eval()
|
| 53 |
+
|
| 54 |
+
if image_size_list is None:
|
| 55 |
+
image_size_list = val2list(run_manager.run_config.data_provider.image_size, 1)
|
| 56 |
+
if ks_list is None:
|
| 57 |
+
ks_list = dynamic_net.ks_list
|
| 58 |
+
if expand_ratio_list is None:
|
| 59 |
+
expand_ratio_list = dynamic_net.expand_ratio_list
|
| 60 |
+
if depth_list is None:
|
| 61 |
+
depth_list = dynamic_net.depth_list
|
| 62 |
+
if width_mult_list is not None:
|
| 63 |
+
if "width_mult_list" in dynamic_net.__dict__:
|
| 64 |
+
width_mult_list = list(range(len(dynamic_net.width_mult_list)))
|
| 65 |
+
else:
|
| 66 |
+
width_mult_list = [0]
|
| 67 |
+
|
| 68 |
+
subnet_settings = []
|
| 69 |
+
for d in depth_list:
|
| 70 |
+
for e in expand_ratio_list:
|
| 71 |
+
for k in ks_list:
|
| 72 |
+
for w in width_mult_list:
|
| 73 |
+
for img_size in image_size_list:
|
| 74 |
+
subnet_settings.append(
|
| 75 |
+
[
|
| 76 |
+
{
|
| 77 |
+
"image_size": img_size,
|
| 78 |
+
"d": d,
|
| 79 |
+
"e": e,
|
| 80 |
+
"ks": k,
|
| 81 |
+
"w": w,
|
| 82 |
+
},
|
| 83 |
+
"R%s-D%s-E%s-K%s-W%s" % (img_size, d, e, k, w),
|
| 84 |
+
]
|
| 85 |
+
)
|
| 86 |
+
if additional_setting is not None:
|
| 87 |
+
subnet_settings += additional_setting
|
| 88 |
+
|
| 89 |
+
losses_of_subnets, top1_of_subnets, top5_of_subnets , robust1_of_subnets , robust5_of_subnets = [], [], [],[],[]
|
| 90 |
+
|
| 91 |
+
valid_log = ""
|
| 92 |
+
for setting, name in subnet_settings:
|
| 93 |
+
run_manager.write_log(
|
| 94 |
+
"-" * 30 + " Validate %s " % name + "-" * 30, "train", should_print=False
|
| 95 |
+
)
|
| 96 |
+
run_manager.run_config.data_provider.assign_active_img_size(
|
| 97 |
+
setting.pop("image_size")
|
| 98 |
+
)
|
| 99 |
+
dynamic_net.set_active_subnet(**setting)
|
| 100 |
+
run_manager.write_log(dynamic_net.module_str, "train", should_print=False)
|
| 101 |
+
|
| 102 |
+
run_manager.reset_running_statistics(dynamic_net)
|
| 103 |
+
loss, (top1, top5,robust1,robust5) = run_manager.validate(
|
| 104 |
+
epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net
|
| 105 |
+
)
|
| 106 |
+
losses_of_subnets.append(loss)
|
| 107 |
+
top1_of_subnets.append(top1)
|
| 108 |
+
top5_of_subnets.append(top5)
|
| 109 |
+
robust1_of_subnets.append(robust1)
|
| 110 |
+
robust5_of_subnets.append(robust5)
|
| 111 |
+
valid_log += "%s (%.3f) (%.3f), " % (name, top1,robust1)
|
| 112 |
+
|
| 113 |
+
return (
|
| 114 |
+
list_mean(losses_of_subnets),
|
| 115 |
+
list_mean(top1_of_subnets),
|
| 116 |
+
list_mean(top5_of_subnets),
|
| 117 |
+
list_mean(robust1_of_subnets),
|
| 118 |
+
list_mean(robust5_of_subnets),
|
| 119 |
+
valid_log,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0):
|
| 124 |
+
dynamic_net = run_manager.network
|
| 125 |
+
distributed = isinstance(run_manager, DistributedRunManager)
|
| 126 |
+
|
| 127 |
+
# switch to train mode
|
| 128 |
+
dynamic_net.train()
|
| 129 |
+
if distributed:
|
| 130 |
+
run_manager.run_config.train_loader.sampler.set_epoch(epoch)
|
| 131 |
+
MyRandomResizedCrop.EPOCH = epoch
|
| 132 |
+
|
| 133 |
+
nBatch = len(run_manager.run_config.train_loader)
|
| 134 |
+
|
| 135 |
+
data_time = AverageMeter()
|
| 136 |
+
losses = DistributedMetric("train_loss") if distributed else AverageMeter()
|
| 137 |
+
metric_dict = run_manager.get_metric_dict()
|
| 138 |
+
|
| 139 |
+
with tqdm(
|
| 140 |
+
total=nBatch,
|
| 141 |
+
desc="Train Epoch #{}".format(epoch + 1),
|
| 142 |
+
disable=distributed and not run_manager.is_root,
|
| 143 |
+
) as t:
|
| 144 |
+
end = time.time()
|
| 145 |
+
subnet_str = ""
|
| 146 |
+
j=0
|
| 147 |
+
for _ in range(args.dynamic_batch_size):
|
| 148 |
+
# set random seed before sampling
|
| 149 |
+
subnet_seed = int("%d%.3d%.3d" % (epoch * nBatch + j, _, 0))
|
| 150 |
+
random.seed(subnet_seed)
|
| 151 |
+
subnet_settings = dynamic_net.sample_active_subnet()
|
| 152 |
+
subnet_str += (
|
| 153 |
+
"%d: " % _
|
| 154 |
+
+ ",".join(
|
| 155 |
+
[
|
| 156 |
+
"%s_%s"
|
| 157 |
+
% (
|
| 158 |
+
key,
|
| 159 |
+
"%.1f" % subset_mean(val, 0)
|
| 160 |
+
if isinstance(val, list)
|
| 161 |
+
else val,
|
| 162 |
+
)
|
| 163 |
+
for key, val in subnet_settings.items()
|
| 164 |
+
]
|
| 165 |
+
)
|
| 166 |
+
+ " || "
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
for i, (images, labels) in enumerate(run_manager.run_config.train_loader):
|
| 170 |
+
MyRandomResizedCrop.BATCH = i
|
| 171 |
+
data_time.update(time.time() - end)
|
| 172 |
+
if epoch < warmup_epochs:
|
| 173 |
+
new_lr = run_manager.run_config.warmup_adjust_learning_rate(
|
| 174 |
+
run_manager.optimizer,
|
| 175 |
+
warmup_epochs * nBatch,
|
| 176 |
+
nBatch,
|
| 177 |
+
epoch,
|
| 178 |
+
i,
|
| 179 |
+
warmup_lr,
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
new_lr = run_manager.run_config.adjust_learning_rate(
|
| 183 |
+
run_manager.optimizer, epoch - warmup_epochs, i, nBatch
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
images, labels = images.cuda(), labels.cuda()
|
| 187 |
+
target = labels
|
| 188 |
+
|
| 189 |
+
# soft target
|
| 190 |
+
if args.kd_ratio > 0:
|
| 191 |
+
args.teacher_model.eval()
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
soft_logits = args.teacher_model(images).detach()
|
| 194 |
+
soft_label = F.softmax(soft_logits, dim=1)
|
| 195 |
+
|
| 196 |
+
# clean gradients
|
| 197 |
+
dynamic_net.zero_grad()
|
| 198 |
+
|
| 199 |
+
loss_of_subnets = []
|
| 200 |
+
# compute output
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
output = dynamic_net(images)
|
| 204 |
+
|
| 205 |
+
if args.kd_ratio == 0:
|
| 206 |
+
if run_manager.run_config.robust_mode:
|
| 207 |
+
loss = run_manager.train_criterion(dynamic_net,images,labels,run_manager.optimizer,run_manager.run_config.step_size_train,run_manager.run_config.epsilon_train,run_manager.run_config.num_steps_train,run_manager.run_config.beta_train,run_manager.run_config.distance_train)
|
| 208 |
+
loss_type = run_manager.run_config.train_criterion_loss.__name__
|
| 209 |
+
else:
|
| 210 |
+
loss = torch.nn.CrossEntropyLoss(output,labels)
|
| 211 |
+
loss_type = 'ce'
|
| 212 |
+
else:
|
| 213 |
+
if run_manager.run_config.robust_mode:
|
| 214 |
+
loss = run_manager.kd_criterion(args.teacher_model,dynamic_net,images,labels,run_manager.optimizer,run_manager.run_config.step_size_train,run_manager.run_config.epsilon_train,run_manager.run_config.num_steps_train,run_manager.run_config.beta_train)
|
| 215 |
+
loss_type = run_manager.run_config.kd_criterion_loss.__name__
|
| 216 |
+
else:
|
| 217 |
+
if args.kd_type == "ce":
|
| 218 |
+
kd_loss = cross_entropy_loss_with_soft_target(
|
| 219 |
+
output, soft_label
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
kd_loss = F.mse_loss(output, soft_logits)
|
| 223 |
+
loss = args.kd_ratio * kd_loss + loss
|
| 224 |
+
loss_type = "%.1fkd+ce" % args.kd_ratio
|
| 225 |
+
# measure accuracy and record loss
|
| 226 |
+
loss_of_subnets.append(loss)
|
| 227 |
+
run_manager.update_metric(metric_dict, output,output, target)
|
| 228 |
+
|
| 229 |
+
loss.backward()
|
| 230 |
+
run_manager.optimizer.step()
|
| 231 |
+
|
| 232 |
+
losses.update(list_mean(loss_of_subnets), images.size(0))
|
| 233 |
+
|
| 234 |
+
t.set_postfix(
|
| 235 |
+
{
|
| 236 |
+
"loss": losses.avg.item(),
|
| 237 |
+
**run_manager.get_metric_vals(metric_dict, return_dict=True),
|
| 238 |
+
"R": images.size(2),
|
| 239 |
+
"lr": new_lr,
|
| 240 |
+
"loss_type": loss_type,
|
| 241 |
+
"seed": str(subnet_seed),
|
| 242 |
+
"str": subnet_str,
|
| 243 |
+
"data_time": data_time.avg,
|
| 244 |
+
}
|
| 245 |
+
)
|
| 246 |
+
t.update(1)
|
| 247 |
+
end = time.time()
|
| 248 |
+
j+=1
|
| 249 |
+
return losses.avg.item(), run_manager.get_metric_vals(metric_dict)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def train(run_manager, args, validate_func=None):
|
| 253 |
+
distributed = isinstance(run_manager, DistributedRunManager)
|
| 254 |
+
if validate_func is None:
|
| 255 |
+
validate_func = validate
|
| 256 |
+
|
| 257 |
+
for epoch in range(
|
| 258 |
+
run_manager.start_epoch, run_manager.run_config.n_epochs + args.warmup_epochs
|
| 259 |
+
):
|
| 260 |
+
train_loss, (train_top1, train_top5 , train_robust1 , train_robust5) = train_one_epoch(
|
| 261 |
+
run_manager, args, epoch, args.warmup_epochs, args.warmup_lr
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if (epoch + 1) % args.validation_frequency == 0:
|
| 265 |
+
val_loss, val_acc, val_acc5, val_robust1, val_robust5, _val_log = validate_func(
|
| 266 |
+
run_manager, epoch=epoch, is_test=True
|
| 267 |
+
)
|
| 268 |
+
# best_acc
|
| 269 |
+
is_best = val_acc > run_manager.best_acc
|
| 270 |
+
is_best_robust = val_robust1 > run_manager.best_robustness
|
| 271 |
+
run_manager.best_acc = max(run_manager.best_acc, val_acc)
|
| 272 |
+
run_manager.best_robustness = max(run_manager.best_robustness, val_robust1)
|
| 273 |
+
if not distributed or run_manager.is_root:
|
| 274 |
+
val_log = (
|
| 275 |
+
"Valid [{0}/{1}] loss={2:.3f}, top-1={3:.3f} ({4:.3f}) , robust-1 = {4:.3f} ({5:.3f}) ".format(
|
| 276 |
+
epoch + 1 - args.warmup_epochs,
|
| 277 |
+
run_manager.run_config.n_epochs,
|
| 278 |
+
val_loss,
|
| 279 |
+
val_acc,
|
| 280 |
+
run_manager.best_acc,
|
| 281 |
+
val_robust1,
|
| 282 |
+
run_manager.best_robustness,
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
val_log += ", Train top-1 {top1:.3f}, Train robust-1 {robust1:.3f}, Train loss {loss:.3f}\t".format(
|
| 286 |
+
top1=train_top1, robust1 = train_robust1, loss=train_loss
|
| 287 |
+
)
|
| 288 |
+
val_log += _val_log
|
| 289 |
+
run_manager.write_log(val_log, "valid", should_print=False)
|
| 290 |
+
|
| 291 |
+
run_manager.save_model(
|
| 292 |
+
{
|
| 293 |
+
"epoch": epoch,
|
| 294 |
+
"best_acc": run_manager.best_acc,
|
| 295 |
+
"optimizer": run_manager.optimizer.state_dict(),
|
| 296 |
+
"state_dict": run_manager.network.state_dict(),
|
| 297 |
+
},
|
| 298 |
+
is_best=is_best,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def load_models(run_manager, dynamic_net, model_path=None):
|
| 303 |
+
# specify init path
|
| 304 |
+
init = torch.load(model_path, map_location="cpu")["state_dict"]
|
| 305 |
+
dynamic_net.load_state_dict(init)
|
| 306 |
+
run_manager.write_log("Loaded init from %s" % model_path, "valid")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def train_elastic_depth(train_func, run_manager, args, validate_func_dict):
|
| 310 |
+
dynamic_net = run_manager.net
|
| 311 |
+
if isinstance(dynamic_net, nn.DataParallel):
|
| 312 |
+
dynamic_net = dynamic_net.module
|
| 313 |
+
|
| 314 |
+
depth_stage_list = dynamic_net.depth_list.copy()
|
| 315 |
+
depth_stage_list.sort(reverse=True)
|
| 316 |
+
n_stages = len(depth_stage_list) - 1
|
| 317 |
+
current_stage = n_stages - 1
|
| 318 |
+
|
| 319 |
+
# load pretrained models
|
| 320 |
+
if run_manager.start_epoch == 0 and not args.resume:
|
| 321 |
+
validate_func_dict["depth_list"] = sorted(dynamic_net.depth_list)
|
| 322 |
+
|
| 323 |
+
load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path)
|
| 324 |
+
# validate after loading weights
|
| 325 |
+
run_manager.write_log(
|
| 326 |
+
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
|
| 327 |
+
% validate(run_manager, is_test=True, **validate_func_dict),
|
| 328 |
+
"valid",
|
| 329 |
+
)
|
| 330 |
+
else:
|
| 331 |
+
assert args.resume
|
| 332 |
+
|
| 333 |
+
run_manager.write_log(
|
| 334 |
+
"-" * 30
|
| 335 |
+
+ "Supporting Elastic Depth: %s -> %s"
|
| 336 |
+
% (depth_stage_list[: current_stage + 1], depth_stage_list[: current_stage + 2])
|
| 337 |
+
+ "-" * 30,
|
| 338 |
+
"valid",
|
| 339 |
+
)
|
| 340 |
+
# add depth list constraints
|
| 341 |
+
if (
|
| 342 |
+
len(set(dynamic_net.ks_list)) == 1
|
| 343 |
+
and len(set(dynamic_net.expand_ratio_list)) == 1
|
| 344 |
+
):
|
| 345 |
+
validate_func_dict["depth_list"] = depth_stage_list
|
| 346 |
+
else:
|
| 347 |
+
validate_func_dict["depth_list"] = sorted(
|
| 348 |
+
{min(depth_stage_list), max(depth_stage_list)}
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# train
|
| 352 |
+
train_func(
|
| 353 |
+
run_manager,
|
| 354 |
+
args,
|
| 355 |
+
lambda _run_manager, epoch, is_test: validate(
|
| 356 |
+
_run_manager, epoch, is_test, **validate_func_dict
|
| 357 |
+
),
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def train_elastic_expand(train_func, run_manager, args, validate_func_dict):
|
| 362 |
+
dynamic_net = run_manager.net
|
| 363 |
+
if isinstance(dynamic_net, nn.DataParallel):
|
| 364 |
+
dynamic_net = dynamic_net.module
|
| 365 |
+
|
| 366 |
+
expand_stage_list = dynamic_net.expand_ratio_list.copy()
|
| 367 |
+
expand_stage_list.sort(reverse=True)
|
| 368 |
+
n_stages = len(expand_stage_list) - 1
|
| 369 |
+
current_stage = n_stages - 1
|
| 370 |
+
|
| 371 |
+
# load pretrained models
|
| 372 |
+
if run_manager.start_epoch == 0 and not args.resume:
|
| 373 |
+
validate_func_dict["expand_ratio_list"] = sorted(dynamic_net.expand_ratio_list)
|
| 374 |
+
|
| 375 |
+
load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path)
|
| 376 |
+
dynamic_net.re_organize_middle_weights(expand_ratio_stage=current_stage)
|
| 377 |
+
run_manager.write_log(
|
| 378 |
+
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
|
| 379 |
+
% validate(run_manager, is_test=True, **validate_func_dict),
|
| 380 |
+
"valid",
|
| 381 |
+
)
|
| 382 |
+
else:
|
| 383 |
+
assert args.resume
|
| 384 |
+
|
| 385 |
+
run_manager.write_log(
|
| 386 |
+
"-" * 30
|
| 387 |
+
+ "Supporting Elastic Expand Ratio: %s -> %s"
|
| 388 |
+
% (
|
| 389 |
+
expand_stage_list[: current_stage + 1],
|
| 390 |
+
expand_stage_list[: current_stage + 2],
|
| 391 |
+
)
|
| 392 |
+
+ "-" * 30,
|
| 393 |
+
"valid",
|
| 394 |
+
)
|
| 395 |
+
if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.depth_list)) == 1:
|
| 396 |
+
validate_func_dict["expand_ratio_list"] = expand_stage_list
|
| 397 |
+
else:
|
| 398 |
+
validate_func_dict["expand_ratio_list"] = sorted(
|
| 399 |
+
{min(expand_stage_list), max(expand_stage_list)}
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# train
|
| 403 |
+
train_func(
|
| 404 |
+
run_manager,
|
| 405 |
+
args,
|
| 406 |
+
lambda _run_manager, epoch, is_test: validate(
|
| 407 |
+
_run_manager, epoch, is_test, **validate_func_dict
|
| 408 |
+
),
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def train_elastic_width_mult(train_func, run_manager, args, validate_func_dict):
|
| 413 |
+
dynamic_net = run_manager.net
|
| 414 |
+
if isinstance(dynamic_net, nn.DataParallel):
|
| 415 |
+
dynamic_net = dynamic_net.module
|
| 416 |
+
|
| 417 |
+
width_stage_list = dynamic_net.width_mult_list.copy()
|
| 418 |
+
width_stage_list.sort(reverse=True)
|
| 419 |
+
n_stages = len(width_stage_list) - 1
|
| 420 |
+
current_stage = n_stages - 1
|
| 421 |
+
|
| 422 |
+
if run_manager.start_epoch == 0 and not args.resume:
|
| 423 |
+
load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path)
|
| 424 |
+
if current_stage == 0:
|
| 425 |
+
dynamic_net.re_organize_middle_weights(
|
| 426 |
+
expand_ratio_stage=len(dynamic_net.expand_ratio_list) - 1
|
| 427 |
+
)
|
| 428 |
+
run_manager.write_log(
|
| 429 |
+
"reorganize_middle_weights (expand_ratio_stage=%d)"
|
| 430 |
+
% (len(dynamic_net.expand_ratio_list) - 1),
|
| 431 |
+
"valid",
|
| 432 |
+
)
|
| 433 |
+
try:
|
| 434 |
+
dynamic_net.re_organize_outer_weights()
|
| 435 |
+
run_manager.write_log("reorganize_outer_weights", "valid")
|
| 436 |
+
except Exception:
|
| 437 |
+
pass
|
| 438 |
+
validate_func_dict["width_mult_list"] = sorted({0, len(width_stage_list) - 1})
|
| 439 |
+
run_manager.write_log(
|
| 440 |
+
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
|
| 441 |
+
% validate(run_manager, is_test=True, **validate_func_dict),
|
| 442 |
+
"valid",
|
| 443 |
+
)
|
| 444 |
+
else:
|
| 445 |
+
assert args.resume
|
| 446 |
+
|
| 447 |
+
run_manager.write_log(
|
| 448 |
+
"-" * 30
|
| 449 |
+
+ "Supporting Elastic Width Mult: %s -> %s"
|
| 450 |
+
% (width_stage_list[: current_stage + 1], width_stage_list[: current_stage + 2])
|
| 451 |
+
+ "-" * 30,
|
| 452 |
+
"valid",
|
| 453 |
+
)
|
| 454 |
+
validate_func_dict["width_mult_list"] = sorted({0, len(width_stage_list) - 1})
|
| 455 |
+
|
| 456 |
+
# train
|
| 457 |
+
train_func(
|
| 458 |
+
run_manager,
|
| 459 |
+
args,
|
| 460 |
+
lambda _run_manager, epoch, is_test: validate(
|
| 461 |
+
_run_manager, epoch, is_test, **validate_func_dict
|
| 462 |
+
),
|
| 463 |
+
)
|
proard/classification/elastic_nn/utils.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch
|
| 9 |
+
from attacks import create_attack
|
| 10 |
+
from attacks.utils import ctx_noparamgrad_and_eval
|
| 11 |
+
from proard.utils import AverageMeter, get_net_device, DistributedTensor
|
| 12 |
+
from proard.classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d
|
| 13 |
+
|
| 14 |
+
__all__ = ["set_running_statistics"]
|
| 15 |
+
|
| 16 |
+
def set_running_statistics(model, data_loader, distributed=False):
|
| 17 |
+
bn_mean = {}
|
| 18 |
+
bn_var = {}
|
| 19 |
+
|
| 20 |
+
forward_model = copy.deepcopy(model)
|
| 21 |
+
for name, m in forward_model.named_modules():
|
| 22 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 23 |
+
if distributed:
|
| 24 |
+
bn_mean[name] = DistributedTensor(name + "#mean")
|
| 25 |
+
bn_var[name] = DistributedTensor(name + "#var")
|
| 26 |
+
else:
|
| 27 |
+
bn_mean[name] = AverageMeter()
|
| 28 |
+
bn_var[name] = AverageMeter()
|
| 29 |
+
|
| 30 |
+
def new_forward(bn, mean_est, var_est):
|
| 31 |
+
def lambda_forward(x):
|
| 32 |
+
batch_mean = (
|
| 33 |
+
x.mean(0, keepdim=True)
|
| 34 |
+
.mean(2, keepdim=True)
|
| 35 |
+
.mean(3, keepdim=True)
|
| 36 |
+
) # 1, C, 1, 1
|
| 37 |
+
batch_var = (x - batch_mean) * (x - batch_mean)
|
| 38 |
+
batch_var = (
|
| 39 |
+
batch_var.mean(0, keepdim=True)
|
| 40 |
+
.mean(2, keepdim=True)
|
| 41 |
+
.mean(3, keepdim=True)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
batch_mean = torch.squeeze(batch_mean)
|
| 45 |
+
batch_var = torch.squeeze(batch_var)
|
| 46 |
+
|
| 47 |
+
mean_est.update(batch_mean.data, x.size(0))
|
| 48 |
+
var_est.update(batch_var.data, x.size(0))
|
| 49 |
+
|
| 50 |
+
# bn forward using calculated mean & var
|
| 51 |
+
_feature_dim = batch_mean.size(0)
|
| 52 |
+
return F.batch_norm(
|
| 53 |
+
x,
|
| 54 |
+
batch_mean,
|
| 55 |
+
batch_var,
|
| 56 |
+
bn.weight[:_feature_dim],
|
| 57 |
+
bn.bias[:_feature_dim],
|
| 58 |
+
False,
|
| 59 |
+
0.0,
|
| 60 |
+
bn.eps,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return lambda_forward
|
| 64 |
+
|
| 65 |
+
m.forward = new_forward(m, bn_mean[name], bn_var[name])
|
| 66 |
+
|
| 67 |
+
if len(bn_mean) == 0:
|
| 68 |
+
# skip if there is no batch normalization layers in the network
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True
|
| 73 |
+
for images, labels in data_loader:
|
| 74 |
+
images = images.to(get_net_device(forward_model))
|
| 75 |
+
forward_model(images)
|
| 76 |
+
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False
|
| 77 |
+
|
| 78 |
+
for name, m in model.named_modules():
|
| 79 |
+
if name in bn_mean and bn_mean[name].count > 0:
|
| 80 |
+
feature_dim = bn_mean[name].avg.size(0)
|
| 81 |
+
assert isinstance(m, nn.BatchNorm2d)
|
| 82 |
+
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
|
| 83 |
+
m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
|
proard/classification/networks/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
from .proxyless_nets import *
|
| 6 |
+
from .mobilenet_v3 import *
|
| 7 |
+
from .resnets import *
|
| 8 |
+
from .wide_resnet import WideResNet
|
| 9 |
+
from .resnet_trades import *
|
| 10 |
+
|
| 11 |
+
def get_net_by_name(name):
|
| 12 |
+
if name == ProxylessNASNets.__name__:
|
| 13 |
+
return ProxylessNASNets
|
| 14 |
+
elif name == MobileNetV3.__name__:
|
| 15 |
+
return MobileNetV3
|
| 16 |
+
elif name == ResNets.__name__:
|
| 17 |
+
return ResNets
|
| 18 |
+
if name == ProxylessNASNets_Cifar.__name__:
|
| 19 |
+
return ProxylessNASNets_Cifar
|
| 20 |
+
elif name == MobileNetV3_Cifar.__name__:
|
| 21 |
+
return MobileNetV3
|
| 22 |
+
elif name == ResNets_Cifar.__name__:
|
| 23 |
+
return ResNets_Cifar
|
| 24 |
+
else:
|
| 25 |
+
raise ValueError("unrecognized type of network: %s" % name)
|
proard/classification/networks/mobilenet_v3.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from proard.utils.layers import (
|
| 9 |
+
set_layer_from_config,
|
| 10 |
+
MBConvLayer,
|
| 11 |
+
ConvLayer,
|
| 12 |
+
IdentityLayer,
|
| 13 |
+
LinearLayer,
|
| 14 |
+
ResidualBlock,
|
| 15 |
+
)
|
| 16 |
+
from proard.utils import MyNetwork, make_divisible, MyGlobalAvgPool2d
|
| 17 |
+
|
| 18 |
+
__all__ = ["MobileNetV3", "MobileNetV3Large","MobileNetV3_Cifar", "MobileNetV3Large_Cifar"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MobileNetV3(MyNetwork):
|
| 22 |
+
def __init__(
|
| 23 |
+
self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 24 |
+
):
|
| 25 |
+
super(MobileNetV3, self).__init__()
|
| 26 |
+
|
| 27 |
+
self.first_conv = first_conv
|
| 28 |
+
self.blocks = nn.ModuleList(blocks)
|
| 29 |
+
self.final_expand_layer = final_expand_layer
|
| 30 |
+
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=True)
|
| 31 |
+
self.feature_mix_layer = feature_mix_layer
|
| 32 |
+
self.classifier = classifier
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
x = self.first_conv(x)
|
| 36 |
+
for block in self.blocks:
|
| 37 |
+
x = block(x)
|
| 38 |
+
x = self.final_expand_layer(x)
|
| 39 |
+
x = self.global_avg_pool(x) # global average pooling
|
| 40 |
+
x = self.feature_mix_layer(x)
|
| 41 |
+
x = x.view(x.size(0), -1)
|
| 42 |
+
x = self.classifier(x)
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def module_str(self):
|
| 47 |
+
_str = self.first_conv.module_str + "\n"
|
| 48 |
+
for block in self.blocks:
|
| 49 |
+
_str += block.module_str + "\n"
|
| 50 |
+
_str += self.final_expand_layer.module_str + "\n"
|
| 51 |
+
_str += self.global_avg_pool.__repr__() + "\n"
|
| 52 |
+
_str += self.feature_mix_layer.module_str + "\n"
|
| 53 |
+
_str += self.classifier.module_str
|
| 54 |
+
return _str
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def config(self):
|
| 58 |
+
return {
|
| 59 |
+
"name": MobileNetV3.__name__,
|
| 60 |
+
"bn": self.get_bn_param(),
|
| 61 |
+
"first_conv": self.first_conv.config,
|
| 62 |
+
"blocks": [block.config for block in self.blocks],
|
| 63 |
+
"final_expand_layer": self.final_expand_layer.config,
|
| 64 |
+
"feature_mix_layer": self.feature_mix_layer.config,
|
| 65 |
+
"classifier": self.classifier.config,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def build_from_config(config):
|
| 70 |
+
first_conv = set_layer_from_config(config["first_conv"])
|
| 71 |
+
final_expand_layer = set_layer_from_config(config["final_expand_layer"])
|
| 72 |
+
feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
|
| 73 |
+
classifier = set_layer_from_config(config["classifier"])
|
| 74 |
+
|
| 75 |
+
blocks = []
|
| 76 |
+
for block_config in config["blocks"]:
|
| 77 |
+
blocks.append(ResidualBlock.build_from_config(block_config))
|
| 78 |
+
|
| 79 |
+
net = MobileNetV3(
|
| 80 |
+
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 81 |
+
)
|
| 82 |
+
if "bn" in config:
|
| 83 |
+
net.set_bn_param(**config["bn"])
|
| 84 |
+
else:
|
| 85 |
+
net.set_bn_param(momentum=0.1, eps=1e-5)
|
| 86 |
+
|
| 87 |
+
return net
|
| 88 |
+
|
| 89 |
+
def zero_last_gamma(self):
|
| 90 |
+
for m in self.modules():
|
| 91 |
+
if isinstance(m, ResidualBlock):
|
| 92 |
+
if isinstance(m.conv, MBConvLayer) and isinstance(
|
| 93 |
+
m.shortcut, IdentityLayer
|
| 94 |
+
):
|
| 95 |
+
m.conv.point_linear.bn.weight.data.zero_()
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def grouped_block_index(self):
|
| 99 |
+
info_list = []
|
| 100 |
+
block_index_list = []
|
| 101 |
+
for i, block in enumerate(self.blocks[1:], 1):
|
| 102 |
+
if block.shortcut is None and len(block_index_list) > 0:
|
| 103 |
+
info_list.append(block_index_list)
|
| 104 |
+
block_index_list = []
|
| 105 |
+
block_index_list.append(i)
|
| 106 |
+
if len(block_index_list) > 0:
|
| 107 |
+
info_list.append(block_index_list)
|
| 108 |
+
return info_list
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
|
| 112 |
+
# first conv layer
|
| 113 |
+
first_conv = ConvLayer(
|
| 114 |
+
3,
|
| 115 |
+
input_channel,
|
| 116 |
+
kernel_size=3,
|
| 117 |
+
stride=2,
|
| 118 |
+
use_bn=True,
|
| 119 |
+
act_func="h_swish",
|
| 120 |
+
ops_order="weight_bn_act",
|
| 121 |
+
)
|
| 122 |
+
# build mobile blocks
|
| 123 |
+
feature_dim = input_channel
|
| 124 |
+
blocks = []
|
| 125 |
+
for stage_id, block_config_list in cfg.items():
|
| 126 |
+
for (
|
| 127 |
+
k,
|
| 128 |
+
mid_channel,
|
| 129 |
+
out_channel,
|
| 130 |
+
use_se,
|
| 131 |
+
act_func,
|
| 132 |
+
stride,
|
| 133 |
+
expand_ratio,
|
| 134 |
+
) in block_config_list:
|
| 135 |
+
mb_conv = MBConvLayer(
|
| 136 |
+
feature_dim,
|
| 137 |
+
out_channel,
|
| 138 |
+
k,
|
| 139 |
+
stride,
|
| 140 |
+
expand_ratio,
|
| 141 |
+
mid_channel,
|
| 142 |
+
act_func,
|
| 143 |
+
use_se,
|
| 144 |
+
)
|
| 145 |
+
if stride == 1 and out_channel == feature_dim:
|
| 146 |
+
shortcut = IdentityLayer(out_channel, out_channel)
|
| 147 |
+
else:
|
| 148 |
+
shortcut = None
|
| 149 |
+
blocks.append(ResidualBlock(mb_conv, shortcut))
|
| 150 |
+
feature_dim = out_channel
|
| 151 |
+
# final expand layer
|
| 152 |
+
final_expand_layer = ConvLayer(
|
| 153 |
+
feature_dim,
|
| 154 |
+
feature_dim * 6,
|
| 155 |
+
kernel_size=1,
|
| 156 |
+
use_bn=True,
|
| 157 |
+
act_func="h_swish",
|
| 158 |
+
ops_order="weight_bn_act",
|
| 159 |
+
)
|
| 160 |
+
# feature mix layer
|
| 161 |
+
feature_mix_layer = ConvLayer(
|
| 162 |
+
feature_dim * 6,
|
| 163 |
+
last_channel,
|
| 164 |
+
kernel_size=1,
|
| 165 |
+
bias=False,
|
| 166 |
+
use_bn=False,
|
| 167 |
+
act_func="h_swish",
|
| 168 |
+
)
|
| 169 |
+
# classifier
|
| 170 |
+
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
| 171 |
+
|
| 172 |
+
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def adjust_cfg(
|
| 176 |
+
cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None
|
| 177 |
+
):
|
| 178 |
+
for i, (stage_id, block_config_list) in enumerate(cfg.items()):
|
| 179 |
+
for block_config in block_config_list:
|
| 180 |
+
if ks is not None and stage_id != "0":
|
| 181 |
+
block_config[0] = ks
|
| 182 |
+
if expand_ratio is not None and stage_id != "0":
|
| 183 |
+
block_config[-1] = expand_ratio
|
| 184 |
+
block_config[1] = None
|
| 185 |
+
if stage_width_list is not None:
|
| 186 |
+
block_config[2] = stage_width_list[i]
|
| 187 |
+
if depth_param is not None and stage_id != "0":
|
| 188 |
+
new_block_config_list = [block_config_list[0]]
|
| 189 |
+
new_block_config_list += [
|
| 190 |
+
copy.deepcopy(block_config_list[-1]) for _ in range(depth_param - 1)
|
| 191 |
+
]
|
| 192 |
+
cfg[stage_id] = new_block_config_list
|
| 193 |
+
return cfg
|
| 194 |
+
|
| 195 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 196 |
+
current_state_dict = self.state_dict()
|
| 197 |
+
|
| 198 |
+
for key in state_dict:
|
| 199 |
+
if key not in current_state_dict:
|
| 200 |
+
assert ".mobile_inverted_conv." in key
|
| 201 |
+
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
|
| 202 |
+
else:
|
| 203 |
+
new_key = key
|
| 204 |
+
current_state_dict[new_key] = state_dict[key]
|
| 205 |
+
super(MobileNetV3, self).load_state_dict(current_state_dict)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class MobileNetV3Large(MobileNetV3):
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
n_classes=1000,
|
| 212 |
+
width_mult=1.0,
|
| 213 |
+
bn_param=(0.1, 1e-5),
|
| 214 |
+
dropout_rate=0.2,
|
| 215 |
+
ks=None,
|
| 216 |
+
expand_ratio=None,
|
| 217 |
+
depth_param=None,
|
| 218 |
+
stage_width_list=None,
|
| 219 |
+
):
|
| 220 |
+
input_channel = 16
|
| 221 |
+
last_channel = 1280
|
| 222 |
+
|
| 223 |
+
input_channel = make_divisible(
|
| 224 |
+
input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 225 |
+
)
|
| 226 |
+
last_channel = (
|
| 227 |
+
make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 228 |
+
if width_mult > 1.0
|
| 229 |
+
else last_channel
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
cfg = {
|
| 233 |
+
# k, exp, c, se, nl, s, e,
|
| 234 |
+
"0": [
|
| 235 |
+
[3, 16, 16, False, "relu", 1, 1],
|
| 236 |
+
],
|
| 237 |
+
"1": [
|
| 238 |
+
[3, 64, 24, False, "relu", 2, None], # 4
|
| 239 |
+
[3, 72, 24, False, "relu", 1, None], # 3
|
| 240 |
+
],
|
| 241 |
+
"2": [
|
| 242 |
+
[5, 72, 40, True, "relu", 2, None], # 3
|
| 243 |
+
[5, 120, 40, True, "relu", 1, None], # 3
|
| 244 |
+
[5, 120, 40, True, "relu", 1, None], # 3
|
| 245 |
+
],
|
| 246 |
+
"3": [
|
| 247 |
+
[3, 240, 80, False, "h_swish", 2, None], # 6
|
| 248 |
+
[3, 200, 80, False, "h_swish", 1, None], # 2.5
|
| 249 |
+
[3, 184, 80, False, "h_swish", 1, None], # 2.3
|
| 250 |
+
[3, 184, 80, False, "h_swish", 1, None], # 2.3
|
| 251 |
+
],
|
| 252 |
+
"4": [
|
| 253 |
+
[3, 480, 112, True, "h_swish", 1, None], # 6
|
| 254 |
+
[3, 672, 112, True, "h_swish", 1, None], # 6
|
| 255 |
+
],
|
| 256 |
+
"5": [
|
| 257 |
+
[5, 672, 160, True, "h_swish", 2, None], # 6
|
| 258 |
+
[5, 960, 160, True, "h_swish", 1, None], # 6
|
| 259 |
+
[5, 960, 160, True, "h_swish", 1, None], # 6
|
| 260 |
+
],
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list)
|
| 264 |
+
# width multiplier on mobile setting, change `exp: 1` and `c: 2`
|
| 265 |
+
for stage_id, block_config_list in cfg.items():
|
| 266 |
+
for block_config in block_config_list:
|
| 267 |
+
if block_config[1] is not None:
|
| 268 |
+
block_config[1] = make_divisible(
|
| 269 |
+
block_config[1] * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 270 |
+
)
|
| 271 |
+
block_config[2] = make_divisible(
|
| 272 |
+
block_config[2] * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
(
|
| 276 |
+
first_conv,
|
| 277 |
+
blocks,
|
| 278 |
+
final_expand_layer,
|
| 279 |
+
feature_mix_layer,
|
| 280 |
+
classifier,
|
| 281 |
+
) = self.build_net_via_cfg(
|
| 282 |
+
cfg, input_channel, last_channel, n_classes, dropout_rate
|
| 283 |
+
)
|
| 284 |
+
super(MobileNetV3Large, self).__init__(
|
| 285 |
+
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 286 |
+
)
|
| 287 |
+
# set bn param
|
| 288 |
+
self.set_bn_param(*bn_param)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class MobileNetV3_Cifar(MyNetwork):
|
| 293 |
+
def __init__(
|
| 294 |
+
self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 295 |
+
):
|
| 296 |
+
super(MobileNetV3_Cifar, self).__init__()
|
| 297 |
+
|
| 298 |
+
self.first_conv = first_conv
|
| 299 |
+
self.blocks = nn.ModuleList(blocks)
|
| 300 |
+
self.final_expand_layer = final_expand_layer
|
| 301 |
+
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=True)
|
| 302 |
+
self.feature_mix_layer = feature_mix_layer
|
| 303 |
+
self.classifier = classifier
|
| 304 |
+
|
| 305 |
+
def forward(self, x):
|
| 306 |
+
x = self.first_conv(x)
|
| 307 |
+
for block in self.blocks:
|
| 308 |
+
x = block(x)
|
| 309 |
+
x = self.final_expand_layer(x)
|
| 310 |
+
x = self.global_avg_pool(x) # global average pooling
|
| 311 |
+
x = self.feature_mix_layer(x)
|
| 312 |
+
x = x.view(x.size(0), -1)
|
| 313 |
+
x = self.classifier(x)
|
| 314 |
+
return x
|
| 315 |
+
|
| 316 |
+
@property
|
| 317 |
+
def module_str(self):
|
| 318 |
+
_str = self.first_conv.module_str + "\n"
|
| 319 |
+
for block in self.blocks:
|
| 320 |
+
_str += block.module_str + "\n"
|
| 321 |
+
_str += self.final_expand_layer.module_str + "\n"
|
| 322 |
+
_str += self.global_avg_pool.__repr__() + "\n"
|
| 323 |
+
_str += self.feature_mix_layer.module_str + "\n"
|
| 324 |
+
_str += self.classifier.module_str
|
| 325 |
+
return _str
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def config(self):
|
| 329 |
+
return {
|
| 330 |
+
"name": MobileNetV3_Cifar.__name__,
|
| 331 |
+
"bn": self.get_bn_param(),
|
| 332 |
+
"first_conv": self.first_conv.config,
|
| 333 |
+
"blocks": [block.config for block in self.blocks],
|
| 334 |
+
"final_expand_layer": self.final_expand_layer.config,
|
| 335 |
+
"feature_mix_layer": self.feature_mix_layer.config,
|
| 336 |
+
"classifier": self.classifier.config,
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
@staticmethod
|
| 340 |
+
def build_from_config(config):
|
| 341 |
+
first_conv = set_layer_from_config(config["first_conv"])
|
| 342 |
+
final_expand_layer = set_layer_from_config(config["final_expand_layer"])
|
| 343 |
+
feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
|
| 344 |
+
classifier = set_layer_from_config(config["classifier"])
|
| 345 |
+
|
| 346 |
+
blocks = []
|
| 347 |
+
for block_config in config["blocks"]:
|
| 348 |
+
blocks.append(ResidualBlock.build_from_config(block_config))
|
| 349 |
+
|
| 350 |
+
net = MobileNetV3_Cifar(
|
| 351 |
+
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 352 |
+
)
|
| 353 |
+
if "bn" in config:
|
| 354 |
+
net.set_bn_param(**config["bn"])
|
| 355 |
+
else:
|
| 356 |
+
net.set_bn_param(momentum=0.1, eps=1e-5)
|
| 357 |
+
|
| 358 |
+
return net
|
| 359 |
+
|
| 360 |
+
def zero_last_gamma(self):
|
| 361 |
+
for m in self.modules():
|
| 362 |
+
if isinstance(m, ResidualBlock):
|
| 363 |
+
if isinstance(m.conv, MBConvLayer) and isinstance(
|
| 364 |
+
m.shortcut, IdentityLayer
|
| 365 |
+
):
|
| 366 |
+
m.conv.point_linear.bn.weight.data.zero_()
|
| 367 |
+
|
| 368 |
+
@property
|
| 369 |
+
def grouped_block_index(self):
|
| 370 |
+
info_list = []
|
| 371 |
+
block_index_list = []
|
| 372 |
+
for i, block in enumerate(self.blocks[1:], 1):
|
| 373 |
+
if block.shortcut is None and len(block_index_list) > 0:
|
| 374 |
+
info_list.append(block_index_list)
|
| 375 |
+
block_index_list = []
|
| 376 |
+
block_index_list.append(i)
|
| 377 |
+
if len(block_index_list) > 0:
|
| 378 |
+
info_list.append(block_index_list)
|
| 379 |
+
return info_list
|
| 380 |
+
|
| 381 |
+
@staticmethod
|
| 382 |
+
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
|
| 383 |
+
# first conv layer
|
| 384 |
+
first_conv = ConvLayer(
|
| 385 |
+
3,
|
| 386 |
+
input_channel,
|
| 387 |
+
kernel_size=3,
|
| 388 |
+
stride=1,
|
| 389 |
+
use_bn=True,
|
| 390 |
+
act_func="h_swish",
|
| 391 |
+
ops_order="weight_bn_act",
|
| 392 |
+
)
|
| 393 |
+
# build mobile blocks
|
| 394 |
+
feature_dim = input_channel
|
| 395 |
+
blocks = []
|
| 396 |
+
for stage_id, block_config_list in cfg.items():
|
| 397 |
+
for (
|
| 398 |
+
k,
|
| 399 |
+
mid_channel,
|
| 400 |
+
out_channel,
|
| 401 |
+
use_se,
|
| 402 |
+
act_func,
|
| 403 |
+
stride,
|
| 404 |
+
expand_ratio,
|
| 405 |
+
) in block_config_list:
|
| 406 |
+
mb_conv = MBConvLayer(
|
| 407 |
+
feature_dim,
|
| 408 |
+
out_channel,
|
| 409 |
+
k,
|
| 410 |
+
stride,
|
| 411 |
+
expand_ratio,
|
| 412 |
+
mid_channel,
|
| 413 |
+
act_func,
|
| 414 |
+
use_se,
|
| 415 |
+
)
|
| 416 |
+
if stride == 1 and out_channel == feature_dim:
|
| 417 |
+
shortcut = IdentityLayer(out_channel, out_channel)
|
| 418 |
+
else:
|
| 419 |
+
shortcut = None
|
| 420 |
+
blocks.append(ResidualBlock(mb_conv, shortcut))
|
| 421 |
+
feature_dim = out_channel
|
| 422 |
+
# final expand layer
|
| 423 |
+
final_expand_layer = ConvLayer(
|
| 424 |
+
feature_dim,
|
| 425 |
+
feature_dim * 6,
|
| 426 |
+
kernel_size=1,
|
| 427 |
+
use_bn=True,
|
| 428 |
+
act_func="h_swish",
|
| 429 |
+
ops_order="weight_bn_act",
|
| 430 |
+
)
|
| 431 |
+
# feature mix layer
|
| 432 |
+
feature_mix_layer = ConvLayer(
|
| 433 |
+
feature_dim * 6,
|
| 434 |
+
last_channel,
|
| 435 |
+
kernel_size=1,
|
| 436 |
+
bias=False,
|
| 437 |
+
use_bn=False,
|
| 438 |
+
act_func="h_swish",
|
| 439 |
+
)
|
| 440 |
+
# classifier
|
| 441 |
+
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
| 442 |
+
|
| 443 |
+
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 444 |
+
|
| 445 |
+
@staticmethod
|
| 446 |
+
def adjust_cfg(
|
| 447 |
+
cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None
|
| 448 |
+
):
|
| 449 |
+
for i, (stage_id, block_config_list) in enumerate(cfg.items()):
|
| 450 |
+
for block_config in block_config_list:
|
| 451 |
+
if ks is not None and stage_id != "0":
|
| 452 |
+
block_config[0] = ks
|
| 453 |
+
if expand_ratio is not None and stage_id != "0":
|
| 454 |
+
block_config[-1] = expand_ratio
|
| 455 |
+
block_config[1] = None
|
| 456 |
+
if stage_width_list is not None:
|
| 457 |
+
block_config[2] = stage_width_list[i]
|
| 458 |
+
if depth_param is not None and stage_id != "0":
|
| 459 |
+
new_block_config_list = [block_config_list[0]]
|
| 460 |
+
new_block_config_list += [
|
| 461 |
+
copy.deepcopy(block_config_list[-1]) for _ in range(depth_param - 1)
|
| 462 |
+
]
|
| 463 |
+
cfg[stage_id] = new_block_config_list
|
| 464 |
+
return cfg
|
| 465 |
+
|
| 466 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 467 |
+
current_state_dict = self.state_dict()
|
| 468 |
+
|
| 469 |
+
for key in state_dict:
|
| 470 |
+
if key not in current_state_dict:
|
| 471 |
+
assert ".mobile_inverted_conv." in key
|
| 472 |
+
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
|
| 473 |
+
else:
|
| 474 |
+
new_key = key
|
| 475 |
+
current_state_dict[new_key] = state_dict[key]
|
| 476 |
+
super(MobileNetV3_Cifar, self).load_state_dict(current_state_dict)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class MobileNetV3Large_Cifar(MobileNetV3_Cifar):
|
| 480 |
+
def __init__(
|
| 481 |
+
self,
|
| 482 |
+
n_classes=10,
|
| 483 |
+
width_mult=1.0,
|
| 484 |
+
bn_param=(0.1, 1e-5),
|
| 485 |
+
dropout_rate=0.2,
|
| 486 |
+
ks=None,
|
| 487 |
+
expand_ratio=None,
|
| 488 |
+
depth_param=None,
|
| 489 |
+
stage_width_list=None,
|
| 490 |
+
):
|
| 491 |
+
input_channel = 16
|
| 492 |
+
last_channel = 1280
|
| 493 |
+
|
| 494 |
+
input_channel = make_divisible(
|
| 495 |
+
input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 496 |
+
)
|
| 497 |
+
last_channel = (
|
| 498 |
+
make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 499 |
+
if width_mult > 1.0
|
| 500 |
+
else last_channel
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
cfg = {
|
| 504 |
+
# k, exp, c, se, nl, s, e,
|
| 505 |
+
"0": [
|
| 506 |
+
[3, 16, 16, False, "relu", 1, 1],
|
| 507 |
+
],
|
| 508 |
+
"1": [
|
| 509 |
+
[3, 64, 24, False, "relu", 1, None], # 4
|
| 510 |
+
[3, 72, 24, False, "relu", 1, None], # 3
|
| 511 |
+
],
|
| 512 |
+
"2": [
|
| 513 |
+
[5, 72, 40, True, "relu", 2, None], # 3
|
| 514 |
+
[5, 120, 40, True, "relu", 1, None], # 3
|
| 515 |
+
[5, 120, 40, True, "relu", 1, None], # 3
|
| 516 |
+
],
|
| 517 |
+
"3": [
|
| 518 |
+
[3, 240, 80, False, "h_swish", 2, None], # 6
|
| 519 |
+
[3, 200, 80, False, "h_swish", 1, None], # 2.5
|
| 520 |
+
[3, 184, 80, False, "h_swish", 1, None], # 2.3
|
| 521 |
+
[3, 184, 80, False, "h_swish", 1, None], # 2.3
|
| 522 |
+
],
|
| 523 |
+
"4": [
|
| 524 |
+
[3, 480, 112, True, "h_swish", 1, None], # 6
|
| 525 |
+
[3, 672, 112, True, "h_swish", 1, None], # 6
|
| 526 |
+
],
|
| 527 |
+
"5": [
|
| 528 |
+
[5, 672, 160, True, "h_swish", 2, None], # 6
|
| 529 |
+
[5, 960, 160, True, "h_swish", 1, None], # 6
|
| 530 |
+
[5, 960, 160, True, "h_swish", 1, None], # 6
|
| 531 |
+
],
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list)
|
| 535 |
+
# width multiplier on mobile setting, change `exp: 1` and `c: 2`
|
| 536 |
+
for stage_id, block_config_list in cfg.items():
|
| 537 |
+
for block_config in block_config_list:
|
| 538 |
+
if block_config[1] is not None:
|
| 539 |
+
block_config[1] = make_divisible(
|
| 540 |
+
block_config[1] * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 541 |
+
)
|
| 542 |
+
block_config[2] = make_divisible(
|
| 543 |
+
block_config[2] * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
(
|
| 547 |
+
first_conv,
|
| 548 |
+
blocks,
|
| 549 |
+
final_expand_layer,
|
| 550 |
+
feature_mix_layer,
|
| 551 |
+
classifier,
|
| 552 |
+
) = self.build_net_via_cfg(
|
| 553 |
+
cfg, input_channel, last_channel, n_classes, dropout_rate
|
| 554 |
+
)
|
| 555 |
+
super(MobileNetV3Large_Cifar, self).__init__(
|
| 556 |
+
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
| 557 |
+
)
|
| 558 |
+
# set bn param
|
| 559 |
+
self.set_bn_param(*bn_param)
|
proard/classification/networks/proxyless_nets.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from proard.utils.layers import (
|
| 9 |
+
set_layer_from_config,
|
| 10 |
+
MBConvLayer,
|
| 11 |
+
ConvLayer,
|
| 12 |
+
IdentityLayer,
|
| 13 |
+
LinearLayer,
|
| 14 |
+
ResidualBlock,
|
| 15 |
+
)
|
| 16 |
+
from proard.utils import (
|
| 17 |
+
download_url,
|
| 18 |
+
make_divisible,
|
| 19 |
+
val2list,
|
| 20 |
+
MyNetwork,
|
| 21 |
+
MyGlobalAvgPool2d,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
__all__ = ["proxyless_base_cifar","proxyless_base", "ProxylessNASNets", "MobileNetV2", "ProxylessNASNets_Cifar", "MobileNetV2_Cifar"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def proxyless_base(
|
| 28 |
+
net_config=None,
|
| 29 |
+
n_classes=None,
|
| 30 |
+
bn_param=None,
|
| 31 |
+
dropout_rate=None,
|
| 32 |
+
local_path="~/.torch/proxylessnas/",
|
| 33 |
+
):
|
| 34 |
+
assert net_config is not None, "Please input a network config"
|
| 35 |
+
if "http" in net_config:
|
| 36 |
+
net_config_path = download_url(net_config, local_path)
|
| 37 |
+
else:
|
| 38 |
+
net_config_path = net_config
|
| 39 |
+
net_config_json = json.load(open(net_config_path, "r"))
|
| 40 |
+
|
| 41 |
+
if n_classes is not None:
|
| 42 |
+
net_config_json["classifier"]["out_features"] = n_classes
|
| 43 |
+
if dropout_rate is not None:
|
| 44 |
+
net_config_json["classifier"]["dropout_rate"] = dropout_rate
|
| 45 |
+
|
| 46 |
+
net = ProxylessNASNets.build_from_config(net_config_json)
|
| 47 |
+
if bn_param is not None:
|
| 48 |
+
net.set_bn_param(*bn_param)
|
| 49 |
+
|
| 50 |
+
return net
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ProxylessNASNets(MyNetwork):
|
| 54 |
+
def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
|
| 55 |
+
super(ProxylessNASNets, self).__init__()
|
| 56 |
+
|
| 57 |
+
self.first_conv = first_conv
|
| 58 |
+
self.blocks = nn.ModuleList(blocks)
|
| 59 |
+
self.feature_mix_layer = feature_mix_layer
|
| 60 |
+
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
|
| 61 |
+
self.classifier = classifier
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
x = self.first_conv(x)
|
| 65 |
+
for block in self.blocks:
|
| 66 |
+
x = block(x)
|
| 67 |
+
if self.feature_mix_layer is not None:
|
| 68 |
+
x = self.feature_mix_layer(x)
|
| 69 |
+
x = self.global_avg_pool(x)
|
| 70 |
+
x = self.classifier(x)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def module_str(self):
|
| 75 |
+
_str = self.first_conv.module_str + "\n"
|
| 76 |
+
for block in self.blocks:
|
| 77 |
+
_str += block.module_str + "\n"
|
| 78 |
+
_str += self.feature_mix_layer.module_str + "\n"
|
| 79 |
+
_str += self.global_avg_pool.__repr__() + "\n"
|
| 80 |
+
_str += self.classifier.module_str
|
| 81 |
+
return _str
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def config(self):
|
| 85 |
+
return {
|
| 86 |
+
"name": ProxylessNASNets.__name__,
|
| 87 |
+
"bn": self.get_bn_param(),
|
| 88 |
+
"first_conv": self.first_conv.config,
|
| 89 |
+
"blocks": [block.config for block in self.blocks],
|
| 90 |
+
"feature_mix_layer": None
|
| 91 |
+
if self.feature_mix_layer is None
|
| 92 |
+
else self.feature_mix_layer.config,
|
| 93 |
+
"classifier": self.classifier.config,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def build_from_config(config):
|
| 98 |
+
first_conv = set_layer_from_config(config["first_conv"])
|
| 99 |
+
feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
|
| 100 |
+
classifier = set_layer_from_config(config["classifier"])
|
| 101 |
+
|
| 102 |
+
blocks = []
|
| 103 |
+
for block_config in config["blocks"]:
|
| 104 |
+
blocks.append(ResidualBlock.build_from_config(block_config))
|
| 105 |
+
|
| 106 |
+
net = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
|
| 107 |
+
if "bn" in config:
|
| 108 |
+
net.set_bn_param(**config["bn"])
|
| 109 |
+
else:
|
| 110 |
+
net.set_bn_param(momentum=0.1, eps=1e-3)
|
| 111 |
+
|
| 112 |
+
return net
|
| 113 |
+
|
| 114 |
+
def zero_last_gamma(self):
|
| 115 |
+
for m in self.modules():
|
| 116 |
+
if isinstance(m, ResidualBlock):
|
| 117 |
+
if isinstance(m.conv, MBConvLayer) and isinstance(
|
| 118 |
+
m.shortcut, IdentityLayer
|
| 119 |
+
):
|
| 120 |
+
m.conv.point_linear.bn.weight.data.zero_()
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def grouped_block_index(self):
|
| 124 |
+
info_list = []
|
| 125 |
+
block_index_list = []
|
| 126 |
+
for i, block in enumerate(self.blocks[1:], 1):
|
| 127 |
+
if block.shortcut is None and len(block_index_list) > 0:
|
| 128 |
+
info_list.append(block_index_list)
|
| 129 |
+
block_index_list = []
|
| 130 |
+
block_index_list.append(i)
|
| 131 |
+
if len(block_index_list) > 0:
|
| 132 |
+
info_list.append(block_index_list)
|
| 133 |
+
return info_list
|
| 134 |
+
|
| 135 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 136 |
+
current_state_dict = self.state_dict()
|
| 137 |
+
|
| 138 |
+
for key in state_dict:
|
| 139 |
+
if key not in current_state_dict:
|
| 140 |
+
assert ".mobile_inverted_conv." in key
|
| 141 |
+
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
|
| 142 |
+
else:
|
| 143 |
+
new_key = key
|
| 144 |
+
current_state_dict[new_key] = state_dict[key]
|
| 145 |
+
super(ProxylessNASNets, self).load_state_dict(current_state_dict)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class MobileNetV2(ProxylessNASNets):
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
n_classes=1000,
|
| 152 |
+
width_mult=1.0,
|
| 153 |
+
bn_param=(0.1, 1e-3),
|
| 154 |
+
dropout_rate=0.2,
|
| 155 |
+
ks=None,
|
| 156 |
+
expand_ratio=None,
|
| 157 |
+
depth_param=None,
|
| 158 |
+
stage_width_list=None,
|
| 159 |
+
):
|
| 160 |
+
|
| 161 |
+
ks = 3 if ks is None else ks
|
| 162 |
+
expand_ratio = 6 if expand_ratio is None else expand_ratio
|
| 163 |
+
|
| 164 |
+
input_channel = 32
|
| 165 |
+
last_channel = 1280
|
| 166 |
+
|
| 167 |
+
input_channel = make_divisible(
|
| 168 |
+
input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 169 |
+
)
|
| 170 |
+
last_channel = (
|
| 171 |
+
make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 172 |
+
if width_mult > 1.0
|
| 173 |
+
else last_channel
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
inverted_residual_setting = [
|
| 177 |
+
# t, c, n, s
|
| 178 |
+
[1, 16, 1, 1],
|
| 179 |
+
[expand_ratio, 24, 2, 2],
|
| 180 |
+
[expand_ratio, 32, 3, 2],
|
| 181 |
+
[expand_ratio, 64, 4, 2],
|
| 182 |
+
[expand_ratio, 96, 3, 1],
|
| 183 |
+
[expand_ratio, 160, 3, 2],
|
| 184 |
+
[expand_ratio, 320, 1, 1],
|
| 185 |
+
]
|
| 186 |
+
|
| 187 |
+
if depth_param is not None:
|
| 188 |
+
assert isinstance(depth_param, int)
|
| 189 |
+
for i in range(1, len(inverted_residual_setting) - 1):
|
| 190 |
+
inverted_residual_setting[i][2] = depth_param
|
| 191 |
+
|
| 192 |
+
if stage_width_list is not None:
|
| 193 |
+
for i in range(len(inverted_residual_setting)):
|
| 194 |
+
inverted_residual_setting[i][1] = stage_width_list[i]
|
| 195 |
+
|
| 196 |
+
ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
|
| 197 |
+
_pt = 0
|
| 198 |
+
|
| 199 |
+
# first conv layer
|
| 200 |
+
first_conv = ConvLayer(
|
| 201 |
+
3,
|
| 202 |
+
input_channel,
|
| 203 |
+
kernel_size=3,
|
| 204 |
+
stride=2,
|
| 205 |
+
use_bn=True,
|
| 206 |
+
act_func="relu6",
|
| 207 |
+
ops_order="weight_bn_act",
|
| 208 |
+
)
|
| 209 |
+
# inverted residual blocks
|
| 210 |
+
blocks = []
|
| 211 |
+
for t, c, n, s in inverted_residual_setting:
|
| 212 |
+
output_channel = make_divisible(c * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 213 |
+
for i in range(n):
|
| 214 |
+
if i == 0:
|
| 215 |
+
stride = s
|
| 216 |
+
else:
|
| 217 |
+
stride = 1
|
| 218 |
+
if t == 1:
|
| 219 |
+
kernel_size = 3
|
| 220 |
+
else:
|
| 221 |
+
kernel_size = ks[_pt]
|
| 222 |
+
_pt += 1
|
| 223 |
+
mobile_inverted_conv = MBConvLayer(
|
| 224 |
+
in_channels=input_channel,
|
| 225 |
+
out_channels=output_channel,
|
| 226 |
+
kernel_size=kernel_size,
|
| 227 |
+
stride=stride,
|
| 228 |
+
expand_ratio=t,
|
| 229 |
+
)
|
| 230 |
+
if stride == 1:
|
| 231 |
+
if input_channel == output_channel:
|
| 232 |
+
shortcut = IdentityLayer(input_channel, input_channel)
|
| 233 |
+
else:
|
| 234 |
+
shortcut = ConvLayer(input_channel,output_channel,kernel_size=1,stride=1,bias=False,act_func=None)
|
| 235 |
+
else:
|
| 236 |
+
shortcut = None
|
| 237 |
+
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
|
| 238 |
+
input_channel = output_channel
|
| 239 |
+
# 1x1_conv before global average pooling
|
| 240 |
+
feature_mix_layer = ConvLayer(
|
| 241 |
+
input_channel,
|
| 242 |
+
last_channel,
|
| 243 |
+
kernel_size=1,
|
| 244 |
+
use_bn=True,
|
| 245 |
+
act_func="relu6",
|
| 246 |
+
ops_order="weight_bn_act",
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
| 250 |
+
|
| 251 |
+
super(MobileNetV2, self).__init__(
|
| 252 |
+
first_conv, blocks, feature_mix_layer, classifier
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# set bn param
|
| 256 |
+
self.set_bn_param(*bn_param)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def proxyless_base_cifar(
|
| 261 |
+
net_config=None,
|
| 262 |
+
n_classes=None,
|
| 263 |
+
bn_param=None,
|
| 264 |
+
dropout_rate=None,
|
| 265 |
+
local_path="~/.torch/proxylessnas/",
|
| 266 |
+
):
|
| 267 |
+
assert net_config is not None, "Please input a network config"
|
| 268 |
+
if "http" in net_config:
|
| 269 |
+
net_config_path = download_url(net_config, local_path)
|
| 270 |
+
else:
|
| 271 |
+
net_config_path = net_config
|
| 272 |
+
net_config_json = json.load(open(net_config_path, "r"))
|
| 273 |
+
|
| 274 |
+
if n_classes is not None:
|
| 275 |
+
net_config_json["classifier"]["out_features"] = n_classes
|
| 276 |
+
if dropout_rate is not None:
|
| 277 |
+
net_config_json["classifier"]["dropout_rate"] = dropout_rate
|
| 278 |
+
|
| 279 |
+
net = ProxylessNASNets_Cifar.build_from_config(net_config_json)
|
| 280 |
+
if bn_param is not None:
|
| 281 |
+
net.set_bn_param(*bn_param)
|
| 282 |
+
|
| 283 |
+
return net
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class ProxylessNASNets_Cifar(MyNetwork):
|
| 287 |
+
def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
|
| 288 |
+
super(ProxylessNASNets_Cifar, self).__init__()
|
| 289 |
+
|
| 290 |
+
self.first_conv = first_conv
|
| 291 |
+
self.blocks = nn.ModuleList(blocks)
|
| 292 |
+
self.feature_mix_layer = feature_mix_layer
|
| 293 |
+
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
|
| 294 |
+
self.classifier = classifier
|
| 295 |
+
|
| 296 |
+
def forward(self, x):
|
| 297 |
+
x = self.first_conv(x)
|
| 298 |
+
for block in self.blocks:
|
| 299 |
+
x = block(x)
|
| 300 |
+
if self.feature_mix_layer is not None:
|
| 301 |
+
x = self.feature_mix_layer(x)
|
| 302 |
+
x = self.global_avg_pool(x)
|
| 303 |
+
x = self.classifier(x)
|
| 304 |
+
return x
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def module_str(self):
|
| 308 |
+
_str = self.first_conv.module_str + "\n"
|
| 309 |
+
for block in self.blocks:
|
| 310 |
+
_str += block.module_str + "\n"
|
| 311 |
+
_str += self.feature_mix_layer.module_str + "\n"
|
| 312 |
+
_str += self.global_avg_pool.__repr__() + "\n"
|
| 313 |
+
_str += self.classifier.module_str
|
| 314 |
+
return _str
|
| 315 |
+
|
| 316 |
+
@property
|
| 317 |
+
def config(self):
|
| 318 |
+
return {
|
| 319 |
+
"name": ProxylessNASNets_Cifar.__name__,
|
| 320 |
+
"bn": self.get_bn_param(),
|
| 321 |
+
"first_conv": self.first_conv.config,
|
| 322 |
+
"blocks": [block.config for block in self.blocks],
|
| 323 |
+
"feature_mix_layer": None
|
| 324 |
+
if self.feature_mix_layer is None
|
| 325 |
+
else self.feature_mix_layer.config,
|
| 326 |
+
"classifier": self.classifier.config,
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
@staticmethod
|
| 330 |
+
def build_from_config(config):
|
| 331 |
+
first_conv = set_layer_from_config(config["first_conv"])
|
| 332 |
+
feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
|
| 333 |
+
classifier = set_layer_from_config(config["classifier"])
|
| 334 |
+
|
| 335 |
+
blocks = []
|
| 336 |
+
for block_config in config["blocks"]:
|
| 337 |
+
blocks.append(ResidualBlock.build_from_config(block_config))
|
| 338 |
+
|
| 339 |
+
net = ProxylessNASNets_Cifar(first_conv, blocks, feature_mix_layer, classifier)
|
| 340 |
+
if "bn" in config:
|
| 341 |
+
net.set_bn_param(**config["bn"])
|
| 342 |
+
else:
|
| 343 |
+
net.set_bn_param(momentum=0.1, eps=1e-3)
|
| 344 |
+
|
| 345 |
+
return net
|
| 346 |
+
|
| 347 |
+
def zero_last_gamma(self):
|
| 348 |
+
for m in self.modules():
|
| 349 |
+
if isinstance(m, ResidualBlock):
|
| 350 |
+
if isinstance(m.conv, MBConvLayer) and isinstance(
|
| 351 |
+
m.shortcut, IdentityLayer
|
| 352 |
+
):
|
| 353 |
+
m.conv.point_linear.bn.weight.data.zero_()
|
| 354 |
+
|
| 355 |
+
@property
|
| 356 |
+
def grouped_block_index(self):
|
| 357 |
+
info_list = []
|
| 358 |
+
block_index_list = []
|
| 359 |
+
for i, block in enumerate(self.blocks[1:], 1):
|
| 360 |
+
if block.shortcut is None and len(block_index_list) > 0:
|
| 361 |
+
info_list.append(block_index_list)
|
| 362 |
+
block_index_list = []
|
| 363 |
+
block_index_list.append(i)
|
| 364 |
+
if len(block_index_list) > 0:
|
| 365 |
+
info_list.append(block_index_list)
|
| 366 |
+
return info_list
|
| 367 |
+
|
| 368 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 369 |
+
current_state_dict = self.state_dict()
|
| 370 |
+
|
| 371 |
+
for key in state_dict:
|
| 372 |
+
if key not in current_state_dict:
|
| 373 |
+
assert ".mobile_inverted_conv." in key
|
| 374 |
+
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
|
| 375 |
+
else:
|
| 376 |
+
new_key = key
|
| 377 |
+
current_state_dict[new_key] = state_dict[key]
|
| 378 |
+
super(ProxylessNASNets_Cifar, self).load_state_dict(current_state_dict)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class MobileNetV2_Cifar(ProxylessNASNets_Cifar):
|
| 382 |
+
def __init__(
|
| 383 |
+
self,
|
| 384 |
+
n_classes=10,
|
| 385 |
+
width_mult=1.0,
|
| 386 |
+
bn_param=(0.1, 1e-3),
|
| 387 |
+
dropout_rate=0.2,
|
| 388 |
+
ks=None,
|
| 389 |
+
expand_ratio=None,
|
| 390 |
+
depth_param=None,
|
| 391 |
+
stage_width_list=None,
|
| 392 |
+
):
|
| 393 |
+
|
| 394 |
+
ks = 3 if ks is None else ks
|
| 395 |
+
expand_ratio = 6 if expand_ratio is None else expand_ratio
|
| 396 |
+
|
| 397 |
+
input_channel = 32
|
| 398 |
+
last_channel = 1280
|
| 399 |
+
|
| 400 |
+
input_channel = make_divisible(
|
| 401 |
+
input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 402 |
+
)
|
| 403 |
+
last_channel = (
|
| 404 |
+
make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 405 |
+
if width_mult > 1.0
|
| 406 |
+
else last_channel
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
inverted_residual_setting = [
|
| 410 |
+
# t, c, n, s
|
| 411 |
+
[1, 16, 1, 1],
|
| 412 |
+
[expand_ratio, 24, 2, 1],
|
| 413 |
+
[expand_ratio, 32, 3, 2],
|
| 414 |
+
[expand_ratio, 64, 4, 2],
|
| 415 |
+
[expand_ratio, 96, 3, 1],
|
| 416 |
+
[expand_ratio, 160, 3, 2],
|
| 417 |
+
[expand_ratio, 320, 1, 1],
|
| 418 |
+
]
|
| 419 |
+
|
| 420 |
+
if depth_param is not None:
|
| 421 |
+
assert isinstance(depth_param, int)
|
| 422 |
+
for i in range(1, len(inverted_residual_setting) - 1):
|
| 423 |
+
inverted_residual_setting[i][2] = depth_param
|
| 424 |
+
|
| 425 |
+
if stage_width_list is not None:
|
| 426 |
+
for i in range(len(inverted_residual_setting)):
|
| 427 |
+
inverted_residual_setting[i][1] = stage_width_list[i]
|
| 428 |
+
|
| 429 |
+
ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
|
| 430 |
+
_pt = 0
|
| 431 |
+
|
| 432 |
+
# first conv layer
|
| 433 |
+
first_conv = ConvLayer(
|
| 434 |
+
3,
|
| 435 |
+
input_channel,
|
| 436 |
+
kernel_size=3,
|
| 437 |
+
stride=1,
|
| 438 |
+
use_bn=True,
|
| 439 |
+
act_func="relu6",
|
| 440 |
+
ops_order="weight_bn_act",
|
| 441 |
+
)
|
| 442 |
+
# inverted residual blocks
|
| 443 |
+
blocks = []
|
| 444 |
+
for t, c, n, s in inverted_residual_setting:
|
| 445 |
+
output_channel = make_divisible(c * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 446 |
+
for i in range(n):
|
| 447 |
+
if i == 0:
|
| 448 |
+
stride = s
|
| 449 |
+
else:
|
| 450 |
+
stride = 1
|
| 451 |
+
if t == 1:
|
| 452 |
+
kernel_size = 3
|
| 453 |
+
else:
|
| 454 |
+
kernel_size = ks[_pt]
|
| 455 |
+
_pt += 1
|
| 456 |
+
mobile_inverted_conv = MBConvLayer(
|
| 457 |
+
in_channels=input_channel,
|
| 458 |
+
out_channels=output_channel,
|
| 459 |
+
kernel_size=kernel_size,
|
| 460 |
+
stride=stride,
|
| 461 |
+
expand_ratio=t,
|
| 462 |
+
)
|
| 463 |
+
if stride == 1:
|
| 464 |
+
if input_channel == output_channel:
|
| 465 |
+
shortcut = IdentityLayer(input_channel, input_channel)
|
| 466 |
+
else:
|
| 467 |
+
shortcut = None #ConvLayer(input_channel,output_channel,kernel_size=1,stride=1,bias=False,act_func=None)
|
| 468 |
+
else:
|
| 469 |
+
shortcut = None
|
| 470 |
+
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
|
| 471 |
+
input_channel = output_channel
|
| 472 |
+
# 1x1_conv before global average pooling
|
| 473 |
+
feature_mix_layer = ConvLayer(
|
| 474 |
+
input_channel,
|
| 475 |
+
last_channel,
|
| 476 |
+
kernel_size=1,
|
| 477 |
+
stride=1,
|
| 478 |
+
use_bn=True,
|
| 479 |
+
act_func="relu6",
|
| 480 |
+
ops_order="weight_bn_act",
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
| 484 |
+
|
| 485 |
+
super(MobileNetV2_Cifar, self).__init__(
|
| 486 |
+
first_conv, blocks, feature_mix_layer, classifier
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# set bn param
|
| 490 |
+
self.set_bn_param(*bn_param)
|
proard/classification/networks/resnet_trades.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BasicBlock(nn.Module):
|
| 7 |
+
expansion = 1
|
| 8 |
+
|
| 9 |
+
def __init__(self, in_planes, planes, stride=1):
|
| 10 |
+
super(BasicBlock, self).__init__()
|
| 11 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 12 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 13 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 14 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 15 |
+
|
| 16 |
+
self.shortcut = nn.Sequential()
|
| 17 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 18 |
+
self.shortcut = nn.Sequential(
|
| 19 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 20 |
+
nn.BatchNorm2d(self.expansion * planes)
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 25 |
+
out = self.bn2(self.conv2(out))
|
| 26 |
+
out += self.shortcut(x)
|
| 27 |
+
out = F.relu(out)
|
| 28 |
+
return out
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Bottleneck(nn.Module):
|
| 32 |
+
expansion = 4
|
| 33 |
+
|
| 34 |
+
def __init__(self, in_planes, planes, stride=1):
|
| 35 |
+
super(Bottleneck, self).__init__()
|
| 36 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
| 37 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 38 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 39 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 40 |
+
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
|
| 41 |
+
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
| 42 |
+
|
| 43 |
+
self.shortcut = nn.Sequential()
|
| 44 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 45 |
+
self.shortcut = nn.Sequential(
|
| 46 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 47 |
+
nn.BatchNorm2d(self.expansion * planes)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 52 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
| 53 |
+
out = self.bn3(self.conv3(out))
|
| 54 |
+
out += self.shortcut(x)
|
| 55 |
+
out = F.relu(out)
|
| 56 |
+
return out
|
| 57 |
+
|
| 58 |
+
from proard.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d
|
| 59 |
+
class ResNet(MyNetwork):
|
| 60 |
+
def __init__(self, block, num_blocks, num_classes=10):
|
| 61 |
+
super(ResNet, self).__init__()
|
| 62 |
+
self.in_planes = 64
|
| 63 |
+
|
| 64 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
| 65 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 66 |
+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
| 67 |
+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
| 68 |
+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
| 69 |
+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
| 70 |
+
self.linear = nn.Linear(512 * block.expansion, num_classes)
|
| 71 |
+
|
| 72 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
| 73 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 74 |
+
layers = []
|
| 75 |
+
for stride in strides:
|
| 76 |
+
layers.append(block(self.in_planes, planes, stride))
|
| 77 |
+
self.in_planes = planes * block.expansion
|
| 78 |
+
return nn.Sequential(*layers)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 82 |
+
out = self.layer1(out)
|
| 83 |
+
out = self.layer2(out)
|
| 84 |
+
out = self.layer3(out)
|
| 85 |
+
out = self.layer4(out)
|
| 86 |
+
out = F.avg_pool2d(out, 4)
|
| 87 |
+
out = out.view(out.size(0), -1)
|
| 88 |
+
out = self.linear(out)
|
| 89 |
+
return out
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def ResNet18_trades():
|
| 93 |
+
return ResNet(BasicBlock, [2, 2, 2, 2])
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def ResNet34_trades():
|
| 97 |
+
return ResNet(BasicBlock, [3, 4, 6, 3])
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def ResNet50_trades():
|
| 101 |
+
return ResNet(Bottleneck, [3, 4, 6, 3])
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def ResNet101_trades():
|
| 105 |
+
return ResNet(Bottleneck, [3, 4, 23, 3])
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def ResNet152_trades():
|
| 109 |
+
return ResNet(Bottleneck, [3, 8, 36, 3])
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def test():
|
| 113 |
+
net = ResNet18_trades()
|
| 114 |
+
y = net(torch.randn(1, 3, 32, 32))
|
| 115 |
+
print(y.size())
|
proard/classification/networks/resnets.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from proard.utils.layers import (
|
| 4 |
+
set_layer_from_config,
|
| 5 |
+
ConvLayer,
|
| 6 |
+
IdentityLayer,
|
| 7 |
+
LinearLayer,
|
| 8 |
+
)
|
| 9 |
+
from proard.utils.layers import ResNetBottleneckBlock, ResidualBlock
|
| 10 |
+
from proard.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d
|
| 11 |
+
|
| 12 |
+
__all__ = ["ResNets", "ResNet50", "ResNet50D","ResNets_Cifar","ResNet50_Cifar", "ResNet50D_Cifar"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ResNets(MyNetwork):
|
| 16 |
+
BASE_DEPTH_LIST = [2, 2, 4, 2]
|
| 17 |
+
STAGE_WIDTH_LIST = [256, 512, 1024, 2048]
|
| 18 |
+
|
| 19 |
+
def __init__(self, input_stem, blocks, classifier):
|
| 20 |
+
super(ResNets, self).__init__()
|
| 21 |
+
|
| 22 |
+
self.input_stem = nn.ModuleList(input_stem)
|
| 23 |
+
self.max_pooling = nn.MaxPool2d(
|
| 24 |
+
kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
|
| 25 |
+
)
|
| 26 |
+
self.blocks = nn.ModuleList(blocks)
|
| 27 |
+
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
|
| 28 |
+
self.classifier = classifier
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
for layer in self.input_stem:
|
| 32 |
+
x = layer(x)
|
| 33 |
+
x = self.max_pooling(x)
|
| 34 |
+
for block in self.blocks:
|
| 35 |
+
x = block(x)
|
| 36 |
+
x = self.global_avg_pool(x)
|
| 37 |
+
x = self.classifier(x)
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def module_str(self):
|
| 42 |
+
_str = ""
|
| 43 |
+
for layer in self.input_stem:
|
| 44 |
+
_str += layer.module_str + "\n"
|
| 45 |
+
_str += "max_pooling(ks=3, stride=2)\n"
|
| 46 |
+
for block in self.blocks:
|
| 47 |
+
_str += block.module_str + "\n"
|
| 48 |
+
_str += self.global_avg_pool.__repr__() + "\n"
|
| 49 |
+
_str += self.classifier.module_str
|
| 50 |
+
return _str
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def config(self):
|
| 54 |
+
return {
|
| 55 |
+
"name": ResNets.__name__,
|
| 56 |
+
"bn": self.get_bn_param(),
|
| 57 |
+
"input_stem": [layer.config for layer in self.input_stem],
|
| 58 |
+
"blocks": [block.config for block in self.blocks],
|
| 59 |
+
"classifier": self.classifier.config,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def build_from_config(config):
|
| 64 |
+
classifier = set_layer_from_config(config["classifier"])
|
| 65 |
+
|
| 66 |
+
input_stem = []
|
| 67 |
+
for layer_config in config["input_stem"]:
|
| 68 |
+
input_stem.append(set_layer_from_config(layer_config))
|
| 69 |
+
blocks = []
|
| 70 |
+
for block_config in config["blocks"]:
|
| 71 |
+
blocks.append(set_layer_from_config(block_config))
|
| 72 |
+
|
| 73 |
+
net = ResNets(input_stem, blocks, classifier)
|
| 74 |
+
if "bn" in config:
|
| 75 |
+
net.set_bn_param(**config["bn"])
|
| 76 |
+
else:
|
| 77 |
+
net.set_bn_param(momentum=0.1, eps=1e-5)
|
| 78 |
+
|
| 79 |
+
return net
|
| 80 |
+
|
| 81 |
+
def zero_last_gamma(self):
|
| 82 |
+
for m in self.modules():
|
| 83 |
+
if isinstance(m, ResNetBottleneckBlock) and isinstance(
|
| 84 |
+
m.downsample, IdentityLayer
|
| 85 |
+
):
|
| 86 |
+
m.conv3.bn.weight.data.zero_()
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def grouped_block_index(self):
|
| 90 |
+
info_list = []
|
| 91 |
+
block_index_list = []
|
| 92 |
+
for i, block in enumerate(self.blocks):
|
| 93 |
+
if (
|
| 94 |
+
not isinstance(block.downsample, IdentityLayer)
|
| 95 |
+
and len(block_index_list) > 0
|
| 96 |
+
):
|
| 97 |
+
info_list.append(block_index_list)
|
| 98 |
+
block_index_list = []
|
| 99 |
+
block_index_list.append(i)
|
| 100 |
+
if len(block_index_list) > 0:
|
| 101 |
+
info_list.append(block_index_list)
|
| 102 |
+
return info_list
|
| 103 |
+
|
| 104 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 105 |
+
super(ResNets, self).load_state_dict(state_dict)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class ResNet50(ResNets):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
n_classes=1000,
|
| 112 |
+
width_mult=1.0,
|
| 113 |
+
bn_param=(0.1, 1e-5),
|
| 114 |
+
dropout_rate=0,
|
| 115 |
+
expand_ratio=None,
|
| 116 |
+
depth_param=None,
|
| 117 |
+
):
|
| 118 |
+
|
| 119 |
+
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
|
| 120 |
+
|
| 121 |
+
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 122 |
+
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
|
| 123 |
+
for i, width in enumerate(stage_width_list):
|
| 124 |
+
stage_width_list[i] = make_divisible(
|
| 125 |
+
width * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
depth_list = [3, 4, 6, 3]
|
| 129 |
+
if depth_param is not None:
|
| 130 |
+
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
|
| 131 |
+
depth_list[i] = depth + depth_param
|
| 132 |
+
|
| 133 |
+
stride_list = [1, 2, 2, 2]
|
| 134 |
+
|
| 135 |
+
# build input stem
|
| 136 |
+
input_stem = [
|
| 137 |
+
ConvLayer(
|
| 138 |
+
3,
|
| 139 |
+
input_channel,
|
| 140 |
+
kernel_size=7,
|
| 141 |
+
stride=2,
|
| 142 |
+
use_bn=True,
|
| 143 |
+
act_func="relu",
|
| 144 |
+
ops_order="weight_bn_act",
|
| 145 |
+
)
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
# blocks
|
| 149 |
+
blocks = []
|
| 150 |
+
for d, width, s in zip(depth_list, stage_width_list, stride_list):
|
| 151 |
+
for i in range(d):
|
| 152 |
+
stride = s if i == 0 else 1
|
| 153 |
+
bottleneck_block = ResNetBottleneckBlock(
|
| 154 |
+
input_channel,
|
| 155 |
+
width,
|
| 156 |
+
kernel_size=3,
|
| 157 |
+
stride=stride,
|
| 158 |
+
expand_ratio=expand_ratio,
|
| 159 |
+
act_func="relu",
|
| 160 |
+
downsample_mode="conv",
|
| 161 |
+
)
|
| 162 |
+
blocks.append(bottleneck_block)
|
| 163 |
+
input_channel = width
|
| 164 |
+
# classifier
|
| 165 |
+
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
|
| 166 |
+
|
| 167 |
+
super(ResNet50, self).__init__(input_stem, blocks, classifier)
|
| 168 |
+
|
| 169 |
+
# set bn param
|
| 170 |
+
self.set_bn_param(*bn_param)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class ResNet50D(ResNets):
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
n_classes=1000,
|
| 177 |
+
width_mult=1.0,
|
| 178 |
+
bn_param=(0.1, 1e-5),
|
| 179 |
+
dropout_rate=0,
|
| 180 |
+
expand_ratio=None,
|
| 181 |
+
depth_param=None,
|
| 182 |
+
):
|
| 183 |
+
|
| 184 |
+
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
|
| 185 |
+
|
| 186 |
+
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 187 |
+
mid_input_channel = make_divisible(
|
| 188 |
+
input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE
|
| 189 |
+
)
|
| 190 |
+
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
|
| 191 |
+
for i, width in enumerate(stage_width_list):
|
| 192 |
+
stage_width_list[i] = make_divisible(
|
| 193 |
+
width * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
depth_list = [3, 4, 6, 3]
|
| 197 |
+
if depth_param is not None:
|
| 198 |
+
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
|
| 199 |
+
depth_list[i] = depth + depth_param
|
| 200 |
+
|
| 201 |
+
stride_list = [1, 2, 2, 2]
|
| 202 |
+
|
| 203 |
+
# build input stem
|
| 204 |
+
input_stem = [
|
| 205 |
+
ConvLayer(3, mid_input_channel, 3, stride=2, use_bn=True, act_func="relu"),
|
| 206 |
+
ResidualBlock(
|
| 207 |
+
ConvLayer(
|
| 208 |
+
mid_input_channel,
|
| 209 |
+
mid_input_channel,
|
| 210 |
+
3,
|
| 211 |
+
stride=1,
|
| 212 |
+
use_bn=True,
|
| 213 |
+
act_func="relu",
|
| 214 |
+
),
|
| 215 |
+
IdentityLayer(mid_input_channel, mid_input_channel),
|
| 216 |
+
),
|
| 217 |
+
ConvLayer(
|
| 218 |
+
mid_input_channel,
|
| 219 |
+
input_channel,
|
| 220 |
+
3,
|
| 221 |
+
stride=1,
|
| 222 |
+
use_bn=True,
|
| 223 |
+
act_func="relu",
|
| 224 |
+
),
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
# blocks
|
| 228 |
+
blocks = []
|
| 229 |
+
for d, width, s in zip(depth_list, stage_width_list, stride_list):
|
| 230 |
+
for i in range(d):
|
| 231 |
+
stride = s if i == 0 else 1
|
| 232 |
+
bottleneck_block = ResNetBottleneckBlock(
|
| 233 |
+
input_channel,
|
| 234 |
+
width,
|
| 235 |
+
kernel_size=3,
|
| 236 |
+
stride=stride,
|
| 237 |
+
expand_ratio=expand_ratio,
|
| 238 |
+
act_func="relu",
|
| 239 |
+
downsample_mode="avgpool_conv",
|
| 240 |
+
)
|
| 241 |
+
blocks.append(bottleneck_block)
|
| 242 |
+
input_channel = width
|
| 243 |
+
# classifier
|
| 244 |
+
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
|
| 245 |
+
|
| 246 |
+
super(ResNet50D, self).__init__(input_stem, blocks, classifier)
|
| 247 |
+
|
| 248 |
+
# set bn param
|
| 249 |
+
self.set_bn_param(*bn_param)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class ResNets_Cifar(MyNetwork):
|
| 254 |
+
|
| 255 |
+
BASE_DEPTH_LIST = [2, 2, 4, 2]
|
| 256 |
+
STAGE_WIDTH_LIST = [256, 512, 1024, 2048]
|
| 257 |
+
|
| 258 |
+
def __init__(self, input_stem, blocks, classifier):
|
| 259 |
+
super(ResNets_Cifar, self).__init__()
|
| 260 |
+
|
| 261 |
+
self.input_stem = nn.ModuleList(input_stem)
|
| 262 |
+
self.blocks = nn.ModuleList(blocks)
|
| 263 |
+
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
|
| 264 |
+
self.classifier = classifier
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
for layer in self.input_stem:
|
| 268 |
+
x = layer(x)
|
| 269 |
+
for block in self.blocks:
|
| 270 |
+
x = block(x)
|
| 271 |
+
x = self.global_avg_pool(x)
|
| 272 |
+
x = self.classifier(x)
|
| 273 |
+
return x
|
| 274 |
+
|
| 275 |
+
@property
|
| 276 |
+
def module_str(self):
|
| 277 |
+
_str = ""
|
| 278 |
+
for layer in self.input_stem:
|
| 279 |
+
_str += layer.module_str + "\n"
|
| 280 |
+
# _str += "max_pooling(ks=3, stride=2)\n"
|
| 281 |
+
for block in self.blocks:
|
| 282 |
+
_str += block.module_str + "\n"
|
| 283 |
+
_str += self.global_avg_pool.__repr__() + "\n"
|
| 284 |
+
_str += self.classifier.module_str
|
| 285 |
+
return _str
|
| 286 |
+
|
| 287 |
+
@property
|
| 288 |
+
def config(self):
|
| 289 |
+
return {
|
| 290 |
+
"name": ResNets_Cifar.__name__,
|
| 291 |
+
"bn": self.get_bn_param(),
|
| 292 |
+
"input_stem": [layer.config for layer in self.input_stem],
|
| 293 |
+
"blocks": [block.config for block in self.blocks],
|
| 294 |
+
"classifier": self.classifier.config,
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
@staticmethod
|
| 298 |
+
def build_from_config(config):
|
| 299 |
+
classifier = set_layer_from_config(config["classifier"])
|
| 300 |
+
|
| 301 |
+
input_stem = []
|
| 302 |
+
for layer_config in config["input_stem"]:
|
| 303 |
+
input_stem.append(set_layer_from_config(layer_config))
|
| 304 |
+
blocks = []
|
| 305 |
+
for block_config in config["blocks"]:
|
| 306 |
+
blocks.append(set_layer_from_config(block_config))
|
| 307 |
+
|
| 308 |
+
net = ResNets(input_stem, blocks, classifier)
|
| 309 |
+
if "bn" in config:
|
| 310 |
+
net.set_bn_param(**config["bn"])
|
| 311 |
+
else:
|
| 312 |
+
net.set_bn_param(momentum=0.1, eps=1e-5)
|
| 313 |
+
|
| 314 |
+
return net
|
| 315 |
+
|
| 316 |
+
def zero_last_gamma(self):
|
| 317 |
+
for m in self.modules():
|
| 318 |
+
if isinstance(m, ResNetBottleneckBlock) and isinstance(
|
| 319 |
+
m.downsample, IdentityLayer
|
| 320 |
+
):
|
| 321 |
+
m.conv3.bn.weight.data.zero_()
|
| 322 |
+
|
| 323 |
+
@property
|
| 324 |
+
def grouped_block_index(self):
|
| 325 |
+
info_list = []
|
| 326 |
+
block_index_list = []
|
| 327 |
+
for i, block in enumerate(self.blocks):
|
| 328 |
+
if (
|
| 329 |
+
not isinstance(block.downsample, IdentityLayer)
|
| 330 |
+
and len(block_index_list) > 0
|
| 331 |
+
):
|
| 332 |
+
info_list.append(block_index_list)
|
| 333 |
+
block_index_list = []
|
| 334 |
+
block_index_list.append(i)
|
| 335 |
+
if len(block_index_list) > 0:
|
| 336 |
+
info_list.append(block_index_list)
|
| 337 |
+
return info_list
|
| 338 |
+
|
| 339 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 340 |
+
super(ResNets_Cifar, self).load_state_dict(state_dict)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class ResNet50_Cifar(ResNets_Cifar):
|
| 344 |
+
def __init__(
|
| 345 |
+
self,
|
| 346 |
+
n_classes=10,
|
| 347 |
+
width_mult=1.0,
|
| 348 |
+
bn_param=(0.1, 1e-5),
|
| 349 |
+
dropout_rate=0,
|
| 350 |
+
expand_ratio=None,
|
| 351 |
+
depth_param=None,
|
| 352 |
+
):
|
| 353 |
+
|
| 354 |
+
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
|
| 355 |
+
|
| 356 |
+
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 357 |
+
stage_width_list = ResNets_Cifar.STAGE_WIDTH_LIST.copy()
|
| 358 |
+
for i, width in enumerate(stage_width_list):
|
| 359 |
+
stage_width_list[i] = make_divisible(
|
| 360 |
+
width * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
depth_list = [3, 4, 6, 3]
|
| 364 |
+
if depth_param is not None:
|
| 365 |
+
for i, depth in enumerate(ResNets_Cifar.BASE_DEPTH_LIST):
|
| 366 |
+
depth_list[i] = depth + depth_param
|
| 367 |
+
|
| 368 |
+
stride_list = [1, 2, 2, 2]
|
| 369 |
+
|
| 370 |
+
# build input stem
|
| 371 |
+
input_stem = [
|
| 372 |
+
ConvLayer(
|
| 373 |
+
3,
|
| 374 |
+
input_channel,
|
| 375 |
+
kernel_size=3,
|
| 376 |
+
stride=1,
|
| 377 |
+
use_bn=True,
|
| 378 |
+
act_func="relu",
|
| 379 |
+
ops_order="weight_bn_act",
|
| 380 |
+
)
|
| 381 |
+
]
|
| 382 |
+
|
| 383 |
+
# blocks
|
| 384 |
+
blocks = []
|
| 385 |
+
for d, width, s in zip(depth_list, stage_width_list, stride_list):
|
| 386 |
+
for i in range(d):
|
| 387 |
+
stride = s if i == 0 else 1
|
| 388 |
+
bottleneck_block = ResNetBottleneckBlock(
|
| 389 |
+
input_channel,
|
| 390 |
+
width,
|
| 391 |
+
kernel_size=3,
|
| 392 |
+
stride=stride,
|
| 393 |
+
expand_ratio=expand_ratio,
|
| 394 |
+
act_func="relu",
|
| 395 |
+
downsample_mode="conv",
|
| 396 |
+
)
|
| 397 |
+
blocks.append(bottleneck_block)
|
| 398 |
+
input_channel = width
|
| 399 |
+
# classifier
|
| 400 |
+
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
|
| 401 |
+
|
| 402 |
+
super(ResNet50_Cifar, self).__init__(input_stem, blocks, classifier)
|
| 403 |
+
|
| 404 |
+
# set bn param
|
| 405 |
+
self.set_bn_param(*bn_param)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class ResNet50D_Cifar(ResNets_Cifar):
|
| 409 |
+
def __init__(
|
| 410 |
+
self,
|
| 411 |
+
n_classes=10,
|
| 412 |
+
width_mult=1.0,
|
| 413 |
+
bn_param=(0.1, 1e-5),
|
| 414 |
+
dropout_rate=0,
|
| 415 |
+
expand_ratio=None,
|
| 416 |
+
depth_param=None,
|
| 417 |
+
):
|
| 418 |
+
|
| 419 |
+
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
|
| 420 |
+
|
| 421 |
+
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
| 422 |
+
mid_input_channel = make_divisible(
|
| 423 |
+
input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE
|
| 424 |
+
)
|
| 425 |
+
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
|
| 426 |
+
for i, width in enumerate(stage_width_list):
|
| 427 |
+
stage_width_list[i] = make_divisible(
|
| 428 |
+
width * width_mult, MyNetwork.CHANNEL_DIVISIBLE
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
depth_list = [3, 4, 6, 3]
|
| 432 |
+
if depth_param is not None:
|
| 433 |
+
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
|
| 434 |
+
depth_list[i] = depth + depth_param
|
| 435 |
+
|
| 436 |
+
stride_list = [1, 2, 2, 2]
|
| 437 |
+
|
| 438 |
+
# build input stem
|
| 439 |
+
input_stem = [
|
| 440 |
+
ConvLayer(3, mid_input_channel, 3, stride=1, use_bn=True, act_func="relu"),
|
| 441 |
+
ResidualBlock(
|
| 442 |
+
ConvLayer(
|
| 443 |
+
mid_input_channel,
|
| 444 |
+
mid_input_channel,
|
| 445 |
+
3,
|
| 446 |
+
stride=1,
|
| 447 |
+
use_bn=True,
|
| 448 |
+
act_func="relu",
|
| 449 |
+
),
|
| 450 |
+
IdentityLayer(mid_input_channel, mid_input_channel),
|
| 451 |
+
),
|
| 452 |
+
ConvLayer(
|
| 453 |
+
mid_input_channel,
|
| 454 |
+
input_channel,
|
| 455 |
+
3,
|
| 456 |
+
stride=1,
|
| 457 |
+
use_bn=True,
|
| 458 |
+
act_func="relu",
|
| 459 |
+
),
|
| 460 |
+
]
|
| 461 |
+
|
| 462 |
+
# blocks
|
| 463 |
+
blocks = []
|
| 464 |
+
for d, width, s in zip(depth_list, stage_width_list, stride_list):
|
| 465 |
+
for i in range(d):
|
| 466 |
+
stride = s if i == 0 else 1
|
| 467 |
+
bottleneck_block = ResNetBottleneckBlock(
|
| 468 |
+
input_channel,
|
| 469 |
+
width,
|
| 470 |
+
kernel_size=3,
|
| 471 |
+
stride=stride,
|
| 472 |
+
expand_ratio=expand_ratio,
|
| 473 |
+
act_func="relu",
|
| 474 |
+
downsample_mode="avgpool_conv",
|
| 475 |
+
)
|
| 476 |
+
blocks.append(bottleneck_block)
|
| 477 |
+
input_channel = width
|
| 478 |
+
# classifier
|
| 479 |
+
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
|
| 480 |
+
|
| 481 |
+
super(ResNet50D_Cifar, self).__init__(input_stem, blocks, classifier)
|
| 482 |
+
|
| 483 |
+
# set bn param
|
| 484 |
+
self.set_bn_param(*bn_param)
|
| 485 |
+
if __name__=="__main__":
|
| 486 |
+
import torch
|
| 487 |
+
resnet = ResNet50_Cifar()
|
| 488 |
+
x = torch.randn((1,3,32,32))
|
| 489 |
+
resnet(x)
|
| 490 |
+
|
proard/classification/networks/wide_resnet.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from proard.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d
|
| 6 |
+
|
| 7 |
+
class BasicBlock(nn.Module):
|
| 8 |
+
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
| 9 |
+
super(BasicBlock, self).__init__()
|
| 10 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
| 11 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 12 |
+
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 13 |
+
padding=1, bias=False)
|
| 14 |
+
self.bn2 = nn.BatchNorm2d(out_planes)
|
| 15 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 16 |
+
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
|
| 17 |
+
padding=1, bias=False)
|
| 18 |
+
self.droprate = dropRate
|
| 19 |
+
self.equalInOut = (in_planes == out_planes)
|
| 20 |
+
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
| 21 |
+
padding=0, bias=False) or None
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
if not self.equalInOut:
|
| 25 |
+
x = self.relu1(self.bn1(x))
|
| 26 |
+
else:
|
| 27 |
+
out = self.relu1(self.bn1(x))
|
| 28 |
+
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
|
| 29 |
+
if self.droprate > 0:
|
| 30 |
+
out = F.dropout(out, p=self.droprate, training=self.training)
|
| 31 |
+
out = self.conv2(out)
|
| 32 |
+
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class NetworkBlock(nn.Module):
|
| 36 |
+
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
|
| 37 |
+
super(NetworkBlock, self).__init__()
|
| 38 |
+
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
|
| 39 |
+
|
| 40 |
+
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
|
| 41 |
+
layers = []
|
| 42 |
+
for i in range(int(nb_layers)):
|
| 43 |
+
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
|
| 44 |
+
return nn.Sequential(*layers)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
return self.layer(x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class WideResNet(MyNetwork):
|
| 51 |
+
def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0):
|
| 52 |
+
super(WideResNet, self).__init__()
|
| 53 |
+
nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
|
| 54 |
+
assert ((depth - 4) % 6 == 0)
|
| 55 |
+
n = (depth - 4) / 6
|
| 56 |
+
block = BasicBlock
|
| 57 |
+
# 1st conv before any network block
|
| 58 |
+
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
|
| 59 |
+
padding=1, bias=False)
|
| 60 |
+
# 1st block
|
| 61 |
+
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
|
| 62 |
+
# 1st sub-block
|
| 63 |
+
self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
|
| 64 |
+
# 2nd block
|
| 65 |
+
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
|
| 66 |
+
# 3rd block
|
| 67 |
+
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
|
| 68 |
+
# global average pooling and classifier
|
| 69 |
+
self.bn1 = nn.BatchNorm2d(nChannels[3])
|
| 70 |
+
self.relu = nn.ReLU(inplace=True)
|
| 71 |
+
self.fc = nn.Linear(nChannels[3], num_classes)
|
| 72 |
+
self.nChannels = nChannels[3]
|
| 73 |
+
|
| 74 |
+
for m in self.modules():
|
| 75 |
+
if isinstance(m, nn.Conv2d):
|
| 76 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 77 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 78 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 79 |
+
m.weight.data.fill_(1)
|
| 80 |
+
m.bias.data.zero_()
|
| 81 |
+
elif isinstance(m, nn.Linear):
|
| 82 |
+
m.bias.data.zero_()
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
out = self.conv1(x)
|
| 86 |
+
out = self.block1(out)
|
| 87 |
+
out = self.block2(out)
|
| 88 |
+
out = self.block3(out)
|
| 89 |
+
out = self.relu(self.bn1(out))
|
| 90 |
+
out = F.avg_pool2d(out, 8)
|
| 91 |
+
out = out.view(-1, self.nChannels)
|
| 92 |
+
return self.fc(out)
|
| 93 |
+
|
proard/classification/run_manager/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
from .run_config import *
|
| 6 |
+
from .run_manager import *
|
| 7 |
+
from .distributed_run_manager import *
|
proard/classification/run_manager/distributed_run_manager.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import time
|
| 8 |
+
import random
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from attacks import create_attack
|
| 13 |
+
import torch.backends.cudnn as cudnn
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from attacks.utils import ctx_noparamgrad_and_eval
|
| 16 |
+
from proard.utils import (
|
| 17 |
+
cross_entropy_with_label_smoothing,
|
| 18 |
+
cross_entropy_loss_with_soft_target,
|
| 19 |
+
write_log,
|
| 20 |
+
init_models,
|
| 21 |
+
)
|
| 22 |
+
from proard.utils import (
|
| 23 |
+
DistributedMetric,
|
| 24 |
+
list_mean,
|
| 25 |
+
get_net_info,
|
| 26 |
+
accuracy,
|
| 27 |
+
AverageMeter,
|
| 28 |
+
mix_labels,
|
| 29 |
+
mix_images,
|
| 30 |
+
)
|
| 31 |
+
from proard.utils import MyRandomResizedCrop
|
| 32 |
+
|
| 33 |
+
__all__ = ["DistributedRunManager"]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class DistributedRunManager:
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
path,
|
| 40 |
+
net,
|
| 41 |
+
run_config,
|
| 42 |
+
hvd_compression,
|
| 43 |
+
backward_steps=1,
|
| 44 |
+
is_root=False,
|
| 45 |
+
init=True,
|
| 46 |
+
):
|
| 47 |
+
import horovod.torch as hvd
|
| 48 |
+
|
| 49 |
+
self.path = path
|
| 50 |
+
self.net = net
|
| 51 |
+
self.run_config = run_config
|
| 52 |
+
self.is_root = is_root
|
| 53 |
+
|
| 54 |
+
self.best_acc = 0.0
|
| 55 |
+
self.best_robustness = 0.0
|
| 56 |
+
self.start_epoch = 0
|
| 57 |
+
|
| 58 |
+
os.makedirs(self.path, exist_ok=True)
|
| 59 |
+
|
| 60 |
+
self.net.cuda()
|
| 61 |
+
cudnn.benchmark = True
|
| 62 |
+
if init and self.is_root:
|
| 63 |
+
init_models(self.net, self.run_config.model_init)
|
| 64 |
+
if self.is_root:
|
| 65 |
+
# print net info
|
| 66 |
+
net_info = get_net_info(self.net, self.run_config.data_provider.data_shape)
|
| 67 |
+
with open("%s/net_info.txt" % self.path, "w") as fout:
|
| 68 |
+
fout.write(json.dumps(net_info, indent=4) + "\n")
|
| 69 |
+
try:
|
| 70 |
+
fout.write(self.net.module_str + "\n")
|
| 71 |
+
except Exception:
|
| 72 |
+
fout.write("%s do not support `module_str`" % type(self.net))
|
| 73 |
+
fout.write(
|
| 74 |
+
"%s\n" % self.run_config.data_provider.train.dataset.transform
|
| 75 |
+
)
|
| 76 |
+
fout.write(
|
| 77 |
+
"%s\n" % self.run_config.data_provider.test.dataset.transform
|
| 78 |
+
)
|
| 79 |
+
fout.write("%s\n" % self.net)
|
| 80 |
+
|
| 81 |
+
# criterion
|
| 82 |
+
self.train_criterion = self.run_config.train_criterion_loss
|
| 83 |
+
self.test_criterion = self.run_config.test_criterion_loss
|
| 84 |
+
self.kd_criterion = self.run_config.kd_criterion_loss
|
| 85 |
+
|
| 86 |
+
# optimizer
|
| 87 |
+
if self.run_config.no_decay_keys:
|
| 88 |
+
keys = self.run_config.no_decay_keys.split("#")
|
| 89 |
+
net_params = [
|
| 90 |
+
self.net.get_parameters(
|
| 91 |
+
keys, mode="exclude"
|
| 92 |
+
), # parameters with weight decay
|
| 93 |
+
self.net.get_parameters(
|
| 94 |
+
keys, mode="include"
|
| 95 |
+
), # parameters without weight decay
|
| 96 |
+
]
|
| 97 |
+
else:
|
| 98 |
+
# noinspection PyBroadException
|
| 99 |
+
try:
|
| 100 |
+
net_params = self.network.weight_parameters()
|
| 101 |
+
except Exception:
|
| 102 |
+
net_params = []
|
| 103 |
+
for param in self.network.parameters():
|
| 104 |
+
if param.requires_grad:
|
| 105 |
+
net_params.append(param)
|
| 106 |
+
self.optimizer = self.run_config.build_optimizer(net_params)
|
| 107 |
+
self.optimizer = hvd.DistributedOptimizer(
|
| 108 |
+
self.optimizer,
|
| 109 |
+
named_parameters=self.net.named_parameters(),
|
| 110 |
+
compression=hvd_compression,
|
| 111 |
+
backward_passes_per_step=backward_steps,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
""" save path and log path """
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def save_path(self):
|
| 118 |
+
if self.__dict__.get("_save_path", None) is None:
|
| 119 |
+
save_path = os.path.join(self.path, "checkpoint")
|
| 120 |
+
os.makedirs(save_path, exist_ok=True)
|
| 121 |
+
self.__dict__["_save_path"] = save_path
|
| 122 |
+
return self.__dict__["_save_path"]
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def logs_path(self):
|
| 126 |
+
if self.__dict__.get("_logs_path", None) is None:
|
| 127 |
+
logs_path = os.path.join(self.path, "logs")
|
| 128 |
+
os.makedirs(logs_path, exist_ok=True)
|
| 129 |
+
self.__dict__["_logs_path"] = logs_path
|
| 130 |
+
return self.__dict__["_logs_path"]
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def network(self):
|
| 134 |
+
return self.net
|
| 135 |
+
|
| 136 |
+
@network.setter
|
| 137 |
+
def network(self, new_val):
|
| 138 |
+
self.net = new_val
|
| 139 |
+
|
| 140 |
+
def write_log(self, log_str, prefix="valid", should_print=True, mode="a"):
|
| 141 |
+
if self.is_root:
|
| 142 |
+
write_log(self.logs_path, log_str, prefix, should_print, mode)
|
| 143 |
+
|
| 144 |
+
""" save & load model & save_config & broadcast """
|
| 145 |
+
|
| 146 |
+
def save_config(self, extra_run_config=None, extra_net_config=None):
|
| 147 |
+
if self.is_root:
|
| 148 |
+
run_save_path = os.path.join(self.path, "run.config")
|
| 149 |
+
if not os.path.isfile(run_save_path):
|
| 150 |
+
run_config = self.run_config.config
|
| 151 |
+
if extra_run_config is not None:
|
| 152 |
+
run_config.update(extra_run_config)
|
| 153 |
+
json.dump(run_config, open(run_save_path, "w"), indent=4)
|
| 154 |
+
print("Run configs dump to %s" % run_save_path)
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
net_save_path = os.path.join(self.path, "net.config")
|
| 158 |
+
net_config = self.net.config
|
| 159 |
+
if extra_net_config is not None:
|
| 160 |
+
net_config.update(extra_net_config)
|
| 161 |
+
json.dump(net_config, open(net_save_path, "w"), indent=4)
|
| 162 |
+
print("Network configs dump to %s" % net_save_path)
|
| 163 |
+
except Exception:
|
| 164 |
+
print("%s do not support net config" % type(self.net))
|
| 165 |
+
|
| 166 |
+
def save_model(self, checkpoint=None, is_best=False, model_name=None):
|
| 167 |
+
if self.is_root:
|
| 168 |
+
if checkpoint is None:
|
| 169 |
+
checkpoint = {"state_dict": self.net.state_dict()}
|
| 170 |
+
|
| 171 |
+
if model_name is None:
|
| 172 |
+
model_name = "checkpoint.pth.tar"
|
| 173 |
+
|
| 174 |
+
latest_fname = os.path.join(self.save_path, "latest.txt")
|
| 175 |
+
model_path = os.path.join(self.save_path, model_name)
|
| 176 |
+
with open(latest_fname, "w") as _fout:
|
| 177 |
+
_fout.write(model_path + "\n")
|
| 178 |
+
torch.save(checkpoint, model_path)
|
| 179 |
+
|
| 180 |
+
if is_best:
|
| 181 |
+
best_path = os.path.join(self.save_path, "model_best.pth.tar")
|
| 182 |
+
torch.save({"state_dict": checkpoint["state_dict"]}, best_path)
|
| 183 |
+
|
| 184 |
+
def load_model(self, model_fname=None):
|
| 185 |
+
if self.is_root:
|
| 186 |
+
latest_fname = os.path.join(self.save_path, "latest.txt")
|
| 187 |
+
if model_fname is None and os.path.exists(latest_fname):
|
| 188 |
+
with open(latest_fname, "r") as fin:
|
| 189 |
+
model_fname = fin.readline()
|
| 190 |
+
if model_fname[-1] == "\n":
|
| 191 |
+
model_fname = model_fname[:-1]
|
| 192 |
+
# noinspection PyBroadException
|
| 193 |
+
try:
|
| 194 |
+
if model_fname is None or not os.path.exists(model_fname):
|
| 195 |
+
model_fname = "%s/checkpoint.pth.tar" % self.save_path
|
| 196 |
+
with open(latest_fname, "w") as fout:
|
| 197 |
+
fout.write(model_fname + "\n")
|
| 198 |
+
print("=> loading checkpoint '{}'".format(model_fname))
|
| 199 |
+
checkpoint = torch.load(model_fname, map_location="cpu")
|
| 200 |
+
except Exception:
|
| 201 |
+
self.write_log(
|
| 202 |
+
"fail to load checkpoint from %s" % self.save_path, "valid"
|
| 203 |
+
)
|
| 204 |
+
return
|
| 205 |
+
|
| 206 |
+
self.net.load_state_dict(checkpoint["state_dict"])
|
| 207 |
+
if "epoch" in checkpoint:
|
| 208 |
+
self.start_epoch = checkpoint["epoch"] + 1
|
| 209 |
+
if "best_acc" in checkpoint:
|
| 210 |
+
self.best_acc = checkpoint["best_acc"]
|
| 211 |
+
if "optimizer" in checkpoint:
|
| 212 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| 213 |
+
|
| 214 |
+
self.write_log("=> loaded checkpoint '{}'".format(model_fname), "valid")
|
| 215 |
+
|
| 216 |
+
# noinspection PyArgumentList
|
| 217 |
+
def broadcast(self):
|
| 218 |
+
import horovod.torch as hvd
|
| 219 |
+
|
| 220 |
+
self.start_epoch = hvd.broadcast(
|
| 221 |
+
torch.LongTensor(1).fill_(self.start_epoch)[0], 0, name="start_epoch"
|
| 222 |
+
).item()
|
| 223 |
+
self.best_acc = hvd.broadcast(
|
| 224 |
+
torch.Tensor(1).fill_(self.best_acc)[0], 0, name="best_acc"
|
| 225 |
+
).item()
|
| 226 |
+
hvd.broadcast_parameters(self.net.state_dict(), 0)
|
| 227 |
+
hvd.broadcast_optimizer_state(self.optimizer, 0)
|
| 228 |
+
|
| 229 |
+
""" metric related """
|
| 230 |
+
|
| 231 |
+
def get_metric_dict(self):
|
| 232 |
+
return {
|
| 233 |
+
"top1": DistributedMetric("top1"),
|
| 234 |
+
"top5": DistributedMetric("top5"),
|
| 235 |
+
"robust1" : DistributedMetric("robust1"),
|
| 236 |
+
"robust5": DistributedMetric("robust5")
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
def update_metric(self, metric_dict, output, output_adv , labels):
|
| 240 |
+
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
| 241 |
+
robust1, robust5 = accuracy(output_adv, labels, topk=(1, 5))
|
| 242 |
+
metric_dict["top1"].update(acc1[0], output.size(0))
|
| 243 |
+
metric_dict["top5"].update(acc5[0], output.size(0))
|
| 244 |
+
metric_dict["robust1"].update(robust1[0], output.size(0))
|
| 245 |
+
metric_dict["robust5"].update(robust5[0], output.size(0))
|
| 246 |
+
|
| 247 |
+
def get_metric_vals(self, metric_dict, return_dict=False):
|
| 248 |
+
if return_dict:
|
| 249 |
+
return {key: metric_dict[key].avg.item() for key in metric_dict}
|
| 250 |
+
else:
|
| 251 |
+
return [metric_dict[key].avg.item() for key in metric_dict]
|
| 252 |
+
|
| 253 |
+
def get_metric_names(self):
|
| 254 |
+
return "top1", "top5", "robust1" ,"robust5"
|
| 255 |
+
|
| 256 |
+
""" train & validate """
|
| 257 |
+
|
| 258 |
+
def validate(
|
| 259 |
+
self,
|
| 260 |
+
epoch=0,
|
| 261 |
+
is_test=False,
|
| 262 |
+
run_str="",
|
| 263 |
+
net=None,
|
| 264 |
+
data_loader=None,
|
| 265 |
+
no_logs=False,
|
| 266 |
+
):
|
| 267 |
+
if net is None:
|
| 268 |
+
net = self.net
|
| 269 |
+
if data_loader is None:
|
| 270 |
+
if is_test:
|
| 271 |
+
data_loader = self.run_config.test_loader
|
| 272 |
+
else:
|
| 273 |
+
data_loader = self.run_config.valid_loader
|
| 274 |
+
|
| 275 |
+
net.eval()
|
| 276 |
+
if self.run_config.robust_mode:
|
| 277 |
+
eval_attack = create_attack(net, self.test_criterion.cuda(), self.run_config.attack_type,self.run_config.epsilon_test,self.run_config.num_steps_test, self.run_config.step_size_test)
|
| 278 |
+
losses = DistributedMetric("val_loss")
|
| 279 |
+
metric_dict = self.get_metric_dict()
|
| 280 |
+
|
| 281 |
+
with tqdm(
|
| 282 |
+
total=len(data_loader),
|
| 283 |
+
desc="Validate Epoch #{} {}".format(epoch + 1, run_str),
|
| 284 |
+
disable=no_logs or not self.is_root,
|
| 285 |
+
) as t:
|
| 286 |
+
for i, (images, labels) in enumerate(data_loader):
|
| 287 |
+
images, labels = images.cuda(), labels.cuda()
|
| 288 |
+
# compute output
|
| 289 |
+
output = net(images)
|
| 290 |
+
if self.run_config.robust_mode:
|
| 291 |
+
with ctx_noparamgrad_and_eval(net):
|
| 292 |
+
images_adv,_ = eval_attack.perturb(images, labels)
|
| 293 |
+
output_adv = net(images_adv)
|
| 294 |
+
loss = self.test_criterion(output_adv,labels)
|
| 295 |
+
else:
|
| 296 |
+
output_adv = output
|
| 297 |
+
loss = self.test_criterion(output,labels)
|
| 298 |
+
|
| 299 |
+
# measure accuracy and record loss
|
| 300 |
+
losses.update(loss, images.size(0))
|
| 301 |
+
self.update_metric(metric_dict, output, output_adv, labels)
|
| 302 |
+
t.set_postfix(
|
| 303 |
+
{
|
| 304 |
+
"loss": losses.avg.item(),
|
| 305 |
+
**self.get_metric_vals(metric_dict, return_dict=True),
|
| 306 |
+
"img_size": images.size(2),
|
| 307 |
+
}
|
| 308 |
+
)
|
| 309 |
+
t.update(1)
|
| 310 |
+
return losses.avg.item(), self.get_metric_vals(metric_dict)
|
| 311 |
+
|
| 312 |
+
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
|
| 313 |
+
if net is None:
|
| 314 |
+
net = self.net
|
| 315 |
+
if isinstance(self.run_config.data_provider.image_size, list):
|
| 316 |
+
img_size_list, loss_list, top1_list, top5_list ,robust1_list, robust5_list = [], [], [], [],[],[]
|
| 317 |
+
for img_size in self.run_config.data_provider.image_size:
|
| 318 |
+
img_size_list.append(img_size)
|
| 319 |
+
self.run_config.data_provider.assign_active_img_size(img_size)
|
| 320 |
+
self.reset_running_statistics(net=net) # I am not sure that this is good fot robustness or not
|
| 321 |
+
loss, (top1, top5 ,robust1, robust5) = self.validate(epoch, is_test, net=net)
|
| 322 |
+
loss_list.append(loss)
|
| 323 |
+
top1_list.append(top1)
|
| 324 |
+
top5_list.append(top5)
|
| 325 |
+
robust1_list.append(robust1)
|
| 326 |
+
robust5_list.append(robust5)
|
| 327 |
+
|
| 328 |
+
return img_size_list, loss_list, top1_list, top5_list,robust1_list,robust5_list
|
| 329 |
+
else:
|
| 330 |
+
self.reset_running_statistics(net=net)
|
| 331 |
+
loss, (top1, top5 , robust1 ,robust5) = self.validate(epoch, is_test, net=net)
|
| 332 |
+
return (
|
| 333 |
+
[self.run_config.data_provider.active_img_size],
|
| 334 |
+
[loss],
|
| 335 |
+
[top1],
|
| 336 |
+
[top5],
|
| 337 |
+
[robust1],
|
| 338 |
+
[robust5],
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def train_one_epoch(self, args, epoch, warmup_epochs=5, warmup_lr=0):
|
| 342 |
+
self.net.train()
|
| 343 |
+
self.run_config.train_loader.sampler.set_epoch(
|
| 344 |
+
epoch
|
| 345 |
+
) # required by distributed sampler
|
| 346 |
+
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
|
| 347 |
+
|
| 348 |
+
nBatch = len(self.run_config.train_loader)
|
| 349 |
+
|
| 350 |
+
losses = DistributedMetric("train_loss")
|
| 351 |
+
metric_dict = self.get_metric_dict()
|
| 352 |
+
data_time = AverageMeter()
|
| 353 |
+
|
| 354 |
+
with tqdm(
|
| 355 |
+
total=nBatch,
|
| 356 |
+
desc="Train Epoch #{}".format(epoch + 1),
|
| 357 |
+
disable=not self.is_root,
|
| 358 |
+
) as t:
|
| 359 |
+
end = time.time()
|
| 360 |
+
for i, (images, labels) in enumerate(self.run_config.train_loader):
|
| 361 |
+
MyRandomResizedCrop.BATCH = i
|
| 362 |
+
data_time.update(time.time() - end)
|
| 363 |
+
if epoch < warmup_epochs:
|
| 364 |
+
new_lr = self.run_config.warmup_adjust_learning_rate(
|
| 365 |
+
self.optimizer,
|
| 366 |
+
warmup_epochs * nBatch,
|
| 367 |
+
nBatch,
|
| 368 |
+
epoch,
|
| 369 |
+
i,
|
| 370 |
+
warmup_lr,
|
| 371 |
+
)
|
| 372 |
+
else:
|
| 373 |
+
new_lr = self.run_config.adjust_learning_rate(
|
| 374 |
+
self.optimizer, epoch - warmup_epochs, i, nBatch
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
images, labels = images.cuda(), labels.cuda()
|
| 378 |
+
target = labels
|
| 379 |
+
if isinstance(self.run_config.mixup_alpha, float):
|
| 380 |
+
# transform data
|
| 381 |
+
random.seed(int("%d%.3d" % (i, epoch)))
|
| 382 |
+
lam = random.betavariate(
|
| 383 |
+
self.run_config.mixup_alpha, self.run_config.mixup_alpha
|
| 384 |
+
)
|
| 385 |
+
images = mix_images(images, lam)
|
| 386 |
+
labels = mix_labels(
|
| 387 |
+
labels,
|
| 388 |
+
lam,
|
| 389 |
+
self.run_config.data_provider.n_classes,
|
| 390 |
+
self.run_config.label_smoothing,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# soft target
|
| 394 |
+
if args.teacher_model is not None:
|
| 395 |
+
args.teacher_model.train()
|
| 396 |
+
with torch.no_grad():
|
| 397 |
+
soft_logits = args.teacher_model(images).detach()
|
| 398 |
+
soft_label = F.softmax(soft_logits, dim=1)
|
| 399 |
+
|
| 400 |
+
# compute output
|
| 401 |
+
output = self.net(images)
|
| 402 |
+
if args.teacher_model is None:
|
| 403 |
+
if self.run_config.robust_mode:
|
| 404 |
+
loss = self.train_criterion(self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train,self.run_config.distance_train)
|
| 405 |
+
loss_type = self.train_criterion.__name__
|
| 406 |
+
else:
|
| 407 |
+
loss = torch.nn.CrossEntropyLoss(output,labels)
|
| 408 |
+
loss_type = 'ce'
|
| 409 |
+
|
| 410 |
+
else:
|
| 411 |
+
if self.run_config.robust_mode:
|
| 412 |
+
loss = self.kd_criterion(args.teacher_model,self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train)
|
| 413 |
+
loss_type = self.kd_criterion_loss.__name__
|
| 414 |
+
else:
|
| 415 |
+
if args.kd_type == "ce":
|
| 416 |
+
kd_loss = cross_entropy_loss_with_soft_target(
|
| 417 |
+
output, soft_label
|
| 418 |
+
)
|
| 419 |
+
else:
|
| 420 |
+
kd_loss = F.mse_loss(output, soft_logits)
|
| 421 |
+
loss = args.kd_ratio * kd_loss + loss
|
| 422 |
+
loss_type = "%.1fkd+ce" % args.kd_ratio
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
# update
|
| 426 |
+
self.optimizer.zero_grad()
|
| 427 |
+
loss.backward()
|
| 428 |
+
self.optimizer.step()
|
| 429 |
+
|
| 430 |
+
# measure accuracy and record loss
|
| 431 |
+
losses.update(loss, images.size(0))
|
| 432 |
+
self.update_metric(metric_dict, output, output, target)
|
| 433 |
+
|
| 434 |
+
t.set_postfix(
|
| 435 |
+
{
|
| 436 |
+
"loss": losses.avg.item(),
|
| 437 |
+
**self.get_metric_vals(metric_dict, return_dict=True),
|
| 438 |
+
"img_size": images.size(2),
|
| 439 |
+
"lr": new_lr,
|
| 440 |
+
"loss_type": loss_type,
|
| 441 |
+
"data_time": data_time.avg,
|
| 442 |
+
}
|
| 443 |
+
)
|
| 444 |
+
t.update(1)
|
| 445 |
+
end = time.time()
|
| 446 |
+
return losses.avg.item(), self.get_metric_vals(metric_dict)
|
| 447 |
+
|
| 448 |
+
def train(self, args, warmup_epochs=5, warmup_lr=0):
|
| 449 |
+
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epochs):
|
| 450 |
+
train_loss, (train_top1, train_top5, train_robust1, train_robust5) = self.train_one_epoch(
|
| 451 |
+
args, epoch, warmup_epochs, warmup_lr
|
| 452 |
+
)
|
| 453 |
+
img_size, val_loss, val_top1, val_top5 , val_robust1, val_robust5= self.validate_all_resolution(
|
| 454 |
+
epoch, is_test=False
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
is_best = list_mean(val_top1) > self.best_acc
|
| 458 |
+
is_best_robust = list_mean(val_robust1) > self.best_robustness
|
| 459 |
+
self.best_robustness = max(self.best_robustness, list_mean(val_robust1))
|
| 460 |
+
self.best_acc = max(self.best_acc, list_mean(val_top1))
|
| 461 |
+
if self.is_root:
|
| 462 |
+
val_log = (
|
| 463 |
+
"[{0}/{1}]\tloss {2:.3f}\t{6} acc {3:.3f} ({4:.3f})\t{7} acc {5:.3f}\t {8} robust {10:.3f} ({4:.3f})\t{9} robust {11:.3f} "
|
| 464 |
+
"Train {6} {top1:.3f}\tloss {train_loss:.3f}\t robust1 {8} {robust1:.3f}\t".format(
|
| 465 |
+
epoch + 1 - warmup_epochs,
|
| 466 |
+
self.run_config.n_epochs,
|
| 467 |
+
list_mean(val_loss),
|
| 468 |
+
list_mean(val_top1),
|
| 469 |
+
self.best_acc,
|
| 470 |
+
list_mean(val_top5),
|
| 471 |
+
*self.get_metric_names(),
|
| 472 |
+
list_mean(val_robust1),
|
| 473 |
+
list_mean(val_robust5),
|
| 474 |
+
top1=train_top1,
|
| 475 |
+
train_loss=train_loss,
|
| 476 |
+
robust1 = train_robust1,
|
| 477 |
+
)
|
| 478 |
+
)
|
| 479 |
+
for i_s, v_a in zip(img_size, val_top1):
|
| 480 |
+
val_log += "(%d, %.3f), " % (i_s, v_a)
|
| 481 |
+
self.write_log(val_log, prefix="valid", should_print=False)
|
| 482 |
+
|
| 483 |
+
self.save_model(
|
| 484 |
+
{
|
| 485 |
+
"epoch": epoch,
|
| 486 |
+
"best_acc": self.best_acc,
|
| 487 |
+
"optimizer": self.optimizer.state_dict(),
|
| 488 |
+
"state_dict": self.net.state_dict(),
|
| 489 |
+
},
|
| 490 |
+
is_best=is_best,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def reset_running_statistics(
|
| 494 |
+
self, net=None, subset_size=4000, subset_batch_size=200, data_loader=None
|
| 495 |
+
):
|
| 496 |
+
from proard.classification.elastic_nn.utils import set_running_statistics
|
| 497 |
+
|
| 498 |
+
if net is None:
|
| 499 |
+
net = self.net
|
| 500 |
+
if data_loader is None:
|
| 501 |
+
data_loader = self.run_config.random_sub_train_loader(
|
| 502 |
+
subset_size, subset_batch_size
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
set_running_statistics(net, data_loader)
|
proard/classification/run_manager/run_config.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
from proard.utils import calc_learning_rate, build_optimizer
|
| 6 |
+
from proard.classification.data_providers import ImagenetDataProvider
|
| 7 |
+
from proard.classification.data_providers import Cifar10DataProvider
|
| 8 |
+
from proard.classification.data_providers import Cifar100DataProvider
|
| 9 |
+
from robust_loss.trades import trades_loss
|
| 10 |
+
from robust_loss.adaad import adaad_loss
|
| 11 |
+
from robust_loss.ard import ard_loss
|
| 12 |
+
from robust_loss.hat import hat_loss
|
| 13 |
+
from robust_loss.mart import mart_loss
|
| 14 |
+
from robust_loss.sat import sat_loss
|
| 15 |
+
from robust_loss.rslad import rslad_loss
|
| 16 |
+
import torch
|
| 17 |
+
__all__ = ["RunConfig", "ClassificationRunConfig", "DistributedClassificationRunConfig"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RunConfig:
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
n_epochs,
|
| 24 |
+
init_lr,
|
| 25 |
+
lr_schedule_type,
|
| 26 |
+
lr_schedule_param,
|
| 27 |
+
dataset,
|
| 28 |
+
train_batch_size,
|
| 29 |
+
test_batch_size,
|
| 30 |
+
valid_size,
|
| 31 |
+
opt_type,
|
| 32 |
+
opt_param,
|
| 33 |
+
weight_decay,
|
| 34 |
+
label_smoothing,
|
| 35 |
+
no_decay_keys,
|
| 36 |
+
mixup_alpha,
|
| 37 |
+
model_init,
|
| 38 |
+
validation_frequency,
|
| 39 |
+
print_frequency,
|
| 40 |
+
):
|
| 41 |
+
self.n_epochs = n_epochs
|
| 42 |
+
self.init_lr = init_lr
|
| 43 |
+
self.lr_schedule_type = lr_schedule_type
|
| 44 |
+
self.lr_schedule_param = lr_schedule_param
|
| 45 |
+
|
| 46 |
+
self.dataset = dataset
|
| 47 |
+
self.train_batch_size = train_batch_size
|
| 48 |
+
self.test_batch_size = test_batch_size
|
| 49 |
+
self.valid_size = valid_size
|
| 50 |
+
|
| 51 |
+
self.opt_type = opt_type
|
| 52 |
+
self.opt_param = opt_param
|
| 53 |
+
self.weight_decay = weight_decay
|
| 54 |
+
self.label_smoothing = label_smoothing
|
| 55 |
+
self.no_decay_keys = no_decay_keys
|
| 56 |
+
|
| 57 |
+
self.mixup_alpha = mixup_alpha
|
| 58 |
+
|
| 59 |
+
self.model_init = model_init
|
| 60 |
+
self.validation_frequency = validation_frequency
|
| 61 |
+
self.print_frequency = print_frequency
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def config(self):
|
| 65 |
+
config = {}
|
| 66 |
+
for key in self.__dict__:
|
| 67 |
+
if not key.startswith("_"):
|
| 68 |
+
config[key] = self.__dict__[key]
|
| 69 |
+
return config
|
| 70 |
+
|
| 71 |
+
def copy(self):
|
| 72 |
+
return RunConfig(**self.config)
|
| 73 |
+
|
| 74 |
+
""" learning rate """
|
| 75 |
+
|
| 76 |
+
def adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
|
| 77 |
+
"""adjust learning of a given optimizer and return the new learning rate"""
|
| 78 |
+
new_lr = calc_learning_rate(
|
| 79 |
+
epoch, self.init_lr, self.n_epochs, batch, nBatch, self.lr_schedule_type
|
| 80 |
+
)
|
| 81 |
+
for param_group in optimizer.param_groups:
|
| 82 |
+
param_group["lr"] = new_lr
|
| 83 |
+
return new_lr
|
| 84 |
+
|
| 85 |
+
def warmup_adjust_learning_rate(
|
| 86 |
+
self, optimizer, T_total, nBatch, epoch, batch=0, warmup_lr=0
|
| 87 |
+
):
|
| 88 |
+
T_cur = epoch * nBatch + batch + 1
|
| 89 |
+
new_lr = T_cur / T_total * (self.init_lr - warmup_lr) + warmup_lr
|
| 90 |
+
for param_group in optimizer.param_groups:
|
| 91 |
+
param_group["lr"] = new_lr
|
| 92 |
+
return new_lr
|
| 93 |
+
|
| 94 |
+
""" data provider """
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def data_provider(self):
|
| 98 |
+
raise NotImplementedError
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def train_loader(self):
|
| 102 |
+
return self.data_provider.train
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def valid_loader(self):
|
| 106 |
+
return self.data_provider.valid
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def test_loader(self):
|
| 110 |
+
return self.data_provider.test
|
| 111 |
+
|
| 112 |
+
def random_sub_train_loader(
|
| 113 |
+
self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None
|
| 114 |
+
):
|
| 115 |
+
return self.data_provider.build_sub_train_loader(
|
| 116 |
+
n_images, batch_size, num_worker, num_replicas, rank
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
""" optimizer """
|
| 120 |
+
|
| 121 |
+
def build_optimizer(self, net_params):
|
| 122 |
+
return build_optimizer(
|
| 123 |
+
net_params,
|
| 124 |
+
self.opt_type,
|
| 125 |
+
self.opt_param,
|
| 126 |
+
self.init_lr,
|
| 127 |
+
self.weight_decay,
|
| 128 |
+
self.no_decay_keys,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class ClassificationRunConfig(RunConfig):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
n_epochs=150,
|
| 137 |
+
init_lr=0.05,
|
| 138 |
+
lr_schedule_type="cosine",
|
| 139 |
+
lr_schedule_param=None,
|
| 140 |
+
dataset="imagenet", # 'cifar10' or 'cifar100'
|
| 141 |
+
train_batch_size=256,
|
| 142 |
+
test_batch_size=500,
|
| 143 |
+
valid_size=None,
|
| 144 |
+
opt_type="sgd",
|
| 145 |
+
opt_param=None,
|
| 146 |
+
weight_decay=4e-5,
|
| 147 |
+
label_smoothing=0.1,
|
| 148 |
+
no_decay_keys=None,
|
| 149 |
+
mixup_alpha=None,
|
| 150 |
+
model_init="he_fout",
|
| 151 |
+
validation_frequency=1,
|
| 152 |
+
print_frequency=10,
|
| 153 |
+
n_worker=32,
|
| 154 |
+
resize_scale=0.08,
|
| 155 |
+
distort_color="tf",
|
| 156 |
+
image_size=224, # 32
|
| 157 |
+
robust_mode = False,
|
| 158 |
+
epsilon_train = 0.031,
|
| 159 |
+
num_steps_train = 10,
|
| 160 |
+
step_size_train = 0.0078,
|
| 161 |
+
clip_min_train = 0 ,
|
| 162 |
+
clip_max_train = 1,
|
| 163 |
+
const_init_train = False,
|
| 164 |
+
beta_train = 6.0,
|
| 165 |
+
distance_train ="l_inf",
|
| 166 |
+
epsilon_test = 0.031,
|
| 167 |
+
num_steps_test = 20,
|
| 168 |
+
step_size_test = 0.0078,
|
| 169 |
+
clip_min_test = 0,
|
| 170 |
+
clip_max_test = 1,
|
| 171 |
+
const_init_test = False,
|
| 172 |
+
beta_test = 6.0,
|
| 173 |
+
distance_test = "l_inf",
|
| 174 |
+
train_criterion = "trades",
|
| 175 |
+
test_criterion = "ce",
|
| 176 |
+
kd_criterion = 'rslad',
|
| 177 |
+
attack_type = "linf-pgd",
|
| 178 |
+
**kwargs
|
| 179 |
+
):
|
| 180 |
+
super(ClassificationRunConfig, self).__init__(
|
| 181 |
+
n_epochs,
|
| 182 |
+
init_lr,
|
| 183 |
+
lr_schedule_type,
|
| 184 |
+
lr_schedule_param,
|
| 185 |
+
dataset,
|
| 186 |
+
train_batch_size,
|
| 187 |
+
test_batch_size,
|
| 188 |
+
valid_size,
|
| 189 |
+
opt_type,
|
| 190 |
+
opt_param,
|
| 191 |
+
weight_decay,
|
| 192 |
+
label_smoothing,
|
| 193 |
+
no_decay_keys,
|
| 194 |
+
mixup_alpha,
|
| 195 |
+
model_init,
|
| 196 |
+
validation_frequency,
|
| 197 |
+
print_frequency,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.n_worker = n_worker
|
| 201 |
+
self.resize_scale = resize_scale
|
| 202 |
+
self.distort_color = distort_color
|
| 203 |
+
self.image_size = image_size
|
| 204 |
+
self.epsilon_train = epsilon_train
|
| 205 |
+
self.num_steps_train = num_steps_train
|
| 206 |
+
self.step_size_train = step_size_train
|
| 207 |
+
self.clip_min_train = clip_min_train
|
| 208 |
+
self.clip_max_train = clip_max_train
|
| 209 |
+
self.const_init_train = const_init_train
|
| 210 |
+
self.beta_train = beta_train
|
| 211 |
+
self.distance_train = distance_train
|
| 212 |
+
self.epsilon_test = epsilon_test
|
| 213 |
+
self.num_steps_test = num_steps_test
|
| 214 |
+
self.step_size_test = step_size_test
|
| 215 |
+
self.clip_min_test = clip_min_test
|
| 216 |
+
self.clip_max_test = clip_max_test
|
| 217 |
+
self.const_init_test = const_init_test
|
| 218 |
+
self.beta_test = beta_test
|
| 219 |
+
self.distance_test = distance_test
|
| 220 |
+
self.train_criterion = train_criterion
|
| 221 |
+
self.test_criterion = test_criterion
|
| 222 |
+
self.kd_criterion = kd_criterion
|
| 223 |
+
self.attack_type = attack_type
|
| 224 |
+
self.robust_mode = robust_mode
|
| 225 |
+
@property
|
| 226 |
+
def data_provider(self):
|
| 227 |
+
if self.__dict__.get("_data_provider", None) is None:
|
| 228 |
+
if self.dataset == ImagenetDataProvider.name():
|
| 229 |
+
DataProviderClass = ImagenetDataProvider
|
| 230 |
+
elif self.dataset == Cifar10DataProvider.name():
|
| 231 |
+
DataProviderClass = Cifar10DataProvider
|
| 232 |
+
elif self.dataset == Cifar100DataProvider.name():
|
| 233 |
+
DataProviderClass = Cifar100DataProvider
|
| 234 |
+
else:
|
| 235 |
+
raise NotImplementedError
|
| 236 |
+
self.__dict__["_data_provider"] = DataProviderClass(
|
| 237 |
+
train_batch_size=self.train_batch_size,
|
| 238 |
+
test_batch_size=self.test_batch_size,
|
| 239 |
+
valid_size=self.valid_size,
|
| 240 |
+
n_worker=self.n_worker,
|
| 241 |
+
resize_scale=self.resize_scale,
|
| 242 |
+
distort_color=self.distort_color,
|
| 243 |
+
image_size=self.image_size,
|
| 244 |
+
)
|
| 245 |
+
return self.__dict__["_data_provider"]
|
| 246 |
+
@property
|
| 247 |
+
def train_criterion_loss (self):
|
| 248 |
+
if self.train_criterion == "trades" :
|
| 249 |
+
return trades_loss
|
| 250 |
+
elif self.train_criterion == "mart" :
|
| 251 |
+
return mart_loss
|
| 252 |
+
elif self.train_criterion == "sat" :
|
| 253 |
+
return sat_loss
|
| 254 |
+
elif self.train_criterion == "hat" :
|
| 255 |
+
return hat_loss
|
| 256 |
+
@property
|
| 257 |
+
def test_criterion_loss (self) :
|
| 258 |
+
if self.test_criterion == "ce" :
|
| 259 |
+
return torch.nn.CrossEntropyLoss()
|
| 260 |
+
@property
|
| 261 |
+
def kd_criterion_loss (self) :
|
| 262 |
+
if self.kd_criterion =="ard" :
|
| 263 |
+
return ard_loss
|
| 264 |
+
elif self.kd_criterion == "adaad" :
|
| 265 |
+
return adaad_loss
|
| 266 |
+
elif self.kd_criterion == "rslad" :
|
| 267 |
+
return rslad_loss
|
| 268 |
+
class DistributedClassificationRunConfig(ClassificationRunConfig):
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
n_epochs=150,
|
| 272 |
+
init_lr=0.05,
|
| 273 |
+
lr_schedule_type="cosine",
|
| 274 |
+
lr_schedule_param=None,
|
| 275 |
+
dataset="imagenet",
|
| 276 |
+
train_batch_size=64,
|
| 277 |
+
test_batch_size=64,
|
| 278 |
+
valid_size=None,
|
| 279 |
+
opt_type="sgd",
|
| 280 |
+
opt_param=None,
|
| 281 |
+
weight_decay=4e-5,
|
| 282 |
+
label_smoothing=0.1,
|
| 283 |
+
no_decay_keys=None,
|
| 284 |
+
mixup_alpha=None,
|
| 285 |
+
model_init="he_fout",
|
| 286 |
+
validation_frequency=1,
|
| 287 |
+
print_frequency=10,
|
| 288 |
+
n_worker=8,
|
| 289 |
+
resize_scale=0.08,
|
| 290 |
+
distort_color="tf",
|
| 291 |
+
image_size=224,
|
| 292 |
+
robust_mode = False,
|
| 293 |
+
epsilon = 0.031,
|
| 294 |
+
num_steps = 10,
|
| 295 |
+
step_size = 0.0078,
|
| 296 |
+
clip_min = 0,
|
| 297 |
+
clip_max = 1,
|
| 298 |
+
const_init = False,
|
| 299 |
+
beta = 6.0,
|
| 300 |
+
distance = "l_inf",
|
| 301 |
+
train_criterion = "trades",
|
| 302 |
+
test_criterion = "ce",
|
| 303 |
+
kd_criterion = 'rslad',
|
| 304 |
+
attack_type = "linf-pgd",
|
| 305 |
+
**kwargs
|
| 306 |
+
):
|
| 307 |
+
super(DistributedClassificationRunConfig, self).__init__(
|
| 308 |
+
n_epochs,
|
| 309 |
+
init_lr,
|
| 310 |
+
lr_schedule_type,
|
| 311 |
+
lr_schedule_param,
|
| 312 |
+
dataset,
|
| 313 |
+
train_batch_size,
|
| 314 |
+
test_batch_size,
|
| 315 |
+
valid_size,
|
| 316 |
+
opt_type,
|
| 317 |
+
opt_param,
|
| 318 |
+
weight_decay,
|
| 319 |
+
label_smoothing,
|
| 320 |
+
no_decay_keys,
|
| 321 |
+
mixup_alpha,
|
| 322 |
+
model_init,
|
| 323 |
+
validation_frequency,
|
| 324 |
+
print_frequency,
|
| 325 |
+
n_worker,
|
| 326 |
+
resize_scale,
|
| 327 |
+
distort_color,
|
| 328 |
+
image_size,
|
| 329 |
+
robust_mode,
|
| 330 |
+
epsilon,
|
| 331 |
+
num_steps,
|
| 332 |
+
step_size,
|
| 333 |
+
clip_min,
|
| 334 |
+
clip_max,
|
| 335 |
+
const_init,
|
| 336 |
+
beta,
|
| 337 |
+
distance,
|
| 338 |
+
epsilon,
|
| 339 |
+
num_steps * 2,
|
| 340 |
+
step_size,
|
| 341 |
+
clip_min,clip_max,
|
| 342 |
+
const_init,
|
| 343 |
+
beta,
|
| 344 |
+
distance,
|
| 345 |
+
train_criterion,
|
| 346 |
+
test_criterion,
|
| 347 |
+
kd_criterion,
|
| 348 |
+
attack_type,
|
| 349 |
+
**kwargs
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
self._num_replicas = kwargs["num_replicas"]
|
| 353 |
+
self._rank = kwargs["rank"]
|
| 354 |
+
|
| 355 |
+
@property
|
| 356 |
+
def data_provider(self):
|
| 357 |
+
if self.__dict__.get("_data_provider", None) is None:
|
| 358 |
+
if self.dataset == ImagenetDataProvider.name():
|
| 359 |
+
DataProviderClass = ImagenetDataProvider
|
| 360 |
+
elif self.dataset == Cifar10DataProvider.name():
|
| 361 |
+
DataProviderClass = Cifar10DataProvider
|
| 362 |
+
elif self.dataset == Cifar100DataProvider.name():
|
| 363 |
+
DataProviderClass = Cifar100DataProvider
|
| 364 |
+
else:
|
| 365 |
+
raise NotImplementedError
|
| 366 |
+
if self.dataset == "imagenet":
|
| 367 |
+
self.__dict__["_data_provider"] = DataProviderClass(
|
| 368 |
+
train_batch_size=self.train_batch_size,
|
| 369 |
+
test_batch_size=self.test_batch_size,
|
| 370 |
+
valid_size=self.valid_size,
|
| 371 |
+
n_worker=self.n_worker,
|
| 372 |
+
resize_scale=self.resize_scale,
|
| 373 |
+
distort_color=self.distort_color,
|
| 374 |
+
image_size=self.image_size,
|
| 375 |
+
num_replicas=self._num_replicas,
|
| 376 |
+
rank=self._rank,
|
| 377 |
+
)
|
| 378 |
+
else:
|
| 379 |
+
self.__dict__["_data_provider"] = DataProviderClass(
|
| 380 |
+
train_batch_size=self.train_batch_size,
|
| 381 |
+
test_batch_size=self.test_batch_size,
|
| 382 |
+
valid_size=self.valid_size,
|
| 383 |
+
n_worker=self.n_worker,
|
| 384 |
+
resize_scale=None,
|
| 385 |
+
distort_color=None,
|
| 386 |
+
image_size=self.image_size,
|
| 387 |
+
num_replicas=self._num_replicas,
|
| 388 |
+
rank=self._rank,
|
| 389 |
+
)
|
| 390 |
+
return self.__dict__["_data_provider"]
|
| 391 |
+
@property
|
| 392 |
+
def train_criterion_loss (self):
|
| 393 |
+
if self.train_criterion == "trades" :
|
| 394 |
+
return trades_loss
|
| 395 |
+
elif self.train_criterion == "mart" :
|
| 396 |
+
return mart_loss
|
| 397 |
+
elif self.train_criterion == "sat" :
|
| 398 |
+
return sat_loss
|
| 399 |
+
elif self.train_criterion == "hat" :
|
| 400 |
+
return hat_loss
|
| 401 |
+
@property
|
| 402 |
+
def test_criterion_loss (self) :
|
| 403 |
+
if self.test_criterion == "ce" :
|
| 404 |
+
return torch.nn.CrossEntropyLoss()
|
| 405 |
+
@property
|
| 406 |
+
def kd_criterion_loss (self) :
|
| 407 |
+
if self.kd_criterion =="ard" :
|
| 408 |
+
return ard_loss
|
| 409 |
+
elif self.kd_criterion == "adaad" :
|
| 410 |
+
return adaad_loss
|
| 411 |
+
elif self.kd_criterion == "rslad" :
|
| 412 |
+
return rslad_loss
|
| 413 |
+
|
| 414 |
+
|
proard/classification/run_manager/run_manager.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
import json
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torch.nn.parallel
|
| 13 |
+
import torch.backends.cudnn as cudnn
|
| 14 |
+
import torch.optim
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
from attacks.utils import ctx_noparamgrad_and_eval
|
| 17 |
+
from robust_loss.rslad import rslad_inner_loss,kl_loss
|
| 18 |
+
from robust_loss.trades import trades_loss
|
| 19 |
+
from attacks import create_attack
|
| 20 |
+
from proard.utils import (
|
| 21 |
+
get_net_info,
|
| 22 |
+
cross_entropy_loss_with_soft_target,
|
| 23 |
+
cross_entropy_with_label_smoothing,
|
| 24 |
+
)
|
| 25 |
+
from proard.utils import (
|
| 26 |
+
AverageMeter,
|
| 27 |
+
accuracy,
|
| 28 |
+
write_log,
|
| 29 |
+
mix_images,
|
| 30 |
+
mix_labels,
|
| 31 |
+
init_models,
|
| 32 |
+
)
|
| 33 |
+
from proard.utils import MyRandomResizedCrop
|
| 34 |
+
|
| 35 |
+
__all__ = ["RunManager"]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class RunManager:
|
| 39 |
+
def __init__(
|
| 40 |
+
self, path, net, run_config, init=True, measure_latency=None, no_gpu=False
|
| 41 |
+
):
|
| 42 |
+
self.path = path
|
| 43 |
+
self.net = net
|
| 44 |
+
self.run_config = run_config
|
| 45 |
+
|
| 46 |
+
self.best_acc = 0
|
| 47 |
+
self.best_robustness = 0
|
| 48 |
+
self.start_epoch = 0
|
| 49 |
+
|
| 50 |
+
os.makedirs(self.path, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
# move network to GPU if available
|
| 53 |
+
if torch.cuda.is_available() and (not no_gpu):
|
| 54 |
+
self.device = torch.device("cuda")
|
| 55 |
+
self.net = self.net.to(self.device)
|
| 56 |
+
cudnn.benchmark = True
|
| 57 |
+
else:
|
| 58 |
+
self.device = torch.device("cpu")
|
| 59 |
+
# initialize model (default)
|
| 60 |
+
if init:
|
| 61 |
+
init_models(net,run_config.model_init)
|
| 62 |
+
|
| 63 |
+
# net info
|
| 64 |
+
net_info = get_net_info(
|
| 65 |
+
self.net, self.run_config.data_provider.data_shape, measure_latency, True
|
| 66 |
+
)
|
| 67 |
+
with open("%s/net_info.txt" % self.path, "w") as fout:
|
| 68 |
+
fout.write(json.dumps(net_info, indent=4) + "\n")
|
| 69 |
+
# noinspection PyBroadException
|
| 70 |
+
try:
|
| 71 |
+
fout.write(self.network.module_str + "\n")
|
| 72 |
+
except Exception:
|
| 73 |
+
pass
|
| 74 |
+
fout.write("%s\n" % self.run_config.data_provider.train.dataset.transform)
|
| 75 |
+
fout.write("%s\n" % self.run_config.data_provider.test.dataset.transform)
|
| 76 |
+
fout.write("%s\n" % self.network)
|
| 77 |
+
|
| 78 |
+
self.train_criterion = self.run_config.train_criterion_loss
|
| 79 |
+
self.test_criterion = self.run_config.test_criterion_loss
|
| 80 |
+
self.kd_criterion = self.run_config.kd_criterion_loss
|
| 81 |
+
|
| 82 |
+
# optimizer
|
| 83 |
+
if self.run_config.no_decay_keys:
|
| 84 |
+
keys = self.run_config.no_decay_keys.split("#")
|
| 85 |
+
net_params = [
|
| 86 |
+
self.network.get_parameters(
|
| 87 |
+
keys, mode="exclude"
|
| 88 |
+
), # parameters with weight decay
|
| 89 |
+
self.network.get_parameters(
|
| 90 |
+
keys, mode="include"
|
| 91 |
+
), # parameters without weight decay
|
| 92 |
+
]
|
| 93 |
+
else:
|
| 94 |
+
# noinspection PyBroadException
|
| 95 |
+
try:
|
| 96 |
+
net_params = self.network.weight_parameters()
|
| 97 |
+
except Exception:
|
| 98 |
+
net_params = []
|
| 99 |
+
for param in self.network.parameters():
|
| 100 |
+
if param.requires_grad:
|
| 101 |
+
net_params.append(param)
|
| 102 |
+
self.optimizer = self.run_config.build_optimizer(net_params)
|
| 103 |
+
|
| 104 |
+
self.net = torch.nn.DataParallel(self.net)
|
| 105 |
+
|
| 106 |
+
""" save path and log path """
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def save_path(self):
|
| 110 |
+
if self.__dict__.get("_save_path", None) is None:
|
| 111 |
+
save_path = os.path.join(self.path, "checkpoint")
|
| 112 |
+
os.makedirs(save_path, exist_ok=True)
|
| 113 |
+
self.__dict__["_save_path"] = save_path
|
| 114 |
+
return self.__dict__["_save_path"]
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def logs_path(self):
|
| 118 |
+
if self.__dict__.get("_logs_path", None) is None:
|
| 119 |
+
logs_path = os.path.join(self.path, "logs")
|
| 120 |
+
os.makedirs(logs_path, exist_ok=True)
|
| 121 |
+
self.__dict__["_logs_path"] = logs_path
|
| 122 |
+
return self.__dict__["_logs_path"]
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def network(self):
|
| 126 |
+
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
|
| 127 |
+
|
| 128 |
+
def write_log(self, log_str, prefix="valid", should_print=True, mode="a"):
|
| 129 |
+
write_log(self.logs_path, log_str, prefix, should_print, mode)
|
| 130 |
+
|
| 131 |
+
""" save and load models """
|
| 132 |
+
|
| 133 |
+
def save_model(self, checkpoint=None, is_best=False, model_name=None):
|
| 134 |
+
if checkpoint is None:
|
| 135 |
+
checkpoint = {"state_dict": self.network.state_dict()}
|
| 136 |
+
|
| 137 |
+
if model_name is None:
|
| 138 |
+
model_name = "checkpoint.pth.tar"
|
| 139 |
+
|
| 140 |
+
checkpoint[
|
| 141 |
+
"dataset"
|
| 142 |
+
] = self.run_config.dataset # add `dataset` info to the checkpoint
|
| 143 |
+
latest_fname = os.path.join(self.save_path, "latest.txt")
|
| 144 |
+
model_path = os.path.join(self.save_path, model_name)
|
| 145 |
+
with open(latest_fname, "w") as fout:
|
| 146 |
+
fout.write(model_path + "\n")
|
| 147 |
+
torch.save(checkpoint, model_path)
|
| 148 |
+
|
| 149 |
+
if is_best:
|
| 150 |
+
best_path = os.path.join(self.save_path, "model_best.pth.tar")
|
| 151 |
+
torch.save({"state_dict": checkpoint["state_dict"]}, best_path)
|
| 152 |
+
|
| 153 |
+
def load_model(self, model_fname=None):
|
| 154 |
+
latest_fname = os.path.join(self.save_path, "latest.txt")
|
| 155 |
+
if model_fname is None and os.path.exists(latest_fname):
|
| 156 |
+
with open(latest_fname, "r") as fin:
|
| 157 |
+
model_fname = fin.readline()
|
| 158 |
+
if model_fname[-1] == "\n":
|
| 159 |
+
model_fname = model_fname[:-1]
|
| 160 |
+
# noinspection PyBroadException
|
| 161 |
+
try:
|
| 162 |
+
if model_fname is None or not os.path.exists(model_fname):
|
| 163 |
+
model_fname = "%s/checkpoint.pth.tar" % self.save_path
|
| 164 |
+
with open(latest_fname, "w") as fout:
|
| 165 |
+
fout.write(model_fname + "\n")
|
| 166 |
+
print("=> loading checkpoint '{}'".format(model_fname))
|
| 167 |
+
checkpoint = torch.load(model_fname, map_location="cpu")
|
| 168 |
+
except Exception:
|
| 169 |
+
print("fail to load checkpoint from %s" % self.save_path)
|
| 170 |
+
return {}
|
| 171 |
+
|
| 172 |
+
self.network.load_state_dict(checkpoint["state_dict"])
|
| 173 |
+
if "epoch" in checkpoint:
|
| 174 |
+
self.start_epoch = checkpoint["epoch"] + 1
|
| 175 |
+
if "best_acc" in checkpoint:
|
| 176 |
+
self.best_acc = checkpoint["best_acc"]
|
| 177 |
+
if "optimizer" in checkpoint:
|
| 178 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| 179 |
+
|
| 180 |
+
print("=> loaded checkpoint '{}'".format(model_fname))
|
| 181 |
+
return checkpoint
|
| 182 |
+
|
| 183 |
+
def save_config(self, extra_run_config=None, extra_net_config=None):
|
| 184 |
+
"""dump run_config and net_config to the model_folder"""
|
| 185 |
+
run_save_path = os.path.join(self.path, "run.config")
|
| 186 |
+
if not os.path.isfile(run_save_path):
|
| 187 |
+
run_config = self.run_config.config
|
| 188 |
+
if extra_run_config is not None:
|
| 189 |
+
run_config.update(extra_run_config)
|
| 190 |
+
json.dump(run_config, open(run_save_path, "w"), indent=4)
|
| 191 |
+
print("Run configs dump to %s" % run_save_path)
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
net_save_path = os.path.join(self.path, "net.config")
|
| 195 |
+
net_config = self.network.config
|
| 196 |
+
if extra_net_config is not None:
|
| 197 |
+
net_config.update(extra_net_config)
|
| 198 |
+
json.dump(net_config, open(net_save_path, "w"), indent=4)
|
| 199 |
+
print("Network configs dump to %s" % net_save_path)
|
| 200 |
+
except Exception:
|
| 201 |
+
print("%s do not support net config" % type(self.network))
|
| 202 |
+
|
| 203 |
+
""" metric related """
|
| 204 |
+
|
| 205 |
+
def get_metric_dict(self):
|
| 206 |
+
return {
|
| 207 |
+
"top1": AverageMeter(),
|
| 208 |
+
"top5": AverageMeter(),
|
| 209 |
+
"robust1" :AverageMeter(),
|
| 210 |
+
"robust5" :AverageMeter(),
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
def update_metric(self, metric_dict, output, output_adv, labels):
|
| 214 |
+
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
| 215 |
+
robust1,robust5 = accuracy(output_adv,labels,topk=(1,5))
|
| 216 |
+
metric_dict["top1"].update(acc1[0].item(), output.size(0))
|
| 217 |
+
metric_dict["top5"].update(acc5[0].item(), output.size(0))
|
| 218 |
+
metric_dict["robust1"].update(robust1[0].item(), output.size(0))
|
| 219 |
+
metric_dict["robust5"].update(robust5[0].item(), output.size(0))
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def get_metric_vals(self, metric_dict, return_dict=False):
|
| 223 |
+
if return_dict:
|
| 224 |
+
return {key: metric_dict[key].avg for key in metric_dict}
|
| 225 |
+
else:
|
| 226 |
+
return [metric_dict[key].avg for key in metric_dict]
|
| 227 |
+
|
| 228 |
+
def get_metric_names(self):
|
| 229 |
+
return "top1", "top5" , "robust1" , "robust5"
|
| 230 |
+
|
| 231 |
+
""" train and test """
|
| 232 |
+
|
| 233 |
+
def validate(
|
| 234 |
+
self,
|
| 235 |
+
epoch=0,
|
| 236 |
+
is_test=False,
|
| 237 |
+
run_str="",
|
| 238 |
+
net=None,
|
| 239 |
+
data_loader=None,
|
| 240 |
+
no_logs=False,
|
| 241 |
+
train_mode=False,
|
| 242 |
+
):
|
| 243 |
+
if net is None:
|
| 244 |
+
net = self.net
|
| 245 |
+
if not isinstance(net, nn.DataParallel):
|
| 246 |
+
net = nn.DataParallel(net)
|
| 247 |
+
if data_loader is None:
|
| 248 |
+
data_loader = (
|
| 249 |
+
self.run_config.test_loader if is_test else self.run_config.valid_loader
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
if train_mode:
|
| 253 |
+
net.train()
|
| 254 |
+
else:
|
| 255 |
+
net.eval()
|
| 256 |
+
if self.run_config.robust_mode:
|
| 257 |
+
eval_attack = create_attack(net, self.test_criterion.cuda(), self.run_config.attack_type,self.run_config.epsilon_test,self.run_config.num_steps_test, self.run_config.step_size_test)
|
| 258 |
+
losses = AverageMeter()
|
| 259 |
+
metric_dict = self.get_metric_dict()
|
| 260 |
+
|
| 261 |
+
with tqdm(
|
| 262 |
+
total=len(data_loader),
|
| 263 |
+
desc="Validate Epoch #{} {}".format(epoch + 1, run_str),
|
| 264 |
+
disable=no_logs,
|
| 265 |
+
) as t:
|
| 266 |
+
for i, (images, labels) in enumerate(data_loader):
|
| 267 |
+
images, labels = images.to(self.device), labels.to(self.device)
|
| 268 |
+
# compute output
|
| 269 |
+
output = net(images)
|
| 270 |
+
if self.run_config.robust_mode:
|
| 271 |
+
with ctx_noparamgrad_and_eval(net):
|
| 272 |
+
images_adv,_ = eval_attack.perturb(images, labels)
|
| 273 |
+
output_adv = net(images_adv)
|
| 274 |
+
loss = nn.CrossEntropyLoss()(output_adv,labels)
|
| 275 |
+
else:
|
| 276 |
+
output_adv = output
|
| 277 |
+
loss = nn.CrossEntropyLoss()(output,labels)
|
| 278 |
+
|
| 279 |
+
# measure accuracy and record loss
|
| 280 |
+
self.update_metric(metric_dict, output, output_adv , labels)
|
| 281 |
+
|
| 282 |
+
losses.update(loss.item(), images.size(0))
|
| 283 |
+
t.set_postfix(
|
| 284 |
+
{
|
| 285 |
+
"loss": losses.avg,
|
| 286 |
+
**self.get_metric_vals(metric_dict, return_dict=True),
|
| 287 |
+
"img_size": images.size(2),
|
| 288 |
+
}
|
| 289 |
+
)
|
| 290 |
+
t.update(1)
|
| 291 |
+
return losses.avg, self.get_metric_vals(metric_dict)
|
| 292 |
+
|
| 293 |
+
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
|
| 294 |
+
if net is None:
|
| 295 |
+
net = self.network
|
| 296 |
+
if isinstance(self.run_config.data_provider.image_size, list):
|
| 297 |
+
img_size_list, loss_list, top1_list, top5_list , robust1_list , robust5_list = [], [], [], [],[],[]
|
| 298 |
+
for img_size in self.run_config.data_provider.image_size:
|
| 299 |
+
img_size_list.append(img_size)
|
| 300 |
+
self.run_config.data_provider.assign_active_img_size(img_size)
|
| 301 |
+
self.reset_running_statistics(net=net)
|
| 302 |
+
loss, (top1, top5 , robust1,robust5) = self.validate(epoch, is_test, net=net)
|
| 303 |
+
loss_list.append(loss)
|
| 304 |
+
top1_list.append(top1)
|
| 305 |
+
top5_list.append(top5)
|
| 306 |
+
robust1_list.append(robust1)
|
| 307 |
+
robust5_list.append(robust5)
|
| 308 |
+
return img_size_list, loss_list, top1_list, top5_list ,robust1_list ,robust5_list
|
| 309 |
+
else:
|
| 310 |
+
loss, (top1, top5 , robust1 , robust5) = self.validate(epoch, is_test, net=net)
|
| 311 |
+
return (
|
| 312 |
+
[self.run_config.data_provider.active_img_size],
|
| 313 |
+
[loss],
|
| 314 |
+
[top1],
|
| 315 |
+
[top5],
|
| 316 |
+
[robust1],
|
| 317 |
+
[robust5]
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def train_one_epoch(self, args, epoch, warmup_epochs=0, warmup_lr=0):
|
| 321 |
+
# switch to train mode
|
| 322 |
+
self.net.train()
|
| 323 |
+
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
|
| 324 |
+
|
| 325 |
+
nBatch = len(self.run_config.train_loader)
|
| 326 |
+
|
| 327 |
+
losses = AverageMeter()
|
| 328 |
+
metric_dict = self.get_metric_dict()
|
| 329 |
+
data_time = AverageMeter()
|
| 330 |
+
|
| 331 |
+
with tqdm(
|
| 332 |
+
total=nBatch,
|
| 333 |
+
desc="{} Train Epoch #{}".format(self.run_config.dataset, epoch + 1),
|
| 334 |
+
) as t:
|
| 335 |
+
end = time.time()
|
| 336 |
+
for i, (images, labels) in enumerate(self.run_config.train_loader):
|
| 337 |
+
MyRandomResizedCrop.BATCH = i
|
| 338 |
+
data_time.update(time.time() - end)
|
| 339 |
+
if epoch < warmup_epochs:
|
| 340 |
+
new_lr = self.run_config.warmup_adjust_learning_rate(
|
| 341 |
+
self.optimizer,
|
| 342 |
+
warmup_epochs * nBatch,
|
| 343 |
+
nBatch,
|
| 344 |
+
epoch,
|
| 345 |
+
i,
|
| 346 |
+
warmup_lr,
|
| 347 |
+
)
|
| 348 |
+
else:
|
| 349 |
+
new_lr = self.run_config.adjust_learning_rate(
|
| 350 |
+
self.optimizer, epoch - warmup_epochs, i, nBatch
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
images, labels = images.to(self.device), labels.to(self.device)
|
| 354 |
+
target = labels
|
| 355 |
+
if isinstance(self.run_config.mixup_alpha, float):
|
| 356 |
+
# transform data
|
| 357 |
+
lam = random.betavariate(
|
| 358 |
+
self.run_config.mixup_alpha, self.run_config.mixup_alpha
|
| 359 |
+
)
|
| 360 |
+
images = mix_images(images, lam)
|
| 361 |
+
labels = mix_labels(
|
| 362 |
+
labels,
|
| 363 |
+
lam,
|
| 364 |
+
self.run_config.data_provider.n_classes,
|
| 365 |
+
self.run_config.label_smoothing,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# soft target
|
| 369 |
+
if args.teacher_model is not None:
|
| 370 |
+
args.teacher_model.train()
|
| 371 |
+
with torch.no_grad():
|
| 372 |
+
soft_logits = args.teacher_model(images).detach()
|
| 373 |
+
soft_label = F.softmax(soft_logits, dim=1)
|
| 374 |
+
|
| 375 |
+
# compute output
|
| 376 |
+
output = self.net(images)
|
| 377 |
+
|
| 378 |
+
if args.teacher_model is None:
|
| 379 |
+
if self.run_config.robust_mode:
|
| 380 |
+
loss = self.train_criterion(self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train,self.run_config.distance_train)
|
| 381 |
+
loss_type = self.run_config.train_criterion
|
| 382 |
+
else:
|
| 383 |
+
loss = torch.nn.CrossEntropyLoss(output,labels)
|
| 384 |
+
loss_type = 'ce'
|
| 385 |
+
|
| 386 |
+
else:
|
| 387 |
+
if self.run_config.robust_mode:
|
| 388 |
+
loss = self.kd_criterion(args.teacher_model,self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train)
|
| 389 |
+
loss_type = self.run_config.train_criterion
|
| 390 |
+
else:
|
| 391 |
+
if args.kd_type == "ce":
|
| 392 |
+
kd_loss = cross_entropy_loss_with_soft_target(
|
| 393 |
+
output, soft_label
|
| 394 |
+
)
|
| 395 |
+
else:
|
| 396 |
+
kd_loss = F.mse_loss(output, soft_logits)
|
| 397 |
+
loss = args.kd_ratio * kd_loss + loss
|
| 398 |
+
loss_type = "%.1fkd+ce" % args.kd_ratio
|
| 399 |
+
|
| 400 |
+
# compute gradient and do SGD step
|
| 401 |
+
self.net.zero_grad() # or self.optimizer.zero_grad()
|
| 402 |
+
loss.backward()
|
| 403 |
+
self.optimizer.step()
|
| 404 |
+
|
| 405 |
+
# measure accuracy and record loss
|
| 406 |
+
losses.update(loss.item(), images.size(0))
|
| 407 |
+
self.update_metric(metric_dict, output, output ,target)
|
| 408 |
+
|
| 409 |
+
t.set_postfix(
|
| 410 |
+
{
|
| 411 |
+
"loss": losses.avg,
|
| 412 |
+
**self.get_metric_vals(metric_dict, return_dict=True),
|
| 413 |
+
"img_size": images.size(2),
|
| 414 |
+
"lr": new_lr,
|
| 415 |
+
"loss_type": loss_type,
|
| 416 |
+
"data_time": data_time.avg,
|
| 417 |
+
}
|
| 418 |
+
)
|
| 419 |
+
t.update(1)
|
| 420 |
+
end = time.time()
|
| 421 |
+
return losses.avg, self.get_metric_vals(metric_dict)
|
| 422 |
+
|
| 423 |
+
def train(self, args, warmup_epoch=0, warmup_lr=0):
|
| 424 |
+
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch):
|
| 425 |
+
train_loss, (train_top1, train_top5 , train_robust1 , train_robust5) = self.train_one_epoch(
|
| 426 |
+
args, epoch, warmup_epoch, warmup_lr
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
if (epoch + 1) % self.run_config.validation_frequency == 0:
|
| 430 |
+
img_size, val_loss, val_acc, val_acc5 ,val_robust, val_robust5 = self.validate_all_resolution(
|
| 431 |
+
epoch=epoch, is_test=False
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
is_best = np.mean(val_acc) > self.best_acc
|
| 435 |
+
is_best_robust = np.mean(val_robust) > self.best_robustness
|
| 436 |
+
self.best_acc = max(self.best_acc, np.mean(val_acc))
|
| 437 |
+
self.best_robustness = max(self.best_robustness, np.mean(val_robust))
|
| 438 |
+
val_log = "Valid [{0}/{1}]\tloss {2:.3f} \t{7} {3:.3f} ({5:.3f}) \t{8} {4:.3f} ({6:.3f})".format(
|
| 439 |
+
epoch + 1 - warmup_epoch,
|
| 440 |
+
self.run_config.n_epochs,
|
| 441 |
+
np.mean(val_loss),
|
| 442 |
+
np.mean(val_acc),
|
| 443 |
+
np.mean(val_robust),
|
| 444 |
+
self.best_acc,
|
| 445 |
+
self.best_robustness,
|
| 446 |
+
self.get_metric_names()[0],
|
| 447 |
+
self.get_metric_names()[2],
|
| 448 |
+
)
|
| 449 |
+
val_log += "\t{2} {0:.3f} \tTrain {1} {top1:.3f}\t {3} {robust:.3f} \t loss {train_loss:.3f}\t".format(
|
| 450 |
+
np.mean(val_acc5),
|
| 451 |
+
*self.get_metric_names(),
|
| 452 |
+
top1=train_top1,
|
| 453 |
+
robust = train_robust1,
|
| 454 |
+
train_loss=train_loss
|
| 455 |
+
)
|
| 456 |
+
for i_s, v_a in zip(img_size, val_acc):
|
| 457 |
+
val_log += "(%d, %.3f), " % (i_s, v_a)
|
| 458 |
+
self.write_log(val_log, prefix="valid", should_print=False)
|
| 459 |
+
else:
|
| 460 |
+
is_best = False
|
| 461 |
+
is_best_robust = False
|
| 462 |
+
|
| 463 |
+
self.save_model(
|
| 464 |
+
{
|
| 465 |
+
"epoch": epoch,
|
| 466 |
+
"best_acc": self.best_acc,
|
| 467 |
+
"optimizer": self.optimizer.state_dict(),
|
| 468 |
+
"state_dict": self.network.state_dict(),
|
| 469 |
+
},
|
| 470 |
+
is_best=is_best,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
def reset_running_statistics(
|
| 474 |
+
self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None
|
| 475 |
+
):
|
| 476 |
+
from proard.classification.elastic_nn.utils import set_running_statistics
|
| 477 |
+
|
| 478 |
+
if net is None:
|
| 479 |
+
net = self.network
|
| 480 |
+
if data_loader is None:
|
| 481 |
+
data_loader = self.run_config.random_sub_train_loader(
|
| 482 |
+
subset_size, subset_batch_size
|
| 483 |
+
)
|
| 484 |
+
set_running_statistics(net, data_loader)
|
proard/model_zoo.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import torch
|
| 7 |
+
import gdown
|
| 8 |
+
|
| 9 |
+
from proard.classification.networks import get_net_by_name, ResNet50
|
| 10 |
+
from proard.classification.elastic_nn.networks import (
|
| 11 |
+
DYNResNets,DYNMobileNetV3,DYNProxylessNASNets,DYNProxylessNASNets_Cifar,DYNMobileNetV3_Cifar,DYNResNets_Cifar
|
| 12 |
+
)
|
| 13 |
+
from proard.classification.networks import (WideResNet,ResNet50_Cifar,ResNet50,MobileNetV3_Cifar,MobileNetV3Large_Cifar,MobileNetV3Large,ProxylessNASNets_Cifar,ProxylessNASNets,MobileNetV2_Cifar,MobileNetV2)
|
| 14 |
+
__all__ = [
|
| 15 |
+
"DYN_net",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def DYN_net(net_id, robust_mode, dataset,train_criterion, pretrained=True,run_config=None,WPS=False,base=False):
|
| 21 |
+
if net_id == "ResNet50":
|
| 22 |
+
if not base:
|
| 23 |
+
if dataset == "cifar10" or dataset == "cifar100":
|
| 24 |
+
net = DYNResNets_Cifar(n_classes=run_config.data_provider.n_classes,
|
| 25 |
+
dropout_rate=0,
|
| 26 |
+
depth_list=[0, 1, 2],
|
| 27 |
+
expand_ratio_list=[0.2, 0.25, 0.35],
|
| 28 |
+
width_mult_list=[0.65, 0.8, 1.0],
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
net = DYNResNets(n_classes=run_config.data_provider.n_classes,
|
| 32 |
+
dropout_rate=0,
|
| 33 |
+
depth_list=[0, 1, 2],
|
| 34 |
+
expand_ratio_list=[0.2, 0.25, 0.35],
|
| 35 |
+
width_mult_list=[0.65, 0.8, 1.0],
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
if dataset == "cifar10" or dataset == "cifar100":
|
| 39 |
+
net = ResNet50_Cifar(n_classes=run_config.data_provider.n_classes)
|
| 40 |
+
else:
|
| 41 |
+
net = ResNet50(n_classes=run_config.data_provider.n_classes)
|
| 42 |
+
|
| 43 |
+
elif net_id == "MBV3":
|
| 44 |
+
if not base:
|
| 45 |
+
if dataset == "cifar10" or dataset == "cifar100":
|
| 46 |
+
net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,
|
| 47 |
+
dropout_rate=0,
|
| 48 |
+
width_mult=1.0,
|
| 49 |
+
ks_list=[3, 5, 7],
|
| 50 |
+
expand_ratio_list=[3, 4, 6],
|
| 51 |
+
depth_list=[2, 3, 4],
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,
|
| 55 |
+
dropout_rate=0,
|
| 56 |
+
width_mult=1.0,
|
| 57 |
+
ks_list=[3, 5, 7],
|
| 58 |
+
expand_ratio_list=[3, 4, 6],
|
| 59 |
+
depth_list=[2, 3, 4],
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
if dataset == "cifar10" or dataset == "cifar100":
|
| 63 |
+
net = MobileNetV3Large_Cifar(n_classes=run_config.data_provider.n_classes)
|
| 64 |
+
else:
|
| 65 |
+
net = MobileNetV3Large(n_classes=run_config.data_provider.n_classes)
|
| 66 |
+
elif net_id == "ProxylessNASNet":
|
| 67 |
+
if not base:
|
| 68 |
+
if dataset == "cifar10" or dataset == "cifar100":
|
| 69 |
+
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
|
| 70 |
+
dropout_rate=0,
|
| 71 |
+
width_mult=1.0,
|
| 72 |
+
ks_list=[3, 5, 7],
|
| 73 |
+
expand_ratio_list=[3, 4, 6],
|
| 74 |
+
depth_list=[2, 3, 4],
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
|
| 78 |
+
dropout_rate=0,
|
| 79 |
+
width_mult=1.0,
|
| 80 |
+
ks_list=[3, 5, 7],
|
| 81 |
+
expand_ratio_list=[3, 4, 6],
|
| 82 |
+
depth_list=[2, 3, 4],
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
if dataset == "cifar10" or dataset == "cifar100":
|
| 86 |
+
net = ProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes)
|
| 87 |
+
else:
|
| 88 |
+
net = ProxylessNASNets(n_classes=run_config.data_provider.n_classes)
|
| 89 |
+
elif net_id == "MBV2":
|
| 90 |
+
if not base:
|
| 91 |
+
if dataset == "cifar10" or dataset == "cifar100":
|
| 92 |
+
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
|
| 93 |
+
dropout_rate=0,
|
| 94 |
+
base_stage_width="google",
|
| 95 |
+
width_mult=1.0,
|
| 96 |
+
ks_list=[3, 5, 7],
|
| 97 |
+
expand_ratio_list=[3, 4, 6],
|
| 98 |
+
depth_list=[2, 3, 4],
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
|
| 102 |
+
dropout_rate=0,
|
| 103 |
+
base_stage_width="google",
|
| 104 |
+
width_mult=1.0,
|
| 105 |
+
ks_list=[3, 5, 7],
|
| 106 |
+
expand_ratio_list=[3, 4, 6],
|
| 107 |
+
depth_list=[2, 3, 4],
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
if dataset == "cifar10" or dataset == "cifar100":
|
| 111 |
+
net = MobileNetV2_Cifar(n_classes=run_config.data_provider.n_classes)
|
| 112 |
+
else:
|
| 113 |
+
net = MobileNetV2(n_classes=run_config.data_provider.n_classes)
|
| 114 |
+
elif net_id == "WideResNet":
|
| 115 |
+
if dataset == "cifar10" or dataset == "cifar100":
|
| 116 |
+
net = WideResNet(num_classes=run_config.data_provider.n_classes)
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError("Not supported: %s" % net_id)
|
| 119 |
+
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError("Not supported: %s" % net_id)
|
| 122 |
+
|
| 123 |
+
if pretrained and not WPS and not base:
|
| 124 |
+
if net_id == "ResNet50":
|
| 125 |
+
if robust_mode:
|
| 126 |
+
pt_path = "exp/robust/"+ dataset + "/" + net_id + '/' + train_criterion +"/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
|
| 127 |
+
else:
|
| 128 |
+
pt_path = "exp/"+ dataset + "/" + net_id + '/' + train_criterion + "/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
|
| 129 |
+
else:
|
| 130 |
+
if robust_mode:
|
| 131 |
+
pt_path = "exp/robust/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
pt_path = "exp/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
|
| 135 |
+
elif pretrained and WPS and not base:
|
| 136 |
+
if net_id == "ResNet50":
|
| 137 |
+
if robust_mode:
|
| 138 |
+
pt_path = "exp/robust/WPS/"+ dataset + "/" + net_id + '/' + train_criterion +"/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
|
| 139 |
+
else:
|
| 140 |
+
pt_path = "exp/WPS/"+ dataset + "/" + net_id + '/' + train_criterion + "/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
|
| 141 |
+
else:
|
| 142 |
+
if robust_mode:
|
| 143 |
+
pt_path = "exp/robust/WPS/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
|
| 144 |
+
|
| 145 |
+
else:
|
| 146 |
+
pt_path = "exp/WPS/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
|
| 147 |
+
else:
|
| 148 |
+
if not base:
|
| 149 |
+
pt_path = "exp/robust/teacher/"+ dataset + '/' + net_id + '/' + train_criterion + "/checkpoint/model_best.pth.tar"
|
| 150 |
+
else:
|
| 151 |
+
pt_path = "exp/robust/base/"+ dataset + '/' + net_id + '/' + train_criterion + "/checkpoint/model_best.pth.tar"
|
| 152 |
+
print(pt_path)
|
| 153 |
+
init = torch.load(pt_path, map_location="cuda")["state_dict"]
|
| 154 |
+
# from collections import OrderedDict
|
| 155 |
+
# new_state_dict = OrderedDict()
|
| 156 |
+
# for k, v in init.items():
|
| 157 |
+
# name = k[7:] # remove `module.`
|
| 158 |
+
# new_state_dict[name] = v
|
| 159 |
+
net.load_state_dict(init)
|
| 160 |
+
return net
|
| 161 |
+
|
| 162 |
+
|
proard/nas/__init__.py
ADDED
|
File without changes
|
proard/nas/accuracy_predictor/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
from .acc_dataset import *
|
| 6 |
+
from .acc_predictor import *
|
| 7 |
+
from .arch_encoder import *
|
| 8 |
+
from .rob_dataset import *
|
| 9 |
+
from .rob_predictor import *
|
| 10 |
+
from .acc_rob_dataset import *
|
| 11 |
+
from .acc_rob_predictor import *
|
proard/nas/accuracy_predictor/acc_dataset.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils.data
|
| 11 |
+
|
| 12 |
+
from proard.utils import list_mean
|
| 13 |
+
|
| 14 |
+
__all__ = ["net_setting2id", "net_id2setting", "AccuracyDataset"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def net_setting2id(net_setting):
|
| 18 |
+
return json.dumps(net_setting)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def net_id2setting(net_id):
|
| 22 |
+
return json.loads(net_id)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RegDataset(torch.utils.data.Dataset):
|
| 26 |
+
def __init__(self, inputs, targets):
|
| 27 |
+
super(RegDataset, self).__init__()
|
| 28 |
+
self.inputs = inputs
|
| 29 |
+
self.targets = targets
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, index):
|
| 32 |
+
return self.inputs[index], self.targets[index]
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return self.inputs.size(0)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AccuracyDataset:
|
| 39 |
+
def __init__(self, path):
|
| 40 |
+
self.path = path
|
| 41 |
+
os.makedirs(self.path, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def net_id_path(self):
|
| 45 |
+
return os.path.join(self.path, "net_id.dict")
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def acc_src_folder(self):
|
| 49 |
+
return os.path.join(self.path, "src")
|
| 50 |
+
@property
|
| 51 |
+
def acc_dict_path(self):
|
| 52 |
+
return os.path.join(self.path, "src/acc.dict")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# TODO: support parallel building
|
| 56 |
+
def build_acc_dataset(
|
| 57 |
+
self, run_manager, dyn_network, n_arch=2000, image_size_list=None
|
| 58 |
+
):
|
| 59 |
+
# load net_id_list, random sample if not exist
|
| 60 |
+
if os.path.isfile(self.net_id_path):
|
| 61 |
+
net_id_list = json.load(open(self.net_id_path))
|
| 62 |
+
else:
|
| 63 |
+
net_id_list = set()
|
| 64 |
+
while len(net_id_list) < n_arch:
|
| 65 |
+
net_setting = dyn_network.sample_active_subnet()
|
| 66 |
+
net_id = net_setting2id(net_setting)
|
| 67 |
+
net_id_list.add(net_id)
|
| 68 |
+
net_id_list = list(net_id_list)
|
| 69 |
+
net_id_list.sort()
|
| 70 |
+
json.dump(net_id_list, open(self.net_id_path, "w"), indent=4)
|
| 71 |
+
|
| 72 |
+
image_size_list = (
|
| 73 |
+
[128, 160, 192, 224] if image_size_list is None else image_size_list
|
| 74 |
+
)
|
| 75 |
+
print(image_size_list)
|
| 76 |
+
with tqdm(
|
| 77 |
+
total=len(net_id_list) * len(image_size_list), desc="Building Acc Dataset"
|
| 78 |
+
) as t:
|
| 79 |
+
for image_size in image_size_list:
|
| 80 |
+
# load val dataset into memory
|
| 81 |
+
val_dataset = []
|
| 82 |
+
run_manager.run_config.data_provider.assign_active_img_size(image_size)
|
| 83 |
+
for images, labels in run_manager.run_config.valid_loader:
|
| 84 |
+
val_dataset.append((images, labels))
|
| 85 |
+
# save path
|
| 86 |
+
os.makedirs(self.acc_src_folder, exist_ok=True)
|
| 87 |
+
acc_save_path = os.path.join(
|
| 88 |
+
self.acc_src_folder, "%d.dict" % image_size
|
| 89 |
+
)
|
| 90 |
+
acc_dict = {}
|
| 91 |
+
# load existing acc dict
|
| 92 |
+
if os.path.isfile(acc_save_path):
|
| 93 |
+
existing_acc_dict = json.load(open(acc_save_path, "r"))
|
| 94 |
+
else:
|
| 95 |
+
existing_acc_dict = {}
|
| 96 |
+
for net_id in net_id_list:
|
| 97 |
+
net_setting = net_id2setting(net_id)
|
| 98 |
+
key = net_setting2id({**net_setting, "image_size": image_size})
|
| 99 |
+
if key in existing_acc_dict:
|
| 100 |
+
acc_dict[key] = existing_acc_dict[key]
|
| 101 |
+
t.set_postfix(
|
| 102 |
+
{
|
| 103 |
+
"net_id": net_id,
|
| 104 |
+
"image_size": image_size,
|
| 105 |
+
"info_val": acc_dict[key],
|
| 106 |
+
"status": "loading",
|
| 107 |
+
}
|
| 108 |
+
)
|
| 109 |
+
t.update()
|
| 110 |
+
continue
|
| 111 |
+
dyn_network.set_active_subnet(**net_setting)
|
| 112 |
+
run_manager.reset_running_statistics(dyn_network)
|
| 113 |
+
net_setting_str = ",".join(
|
| 114 |
+
[
|
| 115 |
+
"%s_%s"
|
| 116 |
+
% (
|
| 117 |
+
key,
|
| 118 |
+
"%.1f" % list_mean(val)
|
| 119 |
+
if isinstance(val, list)
|
| 120 |
+
else val,
|
| 121 |
+
)
|
| 122 |
+
for key, val in net_setting.items()
|
| 123 |
+
]
|
| 124 |
+
)
|
| 125 |
+
loss, (top1, top5,robust1,robust5) = run_manager.validate(
|
| 126 |
+
run_str=net_setting_str,
|
| 127 |
+
net=dyn_network,
|
| 128 |
+
data_loader=val_dataset,
|
| 129 |
+
no_logs=True,
|
| 130 |
+
)
|
| 131 |
+
info_val = top1
|
| 132 |
+
t.set_postfix(
|
| 133 |
+
{
|
| 134 |
+
"net_id": net_id,
|
| 135 |
+
"image_size": image_size,
|
| 136 |
+
"info_val": info_val,
|
| 137 |
+
}
|
| 138 |
+
)
|
| 139 |
+
t.update()
|
| 140 |
+
|
| 141 |
+
acc_dict.update({key: info_val})
|
| 142 |
+
json.dump(acc_dict, open(acc_save_path, "w"), indent=4)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def merge_acc_dataset(self, image_size_list=None):
|
| 146 |
+
# load existing data
|
| 147 |
+
merged_acc_dict = {}
|
| 148 |
+
for fname in os.listdir(self.acc_src_folder):
|
| 149 |
+
if ".dict" not in fname:
|
| 150 |
+
continue
|
| 151 |
+
image_size = int(fname.split(".dict")[0])
|
| 152 |
+
if image_size_list is not None and image_size not in image_size_list:
|
| 153 |
+
print("Skip ", fname)
|
| 154 |
+
continue
|
| 155 |
+
full_path = os.path.join(self.acc_src_folder, fname)
|
| 156 |
+
partial_acc_dict = json.load(open(full_path))
|
| 157 |
+
merged_acc_dict.update(partial_acc_dict)
|
| 158 |
+
print("loaded %s" % full_path)
|
| 159 |
+
json.dump(merged_acc_dict, open(self.acc_dict_path, "w"), indent=4)
|
| 160 |
+
return merged_acc_dict
|
| 161 |
+
|
| 162 |
+
def build_acc_data_loader(
|
| 163 |
+
self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16
|
| 164 |
+
):
|
| 165 |
+
# load data
|
| 166 |
+
acc_dict = json.load(open(self.acc_dict_path))
|
| 167 |
+
X_all = []
|
| 168 |
+
Y_all = []
|
| 169 |
+
|
| 170 |
+
with tqdm(total=len(acc_dict), desc="Loading data") as t:
|
| 171 |
+
for k, v in acc_dict.items():
|
| 172 |
+
dic = json.loads(k)
|
| 173 |
+
X_all.append(arch_encoder.arch2feature(dic))
|
| 174 |
+
Y_all.append(v / 100.0) # range: 0 - 1
|
| 175 |
+
t.update()
|
| 176 |
+
base_acc = np.mean(Y_all)
|
| 177 |
+
# convert to torch tensor
|
| 178 |
+
X_all = torch.tensor(X_all, dtype=torch.float)
|
| 179 |
+
Y_all = torch.tensor(Y_all)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# random shuffle
|
| 183 |
+
shuffle_idx = torch.randperm(len(X_all))
|
| 184 |
+
X_all = X_all[shuffle_idx]
|
| 185 |
+
Y_all = Y_all[shuffle_idx]
|
| 186 |
+
# split data
|
| 187 |
+
idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
|
| 188 |
+
val_idx = X_all.size(0) // 5 * 4
|
| 189 |
+
X_train, Y_train = X_all[:idx], Y_all[:idx]
|
| 190 |
+
X_test, Y_test = X_all[val_idx:], Y_all[val_idx:]
|
| 191 |
+
print("Train Size: %d," % len(X_train), "Valid Size: %d" % len(X_test))
|
| 192 |
+
|
| 193 |
+
# build data loader
|
| 194 |
+
train_dataset = RegDataset(X_train, Y_train)
|
| 195 |
+
val_dataset = RegDataset(X_test, Y_test)
|
| 196 |
+
train_loader = torch.utils.data.DataLoader(
|
| 197 |
+
train_dataset,
|
| 198 |
+
batch_size=batch_size,
|
| 199 |
+
shuffle=True,
|
| 200 |
+
pin_memory=False,
|
| 201 |
+
num_workers=n_workers,
|
| 202 |
+
)
|
| 203 |
+
valid_loader = torch.utils.data.DataLoader(
|
| 204 |
+
val_dataset,
|
| 205 |
+
batch_size=batch_size,
|
| 206 |
+
shuffle=False,
|
| 207 |
+
pin_memory=False,
|
| 208 |
+
num_workers=n_workers,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
return train_loader, valid_loader, base_acc
|
| 212 |
+
|
| 213 |
+
|
proard/nas/accuracy_predictor/acc_predictor.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
__all__ = ["AccuracyPredictor"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AccuracyPredictor(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
arch_encoder,
|
| 17 |
+
hidden_size=400,
|
| 18 |
+
n_layers=3,
|
| 19 |
+
checkpoint_path=None,
|
| 20 |
+
device="cuda:0",
|
| 21 |
+
base_acc_val = None
|
| 22 |
+
):
|
| 23 |
+
super(AccuracyPredictor, self).__init__()
|
| 24 |
+
self.arch_encoder = arch_encoder
|
| 25 |
+
self.hidden_size = hidden_size
|
| 26 |
+
self.n_layers = n_layers
|
| 27 |
+
self.device = device
|
| 28 |
+
self.base_acc_val = base_acc_val
|
| 29 |
+
# build layers
|
| 30 |
+
layers = []
|
| 31 |
+
for i in range(self.n_layers):
|
| 32 |
+
layers.append(
|
| 33 |
+
nn.Sequential(
|
| 34 |
+
nn.Linear(
|
| 35 |
+
self.arch_encoder.n_dim if i == 0 else self.hidden_size,
|
| 36 |
+
self.hidden_size,
|
| 37 |
+
),
|
| 38 |
+
nn.ReLU(inplace=True),
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
layers.append(nn.Linear(self.hidden_size, 1, bias=False))
|
| 42 |
+
self.layers = nn.Sequential(*layers)
|
| 43 |
+
if self.base_acc_val!=None :
|
| 44 |
+
self.base_acc = nn.Parameter(
|
| 45 |
+
torch.tensor(self.base_acc_val, device=self.device), requires_grad=False
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
self.base_acc = nn.Parameter(
|
| 49 |
+
torch.zeros(1, device=self.device), requires_grad=False
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
if checkpoint_path is not None and os.path.exists(checkpoint_path):
|
| 53 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 54 |
+
if "state_dict" in checkpoint:
|
| 55 |
+
checkpoint = checkpoint["state_dict"]
|
| 56 |
+
self.load_state_dict(checkpoint)
|
| 57 |
+
print("Loaded checkpoint from %s" % checkpoint_path)
|
| 58 |
+
|
| 59 |
+
self.layers = self.layers.to(self.device)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
y = self.layers(x).squeeze()
|
| 63 |
+
return y + self.base_acc
|
| 64 |
+
|
| 65 |
+
def predict_acc(self, arch_dict_list):
|
| 66 |
+
X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
|
| 67 |
+
X = torch.tensor(np.array(X)).float().to(self.device)
|
| 68 |
+
return self.forward(X)
|
proard/nas/accuracy_predictor/acc_rob_dataset.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils.data
|
| 11 |
+
|
| 12 |
+
from proard.utils import list_mean
|
| 13 |
+
|
| 14 |
+
__all__ = ["net_setting2id", "net_id2setting", "AccuracyRobustnessDataset"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def net_setting2id(net_setting):
|
| 18 |
+
return json.dumps(net_setting)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def net_id2setting(net_id):
|
| 22 |
+
return json.loads(net_id)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TwoRegDataset(torch.utils.data.Dataset):
|
| 26 |
+
def __init__(self, inputs, targets_acc , targets_rob ):
|
| 27 |
+
super(TwoRegDataset, self).__init__()
|
| 28 |
+
self.inputs = inputs
|
| 29 |
+
self.targets_acc = targets_acc
|
| 30 |
+
self.targets_rob = targets_rob
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, index):
|
| 33 |
+
return self.inputs[index], self.targets_acc[index] , self.targets_rob[index]
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
return self.inputs.size(0)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class AccuracyRobustnessDataset:
|
| 40 |
+
def __init__(self, path):
|
| 41 |
+
self.path = path
|
| 42 |
+
os.makedirs(self.path, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def net_id_path(self):
|
| 46 |
+
return os.path.join(self.path, "net_id.dict")
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def acc_rob_src_folder(self):
|
| 50 |
+
return os.path.join(self.path, "src")
|
| 51 |
+
@property
|
| 52 |
+
def acc_rob_dict_path(self):
|
| 53 |
+
return os.path.join(self.path, "src/acc_robust.dict")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# TODO: support parallel building
|
| 57 |
+
def build_acc_rob_dataset(
|
| 58 |
+
self, run_manager, dyn_network, n_arch=2000, image_size_list=None
|
| 59 |
+
):
|
| 60 |
+
# load net_id_list, random sample if not exist
|
| 61 |
+
if os.path.isfile(self.net_id_path):
|
| 62 |
+
net_id_list = json.load(open(self.net_id_path))
|
| 63 |
+
else:
|
| 64 |
+
net_id_list = set()
|
| 65 |
+
while len(net_id_list) < n_arch:
|
| 66 |
+
net_setting = dyn_network.sample_active_subnet()
|
| 67 |
+
net_id = net_setting2id(net_setting)
|
| 68 |
+
net_id_list.add(net_id)
|
| 69 |
+
net_id_list = list(net_id_list)
|
| 70 |
+
net_id_list.sort()
|
| 71 |
+
json.dump(net_id_list, open(self.net_id_path, "w"), indent=4)
|
| 72 |
+
|
| 73 |
+
image_size_list = (
|
| 74 |
+
[128, 160, 192, 224] if image_size_list is None else image_size_list
|
| 75 |
+
)
|
| 76 |
+
print(image_size_list)
|
| 77 |
+
with tqdm(
|
| 78 |
+
total=len(net_id_list) * len(image_size_list), desc="Building Acc Dataset"
|
| 79 |
+
) as t:
|
| 80 |
+
for image_size in image_size_list:
|
| 81 |
+
# load val dataset into memory
|
| 82 |
+
val_dataset = []
|
| 83 |
+
run_manager.run_config.data_provider.assign_active_img_size(image_size)
|
| 84 |
+
for images, labels in run_manager.run_config.valid_loader:
|
| 85 |
+
val_dataset.append((images, labels))
|
| 86 |
+
# save path
|
| 87 |
+
os.makedirs(self.acc_rob_src_folder, exist_ok=True)
|
| 88 |
+
acc_rob_save_path = os.path.join(
|
| 89 |
+
self.acc_rob_src_folder, "%d.dict" % image_size
|
| 90 |
+
)
|
| 91 |
+
acc_rob_dict = {}
|
| 92 |
+
# load existing acc dict
|
| 93 |
+
if os.path.isfile(acc_rob_save_path):
|
| 94 |
+
existing_acc_rob_dict = json.load(open(acc_rob_save_path, "r"))
|
| 95 |
+
else:
|
| 96 |
+
existing_acc_rob_dict = {}
|
| 97 |
+
for net_id in net_id_list:
|
| 98 |
+
net_setting = net_id2setting(net_id)
|
| 99 |
+
key = net_setting2id({**net_setting, "image_size": image_size})
|
| 100 |
+
if key in existing_acc_rob_dict:
|
| 101 |
+
acc_rob_dict[key] = existing_acc_rob_dict[key]
|
| 102 |
+
t.set_postfix(
|
| 103 |
+
{
|
| 104 |
+
"net_id": net_id,
|
| 105 |
+
"image_size": image_size,
|
| 106 |
+
"info_val": acc_rob_dict[key],
|
| 107 |
+
"status": "loading",
|
| 108 |
+
}
|
| 109 |
+
)
|
| 110 |
+
t.update()
|
| 111 |
+
continue
|
| 112 |
+
dyn_network.set_active_subnet(**net_setting)
|
| 113 |
+
run_manager.reset_running_statistics(dyn_network)
|
| 114 |
+
net_setting_str = ",".join(
|
| 115 |
+
[
|
| 116 |
+
"%s_%s"
|
| 117 |
+
% (
|
| 118 |
+
key,
|
| 119 |
+
"%.1f" % list_mean(val)
|
| 120 |
+
if isinstance(val, list)
|
| 121 |
+
else val,
|
| 122 |
+
)
|
| 123 |
+
for key, val in net_setting.items()
|
| 124 |
+
]
|
| 125 |
+
)
|
| 126 |
+
loss, (top1, top5,robust1,robust5) = run_manager.validate(
|
| 127 |
+
run_str=net_setting_str,
|
| 128 |
+
net=dyn_network,
|
| 129 |
+
data_loader=val_dataset,
|
| 130 |
+
no_logs=True,
|
| 131 |
+
)
|
| 132 |
+
info_val = [top1,robust1]
|
| 133 |
+
t.set_postfix(
|
| 134 |
+
{
|
| 135 |
+
"net_id": net_id,
|
| 136 |
+
"image_size": image_size,
|
| 137 |
+
"info_val": info_val,
|
| 138 |
+
}
|
| 139 |
+
)
|
| 140 |
+
t.update()
|
| 141 |
+
|
| 142 |
+
acc_rob_dict.update({key: info_val})
|
| 143 |
+
json.dump(acc_rob_dict, open(acc_rob_save_path, "w"), indent=4)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def merge_acc_dataset(self, image_size_list=None):
|
| 147 |
+
# load existing data
|
| 148 |
+
merged_acc_rob_dict = {}
|
| 149 |
+
for fname in os.listdir(self.acc_rob_src_folder):
|
| 150 |
+
if ".dict" not in fname:
|
| 151 |
+
continue
|
| 152 |
+
image_size = int(fname.split(".dict")[0])
|
| 153 |
+
if image_size_list is not None and image_size not in image_size_list:
|
| 154 |
+
print("Skip ", fname)
|
| 155 |
+
continue
|
| 156 |
+
full_path = os.path.join(self.acc_rob_src_folder, fname)
|
| 157 |
+
partial_acc_rob_dict = json.load(open(full_path))
|
| 158 |
+
merged_acc_rob_dict.update(partial_acc_rob_dict)
|
| 159 |
+
print("loaded %s" % full_path)
|
| 160 |
+
json.dump(merged_acc_rob_dict, open(self.acc_rob_dict_path, "w"), indent=4)
|
| 161 |
+
return merged_acc_rob_dict
|
| 162 |
+
|
| 163 |
+
def build_acc_data_loader(
|
| 164 |
+
self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16
|
| 165 |
+
):
|
| 166 |
+
# load data
|
| 167 |
+
acc_rob_dict = json.load(open(self.acc_rob_dict_path))
|
| 168 |
+
X_all = []
|
| 169 |
+
Y_acc_all = []
|
| 170 |
+
Y_rob_all = []
|
| 171 |
+
|
| 172 |
+
with tqdm(total=len(acc_rob_dict), desc="Loading data") as t:
|
| 173 |
+
for k, v in acc_rob_dict.items():
|
| 174 |
+
dic = json.loads(k)
|
| 175 |
+
X_all.append(arch_encoder.arch2feature(dic))
|
| 176 |
+
Y_acc_all.append(v[0] / 100.0) # range: 0 - 1
|
| 177 |
+
Y_rob_all.append(v[1] / 100.0)
|
| 178 |
+
t.update()
|
| 179 |
+
base_acc = np.mean(Y_acc_all)
|
| 180 |
+
base_rob = np.mean(Y_rob_all)
|
| 181 |
+
# convert to torch tensor
|
| 182 |
+
X_all = torch.tensor(X_all, dtype=torch.float)
|
| 183 |
+
Y_acc_all = torch.tensor(Y_acc_all)
|
| 184 |
+
Y_rob_all = torch.tensor(Y_rob_all)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# random shuffle
|
| 188 |
+
shuffle_idx = torch.randperm(len(X_all))
|
| 189 |
+
X_all = X_all[shuffle_idx]
|
| 190 |
+
Y_acc_all = Y_acc_all[shuffle_idx]
|
| 191 |
+
Y_rob_all = Y_rob_all[shuffle_idx]
|
| 192 |
+
# split data
|
| 193 |
+
idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
|
| 194 |
+
val_idx = X_all.size(0) // 5 * 4
|
| 195 |
+
X_train, Y_acc_train, Y_rob_train = X_all[:idx], Y_acc_all[:idx], Y_rob_all[:idx]
|
| 196 |
+
X_test, Y_acc_test , Y_rob_test = X_all[val_idx:], Y_acc_all[val_idx:] , Y_rob_all[val_idx:]
|
| 197 |
+
print("Train Size: %d," % len(X_train), "Valid Size: %d" % len(X_test))
|
| 198 |
+
|
| 199 |
+
# build data loader
|
| 200 |
+
train_dataset = TwoRegDataset(X_train, Y_acc_train , Y_rob_train)
|
| 201 |
+
val_dataset = TwoRegDataset(X_test, Y_acc_test ,Y_rob_test )
|
| 202 |
+
train_loader = torch.utils.data.DataLoader(
|
| 203 |
+
train_dataset,
|
| 204 |
+
batch_size=batch_size,
|
| 205 |
+
shuffle=True,
|
| 206 |
+
pin_memory=False,
|
| 207 |
+
num_workers=n_workers,
|
| 208 |
+
)
|
| 209 |
+
valid_loader = torch.utils.data.DataLoader(
|
| 210 |
+
val_dataset,
|
| 211 |
+
batch_size=batch_size,
|
| 212 |
+
shuffle=False,
|
| 213 |
+
pin_memory=False,
|
| 214 |
+
num_workers=n_workers,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return train_loader, valid_loader, base_acc, base_rob
|
| 218 |
+
|
| 219 |
+
|
proard/nas/accuracy_predictor/acc_rob_predictor.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
__all__ = ["Accuracy_Robustness_Predictor"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Accuracy_Robustness_Predictor(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
arch_encoder,
|
| 17 |
+
hidden_size=400,
|
| 18 |
+
n_layers=6,
|
| 19 |
+
checkpoint_path=None,
|
| 20 |
+
device="cuda:0",
|
| 21 |
+
base_acc_val = None,
|
| 22 |
+
base_rob_val = None
|
| 23 |
+
):
|
| 24 |
+
super(Accuracy_Robustness_Predictor, self).__init__()
|
| 25 |
+
self.arch_encoder = arch_encoder
|
| 26 |
+
self.hidden_size = hidden_size
|
| 27 |
+
self.n_layers = n_layers
|
| 28 |
+
self.device = device
|
| 29 |
+
self.base_acc_val = base_acc_val
|
| 30 |
+
self.base_rob_val = base_rob_val
|
| 31 |
+
# build layers
|
| 32 |
+
layers = []
|
| 33 |
+
for i in range(self.n_layers):
|
| 34 |
+
layers.append(
|
| 35 |
+
nn.Sequential(
|
| 36 |
+
nn.Linear(
|
| 37 |
+
self.arch_encoder.n_dim if i == 0 else self.hidden_size,
|
| 38 |
+
self.hidden_size,
|
| 39 |
+
),
|
| 40 |
+
nn.ReLU(inplace=True),
|
| 41 |
+
)
|
| 42 |
+
)
|
| 43 |
+
layers.append(nn.Linear(self.hidden_size, 2, bias=False))
|
| 44 |
+
self.layers = nn.Sequential(*layers)
|
| 45 |
+
if self.base_acc_val!=None :
|
| 46 |
+
self.base_acc = nn.Parameter(
|
| 47 |
+
torch.tensor(self.base_acc_val, device=self.device), requires_grad=False
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
self.base_acc = nn.Parameter(
|
| 51 |
+
torch.zeros(1, device=self.device), requires_grad=False
|
| 52 |
+
)
|
| 53 |
+
if self.base_rob_val!=None :
|
| 54 |
+
self.base_rob = nn.Parameter(
|
| 55 |
+
torch.tensor(self.base_rob_val, device=self.device), requires_grad=False
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
self.base_rob = nn.Parameter(
|
| 59 |
+
torch.zeros(1, device=self.device), requires_grad=False
|
| 60 |
+
)
|
| 61 |
+
if checkpoint_path is not None and os.path.exists(checkpoint_path):
|
| 62 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 63 |
+
if "state_dict" in checkpoint:
|
| 64 |
+
checkpoint = checkpoint["state_dict"]
|
| 65 |
+
self.load_state_dict(checkpoint)
|
| 66 |
+
print("Loaded checkpoint from %s" % checkpoint_path)
|
| 67 |
+
|
| 68 |
+
self.layers = self.layers.to(self.device)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
y = self.layers(x).squeeze()
|
| 72 |
+
return y + self.base_acc
|
| 73 |
+
|
| 74 |
+
def predict_acc_rob(self, arch_dict_list):
|
| 75 |
+
X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
|
| 76 |
+
X = torch.tensor(np.array(X)).float().to(self.device)
|
| 77 |
+
return self.forward(X)
|
proard/nas/accuracy_predictor/arch_encoder.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import numpy as np
|
| 8 |
+
from proard.classification.networks import ResNets
|
| 9 |
+
|
| 10 |
+
__all__ = ["MobileNetArchEncoder", "ResNetArchEncoder"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MobileNetArchEncoder:
|
| 14 |
+
SPACE_TYPE = "mbv3"
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
image_size_list=None,
|
| 19 |
+
ks_list=None,
|
| 20 |
+
expand_list=None,
|
| 21 |
+
depth_list=None,
|
| 22 |
+
n_stage=None,
|
| 23 |
+
):
|
| 24 |
+
self.image_size_list = [224] if image_size_list is None else image_size_list
|
| 25 |
+
self.ks_list = [3, 5, 7] if ks_list is None else ks_list
|
| 26 |
+
self.expand_list = (
|
| 27 |
+
[3, 4, 6]
|
| 28 |
+
if expand_list is None
|
| 29 |
+
else [int(expand) for expand in expand_list]
|
| 30 |
+
)
|
| 31 |
+
self.depth_list = [2, 3, 4] if depth_list is None else depth_list
|
| 32 |
+
if n_stage is not None:
|
| 33 |
+
self.n_stage = n_stage
|
| 34 |
+
elif self.SPACE_TYPE == "mbv2":
|
| 35 |
+
self.n_stage = 6
|
| 36 |
+
elif self.SPACE_TYPE == "mbv3":
|
| 37 |
+
self.n_stage = 5
|
| 38 |
+
else:
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
|
| 41 |
+
# build info dict
|
| 42 |
+
self.n_dim = 0
|
| 43 |
+
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
|
| 44 |
+
self._build_info_dict(target="r")
|
| 45 |
+
self.k_info = dict(id2val=[], val2id=[], L=[], R=[])
|
| 46 |
+
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
|
| 47 |
+
self._build_info_dict(target="k")
|
| 48 |
+
self._build_info_dict(target="e")
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def max_n_blocks(self):
|
| 52 |
+
if self.SPACE_TYPE == "mbv3":
|
| 53 |
+
return self.n_stage * max(self.depth_list)
|
| 54 |
+
elif self.SPACE_TYPE == "mbv2":
|
| 55 |
+
return (self.n_stage - 1) * max(self.depth_list) + 1
|
| 56 |
+
else:
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
def _build_info_dict(self, target):
|
| 60 |
+
if target == "r":
|
| 61 |
+
target_dict = self.r_info
|
| 62 |
+
target_dict["L"].append(self.n_dim)
|
| 63 |
+
for img_size in self.image_size_list:
|
| 64 |
+
target_dict["val2id"][img_size] = self.n_dim
|
| 65 |
+
target_dict["id2val"][self.n_dim] = img_size
|
| 66 |
+
self.n_dim += 1
|
| 67 |
+
target_dict["R"].append(self.n_dim)
|
| 68 |
+
else:
|
| 69 |
+
if target == "k":
|
| 70 |
+
target_dict = self.k_info
|
| 71 |
+
choices = self.ks_list
|
| 72 |
+
elif target == "e":
|
| 73 |
+
target_dict = self.e_info
|
| 74 |
+
choices = self.expand_list
|
| 75 |
+
else:
|
| 76 |
+
raise NotImplementedError
|
| 77 |
+
for i in range(self.max_n_blocks):
|
| 78 |
+
target_dict["val2id"].append({})
|
| 79 |
+
target_dict["id2val"].append({})
|
| 80 |
+
target_dict["L"].append(self.n_dim)
|
| 81 |
+
for k in choices:
|
| 82 |
+
target_dict["val2id"][i][k] = self.n_dim
|
| 83 |
+
target_dict["id2val"][i][self.n_dim] = k
|
| 84 |
+
self.n_dim += 1
|
| 85 |
+
target_dict["R"].append(self.n_dim)
|
| 86 |
+
|
| 87 |
+
def arch2feature(self, arch_dict):
|
| 88 |
+
ks, e, d, r = (
|
| 89 |
+
arch_dict["ks"],
|
| 90 |
+
arch_dict["e"],
|
| 91 |
+
arch_dict["d"],
|
| 92 |
+
arch_dict["image_size"],
|
| 93 |
+
)
|
| 94 |
+
feature = np.zeros(self.n_dim)
|
| 95 |
+
for i in range(self.max_n_blocks):
|
| 96 |
+
nowd = i % max(self.depth_list)
|
| 97 |
+
stg = i // max(self.depth_list)
|
| 98 |
+
if nowd < d[stg]:
|
| 99 |
+
feature[self.k_info["val2id"][i][ks[i]]] = 1
|
| 100 |
+
feature[self.e_info["val2id"][i][e[i]]] = 1
|
| 101 |
+
feature[self.r_info["val2id"][r[0]]] = 1
|
| 102 |
+
return feature
|
| 103 |
+
|
| 104 |
+
def feature2arch(self, feature):
|
| 105 |
+
img_sz = self.r_info["id2val"][
|
| 106 |
+
int(np.argmax(feature[self.r_info["L"][0] : self.r_info["R"][0]]))
|
| 107 |
+
+ self.r_info["L"][0]
|
| 108 |
+
]
|
| 109 |
+
assert img_sz in self.image_size_list
|
| 110 |
+
arch_dict = {"ks": [], "e": [], "d": [], "image_size": img_sz}
|
| 111 |
+
|
| 112 |
+
d = 0
|
| 113 |
+
for i in range(self.max_n_blocks):
|
| 114 |
+
skip = True
|
| 115 |
+
for j in range(self.k_info["L"][i], self.k_info["R"][i]):
|
| 116 |
+
if feature[j] == 1:
|
| 117 |
+
arch_dict["ks"].append(self.k_info["id2val"][i][j])
|
| 118 |
+
skip = False
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
for j in range(self.e_info["L"][i], self.e_info["R"][i]):
|
| 122 |
+
if feature[j] == 1:
|
| 123 |
+
arch_dict["e"].append(self.e_info["id2val"][i][j])
|
| 124 |
+
assert not skip
|
| 125 |
+
skip = False
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
if skip:
|
| 129 |
+
arch_dict["e"].append(0)
|
| 130 |
+
arch_dict["ks"].append(0)
|
| 131 |
+
else:
|
| 132 |
+
d += 1
|
| 133 |
+
|
| 134 |
+
if (i + 1) % max(self.depth_list) == 0 or (i + 1) == self.max_n_blocks:
|
| 135 |
+
arch_dict["d"].append(d)
|
| 136 |
+
d = 0
|
| 137 |
+
return arch_dict
|
| 138 |
+
|
| 139 |
+
def random_sample_arch(self):
|
| 140 |
+
return {
|
| 141 |
+
"ks": random.choices(self.ks_list, k=self.max_n_blocks),
|
| 142 |
+
"e": random.choices(self.expand_list, k=self.max_n_blocks),
|
| 143 |
+
"d": random.choices(self.depth_list, k=self.n_stage),
|
| 144 |
+
"image_size": [random.choice(self.image_size_list)],
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
def mutate_resolution(self, arch_dict, mutate_prob):
|
| 148 |
+
if random.random() < mutate_prob:
|
| 149 |
+
arch_dict["image_size"] = random.choice(self.image_size_list)
|
| 150 |
+
return arch_dict
|
| 151 |
+
|
| 152 |
+
def mutate_arch(self, arch_dict, mutate_prob):
|
| 153 |
+
for i in range(self.max_n_blocks):
|
| 154 |
+
if random.random() < mutate_prob:
|
| 155 |
+
arch_dict["ks"][i] = random.choice(self.ks_list)
|
| 156 |
+
arch_dict["e"][i] = random.choice(self.expand_list)
|
| 157 |
+
|
| 158 |
+
for i in range(self.n_stage):
|
| 159 |
+
if random.random() < mutate_prob:
|
| 160 |
+
arch_dict["d"][i] = random.choice(self.depth_list)
|
| 161 |
+
return arch_dict
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class ResNetArchEncoder:
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
image_size_list=None,
|
| 168 |
+
depth_list=None,
|
| 169 |
+
expand_list=None,
|
| 170 |
+
width_mult_list=None,
|
| 171 |
+
base_depth_list=None,
|
| 172 |
+
):
|
| 173 |
+
self.image_size_list = [224] if image_size_list is None else image_size_list
|
| 174 |
+
self.expand_list = [0.2, 0.25, 0.35] if expand_list is None else expand_list
|
| 175 |
+
self.depth_list = [0, 1, 2] if depth_list is None else depth_list
|
| 176 |
+
self.width_mult_list = (
|
| 177 |
+
[0.65, 0.8, 1.0] if width_mult_list is None else width_mult_list
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.base_depth_list = (
|
| 181 |
+
ResNets.BASE_DEPTH_LIST if base_depth_list is None else base_depth_list
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
"""" build info dict """
|
| 185 |
+
self.n_dim = 0
|
| 186 |
+
# resolution
|
| 187 |
+
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
|
| 188 |
+
self._build_info_dict(target="r")
|
| 189 |
+
# input stem skip
|
| 190 |
+
self.input_stem_d_info = dict(id2val={}, val2id={}, L=[], R=[])
|
| 191 |
+
self._build_info_dict(target="input_stem_d")
|
| 192 |
+
# width_mult
|
| 193 |
+
self.width_mult_info = dict(id2val=[], val2id=[], L=[], R=[])
|
| 194 |
+
self._build_info_dict(target="width_mult")
|
| 195 |
+
# expand ratio
|
| 196 |
+
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
|
| 197 |
+
self._build_info_dict(target="e")
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def n_stage(self):
|
| 201 |
+
return len(self.base_depth_list)
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def max_n_blocks(self):
|
| 205 |
+
return sum(self.base_depth_list) + self.n_stage * max(self.depth_list)
|
| 206 |
+
|
| 207 |
+
def _build_info_dict(self, target):
|
| 208 |
+
if target == "r":
|
| 209 |
+
target_dict = self.r_info
|
| 210 |
+
target_dict["L"].append(self.n_dim)
|
| 211 |
+
for img_size in self.image_size_list:
|
| 212 |
+
target_dict["val2id"][img_size] = self.n_dim
|
| 213 |
+
target_dict["id2val"][self.n_dim] = img_size
|
| 214 |
+
self.n_dim += 1
|
| 215 |
+
target_dict["R"].append(self.n_dim)
|
| 216 |
+
elif target == "input_stem_d":
|
| 217 |
+
target_dict = self.input_stem_d_info
|
| 218 |
+
target_dict["L"].append(self.n_dim)
|
| 219 |
+
for skip in [0, 1]:
|
| 220 |
+
target_dict["val2id"][skip] = self.n_dim
|
| 221 |
+
target_dict["id2val"][self.n_dim] = skip
|
| 222 |
+
self.n_dim += 1
|
| 223 |
+
target_dict["R"].append(self.n_dim)
|
| 224 |
+
elif target == "e":
|
| 225 |
+
target_dict = self.e_info
|
| 226 |
+
choices = self.expand_list
|
| 227 |
+
for i in range(self.max_n_blocks):
|
| 228 |
+
target_dict["val2id"].append({})
|
| 229 |
+
target_dict["id2val"].append({})
|
| 230 |
+
target_dict["L"].append(self.n_dim)
|
| 231 |
+
for e in choices:
|
| 232 |
+
target_dict["val2id"][i][e] = self.n_dim
|
| 233 |
+
target_dict["id2val"][i][self.n_dim] = e
|
| 234 |
+
self.n_dim += 1
|
| 235 |
+
target_dict["R"].append(self.n_dim)
|
| 236 |
+
elif target == "width_mult":
|
| 237 |
+
target_dict = self.width_mult_info
|
| 238 |
+
choices = list(range(len(self.width_mult_list)))
|
| 239 |
+
for i in range(self.n_stage + 2):
|
| 240 |
+
target_dict["val2id"].append({})
|
| 241 |
+
target_dict["id2val"].append({})
|
| 242 |
+
target_dict["L"].append(self.n_dim)
|
| 243 |
+
for w in choices:
|
| 244 |
+
target_dict["val2id"][i][w] = self.n_dim
|
| 245 |
+
target_dict["id2val"][i][self.n_dim] = w
|
| 246 |
+
self.n_dim += 1
|
| 247 |
+
target_dict["R"].append(self.n_dim)
|
| 248 |
+
|
| 249 |
+
def arch2feature(self, arch_dict):
|
| 250 |
+
d, e, w, r = (
|
| 251 |
+
arch_dict["d"],
|
| 252 |
+
arch_dict["e"],
|
| 253 |
+
arch_dict["w"],
|
| 254 |
+
arch_dict["image_size"],
|
| 255 |
+
)
|
| 256 |
+
input_stem_skip = 1 if d[0] > 0 else 0
|
| 257 |
+
d = d[1:]
|
| 258 |
+
|
| 259 |
+
feature = np.zeros(self.n_dim)
|
| 260 |
+
feature[self.r_info["val2id"][r]] = 1
|
| 261 |
+
feature[self.input_stem_d_info["val2id"][input_stem_skip]] = 1
|
| 262 |
+
for i in range(self.n_stage + 2):
|
| 263 |
+
feature[self.width_mult_info["val2id"][i][w[i]]] = 1
|
| 264 |
+
|
| 265 |
+
start_pt = 0
|
| 266 |
+
for i, base_depth in enumerate(self.base_depth_list):
|
| 267 |
+
depth = base_depth + d[i]
|
| 268 |
+
for j in range(start_pt, start_pt + depth):
|
| 269 |
+
feature[self.e_info["val2id"][j][e[j]]] = 1
|
| 270 |
+
start_pt += max(self.depth_list) + base_depth
|
| 271 |
+
return feature
|
| 272 |
+
|
| 273 |
+
def feature2arch(self, feature):
|
| 274 |
+
img_sz = self.r_info["id2val"][
|
| 275 |
+
int(np.argmax(feature[self.r_info["L"][0] : self.r_info["R"][0]]))
|
| 276 |
+
+ self.r_info["L"][0]
|
| 277 |
+
]
|
| 278 |
+
input_stem_skip = (
|
| 279 |
+
self.input_stem_d_info["id2val"][
|
| 280 |
+
int(
|
| 281 |
+
np.argmax(
|
| 282 |
+
feature[
|
| 283 |
+
self.input_stem_d_info["L"][0] : self.input_stem_d_info[
|
| 284 |
+
"R"
|
| 285 |
+
][0]
|
| 286 |
+
]
|
| 287 |
+
)
|
| 288 |
+
)
|
| 289 |
+
+ self.input_stem_d_info["L"][0]
|
| 290 |
+
]
|
| 291 |
+
* 2
|
| 292 |
+
)
|
| 293 |
+
assert img_sz in self.image_size_list
|
| 294 |
+
arch_dict = {"d": [input_stem_skip], "e": [], "w": [], "image_size": img_sz}
|
| 295 |
+
|
| 296 |
+
for i in range(self.n_stage + 2):
|
| 297 |
+
arch_dict["w"].append(
|
| 298 |
+
self.width_mult_info["id2val"][i][
|
| 299 |
+
int(
|
| 300 |
+
np.argmax(
|
| 301 |
+
feature[
|
| 302 |
+
self.width_mult_info["L"][i] : self.width_mult_info[
|
| 303 |
+
"R"
|
| 304 |
+
][i]
|
| 305 |
+
]
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
+ self.width_mult_info["L"][i]
|
| 309 |
+
]
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
d = 0
|
| 313 |
+
skipped = 0
|
| 314 |
+
stage_id = 0
|
| 315 |
+
for i in range(self.max_n_blocks):
|
| 316 |
+
skip = True
|
| 317 |
+
for j in range(self.e_info["L"][i], self.e_info["R"][i]):
|
| 318 |
+
if feature[j] == 1:
|
| 319 |
+
arch_dict["e"].append(self.e_info["id2val"][i][j])
|
| 320 |
+
skip = False
|
| 321 |
+
break
|
| 322 |
+
if skip:
|
| 323 |
+
arch_dict["e"].append(0)
|
| 324 |
+
skipped += 1
|
| 325 |
+
else:
|
| 326 |
+
d += 1
|
| 327 |
+
|
| 328 |
+
if (
|
| 329 |
+
i + 1 == self.max_n_blocks
|
| 330 |
+
or (skipped + d)
|
| 331 |
+
% (max(self.depth_list) + self.base_depth_list[stage_id])
|
| 332 |
+
== 0
|
| 333 |
+
):
|
| 334 |
+
arch_dict["d"].append(d - self.base_depth_list[stage_id])
|
| 335 |
+
d, skipped = 0, 0
|
| 336 |
+
stage_id += 1
|
| 337 |
+
return arch_dict
|
| 338 |
+
|
| 339 |
+
def random_sample_arch(self):
|
| 340 |
+
return {
|
| 341 |
+
"d": [random.choice([0, 2])]
|
| 342 |
+
+ random.choices(self.depth_list, k=self.n_stage),
|
| 343 |
+
"e": random.choices(self.expand_list, k=self.max_n_blocks),
|
| 344 |
+
"w": random.choices(
|
| 345 |
+
list(range(len(self.width_mult_list))), k=self.n_stage + 2
|
| 346 |
+
),
|
| 347 |
+
"image_size": random.choice(self.image_size_list),
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
def mutate_resolution(self, arch_dict, mutate_prob):
|
| 351 |
+
if random.random() < mutate_prob:
|
| 352 |
+
arch_dict["image_size"] = random.choice(self.image_size_list)
|
| 353 |
+
return arch_dict
|
| 354 |
+
|
| 355 |
+
def mutate_arch(self, arch_dict, mutate_prob):
|
| 356 |
+
# input stem skip
|
| 357 |
+
if random.random() < mutate_prob:
|
| 358 |
+
arch_dict["d"][0] = random.choice([0, 2])
|
| 359 |
+
# depth
|
| 360 |
+
for i in range(1, len(arch_dict["d"])):
|
| 361 |
+
if random.random() < mutate_prob:
|
| 362 |
+
arch_dict["d"][i] = random.choice(self.depth_list)
|
| 363 |
+
# width_mult
|
| 364 |
+
for i in range(len(arch_dict["w"])):
|
| 365 |
+
if random.random() < mutate_prob:
|
| 366 |
+
arch_dict["w"][i] = random.choice(
|
| 367 |
+
list(range(len(self.width_mult_list)))
|
| 368 |
+
)
|
| 369 |
+
# expand ratio
|
| 370 |
+
for i in range(len(arch_dict["e"])):
|
| 371 |
+
if random.random() < mutate_prob:
|
| 372 |
+
arch_dict["e"][i] = random.choice(self.expand_list)
|
proard/nas/accuracy_predictor/rob_dataset.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils.data
|
| 11 |
+
|
| 12 |
+
from proard.utils import list_mean
|
| 13 |
+
|
| 14 |
+
__all__ = ["net_setting2id", "net_id2setting", "RobustnessDataset"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def net_setting2id(net_setting):
|
| 18 |
+
return json.dumps(net_setting)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def net_id2setting(net_id):
|
| 22 |
+
return json.loads(net_id)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RegDataset(torch.utils.data.Dataset):
|
| 26 |
+
def __init__(self, inputs, targets):
|
| 27 |
+
super(RegDataset, self).__init__()
|
| 28 |
+
self.inputs = inputs
|
| 29 |
+
self.targets = targets
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, index):
|
| 32 |
+
return self.inputs[index], self.targets[index]
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return self.inputs.size(0)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class RobustnessDataset:
|
| 39 |
+
def __init__(self, path):
|
| 40 |
+
self.path = path
|
| 41 |
+
os.makedirs(self.path, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def net_id_path(self):
|
| 45 |
+
return os.path.join(self.path, "net_id.dict")
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def rob_src_folder(self):
|
| 49 |
+
return os.path.join(self.path, "src_rob")
|
| 50 |
+
@property
|
| 51 |
+
def rob_dict_path(self):
|
| 52 |
+
return os.path.join(self.path, "src_rob/rob.dict")
|
| 53 |
+
|
| 54 |
+
# TODO: support parallel building
|
| 55 |
+
def build_rob_dataset(
|
| 56 |
+
self, run_manager, dyn_network, n_arch=2000, image_size_list=None
|
| 57 |
+
):
|
| 58 |
+
# load net_id_list, random sample if not exist
|
| 59 |
+
if os.path.isfile(self.net_id_path):
|
| 60 |
+
net_id_list = json.load(open(self.net_id_path))
|
| 61 |
+
else:
|
| 62 |
+
net_id_list = set()
|
| 63 |
+
while len(net_id_list) < n_arch:
|
| 64 |
+
net_setting = dyn_network.sample_active_subnet()
|
| 65 |
+
net_id = net_setting2id(net_setting)
|
| 66 |
+
net_id_list.add(net_id)
|
| 67 |
+
net_id_list = list(net_id_list)
|
| 68 |
+
net_id_list.sort()
|
| 69 |
+
json.dump(net_id_list, open(self.net_id_path, "w"), indent=4)
|
| 70 |
+
|
| 71 |
+
image_size_list = (
|
| 72 |
+
[128, 160, 192, 224] if image_size_list is None else image_size_list
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
with tqdm(
|
| 76 |
+
total=len(net_id_list) * len(image_size_list), desc="Building Robustness Dataset"
|
| 77 |
+
) as t:
|
| 78 |
+
for image_size in image_size_list:
|
| 79 |
+
# load val dataset into memory
|
| 80 |
+
val_dataset = []
|
| 81 |
+
run_manager.run_config.data_provider.assign_active_img_size(image_size)
|
| 82 |
+
for images, labels in run_manager.run_config.valid_loader:
|
| 83 |
+
val_dataset.append((images, labels))
|
| 84 |
+
# save path
|
| 85 |
+
os.makedirs(self.rob_src_folder, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
rob_save_path = os.path.join(
|
| 88 |
+
self.rob_src_folder, "%d.dict" % image_size
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
rob_dict ={}
|
| 92 |
+
# load existing rob dict
|
| 93 |
+
if os.path.isfile(rob_save_path):
|
| 94 |
+
existing_rob_dict = json.load(open(rob_save_path,"r"))
|
| 95 |
+
else:
|
| 96 |
+
existing_rob_dict = {}
|
| 97 |
+
for net_id in net_id_list:
|
| 98 |
+
net_setting = net_id2setting(net_id)
|
| 99 |
+
key = net_setting2id({**net_setting, "image_size": image_size})
|
| 100 |
+
if key in existing_rob_dict:
|
| 101 |
+
rob_dict[key] = existing_rob_dict[key]
|
| 102 |
+
t.set_postfix(
|
| 103 |
+
{
|
| 104 |
+
"net_id": net_id,
|
| 105 |
+
"image_size": image_size,
|
| 106 |
+
"info_rob" : rob_dict[key],
|
| 107 |
+
"status": "loading",
|
| 108 |
+
}
|
| 109 |
+
)
|
| 110 |
+
t.update()
|
| 111 |
+
continue
|
| 112 |
+
dyn_network.set_active_subnet(**net_setting)
|
| 113 |
+
run_manager.reset_running_statistics(dyn_network)
|
| 114 |
+
net_setting_str = ",".join(
|
| 115 |
+
[
|
| 116 |
+
"%s_%s"
|
| 117 |
+
% (
|
| 118 |
+
key,
|
| 119 |
+
"%.1f" % list_mean(val)
|
| 120 |
+
if isinstance(val, list)
|
| 121 |
+
else val,
|
| 122 |
+
)
|
| 123 |
+
for key, val in net_setting.items()
|
| 124 |
+
]
|
| 125 |
+
)
|
| 126 |
+
loss, (top1, top5,robust1,robust5) = run_manager.validate(
|
| 127 |
+
run_str=net_setting_str,
|
| 128 |
+
net=dyn_network,
|
| 129 |
+
data_loader=val_dataset,
|
| 130 |
+
no_logs=True,
|
| 131 |
+
)
|
| 132 |
+
info_robust = robust1
|
| 133 |
+
t.set_postfix(
|
| 134 |
+
{
|
| 135 |
+
"net_id": net_id,
|
| 136 |
+
"image_size": image_size,
|
| 137 |
+
"info_rob" : info_robust,
|
| 138 |
+
"info_robust" : info_robust,
|
| 139 |
+
}
|
| 140 |
+
)
|
| 141 |
+
t.update()
|
| 142 |
+
|
| 143 |
+
rob_dict.update({key:info_robust})
|
| 144 |
+
json.dump(rob_dict, open(rob_save_path, "w"), indent=4)
|
| 145 |
+
|
| 146 |
+
def merge_rob_dataset(self, image_size_list=None):
|
| 147 |
+
# load existing data
|
| 148 |
+
merged_rob_dict = {}
|
| 149 |
+
for fname in os.listdir(self.rob_src_folder):
|
| 150 |
+
if ".dict" not in fname:
|
| 151 |
+
continue
|
| 152 |
+
image_size = int(fname.split(".dict")[0])
|
| 153 |
+
if image_size_list is not None and image_size not in image_size_list:
|
| 154 |
+
print("Skip ", fname)
|
| 155 |
+
continue
|
| 156 |
+
full_path = os.path.join(self.rob_src_folder, fname)
|
| 157 |
+
partial_rob_dict = json.load(open(full_path))
|
| 158 |
+
merged_rob_dict.update(partial_rob_dict)
|
| 159 |
+
print("loaded %s" % full_path)
|
| 160 |
+
json.dump(merged_rob_dict, open(self.rob_dict_path, "w"), indent=4)
|
| 161 |
+
return merged_rob_dict
|
| 162 |
+
|
| 163 |
+
def build_rob_data_loader(
|
| 164 |
+
self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16
|
| 165 |
+
):
|
| 166 |
+
# load data
|
| 167 |
+
rob_dict = json.load(open(self.rob_dict_path))
|
| 168 |
+
X_all_rob = []
|
| 169 |
+
Y_all_rob = []
|
| 170 |
+
with tqdm(total=len(rob_dict), desc="Loading data") as t:
|
| 171 |
+
for k, v in rob_dict.items():
|
| 172 |
+
dic = json.loads(k)
|
| 173 |
+
X_all_rob.append(arch_encoder.arch2feature(dic))
|
| 174 |
+
Y_all_rob.append(v / 100.0) # range: 0 - 1
|
| 175 |
+
t.update()
|
| 176 |
+
base_rob = np.mean(Y_all_rob)
|
| 177 |
+
# convert to torch tensor
|
| 178 |
+
X_all_rob = torch.tensor(X_all_rob, dtype=torch.float)
|
| 179 |
+
Y_all_rob = torch.tensor(Y_all_rob)
|
| 180 |
+
|
| 181 |
+
# random shuffle
|
| 182 |
+
shuffle_idx_rob = torch.randperm(len(X_all_rob))
|
| 183 |
+
X_all_rob = X_all_rob[shuffle_idx_rob]
|
| 184 |
+
Y_all_rob = Y_all_rob[shuffle_idx_rob]
|
| 185 |
+
# split data
|
| 186 |
+
idx_rob = X_all_rob.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
|
| 187 |
+
val_idx_rob = X_all_rob.size(0) // 5 * 4
|
| 188 |
+
X_train_rob, Y_train_rob = X_all_rob[:idx_rob], Y_all_rob[:idx_rob]
|
| 189 |
+
X_test_rob, Y_test_rob = X_all_rob[val_idx_rob:], Y_all_rob[val_idx_rob:]
|
| 190 |
+
print("Train Robustness Size: %d," % len(X_train_rob), "Valid Robustness Size: %d" % len(X_test_rob))
|
| 191 |
+
# build data loader
|
| 192 |
+
train_dataset_rob = RegDataset(X_train_rob, Y_train_rob)
|
| 193 |
+
val_dataset_rob = RegDataset(X_test_rob, Y_test_rob)
|
| 194 |
+
|
| 195 |
+
train_loader_rob = torch.utils.data.DataLoader(
|
| 196 |
+
train_dataset_rob,
|
| 197 |
+
batch_size=batch_size,
|
| 198 |
+
shuffle=True,
|
| 199 |
+
pin_memory=False,
|
| 200 |
+
num_workers=n_workers,
|
| 201 |
+
)
|
| 202 |
+
valid_loader_rob = torch.utils.data.DataLoader(
|
| 203 |
+
val_dataset_rob,
|
| 204 |
+
batch_size=batch_size,
|
| 205 |
+
shuffle=False,
|
| 206 |
+
pin_memory=False,
|
| 207 |
+
num_workers=n_workers,
|
| 208 |
+
)
|
| 209 |
+
return train_loader_rob, valid_loader_rob , base_rob
|
| 210 |
+
|
| 211 |
+
|
proard/nas/accuracy_predictor/rob_predictor.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
__all__ = ["RobustnessPredictor"]
|
| 11 |
+
|
| 12 |
+
class RobustnessPredictor(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
arch_encoder,
|
| 16 |
+
hidden_size=400,
|
| 17 |
+
n_layers=3,
|
| 18 |
+
checkpoint_path=None,
|
| 19 |
+
device="cuda:0",
|
| 20 |
+
base_rob_val = None,
|
| 21 |
+
):
|
| 22 |
+
super(RobustnessPredictor, self).__init__()
|
| 23 |
+
self.arch_encoder = arch_encoder
|
| 24 |
+
self.hidden_size = hidden_size
|
| 25 |
+
self.n_layers = n_layers
|
| 26 |
+
self.device = device
|
| 27 |
+
self.base_rob_val = base_rob_val
|
| 28 |
+
# build layers
|
| 29 |
+
layers = []
|
| 30 |
+
for i in range(self.n_layers):
|
| 31 |
+
layers.append(
|
| 32 |
+
nn.Sequential(
|
| 33 |
+
nn.Linear(
|
| 34 |
+
self.arch_encoder.n_dim if i == 0 else self.hidden_size,
|
| 35 |
+
self.hidden_size,
|
| 36 |
+
),
|
| 37 |
+
nn.ReLU(inplace=True),
|
| 38 |
+
)
|
| 39 |
+
)
|
| 40 |
+
layers.append(nn.Linear(self.hidden_size, 1, bias=False))
|
| 41 |
+
self.layers = nn.Sequential(*layers)
|
| 42 |
+
if self.base_rob_val !=None :
|
| 43 |
+
self.base_rob = nn.Parameter(
|
| 44 |
+
torch.tensor(self.base_rob_val,device=self.device), requires_grad=False
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
self.base_rob = nn.Parameter(
|
| 48 |
+
torch.zeros(1, device=self.device), requires_grad=False
|
| 49 |
+
)
|
| 50 |
+
if checkpoint_path is not None and os.path.exists(checkpoint_path):
|
| 51 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 52 |
+
if "state_dict" in checkpoint:
|
| 53 |
+
checkpoint = checkpoint["state_dict"]
|
| 54 |
+
self.load_state_dict(checkpoint)
|
| 55 |
+
print("Loaded checkpoint from %s" % checkpoint_path)
|
| 56 |
+
|
| 57 |
+
self.layers = self.layers.to(self.device)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
y = self.layers(x).squeeze()
|
| 61 |
+
return y + self.base_rob
|
| 62 |
+
|
| 63 |
+
def predict_rob(self, arch_dict_list):
|
| 64 |
+
X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
|
| 65 |
+
X = torch.tensor(np.array(X)).float().to(self.device)
|
| 66 |
+
return self.forward(X)
|
proard/nas/efficiency_predictor/__init__.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import copy
|
| 7 |
+
from .latency_lookup_table import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BaseEfficiencyModel:
|
| 11 |
+
def __init__(self, dyn_net):
|
| 12 |
+
self.dyn_net = dyn_net
|
| 13 |
+
|
| 14 |
+
def get_active_subnet_config(self, arch_dict):
|
| 15 |
+
arch_dict = copy.deepcopy(arch_dict)
|
| 16 |
+
image_size = arch_dict.pop("image_size")
|
| 17 |
+
self.dyn_net.set_active_subnet(**arch_dict)
|
| 18 |
+
active_net_config = self.dyn_net.get_active_net_config()
|
| 19 |
+
return active_net_config, image_size
|
| 20 |
+
|
| 21 |
+
def get_efficiency(self, arch_dict):
|
| 22 |
+
raise NotImplementedError
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ProxylessNASFLOPsModel(BaseEfficiencyModel):
|
| 26 |
+
def get_efficiency(self, arch_dict):
|
| 27 |
+
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
| 28 |
+
return ProxylessNASLatencyTable.count_flops_given_config(
|
| 29 |
+
active_net_config, image_size
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Mbv3FLOPsModel(BaseEfficiencyModel):
|
| 34 |
+
def get_efficiency(self, arch_dict):
|
| 35 |
+
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
| 36 |
+
return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size[0])
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ResNet50FLOPsModel(BaseEfficiencyModel):
|
| 40 |
+
def get_efficiency(self, arch_dict):
|
| 41 |
+
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
| 42 |
+
return ResNet50LatencyTable.count_flops_given_config(
|
| 43 |
+
active_net_config, image_size
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ProxylessNASLatencyModel(BaseEfficiencyModel):
|
| 48 |
+
def __init__(self, dyn_net, lookup_table_path_dict):
|
| 49 |
+
super(ProxylessNASLatencyModel, self).__init__(dyn_net)
|
| 50 |
+
self.latency_tables = {}
|
| 51 |
+
for image_size, path in lookup_table_path_dict.items():
|
| 52 |
+
self.latency_tables[image_size] = ProxylessNASLatencyTable(
|
| 53 |
+
local_dir="/tmp/.dyn_latency_tools/",
|
| 54 |
+
url=os.path.join(path, "%d_lookup_table.yaml" % image_size),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def get_efficiency(self, arch_dict):
|
| 58 |
+
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
| 59 |
+
return self.latency_tables[image_size].predict_network_latency_given_config(
|
| 60 |
+
active_net_config, image_size
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Mbv3LatencyModel(BaseEfficiencyModel):
|
| 65 |
+
def __init__(self, dyn_net, lookup_table_path_dict):
|
| 66 |
+
super(Mbv3LatencyModel, self).__init__(dyn_net)
|
| 67 |
+
self.latency_tables = {}
|
| 68 |
+
for image_size, path in lookup_table_path_dict.items():
|
| 69 |
+
self.latency_tables[image_size] = MBv3LatencyTable(
|
| 70 |
+
local_dir="/tmp/.dyn_latency_tools/",
|
| 71 |
+
url=os.path.join(path, "%d_lookup_table.yaml" % image_size),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def get_efficiency(self, arch_dict):
|
| 75 |
+
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
| 76 |
+
return self.latency_tables[image_size].predict_network_latency_given_config(
|
| 77 |
+
active_net_config, image_size
|
| 78 |
+
)
|
proard/nas/efficiency_predictor/latency_lookup_table.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
from proard.utils import download_url, make_divisible, MyNetwork
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"count_conv_flop",
|
| 10 |
+
"ProxylessNASLatencyTable",
|
| 11 |
+
"MBv3LatencyTable",
|
| 12 |
+
"ResNet50LatencyTable",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def count_conv_flop(out_size, in_channels, out_channels, kernel_size, groups):
|
| 17 |
+
out_h = out_w = out_size
|
| 18 |
+
delta_ops = (
|
| 19 |
+
in_channels * out_channels * kernel_size * kernel_size * out_h * out_w / groups
|
| 20 |
+
)
|
| 21 |
+
return delta_ops
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class LatencyTable(object):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
local_dir="~/.dyn/latency_tools/",
|
| 28 |
+
url="https://raw.githubusercontent.com/han-cai/files/master/proxylessnas/mobile_trim.yaml",
|
| 29 |
+
):
|
| 30 |
+
if url.startswith("http"):
|
| 31 |
+
fname = download_url(url, local_dir, overwrite=True)
|
| 32 |
+
else:
|
| 33 |
+
fname = url
|
| 34 |
+
with open(fname, "r") as fp:
|
| 35 |
+
self.lut = yaml.load(fp)
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def repr_shape(shape):
|
| 39 |
+
if isinstance(shape, (list, tuple)):
|
| 40 |
+
return "x".join(str(_) for _ in shape)
|
| 41 |
+
elif isinstance(shape, str):
|
| 42 |
+
return shape
|
| 43 |
+
else:
|
| 44 |
+
return TypeError
|
| 45 |
+
|
| 46 |
+
def query(self, **kwargs):
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
|
| 49 |
+
def predict_network_latency(self, net, image_size):
|
| 50 |
+
raise NotImplementedError
|
| 51 |
+
|
| 52 |
+
def predict_network_latency_given_config(self, net_config, image_size):
|
| 53 |
+
raise NotImplementedError
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def count_flops_given_config(net_config, image_size=224):
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ProxylessNASLatencyTable(LatencyTable):
|
| 61 |
+
def query(
|
| 62 |
+
self,
|
| 63 |
+
l_type: str,
|
| 64 |
+
input_shape,
|
| 65 |
+
output_shape,
|
| 66 |
+
expand=None,
|
| 67 |
+
ks=None,
|
| 68 |
+
stride=None,
|
| 69 |
+
id_skip=None,
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
:param l_type:
|
| 73 |
+
Layer type must be one of the followings
|
| 74 |
+
1. `Conv`: The initial 3x3 conv with stride 2.
|
| 75 |
+
2. `Conv_1`: feature_mix_layer
|
| 76 |
+
3. `Logits`: All operations after `Conv_1`.
|
| 77 |
+
4. `expanded_conv`: MobileInvertedResidual
|
| 78 |
+
:param input_shape: input shape (h, w, #channels)
|
| 79 |
+
:param output_shape: output shape (h, w, #channels)
|
| 80 |
+
:param expand: expansion ratio
|
| 81 |
+
:param ks: kernel size
|
| 82 |
+
:param stride:
|
| 83 |
+
:param id_skip: indicate whether has the residual connection
|
| 84 |
+
"""
|
| 85 |
+
infos = [
|
| 86 |
+
l_type,
|
| 87 |
+
"input:%s" % self.repr_shape(input_shape),
|
| 88 |
+
"output:%s" % self.repr_shape(output_shape),
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
if l_type in ("expanded_conv",):
|
| 92 |
+
assert None not in (expand, ks, stride, id_skip)
|
| 93 |
+
infos += [
|
| 94 |
+
"expand:%d" % expand,
|
| 95 |
+
"kernel:%d" % ks,
|
| 96 |
+
"stride:%d" % stride,
|
| 97 |
+
"idskip:%d" % id_skip,
|
| 98 |
+
]
|
| 99 |
+
key = "-".join(infos)
|
| 100 |
+
return self.lut[key]["mean"]
|
| 101 |
+
|
| 102 |
+
def predict_network_latency(self, net, image_size=224):
|
| 103 |
+
predicted_latency = 0
|
| 104 |
+
# first conv
|
| 105 |
+
predicted_latency += self.query(
|
| 106 |
+
"Conv",
|
| 107 |
+
[image_size, image_size, 3],
|
| 108 |
+
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels],
|
| 109 |
+
)
|
| 110 |
+
# blocks
|
| 111 |
+
fsize = (image_size + 1) // 2
|
| 112 |
+
for block in net.blocks:
|
| 113 |
+
mb_conv = block.conv
|
| 114 |
+
shortcut = block.shortcut
|
| 115 |
+
|
| 116 |
+
if mb_conv is None:
|
| 117 |
+
continue
|
| 118 |
+
if shortcut is None:
|
| 119 |
+
idskip = 0
|
| 120 |
+
else:
|
| 121 |
+
idskip = 1
|
| 122 |
+
out_fz = int((fsize - 1) / mb_conv.stride + 1) # fsize // mb_conv.stride
|
| 123 |
+
block_latency = self.query(
|
| 124 |
+
"expanded_conv",
|
| 125 |
+
[fsize, fsize, mb_conv.in_channels],
|
| 126 |
+
[out_fz, out_fz, mb_conv.out_channels],
|
| 127 |
+
expand=mb_conv.expand_ratio,
|
| 128 |
+
ks=mb_conv.kernel_size,
|
| 129 |
+
stride=mb_conv.stride,
|
| 130 |
+
id_skip=idskip,
|
| 131 |
+
)
|
| 132 |
+
predicted_latency += block_latency
|
| 133 |
+
fsize = out_fz
|
| 134 |
+
# feature mix layer
|
| 135 |
+
predicted_latency += self.query(
|
| 136 |
+
"Conv_1",
|
| 137 |
+
[fsize, fsize, net.feature_mix_layer.in_channels],
|
| 138 |
+
[fsize, fsize, net.feature_mix_layer.out_channels],
|
| 139 |
+
)
|
| 140 |
+
# classifier
|
| 141 |
+
predicted_latency += self.query(
|
| 142 |
+
"Logits",
|
| 143 |
+
[fsize, fsize, net.classifier.in_features],
|
| 144 |
+
[net.classifier.out_features], # 1000
|
| 145 |
+
)
|
| 146 |
+
return predicted_latency
|
| 147 |
+
|
| 148 |
+
def predict_network_latency_given_config(self, net_config, image_size=224):
|
| 149 |
+
predicted_latency = 0
|
| 150 |
+
# first conv
|
| 151 |
+
predicted_latency += self.query(
|
| 152 |
+
"Conv",
|
| 153 |
+
[image_size, image_size, 3],
|
| 154 |
+
[
|
| 155 |
+
(image_size + 1) // 2,
|
| 156 |
+
(image_size + 1) // 2,
|
| 157 |
+
net_config["first_conv"]["out_channels"],
|
| 158 |
+
],
|
| 159 |
+
)
|
| 160 |
+
# blocks
|
| 161 |
+
fsize = (image_size + 1) // 2
|
| 162 |
+
for block in net_config["blocks"]:
|
| 163 |
+
mb_conv = (
|
| 164 |
+
block["mobile_inverted_conv"]
|
| 165 |
+
if "mobile_inverted_conv" in block
|
| 166 |
+
else block["conv"]
|
| 167 |
+
)
|
| 168 |
+
shortcut = block["shortcut"]
|
| 169 |
+
|
| 170 |
+
if mb_conv is None:
|
| 171 |
+
continue
|
| 172 |
+
if shortcut is None:
|
| 173 |
+
idskip = 0
|
| 174 |
+
else:
|
| 175 |
+
idskip = 1
|
| 176 |
+
out_fz = int((fsize - 1) / mb_conv["stride"] + 1)
|
| 177 |
+
block_latency = self.query(
|
| 178 |
+
"expanded_conv",
|
| 179 |
+
[fsize, fsize, mb_conv["in_channels"]],
|
| 180 |
+
[out_fz, out_fz, mb_conv["out_channels"]],
|
| 181 |
+
expand=mb_conv["expand_ratio"],
|
| 182 |
+
ks=mb_conv["kernel_size"],
|
| 183 |
+
stride=mb_conv["stride"],
|
| 184 |
+
id_skip=idskip,
|
| 185 |
+
)
|
| 186 |
+
predicted_latency += block_latency
|
| 187 |
+
fsize = out_fz
|
| 188 |
+
# feature mix layer
|
| 189 |
+
predicted_latency += self.query(
|
| 190 |
+
"Conv_1",
|
| 191 |
+
[fsize, fsize, net_config["feature_mix_layer"]["in_channels"]],
|
| 192 |
+
[fsize, fsize, net_config["feature_mix_layer"]["out_channels"]],
|
| 193 |
+
)
|
| 194 |
+
# classifier
|
| 195 |
+
predicted_latency += self.query(
|
| 196 |
+
"Logits",
|
| 197 |
+
[fsize, fsize, net_config["classifier"]["in_features"]],
|
| 198 |
+
[net_config["classifier"]["out_features"]], # 1000
|
| 199 |
+
)
|
| 200 |
+
return predicted_latency
|
| 201 |
+
|
| 202 |
+
@staticmethod
|
| 203 |
+
def count_flops_given_config(net_config, image_size=224):
|
| 204 |
+
flops = 0
|
| 205 |
+
# first conv
|
| 206 |
+
flops += count_conv_flop(
|
| 207 |
+
(image_size + 1) // 2, 3, net_config["first_conv"]["out_channels"], 3, 1
|
| 208 |
+
)
|
| 209 |
+
# blocks
|
| 210 |
+
fsize = (image_size + 1) // 2
|
| 211 |
+
for block in net_config["blocks"]:
|
| 212 |
+
mb_conv = (
|
| 213 |
+
block["mobile_inverted_conv"]
|
| 214 |
+
if "mobile_inverted_conv" in block
|
| 215 |
+
else block["conv"]
|
| 216 |
+
)
|
| 217 |
+
if mb_conv is None:
|
| 218 |
+
continue
|
| 219 |
+
out_fz = int((fsize - 1) / mb_conv["stride"] + 1)
|
| 220 |
+
if mb_conv["mid_channels"] is None:
|
| 221 |
+
mb_conv["mid_channels"] = round(
|
| 222 |
+
mb_conv["in_channels"] * mb_conv["expand_ratio"]
|
| 223 |
+
)
|
| 224 |
+
if mb_conv["expand_ratio"] != 1:
|
| 225 |
+
# inverted bottleneck
|
| 226 |
+
flops += count_conv_flop(
|
| 227 |
+
fsize, mb_conv["in_channels"], mb_conv["mid_channels"], 1, 1
|
| 228 |
+
)
|
| 229 |
+
# depth conv
|
| 230 |
+
flops += count_conv_flop(
|
| 231 |
+
out_fz,
|
| 232 |
+
mb_conv["mid_channels"],
|
| 233 |
+
mb_conv["mid_channels"],
|
| 234 |
+
mb_conv["kernel_size"],
|
| 235 |
+
mb_conv["mid_channels"],
|
| 236 |
+
)
|
| 237 |
+
# point linear
|
| 238 |
+
flops += count_conv_flop(
|
| 239 |
+
out_fz, mb_conv["mid_channels"], mb_conv["out_channels"], 1, 1
|
| 240 |
+
)
|
| 241 |
+
fsize = out_fz
|
| 242 |
+
# feature mix layer
|
| 243 |
+
flops += count_conv_flop(
|
| 244 |
+
fsize,
|
| 245 |
+
net_config["feature_mix_layer"]["in_channels"],
|
| 246 |
+
net_config["feature_mix_layer"]["out_channels"],
|
| 247 |
+
1,
|
| 248 |
+
1,
|
| 249 |
+
)
|
| 250 |
+
# classifier
|
| 251 |
+
flops += count_conv_flop(
|
| 252 |
+
1,
|
| 253 |
+
net_config["classifier"]["in_features"],
|
| 254 |
+
net_config["classifier"]["out_features"],
|
| 255 |
+
1,
|
| 256 |
+
1,
|
| 257 |
+
)
|
| 258 |
+
return flops / 1e6 # MFLOPs
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class MBv3LatencyTable(LatencyTable):
|
| 262 |
+
def query(
|
| 263 |
+
self,
|
| 264 |
+
l_type: str,
|
| 265 |
+
input_shape,
|
| 266 |
+
output_shape,
|
| 267 |
+
mid=None,
|
| 268 |
+
ks=None,
|
| 269 |
+
stride=None,
|
| 270 |
+
id_skip=None,
|
| 271 |
+
se=None,
|
| 272 |
+
h_swish=None,
|
| 273 |
+
):
|
| 274 |
+
infos = [
|
| 275 |
+
l_type,
|
| 276 |
+
"input:%s" % self.repr_shape(input_shape),
|
| 277 |
+
"output:%s" % self.repr_shape(output_shape),
|
| 278 |
+
]
|
| 279 |
+
|
| 280 |
+
if l_type in ("expanded_conv",):
|
| 281 |
+
assert None not in (mid, ks, stride, id_skip, se, h_swish)
|
| 282 |
+
infos += [
|
| 283 |
+
"expand:%d" % mid,
|
| 284 |
+
"kernel:%d" % ks,
|
| 285 |
+
"stride:%d" % stride,
|
| 286 |
+
"idskip:%d" % id_skip,
|
| 287 |
+
"se:%d" % se,
|
| 288 |
+
"hs:%d" % h_swish,
|
| 289 |
+
]
|
| 290 |
+
key = "-".join(infos)
|
| 291 |
+
return self.lut[key]["mean"]
|
| 292 |
+
|
| 293 |
+
def predict_network_latency(self, net, image_size=224):
|
| 294 |
+
predicted_latency = 0
|
| 295 |
+
# first conv
|
| 296 |
+
predicted_latency += self.query(
|
| 297 |
+
"Conv",
|
| 298 |
+
[image_size, image_size, 3],
|
| 299 |
+
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels],
|
| 300 |
+
)
|
| 301 |
+
# blocks
|
| 302 |
+
fsize = (image_size + 1) // 2
|
| 303 |
+
for block in net.blocks:
|
| 304 |
+
mb_conv = block.conv
|
| 305 |
+
shortcut = block.shortcut
|
| 306 |
+
|
| 307 |
+
if mb_conv is None:
|
| 308 |
+
continue
|
| 309 |
+
if shortcut is None:
|
| 310 |
+
idskip = 0
|
| 311 |
+
else:
|
| 312 |
+
idskip = 1
|
| 313 |
+
out_fz = int((fsize - 1) / mb_conv.stride + 1)
|
| 314 |
+
block_latency = self.query(
|
| 315 |
+
"expanded_conv",
|
| 316 |
+
[fsize, fsize, mb_conv.in_channels],
|
| 317 |
+
[out_fz, out_fz, mb_conv.out_channels],
|
| 318 |
+
mid=mb_conv.depth_conv.conv.in_channels,
|
| 319 |
+
ks=mb_conv.kernel_size,
|
| 320 |
+
stride=mb_conv.stride,
|
| 321 |
+
id_skip=idskip,
|
| 322 |
+
se=1 if mb_conv.use_se else 0,
|
| 323 |
+
h_swish=1 if mb_conv.act_func == "h_swish" else 0,
|
| 324 |
+
)
|
| 325 |
+
predicted_latency += block_latency
|
| 326 |
+
fsize = out_fz
|
| 327 |
+
# final expand layer
|
| 328 |
+
predicted_latency += self.query(
|
| 329 |
+
"Conv_1",
|
| 330 |
+
[fsize, fsize, net.final_expand_layer.in_channels],
|
| 331 |
+
[fsize, fsize, net.final_expand_layer.out_channels],
|
| 332 |
+
)
|
| 333 |
+
# global average pooling
|
| 334 |
+
predicted_latency += self.query(
|
| 335 |
+
"AvgPool2D",
|
| 336 |
+
[fsize, fsize, net.final_expand_layer.out_channels],
|
| 337 |
+
[1, 1, net.final_expand_layer.out_channels],
|
| 338 |
+
)
|
| 339 |
+
# feature mix layer
|
| 340 |
+
predicted_latency += self.query(
|
| 341 |
+
"Conv_2",
|
| 342 |
+
[1, 1, net.feature_mix_layer.in_channels],
|
| 343 |
+
[1, 1, net.feature_mix_layer.out_channels],
|
| 344 |
+
)
|
| 345 |
+
# classifier
|
| 346 |
+
predicted_latency += self.query(
|
| 347 |
+
"Logits", [1, 1, net.classifier.in_features], [net.classifier.out_features]
|
| 348 |
+
)
|
| 349 |
+
return predicted_latency
|
| 350 |
+
|
| 351 |
+
def predict_network_latency_given_config(self, net_config, image_size=224):
|
| 352 |
+
predicted_latency = 0
|
| 353 |
+
# first conv
|
| 354 |
+
predicted_latency += self.query(
|
| 355 |
+
"Conv",
|
| 356 |
+
[image_size, image_size, 3],
|
| 357 |
+
[
|
| 358 |
+
(image_size + 1) // 2,
|
| 359 |
+
(image_size + 1) // 2,
|
| 360 |
+
net_config["first_conv"]["out_channels"],
|
| 361 |
+
],
|
| 362 |
+
)
|
| 363 |
+
# blocks
|
| 364 |
+
fsize = (image_size + 1) // 2
|
| 365 |
+
for block in net_config["blocks"]:
|
| 366 |
+
mb_conv = (
|
| 367 |
+
block["mobile_inverted_conv"]
|
| 368 |
+
if "mobile_inverted_conv" in block
|
| 369 |
+
else block["conv"]
|
| 370 |
+
)
|
| 371 |
+
shortcut = block["shortcut"]
|
| 372 |
+
|
| 373 |
+
if mb_conv is None:
|
| 374 |
+
continue
|
| 375 |
+
if shortcut is None:
|
| 376 |
+
idskip = 0
|
| 377 |
+
else:
|
| 378 |
+
idskip = 1
|
| 379 |
+
out_fz = int((fsize - 1) / mb_conv["stride"] + 1)
|
| 380 |
+
if mb_conv["mid_channels"] is None:
|
| 381 |
+
mb_conv["mid_channels"] = round(
|
| 382 |
+
mb_conv["in_channels"] * mb_conv["expand_ratio"]
|
| 383 |
+
)
|
| 384 |
+
block_latency = self.query(
|
| 385 |
+
"expanded_conv",
|
| 386 |
+
[fsize, fsize, mb_conv["in_channels"]],
|
| 387 |
+
[out_fz, out_fz, mb_conv["out_channels"]],
|
| 388 |
+
mid=mb_conv["mid_channels"],
|
| 389 |
+
ks=mb_conv["kernel_size"],
|
| 390 |
+
stride=mb_conv["stride"],
|
| 391 |
+
id_skip=idskip,
|
| 392 |
+
se=1 if mb_conv["use_se"] else 0,
|
| 393 |
+
h_swish=1 if mb_conv["act_func"] == "h_swish" else 0,
|
| 394 |
+
)
|
| 395 |
+
predicted_latency += block_latency
|
| 396 |
+
fsize = out_fz
|
| 397 |
+
# final expand layer
|
| 398 |
+
predicted_latency += self.query(
|
| 399 |
+
"Conv_1",
|
| 400 |
+
[fsize, fsize, net_config["final_expand_layer"]["in_channels"]],
|
| 401 |
+
[fsize, fsize, net_config["final_expand_layer"]["out_channels"]],
|
| 402 |
+
)
|
| 403 |
+
# global average pooling
|
| 404 |
+
predicted_latency += self.query(
|
| 405 |
+
"AvgPool2D",
|
| 406 |
+
[fsize, fsize, net_config["final_expand_layer"]["out_channels"]],
|
| 407 |
+
[1, 1, net_config["final_expand_layer"]["out_channels"]],
|
| 408 |
+
)
|
| 409 |
+
# feature mix layer
|
| 410 |
+
predicted_latency += self.query(
|
| 411 |
+
"Conv_2",
|
| 412 |
+
[1, 1, net_config["feature_mix_layer"]["in_channels"]],
|
| 413 |
+
[1, 1, net_config["feature_mix_layer"]["out_channels"]],
|
| 414 |
+
)
|
| 415 |
+
# classifier
|
| 416 |
+
predicted_latency += self.query(
|
| 417 |
+
"Logits",
|
| 418 |
+
[1, 1, net_config["classifier"]["in_features"]],
|
| 419 |
+
[net_config["classifier"]["out_features"]],
|
| 420 |
+
)
|
| 421 |
+
return predicted_latency
|
| 422 |
+
|
| 423 |
+
@staticmethod
|
| 424 |
+
def count_flops_given_config(net_config, image_size=224):
|
| 425 |
+
flops = 0
|
| 426 |
+
# first conv
|
| 427 |
+
flops += count_conv_flop(
|
| 428 |
+
(image_size + 1) // 2, 3, net_config["first_conv"]["out_channels"], 3, 1
|
| 429 |
+
)
|
| 430 |
+
# blocks
|
| 431 |
+
fsize = (image_size + 1) // 2
|
| 432 |
+
for block in net_config["blocks"]:
|
| 433 |
+
mb_conv = (
|
| 434 |
+
block["mobile_inverted_conv"]
|
| 435 |
+
if "mobile_inverted_conv" in block
|
| 436 |
+
else block["conv"]
|
| 437 |
+
)
|
| 438 |
+
if mb_conv is None:
|
| 439 |
+
continue
|
| 440 |
+
out_fz = int((fsize - 1) / mb_conv["stride"] + 1)
|
| 441 |
+
if mb_conv["mid_channels"] is None:
|
| 442 |
+
mb_conv["mid_channels"] = round(
|
| 443 |
+
mb_conv["in_channels"] * mb_conv["expand_ratio"]
|
| 444 |
+
)
|
| 445 |
+
if mb_conv["expand_ratio"] != 1:
|
| 446 |
+
# inverted bottleneck
|
| 447 |
+
flops += count_conv_flop(
|
| 448 |
+
fsize, mb_conv["in_channels"], mb_conv["mid_channels"], 1, 1
|
| 449 |
+
)
|
| 450 |
+
# depth conv
|
| 451 |
+
flops += count_conv_flop(
|
| 452 |
+
out_fz,
|
| 453 |
+
mb_conv["mid_channels"],
|
| 454 |
+
mb_conv["mid_channels"],
|
| 455 |
+
mb_conv["kernel_size"],
|
| 456 |
+
mb_conv["mid_channels"],
|
| 457 |
+
)
|
| 458 |
+
if mb_conv["use_se"]:
|
| 459 |
+
# SE layer
|
| 460 |
+
se_mid = make_divisible(
|
| 461 |
+
mb_conv["mid_channels"] // 4, divisor=MyNetwork.CHANNEL_DIVISIBLE
|
| 462 |
+
)
|
| 463 |
+
flops += count_conv_flop(1, mb_conv["mid_channels"], se_mid, 1, 1)
|
| 464 |
+
flops += count_conv_flop(1, se_mid, mb_conv["mid_channels"], 1, 1)
|
| 465 |
+
# point linear
|
| 466 |
+
flops += count_conv_flop(
|
| 467 |
+
out_fz, mb_conv["mid_channels"], mb_conv["out_channels"], 1, 1
|
| 468 |
+
)
|
| 469 |
+
fsize = out_fz
|
| 470 |
+
# final expand layer
|
| 471 |
+
flops += count_conv_flop(
|
| 472 |
+
fsize,
|
| 473 |
+
net_config["final_expand_layer"]["in_channels"],
|
| 474 |
+
net_config["final_expand_layer"]["out_channels"],
|
| 475 |
+
1,
|
| 476 |
+
1,
|
| 477 |
+
)
|
| 478 |
+
# feature mix layer
|
| 479 |
+
flops += count_conv_flop(
|
| 480 |
+
1,
|
| 481 |
+
net_config["feature_mix_layer"]["in_channels"],
|
| 482 |
+
net_config["feature_mix_layer"]["out_channels"],
|
| 483 |
+
1,
|
| 484 |
+
1,
|
| 485 |
+
)
|
| 486 |
+
# classifier
|
| 487 |
+
flops += count_conv_flop(
|
| 488 |
+
1,
|
| 489 |
+
net_config["classifier"]["in_features"],
|
| 490 |
+
net_config["classifier"]["out_features"],
|
| 491 |
+
1,
|
| 492 |
+
1,
|
| 493 |
+
)
|
| 494 |
+
return flops / 1e6 # MFLOPs
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class ResNet50LatencyTable(LatencyTable):
|
| 498 |
+
def query(self, **kwargs):
|
| 499 |
+
raise NotImplementedError
|
| 500 |
+
|
| 501 |
+
def predict_network_latency(self, net, image_size):
|
| 502 |
+
raise NotImplementedError
|
| 503 |
+
|
| 504 |
+
def predict_network_latency_given_config(self, net_config, image_size):
|
| 505 |
+
raise NotImplementedError
|
| 506 |
+
|
| 507 |
+
@staticmethod
|
| 508 |
+
def count_flops_given_config(net_config, image_size=32):
|
| 509 |
+
flops = 0
|
| 510 |
+
# input stem
|
| 511 |
+
for layer_config in net_config["input_stem"]:
|
| 512 |
+
if layer_config["name"] != "ConvLayer":
|
| 513 |
+
layer_config = layer_config["conv"]
|
| 514 |
+
in_channel = layer_config["in_channels"]
|
| 515 |
+
out_channel = layer_config["out_channels"]
|
| 516 |
+
out_image_size = int((image_size - 1) / layer_config["stride"] + 1)
|
| 517 |
+
|
| 518 |
+
flops += count_conv_flop(
|
| 519 |
+
out_image_size,
|
| 520 |
+
in_channel,
|
| 521 |
+
out_channel,
|
| 522 |
+
layer_config["kernel_size"],
|
| 523 |
+
layer_config.get("groups", 1),
|
| 524 |
+
)
|
| 525 |
+
image_size = out_image_size
|
| 526 |
+
# max pooling
|
| 527 |
+
# image_size = int((image_size - 1) / 2 + 1)
|
| 528 |
+
# ResNetBottleneckBlocks
|
| 529 |
+
for block_config in net_config["blocks"]:
|
| 530 |
+
in_channel = block_config["in_channels"]
|
| 531 |
+
out_channel = block_config["out_channels"]
|
| 532 |
+
|
| 533 |
+
out_image_size = int((image_size - 1) / block_config["stride"] + 1)
|
| 534 |
+
mid_channel = (
|
| 535 |
+
block_config["mid_channels"]
|
| 536 |
+
if block_config["mid_channels"] is not None
|
| 537 |
+
else round(out_channel * block_config["expand_ratio"])
|
| 538 |
+
)
|
| 539 |
+
mid_channel = make_divisible(mid_channel, MyNetwork.CHANNEL_DIVISIBLE)
|
| 540 |
+
|
| 541 |
+
# conv1
|
| 542 |
+
flops += count_conv_flop(image_size, in_channel, mid_channel, 1, 1)
|
| 543 |
+
# conv2
|
| 544 |
+
flops += count_conv_flop(
|
| 545 |
+
out_image_size,
|
| 546 |
+
mid_channel,
|
| 547 |
+
mid_channel,
|
| 548 |
+
block_config["kernel_size"],
|
| 549 |
+
block_config["groups"],
|
| 550 |
+
)
|
| 551 |
+
# conv3
|
| 552 |
+
flops += count_conv_flop(out_image_size, mid_channel, out_channel, 1, 1)
|
| 553 |
+
# downsample
|
| 554 |
+
if block_config["stride"] == 1 and in_channel == out_channel:
|
| 555 |
+
pass
|
| 556 |
+
else:
|
| 557 |
+
flops += count_conv_flop(out_image_size, in_channel, out_channel, 1, 1)
|
| 558 |
+
image_size = out_image_size
|
| 559 |
+
# final classifier
|
| 560 |
+
flops += count_conv_flop(
|
| 561 |
+
1,
|
| 562 |
+
net_config["classifier"]["in_features"],
|
| 563 |
+
net_config["classifier"]["out_features"],
|
| 564 |
+
1,
|
| 565 |
+
1,
|
| 566 |
+
)
|
| 567 |
+
return flops / 1e6 # MFLOPs
|
proard/nas/search_algorithm/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
from .evolution import *
|
| 6 |
+
from .multi_evolution import *
|
proard/nas/search_algorithm/evolution.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import random
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
__all__ = ["EvolutionFinder"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class EvolutionFinder:
|
| 14 |
+
def __init__(self, efficiency_predictor, accuracy_predictor, Robustness_predictor, **kwargs):
|
| 15 |
+
self.efficiency_predictor = efficiency_predictor
|
| 16 |
+
self.accuracy_predictor = accuracy_predictor
|
| 17 |
+
self.robustness_predictor = Robustness_predictor
|
| 18 |
+
|
| 19 |
+
# evolution hyper-parameters
|
| 20 |
+
self.arch_mutate_prob = kwargs.get("arch_mutate_prob", 0.1)
|
| 21 |
+
self.resolution_mutate_prob = kwargs.get("resolution_mutate_prob", 0.5)
|
| 22 |
+
self.population_size = kwargs.get("population_size", 100)
|
| 23 |
+
self.max_time_budget = kwargs.get("max_time_budget", 500)
|
| 24 |
+
self.parent_ratio = kwargs.get("parent_ratio", 0.25)
|
| 25 |
+
self.mutation_ratio = kwargs.get("mutation_ratio", 0.5)
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def arch_manager(self):
|
| 29 |
+
return self.accuracy_predictor.arch_encoder
|
| 30 |
+
|
| 31 |
+
def update_hyper_params(self, new_param_dict):
|
| 32 |
+
self.__dict__.update(new_param_dict)
|
| 33 |
+
|
| 34 |
+
def random_valid_sample(self, constraint):
|
| 35 |
+
while True:
|
| 36 |
+
sample = self.arch_manager.random_sample_arch()
|
| 37 |
+
efficiency = self.efficiency_predictor.get_efficiency(sample)
|
| 38 |
+
if efficiency <= constraint:
|
| 39 |
+
return sample, efficiency
|
| 40 |
+
|
| 41 |
+
def mutate_sample(self, sample, constraint):
|
| 42 |
+
while True:
|
| 43 |
+
new_sample = copy.deepcopy(sample)
|
| 44 |
+
self.arch_manager.mutate_resolution(new_sample, self.resolution_mutate_prob)
|
| 45 |
+
self.arch_manager.mutate_arch(new_sample, self.arch_mutate_prob)
|
| 46 |
+
|
| 47 |
+
efficiency = self.efficiency_predictor.get_efficiency(new_sample)
|
| 48 |
+
if efficiency <= constraint:
|
| 49 |
+
return new_sample, efficiency
|
| 50 |
+
|
| 51 |
+
def crossover_sample(self, sample1, sample2, constraint):
|
| 52 |
+
while True:
|
| 53 |
+
new_sample = copy.deepcopy(sample1)
|
| 54 |
+
for key in new_sample.keys():
|
| 55 |
+
if not isinstance(new_sample[key], list):
|
| 56 |
+
new_sample[key] = random.choice([sample1[key], sample2[key]])
|
| 57 |
+
else:
|
| 58 |
+
for i in range(len(new_sample[key])):
|
| 59 |
+
new_sample[key][i] = random.choice(
|
| 60 |
+
[sample1[key][i], sample2[key][i]]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
efficiency = self.efficiency_predictor.get_efficiency(new_sample)
|
| 64 |
+
if efficiency <= constraint:
|
| 65 |
+
return new_sample, efficiency
|
| 66 |
+
|
| 67 |
+
def run_evolution_search(self, constraint, verbose=False, **kwargs):
|
| 68 |
+
"""Run a single roll-out of regularized evolution to a fixed time budget."""
|
| 69 |
+
self.update_hyper_params(kwargs)
|
| 70 |
+
|
| 71 |
+
mutation_numbers = int(round(self.mutation_ratio * self.population_size))
|
| 72 |
+
parents_size = int(round(self.parent_ratio * self.population_size))
|
| 73 |
+
|
| 74 |
+
best_valids = [-100]
|
| 75 |
+
population = [] # (validation, robustness, sample, latency) tuples
|
| 76 |
+
child_pool = []
|
| 77 |
+
efficiency_pool = []
|
| 78 |
+
best_info = None
|
| 79 |
+
if verbose:
|
| 80 |
+
print("Generate random population...")
|
| 81 |
+
for _ in range(self.population_size):
|
| 82 |
+
sample, efficiency = self.random_valid_sample(constraint)
|
| 83 |
+
child_pool.append(sample)
|
| 84 |
+
efficiency_pool.append(efficiency)
|
| 85 |
+
|
| 86 |
+
accs = self.accuracy_predictor.predict_acc(child_pool)
|
| 87 |
+
robs = self.robustness_predictor.predict_rob(child_pool)
|
| 88 |
+
for i in range(self.population_size):
|
| 89 |
+
population.append((accs[i].item(), robs[i].item(), child_pool[i], efficiency_pool[i]))
|
| 90 |
+
|
| 91 |
+
if verbose:
|
| 92 |
+
print("Start Evolution...")
|
| 93 |
+
# After the population is seeded, proceed with evolving the population.
|
| 94 |
+
with tqdm(
|
| 95 |
+
total=self.max_time_budget,
|
| 96 |
+
desc="Searching with constraint (%s)" % constraint,
|
| 97 |
+
disable=(not verbose),
|
| 98 |
+
) as t:
|
| 99 |
+
for i in range(self.max_time_budget):
|
| 100 |
+
parents = sorted(population, key=lambda x: x[0])[::-1][:parents_size]
|
| 101 |
+
acc = parents[0][0]
|
| 102 |
+
rob = parents[0][1]
|
| 103 |
+
t.set_postfix({"acc": parents[0][0] , "rob":parents[0][1]})
|
| 104 |
+
if not verbose and (i + 1) % 100 == 0:
|
| 105 |
+
print("Iter: {} Acc: {} Rob: {}".format(i + 1, parents[0][0],parents[0][1]))
|
| 106 |
+
|
| 107 |
+
if acc > best_valids[-1]:
|
| 108 |
+
best_valids.append(acc)
|
| 109 |
+
best_info = parents[0]
|
| 110 |
+
else:
|
| 111 |
+
best_valids.append(best_valids[-1])
|
| 112 |
+
|
| 113 |
+
population = parents
|
| 114 |
+
child_pool = []
|
| 115 |
+
efficiency_pool = []
|
| 116 |
+
|
| 117 |
+
for j in range(mutation_numbers):
|
| 118 |
+
par_sample = population[np.random.randint(parents_size)][2]
|
| 119 |
+
# Mutate
|
| 120 |
+
new_sample, efficiency = self.mutate_sample(par_sample, constraint)
|
| 121 |
+
child_pool.append(new_sample)
|
| 122 |
+
efficiency_pool.append(efficiency)
|
| 123 |
+
|
| 124 |
+
for j in range(self.population_size - mutation_numbers):
|
| 125 |
+
par_sample1 = population[np.random.randint(parents_size)][2]
|
| 126 |
+
par_sample2 = population[np.random.randint(parents_size)][2]
|
| 127 |
+
# Crossover
|
| 128 |
+
new_sample, efficiency = self.crossover_sample(
|
| 129 |
+
par_sample1, par_sample2, constraint
|
| 130 |
+
)
|
| 131 |
+
child_pool.append(new_sample)
|
| 132 |
+
efficiency_pool.append(efficiency)
|
| 133 |
+
|
| 134 |
+
accs = self.accuracy_predictor.predict_acc(child_pool)
|
| 135 |
+
robs = self.robustness_predictor.predict_rob(child_pool)
|
| 136 |
+
for j in range(self.population_size):
|
| 137 |
+
population.append(
|
| 138 |
+
(accs[j].item(), robs[j].item(), child_pool[j], efficiency_pool[j])
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
t.update(1)
|
| 142 |
+
|
| 143 |
+
return best_valids, best_info
|
proard/nas/search_algorithm/multi_evolution.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from pymoo.core.individual import Individual
|
| 3 |
+
from pymoo.core.problem import Problem
|
| 4 |
+
from pymoo.core.sampling import Sampling
|
| 5 |
+
from pymoo.core.variable import Choice
|
| 6 |
+
__all__ = ["individual_to_arch_mbv","DynIndividual_mbv","DynProblem_mbv","individual_to_arch_res","DynIndividual_res","DynProblem_res","DynSampling","DynRandomSampler"]
|
| 7 |
+
def individual_to_arch_mbv(population, n_blocks):
|
| 8 |
+
archs = []
|
| 9 |
+
for individual in population:
|
| 10 |
+
archs.append(
|
| 11 |
+
{
|
| 12 |
+
"ks": individual[0:n_blocks],
|
| 13 |
+
"e": individual[n_blocks : 2 * n_blocks],
|
| 14 |
+
"d": individual[2 * n_blocks : -1],
|
| 15 |
+
"image_size": individual[-1:],
|
| 16 |
+
}
|
| 17 |
+
)
|
| 18 |
+
return archs
|
| 19 |
+
class DynIndividual_mbv(Individual):
|
| 20 |
+
def __init__(self, individual, accuracy_predictor,Robustness_predictor, config=None, **kwargs):
|
| 21 |
+
super().__init__(config=None, **kwargs)
|
| 22 |
+
self.X = np.concatenate(
|
| 23 |
+
(
|
| 24 |
+
individual[0]["ks"],
|
| 25 |
+
individual[0]["e"],
|
| 26 |
+
individual[0]["d"],
|
| 27 |
+
individual[0]["image_size"],
|
| 28 |
+
)
|
| 29 |
+
)
|
| 30 |
+
self.flops = individual[1]
|
| 31 |
+
self.accuracy = 100 - accuracy_predictor.predict_acc([individual[0]])
|
| 32 |
+
self.robustness = 100 - Robustness_predictor.predict_rob([individual[0]])
|
| 33 |
+
self.F = np.concatenate(([self.flops], [self.accuracy.squeeze().cpu().detach().numpy()],[self.robustness.squeeze().cpu().detach().numpy()]))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DynProblem_mbv(Problem):
|
| 38 |
+
def __init__(self, efficiency_predictor, accuracy_predictor, robustness_predictor, num_blocks, num_stages, search_vars):
|
| 39 |
+
self.ks = Choice(options=search_vars.get('ks'))
|
| 40 |
+
self.e = Choice(options=search_vars.get('e'))
|
| 41 |
+
self.d = Choice(options=search_vars.get('d'))
|
| 42 |
+
self.r = Choice(options=search_vars.get('image_size'))
|
| 43 |
+
|
| 44 |
+
super().__init__(
|
| 45 |
+
vars= dict(zip(range(len(num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])), num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])),
|
| 46 |
+
n_obj=3,
|
| 47 |
+
n_constr=0,
|
| 48 |
+
)
|
| 49 |
+
self.efficiency_predictor = efficiency_predictor
|
| 50 |
+
self.accuracy_predictor = accuracy_predictor
|
| 51 |
+
self.robustness_predictor = robustness_predictor
|
| 52 |
+
self.blocks = num_blocks
|
| 53 |
+
self.stages = num_stages
|
| 54 |
+
self.search_vars = search_vars
|
| 55 |
+
|
| 56 |
+
def _evaluate(self, x, out, *args, **kwargs):
|
| 57 |
+
f1=[]
|
| 58 |
+
# x.shape = (population_size, n_var) = (100, 4)
|
| 59 |
+
arch = individual_to_arch_mbv(x, self.blocks)
|
| 60 |
+
for arc in arch:
|
| 61 |
+
f1.append(self.efficiency_predictor.get_efficiency(arc))
|
| 62 |
+
f2 = 100 - self.accuracy_predictor.predict_acc(arch).detach().cpu().numpy()
|
| 63 |
+
f3 = 100 - self.robustness_predictor.predict_rob(arch).detach().cpu().numpy()
|
| 64 |
+
out["F"] = np.column_stack([f1, f2,f3])
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def individual_to_arch_res(population, n_blocks):
|
| 68 |
+
archs = []
|
| 69 |
+
for individual in population:
|
| 70 |
+
archs.append(
|
| 71 |
+
{
|
| 72 |
+
"e": individual[n_blocks : 2 * n_blocks],
|
| 73 |
+
"d": individual[2 * n_blocks : -1],
|
| 74 |
+
"w": individual[0:n_blocks],
|
| 75 |
+
"r": individual[-1:],
|
| 76 |
+
}
|
| 77 |
+
)
|
| 78 |
+
return archs
|
| 79 |
+
class DynIndividual_res(Individual):
|
| 80 |
+
def __init__(self, individual, accuracy_predictor,Robustness_predictor, config=None, **kwargs):
|
| 81 |
+
super().__init__(config=None, **kwargs)
|
| 82 |
+
self.X = np.concatenate(
|
| 83 |
+
(
|
| 84 |
+
individual[0]["e"],
|
| 85 |
+
individual[0]["d"],
|
| 86 |
+
individual[0]["w"],
|
| 87 |
+
[individual[0]["image_size"]],
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
self.flops = individual[1]
|
| 91 |
+
self.accuracy = 100 - accuracy_predictor.predict_acc([individual[0]])
|
| 92 |
+
self.robustness = 100 - Robustness_predictor.predict_rob([individual[0]])
|
| 93 |
+
self.F = np.concatenate(([self.flops], [self.accuracy.squeeze().cpu().detach().numpy()],[self.robustness.squeeze().cpu().detach().numpy()]))
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class DynProblem_res(Problem):
|
| 98 |
+
def __init__(self, efficiency_predictor, accuracy_predictor, robustness_predictor, num_blocks, num_stages, search_vars):
|
| 99 |
+
self.e = Choice(options=search_vars.get('e'))
|
| 100 |
+
self.d = Choice(options=search_vars.get('d'))
|
| 101 |
+
self.w = Choice(options=search_vars.get('w'))
|
| 102 |
+
self.r = Choice(options=search_vars.get('image_size'))
|
| 103 |
+
super().__init__(
|
| 104 |
+
vars= dict(zip(range(len(num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])), num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])),
|
| 105 |
+
n_obj=3,
|
| 106 |
+
n_constr=0,
|
| 107 |
+
)
|
| 108 |
+
self.efficiency_predictor = efficiency_predictor
|
| 109 |
+
self.accuracy_predictor = accuracy_predictor
|
| 110 |
+
self.robustness_predictor = robustness_predictor
|
| 111 |
+
self.blocks = num_blocks
|
| 112 |
+
self.stages = num_stages
|
| 113 |
+
self.search_vars = search_vars
|
| 114 |
+
|
| 115 |
+
def _evaluate(self, x, out, *args, **kwargs):
|
| 116 |
+
f1={}
|
| 117 |
+
# x.shape = (population_size, n_var) = (100, 4)
|
| 118 |
+
arch = individual_to_arch_res(x, self.blocks)
|
| 119 |
+
for arc in arch:
|
| 120 |
+
f1.append(self.efficiency_predictor.get_efficiency(arc))
|
| 121 |
+
f2 = 100 - self.accuracy_predictor.predict_acc(arch)
|
| 122 |
+
f3 = 100 - self.robustness_predictor.predict_rob(arch)
|
| 123 |
+
out["F"] = np.column_stack([f1, f2,f3])
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class DynSampling(Sampling):
|
| 128 |
+
def _do(self, problem, n_samples, **kwargs):
|
| 129 |
+
return [
|
| 130 |
+
[np.random.choice(var.options) for key,var in problem.vars.items()]
|
| 131 |
+
for _ in range(n_samples)
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class DynRandomSampler:
|
| 136 |
+
def __init__(self, arch_manager, efficiency_predictor):
|
| 137 |
+
self.arch_manager = arch_manager
|
| 138 |
+
self.efficiency_predictor = efficiency_predictor
|
| 139 |
+
|
| 140 |
+
def random_sample(self):
|
| 141 |
+
sample = self.arch_manager.random_sample_arch()
|
| 142 |
+
efficiency = self.efficiency_predictor.get_efficiency(sample)
|
| 143 |
+
return sample, efficiency
|
proard/utils/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
from .pytorch_modules import *
|
| 6 |
+
from .pytorch_utils import *
|
| 7 |
+
from .my_modules import *
|
| 8 |
+
from .flops_counter import *
|
| 9 |
+
from .common_tools import *
|
| 10 |
+
from .my_dataloader import *
|
proard/utils/common_tools.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from urllib import urlretrieve
|
| 12 |
+
except ImportError:
|
| 13 |
+
from urllib.request import urlretrieve
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"sort_dict",
|
| 17 |
+
"get_same_padding",
|
| 18 |
+
"get_split_list",
|
| 19 |
+
"list_sum",
|
| 20 |
+
"list_mean",
|
| 21 |
+
"list_join",
|
| 22 |
+
"subset_mean",
|
| 23 |
+
"sub_filter_start_end",
|
| 24 |
+
"min_divisible_value",
|
| 25 |
+
"val2list",
|
| 26 |
+
"download_url",
|
| 27 |
+
"write_log",
|
| 28 |
+
"pairwise_accuracy",
|
| 29 |
+
"accuracy",
|
| 30 |
+
"AverageMeter",
|
| 31 |
+
"MultiClassAverageMeter",
|
| 32 |
+
"DistributedMetric",
|
| 33 |
+
"DistributedTensor",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def sort_dict(src_dict, reverse=False, return_dict=True):
|
| 38 |
+
output = sorted(src_dict.items(), key=lambda x: x[1], reverse=reverse)
|
| 39 |
+
if return_dict:
|
| 40 |
+
return dict(output)
|
| 41 |
+
else:
|
| 42 |
+
return output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_same_padding(kernel_size):
|
| 46 |
+
if isinstance(kernel_size, tuple):
|
| 47 |
+
assert len(kernel_size) == 2, "invalid kernel size: %s" % kernel_size
|
| 48 |
+
p1 = get_same_padding(kernel_size[0])
|
| 49 |
+
p2 = get_same_padding(kernel_size[1])
|
| 50 |
+
return p1, p2
|
| 51 |
+
assert isinstance(kernel_size, int), "kernel size should be either `int` or `tuple`"
|
| 52 |
+
assert kernel_size % 2 > 0, "kernel size should be odd number"
|
| 53 |
+
return kernel_size // 2
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_split_list(in_dim, child_num, accumulate=False):
|
| 57 |
+
in_dim_list = [in_dim // child_num] * child_num
|
| 58 |
+
for _i in range(in_dim % child_num):
|
| 59 |
+
in_dim_list[_i] += 1
|
| 60 |
+
if accumulate:
|
| 61 |
+
for i in range(1, child_num):
|
| 62 |
+
in_dim_list[i] += in_dim_list[i - 1]
|
| 63 |
+
return in_dim_list
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def list_sum(x):
|
| 67 |
+
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def list_mean(x):
|
| 71 |
+
return list_sum(x) / len(x)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def list_join(val_list, sep="\t"):
|
| 75 |
+
return sep.join([str(val) for val in val_list])
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def subset_mean(val_list, sub_indexes):
|
| 79 |
+
sub_indexes = val2list(sub_indexes, 1)
|
| 80 |
+
return list_mean([val_list[idx] for idx in sub_indexes])
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def sub_filter_start_end(kernel_size, sub_kernel_size):
|
| 84 |
+
center = kernel_size // 2
|
| 85 |
+
dev = sub_kernel_size // 2
|
| 86 |
+
start, end = center - dev, center + dev + 1
|
| 87 |
+
assert end - start == sub_kernel_size
|
| 88 |
+
return start, end
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def min_divisible_value(n1, v1):
|
| 92 |
+
"""make sure v1 is divisible by n1, otherwise decrease v1"""
|
| 93 |
+
if v1 >= n1:
|
| 94 |
+
return n1
|
| 95 |
+
while n1 % v1 != 0:
|
| 96 |
+
v1 -= 1
|
| 97 |
+
return v1
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def val2list(val, repeat_time=1):
|
| 101 |
+
if isinstance(val, list) or isinstance(val, np.ndarray):
|
| 102 |
+
return val
|
| 103 |
+
elif isinstance(val, tuple):
|
| 104 |
+
return list(val)
|
| 105 |
+
else:
|
| 106 |
+
return [val for _ in range(repeat_time)]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def download_url(url, model_dir="~/.torch/", overwrite=False):
|
| 110 |
+
target_dir = url.split("/")[-1]
|
| 111 |
+
model_dir = os.path.expanduser(model_dir)
|
| 112 |
+
try:
|
| 113 |
+
if not os.path.exists(model_dir):
|
| 114 |
+
os.makedirs(model_dir)
|
| 115 |
+
model_dir = os.path.join(model_dir, target_dir)
|
| 116 |
+
cached_file = model_dir
|
| 117 |
+
if not os.path.exists(cached_file) or overwrite:
|
| 118 |
+
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
| 119 |
+
urlretrieve(url, cached_file)
|
| 120 |
+
return cached_file
|
| 121 |
+
except Exception as e:
|
| 122 |
+
# remove lock file so download can be executed next time.
|
| 123 |
+
os.remove(os.path.join(model_dir, "download.lock"))
|
| 124 |
+
sys.stderr.write("Failed to download from url %s" % url + "\n" + str(e) + "\n")
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def write_log(logs_path, log_str, prefix="valid", should_print=True, mode="a"):
|
| 129 |
+
if not os.path.exists(logs_path):
|
| 130 |
+
os.makedirs(logs_path, exist_ok=True)
|
| 131 |
+
""" prefix: valid, train, test """
|
| 132 |
+
if prefix in ["valid", "test"]:
|
| 133 |
+
with open(os.path.join(logs_path, "valid_console.txt"), mode) as fout:
|
| 134 |
+
fout.write(log_str + "\n")
|
| 135 |
+
fout.flush()
|
| 136 |
+
if prefix in ["valid", "test", "train"]:
|
| 137 |
+
with open(os.path.join(logs_path, "train_console.txt"), mode) as fout:
|
| 138 |
+
if prefix in ["valid", "test"]:
|
| 139 |
+
fout.write("=" * 10)
|
| 140 |
+
fout.write(log_str + "\n")
|
| 141 |
+
fout.flush()
|
| 142 |
+
else:
|
| 143 |
+
with open(os.path.join(logs_path, "%s.txt" % prefix), mode) as fout:
|
| 144 |
+
fout.write(log_str + "\n")
|
| 145 |
+
fout.flush()
|
| 146 |
+
if should_print:
|
| 147 |
+
print(log_str)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def pairwise_accuracy(la, lb, n_samples=200000):
|
| 151 |
+
n = len(la)
|
| 152 |
+
assert n == len(lb)
|
| 153 |
+
total = 0
|
| 154 |
+
count = 0
|
| 155 |
+
for _ in range(n_samples):
|
| 156 |
+
i = np.random.randint(n)
|
| 157 |
+
j = np.random.randint(n)
|
| 158 |
+
while i == j:
|
| 159 |
+
j = np.random.randint(n)
|
| 160 |
+
if la[i] >= la[j] and lb[i] >= lb[j]:
|
| 161 |
+
count += 1
|
| 162 |
+
if la[i] < la[j] and lb[i] < lb[j]:
|
| 163 |
+
count += 1
|
| 164 |
+
total += 1
|
| 165 |
+
return float(count) / total
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def accuracy(output, target, topk=(1,)):
|
| 171 |
+
"""Computes the precision@k for the specified values of k"""
|
| 172 |
+
maxk = max(topk)
|
| 173 |
+
batch_size = target.size(0)
|
| 174 |
+
|
| 175 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 176 |
+
pred = pred.t()
|
| 177 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
| 178 |
+
|
| 179 |
+
res = []
|
| 180 |
+
for k in topk:
|
| 181 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
| 182 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
| 183 |
+
return res
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class AverageMeter(object):
|
| 187 |
+
"""
|
| 188 |
+
Computes and stores the average and current value
|
| 189 |
+
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self):
|
| 193 |
+
self.val = 0
|
| 194 |
+
self.avg = 0
|
| 195 |
+
self.sum = 0
|
| 196 |
+
self.count = 0
|
| 197 |
+
|
| 198 |
+
def reset(self):
|
| 199 |
+
self.val = 0
|
| 200 |
+
self.avg = 0
|
| 201 |
+
self.sum = 0
|
| 202 |
+
self.count = 0
|
| 203 |
+
|
| 204 |
+
def update(self, val, n=1):
|
| 205 |
+
self.val = val
|
| 206 |
+
self.sum += val * n
|
| 207 |
+
self.count += n
|
| 208 |
+
self.avg = self.sum / self.count
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class MultiClassAverageMeter:
|
| 212 |
+
|
| 213 |
+
"""Multi Binary Classification Tasks"""
|
| 214 |
+
|
| 215 |
+
def __init__(self, num_classes, balanced=False, **kwargs):
|
| 216 |
+
|
| 217 |
+
super(MultiClassAverageMeter, self).__init__()
|
| 218 |
+
self.num_classes = num_classes
|
| 219 |
+
self.balanced = balanced
|
| 220 |
+
|
| 221 |
+
self.counts = []
|
| 222 |
+
for k in range(self.num_classes):
|
| 223 |
+
self.counts.append(np.ndarray((2, 2), dtype=np.float32))
|
| 224 |
+
|
| 225 |
+
self.reset()
|
| 226 |
+
|
| 227 |
+
def reset(self):
|
| 228 |
+
for k in range(self.num_classes):
|
| 229 |
+
self.counts[k].fill(0)
|
| 230 |
+
|
| 231 |
+
def add(self, outputs, targets):
|
| 232 |
+
outputs = outputs.data.cpu().numpy()
|
| 233 |
+
targets = targets.data.cpu().numpy()
|
| 234 |
+
|
| 235 |
+
for k in range(self.num_classes):
|
| 236 |
+
output = np.argmax(outputs[:, k, :], axis=1)
|
| 237 |
+
target = targets[:, k]
|
| 238 |
+
|
| 239 |
+
x = output + 2 * target
|
| 240 |
+
bincount = np.bincount(x.astype(np.int32), minlength=2 ** 2)
|
| 241 |
+
|
| 242 |
+
self.counts[k] += bincount.reshape((2, 2))
|
| 243 |
+
|
| 244 |
+
def value(self):
|
| 245 |
+
mean = 0
|
| 246 |
+
for k in range(self.num_classes):
|
| 247 |
+
if self.balanced:
|
| 248 |
+
value = np.mean(
|
| 249 |
+
(
|
| 250 |
+
self.counts[k]
|
| 251 |
+
/ np.maximum(np.sum(self.counts[k], axis=1), 1)[:, None]
|
| 252 |
+
).diagonal()
|
| 253 |
+
)
|
| 254 |
+
else:
|
| 255 |
+
value = np.sum(self.counts[k].diagonal()) / np.maximum(
|
| 256 |
+
np.sum(self.counts[k]), 1
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
mean += value / self.num_classes * 100.0
|
| 260 |
+
return mean
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class DistributedMetric(object):
|
| 264 |
+
"""
|
| 265 |
+
Horovod: average metrics from distributed training.
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
def __init__(self, name):
|
| 269 |
+
self.name = name
|
| 270 |
+
self.sum = torch.zeros(1)[0]
|
| 271 |
+
self.count = torch.zeros(1)[0]
|
| 272 |
+
|
| 273 |
+
def update(self, val, delta_n=1):
|
| 274 |
+
import horovod.torch as hvd
|
| 275 |
+
|
| 276 |
+
val *= delta_n
|
| 277 |
+
self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
|
| 278 |
+
self.count += delta_n
|
| 279 |
+
|
| 280 |
+
@property
|
| 281 |
+
def avg(self):
|
| 282 |
+
return self.sum / self.count
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class DistributedTensor(object):
|
| 286 |
+
def __init__(self, name):
|
| 287 |
+
self.name = name
|
| 288 |
+
self.sum = None
|
| 289 |
+
self.count = torch.zeros(1)[0]
|
| 290 |
+
self.synced = False
|
| 291 |
+
|
| 292 |
+
def update(self, val, delta_n=1):
|
| 293 |
+
val *= delta_n
|
| 294 |
+
if self.sum is None:
|
| 295 |
+
self.sum = val.detach()
|
| 296 |
+
else:
|
| 297 |
+
self.sum += val.detach()
|
| 298 |
+
self.count += delta_n
|
| 299 |
+
|
| 300 |
+
@property
|
| 301 |
+
def avg(self):
|
| 302 |
+
import horovod.torch as hvd
|
| 303 |
+
|
| 304 |
+
if not self.synced:
|
| 305 |
+
self.sum = hvd.allreduce(self.sum, name=self.name)
|
| 306 |
+
self.synced = True
|
| 307 |
+
return self.sum / self.count
|
proard/utils/flops_counter.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from .my_modules import MyConv2d
|
| 9 |
+
|
| 10 |
+
__all__ = ["profile"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def count_convNd(m, _, y):
|
| 14 |
+
cin = m.in_channels
|
| 15 |
+
|
| 16 |
+
kernel_ops = m.weight.size()[2] * m.weight.size()[3]
|
| 17 |
+
ops_per_element = kernel_ops
|
| 18 |
+
output_elements = y.nelement()
|
| 19 |
+
|
| 20 |
+
# cout x oW x oH
|
| 21 |
+
total_ops = cin * output_elements * ops_per_element // m.groups
|
| 22 |
+
m.total_ops = torch.zeros(1).fill_(total_ops)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def count_linear(m, _, __):
|
| 26 |
+
total_ops = m.in_features * m.out_features
|
| 27 |
+
|
| 28 |
+
m.total_ops = torch.zeros(1).fill_(total_ops)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
register_hooks = {
|
| 32 |
+
nn.Conv1d: count_convNd,
|
| 33 |
+
nn.Conv2d: count_convNd,
|
| 34 |
+
nn.Conv3d: count_convNd,
|
| 35 |
+
MyConv2d: count_convNd,
|
| 36 |
+
######################################
|
| 37 |
+
nn.Linear: count_linear,
|
| 38 |
+
######################################
|
| 39 |
+
nn.Dropout: None,
|
| 40 |
+
nn.Dropout2d: None,
|
| 41 |
+
nn.Dropout3d: None,
|
| 42 |
+
nn.BatchNorm2d: None,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def profile(model, input_size, custom_ops=None):
|
| 47 |
+
handler_collection = []
|
| 48 |
+
custom_ops = {} if custom_ops is None else custom_ops
|
| 49 |
+
|
| 50 |
+
def add_hooks(m_):
|
| 51 |
+
if len(list(m_.children())) > 0:
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
m_.register_buffer("total_ops", torch.zeros(1))
|
| 55 |
+
m_.register_buffer("total_params", torch.zeros(1))
|
| 56 |
+
|
| 57 |
+
for p in m_.parameters():
|
| 58 |
+
m_.total_params += torch.zeros(1).fill_(p.numel())
|
| 59 |
+
|
| 60 |
+
m_type = type(m_)
|
| 61 |
+
fn = None
|
| 62 |
+
|
| 63 |
+
if m_type in custom_ops:
|
| 64 |
+
fn = custom_ops[m_type]
|
| 65 |
+
elif m_type in register_hooks:
|
| 66 |
+
fn = register_hooks[m_type]
|
| 67 |
+
|
| 68 |
+
if fn is not None:
|
| 69 |
+
_handler = m_.register_forward_hook(fn)
|
| 70 |
+
handler_collection.append(_handler)
|
| 71 |
+
|
| 72 |
+
original_device = model.parameters().__next__().device
|
| 73 |
+
training = model.training
|
| 74 |
+
|
| 75 |
+
model.eval()
|
| 76 |
+
model.apply(add_hooks)
|
| 77 |
+
|
| 78 |
+
x = torch.zeros(input_size).to(original_device)
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
model(x)
|
| 81 |
+
|
| 82 |
+
total_ops = 0
|
| 83 |
+
total_params = 0
|
| 84 |
+
for m in model.modules():
|
| 85 |
+
if len(list(m.children())) > 0: # skip for non-leaf module
|
| 86 |
+
continue
|
| 87 |
+
total_ops += m.total_ops
|
| 88 |
+
total_params += m.total_params
|
| 89 |
+
|
| 90 |
+
total_ops = total_ops.item()
|
| 91 |
+
total_params = total_params.item()
|
| 92 |
+
|
| 93 |
+
model.train(training).to(original_device)
|
| 94 |
+
for handler in handler_collection:
|
| 95 |
+
handler.remove()
|
| 96 |
+
|
| 97 |
+
return total_ops, total_params
|
proard/utils/layers.py
ADDED
|
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
from proard.utils import get_same_padding, min_divisible_value, SEModule, ShuffleLayer
|
| 10 |
+
from proard.utils import MyNetwork, MyModule
|
| 11 |
+
from proard.utils import build_activation, make_divisible
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"set_layer_from_config",
|
| 15 |
+
"ConvLayer",
|
| 16 |
+
"IdentityLayer",
|
| 17 |
+
"LinearLayer",
|
| 18 |
+
"MultiHeadLinearLayer",
|
| 19 |
+
"ZeroLayer",
|
| 20 |
+
"MBConvLayer",
|
| 21 |
+
"ResidualBlock",
|
| 22 |
+
"ResNetBottleneckBlock",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def set_layer_from_config(layer_config):
|
| 27 |
+
if layer_config is None:
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
name2layer = {
|
| 31 |
+
ConvLayer.__name__: ConvLayer,
|
| 32 |
+
IdentityLayer.__name__: IdentityLayer,
|
| 33 |
+
LinearLayer.__name__: LinearLayer,
|
| 34 |
+
MultiHeadLinearLayer.__name__: MultiHeadLinearLayer,
|
| 35 |
+
ZeroLayer.__name__: ZeroLayer,
|
| 36 |
+
MBConvLayer.__name__: MBConvLayer,
|
| 37 |
+
"MBInvertedConvLayer": MBConvLayer,
|
| 38 |
+
##########################################################
|
| 39 |
+
ResidualBlock.__name__: ResidualBlock,
|
| 40 |
+
ResNetBottleneckBlock.__name__: ResNetBottleneckBlock,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
layer_name = layer_config.pop("name")
|
| 44 |
+
layer = name2layer[layer_name]
|
| 45 |
+
return layer.build_from_config(layer_config)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class My2DLayer(MyModule):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
in_channels,
|
| 52 |
+
out_channels,
|
| 53 |
+
use_bn=True,
|
| 54 |
+
act_func="relu",
|
| 55 |
+
dropout_rate=0,
|
| 56 |
+
ops_order="weight_bn_act",
|
| 57 |
+
):
|
| 58 |
+
super(My2DLayer, self).__init__()
|
| 59 |
+
self.in_channels = in_channels
|
| 60 |
+
self.out_channels = out_channels
|
| 61 |
+
|
| 62 |
+
self.use_bn = use_bn
|
| 63 |
+
self.act_func = act_func
|
| 64 |
+
self.dropout_rate = dropout_rate
|
| 65 |
+
self.ops_order = ops_order
|
| 66 |
+
|
| 67 |
+
""" modules """
|
| 68 |
+
modules = {}
|
| 69 |
+
# batch norm
|
| 70 |
+
if self.use_bn:
|
| 71 |
+
if self.bn_before_weight:
|
| 72 |
+
modules["bn"] = nn.BatchNorm2d(in_channels)
|
| 73 |
+
else:
|
| 74 |
+
modules["bn"] = nn.BatchNorm2d(out_channels)
|
| 75 |
+
else:
|
| 76 |
+
modules["bn"] = None
|
| 77 |
+
# activation
|
| 78 |
+
modules["act"] = build_activation(
|
| 79 |
+
self.act_func, self.ops_list[0] != "act" and self.use_bn
|
| 80 |
+
)
|
| 81 |
+
# dropout
|
| 82 |
+
if self.dropout_rate > 0:
|
| 83 |
+
modules["dropout"] = nn.Dropout2d(self.dropout_rate, inplace=True)
|
| 84 |
+
else:
|
| 85 |
+
modules["dropout"] = None
|
| 86 |
+
# weight
|
| 87 |
+
modules["weight"] = self.weight_op()
|
| 88 |
+
|
| 89 |
+
# add modules
|
| 90 |
+
for op in self.ops_list:
|
| 91 |
+
if modules[op] is None:
|
| 92 |
+
continue
|
| 93 |
+
elif op == "weight":
|
| 94 |
+
# dropout before weight operation
|
| 95 |
+
if modules["dropout"] is not None:
|
| 96 |
+
self.add_module("dropout", modules["dropout"])
|
| 97 |
+
for key in modules["weight"]:
|
| 98 |
+
self.add_module(key, modules["weight"][key])
|
| 99 |
+
else:
|
| 100 |
+
self.add_module(op, modules[op])
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def ops_list(self):
|
| 104 |
+
return self.ops_order.split("_")
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def bn_before_weight(self):
|
| 108 |
+
for op in self.ops_list:
|
| 109 |
+
if op == "bn":
|
| 110 |
+
return True
|
| 111 |
+
elif op == "weight":
|
| 112 |
+
return False
|
| 113 |
+
raise ValueError("Invalid ops_order: %s" % self.ops_order)
|
| 114 |
+
|
| 115 |
+
def weight_op(self):
|
| 116 |
+
raise NotImplementedError
|
| 117 |
+
|
| 118 |
+
""" Methods defined in MyModule """
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
# similar to nn.Sequential
|
| 122 |
+
for module in self._modules.values():
|
| 123 |
+
x = module(x)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def module_str(self):
|
| 128 |
+
raise NotImplementedError
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def config(self):
|
| 132 |
+
return {
|
| 133 |
+
"in_channels": self.in_channels,
|
| 134 |
+
"out_channels": self.out_channels,
|
| 135 |
+
"use_bn": self.use_bn,
|
| 136 |
+
"act_func": self.act_func,
|
| 137 |
+
"dropout_rate": self.dropout_rate,
|
| 138 |
+
"ops_order": self.ops_order,
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
def build_from_config(config):
|
| 143 |
+
raise NotImplementedError
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class ConvLayer(My2DLayer):
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
in_channels,
|
| 150 |
+
out_channels,
|
| 151 |
+
kernel_size=3,
|
| 152 |
+
stride=1,
|
| 153 |
+
dilation=1,
|
| 154 |
+
groups=1,
|
| 155 |
+
bias=False,
|
| 156 |
+
has_shuffle=False,
|
| 157 |
+
use_se=False,
|
| 158 |
+
use_bn=True,
|
| 159 |
+
act_func="relu",
|
| 160 |
+
dropout_rate=0,
|
| 161 |
+
ops_order="weight_bn_act",
|
| 162 |
+
):
|
| 163 |
+
# default normal 3x3_Conv with bn and relu
|
| 164 |
+
self.kernel_size = kernel_size
|
| 165 |
+
self.stride = stride
|
| 166 |
+
self.dilation = dilation
|
| 167 |
+
self.groups = groups
|
| 168 |
+
self.bias = bias
|
| 169 |
+
self.has_shuffle = has_shuffle
|
| 170 |
+
self.use_se = use_se
|
| 171 |
+
|
| 172 |
+
super(ConvLayer, self).__init__(
|
| 173 |
+
in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order
|
| 174 |
+
)
|
| 175 |
+
if self.use_se:
|
| 176 |
+
self.add_module("se", SEModule(self.out_channels))
|
| 177 |
+
|
| 178 |
+
def weight_op(self):
|
| 179 |
+
padding = get_same_padding(self.kernel_size)
|
| 180 |
+
if isinstance(padding, int):
|
| 181 |
+
padding *= self.dilation
|
| 182 |
+
else:
|
| 183 |
+
padding[0] *= self.dilation
|
| 184 |
+
padding[1] *= self.dilation
|
| 185 |
+
|
| 186 |
+
weight_dict = OrderedDict(
|
| 187 |
+
{
|
| 188 |
+
"conv": nn.Conv2d(
|
| 189 |
+
self.in_channels,
|
| 190 |
+
self.out_channels,
|
| 191 |
+
kernel_size=self.kernel_size,
|
| 192 |
+
stride=self.stride,
|
| 193 |
+
padding=padding,
|
| 194 |
+
dilation=self.dilation,
|
| 195 |
+
groups=min_divisible_value(self.in_channels, self.groups),
|
| 196 |
+
bias=self.bias,
|
| 197 |
+
)
|
| 198 |
+
}
|
| 199 |
+
)
|
| 200 |
+
if self.has_shuffle and self.groups > 1:
|
| 201 |
+
weight_dict["shuffle"] = ShuffleLayer(self.groups)
|
| 202 |
+
|
| 203 |
+
return weight_dict
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def module_str(self):
|
| 207 |
+
if isinstance(self.kernel_size, int):
|
| 208 |
+
kernel_size = (self.kernel_size, self.kernel_size)
|
| 209 |
+
else:
|
| 210 |
+
kernel_size = self.kernel_size
|
| 211 |
+
if self.groups == 1:
|
| 212 |
+
if self.dilation > 1:
|
| 213 |
+
conv_str = "%dx%d_DilatedConv" % (kernel_size[0], kernel_size[1])
|
| 214 |
+
else:
|
| 215 |
+
conv_str = "%dx%d_Conv" % (kernel_size[0], kernel_size[1])
|
| 216 |
+
else:
|
| 217 |
+
if self.dilation > 1:
|
| 218 |
+
conv_str = "%dx%d_DilatedGroupConv" % (kernel_size[0], kernel_size[1])
|
| 219 |
+
else:
|
| 220 |
+
conv_str = "%dx%d_GroupConv" % (kernel_size[0], kernel_size[1])
|
| 221 |
+
conv_str += "_O%d" % self.out_channels
|
| 222 |
+
if self.use_se:
|
| 223 |
+
conv_str = "SE_" + conv_str
|
| 224 |
+
conv_str += "_" + self.act_func.upper()
|
| 225 |
+
if self.use_bn:
|
| 226 |
+
if isinstance(self.bn, nn.GroupNorm):
|
| 227 |
+
conv_str += "_GN%d" % self.bn.num_groups
|
| 228 |
+
elif isinstance(self.bn, nn.BatchNorm2d):
|
| 229 |
+
conv_str += "_BN"
|
| 230 |
+
return conv_str
|
| 231 |
+
|
| 232 |
+
@property
|
| 233 |
+
def config(self):
|
| 234 |
+
return {
|
| 235 |
+
"name": ConvLayer.__name__,
|
| 236 |
+
"kernel_size": self.kernel_size,
|
| 237 |
+
"stride": self.stride,
|
| 238 |
+
"dilation": self.dilation,
|
| 239 |
+
"groups": self.groups,
|
| 240 |
+
"bias": self.bias,
|
| 241 |
+
"has_shuffle": self.has_shuffle,
|
| 242 |
+
"use_se": self.use_se,
|
| 243 |
+
**super(ConvLayer, self).config,
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def build_from_config(config):
|
| 248 |
+
return ConvLayer(**config)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class IdentityLayer(My2DLayer):
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
in_channels,
|
| 255 |
+
out_channels,
|
| 256 |
+
use_bn=False,
|
| 257 |
+
act_func=None,
|
| 258 |
+
dropout_rate=0,
|
| 259 |
+
ops_order="weight_bn_act",
|
| 260 |
+
):
|
| 261 |
+
super(IdentityLayer, self).__init__(
|
| 262 |
+
in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
def weight_op(self):
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
@property
|
| 269 |
+
def module_str(self):
|
| 270 |
+
return "Identity"
|
| 271 |
+
|
| 272 |
+
@property
|
| 273 |
+
def config(self):
|
| 274 |
+
return {
|
| 275 |
+
"name": IdentityLayer.__name__,
|
| 276 |
+
**super(IdentityLayer, self).config,
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def build_from_config(config):
|
| 281 |
+
return IdentityLayer(**config)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class LinearLayer(MyModule):
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
in_features,
|
| 288 |
+
out_features,
|
| 289 |
+
bias=True,
|
| 290 |
+
use_bn=False,
|
| 291 |
+
act_func=None,
|
| 292 |
+
dropout_rate=0,
|
| 293 |
+
ops_order="weight_bn_act",
|
| 294 |
+
):
|
| 295 |
+
super(LinearLayer, self).__init__()
|
| 296 |
+
|
| 297 |
+
self.in_features = in_features
|
| 298 |
+
self.out_features = out_features
|
| 299 |
+
self.bias = bias
|
| 300 |
+
|
| 301 |
+
self.use_bn = use_bn
|
| 302 |
+
self.act_func = act_func
|
| 303 |
+
self.dropout_rate = dropout_rate
|
| 304 |
+
self.ops_order = ops_order
|
| 305 |
+
|
| 306 |
+
""" modules """
|
| 307 |
+
modules = {}
|
| 308 |
+
# batch norm
|
| 309 |
+
if self.use_bn:
|
| 310 |
+
if self.bn_before_weight:
|
| 311 |
+
modules["bn"] = nn.BatchNorm1d(in_features)
|
| 312 |
+
else:
|
| 313 |
+
modules["bn"] = nn.BatchNorm1d(out_features)
|
| 314 |
+
else:
|
| 315 |
+
modules["bn"] = None
|
| 316 |
+
# activation
|
| 317 |
+
modules["act"] = build_activation(self.act_func, self.ops_list[0] != "act")
|
| 318 |
+
# dropout
|
| 319 |
+
if self.dropout_rate > 0:
|
| 320 |
+
modules["dropout"] = nn.Dropout(self.dropout_rate, inplace=True)
|
| 321 |
+
else:
|
| 322 |
+
modules["dropout"] = None
|
| 323 |
+
# linear
|
| 324 |
+
modules["weight"] = {
|
| 325 |
+
"linear": nn.Linear(self.in_features, self.out_features, self.bias)
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
# add modules
|
| 329 |
+
for op in self.ops_list:
|
| 330 |
+
if modules[op] is None:
|
| 331 |
+
continue
|
| 332 |
+
elif op == "weight":
|
| 333 |
+
if modules["dropout"] is not None:
|
| 334 |
+
self.add_module("dropout", modules["dropout"])
|
| 335 |
+
for key in modules["weight"]:
|
| 336 |
+
self.add_module(key, modules["weight"][key])
|
| 337 |
+
else:
|
| 338 |
+
self.add_module(op, modules[op])
|
| 339 |
+
|
| 340 |
+
@property
|
| 341 |
+
def ops_list(self):
|
| 342 |
+
return self.ops_order.split("_")
|
| 343 |
+
|
| 344 |
+
@property
|
| 345 |
+
def bn_before_weight(self):
|
| 346 |
+
for op in self.ops_list:
|
| 347 |
+
if op == "bn":
|
| 348 |
+
return True
|
| 349 |
+
elif op == "weight":
|
| 350 |
+
return False
|
| 351 |
+
raise ValueError("Invalid ops_order: %s" % self.ops_order)
|
| 352 |
+
|
| 353 |
+
def forward(self, x):
|
| 354 |
+
for module in self._modules.values():
|
| 355 |
+
x = module(x)
|
| 356 |
+
return x
|
| 357 |
+
|
| 358 |
+
@property
|
| 359 |
+
def module_str(self):
|
| 360 |
+
return "%dx%d_Linear" % (self.in_features, self.out_features)
|
| 361 |
+
|
| 362 |
+
@property
|
| 363 |
+
def config(self):
|
| 364 |
+
return {
|
| 365 |
+
"name": LinearLayer.__name__,
|
| 366 |
+
"in_features": self.in_features,
|
| 367 |
+
"out_features": self.out_features,
|
| 368 |
+
"bias": self.bias,
|
| 369 |
+
"use_bn": self.use_bn,
|
| 370 |
+
"act_func": self.act_func,
|
| 371 |
+
"dropout_rate": self.dropout_rate,
|
| 372 |
+
"ops_order": self.ops_order,
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def build_from_config(config):
|
| 377 |
+
return LinearLayer(**config)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class MultiHeadLinearLayer(MyModule):
|
| 381 |
+
def __init__(
|
| 382 |
+
self, in_features, out_features, num_heads=1, bias=True, dropout_rate=0
|
| 383 |
+
):
|
| 384 |
+
super(MultiHeadLinearLayer, self).__init__()
|
| 385 |
+
self.in_features = in_features
|
| 386 |
+
self.out_features = out_features
|
| 387 |
+
self.num_heads = num_heads
|
| 388 |
+
|
| 389 |
+
self.bias = bias
|
| 390 |
+
self.dropout_rate = dropout_rate
|
| 391 |
+
|
| 392 |
+
if self.dropout_rate > 0:
|
| 393 |
+
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
|
| 394 |
+
else:
|
| 395 |
+
self.dropout = None
|
| 396 |
+
|
| 397 |
+
self.layers = nn.ModuleList()
|
| 398 |
+
for k in range(num_heads):
|
| 399 |
+
layer = nn.Linear(in_features, out_features, self.bias)
|
| 400 |
+
self.layers.append(layer)
|
| 401 |
+
|
| 402 |
+
def forward(self, inputs):
|
| 403 |
+
if self.dropout is not None:
|
| 404 |
+
inputs = self.dropout(inputs)
|
| 405 |
+
|
| 406 |
+
outputs = []
|
| 407 |
+
for layer in self.layers:
|
| 408 |
+
output = layer.forward(inputs)
|
| 409 |
+
outputs.append(output)
|
| 410 |
+
|
| 411 |
+
outputs = torch.stack(outputs, dim=1)
|
| 412 |
+
return outputs
|
| 413 |
+
|
| 414 |
+
@property
|
| 415 |
+
def module_str(self):
|
| 416 |
+
return self.__repr__()
|
| 417 |
+
|
| 418 |
+
@property
|
| 419 |
+
def config(self):
|
| 420 |
+
return {
|
| 421 |
+
"name": MultiHeadLinearLayer.__name__,
|
| 422 |
+
"in_features": self.in_features,
|
| 423 |
+
"out_features": self.out_features,
|
| 424 |
+
"num_heads": self.num_heads,
|
| 425 |
+
"bias": self.bias,
|
| 426 |
+
"dropout_rate": self.dropout_rate,
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
@staticmethod
|
| 430 |
+
def build_from_config(config):
|
| 431 |
+
return MultiHeadLinearLayer(**config)
|
| 432 |
+
|
| 433 |
+
def __repr__(self):
|
| 434 |
+
return (
|
| 435 |
+
"MultiHeadLinear(in_features=%d, out_features=%d, num_heads=%d, bias=%s, dropout_rate=%s)"
|
| 436 |
+
% (
|
| 437 |
+
self.in_features,
|
| 438 |
+
self.out_features,
|
| 439 |
+
self.num_heads,
|
| 440 |
+
self.bias,
|
| 441 |
+
self.dropout_rate,
|
| 442 |
+
)
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class ZeroLayer(MyModule):
|
| 447 |
+
def __init__(self):
|
| 448 |
+
super(ZeroLayer, self).__init__()
|
| 449 |
+
|
| 450 |
+
def forward(self, x):
|
| 451 |
+
raise ValueError
|
| 452 |
+
|
| 453 |
+
@property
|
| 454 |
+
def module_str(self):
|
| 455 |
+
return "Zero"
|
| 456 |
+
|
| 457 |
+
@property
|
| 458 |
+
def config(self):
|
| 459 |
+
return {
|
| 460 |
+
"name": ZeroLayer.__name__,
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
@staticmethod
|
| 464 |
+
def build_from_config(config):
|
| 465 |
+
return ZeroLayer()
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class MBConvLayer(MyModule):
|
| 469 |
+
def __init__(
|
| 470 |
+
self,
|
| 471 |
+
in_channels,
|
| 472 |
+
out_channels,
|
| 473 |
+
kernel_size=3,
|
| 474 |
+
stride=1,
|
| 475 |
+
expand_ratio=6,
|
| 476 |
+
mid_channels=None,
|
| 477 |
+
act_func="relu6",
|
| 478 |
+
use_se=False,
|
| 479 |
+
groups=None,
|
| 480 |
+
):
|
| 481 |
+
super(MBConvLayer, self).__init__()
|
| 482 |
+
|
| 483 |
+
self.in_channels = in_channels
|
| 484 |
+
self.out_channels = out_channels
|
| 485 |
+
|
| 486 |
+
self.kernel_size = kernel_size
|
| 487 |
+
self.stride = stride
|
| 488 |
+
self.expand_ratio = expand_ratio
|
| 489 |
+
self.mid_channels = mid_channels
|
| 490 |
+
self.act_func = act_func
|
| 491 |
+
self.use_se = use_se
|
| 492 |
+
self.groups = groups
|
| 493 |
+
|
| 494 |
+
if self.mid_channels is None:
|
| 495 |
+
feature_dim = round(self.in_channels * self.expand_ratio)
|
| 496 |
+
else:
|
| 497 |
+
feature_dim = self.mid_channels
|
| 498 |
+
|
| 499 |
+
if self.expand_ratio == 1:
|
| 500 |
+
self.inverted_bottleneck = None
|
| 501 |
+
else:
|
| 502 |
+
self.inverted_bottleneck = nn.Sequential(
|
| 503 |
+
OrderedDict(
|
| 504 |
+
[
|
| 505 |
+
(
|
| 506 |
+
"conv",
|
| 507 |
+
nn.Conv2d(
|
| 508 |
+
self.in_channels, feature_dim, 1, 1, 0, bias=False
|
| 509 |
+
),
|
| 510 |
+
),
|
| 511 |
+
("bn", nn.BatchNorm2d(feature_dim)),
|
| 512 |
+
("act", build_activation(self.act_func, inplace=True)),
|
| 513 |
+
]
|
| 514 |
+
)
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
pad = get_same_padding(self.kernel_size)
|
| 518 |
+
groups = (
|
| 519 |
+
feature_dim
|
| 520 |
+
if self.groups is None
|
| 521 |
+
else min_divisible_value(feature_dim, self.groups)
|
| 522 |
+
)
|
| 523 |
+
depth_conv_modules = [
|
| 524 |
+
(
|
| 525 |
+
"conv",
|
| 526 |
+
nn.Conv2d(
|
| 527 |
+
feature_dim,
|
| 528 |
+
feature_dim,
|
| 529 |
+
kernel_size,
|
| 530 |
+
stride,
|
| 531 |
+
pad,
|
| 532 |
+
groups=groups,
|
| 533 |
+
bias=False,
|
| 534 |
+
),
|
| 535 |
+
),
|
| 536 |
+
("bn", nn.BatchNorm2d(feature_dim)),
|
| 537 |
+
("act", build_activation(self.act_func, inplace=True)),
|
| 538 |
+
]
|
| 539 |
+
if self.use_se:
|
| 540 |
+
depth_conv_modules.append(("se", SEModule(feature_dim)))
|
| 541 |
+
self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))
|
| 542 |
+
|
| 543 |
+
self.point_linear = nn.Sequential(
|
| 544 |
+
OrderedDict(
|
| 545 |
+
[
|
| 546 |
+
("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
|
| 547 |
+
("bn", nn.BatchNorm2d(out_channels)),
|
| 548 |
+
]
|
| 549 |
+
)
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
def forward(self, x):
|
| 553 |
+
if self.inverted_bottleneck:
|
| 554 |
+
x = self.inverted_bottleneck(x)
|
| 555 |
+
x = self.depth_conv(x)
|
| 556 |
+
x = self.point_linear(x)
|
| 557 |
+
return x
|
| 558 |
+
|
| 559 |
+
@property
|
| 560 |
+
def module_str(self):
|
| 561 |
+
if self.mid_channels is None:
|
| 562 |
+
expand_ratio = self.expand_ratio
|
| 563 |
+
else:
|
| 564 |
+
expand_ratio = self.mid_channels // self.in_channels
|
| 565 |
+
layer_str = "%dx%d_MBConv%d_%s" % (
|
| 566 |
+
self.kernel_size,
|
| 567 |
+
self.kernel_size,
|
| 568 |
+
expand_ratio,
|
| 569 |
+
self.act_func.upper(),
|
| 570 |
+
)
|
| 571 |
+
if self.use_se:
|
| 572 |
+
layer_str = "SE_" + layer_str
|
| 573 |
+
layer_str += "_O%d" % self.out_channels
|
| 574 |
+
if self.groups is not None:
|
| 575 |
+
layer_str += "_G%d" % self.groups
|
| 576 |
+
if isinstance(self.point_linear.bn, nn.GroupNorm):
|
| 577 |
+
layer_str += "_GN%d" % self.point_linear.bn.num_groups
|
| 578 |
+
elif isinstance(self.point_linear.bn, nn.BatchNorm2d):
|
| 579 |
+
layer_str += "_BN"
|
| 580 |
+
|
| 581 |
+
return layer_str
|
| 582 |
+
|
| 583 |
+
@property
|
| 584 |
+
def config(self):
|
| 585 |
+
return {
|
| 586 |
+
"name": MBConvLayer.__name__,
|
| 587 |
+
"in_channels": self.in_channels,
|
| 588 |
+
"out_channels": self.out_channels,
|
| 589 |
+
"kernel_size": self.kernel_size,
|
| 590 |
+
"stride": self.stride,
|
| 591 |
+
"expand_ratio": self.expand_ratio,
|
| 592 |
+
"mid_channels": self.mid_channels,
|
| 593 |
+
"act_func": self.act_func,
|
| 594 |
+
"use_se": self.use_se,
|
| 595 |
+
"groups": self.groups,
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
@staticmethod
|
| 599 |
+
def build_from_config(config):
|
| 600 |
+
return MBConvLayer(**config)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class ResidualBlock(MyModule):
|
| 604 |
+
def __init__(self, conv, shortcut):
|
| 605 |
+
super(ResidualBlock, self).__init__()
|
| 606 |
+
|
| 607 |
+
self.conv = conv
|
| 608 |
+
self.shortcut = shortcut
|
| 609 |
+
|
| 610 |
+
def forward(self, x):
|
| 611 |
+
if self.conv is None or isinstance(self.conv, ZeroLayer):
|
| 612 |
+
res = x
|
| 613 |
+
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
|
| 614 |
+
res = self.conv(x)
|
| 615 |
+
else:
|
| 616 |
+
res = self.conv(x) + self.shortcut(x)
|
| 617 |
+
return res
|
| 618 |
+
|
| 619 |
+
@property
|
| 620 |
+
def module_str(self):
|
| 621 |
+
return "(%s, %s)" % (
|
| 622 |
+
self.conv.module_str if self.conv is not None else None,
|
| 623 |
+
self.shortcut.module_str if self.shortcut is not None else None,
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
@property
|
| 627 |
+
def config(self):
|
| 628 |
+
return {
|
| 629 |
+
"name": ResidualBlock.__name__,
|
| 630 |
+
"conv": self.conv.config if self.conv is not None else None,
|
| 631 |
+
"shortcut": self.shortcut.config if self.shortcut is not None else None,
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
@staticmethod
|
| 635 |
+
def build_from_config(config):
|
| 636 |
+
conv_config = (
|
| 637 |
+
config["conv"] if "conv" in config else config["mobile_inverted_conv"]
|
| 638 |
+
)
|
| 639 |
+
conv = set_layer_from_config(conv_config)
|
| 640 |
+
shortcut = set_layer_from_config(config["shortcut"])
|
| 641 |
+
return ResidualBlock(conv, shortcut)
|
| 642 |
+
|
| 643 |
+
@property
|
| 644 |
+
def mobile_inverted_conv(self):
|
| 645 |
+
return self.conv
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
class ResNetBottleneckBlock(MyModule):
|
| 649 |
+
def __init__(
|
| 650 |
+
self,
|
| 651 |
+
in_channels,
|
| 652 |
+
out_channels,
|
| 653 |
+
kernel_size=3,
|
| 654 |
+
stride=1,
|
| 655 |
+
expand_ratio=0.25,
|
| 656 |
+
mid_channels=None,
|
| 657 |
+
act_func="relu",
|
| 658 |
+
groups=1,
|
| 659 |
+
downsample_mode="avgpool_conv",
|
| 660 |
+
):
|
| 661 |
+
super(ResNetBottleneckBlock, self).__init__()
|
| 662 |
+
|
| 663 |
+
self.in_channels = in_channels
|
| 664 |
+
self.out_channels = out_channels
|
| 665 |
+
|
| 666 |
+
self.kernel_size = kernel_size
|
| 667 |
+
self.stride = stride
|
| 668 |
+
self.expand_ratio = expand_ratio
|
| 669 |
+
self.mid_channels = mid_channels
|
| 670 |
+
self.act_func = act_func
|
| 671 |
+
self.groups = groups
|
| 672 |
+
|
| 673 |
+
self.downsample_mode = downsample_mode
|
| 674 |
+
|
| 675 |
+
if self.mid_channels is None:
|
| 676 |
+
feature_dim = round(self.out_channels * self.expand_ratio)
|
| 677 |
+
else:
|
| 678 |
+
feature_dim = self.mid_channels
|
| 679 |
+
|
| 680 |
+
feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
|
| 681 |
+
self.mid_channels = feature_dim
|
| 682 |
+
|
| 683 |
+
# build modules
|
| 684 |
+
self.conv1 = nn.Sequential(
|
| 685 |
+
OrderedDict(
|
| 686 |
+
[
|
| 687 |
+
(
|
| 688 |
+
"conv",
|
| 689 |
+
nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False),
|
| 690 |
+
),
|
| 691 |
+
("bn", nn.BatchNorm2d(feature_dim)),
|
| 692 |
+
("act", build_activation(self.act_func, inplace=True)),
|
| 693 |
+
]
|
| 694 |
+
)
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
pad = get_same_padding(self.kernel_size)
|
| 698 |
+
self.conv2 = nn.Sequential(
|
| 699 |
+
OrderedDict(
|
| 700 |
+
[
|
| 701 |
+
(
|
| 702 |
+
"conv",
|
| 703 |
+
nn.Conv2d(
|
| 704 |
+
feature_dim,
|
| 705 |
+
feature_dim,
|
| 706 |
+
kernel_size,
|
| 707 |
+
stride,
|
| 708 |
+
pad,
|
| 709 |
+
groups=groups,
|
| 710 |
+
bias=False,
|
| 711 |
+
),
|
| 712 |
+
),
|
| 713 |
+
("bn", nn.BatchNorm2d(feature_dim)),
|
| 714 |
+
("act", build_activation(self.act_func, inplace=True)),
|
| 715 |
+
]
|
| 716 |
+
)
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
self.conv3 = nn.Sequential(
|
| 720 |
+
OrderedDict(
|
| 721 |
+
[
|
| 722 |
+
(
|
| 723 |
+
"conv",
|
| 724 |
+
nn.Conv2d(feature_dim, self.out_channels, 1, 1, 0, bias=False),
|
| 725 |
+
),
|
| 726 |
+
("bn", nn.BatchNorm2d(self.out_channels)),
|
| 727 |
+
]
|
| 728 |
+
)
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
if stride == 1 and in_channels == out_channels:
|
| 732 |
+
self.downsample = IdentityLayer(in_channels, out_channels)
|
| 733 |
+
elif self.downsample_mode == "conv":
|
| 734 |
+
self.downsample = nn.Sequential(
|
| 735 |
+
OrderedDict(
|
| 736 |
+
[
|
| 737 |
+
(
|
| 738 |
+
"conv",
|
| 739 |
+
nn.Conv2d(
|
| 740 |
+
in_channels, out_channels, 1, stride, 0, bias=False
|
| 741 |
+
),
|
| 742 |
+
),
|
| 743 |
+
("bn", nn.BatchNorm2d(out_channels)),
|
| 744 |
+
]
|
| 745 |
+
)
|
| 746 |
+
)
|
| 747 |
+
elif self.downsample_mode == "avgpool_conv":
|
| 748 |
+
self.downsample = nn.Sequential(
|
| 749 |
+
OrderedDict(
|
| 750 |
+
[
|
| 751 |
+
(
|
| 752 |
+
"avg_pool",
|
| 753 |
+
nn.AvgPool2d(
|
| 754 |
+
kernel_size=stride,
|
| 755 |
+
stride=stride,
|
| 756 |
+
padding=0,
|
| 757 |
+
ceil_mode=True,
|
| 758 |
+
),
|
| 759 |
+
),
|
| 760 |
+
(
|
| 761 |
+
"conv",
|
| 762 |
+
nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
|
| 763 |
+
),
|
| 764 |
+
("bn", nn.BatchNorm2d(out_channels)),
|
| 765 |
+
]
|
| 766 |
+
)
|
| 767 |
+
)
|
| 768 |
+
else:
|
| 769 |
+
raise NotImplementedError
|
| 770 |
+
|
| 771 |
+
self.final_act = build_activation(self.act_func, inplace=True)
|
| 772 |
+
|
| 773 |
+
def forward(self, x):
|
| 774 |
+
residual = self.downsample(x)
|
| 775 |
+
|
| 776 |
+
x = self.conv1(x)
|
| 777 |
+
x = self.conv2(x)
|
| 778 |
+
x = self.conv3(x)
|
| 779 |
+
|
| 780 |
+
x = x + residual
|
| 781 |
+
x = self.final_act(x)
|
| 782 |
+
return x
|
| 783 |
+
|
| 784 |
+
@property
|
| 785 |
+
def module_str(self):
|
| 786 |
+
return "(%s, %s)" % (
|
| 787 |
+
"%dx%d_BottleneckConv_%d->%d->%d_S%d_G%d"
|
| 788 |
+
% (
|
| 789 |
+
self.kernel_size,
|
| 790 |
+
self.kernel_size,
|
| 791 |
+
self.in_channels,
|
| 792 |
+
self.mid_channels,
|
| 793 |
+
self.out_channels,
|
| 794 |
+
self.stride,
|
| 795 |
+
self.groups,
|
| 796 |
+
),
|
| 797 |
+
"Identity"
|
| 798 |
+
if isinstance(self.downsample, IdentityLayer)
|
| 799 |
+
else self.downsample_mode,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
@property
|
| 803 |
+
def config(self):
|
| 804 |
+
return {
|
| 805 |
+
"name": ResNetBottleneckBlock.__name__,
|
| 806 |
+
"in_channels": self.in_channels,
|
| 807 |
+
"out_channels": self.out_channels,
|
| 808 |
+
"kernel_size": self.kernel_size,
|
| 809 |
+
"stride": self.stride,
|
| 810 |
+
"expand_ratio": self.expand_ratio,
|
| 811 |
+
"mid_channels": self.mid_channels,
|
| 812 |
+
"act_func": self.act_func,
|
| 813 |
+
"groups": self.groups,
|
| 814 |
+
"downsample_mode": self.downsample_mode,
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
@staticmethod
|
| 818 |
+
def build_from_config(config):
|
| 819 |
+
return ResNetBottleneckBlock(**config)
|
proard/utils/my_dataloader/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .my_distributed_sampler import *
|
| 2 |
+
from .my_random_resize_crop import *
|
proard/utils/my_dataloader/my_data_loader.py
ADDED
|
@@ -0,0 +1,1050 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
|
| 2 |
+
|
| 3 |
+
To support these two classes, in `./_utils` we define many utility methods and
|
| 4 |
+
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
| 5 |
+
in `./_utils/worker.py`.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import threading
|
| 9 |
+
import itertools
|
| 10 |
+
import warnings
|
| 11 |
+
import multiprocessing as python_multiprocessing
|
| 12 |
+
import torch
|
| 13 |
+
import torch.multiprocessing as multiprocessing
|
| 14 |
+
from torch._utils import ExceptionWrapper
|
| 15 |
+
from torch.multiprocessing import Queue as queue
|
| 16 |
+
from torch._six import string_classes
|
| 17 |
+
from torch.utils.data.dataset import IterableDataset
|
| 18 |
+
from torch.utils.data import Sampler, SequentialSampler, RandomSampler, BatchSampler
|
| 19 |
+
from torch.utils.data import _utils
|
| 20 |
+
|
| 21 |
+
from .my_data_worker import worker_loop
|
| 22 |
+
|
| 23 |
+
__all__ = ["MyDataLoader"]
|
| 24 |
+
|
| 25 |
+
get_worker_info = _utils.worker.get_worker_info
|
| 26 |
+
|
| 27 |
+
# This function used to be defined in this file. However, it was moved to
|
| 28 |
+
# _utils/collate.py. Although it is rather hard to access this from user land
|
| 29 |
+
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
|
| 30 |
+
# probably is user code out there using it. This aliasing maintains BC in this
|
| 31 |
+
# aspect.
|
| 32 |
+
default_collate = _utils.collate.default_collate
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class _DatasetKind(object):
|
| 36 |
+
Map = 0
|
| 37 |
+
Iterable = 1
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
|
| 41 |
+
if kind == _DatasetKind.Map:
|
| 42 |
+
return _utils.fetch._MapDatasetFetcher(
|
| 43 |
+
dataset, auto_collation, collate_fn, drop_last
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
return _utils.fetch._IterableDatasetFetcher(
|
| 47 |
+
dataset, auto_collation, collate_fn, drop_last
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class _InfiniteConstantSampler(Sampler):
|
| 52 |
+
r"""Analogous to ``itertools.repeat(None, None)``.
|
| 53 |
+
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
|
| 54 |
+
|
| 55 |
+
Arguments:
|
| 56 |
+
data_source (Dataset): dataset to sample from
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self):
|
| 60 |
+
super(_InfiniteConstantSampler, self).__init__(None)
|
| 61 |
+
|
| 62 |
+
def __iter__(self):
|
| 63 |
+
while True:
|
| 64 |
+
yield None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MyDataLoader(object):
|
| 68 |
+
r"""
|
| 69 |
+
Data loader. Combines a dataset and a sampler, and provides an iterable over
|
| 70 |
+
the given dataset.
|
| 71 |
+
|
| 72 |
+
The :class:`~torch.utils.data.DataLoader` supports both map-style and
|
| 73 |
+
iterable-style datasets with single- or multi-process loading, customizing
|
| 74 |
+
loading order and optional automatic batching (collation) and memory pinning.
|
| 75 |
+
|
| 76 |
+
See :py:mod:`torch.utils.data` documentation page for more details.
|
| 77 |
+
|
| 78 |
+
Arguments:
|
| 79 |
+
dataset (Dataset): dataset from which to load the data.
|
| 80 |
+
batch_size (int, optional): how many samples per batch to load
|
| 81 |
+
(default: ``1``).
|
| 82 |
+
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
| 83 |
+
at every epoch (default: ``False``).
|
| 84 |
+
sampler (Sampler, optional): defines the strategy to draw samples from
|
| 85 |
+
the dataset. If specified, :attr:`shuffle` must be ``False``.
|
| 86 |
+
batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of
|
| 87 |
+
indices at a time. Mutually exclusive with :attr:`batch_size`,
|
| 88 |
+
:attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
|
| 89 |
+
num_workers (int, optional): how many subprocesses to use for data
|
| 90 |
+
loading. ``0`` means that the data will be loaded in the main process.
|
| 91 |
+
(default: ``0``)
|
| 92 |
+
collate_fn (callable, optional): merges a list of samples to form a
|
| 93 |
+
mini-batch of Tensor(s). Used when using batched loading from a
|
| 94 |
+
map-style dataset.
|
| 95 |
+
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
|
| 96 |
+
into CUDA pinned memory before returning them. If your data elements
|
| 97 |
+
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
|
| 98 |
+
see the example below.
|
| 99 |
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
| 100 |
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
| 101 |
+
the size of dataset is not divisible by the batch size, then the last batch
|
| 102 |
+
will be smaller. (default: ``False``)
|
| 103 |
+
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
| 104 |
+
from workers. Should always be non-negative. (default: ``0``)
|
| 105 |
+
worker_init_fn (callable, optional): If not ``None``, this will be called on each
|
| 106 |
+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
| 107 |
+
input, after seeding and before data loading. (default: ``None``)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
|
| 111 |
+
cannot be an unpicklable object, e.g., a lambda function. See
|
| 112 |
+
:ref:`multiprocessing-best-practices` on more details related
|
| 113 |
+
to multiprocessing in PyTorch.
|
| 114 |
+
|
| 115 |
+
.. note:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
|
| 116 |
+
When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
|
| 117 |
+
``len(dataset)`` (if implemented) is returned instead, regardless
|
| 118 |
+
of multi-process loading configurations, because PyTorch trust
|
| 119 |
+
user :attr:`dataset` code in correctly handling multi-process
|
| 120 |
+
loading to avoid duplicate data. See `Dataset Types`_ for more
|
| 121 |
+
details on these two types of datasets and how
|
| 122 |
+
:class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
__initialized = False
|
| 126 |
+
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
dataset,
|
| 130 |
+
batch_size=1,
|
| 131 |
+
shuffle=False,
|
| 132 |
+
sampler=None,
|
| 133 |
+
batch_sampler=None,
|
| 134 |
+
num_workers=0,
|
| 135 |
+
collate_fn=None,
|
| 136 |
+
pin_memory=False,
|
| 137 |
+
drop_last=False,
|
| 138 |
+
timeout=0,
|
| 139 |
+
worker_init_fn=None,
|
| 140 |
+
multiprocessing_context=None,
|
| 141 |
+
):
|
| 142 |
+
torch._C._log_api_usage_once("python.data_loader")
|
| 143 |
+
|
| 144 |
+
if num_workers < 0:
|
| 145 |
+
raise ValueError(
|
| 146 |
+
"num_workers option should be non-negative; "
|
| 147 |
+
"use num_workers=0 to disable multiprocessing."
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if timeout < 0:
|
| 151 |
+
raise ValueError("timeout option should be non-negative")
|
| 152 |
+
|
| 153 |
+
self.dataset = dataset
|
| 154 |
+
self.num_workers = num_workers
|
| 155 |
+
self.pin_memory = pin_memory
|
| 156 |
+
self.timeout = timeout
|
| 157 |
+
self.worker_init_fn = worker_init_fn
|
| 158 |
+
self.multiprocessing_context = multiprocessing_context
|
| 159 |
+
|
| 160 |
+
# Arg-check dataset related before checking samplers because we want to
|
| 161 |
+
# tell users that iterable-style datasets are incompatible with custom
|
| 162 |
+
# samplers first, so that they don't learn that this combo doesn't work
|
| 163 |
+
# after spending time fixing the custom sampler errors.
|
| 164 |
+
if isinstance(dataset, IterableDataset):
|
| 165 |
+
self._dataset_kind = _DatasetKind.Iterable
|
| 166 |
+
# NOTE [ Custom Samplers and `IterableDataset` ]
|
| 167 |
+
#
|
| 168 |
+
# `IterableDataset` does not support custom `batch_sampler` or
|
| 169 |
+
# `sampler` since the key is irrelevant (unless we support
|
| 170 |
+
# generator-style dataset one day...).
|
| 171 |
+
#
|
| 172 |
+
# For `sampler`, we always create a dummy sampler. This is an
|
| 173 |
+
# infinite sampler even when the dataset may have an implemented
|
| 174 |
+
# finite `__len__` because in multi-process data loading, naive
|
| 175 |
+
# settings will return duplicated data (which may be desired), and
|
| 176 |
+
# thus using a sampler with length matching that of dataset will
|
| 177 |
+
# cause data lost (you may have duplicates of the first couple
|
| 178 |
+
# batches, but never see anything afterwards). Therefore,
|
| 179 |
+
# `Iterabledataset` always uses an infinite sampler, an instance of
|
| 180 |
+
# `_InfiniteConstantSampler` defined above.
|
| 181 |
+
#
|
| 182 |
+
# A custom `batch_sampler` essentially only controls the batch size.
|
| 183 |
+
# However, it is unclear how useful it would be since an iterable-style
|
| 184 |
+
# dataset can handle that within itself. Moreover, it is pointless
|
| 185 |
+
# in multi-process data loading as the assignment order of batches
|
| 186 |
+
# to workers is an implementation detail so users can not control
|
| 187 |
+
# how to batchify each worker's iterable. Thus, we disable this
|
| 188 |
+
# option. If this turns out to be useful in future, we can re-enable
|
| 189 |
+
# this, and support custom samplers that specify the assignments to
|
| 190 |
+
# specific workers.
|
| 191 |
+
if shuffle is not False:
|
| 192 |
+
raise ValueError(
|
| 193 |
+
"DataLoader with IterableDataset: expected unspecified "
|
| 194 |
+
"shuffle option, but got shuffle={}".format(shuffle)
|
| 195 |
+
)
|
| 196 |
+
elif sampler is not None:
|
| 197 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
| 198 |
+
raise ValueError(
|
| 199 |
+
"DataLoader with IterableDataset: expected unspecified "
|
| 200 |
+
"sampler option, but got sampler={}".format(sampler)
|
| 201 |
+
)
|
| 202 |
+
elif batch_sampler is not None:
|
| 203 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
| 204 |
+
raise ValueError(
|
| 205 |
+
"DataLoader with IterableDataset: expected unspecified "
|
| 206 |
+
"batch_sampler option, but got batch_sampler={}".format(
|
| 207 |
+
batch_sampler
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
else:
|
| 211 |
+
self._dataset_kind = _DatasetKind.Map
|
| 212 |
+
|
| 213 |
+
if sampler is not None and shuffle:
|
| 214 |
+
raise ValueError("sampler option is mutually exclusive with " "shuffle")
|
| 215 |
+
|
| 216 |
+
if batch_sampler is not None:
|
| 217 |
+
# auto_collation with custom batch_sampler
|
| 218 |
+
if batch_size != 1 or shuffle or sampler is not None or drop_last:
|
| 219 |
+
raise ValueError(
|
| 220 |
+
"batch_sampler option is mutually exclusive "
|
| 221 |
+
"with batch_size, shuffle, sampler, and "
|
| 222 |
+
"drop_last"
|
| 223 |
+
)
|
| 224 |
+
batch_size = None
|
| 225 |
+
drop_last = False
|
| 226 |
+
elif batch_size is None:
|
| 227 |
+
# no auto_collation
|
| 228 |
+
if shuffle or drop_last:
|
| 229 |
+
raise ValueError(
|
| 230 |
+
"batch_size=None option disables auto-batching "
|
| 231 |
+
"and is mutually exclusive with "
|
| 232 |
+
"shuffle, and drop_last"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if sampler is None: # give default samplers
|
| 236 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
| 237 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
| 238 |
+
sampler = _InfiniteConstantSampler()
|
| 239 |
+
else: # map-style
|
| 240 |
+
if shuffle:
|
| 241 |
+
sampler = RandomSampler(dataset)
|
| 242 |
+
else:
|
| 243 |
+
sampler = SequentialSampler(dataset)
|
| 244 |
+
|
| 245 |
+
if batch_size is not None and batch_sampler is None:
|
| 246 |
+
# auto_collation without custom batch_sampler
|
| 247 |
+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
| 248 |
+
|
| 249 |
+
self.batch_size = batch_size
|
| 250 |
+
self.drop_last = drop_last
|
| 251 |
+
self.sampler = sampler
|
| 252 |
+
self.batch_sampler = batch_sampler
|
| 253 |
+
|
| 254 |
+
if collate_fn is None:
|
| 255 |
+
if self._auto_collation:
|
| 256 |
+
collate_fn = _utils.collate.default_collate
|
| 257 |
+
else:
|
| 258 |
+
collate_fn = _utils.collate.default_convert
|
| 259 |
+
|
| 260 |
+
self.collate_fn = collate_fn
|
| 261 |
+
self.__initialized = True
|
| 262 |
+
self._IterableDataset_len_called = (
|
| 263 |
+
None # See NOTE [ IterableDataset and __len__ ]
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def multiprocessing_context(self):
|
| 268 |
+
return self.__multiprocessing_context
|
| 269 |
+
|
| 270 |
+
@multiprocessing_context.setter
|
| 271 |
+
def multiprocessing_context(self, multiprocessing_context):
|
| 272 |
+
if multiprocessing_context is not None:
|
| 273 |
+
if self.num_workers > 0:
|
| 274 |
+
if not multiprocessing._supports_context:
|
| 275 |
+
raise ValueError(
|
| 276 |
+
"multiprocessing_context relies on Python >= 3.4, with "
|
| 277 |
+
"support for different start methods"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if isinstance(multiprocessing_context, string_classes):
|
| 281 |
+
valid_start_methods = multiprocessing.get_all_start_methods()
|
| 282 |
+
if multiprocessing_context not in valid_start_methods:
|
| 283 |
+
raise ValueError(
|
| 284 |
+
(
|
| 285 |
+
"multiprocessing_context option "
|
| 286 |
+
"should specify a valid start method in {}, but got "
|
| 287 |
+
"multiprocessing_context={}"
|
| 288 |
+
).format(valid_start_methods, multiprocessing_context)
|
| 289 |
+
)
|
| 290 |
+
multiprocessing_context = multiprocessing.get_context(
|
| 291 |
+
multiprocessing_context
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if not isinstance(
|
| 295 |
+
multiprocessing_context, python_multiprocessing.context.BaseContext
|
| 296 |
+
):
|
| 297 |
+
raise ValueError(
|
| 298 |
+
(
|
| 299 |
+
"multiprocessing_context option should be a valid context "
|
| 300 |
+
"object or a string specifying the start method, but got "
|
| 301 |
+
"multiprocessing_context={}"
|
| 302 |
+
).format(multiprocessing_context)
|
| 303 |
+
)
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
(
|
| 307 |
+
"multiprocessing_context can only be used with "
|
| 308 |
+
"multi-process loading (num_workers > 0), but got "
|
| 309 |
+
"num_workers={}"
|
| 310 |
+
).format(self.num_workers)
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
self.__multiprocessing_context = multiprocessing_context
|
| 314 |
+
|
| 315 |
+
def __setattr__(self, attr, val):
|
| 316 |
+
if self.__initialized and attr in (
|
| 317 |
+
"batch_size",
|
| 318 |
+
"batch_sampler",
|
| 319 |
+
"sampler",
|
| 320 |
+
"drop_last",
|
| 321 |
+
"dataset",
|
| 322 |
+
):
|
| 323 |
+
raise ValueError(
|
| 324 |
+
"{} attribute should not be set after {} is "
|
| 325 |
+
"initialized".format(attr, self.__class__.__name__)
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
super(MyDataLoader, self).__setattr__(attr, val)
|
| 329 |
+
|
| 330 |
+
def __iter__(self):
|
| 331 |
+
if self.num_workers == 0:
|
| 332 |
+
return _SingleProcessDataLoaderIter(self)
|
| 333 |
+
else:
|
| 334 |
+
return _MultiProcessingDataLoaderIter(self)
|
| 335 |
+
|
| 336 |
+
@property
|
| 337 |
+
def _auto_collation(self):
|
| 338 |
+
return self.batch_sampler is not None
|
| 339 |
+
|
| 340 |
+
@property
|
| 341 |
+
def _index_sampler(self):
|
| 342 |
+
# The actual sampler used for generating indices for `_DatasetFetcher`
|
| 343 |
+
# (see _utils/fetch.py) to read data at each time. This would be
|
| 344 |
+
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
|
| 345 |
+
# We can't change `.sampler` and `.batch_sampler` attributes for BC
|
| 346 |
+
# reasons.
|
| 347 |
+
if self._auto_collation:
|
| 348 |
+
return self.batch_sampler
|
| 349 |
+
else:
|
| 350 |
+
return self.sampler
|
| 351 |
+
|
| 352 |
+
def __len__(self):
|
| 353 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
| 354 |
+
# NOTE [ IterableDataset and __len__ ]
|
| 355 |
+
#
|
| 356 |
+
# For `IterableDataset`, `__len__` could be inaccurate when one naively
|
| 357 |
+
# does multi-processing data loading, since the samples will be duplicated.
|
| 358 |
+
# However, no real use case should be actually using that behavior, so
|
| 359 |
+
# it should count as a user error. We should generally trust user
|
| 360 |
+
# code to do the proper thing (e.g., configure each replica differently
|
| 361 |
+
# in `__iter__`), and give us the correct `__len__` if they choose to
|
| 362 |
+
# implement it (this will still throw if the dataset does not implement
|
| 363 |
+
# a `__len__`).
|
| 364 |
+
#
|
| 365 |
+
# To provide a further warning, we track if `__len__` was called on the
|
| 366 |
+
# `DataLoader`, save the returned value in `self._len_called`, and warn
|
| 367 |
+
# if the iterator ends up yielding more than this number of samples.
|
| 368 |
+
length = self._IterableDataset_len_called = len(self.dataset)
|
| 369 |
+
return length
|
| 370 |
+
else:
|
| 371 |
+
return len(self._index_sampler)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class _BaseDataLoaderIter(object):
|
| 375 |
+
def __init__(self, loader):
|
| 376 |
+
self._dataset = loader.dataset
|
| 377 |
+
self._dataset_kind = loader._dataset_kind
|
| 378 |
+
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
| 379 |
+
self._auto_collation = loader._auto_collation
|
| 380 |
+
self._drop_last = loader.drop_last
|
| 381 |
+
self._index_sampler = loader._index_sampler
|
| 382 |
+
self._num_workers = loader.num_workers
|
| 383 |
+
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
|
| 384 |
+
self._timeout = loader.timeout
|
| 385 |
+
self._collate_fn = loader.collate_fn
|
| 386 |
+
self._sampler_iter = iter(self._index_sampler)
|
| 387 |
+
self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
|
| 388 |
+
self._num_yielded = 0
|
| 389 |
+
|
| 390 |
+
def __iter__(self):
|
| 391 |
+
return self
|
| 392 |
+
|
| 393 |
+
def _next_index(self):
|
| 394 |
+
return next(self._sampler_iter) # may raise StopIteration
|
| 395 |
+
|
| 396 |
+
def _next_data(self):
|
| 397 |
+
raise NotImplementedError
|
| 398 |
+
|
| 399 |
+
def __next__(self):
|
| 400 |
+
data = self._next_data()
|
| 401 |
+
self._num_yielded += 1
|
| 402 |
+
if (
|
| 403 |
+
self._dataset_kind == _DatasetKind.Iterable
|
| 404 |
+
and self._IterableDataset_len_called is not None
|
| 405 |
+
and self._num_yielded > self._IterableDataset_len_called
|
| 406 |
+
):
|
| 407 |
+
warn_msg = (
|
| 408 |
+
"Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
|
| 409 |
+
"samples have been fetched. "
|
| 410 |
+
).format(self._dataset, self._IterableDataset_len_called, self._num_yielded)
|
| 411 |
+
if self._num_workers > 0:
|
| 412 |
+
warn_msg += (
|
| 413 |
+
"For multiprocessing data-loading, this could be caused by not properly configuring the "
|
| 414 |
+
"IterableDataset replica at each worker. Please see "
|
| 415 |
+
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
|
| 416 |
+
)
|
| 417 |
+
warnings.warn(warn_msg)
|
| 418 |
+
return data
|
| 419 |
+
|
| 420 |
+
next = __next__ # Python 2 compatibility
|
| 421 |
+
|
| 422 |
+
def __len__(self):
|
| 423 |
+
return len(self._index_sampler)
|
| 424 |
+
|
| 425 |
+
def __getstate__(self):
|
| 426 |
+
# across multiple threads for HOGWILD.
|
| 427 |
+
# Probably the best way to do this is by moving the sample pushing
|
| 428 |
+
# to a separate thread and then just sharing the data queue
|
| 429 |
+
# but signalling the end is tricky without a non-blocking API
|
| 430 |
+
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
|
| 434 |
+
def __init__(self, loader):
|
| 435 |
+
super(_SingleProcessDataLoaderIter, self).__init__(loader)
|
| 436 |
+
assert self._timeout == 0
|
| 437 |
+
assert self._num_workers == 0
|
| 438 |
+
|
| 439 |
+
self._dataset_fetcher = _DatasetKind.create_fetcher(
|
| 440 |
+
self._dataset_kind,
|
| 441 |
+
self._dataset,
|
| 442 |
+
self._auto_collation,
|
| 443 |
+
self._collate_fn,
|
| 444 |
+
self._drop_last,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
def _next_data(self):
|
| 448 |
+
index = self._next_index() # may raise StopIteration
|
| 449 |
+
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
|
| 450 |
+
if self._pin_memory:
|
| 451 |
+
data = _utils.pin_memory.pin_memory(data)
|
| 452 |
+
return data
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
| 456 |
+
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
|
| 457 |
+
|
| 458 |
+
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
|
| 459 |
+
#
|
| 460 |
+
# Preliminary:
|
| 461 |
+
#
|
| 462 |
+
# Our data model looks like this (queues are indicated with curly brackets):
|
| 463 |
+
#
|
| 464 |
+
# main process ||
|
| 465 |
+
# | ||
|
| 466 |
+
# {index_queue} ||
|
| 467 |
+
# | ||
|
| 468 |
+
# worker processes || DATA
|
| 469 |
+
# | ||
|
| 470 |
+
# {worker_result_queue} || FLOW
|
| 471 |
+
# | ||
|
| 472 |
+
# pin_memory_thread of main process || DIRECTION
|
| 473 |
+
# | ||
|
| 474 |
+
# {data_queue} ||
|
| 475 |
+
# | ||
|
| 476 |
+
# data output \/
|
| 477 |
+
#
|
| 478 |
+
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
|
| 479 |
+
# `pin_memory=False`.
|
| 480 |
+
#
|
| 481 |
+
#
|
| 482 |
+
# Terminating multiprocessing logic requires very careful design. In
|
| 483 |
+
# particular, we need to make sure that
|
| 484 |
+
#
|
| 485 |
+
# 1. The iterator gracefully exits the workers when its last reference is
|
| 486 |
+
# gone or it is depleted.
|
| 487 |
+
#
|
| 488 |
+
# In this case, the workers should be gracefully exited because the
|
| 489 |
+
# main process may still need to continue to run, and we want cleaning
|
| 490 |
+
# up code in the workers to be executed (e.g., releasing GPU memory).
|
| 491 |
+
# Naturally, we implement the shutdown logic in `__del__` of
|
| 492 |
+
# DataLoaderIterator.
|
| 493 |
+
#
|
| 494 |
+
# We delay the discussion on the logic in this case until later.
|
| 495 |
+
#
|
| 496 |
+
# 2. The iterator exits the workers when the loader process and/or worker
|
| 497 |
+
# processes exits normally or with error.
|
| 498 |
+
#
|
| 499 |
+
# We set all workers and `pin_memory_thread` to have `daemon=True`.
|
| 500 |
+
#
|
| 501 |
+
# You may ask, why can't we make the workers non-daemonic, and
|
| 502 |
+
# gracefully exit using the same logic as we have in `__del__` when the
|
| 503 |
+
# iterator gets deleted (see 1 above)?
|
| 504 |
+
#
|
| 505 |
+
# First of all, `__del__` is **not** guaranteed to be called when
|
| 506 |
+
# interpreter exits. Even if it is called, by the time it executes,
|
| 507 |
+
# many Python core library resources may alreay be freed, and even
|
| 508 |
+
# simple things like acquiring an internal lock of a queue may hang.
|
| 509 |
+
# Therefore, in this case, we actually need to prevent `__del__` from
|
| 510 |
+
# being executed, and rely on the automatic termination of daemonic
|
| 511 |
+
# children. Thus, we register an `atexit` hook that sets a global flag
|
| 512 |
+
# `_utils.python_exit_status`. Since `atexit` hooks are executed in the
|
| 513 |
+
# reverse order of registration, we are guaranteed that this flag is
|
| 514 |
+
# set before library resources we use are freed. (Hooks freeing those
|
| 515 |
+
# resources are registered at importing the Python core libraries at
|
| 516 |
+
# the top of this file.) So in `__del__`, we check if
|
| 517 |
+
# `_utils.python_exit_status` is set or `None` (freed), and perform
|
| 518 |
+
# no-op if so.
|
| 519 |
+
#
|
| 520 |
+
# Another problem with `__del__` is also related to the library cleanup
|
| 521 |
+
# calls. When a process ends, it shuts the all its daemonic children
|
| 522 |
+
# down with a SIGTERM (instead of joining them without a timeout).
|
| 523 |
+
# Simiarly for threads, but by a different mechanism. This fact,
|
| 524 |
+
# together with a few implementation details of multiprocessing, forces
|
| 525 |
+
# us to make workers daemonic. All of our problems arise when a
|
| 526 |
+
# DataLoader is used in a subprocess, and are caused by multiprocessing
|
| 527 |
+
# code which looks more or less like this:
|
| 528 |
+
#
|
| 529 |
+
# try:
|
| 530 |
+
# your_function_using_a_dataloader()
|
| 531 |
+
# finally:
|
| 532 |
+
# multiprocessing.util._exit_function()
|
| 533 |
+
#
|
| 534 |
+
# The joining/termination mentioned above happens inside
|
| 535 |
+
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
|
| 536 |
+
# throws, the stack trace stored in the exception will prevent the
|
| 537 |
+
# frame which uses `DataLoaderIter` to be freed. If the frame has any
|
| 538 |
+
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
|
| 539 |
+
# its `__del__`, which starts the shutdown procedure, will not be
|
| 540 |
+
# called. That, in turn, means that workers aren't notified. Attempting
|
| 541 |
+
# to join in `_exit_function` will then result in a hang.
|
| 542 |
+
#
|
| 543 |
+
# For context, `_exit_function` is also registered as an `atexit` call.
|
| 544 |
+
# So it is unclear to me (@ssnl) why this is needed in a finally block.
|
| 545 |
+
# The code dates back to 2008 and there is no comment on the original
|
| 546 |
+
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
|
| 547 |
+
# the finally block and the `atexit` registration) that explains this.
|
| 548 |
+
#
|
| 549 |
+
# Another choice is to just shutdown workers with logic in 1 above
|
| 550 |
+
# whenever we see an error in `next`. This isn't ideal because
|
| 551 |
+
# a. It prevents users from using try-catch to resume data loading.
|
| 552 |
+
# b. It doesn't prevent hanging if users have references to the
|
| 553 |
+
# iterator.
|
| 554 |
+
#
|
| 555 |
+
# 3. All processes exit if any of them die unexpectedly by fatal signals.
|
| 556 |
+
#
|
| 557 |
+
# As shown above, the workers are set as daemonic children of the main
|
| 558 |
+
# process. However, automatic cleaning-up of such child processes only
|
| 559 |
+
# happens if the parent process exits gracefully (e.g., not via fatal
|
| 560 |
+
# signals like SIGKILL). So we must ensure that each process will exit
|
| 561 |
+
# even the process that should send/receive data to/from it were
|
| 562 |
+
# killed, i.e.,
|
| 563 |
+
#
|
| 564 |
+
# a. A process won't hang when getting from a queue.
|
| 565 |
+
#
|
| 566 |
+
# Even with carefully designed data dependencies (i.e., a `put()`
|
| 567 |
+
# always corresponding to a `get()`), hanging on `get()` can still
|
| 568 |
+
# happen when data in queue is corrupted (e.g., due to
|
| 569 |
+
# `cancel_join_thread` or unexpected exit).
|
| 570 |
+
#
|
| 571 |
+
# For child exit, we set a timeout whenever we try to get data
|
| 572 |
+
# from `data_queue`, and check the workers' status on each timeout
|
| 573 |
+
# and error.
|
| 574 |
+
# See `_DataLoaderiter._get_batch()` and
|
| 575 |
+
# `_DataLoaderiter._try_get_data()` for details.
|
| 576 |
+
#
|
| 577 |
+
# Additionally, for child exit on non-Windows platforms, we also
|
| 578 |
+
# register a SIGCHLD handler (which is supported on Windows) on
|
| 579 |
+
# the main process, which checks if any of the workers fail in the
|
| 580 |
+
# (Python) handler. This is more efficient and faster in detecting
|
| 581 |
+
# worker failures, compared to only using the above mechanism.
|
| 582 |
+
# See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
|
| 583 |
+
#
|
| 584 |
+
# For `.get()` calls where the sender(s) is not the workers, we
|
| 585 |
+
# guard them with timeouts, and check the status of the sender
|
| 586 |
+
# when timeout happens:
|
| 587 |
+
# + in the workers, the `_utils.worker.ManagerWatchdog` class
|
| 588 |
+
# checks the status of the main process.
|
| 589 |
+
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
|
| 590 |
+
# check `pin_memory_thread` status periodically until `.get()`
|
| 591 |
+
# returns or see that `pin_memory_thread` died.
|
| 592 |
+
#
|
| 593 |
+
# b. A process won't hang when putting into a queue;
|
| 594 |
+
#
|
| 595 |
+
# We use `mp.Queue` which has a separate background thread to put
|
| 596 |
+
# objects from an unbounded buffer array. The background thread is
|
| 597 |
+
# daemonic and usually automatically joined when the process
|
| 598 |
+
# exits.
|
| 599 |
+
#
|
| 600 |
+
# However, in case that the receiver has ended abruptly while
|
| 601 |
+
# reading from the pipe, the join will hang forever. Therefore,
|
| 602 |
+
# for both `worker_result_queue` (worker -> main process/pin_memory_thread)
|
| 603 |
+
# and each `index_queue` (main process -> worker), we use
|
| 604 |
+
# `q.cancel_join_thread()` in sender process before any `q.put` to
|
| 605 |
+
# prevent this automatic join.
|
| 606 |
+
#
|
| 607 |
+
# Moreover, having all queues called `cancel_join_thread` makes
|
| 608 |
+
# implementing graceful shutdown logic in `__del__` much easier.
|
| 609 |
+
# It won't need to get from any queue, which would also need to be
|
| 610 |
+
# guarded by periodic status checks.
|
| 611 |
+
#
|
| 612 |
+
# Nonetheless, `cancel_join_thread` must only be called when the
|
| 613 |
+
# queue is **not** going to be read from or write into by another
|
| 614 |
+
# process, because it may hold onto a lock or leave corrupted data
|
| 615 |
+
# in the queue, leading other readers/writers to hang.
|
| 616 |
+
#
|
| 617 |
+
# `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does
|
| 618 |
+
# a blocking `put` if the queue is full. So there is no above
|
| 619 |
+
# problem, but we do need to wrap the `put` in a loop that breaks
|
| 620 |
+
# not only upon success, but also when the main process stops
|
| 621 |
+
# reading, i.e., is shutting down.
|
| 622 |
+
#
|
| 623 |
+
#
|
| 624 |
+
# Now let's get back to 1:
|
| 625 |
+
# how we gracefully exit the workers when the last reference to the
|
| 626 |
+
# iterator is gone.
|
| 627 |
+
#
|
| 628 |
+
# To achieve this, we implement the following logic along with the design
|
| 629 |
+
# choices mentioned above:
|
| 630 |
+
#
|
| 631 |
+
# `workers_done_event`:
|
| 632 |
+
# A `multiprocessing.Event` shared among the main process and all worker
|
| 633 |
+
# processes. This is used to signal the workers that the iterator is
|
| 634 |
+
# shutting down. After it is set, they will not send processed data to
|
| 635 |
+
# queues anymore, and only wait for the final `None` before exiting.
|
| 636 |
+
# `done_event` isn't strictly needed. I.e., we can just check for `None`
|
| 637 |
+
# from the input queue, but it allows us to skip wasting resources
|
| 638 |
+
# processing data if we are already shutting down.
|
| 639 |
+
#
|
| 640 |
+
# `pin_memory_thread_done_event`:
|
| 641 |
+
# A `threading.Event` for a similar purpose to that of
|
| 642 |
+
# `workers_done_event`, but is for the `pin_memory_thread`. The reason
|
| 643 |
+
# that separate events are needed is that `pin_memory_thread` reads from
|
| 644 |
+
# the output queue of the workers. But the workers, upon seeing that
|
| 645 |
+
# `workers_done_event` is set, only wants to see the final `None`, and is
|
| 646 |
+
# not required to flush all data in the output queue (e.g., it may call
|
| 647 |
+
# `cancel_join_thread` on that queue if its `IterableDataset` iterator
|
| 648 |
+
# happens to exhaust coincidentally, which is out of the control of the
|
| 649 |
+
# main process). Thus, since we will exit `pin_memory_thread` before the
|
| 650 |
+
# workers (see below), two separete events are used.
|
| 651 |
+
#
|
| 652 |
+
# NOTE: In short, the protocol is that the main process will set these
|
| 653 |
+
# `done_event`s and then the corresponding processes/threads a `None`,
|
| 654 |
+
# and that they may exit at any time after receiving the `None`.
|
| 655 |
+
#
|
| 656 |
+
# NOTE: Using `None` as the final signal is valid, since normal data will
|
| 657 |
+
# always be a 2-tuple with the 1st element being the index of the data
|
| 658 |
+
# transferred (different from dataset index/key), and the 2nd being
|
| 659 |
+
# either the dataset key or the data sample (depending on which part
|
| 660 |
+
# of the data model the queue is at).
|
| 661 |
+
#
|
| 662 |
+
# [ worker processes ]
|
| 663 |
+
# While loader process is alive:
|
| 664 |
+
# Get from `index_queue`.
|
| 665 |
+
# If get anything else,
|
| 666 |
+
# Check `workers_done_event`.
|
| 667 |
+
# If set, continue to next iteration
|
| 668 |
+
# i.e., keep getting until see the `None`, then exit.
|
| 669 |
+
# Otherwise, process data:
|
| 670 |
+
# If is fetching from an `IterableDataset` and the iterator
|
| 671 |
+
# is exhausted, send an `_IterableDatasetStopIteration`
|
| 672 |
+
# object to signal iteration end. The main process, upon
|
| 673 |
+
# receiving such an object, will send `None` to this
|
| 674 |
+
# worker and not use the corresponding `index_queue`
|
| 675 |
+
# anymore.
|
| 676 |
+
# If timed out,
|
| 677 |
+
# No matter `workers_done_event` is set (still need to see `None`)
|
| 678 |
+
# or not, must continue to next iteration.
|
| 679 |
+
# (outside loop)
|
| 680 |
+
# If `workers_done_event` is set, (this can be False with `IterableDataset`)
|
| 681 |
+
# `data_queue.cancel_join_thread()`. (Everything is ending here:
|
| 682 |
+
# main process won't read from it;
|
| 683 |
+
# other workers will also call
|
| 684 |
+
# `cancel_join_thread`.)
|
| 685 |
+
#
|
| 686 |
+
# [ pin_memory_thread ]
|
| 687 |
+
# # No need to check main thread. If this thread is alive, the main loader
|
| 688 |
+
# # thread must be alive, because this thread is set as daemonic.
|
| 689 |
+
# While `pin_memory_thread_done_event` is not set:
|
| 690 |
+
# Get from `index_queue`.
|
| 691 |
+
# If timed out, continue to get in the next iteration.
|
| 692 |
+
# Otherwise, process data.
|
| 693 |
+
# While `pin_memory_thread_done_event` is not set:
|
| 694 |
+
# Put processed data to `data_queue` (a `queue.Queue` with blocking put)
|
| 695 |
+
# If timed out, continue to put in the next iteration.
|
| 696 |
+
# Otherwise, break, i.e., continuing to the out loop.
|
| 697 |
+
#
|
| 698 |
+
# NOTE: we don't check the status of the main thread because
|
| 699 |
+
# 1. if the process is killed by fatal signal, `pin_memory_thread`
|
| 700 |
+
# ends.
|
| 701 |
+
# 2. in other cases, either the cleaning-up in __del__ or the
|
| 702 |
+
# automatic exit of daemonic thread will take care of it.
|
| 703 |
+
# This won't busy-wait either because `.get(timeout)` does not
|
| 704 |
+
# busy-wait.
|
| 705 |
+
#
|
| 706 |
+
# [ main process ]
|
| 707 |
+
# In the DataLoader Iter's `__del__`
|
| 708 |
+
# b. Exit `pin_memory_thread`
|
| 709 |
+
# i. Set `pin_memory_thread_done_event`.
|
| 710 |
+
# ii Put `None` in `worker_result_queue`.
|
| 711 |
+
# iii. Join the `pin_memory_thread`.
|
| 712 |
+
# iv. `worker_result_queue.cancel_join_thread()`.
|
| 713 |
+
#
|
| 714 |
+
# c. Exit the workers.
|
| 715 |
+
# i. Set `workers_done_event`.
|
| 716 |
+
# ii. Put `None` in each worker's `index_queue`.
|
| 717 |
+
# iii. Join the workers.
|
| 718 |
+
# iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
|
| 719 |
+
#
|
| 720 |
+
# NOTE: (c) is better placed after (b) because it may leave corrupted
|
| 721 |
+
# data in `worker_result_queue`, which `pin_memory_thread`
|
| 722 |
+
# reads from, in which case the `pin_memory_thread` can only
|
| 723 |
+
# happen at timeing out, which is slow. Nonetheless, same thing
|
| 724 |
+
# happens if a worker is killed by signal at unfortunate times,
|
| 725 |
+
# but in other cases, we are better off having a non-corrupted
|
| 726 |
+
# `worker_result_queue` for `pin_memory_thread`.
|
| 727 |
+
#
|
| 728 |
+
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
|
| 729 |
+
# can be omitted
|
| 730 |
+
#
|
| 731 |
+
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
|
| 732 |
+
# `None` from `index_queue`, but it allows us to skip wasting resources
|
| 733 |
+
# processing indices already in `index_queue` if we are already shutting
|
| 734 |
+
# down.
|
| 735 |
+
|
| 736 |
+
def __init__(self, loader):
|
| 737 |
+
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
|
| 738 |
+
|
| 739 |
+
assert self._num_workers > 0
|
| 740 |
+
|
| 741 |
+
if loader.multiprocessing_context is None:
|
| 742 |
+
multiprocessing_context = multiprocessing
|
| 743 |
+
else:
|
| 744 |
+
multiprocessing_context = loader.multiprocessing_context
|
| 745 |
+
|
| 746 |
+
self._worker_init_fn = loader.worker_init_fn
|
| 747 |
+
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
|
| 748 |
+
self._worker_result_queue = multiprocessing_context.Queue()
|
| 749 |
+
self._worker_pids_set = False
|
| 750 |
+
self._shutdown = False
|
| 751 |
+
self._send_idx = 0 # idx of the next task to be sent to workers
|
| 752 |
+
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
|
| 753 |
+
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
|
| 754 |
+
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
|
| 755 |
+
# \ (worker_id, data) if data is already fetched (out-of-order)
|
| 756 |
+
self._task_info = {}
|
| 757 |
+
self._tasks_outstanding = (
|
| 758 |
+
0 # always equal to count(v for v in task_info.values() if len(v) == 1)
|
| 759 |
+
)
|
| 760 |
+
self._workers_done_event = multiprocessing_context.Event()
|
| 761 |
+
|
| 762 |
+
self._index_queues = []
|
| 763 |
+
self._workers = []
|
| 764 |
+
# A list of booleans representing whether each worker still has work to
|
| 765 |
+
# do, i.e., not having exhausted its iterable dataset object. It always
|
| 766 |
+
# contains all `True`s if not using an iterable-style dataset
|
| 767 |
+
# (i.e., if kind != Iterable).
|
| 768 |
+
self._workers_status = []
|
| 769 |
+
for i in range(self._num_workers):
|
| 770 |
+
index_queue = multiprocessing_context.Queue()
|
| 771 |
+
# index_queue.cancel_join_thread()
|
| 772 |
+
w = multiprocessing_context.Process(
|
| 773 |
+
target=worker_loop,
|
| 774 |
+
args=(
|
| 775 |
+
self._dataset_kind,
|
| 776 |
+
self._dataset,
|
| 777 |
+
index_queue,
|
| 778 |
+
self._worker_result_queue,
|
| 779 |
+
self._workers_done_event,
|
| 780 |
+
self._auto_collation,
|
| 781 |
+
self._collate_fn,
|
| 782 |
+
self._drop_last,
|
| 783 |
+
self._base_seed + i,
|
| 784 |
+
self._worker_init_fn,
|
| 785 |
+
i,
|
| 786 |
+
self._num_workers,
|
| 787 |
+
),
|
| 788 |
+
)
|
| 789 |
+
w.daemon = True
|
| 790 |
+
# NB: Process.start() actually take some time as it needs to
|
| 791 |
+
# start a process and pass the arguments over via a pipe.
|
| 792 |
+
# Therefore, we only add a worker to self._workers list after
|
| 793 |
+
# it started, so that we do not call .join() if program dies
|
| 794 |
+
# before it starts, and __del__ tries to join but will get:
|
| 795 |
+
# AssertionError: can only join a started process.
|
| 796 |
+
w.start()
|
| 797 |
+
self._index_queues.append(index_queue)
|
| 798 |
+
self._workers.append(w)
|
| 799 |
+
self._workers_status.append(True)
|
| 800 |
+
|
| 801 |
+
if self._pin_memory:
|
| 802 |
+
self._pin_memory_thread_done_event = threading.Event()
|
| 803 |
+
self._data_queue = queue()
|
| 804 |
+
pin_memory_thread = threading.Thread(
|
| 805 |
+
target=_utils.pin_memory._pin_memory_loop,
|
| 806 |
+
args=(
|
| 807 |
+
self._worker_result_queue,
|
| 808 |
+
self._data_queue,
|
| 809 |
+
torch.cuda.current_device(),
|
| 810 |
+
self._pin_memory_thread_done_event,
|
| 811 |
+
),
|
| 812 |
+
)
|
| 813 |
+
pin_memory_thread.daemon = True
|
| 814 |
+
pin_memory_thread.start()
|
| 815 |
+
# Similar to workers (see comment above), we only register
|
| 816 |
+
# pin_memory_thread once it is started.
|
| 817 |
+
self._pin_memory_thread = pin_memory_thread
|
| 818 |
+
else:
|
| 819 |
+
self._data_queue = self._worker_result_queue
|
| 820 |
+
|
| 821 |
+
_utils.signal_handling._set_worker_pids(
|
| 822 |
+
id(self), tuple(w.pid for w in self._workers)
|
| 823 |
+
)
|
| 824 |
+
_utils.signal_handling._set_SIGCHLD_handler()
|
| 825 |
+
self._worker_pids_set = True
|
| 826 |
+
|
| 827 |
+
# prime the prefetch loop
|
| 828 |
+
for _ in range(2 * self._num_workers):
|
| 829 |
+
self._try_put_index()
|
| 830 |
+
|
| 831 |
+
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
|
| 832 |
+
# Tries to fetch data from `self._data_queue` once for a given timeout.
|
| 833 |
+
# This can also be used as inner loop of fetching without timeout, with
|
| 834 |
+
# the sender status as the loop condition.
|
| 835 |
+
#
|
| 836 |
+
# This raises a `RuntimeError` if any worker died expectedly. This error
|
| 837 |
+
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
|
| 838 |
+
# (only for non-Windows platforms), or the manual check below on errors
|
| 839 |
+
# and timeouts.
|
| 840 |
+
#
|
| 841 |
+
# Returns a 2-tuple:
|
| 842 |
+
# (bool: whether successfully get data, any: data if successful else None)
|
| 843 |
+
try:
|
| 844 |
+
data = self._data_queue.get(timeout=timeout)
|
| 845 |
+
return (True, data)
|
| 846 |
+
except Exception as e:
|
| 847 |
+
# At timeout and error, we manually check whether any worker has
|
| 848 |
+
# failed. Note that this is the only mechanism for Windows to detect
|
| 849 |
+
# worker failures.
|
| 850 |
+
failed_workers = []
|
| 851 |
+
for worker_id, w in enumerate(self._workers):
|
| 852 |
+
if self._workers_status[worker_id] and not w.is_alive():
|
| 853 |
+
failed_workers.append(w)
|
| 854 |
+
self._shutdown_worker(worker_id)
|
| 855 |
+
if len(failed_workers) > 0:
|
| 856 |
+
pids_str = ", ".join(str(w.pid) for w in failed_workers)
|
| 857 |
+
raise RuntimeError(
|
| 858 |
+
"DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)
|
| 859 |
+
)
|
| 860 |
+
if isinstance(e, queue.Empty):
|
| 861 |
+
return (False, None)
|
| 862 |
+
raise
|
| 863 |
+
|
| 864 |
+
def _get_data(self):
|
| 865 |
+
# Fetches data from `self._data_queue`.
|
| 866 |
+
#
|
| 867 |
+
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
|
| 868 |
+
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
|
| 869 |
+
# in a loop. This is the only mechanism to detect worker failures for
|
| 870 |
+
# Windows. For other platforms, a SIGCHLD handler is also used for
|
| 871 |
+
# worker failure detection.
|
| 872 |
+
#
|
| 873 |
+
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
|
| 874 |
+
# died at timeouts.
|
| 875 |
+
if self._timeout > 0:
|
| 876 |
+
success, data = self._try_get_data(self._timeout)
|
| 877 |
+
if success:
|
| 878 |
+
return data
|
| 879 |
+
else:
|
| 880 |
+
raise RuntimeError(
|
| 881 |
+
"DataLoader timed out after {} seconds".format(self._timeout)
|
| 882 |
+
)
|
| 883 |
+
elif self._pin_memory:
|
| 884 |
+
while self._pin_memory_thread.is_alive():
|
| 885 |
+
success, data = self._try_get_data()
|
| 886 |
+
if success:
|
| 887 |
+
return data
|
| 888 |
+
else:
|
| 889 |
+
# while condition is false, i.e., pin_memory_thread died.
|
| 890 |
+
raise RuntimeError("Pin memory thread exited unexpectedly")
|
| 891 |
+
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
|
| 892 |
+
# need to call `.task_done()` because we don't use `.join()`.
|
| 893 |
+
else:
|
| 894 |
+
while True:
|
| 895 |
+
success, data = self._try_get_data()
|
| 896 |
+
if success:
|
| 897 |
+
return data
|
| 898 |
+
|
| 899 |
+
def _next_data(self):
|
| 900 |
+
while True:
|
| 901 |
+
# If the worker responsible for `self._rcvd_idx` has already ended
|
| 902 |
+
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
|
| 903 |
+
# we try to advance `self._rcvd_idx` to find the next valid index.
|
| 904 |
+
#
|
| 905 |
+
# This part needs to run in the loop because both the `self._get_data()`
|
| 906 |
+
# call and `_IterableDatasetStopIteration` check below can mark
|
| 907 |
+
# extra worker(s) as dead.
|
| 908 |
+
while self._rcvd_idx < self._send_idx:
|
| 909 |
+
info = self._task_info[self._rcvd_idx]
|
| 910 |
+
worker_id = info[0]
|
| 911 |
+
if (
|
| 912 |
+
len(info) == 2 or self._workers_status[worker_id]
|
| 913 |
+
): # has data or is still active
|
| 914 |
+
break
|
| 915 |
+
del self._task_info[self._rcvd_idx]
|
| 916 |
+
self._rcvd_idx += 1
|
| 917 |
+
else:
|
| 918 |
+
# no valid `self._rcvd_idx` is found (i.e., didn't break)
|
| 919 |
+
self._shutdown_workers()
|
| 920 |
+
raise StopIteration
|
| 921 |
+
|
| 922 |
+
# Now `self._rcvd_idx` is the batch index we want to fetch
|
| 923 |
+
|
| 924 |
+
# Check if the next sample has already been generated
|
| 925 |
+
if len(self._task_info[self._rcvd_idx]) == 2:
|
| 926 |
+
data = self._task_info.pop(self._rcvd_idx)[1]
|
| 927 |
+
return self._process_data(data)
|
| 928 |
+
|
| 929 |
+
assert not self._shutdown and self._tasks_outstanding > 0
|
| 930 |
+
idx, data = self._get_data()
|
| 931 |
+
self._tasks_outstanding -= 1
|
| 932 |
+
|
| 933 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
| 934 |
+
# Check for _IterableDatasetStopIteration
|
| 935 |
+
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
|
| 936 |
+
self._shutdown_worker(data.worker_id)
|
| 937 |
+
self._try_put_index()
|
| 938 |
+
continue
|
| 939 |
+
|
| 940 |
+
if idx != self._rcvd_idx:
|
| 941 |
+
# store out-of-order samples
|
| 942 |
+
self._task_info[idx] += (data,)
|
| 943 |
+
else:
|
| 944 |
+
del self._task_info[idx]
|
| 945 |
+
return self._process_data(data)
|
| 946 |
+
|
| 947 |
+
def _try_put_index(self):
|
| 948 |
+
assert self._tasks_outstanding < 2 * self._num_workers
|
| 949 |
+
try:
|
| 950 |
+
index = self._next_index()
|
| 951 |
+
except StopIteration:
|
| 952 |
+
return
|
| 953 |
+
for _ in range(self._num_workers): # find the next active worker, if any
|
| 954 |
+
worker_queue_idx = next(self._worker_queue_idx_cycle)
|
| 955 |
+
if self._workers_status[worker_queue_idx]:
|
| 956 |
+
break
|
| 957 |
+
else:
|
| 958 |
+
# not found (i.e., didn't break)
|
| 959 |
+
return
|
| 960 |
+
|
| 961 |
+
self._index_queues[worker_queue_idx].put((self._send_idx, index))
|
| 962 |
+
self._task_info[self._send_idx] = (worker_queue_idx,)
|
| 963 |
+
self._tasks_outstanding += 1
|
| 964 |
+
self._send_idx += 1
|
| 965 |
+
|
| 966 |
+
def _process_data(self, data):
|
| 967 |
+
self._rcvd_idx += 1
|
| 968 |
+
self._try_put_index()
|
| 969 |
+
if isinstance(data, ExceptionWrapper):
|
| 970 |
+
data.reraise()
|
| 971 |
+
return data
|
| 972 |
+
|
| 973 |
+
def _shutdown_worker(self, worker_id):
|
| 974 |
+
# Mark a worker as having finished its work and dead, e.g., due to
|
| 975 |
+
# exhausting an `IterableDataset`. This should be used only when this
|
| 976 |
+
# `_MultiProcessingDataLoaderIter` is going to continue running.
|
| 977 |
+
|
| 978 |
+
assert self._workers_status[worker_id]
|
| 979 |
+
|
| 980 |
+
# Signal termination to that specific worker.
|
| 981 |
+
q = self._index_queues[worker_id]
|
| 982 |
+
# Indicate that no more data will be put on this queue by the current
|
| 983 |
+
# process.
|
| 984 |
+
q.put(None)
|
| 985 |
+
|
| 986 |
+
# Note that we don't actually join the worker here, nor do we remove the
|
| 987 |
+
# worker's pid from C side struct because (1) joining may be slow, and
|
| 988 |
+
# (2) since we don't join, the worker may still raise error, and we
|
| 989 |
+
# prefer capturing those, rather than ignoring them, even though they
|
| 990 |
+
# are raised after the worker has finished its job.
|
| 991 |
+
# Joinning is deferred to `_shutdown_workers`, which it is called when
|
| 992 |
+
# all workers finish their jobs (e.g., `IterableDataset` replicas) or
|
| 993 |
+
# when this iterator is garbage collected.
|
| 994 |
+
self._workers_status[worker_id] = False
|
| 995 |
+
|
| 996 |
+
def _shutdown_workers(self):
|
| 997 |
+
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
|
| 998 |
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
|
| 999 |
+
# the logic of this function.
|
| 1000 |
+
python_exit_status = _utils.python_exit_status
|
| 1001 |
+
if python_exit_status is True or python_exit_status is None:
|
| 1002 |
+
# See (2) of the note. If Python is shutting down, do no-op.
|
| 1003 |
+
return
|
| 1004 |
+
# Normal exit when last reference is gone / iterator is depleted.
|
| 1005 |
+
# See (1) and the second half of the note.
|
| 1006 |
+
if not self._shutdown:
|
| 1007 |
+
self._shutdown = True
|
| 1008 |
+
try:
|
| 1009 |
+
# Exit `pin_memory_thread` first because exiting workers may leave
|
| 1010 |
+
# corrupted data in `worker_result_queue` which `pin_memory_thread`
|
| 1011 |
+
# reads from.
|
| 1012 |
+
if hasattr(self, "_pin_memory_thread"):
|
| 1013 |
+
# Use hasattr in case error happens before we set the attribute.
|
| 1014 |
+
self._pin_memory_thread_done_event.set()
|
| 1015 |
+
# Send something to pin_memory_thread in case it is waiting
|
| 1016 |
+
# so that it can wake up and check `pin_memory_thread_done_event`
|
| 1017 |
+
self._worker_result_queue.put((None, None))
|
| 1018 |
+
self._pin_memory_thread.join()
|
| 1019 |
+
self._worker_result_queue.close()
|
| 1020 |
+
|
| 1021 |
+
# Exit workers now.
|
| 1022 |
+
self._workers_done_event.set()
|
| 1023 |
+
for worker_id in range(len(self._workers)):
|
| 1024 |
+
# Get number of workers from `len(self._workers)` instead of
|
| 1025 |
+
# `self._num_workers` in case we error before starting all
|
| 1026 |
+
# workers.
|
| 1027 |
+
if self._workers_status[worker_id]:
|
| 1028 |
+
self._shutdown_worker(worker_id)
|
| 1029 |
+
for w in self._workers:
|
| 1030 |
+
w.join()
|
| 1031 |
+
for q in self._index_queues:
|
| 1032 |
+
q.cancel_join_thread()
|
| 1033 |
+
q.close()
|
| 1034 |
+
finally:
|
| 1035 |
+
# Even though all this function does is putting into queues that
|
| 1036 |
+
# we have called `cancel_join_thread` on, weird things can
|
| 1037 |
+
# happen when a worker is killed by a signal, e.g., hanging in
|
| 1038 |
+
# `Event.set()`. So we need to guard this with SIGCHLD handler,
|
| 1039 |
+
# and remove pids from the C side data structure only at the
|
| 1040 |
+
# end.
|
| 1041 |
+
#
|
| 1042 |
+
# FIXME: Unfortunately, for Windows, we are missing a worker
|
| 1043 |
+
# error detection mechanism here in this function, as it
|
| 1044 |
+
# doesn't provide a SIGCHLD handler.
|
| 1045 |
+
if self._worker_pids_set:
|
| 1046 |
+
_utils.signal_handling._remove_worker_pids(id(self))
|
| 1047 |
+
self._worker_pids_set = False
|
| 1048 |
+
|
| 1049 |
+
def __del__(self):
|
| 1050 |
+
self._shutdown_workers()
|
proard/utils/my_dataloader/my_data_worker.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
| 2 |
+
|
| 3 |
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
| 4 |
+
static methods.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import random
|
| 9 |
+
import os
|
| 10 |
+
from collections import namedtuple
|
| 11 |
+
# from torch._six import queue
|
| 12 |
+
from torch.multiprocessing import Queue as queue
|
| 13 |
+
from torch._utils import ExceptionWrapper
|
| 14 |
+
from torch.utils.data._utils import (
|
| 15 |
+
signal_handling,
|
| 16 |
+
MP_STATUS_CHECK_INTERVAL,
|
| 17 |
+
IS_WINDOWS,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from .my_random_resize_crop import MyRandomResizedCrop
|
| 21 |
+
|
| 22 |
+
__all__ = ["worker_loop"]
|
| 23 |
+
|
| 24 |
+
if IS_WINDOWS:
|
| 25 |
+
import ctypes
|
| 26 |
+
from ctypes.wintypes import DWORD, BOOL, HANDLE
|
| 27 |
+
|
| 28 |
+
# On Windows, the parent ID of the worker process remains unchanged when the manager process
|
| 29 |
+
# is gone, and the only way to check it through OS is to let the worker have a process handle
|
| 30 |
+
# of the manager and ask if the process status has changed.
|
| 31 |
+
class ManagerWatchdog(object):
|
| 32 |
+
def __init__(self):
|
| 33 |
+
self.manager_pid = os.getppid()
|
| 34 |
+
|
| 35 |
+
self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
|
| 36 |
+
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
|
| 37 |
+
self.kernel32.OpenProcess.restype = HANDLE
|
| 38 |
+
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
|
| 39 |
+
self.kernel32.WaitForSingleObject.restype = DWORD
|
| 40 |
+
|
| 41 |
+
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
|
| 42 |
+
SYNCHRONIZE = 0x00100000
|
| 43 |
+
self.manager_handle = self.kernel32.OpenProcess(
|
| 44 |
+
SYNCHRONIZE, 0, self.manager_pid
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if not self.manager_handle:
|
| 48 |
+
raise ctypes.WinError(ctypes.get_last_error())
|
| 49 |
+
|
| 50 |
+
self.manager_dead = False
|
| 51 |
+
|
| 52 |
+
def is_alive(self):
|
| 53 |
+
if not self.manager_dead:
|
| 54 |
+
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
|
| 55 |
+
self.manager_dead = (
|
| 56 |
+
self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
|
| 57 |
+
)
|
| 58 |
+
return not self.manager_dead
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
else:
|
| 62 |
+
|
| 63 |
+
class ManagerWatchdog(object):
|
| 64 |
+
def __init__(self):
|
| 65 |
+
self.manager_pid = os.getppid()
|
| 66 |
+
self.manager_dead = False
|
| 67 |
+
|
| 68 |
+
def is_alive(self):
|
| 69 |
+
if not self.manager_dead:
|
| 70 |
+
self.manager_dead = os.getppid() != self.manager_pid
|
| 71 |
+
return not self.manager_dead
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
_worker_info = None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class WorkerInfo(object):
|
| 78 |
+
__initialized = False
|
| 79 |
+
|
| 80 |
+
def __init__(self, **kwargs):
|
| 81 |
+
for k, v in kwargs.items():
|
| 82 |
+
setattr(self, k, v)
|
| 83 |
+
self.__initialized = True
|
| 84 |
+
|
| 85 |
+
def __setattr__(self, key, val):
|
| 86 |
+
if self.__initialized:
|
| 87 |
+
raise RuntimeError(
|
| 88 |
+
"Cannot assign attributes to {} objects".format(self.__class__.__name__)
|
| 89 |
+
)
|
| 90 |
+
return super(WorkerInfo, self).__setattr__(key, val)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_worker_info():
|
| 94 |
+
r"""Returns the information about the current
|
| 95 |
+
:class:`~torch.utils.data.DataLoader` iterator worker process.
|
| 96 |
+
|
| 97 |
+
When called in a worker, this returns an object guaranteed to have the
|
| 98 |
+
following attributes:
|
| 99 |
+
|
| 100 |
+
* :attr:`id`: the current worker id.
|
| 101 |
+
* :attr:`num_workers`: the total number of workers.
|
| 102 |
+
* :attr:`seed`: the random seed set for the current worker. This value is
|
| 103 |
+
determined by main process RNG and the worker id. See
|
| 104 |
+
:class:`~torch.utils.data.DataLoader`'s documentation for more details.
|
| 105 |
+
* :attr:`dataset`: the copy of the dataset object in **this** process. Note
|
| 106 |
+
that this will be a different object in a different process than the one
|
| 107 |
+
in the main process.
|
| 108 |
+
|
| 109 |
+
When called in the main process, this returns ``None``.
|
| 110 |
+
|
| 111 |
+
.. note::
|
| 112 |
+
When used in a :attr:`worker_init_fn` passed over to
|
| 113 |
+
:class:`~torch.utils.data.DataLoader`, this method can be useful to
|
| 114 |
+
set up each worker process differently, for instance, using ``worker_id``
|
| 115 |
+
to configure the ``dataset`` object to only read a specific fraction of a
|
| 116 |
+
sharded dataset, or use ``seed`` to seed other libraries used in dataset
|
| 117 |
+
code (e.g., NumPy).
|
| 118 |
+
"""
|
| 119 |
+
return _worker_info
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
r"""Dummy class used to signal the end of an IterableDataset"""
|
| 123 |
+
_IterableDatasetStopIteration = namedtuple(
|
| 124 |
+
"_IterableDatasetStopIteration", ["worker_id"]
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def worker_loop(
|
| 129 |
+
dataset_kind,
|
| 130 |
+
dataset,
|
| 131 |
+
index_queue,
|
| 132 |
+
data_queue,
|
| 133 |
+
done_event,
|
| 134 |
+
auto_collation,
|
| 135 |
+
collate_fn,
|
| 136 |
+
drop_last,
|
| 137 |
+
seed,
|
| 138 |
+
init_fn,
|
| 139 |
+
worker_id,
|
| 140 |
+
num_workers,
|
| 141 |
+
):
|
| 142 |
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
| 143 |
+
# logic of this function.
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
| 147 |
+
# module's handlers are executed after Python returns from C low-level
|
| 148 |
+
# handlers, likely when the same fatal signal had already happened
|
| 149 |
+
# again.
|
| 150 |
+
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
|
| 151 |
+
signal_handling._set_worker_signal_handlers()
|
| 152 |
+
|
| 153 |
+
torch.set_num_threads(1)
|
| 154 |
+
random.seed(seed)
|
| 155 |
+
torch.manual_seed(seed)
|
| 156 |
+
|
| 157 |
+
global _worker_info
|
| 158 |
+
_worker_info = WorkerInfo(
|
| 159 |
+
id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
from torch.utils.data import _DatasetKind
|
| 163 |
+
|
| 164 |
+
init_exception = None
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
if init_fn is not None:
|
| 168 |
+
init_fn(worker_id)
|
| 169 |
+
|
| 170 |
+
fetcher = _DatasetKind.create_fetcher(
|
| 171 |
+
dataset_kind, dataset, auto_collation, collate_fn, drop_last
|
| 172 |
+
)
|
| 173 |
+
except Exception:
|
| 174 |
+
init_exception = ExceptionWrapper(
|
| 175 |
+
where="in DataLoader worker process {}".format(worker_id)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# When using Iterable mode, some worker can exit earlier than others due
|
| 179 |
+
# to the IterableDataset behaving differently for different workers.
|
| 180 |
+
# When such things happen, an `_IterableDatasetStopIteration` object is
|
| 181 |
+
# sent over to the main process with the ID of this worker, so that the
|
| 182 |
+
# main process won't send more tasks to this worker, and will send
|
| 183 |
+
# `None` to this worker to properly exit it.
|
| 184 |
+
#
|
| 185 |
+
# Note that we cannot set `done_event` from a worker as it is shared
|
| 186 |
+
# among all processes. Instead, we set the `iteration_end` flag to
|
| 187 |
+
# signify that the iterator is exhausted. When either `done_event` or
|
| 188 |
+
# `iteration_end` is set, we skip all processing step and just wait for
|
| 189 |
+
# `None`.
|
| 190 |
+
iteration_end = False
|
| 191 |
+
|
| 192 |
+
watchdog = ManagerWatchdog()
|
| 193 |
+
|
| 194 |
+
while watchdog.is_alive():
|
| 195 |
+
try:
|
| 196 |
+
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
| 197 |
+
except queue.Empty:
|
| 198 |
+
continue
|
| 199 |
+
if r is None:
|
| 200 |
+
# Received the final signal
|
| 201 |
+
assert done_event.is_set() or iteration_end
|
| 202 |
+
break
|
| 203 |
+
elif done_event.is_set() or iteration_end:
|
| 204 |
+
# `done_event` is set. But I haven't received the final signal
|
| 205 |
+
# (None) yet. I will keep continuing until get it, and skip the
|
| 206 |
+
# processing steps.
|
| 207 |
+
continue
|
| 208 |
+
idx, index = r
|
| 209 |
+
""" Added """
|
| 210 |
+
MyRandomResizedCrop.sample_image_size(idx)
|
| 211 |
+
""" Added """
|
| 212 |
+
if init_exception is not None:
|
| 213 |
+
data = init_exception
|
| 214 |
+
init_exception = None
|
| 215 |
+
else:
|
| 216 |
+
try:
|
| 217 |
+
data = fetcher.fetch(index)
|
| 218 |
+
except Exception as e:
|
| 219 |
+
if (
|
| 220 |
+
isinstance(e, StopIteration)
|
| 221 |
+
and dataset_kind == _DatasetKind.Iterable
|
| 222 |
+
):
|
| 223 |
+
data = _IterableDatasetStopIteration(worker_id)
|
| 224 |
+
# Set `iteration_end`
|
| 225 |
+
# (1) to save future `next(...)` calls, and
|
| 226 |
+
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
|
| 227 |
+
iteration_end = True
|
| 228 |
+
else:
|
| 229 |
+
# It is important that we don't store exc_info in a variable.
|
| 230 |
+
# `ExceptionWrapper` does the correct thing.
|
| 231 |
+
# See NOTE [ Python Traceback Reference Cycle Problem ]
|
| 232 |
+
data = ExceptionWrapper(
|
| 233 |
+
where="in DataLoader worker process {}".format(worker_id)
|
| 234 |
+
)
|
| 235 |
+
data_queue.put((idx, data))
|
| 236 |
+
del data, idx, index, r # save memory
|
| 237 |
+
except KeyboardInterrupt:
|
| 238 |
+
# Main process will raise KeyboardInterrupt anyways.
|
| 239 |
+
pass
|
| 240 |
+
if done_event.is_set():
|
| 241 |
+
data_queue.cancel_join_thread()
|
| 242 |
+
data_queue.close()
|