Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/ops/kda/__init__.py
Empty file.
276 changes: 276 additions & 0 deletions tests/ops/kda/test_pallas_chunk_kda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
"""chunk_kda: Pallas kernel vs CPU reference tests for Kimi Delta Attention.

Usage:
uv run pytest tests/ops/kda/test_pallas_chunk_kda.py -v
"""

from __future__ import annotations

import numpy as np
import pytest
import jax
import jax.numpy as jnp

from tops.cpu.ops.kda import naive_recurrent_kda, naive_chunk_kda
from tops.ops.kda import chunk_kda_fwd


def compare_tensor(name, gold, tensor, atol=1e-5, rtol=1e-5, **kwargs):
"""Lightweight compare_tensor without torch dependency."""
if gold is None and tensor is None:
print(f"[{name}] Both are None. MATCH.")
return True
if gold is None or tensor is None:
print(f"[{name}] One is None! MISMATCH.")
return False

a = np.array(gold, dtype=np.float64)
b = np.array(tensor, dtype=np.float64)

if a.shape != b.shape:
print(f"[{name}] Shape mismatch: {a.shape} vs {b.shape}. FAIL.")
return False

diff = np.abs(a - b)
max_diff = np.max(diff)
max_val = np.max(np.abs(b))
ok = np.allclose(a, b, atol=atol, rtol=rtol, equal_nan=True)
status = "PASS" if ok else "FAIL"
print(f"[{name}] {status} max_val={max_val:.6e} max_diff={max_diff:.6e}")
if not ok:
idx = np.unravel_index(np.argmax(diff), diff.shape)
print(
f" worst at {idx}: gold={a[idx]:.8e} pallas={b[idx]:.8e} diff={diff[idx]:.8e}"
)
return ok


# ============================================================================
# Test configs
# ============================================================================

CASES = [
# ── standard shapes ──
dict(B=2, T=64, H=4, K=64, V=64, seed=42),
dict(B=2, T=64, H=4, K=64, V=64, seed=13, h0=True),
dict(B=1, T=128, H=2, K=64, V=128, seed=7),
# ── single head ──
dict(B=2, T=64, H=1, K=64, V=64, seed=10),
# ── K != V ──
dict(B=2, T=64, H=4, K=64, V=128, seed=20),
dict(B=2, T=64, H=4, K=128, V=64, seed=21),
# ── very short T ──
dict(B=1, T=64, H=2, K=64, V=64, seed=30),
# ── T > chunk_size (multiple chunks) ──
dict(B=1, T=128, H=2, K=64, V=64, seed=31),
dict(B=1, T=192, H=2, K=64, V=64, seed=32),
# ── with initial state ──
dict(B=2, T=128, H=4, K=64, V=64, seed=40, h0=True),
dict(B=1, T=64, H=2, K=128, V=128, seed=41, h0=True),
# ── large batch ──
dict(B=4, T=64, H=4, K=64, V=64, seed=50),
# ── many heads ──
dict(B=1, T=64, H=8, K=64, V=64, seed=60),
# ── custom scale ──
dict(B=2, T=64, H=4, K=64, V=64, seed=200, scale=0.1),
# ── longer sequence ──
dict(B=1, T=256, H=2, K=64, V=64, seed=300),
dict(B=1, T=256, H=2, K=64, V=64, seed=301, h0=True),
# ── scale + h0 ──
dict(B=2, T=64, H=4, K=64, V=64, seed=140, scale=0.1, h0=True),
]


def _case_id(c):
parts = [f"B{c['B']}_T{c['T']}_H{c['H']}_K{c['K']}_V{c['V']}"]
if c.get("h0"):
parts.append("h0")
if c.get("scale") is not None:
parts.append(f"scale={c['scale']}")
return "-".join(parts)


# ============================================================================
# Helpers
# ============================================================================


def _l2_normalize_last_dim(x):
"""L2-normalize along the last dimension with epsilon for stability."""
denom = jnp.sqrt(jnp.sum(x * x, axis=-1, keepdims=True) + 1e-6)
return (x / denom).astype(x.dtype)


def _make_inputs(cfg):
"""Generate random inputs for KDA in bfloat16."""
B, T, H, K, V = cfg["B"], cfg["T"], cfg["H"], cfg["K"], cfg["V"]
key = jax.random.PRNGKey(cfg["seed"])
keys = jax.random.split(key, 6)

dtype = jnp.bfloat16

# L2-normalize q and k so that each row has unit norm. This keeps
# the Neumann series truncation valid: the interaction matrix element
# magnitude L ~ O(K · σ² · β) stays small because after L2 norm
# each element has magnitude ~ 1/√K, i.e. σ = 1/√K.
q = _l2_normalize_last_dim(jax.random.normal(keys[0], (B, T, H, K), dtype=dtype))
k = _l2_normalize_last_dim(jax.random.normal(keys[1], (B, T, H, K), dtype=dtype))
v = jax.random.normal(keys[2], (B, T, H, V), dtype=dtype) * 0.1
# g should be negative (decay gates) — use log-sigmoid
g = jax.nn.log_sigmoid(jax.random.normal(keys[3], (B, T, H, K), dtype=dtype))
beta = jax.nn.sigmoid(jax.random.normal(keys[4], (B, T, H), dtype=dtype))

h0 = None
if cfg.get("h0"):
h0 = jax.random.normal(keys[5], (B, H, K, V), dtype=dtype) * 0.1

