import torch from torch import nn import numpy as np from PIL import Image import json, os import gradio as gr import torchvision.transforms.functional as TF from safetensors.torch import load_file # Import the load_file function from safetensors from matplotlib import cm from models import get_model from utils import resize_density_map, init_seeds mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) alpha = 0.8 init_seeds(42) # ----------------------------- # Define the model architecture # ----------------------------- truncation = 4 reduction = 8 granularity = "fine" anchor_points = "average" model_name = "clip_vit_l_14" input_size = 224 # Comment the lines below to test non-CLIP models. prompt_type = "word" num_vpt = 32 vpt_drop = 0. deep_vpt = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if truncation is None: # regression, no truncation. bins, anchor_points = None, None else: with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f: config = json.load(f)[str(truncation)]["nwpu"] bins = config["bins"][granularity] anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"] bins = [(float(b[0]), float(b[1])) for b in bins] anchor_points = [float(p) for p in anchor_points] model = get_model( backbone=model_name, input_size=input_size, reduction=reduction, bins=bins, anchor_points=anchor_points, # CLIP parameters prompt_type=prompt_type, num_vpt=num_vpt, vpt_drop=vpt_drop, deep_vpt=deep_vpt ) weights_path = os.path.join("pre-trained weights", "CLIP-EBC-ViT-L-14-NWPU", "model.safetensors") state_dict = load_file(weights_path) new_state_dict = {} for k, v in state_dict.items(): new_state_dict[k.replace("model.", "")] = v model.load_state_dict(new_state_dict) model.to(device) model.eval() # ----------------------------- # Preprocessing function # ----------------------------- # Adjust the image transforms to match what your model expects. def transform(image: Image.Image): assert isinstance(image, Image.Image), "Input must be a PIL Image" image_tensor = TF.to_tensor(image) image_height, image_width = image_tensor.shape[-2:] if image_height < input_size or image_width < input_size: # Find the ratio to resize the image while maintaining the aspect ratio ratio = max(input_size / image_height, input_size / image_width) new_height = int(image_height * ratio) + 1 new_width = int(image_width * ratio) + 1 image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True) image_tensor = TF.normalize(image_tensor, mean=mean, std=std) return image_tensor.unsqueeze(0) # Add batch dimension # ----------------------------- # Inference function # ----------------------------- def predict(image: Image.Image): """ Given an input image, preprocess it, run the model to obtain a density map, compute the total crowd count, and prepare the density map for display. """ # Preprocess the image input_width, input_height = image.size input_tensor = transform(image).to(device) # shape: (1, 3, H, W) with torch.no_grad(): density_map = model(input_tensor) # expected shape: (1, 1, H, W) total_count = density_map.sum().item() resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy() # Normalize the density map for display purposes eps = 1e-8 density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps) # Apply a colormap (e.g., 'jet') to get an RGBA image colormap = cm.get_cmap("jet") # The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8. density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8) density_map_color_img = Image.fromarray(density_map_color).convert("RGBA") # Ensure the original image is in RGBA format. image_rgba = image.convert("RGBA") overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha) return image, overlayed_image, f"Predicted Count: {total_count:.2f}" # ----------------------------- # Build Gradio Interface using Blocks for a two-column layout # ----------------------------- with gr.Blocks() as demo: gr.Markdown("# Crowd Counting Demo") gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.") with gr.Row(): with gr.Column(): input_img = gr.Image( label="Input Image", sources=["upload", "clipboard"], type="pil", ) submit_btn = gr.Button("Predict") with gr.Column(): output_img = gr.Image(label="Predicted Density Map", type="pil") output_text = gr.Textbox(label="Total Count") submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text]) # Optional: add example images. Ensure these files are in your repo. gr.Examples( examples=[ ["example1.jpg"], ["example2.jpg"] ], inputs=input_img, label="Try an example" ) # Launch the app demo.launch(share=True)