feat: add fully-fused Pallas chunk KDA forward kernel#192
Conversation
Implement the chunked Kimi Delta Attention (KDA) forward pass as a single fused Pallas TPU kernel. Preprocessing (cumulative gate, interaction matrix, Neumann-series inversion, effective keys/values) and inter-chunk recurrence are computed inline to avoid HBM round-trips for intermediate tensors. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace `* 0.1` scaling on q and k with L2 normalization to keep the Neumann series truncation valid. Add test run instructions to docstring. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Warning 规模超限 此 PR 核心代码变更行数为 402,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
There was a problem hiding this comment.
Code Review
This pull request introduces a fully-fused Pallas TPU kernel for Kimi Delta Attention (KDA), implementing a chunked delta-rule recurrence. It includes a comprehensive test suite comparing the Pallas implementation against a CPU reference. Feedback focuses on critical architectural issues regarding state persistence in VMEM across the sequence dimension, numerical instability in gate exponentiation, excessive memory overhead from tensor padding, and potential accuracy improvements for the matrix inversion approximation.
| ), | ||
| grid_spec=pltpu.PrefetchScalarGridSpec( | ||
| num_scalar_prefetch=0, | ||
| grid=(B, H, NV, NT), |
There was a problem hiding this comment.
The sequence dimension NT is included in the Pallas grid, and the hidden state S is stored in scratch_ref (VMEM). However, VMEM scratch is not guaranteed to persist across different grid points in Pallas. For sequences with multiple chunks (NT > 1), each grid point for i_t > 0 will likely start with an uninitialized or zeroed state in VMEM, breaking the inter-chunk recurrence. To fix this and ensure state persistence in VMEM without HBM round-trips, the NT dimension should be removed from the grid and handled via a loop inside the kernel. The BlockSpecs should be updated to cover the full sequence length (or use vmap within the kernel).
|
|
||
| # 2. interaction matrix A[c,i] = Σ_d k[c,d]·exp(G[c,d]-G[i,d])·k[i,d] | ||
| b_k_eg = b_k * jnp.exp(b_g) # (BT, K) | ||
| b_k_eng = b_k * jnp.exp(-b_g) # (BT, K) |
There was a problem hiding this comment.
Computing exp(-b_g) is numerically unstable. Since b_g is a cumulative sum of gates (which are typically negative in log-space), -b_g can become large and positive, leading to float32 overflow. For example, with an average gate value of -2.0, after 64 steps -b_g reaches 128, and exp(128) overflows. A more stable approach is to compute the interaction matrix elements using the difference in exponents, e.g., exp(G_i - G_j). While this is harder to do with a single jnp.dot, you should at least consider subtracting a constant (like the max of b_g in the chunk) before the exponentiation to improve the dynamic range.
|
|
||
| # beta: [B, T, H] -> [B, H, NT, BT, 128] (last dim padded for TPU alignment) | ||
| _beta = beta.transpose(0, 2, 1).reshape(B, H, NT, BT) | ||
| _beta = jnp.pad(_beta[..., None], ((0, 0), (0, 0), (0, 0), (0, 0), (0, 127))) |
There was a problem hiding this comment.
Padding the beta tensor to 128 in the last dimension introduces a significant memory overhead (128x for this tensor). While TPU alignment is important, it can usually be achieved by ensuring the block size is a multiple of 8 or 16. Since beta is a scalar per timestep, you can pass it as a 1D array per chunk and load it as a vector in the kernel without this extreme padding.
| lower_mask = 1.0 - jnp.triu(jnp.ones((BT, BT), dtype=jnp.float32)) | ||
| b_L = lower_mask * b_A | ||
| b_L_sq = jnp.dot(b_L, b_L, preferred_element_type=jnp.float32) | ||
| b_A = (jnp.eye(BT, dtype=jnp.float32) + b_L + b_L_sq) * b_beta[None, :] |
There was a problem hiding this comment.
The 2nd-order Neumann series approximation (I + L + L^2) for (I - L)^{-1} may not be accurate enough when the interaction matrix L has large elements. This is likely why the tests require a high tolerance of atol=5e-2. Consider using a higher-order approximation or the iterative product method (I+L)(I+L^2)(I+L^4)... which can provide the exact inverse for a 64x64 triangular matrix in just 6 matrix multiplications (log2(64)).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Warning 规模超限 此 PR 核心代码变更行数为 402,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
Summary
Test plan
test_cpu_vs_pallascases (various B/T/H/K/V, with/without h0, custom scale)test_state_split_pallas— state continuity across split sequencestest_no_final_state_pallas— output_final_state=False returns Nonetest_matches_naive_recurrent— Pallas chunk matches naive recurrent ground truth🤖 Generated with Claude Code