File size: 5,500 Bytes
2366e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
from torch import nn
import numpy as np
from PIL import Image
import json, os
import gradio as gr
import torchvision.transforms.functional as TF
from safetensors.torch import load_file  # Import the load_file function from safetensors
from matplotlib import cm

from models import get_model
from utils import resize_density_map, init_seeds


mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
alpha = 0.8
init_seeds(42)

# -----------------------------
# Define the model architecture
# -----------------------------
truncation = 4
reduction = 8
granularity = "fine"
anchor_points = "average"

model_name = "clip_vit_l_14"
input_size = 224

# Comment the lines below to test non-CLIP models.
prompt_type = "word"
num_vpt = 32
vpt_drop = 0.
deep_vpt = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if truncation is None:  # regression, no truncation.
    bins, anchor_points = None, None
else:
    with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
        config = json.load(f)[str(truncation)]["nwpu"]
    bins = config["bins"][granularity]
    anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
    bins = [(float(b[0]), float(b[1])) for b in bins]
    anchor_points = [float(p) for p in anchor_points]


model = get_model(
    backbone=model_name,
    input_size=input_size, 
    reduction=reduction,
    bins=bins,
    anchor_points=anchor_points,
    # CLIP parameters
    prompt_type=prompt_type,
    num_vpt=num_vpt,
    vpt_drop=vpt_drop,
    deep_vpt=deep_vpt
)
weights_path = os.path.join("pre-trained weights", "CLIP-EBC-ViT-L-14-NWPU", "model.safetensors")
state_dict = load_file(weights_path)
new_state_dict = {}
for k, v in state_dict.items():
    new_state_dict[k.replace("model.", "")] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()


# -----------------------------
# Preprocessing function
# -----------------------------
# Adjust the image transforms to match what your model expects.
def transform(image: Image.Image):
    assert isinstance(image, Image.Image), "Input must be a PIL Image"
    image_tensor = TF.to_tensor(image)

    image_height, image_width = image_tensor.shape[-2:]
    if image_height < input_size or image_width < input_size:
        # Find the ratio to resize the image while maintaining the aspect ratio
        ratio = max(input_size / image_height, input_size / image_width)
        new_height = int(image_height * ratio) + 1
        new_width = int(image_width * ratio) + 1
        image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True)

    image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
    return image_tensor.unsqueeze(0)  # Add batch dimension



# -----------------------------
# Inference function
# -----------------------------
def predict(image: Image.Image):
    """
    Given an input image, preprocess it, run the model to obtain a density map,
    compute the total crowd count, and prepare the density map for display.
    """
    # Preprocess the image
    input_width, input_height = image.size
    input_tensor = transform(image).to(device)  # shape: (1, 3, H, W)
    
    with torch.no_grad():
        density_map = model(input_tensor)  # expected shape: (1, 1, H, W)
        total_count = density_map.sum().item()
        resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
    
    # Normalize the density map for display purposes
    eps = 1e-8
    density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps)
    
    # Apply a colormap (e.g., 'jet') to get an RGBA image
    colormap = cm.get_cmap("jet")
    # The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
    density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8)
    density_map_color_img = Image.fromarray(density_map_color).convert("RGBA")
    
    # Ensure the original image is in RGBA format.
    image_rgba = image.convert("RGBA")
    overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha)
    
    return image, overlayed_image, f"Predicted Count: {total_count:.2f}"


# -----------------------------
# Build Gradio Interface using Blocks for a two-column layout
# -----------------------------
with gr.Blocks() as demo:
    gr.Markdown("# Crowd Counting Demo")
    gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
    
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(
                label="Input Image",
                sources=["upload", "clipboard"],
                type="pil",
            )
            submit_btn = gr.Button("Predict")
        with gr.Column():
            output_img = gr.Image(label="Predicted Density Map", type="pil")
            output_text = gr.Textbox(label="Total Count")
    
    submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text])
    
    # Optional: add example images. Ensure these files are in your repo.
    gr.Examples(
        examples=[
            ["example1.jpg"],
            ["example2.jpg"]
        ],
        inputs=input_img,
        label="Try an example"
    )

# Launch the app
demo.launch(share=True)