Skip to content

feat(kda): add KDA Pallas kernel with chunked forward + backward#174

Draft
0xaskr wants to merge 22 commits intomainfrom
feat/kda-pallas-kernel
Draft

feat(kda): add KDA Pallas kernel with chunked forward + backward#174
0xaskr wants to merge 22 commits intomainfrom
feat/kda-pallas-kernel

Conversation

@0xaskr
Copy link
Copy Markdown
Collaborator

@0xaskr 0xaskr commented Apr 7, 2026

Summary

  • Add KDA (Kimi Delta Attention) chunked Pallas TPU kernel: forward pass (chunk_kda_fwd) and backward pass via custom_vjp
  • Add delta-rule inter-chunk state propagation (tops/ops/common/chunk_delta_h.py) — shared by KDA and future Gated Delta Rule ops
  • Add fused recurrent KDA (fused_recurrent_kda) via lax.scan as CPU-compatible reference
  • Add intra-chunk triangular solve kernel with block-based forward substitution (B=16)
  • Fix USE_EXP2 midpoint stabilization bug in chunk_gla_fwd_o_gk_pl_kerneljnp.exp was used instead of jnp.exp2 for the h factor when USE_EXP2=True

New files

  • tops/ops/kda/ — gate.py, chunk.py, chunk_intra.py, fused_recurrent.py, naive.py
  • tops/ops/common/chunk_delta_h.py — pure JAX ref + Pallas TPU kernel (fwd)
  • tests/ops/kda/ — test_chunk_kda_tpu.py, test_fused_recurrent_kda_tpu.py

Test plan

  • 4 chunk_kda tests pass (3 forward + 1 backward) via PALLAS_INTERPRET=1
  • 7 fused_recurrent_kda tests pass (fp64, fp32, state continuity)
  • GLA tests unaffected (USE_EXP2=False path unchanged)

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added chunked KDA forward/backward, forward-only chunked KDA, fused recurrent KDA, and a CPU reference export.
  • Improvements

    • Faster, chunk-parallel intra-chunk solves and improved inter-chunk state propagation with optional final-state outputs.
  • Tests

    • Deterministic tests validating chunked and fused recurrent KDA (forward and backward) against the CPU reference.
  • Documentation

    • Added forward and backward chunk-parallel KDA design docs.

0xaskr and others added 15 commits April 7, 2026 16:18
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Update project description to be architecture-agnostic
- Remove numpy from base deps (transitive via flax)
- Move pytest back to base deps (needed everywhere)
- Remove unused torchvision dependency
- Remove non-existent jax[gpu] and flash-linear-attention[gpu] extras
- Remove unused profile optional-dependency group
- Remove torchvision from uv sources
- Switch CI GPU tests to uv sync
- Complete tpuv7 K8s Pod uv setup flow
- Update CLAUDE.md install commands

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rule state

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Migrates kda_intra_kernel.py into tops/ops/kda/chunk_intra.py, replacing
hardcoded interpret=True with get_interpret() and adding docstrings plus
assert_shape validation on all public launcher functions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add chunk_kda function wrapping chunk_kda_fwd with jax.custom_vjp for
backward pass support. The initial backward uses JAX autograd through
the forward; explicit backward kernels will follow. Update exports in
tops/ops/kda/__init__.py and tops/ops/__init__.py.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Also fix fused_recurrent_kda to use acc_dtype (fp64 for fp64 inputs,
fp32 otherwise) matching naive_kda's precision contract, instead of
hardcoding float32 for all dtypes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Fix exp/exp2 mismatch in chunk_gla_fwd_o_gk_pl_kernel: the midpoint
  stabilization factor for h was always using jnp.exp even when
  USE_EXP2=True (KDA log2 space gates), causing wrong output for
  multi-chunk sequences.
- Fix backward pass: use fused_recurrent_kda (lax.scan, fully
  differentiable) instead of chunk_kda_fwd (pallas_call, not
  differentiable) for gradient computation.
- Adjust T=128 forward test case to K=16 (from K=32) to avoid
  float32 precision issues with extreme value growth.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 7, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2f4fa974-61b1-49dd-9dc9-7c0fa5527941

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds chunked and fused-recurrent KDA implementations (TPU Pallas kernels and CPU helpers), intra-/inter-chunk kernels and orchestrator with a custom VJP, new tests and design docs, and re-exports to public ops packages.

Changes

Cohort / File(s) Summary
Tests
tests/ops/kda/test_chunk_kda_tpu.py, tests/ops/kda/test_fused_recurrent_kda_tpu.py
New TPU tests comparing chunked and fused-recurrent KDA against CPU naive_kda for forward, backward (gradients), and state-continuity across fp64/fp32.
KDA Public API
tops/ops/kda/__init__.py, tops/ops/kda/naive.py, tops/ops/__init__.py
New package exports: chunk_kda, chunk_kda_fwd, fused_recurrent_kda; naive_kda re-exported; top-level ops exports updated.
Chunked Forward Orchestrator & VJP
tops/ops/kda/chunk.py
Adds chunk_kda_fwd orchestrator (gate cumsum → intra-chunk solve → inter-chunk state → output assembly) and chunk_kda with a custom VJP that recomputes via fused-recurrent kernel for backward.
Intra-chunk TPU Kernels
tops/ops/kda/chunk_intra.py
Adds Pallas intra-chunk forward/backward kernels, a block forward-substitution solver, and JIT wrappers producing per-chunk Aqk/Akk/u/w/qg/kg.
Inter-chunk State Propagation
tops/ops/common/chunk_delta_h.py
Adds reference and Pallas (varlen/non-varlen) implementations for chunked gated delta-rule hidden-state propagation with optional saved v_new/final state.
Gate Processing
tops/ops/kda/gate.py
Adds kda_gate_chunk_cumsum converting natural-log gates to log2 and computing chunk-local cumsums (varlen support).
Output Fusion Fix
tops/ops/gla/chunk.py
Runtime selection between exp2 and exp in the chunk output kernel based on USE_EXP2.
CPU Chunked Helpers & Backward
tops/cpu/ops/kda/chunk.py, tops/cpu/ops/kda/__init__.py
Refactors CPU chunk_kda to use new chunk_kda_fwd/chunk_kda_bwd helpers; exports updated.
Fused Recurrent Implementation
tops/ops/kda/fused_recurrent.py
Adds fused_recurrent_kda implemented with lax.scan delta-rule recurrence, returning outputs and optional final state.
Docs
docs/design-docs/ops/kda/chunk-fwd.md, docs/design-docs/ops/kda/chunk-bwd.md
Design docs for chunk-parallel forward and backward pipelines, math, pseudocode, tensor shapes, and gradient-flow.

Sequence Diagram(s)

