Yiming-M commited on
Commit
2366e68
Β·
1 Parent(s): 83196f2

Renamed Weights Folder Names; Added app.py

Browse files
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from PIL import Image
5
+ import json, os
6
+ import gradio as gr
7
+ import torchvision.transforms.functional as TF
8
+ from safetensors.torch import load_file # Import the load_file function from safetensors
9
+ from matplotlib import cm
10
+
11
+ from models import get_model
12
+ from utils import resize_density_map, init_seeds
13
+
14
+
15
+ mean = (0.485, 0.456, 0.406)
16
+ std = (0.229, 0.224, 0.225)
17
+ alpha = 0.8
18
+ init_seeds(42)
19
+
20
+ # -----------------------------
21
+ # Define the model architecture
22
+ # -----------------------------
23
+ truncation = 4
24
+ reduction = 8
25
+ granularity = "fine"
26
+ anchor_points = "average"
27
+
28
+ model_name = "clip_vit_l_14"
29
+ input_size = 224
30
+
31
+ # Comment the lines below to test non-CLIP models.
32
+ prompt_type = "word"
33
+ num_vpt = 32
34
+ vpt_drop = 0.
35
+ deep_vpt = True
36
+
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+
40
+ if truncation is None: # regression, no truncation.
41
+ bins, anchor_points = None, None
42
+ else:
43
+ with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
44
+ config = json.load(f)[str(truncation)]["nwpu"]
45
+ bins = config["bins"][granularity]
46
+ anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
47
+ bins = [(float(b[0]), float(b[1])) for b in bins]
48
+ anchor_points = [float(p) for p in anchor_points]
49
+
50
+
51
+ model = get_model(
52
+ backbone=model_name,
53
+ input_size=input_size,
54
+ reduction=reduction,
55
+ bins=bins,
56
+ anchor_points=anchor_points,
57
+ # CLIP parameters
58
+ prompt_type=prompt_type,
59
+ num_vpt=num_vpt,
60
+ vpt_drop=vpt_drop,
61
+ deep_vpt=deep_vpt
62
+ )
63
+ weights_path = os.path.join("pre-trained weights", "CLIP-EBC-ViT-L-14-NWPU", "model.safetensors")
64
+ state_dict = load_file(weights_path)
65
+ new_state_dict = {}
66
+ for k, v in state_dict.items():
67
+ new_state_dict[k.replace("model.", "")] = v
68
+ model.load_state_dict(new_state_dict)
69
+ model.to(device)
70
+ model.eval()
71
+
72
+
73
+ # -----------------------------
74
+ # Preprocessing function
75
+ # -----------------------------
76
+ # Adjust the image transforms to match what your model expects.
77
+ def transform(image: Image.Image):
78
+ assert isinstance(image, Image.Image), "Input must be a PIL Image"
79
+ image_tensor = TF.to_tensor(image)
80
+
81
+ image_height, image_width = image_tensor.shape[-2:]
82
+ if image_height < input_size or image_width < input_size:
83
+ # Find the ratio to resize the image while maintaining the aspect ratio
84
+ ratio = max(input_size / image_height, input_size / image_width)
85
+ new_height = int(image_height * ratio) + 1
86
+ new_width = int(image_width * ratio) + 1
87
+ image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True)
88
+
89
+ image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
90
+ return image_tensor.unsqueeze(0) # Add batch dimension
91
+
92
+
93
+
94
+ # -----------------------------
95
+ # Inference function
96
+ # -----------------------------
97
+ def predict(image: Image.Image):
98
+ """
99
+ Given an input image, preprocess it, run the model to obtain a density map,
100
+ compute the total crowd count, and prepare the density map for display.
101
+ """
102
+ # Preprocess the image
103
+ input_width, input_height = image.size
104
+ input_tensor = transform(image).to(device) # shape: (1, 3, H, W)
105
+
106
+ with torch.no_grad():
107
+ density_map = model(input_tensor) # expected shape: (1, 1, H, W)
108
+ total_count = density_map.sum().item()
109
+ resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
110
+
111
+ # Normalize the density map for display purposes
112
+ eps = 1e-8
113
+ density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps)
114
+
115
+ # Apply a colormap (e.g., 'jet') to get an RGBA image
116
+ colormap = cm.get_cmap("jet")
117
+ # The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
118
+ density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8)
119
+ density_map_color_img = Image.fromarray(density_map_color).convert("RGBA")
120
+
121
+ # Ensure the original image is in RGBA format.
122
+ image_rgba = image.convert("RGBA")
123
+ overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha)
124
+
125
+ return image, overlayed_image, f"Predicted Count: {total_count:.2f}"
126
+
127
+
128
+ # -----------------------------
129
+ # Build Gradio Interface using Blocks for a two-column layout
130
+ # -----------------------------
131
+ with gr.Blocks() as demo:
132
+ gr.Markdown("# Crowd Counting Demo")
133
+ gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
134
+
135
+ with gr.Row():
136
+ with gr.Column():
137
+ input_img = gr.Image(
138
+ label="Input Image",
139
+ sources=["upload", "clipboard"],
140
+ type="pil",
141
+ )
142
+ submit_btn = gr.Button("Predict")
143
+ with gr.Column():
144
+ output_img = gr.Image(label="Predicted Density Map", type="pil")
145
+ output_text = gr.Textbox(label="Total Count")
146
+
147
+ submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text])
148
+
149
+ # Optional: add example images. Ensure these files are in your repo.
150
+ gr.Examples(
151
+ examples=[
152
+ ["example1.jpg"],
153
+ ["example2.jpg"]
154
+ ],
155
+ inputs=input_img,
156
+ label="Try an example"
157
+ )
158
+
159
+ # Launch the app
160
+ demo.launch(share=True)
pre-trained weights/{CLIP-EBC-ViT-B:16 (NWPU) β†’ CLIP-EBC-ViT-B-16-NWPU}/README.md RENAMED
File without changes
pre-trained weights/{CLIP-EBC-ViT-B:16 (NWPU) β†’ CLIP-EBC-ViT-B-16-NWPU}/config.json RENAMED
File without changes
pre-trained weights/{CLIP-EBC-ViT-B:16 (NWPU) β†’ CLIP-EBC-ViT-B-16-NWPU}/model.safetensors RENAMED
File without changes
pre-trained weights/{CLIP-EBC-ViT-L:14 (NWPU) β†’ CLIP-EBC-ViT-L-14-NWPU}/README.md RENAMED
File without changes
pre-trained weights/{CLIP-EBC-ViT-L:14 (NWPU) β†’ CLIP-EBC-ViT-L-14-NWPU}/config.json RENAMED
File without changes
pre-trained weights/{CLIP-EBC-ViT-L:14 (NWPU) β†’ CLIP-EBC-ViT-L-14-NWPU}/model.safetensors RENAMED
File without changes