smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import random
import time
import json
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from tqdm import tqdm
from attacks.utils import ctx_noparamgrad_and_eval
from robust_loss.rslad import rslad_inner_loss,kl_loss
from robust_loss.trades import trades_loss
from attacks import create_attack
from proard.utils import (
get_net_info,
cross_entropy_loss_with_soft_target,
cross_entropy_with_label_smoothing,
)
from proard.utils import (
AverageMeter,
accuracy,
write_log,
mix_images,
mix_labels,
init_models,
)
from proard.utils import MyRandomResizedCrop
__all__ = ["RunManager"]
class RunManager:
def __init__(
self, path, net, run_config, init=True, measure_latency=None, no_gpu=False
):
self.path = path
self.net = net
self.run_config = run_config
self.best_acc = 0
self.best_robustness = 0
self.start_epoch = 0
os.makedirs(self.path, exist_ok=True)
# move network to GPU if available
if torch.cuda.is_available() and (not no_gpu):
self.device = torch.device("cuda")
self.net = self.net.to(self.device)
cudnn.benchmark = True
else:
self.device = torch.device("cpu")
# initialize model (default)
if init:
init_models(net,run_config.model_init)
# net info
net_info = get_net_info(
self.net, self.run_config.data_provider.data_shape, measure_latency, True
)
with open("%s/net_info.txt" % self.path, "w") as fout:
fout.write(json.dumps(net_info, indent=4) + "\n")
# noinspection PyBroadException
try:
fout.write(self.network.module_str + "\n")
except Exception:
pass
fout.write("%s\n" % self.run_config.data_provider.train.dataset.transform)
fout.write("%s\n" % self.run_config.data_provider.test.dataset.transform)
fout.write("%s\n" % self.network)
self.train_criterion = self.run_config.train_criterion_loss
self.test_criterion = self.run_config.test_criterion_loss
self.kd_criterion = self.run_config.kd_criterion_loss
# optimizer
if self.run_config.no_decay_keys:
keys = self.run_config.no_decay_keys.split("#")
net_params = [
self.network.get_parameters(
keys, mode="exclude"
), # parameters with weight decay
self.network.get_parameters(
keys, mode="include"
), # parameters without weight decay
]
else:
# noinspection PyBroadException
try:
net_params = self.network.weight_parameters()
except Exception:
net_params = []
for param in self.network.parameters():
if param.requires_grad:
net_params.append(param)
self.optimizer = self.run_config.build_optimizer(net_params)
self.net = torch.nn.DataParallel(self.net)
""" save path and log path """
@property
def save_path(self):
if self.__dict__.get("_save_path", None) is None:
save_path = os.path.join(self.path, "checkpoint")
os.makedirs(save_path, exist_ok=True)
self.__dict__["_save_path"] = save_path
return self.__dict__["_save_path"]
@property
def logs_path(self):
if self.__dict__.get("_logs_path", None) is None:
logs_path = os.path.join(self.path, "logs")
os.makedirs(logs_path, exist_ok=True)
self.__dict__["_logs_path"] = logs_path
return self.__dict__["_logs_path"]
@property
def network(self):
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
def write_log(self, log_str, prefix="valid", should_print=True, mode="a"):
write_log(self.logs_path, log_str, prefix, should_print, mode)
""" save and load models """
def save_model(self, checkpoint=None, is_best=False, model_name=None):
if checkpoint is None:
checkpoint = {"state_dict": self.network.state_dict()}
if model_name is None:
model_name = "checkpoint.pth.tar"
checkpoint[
"dataset"
] = self.run_config.dataset # add `dataset` info to the checkpoint
latest_fname = os.path.join(self.save_path, "latest.txt")
model_path = os.path.join(self.save_path, model_name)
with open(latest_fname, "w") as fout:
fout.write(model_path + "\n")
torch.save(checkpoint, model_path)
if is_best:
best_path = os.path.join(self.save_path, "model_best.pth.tar")
torch.save({"state_dict": checkpoint["state_dict"]}, best_path)
def load_model(self, model_fname=None):
latest_fname = os.path.join(self.save_path, "latest.txt")
if model_fname is None and os.path.exists(latest_fname):
with open(latest_fname, "r") as fin:
model_fname = fin.readline()
if model_fname[-1] == "\n":
model_fname = model_fname[:-1]
# noinspection PyBroadException
try:
if model_fname is None or not os.path.exists(model_fname):
model_fname = "%s/checkpoint.pth.tar" % self.save_path
with open(latest_fname, "w") as fout:
fout.write(model_fname + "\n")
print("=> loading checkpoint '{}'".format(model_fname))
checkpoint = torch.load(model_fname, map_location="cpu")
except Exception:
print("fail to load checkpoint from %s" % self.save_path)
return {}
self.network.load_state_dict(checkpoint["state_dict"])
if "epoch" in checkpoint:
self.start_epoch = checkpoint["epoch"] + 1
if "best_acc" in checkpoint:
self.best_acc = checkpoint["best_acc"]
if "optimizer" in checkpoint:
self.optimizer.load_state_dict(checkpoint["optimizer"])
print("=> loaded checkpoint '{}'".format(model_fname))
return checkpoint
def save_config(self, extra_run_config=None, extra_net_config=None):
"""dump run_config and net_config to the model_folder"""
run_save_path = os.path.join(self.path, "run.config")
if not os.path.isfile(run_save_path):
run_config = self.run_config.config
if extra_run_config is not None:
run_config.update(extra_run_config)
json.dump(run_config, open(run_save_path, "w"), indent=4)
print("Run configs dump to %s" % run_save_path)
try:
net_save_path = os.path.join(self.path, "net.config")
net_config = self.network.config
if extra_net_config is not None:
net_config.update(extra_net_config)
json.dump(net_config, open(net_save_path, "w"), indent=4)
print("Network configs dump to %s" % net_save_path)
except Exception:
print("%s do not support net config" % type(self.network))
""" metric related """
def get_metric_dict(self):
return {
"top1": AverageMeter(),
"top5": AverageMeter(),
"robust1" :AverageMeter(),
"robust5" :AverageMeter(),
}
def update_metric(self, metric_dict, output, output_adv, labels):
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
robust1,robust5 = accuracy(output_adv,labels,topk=(1,5))
metric_dict["top1"].update(acc1[0].item(), output.size(0))
metric_dict["top5"].update(acc5[0].item(), output.size(0))
metric_dict["robust1"].update(robust1[0].item(), output.size(0))
metric_dict["robust5"].update(robust5[0].item(), output.size(0))
def get_metric_vals(self, metric_dict, return_dict=False):
if return_dict:
return {key: metric_dict[key].avg for key in metric_dict}
else:
return [metric_dict[key].avg for key in metric_dict]
def get_metric_names(self):
return "top1", "top5" , "robust1" , "robust5"
""" train and test """
def validate(
self,
epoch=0,
is_test=False,
run_str="",
net=None,
data_loader=None,
no_logs=False,
train_mode=False,
):
if net is None:
net = self.net
if not isinstance(net, nn.DataParallel):
net = nn.DataParallel(net)
if data_loader is None:
data_loader = (
self.run_config.test_loader if is_test else self.run_config.valid_loader
)
if train_mode:
net.train()
else:
net.eval()
if self.run_config.robust_mode:
eval_attack = create_attack(net, self.test_criterion.cuda(), self.run_config.attack_type,self.run_config.epsilon_test,self.run_config.num_steps_test, self.run_config.step_size_test)
losses = AverageMeter()
metric_dict = self.get_metric_dict()
with tqdm(
total=len(data_loader),
desc="Validate Epoch #{} {}".format(epoch + 1, run_str),
disable=no_logs,
) as t:
for i, (images, labels) in enumerate(data_loader):
images, labels = images.to(self.device), labels.to(self.device)
# compute output
output = net(images)
if self.run_config.robust_mode:
with ctx_noparamgrad_and_eval(net):
images_adv,_ = eval_attack.perturb(images, labels)
output_adv = net(images_adv)
loss = nn.CrossEntropyLoss()(output_adv,labels)
else:
output_adv = output
loss = nn.CrossEntropyLoss()(output,labels)
# measure accuracy and record loss
self.update_metric(metric_dict, output, output_adv , labels)
losses.update(loss.item(), images.size(0))
t.set_postfix(
{
"loss": losses.avg,
**self.get_metric_vals(metric_dict, return_dict=True),
"img_size": images.size(2),
}
)
t.update(1)
return losses.avg, self.get_metric_vals(metric_dict)
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
if net is None:
net = self.network
if isinstance(self.run_config.data_provider.image_size, list):
img_size_list, loss_list, top1_list, top5_list , robust1_list , robust5_list = [], [], [], [],[],[]
for img_size in self.run_config.data_provider.image_size:
img_size_list.append(img_size)
self.run_config.data_provider.assign_active_img_size(img_size)
self.reset_running_statistics(net=net)
loss, (top1, top5 , robust1,robust5) = self.validate(epoch, is_test, net=net)
loss_list.append(loss)
top1_list.append(top1)
top5_list.append(top5)
robust1_list.append(robust1)
robust5_list.append(robust5)
return img_size_list, loss_list, top1_list, top5_list ,robust1_list ,robust5_list
else:
loss, (top1, top5 , robust1 , robust5) = self.validate(epoch, is_test, net=net)
return (
[self.run_config.data_provider.active_img_size],
[loss],
[top1],
[top5],
[robust1],
[robust5]
)
def train_one_epoch(self, args, epoch, warmup_epochs=0, warmup_lr=0):
# switch to train mode
self.net.train()
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
nBatch = len(self.run_config.train_loader)
losses = AverageMeter()
metric_dict = self.get_metric_dict()
data_time = AverageMeter()
with tqdm(
total=nBatch,
desc="{} Train Epoch #{}".format(self.run_config.dataset, epoch + 1),
) as t:
end = time.time()
for i, (images, labels) in enumerate(self.run_config.train_loader):
MyRandomResizedCrop.BATCH = i
data_time.update(time.time() - end)
if epoch < warmup_epochs:
new_lr = self.run_config.warmup_adjust_learning_rate(
self.optimizer,
warmup_epochs * nBatch,
nBatch,
epoch,
i,
warmup_lr,
)
else:
new_lr = self.run_config.adjust_learning_rate(
self.optimizer, epoch - warmup_epochs, i, nBatch
)
images, labels = images.to(self.device), labels.to(self.device)
target = labels
if isinstance(self.run_config.mixup_alpha, float):
# transform data
lam = random.betavariate(
self.run_config.mixup_alpha, self.run_config.mixup_alpha
)
images = mix_images(images, lam)
labels = mix_labels(
labels,
lam,
self.run_config.data_provider.n_classes,
self.run_config.label_smoothing,
)
# soft target
if args.teacher_model is not None:
args.teacher_model.train()
with torch.no_grad():
soft_logits = args.teacher_model(images).detach()
soft_label = F.softmax(soft_logits, dim=1)
# compute output
output = self.net(images)
if args.teacher_model is None:
if self.run_config.robust_mode:
loss = self.train_criterion(self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train,self.run_config.distance_train)
loss_type = self.run_config.train_criterion
else:
loss = torch.nn.CrossEntropyLoss(output,labels)
loss_type = 'ce'
else:
if self.run_config.robust_mode:
loss = self.kd_criterion(args.teacher_model,self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train)
loss_type = self.run_config.train_criterion
else:
if args.kd_type == "ce":
kd_loss = cross_entropy_loss_with_soft_target(
output, soft_label
)
else:
kd_loss = F.mse_loss(output, soft_logits)
loss = args.kd_ratio * kd_loss + loss
loss_type = "%.1fkd+ce" % args.kd_ratio
# compute gradient and do SGD step
self.net.zero_grad() # or self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# measure accuracy and record loss
losses.update(loss.item(), images.size(0))
self.update_metric(metric_dict, output, output ,target)
t.set_postfix(
{
"loss": losses.avg,
**self.get_metric_vals(metric_dict, return_dict=True),
"img_size": images.size(2),
"lr": new_lr,
"loss_type": loss_type,
"data_time": data_time.avg,
}
)
t.update(1)
end = time.time()
return losses.avg, self.get_metric_vals(metric_dict)
def train(self, args, warmup_epoch=0, warmup_lr=0):
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch):
train_loss, (train_top1, train_top5 , train_robust1 , train_robust5) = self.train_one_epoch(
args, epoch, warmup_epoch, warmup_lr
)
if (epoch + 1) % self.run_config.validation_frequency == 0:
img_size, val_loss, val_acc, val_acc5 ,val_robust, val_robust5 = self.validate_all_resolution(
epoch=epoch, is_test=False
)
is_best = np.mean(val_acc) > self.best_acc
is_best_robust = np.mean(val_robust) > self.best_robustness
self.best_acc = max(self.best_acc, np.mean(val_acc))
self.best_robustness = max(self.best_robustness, np.mean(val_robust))
val_log = "Valid [{0}/{1}]\tloss {2:.3f} \t{7} {3:.3f} ({5:.3f}) \t{8} {4:.3f} ({6:.3f})".format(
epoch + 1 - warmup_epoch,
self.run_config.n_epochs,
np.mean(val_loss),
np.mean(val_acc),
np.mean(val_robust),
self.best_acc,
self.best_robustness,
self.get_metric_names()[0],
self.get_metric_names()[2],
)
val_log += "\t{2} {0:.3f} \tTrain {1} {top1:.3f}\t {3} {robust:.3f} \t loss {train_loss:.3f}\t".format(
np.mean(val_acc5),
*self.get_metric_names(),
top1=train_top1,
robust = train_robust1,
train_loss=train_loss
)
for i_s, v_a in zip(img_size, val_acc):
val_log += "(%d, %.3f), " % (i_s, v_a)
self.write_log(val_log, prefix="valid", should_print=False)
else:
is_best = False
is_best_robust = False
self.save_model(
{
"epoch": epoch,
"best_acc": self.best_acc,
"optimizer": self.optimizer.state_dict(),
"state_dict": self.network.state_dict(),
},
is_best=is_best,
)
def reset_running_statistics(
self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None
):
from proard.classification.elastic_nn.utils import set_running_statistics
if net is None:
net = self.network
if data_loader is None:
data_loader = self.run_config.random_sub_train_loader(
subset_size, subset_batch_size
)
set_running_statistics(net, data_loader)