# @title Experiment 8.23 — DeiT-Small + SVD Taps # # Proper transformer: 12 layers, 384-d, 6 heads, patch 4×4 # SVD taps at layers 3, 6, 9, 12 (every 3 layers) # Same detached SVD observation as 8.21 — proven to work # # DeiT-Small on CIFAR-100 typically hits ~75-78% from scratch. # Question: does SVD structural observation push it further? """ Standard DeiT without SVD features at 100 epochs defeated. Overtaken by epoch 52 of the SVD features variant. The amount of time difference isn't insignificant, but this will be improved considerably as the process improves. [DATA] CIFAR-100: 50000 train, 10000 val, bs=256 ====================================================================== PHASE 1: DeiT-Small BASELINE (no SVD) ====================================================================== [MODEL] DeiT-Small Baseline: 21,526,372 params ====================================================================== [EXP] deit_small_baseline | 21,526,372 params | 100 epochs ====================================================================== E 1 | Tr 1.0% Va 1.5% | L=4.636 gap=-0.5 | Best 1.5%@E1 | 5.8s E 2 | Tr 1.8% Va 3.1% | L=4.526 gap=-1.3 | Best 3.1%@E2 | 5.4s E 3 | Tr 2.9% Va 3.9% | L=4.389 gap=-1.0 | Best 3.9%@E3 | 5.4s E 4 | Tr 3.6% Va 4.9% | L=4.335 gap=-1.3 | Best 4.9%@E4 | 5.4s E 5 | Tr 4.1% Va 5.3% | L=4.282 gap=-1.2 | Best 5.3%@E5 | 5.4s E 10 | Tr 10.5% Va 13.7% | L=3.861 gap=-3.2 | Best 13.7%@E10 | 5.4s E 15 | Tr 15.2% Va 20.5% | L=3.585 gap=-5.2 | Best 20.5%@E15 | 5.4s E 20 | Tr 18.8% Va 25.2% | L=3.375 gap=-6.4 | Best 25.2%@E20 | 5.4s E 25 | Tr 21.9% Va 26.9% | L=3.218 gap=-5.1 | Best 26.9%@E25 | 5.4s E 30 | Tr 24.4% Va 30.1% | L=3.075 gap=-5.7 | Best 30.1%@E30 | 5.4s E 35 | Tr 27.2% Va 32.8% | L=2.941 gap=-5.6 | Best 32.8%@E35 | 5.4s E 40 | Tr 29.5% Va 34.8% | L=2.808 gap=-5.2 | Best 34.8%@E40 | 5.4s E 45 | Tr 31.9% Va 36.1% | L=2.685 gap=-4.2 | Best 36.1%@E45 | 5.4s E 50 | Tr 35.1% Va 37.5% | L=2.547 gap=-2.4 | Best 37.5%@E50 | 5.4s E 55 | Tr 37.6% Va 38.8% | L=2.428 gap=-1.2 | Best 38.8%@E55 | 5.4s E 60 | Tr 40.1% Va 40.0% | L=2.313 gap=+0.1 | Best 40.0%@E60 | 5.4s E 65 | Tr 42.5% Va 41.1% | L=2.202 gap=+1.4 | Best 41.1%@E65 | 5.4s E 70 | Tr 44.8% Va 41.2% | L=2.104 gap=+3.6 | Best 41.2%@E70 | 5.4s E 75 | Tr 46.9% Va 41.5% | L=2.023 gap=+5.3 | Best 41.6%@E72 | 5.4s E 80 | Tr 48.5% Va 41.7% | L=1.955 gap=+6.8 | Best 41.9%@E79 | 5.4s E 85 | Tr 49.6% Va 42.3% | L=1.914 gap=+7.2 | Best 42.5%@E83 | 5.4s E 90 | Tr 50.4% Va 42.3% | L=1.877 gap=+8.1 | Best 42.7%@E87 | 5.4s E 95 | Tr 50.6% Va 42.5% | L=1.860 gap=+8.1 | Best 42.7%@E87 | 5.4s E100 | Tr 50.7% Va 42.5% | L=1.855 gap=+8.2 | Best 42.7%@E87 | 5.4s [RESULT] deit_small_baseline: Best Val = 42.74% @E87 | Params: 21,526,372 ====================================================================== PHASE 2: DeiT-Small + SVD TAPS ====================================================================== [MODEL] DeiT-Small + SVD: 21,676,900 params SVD taps at layers: (3, 6, 9, 12) SVD features: 264 = 4×66 Classifier input: 384 + 264 = 648 ====================================================================== [EXP] deit_small_svd | 21,676,900 params | 100 epochs ====================================================================== E 1 | Tr 2.3% Va 3.7% | L=4.471 gap=-1.3 | Best 3.7%@E1 | 21.4s E 2 | Tr 4.4% Va 7.9% | L=4.261 gap=-3.5 | Best 7.9%@E2 | 22.7s E 3 | Tr 6.4% Va 9.0% | L=4.108 gap=-2.6 | Best 9.0%@E3 | 21.1s E 4 | Tr 8.4% Va 12.0% | L=3.972 gap=-3.6 | Best 12.0%@E4 | 21.1s E 5 | Tr 10.1% Va 14.6% | L=3.878 gap=-4.6 | Best 14.6%@E5 | 21.1s E 10 | Tr 15.9% Va 20.7% | L=3.533 gap=-4.8 | Best 20.7%@E10 | 21.5s E 15 | Tr 19.7% Va 24.7% | L=3.321 gap=-5.0 | Best 24.7%@E15 | 21.4s E 20 | Tr 22.8% Va 28.5% | L=3.170 gap=-5.7 | Best 28.5%@E20 | 22.1s E 25 | Tr 25.3% Va 30.5% | L=3.023 gap=-5.2 | Best 30.5%@E25 | 21.8s E 30 | Tr 28.0% Va 33.2% | L=2.882 gap=-5.2 | Best 33.2%@E30 | 22.0s E 35 | Tr 31.0% Va 35.2% | L=2.741 gap=-4.2 | Best 35.2%@E35 | 22.0s E 40 | Tr 34.0% Va 37.6% | L=2.588 gap=-3.6 | Best 37.6%@E40 | 21.7s E 45 | Tr 37.5% Va 39.1% | L=2.432 gap=-1.6 | Best 39.4%@E44 | 22.0s E 50 | Tr 40.1% Va 41.0% | L=2.301 gap=-0.9 | Best 41.2%@E49 | 21.9s E 55 | Tr 43.4% Va 43.5% | L=2.138 gap=-0.1 | Best 43.5%@E55 | 22.2s E 60 | Tr 46.9% Va 44.1% | L=2.007 gap=+2.8 | Best 44.1%@E60 | 22.2s E 65 | Tr 50.0% Va 44.7% | L=1.869 gap=+5.3 | Best 45.0%@E64 | 22.2s E 70 | Tr 52.5% Va 45.8% | L=1.768 gap=+6.7 | Best 45.9%@E69 | 23.1s E 75 | Tr 54.7% Va 46.3% | L=1.669 gap=+8.4 | Best 46.3%@E75 | 21.8s E 80 | Tr 57.3% Va 45.8% | L=1.574 gap=+11.5 | Best 46.3%@E79 | 22.4s E 85 | Tr 58.4% Va 46.7% | L=1.532 gap=+11.7 | Best 46.8%@E84 | 22.4s E 90 | Tr 59.5% Va 46.8% | L=1.490 gap=+12.7 | Best 46.9%@E87 | 22.4s E 95 | Tr 60.0% Va 46.9% | L=1.475 gap=+13.2 | Best 46.9%@E87 | 22.2s E100 | Tr 60.6% Va 46.8% | L=1.452 gap=+13.8 | Best 46.9%@E87 | 21.9s [RESULT] deit_small_svd: Best Val = 46.90% @E87 | Params: 21,676,900 ====================================================================== HEAD-TO-HEAD COMPARISON ====================================================================== Model Val% Params ------------------------------------------------------- DeiT-Small baseline 42.74% 21,526,372 DeiT-Small + SVD 46.90% 21,676,900 SVD contribution: +4.16 points ================================================================================ SCOREBOARD ================================================================================ Experiment Val% Params Epoch --------------------------------------------- ------- ---------- ------ svd_classification_test 70.92% 3,878,820 93 vit_svd_classification_test 53.57% 6,705,828 86 deit_small_baseline 42.74% 21,526,372 87 ================================================================================ ================================================================================ SCOREBOARD ================================================================================ Experiment Val% Params Epoch --------------------------------------------- ------- ---------- ------ svd_classification_test 70.92% 3,878,820 93 vit_svd_classification_test 53.57% 6,705,828 86 deit_small_svd 46.90% 21,676,900 87 deit_small_baseline 42.74% 21,526,372 87 ================================================================================ """ class DeiTSmallSVD(nn.Module): """DeiT-Small with SVD observation taps. Architecture: Patch embed (4×4) → 64 tokens + CLS → 384-d 12 transformer layers (384-d, 6 heads, MLP ratio 4) SVD tap after layers 3, 6, 9, 12 CLS token + SVD features → classify SVD observes token-space structure at 4 depths. Detached — no gradient through eigh. Transformer learns normally. SVD provides complementary structural features to the classifier. """ def __init__(self, num_classes=100, img_size=32, patch_size=4, embed_dim=384, depth=12, n_heads=6, mlp_ratio=4.0, dropout=0.1, svd_rank=32, tap_layers=(3, 6, 9, 12)): super().__init__() self.embed_dim = embed_dim self.svd_rank = svd_rank self.tap_layers = tap_layers self.n_taps = len(tap_layers) k = svd_rank # ── Patch embedding ── self.n_patches = (img_size // patch_size) ** 2 # 64 self.patch_embed = nn.Sequential( nn.Conv2d(3, embed_dim, patch_size, stride=patch_size), nn.Flatten(2), # (B, embed_dim, n_patches) ) self.patch_norm = nn.LayerNorm(embed_dim) # CLS token + positional embedding self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02) self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1, embed_dim) * 0.02) self.pos_drop = nn.Dropout(dropout) # ── Transformer layers ── self.layers = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=embed_dim, nhead=n_heads, dim_feedforward=int(embed_dim * mlp_ratio), dropout=dropout, activation='gelu', batch_first=True, norm_first=True) for _ in range(depth)]) self.norm = nn.LayerNorm(embed_dim) # ── SVD projections at tap points ── self.svd_projs = nn.ModuleList([ nn.Linear(embed_dim, k, bias=False) for _ in range(self.n_taps)]) # ── SVD feature extraction ── # Per-tap: S_norm(k) + Vh_diag(k) + offdiag(1) + entropy(1) = 2k+2 svd_feat_dim = 2 * k + 2 total_svd_feat = svd_feat_dim * self.n_taps # ── Classifier: CLS token (384-d) + SVD features (264-d) ── total_dim = embed_dim + total_svd_feat self.head = nn.Sequential( nn.Linear(total_dim, embed_dim), nn.GELU(), nn.LayerNorm(embed_dim), nn.Dropout(0.1), nn.Linear(embed_dim, num_classes)) self.n_params = sum(p.numel() for p in self.parameters()) self._init_weights() def _init_weights(self): # Trunc normal for linear, ones/zeros for norms for m in self.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: nn.init.zeros_(m.bias) def _extract_svd_features(self, S, Vh): """Compact SVD summary. Same as 8.21.""" B, k = S.shape S_safe = S.clamp(min=1e-6) s_norm = S_safe / (S_safe.sum(dim=-1, keepdim=True) + 1e-8) vh_diag = Vh.diagonal(dim1=-2, dim2=-1) vh_offdiag = (Vh.pow(2).sum((-2, -1)) - vh_diag.pow(2).sum(-1)).unsqueeze(-1).clamp(min=0) s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1, keepdim=True) out = torch.cat([s_norm, vh_diag, vh_offdiag, s_ent], dim=-1) return torch.where(torch.isfinite(out), out, torch.zeros_like(out)) def forward(self, x): B = x.shape[0] # Patch embed + CLS patches = self.patch_embed(x).transpose(1, 2) # (B, n_patches, embed_dim) cls = self.cls_token.expand(B, -1, -1) tokens = torch.cat([cls, patches], dim=1) # (B, n_patches+1, embed_dim) tokens = self.pos_drop(tokens + self.pos_embed) # Transformer layers with SVD taps svd_feats = [] tap_idx = 0 for layer_idx, layer in enumerate(self.layers): tokens = layer(tokens) # SVD tap at designated layers (1-indexed: after layer 3, 6, 9, 12) if tap_idx < self.n_taps and (layer_idx + 1) == self.tap_layers[tap_idx]: # Project tokens (excluding CLS) to SVD space patch_tokens = tokens[:, 1:] # (B, n_patches, embed_dim) h_proj = self.svd_projs[tap_idx](patch_tokens) # (B, n_patches, k) with torch.amp.autocast('cuda', enabled=False): with torch.no_grad(): _, S, Vh = gram_eigh_svd(h_proj.float()) S = S.clamp(min=1e-6) S = torch.where(torch.isfinite(S), S, torch.ones_like(S)) Vh = torch.where(torch.isfinite(Vh), Vh, torch.zeros_like(Vh)) svd_feats.append(self._extract_svd_features(S, Vh)) tap_idx += 1 # Final norm + CLS token tokens = self.norm(tokens) cls_out = tokens[:, 0] # (B, embed_dim) # Concatenate CLS + SVD features all_feats = torch.cat([cls_out] + svd_feats, dim=-1) return self.head(all_feats) # ── Also build a baseline without SVD for fair comparison ──────────────────── class DeiTSmallBaseline(nn.Module): """Same DeiT-Small, no SVD taps. CLS → classify.""" def __init__(self, num_classes=100, img_size=32, patch_size=4, embed_dim=384, depth=12, n_heads=6, mlp_ratio=4.0, dropout=0.1): super().__init__() self.embed_dim = embed_dim self.n_patches = (img_size // patch_size) ** 2 self.patch_embed = nn.Sequential( nn.Conv2d(3, embed_dim, patch_size, stride=patch_size), nn.Flatten(2)) self.patch_norm = nn.LayerNorm(embed_dim) self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02) self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1, embed_dim) * 0.02) self.pos_drop = nn.Dropout(dropout) self.layers = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=embed_dim, nhead=n_heads, dim_feedforward=int(embed_dim * mlp_ratio), dropout=dropout, activation='gelu', batch_first=True, norm_first=True) for _ in range(depth)]) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Sequential( nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.LayerNorm(embed_dim), nn.Dropout(0.1), nn.Linear(embed_dim, num_classes)) self.n_params = sum(p.numel() for p in self.parameters()) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight); nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') def forward(self, x): B = x.shape[0] patches = self.patch_embed(x).transpose(1, 2) cls = self.cls_token.expand(B, -1, -1) tokens = self.pos_drop(torch.cat([cls, patches], dim=1) + self.pos_embed) for layer in self.layers: tokens = layer(tokens) cls_out = self.norm(tokens)[:, 0] return self.head(cls_out) # ── Training loop ──────────────────────────────────────────────────────────── def train_model(model, train_loader, val_loader, device, epochs=100, lr=3e-4, label=""): model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 best_val = 0.0; best_epoch = 0 print(f"\n{'='*70}") print(f"[EXP] {label} | {model.n_params:,} params | {epochs} epochs") print(f"{'='*70}") for epoch in range(1, epochs + 1): model.train(); t0 = time.time() correct = total = 0; loss_sum = 0.0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad(set_to_none=True) with torch.amp.autocast('cuda', dtype=amp_dtype): logits = model(images) loss = F.cross_entropy(logits, labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() correct += (logits.argmax(-1) == labels).sum().item() total += labels.size(0); loss_sum += loss.item() scheduler.step() train_acc = 100.0 * correct / total model.eval(); val_correct = val_total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) with torch.amp.autocast('cuda', dtype=amp_dtype): logits = model(images) val_correct += (logits.argmax(-1) == labels).sum().item() val_total += labels.size(0) val_acc = 100.0 * val_correct / val_total if val_acc > best_val: best_val = val_acc; best_epoch = epoch elapsed = time.time() - t0; gap = train_acc - val_acc if epoch <= 5 or epoch % 5 == 0 or epoch == epochs: print(f" E{epoch:>3} | Tr {train_acc:5.1f}% Va {val_acc:5.1f}%" f" | L={loss_sum/len(train_loader):.3f} gap={gap:+.1f}" f" | Best {best_val:.1f}%@E{best_epoch} | {elapsed:.1f}s") print(f"\n[RESULT] {label}: Best Val = {best_val:.2f}% @E{best_epoch} | Params: {model.n_params:,}") return {'experiment': label, 'best_val_acc': best_val, 'best_epoch': best_epoch, 'params': model.n_params} # ── Launch ─────────────────────────────────────────────────────────────────── tf_train = T.Compose([ T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.autoaugment.RandAugment(num_ops=2, magnitude=9), T.ToTensor()]) tf_val = T.Compose([T.ToTensor()]) train_ds = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=tf_train) val_ds = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=tf_val) train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4, pin_memory=True, drop_last=True, persistent_workers=True) val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True) print(f"[DATA] CIFAR-100: {len(train_ds)} train, {len(val_ds)} val, bs=256") # ── Run baseline first ── print("\n" + "="*70) print(" PHASE 1: DeiT-Small BASELINE (no SVD)") print("="*70) model_baseline = DeiTSmallBaseline(num_classes=100) print(f"[MODEL] DeiT-Small Baseline: {model_baseline.n_params:,} params") result_baseline = train_model(model_baseline, train_loader, val_loader, device, epochs=100, label="deit_small_baseline") # ── Then SVD version ── print("\n" + "="*70) print(" PHASE 2: DeiT-Small + SVD TAPS") print("="*70) model_svd = DeiTSmallSVD(num_classes=100, svd_rank=32) print(f"[MODEL] DeiT-Small + SVD: {model_svd.n_params:,} params") print(f" SVD taps at layers: {model_svd.tap_layers}") print(f" SVD features: {model_svd.n_taps * 66} = {model_svd.n_taps}×66") print(f" Classifier input: {model_svd.embed_dim} + {model_svd.n_taps * 66} = {model_svd.embed_dim + model_svd.n_taps * 66}") result_svd = train_model(model_svd, train_loader, val_loader, device, epochs=100, label="deit_small_svd") # ── Compare ── print(f"\n{'='*70}") print(f" HEAD-TO-HEAD COMPARISON") print(f"{'='*70}") print(f" {'Model':<30} {'Val%':>7} {'Params':>12}") print(f" {'-'*55}") print(f" {'DeiT-Small baseline':<30} {result_baseline['best_val_acc']:>6.2f}% {result_baseline['params']:>12,}") print(f" {'DeiT-Small + SVD':<30} {result_svd['best_val_acc']:>6.2f}% {result_svd['params']:>12,}") delta = result_svd['best_val_acc'] - result_baseline['best_val_acc'] print(f"\n SVD contribution: {delta:+.2f} points")