| """ |
| triton_svd3.py — Fused Triton SVD kernel for batched M×3 matrices. |
| |
| Target use case: CIFAR-sized images (32×32=1024 pixels, 3 channels) |
| where cuSOLVER overhead dominates because the "thin" dimension is only 3. |
| |
| Architecture: |
| 1. Each program handles one (M×3) matrix from the batch |
| 2. Compute 3×3 Gram matrix G = A^T A (6 unique values, in registers) |
| 3. Diagonalize G via Jacobi rotations (3×3 converges in ≤6 sweeps) |
| 4. S = sqrt(eigenvalues), V = eigenvectors |
| 5. U = A @ V @ diag(1/S) (tiled reduction over M) |
| |
| The entire 3×3 eigensolver lives in scalar registers — zero shared memory, |
| zero global memory round-trips. The only bandwidth cost is loading A and |
| writing back U, S, Vh. |
| |
| Author: AbstractPhil / Claude |
| """ |
|
|
| import triton |
| import triton.language as tl |
| import torch |
| import math |
|
|
|
|
| |
| |
| |
|
|
| @triton.jit |
| def _svd3_kernel( |
| |
| A_ptr, |
| U_ptr, |
| S_ptr, |
| Vh_ptr, |
| |
| M: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| JACOBI_ITERS: tl.constexpr, |
| EPS: tl.constexpr, |
| ): |
| """ |
| One program instance = one (M, 3) matrix in the batch. |
| """ |
| bid = tl.program_id(0) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| g00 = tl.zeros([], dtype=tl.float32) |
| g01 = tl.zeros([], dtype=tl.float32) |
| g02 = tl.zeros([], dtype=tl.float32) |
| g11 = tl.zeros([], dtype=tl.float32) |
| g12 = tl.zeros([], dtype=tl.float32) |
| g22 = tl.zeros([], dtype=tl.float32) |
|
|
| base = bid * M * 3 |
|
|
| |
| for block_start in range(0, M, BLOCK_M): |
| offs = tl.arange(0, BLOCK_M) |
| row_idx = block_start + offs |
| mask = row_idx < M |
|
|
| |
| |
| ptr0 = base + row_idx * 3 + 0 |
| ptr1 = base + row_idx * 3 + 1 |
| ptr2 = base + row_idx * 3 + 2 |
|
|
| a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32) |
| a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32) |
| a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32) |
|
|
| |
| g00 += tl.sum(a0 * a0) |
| g01 += tl.sum(a0 * a1) |
| g02 += tl.sum(a0 * a2) |
| g11 += tl.sum(a1 * a1) |
| g12 += tl.sum(a1 * a2) |
| g22 += tl.sum(a2 * a2) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| v00 = 1.0; v01 = 0.0; v02 = 0.0 |
| v10 = 0.0; v11 = 1.0; v12 = 0.0 |
| v20 = 0.0; v21 = 0.0; v22 = 1.0 |
|
|
| for _sweep in range(JACOBI_ITERS): |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| off_diag = g01 |
| diag_diff = g11 - g00 |
| abs_off = tl.abs(off_diag) |
| |
| tau_01 = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0) |
| abs_tau = tl.abs(tau_01) |
| t_01 = tl.where( |
| abs_off > EPS, |
| tl.where(tau_01 >= 0, 1.0, -1.0) / (abs_tau + tl.sqrt(1.0 + tau_01 * tau_01)), |
| 0.0, |
| ) |
| c01 = 1.0 / tl.sqrt(1.0 + t_01 * t_01) |
| s01 = t_01 * c01 |
|
|
| |
| new_g00 = c01*c01*g00 - 2.0*s01*c01*g01 + s01*s01*g11 |
| new_g11 = s01*s01*g00 + 2.0*s01*c01*g01 + c01*c01*g11 |
| new_g01 = 0.0 |
| new_g02 = c01*g02 - s01*g12 |
| new_g12 = s01*g02 + c01*g12 |
| |
| g00 = new_g00; g11 = new_g11; g01 = new_g01 |
| g02 = new_g02; g12 = new_g12 |
|
|
| |
| nv00 = c01*v00 - s01*v01; nv01 = s01*v00 + c01*v01 |
| nv10 = c01*v10 - s01*v11; nv11 = s01*v10 + c01*v11 |
| nv20 = c01*v20 - s01*v21; nv21 = s01*v20 + c01*v21 |
| v00 = nv00; v01 = nv01; v10 = nv10; v11 = nv11; v20 = nv20; v21 = nv21 |
|
|
| |
| off_diag = g02 |
| diag_diff = g22 - g00 |
| abs_off = tl.abs(off_diag) |
| tau_02 = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0) |
| abs_tau = tl.abs(tau_02) |
| t_02 = tl.where( |
| abs_off > EPS, |
| tl.where(tau_02 >= 0, 1.0, -1.0) / (abs_tau + tl.sqrt(1.0 + tau_02 * tau_02)), |
| 0.0, |
| ) |
| c02 = 1.0 / tl.sqrt(1.0 + t_02 * t_02) |
| s02 = t_02 * c02 |
|
|
| new_g00 = c02*c02*g00 - 2.0*s02*c02*g02 + s02*s02*g22 |
| new_g22 = s02*s02*g00 + 2.0*s02*c02*g02 + c02*c02*g22 |
| new_g02 = 0.0 |
| new_g01 = c02*g01 - s02*g12 |
| new_g12_b = s02*g01 + c02*g12 |
| g00 = new_g00; g22 = new_g22; g02 = new_g02 |
| g01 = new_g01; g12 = new_g12_b |
|
|
| nv00 = c02*v00 - s02*v02; nv02 = s02*v00 + c02*v02 |
| nv10 = c02*v10 - s02*v12; nv12 = s02*v10 + c02*v12 |
| nv20 = c02*v20 - s02*v22; nv22 = s02*v20 + c02*v22 |
| v00 = nv00; v02 = nv02; v10 = nv10; v12 = nv12; v20 = nv20; v22 = nv22 |
|
|
| |
| off_diag = g12 |
| diag_diff = g22 - g11 |
| abs_off = tl.abs(off_diag) |
| tau_12 = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0) |
| abs_tau = tl.abs(tau_12) |
| t_12 = tl.where( |
| abs_off > EPS, |
| tl.where(tau_12 >= 0, 1.0, -1.0) / (abs_tau + tl.sqrt(1.0 + tau_12 * tau_12)), |
| 0.0, |
| ) |
| c12 = 1.0 / tl.sqrt(1.0 + t_12 * t_12) |
| s12 = t_12 * c12 |
|
|
| new_g11 = c12*c12*g11 - 2.0*s12*c12*g12 + s12*s12*g22 |
| new_g22 = s12*s12*g11 + 2.0*s12*c12*g12 + c12*c12*g22 |
| new_g12 = 0.0 |
| new_g01 = c12*g01 - s12*g02 |
| new_g02_b = s12*g01 + c12*g02 |
| g11 = new_g11; g22 = new_g22; g12 = new_g12 |
| g01 = new_g01; g02 = new_g02_b |
|
|
| nv01 = c12*v01 - s12*v02; nv02 = s12*v01 + c12*v02 |
| nv11 = c12*v11 - s12*v12; nv12 = s12*v11 + c12*v12 |
| nv21 = c12*v21 - s12*v22; nv22 = s12*v21 + c12*v22 |
| v01 = nv01; v02 = nv02; v11 = nv11; v12 = nv12; v21 = nv21; v22 = nv22 |
|
|
| |
| |
| |
| |
| eig0 = tl.maximum(g00, EPS) |
| eig1 = tl.maximum(g11, EPS) |
| eig2 = tl.maximum(g22, EPS) |
|
|
| s0 = tl.sqrt(eig0) |
| s1 = tl.sqrt(eig1) |
| s2 = tl.sqrt(eig2) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| do_swap = s0 < s1 |
| s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1) |
| |
| tv00 = v00; tv10 = v10; tv20 = v20 |
| v00 = tl.where(do_swap, v01, v00); v01 = tl.where(do_swap, tv00, v01) |
| v10 = tl.where(do_swap, v11, v10); v11 = tl.where(do_swap, tv10, v11) |
| v20 = tl.where(do_swap, v21, v20); v21 = tl.where(do_swap, tv20, v21) |
|
|
| |
| do_swap = s0 < s2 |
| s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2) |
| tv00 = v00; tv10 = v10; tv20 = v20 |
| v00 = tl.where(do_swap, v02, v00); v02 = tl.where(do_swap, tv00, v02) |
| v10 = tl.where(do_swap, v12, v10); v12 = tl.where(do_swap, tv10, v12) |
| v20 = tl.where(do_swap, v22, v20); v22 = tl.where(do_swap, tv20, v22) |
|
|
| |
| do_swap = s1 < s2 |
| s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2) |
| tv01 = v01; tv11 = v11; tv21 = v21 |
| v01 = tl.where(do_swap, v02, v01); v02 = tl.where(do_swap, tv01, v02) |
| v11 = tl.where(do_swap, v12, v11); v12 = tl.where(do_swap, tv11, v12) |
| v21 = tl.where(do_swap, v22, v21); v22 = tl.where(do_swap, tv21, v22) |
|
|
| |
| |
| |
| s_base = bid * 3 |
| tl.store(S_ptr + s_base + 0, s0) |
| tl.store(S_ptr + s_base + 1, s1) |
| tl.store(S_ptr + s_base + 2, s2) |
|
|
| |
| vh_base = bid * 9 |
| tl.store(Vh_ptr + vh_base + 0, v00) |
| tl.store(Vh_ptr + vh_base + 1, v10) |
| tl.store(Vh_ptr + vh_base + 2, v20) |
| tl.store(Vh_ptr + vh_base + 3, v01) |
| tl.store(Vh_ptr + vh_base + 4, v11) |
| tl.store(Vh_ptr + vh_base + 5, v21) |
| tl.store(Vh_ptr + vh_base + 6, v02) |
| tl.store(Vh_ptr + vh_base + 7, v12) |
| tl.store(Vh_ptr + vh_base + 8, v22) |
|
|
| |
| |
| |
| |
| |
| |
| inv_s0 = 1.0 / (s0 + EPS) |
| inv_s1 = 1.0 / (s1 + EPS) |
| inv_s2 = 1.0 / (s2 + EPS) |
|
|
| for block_start in range(0, M, BLOCK_M): |
| offs = tl.arange(0, BLOCK_M) |
| row_idx = block_start + offs |
| mask = row_idx < M |
|
|
| ptr0 = base + row_idx * 3 + 0 |
| ptr1 = base + row_idx * 3 + 1 |
| ptr2 = base + row_idx * 3 + 2 |
|
|
| a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32) |
| a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32) |
| a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32) |
|
|
| |
| u0 = (a0 * v00 + a1 * v10 + a2 * v20) * inv_s0 |
| |
| u1 = (a0 * v01 + a1 * v11 + a2 * v21) * inv_s1 |
| |
| u2 = (a0 * v02 + a1 * v12 + a2 * v22) * inv_s2 |
|
|
| u_base = bid * M * 3 |
| tl.store(U_ptr + u_base + row_idx * 3 + 0, u0, mask=mask) |
| tl.store(U_ptr + u_base + row_idx * 3 + 1, u1, mask=mask) |
| tl.store(U_ptr + u_base + row_idx * 3 + 2, u2, mask=mask) |
|
|
|
|
| |
| |
| |
|
|
| def batched_svd3( |
| A: torch.Tensor, |
| block_m: int = 128, |
| jacobi_iters: int = 6, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Batched thin SVD for (B, M, 3) float32 tensors. |
| |
| Args: |
| A: Input tensor of shape (B, M, 3). M can be anything (1024 for CIFAR). |
| block_m: Tile size for the spatial dimension. 128 is good for M=1024. |
| jacobi_iters: Number of cyclic Jacobi sweeps. 6 is overkill for 3×3. |
| |
| Returns: |
| U: (B, M, 3) — thin left singular vectors |
| S: (B, 3) — singular values, descending |
| Vh: (B, 3, 3) — right singular vectors transposed |
| """ |
| assert A.ndim == 3 and A.shape[2] == 3, f"Expected (B, M, 3), got {A.shape}" |
| assert A.is_cuda, "Input must be on CUDA" |
|
|
| B, M, _ = A.shape |
| A_f32 = A.contiguous().float() |
|
|
| U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device) |
| S = torch.empty((B, 3), dtype=torch.float32, device=A.device) |
| Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device) |
|
|
| _svd3_kernel[(B,)]( |
| A_f32, U, S, Vh, |
| M=M, |
| BLOCK_M=block_m, |
| JACOBI_ITERS=jacobi_iters, |
| EPS=1e-12, |
| ) |
|
|
| return U, S, Vh |
|
|
|
|
| |
| |
| |
|
|
| def _test_correctness(B=256, M=1024): |
| """Validate against torch.linalg.svd.""" |
| A = torch.randn(B, M, 3, device="cuda", dtype=torch.float32) |
|
|
| U, S, Vh = batched_svd3(A) |
|
|
| |
| U_ref, S_ref, Vh_ref = torch.linalg.svd(A, full_matrices=False) |
|
|
| |
| s_err = (S - S_ref).abs().max().item() |
| print(f"[correctness] S max error: {s_err:.2e}") |
|
|
| |
| |
| |
| recon = torch.bmm(U * S.unsqueeze(1), Vh) |
| recon_ref = torch.bmm(U_ref * S_ref.unsqueeze(1), Vh_ref) |
| r_err = (A - recon).abs().max().item() |
| r_ref_err = (A - recon_ref).abs().max().item() |
| print(f"[correctness] Recon error: {r_err:.2e} (ref: {r_ref_err:.2e})") |
|
|
| |
| UtU = torch.bmm(U.transpose(1, 2), U) |
| eye = torch.eye(3, device="cuda").expand(B, -1, -1) |
| orth_err = (UtU - eye).abs().max().item() |
| print(f"[correctness] U orthog error: {orth_err:.2e}") |
|
|
| |
| assert s_err < 1e-3, f"Singular value error {s_err} > 1e-3" |
| assert r_err < r_ref_err * 2.0 + 1e-5, ( |
| f"Recon error {r_err:.2e} > 2x torch ref {r_ref_err:.2e}" |
| ) |
| assert orth_err < 1e-3, f"Orthogonality error {orth_err} > 1e-3" |
| print("[correctness] PASSED\n") |
|
|
|
|
| def _cuda_timer(fn, warmup=25, iters=100): |
| """ |
| CUDA-event-timed benchmark. Returns (mean_ms, std_ms) over `iters` runs. |
| Uses per-iteration events for proper distribution stats. |
| """ |
| |
| for _ in range(warmup): |
| fn() |
| torch.cuda.synchronize() |
|
|
| starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] |
| ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] |
|
|
| for i in range(iters): |
| starts[i].record() |
| fn() |
| ends[i].record() |
| torch.cuda.synchronize() |
|
|
| times = [starts[i].elapsed_time(ends[i]) for i in range(iters)] |
| t = torch.tensor(times) |
| return t.mean().item(), t.std().item(), t.median().item(), t.min().item(), t.max().item() |
|
|
|
|
| def _profile_sweep(): |
| """ |
| Full profiling sweep: batch sizes × spatial dims × block sizes. |
| Compares Triton SVD3 vs torch.linalg.svd with CUDA event timing. |
| """ |
| import json |
|
|
| device_name = torch.cuda.get_device_name(0) |
| print(f"{'='*72}") |
| print(f" Triton SVD3 Profiling — {device_name}") |
| print(f"{'='*72}\n") |
|
|
| |
| |
| |
| print(f"{'─'*72}") |
| print(f" SWEEP 1: Batch scaling (M=1024, BLOCK_M=128)") |
| print(f"{'─'*72}") |
| print(f" {'B':>6} {'Triton ms':>12} {'±std':>8} {'Torch ms':>12} {'±std':>8} {'Speedup':>8}") |
| print(f" {'─'*62}") |
|
|
| batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384] |
| M = 1024 |
| batch_results = [] |
|
|
| for B in batch_sizes: |
| A = torch.randn(B, M, 3, device="cuda", dtype=torch.float32) |
|
|
| tri_mean, tri_std, tri_med, tri_min, tri_max = _cuda_timer( |
| lambda: batched_svd3(A, block_m=128) |
| ) |
| tch_mean, tch_std, tch_med, tch_min, tch_max = _cuda_timer( |
| lambda: torch.linalg.svd(A, full_matrices=False) |
| ) |
| speedup = tch_mean / (tri_mean + 1e-9) |
|
|
| print(f" {B:>6} {tri_mean:>10.3f}ms {tri_std:>6.3f} {tch_mean:>10.3f}ms {tch_std:>6.3f} {speedup:>7.2f}x") |
| batch_results.append({ |
| "B": B, "M": M, |
| "triton_mean_ms": round(tri_mean, 4), "triton_std_ms": round(tri_std, 4), |
| "triton_median_ms": round(tri_med, 4), "triton_min_ms": round(tri_min, 4), |
| "torch_mean_ms": round(tch_mean, 4), "torch_std_ms": round(tch_std, 4), |
| "torch_median_ms": round(tch_med, 4), "torch_min_ms": round(tch_min, 4), |
| "speedup": round(speedup, 3), |
| }) |
| del A |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| print(f"\n{'─'*72}") |
| print(f" SWEEP 2: Spatial scaling (B=1024, BLOCK_M=128)") |
| print(f"{'─'*72}") |
| print(f" {'M':>6} {'Triton ms':>12} {'±std':>8} {'Torch ms':>12} {'±std':>8} {'Speedup':>8}") |
| print(f" {'─'*62}") |
|
|
| spatial_dims = [64, 256, 512, 1024, 2048, 4096] |
| B = 1024 |
| spatial_results = [] |
|
|
| for M in spatial_dims: |
| A = torch.randn(B, M, 3, device="cuda", dtype=torch.float32) |
|
|
| tri_mean, tri_std, tri_med, tri_min, tri_max = _cuda_timer( |
| lambda: batched_svd3(A, block_m=128) |
| ) |
| tch_mean, tch_std, tch_med, tch_min, tch_max = _cuda_timer( |
| lambda: torch.linalg.svd(A, full_matrices=False) |
| ) |
| speedup = tch_mean / (tri_mean + 1e-9) |
|
|
| equiv_hw = int(M**0.5) |
| tag = f"~{equiv_hw}x{equiv_hw}" if equiv_hw * equiv_hw == M else f" {M}" |
| print(f" {tag:>6} {tri_mean:>10.3f}ms {tri_std:>6.3f} {tch_mean:>10.3f}ms {tch_std:>6.3f} {speedup:>7.2f}x") |
| spatial_results.append({ |
| "B": B, "M": M, |
| "triton_mean_ms": round(tri_mean, 4), "torch_mean_ms": round(tch_mean, 4), |
| "speedup": round(speedup, 3), |
| }) |
| del A |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| print(f"\n{'─'*72}") |
| print(f" SWEEP 3: BLOCK_M tuning (B=4096, M=1024)") |
| print(f"{'─'*72}") |
| print(f" {'BLOCK_M':>8} {'Triton ms':>12} {'±std':>8} {'tiles/img':>10}") |
| print(f" {'─'*44}") |
|
|
| block_sizes = [32, 64, 128, 256, 512, 1024] |
| B, M = 4096, 1024 |
| A = torch.randn(B, M, 3, device="cuda", dtype=torch.float32) |
| block_results = [] |
|
|
| for bm in block_sizes: |
| tri_mean, tri_std, tri_med, tri_min, tri_max = _cuda_timer( |
| lambda: batched_svd3(A, block_m=bm) |
| ) |
| n_tiles = (M + bm - 1) // bm |
| print(f" {bm:>8} {tri_mean:>10.3f}ms {tri_std:>6.3f} {n_tiles:>10}") |
| block_results.append({ |
| "block_m": bm, "triton_mean_ms": round(tri_mean, 4), |
| "triton_std_ms": round(tri_std, 4), "n_tiles": n_tiles, |
| }) |
|
|
| del A |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| print(f"\n{'─'*72}") |
| print(f" SWEEP 4: Throughput (images/sec)") |
| print(f"{'─'*72}") |
|
|
| for B in [4096, 16384]: |
| A = torch.randn(B, 1024, 3, device="cuda", dtype=torch.float32) |
| tri_mean, *_ = _cuda_timer(lambda: batched_svd3(A, block_m=128)) |
| tch_mean, *_ = _cuda_timer(lambda: torch.linalg.svd(A, full_matrices=False)) |
|
|
| tri_ips = B / (tri_mean / 1000) |
| tch_ips = B / (tch_mean / 1000) |
| print(f" B={B:>5}: Triton {tri_ips:>12,.0f} img/s | Torch {tch_ips:>12,.0f} img/s") |
| del A |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| print(f"\n{'─'*72}") |
| print(f" MEMORY: Peak allocation (B=4096, M=1024)") |
| print(f"{'─'*72}") |
|
|
| B, M = 4096, 1024 |
| A = torch.randn(B, M, 3, device="cuda", dtype=torch.float32) |
|
|
| torch.cuda.reset_peak_memory_stats() |
| _ = batched_svd3(A) |
| torch.cuda.synchronize() |
| tri_peak = torch.cuda.max_memory_allocated() / 1024**2 |
|
|
| torch.cuda.reset_peak_memory_stats() |
| _ = torch.linalg.svd(A, full_matrices=False) |
| torch.cuda.synchronize() |
| tch_peak = torch.cuda.max_memory_allocated() / 1024**2 |
|
|
| print(f" Triton: {tri_peak:.1f} MB") |
| print(f" Torch: {tch_peak:.1f} MB") |
| print(f" Ratio: {tch_peak / (tri_peak + 1e-9):.2f}x\n") |
|
|
| |
| |
| |
| report = { |
| "device": device_name, |
| "batch_sweep": batch_results, |
| "spatial_sweep": spatial_results, |
| "block_m_sweep": block_results, |
| } |
| with open("svd3_profile.json", "w") as f: |
| json.dump(report, f, indent=2) |
| print(f" Full results written to svd3_profile.json\n") |
|
|
|
|
| if __name__ == "__main__": |
| _test_correctness() |
| _profile_sweep() |