Skip to content

fix: resolve chunk_bwd_dhu signature mismatch and test OOM#197

Merged
0xaskr merged 6 commits intomainfrom
fix/chunk-bwd-dhu-signature
Apr 14, 2026
Merged

fix: resolve chunk_bwd_dhu signature mismatch and test OOM#197
0xaskr merged 6 commits intomainfrom
fix/chunk-bwd-dhu-signature

Conversation

@0xaskr
Copy link
Copy Markdown
Collaborator

@0xaskr 0xaskr commented Apr 14, 2026

Summary

  • Fix common/chunk_delta_h.py: add missing imports (lax, cdiv, exp2, assert_shape, pad_to_multiple), remove invalid @cpu_reference decorator, align function signature to keyword-only style matching the caller convention, replace undefined DEFAULT_BT with 64
  • Fix test_chunk_bwd.py: release GPU tensors between Triton and JAX runs to prevent CUDA OOM on the B1_T4096_H8_K128_V128 configuration (6GB GPU)

Test plan

  • All 6 e2e tests in test_chunk_bwd.py pass (including the previously OOM B1_T4096_H8 case)
  • CI passes

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added WY-based recompute and hidden-state propagation support for backward KDA flows.
  • Tests

    • Added end-to-end backward validation and skipped an internal sub-function test; aligned CPU/TPU test conventions and adjusted test comparisons.
  • Refactor

    • Reorganized KDA backward into stage-based components, expanded public backward utilities, and removed the legacy monolithic CPU reference forward/backward implementation.
  • Chores

    • Updated CI/test invocation to exclude specific heavy test files.

…st OOM

- Fix common/chunk_delta_h.py: add missing imports (lax, cdiv, exp2, assert_shape,
  pad_to_multiple), remove invalid @cpu_reference decorator, change function
  signature to keyword-only style matching the caller convention, replace
  undefined DEFAULT_BT with 64
- Fix test_chunk_bwd.py: release GPU tensors between Triton and JAX runs to
  prevent CUDA OOM on the B1_T4096_H8 configuration (6GB GPU)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2066,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 14, 2026

Warning

Rate limit exceeded

@0xaskr has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 12 minutes and 35 seconds before requesting another review.

Your organization is not enrolled in usage-based pricing. Contact your admin to enable usage-based pricing to continue reviews beyond the rate limit, or try again in 12 minutes and 35 seconds.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2b64b98b-cfd8-45b6-85a8-ba89f3046040

📥 Commits

Reviewing files that changed from the base of the PR and between 41fe955 and 2b125e2.

📒 Files selected for processing (2)
  • .github/ci/gpu-tests.sky.yaml
  • .github/workflows/gpu_tests.yml
📝 Walkthrough

Walkthrough

Refactors KDA backward into staged components, adds CPU reference modules for gated-delta hidden-state and WY representation, removes the legacy monolithic CPU chunk forward/backward, updates package exports, and adds an end-to-end Triton vs JAX CPU backward test comparing gradients via recomputed intermediates.

Changes

Cohort / File(s) Summary
Common gated-delta helpers
tops/cpu/ops/common/chunk_delta_h.py, tops/cpu/ops/common/__init__.py
Adds gated-delta inter-chunk forward/backward helpers (chunk_gated_delta_rule_fwd_h, chunk_gated_delta_rule_bwd_dhu) and re-exports them from the common package.
New CPU KDA reference modules
tops/cpu/ops/kda/wy_fast.py, tops/cpu/ops/kda/chunk_intra.py
Introduces WY recompute/backprop (recompute_w_u_fwd, prepare_wy_repr_bwd) and intra-chunk backward (chunk_kda_bwd_intra) CPU implementations with chunked reshaping, masking, and dA handling.
KDA backward orchestrator & API updates
tops/cpu/ops/kda/chunk_bwd.py, tops/cpu/ops/kda/__init__.py
Removes gated-delta implementation from this file; adds a stage-based chunk_kda_bwd orchestrator, makes A/scale optional in sub-stages, refactors WY-fused stage, and updates kda exports.
Removed legacy CPU chunk implementation
tops/cpu/ops/kda/chunk.py
Deletes the previous monolithic CPU reference chunk.py (full forward/backward reference impl removed).
Tests: Triton/JAX backward E2E & callsite adjustments
tests/ref/kda/special/test_chunk_bwd.py, tests/ref/kda/test_chunk_kda.py, tests/ops/kda/test_chunk_kda_tpu.py, pyproject.toml
Adds an end-to-end Triton vs JAX backward test (uses Triton recompute with disable_recompute=True), adjusts gate/sign conventions and call signatures in tests, and updates pytest ignore entries.
CI/test invocation tweak
.github/ci/gpu-tests.sky.yaml
Adjusts pytest invocation to exclude specific tests via --ignore= instead of -k triton selection.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test Runner
    participant TritonFwd as Triton Forward
    participant GPU as GPU Memory/Kernels
    participant TritonBwd as Triton Backward
    participant JAX as JAX CPU Reference

    Test->>TritonFwd: call triton_chunk_kda_fwd(..., disable_recompute=True)
    TritonFwd->>GPU: execute forward kernel, save intermediates
    TritonFwd->>Test: return intermediates (on GPU)
    Test->>TritonBwd: call triton_chunk_kda_bwd(..., disable_recompute=True, intermediates)
    TritonBwd->>GPU: execute backward, produce gradients
    TritonBwd->>Test: copy gradients -> CPU (numpy)
    Test->>GPU: free GPU memory
    Test->>JAX: call jax_chunk_kda_bwd(..., disable_recompute=True, intermediates converted to CPU)
    JAX->>Test: return CPU gradients
    Test->>Test: compare gradients (dq/dk/dv/db/dg[, dh0])
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~80 minutes

Possibly related PRs

  • Fix/new api #196 — Modifies the same test tests/ref/kda/special/test_chunk_bwd.py adapting to the new forward/backward API and test harness.
  • add chunk bwd wy-dqkg-fused #190 — Related WY-fused backward refactor and tests (chunk_kda_bwd_wy_dqkg_fused changes).
  • update pallas kernel #38 — Related gated-delta forward/backward logic and interface changes (chunk gated-delta helpers).

