| | import json |
| | import os |
| | from functools import lru_cache |
| | from pathlib import Path |
| | from typing import Dict, List, Tuple |
| |
|
| | import torch |
| | import gradio as gr |
| | from transformers import AutoModelForZeroShotImageClassification, AutoProcessor |
| |
|
| | from utils.cache_manager import cached_inference |
| | from utils.modality_router import detect_modality |
| |
|
| |
|
| | BASE_DIR = Path(__file__).resolve().parent |
| | LABEL_DIR = BASE_DIR / "labels" |
| | MODEL_ID = "google/medsiglip-448" |
| |
|
| |
|
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| |
|
| | torch.set_num_threads(1) |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| |
|
| | processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) |
| | model = AutoModelForZeroShotImageClassification.from_pretrained( |
| | MODEL_ID, |
| | token=HF_TOKEN, |
| | torch_dtype=model_dtype, |
| | ).to(device) |
| | model.eval() |
| |
|
| |
|
| | LABEL_OVERRIDES = { |
| | "xray": "chest_labels.json", |
| | "mri": "brain_labels.json", |
| | } |
| |
|
| |
|
| | @lru_cache(maxsize=None) |
| | def load_labels(file_name: str) -> List[str]: |
| | label_path = LABEL_DIR / file_name |
| | with label_path.open("r", encoding="utf-8") as handle: |
| | return json.load(handle) |
| |
|
| |
|
| | def get_candidate_labels(image_path: str) -> Tuple[str, ...]: |
| | modality = detect_modality(image_path) |
| | candidate_path = LABEL_DIR / f"{modality}_labels.json" |
| | if not candidate_path.exists(): |
| | override = LABEL_OVERRIDES.get(modality) |
| | if override: |
| | candidate_path = LABEL_DIR / override |
| | if not candidate_path.exists(): |
| | candidate_path = LABEL_DIR / "general_labels.json" |
| |
|
| | return tuple(load_labels(candidate_path.name)) |
| |
|
| |
|
| | def classify_medical_image(image_path: str) -> Dict[str, float]: |
| | if not image_path: |
| | return {} |
| |
|
| | candidate_labels = get_candidate_labels(image_path) |
| | scores = cached_inference(image_path, candidate_labels, model, processor) |
| |
|
| | if not scores: |
| | return {} |
| |
|
| | results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True) |
| | top_results = results[:5] |
| |
|
| | return {label: float(score) for label, score in top_results} |
| |
|
| |
|
| | demo = gr.Interface( |
| | fn=classify_medical_image, |
| | inputs=gr.Image(type="filepath", label="📤 Upload Medical Image"), |
| | outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"), |
| | title="🩻 MedSigLIP Smart Medical Classifier", |
| | description="Zero-shot model with automatic label filtering for different modalities.", |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False) |
| |
|