sequenceDiagram
    participant User as User
    participant ChunkOrch as chunk_kda_fwd<br/>(Orchestrator)
    participant GateCumsum as kda_gate_chunk_cumsum<br/>(Gate Processing)
    participant IntraChunk as kda_intra_chunk_fwd<br/>(Intra-chunk Attention)
    participant InterChunk as chunk_gated_delta_rule_fwd_h<br/>(Inter-chunk State)
    participant OutputFusion as chunk_gla_fwd_o_gk<br/>(Output Assembly)

    User->>ChunkOrch: q, k, v, g, beta, initial_state
    ChunkOrch->>GateCumsum: g (natural-log)
    GateCumsum-->>ChunkOrch: g_cumsum (log2)
    ChunkOrch->>IntraChunk: q,k,v,g_cumsum,beta (per-chunk layout)
    IntraChunk-->>ChunkOrch: u, w, qg, kg, Aqk, Akk_inv
    ChunkOrch->>InterChunk: w, u, g_cumsum, initial_state
    InterChunk-->>ChunkOrch: h (per-chunk), v_new, final_state
    ChunkOrch->>OutputFusion: Aqk, v_new, qg, kg
    OutputFusion-->>ChunkOrch: o
    ChunkOrch-->>User: o, final_state
Loading

Estimated Code Review Effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly Related PRs

Suggested labels

cpu-ref

Suggested reviewers

  • labyrinth-ssr
  • pathfinder-pf

Poem

🐰
I nibble gates and stitch the chunks,
Solve triangles where the data funks,
States hop forward, deltas play,
Outputs weave and gradients stay,
KDA hops—kernels saved the day! 🎉

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.11% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title accurately and clearly describes the main addition: a chunked KDA Pallas kernel with both forward and backward pass implementations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/kda-pallas-kernel

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Kimi Delta Attention (KDA) Pallas TPU kernels, implementing a four-stage pipeline: gate cumsum, intra-chunk triangular solve, inter-chunk state propagation, and output computation. It also includes comprehensive tests for these new operations and fixes a bug in the GLA chunked kernel where the incorrect exponential function was used during hidden state scaling. Additionally, the PR updates CI scripts and project metadata to utilize uv for more efficient dependency management. Feedback focuses on optimizing the Pallas kernels to avoid potential VMEM overflows and reducing overhead from inefficient array operations like stacking and splitting within TPU loops.

Comment on lines +125 to +133
g_diff = g_f32[:, None, :] - g_f32[None, :, :] # (C, C, D)
g_diff_safe = jnp.where(causal_mask_qk[:, :, None], g_diff, 0.0)
exp_g_diff = jnp.exp2(g_diff_safe) # (C, C, D)

# Aqk[i,j] = scale * sum_d q[i,d] * k[j,d] * exp(g[i,d] - g[j,d])
Aqk_raw = jnp.sum(
q.astype(jnp.float32)[:, None, :] * k.astype(jnp.float32)[None, :, :] * exp_g_diff,
axis=-1,
).astype(dtype) # (C, C)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of Aqk_raw and Akk_raw creates a large intermediate tensor exp_g_diff of shape [C, C, D]. For a typical chunk_size=128 and head_dim=256, this tensor alone consumes 128 * 128 * 256 * 4 bytes = 16MB. Since Pallas kernels on TPU have limited VMEM (often 16MB or 32MB total), having multiple such intermediates (for Aqk and Akk) will likely lead to out-of-memory errors during compilation or execution.

Consider refactoring this to compute the gated dot product in blocks over the head dimension D to keep the VMEM footprint manageable.

Comment on lines +53 to +63
rows = [x_block[r] for r in range(B)]
for j in range(B):
if j > 0:
vec = A_ii[j, :j][None, :]
mat = jnp.stack(rows[:j])
correction = jax.lax.dot_general(
vec, mat,
(((1,), (0,)), ((), ())),
preferred_element_type=jnp.float32
).squeeze(axis=0)
rows[j] = rows[j] - correction
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a list of arrays and jnp.stack inside a loop within a Pallas kernel is inefficient and may lead to excessive VMEM allocations or compilation overhead. In Pallas/XLA, it is generally better to work with a single array and use indexing or jax.lax.dynamic_update_slice (or at[j].set if supported in the context) to update the solution in-place.

Since B=16 is small and static, the loop will unroll, but the repeated stacking of rows creates many intermediate arrays.

Comment on lines +44 to +82
blocks = jnp.split(b, num_blocks, axis=0)

for i in range(num_blocks):
start = i * B
end = (i + 1) * B

A_ii = A[start:end, start:end]
x_block = blocks[i]

rows = [x_block[r] for r in range(B)]
for j in range(B):
if j > 0:
vec = A_ii[j, :j][None, :]
mat = jnp.stack(rows[:j])
correction = jax.lax.dot_general(
vec, mat,
(((1,), (0,)), ((), ())),
preferred_element_type=jnp.float32
).squeeze(axis=0)
rows[j] = rows[j] - correction

x_block = jnp.stack(rows)
blocks[i] = x_block

if i < num_blocks - 1:
rest_start = (i + 1) * B

x_rest = jnp.concatenate(blocks[i+1:], axis=0)
A_rest = A[rest_start:, start:end]

update = jax.lax.dot_general(
A_rest, x_block,
(((1,), (0,)), ((), ())),
preferred_element_type=jnp.float32
)
x_rest = x_rest - update

