|
|
import torch |
|
|
from torch import nn |
|
|
import torch.multiprocessing as mp |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.cuda.amp import GradScaler |
|
|
|
|
|
from argparse import ArgumentParser |
|
|
import os, json |
|
|
|
|
|
current_dir = os.path.abspath(os.path.dirname(__file__)) |
|
|
|
|
|
from datasets import standardize_dataset_name |
|
|
from models import get_model |
|
|
|
|
|
from utils import setup, cleanup, init_seeds, get_logger, get_config, barrier |
|
|
from utils import get_dataloader, get_loss_fn, get_optimizer, load_checkpoint, save_checkpoint |
|
|
from utils import get_writer, update_train_result, update_eval_result, log |
|
|
from train import train |
|
|
from eval import evaluate |
|
|
|
|
|
|
|
|
parser = ArgumentParser(description="Train an EBC model.") |
|
|
|
|
|
|
|
|
parser.add_argument("--model", type=str, default="vgg19_ae", help="The model to train.") |
|
|
parser.add_argument("--input_size", type=int, default=448, help="The size of the input image.") |
|
|
parser.add_argument("--reduction", type=int, default=8, choices=[8, 16, 32], help="The reduction factor of the model.") |
|
|
parser.add_argument("--regression", action="store_true", help="Use blockwise regression instead of classification.") |
|
|
parser.add_argument("--truncation", type=int, default=None, help="The truncation of the count.") |
|
|
parser.add_argument("--anchor_points", type=str, default="average", choices=["average", "middle"], help="The representative count values of bins.") |
|
|
parser.add_argument("--prompt_type", type=str, default="word", choices=["word", "number"], help="The prompt type for CLIP.") |
|
|
parser.add_argument("--granularity", type=str, default="fine", choices=["fine", "dynamic", "coarse"], help="The granularity of bins.") |
|
|
parser.add_argument("--num_vpt", type=int, default=32, help="The number of visual prompt tokens.") |
|
|
parser.add_argument("--vpt_drop", type=float, default=0.0, help="The dropout rate for visual prompt tokens.") |
|
|
parser.add_argument("--shallow_vpt", action="store_true", help="Use shallow visual prompt tokens.") |
|
|
|
|
|
|
|
|
parser.add_argument("--dataset", type=str, required=True, help="The dataset to train on.") |
|
|
parser.add_argument("--batch_size", type=int, default=8, help="The training batch size.") |
|
|
parser.add_argument("--num_crops", type=int, default=1, help="The number of crops for multi-crop training.") |
|
|
parser.add_argument("--min_scale", type=float, default=1.0, help="The minimum scale for random scale augmentation.") |
|
|
parser.add_argument("--max_scale", type=float, default=2.0, help="The maximum scale for random scale augmentation.") |
|
|
parser.add_argument("--brightness", type=float, default=0.1, help="The brightness factor for random color jitter augmentation.") |
|
|
parser.add_argument("--contrast", type=float, default=0.1, help="The contrast factor for random color jitter augmentation.") |
|
|
parser.add_argument("--saturation", type=float, default=0.1, help="The saturation factor for random color jitter augmentation.") |
|
|
parser.add_argument("--hue", type=float, default=0.0, help="The hue factor for random color jitter augmentation.") |
|
|
parser.add_argument("--kernel_size", type=int, default=5, help="The kernel size for Gaussian blur augmentation.") |
|
|
parser.add_argument("--saltiness", type=float, default=1e-3, help="The saltiness for pepper salt noise augmentation.") |
|
|
parser.add_argument("--spiciness", type=float, default=1e-3, help="The spiciness for pepper salt noise augmentation.") |
|
|
parser.add_argument("--jitter_prob", type=float, default=0.2, help="The probability for random color jitter augmentation.") |
|
|
parser.add_argument("--blur_prob", type=float, default=0.2, help="The probability for Gaussian blur augmentation.") |
|
|
parser.add_argument("--noise_prob", type=float, default=0.5, help="The probability for pepper salt noise augmentation.") |
|
|
|
|
|
|
|
|
parser.add_argument("--sliding_window", action="store_true", help="Use sliding window strategy for evaluation.") |
|
|
parser.add_argument("--stride", type=int, default=None, help="The stride for sliding window strategy.") |
|
|
parser.add_argument("--window_size", type=int, default=None, help="The window size for in prediction.") |
|
|
parser.add_argument("--resize_to_multiple", action="store_true", help="Resize the image to the nearest multiple of the input size.") |
|
|
parser.add_argument("--zero_pad_to_multiple", action="store_true", help="Zero pad the image to the nearest multiple of the input size.") |
|
|
|
|
|
|
|
|
parser.add_argument("--weight_count_loss", type=float, default=1.0, help="The weight for count loss.") |
|
|
parser.add_argument("--count_loss", type=str, default="mae", choices=["mae", "mse", "dmcount"], help="The loss function for count.") |
|
|
|
|
|
|
|
|
parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate.") |
|
|
parser.add_argument("--weight_decay", type=float, default=1e-4, help="The weight decay.") |
|
|
|
|
|
|
|
|
parser.add_argument("--warmup_epochs", type=int, default=50, help="Number of epochs for warmup. The learning rate will increase from eta_min to lr.") |
|
|
parser.add_argument("--warmup_lr", type=float, default=1e-6, help="Learning rate for warmup.") |
|
|
parser.add_argument("--T_0", type=int, default=5, help="Number of epochs for the first restart.") |
|
|
parser.add_argument("--T_mult", type=int, default=2, help="A factor increases T_0 after a restart.") |
|
|
parser.add_argument("--eta_min", type=float, default=1e-7, help="Minimum learning rate.") |
|
|
|
|
|
|
|
|
parser.add_argument("--total_epochs", type=int, default=2600, help="Number of epochs to train.") |
|
|
parser.add_argument("--eval_start", type=int, default=50, help="Start to evaluate after this number of epochs.") |
|
|
parser.add_argument("--eval_freq", type=int, default=1, help="Evaluate every this number of epochs.") |
|
|
parser.add_argument("--save_freq", type=int, default=5, help="Save checkpoint every this number of epochs. Could help reduce I/O.") |
|
|
parser.add_argument("--save_best_k", type=int, default=3, help="Save the best k checkpoints.") |
|
|
parser.add_argument("--amp", action="store_true", help="Use automatic mixed precision training.") |
|
|
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.") |
|
|
parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training.") |
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed.") |
|
|
|
|
|
|
|
|
def run(local_rank: int, nprocs: int, args: ArgumentParser) -> None: |
|
|
print(f"Rank {local_rank} process among {nprocs} processes.") |
|
|
init_seeds(args.seed + local_rank) |
|
|
setup(local_rank, nprocs) |
|
|
print(f"Initialized successfully. Training with {nprocs} GPUs.") |
|
|
device = f"cuda:{local_rank}" if local_rank != -1 else "cuda:0" |
|
|
print(f"Using device: {device}.") |
|
|
|
|
|
ddp = nprocs > 1 |
|
|
|
|
|
if args.regression: |
|
|
bins, anchor_points = None, None |
|
|
else: |
|
|
with open(os.path.join(current_dir, "configs", f"reduction_{args.reduction}.json"), "r") as f: |
|
|
config = json.load(f)[str(args.truncation)][args.dataset] |
|
|
bins = config["bins"][args.granularity] |
|
|
anchor_points = config["anchor_points"][args.granularity]["average"] if args.anchor_points == "average" else config["anchor_points"][args.granularity]["middle"] |
|
|
bins = [(float(b[0]), float(b[1])) for b in bins] |
|
|
anchor_points = [float(p) for p in anchor_points] |
|
|
|
|
|
args.bins = bins |
|
|
args.anchor_points = anchor_points |
|
|
|
|
|
model = get_model( |
|
|
backbone=args.model, |
|
|
input_size=args.input_size, |
|
|
reduction=args.reduction, |
|
|
bins=bins, |
|
|
anchor_points=anchor_points, |
|
|
prompt_type=args.prompt_type, |
|
|
num_vpt=args.num_vpt, |
|
|
vpt_drop=args.vpt_drop, |
|
|
deep_vpt=not args.shallow_vpt |
|
|
).to(device) |
|
|
|
|
|
grad_scaler = GradScaler() if args.amp else None |
|
|
|
|
|
loss_fn = get_loss_fn(args).to(device) |
|
|
optimizer, scheduler = get_optimizer(args, model) |
|
|
|
|
|
ckpt_dir_name = f"{args.model}_{args.prompt_type}_" if "clip" in args.model else f"{args.model}_" |
|
|
ckpt_dir_name += f"{args.input_size}_{args.reduction}_{args.truncation}_{args.granularity}_" |
|
|
ckpt_dir_name += f"{args.weight_count_loss}_{args.count_loss}" |
|
|
|
|
|
args.ckpt_dir = os.path.join(current_dir, "checkpoints", args.dataset, ckpt_dir_name) |
|
|
os.makedirs(args.ckpt_dir, exist_ok=True) |
|
|
model, optimizer, scheduler, grad_scaler, start_epoch, loss_info, hist_val_scores, best_val_scores = load_checkpoint(args, model, optimizer, scheduler, grad_scaler) |
|
|
|
|
|
if local_rank == 0: |
|
|
model_without_ddp = model |
|
|
writer = get_writer(args.ckpt_dir) |
|
|
logger = get_logger(os.path.join(args.ckpt_dir, "train.log")) |
|
|
logger.info(get_config(vars(args), mute=False)) |
|
|
val_loader = get_dataloader(args, split="val", ddp=False) |
|
|
|
|
|
args.batch_size = int(args.batch_size / nprocs) |
|
|
args.num_workers = int(args.num_workers / nprocs) |
|
|
train_loader, sampler = get_dataloader(args, split="train", ddp=ddp) |
|
|
|
|
|
model = DDP(nn.SyncBatchNorm.convert_sync_batchnorm(model), device_ids=[local_rank], output_device=local_rank) if ddp else model |
|
|
|
|
|
for epoch in range(start_epoch, args.total_epochs + 1): |
|
|
if local_rank == 0: |
|
|
message = f"\tlr: {optimizer.param_groups[0]['lr']:.3e}" |
|
|
log(logger, epoch, args.total_epochs, message=message) |
|
|
|
|
|
if sampler is not None: |
|
|
sampler.set_epoch(epoch) |
|
|
|
|
|
model, optimizer, grad_scaler, loss_info = train(model, train_loader, loss_fn, optimizer, grad_scaler, device, local_rank, nprocs) |
|
|
scheduler.step() |
|
|
barrier(ddp) |
|
|
|
|
|
if local_rank == 0: |
|
|
eval = (epoch >= args.eval_start) and ((epoch - args.eval_start) % args.eval_freq == 0) |
|
|
update_train_result(epoch, loss_info, writer) |
|
|
log(logger, None, None, loss_info=loss_info, message="\n" * 2 if not eval else None) |
|
|
|
|
|
if eval: |
|
|
print("Evaluating") |
|
|
state_dict = model.module.state_dict() if ddp else model.state_dict() |
|
|
model_without_ddp.load_state_dict(state_dict) |
|
|
curr_val_scores = evaluate( |
|
|
model_without_ddp, |
|
|
val_loader, |
|
|
device, |
|
|
args.sliding_window, |
|
|
args.input_size, |
|
|
args.stride, |
|
|
) |
|
|
hist_val_scores, best_val_scores = update_eval_result(epoch, curr_val_scores, hist_val_scores, best_val_scores, writer, state_dict, os.path.join(args.ckpt_dir)) |
|
|
log(logger, None, None, None, curr_val_scores, best_val_scores, message="\n" * 3) |
|
|
|
|
|
if (epoch % args.save_freq == 0): |
|
|
save_checkpoint( |
|
|
epoch + 1, |
|
|
model.module.state_dict() if ddp else model.state_dict(), |
|
|
optimizer.state_dict(), |
|
|
scheduler.state_dict() if scheduler is not None else None, |
|
|
grad_scaler.state_dict() if grad_scaler is not None else None, |
|
|
loss_info, |
|
|
hist_val_scores, |
|
|
best_val_scores, |
|
|
args.ckpt_dir, |
|
|
) |
|
|
|
|
|
barrier(ddp) |
|
|
|
|
|
if local_rank == 0: |
|
|
writer.close() |
|
|
print("Training completed. Best scores:") |
|
|
for k in best_val_scores.keys(): |
|
|
scores = " ".join([f"{best_val_scores[k][i]:.4f};" for i in range(len(best_val_scores[k]))]) |
|
|
print(f" {k}: {scores}") |
|
|
|
|
|
cleanup(ddp) |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parser.parse_args() |
|
|
args.model = args.model.lower() |
|
|
args.dataset = standardize_dataset_name(args.dataset) |
|
|
|
|
|
if args.regression: |
|
|
args.truncation = None |
|
|
args.anchor_points = None |
|
|
args.bins = None |
|
|
args.prompt_type = None |
|
|
args.granularity = None |
|
|
|
|
|
if "clip_vit" not in args.model: |
|
|
args.num_vpt = None |
|
|
args.vpt_drop = None |
|
|
args.shallow_vpt = None |
|
|
|
|
|
if "clip" not in args.model: |
|
|
args.prompt_type = None |
|
|
|
|
|
if args.sliding_window: |
|
|
args.window_size = args.input_size if args.window_size is None else args.window_size |
|
|
args.stride = args.input_size if args.stride is None else args.stride |
|
|
assert not (args.zero_pad_to_multiple and args.resize_to_multiple), "Cannot use both zero pad and resize to multiple." |
|
|
|
|
|
else: |
|
|
args.window_size = None |
|
|
args.stride = None |
|
|
args.zero_pad_to_multiple = False |
|
|
args.resize_to_multiple = False |
|
|
|
|
|
args.nprocs = torch.cuda.device_count() |
|
|
print(f"Using {args.nprocs} GPUs.") |
|
|
if args.nprocs > 1: |
|
|
mp.spawn(run, nprocs=args.nprocs, args=(args.nprocs, args)) |
|
|
else: |
|
|
run(0, 1, args) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|