Suggested labels

cpu-ref

Suggested reviewers

  • labyrinth-ssr

Poem

🐰 Hopping through chunks where gradients play,

gates and WY dance their clever way,
Triton runs, JAX recomputes the tune,
numbers match beneath the moon,
a little rabbit cheers — bravo, smooth! 🥕

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Title check ⚠️ Warning The title references resolving chunk_bwd_dhu signature mismatch and test OOM, but based on the changeset summary, the main changes involve refactoring KDA backward operations, extracting hidden-state propagation to a new module, and adding end-to-end tests. Revise title to better reflect the primary changes: consider 'refactor: extract hidden-state propagation and add KDA backward end-to-end tests' or similar to accurately represent the substantial refactoring of chunk_bwd into multiple specialized modules.
Docstring Coverage ⚠️ Warning Docstring coverage is 62.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/chunk-bwd-dhu-signature

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the KDA backward pass implementation for JAX CPU reference kernels to align with the Triton 6-stage pipeline. It introduces new modules for inter-chunk hidden state propagation, intra-chunk gradients, and WY representation logic, alongside comprehensive end-to-end tests. Review feedback points out a missing ln(2) factor in the gate gradient calculation, suggests using the DEFAULT_BT constant for consistency in chunk size defaults, and recommends safer handling of keyword arguments to prevent potential KeyError exceptions when recomputation is disabled.

dk_out = dk_out.at[:, n, j_start:j_end].add(dk_kk_j)

# dg: from the exp2(g_i - g_j) gate gradient
dg_chunk = q_c[:, n] * dq_out[:, n] - k_c[:, n] * dk_out[:, n]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The gradient calculation for the gate g appears to be missing the ln(2) factor. Since jnp.exp2 (base 2) is used in the forward pass, the derivative of $2^g$ with respect to $g$ is $\ln(2) \cdot 2^g$. Without this factor, the gradient will be scaled incorrectly by approximately $0.693$. If this is an intentional convention to match a specific Triton kernel implementation, please add a comment explaining it; otherwise, it should be included for mathematical correctness.

g_org: jax.Array | None = None,
cu_seqlens: jax.Array | None = None,
chunk_indices: jax.Array | None = None,
chunk_size: int = 64,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value for chunk_size is hardcoded to 64 here, whereas other functions in this file use the DEFAULT_BT constant defined on line 23. For consistency and easier maintenance, please use DEFAULT_BT.

Suggested change
chunk_size: int = 64,
chunk_size: int = DEFAULT_BT,