remaining_blocks_count = num_blocks - 1 - i
new_blocks = jnp.split(x_rest, remaining_blocks_count, axis=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The solve_unit_lower_triangular function uses jnp.split and jnp.concatenate on VMEM arrays. These operations are often expensive on TPU as they can trigger data movement or reshapes that are not well-optimized by the Pallas compiler compared to direct indexing.

Consider refactoring the block-based substitution to use direct indexing into the b and A arrays instead of splitting them into lists of blocks.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (9)
scripts/launch-tpuv7.yml (1)

61-62: Minor: Inconsistent comment language.

Lines 61-62 have Chinese comments while the rest of the file uses English. Consider translating for consistency:

-        secretName: ssh-key-secret # 必须与你 kubectl create secret 时起的名字一致
-        defaultMode: 0600 # 尝试直接设定权限 (虽然 cp 方式更保险)
+        secretName: ssh-key-secret # Must match the name used in kubectl create secret
+        defaultMode: 0600 # Set permissions directly (though cp method is safer)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/launch-tpuv7.yml` around lines 61 - 62, The inline comments next to
the secret configuration keys (secretName and defaultMode) are in Chinese while
the rest of the file uses English; update the comments adjacent to secretName
and defaultMode to English for consistency (e.g., "must match the name used when
creating the kubectl secret" and "attempt to set permissions directly (cp
approach is safer)") so the file's comment language is uniform.
tops/ops/kda/fused_recurrent.py (1)

104-105: Note: final_state dtype differs from input dtype.

The output o is cast back to q.dtype (line 102), but final_state remains in the accumulator dtype (float32 or float64). This is likely intentional for precision in state continuation, but worth documenting if callers expect matching dtypes.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/fused_recurrent.py` around lines 104 - 105, The returned
final_state currently remains in the accumulator dtype (float32/float64) while
the output o is cast to q.dtype; update the function to either (A) cast final_S
to q.dtype before assigning final_state (e.g. final_state =
final_S.astype(q.dtype) when output_final_state is True) to ensure matching
dtypes, or (B) explicitly document in the function docstring/inline comment
(referencing final_state, final_S, o, and q.dtype) that final_state
intentionally preserves the accumulator precision so callers know it may differ
from q.dtype.
tops/ops/kda/chunk_intra.py (1)

253-270: Rename ambiguous variable l to c or chunk_idx.

The variable l in lambdas is flagged by static analysis as ambiguous (can be confused with digit 1). Consider renaming for clarity.

🔧 Suggested fix (showing pattern for all occurrences)
         in_specs=[
-            pl.BlockSpec(index_map=lambda i, j, l: (i, j, l, 0, 0), block_shape=(1, 1, 1, chunk_size, K)), # q
+            pl.BlockSpec(index_map=lambda i, j, c: (i, j, c, 0, 0), block_shape=(1, 1, 1, chunk_size, K)), # q
             # ... apply same rename to all other BlockSpec lambdas
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/chunk_intra.py` around lines 253 - 270, The index_map lambdas in
the pl.BlockSpec entries use the ambiguous variable name "l"; rename that
parameter to a clearer name like "c" or "chunk_idx" across all in_specs and
out_specs to avoid confusion (e.g., change index_map=lambda i, j, l: ... to
index_map=lambda i, j, c: ...), and update any corresponding tuple construction
inside those lambdas (references in pl.BlockSpec usages for q, k, g, beta, v,
segment_ids, u, w, qg, kg, Aqk, Akk_inv) so behavior remains identical.
tops/ops/kda/chunk.py (3)

216-236: Prefix unused unpacked variables with underscore.

d_final_state, chunk_size, and cu_seqlens are unpacked but unused in the backward pass. This is flagged by static analysis.

🔧 Suggested fix
 def _chunk_kda_bwd_custom(saved, grad_outputs):
-    do, d_final_state = grad_outputs
-    q, k, v, g, beta, scale, initial_state, chunk_size, cu_seqlens = saved
+    do, _d_final_state = grad_outputs
+    q, k, v, g, beta, scale, initial_state, _chunk_size, _cu_seqlens = saved
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/chunk.py` around lines 216 - 236, In _chunk_kda_bwd_custom, some
items from the saved tuple are unpacked but not used (d_final_state, chunk_size,
cu_seqlens); update the unpacking to prefix these with underscores (e.g.,
_d_final_state, _chunk_size, _cu_seqlens) so they are clearly marked as unused
and satisfy static analysis, leaving the rest of the logic (use of q, k, v, g,
beta, initial_state and the jax.grad call) unchanged.

111-125: Prefix unused variables with underscore.

qg_bht and Akk_inv are unpacked but never used. Prefix them with _ to indicate they are intentionally unused.

🔧 Suggested fix
-    u_bht, w_bht, qg_bht, kg_bht, Aqk, Akk_inv = kda_intra_chunk_fwd(
+    u_bht, w_bht, _qg_bht, kg_bht, Aqk, _Akk_inv = kda_intra_chunk_fwd(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/chunk.py` around lines 111 - 125, Unpack kda_intra_chunk_fwd
results but mark the unused outputs as intentionally unused by prefixing qg_bht
and Akk_inv with underscores (e.g., change qg_bht -> _qg_bht and Akk_inv ->
_Akk_inv) in the assignment where kda_intra_chunk_fwd(...) is called; keep the
other variable names (u_bht, w_bht, kg_bht, Aqk) unchanged so behavior is
identical but linter/reader signals the two values are unused (also update the
inline comment if present to reflect the renamed unused symbols).

22-22: Remove unused import pad_to_multiple.

This import is not used anywhere in the file.

🔧 Suggested fix
-from tops.utils import assert_shape, assert_shape_or_none, pad_to_multiple
+from tops.utils import assert_shape, assert_shape_or_none
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/chunk.py` at line 22, The import list in tops/ops/kda/chunk.py
includes an unused symbol pad_to_multiple; remove pad_to_multiple from the from
tops.utils import ... statement so only the used helpers (assert_shape,
assert_shape_or_none) are imported, ensuring no other references to
pad_to_multiple remain in the file.
tops/ops/common/chunk_delta_h.py (1)

217-217: Use assert IS_VARLEN instead of assert IS_VARLEN == True.

Per static analysis and Python idiom, use truthiness check directly.

🔧 Suggested fix
-  assert IS_VARLEN == True
+  assert IS_VARLEN
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_delta_h.py` at line 217, The assertion currently uses
an explicit comparison "assert IS_VARLEN == True"; change it to a simple
truthiness check by replacing that expression with "assert IS_VARLEN" to follow
Python idioms and satisfy static analysis; locate the assertion referencing the
IS_VARLEN symbol in chunk_delta_h.py (the assertion line that currently compares
IS_VARLEN to True) and update it accordingly.
tests/ops/kda/test_chunk_kda_tpu.py (2)

104-109: Add strict=True to zip() for defensive iteration.

Per static analysis hint and best practice, adding strict=True ensures the loop fails fast if the iterables have mismatched lengths.

🔧 Suggested fix
     names = ["dq", "dk", "dv", "dg", "dbeta"]
-    for name, gn, gc in zip(names, grads_naive, grads_chunk):
+    for name, gn, gc in zip(names, grads_naive, grads_chunk, strict=True):
         assert compare_tensor(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/kda/test_chunk_kda_tpu.py` around lines 104 - 109, The loop that
zips names, grads_naive, and grads_chunk should use zip(..., strict=True) to
fail fast on mismatched lengths; update the for loop that iterates over names,
grads_naive, grads_chunk (the block invoking compare_tensor for each name, gn,
gc) to call zip(names, grads_naive, grads_chunk, strict=True) so any length
mismatch raises immediately and prevents silent mis-alignment.

75-83: Tolerances are quite loose at 5e-2 — verify this is expected.

The forward test uses atol=5e-2, rtol=5e-2 which is relatively permissive compared to the fused recurrent tests (1e-10 for fp64, 5e-5 for fp32). If this is due to the chunked approximation introducing more numerical divergence, consider adding a brief comment explaining why these tolerances are necessary.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/kda/test_chunk_kda_tpu.py` around lines 75 - 83, The test uses
loose tolerances (atol=5e-2, rtol=5e-2) in the compare_tensor checks for
"output" and "final_state" (variables o_naive, o_chunk, s_naive, s_chunk) —
either tighten these to match the fused recurrent test thresholds (e.g., 1e-10
for fp64 or 5e-5 for fp32) if chunking should be numerically equivalent, or add
a short inline comment above the assertions explaining that the chunked KDA/TPU
implementation intentionally incurs greater numerical divergence and therefore
requires the relaxed tolerances; update the comment to reference why fp dtype
and chunking cause the chosen atol/rtol and keep the compare_tensor calls
unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/kda/chunk_intra.py`:
- Around line 28-88: The function solve_unit_lower_triangular assumes B=16
divides N by using num_blocks = N // B and jnp.split, which will fail when chunk
size isn't divisible by 16; either assert divisibility or handle remainders. Fix
by checking N % B == 0 at the start (raise/assert with clear message) or change
splitting logic to compute num_blocks = math.ceil(N / B) and replace jnp.split
with jnp.array_split (or pad the last block) and ensure the block-processing
loops (A_ii, rows logic, and update of x_rest) correctly handle a final block
size < B; update references to B, num_blocks, blocks, A_ii, x_block, and x_rest
accordingly.
- Around line 313-319: Fix the inconsistent mask: in tops/ops/kda/chunk_intra.py
update the definition of causal_mask used for the backward kernel (the variable
named causal_mask that feeds mask_akk) to use a strict greater-than comparison
(idx[:, None] > idx[None, :]) instead of >= so it matches the forward kernel
behavior referenced at line 111; leave causal_mask_qk (used for mask_aqk) as >=
so the two masks (causal_mask -> mask_akk and causal_mask_qk -> mask_aqk)
reflect strictly lower-triangular vs lower-triangular-with-diagonal semantics
consistently.

---

Nitpick comments:
In `@scripts/launch-tpuv7.yml`:
- Around line 61-62: The inline comments next to the secret configuration keys
(secretName and defaultMode) are in Chinese while the rest of the file uses
English; update the comments adjacent to secretName and defaultMode to English
for consistency (e.g., "must match the name used when creating the kubectl
secret" and "attempt to set permissions directly (cp approach is safer)") so the
file's comment language is uniform.

In `@tests/ops/kda/test_chunk_kda_tpu.py`:
- Around line 104-109: The loop that zips names, grads_naive, and grads_chunk
should use zip(..., strict=True) to fail fast on mismatched lengths; update the
for loop that iterates over names, grads_naive, grads_chunk (the block invoking
compare_tensor for each name, gn, gc) to call zip(names, grads_naive,
grads_chunk, strict=True) so any length mismatch raises immediately and prevents
silent mis-alignment.
- Around line 75-83: The test uses loose tolerances (atol=5e-2, rtol=5e-2) in
the compare_tensor checks for "output" and "final_state" (variables o_naive,
o_chunk, s_naive, s_chunk) — either tighten these to match the fused recurrent
test thresholds (e.g., 1e-10 for fp64 or 5e-5 for fp32) if chunking should be
numerically equivalent, or add a short inline comment above the assertions
explaining that the chunked KDA/TPU implementation intentionally incurs greater
numerical divergence and therefore requires the relaxed tolerances; update the
comment to reference why fp dtype and chunking cause the chosen atol/rtol and
keep the compare_tensor calls unchanged.

In `@tops/ops/common/chunk_delta_h.py`:
- Line 217: The assertion currently uses an explicit comparison "assert
IS_VARLEN == True"; change it to a simple truthiness check by replacing that
expression with "assert IS_VARLEN" to follow Python idioms and satisfy static
analysis; locate the assertion referencing the IS_VARLEN symbol in
chunk_delta_h.py (the assertion line that currently compares IS_VARLEN to True)
and update it accordingly.

In `@tops/ops/kda/chunk_intra.py`:
- Around line 253-270: The index_map lambdas in the pl.BlockSpec entries use the
ambiguous variable name "l"; rename that parameter to a clearer name like "c" or
"chunk_idx" across all in_specs and out_specs to avoid confusion (e.g., change
index_map=lambda i, j, l: ... to index_map=lambda i, j, c: ...), and update any
corresponding tuple construction inside those lambdas (references in
pl.BlockSpec usages for q, k, g, beta, v, segment_ids, u, w, qg, kg, Aqk,
Akk_inv) so behavior remains identical.

In `@tops/ops/kda/chunk.py`:
- Around line 216-236: In _chunk_kda_bwd_custom, some items from the saved tuple
are unpacked but not used (d_final_state, chunk_size, cu_seqlens); update the
unpacking to prefix these with underscores (e.g., _d_final_state, _chunk_size,
_cu_seqlens) so they are clearly marked as unused and satisfy static analysis,
leaving the rest of the logic (use of q, k, v, g, beta, initial_state and the
jax.grad call) unchanged.
- Around line 111-125: Unpack kda_intra_chunk_fwd results but mark the unused
outputs as intentionally unused by prefixing qg_bht and Akk_inv with underscores
(e.g., change qg_bht -> _qg_bht and Akk_inv -> _Akk_inv) in the assignment where
kda_intra_chunk_fwd(...) is called; keep the other variable names (u_bht, w_bht,
kg_bht, Aqk) unchanged so behavior is identical but linter/reader signals the
two values are unused (also update the inline comment if present to reflect the
renamed unused symbols).
- Line 22: The import list in tops/ops/kda/chunk.py includes an unused symbol
pad_to_multiple; remove pad_to_multiple from the from tops.utils import ...
statement so only the used helpers (assert_shape, assert_shape_or_none) are
imported, ensuring no other references to pad_to_multiple remain in the file.

In `@tops/ops/kda/fused_recurrent.py`:
- Around line 104-105: The returned final_state currently remains in the
accumulator dtype (float32/float64) while the output o is cast to q.dtype;
update the function to either (A) cast final_S to q.dtype before assigning
final_state (e.g. final_state = final_S.astype(q.dtype) when output_final_state
is True) to ensure matching dtypes, or (B) explicitly document in the function
docstring/inline comment (referencing final_state, final_S, o, and q.dtype) that
final_state intentionally preserves the accumulator precision so callers know it
may differ from q.dtype.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 37ed064d-ae9b-445c-ba5a-0fbe7c90a7d9

📥 Commits

Reviewing files that changed from the base of the PR and between dfe6ffa and 523a595.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (16)
  • .github/ci/gpu-tests.sky.yaml
  • CLAUDE.md
  • pyproject.toml
  • scripts/launch-tpuv7.yml
  • tests/ops/kda/__init__.py
  • tests/ops/kda/test_chunk_kda_tpu.py
  • tests/ops/kda/test_fused_recurrent_kda_tpu.py
  • tops/ops/__init__.py
  • tops/ops/common/chunk_delta_h.py
  • tops/ops/gla/chunk.py
  • tops/ops/kda/__init__.py
  • tops/ops/kda/chunk.py
  • tops/ops/kda/chunk_intra.py
  • tops/ops/kda/fused_recurrent.py
  • tops/ops/kda/gate.py
  • tops/ops/kda/naive.py

Comment on lines +28 to +88
def solve_unit_lower_triangular(A, b):
"""
Solves (I + A) x = b for x, where A is strictly lower triangular.
Uses block-based forward substitution for better performance on TPU.
Args:
A: (N, N) strictly lower triangular matrix in VMEM.
b: (N, D) matrix in VMEM.
Returns:
x: (N, D) solution matrix.
"""
N, D = b.shape
B = 16
num_blocks = N // B
A = A.astype(jnp.float32)
b = b.astype(jnp.float32)

blocks = jnp.split(b, num_blocks, axis=0)

for i in range(num_blocks):
start = i * B
end = (i + 1) * B

A_ii = A[start:end, start:end]
x_block = blocks[i]

rows = [x_block[r] for r in range(B)]
for j in range(B):
if j > 0:
vec = A_ii[j, :j][None, :]
mat = jnp.stack(rows[:j])
correction = jax.lax.dot_general(
vec, mat,
(((1,), (0,)), ((), ())),
preferred_element_type=jnp.float32
).squeeze(axis=0)
rows[j] = rows[j] - correction

x_block = jnp.stack(rows)
blocks[i] = x_block

if i < num_blocks - 1:
rest_start = (i + 1) * B

x_rest = jnp.concatenate(blocks[i+1:], axis=0)
A_rest = A[rest_start:, start:end]

update = jax.lax.dot_general(
A_rest, x_block,
(((1,), (0,)), ((), ())),
preferred_element_type=jnp.float32
)
x_rest = x_rest - update

remaining_blocks_count = num_blocks - 1 - i
new_blocks = jnp.split(x_rest, remaining_blocks_count, axis=0)

for k, nb in enumerate(new_blocks):
blocks[i + 1 + k] = nb

x = jnp.concatenate(blocks, axis=0)
return x
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Block size B=16 assumes chunk_size is divisible by 16.

The solver uses a hardcoded block size of 16 and num_blocks = N // B. If chunk_size is not divisible by 16, jnp.split will fail. Consider adding an assertion or handling non-divisible cases.

🛡️ Suggested defensive assertion
 def solve_unit_lower_triangular(A, b):
     """..."""
     N, D = b.shape
     B = 16
+    assert N % B == 0, f"Matrix size N={N} must be divisible by block size B={B}"
     num_blocks = N // B
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def solve_unit_lower_triangular(A, b):
"""
Solves (I + A) x = b for x, where A is strictly lower triangular.
Uses block-based forward substitution for better performance on TPU.
Args:
A: (N, N) strictly lower triangular matrix in VMEM.
b: (N, D) matrix in VMEM.
Returns:
x: (N, D) solution matrix.
"""
N, D = b.shape
B = 16
num_blocks = N // B
A = A.astype(jnp.float32)
b = b.astype(jnp.float32)
blocks = jnp.split(b, num_blocks, axis=0)
for i in range(num_blocks):
start = i * B
end = (i + 1) * B
A_ii = A[start:end, start:end]
x_block = blocks[i]
rows = [x_block[r] for r in range(B)]
for j in range(B):
if j > 0:
vec = A_ii[j, :j][None, :]
mat = jnp.stack(rows[:j])
correction = jax.lax.dot_general(
vec, mat,
(((1,), (0,)), ((), ())),
preferred_element_type=jnp.float32
).squeeze(axis=0)
rows[j] = rows[j] - correction
x_block = jnp.stack(rows)
blocks[i] = x_block
if i < num_blocks - 1:
rest_start = (i + 1) * B
x_rest = jnp.concatenate(blocks[i+1:], axis=0)
A_rest = A[rest_start:, start:end]
update = jax.lax.dot_general(
A_rest, x_block,
(((1,), (0,)), ((), ())),
preferred_element_type=jnp.float32
)
x_rest = x_rest - update
remaining_blocks_count = num_blocks - 1 - i
new_blocks = jnp.split(x_rest, remaining_blocks_count, axis=0)
for k, nb in enumerate(new_blocks):
blocks[i + 1 + k] = nb
x = jnp.concatenate(blocks, axis=0)
return x
def solve_unit_lower_triangular(A, b):
"""
Solves (I + A) x = b for x, where A is strictly lower triangular.
Uses block-based forward substitution for better performance on TPU.
Args:
A: (N, N) strictly lower triangular matrix in VMEM.
b: (N, D) matrix in VMEM.
Returns:
x: (N, D) solution matrix.
"""
N, D = b.shape
B = 16
assert N % B == 0, f"Matrix size N={N} must be divisible by block size B={B}"
num_blocks = N // B
A = A.astype(jnp.float32)
b = b.astype(jnp.float32)
blocks = jnp.split(b, num_blocks, axis=0)
for i in range(num_blocks):
start = i * B
end = (i + 1) * B
A_ii = A[start:end, start:end]
x_block = blocks[i]
rows = [x_block[r] for r in range(B)]
for j in range(B):
if j > 0:
vec = A_ii[j, :j][None, :]
mat = jnp.stack(rows[:j])
correction = jax.lax.dot_general(
vec, mat,
(((1,), (0,)), ((), ())),
preferred_element_type=jnp.float32
).squeeze(axis=0)
rows[j] = rows[j] - correction
x_block = jnp.stack(rows)
blocks[i] = x_block
if i < num_blocks - 1:
rest_start = (i + 1) * B
x_rest = jnp.concatenate(blocks[i+1:], axis=0)
A_rest = A[rest_start:, start:end]
update = jax.lax.dot_general(
A_rest, x_block,
(((1,), (0,)), ((), ())),
preferred_element_type=jnp.float32
)
x_rest = x_rest - update
remaining_blocks_count = num_blocks - 1 - i
new_blocks = jnp.split(x_rest, remaining_blocks_count, axis=0)
for k, nb in enumerate(new_blocks):
blocks[i + 1 + k] = nb
x = jnp.concatenate(blocks, axis=0)
return x
🧰 Tools
🪛 Ruff (0.15.9)

[warning] 38-38: Unpacked variable D is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/chunk_intra.py` around lines 28 - 88, The function
solve_unit_lower_triangular assumes B=16 divides N by using num_blocks = N // B
and jnp.split, which will fail when chunk size isn't divisible by 16; either
assert divisibility or handle remainders. Fix by checking N % B == 0 at the
start (raise/assert with clear message) or change splitting logic to compute
num_blocks = math.ceil(N / B) and replace jnp.split with jnp.array_split (or pad
the last block) and ensure the block-processing loops (A_ii, rows logic, and
update of x_rest) correctly handle a final block size < B; update references to
B, num_blocks, blocks, A_ii, x_block, and x_rest accordingly.

Comment on lines +313 to +319
idx = jnp.arange(chunk_size, dtype=jnp.int32)
causal_mask = idx[:, None] >= idx[None, :]
causal_mask_qk = idx[:, None] >= idx[None, :]
segment_mask = segment_ids[:, None] == segment_ids[None, :]

mask_akk = causal_mask & segment_mask
mask_aqk = causal_mask_qk & segment_mask
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if the forward kernel has different mask definitions
rg -n "causal_mask.*>" tops/ops/kda/chunk_intra.py | head -20

Repository: primatrix/pallas-kernel

Length of output: 344


Fix inconsistent mask definitions: causal_mask should use > in backward to match forward kernel.

Line 314 defines causal_mask with >= (includes diagonal), but the forward kernel (line 111) uses > (strictly lower triangular). Line 315's causal_mask_qk correctly uses >= to match the forward kernel. Align line 314 to use > for consistency:

causal_mask = idx[:, None] > idx[None, :]        # strictly lower triangular
causal_mask_qk = idx[:, None] >= idx[None, :]   # lower triangular with diagonal
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/chunk_intra.py` around lines 313 - 319, Fix the inconsistent
mask: in tops/ops/kda/chunk_intra.py update the definition of causal_mask used
for the backward kernel (the variable named causal_mask that feeds mask_akk) to
use a strict greater-than comparison (idx[:, None] > idx[None, :]) instead of >=
so it matches the forward kernel behavior referenced at line 111; leave
causal_mask_qk (used for mask_aqk) as >= so the two masks (causal_mask ->
mask_akk and causal_mask_qk -> mask_aqk) reflect strictly lower-triangular vs
lower-triangular-with-diagonal semantics consistently.

0xaskr and others added 3 commits April 7, 2026 23:08
Add detailed design documentation for KDA chunk-parallel forward and
backward kernels, with mathematical derivations and FLA optimized
implementation mapping.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
docs/design-docs/ops/kda/chunk-fwd.md (1)

11-11: Fix heading level jump at Line 11.

### appears without a preceding ##, which triggers MD001 and breaks heading structure consistency.

Proposed markdown fix
-### 核心思想与背景
+## 核心思想与背景
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/design-docs/ops/kda/chunk-fwd.md` at line 11, The heading "核心思想与背景" uses
"###" causing a level jump (MD001); change its markdown to a level-2 heading
("## 核心思想与背景") or add an appropriate level-2 parent heading above it so the
document follows a consistent heading hierarchy; locate the heading text
"核心思想与背景" and adjust its leading hashes (or insert the missing parent heading)
to restore correct heading levels.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/design-docs/ops/kda/chunk-bwd.md`:
- Around line 9-11: The heading at "### 核心思想与背景" jumps a level (MD001); fix by
promoting it to a top-level section or adding its parent: either change "###
核心思想与背景" to "## 核心思想与背景" or insert a preceding "##" parent heading (e.g., "##
背景与概述") immediately before the existing "### 核心思想与背景" so the heading hierarchy
is valid.

In `@docs/design-docs/ops/kda/chunk-fwd.md`:
- Around line 324-327: The pseudocode incorrectly re-applies gk gating to the
value path: remove the b_v *= exp2(b_gk_last - b_gk) line so v_new is not gated
again (k-path already handles gating), and keep b_h *= exp2(b_gk_last) for state
decay only; reference the b_gk_last, b_gk, b_v, b_h and v_new symbols and align
the doc to the behavior in tops/ops/common/chunk_delta_h.py where gk only
affects state decay.

---

Nitpick comments:
In `@docs/design-docs/ops/kda/chunk-fwd.md`:
- Line 11: The heading "核心思想与背景" uses "###" causing a level jump (MD001); change
its markdown to a level-2 heading ("## 核心思想与背景") or add an appropriate level-2
parent heading above it so the document follows a consistent heading hierarchy;
locate the heading text "核心思想与背景" and adjust its leading hashes (or insert the
missing parent heading) to restore correct heading levels.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1f8917b0-2fbf-423d-8330-c64e33ec51b2

📥 Commits

Reviewing files that changed from the base of the PR and between 523a595 and 337a2d5.

📒 Files selected for processing (2)
  • docs/design-docs/ops/kda/chunk-bwd.md
  • docs/design-docs/ops/kda/chunk-fwd.md

Comment on lines +9 to +11
### 核心思想与背景

反向传播的目标:给定损失函数对输出的梯度 $`\frac{\partial \mathcal{L}}{\partial \mathbf{o}}`$(记为 $`d\mathbf{o}`$),计算损失函数对所有输入 $`\mathbf{q}, \mathbf{k}, \mathbf{v}, \mathbf{g}, \boldsymbol{\beta}`$ 的梯度,以及对初始状态 $`\mathbf{h}_0`$ 的梯度。
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix heading level jump to satisfy MD001

Line 9 starts with ### without a preceding ## section level, which breaks heading hierarchy. Please introduce an ## parent heading (or downgrade this heading level) before Line 9.

🧰 Tools
🪛 markdownlint-cli2 (0.22.0)

[warning] 11-11: Heading levels should only increment by one level at a time
Expected: h2; Actual: h3

(MD001, heading-increment)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/design-docs/ops/kda/chunk-bwd.md` around lines 9 - 11, The heading at
"### 核心思想与背景" jumps a level (MD001); fix by promoting it to a top-level section
or adding its parent: either change "### 核心思想与背景" to "## 核心思想与背景" or insert a
preceding "##" parent heading (e.g., "## 背景与概述") immediately before the existing
"### 核心思想与背景" so the heading hierarchy is valid.

Comment on lines +324 to +327
# 3. 门控衰减
b_gk_last = gk[chunk 末尾位置]
b_v *= exp2(b_gk_last - b_gk) # 对 v_new 应用相对衰减
b_h *= exp2(b_gk_last) # 对状态应用 chunk 末尾衰减
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Correct gk-path pseudocode: v_new should not be gated again.

At Line 324-Line 327, the pseudocode applies exp2(b_gk_last - b_gk) to b_v in the gk flow. That conflicts with the documented/reference behavior in tops/ops/common/chunk_delta_h.py:44-73, where gk is used for state decay and k is already gated.

Proposed doc correction
-    # 3. 门控衰减
-    b_gk_last = gk[chunk 末尾位置]
-    b_v *= exp2(b_gk_last - b_gk)                    # 对 v_new 应用相对衰减
-    b_h *= exp2(b_gk_last)                           # 对状态应用 chunk 末尾衰减
+    # 3. 门控衰减
+    b_gk_last = gk[chunk 末尾位置]
+    # gk 路径下不对 v_new 额外门控(k/kg 已完成位置门控)
+    b_h *= exp2(b_gk_last)                           # 仅对状态应用 chunk 末尾衰减
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# 3. 门控衰减
b_gk_last = gk[chunk 末尾位置]
b_v *= exp2(b_gk_last - b_gk) # 对 v_new 应用相对衰减
b_h *= exp2(b_gk_last) # 对状态应用 chunk 末尾衰减
# 3. 门控衰减
b_gk_last = gk[chunk 末尾位置]
# gk 路径下不对 v_new 额外门控(k/kg 已完成位置门控)
b_h *= exp2(b_gk_last) # 仅对状态应用 chunk 末尾衰减
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/design-docs/ops/kda/chunk-fwd.md` around lines 324 - 327, The pseudocode
incorrectly re-applies gk gating to the value path: remove the b_v *=
exp2(b_gk_last - b_gk) line so v_new is not gated again (k-path already handles
gating), and keep b_h *= exp2(b_gk_last) for state decay only; reference the
b_gk_last, b_gk, b_v, b_h and v_new symbols and align the doc to the behavior in
tops/ops/common/chunk_delta_h.py where gk only affects state decay.

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2458,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tops/cpu/ops/kda/chunk.py (1)

549-564: ⚠️ Potential issue | 🟡 Minor

Add dtype assertions to the public API.

These checks only validate shapes right now. chunk_kda still accepts mixed dtypes across q/k/v/g/beta/initial_state, which can silently change promotion and make the CPU reference diverge from the kernel path.

🛡️ Suggested assertions
   assert beta.ndim == 3 and beta.shape == q.shape[:3], (
     f"beta shape {beta.shape} != {q.shape[:3]}"
   )
+  assert k.dtype == q.dtype, f"k dtype {k.dtype} != q dtype {q.dtype}"
+  assert v.dtype == q.dtype, f"v dtype {v.dtype} != q dtype {q.dtype}"
+  assert g.dtype == q.dtype, f"g dtype {g.dtype} != q dtype {q.dtype}"
+  assert beta.dtype == q.dtype, f"beta dtype {beta.dtype} != q dtype {q.dtype}"
   if initial_state is not None:
     assert initial_state.shape == (B, H, K, V), (
       f"initial_state shape {initial_state.shape} != ({B}, {H}, {K}, {V})"
     )
+    assert initial_state.dtype in (q.dtype, acc_dt), (
+      f"initial_state dtype {initial_state.dtype} must match q dtype {q.dtype} "
+      f"or accumulator dtype {acc_dt}"
+    )

As per coding guidelines, "All public functions must enforce strict input assertions on shape and types before executing main logic using assert instructions or utilities like assert_shape_or_none from tops.utils".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk.py` around lines 549 - 564, Add strict dtype
assertions in chunk_kda: assert q,k,v,g,beta all have the same dtype and that
dtype is a supported floating type (e.g., torch.float16/float32/bfloat16), and
if initial_state is not None assert initial_state.dtype == v.dtype and
initial_state.dtype is the same supported float type; use the existing symbols
q, k, v, g, beta, initial_state and add concise assert statements near the
existing shape checks so mixed dtypes are rejected before main logic runs.
♻️ Duplicate comments (2)
docs/design-docs/ops/kda/chunk-fwd.md (1)

373-376: ⚠️ Potential issue | 🟠 Major

The gk pseudocode still gates v_new twice.

tops/ops/common/chunk_delta_h.py:44-80 computes v_new = u - w @ h and then updates the state with gated keys; it does not apply an extra exp2(b_gk_last - b_gk) factor to v_new. Keeping Line 375 here documents a different recurrence than the implementation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/design-docs/ops/kda/chunk-fwd.md` around lines 373 - 376, The doc's
pseudocode incorrectly applies an extra exp2(b_gk_last - b_gk) gating to v_new,
which mismatches the implementation in tops/ops/common/chunk_delta_h.py where
v_new is computed as v_new = u - w @ h and only the state update uses gated
keys; remove the line that multiplies b_v by exp2(b_gk_last - b_gk) (or clarify
that only b_h is decayed by exp2(b_gk_last)) so the pseudocode matches the
implementation (referencing symbols gk, b_v, b_h, v_new, u, w, h and the
function/logic in chunk_delta_h.py).
docs/design-docs/ops/kda/chunk-bwd.md (1)

9-9: ⚠️ Potential issue | 🟡 Minor

Fix the heading level jump.

This still starts with ### before any ## section, so MD001 remains unresolved.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/design-docs/ops/kda/chunk-bwd.md` at line 9, The document has a
level-jump: the heading "核心思想与背景" uses ### but no preceding ## exists, causing
MD001; fix it by promoting that header to a top-level section (change ###
核心思想与背景 to ## 核心思想与背景) or add a preceding ## section before it so heading levels
increase incrementally; update the heading token for "核心思想与背景" accordingly to
resolve the MD001 rule.
🧹 Nitpick comments (1)
tops/cpu/ops/kda/chunk.py (1)

593-600: Avoid computing backward intermediates on the forward-only path.

chunk_kda now always materializes Aqk and Akk, then drops both values. That adds O(B·H·NT·C²) work and memory to every CPU reference forward even when callers only need o and final_state. A return_intermediates=False fast path would keep the reference cheaper and also clear the unused-variable warning.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk.py` around lines 593 - 600, The forward path in
chunk_kda is always materializing Aqk and Akk via chunk_kda_fwd then discarding
them, which wastes O(B·H·NT·C²) compute/memory; add a fast-path flag (e.g.,
return_intermediates=False) to chunk_kda and only request/compute Aqk and Akk
from chunk_kda_fwd when that flag is true (or when callers explicitly need
intermediates), otherwise call chunk_kda_fwd with output_intermediates=False (or
an equivalent param) and only capture o and final_state; update the chunk_kda
signature and callers to default to not materializing intermediates and remove
the allocations/assignments of Aqk/Akk on the forward-only path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/design-docs/ops/kda/chunk-bwd.md`:
- Around line 671-676: The tables incorrectly list dAqk and dAkk as [B, T, H, C]
but they are chunk-local CxC matrices; update their shapes in the Stage 1 (and
also the entries at 686-695) to [B, H, NT, C, C] to match the symbol definitions
and referenced kernels (symbols: dAqk, dAkk); also verify dv remains [B, T, H,
V] and adjust any surrounding textual descriptions that rely on the old shape
notation.
- Around line 15-20: The fenced code blocks containing the pseudocode (the block
with "Step 0: g = cumsum(g_raw) ... Step 3: o = chunk_gla_fwd_o_gk(...)") and
the other similar blocks at the ranges mentioned should include an info string
(e.g., ```text or ```text+pseudo) to satisfy MD040; update each triple-backtick
opening fence for those blocks (including the blocks around lines referenced as
48-55 and 638-652) to include a language token so the linter stops flagging
them.

In `@docs/design-docs/ops/kda/chunk-fwd.md`:
- Line 11: The heading "### 核心思想与背景" uses h3 without a parent h2 which fails
MD001; update this heading to "## 核心思想与背景" (or insert a higher-level parent
section above it) so the document has a proper hierarchy—locate the line with
the exact text "### 核心思想与背景" and promote it to an h2 (change the leading ### to
##) or add an appropriate parent h2 before it.

---

Outside diff comments:
In `@tops/cpu/ops/kda/chunk.py`:
- Around line 549-564: Add strict dtype assertions in chunk_kda: assert
q,k,v,g,beta all have the same dtype and that dtype is a supported floating type
(e.g., torch.float16/float32/bfloat16), and if initial_state is not None assert
initial_state.dtype == v.dtype and initial_state.dtype is the same supported
float type; use the existing symbols q, k, v, g, beta, initial_state and add
concise assert statements near the existing shape checks so mixed dtypes are
rejected before main logic runs.

---

Duplicate comments:
In `@docs/design-docs/ops/kda/chunk-bwd.md`:
- Line 9: The document has a level-jump: the heading "核心思想与背景" uses ### but no
preceding ## exists, causing MD001; fix it by promoting that header to a
top-level section (change ### 核心思想与背景 to ## 核心思想与背景) or add a preceding ##
section before it so heading levels increase incrementally; update the heading
token for "核心思想与背景" accordingly to resolve the MD001 rule.

In `@docs/design-docs/ops/kda/chunk-fwd.md`:
- Around line 373-376: The doc's pseudocode incorrectly applies an extra
exp2(b_gk_last - b_gk) gating to v_new, which mismatches the implementation in
tops/ops/common/chunk_delta_h.py where v_new is computed as v_new = u - w @ h
and only the state update uses gated keys; remove the line that multiplies b_v
by exp2(b_gk_last - b_gk) (or clarify that only b_h is decayed by
exp2(b_gk_last)) so the pseudocode matches the implementation (referencing
symbols gk, b_v, b_h, v_new, u, w, h and the function/logic in
chunk_delta_h.py).

---

Nitpick comments:
In `@tops/cpu/ops/kda/chunk.py`:
- Around line 593-600: The forward path in chunk_kda is always materializing Aqk
and Akk via chunk_kda_fwd then discarding them, which wastes O(B·H·NT·C²)
compute/memory; add a fast-path flag (e.g., return_intermediates=False) to
chunk_kda and only request/compute Aqk and Akk from chunk_kda_fwd when that flag
is true (or when callers explicitly need intermediates), otherwise call
chunk_kda_fwd with output_intermediates=False (or an equivalent param) and only
capture o and final_state; update the chunk_kda signature and callers to default
to not materializing intermediates and remove the allocations/assignments of
Aqk/Akk on the forward-only path.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ebe2df40-8dad-48d8-ac2c-0595d4c26435

📥 Commits

Reviewing files that changed from the base of the PR and between 337a2d5 and 5b7ffcd.

📒 Files selected for processing (4)
  • docs/design-docs/ops/kda/chunk-bwd.md
  • docs/design-docs/ops/kda/chunk-fwd.md
  • tops/cpu/ops/kda/__init__.py
  • tops/cpu/ops/kda/chunk.py
✅ Files skipped from review due to trivial changes (1)
  • tops/cpu/ops/kda/init.py

Comment on lines +15 to +20
```
Step 0: g = cumsum(g_raw) [chunk_local_cumsum]
Step 1: w, u, kg, Aqk, Akk = chunk_kda_fwd_intra(q, k, v, g, β) [chunk_kda_fwd_intra]
Step 2: h, v_new = chunk_gated_delta_rule_fwd_h(kg, w, u, gk=g) [chunk_gated_delta_rule_fwd_h]
Step 3: o = chunk_gla_fwd_o_gk(q, v_new, g, Aqk, h) [chunk_gla_fwd_o_gk]
```
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add languages to the fenced code blocks.

These fences are missing an info string, so MD040 will keep firing. Even text is enough where the block is pseudocode.

Also applies to: 48-55, 638-652

🧰 Tools
🪛 markdownlint-cli2 (0.22.0)

[warning] 15-15: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/design-docs/ops/kda/chunk-bwd.md` around lines 15 - 20, The fenced code
blocks containing the pseudocode (the block with "Step 0: g = cumsum(g_raw) ...
Step 3: o = chunk_gla_fwd_o_gk(...)") and the other similar blocks at the ranges
mentioned should include an info string (e.g., ```text or ```text+pseudo) to
satisfy MD040; update each triple-backtick opening fence for those blocks
(including the blocks around lines referenced as 48-55 and 638-652) to include a
language token so the linter stops flagging them.

Comment on lines +671 to +676
阶段 1 输出:

| 张量 | Shape | 说明 |
|---|---|---|
| `dAqk` | `[B, T, H, C]` | $`d\mathbf{o} \cdot \mathbf{v}_{\text{new}}^\top`$(下三角) |
| `dv` | `[B, T, H, V]` | $`\mathbf{A}_{qk}^\top \cdot d\mathbf{o}`$ |
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

The summary tables list the matrix-gradient shapes incorrectly.

dAqk and dAkk are chunk-local C x C matrices, so these entries should be [B, H, NT, C, C], not [B, T, H, C]. The current table contradicts the symbol definitions above and the referenced kernels.

📝 Suggested doc fix
-| `dAqk` | `[B, T, H, C]` | $`d\mathbf{o} \cdot \mathbf{v}_{\text{new}}^\top`$(下三角) |
+| `dAqk` | `[B, H, NT, C, C]` | $`d\mathbf{o} \cdot \mathbf{v}_{\text{new}}^\top`$(下三角) |
...
-| `dAkk` | `[B, T, H, C]` | 键-键矩阵逆的梯度 |
+| `dAkk` | `[B, H, NT, C, C]` | 键-键矩阵逆的梯度 |

Also applies to: 686-695

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/design-docs/ops/kda/chunk-bwd.md` around lines 671 - 676, The tables
incorrectly list dAqk and dAkk as [B, T, H, C] but they are chunk-local CxC
matrices; update their shapes in the Stage 1 (and also the entries at 686-695)
to [B, H, NT, C, C] to match the symbol definitions and referenced kernels
(symbols: dAqk, dAkk); also verify dv remains [B, T, H, V] and adjust any
surrounding textual descriptions that rely on the old shape notation.


本设计文档记录了 Kimi Delta Attention (KDA) 的 Chunk-Parallel 前向内核的设计与实现逻辑。

### 核心思想与背景
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the heading hierarchy.

### jumps straight from the title block to an h3. Promote this to ## or insert a parent section so the document passes MD001.

🧰 Tools
🪛 markdownlint-cli2 (0.22.0)

[warning] 11-11: Heading levels should only increment by one level at a time
Expected: h2; Actual: h3

(MD001, heading-increment)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/design-docs/ops/kda/chunk-fwd.md` at line 11, The heading "### 核心思想与背景"
uses h3 without a parent h2 which fails MD001; update this heading to "##
核心思想与背景" (or insert a higher-level parent section above it) so the document has
a proper hierarchy—locate the line with the exact text "### 核心思想与背景" and promote
it to an h2 (change the leading ### to ##) or add an appropriate parent h2
before it.

# Conflicts:
#	tops/cpu/ops/kda/__init__.py
#	tops/cpu/ops/kda/chunk.py
@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2497,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@0xaskr 0xaskr marked this pull request as draft April 13, 2026 02:38
- Extract chunk_kda_bwd_intra as standalone Stage 4 function from chunk_kda_bwd
- Add vectorized PyTorch reference (chunk_bwd_intra_ref.py): replace triple
  loop with 4 batched matmuls using decay factorization exp(g_i-g_j)=exp(g_i)·exp(-g_j)
- Add test_chunk_kda.py: cross-validation, symmetry, and CPU-ref vs Triton tests
- Fix OOM in large-shape Triton test: use numpy with contiguous layout for
  JAX→FLA conversion, run JAX ref on CPU for large tensors

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2612,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 1955,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant