Skip to content

feat/dav dhu#194

Merged
pathfinder-pf merged 10 commits intomainfrom
feat/dav_dhu
Apr 14, 2026
Merged

feat/dav dhu#194
pathfinder-pf merged 10 commits intomainfrom
feat/dav_dhu

Conversation

@pathfinder-pf
Copy link
Copy Markdown
Collaborator

@pathfinder-pf pathfinder-pf commented Apr 14, 2026

Description

add dav and dhu kernel for chunk_kda

Related Issue

Closes #

Change Type

  • feat — New feature
  • fix — Bug fix
  • refactor — Code refactoring
  • docs — Documentation
  • ci — CI/CD changes
  • test — Tests
  • perf — Performance improvement

Checklist

  • Code passes uv run ruff check src/ tests/ and uv run ruff format src/ tests/
  • New/modified public APIs have complete docstrings (tensor shapes, dimension semantics, business logic)
  • Public functions have input assertions (assert or assert_shape_or_none)
  • Tests added at the appropriate layer (tests/ops/, tests/modules/, tests/layers/, or tests/ref/)
  • If tops/cpu/ is modified, core developers have been notified and PR is labeled cpu-ref

Test Results

Paste relevant test output here.

Summary by CodeRabbit

  • New Features

    • Added chunked backward-pass support for KDA attention on CPU and GPU, enabling gradient computation for chunked attention and optional initial-state handling.
    • Causal masking respected in backward outputs.
  • Tests

    • Added comprehensive correctness tests comparing implementations across float32 and bfloat16.
    • Covers shape/validation checks, causal-mask behavior, optional carry handling, and various tensor shapes.

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 653,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@pathfinder-pf pathfinder-pf changed the title Feat/dav dhu feat/dav dhu Apr 14, 2026
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 14, 2026

📝 Walkthrough
📝 Walkthrough
🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'feat/dav dhu' is vague and does not clearly convey what the pull request accomplishes; it uses abbreviations without context. Revise the title to be more descriptive and specific, such as 'Add CPU/Pallas backward kernels for KDA: dAv and dhu stages' to clearly communicate the main changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/dav_dhu

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements CPU reference and Pallas kernels for chunk_kda_bwd_dAv and chunk_gated_delta_rule_bwd_dhu, including comprehensive tests comparing outputs against Triton references. Review feedback highlights a shape inconsistency in the Pallas dh output and the need for functional parity regarding the initial state gradient (dh0). Additionally, suggestions were made to simplify the CPU reference by removing redundant padding logic and to improve API consistency by exposing the chunk_size parameter in the gated delta rule backward implementation.