Comment on lines +613 to +618
w = kwargs["w"]
u = kwargs["u"]
qg = kwargs["qg"]
kg = kwargs["kg"]
v_new = kwargs["v_new"]
h = kwargs["h"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing kwargs directly for required intermediates when disable_recompute is True can lead to KeyError if the caller forgets to provide them. It is safer to use .get() with a clear error message or, preferably, include these as optional keyword-only arguments in the function signature for better discoverability and type safety.

    w = kwargs.get("w")
    u = kwargs.get("u")
    qg = kwargs.get("qg")
    kg = kwargs.get("kg")
    v_new = kwargs.get("v_new")
    h = kwargs.get("h")
    if any(x is None for x in [w, u, qg, kg, v_new, h]):
        raise ValueError("Intermediates (w, u, qg, kg, v_new, h) must be provided in kwargs when disable_recompute=True")

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (10)
tops/cpu/ops/common/chunk_delta_h.py (2)

303-306: Optional: Sort __all__ alphabetically.

Static analysis flagged that __all__ is not sorted.

✨ Proposed fix
 __all__ = [
+  "chunk_gated_delta_rule_bwd_dhu",
   "chunk_gated_delta_rule_fwd_h",
-  "chunk_gated_delta_rule_bwd_dhu",
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/common/chunk_delta_h.py` around lines 303 - 306, The __all__
list is not alphabetically ordered; update the __all__ definition to be sorted
so exports are deterministic and satisfy static analysis—specifically reorder
the symbols so "chunk_gated_delta_rule_bwd_dhu" appears before
"chunk_gated_delta_rule_fwd_h" in the __all__ list.

22-58: Add input shape assertions for chunk_gated_delta_rule_fwd_h.

Per coding guidelines, public functions need strict input assertions. The backward function (chunk_gated_delta_rule_bwd_dhu) has proper assertions, but the forward function lacks them.

🔧 Proposed fix
 def chunk_gated_delta_rule_fwd_h(
   ...
 ) -> tuple[jax.Array, jax.Array | None, jax.Array | None]:
   """..."""
   del g, cu_seqlens, transpose_state_layout

   B, T, H, K = k.shape
   V = u.shape[-1]
   C = chunk_size
   NT = T // C
+  
+  # =================== input shape assertions ===================
+  assert_shape(k, (B, T, H, K), "k")
+  assert_shape(w, (B, T, H, K), "w")
+  assert_shape(u, (B, T, H, V), "u")
+  if gk is not None:
+      assert_shape(gk, (B, T, H, K), "gk")
+  if initial_state is not None:
+      assert_shape(initial_state, (B, H, K, V), "initial_state")
+  assert T % C == 0, f"T={T} must be divisible by chunk_size={C}"
+  # ==============================================================
+  
   acc = acc_dtype(k.dtype)

As per coding guidelines: "All public functions must enforce strict input assertions on shape and types before executing main logic."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/common/chunk_delta_h.py` around lines 22 - 58, The forward
function chunk_gated_delta_rule_fwd_h lacks the strict input shape/type
assertions required for public APIs; add the same set of validations used in
chunk_gated_delta_rule_bwd_dhu (or equivalent) at the start of
chunk_gated_delta_rule_fwd_h: assert k, w, u are jax.Array with matching
batch/time/feature dims (k: [B,T,H,K], w: [B,T,H,K], u: [B,T,H,V]), assert
optional gk has shape [B,T,H,K] when provided, initial_state has shape [B,H,K,V]
or [N,H,K,V] when provided, chunk_size is positive int,
output_final_state/save_new_value are bools, and raise clear AssertionError
messages on mismatch; place these checks before any computation (e.g., right
after the docstring and before del g,...).
tops/cpu/ops/kda/chunk_bwd.py (2)

1-6: Remove duplicate jax import.

jax is imported twice (lines 1 and 6).

✨ Proposed fix
-import jax
-
 import math
 from typing import Optional, Tuple

 import jax
 import jax.numpy as jnp
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk_bwd.py` around lines 1 - 6, There is a duplicate
import of the module name jax; remove the redundant import statement so jax is
imported only once (keep the first occurrence and delete the second), and ensure
the remaining imports (jax, math, Optional, Tuple) are correctly ordered and
unique in chunk_bwd.py.

526-608: Add input shape assertions for chunk_kda_bwd orchestrator.

Per coding guidelines, public functions should have input assertions. While this orchestrator delegates to sub-functions that have their own assertions, adding top-level validation provides better error messages and catches mismatches early.

🔧 Proposed fix
 def chunk_kda_bwd(
   ...
 ) -> tuple[...]:
   """Full chunk KDA backward..."""
   del cu_seqlens, chunk_indices, safe_gate, cp_context
   del transpose_state_layout

   BT = chunk_size
+  B, T, H, K = q.shape
+  V = v.shape[-1]
+  NT = T // BT
+
+  # =================== input shape assertions ===================
+  assert_shape(q, (B, T, H, K), "q")
+  assert_shape(k, (B, T, H, K), "k")
+  assert_shape(v, (B, T, H, V), "v")
+  assert_shape(beta, (B, T, H), "beta")
+  assert_shape(Aqk, (B, T, H, BT), "Aqk")
+  assert_shape(Akk, (B, T, H, BT), "Akk")
+  assert_shape(do, (B, T, H, V), "do")
+  if initial_state is not None:
+      assert_shape(initial_state, (B, H, K, V), "initial_state")
+  if dht is not None:
+      assert_shape(dht, (B, H, K, V), "dht")
+  if g is not None:
+      assert_shape(g, (B, T, H, K), "g")
+  assert T % BT == 0, f"T={T} must be divisible by chunk_size={BT}"
+  # ==============================================================

   # ---- Stage 0: Recompute forward intermediates ----

As per coding guidelines: "All public functions must enforce strict input assertions on shape and types before executing main logic."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk_bwd.py` around lines 526 - 608, Add top-level input
assertions in chunk_kda_bwd to validate shapes and types before orchestration:
check q/k shapes [B,T,H,K] match each other, v shape [B,T,H,V], beta shape
[B,T,H], Aqk/Akk shapes [B,T,H,BT] match T/block expectations, do shape
[B,T,H,V], initial_state (if not None) shape [B,H,K,V], dht (if not None) shape
[B,H,K,V], and optional g/g_org shapes [B,T,H,K]; also assert A_log shape [H]
and dt_bias shape [H*K] when provided, and that scale is a scalar and chunk_size
is int>0; raise clear ValueError messages referencing chunk_kda_bwd and the
offending parameter (e.g., "chunk_kda_bwd: q and k must have same shape") so
mismatches are caught early.
tests/ref/kda/special/test_chunk_bwd.py (2)

266-271: Prefix unused variable with underscore.

Static analysis flagged final_state as unused. Prefix with _ to indicate intentional non-use.

✨ Proposed fix
-    (o, final_state, g_cumsum, Aqk, Akk,
+    (o, _final_state, g_cumsum, Aqk, Akk,
      w, u, qg, kg, v_new, h, initial_state) = triton_chunk_kda_fwd(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/special/test_chunk_bwd.py` around lines 266 - 271, The tuple
unpacking from triton_chunk_kda_fwd currently binds an unused variable
final_state; to satisfy static analysis, prefix it with an underscore (e.g.,
_final_state) in the unpacking expression where triton_chunk_kda_fwd(...) is
called and ensure any later references (if none exist) are either removed or
updated to use the new name.

1-3: Setting environment variable at import time affects all tests in the module.

Setting TRITON_F32_DEFAULT at module import time is a side effect that could affect other tests or be sensitive to import order. Consider using a pytest fixture with monkeypatch.setenv for better isolation.

✨ Alternative using pytest fixture
`@pytest.fixture`(autouse=True)
def set_triton_f32_ieee(monkeypatch):
    monkeypatch.setenv("TRITON_F32_DEFAULT", "ieee")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/special/test_chunk_bwd.py` around lines 1 - 3, Replace the
module-level side-effect os.environ["TRITON_F32_DEFAULT"] = "ieee" with an
autouse pytest fixture so the environment change is scoped to each test; add a
fixture named e.g. set_triton_f32_ieee that uses
monkeypatch.setenv("TRITON_F32_DEFAULT", "ieee") (and remove the import-time
os.environ assignment and unnecessary module-level import if no longer needed)
so tests are isolated and import-time side effects are avoided.
tops/cpu/ops/kda/wy_fast.py (3)

227-230: Optional: Sort __all__ alphabetically.

Static analysis flagged that __all__ is not sorted. This is a minor style issue.

✨ Proposed fix
 __all__ = [
+  "prepare_wy_repr_bwd",
   "recompute_w_u_fwd",
-  "prepare_wy_repr_bwd",
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/wy_fast.py` around lines 227 - 230, The __all__ list
currently contains ["recompute_w_u_fwd", "prepare_wy_repr_bwd"] unsorted; please
reorder the entries alphabetically so exported symbols are in lexicographic
order (i.e., update the module-level __all__ to list "prepare_wy_repr_bwd"
before "recompute_w_u_fwd") while keeping the same symbols and preserving
surrounding formatting.

15-46: Add input shape assertions for recompute_w_u_fwd.

Per coding guidelines, public functions must enforce strict input assertions. Add shape validation before the main logic.

🔧 Proposed fix
+from tops.utils import assert_shape
+
 def recompute_w_u_fwd(
   ...
 ) -> tuple[jax.Array, jax.Array, jax.Array | None, jax.Array | None]:
   """..."""
   del cu_seqlens, chunk_indices

   B, T, H, K = k.shape
   V = v.shape[-1]
   BT = A.shape[-1]
   NT = T // BT
+  
+  # =================== input shape assertions ===================
+  assert_shape(k, (B, T, H, K), "k")
+  assert_shape(v, (B, T, H, V), "v")
+  assert_shape(beta, (B, T, H), "beta")
+  assert_shape(A, (B, T, H, BT), "A")
+  if q is not None:
+      assert_shape(q, (B, T, H, K), "q")
+  if gk is not None:
+      assert_shape(gk, (B, T, H, K), "gk")
+  assert T % BT == 0, f"T={T} must be divisible by chunk_size={BT}"
+  # ==============================================================
+  
   acc = acc_dtype(k.dtype)

As per coding guidelines: "All public functions must enforce strict input assertions on shape and types before executing main logic."


96-145: Add input shape assertions for prepare_wy_repr_bwd.

Same guideline applies - add shape validation for this backward function.

🔧 Proposed fix
 def prepare_wy_repr_bwd(
   ...
 ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
   """..."""
   del cu_seqlens, chunk_indices

   B, T, H, K = k.shape
   V = v.shape[-1]
   BT = A.shape[-1]
   NT = T // BT
+  
+  # =================== input shape assertions ===================
+  assert_shape(k, (B, T, H, K), "k")
+  assert_shape(v, (B, T, H, V), "v")
+  assert_shape(beta, (B, T, H), "beta")
+  assert_shape(gk, (B, T, H, K), "gk")
+  assert_shape(A, (B, T, H, BT), "A")
+  assert_shape(dk, (B, T, H, K), "dk")
+  assert_shape(dw, (B, T, H, K), "dw")
+  assert_shape(du, (B, T, H, V), "du")
+  assert_shape(dg, (B, T, H, K), "dg")
+  assert T % BT == 0, f"T={T} must be divisible by chunk_size={BT}"
+  # ==============================================================
+  
   acc = acc_dtype(k.dtype)

As per coding guidelines: "All public functions must enforce strict input assertions on shape and types before executing main logic."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/wy_fast.py` around lines 96 - 145, Add strict input
assertions at the start of prepare_wy_repr_bwd: verify each argument is a
jax.Array (k, v, beta, gk, A, dk, dw, du, dg) and that their ranks and shapes
match the docstring contracts (k and gk and dk and dw and dg -> [B,T,H,K]; v and
du -> [B,T,H,V]; beta -> [B,T,H]; A -> [B,T,H,BT] or the concrete trailing dim
used elsewhere), and that batch/time/head dims (B,T,H) are consistent across
tensors; also assert cu_seqlens and chunk_indices are None or raise ValueError.
Use explicit, clear error messages naming the symbol (e.g., "expected k shape
[B,T,H,K], got ...") so callers can debug shape mismatches before the main logic
runs.
tops/cpu/ops/kda/chunk_intra.py (1)

16-31: Add input shape assertions for public function.

Per coding guidelines, all public functions must enforce strict input assertions on shape and types before executing main logic. This function lacks assert_shape calls for the input tensors.

🔧 Proposed fix to add input assertions
+from tops.utils import assert_shape
+
 def chunk_kda_bwd_intra(
   q: jax.Array,
   ...
 ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
   """Stage 4: Intra-chunk backward through sub-chunk triangular system.
   ...
   """
   del cu_seqlens, chunk_indices, safe_gate

   B, T, H, K = q.shape
   BT = chunk_size
+  
+  # =================== input shape assertions ===================
+  assert_shape(q, (B, T, H, K), "q")
+  assert_shape(k, (B, T, H, K), "k")
+  assert_shape(g, (B, T, H, K), "g")
+  assert_shape(beta, (B, T, H), "beta")
+  assert_shape(dAqk, (B, T, H, BT), "dAqk")
+  assert_shape(dAkk, (B, T, H, BT), "dAkk")
+  assert_shape(dq, (B, T, H, K), "dq")
+  assert_shape(dk, (B, T, H, K), "dk")
+  assert_shape(db, (B, T, H), "db")
+  assert_shape(dg, (B, T, H, K), "dg")
+  assert T % BT == 0, f"T={T} must be divisible by chunk_size={BT}"
+  # ==============================================================
+  
   BC = min(16, BT)

As per coding guidelines: "All public functions must enforce strict input assertions on shape and types before executing main logic using assert instructions or utilities like assert_shape_or_none from tops.utils"

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/ref/kda/special/test_chunk_bwd.py`:
- Around line 266-271: The tuple unpacking from triton_chunk_kda_fwd currently
binds an unused variable final_state; to satisfy static analysis, prefix it with
an underscore (e.g., _final_state) in the unpacking expression where
triton_chunk_kda_fwd(...) is called and ensure any later references (if none
exist) are either removed or updated to use the new name.
- Around line 1-3: Replace the module-level side-effect
os.environ["TRITON_F32_DEFAULT"] = "ieee" with an autouse pytest fixture so the
environment change is scoped to each test; add a fixture named e.g.
set_triton_f32_ieee that uses monkeypatch.setenv("TRITON_F32_DEFAULT", "ieee")
(and remove the import-time os.environ assignment and unnecessary module-level
import if no longer needed) so tests are isolated and import-time side effects
are avoided.

In `@tops/cpu/ops/common/chunk_delta_h.py`:
- Around line 303-306: The __all__ list is not alphabetically ordered; update
the __all__ definition to be sorted so exports are deterministic and satisfy
static analysis—specifically reorder the symbols so
"chunk_gated_delta_rule_bwd_dhu" appears before "chunk_gated_delta_rule_fwd_h"
in the __all__ list.
- Around line 22-58: The forward function chunk_gated_delta_rule_fwd_h lacks the
strict input shape/type assertions required for public APIs; add the same set of
validations used in chunk_gated_delta_rule_bwd_dhu (or equivalent) at the start
of chunk_gated_delta_rule_fwd_h: assert k, w, u are jax.Array with matching
batch/time/feature dims (k: [B,T,H,K], w: [B,T,H,K], u: [B,T,H,V]), assert
optional gk has shape [B,T,H,K] when provided, initial_state has shape [B,H,K,V]
or [N,H,K,V] when provided, chunk_size is positive int,
output_final_state/save_new_value are bools, and raise clear AssertionError
messages on mismatch; place these checks before any computation (e.g., right
after the docstring and before del g,...).

In `@tops/cpu/ops/kda/chunk_bwd.py`:
- Around line 1-6: There is a duplicate import of the module name jax; remove
the redundant import statement so jax is imported only once (keep the first
occurrence and delete the second), and ensure the remaining imports (jax, math,
Optional, Tuple) are correctly ordered and unique in chunk_bwd.py.
- Around line 526-608: Add top-level input assertions in chunk_kda_bwd to
validate shapes and types before orchestration: check q/k shapes [B,T,H,K] match
each other, v shape [B,T,H,V], beta shape [B,T,H], Aqk/Akk shapes [B,T,H,BT]
match T/block expectations, do shape [B,T,H,V], initial_state (if not None)
shape [B,H,K,V], dht (if not None) shape [B,H,K,V], and optional g/g_org shapes
[B,T,H,K]; also assert A_log shape [H] and dt_bias shape [H*K] when provided,
and that scale is a scalar and chunk_size is int>0; raise clear ValueError
messages referencing chunk_kda_bwd and the offending parameter (e.g.,
"chunk_kda_bwd: q and k must have same shape") so mismatches are caught early.

In `@tops/cpu/ops/kda/wy_fast.py`:
- Around line 227-230: The __all__ list currently contains ["recompute_w_u_fwd",
"prepare_wy_repr_bwd"] unsorted; please reorder the entries alphabetically so
exported symbols are in lexicographic order (i.e., update the module-level
__all__ to list "prepare_wy_repr_bwd" before "recompute_w_u_fwd") while keeping
the same symbols and preserving surrounding formatting.
- Around line 96-145: Add strict input assertions at the start of
prepare_wy_repr_bwd: verify each argument is a jax.Array (k, v, beta, gk, A, dk,
dw, du, dg) and that their ranks and shapes match the docstring contracts (k and
gk and dk and dw and dg -> [B,T,H,K]; v and du -> [B,T,H,V]; beta -> [B,T,H]; A
-> [B,T,H,BT] or the concrete trailing dim used elsewhere), and that
batch/time/head dims (B,T,H) are consistent across tensors; also assert
cu_seqlens and chunk_indices are None or raise ValueError. Use explicit, clear
error messages naming the symbol (e.g., "expected k shape [B,T,H,K], got ...")
so callers can debug shape mismatches before the main logic runs.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6257e251-058f-4d8b-881e-83e2f186013f

📥 Commits

Reviewing files that changed from the base of the PR and between 0e0f4e0 and a177453.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (8)
  • tests/ref/kda/special/test_chunk_bwd.py
  • tops/cpu/ops/common/__init__.py
  • tops/cpu/ops/common/chunk_delta_h.py
  • tops/cpu/ops/kda/__init__.py
  • tops/cpu/ops/kda/chunk.py
  • tops/cpu/ops/kda/chunk_bwd.py
  • tops/cpu/ops/kda/chunk_intra.py
  • tops/cpu/ops/kda/wy_fast.py
💤 Files with no reviewable changes (1)
  • tops/cpu/ops/kda/chunk.py

@0xaskr 0xaskr marked this pull request as draft April 14, 2026 10:50
@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2242,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@0xaskr 0xaskr marked this pull request as ready for review April 14, 2026 11:16
@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2243,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2543,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/cpu/ops/common/chunk_delta_h.py`:
- Around line 34-35: The public helper chunk_gated_delta_rule_fwd_h currently
always applies jnp.exp2 and reshapes inputs before validating them; update it to
(1) add a comprehensive docstring describing business semantics and exact tensor
shapes/dimensions for inputs k/w/u/gk and outputs, (2) perform strict input
assertions up front (use assert_shape_or_none or explicit assert statements and
type checks) and validate T % chunk_size before any reshape/broadcast, and (3)
honor the use_exp2 flag so that jnp.exp2 is applied only when use_exp2 is True
(remove the unconditional exp2 on line ~99). Also add/adjust
transpose_state_layout handling after validation. Ensure the assertions
reference the function name chunk_gated_delta_rule_fwd_h for clarity.
- Around line 122-123: The signature currently allows gk: jax.Array | None and
h0: jax.Array | None but the code later calls _to_bh(gk, K) (and similarly for
h0) unconditionally; update the implementation in chunk_delta_h.py to either (A)
require gk and h0 be non-None by changing the signature/types and validating at
entry, or (B) add explicit branching before flattening: if gk is None take the
nongated path (skip _to_bh(gk, K) and any gated logic), otherwise call
_to_bh(gk, K); do the same for h0 so no None value is passed to _to_bh. Ensure
the chosen branch preserves existing behavior for gated vs nongated execution.

In `@tops/cpu/ops/kda/chunk_bwd.py`:
- Around line 40-41: The helper signatures were widened to accept optional
parameters (e.g., scale: float | None in chunk_kda_bwd_wy_dqkg_fused and A/scale
made optional in chunk_kda_bwd_dAv) but the function bodies still assume
non-None (multiplying by scale, asserting A, etc.); fix by either restoring the
parameters to required types (remove | None from scale and make A required) or
explicitly handle None at the API boundary: add explicit guards/defaults at the
start of chunk_kda_bwd_wy_dqkg_fused and chunk_kda_bwd_dAv (e.g., raise a clear
TypeError if A or scale is None, or define and apply a concrete fallback
behavior) so no None value will cause hidden failures later in the body.
- Around line 525-550: Add upfront validation in chunk_kda_bwd: assert required
shapes/types for q, k, v, beta, Aqk, Akk, do, and (if provided) initial_state
and dht using assert or tops.utils.assert_shape_or_none before any processing,
and check that when disable_recompute is True the cached tensors expected from
kwargs (e.g., kwargs["w"], kwargs["kv_cache"] or other cache keys used later)
are present and correctly shaped; replicate the same front-door checks for the
related overload/variant around the chunk_kda_bwd usage at the second region
(lines 608-615) so callers get immediate, clear assertion errors instead of late
reshape/KeyError failures.
- Around line 91-92: Ensure fixed block sizes BK/BV computed from
next_power_of_2(K)/next_power_of_2(V) never exceed their respective feature
dimensions: clamp BK = min(next_power_of_2(K), DEFAULT_BK, K) and BV =
min(next_power_of_2(V), DEFAULT_BV, V) or alternatively compute BK/BV as the
largest power-of-two ≤ min(DEFAULT_BK, K) (and similarly for V) so that
lax.dynamic_slice start indices cannot be clamped and tail blocks don't overlap;
apply this guard wherever BK/BV are used (notably in chunk_kda_bwd_wy_dqkg_fused
and chunk_kda_bwd_dAv). Also add explicit input shape and type assertions in the
public orchestrator chunk_kda_bwd for all primary parameters (Q, K, V, head_dim,
seq_len, etc.) and when disable_recompute=True validate that required tensors
exist in kwargs and have the expected shapes/dtypes before indexing them.

In `@tops/cpu/ops/kda/chunk_intra.py`:
- Around line 67-71: The code must validate chunk_size and input shapes before
performing reshapes: assert chunk_size is a positive int, chunk_size <= T and T
% chunk_size == 0 (so BT == chunk_size divides T) and additionally assert BT %
BC == 0 (where BC = min(16,BT)) to avoid silently dropping positions when BT >
16 but not a multiple of 16; use the project's assertion helper (e.g.
assert_shape_or_none or plain assert) at the start of the public helper in
chunk_intra.py (the block that reads B, T, H, K = q.shape and computes BT, BC,
NT, NC) so the function fails fast with a clear message before any
reshape/permute operations.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4e48b09d-7fad-4c5f-afaf-a5a264d60cc9

📥 Commits

Reviewing files that changed from the base of the PR and between a177453 and b32e763.

📒 Files selected for processing (9)
  • pyproject.toml
  • tests/ops/kda/test_chunk_kda_tpu.py
  • tests/ref/kda/special/test_chunk_bwd.py
  • tests/ref/kda/test_chunk_kda.py
  • tops/cpu/ops/common/chunk_delta_h.py
  • tops/cpu/ops/kda/__init__.py
  • tops/cpu/ops/kda/chunk.py
  • tops/cpu/ops/kda/chunk_bwd.py
  • tops/cpu/ops/kda/chunk_intra.py
💤 Files with no reviewable changes (1)
  • tops/cpu/ops/kda/chunk.py
✅ Files skipped from review due to trivial changes (1)
  • pyproject.toml
🚧 Files skipped from review as they are similar to previous changes (1)
  • tops/cpu/ops/kda/init.py

Comment on lines +34 to +35
use_exp2: bool = False,
transpose_state_layout: bool = False,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

chunk_gated_delta_rule_fwd_h's API contract is only partially implemented.

use_exp2=False never changes behavior—Line 99 always applies jnp.exp2—and this public helper reshapes k/w/u/gk before checking shapes or T % chunk_size, so bad calls fail as raw reshape/broadcast errors.

As per coding guidelines All public functions must have a comprehensive docstring that explains the business semantics of the function and clearly details tensor shapes and dimension meanings for every input and output argument and All public functions must enforce strict input assertions on shape and types before executing main logic using assert instructions or utilities like assert_shape_or_none from tops.utils.

Also applies to: 61-72, 95-99

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/common/chunk_delta_h.py` around lines 34 - 35, The public helper
chunk_gated_delta_rule_fwd_h currently always applies jnp.exp2 and reshapes
inputs before validating them; update it to (1) add a comprehensive docstring
describing business semantics and exact tensor shapes/dimensions for inputs
k/w/u/gk and outputs, (2) perform strict input assertions up front (use
assert_shape_or_none or explicit assert statements and type checks) and validate
T % chunk_size before any reshape/broadcast, and (3) honor the use_exp2 flag so
that jnp.exp2 is applied only when use_exp2 is True (remove the unconditional
exp2 on line ~99). Also add/adjust transpose_state_layout handling after
validation. Ensure the assertions reference the function name
chunk_gated_delta_rule_fwd_h for clarity.

Comment on lines +122 to +123
gk: jax.Array | None = None,
h0: jax.Array | None = None,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

gk=None is still unsupported here.

The signature accepts gk: jax.Array | None, but Line 269 unconditionally flattens it. A None call-site makes it through the earlier guards and then dies on _to_bh(gk, K). Either require gk or branch to a nongated path before flattening.

Also applies to: 269-270

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/common/chunk_delta_h.py` around lines 122 - 123, The signature
currently allows gk: jax.Array | None and h0: jax.Array | None but the code
later calls _to_bh(gk, K) (and similarly for h0) unconditionally; update the
implementation in chunk_delta_h.py to either (A) require gk and h0 be non-None
by changing the signature/types and validating at entry, or (B) add explicit
branching before flattening: if gk is None take the nongated path (skip
_to_bh(gk, K) and any gated logic), otherwise call _to_bh(gk, K); do the same
for h0 so no None value is passed to _to_bh. Ensure the chosen branch preserves
existing behavior for gated vs nongated execution.

Comment on lines +40 to +41
scale: float | None = None,
chunk_size: int = DEFAULT_BT,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The widened helper signatures still behave as required-only.

chunk_kda_bwd_wy_dqkg_fused now accepts scale: float | None, and chunk_kda_bwd_dAv accepts A/scale as optional, but Line 236 multiplies by scale, Line 432 asserts A, and Line 492 multiplies by scale. Any None call-path still fails inside the body instead of at the API boundary. Either restore required annotations or define the None behavior explicitly.

Minimal fix if `None` is not actually supported
-  scale: float | None = None,
+  scale: float,
...
-    A: jax.Array | None = None,
-    scale: float | None = None,
+    A: jax.Array,
+    scale: float,

Also applies to: 236-237, 397-398, 432-433, 492-492

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk_bwd.py` around lines 40 - 41, The helper signatures
were widened to accept optional parameters (e.g., scale: float | None in
chunk_kda_bwd_wy_dqkg_fused and A/scale made optional in chunk_kda_bwd_dAv) but
the function bodies still assume non-None (multiplying by scale, asserting A,
etc.); fix by either restoring the parameters to required types (remove | None
from scale and make A required) or explicitly handle None at the API boundary:
add explicit guards/defaults at the start of chunk_kda_bwd_wy_dqkg_fused and
chunk_kda_bwd_dAv (e.g., raise a clear TypeError if A or scale is None, or
define and apply a concrete fallback behavior) so no None value will cause
hidden failures later in the body.

Comment on lines +91 to +92
BK = min(next_power_of_2(K), DEFAULT_BK)
BV = min(next_power_of_2(V), DEFAULT_BV)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In JAX, does lax.dynamic_slicerequire eachslice_sizeto be <= the operand dimension, and dolax.dynamic_slice/lax.dynamic_update_slice clamp out-of-bounds start indices so a tail block can overlap an earlier region?

💡 Result:

Yes, lax.dynamic_slice requires each slice_size to be <= the corresponding operand dimension size. No, neither lax.dynamic_slice nor lax.dynamic_update_slice clamps out-of-bounds start indices in a way that causes a tail block to overlap an earlier region; instead, when the slice overruns the array bounds, the start index is adjusted backward to ensure the slice fits within the array bounds, potentially overlapping earlier regions only if the requested slice size exceeds the remaining array size from the start position.

Citations:


🏁 Script executed:

# Get file size and read the full file to examine the flagged lines
wc -l tops/cpu/ops/kda/chunk_bwd.py
cat -n tops/cpu/ops/kda/chunk_bwd.py | head -200

Repository: primatrix/pallas-kernel

Length of output: 8844


🏁 Script executed:

# Continue reading the file to see the flagged lines and public function
cat -n tops/cpu/ops/kda/chunk_bwd.py | sed -n '200,350p'

Repository: primatrix/pallas-kernel

Length of output: 6401


🏁 Script executed:

# Continue reading to find the flagged lines and the public function
cat -n tops/cpu/ops/kda/chunk_bwd.py | sed -n '350,500p'

Repository: primatrix/pallas-kernel

Length of output: 6543


🏁 Script executed:

# Continue reading to find the public function chunk_kda_bwd and check around line 525-604
cat -n tops/cpu/ops/kda/chunk_bwd.py | sed -n '500,674p'

Repository: primatrix/pallas-kernel

Length of output: 6860


Add dimension guards for fixed-size feature blocks and validate public function inputs.

The block slicing pattern used across this file requires each slice_size parameter to be ≤ the corresponding operand dimension. Lines 91–92 and 425 derive fixed BK/BV from next_power_of_2(K) / next_power_of_2(V), but for non-power-of-two dimensions (e.g., V=48 → BV=64, K=96 → BK=128), these slice sizes can exceed the feature dimension, causing lax.dynamic_slice to clamp the start index and overlap earlier regions instead of covering the tail block.

This affects all K-block and V-block slicing throughout chunk_kda_bwd_wy_dqkg_fused (lines 158–180, 268–270) and chunk_kda_bwd_dAv (lines 425, 475–484).

Additionally, the public orchestrator chunk_kda_bwd violates the strict input assertions guideline. When disable_recompute=True, it directly indexes kwargs (lines 609–615) without validating that required cached tensors are present or correctly shaped.

Add guards to ensure K and V are either powers of two ≤ DEFAULT_BK/DEFAULT_BV, or multiples of those constants; add input shape assertions to chunk_kda_bwd on all primary parameters and on kwargs keys when disable_recompute=True.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk_bwd.py` around lines 91 - 92, Ensure fixed block sizes
BK/BV computed from next_power_of_2(K)/next_power_of_2(V) never exceed their
respective feature dimensions: clamp BK = min(next_power_of_2(K), DEFAULT_BK, K)
and BV = min(next_power_of_2(V), DEFAULT_BV, V) or alternatively compute BK/BV
as the largest power-of-two ≤ min(DEFAULT_BK, K) (and similarly for V) so that
lax.dynamic_slice start indices cannot be clamped and tail blocks don't overlap;
apply this guard wherever BK/BV are used (notably in chunk_kda_bwd_wy_dqkg_fused
and chunk_kda_bwd_dAv). Also add explicit input shape and type assertions in the
public orchestrator chunk_kda_bwd for all primary parameters (Q, K, V, head_dim,
seq_len, etc.) and when disable_recompute=True validate that required tensors
exist in kwargs and have the expected shapes/dtypes before indexing them.

Comment on lines +525 to +550
def chunk_kda_bwd(
q: jax.Array,
k: jax.Array,
v: jax.Array,
beta: jax.Array,
Aqk: jax.Array,
Akk: jax.Array,
scale: float,
initial_state: jax.Array | None,
do: jax.Array,
dht: jax.Array | None,
*,
g: jax.Array | None = None,
g_org: jax.Array | None = None,
cu_seqlens: jax.Array | None = None,
chunk_indices: jax.Array | None = None,
chunk_size: int = 64,
safe_gate: bool = False,
lower_bound: float | None = None,
use_gate_in_kernel: bool = False,
A_log: jax.Array | None = None,
dt_bias: jax.Array | None = None,
disable_recompute: bool = False,
cp_context=None,
transpose_state_layout: bool = False,
**kwargs,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add front-door validation to chunk_kda_bwd.

This new public entrypoint never checks shapes or the cached tensors required by disable_recompute=True before Stage 0. Right now a bad caller gets a deep reshape/assertion later or a bare KeyError from kwargs["w"], which is much harder to diagnose than a direct contract failure.

As per coding guidelines All public functions must enforce strict input assertions on shape and types before executing main logic using assert instructions or utilities like assert_shape_or_none from tops.utils.

Also applies to: 608-615

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk_bwd.py` around lines 525 - 550, Add upfront validation
in chunk_kda_bwd: assert required shapes/types for q, k, v, beta, Aqk, Akk, do,
and (if provided) initial_state and dht using assert or
tops.utils.assert_shape_or_none before any processing, and check that when
disable_recompute is True the cached tensors expected from kwargs (e.g.,
kwargs["w"], kwargs["kv_cache"] or other cache keys used later) are present and
correctly shaped; replicate the same front-door checks for the related
overload/variant around the chunk_kda_bwd usage at the second region (lines
608-615) so callers get immediate, clear assertion errors instead of late
reshape/KeyError failures.

Comment on lines +67 to +71
B, T, H, K = q.shape
BT = chunk_size
BC = min(16, BT)
NT = T // BT
NC = BT // BC
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard unsupported chunk_size values before chunking.

NC = BT // BC silently drops the last BT % 16 positions whenever chunk_size > 16 and is not a multiple of 16, and this public helper still does no front-door shape checks before the reshapes.

Possible guard
  B, T, H, K = q.shape
  BT = chunk_size
+ assert k.shape == (B, T, H, K)
+ assert g.shape == (B, T, H, K)
+ assert beta.shape == (B, T, H)
+ assert dAqk.shape == (B, T, H, BT)
+ assert dAkk.shape == (B, T, H, BT)
+ assert T % BT == 0, f"T={T} must be divisible by chunk_size={BT}"
  BC = min(16, BT)
+ assert BT <= 16 or BT % BC == 0, (
+   f"chunk_size={BT} leaves a partial BC-sized tail that this kernel never visits"
+ )

As per coding guidelines All public functions must enforce strict input assertions on shape and types before executing main logic using assert instructions or utilities like assert_shape_or_none from tops.utils.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/chunk_intra.py` around lines 67 - 71, The code must validate
chunk_size and input shapes before performing reshapes: assert chunk_size is a
positive int, chunk_size <= T and T % chunk_size == 0 (so BT == chunk_size
divides T) and additionally assert BT % BC == 0 (where BC = min(16,BT)) to avoid
silently dropping positions when BT > 16 but not a multiple of 16; use the
project's assertion helper (e.g. assert_shape_or_none or plain assert) at the
start of the public helper in chunk_intra.py (the block that reads B, T, H, K =
q.shape and computes BT, BC, NT, NC) so the function fails fast with a clear
message before any reshape/permute operations.

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2546,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In @.github/ci/gpu-tests.sky.yaml:
- Around line 31-32: The CI currently permanently excludes important
Triton↔Pallas parity tests by using the --ignore flags for
tests/ref/kda/special/test_chunk_bwd_dAv_dhu_pallas.py and
tests/ref/kda/special/test_chunk_bwd_pallas.py; revert removing these tests from
the CI command (remove those --ignore entries) and instead handle instability by
marking the tests themselves with pytest markers/xfail or adding a separate
gated CI job that runs them, or conditionally skip them via an env-driven CI
flag; update either the test files (add `@pytest.mark.xfail` or conditional skip)
or the CI YAML to run them in a dedicated job rather than using --ignore so the
parity coverage remains in CI.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 73db91c6-6280-4511-9dbe-d228f795847a

📥 Commits

Reviewing files that changed from the base of the PR and between b32e763 and 41fe955.

📒 Files selected for processing (1)
  • .github/ci/gpu-tests.sky.yaml

Comment on lines +31 to +32
--ignore=tests/ref/kda/special/test_chunk_bwd_dAv_dhu_pallas.py \
--ignore=tests/ref/kda/special/test_chunk_bwd_pallas.py \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid permanently excluding core Triton↔Pallas parity tests from CI.

Line 31 and Line 32 remove two high-signal regression suites (including chunk_gated_delta_rule_bwd_dhu parity), which creates a CI blind spot in the same backward path this PR modifies. Prefer keeping these tests in CI and handling flakes/OOM with targeted markers/xfail or a separate gated job instead of --ignore.

Suggested CI command adjustment
   uv run python -m pytest \
     -o "addopts=--strict-markers" \
     tests/ref/kda/special/ \
-    --ignore=tests/ref/kda/special/test_chunk_bwd_dAv_dhu_pallas.py \
-    --ignore=tests/ref/kda/special/test_chunk_bwd_pallas.py \
     -v
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
--ignore=tests/ref/kda/special/test_chunk_bwd_dAv_dhu_pallas.py \
--ignore=tests/ref/kda/special/test_chunk_bwd_pallas.py \
uv run python -m pytest \
-o "addopts=--strict-markers" \
tests/ref/kda/special/ \
-v
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In @.github/ci/gpu-tests.sky.yaml around lines 31 - 32, The CI currently
permanently excludes important Triton↔Pallas parity tests by using the --ignore
flags for tests/ref/kda/special/test_chunk_bwd_dAv_dhu_pallas.py and
tests/ref/kda/special/test_chunk_bwd_pallas.py; revert removing these tests from
the CI command (remove those --ignore entries) and instead handle instability by
marking the tests themselves with pytest markers/xfail or adding a separate
gated CI job that runs them, or conditionally skip them via an env-driven CI
flag; update either the test files (add `@pytest.mark.xfail` or conditional skip)
or the CI YAML to run them in a dedicated job rather than using --ignore so the
parity coverage remains in CI.

- Set TRITON_AUTOTUNE=0 to skip kernel autotuning (avoids compiling
  10+ ptxas variants per kernel, each taking 1-3 minutes)
- Remove --idle-minutes-to-autostop to allow K8s cluster usage
  (K8s doesn't support autostop, causing fallback to GCP spot)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 2548,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@0xaskr 0xaskr added this pull request to the merge queue Apr 14, 2026
Merged via the queue into main with commit 343d3f9 Apr 14, 2026
2 of 4 checks passed
@0xaskr 0xaskr deleted the fix/chunk-bwd-dhu-signature branch April 14, 2026 14:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant