core/onnx: fused OptGru op (draft — path-C alternative to #2257)#2260
Draft
czoli1976 wants to merge 1 commit into
Draft
core/onnx: fused OptGru op (draft — path-C alternative to #2257)#2260czoli1976 wants to merge 1 commit into
czoli1976 wants to merge 1 commit into
Conversation
Lowers ONNX GRU to a single OptGru op owning its recurrence (no Scan, no per-timestep sub-plan re-entry): hoisted X*W^T (mmm) + per-step combined recurrent mmv R*h_prev + fused sigmoid/tanh + recurrence, in one allocation-free loop. Stateful, caches packed W/R. Gated TRACT_ENABLE_OPT_GRU=1; falls back to Scan otherwise. Bit-exact vs Scan (differential tests + scalar-ref unit tests). Wall-clock PARITY with Scan on large GRUs (df_dec) -- filed as a draft for direction (path C, the alternative to sonos#2257's path B), not as a perf win. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary — draft for direction, not for merge
This is the fused
OptGruop — the "path C" alternative to #2257's lighter-Scan ("path B"). Filing it as a draft so you can see the shape, per the discussion on #2257. I'm not advocating for it (see perf below); it's here for your call on whether a fused RNN op is a direction you'd want, vs the generic-Scan tweak.What it is
ONNX
GRUlowers (via aWireBody::try_wire_fusedhook, gatedTRACT_ENABLE_OPT_GRU=1) to a singleOptGruop that owns its recurrence — noScan, no per-timestep sub-plan re-entry:X·Wᵀ(linalgmmm),mmvR·h_prev(N=1),(1−z)·h̃ + z·hrecurrence,all in one allocation-free Rust loop. Stateful (
OptGruState) — caches the packed constantW/Racross calls.The fast path covers the common config (f32, batch=1,
linear_before_reset=1, concrete dims, no peepholes/seq_lens); anything else falls back toScan, so it's purely additive.Correctness
onnx/tests/opt_gru_vs_scan.rs): bit-exact vsScan(single, symbolic-seq, df_dec dims 256/100, two stacked).opt_gru.rs.Scan.Performance — honest: PARITY on large GRUs, not a win
On DFN3 df_dec (hidden=256, seq=100), wall-clock OptGru is parity with
Scan(native and WASM). It removes theScanper-iteration scaffolding but adds ~equivalent own per-iter overhead (themmvcall + the activation kernel calls), so it nets out — which is why I haven't pushed it as a perf PR.The one regime where a fused op could win is small/overhead-dominated RNNs (small hidden, where the per-step
Scanmachinery alone is a large fraction — e.g. a 128-hidden GRU is ~1.4× ORT on native purely from per-step overhead). Large GRUs (DFN3) are parity either way.Caveats (why it's draft)
OptGruisn't pulse-aware; the lowering would need to gate on non-pulse (orOptGruneeds a per-step state API). Currently moot — opt-in, off by default.batch_first=TrueY-wiring is implemented but untested (no canary).Relationship to #2257
#2257 (path B) is the small, generic Scan-loop optimization (helps every
Scan/Loopuser a little, neutral on single-iteration streaming decoders). This (path C) is the bigger, RNN-specific alternative — potentially better on small GRUs, parity on large ones, GRU-only. Mutually informative; pick whichever direction you prefer (or neither).🤖 Generated with Claude Code