CLIP-EBC / app.py
Yiming-M's picture
Renamed Weights Folder Names; Added app.py
2366e68
import torch
from torch import nn
import numpy as np
from PIL import Image
import json, os
import gradio as gr
import torchvision.transforms.functional as TF
from safetensors.torch import load_file # Import the load_file function from safetensors
from matplotlib import cm
from models import get_model
from utils import resize_density_map, init_seeds
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
alpha = 0.8
init_seeds(42)
# -----------------------------
# Define the model architecture
# -----------------------------
truncation = 4
reduction = 8
granularity = "fine"
anchor_points = "average"
model_name = "clip_vit_l_14"
input_size = 224
# Comment the lines below to test non-CLIP models.
prompt_type = "word"
num_vpt = 32
vpt_drop = 0.
deep_vpt = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if truncation is None: # regression, no truncation.
bins, anchor_points = None, None
else:
with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
config = json.load(f)[str(truncation)]["nwpu"]
bins = config["bins"][granularity]
anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
bins = [(float(b[0]), float(b[1])) for b in bins]
anchor_points = [float(p) for p in anchor_points]
model = get_model(
backbone=model_name,
input_size=input_size,
reduction=reduction,
bins=bins,
anchor_points=anchor_points,
# CLIP parameters
prompt_type=prompt_type,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
deep_vpt=deep_vpt
)
weights_path = os.path.join("pre-trained weights", "CLIP-EBC-ViT-L-14-NWPU", "model.safetensors")
state_dict = load_file(weights_path)
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k.replace("model.", "")] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
# -----------------------------
# Preprocessing function
# -----------------------------
# Adjust the image transforms to match what your model expects.
def transform(image: Image.Image):
assert isinstance(image, Image.Image), "Input must be a PIL Image"
image_tensor = TF.to_tensor(image)
image_height, image_width = image_tensor.shape[-2:]
if image_height < input_size or image_width < input_size:
# Find the ratio to resize the image while maintaining the aspect ratio
ratio = max(input_size / image_height, input_size / image_width)
new_height = int(image_height * ratio) + 1
new_width = int(image_width * ratio) + 1
image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True)
image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
return image_tensor.unsqueeze(0) # Add batch dimension
# -----------------------------
# Inference function
# -----------------------------
def predict(image: Image.Image):
"""
Given an input image, preprocess it, run the model to obtain a density map,
compute the total crowd count, and prepare the density map for display.
"""
# Preprocess the image
input_width, input_height = image.size
input_tensor = transform(image).to(device) # shape: (1, 3, H, W)
with torch.no_grad():
density_map = model(input_tensor) # expected shape: (1, 1, H, W)
total_count = density_map.sum().item()
resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
# Normalize the density map for display purposes
eps = 1e-8
density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps)
# Apply a colormap (e.g., 'jet') to get an RGBA image
colormap = cm.get_cmap("jet")
# The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8)
density_map_color_img = Image.fromarray(density_map_color).convert("RGBA")
# Ensure the original image is in RGBA format.
image_rgba = image.convert("RGBA")
overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha)
return image, overlayed_image, f"Predicted Count: {total_count:.2f}"
# -----------------------------
# Build Gradio Interface using Blocks for a two-column layout
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Crowd Counting Demo")
gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
with gr.Row():
with gr.Column():
input_img = gr.Image(
label="Input Image",
sources=["upload", "clipboard"],
type="pil",
)
submit_btn = gr.Button("Predict")
with gr.Column():
output_img = gr.Image(label="Predicted Density Map", type="pil")
output_text = gr.Textbox(label="Total Count")
submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text])
# Optional: add example images. Ensure these files are in your repo.
gr.Examples(
examples=[
["example1.jpg"],
["example2.jpg"]
],
inputs=input_img,
label="Try an example"
)
# Launch the app
demo.launch(share=True)