borso271 commited on
Commit
2a97c1d
Β·
1 Parent(s): 42efcad

Add remove labels functionality with admin panel UI

Browse files
Files changed (4) hide show
  1. app.py +115 -0
  2. app_backup.py +280 -0
  3. app_fixed.py +301 -0
  4. handler.py +49 -0
app.py CHANGED
@@ -155,6 +155,81 @@ def get_current_stats():
155
  except Exception as e:
156
  return f"Error getting stats: {str(e)}"
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # Create Gradio interface
159
  print("Creating Gradio interface...")
160
  with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
@@ -265,6 +340,46 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
265
  inputs=[admin_token_input, version_input],
266
  outputs=reload_output
267
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  with gr.Tab("ℹ️ About"):
270
  gr.Markdown("""
 
155
  except Exception as e:
156
  return f"Error getting stats: {str(e)}"
157
 
158
+ def get_labels_table():
159
+ """
160
+ Get all current labels as a formatted table for display.
161
+ """
162
+ if handler is None:
163
+ return "Handler not initialized"
164
+
165
+ if not hasattr(handler, 'class_ids') or len(handler.class_ids) == 0:
166
+ return "No labels currently loaded"
167
+
168
+ try:
169
+ # Create a formatted table of labels
170
+ table_data = []
171
+ for id, name in zip(handler.class_ids, handler.class_names):
172
+ table_data.append([int(id), name])
173
+
174
+ return table_data
175
+ except Exception as e:
176
+ return f"Error getting labels: {str(e)}"
177
+
178
+ def remove_labels_admin(admin_token, ids_to_remove_str):
179
+ """
180
+ Admin function to remove labels by ID.
181
+ """
182
+ if handler is None:
183
+ return "Error: Handler not initialized"
184
+
185
+ if not admin_token:
186
+ return "Error: Admin token required"
187
+
188
+ try:
189
+ # Parse the IDs from comma-separated string
190
+ if not ids_to_remove_str or ids_to_remove_str.strip() == "":
191
+ return "❌ Error: Please provide IDs to remove (comma-separated)"
192
+
193
+ ids_to_remove = []
194
+ for id_str in ids_to_remove_str.split(','):
195
+ id_str = id_str.strip()
196
+ if id_str:
197
+ ids_to_remove.append(int(id_str))
198
+
199
+ if not ids_to_remove:
200
+ return "❌ Error: No valid IDs provided"
201
+
202
+ # Get names of items to be removed for confirmation
203
+ removed_names = []
204
+ if hasattr(handler, 'class_ids'):
205
+ for id in ids_to_remove:
206
+ if id in handler.class_ids:
207
+ idx = handler.class_ids.index(id)
208
+ removed_names.append(f"{id}: {handler.class_names[idx]}")
209
+
210
+ result = handler({
211
+ "inputs": {
212
+ "op": "remove_labels",
213
+ "token": admin_token,
214
+ "ids": ids_to_remove
215
+ }
216
+ })
217
+
218
+ if result.get("status") == "ok":
219
+ removed_list = "\n".join(removed_names) if removed_names else "None found"
220
+ return f"βœ… Success! Removed {result.get('removed', 0)} labels. Current version: {result.get('labels_version', 'unknown')}\n\nRemoved items:\n{removed_list}"
221
+ elif result.get("error") == "unauthorized":
222
+ return "❌ Error: Invalid admin token"
223
+ elif result.get("error") == "no_ids_provided":
224
+ return "❌ Error: No IDs provided"
225
+ else:
226
+ return f"❌ Error: {result.get('detail', result.get('error', 'Unknown error'))}"
227
+
228
+ except ValueError:
229
+ return "❌ Error: Invalid ID format. Please provide comma-separated numbers (e.g., 1001,1002,1003)"
230
+ except Exception as e:
231
+ return f"❌ Error: {str(e)}"
232
+
233
  # Create Gradio interface
234
  print("Creating Gradio interface...")
235
  with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
 
340
  inputs=[admin_token_input, version_input],
341
  outputs=reload_output
342
  )
343
+
344
+ with gr.Accordion("πŸ—‘οΈ Remove Labels", open=False):
345
+ gr.Markdown("Remove specific labels by their IDs")
346
+
347
+ # Display current labels
348
+ labels_table = gr.Dataframe(
349
+ value=get_labels_table(),
350
+ headers=["ID", "Name"],
351
+ label="Current Labels",
352
+ interactive=False,
353
+ height=300
354
+ )
355
+
356
+ refresh_labels_btn = gr.Button("πŸ”„ Refresh Label List", size="sm")
357
+ refresh_labels_btn.click(
358
+ fn=get_labels_table,
359
+ inputs=[],
360
+ outputs=labels_table
361
+ )
362
+
363
+ gr.Markdown("Enter IDs to remove (comma-separated):")
364
+ ids_to_remove_input = gr.Textbox(
365
+ label="IDs to Remove",
366
+ placeholder="e.g., 1001, 1002, 1003",
367
+ lines=1
368
+ )
369
+
370
+ remove_btn = gr.Button("πŸ—‘οΈ Remove Selected Labels", variant="stop")
371
+ remove_output = gr.Markdown()
372
+
373
+ def remove_and_refresh(token, ids):
374
+ result = remove_labels_admin(token, ids)
375
+ updated_table = get_labels_table()
376
+ return result, updated_table
377
+
378
+ remove_btn.click(
379
+ fn=remove_and_refresh,
380
+ inputs=[admin_token_input, ids_to_remove_input],
381
+ outputs=[remove_output, labels_table]
382
+ )
383
 
384
  with gr.Tab("ℹ️ About"):
385
  gr.Markdown("""
app_backup.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import base64
3
+ import json
4
+ import os
5
+ from PIL import Image
6
+ import io
7
+ from handler import EndpointHandler
8
+
9
+ handler = EndpointHandler()
10
+
11
+ def classify_image(image, top_k=10):
12
+ """
13
+ Main classification function for public interface.
14
+ """
15
+ if image is None:
16
+ return None, "Please upload an image"
17
+
18
+ try:
19
+ # Convert PIL image to base64
20
+ buffered = io.BytesIO()
21
+ image.save(buffered, format="PNG")
22
+ img_b64 = base64.b64encode(buffered.getvalue()).decode()
23
+
24
+ # Call handler
25
+ result = handler({
26
+ "inputs": {
27
+ "image": img_b64,
28
+ "top_k": int(top_k)
29
+ }
30
+ })
31
+
32
+ # Format results for display
33
+ if isinstance(result, list):
34
+ # Create formatted output
35
+ output_text = "**Top {} Classifications:**\n\n".format(len(result))
36
+
37
+ # Create a dictionary for the bar chart
38
+ chart_data = {}
39
+
40
+ for i, item in enumerate(result, 1):
41
+ score_pct = item['score'] * 100
42
+ output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n"
43
+ chart_data[item['label']] = item['score']
44
+
45
+ return chart_data, output_text
46
+ else:
47
+ return None, f"Error: {result.get('error', 'Unknown error')}"
48
+
49
+ except Exception as e:
50
+ return None, f"Error: {str(e)}"
51
+
52
+ def upsert_labels_admin(admin_token, new_items_json):
53
+ """
54
+ Admin function to add new labels.
55
+ """
56
+ if not admin_token:
57
+ return "Error: Admin token required"
58
+
59
+ try:
60
+ # Parse the JSON input
61
+ items = json.loads(new_items_json) if new_items_json else []
62
+
63
+ result = handler({
64
+ "inputs": {
65
+ "op": "upsert_labels",
66
+ "token": admin_token,
67
+ "items": items
68
+ }
69
+ })
70
+
71
+ if result.get("status") == "ok":
72
+ return f"βœ… Success! Added {result.get('added', 0)} new labels. Current version: {result.get('labels_version', 'unknown')}"
73
+ elif result.get("error") == "unauthorized":
74
+ return "❌ Error: Invalid admin token"
75
+ else:
76
+ return f"❌ Error: {result.get('detail', result.get('error', 'Unknown error'))}"
77
+
78
+ except json.JSONDecodeError:
79
+ return "❌ Error: Invalid JSON format"
80
+ except Exception as e:
81
+ return f"❌ Error: {str(e)}"
82
+
83
+ def reload_labels_admin(admin_token, version):
84
+ """
85
+ Admin function to reload a specific label version.
86
+ """
87
+ if not admin_token:
88
+ return "Error: Admin token required"
89
+
90
+ try:
91
+ result = handler({
92
+ "inputs": {
93
+ "op": "reload_labels",
94
+ "token": admin_token,
95
+ "version": int(version) if version else 1
96
+ }
97
+ })
98
+
99
+ if result.get("status") == "ok":
100
+ return f"βœ… Labels reloaded successfully! Current version: {result.get('labels_version', 'unknown')}"
101
+ elif result.get("status") == "nochange":
102
+ return f"ℹ️ No change needed. Current version: {result.get('labels_version', 'unknown')}"
103
+ elif result.get("error") == "unauthorized":
104
+ return "❌ Error: Invalid admin token"
105
+ elif result.get("error") == "invalid_version":
106
+ return "❌ Error: Invalid version number"
107
+ else:
108
+ return f"❌ Error: {result.get('error', 'Unknown error')}"
109
+
110
+ except Exception as e:
111
+ return f"❌ Error: {str(e)}"
112
+
113
+ def get_current_stats():
114
+ """
115
+ Get current label statistics.
116
+ """
117
+ try:
118
+ num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0
119
+ version = getattr(handler, 'labels_version', 1)
120
+ device = handler.device if hasattr(handler, 'device') else "unknown"
121
+
122
+ stats = f"""
123
+ **Current Statistics:**
124
+ - Number of labels: {num_labels}
125
+ - Labels version: {version}
126
+ - Device: {device}
127
+ - Model: MobileCLIP-B
128
+ """
129
+
130
+ if hasattr(handler, 'class_names') and len(handler.class_names) > 0:
131
+ stats += f"\n- Sample labels: {', '.join(handler.class_names[:5])}"
132
+ if len(handler.class_names) > 5:
133
+ stats += "..."
134
+
135
+ return stats
136
+ except Exception as e:
137
+ return f"Error getting stats: {str(e)}"
138
+
139
+ # Create Gradio interface
140
+ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
141
+ gr.Markdown("""
142
+ # πŸ–ΌοΈ MobileCLIP-B Zero-Shot Image Classifier
143
+
144
+ Upload an image to classify it using MobileCLIP-B model with dynamic label management.
145
+ """)
146
+
147
+ with gr.Tab("πŸ” Image Classification"):
148
+ with gr.Row():
149
+ with gr.Column():
150
+ input_image = gr.Image(
151
+ type="pil",
152
+ label="Upload Image"
153
+ )
154
+ top_k_slider = gr.Slider(
155
+ minimum=1,
156
+ maximum=50,
157
+ value=10,
158
+ step=1,
159
+ label="Number of top results to show"
160
+ )
161
+ classify_btn = gr.Button("πŸš€ Classify Image", variant="primary")
162
+
163
+ with gr.Column():
164
+ output_chart = gr.BarPlot(
165
+ label="Classification Confidence",
166
+ x_label="Label",
167
+ y_label="Confidence",
168
+ vertical=False,
169
+ height=400
170
+ )
171
+ output_text = gr.Markdown(label="Classification Results")
172
+
173
+ gr.Examples(
174
+ examples=[
175
+ ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/image_classifier/examples/cheetah.jpg"],
176
+ ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/image_classifier/examples/elephant.jpg"],
177
+ ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/image_classifier/examples/giraffe.jpg"]
178
+ ],
179
+ inputs=input_image,
180
+ label="Example Images"
181
+ )
182
+
183
+ classify_btn.click(
184
+ classify_image,
185
+ inputs=[input_image, top_k_slider],
186
+ outputs=[output_chart, output_text]
187
+ )
188
+
189
+ with gr.Tab("πŸ”§ Admin Panel"):
190
+ gr.Markdown("""
191
+ ### Admin Functions
192
+ **Note:** Requires admin token (set via environment variable `ADMIN_TOKEN`)
193
+ """)
194
+
195
+ with gr.Row():
196
+ admin_token_input = gr.Textbox(
197
+ label="Admin Token",
198
+ type="password",
199
+ placeholder="Enter admin token"
200
+ )
201
+
202
+ with gr.Accordion("πŸ“Š Current Statistics", open=True):
203
+ stats_display = gr.Markdown(value=get_current_stats())
204
+ refresh_stats_btn = gr.Button("πŸ”„ Refresh Stats")
205
+ refresh_stats_btn.click(
206
+ get_current_stats,
207
+ outputs=stats_display
208
+ )
209
+
210
+ with gr.Accordion("βž• Add New Labels", open=False):
211
+ gr.Markdown("""
212
+ Add new labels by providing JSON array:
213
+ ```json
214
+ [
215
+ {"id": 100, "name": "new_object", "prompt": "a photo of a new_object"},
216
+ {"id": 101, "name": "another_object", "prompt": "a photo of another_object"}
217
+ ]
218
+ ```
219
+ """)
220
+ new_items_input = gr.Code(
221
+ label="New Items JSON",
222
+ language="json",
223
+ lines=5,
224
+ value='[\n {"id": 100, "name": "example", "prompt": "a photo of example"}\n]'
225
+ )
226
+ upsert_btn = gr.Button("βž• Add Labels", variant="primary")
227
+ upsert_output = gr.Markdown()
228
+
229
+ upsert_btn.click(
230
+ upsert_labels_admin,
231
+ inputs=[admin_token_input, new_items_input],
232
+ outputs=upsert_output
233
+ )
234
+
235
+ with gr.Accordion("πŸ”„ Reload Label Version", open=False):
236
+ gr.Markdown("Reload labels from a specific version stored in the Hub")
237
+ version_input = gr.Number(
238
+ label="Version Number",
239
+ value=1,
240
+ precision=0
241
+ )
242
+ reload_btn = gr.Button("πŸ”„ Reload Version", variant="primary")
243
+ reload_output = gr.Markdown()
244
+
245
+ reload_btn.click(
246
+ reload_labels_admin,
247
+ inputs=[admin_token_input, version_input],
248
+ outputs=reload_output
249
+ )
250
+
251
+ with gr.Tab("ℹ️ About"):
252
+ gr.Markdown("""
253
+ ## About MobileCLIP-B Classifier
254
+
255
+ This Space provides a web interface for Apple's MobileCLIP-B model, optimized for fast zero-shot image classification.
256
+
257
+ ### Features:
258
+ - πŸš€ **Fast inference**: < 30ms on GPU
259
+ - 🏷️ **Dynamic labels**: Add/update labels without redeployment
260
+ - πŸ”„ **Version control**: Track and reload label versions
261
+ - πŸ“Š **Visual results**: Bar charts and confidence scores
262
+
263
+ ### Environment Variables (set in Space Settings):
264
+ - `ADMIN_TOKEN`: Secret token for admin operations
265
+ - `HF_LABEL_REPO`: Hub repository for label storage (e.g., "username/labels")
266
+ - `HF_WRITE_TOKEN`: Token with write permissions to label repo
267
+ - `HF_READ_TOKEN`: Token with read permissions (optional, defaults to write token)
268
+
269
+ ### Model Details:
270
+ - **Architecture**: MobileCLIP-B with MobileOne blocks
271
+ - **Text Encoder**: Transformer-based, 77 token context
272
+ - **Image Size**: 224x224
273
+ - **Embedding Dim**: 512
274
+
275
+ ### License:
276
+ Model weights are licensed under Apple Sample Code License (ASCL).
277
+ """)
278
+
279
+ if __name__ == "__main__":
280
+ demo.launch()
app_fixed.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import base64
3
+ import json
4
+ import os
5
+ from PIL import Image
6
+ import io
7
+ from handler import EndpointHandler
8
+
9
+ # Initialize handler
10
+ print("Initializing MobileCLIP handler...")
11
+ try:
12
+ handler = EndpointHandler()
13
+ print(f"Handler initialized successfully! Device: {handler.device}")
14
+ except Exception as e:
15
+ print(f"Error initializing handler: {e}")
16
+ handler = None
17
+
18
+ def classify_image(image, top_k=10):
19
+ """
20
+ Main classification function for public interface.
21
+ """
22
+ if handler is None:
23
+ return "Error: Handler not initialized", None
24
+
25
+ if image is None:
26
+ return "Please upload an image", None
27
+
28
+ try:
29
+ # Convert PIL image to base64
30
+ buffered = io.BytesIO()
31
+ image.save(buffered, format="PNG")
32
+ img_b64 = base64.b64encode(buffered.getvalue()).decode()
33
+
34
+ # Call handler
35
+ result = handler({
36
+ "inputs": {
37
+ "image": img_b64,
38
+ "top_k": int(top_k)
39
+ }
40
+ })
41
+
42
+ # Format results for display
43
+ if isinstance(result, list):
44
+ # Create formatted output
45
+ output_text = "**Top {} Classifications:**\n\n".format(len(result))
46
+
47
+ # Create data for bar chart (list of tuples)
48
+ chart_data = []
49
+
50
+ for i, item in enumerate(result, 1):
51
+ score_pct = item['score'] * 100
52
+ output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n"
53
+ chart_data.append((item['label'], item['score']))
54
+
55
+ return output_text, chart_data
56
+ else:
57
+ return f"Error: {result.get('error', 'Unknown error')}", None
58
+
59
+ except Exception as e:
60
+ return f"Error: {str(e)}", None
61
+
62
+ def upsert_labels_admin(admin_token, new_items_json):
63
+ """
64
+ Admin function to add new labels.
65
+ """
66
+ if handler is None:
67
+ return "Error: Handler not initialized"
68
+
69
+ if not admin_token:
70
+ return "Error: Admin token required"
71
+
72
+ try:
73
+ # Parse the JSON input
74
+ items = json.loads(new_items_json) if new_items_json else []
75
+
76
+ result = handler({
77
+ "inputs": {
78
+ "op": "upsert_labels",
79
+ "token": admin_token,
80
+ "items": items
81
+ }
82
+ })
83
+
84
+ if result.get("status") == "ok":
85
+ return f"βœ… Success! Added {result.get('added', 0)} new labels. Current version: {result.get('labels_version', 'unknown')}"
86
+ elif result.get("error") == "unauthorized":
87
+ return "❌ Error: Invalid admin token"
88
+ else:
89
+ return f"❌ Error: {result.get('detail', result.get('error', 'Unknown error'))}"
90
+
91
+ except json.JSONDecodeError:
92
+ return "❌ Error: Invalid JSON format"
93
+ except Exception as e:
94
+ return f"❌ Error: {str(e)}"
95
+
96
+ def reload_labels_admin(admin_token, version):
97
+ """
98
+ Admin function to reload a specific label version.
99
+ """
100
+ if handler is None:
101
+ return "Error: Handler not initialized"
102
+
103
+ if not admin_token:
104
+ return "Error: Admin token required"
105
+
106
+ try:
107
+ result = handler({
108
+ "inputs": {
109
+ "op": "reload_labels",
110
+ "token": admin_token,
111
+ "version": int(version) if version else 1
112
+ }
113
+ })
114
+
115
+ if result.get("status") == "ok":
116
+ return f"βœ… Labels reloaded successfully! Current version: {result.get('labels_version', 'unknown')}"
117
+ elif result.get("status") == "nochange":
118
+ return f"ℹ️ No change needed. Current version: {result.get('labels_version', 'unknown')}"
119
+ elif result.get("error") == "unauthorized":
120
+ return "❌ Error: Invalid admin token"
121
+ elif result.get("error") == "invalid_version":
122
+ return "❌ Error: Invalid version number"
123
+ else:
124
+ return f"❌ Error: {result.get('error', 'Unknown error')}"
125
+
126
+ except Exception as e:
127
+ return f"❌ Error: {str(e)}"
128
+
129
+ def get_current_stats():
130
+ """
131
+ Get current label statistics.
132
+ """
133
+ if handler is None:
134
+ return "Handler not initialized"
135
+
136
+ try:
137
+ num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0
138
+ version = getattr(handler, 'labels_version', 1)
139
+ device = handler.device if hasattr(handler, 'device') else "unknown"
140
+
141
+ stats = f"""
142
+ **Current Statistics:**
143
+ - Number of labels: {num_labels}
144
+ - Labels version: {version}
145
+ - Device: {device}
146
+ - Model: MobileCLIP-B
147
+ """
148
+
149
+ if hasattr(handler, 'class_names') and len(handler.class_names) > 0:
150
+ stats += f"\n- Sample labels: {', '.join(handler.class_names[:5])}"
151
+ if len(handler.class_names) > 5:
152
+ stats += "..."
153
+
154
+ return stats
155
+ except Exception as e:
156
+ return f"Error getting stats: {str(e)}"
157
+
158
+ # Create Gradio interface
159
+ print("Creating Gradio interface...")
160
+ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
161
+ gr.Markdown("""
162
+ # πŸ–ΌοΈ MobileCLIP-B Zero-Shot Image Classifier
163
+
164
+ Upload an image to classify it using MobileCLIP-B model with dynamic label management.
165
+ """)
166
+
167
+ with gr.Tab("πŸ” Image Classification"):
168
+ with gr.Row():
169
+ with gr.Column():
170
+ input_image = gr.Image(
171
+ type="pil",
172
+ label="Upload Image"
173
+ )
174
+ top_k_slider = gr.Slider(
175
+ minimum=1,
176
+ maximum=50,
177
+ value=10,
178
+ step=1,
179
+ label="Number of top results to show"
180
+ )
181
+ classify_btn = gr.Button("πŸš€ Classify Image", variant="primary")
182
+
183
+ with gr.Column():
184
+ output_text = gr.Markdown(label="Classification Results")
185
+ # Simplified bar chart using Dataframe
186
+ output_chart = gr.Dataframe(
187
+ headers=["Label", "Confidence"],
188
+ label="Classification Scores",
189
+ interactive=False
190
+ )
191
+
192
+ # Event handler for classification
193
+ classify_btn.click(
194
+ fn=classify_image,
195
+ inputs=[input_image, top_k_slider],
196
+ outputs=[output_text, output_chart]
197
+ )
198
+
199
+ # Also trigger on image upload
200
+ input_image.change(
201
+ fn=classify_image,
202
+ inputs=[input_image, top_k_slider],
203
+ outputs=[output_text, output_chart]
204
+ )
205
+
206
+ with gr.Tab("πŸ”§ Admin Panel"):
207
+ gr.Markdown("""
208
+ ### Admin Functions
209
+ **Note:** Requires admin token (set via environment variable `ADMIN_TOKEN`)
210
+ """)
211
+
212
+ with gr.Row():
213
+ admin_token_input = gr.Textbox(
214
+ label="Admin Token",
215
+ type="password",
216
+ placeholder="Enter admin token"
217
+ )
218
+
219
+ with gr.Accordion("πŸ“Š Current Statistics", open=True):
220
+ stats_display = gr.Markdown(value=get_current_stats())
221
+ refresh_stats_btn = gr.Button("πŸ”„ Refresh Stats")
222
+ refresh_stats_btn.click(
223
+ fn=get_current_stats,
224
+ inputs=[],
225
+ outputs=stats_display
226
+ )
227
+
228
+ with gr.Accordion("βž• Add New Labels", open=False):
229
+ gr.Markdown("""
230
+ Add new labels by providing JSON array:
231
+ ```json
232
+ [
233
+ {"id": 100, "name": "new_object", "prompt": "a photo of a new_object"},
234
+ {"id": 101, "name": "another_object", "prompt": "a photo of another_object"}
235
+ ]
236
+ ```
237
+ """)
238
+ new_items_input = gr.Code(
239
+ label="New Items JSON",
240
+ language="json",
241
+ lines=5,
242
+ value='[\n {"id": 100, "name": "example", "prompt": "a photo of example"}\n]'
243
+ )
244
+ upsert_btn = gr.Button("βž• Add Labels", variant="primary")
245
+ upsert_output = gr.Markdown()
246
+
247
+ upsert_btn.click(
248
+ fn=upsert_labels_admin,
249
+ inputs=[admin_token_input, new_items_input],
250
+ outputs=upsert_output
251
+ )
252
+
253
+ with gr.Accordion("πŸ”„ Reload Label Version", open=False):
254
+ gr.Markdown("Reload labels from a specific version stored in the Hub")
255
+ version_input = gr.Number(
256
+ label="Version Number",
257
+ value=1,
258
+ precision=0
259
+ )
260
+ reload_btn = gr.Button("πŸ”„ Reload Version", variant="primary")
261
+ reload_output = gr.Markdown()
262
+
263
+ reload_btn.click(
264
+ fn=reload_labels_admin,
265
+ inputs=[admin_token_input, version_input],
266
+ outputs=reload_output
267
+ )
268
+
269
+ with gr.Tab("ℹ️ About"):
270
+ gr.Markdown("""
271
+ ## About MobileCLIP-B Classifier
272
+
273
+ This Space provides a web interface for Apple's MobileCLIP-B model, optimized for fast zero-shot image classification.
274
+
275
+ ### Features:
276
+ - πŸš€ **Fast inference**: < 30ms on GPU
277
+ - 🏷️ **Dynamic labels**: Add/update labels without redeployment
278
+ - πŸ”„ **Version control**: Track and reload label versions
279
+ - πŸ“Š **Visual results**: Classification scores and confidence
280
+
281
+ ### Environment Variables (set in Space Settings):
282
+ - `ADMIN_TOKEN`: Secret token for admin operations
283
+ - `HF_LABEL_REPO`: Hub repository for label storage
284
+ - `HF_WRITE_TOKEN`: Token with write permissions to label repo
285
+ - `HF_READ_TOKEN`: Token with read permissions (optional)
286
+
287
+ ### Model Details:
288
+ - **Architecture**: MobileCLIP-B with MobileOne blocks
289
+ - **Text Encoder**: Transformer-based, 77 token context
290
+ - **Image Size**: 224x224
291
+ - **Embedding Dim**: 512
292
+
293
+ ### License:
294
+ Model weights are licensed under Apple Sample Code License (ASCL).
295
+ """)
296
+
297
+ print("Gradio interface created successfully!")
298
+
299
+ if __name__ == "__main__":
300
+ print("Launching Gradio app...")
301
+ demo.launch()
handler.py CHANGED
@@ -93,6 +93,24 @@ class EndpointHandler:
93
  ok = self._load_snapshot_from_hub_version(ver)
94
  return {"status": "ok" if ok else "nochange", "labels_version": getattr(self, "labels_version", 0)}
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # Freshness guard (optional)
97
  min_ver = payload.get("min_labels_version")
98
  if isinstance(min_ver, int) and min_ver > getattr(self, "labels_version", 0):
@@ -151,6 +169,37 @@ class EndpointHandler:
151
  self._to_device()
152
  return len(batch)
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def _persist_snapshot_to_hub(self, version: int):
155
  if not HF_LABEL_REPO:
156
  raise RuntimeError("HF_LABEL_REPO not set")
 
93
  ok = self._load_snapshot_from_hub_version(ver)
94
  return {"status": "ok" if ok else "nochange", "labels_version": getattr(self, "labels_version", 0)}
95
 
96
+ # Admin op: remove_labels
97
+ if op == "remove_labels":
98
+ if payload.get("token") != ADMIN_TOKEN:
99
+ return {"error": "unauthorized"}
100
+ ids_to_remove = set(payload.get("ids", []))
101
+ if not ids_to_remove:
102
+ return {"error": "no_ids_provided"}
103
+
104
+ removed = self._remove_items(ids_to_remove)
105
+ if removed > 0:
106
+ new_ver = int(getattr(self, "labels_version", 1)) + 1
107
+ try:
108
+ self._persist_snapshot_to_hub(new_ver)
109
+ self.labels_version = new_ver
110
+ except Exception as e:
111
+ return {"status": "error", "removed": removed, "detail": str(e)}
112
+ return {"status": "ok", "removed": removed, "labels_version": getattr(self, "labels_version", 1)}
113
+
114
  # Freshness guard (optional)
115
  min_ver = payload.get("min_labels_version")
116
  if isinstance(min_ver, int) and min_ver > getattr(self, "labels_version", 0):
 
169
  self._to_device()
170
  return len(batch)
171
 
172
+ def _remove_items(self, ids_to_remove):
173
+ if not ids_to_remove or not hasattr(self, "class_ids"):
174
+ return 0
175
+ with self._lock:
176
+ ids_to_remove = set(int(id) for id in ids_to_remove)
177
+ # Find indices to keep
178
+ indices_to_keep = []
179
+ removed_count = 0
180
+ for i, class_id in enumerate(self.class_ids):
181
+ if class_id not in ids_to_remove:
182
+ indices_to_keep.append(i)
183
+ else:
184
+ removed_count += 1
185
+
186
+ if removed_count == 0:
187
+ return 0
188
+
189
+ # Filter the tensors and lists
190
+ if indices_to_keep:
191
+ self.text_features_cpu = self.text_features_cpu[indices_to_keep].contiguous()
192
+ self.class_ids = [self.class_ids[i] for i in indices_to_keep]
193
+ self.class_names = [self.class_names[i] for i in indices_to_keep]
194
+ else:
195
+ # All items removed, reset to empty
196
+ self.text_features_cpu = torch.empty(0, self.text_features_cpu.shape[1])
197
+ self.class_ids = []
198
+ self.class_names = []
199
+
200
+ self._to_device()
201
+ return removed_count
202
+
203
  def _persist_snapshot_to_hub(self, version: int):
204
  if not HF_LABEL_REPO:
205
  raise RuntimeError("HF_LABEL_REPO not set")