import gradio as gr import torch import torch.nn.functional as F from torchvision import models, transforms from PIL import Image MODEL_PATH = "densenet121-a639ec97.pth" # rename your .pth file to model.pth # ------------------------------------------ # Load Model with Legacy-Key Compatibility # ------------------------------------------ def load_legacy_densenet(path): print("Loading model:", path) state_dict = torch.load(path, map_location="cpu") model = models.densenet121(weights=None) # Legacy checkpoints must use strict=False missing, unexpected = model.load_state_dict(state_dict, strict=False) print("\n=== LOADING SUMMARY ===") print("Missing keys:", len(missing)) print("Unexpected keys:", len(unexpected)) model.eval() return model model = load_legacy_densenet(MODEL_PATH) # ------------------------------------------ # Preprocessing # ------------------------------------------ transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # ------------------------------------------ # Prediction function # ------------------------------------------ def predict(image): image = Image.fromarray(image).convert("RGB") img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): logits = model(img_tensor) probs = F.softmax(logits, dim=1) top5_prob, top5_idx = torch.topk(probs, 5) results = [] for p, idx in zip(top5_prob[0], top5_idx[0]): results.append({ "class_index": int(idx.item()), "confidence": float(p.item()) }) return results # ------------------------------------------ # Gradio UI # ------------------------------------------ demo = gr.Interface( fn=predict, inputs=gr.Image(type="numpy", label="Upload Image"), outputs=gr.JSON(label="Top-5 Predictions (Class Index Only)"), title="DenseNet-121 Model Classifier (No Label File)", description="This app uses a legacy DenseNet121 .pth model and returns top-5 class indices + confidence.", ) demo.launch()