File size: 5,500 Bytes
2366e68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
from torch import nn
import numpy as np
from PIL import Image
import json, os
import gradio as gr
import torchvision.transforms.functional as TF
from safetensors.torch import load_file # Import the load_file function from safetensors
from matplotlib import cm
from models import get_model
from utils import resize_density_map, init_seeds
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
alpha = 0.8
init_seeds(42)
# -----------------------------
# Define the model architecture
# -----------------------------
truncation = 4
reduction = 8
granularity = "fine"
anchor_points = "average"
model_name = "clip_vit_l_14"
input_size = 224
# Comment the lines below to test non-CLIP models.
prompt_type = "word"
num_vpt = 32
vpt_drop = 0.
deep_vpt = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if truncation is None: # regression, no truncation.
bins, anchor_points = None, None
else:
with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
config = json.load(f)[str(truncation)]["nwpu"]
bins = config["bins"][granularity]
anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
bins = [(float(b[0]), float(b[1])) for b in bins]
anchor_points = [float(p) for p in anchor_points]
model = get_model(
backbone=model_name,
input_size=input_size,
reduction=reduction,
bins=bins,
anchor_points=anchor_points,
# CLIP parameters
prompt_type=prompt_type,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
deep_vpt=deep_vpt
)
weights_path = os.path.join("pre-trained weights", "CLIP-EBC-ViT-L-14-NWPU", "model.safetensors")
state_dict = load_file(weights_path)
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k.replace("model.", "")] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
# -----------------------------
# Preprocessing function
# -----------------------------
# Adjust the image transforms to match what your model expects.
def transform(image: Image.Image):
assert isinstance(image, Image.Image), "Input must be a PIL Image"
image_tensor = TF.to_tensor(image)
image_height, image_width = image_tensor.shape[-2:]
if image_height < input_size or image_width < input_size:
# Find the ratio to resize the image while maintaining the aspect ratio
ratio = max(input_size / image_height, input_size / image_width)
new_height = int(image_height * ratio) + 1
new_width = int(image_width * ratio) + 1
image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True)
image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
return image_tensor.unsqueeze(0) # Add batch dimension
# -----------------------------
# Inference function
# -----------------------------
def predict(image: Image.Image):
"""
Given an input image, preprocess it, run the model to obtain a density map,
compute the total crowd count, and prepare the density map for display.
"""
# Preprocess the image
input_width, input_height = image.size
input_tensor = transform(image).to(device) # shape: (1, 3, H, W)
with torch.no_grad():
density_map = model(input_tensor) # expected shape: (1, 1, H, W)
total_count = density_map.sum().item()
resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
# Normalize the density map for display purposes
eps = 1e-8
density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps)
# Apply a colormap (e.g., 'jet') to get an RGBA image
colormap = cm.get_cmap("jet")
# The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8)
density_map_color_img = Image.fromarray(density_map_color).convert("RGBA")
# Ensure the original image is in RGBA format.
image_rgba = image.convert("RGBA")
overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha)
return image, overlayed_image, f"Predicted Count: {total_count:.2f}"
# -----------------------------
# Build Gradio Interface using Blocks for a two-column layout
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Crowd Counting Demo")
gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
with gr.Row():
with gr.Column():
input_img = gr.Image(
label="Input Image",
sources=["upload", "clipboard"],
type="pil",
)
submit_btn = gr.Button("Predict")
with gr.Column():
output_img = gr.Image(label="Predicted Density Map", type="pil")
output_text = gr.Textbox(label="Total Count")
submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text])
# Optional: add example images. Ensure these files are in your repo.
gr.Examples(
examples=[
["example1.jpg"],
["example2.jpg"]
],
inputs=input_img,
label="Try an example"
)
# Launch the app
demo.launch(share=True)
|