clipspace / app.py
borso271's picture
Add base64 API endpoint for direct backend integration
b48deab
import gradio as gr
import base64
import json
import os
from PIL import Image
import io
from handler import EndpointHandler
# Initialize handler
print("Initializing MobileCLIP handler...")
try:
handler = EndpointHandler()
print(f"Handler initialized successfully! Device: {handler.device}")
except Exception as e:
print(f"Error initializing handler: {e}")
handler = None
def classify_image(image, top_k=10):
"""
Main classification function for public interface.
"""
if handler is None:
return "Error: Handler not initialized", None
if image is None:
return "Please upload an image", None
try:
# Convert PIL image to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_b64 = base64.b64encode(buffered.getvalue()).decode()
# Call handler
result = handler({
"inputs": {
"image": img_b64,
"top_k": int(top_k)
}
})
# Format results for display
if isinstance(result, list):
# Create formatted output
output_text = "**Top {} Classifications:**\n\n".format(len(result))
# Create data for bar chart (list of tuples)
chart_data = []
for i, item in enumerate(result, 1):
score_pct = item['score'] * 100
output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n"
chart_data.append((item['label'], item['score']))
return output_text, chart_data
else:
return f"Error: {result.get('error', 'Unknown error')}", None
except Exception as e:
return f"Error: {str(e)}", None
def upsert_labels_admin(admin_token, new_items_json):
"""
Admin function to add new labels.
"""
if handler is None:
return "Error: Handler not initialized"
if not admin_token:
return "Error: Admin token required"
try:
# Parse the JSON input
items = json.loads(new_items_json) if new_items_json else []
result = handler({
"inputs": {
"op": "upsert_labels",
"token": admin_token,
"items": items
}
})
if result.get("status") == "ok":
return f"βœ… Success! Added {result.get('added', 0)} new labels. Current version: {result.get('labels_version', 'unknown')}"
elif result.get("error") == "unauthorized":
return "❌ Error: Invalid admin token"
else:
return f"❌ Error: {result.get('detail', result.get('error', 'Unknown error'))}"
except json.JSONDecodeError:
return "❌ Error: Invalid JSON format"
except Exception as e:
return f"❌ Error: {str(e)}"
def reload_labels_admin(admin_token, version):
"""
Admin function to reload a specific label version.
"""
if handler is None:
return "Error: Handler not initialized"
if not admin_token:
return "Error: Admin token required"
try:
result = handler({
"inputs": {
"op": "reload_labels",
"token": admin_token,
"version": int(version) if version else 1
}
})
if result.get("status") == "ok":
return f"βœ… Labels reloaded successfully! Current version: {result.get('labels_version', 'unknown')}"
elif result.get("status") == "nochange":
return f"ℹ️ No change needed. Current version: {result.get('labels_version', 'unknown')}"
elif result.get("error") == "unauthorized":
return "❌ Error: Invalid admin token"
elif result.get("error") == "invalid_version":
return "❌ Error: Invalid version number"
else:
return f"❌ Error: {result.get('error', 'Unknown error')}"
except Exception as e:
return f"❌ Error: {str(e)}"
def get_current_stats():
"""
Get current label statistics.
"""
if handler is None:
return "Handler not initialized"
try:
num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0
version = getattr(handler, 'labels_version', 1)
device = handler.device if hasattr(handler, 'device') else "unknown"
stats = f"""
**Current Statistics:**
- Number of labels: {num_labels}
- Labels version: {version}
- Device: {device}
- Model: MobileCLIP-B
"""
if hasattr(handler, 'class_names') and len(handler.class_names) > 0:
stats += f"\n- Sample labels: {', '.join(handler.class_names[:5])}"
if len(handler.class_names) > 5:
stats += "..."
return stats
except Exception as e:
return f"Error getting stats: {str(e)}"
def get_labels_table():
"""
Get all current labels as a formatted table for display.
"""
if handler is None:
return "Handler not initialized"
if not hasattr(handler, 'class_ids') or len(handler.class_ids) == 0:
return "No labels currently loaded"
try:
# Create a formatted table of labels
table_data = []
for id, name in zip(handler.class_ids, handler.class_names):
table_data.append([int(id), name])
return table_data
except Exception as e:
return f"Error getting labels: {str(e)}"
def remove_labels_admin(admin_token, ids_to_remove_str):
"""
Admin function to remove labels by ID.
"""
if handler is None:
return "Error: Handler not initialized"
if not admin_token:
return "Error: Admin token required"
try:
# Parse the IDs from comma-separated string
if not ids_to_remove_str or ids_to_remove_str.strip() == "":
return "❌ Error: Please provide IDs to remove (comma-separated)"
ids_to_remove = []
for id_str in ids_to_remove_str.split(','):
id_str = id_str.strip()
if id_str:
ids_to_remove.append(int(id_str))
if not ids_to_remove:
return "❌ Error: No valid IDs provided"
# Get names of items to be removed for confirmation
removed_names = []
if hasattr(handler, 'class_ids'):
for id in ids_to_remove:
if id in handler.class_ids:
idx = handler.class_ids.index(id)
removed_names.append(f"{id}: {handler.class_names[idx]}")
result = handler({
"inputs": {
"op": "remove_labels",
"token": admin_token,
"ids": ids_to_remove
}
})
if result.get("status") == "ok":
removed_list = "\n".join(removed_names) if removed_names else "None found"
return f"βœ… Success! Removed {result.get('removed', 0)} labels. Current version: {result.get('labels_version', 'unknown')}\n\nRemoved items:\n{removed_list}"
elif result.get("error") == "unauthorized":
return "❌ Error: Invalid admin token"
elif result.get("error") == "no_ids_provided":
return "❌ Error: No IDs provided"
else:
return f"❌ Error: {result.get('detail', result.get('error', 'Unknown error'))}"
except ValueError:
return "❌ Error: Invalid ID format. Please provide comma-separated numbers (e.g., 1001,1002,1003)"
except Exception as e:
return f"❌ Error: {str(e)}"
# Create Gradio interface
print("Creating Gradio interface...")
with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
gr.Markdown("""
# πŸ–ΌοΈ MobileCLIP-B Zero-Shot Image Classifier
Upload an image to classify it using MobileCLIP-B model with dynamic label management.
""")
with gr.Tab("πŸ” Image Classification"):
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Upload Image"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=50,
value=10,
step=1,
label="Number of top results to show"
)
classify_btn = gr.Button("πŸš€ Classify Image", variant="primary")
with gr.Column():
output_text = gr.Markdown(label="Classification Results")
# Simplified bar chart using Dataframe
output_chart = gr.Dataframe(
headers=["Label", "Confidence"],
label="Classification Scores",
interactive=False
)
# Event handler for classification
classify_btn.click(
fn=classify_image,
inputs=[input_image, top_k_slider],
outputs=[output_text, output_chart],
api_name="classify_image"
)
# Also trigger on image upload
input_image.change(
fn=classify_image,
inputs=[input_image, top_k_slider],
outputs=[output_text, output_chart],
api_name="classify_image_1"
)
with gr.Tab("πŸ”§ Admin Panel"):
gr.Markdown("""
### Admin Functions
**Note:** Requires admin token (set via environment variable `ADMIN_TOKEN`)
""")
with gr.Row():
admin_token_input = gr.Textbox(
label="Admin Token",
type="password",
placeholder="Enter admin token"
)
with gr.Accordion("πŸ“Š Current Statistics", open=True):
stats_display = gr.Markdown(value=get_current_stats())
refresh_stats_btn = gr.Button("πŸ”„ Refresh Stats")
refresh_stats_btn.click(
fn=get_current_stats,
inputs=[],
outputs=stats_display
)
with gr.Accordion("βž• Add New Labels", open=False):
gr.Markdown("""
Add new labels by providing JSON array:
```json
[
{"id": 100, "name": "new_object", "prompt": "a photo of a new_object"},
{"id": 101, "name": "another_object", "prompt": "a photo of another_object"}
]
```
""")
new_items_input = gr.Code(
label="New Items JSON",
language="json",
lines=5,
value='[\n {"id": 100, "name": "example", "prompt": "a photo of example"}\n]'
)
upsert_btn = gr.Button("βž• Add Labels", variant="primary")
upsert_output = gr.Markdown()
upsert_btn.click(
fn=upsert_labels_admin,
inputs=[admin_token_input, new_items_input],
outputs=upsert_output,
api_name="upsert_labels_admin"
)
with gr.Accordion("πŸ”„ Reload Label Version", open=False):
gr.Markdown("Reload labels from a specific version stored in the Hub")
version_input = gr.Number(
label="Version Number",
value=1,
precision=0
)
reload_btn = gr.Button("πŸ”„ Reload Version", variant="primary")
reload_output = gr.Markdown()
reload_btn.click(
fn=reload_labels_admin,
inputs=[admin_token_input, version_input],
outputs=reload_output
)
with gr.Accordion("πŸ—‘οΈ Remove Labels", open=False):
gr.Markdown("Remove specific labels by their IDs")
# Display current labels
labels_table = gr.Dataframe(
value=get_labels_table(),
headers=["ID", "Name"],
label="Current Labels",
interactive=False,
height=300
)
refresh_labels_btn = gr.Button("πŸ”„ Refresh Label List", size="sm")
refresh_labels_btn.click(
fn=get_labels_table,
inputs=[],
outputs=labels_table
)
gr.Markdown("Enter IDs to remove (comma-separated):")
ids_to_remove_input = gr.Textbox(
label="IDs to Remove",
placeholder="e.g., 1001, 1002, 1003",
lines=1
)
remove_btn = gr.Button("πŸ—‘οΈ Remove Selected Labels", variant="stop")
remove_output = gr.Markdown()
def remove_and_refresh(token, ids):
result = remove_labels_admin(token, ids)
updated_table = get_labels_table()
return result, updated_table
remove_btn.click(
fn=remove_and_refresh,
inputs=[admin_token_input, ids_to_remove_input],
outputs=[remove_output, labels_table]
)
with gr.Tab("ℹ️ About"):
gr.Markdown("""
## About MobileCLIP-B Classifier
This Space provides a web interface for Apple's MobileCLIP-B model, optimized for fast zero-shot image classification.
### Features:
- πŸš€ **Fast inference**: < 30ms on GPU
- 🏷️ **Dynamic labels**: Add/update labels without redeployment
- πŸ”„ **Version control**: Track and reload label versions
- πŸ“Š **Visual results**: Classification scores and confidence
### Environment Variables (set in Space Settings):
- `ADMIN_TOKEN`: Secret token for admin operations
- `HF_LABEL_REPO`: Hub repository for label storage
- `HF_WRITE_TOKEN`: Token with write permissions to label repo
- `HF_READ_TOKEN`: Token with read permissions (optional)
### Model Details:
- **Architecture**: MobileCLIP-B with MobileOne blocks
- **Text Encoder**: Transformer-based, 77 token context
- **Image Size**: 224x224
- **Embedding Dim**: 512
### License:
Model weights are licensed under Apple Sample Code License (ASCL).
""")
print("Gradio interface created successfully!")
# Add pure API endpoint for base64 classification (as suggested by GPT)
def classify_base64(image_b64: str, top_k: int = 10):
"""
API-only endpoint that accepts base64 images directly.
This enables direct API calls from backends without file uploads.
"""
if handler is None:
return {"error": "handler not initialized"}
try:
# Call handler directly with base64
result = handler({
"inputs": {
"image": image_b64,
"top_k": int(top_k)
}
})
return result
except Exception as e:
return {"error": str(e)}
# Register the API endpoint (no UI)
with demo:
gr.Interface(
fn=classify_base64,
inputs=[
gr.Textbox(label="image_b64", visible=False),
gr.Number(label="top_k", visible=False)
],
outputs=gr.JSON(visible=False),
api_name="classify_base64",
visible=False
)
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch()