|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import torch |
|
|
import gdown |
|
|
|
|
|
from proard.classification.networks import get_net_by_name, ResNet50 |
|
|
from proard.classification.elastic_nn.networks import ( |
|
|
DYNResNets,DYNMobileNetV3,DYNProxylessNASNets,DYNProxylessNASNets_Cifar,DYNMobileNetV3_Cifar,DYNResNets_Cifar |
|
|
) |
|
|
from proard.classification.networks import (WideResNet,ResNet50_Cifar,ResNet50,MobileNetV3_Cifar,MobileNetV3Large_Cifar,MobileNetV3Large,ProxylessNASNets_Cifar,ProxylessNASNets,MobileNetV2_Cifar,MobileNetV2) |
|
|
__all__ = [ |
|
|
"DYN_net", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def DYN_net(net_id, robust_mode, dataset,train_criterion, pretrained=True,run_config=None,WPS=False,base=False): |
|
|
if net_id == "ResNet50": |
|
|
if not base: |
|
|
if dataset == "cifar10" or dataset == "cifar100": |
|
|
net = DYNResNets_Cifar(n_classes=run_config.data_provider.n_classes, |
|
|
dropout_rate=0, |
|
|
depth_list=[0, 1, 2], |
|
|
expand_ratio_list=[0.2, 0.25, 0.35], |
|
|
width_mult_list=[0.65, 0.8, 1.0], |
|
|
) |
|
|
else: |
|
|
net = DYNResNets(n_classes=run_config.data_provider.n_classes, |
|
|
dropout_rate=0, |
|
|
depth_list=[0, 1, 2], |
|
|
expand_ratio_list=[0.2, 0.25, 0.35], |
|
|
width_mult_list=[0.65, 0.8, 1.0], |
|
|
) |
|
|
else: |
|
|
if dataset == "cifar10" or dataset == "cifar100": |
|
|
net = ResNet50_Cifar(n_classes=run_config.data_provider.n_classes) |
|
|
else: |
|
|
net = ResNet50(n_classes=run_config.data_provider.n_classes) |
|
|
|
|
|
elif net_id == "MBV3": |
|
|
if not base: |
|
|
if dataset == "cifar10" or dataset == "cifar100": |
|
|
net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes, |
|
|
dropout_rate=0, |
|
|
width_mult=1.0, |
|
|
ks_list=[3, 5, 7], |
|
|
expand_ratio_list=[3, 4, 6], |
|
|
depth_list=[2, 3, 4], |
|
|
) |
|
|
else: |
|
|
net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes, |
|
|
dropout_rate=0, |
|
|
width_mult=1.0, |
|
|
ks_list=[3, 5, 7], |
|
|
expand_ratio_list=[3, 4, 6], |
|
|
depth_list=[2, 3, 4], |
|
|
) |
|
|
else: |
|
|
if dataset == "cifar10" or dataset == "cifar100": |
|
|
net = MobileNetV3Large_Cifar(n_classes=run_config.data_provider.n_classes) |
|
|
else: |
|
|
net = MobileNetV3Large(n_classes=run_config.data_provider.n_classes) |
|
|
elif net_id == "ProxylessNASNet": |
|
|
if not base: |
|
|
if dataset == "cifar10" or dataset == "cifar100": |
|
|
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes, |
|
|
dropout_rate=0, |
|
|
width_mult=1.0, |
|
|
ks_list=[3, 5, 7], |
|
|
expand_ratio_list=[3, 4, 6], |
|
|
depth_list=[2, 3, 4], |
|
|
) |
|
|
else: |
|
|
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes, |
|
|
dropout_rate=0, |
|
|
width_mult=1.0, |
|
|
ks_list=[3, 5, 7], |
|
|
expand_ratio_list=[3, 4, 6], |
|
|
depth_list=[2, 3, 4], |
|
|
) |
|
|
else: |
|
|
if dataset == "cifar10" or dataset == "cifar100": |
|
|
net = ProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes) |
|
|
else: |
|
|
net = ProxylessNASNets(n_classes=run_config.data_provider.n_classes) |
|
|
elif net_id == "MBV2": |
|
|
if not base: |
|
|
if dataset == "cifar10" or dataset == "cifar100": |
|
|
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes, |
|
|
dropout_rate=0, |
|
|
base_stage_width="google", |
|
|
width_mult=1.0, |
|
|
ks_list=[3, 5, 7], |
|
|
expand_ratio_list=[3, 4, 6], |
|
|
depth_list=[2, 3, 4], |
|
|
) |
|
|
else: |
|
|
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes, |
|
|
dropout_rate=0, |
|
|
base_stage_width="google", |
|
|
width_mult=1.0, |
|
|
ks_list=[3, 5, 7], |
|
|
expand_ratio_list=[3, 4, 6], |
|
|
depth_list=[2, 3, 4], |
|
|
) |
|
|
else: |
|
|
if dataset == "cifar10" or dataset == "cifar100": |
|
|
net = MobileNetV2_Cifar(n_classes=run_config.data_provider.n_classes) |
|
|
else: |
|
|
net = MobileNetV2(n_classes=run_config.data_provider.n_classes) |
|
|
elif net_id == "WideResNet": |
|
|
if dataset == "cifar10" or dataset == "cifar100": |
|
|
net = WideResNet(num_classes=run_config.data_provider.n_classes) |
|
|
else: |
|
|
raise ValueError("Not supported: %s" % net_id) |
|
|
|
|
|
else: |
|
|
raise ValueError("Not supported: %s" % net_id) |
|
|
|
|
|
if pretrained and not WPS and not base: |
|
|
if net_id == "ResNet50": |
|
|
if robust_mode: |
|
|
pt_path = "exp/robust/"+ dataset + "/" + net_id + '/' + train_criterion +"/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar" |
|
|
else: |
|
|
pt_path = "exp/"+ dataset + "/" + net_id + '/' + train_criterion + "/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar" |
|
|
else: |
|
|
if robust_mode: |
|
|
pt_path = "exp/robust/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar" |
|
|
|
|
|
else: |
|
|
pt_path = "exp/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar" |
|
|
elif pretrained and WPS and not base: |
|
|
if net_id == "ResNet50": |
|
|
if robust_mode: |
|
|
pt_path = "exp/robust/WPS/"+ dataset + "/" + net_id + '/' + train_criterion +"/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar" |
|
|
else: |
|
|
pt_path = "exp/WPS/"+ dataset + "/" + net_id + '/' + train_criterion + "/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar" |
|
|
else: |
|
|
if robust_mode: |
|
|
pt_path = "exp/robust/WPS/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar" |
|
|
|
|
|
else: |
|
|
pt_path = "exp/WPS/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar" |
|
|
else: |
|
|
if not base: |
|
|
pt_path = "exp/robust/teacher/"+ dataset + '/' + net_id + '/' + train_criterion + "/checkpoint/model_best.pth.tar" |
|
|
else: |
|
|
pt_path = "exp/robust/base/"+ dataset + '/' + net_id + '/' + train_criterion + "/checkpoint/model_best.pth.tar" |
|
|
print(pt_path) |
|
|
init = torch.load(pt_path, map_location="cuda")["state_dict"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
net.load_state_dict(init) |
|
|
return net |
|
|
|
|
|
|
|
|
|