Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the KDA reference tests to utilize an updated triton_chunk_kda_fwd interface that returns intermediate tensors directly, thereby removing the need for manual recomputation within the test utility functions. Additionally, KDA tests have been integrated into the GPU CI workflow. The review feedback identifies inconsistencies in the return dictionary structure of _generate_inputs across different test files, specifically noting that the original gate tensor g is being shadowed by g_cumsum and that initial_state is being redundantly assigned. It is recommended to align these structures with the other KDA tests to maintain access to original input tensors.
| q=q, k=k, v=v, g=g_cumsum, beta=beta, | ||
| h0=initial_state, scale=scale, | ||
| Aqk=Aqk, Akk=Akk, | ||
| w=w, u=u, qg=qg, kg=kg, v_new=v_new, h=h, | ||
| initial_state=h0, | ||
| initial_state=initial_state, |
There was a problem hiding this comment.
The return dictionary in _generate_inputs is inconsistent with the structure used in tests/ref/kda/test_chunk_bwd.py. Specifically, g is assigned g_cumsum here (shadowing the original gate), and h0 is assigned initial_state. Additionally, initial_state is duplicated. It is recommended to follow the same structure as in other KDA tests (keeping g as the original gate and adding a separate g_cumsum key) to avoid confusion and maintain access to the original input tensors.
| q=q, k=k, v=v, g=g_cumsum, beta=beta, | ||
| h0=initial_state, scale=scale, | ||
| Aqk=Aqk, Akk=Akk, | ||
| w=w, u=u, qg=qg, kg=kg, v_new=v_new, h=h, | ||
| initial_state=h0, | ||
| initial_state=initial_state, |
There was a problem hiding this comment.
The return dictionary in _generate_inputs is inconsistent with the structure used in tests/ref/kda/test_chunk_bwd.py. Specifically, g is assigned g_cumsum here (shadowing the original gate), and h0 is assigned initial_state. Additionally, initial_state is duplicated. It is recommended to follow the same structure as in other KDA tests (keeping g as the original gate and adding a separate g_cumsum key) to avoid confusion and maintain access to the original input tensors.
|
|
||
| from tests.utils import compare_tensor | ||
|
|
||
| from tops.ops.kda.chunk_bwd import ( |
There was a problem hiding this comment.
Is this file for testing the TPU? Is it in this directory?
📝 WalkthroughWalkthroughThe PR updates CI configuration to run additional KDA special test suites, increases the Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/ref/kda/test_chunk_bwd.py (1)
60-78: Extract the Triton-forward fixture before the copies drift further.This 12-value unpacking and
datadict assembly now lives in four KDA backward suites, and the copies already disagree on key semantics: this helper keeps rawg/h0plusg_cumsum/initial_state, while the dAv+dhu helpers reusegandh0for the transformed values. A shared helper would make the next FLA tuple change much safer.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ref/kda/test_chunk_bwd.py` around lines 60 - 78, Extract the repeated 12-value unpacking and dict assembly around triton_chunk_kda_fwd into a shared fixture/helper used by the KDA backward suites: create a helper function (e.g., make_triton_kda_fixture) that calls triton_chunk_kda_fwd(...) and returns a dict containing the raw inputs (q,k,v,g,h0,scale,beta,chunk_size), plus g_cumsum and initial_state and the computed outputs (o,final_state,Aqk,Akk,w,u,qg,kg,v_new,h); replace the in-place unpack + return dict in test_chunk_bwd.py and the other backward suites to call this helper and use the returned dict, ensuring the helper preserves the original semantics (keep raw g/h0 and g_cumsum/initial_state separate from any transformed values).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In @.github/ci/gpu-tests.sky.yaml:
- Around line 30-32: The CI config currently uses `-k triton` which only selects
tests with "triton" in names and thus misses tests under `tests/ref/kda/` (e.g.,
class methods in TestChunkKdaBwdDAvCpu and TestChunkKdaBwdDAvPallas); update
test selection to reliably include all intended KDA GPU tests by either adding
an explicit pytest marker (decorate relevant tests like
test_cpu_ref_vs_triton_bwd_intra, test_against_triton and the KDA classes with
`@pytest.mark.triton` and change the CI flag to `-m triton`) or remove the `-k
triton` filter so `tests/ref/kda/` runs unfiltered in GPU CI.
---
Nitpick comments:
In `@tests/ref/kda/test_chunk_bwd.py`:
- Around line 60-78: Extract the repeated 12-value unpacking and dict assembly
around triton_chunk_kda_fwd into a shared fixture/helper used by the KDA
backward suites: create a helper function (e.g., make_triton_kda_fixture) that
calls triton_chunk_kda_fwd(...) and returns a dict containing the raw inputs
(q,k,v,g,h0,scale,beta,chunk_size), plus g_cumsum and initial_state and the
computed outputs (o,final_state,Aqk,Akk,w,u,qg,kg,v_new,h); replace the in-place
unpack + return dict in test_chunk_bwd.py and the other backward suites to call
this helper and use the returned dict, ensuring the helper preserves the
original semantics (keep raw g/h0 and g_cumsum/initial_state separate from any
transformed values).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d0565c7a-edd0-4b29-9ed7-22f71dd96758
📒 Files selected for processing (8)
.github/ci/gpu-tests.sky.yamltests/ref/kda/test_chunk_bwd.pytests/ref/kda/test_chunk_bwd_dAv_dhu.pytests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.pytests/ref/kda/test_chunk_bwd_pallas.pytests/ref/kda/test_chunk_kda.pytests/ref/kda/test_fused_recurrent_kda.pytests/ref/kda/test_naive_kda.py
.github/ci/gpu-tests.sky.yaml
Outdated
| tests/ref/simple_gla/ \ | ||
| tests/ref/kda/ \ | ||
| -k triton \ |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's see the context of the gpu-tests.sky.yaml file around lines 30-32
fd -t f gpu-tests.sky.yamlRepository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Once we find it, let's examine the full context around those lines
head -50 .github/ci/gpu-tests.sky.yamlRepository: primatrix/pallas-kernel
Length of output: 734
🏁 Script executed:
# Now let's look at what test files exist in tests/ref/kda/
fd -t f '\.py$' tests/ref/kda/ | head -20Repository: primatrix/pallas-kernel
Length of output: 356
🏁 Script executed:
# Let's examine the test names in the KDA directory to see if they contain "triton"
fd -t f '\.py$' tests/ref/kda/ -x grep -l "def test_\|class Test" {}Repository: primatrix/pallas-kernel
Length of output: 330
🏁 Script executed:
# Let's check if any test names in KDA files actually contain "triton"
fd -t f '\.py$' tests/ref/kda/ -x grep -E "(def test_|class Test)" {} +Repository: primatrix/pallas-kernel
Length of output: 5322
🏁 Script executed:
# Let's search for "triton" mentions in KDA test files
fd -t f '\.py$' tests/ref/kda/ -x grep -i "triton" {} +Repository: primatrix/pallas-kernel
Length of output: 15042
🏁 Script executed:
# Let's check if there are any pytest markers defined in these files
fd -t f '\.py$' tests/ref/kda/ -x grep -E "@pytest|@mark" {} +Repository: primatrix/pallas-kernel
Length of output: 2520
🏁 Script executed:
# Let's also check git to understand if this is a proposed change or current state
git log --oneline -20Repository: primatrix/pallas-kernel
Length of output: 89
-k triton still misses part of the KDA directory.
pytest -k matches filenames, class names, and function names. While some KDA tests include "triton" in their identifiers (test_cpu_ref_vs_triton_bwd_intra, test_against_triton), others like the class methods in TestChunkKdaBwdDAvCpu and TestChunkKdaBwdDAvPallas do not. Adding tests/ref/kda/ without adjusting the filter will leave those tests out of GPU CI.
Consider using explicit test selection, defining a real pytest marker (@pytest.mark.triton) plus -m triton, or removing the -k triton filter if all KDA tests should run on GPU. (docs.pytest.org)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In @.github/ci/gpu-tests.sky.yaml around lines 30 - 32, The CI config currently
uses `-k triton` which only selects tests with "triton" in names and thus misses
tests under `tests/ref/kda/` (e.g., class methods in TestChunkKdaBwdDAvCpu and
TestChunkKdaBwdDAvPallas); update test selection to reliably include all
intended KDA GPU tests by either adding an explicit pytest marker (decorate
relevant tests like test_cpu_ref_vs_triton_bwd_intra, test_against_triton and
the KDA classes with `@pytest.mark.triton` and change the CI flag to `-m triton`)
or remove the `-k triton` filter so `tests/ref/kda/` runs unfiltered in GPU CI.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/ref/kda/special/test_chunk_bwd_pallas.py (1)
150-155:⚠️ Potential issue | 🟡 MinorMissing tolerance parameters in
compare_tensorcalls.These comparisons lack explicit
atol/rtolparameters. The sibling tests (test_chunk_bwd_dAv_dhu.py,test_chunk_bwd_dAv_dhu_pallas.py) consistently specify tolerances (e.g.,atol=5e-5, rtol=5e-5for float32).Proposed fix to add tolerance parameters
- compare_tensor("dq", dq_jax, dq_ref) - compare_tensor("dk", dk_jax, dk_ref) - compare_tensor("dv", dv_jax, dv_ref) - compare_tensor("db", db_jax, db_ref) - compare_tensor("dg", dg_jax, dg_ref) - compare_tensor("dA", dA_jax, dA_ref) + compare_tensor("dq", dq_jax, dq_ref, atol=5e-5, rtol=5e-5) + compare_tensor("dk", dk_jax, dk_ref, atol=5e-5, rtol=5e-5) + compare_tensor("dv", dv_jax, dv_ref, atol=5e-5, rtol=5e-5) + compare_tensor("db", db_jax, db_ref, atol=5e-5, rtol=5e-5) + compare_tensor("dg", dg_jax, dg_ref, atol=5e-5, rtol=5e-5) + compare_tensor("dA", dA_jax, dA_ref, atol=5e-5, rtol=5e-5)As per coding guidelines: "Use
compare_tensorutility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ref/kda/special/test_chunk_bwd_pallas.py` around lines 150 - 155, The compare_tensor calls for dq/dk/dv/db/dg/dA (compare_tensor("dq", dq_jax, dq_ref), etc.) are missing explicit tolerance arguments; update each call to include the same tolerances used in the sibling tests (for float32 use atol=5e-5, rtol=5e-5, and optionally max_ulp if your project uses it) so the assertions match expected numerical slack—apply these parameters to every compare_tensor invocation in this block.tests/ref/kda/special/test_chunk_bwd.py (1)
150-155:⚠️ Potential issue | 🟡 MinorMissing tolerance parameters in
compare_tensorcalls.Same issue as
test_chunk_bwd_pallas.py. Without explicit tolerances, test behavior depends oncompare_tensordefaults which may be too strict or too lenient for these specific kernel comparisons.Proposed fix
- compare_tensor("dq", dq_jax, dq_ref) - compare_tensor("dk", dk_jax, dk_ref) - compare_tensor("dv", dv_jax, dv_ref) - compare_tensor("db", db_jax, db_ref) - compare_tensor("dg", dg_jax, dg_ref) - compare_tensor("dA", dA_jax, dA_ref) + compare_tensor("dq", dq_jax, dq_ref, atol=5e-5, rtol=5e-5) + compare_tensor("dk", dk_jax, dk_ref, atol=5e-5, rtol=5e-5) + compare_tensor("dv", dv_jax, dv_ref, atol=5e-5, rtol=5e-5) + compare_tensor("db", db_jax, db_ref, atol=5e-5, rtol=5e-5) + compare_tensor("dg", dg_jax, dg_ref, atol=5e-5, rtol=5e-5) + compare_tensor("dA", dA_jax, dA_ref, atol=5e-5, rtol=5e-5)As per coding guidelines: "Use
compare_tensorutility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ref/kda/special/test_chunk_bwd.py` around lines 150 - 155, The six compare_tensor calls (compare_tensor("dq", dq_jax, dq_ref), "dk", "dv", "db", "dg", "dA") are missing explicit tolerance arguments; update each call to pass explicit atol, rtol and max_ulp values when comparing kernel outputs (use the same tolerance choices as other kernel tests such as in test_chunk_bwd_pallas.py or the guidance in tests/utils.py) so comparisons are deterministic and appropriate for FP error; modify the compare_tensor invocations for dq, dk, dv, db, dg, and dA to include those three parameters.
🧹 Nitpick comments (4)
tests/ref/kda/special/test_chunk_bwd_pallas.py (2)
37-39: Consider importingtorch_to_jaxfromtests.utilsfor consistency.This file defines
torch_to_jaxlocally, whiletest_chunk_bwd_dAv_dhu.pyandtest_chunk_bwd_dAv_dhu_pallas.pyimport it fromtests.utils. Using the shared utility reduces duplication.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ref/kda/special/test_chunk_bwd_pallas.py` around lines 37 - 39, The local function torch_to_jax defined in this test should be removed and replaced with a shared import from tests.utils to avoid duplication; delete the local def torch_to_jax(...) and add an import bringing torch_to_jax from tests.utils (ensure tests.utils exports torch_to_jax), updating the top-of-file imports to reference the symbol instead of redefining it.
60-82: Dictionary mappings differ from sibling test files.This file stores both
gandg_cumsumseparately (line 74), whiletest_chunk_bwd_dAv_dhu.pyandtest_chunk_bwd_dAv_dhu_pallas.pysetg=g_cumsum. Similarly, this file uses the localh0(line 75), while the others useh0=initial_state.While the logic appears correct (downstream usages reference
g_cumsumexplicitly at line 103), the inconsistency may cause confusion during maintenance.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ref/kda/special/test_chunk_bwd_pallas.py` around lines 60 - 82, The dict returned by the test currently exposes both g and g_cumsum and uses h0 directly; make it consistent with sibling tests by mapping g to g_cumsum (i.e., g=g_cumsum) and set h0 to the produced initial_state (i.e., h0=initial_state) in the returned dict so callers see the same keys/semantics as in test_chunk_bwd_dAv_dhu.py and test_chunk_bwd_dAv_dhu_pallas.py; update the dictionary entries that reference g, g_cumsum, h0, and initial_state accordingly.tests/ref/kda/special/test_chunk_bwd.py (2)
73-82: Rawgis stored but appears unused in this test file.Line 74 stores both
g=gandg_cumsum=g_cumsum, but all downstream usages referenceg_cumsum(lines 103, 127, 139, 168, 194, 203). The raw gate values could be removed to reduce dictionary size, though this is minor.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ref/kda/special/test_chunk_bwd.py` around lines 73 - 82, The returned dict currently includes both g and g_cumsum but the raw gate g is never used elsewhere in the test; remove the unused g entry from the return dict (the dict creation that currently lists g=g, g_cumsum=g_cumsum) so only g_cumsum is returned, or alternatively replace any downstream references to g_cumsum with g if you intended to keep raw g—update the dict in the function that builds the return value accordingly.
37-39: Localtorch_to_jaxduplicates the shared utility intests.utils.Same issue as
test_chunk_bwd_pallas.py. Consider importing fromtests.utilsto maintain a single source of truth.Proposed fix
-def torch_to_jax(t: torch.Tensor) -> jax.Array: - """PyTorch CUDA tensor -> JAX array (via CPU).""" - return jnp.array(t.detach().cpu().float().numpy()) +from tests.utils import torch_to_jaxNote: This requires moving the import to line 8 with other test utility imports.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ref/kda/special/test_chunk_bwd.py` around lines 37 - 39, The file defines a local utility function torch_to_jax that duplicates the shared helper in tests.utils; remove the local torch_to_jax function and instead import the shared helper from tests.utils (use the existing symbol name torch_to_jax) alongside the other test utility imports so there is a single source of truth (move/add the import with the other test utility imports near the top).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tests/ref/kda/special/test_chunk_bwd_pallas.py`:
- Around line 150-155: The compare_tensor calls for dq/dk/dv/db/dg/dA
(compare_tensor("dq", dq_jax, dq_ref), etc.) are missing explicit tolerance
arguments; update each call to include the same tolerances used in the sibling
tests (for float32 use atol=5e-5, rtol=5e-5, and optionally max_ulp if your
project uses it) so the assertions match expected numerical slack—apply these
parameters to every compare_tensor invocation in this block.
In `@tests/ref/kda/special/test_chunk_bwd.py`:
- Around line 150-155: The six compare_tensor calls (compare_tensor("dq",
dq_jax, dq_ref), "dk", "dv", "db", "dg", "dA") are missing explicit tolerance
arguments; update each call to pass explicit atol, rtol and max_ulp values when
comparing kernel outputs (use the same tolerance choices as other kernel tests
such as in test_chunk_bwd_pallas.py or the guidance in tests/utils.py) so
comparisons are deterministic and appropriate for FP error; modify the
compare_tensor invocations for dq, dk, dv, db, dg, and dA to include those three
parameters.
---
Nitpick comments:
In `@tests/ref/kda/special/test_chunk_bwd_pallas.py`:
- Around line 37-39: The local function torch_to_jax defined in this test should
be removed and replaced with a shared import from tests.utils to avoid
duplication; delete the local def torch_to_jax(...) and add an import bringing
torch_to_jax from tests.utils (ensure tests.utils exports torch_to_jax),
updating the top-of-file imports to reference the symbol instead of redefining
it.
- Around line 60-82: The dict returned by the test currently exposes both g and
g_cumsum and uses h0 directly; make it consistent with sibling tests by mapping
g to g_cumsum (i.e., g=g_cumsum) and set h0 to the produced initial_state (i.e.,
h0=initial_state) in the returned dict so callers see the same keys/semantics as
in test_chunk_bwd_dAv_dhu.py and test_chunk_bwd_dAv_dhu_pallas.py; update the
dictionary entries that reference g, g_cumsum, h0, and initial_state
accordingly.
In `@tests/ref/kda/special/test_chunk_bwd.py`:
- Around line 73-82: The returned dict currently includes both g and g_cumsum
but the raw gate g is never used elsewhere in the test; remove the unused g
entry from the return dict (the dict creation that currently lists g=g,
g_cumsum=g_cumsum) so only g_cumsum is returned, or alternatively replace any
downstream references to g_cumsum with g if you intended to keep raw g—update
the dict in the function that builds the return value accordingly.
- Around line 37-39: The file defines a local utility function torch_to_jax that
duplicates the shared helper in tests.utils; remove the local torch_to_jax
function and instead import the shared helper from tests.utils (use the existing
symbol name torch_to_jax) alongside the other test utility imports so there is a
single source of truth (move/add the import with the other test utility imports
near the top).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bbf7ee92-2276-4620-a9e3-24e390191c7d
📒 Files selected for processing (7)
.github/ci/gpu-tests.sky.yamlpyproject.tomltests/ref/kda/special/test_chunk_bwd.pytests/ref/kda/special/test_chunk_bwd_dAv_dhu.pytests/ref/kda/special/test_chunk_bwd_dAv_dhu_pallas.pytests/ref/kda/special/test_chunk_bwd_pallas.pytests/ref/kda/test_fused_recurrent_kda.py
✅ Files skipped from review due to trivial changes (1)
- pyproject.toml
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/ref/kda/test_fused_recurrent_kda.py
- .github/ci/gpu-tests.sky.yaml
Description
use fla==0.4.2 to test
Related Issue
Closes #
Change Type
feat— New featurefix— Bug fixrefactor— Code refactoringdocs— Documentationci— CI/CD changestest— Testsperf— Performance improvementChecklist
uv run ruff check src/ tests/anduv run ruff format src/ tests/assertorassert_shape_or_none)tests/ops/,tests/modules/,tests/layers/, ortests/ref/)tops/cpu/is modified, core developers have been notified and PR is labeledcpu-refTest Results
Summary by CodeRabbit
Release Notes
Dependencies
Tests
Chores