Skip to content

Add shape subgraph folding and dynamic output dim resolution for WebNN EP#21

Open
arisha07 wants to merge 3 commits into
Honry:dynamic-dim-pocfrom
arisha07:nogqa-shape-folder
Open

Add shape subgraph folding and dynamic output dim resolution for WebNN EP#21
arisha07 wants to merge 3 commits into
Honry:dynamic-dim-pocfrom
arisha07:nogqa-shape-folder

Conversation

@arisha07
Copy link
Copy Markdown

@arisha07 arisha07 commented May 29, 2026

Description

  • Add ShapeSubgraphFolder to pre-evaluate shape subgraphs (Where/Equal/Range/ConstantOfShape chains) so Reshape/Expand see constant shapes at build time

  • Integrate folded shapes into Reshape and Expand op builders

  • Support additive dim_param expressions (e.g. past_sequence_length + sequence_length)

  • Add heuristic fallback for unresolved output dimensions from runtime inputs

  • Fix QDQ per-axis reshape to handle all axes (not just last axis)

  • Claim folded nodes in GetCapability to keep them in WebNN partition

Motivation and Context

When running LLMs (e.g., Llama 3.2 1B exported from HuggingFace optimum-nncf flow) with dynamic shapes through the WebNN execution provider, a large portion of the model graph consists of "shape-computing" subgraphs : chains of Shape → Gather → Concat → Cast → Unsqueeze nodes that produce shape tensors consumed by Reshape, Expand, and ConstantOfShape ops.

Problem 1: Graph partitioning overhead. These shape ops (often using int64 or bool) fall back to the CPU execution provider. This creates dozens of partition boundaries per layer, forcing expensive CPU↔GPU tensor transfers on every inference. This dominates latency at decode time.

Problem 2: Unresolved output dimensions. Models exported from HuggingFace Optimum contain symbolic dim_param expressions like past_sequence_length + sequence_length that the EP couldn't evaluate, causing hard failures at runtime.

Problem 3: QDQ broadcasting bug. Per-axis quantize/dequantize only reshaped scale/zero_point tensors when the axis was the last dim, breaking int4-quantized models on non-last axes.

What does this PR do?

  1. ShapeSubgraphFolder (shape_subgraph_folder.cc/h): A compile-time pre-pass that traces backward from shape-consuming input slots, evaluates the subgraph using known constant initializers + freeDimensionOverrides, and caches the result. Folded nodes are skipped during AddOperations(). Runs once at session creation — zero per-inference cost.

  2. Op builder integration (expand_op_builder.cc, reshape_op_builder.cc): When a shape input has been folded, emit the WebNN op with the constant shape directly instead of requiring a runtime shape tensor or dynamic op variant.

  3. Additive dim_param resolution (webnn_execution_provider.cc): Parses expressions like a + b in symbolic dim_params and resolves them from runtime input shapes.

  4. Heuristic fallback (webnn_execution_provider.cc): For intermediate outputs with data-dependent shapes that can't be symbolically resolved, infers dimensions from the largest matching input dim instead of hard-failing.

  5. QDQ per-axis fix (qdq_op_builder.cc): Removes the axis != input_rank - 1 guard so scale/zero_point reshaping works for all quantization axes.

  6. GetCapability integration (webnn_execution_provider.cc): Claims folded nodes as "supported" so they stay in the WebNN partition (even though they won't be executed — they'll be skipped in AddOperations()).

Performance impact (More testing to follow)

On Llama 3.2 1B (non-GQA, int4) with WebNN GPU (OpenVINO EP):

Before: ~5 tok/s
After: enables 25+ tok/s (when combined with Chromium-side dynamic shape dispatch support)

@arisha07 arisha07 marked this pull request as draft May 29, 2026 23:47
…N EP

- Add ShapeSubgraphFolder to pre-evaluate shape subgraphs (Where/Equal/Range/ConstantOfShape chains) so Reshape/Expand see constant shapes at build time

- Integrate folded shapes into Reshape and Expand op builders

- Support additive dim_param expressions (e.g. past_sequence_length + sequence_length)

- Add heuristic fallback for unresolved output dimensions from runtime inputs

- Fix QDQ per-axis reshape to handle all axes (not just last axis)

- Claim folded nodes in GetCapability to keep them in WebNN partition
@arisha07 arisha07 force-pushed the nogqa-shape-folder branch from 77791de to eec8bb6 Compare May 30, 2026 00:02
@arisha07 arisha07 marked this pull request as ready for review May 30, 2026 00:04
@arisha07 arisha07 marked this pull request as draft May 30, 2026 00:05
@arisha07 arisha07 marked this pull request as ready for review June 2, 2026 17:52
@Honry
Copy link
Copy Markdown
Owner

Honry commented Jun 3, 2026

