smi08 commited on
Commit
188f311
·
verified ·
1 Parent(s): 41fccda

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. proard/__init__.py +0 -0
  2. proard/classification/__init__.py +0 -0
  3. proard/classification/data_providers/__init__.py +3 -0
  4. proard/classification/data_providers/base_provider.py +58 -0
  5. proard/classification/data_providers/cifar10.py +264 -0
  6. proard/classification/data_providers/cifar100.py +264 -0
  7. proard/classification/data_providers/imagenet.py +310 -0
  8. proard/classification/elastic_nn/__init__.py +0 -0
  9. proard/classification/elastic_nn/modules/__init__.py +6 -0
  10. proard/classification/elastic_nn/modules/dynamic_layers.py +841 -0
  11. proard/classification/elastic_nn/modules/dynamic_op.py +401 -0
  12. proard/classification/elastic_nn/networks/__init__.py +7 -0
  13. proard/classification/elastic_nn/networks/dyn_mbv3.py +780 -0
  14. proard/classification/elastic_nn/networks/dyn_proxyless.py +774 -0
  15. proard/classification/elastic_nn/networks/dyn_resnets.py +678 -0
  16. proard/classification/elastic_nn/training/__init__.py +6 -0
  17. proard/classification/elastic_nn/training/progressive_shrinking.py +463 -0
  18. proard/classification/elastic_nn/utils.py +83 -0
  19. proard/classification/networks/__init__.py +25 -0
  20. proard/classification/networks/mobilenet_v3.py +559 -0
  21. proard/classification/networks/proxyless_nets.py +490 -0
  22. proard/classification/networks/resnet_trades.py +115 -0
  23. proard/classification/networks/resnets.py +490 -0
  24. proard/classification/networks/wide_resnet.py +93 -0
  25. proard/classification/run_manager/__init__.py +7 -0
  26. proard/classification/run_manager/distributed_run_manager.py +505 -0
  27. proard/classification/run_manager/run_config.py +414 -0
  28. proard/classification/run_manager/run_manager.py +484 -0
  29. proard/model_zoo.py +162 -0
  30. proard/nas/__init__.py +0 -0
  31. proard/nas/accuracy_predictor/__init__.py +11 -0
  32. proard/nas/accuracy_predictor/acc_dataset.py +213 -0
  33. proard/nas/accuracy_predictor/acc_predictor.py +68 -0
  34. proard/nas/accuracy_predictor/acc_rob_dataset.py +219 -0
  35. proard/nas/accuracy_predictor/acc_rob_predictor.py +77 -0
  36. proard/nas/accuracy_predictor/arch_encoder.py +372 -0
  37. proard/nas/accuracy_predictor/rob_dataset.py +211 -0
  38. proard/nas/accuracy_predictor/rob_predictor.py +66 -0
  39. proard/nas/efficiency_predictor/__init__.py +78 -0
  40. proard/nas/efficiency_predictor/latency_lookup_table.py +567 -0
  41. proard/nas/search_algorithm/__init__.py +6 -0
  42. proard/nas/search_algorithm/evolution.py +143 -0
  43. proard/nas/search_algorithm/multi_evolution.py +143 -0
  44. proard/utils/__init__.py +10 -0
  45. proard/utils/common_tools.py +307 -0
  46. proard/utils/flops_counter.py +97 -0
  47. proard/utils/layers.py +819 -0
  48. proard/utils/my_dataloader/__init__.py +2 -0
  49. proard/utils/my_dataloader/my_data_loader.py +1050 -0
  50. proard/utils/my_dataloader/my_data_worker.py +242 -0
proard/__init__.py ADDED
File without changes
proard/classification/__init__.py ADDED
File without changes
proard/classification/data_providers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .cifar10 import *
2
+ from .cifar100 import *
3
+ from .imagenet import *
proard/classification/data_providers/base_provider.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ __all__ = ["DataProvider"]
9
+
10
+
11
+ class DataProvider:
12
+ SUB_SEED = 937162211 # random seed for sampling subset
13
+ VALID_SEED = 2147483647 # random seed for the validation set
14
+
15
+ @staticmethod
16
+ def name():
17
+ """Return name of the dataset"""
18
+ raise NotImplementedError
19
+
20
+ @property
21
+ def data_shape(self):
22
+ """Return shape as python list of one data entry"""
23
+ raise NotImplementedError
24
+
25
+ @property
26
+ def n_classes(self):
27
+ """Return `int` of num classes"""
28
+ raise NotImplementedError
29
+
30
+ @property
31
+ def save_path(self):
32
+ """local path to save the data"""
33
+ raise NotImplementedError
34
+
35
+ @property
36
+ def data_url(self):
37
+ """link to download the data"""
38
+ raise NotImplementedError
39
+
40
+ @staticmethod
41
+ def random_sample_valid_set(train_size, valid_size):
42
+ assert train_size > valid_size
43
+
44
+ g = torch.Generator()
45
+ g.manual_seed(
46
+ DataProvider.VALID_SEED
47
+ ) # set random seed before sampling validation set
48
+ rand_indexes = torch.randperm(train_size, generator=g).tolist()
49
+
50
+ valid_indexes = rand_indexes[:valid_size]
51
+ train_indexes = rand_indexes[valid_size:]
52
+ return train_indexes, valid_indexes
53
+
54
+ @staticmethod
55
+ def labels_to_one_hot(n_classes, labels):
56
+ new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
57
+ new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
58
+ return new_labels
proard/classification/data_providers/cifar10.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import os
3
+ import math
4
+ import numpy as np
5
+ import torch.utils.data
6
+ import torchvision.transforms as transforms
7
+ import torchvision.datasets as datasets
8
+ from .base_provider import DataProvider
9
+ from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
10
+
11
+ __all__ = ["Cifar10DataProvider"]
12
+
13
+ class Cifar10DataProvider(DataProvider):
14
+ DEFAULT_PATH = "./dataset/cifar10"
15
+ def __init__(
16
+ self,
17
+ save_path=None,
18
+ train_batch_size=256,
19
+ test_batch_size=512,
20
+ valid_size=None,
21
+ resize_scale=0.08,
22
+ distort_color=None,
23
+ n_worker=32,
24
+ image_size=32,
25
+ num_replicas=None,
26
+ rank=None,
27
+ ):
28
+
29
+ warnings.filterwarnings("ignore")
30
+ self._save_path = save_path
31
+
32
+ self.image_size = image_size # int or list of int
33
+
34
+
35
+ self._valid_transform_dict = {}
36
+ if not isinstance(self.image_size, int):
37
+ from proard.utils.my_dataloader.my_data_loader import MyDataLoader
38
+
39
+ assert isinstance(self.image_size, list)
40
+ self.image_size.sort() # e.g., 160 -> 224
41
+ MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
42
+ MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
43
+
44
+ for img_size in self.image_size:
45
+ self._valid_transform_dict[img_size] = self.build_valid_transform(
46
+ img_size
47
+ )
48
+ self.active_img_size = max(self.image_size) # active resolution for test
49
+ valid_transforms = self._valid_transform_dict[self.active_img_size]
50
+ train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
51
+ else:
52
+ self.active_img_size = self.image_size
53
+ valid_transforms = self.build_valid_transform()
54
+ train_loader_class = torch.utils.data.DataLoader
55
+
56
+ train_dataset = self.train_dataset(self.build_train_transform())
57
+
58
+ if valid_size is not None:
59
+ if not isinstance(valid_size, int):
60
+ assert isinstance(valid_size, float) and 0 < valid_size < 1
61
+ valid_size = int(len(train_dataset) * valid_size)
62
+
63
+ valid_dataset = self.train_dataset(valid_transforms)
64
+ train_indexes, valid_indexes = self.random_sample_valid_set(
65
+ len(train_dataset), valid_size
66
+ )
67
+
68
+ if num_replicas is not None:
69
+ train_sampler = MyDistributedSampler(
70
+ train_dataset, num_replicas, rank, True, np.array(train_indexes)
71
+ )
72
+ valid_sampler = MyDistributedSampler(
73
+ valid_dataset, num_replicas, rank, True, np.array(valid_indexes)
74
+ )
75
+ else:
76
+ train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
77
+ train_indexes
78
+ )
79
+ valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
80
+ valid_indexes
81
+ )
82
+
83
+ self.train = train_loader_class(
84
+ train_dataset,
85
+ batch_size=train_batch_size,
86
+ sampler=train_sampler,
87
+ num_workers=n_worker,
88
+ pin_memory=False,
89
+ )
90
+ self.valid = torch.utils.data.DataLoader(
91
+ valid_dataset,
92
+ batch_size=test_batch_size,
93
+ sampler=valid_sampler,
94
+ num_workers=n_worker,
95
+ pin_memory=False,
96
+ )
97
+ else:
98
+ if num_replicas is not None:
99
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
100
+ train_dataset, num_replicas, rank
101
+ )
102
+ self.train = train_loader_class(
103
+ train_dataset,
104
+ batch_size=train_batch_size,
105
+ sampler=train_sampler,
106
+ num_workers=n_worker,
107
+ pin_memory=True,
108
+ )
109
+ else:
110
+ self.train = train_loader_class(
111
+ train_dataset,
112
+ batch_size=train_batch_size,
113
+ shuffle=True,
114
+ num_workers=n_worker,
115
+ pin_memory=False,
116
+ )
117
+ self.valid = None
118
+
119
+ test_dataset = self.test_dataset(valid_transforms)
120
+ if num_replicas is not None:
121
+ test_sampler = torch.utils.data.distributed.DistributedSampler(
122
+ test_dataset, num_replicas, rank
123
+ )
124
+ self.test = torch.utils.data.DataLoader(
125
+ test_dataset,
126
+ batch_size=test_batch_size,
127
+ sampler=test_sampler,
128
+ num_workers=n_worker,
129
+ pin_memory=False,
130
+ )
131
+ else:
132
+ self.test = torch.utils.data.DataLoader(
133
+ test_dataset,
134
+ batch_size=test_batch_size,
135
+ shuffle=True,
136
+ num_workers=n_worker,
137
+ pin_memory=False,
138
+ )
139
+
140
+ if self.valid is None:
141
+ self.valid = self.test
142
+
143
+ @staticmethod
144
+ def name():
145
+ return "cifar10"
146
+
147
+ @property
148
+ def data_shape(self):
149
+ return 3, self.active_img_size, self.active_img_size # C, H, W
150
+
151
+ @property
152
+ def n_classes(self):
153
+ return 10
154
+
155
+ @property
156
+ def save_path(self):
157
+ if self._save_path is None:
158
+ self._save_path = self.DEFAULT_PATH
159
+ if not os.path.exists(self._save_path):
160
+ self._save_path = os.path.expanduser("~/dataset/cifar10")
161
+ return self._save_path
162
+
163
+ @property
164
+ def data_url(self):
165
+ raise ValueError("unable to download %s" % self.name())
166
+
167
+ def train_dataset(self, _transforms):
168
+ return datasets.CIFAR10(self.train_path, train=True, transform=_transforms,download=True)
169
+
170
+ def test_dataset(self, _transforms):
171
+ return datasets.CIFAR10(self.valid_path, train=False, transform=_transforms,download=True)
172
+ @property
173
+ def train_path(self):
174
+ return os.path.join(self.save_path, "train")
175
+
176
+ @property
177
+ def valid_path(self):
178
+ return os.path.join(self.save_path, "val")
179
+
180
+ @property
181
+ def normalize(self):
182
+ return transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
183
+
184
+ def build_train_transform(self, image_size=None, print_log=True):
185
+ if image_size is None:
186
+ image_size = self.image_size
187
+
188
+ # random_resize_crop -> random_horizontal_flip
189
+ train_transforms = [
190
+ transforms.RandomCrop(32,padding=4),
191
+ transforms.RandomHorizontalFlip(),
192
+ # AutoAugment(),
193
+ ]
194
+
195
+ train_transforms += [
196
+ transforms.ToTensor(),
197
+ # self.normalize,
198
+ ]
199
+
200
+ train_transforms = transforms.Compose(train_transforms)
201
+ return train_transforms
202
+
203
+ def build_valid_transform(self, image_size=None):
204
+ if image_size is None:
205
+ image_size = self.active_img_size
206
+ return transforms.Compose([
207
+ transforms.ToTensor(),
208
+ # self.normalize,
209
+ ])
210
+
211
+ def assign_active_img_size(self, new_img_size):
212
+ self.active_img_size = new_img_size
213
+ if self.active_img_size not in self._valid_transform_dict:
214
+ self._valid_transform_dict[
215
+ self.active_img_size
216
+ ] = self.build_valid_transform()
217
+ # change the transform of the valid and test set
218
+ self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
219
+ self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
220
+
221
+ def build_sub_train_loader(
222
+ self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None
223
+ ):
224
+ # used for resetting BN running statistics
225
+ if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None:
226
+ if num_worker is None:
227
+ num_worker = self.train.num_workers
228
+
229
+ n_samples = len(self.train.dataset)
230
+ g = torch.Generator()
231
+ g.manual_seed(DataProvider.SUB_SEED)
232
+ rand_indexes = torch.randperm(n_samples, generator=g).tolist()
233
+
234
+ new_train_dataset = self.train_dataset(
235
+ self.build_train_transform(
236
+ image_size=self.active_img_size, print_log=False
237
+ )
238
+ )
239
+ chosen_indexes = rand_indexes[:n_images]
240
+ if num_replicas is not None:
241
+ sub_sampler = MyDistributedSampler(
242
+ new_train_dataset,
243
+ num_replicas,
244
+ rank,
245
+ True,
246
+ np.array(chosen_indexes),
247
+ )
248
+ else:
249
+ sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(
250
+ chosen_indexes
251
+ )
252
+ sub_data_loader = torch.utils.data.DataLoader(
253
+ new_train_dataset,
254
+ batch_size=batch_size,
255
+ sampler=sub_sampler,
256
+ num_workers=num_worker,
257
+ pin_memory=False,
258
+ )
259
+ self.__dict__["sub_train_%d" % self.active_img_size] = []
260
+ for images, labels in sub_data_loader:
261
+ self.__dict__["sub_train_%d" % self.active_img_size].append(
262
+ (images, labels)
263
+ )
264
+ return self.__dict__["sub_train_%d" % self.active_img_size]
proard/classification/data_providers/cifar100.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import os
3
+ import math
4
+ import numpy as np
5
+ import torch.utils.data
6
+ import torchvision.transforms as transforms
7
+ import torchvision.datasets as datasets
8
+ from .base_provider import DataProvider
9
+ from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
10
+
11
+ __all__ = ["Cifar100DataProvider"]
12
+
13
+ class Cifar100DataProvider(DataProvider):
14
+ DEFAULT_PATH = "./dataset/cifar100"
15
+ def __init__(
16
+ self,
17
+ save_path=None,
18
+ train_batch_size=256,
19
+ test_batch_size=512,
20
+ resize_scale=0.08,
21
+ distort_color=None,
22
+ valid_size=None,
23
+ n_worker=32,
24
+ image_size=32,
25
+ num_replicas=None,
26
+ rank=None,
27
+ ):
28
+
29
+ warnings.filterwarnings("ignore")
30
+ self._save_path = save_path
31
+
32
+ self.image_size = image_size # int or list of int
33
+
34
+
35
+ self._valid_transform_dict = {}
36
+ if not isinstance(self.image_size, int):
37
+ from proard.utils.my_dataloader.my_data_loader import MyDataLoader
38
+
39
+ assert isinstance(self.image_size, list)
40
+ self.image_size.sort() # e.g., 160 -> 224
41
+ MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
42
+ MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
43
+
44
+ for img_size in self.image_size:
45
+ self._valid_transform_dict[img_size] = self.build_valid_transform(
46
+ img_size
47
+ )
48
+ self.active_img_size = max(self.image_size) # active resolution for test
49
+ valid_transforms = self._valid_transform_dict[self.active_img_size]
50
+ train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
51
+ else:
52
+ self.active_img_size = self.image_size
53
+ valid_transforms = self.build_valid_transform()
54
+ train_loader_class = torch.utils.data.DataLoader
55
+
56
+ train_dataset = self.train_dataset(self.build_train_transform())
57
+
58
+ if valid_size is not None:
59
+ if not isinstance(valid_size, int):
60
+ assert isinstance(valid_size, float) and 0 < valid_size < 1
61
+ valid_size = int(len(train_dataset) * valid_size)
62
+
63
+ valid_dataset = self.train_dataset(valid_transforms)
64
+ train_indexes, valid_indexes = self.random_sample_valid_set(
65
+ len(train_dataset), valid_size
66
+ )
67
+
68
+ if num_replicas is not None:
69
+ train_sampler = MyDistributedSampler(
70
+ train_dataset, num_replicas, rank, True, np.array(train_indexes)
71
+ )
72
+ valid_sampler = MyDistributedSampler(
73
+ valid_dataset, num_replicas, rank, True, np.array(valid_indexes)
74
+ )
75
+ else:
76
+ train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
77
+ train_indexes
78
+ )
79
+ valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
80
+ valid_indexes
81
+ )
82
+
83
+ self.train = train_loader_class(
84
+ train_dataset,
85
+ batch_size=train_batch_size,
86
+ sampler=train_sampler,
87
+ num_workers=n_worker,
88
+ pin_memory=False,
89
+ )
90
+ self.valid = torch.utils.data.DataLoader(
91
+ valid_dataset,
92
+ batch_size=test_batch_size,
93
+ sampler=valid_sampler,
94
+ num_workers=n_worker,
95
+ pin_memory=False,
96
+ )
97
+ else:
98
+ if num_replicas is not None:
99
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
100
+ train_dataset, num_replicas, rank
101
+ )
102
+ self.train = train_loader_class(
103
+ train_dataset,
104
+ batch_size=train_batch_size,
105
+ sampler=train_sampler,
106
+ num_workers=n_worker,
107
+ pin_memory=True,
108
+ )
109
+ else:
110
+ self.train = train_loader_class(
111
+ train_dataset,
112
+ batch_size=train_batch_size,
113
+ shuffle=True,
114
+ num_workers=n_worker,
115
+ pin_memory=False,
116
+ )
117
+ self.valid = None
118
+
119
+ test_dataset = self.test_dataset(valid_transforms)
120
+ if num_replicas is not None:
121
+ test_sampler = torch.utils.data.distributed.DistributedSampler(
122
+ test_dataset, num_replicas, rank
123
+ )
124
+ self.test = torch.utils.data.DataLoader(
125
+ test_dataset,
126
+ batch_size=test_batch_size,
127
+ sampler=test_sampler,
128
+ num_workers=n_worker,
129
+ pin_memory=False,
130
+ )
131
+ else:
132
+ self.test = torch.utils.data.DataLoader(
133
+ test_dataset,
134
+ batch_size=test_batch_size,
135
+ shuffle=True,
136
+ num_workers=n_worker,
137
+ pin_memory=False,
138
+ )
139
+
140
+ if self.valid is None:
141
+ self.valid = self.test
142
+
143
+ @staticmethod
144
+ def name():
145
+ return "cifar100"
146
+
147
+ @property
148
+ def data_shape(self):
149
+ return 3, self.active_img_size, self.active_img_size # C, H, W
150
+
151
+ @property
152
+ def n_classes(self):
153
+ return 100
154
+
155
+ @property
156
+ def save_path(self):
157
+ if self._save_path is None:
158
+ self._save_path = self.DEFAULT_PATH
159
+ if not os.path.exists(self._save_path):
160
+ self._save_path = os.path.expanduser("~/dataset/cifar100")
161
+ return self._save_path
162
+
163
+ @property
164
+ def data_url(self):
165
+ raise ValueError("unable to download %s" % self.name())
166
+
167
+ def train_dataset(self, _transforms):
168
+ return datasets.CIFAR100(self.train_path, train=True, transform=_transforms,download=True)
169
+
170
+ def test_dataset(self, _transforms):
171
+ return datasets.CIFAR100(self.valid_path, train=False, transform=_transforms,download=True)
172
+ @property
173
+ def train_path(self):
174
+ return os.path.join(self.save_path, "train")
175
+
176
+ @property
177
+ def valid_path(self):
178
+ return os.path.join(self.save_path, "val")
179
+
180
+ @property
181
+ def normalize(self):
182
+ return transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
183
+
184
+ def build_train_transform(self, image_size=None, print_log=True):
185
+ if image_size is None:
186
+ image_size = self.image_size
187
+
188
+ # random_resize_crop -> random_horizontal_flip
189
+ train_transforms = [
190
+ transforms.RandomCrop(32,padding=4),
191
+ transforms.RandomHorizontalFlip(),
192
+ # AutoAugment(),
193
+ ]
194
+
195
+ train_transforms += [
196
+ transforms.ToTensor(),
197
+ # self.normalize,
198
+ ]
199
+
200
+ train_transforms = transforms.Compose(train_transforms)
201
+ return train_transforms
202
+
203
+ def build_valid_transform(self, image_size=None):
204
+ if image_size is None:
205
+ image_size = self.active_img_size
206
+ return transforms.Compose([
207
+ transforms.ToTensor(),
208
+ # self.normalize,
209
+ ])
210
+
211
+ def assign_active_img_size(self, new_img_size):
212
+ self.active_img_size = new_img_size
213
+ if self.active_img_size not in self._valid_transform_dict:
214
+ self._valid_transform_dict[
215
+ self.active_img_size
216
+ ] = self.build_valid_transform()
217
+ # change the transform of the valid and test set
218
+ self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
219
+ self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
220
+
221
+ def build_sub_train_loader(
222
+ self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None
223
+ ):
224
+ # used for resetting BN running statistics
225
+ if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None:
226
+ if num_worker is None:
227
+ num_worker = self.train.num_workers
228
+
229
+ n_samples = len(self.train.dataset)
230
+ g = torch.Generator()
231
+ g.manual_seed(DataProvider.SUB_SEED)
232
+ rand_indexes = torch.randperm(n_samples, generator=g).tolist()
233
+
234
+ new_train_dataset = self.train_dataset(
235
+ self.build_train_transform(
236
+ image_size=self.active_img_size, print_log=False
237
+ )
238
+ )
239
+ chosen_indexes = rand_indexes[:n_images]
240
+ if num_replicas is not None:
241
+ sub_sampler = MyDistributedSampler(
242
+ new_train_dataset,
243
+ num_replicas,
244
+ rank,
245
+ True,
246
+ np.array(chosen_indexes),
247
+ )
248
+ else:
249
+ sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(
250
+ chosen_indexes
251
+ )
252
+ sub_data_loader = torch.utils.data.DataLoader(
253
+ new_train_dataset,
254
+ batch_size=batch_size,
255
+ sampler=sub_sampler,
256
+ num_workers=num_worker,
257
+ pin_memory=False,
258
+ )
259
+ self.__dict__["sub_train_%d" % self.active_img_size] = []
260
+ for images, labels in sub_data_loader:
261
+ self.__dict__["sub_train_%d" % self.active_img_size].append(
262
+ (images, labels)
263
+ )
264
+ return self.__dict__["sub_train_%d" % self.active_img_size]
proard/classification/data_providers/imagenet.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import warnings
6
+ import os
7
+ import math
8
+ import numpy as np
9
+ import torch.utils.data
10
+ import torchvision.transforms as transforms
11
+ import torchvision.datasets as datasets
12
+
13
+ from .base_provider import DataProvider
14
+ from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
15
+
16
+ __all__ = ["ImagenetDataProvider"]
17
+
18
+
19
+ class ImagenetDataProvider(DataProvider):
20
+ DEFAULT_PATH = "./dataset/imagenet"
21
+
22
+ def __init__(
23
+ self,
24
+ save_path=None,
25
+ train_batch_size=256,
26
+ test_batch_size=512,
27
+ valid_size=None,
28
+ n_worker=32,
29
+ resize_scale=0.08,
30
+ distort_color=None,
31
+ image_size=224,
32
+ num_replicas=None,
33
+ rank=None,
34
+ ):
35
+
36
+ warnings.filterwarnings("ignore")
37
+ self._save_path = save_path
38
+
39
+ self.image_size = image_size # int or list of int
40
+ self.distort_color = "None" if distort_color is None else distort_color
41
+ self.resize_scale = resize_scale
42
+
43
+ self._valid_transform_dict = {}
44
+ if not isinstance(self.image_size, int):
45
+ from proard.utils.my_dataloader.my_data_loader import MyDataLoader
46
+
47
+ assert isinstance(self.image_size, list)
48
+ self.image_size.sort() # e.g., 160 -> 224
49
+ MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
50
+ MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
51
+
52
+ for img_size in self.image_size:
53
+ self._valid_transform_dict[img_size] = self.build_valid_transform(
54
+ img_size
55
+ )
56
+ self.active_img_size = max(self.image_size) # active resolution for test
57
+ valid_transforms = self._valid_transform_dict[self.active_img_size]
58
+ train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
59
+ else:
60
+ self.active_img_size = self.image_size
61
+ valid_transforms = self.build_valid_transform()
62
+ train_loader_class = torch.utils.data.DataLoader
63
+
64
+ train_dataset = self.train_dataset(self.build_train_transform())
65
+
66
+ if valid_size is not None:
67
+ if not isinstance(valid_size, int):
68
+ assert isinstance(valid_size, float) and 0 < valid_size < 1
69
+ valid_size = int(len(train_dataset) * valid_size)
70
+
71
+ valid_dataset = self.train_dataset(valid_transforms)
72
+ train_indexes, valid_indexes = self.random_sample_valid_set(
73
+ len(train_dataset), valid_size
74
+ )
75
+
76
+ if num_replicas is not None:
77
+ train_sampler = MyDistributedSampler(
78
+ train_dataset, num_replicas, rank, True, np.array(train_indexes)
79
+ )
80
+ valid_sampler = MyDistributedSampler(
81
+ valid_dataset, num_replicas, rank, True, np.array(valid_indexes)
82
+ )
83
+ else:
84
+ train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
85
+ train_indexes
86
+ )
87
+ valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
88
+ valid_indexes
89
+ )
90
+
91
+ self.train = train_loader_class(
92
+ train_dataset,
93
+ batch_size=train_batch_size,
94
+ sampler=train_sampler,
95
+ num_workers=n_worker,
96
+ pin_memory=False,
97
+ )
98
+ self.valid = torch.utils.data.DataLoader(
99
+ valid_dataset,
100
+ batch_size=test_batch_size,
101
+ sampler=valid_sampler,
102
+ num_workers=n_worker,
103
+ pin_memory=False,
104
+ )
105
+ else:
106
+ if num_replicas is not None:
107
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
108
+ train_dataset, num_replicas, rank
109
+ )
110
+ self.train = train_loader_class(
111
+ train_dataset,
112
+ batch_size=train_batch_size,
113
+ sampler=train_sampler,
114
+ num_workers=n_worker,
115
+ pin_memory=True,
116
+ )
117
+ else:
118
+ self.train = train_loader_class(
119
+ train_dataset,
120
+ batch_size=train_batch_size,
121
+ shuffle=True,
122
+ num_workers=n_worker,
123
+ pin_memory=False,
124
+ )
125
+ self.valid = None
126
+
127
+ test_dataset = self.test_dataset(valid_transforms)
128
+ if num_replicas is not None:
129
+ test_sampler = torch.utils.data.distributed.DistributedSampler(
130
+ test_dataset, num_replicas, rank
131
+ )
132
+ self.test = torch.utils.data.DataLoader(
133
+ test_dataset,
134
+ batch_size=test_batch_size,
135
+ sampler=test_sampler,
136
+ num_workers=n_worker,
137
+ pin_memory=False,
138
+ )
139
+ else:
140
+ self.test = torch.utils.data.DataLoader(
141
+ test_dataset,
142
+ batch_size=test_batch_size,
143
+ shuffle=True,
144
+ num_workers=n_worker,
145
+ pin_memory=False,
146
+ )
147
+
148
+ if self.valid is None:
149
+ self.valid = self.test
150
+
151
+ @staticmethod
152
+ def name():
153
+ return "imagenet"
154
+
155
+ @property
156
+ def data_shape(self):
157
+ return 3, self.active_img_size, self.active_img_size # C, H, W
158
+
159
+ @property
160
+ def n_classes(self):
161
+ return 1000
162
+
163
+ @property
164
+ def save_path(self):
165
+ if self._save_path is None:
166
+ self._save_path = self.DEFAULT_PATH
167
+ if not os.path.exists(self._save_path):
168
+ self._save_path = os.path.expanduser("~/dataset/imagenet")
169
+ return self._save_path
170
+
171
+ @property
172
+ def data_url(self):
173
+ raise ValueError("unable to download %s" % self.name())
174
+
175
+ def train_dataset(self, _transforms):
176
+ return datasets.ImageFolder(self.train_path, _transforms)
177
+
178
+ def test_dataset(self, _transforms):
179
+ return datasets.ImageFolder(self.valid_path, _transforms)
180
+
181
+ @property
182
+ def train_path(self):
183
+ return os.path.join(self.save_path, "train")
184
+
185
+ @property
186
+ def valid_path(self):
187
+ return os.path.join(self.save_path, "val")
188
+
189
+ @property
190
+ def normalize(self):
191
+ return transforms.Normalize(
192
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
193
+ )
194
+
195
+ def build_train_transform(self, image_size=None, print_log=True):
196
+ if image_size is None:
197
+ image_size = self.image_size
198
+ if print_log:
199
+ print(
200
+ "Color jitter: %s, resize_scale: %s, img_size: %s"
201
+ % (self.distort_color, self.resize_scale, image_size)
202
+ )
203
+
204
+ if isinstance(image_size, list):
205
+ resize_transform_class = MyRandomResizedCrop
206
+ print(
207
+ "Use MyRandomResizedCrop: %s, \t %s"
208
+ % MyRandomResizedCrop.get_candidate_image_size(),
209
+ "sync=%s, continuous=%s"
210
+ % (
211
+ MyRandomResizedCrop.SYNC_DISTRIBUTED,
212
+ MyRandomResizedCrop.CONTINUOUS,
213
+ ),
214
+ )
215
+ else:
216
+ resize_transform_class = transforms.RandomResizedCrop
217
+
218
+ # random_resize_crop -> random_horizontal_flip
219
+ train_transforms = [
220
+ resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
221
+ transforms.RandomHorizontalFlip(),
222
+ ]
223
+
224
+ # color augmentation (optional)
225
+ color_transform = None
226
+ if self.distort_color == "torch":
227
+ color_transform = transforms.ColorJitter(
228
+ brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
229
+ )
230
+ elif self.distort_color == "tf":
231
+ color_transform = transforms.ColorJitter(
232
+ brightness=32.0 / 255.0, saturation=0.5
233
+ )
234
+ if color_transform is not None:
235
+ train_transforms.append(color_transform)
236
+
237
+ train_transforms += [
238
+ transforms.ToTensor(),
239
+ self.normalize,
240
+ ]
241
+
242
+ train_transforms = transforms.Compose(train_transforms)
243
+ return train_transforms
244
+
245
+ def build_valid_transform(self, image_size=None):
246
+ if image_size is None:
247
+ image_size = self.active_img_size
248
+ return transforms.Compose(
249
+ [
250
+ transforms.Resize(int(math.ceil(image_size / 0.875))),
251
+ transforms.CenterCrop(image_size),
252
+ transforms.ToTensor(),
253
+ self.normalize,
254
+ ]
255
+ )
256
+
257
+ def assign_active_img_size(self, new_img_size):
258
+ self.active_img_size = new_img_size
259
+ if self.active_img_size not in self._valid_transform_dict:
260
+ self._valid_transform_dict[
261
+ self.active_img_size
262
+ ] = self.build_valid_transform()
263
+ # change the transform of the valid and test set
264
+ self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
265
+ self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
266
+
267
+ def build_sub_train_loader(
268
+ self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None
269
+ ):
270
+ # used for resetting BN running statistics
271
+ if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None:
272
+ if num_worker is None:
273
+ num_worker = self.train.num_workers
274
+
275
+ n_samples = len(self.train.dataset)
276
+ g = torch.Generator()
277
+ g.manual_seed(DataProvider.SUB_SEED)
278
+ rand_indexes = torch.randperm(n_samples, generator=g).tolist()
279
+
280
+ new_train_dataset = self.train_dataset(
281
+ self.build_train_transform(
282
+ image_size=self.active_img_size, print_log=False
283
+ )
284
+ )
285
+ chosen_indexes = rand_indexes[:n_images]
286
+ if num_replicas is not None:
287
+ sub_sampler = MyDistributedSampler(
288
+ new_train_dataset,
289
+ num_replicas,
290
+ rank,
291
+ True,
292
+ np.array(chosen_indexes),
293
+ )
294
+ else:
295
+ sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(
296
+ chosen_indexes
297
+ )
298
+ sub_data_loader = torch.utils.data.DataLoader(
299
+ new_train_dataset,
300
+ batch_size=batch_size,
301
+ sampler=sub_sampler,
302
+ num_workers=num_worker,
303
+ pin_memory=False,
304
+ )
305
+ self.__dict__["sub_train_%d" % self.active_img_size] = []
306
+ for images, labels in sub_data_loader:
307
+ self.__dict__["sub_train_%d" % self.active_img_size].append(
308
+ (images, labels)
309
+ )
310
+ return self.__dict__["sub_train_%d" % self.active_img_size]
proard/classification/elastic_nn/__init__.py ADDED
File without changes
proard/classification/elastic_nn/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ from .dynamic_layers import *
6
+ from .dynamic_op import *
proard/classification/elastic_nn/modules/dynamic_layers.py ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import copy
6
+ import torch
7
+ import torch.nn as nn
8
+ from collections import OrderedDict
9
+
10
+ from proard.utils.layers import (
11
+ MBConvLayer,
12
+ ConvLayer,
13
+ IdentityLayer,
14
+ set_layer_from_config,
15
+ )
16
+ from proard.utils.layers import ResNetBottleneckBlock, LinearLayer
17
+ from proard.utils import (
18
+ MyModule,
19
+ val2list,
20
+ get_net_device,
21
+ build_activation,
22
+ make_divisible,
23
+ SEModule,
24
+ MyNetwork,
25
+ )
26
+ from .dynamic_op import (
27
+ DynamicSeparableConv2d,
28
+ DynamicConv2d,
29
+ DynamicBatchNorm2d,
30
+ DynamicSE,
31
+ DynamicGroupNorm,
32
+ )
33
+ from .dynamic_op import DynamicLinear
34
+
35
+ __all__ = [
36
+ "adjust_bn_according_to_idx",
37
+ "copy_bn",
38
+ "DynamicMBConvLayer",
39
+ "DynamicConvLayer",
40
+ "DynamicLinearLayer",
41
+ "DynamicResNetBottleneckBlock",
42
+ ]
43
+
44
+
45
+ def adjust_bn_according_to_idx(bn, idx):
46
+ bn.weight.data = torch.index_select(bn.weight.data, 0, idx)
47
+ bn.bias.data = torch.index_select(bn.bias.data, 0, idx)
48
+ if type(bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
49
+ bn.running_mean.data = torch.index_select(bn.running_mean.data, 0, idx)
50
+ bn.running_var.data = torch.index_select(bn.running_var.data, 0, idx)
51
+
52
+
53
+ def copy_bn(target_bn, src_bn):
54
+ feature_dim = (
55
+ target_bn.num_channels
56
+ if isinstance(target_bn, nn.GroupNorm)
57
+ else target_bn.num_features
58
+ )
59
+
60
+ target_bn.weight.data.copy_(src_bn.weight.data[:feature_dim])
61
+ target_bn.bias.data.copy_(src_bn.bias.data[:feature_dim])
62
+ if type(src_bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
63
+ target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim])
64
+ target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim])
65
+
66
+
67
+ class DynamicLinearLayer(MyModule):
68
+ def __init__(self, in_features_list, out_features, bias=True, dropout_rate=0):
69
+ super(DynamicLinearLayer, self).__init__()
70
+
71
+ self.in_features_list = in_features_list
72
+ self.out_features = out_features
73
+ self.bias = bias
74
+ self.dropout_rate = dropout_rate
75
+
76
+ if self.dropout_rate > 0:
77
+ self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
78
+ else:
79
+ self.dropout = None
80
+ self.linear = DynamicLinear(
81
+ max_in_features=max(self.in_features_list),
82
+ max_out_features=self.out_features,
83
+ bias=self.bias,
84
+ )
85
+
86
+ def forward(self, x):
87
+ if self.dropout is not None:
88
+ x = self.dropout(x)
89
+ return self.linear(x)
90
+
91
+ @property
92
+ def module_str(self):
93
+ return "DyLinear(%d, %d)" % (max(self.in_features_list), self.out_features)
94
+
95
+ @property
96
+ def config(self):
97
+ return {
98
+ "name": DynamicLinear.__name__,
99
+ "in_features_list": self.in_features_list,
100
+ "out_features": self.out_features,
101
+ "bias": self.bias,
102
+ "dropout_rate": self.dropout_rate,
103
+ }
104
+
105
+ @staticmethod
106
+ def build_from_config(config):
107
+ return DynamicLinearLayer(**config)
108
+
109
+ def get_active_subnet(self, in_features, preserve_weight=True):
110
+ sub_layer = LinearLayer(
111
+ in_features, self.out_features, self.bias, dropout_rate=self.dropout_rate
112
+ )
113
+ sub_layer = sub_layer.to(get_net_device(self))
114
+ if not preserve_weight:
115
+ return sub_layer
116
+
117
+ sub_layer.linear.weight.data.copy_(
118
+ self.linear.get_active_weight(self.out_features, in_features).data
119
+ )
120
+ if self.bias:
121
+ sub_layer.linear.bias.data.copy_(
122
+ self.linear.get_active_bias(self.out_features).data
123
+ )
124
+ return sub_layer
125
+
126
+ def get_active_subnet_config(self, in_features):
127
+ return {
128
+ "name": LinearLayer.__name__,
129
+ "in_features": in_features,
130
+ "out_features": self.out_features,
131
+ "bias": self.bias,
132
+ "dropout_rate": self.dropout_rate,
133
+ }
134
+
135
+
136
+ class DynamicMBConvLayer(MyModule):
137
+ def __init__(
138
+ self,
139
+ in_channel_list,
140
+ out_channel_list,
141
+ kernel_size_list=3,
142
+ expand_ratio_list=6,
143
+ stride=1,
144
+ act_func="relu6",
145
+ use_se=False,
146
+ ):
147
+ super(DynamicMBConvLayer, self).__init__()
148
+
149
+ self.in_channel_list = in_channel_list
150
+ self.out_channel_list = out_channel_list
151
+
152
+ self.kernel_size_list = val2list(kernel_size_list)
153
+ self.expand_ratio_list = val2list(expand_ratio_list)
154
+
155
+ self.stride = stride
156
+ self.act_func = act_func
157
+ self.use_se = use_se
158
+
159
+ # build modules
160
+ max_middle_channel = make_divisible(
161
+ round(max(self.in_channel_list) * max(self.expand_ratio_list)),
162
+ MyNetwork.CHANNEL_DIVISIBLE,
163
+ )
164
+ if max(self.expand_ratio_list) == 1:
165
+ self.inverted_bottleneck = None
166
+ else:
167
+ self.inverted_bottleneck = nn.Sequential(
168
+ OrderedDict(
169
+ [
170
+ (
171
+ "conv",
172
+ DynamicConv2d(
173
+ max(self.in_channel_list), max_middle_channel
174
+ ),
175
+ ),
176
+ ("bn", DynamicBatchNorm2d(max_middle_channel)),
177
+ ("act", build_activation(self.act_func)),
178
+ ]
179
+ )
180
+ )
181
+
182
+ self.depth_conv = nn.Sequential(
183
+ OrderedDict(
184
+ [
185
+ (
186
+ "conv",
187
+ DynamicSeparableConv2d(
188
+ max_middle_channel, self.kernel_size_list, self.stride
189
+ ),
190
+ ),
191
+ ("bn", DynamicBatchNorm2d(max_middle_channel)),
192
+ ("act", build_activation(self.act_func)),
193
+ ]
194
+ )
195
+ )
196
+ if self.use_se:
197
+ self.depth_conv.add_module("se", DynamicSE(max_middle_channel))
198
+
199
+ self.point_linear = nn.Sequential(
200
+ OrderedDict(
201
+ [
202
+ (
203
+ "conv",
204
+ DynamicConv2d(max_middle_channel, max(self.out_channel_list)),
205
+ ),
206
+ ("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
207
+ ]
208
+ )
209
+ )
210
+
211
+ self.active_kernel_size = max(self.kernel_size_list)
212
+ self.active_expand_ratio = max(self.expand_ratio_list)
213
+ self.active_out_channel = max(self.out_channel_list)
214
+
215
+ def forward(self, x):
216
+ in_channel = x.size(1)
217
+
218
+ if self.inverted_bottleneck is not None:
219
+ self.inverted_bottleneck.conv.active_out_channel = make_divisible(
220
+ round(in_channel * self.active_expand_ratio),
221
+ MyNetwork.CHANNEL_DIVISIBLE,
222
+ )
223
+
224
+ self.depth_conv.conv.active_kernel_size = self.active_kernel_size
225
+ self.point_linear.conv.active_out_channel = self.active_out_channel
226
+
227
+ if self.inverted_bottleneck is not None:
228
+ x = self.inverted_bottleneck(x)
229
+ x = self.depth_conv(x)
230
+ x = self.point_linear(x)
231
+ return x
232
+
233
+ @property
234
+ def module_str(self):
235
+ if self.use_se:
236
+ return "SE(O%d, E%.1f, K%d)" % (
237
+ self.active_out_channel,
238
+ self.active_expand_ratio,
239
+ self.active_kernel_size,
240
+ )
241
+ else:
242
+ return "(O%d, E%.1f, K%d)" % (
243
+ self.active_out_channel,
244
+ self.active_expand_ratio,
245
+ self.active_kernel_size,
246
+ )
247
+
248
+ @property
249
+ def config(self):
250
+ return {
251
+ "name": DynamicMBConvLayer.__name__,
252
+ "in_channel_list": self.in_channel_list,
253
+ "out_channel_list": self.out_channel_list,
254
+ "kernel_size_list": self.kernel_size_list,
255
+ "expand_ratio_list": self.expand_ratio_list,
256
+ "stride": self.stride,
257
+ "act_func": self.act_func,
258
+ "use_se": self.use_se,
259
+ }
260
+
261
+ @staticmethod
262
+ def build_from_config(config):
263
+ return DynamicMBConvLayer(**config)
264
+
265
+ ############################################################################################
266
+
267
+ @property
268
+ def in_channels(self):
269
+ return max(self.in_channel_list)
270
+
271
+ @property
272
+ def out_channels(self):
273
+ return max(self.out_channel_list)
274
+
275
+ def active_middle_channel(self, in_channel):
276
+ return make_divisible(
277
+ round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE
278
+ )
279
+
280
+ ############################################################################################
281
+
282
+ def get_active_subnet(self, in_channel, preserve_weight=True):
283
+ # build the new layer
284
+ sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
285
+ sub_layer = sub_layer.to(get_net_device(self))
286
+ if not preserve_weight:
287
+ return sub_layer
288
+
289
+ middle_channel = self.active_middle_channel(in_channel)
290
+ # copy weight from current layer
291
+ if sub_layer.inverted_bottleneck is not None:
292
+ sub_layer.inverted_bottleneck.conv.weight.data.copy_(
293
+ self.inverted_bottleneck.conv.get_active_filter(
294
+ middle_channel, in_channel
295
+ ).data,
296
+ )
297
+ copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn)
298
+
299
+ sub_layer.depth_conv.conv.weight.data.copy_(
300
+ self.depth_conv.conv.get_active_filter(
301
+ middle_channel, self.active_kernel_size
302
+ ).data
303
+ )
304
+ copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn)
305
+
306
+ if self.use_se:
307
+ se_mid = make_divisible(
308
+ middle_channel // SEModule.REDUCTION,
309
+ divisor=MyNetwork.CHANNEL_DIVISIBLE,
310
+ )
311
+ sub_layer.depth_conv.se.fc.reduce.weight.data.copy_(
312
+ self.depth_conv.se.get_active_reduce_weight(se_mid, middle_channel).data
313
+ )
314
+ sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(
315
+ self.depth_conv.se.get_active_reduce_bias(se_mid).data
316
+ )
317
+
318
+ sub_layer.depth_conv.se.fc.expand.weight.data.copy_(
319
+ self.depth_conv.se.get_active_expand_weight(se_mid, middle_channel).data
320
+ )
321
+ sub_layer.depth_conv.se.fc.expand.bias.data.copy_(
322
+ self.depth_conv.se.get_active_expand_bias(middle_channel).data
323
+ )
324
+
325
+ sub_layer.point_linear.conv.weight.data.copy_(
326
+ self.point_linear.conv.get_active_filter(
327
+ self.active_out_channel, middle_channel
328
+ ).data
329
+ )
330
+ copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn)
331
+
332
+ return sub_layer
333
+
334
+ def get_active_subnet_config(self, in_channel):
335
+ return {
336
+ "name": MBConvLayer.__name__,
337
+ "in_channels": in_channel,
338
+ "out_channels": self.active_out_channel,
339
+ "kernel_size": self.active_kernel_size,
340
+ "stride": self.stride,
341
+ "expand_ratio": self.active_expand_ratio,
342
+ "mid_channels": self.active_middle_channel(in_channel),
343
+ "act_func": self.act_func,
344
+ "use_se": self.use_se,
345
+ }
346
+
347
+ def re_organize_middle_weights(self, expand_ratio_stage=0):
348
+ importance = torch.sum(
349
+ torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3)
350
+ )
351
+ if isinstance(self.depth_conv.bn, DynamicGroupNorm):
352
+ channel_per_group = self.depth_conv.bn.channel_per_group
353
+ importance_chunks = torch.split(importance, channel_per_group)
354
+ for chunk in importance_chunks:
355
+ chunk.data.fill_(torch.mean(chunk))
356
+ importance = torch.cat(importance_chunks, dim=0)
357
+ if expand_ratio_stage > 0:
358
+ sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
359
+ sorted_expand_list.sort(reverse=True)
360
+ target_width_list = [
361
+ make_divisible(
362
+ round(max(self.in_channel_list) * expand),
363
+ MyNetwork.CHANNEL_DIVISIBLE,
364
+ )
365
+ for expand in sorted_expand_list
366
+ ]
367
+
368
+ right = len(importance)
369
+ base = -len(target_width_list) * 1e5
370
+ for i in range(expand_ratio_stage + 1):
371
+ left = target_width_list[i]
372
+ importance[left:right] += base
373
+ base += 1e5
374
+ right = left
375
+
376
+ sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
377
+ self.point_linear.conv.conv.weight.data = torch.index_select(
378
+ self.point_linear.conv.conv.weight.data, 1, sorted_idx
379
+ )
380
+
381
+ adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx)
382
+ self.depth_conv.conv.conv.weight.data = torch.index_select(
383
+ self.depth_conv.conv.conv.weight.data, 0, sorted_idx
384
+ )
385
+
386
+ if self.use_se:
387
+ # se expand: output dim 0 reorganize
388
+ se_expand = self.depth_conv.se.fc.expand
389
+ se_expand.weight.data = torch.index_select(
390
+ se_expand.weight.data, 0, sorted_idx
391
+ )
392
+ se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx)
393
+ # se reduce: input dim 1 reorganize
394
+ se_reduce = self.depth_conv.se.fc.reduce
395
+ se_reduce.weight.data = torch.index_select(
396
+ se_reduce.weight.data, 1, sorted_idx
397
+ )
398
+ # middle weight reorganize
399
+ se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3))
400
+ se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True)
401
+
402
+ se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx)
403
+ se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx)
404
+ se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx)
405
+
406
+ if self.inverted_bottleneck is not None:
407
+ adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx)
408
+ self.inverted_bottleneck.conv.conv.weight.data = torch.index_select(
409
+ self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx
410
+ )
411
+ return None
412
+ else:
413
+ return sorted_idx
414
+
415
+
416
+ class DynamicConvLayer(MyModule):
417
+ def __init__(
418
+ self,
419
+ in_channel_list,
420
+ out_channel_list,
421
+ kernel_size=3,
422
+ stride=1,
423
+ dilation=1,
424
+ use_bn=True,
425
+ act_func="relu6",
426
+ ):
427
+ super(DynamicConvLayer, self).__init__()
428
+
429
+ self.in_channel_list = in_channel_list
430
+ self.out_channel_list = out_channel_list
431
+ self.kernel_size = kernel_size
432
+ self.stride = stride
433
+ self.dilation = dilation
434
+ self.use_bn = use_bn
435
+ self.act_func = act_func
436
+
437
+ self.conv = DynamicConv2d(
438
+ max_in_channels=max(self.in_channel_list),
439
+ max_out_channels=max(self.out_channel_list),
440
+ kernel_size=self.kernel_size,
441
+ stride=self.stride,
442
+ dilation=self.dilation,
443
+ )
444
+ if self.use_bn:
445
+ self.bn = DynamicBatchNorm2d(max(self.out_channel_list))
446
+ self.act = build_activation(self.act_func)
447
+
448
+ self.active_out_channel = max(self.out_channel_list)
449
+
450
+ def forward(self, x):
451
+ self.conv.active_out_channel = self.active_out_channel
452
+
453
+ x = self.conv(x)
454
+ if self.use_bn:
455
+ x = self.bn(x)
456
+ x = self.act(x)
457
+ return x
458
+
459
+ @property
460
+ def module_str(self):
461
+ return "DyConv(O%d, K%d, S%d)" % (
462
+ self.active_out_channel,
463
+ self.kernel_size,
464
+ self.stride,
465
+ )
466
+
467
+ @property
468
+ def config(self):
469
+ return {
470
+ "name": DynamicConvLayer.__name__,
471
+ "in_channel_list": self.in_channel_list,
472
+ "out_channel_list": self.out_channel_list,
473
+ "kernel_size": self.kernel_size,
474
+ "stride": self.stride,
475
+ "dilation": self.dilation,
476
+ "use_bn": self.use_bn,
477
+ "act_func": self.act_func,
478
+ }
479
+
480
+ @staticmethod
481
+ def build_from_config(config):
482
+ return DynamicConvLayer(**config)
483
+
484
+ ############################################################################################
485
+
486
+ @property
487
+ def in_channels(self):
488
+ return max(self.in_channel_list)
489
+
490
+ @property
491
+ def out_channels(self):
492
+ return max(self.out_channel_list)
493
+
494
+ ############################################################################################
495
+
496
+ def get_active_subnet(self, in_channel, preserve_weight=True):
497
+ sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
498
+ sub_layer = sub_layer.to(get_net_device(self))
499
+
500
+ if not preserve_weight:
501
+ return sub_layer
502
+
503
+ sub_layer.conv.weight.data.copy_(
504
+ self.conv.get_active_filter(self.active_out_channel, in_channel).data
505
+ )
506
+ if self.use_bn:
507
+ copy_bn(sub_layer.bn, self.bn.bn)
508
+
509
+ return sub_layer
510
+
511
+ def get_active_subnet_config(self, in_channel):
512
+ return {
513
+ "name": ConvLayer.__name__,
514
+ "in_channels": in_channel,
515
+ "out_channels": self.active_out_channel,
516
+ "kernel_size": self.kernel_size,
517
+ "stride": self.stride,
518
+ "dilation": self.dilation,
519
+ "use_bn": self.use_bn,
520
+ "act_func": self.act_func,
521
+ }
522
+
523
+
524
+ class DynamicResNetBottleneckBlock(MyModule):
525
+ def __init__(
526
+ self,
527
+ in_channel_list,
528
+ out_channel_list,
529
+ expand_ratio_list=0.25,
530
+ kernel_size=3,
531
+ stride=1,
532
+ act_func="relu",
533
+ downsample_mode="avgpool_conv",
534
+ ):
535
+ super(DynamicResNetBottleneckBlock, self).__init__()
536
+
537
+ self.in_channel_list = in_channel_list
538
+ self.out_channel_list = out_channel_list
539
+ self.expand_ratio_list = val2list(expand_ratio_list)
540
+
541
+ self.kernel_size = kernel_size
542
+ self.stride = stride
543
+ self.act_func = act_func
544
+ self.downsample_mode = downsample_mode
545
+
546
+ # build modules
547
+ max_middle_channel = make_divisible(
548
+ round(max(self.out_channel_list) * max(self.expand_ratio_list)),
549
+ MyNetwork.CHANNEL_DIVISIBLE,
550
+ )
551
+
552
+ self.conv1 = nn.Sequential(
553
+ OrderedDict(
554
+ [
555
+ (
556
+ "conv",
557
+ DynamicConv2d(max(self.in_channel_list), max_middle_channel),
558
+ ),
559
+ ("bn", DynamicBatchNorm2d(max_middle_channel)),
560
+ ("act", build_activation(self.act_func, inplace=True)),
561
+ ]
562
+ )
563
+ )
564
+
565
+ self.conv2 = nn.Sequential(
566
+ OrderedDict(
567
+ [
568
+ (
569
+ "conv",
570
+ DynamicConv2d(
571
+ max_middle_channel, max_middle_channel, kernel_size, stride
572
+ ),
573
+ ),
574
+ ("bn", DynamicBatchNorm2d(max_middle_channel)),
575
+ ("act", build_activation(self.act_func, inplace=True)),
576
+ ]
577
+ )
578
+ )
579
+
580
+ self.conv3 = nn.Sequential(
581
+ OrderedDict(
582
+ [
583
+ (
584
+ "conv",
585
+ DynamicConv2d(max_middle_channel, max(self.out_channel_list)),
586
+ ),
587
+ ("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
588
+ ]
589
+ )
590
+ )
591
+
592
+ if self.stride == 1 and self.in_channel_list == self.out_channel_list:
593
+ self.downsample = IdentityLayer(
594
+ max(self.in_channel_list), max(self.out_channel_list)
595
+ )
596
+ elif self.downsample_mode == "conv":
597
+ self.downsample = nn.Sequential(
598
+ OrderedDict(
599
+ [
600
+ (
601
+ "conv",
602
+ DynamicConv2d(
603
+ max(self.in_channel_list),
604
+ max(self.out_channel_list),
605
+ stride=stride,
606
+ ),
607
+ ),
608
+ ("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
609
+ ]
610
+ )
611
+ )
612
+ elif self.downsample_mode == "avgpool_conv":
613
+ self.downsample = nn.Sequential(
614
+ OrderedDict(
615
+ [
616
+ (
617
+ "avg_pool",
618
+ nn.AvgPool2d(
619
+ kernel_size=stride,
620
+ stride=stride,
621
+ padding=0,
622
+ ceil_mode=True,
623
+ ),
624
+ ),
625
+ (
626
+ "conv",
627
+ DynamicConv2d(
628
+ max(self.in_channel_list), max(self.out_channel_list)
629
+ ),
630
+ ),
631
+ ("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
632
+ ]
633
+ )
634
+ )
635
+ else:
636
+ raise NotImplementedError
637
+
638
+ self.final_act = build_activation(self.act_func, inplace=True)
639
+
640
+ self.active_expand_ratio = max(self.expand_ratio_list)
641
+ self.active_out_channel = max(self.out_channel_list)
642
+
643
+ def forward(self, x):
644
+ feature_dim = self.active_middle_channels
645
+
646
+ self.conv1.conv.active_out_channel = feature_dim
647
+ self.conv2.conv.active_out_channel = feature_dim
648
+ self.conv3.conv.active_out_channel = self.active_out_channel
649
+ if not isinstance(self.downsample, IdentityLayer):
650
+ self.downsample.conv.active_out_channel = self.active_out_channel
651
+
652
+ residual = self.downsample(x)
653
+
654
+ x = self.conv1(x)
655
+ x = self.conv2(x)
656
+ x = self.conv3(x)
657
+
658
+ x = x + residual
659
+ x = self.final_act(x)
660
+ return x
661
+
662
+ @property
663
+ def module_str(self):
664
+ return "(%s, %s)" % (
665
+ "%dx%d_BottleneckConv_in->%d->%d_S%d"
666
+ % (
667
+ self.kernel_size,
668
+ self.kernel_size,
669
+ self.active_middle_channels,
670
+ self.active_out_channel,
671
+ self.stride,
672
+ ),
673
+ "Identity"
674
+ if isinstance(self.downsample, IdentityLayer)
675
+ else self.downsample_mode,
676
+ )
677
+
678
+ @property
679
+ def config(self):
680
+ return {
681
+ "name": DynamicResNetBottleneckBlock.__name__,
682
+ "in_channel_list": self.in_channel_list,
683
+ "out_channel_list": self.out_channel_list,
684
+ "expand_ratio_list": self.expand_ratio_list,
685
+ "kernel_size": self.kernel_size,
686
+ "stride": self.stride,
687
+ "act_func": self.act_func,
688
+ "downsample_mode": self.downsample_mode,
689
+ }
690
+
691
+ @staticmethod
692
+ def build_from_config(config):
693
+ return DynamicResNetBottleneckBlock(**config)
694
+
695
+ ############################################################################################
696
+
697
+ @property
698
+ def in_channels(self):
699
+ return max(self.in_channel_list)
700
+
701
+ @property
702
+ def out_channels(self):
703
+ return max(self.out_channel_list)
704
+
705
+ @property
706
+ def active_middle_channels(self):
707
+ feature_dim = round(self.active_out_channel * self.active_expand_ratio)
708
+ feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
709
+ return feature_dim
710
+
711
+ ############################################################################################
712
+
713
+ def get_active_subnet(self, in_channel, preserve_weight=True):
714
+ # build the new layer
715
+ sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
716
+ sub_layer = sub_layer.to(get_net_device(self))
717
+ if not preserve_weight:
718
+ return sub_layer
719
+
720
+ # copy weight from current layer
721
+ sub_layer.conv1.conv.weight.data.copy_(
722
+ self.conv1.conv.get_active_filter(
723
+ self.active_middle_channels, in_channel
724
+ ).data
725
+ )
726
+ copy_bn(sub_layer.conv1.bn, self.conv1.bn.bn)
727
+
728
+ sub_layer.conv2.conv.weight.data.copy_(
729
+ self.conv2.conv.get_active_filter(
730
+ self.active_middle_channels, self.active_middle_channels
731
+ ).data
732
+ )
733
+ copy_bn(sub_layer.conv2.bn, self.conv2.bn.bn)
734
+
735
+ sub_layer.conv3.conv.weight.data.copy_(
736
+ self.conv3.conv.get_active_filter(
737
+ self.active_out_channel, self.active_middle_channels
738
+ ).data
739
+ )
740
+ copy_bn(sub_layer.conv3.bn, self.conv3.bn.bn)
741
+
742
+ if not isinstance(self.downsample, IdentityLayer):
743
+ sub_layer.downsample.conv.weight.data.copy_(
744
+ self.downsample.conv.get_active_filter(
745
+ self.active_out_channel, in_channel
746
+ ).data
747
+ )
748
+ copy_bn(sub_layer.downsample.bn, self.downsample.bn.bn)
749
+
750
+ return sub_layer
751
+
752
+ def get_active_subnet_config(self, in_channel):
753
+ return {
754
+ "name": ResNetBottleneckBlock.__name__,
755
+ "in_channels": in_channel,
756
+ "out_channels": self.active_out_channel,
757
+ "kernel_size": self.kernel_size,
758
+ "stride": self.stride,
759
+ "expand_ratio": self.active_expand_ratio,
760
+ "mid_channels": self.active_middle_channels,
761
+ "act_func": self.act_func,
762
+ "groups": 1,
763
+ "downsample_mode": self.downsample_mode,
764
+ }
765
+
766
+ def re_organize_middle_weights(self, expand_ratio_stage=0):
767
+ # conv3 -> conv2
768
+ importance = torch.sum(
769
+ torch.abs(self.conv3.conv.conv.weight.data), dim=(0, 2, 3)
770
+ )
771
+ if isinstance(self.conv2.bn, DynamicGroupNorm):
772
+ channel_per_group = self.conv2.bn.channel_per_group
773
+ importance_chunks = torch.split(importance, channel_per_group)
774
+ for chunk in importance_chunks:
775
+ chunk.data.fill_(torch.mean(chunk))
776
+ importance = torch.cat(importance_chunks, dim=0)
777
+ if expand_ratio_stage > 0:
778
+ sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
779
+ sorted_expand_list.sort(reverse=True)
780
+ target_width_list = [
781
+ make_divisible(
782
+ round(max(self.out_channel_list) * expand),
783
+ MyNetwork.CHANNEL_DIVISIBLE,
784
+ )
785
+ for expand in sorted_expand_list
786
+ ]
787
+ right = len(importance)
788
+ base = -len(target_width_list) * 1e5
789
+ for i in range(expand_ratio_stage + 1):
790
+ left = target_width_list[i]
791
+ importance[left:right] += base
792
+ base += 1e5
793
+ right = left
794
+
795
+ sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
796
+ self.conv3.conv.conv.weight.data = torch.index_select(
797
+ self.conv3.conv.conv.weight.data, 1, sorted_idx
798
+ )
799
+ adjust_bn_according_to_idx(self.conv2.bn.bn, sorted_idx)
800
+ self.conv2.conv.conv.weight.data = torch.index_select(
801
+ self.conv2.conv.conv.weight.data, 0, sorted_idx
802
+ )
803
+
804
+ # conv2 -> conv1
805
+ importance = torch.sum(
806
+ torch.abs(self.conv2.conv.conv.weight.data), dim=(0, 2, 3)
807
+ )
808
+ if isinstance(self.conv1.bn, DynamicGroupNorm):
809
+ channel_per_group = self.conv1.bn.channel_per_group
810
+ importance_chunks = torch.split(importance, channel_per_group)
811
+ for chunk in importance_chunks:
812
+ chunk.data.fill_(torch.mean(chunk))
813
+ importance = torch.cat(importance_chunks, dim=0)
814
+ if expand_ratio_stage > 0:
815
+ sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
816
+ sorted_expand_list.sort(reverse=True)
817
+ target_width_list = [
818
+ make_divisible(
819
+ round(max(self.out_channel_list) * expand),
820
+ MyNetwork.CHANNEL_DIVISIBLE,
821
+ )
822
+ for expand in sorted_expand_list
823
+ ]
824
+ right = len(importance)
825
+ base = -len(target_width_list) * 1e5
826
+ for i in range(expand_ratio_stage + 1):
827
+ left = target_width_list[i]
828
+ importance[left:right] += base
829
+ base += 1e5
830
+ right = left
831
+ sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
832
+
833
+ self.conv2.conv.conv.weight.data = torch.index_select(
834
+ self.conv2.conv.conv.weight.data, 1, sorted_idx
835
+ )
836
+ adjust_bn_according_to_idx(self.conv1.bn.bn, sorted_idx)
837
+ self.conv1.conv.conv.weight.data = torch.index_select(
838
+ self.conv1.conv.conv.weight.data, 0, sorted_idx
839
+ )
840
+
841
+ return None
proard/classification/elastic_nn/modules/dynamic_op.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import torch.nn.functional as F
6
+ import torch.nn as nn
7
+ import torch
8
+ from torch.nn.parameter import Parameter
9
+
10
+ from proard.utils import (
11
+ get_same_padding,
12
+ sub_filter_start_end,
13
+ make_divisible,
14
+ SEModule,
15
+ MyNetwork,
16
+ MyConv2d,
17
+ )
18
+
19
+ __all__ = [
20
+ "DynamicSeparableConv2d",
21
+ "DynamicConv2d",
22
+ "DynamicGroupConv2d",
23
+ "DynamicBatchNorm2d",
24
+ "DynamicGroupNorm",
25
+ "DynamicSE",
26
+ "DynamicLinear",
27
+ ]
28
+
29
+ # Seprable conv consits of a depthwise and pointwise conv
30
+
31
+ class DynamicSeparableConv2d(nn.Module):
32
+ KERNEL_TRANSFORM_MODE = 1 # None or 1
33
+
34
+ def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
35
+ super(DynamicSeparableConv2d, self).__init__()
36
+
37
+ self.max_in_channels = max_in_channels
38
+ self.kernel_size_list = kernel_size_list # list of kernel size
39
+ self.stride = stride
40
+ self.dilation = dilation
41
+
42
+ self.conv = nn.Conv2d(
43
+ self.max_in_channels,
44
+ self.max_in_channels,
45
+ max(self.kernel_size_list),
46
+ self.stride,
47
+ groups=self.max_in_channels,
48
+ bias=False,
49
+ )
50
+
51
+ self._ks_set = list(set(self.kernel_size_list))
52
+ self._ks_set.sort() # e.g., [3, 5, 7]
53
+ # define a matrix for converting from damll kernel size to larger one
54
+ if self.KERNEL_TRANSFORM_MODE is not None:
55
+ # register scaling parameters
56
+ # 7to5_matrix, 5to3_matrix
57
+ scale_params = {}
58
+ for i in range(len(self._ks_set) - 1):
59
+ ks_small = self._ks_set[i]
60
+ ks_larger = self._ks_set[i + 1]
61
+ param_name = "%dto%d" % (ks_larger, ks_small)
62
+ # noinspection PyArgumentList
63
+ scale_params["%s_matrix" % param_name] = Parameter(
64
+ torch.eye(ks_small ** 2)
65
+ )
66
+ for name, param in scale_params.items():
67
+ self.register_parameter(name, param)
68
+
69
+ self.active_kernel_size = max(self.kernel_size_list)
70
+
71
+ def get_active_filter(self, in_channel, kernel_size):
72
+ out_channel = in_channel
73
+ max_kernel_size = max(self.kernel_size_list)
74
+
75
+ start, end = sub_filter_start_end(max_kernel_size, kernel_size)
76
+ filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
77
+ if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
78
+ start_filter = self.conv.weight[
79
+ :out_channel, :in_channel, :, :
80
+ ] # start with max kernel
81
+ for i in range(len(self._ks_set) - 1, 0, -1):
82
+ src_ks = self._ks_set[i]
83
+ if src_ks <= kernel_size:
84
+ break
85
+ target_ks = self._ks_set[i - 1]
86
+ start, end = sub_filter_start_end(src_ks, target_ks)
87
+ _input_filter = start_filter[:, :, start:end, start:end]
88
+ _input_filter = _input_filter.contiguous()
89
+ _input_filter = _input_filter.view(
90
+ _input_filter.size(0), _input_filter.size(1), -1
91
+ )
92
+ _input_filter = _input_filter.view(-1, _input_filter.size(2))
93
+ _input_filter = F.linear(
94
+ _input_filter,
95
+ self.__getattr__("%dto%d_matrix" % (src_ks, target_ks)),
96
+ )
97
+ _input_filter = _input_filter.view(
98
+ filters.size(0), filters.size(1), target_ks ** 2
99
+ )
100
+ _input_filter = _input_filter.view(
101
+ filters.size(0), filters.size(1), target_ks, target_ks
102
+ )
103
+ start_filter = _input_filter
104
+ filters = start_filter
105
+ return filters
106
+
107
+ def forward(self, x, kernel_size=None):
108
+ if kernel_size is None:
109
+ kernel_size = self.active_kernel_size
110
+ in_channel = x.size(1)
111
+
112
+ filters = self.get_active_filter(in_channel, kernel_size).contiguous()
113
+
114
+ padding = get_same_padding(kernel_size)
115
+ filters = (
116
+ self.conv.weight_standardization(filters)
117
+ if isinstance(self.conv, MyConv2d)
118
+ else filters
119
+ )
120
+ y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, in_channel)
121
+ return y
122
+
123
+
124
+ class DynamicConv2d(nn.Module):
125
+ def __init__(
126
+ self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1
127
+ ):
128
+ super(DynamicConv2d, self).__init__()
129
+
130
+ self.max_in_channels = max_in_channels
131
+ self.max_out_channels = max_out_channels
132
+ self.kernel_size = kernel_size
133
+ self.stride = stride
134
+ self.dilation = dilation
135
+
136
+ self.conv = nn.Conv2d(
137
+ self.max_in_channels,
138
+ self.max_out_channels,
139
+ self.kernel_size,
140
+ stride=self.stride,
141
+ bias=False,
142
+ )
143
+
144
+ self.active_out_channel = self.max_out_channels
145
+
146
+ def get_active_filter(self, out_channel, in_channel):
147
+ return self.conv.weight[:out_channel, :in_channel, :, :]
148
+
149
+ def forward(self, x, out_channel=None):
150
+ if out_channel is None:
151
+ out_channel = self.active_out_channel
152
+ in_channel = x.size(1)
153
+ filters = self.get_active_filter(out_channel, in_channel).contiguous()
154
+
155
+ padding = get_same_padding(self.kernel_size)
156
+ filters = (
157
+ self.conv.weight_standardization(filters)
158
+ if isinstance(self.conv, MyConv2d)
159
+ else filters
160
+ )
161
+ y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1)
162
+ return y
163
+
164
+
165
+ class DynamicGroupConv2d(nn.Module):
166
+ def __init__(
167
+ self,
168
+ in_channels,
169
+ out_channels,
170
+ kernel_size_list,
171
+ groups_list,
172
+ stride=1,
173
+ dilation=1,
174
+ ):
175
+ super(DynamicGroupConv2d, self).__init__()
176
+
177
+ self.in_channels = in_channels
178
+ self.out_channels = out_channels
179
+ self.kernel_size_list = kernel_size_list
180
+ self.groups_list = groups_list
181
+ self.stride = stride
182
+ self.dilation = dilation
183
+
184
+ self.conv = nn.Conv2d(
185
+ self.in_channels,
186
+ self.out_channels,
187
+ max(self.kernel_size_list),
188
+ self.stride,
189
+ groups=min(self.groups_list),
190
+ bias=False,
191
+ )
192
+
193
+ self.active_kernel_size = max(self.kernel_size_list)
194
+ self.active_groups = min(self.groups_list)
195
+
196
+ def get_active_filter(self, kernel_size, groups):
197
+ start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size)
198
+ filters = self.conv.weight[:, :, start:end, start:end]
199
+
200
+ sub_filters = torch.chunk(filters, groups, dim=0)
201
+ sub_in_channels = self.in_channels // groups
202
+ sub_ratio = filters.size(1) // sub_in_channels
203
+
204
+ filter_crops = []
205
+ for i, sub_filter in enumerate(sub_filters):
206
+ part_id = i % sub_ratio
207
+ start = part_id * sub_in_channels
208
+ filter_crops.append(sub_filter[:, start : start + sub_in_channels, :, :])
209
+ filters = torch.cat(filter_crops, dim=0)
210
+ return filters
211
+
212
+ def forward(self, x, kernel_size=None, groups=None):
213
+ if kernel_size is None:
214
+ kernel_size = self.active_kernel_size
215
+ if groups is None:
216
+ groups = self.active_groups
217
+
218
+ filters = self.get_active_filter(kernel_size, groups).contiguous()
219
+ padding = get_same_padding(kernel_size)
220
+ filters = (
221
+ self.conv.weight_standardization(filters)
222
+ if isinstance(self.conv, MyConv2d)
223
+ else filters
224
+ )
225
+ y = F.conv2d(
226
+ x,
227
+ filters,
228
+ None,
229
+ self.stride,
230
+ padding,
231
+ self.dilation,
232
+ groups,
233
+ )
234
+ return y
235
+
236
+
237
+ class DynamicBatchNorm2d(nn.Module):
238
+ SET_RUNNING_STATISTICS = False
239
+
240
+ def __init__(self, max_feature_dim):
241
+ super(DynamicBatchNorm2d, self).__init__()
242
+
243
+ self.max_feature_dim = max_feature_dim
244
+ self.bn = nn.BatchNorm2d(self.max_feature_dim)
245
+
246
+ @staticmethod
247
+ def bn_forward(x, bn: nn.BatchNorm2d, feature_dim):
248
+ if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS:
249
+ return bn(x)
250
+ else:
251
+ exponential_average_factor = 0.0
252
+
253
+ if bn.training and bn.track_running_stats:
254
+ if bn.num_batches_tracked is not None:
255
+ bn.num_batches_tracked += 1
256
+ if bn.momentum is None: # use cumulative moving average
257
+ exponential_average_factor = 1.0 / float(bn.num_batches_tracked)
258
+ else: # use exponential moving average
259
+ exponential_average_factor = bn.momentum
260
+ return F.batch_norm(
261
+ x,
262
+ bn.running_mean[:feature_dim],
263
+ bn.running_var[:feature_dim],
264
+ bn.weight[:feature_dim],
265
+ bn.bias[:feature_dim],
266
+ bn.training or not bn.track_running_stats,
267
+ exponential_average_factor,
268
+ bn.eps,
269
+ )
270
+
271
+ def forward(self, x):
272
+ feature_dim = x.size(1)
273
+ y = self.bn_forward(x, self.bn, feature_dim)
274
+ return y
275
+
276
+
277
+ class DynamicGroupNorm(nn.GroupNorm):
278
+ def __init__(
279
+ self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None
280
+ ):
281
+ super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
282
+ self.channel_per_group = channel_per_group
283
+
284
+ def forward(self, x):
285
+ n_channels = x.size(1)
286
+ n_groups = n_channels // self.channel_per_group
287
+ return F.group_norm(
288
+ x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps
289
+ )
290
+
291
+ @property
292
+ def bn(self):
293
+ return self
294
+
295
+
296
+ class DynamicSE(SEModule):
297
+ def __init__(self, max_channel):
298
+ super(DynamicSE, self).__init__(max_channel)
299
+
300
+ def get_active_reduce_weight(self, num_mid, in_channel, groups=None):
301
+ if groups is None or groups == 1:
302
+ return self.fc.reduce.weight[:num_mid, :in_channel, :, :]
303
+ else:
304
+ assert in_channel % groups == 0
305
+ sub_in_channels = in_channel // groups
306
+ sub_filters = torch.chunk(
307
+ self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1
308
+ )
309
+ return torch.cat(
310
+ [sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters],
311
+ dim=1,
312
+ )
313
+
314
+ def get_active_reduce_bias(self, num_mid):
315
+ return (
316
+ self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None
317
+ )
318
+
319
+ def get_active_expand_weight(self, num_mid, in_channel, groups=None):
320
+ if groups is None or groups == 1:
321
+ return self.fc.expand.weight[:in_channel, :num_mid, :, :]
322
+ else:
323
+ assert in_channel % groups == 0
324
+ sub_in_channels = in_channel // groups
325
+ sub_filters = torch.chunk(
326
+ self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0
327
+ )
328
+ return torch.cat(
329
+ [sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters],
330
+ dim=0,
331
+ )
332
+
333
+ def get_active_expand_bias(self, in_channel, groups=None):
334
+ if groups is None or groups == 1:
335
+ return (
336
+ self.fc.expand.bias[:in_channel]
337
+ if self.fc.expand.bias is not None
338
+ else None
339
+ )
340
+ else:
341
+ assert in_channel % groups == 0
342
+ sub_in_channels = in_channel // groups
343
+ sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0)
344
+ return torch.cat(
345
+ [sub_bias[:sub_in_channels] for sub_bias in sub_bias_list], dim=0
346
+ )
347
+
348
+ def forward(self, x, groups=None):
349
+ in_channel = x.size(1)
350
+ num_mid = make_divisible(
351
+ in_channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE
352
+ )
353
+
354
+ y = x.mean(3, keepdim=True).mean(2, keepdim=True)
355
+ # reduce
356
+ reduce_filter = self.get_active_reduce_weight(
357
+ num_mid, in_channel, groups=groups
358
+ ).contiguous()
359
+ reduce_bias = self.get_active_reduce_bias(num_mid)
360
+ y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1)
361
+ # relu
362
+ y = self.fc.relu(y)
363
+ # expand
364
+ expand_filter = self.get_active_expand_weight(
365
+ num_mid, in_channel, groups=groups
366
+ ).contiguous()
367
+ expand_bias = self.get_active_expand_bias(in_channel, groups=groups)
368
+ y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1)
369
+ # hard sigmoid
370
+ y = self.fc.h_sigmoid(y)
371
+
372
+ return x * y
373
+
374
+
375
+ class DynamicLinear(nn.Module):
376
+ def __init__(self, max_in_features, max_out_features, bias=True):
377
+ super(DynamicLinear, self).__init__()
378
+
379
+ self.max_in_features = max_in_features
380
+ self.max_out_features = max_out_features
381
+ self.bias = bias
382
+
383
+ self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias)
384
+
385
+ self.active_out_features = self.max_out_features
386
+
387
+ def get_active_weight(self, out_features, in_features):
388
+ return self.linear.weight[:out_features, :in_features]
389
+
390
+ def get_active_bias(self, out_features):
391
+ return self.linear.bias[:out_features] if self.bias else None
392
+
393
+ def forward(self, x, out_features=None):
394
+ if out_features is None:
395
+ out_features = self.active_out_features
396
+
397
+ in_features = x.size(1)
398
+ weight = self.get_active_weight(out_features, in_features).contiguous()
399
+ bias = self.get_active_bias(out_features)
400
+ y = F.linear(x, weight, bias)
401
+ return y
proard/classification/elastic_nn/networks/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ from .dyn_proxyless import DYNProxylessNASNets,DYNProxylessNASNets_Cifar
6
+ from .dyn_mbv3 import DYNMobileNetV3,DYNMobileNetV3_Cifar
7
+ from .dyn_resnets import DYNResNets,DYNResNets_Cifar
proard/classification/elastic_nn/networks/dyn_mbv3.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import copy
6
+ import random
7
+
8
+ from proard.classification.elastic_nn.modules.dynamic_layers import (
9
+ DynamicMBConvLayer,
10
+ )
11
+ from proard.utils.layers import (
12
+ ConvLayer,
13
+ IdentityLayer,
14
+ LinearLayer,
15
+ MBConvLayer,
16
+ ResidualBlock,
17
+ )
18
+ from proard.classification.networks import MobileNetV3,MobileNetV3_Cifar
19
+ from proard.utils import make_divisible, val2list, MyNetwork
20
+
21
+ __all__ = ["DYNMobileNetV3","DYNMobileNetV3_Cifar"]
22
+
23
+
24
+ class DYNMobileNetV3(MobileNetV3):
25
+ def __init__(
26
+ self,
27
+ n_classes=1000,
28
+ bn_param=(0.1, 1e-5),
29
+ dropout_rate=0.1,
30
+ base_stage_width=None,
31
+ width_mult=1.0,
32
+ ks_list=3,
33
+ expand_ratio_list=6,
34
+ depth_list=4,
35
+ ):
36
+
37
+ self.width_mult = width_mult
38
+ self.ks_list = val2list(ks_list, 1)
39
+ self.expand_ratio_list = val2list(expand_ratio_list, 1)
40
+ self.depth_list = val2list(depth_list, 1)
41
+
42
+ self.ks_list.sort()
43
+ self.expand_ratio_list.sort()
44
+ self.depth_list.sort()
45
+
46
+ base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280]
47
+
48
+ final_expand_width = make_divisible(
49
+ base_stage_width[-2] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
50
+ )
51
+ last_channel = make_divisible(
52
+ base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
53
+ )
54
+
55
+ stride_stages = [1, 2, 2, 2, 1, 2]
56
+ act_stages = ["relu", "relu", "relu", "h_swish", "h_swish", "h_swish"]
57
+ se_stages = [False, False, True, False, True, True]
58
+ n_block_list = [1] + [max(self.depth_list)] * 5
59
+ width_list = []
60
+ for base_width in base_stage_width[:-2]:
61
+ width = make_divisible(
62
+ base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
63
+ )
64
+ width_list.append(width)
65
+
66
+ input_channel, first_block_dim = width_list[0], width_list[1]
67
+ # first conv layer
68
+ first_conv = ConvLayer(
69
+ 3, input_channel, kernel_size=3, stride=2, act_func="h_swish"
70
+ )
71
+ first_block_conv = MBConvLayer(
72
+ in_channels=input_channel,
73
+ out_channels=first_block_dim,
74
+ kernel_size=3,
75
+ stride=stride_stages[0],
76
+ expand_ratio=1,
77
+ act_func=act_stages[0],
78
+ use_se=se_stages[0],
79
+ )
80
+ first_block = ResidualBlock(
81
+ first_block_conv,
82
+ IdentityLayer(first_block_dim, first_block_dim)
83
+ if input_channel == first_block_dim
84
+ else None,
85
+ )
86
+
87
+ # inverted residual blocks
88
+ self.block_group_info = []
89
+ blocks = [first_block]
90
+ _block_index = 1
91
+ feature_dim = first_block_dim
92
+
93
+ for width, n_block, s, act_func, use_se in zip(
94
+ width_list[2:],
95
+ n_block_list[1:],
96
+ stride_stages[1:],
97
+ act_stages[1:],
98
+ se_stages[1:],
99
+ ):
100
+ self.block_group_info.append([_block_index + i for i in range(n_block)])
101
+ _block_index += n_block
102
+
103
+ output_channel = width
104
+ for i in range(n_block):
105
+ if i == 0:
106
+ stride = s
107
+ else:
108
+ stride = 1
109
+ mobile_inverted_conv = DynamicMBConvLayer(
110
+ in_channel_list=val2list(feature_dim),
111
+ out_channel_list=val2list(output_channel),
112
+ kernel_size_list=ks_list,
113
+ expand_ratio_list=expand_ratio_list,
114
+ stride=stride,
115
+ act_func=act_func,
116
+ use_se=use_se,
117
+ )
118
+ if stride == 1 and feature_dim == output_channel:
119
+ shortcut = IdentityLayer(feature_dim, feature_dim)
120
+ else:
121
+ shortcut = None
122
+ blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
123
+ feature_dim = output_channel
124
+ # final expand layer, feature mix layer & classifier
125
+ final_expand_layer = ConvLayer(
126
+ feature_dim, final_expand_width, kernel_size=1, act_func="h_swish"
127
+ )
128
+ feature_mix_layer = ConvLayer(
129
+ final_expand_width,
130
+ last_channel,
131
+ kernel_size=1,
132
+ bias=False,
133
+ use_bn=False,
134
+ act_func="h_swish",
135
+ )
136
+
137
+ classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
138
+
139
+ super(DYNMobileNetV3, self).__init__(
140
+ first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
141
+ )
142
+
143
+ # set bn param
144
+ self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
145
+
146
+ # runtime_depth
147
+ self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
148
+
149
+ """ MyNetwork required methods """
150
+
151
+ @staticmethod
152
+ def name():
153
+ return "DYNMobileNetV3"
154
+
155
+ def forward(self, x):
156
+ # first conv
157
+ x = self.first_conv(x)
158
+ # first block
159
+ x = self.blocks[0](x)
160
+ # blocks
161
+ for stage_id, block_idx in enumerate(self.block_group_info):
162
+ depth = self.runtime_depth[stage_id]
163
+ active_idx = block_idx[:depth]
164
+ for idx in active_idx:
165
+ x = self.blocks[idx](x)
166
+ x = self.final_expand_layer(x)
167
+ x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
168
+ x = self.feature_mix_layer(x)
169
+ x = x.view(x.size(0), -1)
170
+ x = self.classifier(x)
171
+ return x
172
+
173
+ @property
174
+ def module_str(self):
175
+ _str = self.first_conv.module_str + "\n"
176
+ _str += self.blocks[0].module_str + "\n"
177
+
178
+ for stage_id, block_idx in enumerate(self.block_group_info):
179
+ depth = self.runtime_depth[stage_id]
180
+ active_idx = block_idx[:depth]
181
+ for idx in active_idx:
182
+ _str += self.blocks[idx].module_str + "\n"
183
+
184
+ _str += self.final_expand_layer.module_str + "\n"
185
+ _str += self.feature_mix_layer.module_str + "\n"
186
+ _str += self.classifier.module_str + "\n"
187
+ return _str
188
+
189
+ @property
190
+ def config(self):
191
+ return {
192
+ "name": DYNMobileNetV3.__name__,
193
+ "bn": self.get_bn_param(),
194
+ "first_conv": self.first_conv.config,
195
+ "blocks": [block.config for block in self.blocks],
196
+ "final_expand_layer": self.final_expand_layer.config,
197
+ "feature_mix_layer": self.feature_mix_layer.config,
198
+ "classifier": self.classifier.config,
199
+ }
200
+
201
+ @staticmethod
202
+ def build_from_config(config):
203
+ raise ValueError("do not support this function")
204
+
205
+ @property
206
+ def grouped_block_index(self):
207
+ return self.block_group_info
208
+
209
+ def load_state_dict(self, state_dict, **kwargs):
210
+ model_dict = self.state_dict()
211
+ for key in state_dict:
212
+ if ".mobile_inverted_conv." in key:
213
+ new_key = key.replace(".mobile_inverted_conv.", ".conv.")
214
+ else:
215
+ new_key = key
216
+ if new_key in model_dict:
217
+ pass
218
+ elif ".bn.bn." in new_key:
219
+ new_key = new_key.replace(".bn.bn.", ".bn.")
220
+ elif ".conv.conv.weight" in new_key:
221
+ new_key = new_key.replace(".conv.conv.weight", ".conv.weight")
222
+ elif ".linear.linear." in new_key:
223
+ new_key = new_key.replace(".linear.linear.", ".linear.")
224
+ ##############################################################################
225
+ elif ".linear." in new_key:
226
+ new_key = new_key.replace(".linear.", ".linear.linear.")
227
+ elif "bn." in new_key:
228
+ new_key = new_key.replace("bn.", "bn.bn.")
229
+ elif "conv.weight" in new_key:
230
+ new_key = new_key.replace("conv.weight", "conv.conv.weight")
231
+ else:
232
+ raise ValueError(new_key)
233
+ assert new_key in model_dict, "%s" % new_key
234
+ model_dict[new_key] = state_dict[key]
235
+ super(DYNMobileNetV3, self).load_state_dict(model_dict)
236
+
237
+ """ set, sample and get active sub-networks """
238
+
239
+ def set_max_net(self):
240
+ self.set_active_subnet(
241
+ ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list)
242
+ )
243
+
244
+ def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
245
+ ks = val2list(ks, len(self.blocks) - 1)
246
+ expand_ratio = val2list(e, len(self.blocks) - 1)
247
+ depth = val2list(d, len(self.block_group_info))
248
+
249
+ for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
250
+ if k is not None:
251
+ block.conv.active_kernel_size = k
252
+ if e is not None:
253
+ block.conv.active_expand_ratio = e
254
+
255
+ for i, d in enumerate(depth):
256
+ if d is not None:
257
+ self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
258
+
259
+ def set_constraint(self, include_list, constraint_type="depth"):
260
+ if constraint_type == "depth":
261
+ self.__dict__["_depth_include_list"] = include_list.copy()
262
+ elif constraint_type == "expand_ratio":
263
+ self.__dict__["_expand_include_list"] = include_list.copy()
264
+ elif constraint_type == "kernel_size":
265
+ self.__dict__["_ks_include_list"] = include_list.copy()
266
+ else:
267
+ raise NotImplementedError
268
+
269
+ def clear_constraint(self):
270
+ self.__dict__["_depth_include_list"] = None
271
+ self.__dict__["_expand_include_list"] = None
272
+ self.__dict__["_ks_include_list"] = None
273
+
274
+ def sample_active_subnet(self):
275
+ ks_candidates = (
276
+ self.ks_list
277
+ if self.__dict__.get("_ks_include_list", None) is None
278
+ else self.__dict__["_ks_include_list"]
279
+ )
280
+ expand_candidates = (
281
+ self.expand_ratio_list
282
+ if self.__dict__.get("_expand_include_list", None) is None
283
+ else self.__dict__["_expand_include_list"]
284
+ )
285
+ depth_candidates = (
286
+ self.depth_list
287
+ if self.__dict__.get("_depth_include_list", None) is None
288
+ else self.__dict__["_depth_include_list"]
289
+ )
290
+
291
+ # sample kernel size
292
+ ks_setting = []
293
+ if not isinstance(ks_candidates[0], list):
294
+ ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
295
+ for k_set in ks_candidates:
296
+ k = random.choice(k_set)
297
+ ks_setting.append(k)
298
+
299
+ # sample expand ratio
300
+ expand_setting = []
301
+ if not isinstance(expand_candidates[0], list):
302
+ expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
303
+ for e_set in expand_candidates:
304
+ e = random.choice(e_set)
305
+ expand_setting.append(e)
306
+
307
+ # sample depth
308
+ depth_setting = []
309
+ if not isinstance(depth_candidates[0], list):
310
+ depth_candidates = [
311
+ depth_candidates for _ in range(len(self.block_group_info))
312
+ ]
313
+ for d_set in depth_candidates:
314
+ d = random.choice(d_set)
315
+ depth_setting.append(d)
316
+
317
+ self.set_active_subnet(ks_setting, expand_setting, depth_setting)
318
+
319
+ return {
320
+ "ks": ks_setting,
321
+ "e": expand_setting,
322
+ "d": depth_setting,
323
+ }
324
+
325
+ def get_active_subnet(self, preserve_weight=True):
326
+ first_conv = copy.deepcopy(self.first_conv)
327
+ blocks = [copy.deepcopy(self.blocks[0])]
328
+
329
+ final_expand_layer = copy.deepcopy(self.final_expand_layer)
330
+ feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
331
+ classifier = copy.deepcopy(self.classifier)
332
+
333
+ input_channel = blocks[0].conv.out_channels
334
+ # blocks
335
+ for stage_id, block_idx in enumerate(self.block_group_info):
336
+ depth = self.runtime_depth[stage_id]
337
+ active_idx = block_idx[:depth]
338
+ stage_blocks = []
339
+ for idx in active_idx:
340
+ stage_blocks.append(
341
+ ResidualBlock(
342
+ self.blocks[idx].conv.get_active_subnet(
343
+ input_channel, preserve_weight
344
+ ),
345
+ copy.deepcopy(self.blocks[idx].shortcut),
346
+ )
347
+ )
348
+ input_channel = stage_blocks[-1].conv.out_channels
349
+ blocks += stage_blocks
350
+
351
+ _subnet = MobileNetV3(
352
+ first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
353
+ )
354
+ _subnet.set_bn_param(**self.get_bn_param())
355
+ return _subnet
356
+
357
+ def get_active_net_config(self):
358
+ # first conv
359
+ first_conv_config = self.first_conv.config
360
+ first_block_config = self.blocks[0].config
361
+ final_expand_config = self.final_expand_layer.config
362
+ feature_mix_layer_config = self.feature_mix_layer.config
363
+ classifier_config = self.classifier.config
364
+
365
+ block_config_list = [first_block_config]
366
+ input_channel = first_block_config["conv"]["out_channels"]
367
+ for stage_id, block_idx in enumerate(self.block_group_info):
368
+ depth = self.runtime_depth[stage_id]
369
+ active_idx = block_idx[:depth]
370
+ stage_blocks = []
371
+ for idx in active_idx:
372
+ stage_blocks.append(
373
+ {
374
+ "name": ResidualBlock.__name__,
375
+ "conv": self.blocks[idx].conv.get_active_subnet_config(
376
+ input_channel
377
+ ),
378
+ "shortcut": self.blocks[idx].shortcut.config
379
+ if self.blocks[idx].shortcut is not None
380
+ else None,
381
+ }
382
+ )
383
+ input_channel = self.blocks[idx].conv.active_out_channel
384
+ block_config_list += stage_blocks
385
+
386
+ return {
387
+ "name": MobileNetV3.__name__,
388
+ "bn": self.get_bn_param(),
389
+ "first_conv": first_conv_config,
390
+ "blocks": block_config_list,
391
+ "final_expand_layer": final_expand_config,
392
+ "feature_mix_layer": feature_mix_layer_config,
393
+ "classifier": classifier_config,
394
+ }
395
+
396
+ """ Width Related Methods """
397
+
398
+ def re_organize_middle_weights(self, expand_ratio_stage=0):
399
+ for block in self.blocks[1:]:
400
+ block.conv.re_organize_middle_weights(expand_ratio_stage)
401
+
402
+
403
+
404
+ class DYNMobileNetV3_Cifar(MobileNetV3_Cifar):
405
+ def __init__(
406
+ self,
407
+ n_classes=10,
408
+ bn_param=(0.1, 1e-5),
409
+ dropout_rate=0.1,
410
+ base_stage_width=None,
411
+ width_mult=1.0,
412
+ ks_list=3,
413
+ expand_ratio_list=6,
414
+ depth_list=4,
415
+ ):
416
+
417
+ self.width_mult = width_mult
418
+ self.ks_list = val2list(ks_list, 1)
419
+ self.expand_ratio_list = val2list(expand_ratio_list, 1)
420
+ self.depth_list = val2list(depth_list, 1)
421
+
422
+ self.ks_list.sort()
423
+ self.expand_ratio_list.sort()
424
+ self.depth_list.sort()
425
+
426
+ base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280]
427
+
428
+ final_expand_width = make_divisible(
429
+ base_stage_width[-2] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
430
+ )
431
+ last_channel = make_divisible(
432
+ base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
433
+ )
434
+
435
+ stride_stages = [1, 1, 2, 2, 1, 2]
436
+ act_stages = ["relu", "relu", "relu", "h_swish", "h_swish", "h_swish"]
437
+ se_stages = [False, False, True, False, True, True]
438
+ n_block_list = [1] + [max(self.depth_list)] * 5
439
+ width_list = []
440
+ for base_width in base_stage_width[:-2]:
441
+ width = make_divisible(
442
+ base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
443
+ )
444
+ width_list.append(width)
445
+
446
+ input_channel, first_block_dim = width_list[0], width_list[1]
447
+ # first conv layer
448
+ first_conv = ConvLayer(
449
+ 3, input_channel, kernel_size=3, stride=1, act_func="h_swish"
450
+ )
451
+ first_block_conv = MBConvLayer(
452
+ in_channels=input_channel,
453
+ out_channels=first_block_dim,
454
+ kernel_size=3,
455
+ stride=stride_stages[0],
456
+ expand_ratio=1,
457
+ act_func=act_stages[0],
458
+ use_se=se_stages[0],
459
+ )
460
+ first_block = ResidualBlock(
461
+ first_block_conv,
462
+ IdentityLayer(first_block_dim, first_block_dim)
463
+ if input_channel == first_block_dim
464
+ else None,
465
+ )
466
+
467
+ # inverted residual blocks
468
+ self.block_group_info = []
469
+ blocks = [first_block]
470
+ _block_index = 1
471
+ feature_dim = first_block_dim
472
+
473
+ for width, n_block, s, act_func, use_se in zip(
474
+ width_list[2:],
475
+ n_block_list[1:],
476
+ stride_stages[1:],
477
+ act_stages[1:],
478
+ se_stages[1:],
479
+ ):
480
+ self.block_group_info.append([_block_index + i for i in range(n_block)])
481
+ _block_index += n_block
482
+
483
+ output_channel = width
484
+ for i in range(n_block):
485
+ if i == 0:
486
+ stride = s
487
+ else:
488
+ stride = 1
489
+ mobile_inverted_conv = DynamicMBConvLayer(
490
+ in_channel_list=val2list(feature_dim),
491
+ out_channel_list=val2list(output_channel),
492
+ kernel_size_list=ks_list,
493
+ expand_ratio_list=expand_ratio_list,
494
+ stride=stride,
495
+ act_func=act_func,
496
+ use_se=use_se,
497
+ )
498
+ if stride == 1 and feature_dim == output_channel:
499
+ shortcut = IdentityLayer(feature_dim, feature_dim)
500
+ else:
501
+ shortcut = None
502
+ blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
503
+ feature_dim = output_channel
504
+ # final expand layer, feature mix layer & classifier
505
+ final_expand_layer = ConvLayer(
506
+ feature_dim, final_expand_width, kernel_size=1, act_func="h_swish"
507
+ )
508
+ feature_mix_layer = ConvLayer(
509
+ final_expand_width,
510
+ last_channel,
511
+ kernel_size=1,
512
+ bias=False,
513
+ use_bn=False,
514
+ act_func="h_swish",
515
+ )
516
+
517
+ classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
518
+
519
+ super(DYNMobileNetV3_Cifar, self).__init__(
520
+ first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
521
+ )
522
+
523
+ # set bn param
524
+ self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
525
+
526
+ # runtime_depth
527
+ self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
528
+
529
+ """ MyNetwork required methods """
530
+
531
+ @staticmethod
532
+ def name():
533
+ return "DYNMobileNetV3_Cifar"
534
+
535
+ def forward(self, x):
536
+ # first conv
537
+ x = self.first_conv(x)
538
+ # first block
539
+ x = self.blocks[0](x)
540
+ # blocks
541
+ for stage_id, block_idx in enumerate(self.block_group_info):
542
+ depth = self.runtime_depth[stage_id]
543
+ active_idx = block_idx[:depth]
544
+ for idx in active_idx:
545
+ x = self.blocks[idx](x)
546
+ x = self.final_expand_layer(x)
547
+ x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
548
+ x = self.feature_mix_layer(x)
549
+ x = x.view(x.size(0), -1)
550
+ x = self.classifier(x)
551
+ return x
552
+
553
+ @property
554
+ def module_str(self):
555
+ _str = self.first_conv.module_str + "\n"
556
+ _str += self.blocks[0].module_str + "\n"
557
+
558
+ for stage_id, block_idx in enumerate(self.block_group_info):
559
+ depth = self.runtime_depth[stage_id]
560
+ active_idx = block_idx[:depth]
561
+ for idx in active_idx:
562
+ _str += self.blocks[idx].module_str + "\n"
563
+
564
+ _str += self.final_expand_layer.module_str + "\n"
565
+ _str += self.feature_mix_layer.module_str + "\n"
566
+ _str += self.classifier.module_str + "\n"
567
+ return _str
568
+
569
+ @property
570
+ def config(self):
571
+ return {
572
+ "name": DYNMobileNetV3_Cifar.__name__,
573
+ "bn": self.get_bn_param(),
574
+ "first_conv": self.first_conv.config,
575
+ "blocks": [block.config for block in self.blocks],
576
+ "final_expand_layer": self.final_expand_layer.config,
577
+ "feature_mix_layer": self.feature_mix_layer.config,
578
+ "classifier": self.classifier.config,
579
+ }
580
+
581
+ @staticmethod
582
+ def build_from_config(config):
583
+ raise ValueError("do not support this function")
584
+
585
+ @property
586
+ def grouped_block_index(self):
587
+ return self.block_group_info
588
+
589
+ def load_state_dict(self, state_dict, **kwargs):
590
+ model_dict = self.state_dict()
591
+ for key in state_dict:
592
+ if ".mobile_inverted_conv." in key:
593
+ new_key = key.replace(".mobile_inverted_conv.", ".conv.")
594
+ else:
595
+ new_key = key
596
+ if new_key in model_dict:
597
+ pass
598
+ elif ".bn.bn." in new_key:
599
+ new_key = new_key.replace(".bn.bn.", ".bn.")
600
+ elif ".conv.conv.weight" in new_key:
601
+ new_key = new_key.replace(".conv.conv.weight", ".conv.weight")
602
+ elif ".linear.linear." in new_key:
603
+ new_key = new_key.replace(".linear.linear.", ".linear.")
604
+ ##############################################################################
605
+ elif ".linear." in new_key:
606
+ new_key = new_key.replace(".linear.", ".linear.linear.")
607
+ elif "bn." in new_key:
608
+ new_key = new_key.replace("bn.", "bn.bn.")
609
+ elif "conv.weight" in new_key:
610
+ new_key = new_key.replace("conv.weight", "conv.conv.weight")
611
+ else:
612
+ raise ValueError(new_key)
613
+ assert new_key in model_dict, "%s" % new_key
614
+ model_dict[new_key] = state_dict[key]
615
+ super(DYNMobileNetV3_Cifar, self).load_state_dict(model_dict)
616
+
617
+ """ set, sample and get active sub-networks """
618
+
619
+ def set_max_net(self):
620
+ self.set_active_subnet(
621
+ ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list)
622
+ )
623
+
624
+ def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
625
+ ks = val2list(ks, len(self.blocks) - 1)
626
+ expand_ratio = val2list(e, len(self.blocks) - 1)
627
+ depth = val2list(d, len(self.block_group_info))
628
+
629
+ for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
630
+ if k is not None:
631
+ block.conv.active_kernel_size = k
632
+ if e is not None:
633
+ block.conv.active_expand_ratio = e
634
+
635
+ for i, d in enumerate(depth):
636
+ if d is not None:
637
+ self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
638
+
639
+ def set_constraint(self, include_list, constraint_type="depth"):
640
+ if constraint_type == "depth":
641
+ self.__dict__["_depth_include_list"] = include_list.copy()
642
+ elif constraint_type == "expand_ratio":
643
+ self.__dict__["_expand_include_list"] = include_list.copy()
644
+ elif constraint_type == "kernel_size":
645
+ self.__dict__["_ks_include_list"] = include_list.copy()
646
+ else:
647
+ raise NotImplementedError
648
+
649
+ def clear_constraint(self):
650
+ self.__dict__["_depth_include_list"] = None
651
+ self.__dict__["_expand_include_list"] = None
652
+ self.__dict__["_ks_include_list"] = None
653
+
654
+ def sample_active_subnet(self):
655
+ ks_candidates = (
656
+ self.ks_list
657
+ if self.__dict__.get("_ks_include_list", None) is None
658
+ else self.__dict__["_ks_include_list"]
659
+ )
660
+ expand_candidates = (
661
+ self.expand_ratio_list
662
+ if self.__dict__.get("_expand_include_list", None) is None
663
+ else self.__dict__["_expand_include_list"]
664
+ )
665
+ depth_candidates = (
666
+ self.depth_list
667
+ if self.__dict__.get("_depth_include_list", None) is None
668
+ else self.__dict__["_depth_include_list"]
669
+ )
670
+
671
+ # sample kernel size
672
+ ks_setting = []
673
+ if not isinstance(ks_candidates[0], list):
674
+ ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
675
+ for k_set in ks_candidates:
676
+ k = random.choice(k_set)
677
+ ks_setting.append(k)
678
+
679
+ # sample expand ratio
680
+ expand_setting = []
681
+ if not isinstance(expand_candidates[0], list):
682
+ expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
683
+ for e_set in expand_candidates:
684
+ e = random.choice(e_set)
685
+ expand_setting.append(e)
686
+
687
+ # sample depth
688
+ depth_setting = []
689
+ if not isinstance(depth_candidates[0], list):
690
+ depth_candidates = [
691
+ depth_candidates for _ in range(len(self.block_group_info))
692
+ ]
693
+ for d_set in depth_candidates:
694
+ d = random.choice(d_set)
695
+ depth_setting.append(d)
696
+
697
+ self.set_active_subnet(ks_setting, expand_setting, depth_setting)
698
+
699
+ return {
700
+ "ks": ks_setting,
701
+ "e": expand_setting,
702
+ "d": depth_setting,
703
+ }
704
+
705
+ def get_active_subnet(self, preserve_weight=True):
706
+ first_conv = copy.deepcopy(self.first_conv)
707
+ blocks = [copy.deepcopy(self.blocks[0])]
708
+
709
+ final_expand_layer = copy.deepcopy(self.final_expand_layer)
710
+ feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
711
+ classifier = copy.deepcopy(self.classifier)
712
+
713
+ input_channel = blocks[0].conv.out_channels
714
+ # blocks
715
+ for stage_id, block_idx in enumerate(self.block_group_info):
716
+ depth = self.runtime_depth[stage_id]
717
+ active_idx = block_idx[:depth]
718
+ stage_blocks = []
719
+ for idx in active_idx:
720
+ stage_blocks.append(
721
+ ResidualBlock(
722
+ self.blocks[idx].conv.get_active_subnet(
723
+ input_channel, preserve_weight
724
+ ),
725
+ copy.deepcopy(self.blocks[idx].shortcut),
726
+ )
727
+ )
728
+ input_channel = stage_blocks[-1].conv.out_channels
729
+ blocks += stage_blocks
730
+
731
+ _subnet = MobileNetV3_Cifar(
732
+ first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
733
+ )
734
+ _subnet.set_bn_param(**self.get_bn_param())
735
+ return _subnet
736
+
737
+ def get_active_net_config(self):
738
+ # first conv
739
+ first_conv_config = self.first_conv.config
740
+ first_block_config = self.blocks[0].config
741
+ final_expand_config = self.final_expand_layer.config
742
+ feature_mix_layer_config = self.feature_mix_layer.config
743
+ classifier_config = self.classifier.config
744
+
745
+ block_config_list = [first_block_config]
746
+ input_channel = first_block_config["conv"]["out_channels"]
747
+ for stage_id, block_idx in enumerate(self.block_group_info):
748
+ depth = self.runtime_depth[stage_id]
749
+ active_idx = block_idx[:depth]
750
+ stage_blocks = []
751
+ for idx in active_idx:
752
+ stage_blocks.append(
753
+ {
754
+ "name": ResidualBlock.__name__,
755
+ "conv": self.blocks[idx].conv.get_active_subnet_config(
756
+ input_channel
757
+ ),
758
+ "shortcut": self.blocks[idx].shortcut.config
759
+ if self.blocks[idx].shortcut is not None
760
+ else None,
761
+ }
762
+ )
763
+ input_channel = self.blocks[idx].conv.active_out_channel
764
+ block_config_list += stage_blocks
765
+
766
+ return {
767
+ "name": MobileNetV3_Cifar.__name__,
768
+ "bn": self.get_bn_param(),
769
+ "first_conv": first_conv_config,
770
+ "blocks": block_config_list,
771
+ "final_expand_layer": final_expand_config,
772
+ "feature_mix_layer": feature_mix_layer_config,
773
+ "classifier": classifier_config,
774
+ }
775
+
776
+ """ Width Related Methods """
777
+
778
+ def re_organize_middle_weights(self, expand_ratio_stage=0):
779
+ for block in self.blocks[1:]:
780
+ block.conv.re_organize_middle_weights(expand_ratio_stage)
proard/classification/elastic_nn/networks/dyn_proxyless.py ADDED
@@ -0,0 +1,774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import copy
6
+ import random
7
+
8
+ from proard.utils import make_divisible, val2list, MyNetwork
9
+ from proard.classification.elastic_nn.modules import DynamicMBConvLayer
10
+ from proard.utils.layers import (
11
+ ConvLayer,
12
+ IdentityLayer,
13
+ LinearLayer,
14
+ MBConvLayer,
15
+ ResidualBlock,
16
+ )
17
+ from proard.classification.networks.proxyless_nets import ProxylessNASNets,ProxylessNASNets_Cifar
18
+
19
+ __all__ = ["DYNProxylessNASNets","DYNProxylessNASNets_Cifar"]
20
+
21
+
22
+ class DYNProxylessNASNets(ProxylessNASNets):
23
+ def __init__(
24
+ self,
25
+ n_classes=1000,
26
+ bn_param=(0.1, 1e-3),
27
+ dropout_rate=0.1,
28
+ base_stage_width=None,
29
+ width_mult=1.0,
30
+ ks_list=3,
31
+ expand_ratio_list=6,
32
+ depth_list=4,
33
+ ):
34
+
35
+ self.width_mult = width_mult
36
+ self.ks_list = val2list(ks_list, 1)
37
+ self.expand_ratio_list = val2list(expand_ratio_list, 1)
38
+ self.depth_list = val2list(depth_list, 1)
39
+
40
+ self.ks_list.sort()
41
+ self.expand_ratio_list.sort()
42
+ self.depth_list.sort()
43
+
44
+ if base_stage_width == "google":
45
+ # MobileNetV2 Stage Width
46
+ base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280]
47
+ else:
48
+ # ProxylessNAS Stage Width
49
+ base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280]
50
+
51
+ input_channel = make_divisible(
52
+ base_stage_width[0] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
53
+ )
54
+ first_block_width = make_divisible(
55
+ base_stage_width[1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
56
+ )
57
+ last_channel = make_divisible(
58
+ base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
59
+ )
60
+
61
+ # first conv layer
62
+ first_conv = ConvLayer(
63
+ 3,
64
+ input_channel,
65
+ kernel_size=3,
66
+ stride=2,
67
+ use_bn=True,
68
+ act_func="relu6",
69
+ ops_order="weight_bn_act",
70
+ )
71
+ # first block
72
+ first_block_conv = MBConvLayer(
73
+ in_channels=input_channel,
74
+ out_channels=first_block_width,
75
+ kernel_size=3,
76
+ stride=1,
77
+ expand_ratio=1,
78
+ act_func="relu6",
79
+ )
80
+ first_block = ResidualBlock(first_block_conv, None)
81
+
82
+ input_channel = first_block_width
83
+ # inverted residual blocks
84
+ self.block_group_info = []
85
+ blocks = [first_block]
86
+ _block_index = 1
87
+
88
+ stride_stages = [2, 2, 2, 1, 2, 1]
89
+ n_block_list = [max(self.depth_list)] * 5 + [1]
90
+
91
+ width_list = []
92
+ for base_width in base_stage_width[2:-1]:
93
+ width = make_divisible(
94
+ base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
95
+ )
96
+ width_list.append(width)
97
+
98
+ for width, n_block, s in zip(width_list, n_block_list, stride_stages):
99
+ self.block_group_info.append([_block_index + i for i in range(n_block)])
100
+ _block_index += n_block
101
+
102
+ output_channel = width
103
+ for i in range(n_block):
104
+ if i == 0:
105
+ stride = s
106
+ else:
107
+ stride = 1
108
+
109
+ mobile_inverted_conv = DynamicMBConvLayer(
110
+ in_channel_list=val2list(input_channel, 1),
111
+ out_channel_list=val2list(output_channel, 1),
112
+ kernel_size_list=ks_list,
113
+ expand_ratio_list=expand_ratio_list,
114
+ stride=stride,
115
+ act_func="relu6",
116
+ )
117
+
118
+ if stride == 1 and input_channel == output_channel:
119
+ shortcut = IdentityLayer(input_channel, input_channel)
120
+ else:
121
+ shortcut = None
122
+
123
+ mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut)
124
+
125
+ blocks.append(mb_inverted_block)
126
+ input_channel = output_channel
127
+ # 1x1_conv before global average pooling
128
+ feature_mix_layer = ConvLayer(
129
+ input_channel,
130
+ last_channel,
131
+ kernel_size=1,
132
+ use_bn=True,
133
+ act_func="relu6",
134
+ )
135
+ classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
136
+
137
+ super(DYNProxylessNASNets, self).__init__(
138
+ first_conv, blocks, feature_mix_layer, classifier
139
+ )
140
+
141
+ # set bn param
142
+ self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
143
+
144
+ # runtime_depth
145
+ self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
146
+
147
+ """ MyNetwork required methods """
148
+
149
+ @staticmethod
150
+ def name():
151
+ return "DYNProxylessNASNets"
152
+
153
+ def forward(self, x):
154
+ # first conv
155
+ x = self.first_conv(x)
156
+ # first block
157
+ x = self.blocks[0](x)
158
+
159
+ # blocks
160
+ for stage_id, block_idx in enumerate(self.block_group_info):
161
+ depth = self.runtime_depth[stage_id]
162
+ active_idx = block_idx[:depth]
163
+ for idx in active_idx:
164
+ x = self.blocks[idx](x)
165
+
166
+ # feature_mix_layer
167
+ x = self.feature_mix_layer(x)
168
+ x = x.mean(3).mean(2)
169
+
170
+ x = self.classifier(x)
171
+ return x
172
+
173
+ @property
174
+ def module_str(self):
175
+ _str = self.first_conv.module_str + "\n"
176
+ _str += self.blocks[0].module_str + "\n"
177
+
178
+ for stage_id, block_idx in enumerate(self.block_group_info):
179
+ depth = self.runtime_depth[stage_id]
180
+ active_idx = block_idx[:depth]
181
+ for idx in active_idx:
182
+ _str += self.blocks[idx].module_str + "\n"
183
+ _str += self.feature_mix_layer.module_str + "\n"
184
+ _str += self.classifier.module_str + "\n"
185
+ return _str
186
+
187
+ @property
188
+ def config(self):
189
+ return {
190
+ "name": DYNProxylessNASNets.__name__,
191
+ "bn": self.get_bn_param(),
192
+ "first_conv": self.first_conv.config,
193
+ "blocks": [block.config for block in self.blocks],
194
+ "feature_mix_layer": None
195
+ if self.feature_mix_layer is None
196
+ else self.feature_mix_layer.config,
197
+ "classifier": self.classifier.config,
198
+ }
199
+
200
+ @staticmethod
201
+ def build_from_config(config):
202
+ raise ValueError("do not support this function")
203
+
204
+ @property
205
+ def grouped_block_index(self):
206
+ return self.block_group_info
207
+
208
+ def load_state_dict(self, state_dict, **kwargs):
209
+ model_dict = self.state_dict()
210
+ for key in state_dict:
211
+ if ".mobile_inverted_conv." in key:
212
+ new_key = key.replace(".mobile_inverted_conv.", ".conv.")
213
+ else:
214
+ new_key = key
215
+ if new_key in model_dict:
216
+ pass
217
+ elif ".bn.bn." in new_key:
218
+ new_key = new_key.replace(".bn.bn.", ".bn.")
219
+ elif ".conv.conv.weight" in new_key:
220
+ new_key = new_key.replace(".conv.conv.weight", ".conv.weight")
221
+ elif ".linear.linear." in new_key:
222
+ new_key = new_key.replace(".linear.linear.", ".linear.")
223
+ ##############################################################################
224
+ elif ".linear." in new_key:
225
+ new_key = new_key.replace(".linear.", ".linear.linear.")
226
+ elif "bn." in new_key:
227
+ new_key = new_key.replace("bn.", "bn.bn.")
228
+ elif "conv.weight" in new_key:
229
+ new_key = new_key.replace("conv.weight", "conv.conv.weight")
230
+ else:
231
+ raise ValueError(new_key)
232
+ assert new_key in model_dict, "%s" % new_key
233
+ model_dict[new_key] = state_dict[key]
234
+ super(DYNProxylessNASNets, self).load_state_dict(model_dict)
235
+
236
+ """ set, sample and get active sub-networks """
237
+
238
+ def set_max_net(self):
239
+ self.set_active_subnet(
240
+ ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list)
241
+ )
242
+
243
+ def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
244
+ ks = val2list(ks, len(self.blocks) - 1)
245
+ expand_ratio = val2list(e, len(self.blocks) - 1)
246
+ depth = val2list(d, len(self.block_group_info))
247
+
248
+ for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
249
+ if k is not None:
250
+ block.conv.active_kernel_size = k
251
+ if e is not None:
252
+ block.conv.active_expand_ratio = e
253
+
254
+ for i, d in enumerate(depth):
255
+ if d is not None:
256
+ self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
257
+
258
+ def set_constraint(self, include_list, constraint_type="depth"):
259
+ if constraint_type == "depth":
260
+ self.__dict__["_depth_include_list"] = include_list.copy()
261
+ elif constraint_type == "expand_ratio":
262
+ self.__dict__["_expand_include_list"] = include_list.copy()
263
+ elif constraint_type == "kernel_size":
264
+ self.__dict__["_ks_include_list"] = include_list.copy()
265
+ else:
266
+ raise NotImplementedError
267
+
268
+ def clear_constraint(self):
269
+ self.__dict__["_depth_include_list"] = None
270
+ self.__dict__["_expand_include_list"] = None
271
+ self.__dict__["_ks_include_list"] = None
272
+
273
+ def sample_active_subnet(self):
274
+ ks_candidates = (
275
+ self.ks_list
276
+ if self.__dict__.get("_ks_include_list", None) is None
277
+ else self.__dict__["_ks_include_list"]
278
+ )
279
+ expand_candidates = (
280
+ self.expand_ratio_list
281
+ if self.__dict__.get("_expand_include_list", None) is None
282
+ else self.__dict__["_expand_include_list"]
283
+ )
284
+ depth_candidates = (
285
+ self.depth_list
286
+ if self.__dict__.get("_depth_include_list", None) is None
287
+ else self.__dict__["_depth_include_list"]
288
+ )
289
+
290
+ # sample kernel size
291
+ ks_setting = []
292
+ if not isinstance(ks_candidates[0], list):
293
+ ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
294
+ for k_set in ks_candidates:
295
+ k = random.choice(k_set)
296
+ ks_setting.append(k)
297
+
298
+ # sample expand ratio
299
+ expand_setting = []
300
+ if not isinstance(expand_candidates[0], list):
301
+ expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
302
+ for e_set in expand_candidates:
303
+ e = random.choice(e_set)
304
+ expand_setting.append(e)
305
+
306
+ # sample depth
307
+ depth_setting = []
308
+ if not isinstance(depth_candidates[0], list):
309
+ depth_candidates = [
310
+ depth_candidates for _ in range(len(self.block_group_info))
311
+ ]
312
+ for d_set in depth_candidates:
313
+ d = random.choice(d_set)
314
+ depth_setting.append(d)
315
+
316
+ depth_setting[-1] = 1
317
+ self.set_active_subnet(ks_setting, expand_setting, depth_setting)
318
+
319
+ return {
320
+ "ks": ks_setting,
321
+ "e": expand_setting,
322
+ "d": depth_setting,
323
+ }
324
+
325
+ def get_active_subnet(self, preserve_weight=True):
326
+ first_conv = copy.deepcopy(self.first_conv)
327
+ blocks = [copy.deepcopy(self.blocks[0])]
328
+ feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
329
+ classifier = copy.deepcopy(self.classifier)
330
+
331
+ input_channel = blocks[0].conv.out_channels
332
+ # blocks
333
+ for stage_id, block_idx in enumerate(self.block_group_info):
334
+ depth = self.runtime_depth[stage_id]
335
+ active_idx = block_idx[:depth]
336
+ stage_blocks = []
337
+ for idx in active_idx:
338
+ stage_blocks.append(
339
+ ResidualBlock(
340
+ self.blocks[idx].conv.get_active_subnet(
341
+ input_channel, preserve_weight
342
+ ),
343
+ copy.deepcopy(self.blocks[idx].shortcut),
344
+ )
345
+ )
346
+ input_channel = stage_blocks[-1].conv.out_channels
347
+ blocks += stage_blocks
348
+
349
+ _subnet = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
350
+ _subnet.set_bn_param(**self.get_bn_param())
351
+ return _subnet
352
+
353
+ def get_active_net_config(self):
354
+ first_conv_config = self.first_conv.config
355
+ first_block_config = self.blocks[0].config
356
+ feature_mix_layer_config = self.feature_mix_layer.config
357
+ classifier_config = self.classifier.config
358
+
359
+ block_config_list = [first_block_config]
360
+ input_channel = first_block_config["conv"]["out_channels"]
361
+ for stage_id, block_idx in enumerate(self.block_group_info):
362
+ depth = self.runtime_depth[stage_id]
363
+ active_idx = block_idx[:depth]
364
+ stage_blocks = []
365
+ for idx in active_idx:
366
+ stage_blocks.append(
367
+ {
368
+ "name": ResidualBlock.__name__,
369
+ "conv": self.blocks[idx].conv.get_active_subnet_config(
370
+ input_channel
371
+ ),
372
+ "shortcut": self.blocks[idx].shortcut.config
373
+ if self.blocks[idx].shortcut is not None
374
+ else None,
375
+ }
376
+ )
377
+ try:
378
+ input_channel = self.blocks[idx].conv.active_out_channel
379
+ except Exception:
380
+ input_channel = self.blocks[idx].conv.out_channels
381
+ block_config_list += stage_blocks
382
+
383
+ return {
384
+ "name": ProxylessNASNets.__name__,
385
+ "bn": self.get_bn_param(),
386
+ "first_conv": first_conv_config,
387
+ "blocks": block_config_list,
388
+ "feature_mix_layer": feature_mix_layer_config,
389
+ "classifier": classifier_config,
390
+ }
391
+
392
+ """ Width Related Methods """
393
+
394
+ def re_organize_middle_weights(self, expand_ratio_stage=0):
395
+ for block in self.blocks[1:]:
396
+ block.conv.re_organize_middle_weights(expand_ratio_stage)
397
+
398
+
399
+
400
+ class DYNProxylessNASNets_Cifar(ProxylessNASNets_Cifar):
401
+ def __init__(
402
+ self,
403
+ n_classes=10,
404
+ bn_param=(0.1, 1e-3),
405
+ dropout_rate=0.1,
406
+ base_stage_width=None,
407
+ width_mult=1.0,
408
+ ks_list=3,
409
+ expand_ratio_list=6,
410
+ depth_list=4,
411
+ ):
412
+
413
+ self.width_mult = width_mult
414
+ self.ks_list = val2list(ks_list, 1)
415
+ self.expand_ratio_list = val2list(expand_ratio_list, 1)
416
+ self.depth_list = val2list(depth_list, 1)
417
+
418
+ self.ks_list.sort()
419
+ self.expand_ratio_list.sort()
420
+ self.depth_list.sort()
421
+
422
+ if base_stage_width == "MBV2":
423
+ # MobileNetV2 Stage Width
424
+ base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280]
425
+ else:
426
+ # ProxylessNAS Stage Width
427
+ base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280]
428
+
429
+ input_channel = make_divisible(
430
+ base_stage_width[0] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
431
+ )
432
+ first_block_width = make_divisible(
433
+ base_stage_width[1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
434
+ )
435
+ last_channel = make_divisible(
436
+ base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
437
+ )
438
+
439
+ # first conv layer
440
+ first_conv = ConvLayer(
441
+ 3,
442
+ input_channel,
443
+ kernel_size=3,
444
+ stride=1,
445
+ use_bn=True,
446
+ act_func="relu6",
447
+ ops_order="weight_bn_act",
448
+ )
449
+ # first block
450
+ first_block_conv = MBConvLayer(
451
+ in_channels=input_channel,
452
+ out_channels=first_block_width,
453
+ kernel_size=3,
454
+ stride=1,
455
+ expand_ratio=1,
456
+ act_func="relu6",
457
+ )
458
+ first_block = ResidualBlock(first_block_conv, None)
459
+
460
+ input_channel = first_block_width
461
+ # inverted residual blocks
462
+ self.block_group_info = []
463
+ blocks = [first_block]
464
+ _block_index = 1
465
+
466
+ stride_stages = [1, 2, 2, 1, 2, 1]
467
+ n_block_list = [max(self.depth_list)] * 5 + [1]
468
+
469
+ width_list = []
470
+ for base_width in base_stage_width[2:-1]:
471
+ width = make_divisible(
472
+ base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE
473
+ )
474
+ width_list.append(width)
475
+
476
+ for width, n_block, s in zip(width_list, n_block_list, stride_stages):
477
+ self.block_group_info.append([_block_index + i for i in range(n_block)])
478
+ _block_index += n_block
479
+
480
+ output_channel = width
481
+ for i in range(n_block):
482
+ if i == 0:
483
+ stride = s
484
+ else:
485
+ stride = 1
486
+
487
+ mobile_inverted_conv = DynamicMBConvLayer(
488
+ in_channel_list=val2list(input_channel, 1),
489
+ out_channel_list=val2list(output_channel, 1),
490
+ kernel_size_list=ks_list,
491
+ expand_ratio_list=expand_ratio_list,
492
+ stride=stride,
493
+ act_func="relu6",
494
+ )
495
+
496
+ if stride == 1 and input_channel == output_channel:
497
+ shortcut = IdentityLayer(input_channel, input_channel)
498
+ else:
499
+ shortcut = None
500
+
501
+ mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut)
502
+
503
+ blocks.append(mb_inverted_block)
504
+ input_channel = output_channel
505
+ # 1x1_conv before global average pooling
506
+ feature_mix_layer = ConvLayer(
507
+ input_channel,
508
+ last_channel,
509
+ kernel_size=1,
510
+ use_bn=True,
511
+ act_func="relu6",
512
+ )
513
+ classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
514
+
515
+ super(DYNProxylessNASNets_Cifar, self).__init__(
516
+ first_conv, blocks, feature_mix_layer, classifier
517
+ )
518
+
519
+ # set bn param
520
+ self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
521
+
522
+ # runtime_depth
523
+ self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
524
+
525
+ """ MyNetwork required methods """
526
+
527
+ @staticmethod
528
+ def name():
529
+ return "DYNProxylessNASNets_Cifar"
530
+
531
+ def forward(self, x):
532
+ # first conv
533
+ x = self.first_conv(x)
534
+ # first block
535
+ x = self.blocks[0](x)
536
+
537
+ # blocks
538
+ for stage_id, block_idx in enumerate(self.block_group_info):
539
+ depth = self.runtime_depth[stage_id]
540
+ active_idx = block_idx[:depth]
541
+ for idx in active_idx:
542
+ x = self.blocks[idx](x)
543
+
544
+ # feature_mix_layer
545
+ x = self.feature_mix_layer(x)
546
+ x = x.mean(3).mean(2)
547
+
548
+ x = self.classifier(x)
549
+ return x
550
+
551
+ @property
552
+ def module_str(self):
553
+ _str = self.first_conv.module_str + "\n"
554
+ _str += self.blocks[0].module_str + "\n"
555
+
556
+ for stage_id, block_idx in enumerate(self.block_group_info):
557
+ depth = self.runtime_depth[stage_id]
558
+ active_idx = block_idx[:depth]
559
+ for idx in active_idx:
560
+ _str += self.blocks[idx].module_str + "\n"
561
+ _str += self.feature_mix_layer.module_str + "\n"
562
+ _str += self.classifier.module_str + "\n"
563
+ return _str
564
+
565
+ @property
566
+ def config(self):
567
+ return {
568
+ "name": DYNProxylessNASNets_Cifar.__name__,
569
+ "bn": self.get_bn_param(),
570
+ "first_conv": self.first_conv.config,
571
+ "blocks": [block.config for block in self.blocks],
572
+ "feature_mix_layer": None
573
+ if self.feature_mix_layer is None
574
+ else self.feature_mix_layer.config,
575
+ "classifier": self.classifier.config,
576
+ }
577
+
578
+ @staticmethod
579
+ def build_from_config(config):
580
+ raise ValueError("do not support this function")
581
+
582
+ @property
583
+ def grouped_block_index(self):
584
+ return self.block_group_info
585
+
586
+ def load_state_dict(self, state_dict, **kwargs):
587
+ model_dict = self.state_dict()
588
+ for key in state_dict:
589
+ if ".mobile_inverted_conv." in key:
590
+ new_key = key.replace(".mobile_inverted_conv.", ".conv.")
591
+ else:
592
+ new_key = key
593
+ if new_key in model_dict:
594
+ pass
595
+ elif ".bn.bn." in new_key:
596
+ new_key = new_key.replace(".bn.bn.", ".bn.")
597
+ elif ".conv.conv.weight" in new_key:
598
+ new_key = new_key.replace(".conv.conv.weight", ".conv.weight")
599
+ elif ".linear.linear." in new_key:
600
+ new_key = new_key.replace(".linear.linear.", ".linear.")
601
+ ##############################################################################
602
+ elif ".linear." in new_key:
603
+ new_key = new_key.replace(".linear.", ".linear.linear.")
604
+ elif "bn." in new_key:
605
+ new_key = new_key.replace("bn.", "bn.bn.")
606
+ elif "conv.weight" in new_key:
607
+ new_key = new_key.replace("conv.weight", "conv.conv.weight")
608
+ else:
609
+ raise ValueError(new_key)
610
+ assert new_key in model_dict, "%s" % new_key
611
+ model_dict[new_key] = state_dict[key]
612
+ super(DYNProxylessNASNets_Cifar, self).load_state_dict(model_dict)
613
+
614
+ """ set, sample and get active sub-networks """
615
+
616
+ def set_max_net(self):
617
+ self.set_active_subnet(
618
+ ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list)
619
+ )
620
+
621
+ def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
622
+ ks = val2list(ks, len(self.blocks) - 1)
623
+ expand_ratio = val2list(e, len(self.blocks) - 1)
624
+ depth = val2list(d, len(self.block_group_info))
625
+
626
+ for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
627
+ if k is not None:
628
+ block.conv.active_kernel_size = k
629
+ if e is not None:
630
+ block.conv.active_expand_ratio = e
631
+
632
+ for i, d in enumerate(depth):
633
+ if d is not None:
634
+ self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
635
+
636
+ def set_constraint(self, include_list, constraint_type="depth"):
637
+ if constraint_type == "depth":
638
+ self.__dict__["_depth_include_list"] = include_list.copy()
639
+ elif constraint_type == "expand_ratio":
640
+ self.__dict__["_expand_include_list"] = include_list.copy()
641
+ elif constraint_type == "kernel_size":
642
+ self.__dict__["_ks_include_list"] = include_list.copy()
643
+ else:
644
+ raise NotImplementedError
645
+
646
+ def clear_constraint(self):
647
+ self.__dict__["_depth_include_list"] = None
648
+ self.__dict__["_expand_include_list"] = None
649
+ self.__dict__["_ks_include_list"] = None
650
+
651
+ def sample_active_subnet(self):
652
+ ks_candidates = (
653
+ self.ks_list
654
+ if self.__dict__.get("_ks_include_list", None) is None
655
+ else self.__dict__["_ks_include_list"]
656
+ )
657
+ expand_candidates = (
658
+ self.expand_ratio_list
659
+ if self.__dict__.get("_expand_include_list", None) is None
660
+ else self.__dict__["_expand_include_list"]
661
+ )
662
+ depth_candidates = (
663
+ self.depth_list
664
+ if self.__dict__.get("_depth_include_list", None) is None
665
+ else self.__dict__["_depth_include_list"]
666
+ )
667
+
668
+ # sample kernel size
669
+ ks_setting = []
670
+ if not isinstance(ks_candidates[0], list):
671
+ ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
672
+ for k_set in ks_candidates:
673
+ k = random.choice(k_set)
674
+ ks_setting.append(k)
675
+
676
+ # sample expand ratio
677
+ expand_setting = []
678
+ if not isinstance(expand_candidates[0], list):
679
+ expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
680
+ for e_set in expand_candidates:
681
+ e = random.choice(e_set)
682
+ expand_setting.append(e)
683
+
684
+ # sample depth
685
+ depth_setting = []
686
+ if not isinstance(depth_candidates[0], list):
687
+ depth_candidates = [
688
+ depth_candidates for _ in range(len(self.block_group_info))
689
+ ]
690
+ for d_set in depth_candidates:
691
+ d = random.choice(d_set)
692
+ depth_setting.append(d)
693
+
694
+ depth_setting[-1] = 1
695
+ self.set_active_subnet(ks_setting, expand_setting, depth_setting)
696
+
697
+ return {
698
+ "ks": ks_setting,
699
+ "e": expand_setting,
700
+ "d": depth_setting,
701
+ }
702
+
703
+ def get_active_subnet(self, preserve_weight=True):
704
+ first_conv = copy.deepcopy(self.first_conv)
705
+ blocks = [copy.deepcopy(self.blocks[0])]
706
+ feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
707
+ classifier = copy.deepcopy(self.classifier)
708
+
709
+ input_channel = blocks[0].conv.out_channels
710
+ # blocks
711
+ for stage_id, block_idx in enumerate(self.block_group_info):
712
+ depth = self.runtime_depth[stage_id]
713
+ active_idx = block_idx[:depth]
714
+ stage_blocks = []
715
+ for idx in active_idx:
716
+ stage_blocks.append(
717
+ ResidualBlock(
718
+ self.blocks[idx].conv.get_active_subnet(
719
+ input_channel, preserve_weight
720
+ ),
721
+ copy.deepcopy(self.blocks[idx].shortcut),
722
+ )
723
+ )
724
+ input_channel = stage_blocks[-1].conv.out_channels
725
+ blocks += stage_blocks
726
+
727
+ _subnet = ProxylessNASNets_Cifar(first_conv, blocks, feature_mix_layer, classifier)
728
+ _subnet.set_bn_param(**self.get_bn_param())
729
+ return _subnet
730
+
731
+ def get_active_net_config(self):
732
+ first_conv_config = self.first_conv.config
733
+ first_block_config = self.blocks[0].config
734
+ feature_mix_layer_config = self.feature_mix_layer.config
735
+ classifier_config = self.classifier.config
736
+
737
+ block_config_list = [first_block_config]
738
+ input_channel = first_block_config["conv"]["out_channels"]
739
+ for stage_id, block_idx in enumerate(self.block_group_info):
740
+ depth = self.runtime_depth[stage_id]
741
+ active_idx = block_idx[:depth]
742
+ stage_blocks = []
743
+ for idx in active_idx:
744
+ stage_blocks.append(
745
+ {
746
+ "name": ResidualBlock.__name__,
747
+ "conv": self.blocks[idx].conv.get_active_subnet_config(
748
+ input_channel
749
+ ),
750
+ "shortcut": self.blocks[idx].shortcut.config
751
+ if self.blocks[idx].shortcut is not None
752
+ else None,
753
+ }
754
+ )
755
+ try:
756
+ input_channel = self.blocks[idx].conv.active_out_channel
757
+ except Exception:
758
+ input_channel = self.blocks[idx].conv.out_channels
759
+ block_config_list += stage_blocks
760
+
761
+ return {
762
+ "name": ProxylessNASNets_Cifar.__name__,
763
+ "bn": self.get_bn_param(),
764
+ "first_conv": first_conv_config,
765
+ "blocks": block_config_list,
766
+ "feature_mix_layer": feature_mix_layer_config,
767
+ "classifier": classifier_config,
768
+ }
769
+
770
+ """ Width Related Methods """
771
+
772
+ def re_organize_middle_weights(self, expand_ratio_stage=0):
773
+ for block in self.blocks[1:]:
774
+ block.conv.re_organize_middle_weights(expand_ratio_stage)
proard/classification/elastic_nn/networks/dyn_resnets.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from proard.classification.elastic_nn.modules.dynamic_layers import (
4
+ DynamicConvLayer,
5
+ DynamicLinearLayer,
6
+ )
7
+ from proard.classification.elastic_nn.modules.dynamic_layers import (
8
+ DynamicResNetBottleneckBlock,
9
+ )
10
+ from proard.utils.layers import IdentityLayer, ResidualBlock
11
+ from proard.classification.networks import ResNets,ResNets_Cifar
12
+ from proard.utils import make_divisible, val2list, MyNetwork
13
+
14
+ __all__ = ["DYNResNets","DYNResNets_Cifar"]
15
+
16
+
17
+ class DYNResNets(ResNets):
18
+ def __init__(
19
+ self,
20
+ n_classes=1000,
21
+ bn_param=(0.1, 1e-5),
22
+ dropout_rate=0,
23
+ depth_list=2,
24
+ expand_ratio_list=0.25,
25
+ width_mult_list=1.0,
26
+ ):
27
+
28
+ self.depth_list = val2list(depth_list)
29
+ self.expand_ratio_list = val2list(expand_ratio_list)
30
+ self.width_mult_list = val2list(width_mult_list)
31
+ # sort
32
+ self.depth_list.sort()
33
+ self.expand_ratio_list.sort()
34
+ self.width_mult_list.sort()
35
+
36
+ input_channel = [
37
+ make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
38
+ for width_mult in self.width_mult_list
39
+ ]
40
+ mid_input_channel = [
41
+ make_divisible(channel // 2, MyNetwork.CHANNEL_DIVISIBLE)
42
+ for channel in input_channel
43
+ ]
44
+
45
+ stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
46
+ for i, width in enumerate(stage_width_list):
47
+ stage_width_list[i] = [
48
+ make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
49
+ for width_mult in self.width_mult_list
50
+ ]
51
+
52
+ n_block_list = [
53
+ base_depth + max(self.depth_list) for base_depth in ResNets.BASE_DEPTH_LIST
54
+ ]
55
+ stride_list = [1, 2, 2, 2]
56
+
57
+ # build input stem
58
+ input_stem = [
59
+ DynamicConvLayer(
60
+ val2list(3),
61
+ mid_input_channel,
62
+ 3,
63
+ stride=2,
64
+ use_bn=True,
65
+ act_func="relu",
66
+ ),
67
+ ResidualBlock(
68
+ DynamicConvLayer(
69
+ mid_input_channel,
70
+ mid_input_channel,
71
+ 3,
72
+ stride=1,
73
+ use_bn=True,
74
+ act_func="relu",
75
+ ),
76
+ IdentityLayer(mid_input_channel, mid_input_channel),
77
+ ),
78
+ DynamicConvLayer(
79
+ mid_input_channel,
80
+ input_channel,
81
+ 3,
82
+ stride=1,
83
+ use_bn=True,
84
+ act_func="relu",
85
+ ),
86
+ ]
87
+
88
+ # blocks
89
+ blocks = []
90
+ for d, width, s in zip(n_block_list, stage_width_list, stride_list):
91
+ for i in range(d):
92
+ stride = s if i == 0 else 1
93
+ bottleneck_block = DynamicResNetBottleneckBlock(
94
+ input_channel,
95
+ width,
96
+ expand_ratio_list=self.expand_ratio_list,
97
+ kernel_size=3,
98
+ stride=stride,
99
+ act_func="relu",
100
+ downsample_mode="avgpool_conv",
101
+ )
102
+ blocks.append(bottleneck_block)
103
+ input_channel = width
104
+ # classifier
105
+ classifier = DynamicLinearLayer(
106
+ input_channel, n_classes, dropout_rate=dropout_rate
107
+ )
108
+
109
+ super(DYNResNets, self).__init__(input_stem, blocks, classifier)
110
+
111
+ # set bn param
112
+ self.set_bn_param(*bn_param)
113
+
114
+ # runtime_depth
115
+ self.input_stem_skipping = 0
116
+ self.runtime_depth = [0] * len(n_block_list)
117
+
118
+ @property
119
+ def ks_list(self):
120
+ return [3]
121
+
122
+ @staticmethod
123
+ def name():
124
+ return "DYNResNets"
125
+
126
+ def forward(self, x):
127
+ for layer in self.input_stem:
128
+ if (
129
+ self.input_stem_skipping > 0
130
+ and isinstance(layer, ResidualBlock)
131
+ and isinstance(layer.shortcut, IdentityLayer)
132
+ ):
133
+ pass
134
+ else:
135
+ x = layer(x)
136
+ x = self.max_pooling(x)
137
+ for stage_id, block_idx in enumerate(self.grouped_block_index):
138
+ depth_param = self.runtime_depth[stage_id]
139
+ active_idx = block_idx[: len(block_idx) - depth_param]
140
+ for idx in active_idx:
141
+ x = self.blocks[idx](x)
142
+ x = self.global_avg_pool(x)
143
+ x = self.classifier(x)
144
+ return x
145
+
146
+ @property
147
+ def module_str(self):
148
+ _str = ""
149
+ for layer in self.input_stem:
150
+ if (
151
+ self.input_stem_skipping > 0
152
+ and isinstance(layer, ResidualBlock)
153
+ and isinstance(layer.shortcut, IdentityLayer)
154
+ ):
155
+ pass
156
+ else:
157
+ _str += layer.module_str + "\n"
158
+ _str += "max_pooling(ks=3, stride=2)\n"
159
+ for stage_id, block_idx in enumerate(self.grouped_block_index):
160
+ depth_param = self.runtime_depth[stage_id]
161
+ active_idx = block_idx[: len(block_idx) - depth_param]
162
+ for idx in active_idx:
163
+ _str += self.blocks[idx].module_str + "\n"
164
+ _str += self.global_avg_pool.__repr__() + "\n"
165
+ _str += self.classifier.module_str
166
+ return _str
167
+
168
+ @property
169
+ def config(self):
170
+ return {
171
+ "name": DYNResNets.__name__,
172
+ "bn": self.get_bn_param(),
173
+ "input_stem": [layer.config for layer in self.input_stem],
174
+ "blocks": [block.config for block in self.blocks],
175
+ "classifier": self.classifier.config,
176
+ }
177
+
178
+ @staticmethod
179
+ def build_from_config(config):
180
+ raise ValueError("do not support this function")
181
+
182
+ def load_state_dict(self, state_dict, **kwargs):
183
+ model_dict = self.state_dict()
184
+ for key in state_dict:
185
+ new_key = key
186
+ if new_key in model_dict:
187
+ pass
188
+ elif ".linear." in new_key:
189
+ new_key = new_key.replace(".linear.", ".linear.linear.")
190
+ elif "bn." in new_key:
191
+ new_key = new_key.replace("bn.", "bn.bn.")
192
+ elif "conv.weight" in new_key:
193
+ new_key = new_key.replace("conv.weight", "conv.conv.weight")
194
+ else:
195
+ raise ValueError(new_key)
196
+ assert new_key in model_dict, "%s" % new_key
197
+ model_dict[new_key] = state_dict[key]
198
+ super(DYNResNets, self).load_state_dict(model_dict)
199
+
200
+ """ set, sample and get active sub-networks """
201
+
202
+ def set_max_net(self):
203
+ self.set_active_subnet(
204
+ d=max(self.depth_list),
205
+ e=max(self.expand_ratio_list),
206
+ w=len(self.width_mult_list) - 1,
207
+ )
208
+
209
+ def set_active_subnet(self, d=None, e=None, w=None, **kwargs):
210
+ depth = val2list(d, len(ResNets.BASE_DEPTH_LIST) + 1)
211
+ expand_ratio = val2list(e, len(self.blocks))
212
+ width_mult = val2list(w, len(ResNets.BASE_DEPTH_LIST) + 2)
213
+
214
+ for block, e in zip(self.blocks, expand_ratio):
215
+ if e is not None:
216
+ block.active_expand_ratio = e
217
+
218
+ if width_mult[0] is not None:
219
+ self.input_stem[1].conv.active_out_channel = self.input_stem[
220
+ 0
221
+ ].active_out_channel = self.input_stem[0].out_channel_list[width_mult[0]]
222
+ if width_mult[1] is not None:
223
+ self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[
224
+ width_mult[1]
225
+ ]
226
+
227
+ if depth[0] is not None:
228
+ self.input_stem_skipping = depth[0] != max(self.depth_list)
229
+ for stage_id, (block_idx, d, w) in enumerate(
230
+ zip(self.grouped_block_index, depth[1:], width_mult[2:])
231
+ ):
232
+ if d is not None:
233
+ self.runtime_depth[stage_id] = max(self.depth_list) - d
234
+ if w is not None:
235
+ for idx in block_idx:
236
+ self.blocks[idx].active_out_channel = self.blocks[
237
+ idx
238
+ ].out_channel_list[w]
239
+
240
+ def sample_active_subnet(self):
241
+ # sample expand ratio
242
+ expand_setting = []
243
+ for block in self.blocks:
244
+ expand_setting.append(random.choice(block.expand_ratio_list))
245
+
246
+ # sample depth
247
+ depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])]
248
+ for stage_id in range(len(ResNets.BASE_DEPTH_LIST)):
249
+ depth_setting.append(random.choice(self.depth_list))
250
+
251
+ # sample width_mult
252
+ width_mult_setting = [
253
+ random.choice(list(range(len(self.input_stem[0].out_channel_list)))),
254
+ random.choice(list(range(len(self.input_stem[2].out_channel_list)))),
255
+ ]
256
+ for stage_id, block_idx in enumerate(self.grouped_block_index):
257
+ stage_first_block = self.blocks[block_idx[0]]
258
+ width_mult_setting.append(
259
+ random.choice(list(range(len(stage_first_block.out_channel_list))))
260
+ )
261
+
262
+ arch_config = {"d": depth_setting, "e": expand_setting, "w": width_mult_setting}
263
+ self.set_active_subnet(**arch_config)
264
+ return arch_config
265
+
266
+ def get_active_subnet(self, preserve_weight=True):
267
+ input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)]
268
+ if self.input_stem_skipping <= 0:
269
+ input_stem.append(
270
+ ResidualBlock(
271
+ self.input_stem[1].conv.get_active_subnet(
272
+ self.input_stem[0].active_out_channel, preserve_weight
273
+ ),
274
+ IdentityLayer(
275
+ self.input_stem[0].active_out_channel,
276
+ self.input_stem[0].active_out_channel,
277
+ ),
278
+ )
279
+ )
280
+ input_stem.append(
281
+ self.input_stem[2].get_active_subnet(
282
+ self.input_stem[0].active_out_channel, preserve_weight
283
+ )
284
+ )
285
+ input_channel = self.input_stem[2].active_out_channel
286
+
287
+ blocks = []
288
+ for stage_id, block_idx in enumerate(self.grouped_block_index):
289
+ depth_param = self.runtime_depth[stage_id]
290
+ active_idx = block_idx[: len(block_idx) - depth_param]
291
+ for idx in active_idx:
292
+ blocks.append(
293
+ self.blocks[idx].get_active_subnet(input_channel, preserve_weight)
294
+ )
295
+ input_channel = self.blocks[idx].active_out_channel
296
+ classifier = self.classifier.get_active_subnet(input_channel, preserve_weight)
297
+ subnet = ResNets(input_stem, blocks, classifier)
298
+
299
+ subnet.set_bn_param(**self.get_bn_param())
300
+ return subnet
301
+
302
+ def get_active_net_config(self):
303
+ input_stem_config = [self.input_stem[0].get_active_subnet_config(3)]
304
+ if self.input_stem_skipping <= 0:
305
+ input_stem_config.append(
306
+ {
307
+ "name": ResidualBlock.__name__,
308
+ "conv": self.input_stem[1].conv.get_active_subnet_config(
309
+ self.input_stem[0].active_out_channel
310
+ ),
311
+ "shortcut": IdentityLayer(
312
+ self.input_stem[0].active_out_channel,
313
+ self.input_stem[0].active_out_channel,
314
+ ),
315
+ }
316
+ )
317
+ input_stem_config.append(
318
+ self.input_stem[2].get_active_subnet_config(
319
+ self.input_stem[0].active_out_channel
320
+ )
321
+ )
322
+ input_channel = self.input_stem[2].active_out_channel
323
+
324
+ blocks_config = []
325
+ for stage_id, block_idx in enumerate(self.grouped_block_index):
326
+ depth_param = self.runtime_depth[stage_id]
327
+ active_idx = block_idx[: len(block_idx) - depth_param]
328
+ for idx in active_idx:
329
+ blocks_config.append(
330
+ self.blocks[idx].get_active_subnet_config(input_channel)
331
+ )
332
+ input_channel = self.blocks[idx].active_out_channel
333
+ classifier_config = self.classifier.get_active_subnet_config(input_channel)
334
+ return {
335
+ "name": ResNets.__name__,
336
+ "bn": self.get_bn_param(),
337
+ "input_stem": input_stem_config,
338
+ "blocks": blocks_config,
339
+ "classifier": classifier_config,
340
+ }
341
+
342
+ """ Width Related Methods """
343
+
344
+ def re_organize_middle_weights(self, expand_ratio_stage=0):
345
+ for block in self.blocks:
346
+ block.re_organize_middle_weights(expand_ratio_stage)
347
+
348
+
349
+
350
+ class DYNResNets_Cifar(ResNets_Cifar):
351
+ def __init__(
352
+ self,
353
+ n_classes=10,
354
+ bn_param=(0.1, 1e-5),
355
+ dropout_rate=0,
356
+ depth_list=0,
357
+ expand_ratio_list=0.25,
358
+ width_mult_list=1.0,
359
+ ):
360
+
361
+ self.depth_list = val2list(depth_list)
362
+ self.expand_ratio_list = val2list(expand_ratio_list)
363
+ self.width_mult_list = val2list(width_mult_list)
364
+ # sort
365
+ self.depth_list.sort()
366
+ self.expand_ratio_list.sort()
367
+ self.width_mult_list.sort()
368
+
369
+ input_channel = [
370
+ make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
371
+ for width_mult in self.width_mult_list
372
+ ]
373
+ mid_input_channel = [
374
+ make_divisible(channel // 2, MyNetwork.CHANNEL_DIVISIBLE)
375
+ for channel in input_channel
376
+ ]
377
+
378
+ stage_width_list = ResNets_Cifar.STAGE_WIDTH_LIST.copy()
379
+ for i, width in enumerate(stage_width_list):
380
+ stage_width_list[i] = [
381
+ make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
382
+ for width_mult in self.width_mult_list
383
+ ]
384
+
385
+ n_block_list = [
386
+ base_depth + max(self.depth_list) for base_depth in ResNets_Cifar.BASE_DEPTH_LIST
387
+ ]
388
+ stride_list = [1, 2, 2, 2]
389
+
390
+ # build input stem
391
+ input_stem = [
392
+ DynamicConvLayer(
393
+ val2list(3),
394
+ mid_input_channel,
395
+ 3,
396
+ stride=1,
397
+ use_bn=True,
398
+ act_func="relu",
399
+ ),
400
+ ResidualBlock(
401
+ DynamicConvLayer(
402
+ mid_input_channel,
403
+ mid_input_channel,
404
+ 3,
405
+ stride=1,
406
+ use_bn=True,
407
+ act_func="relu",
408
+ ),
409
+ IdentityLayer(mid_input_channel, mid_input_channel),
410
+ ),
411
+ DynamicConvLayer(
412
+ mid_input_channel,
413
+ input_channel,
414
+ 3,
415
+ stride=1,
416
+ use_bn=True,
417
+ act_func="relu",
418
+ ),
419
+ ]
420
+
421
+ # blocks
422
+ blocks = []
423
+ for d, width, s in zip(n_block_list, stage_width_list, stride_list):
424
+ for i in range(d):
425
+ stride = s if i == 0 else 1
426
+ bottleneck_block = DynamicResNetBottleneckBlock(
427
+ input_channel,
428
+ width,
429
+ expand_ratio_list=self.expand_ratio_list,
430
+ kernel_size=3,
431
+ stride=stride,
432
+ act_func="relu",
433
+ downsample_mode="conv",
434
+ )
435
+ blocks.append(bottleneck_block)
436
+ input_channel = width
437
+ # classifier
438
+ classifier = DynamicLinearLayer(
439
+ input_channel, n_classes, dropout_rate=dropout_rate
440
+ )
441
+
442
+ super(DYNResNets_Cifar, self).__init__(input_stem, blocks, classifier)
443
+
444
+ # set bn param
445
+ self.set_bn_param(*bn_param)
446
+
447
+ # runtime_depth
448
+ self.input_stem_skipping = 0
449
+ self.runtime_depth = [0] * len(n_block_list)
450
+
451
+ @property
452
+ def ks_list(self):
453
+ return [3]
454
+
455
+ @staticmethod
456
+ def name():
457
+ return "DYNResNets_Cifar"
458
+
459
+ def forward(self, x):
460
+ for layer in self.input_stem:
461
+ if (
462
+ self.input_stem_skipping > 0
463
+ and isinstance(layer, ResidualBlock)
464
+ and isinstance(layer.shortcut, IdentityLayer)
465
+ ):
466
+ pass
467
+ else:
468
+ x = layer(x)
469
+ for stage_id, block_idx in enumerate(self.grouped_block_index):
470
+ depth_param = self.runtime_depth[stage_id]
471
+ active_idx = block_idx[: len(block_idx) - depth_param]
472
+ for idx in active_idx:
473
+ x = self.blocks[idx](x)
474
+ x = self.global_avg_pool(x)
475
+ x = self.classifier(x)
476
+ return x
477
+
478
+ @property
479
+ def module_str(self):
480
+ _str = ""
481
+ for layer in self.input_stem:
482
+ if (
483
+ self.input_stem_skipping > 0
484
+ and isinstance(layer, ResidualBlock)
485
+ and isinstance(layer.shortcut, IdentityLayer)
486
+ ):
487
+ pass
488
+ else:
489
+ _str += layer.module_str + "\n"
490
+ # _str += "max_pooling(ks=3, stride=2)\n"
491
+ for stage_id, block_idx in enumerate(self.grouped_block_index):
492
+ depth_param = self.runtime_depth[stage_id]
493
+ active_idx = block_idx[: len(block_idx) - depth_param]
494
+ for idx in active_idx:
495
+ _str += self.blocks[idx].module_str + "\n"
496
+ _str += self.global_avg_pool.__repr__() + "\n"
497
+ _str += self.classifier.module_str
498
+ return _str
499
+
500
+ @property
501
+ def config(self):
502
+ return {
503
+ "name": DYNResNets_Cifar.__name__,
504
+ "bn": self.get_bn_param(),
505
+ "input_stem": [layer.config for layer in self.input_stem],
506
+ "blocks": [block.config for block in self.blocks],
507
+ "classifier": self.classifier.config,
508
+ }
509
+
510
+ @staticmethod
511
+ def build_from_config(config):
512
+ raise ValueError("do not support this function")
513
+
514
+ def load_state_dict(self, state_dict, **kwargs):
515
+ model_dict = self.state_dict()
516
+ for key in state_dict:
517
+ new_key = key
518
+ if new_key in model_dict:
519
+ pass
520
+ elif ".linear." in new_key:
521
+ new_key = new_key.replace(".linear.", ".linear.linear.")
522
+ elif "bn." in new_key:
523
+ new_key = new_key.replace("bn.", "bn.bn.")
524
+ elif "conv.weight" in new_key:
525
+ new_key = new_key.replace("conv.weight", "conv.conv.weight")
526
+ else:
527
+ raise ValueError(new_key)
528
+ assert new_key in model_dict, "%s" % new_key
529
+ model_dict[new_key] = state_dict[key]
530
+ super(DYNResNets_Cifar, self).load_state_dict(model_dict)
531
+
532
+ """ set, sample and get active sub-networks """
533
+
534
+ def set_max_net(self):
535
+ self.set_active_subnet(
536
+ d=max(self.depth_list),
537
+ e=max(self.expand_ratio_list),
538
+ w=len(self.width_mult_list) - 1,
539
+ )
540
+
541
+ def set_active_subnet(self, d=None, e=None, w=None, **kwargs):
542
+ depth = val2list(d, len(ResNets_Cifar.BASE_DEPTH_LIST) + 1)
543
+ expand_ratio = val2list(e, len(self.blocks))
544
+ width_mult = val2list(w, len(ResNets_Cifar.BASE_DEPTH_LIST) + 2)
545
+
546
+ for block, e in zip(self.blocks, expand_ratio):
547
+ if e is not None:
548
+ block.active_expand_ratio = e
549
+
550
+ if width_mult[0] is not None:
551
+ self.input_stem[1].conv.active_out_channel = self.input_stem[
552
+ 0
553
+ ].active_out_channel = self.input_stem[0].out_channel_list[int(width_mult[0])]
554
+ if width_mult[1] is not None:
555
+ self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[
556
+ int(width_mult[1])
557
+ ]
558
+
559
+ if depth[0] is not None:
560
+ self.input_stem_skipping = depth[0] != max(self.depth_list)
561
+ for stage_id, (block_idx, d, w) in enumerate(
562
+ zip(self.grouped_block_index, depth[1:], width_mult[2:])
563
+ ):
564
+ if d is not None:
565
+ self.runtime_depth[stage_id] = max(self.depth_list) - d
566
+ if w is not None:
567
+ for idx in block_idx:
568
+ self.blocks[idx].active_out_channel = self.blocks[
569
+ idx
570
+ ].out_channel_list[int(w)]
571
+
572
+ def sample_active_subnet(self):
573
+ # sample expand ratio
574
+ expand_setting = []
575
+ for block in self.blocks:
576
+ expand_setting.append(random.choice(block.expand_ratio_list))
577
+
578
+ # sample depth
579
+ depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])]
580
+ for stage_id in range(len(ResNets_Cifar.BASE_DEPTH_LIST)):
581
+ depth_setting.append(random.choice(self.depth_list))
582
+
583
+ # sample width_mult
584
+ width_mult_setting = [
585
+ random.choice(list(range(len(self.input_stem[0].out_channel_list)))),
586
+ random.choice(list(range(len(self.input_stem[2].out_channel_list)))),
587
+ ]
588
+ for stage_id, block_idx in enumerate(self.grouped_block_index):
589
+ stage_first_block = self.blocks[block_idx[0]]
590
+ width_mult_setting.append(
591
+ random.choice(list(range(len(stage_first_block.out_channel_list))))
592
+ )
593
+
594
+ arch_config = {"d": depth_setting, "e": expand_setting, "w": width_mult_setting}
595
+ self.set_active_subnet(**arch_config)
596
+ return arch_config
597
+
598
+ def get_active_subnet(self, preserve_weight=True):
599
+ input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)]
600
+ if self.input_stem_skipping <= 0:
601
+ input_stem.append(
602
+ ResidualBlock(
603
+ self.input_stem[1].conv.get_active_subnet(
604
+ self.input_stem[0].active_out_channel, preserve_weight
605
+ ),
606
+ IdentityLayer(
607
+ self.input_stem[0].active_out_channel,
608
+ self.input_stem[0].active_out_channel,
609
+ ),
610
+ )
611
+ )
612
+ input_stem.append(
613
+ self.input_stem[2].get_active_subnet(
614
+ self.input_stem[0].active_out_channel, preserve_weight
615
+ )
616
+ )
617
+ input_channel = self.input_stem[2].active_out_channel
618
+
619
+ blocks = []
620
+ for stage_id, block_idx in enumerate(self.grouped_block_index):
621
+ depth_param = self.runtime_depth[stage_id]
622
+ active_idx = block_idx[: len(block_idx) - depth_param]
623
+ for idx in active_idx:
624
+ blocks.append(
625
+ self.blocks[idx].get_active_subnet(input_channel, preserve_weight)
626
+ )
627
+ input_channel = self.blocks[idx].active_out_channel
628
+ classifier = self.classifier.get_active_subnet(input_channel, preserve_weight)
629
+ subnet = ResNets_Cifar(input_stem, blocks, classifier)
630
+
631
+ subnet.set_bn_param(**self.get_bn_param())
632
+ return subnet
633
+
634
+ def get_active_net_config(self):
635
+ input_stem_config = [self.input_stem[0].get_active_subnet_config(3)]
636
+ if self.input_stem_skipping <= 0:
637
+ input_stem_config.append(
638
+ {
639
+ "name": ResidualBlock.__name__,
640
+ "conv": self.input_stem[1].conv.get_active_subnet_config(
641
+ self.input_stem[0].active_out_channel
642
+ ),
643
+ "shortcut": IdentityLayer(
644
+ self.input_stem[0].active_out_channel,
645
+ self.input_stem[0].active_out_channel,
646
+ ),
647
+ }
648
+ )
649
+ input_stem_config.append(
650
+ self.input_stem[2].get_active_subnet_config(
651
+ self.input_stem[0].active_out_channel
652
+ )
653
+ )
654
+ input_channel = self.input_stem[2].active_out_channel
655
+
656
+ blocks_config = []
657
+ for stage_id, block_idx in enumerate(self.grouped_block_index):
658
+ depth_param = self.runtime_depth[stage_id]
659
+ active_idx = block_idx[: len(block_idx) - int(depth_param)]
660
+ for idx in active_idx:
661
+ blocks_config.append(
662
+ self.blocks[idx].get_active_subnet_config(input_channel)
663
+ )
664
+ input_channel = self.blocks[idx].active_out_channel
665
+ classifier_config = self.classifier.get_active_subnet_config(input_channel)
666
+ return {
667
+ "name": ResNets_Cifar.__name__,
668
+ "bn": self.get_bn_param(),
669
+ "input_stem": input_stem_config,
670
+ "blocks": blocks_config,
671
+ "classifier": classifier_config,
672
+ }
673
+
674
+ """ Width Related Methods """
675
+
676
+ def re_organize_middle_weights(self, expand_ratio_stage=0):
677
+ for block in self.blocks:
678
+ block.re_organize_middle_weights(expand_ratio_stage)
proard/classification/elastic_nn/training/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ from .progressive_shrinking import *
6
+ from .progressive_shrinking import *
proard/classification/elastic_nn/training/progressive_shrinking.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import torch.nn as nn
6
+ import random
7
+ import time
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from tqdm import tqdm
11
+ from attacks.utils import ctx_noparamgrad_and_eval
12
+ from robust_loss.rslad import rslad_inner_loss,kl_loss
13
+ from robust_loss.trades import trades_loss
14
+ from attacks import create_attack
15
+ import copy
16
+ from proard.utils import AverageMeter, cross_entropy_loss_with_soft_target
17
+ from proard.utils import (
18
+ DistributedMetric,
19
+ list_mean,
20
+ subset_mean,
21
+ val2list,
22
+ MyRandomResizedCrop,
23
+ )
24
+ from proard.classification.run_manager import DistributedRunManager
25
+
26
+ __all__ = [
27
+ "validate",
28
+ "train_one_epoch",
29
+ "train",
30
+ "load_models",
31
+ "train_elastic_depth",
32
+ "train_elastic_expand",
33
+ "train_elastic_width_mult",
34
+ ]
35
+
36
+
37
+ def validate(
38
+ run_manager,
39
+ epoch=0,
40
+ is_test=False,
41
+ image_size_list=None,
42
+ ks_list=None,
43
+ expand_ratio_list=None,
44
+ depth_list=None,
45
+ width_mult_list=None,
46
+ additional_setting=None,
47
+ ):
48
+ dynamic_net = run_manager.net
49
+ if isinstance(dynamic_net, nn.DataParallel):
50
+ dynamic_net = dynamic_net.module
51
+
52
+ dynamic_net.eval()
53
+
54
+ if image_size_list is None:
55
+ image_size_list = val2list(run_manager.run_config.data_provider.image_size, 1)
56
+ if ks_list is None:
57
+ ks_list = dynamic_net.ks_list
58
+ if expand_ratio_list is None:
59
+ expand_ratio_list = dynamic_net.expand_ratio_list
60
+ if depth_list is None:
61
+ depth_list = dynamic_net.depth_list
62
+ if width_mult_list is not None:
63
+ if "width_mult_list" in dynamic_net.__dict__:
64
+ width_mult_list = list(range(len(dynamic_net.width_mult_list)))
65
+ else:
66
+ width_mult_list = [0]
67
+
68
+ subnet_settings = []
69
+ for d in depth_list:
70
+ for e in expand_ratio_list:
71
+ for k in ks_list:
72
+ for w in width_mult_list:
73
+ for img_size in image_size_list:
74
+ subnet_settings.append(
75
+ [
76
+ {
77
+ "image_size": img_size,
78
+ "d": d,
79
+ "e": e,
80
+ "ks": k,
81
+ "w": w,
82
+ },
83
+ "R%s-D%s-E%s-K%s-W%s" % (img_size, d, e, k, w),
84
+ ]
85
+ )
86
+ if additional_setting is not None:
87
+ subnet_settings += additional_setting
88
+
89
+ losses_of_subnets, top1_of_subnets, top5_of_subnets , robust1_of_subnets , robust5_of_subnets = [], [], [],[],[]
90
+
91
+ valid_log = ""
92
+ for setting, name in subnet_settings:
93
+ run_manager.write_log(
94
+ "-" * 30 + " Validate %s " % name + "-" * 30, "train", should_print=False
95
+ )
96
+ run_manager.run_config.data_provider.assign_active_img_size(
97
+ setting.pop("image_size")
98
+ )
99
+ dynamic_net.set_active_subnet(**setting)
100
+ run_manager.write_log(dynamic_net.module_str, "train", should_print=False)
101
+
102
+ run_manager.reset_running_statistics(dynamic_net)
103
+ loss, (top1, top5,robust1,robust5) = run_manager.validate(
104
+ epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net
105
+ )
106
+ losses_of_subnets.append(loss)
107
+ top1_of_subnets.append(top1)
108
+ top5_of_subnets.append(top5)
109
+ robust1_of_subnets.append(robust1)
110
+ robust5_of_subnets.append(robust5)
111
+ valid_log += "%s (%.3f) (%.3f), " % (name, top1,robust1)
112
+
113
+ return (
114
+ list_mean(losses_of_subnets),
115
+ list_mean(top1_of_subnets),
116
+ list_mean(top5_of_subnets),
117
+ list_mean(robust1_of_subnets),
118
+ list_mean(robust5_of_subnets),
119
+ valid_log,
120
+ )
121
+
122
+
123
+ def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0):
124
+ dynamic_net = run_manager.network
125
+ distributed = isinstance(run_manager, DistributedRunManager)
126
+
127
+ # switch to train mode
128
+ dynamic_net.train()
129
+ if distributed:
130
+ run_manager.run_config.train_loader.sampler.set_epoch(epoch)
131
+ MyRandomResizedCrop.EPOCH = epoch
132
+
133
+ nBatch = len(run_manager.run_config.train_loader)
134
+
135
+ data_time = AverageMeter()
136
+ losses = DistributedMetric("train_loss") if distributed else AverageMeter()
137
+ metric_dict = run_manager.get_metric_dict()
138
+
139
+ with tqdm(
140
+ total=nBatch,
141
+ desc="Train Epoch #{}".format(epoch + 1),
142
+ disable=distributed and not run_manager.is_root,
143
+ ) as t:
144
+ end = time.time()
145
+ subnet_str = ""
146
+ j=0
147
+ for _ in range(args.dynamic_batch_size):
148
+ # set random seed before sampling
149
+ subnet_seed = int("%d%.3d%.3d" % (epoch * nBatch + j, _, 0))
150
+ random.seed(subnet_seed)
151
+ subnet_settings = dynamic_net.sample_active_subnet()
152
+ subnet_str += (
153
+ "%d: " % _
154
+ + ",".join(
155
+ [
156
+ "%s_%s"
157
+ % (
158
+ key,
159
+ "%.1f" % subset_mean(val, 0)
160
+ if isinstance(val, list)
161
+ else val,
162
+ )
163
+ for key, val in subnet_settings.items()
164
+ ]
165
+ )
166
+ + " || "
167
+ )
168
+
169
+ for i, (images, labels) in enumerate(run_manager.run_config.train_loader):
170
+ MyRandomResizedCrop.BATCH = i
171
+ data_time.update(time.time() - end)
172
+ if epoch < warmup_epochs:
173
+ new_lr = run_manager.run_config.warmup_adjust_learning_rate(
174
+ run_manager.optimizer,
175
+ warmup_epochs * nBatch,
176
+ nBatch,
177
+ epoch,
178
+ i,
179
+ warmup_lr,
180
+ )
181
+ else:
182
+ new_lr = run_manager.run_config.adjust_learning_rate(
183
+ run_manager.optimizer, epoch - warmup_epochs, i, nBatch
184
+ )
185
+
186
+ images, labels = images.cuda(), labels.cuda()
187
+ target = labels
188
+
189
+ # soft target
190
+ if args.kd_ratio > 0:
191
+ args.teacher_model.eval()
192
+ with torch.no_grad():
193
+ soft_logits = args.teacher_model(images).detach()
194
+ soft_label = F.softmax(soft_logits, dim=1)
195
+
196
+ # clean gradients
197
+ dynamic_net.zero_grad()
198
+
199
+ loss_of_subnets = []
200
+ # compute output
201
+
202
+
203
+ output = dynamic_net(images)
204
+
205
+ if args.kd_ratio == 0:
206
+ if run_manager.run_config.robust_mode:
207
+ loss = run_manager.train_criterion(dynamic_net,images,labels,run_manager.optimizer,run_manager.run_config.step_size_train,run_manager.run_config.epsilon_train,run_manager.run_config.num_steps_train,run_manager.run_config.beta_train,run_manager.run_config.distance_train)
208
+ loss_type = run_manager.run_config.train_criterion_loss.__name__
209
+ else:
210
+ loss = torch.nn.CrossEntropyLoss(output,labels)
211
+ loss_type = 'ce'
212
+ else:
213
+ if run_manager.run_config.robust_mode:
214
+ loss = run_manager.kd_criterion(args.teacher_model,dynamic_net,images,labels,run_manager.optimizer,run_manager.run_config.step_size_train,run_manager.run_config.epsilon_train,run_manager.run_config.num_steps_train,run_manager.run_config.beta_train)
215
+ loss_type = run_manager.run_config.kd_criterion_loss.__name__
216
+ else:
217
+ if args.kd_type == "ce":
218
+ kd_loss = cross_entropy_loss_with_soft_target(
219
+ output, soft_label
220
+ )
221
+ else:
222
+ kd_loss = F.mse_loss(output, soft_logits)
223
+ loss = args.kd_ratio * kd_loss + loss
224
+ loss_type = "%.1fkd+ce" % args.kd_ratio
225
+ # measure accuracy and record loss
226
+ loss_of_subnets.append(loss)
227
+ run_manager.update_metric(metric_dict, output,output, target)
228
+
229
+ loss.backward()
230
+ run_manager.optimizer.step()
231
+
232
+ losses.update(list_mean(loss_of_subnets), images.size(0))
233
+
234
+ t.set_postfix(
235
+ {
236
+ "loss": losses.avg.item(),
237
+ **run_manager.get_metric_vals(metric_dict, return_dict=True),
238
+ "R": images.size(2),
239
+ "lr": new_lr,
240
+ "loss_type": loss_type,
241
+ "seed": str(subnet_seed),
242
+ "str": subnet_str,
243
+ "data_time": data_time.avg,
244
+ }
245
+ )
246
+ t.update(1)
247
+ end = time.time()
248
+ j+=1
249
+ return losses.avg.item(), run_manager.get_metric_vals(metric_dict)
250
+
251
+
252
+ def train(run_manager, args, validate_func=None):
253
+ distributed = isinstance(run_manager, DistributedRunManager)
254
+ if validate_func is None:
255
+ validate_func = validate
256
+
257
+ for epoch in range(
258
+ run_manager.start_epoch, run_manager.run_config.n_epochs + args.warmup_epochs
259
+ ):
260
+ train_loss, (train_top1, train_top5 , train_robust1 , train_robust5) = train_one_epoch(
261
+ run_manager, args, epoch, args.warmup_epochs, args.warmup_lr
262
+ )
263
+
264
+ if (epoch + 1) % args.validation_frequency == 0:
265
+ val_loss, val_acc, val_acc5, val_robust1, val_robust5, _val_log = validate_func(
266
+ run_manager, epoch=epoch, is_test=True
267
+ )
268
+ # best_acc
269
+ is_best = val_acc > run_manager.best_acc
270
+ is_best_robust = val_robust1 > run_manager.best_robustness
271
+ run_manager.best_acc = max(run_manager.best_acc, val_acc)
272
+ run_manager.best_robustness = max(run_manager.best_robustness, val_robust1)
273
+ if not distributed or run_manager.is_root:
274
+ val_log = (
275
+ "Valid [{0}/{1}] loss={2:.3f}, top-1={3:.3f} ({4:.3f}) , robust-1 = {4:.3f} ({5:.3f}) ".format(
276
+ epoch + 1 - args.warmup_epochs,
277
+ run_manager.run_config.n_epochs,
278
+ val_loss,
279
+ val_acc,
280
+ run_manager.best_acc,
281
+ val_robust1,
282
+ run_manager.best_robustness,
283
+ )
284
+ )
285
+ val_log += ", Train top-1 {top1:.3f}, Train robust-1 {robust1:.3f}, Train loss {loss:.3f}\t".format(
286
+ top1=train_top1, robust1 = train_robust1, loss=train_loss
287
+ )
288
+ val_log += _val_log
289
+ run_manager.write_log(val_log, "valid", should_print=False)
290
+
291
+ run_manager.save_model(
292
+ {
293
+ "epoch": epoch,
294
+ "best_acc": run_manager.best_acc,
295
+ "optimizer": run_manager.optimizer.state_dict(),
296
+ "state_dict": run_manager.network.state_dict(),
297
+ },
298
+ is_best=is_best,
299
+ )
300
+
301
+
302
+ def load_models(run_manager, dynamic_net, model_path=None):
303
+ # specify init path
304
+ init = torch.load(model_path, map_location="cpu")["state_dict"]
305
+ dynamic_net.load_state_dict(init)
306
+ run_manager.write_log("Loaded init from %s" % model_path, "valid")
307
+
308
+
309
+ def train_elastic_depth(train_func, run_manager, args, validate_func_dict):
310
+ dynamic_net = run_manager.net
311
+ if isinstance(dynamic_net, nn.DataParallel):
312
+ dynamic_net = dynamic_net.module
313
+
314
+ depth_stage_list = dynamic_net.depth_list.copy()
315
+ depth_stage_list.sort(reverse=True)
316
+ n_stages = len(depth_stage_list) - 1
317
+ current_stage = n_stages - 1
318
+
319
+ # load pretrained models
320
+ if run_manager.start_epoch == 0 and not args.resume:
321
+ validate_func_dict["depth_list"] = sorted(dynamic_net.depth_list)
322
+
323
+ load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path)
324
+ # validate after loading weights
325
+ run_manager.write_log(
326
+ "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
327
+ % validate(run_manager, is_test=True, **validate_func_dict),
328
+ "valid",
329
+ )
330
+ else:
331
+ assert args.resume
332
+
333
+ run_manager.write_log(
334
+ "-" * 30
335
+ + "Supporting Elastic Depth: %s -> %s"
336
+ % (depth_stage_list[: current_stage + 1], depth_stage_list[: current_stage + 2])
337
+ + "-" * 30,
338
+ "valid",
339
+ )
340
+ # add depth list constraints
341
+ if (
342
+ len(set(dynamic_net.ks_list)) == 1
343
+ and len(set(dynamic_net.expand_ratio_list)) == 1
344
+ ):
345
+ validate_func_dict["depth_list"] = depth_stage_list
346
+ else:
347
+ validate_func_dict["depth_list"] = sorted(
348
+ {min(depth_stage_list), max(depth_stage_list)}
349
+ )
350
+
351
+ # train
352
+ train_func(
353
+ run_manager,
354
+ args,
355
+ lambda _run_manager, epoch, is_test: validate(
356
+ _run_manager, epoch, is_test, **validate_func_dict
357
+ ),
358
+ )
359
+
360
+
361
+ def train_elastic_expand(train_func, run_manager, args, validate_func_dict):
362
+ dynamic_net = run_manager.net
363
+ if isinstance(dynamic_net, nn.DataParallel):
364
+ dynamic_net = dynamic_net.module
365
+
366
+ expand_stage_list = dynamic_net.expand_ratio_list.copy()
367
+ expand_stage_list.sort(reverse=True)
368
+ n_stages = len(expand_stage_list) - 1
369
+ current_stage = n_stages - 1
370
+
371
+ # load pretrained models
372
+ if run_manager.start_epoch == 0 and not args.resume:
373
+ validate_func_dict["expand_ratio_list"] = sorted(dynamic_net.expand_ratio_list)
374
+
375
+ load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path)
376
+ dynamic_net.re_organize_middle_weights(expand_ratio_stage=current_stage)
377
+ run_manager.write_log(
378
+ "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
379
+ % validate(run_manager, is_test=True, **validate_func_dict),
380
+ "valid",
381
+ )
382
+ else:
383
+ assert args.resume
384
+
385
+ run_manager.write_log(
386
+ "-" * 30
387
+ + "Supporting Elastic Expand Ratio: %s -> %s"
388
+ % (
389
+ expand_stage_list[: current_stage + 1],
390
+ expand_stage_list[: current_stage + 2],
391
+ )
392
+ + "-" * 30,
393
+ "valid",
394
+ )
395
+ if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.depth_list)) == 1:
396
+ validate_func_dict["expand_ratio_list"] = expand_stage_list
397
+ else:
398
+ validate_func_dict["expand_ratio_list"] = sorted(
399
+ {min(expand_stage_list), max(expand_stage_list)}
400
+ )
401
+
402
+ # train
403
+ train_func(
404
+ run_manager,
405
+ args,
406
+ lambda _run_manager, epoch, is_test: validate(
407
+ _run_manager, epoch, is_test, **validate_func_dict
408
+ ),
409
+ )
410
+
411
+
412
+ def train_elastic_width_mult(train_func, run_manager, args, validate_func_dict):
413
+ dynamic_net = run_manager.net
414
+ if isinstance(dynamic_net, nn.DataParallel):
415
+ dynamic_net = dynamic_net.module
416
+
417
+ width_stage_list = dynamic_net.width_mult_list.copy()
418
+ width_stage_list.sort(reverse=True)
419
+ n_stages = len(width_stage_list) - 1
420
+ current_stage = n_stages - 1
421
+
422
+ if run_manager.start_epoch == 0 and not args.resume:
423
+ load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path)
424
+ if current_stage == 0:
425
+ dynamic_net.re_organize_middle_weights(
426
+ expand_ratio_stage=len(dynamic_net.expand_ratio_list) - 1
427
+ )
428
+ run_manager.write_log(
429
+ "reorganize_middle_weights (expand_ratio_stage=%d)"
430
+ % (len(dynamic_net.expand_ratio_list) - 1),
431
+ "valid",
432
+ )
433
+ try:
434
+ dynamic_net.re_organize_outer_weights()
435
+ run_manager.write_log("reorganize_outer_weights", "valid")
436
+ except Exception:
437
+ pass
438
+ validate_func_dict["width_mult_list"] = sorted({0, len(width_stage_list) - 1})
439
+ run_manager.write_log(
440
+ "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
441
+ % validate(run_manager, is_test=True, **validate_func_dict),
442
+ "valid",
443
+ )
444
+ else:
445
+ assert args.resume
446
+
447
+ run_manager.write_log(
448
+ "-" * 30
449
+ + "Supporting Elastic Width Mult: %s -> %s"
450
+ % (width_stage_list[: current_stage + 1], width_stage_list[: current_stage + 2])
451
+ + "-" * 30,
452
+ "valid",
453
+ )
454
+ validate_func_dict["width_mult_list"] = sorted({0, len(width_stage_list) - 1})
455
+
456
+ # train
457
+ train_func(
458
+ run_manager,
459
+ args,
460
+ lambda _run_manager, epoch, is_test: validate(
461
+ _run_manager, epoch, is_test, **validate_func_dict
462
+ ),
463
+ )
proard/classification/elastic_nn/utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import copy
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ import torch
9
+ from attacks import create_attack
10
+ from attacks.utils import ctx_noparamgrad_and_eval
11
+ from proard.utils import AverageMeter, get_net_device, DistributedTensor
12
+ from proard.classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d
13
+
14
+ __all__ = ["set_running_statistics"]
15
+
16
+ def set_running_statistics(model, data_loader, distributed=False):
17
+ bn_mean = {}
18
+ bn_var = {}
19
+
20
+ forward_model = copy.deepcopy(model)
21
+ for name, m in forward_model.named_modules():
22
+ if isinstance(m, nn.BatchNorm2d):
23
+ if distributed:
24
+ bn_mean[name] = DistributedTensor(name + "#mean")
25
+ bn_var[name] = DistributedTensor(name + "#var")
26
+ else:
27
+ bn_mean[name] = AverageMeter()
28
+ bn_var[name] = AverageMeter()
29
+
30
+ def new_forward(bn, mean_est, var_est):
31
+ def lambda_forward(x):
32
+ batch_mean = (
33
+ x.mean(0, keepdim=True)
34
+ .mean(2, keepdim=True)
35
+ .mean(3, keepdim=True)
36
+ ) # 1, C, 1, 1
37
+ batch_var = (x - batch_mean) * (x - batch_mean)
38
+ batch_var = (
39
+ batch_var.mean(0, keepdim=True)
40
+ .mean(2, keepdim=True)
41
+ .mean(3, keepdim=True)
42
+ )
43
+
44
+ batch_mean = torch.squeeze(batch_mean)
45
+ batch_var = torch.squeeze(batch_var)
46
+
47
+ mean_est.update(batch_mean.data, x.size(0))
48
+ var_est.update(batch_var.data, x.size(0))
49
+
50
+ # bn forward using calculated mean & var
51
+ _feature_dim = batch_mean.size(0)
52
+ return F.batch_norm(
53
+ x,
54
+ batch_mean,
55
+ batch_var,
56
+ bn.weight[:_feature_dim],
57
+ bn.bias[:_feature_dim],
58
+ False,
59
+ 0.0,
60
+ bn.eps,
61
+ )
62
+
63
+ return lambda_forward
64
+
65
+ m.forward = new_forward(m, bn_mean[name], bn_var[name])
66
+
67
+ if len(bn_mean) == 0:
68
+ # skip if there is no batch normalization layers in the network
69
+ return
70
+
71
+ with torch.no_grad():
72
+ DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True
73
+ for images, labels in data_loader:
74
+ images = images.to(get_net_device(forward_model))
75
+ forward_model(images)
76
+ DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False
77
+
78
+ for name, m in model.named_modules():
79
+ if name in bn_mean and bn_mean[name].count > 0:
80
+ feature_dim = bn_mean[name].avg.size(0)
81
+ assert isinstance(m, nn.BatchNorm2d)
82
+ m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
83
+ m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
proard/classification/networks/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ from .proxyless_nets import *
6
+ from .mobilenet_v3 import *
7
+ from .resnets import *
8
+ from .wide_resnet import WideResNet
9
+ from .resnet_trades import *
10
+
11
+ def get_net_by_name(name):
12
+ if name == ProxylessNASNets.__name__:
13
+ return ProxylessNASNets
14
+ elif name == MobileNetV3.__name__:
15
+ return MobileNetV3
16
+ elif name == ResNets.__name__:
17
+ return ResNets
18
+ if name == ProxylessNASNets_Cifar.__name__:
19
+ return ProxylessNASNets_Cifar
20
+ elif name == MobileNetV3_Cifar.__name__:
21
+ return MobileNetV3
22
+ elif name == ResNets_Cifar.__name__:
23
+ return ResNets_Cifar
24
+ else:
25
+ raise ValueError("unrecognized type of network: %s" % name)
proard/classification/networks/mobilenet_v3.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import copy
6
+ import torch.nn as nn
7
+
8
+ from proard.utils.layers import (
9
+ set_layer_from_config,
10
+ MBConvLayer,
11
+ ConvLayer,
12
+ IdentityLayer,
13
+ LinearLayer,
14
+ ResidualBlock,
15
+ )
16
+ from proard.utils import MyNetwork, make_divisible, MyGlobalAvgPool2d
17
+
18
+ __all__ = ["MobileNetV3", "MobileNetV3Large","MobileNetV3_Cifar", "MobileNetV3Large_Cifar"]
19
+
20
+
21
+ class MobileNetV3(MyNetwork):
22
+ def __init__(
23
+ self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
24
+ ):
25
+ super(MobileNetV3, self).__init__()
26
+
27
+ self.first_conv = first_conv
28
+ self.blocks = nn.ModuleList(blocks)
29
+ self.final_expand_layer = final_expand_layer
30
+ self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=True)
31
+ self.feature_mix_layer = feature_mix_layer
32
+ self.classifier = classifier
33
+
34
+ def forward(self, x):
35
+ x = self.first_conv(x)
36
+ for block in self.blocks:
37
+ x = block(x)
38
+ x = self.final_expand_layer(x)
39
+ x = self.global_avg_pool(x) # global average pooling
40
+ x = self.feature_mix_layer(x)
41
+ x = x.view(x.size(0), -1)
42
+ x = self.classifier(x)
43
+ return x
44
+
45
+ @property
46
+ def module_str(self):
47
+ _str = self.first_conv.module_str + "\n"
48
+ for block in self.blocks:
49
+ _str += block.module_str + "\n"
50
+ _str += self.final_expand_layer.module_str + "\n"
51
+ _str += self.global_avg_pool.__repr__() + "\n"
52
+ _str += self.feature_mix_layer.module_str + "\n"
53
+ _str += self.classifier.module_str
54
+ return _str
55
+
56
+ @property
57
+ def config(self):
58
+ return {
59
+ "name": MobileNetV3.__name__,
60
+ "bn": self.get_bn_param(),
61
+ "first_conv": self.first_conv.config,
62
+ "blocks": [block.config for block in self.blocks],
63
+ "final_expand_layer": self.final_expand_layer.config,
64
+ "feature_mix_layer": self.feature_mix_layer.config,
65
+ "classifier": self.classifier.config,
66
+ }
67
+
68
+ @staticmethod
69
+ def build_from_config(config):
70
+ first_conv = set_layer_from_config(config["first_conv"])
71
+ final_expand_layer = set_layer_from_config(config["final_expand_layer"])
72
+ feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
73
+ classifier = set_layer_from_config(config["classifier"])
74
+
75
+ blocks = []
76
+ for block_config in config["blocks"]:
77
+ blocks.append(ResidualBlock.build_from_config(block_config))
78
+
79
+ net = MobileNetV3(
80
+ first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
81
+ )
82
+ if "bn" in config:
83
+ net.set_bn_param(**config["bn"])
84
+ else:
85
+ net.set_bn_param(momentum=0.1, eps=1e-5)
86
+
87
+ return net
88
+
89
+ def zero_last_gamma(self):
90
+ for m in self.modules():
91
+ if isinstance(m, ResidualBlock):
92
+ if isinstance(m.conv, MBConvLayer) and isinstance(
93
+ m.shortcut, IdentityLayer
94
+ ):
95
+ m.conv.point_linear.bn.weight.data.zero_()
96
+
97
+ @property
98
+ def grouped_block_index(self):
99
+ info_list = []
100
+ block_index_list = []
101
+ for i, block in enumerate(self.blocks[1:], 1):
102
+ if block.shortcut is None and len(block_index_list) > 0:
103
+ info_list.append(block_index_list)
104
+ block_index_list = []
105
+ block_index_list.append(i)
106
+ if len(block_index_list) > 0:
107
+ info_list.append(block_index_list)
108
+ return info_list
109
+
110
+ @staticmethod
111
+ def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
112
+ # first conv layer
113
+ first_conv = ConvLayer(
114
+ 3,
115
+ input_channel,
116
+ kernel_size=3,
117
+ stride=2,
118
+ use_bn=True,
119
+ act_func="h_swish",
120
+ ops_order="weight_bn_act",
121
+ )
122
+ # build mobile blocks
123
+ feature_dim = input_channel
124
+ blocks = []
125
+ for stage_id, block_config_list in cfg.items():
126
+ for (
127
+ k,
128
+ mid_channel,
129
+ out_channel,
130
+ use_se,
131
+ act_func,
132
+ stride,
133
+ expand_ratio,
134
+ ) in block_config_list:
135
+ mb_conv = MBConvLayer(
136
+ feature_dim,
137
+ out_channel,
138
+ k,
139
+ stride,
140
+ expand_ratio,
141
+ mid_channel,
142
+ act_func,
143
+ use_se,
144
+ )
145
+ if stride == 1 and out_channel == feature_dim:
146
+ shortcut = IdentityLayer(out_channel, out_channel)
147
+ else:
148
+ shortcut = None
149
+ blocks.append(ResidualBlock(mb_conv, shortcut))
150
+ feature_dim = out_channel
151
+ # final expand layer
152
+ final_expand_layer = ConvLayer(
153
+ feature_dim,
154
+ feature_dim * 6,
155
+ kernel_size=1,
156
+ use_bn=True,
157
+ act_func="h_swish",
158
+ ops_order="weight_bn_act",
159
+ )
160
+ # feature mix layer
161
+ feature_mix_layer = ConvLayer(
162
+ feature_dim * 6,
163
+ last_channel,
164
+ kernel_size=1,
165
+ bias=False,
166
+ use_bn=False,
167
+ act_func="h_swish",
168
+ )
169
+ # classifier
170
+ classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
171
+
172
+ return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
173
+
174
+ @staticmethod
175
+ def adjust_cfg(
176
+ cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None
177
+ ):
178
+ for i, (stage_id, block_config_list) in enumerate(cfg.items()):
179
+ for block_config in block_config_list:
180
+ if ks is not None and stage_id != "0":
181
+ block_config[0] = ks
182
+ if expand_ratio is not None and stage_id != "0":
183
+ block_config[-1] = expand_ratio
184
+ block_config[1] = None
185
+ if stage_width_list is not None:
186
+ block_config[2] = stage_width_list[i]
187
+ if depth_param is not None and stage_id != "0":
188
+ new_block_config_list = [block_config_list[0]]
189
+ new_block_config_list += [
190
+ copy.deepcopy(block_config_list[-1]) for _ in range(depth_param - 1)
191
+ ]
192
+ cfg[stage_id] = new_block_config_list
193
+ return cfg
194
+
195
+ def load_state_dict(self, state_dict, **kwargs):
196
+ current_state_dict = self.state_dict()
197
+
198
+ for key in state_dict:
199
+ if key not in current_state_dict:
200
+ assert ".mobile_inverted_conv." in key
201
+ new_key = key.replace(".mobile_inverted_conv.", ".conv.")
202
+ else:
203
+ new_key = key
204
+ current_state_dict[new_key] = state_dict[key]
205
+ super(MobileNetV3, self).load_state_dict(current_state_dict)
206
+
207
+
208
+ class MobileNetV3Large(MobileNetV3):
209
+ def __init__(
210
+ self,
211
+ n_classes=1000,
212
+ width_mult=1.0,
213
+ bn_param=(0.1, 1e-5),
214
+ dropout_rate=0.2,
215
+ ks=None,
216
+ expand_ratio=None,
217
+ depth_param=None,
218
+ stage_width_list=None,
219
+ ):
220
+ input_channel = 16
221
+ last_channel = 1280
222
+
223
+ input_channel = make_divisible(
224
+ input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE
225
+ )
226
+ last_channel = (
227
+ make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
228
+ if width_mult > 1.0
229
+ else last_channel
230
+ )
231
+
232
+ cfg = {
233
+ # k, exp, c, se, nl, s, e,
234
+ "0": [
235
+ [3, 16, 16, False, "relu", 1, 1],
236
+ ],
237
+ "1": [
238
+ [3, 64, 24, False, "relu", 2, None], # 4
239
+ [3, 72, 24, False, "relu", 1, None], # 3
240
+ ],
241
+ "2": [
242
+ [5, 72, 40, True, "relu", 2, None], # 3
243
+ [5, 120, 40, True, "relu", 1, None], # 3
244
+ [5, 120, 40, True, "relu", 1, None], # 3
245
+ ],
246
+ "3": [
247
+ [3, 240, 80, False, "h_swish", 2, None], # 6
248
+ [3, 200, 80, False, "h_swish", 1, None], # 2.5
249
+ [3, 184, 80, False, "h_swish", 1, None], # 2.3
250
+ [3, 184, 80, False, "h_swish", 1, None], # 2.3
251
+ ],
252
+ "4": [
253
+ [3, 480, 112, True, "h_swish", 1, None], # 6
254
+ [3, 672, 112, True, "h_swish", 1, None], # 6
255
+ ],
256
+ "5": [
257
+ [5, 672, 160, True, "h_swish", 2, None], # 6
258
+ [5, 960, 160, True, "h_swish", 1, None], # 6
259
+ [5, 960, 160, True, "h_swish", 1, None], # 6
260
+ ],
261
+ }
262
+
263
+ cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list)
264
+ # width multiplier on mobile setting, change `exp: 1` and `c: 2`
265
+ for stage_id, block_config_list in cfg.items():
266
+ for block_config in block_config_list:
267
+ if block_config[1] is not None:
268
+ block_config[1] = make_divisible(
269
+ block_config[1] * width_mult, MyNetwork.CHANNEL_DIVISIBLE
270
+ )
271
+ block_config[2] = make_divisible(
272
+ block_config[2] * width_mult, MyNetwork.CHANNEL_DIVISIBLE
273
+ )
274
+
275
+ (
276
+ first_conv,
277
+ blocks,
278
+ final_expand_layer,
279
+ feature_mix_layer,
280
+ classifier,
281
+ ) = self.build_net_via_cfg(
282
+ cfg, input_channel, last_channel, n_classes, dropout_rate
283
+ )
284
+ super(MobileNetV3Large, self).__init__(
285
+ first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
286
+ )
287
+ # set bn param
288
+ self.set_bn_param(*bn_param)
289
+
290
+
291
+
292
+ class MobileNetV3_Cifar(MyNetwork):
293
+ def __init__(
294
+ self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
295
+ ):
296
+ super(MobileNetV3_Cifar, self).__init__()
297
+
298
+ self.first_conv = first_conv
299
+ self.blocks = nn.ModuleList(blocks)
300
+ self.final_expand_layer = final_expand_layer
301
+ self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=True)
302
+ self.feature_mix_layer = feature_mix_layer
303
+ self.classifier = classifier
304
+
305
+ def forward(self, x):
306
+ x = self.first_conv(x)
307
+ for block in self.blocks:
308
+ x = block(x)
309
+ x = self.final_expand_layer(x)
310
+ x = self.global_avg_pool(x) # global average pooling
311
+ x = self.feature_mix_layer(x)
312
+ x = x.view(x.size(0), -1)
313
+ x = self.classifier(x)
314
+ return x
315
+
316
+ @property
317
+ def module_str(self):
318
+ _str = self.first_conv.module_str + "\n"
319
+ for block in self.blocks:
320
+ _str += block.module_str + "\n"
321
+ _str += self.final_expand_layer.module_str + "\n"
322
+ _str += self.global_avg_pool.__repr__() + "\n"
323
+ _str += self.feature_mix_layer.module_str + "\n"
324
+ _str += self.classifier.module_str
325
+ return _str
326
+
327
+ @property
328
+ def config(self):
329
+ return {
330
+ "name": MobileNetV3_Cifar.__name__,
331
+ "bn": self.get_bn_param(),
332
+ "first_conv": self.first_conv.config,
333
+ "blocks": [block.config for block in self.blocks],
334
+ "final_expand_layer": self.final_expand_layer.config,
335
+ "feature_mix_layer": self.feature_mix_layer.config,
336
+ "classifier": self.classifier.config,
337
+ }
338
+
339
+ @staticmethod
340
+ def build_from_config(config):
341
+ first_conv = set_layer_from_config(config["first_conv"])
342
+ final_expand_layer = set_layer_from_config(config["final_expand_layer"])
343
+ feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
344
+ classifier = set_layer_from_config(config["classifier"])
345
+
346
+ blocks = []
347
+ for block_config in config["blocks"]:
348
+ blocks.append(ResidualBlock.build_from_config(block_config))
349
+
350
+ net = MobileNetV3_Cifar(
351
+ first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
352
+ )
353
+ if "bn" in config:
354
+ net.set_bn_param(**config["bn"])
355
+ else:
356
+ net.set_bn_param(momentum=0.1, eps=1e-5)
357
+
358
+ return net
359
+
360
+ def zero_last_gamma(self):
361
+ for m in self.modules():
362
+ if isinstance(m, ResidualBlock):
363
+ if isinstance(m.conv, MBConvLayer) and isinstance(
364
+ m.shortcut, IdentityLayer
365
+ ):
366
+ m.conv.point_linear.bn.weight.data.zero_()
367
+
368
+ @property
369
+ def grouped_block_index(self):
370
+ info_list = []
371
+ block_index_list = []
372
+ for i, block in enumerate(self.blocks[1:], 1):
373
+ if block.shortcut is None and len(block_index_list) > 0:
374
+ info_list.append(block_index_list)
375
+ block_index_list = []
376
+ block_index_list.append(i)
377
+ if len(block_index_list) > 0:
378
+ info_list.append(block_index_list)
379
+ return info_list
380
+
381
+ @staticmethod
382
+ def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
383
+ # first conv layer
384
+ first_conv = ConvLayer(
385
+ 3,
386
+ input_channel,
387
+ kernel_size=3,
388
+ stride=1,
389
+ use_bn=True,
390
+ act_func="h_swish",
391
+ ops_order="weight_bn_act",
392
+ )
393
+ # build mobile blocks
394
+ feature_dim = input_channel
395
+ blocks = []
396
+ for stage_id, block_config_list in cfg.items():
397
+ for (
398
+ k,
399
+ mid_channel,
400
+ out_channel,
401
+ use_se,
402
+ act_func,
403
+ stride,
404
+ expand_ratio,
405
+ ) in block_config_list:
406
+ mb_conv = MBConvLayer(
407
+ feature_dim,
408
+ out_channel,
409
+ k,
410
+ stride,
411
+ expand_ratio,
412
+ mid_channel,
413
+ act_func,
414
+ use_se,
415
+ )
416
+ if stride == 1 and out_channel == feature_dim:
417
+ shortcut = IdentityLayer(out_channel, out_channel)
418
+ else:
419
+ shortcut = None
420
+ blocks.append(ResidualBlock(mb_conv, shortcut))
421
+ feature_dim = out_channel
422
+ # final expand layer
423
+ final_expand_layer = ConvLayer(
424
+ feature_dim,
425
+ feature_dim * 6,
426
+ kernel_size=1,
427
+ use_bn=True,
428
+ act_func="h_swish",
429
+ ops_order="weight_bn_act",
430
+ )
431
+ # feature mix layer
432
+ feature_mix_layer = ConvLayer(
433
+ feature_dim * 6,
434
+ last_channel,
435
+ kernel_size=1,
436
+ bias=False,
437
+ use_bn=False,
438
+ act_func="h_swish",
439
+ )
440
+ # classifier
441
+ classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
442
+
443
+ return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
444
+
445
+ @staticmethod
446
+ def adjust_cfg(
447
+ cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None
448
+ ):
449
+ for i, (stage_id, block_config_list) in enumerate(cfg.items()):
450
+ for block_config in block_config_list:
451
+ if ks is not None and stage_id != "0":
452
+ block_config[0] = ks
453
+ if expand_ratio is not None and stage_id != "0":
454
+ block_config[-1] = expand_ratio
455
+ block_config[1] = None
456
+ if stage_width_list is not None:
457
+ block_config[2] = stage_width_list[i]
458
+ if depth_param is not None and stage_id != "0":
459
+ new_block_config_list = [block_config_list[0]]
460
+ new_block_config_list += [
461
+ copy.deepcopy(block_config_list[-1]) for _ in range(depth_param - 1)
462
+ ]
463
+ cfg[stage_id] = new_block_config_list
464
+ return cfg
465
+
466
+ def load_state_dict(self, state_dict, **kwargs):
467
+ current_state_dict = self.state_dict()
468
+
469
+ for key in state_dict:
470
+ if key not in current_state_dict:
471
+ assert ".mobile_inverted_conv." in key
472
+ new_key = key.replace(".mobile_inverted_conv.", ".conv.")
473
+ else:
474
+ new_key = key
475
+ current_state_dict[new_key] = state_dict[key]
476
+ super(MobileNetV3_Cifar, self).load_state_dict(current_state_dict)
477
+
478
+
479
+ class MobileNetV3Large_Cifar(MobileNetV3_Cifar):
480
+ def __init__(
481
+ self,
482
+ n_classes=10,
483
+ width_mult=1.0,
484
+ bn_param=(0.1, 1e-5),
485
+ dropout_rate=0.2,
486
+ ks=None,
487
+ expand_ratio=None,
488
+ depth_param=None,
489
+ stage_width_list=None,
490
+ ):
491
+ input_channel = 16
492
+ last_channel = 1280
493
+
494
+ input_channel = make_divisible(
495
+ input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE
496
+ )
497
+ last_channel = (
498
+ make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
499
+ if width_mult > 1.0
500
+ else last_channel
501
+ )
502
+
503
+ cfg = {
504
+ # k, exp, c, se, nl, s, e,
505
+ "0": [
506
+ [3, 16, 16, False, "relu", 1, 1],
507
+ ],
508
+ "1": [
509
+ [3, 64, 24, False, "relu", 1, None], # 4
510
+ [3, 72, 24, False, "relu", 1, None], # 3
511
+ ],
512
+ "2": [
513
+ [5, 72, 40, True, "relu", 2, None], # 3
514
+ [5, 120, 40, True, "relu", 1, None], # 3
515
+ [5, 120, 40, True, "relu", 1, None], # 3
516
+ ],
517
+ "3": [
518
+ [3, 240, 80, False, "h_swish", 2, None], # 6
519
+ [3, 200, 80, False, "h_swish", 1, None], # 2.5
520
+ [3, 184, 80, False, "h_swish", 1, None], # 2.3
521
+ [3, 184, 80, False, "h_swish", 1, None], # 2.3
522
+ ],
523
+ "4": [
524
+ [3, 480, 112, True, "h_swish", 1, None], # 6
525
+ [3, 672, 112, True, "h_swish", 1, None], # 6
526
+ ],
527
+ "5": [
528
+ [5, 672, 160, True, "h_swish", 2, None], # 6
529
+ [5, 960, 160, True, "h_swish", 1, None], # 6
530
+ [5, 960, 160, True, "h_swish", 1, None], # 6
531
+ ],
532
+ }
533
+
534
+ cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list)
535
+ # width multiplier on mobile setting, change `exp: 1` and `c: 2`
536
+ for stage_id, block_config_list in cfg.items():
537
+ for block_config in block_config_list:
538
+ if block_config[1] is not None:
539
+ block_config[1] = make_divisible(
540
+ block_config[1] * width_mult, MyNetwork.CHANNEL_DIVISIBLE
541
+ )
542
+ block_config[2] = make_divisible(
543
+ block_config[2] * width_mult, MyNetwork.CHANNEL_DIVISIBLE
544
+ )
545
+
546
+ (
547
+ first_conv,
548
+ blocks,
549
+ final_expand_layer,
550
+ feature_mix_layer,
551
+ classifier,
552
+ ) = self.build_net_via_cfg(
553
+ cfg, input_channel, last_channel, n_classes, dropout_rate
554
+ )
555
+ super(MobileNetV3Large_Cifar, self).__init__(
556
+ first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
557
+ )
558
+ # set bn param
559
+ self.set_bn_param(*bn_param)
proard/classification/networks/proxyless_nets.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import json
6
+ import torch.nn as nn
7
+
8
+ from proard.utils.layers import (
9
+ set_layer_from_config,
10
+ MBConvLayer,
11
+ ConvLayer,
12
+ IdentityLayer,
13
+ LinearLayer,
14
+ ResidualBlock,
15
+ )
16
+ from proard.utils import (
17
+ download_url,
18
+ make_divisible,
19
+ val2list,
20
+ MyNetwork,
21
+ MyGlobalAvgPool2d,
22
+ )
23
+
24
+ __all__ = ["proxyless_base_cifar","proxyless_base", "ProxylessNASNets", "MobileNetV2", "ProxylessNASNets_Cifar", "MobileNetV2_Cifar"]
25
+
26
+
27
+ def proxyless_base(
28
+ net_config=None,
29
+ n_classes=None,
30
+ bn_param=None,
31
+ dropout_rate=None,
32
+ local_path="~/.torch/proxylessnas/",
33
+ ):
34
+ assert net_config is not None, "Please input a network config"
35
+ if "http" in net_config:
36
+ net_config_path = download_url(net_config, local_path)
37
+ else:
38
+ net_config_path = net_config
39
+ net_config_json = json.load(open(net_config_path, "r"))
40
+
41
+ if n_classes is not None:
42
+ net_config_json["classifier"]["out_features"] = n_classes
43
+ if dropout_rate is not None:
44
+ net_config_json["classifier"]["dropout_rate"] = dropout_rate
45
+
46
+ net = ProxylessNASNets.build_from_config(net_config_json)
47
+ if bn_param is not None:
48
+ net.set_bn_param(*bn_param)
49
+
50
+ return net
51
+
52
+
53
+ class ProxylessNASNets(MyNetwork):
54
+ def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
55
+ super(ProxylessNASNets, self).__init__()
56
+
57
+ self.first_conv = first_conv
58
+ self.blocks = nn.ModuleList(blocks)
59
+ self.feature_mix_layer = feature_mix_layer
60
+ self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
61
+ self.classifier = classifier
62
+
63
+ def forward(self, x):
64
+ x = self.first_conv(x)
65
+ for block in self.blocks:
66
+ x = block(x)
67
+ if self.feature_mix_layer is not None:
68
+ x = self.feature_mix_layer(x)
69
+ x = self.global_avg_pool(x)
70
+ x = self.classifier(x)
71
+ return x
72
+
73
+ @property
74
+ def module_str(self):
75
+ _str = self.first_conv.module_str + "\n"
76
+ for block in self.blocks:
77
+ _str += block.module_str + "\n"
78
+ _str += self.feature_mix_layer.module_str + "\n"
79
+ _str += self.global_avg_pool.__repr__() + "\n"
80
+ _str += self.classifier.module_str
81
+ return _str
82
+
83
+ @property
84
+ def config(self):
85
+ return {
86
+ "name": ProxylessNASNets.__name__,
87
+ "bn": self.get_bn_param(),
88
+ "first_conv": self.first_conv.config,
89
+ "blocks": [block.config for block in self.blocks],
90
+ "feature_mix_layer": None
91
+ if self.feature_mix_layer is None
92
+ else self.feature_mix_layer.config,
93
+ "classifier": self.classifier.config,
94
+ }
95
+
96
+ @staticmethod
97
+ def build_from_config(config):
98
+ first_conv = set_layer_from_config(config["first_conv"])
99
+ feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
100
+ classifier = set_layer_from_config(config["classifier"])
101
+
102
+ blocks = []
103
+ for block_config in config["blocks"]:
104
+ blocks.append(ResidualBlock.build_from_config(block_config))
105
+
106
+ net = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
107
+ if "bn" in config:
108
+ net.set_bn_param(**config["bn"])
109
+ else:
110
+ net.set_bn_param(momentum=0.1, eps=1e-3)
111
+
112
+ return net
113
+
114
+ def zero_last_gamma(self):
115
+ for m in self.modules():
116
+ if isinstance(m, ResidualBlock):
117
+ if isinstance(m.conv, MBConvLayer) and isinstance(
118
+ m.shortcut, IdentityLayer
119
+ ):
120
+ m.conv.point_linear.bn.weight.data.zero_()
121
+
122
+ @property
123
+ def grouped_block_index(self):
124
+ info_list = []
125
+ block_index_list = []
126
+ for i, block in enumerate(self.blocks[1:], 1):
127
+ if block.shortcut is None and len(block_index_list) > 0:
128
+ info_list.append(block_index_list)
129
+ block_index_list = []
130
+ block_index_list.append(i)
131
+ if len(block_index_list) > 0:
132
+ info_list.append(block_index_list)
133
+ return info_list
134
+
135
+ def load_state_dict(self, state_dict, **kwargs):
136
+ current_state_dict = self.state_dict()
137
+
138
+ for key in state_dict:
139
+ if key not in current_state_dict:
140
+ assert ".mobile_inverted_conv." in key
141
+ new_key = key.replace(".mobile_inverted_conv.", ".conv.")
142
+ else:
143
+ new_key = key
144
+ current_state_dict[new_key] = state_dict[key]
145
+ super(ProxylessNASNets, self).load_state_dict(current_state_dict)
146
+
147
+
148
+ class MobileNetV2(ProxylessNASNets):
149
+ def __init__(
150
+ self,
151
+ n_classes=1000,
152
+ width_mult=1.0,
153
+ bn_param=(0.1, 1e-3),
154
+ dropout_rate=0.2,
155
+ ks=None,
156
+ expand_ratio=None,
157
+ depth_param=None,
158
+ stage_width_list=None,
159
+ ):
160
+
161
+ ks = 3 if ks is None else ks
162
+ expand_ratio = 6 if expand_ratio is None else expand_ratio
163
+
164
+ input_channel = 32
165
+ last_channel = 1280
166
+
167
+ input_channel = make_divisible(
168
+ input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE
169
+ )
170
+ last_channel = (
171
+ make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
172
+ if width_mult > 1.0
173
+ else last_channel
174
+ )
175
+
176
+ inverted_residual_setting = [
177
+ # t, c, n, s
178
+ [1, 16, 1, 1],
179
+ [expand_ratio, 24, 2, 2],
180
+ [expand_ratio, 32, 3, 2],
181
+ [expand_ratio, 64, 4, 2],
182
+ [expand_ratio, 96, 3, 1],
183
+ [expand_ratio, 160, 3, 2],
184
+ [expand_ratio, 320, 1, 1],
185
+ ]
186
+
187
+ if depth_param is not None:
188
+ assert isinstance(depth_param, int)
189
+ for i in range(1, len(inverted_residual_setting) - 1):
190
+ inverted_residual_setting[i][2] = depth_param
191
+
192
+ if stage_width_list is not None:
193
+ for i in range(len(inverted_residual_setting)):
194
+ inverted_residual_setting[i][1] = stage_width_list[i]
195
+
196
+ ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
197
+ _pt = 0
198
+
199
+ # first conv layer
200
+ first_conv = ConvLayer(
201
+ 3,
202
+ input_channel,
203
+ kernel_size=3,
204
+ stride=2,
205
+ use_bn=True,
206
+ act_func="relu6",
207
+ ops_order="weight_bn_act",
208
+ )
209
+ # inverted residual blocks
210
+ blocks = []
211
+ for t, c, n, s in inverted_residual_setting:
212
+ output_channel = make_divisible(c * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
213
+ for i in range(n):
214
+ if i == 0:
215
+ stride = s
216
+ else:
217
+ stride = 1
218
+ if t == 1:
219
+ kernel_size = 3
220
+ else:
221
+ kernel_size = ks[_pt]
222
+ _pt += 1
223
+ mobile_inverted_conv = MBConvLayer(
224
+ in_channels=input_channel,
225
+ out_channels=output_channel,
226
+ kernel_size=kernel_size,
227
+ stride=stride,
228
+ expand_ratio=t,
229
+ )
230
+ if stride == 1:
231
+ if input_channel == output_channel:
232
+ shortcut = IdentityLayer(input_channel, input_channel)
233
+ else:
234
+ shortcut = ConvLayer(input_channel,output_channel,kernel_size=1,stride=1,bias=False,act_func=None)
235
+ else:
236
+ shortcut = None
237
+ blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
238
+ input_channel = output_channel
239
+ # 1x1_conv before global average pooling
240
+ feature_mix_layer = ConvLayer(
241
+ input_channel,
242
+ last_channel,
243
+ kernel_size=1,
244
+ use_bn=True,
245
+ act_func="relu6",
246
+ ops_order="weight_bn_act",
247
+ )
248
+
249
+ classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
250
+
251
+ super(MobileNetV2, self).__init__(
252
+ first_conv, blocks, feature_mix_layer, classifier
253
+ )
254
+
255
+ # set bn param
256
+ self.set_bn_param(*bn_param)
257
+
258
+
259
+
260
+ def proxyless_base_cifar(
261
+ net_config=None,
262
+ n_classes=None,
263
+ bn_param=None,
264
+ dropout_rate=None,
265
+ local_path="~/.torch/proxylessnas/",
266
+ ):
267
+ assert net_config is not None, "Please input a network config"
268
+ if "http" in net_config:
269
+ net_config_path = download_url(net_config, local_path)
270
+ else:
271
+ net_config_path = net_config
272
+ net_config_json = json.load(open(net_config_path, "r"))
273
+
274
+ if n_classes is not None:
275
+ net_config_json["classifier"]["out_features"] = n_classes
276
+ if dropout_rate is not None:
277
+ net_config_json["classifier"]["dropout_rate"] = dropout_rate
278
+
279
+ net = ProxylessNASNets_Cifar.build_from_config(net_config_json)
280
+ if bn_param is not None:
281
+ net.set_bn_param(*bn_param)
282
+
283
+ return net
284
+
285
+
286
+ class ProxylessNASNets_Cifar(MyNetwork):
287
+ def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
288
+ super(ProxylessNASNets_Cifar, self).__init__()
289
+
290
+ self.first_conv = first_conv
291
+ self.blocks = nn.ModuleList(blocks)
292
+ self.feature_mix_layer = feature_mix_layer
293
+ self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
294
+ self.classifier = classifier
295
+
296
+ def forward(self, x):
297
+ x = self.first_conv(x)
298
+ for block in self.blocks:
299
+ x = block(x)
300
+ if self.feature_mix_layer is not None:
301
+ x = self.feature_mix_layer(x)
302
+ x = self.global_avg_pool(x)
303
+ x = self.classifier(x)
304
+ return x
305
+
306
+ @property
307
+ def module_str(self):
308
+ _str = self.first_conv.module_str + "\n"
309
+ for block in self.blocks:
310
+ _str += block.module_str + "\n"
311
+ _str += self.feature_mix_layer.module_str + "\n"
312
+ _str += self.global_avg_pool.__repr__() + "\n"
313
+ _str += self.classifier.module_str
314
+ return _str
315
+
316
+ @property
317
+ def config(self):
318
+ return {
319
+ "name": ProxylessNASNets_Cifar.__name__,
320
+ "bn": self.get_bn_param(),
321
+ "first_conv": self.first_conv.config,
322
+ "blocks": [block.config for block in self.blocks],
323
+ "feature_mix_layer": None
324
+ if self.feature_mix_layer is None
325
+ else self.feature_mix_layer.config,
326
+ "classifier": self.classifier.config,
327
+ }
328
+
329
+ @staticmethod
330
+ def build_from_config(config):
331
+ first_conv = set_layer_from_config(config["first_conv"])
332
+ feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
333
+ classifier = set_layer_from_config(config["classifier"])
334
+
335
+ blocks = []
336
+ for block_config in config["blocks"]:
337
+ blocks.append(ResidualBlock.build_from_config(block_config))
338
+
339
+ net = ProxylessNASNets_Cifar(first_conv, blocks, feature_mix_layer, classifier)
340
+ if "bn" in config:
341
+ net.set_bn_param(**config["bn"])
342
+ else:
343
+ net.set_bn_param(momentum=0.1, eps=1e-3)
344
+
345
+ return net
346
+
347
+ def zero_last_gamma(self):
348
+ for m in self.modules():
349
+ if isinstance(m, ResidualBlock):
350
+ if isinstance(m.conv, MBConvLayer) and isinstance(
351
+ m.shortcut, IdentityLayer
352
+ ):
353
+ m.conv.point_linear.bn.weight.data.zero_()
354
+
355
+ @property
356
+ def grouped_block_index(self):
357
+ info_list = []
358
+ block_index_list = []
359
+ for i, block in enumerate(self.blocks[1:], 1):
360
+ if block.shortcut is None and len(block_index_list) > 0:
361
+ info_list.append(block_index_list)
362
+ block_index_list = []
363
+ block_index_list.append(i)
364
+ if len(block_index_list) > 0:
365
+ info_list.append(block_index_list)
366
+ return info_list
367
+
368
+ def load_state_dict(self, state_dict, **kwargs):
369
+ current_state_dict = self.state_dict()
370
+
371
+ for key in state_dict:
372
+ if key not in current_state_dict:
373
+ assert ".mobile_inverted_conv." in key
374
+ new_key = key.replace(".mobile_inverted_conv.", ".conv.")
375
+ else:
376
+ new_key = key
377
+ current_state_dict[new_key] = state_dict[key]
378
+ super(ProxylessNASNets_Cifar, self).load_state_dict(current_state_dict)
379
+
380
+
381
+ class MobileNetV2_Cifar(ProxylessNASNets_Cifar):
382
+ def __init__(
383
+ self,
384
+ n_classes=10,
385
+ width_mult=1.0,
386
+ bn_param=(0.1, 1e-3),
387
+ dropout_rate=0.2,
388
+ ks=None,
389
+ expand_ratio=None,
390
+ depth_param=None,
391
+ stage_width_list=None,
392
+ ):
393
+
394
+ ks = 3 if ks is None else ks
395
+ expand_ratio = 6 if expand_ratio is None else expand_ratio
396
+
397
+ input_channel = 32
398
+ last_channel = 1280
399
+
400
+ input_channel = make_divisible(
401
+ input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE
402
+ )
403
+ last_channel = (
404
+ make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
405
+ if width_mult > 1.0
406
+ else last_channel
407
+ )
408
+
409
+ inverted_residual_setting = [
410
+ # t, c, n, s
411
+ [1, 16, 1, 1],
412
+ [expand_ratio, 24, 2, 1],
413
+ [expand_ratio, 32, 3, 2],
414
+ [expand_ratio, 64, 4, 2],
415
+ [expand_ratio, 96, 3, 1],
416
+ [expand_ratio, 160, 3, 2],
417
+ [expand_ratio, 320, 1, 1],
418
+ ]
419
+
420
+ if depth_param is not None:
421
+ assert isinstance(depth_param, int)
422
+ for i in range(1, len(inverted_residual_setting) - 1):
423
+ inverted_residual_setting[i][2] = depth_param
424
+
425
+ if stage_width_list is not None:
426
+ for i in range(len(inverted_residual_setting)):
427
+ inverted_residual_setting[i][1] = stage_width_list[i]
428
+
429
+ ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
430
+ _pt = 0
431
+
432
+ # first conv layer
433
+ first_conv = ConvLayer(
434
+ 3,
435
+ input_channel,
436
+ kernel_size=3,
437
+ stride=1,
438
+ use_bn=True,
439
+ act_func="relu6",
440
+ ops_order="weight_bn_act",
441
+ )
442
+ # inverted residual blocks
443
+ blocks = []
444
+ for t, c, n, s in inverted_residual_setting:
445
+ output_channel = make_divisible(c * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
446
+ for i in range(n):
447
+ if i == 0:
448
+ stride = s
449
+ else:
450
+ stride = 1
451
+ if t == 1:
452
+ kernel_size = 3
453
+ else:
454
+ kernel_size = ks[_pt]
455
+ _pt += 1
456
+ mobile_inverted_conv = MBConvLayer(
457
+ in_channels=input_channel,
458
+ out_channels=output_channel,
459
+ kernel_size=kernel_size,
460
+ stride=stride,
461
+ expand_ratio=t,
462
+ )
463
+ if stride == 1:
464
+ if input_channel == output_channel:
465
+ shortcut = IdentityLayer(input_channel, input_channel)
466
+ else:
467
+ shortcut = None #ConvLayer(input_channel,output_channel,kernel_size=1,stride=1,bias=False,act_func=None)
468
+ else:
469
+ shortcut = None
470
+ blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
471
+ input_channel = output_channel
472
+ # 1x1_conv before global average pooling
473
+ feature_mix_layer = ConvLayer(
474
+ input_channel,
475
+ last_channel,
476
+ kernel_size=1,
477
+ stride=1,
478
+ use_bn=True,
479
+ act_func="relu6",
480
+ ops_order="weight_bn_act",
481
+ )
482
+
483
+ classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
484
+
485
+ super(MobileNetV2_Cifar, self).__init__(
486
+ first_conv, blocks, feature_mix_layer, classifier
487
+ )
488
+
489
+ # set bn param
490
+ self.set_bn_param(*bn_param)
proard/classification/networks/resnet_trades.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class BasicBlock(nn.Module):
7
+ expansion = 1
8
+
9
+ def __init__(self, in_planes, planes, stride=1):
10
+ super(BasicBlock, self).__init__()
11
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
12
+ self.bn1 = nn.BatchNorm2d(planes)
13
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
14
+ self.bn2 = nn.BatchNorm2d(planes)
15
+
16
+ self.shortcut = nn.Sequential()
17
+ if stride != 1 or in_planes != self.expansion * planes:
18
+ self.shortcut = nn.Sequential(
19
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
20
+ nn.BatchNorm2d(self.expansion * planes)
21
+ )
22
+
23
+ def forward(self, x):
24
+ out = F.relu(self.bn1(self.conv1(x)))
25
+ out = self.bn2(self.conv2(out))
26
+ out += self.shortcut(x)
27
+ out = F.relu(out)
28
+ return out
29
+
30
+
31
+ class Bottleneck(nn.Module):
32
+ expansion = 4
33
+
34
+ def __init__(self, in_planes, planes, stride=1):
35
+ super(Bottleneck, self).__init__()
36
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
37
+ self.bn1 = nn.BatchNorm2d(planes)
38
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
39
+ self.bn2 = nn.BatchNorm2d(planes)
40
+ self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
41
+ self.bn3 = nn.BatchNorm2d(self.expansion * planes)
42
+
43
+ self.shortcut = nn.Sequential()
44
+ if stride != 1 or in_planes != self.expansion * planes:
45
+ self.shortcut = nn.Sequential(
46
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
47
+ nn.BatchNorm2d(self.expansion * planes)
48
+ )
49
+
50
+ def forward(self, x):
51
+ out = F.relu(self.bn1(self.conv1(x)))
52
+ out = F.relu(self.bn2(self.conv2(out)))
53
+ out = self.bn3(self.conv3(out))
54
+ out += self.shortcut(x)
55
+ out = F.relu(out)
56
+ return out
57
+
58
+ from proard.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d
59
+ class ResNet(MyNetwork):
60
+ def __init__(self, block, num_blocks, num_classes=10):
61
+ super(ResNet, self).__init__()
62
+ self.in_planes = 64
63
+
64
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(64)
66
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
67
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
68
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
69
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
70
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
71
+
72
+ def _make_layer(self, block, planes, num_blocks, stride):
73
+ strides = [stride] + [1] * (num_blocks - 1)
74
+ layers = []
75
+ for stride in strides:
76
+ layers.append(block(self.in_planes, planes, stride))
77
+ self.in_planes = planes * block.expansion
78
+ return nn.Sequential(*layers)
79
+
80
+ def forward(self, x):
81
+ out = F.relu(self.bn1(self.conv1(x)))
82
+ out = self.layer1(out)
83
+ out = self.layer2(out)
84
+ out = self.layer3(out)
85
+ out = self.layer4(out)
86
+ out = F.avg_pool2d(out, 4)
87
+ out = out.view(out.size(0), -1)
88
+ out = self.linear(out)
89
+ return out
90
+
91
+
92
+ def ResNet18_trades():
93
+ return ResNet(BasicBlock, [2, 2, 2, 2])
94
+
95
+
96
+ def ResNet34_trades():
97
+ return ResNet(BasicBlock, [3, 4, 6, 3])
98
+
99
+
100
+ def ResNet50_trades():
101
+ return ResNet(Bottleneck, [3, 4, 6, 3])
102
+
103
+
104
+ def ResNet101_trades():
105
+ return ResNet(Bottleneck, [3, 4, 23, 3])
106
+
107
+
108
+ def ResNet152_trades():
109
+ return ResNet(Bottleneck, [3, 8, 36, 3])
110
+
111
+
112
+ def test():
113
+ net = ResNet18_trades()
114
+ y = net(torch.randn(1, 3, 32, 32))
115
+ print(y.size())
proard/classification/networks/resnets.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from proard.utils.layers import (
4
+ set_layer_from_config,
5
+ ConvLayer,
6
+ IdentityLayer,
7
+ LinearLayer,
8
+ )
9
+ from proard.utils.layers import ResNetBottleneckBlock, ResidualBlock
10
+ from proard.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d
11
+
12
+ __all__ = ["ResNets", "ResNet50", "ResNet50D","ResNets_Cifar","ResNet50_Cifar", "ResNet50D_Cifar"]
13
+
14
+
15
+ class ResNets(MyNetwork):
16
+ BASE_DEPTH_LIST = [2, 2, 4, 2]
17
+ STAGE_WIDTH_LIST = [256, 512, 1024, 2048]
18
+
19
+ def __init__(self, input_stem, blocks, classifier):
20
+ super(ResNets, self).__init__()
21
+
22
+ self.input_stem = nn.ModuleList(input_stem)
23
+ self.max_pooling = nn.MaxPool2d(
24
+ kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
25
+ )
26
+ self.blocks = nn.ModuleList(blocks)
27
+ self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
28
+ self.classifier = classifier
29
+
30
+ def forward(self, x):
31
+ for layer in self.input_stem:
32
+ x = layer(x)
33
+ x = self.max_pooling(x)
34
+ for block in self.blocks:
35
+ x = block(x)
36
+ x = self.global_avg_pool(x)
37
+ x = self.classifier(x)
38
+ return x
39
+
40
+ @property
41
+ def module_str(self):
42
+ _str = ""
43
+ for layer in self.input_stem:
44
+ _str += layer.module_str + "\n"
45
+ _str += "max_pooling(ks=3, stride=2)\n"
46
+ for block in self.blocks:
47
+ _str += block.module_str + "\n"
48
+ _str += self.global_avg_pool.__repr__() + "\n"
49
+ _str += self.classifier.module_str
50
+ return _str
51
+
52
+ @property
53
+ def config(self):
54
+ return {
55
+ "name": ResNets.__name__,
56
+ "bn": self.get_bn_param(),
57
+ "input_stem": [layer.config for layer in self.input_stem],
58
+ "blocks": [block.config for block in self.blocks],
59
+ "classifier": self.classifier.config,
60
+ }
61
+
62
+ @staticmethod
63
+ def build_from_config(config):
64
+ classifier = set_layer_from_config(config["classifier"])
65
+
66
+ input_stem = []
67
+ for layer_config in config["input_stem"]:
68
+ input_stem.append(set_layer_from_config(layer_config))
69
+ blocks = []
70
+ for block_config in config["blocks"]:
71
+ blocks.append(set_layer_from_config(block_config))
72
+
73
+ net = ResNets(input_stem, blocks, classifier)
74
+ if "bn" in config:
75
+ net.set_bn_param(**config["bn"])
76
+ else:
77
+ net.set_bn_param(momentum=0.1, eps=1e-5)
78
+
79
+ return net
80
+
81
+ def zero_last_gamma(self):
82
+ for m in self.modules():
83
+ if isinstance(m, ResNetBottleneckBlock) and isinstance(
84
+ m.downsample, IdentityLayer
85
+ ):
86
+ m.conv3.bn.weight.data.zero_()
87
+
88
+ @property
89
+ def grouped_block_index(self):
90
+ info_list = []
91
+ block_index_list = []
92
+ for i, block in enumerate(self.blocks):
93
+ if (
94
+ not isinstance(block.downsample, IdentityLayer)
95
+ and len(block_index_list) > 0
96
+ ):
97
+ info_list.append(block_index_list)
98
+ block_index_list = []
99
+ block_index_list.append(i)
100
+ if len(block_index_list) > 0:
101
+ info_list.append(block_index_list)
102
+ return info_list
103
+
104
+ def load_state_dict(self, state_dict, **kwargs):
105
+ super(ResNets, self).load_state_dict(state_dict)
106
+
107
+
108
+ class ResNet50(ResNets):
109
+ def __init__(
110
+ self,
111
+ n_classes=1000,
112
+ width_mult=1.0,
113
+ bn_param=(0.1, 1e-5),
114
+ dropout_rate=0,
115
+ expand_ratio=None,
116
+ depth_param=None,
117
+ ):
118
+
119
+ expand_ratio = 0.25 if expand_ratio is None else expand_ratio
120
+
121
+ input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
122
+ stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
123
+ for i, width in enumerate(stage_width_list):
124
+ stage_width_list[i] = make_divisible(
125
+ width * width_mult, MyNetwork.CHANNEL_DIVISIBLE
126
+ )
127
+
128
+ depth_list = [3, 4, 6, 3]
129
+ if depth_param is not None:
130
+ for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
131
+ depth_list[i] = depth + depth_param
132
+
133
+ stride_list = [1, 2, 2, 2]
134
+
135
+ # build input stem
136
+ input_stem = [
137
+ ConvLayer(
138
+ 3,
139
+ input_channel,
140
+ kernel_size=7,
141
+ stride=2,
142
+ use_bn=True,
143
+ act_func="relu",
144
+ ops_order="weight_bn_act",
145
+ )
146
+ ]
147
+
148
+ # blocks
149
+ blocks = []
150
+ for d, width, s in zip(depth_list, stage_width_list, stride_list):
151
+ for i in range(d):
152
+ stride = s if i == 0 else 1
153
+ bottleneck_block = ResNetBottleneckBlock(
154
+ input_channel,
155
+ width,
156
+ kernel_size=3,
157
+ stride=stride,
158
+ expand_ratio=expand_ratio,
159
+ act_func="relu",
160
+ downsample_mode="conv",
161
+ )
162
+ blocks.append(bottleneck_block)
163
+ input_channel = width
164
+ # classifier
165
+ classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
166
+
167
+ super(ResNet50, self).__init__(input_stem, blocks, classifier)
168
+
169
+ # set bn param
170
+ self.set_bn_param(*bn_param)
171
+
172
+
173
+ class ResNet50D(ResNets):
174
+ def __init__(
175
+ self,
176
+ n_classes=1000,
177
+ width_mult=1.0,
178
+ bn_param=(0.1, 1e-5),
179
+ dropout_rate=0,
180
+ expand_ratio=None,
181
+ depth_param=None,
182
+ ):
183
+
184
+ expand_ratio = 0.25 if expand_ratio is None else expand_ratio
185
+
186
+ input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
187
+ mid_input_channel = make_divisible(
188
+ input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE
189
+ )
190
+ stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
191
+ for i, width in enumerate(stage_width_list):
192
+ stage_width_list[i] = make_divisible(
193
+ width * width_mult, MyNetwork.CHANNEL_DIVISIBLE
194
+ )
195
+
196
+ depth_list = [3, 4, 6, 3]
197
+ if depth_param is not None:
198
+ for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
199
+ depth_list[i] = depth + depth_param
200
+
201
+ stride_list = [1, 2, 2, 2]
202
+
203
+ # build input stem
204
+ input_stem = [
205
+ ConvLayer(3, mid_input_channel, 3, stride=2, use_bn=True, act_func="relu"),
206
+ ResidualBlock(
207
+ ConvLayer(
208
+ mid_input_channel,
209
+ mid_input_channel,
210
+ 3,
211
+ stride=1,
212
+ use_bn=True,
213
+ act_func="relu",
214
+ ),
215
+ IdentityLayer(mid_input_channel, mid_input_channel),
216
+ ),
217
+ ConvLayer(
218
+ mid_input_channel,
219
+ input_channel,
220
+ 3,
221
+ stride=1,
222
+ use_bn=True,
223
+ act_func="relu",
224
+ ),
225
+ ]
226
+
227
+ # blocks
228
+ blocks = []
229
+ for d, width, s in zip(depth_list, stage_width_list, stride_list):
230
+ for i in range(d):
231
+ stride = s if i == 0 else 1
232
+ bottleneck_block = ResNetBottleneckBlock(
233
+ input_channel,
234
+ width,
235
+ kernel_size=3,
236
+ stride=stride,
237
+ expand_ratio=expand_ratio,
238
+ act_func="relu",
239
+ downsample_mode="avgpool_conv",
240
+ )
241
+ blocks.append(bottleneck_block)
242
+ input_channel = width
243
+ # classifier
244
+ classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
245
+
246
+ super(ResNet50D, self).__init__(input_stem, blocks, classifier)
247
+
248
+ # set bn param
249
+ self.set_bn_param(*bn_param)
250
+
251
+
252
+
253
+ class ResNets_Cifar(MyNetwork):
254
+
255
+ BASE_DEPTH_LIST = [2, 2, 4, 2]
256
+ STAGE_WIDTH_LIST = [256, 512, 1024, 2048]
257
+
258
+ def __init__(self, input_stem, blocks, classifier):
259
+ super(ResNets_Cifar, self).__init__()
260
+
261
+ self.input_stem = nn.ModuleList(input_stem)
262
+ self.blocks = nn.ModuleList(blocks)
263
+ self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
264
+ self.classifier = classifier
265
+
266
+ def forward(self, x):
267
+ for layer in self.input_stem:
268
+ x = layer(x)
269
+ for block in self.blocks:
270
+ x = block(x)
271
+ x = self.global_avg_pool(x)
272
+ x = self.classifier(x)
273
+ return x
274
+
275
+ @property
276
+ def module_str(self):
277
+ _str = ""
278
+ for layer in self.input_stem:
279
+ _str += layer.module_str + "\n"
280
+ # _str += "max_pooling(ks=3, stride=2)\n"
281
+ for block in self.blocks:
282
+ _str += block.module_str + "\n"
283
+ _str += self.global_avg_pool.__repr__() + "\n"
284
+ _str += self.classifier.module_str
285
+ return _str
286
+
287
+ @property
288
+ def config(self):
289
+ return {
290
+ "name": ResNets_Cifar.__name__,
291
+ "bn": self.get_bn_param(),
292
+ "input_stem": [layer.config for layer in self.input_stem],
293
+ "blocks": [block.config for block in self.blocks],
294
+ "classifier": self.classifier.config,
295
+ }
296
+
297
+ @staticmethod
298
+ def build_from_config(config):
299
+ classifier = set_layer_from_config(config["classifier"])
300
+
301
+ input_stem = []
302
+ for layer_config in config["input_stem"]:
303
+ input_stem.append(set_layer_from_config(layer_config))
304
+ blocks = []
305
+ for block_config in config["blocks"]:
306
+ blocks.append(set_layer_from_config(block_config))
307
+
308
+ net = ResNets(input_stem, blocks, classifier)
309
+ if "bn" in config:
310
+ net.set_bn_param(**config["bn"])
311
+ else:
312
+ net.set_bn_param(momentum=0.1, eps=1e-5)
313
+
314
+ return net
315
+
316
+ def zero_last_gamma(self):
317
+ for m in self.modules():
318
+ if isinstance(m, ResNetBottleneckBlock) and isinstance(
319
+ m.downsample, IdentityLayer
320
+ ):
321
+ m.conv3.bn.weight.data.zero_()
322
+
323
+ @property
324
+ def grouped_block_index(self):
325
+ info_list = []
326
+ block_index_list = []
327
+ for i, block in enumerate(self.blocks):
328
+ if (
329
+ not isinstance(block.downsample, IdentityLayer)
330
+ and len(block_index_list) > 0
331
+ ):
332
+ info_list.append(block_index_list)
333
+ block_index_list = []
334
+ block_index_list.append(i)
335
+ if len(block_index_list) > 0:
336
+ info_list.append(block_index_list)
337
+ return info_list
338
+
339
+ def load_state_dict(self, state_dict, **kwargs):
340
+ super(ResNets_Cifar, self).load_state_dict(state_dict)
341
+
342
+
343
+ class ResNet50_Cifar(ResNets_Cifar):
344
+ def __init__(
345
+ self,
346
+ n_classes=10,
347
+ width_mult=1.0,
348
+ bn_param=(0.1, 1e-5),
349
+ dropout_rate=0,
350
+ expand_ratio=None,
351
+ depth_param=None,
352
+ ):
353
+
354
+ expand_ratio = 0.25 if expand_ratio is None else expand_ratio
355
+
356
+ input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
357
+ stage_width_list = ResNets_Cifar.STAGE_WIDTH_LIST.copy()
358
+ for i, width in enumerate(stage_width_list):
359
+ stage_width_list[i] = make_divisible(
360
+ width * width_mult, MyNetwork.CHANNEL_DIVISIBLE
361
+ )
362
+
363
+ depth_list = [3, 4, 6, 3]
364
+ if depth_param is not None:
365
+ for i, depth in enumerate(ResNets_Cifar.BASE_DEPTH_LIST):
366
+ depth_list[i] = depth + depth_param
367
+
368
+ stride_list = [1, 2, 2, 2]
369
+
370
+ # build input stem
371
+ input_stem = [
372
+ ConvLayer(
373
+ 3,
374
+ input_channel,
375
+ kernel_size=3,
376
+ stride=1,
377
+ use_bn=True,
378
+ act_func="relu",
379
+ ops_order="weight_bn_act",
380
+ )
381
+ ]
382
+
383
+ # blocks
384
+ blocks = []
385
+ for d, width, s in zip(depth_list, stage_width_list, stride_list):
386
+ for i in range(d):
387
+ stride = s if i == 0 else 1
388
+ bottleneck_block = ResNetBottleneckBlock(
389
+ input_channel,
390
+ width,
391
+ kernel_size=3,
392
+ stride=stride,
393
+ expand_ratio=expand_ratio,
394
+ act_func="relu",
395
+ downsample_mode="conv",
396
+ )
397
+ blocks.append(bottleneck_block)
398
+ input_channel = width
399
+ # classifier
400
+ classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
401
+
402
+ super(ResNet50_Cifar, self).__init__(input_stem, blocks, classifier)
403
+
404
+ # set bn param
405
+ self.set_bn_param(*bn_param)
406
+
407
+
408
+ class ResNet50D_Cifar(ResNets_Cifar):
409
+ def __init__(
410
+ self,
411
+ n_classes=10,
412
+ width_mult=1.0,
413
+ bn_param=(0.1, 1e-5),
414
+ dropout_rate=0,
415
+ expand_ratio=None,
416
+ depth_param=None,
417
+ ):
418
+
419
+ expand_ratio = 0.25 if expand_ratio is None else expand_ratio
420
+
421
+ input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
422
+ mid_input_channel = make_divisible(
423
+ input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE
424
+ )
425
+ stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
426
+ for i, width in enumerate(stage_width_list):
427
+ stage_width_list[i] = make_divisible(
428
+ width * width_mult, MyNetwork.CHANNEL_DIVISIBLE
429
+ )
430
+
431
+ depth_list = [3, 4, 6, 3]
432
+ if depth_param is not None:
433
+ for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
434
+ depth_list[i] = depth + depth_param
435
+
436
+ stride_list = [1, 2, 2, 2]
437
+
438
+ # build input stem
439
+ input_stem = [
440
+ ConvLayer(3, mid_input_channel, 3, stride=1, use_bn=True, act_func="relu"),
441
+ ResidualBlock(
442
+ ConvLayer(
443
+ mid_input_channel,
444
+ mid_input_channel,
445
+ 3,
446
+ stride=1,
447
+ use_bn=True,
448
+ act_func="relu",
449
+ ),
450
+ IdentityLayer(mid_input_channel, mid_input_channel),
451
+ ),
452
+ ConvLayer(
453
+ mid_input_channel,
454
+ input_channel,
455
+ 3,
456
+ stride=1,
457
+ use_bn=True,
458
+ act_func="relu",
459
+ ),
460
+ ]
461
+
462
+ # blocks
463
+ blocks = []
464
+ for d, width, s in zip(depth_list, stage_width_list, stride_list):
465
+ for i in range(d):
466
+ stride = s if i == 0 else 1
467
+ bottleneck_block = ResNetBottleneckBlock(
468
+ input_channel,
469
+ width,
470
+ kernel_size=3,
471
+ stride=stride,
472
+ expand_ratio=expand_ratio,
473
+ act_func="relu",
474
+ downsample_mode="avgpool_conv",
475
+ )
476
+ blocks.append(bottleneck_block)
477
+ input_channel = width
478
+ # classifier
479
+ classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
480
+
481
+ super(ResNet50D_Cifar, self).__init__(input_stem, blocks, classifier)
482
+
483
+ # set bn param
484
+ self.set_bn_param(*bn_param)
485
+ if __name__=="__main__":
486
+ import torch
487
+ resnet = ResNet50_Cifar()
488
+ x = torch.randn((1,3,32,32))
489
+ resnet(x)
490
+
proard/classification/networks/wide_resnet.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from proard.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d
6
+
7
+ class BasicBlock(nn.Module):
8
+ def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9
+ super(BasicBlock, self).__init__()
10
+ self.bn1 = nn.BatchNorm2d(in_planes)
11
+ self.relu1 = nn.ReLU(inplace=True)
12
+ self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13
+ padding=1, bias=False)
14
+ self.bn2 = nn.BatchNorm2d(out_planes)
15
+ self.relu2 = nn.ReLU(inplace=True)
16
+ self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17
+ padding=1, bias=False)
18
+ self.droprate = dropRate
19
+ self.equalInOut = (in_planes == out_planes)
20
+ self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21
+ padding=0, bias=False) or None
22
+
23
+ def forward(self, x):
24
+ if not self.equalInOut:
25
+ x = self.relu1(self.bn1(x))
26
+ else:
27
+ out = self.relu1(self.bn1(x))
28
+ out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29
+ if self.droprate > 0:
30
+ out = F.dropout(out, p=self.droprate, training=self.training)
31
+ out = self.conv2(out)
32
+ return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33
+
34
+
35
+ class NetworkBlock(nn.Module):
36
+ def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
37
+ super(NetworkBlock, self).__init__()
38
+ self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
39
+
40
+ def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
41
+ layers = []
42
+ for i in range(int(nb_layers)):
43
+ layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
44
+ return nn.Sequential(*layers)
45
+
46
+ def forward(self, x):
47
+ return self.layer(x)
48
+
49
+
50
+ class WideResNet(MyNetwork):
51
+ def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0):
52
+ super(WideResNet, self).__init__()
53
+ nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
54
+ assert ((depth - 4) % 6 == 0)
55
+ n = (depth - 4) / 6
56
+ block = BasicBlock
57
+ # 1st conv before any network block
58
+ self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
59
+ padding=1, bias=False)
60
+ # 1st block
61
+ self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
62
+ # 1st sub-block
63
+ self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
64
+ # 2nd block
65
+ self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
66
+ # 3rd block
67
+ self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
68
+ # global average pooling and classifier
69
+ self.bn1 = nn.BatchNorm2d(nChannels[3])
70
+ self.relu = nn.ReLU(inplace=True)
71
+ self.fc = nn.Linear(nChannels[3], num_classes)
72
+ self.nChannels = nChannels[3]
73
+
74
+ for m in self.modules():
75
+ if isinstance(m, nn.Conv2d):
76
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
77
+ m.weight.data.normal_(0, math.sqrt(2. / n))
78
+ elif isinstance(m, nn.BatchNorm2d):
79
+ m.weight.data.fill_(1)
80
+ m.bias.data.zero_()
81
+ elif isinstance(m, nn.Linear):
82
+ m.bias.data.zero_()
83
+
84
+ def forward(self, x):
85
+ out = self.conv1(x)
86
+ out = self.block1(out)
87
+ out = self.block2(out)
88
+ out = self.block3(out)
89
+ out = self.relu(self.bn1(out))
90
+ out = F.avg_pool2d(out, 8)
91
+ out = out.view(-1, self.nChannels)
92
+ return self.fc(out)
93
+
proard/classification/run_manager/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ from .run_config import *
6
+ from .run_manager import *
7
+ from .distributed_run_manager import *
proard/classification/run_manager/distributed_run_manager.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import json
7
+ import time
8
+ import random
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from attacks import create_attack
13
+ import torch.backends.cudnn as cudnn
14
+ from tqdm import tqdm
15
+ from attacks.utils import ctx_noparamgrad_and_eval
16
+ from proard.utils import (
17
+ cross_entropy_with_label_smoothing,
18
+ cross_entropy_loss_with_soft_target,
19
+ write_log,
20
+ init_models,
21
+ )
22
+ from proard.utils import (
23
+ DistributedMetric,
24
+ list_mean,
25
+ get_net_info,
26
+ accuracy,
27
+ AverageMeter,
28
+ mix_labels,
29
+ mix_images,
30
+ )
31
+ from proard.utils import MyRandomResizedCrop
32
+
33
+ __all__ = ["DistributedRunManager"]
34
+
35
+
36
+ class DistributedRunManager:
37
+ def __init__(
38
+ self,
39
+ path,
40
+ net,
41
+ run_config,
42
+ hvd_compression,
43
+ backward_steps=1,
44
+ is_root=False,
45
+ init=True,
46
+ ):
47
+ import horovod.torch as hvd
48
+
49
+ self.path = path
50
+ self.net = net
51
+ self.run_config = run_config
52
+ self.is_root = is_root
53
+
54
+ self.best_acc = 0.0
55
+ self.best_robustness = 0.0
56
+ self.start_epoch = 0
57
+
58
+ os.makedirs(self.path, exist_ok=True)
59
+
60
+ self.net.cuda()
61
+ cudnn.benchmark = True
62
+ if init and self.is_root:
63
+ init_models(self.net, self.run_config.model_init)
64
+ if self.is_root:
65
+ # print net info
66
+ net_info = get_net_info(self.net, self.run_config.data_provider.data_shape)
67
+ with open("%s/net_info.txt" % self.path, "w") as fout:
68
+ fout.write(json.dumps(net_info, indent=4) + "\n")
69
+ try:
70
+ fout.write(self.net.module_str + "\n")
71
+ except Exception:
72
+ fout.write("%s do not support `module_str`" % type(self.net))
73
+ fout.write(
74
+ "%s\n" % self.run_config.data_provider.train.dataset.transform
75
+ )
76
+ fout.write(
77
+ "%s\n" % self.run_config.data_provider.test.dataset.transform
78
+ )
79
+ fout.write("%s\n" % self.net)
80
+
81
+ # criterion
82
+ self.train_criterion = self.run_config.train_criterion_loss
83
+ self.test_criterion = self.run_config.test_criterion_loss
84
+ self.kd_criterion = self.run_config.kd_criterion_loss
85
+
86
+ # optimizer
87
+ if self.run_config.no_decay_keys:
88
+ keys = self.run_config.no_decay_keys.split("#")
89
+ net_params = [
90
+ self.net.get_parameters(
91
+ keys, mode="exclude"
92
+ ), # parameters with weight decay
93
+ self.net.get_parameters(
94
+ keys, mode="include"
95
+ ), # parameters without weight decay
96
+ ]
97
+ else:
98
+ # noinspection PyBroadException
99
+ try:
100
+ net_params = self.network.weight_parameters()
101
+ except Exception:
102
+ net_params = []
103
+ for param in self.network.parameters():
104
+ if param.requires_grad:
105
+ net_params.append(param)
106
+ self.optimizer = self.run_config.build_optimizer(net_params)
107
+ self.optimizer = hvd.DistributedOptimizer(
108
+ self.optimizer,
109
+ named_parameters=self.net.named_parameters(),
110
+ compression=hvd_compression,
111
+ backward_passes_per_step=backward_steps,
112
+ )
113
+
114
+ """ save path and log path """
115
+
116
+ @property
117
+ def save_path(self):
118
+ if self.__dict__.get("_save_path", None) is None:
119
+ save_path = os.path.join(self.path, "checkpoint")
120
+ os.makedirs(save_path, exist_ok=True)
121
+ self.__dict__["_save_path"] = save_path
122
+ return self.__dict__["_save_path"]
123
+
124
+ @property
125
+ def logs_path(self):
126
+ if self.__dict__.get("_logs_path", None) is None:
127
+ logs_path = os.path.join(self.path, "logs")
128
+ os.makedirs(logs_path, exist_ok=True)
129
+ self.__dict__["_logs_path"] = logs_path
130
+ return self.__dict__["_logs_path"]
131
+
132
+ @property
133
+ def network(self):
134
+ return self.net
135
+
136
+ @network.setter
137
+ def network(self, new_val):
138
+ self.net = new_val
139
+
140
+ def write_log(self, log_str, prefix="valid", should_print=True, mode="a"):
141
+ if self.is_root:
142
+ write_log(self.logs_path, log_str, prefix, should_print, mode)
143
+
144
+ """ save & load model & save_config & broadcast """
145
+
146
+ def save_config(self, extra_run_config=None, extra_net_config=None):
147
+ if self.is_root:
148
+ run_save_path = os.path.join(self.path, "run.config")
149
+ if not os.path.isfile(run_save_path):
150
+ run_config = self.run_config.config
151
+ if extra_run_config is not None:
152
+ run_config.update(extra_run_config)
153
+ json.dump(run_config, open(run_save_path, "w"), indent=4)
154
+ print("Run configs dump to %s" % run_save_path)
155
+
156
+ try:
157
+ net_save_path = os.path.join(self.path, "net.config")
158
+ net_config = self.net.config
159
+ if extra_net_config is not None:
160
+ net_config.update(extra_net_config)
161
+ json.dump(net_config, open(net_save_path, "w"), indent=4)
162
+ print("Network configs dump to %s" % net_save_path)
163
+ except Exception:
164
+ print("%s do not support net config" % type(self.net))
165
+
166
+ def save_model(self, checkpoint=None, is_best=False, model_name=None):
167
+ if self.is_root:
168
+ if checkpoint is None:
169
+ checkpoint = {"state_dict": self.net.state_dict()}
170
+
171
+ if model_name is None:
172
+ model_name = "checkpoint.pth.tar"
173
+
174
+ latest_fname = os.path.join(self.save_path, "latest.txt")
175
+ model_path = os.path.join(self.save_path, model_name)
176
+ with open(latest_fname, "w") as _fout:
177
+ _fout.write(model_path + "\n")
178
+ torch.save(checkpoint, model_path)
179
+
180
+ if is_best:
181
+ best_path = os.path.join(self.save_path, "model_best.pth.tar")
182
+ torch.save({"state_dict": checkpoint["state_dict"]}, best_path)
183
+
184
+ def load_model(self, model_fname=None):
185
+ if self.is_root:
186
+ latest_fname = os.path.join(self.save_path, "latest.txt")
187
+ if model_fname is None and os.path.exists(latest_fname):
188
+ with open(latest_fname, "r") as fin:
189
+ model_fname = fin.readline()
190
+ if model_fname[-1] == "\n":
191
+ model_fname = model_fname[:-1]
192
+ # noinspection PyBroadException
193
+ try:
194
+ if model_fname is None or not os.path.exists(model_fname):
195
+ model_fname = "%s/checkpoint.pth.tar" % self.save_path
196
+ with open(latest_fname, "w") as fout:
197
+ fout.write(model_fname + "\n")
198
+ print("=> loading checkpoint '{}'".format(model_fname))
199
+ checkpoint = torch.load(model_fname, map_location="cpu")
200
+ except Exception:
201
+ self.write_log(
202
+ "fail to load checkpoint from %s" % self.save_path, "valid"
203
+ )
204
+ return
205
+
206
+ self.net.load_state_dict(checkpoint["state_dict"])
207
+ if "epoch" in checkpoint:
208
+ self.start_epoch = checkpoint["epoch"] + 1
209
+ if "best_acc" in checkpoint:
210
+ self.best_acc = checkpoint["best_acc"]
211
+ if "optimizer" in checkpoint:
212
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
213
+
214
+ self.write_log("=> loaded checkpoint '{}'".format(model_fname), "valid")
215
+
216
+ # noinspection PyArgumentList
217
+ def broadcast(self):
218
+ import horovod.torch as hvd
219
+
220
+ self.start_epoch = hvd.broadcast(
221
+ torch.LongTensor(1).fill_(self.start_epoch)[0], 0, name="start_epoch"
222
+ ).item()
223
+ self.best_acc = hvd.broadcast(
224
+ torch.Tensor(1).fill_(self.best_acc)[0], 0, name="best_acc"
225
+ ).item()
226
+ hvd.broadcast_parameters(self.net.state_dict(), 0)
227
+ hvd.broadcast_optimizer_state(self.optimizer, 0)
228
+
229
+ """ metric related """
230
+
231
+ def get_metric_dict(self):
232
+ return {
233
+ "top1": DistributedMetric("top1"),
234
+ "top5": DistributedMetric("top5"),
235
+ "robust1" : DistributedMetric("robust1"),
236
+ "robust5": DistributedMetric("robust5")
237
+ }
238
+
239
+ def update_metric(self, metric_dict, output, output_adv , labels):
240
+ acc1, acc5 = accuracy(output, labels, topk=(1, 5))
241
+ robust1, robust5 = accuracy(output_adv, labels, topk=(1, 5))
242
+ metric_dict["top1"].update(acc1[0], output.size(0))
243
+ metric_dict["top5"].update(acc5[0], output.size(0))
244
+ metric_dict["robust1"].update(robust1[0], output.size(0))
245
+ metric_dict["robust5"].update(robust5[0], output.size(0))
246
+
247
+ def get_metric_vals(self, metric_dict, return_dict=False):
248
+ if return_dict:
249
+ return {key: metric_dict[key].avg.item() for key in metric_dict}
250
+ else:
251
+ return [metric_dict[key].avg.item() for key in metric_dict]
252
+
253
+ def get_metric_names(self):
254
+ return "top1", "top5", "robust1" ,"robust5"
255
+
256
+ """ train & validate """
257
+
258
+ def validate(
259
+ self,
260
+ epoch=0,
261
+ is_test=False,
262
+ run_str="",
263
+ net=None,
264
+ data_loader=None,
265
+ no_logs=False,
266
+ ):
267
+ if net is None:
268
+ net = self.net
269
+ if data_loader is None:
270
+ if is_test:
271
+ data_loader = self.run_config.test_loader
272
+ else:
273
+ data_loader = self.run_config.valid_loader
274
+
275
+ net.eval()
276
+ if self.run_config.robust_mode:
277
+ eval_attack = create_attack(net, self.test_criterion.cuda(), self.run_config.attack_type,self.run_config.epsilon_test,self.run_config.num_steps_test, self.run_config.step_size_test)
278
+ losses = DistributedMetric("val_loss")
279
+ metric_dict = self.get_metric_dict()
280
+
281
+ with tqdm(
282
+ total=len(data_loader),
283
+ desc="Validate Epoch #{} {}".format(epoch + 1, run_str),
284
+ disable=no_logs or not self.is_root,
285
+ ) as t:
286
+ for i, (images, labels) in enumerate(data_loader):
287
+ images, labels = images.cuda(), labels.cuda()
288
+ # compute output
289
+ output = net(images)
290
+ if self.run_config.robust_mode:
291
+ with ctx_noparamgrad_and_eval(net):
292
+ images_adv,_ = eval_attack.perturb(images, labels)
293
+ output_adv = net(images_adv)
294
+ loss = self.test_criterion(output_adv,labels)
295
+ else:
296
+ output_adv = output
297
+ loss = self.test_criterion(output,labels)
298
+
299
+ # measure accuracy and record loss
300
+ losses.update(loss, images.size(0))
301
+ self.update_metric(metric_dict, output, output_adv, labels)
302
+ t.set_postfix(
303
+ {
304
+ "loss": losses.avg.item(),
305
+ **self.get_metric_vals(metric_dict, return_dict=True),
306
+ "img_size": images.size(2),
307
+ }
308
+ )
309
+ t.update(1)
310
+ return losses.avg.item(), self.get_metric_vals(metric_dict)
311
+
312
+ def validate_all_resolution(self, epoch=0, is_test=False, net=None):
313
+ if net is None:
314
+ net = self.net
315
+ if isinstance(self.run_config.data_provider.image_size, list):
316
+ img_size_list, loss_list, top1_list, top5_list ,robust1_list, robust5_list = [], [], [], [],[],[]
317
+ for img_size in self.run_config.data_provider.image_size:
318
+ img_size_list.append(img_size)
319
+ self.run_config.data_provider.assign_active_img_size(img_size)
320
+ self.reset_running_statistics(net=net) # I am not sure that this is good fot robustness or not
321
+ loss, (top1, top5 ,robust1, robust5) = self.validate(epoch, is_test, net=net)
322
+ loss_list.append(loss)
323
+ top1_list.append(top1)
324
+ top5_list.append(top5)
325
+ robust1_list.append(robust1)
326
+ robust5_list.append(robust5)
327
+
328
+ return img_size_list, loss_list, top1_list, top5_list,robust1_list,robust5_list
329
+ else:
330
+ self.reset_running_statistics(net=net)
331
+ loss, (top1, top5 , robust1 ,robust5) = self.validate(epoch, is_test, net=net)
332
+ return (
333
+ [self.run_config.data_provider.active_img_size],
334
+ [loss],
335
+ [top1],
336
+ [top5],
337
+ [robust1],
338
+ [robust5],
339
+ )
340
+
341
+ def train_one_epoch(self, args, epoch, warmup_epochs=5, warmup_lr=0):
342
+ self.net.train()
343
+ self.run_config.train_loader.sampler.set_epoch(
344
+ epoch
345
+ ) # required by distributed sampler
346
+ MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
347
+
348
+ nBatch = len(self.run_config.train_loader)
349
+
350
+ losses = DistributedMetric("train_loss")
351
+ metric_dict = self.get_metric_dict()
352
+ data_time = AverageMeter()
353
+
354
+ with tqdm(
355
+ total=nBatch,
356
+ desc="Train Epoch #{}".format(epoch + 1),
357
+ disable=not self.is_root,
358
+ ) as t:
359
+ end = time.time()
360
+ for i, (images, labels) in enumerate(self.run_config.train_loader):
361
+ MyRandomResizedCrop.BATCH = i
362
+ data_time.update(time.time() - end)
363
+ if epoch < warmup_epochs:
364
+ new_lr = self.run_config.warmup_adjust_learning_rate(
365
+ self.optimizer,
366
+ warmup_epochs * nBatch,
367
+ nBatch,
368
+ epoch,
369
+ i,
370
+ warmup_lr,
371
+ )
372
+ else:
373
+ new_lr = self.run_config.adjust_learning_rate(
374
+ self.optimizer, epoch - warmup_epochs, i, nBatch
375
+ )
376
+
377
+ images, labels = images.cuda(), labels.cuda()
378
+ target = labels
379
+ if isinstance(self.run_config.mixup_alpha, float):
380
+ # transform data
381
+ random.seed(int("%d%.3d" % (i, epoch)))
382
+ lam = random.betavariate(
383
+ self.run_config.mixup_alpha, self.run_config.mixup_alpha
384
+ )
385
+ images = mix_images(images, lam)
386
+ labels = mix_labels(
387
+ labels,
388
+ lam,
389
+ self.run_config.data_provider.n_classes,
390
+ self.run_config.label_smoothing,
391
+ )
392
+
393
+ # soft target
394
+ if args.teacher_model is not None:
395
+ args.teacher_model.train()
396
+ with torch.no_grad():
397
+ soft_logits = args.teacher_model(images).detach()
398
+ soft_label = F.softmax(soft_logits, dim=1)
399
+
400
+ # compute output
401
+ output = self.net(images)
402
+ if args.teacher_model is None:
403
+ if self.run_config.robust_mode:
404
+ loss = self.train_criterion(self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train,self.run_config.distance_train)
405
+ loss_type = self.train_criterion.__name__
406
+ else:
407
+ loss = torch.nn.CrossEntropyLoss(output,labels)
408
+ loss_type = 'ce'
409
+
410
+ else:
411
+ if self.run_config.robust_mode:
412
+ loss = self.kd_criterion(args.teacher_model,self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train)
413
+ loss_type = self.kd_criterion_loss.__name__
414
+ else:
415
+ if args.kd_type == "ce":
416
+ kd_loss = cross_entropy_loss_with_soft_target(
417
+ output, soft_label
418
+ )
419
+ else:
420
+ kd_loss = F.mse_loss(output, soft_logits)
421
+ loss = args.kd_ratio * kd_loss + loss
422
+ loss_type = "%.1fkd+ce" % args.kd_ratio
423
+
424
+
425
+ # update
426
+ self.optimizer.zero_grad()
427
+ loss.backward()
428
+ self.optimizer.step()
429
+
430
+ # measure accuracy and record loss
431
+ losses.update(loss, images.size(0))
432
+ self.update_metric(metric_dict, output, output, target)
433
+
434
+ t.set_postfix(
435
+ {
436
+ "loss": losses.avg.item(),
437
+ **self.get_metric_vals(metric_dict, return_dict=True),
438
+ "img_size": images.size(2),
439
+ "lr": new_lr,
440
+ "loss_type": loss_type,
441
+ "data_time": data_time.avg,
442
+ }
443
+ )
444
+ t.update(1)
445
+ end = time.time()
446
+ return losses.avg.item(), self.get_metric_vals(metric_dict)
447
+
448
+ def train(self, args, warmup_epochs=5, warmup_lr=0):
449
+ for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epochs):
450
+ train_loss, (train_top1, train_top5, train_robust1, train_robust5) = self.train_one_epoch(
451
+ args, epoch, warmup_epochs, warmup_lr
452
+ )
453
+ img_size, val_loss, val_top1, val_top5 , val_robust1, val_robust5= self.validate_all_resolution(
454
+ epoch, is_test=False
455
+ )
456
+
457
+ is_best = list_mean(val_top1) > self.best_acc
458
+ is_best_robust = list_mean(val_robust1) > self.best_robustness
459
+ self.best_robustness = max(self.best_robustness, list_mean(val_robust1))
460
+ self.best_acc = max(self.best_acc, list_mean(val_top1))
461
+ if self.is_root:
462
+ val_log = (
463
+ "[{0}/{1}]\tloss {2:.3f}\t{6} acc {3:.3f} ({4:.3f})\t{7} acc {5:.3f}\t {8} robust {10:.3f} ({4:.3f})\t{9} robust {11:.3f} "
464
+ "Train {6} {top1:.3f}\tloss {train_loss:.3f}\t robust1 {8} {robust1:.3f}\t".format(
465
+ epoch + 1 - warmup_epochs,
466
+ self.run_config.n_epochs,
467
+ list_mean(val_loss),
468
+ list_mean(val_top1),
469
+ self.best_acc,
470
+ list_mean(val_top5),
471
+ *self.get_metric_names(),
472
+ list_mean(val_robust1),
473
+ list_mean(val_robust5),
474
+ top1=train_top1,
475
+ train_loss=train_loss,
476
+ robust1 = train_robust1,
477
+ )
478
+ )
479
+ for i_s, v_a in zip(img_size, val_top1):
480
+ val_log += "(%d, %.3f), " % (i_s, v_a)
481
+ self.write_log(val_log, prefix="valid", should_print=False)
482
+
483
+ self.save_model(
484
+ {
485
+ "epoch": epoch,
486
+ "best_acc": self.best_acc,
487
+ "optimizer": self.optimizer.state_dict(),
488
+ "state_dict": self.net.state_dict(),
489
+ },
490
+ is_best=is_best,
491
+ )
492
+
493
+ def reset_running_statistics(
494
+ self, net=None, subset_size=4000, subset_batch_size=200, data_loader=None
495
+ ):
496
+ from proard.classification.elastic_nn.utils import set_running_statistics
497
+
498
+ if net is None:
499
+ net = self.net
500
+ if data_loader is None:
501
+ data_loader = self.run_config.random_sub_train_loader(
502
+ subset_size, subset_batch_size
503
+ )
504
+
505
+ set_running_statistics(net, data_loader)
proard/classification/run_manager/run_config.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ from proard.utils import calc_learning_rate, build_optimizer
6
+ from proard.classification.data_providers import ImagenetDataProvider
7
+ from proard.classification.data_providers import Cifar10DataProvider
8
+ from proard.classification.data_providers import Cifar100DataProvider
9
+ from robust_loss.trades import trades_loss
10
+ from robust_loss.adaad import adaad_loss
11
+ from robust_loss.ard import ard_loss
12
+ from robust_loss.hat import hat_loss
13
+ from robust_loss.mart import mart_loss
14
+ from robust_loss.sat import sat_loss
15
+ from robust_loss.rslad import rslad_loss
16
+ import torch
17
+ __all__ = ["RunConfig", "ClassificationRunConfig", "DistributedClassificationRunConfig"]
18
+
19
+
20
+ class RunConfig:
21
+ def __init__(
22
+ self,
23
+ n_epochs,
24
+ init_lr,
25
+ lr_schedule_type,
26
+ lr_schedule_param,
27
+ dataset,
28
+ train_batch_size,
29
+ test_batch_size,
30
+ valid_size,
31
+ opt_type,
32
+ opt_param,
33
+ weight_decay,
34
+ label_smoothing,
35
+ no_decay_keys,
36
+ mixup_alpha,
37
+ model_init,
38
+ validation_frequency,
39
+ print_frequency,
40
+ ):
41
+ self.n_epochs = n_epochs
42
+ self.init_lr = init_lr
43
+ self.lr_schedule_type = lr_schedule_type
44
+ self.lr_schedule_param = lr_schedule_param
45
+
46
+ self.dataset = dataset
47
+ self.train_batch_size = train_batch_size
48
+ self.test_batch_size = test_batch_size
49
+ self.valid_size = valid_size
50
+
51
+ self.opt_type = opt_type
52
+ self.opt_param = opt_param
53
+ self.weight_decay = weight_decay
54
+ self.label_smoothing = label_smoothing
55
+ self.no_decay_keys = no_decay_keys
56
+
57
+ self.mixup_alpha = mixup_alpha
58
+
59
+ self.model_init = model_init
60
+ self.validation_frequency = validation_frequency
61
+ self.print_frequency = print_frequency
62
+
63
+ @property
64
+ def config(self):
65
+ config = {}
66
+ for key in self.__dict__:
67
+ if not key.startswith("_"):
68
+ config[key] = self.__dict__[key]
69
+ return config
70
+
71
+ def copy(self):
72
+ return RunConfig(**self.config)
73
+
74
+ """ learning rate """
75
+
76
+ def adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
77
+ """adjust learning of a given optimizer and return the new learning rate"""
78
+ new_lr = calc_learning_rate(
79
+ epoch, self.init_lr, self.n_epochs, batch, nBatch, self.lr_schedule_type
80
+ )
81
+ for param_group in optimizer.param_groups:
82
+ param_group["lr"] = new_lr
83
+ return new_lr
84
+
85
+ def warmup_adjust_learning_rate(
86
+ self, optimizer, T_total, nBatch, epoch, batch=0, warmup_lr=0
87
+ ):
88
+ T_cur = epoch * nBatch + batch + 1
89
+ new_lr = T_cur / T_total * (self.init_lr - warmup_lr) + warmup_lr
90
+ for param_group in optimizer.param_groups:
91
+ param_group["lr"] = new_lr
92
+ return new_lr
93
+
94
+ """ data provider """
95
+
96
+ @property
97
+ def data_provider(self):
98
+ raise NotImplementedError
99
+
100
+ @property
101
+ def train_loader(self):
102
+ return self.data_provider.train
103
+
104
+ @property
105
+ def valid_loader(self):
106
+ return self.data_provider.valid
107
+
108
+ @property
109
+ def test_loader(self):
110
+ return self.data_provider.test
111
+
112
+ def random_sub_train_loader(
113
+ self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None
114
+ ):
115
+ return self.data_provider.build_sub_train_loader(
116
+ n_images, batch_size, num_worker, num_replicas, rank
117
+ )
118
+
119
+ """ optimizer """
120
+
121
+ def build_optimizer(self, net_params):
122
+ return build_optimizer(
123
+ net_params,
124
+ self.opt_type,
125
+ self.opt_param,
126
+ self.init_lr,
127
+ self.weight_decay,
128
+ self.no_decay_keys,
129
+ )
130
+
131
+
132
+
133
+ class ClassificationRunConfig(RunConfig):
134
+ def __init__(
135
+ self,
136
+ n_epochs=150,
137
+ init_lr=0.05,
138
+ lr_schedule_type="cosine",
139
+ lr_schedule_param=None,
140
+ dataset="imagenet", # 'cifar10' or 'cifar100'
141
+ train_batch_size=256,
142
+ test_batch_size=500,
143
+ valid_size=None,
144
+ opt_type="sgd",
145
+ opt_param=None,
146
+ weight_decay=4e-5,
147
+ label_smoothing=0.1,
148
+ no_decay_keys=None,
149
+ mixup_alpha=None,
150
+ model_init="he_fout",
151
+ validation_frequency=1,
152
+ print_frequency=10,
153
+ n_worker=32,
154
+ resize_scale=0.08,
155
+ distort_color="tf",
156
+ image_size=224, # 32
157
+ robust_mode = False,
158
+ epsilon_train = 0.031,
159
+ num_steps_train = 10,
160
+ step_size_train = 0.0078,
161
+ clip_min_train = 0 ,
162
+ clip_max_train = 1,
163
+ const_init_train = False,
164
+ beta_train = 6.0,
165
+ distance_train ="l_inf",
166
+ epsilon_test = 0.031,
167
+ num_steps_test = 20,
168
+ step_size_test = 0.0078,
169
+ clip_min_test = 0,
170
+ clip_max_test = 1,
171
+ const_init_test = False,
172
+ beta_test = 6.0,
173
+ distance_test = "l_inf",
174
+ train_criterion = "trades",
175
+ test_criterion = "ce",
176
+ kd_criterion = 'rslad',
177
+ attack_type = "linf-pgd",
178
+ **kwargs
179
+ ):
180
+ super(ClassificationRunConfig, self).__init__(
181
+ n_epochs,
182
+ init_lr,
183
+ lr_schedule_type,
184
+ lr_schedule_param,
185
+ dataset,
186
+ train_batch_size,
187
+ test_batch_size,
188
+ valid_size,
189
+ opt_type,
190
+ opt_param,
191
+ weight_decay,
192
+ label_smoothing,
193
+ no_decay_keys,
194
+ mixup_alpha,
195
+ model_init,
196
+ validation_frequency,
197
+ print_frequency,
198
+ )
199
+
200
+ self.n_worker = n_worker
201
+ self.resize_scale = resize_scale
202
+ self.distort_color = distort_color
203
+ self.image_size = image_size
204
+ self.epsilon_train = epsilon_train
205
+ self.num_steps_train = num_steps_train
206
+ self.step_size_train = step_size_train
207
+ self.clip_min_train = clip_min_train
208
+ self.clip_max_train = clip_max_train
209
+ self.const_init_train = const_init_train
210
+ self.beta_train = beta_train
211
+ self.distance_train = distance_train
212
+ self.epsilon_test = epsilon_test
213
+ self.num_steps_test = num_steps_test
214
+ self.step_size_test = step_size_test
215
+ self.clip_min_test = clip_min_test
216
+ self.clip_max_test = clip_max_test
217
+ self.const_init_test = const_init_test
218
+ self.beta_test = beta_test
219
+ self.distance_test = distance_test
220
+ self.train_criterion = train_criterion
221
+ self.test_criterion = test_criterion
222
+ self.kd_criterion = kd_criterion
223
+ self.attack_type = attack_type
224
+ self.robust_mode = robust_mode
225
+ @property
226
+ def data_provider(self):
227
+ if self.__dict__.get("_data_provider", None) is None:
228
+ if self.dataset == ImagenetDataProvider.name():
229
+ DataProviderClass = ImagenetDataProvider
230
+ elif self.dataset == Cifar10DataProvider.name():
231
+ DataProviderClass = Cifar10DataProvider
232
+ elif self.dataset == Cifar100DataProvider.name():
233
+ DataProviderClass = Cifar100DataProvider
234
+ else:
235
+ raise NotImplementedError
236
+ self.__dict__["_data_provider"] = DataProviderClass(
237
+ train_batch_size=self.train_batch_size,
238
+ test_batch_size=self.test_batch_size,
239
+ valid_size=self.valid_size,
240
+ n_worker=self.n_worker,
241
+ resize_scale=self.resize_scale,
242
+ distort_color=self.distort_color,
243
+ image_size=self.image_size,
244
+ )
245
+ return self.__dict__["_data_provider"]
246
+ @property
247
+ def train_criterion_loss (self):
248
+ if self.train_criterion == "trades" :
249
+ return trades_loss
250
+ elif self.train_criterion == "mart" :
251
+ return mart_loss
252
+ elif self.train_criterion == "sat" :
253
+ return sat_loss
254
+ elif self.train_criterion == "hat" :
255
+ return hat_loss
256
+ @property
257
+ def test_criterion_loss (self) :
258
+ if self.test_criterion == "ce" :
259
+ return torch.nn.CrossEntropyLoss()
260
+ @property
261
+ def kd_criterion_loss (self) :
262
+ if self.kd_criterion =="ard" :
263
+ return ard_loss
264
+ elif self.kd_criterion == "adaad" :
265
+ return adaad_loss
266
+ elif self.kd_criterion == "rslad" :
267
+ return rslad_loss
268
+ class DistributedClassificationRunConfig(ClassificationRunConfig):
269
+ def __init__(
270
+ self,
271
+ n_epochs=150,
272
+ init_lr=0.05,
273
+ lr_schedule_type="cosine",
274
+ lr_schedule_param=None,
275
+ dataset="imagenet",
276
+ train_batch_size=64,
277
+ test_batch_size=64,
278
+ valid_size=None,
279
+ opt_type="sgd",
280
+ opt_param=None,
281
+ weight_decay=4e-5,
282
+ label_smoothing=0.1,
283
+ no_decay_keys=None,
284
+ mixup_alpha=None,
285
+ model_init="he_fout",
286
+ validation_frequency=1,
287
+ print_frequency=10,
288
+ n_worker=8,
289
+ resize_scale=0.08,
290
+ distort_color="tf",
291
+ image_size=224,
292
+ robust_mode = False,
293
+ epsilon = 0.031,
294
+ num_steps = 10,
295
+ step_size = 0.0078,
296
+ clip_min = 0,
297
+ clip_max = 1,
298
+ const_init = False,
299
+ beta = 6.0,
300
+ distance = "l_inf",
301
+ train_criterion = "trades",
302
+ test_criterion = "ce",
303
+ kd_criterion = 'rslad',
304
+ attack_type = "linf-pgd",
305
+ **kwargs
306
+ ):
307
+ super(DistributedClassificationRunConfig, self).__init__(
308
+ n_epochs,
309
+ init_lr,
310
+ lr_schedule_type,
311
+ lr_schedule_param,
312
+ dataset,
313
+ train_batch_size,
314
+ test_batch_size,
315
+ valid_size,
316
+ opt_type,
317
+ opt_param,
318
+ weight_decay,
319
+ label_smoothing,
320
+ no_decay_keys,
321
+ mixup_alpha,
322
+ model_init,
323
+ validation_frequency,
324
+ print_frequency,
325
+ n_worker,
326
+ resize_scale,
327
+ distort_color,
328
+ image_size,
329
+ robust_mode,
330
+ epsilon,
331
+ num_steps,
332
+ step_size,
333
+ clip_min,
334
+ clip_max,
335
+ const_init,
336
+ beta,
337
+ distance,
338
+ epsilon,
339
+ num_steps * 2,
340
+ step_size,
341
+ clip_min,clip_max,
342
+ const_init,
343
+ beta,
344
+ distance,
345
+ train_criterion,
346
+ test_criterion,
347
+ kd_criterion,
348
+ attack_type,
349
+ **kwargs
350
+ )
351
+
352
+ self._num_replicas = kwargs["num_replicas"]
353
+ self._rank = kwargs["rank"]
354
+
355
+ @property
356
+ def data_provider(self):
357
+ if self.__dict__.get("_data_provider", None) is None:
358
+ if self.dataset == ImagenetDataProvider.name():
359
+ DataProviderClass = ImagenetDataProvider
360
+ elif self.dataset == Cifar10DataProvider.name():
361
+ DataProviderClass = Cifar10DataProvider
362
+ elif self.dataset == Cifar100DataProvider.name():
363
+ DataProviderClass = Cifar100DataProvider
364
+ else:
365
+ raise NotImplementedError
366
+ if self.dataset == "imagenet":
367
+ self.__dict__["_data_provider"] = DataProviderClass(
368
+ train_batch_size=self.train_batch_size,
369
+ test_batch_size=self.test_batch_size,
370
+ valid_size=self.valid_size,
371
+ n_worker=self.n_worker,
372
+ resize_scale=self.resize_scale,
373
+ distort_color=self.distort_color,
374
+ image_size=self.image_size,
375
+ num_replicas=self._num_replicas,
376
+ rank=self._rank,
377
+ )
378
+ else:
379
+ self.__dict__["_data_provider"] = DataProviderClass(
380
+ train_batch_size=self.train_batch_size,
381
+ test_batch_size=self.test_batch_size,
382
+ valid_size=self.valid_size,
383
+ n_worker=self.n_worker,
384
+ resize_scale=None,
385
+ distort_color=None,
386
+ image_size=self.image_size,
387
+ num_replicas=self._num_replicas,
388
+ rank=self._rank,
389
+ )
390
+ return self.__dict__["_data_provider"]
391
+ @property
392
+ def train_criterion_loss (self):
393
+ if self.train_criterion == "trades" :
394
+ return trades_loss
395
+ elif self.train_criterion == "mart" :
396
+ return mart_loss
397
+ elif self.train_criterion == "sat" :
398
+ return sat_loss
399
+ elif self.train_criterion == "hat" :
400
+ return hat_loss
401
+ @property
402
+ def test_criterion_loss (self) :
403
+ if self.test_criterion == "ce" :
404
+ return torch.nn.CrossEntropyLoss()
405
+ @property
406
+ def kd_criterion_loss (self) :
407
+ if self.kd_criterion =="ard" :
408
+ return ard_loss
409
+ elif self.kd_criterion == "adaad" :
410
+ return adaad_loss
411
+ elif self.kd_criterion == "rslad" :
412
+ return rslad_loss
413
+
414
+
proard/classification/run_manager/run_manager.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import random
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.nn.parallel
13
+ import torch.backends.cudnn as cudnn
14
+ import torch.optim
15
+ from tqdm import tqdm
16
+ from attacks.utils import ctx_noparamgrad_and_eval
17
+ from robust_loss.rslad import rslad_inner_loss,kl_loss
18
+ from robust_loss.trades import trades_loss
19
+ from attacks import create_attack
20
+ from proard.utils import (
21
+ get_net_info,
22
+ cross_entropy_loss_with_soft_target,
23
+ cross_entropy_with_label_smoothing,
24
+ )
25
+ from proard.utils import (
26
+ AverageMeter,
27
+ accuracy,
28
+ write_log,
29
+ mix_images,
30
+ mix_labels,
31
+ init_models,
32
+ )
33
+ from proard.utils import MyRandomResizedCrop
34
+
35
+ __all__ = ["RunManager"]
36
+
37
+
38
+ class RunManager:
39
+ def __init__(
40
+ self, path, net, run_config, init=True, measure_latency=None, no_gpu=False
41
+ ):
42
+ self.path = path
43
+ self.net = net
44
+ self.run_config = run_config
45
+
46
+ self.best_acc = 0
47
+ self.best_robustness = 0
48
+ self.start_epoch = 0
49
+
50
+ os.makedirs(self.path, exist_ok=True)
51
+
52
+ # move network to GPU if available
53
+ if torch.cuda.is_available() and (not no_gpu):
54
+ self.device = torch.device("cuda")
55
+ self.net = self.net.to(self.device)
56
+ cudnn.benchmark = True
57
+ else:
58
+ self.device = torch.device("cpu")
59
+ # initialize model (default)
60
+ if init:
61
+ init_models(net,run_config.model_init)
62
+
63
+ # net info
64
+ net_info = get_net_info(
65
+ self.net, self.run_config.data_provider.data_shape, measure_latency, True
66
+ )
67
+ with open("%s/net_info.txt" % self.path, "w") as fout:
68
+ fout.write(json.dumps(net_info, indent=4) + "\n")
69
+ # noinspection PyBroadException
70
+ try:
71
+ fout.write(self.network.module_str + "\n")
72
+ except Exception:
73
+ pass
74
+ fout.write("%s\n" % self.run_config.data_provider.train.dataset.transform)
75
+ fout.write("%s\n" % self.run_config.data_provider.test.dataset.transform)
76
+ fout.write("%s\n" % self.network)
77
+
78
+ self.train_criterion = self.run_config.train_criterion_loss
79
+ self.test_criterion = self.run_config.test_criterion_loss
80
+ self.kd_criterion = self.run_config.kd_criterion_loss
81
+
82
+ # optimizer
83
+ if self.run_config.no_decay_keys:
84
+ keys = self.run_config.no_decay_keys.split("#")
85
+ net_params = [
86
+ self.network.get_parameters(
87
+ keys, mode="exclude"
88
+ ), # parameters with weight decay
89
+ self.network.get_parameters(
90
+ keys, mode="include"
91
+ ), # parameters without weight decay
92
+ ]
93
+ else:
94
+ # noinspection PyBroadException
95
+ try:
96
+ net_params = self.network.weight_parameters()
97
+ except Exception:
98
+ net_params = []
99
+ for param in self.network.parameters():
100
+ if param.requires_grad:
101
+ net_params.append(param)
102
+ self.optimizer = self.run_config.build_optimizer(net_params)
103
+
104
+ self.net = torch.nn.DataParallel(self.net)
105
+
106
+ """ save path and log path """
107
+
108
+ @property
109
+ def save_path(self):
110
+ if self.__dict__.get("_save_path", None) is None:
111
+ save_path = os.path.join(self.path, "checkpoint")
112
+ os.makedirs(save_path, exist_ok=True)
113
+ self.__dict__["_save_path"] = save_path
114
+ return self.__dict__["_save_path"]
115
+
116
+ @property
117
+ def logs_path(self):
118
+ if self.__dict__.get("_logs_path", None) is None:
119
+ logs_path = os.path.join(self.path, "logs")
120
+ os.makedirs(logs_path, exist_ok=True)
121
+ self.__dict__["_logs_path"] = logs_path
122
+ return self.__dict__["_logs_path"]
123
+
124
+ @property
125
+ def network(self):
126
+ return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
127
+
128
+ def write_log(self, log_str, prefix="valid", should_print=True, mode="a"):
129
+ write_log(self.logs_path, log_str, prefix, should_print, mode)
130
+
131
+ """ save and load models """
132
+
133
+ def save_model(self, checkpoint=None, is_best=False, model_name=None):
134
+ if checkpoint is None:
135
+ checkpoint = {"state_dict": self.network.state_dict()}
136
+
137
+ if model_name is None:
138
+ model_name = "checkpoint.pth.tar"
139
+
140
+ checkpoint[
141
+ "dataset"
142
+ ] = self.run_config.dataset # add `dataset` info to the checkpoint
143
+ latest_fname = os.path.join(self.save_path, "latest.txt")
144
+ model_path = os.path.join(self.save_path, model_name)
145
+ with open(latest_fname, "w") as fout:
146
+ fout.write(model_path + "\n")
147
+ torch.save(checkpoint, model_path)
148
+
149
+ if is_best:
150
+ best_path = os.path.join(self.save_path, "model_best.pth.tar")
151
+ torch.save({"state_dict": checkpoint["state_dict"]}, best_path)
152
+
153
+ def load_model(self, model_fname=None):
154
+ latest_fname = os.path.join(self.save_path, "latest.txt")
155
+ if model_fname is None and os.path.exists(latest_fname):
156
+ with open(latest_fname, "r") as fin:
157
+ model_fname = fin.readline()
158
+ if model_fname[-1] == "\n":
159
+ model_fname = model_fname[:-1]
160
+ # noinspection PyBroadException
161
+ try:
162
+ if model_fname is None or not os.path.exists(model_fname):
163
+ model_fname = "%s/checkpoint.pth.tar" % self.save_path
164
+ with open(latest_fname, "w") as fout:
165
+ fout.write(model_fname + "\n")
166
+ print("=> loading checkpoint '{}'".format(model_fname))
167
+ checkpoint = torch.load(model_fname, map_location="cpu")
168
+ except Exception:
169
+ print("fail to load checkpoint from %s" % self.save_path)
170
+ return {}
171
+
172
+ self.network.load_state_dict(checkpoint["state_dict"])
173
+ if "epoch" in checkpoint:
174
+ self.start_epoch = checkpoint["epoch"] + 1
175
+ if "best_acc" in checkpoint:
176
+ self.best_acc = checkpoint["best_acc"]
177
+ if "optimizer" in checkpoint:
178
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
179
+
180
+ print("=> loaded checkpoint '{}'".format(model_fname))
181
+ return checkpoint
182
+
183
+ def save_config(self, extra_run_config=None, extra_net_config=None):
184
+ """dump run_config and net_config to the model_folder"""
185
+ run_save_path = os.path.join(self.path, "run.config")
186
+ if not os.path.isfile(run_save_path):
187
+ run_config = self.run_config.config
188
+ if extra_run_config is not None:
189
+ run_config.update(extra_run_config)
190
+ json.dump(run_config, open(run_save_path, "w"), indent=4)
191
+ print("Run configs dump to %s" % run_save_path)
192
+
193
+ try:
194
+ net_save_path = os.path.join(self.path, "net.config")
195
+ net_config = self.network.config
196
+ if extra_net_config is not None:
197
+ net_config.update(extra_net_config)
198
+ json.dump(net_config, open(net_save_path, "w"), indent=4)
199
+ print("Network configs dump to %s" % net_save_path)
200
+ except Exception:
201
+ print("%s do not support net config" % type(self.network))
202
+
203
+ """ metric related """
204
+
205
+ def get_metric_dict(self):
206
+ return {
207
+ "top1": AverageMeter(),
208
+ "top5": AverageMeter(),
209
+ "robust1" :AverageMeter(),
210
+ "robust5" :AverageMeter(),
211
+ }
212
+
213
+ def update_metric(self, metric_dict, output, output_adv, labels):
214
+ acc1, acc5 = accuracy(output, labels, topk=(1, 5))
215
+ robust1,robust5 = accuracy(output_adv,labels,topk=(1,5))
216
+ metric_dict["top1"].update(acc1[0].item(), output.size(0))
217
+ metric_dict["top5"].update(acc5[0].item(), output.size(0))
218
+ metric_dict["robust1"].update(robust1[0].item(), output.size(0))
219
+ metric_dict["robust5"].update(robust5[0].item(), output.size(0))
220
+
221
+
222
+ def get_metric_vals(self, metric_dict, return_dict=False):
223
+ if return_dict:
224
+ return {key: metric_dict[key].avg for key in metric_dict}
225
+ else:
226
+ return [metric_dict[key].avg for key in metric_dict]
227
+
228
+ def get_metric_names(self):
229
+ return "top1", "top5" , "robust1" , "robust5"
230
+
231
+ """ train and test """
232
+
233
+ def validate(
234
+ self,
235
+ epoch=0,
236
+ is_test=False,
237
+ run_str="",
238
+ net=None,
239
+ data_loader=None,
240
+ no_logs=False,
241
+ train_mode=False,
242
+ ):
243
+ if net is None:
244
+ net = self.net
245
+ if not isinstance(net, nn.DataParallel):
246
+ net = nn.DataParallel(net)
247
+ if data_loader is None:
248
+ data_loader = (
249
+ self.run_config.test_loader if is_test else self.run_config.valid_loader
250
+ )
251
+
252
+ if train_mode:
253
+ net.train()
254
+ else:
255
+ net.eval()
256
+ if self.run_config.robust_mode:
257
+ eval_attack = create_attack(net, self.test_criterion.cuda(), self.run_config.attack_type,self.run_config.epsilon_test,self.run_config.num_steps_test, self.run_config.step_size_test)
258
+ losses = AverageMeter()
259
+ metric_dict = self.get_metric_dict()
260
+
261
+ with tqdm(
262
+ total=len(data_loader),
263
+ desc="Validate Epoch #{} {}".format(epoch + 1, run_str),
264
+ disable=no_logs,
265
+ ) as t:
266
+ for i, (images, labels) in enumerate(data_loader):
267
+ images, labels = images.to(self.device), labels.to(self.device)
268
+ # compute output
269
+ output = net(images)
270
+ if self.run_config.robust_mode:
271
+ with ctx_noparamgrad_and_eval(net):
272
+ images_adv,_ = eval_attack.perturb(images, labels)
273
+ output_adv = net(images_adv)
274
+ loss = nn.CrossEntropyLoss()(output_adv,labels)
275
+ else:
276
+ output_adv = output
277
+ loss = nn.CrossEntropyLoss()(output,labels)
278
+
279
+ # measure accuracy and record loss
280
+ self.update_metric(metric_dict, output, output_adv , labels)
281
+
282
+ losses.update(loss.item(), images.size(0))
283
+ t.set_postfix(
284
+ {
285
+ "loss": losses.avg,
286
+ **self.get_metric_vals(metric_dict, return_dict=True),
287
+ "img_size": images.size(2),
288
+ }
289
+ )
290
+ t.update(1)
291
+ return losses.avg, self.get_metric_vals(metric_dict)
292
+
293
+ def validate_all_resolution(self, epoch=0, is_test=False, net=None):
294
+ if net is None:
295
+ net = self.network
296
+ if isinstance(self.run_config.data_provider.image_size, list):
297
+ img_size_list, loss_list, top1_list, top5_list , robust1_list , robust5_list = [], [], [], [],[],[]
298
+ for img_size in self.run_config.data_provider.image_size:
299
+ img_size_list.append(img_size)
300
+ self.run_config.data_provider.assign_active_img_size(img_size)
301
+ self.reset_running_statistics(net=net)
302
+ loss, (top1, top5 , robust1,robust5) = self.validate(epoch, is_test, net=net)
303
+ loss_list.append(loss)
304
+ top1_list.append(top1)
305
+ top5_list.append(top5)
306
+ robust1_list.append(robust1)
307
+ robust5_list.append(robust5)
308
+ return img_size_list, loss_list, top1_list, top5_list ,robust1_list ,robust5_list
309
+ else:
310
+ loss, (top1, top5 , robust1 , robust5) = self.validate(epoch, is_test, net=net)
311
+ return (
312
+ [self.run_config.data_provider.active_img_size],
313
+ [loss],
314
+ [top1],
315
+ [top5],
316
+ [robust1],
317
+ [robust5]
318
+ )
319
+
320
+ def train_one_epoch(self, args, epoch, warmup_epochs=0, warmup_lr=0):
321
+ # switch to train mode
322
+ self.net.train()
323
+ MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
324
+
325
+ nBatch = len(self.run_config.train_loader)
326
+
327
+ losses = AverageMeter()
328
+ metric_dict = self.get_metric_dict()
329
+ data_time = AverageMeter()
330
+
331
+ with tqdm(
332
+ total=nBatch,
333
+ desc="{} Train Epoch #{}".format(self.run_config.dataset, epoch + 1),
334
+ ) as t:
335
+ end = time.time()
336
+ for i, (images, labels) in enumerate(self.run_config.train_loader):
337
+ MyRandomResizedCrop.BATCH = i
338
+ data_time.update(time.time() - end)
339
+ if epoch < warmup_epochs:
340
+ new_lr = self.run_config.warmup_adjust_learning_rate(
341
+ self.optimizer,
342
+ warmup_epochs * nBatch,
343
+ nBatch,
344
+ epoch,
345
+ i,
346
+ warmup_lr,
347
+ )
348
+ else:
349
+ new_lr = self.run_config.adjust_learning_rate(
350
+ self.optimizer, epoch - warmup_epochs, i, nBatch
351
+ )
352
+
353
+ images, labels = images.to(self.device), labels.to(self.device)
354
+ target = labels
355
+ if isinstance(self.run_config.mixup_alpha, float):
356
+ # transform data
357
+ lam = random.betavariate(
358
+ self.run_config.mixup_alpha, self.run_config.mixup_alpha
359
+ )
360
+ images = mix_images(images, lam)
361
+ labels = mix_labels(
362
+ labels,
363
+ lam,
364
+ self.run_config.data_provider.n_classes,
365
+ self.run_config.label_smoothing,
366
+ )
367
+
368
+ # soft target
369
+ if args.teacher_model is not None:
370
+ args.teacher_model.train()
371
+ with torch.no_grad():
372
+ soft_logits = args.teacher_model(images).detach()
373
+ soft_label = F.softmax(soft_logits, dim=1)
374
+
375
+ # compute output
376
+ output = self.net(images)
377
+
378
+ if args.teacher_model is None:
379
+ if self.run_config.robust_mode:
380
+ loss = self.train_criterion(self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train,self.run_config.distance_train)
381
+ loss_type = self.run_config.train_criterion
382
+ else:
383
+ loss = torch.nn.CrossEntropyLoss(output,labels)
384
+ loss_type = 'ce'
385
+
386
+ else:
387
+ if self.run_config.robust_mode:
388
+ loss = self.kd_criterion(args.teacher_model,self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train)
389
+ loss_type = self.run_config.train_criterion
390
+ else:
391
+ if args.kd_type == "ce":
392
+ kd_loss = cross_entropy_loss_with_soft_target(
393
+ output, soft_label
394
+ )
395
+ else:
396
+ kd_loss = F.mse_loss(output, soft_logits)
397
+ loss = args.kd_ratio * kd_loss + loss
398
+ loss_type = "%.1fkd+ce" % args.kd_ratio
399
+
400
+ # compute gradient and do SGD step
401
+ self.net.zero_grad() # or self.optimizer.zero_grad()
402
+ loss.backward()
403
+ self.optimizer.step()
404
+
405
+ # measure accuracy and record loss
406
+ losses.update(loss.item(), images.size(0))
407
+ self.update_metric(metric_dict, output, output ,target)
408
+
409
+ t.set_postfix(
410
+ {
411
+ "loss": losses.avg,
412
+ **self.get_metric_vals(metric_dict, return_dict=True),
413
+ "img_size": images.size(2),
414
+ "lr": new_lr,
415
+ "loss_type": loss_type,
416
+ "data_time": data_time.avg,
417
+ }
418
+ )
419
+ t.update(1)
420
+ end = time.time()
421
+ return losses.avg, self.get_metric_vals(metric_dict)
422
+
423
+ def train(self, args, warmup_epoch=0, warmup_lr=0):
424
+ for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch):
425
+ train_loss, (train_top1, train_top5 , train_robust1 , train_robust5) = self.train_one_epoch(
426
+ args, epoch, warmup_epoch, warmup_lr
427
+ )
428
+
429
+ if (epoch + 1) % self.run_config.validation_frequency == 0:
430
+ img_size, val_loss, val_acc, val_acc5 ,val_robust, val_robust5 = self.validate_all_resolution(
431
+ epoch=epoch, is_test=False
432
+ )
433
+
434
+ is_best = np.mean(val_acc) > self.best_acc
435
+ is_best_robust = np.mean(val_robust) > self.best_robustness
436
+ self.best_acc = max(self.best_acc, np.mean(val_acc))
437
+ self.best_robustness = max(self.best_robustness, np.mean(val_robust))
438
+ val_log = "Valid [{0}/{1}]\tloss {2:.3f} \t{7} {3:.3f} ({5:.3f}) \t{8} {4:.3f} ({6:.3f})".format(
439
+ epoch + 1 - warmup_epoch,
440
+ self.run_config.n_epochs,
441
+ np.mean(val_loss),
442
+ np.mean(val_acc),
443
+ np.mean(val_robust),
444
+ self.best_acc,
445
+ self.best_robustness,
446
+ self.get_metric_names()[0],
447
+ self.get_metric_names()[2],
448
+ )
449
+ val_log += "\t{2} {0:.3f} \tTrain {1} {top1:.3f}\t {3} {robust:.3f} \t loss {train_loss:.3f}\t".format(
450
+ np.mean(val_acc5),
451
+ *self.get_metric_names(),
452
+ top1=train_top1,
453
+ robust = train_robust1,
454
+ train_loss=train_loss
455
+ )
456
+ for i_s, v_a in zip(img_size, val_acc):
457
+ val_log += "(%d, %.3f), " % (i_s, v_a)
458
+ self.write_log(val_log, prefix="valid", should_print=False)
459
+ else:
460
+ is_best = False
461
+ is_best_robust = False
462
+
463
+ self.save_model(
464
+ {
465
+ "epoch": epoch,
466
+ "best_acc": self.best_acc,
467
+ "optimizer": self.optimizer.state_dict(),
468
+ "state_dict": self.network.state_dict(),
469
+ },
470
+ is_best=is_best,
471
+ )
472
+
473
+ def reset_running_statistics(
474
+ self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None
475
+ ):
476
+ from proard.classification.elastic_nn.utils import set_running_statistics
477
+
478
+ if net is None:
479
+ net = self.network
480
+ if data_loader is None:
481
+ data_loader = self.run_config.random_sub_train_loader(
482
+ subset_size, subset_batch_size
483
+ )
484
+ set_running_statistics(net, data_loader)
proard/model_zoo.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import json
6
+ import torch
7
+ import gdown
8
+
9
+ from proard.classification.networks import get_net_by_name, ResNet50
10
+ from proard.classification.elastic_nn.networks import (
11
+ DYNResNets,DYNMobileNetV3,DYNProxylessNASNets,DYNProxylessNASNets_Cifar,DYNMobileNetV3_Cifar,DYNResNets_Cifar
12
+ )
13
+ from proard.classification.networks import (WideResNet,ResNet50_Cifar,ResNet50,MobileNetV3_Cifar,MobileNetV3Large_Cifar,MobileNetV3Large,ProxylessNASNets_Cifar,ProxylessNASNets,MobileNetV2_Cifar,MobileNetV2)
14
+ __all__ = [
15
+ "DYN_net",
16
+ ]
17
+
18
+
19
+
20
+ def DYN_net(net_id, robust_mode, dataset,train_criterion, pretrained=True,run_config=None,WPS=False,base=False):
21
+ if net_id == "ResNet50":
22
+ if not base:
23
+ if dataset == "cifar10" or dataset == "cifar100":
24
+ net = DYNResNets_Cifar(n_classes=run_config.data_provider.n_classes,
25
+ dropout_rate=0,
26
+ depth_list=[0, 1, 2],
27
+ expand_ratio_list=[0.2, 0.25, 0.35],
28
+ width_mult_list=[0.65, 0.8, 1.0],
29
+ )
30
+ else:
31
+ net = DYNResNets(n_classes=run_config.data_provider.n_classes,
32
+ dropout_rate=0,
33
+ depth_list=[0, 1, 2],
34
+ expand_ratio_list=[0.2, 0.25, 0.35],
35
+ width_mult_list=[0.65, 0.8, 1.0],
36
+ )
37
+ else:
38
+ if dataset == "cifar10" or dataset == "cifar100":
39
+ net = ResNet50_Cifar(n_classes=run_config.data_provider.n_classes)
40
+ else:
41
+ net = ResNet50(n_classes=run_config.data_provider.n_classes)
42
+
43
+ elif net_id == "MBV3":
44
+ if not base:
45
+ if dataset == "cifar10" or dataset == "cifar100":
46
+ net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,
47
+ dropout_rate=0,
48
+ width_mult=1.0,
49
+ ks_list=[3, 5, 7],
50
+ expand_ratio_list=[3, 4, 6],
51
+ depth_list=[2, 3, 4],
52
+ )
53
+ else:
54
+ net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,
55
+ dropout_rate=0,
56
+ width_mult=1.0,
57
+ ks_list=[3, 5, 7],
58
+ expand_ratio_list=[3, 4, 6],
59
+ depth_list=[2, 3, 4],
60
+ )
61
+ else:
62
+ if dataset == "cifar10" or dataset == "cifar100":
63
+ net = MobileNetV3Large_Cifar(n_classes=run_config.data_provider.n_classes)
64
+ else:
65
+ net = MobileNetV3Large(n_classes=run_config.data_provider.n_classes)
66
+ elif net_id == "ProxylessNASNet":
67
+ if not base:
68
+ if dataset == "cifar10" or dataset == "cifar100":
69
+ net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
70
+ dropout_rate=0,
71
+ width_mult=1.0,
72
+ ks_list=[3, 5, 7],
73
+ expand_ratio_list=[3, 4, 6],
74
+ depth_list=[2, 3, 4],
75
+ )
76
+ else:
77
+ net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
78
+ dropout_rate=0,
79
+ width_mult=1.0,
80
+ ks_list=[3, 5, 7],
81
+ expand_ratio_list=[3, 4, 6],
82
+ depth_list=[2, 3, 4],
83
+ )
84
+ else:
85
+ if dataset == "cifar10" or dataset == "cifar100":
86
+ net = ProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes)
87
+ else:
88
+ net = ProxylessNASNets(n_classes=run_config.data_provider.n_classes)
89
+ elif net_id == "MBV2":
90
+ if not base:
91
+ if dataset == "cifar10" or dataset == "cifar100":
92
+ net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
93
+ dropout_rate=0,
94
+ base_stage_width="google",
95
+ width_mult=1.0,
96
+ ks_list=[3, 5, 7],
97
+ expand_ratio_list=[3, 4, 6],
98
+ depth_list=[2, 3, 4],
99
+ )
100
+ else:
101
+ net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
102
+ dropout_rate=0,
103
+ base_stage_width="google",
104
+ width_mult=1.0,
105
+ ks_list=[3, 5, 7],
106
+ expand_ratio_list=[3, 4, 6],
107
+ depth_list=[2, 3, 4],
108
+ )
109
+ else:
110
+ if dataset == "cifar10" or dataset == "cifar100":
111
+ net = MobileNetV2_Cifar(n_classes=run_config.data_provider.n_classes)
112
+ else:
113
+ net = MobileNetV2(n_classes=run_config.data_provider.n_classes)
114
+ elif net_id == "WideResNet":
115
+ if dataset == "cifar10" or dataset == "cifar100":
116
+ net = WideResNet(num_classes=run_config.data_provider.n_classes)
117
+ else:
118
+ raise ValueError("Not supported: %s" % net_id)
119
+
120
+ else:
121
+ raise ValueError("Not supported: %s" % net_id)
122
+
123
+ if pretrained and not WPS and not base:
124
+ if net_id == "ResNet50":
125
+ if robust_mode:
126
+ pt_path = "exp/robust/"+ dataset + "/" + net_id + '/' + train_criterion +"/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
127
+ else:
128
+ pt_path = "exp/"+ dataset + "/" + net_id + '/' + train_criterion + "/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
129
+ else:
130
+ if robust_mode:
131
+ pt_path = "exp/robust/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
132
+
133
+ else:
134
+ pt_path = "exp/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
135
+ elif pretrained and WPS and not base:
136
+ if net_id == "ResNet50":
137
+ if robust_mode:
138
+ pt_path = "exp/robust/WPS/"+ dataset + "/" + net_id + '/' + train_criterion +"/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
139
+ else:
140
+ pt_path = "exp/WPS/"+ dataset + "/" + net_id + '/' + train_criterion + "/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
141
+ else:
142
+ if robust_mode:
143
+ pt_path = "exp/robust/WPS/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
144
+
145
+ else:
146
+ pt_path = "exp/WPS/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar"
147
+ else:
148
+ if not base:
149
+ pt_path = "exp/robust/teacher/"+ dataset + '/' + net_id + '/' + train_criterion + "/checkpoint/model_best.pth.tar"
150
+ else:
151
+ pt_path = "exp/robust/base/"+ dataset + '/' + net_id + '/' + train_criterion + "/checkpoint/model_best.pth.tar"
152
+ print(pt_path)
153
+ init = torch.load(pt_path, map_location="cuda")["state_dict"]
154
+ # from collections import OrderedDict
155
+ # new_state_dict = OrderedDict()
156
+ # for k, v in init.items():
157
+ # name = k[7:] # remove `module.`
158
+ # new_state_dict[name] = v
159
+ net.load_state_dict(init)
160
+ return net
161
+
162
+
proard/nas/__init__.py ADDED
File without changes
proard/nas/accuracy_predictor/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ from .acc_dataset import *
6
+ from .acc_predictor import *
7
+ from .arch_encoder import *
8
+ from .rob_dataset import *
9
+ from .rob_predictor import *
10
+ from .acc_rob_dataset import *
11
+ from .acc_rob_predictor import *
proard/nas/accuracy_predictor/acc_dataset.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import json
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import torch
10
+ import torch.utils.data
11
+
12
+ from proard.utils import list_mean
13
+
14
+ __all__ = ["net_setting2id", "net_id2setting", "AccuracyDataset"]
15
+
16
+
17
+ def net_setting2id(net_setting):
18
+ return json.dumps(net_setting)
19
+
20
+
21
+ def net_id2setting(net_id):
22
+ return json.loads(net_id)
23
+
24
+
25
+ class RegDataset(torch.utils.data.Dataset):
26
+ def __init__(self, inputs, targets):
27
+ super(RegDataset, self).__init__()
28
+ self.inputs = inputs
29
+ self.targets = targets
30
+
31
+ def __getitem__(self, index):
32
+ return self.inputs[index], self.targets[index]
33
+
34
+ def __len__(self):
35
+ return self.inputs.size(0)
36
+
37
+
38
+ class AccuracyDataset:
39
+ def __init__(self, path):
40
+ self.path = path
41
+ os.makedirs(self.path, exist_ok=True)
42
+
43
+ @property
44
+ def net_id_path(self):
45
+ return os.path.join(self.path, "net_id.dict")
46
+
47
+ @property
48
+ def acc_src_folder(self):
49
+ return os.path.join(self.path, "src")
50
+ @property
51
+ def acc_dict_path(self):
52
+ return os.path.join(self.path, "src/acc.dict")
53
+
54
+
55
+ # TODO: support parallel building
56
+ def build_acc_dataset(
57
+ self, run_manager, dyn_network, n_arch=2000, image_size_list=None
58
+ ):
59
+ # load net_id_list, random sample if not exist
60
+ if os.path.isfile(self.net_id_path):
61
+ net_id_list = json.load(open(self.net_id_path))
62
+ else:
63
+ net_id_list = set()
64
+ while len(net_id_list) < n_arch:
65
+ net_setting = dyn_network.sample_active_subnet()
66
+ net_id = net_setting2id(net_setting)
67
+ net_id_list.add(net_id)
68
+ net_id_list = list(net_id_list)
69
+ net_id_list.sort()
70
+ json.dump(net_id_list, open(self.net_id_path, "w"), indent=4)
71
+
72
+ image_size_list = (
73
+ [128, 160, 192, 224] if image_size_list is None else image_size_list
74
+ )
75
+ print(image_size_list)
76
+ with tqdm(
77
+ total=len(net_id_list) * len(image_size_list), desc="Building Acc Dataset"
78
+ ) as t:
79
+ for image_size in image_size_list:
80
+ # load val dataset into memory
81
+ val_dataset = []
82
+ run_manager.run_config.data_provider.assign_active_img_size(image_size)
83
+ for images, labels in run_manager.run_config.valid_loader:
84
+ val_dataset.append((images, labels))
85
+ # save path
86
+ os.makedirs(self.acc_src_folder, exist_ok=True)
87
+ acc_save_path = os.path.join(
88
+ self.acc_src_folder, "%d.dict" % image_size
89
+ )
90
+ acc_dict = {}
91
+ # load existing acc dict
92
+ if os.path.isfile(acc_save_path):
93
+ existing_acc_dict = json.load(open(acc_save_path, "r"))
94
+ else:
95
+ existing_acc_dict = {}
96
+ for net_id in net_id_list:
97
+ net_setting = net_id2setting(net_id)
98
+ key = net_setting2id({**net_setting, "image_size": image_size})
99
+ if key in existing_acc_dict:
100
+ acc_dict[key] = existing_acc_dict[key]
101
+ t.set_postfix(
102
+ {
103
+ "net_id": net_id,
104
+ "image_size": image_size,
105
+ "info_val": acc_dict[key],
106
+ "status": "loading",
107
+ }
108
+ )
109
+ t.update()
110
+ continue
111
+ dyn_network.set_active_subnet(**net_setting)
112
+ run_manager.reset_running_statistics(dyn_network)
113
+ net_setting_str = ",".join(
114
+ [
115
+ "%s_%s"
116
+ % (
117
+ key,
118
+ "%.1f" % list_mean(val)
119
+ if isinstance(val, list)
120
+ else val,
121
+ )
122
+ for key, val in net_setting.items()
123
+ ]
124
+ )
125
+ loss, (top1, top5,robust1,robust5) = run_manager.validate(
126
+ run_str=net_setting_str,
127
+ net=dyn_network,
128
+ data_loader=val_dataset,
129
+ no_logs=True,
130
+ )
131
+ info_val = top1
132
+ t.set_postfix(
133
+ {
134
+ "net_id": net_id,
135
+ "image_size": image_size,
136
+ "info_val": info_val,
137
+ }
138
+ )
139
+ t.update()
140
+
141
+ acc_dict.update({key: info_val})
142
+ json.dump(acc_dict, open(acc_save_path, "w"), indent=4)
143
+
144
+
145
+ def merge_acc_dataset(self, image_size_list=None):
146
+ # load existing data
147
+ merged_acc_dict = {}
148
+ for fname in os.listdir(self.acc_src_folder):
149
+ if ".dict" not in fname:
150
+ continue
151
+ image_size = int(fname.split(".dict")[0])
152
+ if image_size_list is not None and image_size not in image_size_list:
153
+ print("Skip ", fname)
154
+ continue
155
+ full_path = os.path.join(self.acc_src_folder, fname)
156
+ partial_acc_dict = json.load(open(full_path))
157
+ merged_acc_dict.update(partial_acc_dict)
158
+ print("loaded %s" % full_path)
159
+ json.dump(merged_acc_dict, open(self.acc_dict_path, "w"), indent=4)
160
+ return merged_acc_dict
161
+
162
+ def build_acc_data_loader(
163
+ self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16
164
+ ):
165
+ # load data
166
+ acc_dict = json.load(open(self.acc_dict_path))
167
+ X_all = []
168
+ Y_all = []
169
+
170
+ with tqdm(total=len(acc_dict), desc="Loading data") as t:
171
+ for k, v in acc_dict.items():
172
+ dic = json.loads(k)
173
+ X_all.append(arch_encoder.arch2feature(dic))
174
+ Y_all.append(v / 100.0) # range: 0 - 1
175
+ t.update()
176
+ base_acc = np.mean(Y_all)
177
+ # convert to torch tensor
178
+ X_all = torch.tensor(X_all, dtype=torch.float)
179
+ Y_all = torch.tensor(Y_all)
180
+
181
+
182
+ # random shuffle
183
+ shuffle_idx = torch.randperm(len(X_all))
184
+ X_all = X_all[shuffle_idx]
185
+ Y_all = Y_all[shuffle_idx]
186
+ # split data
187
+ idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
188
+ val_idx = X_all.size(0) // 5 * 4
189
+ X_train, Y_train = X_all[:idx], Y_all[:idx]
190
+ X_test, Y_test = X_all[val_idx:], Y_all[val_idx:]
191
+ print("Train Size: %d," % len(X_train), "Valid Size: %d" % len(X_test))
192
+
193
+ # build data loader
194
+ train_dataset = RegDataset(X_train, Y_train)
195
+ val_dataset = RegDataset(X_test, Y_test)
196
+ train_loader = torch.utils.data.DataLoader(
197
+ train_dataset,
198
+ batch_size=batch_size,
199
+ shuffle=True,
200
+ pin_memory=False,
201
+ num_workers=n_workers,
202
+ )
203
+ valid_loader = torch.utils.data.DataLoader(
204
+ val_dataset,
205
+ batch_size=batch_size,
206
+ shuffle=False,
207
+ pin_memory=False,
208
+ num_workers=n_workers,
209
+ )
210
+
211
+ return train_loader, valid_loader, base_acc
212
+
213
+
proard/nas/accuracy_predictor/acc_predictor.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ __all__ = ["AccuracyPredictor"]
11
+
12
+
13
+ class AccuracyPredictor(nn.Module):
14
+ def __init__(
15
+ self,
16
+ arch_encoder,
17
+ hidden_size=400,
18
+ n_layers=3,
19
+ checkpoint_path=None,
20
+ device="cuda:0",
21
+ base_acc_val = None
22
+ ):
23
+ super(AccuracyPredictor, self).__init__()
24
+ self.arch_encoder = arch_encoder
25
+ self.hidden_size = hidden_size
26
+ self.n_layers = n_layers
27
+ self.device = device
28
+ self.base_acc_val = base_acc_val
29
+ # build layers
30
+ layers = []
31
+ for i in range(self.n_layers):
32
+ layers.append(
33
+ nn.Sequential(
34
+ nn.Linear(
35
+ self.arch_encoder.n_dim if i == 0 else self.hidden_size,
36
+ self.hidden_size,
37
+ ),
38
+ nn.ReLU(inplace=True),
39
+ )
40
+ )
41
+ layers.append(nn.Linear(self.hidden_size, 1, bias=False))
42
+ self.layers = nn.Sequential(*layers)
43
+ if self.base_acc_val!=None :
44
+ self.base_acc = nn.Parameter(
45
+ torch.tensor(self.base_acc_val, device=self.device), requires_grad=False
46
+ )
47
+ else:
48
+ self.base_acc = nn.Parameter(
49
+ torch.zeros(1, device=self.device), requires_grad=False
50
+ )
51
+
52
+ if checkpoint_path is not None and os.path.exists(checkpoint_path):
53
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
54
+ if "state_dict" in checkpoint:
55
+ checkpoint = checkpoint["state_dict"]
56
+ self.load_state_dict(checkpoint)
57
+ print("Loaded checkpoint from %s" % checkpoint_path)
58
+
59
+ self.layers = self.layers.to(self.device)
60
+
61
+ def forward(self, x):
62
+ y = self.layers(x).squeeze()
63
+ return y + self.base_acc
64
+
65
+ def predict_acc(self, arch_dict_list):
66
+ X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
67
+ X = torch.tensor(np.array(X)).float().to(self.device)
68
+ return self.forward(X)
proard/nas/accuracy_predictor/acc_rob_dataset.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import json
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import torch
10
+ import torch.utils.data
11
+
12
+ from proard.utils import list_mean
13
+
14
+ __all__ = ["net_setting2id", "net_id2setting", "AccuracyRobustnessDataset"]
15
+
16
+
17
+ def net_setting2id(net_setting):
18
+ return json.dumps(net_setting)
19
+
20
+
21
+ def net_id2setting(net_id):
22
+ return json.loads(net_id)
23
+
24
+
25
+ class TwoRegDataset(torch.utils.data.Dataset):
26
+ def __init__(self, inputs, targets_acc , targets_rob ):
27
+ super(TwoRegDataset, self).__init__()
28
+ self.inputs = inputs
29
+ self.targets_acc = targets_acc
30
+ self.targets_rob = targets_rob
31
+
32
+ def __getitem__(self, index):
33
+ return self.inputs[index], self.targets_acc[index] , self.targets_rob[index]
34
+
35
+ def __len__(self):
36
+ return self.inputs.size(0)
37
+
38
+
39
+ class AccuracyRobustnessDataset:
40
+ def __init__(self, path):
41
+ self.path = path
42
+ os.makedirs(self.path, exist_ok=True)
43
+
44
+ @property
45
+ def net_id_path(self):
46
+ return os.path.join(self.path, "net_id.dict")
47
+
48
+ @property
49
+ def acc_rob_src_folder(self):
50
+ return os.path.join(self.path, "src")
51
+ @property
52
+ def acc_rob_dict_path(self):
53
+ return os.path.join(self.path, "src/acc_robust.dict")
54
+
55
+
56
+ # TODO: support parallel building
57
+ def build_acc_rob_dataset(
58
+ self, run_manager, dyn_network, n_arch=2000, image_size_list=None
59
+ ):
60
+ # load net_id_list, random sample if not exist
61
+ if os.path.isfile(self.net_id_path):
62
+ net_id_list = json.load(open(self.net_id_path))
63
+ else:
64
+ net_id_list = set()
65
+ while len(net_id_list) < n_arch:
66
+ net_setting = dyn_network.sample_active_subnet()
67
+ net_id = net_setting2id(net_setting)
68
+ net_id_list.add(net_id)
69
+ net_id_list = list(net_id_list)
70
+ net_id_list.sort()
71
+ json.dump(net_id_list, open(self.net_id_path, "w"), indent=4)
72
+
73
+ image_size_list = (
74
+ [128, 160, 192, 224] if image_size_list is None else image_size_list
75
+ )
76
+ print(image_size_list)
77
+ with tqdm(
78
+ total=len(net_id_list) * len(image_size_list), desc="Building Acc Dataset"
79
+ ) as t:
80
+ for image_size in image_size_list:
81
+ # load val dataset into memory
82
+ val_dataset = []
83
+ run_manager.run_config.data_provider.assign_active_img_size(image_size)
84
+ for images, labels in run_manager.run_config.valid_loader:
85
+ val_dataset.append((images, labels))
86
+ # save path
87
+ os.makedirs(self.acc_rob_src_folder, exist_ok=True)
88
+ acc_rob_save_path = os.path.join(
89
+ self.acc_rob_src_folder, "%d.dict" % image_size
90
+ )
91
+ acc_rob_dict = {}
92
+ # load existing acc dict
93
+ if os.path.isfile(acc_rob_save_path):
94
+ existing_acc_rob_dict = json.load(open(acc_rob_save_path, "r"))
95
+ else:
96
+ existing_acc_rob_dict = {}
97
+ for net_id in net_id_list:
98
+ net_setting = net_id2setting(net_id)
99
+ key = net_setting2id({**net_setting, "image_size": image_size})
100
+ if key in existing_acc_rob_dict:
101
+ acc_rob_dict[key] = existing_acc_rob_dict[key]
102
+ t.set_postfix(
103
+ {
104
+ "net_id": net_id,
105
+ "image_size": image_size,
106
+ "info_val": acc_rob_dict[key],
107
+ "status": "loading",
108
+ }
109
+ )
110
+ t.update()
111
+ continue
112
+ dyn_network.set_active_subnet(**net_setting)
113
+ run_manager.reset_running_statistics(dyn_network)
114
+ net_setting_str = ",".join(
115
+ [
116
+ "%s_%s"
117
+ % (
118
+ key,
119
+ "%.1f" % list_mean(val)
120
+ if isinstance(val, list)
121
+ else val,
122
+ )
123
+ for key, val in net_setting.items()
124
+ ]
125
+ )
126
+ loss, (top1, top5,robust1,robust5) = run_manager.validate(
127
+ run_str=net_setting_str,
128
+ net=dyn_network,
129
+ data_loader=val_dataset,
130
+ no_logs=True,
131
+ )
132
+ info_val = [top1,robust1]
133
+ t.set_postfix(
134
+ {
135
+ "net_id": net_id,
136
+ "image_size": image_size,
137
+ "info_val": info_val,
138
+ }
139
+ )
140
+ t.update()
141
+
142
+ acc_rob_dict.update({key: info_val})
143
+ json.dump(acc_rob_dict, open(acc_rob_save_path, "w"), indent=4)
144
+
145
+
146
+ def merge_acc_dataset(self, image_size_list=None):
147
+ # load existing data
148
+ merged_acc_rob_dict = {}
149
+ for fname in os.listdir(self.acc_rob_src_folder):
150
+ if ".dict" not in fname:
151
+ continue
152
+ image_size = int(fname.split(".dict")[0])
153
+ if image_size_list is not None and image_size not in image_size_list:
154
+ print("Skip ", fname)
155
+ continue
156
+ full_path = os.path.join(self.acc_rob_src_folder, fname)
157
+ partial_acc_rob_dict = json.load(open(full_path))
158
+ merged_acc_rob_dict.update(partial_acc_rob_dict)
159
+ print("loaded %s" % full_path)
160
+ json.dump(merged_acc_rob_dict, open(self.acc_rob_dict_path, "w"), indent=4)
161
+ return merged_acc_rob_dict
162
+
163
+ def build_acc_data_loader(
164
+ self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16
165
+ ):
166
+ # load data
167
+ acc_rob_dict = json.load(open(self.acc_rob_dict_path))
168
+ X_all = []
169
+ Y_acc_all = []
170
+ Y_rob_all = []
171
+
172
+ with tqdm(total=len(acc_rob_dict), desc="Loading data") as t:
173
+ for k, v in acc_rob_dict.items():
174
+ dic = json.loads(k)
175
+ X_all.append(arch_encoder.arch2feature(dic))
176
+ Y_acc_all.append(v[0] / 100.0) # range: 0 - 1
177
+ Y_rob_all.append(v[1] / 100.0)
178
+ t.update()
179
+ base_acc = np.mean(Y_acc_all)
180
+ base_rob = np.mean(Y_rob_all)
181
+ # convert to torch tensor
182
+ X_all = torch.tensor(X_all, dtype=torch.float)
183
+ Y_acc_all = torch.tensor(Y_acc_all)
184
+ Y_rob_all = torch.tensor(Y_rob_all)
185
+
186
+
187
+ # random shuffle
188
+ shuffle_idx = torch.randperm(len(X_all))
189
+ X_all = X_all[shuffle_idx]
190
+ Y_acc_all = Y_acc_all[shuffle_idx]
191
+ Y_rob_all = Y_rob_all[shuffle_idx]
192
+ # split data
193
+ idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
194
+ val_idx = X_all.size(0) // 5 * 4
195
+ X_train, Y_acc_train, Y_rob_train = X_all[:idx], Y_acc_all[:idx], Y_rob_all[:idx]
196
+ X_test, Y_acc_test , Y_rob_test = X_all[val_idx:], Y_acc_all[val_idx:] , Y_rob_all[val_idx:]
197
+ print("Train Size: %d," % len(X_train), "Valid Size: %d" % len(X_test))
198
+
199
+ # build data loader
200
+ train_dataset = TwoRegDataset(X_train, Y_acc_train , Y_rob_train)
201
+ val_dataset = TwoRegDataset(X_test, Y_acc_test ,Y_rob_test )
202
+ train_loader = torch.utils.data.DataLoader(
203
+ train_dataset,
204
+ batch_size=batch_size,
205
+ shuffle=True,
206
+ pin_memory=False,
207
+ num_workers=n_workers,
208
+ )
209
+ valid_loader = torch.utils.data.DataLoader(
210
+ val_dataset,
211
+ batch_size=batch_size,
212
+ shuffle=False,
213
+ pin_memory=False,
214
+ num_workers=n_workers,
215
+ )
216
+
217
+ return train_loader, valid_loader, base_acc, base_rob
218
+
219
+
proard/nas/accuracy_predictor/acc_rob_predictor.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ __all__ = ["Accuracy_Robustness_Predictor"]
11
+
12
+
13
+ class Accuracy_Robustness_Predictor(nn.Module):
14
+ def __init__(
15
+ self,
16
+ arch_encoder,
17
+ hidden_size=400,
18
+ n_layers=6,
19
+ checkpoint_path=None,
20
+ device="cuda:0",
21
+ base_acc_val = None,
22
+ base_rob_val = None
23
+ ):
24
+ super(Accuracy_Robustness_Predictor, self).__init__()
25
+ self.arch_encoder = arch_encoder
26
+ self.hidden_size = hidden_size
27
+ self.n_layers = n_layers
28
+ self.device = device
29
+ self.base_acc_val = base_acc_val
30
+ self.base_rob_val = base_rob_val
31
+ # build layers
32
+ layers = []
33
+ for i in range(self.n_layers):
34
+ layers.append(
35
+ nn.Sequential(
36
+ nn.Linear(
37
+ self.arch_encoder.n_dim if i == 0 else self.hidden_size,
38
+ self.hidden_size,
39
+ ),
40
+ nn.ReLU(inplace=True),
41
+ )
42
+ )
43
+ layers.append(nn.Linear(self.hidden_size, 2, bias=False))
44
+ self.layers = nn.Sequential(*layers)
45
+ if self.base_acc_val!=None :
46
+ self.base_acc = nn.Parameter(
47
+ torch.tensor(self.base_acc_val, device=self.device), requires_grad=False
48
+ )
49
+ else:
50
+ self.base_acc = nn.Parameter(
51
+ torch.zeros(1, device=self.device), requires_grad=False
52
+ )
53
+ if self.base_rob_val!=None :
54
+ self.base_rob = nn.Parameter(
55
+ torch.tensor(self.base_rob_val, device=self.device), requires_grad=False
56
+ )
57
+ else:
58
+ self.base_rob = nn.Parameter(
59
+ torch.zeros(1, device=self.device), requires_grad=False
60
+ )
61
+ if checkpoint_path is not None and os.path.exists(checkpoint_path):
62
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
63
+ if "state_dict" in checkpoint:
64
+ checkpoint = checkpoint["state_dict"]
65
+ self.load_state_dict(checkpoint)
66
+ print("Loaded checkpoint from %s" % checkpoint_path)
67
+
68
+ self.layers = self.layers.to(self.device)
69
+
70
+ def forward(self, x):
71
+ y = self.layers(x).squeeze()
72
+ return y + self.base_acc
73
+
74
+ def predict_acc_rob(self, arch_dict_list):
75
+ X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
76
+ X = torch.tensor(np.array(X)).float().to(self.device)
77
+ return self.forward(X)
proard/nas/accuracy_predictor/arch_encoder.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+
6
+ import random
7
+ import numpy as np
8
+ from proard.classification.networks import ResNets
9
+
10
+ __all__ = ["MobileNetArchEncoder", "ResNetArchEncoder"]
11
+
12
+
13
+ class MobileNetArchEncoder:
14
+ SPACE_TYPE = "mbv3"
15
+
16
+ def __init__(
17
+ self,
18
+ image_size_list=None,
19
+ ks_list=None,
20
+ expand_list=None,
21
+ depth_list=None,
22
+ n_stage=None,
23
+ ):
24
+ self.image_size_list = [224] if image_size_list is None else image_size_list
25
+ self.ks_list = [3, 5, 7] if ks_list is None else ks_list
26
+ self.expand_list = (
27
+ [3, 4, 6]
28
+ if expand_list is None
29
+ else [int(expand) for expand in expand_list]
30
+ )
31
+ self.depth_list = [2, 3, 4] if depth_list is None else depth_list
32
+ if n_stage is not None:
33
+ self.n_stage = n_stage
34
+ elif self.SPACE_TYPE == "mbv2":
35
+ self.n_stage = 6
36
+ elif self.SPACE_TYPE == "mbv3":
37
+ self.n_stage = 5
38
+ else:
39
+ raise NotImplementedError
40
+
41
+ # build info dict
42
+ self.n_dim = 0
43
+ self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
44
+ self._build_info_dict(target="r")
45
+ self.k_info = dict(id2val=[], val2id=[], L=[], R=[])
46
+ self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
47
+ self._build_info_dict(target="k")
48
+ self._build_info_dict(target="e")
49
+
50
+ @property
51
+ def max_n_blocks(self):
52
+ if self.SPACE_TYPE == "mbv3":
53
+ return self.n_stage * max(self.depth_list)
54
+ elif self.SPACE_TYPE == "mbv2":
55
+ return (self.n_stage - 1) * max(self.depth_list) + 1
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ def _build_info_dict(self, target):
60
+ if target == "r":
61
+ target_dict = self.r_info
62
+ target_dict["L"].append(self.n_dim)
63
+ for img_size in self.image_size_list:
64
+ target_dict["val2id"][img_size] = self.n_dim
65
+ target_dict["id2val"][self.n_dim] = img_size
66
+ self.n_dim += 1
67
+ target_dict["R"].append(self.n_dim)
68
+ else:
69
+ if target == "k":
70
+ target_dict = self.k_info
71
+ choices = self.ks_list
72
+ elif target == "e":
73
+ target_dict = self.e_info
74
+ choices = self.expand_list
75
+ else:
76
+ raise NotImplementedError
77
+ for i in range(self.max_n_blocks):
78
+ target_dict["val2id"].append({})
79
+ target_dict["id2val"].append({})
80
+ target_dict["L"].append(self.n_dim)
81
+ for k in choices:
82
+ target_dict["val2id"][i][k] = self.n_dim
83
+ target_dict["id2val"][i][self.n_dim] = k
84
+ self.n_dim += 1
85
+ target_dict["R"].append(self.n_dim)
86
+
87
+ def arch2feature(self, arch_dict):
88
+ ks, e, d, r = (
89
+ arch_dict["ks"],
90
+ arch_dict["e"],
91
+ arch_dict["d"],
92
+ arch_dict["image_size"],
93
+ )
94
+ feature = np.zeros(self.n_dim)
95
+ for i in range(self.max_n_blocks):
96
+ nowd = i % max(self.depth_list)
97
+ stg = i // max(self.depth_list)
98
+ if nowd < d[stg]:
99
+ feature[self.k_info["val2id"][i][ks[i]]] = 1
100
+ feature[self.e_info["val2id"][i][e[i]]] = 1
101
+ feature[self.r_info["val2id"][r[0]]] = 1
102
+ return feature
103
+
104
+ def feature2arch(self, feature):
105
+ img_sz = self.r_info["id2val"][
106
+ int(np.argmax(feature[self.r_info["L"][0] : self.r_info["R"][0]]))
107
+ + self.r_info["L"][0]
108
+ ]
109
+ assert img_sz in self.image_size_list
110
+ arch_dict = {"ks": [], "e": [], "d": [], "image_size": img_sz}
111
+
112
+ d = 0
113
+ for i in range(self.max_n_blocks):
114
+ skip = True
115
+ for j in range(self.k_info["L"][i], self.k_info["R"][i]):
116
+ if feature[j] == 1:
117
+ arch_dict["ks"].append(self.k_info["id2val"][i][j])
118
+ skip = False
119
+ break
120
+
121
+ for j in range(self.e_info["L"][i], self.e_info["R"][i]):
122
+ if feature[j] == 1:
123
+ arch_dict["e"].append(self.e_info["id2val"][i][j])
124
+ assert not skip
125
+ skip = False
126
+ break
127
+
128
+ if skip:
129
+ arch_dict["e"].append(0)
130
+ arch_dict["ks"].append(0)
131
+ else:
132
+ d += 1
133
+
134
+ if (i + 1) % max(self.depth_list) == 0 or (i + 1) == self.max_n_blocks:
135
+ arch_dict["d"].append(d)
136
+ d = 0
137
+ return arch_dict
138
+
139
+ def random_sample_arch(self):
140
+ return {
141
+ "ks": random.choices(self.ks_list, k=self.max_n_blocks),
142
+ "e": random.choices(self.expand_list, k=self.max_n_blocks),
143
+ "d": random.choices(self.depth_list, k=self.n_stage),
144
+ "image_size": [random.choice(self.image_size_list)],
145
+ }
146
+
147
+ def mutate_resolution(self, arch_dict, mutate_prob):
148
+ if random.random() < mutate_prob:
149
+ arch_dict["image_size"] = random.choice(self.image_size_list)
150
+ return arch_dict
151
+
152
+ def mutate_arch(self, arch_dict, mutate_prob):
153
+ for i in range(self.max_n_blocks):
154
+ if random.random() < mutate_prob:
155
+ arch_dict["ks"][i] = random.choice(self.ks_list)
156
+ arch_dict["e"][i] = random.choice(self.expand_list)
157
+
158
+ for i in range(self.n_stage):
159
+ if random.random() < mutate_prob:
160
+ arch_dict["d"][i] = random.choice(self.depth_list)
161
+ return arch_dict
162
+
163
+
164
+ class ResNetArchEncoder:
165
+ def __init__(
166
+ self,
167
+ image_size_list=None,
168
+ depth_list=None,
169
+ expand_list=None,
170
+ width_mult_list=None,
171
+ base_depth_list=None,
172
+ ):
173
+ self.image_size_list = [224] if image_size_list is None else image_size_list
174
+ self.expand_list = [0.2, 0.25, 0.35] if expand_list is None else expand_list
175
+ self.depth_list = [0, 1, 2] if depth_list is None else depth_list
176
+ self.width_mult_list = (
177
+ [0.65, 0.8, 1.0] if width_mult_list is None else width_mult_list
178
+ )
179
+
180
+ self.base_depth_list = (
181
+ ResNets.BASE_DEPTH_LIST if base_depth_list is None else base_depth_list
182
+ )
183
+
184
+ """" build info dict """
185
+ self.n_dim = 0
186
+ # resolution
187
+ self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
188
+ self._build_info_dict(target="r")
189
+ # input stem skip
190
+ self.input_stem_d_info = dict(id2val={}, val2id={}, L=[], R=[])
191
+ self._build_info_dict(target="input_stem_d")
192
+ # width_mult
193
+ self.width_mult_info = dict(id2val=[], val2id=[], L=[], R=[])
194
+ self._build_info_dict(target="width_mult")
195
+ # expand ratio
196
+ self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
197
+ self._build_info_dict(target="e")
198
+
199
+ @property
200
+ def n_stage(self):
201
+ return len(self.base_depth_list)
202
+
203
+ @property
204
+ def max_n_blocks(self):
205
+ return sum(self.base_depth_list) + self.n_stage * max(self.depth_list)
206
+
207
+ def _build_info_dict(self, target):
208
+ if target == "r":
209
+ target_dict = self.r_info
210
+ target_dict["L"].append(self.n_dim)
211
+ for img_size in self.image_size_list:
212
+ target_dict["val2id"][img_size] = self.n_dim
213
+ target_dict["id2val"][self.n_dim] = img_size
214
+ self.n_dim += 1
215
+ target_dict["R"].append(self.n_dim)
216
+ elif target == "input_stem_d":
217
+ target_dict = self.input_stem_d_info
218
+ target_dict["L"].append(self.n_dim)
219
+ for skip in [0, 1]:
220
+ target_dict["val2id"][skip] = self.n_dim
221
+ target_dict["id2val"][self.n_dim] = skip
222
+ self.n_dim += 1
223
+ target_dict["R"].append(self.n_dim)
224
+ elif target == "e":
225
+ target_dict = self.e_info
226
+ choices = self.expand_list
227
+ for i in range(self.max_n_blocks):
228
+ target_dict["val2id"].append({})
229
+ target_dict["id2val"].append({})
230
+ target_dict["L"].append(self.n_dim)
231
+ for e in choices:
232
+ target_dict["val2id"][i][e] = self.n_dim
233
+ target_dict["id2val"][i][self.n_dim] = e
234
+ self.n_dim += 1
235
+ target_dict["R"].append(self.n_dim)
236
+ elif target == "width_mult":
237
+ target_dict = self.width_mult_info
238
+ choices = list(range(len(self.width_mult_list)))
239
+ for i in range(self.n_stage + 2):
240
+ target_dict["val2id"].append({})
241
+ target_dict["id2val"].append({})
242
+ target_dict["L"].append(self.n_dim)
243
+ for w in choices:
244
+ target_dict["val2id"][i][w] = self.n_dim
245
+ target_dict["id2val"][i][self.n_dim] = w
246
+ self.n_dim += 1
247
+ target_dict["R"].append(self.n_dim)
248
+
249
+ def arch2feature(self, arch_dict):
250
+ d, e, w, r = (
251
+ arch_dict["d"],
252
+ arch_dict["e"],
253
+ arch_dict["w"],
254
+ arch_dict["image_size"],
255
+ )
256
+ input_stem_skip = 1 if d[0] > 0 else 0
257
+ d = d[1:]
258
+
259
+ feature = np.zeros(self.n_dim)
260
+ feature[self.r_info["val2id"][r]] = 1
261
+ feature[self.input_stem_d_info["val2id"][input_stem_skip]] = 1
262
+ for i in range(self.n_stage + 2):
263
+ feature[self.width_mult_info["val2id"][i][w[i]]] = 1
264
+
265
+ start_pt = 0
266
+ for i, base_depth in enumerate(self.base_depth_list):
267
+ depth = base_depth + d[i]
268
+ for j in range(start_pt, start_pt + depth):
269
+ feature[self.e_info["val2id"][j][e[j]]] = 1
270
+ start_pt += max(self.depth_list) + base_depth
271
+ return feature
272
+
273
+ def feature2arch(self, feature):
274
+ img_sz = self.r_info["id2val"][
275
+ int(np.argmax(feature[self.r_info["L"][0] : self.r_info["R"][0]]))
276
+ + self.r_info["L"][0]
277
+ ]
278
+ input_stem_skip = (
279
+ self.input_stem_d_info["id2val"][
280
+ int(
281
+ np.argmax(
282
+ feature[
283
+ self.input_stem_d_info["L"][0] : self.input_stem_d_info[
284
+ "R"
285
+ ][0]
286
+ ]
287
+ )
288
+ )
289
+ + self.input_stem_d_info["L"][0]
290
+ ]
291
+ * 2
292
+ )
293
+ assert img_sz in self.image_size_list
294
+ arch_dict = {"d": [input_stem_skip], "e": [], "w": [], "image_size": img_sz}
295
+
296
+ for i in range(self.n_stage + 2):
297
+ arch_dict["w"].append(
298
+ self.width_mult_info["id2val"][i][
299
+ int(
300
+ np.argmax(
301
+ feature[
302
+ self.width_mult_info["L"][i] : self.width_mult_info[
303
+ "R"
304
+ ][i]
305
+ ]
306
+ )
307
+ )
308
+ + self.width_mult_info["L"][i]
309
+ ]
310
+ )
311
+
312
+ d = 0
313
+ skipped = 0
314
+ stage_id = 0
315
+ for i in range(self.max_n_blocks):
316
+ skip = True
317
+ for j in range(self.e_info["L"][i], self.e_info["R"][i]):
318
+ if feature[j] == 1:
319
+ arch_dict["e"].append(self.e_info["id2val"][i][j])
320
+ skip = False
321
+ break
322
+ if skip:
323
+ arch_dict["e"].append(0)
324
+ skipped += 1
325
+ else:
326
+ d += 1
327
+
328
+ if (
329
+ i + 1 == self.max_n_blocks
330
+ or (skipped + d)
331
+ % (max(self.depth_list) + self.base_depth_list[stage_id])
332
+ == 0
333
+ ):
334
+ arch_dict["d"].append(d - self.base_depth_list[stage_id])
335
+ d, skipped = 0, 0
336
+ stage_id += 1
337
+ return arch_dict
338
+
339
+ def random_sample_arch(self):
340
+ return {
341
+ "d": [random.choice([0, 2])]
342
+ + random.choices(self.depth_list, k=self.n_stage),
343
+ "e": random.choices(self.expand_list, k=self.max_n_blocks),
344
+ "w": random.choices(
345
+ list(range(len(self.width_mult_list))), k=self.n_stage + 2
346
+ ),
347
+ "image_size": random.choice(self.image_size_list),
348
+ }
349
+
350
+ def mutate_resolution(self, arch_dict, mutate_prob):
351
+ if random.random() < mutate_prob:
352
+ arch_dict["image_size"] = random.choice(self.image_size_list)
353
+ return arch_dict
354
+
355
+ def mutate_arch(self, arch_dict, mutate_prob):
356
+ # input stem skip
357
+ if random.random() < mutate_prob:
358
+ arch_dict["d"][0] = random.choice([0, 2])
359
+ # depth
360
+ for i in range(1, len(arch_dict["d"])):
361
+ if random.random() < mutate_prob:
362
+ arch_dict["d"][i] = random.choice(self.depth_list)
363
+ # width_mult
364
+ for i in range(len(arch_dict["w"])):
365
+ if random.random() < mutate_prob:
366
+ arch_dict["w"][i] = random.choice(
367
+ list(range(len(self.width_mult_list)))
368
+ )
369
+ # expand ratio
370
+ for i in range(len(arch_dict["e"])):
371
+ if random.random() < mutate_prob:
372
+ arch_dict["e"][i] = random.choice(self.expand_list)
proard/nas/accuracy_predictor/rob_dataset.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import json
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import torch
10
+ import torch.utils.data
11
+
12
+ from proard.utils import list_mean
13
+
14
+ __all__ = ["net_setting2id", "net_id2setting", "RobustnessDataset"]
15
+
16
+
17
+ def net_setting2id(net_setting):
18
+ return json.dumps(net_setting)
19
+
20
+
21
+ def net_id2setting(net_id):
22
+ return json.loads(net_id)
23
+
24
+
25
+ class RegDataset(torch.utils.data.Dataset):
26
+ def __init__(self, inputs, targets):
27
+ super(RegDataset, self).__init__()
28
+ self.inputs = inputs
29
+ self.targets = targets
30
+
31
+ def __getitem__(self, index):
32
+ return self.inputs[index], self.targets[index]
33
+
34
+ def __len__(self):
35
+ return self.inputs.size(0)
36
+
37
+
38
+ class RobustnessDataset:
39
+ def __init__(self, path):
40
+ self.path = path
41
+ os.makedirs(self.path, exist_ok=True)
42
+
43
+ @property
44
+ def net_id_path(self):
45
+ return os.path.join(self.path, "net_id.dict")
46
+
47
+ @property
48
+ def rob_src_folder(self):
49
+ return os.path.join(self.path, "src_rob")
50
+ @property
51
+ def rob_dict_path(self):
52
+ return os.path.join(self.path, "src_rob/rob.dict")
53
+
54
+ # TODO: support parallel building
55
+ def build_rob_dataset(
56
+ self, run_manager, dyn_network, n_arch=2000, image_size_list=None
57
+ ):
58
+ # load net_id_list, random sample if not exist
59
+ if os.path.isfile(self.net_id_path):
60
+ net_id_list = json.load(open(self.net_id_path))
61
+ else:
62
+ net_id_list = set()
63
+ while len(net_id_list) < n_arch:
64
+ net_setting = dyn_network.sample_active_subnet()
65
+ net_id = net_setting2id(net_setting)
66
+ net_id_list.add(net_id)
67
+ net_id_list = list(net_id_list)
68
+ net_id_list.sort()
69
+ json.dump(net_id_list, open(self.net_id_path, "w"), indent=4)
70
+
71
+ image_size_list = (
72
+ [128, 160, 192, 224] if image_size_list is None else image_size_list
73
+ )
74
+
75
+ with tqdm(
76
+ total=len(net_id_list) * len(image_size_list), desc="Building Robustness Dataset"
77
+ ) as t:
78
+ for image_size in image_size_list:
79
+ # load val dataset into memory
80
+ val_dataset = []
81
+ run_manager.run_config.data_provider.assign_active_img_size(image_size)
82
+ for images, labels in run_manager.run_config.valid_loader:
83
+ val_dataset.append((images, labels))
84
+ # save path
85
+ os.makedirs(self.rob_src_folder, exist_ok=True)
86
+
87
+ rob_save_path = os.path.join(
88
+ self.rob_src_folder, "%d.dict" % image_size
89
+ )
90
+
91
+ rob_dict ={}
92
+ # load existing rob dict
93
+ if os.path.isfile(rob_save_path):
94
+ existing_rob_dict = json.load(open(rob_save_path,"r"))
95
+ else:
96
+ existing_rob_dict = {}
97
+ for net_id in net_id_list:
98
+ net_setting = net_id2setting(net_id)
99
+ key = net_setting2id({**net_setting, "image_size": image_size})
100
+ if key in existing_rob_dict:
101
+ rob_dict[key] = existing_rob_dict[key]
102
+ t.set_postfix(
103
+ {
104
+ "net_id": net_id,
105
+ "image_size": image_size,
106
+ "info_rob" : rob_dict[key],
107
+ "status": "loading",
108
+ }
109
+ )
110
+ t.update()
111
+ continue
112
+ dyn_network.set_active_subnet(**net_setting)
113
+ run_manager.reset_running_statistics(dyn_network)
114
+ net_setting_str = ",".join(
115
+ [
116
+ "%s_%s"
117
+ % (
118
+ key,
119
+ "%.1f" % list_mean(val)
120
+ if isinstance(val, list)
121
+ else val,
122
+ )
123
+ for key, val in net_setting.items()
124
+ ]
125
+ )
126
+ loss, (top1, top5,robust1,robust5) = run_manager.validate(
127
+ run_str=net_setting_str,
128
+ net=dyn_network,
129
+ data_loader=val_dataset,
130
+ no_logs=True,
131
+ )
132
+ info_robust = robust1
133
+ t.set_postfix(
134
+ {
135
+ "net_id": net_id,
136
+ "image_size": image_size,
137
+ "info_rob" : info_robust,
138
+ "info_robust" : info_robust,
139
+ }
140
+ )
141
+ t.update()
142
+
143
+ rob_dict.update({key:info_robust})
144
+ json.dump(rob_dict, open(rob_save_path, "w"), indent=4)
145
+
146
+ def merge_rob_dataset(self, image_size_list=None):
147
+ # load existing data
148
+ merged_rob_dict = {}
149
+ for fname in os.listdir(self.rob_src_folder):
150
+ if ".dict" not in fname:
151
+ continue
152
+ image_size = int(fname.split(".dict")[0])
153
+ if image_size_list is not None and image_size not in image_size_list:
154
+ print("Skip ", fname)
155
+ continue
156
+ full_path = os.path.join(self.rob_src_folder, fname)
157
+ partial_rob_dict = json.load(open(full_path))
158
+ merged_rob_dict.update(partial_rob_dict)
159
+ print("loaded %s" % full_path)
160
+ json.dump(merged_rob_dict, open(self.rob_dict_path, "w"), indent=4)
161
+ return merged_rob_dict
162
+
163
+ def build_rob_data_loader(
164
+ self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16
165
+ ):
166
+ # load data
167
+ rob_dict = json.load(open(self.rob_dict_path))
168
+ X_all_rob = []
169
+ Y_all_rob = []
170
+ with tqdm(total=len(rob_dict), desc="Loading data") as t:
171
+ for k, v in rob_dict.items():
172
+ dic = json.loads(k)
173
+ X_all_rob.append(arch_encoder.arch2feature(dic))
174
+ Y_all_rob.append(v / 100.0) # range: 0 - 1
175
+ t.update()
176
+ base_rob = np.mean(Y_all_rob)
177
+ # convert to torch tensor
178
+ X_all_rob = torch.tensor(X_all_rob, dtype=torch.float)
179
+ Y_all_rob = torch.tensor(Y_all_rob)
180
+
181
+ # random shuffle
182
+ shuffle_idx_rob = torch.randperm(len(X_all_rob))
183
+ X_all_rob = X_all_rob[shuffle_idx_rob]
184
+ Y_all_rob = Y_all_rob[shuffle_idx_rob]
185
+ # split data
186
+ idx_rob = X_all_rob.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
187
+ val_idx_rob = X_all_rob.size(0) // 5 * 4
188
+ X_train_rob, Y_train_rob = X_all_rob[:idx_rob], Y_all_rob[:idx_rob]
189
+ X_test_rob, Y_test_rob = X_all_rob[val_idx_rob:], Y_all_rob[val_idx_rob:]
190
+ print("Train Robustness Size: %d," % len(X_train_rob), "Valid Robustness Size: %d" % len(X_test_rob))
191
+ # build data loader
192
+ train_dataset_rob = RegDataset(X_train_rob, Y_train_rob)
193
+ val_dataset_rob = RegDataset(X_test_rob, Y_test_rob)
194
+
195
+ train_loader_rob = torch.utils.data.DataLoader(
196
+ train_dataset_rob,
197
+ batch_size=batch_size,
198
+ shuffle=True,
199
+ pin_memory=False,
200
+ num_workers=n_workers,
201
+ )
202
+ valid_loader_rob = torch.utils.data.DataLoader(
203
+ val_dataset_rob,
204
+ batch_size=batch_size,
205
+ shuffle=False,
206
+ pin_memory=False,
207
+ num_workers=n_workers,
208
+ )
209
+ return train_loader_rob, valid_loader_rob , base_rob
210
+
211
+
proard/nas/accuracy_predictor/rob_predictor.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ __all__ = ["RobustnessPredictor"]
11
+
12
+ class RobustnessPredictor(nn.Module):
13
+ def __init__(
14
+ self,
15
+ arch_encoder,
16
+ hidden_size=400,
17
+ n_layers=3,
18
+ checkpoint_path=None,
19
+ device="cuda:0",
20
+ base_rob_val = None,
21
+ ):
22
+ super(RobustnessPredictor, self).__init__()
23
+ self.arch_encoder = arch_encoder
24
+ self.hidden_size = hidden_size
25
+ self.n_layers = n_layers
26
+ self.device = device
27
+ self.base_rob_val = base_rob_val
28
+ # build layers
29
+ layers = []
30
+ for i in range(self.n_layers):
31
+ layers.append(
32
+ nn.Sequential(
33
+ nn.Linear(
34
+ self.arch_encoder.n_dim if i == 0 else self.hidden_size,
35
+ self.hidden_size,
36
+ ),
37
+ nn.ReLU(inplace=True),
38
+ )
39
+ )
40
+ layers.append(nn.Linear(self.hidden_size, 1, bias=False))
41
+ self.layers = nn.Sequential(*layers)
42
+ if self.base_rob_val !=None :
43
+ self.base_rob = nn.Parameter(
44
+ torch.tensor(self.base_rob_val,device=self.device), requires_grad=False
45
+ )
46
+ else:
47
+ self.base_rob = nn.Parameter(
48
+ torch.zeros(1, device=self.device), requires_grad=False
49
+ )
50
+ if checkpoint_path is not None and os.path.exists(checkpoint_path):
51
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
52
+ if "state_dict" in checkpoint:
53
+ checkpoint = checkpoint["state_dict"]
54
+ self.load_state_dict(checkpoint)
55
+ print("Loaded checkpoint from %s" % checkpoint_path)
56
+
57
+ self.layers = self.layers.to(self.device)
58
+
59
+ def forward(self, x):
60
+ y = self.layers(x).squeeze()
61
+ return y + self.base_rob
62
+
63
+ def predict_rob(self, arch_dict_list):
64
+ X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
65
+ X = torch.tensor(np.array(X)).float().to(self.device)
66
+ return self.forward(X)
proard/nas/efficiency_predictor/__init__.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import copy
7
+ from .latency_lookup_table import *
8
+
9
+
10
+ class BaseEfficiencyModel:
11
+ def __init__(self, dyn_net):
12
+ self.dyn_net = dyn_net
13
+
14
+ def get_active_subnet_config(self, arch_dict):
15
+ arch_dict = copy.deepcopy(arch_dict)
16
+ image_size = arch_dict.pop("image_size")
17
+ self.dyn_net.set_active_subnet(**arch_dict)
18
+ active_net_config = self.dyn_net.get_active_net_config()
19
+ return active_net_config, image_size
20
+
21
+ def get_efficiency(self, arch_dict):
22
+ raise NotImplementedError
23
+
24
+
25
+ class ProxylessNASFLOPsModel(BaseEfficiencyModel):
26
+ def get_efficiency(self, arch_dict):
27
+ active_net_config, image_size = self.get_active_subnet_config(arch_dict)
28
+ return ProxylessNASLatencyTable.count_flops_given_config(
29
+ active_net_config, image_size
30
+ )
31
+
32
+
33
+ class Mbv3FLOPsModel(BaseEfficiencyModel):
34
+ def get_efficiency(self, arch_dict):
35
+ active_net_config, image_size = self.get_active_subnet_config(arch_dict)
36
+ return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size[0])
37
+
38
+
39
+ class ResNet50FLOPsModel(BaseEfficiencyModel):
40
+ def get_efficiency(self, arch_dict):
41
+ active_net_config, image_size = self.get_active_subnet_config(arch_dict)
42
+ return ResNet50LatencyTable.count_flops_given_config(
43
+ active_net_config, image_size
44
+ )
45
+
46
+
47
+ class ProxylessNASLatencyModel(BaseEfficiencyModel):
48
+ def __init__(self, dyn_net, lookup_table_path_dict):
49
+ super(ProxylessNASLatencyModel, self).__init__(dyn_net)
50
+ self.latency_tables = {}
51
+ for image_size, path in lookup_table_path_dict.items():
52
+ self.latency_tables[image_size] = ProxylessNASLatencyTable(
53
+ local_dir="/tmp/.dyn_latency_tools/",
54
+ url=os.path.join(path, "%d_lookup_table.yaml" % image_size),
55
+ )
56
+
57
+ def get_efficiency(self, arch_dict):
58
+ active_net_config, image_size = self.get_active_subnet_config(arch_dict)
59
+ return self.latency_tables[image_size].predict_network_latency_given_config(
60
+ active_net_config, image_size
61
+ )
62
+
63
+
64
+ class Mbv3LatencyModel(BaseEfficiencyModel):
65
+ def __init__(self, dyn_net, lookup_table_path_dict):
66
+ super(Mbv3LatencyModel, self).__init__(dyn_net)
67
+ self.latency_tables = {}
68
+ for image_size, path in lookup_table_path_dict.items():
69
+ self.latency_tables[image_size] = MBv3LatencyTable(
70
+ local_dir="/tmp/.dyn_latency_tools/",
71
+ url=os.path.join(path, "%d_lookup_table.yaml" % image_size),
72
+ )
73
+
74
+ def get_efficiency(self, arch_dict):
75
+ active_net_config, image_size = self.get_active_subnet_config(arch_dict)
76
+ return self.latency_tables[image_size].predict_network_latency_given_config(
77
+ active_net_config, image_size
78
+ )
proard/nas/efficiency_predictor/latency_lookup_table.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import yaml
6
+ from proard.utils import download_url, make_divisible, MyNetwork
7
+
8
+ __all__ = [
9
+ "count_conv_flop",
10
+ "ProxylessNASLatencyTable",
11
+ "MBv3LatencyTable",
12
+ "ResNet50LatencyTable",
13
+ ]
14
+
15
+
16
+ def count_conv_flop(out_size, in_channels, out_channels, kernel_size, groups):
17
+ out_h = out_w = out_size
18
+ delta_ops = (
19
+ in_channels * out_channels * kernel_size * kernel_size * out_h * out_w / groups
20
+ )
21
+ return delta_ops
22
+
23
+
24
+ class LatencyTable(object):
25
+ def __init__(
26
+ self,
27
+ local_dir="~/.dyn/latency_tools/",
28
+ url="https://raw.githubusercontent.com/han-cai/files/master/proxylessnas/mobile_trim.yaml",
29
+ ):
30
+ if url.startswith("http"):
31
+ fname = download_url(url, local_dir, overwrite=True)
32
+ else:
33
+ fname = url
34
+ with open(fname, "r") as fp:
35
+ self.lut = yaml.load(fp)
36
+
37
+ @staticmethod
38
+ def repr_shape(shape):
39
+ if isinstance(shape, (list, tuple)):
40
+ return "x".join(str(_) for _ in shape)
41
+ elif isinstance(shape, str):
42
+ return shape
43
+ else:
44
+ return TypeError
45
+
46
+ def query(self, **kwargs):
47
+ raise NotImplementedError
48
+
49
+ def predict_network_latency(self, net, image_size):
50
+ raise NotImplementedError
51
+
52
+ def predict_network_latency_given_config(self, net_config, image_size):
53
+ raise NotImplementedError
54
+
55
+ @staticmethod
56
+ def count_flops_given_config(net_config, image_size=224):
57
+ raise NotImplementedError
58
+
59
+
60
+ class ProxylessNASLatencyTable(LatencyTable):
61
+ def query(
62
+ self,
63
+ l_type: str,
64
+ input_shape,
65
+ output_shape,
66
+ expand=None,
67
+ ks=None,
68
+ stride=None,
69
+ id_skip=None,
70
+ ):
71
+ """
72
+ :param l_type:
73
+ Layer type must be one of the followings
74
+ 1. `Conv`: The initial 3x3 conv with stride 2.
75
+ 2. `Conv_1`: feature_mix_layer
76
+ 3. `Logits`: All operations after `Conv_1`.
77
+ 4. `expanded_conv`: MobileInvertedResidual
78
+ :param input_shape: input shape (h, w, #channels)
79
+ :param output_shape: output shape (h, w, #channels)
80
+ :param expand: expansion ratio
81
+ :param ks: kernel size
82
+ :param stride:
83
+ :param id_skip: indicate whether has the residual connection
84
+ """
85
+ infos = [
86
+ l_type,
87
+ "input:%s" % self.repr_shape(input_shape),
88
+ "output:%s" % self.repr_shape(output_shape),
89
+ ]
90
+
91
+ if l_type in ("expanded_conv",):
92
+ assert None not in (expand, ks, stride, id_skip)
93
+ infos += [
94
+ "expand:%d" % expand,
95
+ "kernel:%d" % ks,
96
+ "stride:%d" % stride,
97
+ "idskip:%d" % id_skip,
98
+ ]
99
+ key = "-".join(infos)
100
+ return self.lut[key]["mean"]
101
+
102
+ def predict_network_latency(self, net, image_size=224):
103
+ predicted_latency = 0
104
+ # first conv
105
+ predicted_latency += self.query(
106
+ "Conv",
107
+ [image_size, image_size, 3],
108
+ [(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels],
109
+ )
110
+ # blocks
111
+ fsize = (image_size + 1) // 2
112
+ for block in net.blocks:
113
+ mb_conv = block.conv
114
+ shortcut = block.shortcut
115
+
116
+ if mb_conv is None:
117
+ continue
118
+ if shortcut is None:
119
+ idskip = 0
120
+ else:
121
+ idskip = 1
122
+ out_fz = int((fsize - 1) / mb_conv.stride + 1) # fsize // mb_conv.stride
123
+ block_latency = self.query(
124
+ "expanded_conv",
125
+ [fsize, fsize, mb_conv.in_channels],
126
+ [out_fz, out_fz, mb_conv.out_channels],
127
+ expand=mb_conv.expand_ratio,
128
+ ks=mb_conv.kernel_size,
129
+ stride=mb_conv.stride,
130
+ id_skip=idskip,
131
+ )
132
+ predicted_latency += block_latency
133
+ fsize = out_fz
134
+ # feature mix layer
135
+ predicted_latency += self.query(
136
+ "Conv_1",
137
+ [fsize, fsize, net.feature_mix_layer.in_channels],
138
+ [fsize, fsize, net.feature_mix_layer.out_channels],
139
+ )
140
+ # classifier
141
+ predicted_latency += self.query(
142
+ "Logits",
143
+ [fsize, fsize, net.classifier.in_features],
144
+ [net.classifier.out_features], # 1000
145
+ )
146
+ return predicted_latency
147
+
148
+ def predict_network_latency_given_config(self, net_config, image_size=224):
149
+ predicted_latency = 0
150
+ # first conv
151
+ predicted_latency += self.query(
152
+ "Conv",
153
+ [image_size, image_size, 3],
154
+ [
155
+ (image_size + 1) // 2,
156
+ (image_size + 1) // 2,
157
+ net_config["first_conv"]["out_channels"],
158
+ ],
159
+ )
160
+ # blocks
161
+ fsize = (image_size + 1) // 2
162
+ for block in net_config["blocks"]:
163
+ mb_conv = (
164
+ block["mobile_inverted_conv"]
165
+ if "mobile_inverted_conv" in block
166
+ else block["conv"]
167
+ )
168
+ shortcut = block["shortcut"]
169
+
170
+ if mb_conv is None:
171
+ continue
172
+ if shortcut is None:
173
+ idskip = 0
174
+ else:
175
+ idskip = 1
176
+ out_fz = int((fsize - 1) / mb_conv["stride"] + 1)
177
+ block_latency = self.query(
178
+ "expanded_conv",
179
+ [fsize, fsize, mb_conv["in_channels"]],
180
+ [out_fz, out_fz, mb_conv["out_channels"]],
181
+ expand=mb_conv["expand_ratio"],
182
+ ks=mb_conv["kernel_size"],
183
+ stride=mb_conv["stride"],
184
+ id_skip=idskip,
185
+ )
186
+ predicted_latency += block_latency
187
+ fsize = out_fz
188
+ # feature mix layer
189
+ predicted_latency += self.query(
190
+ "Conv_1",
191
+ [fsize, fsize, net_config["feature_mix_layer"]["in_channels"]],
192
+ [fsize, fsize, net_config["feature_mix_layer"]["out_channels"]],
193
+ )
194
+ # classifier
195
+ predicted_latency += self.query(
196
+ "Logits",
197
+ [fsize, fsize, net_config["classifier"]["in_features"]],
198
+ [net_config["classifier"]["out_features"]], # 1000
199
+ )
200
+ return predicted_latency
201
+
202
+ @staticmethod
203
+ def count_flops_given_config(net_config, image_size=224):
204
+ flops = 0
205
+ # first conv
206
+ flops += count_conv_flop(
207
+ (image_size + 1) // 2, 3, net_config["first_conv"]["out_channels"], 3, 1
208
+ )
209
+ # blocks
210
+ fsize = (image_size + 1) // 2
211
+ for block in net_config["blocks"]:
212
+ mb_conv = (
213
+ block["mobile_inverted_conv"]
214
+ if "mobile_inverted_conv" in block
215
+ else block["conv"]
216
+ )
217
+ if mb_conv is None:
218
+ continue
219
+ out_fz = int((fsize - 1) / mb_conv["stride"] + 1)
220
+ if mb_conv["mid_channels"] is None:
221
+ mb_conv["mid_channels"] = round(
222
+ mb_conv["in_channels"] * mb_conv["expand_ratio"]
223
+ )
224
+ if mb_conv["expand_ratio"] != 1:
225
+ # inverted bottleneck
226
+ flops += count_conv_flop(
227
+ fsize, mb_conv["in_channels"], mb_conv["mid_channels"], 1, 1
228
+ )
229
+ # depth conv
230
+ flops += count_conv_flop(
231
+ out_fz,
232
+ mb_conv["mid_channels"],
233
+ mb_conv["mid_channels"],
234
+ mb_conv["kernel_size"],
235
+ mb_conv["mid_channels"],
236
+ )
237
+ # point linear
238
+ flops += count_conv_flop(
239
+ out_fz, mb_conv["mid_channels"], mb_conv["out_channels"], 1, 1
240
+ )
241
+ fsize = out_fz
242
+ # feature mix layer
243
+ flops += count_conv_flop(
244
+ fsize,
245
+ net_config["feature_mix_layer"]["in_channels"],
246
+ net_config["feature_mix_layer"]["out_channels"],
247
+ 1,
248
+ 1,
249
+ )
250
+ # classifier
251
+ flops += count_conv_flop(
252
+ 1,
253
+ net_config["classifier"]["in_features"],
254
+ net_config["classifier"]["out_features"],
255
+ 1,
256
+ 1,
257
+ )
258
+ return flops / 1e6 # MFLOPs
259
+
260
+
261
+ class MBv3LatencyTable(LatencyTable):
262
+ def query(
263
+ self,
264
+ l_type: str,
265
+ input_shape,
266
+ output_shape,
267
+ mid=None,
268
+ ks=None,
269
+ stride=None,
270
+ id_skip=None,
271
+ se=None,
272
+ h_swish=None,
273
+ ):
274
+ infos = [
275
+ l_type,
276
+ "input:%s" % self.repr_shape(input_shape),
277
+ "output:%s" % self.repr_shape(output_shape),
278
+ ]
279
+
280
+ if l_type in ("expanded_conv",):
281
+ assert None not in (mid, ks, stride, id_skip, se, h_swish)
282
+ infos += [
283
+ "expand:%d" % mid,
284
+ "kernel:%d" % ks,
285
+ "stride:%d" % stride,
286
+ "idskip:%d" % id_skip,
287
+ "se:%d" % se,
288
+ "hs:%d" % h_swish,
289
+ ]
290
+ key = "-".join(infos)
291
+ return self.lut[key]["mean"]
292
+
293
+ def predict_network_latency(self, net, image_size=224):
294
+ predicted_latency = 0
295
+ # first conv
296
+ predicted_latency += self.query(
297
+ "Conv",
298
+ [image_size, image_size, 3],
299
+ [(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels],
300
+ )
301
+ # blocks
302
+ fsize = (image_size + 1) // 2
303
+ for block in net.blocks:
304
+ mb_conv = block.conv
305
+ shortcut = block.shortcut
306
+
307
+ if mb_conv is None:
308
+ continue
309
+ if shortcut is None:
310
+ idskip = 0
311
+ else:
312
+ idskip = 1
313
+ out_fz = int((fsize - 1) / mb_conv.stride + 1)
314
+ block_latency = self.query(
315
+ "expanded_conv",
316
+ [fsize, fsize, mb_conv.in_channels],
317
+ [out_fz, out_fz, mb_conv.out_channels],
318
+ mid=mb_conv.depth_conv.conv.in_channels,
319
+ ks=mb_conv.kernel_size,
320
+ stride=mb_conv.stride,
321
+ id_skip=idskip,
322
+ se=1 if mb_conv.use_se else 0,
323
+ h_swish=1 if mb_conv.act_func == "h_swish" else 0,
324
+ )
325
+ predicted_latency += block_latency
326
+ fsize = out_fz
327
+ # final expand layer
328
+ predicted_latency += self.query(
329
+ "Conv_1",
330
+ [fsize, fsize, net.final_expand_layer.in_channels],
331
+ [fsize, fsize, net.final_expand_layer.out_channels],
332
+ )
333
+ # global average pooling
334
+ predicted_latency += self.query(
335
+ "AvgPool2D",
336
+ [fsize, fsize, net.final_expand_layer.out_channels],
337
+ [1, 1, net.final_expand_layer.out_channels],
338
+ )
339
+ # feature mix layer
340
+ predicted_latency += self.query(
341
+ "Conv_2",
342
+ [1, 1, net.feature_mix_layer.in_channels],
343
+ [1, 1, net.feature_mix_layer.out_channels],
344
+ )
345
+ # classifier
346
+ predicted_latency += self.query(
347
+ "Logits", [1, 1, net.classifier.in_features], [net.classifier.out_features]
348
+ )
349
+ return predicted_latency
350
+
351
+ def predict_network_latency_given_config(self, net_config, image_size=224):
352
+ predicted_latency = 0
353
+ # first conv
354
+ predicted_latency += self.query(
355
+ "Conv",
356
+ [image_size, image_size, 3],
357
+ [
358
+ (image_size + 1) // 2,
359
+ (image_size + 1) // 2,
360
+ net_config["first_conv"]["out_channels"],
361
+ ],
362
+ )
363
+ # blocks
364
+ fsize = (image_size + 1) // 2
365
+ for block in net_config["blocks"]:
366
+ mb_conv = (
367
+ block["mobile_inverted_conv"]
368
+ if "mobile_inverted_conv" in block
369
+ else block["conv"]
370
+ )
371
+ shortcut = block["shortcut"]
372
+
373
+ if mb_conv is None:
374
+ continue
375
+ if shortcut is None:
376
+ idskip = 0
377
+ else:
378
+ idskip = 1
379
+ out_fz = int((fsize - 1) / mb_conv["stride"] + 1)
380
+ if mb_conv["mid_channels"] is None:
381
+ mb_conv["mid_channels"] = round(
382
+ mb_conv["in_channels"] * mb_conv["expand_ratio"]
383
+ )
384
+ block_latency = self.query(
385
+ "expanded_conv",
386
+ [fsize, fsize, mb_conv["in_channels"]],
387
+ [out_fz, out_fz, mb_conv["out_channels"]],
388
+ mid=mb_conv["mid_channels"],
389
+ ks=mb_conv["kernel_size"],
390
+ stride=mb_conv["stride"],
391
+ id_skip=idskip,
392
+ se=1 if mb_conv["use_se"] else 0,
393
+ h_swish=1 if mb_conv["act_func"] == "h_swish" else 0,
394
+ )
395
+ predicted_latency += block_latency
396
+ fsize = out_fz
397
+ # final expand layer
398
+ predicted_latency += self.query(
399
+ "Conv_1",
400
+ [fsize, fsize, net_config["final_expand_layer"]["in_channels"]],
401
+ [fsize, fsize, net_config["final_expand_layer"]["out_channels"]],
402
+ )
403
+ # global average pooling
404
+ predicted_latency += self.query(
405
+ "AvgPool2D",
406
+ [fsize, fsize, net_config["final_expand_layer"]["out_channels"]],
407
+ [1, 1, net_config["final_expand_layer"]["out_channels"]],
408
+ )
409
+ # feature mix layer
410
+ predicted_latency += self.query(
411
+ "Conv_2",
412
+ [1, 1, net_config["feature_mix_layer"]["in_channels"]],
413
+ [1, 1, net_config["feature_mix_layer"]["out_channels"]],
414
+ )
415
+ # classifier
416
+ predicted_latency += self.query(
417
+ "Logits",
418
+ [1, 1, net_config["classifier"]["in_features"]],
419
+ [net_config["classifier"]["out_features"]],
420
+ )
421
+ return predicted_latency
422
+
423
+ @staticmethod
424
+ def count_flops_given_config(net_config, image_size=224):
425
+ flops = 0
426
+ # first conv
427
+ flops += count_conv_flop(
428
+ (image_size + 1) // 2, 3, net_config["first_conv"]["out_channels"], 3, 1
429
+ )
430
+ # blocks
431
+ fsize = (image_size + 1) // 2
432
+ for block in net_config["blocks"]:
433
+ mb_conv = (
434
+ block["mobile_inverted_conv"]
435
+ if "mobile_inverted_conv" in block
436
+ else block["conv"]
437
+ )
438
+ if mb_conv is None:
439
+ continue
440
+ out_fz = int((fsize - 1) / mb_conv["stride"] + 1)
441
+ if mb_conv["mid_channels"] is None:
442
+ mb_conv["mid_channels"] = round(
443
+ mb_conv["in_channels"] * mb_conv["expand_ratio"]
444
+ )
445
+ if mb_conv["expand_ratio"] != 1:
446
+ # inverted bottleneck
447
+ flops += count_conv_flop(
448
+ fsize, mb_conv["in_channels"], mb_conv["mid_channels"], 1, 1
449
+ )
450
+ # depth conv
451
+ flops += count_conv_flop(
452
+ out_fz,
453
+ mb_conv["mid_channels"],
454
+ mb_conv["mid_channels"],
455
+ mb_conv["kernel_size"],
456
+ mb_conv["mid_channels"],
457
+ )
458
+ if mb_conv["use_se"]:
459
+ # SE layer
460
+ se_mid = make_divisible(
461
+ mb_conv["mid_channels"] // 4, divisor=MyNetwork.CHANNEL_DIVISIBLE
462
+ )
463
+ flops += count_conv_flop(1, mb_conv["mid_channels"], se_mid, 1, 1)
464
+ flops += count_conv_flop(1, se_mid, mb_conv["mid_channels"], 1, 1)
465
+ # point linear
466
+ flops += count_conv_flop(
467
+ out_fz, mb_conv["mid_channels"], mb_conv["out_channels"], 1, 1
468
+ )
469
+ fsize = out_fz
470
+ # final expand layer
471
+ flops += count_conv_flop(
472
+ fsize,
473
+ net_config["final_expand_layer"]["in_channels"],
474
+ net_config["final_expand_layer"]["out_channels"],
475
+ 1,
476
+ 1,
477
+ )
478
+ # feature mix layer
479
+ flops += count_conv_flop(
480
+ 1,
481
+ net_config["feature_mix_layer"]["in_channels"],
482
+ net_config["feature_mix_layer"]["out_channels"],
483
+ 1,
484
+ 1,
485
+ )
486
+ # classifier
487
+ flops += count_conv_flop(
488
+ 1,
489
+ net_config["classifier"]["in_features"],
490
+ net_config["classifier"]["out_features"],
491
+ 1,
492
+ 1,
493
+ )
494
+ return flops / 1e6 # MFLOPs
495
+
496
+
497
+ class ResNet50LatencyTable(LatencyTable):
498
+ def query(self, **kwargs):
499
+ raise NotImplementedError
500
+
501
+ def predict_network_latency(self, net, image_size):
502
+ raise NotImplementedError
503
+
504
+ def predict_network_latency_given_config(self, net_config, image_size):
505
+ raise NotImplementedError
506
+
507
+ @staticmethod
508
+ def count_flops_given_config(net_config, image_size=32):
509
+ flops = 0
510
+ # input stem
511
+ for layer_config in net_config["input_stem"]:
512
+ if layer_config["name"] != "ConvLayer":
513
+ layer_config = layer_config["conv"]
514
+ in_channel = layer_config["in_channels"]
515
+ out_channel = layer_config["out_channels"]
516
+ out_image_size = int((image_size - 1) / layer_config["stride"] + 1)
517
+
518
+ flops += count_conv_flop(
519
+ out_image_size,
520
+ in_channel,
521
+ out_channel,
522
+ layer_config["kernel_size"],
523
+ layer_config.get("groups", 1),
524
+ )
525
+ image_size = out_image_size
526
+ # max pooling
527
+ # image_size = int((image_size - 1) / 2 + 1)
528
+ # ResNetBottleneckBlocks
529
+ for block_config in net_config["blocks"]:
530
+ in_channel = block_config["in_channels"]
531
+ out_channel = block_config["out_channels"]
532
+
533
+ out_image_size = int((image_size - 1) / block_config["stride"] + 1)
534
+ mid_channel = (
535
+ block_config["mid_channels"]
536
+ if block_config["mid_channels"] is not None
537
+ else round(out_channel * block_config["expand_ratio"])
538
+ )
539
+ mid_channel = make_divisible(mid_channel, MyNetwork.CHANNEL_DIVISIBLE)
540
+
541
+ # conv1
542
+ flops += count_conv_flop(image_size, in_channel, mid_channel, 1, 1)
543
+ # conv2
544
+ flops += count_conv_flop(
545
+ out_image_size,
546
+ mid_channel,
547
+ mid_channel,
548
+ block_config["kernel_size"],
549
+ block_config["groups"],
550
+ )
551
+ # conv3
552
+ flops += count_conv_flop(out_image_size, mid_channel, out_channel, 1, 1)
553
+ # downsample
554
+ if block_config["stride"] == 1 and in_channel == out_channel:
555
+ pass
556
+ else:
557
+ flops += count_conv_flop(out_image_size, in_channel, out_channel, 1, 1)
558
+ image_size = out_image_size
559
+ # final classifier
560
+ flops += count_conv_flop(
561
+ 1,
562
+ net_config["classifier"]["in_features"],
563
+ net_config["classifier"]["out_features"],
564
+ 1,
565
+ 1,
566
+ )
567
+ return flops / 1e6 # MFLOPs
proard/nas/search_algorithm/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ from .evolution import *
6
+ from .multi_evolution import *
proard/nas/search_algorithm/evolution.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import copy
6
+ import random
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+ __all__ = ["EvolutionFinder"]
11
+
12
+
13
+ class EvolutionFinder:
14
+ def __init__(self, efficiency_predictor, accuracy_predictor, Robustness_predictor, **kwargs):
15
+ self.efficiency_predictor = efficiency_predictor
16
+ self.accuracy_predictor = accuracy_predictor
17
+ self.robustness_predictor = Robustness_predictor
18
+
19
+ # evolution hyper-parameters
20
+ self.arch_mutate_prob = kwargs.get("arch_mutate_prob", 0.1)
21
+ self.resolution_mutate_prob = kwargs.get("resolution_mutate_prob", 0.5)
22
+ self.population_size = kwargs.get("population_size", 100)
23
+ self.max_time_budget = kwargs.get("max_time_budget", 500)
24
+ self.parent_ratio = kwargs.get("parent_ratio", 0.25)
25
+ self.mutation_ratio = kwargs.get("mutation_ratio", 0.5)
26
+
27
+ @property
28
+ def arch_manager(self):
29
+ return self.accuracy_predictor.arch_encoder
30
+
31
+ def update_hyper_params(self, new_param_dict):
32
+ self.__dict__.update(new_param_dict)
33
+
34
+ def random_valid_sample(self, constraint):
35
+ while True:
36
+ sample = self.arch_manager.random_sample_arch()
37
+ efficiency = self.efficiency_predictor.get_efficiency(sample)
38
+ if efficiency <= constraint:
39
+ return sample, efficiency
40
+
41
+ def mutate_sample(self, sample, constraint):
42
+ while True:
43
+ new_sample = copy.deepcopy(sample)
44
+ self.arch_manager.mutate_resolution(new_sample, self.resolution_mutate_prob)
45
+ self.arch_manager.mutate_arch(new_sample, self.arch_mutate_prob)
46
+
47
+ efficiency = self.efficiency_predictor.get_efficiency(new_sample)
48
+ if efficiency <= constraint:
49
+ return new_sample, efficiency
50
+
51
+ def crossover_sample(self, sample1, sample2, constraint):
52
+ while True:
53
+ new_sample = copy.deepcopy(sample1)
54
+ for key in new_sample.keys():
55
+ if not isinstance(new_sample[key], list):
56
+ new_sample[key] = random.choice([sample1[key], sample2[key]])
57
+ else:
58
+ for i in range(len(new_sample[key])):
59
+ new_sample[key][i] = random.choice(
60
+ [sample1[key][i], sample2[key][i]]
61
+ )
62
+
63
+ efficiency = self.efficiency_predictor.get_efficiency(new_sample)
64
+ if efficiency <= constraint:
65
+ return new_sample, efficiency
66
+
67
+ def run_evolution_search(self, constraint, verbose=False, **kwargs):
68
+ """Run a single roll-out of regularized evolution to a fixed time budget."""
69
+ self.update_hyper_params(kwargs)
70
+
71
+ mutation_numbers = int(round(self.mutation_ratio * self.population_size))
72
+ parents_size = int(round(self.parent_ratio * self.population_size))
73
+
74
+ best_valids = [-100]
75
+ population = [] # (validation, robustness, sample, latency) tuples
76
+ child_pool = []
77
+ efficiency_pool = []
78
+ best_info = None
79
+ if verbose:
80
+ print("Generate random population...")
81
+ for _ in range(self.population_size):
82
+ sample, efficiency = self.random_valid_sample(constraint)
83
+ child_pool.append(sample)
84
+ efficiency_pool.append(efficiency)
85
+
86
+ accs = self.accuracy_predictor.predict_acc(child_pool)
87
+ robs = self.robustness_predictor.predict_rob(child_pool)
88
+ for i in range(self.population_size):
89
+ population.append((accs[i].item(), robs[i].item(), child_pool[i], efficiency_pool[i]))
90
+
91
+ if verbose:
92
+ print("Start Evolution...")
93
+ # After the population is seeded, proceed with evolving the population.
94
+ with tqdm(
95
+ total=self.max_time_budget,
96
+ desc="Searching with constraint (%s)" % constraint,
97
+ disable=(not verbose),
98
+ ) as t:
99
+ for i in range(self.max_time_budget):
100
+ parents = sorted(population, key=lambda x: x[0])[::-1][:parents_size]
101
+ acc = parents[0][0]
102
+ rob = parents[0][1]
103
+ t.set_postfix({"acc": parents[0][0] , "rob":parents[0][1]})
104
+ if not verbose and (i + 1) % 100 == 0:
105
+ print("Iter: {} Acc: {} Rob: {}".format(i + 1, parents[0][0],parents[0][1]))
106
+
107
+ if acc > best_valids[-1]:
108
+ best_valids.append(acc)
109
+ best_info = parents[0]
110
+ else:
111
+ best_valids.append(best_valids[-1])
112
+
113
+ population = parents
114
+ child_pool = []
115
+ efficiency_pool = []
116
+
117
+ for j in range(mutation_numbers):
118
+ par_sample = population[np.random.randint(parents_size)][2]
119
+ # Mutate
120
+ new_sample, efficiency = self.mutate_sample(par_sample, constraint)
121
+ child_pool.append(new_sample)
122
+ efficiency_pool.append(efficiency)
123
+
124
+ for j in range(self.population_size - mutation_numbers):
125
+ par_sample1 = population[np.random.randint(parents_size)][2]
126
+ par_sample2 = population[np.random.randint(parents_size)][2]
127
+ # Crossover
128
+ new_sample, efficiency = self.crossover_sample(
129
+ par_sample1, par_sample2, constraint
130
+ )
131
+ child_pool.append(new_sample)
132
+ efficiency_pool.append(efficiency)
133
+
134
+ accs = self.accuracy_predictor.predict_acc(child_pool)
135
+ robs = self.robustness_predictor.predict_rob(child_pool)
136
+ for j in range(self.population_size):
137
+ population.append(
138
+ (accs[j].item(), robs[j].item(), child_pool[j], efficiency_pool[j])
139
+ )
140
+
141
+ t.update(1)
142
+
143
+ return best_valids, best_info
proard/nas/search_algorithm/multi_evolution.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pymoo.core.individual import Individual
3
+ from pymoo.core.problem import Problem
4
+ from pymoo.core.sampling import Sampling
5
+ from pymoo.core.variable import Choice
6
+ __all__ = ["individual_to_arch_mbv","DynIndividual_mbv","DynProblem_mbv","individual_to_arch_res","DynIndividual_res","DynProblem_res","DynSampling","DynRandomSampler"]
7
+ def individual_to_arch_mbv(population, n_blocks):
8
+ archs = []
9
+ for individual in population:
10
+ archs.append(
11
+ {
12
+ "ks": individual[0:n_blocks],
13
+ "e": individual[n_blocks : 2 * n_blocks],
14
+ "d": individual[2 * n_blocks : -1],
15
+ "image_size": individual[-1:],
16
+ }
17
+ )
18
+ return archs
19
+ class DynIndividual_mbv(Individual):
20
+ def __init__(self, individual, accuracy_predictor,Robustness_predictor, config=None, **kwargs):
21
+ super().__init__(config=None, **kwargs)
22
+ self.X = np.concatenate(
23
+ (
24
+ individual[0]["ks"],
25
+ individual[0]["e"],
26
+ individual[0]["d"],
27
+ individual[0]["image_size"],
28
+ )
29
+ )
30
+ self.flops = individual[1]
31
+ self.accuracy = 100 - accuracy_predictor.predict_acc([individual[0]])
32
+ self.robustness = 100 - Robustness_predictor.predict_rob([individual[0]])
33
+ self.F = np.concatenate(([self.flops], [self.accuracy.squeeze().cpu().detach().numpy()],[self.robustness.squeeze().cpu().detach().numpy()]))
34
+
35
+
36
+
37
+ class DynProblem_mbv(Problem):
38
+ def __init__(self, efficiency_predictor, accuracy_predictor, robustness_predictor, num_blocks, num_stages, search_vars):
39
+ self.ks = Choice(options=search_vars.get('ks'))
40
+ self.e = Choice(options=search_vars.get('e'))
41
+ self.d = Choice(options=search_vars.get('d'))
42
+ self.r = Choice(options=search_vars.get('image_size'))
43
+
44
+ super().__init__(
45
+ vars= dict(zip(range(len(num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])), num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])),
46
+ n_obj=3,
47
+ n_constr=0,
48
+ )
49
+ self.efficiency_predictor = efficiency_predictor
50
+ self.accuracy_predictor = accuracy_predictor
51
+ self.robustness_predictor = robustness_predictor
52
+ self.blocks = num_blocks
53
+ self.stages = num_stages
54
+ self.search_vars = search_vars
55
+
56
+ def _evaluate(self, x, out, *args, **kwargs):
57
+ f1=[]
58
+ # x.shape = (population_size, n_var) = (100, 4)
59
+ arch = individual_to_arch_mbv(x, self.blocks)
60
+ for arc in arch:
61
+ f1.append(self.efficiency_predictor.get_efficiency(arc))
62
+ f2 = 100 - self.accuracy_predictor.predict_acc(arch).detach().cpu().numpy()
63
+ f3 = 100 - self.robustness_predictor.predict_rob(arch).detach().cpu().numpy()
64
+ out["F"] = np.column_stack([f1, f2,f3])
65
+
66
+
67
+ def individual_to_arch_res(population, n_blocks):
68
+ archs = []
69
+ for individual in population:
70
+ archs.append(
71
+ {
72
+ "e": individual[n_blocks : 2 * n_blocks],
73
+ "d": individual[2 * n_blocks : -1],
74
+ "w": individual[0:n_blocks],
75
+ "r": individual[-1:],
76
+ }
77
+ )
78
+ return archs
79
+ class DynIndividual_res(Individual):
80
+ def __init__(self, individual, accuracy_predictor,Robustness_predictor, config=None, **kwargs):
81
+ super().__init__(config=None, **kwargs)
82
+ self.X = np.concatenate(
83
+ (
84
+ individual[0]["e"],
85
+ individual[0]["d"],
86
+ individual[0]["w"],
87
+ [individual[0]["image_size"]],
88
+ )
89
+ )
90
+ self.flops = individual[1]
91
+ self.accuracy = 100 - accuracy_predictor.predict_acc([individual[0]])
92
+ self.robustness = 100 - Robustness_predictor.predict_rob([individual[0]])
93
+ self.F = np.concatenate(([self.flops], [self.accuracy.squeeze().cpu().detach().numpy()],[self.robustness.squeeze().cpu().detach().numpy()]))
94
+
95
+
96
+
97
+ class DynProblem_res(Problem):
98
+ def __init__(self, efficiency_predictor, accuracy_predictor, robustness_predictor, num_blocks, num_stages, search_vars):
99
+ self.e = Choice(options=search_vars.get('e'))
100
+ self.d = Choice(options=search_vars.get('d'))
101
+ self.w = Choice(options=search_vars.get('w'))
102
+ self.r = Choice(options=search_vars.get('image_size'))
103
+ super().__init__(
104
+ vars= dict(zip(range(len(num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])), num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])),
105
+ n_obj=3,
106
+ n_constr=0,
107
+ )
108
+ self.efficiency_predictor = efficiency_predictor
109
+ self.accuracy_predictor = accuracy_predictor
110
+ self.robustness_predictor = robustness_predictor
111
+ self.blocks = num_blocks
112
+ self.stages = num_stages
113
+ self.search_vars = search_vars
114
+
115
+ def _evaluate(self, x, out, *args, **kwargs):
116
+ f1={}
117
+ # x.shape = (population_size, n_var) = (100, 4)
118
+ arch = individual_to_arch_res(x, self.blocks)
119
+ for arc in arch:
120
+ f1.append(self.efficiency_predictor.get_efficiency(arc))
121
+ f2 = 100 - self.accuracy_predictor.predict_acc(arch)
122
+ f3 = 100 - self.robustness_predictor.predict_rob(arch)
123
+ out["F"] = np.column_stack([f1, f2,f3])
124
+
125
+
126
+
127
+ class DynSampling(Sampling):
128
+ def _do(self, problem, n_samples, **kwargs):
129
+ return [
130
+ [np.random.choice(var.options) for key,var in problem.vars.items()]
131
+ for _ in range(n_samples)
132
+ ]
133
+
134
+
135
+ class DynRandomSampler:
136
+ def __init__(self, arch_manager, efficiency_predictor):
137
+ self.arch_manager = arch_manager
138
+ self.efficiency_predictor = efficiency_predictor
139
+
140
+ def random_sample(self):
141
+ sample = self.arch_manager.random_sample_arch()
142
+ efficiency = self.efficiency_predictor.get_efficiency(sample)
143
+ return sample, efficiency
proard/utils/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ from .pytorch_modules import *
6
+ from .pytorch_utils import *
7
+ from .my_modules import *
8
+ from .flops_counter import *
9
+ from .common_tools import *
10
+ from .my_dataloader import *
proard/utils/common_tools.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import numpy as np
6
+ import os
7
+ import sys
8
+ import torch
9
+
10
+ try:
11
+ from urllib import urlretrieve
12
+ except ImportError:
13
+ from urllib.request import urlretrieve
14
+
15
+ __all__ = [
16
+ "sort_dict",
17
+ "get_same_padding",
18
+ "get_split_list",
19
+ "list_sum",
20
+ "list_mean",
21
+ "list_join",
22
+ "subset_mean",
23
+ "sub_filter_start_end",
24
+ "min_divisible_value",
25
+ "val2list",
26
+ "download_url",
27
+ "write_log",
28
+ "pairwise_accuracy",
29
+ "accuracy",
30
+ "AverageMeter",
31
+ "MultiClassAverageMeter",
32
+ "DistributedMetric",
33
+ "DistributedTensor",
34
+ ]
35
+
36
+
37
+ def sort_dict(src_dict, reverse=False, return_dict=True):
38
+ output = sorted(src_dict.items(), key=lambda x: x[1], reverse=reverse)
39
+ if return_dict:
40
+ return dict(output)
41
+ else:
42
+ return output
43
+
44
+
45
+ def get_same_padding(kernel_size):
46
+ if isinstance(kernel_size, tuple):
47
+ assert len(kernel_size) == 2, "invalid kernel size: %s" % kernel_size
48
+ p1 = get_same_padding(kernel_size[0])
49
+ p2 = get_same_padding(kernel_size[1])
50
+ return p1, p2
51
+ assert isinstance(kernel_size, int), "kernel size should be either `int` or `tuple`"
52
+ assert kernel_size % 2 > 0, "kernel size should be odd number"
53
+ return kernel_size // 2
54
+
55
+
56
+ def get_split_list(in_dim, child_num, accumulate=False):
57
+ in_dim_list = [in_dim // child_num] * child_num
58
+ for _i in range(in_dim % child_num):
59
+ in_dim_list[_i] += 1
60
+ if accumulate:
61
+ for i in range(1, child_num):
62
+ in_dim_list[i] += in_dim_list[i - 1]
63
+ return in_dim_list
64
+
65
+
66
+ def list_sum(x):
67
+ return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
68
+
69
+
70
+ def list_mean(x):
71
+ return list_sum(x) / len(x)
72
+
73
+
74
+ def list_join(val_list, sep="\t"):
75
+ return sep.join([str(val) for val in val_list])
76
+
77
+
78
+ def subset_mean(val_list, sub_indexes):
79
+ sub_indexes = val2list(sub_indexes, 1)
80
+ return list_mean([val_list[idx] for idx in sub_indexes])
81
+
82
+
83
+ def sub_filter_start_end(kernel_size, sub_kernel_size):
84
+ center = kernel_size // 2
85
+ dev = sub_kernel_size // 2
86
+ start, end = center - dev, center + dev + 1
87
+ assert end - start == sub_kernel_size
88
+ return start, end
89
+
90
+
91
+ def min_divisible_value(n1, v1):
92
+ """make sure v1 is divisible by n1, otherwise decrease v1"""
93
+ if v1 >= n1:
94
+ return n1
95
+ while n1 % v1 != 0:
96
+ v1 -= 1
97
+ return v1
98
+
99
+
100
+ def val2list(val, repeat_time=1):
101
+ if isinstance(val, list) or isinstance(val, np.ndarray):
102
+ return val
103
+ elif isinstance(val, tuple):
104
+ return list(val)
105
+ else:
106
+ return [val for _ in range(repeat_time)]
107
+
108
+
109
+ def download_url(url, model_dir="~/.torch/", overwrite=False):
110
+ target_dir = url.split("/")[-1]
111
+ model_dir = os.path.expanduser(model_dir)
112
+ try:
113
+ if not os.path.exists(model_dir):
114
+ os.makedirs(model_dir)
115
+ model_dir = os.path.join(model_dir, target_dir)
116
+ cached_file = model_dir
117
+ if not os.path.exists(cached_file) or overwrite:
118
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
119
+ urlretrieve(url, cached_file)
120
+ return cached_file
121
+ except Exception as e:
122
+ # remove lock file so download can be executed next time.
123
+ os.remove(os.path.join(model_dir, "download.lock"))
124
+ sys.stderr.write("Failed to download from url %s" % url + "\n" + str(e) + "\n")
125
+ return None
126
+
127
+
128
+ def write_log(logs_path, log_str, prefix="valid", should_print=True, mode="a"):
129
+ if not os.path.exists(logs_path):
130
+ os.makedirs(logs_path, exist_ok=True)
131
+ """ prefix: valid, train, test """
132
+ if prefix in ["valid", "test"]:
133
+ with open(os.path.join(logs_path, "valid_console.txt"), mode) as fout:
134
+ fout.write(log_str + "\n")
135
+ fout.flush()
136
+ if prefix in ["valid", "test", "train"]:
137
+ with open(os.path.join(logs_path, "train_console.txt"), mode) as fout:
138
+ if prefix in ["valid", "test"]:
139
+ fout.write("=" * 10)
140
+ fout.write(log_str + "\n")
141
+ fout.flush()
142
+ else:
143
+ with open(os.path.join(logs_path, "%s.txt" % prefix), mode) as fout:
144
+ fout.write(log_str + "\n")
145
+ fout.flush()
146
+ if should_print:
147
+ print(log_str)
148
+
149
+
150
+ def pairwise_accuracy(la, lb, n_samples=200000):
151
+ n = len(la)
152
+ assert n == len(lb)
153
+ total = 0
154
+ count = 0
155
+ for _ in range(n_samples):
156
+ i = np.random.randint(n)
157
+ j = np.random.randint(n)
158
+ while i == j:
159
+ j = np.random.randint(n)
160
+ if la[i] >= la[j] and lb[i] >= lb[j]:
161
+ count += 1
162
+ if la[i] < la[j] and lb[i] < lb[j]:
163
+ count += 1
164
+ total += 1
165
+ return float(count) / total
166
+
167
+
168
+
169
+
170
+ def accuracy(output, target, topk=(1,)):
171
+ """Computes the precision@k for the specified values of k"""
172
+ maxk = max(topk)
173
+ batch_size = target.size(0)
174
+
175
+ _, pred = output.topk(maxk, 1, True, True)
176
+ pred = pred.t()
177
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
178
+
179
+ res = []
180
+ for k in topk:
181
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
182
+ res.append(correct_k.mul_(100.0 / batch_size))
183
+ return res
184
+
185
+
186
+ class AverageMeter(object):
187
+ """
188
+ Computes and stores the average and current value
189
+ Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
190
+ """
191
+
192
+ def __init__(self):
193
+ self.val = 0
194
+ self.avg = 0
195
+ self.sum = 0
196
+ self.count = 0
197
+
198
+ def reset(self):
199
+ self.val = 0
200
+ self.avg = 0
201
+ self.sum = 0
202
+ self.count = 0
203
+
204
+ def update(self, val, n=1):
205
+ self.val = val
206
+ self.sum += val * n
207
+ self.count += n
208
+ self.avg = self.sum / self.count
209
+
210
+
211
+ class MultiClassAverageMeter:
212
+
213
+ """Multi Binary Classification Tasks"""
214
+
215
+ def __init__(self, num_classes, balanced=False, **kwargs):
216
+
217
+ super(MultiClassAverageMeter, self).__init__()
218
+ self.num_classes = num_classes
219
+ self.balanced = balanced
220
+
221
+ self.counts = []
222
+ for k in range(self.num_classes):
223
+ self.counts.append(np.ndarray((2, 2), dtype=np.float32))
224
+
225
+ self.reset()
226
+
227
+ def reset(self):
228
+ for k in range(self.num_classes):
229
+ self.counts[k].fill(0)
230
+
231
+ def add(self, outputs, targets):
232
+ outputs = outputs.data.cpu().numpy()
233
+ targets = targets.data.cpu().numpy()
234
+
235
+ for k in range(self.num_classes):
236
+ output = np.argmax(outputs[:, k, :], axis=1)
237
+ target = targets[:, k]
238
+
239
+ x = output + 2 * target
240
+ bincount = np.bincount(x.astype(np.int32), minlength=2 ** 2)
241
+
242
+ self.counts[k] += bincount.reshape((2, 2))
243
+
244
+ def value(self):
245
+ mean = 0
246
+ for k in range(self.num_classes):
247
+ if self.balanced:
248
+ value = np.mean(
249
+ (
250
+ self.counts[k]
251
+ / np.maximum(np.sum(self.counts[k], axis=1), 1)[:, None]
252
+ ).diagonal()
253
+ )
254
+ else:
255
+ value = np.sum(self.counts[k].diagonal()) / np.maximum(
256
+ np.sum(self.counts[k]), 1
257
+ )
258
+
259
+ mean += value / self.num_classes * 100.0
260
+ return mean
261
+
262
+
263
+ class DistributedMetric(object):
264
+ """
265
+ Horovod: average metrics from distributed training.
266
+ """
267
+
268
+ def __init__(self, name):
269
+ self.name = name
270
+ self.sum = torch.zeros(1)[0]
271
+ self.count = torch.zeros(1)[0]
272
+
273
+ def update(self, val, delta_n=1):
274
+ import horovod.torch as hvd
275
+
276
+ val *= delta_n
277
+ self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
278
+ self.count += delta_n
279
+
280
+ @property
281
+ def avg(self):
282
+ return self.sum / self.count
283
+
284
+
285
+ class DistributedTensor(object):
286
+ def __init__(self, name):
287
+ self.name = name
288
+ self.sum = None
289
+ self.count = torch.zeros(1)[0]
290
+ self.synced = False
291
+
292
+ def update(self, val, delta_n=1):
293
+ val *= delta_n
294
+ if self.sum is None:
295
+ self.sum = val.detach()
296
+ else:
297
+ self.sum += val.detach()
298
+ self.count += delta_n
299
+
300
+ @property
301
+ def avg(self):
302
+ import horovod.torch as hvd
303
+
304
+ if not self.synced:
305
+ self.sum = hvd.allreduce(self.sum, name=self.name)
306
+ self.synced = True
307
+ return self.sum / self.count
proard/utils/flops_counter.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .my_modules import MyConv2d
9
+
10
+ __all__ = ["profile"]
11
+
12
+
13
+ def count_convNd(m, _, y):
14
+ cin = m.in_channels
15
+
16
+ kernel_ops = m.weight.size()[2] * m.weight.size()[3]
17
+ ops_per_element = kernel_ops
18
+ output_elements = y.nelement()
19
+
20
+ # cout x oW x oH
21
+ total_ops = cin * output_elements * ops_per_element // m.groups
22
+ m.total_ops = torch.zeros(1).fill_(total_ops)
23
+
24
+
25
+ def count_linear(m, _, __):
26
+ total_ops = m.in_features * m.out_features
27
+
28
+ m.total_ops = torch.zeros(1).fill_(total_ops)
29
+
30
+
31
+ register_hooks = {
32
+ nn.Conv1d: count_convNd,
33
+ nn.Conv2d: count_convNd,
34
+ nn.Conv3d: count_convNd,
35
+ MyConv2d: count_convNd,
36
+ ######################################
37
+ nn.Linear: count_linear,
38
+ ######################################
39
+ nn.Dropout: None,
40
+ nn.Dropout2d: None,
41
+ nn.Dropout3d: None,
42
+ nn.BatchNorm2d: None,
43
+ }
44
+
45
+
46
+ def profile(model, input_size, custom_ops=None):
47
+ handler_collection = []
48
+ custom_ops = {} if custom_ops is None else custom_ops
49
+
50
+ def add_hooks(m_):
51
+ if len(list(m_.children())) > 0:
52
+ return
53
+
54
+ m_.register_buffer("total_ops", torch.zeros(1))
55
+ m_.register_buffer("total_params", torch.zeros(1))
56
+
57
+ for p in m_.parameters():
58
+ m_.total_params += torch.zeros(1).fill_(p.numel())
59
+
60
+ m_type = type(m_)
61
+ fn = None
62
+
63
+ if m_type in custom_ops:
64
+ fn = custom_ops[m_type]
65
+ elif m_type in register_hooks:
66
+ fn = register_hooks[m_type]
67
+
68
+ if fn is not None:
69
+ _handler = m_.register_forward_hook(fn)
70
+ handler_collection.append(_handler)
71
+
72
+ original_device = model.parameters().__next__().device
73
+ training = model.training
74
+
75
+ model.eval()
76
+ model.apply(add_hooks)
77
+
78
+ x = torch.zeros(input_size).to(original_device)
79
+ with torch.no_grad():
80
+ model(x)
81
+
82
+ total_ops = 0
83
+ total_params = 0
84
+ for m in model.modules():
85
+ if len(list(m.children())) > 0: # skip for non-leaf module
86
+ continue
87
+ total_ops += m.total_ops
88
+ total_params += m.total_params
89
+
90
+ total_ops = total_ops.item()
91
+ total_params = total_params.item()
92
+
93
+ model.train(training).to(original_device)
94
+ for handler in handler_collection:
95
+ handler.remove()
96
+
97
+ return total_ops, total_params
proard/utils/layers.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from collections import OrderedDict
9
+ from proard.utils import get_same_padding, min_divisible_value, SEModule, ShuffleLayer
10
+ from proard.utils import MyNetwork, MyModule
11
+ from proard.utils import build_activation, make_divisible
12
+
13
+ __all__ = [
14
+ "set_layer_from_config",
15
+ "ConvLayer",
16
+ "IdentityLayer",
17
+ "LinearLayer",
18
+ "MultiHeadLinearLayer",
19
+ "ZeroLayer",
20
+ "MBConvLayer",
21
+ "ResidualBlock",
22
+ "ResNetBottleneckBlock",
23
+ ]
24
+
25
+
26
+ def set_layer_from_config(layer_config):
27
+ if layer_config is None:
28
+ return None
29
+
30
+ name2layer = {
31
+ ConvLayer.__name__: ConvLayer,
32
+ IdentityLayer.__name__: IdentityLayer,
33
+ LinearLayer.__name__: LinearLayer,
34
+ MultiHeadLinearLayer.__name__: MultiHeadLinearLayer,
35
+ ZeroLayer.__name__: ZeroLayer,
36
+ MBConvLayer.__name__: MBConvLayer,
37
+ "MBInvertedConvLayer": MBConvLayer,
38
+ ##########################################################
39
+ ResidualBlock.__name__: ResidualBlock,
40
+ ResNetBottleneckBlock.__name__: ResNetBottleneckBlock,
41
+ }
42
+
43
+ layer_name = layer_config.pop("name")
44
+ layer = name2layer[layer_name]
45
+ return layer.build_from_config(layer_config)
46
+
47
+
48
+ class My2DLayer(MyModule):
49
+ def __init__(
50
+ self,
51
+ in_channels,
52
+ out_channels,
53
+ use_bn=True,
54
+ act_func="relu",
55
+ dropout_rate=0,
56
+ ops_order="weight_bn_act",
57
+ ):
58
+ super(My2DLayer, self).__init__()
59
+ self.in_channels = in_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.use_bn = use_bn
63
+ self.act_func = act_func
64
+ self.dropout_rate = dropout_rate
65
+ self.ops_order = ops_order
66
+
67
+ """ modules """
68
+ modules = {}
69
+ # batch norm
70
+ if self.use_bn:
71
+ if self.bn_before_weight:
72
+ modules["bn"] = nn.BatchNorm2d(in_channels)
73
+ else:
74
+ modules["bn"] = nn.BatchNorm2d(out_channels)
75
+ else:
76
+ modules["bn"] = None
77
+ # activation
78
+ modules["act"] = build_activation(
79
+ self.act_func, self.ops_list[0] != "act" and self.use_bn
80
+ )
81
+ # dropout
82
+ if self.dropout_rate > 0:
83
+ modules["dropout"] = nn.Dropout2d(self.dropout_rate, inplace=True)
84
+ else:
85
+ modules["dropout"] = None
86
+ # weight
87
+ modules["weight"] = self.weight_op()
88
+
89
+ # add modules
90
+ for op in self.ops_list:
91
+ if modules[op] is None:
92
+ continue
93
+ elif op == "weight":
94
+ # dropout before weight operation
95
+ if modules["dropout"] is not None:
96
+ self.add_module("dropout", modules["dropout"])
97
+ for key in modules["weight"]:
98
+ self.add_module(key, modules["weight"][key])
99
+ else:
100
+ self.add_module(op, modules[op])
101
+
102
+ @property
103
+ def ops_list(self):
104
+ return self.ops_order.split("_")
105
+
106
+ @property
107
+ def bn_before_weight(self):
108
+ for op in self.ops_list:
109
+ if op == "bn":
110
+ return True
111
+ elif op == "weight":
112
+ return False
113
+ raise ValueError("Invalid ops_order: %s" % self.ops_order)
114
+
115
+ def weight_op(self):
116
+ raise NotImplementedError
117
+
118
+ """ Methods defined in MyModule """
119
+
120
+ def forward(self, x):
121
+ # similar to nn.Sequential
122
+ for module in self._modules.values():
123
+ x = module(x)
124
+ return x
125
+
126
+ @property
127
+ def module_str(self):
128
+ raise NotImplementedError
129
+
130
+ @property
131
+ def config(self):
132
+ return {
133
+ "in_channels": self.in_channels,
134
+ "out_channels": self.out_channels,
135
+ "use_bn": self.use_bn,
136
+ "act_func": self.act_func,
137
+ "dropout_rate": self.dropout_rate,
138
+ "ops_order": self.ops_order,
139
+ }
140
+
141
+ @staticmethod
142
+ def build_from_config(config):
143
+ raise NotImplementedError
144
+
145
+
146
+ class ConvLayer(My2DLayer):
147
+ def __init__(
148
+ self,
149
+ in_channels,
150
+ out_channels,
151
+ kernel_size=3,
152
+ stride=1,
153
+ dilation=1,
154
+ groups=1,
155
+ bias=False,
156
+ has_shuffle=False,
157
+ use_se=False,
158
+ use_bn=True,
159
+ act_func="relu",
160
+ dropout_rate=0,
161
+ ops_order="weight_bn_act",
162
+ ):
163
+ # default normal 3x3_Conv with bn and relu
164
+ self.kernel_size = kernel_size
165
+ self.stride = stride
166
+ self.dilation = dilation
167
+ self.groups = groups
168
+ self.bias = bias
169
+ self.has_shuffle = has_shuffle
170
+ self.use_se = use_se
171
+
172
+ super(ConvLayer, self).__init__(
173
+ in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order
174
+ )
175
+ if self.use_se:
176
+ self.add_module("se", SEModule(self.out_channels))
177
+
178
+ def weight_op(self):
179
+ padding = get_same_padding(self.kernel_size)
180
+ if isinstance(padding, int):
181
+ padding *= self.dilation
182
+ else:
183
+ padding[0] *= self.dilation
184
+ padding[1] *= self.dilation
185
+
186
+ weight_dict = OrderedDict(
187
+ {
188
+ "conv": nn.Conv2d(
189
+ self.in_channels,
190
+ self.out_channels,
191
+ kernel_size=self.kernel_size,
192
+ stride=self.stride,
193
+ padding=padding,
194
+ dilation=self.dilation,
195
+ groups=min_divisible_value(self.in_channels, self.groups),
196
+ bias=self.bias,
197
+ )
198
+ }
199
+ )
200
+ if self.has_shuffle and self.groups > 1:
201
+ weight_dict["shuffle"] = ShuffleLayer(self.groups)
202
+
203
+ return weight_dict
204
+
205
+ @property
206
+ def module_str(self):
207
+ if isinstance(self.kernel_size, int):
208
+ kernel_size = (self.kernel_size, self.kernel_size)
209
+ else:
210
+ kernel_size = self.kernel_size
211
+ if self.groups == 1:
212
+ if self.dilation > 1:
213
+ conv_str = "%dx%d_DilatedConv" % (kernel_size[0], kernel_size[1])
214
+ else:
215
+ conv_str = "%dx%d_Conv" % (kernel_size[0], kernel_size[1])
216
+ else:
217
+ if self.dilation > 1:
218
+ conv_str = "%dx%d_DilatedGroupConv" % (kernel_size[0], kernel_size[1])
219
+ else:
220
+ conv_str = "%dx%d_GroupConv" % (kernel_size[0], kernel_size[1])
221
+ conv_str += "_O%d" % self.out_channels
222
+ if self.use_se:
223
+ conv_str = "SE_" + conv_str
224
+ conv_str += "_" + self.act_func.upper()
225
+ if self.use_bn:
226
+ if isinstance(self.bn, nn.GroupNorm):
227
+ conv_str += "_GN%d" % self.bn.num_groups
228
+ elif isinstance(self.bn, nn.BatchNorm2d):
229
+ conv_str += "_BN"
230
+ return conv_str
231
+
232
+ @property
233
+ def config(self):
234
+ return {
235
+ "name": ConvLayer.__name__,
236
+ "kernel_size": self.kernel_size,
237
+ "stride": self.stride,
238
+ "dilation": self.dilation,
239
+ "groups": self.groups,
240
+ "bias": self.bias,
241
+ "has_shuffle": self.has_shuffle,
242
+ "use_se": self.use_se,
243
+ **super(ConvLayer, self).config,
244
+ }
245
+
246
+ @staticmethod
247
+ def build_from_config(config):
248
+ return ConvLayer(**config)
249
+
250
+
251
+ class IdentityLayer(My2DLayer):
252
+ def __init__(
253
+ self,
254
+ in_channels,
255
+ out_channels,
256
+ use_bn=False,
257
+ act_func=None,
258
+ dropout_rate=0,
259
+ ops_order="weight_bn_act",
260
+ ):
261
+ super(IdentityLayer, self).__init__(
262
+ in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order
263
+ )
264
+
265
+ def weight_op(self):
266
+ return None
267
+
268
+ @property
269
+ def module_str(self):
270
+ return "Identity"
271
+
272
+ @property
273
+ def config(self):
274
+ return {
275
+ "name": IdentityLayer.__name__,
276
+ **super(IdentityLayer, self).config,
277
+ }
278
+
279
+ @staticmethod
280
+ def build_from_config(config):
281
+ return IdentityLayer(**config)
282
+
283
+
284
+ class LinearLayer(MyModule):
285
+ def __init__(
286
+ self,
287
+ in_features,
288
+ out_features,
289
+ bias=True,
290
+ use_bn=False,
291
+ act_func=None,
292
+ dropout_rate=0,
293
+ ops_order="weight_bn_act",
294
+ ):
295
+ super(LinearLayer, self).__init__()
296
+
297
+ self.in_features = in_features
298
+ self.out_features = out_features
299
+ self.bias = bias
300
+
301
+ self.use_bn = use_bn
302
+ self.act_func = act_func
303
+ self.dropout_rate = dropout_rate
304
+ self.ops_order = ops_order
305
+
306
+ """ modules """
307
+ modules = {}
308
+ # batch norm
309
+ if self.use_bn:
310
+ if self.bn_before_weight:
311
+ modules["bn"] = nn.BatchNorm1d(in_features)
312
+ else:
313
+ modules["bn"] = nn.BatchNorm1d(out_features)
314
+ else:
315
+ modules["bn"] = None
316
+ # activation
317
+ modules["act"] = build_activation(self.act_func, self.ops_list[0] != "act")
318
+ # dropout
319
+ if self.dropout_rate > 0:
320
+ modules["dropout"] = nn.Dropout(self.dropout_rate, inplace=True)
321
+ else:
322
+ modules["dropout"] = None
323
+ # linear
324
+ modules["weight"] = {
325
+ "linear": nn.Linear(self.in_features, self.out_features, self.bias)
326
+ }
327
+
328
+ # add modules
329
+ for op in self.ops_list:
330
+ if modules[op] is None:
331
+ continue
332
+ elif op == "weight":
333
+ if modules["dropout"] is not None:
334
+ self.add_module("dropout", modules["dropout"])
335
+ for key in modules["weight"]:
336
+ self.add_module(key, modules["weight"][key])
337
+ else:
338
+ self.add_module(op, modules[op])
339
+
340
+ @property
341
+ def ops_list(self):
342
+ return self.ops_order.split("_")
343
+
344
+ @property
345
+ def bn_before_weight(self):
346
+ for op in self.ops_list:
347
+ if op == "bn":
348
+ return True
349
+ elif op == "weight":
350
+ return False
351
+ raise ValueError("Invalid ops_order: %s" % self.ops_order)
352
+
353
+ def forward(self, x):
354
+ for module in self._modules.values():
355
+ x = module(x)
356
+ return x
357
+
358
+ @property
359
+ def module_str(self):
360
+ return "%dx%d_Linear" % (self.in_features, self.out_features)
361
+
362
+ @property
363
+ def config(self):
364
+ return {
365
+ "name": LinearLayer.__name__,
366
+ "in_features": self.in_features,
367
+ "out_features": self.out_features,
368
+ "bias": self.bias,
369
+ "use_bn": self.use_bn,
370
+ "act_func": self.act_func,
371
+ "dropout_rate": self.dropout_rate,
372
+ "ops_order": self.ops_order,
373
+ }
374
+
375
+ @staticmethod
376
+ def build_from_config(config):
377
+ return LinearLayer(**config)
378
+
379
+
380
+ class MultiHeadLinearLayer(MyModule):
381
+ def __init__(
382
+ self, in_features, out_features, num_heads=1, bias=True, dropout_rate=0
383
+ ):
384
+ super(MultiHeadLinearLayer, self).__init__()
385
+ self.in_features = in_features
386
+ self.out_features = out_features
387
+ self.num_heads = num_heads
388
+
389
+ self.bias = bias
390
+ self.dropout_rate = dropout_rate
391
+
392
+ if self.dropout_rate > 0:
393
+ self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
394
+ else:
395
+ self.dropout = None
396
+
397
+ self.layers = nn.ModuleList()
398
+ for k in range(num_heads):
399
+ layer = nn.Linear(in_features, out_features, self.bias)
400
+ self.layers.append(layer)
401
+
402
+ def forward(self, inputs):
403
+ if self.dropout is not None:
404
+ inputs = self.dropout(inputs)
405
+
406
+ outputs = []
407
+ for layer in self.layers:
408
+ output = layer.forward(inputs)
409
+ outputs.append(output)
410
+
411
+ outputs = torch.stack(outputs, dim=1)
412
+ return outputs
413
+
414
+ @property
415
+ def module_str(self):
416
+ return self.__repr__()
417
+
418
+ @property
419
+ def config(self):
420
+ return {
421
+ "name": MultiHeadLinearLayer.__name__,
422
+ "in_features": self.in_features,
423
+ "out_features": self.out_features,
424
+ "num_heads": self.num_heads,
425
+ "bias": self.bias,
426
+ "dropout_rate": self.dropout_rate,
427
+ }
428
+
429
+ @staticmethod
430
+ def build_from_config(config):
431
+ return MultiHeadLinearLayer(**config)
432
+
433
+ def __repr__(self):
434
+ return (
435
+ "MultiHeadLinear(in_features=%d, out_features=%d, num_heads=%d, bias=%s, dropout_rate=%s)"
436
+ % (
437
+ self.in_features,
438
+ self.out_features,
439
+ self.num_heads,
440
+ self.bias,
441
+ self.dropout_rate,
442
+ )
443
+ )
444
+
445
+
446
+ class ZeroLayer(MyModule):
447
+ def __init__(self):
448
+ super(ZeroLayer, self).__init__()
449
+
450
+ def forward(self, x):
451
+ raise ValueError
452
+
453
+ @property
454
+ def module_str(self):
455
+ return "Zero"
456
+
457
+ @property
458
+ def config(self):
459
+ return {
460
+ "name": ZeroLayer.__name__,
461
+ }
462
+
463
+ @staticmethod
464
+ def build_from_config(config):
465
+ return ZeroLayer()
466
+
467
+
468
+ class MBConvLayer(MyModule):
469
+ def __init__(
470
+ self,
471
+ in_channels,
472
+ out_channels,
473
+ kernel_size=3,
474
+ stride=1,
475
+ expand_ratio=6,
476
+ mid_channels=None,
477
+ act_func="relu6",
478
+ use_se=False,
479
+ groups=None,
480
+ ):
481
+ super(MBConvLayer, self).__init__()
482
+
483
+ self.in_channels = in_channels
484
+ self.out_channels = out_channels
485
+
486
+ self.kernel_size = kernel_size
487
+ self.stride = stride
488
+ self.expand_ratio = expand_ratio
489
+ self.mid_channels = mid_channels
490
+ self.act_func = act_func
491
+ self.use_se = use_se
492
+ self.groups = groups
493
+
494
+ if self.mid_channels is None:
495
+ feature_dim = round(self.in_channels * self.expand_ratio)
496
+ else:
497
+ feature_dim = self.mid_channels
498
+
499
+ if self.expand_ratio == 1:
500
+ self.inverted_bottleneck = None
501
+ else:
502
+ self.inverted_bottleneck = nn.Sequential(
503
+ OrderedDict(
504
+ [
505
+ (
506
+ "conv",
507
+ nn.Conv2d(
508
+ self.in_channels, feature_dim, 1, 1, 0, bias=False
509
+ ),
510
+ ),
511
+ ("bn", nn.BatchNorm2d(feature_dim)),
512
+ ("act", build_activation(self.act_func, inplace=True)),
513
+ ]
514
+ )
515
+ )
516
+
517
+ pad = get_same_padding(self.kernel_size)
518
+ groups = (
519
+ feature_dim
520
+ if self.groups is None
521
+ else min_divisible_value(feature_dim, self.groups)
522
+ )
523
+ depth_conv_modules = [
524
+ (
525
+ "conv",
526
+ nn.Conv2d(
527
+ feature_dim,
528
+ feature_dim,
529
+ kernel_size,
530
+ stride,
531
+ pad,
532
+ groups=groups,
533
+ bias=False,
534
+ ),
535
+ ),
536
+ ("bn", nn.BatchNorm2d(feature_dim)),
537
+ ("act", build_activation(self.act_func, inplace=True)),
538
+ ]
539
+ if self.use_se:
540
+ depth_conv_modules.append(("se", SEModule(feature_dim)))
541
+ self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))
542
+
543
+ self.point_linear = nn.Sequential(
544
+ OrderedDict(
545
+ [
546
+ ("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
547
+ ("bn", nn.BatchNorm2d(out_channels)),
548
+ ]
549
+ )
550
+ )
551
+
552
+ def forward(self, x):
553
+ if self.inverted_bottleneck:
554
+ x = self.inverted_bottleneck(x)
555
+ x = self.depth_conv(x)
556
+ x = self.point_linear(x)
557
+ return x
558
+
559
+ @property
560
+ def module_str(self):
561
+ if self.mid_channels is None:
562
+ expand_ratio = self.expand_ratio
563
+ else:
564
+ expand_ratio = self.mid_channels // self.in_channels
565
+ layer_str = "%dx%d_MBConv%d_%s" % (
566
+ self.kernel_size,
567
+ self.kernel_size,
568
+ expand_ratio,
569
+ self.act_func.upper(),
570
+ )
571
+ if self.use_se:
572
+ layer_str = "SE_" + layer_str
573
+ layer_str += "_O%d" % self.out_channels
574
+ if self.groups is not None:
575
+ layer_str += "_G%d" % self.groups
576
+ if isinstance(self.point_linear.bn, nn.GroupNorm):
577
+ layer_str += "_GN%d" % self.point_linear.bn.num_groups
578
+ elif isinstance(self.point_linear.bn, nn.BatchNorm2d):
579
+ layer_str += "_BN"
580
+
581
+ return layer_str
582
+
583
+ @property
584
+ def config(self):
585
+ return {
586
+ "name": MBConvLayer.__name__,
587
+ "in_channels": self.in_channels,
588
+ "out_channels": self.out_channels,
589
+ "kernel_size": self.kernel_size,
590
+ "stride": self.stride,
591
+ "expand_ratio": self.expand_ratio,
592
+ "mid_channels": self.mid_channels,
593
+ "act_func": self.act_func,
594
+ "use_se": self.use_se,
595
+ "groups": self.groups,
596
+ }
597
+
598
+ @staticmethod
599
+ def build_from_config(config):
600
+ return MBConvLayer(**config)
601
+
602
+
603
+ class ResidualBlock(MyModule):
604
+ def __init__(self, conv, shortcut):
605
+ super(ResidualBlock, self).__init__()
606
+
607
+ self.conv = conv
608
+ self.shortcut = shortcut
609
+
610
+ def forward(self, x):
611
+ if self.conv is None or isinstance(self.conv, ZeroLayer):
612
+ res = x
613
+ elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
614
+ res = self.conv(x)
615
+ else:
616
+ res = self.conv(x) + self.shortcut(x)
617
+ return res
618
+
619
+ @property
620
+ def module_str(self):
621
+ return "(%s, %s)" % (
622
+ self.conv.module_str if self.conv is not None else None,
623
+ self.shortcut.module_str if self.shortcut is not None else None,
624
+ )
625
+
626
+ @property
627
+ def config(self):
628
+ return {
629
+ "name": ResidualBlock.__name__,
630
+ "conv": self.conv.config if self.conv is not None else None,
631
+ "shortcut": self.shortcut.config if self.shortcut is not None else None,
632
+ }
633
+
634
+ @staticmethod
635
+ def build_from_config(config):
636
+ conv_config = (
637
+ config["conv"] if "conv" in config else config["mobile_inverted_conv"]
638
+ )
639
+ conv = set_layer_from_config(conv_config)
640
+ shortcut = set_layer_from_config(config["shortcut"])
641
+ return ResidualBlock(conv, shortcut)
642
+
643
+ @property
644
+ def mobile_inverted_conv(self):
645
+ return self.conv
646
+
647
+
648
+ class ResNetBottleneckBlock(MyModule):
649
+ def __init__(
650
+ self,
651
+ in_channels,
652
+ out_channels,
653
+ kernel_size=3,
654
+ stride=1,
655
+ expand_ratio=0.25,
656
+ mid_channels=None,
657
+ act_func="relu",
658
+ groups=1,
659
+ downsample_mode="avgpool_conv",
660
+ ):
661
+ super(ResNetBottleneckBlock, self).__init__()
662
+
663
+ self.in_channels = in_channels
664
+ self.out_channels = out_channels
665
+
666
+ self.kernel_size = kernel_size
667
+ self.stride = stride
668
+ self.expand_ratio = expand_ratio
669
+ self.mid_channels = mid_channels
670
+ self.act_func = act_func
671
+ self.groups = groups
672
+
673
+ self.downsample_mode = downsample_mode
674
+
675
+ if self.mid_channels is None:
676
+ feature_dim = round(self.out_channels * self.expand_ratio)
677
+ else:
678
+ feature_dim = self.mid_channels
679
+
680
+ feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
681
+ self.mid_channels = feature_dim
682
+
683
+ # build modules
684
+ self.conv1 = nn.Sequential(
685
+ OrderedDict(
686
+ [
687
+ (
688
+ "conv",
689
+ nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False),
690
+ ),
691
+ ("bn", nn.BatchNorm2d(feature_dim)),
692
+ ("act", build_activation(self.act_func, inplace=True)),
693
+ ]
694
+ )
695
+ )
696
+
697
+ pad = get_same_padding(self.kernel_size)
698
+ self.conv2 = nn.Sequential(
699
+ OrderedDict(
700
+ [
701
+ (
702
+ "conv",
703
+ nn.Conv2d(
704
+ feature_dim,
705
+ feature_dim,
706
+ kernel_size,
707
+ stride,
708
+ pad,
709
+ groups=groups,
710
+ bias=False,
711
+ ),
712
+ ),
713
+ ("bn", nn.BatchNorm2d(feature_dim)),
714
+ ("act", build_activation(self.act_func, inplace=True)),
715
+ ]
716
+ )
717
+ )
718
+
719
+ self.conv3 = nn.Sequential(
720
+ OrderedDict(
721
+ [
722
+ (
723
+ "conv",
724
+ nn.Conv2d(feature_dim, self.out_channels, 1, 1, 0, bias=False),
725
+ ),
726
+ ("bn", nn.BatchNorm2d(self.out_channels)),
727
+ ]
728
+ )
729
+ )
730
+
731
+ if stride == 1 and in_channels == out_channels:
732
+ self.downsample = IdentityLayer(in_channels, out_channels)
733
+ elif self.downsample_mode == "conv":
734
+ self.downsample = nn.Sequential(
735
+ OrderedDict(
736
+ [
737
+ (
738
+ "conv",
739
+ nn.Conv2d(
740
+ in_channels, out_channels, 1, stride, 0, bias=False
741
+ ),
742
+ ),
743
+ ("bn", nn.BatchNorm2d(out_channels)),
744
+ ]
745
+ )
746
+ )
747
+ elif self.downsample_mode == "avgpool_conv":
748
+ self.downsample = nn.Sequential(
749
+ OrderedDict(
750
+ [
751
+ (
752
+ "avg_pool",
753
+ nn.AvgPool2d(
754
+ kernel_size=stride,
755
+ stride=stride,
756
+ padding=0,
757
+ ceil_mode=True,
758
+ ),
759
+ ),
760
+ (
761
+ "conv",
762
+ nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
763
+ ),
764
+ ("bn", nn.BatchNorm2d(out_channels)),
765
+ ]
766
+ )
767
+ )
768
+ else:
769
+ raise NotImplementedError
770
+
771
+ self.final_act = build_activation(self.act_func, inplace=True)
772
+
773
+ def forward(self, x):
774
+ residual = self.downsample(x)
775
+
776
+ x = self.conv1(x)
777
+ x = self.conv2(x)
778
+ x = self.conv3(x)
779
+
780
+ x = x + residual
781
+ x = self.final_act(x)
782
+ return x
783
+
784
+ @property
785
+ def module_str(self):
786
+ return "(%s, %s)" % (
787
+ "%dx%d_BottleneckConv_%d->%d->%d_S%d_G%d"
788
+ % (
789
+ self.kernel_size,
790
+ self.kernel_size,
791
+ self.in_channels,
792
+ self.mid_channels,
793
+ self.out_channels,
794
+ self.stride,
795
+ self.groups,
796
+ ),
797
+ "Identity"
798
+ if isinstance(self.downsample, IdentityLayer)
799
+ else self.downsample_mode,
800
+ )
801
+
802
+ @property
803
+ def config(self):
804
+ return {
805
+ "name": ResNetBottleneckBlock.__name__,
806
+ "in_channels": self.in_channels,
807
+ "out_channels": self.out_channels,
808
+ "kernel_size": self.kernel_size,
809
+ "stride": self.stride,
810
+ "expand_ratio": self.expand_ratio,
811
+ "mid_channels": self.mid_channels,
812
+ "act_func": self.act_func,
813
+ "groups": self.groups,
814
+ "downsample_mode": self.downsample_mode,
815
+ }
816
+
817
+ @staticmethod
818
+ def build_from_config(config):
819
+ return ResNetBottleneckBlock(**config)
proard/utils/my_dataloader/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .my_distributed_sampler import *
2
+ from .my_random_resize_crop import *
proard/utils/my_dataloader/my_data_loader.py ADDED
@@ -0,0 +1,1050 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
2
+
3
+ To support these two classes, in `./_utils` we define many utility methods and
4
+ functions to be run in multiprocessing. E.g., the data loading worker loop is
5
+ in `./_utils/worker.py`.
6
+ """
7
+
8
+ import threading
9
+ import itertools
10
+ import warnings
11
+ import multiprocessing as python_multiprocessing
12
+ import torch
13
+ import torch.multiprocessing as multiprocessing
14
+ from torch._utils import ExceptionWrapper
15
+ from torch.multiprocessing import Queue as queue
16
+ from torch._six import string_classes
17
+ from torch.utils.data.dataset import IterableDataset
18
+ from torch.utils.data import Sampler, SequentialSampler, RandomSampler, BatchSampler
19
+ from torch.utils.data import _utils
20
+
21
+ from .my_data_worker import worker_loop
22
+
23
+ __all__ = ["MyDataLoader"]
24
+
25
+ get_worker_info = _utils.worker.get_worker_info
26
+
27
+ # This function used to be defined in this file. However, it was moved to
28
+ # _utils/collate.py. Although it is rather hard to access this from user land
29
+ # (one has to explicitly directly `import torch.utils.data.dataloader`), there
30
+ # probably is user code out there using it. This aliasing maintains BC in this
31
+ # aspect.
32
+ default_collate = _utils.collate.default_collate
33
+
34
+
35
+ class _DatasetKind(object):
36
+ Map = 0
37
+ Iterable = 1
38
+
39
+ @staticmethod
40
+ def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
41
+ if kind == _DatasetKind.Map:
42
+ return _utils.fetch._MapDatasetFetcher(
43
+ dataset, auto_collation, collate_fn, drop_last
44
+ )
45
+ else:
46
+ return _utils.fetch._IterableDatasetFetcher(
47
+ dataset, auto_collation, collate_fn, drop_last
48
+ )
49
+
50
+
51
+ class _InfiniteConstantSampler(Sampler):
52
+ r"""Analogous to ``itertools.repeat(None, None)``.
53
+ Used as sampler for :class:`~torch.utils.data.IterableDataset`.
54
+
55
+ Arguments:
56
+ data_source (Dataset): dataset to sample from
57
+ """
58
+
59
+ def __init__(self):
60
+ super(_InfiniteConstantSampler, self).__init__(None)
61
+
62
+ def __iter__(self):
63
+ while True:
64
+ yield None
65
+
66
+
67
+ class MyDataLoader(object):
68
+ r"""
69
+ Data loader. Combines a dataset and a sampler, and provides an iterable over
70
+ the given dataset.
71
+
72
+ The :class:`~torch.utils.data.DataLoader` supports both map-style and
73
+ iterable-style datasets with single- or multi-process loading, customizing
74
+ loading order and optional automatic batching (collation) and memory pinning.
75
+
76
+ See :py:mod:`torch.utils.data` documentation page for more details.
77
+
78
+ Arguments:
79
+ dataset (Dataset): dataset from which to load the data.
80
+ batch_size (int, optional): how many samples per batch to load
81
+ (default: ``1``).
82
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
83
+ at every epoch (default: ``False``).
84
+ sampler (Sampler, optional): defines the strategy to draw samples from
85
+ the dataset. If specified, :attr:`shuffle` must be ``False``.
86
+ batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of
87
+ indices at a time. Mutually exclusive with :attr:`batch_size`,
88
+ :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
89
+ num_workers (int, optional): how many subprocesses to use for data
90
+ loading. ``0`` means that the data will be loaded in the main process.
91
+ (default: ``0``)
92
+ collate_fn (callable, optional): merges a list of samples to form a
93
+ mini-batch of Tensor(s). Used when using batched loading from a
94
+ map-style dataset.
95
+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
96
+ into CUDA pinned memory before returning them. If your data elements
97
+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
98
+ see the example below.
99
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
100
+ if the dataset size is not divisible by the batch size. If ``False`` and
101
+ the size of dataset is not divisible by the batch size, then the last batch
102
+ will be smaller. (default: ``False``)
103
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
104
+ from workers. Should always be non-negative. (default: ``0``)
105
+ worker_init_fn (callable, optional): If not ``None``, this will be called on each
106
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
107
+ input, after seeding and before data loading. (default: ``None``)
108
+
109
+
110
+ .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
111
+ cannot be an unpicklable object, e.g., a lambda function. See
112
+ :ref:`multiprocessing-best-practices` on more details related
113
+ to multiprocessing in PyTorch.
114
+
115
+ .. note:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
116
+ When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
117
+ ``len(dataset)`` (if implemented) is returned instead, regardless
118
+ of multi-process loading configurations, because PyTorch trust
119
+ user :attr:`dataset` code in correctly handling multi-process
120
+ loading to avoid duplicate data. See `Dataset Types`_ for more
121
+ details on these two types of datasets and how
122
+ :class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_.
123
+ """
124
+
125
+ __initialized = False
126
+
127
+ def __init__(
128
+ self,
129
+ dataset,
130
+ batch_size=1,
131
+ shuffle=False,
132
+ sampler=None,
133
+ batch_sampler=None,
134
+ num_workers=0,
135
+ collate_fn=None,
136
+ pin_memory=False,
137
+ drop_last=False,
138
+ timeout=0,
139
+ worker_init_fn=None,
140
+ multiprocessing_context=None,
141
+ ):
142
+ torch._C._log_api_usage_once("python.data_loader")
143
+
144
+ if num_workers < 0:
145
+ raise ValueError(
146
+ "num_workers option should be non-negative; "
147
+ "use num_workers=0 to disable multiprocessing."
148
+ )
149
+
150
+ if timeout < 0:
151
+ raise ValueError("timeout option should be non-negative")
152
+
153
+ self.dataset = dataset
154
+ self.num_workers = num_workers
155
+ self.pin_memory = pin_memory
156
+ self.timeout = timeout
157
+ self.worker_init_fn = worker_init_fn
158
+ self.multiprocessing_context = multiprocessing_context
159
+
160
+ # Arg-check dataset related before checking samplers because we want to
161
+ # tell users that iterable-style datasets are incompatible with custom
162
+ # samplers first, so that they don't learn that this combo doesn't work
163
+ # after spending time fixing the custom sampler errors.
164
+ if isinstance(dataset, IterableDataset):
165
+ self._dataset_kind = _DatasetKind.Iterable
166
+ # NOTE [ Custom Samplers and `IterableDataset` ]
167
+ #
168
+ # `IterableDataset` does not support custom `batch_sampler` or
169
+ # `sampler` since the key is irrelevant (unless we support
170
+ # generator-style dataset one day...).
171
+ #
172
+ # For `sampler`, we always create a dummy sampler. This is an
173
+ # infinite sampler even when the dataset may have an implemented
174
+ # finite `__len__` because in multi-process data loading, naive
175
+ # settings will return duplicated data (which may be desired), and
176
+ # thus using a sampler with length matching that of dataset will
177
+ # cause data lost (you may have duplicates of the first couple
178
+ # batches, but never see anything afterwards). Therefore,
179
+ # `Iterabledataset` always uses an infinite sampler, an instance of
180
+ # `_InfiniteConstantSampler` defined above.
181
+ #
182
+ # A custom `batch_sampler` essentially only controls the batch size.
183
+ # However, it is unclear how useful it would be since an iterable-style
184
+ # dataset can handle that within itself. Moreover, it is pointless
185
+ # in multi-process data loading as the assignment order of batches
186
+ # to workers is an implementation detail so users can not control
187
+ # how to batchify each worker's iterable. Thus, we disable this
188
+ # option. If this turns out to be useful in future, we can re-enable
189
+ # this, and support custom samplers that specify the assignments to
190
+ # specific workers.
191
+ if shuffle is not False:
192
+ raise ValueError(
193
+ "DataLoader with IterableDataset: expected unspecified "
194
+ "shuffle option, but got shuffle={}".format(shuffle)
195
+ )
196
+ elif sampler is not None:
197
+ # See NOTE [ Custom Samplers and IterableDataset ]
198
+ raise ValueError(
199
+ "DataLoader with IterableDataset: expected unspecified "
200
+ "sampler option, but got sampler={}".format(sampler)
201
+ )
202
+ elif batch_sampler is not None:
203
+ # See NOTE [ Custom Samplers and IterableDataset ]
204
+ raise ValueError(
205
+ "DataLoader with IterableDataset: expected unspecified "
206
+ "batch_sampler option, but got batch_sampler={}".format(
207
+ batch_sampler
208
+ )
209
+ )
210
+ else:
211
+ self._dataset_kind = _DatasetKind.Map
212
+
213
+ if sampler is not None and shuffle:
214
+ raise ValueError("sampler option is mutually exclusive with " "shuffle")
215
+
216
+ if batch_sampler is not None:
217
+ # auto_collation with custom batch_sampler
218
+ if batch_size != 1 or shuffle or sampler is not None or drop_last:
219
+ raise ValueError(
220
+ "batch_sampler option is mutually exclusive "
221
+ "with batch_size, shuffle, sampler, and "
222
+ "drop_last"
223
+ )
224
+ batch_size = None
225
+ drop_last = False
226
+ elif batch_size is None:
227
+ # no auto_collation
228
+ if shuffle or drop_last:
229
+ raise ValueError(
230
+ "batch_size=None option disables auto-batching "
231
+ "and is mutually exclusive with "
232
+ "shuffle, and drop_last"
233
+ )
234
+
235
+ if sampler is None: # give default samplers
236
+ if self._dataset_kind == _DatasetKind.Iterable:
237
+ # See NOTE [ Custom Samplers and IterableDataset ]
238
+ sampler = _InfiniteConstantSampler()
239
+ else: # map-style
240
+ if shuffle:
241
+ sampler = RandomSampler(dataset)
242
+ else:
243
+ sampler = SequentialSampler(dataset)
244
+
245
+ if batch_size is not None and batch_sampler is None:
246
+ # auto_collation without custom batch_sampler
247
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
248
+
249
+ self.batch_size = batch_size
250
+ self.drop_last = drop_last
251
+ self.sampler = sampler
252
+ self.batch_sampler = batch_sampler
253
+
254
+ if collate_fn is None:
255
+ if self._auto_collation:
256
+ collate_fn = _utils.collate.default_collate
257
+ else:
258
+ collate_fn = _utils.collate.default_convert
259
+
260
+ self.collate_fn = collate_fn
261
+ self.__initialized = True
262
+ self._IterableDataset_len_called = (
263
+ None # See NOTE [ IterableDataset and __len__ ]
264
+ )
265
+
266
+ @property
267
+ def multiprocessing_context(self):
268
+ return self.__multiprocessing_context
269
+
270
+ @multiprocessing_context.setter
271
+ def multiprocessing_context(self, multiprocessing_context):
272
+ if multiprocessing_context is not None:
273
+ if self.num_workers > 0:
274
+ if not multiprocessing._supports_context:
275
+ raise ValueError(
276
+ "multiprocessing_context relies on Python >= 3.4, with "
277
+ "support for different start methods"
278
+ )
279
+
280
+ if isinstance(multiprocessing_context, string_classes):
281
+ valid_start_methods = multiprocessing.get_all_start_methods()
282
+ if multiprocessing_context not in valid_start_methods:
283
+ raise ValueError(
284
+ (
285
+ "multiprocessing_context option "
286
+ "should specify a valid start method in {}, but got "
287
+ "multiprocessing_context={}"
288
+ ).format(valid_start_methods, multiprocessing_context)
289
+ )
290
+ multiprocessing_context = multiprocessing.get_context(
291
+ multiprocessing_context
292
+ )
293
+
294
+ if not isinstance(
295
+ multiprocessing_context, python_multiprocessing.context.BaseContext
296
+ ):
297
+ raise ValueError(
298
+ (
299
+ "multiprocessing_context option should be a valid context "
300
+ "object or a string specifying the start method, but got "
301
+ "multiprocessing_context={}"
302
+ ).format(multiprocessing_context)
303
+ )
304
+ else:
305
+ raise ValueError(
306
+ (
307
+ "multiprocessing_context can only be used with "
308
+ "multi-process loading (num_workers > 0), but got "
309
+ "num_workers={}"
310
+ ).format(self.num_workers)
311
+ )
312
+
313
+ self.__multiprocessing_context = multiprocessing_context
314
+
315
+ def __setattr__(self, attr, val):
316
+ if self.__initialized and attr in (
317
+ "batch_size",
318
+ "batch_sampler",
319
+ "sampler",
320
+ "drop_last",
321
+ "dataset",
322
+ ):
323
+ raise ValueError(
324
+ "{} attribute should not be set after {} is "
325
+ "initialized".format(attr, self.__class__.__name__)
326
+ )
327
+
328
+ super(MyDataLoader, self).__setattr__(attr, val)
329
+
330
+ def __iter__(self):
331
+ if self.num_workers == 0:
332
+ return _SingleProcessDataLoaderIter(self)
333
+ else:
334
+ return _MultiProcessingDataLoaderIter(self)
335
+
336
+ @property
337
+ def _auto_collation(self):
338
+ return self.batch_sampler is not None
339
+
340
+ @property
341
+ def _index_sampler(self):
342
+ # The actual sampler used for generating indices for `_DatasetFetcher`
343
+ # (see _utils/fetch.py) to read data at each time. This would be
344
+ # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
345
+ # We can't change `.sampler` and `.batch_sampler` attributes for BC
346
+ # reasons.
347
+ if self._auto_collation:
348
+ return self.batch_sampler
349
+ else:
350
+ return self.sampler
351
+
352
+ def __len__(self):
353
+ if self._dataset_kind == _DatasetKind.Iterable:
354
+ # NOTE [ IterableDataset and __len__ ]
355
+ #
356
+ # For `IterableDataset`, `__len__` could be inaccurate when one naively
357
+ # does multi-processing data loading, since the samples will be duplicated.
358
+ # However, no real use case should be actually using that behavior, so
359
+ # it should count as a user error. We should generally trust user
360
+ # code to do the proper thing (e.g., configure each replica differently
361
+ # in `__iter__`), and give us the correct `__len__` if they choose to
362
+ # implement it (this will still throw if the dataset does not implement
363
+ # a `__len__`).
364
+ #
365
+ # To provide a further warning, we track if `__len__` was called on the
366
+ # `DataLoader`, save the returned value in `self._len_called`, and warn
367
+ # if the iterator ends up yielding more than this number of samples.
368
+ length = self._IterableDataset_len_called = len(self.dataset)
369
+ return length
370
+ else:
371
+ return len(self._index_sampler)
372
+
373
+
374
+ class _BaseDataLoaderIter(object):
375
+ def __init__(self, loader):
376
+ self._dataset = loader.dataset
377
+ self._dataset_kind = loader._dataset_kind
378
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
379
+ self._auto_collation = loader._auto_collation
380
+ self._drop_last = loader.drop_last
381
+ self._index_sampler = loader._index_sampler
382
+ self._num_workers = loader.num_workers
383
+ self._pin_memory = loader.pin_memory and torch.cuda.is_available()
384
+ self._timeout = loader.timeout
385
+ self._collate_fn = loader.collate_fn
386
+ self._sampler_iter = iter(self._index_sampler)
387
+ self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
388
+ self._num_yielded = 0
389
+
390
+ def __iter__(self):
391
+ return self
392
+
393
+ def _next_index(self):
394
+ return next(self._sampler_iter) # may raise StopIteration
395
+
396
+ def _next_data(self):
397
+ raise NotImplementedError
398
+
399
+ def __next__(self):
400
+ data = self._next_data()
401
+ self._num_yielded += 1
402
+ if (
403
+ self._dataset_kind == _DatasetKind.Iterable
404
+ and self._IterableDataset_len_called is not None
405
+ and self._num_yielded > self._IterableDataset_len_called
406
+ ):
407
+ warn_msg = (
408
+ "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
409
+ "samples have been fetched. "
410
+ ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded)
411
+ if self._num_workers > 0:
412
+ warn_msg += (
413
+ "For multiprocessing data-loading, this could be caused by not properly configuring the "
414
+ "IterableDataset replica at each worker. Please see "
415
+ "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
416
+ )
417
+ warnings.warn(warn_msg)
418
+ return data
419
+
420
+ next = __next__ # Python 2 compatibility
421
+
422
+ def __len__(self):
423
+ return len(self._index_sampler)
424
+
425
+ def __getstate__(self):
426
+ # across multiple threads for HOGWILD.
427
+ # Probably the best way to do this is by moving the sample pushing
428
+ # to a separate thread and then just sharing the data queue
429
+ # but signalling the end is tricky without a non-blocking API
430
+ raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
431
+
432
+
433
+ class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
434
+ def __init__(self, loader):
435
+ super(_SingleProcessDataLoaderIter, self).__init__(loader)
436
+ assert self._timeout == 0
437
+ assert self._num_workers == 0
438
+
439
+ self._dataset_fetcher = _DatasetKind.create_fetcher(
440
+ self._dataset_kind,
441
+ self._dataset,
442
+ self._auto_collation,
443
+ self._collate_fn,
444
+ self._drop_last,
445
+ )
446
+
447
+ def _next_data(self):
448
+ index = self._next_index() # may raise StopIteration
449
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
450
+ if self._pin_memory:
451
+ data = _utils.pin_memory.pin_memory(data)
452
+ return data
453
+
454
+
455
+ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
456
+ r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
457
+
458
+ # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
459
+ #
460
+ # Preliminary:
461
+ #
462
+ # Our data model looks like this (queues are indicated with curly brackets):
463
+ #
464
+ # main process ||
465
+ # | ||
466
+ # {index_queue} ||
467
+ # | ||
468
+ # worker processes || DATA
469
+ # | ||
470
+ # {worker_result_queue} || FLOW
471
+ # | ||
472
+ # pin_memory_thread of main process || DIRECTION
473
+ # | ||
474
+ # {data_queue} ||
475
+ # | ||
476
+ # data output \/
477
+ #
478
+ # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
479
+ # `pin_memory=False`.
480
+ #
481
+ #
482
+ # Terminating multiprocessing logic requires very careful design. In
483
+ # particular, we need to make sure that
484
+ #
485
+ # 1. The iterator gracefully exits the workers when its last reference is
486
+ # gone or it is depleted.
487
+ #
488
+ # In this case, the workers should be gracefully exited because the
489
+ # main process may still need to continue to run, and we want cleaning
490
+ # up code in the workers to be executed (e.g., releasing GPU memory).
491
+ # Naturally, we implement the shutdown logic in `__del__` of
492
+ # DataLoaderIterator.
493
+ #
494
+ # We delay the discussion on the logic in this case until later.
495
+ #
496
+ # 2. The iterator exits the workers when the loader process and/or worker
497
+ # processes exits normally or with error.
498
+ #
499
+ # We set all workers and `pin_memory_thread` to have `daemon=True`.
500
+ #
501
+ # You may ask, why can't we make the workers non-daemonic, and
502
+ # gracefully exit using the same logic as we have in `__del__` when the
503
+ # iterator gets deleted (see 1 above)?
504
+ #
505
+ # First of all, `__del__` is **not** guaranteed to be called when
506
+ # interpreter exits. Even if it is called, by the time it executes,
507
+ # many Python core library resources may alreay be freed, and even
508
+ # simple things like acquiring an internal lock of a queue may hang.
509
+ # Therefore, in this case, we actually need to prevent `__del__` from
510
+ # being executed, and rely on the automatic termination of daemonic
511
+ # children. Thus, we register an `atexit` hook that sets a global flag
512
+ # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
513
+ # reverse order of registration, we are guaranteed that this flag is
514
+ # set before library resources we use are freed. (Hooks freeing those
515
+ # resources are registered at importing the Python core libraries at
516
+ # the top of this file.) So in `__del__`, we check if
517
+ # `_utils.python_exit_status` is set or `None` (freed), and perform
518
+ # no-op if so.
519
+ #
520
+ # Another problem with `__del__` is also related to the library cleanup
521
+ # calls. When a process ends, it shuts the all its daemonic children
522
+ # down with a SIGTERM (instead of joining them without a timeout).
523
+ # Simiarly for threads, but by a different mechanism. This fact,
524
+ # together with a few implementation details of multiprocessing, forces
525
+ # us to make workers daemonic. All of our problems arise when a
526
+ # DataLoader is used in a subprocess, and are caused by multiprocessing
527
+ # code which looks more or less like this:
528
+ #
529
+ # try:
530
+ # your_function_using_a_dataloader()
531
+ # finally:
532
+ # multiprocessing.util._exit_function()
533
+ #
534
+ # The joining/termination mentioned above happens inside
535
+ # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
536
+ # throws, the stack trace stored in the exception will prevent the
537
+ # frame which uses `DataLoaderIter` to be freed. If the frame has any
538
+ # reference to the `DataLoaderIter` (e.g., in a method of the iter),
539
+ # its `__del__`, which starts the shutdown procedure, will not be
540
+ # called. That, in turn, means that workers aren't notified. Attempting
541
+ # to join in `_exit_function` will then result in a hang.
542
+ #
543
+ # For context, `_exit_function` is also registered as an `atexit` call.
544
+ # So it is unclear to me (@ssnl) why this is needed in a finally block.
545
+ # The code dates back to 2008 and there is no comment on the original
546
+ # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
547
+ # the finally block and the `atexit` registration) that explains this.
548
+ #
549
+ # Another choice is to just shutdown workers with logic in 1 above
550
+ # whenever we see an error in `next`. This isn't ideal because
551
+ # a. It prevents users from using try-catch to resume data loading.
552
+ # b. It doesn't prevent hanging if users have references to the
553
+ # iterator.
554
+ #
555
+ # 3. All processes exit if any of them die unexpectedly by fatal signals.
556
+ #
557
+ # As shown above, the workers are set as daemonic children of the main
558
+ # process. However, automatic cleaning-up of such child processes only
559
+ # happens if the parent process exits gracefully (e.g., not via fatal
560
+ # signals like SIGKILL). So we must ensure that each process will exit
561
+ # even the process that should send/receive data to/from it were
562
+ # killed, i.e.,
563
+ #
564
+ # a. A process won't hang when getting from a queue.
565
+ #
566
+ # Even with carefully designed data dependencies (i.e., a `put()`
567
+ # always corresponding to a `get()`), hanging on `get()` can still
568
+ # happen when data in queue is corrupted (e.g., due to
569
+ # `cancel_join_thread` or unexpected exit).
570
+ #
571
+ # For child exit, we set a timeout whenever we try to get data
572
+ # from `data_queue`, and check the workers' status on each timeout
573
+ # and error.
574
+ # See `_DataLoaderiter._get_batch()` and
575
+ # `_DataLoaderiter._try_get_data()` for details.
576
+ #
577
+ # Additionally, for child exit on non-Windows platforms, we also
578
+ # register a SIGCHLD handler (which is supported on Windows) on
579
+ # the main process, which checks if any of the workers fail in the
580
+ # (Python) handler. This is more efficient and faster in detecting
581
+ # worker failures, compared to only using the above mechanism.
582
+ # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
583
+ #
584
+ # For `.get()` calls where the sender(s) is not the workers, we
585
+ # guard them with timeouts, and check the status of the sender
586
+ # when timeout happens:
587
+ # + in the workers, the `_utils.worker.ManagerWatchdog` class
588
+ # checks the status of the main process.
589
+ # + if `pin_memory=True`, when getting from `pin_memory_thread`,
590
+ # check `pin_memory_thread` status periodically until `.get()`
591
+ # returns or see that `pin_memory_thread` died.
592
+ #
593
+ # b. A process won't hang when putting into a queue;
594
+ #
595
+ # We use `mp.Queue` which has a separate background thread to put
596
+ # objects from an unbounded buffer array. The background thread is
597
+ # daemonic and usually automatically joined when the process
598
+ # exits.
599
+ #
600
+ # However, in case that the receiver has ended abruptly while
601
+ # reading from the pipe, the join will hang forever. Therefore,
602
+ # for both `worker_result_queue` (worker -> main process/pin_memory_thread)
603
+ # and each `index_queue` (main process -> worker), we use
604
+ # `q.cancel_join_thread()` in sender process before any `q.put` to
605
+ # prevent this automatic join.
606
+ #
607
+ # Moreover, having all queues called `cancel_join_thread` makes
608
+ # implementing graceful shutdown logic in `__del__` much easier.
609
+ # It won't need to get from any queue, which would also need to be
610
+ # guarded by periodic status checks.
611
+ #
612
+ # Nonetheless, `cancel_join_thread` must only be called when the
613
+ # queue is **not** going to be read from or write into by another
614
+ # process, because it may hold onto a lock or leave corrupted data
615
+ # in the queue, leading other readers/writers to hang.
616
+ #
617
+ # `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does
618
+ # a blocking `put` if the queue is full. So there is no above
619
+ # problem, but we do need to wrap the `put` in a loop that breaks
620
+ # not only upon success, but also when the main process stops
621
+ # reading, i.e., is shutting down.
622
+ #
623
+ #
624
+ # Now let's get back to 1:
625
+ # how we gracefully exit the workers when the last reference to the
626
+ # iterator is gone.
627
+ #
628
+ # To achieve this, we implement the following logic along with the design
629
+ # choices mentioned above:
630
+ #
631
+ # `workers_done_event`:
632
+ # A `multiprocessing.Event` shared among the main process and all worker
633
+ # processes. This is used to signal the workers that the iterator is
634
+ # shutting down. After it is set, they will not send processed data to
635
+ # queues anymore, and only wait for the final `None` before exiting.
636
+ # `done_event` isn't strictly needed. I.e., we can just check for `None`
637
+ # from the input queue, but it allows us to skip wasting resources
638
+ # processing data if we are already shutting down.
639
+ #
640
+ # `pin_memory_thread_done_event`:
641
+ # A `threading.Event` for a similar purpose to that of
642
+ # `workers_done_event`, but is for the `pin_memory_thread`. The reason
643
+ # that separate events are needed is that `pin_memory_thread` reads from
644
+ # the output queue of the workers. But the workers, upon seeing that
645
+ # `workers_done_event` is set, only wants to see the final `None`, and is
646
+ # not required to flush all data in the output queue (e.g., it may call
647
+ # `cancel_join_thread` on that queue if its `IterableDataset` iterator
648
+ # happens to exhaust coincidentally, which is out of the control of the
649
+ # main process). Thus, since we will exit `pin_memory_thread` before the
650
+ # workers (see below), two separete events are used.
651
+ #
652
+ # NOTE: In short, the protocol is that the main process will set these
653
+ # `done_event`s and then the corresponding processes/threads a `None`,
654
+ # and that they may exit at any time after receiving the `None`.
655
+ #
656
+ # NOTE: Using `None` as the final signal is valid, since normal data will
657
+ # always be a 2-tuple with the 1st element being the index of the data
658
+ # transferred (different from dataset index/key), and the 2nd being
659
+ # either the dataset key or the data sample (depending on which part
660
+ # of the data model the queue is at).
661
+ #
662
+ # [ worker processes ]
663
+ # While loader process is alive:
664
+ # Get from `index_queue`.
665
+ # If get anything else,
666
+ # Check `workers_done_event`.
667
+ # If set, continue to next iteration
668
+ # i.e., keep getting until see the `None`, then exit.
669
+ # Otherwise, process data:
670
+ # If is fetching from an `IterableDataset` and the iterator
671
+ # is exhausted, send an `_IterableDatasetStopIteration`
672
+ # object to signal iteration end. The main process, upon
673
+ # receiving such an object, will send `None` to this
674
+ # worker and not use the corresponding `index_queue`
675
+ # anymore.
676
+ # If timed out,
677
+ # No matter `workers_done_event` is set (still need to see `None`)
678
+ # or not, must continue to next iteration.
679
+ # (outside loop)
680
+ # If `workers_done_event` is set, (this can be False with `IterableDataset`)
681
+ # `data_queue.cancel_join_thread()`. (Everything is ending here:
682
+ # main process won't read from it;
683
+ # other workers will also call
684
+ # `cancel_join_thread`.)
685
+ #
686
+ # [ pin_memory_thread ]
687
+ # # No need to check main thread. If this thread is alive, the main loader
688
+ # # thread must be alive, because this thread is set as daemonic.
689
+ # While `pin_memory_thread_done_event` is not set:
690
+ # Get from `index_queue`.
691
+ # If timed out, continue to get in the next iteration.
692
+ # Otherwise, process data.
693
+ # While `pin_memory_thread_done_event` is not set:
694
+ # Put processed data to `data_queue` (a `queue.Queue` with blocking put)
695
+ # If timed out, continue to put in the next iteration.
696
+ # Otherwise, break, i.e., continuing to the out loop.
697
+ #
698
+ # NOTE: we don't check the status of the main thread because
699
+ # 1. if the process is killed by fatal signal, `pin_memory_thread`
700
+ # ends.
701
+ # 2. in other cases, either the cleaning-up in __del__ or the
702
+ # automatic exit of daemonic thread will take care of it.
703
+ # This won't busy-wait either because `.get(timeout)` does not
704
+ # busy-wait.
705
+ #
706
+ # [ main process ]
707
+ # In the DataLoader Iter's `__del__`
708
+ # b. Exit `pin_memory_thread`
709
+ # i. Set `pin_memory_thread_done_event`.
710
+ # ii Put `None` in `worker_result_queue`.
711
+ # iii. Join the `pin_memory_thread`.
712
+ # iv. `worker_result_queue.cancel_join_thread()`.
713
+ #
714
+ # c. Exit the workers.
715
+ # i. Set `workers_done_event`.
716
+ # ii. Put `None` in each worker's `index_queue`.
717
+ # iii. Join the workers.
718
+ # iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
719
+ #
720
+ # NOTE: (c) is better placed after (b) because it may leave corrupted
721
+ # data in `worker_result_queue`, which `pin_memory_thread`
722
+ # reads from, in which case the `pin_memory_thread` can only
723
+ # happen at timeing out, which is slow. Nonetheless, same thing
724
+ # happens if a worker is killed by signal at unfortunate times,
725
+ # but in other cases, we are better off having a non-corrupted
726
+ # `worker_result_queue` for `pin_memory_thread`.
727
+ #
728
+ # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
729
+ # can be omitted
730
+ #
731
+ # NB: `done_event`s isn't strictly needed. E.g., we can just check for
732
+ # `None` from `index_queue`, but it allows us to skip wasting resources
733
+ # processing indices already in `index_queue` if we are already shutting
734
+ # down.
735
+
736
+ def __init__(self, loader):
737
+ super(_MultiProcessingDataLoaderIter, self).__init__(loader)
738
+
739
+ assert self._num_workers > 0
740
+
741
+ if loader.multiprocessing_context is None:
742
+ multiprocessing_context = multiprocessing
743
+ else:
744
+ multiprocessing_context = loader.multiprocessing_context
745
+
746
+ self._worker_init_fn = loader.worker_init_fn
747
+ self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
748
+ self._worker_result_queue = multiprocessing_context.Queue()
749
+ self._worker_pids_set = False
750
+ self._shutdown = False
751
+ self._send_idx = 0 # idx of the next task to be sent to workers
752
+ self._rcvd_idx = 0 # idx of the next task to be returned in __next__
753
+ # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
754
+ # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
755
+ # \ (worker_id, data) if data is already fetched (out-of-order)
756
+ self._task_info = {}
757
+ self._tasks_outstanding = (
758
+ 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
759
+ )
760
+ self._workers_done_event = multiprocessing_context.Event()
761
+
762
+ self._index_queues = []
763
+ self._workers = []
764
+ # A list of booleans representing whether each worker still has work to
765
+ # do, i.e., not having exhausted its iterable dataset object. It always
766
+ # contains all `True`s if not using an iterable-style dataset
767
+ # (i.e., if kind != Iterable).
768
+ self._workers_status = []
769
+ for i in range(self._num_workers):
770
+ index_queue = multiprocessing_context.Queue()
771
+ # index_queue.cancel_join_thread()
772
+ w = multiprocessing_context.Process(
773
+ target=worker_loop,
774
+ args=(
775
+ self._dataset_kind,
776
+ self._dataset,
777
+ index_queue,
778
+ self._worker_result_queue,
779
+ self._workers_done_event,
780
+ self._auto_collation,
781
+ self._collate_fn,
782
+ self._drop_last,
783
+ self._base_seed + i,
784
+ self._worker_init_fn,
785
+ i,
786
+ self._num_workers,
787
+ ),
788
+ )
789
+ w.daemon = True
790
+ # NB: Process.start() actually take some time as it needs to
791
+ # start a process and pass the arguments over via a pipe.
792
+ # Therefore, we only add a worker to self._workers list after
793
+ # it started, so that we do not call .join() if program dies
794
+ # before it starts, and __del__ tries to join but will get:
795
+ # AssertionError: can only join a started process.
796
+ w.start()
797
+ self._index_queues.append(index_queue)
798
+ self._workers.append(w)
799
+ self._workers_status.append(True)
800
+
801
+ if self._pin_memory:
802
+ self._pin_memory_thread_done_event = threading.Event()
803
+ self._data_queue = queue()
804
+ pin_memory_thread = threading.Thread(
805
+ target=_utils.pin_memory._pin_memory_loop,
806
+ args=(
807
+ self._worker_result_queue,
808
+ self._data_queue,
809
+ torch.cuda.current_device(),
810
+ self._pin_memory_thread_done_event,
811
+ ),
812
+ )
813
+ pin_memory_thread.daemon = True
814
+ pin_memory_thread.start()
815
+ # Similar to workers (see comment above), we only register
816
+ # pin_memory_thread once it is started.
817
+ self._pin_memory_thread = pin_memory_thread
818
+ else:
819
+ self._data_queue = self._worker_result_queue
820
+
821
+ _utils.signal_handling._set_worker_pids(
822
+ id(self), tuple(w.pid for w in self._workers)
823
+ )
824
+ _utils.signal_handling._set_SIGCHLD_handler()
825
+ self._worker_pids_set = True
826
+
827
+ # prime the prefetch loop
828
+ for _ in range(2 * self._num_workers):
829
+ self._try_put_index()
830
+
831
+ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
832
+ # Tries to fetch data from `self._data_queue` once for a given timeout.
833
+ # This can also be used as inner loop of fetching without timeout, with
834
+ # the sender status as the loop condition.
835
+ #
836
+ # This raises a `RuntimeError` if any worker died expectedly. This error
837
+ # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
838
+ # (only for non-Windows platforms), or the manual check below on errors
839
+ # and timeouts.
840
+ #
841
+ # Returns a 2-tuple:
842
+ # (bool: whether successfully get data, any: data if successful else None)
843
+ try:
844
+ data = self._data_queue.get(timeout=timeout)
845
+ return (True, data)
846
+ except Exception as e:
847
+ # At timeout and error, we manually check whether any worker has
848
+ # failed. Note that this is the only mechanism for Windows to detect
849
+ # worker failures.
850
+ failed_workers = []
851
+ for worker_id, w in enumerate(self._workers):
852
+ if self._workers_status[worker_id] and not w.is_alive():
853
+ failed_workers.append(w)
854
+ self._shutdown_worker(worker_id)
855
+ if len(failed_workers) > 0:
856
+ pids_str = ", ".join(str(w.pid) for w in failed_workers)
857
+ raise RuntimeError(
858
+ "DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)
859
+ )
860
+ if isinstance(e, queue.Empty):
861
+ return (False, None)
862
+ raise
863
+
864
+ def _get_data(self):
865
+ # Fetches data from `self._data_queue`.
866
+ #
867
+ # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
868
+ # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
869
+ # in a loop. This is the only mechanism to detect worker failures for
870
+ # Windows. For other platforms, a SIGCHLD handler is also used for
871
+ # worker failure detection.
872
+ #
873
+ # If `pin_memory=True`, we also need check if `pin_memory_thread` had
874
+ # died at timeouts.
875
+ if self._timeout > 0:
876
+ success, data = self._try_get_data(self._timeout)
877
+ if success:
878
+ return data
879
+ else:
880
+ raise RuntimeError(
881
+ "DataLoader timed out after {} seconds".format(self._timeout)
882
+ )
883
+ elif self._pin_memory:
884
+ while self._pin_memory_thread.is_alive():
885
+ success, data = self._try_get_data()
886
+ if success:
887
+ return data
888
+ else:
889
+ # while condition is false, i.e., pin_memory_thread died.
890
+ raise RuntimeError("Pin memory thread exited unexpectedly")
891
+ # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
892
+ # need to call `.task_done()` because we don't use `.join()`.
893
+ else:
894
+ while True:
895
+ success, data = self._try_get_data()
896
+ if success:
897
+ return data
898
+
899
+ def _next_data(self):
900
+ while True:
901
+ # If the worker responsible for `self._rcvd_idx` has already ended
902
+ # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
903
+ # we try to advance `self._rcvd_idx` to find the next valid index.
904
+ #
905
+ # This part needs to run in the loop because both the `self._get_data()`
906
+ # call and `_IterableDatasetStopIteration` check below can mark
907
+ # extra worker(s) as dead.
908
+ while self._rcvd_idx < self._send_idx:
909
+ info = self._task_info[self._rcvd_idx]
910
+ worker_id = info[0]
911
+ if (
912
+ len(info) == 2 or self._workers_status[worker_id]
913
+ ): # has data or is still active
914
+ break
915
+ del self._task_info[self._rcvd_idx]
916
+ self._rcvd_idx += 1
917
+ else:
918
+ # no valid `self._rcvd_idx` is found (i.e., didn't break)
919
+ self._shutdown_workers()
920
+ raise StopIteration
921
+
922
+ # Now `self._rcvd_idx` is the batch index we want to fetch
923
+
924
+ # Check if the next sample has already been generated
925
+ if len(self._task_info[self._rcvd_idx]) == 2:
926
+ data = self._task_info.pop(self._rcvd_idx)[1]
927
+ return self._process_data(data)
928
+
929
+ assert not self._shutdown and self._tasks_outstanding > 0
930
+ idx, data = self._get_data()
931
+ self._tasks_outstanding -= 1
932
+
933
+ if self._dataset_kind == _DatasetKind.Iterable:
934
+ # Check for _IterableDatasetStopIteration
935
+ if isinstance(data, _utils.worker._IterableDatasetStopIteration):
936
+ self._shutdown_worker(data.worker_id)
937
+ self._try_put_index()
938
+ continue
939
+
940
+ if idx != self._rcvd_idx:
941
+ # store out-of-order samples
942
+ self._task_info[idx] += (data,)
943
+ else:
944
+ del self._task_info[idx]
945
+ return self._process_data(data)
946
+
947
+ def _try_put_index(self):
948
+ assert self._tasks_outstanding < 2 * self._num_workers
949
+ try:
950
+ index = self._next_index()
951
+ except StopIteration:
952
+ return
953
+ for _ in range(self._num_workers): # find the next active worker, if any
954
+ worker_queue_idx = next(self._worker_queue_idx_cycle)
955
+ if self._workers_status[worker_queue_idx]:
956
+ break
957
+ else:
958
+ # not found (i.e., didn't break)
959
+ return
960
+
961
+ self._index_queues[worker_queue_idx].put((self._send_idx, index))
962
+ self._task_info[self._send_idx] = (worker_queue_idx,)
963
+ self._tasks_outstanding += 1
964
+ self._send_idx += 1
965
+
966
+ def _process_data(self, data):
967
+ self._rcvd_idx += 1
968
+ self._try_put_index()
969
+ if isinstance(data, ExceptionWrapper):
970
+ data.reraise()
971
+ return data
972
+
973
+ def _shutdown_worker(self, worker_id):
974
+ # Mark a worker as having finished its work and dead, e.g., due to
975
+ # exhausting an `IterableDataset`. This should be used only when this
976
+ # `_MultiProcessingDataLoaderIter` is going to continue running.
977
+
978
+ assert self._workers_status[worker_id]
979
+
980
+ # Signal termination to that specific worker.
981
+ q = self._index_queues[worker_id]
982
+ # Indicate that no more data will be put on this queue by the current
983
+ # process.
984
+ q.put(None)
985
+
986
+ # Note that we don't actually join the worker here, nor do we remove the
987
+ # worker's pid from C side struct because (1) joining may be slow, and
988
+ # (2) since we don't join, the worker may still raise error, and we
989
+ # prefer capturing those, rather than ignoring them, even though they
990
+ # are raised after the worker has finished its job.
991
+ # Joinning is deferred to `_shutdown_workers`, which it is called when
992
+ # all workers finish their jobs (e.g., `IterableDataset` replicas) or
993
+ # when this iterator is garbage collected.
994
+ self._workers_status[worker_id] = False
995
+
996
+ def _shutdown_workers(self):
997
+ # Called when shutting down this `_MultiProcessingDataLoaderIter`.
998
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
999
+ # the logic of this function.
1000
+ python_exit_status = _utils.python_exit_status
1001
+ if python_exit_status is True or python_exit_status is None:
1002
+ # See (2) of the note. If Python is shutting down, do no-op.
1003
+ return
1004
+ # Normal exit when last reference is gone / iterator is depleted.
1005
+ # See (1) and the second half of the note.
1006
+ if not self._shutdown:
1007
+ self._shutdown = True
1008
+ try:
1009
+ # Exit `pin_memory_thread` first because exiting workers may leave
1010
+ # corrupted data in `worker_result_queue` which `pin_memory_thread`
1011
+ # reads from.
1012
+ if hasattr(self, "_pin_memory_thread"):
1013
+ # Use hasattr in case error happens before we set the attribute.
1014
+ self._pin_memory_thread_done_event.set()
1015
+ # Send something to pin_memory_thread in case it is waiting
1016
+ # so that it can wake up and check `pin_memory_thread_done_event`
1017
+ self._worker_result_queue.put((None, None))
1018
+ self._pin_memory_thread.join()
1019
+ self._worker_result_queue.close()
1020
+
1021
+ # Exit workers now.
1022
+ self._workers_done_event.set()
1023
+ for worker_id in range(len(self._workers)):
1024
+ # Get number of workers from `len(self._workers)` instead of
1025
+ # `self._num_workers` in case we error before starting all
1026
+ # workers.
1027
+ if self._workers_status[worker_id]:
1028
+ self._shutdown_worker(worker_id)
1029
+ for w in self._workers:
1030
+ w.join()
1031
+ for q in self._index_queues:
1032
+ q.cancel_join_thread()
1033
+ q.close()
1034
+ finally:
1035
+ # Even though all this function does is putting into queues that
1036
+ # we have called `cancel_join_thread` on, weird things can
1037
+ # happen when a worker is killed by a signal, e.g., hanging in
1038
+ # `Event.set()`. So we need to guard this with SIGCHLD handler,
1039
+ # and remove pids from the C side data structure only at the
1040
+ # end.
1041
+ #
1042
+ # FIXME: Unfortunately, for Windows, we are missing a worker
1043
+ # error detection mechanism here in this function, as it
1044
+ # doesn't provide a SIGCHLD handler.
1045
+ if self._worker_pids_set:
1046
+ _utils.signal_handling._remove_worker_pids(id(self))
1047
+ self._worker_pids_set = False
1048
+
1049
+ def __del__(self):
1050
+ self._shutdown_workers()
proard/utils/my_dataloader/my_data_worker.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
2
+
3
+ These **needs** to be in global scope since Py2 doesn't support serializing
4
+ static methods.
5
+ """
6
+
7
+ import torch
8
+ import random
9
+ import os
10
+ from collections import namedtuple
11
+ # from torch._six import queue
12
+ from torch.multiprocessing import Queue as queue
13
+ from torch._utils import ExceptionWrapper
14
+ from torch.utils.data._utils import (
15
+ signal_handling,
16
+ MP_STATUS_CHECK_INTERVAL,
17
+ IS_WINDOWS,
18
+ )
19
+
20
+ from .my_random_resize_crop import MyRandomResizedCrop
21
+
22
+ __all__ = ["worker_loop"]
23
+
24
+ if IS_WINDOWS:
25
+ import ctypes
26
+ from ctypes.wintypes import DWORD, BOOL, HANDLE
27
+
28
+ # On Windows, the parent ID of the worker process remains unchanged when the manager process
29
+ # is gone, and the only way to check it through OS is to let the worker have a process handle
30
+ # of the manager and ask if the process status has changed.
31
+ class ManagerWatchdog(object):
32
+ def __init__(self):
33
+ self.manager_pid = os.getppid()
34
+
35
+ self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
36
+ self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
37
+ self.kernel32.OpenProcess.restype = HANDLE
38
+ self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
39
+ self.kernel32.WaitForSingleObject.restype = DWORD
40
+
41
+ # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
42
+ SYNCHRONIZE = 0x00100000
43
+ self.manager_handle = self.kernel32.OpenProcess(
44
+ SYNCHRONIZE, 0, self.manager_pid
45
+ )
46
+
47
+ if not self.manager_handle:
48
+ raise ctypes.WinError(ctypes.get_last_error())
49
+
50
+ self.manager_dead = False
51
+
52
+ def is_alive(self):
53
+ if not self.manager_dead:
54
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
55
+ self.manager_dead = (
56
+ self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
57
+ )
58
+ return not self.manager_dead
59
+
60
+
61
+ else:
62
+
63
+ class ManagerWatchdog(object):
64
+ def __init__(self):
65
+ self.manager_pid = os.getppid()
66
+ self.manager_dead = False
67
+
68
+ def is_alive(self):
69
+ if not self.manager_dead:
70
+ self.manager_dead = os.getppid() != self.manager_pid
71
+ return not self.manager_dead
72
+
73
+
74
+ _worker_info = None
75
+
76
+
77
+ class WorkerInfo(object):
78
+ __initialized = False
79
+
80
+ def __init__(self, **kwargs):
81
+ for k, v in kwargs.items():
82
+ setattr(self, k, v)
83
+ self.__initialized = True
84
+
85
+ def __setattr__(self, key, val):
86
+ if self.__initialized:
87
+ raise RuntimeError(
88
+ "Cannot assign attributes to {} objects".format(self.__class__.__name__)
89
+ )
90
+ return super(WorkerInfo, self).__setattr__(key, val)
91
+
92
+
93
+ def get_worker_info():
94
+ r"""Returns the information about the current
95
+ :class:`~torch.utils.data.DataLoader` iterator worker process.
96
+
97
+ When called in a worker, this returns an object guaranteed to have the
98
+ following attributes:
99
+
100
+ * :attr:`id`: the current worker id.
101
+ * :attr:`num_workers`: the total number of workers.
102
+ * :attr:`seed`: the random seed set for the current worker. This value is
103
+ determined by main process RNG and the worker id. See
104
+ :class:`~torch.utils.data.DataLoader`'s documentation for more details.
105
+ * :attr:`dataset`: the copy of the dataset object in **this** process. Note
106
+ that this will be a different object in a different process than the one
107
+ in the main process.
108
+
109
+ When called in the main process, this returns ``None``.
110
+
111
+ .. note::
112
+ When used in a :attr:`worker_init_fn` passed over to
113
+ :class:`~torch.utils.data.DataLoader`, this method can be useful to
114
+ set up each worker process differently, for instance, using ``worker_id``
115
+ to configure the ``dataset`` object to only read a specific fraction of a
116
+ sharded dataset, or use ``seed`` to seed other libraries used in dataset
117
+ code (e.g., NumPy).
118
+ """
119
+ return _worker_info
120
+
121
+
122
+ r"""Dummy class used to signal the end of an IterableDataset"""
123
+ _IterableDatasetStopIteration = namedtuple(
124
+ "_IterableDatasetStopIteration", ["worker_id"]
125
+ )
126
+
127
+
128
+ def worker_loop(
129
+ dataset_kind,
130
+ dataset,
131
+ index_queue,
132
+ data_queue,
133
+ done_event,
134
+ auto_collation,
135
+ collate_fn,
136
+ drop_last,
137
+ seed,
138
+ init_fn,
139
+ worker_id,
140
+ num_workers,
141
+ ):
142
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
143
+ # logic of this function.
144
+
145
+ try:
146
+ # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
147
+ # module's handlers are executed after Python returns from C low-level
148
+ # handlers, likely when the same fatal signal had already happened
149
+ # again.
150
+ # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
151
+ signal_handling._set_worker_signal_handlers()
152
+
153
+ torch.set_num_threads(1)
154
+ random.seed(seed)
155
+ torch.manual_seed(seed)
156
+
157
+ global _worker_info
158
+ _worker_info = WorkerInfo(
159
+ id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
160
+ )
161
+
162
+ from torch.utils.data import _DatasetKind
163
+
164
+ init_exception = None
165
+
166
+ try:
167
+ if init_fn is not None:
168
+ init_fn(worker_id)
169
+
170
+ fetcher = _DatasetKind.create_fetcher(
171
+ dataset_kind, dataset, auto_collation, collate_fn, drop_last
172
+ )
173
+ except Exception:
174
+ init_exception = ExceptionWrapper(
175
+ where="in DataLoader worker process {}".format(worker_id)
176
+ )
177
+
178
+ # When using Iterable mode, some worker can exit earlier than others due
179
+ # to the IterableDataset behaving differently for different workers.
180
+ # When such things happen, an `_IterableDatasetStopIteration` object is
181
+ # sent over to the main process with the ID of this worker, so that the
182
+ # main process won't send more tasks to this worker, and will send
183
+ # `None` to this worker to properly exit it.
184
+ #
185
+ # Note that we cannot set `done_event` from a worker as it is shared
186
+ # among all processes. Instead, we set the `iteration_end` flag to
187
+ # signify that the iterator is exhausted. When either `done_event` or
188
+ # `iteration_end` is set, we skip all processing step and just wait for
189
+ # `None`.
190
+ iteration_end = False
191
+
192
+ watchdog = ManagerWatchdog()
193
+
194
+ while watchdog.is_alive():
195
+ try:
196
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
197
+ except queue.Empty:
198
+ continue
199
+ if r is None:
200
+ # Received the final signal
201
+ assert done_event.is_set() or iteration_end
202
+ break
203
+ elif done_event.is_set() or iteration_end:
204
+ # `done_event` is set. But I haven't received the final signal
205
+ # (None) yet. I will keep continuing until get it, and skip the
206
+ # processing steps.
207
+ continue
208
+ idx, index = r
209
+ """ Added """
210
+ MyRandomResizedCrop.sample_image_size(idx)
211
+ """ Added """
212
+ if init_exception is not None:
213
+ data = init_exception
214
+ init_exception = None
215
+ else:
216
+ try:
217
+ data = fetcher.fetch(index)
218
+ except Exception as e:
219
+ if (
220
+ isinstance(e, StopIteration)
221
+ and dataset_kind == _DatasetKind.Iterable
222
+ ):
223
+ data = _IterableDatasetStopIteration(worker_id)
224
+ # Set `iteration_end`
225
+ # (1) to save future `next(...)` calls, and
226
+ # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
227
+ iteration_end = True
228
+ else:
229
+ # It is important that we don't store exc_info in a variable.
230
+ # `ExceptionWrapper` does the correct thing.
231
+ # See NOTE [ Python Traceback Reference Cycle Problem ]
232
+ data = ExceptionWrapper(
233
+ where="in DataLoader worker process {}".format(worker_id)
234
+ )
235
+ data_queue.put((idx, data))
236
+ del data, idx, index, r # save memory
237
+ except KeyboardInterrupt:
238
+ # Main process will raise KeyboardInterrupt anyways.
239
+ pass
240
+ if done_event.is_set():
241
+ data_queue.cancel_join_thread()
242
+ data_queue.close()