diff --git a/tests/ops/kda/__init__.py b/tests/ops/kda/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ops/kda/test_pallas_chunk_kda.py b/tests/ops/kda/test_pallas_chunk_kda.py new file mode 100644 index 00000000..1c1ebd09 --- /dev/null +++ b/tests/ops/kda/test_pallas_chunk_kda.py @@ -0,0 +1,276 @@ +"""chunk_kda: Pallas kernel vs CPU reference tests for Kimi Delta Attention. + +Usage: + uv run pytest tests/ops/kda/test_pallas_chunk_kda.py -v +""" + +from __future__ import annotations + +import numpy as np +import pytest +import jax +import jax.numpy as jnp + +from tops.cpu.ops.kda import naive_recurrent_kda, naive_chunk_kda +from tops.ops.kda import chunk_kda_fwd + + +def compare_tensor(name, gold, tensor, atol=1e-5, rtol=1e-5, **kwargs): + """Lightweight compare_tensor without torch dependency.""" + if gold is None and tensor is None: + print(f"[{name}] Both are None. MATCH.") + return True + if gold is None or tensor is None: + print(f"[{name}] One is None! MISMATCH.") + return False + + a = np.array(gold, dtype=np.float64) + b = np.array(tensor, dtype=np.float64) + + if a.shape != b.shape: + print(f"[{name}] Shape mismatch: {a.shape} vs {b.shape}. FAIL.") + return False + + diff = np.abs(a - b) + max_diff = np.max(diff) + max_val = np.max(np.abs(b)) + ok = np.allclose(a, b, atol=atol, rtol=rtol, equal_nan=True) + status = "PASS" if ok else "FAIL" + print(f"[{name}] {status} max_val={max_val:.6e} max_diff={max_diff:.6e}") + if not ok: + idx = np.unravel_index(np.argmax(diff), diff.shape) + print( + f" worst at {idx}: gold={a[idx]:.8e} pallas={b[idx]:.8e} diff={diff[idx]:.8e}" + ) + return ok + + +# ============================================================================ +# Test configs +# ============================================================================ + +CASES = [ + # ── standard shapes ── + dict(B=2, T=64, H=4, K=64, V=64, seed=42), + dict(B=2, T=64, H=4, K=64, V=64, seed=13, h0=True), + dict(B=1, T=128, H=2, K=64, V=128, seed=7), + # ── single head ── + dict(B=2, T=64, H=1, K=64, V=64, seed=10), + # ── K != V ── + dict(B=2, T=64, H=4, K=64, V=128, seed=20), + dict(B=2, T=64, H=4, K=128, V=64, seed=21), + # ── very short T ── + dict(B=1, T=64, H=2, K=64, V=64, seed=30), + # ── T > chunk_size (multiple chunks) ── + dict(B=1, T=128, H=2, K=64, V=64, seed=31), + dict(B=1, T=192, H=2, K=64, V=64, seed=32), + # ── with initial state ── + dict(B=2, T=128, H=4, K=64, V=64, seed=40, h0=True), + dict(B=1, T=64, H=2, K=128, V=128, seed=41, h0=True), + # ── large batch ── + dict(B=4, T=64, H=4, K=64, V=64, seed=50), + # ── many heads ── + dict(B=1, T=64, H=8, K=64, V=64, seed=60), + # ── custom scale ── + dict(B=2, T=64, H=4, K=64, V=64, seed=200, scale=0.1), + # ── longer sequence ── + dict(B=1, T=256, H=2, K=64, V=64, seed=300), + dict(B=1, T=256, H=2, K=64, V=64, seed=301, h0=True), + # ── scale + h0 ── + dict(B=2, T=64, H=4, K=64, V=64, seed=140, scale=0.1, h0=True), +] + + +def _case_id(c): + parts = [f"B{c['B']}_T{c['T']}_H{c['H']}_K{c['K']}_V{c['V']}"] + if c.get("h0"): + parts.append("h0") + if c.get("scale") is not None: + parts.append(f"scale={c['scale']}") + return "-".join(parts) + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _l2_normalize_last_dim(x): + """L2-normalize along the last dimension with epsilon for stability.""" + denom = jnp.sqrt(jnp.sum(x * x, axis=-1, keepdims=True) + 1e-6) + return (x / denom).astype(x.dtype) + + +def _make_inputs(cfg): + """Generate random inputs for KDA in bfloat16.""" + B, T, H, K, V = cfg["B"], cfg["T"], cfg["H"], cfg["K"], cfg["V"] + key = jax.random.PRNGKey(cfg["seed"]) + keys = jax.random.split(key, 6) + + dtype = jnp.bfloat16 + + # L2-normalize q and k so that each row has unit norm. This keeps + # the Neumann series truncation valid: the interaction matrix element + # magnitude L ~ O(K · σ² · β) stays small because after L2 norm + # each element has magnitude ~ 1/√K, i.e. σ = 1/√K. + q = _l2_normalize_last_dim(jax.random.normal(keys[0], (B, T, H, K), dtype=dtype)) + k = _l2_normalize_last_dim(jax.random.normal(keys[1], (B, T, H, K), dtype=dtype)) + v = jax.random.normal(keys[2], (B, T, H, V), dtype=dtype) * 0.1 + # g should be negative (decay gates) — use log-sigmoid + g = jax.nn.log_sigmoid(jax.random.normal(keys[3], (B, T, H, K), dtype=dtype)) + beta = jax.nn.sigmoid(jax.random.normal(keys[4], (B, T, H), dtype=dtype)) + + h0 = None + if cfg.get("h0"): + h0 = jax.random.normal(keys[5], (B, H, K, V), dtype=dtype) * 0.1 + + scale = cfg.get("scale", None) + return q, k, v, g, beta, h0, scale + + +def _run_cpu_ref(q, k, v, g, beta, h0, scale): + """Run CPU reference naive_chunk_kda.""" + o, ht = naive_chunk_kda( + q, + k, + v, + g, + beta, + scale=scale, + initial_state=h0, + output_final_state=True, + ) + return o, ht + + +def _run_pallas(q, k, v, g, beta, h0, scale, output_final_state=True): + """Run Pallas chunk_kda_fwd.""" + return chunk_kda_fwd( + q, + k, + v, + g, + beta, + scale=scale, + initial_state=h0, + output_final_state=output_final_state, + ) + + +# ============================================================================ +# Parametrized test — CPU reference vs Pallas +# ============================================================================ + + +@pytest.mark.parametrize("cfg", CASES, ids=[_case_id(c) for c in CASES]) +def test_cpu_vs_pallas(cfg): + atol = cfg.get("atol", 5e-2) + rtol = cfg.get("rtol", 5e-2) + + q, k, v, g, beta, h0, scale = _make_inputs(cfg) + + o_cpu, s_cpu = _run_cpu_ref(q, k, v, g, beta, h0, scale) + o_pallas, s_pallas = _run_pallas(q, k, v, g, beta, h0, scale) + + assert compare_tensor("output", o_cpu, o_pallas, atol=atol, rtol=rtol) + assert compare_tensor("final_state", s_cpu, s_pallas, atol=atol, rtol=rtol) + + +# ============================================================================ +# Structural tests +# ============================================================================ + + +def test_state_split_pallas(): + """Split sequence in 2 halves: state continuity via Pallas.""" + key = jax.random.PRNGKey(77) + B, T, H, K, V = 1, 128, 2, 64, 64 + keys = jax.random.split(key, 5) + dtype = jnp.bfloat16 + + q = _l2_normalize_last_dim(jax.random.normal(keys[0], (B, T, H, K), dtype=dtype)) + k = _l2_normalize_last_dim(jax.random.normal(keys[1], (B, T, H, K), dtype=dtype)) + v = jax.random.normal(keys[2], (B, T, H, V), dtype=dtype) * 0.1 + g = jax.nn.log_sigmoid(jax.random.normal(keys[3], (B, T, H, K), dtype=dtype)) + beta = jax.nn.sigmoid(jax.random.normal(keys[4], (B, T, H), dtype=dtype)) + T1 = T // 2 + + _, s_full = _run_pallas(q, k, v, g, beta, None, None) + _, s1 = _run_pallas( + q[:, :T1], + k[:, :T1], + v[:, :T1], + g[:, :T1], + beta[:, :T1], + None, + None, + ) + _, s2 = _run_pallas( + q[:, T1:], + k[:, T1:], + v[:, T1:], + g[:, T1:], + beta[:, T1:], + s1, + None, + ) + + assert compare_tensor("pallas: full vs split", s_full, s2, atol=5e-2, rtol=5e-2) + + +def test_no_final_state_pallas(): + """output_final_state=False returns None for final_state.""" + key = jax.random.PRNGKey(210) + B, T, H, K, V = 2, 64, 4, 64, 64 + keys = jax.random.split(key, 5) + dtype = jnp.bfloat16 + + q = _l2_normalize_last_dim(jax.random.normal(keys[0], (B, T, H, K), dtype=dtype)) + k = _l2_normalize_last_dim(jax.random.normal(keys[1], (B, T, H, K), dtype=dtype)) + v = jax.random.normal(keys[2], (B, T, H, V), dtype=dtype) * 0.1 + g = jax.nn.log_sigmoid(jax.random.normal(keys[3], (B, T, H, K), dtype=dtype)) + beta = jax.nn.sigmoid(jax.random.normal(keys[4], (B, T, H), dtype=dtype)) + + o_pallas, s_pallas = _run_pallas( + q, k, v, g, beta, None, None, output_final_state=False + ) + + assert s_pallas is None, f"final_state should be None, got {type(s_pallas)}" + assert o_pallas.shape == (B, T, H, V) + + +def test_matches_naive_recurrent(): + """Verify Pallas chunk matches naive recurrent (ground truth) at bf16.""" + key = jax.random.PRNGKey(999) + B, T, H, K, V = 1, 128, 2, 64, 64 + keys = jax.random.split(key, 6) + dtype = jnp.bfloat16 + + q = _l2_normalize_last_dim(jax.random.normal(keys[0], (B, T, H, K), dtype=dtype)) + k = _l2_normalize_last_dim(jax.random.normal(keys[1], (B, T, H, K), dtype=dtype)) + v = jax.random.normal(keys[2], (B, T, H, V), dtype=dtype) * 0.1 + g = jax.nn.log_sigmoid(jax.random.normal(keys[3], (B, T, H, K), dtype=dtype)) + beta = jax.nn.sigmoid(jax.random.normal(keys[4], (B, T, H), dtype=dtype)) + h0 = jax.random.normal(keys[5], (B, H, K, V), dtype=dtype) * 0.1 + + o_recurrent, s_recurrent = naive_recurrent_kda( + q, + k, + v, + g, + beta, + initial_state=h0, + output_final_state=True, + ) + o_pallas, s_pallas = _run_pallas(q, k, v, g, beta, h0, None) + + assert compare_tensor( + "output vs recurrent", o_recurrent, o_pallas, atol=5e-2, rtol=5e-2 + ) + assert compare_tensor( + "state vs recurrent", s_recurrent, s_pallas, atol=5e-2, rtol=5e-2 + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tops/ops/kda/__init__.py b/tops/ops/kda/__init__.py new file mode 100644 index 00000000..3b326293 --- /dev/null +++ b/tops/ops/kda/__init__.py @@ -0,0 +1 @@ +from tops.ops.kda.chunk import chunk_kda_fwd as chunk_kda_fwd diff --git a/tops/ops/kda/chunk.py b/tops/ops/kda/chunk.py new file mode 100644 index 00000000..51050257 --- /dev/null +++ b/tops/ops/kda/chunk.py @@ -0,0 +1,401 @@ +"""Chunk KDA (Kimi Delta Attention) forward — fully-fused Pallas TPU kernel. + +Implements the chunked delta-rule recurrence as a single fused Pallas kernel, +performing both preprocessing and inter-chunk recurrence inline to avoid HBM +round-trips for intermediate tensors (w, u, A). + +Algorithm overview (per chunk): + + 1. Preprocessing (inline, no HBM materialization): + - Cumulative gate: G = cumsum(g) + - Interaction matrix: A = (k·exp(G)) @ (k·exp(-G))^T · diag(β) + - Neumann series inversion: A_resolved = (I + L + L²) · diag(β) + where L = strict_lower(A), truncated at 2nd order. + - Effective keys/values: w = A_resolved @ (exp(G)·k), u = A_resolved @ v + + 2. Inter-chunk recurrence (sequential over NT): + - Delta correction: v_corr = u - w @ S + - Output: o = (q·exp(G)) @ S + A_qk @ v_corr + - State update: S ← S·exp(G_last) + (exp(G_last-G)·k)^T @ v_corr + +Grid: (B, H, NV, NT) — B, H, NV parallel; NT sequential (arbitrary). +Each grid point processes one V-tile (BV=128) of the hidden state. +K is loaded fully (not tiled) because the delta-rule correction ``w @ S`` +couples all K dimensions. +""" + +from __future__ import annotations + +import functools + +import jax +import jax.numpy as jnp +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + +from tops.ops.utils import is_tpu_runtime +from tops.utils import ( + assert_shape, + assert_shape_or_none, + cdiv, + align_up, + pad_to_multiple, +) + + +# ============================================================================ +# Pallas kernel +# ============================================================================ + + +def _chunk_kda_fwd_kernel( + # --- inputs --- + q_ref, # (1, 1, 1, BT, K) + k_ref, # (1, 1, 1, BT, K) + v_ref, # (1, 1, 1, BT, BV) — V-tiled + g_ref, # (1, 1, 1, BT, K) — raw gate (not cumsum'd) + beta_ref, # (1, 1, 1, BT, 128) — padded to 128 for TPU alignment + h0_ref, # (1, 1, K, BV) or None + # --- outputs --- + o_ref, # (1, 1, 1, BT, BV) + ht_ref, # (1, 1, K, BV) or None + # --- scratch --- + scratch_ref, # (K, BV) VMEM float32 — running hidden state S + *, + BT: int, + K: int, + NT: int, + OUTPUT_FINAL_STATE: bool, +): + """Fully-fused Pallas kernel: preprocessing + inter-chunk recurrence. + + Grid: (B, H, NV, NT) — first 3 parallel, NT arbitrary (sequential). + + At each chunk step the kernel: + 1. Computes cumulative gate G, interaction matrix A, effective w and u. + 2. Builds causal intra-chunk attention A_qk. + 3. Computes delta-corrected values: v_corr = u - w @ S. + 4. Computes output: o = (q·exp(G)) @ S + A_qk @ v_corr. + 5. Updates hidden state S. + + Refs (after BlockSpec indexing): + q_ref, k_ref : (1, 1, 1, BT, K) — full K, time-stepped + v_ref : (1, 1, 1, BT, BV) — V-tiled, time-stepped + g_ref : (1, 1, 1, BT, K) — raw gate, time-stepped + beta_ref : (1, 1, 1, BT, 128) — learning rate, time-stepped + h0_ref : (1, 1, K, BV) — initial state tile + o_ref : (1, 1, 1, BT, BV) — output tile + ht_ref : (1, 1, K, BV) — final state tile + scratch_ref : (K, BV) — running hidden state in VMEM + """ + BV = v_ref.shape[4] + i_t = pl.program_id(3) + + # ---- init hidden state ---- + @pl.when(i_t == 0) + def _init(): + if h0_ref is not None: + scratch_ref[:, :] = h0_ref[0, 0].astype(jnp.float32) + else: + scratch_ref[:, :] = jnp.zeros((K, BV), dtype=jnp.float32) + + # ---- load chunk data ---- + b_q = q_ref[0, 0, 0].astype(jnp.float32) # (BT, K) + b_k = k_ref[0, 0, 0].astype(jnp.float32) # (BT, K) + b_v = v_ref[0, 0, 0].astype(jnp.float32) # (BT, BV) + b_g_raw = g_ref[0, 0, 0].astype(jnp.float32) # (BT, K) + b_beta = beta_ref[0, 0, 0, :, 0].astype(jnp.float32) # (BT,) + b_h = scratch_ref[...] # (K, BV) + + # =================== preprocessing (fused) =================== + + # 1. cumulative gate (tril matmul — jnp.cumsum unavailable in Pallas) + tril = jnp.tril(jnp.ones((BT, BT), dtype=jnp.float32)) + b_g = jnp.dot(tril, b_g_raw, preferred_element_type=jnp.float32) # (BT, K) + + # 2. interaction matrix A[c,i] = Σ_d k[c,d]·exp(G[c,d]-G[i,d])·k[i,d] + b_k_eg = b_k * jnp.exp(b_g) # (BT, K) + b_k_eng = b_k * jnp.exp(-b_g) # (BT, K) + b_A = jnp.dot(b_k_eg, b_k_eng.T, preferred_element_type=jnp.float32) # (BT, BT) + b_A = b_A * b_beta[:, None] + + # 3. Neumann-series inversion: (I - L)^{-1} ≈ I + L + L² + # L element magnitude ~ O(K · σ² · β), where σ is the input scale. + # Requires σ << 1 so that L³ is negligible (e.g. σ=0.1 → L~0.32). + lower_mask = 1.0 - jnp.triu(jnp.ones((BT, BT), dtype=jnp.float32)) + b_L = lower_mask * b_A + b_L_sq = jnp.dot(b_L, b_L, preferred_element_type=jnp.float32) + b_A = (jnp.eye(BT, dtype=jnp.float32) + b_L + b_L_sq) * b_beta[None, :] + + # 4. effective keys / values + b_w = jnp.dot(b_A, b_k_eg, preferred_element_type=jnp.float32) # (BT, K) + b_u = jnp.dot(b_A, b_v, preferred_element_type=jnp.float32) # (BT, BV) + + # =================== inter-chunk recurrence =================== + + # 5. intra-chunk causal attention A_qk[c,j] = (q·exp(G))[c] · (k·exp(-G))[j] + b_qg = b_q * jnp.exp(b_g) + b_A_qk = jnp.dot(b_qg, b_k_eng.T, preferred_element_type=jnp.float32) + causal = (jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]).astype(jnp.float32) + b_A_qk = b_A_qk * causal + + # 6. delta-rule correction: v_corr = u - w @ S + b_v_corr = b_u - jnp.dot(b_w, b_h, preferred_element_type=jnp.float32) + + # 7. output: o = (q·exp(G)) @ S + A_qk @ v_corr + b_o = jnp.dot(b_qg, b_h, preferred_element_type=jnp.float32) + b_o = b_o + jnp.dot( + b_A_qk, + b_v_corr, + precision=jax.lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + o_ref[0, 0, 0] = b_o.astype(o_ref.dtype) + + # 8. state update: S ← S·exp(G_last) + (exp(G_last-G)·k)^T @ v_corr + b_g_last = b_g[BT - 1, :] # (K,) + b_h = b_h * jnp.exp(b_g_last)[:, None] # decay + b_k_decay = b_k * jnp.exp(b_g_last[None, :] - b_g) # (BT, K) + b_h = b_h + jnp.dot( + b_k_decay.T, + b_v_corr, + precision=jax.lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + scratch_ref[...] = b_h + + # ---- store final state ---- + @pl.when(i_t == NT - 1) + def _store_ht(): + if OUTPUT_FINAL_STATE and ht_ref is not None: + ht_ref[0, 0] = scratch_ref[...].astype(ht_ref.dtype) + + +# ============================================================================ +# Pallas launcher +# ============================================================================ + + +@functools.partial( + jax.jit, + static_argnames=("output_final_state", "chunk_size", "interpret"), +) +def _chunk_kda_fwd_launcher( + q: jax.Array, + k: jax.Array, + v: jax.Array, + g: jax.Array, + beta: jax.Array, + *, + h0: jax.Array | None = None, + scale: float, + output_final_state: bool = False, + chunk_size: int = 64, + interpret: bool = False, +) -> tuple[jax.Array, jax.Array | None]: + """Reshape inputs, launch the fused Pallas kernel, reshape outputs. + + Args: + q: [B, T, H, K] — queries. + k: [B, T, H, K] — keys. + v: [B, T, H, V] — values. + g: [B, T, H, K] — per-element gate in log-space. + beta: [B, T, H] — learning-rate scalars. + h0: [B, H, K, V] — initial hidden state (optional). + scale: attention scale. + output_final_state: if True, also return the final hidden state. + chunk_size: block size (T must be divisible by chunk_size after padding). + interpret: Pallas interpret mode. + + Returns: + o: [B, T, H, V] — output (float32). + ht: [B, H, K, V] or None — final hidden state (float32). + """ + B, T, H, K = q.shape + V = v.shape[-1] + BT = chunk_size + NT = T // BT + BV = 128 + NV = V // BV + + # ---- reshape to chunked layout [B, H, NT, BT, D] ---- + def _to_chunks(x, D): + return x.transpose(0, 2, 1, 3).reshape(B, H, NT, BT, D) + + _q = _to_chunks(q, K) + _k = _to_chunks(k, K) + _v = _to_chunks(v, V) + _g = _to_chunks(g, K) + + # beta: [B, T, H] -> [B, H, NT, BT, 128] (last dim padded for TPU alignment) + _beta = beta.transpose(0, 2, 1).reshape(B, H, NT, BT) + _beta = jnp.pad(_beta[..., None], ((0, 0), (0, 0), (0, 0), (0, 0), (0, 127))) + + _q = (_q * scale).astype(jnp.float32) + _h0 = h0.astype(jnp.float32) if h0 is not None else None + + # ---- BlockSpecs ---- + def qk_map(b, h, iv, nt): + return (b, h, nt, 0, 0) + + def v_map(b, h, iv, nt): + return (b, h, nt, 0, iv) + + def beta_map(b, h, iv, nt): + return (b, h, nt, 0, 0) + + def state_map(b, h, iv, nt): + return (b, h, 0, iv) + + spec_qk = pl.BlockSpec((1, 1, 1, BT, K), qk_map) + spec_v = pl.BlockSpec((1, 1, 1, BT, BV), v_map) + spec_beta = pl.BlockSpec((1, 1, 1, BT, 128), beta_map) + spec_h0 = pl.BlockSpec((1, 1, K, BV), state_map) if h0 is not None else None + spec_ht = pl.BlockSpec((1, 1, K, BV), state_map) if output_final_state else None + spec_o = pl.BlockSpec((1, 1, 1, BT, BV), v_map) + + # ---- launch ---- + results = pl.pallas_call( + functools.partial( + _chunk_kda_fwd_kernel, + BT=BT, + K=K, + NT=NT, + OUTPUT_FINAL_STATE=output_final_state, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + grid=(B, H, NV, NT), + in_specs=[spec_qk, spec_qk, spec_v, spec_qk, spec_beta, spec_h0], + out_specs=[spec_o, spec_ht], + scratch_shapes=[pltpu.VMEM((K, BV), jnp.float32)], + ), + out_shape=[ + jax.ShapeDtypeStruct((B, H, NT, BT, V), jnp.float32), + jax.ShapeDtypeStruct((B, H, K, V), jnp.float32) if output_final_state else None, + ], + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "parallel", "parallel", "arbitrary"), + ), + interpret=interpret, + )(_q, _k, _v, _g, _beta, _h0) + + o_chunked, ht = results + + # ---- reshape output: (B, H, NT, BT, V) -> (B, T, H, V) ---- + o = o_chunked.reshape(B, H, NT * BT, V).transpose(0, 2, 1, 3) + return o, ht + + +# ============================================================================ +# Public API +# ============================================================================ + + +def chunk_kda_fwd( + q: jax.Array, + k: jax.Array, + v: jax.Array, + g: jax.Array, + beta: jax.Array, + scale: float | None = None, + initial_state: jax.Array | None = None, + output_final_state: bool = False, + chunk_size: int = 64, +) -> tuple[jax.Array, jax.Array | None]: + """Chunk KDA forward — fully-fused Pallas TPU kernel. + + Processes the sequence in chunks of ``chunk_size``. Intra-chunk delta-rule + dependencies are resolved via a 2nd-order Neumann series approximation; + the hidden state is propagated across chunks by a fused Pallas kernel that + keeps S in VMEM scratch throughout. + + Core recurrence (per timestep): + S' = S_{t-1} * exp(g_t) — decay + residual = v_t - k_t^T @ S' — prediction error + S_t = S' + beta_t * k_t (x) residual — delta update + o_t = (q_t * scale)^T @ S_t — output + + Args: + q: [B, T, H, K] — queries. + k: [B, T, H, K] — keys. + v: [B, T, H, V] — values. + g: [B, T, H, K] — per-element gate in log-space. + beta: [B, T, H] — learning rate for delta rule. + scale: Scalar query scale. Defaults to K ** -0.5. + initial_state: [B, H, K, V] — initial hidden state (optional). + output_final_state: Whether to return the final hidden state. + chunk_size: Block size for chunked computation. + + Returns: + o: [B, T, H, V] — output (v.dtype). + final_state: [B, H, K, V] in float32, or None. + """ + orig_dtype = v.dtype + B, T, H, K = q.shape + V = v.shape[-1] + BT = chunk_size + BV = 128 + + if scale is None: + scale = K**-0.5 + + # ---- input validation ---- + assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D" + assert_shape(k, (B, T, H, K), "k") + assert v.ndim == 4 and v.shape[:3] == q.shape[:3], ( + f"v shape {v.shape} incompatible with q shape {q.shape}" + ) + assert_shape(g, (B, T, H, K), "g") + assert beta.ndim == 3 and beta.shape == (B, T, H), ( + f"beta shape {beta.shape} != ({B}, {T}, {H})" + ) + assert_shape_or_none(initial_state, (B, H, K, V), "initial_state") + assert BT > 0, f"chunk_size must be positive, got {BT}" + + # ---- pad T to multiple of chunk_size ---- + T_orig = T + T_padded = cdiv(T, BT) * BT + if T_padded > T: + q = pad_to_multiple(q, BT, 1, 0) + k = pad_to_multiple(k, BT, 1, 0) + v = pad_to_multiple(v, BT, 1, 0) + g = pad_to_multiple(g, BT, 1, 0) + beta = pad_to_multiple(beta, BT, 1, 0) + T = T_padded + + # ---- pad K / V to multiples of 128 ---- + origin_K, origin_V = K, V + K = align_up(K, 128) + V = align_up(V, BV) + + if K > origin_K: + q = pad_to_multiple(q, 128, 3, 0) + k = pad_to_multiple(k, 128, 3, 0) + g = pad_to_multiple(g, 128, 3, 0) + if V > origin_V: + v = pad_to_multiple(v, BV, 3, 0) + if initial_state is not None and (K > origin_K or V > origin_V): + initial_state = pad_to_multiple(initial_state, [128, BV], [2, 3], 0) + + # ---- launch ---- + interpret = not is_tpu_runtime() + o, ht = _chunk_kda_fwd_launcher( + q, + k, + v, + g, + beta, + h0=initial_state, + scale=scale, + output_final_state=output_final_state, + chunk_size=BT, + interpret=interpret, + ) + + # ---- trim padding and cast ---- + o = o[:, :T_orig, :, :origin_V].astype(orig_dtype) + if ht is not None: + ht = ht[:, :, :origin_K, :origin_V] + + return o, ht