core/ops/scan: reuse body state across iterations (skip per-timestep plan churn)#2257
core/ops/scan: reuse body state across iterations (skip per-timestep plan churn)#2257czoli1976 wants to merge 2 commits into
Conversation
|
For context, there's an alternative to this PR worth your call before you spend review time here. This PR is the investigation's "path B": keep the generic
But on honest wall-clock it's only parity with The one place the two genuinely diverge is small / lightweight RNNs, where the per-step scaffolding dominates the tiny matmul: a 128-hidden / 50-step GRU is ~1.4× ORT on native purely from the per-step So the trade-off:
Given you've leaned toward simple/generic over special-casing (and the cost-model discussion on #2253), my guess is you'd prefer this Scan tweak and not a fused op — but I wanted to surface |
|
Curious to see how some TDNN will react to that in the bench. FWIW, I like that we can optimise Scan, even if we make it (moslty) obsolete. We have some whacky LSTM variants lying around, so the Scan is here to stay, even if GRU (and eventually some standard LSTM forms) get promoted down the road. I think I have tiny models in the portfolio with recurring ops. OptGru may happen to be more interesting for these that it is for DFN. |
|
A small extra commit coming for this, will also share the optGru separately |
|
Heads-up: the red macOS / Nemotron speech streaming en 0.6b Large-models check on this PR is not caused by this change — it's a pre-existing breakage on The nemotron encoder's ROI attention emits Fix is in #2259 (mirrors |
|
Filed the fused It lowers ONNX So this PR (#2257, path B = small generic Scan tweak) and #2260 (path C = RNN-specific fused op) are the two alternatives we discussed; happy to take whichever direction you prefer, or neither. |
|
Couly you rebase this one over main ? the main position if was forked from breaks the bench, the bug has been fixed on main in the meantime. (not the nemotron one, another one) |
The optimized Scan body runs the same plan with the same shapes every timestep, so resolve its symbols once, reset between iters without discarding them (reset_turn_keep_symbols), and reuse one drained input buffer -- instead of a full model_state.run() cycle (set_inputs -> resolve_symbols -> exec -> outputs -> reset_turn) per timestep. Pure optimization: bit-identical to the old path across GRU/LSTM/RNN + df_dec. Gated by TRACT_DISABLE_SCAN_ITER_REUSE (safety + A/B). Helps every Scan/Loop/RNN model, most where per-iter overhead is a larger fraction of the body work. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The TRACT_DISABLE_SCAN_ITER_REUSE check was an env::var per Scan eval, which is pure overhead on single-iteration Scans (e.g. an autoregressive RNN-T decoder stepped one token per call -- nemotron's decoder). Read it once via OnceLock so the change is strictly neutral there. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
f0edae5 to
857439e
Compare
|
Rebased onto current Re-ran the one-binary A/B (new = default, old =
You were right — on fixed So I'd lean toward closing it, unless you think the per-iter simplification stands on its own as a cleanup (in which case I'd drop the (Models here are freshly-generated synthetic GRU/LSTM/RNN, so absolute ms differ slightly from my original table; the new-vs-old delta is the comparable quantity.) |
Summary
Draft for feedback. Trims the per-iteration scaffolding in the optimized
Scanbody loop: it runs the same plan with the same shapes every timestep, so we
can resolve the body's symbols once, reset between iterations without
discarding them, and reuse one drained input buffer — instead of a full
model_state.run()cycle (set_inputs → resolve_symbols → exec → outputs →reset_turn) per timestep.Behavior-preserving (it's a pure optimization). Modest but broad: it helps every
Scan/Loop/RNN model, most where the per-iter overhead is a larger fractionof the body work.
What changes
core/src/plan.rs:reset_turn_keep_symbols()(light reset that keepsresolved_symbols),clear_resolved_symbols(),set_inputs_drain()(reusethe caller's buffer), and
resolve_symbols_with_states()madepub(crate).core/src/ops/scan/optimized.rs: the body loop resolves symbols on the firstiteration only, resets with
reset_turn_keep_symbols(), and reuses a singleiter_inputsbuffer (drained) — replacing theOption/Vec/flattenper-iter construction.
TRACT_DISABLE_SCAN_ITER_REUSE(defaults on) — both for A/B and as asafety switch on this core path. Happy to drop the gate if you'd rather.
Correctness — proven behavior-preserving
A/B in one binary (new = default, old =
TRACT_DISABLE_SCAN_ITER_REUSE=1) givesbit-identical results across GRU / LSTM / vanilla-RNN (× 2 sizes) and DFN3
df_dec — same outlier profile vs ORT every time (e.g. GRU 382/382, LSTM 30/30,
RNN 9975/9975, df_dec 8/96000). Full
tract-corelib suite (240 tests) passes.Performance (Apple M1 Pro, single-thread, median; vs the gated old path)
Native:
WASM (wasmtime, +simd128,+relaxed-simd):
parity where the body matmul dominates (gru_256, df_dec). It's also more
valuable as the body kernels get faster — e.g. once linalg/arm64/apple_amx: shape-aware AMX dispatch (5-43% canary wins, no regressions) #2220's AMX dispatch
shrank the body matmul, gru_256 native went from 0% to −1.1%.
faster (native 1-thread tract gru_256 1.89 ms vs ORT 2.34; vanilla RNN ~3×
faster). This PR shaves a bit more off an already-decent position.
Measurement note
The A/B above isolates the run-path change: both arms share the cleaned-up
iter_inputsbuild, so the numbers slightly understate the full change vspristine
main(which also did an extraVec<Option>alloc +flattenperiter). The pristine-vs-new delta is marginally larger.
Questions
TRACT_DISABLE_SCAN_ITER_REUSEgate (safety/A-B), or drop itfor a single clean path?
resolved_symbolsacross body iterations? Bodyshapes are constant within a
Scaneval; I clear them up front each evalsince the body state persists across outer calls.
Test plan
tract-corelib suite (240 tests)cargo fmt --check+cargo clippyclean (changed files)🤖 Generated with Claude Code