|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import random |
|
|
import time |
|
|
import json |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.nn.parallel |
|
|
import torch.backends.cudnn as cudnn |
|
|
import torch.optim |
|
|
from tqdm import tqdm |
|
|
from attacks.utils import ctx_noparamgrad_and_eval |
|
|
from robust_loss.rslad import rslad_inner_loss,kl_loss |
|
|
from robust_loss.trades import trades_loss |
|
|
from attacks import create_attack |
|
|
from proard.utils import ( |
|
|
get_net_info, |
|
|
cross_entropy_loss_with_soft_target, |
|
|
cross_entropy_with_label_smoothing, |
|
|
) |
|
|
from proard.utils import ( |
|
|
AverageMeter, |
|
|
accuracy, |
|
|
write_log, |
|
|
mix_images, |
|
|
mix_labels, |
|
|
init_models, |
|
|
) |
|
|
from proard.utils import MyRandomResizedCrop |
|
|
|
|
|
__all__ = ["RunManager"] |
|
|
|
|
|
|
|
|
class RunManager: |
|
|
def __init__( |
|
|
self, path, net, run_config, init=True, measure_latency=None, no_gpu=False |
|
|
): |
|
|
self.path = path |
|
|
self.net = net |
|
|
self.run_config = run_config |
|
|
|
|
|
self.best_acc = 0 |
|
|
self.best_robustness = 0 |
|
|
self.start_epoch = 0 |
|
|
|
|
|
os.makedirs(self.path, exist_ok=True) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available() and (not no_gpu): |
|
|
self.device = torch.device("cuda") |
|
|
self.net = self.net.to(self.device) |
|
|
cudnn.benchmark = True |
|
|
else: |
|
|
self.device = torch.device("cpu") |
|
|
|
|
|
if init: |
|
|
init_models(net,run_config.model_init) |
|
|
|
|
|
|
|
|
net_info = get_net_info( |
|
|
self.net, self.run_config.data_provider.data_shape, measure_latency, True |
|
|
) |
|
|
with open("%s/net_info.txt" % self.path, "w") as fout: |
|
|
fout.write(json.dumps(net_info, indent=4) + "\n") |
|
|
|
|
|
try: |
|
|
fout.write(self.network.module_str + "\n") |
|
|
except Exception: |
|
|
pass |
|
|
fout.write("%s\n" % self.run_config.data_provider.train.dataset.transform) |
|
|
fout.write("%s\n" % self.run_config.data_provider.test.dataset.transform) |
|
|
fout.write("%s\n" % self.network) |
|
|
|
|
|
self.train_criterion = self.run_config.train_criterion_loss |
|
|
self.test_criterion = self.run_config.test_criterion_loss |
|
|
self.kd_criterion = self.run_config.kd_criterion_loss |
|
|
|
|
|
|
|
|
if self.run_config.no_decay_keys: |
|
|
keys = self.run_config.no_decay_keys.split("#") |
|
|
net_params = [ |
|
|
self.network.get_parameters( |
|
|
keys, mode="exclude" |
|
|
), |
|
|
self.network.get_parameters( |
|
|
keys, mode="include" |
|
|
), |
|
|
] |
|
|
else: |
|
|
|
|
|
try: |
|
|
net_params = self.network.weight_parameters() |
|
|
except Exception: |
|
|
net_params = [] |
|
|
for param in self.network.parameters(): |
|
|
if param.requires_grad: |
|
|
net_params.append(param) |
|
|
self.optimizer = self.run_config.build_optimizer(net_params) |
|
|
|
|
|
self.net = torch.nn.DataParallel(self.net) |
|
|
|
|
|
""" save path and log path """ |
|
|
|
|
|
@property |
|
|
def save_path(self): |
|
|
if self.__dict__.get("_save_path", None) is None: |
|
|
save_path = os.path.join(self.path, "checkpoint") |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
self.__dict__["_save_path"] = save_path |
|
|
return self.__dict__["_save_path"] |
|
|
|
|
|
@property |
|
|
def logs_path(self): |
|
|
if self.__dict__.get("_logs_path", None) is None: |
|
|
logs_path = os.path.join(self.path, "logs") |
|
|
os.makedirs(logs_path, exist_ok=True) |
|
|
self.__dict__["_logs_path"] = logs_path |
|
|
return self.__dict__["_logs_path"] |
|
|
|
|
|
@property |
|
|
def network(self): |
|
|
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net |
|
|
|
|
|
def write_log(self, log_str, prefix="valid", should_print=True, mode="a"): |
|
|
write_log(self.logs_path, log_str, prefix, should_print, mode) |
|
|
|
|
|
""" save and load models """ |
|
|
|
|
|
def save_model(self, checkpoint=None, is_best=False, model_name=None): |
|
|
if checkpoint is None: |
|
|
checkpoint = {"state_dict": self.network.state_dict()} |
|
|
|
|
|
if model_name is None: |
|
|
model_name = "checkpoint.pth.tar" |
|
|
|
|
|
checkpoint[ |
|
|
"dataset" |
|
|
] = self.run_config.dataset |
|
|
latest_fname = os.path.join(self.save_path, "latest.txt") |
|
|
model_path = os.path.join(self.save_path, model_name) |
|
|
with open(latest_fname, "w") as fout: |
|
|
fout.write(model_path + "\n") |
|
|
torch.save(checkpoint, model_path) |
|
|
|
|
|
if is_best: |
|
|
best_path = os.path.join(self.save_path, "model_best.pth.tar") |
|
|
torch.save({"state_dict": checkpoint["state_dict"]}, best_path) |
|
|
|
|
|
def load_model(self, model_fname=None): |
|
|
latest_fname = os.path.join(self.save_path, "latest.txt") |
|
|
if model_fname is None and os.path.exists(latest_fname): |
|
|
with open(latest_fname, "r") as fin: |
|
|
model_fname = fin.readline() |
|
|
if model_fname[-1] == "\n": |
|
|
model_fname = model_fname[:-1] |
|
|
|
|
|
try: |
|
|
if model_fname is None or not os.path.exists(model_fname): |
|
|
model_fname = "%s/checkpoint.pth.tar" % self.save_path |
|
|
with open(latest_fname, "w") as fout: |
|
|
fout.write(model_fname + "\n") |
|
|
print("=> loading checkpoint '{}'".format(model_fname)) |
|
|
checkpoint = torch.load(model_fname, map_location="cpu") |
|
|
except Exception: |
|
|
print("fail to load checkpoint from %s" % self.save_path) |
|
|
return {} |
|
|
|
|
|
self.network.load_state_dict(checkpoint["state_dict"]) |
|
|
if "epoch" in checkpoint: |
|
|
self.start_epoch = checkpoint["epoch"] + 1 |
|
|
if "best_acc" in checkpoint: |
|
|
self.best_acc = checkpoint["best_acc"] |
|
|
if "optimizer" in checkpoint: |
|
|
self.optimizer.load_state_dict(checkpoint["optimizer"]) |
|
|
|
|
|
print("=> loaded checkpoint '{}'".format(model_fname)) |
|
|
return checkpoint |
|
|
|
|
|
def save_config(self, extra_run_config=None, extra_net_config=None): |
|
|
"""dump run_config and net_config to the model_folder""" |
|
|
run_save_path = os.path.join(self.path, "run.config") |
|
|
if not os.path.isfile(run_save_path): |
|
|
run_config = self.run_config.config |
|
|
if extra_run_config is not None: |
|
|
run_config.update(extra_run_config) |
|
|
json.dump(run_config, open(run_save_path, "w"), indent=4) |
|
|
print("Run configs dump to %s" % run_save_path) |
|
|
|
|
|
try: |
|
|
net_save_path = os.path.join(self.path, "net.config") |
|
|
net_config = self.network.config |
|
|
if extra_net_config is not None: |
|
|
net_config.update(extra_net_config) |
|
|
json.dump(net_config, open(net_save_path, "w"), indent=4) |
|
|
print("Network configs dump to %s" % net_save_path) |
|
|
except Exception: |
|
|
print("%s do not support net config" % type(self.network)) |
|
|
|
|
|
""" metric related """ |
|
|
|
|
|
def get_metric_dict(self): |
|
|
return { |
|
|
"top1": AverageMeter(), |
|
|
"top5": AverageMeter(), |
|
|
"robust1" :AverageMeter(), |
|
|
"robust5" :AverageMeter(), |
|
|
} |
|
|
|
|
|
def update_metric(self, metric_dict, output, output_adv, labels): |
|
|
acc1, acc5 = accuracy(output, labels, topk=(1, 5)) |
|
|
robust1,robust5 = accuracy(output_adv,labels,topk=(1,5)) |
|
|
metric_dict["top1"].update(acc1[0].item(), output.size(0)) |
|
|
metric_dict["top5"].update(acc5[0].item(), output.size(0)) |
|
|
metric_dict["robust1"].update(robust1[0].item(), output.size(0)) |
|
|
metric_dict["robust5"].update(robust5[0].item(), output.size(0)) |
|
|
|
|
|
|
|
|
def get_metric_vals(self, metric_dict, return_dict=False): |
|
|
if return_dict: |
|
|
return {key: metric_dict[key].avg for key in metric_dict} |
|
|
else: |
|
|
return [metric_dict[key].avg for key in metric_dict] |
|
|
|
|
|
def get_metric_names(self): |
|
|
return "top1", "top5" , "robust1" , "robust5" |
|
|
|
|
|
""" train and test """ |
|
|
|
|
|
def validate( |
|
|
self, |
|
|
epoch=0, |
|
|
is_test=False, |
|
|
run_str="", |
|
|
net=None, |
|
|
data_loader=None, |
|
|
no_logs=False, |
|
|
train_mode=False, |
|
|
): |
|
|
if net is None: |
|
|
net = self.net |
|
|
if not isinstance(net, nn.DataParallel): |
|
|
net = nn.DataParallel(net) |
|
|
if data_loader is None: |
|
|
data_loader = ( |
|
|
self.run_config.test_loader if is_test else self.run_config.valid_loader |
|
|
) |
|
|
|
|
|
if train_mode: |
|
|
net.train() |
|
|
else: |
|
|
net.eval() |
|
|
if self.run_config.robust_mode: |
|
|
eval_attack = create_attack(net, self.test_criterion.cuda(), self.run_config.attack_type,self.run_config.epsilon_test,self.run_config.num_steps_test, self.run_config.step_size_test) |
|
|
losses = AverageMeter() |
|
|
metric_dict = self.get_metric_dict() |
|
|
|
|
|
with tqdm( |
|
|
total=len(data_loader), |
|
|
desc="Validate Epoch #{} {}".format(epoch + 1, run_str), |
|
|
disable=no_logs, |
|
|
) as t: |
|
|
for i, (images, labels) in enumerate(data_loader): |
|
|
images, labels = images.to(self.device), labels.to(self.device) |
|
|
|
|
|
output = net(images) |
|
|
if self.run_config.robust_mode: |
|
|
with ctx_noparamgrad_and_eval(net): |
|
|
images_adv,_ = eval_attack.perturb(images, labels) |
|
|
output_adv = net(images_adv) |
|
|
loss = nn.CrossEntropyLoss()(output_adv,labels) |
|
|
else: |
|
|
output_adv = output |
|
|
loss = nn.CrossEntropyLoss()(output,labels) |
|
|
|
|
|
|
|
|
self.update_metric(metric_dict, output, output_adv , labels) |
|
|
|
|
|
losses.update(loss.item(), images.size(0)) |
|
|
t.set_postfix( |
|
|
{ |
|
|
"loss": losses.avg, |
|
|
**self.get_metric_vals(metric_dict, return_dict=True), |
|
|
"img_size": images.size(2), |
|
|
} |
|
|
) |
|
|
t.update(1) |
|
|
return losses.avg, self.get_metric_vals(metric_dict) |
|
|
|
|
|
def validate_all_resolution(self, epoch=0, is_test=False, net=None): |
|
|
if net is None: |
|
|
net = self.network |
|
|
if isinstance(self.run_config.data_provider.image_size, list): |
|
|
img_size_list, loss_list, top1_list, top5_list , robust1_list , robust5_list = [], [], [], [],[],[] |
|
|
for img_size in self.run_config.data_provider.image_size: |
|
|
img_size_list.append(img_size) |
|
|
self.run_config.data_provider.assign_active_img_size(img_size) |
|
|
self.reset_running_statistics(net=net) |
|
|
loss, (top1, top5 , robust1,robust5) = self.validate(epoch, is_test, net=net) |
|
|
loss_list.append(loss) |
|
|
top1_list.append(top1) |
|
|
top5_list.append(top5) |
|
|
robust1_list.append(robust1) |
|
|
robust5_list.append(robust5) |
|
|
return img_size_list, loss_list, top1_list, top5_list ,robust1_list ,robust5_list |
|
|
else: |
|
|
loss, (top1, top5 , robust1 , robust5) = self.validate(epoch, is_test, net=net) |
|
|
return ( |
|
|
[self.run_config.data_provider.active_img_size], |
|
|
[loss], |
|
|
[top1], |
|
|
[top5], |
|
|
[robust1], |
|
|
[robust5] |
|
|
) |
|
|
|
|
|
def train_one_epoch(self, args, epoch, warmup_epochs=0, warmup_lr=0): |
|
|
|
|
|
self.net.train() |
|
|
MyRandomResizedCrop.EPOCH = epoch |
|
|
|
|
|
nBatch = len(self.run_config.train_loader) |
|
|
|
|
|
losses = AverageMeter() |
|
|
metric_dict = self.get_metric_dict() |
|
|
data_time = AverageMeter() |
|
|
|
|
|
with tqdm( |
|
|
total=nBatch, |
|
|
desc="{} Train Epoch #{}".format(self.run_config.dataset, epoch + 1), |
|
|
) as t: |
|
|
end = time.time() |
|
|
for i, (images, labels) in enumerate(self.run_config.train_loader): |
|
|
MyRandomResizedCrop.BATCH = i |
|
|
data_time.update(time.time() - end) |
|
|
if epoch < warmup_epochs: |
|
|
new_lr = self.run_config.warmup_adjust_learning_rate( |
|
|
self.optimizer, |
|
|
warmup_epochs * nBatch, |
|
|
nBatch, |
|
|
epoch, |
|
|
i, |
|
|
warmup_lr, |
|
|
) |
|
|
else: |
|
|
new_lr = self.run_config.adjust_learning_rate( |
|
|
self.optimizer, epoch - warmup_epochs, i, nBatch |
|
|
) |
|
|
|
|
|
images, labels = images.to(self.device), labels.to(self.device) |
|
|
target = labels |
|
|
if isinstance(self.run_config.mixup_alpha, float): |
|
|
|
|
|
lam = random.betavariate( |
|
|
self.run_config.mixup_alpha, self.run_config.mixup_alpha |
|
|
) |
|
|
images = mix_images(images, lam) |
|
|
labels = mix_labels( |
|
|
labels, |
|
|
lam, |
|
|
self.run_config.data_provider.n_classes, |
|
|
self.run_config.label_smoothing, |
|
|
) |
|
|
|
|
|
|
|
|
if args.teacher_model is not None: |
|
|
args.teacher_model.train() |
|
|
with torch.no_grad(): |
|
|
soft_logits = args.teacher_model(images).detach() |
|
|
soft_label = F.softmax(soft_logits, dim=1) |
|
|
|
|
|
|
|
|
output = self.net(images) |
|
|
|
|
|
if args.teacher_model is None: |
|
|
if self.run_config.robust_mode: |
|
|
loss = self.train_criterion(self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train,self.run_config.distance_train) |
|
|
loss_type = self.run_config.train_criterion |
|
|
else: |
|
|
loss = torch.nn.CrossEntropyLoss(output,labels) |
|
|
loss_type = 'ce' |
|
|
|
|
|
else: |
|
|
if self.run_config.robust_mode: |
|
|
loss = self.kd_criterion(args.teacher_model,self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train) |
|
|
loss_type = self.run_config.train_criterion |
|
|
else: |
|
|
if args.kd_type == "ce": |
|
|
kd_loss = cross_entropy_loss_with_soft_target( |
|
|
output, soft_label |
|
|
) |
|
|
else: |
|
|
kd_loss = F.mse_loss(output, soft_logits) |
|
|
loss = args.kd_ratio * kd_loss + loss |
|
|
loss_type = "%.1fkd+ce" % args.kd_ratio |
|
|
|
|
|
|
|
|
self.net.zero_grad() |
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
losses.update(loss.item(), images.size(0)) |
|
|
self.update_metric(metric_dict, output, output ,target) |
|
|
|
|
|
t.set_postfix( |
|
|
{ |
|
|
"loss": losses.avg, |
|
|
**self.get_metric_vals(metric_dict, return_dict=True), |
|
|
"img_size": images.size(2), |
|
|
"lr": new_lr, |
|
|
"loss_type": loss_type, |
|
|
"data_time": data_time.avg, |
|
|
} |
|
|
) |
|
|
t.update(1) |
|
|
end = time.time() |
|
|
return losses.avg, self.get_metric_vals(metric_dict) |
|
|
|
|
|
def train(self, args, warmup_epoch=0, warmup_lr=0): |
|
|
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch): |
|
|
train_loss, (train_top1, train_top5 , train_robust1 , train_robust5) = self.train_one_epoch( |
|
|
args, epoch, warmup_epoch, warmup_lr |
|
|
) |
|
|
|
|
|
if (epoch + 1) % self.run_config.validation_frequency == 0: |
|
|
img_size, val_loss, val_acc, val_acc5 ,val_robust, val_robust5 = self.validate_all_resolution( |
|
|
epoch=epoch, is_test=False |
|
|
) |
|
|
|
|
|
is_best = np.mean(val_acc) > self.best_acc |
|
|
is_best_robust = np.mean(val_robust) > self.best_robustness |
|
|
self.best_acc = max(self.best_acc, np.mean(val_acc)) |
|
|
self.best_robustness = max(self.best_robustness, np.mean(val_robust)) |
|
|
val_log = "Valid [{0}/{1}]\tloss {2:.3f} \t{7} {3:.3f} ({5:.3f}) \t{8} {4:.3f} ({6:.3f})".format( |
|
|
epoch + 1 - warmup_epoch, |
|
|
self.run_config.n_epochs, |
|
|
np.mean(val_loss), |
|
|
np.mean(val_acc), |
|
|
np.mean(val_robust), |
|
|
self.best_acc, |
|
|
self.best_robustness, |
|
|
self.get_metric_names()[0], |
|
|
self.get_metric_names()[2], |
|
|
) |
|
|
val_log += "\t{2} {0:.3f} \tTrain {1} {top1:.3f}\t {3} {robust:.3f} \t loss {train_loss:.3f}\t".format( |
|
|
np.mean(val_acc5), |
|
|
*self.get_metric_names(), |
|
|
top1=train_top1, |
|
|
robust = train_robust1, |
|
|
train_loss=train_loss |
|
|
) |
|
|
for i_s, v_a in zip(img_size, val_acc): |
|
|
val_log += "(%d, %.3f), " % (i_s, v_a) |
|
|
self.write_log(val_log, prefix="valid", should_print=False) |
|
|
else: |
|
|
is_best = False |
|
|
is_best_robust = False |
|
|
|
|
|
self.save_model( |
|
|
{ |
|
|
"epoch": epoch, |
|
|
"best_acc": self.best_acc, |
|
|
"optimizer": self.optimizer.state_dict(), |
|
|
"state_dict": self.network.state_dict(), |
|
|
}, |
|
|
is_best=is_best, |
|
|
) |
|
|
|
|
|
def reset_running_statistics( |
|
|
self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None |
|
|
): |
|
|
from proard.classification.elastic_nn.utils import set_running_statistics |
|
|
|
|
|
if net is None: |
|
|
net = self.network |
|
|
if data_loader is None: |
|
|
data_loader = self.run_config.random_sub_train_loader( |
|
|
subset_size, subset_batch_size |
|
|
) |
|
|
set_running_statistics(net, data_loader) |
|
|
|