π§ TinyCNN for MNIST (94K params)
This repository contains a lightweight Convolutional Neural Network (CNN) designed for the MNIST handwritten digit classification task. This model was train on shifted MNIST handwritten digit dataset The model is optimized to be small, fast, and easy to deploy, suitable for both research and educational purposes.
π Model Summary
| Attribute | Value |
|---|---|
| Model Name | TinyCNN |
| Dataset | MNIST (28Γ28 grayscale digits) |
| Total Parameters | ~94,410 |
| Architecture | Conv-BN-ReLU Γ3 β Global Avg Pool β FC |
| Input Shape | (1, 28, 28) |
| Output Classes | 10 |
| Framework | PyTorch |
π Architecture Overview
Input: 1Γ28Γ28
Conv Block 1: Conv(1β32, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2)
Conv Block 2: Conv(32β64, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2)
Conv Block 3: Conv(64β128, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2)
Global Average Pooling
Fully Connected Layer β 10 output classes
This architecture emphasizes parameter efficiency while maintaining strong representation capability.
βοΈ Installation
pip install torch torchvision
π Load Model From Hub
import torch
class TinyCNN(nn.Module):
"""
Tiny CNN for MNIST using Global Avg Pooling.
Trainable parameters: 94,410
"""
def __init__(self, num_classes=10):
super(TinyCNN, self).__init__()
# First conv block
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2, 2)
# Second conv block
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2, 2)
# Third conv block
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(2, 2)
# Global average pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Final FC (input = 128 channels after GAP)
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
x = self.avgpool(x) # (batch, 64, 1, 1)
x = x.view(x.size(0), -1) # (batch, 64)
x = self.fc(x) # (batch, num_classes)
return x
model = TinyCNN(num_classes=10)
state_dict = torch.hub.load_state_dict_from_url(
"https://huggingface.co/FinOS-Internship/ShiftedTinyCNN/TinyCNN_model_acc_98.97.pth"
)
model.load_state_dict(state_dict)
model.eval()
πΌ Example Inference
import torch
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor()
])
img = Image.open("digit.png")
x = transform(img).unsqueeze(0) # shape: (1, 1, 28, 28)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1).item()
print("Predicted digit:", pred)
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support