feat(kda): add KDA Pallas kernel with chunked forward + backward#174
feat(kda): add KDA Pallas kernel with chunked forward + backward#174
Conversation
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Update project description to be architecture-agnostic - Remove numpy from base deps (transitive via flax) - Move pytest back to base deps (needed everywhere) - Remove unused torchvision dependency - Remove non-existent jax[gpu] and flash-linear-attention[gpu] extras - Remove unused profile optional-dependency group - Remove torchvision from uv sources - Switch CI GPU tests to uv sync - Complete tpuv7 K8s Pod uv setup flow - Update CLAUDE.md install commands Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rule state Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Migrates kda_intra_kernel.py into tops/ops/kda/chunk_intra.py, replacing hardcoded interpret=True with get_interpret() and adding docstrings plus assert_shape validation on all public launcher functions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add chunk_kda function wrapping chunk_kda_fwd with jax.custom_vjp for backward pass support. The initial backward uses JAX autograd through the forward; explicit backward kernels will follow. Update exports in tops/ops/kda/__init__.py and tops/ops/__init__.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Also fix fused_recurrent_kda to use acc_dtype (fp64 for fp64 inputs, fp32 otherwise) matching naive_kda's precision contract, instead of hardcoding float32 for all dtypes. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Fix exp/exp2 mismatch in chunk_gla_fwd_o_gk_pl_kernel: the midpoint stabilization factor for h was always using jnp.exp even when USE_EXP2=True (KDA log2 space gates), causing wrong output for multi-chunk sequences. - Fix backward pass: use fused_recurrent_kda (lax.scan, fully differentiable) instead of chunk_kda_fwd (pallas_call, not differentiable) for gradient computation. - Adjust T=128 forward test case to K=16 (from K=32) to avoid float32 precision issues with extreme value growth. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds chunked and fused-recurrent KDA implementations (TPU Pallas kernels and CPU helpers), intra-/inter-chunk kernels and orchestrator with a custom VJP, new tests and design docs, and re-exports to public ops packages. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User
participant ChunkOrch as chunk_kda_fwd<br/>(Orchestrator)
participant GateCumsum as kda_gate_chunk_cumsum<br/>(Gate Processing)
participant IntraChunk as kda_intra_chunk_fwd<br/>(Intra-chunk Attention)
participant InterChunk as chunk_gated_delta_rule_fwd_h<br/>(Inter-chunk State)
participant OutputFusion as chunk_gla_fwd_o_gk<br/>(Output Assembly)
User->>ChunkOrch: q, k, v, g, beta, initial_state
ChunkOrch->>GateCumsum: g (natural-log)
GateCumsum-->>ChunkOrch: g_cumsum (log2)
ChunkOrch->>IntraChunk: q,k,v,g_cumsum,beta (per-chunk layout)
IntraChunk-->>ChunkOrch: u, w, qg, kg, Aqk, Akk_inv
ChunkOrch->>InterChunk: w, u, g_cumsum, initial_state
InterChunk-->>ChunkOrch: h (per-chunk), v_new, final_state
ChunkOrch->>OutputFusion: Aqk, v_new, qg, kg
OutputFusion-->>ChunkOrch: o
ChunkOrch-->>User: o, final_state
Estimated Code Review Effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly Related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces Kimi Delta Attention (KDA) Pallas TPU kernels, implementing a four-stage pipeline: gate cumsum, intra-chunk triangular solve, inter-chunk state propagation, and output computation. It also includes comprehensive tests for these new operations and fixes a bug in the GLA chunked kernel where the incorrect exponential function was used during hidden state scaling. Additionally, the PR updates CI scripts and project metadata to utilize uv for more efficient dependency management. Feedback focuses on optimizing the Pallas kernels to avoid potential VMEM overflows and reducing overhead from inefficient array operations like stacking and splitting within TPU loops.
| g_diff = g_f32[:, None, :] - g_f32[None, :, :] # (C, C, D) | ||
| g_diff_safe = jnp.where(causal_mask_qk[:, :, None], g_diff, 0.0) | ||
| exp_g_diff = jnp.exp2(g_diff_safe) # (C, C, D) | ||
|
|
||
| # Aqk[i,j] = scale * sum_d q[i,d] * k[j,d] * exp(g[i,d] - g[j,d]) | ||
| Aqk_raw = jnp.sum( | ||
| q.astype(jnp.float32)[:, None, :] * k.astype(jnp.float32)[None, :, :] * exp_g_diff, | ||
| axis=-1, | ||
| ).astype(dtype) # (C, C) |
There was a problem hiding this comment.
The current implementation of Aqk_raw and Akk_raw creates a large intermediate tensor exp_g_diff of shape [C, C, D]. For a typical chunk_size=128 and head_dim=256, this tensor alone consumes 128 * 128 * 256 * 4 bytes = 16MB. Since Pallas kernels on TPU have limited VMEM (often 16MB or 32MB total), having multiple such intermediates (for Aqk and Akk) will likely lead to out-of-memory errors during compilation or execution.
Consider refactoring this to compute the gated dot product in blocks over the head dimension D to keep the VMEM footprint manageable.
| rows = [x_block[r] for r in range(B)] | ||
| for j in range(B): | ||
| if j > 0: | ||
| vec = A_ii[j, :j][None, :] | ||
| mat = jnp.stack(rows[:j]) | ||
| correction = jax.lax.dot_general( | ||
| vec, mat, | ||
| (((1,), (0,)), ((), ())), | ||
| preferred_element_type=jnp.float32 | ||
| ).squeeze(axis=0) | ||
| rows[j] = rows[j] - correction |
There was a problem hiding this comment.
Using a list of arrays and jnp.stack inside a loop within a Pallas kernel is inefficient and may lead to excessive VMEM allocations or compilation overhead. In Pallas/XLA, it is generally better to work with a single array and use indexing or jax.lax.dynamic_update_slice (or at[j].set if supported in the context) to update the solution in-place.
Since B=16 is small and static, the loop will unroll, but the repeated stacking of rows creates many intermediate arrays.
| blocks = jnp.split(b, num_blocks, axis=0) | ||
|
|
||
| for i in range(num_blocks): | ||
| start = i * B | ||
| end = (i + 1) * B | ||
|
|
||
| A_ii = A[start:end, start:end] | ||
| x_block = blocks[i] | ||
|
|
||
| rows = [x_block[r] for r in range(B)] | ||
| for j in range(B): | ||
| if j > 0: | ||
| vec = A_ii[j, :j][None, :] | ||
| mat = jnp.stack(rows[:j]) | ||
| correction = jax.lax.dot_general( | ||
| vec, mat, | ||
| (((1,), (0,)), ((), ())), | ||
| preferred_element_type=jnp.float32 | ||
| ).squeeze(axis=0) | ||
| rows[j] = rows[j] - correction | ||
|
|
||
| x_block = jnp.stack(rows) | ||
| blocks[i] = x_block | ||
|
|
||
| if i < num_blocks - 1: | ||
| rest_start = (i + 1) * B | ||
|
|
||
| x_rest = jnp.concatenate(blocks[i+1:], axis=0) | ||
| A_rest = A[rest_start:, start:end] | ||
|
|
||
| update = jax.lax.dot_general( | ||
| A_rest, x_block, | ||
| (((1,), (0,)), ((), ())), | ||
| preferred_element_type=jnp.float32 | ||
| ) | ||
| x_rest = x_rest - update | ||
|
|
||
| remaining_blocks_count = num_blocks - 1 - i | ||
| new_blocks = jnp.split(x_rest, remaining_blocks_count, axis=0) |
There was a problem hiding this comment.
The solve_unit_lower_triangular function uses jnp.split and jnp.concatenate on VMEM arrays. These operations are often expensive on TPU as they can trigger data movement or reshapes that are not well-optimized by the Pallas compiler compared to direct indexing.
Consider refactoring the block-based substitution to use direct indexing into the b and A arrays instead of splitting them into lists of blocks.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (9)
scripts/launch-tpuv7.yml (1)
61-62: Minor: Inconsistent comment language.Lines 61-62 have Chinese comments while the rest of the file uses English. Consider translating for consistency:
- secretName: ssh-key-secret # 必须与你 kubectl create secret 时起的名字一致 - defaultMode: 0600 # 尝试直接设定权限 (虽然 cp 方式更保险) + secretName: ssh-key-secret # Must match the name used in kubectl create secret + defaultMode: 0600 # Set permissions directly (though cp method is safer)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@scripts/launch-tpuv7.yml` around lines 61 - 62, The inline comments next to the secret configuration keys (secretName and defaultMode) are in Chinese while the rest of the file uses English; update the comments adjacent to secretName and defaultMode to English for consistency (e.g., "must match the name used when creating the kubectl secret" and "attempt to set permissions directly (cp approach is safer)") so the file's comment language is uniform.tops/ops/kda/fused_recurrent.py (1)
104-105: Note:final_statedtype differs from input dtype.The output
ois cast back toq.dtype(line 102), butfinal_stateremains in the accumulator dtype (float32orfloat64). This is likely intentional for precision in state continuation, but worth documenting if callers expect matching dtypes.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/kda/fused_recurrent.py` around lines 104 - 105, The returned final_state currently remains in the accumulator dtype (float32/float64) while the output o is cast to q.dtype; update the function to either (A) cast final_S to q.dtype before assigning final_state (e.g. final_state = final_S.astype(q.dtype) when output_final_state is True) to ensure matching dtypes, or (B) explicitly document in the function docstring/inline comment (referencing final_state, final_S, o, and q.dtype) that final_state intentionally preserves the accumulator precision so callers know it may differ from q.dtype.tops/ops/kda/chunk_intra.py (1)
253-270: Rename ambiguous variableltocorchunk_idx.The variable
lin lambdas is flagged by static analysis as ambiguous (can be confused with digit1). Consider renaming for clarity.🔧 Suggested fix (showing pattern for all occurrences)
in_specs=[ - pl.BlockSpec(index_map=lambda i, j, l: (i, j, l, 0, 0), block_shape=(1, 1, 1, chunk_size, K)), # q + pl.BlockSpec(index_map=lambda i, j, c: (i, j, c, 0, 0), block_shape=(1, 1, 1, chunk_size, K)), # q # ... apply same rename to all other BlockSpec lambdas🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/kda/chunk_intra.py` around lines 253 - 270, The index_map lambdas in the pl.BlockSpec entries use the ambiguous variable name "l"; rename that parameter to a clearer name like "c" or "chunk_idx" across all in_specs and out_specs to avoid confusion (e.g., change index_map=lambda i, j, l: ... to index_map=lambda i, j, c: ...), and update any corresponding tuple construction inside those lambdas (references in pl.BlockSpec usages for q, k, g, beta, v, segment_ids, u, w, qg, kg, Aqk, Akk_inv) so behavior remains identical.tops/ops/kda/chunk.py (3)
216-236: Prefix unused unpacked variables with underscore.
d_final_state,chunk_size, andcu_seqlensare unpacked but unused in the backward pass. This is flagged by static analysis.🔧 Suggested fix
def _chunk_kda_bwd_custom(saved, grad_outputs): - do, d_final_state = grad_outputs - q, k, v, g, beta, scale, initial_state, chunk_size, cu_seqlens = saved + do, _d_final_state = grad_outputs + q, k, v, g, beta, scale, initial_state, _chunk_size, _cu_seqlens = saved🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/kda/chunk.py` around lines 216 - 236, In _chunk_kda_bwd_custom, some items from the saved tuple are unpacked but not used (d_final_state, chunk_size, cu_seqlens); update the unpacking to prefix these with underscores (e.g., _d_final_state, _chunk_size, _cu_seqlens) so they are clearly marked as unused and satisfy static analysis, leaving the rest of the logic (use of q, k, v, g, beta, initial_state and the jax.grad call) unchanged.
111-125: Prefix unused variables with underscore.
qg_bhtandAkk_invare unpacked but never used. Prefix them with_to indicate they are intentionally unused.🔧 Suggested fix
- u_bht, w_bht, qg_bht, kg_bht, Aqk, Akk_inv = kda_intra_chunk_fwd( + u_bht, w_bht, _qg_bht, kg_bht, Aqk, _Akk_inv = kda_intra_chunk_fwd(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/kda/chunk.py` around lines 111 - 125, Unpack kda_intra_chunk_fwd results but mark the unused outputs as intentionally unused by prefixing qg_bht and Akk_inv with underscores (e.g., change qg_bht -> _qg_bht and Akk_inv -> _Akk_inv) in the assignment where kda_intra_chunk_fwd(...) is called; keep the other variable names (u_bht, w_bht, kg_bht, Aqk) unchanged so behavior is identical but linter/reader signals the two values are unused (also update the inline comment if present to reflect the renamed unused symbols).
22-22: Remove unused importpad_to_multiple.This import is not used anywhere in the file.
🔧 Suggested fix
-from tops.utils import assert_shape, assert_shape_or_none, pad_to_multiple +from tops.utils import assert_shape, assert_shape_or_none🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/kda/chunk.py` at line 22, The import list in tops/ops/kda/chunk.py includes an unused symbol pad_to_multiple; remove pad_to_multiple from the from tops.utils import ... statement so only the used helpers (assert_shape, assert_shape_or_none) are imported, ensuring no other references to pad_to_multiple remain in the file.tops/ops/common/chunk_delta_h.py (1)
217-217: Useassert IS_VARLENinstead ofassert IS_VARLEN == True.Per static analysis and Python idiom, use truthiness check directly.
🔧 Suggested fix
- assert IS_VARLEN == True + assert IS_VARLEN🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/common/chunk_delta_h.py` at line 217, The assertion currently uses an explicit comparison "assert IS_VARLEN == True"; change it to a simple truthiness check by replacing that expression with "assert IS_VARLEN" to follow Python idioms and satisfy static analysis; locate the assertion referencing the IS_VARLEN symbol in chunk_delta_h.py (the assertion line that currently compares IS_VARLEN to True) and update it accordingly.tests/ops/kda/test_chunk_kda_tpu.py (2)
104-109: Addstrict=Truetozip()for defensive iteration.Per static analysis hint and best practice, adding
strict=Trueensures the loop fails fast if the iterables have mismatched lengths.🔧 Suggested fix
names = ["dq", "dk", "dv", "dg", "dbeta"] - for name, gn, gc in zip(names, grads_naive, grads_chunk): + for name, gn, gc in zip(names, grads_naive, grads_chunk, strict=True): assert compare_tensor(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/kda/test_chunk_kda_tpu.py` around lines 104 - 109, The loop that zips names, grads_naive, and grads_chunk should use zip(..., strict=True) to fail fast on mismatched lengths; update the for loop that iterates over names, grads_naive, grads_chunk (the block invoking compare_tensor for each name, gn, gc) to call zip(names, grads_naive, grads_chunk, strict=True) so any length mismatch raises immediately and prevents silent mis-alignment.
75-83: Tolerances are quite loose at 5e-2 — verify this is expected.The forward test uses
atol=5e-2, rtol=5e-2which is relatively permissive compared to the fused recurrent tests (1e-10 for fp64, 5e-5 for fp32). If this is due to the chunked approximation introducing more numerical divergence, consider adding a brief comment explaining why these tolerances are necessary.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/kda/test_chunk_kda_tpu.py` around lines 75 - 83, The test uses loose tolerances (atol=5e-2, rtol=5e-2) in the compare_tensor checks for "output" and "final_state" (variables o_naive, o_chunk, s_naive, s_chunk) — either tighten these to match the fused recurrent test thresholds (e.g., 1e-10 for fp64 or 5e-5 for fp32) if chunking should be numerically equivalent, or add a short inline comment above the assertions explaining that the chunked KDA/TPU implementation intentionally incurs greater numerical divergence and therefore requires the relaxed tolerances; update the comment to reference why fp dtype and chunking cause the chosen atol/rtol and keep the compare_tensor calls unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/ops/kda/chunk_intra.py`:
- Around line 28-88: The function solve_unit_lower_triangular assumes B=16
divides N by using num_blocks = N // B and jnp.split, which will fail when chunk
size isn't divisible by 16; either assert divisibility or handle remainders. Fix
by checking N % B == 0 at the start (raise/assert with clear message) or change
splitting logic to compute num_blocks = math.ceil(N / B) and replace jnp.split
with jnp.array_split (or pad the last block) and ensure the block-processing
loops (A_ii, rows logic, and update of x_rest) correctly handle a final block
size < B; update references to B, num_blocks, blocks, A_ii, x_block, and x_rest
accordingly.
- Around line 313-319: Fix the inconsistent mask: in tops/ops/kda/chunk_intra.py
update the definition of causal_mask used for the backward kernel (the variable
named causal_mask that feeds mask_akk) to use a strict greater-than comparison
(idx[:, None] > idx[None, :]) instead of >= so it matches the forward kernel
behavior referenced at line 111; leave causal_mask_qk (used for mask_aqk) as >=
so the two masks (causal_mask -> mask_akk and causal_mask_qk -> mask_aqk)
reflect strictly lower-triangular vs lower-triangular-with-diagonal semantics
consistently.
---
Nitpick comments:
In `@scripts/launch-tpuv7.yml`:
- Around line 61-62: The inline comments next to the secret configuration keys
(secretName and defaultMode) are in Chinese while the rest of the file uses
English; update the comments adjacent to secretName and defaultMode to English
for consistency (e.g., "must match the name used when creating the kubectl
secret" and "attempt to set permissions directly (cp approach is safer)") so the
file's comment language is uniform.
In `@tests/ops/kda/test_chunk_kda_tpu.py`:
- Around line 104-109: The loop that zips names, grads_naive, and grads_chunk
should use zip(..., strict=True) to fail fast on mismatched lengths; update the
for loop that iterates over names, grads_naive, grads_chunk (the block invoking
compare_tensor for each name, gn, gc) to call zip(names, grads_naive,
grads_chunk, strict=True) so any length mismatch raises immediately and prevents
silent mis-alignment.
- Around line 75-83: The test uses loose tolerances (atol=5e-2, rtol=5e-2) in
the compare_tensor checks for "output" and "final_state" (variables o_naive,
o_chunk, s_naive, s_chunk) — either tighten these to match the fused recurrent
test thresholds (e.g., 1e-10 for fp64 or 5e-5 for fp32) if chunking should be
numerically equivalent, or add a short inline comment above the assertions
explaining that the chunked KDA/TPU implementation intentionally incurs greater
numerical divergence and therefore requires the relaxed tolerances; update the
comment to reference why fp dtype and chunking cause the chosen atol/rtol and
keep the compare_tensor calls unchanged.
In `@tops/ops/common/chunk_delta_h.py`:
- Line 217: The assertion currently uses an explicit comparison "assert
IS_VARLEN == True"; change it to a simple truthiness check by replacing that
expression with "assert IS_VARLEN" to follow Python idioms and satisfy static
analysis; locate the assertion referencing the IS_VARLEN symbol in
chunk_delta_h.py (the assertion line that currently compares IS_VARLEN to True)
and update it accordingly.
In `@tops/ops/kda/chunk_intra.py`:
- Around line 253-270: The index_map lambdas in the pl.BlockSpec entries use the
ambiguous variable name "l"; rename that parameter to a clearer name like "c" or
"chunk_idx" across all in_specs and out_specs to avoid confusion (e.g., change
index_map=lambda i, j, l: ... to index_map=lambda i, j, c: ...), and update any
corresponding tuple construction inside those lambdas (references in
pl.BlockSpec usages for q, k, g, beta, v, segment_ids, u, w, qg, kg, Aqk,
Akk_inv) so behavior remains identical.
In `@tops/ops/kda/chunk.py`:
- Around line 216-236: In _chunk_kda_bwd_custom, some items from the saved tuple
are unpacked but not used (d_final_state, chunk_size, cu_seqlens); update the
unpacking to prefix these with underscores (e.g., _d_final_state, _chunk_size,
_cu_seqlens) so they are clearly marked as unused and satisfy static analysis,
leaving the rest of the logic (use of q, k, v, g, beta, initial_state and the
jax.grad call) unchanged.
- Around line 111-125: Unpack kda_intra_chunk_fwd results but mark the unused
outputs as intentionally unused by prefixing qg_bht and Akk_inv with underscores
(e.g., change qg_bht -> _qg_bht and Akk_inv -> _Akk_inv) in the assignment where
kda_intra_chunk_fwd(...) is called; keep the other variable names (u_bht, w_bht,
kg_bht, Aqk) unchanged so behavior is identical but linter/reader signals the
two values are unused (also update the inline comment if present to reflect the
renamed unused symbols).
- Line 22: The import list in tops/ops/kda/chunk.py includes an unused symbol
pad_to_multiple; remove pad_to_multiple from the from tops.utils import ...
statement so only the used helpers (assert_shape, assert_shape_or_none) are
imported, ensuring no other references to pad_to_multiple remain in the file.
In `@tops/ops/kda/fused_recurrent.py`:
- Around line 104-105: The returned final_state currently remains in the
accumulator dtype (float32/float64) while the output o is cast to q.dtype;
update the function to either (A) cast final_S to q.dtype before assigning
final_state (e.g. final_state = final_S.astype(q.dtype) when output_final_state
is True) to ensure matching dtypes, or (B) explicitly document in the function
docstring/inline comment (referencing final_state, final_S, o, and q.dtype) that
final_state intentionally preserves the accumulator precision so callers know it
may differ from q.dtype.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 37ed064d-ae9b-445c-ba5a-0fbe7c90a7d9
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (16)
.github/ci/gpu-tests.sky.yamlCLAUDE.mdpyproject.tomlscripts/launch-tpuv7.ymltests/ops/kda/__init__.pytests/ops/kda/test_chunk_kda_tpu.pytests/ops/kda/test_fused_recurrent_kda_tpu.pytops/ops/__init__.pytops/ops/common/chunk_delta_h.pytops/ops/gla/chunk.pytops/ops/kda/__init__.pytops/ops/kda/chunk.pytops/ops/kda/chunk_intra.pytops/ops/kda/fused_recurrent.pytops/ops/kda/gate.pytops/ops/kda/naive.py
| def solve_unit_lower_triangular(A, b): | ||
| """ | ||
| Solves (I + A) x = b for x, where A is strictly lower triangular. | ||
| Uses block-based forward substitution for better performance on TPU. | ||
| Args: | ||
| A: (N, N) strictly lower triangular matrix in VMEM. | ||
| b: (N, D) matrix in VMEM. | ||
| Returns: | ||
| x: (N, D) solution matrix. | ||
| """ | ||
| N, D = b.shape | ||
| B = 16 | ||
| num_blocks = N // B | ||
| A = A.astype(jnp.float32) | ||
| b = b.astype(jnp.float32) | ||
|
|
||
| blocks = jnp.split(b, num_blocks, axis=0) | ||
|
|
||
| for i in range(num_blocks): | ||
| start = i * B | ||
| end = (i + 1) * B | ||
|
|
||
| A_ii = A[start:end, start:end] | ||
| x_block = blocks[i] | ||
|
|
||
| rows = [x_block[r] for r in range(B)] | ||
| for j in range(B): | ||
| if j > 0: | ||
| vec = A_ii[j, :j][None, :] | ||
| mat = jnp.stack(rows[:j]) | ||
| correction = jax.lax.dot_general( | ||
| vec, mat, | ||
| (((1,), (0,)), ((), ())), | ||
| preferred_element_type=jnp.float32 | ||
| ).squeeze(axis=0) | ||
| rows[j] = rows[j] - correction | ||
|
|
||
| x_block = jnp.stack(rows) | ||
| blocks[i] = x_block | ||
|
|
||
| if i < num_blocks - 1: | ||
| rest_start = (i + 1) * B | ||
|
|
||
| x_rest = jnp.concatenate(blocks[i+1:], axis=0) | ||
| A_rest = A[rest_start:, start:end] | ||
|
|
||
| update = jax.lax.dot_general( | ||
| A_rest, x_block, | ||
| (((1,), (0,)), ((), ())), | ||
| preferred_element_type=jnp.float32 | ||
| ) | ||
| x_rest = x_rest - update | ||
|
|
||
| remaining_blocks_count = num_blocks - 1 - i | ||
| new_blocks = jnp.split(x_rest, remaining_blocks_count, axis=0) | ||
|
|
||
| for k, nb in enumerate(new_blocks): | ||
| blocks[i + 1 + k] = nb | ||
|
|
||
| x = jnp.concatenate(blocks, axis=0) | ||
| return x |
There was a problem hiding this comment.
Block size B=16 assumes chunk_size is divisible by 16.
The solver uses a hardcoded block size of 16 and num_blocks = N // B. If chunk_size is not divisible by 16, jnp.split will fail. Consider adding an assertion or handling non-divisible cases.
🛡️ Suggested defensive assertion
def solve_unit_lower_triangular(A, b):
"""..."""
N, D = b.shape
B = 16
+ assert N % B == 0, f"Matrix size N={N} must be divisible by block size B={B}"
num_blocks = N // B📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def solve_unit_lower_triangular(A, b): | |
| """ | |
| Solves (I + A) x = b for x, where A is strictly lower triangular. | |
| Uses block-based forward substitution for better performance on TPU. | |
| Args: | |
| A: (N, N) strictly lower triangular matrix in VMEM. | |
| b: (N, D) matrix in VMEM. | |
| Returns: | |
| x: (N, D) solution matrix. | |
| """ | |
| N, D = b.shape | |
| B = 16 | |
| num_blocks = N // B | |
| A = A.astype(jnp.float32) | |
| b = b.astype(jnp.float32) | |
| blocks = jnp.split(b, num_blocks, axis=0) | |
| for i in range(num_blocks): | |
| start = i * B | |
| end = (i + 1) * B | |
| A_ii = A[start:end, start:end] | |
| x_block = blocks[i] | |
| rows = [x_block[r] for r in range(B)] | |
| for j in range(B): | |
| if j > 0: | |
| vec = A_ii[j, :j][None, :] | |
| mat = jnp.stack(rows[:j]) | |
| correction = jax.lax.dot_general( | |
| vec, mat, | |
| (((1,), (0,)), ((), ())), | |
| preferred_element_type=jnp.float32 | |
| ).squeeze(axis=0) | |
| rows[j] = rows[j] - correction | |
| x_block = jnp.stack(rows) | |
| blocks[i] = x_block | |
| if i < num_blocks - 1: | |
| rest_start = (i + 1) * B | |
| x_rest = jnp.concatenate(blocks[i+1:], axis=0) | |
| A_rest = A[rest_start:, start:end] | |
| update = jax.lax.dot_general( | |
| A_rest, x_block, | |
| (((1,), (0,)), ((), ())), | |
| preferred_element_type=jnp.float32 | |
| ) | |
| x_rest = x_rest - update | |
| remaining_blocks_count = num_blocks - 1 - i | |
| new_blocks = jnp.split(x_rest, remaining_blocks_count, axis=0) | |
| for k, nb in enumerate(new_blocks): | |
| blocks[i + 1 + k] = nb | |
| x = jnp.concatenate(blocks, axis=0) | |
| return x | |
| def solve_unit_lower_triangular(A, b): | |
| """ | |
| Solves (I + A) x = b for x, where A is strictly lower triangular. | |
| Uses block-based forward substitution for better performance on TPU. | |
| Args: | |
| A: (N, N) strictly lower triangular matrix in VMEM. | |
| b: (N, D) matrix in VMEM. | |
| Returns: | |
| x: (N, D) solution matrix. | |
| """ | |
| N, D = b.shape | |
| B = 16 | |
| assert N % B == 0, f"Matrix size N={N} must be divisible by block size B={B}" | |
| num_blocks = N // B | |
| A = A.astype(jnp.float32) | |
| b = b.astype(jnp.float32) | |
| blocks = jnp.split(b, num_blocks, axis=0) | |
| for i in range(num_blocks): | |
| start = i * B | |
| end = (i + 1) * B | |
| A_ii = A[start:end, start:end] | |
| x_block = blocks[i] | |
| rows = [x_block[r] for r in range(B)] | |
| for j in range(B): | |
| if j > 0: | |
| vec = A_ii[j, :j][None, :] | |
| mat = jnp.stack(rows[:j]) | |
| correction = jax.lax.dot_general( | |
| vec, mat, | |
| (((1,), (0,)), ((), ())), | |
| preferred_element_type=jnp.float32 | |
| ).squeeze(axis=0) | |
| rows[j] = rows[j] - correction | |
| x_block = jnp.stack(rows) | |
| blocks[i] = x_block | |
| if i < num_blocks - 1: | |
| rest_start = (i + 1) * B | |
| x_rest = jnp.concatenate(blocks[i+1:], axis=0) | |
| A_rest = A[rest_start:, start:end] | |
| update = jax.lax.dot_general( | |
| A_rest, x_block, | |
| (((1,), (0,)), ((), ())), | |
| preferred_element_type=jnp.float32 | |
| ) | |
| x_rest = x_rest - update | |
| remaining_blocks_count = num_blocks - 1 - i | |
| new_blocks = jnp.split(x_rest, remaining_blocks_count, axis=0) | |
| for k, nb in enumerate(new_blocks): | |
| blocks[i + 1 + k] = nb | |
| x = jnp.concatenate(blocks, axis=0) | |
| return x |
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 38-38: Unpacked variable D is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/kda/chunk_intra.py` around lines 28 - 88, The function
solve_unit_lower_triangular assumes B=16 divides N by using num_blocks = N // B
and jnp.split, which will fail when chunk size isn't divisible by 16; either
assert divisibility or handle remainders. Fix by checking N % B == 0 at the
start (raise/assert with clear message) or change splitting logic to compute
num_blocks = math.ceil(N / B) and replace jnp.split with jnp.array_split (or pad
the last block) and ensure the block-processing loops (A_ii, rows logic, and
update of x_rest) correctly handle a final block size < B; update references to
B, num_blocks, blocks, A_ii, x_block, and x_rest accordingly.
| idx = jnp.arange(chunk_size, dtype=jnp.int32) | ||
| causal_mask = idx[:, None] >= idx[None, :] | ||
| causal_mask_qk = idx[:, None] >= idx[None, :] | ||
| segment_mask = segment_ids[:, None] == segment_ids[None, :] | ||
|
|
||
| mask_akk = causal_mask & segment_mask | ||
| mask_aqk = causal_mask_qk & segment_mask |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if the forward kernel has different mask definitions
rg -n "causal_mask.*>" tops/ops/kda/chunk_intra.py | head -20Repository: primatrix/pallas-kernel
Length of output: 344
Fix inconsistent mask definitions: causal_mask should use > in backward to match forward kernel.
Line 314 defines causal_mask with >= (includes diagonal), but the forward kernel (line 111) uses > (strictly lower triangular). Line 315's causal_mask_qk correctly uses >= to match the forward kernel. Align line 314 to use > for consistency:
causal_mask = idx[:, None] > idx[None, :] # strictly lower triangular
causal_mask_qk = idx[:, None] >= idx[None, :] # lower triangular with diagonal🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/kda/chunk_intra.py` around lines 313 - 319, Fix the inconsistent
mask: in tops/ops/kda/chunk_intra.py update the definition of causal_mask used
for the backward kernel (the variable named causal_mask that feeds mask_akk) to
use a strict greater-than comparison (idx[:, None] > idx[None, :]) instead of >=
so it matches the forward kernel behavior referenced at line 111; leave
causal_mask_qk (used for mask_aqk) as >= so the two masks (causal_mask ->
mask_akk and causal_mask_qk -> mask_aqk) reflect strictly lower-triangular vs
lower-triangular-with-diagonal semantics consistently.
Add detailed design documentation for KDA chunk-parallel forward and backward kernels, with mathematical derivations and FLA optimized implementation mapping. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
docs/design-docs/ops/kda/chunk-fwd.md (1)
11-11: Fix heading level jump at Line 11.
###appears without a preceding##, which triggers MD001 and breaks heading structure consistency.Proposed markdown fix
-### 核心思想与背景 +## 核心思想与背景🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/design-docs/ops/kda/chunk-fwd.md` at line 11, The heading "核心思想与背景" uses "###" causing a level jump (MD001); change its markdown to a level-2 heading ("## 核心思想与背景") or add an appropriate level-2 parent heading above it so the document follows a consistent heading hierarchy; locate the heading text "核心思想与背景" and adjust its leading hashes (or insert the missing parent heading) to restore correct heading levels.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/design-docs/ops/kda/chunk-bwd.md`:
- Around line 9-11: The heading at "### 核心思想与背景" jumps a level (MD001); fix by
promoting it to a top-level section or adding its parent: either change "###
核心思想与背景" to "## 核心思想与背景" or insert a preceding "##" parent heading (e.g., "##
背景与概述") immediately before the existing "### 核心思想与背景" so the heading hierarchy
is valid.
In `@docs/design-docs/ops/kda/chunk-fwd.md`:
- Around line 324-327: The pseudocode incorrectly re-applies gk gating to the
value path: remove the b_v *= exp2(b_gk_last - b_gk) line so v_new is not gated
again (k-path already handles gating), and keep b_h *= exp2(b_gk_last) for state
decay only; reference the b_gk_last, b_gk, b_v, b_h and v_new symbols and align
the doc to the behavior in tops/ops/common/chunk_delta_h.py where gk only
affects state decay.
---
Nitpick comments:
In `@docs/design-docs/ops/kda/chunk-fwd.md`:
- Line 11: The heading "核心思想与背景" uses "###" causing a level jump (MD001); change
its markdown to a level-2 heading ("## 核心思想与背景") or add an appropriate level-2
parent heading above it so the document follows a consistent heading hierarchy;
locate the heading text "核心思想与背景" and adjust its leading hashes (or insert the
missing parent heading) to restore correct heading levels.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1f8917b0-2fbf-423d-8330-c64e33ec51b2
📒 Files selected for processing (2)
docs/design-docs/ops/kda/chunk-bwd.mddocs/design-docs/ops/kda/chunk-fwd.md
| ### 核心思想与背景 | ||
|
|
||
| 反向传播的目标:给定损失函数对输出的梯度 $`\frac{\partial \mathcal{L}}{\partial \mathbf{o}}`$(记为 $`d\mathbf{o}`$),计算损失函数对所有输入 $`\mathbf{q}, \mathbf{k}, \mathbf{v}, \mathbf{g}, \boldsymbol{\beta}`$ 的梯度,以及对初始状态 $`\mathbf{h}_0`$ 的梯度。 |
There was a problem hiding this comment.
Fix heading level jump to satisfy MD001
Line 9 starts with ### without a preceding ## section level, which breaks heading hierarchy. Please introduce an ## parent heading (or downgrade this heading level) before Line 9.
🧰 Tools
🪛 markdownlint-cli2 (0.22.0)
[warning] 11-11: Heading levels should only increment by one level at a time
Expected: h2; Actual: h3
(MD001, heading-increment)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/design-docs/ops/kda/chunk-bwd.md` around lines 9 - 11, The heading at
"### 核心思想与背景" jumps a level (MD001); fix by promoting it to a top-level section
or adding its parent: either change "### 核心思想与背景" to "## 核心思想与背景" or insert a
preceding "##" parent heading (e.g., "## 背景与概述") immediately before the existing
"### 核心思想与背景" so the heading hierarchy is valid.
| # 3. 门控衰减 | ||
| b_gk_last = gk[chunk 末尾位置] | ||
| b_v *= exp2(b_gk_last - b_gk) # 对 v_new 应用相对衰减 | ||
| b_h *= exp2(b_gk_last) # 对状态应用 chunk 末尾衰减 |
There was a problem hiding this comment.
Correct gk-path pseudocode: v_new should not be gated again.
At Line 324-Line 327, the pseudocode applies exp2(b_gk_last - b_gk) to b_v in the gk flow. That conflicts with the documented/reference behavior in tops/ops/common/chunk_delta_h.py:44-73, where gk is used for state decay and k is already gated.
Proposed doc correction
- # 3. 门控衰减
- b_gk_last = gk[chunk 末尾位置]
- b_v *= exp2(b_gk_last - b_gk) # 对 v_new 应用相对衰减
- b_h *= exp2(b_gk_last) # 对状态应用 chunk 末尾衰减
+ # 3. 门控衰减
+ b_gk_last = gk[chunk 末尾位置]
+ # gk 路径下不对 v_new 额外门控(k/kg 已完成位置门控)
+ b_h *= exp2(b_gk_last) # 仅对状态应用 chunk 末尾衰减📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # 3. 门控衰减 | |
| b_gk_last = gk[chunk 末尾位置] | |
| b_v *= exp2(b_gk_last - b_gk) # 对 v_new 应用相对衰减 | |
| b_h *= exp2(b_gk_last) # 对状态应用 chunk 末尾衰减 | |
| # 3. 门控衰减 | |
| b_gk_last = gk[chunk 末尾位置] | |
| # gk 路径下不对 v_new 额外门控(k/kg 已完成位置门控) | |
| b_h *= exp2(b_gk_last) # 仅对状态应用 chunk 末尾衰减 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/design-docs/ops/kda/chunk-fwd.md` around lines 324 - 327, The pseudocode
incorrectly re-applies gk gating to the value path: remove the b_v *=
exp2(b_gk_last - b_gk) line so v_new is not gated again (k-path already handles
gating), and keep b_h *= exp2(b_gk_last) for state decay only; reference the
b_gk_last, b_gk, b_v, b_h and v_new symbols and align the doc to the behavior in
tops/ops/common/chunk_delta_h.py where gk only affects state decay.
|
Warning 规模超限 此 PR 核心代码变更行数为 2458,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tops/cpu/ops/kda/chunk.py (1)
549-564:⚠️ Potential issue | 🟡 MinorAdd dtype assertions to the public API.
These checks only validate shapes right now.
chunk_kdastill accepts mixed dtypes acrossq/k/v/g/beta/initial_state, which can silently change promotion and make the CPU reference diverge from the kernel path.🛡️ Suggested assertions
assert beta.ndim == 3 and beta.shape == q.shape[:3], ( f"beta shape {beta.shape} != {q.shape[:3]}" ) + assert k.dtype == q.dtype, f"k dtype {k.dtype} != q dtype {q.dtype}" + assert v.dtype == q.dtype, f"v dtype {v.dtype} != q dtype {q.dtype}" + assert g.dtype == q.dtype, f"g dtype {g.dtype} != q dtype {q.dtype}" + assert beta.dtype == q.dtype, f"beta dtype {beta.dtype} != q dtype {q.dtype}" if initial_state is not None: assert initial_state.shape == (B, H, K, V), ( f"initial_state shape {initial_state.shape} != ({B}, {H}, {K}, {V})" ) + assert initial_state.dtype in (q.dtype, acc_dt), ( + f"initial_state dtype {initial_state.dtype} must match q dtype {q.dtype} " + f"or accumulator dtype {acc_dt}" + )As per coding guidelines, "All public functions must enforce strict input assertions on shape and types before executing main logic using
assertinstructions or utilities likeassert_shape_or_nonefromtops.utils".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/kda/chunk.py` around lines 549 - 564, Add strict dtype assertions in chunk_kda: assert q,k,v,g,beta all have the same dtype and that dtype is a supported floating type (e.g., torch.float16/float32/bfloat16), and if initial_state is not None assert initial_state.dtype == v.dtype and initial_state.dtype is the same supported float type; use the existing symbols q, k, v, g, beta, initial_state and add concise assert statements near the existing shape checks so mixed dtypes are rejected before main logic runs.
♻️ Duplicate comments (2)
docs/design-docs/ops/kda/chunk-fwd.md (1)
373-376:⚠️ Potential issue | 🟠 MajorThe
gkpseudocode still gatesv_newtwice.
tops/ops/common/chunk_delta_h.py:44-80computesv_new = u - w @ hand then updates the state with gated keys; it does not apply an extraexp2(b_gk_last - b_gk)factor tov_new. Keeping Line 375 here documents a different recurrence than the implementation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/design-docs/ops/kda/chunk-fwd.md` around lines 373 - 376, The doc's pseudocode incorrectly applies an extra exp2(b_gk_last - b_gk) gating to v_new, which mismatches the implementation in tops/ops/common/chunk_delta_h.py where v_new is computed as v_new = u - w @ h and only the state update uses gated keys; remove the line that multiplies b_v by exp2(b_gk_last - b_gk) (or clarify that only b_h is decayed by exp2(b_gk_last)) so the pseudocode matches the implementation (referencing symbols gk, b_v, b_h, v_new, u, w, h and the function/logic in chunk_delta_h.py).docs/design-docs/ops/kda/chunk-bwd.md (1)
9-9:⚠️ Potential issue | 🟡 MinorFix the heading level jump.
This still starts with
###before any##section, so MD001 remains unresolved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/design-docs/ops/kda/chunk-bwd.md` at line 9, The document has a level-jump: the heading "核心思想与背景" uses ### but no preceding ## exists, causing MD001; fix it by promoting that header to a top-level section (change ### 核心思想与背景 to ## 核心思想与背景) or add a preceding ## section before it so heading levels increase incrementally; update the heading token for "核心思想与背景" accordingly to resolve the MD001 rule.
🧹 Nitpick comments (1)
tops/cpu/ops/kda/chunk.py (1)
593-600: Avoid computing backward intermediates on the forward-only path.
chunk_kdanow always materializesAqkandAkk, then drops both values. That addsO(B·H·NT·C²)work and memory to every CPU reference forward even when callers only needoandfinal_state. Areturn_intermediates=Falsefast path would keep the reference cheaper and also clear the unused-variable warning.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/kda/chunk.py` around lines 593 - 600, The forward path in chunk_kda is always materializing Aqk and Akk via chunk_kda_fwd then discarding them, which wastes O(B·H·NT·C²) compute/memory; add a fast-path flag (e.g., return_intermediates=False) to chunk_kda and only request/compute Aqk and Akk from chunk_kda_fwd when that flag is true (or when callers explicitly need intermediates), otherwise call chunk_kda_fwd with output_intermediates=False (or an equivalent param) and only capture o and final_state; update the chunk_kda signature and callers to default to not materializing intermediates and remove the allocations/assignments of Aqk/Akk on the forward-only path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/design-docs/ops/kda/chunk-bwd.md`:
- Around line 671-676: The tables incorrectly list dAqk and dAkk as [B, T, H, C]
but they are chunk-local CxC matrices; update their shapes in the Stage 1 (and
also the entries at 686-695) to [B, H, NT, C, C] to match the symbol definitions
and referenced kernels (symbols: dAqk, dAkk); also verify dv remains [B, T, H,
V] and adjust any surrounding textual descriptions that rely on the old shape
notation.
- Around line 15-20: The fenced code blocks containing the pseudocode (the block
with "Step 0: g = cumsum(g_raw) ... Step 3: o = chunk_gla_fwd_o_gk(...)") and
the other similar blocks at the ranges mentioned should include an info string
(e.g., ```text or ```text+pseudo) to satisfy MD040; update each triple-backtick
opening fence for those blocks (including the blocks around lines referenced as
48-55 and 638-652) to include a language token so the linter stops flagging
them.
In `@docs/design-docs/ops/kda/chunk-fwd.md`:
- Line 11: The heading "### 核心思想与背景" uses h3 without a parent h2 which fails
MD001; update this heading to "## 核心思想与背景" (or insert a higher-level parent
section above it) so the document has a proper hierarchy—locate the line with
the exact text "### 核心思想与背景" and promote it to an h2 (change the leading ### to
##) or add an appropriate parent h2 before it.
---
Outside diff comments:
In `@tops/cpu/ops/kda/chunk.py`:
- Around line 549-564: Add strict dtype assertions in chunk_kda: assert
q,k,v,g,beta all have the same dtype and that dtype is a supported floating type
(e.g., torch.float16/float32/bfloat16), and if initial_state is not None assert
initial_state.dtype == v.dtype and initial_state.dtype is the same supported
float type; use the existing symbols q, k, v, g, beta, initial_state and add
concise assert statements near the existing shape checks so mixed dtypes are
rejected before main logic runs.
---
Duplicate comments:
In `@docs/design-docs/ops/kda/chunk-bwd.md`:
- Line 9: The document has a level-jump: the heading "核心思想与背景" uses ### but no
preceding ## exists, causing MD001; fix it by promoting that header to a
top-level section (change ### 核心思想与背景 to ## 核心思想与背景) or add a preceding ##
section before it so heading levels increase incrementally; update the heading
token for "核心思想与背景" accordingly to resolve the MD001 rule.
In `@docs/design-docs/ops/kda/chunk-fwd.md`:
- Around line 373-376: The doc's pseudocode incorrectly applies an extra
exp2(b_gk_last - b_gk) gating to v_new, which mismatches the implementation in
tops/ops/common/chunk_delta_h.py where v_new is computed as v_new = u - w @ h
and only the state update uses gated keys; remove the line that multiplies b_v
by exp2(b_gk_last - b_gk) (or clarify that only b_h is decayed by
exp2(b_gk_last)) so the pseudocode matches the implementation (referencing
symbols gk, b_v, b_h, v_new, u, w, h and the function/logic in
chunk_delta_h.py).
---
Nitpick comments:
In `@tops/cpu/ops/kda/chunk.py`:
- Around line 593-600: The forward path in chunk_kda is always materializing Aqk
and Akk via chunk_kda_fwd then discarding them, which wastes O(B·H·NT·C²)
compute/memory; add a fast-path flag (e.g., return_intermediates=False) to
chunk_kda and only request/compute Aqk and Akk from chunk_kda_fwd when that flag
is true (or when callers explicitly need intermediates), otherwise call
chunk_kda_fwd with output_intermediates=False (or an equivalent param) and only
capture o and final_state; update the chunk_kda signature and callers to default
to not materializing intermediates and remove the allocations/assignments of
Aqk/Akk on the forward-only path.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ebe2df40-8dad-48d8-ac2c-0595d4c26435
📒 Files selected for processing (4)
docs/design-docs/ops/kda/chunk-bwd.mddocs/design-docs/ops/kda/chunk-fwd.mdtops/cpu/ops/kda/__init__.pytops/cpu/ops/kda/chunk.py
✅ Files skipped from review due to trivial changes (1)
- tops/cpu/ops/kda/init.py
| ``` | ||
| Step 0: g = cumsum(g_raw) [chunk_local_cumsum] | ||
| Step 1: w, u, kg, Aqk, Akk = chunk_kda_fwd_intra(q, k, v, g, β) [chunk_kda_fwd_intra] | ||
| Step 2: h, v_new = chunk_gated_delta_rule_fwd_h(kg, w, u, gk=g) [chunk_gated_delta_rule_fwd_h] | ||
| Step 3: o = chunk_gla_fwd_o_gk(q, v_new, g, Aqk, h) [chunk_gla_fwd_o_gk] | ||
| ``` |
There was a problem hiding this comment.
Add languages to the fenced code blocks.
These fences are missing an info string, so MD040 will keep firing. Even text is enough where the block is pseudocode.
Also applies to: 48-55, 638-652
🧰 Tools
🪛 markdownlint-cli2 (0.22.0)
[warning] 15-15: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/design-docs/ops/kda/chunk-bwd.md` around lines 15 - 20, The fenced code
blocks containing the pseudocode (the block with "Step 0: g = cumsum(g_raw) ...
Step 3: o = chunk_gla_fwd_o_gk(...)") and the other similar blocks at the ranges
mentioned should include an info string (e.g., ```text or ```text+pseudo) to
satisfy MD040; update each triple-backtick opening fence for those blocks
(including the blocks around lines referenced as 48-55 and 638-652) to include a
language token so the linter stops flagging them.
| 阶段 1 输出: | ||
|
|
||
| | 张量 | Shape | 说明 | | ||
| |---|---|---| | ||
| | `dAqk` | `[B, T, H, C]` | $`d\mathbf{o} \cdot \mathbf{v}_{\text{new}}^\top`$(下三角) | | ||
| | `dv` | `[B, T, H, V]` | $`\mathbf{A}_{qk}^\top \cdot d\mathbf{o}`$ | |
There was a problem hiding this comment.
The summary tables list the matrix-gradient shapes incorrectly.
dAqk and dAkk are chunk-local C x C matrices, so these entries should be [B, H, NT, C, C], not [B, T, H, C]. The current table contradicts the symbol definitions above and the referenced kernels.
📝 Suggested doc fix
-| `dAqk` | `[B, T, H, C]` | $`d\mathbf{o} \cdot \mathbf{v}_{\text{new}}^\top`$(下三角) |
+| `dAqk` | `[B, H, NT, C, C]` | $`d\mathbf{o} \cdot \mathbf{v}_{\text{new}}^\top`$(下三角) |
...
-| `dAkk` | `[B, T, H, C]` | 键-键矩阵逆的梯度 |
+| `dAkk` | `[B, H, NT, C, C]` | 键-键矩阵逆的梯度 |Also applies to: 686-695
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/design-docs/ops/kda/chunk-bwd.md` around lines 671 - 676, The tables
incorrectly list dAqk and dAkk as [B, T, H, C] but they are chunk-local CxC
matrices; update their shapes in the Stage 1 (and also the entries at 686-695)
to [B, H, NT, C, C] to match the symbol definitions and referenced kernels
(symbols: dAqk, dAkk); also verify dv remains [B, T, H, V] and adjust any
surrounding textual descriptions that rely on the old shape notation.
|
|
||
| 本设计文档记录了 Kimi Delta Attention (KDA) 的 Chunk-Parallel 前向内核的设计与实现逻辑。 | ||
|
|
||
| ### 核心思想与背景 |
There was a problem hiding this comment.
Fix the heading hierarchy.
### jumps straight from the title block to an h3. Promote this to ## or insert a parent section so the document passes MD001.
🧰 Tools
🪛 markdownlint-cli2 (0.22.0)
[warning] 11-11: Heading levels should only increment by one level at a time
Expected: h2; Actual: h3
(MD001, heading-increment)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/design-docs/ops/kda/chunk-fwd.md` at line 11, The heading "### 核心思想与背景"
uses h3 without a parent h2 which fails MD001; update this heading to "##
核心思想与背景" (or insert a higher-level parent section above it) so the document has
a proper hierarchy—locate the line with the exact text "### 核心思想与背景" and promote
it to an h2 (change the leading ### to ##) or add an appropriate parent h2
before it.
# Conflicts: # tops/cpu/ops/kda/__init__.py # tops/cpu/ops/kda/chunk.py
|
Warning 规模超限 此 PR 核心代码变更行数为 2497,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
- Extract chunk_kda_bwd_intra as standalone Stage 4 function from chunk_kda_bwd - Add vectorized PyTorch reference (chunk_bwd_intra_ref.py): replace triple loop with 4 batched matmuls using decay factorization exp(g_i-g_j)=exp(g_i)·exp(-g_j) - Add test_chunk_kda.py: cross-validation, symmetry, and CPU-ref vs Triton tests - Fix OOM in large-shape Triton test: use numpy with contiguous layout for JAX→FLA conversion, run JAX ref on CPU for large tensors Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Warning 规模超限 此 PR 核心代码变更行数为 2612,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
|
Warning 规模超限 此 PR 核心代码变更行数为 1955,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
Summary
chunk_kda_fwd) and backward pass viacustom_vjptops/ops/common/chunk_delta_h.py) — shared by KDA and future Gated Delta Rule opsfused_recurrent_kda) vialax.scanas CPU-compatible referenceUSE_EXP2midpoint stabilization bug inchunk_gla_fwd_o_gk_pl_kernel—jnp.expwas used instead ofjnp.exp2for the h factor whenUSE_EXP2=TrueNew files
tops/ops/kda/— gate.py, chunk.py, chunk_intra.py, fused_recurrent.py, naive.pytops/ops/common/chunk_delta_h.py— pure JAX ref + Pallas TPU kernel (fwd)tests/ops/kda/— test_chunk_kda_tpu.py, test_fused_recurrent_kda_tpu.pyTest plan
PALLAS_INTERPRET=1🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Improvements
Tests
Documentation