Thanks @arisha07,

Problem 1: Graph partitioning overhead. These shape ops (often using int64 or bool) fall back to the CPU execution provider. This creates dozens of partition boundaries per layer, forcing expensive CPU↔GPU tensor transfers on every inference. This dominates latency at decode time.

These shape ops should support int64 or bool (webnn uses uint8 instead) data type for ort backend, there should be no fallback I believe.

Problem 2: Unresolved output dimensions. Models exported from HuggingFace Optimum contain symbolic dim_param expressions like past_sequence_length + sequence_length that the EP couldn't evaluate, causing hard failures at runtime.

This is a good one to reinforce current output expressions, but please note that we are going to propose a new API computeShapes() in WebNN which would help framework to calculate the WebNN output shape before dispatch(), it is a more general solution and able to resolve all unknown output dims.

Add ShapeSubgraphFolder to pre-evaluate shape subgraphs (Where/Equal/Range/ConstantOfShape chains) so Reshape/Expand see constant shapes at build time

Looks like all the folding are computed for static shape, what if the shape is dynamic?

@arisha07
Copy link
Copy Markdown
Author

arisha07 commented Jun 3, 2026

Thanks for the review @Honry! Let me address each point:

  1. You're right, for our current test models (Qwen2.5 0.5B and Llama 3.2 1B non-GQA, going to test more models), all ops are marked supported and form a single partition. The ShapeSubgraphFolder reports "0 folded" because ORT's standard constant folding (enabled by freeDimensionOverrides) already resolves the shape subgraphs before the folder runs. The folder was designed as a safety net for models from other export pipelines that may produce unresolved shape chains.

  2. On computeShapes() being a more general solution:
    Agreed that computeShapes() would be the proper long-term solution for unknown output dims. This PR's additive dim_param resolution and heuristic fallback are interim workarounds for models that exist today. Happy to gate them behind a flag or remove them once computeShapes() lands in the spec/browsers.

  3. On "what if the shape is dynamic?":

    The folder uses FreeDimensionOverrides (passed via session options) to resolve symbolic dim_param names. For LLM KV-cache models, the user sets overrides like past_sequence_length=1024, sequence_length=1 which makes the shape subgraphs fully evaluable at session creation.

    When a dimension truly can't be resolved (no override, no constant), GetResolvedShape() returns false → the subgraph is not folded and remains in the graph as-is. So it's a best-effort optimization that gracefully falls back to the existing behavior for genuinely dynamic shapes.

@Honry
Copy link
Copy Markdown
Owner

Honry commented Jun 4, 2026

other export pipelines that may produce unresolved shape chains.

For static ShapeSubgraphFolder, ORT has already done the same work when users set the freeDimensionOverrides, it's a general solution for all onnx models with dynamic input shape. Do you have concret example that have unresolved shape chains after applying freeDimensionOverrides? And that maybe a ORT bug, we don't need to do duplicated shape folding in WebNN EP.

Agreed that computeShapes() would be the proper long-term solution for unknown output dims. This PR's additive dim_param resolution and heuristic fallback are interim workarounds for models that exist today. Happy to gate them behind a flag or remove them once computeShapes() lands in the spec/browsers.

All these dynamic shape proposed APIs are still under discussion and haven't been landed in the spec.
@miaobin is implementing the computeShapes() in his personal Chromium repo and I will apply it to the WebNN EP after that, and yes, maybe a flag may help to avoid the computeShapes() not be accepted by the WG.

@arisha07
Copy link
Copy Markdown
Author

arisha07 commented Jun 4, 2026

After retesting, I do not currently have a concrete model where unresolved shape chains remain after freeDimensionOverrides are applied. For the models I validated (Qwen2.5 0.5B non-GQA and Llama 3.2 1B non-GQA), ORT standard constant folding already resolves them and ShapeSubgraphFolder reports 0 folded.

So I agree this may duplicate existing ORT behavior for current pipelines. I will trim this PR to avoid redundancy.

Comment thread onnxruntime/core/providers/webnn/webnn_execution_provider.cc Outdated
Comment thread onnxruntime/core/providers/webnn/builders/model_builder.h Outdated
…de and enable_additive_dim_param option

- Remove IsFoldedShape(), GetFoldedShape(), IsFoldedNode() declarations and
  implementations from model_builder.h/cc (dead code from removed ShapeSubgraphFolder)
- Remove IsFoldedShape/IsFoldedNode call sites in expand_op_builder.cc and
  reshape_op_builder.cc
- Remove enable_additive_dim_param constructor param, member variable, and
  option parsing from webnn_execution_provider.h/cc and webnn_provider_factory.cc
- Remove enableAdditiveDimParam session option mapping from session-options.ts
- Keep additive dim_param fallback logic guarded by runtime computeShapes check
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants