smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import numpy as np
import torch
import torch.nn as nn
__all__ = ["AccuracyPredictor"]
class AccuracyPredictor(nn.Module):
def __init__(
self,
arch_encoder,
hidden_size=400,
n_layers=3,
checkpoint_path=None,
device="cuda:0",
base_acc_val = None
):
super(AccuracyPredictor, self).__init__()
self.arch_encoder = arch_encoder
self.hidden_size = hidden_size
self.n_layers = n_layers
self.device = device
self.base_acc_val = base_acc_val
# build layers
layers = []
for i in range(self.n_layers):
layers.append(
nn.Sequential(
nn.Linear(
self.arch_encoder.n_dim if i == 0 else self.hidden_size,
self.hidden_size,
),
nn.ReLU(inplace=True),
)
)
layers.append(nn.Linear(self.hidden_size, 1, bias=False))
self.layers = nn.Sequential(*layers)
if self.base_acc_val!=None :
self.base_acc = nn.Parameter(
torch.tensor(self.base_acc_val, device=self.device), requires_grad=False
)
else:
self.base_acc = nn.Parameter(
torch.zeros(1, device=self.device), requires_grad=False
)
if checkpoint_path is not None and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
self.load_state_dict(checkpoint)
print("Loaded checkpoint from %s" % checkpoint_path)
self.layers = self.layers.to(self.device)
def forward(self, x):
y = self.layers(x).squeeze()
return y + self.base_acc
def predict_acc(self, arch_dict_list):
X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
X = torch.tensor(np.array(X)).float().to(self.device)
return self.forward(X)