🧠 TinyCNN for MNIST (94K params)

This repository contains a lightweight Convolutional Neural Network (CNN) designed for the MNIST handwritten digit classification task. This model was train on shifted MNIST handwritten digit dataset The model is optimized to be small, fast, and easy to deploy, suitable for both research and educational purposes.


πŸ“Œ Model Summary

Attribute Value
Model Name TinyCNN
Dataset MNIST (28Γ—28 grayscale digits)
Total Parameters ~94,410
Architecture Conv-BN-ReLU Γ—3 β†’ Global Avg Pool β†’ FC
Input Shape (1, 28, 28)
Output Classes 10
Framework PyTorch

πŸ— Architecture Overview

Input: 1Γ—28Γ—28

Conv Block 1: Conv(1β†’32, 3Γ—3) β†’ BatchNorm β†’ ReLU β†’ MaxPool(2Γ—2)

Conv Block 2: Conv(32β†’64, 3Γ—3) β†’ BatchNorm β†’ ReLU β†’ MaxPool(2Γ—2)

Conv Block 3: Conv(64β†’128, 3Γ—3) β†’ BatchNorm β†’ ReLU β†’ MaxPool(2Γ—2)

Global Average Pooling

Fully Connected Layer β†’ 10 output classes

This architecture emphasizes parameter efficiency while maintaining strong representation capability.


βš™οΈ Installation

pip install torch torchvision

πŸš€ Load Model From Hub

import torch

class TinyCNN(nn.Module):
    """
    Tiny CNN for MNIST using Global Avg Pooling.

    Trainable parameters: 94,410
    """
    
    def __init__(self, num_classes=10):
        super(TinyCNN, self).__init__()
        
        # First conv block
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        # Second conv block
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)

        # Third conv block
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        # Global average pooling
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Final FC (input = 128 channels after GAP)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = self.avgpool(x)              # (batch, 64, 1, 1)
        x = x.view(x.size(0), -1)        # (batch, 64)
        x = self.fc(x)                   # (batch, num_classes)
        return x

model = TinyCNN(num_classes=10)

state_dict = torch.hub.load_state_dict_from_url(
    "https://huggingface.co/FinOS-Internship/ShiftedTinyCNN/TinyCNN_model_acc_98.97.pth"
)
model.load_state_dict(state_dict)
model.eval()

πŸ–Ό Example Inference

import torch
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

img = Image.open("digit.png")
x = transform(img).unsqueeze(0)  # shape: (1, 1, 28, 28)

with torch.no_grad():
    logits = model(x)
    pred = logits.argmax(dim=1).item()

print("Predicted digit:", pred)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train FinOS-Internship/ShiftedTinyCNN