# --- Reshape back ---
# dh: [BH, NT, K, V] -> [B, H, NT, K, V] -> [B, NT, H, K, V]
dh = dh_r.reshape(B, H, NT, K, V).transpose(0, 2, 1, 3, 4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The shape of the returned dh tensor is [B, NT, H, K, V], which is inconsistent with the CPU reference implementation in tops/cpu/ops/kda/chunk_bwd.py (line 650) that returns [B * NT, H, K, V]. This inconsistency should be resolved by flattening the batch and chunk dimensions to match the CPU reference and common conventions for these kernels. Please also update the docstring at line 557 to reflect this change.

Suggested change
dh = dh_r.reshape(B, H, NT, K, V).transpose(0, 2, 1, 3, 4)
dh = dh_r.reshape(B, H, NT, K, V).transpose(0, 2, 1, 3, 4).reshape(B * NT, H, K, V)

Comment on lines +406 to +410
q_bh = pad_to_multiple(q_bh, T_padded, axis=0)
k_bh = pad_to_multiple(k_bh, T_padded, axis=0)
v_bh = pad_to_multiple(v_bh, T_padded, axis=0)
do_bh = pad_to_multiple(do_bh, T_padded, axis=0)
A_bh = pad_to_multiple(A_bh, T_padded, axis=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The padding logic here is redundant because T % BT == 0 is asserted at line 393, ensuring that T_padded (which is NT * BT) will always equal T_actual. This code can be simplified by removing the pad_to_multiple calls to improve readability and avoid unnecessary operations.

Comment on lines +481 to +492
def chunk_gated_delta_rule_bwd_dhu(
q: jax.Array, # [B, T, H, K] (gated query)
k: jax.Array, # [B, T, H, K] (gated key)
w: jax.Array, # [B, T, H, K] (delta-rule erase weight)
gk: jax.Array, # [B, T, H, K] (per-key gate, cumsum'd, log2-space)
h0: Optional[jax.Array], # [B, H, K, V] or None (initial state)
dht: Optional[jax.Array], # [B, H, K, V] or None (gradient of final state)
do: jax.Array, # [B, T, H, V] (gradient of output)
dv: jax.Array, # [B, T, H, V] (gradient of v from dAv stage)
scale: float,
use_exp2: bool = True,
) -> Tuple[jax.Array, Optional[jax.Array], jax.Array]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The chunk_gated_delta_rule_bwd_dhu function is missing the chunk_size parameter in its signature, which is inconsistent with chunk_kda_bwd_dAv and the corresponding Pallas implementation. It currently hardcodes BT = DEFAULT_BT (line 516). Adding chunk_size as an argument would improve flexibility and consistency across the API.

def chunk_gated_delta_rule_bwd_dhu(
    q: jax.Array,                   # [B, T, H, K]  (gated query)
    k: jax.Array,                   # [B, T, H, K]  (gated key)
    w: jax.Array,                   # [B, T, H, K]  (delta-rule erase weight)
    gk: jax.Array,                  # [B, T, H, K]  (per-key gate, cumsum'd, log2-space)
    h0: Optional[jax.Array],        # [B, H, K, V] or None (initial state)
    dht: Optional[jax.Array],       # [B, H, K, V] or None (gradient of final state)
    do: jax.Array,                  # [B, T, H, V]  (gradient of output)
    dv: jax.Array,                  # [B, T, H, V]  (gradient of v from dAv stage)
    scale: float,
    chunk_size: int = DEFAULT_BT,
    use_exp2: bool = True,
) -> Tuple[jax.Array, Optional[jax.Array], jax.Array]:

"""
B, T, H, K = q.shape
V = do.shape[-1]
BT = DEFAULT_BT
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the chunk_size parameter instead of hardcoding DEFAULT_BT to ensure consistency with the updated function signature.

Suggested change
BT = DEFAULT_BT
BT = chunk_size

Comment on lines +539 to +544
q_bh = pad_to_multiple(q_bh, T_padded, 0)
k_bh = pad_to_multiple(k_bh, T_padded, 0)
w_bh = pad_to_multiple(w_bh, T_padded, 0)
gk_bh = pad_to_multiple(gk_bh, T_padded, 0)
do_bh = pad_to_multiple(do_bh, T_padded, 0)
dv_bh = pad_to_multiple(dv_bh, T_padded, 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the dAv function, the padding logic here is redundant due to the T % BT == 0 assertion at line 530. Removing these calls will simplify the implementation.


# dh0: not computed inside the Pallas kernel (would need final carry).
# For now, return None if h0 is None.
dh0 = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Pallas implementation returns None for dh0 (the gradient of the initial state), while the CPU reference computes it. dh0 is the final carry of the lax.scan operation and should be retrieved and returned to maintain functional parity with the CPU reference.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
tops/ops/kda/chunk_bwd.py (1)

371-374: Parameters q and k are accepted but unused.

The function signature includes q and k parameters, and shape assertions are performed on them (Lines 400-401), but they are never used in the actual computation — only v, do, and A are reshaped and passed to the kernel.

If these parameters are kept for API compatibility with the Triton reference, consider documenting this in the docstring. Otherwise, they could be removed to avoid confusion.

📝 Suggested docstring update
     Args:
-        q:     [B, T, H, K]   query tensor.
-        k:     [B, T, H, K]   key tensor.
+        q:     [B, T, H, K]   query tensor (unused, for API compatibility).
+        k:     [B, T, H, K]   key tensor (unused, for API compatibility).
         v:     [B, T, H, V]   value tensor (v_new in the full backward).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/kda/chunk_bwd.py` around lines 371 - 374, The function
chunk_kda_bwd_dAv_kernel currently accepts parameters q and k but never uses
them; either remove q and k from the signature (and from any assertions like the
ones around lines referencing q/k) to avoid confusion, or explicitly document in
the chunk_kda_bwd_dAv_kernel docstring that q and k are accepted only for API
compatibility with the Triton reference and are intentionally unused; update the
signature or docstring and adjust/remove related shape assertions and comments
in the function body accordingly.
tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py (2)

282-312: Same dh0 coverage gap as CPU tests — use underscore prefix for unused variables.

Same issue as the CPU test file: dh0_ref and dh0_pallas are unpacked but never compared. Use underscore prefix to suppress warnings and document intent.

✅ Proposed fix
         # Triton ground truth
-        dh_ref, dh0_ref, dv2_ref = triton_chunk_gated_delta_rule_bwd_dhu(
+        dh_ref, _dh0_ref, dv2_ref = triton_chunk_gated_delta_rule_bwd_dhu(
             ...
         )

         # Pallas under test
-        dh_pallas, dh0_pallas, dv2_pallas = pallas_dhu(
+        dh_pallas, _dh0_pallas, dv2_pallas = pallas_dhu(
             ...
         )

Apply similar changes to Lines 325, 335, and 404.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py` around lines 282 - 312, The
test unpacks dh_ref, dh0_ref, dv2_ref and dh_pallas, dh0_pallas, dv2_pallas from
triton_chunk_gated_delta_rule_bwd_dhu and pallas_dhu but never uses
dh0_ref/dh0_pallas; update the unpacking to use an underscore-prefixed name for
the unused second return (e.g., replace dh0_ref/dh0_pallas with
_dh0_ref/_dh0_pallas) wherever these functions are called (references:
triton_chunk_gated_delta_rule_bwd_dhu and pallas_dhu, and variables dh_ref,
dh0_ref, dv2_ref, dh_pallas, dh0_pallas, dv2_pallas) so the intent is explicit
and linter warnings are suppressed; apply the same change to the other
occurrences noted in the review.

47-99: Consider extracting shared _generate_inputs helper to reduce duplication.

The _generate_inputs function is nearly identical between this file and test_chunk_bwd_dAv_dhu.py. Consider extracting it to a shared test utility module to reduce maintenance burden and ensure consistency.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py` around lines 47 - 99, Extract
the duplicated _generate_inputs helper into a shared test utility module and
update both tests to import and call it; specifically move the function
(including its signature parameters B, T, H, K, V, chunk_size,
gate_logit_normalizer, dtype and its use of triton_chunk_kda_fwd,
triton_recompute_w_u_fwd, triton_chunk_gated_delta_rule_fwd_h, SEED, device)
into a new test helper (e.g., tests/util/test_helpers.py), export the function
name _generate_inputs, ensure the returned dict keys remain identical (q, k, v,
g, beta, h0, scale, Aqk, Akk, w, u, qg, kg, v_new, h, initial_state, o,
final_state, do, dht, chunk_size), and update both test files to import
_generate_inputs from the new module so signatures and behavior are preserved.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ref/kda/test_chunk_bwd_dAv_dhu.py`:
- Around line 282-312: The test unpacks dh0_ref and dh0_pallas from
triton_chunk_gated_delta_rule_bwd_dhu and cpu_dhu but never validates them, so a
missing dh0 from cpu_dhu can pass silently; update the test to either
(preferred) assert dh0 equality using compare_tensor("dh0 (fp32)", dh0_pallas,
dh0_ref, atol=1e-4, rtol=1e-4) alongside the existing dh/dv2 checks, or if dh0
is intentionally not produced yet, rename the unpacked variables to _ (e.g.,
replace dh0_ref/dh0_pallas with _) and add a one-line TODO comment referencing
cpu_dhu/triton_chunk_gated_delta_rule_bwd_dhu explaining why dh0 is skipped.
Ensure you modify the unpack at the triton_chunk_gated_delta_rule_bwd_dhu and
cpu_dhu call sites (lines where dh_ref, dh0_ref, dv2_ref and dh_pallas,
dh0_pallas, dv2_pallas are assigned).

In `@tops/cpu/ops/kda/chunk_bwd.py`:
- Around line 480-492: The function chunk_gated_delta_rule_bwd_dhu is missing a
chunk_size parameter (tests pass chunk_size=data["chunk_size"]) so add
chunk_size: int (with a default if desired) to the function signature and use
that parameter instead of the hardcoded DEFAULT_BT when computing BT (replace
the hardcoded BT = DEFAULT_BT usage inside the function). Update any internal
references that rely on BT to use the new chunk_size variable (e.g., where BT is
set/used) so the function accepts and respects the caller-provided chunk size.

In `@tops/ops/kda/chunk_bwd.py`:
- Around line 646-648: The Pallas kernel always sets dh0 = None (dh0 in
chunk_bwd.py) which mismatches the CPU reference; update the Pallas path to
capture the final reverse-scan carry (the running b_dh in scan_fn) and return
dh0 when h0 was provided: extract the final b_dh from the kernel/scan output,
reshape it to (B, H, K, V) to match the CPU reference (as CPU does with
dh0_flat.reshape(B, H, K, V)), and return that instead of None; if you
intentionally defer this, add a clear TODO in chunk_bwd.py noting that dh0 is
not yet implemented and why.

---

Nitpick comments:
In `@tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py`:
- Around line 282-312: The test unpacks dh_ref, dh0_ref, dv2_ref and dh_pallas,
dh0_pallas, dv2_pallas from triton_chunk_gated_delta_rule_bwd_dhu and pallas_dhu
but never uses dh0_ref/dh0_pallas; update the unpacking to use an
underscore-prefixed name for the unused second return (e.g., replace
dh0_ref/dh0_pallas with _dh0_ref/_dh0_pallas) wherever these functions are
called (references: triton_chunk_gated_delta_rule_bwd_dhu and pallas_dhu, and
variables dh_ref, dh0_ref, dv2_ref, dh_pallas, dh0_pallas, dv2_pallas) so the
intent is explicit and linter warnings are suppressed; apply the same change to
the other occurrences noted in the review.
- Around line 47-99: Extract the duplicated _generate_inputs helper into a
shared test utility module and update both tests to import and call it;
specifically move the function (including its signature parameters B, T, H, K,
V, chunk_size, gate_logit_normalizer, dtype and its use of triton_chunk_kda_fwd,
triton_recompute_w_u_fwd, triton_chunk_gated_delta_rule_fwd_h, SEED, device)
into a new test helper (e.g., tests/util/test_helpers.py), export the function
name _generate_inputs, ensure the returned dict keys remain identical (q, k, v,
g, beta, h0, scale, Aqk, Akk, w, u, qg, kg, v_new, h, initial_state, o,
final_state, do, dht, chunk_size), and update both test files to import
_generate_inputs from the new module so signatures and behavior are preserved.

In `@tops/ops/kda/chunk_bwd.py`:
- Around line 371-374: The function chunk_kda_bwd_dAv_kernel currently accepts
parameters q and k but never uses them; either remove q and k from the signature
(and from any assertions like the ones around lines referencing q/k) to avoid
confusion, or explicitly document in the chunk_kda_bwd_dAv_kernel docstring that
q and k are accepted only for API compatibility with the Triton reference and
are intentionally unused; update the signature or docstring and adjust/remove
related shape assertions and comments in the function body accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 47341f66-8661-463a-afaf-e0df4bd59d99

📥 Commits

Reviewing files that changed from the base of the PR and between 147efe7 and 6635c6c.

📒 Files selected for processing (4)
  • tests/ref/kda/test_chunk_bwd_dAv_dhu.py
  • tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py
  • tops/cpu/ops/kda/chunk_bwd.py
  • tops/ops/kda/chunk_bwd.py

Comment on lines +282 to +312
dh_ref, dh0_ref, dv2_ref = triton_chunk_gated_delta_rule_bwd_dhu(
q=data["qg"], k=data["kg"], w=data["w"],
gk=data["g"],
h0=data["initial_state"],
dht=data["dht"],
do=data["do"], dv=dv,
scale=data["scale"],
use_exp2=True,
)

# Pallas under test
dh_pallas, dh0_pallas, dv2_pallas = cpu_dhu(
q=torch_to_jax(data["qg"]),
k=torch_to_jax(data["kg"]),
w=torch_to_jax(data["w"]),
gk=torch_to_jax(data["g"]),
h0=torch_to_jax(data["initial_state"]),
dht=torch_to_jax(data["dht"]),
do=torch_to_jax(data["do"]),
dv=torch_to_jax(dv),
scale=data["scale"],
chunk_size=data["chunk_size"],
)

# dh: Triton returns [B*NT, H, K, V], Pallas returns [B, NT, H, K, V]
NT = T // data["chunk_size"]
assert compare_tensor("dh (fp32)", dh_pallas,
dh_ref.reshape(B, NT, H, K, V),
atol=1e-4, rtol=1e-4)
assert compare_tensor("dv2 (fp32)", dv2_pallas, dv2_ref,
atol=1e-4, rtol=1e-4)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

dh0 is extracted but never validated in tests.

The tests unpack dh0_ref and dh0_pallas (Lines 282, 293) but only compare dh and dv2 (Lines 308-312). This means:

  1. The dh0 output is never validated against the reference
  2. The discrepancy where Pallas returns None but Triton returns an actual tensor goes undetected

Consider either:

  • Adding dh0 comparison when the implementation is complete
  • Renaming to _ to acknowledge it's intentionally unused
  • Adding a TODO comment explaining why dh0 isn't tested
✅ Proposed fix to suppress warnings and document intent
         # Triton ground truth
-        dh_ref, dh0_ref, dv2_ref = triton_chunk_gated_delta_rule_bwd_dhu(
+        dh_ref, _dh0_ref, dv2_ref = triton_chunk_gated_delta_rule_bwd_dhu(
             q=data["qg"], k=data["kg"], w=data["w"],
             ...
         )

         # Pallas under test
-        dh_pallas, dh0_pallas, dv2_pallas = cpu_dhu(
+        dh_pallas, _dh0_pallas, dv2_pallas = cpu_dhu(
             ...
         )

         # dh: Triton returns [B*NT, H, K, V], Pallas returns [B, NT, H, K, V]
         NT = T // data["chunk_size"]
         assert compare_tensor("dh (fp32)", dh_pallas,
                               dh_ref.reshape(B, NT, H, K, V),
                               atol=1e-4, rtol=1e-4)
         assert compare_tensor("dv2 (fp32)", dv2_pallas, dv2_ref,
                               atol=1e-4, rtol=1e-4)
+        # TODO: Add dh0 comparison once Pallas kernel computes dh0
🧰 Tools
🪛 Ruff (0.15.9)

[warning] 282-282: Unpacked variable dh0_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


[warning] 293-293: Unpacked variable dh0_pallas is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/test_chunk_bwd_dAv_dhu.py` around lines 282 - 312, The test
unpacks dh0_ref and dh0_pallas from triton_chunk_gated_delta_rule_bwd_dhu and
cpu_dhu but never validates them, so a missing dh0 from cpu_dhu can pass
silently; update the test to either (preferred) assert dh0 equality using
compare_tensor("dh0 (fp32)", dh0_pallas, dh0_ref, atol=1e-4, rtol=1e-4)
alongside the existing dh/dv2 checks, or if dh0 is intentionally not produced
yet, rename the unpacked variables to _ (e.g., replace dh0_ref/dh0_pallas with
_) and add a one-line TODO comment referencing
cpu_dhu/triton_chunk_gated_delta_rule_bwd_dhu explaining why dh0 is skipped.
Ensure you modify the unpack at the triton_chunk_gated_delta_rule_bwd_dhu and
cpu_dhu call sites (lines where dh_ref, dh0_ref, dv2_ref and dh_pallas,
dh0_pallas, dv2_pallas are assigned).

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 656,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 656,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 656,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 656,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/ref/kda/test_chunk_bwd_dAv_dhu.py (1)

397-418: dh0 shape assertion missing in test_dhu_shapes.

The shape test validates dh and dv2 shapes but omits dh0. For consistency with test_chunk_bwd_dAv_dhu_pallas.py:test_dhu_shapes (which asserts dh0.shape == (B, H, K, V)), consider adding:

-        dh, dh0, dv2 = cpu_dhu(
+        dh, _dh0, dv2 = cpu_dhu(
             ...
         )
         assert dh.shape == (B, NT, H, K, V)
         assert dv2.shape == (B, T, H, V)
         assert dh.dtype == jnp.float32
+        # TODO: assert _dh0.shape == (B, H, K, V) when dh0 validation is added
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/test_chunk_bwd_dAv_dhu.py` around lines 397 - 418, Add an
assertion for dh0's shape in test_dhu_shapes: after calling cpu_dhu (the
function under test) assert that dh0.shape == (B, H, K, V) to match the
corresponding test in test_chunk_bwd_dAv_dhu_pallas.py and ensure consistency
with dh and dv2 checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/ref/kda/test_chunk_bwd_dAv_dhu.py`:
- Around line 397-418: Add an assertion for dh0's shape in test_dhu_shapes:
after calling cpu_dhu (the function under test) assert that dh0.shape == (B, H,
K, V) to match the corresponding test in test_chunk_bwd_dAv_dhu_pallas.py and
ensure consistency with dh and dv2 checks.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 066e0d3f-0ac3-4075-80b6-b719c11046e5

📥 Commits

Reviewing files that changed from the base of the PR and between 6635c6c and 31f4ea5.

📒 Files selected for processing (6)
  • tests/ref/kda/test_chunk_bwd.py
  • tests/ref/kda/test_chunk_bwd_dAv_dhu.py
  • tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py
  • tests/ref/kda/test_chunk_bwd_pallas.py
  • tops/cpu/ops/kda/chunk_bwd.py
  • tops/ops/kda/chunk_bwd.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tops/cpu/ops/kda/chunk_bwd.py

@beaver-infiscale
Copy link
Copy Markdown

Warning

规模超限

此 PR 核心代码变更行数为 656,超出上限 200(不含测试与文档)。

建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。

@pathfinder-pf pathfinder-pf added this pull request to the merge queue Apr 14, 2026
Merged via the queue into main with commit 2a2e87c Apr 14, 2026
3 of 4 checks passed
@pathfinder-pf pathfinder-pf deleted the feat/dav_dhu branch April 14, 2026 05:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant