diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index f1391ba1e3528..8217a07448266 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -212,7 +212,6 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), - WGSL_TEMPLATE_PARAMETER(is_apple, is_apple_), WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_), WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), @@ -221,7 +220,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_), WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_), WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_), - WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_)); + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), + WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_)); } Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const { @@ -486,6 +486,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; bool is_apple = context.AdapterInfo().vendor == std::string_view{"apple"}; + bool has_subgroups = context.HasFeature(wgpu::FeatureName::Subgroups); bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; bool has_head_sink = head_sink != nullptr; @@ -498,6 +499,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co parameters.is_unidirectional_, is_nvidia, is_apple, + has_subgroups, q_BNSH, use_seqlen_k, has_head_sink}; @@ -532,7 +534,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(prefill_tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, is_apple, q_BNSH, use_seqlen_k, has_head_sink, program.max_k_step()) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, is_apple, has_subgroups, q_BNSH, use_seqlen_k, has_head_sink, program.max_k_step()) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, @@ -584,7 +586,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && - context.HasFeature(wgpu::FeatureName::Subgroups) && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 27fa56e333874..e75b6378f67c6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -77,6 +77,7 @@ class FlashAttentionProgram final : public Program { bool is_unidirectional, bool is_nvidia, bool is_apple, + bool has_subgroups, bool q_BNSH, bool use_seqlen_k = false, bool has_head_sink = false) @@ -88,12 +89,12 @@ class FlashAttentionProgram final : public Program { qkv_num_heads_(qkv_num_heads), is_unidirectional_(is_unidirectional), is_nvidia_(is_nvidia), - is_apple_(is_apple), + use_shm_path_(is_apple || is_nvidia || !has_subgroups), q_BNSH_(q_BNSH), use_seqlen_k_(use_seqlen_k), has_head_sink_(has_head_sink) { - if (is_apple || is_nvidia) { - // On Apple and NVIDIA, use an optimized loop-based path with dynamic max_k_step. + if (use_shm_path_) { + // Use shared-memory loop-based path with dynamic max_k_step. // Compute max_k_step from workgroup shared memory budget: k_tile + v_tile = 2 * element_size * head_size * max_k_step const int element_size = is_fp16 ? 2 : 4; constexpr int kMinWorkgroupStorageBudgetBytes = 16384; @@ -130,7 +131,7 @@ class FlashAttentionProgram final : public Program { int qkv_num_heads_; bool is_unidirectional_; bool is_nvidia_; - bool is_apple_; + bool use_shm_path_; bool q_BNSH_; bool use_seqlen_k_; bool has_head_sink_; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index db41ac12ce268..6b620043413e3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -1,7 +1,6 @@ #param has_attention_bias #param has_head_sink -#param is_apple #param is_fp16 #param is_qualcomm #param is_unidirectional @@ -10,6 +9,7 @@ #param qkv_head_size #param qkv_num_heads #param use_seqlen_k +#param use_shm_path #param max_k_step_param const head_size : u32 = qkv_head_size; @@ -61,7 +61,7 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_ } } -#if is_apple +#if use_shm_path var qk_scores : array; @@ -240,7 +240,7 @@ $MAIN { let seq_causal_length = total_sequence_length; #endif -#if is_apple +#if use_shm_path for (var k_start = 0u; k_start < loop_bound; k_start += max_k_step) { workgroupBarrier(); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index fd4d266dc51f0..bfa25a5cb2e9a 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -1,13 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include +#include +#include +#include + +#include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/flatbuffers/flatbuffers_utils.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/model.h" #include "core/graph/model_editor_api_types.h" +#include "core/graph/model_helpers.h" #include "core/graph/model_load_utils.h" #ifdef _MSC_VER @@ -129,6 +136,8 @@ Model::Model(const std::string& graph_name, func_ptr); } + ORT_THROW_IF_ERROR(ValidateModelLocalFunctionAcyclic(model_local_functions_)); + model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), @@ -261,6 +270,8 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name(), func.overload()), &func); } + ORT_THROW_IF_ERROR(ValidateModelLocalFunctionAcyclic(model_local_functions_)); + model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index c86aac44806bd..9a877ac6bba95 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -343,7 +343,7 @@ class Model { // map from function id to pointer of model local function proto // FunctionProto is hosted in ModelProto. // this map will be used for the local functions' schema's type/shape inference. - // This container is used by ONNX code and must be an std::unordered_map. + // Must be std::unordered_map to match ONNX_NAMESPACE::shape_inference::ModelLocalFunctionsMap. std::unordered_map model_local_functions_; // this is the map from function id to the local function template. // this map will be used by graph to instantiate the function body. diff --git a/onnxruntime/core/graph/model_helpers.cc b/onnxruntime/core/graph/model_helpers.cc new file mode 100644 index 0000000000000..c3214d488ff0d --- /dev/null +++ b/onnxruntime/core/graph/model_helpers.cc @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/graph/model_helpers.h" + +#include +#include +#include +#include +#include + +#include "core/graph/function_utils.h" +#include "core/graph/onnx_protobuf.h" + +namespace onnxruntime { + +namespace { + +// Iterative collection of local function calls from a sequence of nodes, +// including nodes inside nested subgraph attributes. Avoids recursion to +// prevent stack overflow from maliciously deep subgraph nesting. +template +void CollectLocalFunctionCalls( + const NodeRange& nodes, + const std::unordered_map& model_local_functions, + InlinedHashSet& seen_calls, + InlinedVector& called_functions) { + InlinedVector pending_graphs; + + auto process_nodes = [&](const auto& node_range) { + for (const auto& node : node_range) { + const auto function_id = function_utils::GetFunctionIdentifier( + node.domain(), node.op_type(), node.overload()); + auto it = model_local_functions.find(function_id); + if (it != model_local_functions.end()) { + // Use string_view into the map key (stable storage). + std::string_view key_view = it->first; + if (seen_calls.insert(key_view).second) { + called_functions.push_back(key_view); + } + } + + for (const auto& attr : node.attribute()) { + if (attr.has_g()) { + pending_graphs.push_back(&attr.g()); + } + for (const auto& sub_graph : attr.graphs()) { + pending_graphs.push_back(&sub_graph); + } + } + } + }; + + process_nodes(nodes); + + while (!pending_graphs.empty()) { + const auto* graph = pending_graphs.back(); + pending_graphs.pop_back(); + process_nodes(graph->node()); + } +} + +} // namespace + +Status BuildLocalFunctionCallGraph( + const std::unordered_map& model_local_functions, + LocalFunctionCallGraph& call_graph) { + call_graph.reserve(model_local_functions.size()); + + for (const auto& [function_id, function_proto] : model_local_functions) { + if (function_proto == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Null function proto for function id: ", function_id); + } + + InlinedHashSet seen_calls; + InlinedVector callees; + CollectLocalFunctionCalls(function_proto->node(), model_local_functions, seen_calls, callees); + + call_graph.emplace(std::string_view(function_id), std::move(callees)); + } + + return Status::OK(); +} + +Status ValidateCallGraphAcyclic(const LocalFunctionCallGraph& call_graph) { + enum class VisitState { kNotVisited, + kVisiting, + kVisited }; + + InlinedHashMap visit_states; + visit_states.reserve(call_graph.size()); + for (const auto& [function_id, _] : call_graph) { + ORT_UNUSED_PARAMETER(_); + visit_states.emplace(function_id, VisitState::kNotVisited); + } + + // Each frame records the function being visited and a pointer to its callees vector + // in the call graph (no per-frame allocation). + struct DfsFrame { + std::string_view function_id; + const InlinedVector* callees; + size_t next_callee_index; + }; + + std::vector dfs_stack; + + for (const auto& [root_id, root_callees] : call_graph) { + auto root_state_it = visit_states.find(root_id); + if (root_state_it == visit_states.end() || root_state_it->second == VisitState::kVisited) { + continue; + } + + root_state_it->second = VisitState::kVisiting; + dfs_stack.push_back({root_id, &root_callees, 0}); + + while (!dfs_stack.empty()) { + auto& frame = dfs_stack.back(); + + if (frame.next_callee_index >= frame.callees->size()) { + // All callees processed — mark as fully visited and pop. + auto it = visit_states.find(frame.function_id); + ORT_ENFORCE(it != visit_states.end()); + it->second = VisitState::kVisited; + dfs_stack.pop_back(); + continue; + } + + std::string_view callee_id = (*frame.callees)[frame.next_callee_index]; + frame.next_callee_index++; + + auto callee_state_it = visit_states.find(callee_id); + if (callee_state_it == visit_states.end()) { + // Callee not in the graph — skip. + continue; + } + + if (callee_state_it->second == VisitState::kVisited) { + continue; + } + + if (callee_state_it->second == VisitState::kVisiting) { + // Cycle detected. Build cycle description from the stack. + std::string cycle; + bool in_cycle = false; + for (const auto& f : dfs_stack) { + if (f.function_id == callee_id) { + in_cycle = true; + } + if (in_cycle) { + if (!cycle.empty()) { + cycle.append(" -> "); + } + cycle.append(f.function_id); + } + } + cycle.append(" -> "); + cycle.append(callee_id); + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Model local function definitions must not be recursive. Cycle detected: ", cycle); + } + + // Push callee onto the DFS stack. + auto callee_graph_it = call_graph.find(callee_id); + if (callee_graph_it == call_graph.end()) { + continue; + } + + callee_state_it->second = VisitState::kVisiting; + dfs_stack.push_back({callee_id, &callee_graph_it->second, 0}); + } + } + + return Status::OK(); +} + +Status ValidateModelLocalFunctionAcyclic( + const std::unordered_map& model_local_functions) { + LocalFunctionCallGraph call_graph; + ORT_RETURN_IF_ERROR(BuildLocalFunctionCallGraph(model_local_functions, call_graph)); + return ValidateCallGraphAcyclic(call_graph); +} + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/graph/model_helpers.h b/onnxruntime/core/graph/model_helpers.h new file mode 100644 index 0000000000000..777f2ac611c15 --- /dev/null +++ b/onnxruntime/core/graph/model_helpers.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/inlined_containers.h" + +namespace ONNX_NAMESPACE { +class FunctionProto; +} + +namespace onnxruntime { + +/// Adjacency list representation of a local function call graph. +/// Keys and values are string_views into stable storage (e.g. map keys that outlive this structure). +using LocalFunctionCallGraph = InlinedHashMap>; + +/// Build a call graph adjacency list from model local functions. +/// String views in the returned graph point into the keys of @p model_local_functions. +Status BuildLocalFunctionCallGraph( + const std::unordered_map& model_local_functions, + LocalFunctionCallGraph& call_graph); + +/// Validate that a call graph contains no cycles. +/// Returns an error with the cycle path if a cycle is detected. +Status ValidateCallGraphAcyclic(const LocalFunctionCallGraph& call_graph); + +/// Convenience: build the call graph from model local functions and validate acyclicity. +Status ValidateModelLocalFunctionAcyclic( + const std::unordered_map& model_local_functions); + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 8ed9a40097d4b..2530a1f73f81a 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -115,7 +115,7 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, gsl::span nodes_values_as_tensor, gsl::span node_values, gsl::span target_class_weights, gsl::span target_class_weights_as_tensor, - const InlinedVector& node_tree_ids, InlinedVector> indices); + const InlinedVector& node_tree_ids, const InlinedVector>& indices); size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, gsl::span nodes_values_as_tensor, gsl::span node_values, @@ -383,7 +383,7 @@ bool TreeEnsembleCommon::CheckIfSubtreesAr const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, gsl::span nodes_values_as_tensor, gsl::span node_values, gsl::span target_class_weights, gsl::span target_class_weights_as_tensor, - const InlinedVector& node_tree_ids, InlinedVector> indices) { + const InlinedVector& node_tree_ids, const InlinedVector>& indices) { if (left_id == right_id) { return true; } diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 93f2ea704a729..ee3b0a6ec2133 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -1,15 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include + #include "core/graph/onnx_protobuf.h" +#include "onnx/checker.h" #include "onnx/defs/parser.h" #include "core/common/span_utils.h" #include "core/framework/customregistry.h" #include "core/framework/op_kernel.h" #include "core/graph/model.h" +#include "core/graph/model_helpers.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/inference_session.h" @@ -87,6 +92,34 @@ static void Check(const char* source, } } +static Status LoadModel(const char* source) { + ONNX_NAMESPACE::OnnxParser parser(source); + ONNX_NAMESPACE::ModelProto model; + auto parse_status = parser.Parse(model); + if (!parse_status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to parse test model: ", parse_status.ErrorMessage()); + } + if (!parser.EndOfInput()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Extra unparsed input unexpected."); + } + + try { + ONNX_NAMESPACE::checker::check_model(model); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX model check failed: ", e.what()); + } + + std::string serialized_model; + if (!model.SerializeToString(&serialized_model) || serialized_model.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to serialize test model."); + } + + SessionOptions session_options; + InferenceSession session_object{session_options, GetEnvironment()}; + std::istringstream sstr(serialized_model); + return session_object.Load(sstr); +} + namespace { const char* basic_code = R"( < @@ -303,6 +336,370 @@ TEST(FunctionTest, CallInConditional) { Check(code, "x", {1.0, 2.0, 3.0}, "y", {6.0, 12.0, 18.0}); } +TEST(FunctionTest, RejectsSelfRecursiveLocalFunction) { + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.self_recursive (x) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + self_recursive (lx) => (ly) { + ly = local.self_recursive (lx) + } + )"; + + const auto status = LoadModel(code); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +TEST(FunctionTest, RejectsMutuallyRecursiveLocalFunctions) { + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.first (x) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + first (lx) => (ly) { + ly = local.second (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + second (lx) => (ly) { + ly = local.first (lx) + } + )"; + + const auto status = LoadModel(code); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +TEST(FunctionTest, RejectsRecursionThroughSubgraph) { + // A local function that calls itself inside an If subgraph (then_branch). + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.recursive_if (x) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + recursive_if (lx) => (ly) { + temp = Identity (lx) + cond = Constant () + ly = If (cond) < + then_branch = then_graph () => (float[N] then_out) + { + then_out = local.recursive_if (temp) + }, + else_branch = else_graph () => (float[N] else_out) + { + else_out = Identity (temp) + } + > + } + )"; + + const auto status = LoadModel(code); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +// --- Synthetic adjacency-list tests for ValidateCallGraphAcyclic --- +// These test the cycle detection algorithm directly without constructing ONNX models. + +TEST(FunctionTest, CallGraphAcyclic_EmptyGraph) { + onnxruntime::LocalFunctionCallGraph call_graph; + ASSERT_STATUS_OK(onnxruntime::ValidateCallGraphAcyclic(call_graph)); +} + +TEST(FunctionTest, CallGraphAcyclic_SingleNodeNoCalls) { + // Single function with no callees. + std::string a = "A"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {}; + ASSERT_STATUS_OK(onnxruntime::ValidateCallGraphAcyclic(call_graph)); +} + +TEST(FunctionTest, CallGraphAcyclic_SelfCycle) { + std::string a = "A"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {a}; + const auto status = onnxruntime::ValidateCallGraphAcyclic(call_graph); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("A -> A")); +} + +TEST(FunctionTest, CallGraphAcyclic_MutualCycle) { + std::string a = "A", b = "B"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {b}; + call_graph[b] = {a}; + const auto status = onnxruntime::ValidateCallGraphAcyclic(call_graph); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +TEST(FunctionTest, CallGraphAcyclic_LongerCycle) { + // A -> B -> C -> A + std::string a = "A", b = "B", c = "C"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {b}; + call_graph[b] = {c}; + call_graph[c] = {a}; + const auto status = onnxruntime::ValidateCallGraphAcyclic(call_graph); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); + // The cycle path should include all three participants. + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("A")); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("B")); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("C")); +} + +TEST(FunctionTest, CallGraphAcyclic_DiamondNoCycle) { + // A -> B, A -> C, B -> D, C -> D (no cycle) + std::string a = "A", b = "B", c = "C", d = "D"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {b, c}; + call_graph[b] = {d}; + call_graph[c] = {d}; + call_graph[d] = {}; + ASSERT_STATUS_OK(onnxruntime::ValidateCallGraphAcyclic(call_graph)); +} + +TEST(FunctionTest, CallGraphAcyclic_DeepChainNoCycle) { + // A -> B -> C -> D (no cycle) + std::string a = "A", b = "B", c = "C", d = "D"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {b}; + call_graph[b] = {c}; + call_graph[c] = {d}; + call_graph[d] = {}; + ASSERT_STATUS_OK(onnxruntime::ValidateCallGraphAcyclic(call_graph)); +} + +TEST(FunctionTest, CallGraphAcyclic_MultipleIndependentCycles) { + // Two independent cycles: A -> B -> A, C -> D -> C + std::string a = "A", b = "B", c = "C", d = "D"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {b}; + call_graph[b] = {a}; + call_graph[c] = {d}; + call_graph[d] = {c}; + const auto status = onnxruntime::ValidateCallGraphAcyclic(call_graph); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +TEST(FunctionTest, CallGraphAcyclic_SharedCallsDiamondNoCycle) { + // Regression test: acyclic model with shared function calls (diamond pattern). + // E -> A, E -> B, A -> C, B -> C, C -> D (no cycle despite shared references to C) + std::string a = "A", b = "B", c = "C", d = "D", e = "E"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[e] = {a, b}; + call_graph[a] = {c}; + call_graph[b] = {c}; + call_graph[c] = {d}; + call_graph[d] = {}; + ASSERT_STATUS_OK(onnxruntime::ValidateCallGraphAcyclic(call_graph)); +} + +// --- Model-level integration tests --- + +TEST(FunctionTest, RejectsLongerCycle) { + // A -> B -> C -> A (three-function cycle) + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.func_a (x) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_a (lx) => (ly) { + ly = local.func_b (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_b (lx) => (ly) { + ly = local.func_c (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_c (lx) => (ly) { + ly = local.func_a (lx) + } + )"; + + const auto status = LoadModel(code); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +TEST(FunctionTest, AcceptsAcyclicDiamond) { + // A -> B, A -> C, B -> D, C -> D (diamond, no cycle) + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.func_a (x) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_a (lx) => (ly) { + t1 = local.func_b (lx) + ly = local.func_c (t1) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_b (lx) => (ly) { + ly = local.func_d (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_c (lx) => (ly) { + ly = local.func_d (lx) + } + + < + opset_import: [ "" : 16 ], + domain: "local" + > + func_d (lx) => (ly) { + ly = Identity (lx) + } + )"; + + ASSERT_STATUS_OK(LoadModel(code)); +} + +TEST(FunctionTest, AcceptsTrivialSingleNodeFunction) { + // A local function with a single Identity node — verifies that trivial + // (but non-empty) function bodies pass acyclicity validation. + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.trivial_func (x) + } + + < + opset_import: [ "" : 16 ], + domain: "local" + > + trivial_func (lx) => (ly) { + ly = Identity (lx) + } + )"; + + ASSERT_STATUS_OK(LoadModel(code)); +} + +TEST(FunctionTest, RejectsMultipleIndependentCycles) { + // Two independent cycles in the same model: A -> B -> A, C -> D -> C + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + t = local.func_a (x) + y = local.func_c (t) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_a (lx) => (ly) { + ly = local.func_b (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_b (lx) => (ly) { + ly = local.func_a (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_c (lx) => (ly) { + ly = local.func_d (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_d (lx) => (ly) { + ly = local.func_c (lx) + } + )"; + + const auto status = LoadModel(code); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + // Test use of attibute references, especially where source/target attribute // names are not the same. In this example, the "start : int = @s" attribute-reference // binds the attribute named "start" of the Shape op to the attribute named "s" diff --git a/plugin-ep-webgpu/_packaging_utils.py b/plugin-ep-webgpu/_packaging_utils.py index 201b3342ff39c..84850e4dee5fe 100644 --- a/plugin-ep-webgpu/_packaging_utils.py +++ b/plugin-ep-webgpu/_packaging_utils.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. + """Shared utilities for the WebGPU plugin EP packaging scripts. Not a public API.""" from __future__ import annotations diff --git a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj index 58860c46b9c16..5bfbac0308e01 100644 --- a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj +++ b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj @@ -16,7 +16,8 @@ ONNX;ONNX Runtime;Machine Learning;AI;Deep Learning;WebGPU - MIT + LICENSE + https://onnxruntime.ai https://github.com/microsoft/onnxruntime git © Microsoft Corporation. All rights reserved. @@ -29,6 +30,9 @@ + + +