Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_),
WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_),
WGSL_TEMPLATE_PARAMETER(is_apple, is_apple_),
WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_),
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
Expand All @@ -221,7 +220,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_),
WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_),
WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_),
WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_));
WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_),
WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_));
}

Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const {
Expand Down Expand Up @@ -486,6 +486,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
bool is_apple = context.AdapterInfo().vendor == std::string_view{"apple"};
bool has_subgroups = context.HasFeature(wgpu::FeatureName::Subgroups);
bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH;
bool has_head_sink = head_sink != nullptr;
Expand All @@ -498,6 +499,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
parameters.is_unidirectional_,
is_nvidia,
is_apple,
has_subgroups,
q_BNSH,
use_seqlen_k,
has_head_sink};
Expand Down Expand Up @@ -532,7 +534,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co

program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile)
.SetWorkgroupSize(prefill_tile_size)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, is_apple, q_BNSH, use_seqlen_k, has_head_sink, program.max_k_step())
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, is_apple, has_subgroups, q_BNSH, use_seqlen_k, has_head_sink, program.max_k_step())
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(parameters.total_sequence_length_)},
{static_cast<uint32_t>(present_sequence_length)},
Expand Down Expand Up @@ -584,7 +586,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
return !parameters.is_packed_qkv_ &&
parameters.head_size_ == parameters.v_head_size_ &&
context.HasFeature(wgpu::FeatureName::Subgroups) &&
((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0);
}

Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
bool is_unidirectional,
bool is_nvidia,
bool is_apple,
bool has_subgroups,
bool q_BNSH,
bool use_seqlen_k = false,
bool has_head_sink = false)
Expand All @@ -88,12 +89,12 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
qkv_num_heads_(qkv_num_heads),
is_unidirectional_(is_unidirectional),
is_nvidia_(is_nvidia),
is_apple_(is_apple),
use_shm_path_(is_apple || is_nvidia || !has_subgroups),
q_BNSH_(q_BNSH),
use_seqlen_k_(use_seqlen_k),
has_head_sink_(has_head_sink) {
if (is_apple || is_nvidia) {
// On Apple and NVIDIA, use an optimized loop-based path with dynamic max_k_step.
if (use_shm_path_) {
// Use shared-memory loop-based path with dynamic max_k_step.
// Compute max_k_step from workgroup shared memory budget: k_tile + v_tile = 2 * element_size * head_size * max_k_step
const int element_size = is_fp16 ? 2 : 4;
constexpr int kMinWorkgroupStorageBudgetBytes = 16384;
Expand Down Expand Up @@ -130,7 +131,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
int qkv_num_heads_;
bool is_unidirectional_;
bool is_nvidia_;
bool is_apple_;
bool use_shm_path_;
bool q_BNSH_;
bool use_seqlen_k_;
bool has_head_sink_;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

#param has_attention_bias
#param has_head_sink
#param is_apple
#param is_fp16
#param is_qualcomm
#param is_unidirectional
Expand All @@ -10,6 +9,7 @@
#param qkv_head_size
#param qkv_num_heads
#param use_seqlen_k
#param use_shm_path
#param max_k_step_param

const head_size : u32 = qkv_head_size;
Expand Down Expand Up @@ -61,7 +61,7 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_
}
}

#if is_apple
#if use_shm_path

var<private> qk_scores : array<q_element_t, max_k_step>;

Expand Down Expand Up @@ -240,7 +240,7 @@ $MAIN {
let seq_causal_length = total_sequence_length;
#endif

#if is_apple
#if use_shm_path

for (var k_start = 0u; k_start < loop_bound; k_start += max_k_step) {
workgroupBarrier();
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "core/common/inlined_containers.h"
#include "core/common/logging/logging.h"
#include "core/flatbuffers/schema/ort.fbs.h"
#include "core/flatbuffers/flatbuffers_utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/model.h"
#include "core/graph/model_editor_api_types.h"
#include "core/graph/model_helpers.h"
#include "core/graph/model_load_utils.h"

#ifdef _MSC_VER
Expand Down Expand Up @@ -129,6 +136,8 @@ Model::Model(const std::string& graph_name,
func_ptr);
}

ORT_THROW_IF_ERROR(ValidateModelLocalFunctionAcyclic(model_local_functions_));

model_local_function_templates_maps_.reserve(model_proto_.functions().size());
for (auto& func : model_proto_.functions()) {
auto func_schema_ptr = function_utils::CreateSchema(func.domain(),
Expand Down Expand Up @@ -261,6 +270,8 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path,
model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name(), func.overload()), &func);
}

ORT_THROW_IF_ERROR(ValidateModelLocalFunctionAcyclic(model_local_functions_));

model_local_function_templates_maps_.reserve(model_proto_.functions().size());
for (auto& func : model_proto_.functions()) {
auto func_schema_ptr = function_utils::CreateSchema(func.domain(),
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class Model {
// map from function id to pointer of model local function proto
// FunctionProto is hosted in ModelProto.
// this map will be used for the local functions' schema's type/shape inference.
// This container is used by ONNX code and must be an std::unordered_map.
// Must be std::unordered_map to match ONNX_NAMESPACE::shape_inference::ModelLocalFunctionsMap.
std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*> model_local_functions_;
// this is the map from function id to the local function template.
// this map will be used by graph to instantiate the function body.
Expand Down
189 changes: 189 additions & 0 deletions onnxruntime/core/graph/model_helpers.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if !defined(ORT_MINIMAL_BUILD)

#include "core/graph/model_helpers.h"

#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>

#include "core/graph/function_utils.h"
#include "core/graph/onnx_protobuf.h"

namespace onnxruntime {

namespace {

// Iterative collection of local function calls from a sequence of nodes,
// including nodes inside nested subgraph attributes. Avoids recursion to
// prevent stack overflow from maliciously deep subgraph nesting.
template <typename NodeRange>
void CollectLocalFunctionCalls(
const NodeRange& nodes,
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_local_functions,
InlinedHashSet<std::string_view>& seen_calls,
InlinedVector<std::string_view>& called_functions) {
InlinedVector<const ONNX_NAMESPACE::GraphProto*> pending_graphs;

auto process_nodes = [&](const auto& node_range) {
for (const auto& node : node_range) {
const auto function_id = function_utils::GetFunctionIdentifier(
node.domain(), node.op_type(), node.overload());
auto it = model_local_functions.find(function_id);
if (it != model_local_functions.end()) {
// Use string_view into the map key (stable storage).
std::string_view key_view = it->first;
if (seen_calls.insert(key_view).second) {
called_functions.push_back(key_view);
}
}

for (const auto& attr : node.attribute()) {
if (attr.has_g()) {
pending_graphs.push_back(&attr.g());
}
for (const auto& sub_graph : attr.graphs()) {
pending_graphs.push_back(&sub_graph);
}
}
}
};

process_nodes(nodes);

while (!pending_graphs.empty()) {
const auto* graph = pending_graphs.back();
pending_graphs.pop_back();
process_nodes(graph->node());
}
}

} // namespace

Status BuildLocalFunctionCallGraph(
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_local_functions,
LocalFunctionCallGraph& call_graph) {
call_graph.reserve(model_local_functions.size());

for (const auto& [function_id, function_proto] : model_local_functions) {
if (function_proto == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Null function proto for function id: ", function_id);
}

InlinedHashSet<std::string_view> seen_calls;
InlinedVector<std::string_view> callees;
CollectLocalFunctionCalls(function_proto->node(), model_local_functions, seen_calls, callees);

call_graph.emplace(std::string_view(function_id), std::move(callees));
}

return Status::OK();
}

Status ValidateCallGraphAcyclic(const LocalFunctionCallGraph& call_graph) {
enum class VisitState { kNotVisited,
kVisiting,
kVisited };

InlinedHashMap<std::string_view, VisitState> visit_states;
visit_states.reserve(call_graph.size());
for (const auto& [function_id, _] : call_graph) {
ORT_UNUSED_PARAMETER(_);
visit_states.emplace(function_id, VisitState::kNotVisited);
}

// Each frame records the function being visited and a pointer to its callees vector
// in the call graph (no per-frame allocation).
struct DfsFrame {
std::string_view function_id;
const InlinedVector<std::string_view>* callees;
size_t next_callee_index;
};

std::vector<DfsFrame> dfs_stack;

for (const auto& [root_id, root_callees] : call_graph) {
auto root_state_it = visit_states.find(root_id);
if (root_state_it == visit_states.end() || root_state_it->second == VisitState::kVisited) {
continue;
}

root_state_it->second = VisitState::kVisiting;
dfs_stack.push_back({root_id, &root_callees, 0});

while (!dfs_stack.empty()) {
auto& frame = dfs_stack.back();

if (frame.next_callee_index >= frame.callees->size()) {
// All callees processed — mark as fully visited and pop.
auto it = visit_states.find(frame.function_id);
ORT_ENFORCE(it != visit_states.end());
it->second = VisitState::kVisited;
dfs_stack.pop_back();
continue;
}

std::string_view callee_id = (*frame.callees)[frame.next_callee_index];
frame.next_callee_index++;

auto callee_state_it = visit_states.find(callee_id);
if (callee_state_it == visit_states.end()) {
// Callee not in the graph — skip.
continue;
}

if (callee_state_it->second == VisitState::kVisited) {
continue;
}

if (callee_state_it->second == VisitState::kVisiting) {
// Cycle detected. Build cycle description from the stack.
std::string cycle;
bool in_cycle = false;
for (const auto& f : dfs_stack) {
if (f.function_id == callee_id) {
in_cycle = true;
}
if (in_cycle) {
if (!cycle.empty()) {
cycle.append(" -> ");
}
cycle.append(f.function_id);
}
}
cycle.append(" -> ");
cycle.append(callee_id);

return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH,
"Model local function definitions must not be recursive. Cycle detected: ", cycle);
}

// Push callee onto the DFS stack.
auto callee_graph_it = call_graph.find(callee_id);
if (callee_graph_it == call_graph.end()) {
continue;
}

callee_state_it->second = VisitState::kVisiting;
dfs_stack.push_back({callee_id, &callee_graph_it->second, 0});
}
}

return Status::OK();
}

Status ValidateModelLocalFunctionAcyclic(
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_local_functions) {
LocalFunctionCallGraph call_graph;
ORT_RETURN_IF_ERROR(BuildLocalFunctionCallGraph(model_local_functions, call_graph));
return ValidateCallGraphAcyclic(call_graph);
}

} // namespace onnxruntime

#endif // !defined(ORT_MINIMAL_BUILD)
41 changes: 41 additions & 0 deletions onnxruntime/core/graph/model_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#if !defined(ORT_MINIMAL_BUILD)

#include <string>
#include <string_view>
#include <unordered_map>

#include "core/common/common.h"
#include "core/common/inlined_containers.h"

namespace ONNX_NAMESPACE {
class FunctionProto;
}

namespace onnxruntime {

/// Adjacency list representation of a local function call graph.
/// Keys and values are string_views into stable storage (e.g. map keys that outlive this structure).
using LocalFunctionCallGraph = InlinedHashMap<std::string_view, InlinedVector<std::string_view>>;

/// Build a call graph adjacency list from model local functions.
/// String views in the returned graph point into the keys of @p model_local_functions.
Status BuildLocalFunctionCallGraph(
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_local_functions,
LocalFunctionCallGraph& call_graph);

/// Validate that a call graph contains no cycles.
/// Returns an error with the cycle path if a cycle is detected.
Status ValidateCallGraphAcyclic(const LocalFunctionCallGraph& call_graph);

/// Convenience: build the call graph from model local functions and validate acyclicity.
Status ValidateModelLocalFunctionAcyclic(
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_local_functions);

} // namespace onnxruntime

#endif // !defined(ORT_MINIMAL_BUILD)
Loading
Loading