shrinusn77's picture
Update app.py
bcffefb verified
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
MODEL_PATH = "densenet121-a639ec97.pth" # rename your .pth file to model.pth
# ------------------------------------------
# Load Model with Legacy-Key Compatibility
# ------------------------------------------
def load_legacy_densenet(path):
print("Loading model:", path)
state_dict = torch.load(path, map_location="cpu")
model = models.densenet121(weights=None)
# Legacy checkpoints must use strict=False
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print("\n=== LOADING SUMMARY ===")
print("Missing keys:", len(missing))
print("Unexpected keys:", len(unexpected))
model.eval()
return model
model = load_legacy_densenet(MODEL_PATH)
# ------------------------------------------
# Preprocessing
# ------------------------------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ------------------------------------------
# Prediction function
# ------------------------------------------
def predict(image):
image = Image.fromarray(image).convert("RGB")
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
logits = model(img_tensor)
probs = F.softmax(logits, dim=1)
top5_prob, top5_idx = torch.topk(probs, 5)
results = []
for p, idx in zip(top5_prob[0], top5_idx[0]):
results.append({
"class_index": int(idx.item()),
"confidence": float(p.item())
})
return results
# ------------------------------------------
# Gradio UI
# ------------------------------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=gr.JSON(label="Top-5 Predictions (Class Index Only)"),
title="DenseNet-121 Model Classifier (No Label File)",
description="This app uses a legacy DenseNet121 .pth model and returns top-5 class indices + confidence.",
)
demo.launch()