File size: 6,855 Bytes
8b98de9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
from argparse import ArgumentParser
import os, json
from tqdm import tqdm
current_dir = os.path.abspath(os.path.dirname(__file__))
from datasets import NWPUTest, Resize2Multiple
from models import get_model
from utils import get_config, sliding_window_predict
parser = ArgumentParser(description="Test a trained model on the NWPU-Crowd test set.")
# Parameters for model
parser.add_argument("--model", type=str, default="vgg19_ae", help="The model to train.")
parser.add_argument("--input_size", type=int, default=448, help="The size of the input image.")
parser.add_argument("--reduction", type=int, default=8, choices=[8, 16, 32], help="The reduction factor of the model.")
parser.add_argument("--regression", action="store_true", help="Use blockwise regression instead of classification.")
parser.add_argument("--truncation", type=int, default=None, help="The truncation of the count.")
parser.add_argument("--anchor_points", type=str, default="average", choices=["average", "middle"], help="The representative count values of bins.")
parser.add_argument("--prompt_type", type=str, default="word", choices=["word", "number"], help="The prompt type for CLIP.")
parser.add_argument("--granularity", type=str, default="fine", choices=["fine", "dynamic", "coarse"], help="The granularity of bins.")
parser.add_argument("--num_vpt", type=int, default=32, help="The number of visual prompt tokens.")
parser.add_argument("--vpt_drop", type=float, default=0.0, help="The dropout rate for visual prompt tokens.")
parser.add_argument("--shallow_vpt", action="store_true", help="Use shallow visual prompt tokens.")
parser.add_argument("--weight_path", type=str, required=True, help="The path to the weights of the model.")
# Parameters for evaluation
parser.add_argument("--sliding_window", action="store_true", help="Use sliding window strategy for evaluation.")
parser.add_argument("--stride", type=int, default=None, help="The stride for sliding window strategy.")
parser.add_argument("--window_size", type=int, default=None, help="The window size for in prediction.")
parser.add_argument("--resize_to_multiple", action="store_true", help="Resize the image to the nearest multiple of the input size.")
parser.add_argument("--zero_pad_to_multiple", action="store_true", help="Zero pad the image to the nearest multiple of the input size.")
parser.add_argument("--device", type=str, default="cuda", help="The device to use for evaluation.")
parser.add_argument("--num_workers", type=int, default=4, help="The number of workers for the data loader.")
def main(args: ArgumentParser):
print("Testing a trained model on the NWPU-Crowd test set.")
device = torch.device(args.device)
_ = get_config(vars(args).copy(), mute=False)
if args.regression:
bins, anchor_points = None, None
else:
with open(os.path.join(current_dir, "configs", f"reduction_{args.reduction}.json"), "r") as f:
config = json.load(f)[str(args.truncation)]["nwpu"]
bins = config["bins"][args.granularity]
anchor_points = config["anchor_points"][args.granularity]["average"] if args.anchor_points == "average" else config["anchor_points"][args.granularity]["middle"]
bins = [(float(b[0]), float(b[1])) for b in bins]
anchor_points = [float(p) for p in anchor_points]
args.bins = bins
args.anchor_points = anchor_points
model = get_model(
backbone=args.model,
input_size=args.input_size,
reduction=args.reduction,
bins=bins,
anchor_points=anchor_points,
prompt_type=args.prompt_type,
num_vpt=args.num_vpt,
vpt_drop=args.vpt_drop,
deep_vpt=not args.shallow_vpt
)
state_dict = torch.load(args.weight_path, map_location="cpu")
state_dict = state_dict if "best" in os.path.basename(args.weight_path) else state_dict["model_state_dict"]
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
model.eval()
sliding_window = args.sliding_window
if args.sliding_window:
window_size = args.input_size
stride = window_size // 2 if args.stride is None else args.stride
if args.resize_to_multiple:
transforms = Resize2Multiple(base=args.input_size)
else:
transforms = None
else:
window_size, stride = None, None
transforms = None
dataset = NWPUTest(transforms=transforms, return_filename=True)
image_ids = []
preds = []
for idx in tqdm(range(len(dataset)), desc="Testing on NWPU"):
image, image_path = dataset[idx]
image = image.unsqueeze(0) # add batch dimension
image = image.to(device) # add batch dimension
with torch.set_grad_enabled(False):
if sliding_window:
pred_density = sliding_window_predict(model, image, window_size, stride)
else:
pred_density = model(image)
pred_count = pred_density.sum(dim=(1, 2, 3)).item()
image_ids.append(os.path.basename(image_path).split(".")[0])
preds.append(pred_count)
result_dir = os.path.join(current_dir, "nwpu_test_results")
os.makedirs(result_dir, exist_ok=True)
weights_dir, weights_name = os.path.split(args.weight_path)
model_name = os.path.split(weights_dir)[-1]
result_path = os.path.join(result_dir, f"{model_name}_{weights_name.split('.')[0]}.txt")
with open(result_path, "w") as f:
for idx, (image_id, pred) in enumerate(zip(image_ids, preds)):
if idx != len(image_ids) - 1:
f.write(f"{image_id} {pred}\n")
else:
f.write(f"{image_id} {pred}") # no newline at the end of the file
if __name__ == "__main__":
args = parser.parse_args()
args.model = args.model.lower()
if args.regression:
args.truncation = None
args.anchor_points = None
args.bins = None
args.prompt_type = None
args.granularity = None
if "clip_vit" not in args.model:
args.num_vpt = None
args.vpt_drop = None
args.shallow_vpt = None
if "clip" not in args.model:
args.prompt_type = None
if args.sliding_window:
args.window_size = args.input_size if args.window_size is None else args.window_size
args.stride = args.input_size if args.stride is None else args.stride
assert not (args.zero_pad_to_multiple and args.resize_to_multiple), "Cannot use both zero pad and resize to multiple."
else:
args.window_size = None
args.stride = None
args.zero_pad_to_multiple = False
args.resize_to_multiple = False
main(args)
# Example usage:
# python test_nwpu.py --model vgg19_ae --truncation 4 --weight_path ./checkpoints/sha/vgg19_ae_448_4_1.0_dmcount_aug/best_mae.pth --device cuda:0 |