Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ExpandOpBuilder : public BaseOpBuilder {
void ExpandOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
const auto& input_defs = node.InputDefs();
const auto& shape_name = input_defs[1]->Name();
// Only skip the shape input when it is a constant initializer AND the input has static shape.
// Skip the shape input when it is a constant initializer AND the input has static shape.
// When the input has dynamic shape, we need the shape operand for dynamicExpand even if it's constant.
if (model_builder.GetGraphViewer().GetConstantInitializer(shape_name) &&
!HasDynamicShape(*input_defs[0])) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
}
}

// For per-axis quantization/dequantization and axis is not equal to input_rank - 1,
// we need to reshape the scale and zero_point tensors to make them broadcastable with the input tensor.
// For per-axis quantization/dequantization, the scale is 1-D.
// WebNN requires the scale and zero_point tensors to have the same rank as the input tensor.
// We need to reshape them to make them broadcastable with the input tensor.
if (scale_shape.size() == 1 && input_rank > 1 &&
block_size == 0 && axis != static_cast<int32_t>(input_rank - 1)) {
block_size == 0) {
// Insert ones before and after the axis dimension for broadcasting of scale tensor.
// Use emscripten::val::array() to support dynamic axis dim via input["shape"][axis].
emscripten::val target_shape = emscripten::val::array();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ Status ModelBuilder::AddOperations() {
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer_.GetNode(node_indices[i]));

if (const auto* op_builder = GetOpBuilder(*node)) {
ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, *node, logger_));
} else {
Expand Down
89 changes: 72 additions & 17 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view

const auto supported_nodes = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);

std::unordered_set<const Node*> supported_nodes_with_folded = supported_nodes;

const auto gen_metadef_name = [&]() {
HashValue model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
return MakeString(WEBNN, "_", model_hash, "_", metadef_id);
};

auto result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
auto result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes_with_folded, {},
gen_metadef_name, WEBNN, kWebNNExecutionProvider,
&node_unit_map, /*drop_constant_initializers*/ true);

Expand Down Expand Up @@ -286,7 +288,9 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
ORT_UNUSED_PARAMETER(state);
};

compute_info.compute_func = [dim_param_to_input_dim, fixed_dim_param_values, fused_output_shapes, output_dim_params](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
// Use additive dim_param fallback when computeShapes() API is not yet available.
const bool use_additive_dim_fallback = wnn_context_["computeShapes"].isUndefined();
compute_info.compute_func = [dim_param_to_input_dim, fixed_dim_param_values, fused_output_shapes, output_dim_params, use_additive_dim_fallback](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
Ort::KernelContext ctx(context);

const size_t num_inputs = ctx.GetInputCount();
Expand Down Expand Up @@ -429,6 +433,40 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
}
}
}

// Try to parse additive expressions like "dim_a + dim_b"
// (e.g., "past_sequence_length + sequence_length").
if (use_additive_dim_fallback && output_shape[dim_idx] == webnn::kDynamicDim) {
auto plus_pos = dim_param.find('+');
if (plus_pos != std::string::npos) {
const std::string left = utils::TrimString(std::string_view(dim_param).substr(0, plus_pos));
const std::string right = utils::TrimString(std::string_view(dim_param).substr(plus_pos + 1));

// Resolve each operand (from runtime inputs or fixed bounds).
auto resolve_operand = [&](const std::string& operand) -> int64_t {
auto it = dim_param_to_input_dim.find(operand);
if (it != dim_param_to_input_dim.end()) {
const size_t src_idx = it->second.first;
const size_t src_dim = it->second.second;
if (src_idx < runtime_input_shapes.size() &&
src_dim < runtime_input_shapes[src_idx].size()) {
return runtime_input_shapes[src_idx][src_dim];
}
}
auto fixed_it = fixed_dim_param_values.find(operand);
if (fixed_it != fixed_dim_param_values.end()) {
return fixed_it->second;
}
return -1; // unresolved
};

int64_t left_val = resolve_operand(left);
int64_t right_val = resolve_operand(right);
if (left_val >= 0 && right_val >= 0) {
output_shape[dim_idx] = left_val + right_val;
}
}
}
}
}

Expand Down Expand Up @@ -458,28 +496,45 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
}
}

// Hard fail if dynamic dimensions remain unresolved.
// TODO: When WebNN supports the dispatch() API that returns output MLTensors with
// shapes inferred from actual input tensors, we can query the real output shapes
// at runtime instead of relying on symbolic dim_param matching. This would eliminate
// the need for the simplify_dynamic_shapes.py preprocessing step and handle all
// dynamic shape cases natively (including data-dependent output shapes).
// If dynamic dimensions remain unresolved, try to infer from the max bounds
// of known dims (e.g., use batch_size=1 for dim 0, sequence_length for others).
// This handles intermediate outputs (like Expand's causal mask) whose shapes are
// data-dependent and not annotated with a resolvable dim_param.
for (size_t dim_idx = 0; dim_idx < output_shape.size(); ++dim_idx) {
if (output_shape[dim_idx] == webnn::kDynamicDim) {
std::string unresolved_dim_param;
if (output_idx < output_dim_params.size() && dim_idx < output_dim_params[output_idx].size()) {
unresolved_dim_param = output_dim_params[output_idx][dim_idx];
}

LOGS_DEFAULT(ERROR) << "[WebNN] Failed to resolve dynamic output dimension for output ["
<< output_name << "] at dim index [" << dim_idx
<< "], dim_param: [" << unresolved_dim_param
<< "]. Please ensure this dim_param can be inferred from graph inputs"
<< " or provide pre-allocated output tensors via session.run(feeds, fetches).";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"[WebNN] Failed to resolve dynamic output dimension for output: ", output_name,
" at dim index: ", dim_idx,
". dim_param: ", unresolved_dim_param);
// Instead of hard-failing, use a heuristic:
// - Try to match the unresolved dim to any input dim at the same index.
// - As a last resort, copy from the largest input shape at this dim index.
int64_t inferred = 0;
for (size_t inp_idx = 0; inp_idx < runtime_input_shapes.size() && inferred == 0; ++inp_idx) {
if (dim_idx < runtime_input_shapes[inp_idx].size()) {
int64_t candidate = runtime_input_shapes[inp_idx][dim_idx];
if (candidate > inferred) {
inferred = candidate;
}
}
}

if (inferred > 0) {
LOGS_DEFAULT(WARNING) << "[WebNN] Unresolved output dim for [" << output_name
<< "] at index " << dim_idx << " (dim_param: [" << unresolved_dim_param
<< "]). Inferred from runtime inputs: " << inferred;
output_shape[dim_idx] = inferred;
} else {
LOGS_DEFAULT(ERROR) << "[WebNN] Failed to resolve dynamic output dimension for output ["
<< output_name << "] at dim index [" << dim_idx
<< "], dim_param: [" << unresolved_dim_param
<< "]. No input dims available for inference.";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"[WebNN] Failed to resolve dynamic output dimension for output: ", output_name,
" at dim index: ", dim_idx,
". dim_param: ", unresolved_dim_param);
}
}
}

Expand Down