shrinusn77 commited on
Commit
bcffefb
·
verified ·
1 Parent(s): 761722d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -29
app.py CHANGED
@@ -1,43 +1,27 @@
1
  import gradio as gr
2
  import torch
3
- import torch.nn as nn
4
  from torchvision import models, transforms
5
  from PIL import Image
6
- import json
7
- import os
8
 
9
- MODEL_PATH = "densenet121-a639ec97.pth" # rename your .pth to model.pth
10
-
11
- # ------------------------------------------
12
- # Load ImageNet Label Mapping
13
- # ------------------------------------------
14
- with open("imagenet_classes.json", "r") as f:
15
- idx_to_class = json.load(f)
16
 
17
  # ------------------------------------------
18
  # Load Model with Legacy-Key Compatibility
19
  # ------------------------------------------
20
  def load_legacy_densenet(path):
21
  print("Loading model:", path)
22
-
23
- # Load state_dict (old torchvision format)
24
  state_dict = torch.load(path, map_location="cpu")
25
 
26
- # Create new DenseNet model
27
  model = models.densenet121(weights=None)
28
 
29
- # Try loading with strict=False (legacy fix)
30
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
31
 
32
  print("\n=== LOADING SUMMARY ===")
33
  print("Missing keys:", len(missing))
34
  print("Unexpected keys:", len(unexpected))
35
 
36
- if len(missing) > 0:
37
- print("⚠ NOTE: Missing keys detected (normal for legacy checkpoint)")
38
- if len(unexpected) > 0:
39
- print("⚠ NOTE: Unexpected keys detected (normal for legacy checkpoint)")
40
-
41
  model.eval()
42
  return model
43
 
@@ -63,29 +47,28 @@ def predict(image):
63
  img_tensor = transform(image).unsqueeze(0)
64
 
65
  with torch.no_grad():
66
- outputs = model(img_tensor)
67
- probs = torch.softmax(outputs, dim=1)
68
  top5_prob, top5_idx = torch.topk(probs, 5)
69
 
70
  results = []
71
  for p, idx in zip(top5_prob[0], top5_idx[0]):
72
- cls_name = idx_to_class.get(str(idx.item()), "Unknown")
73
  results.append({
74
- "class": cls_name,
75
  "confidence": float(p.item())
76
  })
77
 
78
  return results
79
 
80
  # ------------------------------------------
81
- # Gradio Interface
82
  # ------------------------------------------
83
- interface = gr.Interface(
84
  fn=predict,
85
  inputs=gr.Image(type="numpy", label="Upload Image"),
86
- outputs=gr.JSON(label="Top-5 Predictions"),
87
- title="DenseNet-121 Legacy Model Classifier (ImageNet)",
88
- description="Upload any image. Model returns top-5 ImageNet predictions.",
89
  )
90
 
91
- interface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
  from torchvision import models, transforms
5
  from PIL import Image
 
 
6
 
7
+ MODEL_PATH = "densenet121-a639ec97.pth" # rename your .pth file to model.pth
 
 
 
 
 
 
8
 
9
  # ------------------------------------------
10
  # Load Model with Legacy-Key Compatibility
11
  # ------------------------------------------
12
  def load_legacy_densenet(path):
13
  print("Loading model:", path)
 
 
14
  state_dict = torch.load(path, map_location="cpu")
15
 
 
16
  model = models.densenet121(weights=None)
17
 
18
+ # Legacy checkpoints must use strict=False
19
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
20
 
21
  print("\n=== LOADING SUMMARY ===")
22
  print("Missing keys:", len(missing))
23
  print("Unexpected keys:", len(unexpected))
24
 
 
 
 
 
 
25
  model.eval()
26
  return model
27
 
 
47
  img_tensor = transform(image).unsqueeze(0)
48
 
49
  with torch.no_grad():
50
+ logits = model(img_tensor)
51
+ probs = F.softmax(logits, dim=1)
52
  top5_prob, top5_idx = torch.topk(probs, 5)
53
 
54
  results = []
55
  for p, idx in zip(top5_prob[0], top5_idx[0]):
 
56
  results.append({
57
+ "class_index": int(idx.item()),
58
  "confidence": float(p.item())
59
  })
60
 
61
  return results
62
 
63
  # ------------------------------------------
64
+ # Gradio UI
65
  # ------------------------------------------
66
+ demo = gr.Interface(
67
  fn=predict,
68
  inputs=gr.Image(type="numpy", label="Upload Image"),
69
+ outputs=gr.JSON(label="Top-5 Predictions (Class Index Only)"),
70
+ title="DenseNet-121 Model Classifier (No Label File)",
71
+ description="This app uses a legacy DenseNet121 .pth model and returns top-5 class indices + confidence.",
72
  )
73
 
74
+ demo.launch()