Conversation
|
Warning 规模超限 此 PR 核心代码变更行数为 653,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
📝 Walkthrough📝 Walkthrough🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request implements CPU reference and Pallas kernels for chunk_kda_bwd_dAv and chunk_gated_delta_rule_bwd_dhu, including comprehensive tests comparing outputs against Triton references. Review feedback highlights a shape inconsistency in the Pallas dh output and the need for functional parity regarding the initial state gradient (dh0). Additionally, suggestions were made to simplify the CPU reference by removing redundant padding logic and to improve API consistency by exposing the chunk_size parameter in the gated delta rule backward implementation.
|
|
||
| # --- Reshape back --- | ||
| # dh: [BH, NT, K, V] -> [B, H, NT, K, V] -> [B, NT, H, K, V] | ||
| dh = dh_r.reshape(B, H, NT, K, V).transpose(0, 2, 1, 3, 4) |
There was a problem hiding this comment.
The shape of the returned dh tensor is [B, NT, H, K, V], which is inconsistent with the CPU reference implementation in tops/cpu/ops/kda/chunk_bwd.py (line 650) that returns [B * NT, H, K, V]. This inconsistency should be resolved by flattening the batch and chunk dimensions to match the CPU reference and common conventions for these kernels. Please also update the docstring at line 557 to reflect this change.
| dh = dh_r.reshape(B, H, NT, K, V).transpose(0, 2, 1, 3, 4) | |
| dh = dh_r.reshape(B, H, NT, K, V).transpose(0, 2, 1, 3, 4).reshape(B * NT, H, K, V) |
| q_bh = pad_to_multiple(q_bh, T_padded, axis=0) | ||
| k_bh = pad_to_multiple(k_bh, T_padded, axis=0) | ||
| v_bh = pad_to_multiple(v_bh, T_padded, axis=0) | ||
| do_bh = pad_to_multiple(do_bh, T_padded, axis=0) | ||
| A_bh = pad_to_multiple(A_bh, T_padded, axis=0) |
There was a problem hiding this comment.
| def chunk_gated_delta_rule_bwd_dhu( | ||
| q: jax.Array, # [B, T, H, K] (gated query) | ||
| k: jax.Array, # [B, T, H, K] (gated key) | ||
| w: jax.Array, # [B, T, H, K] (delta-rule erase weight) | ||
| gk: jax.Array, # [B, T, H, K] (per-key gate, cumsum'd, log2-space) | ||
| h0: Optional[jax.Array], # [B, H, K, V] or None (initial state) | ||
| dht: Optional[jax.Array], # [B, H, K, V] or None (gradient of final state) | ||
| do: jax.Array, # [B, T, H, V] (gradient of output) | ||
| dv: jax.Array, # [B, T, H, V] (gradient of v from dAv stage) | ||
| scale: float, | ||
| use_exp2: bool = True, | ||
| ) -> Tuple[jax.Array, Optional[jax.Array], jax.Array]: |
There was a problem hiding this comment.
The chunk_gated_delta_rule_bwd_dhu function is missing the chunk_size parameter in its signature, which is inconsistent with chunk_kda_bwd_dAv and the corresponding Pallas implementation. It currently hardcodes BT = DEFAULT_BT (line 516). Adding chunk_size as an argument would improve flexibility and consistency across the API.
def chunk_gated_delta_rule_bwd_dhu(
q: jax.Array, # [B, T, H, K] (gated query)
k: jax.Array, # [B, T, H, K] (gated key)
w: jax.Array, # [B, T, H, K] (delta-rule erase weight)
gk: jax.Array, # [B, T, H, K] (per-key gate, cumsum'd, log2-space)
h0: Optional[jax.Array], # [B, H, K, V] or None (initial state)
dht: Optional[jax.Array], # [B, H, K, V] or None (gradient of final state)
do: jax.Array, # [B, T, H, V] (gradient of output)
dv: jax.Array, # [B, T, H, V] (gradient of v from dAv stage)
scale: float,
chunk_size: int = DEFAULT_BT,
use_exp2: bool = True,
) -> Tuple[jax.Array, Optional[jax.Array], jax.Array]:
tops/cpu/ops/kda/chunk_bwd.py
Outdated
| """ | ||
| B, T, H, K = q.shape | ||
| V = do.shape[-1] | ||
| BT = DEFAULT_BT |
| q_bh = pad_to_multiple(q_bh, T_padded, 0) | ||
| k_bh = pad_to_multiple(k_bh, T_padded, 0) | ||
| w_bh = pad_to_multiple(w_bh, T_padded, 0) | ||
| gk_bh = pad_to_multiple(gk_bh, T_padded, 0) | ||
| do_bh = pad_to_multiple(do_bh, T_padded, 0) | ||
| dv_bh = pad_to_multiple(dv_bh, T_padded, 0) |
tops/ops/kda/chunk_bwd.py
Outdated
|
|
||
| # dh0: not computed inside the Pallas kernel (would need final carry). | ||
| # For now, return None if h0 is None. | ||
| dh0 = None |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (3)
tops/ops/kda/chunk_bwd.py (1)
371-374: Parametersqandkare accepted but unused.The function signature includes
qandkparameters, and shape assertions are performed on them (Lines 400-401), but they are never used in the actual computation — onlyv,do, andAare reshaped and passed to the kernel.If these parameters are kept for API compatibility with the Triton reference, consider documenting this in the docstring. Otherwise, they could be removed to avoid confusion.
📝 Suggested docstring update
Args: - q: [B, T, H, K] query tensor. - k: [B, T, H, K] key tensor. + q: [B, T, H, K] query tensor (unused, for API compatibility). + k: [B, T, H, K] key tensor (unused, for API compatibility). v: [B, T, H, V] value tensor (v_new in the full backward).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/kda/chunk_bwd.py` around lines 371 - 374, The function chunk_kda_bwd_dAv_kernel currently accepts parameters q and k but never uses them; either remove q and k from the signature (and from any assertions like the ones around lines referencing q/k) to avoid confusion, or explicitly document in the chunk_kda_bwd_dAv_kernel docstring that q and k are accepted only for API compatibility with the Triton reference and are intentionally unused; update the signature or docstring and adjust/remove related shape assertions and comments in the function body accordingly.tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py (2)
282-312: Samedh0coverage gap as CPU tests — use underscore prefix for unused variables.Same issue as the CPU test file:
dh0_refanddh0_pallasare unpacked but never compared. Use underscore prefix to suppress warnings and document intent.✅ Proposed fix
# Triton ground truth - dh_ref, dh0_ref, dv2_ref = triton_chunk_gated_delta_rule_bwd_dhu( + dh_ref, _dh0_ref, dv2_ref = triton_chunk_gated_delta_rule_bwd_dhu( ... ) # Pallas under test - dh_pallas, dh0_pallas, dv2_pallas = pallas_dhu( + dh_pallas, _dh0_pallas, dv2_pallas = pallas_dhu( ... )Apply similar changes to Lines 325, 335, and 404.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py` around lines 282 - 312, The test unpacks dh_ref, dh0_ref, dv2_ref and dh_pallas, dh0_pallas, dv2_pallas from triton_chunk_gated_delta_rule_bwd_dhu and pallas_dhu but never uses dh0_ref/dh0_pallas; update the unpacking to use an underscore-prefixed name for the unused second return (e.g., replace dh0_ref/dh0_pallas with _dh0_ref/_dh0_pallas) wherever these functions are called (references: triton_chunk_gated_delta_rule_bwd_dhu and pallas_dhu, and variables dh_ref, dh0_ref, dv2_ref, dh_pallas, dh0_pallas, dv2_pallas) so the intent is explicit and linter warnings are suppressed; apply the same change to the other occurrences noted in the review.
47-99: Consider extracting shared_generate_inputshelper to reduce duplication.The
_generate_inputsfunction is nearly identical between this file andtest_chunk_bwd_dAv_dhu.py. Consider extracting it to a shared test utility module to reduce maintenance burden and ensure consistency.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py` around lines 47 - 99, Extract the duplicated _generate_inputs helper into a shared test utility module and update both tests to import and call it; specifically move the function (including its signature parameters B, T, H, K, V, chunk_size, gate_logit_normalizer, dtype and its use of triton_chunk_kda_fwd, triton_recompute_w_u_fwd, triton_chunk_gated_delta_rule_fwd_h, SEED, device) into a new test helper (e.g., tests/util/test_helpers.py), export the function name _generate_inputs, ensure the returned dict keys remain identical (q, k, v, g, beta, h0, scale, Aqk, Akk, w, u, qg, kg, v_new, h, initial_state, o, final_state, do, dht, chunk_size), and update both test files to import _generate_inputs from the new module so signatures and behavior are preserved.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/ref/kda/test_chunk_bwd_dAv_dhu.py`:
- Around line 282-312: The test unpacks dh0_ref and dh0_pallas from
triton_chunk_gated_delta_rule_bwd_dhu and cpu_dhu but never validates them, so a
missing dh0 from cpu_dhu can pass silently; update the test to either
(preferred) assert dh0 equality using compare_tensor("dh0 (fp32)", dh0_pallas,
dh0_ref, atol=1e-4, rtol=1e-4) alongside the existing dh/dv2 checks, or if dh0
is intentionally not produced yet, rename the unpacked variables to _ (e.g.,
replace dh0_ref/dh0_pallas with _) and add a one-line TODO comment referencing
cpu_dhu/triton_chunk_gated_delta_rule_bwd_dhu explaining why dh0 is skipped.
Ensure you modify the unpack at the triton_chunk_gated_delta_rule_bwd_dhu and
cpu_dhu call sites (lines where dh_ref, dh0_ref, dv2_ref and dh_pallas,
dh0_pallas, dv2_pallas are assigned).
In `@tops/cpu/ops/kda/chunk_bwd.py`:
- Around line 480-492: The function chunk_gated_delta_rule_bwd_dhu is missing a
chunk_size parameter (tests pass chunk_size=data["chunk_size"]) so add
chunk_size: int (with a default if desired) to the function signature and use
that parameter instead of the hardcoded DEFAULT_BT when computing BT (replace
the hardcoded BT = DEFAULT_BT usage inside the function). Update any internal
references that rely on BT to use the new chunk_size variable (e.g., where BT is
set/used) so the function accepts and respects the caller-provided chunk size.
In `@tops/ops/kda/chunk_bwd.py`:
- Around line 646-648: The Pallas kernel always sets dh0 = None (dh0 in
chunk_bwd.py) which mismatches the CPU reference; update the Pallas path to
capture the final reverse-scan carry (the running b_dh in scan_fn) and return
dh0 when h0 was provided: extract the final b_dh from the kernel/scan output,
reshape it to (B, H, K, V) to match the CPU reference (as CPU does with
dh0_flat.reshape(B, H, K, V)), and return that instead of None; if you
intentionally defer this, add a clear TODO in chunk_bwd.py noting that dh0 is
not yet implemented and why.
---
Nitpick comments:
In `@tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py`:
- Around line 282-312: The test unpacks dh_ref, dh0_ref, dv2_ref and dh_pallas,
dh0_pallas, dv2_pallas from triton_chunk_gated_delta_rule_bwd_dhu and pallas_dhu
but never uses dh0_ref/dh0_pallas; update the unpacking to use an
underscore-prefixed name for the unused second return (e.g., replace
dh0_ref/dh0_pallas with _dh0_ref/_dh0_pallas) wherever these functions are
called (references: triton_chunk_gated_delta_rule_bwd_dhu and pallas_dhu, and
variables dh_ref, dh0_ref, dv2_ref, dh_pallas, dh0_pallas, dv2_pallas) so the
intent is explicit and linter warnings are suppressed; apply the same change to
the other occurrences noted in the review.
- Around line 47-99: Extract the duplicated _generate_inputs helper into a
shared test utility module and update both tests to import and call it;
specifically move the function (including its signature parameters B, T, H, K,
V, chunk_size, gate_logit_normalizer, dtype and its use of triton_chunk_kda_fwd,
triton_recompute_w_u_fwd, triton_chunk_gated_delta_rule_fwd_h, SEED, device)
into a new test helper (e.g., tests/util/test_helpers.py), export the function
name _generate_inputs, ensure the returned dict keys remain identical (q, k, v,
g, beta, h0, scale, Aqk, Akk, w, u, qg, kg, v_new, h, initial_state, o,
final_state, do, dht, chunk_size), and update both test files to import
_generate_inputs from the new module so signatures and behavior are preserved.
In `@tops/ops/kda/chunk_bwd.py`:
- Around line 371-374: The function chunk_kda_bwd_dAv_kernel currently accepts
parameters q and k but never uses them; either remove q and k from the signature
(and from any assertions like the ones around lines referencing q/k) to avoid
confusion, or explicitly document in the chunk_kda_bwd_dAv_kernel docstring that
q and k are accepted only for API compatibility with the Triton reference and
are intentionally unused; update the signature or docstring and adjust/remove
related shape assertions and comments in the function body accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 47341f66-8661-463a-afaf-e0df4bd59d99
📒 Files selected for processing (4)
tests/ref/kda/test_chunk_bwd_dAv_dhu.pytests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.pytops/cpu/ops/kda/chunk_bwd.pytops/ops/kda/chunk_bwd.py
| dh_ref, dh0_ref, dv2_ref = triton_chunk_gated_delta_rule_bwd_dhu( | ||
| q=data["qg"], k=data["kg"], w=data["w"], | ||
| gk=data["g"], | ||
| h0=data["initial_state"], | ||
| dht=data["dht"], | ||
| do=data["do"], dv=dv, | ||
| scale=data["scale"], | ||
| use_exp2=True, | ||
| ) | ||
|
|
||
| # Pallas under test | ||
| dh_pallas, dh0_pallas, dv2_pallas = cpu_dhu( | ||
| q=torch_to_jax(data["qg"]), | ||
| k=torch_to_jax(data["kg"]), | ||
| w=torch_to_jax(data["w"]), | ||
| gk=torch_to_jax(data["g"]), | ||
| h0=torch_to_jax(data["initial_state"]), | ||
| dht=torch_to_jax(data["dht"]), | ||
| do=torch_to_jax(data["do"]), | ||
| dv=torch_to_jax(dv), | ||
| scale=data["scale"], | ||
| chunk_size=data["chunk_size"], | ||
| ) | ||
|
|
||
| # dh: Triton returns [B*NT, H, K, V], Pallas returns [B, NT, H, K, V] | ||
| NT = T // data["chunk_size"] | ||
| assert compare_tensor("dh (fp32)", dh_pallas, | ||
| dh_ref.reshape(B, NT, H, K, V), | ||
| atol=1e-4, rtol=1e-4) | ||
| assert compare_tensor("dv2 (fp32)", dv2_pallas, dv2_ref, | ||
| atol=1e-4, rtol=1e-4) |
There was a problem hiding this comment.
dh0 is extracted but never validated in tests.
The tests unpack dh0_ref and dh0_pallas (Lines 282, 293) but only compare dh and dv2 (Lines 308-312). This means:
- The
dh0output is never validated against the reference - The discrepancy where Pallas returns
Nonebut Triton returns an actual tensor goes undetected
Consider either:
- Adding
dh0comparison when the implementation is complete - Renaming to
_to acknowledge it's intentionally unused - Adding a TODO comment explaining why
dh0isn't tested
✅ Proposed fix to suppress warnings and document intent
# Triton ground truth
- dh_ref, dh0_ref, dv2_ref = triton_chunk_gated_delta_rule_bwd_dhu(
+ dh_ref, _dh0_ref, dv2_ref = triton_chunk_gated_delta_rule_bwd_dhu(
q=data["qg"], k=data["kg"], w=data["w"],
...
)
# Pallas under test
- dh_pallas, dh0_pallas, dv2_pallas = cpu_dhu(
+ dh_pallas, _dh0_pallas, dv2_pallas = cpu_dhu(
...
)
# dh: Triton returns [B*NT, H, K, V], Pallas returns [B, NT, H, K, V]
NT = T // data["chunk_size"]
assert compare_tensor("dh (fp32)", dh_pallas,
dh_ref.reshape(B, NT, H, K, V),
atol=1e-4, rtol=1e-4)
assert compare_tensor("dv2 (fp32)", dv2_pallas, dv2_ref,
atol=1e-4, rtol=1e-4)
+ # TODO: Add dh0 comparison once Pallas kernel computes dh0🧰 Tools
🪛 Ruff (0.15.9)
[warning] 282-282: Unpacked variable dh0_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
[warning] 293-293: Unpacked variable dh0_pallas is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/ref/kda/test_chunk_bwd_dAv_dhu.py` around lines 282 - 312, The test
unpacks dh0_ref and dh0_pallas from triton_chunk_gated_delta_rule_bwd_dhu and
cpu_dhu but never validates them, so a missing dh0 from cpu_dhu can pass
silently; update the test to either (preferred) assert dh0 equality using
compare_tensor("dh0 (fp32)", dh0_pallas, dh0_ref, atol=1e-4, rtol=1e-4)
alongside the existing dh/dv2 checks, or if dh0 is intentionally not produced
yet, rename the unpacked variables to _ (e.g., replace dh0_ref/dh0_pallas with
_) and add a one-line TODO comment referencing
cpu_dhu/triton_chunk_gated_delta_rule_bwd_dhu explaining why dh0 is skipped.
Ensure you modify the unpack at the triton_chunk_gated_delta_rule_bwd_dhu and
cpu_dhu call sites (lines where dh_ref, dh0_ref, dv2_ref and dh_pallas,
dh0_pallas, dv2_pallas are assigned).
|
Warning 规模超限 此 PR 核心代码变更行数为 656,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
|
Warning 规模超限 此 PR 核心代码变更行数为 656,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
|
Warning 规模超限 此 PR 核心代码变更行数为 656,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
|
Warning 规模超限 此 PR 核心代码变更行数为 656,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/ref/kda/test_chunk_bwd_dAv_dhu.py (1)
397-418:dh0shape assertion missing intest_dhu_shapes.The shape test validates
dhanddv2shapes but omitsdh0. For consistency withtest_chunk_bwd_dAv_dhu_pallas.py:test_dhu_shapes(which assertsdh0.shape == (B, H, K, V)), consider adding:- dh, dh0, dv2 = cpu_dhu( + dh, _dh0, dv2 = cpu_dhu( ... ) assert dh.shape == (B, NT, H, K, V) assert dv2.shape == (B, T, H, V) assert dh.dtype == jnp.float32 + # TODO: assert _dh0.shape == (B, H, K, V) when dh0 validation is added🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ref/kda/test_chunk_bwd_dAv_dhu.py` around lines 397 - 418, Add an assertion for dh0's shape in test_dhu_shapes: after calling cpu_dhu (the function under test) assert that dh0.shape == (B, H, K, V) to match the corresponding test in test_chunk_bwd_dAv_dhu_pallas.py and ensure consistency with dh and dv2 checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/ref/kda/test_chunk_bwd_dAv_dhu.py`:
- Around line 397-418: Add an assertion for dh0's shape in test_dhu_shapes:
after calling cpu_dhu (the function under test) assert that dh0.shape == (B, H,
K, V) to match the corresponding test in test_chunk_bwd_dAv_dhu_pallas.py and
ensure consistency with dh and dv2 checks.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 066e0d3f-0ac3-4075-80b6-b719c11046e5
📒 Files selected for processing (6)
tests/ref/kda/test_chunk_bwd.pytests/ref/kda/test_chunk_bwd_dAv_dhu.pytests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.pytests/ref/kda/test_chunk_bwd_pallas.pytops/cpu/ops/kda/chunk_bwd.pytops/ops/kda/chunk_bwd.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tops/cpu/ops/kda/chunk_bwd.py
|
Warning 规模超限 此 PR 核心代码变更行数为 656,超出上限 200(不含测试与文档)。 建议操作: 将此 PR 拆分为多个小 PR,每个 PR 的核心代码变更控制在 200 行以内。 |
Description
add dav and dhu kernel for chunk_kda
Related Issue
Closes #
Change Type
feat— New featurefix— Bug fixrefactor— Code refactoringdocs— Documentationci— CI/CD changestest— Testsperf— Performance improvementChecklist
uv run ruff check src/ tests/anduv run ruff format src/ tests/assertorassert_shape_or_none)tests/ops/,tests/modules/,tests/layers/, ortests/ref/)tops/cpu/is modified, core developers have been notified and PR is labeledcpu-refTest Results
Summary by CodeRabbit
New Features
Tests