Skip to content

core/ops/scan: reuse body state across iterations (skip per-timestep plan churn)#2257

Draft
czoli1976 wants to merge 2 commits into
sonos:mainfrom
czoli1976:feature/scan-iter-reuse
Draft

core/ops/scan: reuse body state across iterations (skip per-timestep plan churn)#2257
czoli1976 wants to merge 2 commits into
sonos:mainfrom
czoli1976:feature/scan-iter-reuse

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

Summary

Draft for feedback. Trims the per-iteration scaffolding in the optimized Scan
body loop: it runs the same plan with the same shapes every timestep, so we
can resolve the body's symbols once, reset between iterations without
discarding them, and reuse one drained input buffer — instead of a full
model_state.run() cycle (set_inputs → resolve_symbols → exec → outputs →
reset_turn) per timestep.

Behavior-preserving (it's a pure optimization). Modest but broad: it helps every
Scan/Loop/RNN model, most where the per-iter overhead is a larger fraction
of the body work.

What changes

  • core/src/plan.rs: reset_turn_keep_symbols() (light reset that keeps
    resolved_symbols), clear_resolved_symbols(), set_inputs_drain() (reuse
    the caller's buffer), and resolve_symbols_with_states() made pub(crate).
  • core/src/ops/scan/optimized.rs: the body loop resolves symbols on the first
    iteration only, resets with reset_turn_keep_symbols(), and reuses a single
    iter_inputs buffer (drained) — replacing the Option/Vec/flatten
    per-iter construction.
  • Gated by TRACT_DISABLE_SCAN_ITER_REUSE (defaults on) — both for A/B and as a
    safety switch on this core path. Happy to drop the gate if you'd rather.

Correctness — proven behavior-preserving

A/B in one binary (new = default, old = TRACT_DISABLE_SCAN_ITER_REUSE=1) gives
bit-identical results across GRU / LSTM / vanilla-RNN (× 2 sizes) and DFN3
df_dec — same outlier profile vs ORT every time (e.g. GRU 382/382, LSTM 30/30,
RNN 9975/9975, df_dec 8/96000). Full tract-core lib suite (240 tests) passes.

Performance (Apple M1 Pro, single-thread, median; vs the gated old path)

Native:

model old new Δ
gru 128/50 0.429 0.411 −4.2%
lstm 128/50 0.562 0.545 −3.0%
gru 256/100 1.910 1.889 −1.1%
lstm 256/100 2.518 2.526 ~par
df_dec (S=100) 5.999 5.988 −0.2%

WASM (wasmtime, +simd128,+relaxed-simd):

model old new Δ
lstm 256/100 3.68 3.60 −2.2%
gru 256/100 2.76 2.74 −0.8%
rnn 256/100 0.925 0.92 −0.8%
  • The win scales with the overhead fraction: bigger on small/lightweight RNNs,
    parity where the body matmul dominates (gru_256, df_dec). It's also more
    valuable as the body kernels get faster
    — e.g. once linalg/arm64/apple_amx: shape-aware AMX dispatch (5-43% canary wins, no regressions) #2220's AMX dispatch
    shrank the body matmul, gru_256 native went from 0% to −1.1%.
  • For context vs ORT on this machine: tract's RNN execution is competitive-to-
    faster (native 1-thread tract gru_256 1.89 ms vs ORT 2.34; vanilla RNN ~3×
    faster). This PR shaves a bit more off an already-decent position.

Measurement note

The A/B above isolates the run-path change: both arms share the cleaned-up
iter_inputs build, so the numbers slightly understate the full change vs
pristine main (which also did an extra Vec<Option> alloc + flatten per
iter). The pristine-vs-new delta is marginally larger.

Questions

  1. OK to keep the TRACT_DISABLE_SCAN_ITER_REUSE gate (safety/A-B), or drop it
    for a single clean path?
  2. Any concern with keeping resolved_symbols across body iterations? Body
    shapes are constant within a Scan eval; I clear them up front each eval
    since the body state persists across outer calls.

Test plan

  • Bit-identical new-vs-old across GRU/LSTM/RNN × 2 sizes + df_dec/erb_dec
  • tract-core lib suite (240 tests)
  • cargo fmt --check + cargo clippy clean (changed files)
  • Wider RNN/Loop conformance (ONNX backend suite) before leaving draft

🤖 Generated with Claude Code

@czoli1976
Copy link
Copy Markdown
Contributor Author

For context, there's an alternative to this PR worth your call before you spend review time here.

This PR is the investigation's "path B": keep the generic Scan, just make the per-iteration loop lighter. There's also a "path C" I prototyped earlier — a fused OptGru op that replaces the Scan-decomposed GRU entirely:

  • ONNX GRU lowers (via a try_wire_fused hook, gated TRACT_ENABLE_OPT_GRU=1) to a single OptGru op that owns its recurrence — no Scan, no per-timestep sub-plan re-entry. One hoisted input projection X·Wᵀ (linalg mmm), then per timestep a single combined recurrent mmv R·h_prev, fused sigmoid/tanh (the linalg SIMD kernels), and the (1−z)·h̃ + z·h recurrence — all in one Rust loop. Stateful, caches the packed W/R.
  • It's implemented and correct (differential tests bit-exact vs Scan, 4/4; df_dec real-input parity identical to Scan).

But on honest wall-clock it's only parity with Scan on large GRUs (DFN3 df_dec, hidden=256/seq=100), native and WASM. It removes the Scan scaffolding (~2.3 µs/iter) but adds ~equivalent own per-iter overhead (the mmv call + the activation kernel calls), so it nets out. That's why I shelved it rather than PR'ing it.

The one place the two genuinely diverge is small / lightweight RNNs, where the per-step scaffolding dominates the tiny matmul: a 128-hidden / 50-step GRU is ~1.4× ORT on native purely from the per-step Scan machinery (≈2.3 µs/iter × 50 ≈ the whole gap). There a fused op could win where this PR can only shave. On large GRUs like DFN3's, both are parity.

So the trade-off:

  • This PR — small, generic, low-risk; helps every Scan/Loop/RNN a little; keeps one code path.
  • OptGru — a new GRU-specific op (bigger surface, GRU-only) that's the only thing that closes the small-RNN per-step-overhead gap, but parity elsewhere.

Given you've leaned toward simple/generic over special-casing (and the cost-model discussion on #2253), my guess is you'd prefer this Scan tweak and not a fused op — but I wanted to surface OptGru explicitly rather than decide for you. Do you prefer this PR, the fused OptGru, both, or neither? Happy to open OptGru as a draft if you want to look, or drop it.

kali
kali previously approved these changes May 20, 2026
@kali
Copy link
Copy Markdown
Collaborator

kali commented May 20, 2026

Curious to see how some TDNN will react to that in the bench.

FWIW, I like that we can optimise Scan, even if we make it (moslty) obsolete. We have some whacky LSTM variants lying around, so the Scan is here to stay, even if GRU (and eventually some standard LSTM forms) get promoted down the road.

I think I have tiny models in the portfolio with recurring ops. OptGru may happen to be more interesting for these that it is for DFN.

@czoli1976
Copy link
Copy Markdown
Contributor Author

A small extra commit coming for this, will also share the optGru separately

@czoli1976
Copy link
Copy Markdown
Contributor Author

Heads-up: the red macOS / Nemotron speech streaming en 0.6b Large-models check on this PR is not caused by this change — it's a pre-existing breakage on main.

The nemotron encoder's ROI attention emits DiagGather + ScaledMaskedSoftmax, which aren't in the nemotron --metal/--cuda --assert-op-only allowlists in ci.sh, so the strict GPU op-coverage assert fails (Model has 48 unexpected op(s)). The CPU -O pass of all four sub-models — including the decoder, where this PR's Scan change actually runs — passes. The same job is red on other branches off current main (e.g. task/onnx-dtype-panic), and this PR doesn't touch those ops, the harness, or GPU coverage.

Fix is in #2259 (mirrors parakeet's allowlist, which already permits DiagGather).

@czoli1976
Copy link
Copy Markdown
Contributor Author

Filed the fused OptGru op (path C) as a draft for your look: #2260.

It lowers ONNX GRU to a single op that owns its recurrence (no Scan, no per-timestep sub-plan re-entry). Honest caveat (also in that PR): it's parity with Scan on large GRUs (df_dec) — it removes the Scan scaffolding but adds ~equivalent own per-iter overhead (mmv call + activation kernel calls) — so it's filed for direction, not as a perf win. It bit-matches Scan (differential tests, opt-in via TRACT_ENABLE_OPT_GRU=1).

So this PR (#2257, path B = small generic Scan tweak) and #2260 (path C = RNN-specific fused op) are the two alternatives we discussed; happy to take whichever direction you prefer, or neither.

@kali
Copy link
Copy Markdown
Collaborator

kali commented May 26, 2026

Couly you rebase this one over main ? the main position if was forked from breaks the bench, the bug has been fixed on main in the meantime. (not the nemotron one, another one)

czoli1976 and others added 2 commits May 26, 2026 09:42
The optimized Scan body runs the same plan with the same shapes every
timestep, so resolve its symbols once, reset between iters without
discarding them (reset_turn_keep_symbols), and reuse one drained input
buffer -- instead of a full model_state.run() cycle (set_inputs ->
resolve_symbols -> exec -> outputs -> reset_turn) per timestep.

Pure optimization: bit-identical to the old path across GRU/LSTM/RNN +
df_dec. Gated by TRACT_DISABLE_SCAN_ITER_REUSE (safety + A/B). Helps every
Scan/Loop/RNN model, most where per-iter overhead is a larger fraction of
the body work.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The TRACT_DISABLE_SCAN_ITER_REUSE check was an env::var per Scan eval,
which is pure overhead on single-iteration Scans (e.g. an autoregressive
RNN-T decoder stepped one token per call -- nemotron's decoder). Read it
once via OnceLock so the change is strictly neutral there.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@czoli1976 czoli1976 force-pushed the feature/scan-iter-reuse branch from f0edae5 to 857439e Compare May 26, 2026 09:32
@czoli1976
Copy link
Copy Markdown
Contributor Author

Rebased onto current main (was forked from f070a6d7f, now on cb9b1b15c).

Re-ran the one-binary A/B (new = default, old = TRACT_DISABLE_SCAN_ITER_REUSE=1) on Apple Silicon, single-thread (RAYON_NUM_THREADS=1), tract -O bench --allow-random-input, interleaved old/new per outer iter (31 runs, each self-averaged over 1.2 s):

model old p50 (ms) new p50 (ms) Δ p50 Δ min
gru 128/50 0.397 0.396 −0.4% −0.2%
lstm 128/50 0.500 0.495 −0.9% −0.3%
gru 256/100 1.757 1.753 −0.2% −0.7%
lstm 256/100 2.330 2.331 +0.0% −0.4%
rnn 256/100 0.620 0.621 +0.1% −1.4%
df_dec 5.95 5.92 −0.5% −0.3%

You were right — on fixed main the win is within noise (±1%). The −3/−4% I'd measured on small RNNs was an artifact of the pre-fix base, and it doesn't survive the rebase. It's still bit-identical new-vs-old and the tract-core suite passes, so the change is correct — but as a performance PR it no longer earns its keep.

So I'd lean toward closing it, unless you think the per-iter simplification stands on its own as a cleanup (in which case I'd drop the TRACT_DISABLE_SCAN_ITER_REUSE gate and re-present it as such). Your call.

(Models here are freshly-generated synthetic GRU/LSTM/RNN, so absolute ms differ slightly from my original table; the new-vs-old delta is the comparable quantity.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants