shrinusn77 commited on
Commit
28a21c6
·
verified ·
1 Parent(s): 6048a38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -30
app.py CHANGED
@@ -4,29 +4,49 @@ import torch.nn as nn
4
  from torchvision import models, transforms
5
  from PIL import Image
6
  import json
 
7
 
8
- # --------------------------
9
- # Load the model
10
- # --------------------------
11
- model_path = "densenet121-a639ec97.pth" # put this file in the Space
12
 
13
- state_dict = torch.load(model_path, map_location="cpu")
14
-
15
- model = models.densenet121(weights=None)
16
- model.load_state_dict(state_dict, strict=True)
17
- model.eval()
18
-
19
- # --------------------------
20
- # Load ImageNet class labels
21
- # --------------------------
22
- # Downloaded from: https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
23
  with open("imagenet_classes.json", "r") as f:
24
  idx_to_class = json.load(f)
25
 
26
- # --------------------------
27
- # Preprocessing function
28
- # --------------------------
29
- preprocess = transforms.Compose([
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  transforms.Resize((224, 224)),
31
  transforms.ToTensor(),
32
  transforms.Normalize(
@@ -35,36 +55,37 @@ preprocess = transforms.Compose([
35
  )
36
  ])
37
 
38
- # --------------------------
39
  # Prediction function
40
- # --------------------------
41
  def predict(image):
42
  image = Image.fromarray(image).convert("RGB")
43
- img_tensor = preprocess(image).unsqueeze(0)
44
 
45
  with torch.no_grad():
46
  outputs = model(img_tensor)
47
- probs = torch.nn.functional.softmax(outputs, dim=1)
48
  top5_prob, top5_idx = torch.topk(probs, 5)
49
 
50
  results = []
51
  for p, idx in zip(top5_prob[0], top5_idx[0]):
 
52
  results.append({
53
- "class": idx_to_class[str(idx.item())],
54
  "confidence": float(p.item())
55
  })
56
 
57
  return results
58
 
59
- # --------------------------
60
- # Gradio UI
61
- # --------------------------
62
- demo = gr.Interface(
63
  fn=predict,
64
  inputs=gr.Image(type="numpy", label="Upload Image"),
65
  outputs=gr.JSON(label="Top-5 Predictions"),
66
- title="DenseNet-121 ImageNet Classifier",
67
- description="Upload an image and get top-5 ImageNet predictions.",
68
  )
69
 
70
- demo.launch()
 
4
  from torchvision import models, transforms
5
  from PIL import Image
6
  import json
7
+ import os
8
 
9
+ MODEL_PATH = "densenet121-a639ec97.pth" # rename your .pth to model.pth
 
 
 
10
 
11
+ # ------------------------------------------
12
+ # Load ImageNet Label Mapping
13
+ # ------------------------------------------
 
 
 
 
 
 
 
14
  with open("imagenet_classes.json", "r") as f:
15
  idx_to_class = json.load(f)
16
 
17
+ # ------------------------------------------
18
+ # Load Model with Legacy-Key Compatibility
19
+ # ------------------------------------------
20
+ def load_legacy_densenet(path):
21
+ print("Loading model:", path)
22
+
23
+ # Load state_dict (old torchvision format)
24
+ state_dict = torch.load(path, map_location="cpu")
25
+
26
+ # Create new DenseNet model
27
+ model = models.densenet121(weights=None)
28
+
29
+ # Try loading with strict=False (legacy fix)
30
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
31
+
32
+ print("\n=== LOADING SUMMARY ===")
33
+ print("Missing keys:", len(missing))
34
+ print("Unexpected keys:", len(unexpected))
35
+
36
+ if len(missing) > 0:
37
+ print("⚠ NOTE: Missing keys detected (normal for legacy checkpoint)")
38
+ if len(unexpected) > 0:
39
+ print("⚠ NOTE: Unexpected keys detected (normal for legacy checkpoint)")
40
+
41
+ model.eval()
42
+ return model
43
+
44
+ model = load_legacy_densenet(MODEL_PATH)
45
+
46
+ # ------------------------------------------
47
+ # Preprocessing
48
+ # ------------------------------------------
49
+ transform = transforms.Compose([
50
  transforms.Resize((224, 224)),
51
  transforms.ToTensor(),
52
  transforms.Normalize(
 
55
  )
56
  ])
57
 
58
+ # ------------------------------------------
59
  # Prediction function
60
+ # ------------------------------------------
61
  def predict(image):
62
  image = Image.fromarray(image).convert("RGB")
63
+ img_tensor = transform(image).unsqueeze(0)
64
 
65
  with torch.no_grad():
66
  outputs = model(img_tensor)
67
+ probs = torch.softmax(outputs, dim=1)
68
  top5_prob, top5_idx = torch.topk(probs, 5)
69
 
70
  results = []
71
  for p, idx in zip(top5_prob[0], top5_idx[0]):
72
+ cls_name = idx_to_class.get(str(idx.item()), "Unknown")
73
  results.append({
74
+ "class": cls_name,
75
  "confidence": float(p.item())
76
  })
77
 
78
  return results
79
 
80
+ # ------------------------------------------
81
+ # Gradio Interface
82
+ # ------------------------------------------
83
+ interface = gr.Interface(
84
  fn=predict,
85
  inputs=gr.Image(type="numpy", label="Upload Image"),
86
  outputs=gr.JSON(label="Top-5 Predictions"),
87
+ title="DenseNet-121 Legacy Model Classifier (ImageNet)",
88
+ description="Upload any image. Model returns top-5 ImageNet predictions.",
89
  )
90
 
91
+ interface.launch()