felix2703 commited on
Commit
5a1026e
·
verified ·
1 Parent(s): 1e33668

Update README

Browse files
Files changed (1) hide show
  1. README.md +49 -2
README.md CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  # 🧠 TinyCNN for MNIST (94K params)
2
 
3
  This repository contains a lightweight Convolutional Neural Network (CNN) designed for the MNIST handwritten digit classification task.
@@ -49,12 +57,51 @@ pip install torch torchvision
49
  ## 🚀 Load Model From Hub
50
  ```
51
  import torch
52
- from model import TinyCNN # Ensure this file is included in your repo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  model = TinyCNN(num_classes=10)
55
 
56
  state_dict = torch.hub.load_state_dict_from_url(
57
- "https://huggingface.co/<your-username>/<your-model-repo>/resolve/main/tinycnn_mnist.pth"
58
  )
59
  model.load_state_dict(state_dict)
60
  model.eval()
 
1
+ ---
2
+ datasets:
3
+ - ylecun/mnist
4
+ language:
5
+ - en
6
+ metrics:
7
+ - accuracy
8
+ ---
9
  # 🧠 TinyCNN for MNIST (94K params)
10
 
11
  This repository contains a lightweight Convolutional Neural Network (CNN) designed for the MNIST handwritten digit classification task.
 
57
  ## 🚀 Load Model From Hub
58
  ```
59
  import torch
60
+
61
+ class TinyCNN(nn.Module):
62
+ """
63
+ Tiny CNN for MNIST using Global Avg Pooling.
64
+
65
+ Trainable parameters: 94,410
66
+ """
67
+
68
+ def __init__(self, num_classes=10):
69
+ super(TinyCNN, self).__init__()
70
+
71
+ # First conv block
72
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
73
+ self.bn1 = nn.BatchNorm2d(32)
74
+ self.pool1 = nn.MaxPool2d(2, 2)
75
+
76
+ # Second conv block
77
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
78
+ self.bn2 = nn.BatchNorm2d(64)
79
+ self.pool2 = nn.MaxPool2d(2, 2)
80
+
81
+ # Third conv block
82
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
83
+ self.bn3 = nn.BatchNorm2d(128)
84
+ self.pool3 = nn.MaxPool2d(2, 2)
85
+
86
+ # Global average pooling
87
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
88
+
89
+ # Final FC (input = 128 channels after GAP)
90
+ self.fc = nn.Linear(128, num_classes)
91
+
92
+ def forward(self, x):
93
+ x = self.pool1(F.relu(self.bn1(self.conv1(x))))
94
+ x = self.pool2(F.relu(self.bn2(self.conv2(x))))
95
+ x = self.pool3(F.relu(self.bn3(self.conv3(x))))
96
+ x = self.avgpool(x) # (batch, 64, 1, 1)
97
+ x = x.view(x.size(0), -1) # (batch, 64)
98
+ x = self.fc(x) # (batch, num_classes)
99
+ return x
100
 
101
  model = TinyCNN(num_classes=10)
102
 
103
  state_dict = torch.hub.load_state_dict_from_url(
104
+ "https://huggingface.co/FinOS-Internship/ShiftedTinyCNN/TinyCNN_model_acc_98.97.pth"
105
  )
106
  model.load_state_dict(state_dict)
107
  model.eval()