shrinusn77 commited on
Commit
ee80624
·
verified ·
1 Parent(s): b79cb71

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+ import json
7
+
8
+ # --------------------------
9
+ # Load the model
10
+ # --------------------------
11
+ model_path = "densenet121-a639ec97.pth" # put this file in the Space
12
+
13
+ state_dict = torch.load(model_path, map_location="cpu")
14
+
15
+ model = models.densenet121(weights=None)
16
+ model.load_state_dict(state_dict, strict=True)
17
+ model.eval()
18
+
19
+ # --------------------------
20
+ # Load ImageNet class labels
21
+ # --------------------------
22
+ # Downloaded from: https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
23
+ with open("imagenet_classes.json", "r") as f:
24
+ idx_to_class = json.load(f)
25
+
26
+ # --------------------------
27
+ # Preprocessing function
28
+ # --------------------------
29
+ preprocess = transforms.Compose([
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(
33
+ mean=[0.485, 0.456, 0.406],
34
+ std=[0.229, 0.224, 0.225]
35
+ )
36
+ ])
37
+
38
+ # --------------------------
39
+ # Prediction function
40
+ # --------------------------
41
+ def predict(image):
42
+ image = Image.fromarray(image).convert("RGB")
43
+ img_tensor = preprocess(image).unsqueeze(0)
44
+
45
+ with torch.no_grad():
46
+ outputs = model(img_tensor)
47
+ probs = torch.nn.functional.softmax(outputs, dim=1)
48
+ top5_prob, top5_idx = torch.topk(probs, 5)
49
+
50
+ results = []
51
+ for p, idx in zip(top5_prob[0], top5_idx[0]):
52
+ results.append({
53
+ "class": idx_to_class[str(idx.item())],
54
+ "confidence": float(p.item())
55
+ })
56
+
57
+ return results
58
+
59
+ # --------------------------
60
+ # Gradio UI
61
+ # --------------------------
62
+ demo = gr.Interface(
63
+ fn=predict,
64
+ inputs=gr.Image(type="numpy", label="Upload Image"),
65
+ outputs=gr.JSON(label="Top-5 Predictions"),
66
+ title="DenseNet-121 ImageNet Classifier",
67
+ description="Upload an image and get top-5 ImageNet predictions.",
68
+ )
69
+
70
+ demo.launch()