Conversation
Migrate KDA (Kimi Delta Attention) code from ant-pretrain: - Pallas intra-chunk fwd/bwd kernel with segment_ids and g-centered optimization - Chunk pipeline: 3-stage parallel algorithm (intra -> inter -> output) - Chunk backward intra reference implementation - CPU reference chunk KDA implementation - Tests: Pallas vs JAX reference, chunk vs recurrent alignment, GPU vs FLA - Design docs for chunk fwd/bwd Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Warning 规模超限 此 PR 核心代码变更行数为 2246,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
📝 WalkthroughWalkthroughThis PR introduces Chunk-Parallel forward and backward implementations for KDA (Kimi Delta Attention), adding a CPU reference implementation in JAX ( Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant chunk_kda
participant Reshape
participant chunk_kda_fwd
participant InterChunk as Inter-Chunk Recurrence
participant Output
User->>chunk_kda: q, k, v, g, beta, chunk_size
chunk_kda->>Reshape: Pad T, reshape to [B,H,NT,C,*]
Reshape->>chunk_kda: Reshaped tensors
chunk_kda->>chunk_kda: Chunk-local cumsum(g)
chunk_kda->>chunk_kda_fwd: Reshaped q,k,v,g,beta
loop For each chunk
chunk_kda_fwd->>chunk_kda_fwd: Build A_raw (interaction matrix)
chunk_kda_fwd->>chunk_kda_fwd: Triangular solve: Akk=(I+A_raw)^-1
chunk_kda_fwd->>chunk_kda_fwd: Compute effective w,u from Akk
chunk_kda_fwd->>InterChunk: Intra-chunk attention Aqk, state
end
InterChunk->>InterChunk: Update state S (delta-rule corrected)
InterChunk->>InterChunk: Fuse outputs with inter-chunk state
InterChunk->>chunk_kda_fwd: Chunk outputs, Aqk, Akk, final_state
chunk_kda_fwd->>Reshape: Stack outputs
Reshape->>Output: Reshape [B,T_orig,H,V], trim padding
Output->>User: output, final_state
sequenceDiagram
participant User
participant chunk_kda_bwd
participant Recompute as Recompute Forward
participant Recurrence as Reverse-Time Recurrence
participant dAkk_Grad as dAkk Gradient Path
participant chunk_kda_bwd_intra
participant Merge as Merge & Post-Process
User->>chunk_kda_bwd: dL/doutput, q, k, v, g, beta, chunk_size
chunk_kda_bwd->>Recompute: Recompute w, u, kg, delta-rule scan
Recompute->>chunk_kda_bwd: Forward intermediates
chunk_kda_bwd->>chunk_kda_bwd: Form dAqk, dv from output gradient
loop Reverse time over chunks
chunk_kda_bwd->>Recurrence: Reverse-time state recurrence
Recurrence->>Recurrence: Accumulate hidden-state gradients
Recurrence->>Recurrence: Update dv with state corrections
end
chunk_kda_bwd->>dAkk_Grad: Compute dAkk via inverse-matrix gradient
dAkk_Grad->>chunk_kda_bwd_intra: dAqk, dAkk
chunk_kda_bwd_intra->>chunk_kda_bwd_intra: Intra-chunk backward (per-chunk)
chunk_kda_bwd_intra->>chunk_kda_bwd: dq_intra, dk_intra, dg_intra, dbeta_intra
chunk_kda_bwd->>Merge: Merge inter + intra gradients
Merge->>Merge: Fused dq, dk, dv, dbeta
Merge->>Merge: Reverse cumsum(dg) to match forward cumsum
Merge->>User: dq, dk, dv, dbeta, dg, dinitial_state
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces the design documentation and reference implementations for the Chunk-Parallel forward and backward kernels of Kimi Delta Attention (KDA). The changes include detailed mathematical derivations for the chunked operations, JAX-based CPU reference implementations, and a Pallas-based TPU kernel for intra-chunk computations. Comprehensive tests are also added to ensure alignment between the recurrent and chunked implementations, as well as parity with the FLA library. The review feedback identifies inconsistencies in the mathematical formulas regarding the application of the scaling factor s across inter-chunk and intra-chunk terms, and notes that the Triton code snippets for matrix inverse gradients are oversimplified compared to the actual implementation logic.
|
|
||
| 回顾前向 Step 3 的输出公式(每个 chunk 内): | ||
|
|
||
| $$\mathbf{o} = \underbrace{s \cdot (\mathbf{q} \odot \exp(\mathbf{g})) \cdot \mathbf{h}}_{\text{inter-chunk 项}} + \underbrace{\text{tril}(\mathbf{A}_{qk}) \cdot \mathbf{v}_{\text{new}}}_{\text{intra-chunk 项}}$$ |
There was a problem hiding this comment.
The formula for the output o appears to be inconsistent with the implementation regarding the scaling factor s. The formula is:
$$\mathbf{o} = \underbrace{s \cdot (\mathbf{q} \odot \exp(\mathbf{g})) \cdot \mathbf{h}}_{\text{inter-chunk 项}} + \underbrace{\text{tril}(\mathbf{A}_{qk}) \cdot \mathbf{v}_{\text{new}}}_{\text{intra-chunk 项}}$$
This suggests only the inter-chunk term is scaled. However, the backward pass implementation (and the forward pass code) scales both the inter-chunk and intra-chunk contributions. For consistency and clarity, please consider updating the formula to reflect that s applies to both terms, for example:
$$\mathbf{o} = s \cdot \left( (\mathbf{q} \odot \exp(\mathbf{g})) \cdot \mathbf{h} + \text{tril}(\mathbf{A}_{qk}) \cdot \mathbf{v}_{\text{new}} \right)$$
This would also align with the gradient calculations shown later in this document, such as the scaling of dAqk on line 120.
| # 路径 E: 矩阵逆梯度 | ||
| b_dA = tl.where(row > col, b_dA, 0) # 严格下三角 | ||
| b_dA = tl.dot(b_dA, b_A) # 右乘 A | ||
| b_dA = tl.dot(b_A, b_dA) # 左乘 A |
There was a problem hiding this comment.
The Triton kernel code snippet for the matrix inverse gradient calculation seems to be an oversimplification and potentially misleading. The line b_dA = tl.dot(b_A, b_dA) suggests a left multiplication by A_kk, but the correct mathematical formula is dM = -A_kk.T @ dA_kk @ A_kk.T, which involves transposes. The JAX implementation in tops/cpu/ops/kda/chunk.py correctly implements this with transposes.
To avoid confusion, could you please update the code snippet to more accurately reflect the computation, for instance by using tl.trans?
|
|
||
| 代入第二步的结果,得到最终的 chunk-parallel 输出公式: | ||
|
|
||
| $$\mathbf{o}_r = \underbrace{(\mathbf{q}_r \odot \exp(\mathbf{g}_r))^\top \mathbf{S}_0}_{\text{inter-chunk}} + \underbrace{\sum_{j} \mathbf{A}_{qk}(r,j) \cdot (\mathbf{u}_j - \mathbf{w}_j \mathbf{S}_0)}_{\text{intra-chunk}}$$ |
There was a problem hiding this comment.
The output formula for o_r seems to be missing the scaling factor s for the intra-chunk term. The formula is given as:
$$\mathbf{o}_r = \underbrace{(\mathbf{q}_r \odot \exp(\mathbf{g}_r))^\top \mathbf{S}_0}_{\text{inter-chunk}} + \underbrace{\sum_{j} \mathbf{A}_{qk}(r,j) \cdot (\mathbf{u}_j - \mathbf{w}_j \mathbf{S}_0)}_{\text{intra-chunk}}$$
However, the implementation appears to apply the scaling factor to both the inter-chunk and intra-chunk components. To maintain consistency with the code and the backward pass documentation, please consider updating the formula to include the scale factor on both terms. For example:
$$\mathbf{o}_r = s \cdot \left( (\mathbf{q}_r \odot \exp(\mathbf{g}_r))^\top \mathbf{S}_0 + \sum_{j} \mathbf{A}_{qk}(r,j) \cdot (\mathbf{u}_j - \mathbf{w}_j \mathbf{S}_0) \right)$$
This also applies to other output formulas in this document, like the one on line 269.
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (6)
tests/ops/kda/test_kda_alignment.py (1)
83-89: Consider usingcompare_tensorutility per coding guidelines.The test uses
np.testing.assert_allclosedirectly. The coding guidelines specify using thecompare_tensorutility fromtests/utils.pyfor kernel output comparisons, which provides consistent tolerance handling across the test suite.This pattern repeats in all test methods (lines 109-114, 119-124, 141-146, 163-168, 185-190).
♻️ Example refactor for one assertion
+from tests.utils import compare_tensor + ... - np.testing.assert_allclose( - np.asarray(out_chunk), - np.asarray(out_recur), - rtol=1e-4, - atol=1e-4, - err_msg="chunk vs recurrent output mismatch (basic)", - ) + compare_tensor( + out_chunk, + out_recur, + rtol=1e-4, + atol=1e-4, + msg="chunk vs recurrent output mismatch (basic)", + )As per coding guidelines: "Use
compare_tensorutility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/kda/test_kda_alignment.py` around lines 83 - 89, Replace direct uses of np.testing.assert_allclose in this test (e.g., the assertion comparing out_chunk and out_recur) with the compare_tensor utility from tests/utils.py; locate each occurrence (assert_allclose calls comparing out_chunk vs out_recur across the test methods) and call compare_tensor(out_chunk, out_recur, atol=1e-4, rtol=1e-4, max_ulp=None) (or appropriate max_ulp if required by project standards) so that all kernel comparisons use the centralized tolerance handling.tops/cpu/ops/kda/chunk.py (2)
641-648: Prefix unused unpacked variables with underscore.
AqkandAkkare returned fromchunk_kda_fwdbut not used in the publicchunk_kdawrapper (they're needed for backward, which isn't called here). Prefix with underscore to indicate intentional discard.♻️ Proposed fix
- o, Aqk, Akk, final_state = chunk_kda_fwd( + o, _Aqk, _Akk, final_state = chunk_kda_fwd( q_c, k_c, v_c, g_c, beta_c, scale=scale, initial_state=initial_state, output_final_state=output_final_state, C=C, acc_dt=acc_dt, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/kda/chunk.py` around lines 641 - 648, The tuple returned by chunk_kda_fwd in the chunk_kda wrapper unpacks Aqk and Akk but never uses them; update the unpacking to prefix these unused variables with underscores (e.g., _Aqk, _Akk) in the call that currently does "o, Aqk, Akk, final_state = chunk_kda_fwd(...)" so their discard is explicit and linter-friendly; keep the rest of the call and parameter names (q_c, k_c, v_c, g_c, beta_c, scale, initial_state, output_final_state, C, acc_dt) unchanged.
172-219: Add input shape assertions.The function has excellent documentation but lacks runtime assertions for input validation.
🛡️ Proposed input validation
`@cpu_reference` def chunk_kda_bwd_intra( q_c: jax.Array, k_c: jax.Array, g_c: jax.Array, beta_c: jax.Array, dAqk: jax.Array, dAkk: jax.Array, C: int, acc_dt: jnp.dtype, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: """...""" + # Shape assertions + assert q_c.ndim == 5, f"q_c must be 5D [B,H,NT,C,K], got {q_c.ndim}D" + assert k_c.shape == q_c.shape, f"k_c shape mismatch" + assert g_c.shape == q_c.shape, f"g_c shape mismatch" + B, H, NT, C_actual, K = q_c.shape + assert C_actual == C, f"chunk size mismatch: {C_actual} vs {C}" + assert beta_c.shape == (B, H, NT, C), f"beta_c shape mismatch" + assert dAqk.shape == (B, H, NT, C, C), f"dAqk shape mismatch" + assert dAkk.shape == (B, H, NT, C, C), f"dAkk shape mismatch" + NT = q_c.shape[2]As per coding guidelines: "All public functions must enforce strict input assertions on shape and types before executing main logic"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/kda/chunk.py` around lines 172 - 219, Add runtime assertions at the start of the chunk_kda_bwd_intra function to validate the shapes and types of all input arrays q_c, k_c, g_c, beta_c, dAqk, and dAkk according to their documented shapes. Ensure batch size B, heads H, number of tokens NT, chunk size C, and feature dimension K are consistent across inputs. Also check that all inputs have the expected dtypes, especially acc_dt for accumulation and the array data types for q_c, k_c, g_c, beta_c, dAqk, and dAkk.tests/ops/kda/test_pallas_intra_chunk.py (1)
149-166: Consider usingcompare_tensorutility per coding guidelines.Similar to
test_kda_alignment.py, this test usesnp.testing.assert_allclosedirectly. The coding guidelines recommend usingcompare_tensorfromtests/utils.pyfor consistent tolerance handling.♻️ Example refactor
+from tests.utils import compare_tensor + ... - np.testing.assert_allclose( - np.asarray(out_ref), - np.asarray(out_pallas), - rtol=rtol, - atol=atol, - err_msg="Pallas vs Reference output mismatch", - ) + compare_tensor( + out_ref, + out_pallas, + rtol=rtol, + atol=atol, + msg="Pallas vs Reference output mismatch", + )As per coding guidelines: "Use
compare_tensorutility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/kda/test_pallas_intra_chunk.py` around lines 149 - 166, Replace direct np.testing.assert_allclose calls with the compare_tensor test utility: import compare_tensor and call compare_tensor(np.asarray(out_ref), np.asarray(out_pallas), atol=atol, rtol=rtol, max_ulp=<appropriate value>) instead of the first assert_allclose, and for the final state (inside if output_final_state) call compare_tensor(np.asarray(state_ref), np.asarray(state_pallas), atol=atol, rtol=rtol, max_ulp=<appropriate value>) after keeping the output_final_state existence checks; remove the two np.testing.assert_allclose blocks and ensure compare_tensor is imported at top of the test file.tops/cpu/ops/kda/chunk_bwd_intra_ref.py (1)
32-46: Add input shape assertions per coding guidelines.The function lacks runtime assertions to validate input tensor shapes and types. Per project coding standards, public functions should enforce strict input assertions before executing main logic.
🛡️ Proposed input validation
def chunk_kda_bwd_intra_ref( q: torch.Tensor, # [B, H, NT, C, K] k: torch.Tensor, # [B, H, NT, C, K] g: torch.Tensor, # [B, H, NT, C, K] chunk-local cumsummed gates beta: torch.Tensor, # [B, H, NT, C] dAqk: torch.Tensor, # [B, H, NT, C, C] 下三角(含对角线) dAkk: torch.Tensor, # [B, H, NT, C, C] 严格下三角 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ 返回: dq: [B, H, NT, C, K] dk: [B, H, NT, C, K] db: [B, H, NT, C] dg: [B, H, NT, C, K] """ + assert q.ndim == 5, f"q must be 5D [B,H,NT,C,K], got {q.ndim}D" + assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}" + assert g.shape == q.shape, f"g shape {g.shape} != q shape {q.shape}" + assert beta.ndim == 4 and beta.shape[:3] == q.shape[:3], ( + f"beta shape {beta.shape} incompatible with q" + ) + B, H, NT, C, K = q.shape + assert dAqk.shape == (B, H, NT, C, C), f"dAqk shape mismatch" + assert dAkk.shape == (B, H, NT, C, C), f"dAkk shape mismatch" + # -- 辅助量 -- eg = torch.exp(g)As per coding guidelines: "All public functions must enforce strict input assertions on shape and types before executing main logic"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/kda/chunk_bwd_intra_ref.py` around lines 32 - 46, Add strict runtime input assertions at the start of chunk_kda_bwd_intra_ref: verify q, k, g are 5-D tensors with shape [..., C, K]; beta is 4-D with matching leading dims [B, H, NT, C]; dAqk and dAkk are 5-D with shape [..., C, C]; ensure all tensors are torch.Tensor, share the same dtype and device, and that C and K are positive integers and consistent across tensors (e.g., q.shape[-2]==beta.shape[-1], q.shape[-1]==k.shape[-1], dAqk.shape[-2]==dAqk.shape[-1]==q.shape[-2], same for dAkk). Place these checks at the top of chunk_kda_bwd_intra_ref before any computation.tests/ops/kda/test_gpu_kda_vs_fla.py (1)
288-294: Usecompare_tensorfor these cross-framework checks.This pattern is repeated throughout the file. Switching the assertions to
tests/utils.py::compare_tensorkeeps tolerance handling and failure output consistent with the rest of the repo.As per coding guidelines,
tests/**/*.py: Usecompare_tensorutility fromtests/utils.pywith appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/kda/test_gpu_kda_vs_fla.py` around lines 288 - 294, Replace the direct NumPy assertion with the repository's compare utility: call compare_tensor(pt2np(out_pt), jax2np(out_jax), atol=atol, rtol=rtol, max_ulp=<appropriate_value>) instead of np.testing.assert_allclose to keep tolerance handling and failure messages consistent; ensure compare_tensor is imported from tests.utils at the top of the file and choose a sensible max_ulp (or pass None if not applicable) and include the same err_msg context (e.g., f"fused_kda_gate mismatch (dt_bias={with_dt_bias})") when invoking compare_tensor.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/ops/kda/test_gpu_kda_vs_fla.py`:
- Around line 28-33: Replace the broad "except Exception" around the JAX
import/device check with targeted exception handling: catch ImportError (for
missing jax) and the JAX-specific initialization errors (e.g., RuntimeError or
jax.lib.xla_extension.XlaRuntimeError) so only genuine unavailability/init
failures set _HAS_JAX_GPU = False, while letting other unexpected exceptions
propagate; update the try/except that encloses the import and the any("cuda"...
for d in jax.devices()) check and reference the _HAS_JAX_GPU symbol and jax
import in your changes.
In `@tops/cpu/ops/kda/chunk_pipeline.py`:
- Around line 152-158: The current activation handling silently defaults to gate
= g_2d for unknown names (using variables activation, g_2d, gate), which hides
typos; instead validate activation explicitly and raise a clear exception for
unsupported values. Replace the final else branch with a fail-fast check that
raises a ValueError (or custom error) listing allowed options ("sigmoid",
"swish", "silu"/"silu" alias handling if desired), so callers get an immediate
error when activation is invalid rather than silently changing semantics.
- Around line 692-706: chunk_kda_reference (and similarly
fused_recurrent_kda_reference) must validate inputs up-front: add assertions
(use tops.utils.assert_shape_or_none where appropriate) to check that q, k, v,
g, beta are arrays of expected rank (e.g., rank >= 2) and share the same leading
dimensions, validate initial_state is either None or has the correct shape and
dtype consistent with the internal state, and ensure cu_seqlens (if provided) is
a rank-1 array and strictly non-decreasing/monotonic; fail fast with clear
assert messages before any reshape/scan/indexing. Use the function names
q/k/v/g/beta/initial_state/cu_seqlens in your checks and apply identical guards
in fused_recurrent_kda_reference (lines referenced) to enforce the coding
guideline for public entrypoints.
In `@tops/ops/kda/intra_chunk.py`:
- Around line 273-324: The index_map lambdas inside the BlockSpec declarations
(e.g., the in_specs/out_specs and their counterparts in the backward call) use
the variable name "l" which triggers Ruff E741; rename that parameter to a
non-conflicting name such as "chunk_idx" in every index_map lambda (e.g., change
"lambda i, j, l: ..." to "lambda i, j, chunk_idx: ..." and update any uses
inside the lambda accordingly) for all BlockSpec entries referenced (including
u, w, qg, kg, Aqk, Akk_inv and the corresponding in_specs and backward
definitions) so the file is lint-clean.
- Around line 221-235: Add strict input assertions before any reshaping in the
intra-chunk attention entrypoint: assert T % chunk_size == 0 (already present)
and additionally assert q, k, g, v have identical shapes (B, H, T, D); assert
beta has shape (B, H, T) or (B, H, T, 1); assert segment_ids is either None or
shape (B, T); assert chunk_size is a multiple of 16; use the project utility
assert_shape_or_none from tops.utils where appropriate to validate optional
arrays and types. Place these checks immediately above the current reshapes
(referencing variables q, k, g, v, beta, segment_ids, chunk_size) so failures
surface with clear messages before any reshape/split operations.
- Around line 387-396: The backward mask for Akk is using a non-strict >= and
thus includes diagonal gradients; change the mask used for Akk to be strictly
lower-triangular to match forward by replacing the >= with > when constructing
mask_akk (e.g. compute mask_akk using idx[:, None] > idx[None, :] combined with
segment_mask or build causal_mask strictly with > and use that only for
mask_akk); leave mask_aqk (the qk mask) as-is so qk causal behavior is
unchanged.
- Around line 245-253: The Pallas kernels are forced into interpreter mode by
passing interpret=True to pl.pallas_call; remove the interpret=True argument
from the pl.pallas_call invocations that wrap kda_intra_chunk_kernel (forward)
and its backward counterpart so the calls can emit compiled backend code for
TPU/GPU execution (keep other args like functools.partial, chunk_size,
head_dim/D, scale, and out_shape unchanged).
---
Nitpick comments:
In `@tests/ops/kda/test_gpu_kda_vs_fla.py`:
- Around line 288-294: Replace the direct NumPy assertion with the repository's
compare utility: call compare_tensor(pt2np(out_pt), jax2np(out_jax), atol=atol,
rtol=rtol, max_ulp=<appropriate_value>) instead of np.testing.assert_allclose to
keep tolerance handling and failure messages consistent; ensure compare_tensor
is imported from tests.utils at the top of the file and choose a sensible
max_ulp (or pass None if not applicable) and include the same err_msg context
(e.g., f"fused_kda_gate mismatch (dt_bias={with_dt_bias})") when invoking
compare_tensor.
In `@tests/ops/kda/test_kda_alignment.py`:
- Around line 83-89: Replace direct uses of np.testing.assert_allclose in this
test (e.g., the assertion comparing out_chunk and out_recur) with the
compare_tensor utility from tests/utils.py; locate each occurrence
(assert_allclose calls comparing out_chunk vs out_recur across the test methods)
and call compare_tensor(out_chunk, out_recur, atol=1e-4, rtol=1e-4,
max_ulp=None) (or appropriate max_ulp if required by project standards) so that
all kernel comparisons use the centralized tolerance handling.
In `@tests/ops/kda/test_pallas_intra_chunk.py`:
- Around line 149-166: Replace direct np.testing.assert_allclose calls with the
compare_tensor test utility: import compare_tensor and call
compare_tensor(np.asarray(out_ref), np.asarray(out_pallas), atol=atol,
rtol=rtol, max_ulp=<appropriate value>) instead of the first assert_allclose,
and for the final state (inside if output_final_state) call
compare_tensor(np.asarray(state_ref), np.asarray(state_pallas), atol=atol,
rtol=rtol, max_ulp=<appropriate value>) after keeping the output_final_state
existence checks; remove the two np.testing.assert_allclose blocks and ensure
compare_tensor is imported at top of the test file.
In `@tops/cpu/ops/kda/chunk_bwd_intra_ref.py`:
- Around line 32-46: Add strict runtime input assertions at the start of
chunk_kda_bwd_intra_ref: verify q, k, g are 5-D tensors with shape [..., C, K];
beta is 4-D with matching leading dims [B, H, NT, C]; dAqk and dAkk are 5-D with
shape [..., C, C]; ensure all tensors are torch.Tensor, share the same dtype and
device, and that C and K are positive integers and consistent across tensors
(e.g., q.shape[-2]==beta.shape[-1], q.shape[-1]==k.shape[-1],
dAqk.shape[-2]==dAqk.shape[-1]==q.shape[-2], same for dAkk). Place these checks
at the top of chunk_kda_bwd_intra_ref before any computation.
In `@tops/cpu/ops/kda/chunk.py`:
- Around line 641-648: The tuple returned by chunk_kda_fwd in the chunk_kda
wrapper unpacks Aqk and Akk but never uses them; update the unpacking to prefix
these unused variables with underscores (e.g., _Aqk, _Akk) in the call that
currently does "o, Aqk, Akk, final_state = chunk_kda_fwd(...)" so their discard
is explicit and linter-friendly; keep the rest of the call and parameter names
(q_c, k_c, v_c, g_c, beta_c, scale, initial_state, output_final_state, C,
acc_dt) unchanged.
- Around line 172-219: Add runtime assertions at the start of the
chunk_kda_bwd_intra function to validate the shapes and types of all input
arrays q_c, k_c, g_c, beta_c, dAqk, and dAkk according to their documented
shapes. Ensure batch size B, heads H, number of tokens NT, chunk size C, and
feature dimension K are consistent across inputs. Also check that all inputs
have the expected dtypes, especially acc_dt for accumulation and the array data
types for q_c, k_c, g_c, beta_c, dAqk, and dAkk.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4646ca06-d487-461c-a6d0-d121767a9248
📒 Files selected for processing (12)
docs/design-docs/ops/kda/chunk-bwd.mddocs/design-docs/ops/kda/chunk-fwd.mdtests/ops/kda/__init__.pytests/ops/kda/test_gpu_kda_vs_fla.pytests/ops/kda/test_kda_alignment.pytests/ops/kda/test_pallas_intra_chunk.pytops/cpu/ops/kda/__init__.pytops/cpu/ops/kda/chunk.pytops/cpu/ops/kda/chunk_bwd_intra_ref.pytops/cpu/ops/kda/chunk_pipeline.pytops/ops/kda/__init__.pytops/ops/kda/intra_chunk.py
tests/ops/kda/test_gpu_kda_vs_fla.py
Outdated
| try: | ||
| import jax | ||
|
|
||
| _HAS_JAX_GPU = any("cuda" in str(d).lower() for d in jax.devices()) | ||
| except Exception: | ||
| _HAS_JAX_GPU = False |
There was a problem hiding this comment.
Don't turn arbitrary JAX failures into a skip.
except Exception will silently skip this whole module on unrelated local regressions too. Catch the expected availability/init errors only so real failures still surface.
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 32-32: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/ops/kda/test_gpu_kda_vs_fla.py` around lines 28 - 33, Replace the broad
"except Exception" around the JAX import/device check with targeted exception
handling: catch ImportError (for missing jax) and the JAX-specific
initialization errors (e.g., RuntimeError or
jax.lib.xla_extension.XlaRuntimeError) so only genuine unavailability/init
failures set _HAS_JAX_GPU = False, while letting other unexpected exceptions
propagate; update the try/except that encloses the import and the any("cuda"...
for d in jax.devices()) check and reference the _HAS_JAX_GPU symbol and jax
import in your changes.
tops/cpu/ops/kda/chunk_pipeline.py
Outdated
| if activation in ("swish", "silu"): | ||
| gate = g_2d * jax.nn.sigmoid(g_2d) # swish(x) = x * sigmoid(x) | ||
| elif activation == "sigmoid": | ||
| gate = jax.nn.sigmoid(g_2d) | ||
| else: | ||
| gate = g_2d | ||
|
|
There was a problem hiding this comment.
Reject unsupported activations instead of silently changing semantics.
The docstring only advertises "sigmoid", "swish", and "silu", but any typo currently falls through to gate = g_2d. That makes a bad config look valid while changing layer behavior.
Fail fast on invalid activation names
if activation in ("swish", "silu"):
gate = g_2d * jax.nn.sigmoid(g_2d) # swish(x) = x * sigmoid(x)
elif activation == "sigmoid":
gate = jax.nn.sigmoid(g_2d)
else:
- gate = g_2d
+ raise ValueError(
+ f"Unsupported activation {activation!r}. Expected 'sigmoid', 'swish', or 'silu'."
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if activation in ("swish", "silu"): | |
| gate = g_2d * jax.nn.sigmoid(g_2d) # swish(x) = x * sigmoid(x) | |
| elif activation == "sigmoid": | |
| gate = jax.nn.sigmoid(g_2d) | |
| else: | |
| gate = g_2d | |
| if activation in ("swish", "silu"): | |
| gate = g_2d * jax.nn.sigmoid(g_2d) # swish(x) = x * sigmoid(x) | |
| elif activation == "sigmoid": | |
| gate = jax.nn.sigmoid(g_2d) | |
| else: | |
| raise ValueError( | |
| f"Unsupported activation {activation!r}. Expected 'sigmoid', 'swish', or 'silu'." | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/cpu/ops/kda/chunk_pipeline.py` around lines 152 - 158, The current
activation handling silently defaults to gate = g_2d for unknown names (using
variables activation, g_2d, gate), which hides typos; instead validate
activation explicitly and raise a clear exception for unsupported values.
Replace the final else branch with a fail-fast check that raises a ValueError
(or custom error) listing allowed options ("sigmoid", "swish", "silu"/"silu"
alias handling if desired), so callers get an immediate error when activation is
invalid rather than silently changing semantics.
tops/cpu/ops/kda/chunk_pipeline.py
Outdated
| def chunk_kda_reference( | ||
| q: Array, | ||
| k: Array, | ||
| v: Array, | ||
| g: Array, | ||
| beta: Array, | ||
| scale: float | None = None, | ||
| initial_state: Array | None = None, | ||
| output_final_state: bool = False, | ||
| use_qk_l2norm_in_kernel: bool = False, | ||
| use_gate_in_kernel: bool = False, | ||
| cu_seqlens: Array | None = None, | ||
| use_pallas: bool = False, | ||
| **kwargs: Any, | ||
| ) -> tuple[Array, Array | None]: |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Public KDA entrypoints need explicit shape/type checks.
chunk_kda_reference and fused_recurrent_kda_reference currently rely on downstream reshape, scan, and indexing errors to catch bad inputs. Please assert the expected ranks and shared leading dimensions for q/k/v/g/beta, validate initial_state, and check that cu_seqlens is rank-1 and monotonic before preprocessing.
As per coding guidelines, **/*.py: All public functions must enforce strict input assertions on shape and types before executing main logic using assert instructions or utilities like assert_shape_or_none from tops.utils.
Also applies to: 795-808
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/cpu/ops/kda/chunk_pipeline.py` around lines 692 - 706,
chunk_kda_reference (and similarly fused_recurrent_kda_reference) must validate
inputs up-front: add assertions (use tops.utils.assert_shape_or_none where
appropriate) to check that q, k, v, g, beta are arrays of expected rank (e.g.,
rank >= 2) and share the same leading dimensions, validate initial_state is
either None or has the correct shape and dtype consistent with the internal
state, and ensure cu_seqlens (if provided) is a rank-1 array and strictly
non-decreasing/monotonic; fail fast with clear assert messages before any
reshape/scan/indexing. Use the function names
q/k/v/g/beta/initial_state/cu_seqlens in your checks and apply identical guards
in fused_recurrent_kda_reference (lines referenced) to enforce the coding
guideline for public entrypoints.
tops/ops/kda/intra_chunk.py
Outdated
| B, H, T, D = k.shape | ||
| assert T % chunk_size == 0, "Sequence length must be divisible by chunk_size" | ||
| num_chunks = T // chunk_size | ||
|
|
||
| if segment_ids is None: | ||
| # Default: all tokens belong to segment 0 | ||
| segment_ids = jnp.zeros((B, T), dtype=jnp.int32) | ||
|
|
||
| q_reshaped = q.reshape(B, H, num_chunks, chunk_size, D) | ||
| k_reshaped = k.reshape(B, H, num_chunks, chunk_size, D) | ||
| g_reshaped = g.reshape(B, H, num_chunks, chunk_size, D) | ||
| beta_reshaped = beta.reshape(B, H, num_chunks, chunk_size, 1) | ||
| v_reshaped = v.reshape(B, H, num_chunks, chunk_size, D) | ||
| segment_ids_reshaped = segment_ids.reshape(B, 1, num_chunks, chunk_size, 1) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Validate the full tensor contract before reshaping.
These entrypoints only guard T % chunk_size == 0, but the implementation assumes much more: q/k/g/v share [B, H, T, D], beta is [B, H, T], segment_ids is [B, T], and the forward path’s block solver only works when chunk_size is a multiple of 16. Right now bad inputs fail later inside reshape/jnp.split with hard-to-diagnose errors.
Suggested guardrail block
B, H, T, D = k.shape
assert T % chunk_size == 0, "Sequence length must be divisible by chunk_size"
+ assert q.shape == k.shape == g.shape == v.shape, "q/k/g/v must all be [B, H, T, D]"
+ assert beta.shape == (B, H, T), "beta must be [B, H, T]"
+ assert chunk_size % 16 == 0, "chunk_size must be a multiple of 16 for the block solver"
+ if segment_ids is not None:
+ assert segment_ids.shape == (B, T), "segment_ids must be [B, T]"Also applies to: 491-503
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/kda/intra_chunk.py` around lines 221 - 235, Add strict input
assertions before any reshaping in the intra-chunk attention entrypoint: assert
T % chunk_size == 0 (already present) and additionally assert q, k, g, v have
identical shapes (B, H, T, D); assert beta has shape (B, H, T) or (B, H, T, 1);
assert segment_ids is either None or shape (B, T); assert chunk_size is a
multiple of 16; use the project utility assert_shape_or_none from tops.utils
where appropriate to validate optional arrays and types. Place these checks
immediately above the current reshapes (referencing variables q, k, g, v, beta,
segment_ids, chunk_size) so failures surface with clear messages before any
reshape/split operations.
tops/ops/kda/intra_chunk.py
Outdated
| ) = pl.pallas_call( | ||
| functools.partial( | ||
| kda_intra_chunk_kernel, | ||
| chunk_size=chunk_size, | ||
| head_dim=D, | ||
| scale=scale, | ||
| ), | ||
| interpret=True, | ||
| out_shape=[ |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In JAX Pallas, what does interpret=Truedo inpl.pallas_call(...)? Does it bypass backend code generation and run through the interpreter/debug path instead of the compiled TPU/GPU kernel?
💡 Result:
In JAX Pallas, interpret=True in pl.pallas_call(...) runs the kernel as a jax.jit of a scan over the grid, where the body is the kernel lowered as a JAX function. This bypasses the backend-specific code generation (e.g., compiled TPU/GPU kernels via Mosaic or other backends) and instead executes through an interpreter/emulation path, primarily for debugging. It enables running Pallas kernels on CPU (the only way on CPU) without requiring TPU/GPU hardware, simulating TPU behaviors like shared memory and synchronization when using TPU interpret mode.
Citations:
- 1: https://github.com/jax-ml/jax/blob/main/jax/_src/pallas/pallas_call.py
- 2: https://docs.jax.dev/en/latest/_autosummary/jax.experimental.pallas.pallas_call.html
- 3: https://github.com/jax-ml/jax/blob/main/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py
- 4: Pallas emulation jax-ml/jax#20193
🏁 Script executed:
# Check the file and specific lines
wc -l tops/ops/kda/intra_chunk.pyRepository: primatrix/pallas-kernel
Length of output: 98
🏁 Script executed:
# Read the specific sections mentioned
sed -n '240,260p' tops/ops/kda/intra_chunk.pyRepository: primatrix/pallas-kernel
Length of output: 686
🏁 Script executed:
# Also check the second location
sed -n '500,520p' tops/ops/kda/intra_chunk.pyRepository: primatrix/pallas-kernel
Length of output: 838
🏁 Script executed:
# Search for other pl.pallas_call usages in the file to understand the pattern
rg "pl.pallas_call" tops/ops/kda/intra_chunk.py -A 5 -B 2Repository: primatrix/pallas-kernel
Length of output: 546
🏁 Script executed:
# Check for TODO or FIXME comments near these lines
sed -n '240,260p' tops/ops/kda/intra_chunk.py | grep -i "todo\|fixme\|interpret\|debug"Repository: primatrix/pallas-kernel
Length of output: 90
🏁 Script executed:
# Check the kernel definitions and their structure
grep -n "def kda_intra_chunk_kernel\|def kda_intra_chunk_bwd_kernel" tops/ops/kda/intra_chunk.py -A 10Repository: primatrix/pallas-kernel
Length of output: 509
🏁 Script executed:
# Look for any comments in the wider context around these pallas_call invocations
sed -n '230,270p' tops/ops/kda/intra_chunk.pyRepository: primatrix/pallas-kernel
Length of output: 1538
Remove interpret=True to enable compiled Pallas kernel execution.
Setting interpret=True in pl.pallas_call() bypasses backend-specific code generation and runs the kernel through an interpreter/emulation path instead of compiling to actual TPU/GPU code. Both the forward pass (lines 245-253) and backward pass (lines 506-513) are pinned to interpreter mode, preventing these Pallas kernels from executing on accelerated hardware.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/kda/intra_chunk.py` around lines 245 - 253, The Pallas kernels are
forced into interpreter mode by passing interpret=True to pl.pallas_call; remove
the interpret=True argument from the pl.pallas_call invocations that wrap
kda_intra_chunk_kernel (forward) and its backward counterpart so the calls can
emit compiled backend code for TPU/GPU execution (keep other args like
functools.partial, chunk_size, head_dim/D, scale, and out_shape unchanged).
tops/ops/kda/intra_chunk.py
Outdated
| in_specs=[ | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, D), | ||
| ), # q | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, D), | ||
| ), # k | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, D), | ||
| ), # g | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, 1), | ||
| ), # beta | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, D), | ||
| ), # v | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, 0, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, 1), | ||
| ), # segment_ids | ||
| ], | ||
| out_specs=[ | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, D), | ||
| ), # u | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, D), | ||
| ), # w | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, D), | ||
| ), # qg | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, D), | ||
| ), # kg | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, chunk_size), | ||
| ), # Aqk | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j, l: (i, j, l, 0, 0), | ||
| block_shape=(1, 1, 1, chunk_size, chunk_size), | ||
| ), # Akk_inv | ||
| ], |
There was a problem hiding this comment.
Rename the l grid axis before this hits Ruff.
The index_map lambdas use l, which triggers E741 on every BlockSpec here and in the backward call. Renaming it to something like chunk_idx keeps the file lint-clean.
Also applies to: 528-575
🧰 Tools
🪛 Ruff (0.15.9)
[error] 275-275: Ambiguous variable name: l
(E741)
[error] 279-279: Ambiguous variable name: l
(E741)
[error] 283-283: Ambiguous variable name: l
(E741)
[error] 287-287: Ambiguous variable name: l
(E741)
[error] 291-291: Ambiguous variable name: l
(E741)
[error] 295-295: Ambiguous variable name: l
(E741)
[error] 301-301: Ambiguous variable name: l
(E741)
[error] 305-305: Ambiguous variable name: l
(E741)
[error] 309-309: Ambiguous variable name: l
(E741)
[error] 313-313: Ambiguous variable name: l
(E741)
[error] 317-317: Ambiguous variable name: l
(E741)
[error] 321-321: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/kda/intra_chunk.py` around lines 273 - 324, The index_map lambdas
inside the BlockSpec declarations (e.g., the in_specs/out_specs and their
counterparts in the backward call) use the variable name "l" which triggers Ruff
E741; rename that parameter to a non-conflicting name such as "chunk_idx" in
every index_map lambda (e.g., change "lambda i, j, l: ..." to "lambda i, j,
chunk_idx: ..." and update any uses inside the lambda accordingly) for all
BlockSpec entries referenced (including u, w, qg, kg, Aqk, Akk_inv and the
corresponding in_specs and backward definitions) so the file is lint-clean.
|
Warning 规模超限 此 PR 核心代码变更行数为 2334,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
|
Warning 规模超限 此 PR 核心代码变更行数为 2015,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
|
Warning 规模超限 此 PR 核心代码变更行数为 2015,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (1)
tops/ops/kda/intra_chunk.py (1)
68-77:⚠️ Potential issue | 🔴 CriticalKeep
dAkkstrictly lower-triangular.
mask_akkis built from>=, so diagonaldAkkterms survive even though the forward Akk graph only has strictly-lower entries. That leaks spurious diagonal gradient intodbeta,dk, anddg.🔧 Minimal fix
idx = jnp.arange(chunk_size, dtype=jnp.int32) - causal_mask = idx[:, None] >= idx[None, :] causal_mask_qk = idx[:, None] >= idx[None, :] + causal_mask_akk = idx[:, None] > idx[None, :] segment_mask = segment_ids[:, None] == segment_ids[None, :] - mask_akk = causal_mask & segment_mask + mask_akk = causal_mask_akk & segment_mask mask_aqk = causal_mask_qk & segment_mask🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/kda/intra_chunk.py` around lines 68 - 77, The current mask_akk uses >= so diagonal elements pass through and allow spurious diagonal gradient into dbeta/dk/dg; change the construction of mask_akk to be strictly lower-triangular by using > instead of >= (i.e., set mask_akk = idx[:, None] > idx[None, :]) so dAkk becomes strictly lower-triangular before computing dAkk_masked; keep mask_aqk as-is if cross terms should include the diagonal. Ensure references: idx, causal_mask/causal_mask_qk, mask_akk, dAkk, and dAkk_masked are updated accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/cpu/ops/kda/chunk.py`:
- Around line 637-666: Move all input validation/assertions to the top of the
function before reading q.shape, computing acc dtype or doing any padding math
(i.e., before uses of q, _acc_dtype, chunk_size C and _cdiv). Specifically,
validate q.ndim, chunk_size > 0, and types/shapes for k, v, g, beta and
initial_state using assert statements or the helper assert_shape_or_none from
tops.utils; only after those checks compute orig_dtype = v.dtype, acc_dt =
_acc_dtype(q.dtype), unpack B,T_orig,H,K = q.shape and compute T_padded =
_cdiv(T_orig, C) * C. Ensure the same unique symbols (q, k, v, g, beta,
initial_state, chunk_size/C, _acc_dtype, _cdiv) are referenced so the assertions
guard all downstream operations.
- Around line 299-309: Reorder and strengthen the input guards so you validate
rank and chunk_size before unpacking/using shapes: assert q.ndim == 4 (or use
assert_shape_or_none) and assert chunk_size > 0 (C != 0) up front, then unpack
B, T, H, K = q.shape and only after that check T % C == 0 and the remaining
shape assertions for k, g, beta, dAqk, dAkk; replace any implicit operations
that can raise raw exceptions with explicit asserts (use assert_shape_or_none
from tops.utils if available) so q, chunk_size, and all tensor shapes are
validated in the function in chunk.py before any unpack/division occurs.
In `@tops/ops/kda/intra_chunk.py`:
- Around line 19-20: The module docstring incorrectly states the tensor layout
as [B, H, T, D]; update it to the correct layout [B, T, H, D] to match how
kda_intra_chunk_bwd and its callers expect tensors, ensuring the documented
transpose contract aligns with the implementation in functions like
kda_intra_chunk_bwd.
- Around line 181-205: The code unpacks B, T, H, D from k.shape before
validating tensor rank/dtype and pallas_call currently hard-codes k.dtype while
kernel reads dq from q.dtype and dbeta from beta.dtype; add upfront assertions
for q.ndim, k.ndim, g.ndim and beta.ndim and explicit dtype checks
(q.dtype==k.dtype==g.dtype where appropriate, and beta.dtype checked for dbeta),
or use tops.utils.assert_shape_or_none to validate shapes before B,T,H,D
assignment, and change the pallas_call invocation to use the correct output
dtypes (derive dq dtype from q.dtype and dbeta dtype from beta.dtype) rather
than always using k.dtype so mixed-dtype or bad-rank inputs fail early and
deterministically (refer to symbols: k, q, g, beta, dAqk, dAkk, segment_ids,
pallas_call, dq, dbeta, assert_shape_or_none).
---
Duplicate comments:
In `@tops/ops/kda/intra_chunk.py`:
- Around line 68-77: The current mask_akk uses >= so diagonal elements pass
through and allow spurious diagonal gradient into dbeta/dk/dg; change the
construction of mask_akk to be strictly lower-triangular by using > instead of
>= (i.e., set mask_akk = idx[:, None] > idx[None, :]) so dAkk becomes strictly
lower-triangular before computing dAkk_masked; keep mask_aqk as-is if cross
terms should include the diagonal. Ensure references: idx,
causal_mask/causal_mask_qk, mask_akk, dAkk, and dAkk_masked are updated
accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 624f24f5-8342-4cb0-b3fe-98211b852542
📒 Files selected for processing (6)
tests/ops/kda/test_chunk_kda_tpu.pytests/ref/kda/test_chunk_kda.pytops/cpu/ops/kda/__init__.pytops/cpu/ops/kda/chunk.pytops/ops/kda/__init__.pytops/ops/kda/intra_chunk.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tops/cpu/ops/kda/init.py
| B, T, H, K = q.shape | ||
| C = chunk_size | ||
|
|
||
| assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D" | ||
| assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}" | ||
| assert g.shape == q.shape, f"g shape {g.shape} != q shape {q.shape}" | ||
| assert beta.shape == (B, T, H), f"beta shape {beta.shape} != ({B}, {T}, {H})" | ||
| assert T % C == 0, f"T={T} must be divisible by chunk_size={C}" | ||
| assert dAqk.shape == (B, T, H, C), f"dAqk shape {dAqk.shape} != ({B}, {T}, {H}, {C})" | ||
| assert dAkk.shape == (B, T, H, C), f"dAkk shape {dAkk.shape} != ({B}, {T}, {H}, {C})" | ||
|
|
There was a problem hiding this comment.
Validate rank and chunk_size before unpacking/using them.
B, T, H, K = q.shape and T % C both run before the public contract is checked, so a bad-rank q or chunk_size=0 raises a raw unpack/division error instead of the intended assertion.
🛡️ Suggested guard ordering
- B, T, H, K = q.shape
C = chunk_size
-
assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D"
+ assert C > 0, f"chunk_size must be positive, got {C}"
+ B, T, H, K = q.shape
assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}"
assert g.shape == q.shape, f"g shape {g.shape} != q shape {q.shape}"
assert beta.shape == (B, T, H), f"beta shape {beta.shape} != ({B}, {T}, {H})"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| B, T, H, K = q.shape | |
| C = chunk_size | |
| assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D" | |
| assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}" | |
| assert g.shape == q.shape, f"g shape {g.shape} != q shape {q.shape}" | |
| assert beta.shape == (B, T, H), f"beta shape {beta.shape} != ({B}, {T}, {H})" | |
| assert T % C == 0, f"T={T} must be divisible by chunk_size={C}" | |
| assert dAqk.shape == (B, T, H, C), f"dAqk shape {dAqk.shape} != ({B}, {T}, {H}, {C})" | |
| assert dAkk.shape == (B, T, H, C), f"dAkk shape {dAkk.shape} != ({B}, {T}, {H}, {C})" | |
| C = chunk_size | |
| assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D" | |
| assert C > 0, f"chunk_size must be positive, got {C}" | |
| B, T, H, K = q.shape | |
| assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}" | |
| assert g.shape == q.shape, f"g shape {g.shape} != q shape {q.shape}" | |
| assert beta.shape == (B, T, H), f"beta shape {beta.shape} != ({B}, {T}, {H})" | |
| assert T % C == 0, f"T={T} must be divisible by chunk_size={C}" | |
| assert dAqk.shape == (B, T, H, C), f"dAqk shape {dAqk.shape} != ({B}, {T}, {H}, {C})" | |
| assert dAkk.shape == (B, T, H, C), f"dAkk shape {dAkk.shape} != ({B}, {T}, {H}, {C})" |
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 299-299: Unpacked variable K is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/cpu/ops/kda/chunk.py` around lines 299 - 309, Reorder and strengthen the
input guards so you validate rank and chunk_size before unpacking/using shapes:
assert q.ndim == 4 (or use assert_shape_or_none) and assert chunk_size > 0 (C !=
0) up front, then unpack B, T, H, K = q.shape and only after that check T % C ==
0 and the remaining shape assertions for k, g, beta, dAqk, dAkk; replace any
implicit operations that can raise raw exceptions with explicit asserts (use
assert_shape_or_none from tops.utils if available) so q, chunk_size, and all
tensor shapes are validated in the function in chunk.py before any
unpack/division occurs.
| orig_dtype = v.dtype | ||
| acc_dt = _acc_dtype(q.dtype) | ||
| B, T_orig, H, K = q.shape | ||
| V = v.shape[-1] | ||
| C = chunk_size | ||
|
|
||
| # Shape assertions (project coding standard) | ||
| assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D" | ||
| assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}" | ||
| assert v.ndim == 4 and v.shape[:3] == q.shape[:3], ( | ||
| f"v shape {v.shape} incompatible with q shape {q.shape}" | ||
| ) | ||
| assert g.ndim == 4 and g.shape == q.shape, ( | ||
| f"g shape {g.shape} != q shape {q.shape}" | ||
| ) | ||
| assert beta.ndim == 3 and beta.shape == q.shape[:3], ( | ||
| f"beta shape {beta.shape} != {q.shape[:3]}" | ||
| ) | ||
| if initial_state is not None: | ||
| assert initial_state.shape == (B, H, K, V), ( | ||
| f"initial_state shape {initial_state.shape} != ({B}, {H}, {K}, {V})" | ||
| ) | ||
|
|
||
| if scale is None: | ||
| scale = K ** -0.5 | ||
|
|
||
| # --- Pad T to multiple of chunk_size --- | ||
| T = T_orig | ||
| T_padded = _cdiv(T_orig, C) * C | ||
| if T_padded > T_orig: |
There was a problem hiding this comment.
Move API validation ahead of shape unpacking and padding math.
This entrypoint reads q.shape and computes _cdiv(T_orig, C) before any guardrails. Invalid ranks or chunk_size <= 0 will fail with low-level exceptions instead of a clear contract error.
🛡️ Suggested guard ordering
orig_dtype = v.dtype
acc_dt = _acc_dtype(q.dtype)
- B, T_orig, H, K = q.shape
- V = v.shape[-1]
C = chunk_size
+ assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D"
+ assert C > 0, f"chunk_size must be positive, got {C}"
+ B, T_orig, H, K = q.shape
+ V = v.shape[-1]
# Shape assertions (project coding standard)
- assert q.ndim == 4, f"q must be 4D [B,T,H,K], got {q.ndim}D"
assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/cpu/ops/kda/chunk.py` around lines 637 - 666, Move all input
validation/assertions to the top of the function before reading q.shape,
computing acc dtype or doing any padding math (i.e., before uses of q,
_acc_dtype, chunk_size C and _cdiv). Specifically, validate q.ndim, chunk_size >
0, and types/shapes for k, v, g, beta and initial_state using assert statements
or the helper assert_shape_or_none from tops.utils; only after those checks
compute orig_dtype = v.dtype, acc_dt = _acc_dtype(q.dtype), unpack B,T_orig,H,K
= q.shape and compute T_padded = _cdiv(T_orig, C) * C. Ensure the same unique
symbols (q, k, v, g, beta, initial_state, chunk_size/C, _acc_dtype, _cdiv) are
referenced so the assertions guard all downstream operations.
| Tensor layout: [B, H, T, D] (batch, heads, time, head_dim). | ||
| """ |
There was a problem hiding this comment.
Fix the tensor layout in the module docstring.
The file header says [B, H, T, D], but kda_intra_chunk_bwd and its callers operate on [B, T, H, D]. Leaving this inverted will mislead the next caller into the wrong transpose contract.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/kda/intra_chunk.py` around lines 19 - 20, The module docstring
incorrectly states the tensor layout as [B, H, T, D]; update it to the correct
layout [B, T, H, D] to match how kda_intra_chunk_bwd and its callers expect
tensors, ensuring the documented transpose contract aligns with the
implementation in functions like kda_intra_chunk_bwd.
| B, T, H, D = k.shape | ||
| C = chunk_size | ||
|
|
||
| assert q.ndim == 4, f"q must be 4D [B,T,H,D], got {q.ndim}D" | ||
| assert q.shape == (B, T, H, D), f"q shape {q.shape} != k shape {k.shape}" | ||
| assert g.shape == (B, T, H, D), f"g shape {g.shape} != ({B}, {T}, {H}, {D})" | ||
| assert beta.shape == (B, T, H), f"beta shape {beta.shape} != ({B}, {T}, {H})" | ||
| assert T % C == 0, f"T={T} must be divisible by chunk_size={C}" | ||
|
|
||
| NT = T // C | ||
|
|
||
| assert dAqk.shape == (B, T, H, C), ( | ||
| f"dAqk shape {dAqk.shape} != ({B}, {T}, {H}, {C})" | ||
| ) | ||
| assert dAkk.shape == (B, T, H, C), ( | ||
| f"dAkk shape {dAkk.shape} != ({B}, {T}, {H}, {C})" | ||
| ) | ||
| if segment_ids is not None: | ||
| assert segment_ids.shape == (B, T), ( | ||
| f"segment_ids shape {segment_ids.shape} != ({B}, {T})" | ||
| ) | ||
|
|
||
| if segment_ids is None: | ||
| segment_ids = jnp.zeros((B, T), dtype=jnp.int32) | ||
|
|
There was a problem hiding this comment.
Validate rank and dtype before deriving kernel shapes.
B, T, H, D = k.shape runs before any rank check, and pallas_call hard-codes k.dtype for every output while the kernel computes dq from q.dtype and dbeta from beta.dtype. Bad-rank or mixed-dtype calls will fail late or silently coerce buffers.
🛡️ Suggested contract checks
- B, T, H, D = k.shape
C = chunk_size
-
- assert q.ndim == 4, f"q must be 4D [B,T,H,D], got {q.ndim}D"
+ assert q.ndim == 4, f"q must be 4D [B,T,H,D], got {q.ndim}D"
+ assert k.ndim == 4, f"k must be 4D [B,T,H,D], got {k.ndim}D"
+ assert C > 0, f"chunk_size must be positive, got {C}"
+ assert q.dtype == k.dtype == g.dtype == beta.dtype == dAqk.dtype == dAkk.dtype, (
+ "q/k/g/beta/dAqk/dAkk must share dtype"
+ )
+ B, T, H, D = k.shape
assert q.shape == (B, T, H, D), f"q shape {q.shape} != k shape {k.shape}"
assert g.shape == (B, T, H, D), f"g shape {g.shape} != ({B}, {T}, {H}, {D})"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/kda/intra_chunk.py` around lines 181 - 205, The code unpacks B, T,
H, D from k.shape before validating tensor rank/dtype and pallas_call currently
hard-codes k.dtype while kernel reads dq from q.dtype and dbeta from beta.dtype;
add upfront assertions for q.ndim, k.ndim, g.ndim and beta.ndim and explicit
dtype checks (q.dtype==k.dtype==g.dtype where appropriate, and beta.dtype
checked for dbeta), or use tops.utils.assert_shape_or_none to validate shapes
before B,T,H,D assignment, and change the pallas_call invocation to use the
correct output dtypes (derive dq dtype from q.dtype and dbeta dtype from
beta.dtype) rather than always using k.dtype so mixed-dtype or bad-rank inputs
fail early and deterministically (refer to symbols: k, q, g, beta, dAqk, dAkk,
segment_ids, pallas_call, dq, dbeta, assert_shape_or_none).
|
Warning 规模超限 此 PR 核心代码变更行数为 2015,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
|
Warning 规模超限 此 PR 核心代码变更行数为 1028,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/design-docs/ops/kda/chunk-bwd-intra.md`:
- Around line 15-27: The fenced code blocks containing the ASCII diagram
(starting with "chunk_kda_bwd_wy_dqkg_fused ──→ ..." and the block starting with
"Grid: (NK * NC, NT, B * H)") are missing language identifiers; update each
triple-backtick fence to include a language tag (e.g., ```text) so linters
(MD040) and readers correctly render them, and apply the same fix to the other
similar block later in the file (the block around the "Grid: ..." section
referenced in the review).
- Line 184: The implementation uses two different variable names for the same
reverse cumsum input causing ambiguity; standardize to one name (either dg or
dg2) across the file by replacing usages of the other with the chosen identifier
so both calls to chunk_local_cumsum use the same variable (e.g., ensure
chunk_local_cumsum(dg, reverse=True) and any downstream code that reads/writes
dg2 are updated to dg), and update any related variable assignment sites and
comments referencing dg2 to use dg so chunk_local_cumsum and its consumers are
consistent.
- Line 11: The heading "### 定位与职责" jumps directly to H3 after the front matter;
change it to an H2 (replace "###" with "## 定位与职责") so the document follows a
proper top-level section after the front matter and update any subsequent
sibling headings to maintain consistent incremental levels relative to this
section.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5c5ab626-fcfe-4d3d-9ffc-b7ef7d93d6c5
📒 Files selected for processing (3)
.gitignoredocs/design-docs/ops/kda/chunk-bwd-intra.mdtops/cpu/ops/kda/__init__.py
✅ Files skipped from review due to trivial changes (1)
- .gitignore
🚧 Files skipped from review as they are similar to previous changes (1)
- tops/cpu/ops/kda/init.py
|
|
||
| --- | ||
|
|
||
| ### 定位与职责 |
There was a problem hiding this comment.
Fix heading level jump to satisfy Markdown structure
Line 11 starts at ### directly after front matter/separators; this trips MD001 and hurts doc outline consistency.
Proposed fix
-### 定位与职责
+## 定位与职责📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ### 定位与职责 | |
| ## 定位与职责 |
🧰 Tools
🪛 markdownlint-cli2 (0.22.0)
[warning] 11-11: Heading levels should only increment by one level at a time
Expected: h2; Actual: h3
(MD001, heading-increment)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/design-docs/ops/kda/chunk-bwd-intra.md` at line 11, The heading "###
定位与职责" jumps directly to H3 after the front matter; change it to an H2 (replace
"###" with "## 定位与职责") so the document follows a proper top-level section after
the front matter and update any subsequent sibling headings to maintain
consistent incremental levels relative to this section.
| ``` | ||
| chunk_kda_bwd_wy_dqkg_fused ──→ dq_inter, dk_inter, dg_inter, db_inter, dAkk | ||
| │ | ||
| chunk_kda_bwd_dAv ──→ dAqk │ | ||
| ▼ | ||
| chunk_kda_bwd_intra | ||
| │ | ||
| ▼ | ||
| dq = dq_inter + dq_intra | ||
| dk = dk_inter + dk_intra | ||
| db = db_inter + db_intra | ||
| dg = reverse_cumsum(dg_inter + dg_intra) | ||
| ``` |
There was a problem hiding this comment.
Add language identifiers to fenced code blocks
Both fenced blocks are missing a language tag (MD040), which reduces readability and lint compliance.
Proposed fix
-```
+```text
chunk_kda_bwd_wy_dqkg_fused ──→ dq_inter, dk_inter, dg_inter, db_inter, dAkk
│
chunk_kda_bwd_dAv ──→ dAqk │
▼
chunk_kda_bwd_intra
│
▼
dq = dq_inter + dq_intra
dk = dk_inter + dk_intra
db = db_inter + db_intra
dg = reverse_cumsum(dg_inter + dg_intra)
```diff
-```
+```text
Grid: (NK * NC, NT, B * H)
NK = ceil(K / BK) head_dim 方向的分块数
NC = BT / BC 每个 chunk 内的子块数(通常 64/16 = 4)
NT chunk 数量
B * H batch × head
</details>
Also applies to: 192-198
<details>
<summary>🧰 Tools</summary>
<details>
<summary>🪛 markdownlint-cli2 (0.22.0)</summary>
[warning] 15-15: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
</details>
</details>
<details>
<summary>🤖 Prompt for AI Agents</summary>
Verify each finding against the current code and only fix it if needed.
In @docs/design-docs/ops/kda/chunk-bwd-intra.md around lines 15 - 27, The fenced
code blocks containing the ASCII diagram (starting with
"chunk_kda_bwd_wy_dqkg_fused ──→ ..." and the block starting with "Grid: (NK *
NC, NT, B * H)") are missing language identifiers; update each triple-backtick
fence to include a language tag (e.g., ```text) so linters (MD040) and readers
correctly render them, and apply the same fix to the other similar block later
in the file (the block around the "Grid: ..." section referenced in the review).
</details>
<!-- fingerprinting:phantom:triton:hawk:99aa887f-9c10-473c-8b14-ed945dc3255b -->
<!-- This is an auto-generated comment by CodeRabbit -->
|
|
||
| $$dg_r^{\text{raw}} = \sum_{j=r}^{C} dg_j$$ | ||
|
|
||
| 实现中在 kernel 外通过 `chunk_local_cumsum(dg, reverse=True)` 完成。这一步将 chunk 内的 $`d\mathbf{g}`$ 转换为对原始 $`\mathbf{g}_{\text{raw}}`$ 的梯度。 |
There was a problem hiding this comment.
Unify reverse-cumsum variable naming (dg vs dg2)
Line 184 uses chunk_local_cumsum(dg, reverse=True), but Line 333 uses chunk_local_cumsum(dg2, reverse=True). Keep one name to avoid implementation ambiguity.
Proposed fix
-实现中在 kernel 外通过 `chunk_local_cumsum(dg, reverse=True)` 完成。这一步将 chunk 内的 $`d\mathbf{g}`$ 转换为对原始 $`\mathbf{g}_{\text{raw}}`$ 的梯度。
+实现中在 kernel 外通过 `chunk_local_cumsum(dg2, reverse=True)` 完成。这一步将 chunk 内的 $`d\mathbf{g}`$ 转换为对原始 $`\mathbf{g}_{\text{raw}}`$ 的梯度。📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| 实现中在 kernel 外通过 `chunk_local_cumsum(dg, reverse=True)` 完成。这一步将 chunk 内的 $`d\mathbf{g}`$ 转换为对原始 $`\mathbf{g}_{\text{raw}}`$ 的梯度。 | |
| 实现中在 kernel 外通过 `chunk_local_cumsum(dg2, reverse=True)` 完成。这一步将 chunk 内的 $`d\mathbf{g}`$ 转换为对原始 $`\mathbf{g}_{\text{raw}}`$ 的梯度。 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/design-docs/ops/kda/chunk-bwd-intra.md` at line 184, The implementation
uses two different variable names for the same reverse cumsum input causing
ambiguity; standardize to one name (either dg or dg2) across the file by
replacing usages of the other with the chosen identifier so both calls to
chunk_local_cumsum use the same variable (e.g., ensure chunk_local_cumsum(dg,
reverse=True) and any downstream code that reads/writes dg2 are updated to dg),
and update any related variable assignment sites and comments referencing dg2 to
use dg so chunk_local_cumsum and its consumers are consistent.
|
Warning 规模超限 此 PR 核心代码变更行数为 1028,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
Summary
tops/ops/kda/intra_chunk.py)tops/cpu/ops/kda/chunk_pipeline.py)Test plan
uv run pytest tests/ops/kda/test_kda_alignment.py -v— 5/5 passeduv run pytest tests/ops/kda/test_pallas_intra_chunk.py -v— 8/8 passeduv run ruff check— all checks passed🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests
Chores