File size: 2,189 Bytes
ee80624
 
bcffefb
ee80624
 
 
bcffefb
ee80624
28a21c6
 
 
 
 
 
 
 
 
bcffefb
28a21c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee80624
 
 
 
 
 
 
 
28a21c6
ee80624
28a21c6
ee80624
 
28a21c6
ee80624
 
bcffefb
 
ee80624
 
 
 
 
bcffefb
ee80624
 
 
 
 
28a21c6
bcffefb
28a21c6
bcffefb
ee80624
 
bcffefb
 
 
ee80624
 
bcffefb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image

MODEL_PATH = "densenet121-a639ec97.pth"   # rename your .pth file to model.pth

# ------------------------------------------
# Load Model with Legacy-Key Compatibility
# ------------------------------------------
def load_legacy_densenet(path):
    print("Loading model:", path)
    state_dict = torch.load(path, map_location="cpu")

    model = models.densenet121(weights=None)

    # Legacy checkpoints must use strict=False
    missing, unexpected = model.load_state_dict(state_dict, strict=False)

    print("\n=== LOADING SUMMARY ===")
    print("Missing keys:", len(missing))
    print("Unexpected keys:", len(unexpected))

    model.eval()
    return model

model = load_legacy_densenet(MODEL_PATH)

# ------------------------------------------
# Preprocessing
# ------------------------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# ------------------------------------------
# Prediction function
# ------------------------------------------
def predict(image):
    image = Image.fromarray(image).convert("RGB")
    img_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        logits = model(img_tensor)
        probs = F.softmax(logits, dim=1)
        top5_prob, top5_idx = torch.topk(probs, 5)

    results = []
    for p, idx in zip(top5_prob[0], top5_idx[0]):
        results.append({
            "class_index": int(idx.item()),
            "confidence": float(p.item())
        })

    return results

# ------------------------------------------
# Gradio UI
# ------------------------------------------
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy", label="Upload Image"),
    outputs=gr.JSON(label="Top-5 Predictions (Class Index Only)"),
    title="DenseNet-121 Model Classifier (No Label File)",
    description="This app uses a legacy DenseNet121 .pth model and returns top-5 class indices + confidence.",
)

demo.launch()