import gradio as gr import cv2 import numpy as np import torch from transformers import AutoModel device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True) model = model.eval().to(device) def calculate_ctr(mask): lungs = np.zeros_like(mask, dtype=np.uint8) lungs[(mask == 1) | (mask == 2)] = 1 heart = (mask == 3).astype("uint8") lung_y, lung_x = np.where(lungs == 1) heart_y, heart_x = np.where(heart == 1) if lung_x.size == 0 or heart_x.size == 0: return None, None, None, None, None thorax_left = int(lung_x.min()) thorax_right = int(lung_x.max()) heart_left = int(heart_x.min()) heart_right = int(heart_x.max()) lung_range = thorax_right - thorax_left heart_range = heart_right - heart_left if lung_range == 0: ctr = None else: ctr = float(heart_range / lung_range) return ctr, thorax_left, thorax_right, heart_left, heart_right def _run_model(image): """Shared logic: from PIL image -> (img_gray, mask, view_idx, age, female_prob, coords...)""" img = np.array(image.convert("L")) h, w = img.shape[:2] x = model.preprocess(img) x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0).float() with torch.inference_mode(): out = model(x.to(device)) mask_small = out["mask"].argmax(1)[0].cpu().numpy() mask = cv2.resize(mask_small.astype("uint8"), (w, h), interpolation=cv2.INTER_NEAREST) view_idx = out["view"].argmax(1).item() age_pred = float(out["age"].item()) female_prob = float(out["female"].item()) ctr, thorax_left, thorax_right, heart_left, heart_right = calculate_ctr(mask) return ( img, mask, h, w, ctr, thorax_left, thorax_right, heart_left, heart_right, view_idx, age_pred, female_prob, ) # ---------- 1) Visual demo (what you already have) ---------- def analyze(image): if image is None: return None, "No image uploaded." ( img, mask, h, w, ctr, thorax_left, thorax_right, heart_left, heart_right, view_idx, age_pred, female_prob, ) = _run_model(image) color = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) overlay = color.copy() overlay[mask == 1] = [0, 255, 0] overlay[mask == 2] = [0, 128, 255] overlay[mask == 3] = [255, 0, 0] blended = cv2.addWeighted(color, 0.7, overlay, 0.3, 0) view_map = {0: "AP", 1: "PA", 2: "lateral"} view = view_map.get(view_idx, "unknown") lines = [] if ctr is not None: lines.append(f"CTR: {ctr:.2f}") else: lines.append("CTR: could not be reliably calculated (segmentation issue).") lines.extend([ f"View (model): {view}", f"Predicted age: {age_pred:.0f} years", f"Predicted sex: {'Female' if female_prob >= 0.5 else 'Male'} (prob={female_prob:.2f})", "", "⚠️ Research/educational use only, NOT for clinical decision-making.", ]) if view != "PA": lines.append("⚠️ CTR is normally interpreted on PA view. Interpret with caution.") return blended, "\n".join(lines) visual_demo = gr.Interface( fn=analyze, inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"), outputs=[ gr.Image(label="Segmentation overlay"), gr.Textbox(label="AI output"), ], title="AI CTR helper (research only)", description=( "Segments heart and lungs and estimates CTR using 'ianpan/chest-x-ray-basic'. " "Research use only." ), ) # ---------- 2) JSON points API (for your Lovable app) ---------- def get_points(image): if image is None: return {"error": "No image uploaded"} ( img, mask, h, w, ctr, thorax_left, thorax_right, heart_left, heart_right, view_idx, age_pred, female_prob, ) = _run_model(image) result = { "image_width": w, "image_height": h, "ctr": ctr, "thorax_left_px": thorax_left, "thorax_right_px": thorax_right, "heart_left_px": heart_left, "heart_right_px": heart_right, "view_idx": int(view_idx), } return result points_api = gr.Interface( fn=get_points, inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"), outputs=gr.JSON(label="CTR points JSON"), title="CTR points API", description="Returns thorax/heart x-coordinates and CTR as JSON.", api_name="ctr_points", # important for programmatic calls ) demo = gr.TabbedInterface( [visual_demo, points_api], ["Viewer", "JSON API"], ) if __name__ == "__main__": demo.launch()