scale = cfg.get("scale", None)
return q, k, v, g, beta, h0, scale


def _run_cpu_ref(q, k, v, g, beta, h0, scale):
"""Run CPU reference naive_chunk_kda."""
o, ht = naive_chunk_kda(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=True,
)
return o, ht


def _run_pallas(q, k, v, g, beta, h0, scale, output_final_state=True):
"""Run Pallas chunk_kda_fwd."""
return chunk_kda_fwd(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=output_final_state,
)


# ============================================================================
# Parametrized test — CPU reference vs Pallas
# ============================================================================


@pytest.mark.parametrize("cfg", CASES, ids=[_case_id(c) for c in CASES])
def test_cpu_vs_pallas(cfg):
atol = cfg.get("atol", 5e-2)
rtol = cfg.get("rtol", 5e-2)

q, k, v, g, beta, h0, scale = _make_inputs(cfg)

o_cpu, s_cpu = _run_cpu_ref(q, k, v, g, beta, h0, scale)
o_pallas, s_pallas = _run_pallas(q, k, v, g, beta, h0, scale)

assert compare_tensor("output", o_cpu, o_pallas, atol=atol, rtol=rtol)
assert compare_tensor("final_state", s_cpu, s_pallas, atol=atol, rtol=rtol)


# ============================================================================
# Structural tests
# ============================================================================


def test_state_split_pallas():
"""Split sequence in 2 halves: state continuity via Pallas."""
key = jax.random.PRNGKey(77)
B, T, H, K, V = 1, 128, 2, 64, 64
keys = jax.random.split(key, 5)
dtype = jnp.bfloat16

q = _l2_normalize_last_dim(jax.random.normal(keys[0], (B, T, H, K), dtype=dtype))
k = _l2_normalize_last_dim(jax.random.normal(keys[1], (B, T, H, K), dtype=dtype))
v = jax.random.normal(keys[2], (B, T, H, V), dtype=dtype) * 0.1
g = jax.nn.log_sigmoid(jax.random.normal(keys[3], (B, T, H, K), dtype=dtype))
beta = jax.nn.sigmoid(jax.random.normal(keys[4], (B, T, H), dtype=dtype))
T1 = T // 2

_, s_full = _run_pallas(q, k, v, g, beta, None, None)
_, s1 = _run_pallas(
q[:, :T1],
k[:, :T1],
v[:, :T1],
g[:, :T1],
beta[:, :T1],
None,
None,
)
_, s2 = _run_pallas(
q[:, T1:],
k[:, T1:],
v[:, T1:],
g[:, T1:],
beta[:, T1:],
s1,
None,
)

assert compare_tensor("pallas: full vs split", s_full, s2, atol=5e-2, rtol=5e-2)


def test_no_final_state_pallas():
"""output_final_state=False returns None for final_state."""
key = jax.random.PRNGKey(210)
B, T, H, K, V = 2, 64, 4, 64, 64
keys = jax.random.split(key, 5)
dtype = jnp.bfloat16

q = _l2_normalize_last_dim(jax.random.normal(keys[0], (B, T, H, K), dtype=dtype))
k = _l2_normalize_last_dim(jax.random.normal(keys[1], (B, T, H, K), dtype=dtype))
v = jax.random.normal(keys[2], (B, T, H, V), dtype=dtype) * 0.1
g = jax.nn.log_sigmoid(jax.random.normal(keys[3], (B, T, H, K), dtype=dtype))
beta = jax.nn.sigmoid(jax.random.normal(keys[4], (B, T, H), dtype=dtype))

o_pallas, s_pallas = _run_pallas(
q, k, v, g, beta, None, None, output_final_state=False
)

assert s_pallas is None, f"final_state should be None, got {type(s_pallas)}"
assert o_pallas.shape == (B, T, H, V)


def test_matches_naive_recurrent():
"""Verify Pallas chunk matches naive recurrent (ground truth) at bf16."""
key = jax.random.PRNGKey(999)
B, T, H, K, V = 1, 128, 2, 64, 64
keys = jax.random.split(key, 6)
dtype = jnp.bfloat16

q = _l2_normalize_last_dim(jax.random.normal(keys[0], (B, T, H, K), dtype=dtype))
k = _l2_normalize_last_dim(jax.random.normal(keys[1], (B, T, H, K), dtype=dtype))
v = jax.random.normal(keys[2], (B, T, H, V), dtype=dtype) * 0.1
g = jax.nn.log_sigmoid(jax.random.normal(keys[3], (B, T, H, K), dtype=dtype))
beta = jax.nn.sigmoid(jax.random.normal(keys[4], (B, T, H), dtype=dtype))
h0 = jax.random.normal(keys[5], (B, H, K, V), dtype=dtype) * 0.1

o_recurrent, s_recurrent = naive_recurrent_kda(
q,
k,
v,
g,
beta,
initial_state=h0,
output_final_state=True,
)
o_pallas, s_pallas = _run_pallas(q, k, v, g, beta, h0, None)

assert compare_tensor(
"output vs recurrent", o_recurrent, o_pallas, atol=5e-2, rtol=5e-2
)
assert compare_tensor(
"state vs recurrent", s_recurrent, s_pallas, atol=5e-2, rtol=5e-2
)


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions tops/ops/kda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tops.ops.kda.chunk import chunk_kda_fwd as chunk_kda_fwd
Loading
Loading