Skip to content

feat: add fully-fused Pallas chunk KDA forward kernel#192

Draft
FENP wants to merge 3 commits intomainfrom
fenp/kda-impl
Draft

feat: add fully-fused Pallas chunk KDA forward kernel#192
FENP wants to merge 3 commits intomainfrom
fenp/kda-impl

Conversation

@FENP
Copy link
Copy Markdown
Contributor

@FENP FENP commented Apr 13, 2026

Summary

  • Implement chunked Kimi Delta Attention (KDA) forward pass as a single fused Pallas TPU kernel
  • Preprocessing (cumulative gate, interaction matrix, Neumann-series inversion, effective keys/values) and inter-chunk recurrence are computed inline to avoid HBM round-trips for intermediate tensors (w, u, A)
  • Add 20 tests covering correctness vs CPU reference, state continuity, and naive recurrent equivalence

Test plan

  • 17 parametrized test_cpu_vs_pallas cases (various B/T/H/K/V, with/without h0, custom scale)
  • test_state_split_pallas — state continuity across split sequences
  • test_no_final_state_pallas — output_final_state=False returns None
  • test_matches_naive_recurrent — Pallas chunk matches naive recurrent ground truth
  • All 20 tests passed on V7 TPU pod

🤖 Generated with Claude Code

FENP and others added 2 commits April 13, 2026 19:19
Implement the chunked Kimi Delta Attention (KDA) forward pass as a
single fused Pallas TPU kernel. Preprocessing (cumulative gate,
interaction matrix, Neumann-series inversion, effective keys/values)
and inter-chunk recurrence are computed inline to avoid HBM round-trips
for intermediate tensors.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace `* 0.1` scaling on q and k with L2 normalization to keep the
Neumann series truncation valid. Add test run instructions to docstring.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 13, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: dcd0db59-424c-4dee-b146-0bb84b21f7bc

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fenp/kda-impl

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 402,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fully-fused Pallas TPU kernel for Kimi Delta Attention (KDA), implementing a chunked delta-rule recurrence. It includes a comprehensive test suite comparing the Pallas implementation against a CPU reference. Feedback focuses on critical architectural issues regarding state persistence in VMEM across the sequence dimension, numerical instability in gate exponentiation, excessive memory overhead from tensor padding, and potential accuracy improvements for the matrix inversion approximation.

),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
grid=(B, H, NV, NT),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The sequence dimension NT is included in the Pallas grid, and the hidden state S is stored in scratch_ref (VMEM). However, VMEM scratch is not guaranteed to persist across different grid points in Pallas. For sequences with multiple chunks (NT > 1), each grid point for i_t > 0 will likely start with an uninitialized or zeroed state in VMEM, breaking the inter-chunk recurrence. To fix this and ensure state persistence in VMEM without HBM round-trips, the NT dimension should be removed from the grid and handled via a loop inside the kernel. The BlockSpecs should be updated to cover the full sequence length (or use vmap within the kernel).


# 2. interaction matrix A[c,i] = Σ_d k[c,d]·exp(G[c,d]-G[i,d])·k[i,d]
b_k_eg = b_k * jnp.exp(b_g) # (BT, K)
b_k_eng = b_k * jnp.exp(-b_g) # (BT, K)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Computing exp(-b_g) is numerically unstable. Since b_g is a cumulative sum of gates (which are typically negative in log-space), -b_g can become large and positive, leading to float32 overflow. For example, with an average gate value of -2.0, after 64 steps -b_g reaches 128, and exp(128) overflows. A more stable approach is to compute the interaction matrix elements using the difference in exponents, e.g., exp(G_i - G_j). While this is harder to do with a single jnp.dot, you should at least consider subtracting a constant (like the max of b_g in the chunk) before the exponentiation to improve the dynamic range.


# beta: [B, T, H] -> [B, H, NT, BT, 128] (last dim padded for TPU alignment)
_beta = beta.transpose(0, 2, 1).reshape(B, H, NT, BT)
_beta = jnp.pad(_beta[..., None], ((0, 0), (0, 0), (0, 0), (0, 0), (0, 127)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Padding the beta tensor to 128 in the last dimension introduces a significant memory overhead (128x for this tensor). While TPU alignment is important, it can usually be achieved by ensuring the block size is a multiple of 8 or 16. Since beta is a scalar per timestep, you can pass it as a 1D array per chunk and load it as a vector in the kernel without this extreme padding.

lower_mask = 1.0 - jnp.triu(jnp.ones((BT, BT), dtype=jnp.float32))
b_L = lower_mask * b_A
b_L_sq = jnp.dot(b_L, b_L, preferred_element_type=jnp.float32)
b_A = (jnp.eye(BT, dtype=jnp.float32) + b_L + b_L_sq) * b_beta[None, :]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The 2nd-order Neumann series approximation (I + L + L^2) for (I - L)^{-1} may not be accurate enough when the interaction matrix L has large elements. This is likely why the tests require a high tolerance of atol=5e-2. Consider using a higher-order approximation or the iterative product method (I+L)(I+L^2)(I+L^4)... which can provide the exact inverse for a 64x64 triangular matrix in just 6 matrix multiplications (log2(64)).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 402,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant