Kernels
sae
File size: 2,331 Bytes
a262a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Callable
import pytest
import torch

pytest.importorskip("torch.cuda")
from .test_setup import DTYPES, DTYPE_TO_TOLS, PARAMS, SEED
from flex_sae import (
    triton_hierarchical_sae_loss,
    hierarchical_sae_loss,
    triton_topk_sae_loss,
    topk_sae_loss,
)


@pytest.fixture(autouse=True)
def _set_cuda_default_device():
    torch.set_default_device("cuda")


def run_funcs(B, K, F, D, dtype, *, kernel_foo: Callable, ref_foo: Callable):
    if dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported():
        pytest.skip("BF16 not supported on this GPU")

    torch.manual_seed(SEED)

    indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")

    vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
    decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
    bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
    target = torch.randn(B, D, dtype=dtype, device="cuda")

    sv_ref = vals.clone().detach().requires_grad_()
    dec_ref = decoder.clone().detach().requires_grad_()
    bias_ref = bias.clone().detach().requires_grad_()

    loss_f = kernel_foo(indices, decoder, vals, bias, target)
    loss_r = ref_foo(indices, dec_ref, sv_ref, bias_ref, target)

    torch.testing.assert_close(loss_f, loss_r, **DTYPE_TO_TOLS[dtype])

    grad_out = torch.randn((), device="cuda", dtype=torch.float32)
    loss_f.backward(grad_out)
    loss_r.backward(grad_out.clone())

    torch.testing.assert_close(vals.grad, sv_ref.grad, **DTYPE_TO_TOLS[dtype])
    torch.testing.assert_close(decoder.grad, dec_ref.grad, **DTYPE_TO_TOLS[dtype])
    torch.testing.assert_close(bias.grad, bias_ref.grad, **DTYPE_TO_TOLS[dtype])

    assert indices.grad is None


@pytest.mark.parametrize("B, K, F, D", PARAMS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_triton_hierarchical_sae_loss_and_grads(B, K, F, D, dtype):
    run_funcs(B, K, F, D, dtype, kernel_foo=triton_hierarchical_sae_loss, ref_foo=hierarchical_sae_loss)
    torch.cuda.empty_cache()


@pytest.mark.parametrize("B, K, F, D", PARAMS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_topk_sae_loss_and_grads(B, K, F, D, dtype):
    run_funcs(
        B, K, F, D, dtype, kernel_foo=triton_topk_sae_loss, ref_foo=topk_sae_loss
    )
    torch.cuda.empty_cache()