Create svd_triton.py
Browse files- svd_triton.py +576 -0
svd_triton.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
triton_svd3.py — Fused Triton SVD kernel for batched M×3 matrices.
|
| 3 |
+
|
| 4 |
+
Target use case: CIFAR-sized images (32×32=1024 pixels, 3 channels)
|
| 5 |
+
where cuSOLVER overhead dominates because the "thin" dimension is only 3.
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
1. Each program handles one (M×3) matrix from the batch
|
| 9 |
+
2. Compute 3×3 Gram matrix G = A^T A (6 unique values, in registers)
|
| 10 |
+
3. Diagonalize G via Jacobi rotations (3×3 converges in ≤6 sweeps)
|
| 11 |
+
4. S = sqrt(eigenvalues), V = eigenvectors
|
| 12 |
+
5. U = A @ V @ diag(1/S) (tiled reduction over M)
|
| 13 |
+
|
| 14 |
+
The entire 3×3 eigensolver lives in scalar registers — zero shared memory,
|
| 15 |
+
zero global memory round-trips. The only bandwidth cost is loading A and
|
| 16 |
+
writing back U, S, Vh.
|
| 17 |
+
|
| 18 |
+
Author: AbstractPhil / Claude
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import triton
|
| 22 |
+
import triton.language as tl
|
| 23 |
+
import torch
|
| 24 |
+
import math
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ============================================================================
|
| 28 |
+
# Core kernel: batched SVD for (B, M, 3) tensors
|
| 29 |
+
# ============================================================================
|
| 30 |
+
|
| 31 |
+
@triton.jit
|
| 32 |
+
def _svd3_kernel(
|
| 33 |
+
# Pointers
|
| 34 |
+
A_ptr, # (B, M, 3) input
|
| 35 |
+
U_ptr, # (B, M, 3) left singular vectors (thin)
|
| 36 |
+
S_ptr, # (B, 3) singular values
|
| 37 |
+
Vh_ptr, # (B, 3, 3) right singular vectors transposed
|
| 38 |
+
# Dimensions
|
| 39 |
+
M: tl.constexpr, # spatial dim (1024 for CIFAR)
|
| 40 |
+
BLOCK_M: tl.constexpr, # tile size for M-dimension loads
|
| 41 |
+
JACOBI_ITERS: tl.constexpr, # number of full Jacobi sweeps
|
| 42 |
+
EPS: tl.constexpr, # numerical floor
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
One program instance = one (M, 3) matrix in the batch.
|
| 46 |
+
"""
|
| 47 |
+
bid = tl.program_id(0) # batch index
|
| 48 |
+
|
| 49 |
+
# ---------------------------------------------------------------
|
| 50 |
+
# Stage 1: Compute 3×3 Gram matrix G = A^T @ A
|
| 51 |
+
#
|
| 52 |
+
# G is symmetric, so we only need 6 accumulators:
|
| 53 |
+
# g00 g01 g02
|
| 54 |
+
# g11 g12
|
| 55 |
+
# g22
|
| 56 |
+
# ---------------------------------------------------------------
|
| 57 |
+
g00 = tl.zeros([], dtype=tl.float32)
|
| 58 |
+
g01 = tl.zeros([], dtype=tl.float32)
|
| 59 |
+
g02 = tl.zeros([], dtype=tl.float32)
|
| 60 |
+
g11 = tl.zeros([], dtype=tl.float32)
|
| 61 |
+
g12 = tl.zeros([], dtype=tl.float32)
|
| 62 |
+
g22 = tl.zeros([], dtype=tl.float32)
|
| 63 |
+
|
| 64 |
+
base = bid * M * 3
|
| 65 |
+
|
| 66 |
+
# Tiled accumulation over the spatial dimension
|
| 67 |
+
for block_start in range(0, M, BLOCK_M):
|
| 68 |
+
offs = tl.arange(0, BLOCK_M) # [0, 1, ..., BLOCK_M-1]
|
| 69 |
+
row_idx = block_start + offs # actual row indices
|
| 70 |
+
mask = row_idx < M
|
| 71 |
+
|
| 72 |
+
# Load 3 columns for this tile
|
| 73 |
+
# A[bid, row, c] = A_ptr[bid*M*3 + row*3 + c]
|
| 74 |
+
ptr0 = base + row_idx * 3 + 0
|
| 75 |
+
ptr1 = base + row_idx * 3 + 1
|
| 76 |
+
ptr2 = base + row_idx * 3 + 2
|
| 77 |
+
|
| 78 |
+
a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)
|
| 79 |
+
a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)
|
| 80 |
+
a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)
|
| 81 |
+
|
| 82 |
+
# Accumulate outer products
|
| 83 |
+
g00 += tl.sum(a0 * a0)
|
| 84 |
+
g01 += tl.sum(a0 * a1)
|
| 85 |
+
g02 += tl.sum(a0 * a2)
|
| 86 |
+
g11 += tl.sum(a1 * a1)
|
| 87 |
+
g12 += tl.sum(a1 * a2)
|
| 88 |
+
g22 += tl.sum(a2 * a2)
|
| 89 |
+
|
| 90 |
+
# ---------------------------------------------------------------
|
| 91 |
+
# Stage 2: 3×3 Jacobi eigendecomposition (all in scalars)
|
| 92 |
+
#
|
| 93 |
+
# Cyclic Jacobi: rotate pairs (0,1), (0,2), (1,2) each sweep.
|
| 94 |
+
# For 3×3 symmetric PSD, 4-6 sweeps is overkill-level convergence.
|
| 95 |
+
#
|
| 96 |
+
# We maintain:
|
| 97 |
+
# - g_ij: the evolving matrix entries (symmetric, 6 values)
|
| 98 |
+
# - v_ij: the eigenvector matrix (9 values, starts as I)
|
| 99 |
+
# ---------------------------------------------------------------
|
| 100 |
+
|
| 101 |
+
# Eigenvector accumulator V (row-major: v_rc = V[r,c])
|
| 102 |
+
v00 = 1.0; v01 = 0.0; v02 = 0.0
|
| 103 |
+
v10 = 0.0; v11 = 1.0; v12 = 0.0
|
| 104 |
+
v20 = 0.0; v21 = 0.0; v22 = 1.0
|
| 105 |
+
|
| 106 |
+
for _sweep in range(JACOBI_ITERS):
|
| 107 |
+
# --- Rotation (p=0, q=1): zero out g01 ---
|
| 108 |
+
#
|
| 109 |
+
# Jacobi rotation angle:
|
| 110 |
+
# if |g_pq| < eps: skip
|
| 111 |
+
# tau = (g_qq - g_pp) / (2 * g_pq)
|
| 112 |
+
# t = sign(tau) / (|tau| + sqrt(1 + tau^2))
|
| 113 |
+
# c = 1/sqrt(1+t^2), s = t*c
|
| 114 |
+
|
| 115 |
+
# -- pair (0, 1) --
|
| 116 |
+
off_diag = g01
|
| 117 |
+
diag_diff = g11 - g00
|
| 118 |
+
abs_off = tl.abs(off_diag)
|
| 119 |
+
# Compute rotation
|
| 120 |
+
tau_01 = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)
|
| 121 |
+
abs_tau = tl.abs(tau_01)
|
| 122 |
+
t_01 = tl.where(
|
| 123 |
+
abs_off > EPS,
|
| 124 |
+
tl.where(tau_01 >= 0, 1.0, -1.0) / (abs_tau + tl.sqrt(1.0 + tau_01 * tau_01)),
|
| 125 |
+
0.0,
|
| 126 |
+
)
|
| 127 |
+
c01 = 1.0 / tl.sqrt(1.0 + t_01 * t_01)
|
| 128 |
+
s01 = t_01 * c01
|
| 129 |
+
|
| 130 |
+
# Apply Givens to G (symmetric update for p=0, q=1)
|
| 131 |
+
new_g00 = c01*c01*g00 - 2.0*s01*c01*g01 + s01*s01*g11
|
| 132 |
+
new_g11 = s01*s01*g00 + 2.0*s01*c01*g01 + c01*c01*g11
|
| 133 |
+
new_g01 = 0.0 # This is the point
|
| 134 |
+
new_g02 = c01*g02 - s01*g12
|
| 135 |
+
new_g12 = s01*g02 + c01*g12
|
| 136 |
+
# g22 unchanged
|
| 137 |
+
g00 = new_g00; g11 = new_g11; g01 = new_g01
|
| 138 |
+
g02 = new_g02; g12 = new_g12
|
| 139 |
+
|
| 140 |
+
# Apply to V columns: V[:, 0], V[:, 1]
|
| 141 |
+
nv00 = c01*v00 - s01*v01; nv01 = s01*v00 + c01*v01
|
| 142 |
+
nv10 = c01*v10 - s01*v11; nv11 = s01*v10 + c01*v11
|
| 143 |
+
nv20 = c01*v20 - s01*v21; nv21 = s01*v20 + c01*v21
|
| 144 |
+
v00 = nv00; v01 = nv01; v10 = nv10; v11 = nv11; v20 = nv20; v21 = nv21
|
| 145 |
+
|
| 146 |
+
# -- pair (0, 2) --
|
| 147 |
+
off_diag = g02
|
| 148 |
+
diag_diff = g22 - g00
|
| 149 |
+
abs_off = tl.abs(off_diag)
|
| 150 |
+
tau_02 = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)
|
| 151 |
+
abs_tau = tl.abs(tau_02)
|
| 152 |
+
t_02 = tl.where(
|
| 153 |
+
abs_off > EPS,
|
| 154 |
+
tl.where(tau_02 >= 0, 1.0, -1.0) / (abs_tau + tl.sqrt(1.0 + tau_02 * tau_02)),
|
| 155 |
+
0.0,
|
| 156 |
+
)
|
| 157 |
+
c02 = 1.0 / tl.sqrt(1.0 + t_02 * t_02)
|
| 158 |
+
s02 = t_02 * c02
|
| 159 |
+
|
| 160 |
+
new_g00 = c02*c02*g00 - 2.0*s02*c02*g02 + s02*s02*g22
|
| 161 |
+
new_g22 = s02*s02*g00 + 2.0*s02*c02*g02 + c02*c02*g22
|
| 162 |
+
new_g02 = 0.0
|
| 163 |
+
new_g01 = c02*g01 - s02*g12
|
| 164 |
+
new_g12_b = s02*g01 + c02*g12
|
| 165 |
+
g00 = new_g00; g22 = new_g22; g02 = new_g02
|
| 166 |
+
g01 = new_g01; g12 = new_g12_b
|
| 167 |
+
|
| 168 |
+
nv00 = c02*v00 - s02*v02; nv02 = s02*v00 + c02*v02
|
| 169 |
+
nv10 = c02*v10 - s02*v12; nv12 = s02*v10 + c02*v12
|
| 170 |
+
nv20 = c02*v20 - s02*v22; nv22 = s02*v20 + c02*v22
|
| 171 |
+
v00 = nv00; v02 = nv02; v10 = nv10; v12 = nv12; v20 = nv20; v22 = nv22
|
| 172 |
+
|
| 173 |
+
# -- pair (1, 2) --
|
| 174 |
+
off_diag = g12
|
| 175 |
+
diag_diff = g22 - g11
|
| 176 |
+
abs_off = tl.abs(off_diag)
|
| 177 |
+
tau_12 = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)
|
| 178 |
+
abs_tau = tl.abs(tau_12)
|
| 179 |
+
t_12 = tl.where(
|
| 180 |
+
abs_off > EPS,
|
| 181 |
+
tl.where(tau_12 >= 0, 1.0, -1.0) / (abs_tau + tl.sqrt(1.0 + tau_12 * tau_12)),
|
| 182 |
+
0.0,
|
| 183 |
+
)
|
| 184 |
+
c12 = 1.0 / tl.sqrt(1.0 + t_12 * t_12)
|
| 185 |
+
s12 = t_12 * c12
|
| 186 |
+
|
| 187 |
+
new_g11 = c12*c12*g11 - 2.0*s12*c12*g12 + s12*s12*g22
|
| 188 |
+
new_g22 = s12*s12*g11 + 2.0*s12*c12*g12 + c12*c12*g22
|
| 189 |
+
new_g12 = 0.0
|
| 190 |
+
new_g01 = c12*g01 - s12*g02
|
| 191 |
+
new_g02_b = s12*g01 + c12*g02
|
| 192 |
+
g11 = new_g11; g22 = new_g22; g12 = new_g12
|
| 193 |
+
g01 = new_g01; g02 = new_g02_b
|
| 194 |
+
|
| 195 |
+
nv01 = c12*v01 - s12*v02; nv02 = s12*v01 + c12*v02
|
| 196 |
+
nv11 = c12*v11 - s12*v12; nv12 = s12*v11 + c12*v12
|
| 197 |
+
nv21 = c12*v21 - s12*v22; nv22 = s12*v21 + c12*v22
|
| 198 |
+
v01 = nv01; v02 = nv02; v11 = nv11; v12 = nv12; v21 = nv21; v22 = nv22
|
| 199 |
+
|
| 200 |
+
# ---------------------------------------------------------------
|
| 201 |
+
# Stage 2b: Extract eigenvalues, sort descending
|
| 202 |
+
# Diagonal of G now holds eigenvalues of A^T A
|
| 203 |
+
# ---------------------------------------------------------------
|
| 204 |
+
eig0 = tl.maximum(g00, EPS)
|
| 205 |
+
eig1 = tl.maximum(g11, EPS)
|
| 206 |
+
eig2 = tl.maximum(g22, EPS)
|
| 207 |
+
|
| 208 |
+
s0 = tl.sqrt(eig0)
|
| 209 |
+
s1 = tl.sqrt(eig1)
|
| 210 |
+
s2 = tl.sqrt(eig2)
|
| 211 |
+
|
| 212 |
+
# Sorting network for 3 elements (descending)
|
| 213 |
+
# We need to sort S and permute V columns accordingly.
|
| 214 |
+
# Approach: compare-and-swap on the scalar eigenvalues + V columns
|
| 215 |
+
#
|
| 216 |
+
# 3-element sorting network: (0,1), (0,2), (1,2)
|
| 217 |
+
|
| 218 |
+
# swap(0, 1) if s0 < s1
|
| 219 |
+
do_swap = s0 < s1
|
| 220 |
+
s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)
|
| 221 |
+
# Swap V columns 0 and 1
|
| 222 |
+
tv00 = v00; tv10 = v10; tv20 = v20
|
| 223 |
+
v00 = tl.where(do_swap, v01, v00); v01 = tl.where(do_swap, tv00, v01)
|
| 224 |
+
v10 = tl.where(do_swap, v11, v10); v11 = tl.where(do_swap, tv10, v11)
|
| 225 |
+
v20 = tl.where(do_swap, v21, v20); v21 = tl.where(do_swap, tv20, v21)
|
| 226 |
+
|
| 227 |
+
# swap(0, 2) if s0 < s2
|
| 228 |
+
do_swap = s0 < s2
|
| 229 |
+
s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)
|
| 230 |
+
tv00 = v00; tv10 = v10; tv20 = v20
|
| 231 |
+
v00 = tl.where(do_swap, v02, v00); v02 = tl.where(do_swap, tv00, v02)
|
| 232 |
+
v10 = tl.where(do_swap, v12, v10); v12 = tl.where(do_swap, tv10, v12)
|
| 233 |
+
v20 = tl.where(do_swap, v22, v20); v22 = tl.where(do_swap, tv20, v22)
|
| 234 |
+
|
| 235 |
+
# swap(1, 2) if s1 < s2
|
| 236 |
+
do_swap = s1 < s2
|
| 237 |
+
s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)
|
| 238 |
+
tv01 = v01; tv11 = v11; tv21 = v21
|
| 239 |
+
v01 = tl.where(do_swap, v02, v01); v02 = tl.where(do_swap, tv01, v02)
|
| 240 |
+
v11 = tl.where(do_swap, v12, v11); v12 = tl.where(do_swap, tv11, v12)
|
| 241 |
+
v21 = tl.where(do_swap, v22, v21); v22 = tl.where(do_swap, tv21, v22)
|
| 242 |
+
|
| 243 |
+
# ---------------------------------------------------------------
|
| 244 |
+
# Stage 2c: Write S and Vh
|
| 245 |
+
# ---------------------------------------------------------------
|
| 246 |
+
s_base = bid * 3
|
| 247 |
+
tl.store(S_ptr + s_base + 0, s0)
|
| 248 |
+
tl.store(S_ptr + s_base + 1, s1)
|
| 249 |
+
tl.store(S_ptr + s_base + 2, s2)
|
| 250 |
+
|
| 251 |
+
# Vh = V^T (Vh[i,j] = V[j,i])
|
| 252 |
+
vh_base = bid * 9
|
| 253 |
+
tl.store(Vh_ptr + vh_base + 0, v00) # Vh[0,0] = V[0,0]
|
| 254 |
+
tl.store(Vh_ptr + vh_base + 1, v10) # Vh[0,1] = V[1,0]
|
| 255 |
+
tl.store(Vh_ptr + vh_base + 2, v20) # Vh[0,2] = V[2,0]
|
| 256 |
+
tl.store(Vh_ptr + vh_base + 3, v01) # Vh[1,0] = V[0,1]
|
| 257 |
+
tl.store(Vh_ptr + vh_base + 4, v11) # Vh[1,1] = V[1,1]
|
| 258 |
+
tl.store(Vh_ptr + vh_base + 5, v21) # Vh[1,2] = V[2,1]
|
| 259 |
+
tl.store(Vh_ptr + vh_base + 6, v02) # Vh[2,0] = V[0,2]
|
| 260 |
+
tl.store(Vh_ptr + vh_base + 7, v12) # Vh[2,1] = V[1,2]
|
| 261 |
+
tl.store(Vh_ptr + vh_base + 8, v22) # Vh[2,2] = V[2,2]
|
| 262 |
+
|
| 263 |
+
# ---------------------------------------------------------------
|
| 264 |
+
# Stage 3: Recover U = A @ V @ diag(1/S)
|
| 265 |
+
#
|
| 266 |
+
# U[:, c] = (1/S[c]) * A @ V[:, c]
|
| 267 |
+
# Tiled over M to keep memory pressure bounded.
|
| 268 |
+
# ---------------------------------------------------------------
|
| 269 |
+
inv_s0 = 1.0 / (s0 + EPS)
|
| 270 |
+
inv_s1 = 1.0 / (s1 + EPS)
|
| 271 |
+
inv_s2 = 1.0 / (s2 + EPS)
|
| 272 |
+
|
| 273 |
+
for block_start in range(0, M, BLOCK_M):
|
| 274 |
+
offs = tl.arange(0, BLOCK_M)
|
| 275 |
+
row_idx = block_start + offs
|
| 276 |
+
mask = row_idx < M
|
| 277 |
+
|
| 278 |
+
ptr0 = base + row_idx * 3 + 0
|
| 279 |
+
ptr1 = base + row_idx * 3 + 1
|
| 280 |
+
ptr2 = base + row_idx * 3 + 2
|
| 281 |
+
|
| 282 |
+
a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)
|
| 283 |
+
a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)
|
| 284 |
+
a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)
|
| 285 |
+
|
| 286 |
+
# U[:, 0] = (A @ V[:, 0]) / s0
|
| 287 |
+
u0 = (a0 * v00 + a1 * v10 + a2 * v20) * inv_s0
|
| 288 |
+
# U[:, 1] = (A @ V[:, 1]) / s1
|
| 289 |
+
u1 = (a0 * v01 + a1 * v11 + a2 * v21) * inv_s1
|
| 290 |
+
# U[:, 2] = (A @ V[:, 2]) / s2
|
| 291 |
+
u2 = (a0 * v02 + a1 * v12 + a2 * v22) * inv_s2
|
| 292 |
+
|
| 293 |
+
u_base = bid * M * 3
|
| 294 |
+
tl.store(U_ptr + u_base + row_idx * 3 + 0, u0, mask=mask)
|
| 295 |
+
tl.store(U_ptr + u_base + row_idx * 3 + 1, u1, mask=mask)
|
| 296 |
+
tl.store(U_ptr + u_base + row_idx * 3 + 2, u2, mask=mask)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# ============================================================================
|
| 300 |
+
# Python wrapper
|
| 301 |
+
# ============================================================================
|
| 302 |
+
|
| 303 |
+
def batched_svd3(
|
| 304 |
+
A: torch.Tensor,
|
| 305 |
+
block_m: int = 128,
|
| 306 |
+
jacobi_iters: int = 6,
|
| 307 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 308 |
+
"""
|
| 309 |
+
Batched thin SVD for (B, M, 3) float32 tensors.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
A: Input tensor of shape (B, M, 3). M can be anything (1024 for CIFAR).
|
| 313 |
+
block_m: Tile size for the spatial dimension. 128 is good for M=1024.
|
| 314 |
+
jacobi_iters: Number of cyclic Jacobi sweeps. 6 is overkill for 3×3.
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
U: (B, M, 3) — thin left singular vectors
|
| 318 |
+
S: (B, 3) — singular values, descending
|
| 319 |
+
Vh: (B, 3, 3) — right singular vectors transposed
|
| 320 |
+
"""
|
| 321 |
+
assert A.ndim == 3 and A.shape[2] == 3, f"Expected (B, M, 3), got {A.shape}"
|
| 322 |
+
assert A.is_cuda, "Input must be on CUDA"
|
| 323 |
+
|
| 324 |
+
B, M, _ = A.shape
|
| 325 |
+
A_f32 = A.contiguous().float()
|
| 326 |
+
|
| 327 |
+
U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)
|
| 328 |
+
S = torch.empty((B, 3), dtype=torch.float32, device=A.device)
|
| 329 |
+
Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)
|
| 330 |
+
|
| 331 |
+
_svd3_kernel[(B,)](
|
| 332 |
+
A_f32, U, S, Vh,
|
| 333 |
+
M=M,
|
| 334 |
+
BLOCK_M=block_m,
|
| 335 |
+
JACOBI_ITERS=jacobi_iters,
|
| 336 |
+
EPS=1e-12,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
return U, S, Vh
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# ============================================================================
|
| 343 |
+
# Benchmark & validation harness
|
| 344 |
+
# ============================================================================
|
| 345 |
+
|
| 346 |
+
def _test_correctness(B=256, M=1024):
|
| 347 |
+
"""Validate against torch.linalg.svd."""
|
| 348 |
+
A = torch.randn(B, M, 3, device="cuda", dtype=torch.float32)
|
| 349 |
+
|
| 350 |
+
U, S, Vh = batched_svd3(A)
|
| 351 |
+
|
| 352 |
+
# Reference: torch.linalg.svd on 3D is batched
|
| 353 |
+
U_ref, S_ref, Vh_ref = torch.linalg.svd(A, full_matrices=False)
|
| 354 |
+
|
| 355 |
+
# Singular values should match (tight — these are just 3 values)
|
| 356 |
+
s_err = (S - S_ref).abs().max().item()
|
| 357 |
+
print(f"[correctness] S max error: {s_err:.2e}")
|
| 358 |
+
|
| 359 |
+
# Reconstruction: A ≈ U @ diag(S) @ Vh
|
| 360 |
+
# Compare against torch's own recon error as the floor —
|
| 361 |
+
# f32 accumulation over M=1024 rows means ~1e-3 is physical.
|
| 362 |
+
recon = torch.bmm(U * S.unsqueeze(1), Vh)
|
| 363 |
+
recon_ref = torch.bmm(U_ref * S_ref.unsqueeze(1), Vh_ref)
|
| 364 |
+
r_err = (A - recon).abs().max().item()
|
| 365 |
+
r_ref_err = (A - recon_ref).abs().max().item()
|
| 366 |
+
print(f"[correctness] Recon error: {r_err:.2e} (ref: {r_ref_err:.2e})")
|
| 367 |
+
|
| 368 |
+
# Orthogonality: U^T U ≈ I
|
| 369 |
+
UtU = torch.bmm(U.transpose(1, 2), U)
|
| 370 |
+
eye = torch.eye(3, device="cuda").expand(B, -1, -1)
|
| 371 |
+
orth_err = (UtU - eye).abs().max().item()
|
| 372 |
+
print(f"[correctness] U orthog error: {orth_err:.2e}")
|
| 373 |
+
|
| 374 |
+
# S should be tight; recon only needs to match torch's own floor
|
| 375 |
+
assert s_err < 1e-3, f"Singular value error {s_err} > 1e-3"
|
| 376 |
+
assert r_err < r_ref_err * 2.0 + 1e-5, (
|
| 377 |
+
f"Recon error {r_err:.2e} > 2x torch ref {r_ref_err:.2e}"
|
| 378 |
+
)
|
| 379 |
+
assert orth_err < 1e-3, f"Orthogonality error {orth_err} > 1e-3"
|
| 380 |
+
print("[correctness] PASSED\n")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def _cuda_timer(fn, warmup=25, iters=100):
|
| 384 |
+
"""
|
| 385 |
+
CUDA-event-timed benchmark. Returns (mean_ms, std_ms) over `iters` runs.
|
| 386 |
+
Uses per-iteration events for proper distribution stats.
|
| 387 |
+
"""
|
| 388 |
+
# Warmup — let triton autotuner / cublas handle settle
|
| 389 |
+
for _ in range(warmup):
|
| 390 |
+
fn()
|
| 391 |
+
torch.cuda.synchronize()
|
| 392 |
+
|
| 393 |
+
starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
| 394 |
+
ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
| 395 |
+
|
| 396 |
+
for i in range(iters):
|
| 397 |
+
starts[i].record()
|
| 398 |
+
fn()
|
| 399 |
+
ends[i].record()
|
| 400 |
+
torch.cuda.synchronize()
|
| 401 |
+
|
| 402 |
+
times = [starts[i].elapsed_time(ends[i]) for i in range(iters)]
|
| 403 |
+
t = torch.tensor(times)
|
| 404 |
+
return t.mean().item(), t.std().item(), t.median().item(), t.min().item(), t.max().item()
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _profile_sweep():
|
| 408 |
+
"""
|
| 409 |
+
Full profiling sweep: batch sizes × spatial dims × block sizes.
|
| 410 |
+
Compares Triton SVD3 vs torch.linalg.svd with CUDA event timing.
|
| 411 |
+
"""
|
| 412 |
+
import json
|
| 413 |
+
|
| 414 |
+
device_name = torch.cuda.get_device_name(0)
|
| 415 |
+
print(f"{'='*72}")
|
| 416 |
+
print(f" Triton SVD3 Profiling — {device_name}")
|
| 417 |
+
print(f"{'='*72}\n")
|
| 418 |
+
|
| 419 |
+
# ------------------------------------------------------------------
|
| 420 |
+
# Sweep 1: Batch scaling at fixed M=1024 (CIFAR spatial)
|
| 421 |
+
# ------------------------------------------------------------------
|
| 422 |
+
print(f"{'─'*72}")
|
| 423 |
+
print(f" SWEEP 1: Batch scaling (M=1024, BLOCK_M=128)")
|
| 424 |
+
print(f"{'─'*72}")
|
| 425 |
+
print(f" {'B':>6} {'Triton ms':>12} {'±std':>8} {'Torch ms':>12} {'±std':>8} {'Speedup':>8}")
|
| 426 |
+
print(f" {'─'*62}")
|
| 427 |
+
|
| 428 |
+
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]
|
| 429 |
+
M = 1024
|
| 430 |
+
batch_results = []
|
| 431 |
+
|
| 432 |
+
for B in batch_sizes:
|
| 433 |
+
A = torch.randn(B, M, 3, device="cuda", dtype=torch.float32)
|
| 434 |
+
|
| 435 |
+
tri_mean, tri_std, tri_med, tri_min, tri_max = _cuda_timer(
|
| 436 |
+
lambda: batched_svd3(A, block_m=128)
|
| 437 |
+
)
|
| 438 |
+
tch_mean, tch_std, tch_med, tch_min, tch_max = _cuda_timer(
|
| 439 |
+
lambda: torch.linalg.svd(A, full_matrices=False)
|
| 440 |
+
)
|
| 441 |
+
speedup = tch_mean / (tri_mean + 1e-9)
|
| 442 |
+
|
| 443 |
+
print(f" {B:>6} {tri_mean:>10.3f}ms {tri_std:>6.3f} {tch_mean:>10.3f}ms {tch_std:>6.3f} {speedup:>7.2f}x")
|
| 444 |
+
batch_results.append({
|
| 445 |
+
"B": B, "M": M,
|
| 446 |
+
"triton_mean_ms": round(tri_mean, 4), "triton_std_ms": round(tri_std, 4),
|
| 447 |
+
"triton_median_ms": round(tri_med, 4), "triton_min_ms": round(tri_min, 4),
|
| 448 |
+
"torch_mean_ms": round(tch_mean, 4), "torch_std_ms": round(tch_std, 4),
|
| 449 |
+
"torch_median_ms": round(tch_med, 4), "torch_min_ms": round(tch_min, 4),
|
| 450 |
+
"speedup": round(speedup, 3),
|
| 451 |
+
})
|
| 452 |
+
del A
|
| 453 |
+
torch.cuda.empty_cache()
|
| 454 |
+
|
| 455 |
+
# ------------------------------------------------------------------
|
| 456 |
+
# Sweep 2: Spatial dim scaling at fixed B=1024
|
| 457 |
+
# ------------------------------------------------------------------
|
| 458 |
+
print(f"\n{'─'*72}")
|
| 459 |
+
print(f" SWEEP 2: Spatial scaling (B=1024, BLOCK_M=128)")
|
| 460 |
+
print(f"{'─'*72}")
|
| 461 |
+
print(f" {'M':>6} {'Triton ms':>12} {'±std':>8} {'Torch ms':>12} {'±std':>8} {'Speedup':>8}")
|
| 462 |
+
print(f" {'─'*62}")
|
| 463 |
+
|
| 464 |
+
spatial_dims = [64, 256, 512, 1024, 2048, 4096] # 8×8 to 64×64
|
| 465 |
+
B = 1024
|
| 466 |
+
spatial_results = []
|
| 467 |
+
|
| 468 |
+
for M in spatial_dims:
|
| 469 |
+
A = torch.randn(B, M, 3, device="cuda", dtype=torch.float32)
|
| 470 |
+
|
| 471 |
+
tri_mean, tri_std, tri_med, tri_min, tri_max = _cuda_timer(
|
| 472 |
+
lambda: batched_svd3(A, block_m=128)
|
| 473 |
+
)
|
| 474 |
+
tch_mean, tch_std, tch_med, tch_min, tch_max = _cuda_timer(
|
| 475 |
+
lambda: torch.linalg.svd(A, full_matrices=False)
|
| 476 |
+
)
|
| 477 |
+
speedup = tch_mean / (tri_mean + 1e-9)
|
| 478 |
+
|
| 479 |
+
equiv_hw = int(M**0.5)
|
| 480 |
+
tag = f"~{equiv_hw}x{equiv_hw}" if equiv_hw * equiv_hw == M else f" {M}"
|
| 481 |
+
print(f" {tag:>6} {tri_mean:>10.3f}ms {tri_std:>6.3f} {tch_mean:>10.3f}ms {tch_std:>6.3f} {speedup:>7.2f}x")
|
| 482 |
+
spatial_results.append({
|
| 483 |
+
"B": B, "M": M,
|
| 484 |
+
"triton_mean_ms": round(tri_mean, 4), "torch_mean_ms": round(tch_mean, 4),
|
| 485 |
+
"speedup": round(speedup, 3),
|
| 486 |
+
})
|
| 487 |
+
del A
|
| 488 |
+
torch.cuda.empty_cache()
|
| 489 |
+
|
| 490 |
+
# ------------------------------------------------------------------
|
| 491 |
+
# Sweep 3: BLOCK_M tuning at fixed B=4096, M=1024
|
| 492 |
+
# ------------------------------------------------------------------
|
| 493 |
+
print(f"\n{'─'*72}")
|
| 494 |
+
print(f" SWEEP 3: BLOCK_M tuning (B=4096, M=1024)")
|
| 495 |
+
print(f"{'─'*72}")
|
| 496 |
+
print(f" {'BLOCK_M':>8} {'Triton ms':>12} {'±std':>8} {'tiles/img':>10}")
|
| 497 |
+
print(f" {'─'*44}")
|
| 498 |
+
|
| 499 |
+
block_sizes = [32, 64, 128, 256, 512, 1024]
|
| 500 |
+
B, M = 4096, 1024
|
| 501 |
+
A = torch.randn(B, M, 3, device="cuda", dtype=torch.float32)
|
| 502 |
+
block_results = []
|
| 503 |
+
|
| 504 |
+
for bm in block_sizes:
|
| 505 |
+
tri_mean, tri_std, tri_med, tri_min, tri_max = _cuda_timer(
|
| 506 |
+
lambda: batched_svd3(A, block_m=bm)
|
| 507 |
+
)
|
| 508 |
+
n_tiles = (M + bm - 1) // bm
|
| 509 |
+
print(f" {bm:>8} {tri_mean:>10.3f}ms {tri_std:>6.3f} {n_tiles:>10}")
|
| 510 |
+
block_results.append({
|
| 511 |
+
"block_m": bm, "triton_mean_ms": round(tri_mean, 4),
|
| 512 |
+
"triton_std_ms": round(tri_std, 4), "n_tiles": n_tiles,
|
| 513 |
+
})
|
| 514 |
+
|
| 515 |
+
del A
|
| 516 |
+
torch.cuda.empty_cache()
|
| 517 |
+
|
| 518 |
+
# ------------------------------------------------------------------
|
| 519 |
+
# Sweep 4: Throughput — images/sec at peak batch
|
| 520 |
+
# ------------------------------------------------------------------
|
| 521 |
+
print(f"\n{'─'*72}")
|
| 522 |
+
print(f" SWEEP 4: Throughput (images/sec)")
|
| 523 |
+
print(f"{'─'*72}")
|
| 524 |
+
|
| 525 |
+
for B in [4096, 16384]:
|
| 526 |
+
A = torch.randn(B, 1024, 3, device="cuda", dtype=torch.float32)
|
| 527 |
+
tri_mean, *_ = _cuda_timer(lambda: batched_svd3(A, block_m=128))
|
| 528 |
+
tch_mean, *_ = _cuda_timer(lambda: torch.linalg.svd(A, full_matrices=False))
|
| 529 |
+
|
| 530 |
+
tri_ips = B / (tri_mean / 1000)
|
| 531 |
+
tch_ips = B / (tch_mean / 1000)
|
| 532 |
+
print(f" B={B:>5}: Triton {tri_ips:>12,.0f} img/s | Torch {tch_ips:>12,.0f} img/s")
|
| 533 |
+
del A
|
| 534 |
+
torch.cuda.empty_cache()
|
| 535 |
+
|
| 536 |
+
# ------------------------------------------------------------------
|
| 537 |
+
# Memory: peak allocation comparison
|
| 538 |
+
# ------------------------------------------------------------------
|
| 539 |
+
print(f"\n{'─'*72}")
|
| 540 |
+
print(f" MEMORY: Peak allocation (B=4096, M=1024)")
|
| 541 |
+
print(f"{'─'*72}")
|
| 542 |
+
|
| 543 |
+
B, M = 4096, 1024
|
| 544 |
+
A = torch.randn(B, M, 3, device="cuda", dtype=torch.float32)
|
| 545 |
+
|
| 546 |
+
torch.cuda.reset_peak_memory_stats()
|
| 547 |
+
_ = batched_svd3(A)
|
| 548 |
+
torch.cuda.synchronize()
|
| 549 |
+
tri_peak = torch.cuda.max_memory_allocated() / 1024**2
|
| 550 |
+
|
| 551 |
+
torch.cuda.reset_peak_memory_stats()
|
| 552 |
+
_ = torch.linalg.svd(A, full_matrices=False)
|
| 553 |
+
torch.cuda.synchronize()
|
| 554 |
+
tch_peak = torch.cuda.max_memory_allocated() / 1024**2
|
| 555 |
+
|
| 556 |
+
print(f" Triton: {tri_peak:.1f} MB")
|
| 557 |
+
print(f" Torch: {tch_peak:.1f} MB")
|
| 558 |
+
print(f" Ratio: {tch_peak / (tri_peak + 1e-9):.2f}x\n")
|
| 559 |
+
|
| 560 |
+
# ------------------------------------------------------------------
|
| 561 |
+
# Dump JSON for further analysis
|
| 562 |
+
# ------------------------------------------------------------------
|
| 563 |
+
report = {
|
| 564 |
+
"device": device_name,
|
| 565 |
+
"batch_sweep": batch_results,
|
| 566 |
+
"spatial_sweep": spatial_results,
|
| 567 |
+
"block_m_sweep": block_results,
|
| 568 |
+
}
|
| 569 |
+
with open("svd3_profile.json", "w") as f:
|
| 570 |
+
json.dump(report, f, indent=2)
|
| 571 |
+
print(f" Full results written to svd3_profile.json\n")
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
if __name__ == "__main__":
|
| 575 |
+
_test_correctness()
|
| 576 |
+
_profile_sweep()
|