File size: 7,712 Bytes
188f311 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import json
import numpy as np
from tqdm import tqdm
import torch
import torch.utils.data
from proard.utils import list_mean
__all__ = ["net_setting2id", "net_id2setting", "AccuracyDataset"]
def net_setting2id(net_setting):
return json.dumps(net_setting)
def net_id2setting(net_id):
return json.loads(net_id)
class RegDataset(torch.utils.data.Dataset):
def __init__(self, inputs, targets):
super(RegDataset, self).__init__()
self.inputs = inputs
self.targets = targets
def __getitem__(self, index):
return self.inputs[index], self.targets[index]
def __len__(self):
return self.inputs.size(0)
class AccuracyDataset:
def __init__(self, path):
self.path = path
os.makedirs(self.path, exist_ok=True)
@property
def net_id_path(self):
return os.path.join(self.path, "net_id.dict")
@property
def acc_src_folder(self):
return os.path.join(self.path, "src")
@property
def acc_dict_path(self):
return os.path.join(self.path, "src/acc.dict")
# TODO: support parallel building
def build_acc_dataset(
self, run_manager, dyn_network, n_arch=2000, image_size_list=None
):
# load net_id_list, random sample if not exist
if os.path.isfile(self.net_id_path):
net_id_list = json.load(open(self.net_id_path))
else:
net_id_list = set()
while len(net_id_list) < n_arch:
net_setting = dyn_network.sample_active_subnet()
net_id = net_setting2id(net_setting)
net_id_list.add(net_id)
net_id_list = list(net_id_list)
net_id_list.sort()
json.dump(net_id_list, open(self.net_id_path, "w"), indent=4)
image_size_list = (
[128, 160, 192, 224] if image_size_list is None else image_size_list
)
print(image_size_list)
with tqdm(
total=len(net_id_list) * len(image_size_list), desc="Building Acc Dataset"
) as t:
for image_size in image_size_list:
# load val dataset into memory
val_dataset = []
run_manager.run_config.data_provider.assign_active_img_size(image_size)
for images, labels in run_manager.run_config.valid_loader:
val_dataset.append((images, labels))
# save path
os.makedirs(self.acc_src_folder, exist_ok=True)
acc_save_path = os.path.join(
self.acc_src_folder, "%d.dict" % image_size
)
acc_dict = {}
# load existing acc dict
if os.path.isfile(acc_save_path):
existing_acc_dict = json.load(open(acc_save_path, "r"))
else:
existing_acc_dict = {}
for net_id in net_id_list:
net_setting = net_id2setting(net_id)
key = net_setting2id({**net_setting, "image_size": image_size})
if key in existing_acc_dict:
acc_dict[key] = existing_acc_dict[key]
t.set_postfix(
{
"net_id": net_id,
"image_size": image_size,
"info_val": acc_dict[key],
"status": "loading",
}
)
t.update()
continue
dyn_network.set_active_subnet(**net_setting)
run_manager.reset_running_statistics(dyn_network)
net_setting_str = ",".join(
[
"%s_%s"
% (
key,
"%.1f" % list_mean(val)
if isinstance(val, list)
else val,
)
for key, val in net_setting.items()
]
)
loss, (top1, top5,robust1,robust5) = run_manager.validate(
run_str=net_setting_str,
net=dyn_network,
data_loader=val_dataset,
no_logs=True,
)
info_val = top1
t.set_postfix(
{
"net_id": net_id,
"image_size": image_size,
"info_val": info_val,
}
)
t.update()
acc_dict.update({key: info_val})
json.dump(acc_dict, open(acc_save_path, "w"), indent=4)
def merge_acc_dataset(self, image_size_list=None):
# load existing data
merged_acc_dict = {}
for fname in os.listdir(self.acc_src_folder):
if ".dict" not in fname:
continue
image_size = int(fname.split(".dict")[0])
if image_size_list is not None and image_size not in image_size_list:
print("Skip ", fname)
continue
full_path = os.path.join(self.acc_src_folder, fname)
partial_acc_dict = json.load(open(full_path))
merged_acc_dict.update(partial_acc_dict)
print("loaded %s" % full_path)
json.dump(merged_acc_dict, open(self.acc_dict_path, "w"), indent=4)
return merged_acc_dict
def build_acc_data_loader(
self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16
):
# load data
acc_dict = json.load(open(self.acc_dict_path))
X_all = []
Y_all = []
with tqdm(total=len(acc_dict), desc="Loading data") as t:
for k, v in acc_dict.items():
dic = json.loads(k)
X_all.append(arch_encoder.arch2feature(dic))
Y_all.append(v / 100.0) # range: 0 - 1
t.update()
base_acc = np.mean(Y_all)
# convert to torch tensor
X_all = torch.tensor(X_all, dtype=torch.float)
Y_all = torch.tensor(Y_all)
# random shuffle
shuffle_idx = torch.randperm(len(X_all))
X_all = X_all[shuffle_idx]
Y_all = Y_all[shuffle_idx]
# split data
idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
val_idx = X_all.size(0) // 5 * 4
X_train, Y_train = X_all[:idx], Y_all[:idx]
X_test, Y_test = X_all[val_idx:], Y_all[val_idx:]
print("Train Size: %d," % len(X_train), "Valid Size: %d" % len(X_test))
# build data loader
train_dataset = RegDataset(X_train, Y_train)
val_dataset = RegDataset(X_test, Y_test)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=False,
num_workers=n_workers,
)
valid_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=False,
num_workers=n_workers,
)
return train_loader, valid_loader, base_acc
|