Add shape subgraph folding and dynamic output dim resolution for WebNN EP#21
Add shape subgraph folding and dynamic output dim resolution for WebNN EP#21arisha07 wants to merge 3 commits into
Conversation
…N EP - Add ShapeSubgraphFolder to pre-evaluate shape subgraphs (Where/Equal/Range/ConstantOfShape chains) so Reshape/Expand see constant shapes at build time - Integrate folded shapes into Reshape and Expand op builders - Support additive dim_param expressions (e.g. past_sequence_length + sequence_length) - Add heuristic fallback for unresolved output dimensions from runtime inputs - Fix QDQ per-axis reshape to handle all axes (not just last axis) - Claim folded nodes in GetCapability to keep them in WebNN partition
77791de to
eec8bb6
Compare
|
Thanks @arisha07,
These shape ops should support int64 or bool (webnn uses uint8 instead) data type for ort backend, there should be no fallback I believe.
This is a good one to reinforce current output expressions, but please note that we are going to propose a new API
Looks like all the folding are computed for static shape, what if the shape is dynamic? |
|
Thanks for the review @Honry! Let me address each point:
|
For static ShapeSubgraphFolder, ORT has already done the same work when users set the
All these dynamic shape proposed APIs are still under discussion and haven't been landed in the spec. |
|
After retesting, I do not currently have a concrete model where unresolved shape chains remain after freeDimensionOverrides are applied. For the models I validated (Qwen2.5 0.5B non-GQA and Llama 3.2 1B non-GQA), ORT standard constant folding already resolves them and ShapeSubgraphFolder reports 0 folded. So I agree this may duplicate existing ORT behavior for current pipelines. I will trim this PR to avoid redundancy. |
…de and enable_additive_dim_param option - Remove IsFoldedShape(), GetFoldedShape(), IsFoldedNode() declarations and implementations from model_builder.h/cc (dead code from removed ShapeSubgraphFolder) - Remove IsFoldedShape/IsFoldedNode call sites in expand_op_builder.cc and reshape_op_builder.cc - Remove enable_additive_dim_param constructor param, member variable, and option parsing from webnn_execution_provider.h/cc and webnn_provider_factory.cc - Remove enableAdditiveDimParam session option mapping from session-options.ts - Keep additive dim_param fallback logic guarded by runtime computeShapes check
Description
Add ShapeSubgraphFolder to pre-evaluate shape subgraphs (Where/Equal/Range/ConstantOfShape chains) so Reshape/Expand see constant shapes at build time
Integrate folded shapes into Reshape and Expand op builders
Support additive dim_param expressions (e.g. past_sequence_length + sequence_length)
Add heuristic fallback for unresolved output dimensions from runtime inputs
Fix QDQ per-axis reshape to handle all axes (not just last axis)
Claim folded nodes in GetCapability to keep them in WebNN partition
Motivation and Context
When running LLMs (e.g., Llama 3.2 1B exported from HuggingFace optimum-nncf flow) with dynamic shapes through the WebNN execution provider, a large portion of the model graph consists of "shape-computing" subgraphs : chains of Shape → Gather → Concat → Cast → Unsqueeze nodes that produce shape tensors consumed by Reshape, Expand, and ConstantOfShape ops.
Problem 1: Graph partitioning overhead. These shape ops (often using int64 or bool) fall back to the CPU execution provider. This creates dozens of partition boundaries per layer, forcing expensive CPU↔GPU tensor transfers on every inference. This dominates latency at decode time.
Problem 2: Unresolved output dimensions. Models exported from HuggingFace Optimum contain symbolic dim_param expressions like past_sequence_length + sequence_length that the EP couldn't evaluate, causing hard failures at runtime.
Problem 3: QDQ broadcasting bug. Per-axis quantize/dequantize only reshaped scale/zero_point tensors when the axis was the last dim, breaking int4-quantized models on non-last axes.
What does this PR do?
ShapeSubgraphFolder (shape_subgraph_folder.cc/h): A compile-time pre-pass that traces backward from shape-consuming input slots, evaluates the subgraph using known constant initializers + freeDimensionOverrides, and caches the result. Folded nodes are skipped during AddOperations(). Runs once at session creation — zero per-inference cost.
Op builder integration (expand_op_builder.cc, reshape_op_builder.cc): When a shape input has been folded, emit the WebNN op with the constant shape directly instead of requiring a runtime shape tensor or dynamic op variant.
Additive dim_param resolution (webnn_execution_provider.cc): Parses expressions like a + b in symbolic dim_params and resolves them from runtime input shapes.
Heuristic fallback (webnn_execution_provider.cc): For intermediate outputs with data-dependent shapes that can't be symbolically resolved, infers dimensions from the largest matching input dim instead of hard-failing.
QDQ per-axis fix (qdq_op_builder.cc): Removes the axis != input_rank - 1 guard so scale/zero_point reshaping works for all quantization axes.
GetCapability integration (webnn_execution_provider.cc): Claims folded nodes as "supported" so they stay in the WebNN partition (even though they won't be executed — they'll be skipped in AddOperations()).
Performance impact (More testing to follow)
On Llama 3.2 1B (non-GQA, int4) with WebNN GPU (OpenVINO EP):
Before: ~5 tok/s
After: enables 25+ tok/s (when combined with Chromium-side dynamic shape dispatch support)