Skip to content

Fix/new api#196

Merged
pathfinder-pf merged 8 commits intomainfrom
fix/new_api
Apr 14, 2026
Merged

Fix/new api#196
pathfinder-pf merged 8 commits intomainfrom
fix/new_api

Conversation

@pathfinder-pf
Copy link
Copy Markdown
Collaborator

@pathfinder-pf pathfinder-pf commented Apr 14, 2026

Description

use fla==0.4.2 to test

Related Issue

Closes #

Change Type

  • feat — New feature
  • fix — Bug fix
  • refactor — Code refactoring
  • docs — Documentation
  • ci — CI/CD changes
  • test — Tests
  • perf — Performance improvement

Checklist

  • Code passes uv run ruff check src/ tests/ and uv run ruff format src/ tests/
  • New/modified public APIs have complete docstrings (tensor shapes, dimension semantics, business logic)
  • Public functions have input assertions (assert or assert_shape_or_none)
  • Tests added at the appropriate layer (tests/ops/, tests/modules/, tests/layers/, or tests/ref/)
  • If tops/cpu/ is modified, core developers have been notified and PR is labeled cpu-ref

Test Results

Paste relevant test output here.

Summary by CodeRabbit

Release Notes

  • Dependencies

    • Updated flash-linear-attention to version >=0.4.2 for improved compatibility
  • Tests

    • Refined numerical tolerance thresholds for floating-point precision testing
    • Enhanced test coverage and execution flow for GPU-accelerated operations
  • Chores

    • Optimized CI GPU test pipeline with expanded execution scope

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the KDA reference tests to utilize an updated triton_chunk_kda_fwd interface that returns intermediate tensors directly, thereby removing the need for manual recomputation within the test utility functions. Additionally, KDA tests have been integrated into the GPU CI workflow. The review feedback identifies inconsistencies in the return dictionary structure of _generate_inputs across different test files, specifically noting that the original gate tensor g is being shadowed by g_cumsum and that initial_state is being redundantly assigned. It is recommended to align these structures with the other KDA tests to maintain access to original input tensors.

Comment on lines +79 to +83
q=q, k=k, v=v, g=g_cumsum, beta=beta,
h0=initial_state, scale=scale,
Aqk=Aqk, Akk=Akk,
w=w, u=u, qg=qg, kg=kg, v_new=v_new, h=h,
initial_state=h0,
initial_state=initial_state,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return dictionary in _generate_inputs is inconsistent with the structure used in tests/ref/kda/test_chunk_bwd.py. Specifically, g is assigned g_cumsum here (shadowing the original gate), and h0 is assigned initial_state. Additionally, initial_state is duplicated. It is recommended to follow the same structure as in other KDA tests (keeping g as the original gate and adding a separate g_cumsum key) to avoid confusion and maintain access to the original input tensors.

Comment on lines +80 to +84
q=q, k=k, v=v, g=g_cumsum, beta=beta,
h0=initial_state, scale=scale,
Aqk=Aqk, Akk=Akk,
w=w, u=u, qg=qg, kg=kg, v_new=v_new, h=h,
initial_state=h0,
initial_state=initial_state,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return dictionary in _generate_inputs is inconsistent with the structure used in tests/ref/kda/test_chunk_bwd.py. Specifically, g is assigned g_cumsum here (shadowing the original gate), and h0 is assigned initial_state. Additionally, initial_state is duplicated. It is recommended to follow the same structure as in other KDA tests (keeping g as the original gate and adding a separate g_cumsum key) to avoid confusion and maintain access to the original input tensors.


from tests.utils import compare_tensor

from tops.ops.kda.chunk_bwd import (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this file for testing the TPU? Is it in this directory?

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 14, 2026

📝 Walkthrough

Walkthrough

The PR updates CI configuration to run additional KDA special test suites, increases the flash-linear-attention dependency version, loosens numerical tolerance thresholds in a fused recurrent test, and refactors multiple KDA test files to use an expanded forward kernel API that returns intermediates directly instead of requiring separate recomputation calls.

Changes

Cohort / File(s) Summary
CI Configuration
.github/ci/gpu-tests.sky.yaml
Extended GPU test pipeline to execute pytest twice: once targeting tests/ref/kda/special/ with triton filters, then tests/ref/simple_gla/ with original options.
Dependency Upgrade
pyproject.toml
Increased minimum flash-linear-attention version from >=0.4.1 to >=0.4.2 in GPU optional dependencies.
Numerical Tolerance Adjustment
tests/ref/kda/test_fused_recurrent_kda.py
Loosened comparison tolerances (atol/rtol) from 5e-3 to 5e-2 for non-float32 precision cases in Triton vs. CPU reference test.
KDA Test Refactoring (Forward API Migration)
tests/ref/kda/special/test_chunk_bwd.py, tests/ref/kda/special/test_chunk_bwd_dAv_dhu.py, tests/ref/kda/special/test_chunk_bwd_dAv_dhu_pallas.py, tests/ref/kda/special/test_chunk_bwd_pallas.py
Migrated forward kernel imports to fla.ops.kda.chunk_fwd.chunk_kda_fwd and restructured to destructure expanded return tuples (g_cumsum, w, u, qg, kg, v_new, h, initial_state) with disable_recompute=True, eliminating separate recomputation helper calls.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

beaver/needs-split

Suggested reviewers

  • 0xaskr

Poem

🐰 The forward kernel blooms anew,
No recomputation makes tests cleaner, true!
Intermediates flow in one return,
Fresh tolerances let differences churn,
A kernel dance, simplified and bright! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'Fix/new api' is too vague and generic to convey meaningful information about the changeset, which involves multiple test updates, dependency version bumping, and API signature changes. Use a more specific title that captures the main change, such as 'Update KDA tests for flash-linear-attention 0.4.2 API' or 'Refactor KDA tests to use new forward/backward function signatures'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/new_api

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/ref/kda/test_chunk_bwd.py (1)

60-78: Extract the Triton-forward fixture before the copies drift further.

This 12-value unpacking and data dict assembly now lives in four KDA backward suites, and the copies already disagree on key semantics: this helper keeps raw g/h0 plus g_cumsum/initial_state, while the dAv+dhu helpers reuse g and h0 for the transformed values. A shared helper would make the next FLA tuple change much safer.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/test_chunk_bwd.py` around lines 60 - 78, Extract the repeated
12-value unpacking and dict assembly around triton_chunk_kda_fwd into a shared
fixture/helper used by the KDA backward suites: create a helper function (e.g.,
make_triton_kda_fixture) that calls triton_chunk_kda_fwd(...) and returns a dict
containing the raw inputs (q,k,v,g,h0,scale,beta,chunk_size), plus g_cumsum and
initial_state and the computed outputs
(o,final_state,Aqk,Akk,w,u,qg,kg,v_new,h); replace the in-place unpack + return
dict in test_chunk_bwd.py and the other backward suites to call this helper and
use the returned dict, ensuring the helper preserves the original semantics
(keep raw g/h0 and g_cumsum/initial_state separate from any transformed values).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In @.github/ci/gpu-tests.sky.yaml:
- Around line 30-32: The CI config currently uses `-k triton` which only selects
tests with "triton" in names and thus misses tests under `tests/ref/kda/` (e.g.,
class methods in TestChunkKdaBwdDAvCpu and TestChunkKdaBwdDAvPallas); update
test selection to reliably include all intended KDA GPU tests by either adding
an explicit pytest marker (decorate relevant tests like
test_cpu_ref_vs_triton_bwd_intra, test_against_triton and the KDA classes with
`@pytest.mark.triton` and change the CI flag to `-m triton`) or remove the `-k
triton` filter so `tests/ref/kda/` runs unfiltered in GPU CI.

---

Nitpick comments:
In `@tests/ref/kda/test_chunk_bwd.py`:
- Around line 60-78: Extract the repeated 12-value unpacking and dict assembly
around triton_chunk_kda_fwd into a shared fixture/helper used by the KDA
backward suites: create a helper function (e.g., make_triton_kda_fixture) that
calls triton_chunk_kda_fwd(...) and returns a dict containing the raw inputs
(q,k,v,g,h0,scale,beta,chunk_size), plus g_cumsum and initial_state and the
computed outputs (o,final_state,Aqk,Akk,w,u,qg,kg,v_new,h); replace the in-place
unpack + return dict in test_chunk_bwd.py and the other backward suites to call
this helper and use the returned dict, ensuring the helper preserves the
original semantics (keep raw g/h0 and g_cumsum/initial_state separate from any
transformed values).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d0565c7a-edd0-4b29-9ed7-22f71dd96758

📥 Commits

Reviewing files that changed from the base of the PR and between 2a2e87c and e2f4377.

📒 Files selected for processing (8)
  • .github/ci/gpu-tests.sky.yaml
  • tests/ref/kda/test_chunk_bwd.py
  • tests/ref/kda/test_chunk_bwd_dAv_dhu.py
  • tests/ref/kda/test_chunk_bwd_dAv_dhu_pallas.py
  • tests/ref/kda/test_chunk_bwd_pallas.py
  • tests/ref/kda/test_chunk_kda.py
  • tests/ref/kda/test_fused_recurrent_kda.py
  • tests/ref/kda/test_naive_kda.py

Comment on lines 30 to 32
tests/ref/simple_gla/ \
tests/ref/kda/ \
-k triton \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's see the context of the gpu-tests.sky.yaml file around lines 30-32
fd -t f gpu-tests.sky.yaml

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Once we find it, let's examine the full context around those lines
head -50 .github/ci/gpu-tests.sky.yaml

Repository: primatrix/pallas-kernel

Length of output: 734


🏁 Script executed:

# Now let's look at what test files exist in tests/ref/kda/
fd -t f '\.py$' tests/ref/kda/ | head -20

Repository: primatrix/pallas-kernel

Length of output: 356


🏁 Script executed:

# Let's examine the test names in the KDA directory to see if they contain "triton"
fd -t f '\.py$' tests/ref/kda/ -x grep -l "def test_\|class Test" {}

Repository: primatrix/pallas-kernel

Length of output: 330


🏁 Script executed:

# Let's check if any test names in KDA files actually contain "triton"
fd -t f '\.py$' tests/ref/kda/ -x grep -E "(def test_|class Test)" {} +

Repository: primatrix/pallas-kernel

Length of output: 5322


🏁 Script executed:

# Let's search for "triton" mentions in KDA test files
fd -t f '\.py$' tests/ref/kda/ -x grep -i "triton" {} +

Repository: primatrix/pallas-kernel

Length of output: 15042


🏁 Script executed:

# Let's check if there are any pytest markers defined in these files
fd -t f '\.py$' tests/ref/kda/ -x grep -E "@pytest|@mark" {} +

Repository: primatrix/pallas-kernel

Length of output: 2520


🏁 Script executed:

# Let's also check git to understand if this is a proposed change or current state
git log --oneline -20

Repository: primatrix/pallas-kernel

Length of output: 89


-k triton still misses part of the KDA directory.

pytest -k matches filenames, class names, and function names. While some KDA tests include "triton" in their identifiers (test_cpu_ref_vs_triton_bwd_intra, test_against_triton), others like the class methods in TestChunkKdaBwdDAvCpu and TestChunkKdaBwdDAvPallas do not. Adding tests/ref/kda/ without adjusting the filter will leave those tests out of GPU CI.

Consider using explicit test selection, defining a real pytest marker (@pytest.mark.triton) plus -m triton, or removing the -k triton filter if all KDA tests should run on GPU. (docs.pytest.org)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In @.github/ci/gpu-tests.sky.yaml around lines 30 - 32, The CI config currently
uses `-k triton` which only selects tests with "triton" in names and thus misses
tests under `tests/ref/kda/` (e.g., class methods in TestChunkKdaBwdDAvCpu and
TestChunkKdaBwdDAvPallas); update test selection to reliably include all
intended KDA GPU tests by either adding an explicit pytest marker (decorate
relevant tests like test_cpu_ref_vs_triton_bwd_intra, test_against_triton and
the KDA classes with `@pytest.mark.triton` and change the CI flag to `-m triton`)
or remove the `-k triton` filter so `tests/ref/kda/` runs unfiltered in GPU CI.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/ref/kda/special/test_chunk_bwd_pallas.py (1)

150-155: ⚠️ Potential issue | 🟡 Minor

Missing tolerance parameters in compare_tensor calls.

These comparisons lack explicit atol/rtol parameters. The sibling tests (test_chunk_bwd_dAv_dhu.py, test_chunk_bwd_dAv_dhu_pallas.py) consistently specify tolerances (e.g., atol=5e-5, rtol=5e-5 for float32).

Proposed fix to add tolerance parameters
-        compare_tensor("dq", dq_jax, dq_ref)
-        compare_tensor("dk", dk_jax, dk_ref)
-        compare_tensor("dv", dv_jax, dv_ref)
-        compare_tensor("db", db_jax, db_ref)
-        compare_tensor("dg", dg_jax, dg_ref)
-        compare_tensor("dA", dA_jax, dA_ref)
+        compare_tensor("dq", dq_jax, dq_ref, atol=5e-5, rtol=5e-5)
+        compare_tensor("dk", dk_jax, dk_ref, atol=5e-5, rtol=5e-5)
+        compare_tensor("dv", dv_jax, dv_ref, atol=5e-5, rtol=5e-5)
+        compare_tensor("db", db_jax, db_ref, atol=5e-5, rtol=5e-5)
+        compare_tensor("dg", dg_jax, dg_ref, atol=5e-5, rtol=5e-5)
+        compare_tensor("dA", dA_jax, dA_ref, atol=5e-5, rtol=5e-5)

As per coding guidelines: "Use compare_tensor utility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/special/test_chunk_bwd_pallas.py` around lines 150 - 155, The
compare_tensor calls for dq/dk/dv/db/dg/dA (compare_tensor("dq", dq_jax,
dq_ref), etc.) are missing explicit tolerance arguments; update each call to
include the same tolerances used in the sibling tests (for float32 use
atol=5e-5, rtol=5e-5, and optionally max_ulp if your project uses it) so the
assertions match expected numerical slack—apply these parameters to every
compare_tensor invocation in this block.
tests/ref/kda/special/test_chunk_bwd.py (1)

150-155: ⚠️ Potential issue | 🟡 Minor

Missing tolerance parameters in compare_tensor calls.

Same issue as test_chunk_bwd_pallas.py. Without explicit tolerances, test behavior depends on compare_tensor defaults which may be too strict or too lenient for these specific kernel comparisons.

Proposed fix
-        compare_tensor("dq", dq_jax, dq_ref)
-        compare_tensor("dk", dk_jax, dk_ref)
-        compare_tensor("dv", dv_jax, dv_ref)
-        compare_tensor("db", db_jax, db_ref)
-        compare_tensor("dg", dg_jax, dg_ref)
-        compare_tensor("dA", dA_jax, dA_ref)
+        compare_tensor("dq", dq_jax, dq_ref, atol=5e-5, rtol=5e-5)
+        compare_tensor("dk", dk_jax, dk_ref, atol=5e-5, rtol=5e-5)
+        compare_tensor("dv", dv_jax, dv_ref, atol=5e-5, rtol=5e-5)
+        compare_tensor("db", db_jax, db_ref, atol=5e-5, rtol=5e-5)
+        compare_tensor("dg", dg_jax, dg_ref, atol=5e-5, rtol=5e-5)
+        compare_tensor("dA", dA_jax, dA_ref, atol=5e-5, rtol=5e-5)

As per coding guidelines: "Use compare_tensor utility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/special/test_chunk_bwd.py` around lines 150 - 155, The six
compare_tensor calls (compare_tensor("dq", dq_jax, dq_ref), "dk", "dv", "db",
"dg", "dA") are missing explicit tolerance arguments; update each call to pass
explicit atol, rtol and max_ulp values when comparing kernel outputs (use the
same tolerance choices as other kernel tests such as in test_chunk_bwd_pallas.py
or the guidance in tests/utils.py) so comparisons are deterministic and
appropriate for FP error; modify the compare_tensor invocations for dq, dk, dv,
db, dg, and dA to include those three parameters.
🧹 Nitpick comments (4)
tests/ref/kda/special/test_chunk_bwd_pallas.py (2)

37-39: Consider importing torch_to_jax from tests.utils for consistency.

This file defines torch_to_jax locally, while test_chunk_bwd_dAv_dhu.py and test_chunk_bwd_dAv_dhu_pallas.py import it from tests.utils. Using the shared utility reduces duplication.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/special/test_chunk_bwd_pallas.py` around lines 37 - 39, The
local function torch_to_jax defined in this test should be removed and replaced
with a shared import from tests.utils to avoid duplication; delete the local def
torch_to_jax(...) and add an import bringing torch_to_jax from tests.utils
(ensure tests.utils exports torch_to_jax), updating the top-of-file imports to
reference the symbol instead of redefining it.

60-82: Dictionary mappings differ from sibling test files.

This file stores both g and g_cumsum separately (line 74), while test_chunk_bwd_dAv_dhu.py and test_chunk_bwd_dAv_dhu_pallas.py set g=g_cumsum. Similarly, this file uses the local h0 (line 75), while the others use h0=initial_state.

While the logic appears correct (downstream usages reference g_cumsum explicitly at line 103), the inconsistency may cause confusion during maintenance.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/special/test_chunk_bwd_pallas.py` around lines 60 - 82, The
dict returned by the test currently exposes both g and g_cumsum and uses h0
directly; make it consistent with sibling tests by mapping g to g_cumsum (i.e.,
g=g_cumsum) and set h0 to the produced initial_state (i.e., h0=initial_state) in
the returned dict so callers see the same keys/semantics as in
test_chunk_bwd_dAv_dhu.py and test_chunk_bwd_dAv_dhu_pallas.py; update the
dictionary entries that reference g, g_cumsum, h0, and initial_state
accordingly.
tests/ref/kda/special/test_chunk_bwd.py (2)

73-82: Raw g is stored but appears unused in this test file.

Line 74 stores both g=g and g_cumsum=g_cumsum, but all downstream usages reference g_cumsum (lines 103, 127, 139, 168, 194, 203). The raw gate values could be removed to reduce dictionary size, though this is minor.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/special/test_chunk_bwd.py` around lines 73 - 82, The returned
dict currently includes both g and g_cumsum but the raw gate g is never used
elsewhere in the test; remove the unused g entry from the return dict (the dict
creation that currently lists g=g, g_cumsum=g_cumsum) so only g_cumsum is
returned, or alternatively replace any downstream references to g_cumsum with g
if you intended to keep raw g—update the dict in the function that builds the
return value accordingly.

37-39: Local torch_to_jax duplicates the shared utility in tests.utils.

Same issue as test_chunk_bwd_pallas.py. Consider importing from tests.utils to maintain a single source of truth.

Proposed fix
-def torch_to_jax(t: torch.Tensor) -> jax.Array:
-    """PyTorch CUDA tensor -> JAX array (via CPU)."""
-    return jnp.array(t.detach().cpu().float().numpy())
+from tests.utils import torch_to_jax

Note: This requires moving the import to line 8 with other test utility imports.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/special/test_chunk_bwd.py` around lines 37 - 39, The file
defines a local utility function torch_to_jax that duplicates the shared helper
in tests.utils; remove the local torch_to_jax function and instead import the
shared helper from tests.utils (use the existing symbol name torch_to_jax)
alongside the other test utility imports so there is a single source of truth
(move/add the import with the other test utility imports near the top).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tests/ref/kda/special/test_chunk_bwd_pallas.py`:
- Around line 150-155: The compare_tensor calls for dq/dk/dv/db/dg/dA
(compare_tensor("dq", dq_jax, dq_ref), etc.) are missing explicit tolerance
arguments; update each call to include the same tolerances used in the sibling
tests (for float32 use atol=5e-5, rtol=5e-5, and optionally max_ulp if your
project uses it) so the assertions match expected numerical slack—apply these
parameters to every compare_tensor invocation in this block.

In `@tests/ref/kda/special/test_chunk_bwd.py`:
- Around line 150-155: The six compare_tensor calls (compare_tensor("dq",
dq_jax, dq_ref), "dk", "dv", "db", "dg", "dA") are missing explicit tolerance
arguments; update each call to pass explicit atol, rtol and max_ulp values when
comparing kernel outputs (use the same tolerance choices as other kernel tests
such as in test_chunk_bwd_pallas.py or the guidance in tests/utils.py) so
comparisons are deterministic and appropriate for FP error; modify the
compare_tensor invocations for dq, dk, dv, db, dg, and dA to include those three
parameters.

---

Nitpick comments:
In `@tests/ref/kda/special/test_chunk_bwd_pallas.py`:
- Around line 37-39: The local function torch_to_jax defined in this test should
be removed and replaced with a shared import from tests.utils to avoid
duplication; delete the local def torch_to_jax(...) and add an import bringing
torch_to_jax from tests.utils (ensure tests.utils exports torch_to_jax),
updating the top-of-file imports to reference the symbol instead of redefining
it.
- Around line 60-82: The dict returned by the test currently exposes both g and
g_cumsum and uses h0 directly; make it consistent with sibling tests by mapping
g to g_cumsum (i.e., g=g_cumsum) and set h0 to the produced initial_state (i.e.,
h0=initial_state) in the returned dict so callers see the same keys/semantics as
in test_chunk_bwd_dAv_dhu.py and test_chunk_bwd_dAv_dhu_pallas.py; update the
dictionary entries that reference g, g_cumsum, h0, and initial_state
accordingly.

In `@tests/ref/kda/special/test_chunk_bwd.py`:
- Around line 73-82: The returned dict currently includes both g and g_cumsum
but the raw gate g is never used elsewhere in the test; remove the unused g
entry from the return dict (the dict creation that currently lists g=g,
g_cumsum=g_cumsum) so only g_cumsum is returned, or alternatively replace any
downstream references to g_cumsum with g if you intended to keep raw g—update
the dict in the function that builds the return value accordingly.
- Around line 37-39: The file defines a local utility function torch_to_jax that
duplicates the shared helper in tests.utils; remove the local torch_to_jax
function and instead import the shared helper from tests.utils (use the existing
symbol name torch_to_jax) alongside the other test utility imports so there is a
single source of truth (move/add the import with the other test utility imports
near the top).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bbf7ee92-2276-4620-a9e3-24e390191c7d

📥 Commits

Reviewing files that changed from the base of the PR and between e2f4377 and f618003.

📒 Files selected for processing (7)
  • .github/ci/gpu-tests.sky.yaml
  • pyproject.toml
  • tests/ref/kda/special/test_chunk_bwd.py
  • tests/ref/kda/special/test_chunk_bwd_dAv_dhu.py
  • tests/ref/kda/special/test_chunk_bwd_dAv_dhu_pallas.py
  • tests/ref/kda/special/test_chunk_bwd_pallas.py
  • tests/ref/kda/test_fused_recurrent_kda.py
✅ Files skipped from review due to trivial changes (1)
  • pyproject.toml
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/ref/kda/test_fused_recurrent_kda.py
  • .github/ci/gpu-tests.sky.yaml

@pathfinder-pf pathfinder-pf added this pull request to the merge queue Apr 14, 2026
Merged via the queue into main with commit 0e0f4e0 Apr 14, 2026
3 of 4 checks passed
@pathfinder-pf pathfinder-pf deleted the fix/new_api branch April 14, 2026 08:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants