|
|
import torch |
|
|
from torch import nn |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import json, os |
|
|
import gradio as gr |
|
|
import torchvision.transforms.functional as TF |
|
|
from safetensors.torch import load_file |
|
|
from matplotlib import cm |
|
|
|
|
|
from models import get_model |
|
|
from utils import resize_density_map, init_seeds |
|
|
|
|
|
|
|
|
mean = (0.485, 0.456, 0.406) |
|
|
std = (0.229, 0.224, 0.225) |
|
|
alpha = 0.8 |
|
|
init_seeds(42) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
truncation = 4 |
|
|
reduction = 8 |
|
|
granularity = "fine" |
|
|
anchor_points = "average" |
|
|
|
|
|
model_name = "clip_vit_l_14" |
|
|
input_size = 224 |
|
|
|
|
|
|
|
|
prompt_type = "word" |
|
|
num_vpt = 32 |
|
|
vpt_drop = 0. |
|
|
deep_vpt = True |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
if truncation is None: |
|
|
bins, anchor_points = None, None |
|
|
else: |
|
|
with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f: |
|
|
config = json.load(f)[str(truncation)]["nwpu"] |
|
|
bins = config["bins"][granularity] |
|
|
anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"] |
|
|
bins = [(float(b[0]), float(b[1])) for b in bins] |
|
|
anchor_points = [float(p) for p in anchor_points] |
|
|
|
|
|
|
|
|
model = get_model( |
|
|
backbone=model_name, |
|
|
input_size=input_size, |
|
|
reduction=reduction, |
|
|
bins=bins, |
|
|
anchor_points=anchor_points, |
|
|
|
|
|
prompt_type=prompt_type, |
|
|
num_vpt=num_vpt, |
|
|
vpt_drop=vpt_drop, |
|
|
deep_vpt=deep_vpt |
|
|
) |
|
|
weights_path = os.path.join("pre-trained weights", "CLIP-EBC-ViT-L-14-NWPU", "model.safetensors") |
|
|
state_dict = load_file(weights_path) |
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
new_state_dict[k.replace("model.", "")] = v |
|
|
model.load_state_dict(new_state_dict) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transform(image: Image.Image): |
|
|
assert isinstance(image, Image.Image), "Input must be a PIL Image" |
|
|
image_tensor = TF.to_tensor(image) |
|
|
|
|
|
image_height, image_width = image_tensor.shape[-2:] |
|
|
if image_height < input_size or image_width < input_size: |
|
|
|
|
|
ratio = max(input_size / image_height, input_size / image_width) |
|
|
new_height = int(image_height * ratio) + 1 |
|
|
new_width = int(image_width * ratio) + 1 |
|
|
image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True) |
|
|
|
|
|
image_tensor = TF.normalize(image_tensor, mean=mean, std=std) |
|
|
return image_tensor.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(image: Image.Image): |
|
|
""" |
|
|
Given an input image, preprocess it, run the model to obtain a density map, |
|
|
compute the total crowd count, and prepare the density map for display. |
|
|
""" |
|
|
|
|
|
input_width, input_height = image.size |
|
|
input_tensor = transform(image).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
density_map = model(input_tensor) |
|
|
total_count = density_map.sum().item() |
|
|
resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy() |
|
|
|
|
|
|
|
|
eps = 1e-8 |
|
|
density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps) |
|
|
|
|
|
|
|
|
colormap = cm.get_cmap("jet") |
|
|
|
|
|
density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8) |
|
|
density_map_color_img = Image.fromarray(density_map_color).convert("RGBA") |
|
|
|
|
|
|
|
|
image_rgba = image.convert("RGBA") |
|
|
overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha) |
|
|
|
|
|
return image, overlayed_image, f"Predicted Count: {total_count:.2f}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Crowd Counting Demo") |
|
|
gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_img = gr.Image( |
|
|
label="Input Image", |
|
|
sources=["upload", "clipboard"], |
|
|
type="pil", |
|
|
) |
|
|
submit_btn = gr.Button("Predict") |
|
|
with gr.Column(): |
|
|
output_img = gr.Image(label="Predicted Density Map", type="pil") |
|
|
output_text = gr.Textbox(label="Total Count") |
|
|
|
|
|
submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text]) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["example1.jpg"], |
|
|
["example2.jpg"] |
|
|
], |
|
|
inputs=input_img, |
|
|
label="Try an example" |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch(share=True) |
|
|
|