Skip to content

core/onnx: fused OptGru op (draft — path-C alternative to #2257)#2260

Draft
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/opt-gru-fused
Draft

core/onnx: fused OptGru op (draft — path-C alternative to #2257)#2260
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/opt-gru-fused

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

Summary — draft for direction, not for merge

This is the fused OptGru op — the "path C" alternative to #2257's lighter-Scan ("path B"). Filing it as a draft so you can see the shape, per the discussion on #2257. I'm not advocating for it (see perf below); it's here for your call on whether a fused RNN op is a direction you'd want, vs the generic-Scan tweak.

What it is

ONNX GRU lowers (via a WireBody::try_wire_fused hook, gated TRACT_ENABLE_OPT_GRU=1) to a single OptGru op that owns its recurrence — no Scan, no per-timestep sub-plan re-entry:

  • one hoisted input projection X·Wᵀ (linalg mmm),
  • per timestep one combined recurrent mmv R·h_prev (N=1),
  • fused sigmoid/tanh (the linalg SIMD kernels),
  • the (1−z)·h̃ + z·h recurrence,

all in one allocation-free Rust loop. Stateful (OptGruState) — caches the packed constant W/R across calls.

The fast path covers the common config (f32, batch=1, linear_before_reset=1, concrete dims, no peepholes/seq_lens); anything else falls back to Scan, so it's purely additive.

Correctness

  • Differential tests (onnx/tests/opt_gru_vs_scan.rs): bit-exact vs Scan (single, symbolic-seq, df_dec dims 256/100, two stacked).
  • Scalar-reference unit tests in opt_gru.rs.
  • df_dec real-input reference parity identical to Scan.

Performance — honest: PARITY on large GRUs, not a win

On DFN3 df_dec (hidden=256, seq=100), wall-clock OptGru is parity with Scan (native and WASM). It removes the Scan per-iteration scaffolding but adds ~equivalent own per-iter overhead (the mmv call + the activation kernel calls), so it nets out — which is why I haven't pushed it as a perf PR.

The one regime where a fused op could win is small/overhead-dominated RNNs (small hidden, where the per-step Scan machinery alone is a large fraction — e.g. a 128-hidden GRU is ~1.4× ORT on native purely from per-step overhead). Large GRUs (DFN3) are parity either way.

Caveats (why it's draft)

  • Pulse: OptGru isn't pulse-aware; the lowering would need to gate on non-pulse (or OptGru needs a per-step state API). Currently moot — opt-in, off by default.
  • batch_first=True Y-wiring is implemented but untested (no canary).
  • The real question is whether you'd want a GRU-specific op at all vs the generic-Scan path.

Relationship to #2257

#2257 (path B) is the small, generic Scan-loop optimization (helps every Scan/Loop user a little, neutral on single-iteration streaming decoders). This (path C) is the bigger, RNN-specific alternative — potentially better on small GRUs, parity on large ones, GRU-only. Mutually informative; pick whichever direction you prefer (or neither).

🤖 Generated with Claude Code

Lowers ONNX GRU to a single OptGru op owning its recurrence (no Scan, no
per-timestep sub-plan re-entry): hoisted X*W^T (mmm) + per-step combined
recurrent mmv R*h_prev + fused sigmoid/tanh + recurrence, in one
allocation-free loop. Stateful, caches packed W/R. Gated
TRACT_ENABLE_OPT_GRU=1; falls back to Scan otherwise.

Bit-exact vs Scan (differential tests + scalar-ref unit tests). Wall-clock
PARITY with Scan on large GRUs (df_dec) -- filed as a draft for direction
(path C, the alternative to sonos#2257's path B), not as a perf win.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant