Skip to content

feat: KDA Pallas intra-chunk kernel#188

Merged
0xaskr merged 7 commits intomainfrom
feat/chunk_kda_bwd_intra
Apr 14, 2026
Merged

feat: KDA Pallas intra-chunk kernel#188
0xaskr merged 7 commits intomainfrom
feat/chunk_kda_bwd_intra

Conversation

@0xaskr
Copy link
Copy Markdown
Collaborator

@0xaskr 0xaskr commented Apr 13, 2026

Summary

  • Add Pallas intra-chunk fwd/bwd kernel with segment_ids and g-centered optimization (tops/ops/kda/intra_chunk.py)
  • Add chunk pipeline: 3-stage parallel algorithm (intra -> inter -> output) (tops/cpu/ops/kda/chunk_pipeline.py)
  • Add chunk backward intra reference and CPU chunk KDA implementation
  • Add tests: Pallas vs JAX reference (8 tests), chunk vs recurrent alignment (5 tests), GPU vs FLA comparison
  • Add design docs for chunk fwd/bwd

Test plan

  • uv run pytest tests/ops/kda/test_kda_alignment.py -v — 5/5 passed
  • uv run pytest tests/ops/kda/test_pallas_intra_chunk.py -v — 8/8 passed
  • uv run ruff check — all checks passed
  • GPU vs FLA tests (requires GPU + FLA environment)

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Added chunked Kimi Delta Attention (KDA) implementation with CPU reference and TPU kernel support for optimized attention computation.
  • Documentation

    • Added comprehensive design documents for KDA forward and backward passes with mathematical derivations and implementation specifications.
  • Tests

    • Added test suites validating CPU reference implementation and TPU kernel against each other with comprehensive shape and edge-case coverage.
  • Chores

    • Updated gitignore configuration.

Migrate KDA (Kimi Delta Attention) code from ant-pretrain:
- Pallas intra-chunk fwd/bwd kernel with segment_ids and g-centered optimization
- Chunk pipeline: 3-stage parallel algorithm (intra -> inter -> output)
- Chunk backward intra reference implementation
- CPU reference chunk KDA implementation
- Tests: Pallas vs JAX reference, chunk vs recurrent alignment, GPU vs FLA
- Design docs for chunk fwd/bwd

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2246,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 13, 2026

📝 Walkthrough

Walkthrough

This PR introduces Chunk-Parallel forward and backward implementations for KDA (Kimi Delta Attention), adding a CPU reference implementation in JAX (chunk_kda), a TPU Pallas kernel for intra-chunk backward propagation, comprehensive mathematical design documents, and validation test suites for CPU-to-TPU correctness.

Changes

Cohort / File(s) Summary
Design Documentation
docs/design-docs/ops/kda/chunk-fwd.md, docs/design-docs/ops/kda/chunk-bwd-intra.md
New design docs formalizing the chunk-parallel forward kernel math (chunking strategy, interaction matrices, triangular solve, effective key/value computation) and intra-chunk backward kernel derivations (gradient propagation, reference-point normalization, accumulation logic).
CPU Reference Implementation
tops/cpu/ops/kda/chunk.py, tops/cpu/ops/kda/__init__.py
New JAX CPU reference for chunk-based KDA: chunk_kda public API validates inputs, pads/reshapes tensors into chunked layout, applies chunk-local cumsum to gates, invokes forward/backward passes with intra-chunk interaction matrices and inter-chunk state recurrence, and reshapes outputs back; internal chunk_kda_fwd, _chunk_kda_bwd_intra, and chunk_kda_bwd handle core forward/backward logic including matrix solve and gradient accumulation.
TPU Pallas Kernel
tops/ops/kda/intra_chunk.py, tops/ops/kda/__init__.py
New Pallas TPU kernel kda_intra_chunk_bwd for intra-chunk backward: reads per-chunk slices of q/k/g/beta and upstream gradients, constructs causal/segment masks, computes interaction matrices, propagates gradients to q/k/beta/g via dot products and corrected terms, with block specs and grid-based execution; wrapper reshapes tensors and applies jit compilation with static arguments.
Test Suites
tests/ref/kda/test_chunk_kda.py, tests/ops/kda/test_chunk_kda_tpu.py
CPU reference test suite validates chunk_kda_bwd_intra shapes, zero-gradient symmetries, edge cases (zero gates, extreme beta), linearity, and cross-validates against Triton FLA kernel when available; TPU test suite validates Pallas kernel against CPU reference with bfloat16 conversion and numeric comparison over varied batch/head/chunk configurations.
Module Exports & Housekeeping
tops/cpu/ops/kda/__init__.py, tops/ops/kda/__init__.py, .gitignore
Updated module __all__ to re-export chunk_kda_bwd_intra and kda_intra_chunk_bwd; added .claude/ to gitignore.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant chunk_kda
    participant Reshape
    participant chunk_kda_fwd
    participant InterChunk as Inter-Chunk Recurrence
    participant Output

    User->>chunk_kda: q, k, v, g, beta, chunk_size
    chunk_kda->>Reshape: Pad T, reshape to [B,H,NT,C,*]
    Reshape->>chunk_kda: Reshaped tensors
    chunk_kda->>chunk_kda: Chunk-local cumsum(g)
    chunk_kda->>chunk_kda_fwd: Reshaped q,k,v,g,beta
    
    loop For each chunk
        chunk_kda_fwd->>chunk_kda_fwd: Build A_raw (interaction matrix)
        chunk_kda_fwd->>chunk_kda_fwd: Triangular solve: Akk=(I+A_raw)^-1
        chunk_kda_fwd->>chunk_kda_fwd: Compute effective w,u from Akk
        chunk_kda_fwd->>InterChunk: Intra-chunk attention Aqk, state
    end
    
    InterChunk->>InterChunk: Update state S (delta-rule corrected)
    InterChunk->>InterChunk: Fuse outputs with inter-chunk state
    InterChunk->>chunk_kda_fwd: Chunk outputs, Aqk, Akk, final_state
    chunk_kda_fwd->>Reshape: Stack outputs
    Reshape->>Output: Reshape [B,T_orig,H,V], trim padding
    Output->>User: output, final_state
Loading
sequenceDiagram
    participant User
    participant chunk_kda_bwd
    participant Recompute as Recompute Forward
    participant Recurrence as Reverse-Time Recurrence
    participant dAkk_Grad as dAkk Gradient Path
    participant chunk_kda_bwd_intra
    participant Merge as Merge & Post-Process

    User->>chunk_kda_bwd: dL/doutput, q, k, v, g, beta, chunk_size
    chunk_kda_bwd->>Recompute: Recompute w, u, kg, delta-rule scan
    Recompute->>chunk_kda_bwd: Forward intermediates
    
    chunk_kda_bwd->>chunk_kda_bwd: Form dAqk, dv from output gradient
    
    loop Reverse time over chunks
        chunk_kda_bwd->>Recurrence: Reverse-time state recurrence
        Recurrence->>Recurrence: Accumulate hidden-state gradients
        Recurrence->>Recurrence: Update dv with state corrections
    end
    
    chunk_kda_bwd->>dAkk_Grad: Compute dAkk via inverse-matrix gradient
    dAkk_Grad->>chunk_kda_bwd_intra: dAqk, dAkk
    chunk_kda_bwd_intra->>chunk_kda_bwd_intra: Intra-chunk backward (per-chunk)
    chunk_kda_bwd_intra->>chunk_kda_bwd: dq_intra, dk_intra, dg_intra, dbeta_intra
    
    chunk_kda_bwd->>Merge: Merge inter + intra gradients
    Merge->>Merge: Fused dq, dk, dv, dbeta
    Merge->>Merge: Reverse cumsum(dg) to match forward cumsum
    Merge->>User: dq, dk, dv, dbeta, dg, dinitial_state
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

cpu-ref

Poem

Chunks are split, then stitched with care,
Deltas flowing everywhere—
Matrices solve, states transfer,
CPU reference? We defer
To TPU's Pallas, fast and fair! 🐰✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 65.63% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: KDA Pallas intra-chunk kernel' directly addresses the primary change: adding a Pallas TPU intra-chunk backward kernel for KDA with supporting CPU reference implementations, documentation, and tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/chunk_kda_bwd_intra

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@0xaskr 0xaskr marked this pull request as draft April 13, 2026 09:09
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the design documentation and reference implementations for the Chunk-Parallel forward and backward kernels of Kimi Delta Attention (KDA). The changes include detailed mathematical derivations for the chunked operations, JAX-based CPU reference implementations, and a Pallas-based TPU kernel for intra-chunk computations. Comprehensive tests are also added to ensure alignment between the recurrent and chunked implementations, as well as parity with the FLA library. The review feedback identifies inconsistencies in the mathematical formulas regarding the application of the scaling factor s across inter-chunk and intra-chunk terms, and notes that the Triton code snippets for matrix inverse gradients are oversimplified compared to the actual implementation logic.


回顾前向 Step 3 的输出公式(每个 chunk 内):

$$\mathbf{o} = \underbrace{s \cdot (\mathbf{q} \odot \exp(\mathbf{g})) \cdot \mathbf{h}}_{\text{inter-chunk 项}} + \underbrace{\text{tril}(\mathbf{A}_{qk}) \cdot \mathbf{v}_{\text{new}}}_{\text{intra-chunk 项}}$$
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The formula for the output o appears to be inconsistent with the implementation regarding the scaling factor s. The formula is:
$$\mathbf{o} = \underbrace{s \cdot (\mathbf{q} \odot \exp(\mathbf{g})) \cdot \mathbf{h}}_{\text{inter-chunk 项}} + \underbrace{\text{tril}(\mathbf{A}_{qk}) \cdot \mathbf{v}_{\text{new}}}_{\text{intra-chunk 项}}$$
This suggests only the inter-chunk term is scaled. However, the backward pass implementation (and the forward pass code) scales both the inter-chunk and intra-chunk contributions. For consistency and clarity, please consider updating the formula to reflect that s applies to both terms, for example:
$$\mathbf{o} = s \cdot \left( (\mathbf{q} \odot \exp(\mathbf{g})) \cdot \mathbf{h} + \text{tril}(\mathbf{A}_{qk}) \cdot \mathbf{v}_{\text{new}} \right)$$
This would also align with the gradient calculations shown later in this document, such as the scaling of dAqk on line 120.

# 路径 E: 矩阵逆梯度
b_dA = tl.where(row > col, b_dA, 0) # 严格下三角
b_dA = tl.dot(b_dA, b_A) # 右乘 A
b_dA = tl.dot(b_A, b_dA) # 左乘 A
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Triton kernel code snippet for the matrix inverse gradient calculation seems to be an oversimplification and potentially misleading. The line b_dA = tl.dot(b_A, b_dA) suggests a left multiplication by A_kk, but the correct mathematical formula is dM = -A_kk.T @ dA_kk @ A_kk.T, which involves transposes. The JAX implementation in tops/cpu/ops/kda/chunk.py correctly implements this with transposes.

To avoid confusion, could you please update the code snippet to more accurately reflect the computation, for instance by using tl.trans?


代入第二步的结果,得到最终的 chunk-parallel 输出公式:

$$\mathbf{o}_r = \underbrace{(\mathbf{q}_r \odot \exp(\mathbf{g}_r))^\top \mathbf{S}_0}_{\text{inter-chunk}} + \underbrace{\sum_{j} \mathbf{A}_{qk}(r,j) \cdot (\mathbf{u}_j - \mathbf{w}_j \mathbf{S}_0)}_{\text{intra-chunk}}$$
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output formula for o_r seems to be missing the scaling factor s for the intra-chunk term. The formula is given as:
$$\mathbf{o}_r = \underbrace{(\mathbf{q}_r \odot \exp(\mathbf{g}_r))^\top \mathbf{S}_0}_{\text{inter-chunk}} + \underbrace{\sum_{j} \mathbf{A}_{qk}(r,j) \cdot (\mathbf{u}_j - \mathbf{w}_j \mathbf{S}_0)}_{\text{intra-chunk}}$$
However, the implementation appears to apply the scaling factor to both the inter-chunk and intra-chunk components. To maintain consistency with the code and the backward pass documentation, please consider updating the formula to include the scale factor on both terms. For example:
$$\mathbf{o}_r = s \cdot \left( (\mathbf{q}_r \odot \exp(\mathbf{g}_r))^\top \mathbf{S}_0 + \sum_{j} \mathbf{A}_{qk}(r,j) \cdot (\mathbf{u}_j - \mathbf{w}_j \mathbf{S}_0) \right)$$
This also applies to other output formulas in this document, like the one on line 269.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (6)
tests/ops/kda/test_kda_alignment.py (1)

83-89: Consider using compare_tensor utility per coding guidelines.

The test uses np.testing.assert_allclose directly. The coding guidelines specify using the compare_tensor utility from tests/utils.py for kernel output comparisons, which provides consistent tolerance handling across the test suite.

This pattern repeats in all test methods (lines 109-114, 119-124, 141-146, 163-168, 185-190).

♻️ Example refactor for one assertion
+from tests.utils import compare_tensor
+
 ...
-        np.testing.assert_allclose(
-            np.asarray(out_chunk),
-            np.asarray(out_recur),
-            rtol=1e-4,
-            atol=1e-4,
-            err_msg="chunk vs recurrent output mismatch (basic)",
-        )
+        compare_tensor(
+            out_chunk,
+            out_recur,
+            rtol=1e-4,
+            atol=1e-4,
+            msg="chunk vs recurrent output mismatch (basic)",
+        )

As per coding guidelines: "Use compare_tensor utility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/kda/test_kda_alignment.py` around lines 83 - 89, Replace direct
uses of np.testing.assert_allclose in this test (e.g., the assertion comparing
out_chunk and out_recur) with the compare_tensor utility from tests/utils.py;
locate each occurrence (assert_allclose calls comparing out_chunk vs out_recur
across the test methods) and call compare_tensor(out_chunk, out_recur,
atol=1e-4, rtol=1e-4, max_ulp=None) (or appropriate max_ulp if required by
project standards) so that all kernel comparisons use the centralized tolerance
handling.
tops/cpu/ops/kda/chunk.py (2)

641-648: Prefix unused unpacked variables with underscore.

Aqk and Akk are returned from chunk_kda_fwd but not used in the public chunk_kda wrapper (they're needed for backward, which isn't called here). Prefix with underscore to indicate intentional discard.

♻️ Proposed fix
-  o, Aqk, Akk, final_state = chunk_kda_fwd(
+  o, _Aqk, _Akk, final_state = chunk_kda_fwd(
     q_c, k_c, v_c, g_c, beta_c,
     scale=scale,
     initial_state=initial_state,
     output_final_state=output_final_state,
     C=C,
     acc_dt=acc_dt,
   )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk.py` around lines 641 - 648, The tuple returned by
chunk_kda_fwd in the chunk_kda wrapper unpacks Aqk and Akk but never uses them;
update the unpacking to prefix these unused variables with underscores (e.g.,
_Aqk, _Akk) in the call that currently does "o, Aqk, Akk, final_state =
chunk_kda_fwd(...)" so their discard is explicit and linter-friendly; keep the
rest of the call and parameter names (q_c, k_c, v_c, g_c, beta_c, scale,
initial_state, output_final_state, C, acc_dt) unchanged.

172-219: Add input shape assertions.

The function has excellent documentation but lacks runtime assertions for input validation.

🛡️ Proposed input validation
 `@cpu_reference`
 def chunk_kda_bwd_intra(
   q_c: jax.Array,
   k_c: jax.Array,
   g_c: jax.Array,
   beta_c: jax.Array,
   dAqk: jax.Array,
   dAkk: jax.Array,
   C: int,
   acc_dt: jnp.dtype,
 ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
   """..."""
+  # Shape assertions
+  assert q_c.ndim == 5, f"q_c must be 5D [B,H,NT,C,K], got {q_c.ndim}D"
+  assert k_c.shape == q_c.shape, f"k_c shape mismatch"
+  assert g_c.shape == q_c.shape, f"g_c shape mismatch"
+  B, H, NT, C_actual, K = q_c.shape
+  assert C_actual == C, f"chunk size mismatch: {C_actual} vs {C}"
+  assert beta_c.shape == (B, H, NT, C), f"beta_c shape mismatch"
+  assert dAqk.shape == (B, H, NT, C, C), f"dAqk shape mismatch"
+  assert dAkk.shape == (B, H, NT, C, C), f"dAkk shape mismatch"
+
   NT = q_c.shape[2]

As per coding guidelines: "All public functions must enforce strict input assertions on shape and types before executing main logic"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk.py` around lines 172 - 219, Add runtime assertions at
the start of the chunk_kda_bwd_intra function to validate the shapes and types
of all input arrays q_c, k_c, g_c, beta_c, dAqk, and dAkk according to their
documented shapes. Ensure batch size B, heads H, number of tokens NT, chunk size
C, and feature dimension K are consistent across inputs. Also check that all
inputs have the expected dtypes, especially acc_dt for accumulation and the
array data types for q_c, k_c, g_c, beta_c, dAqk, and dAkk.
tests/ops/kda/test_pallas_intra_chunk.py (1)

149-166: Consider using compare_tensor utility per coding guidelines.

Similar to test_kda_alignment.py, this test uses np.testing.assert_allclose directly. The coding guidelines recommend using compare_tensor from tests/utils.py for consistent tolerance handling.

♻️ Example refactor
+from tests.utils import compare_tensor
+
 ...
-    np.testing.assert_allclose(
-        np.asarray(out_ref),
-        np.asarray(out_pallas),
-        rtol=rtol,
-        atol=atol,
-        err_msg="Pallas vs Reference output mismatch",
-    )
+    compare_tensor(
+        out_ref,
+        out_pallas,
+        rtol=rtol,
+        atol=atol,
+        msg="Pallas vs Reference output mismatch",
+    )

As per coding guidelines: "Use compare_tensor utility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/kda/test_pallas_intra_chunk.py` around lines 149 - 166, Replace
direct np.testing.assert_allclose calls with the compare_tensor test utility:
import compare_tensor and call compare_tensor(np.asarray(out_ref),
np.asarray(out_pallas), atol=atol, rtol=rtol, max_ulp=<appropriate value>)
instead of the first assert_allclose, and for the final state (inside if
output_final_state) call compare_tensor(np.asarray(state_ref),
np.asarray(state_pallas), atol=atol, rtol=rtol, max_ulp=<appropriate value>)
after keeping the output_final_state existence checks; remove the two
np.testing.assert_allclose blocks and ensure compare_tensor is imported at top
of the test file.
tops/cpu/ops/kda/chunk_bwd_intra_ref.py (1)

32-46: Add input shape assertions per coding guidelines.

The function lacks runtime assertions to validate input tensor shapes and types. Per project coding standards, public functions should enforce strict input assertions before executing main logic.

🛡️ Proposed input validation
 def chunk_kda_bwd_intra_ref(
     q: torch.Tensor,      # [B, H, NT, C, K]
     k: torch.Tensor,      # [B, H, NT, C, K]
     g: torch.Tensor,      # [B, H, NT, C, K]  chunk-local cumsummed gates
     beta: torch.Tensor,   # [B, H, NT, C]
     dAqk: torch.Tensor,   # [B, H, NT, C, C]  下三角(含对角线)
     dAkk: torch.Tensor,   # [B, H, NT, C, C]  严格下三角
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     返回:
         dq:   [B, H, NT, C, K]
         dk:   [B, H, NT, C, K]
         db:   [B, H, NT, C]
         dg:   [B, H, NT, C, K]
     """
+    assert q.ndim == 5, f"q must be 5D [B,H,NT,C,K], got {q.ndim}D"
+    assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}"
+    assert g.shape == q.shape, f"g shape {g.shape} != q shape {q.shape}"
+    assert beta.ndim == 4 and beta.shape[:3] == q.shape[:3], (
+        f"beta shape {beta.shape} incompatible with q"
+    )
+    B, H, NT, C, K = q.shape
+    assert dAqk.shape == (B, H, NT, C, C), f"dAqk shape mismatch"
+    assert dAkk.shape == (B, H, NT, C, C), f"dAkk shape mismatch"
+
     # -- 辅助量 --
     eg  = torch.exp(g)

As per coding guidelines: "All public functions must enforce strict input assertions on shape and types before executing main logic"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk_bwd_intra_ref.py` around lines 32 - 46, Add strict
runtime input assertions at the start of chunk_kda_bwd_intra_ref: verify q, k, g
are 5-D tensors with shape [..., C, K]; beta is 4-D with matching leading dims
[B, H, NT, C]; dAqk and dAkk are 5-D with shape [..., C, C]; ensure all tensors
are torch.Tensor, share the same dtype and device, and that C and K are positive
integers and consistent across tensors (e.g., q.shape[-2]==beta.shape[-1],
q.shape[-1]==k.shape[-1], dAqk.shape[-2]==dAqk.shape[-1]==q.shape[-2], same for
dAkk). Place these checks at the top of chunk_kda_bwd_intra_ref before any
computation.
tests/ops/kda/test_gpu_kda_vs_fla.py (1)

288-294: Use compare_tensor for these cross-framework checks.

This pattern is repeated throughout the file. Switching the assertions to tests/utils.py::compare_tensor keeps tolerance handling and failure output consistent with the rest of the repo.

As per coding guidelines, tests/**/*.py: Use compare_tensor utility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/kda/test_gpu_kda_vs_fla.py` around lines 288 - 294, Replace the
direct NumPy assertion with the repository's compare utility: call
compare_tensor(pt2np(out_pt), jax2np(out_jax), atol=atol, rtol=rtol,
max_ulp=<appropriate_value>) instead of np.testing.assert_allclose to keep
tolerance handling and failure messages consistent; ensure compare_tensor is
imported from tests.utils at the top of the file and choose a sensible max_ulp
(or pass None if not applicable) and include the same err_msg context (e.g.,
f"fused_kda_gate mismatch (dt_bias={with_dt_bias})") when invoking
compare_tensor.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ops/kda/test_gpu_kda_vs_fla.py`:
- Around line 28-33: Replace the broad "except Exception" around the JAX
import/device check with targeted exception handling: catch ImportError (for
missing jax) and the JAX-specific initialization errors (e.g., RuntimeError or
jax.lib.xla_extension.XlaRuntimeError) so only genuine unavailability/init
failures set _HAS_JAX_GPU = False, while letting other unexpected exceptions
propagate; update the try/except that encloses the import and the any("cuda"...
for d in jax.devices()) check and reference the _HAS_JAX_GPU symbol and jax
import in your changes.

In `@tops/cpu/ops/kda/chunk_pipeline.py`:
- Around line 152-158: The current activation handling silently defaults to gate
= g_2d for unknown names (using variables activation, g_2d, gate), which hides
typos; instead validate activation explicitly and raise a clear exception for
unsupported values. Replace the final else branch with a fail-fast check that
raises a ValueError (or custom error) listing allowed options ("sigmoid",
"swish", "silu"/"silu" alias handling if desired), so callers get an immediate
error when activation is invalid rather than silently changing semantics.
- Around line 692-706: chunk_kda_reference (and similarly
fused_recurrent_kda_reference) must validate inputs up-front: add assertions
(use tops.utils.assert_shape_or_none where appropriate) to check that q, k, v,
g, beta are arrays of expected rank (e.g., rank >= 2) and share the same leading
dimensions, validate initial_state is either None or has the correct shape and
dtype consistent with the internal state, and ensure cu_seqlens (if provided) is
a rank-1 array and strictly non-decreasing/monotonic; fail fast with clear
assert messages before any reshape/scan/indexing. Use the function names
q/k/v/g/beta/initial_state/cu_seqlens in your checks and apply identical guards
in fused_recurrent_kda_reference (lines referenced) to enforce the coding
guideline for public entrypoints.

In `@tops/ops/kda/intra_chunk.py`:
- Around line 273-324: The index_map lambdas inside the BlockSpec declarations
(e.g., the in_specs/out_specs and their counterparts in the backward call) use
the variable name "l" which triggers Ruff E741; rename that parameter to a
non-conflicting name such as "chunk_idx" in every index_map lambda (e.g., change
"lambda i, j, l: ..." to "lambda i, j, chunk_idx: ..." and update any uses
inside the lambda accordingly) for all BlockSpec entries referenced (including
u, w, qg, kg, Aqk, Akk_inv and the corresponding in_specs and backward
definitions) so the file is lint-clean.
- Around line 221-235: Add strict input assertions before any reshaping in the
intra-chunk attention entrypoint: assert T % chunk_size == 0 (already present)
and additionally assert q, k, g, v have identical shapes (B, H, T, D); assert
beta has shape (B, H, T) or (B, H, T, 1); assert segment_ids is either None or
shape (B, T); assert chunk_size is a multiple of 16; use the project utility
assert_shape_or_none from tops.utils where appropriate to validate optional
arrays and types. Place these checks immediately above the current reshapes
(referencing variables q, k, g, v, beta, segment_ids, chunk_size) so failures
surface with clear messages before any reshape/split operations.
- Around line 387-396: The backward mask for Akk is using a non-strict >= and
thus includes diagonal gradients; change the mask used for Akk to be strictly
lower-triangular to match forward by replacing the >= with > when constructing
mask_akk (e.g. compute mask_akk using idx[:, None] > idx[None, :] combined with
segment_mask or build causal_mask strictly with > and use that only for
mask_akk); leave mask_aqk (the qk mask) as-is so qk causal behavior is
unchanged.
- Around line 245-253: The Pallas kernels are forced into interpreter mode by
passing interpret=True to pl.pallas_call; remove the interpret=True argument
from the pl.pallas_call invocations that wrap kda_intra_chunk_kernel (forward)
and its backward counterpart so the calls can emit compiled backend code for
TPU/GPU execution (keep other args like functools.partial, chunk_size,
head_dim/D, scale, and out_shape unchanged).

---

Nitpick comments:
In `@tests/ops/kda/test_gpu_kda_vs_fla.py`:
- Around line 288-294: Replace the direct NumPy assertion with the repository's
compare utility: call compare_tensor(pt2np(out_pt), jax2np(out_jax), atol=atol,
rtol=rtol, max_ulp=<appropriate_value>) instead of np.testing.assert_allclose to
keep tolerance handling and failure messages consistent; ensure compare_tensor
is imported from tests.utils at the top of the file and choose a sensible
max_ulp (or pass None if not applicable) and include the same err_msg context
(e.g., f"fused_kda_gate mismatch (dt_bias={with_dt_bias})") when invoking
compare_tensor.

In `@tests/ops/kda/test_kda_alignment.py`:
- Around line 83-89: Replace direct uses of np.testing.assert_allclose in this
test (e.g., the assertion comparing out_chunk and out_recur) with the
compare_tensor utility from tests/utils.py; locate each occurrence
(assert_allclose calls comparing out_chunk vs out_recur across the test methods)
and call compare_tensor(out_chunk, out_recur, atol=1e-4, rtol=1e-4,
max_ulp=None) (or appropriate max_ulp if required by project standards) so that
all kernel comparisons use the centralized tolerance handling.

In `@tests/ops/kda/test_pallas_intra_chunk.py`:
- Around line 149-166: Replace direct np.testing.assert_allclose calls with the
compare_tensor test utility: import compare_tensor and call
compare_tensor(np.asarray(out_ref), np.asarray(out_pallas), atol=atol,
rtol=rtol, max_ulp=<appropriate value>) instead of the first assert_allclose,
and for the final state (inside if output_final_state) call
compare_tensor(np.asarray(state_ref), np.asarray(state_pallas), atol=atol,
rtol=rtol, max_ulp=<appropriate value>) after keeping the output_final_state
existence checks; remove the two np.testing.assert_allclose blocks and ensure
compare_tensor is imported at top of the test file.

In `@tops/cpu/ops/kda/chunk_bwd_intra_ref.py`:
- Around line 32-46: Add strict runtime input assertions at the start of
chunk_kda_bwd_intra_ref: verify q, k, g are 5-D tensors with shape [..., C, K];
beta is 4-D with matching leading dims [B, H, NT, C]; dAqk and dAkk are 5-D with
shape [..., C, C]; ensure all tensors are torch.Tensor, share the same dtype and
device, and that C and K are positive integers and consistent across tensors
(e.g., q.shape[-2]==beta.shape[-1], q.shape[-1]==k.shape[-1],
dAqk.shape[-2]==dAqk.shape[-1]==q.shape[-2], same for dAkk). Place these checks
at the top of chunk_kda_bwd_intra_ref before any computation.

In `@tops/cpu/ops/kda/chunk.py`:
- Around line 641-648: The tuple returned by chunk_kda_fwd in the chunk_kda
wrapper unpacks Aqk and Akk but never uses them; update the unpacking to prefix
these unused variables with underscores (e.g., _Aqk, _Akk) in the call that
currently does "o, Aqk, Akk, final_state = chunk_kda_fwd(...)" so their discard
is explicit and linter-friendly; keep the rest of the call and parameter names
(q_c, k_c, v_c, g_c, beta_c, scale, initial_state, output_final_state, C,
acc_dt) unchanged.
- Around line 172-219: Add runtime assertions at the start of the
chunk_kda_bwd_intra function to validate the shapes and types of all input
arrays q_c, k_c, g_c, beta_c, dAqk, and dAkk according to their documented
shapes. Ensure batch size B, heads H, number of tokens NT, chunk size C, and
feature dimension K are consistent across inputs. Also check that all inputs
have the expected dtypes, especially acc_dt for accumulation and the array data
types for q_c, k_c, g_c, beta_c, dAqk, and dAkk.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4646ca06-d487-461c-a6d0-d121767a9248

📥 Commits

Reviewing files that changed from the base of the PR and between 734cf52 and 23f0bdd.

📒 Files selected for processing (12)
  • docs/design-docs/ops/kda/chunk-bwd.md
  • docs/design-docs/ops/kda/chunk-fwd.md
  • tests/ops/kda/__init__.py
  • tests/ops/kda/test_gpu_kda_vs_fla.py
  • tests/ops/kda/test_kda_alignment.py
  • tests/ops/kda/test_pallas_intra_chunk.py
  • tops/cpu/ops/kda/__init__.py
  • tops/cpu/ops/kda/chunk.py
  • tops/cpu/ops/kda/chunk_bwd_intra_ref.py
  • tops/cpu/ops/kda/chunk_pipeline.py
  • tops/ops/kda/__init__.py
  • tops/ops/kda/intra_chunk.py

Comment on lines +28 to +33
try:
import jax

_HAS_JAX_GPU = any("cuda" in str(d).lower() for d in jax.devices())
except Exception:
_HAS_JAX_GPU = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don't turn arbitrary JAX failures into a skip.

except Exception will silently skip this whole module on unrelated local regressions too. Catch the expected availability/init errors only so real failures still surface.

🧰 Tools
🪛 Ruff (0.15.9)

[warning] 32-32: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/kda/test_gpu_kda_vs_fla.py` around lines 28 - 33, Replace the broad
"except Exception" around the JAX import/device check with targeted exception
handling: catch ImportError (for missing jax) and the JAX-specific
initialization errors (e.g., RuntimeError or
jax.lib.xla_extension.XlaRuntimeError) so only genuine unavailability/init
failures set _HAS_JAX_GPU = False, while letting other unexpected exceptions
propagate; update the try/except that encloses the import and the any("cuda"...
for d in jax.devices()) check and reference the _HAS_JAX_GPU symbol and jax
import in your changes.

Comment on lines +152 to +158
if activation in ("swish", "silu"):
gate = g_2d * jax.nn.sigmoid(g_2d) # swish(x) = x * sigmoid(x)
elif activation == "sigmoid":
gate = jax.nn.sigmoid(g_2d)
else:
gate = g_2d

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Reject unsupported activations instead of silently changing semantics.

The docstring only advertises "sigmoid", "swish", and "silu", but any typo currently falls through to gate = g_2d. That makes a bad config look valid while changing layer behavior.

Fail fast on invalid activation names
     if activation in ("swish", "silu"):
         gate = g_2d * jax.nn.sigmoid(g_2d)  # swish(x) = x * sigmoid(x)
     elif activation == "sigmoid":
         gate = jax.nn.sigmoid(g_2d)
     else:
-        gate = g_2d
+        raise ValueError(
+            f"Unsupported activation {activation!r}. Expected 'sigmoid', 'swish', or 'silu'."
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if activation in ("swish", "silu"):
gate = g_2d * jax.nn.sigmoid(g_2d) # swish(x) = x * sigmoid(x)
elif activation == "sigmoid":
gate = jax.nn.sigmoid(g_2d)
else:
gate = g_2d
if activation in ("swish", "silu"):
gate = g_2d * jax.nn.sigmoid(g_2d) # swish(x) = x * sigmoid(x)
elif activation == "sigmoid":
gate = jax.nn.sigmoid(g_2d)
else:
raise ValueError(
f"Unsupported activation {activation!r}. Expected 'sigmoid', 'swish', or 'silu'."
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk_pipeline.py` around lines 152 - 158, The current
activation handling silently defaults to gate = g_2d for unknown names (using
variables activation, g_2d, gate), which hides typos; instead validate
activation explicitly and raise a clear exception for unsupported values.
Replace the final else branch with a fail-fast check that raises a ValueError
(or custom error) listing allowed options ("sigmoid", "swish", "silu"/"silu"
alias handling if desired), so callers get an immediate error when activation is
invalid rather than silently changing semantics.

Comment on lines +692 to +706
def chunk_kda_reference(
q: Array,
k: Array,
v: Array,
g: Array,
beta: Array,
scale: float | None = None,
initial_state: Array | None = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
use_gate_in_kernel: bool = False,
cu_seqlens: Array | None = None,
use_pallas: bool = False,
**kwargs: Any,
) -> tuple[Array, Array | None]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Public KDA entrypoints need explicit shape/type checks.

chunk_kda_reference and fused_recurrent_kda_reference currently rely on downstream reshape, scan, and indexing errors to catch bad inputs. Please assert the expected ranks and shared leading dimensions for q/k/v/g/beta, validate initial_state, and check that cu_seqlens is rank-1 and monotonic before preprocessing.

As per coding guidelines, **/*.py: All public functions must enforce strict input assertions on shape and types before executing main logic using assert instructions or utilities like assert_shape_or_none from tops.utils.

Also applies to: 795-808

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk_pipeline.py` around lines 692 - 706,
chunk_kda_reference (and similarly fused_recurrent_kda_reference) must validate
inputs up-front: add assertions (use tops.utils.assert_shape_or_none where
appropriate) to check that q, k, v, g, beta are arrays of expected rank (e.g.,
rank >= 2) and share the same leading dimensions, validate initial_state is
either None or has the correct shape and dtype consistent with the internal
state, and ensure cu_seqlens (if provided) is a rank-1 array and strictly
non-decreasing/monotonic; fail fast with clear assert messages before any
reshape/scan/indexing. Use the function names
q/k/v/g/beta/initial_state/cu_seqlens in your checks and apply identical guards
in fused_recurrent_kda_reference (lines referenced) to enforce the coding
guideline for public entrypoints.

Comment on lines +221 to +235
B, H, T, D = k.shape
assert T % chunk_size == 0, "Sequence length must be divisible by chunk_size"
num_chunks = T // chunk_size

if segment_ids is None:
# Default: all tokens belong to segment 0
segment_ids = jnp.zeros((B, T), dtype=jnp.int32)

q_reshaped = q.reshape(B, H, num_chunks, chunk_size, D)
k_reshaped = k.reshape(B, H, num_chunks, chunk_size, D)
g_reshaped = g.reshape(B, H, num_chunks, chunk_size, D)
beta_reshaped = beta.reshape(B, H, num_chunks, chunk_size, 1)
v_reshaped = v.reshape(B, H, num_chunks, chunk_size, D)
segment_ids_reshaped = segment_ids.reshape(B, 1, num_chunks, chunk_size, 1)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Validate the full tensor contract before reshaping.

These entrypoints only guard T % chunk_size == 0, but the implementation assumes much more: q/k/g/v share [B, H, T, D], beta is [B, H, T], segment_ids is [B, T], and the forward path’s block solver only works when chunk_size is a multiple of 16. Right now bad inputs fail later inside reshape/jnp.split with hard-to-diagnose errors.

Suggested guardrail block
     B, H, T, D = k.shape
     assert T % chunk_size == 0, "Sequence length must be divisible by chunk_size"
+    assert q.shape == k.shape == g.shape == v.shape, "q/k/g/v must all be [B, H, T, D]"
+    assert beta.shape == (B, H, T), "beta must be [B, H, T]"
+    assert chunk_size % 16 == 0, "chunk_size must be a multiple of 16 for the block solver"
+    if segment_ids is not None:
+        assert segment_ids.shape == (B, T), "segment_ids must be [B, T]"
As per coding guidelines, `**/*.py`: All public functions must enforce strict input assertions on shape and types before executing main logic using `assert` instructions or utilities like `assert_shape_or_none` from `tops.utils`.

Also applies to: 491-503

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/intra_chunk.py` around lines 221 - 235, Add strict input
assertions before any reshaping in the intra-chunk attention entrypoint: assert
T % chunk_size == 0 (already present) and additionally assert q, k, g, v have
identical shapes (B, H, T, D); assert beta has shape (B, H, T) or (B, H, T, 1);
assert segment_ids is either None or shape (B, T); assert chunk_size is a
multiple of 16; use the project utility assert_shape_or_none from tops.utils
where appropriate to validate optional arrays and types. Place these checks
immediately above the current reshapes (referencing variables q, k, g, v, beta,
segment_ids, chunk_size) so failures surface with clear messages before any
reshape/split operations.

Comment on lines +245 to +253
) = pl.pallas_call(
functools.partial(
kda_intra_chunk_kernel,
chunk_size=chunk_size,
head_dim=D,
scale=scale,
),
interpret=True,
out_shape=[
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In JAX Pallas, what does interpret=Truedo inpl.pallas_call(...)? Does it bypass backend code generation and run through the interpreter/debug path instead of the compiled TPU/GPU kernel?

💡 Result:

In JAX Pallas, interpret=True in pl.pallas_call(...) runs the kernel as a jax.jit of a scan over the grid, where the body is the kernel lowered as a JAX function. This bypasses the backend-specific code generation (e.g., compiled TPU/GPU kernels via Mosaic or other backends) and instead executes through an interpreter/emulation path, primarily for debugging. It enables running Pallas kernels on CPU (the only way on CPU) without requiring TPU/GPU hardware, simulating TPU behaviors like shared memory and synchronization when using TPU interpret mode.

Citations:


🏁 Script executed:

# Check the file and specific lines
wc -l tops/ops/kda/intra_chunk.py

Repository: primatrix/pallas-kernel

Length of output: 98


🏁 Script executed:

# Read the specific sections mentioned
sed -n '240,260p' tops/ops/kda/intra_chunk.py

Repository: primatrix/pallas-kernel

Length of output: 686


🏁 Script executed:

# Also check the second location
sed -n '500,520p' tops/ops/kda/intra_chunk.py

Repository: primatrix/pallas-kernel

Length of output: 838


🏁 Script executed:

# Search for other pl.pallas_call usages in the file to understand the pattern
rg "pl.pallas_call" tops/ops/kda/intra_chunk.py -A 5 -B 2

Repository: primatrix/pallas-kernel

Length of output: 546


🏁 Script executed:

# Check for TODO or FIXME comments near these lines
sed -n '240,260p' tops/ops/kda/intra_chunk.py | grep -i "todo\|fixme\|interpret\|debug"

Repository: primatrix/pallas-kernel

Length of output: 90


🏁 Script executed:

# Check the kernel definitions and their structure
grep -n "def kda_intra_chunk_kernel\|def kda_intra_chunk_bwd_kernel" tops/ops/kda/intra_chunk.py -A 10

Repository: primatrix/pallas-kernel

Length of output: 509


🏁 Script executed:

# Look for any comments in the wider context around these pallas_call invocations
sed -n '230,270p' tops/ops/kda/intra_chunk.py

Repository: primatrix/pallas-kernel

Length of output: 1538


Remove interpret=True to enable compiled Pallas kernel execution.

Setting interpret=True in pl.pallas_call() bypasses backend-specific code generation and runs the kernel through an interpreter/emulation path instead of compiling to actual TPU/GPU code. Both the forward pass (lines 245-253) and backward pass (lines 506-513) are pinned to interpreter mode, preventing these Pallas kernels from executing on accelerated hardware.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/intra_chunk.py` around lines 245 - 253, The Pallas kernels are
forced into interpreter mode by passing interpret=True to pl.pallas_call; remove
the interpret=True argument from the pl.pallas_call invocations that wrap
kda_intra_chunk_kernel (forward) and its backward counterpart so the calls can
emit compiled backend code for TPU/GPU execution (keep other args like
functools.partial, chunk_size, head_dim/D, scale, and out_shape unchanged).

Comment on lines +273 to +324
in_specs=[
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, D),
), # q
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, D),
), # k
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, D),
), # g
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, 1),
), # beta
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, D),
), # v
pl.BlockSpec(
index_map=lambda i, j, l: (i, 0, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, 1),
), # segment_ids
],
out_specs=[
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, D),
), # u
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, D),
), # w
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, D),
), # qg
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, D),
), # kg
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, chunk_size),
), # Aqk
pl.BlockSpec(
index_map=lambda i, j, l: (i, j, l, 0, 0),
block_shape=(1, 1, 1, chunk_size, chunk_size),
), # Akk_inv
],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rename the l grid axis before this hits Ruff.

The index_map lambdas use l, which triggers E741 on every BlockSpec here and in the backward call. Renaming it to something like chunk_idx keeps the file lint-clean.

Also applies to: 528-575

🧰 Tools
🪛 Ruff (0.15.9)

[error] 275-275: Ambiguous variable name: l

(E741)


[error] 279-279: Ambiguous variable name: l

(E741)


[error] 283-283: Ambiguous variable name: l

(E741)


[error] 287-287: Ambiguous variable name: l

(E741)


[error] 291-291: Ambiguous variable name: l

(E741)


[error] 295-295: Ambiguous variable name: l

(E741)


[error] 301-301: Ambiguous variable name: l

(E741)


[error] 305-305: Ambiguous variable name: l

(E741)


[error] 309-309: Ambiguous variable name: l

(E741)


[error] 313-313: Ambiguous variable name: l

(E741)


[error] 317-317: Ambiguous variable name: l

(E741)


[error] 321-321: Ambiguous variable name: l

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/intra_chunk.py` around lines 273 - 324, The index_map lambdas
inside the BlockSpec declarations (e.g., the in_specs/out_specs and their
counterparts in the backward call) use the variable name "l" which triggers Ruff
E741; rename that parameter to a non-conflicting name such as "chunk_idx" in
every index_map lambda (e.g., change "lambda i, j, l: ..." to "lambda i, j,
chunk_idx: ..." and update any uses inside the lambda accordingly) for all
BlockSpec entries referenced (including u, w, qg, kg, Aqk, Akk_inv and the
corresponding in_specs and backward definitions) so the file is lint-clean.

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2334,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@0xaskr 0xaskr changed the title feat: KDA chunk pipeline + Pallas intra-chunk kernel feat: KDA Pallas intra-chunk kernel Apr 13, 2026
@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2015,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2015,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@0xaskr 0xaskr marked this pull request as ready for review April 13, 2026 15:58
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (1)
tops/ops/kda/intra_chunk.py (1)

68-77: ⚠️ Potential issue | 🔴 Critical

Keep dAkk strictly lower-triangular.

mask_akk is built from >=, so diagonal dAkk terms survive even though the forward Akk graph only has strictly-lower entries. That leaks spurious diagonal gradient into dbeta, dk, and dg.

🔧 Minimal fix
     idx = jnp.arange(chunk_size, dtype=jnp.int32)
-    causal_mask = idx[:, None] >= idx[None, :]
     causal_mask_qk = idx[:, None] >= idx[None, :]
+    causal_mask_akk = idx[:, None] > idx[None, :]
     segment_mask = segment_ids[:, None] == segment_ids[None, :]

-    mask_akk = causal_mask & segment_mask
+    mask_akk = causal_mask_akk & segment_mask
     mask_aqk = causal_mask_qk & segment_mask
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/intra_chunk.py` around lines 68 - 77, The current mask_akk uses
>= so diagonal elements pass through and allow spurious diagonal gradient into
dbeta/dk/dg; change the construction of mask_akk to be strictly lower-triangular
by using > instead of >= (i.e., set mask_akk = idx[:, None] > idx[None, :]) so
dAkk becomes strictly lower-triangular before computing dAkk_masked; keep
mask_aqk as-is if cross terms should include the diagonal. Ensure references:
idx, causal_mask/causal_mask_qk, mask_akk, dAkk, and dAkk_masked are updated
accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/cpu/ops/kda/chunk.py`:
- Around line 637-666: Move all input validation/assertions to the top of the
function before reading q.shape, computing acc dtype or doing any padding math
(i.e., before uses of q, _acc_dtype, chunk_size C and _cdiv). Specifically,
validate q.ndim, chunk_size > 0, and types/shapes for k, v, g, beta and
initial_state using assert statements or the helper assert_shape_or_none from
tops.utils; only after those checks compute orig_dtype = v.dtype, acc_dt =
_acc_dtype(q.dtype), unpack B,T_orig,H,K = q.shape and compute T_padded =
_cdiv(T_orig, C) * C. Ensure the same unique symbols (q, k, v, g, beta,
initial_state, chunk_size/C, _acc_dtype, _cdiv) are referenced so the assertions
guard all downstream operations.
- Around line 299-309: Reorder and strengthen the input guards so you validate
rank and chunk_size before unpacking/using shapes: assert q.ndim == 4 (or use
assert_shape_or_none) and assert chunk_size > 0 (C != 0) up front, then unpack
B, T, H, K = q.shape and only after that check T % C == 0 and the remaining
shape assertions for k, g, beta, dAqk, dAkk; replace any implicit operations
that can raise raw exceptions with explicit asserts (use assert_shape_or_none
from tops.utils if available) so q, chunk_size, and all tensor shapes are
validated in the function in chunk.py before any unpack/division occurs.

In `@tops/ops/kda/intra_chunk.py`:
- Around line 19-20: The module docstring incorrectly states the tensor layout
as [B, H, T, D]; update it to the correct layout [B, T, H, D] to match how
kda_intra_chunk_bwd and its callers expect tensors, ensuring the documented
transpose contract aligns with the implementation in functions like
kda_intra_chunk_bwd.
- Around line 181-205: The code unpacks B, T, H, D from k.shape before
validating tensor rank/dtype and pallas_call currently hard-codes k.dtype while
kernel reads dq from q.dtype and dbeta from beta.dtype; add upfront assertions
for q.ndim, k.ndim, g.ndim and beta.ndim and explicit dtype checks
(q.dtype==k.dtype==g.dtype where appropriate, and beta.dtype checked for dbeta),
or use tops.utils.assert_shape_or_none to validate shapes before B,T,H,D
assignment, and change the pallas_call invocation to use the correct output
dtypes (derive dq dtype from q.dtype and dbeta dtype from beta.dtype) rather
than always using k.dtype so mixed-dtype or bad-rank inputs fail early and
deterministically (refer to symbols: k, q, g, beta, dAqk, dAkk, segment_ids,
pallas_call, dq, dbeta, assert_shape_or_none).

---

Duplicate comments:
In `@tops/ops/kda/intra_chunk.py`:
- Around line 68-77: The current mask_akk uses >= so diagonal elements pass
through and allow spurious diagonal gradient into dbeta/dk/dg; change the
construction of mask_akk to be strictly lower-triangular by using > instead of
>= (i.e., set mask_akk = idx[:, None] > idx[None, :]) so dAkk becomes strictly
lower-triangular before computing dAkk_masked; keep mask_aqk as-is if cross
terms should include the diagonal. Ensure references: idx,
causal_mask/causal_mask_qk, mask_akk, dAkk, and dAkk_masked are updated
accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 624f24f5-8342-4cb0-b3fe-98211b852542

📥 Commits

Reviewing files that changed from the base of the PR and between 23f0bdd and 831d0cc.

📒 Files selected for processing (6)
  • tests/ops/kda/test_chunk_kda_tpu.py
  • tests/ref/kda/test_chunk_kda.py
  • tops/cpu/ops/kda/__init__.py
  • tops/cpu/ops/kda/chunk.py
  • tops/ops/kda/__init__.py
  • tops/ops/kda/intra_chunk.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tops/cpu/ops/kda/init.py

Comment on lines +299 to +309
B, T, H, K = q.shape
C = chunk_size

assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D"
assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}"
assert g.shape == q.shape, f"g shape {g.shape} != q shape {q.shape}"
assert beta.shape == (B, T, H), f"beta shape {beta.shape} != ({B}, {T}, {H})"
assert T % C == 0, f"T={T} must be divisible by chunk_size={C}"
assert dAqk.shape == (B, T, H, C), f"dAqk shape {dAqk.shape} != ({B}, {T}, {H}, {C})"
assert dAkk.shape == (B, T, H, C), f"dAkk shape {dAkk.shape} != ({B}, {T}, {H}, {C})"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate rank and chunk_size before unpacking/using them.

B, T, H, K = q.shape and T % C both run before the public contract is checked, so a bad-rank q or chunk_size=0 raises a raw unpack/division error instead of the intended assertion.

🛡️ Suggested guard ordering
-  B, T, H, K = q.shape
   C = chunk_size
-
   assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D"
+  assert C > 0, f"chunk_size must be positive, got {C}"
+  B, T, H, K = q.shape
   assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}"
   assert g.shape == q.shape, f"g shape {g.shape} != q shape {q.shape}"
   assert beta.shape == (B, T, H), f"beta shape {beta.shape} != ({B}, {T}, {H})"
As per coding guidelines, `**/*.py`: All public functions must enforce strict input assertions on shape and types before executing main logic using `assert` instructions or utilities like `assert_shape_or_none` from `tops.utils`.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
B, T, H, K = q.shape
C = chunk_size
assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D"
assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}"
assert g.shape == q.shape, f"g shape {g.shape} != q shape {q.shape}"
assert beta.shape == (B, T, H), f"beta shape {beta.shape} != ({B}, {T}, {H})"
assert T % C == 0, f"T={T} must be divisible by chunk_size={C}"
assert dAqk.shape == (B, T, H, C), f"dAqk shape {dAqk.shape} != ({B}, {T}, {H}, {C})"
assert dAkk.shape == (B, T, H, C), f"dAkk shape {dAkk.shape} != ({B}, {T}, {H}, {C})"
C = chunk_size
assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D"
assert C > 0, f"chunk_size must be positive, got {C}"
B, T, H, K = q.shape
assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}"
assert g.shape == q.shape, f"g shape {g.shape} != q shape {q.shape}"
assert beta.shape == (B, T, H), f"beta shape {beta.shape} != ({B}, {T}, {H})"
assert T % C == 0, f"T={T} must be divisible by chunk_size={C}"
assert dAqk.shape == (B, T, H, C), f"dAqk shape {dAqk.shape} != ({B}, {T}, {H}, {C})"
assert dAkk.shape == (B, T, H, C), f"dAkk shape {dAkk.shape} != ({B}, {T}, {H}, {C})"
🧰 Tools
🪛 Ruff (0.15.9)

[warning] 299-299: Unpacked variable K is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk.py` around lines 299 - 309, Reorder and strengthen the
input guards so you validate rank and chunk_size before unpacking/using shapes:
assert q.ndim == 4 (or use assert_shape_or_none) and assert chunk_size > 0 (C !=
0) up front, then unpack B, T, H, K = q.shape and only after that check T % C ==
0 and the remaining shape assertions for k, g, beta, dAqk, dAkk; replace any
implicit operations that can raise raw exceptions with explicit asserts (use
assert_shape_or_none from tops.utils if available) so q, chunk_size, and all
tensor shapes are validated in the function in chunk.py before any
unpack/division occurs.

Comment on lines +637 to +666
orig_dtype = v.dtype
acc_dt = _acc_dtype(q.dtype)
B, T_orig, H, K = q.shape
V = v.shape[-1]
C = chunk_size

# Shape assertions (project coding standard)
assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D"
assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}"
assert v.ndim == 4 and v.shape[:3] == q.shape[:3], (
f"v shape {v.shape} incompatible with q shape {q.shape}"
)
assert g.ndim == 4 and g.shape == q.shape, (
f"g shape {g.shape} != q shape {q.shape}"
)
assert beta.ndim == 3 and beta.shape == q.shape[:3], (
f"beta shape {beta.shape} != {q.shape[:3]}"
)
if initial_state is not None:
assert initial_state.shape == (B, H, K, V), (
f"initial_state shape {initial_state.shape} != ({B}, {H}, {K}, {V})"
)

if scale is None:
scale = K ** -0.5

# --- Pad T to multiple of chunk_size ---
T = T_orig
T_padded = _cdiv(T_orig, C) * C
if T_padded > T_orig:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Move API validation ahead of shape unpacking and padding math.

This entrypoint reads q.shape and computes _cdiv(T_orig, C) before any guardrails. Invalid ranks or chunk_size <= 0 will fail with low-level exceptions instead of a clear contract error.

🛡️ Suggested guard ordering
   orig_dtype = v.dtype
   acc_dt = _acc_dtype(q.dtype)
-  B, T_orig, H, K = q.shape
-  V = v.shape[-1]
   C = chunk_size
+  assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D"
+  assert C > 0, f"chunk_size must be positive, got {C}"
+  B, T_orig, H, K = q.shape
+  V = v.shape[-1]

   # Shape assertions (project coding standard)
-  assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D"
   assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}"
As per coding guidelines, `**/*.py`: All public functions must enforce strict input assertions on shape and types before executing main logic using `assert` instructions or utilities like `assert_shape_or_none` from `tops.utils`.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk.py` around lines 637 - 666, Move all input
validation/assertions to the top of the function before reading q.shape,
computing acc dtype or doing any padding math (i.e., before uses of q,
_acc_dtype, chunk_size C and _cdiv). Specifically, validate q.ndim, chunk_size >
0, and types/shapes for k, v, g, beta and initial_state using assert statements
or the helper assert_shape_or_none from tops.utils; only after those checks
compute orig_dtype = v.dtype, acc_dt = _acc_dtype(q.dtype), unpack B,T_orig,H,K
= q.shape and compute T_padded = _cdiv(T_orig, C) * C. Ensure the same unique
symbols (q, k, v, g, beta, initial_state, chunk_size/C, _acc_dtype, _cdiv) are
referenced so the assertions guard all downstream operations.

Comment on lines +19 to +20
Tensor layout: [B, H, T, D] (batch, heads, time, head_dim).
"""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the tensor layout in the module docstring.

The file header says [B, H, T, D], but kda_intra_chunk_bwd and its callers operate on [B, T, H, D]. Leaving this inverted will mislead the next caller into the wrong transpose contract.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/intra_chunk.py` around lines 19 - 20, The module docstring
incorrectly states the tensor layout as [B, H, T, D]; update it to the correct
layout [B, T, H, D] to match how kda_intra_chunk_bwd and its callers expect
tensors, ensuring the documented transpose contract aligns with the
implementation in functions like kda_intra_chunk_bwd.

Comment on lines +181 to +205
B, T, H, D = k.shape
C = chunk_size

assert q.ndim == 4, f"q must be 4D [B,T,H,D], got {q.ndim}D"
assert q.shape == (B, T, H, D), f"q shape {q.shape} != k shape {k.shape}"
assert g.shape == (B, T, H, D), f"g shape {g.shape} != ({B}, {T}, {H}, {D})"
assert beta.shape == (B, T, H), f"beta shape {beta.shape} != ({B}, {T}, {H})"
assert T % C == 0, f"T={T} must be divisible by chunk_size={C}"

NT = T // C

assert dAqk.shape == (B, T, H, C), (
f"dAqk shape {dAqk.shape} != ({B}, {T}, {H}, {C})"
)
assert dAkk.shape == (B, T, H, C), (
f"dAkk shape {dAkk.shape} != ({B}, {T}, {H}, {C})"
)
if segment_ids is not None:
assert segment_ids.shape == (B, T), (
f"segment_ids shape {segment_ids.shape} != ({B}, {T})"
)

if segment_ids is None:
segment_ids = jnp.zeros((B, T), dtype=jnp.int32)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate rank and dtype before deriving kernel shapes.

B, T, H, D = k.shape runs before any rank check, and pallas_call hard-codes k.dtype for every output while the kernel computes dq from q.dtype and dbeta from beta.dtype. Bad-rank or mixed-dtype calls will fail late or silently coerce buffers.

🛡️ Suggested contract checks
-    B, T, H, D = k.shape
     C = chunk_size
-
-    assert q.ndim == 4, f"q must be 4D [B,T,H,D], got {q.ndim}D"
+    assert q.ndim == 4, f"q must be 4D [B,T,H,D], got {q.ndim}D"
+    assert k.ndim == 4, f"k must be 4D [B,T,H,D], got {k.ndim}D"
+    assert C > 0, f"chunk_size must be positive, got {C}"
+    assert q.dtype == k.dtype == g.dtype == beta.dtype == dAqk.dtype == dAkk.dtype, (
+        "q/k/g/beta/dAqk/dAkk must share dtype"
+    )
+    B, T, H, D = k.shape
     assert q.shape == (B, T, H, D), f"q shape {q.shape} != k shape {k.shape}"
     assert g.shape == (B, T, H, D), f"g shape {g.shape} != ({B}, {T}, {H}, {D})"
As per coding guidelines, `**/*.py`: All public functions must enforce strict input assertions on shape and types before executing main logic using `assert` instructions or utilities like `assert_shape_or_none` from `tops.utils`.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/intra_chunk.py` around lines 181 - 205, The code unpacks B, T,
H, D from k.shape before validating tensor rank/dtype and pallas_call currently
hard-codes k.dtype while kernel reads dq from q.dtype and dbeta from beta.dtype;
add upfront assertions for q.ndim, k.ndim, g.ndim and beta.ndim and explicit
dtype checks (q.dtype==k.dtype==g.dtype where appropriate, and beta.dtype
checked for dbeta), or use tops.utils.assert_shape_or_none to validate shapes
before B,T,H,D assignment, and change the pallas_call invocation to use the
correct output dtypes (derive dq dtype from q.dtype and dbeta dtype from
beta.dtype) rather than always using k.dtype so mixed-dtype or bad-rank inputs
fail early and deterministically (refer to symbols: k, q, g, beta, dAqk, dAkk,
segment_ids, pallas_call, dq, dbeta, assert_shape_or_none).

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2015,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 1028,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/design-docs/ops/kda/chunk-bwd-intra.md`:
- Around line 15-27: The fenced code blocks containing the ASCII diagram
(starting with "chunk_kda_bwd_wy_dqkg_fused ──→ ..." and the block starting with
"Grid: (NK * NC, NT, B * H)") are missing language identifiers; update each
triple-backtick fence to include a language tag (e.g., ```text) so linters
(MD040) and readers correctly render them, and apply the same fix to the other
similar block later in the file (the block around the "Grid: ..." section
referenced in the review).
- Line 184: The implementation uses two different variable names for the same
reverse cumsum input causing ambiguity; standardize to one name (either dg or
dg2) across the file by replacing usages of the other with the chosen identifier
so both calls to chunk_local_cumsum use the same variable (e.g., ensure
chunk_local_cumsum(dg, reverse=True) and any downstream code that reads/writes
dg2 are updated to dg), and update any related variable assignment sites and
comments referencing dg2 to use dg so chunk_local_cumsum and its consumers are
consistent.
- Line 11: The heading "### 定位与职责" jumps directly to H3 after the front matter;
change it to an H2 (replace "###" with "## 定位与职责") so the document follows a
proper top-level section after the front matter and update any subsequent
sibling headings to maintain consistent incremental levels relative to this
section.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5c5ab626-fcfe-4d3d-9ffc-b7ef7d93d6c5

📥 Commits

Reviewing files that changed from the base of the PR and between 831d0cc and d185e75.

📒 Files selected for processing (3)
  • .gitignore
  • docs/design-docs/ops/kda/chunk-bwd-intra.md
  • tops/cpu/ops/kda/__init__.py
✅ Files skipped from review due to trivial changes (1)
  • .gitignore
🚧 Files skipped from review as they are similar to previous changes (1)
  • tops/cpu/ops/kda/init.py


---

### 定位与职责
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix heading level jump to satisfy Markdown structure

Line 11 starts at ### directly after front matter/separators; this trips MD001 and hurts doc outline consistency.

Proposed fix
-### 定位与职责
+## 定位与职责
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
### 定位与职责
## 定位与职责
🧰 Tools
🪛 markdownlint-cli2 (0.22.0)

[warning] 11-11: Heading levels should only increment by one level at a time
Expected: h2; Actual: h3

(MD001, heading-increment)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/design-docs/ops/kda/chunk-bwd-intra.md` at line 11, The heading "###
定位与职责" jumps directly to H3 after the front matter; change it to an H2 (replace
"###" with "## 定位与职责") so the document follows a proper top-level section after
the front matter and update any subsequent sibling headings to maintain
consistent incremental levels relative to this section.

Comment on lines +15 to +27
```
chunk_kda_bwd_wy_dqkg_fused ──→ dq_inter, dk_inter, dg_inter, db_inter, dAkk
chunk_kda_bwd_dAv ──→ dAqk │
chunk_kda_bwd_intra
dq = dq_inter + dq_intra
dk = dk_inter + dk_intra
db = db_inter + db_intra
dg = reverse_cumsum(dg_inter + dg_intra)
```
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add language identifiers to fenced code blocks

Both fenced blocks are missing a language tag (MD040), which reduces readability and lint compliance.

Proposed fix
-```
+```text
 chunk_kda_bwd_wy_dqkg_fused ──→ dq_inter, dk_inter, dg_inter, db_inter, dAkk
                                      │
 chunk_kda_bwd_dAv ──→ dAqk           │
                                      ▼
                     chunk_kda_bwd_intra
                          │
                          ▼
               dq = dq_inter + dq_intra
               dk = dk_inter + dk_intra
               db = db_inter + db_intra
               dg = reverse_cumsum(dg_inter + dg_intra)

```diff
-```
+```text
 Grid: (NK * NC, NT, B * H)
   NK = ceil(K / BK)    head_dim 方向的分块数
   NC = BT / BC          每个 chunk 内的子块数(通常 64/16 = 4)
   NT                    chunk 数量
   B * H                 batch × head
</details>


Also applies to: 192-198

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 markdownlint-cli2 (0.22.0)</summary>

[warning] 15-15: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

</details>

</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against the current code and only fix it if needed.

In @docs/design-docs/ops/kda/chunk-bwd-intra.md around lines 15 - 27, The fenced
code blocks containing the ASCII diagram (starting with
"chunk_kda_bwd_wy_dqkg_fused ──→ ..." and the block starting with "Grid: (NK *
NC, NT, B * H)") are missing language identifiers; update each triple-backtick
fence to include a language tag (e.g., ```text) so linters (MD040) and readers
correctly render them, and apply the same fix to the other similar block later
in the file (the block around the "Grid: ..." section referenced in the review).


</details>

<!-- fingerprinting:phantom:triton:hawk:99aa887f-9c10-473c-8b14-ed945dc3255b -->

<!-- This is an auto-generated comment by CodeRabbit -->


$$dg_r^{\text{raw}} = \sum_{j=r}^{C} dg_j$$

实现中在 kernel 外通过 `chunk_local_cumsum(dg, reverse=True)` 完成。这一步将 chunk 内的 $`d\mathbf{g}`$ 转换为对原始 $`\mathbf{g}_{\text{raw}}`$ 的梯度。
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unify reverse-cumsum variable naming (dg vs dg2)

Line 184 uses chunk_local_cumsum(dg, reverse=True), but Line 333 uses chunk_local_cumsum(dg2, reverse=True). Keep one name to avoid implementation ambiguity.

Proposed fix
-实现中在 kernel 外通过 `chunk_local_cumsum(dg, reverse=True)` 完成。这一步将 chunk 内的 $`d\mathbf{g}`$ 转换为对原始 $`\mathbf{g}_{\text{raw}}`$ 的梯度。
+实现中在 kernel 外通过 `chunk_local_cumsum(dg2, reverse=True)` 完成。这一步将 chunk 内的 $`d\mathbf{g}`$ 转换为对原始 $`\mathbf{g}_{\text{raw}}`$ 的梯度。
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
实现中在 kernel 外通过 `chunk_local_cumsum(dg, reverse=True)` 完成。这一步将 chunk 内的 $`d\mathbf{g}`$ 转换为对原始 $`\mathbf{g}_{\text{raw}}`$ 的梯度。
实现中在 kernel 外通过 `chunk_local_cumsum(dg2, reverse=True)` 完成。这一步将 chunk 内的 $`d\mathbf{g}`$ 转换为对原始 $`\mathbf{g}_{\text{raw}}`$ 的梯度。
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/design-docs/ops/kda/chunk-bwd-intra.md` at line 184, The implementation
uses two different variable names for the same reverse cumsum input causing
ambiguity; standardize to one name (either dg or dg2) across the file by
replacing usages of the other with the chosen identifier so both calls to
chunk_local_cumsum use the same variable (e.g., ensure chunk_local_cumsum(dg,
reverse=True) and any downstream code that reads/writes dg2 are updated to dg),
and update any related variable assignment sites and comments referencing dg2 to
use dg so chunk_local_cumsum and its consumers are consistent.

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 1028,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@0xaskr 0xaskr added this pull request to the merge queue Apr 14, 2026
Merged via the queue into main with commit df9db22 Apr 14, 2026
3 of 4 checks passed
@0xaskr 0xaskr deleted the feat/chunk_kda_bwd_intra branch April 14, 2026 03:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant