diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 077090bb10911..79149f9a84dbe 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -149,6 +149,9 @@ class AttentionCPUBase : public AttentionBase { OpKernelContext* context, int beam_width, Tensor* output_qk) const { + ORT_RETURN_IF_ERROR(ValidateCacheIndirectionValues(cache_indir->Data(), batch_size, beam_width, + past_sequence_length, max_sequence_length)); + AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -186,6 +189,33 @@ class AttentionCPUBase : public AttentionBase { } private: + static Status ValidateCacheIndirectionValues(const int32_t* cache_indirection_data, + int batch_beam_size, + int beam_width, + int past_sequence_length, + int max_sequence_length) { + if (cache_indirection_data == nullptr || beam_width <= 0 || past_sequence_length <= 0) { + return Status::OK(); + } + + for (int batch_beam_index = 0; batch_beam_index < batch_beam_size; ++batch_beam_index) { + const int32_t* beam_indices = cache_indirection_data + + static_cast(batch_beam_index) * max_sequence_length; + for (int position = 0; position < past_sequence_length; ++position) { + const int32_t beam_index = beam_indices[position]; + if (beam_index < 0 || beam_index >= beam_width) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cache_indirection beam index out of range. Expected [0, ", beam_width, + "), got ", beam_index, + " at flattened batch_beam index ", batch_beam_index, + ", sequence position ", position); + } + } + } + + return Status::OK(); + } + // Helper function to compute the attention probs. It does 2 things: // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) + // 1 x mask_data(B, N, S, T) diff --git a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc index 0d2de59c05394..de6f47ca2626c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc @@ -178,6 +178,20 @@ Status DecoderMaskedMultiHeadAttention::Compute(OpKernelContext* context) con "If beam width is greater than 1, then cache indirection buffer MUST be present"); } + if (cache_indir != nullptr) { + // Read beam width from cache_indirection shape directly. + // DecoderMaskedMultiHeadAttentionParameters shadows AttentionParameters::beam_width, + // so the value set by CheckInputs on the base class is not visible here. + int cache_beam_width = static_cast(cache_indir->Shape().GetDims()[1]); + if (beam_width != nullptr && beam_width_value != cache_beam_width) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'beam_width' should match cache_indirection dimension 1, got ", + beam_width_value, " and ", cache_beam_width); + } + + beam_width_value = cache_beam_width; + } + AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); diff --git a/onnxruntime/contrib_ops/cpu/maxpool_with_mask.h b/onnxruntime/contrib_ops/cpu/maxpool_with_mask.h index 7210a9a7c6859..7dfb2770a6979 100644 --- a/onnxruntime/contrib_ops/cpu/maxpool_with_mask.h +++ b/onnxruntime/contrib_ops/cpu/maxpool_with_mask.h @@ -200,10 +200,22 @@ class MaxpoolWithMask : public OpKernel, public PoolBase { const TensorShape& x_shape = X->Shape(); const TensorShape& m_shape = M->Shape(); ORT_RETURN_IF_NOT(x_shape.NumDimensions() >= 3, "Input dimension cannot be less than 3."); - - // TODO: fix this checker later - // ONNXRUNTIME_RETURN_IF_NOT((x_shape[2] == m_shape[2]) && (x_shape[3] == m_shape[3]), " Input shape and mask shape - // mismatch: ", x_shape, " vs ", m_shape); + ORT_RETURN_IF_NOT(m_shape.NumDimensions() == x_shape.NumDimensions(), + "Mask and input must have the same number of dimensions. Got mask dims: ", + m_shape.NumDimensions(), " input dims: ", x_shape.NumDimensions()); + const bool input_has_nonzero_channels = x_shape[0] > 0 && x_shape[1] > 0; + // Mask N and C dimensions may differ from input (broadcasting via modulo). + // Only require them to be nonzero to prevent division-by-zero in total_mask_channels. + ORT_RETURN_IF_NOT(!input_has_nonzero_channels || (m_shape[0] > 0 && m_shape[1] > 0), + "Mask N and C dimensions must be greater than 0 when input N and C are greater than 0. " + "Got mask N=", + m_shape[0], " C=", m_shape[1], + " input N=", x_shape[0], " C=", x_shape[1]); + for (size_t i = 2; i < x_shape.NumDimensions(); ++i) { + ORT_RETURN_IF_NOT(m_shape[i] == x_shape[i], + "Mask and input spatial dimensions mismatch at dimension ", i, + ": mask=", m_shape[i], " input=", x_shape[i]); + } TensorShapeVector pads = pool_attrs_.pads; TensorShapeVector kernel_shape = pool_attrs_.kernel_shape; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index f1391ba1e3528..8217a07448266 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -212,7 +212,6 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), - WGSL_TEMPLATE_PARAMETER(is_apple, is_apple_), WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_), WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), @@ -221,7 +220,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_), WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_), WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_), - WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_)); + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), + WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_)); } Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const { @@ -486,6 +486,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; bool is_apple = context.AdapterInfo().vendor == std::string_view{"apple"}; + bool has_subgroups = context.HasFeature(wgpu::FeatureName::Subgroups); bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; bool has_head_sink = head_sink != nullptr; @@ -498,6 +499,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co parameters.is_unidirectional_, is_nvidia, is_apple, + has_subgroups, q_BNSH, use_seqlen_k, has_head_sink}; @@ -532,7 +534,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(prefill_tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, is_apple, q_BNSH, use_seqlen_k, has_head_sink, program.max_k_step()) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, is_apple, has_subgroups, q_BNSH, use_seqlen_k, has_head_sink, program.max_k_step()) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, @@ -584,7 +586,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && - context.HasFeature(wgpu::FeatureName::Subgroups) && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 27fa56e333874..e75b6378f67c6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -77,6 +77,7 @@ class FlashAttentionProgram final : public Program { bool is_unidirectional, bool is_nvidia, bool is_apple, + bool has_subgroups, bool q_BNSH, bool use_seqlen_k = false, bool has_head_sink = false) @@ -88,12 +89,12 @@ class FlashAttentionProgram final : public Program { qkv_num_heads_(qkv_num_heads), is_unidirectional_(is_unidirectional), is_nvidia_(is_nvidia), - is_apple_(is_apple), + use_shm_path_(is_apple || is_nvidia || !has_subgroups), q_BNSH_(q_BNSH), use_seqlen_k_(use_seqlen_k), has_head_sink_(has_head_sink) { - if (is_apple || is_nvidia) { - // On Apple and NVIDIA, use an optimized loop-based path with dynamic max_k_step. + if (use_shm_path_) { + // Use shared-memory loop-based path with dynamic max_k_step. // Compute max_k_step from workgroup shared memory budget: k_tile + v_tile = 2 * element_size * head_size * max_k_step const int element_size = is_fp16 ? 2 : 4; constexpr int kMinWorkgroupStorageBudgetBytes = 16384; @@ -130,7 +131,7 @@ class FlashAttentionProgram final : public Program { int qkv_num_heads_; bool is_unidirectional_; bool is_nvidia_; - bool is_apple_; + bool use_shm_path_; bool q_BNSH_; bool use_seqlen_k_; bool has_head_sink_; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index db41ac12ce268..6b620043413e3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -1,7 +1,6 @@ #param has_attention_bias #param has_head_sink -#param is_apple #param is_fp16 #param is_qualcomm #param is_unidirectional @@ -10,6 +9,7 @@ #param qkv_head_size #param qkv_num_heads #param use_seqlen_k +#param use_shm_path #param max_k_step_param const head_size : u32 = qkv_head_size; @@ -61,7 +61,7 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_ } } -#if is_apple +#if use_shm_path var qk_scores : array; @@ -240,7 +240,7 @@ $MAIN { let seq_causal_length = total_sequence_length; #endif -#if is_apple +#if use_shm_path for (var k_start = 0u; k_start < loop_bound; k_start += max_k_step) { workgroupBarrier(); diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index e06b99eea9fd7..0d3a84f30e1fb 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -1007,8 +1007,17 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers const epctx::BufferWriteFuncHolder* output_write_func_holder = ep_context_gen_options.TryGetOutputModelWriteFunc(); const std::filesystem::path* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); + // Determine whether we need to resolve/validate a file system path for the output model. + // A path is needed when: + // - Writing the output model to a file (not to a buffer or write function) + // - Writing initializers to an external file (needs the model path to compute the external file location) + const bool output_is_to_file = (output_buffer_holder == nullptr && output_write_func_holder == nullptr); + const bool needs_path_for_external_initializers = + (ep_context_gen_options.TryGetExternalInitializerFileInfo() != nullptr); + std::filesystem::path valid_output_model_path; - if (output_model_path_ptr != nullptr || !graph.ModelPath().empty()) { + if ((output_is_to_file || needs_path_for_external_initializers) && + (output_model_path_ptr != nullptr || !graph.ModelPath().empty())) { std::filesystem::path output_model_path = (output_model_path_ptr != nullptr) ? *output_model_path_ptr : std::filesystem::path(""); ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(output_model_path, diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 6f73456742160..275fa837a7257 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -2097,6 +2097,33 @@ void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor) { } #if !defined(DISABLE_SPARSE_TENSORS) + +// Validates that a TensorProto's external data path does not escape the model directory. +// Also validates that the file exists when filesystem access is available (skipped on WASM without a virtual FS). +// Returns Status::OK() (no-op) for tensors that do not use file-based external data. +static Status ValidateExternalDataPathForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path) { + // Gates on data_location == EXTERNAL directly instead of using HasExternalData()/HasExternalDataInFile(), + // which also require data_type != UNDEFINED. That check is appropriate for data processing (can't unpack + // without a type), but too narrow for security validation: we must validate any declared external path + // regardless of data_type. + if (tensor_proto.data_location() != ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + return Status::OK(); + } + + std::unique_ptr external_data_info; + ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); + const auto& rel_path = external_data_info->GetRelPath(); + + // In-memory external data uses special marker locations — skip file path validation for those. + if (rel_path == kTensorProtoLittleEndianMemoryAddressTag || + rel_path == kTensorProtoNativeEndianMemoryAddressTag) { + return Status::OK(); + } + + return utils::ValidateExternalDataPath(model_path, rel_path); +} + static Status CopySparseData(const std::string& name, int64_t nnz_elements, const ONNX_NAMESPACE::TensorProto& indices, @@ -2115,10 +2142,18 @@ static Status CopySparseData(const std::string& name, switch (indices.data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_INT64: if (needs_unpack) { - ORT_RETURN_IF_NOT(indices.raw_data().size() == SafeInt(indices_elements) * sizeof(int64_t), - "Sparse tensor: ", name, " indices raw data size does not match expected: ", - indices_elements * sizeof(int64_t)); + // For inline raw_data, validate size before unpacking to avoid a large allocation from a + // malformed tensor with small indices shape but oversized raw_data. For external data, + // raw_data is empty so we can only validate after unpacking. + if (!utils::HasExternalData(indices)) { + ORT_RETURN_IF_NOT(indices.raw_data().size() == SafeInt(indices_elements) * sizeof(int64_t), + "Sparse tensor: ", name, " indices raw data size does not match expected: ", + indices_elements * sizeof(int64_t)); + } ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer)); + ORT_RETURN_IF_NOT(unpack_buffer.size() == SafeInt(indices_elements) * sizeof(int64_t), + "Sparse tensor: ", name, " indices data size does not match expected: ", + indices_elements * sizeof(int64_t)); indices_data = ReinterpretAsSpan(gsl::make_span(unpack_buffer)); } else { ORT_RETURN_IF_NOT(indices.int64_data_size() == indices_elements, @@ -2129,10 +2164,15 @@ static Status CopySparseData(const std::string& name, break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: { if (needs_unpack) { - ORT_RETURN_IF_NOT(indices.raw_data().size() == SafeInt(indices_elements) * sizeof(int32_t), - "Sparse tensor: ", name, " indices raw data size does not match expected: ", - indices_elements * sizeof(int32_t)); + if (!utils::HasExternalData(indices)) { + ORT_RETURN_IF_NOT(indices.raw_data().size() == SafeInt(indices_elements) * sizeof(int32_t), + "Sparse tensor: ", name, " indices raw data size does not match expected: ", + indices_elements * sizeof(int32_t)); + } ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer)); + ORT_RETURN_IF_NOT(unpack_buffer.size() == SafeInt(indices_elements) * sizeof(int32_t), + "Sparse tensor: ", name, " indices data size does not match expected: ", + indices_elements * sizeof(int32_t)); auto int32_span = ReinterpretAsSpan(gsl::make_span(unpack_buffer)); indices_values.insert(indices_values.cend(), int32_span.begin(), int32_span.end()); unpack_buffer.clear(); @@ -2148,10 +2188,15 @@ static Status CopySparseData(const std::string& name, } case ONNX_NAMESPACE::TensorProto_DataType_INT16: { if (needs_unpack) { - ORT_RETURN_IF_NOT(indices.raw_data().size() == SafeInt(indices_elements) * sizeof(int16_t), - "Sparse tensor: ", name, " indices raw data size does not match expected: ", - indices_elements * sizeof(int16_t)); + if (!utils::HasExternalData(indices)) { + ORT_RETURN_IF_NOT(indices.raw_data().size() == SafeInt(indices_elements) * sizeof(int16_t), + "Sparse tensor: ", name, " indices raw data size does not match expected: ", + indices_elements * sizeof(int16_t)); + } ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer)); + ORT_RETURN_IF_NOT(unpack_buffer.size() == SafeInt(indices_elements) * sizeof(int16_t), + "Sparse tensor: ", name, " indices data size does not match expected: ", + indices_elements * sizeof(int16_t)); auto int16_span = ReinterpretAsSpan(gsl::make_span(unpack_buffer)); indices_values.insert(indices_values.cend(), int16_span.begin(), int16_span.end()); unpack_buffer.clear(); @@ -2167,10 +2212,15 @@ static Status CopySparseData(const std::string& name, } case ONNX_NAMESPACE::TensorProto_DataType_INT8: { if (needs_unpack) { - ORT_RETURN_IF_NOT(indices.raw_data().size() == narrow(indices_elements), - "Sparse tensor: ", name, " indices raw data size does not match expected: ", - indices_elements * sizeof(int8_t)); + if (!utils::HasExternalData(indices)) { + ORT_RETURN_IF_NOT(indices.raw_data().size() == narrow(indices_elements), + "Sparse tensor: ", name, " indices raw data size does not match expected: ", + indices_elements * sizeof(int8_t)); + } ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer)); + ORT_RETURN_IF_NOT(unpack_buffer.size() == narrow(indices_elements), + "Sparse tensor: ", name, " indices data size does not match expected: ", + indices_elements * sizeof(int8_t)); auto int8_span = ReinterpretAsSpan(gsl::make_span(unpack_buffer)); indices_values.insert(indices_values.cend(), int8_span.begin(), int8_span.end()); unpack_buffer.clear(); @@ -2318,6 +2368,12 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT } } + // Validate external data paths before any early returns or allocations. + // This ensures malicious paths are rejected even for zero-element tensors, + // and prevents large allocations before an invalid path is caught. + ORT_RETURN_IF_ERROR(ValidateExternalDataPathForTensor(sparse_values, model_path)); + ORT_RETURN_IF_ERROR(ValidateExternalDataPathForTensor(indices, model_path)); + if (dense_elements == 0) { // if there are no elements in the dense tensor, we can return early with an empty tensor proto return status; diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index fd4d266dc51f0..bfa25a5cb2e9a 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -1,13 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include +#include +#include +#include + +#include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/flatbuffers/flatbuffers_utils.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/model.h" #include "core/graph/model_editor_api_types.h" +#include "core/graph/model_helpers.h" #include "core/graph/model_load_utils.h" #ifdef _MSC_VER @@ -129,6 +136,8 @@ Model::Model(const std::string& graph_name, func_ptr); } + ORT_THROW_IF_ERROR(ValidateModelLocalFunctionAcyclic(model_local_functions_)); + model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), @@ -261,6 +270,8 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name(), func.overload()), &func); } + ORT_THROW_IF_ERROR(ValidateModelLocalFunctionAcyclic(model_local_functions_)); + model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index c86aac44806bd..9a877ac6bba95 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -343,7 +343,7 @@ class Model { // map from function id to pointer of model local function proto // FunctionProto is hosted in ModelProto. // this map will be used for the local functions' schema's type/shape inference. - // This container is used by ONNX code and must be an std::unordered_map. + // Must be std::unordered_map to match ONNX_NAMESPACE::shape_inference::ModelLocalFunctionsMap. std::unordered_map model_local_functions_; // this is the map from function id to the local function template. // this map will be used by graph to instantiate the function body. diff --git a/onnxruntime/core/graph/model_helpers.cc b/onnxruntime/core/graph/model_helpers.cc new file mode 100644 index 0000000000000..c3214d488ff0d --- /dev/null +++ b/onnxruntime/core/graph/model_helpers.cc @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/graph/model_helpers.h" + +#include +#include +#include +#include +#include + +#include "core/graph/function_utils.h" +#include "core/graph/onnx_protobuf.h" + +namespace onnxruntime { + +namespace { + +// Iterative collection of local function calls from a sequence of nodes, +// including nodes inside nested subgraph attributes. Avoids recursion to +// prevent stack overflow from maliciously deep subgraph nesting. +template +void CollectLocalFunctionCalls( + const NodeRange& nodes, + const std::unordered_map& model_local_functions, + InlinedHashSet& seen_calls, + InlinedVector& called_functions) { + InlinedVector pending_graphs; + + auto process_nodes = [&](const auto& node_range) { + for (const auto& node : node_range) { + const auto function_id = function_utils::GetFunctionIdentifier( + node.domain(), node.op_type(), node.overload()); + auto it = model_local_functions.find(function_id); + if (it != model_local_functions.end()) { + // Use string_view into the map key (stable storage). + std::string_view key_view = it->first; + if (seen_calls.insert(key_view).second) { + called_functions.push_back(key_view); + } + } + + for (const auto& attr : node.attribute()) { + if (attr.has_g()) { + pending_graphs.push_back(&attr.g()); + } + for (const auto& sub_graph : attr.graphs()) { + pending_graphs.push_back(&sub_graph); + } + } + } + }; + + process_nodes(nodes); + + while (!pending_graphs.empty()) { + const auto* graph = pending_graphs.back(); + pending_graphs.pop_back(); + process_nodes(graph->node()); + } +} + +} // namespace + +Status BuildLocalFunctionCallGraph( + const std::unordered_map& model_local_functions, + LocalFunctionCallGraph& call_graph) { + call_graph.reserve(model_local_functions.size()); + + for (const auto& [function_id, function_proto] : model_local_functions) { + if (function_proto == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Null function proto for function id: ", function_id); + } + + InlinedHashSet seen_calls; + InlinedVector callees; + CollectLocalFunctionCalls(function_proto->node(), model_local_functions, seen_calls, callees); + + call_graph.emplace(std::string_view(function_id), std::move(callees)); + } + + return Status::OK(); +} + +Status ValidateCallGraphAcyclic(const LocalFunctionCallGraph& call_graph) { + enum class VisitState { kNotVisited, + kVisiting, + kVisited }; + + InlinedHashMap visit_states; + visit_states.reserve(call_graph.size()); + for (const auto& [function_id, _] : call_graph) { + ORT_UNUSED_PARAMETER(_); + visit_states.emplace(function_id, VisitState::kNotVisited); + } + + // Each frame records the function being visited and a pointer to its callees vector + // in the call graph (no per-frame allocation). + struct DfsFrame { + std::string_view function_id; + const InlinedVector* callees; + size_t next_callee_index; + }; + + std::vector dfs_stack; + + for (const auto& [root_id, root_callees] : call_graph) { + auto root_state_it = visit_states.find(root_id); + if (root_state_it == visit_states.end() || root_state_it->second == VisitState::kVisited) { + continue; + } + + root_state_it->second = VisitState::kVisiting; + dfs_stack.push_back({root_id, &root_callees, 0}); + + while (!dfs_stack.empty()) { + auto& frame = dfs_stack.back(); + + if (frame.next_callee_index >= frame.callees->size()) { + // All callees processed — mark as fully visited and pop. + auto it = visit_states.find(frame.function_id); + ORT_ENFORCE(it != visit_states.end()); + it->second = VisitState::kVisited; + dfs_stack.pop_back(); + continue; + } + + std::string_view callee_id = (*frame.callees)[frame.next_callee_index]; + frame.next_callee_index++; + + auto callee_state_it = visit_states.find(callee_id); + if (callee_state_it == visit_states.end()) { + // Callee not in the graph — skip. + continue; + } + + if (callee_state_it->second == VisitState::kVisited) { + continue; + } + + if (callee_state_it->second == VisitState::kVisiting) { + // Cycle detected. Build cycle description from the stack. + std::string cycle; + bool in_cycle = false; + for (const auto& f : dfs_stack) { + if (f.function_id == callee_id) { + in_cycle = true; + } + if (in_cycle) { + if (!cycle.empty()) { + cycle.append(" -> "); + } + cycle.append(f.function_id); + } + } + cycle.append(" -> "); + cycle.append(callee_id); + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Model local function definitions must not be recursive. Cycle detected: ", cycle); + } + + // Push callee onto the DFS stack. + auto callee_graph_it = call_graph.find(callee_id); + if (callee_graph_it == call_graph.end()) { + continue; + } + + callee_state_it->second = VisitState::kVisiting; + dfs_stack.push_back({callee_id, &callee_graph_it->second, 0}); + } + } + + return Status::OK(); +} + +Status ValidateModelLocalFunctionAcyclic( + const std::unordered_map& model_local_functions) { + LocalFunctionCallGraph call_graph; + ORT_RETURN_IF_ERROR(BuildLocalFunctionCallGraph(model_local_functions, call_graph)); + return ValidateCallGraphAcyclic(call_graph); +} + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/graph/model_helpers.h b/onnxruntime/core/graph/model_helpers.h new file mode 100644 index 0000000000000..777f2ac611c15 --- /dev/null +++ b/onnxruntime/core/graph/model_helpers.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/inlined_containers.h" + +namespace ONNX_NAMESPACE { +class FunctionProto; +} + +namespace onnxruntime { + +/// Adjacency list representation of a local function call graph. +/// Keys and values are string_views into stable storage (e.g. map keys that outlive this structure). +using LocalFunctionCallGraph = InlinedHashMap>; + +/// Build a call graph adjacency list from model local functions. +/// String views in the returned graph point into the keys of @p model_local_functions. +Status BuildLocalFunctionCallGraph( + const std::unordered_map& model_local_functions, + LocalFunctionCallGraph& call_graph); + +/// Validate that a call graph contains no cycles. +/// Returns an error with the cycle path if a cycle is detected. +Status ValidateCallGraphAcyclic(const LocalFunctionCallGraph& call_graph); + +/// Convenience: build the call graph from model local functions and validate acyclicity. +Status ValidateModelLocalFunctionAcyclic( + const std::unordered_map& model_local_functions); + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 8ed9a40097d4b..2530a1f73f81a 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -115,7 +115,7 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, gsl::span nodes_values_as_tensor, gsl::span node_values, gsl::span target_class_weights, gsl::span target_class_weights_as_tensor, - const InlinedVector& node_tree_ids, InlinedVector> indices); + const InlinedVector& node_tree_ids, const InlinedVector>& indices); size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, gsl::span nodes_values_as_tensor, gsl::span node_values, @@ -383,7 +383,7 @@ bool TreeEnsembleCommon::CheckIfSubtreesAr const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, gsl::span nodes_values_as_tensor, gsl::span node_values, gsl::span target_class_weights, gsl::span target_class_weights_as_tensor, - const InlinedVector& node_tree_ids, InlinedVector> indices) { + const InlinedVector& node_tree_ids, const InlinedVector>& indices) { if (left_id == right_id) { return true; } diff --git a/onnxruntime/core/providers/cuda/cu_inc/unary_elementwise_impl.cuh b/onnxruntime/core/providers/cuda/cu_inc/unary_elementwise_impl.cuh index c8ddbadb12fb2..5959482e5664e 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/unary_elementwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/unary_elementwise_impl.cuh @@ -14,11 +14,11 @@ __global__ void _UnaryElementWise( const InT* input_data, OutT* output_data, const FuncT functor, - CUDA_LONG N) { - CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + int64_t N) { + int64_t start = static_cast(NumElementsPerThread) * NumThreadsPerBlock * blockIdx.x + threadIdx.x; InT value[NumElementsPerThread]; - CUDA_LONG id = start; + int64_t id = start; #pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { if (id < N) { @@ -47,8 +47,10 @@ void UnaryElementWiseImpl( if (count == 0) // special case where there's a dim value of 0 in the shape return; - int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); - CUDA_LONG N = static_cast(count); + size_t blocksPerGridSize = CeilDiv(count, static_cast(GridDim::maxThreadsPerBlock) * GridDim::maxElementsPerThread); + ORT_ENFORCE(blocksPerGridSize <= static_cast(INT32_MAX), "Grid size exceeds CUDA limits"); + int blocksPerGrid = static_cast(blocksPerGridSize); + int64_t N = static_cast(count); _UnaryElementWise <<>>( input_data, diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.cu b/onnxruntime/core/providers/cuda/tensor/cast_op.cu index a8cd6caaa5d5f..c56d613e25241 100644 --- a/onnxruntime/core/providers/cuda/tensor/cast_op.cu +++ b/onnxruntime/core/providers/cuda/tensor/cast_op.cu @@ -220,8 +220,8 @@ struct CastStd { #endif // DISABLE_FLOAT4_TYPES template -__global__ void CastKernelStd(const InT* input, OutT* output, CUDA_LONG N, CastStd cast) { - CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; +__global__ void CastKernelStd(const InT* input, OutT* output, int64_t N, CastStd cast) { + int64_t id = static_cast(NumElementsPerThread) * NumThreadsPerBlock * blockIdx.x + threadIdx.x; #pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { @@ -237,11 +237,13 @@ Status CudaCastStd(cudaStream_t stream, const InT* input, OutT* output, size_t n if (num_of_elements <= 0) return Status::OK(); - int blocksPerGrid = static_cast(CeilDiv(num_of_elements, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + size_t blocksPerGridSize = CeilDiv(num_of_elements, static_cast(GridDim::maxThreadsPerBlock) * GridDim::maxElementsPerThread); + ORT_RETURN_IF_NOT(blocksPerGridSize <= static_cast(INT32_MAX), "Grid size exceeds CUDA limits"); + int blocksPerGrid = static_cast(blocksPerGridSize); CastKernelStd<<>>( input, output, - static_cast(num_of_elements), + static_cast(num_of_elements), CastStd()); return Status::OK(); } @@ -251,10 +253,10 @@ Status CudaCastStd(cudaStream_t stream, const InT* input, OutT* output, size_t n template __global__ void CudaCastPairwiseKernel(const InPairType* input, OutPairType* output, - CUDA_LONG pair_count, + int64_t pair_count, CastStd pair_caster, CastStd singleton_caster) { - CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + int64_t id = static_cast(NumElementsPerThread) * NumThreadsPerBlock * blockIdx.x + threadIdx.x; #pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { @@ -284,9 +286,11 @@ Status CudaCastPairwise(cudaStream_t stream, const Float4E2M1x2* input, float* o bool is_odd = (num_of_elements & 0x01) != 0; - int pair_count = static_cast(num_of_elements / 2); + size_t pair_count = num_of_elements / 2; - int blocksPerGrid = static_cast(CeilDiv(pair_count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + size_t blocksPerGridSize = CeilDiv(pair_count, static_cast(GridDim::maxThreadsPerBlock) * GridDim::maxElementsPerThread); + ORT_RETURN_IF_NOT(blocksPerGridSize <= static_cast(INT32_MAX), "Grid size exceeds CUDA limits"); + int blocksPerGrid = static_cast(blocksPerGridSize); if (pair_count == 0) { blocksPerGrid = 1; @@ -296,14 +300,14 @@ Status CudaCastPairwise(cudaStream_t stream, const Float4E2M1x2* input, float* o CudaCastPairwiseKernel <<>>( - input, reinterpret_cast(output), pair_count, + input, reinterpret_cast(output), static_cast(pair_count), CastStd(), CastStd()); } else { CudaCastPairwiseKernel <<>>( - input, reinterpret_cast(output), pair_count, + input, reinterpret_cast(output), static_cast(pair_count), CastStd(), CastStd()); } @@ -318,9 +322,11 @@ Status CudaCastPairwise(cudaStream_t stream, const float* input, Float4E2M1x2* o bool is_odd = (num_of_elements & 0x01) != 0; - int pair_count = static_cast(num_of_elements / 2); + size_t pair_count = num_of_elements / 2; - int blocksPerGrid = static_cast(CeilDiv(pair_count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + size_t blocksPerGridSize = CeilDiv(pair_count, static_cast(GridDim::maxThreadsPerBlock) * GridDim::maxElementsPerThread); + ORT_RETURN_IF_NOT(blocksPerGridSize <= static_cast(INT32_MAX), "Grid size exceeds CUDA limits"); + int blocksPerGrid = static_cast(blocksPerGridSize); if (pair_count == 0) { blocksPerGrid = 1; @@ -330,14 +336,14 @@ Status CudaCastPairwise(cudaStream_t stream, const float* input, Float4E2M1x2* o CudaCastPairwiseKernel <<>>( - reinterpret_cast(input), output, pair_count, + reinterpret_cast(input), output, static_cast(pair_count), CastStd(), CastStd()); } else { CudaCastPairwiseKernel <<>>( - reinterpret_cast(input), output, pair_count, + reinterpret_cast(input), output, static_cast(pair_count), CastStd(), CastStd()); } @@ -353,8 +359,8 @@ template Status CudaCastPairwise(cudaStream_t stream, const #if !defined(DISABLE_FLOAT8_TYPES) template -__global__ void CastKernelSat(const InT* input, OutT* output, CUDA_LONG N, CastSat cast, bool saturate) { - CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; +__global__ void CastKernelSat(const InT* input, OutT* output, int64_t N, CastSat cast, bool saturate) { + int64_t id = static_cast(NumElementsPerThread) * NumThreadsPerBlock * blockIdx.x + threadIdx.x; #pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { @@ -370,11 +376,13 @@ Status CudaCastSat(cudaStream_t stream, const InT* input, OutT* output, size_t n if (num_of_element <= 0) return Status::OK(); - int blocksPerGrid = static_cast(CeilDiv(num_of_element, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + size_t blocksPerGridSize = CeilDiv(num_of_element, static_cast(GridDim::maxThreadsPerBlock) * GridDim::maxElementsPerThread); + ORT_RETURN_IF_NOT(blocksPerGridSize <= static_cast(INT32_MAX), "Grid size exceeds CUDA limits"); + int blocksPerGrid = static_cast(blocksPerGridSize); CastKernelSat<<>>( input, output, - static_cast(num_of_element), + static_cast(num_of_element), CastSat(), saturate); return Status::OK(); diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.cc b/onnxruntime/core/providers/xnnpack/math/gemm.cc index 9b78e943122de..f8992b13c8f5d 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.cc +++ b/onnxruntime/core/providers/xnnpack/math/gemm.cc @@ -36,6 +36,10 @@ bool Gemm::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& gra const NodeArg* A_arg = input_defs[0]; const NodeArg* B_arg = input_defs[1]; const NodeArg* C_arg = input_defs.size() == 2 ? nullptr : input_defs[2]; + // Single source of truth for "is C actually present?". Matches the kernel + // constructor's C_matrix_exists_ = C_arg && C_arg->Exists() contract and the + // has_bias convention used in xnnpack/nn/conv_base.cc. + const bool has_c = (C_arg != nullptr && C_arg->Exists()); // we only support float currently const auto* A_type = A_arg->TypeAsProto(); @@ -51,14 +55,13 @@ bool Gemm::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& gra break; } - if (input_defs.size() == 3 && !graph.IsConstantInitializer(C_arg->Name(), true)) { + if (has_c && !graph.IsConstantInitializer(C_arg->Name(), true)) { break; } // making sure we are dealing with MatMul const ONNX_NAMESPACE::TensorShapeProto* A_shape = A_arg->Shape(); const ONNX_NAMESPACE::TensorShapeProto* B_shape = B_arg->Shape(); - const ONNX_NAMESPACE::TensorShapeProto* C_shape = C_arg->Shape(); if (!A_shape || A_shape->dim_size() >= 3) { break; @@ -68,12 +71,27 @@ bool Gemm::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& gra break; } - if (!C_shape || C_shape->dim_size() >= 3) { - break; - } - - if (C_arg && C_arg->Exists() && (C_shape->dim(0).dim_value() != B_shape->dim(1).dim_value() && C_shape->dim(0).dim_value() != B_shape->dim(0).dim_value())) { - break; + // Optional C: if the input slot is absent (2-input Gemm) C_arg is null and we must not + // call Shape() on it. If C_arg exists but Exists() is false (empty optional input slot) + // we treat it identically: per ONNX, an empty optional input is equivalent to omitting + // the input. The kernel constructor's C_matrix_exists_ contract agrees. + if (has_c) { + const ONNX_NAMESPACE::TensorShapeProto* C_shape = C_arg->Shape(); + if (!C_shape || C_shape->dim_size() >= 3) { + break; + } + + // Rank-0 C would be out of bounds on the C_shape->dim(0) check below and the + // xnn_create_fully_connected_nc_* bias path requires a length-N vector, so reject + // and fall back to the CPU EP. + if (C_shape->dim_size() == 0) { + break; + } + + if (C_shape->dim(0).dim_value() != B_shape->dim(1).dim_value() && + C_shape->dim(0).dim_value() != B_shape->dim(0).dim_value()) { + break; + } } supported = true; diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 93b509eec6982..0cba82e1135c9 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -7,6 +7,7 @@ import argparse import logging import os +import re import warnings import onnx @@ -309,6 +310,18 @@ def parse_arguments(argv=None): ) quant_args.set_defaults(quantize_symmetric=False) + quant_args.add_argument( + "--quant_method", + required=False, + type=str, + default="k_quant", + choices=["k_quant", "k_quant_mixed"], + help="Quantization method for INT4 precision. " + "k_quant = k_quant algorithm with all nodes at INT4. " + "k_quant_mixed = k_quant with mixed precision (sensitive layers at INT8, rest at INT4). " + "Inspired by llama.cpp k-quant mixed strategy.", + ) + args = parser.parse_args(argv) # Collect cross QKs if either flag is enabled @@ -320,20 +333,121 @@ def parse_arguments(argv=None): return args -# quant_method is reserved for mixed precision in future -def make_quant_algo_config(precision, quant_method: str, matmul_nodes=None): +def get_sensitive_node_names(matmul_nodes: list[str], encoder_layers: int, decoder_layers: int): + """Identify sensitive MatMul nodes that should use INT8 in k_quant_mixed. + + Follows the llama.cpp k-quant mixed strategy adapted for Whisper encoder-decoder: + - First/last ~12.5% of layers + every 3rd layer in between are "sensitive layers" + - Within sensitive layers: attention Q/K/V projections and FFN fc2 (down projection) get INT8 + - proj_out (LM head) always gets INT8 + + Reference: llama.cpp/src/llama-quant.cpp#L136 + + Args: + matmul_nodes: list of MatMul node names from the ONNX graph. + encoder_layers: number of encoder layers in the model. + decoder_layers: number of decoder layers in the model. + + Returns: + list of node names that should be quantized to INT8. + """ + + def get_sensitive_layer_indices(num_layers): + return [ + i + for i in range(num_layers) + if i < num_layers / 8 or i >= 7 * num_layers / 8 or (i - round(num_layers / 8)) % 3 == 2 + ] + + enc_sensitive_layers = set(get_sensitive_layer_indices(encoder_layers)) + dec_sensitive_layers = set(get_sensitive_layer_indices(decoder_layers)) + + # Patterns for sensitive MatMul types within a sensitive layer: + # - Attention projections: q_proj, k_proj, v_proj (most sensitive to quantization) + # - FFN fc2 / out_proj equivalent (the down projection) + # - Cross-attention k_proj (sensitive based on weight distribution analysis) + sensitive_matmul_patterns = [ + "/self_attn/q_proj/", + "/self_attn/k_proj/", + "/self_attn/v_proj/", + "/self_attn/out_proj/", + "/encoder_attn/q_proj/", + "/encoder_attn/k_proj/", + "/encoder_attn/v_proj/", + "/encoder_attn/out_proj/", + "/fc2/", + ] + + sensitive = [] + for name in matmul_nodes: + # proj_out (LM head equivalent) is always sensitive + if "proj_out" in name: + sensitive.append(name) + continue + + # Determine if this is an encoder or decoder node, and extract layer index + layer_match = re.search(r"layers\.(\d+)", name) + if not layer_match: + # Cross-attention KV projections outside layer hierarchy (e.g. /k_proj/MatMul) + # These are always run once; keep them at INT8 for accuracy + if any(p.strip("/") in name for p in ["/k_proj/", "/v_proj/"]): + sensitive.append(name) + continue + + layer_idx = int(layer_match.group(1)) + + is_encoder = "/encoder/" in name + is_decoder = "/decoder/" in name + + # Check if this layer is in the sensitive set + if is_encoder and layer_idx in enc_sensitive_layers: + if any(pat in name for pat in sensitive_matmul_patterns): + sensitive.append(name) + elif is_decoder and layer_idx in dec_sensitive_layers: + if any(pat in name for pat in sensitive_matmul_patterns): + sensitive.append(name) + + return sensitive + + +def make_quant_algo_config( + precision: Precision, + quant_method: str, + matmul_nodes: list[str] | None = None, + encoder_layers: int = 0, + decoder_layers: int = 0, +): + """Create quantization algorithm config for Whisper models. + + Args: + precision: Precision enum (INT4 or INT8). + quant_method: "k_quant" or "k_quant_mixed". + matmul_nodes: list of MatMul node names from the ONNX graph. + encoder_layers: number of encoder layers (needed for k_quant_mixed). + decoder_layers: number of decoder layers (needed for k_quant_mixed). + + Returns: + KQuantWeightOnlyQuantConfig with appropriate customized_weight_config. + """ customized_weight_config = {} - quant_algo_config = None - # need to use k_quant for int8 if precision == Precision.INT8: + # INT8: set every MatMul to 8-bit for node_name in matmul_nodes: customized_weight_config[node_name] = {"bits": 8} - quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) - else: - quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) + elif precision == Precision.INT4 and quant_method == "k_quant_mixed": + # k_quant_mixed: sensitive layers at INT8, rest at INT4 + sensitive_names = get_sensitive_node_names(matmul_nodes, encoder_layers, decoder_layers) + for node_name in sensitive_names: + customized_weight_config[node_name] = {"bits": 8} + logger.info( + f"k_quant_mixed: {len(sensitive_names)} sensitive nodes (INT8) " + f"out of {len(matmul_nodes)} total MatMul nodes" + ) + for name in sensitive_names: + logger.info(f" INT8: {name}") - return quant_algo_config + return KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) def export_onnx_models( @@ -356,6 +470,7 @@ def export_onnx_models( accuracy_level: int = 0, quantize_symmetric: bool = False, provider: str = "cpu", + quant_method: str = "k_quant", ): device = torch.device("cuda" if use_gpu else "cpu") if not use_gpu: @@ -458,7 +573,13 @@ def export_onnx_models( if precision in (Precision.INT8, Precision.INT4): onnx_model = onnx.load(onnx_path, load_external_data=True) matmul_nodes = [node.name for node in onnx_model.graph.node if node.op_type == "MatMul"] - quant_algo_config = make_quant_algo_config(precision, "k_quant", matmul_nodes) + quant_algo_config = make_quant_algo_config( + precision, + quant_method, + matmul_nodes, + encoder_layers=config.encoder_layers, + decoder_layers=config.decoder_layers, + ) quant = MatMulNBitsQuantizer( model=onnx_model, @@ -533,6 +654,7 @@ def main(argv=None): args.accuracy_level, args.quantize_symmetric, args.provider, + args.quant_method, ) max_diff = 0 diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py index ec3a7a2f39928..2ab9f85258a4f 100644 --- a/onnxruntime/python/tools/transformers/quantize_helper.py +++ b/onnxruntime/python/tools/transformers/quantize_helper.py @@ -9,7 +9,11 @@ import onnx import torch -from transformers.modeling_utils import Conv1D + +try: + from transformers.pytorch_utils import Conv1D +except ImportError: + from transformers.modeling_utils import Conv1D logger = logging.getLogger(__name__) diff --git a/onnxruntime/test/autoep/test_handle_leak.cc b/onnxruntime/test/autoep/test_handle_leak.cc new file mode 100644 index 0000000000000..c3e9da7e746fc --- /dev/null +++ b/onnxruntime/test/autoep/test_handle_leak.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/util/include/file_util.h" +#include "test/util/include/temp_dir.h" + +#if defined(_WIN32) +#include +#else +#include +#endif + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +namespace { + +// Returns whether the library is currently mapped in the process, or std::nullopt if the platform +// does not support querying loaded-library state without side effects. +// On Windows, GetModuleHandleW queries by filename without incrementing the refcount. +// On Linux/macOS, dlopen with RTLD_NOLOAD probes without loading; if it succeeds it adds a +// refcount that we immediately release with dlclose. +std::optional IsLibraryLoaded(const std::filesystem::path& library_path) { +#if defined(_WIN32) + return GetModuleHandleW(library_path.filename().wstring().c_str()) != nullptr; +#else +#ifdef RTLD_NOLOAD + void* handle = dlopen(library_path.c_str(), RTLD_NOLOAD | RTLD_NOW); + if (handle) { + dlclose(handle); // Undo the refcount added by the RTLD_NOLOAD probe. + return true; + } + return false; +#else + // RTLD_NOLOAD is not available on this platform; cannot probe without loading. + static_cast(library_path); + return std::nullopt; +#endif +#endif +} + +} // namespace + +// Verify that registering and unregistering a plugin EP library does not leak the library handle. +// +// ProviderLibrary::Load() loads the library then probes for the "GetProvider" symbol. Most plugin EP +// libraries do not export "GetProvider", so the probe fails. Without the fix (PR #28396), +// Load() returned the error without calling Unload(), leaving a leaked refcount. After +// UnregisterExecutionProviderLibrary released only the EpLibraryPlugin's reference, the library +// remained mapped in the process. +// +// To ensure this test is independent of process state (other tests may load the same EP library), +// we copy the library to a temporary directory with a unique filename. This guarantees the copy +// has never been loaded, so we can reliably detect refcount leaks via IsLibraryLoaded. +TEST(OrtEpLibrary, RegisterUnregisterDoesNotLeakLibraryHandle) { + const std::filesystem::path& original_library_path = Utils::example_ep_info.library_path; + + // Use a unique registration name to avoid conflicts with other tests that may have + // registered the same EP library and failed to unregister it. + const std::string registration_name = "handle_leak_test_ep"; + + // Copy the EP library to the temp directory with a unique filename so it is guaranteed to + // not already be loaded in this process. + TemporaryDirectory temp_dir(ORT_TSTR("test_handle_leak_temp")); + const std::filesystem::path temp_library_path = + std::filesystem::path(temp_dir.Path()) / + GetSharedLibraryFileName(ORT_TSTR("handle_leak_test_ep")); + + std::error_code ec; + std::filesystem::copy_file(original_library_path, temp_library_path, + std::filesystem::copy_options::overwrite_existing, ec); + ASSERT_FALSE(ec) << "Failed to copy EP library to temp directory: " << ec.message(); + + std::optional loaded_before = IsLibraryLoaded(temp_library_path); + if (!loaded_before.has_value()) { + GTEST_SKIP() << "Platform does not support querying loaded-library state."; + } + + // The copy should not be loaded yet since we just created it with a unique name. + ASSERT_FALSE(*loaded_before) << "Freshly copied library should not already be loaded in the process."; + + // Register the plugin EP library inside a smaller scope so that the gsl::finally cleanup + // calls UnregisterExecutionProviderLibrary exactly once when leaving the scope. + { + ASSERT_NO_THROW(ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), temp_library_path.c_str())); + auto cleanup_lib = gsl::finally([®istration_name] { + Ort::Status ignored{Ort::GetApi().UnregisterExecutionProviderLibrary(*ort_env, registration_name.c_str())}; + }); + + // The library should be loaded now. + ASSERT_TRUE(IsLibraryLoaded(temp_library_path).value_or(false)) << "Library should be loaded after registration."; + } + + // If the fix is applied, the library should be fully unloaded (refcount == 0). + // Without the fix, ProviderLibrary::Load() leaks a refcount so the library remains mapped. + EXPECT_FALSE(IsLibraryLoaded(temp_library_path).value_or(true)) + << "Library handle leaked: EP library is still loaded after UnregisterExecutionProviderLibrary. " + "This indicates ProviderLibrary::Load() did not call Unload() on GetProvider symbol miss."; +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 7cdbad3ef80a7..2451f7e03a281 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -933,5 +933,33 @@ TEST(DecoderMaskedMultiHeadAttentionTest, cpu_self_attn_fp32) { TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false, /* use_cuda = */ false); } +TEST(DecoderMaskedMultiHeadAttentionTest, cpu_cache_indirection_beam_index_out_of_range) { + OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", 1); + tester.AddAttribute("past_present_share_buffer", 1); + + tester.AddInput("query", {2, 1, 4}, std::vector(8, 0.1f)); + tester.AddInput("key", {2, 1, 4}, std::vector(8, 0.2f)); + tester.AddInput("value", {2, 1, 4}, std::vector(8, 0.3f)); + tester.AddOptionalInputEdge(); + tester.AddOptionalInputEdge(); + tester.AddInput("past_key", {2, 1, 4, 4}, std::vector(32, 0.4f)); + tester.AddInput("past_value", {2, 1, 4, 4}, std::vector(32, 0.5f)); + tester.AddInput("past_sequence_length", {1}, {2}); + tester.AddInput("beam_width", {1}, {2}); + tester.AddInput("cache_indirection", {1, 2, 4}, {0, 2, 0, 0, 0, 0, 0, 0}); + tester.AddOptionalInputEdge(); + + tester.AddOutput("output", {2, 1, 4}, std::vector(8, 0.0f)); + tester.AddOutput("present_key", {2, 1, 4, 4}, std::vector(32, 0.0f)); + tester.AddOutput("present_value", {2, 1, 4, 4}, std::vector(32, 0.0f)); + tester.AddOptionalOutputEdge(); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectFailure, "cache_indirection beam index out of range", + {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/maxpool_mask_test.cc b/onnxruntime/test/contrib_ops/maxpool_mask_test.cc index 7ff1c9919fcad..ed65700ccd336 100644 --- a/onnxruntime/test/contrib_ops/maxpool_mask_test.cc +++ b/onnxruntime/test/contrib_ops/maxpool_mask_test.cc @@ -80,5 +80,86 @@ TEST(ContribOpTest, MaxPoolWithMask) { test.Run(); } +TEST(ContribOpTest, MaxPoolWithMask_SpatialDimMismatch) { + OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain); + + test.AddAttribute("auto_pad", ""); + test.AddAttribute("strides", std::vector{1, 1}); + test.AddAttribute("pads", std::vector{0, 0, 0, 0}); + test.AddAttribute("kernel_shape", std::vector{8, 8}); + + // Input X has shape {1, 1, 8, 8} + std::vector x_dims = {1, 1, 8, 8}; + std::vector x_vals(64, 1.0f); + + // Mask M has wrong spatial dimensions: {1, 1, 4, 8} instead of {1, 1, 8, 8} + std::vector m_dims = {1, 1, 4, 8}; + std::vector m_vals(32, 1); + + // Placeholder output shape and values (not validated since we expect failure) + std::vector expected_dims = {1, 1, 1, 1}; + std::vector expected_vals = {1.0f}; + + test.AddInput("X", x_dims, x_vals); + test.AddInput("M", m_dims, m_vals); + test.AddOutput("Y", expected_dims, expected_vals); + test.Run(BaseTester::ExpectResult::kExpectFailure, + "Mask and input spatial dimensions mismatch at dimension 2"); +} + +TEST(ContribOpTest, MaxPoolWithMask_DimCountMismatch) { + OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain); + + test.AddAttribute("auto_pad", ""); + test.AddAttribute("strides", std::vector{1, 1}); + test.AddAttribute("pads", std::vector{0, 0, 0, 0}); + test.AddAttribute("kernel_shape", std::vector{8, 8}); + + // Input X has shape {1, 1, 8, 8} (4D) + std::vector x_dims = {1, 1, 8, 8}; + std::vector x_vals(64, 1.0f); + + // Mask M has wrong number of dimensions: {1, 1, 8} (3D) instead of 4D + std::vector m_dims = {1, 1, 8}; + std::vector m_vals(8, 1); + + // Placeholder output shape and values (not validated since we expect failure) + std::vector expected_dims = {1, 1, 1, 1}; + std::vector expected_vals = {1.0f}; + + test.AddInput("X", x_dims, x_vals); + test.AddInput("M", m_dims, m_vals); + test.AddOutput("Y", expected_dims, expected_vals); + test.Run(BaseTester::ExpectResult::kExpectFailure, + "Mask and input must have the same number of dimensions"); +} + +TEST(ContribOpTest, MaxPoolWithMask_MaskEmptyBatchDim) { + OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain); + + test.AddAttribute("auto_pad", ""); + test.AddAttribute("strides", std::vector{1, 1}); + test.AddAttribute("pads", std::vector{0, 0, 0, 0}); + test.AddAttribute("kernel_shape", std::vector{8, 8}); + + // Input X has shape {1, 1, 8, 8} (non-empty) + std::vector x_dims = {1, 1, 8, 8}; + std::vector x_vals(64, 1.0f); + + // Mask M has N=0: should trigger the nonzero N/C guard + std::vector m_dims = {0, 1, 8, 8}; + std::vector m_vals; // 0 elements + + // Placeholder output shape and values (not validated since we expect failure) + std::vector expected_dims = {1, 1, 1, 1}; + std::vector expected_vals = {1.0f}; + + test.AddInput("X", x_dims, x_vals); + test.AddInput("M", m_dims, m_vals); + test.AddOutput("Y", expected_dims, expected_vals); + test.Run(BaseTester::ExpectResult::kExpectFailure, + "Mask N and C dimensions must be greater than 0"); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index c740959105977..9c974e03119f9 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -562,6 +562,58 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, } } +TEST(MultiHeadAttentionTest, CacheIndirectionBeamIndexOutOfRange) { + OpTester tester("MultiHeadAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", 1); + + tester.AddInput("query", {2, 1, 4}, std::vector(8, 0.1f)); + tester.AddInput("key", {2, 1, 4}, std::vector(8, 0.2f)); + tester.AddInput("value", {2, 1, 4}, std::vector(8, 0.3f)); + tester.AddOptionalInputEdge(); + tester.AddOptionalInputEdge(); + tester.AddOptionalInputEdge(); + tester.AddInput("past_key", {2, 1, 4, 4}, std::vector(32, 0.4f)); + tester.AddInput("past_value", {2, 1, 4, 4}, std::vector(32, 0.5f)); + tester.AddInput("past_sequence_length", {1}, {2}); + tester.AddInput("cache_indirection", {1, 2, 4}, {0, 2, 0, 0, 0, 0, 0, 0}); + + tester.AddOutput("output", {2, 1, 4}, std::vector(8, 0.0f)); + tester.AddOutput("present_key", {2, 1, 4, 4}, std::vector(32, 0.0f)); + tester.AddOutput("present_value", {2, 1, 4, 4}, std::vector(32, 0.0f)); + tester.AddOptionalOutputEdge(); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectFailure, "cache_indirection beam index out of range", + {}, nullptr, &execution_providers); +} + +TEST(MultiHeadAttentionTest, CacheIndirectionBeamWidthOneInvalidIndex) { + OpTester tester("MultiHeadAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", 1); + + tester.AddInput("query", {2, 1, 4}, std::vector(8, 0.1f)); + tester.AddInput("key", {2, 1, 4}, std::vector(8, 0.2f)); + tester.AddInput("value", {2, 1, 4}, std::vector(8, 0.3f)); + tester.AddOptionalInputEdge(); + tester.AddOptionalInputEdge(); + tester.AddOptionalInputEdge(); + tester.AddInput("past_key", {2, 1, 4, 4}, std::vector(32, 0.4f)); + tester.AddInput("past_value", {2, 1, 4, 4}, std::vector(32, 0.5f)); + tester.AddInput("past_sequence_length", {1}, {2}); + tester.AddInput("cache_indirection", {2, 1, 4}, {0, 1, 0, 0, 0, 0, 0, 0}); + + tester.AddOutput("output", {2, 1, 4}, std::vector(8, 0.0f)); + tester.AddOutput("present_key", {2, 1, 4, 4}, std::vector(32, 0.0f)); + tester.AddOutput("present_value", {2, 1, 4, 4}, std::vector(32, 0.0f)); + tester.AddOptionalOutputEdge(); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectFailure, "cache_indirection beam index out of range", + {}, nullptr, &execution_providers); +} + // Test fused cross attention kernel // It requires head_size > 32 and head_size <= 64 for T4 GPU; hidden_size == v_hidden_size. TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 93f2ea704a729..ee3b0a6ec2133 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -1,15 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include + #include "core/graph/onnx_protobuf.h" +#include "onnx/checker.h" #include "onnx/defs/parser.h" #include "core/common/span_utils.h" #include "core/framework/customregistry.h" #include "core/framework/op_kernel.h" #include "core/graph/model.h" +#include "core/graph/model_helpers.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/inference_session.h" @@ -87,6 +92,34 @@ static void Check(const char* source, } } +static Status LoadModel(const char* source) { + ONNX_NAMESPACE::OnnxParser parser(source); + ONNX_NAMESPACE::ModelProto model; + auto parse_status = parser.Parse(model); + if (!parse_status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to parse test model: ", parse_status.ErrorMessage()); + } + if (!parser.EndOfInput()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Extra unparsed input unexpected."); + } + + try { + ONNX_NAMESPACE::checker::check_model(model); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX model check failed: ", e.what()); + } + + std::string serialized_model; + if (!model.SerializeToString(&serialized_model) || serialized_model.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to serialize test model."); + } + + SessionOptions session_options; + InferenceSession session_object{session_options, GetEnvironment()}; + std::istringstream sstr(serialized_model); + return session_object.Load(sstr); +} + namespace { const char* basic_code = R"( < @@ -303,6 +336,370 @@ TEST(FunctionTest, CallInConditional) { Check(code, "x", {1.0, 2.0, 3.0}, "y", {6.0, 12.0, 18.0}); } +TEST(FunctionTest, RejectsSelfRecursiveLocalFunction) { + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.self_recursive (x) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + self_recursive (lx) => (ly) { + ly = local.self_recursive (lx) + } + )"; + + const auto status = LoadModel(code); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +TEST(FunctionTest, RejectsMutuallyRecursiveLocalFunctions) { + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.first (x) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + first (lx) => (ly) { + ly = local.second (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + second (lx) => (ly) { + ly = local.first (lx) + } + )"; + + const auto status = LoadModel(code); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +TEST(FunctionTest, RejectsRecursionThroughSubgraph) { + // A local function that calls itself inside an If subgraph (then_branch). + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.recursive_if (x) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + recursive_if (lx) => (ly) { + temp = Identity (lx) + cond = Constant () + ly = If (cond) < + then_branch = then_graph () => (float[N] then_out) + { + then_out = local.recursive_if (temp) + }, + else_branch = else_graph () => (float[N] else_out) + { + else_out = Identity (temp) + } + > + } + )"; + + const auto status = LoadModel(code); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +// --- Synthetic adjacency-list tests for ValidateCallGraphAcyclic --- +// These test the cycle detection algorithm directly without constructing ONNX models. + +TEST(FunctionTest, CallGraphAcyclic_EmptyGraph) { + onnxruntime::LocalFunctionCallGraph call_graph; + ASSERT_STATUS_OK(onnxruntime::ValidateCallGraphAcyclic(call_graph)); +} + +TEST(FunctionTest, CallGraphAcyclic_SingleNodeNoCalls) { + // Single function with no callees. + std::string a = "A"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {}; + ASSERT_STATUS_OK(onnxruntime::ValidateCallGraphAcyclic(call_graph)); +} + +TEST(FunctionTest, CallGraphAcyclic_SelfCycle) { + std::string a = "A"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {a}; + const auto status = onnxruntime::ValidateCallGraphAcyclic(call_graph); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("A -> A")); +} + +TEST(FunctionTest, CallGraphAcyclic_MutualCycle) { + std::string a = "A", b = "B"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {b}; + call_graph[b] = {a}; + const auto status = onnxruntime::ValidateCallGraphAcyclic(call_graph); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +TEST(FunctionTest, CallGraphAcyclic_LongerCycle) { + // A -> B -> C -> A + std::string a = "A", b = "B", c = "C"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {b}; + call_graph[b] = {c}; + call_graph[c] = {a}; + const auto status = onnxruntime::ValidateCallGraphAcyclic(call_graph); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); + // The cycle path should include all three participants. + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("A")); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("B")); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("C")); +} + +TEST(FunctionTest, CallGraphAcyclic_DiamondNoCycle) { + // A -> B, A -> C, B -> D, C -> D (no cycle) + std::string a = "A", b = "B", c = "C", d = "D"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {b, c}; + call_graph[b] = {d}; + call_graph[c] = {d}; + call_graph[d] = {}; + ASSERT_STATUS_OK(onnxruntime::ValidateCallGraphAcyclic(call_graph)); +} + +TEST(FunctionTest, CallGraphAcyclic_DeepChainNoCycle) { + // A -> B -> C -> D (no cycle) + std::string a = "A", b = "B", c = "C", d = "D"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {b}; + call_graph[b] = {c}; + call_graph[c] = {d}; + call_graph[d] = {}; + ASSERT_STATUS_OK(onnxruntime::ValidateCallGraphAcyclic(call_graph)); +} + +TEST(FunctionTest, CallGraphAcyclic_MultipleIndependentCycles) { + // Two independent cycles: A -> B -> A, C -> D -> C + std::string a = "A", b = "B", c = "C", d = "D"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[a] = {b}; + call_graph[b] = {a}; + call_graph[c] = {d}; + call_graph[d] = {c}; + const auto status = onnxruntime::ValidateCallGraphAcyclic(call_graph); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +TEST(FunctionTest, CallGraphAcyclic_SharedCallsDiamondNoCycle) { + // Regression test: acyclic model with shared function calls (diamond pattern). + // E -> A, E -> B, A -> C, B -> C, C -> D (no cycle despite shared references to C) + std::string a = "A", b = "B", c = "C", d = "D", e = "E"; + onnxruntime::LocalFunctionCallGraph call_graph; + call_graph[e] = {a, b}; + call_graph[a] = {c}; + call_graph[b] = {c}; + call_graph[c] = {d}; + call_graph[d] = {}; + ASSERT_STATUS_OK(onnxruntime::ValidateCallGraphAcyclic(call_graph)); +} + +// --- Model-level integration tests --- + +TEST(FunctionTest, RejectsLongerCycle) { + // A -> B -> C -> A (three-function cycle) + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.func_a (x) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_a (lx) => (ly) { + ly = local.func_b (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_b (lx) => (ly) { + ly = local.func_c (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_c (lx) => (ly) { + ly = local.func_a (lx) + } + )"; + + const auto status = LoadModel(code); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + +TEST(FunctionTest, AcceptsAcyclicDiamond) { + // A -> B, A -> C, B -> D, C -> D (diamond, no cycle) + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.func_a (x) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_a (lx) => (ly) { + t1 = local.func_b (lx) + ly = local.func_c (t1) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_b (lx) => (ly) { + ly = local.func_d (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_c (lx) => (ly) { + ly = local.func_d (lx) + } + + < + opset_import: [ "" : 16 ], + domain: "local" + > + func_d (lx) => (ly) { + ly = Identity (lx) + } + )"; + + ASSERT_STATUS_OK(LoadModel(code)); +} + +TEST(FunctionTest, AcceptsTrivialSingleNodeFunction) { + // A local function with a single Identity node — verifies that trivial + // (but non-empty) function bodies pass acyclicity validation. + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y = local.trivial_func (x) + } + + < + opset_import: [ "" : 16 ], + domain: "local" + > + trivial_func (lx) => (ly) { + ly = Identity (lx) + } + )"; + + ASSERT_STATUS_OK(LoadModel(code)); +} + +TEST(FunctionTest, RejectsMultipleIndependentCycles) { + // Two independent cycles in the same model: A -> B -> A, C -> D -> C + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + t = local.func_a (x) + y = local.func_c (t) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_a (lx) => (ly) { + ly = local.func_b (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_b (lx) => (ly) { + ly = local.func_a (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_c (lx) => (ly) { + ly = local.func_d (lx) + } + + < + opset_import: [ "" : 16, "local" : 1 ], + domain: "local" + > + func_d (lx) => (ly) { + ly = local.func_c (lx) + } + )"; + + const auto status = LoadModel(code); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("must not be recursive")); +} + // Test use of attibute references, especially where source/target attribute // names are not the same. In this example, the "start : int = @s" attribute-reference // binds the attribute named "start" of the Shape op to the attribute named "s" diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 59ec8f51b4f4e..9efaed8ac7bd6 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -2539,6 +2539,284 @@ TEST(SparseTensorConversionTests, SparseCooToDense_2DRowOutOfRange) { EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid COO 2D index")); } +// Positive tests for SparseTensorProtoToDenseTensorProto with external data. +// These verify end-to-end conversion succeeds when values and/or indices are stored +// in legitimate external files within the model directory. + +// Helper: write data to a temp file and configure a TensorProto to reference it as external data. +// The file is created in the current working directory using CreateTestFile. +// The ScopedFileDeleter is assigned immediately after file creation to ensure cleanup on any failure. +template +static void SetupExternalDataTensor(TensorProto_DataType type, + const std::vector& data, + PathString& filename, + TensorProto& tensor_proto, + ScopedFileDeleter& file_deleter) { + size_t size_in_bytes = data.size() * sizeof(T); + std::vector le_data(size_in_bytes); + + auto src_span = gsl::make_span(data.data(), data.size()); + auto dst_span = gsl::make_span(le_data.data(), le_data.size()); + ASSERT_STATUS_OK(onnxruntime::utils::WriteLittleEndian(src_span, dst_span)); + + FILE* fp; + CreateTestFile(fp, filename); + file_deleter = ScopedFileDeleter(filename); + ASSERT_EQ(size_in_bytes, fwrite(le_data.data(), 1, size_in_bytes, fp)); + ASSERT_EQ(0, fclose(fp)); + + tensor_proto.set_data_type(type); + tensor_proto.set_data_location(TensorProto_DataLocation_EXTERNAL); + + auto* loc = tensor_proto.mutable_external_data()->Add(); + loc->set_key("location"); + loc->set_value(ToUTF8String(filename)); + + auto* len = tensor_proto.mutable_external_data()->Add(); + len->set_key("length"); + len->set_value(std::to_string(size_in_bytes)); +} + +// External values + inline indices (INT64), rank-1 COO. +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_ExternalValues_InlineIndices) { + // Dense shape [2, 3] = 6 elements. + // NNZ=3 values at linear indices [0, 2, 5]. + // Expected dense: [1.0, 0, 2.0, 0, 0, 3.0] + std::vector values = {1.0f, 2.0f, 3.0f}; + PathString values_file(ORT_TSTR("ext_val_XXXXXX")); + + SparseTensorProto sparse; + sparse.add_dims(2); + sparse.add_dims(3); + + ScopedFileDeleter values_deleter; + SetupExternalDataTensor(TensorProto_DataType_FLOAT, values, values_file, *sparse.mutable_values(), + values_deleter); + sparse.mutable_values()->set_name("ext_values_test"); + sparse.mutable_values()->add_dims(3); // NNZ + + auto* indices = sparse.mutable_indices(); + indices->set_data_type(TensorProto_DataType_INT64); + indices->add_dims(3); + indices->add_int64_data(0); + indices->add_int64_data(2); + indices->add_int64_data(5); + + // model_path in CWD so external files are within the model directory + std::filesystem::path model_path = std::filesystem::current_path() / "model.onnx"; + TensorProto dense; + ASSERT_STATUS_OK(utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense)); + + ASSERT_EQ(dense.dims_size(), 2); + EXPECT_EQ(dense.dims(0), 2); + EXPECT_EQ(dense.dims(1), 3); + + std::vector unpacked(6); + ASSERT_STATUS_OK(utils::UnpackTensor(dense, model_path, unpacked.data(), unpacked.size())); + std::vector expected = {1.0f, 0.0f, 2.0f, 0.0f, 0.0f, 3.0f}; + EXPECT_EQ(unpacked, expected); +} + +// Inline values + external indices (INT64), rank-1 COO. +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_InlineValues_ExternalIndicesInt64) { + // Dense shape [4] = 4 elements. + // NNZ=2 at indices [1, 3]. + // Expected dense: [0, 10.0, 0, 20.0] + std::vector indices_data = {1, 3}; + PathString indices_file(ORT_TSTR("ext_idx_XXXXXX")); + + SparseTensorProto sparse; + sparse.add_dims(4); + + auto* values = sparse.mutable_values(); + values->set_name("ext_indices_test"); + values->set_data_type(TensorProto_DataType_FLOAT); + values->add_dims(2); + values->add_float_data(10.0f); + values->add_float_data(20.0f); + + ScopedFileDeleter indices_deleter; + SetupExternalDataTensor(TensorProto_DataType_INT64, indices_data, indices_file, + *sparse.mutable_indices(), indices_deleter); + sparse.mutable_indices()->add_dims(2); + + std::filesystem::path model_path = std::filesystem::current_path() / "model.onnx"; + TensorProto dense; + ASSERT_STATUS_OK(utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense)); + + std::vector unpacked(4); + ASSERT_STATUS_OK(utils::UnpackTensor(dense, model_path, unpacked.data(), unpacked.size())); + std::vector expected = {0.0f, 10.0f, 0.0f, 20.0f}; + EXPECT_EQ(unpacked, expected); +} + +// Inline values + external indices (INT32), rank-1 COO. +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_InlineValues_ExternalIndicesInt32) { + std::vector indices_data = {0, 3}; + PathString indices_file(ORT_TSTR("ext_i32_XXXXXX")); + + SparseTensorProto sparse; + sparse.add_dims(2); + sparse.add_dims(2); + + auto* values = sparse.mutable_values(); + values->set_name("ext_int32_idx_test"); + values->set_data_type(TensorProto_DataType_FLOAT); + values->add_dims(2); + values->add_float_data(5.0f); + values->add_float_data(6.0f); + + ScopedFileDeleter indices_deleter; + SetupExternalDataTensor(TensorProto_DataType_INT32, indices_data, indices_file, + *sparse.mutable_indices(), indices_deleter); + sparse.mutable_indices()->add_dims(2); + + std::filesystem::path model_path = std::filesystem::current_path() / "model.onnx"; + TensorProto dense; + ASSERT_STATUS_OK(utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense)); + + std::vector unpacked(4); + ASSERT_STATUS_OK(utils::UnpackTensor(dense, model_path, unpacked.data(), unpacked.size())); + std::vector expected = {5.0f, 0.0f, 0.0f, 6.0f}; + EXPECT_EQ(unpacked, expected); +} + +// Inline values + external indices (INT16), rank-1 COO. +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_InlineValues_ExternalIndicesInt16) { + std::vector indices_data = {1, 2}; + PathString indices_file(ORT_TSTR("ext_i16_XXXXXX")); + + SparseTensorProto sparse; + sparse.add_dims(4); + + auto* values = sparse.mutable_values(); + values->set_name("ext_int16_idx_test"); + values->set_data_type(TensorProto_DataType_FLOAT); + values->add_dims(2); + values->add_float_data(7.0f); + values->add_float_data(8.0f); + + ScopedFileDeleter indices_deleter; + SetupExternalDataTensor(TensorProto_DataType_INT16, indices_data, indices_file, + *sparse.mutable_indices(), indices_deleter); + sparse.mutable_indices()->add_dims(2); + + std::filesystem::path model_path = std::filesystem::current_path() / "model.onnx"; + TensorProto dense; + ASSERT_STATUS_OK(utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense)); + + std::vector unpacked(4); + ASSERT_STATUS_OK(utils::UnpackTensor(dense, model_path, unpacked.data(), unpacked.size())); + std::vector expected = {0.0f, 7.0f, 8.0f, 0.0f}; + EXPECT_EQ(unpacked, expected); +} + +// Inline values + external indices (INT8), rank-1 COO. +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_InlineValues_ExternalIndicesInt8) { + std::vector indices_data = {0, 2}; + PathString indices_file(ORT_TSTR("ext_i8_XXXXXX")); + + SparseTensorProto sparse; + sparse.add_dims(3); + + auto* values = sparse.mutable_values(); + values->set_name("ext_int8_idx_test"); + values->set_data_type(TensorProto_DataType_FLOAT); + values->add_dims(2); + values->add_float_data(9.0f); + values->add_float_data(11.0f); + + ScopedFileDeleter indices_deleter; + SetupExternalDataTensor(TensorProto_DataType_INT8, indices_data, indices_file, + *sparse.mutable_indices(), indices_deleter); + sparse.mutable_indices()->add_dims(2); + + std::filesystem::path model_path = std::filesystem::current_path() / "model.onnx"; + TensorProto dense; + ASSERT_STATUS_OK(utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense)); + + std::vector unpacked(3); + ASSERT_STATUS_OK(utils::UnpackTensor(dense, model_path, unpacked.data(), unpacked.size())); + std::vector expected = {9.0f, 0.0f, 11.0f}; + EXPECT_EQ(unpacked, expected); +} + +// Both external values and external indices (INT64), rank-1 COO. +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_ExternalValues_ExternalIndicesInt64) { + // Dense shape [3, 2] = 6 elements. + // NNZ=2 at linear indices [1, 4]. + // Expected dense: [0, 100.0, 0, 0, 200.0, 0] + std::vector values_data = {100.0f, 200.0f}; + std::vector indices_data = {1, 4}; + PathString values_file(ORT_TSTR("ext_bv_XXXXXX")); + PathString indices_file(ORT_TSTR("ext_bi_XXXXXX")); + + SparseTensorProto sparse; + sparse.add_dims(3); + sparse.add_dims(2); + + ScopedFileDeleter values_deleter; + SetupExternalDataTensor(TensorProto_DataType_FLOAT, values_data, values_file, *sparse.mutable_values(), + values_deleter); + sparse.mutable_values()->set_name("ext_both_test"); + sparse.mutable_values()->add_dims(2); + + ScopedFileDeleter indices_deleter; + SetupExternalDataTensor(TensorProto_DataType_INT64, indices_data, indices_file, + *sparse.mutable_indices(), indices_deleter); + sparse.mutable_indices()->add_dims(2); + + std::filesystem::path model_path = std::filesystem::current_path() / "model.onnx"; + TensorProto dense; + ASSERT_STATUS_OK(utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense)); + + ASSERT_EQ(dense.dims_size(), 2); + EXPECT_EQ(dense.dims(0), 3); + EXPECT_EQ(dense.dims(1), 2); + + std::vector unpacked(6); + ASSERT_STATUS_OK(utils::UnpackTensor(dense, model_path, unpacked.data(), unpacked.size())); + std::vector expected = {0.0f, 100.0f, 0.0f, 0.0f, 200.0f, 0.0f}; + EXPECT_EQ(unpacked, expected); +} + +// Both external values and external indices (INT64), rank-2 COO indices. +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_ExternalValues_ExternalIndicesInt64_Rank2) { + // Dense shape [3, 3] = 9 elements. + // NNZ=2 with 2D indices: [[0, 2], [2, 0]] -> positions (0,2)=2, (2,0)=6. + // Expected dense: [0, 0, 50.0, 0, 0, 0, 60.0, 0, 0] + std::vector values_data = {50.0f, 60.0f}; + // Rank-2 indices: flattened as [row0, col0, row1, col1] + std::vector indices_data = {0, 2, 2, 0}; + PathString values_file(ORT_TSTR("ext_r2v_XXXXXX")); + PathString indices_file(ORT_TSTR("ext_r2i_XXXXXX")); + + SparseTensorProto sparse; + sparse.add_dims(3); + sparse.add_dims(3); + + ScopedFileDeleter values_deleter; + SetupExternalDataTensor(TensorProto_DataType_FLOAT, values_data, values_file, *sparse.mutable_values(), + values_deleter); + sparse.mutable_values()->set_name("ext_rank2_test"); + sparse.mutable_values()->add_dims(2); // NNZ + + ScopedFileDeleter indices_deleter; + SetupExternalDataTensor(TensorProto_DataType_INT64, indices_data, indices_file, + *sparse.mutable_indices(), indices_deleter); + sparse.mutable_indices()->add_dims(2); // NNZ + sparse.mutable_indices()->add_dims(2); // rank of dense tensor + + std::filesystem::path model_path = std::filesystem::current_path() / "model.onnx"; + TensorProto dense; + ASSERT_STATUS_OK(utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense)); + + std::vector unpacked(9); + ASSERT_STATUS_OK(utils::UnpackTensor(dense, model_path, unpacked.data(), unpacked.size())); + std::vector expected = {0.0f, 0.0f, 50.0f, 0.0f, 0.0f, 0.0f, 60.0f, 0.0f, 0.0f}; + EXPECT_EQ(unpacked, expected); +} + #endif // !defined(DISABLE_SPARSE_TENSORS) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 71ac5b49e9718..06cc3ea6ad8d2 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -835,6 +835,324 @@ TEST_F(PathValidationTest, WeaklyCanonicalPathNtVolumeFallback_ResolvesDotDot) { } #endif // defined(_WIN32) +#if !defined(DISABLE_SPARSE_TENSORS) +// Regression test: SparseTensorProtoToDenseTensorProto must reject external_data paths +// that escape the model directory (path traversal via "../" in location). +TEST_F(PathValidationTest, SparseTensorExternalDataPathTraversalBlocked_Values) { + // Create model directory and a "secret" file outside it. + auto model_dir = base_dir_ / "model_dir"; + std::error_code ec; + std::filesystem::create_directories(model_dir, ec); + ASSERT_FALSE(ec) << "Failed to create model_dir: " << ec.message(); + + // Write known float data to a file outside the model directory. + auto secret_file = base_dir_ / "secret.bin"; + { + std::ofstream ofs(secret_file, std::ios::binary); + ASSERT_TRUE(ofs.is_open()) << "Failed to open " << secret_file; + float secret_data[] = {42.0f, 99.0f}; + ofs.write(reinterpret_cast(secret_data), sizeof(secret_data)); + ASSERT_TRUE(ofs.good()) << "Failed to write to " << secret_file; + } + + // Construct a SparseTensorProto whose values use external data with a path-traversal location. + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.add_dims(4); // dense shape: [4] + + // Values tensor: 2 non-zero float values stored in external file. + auto* values = sparse.mutable_values(); + values->set_name("sparse_test"); + values->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values->add_dims(2); // 2 non-zero elements + values->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* loc = values->add_external_data(); + loc->set_key("location"); + loc->set_value("../secret.bin"); // path traversal! + + auto* len_entry = values->add_external_data(); + len_entry->set_key("length"); + len_entry->set_value(std::to_string(2 * sizeof(float))); + + // Indices: positions 0 and 1 in the dense tensor. + auto* indices = sparse.mutable_indices(); + indices->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + indices->add_dims(2); + indices->add_int64_data(0); + indices->add_int64_data(1); + + // Attempt to convert — this should fail with a path validation error. + ONNX_NAMESPACE::TensorProto dense; + std::filesystem::path model_path = model_dir / "model.onnx"; + Status status = utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense); + ASSERT_FALSE(status.IsOK()) << "SparseTensorProtoToDenseTensorProto should reject path-traversal " + "in values external_data location, but it succeeded (reading " + "arbitrary file outside model directory)."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("escapes")); +} + +// Same as above but for path traversal in the indices external data. +TEST_F(PathValidationTest, SparseTensorExternalDataPathTraversalBlocked_Indices) { + auto model_dir = base_dir_ / "model_dir"; + std::error_code ec; + std::filesystem::create_directories(model_dir, ec); + ASSERT_FALSE(ec) << "Failed to create model_dir: " << ec.message(); + + // Write indices data (2 x int64) to a file outside the model directory. + auto secret_file = base_dir_ / "indices_secret.bin"; + { + std::ofstream ofs(secret_file, std::ios::binary); + ASSERT_TRUE(ofs.is_open()) << "Failed to open " << secret_file; + int64_t idx_data[] = {0, 1}; + ofs.write(reinterpret_cast(idx_data), sizeof(idx_data)); + ASSERT_TRUE(ofs.good()) << "Failed to write to " << secret_file; + } + + // Also need a valid values file inside the model directory. + auto values_file = model_dir / "values.bin"; + { + std::ofstream ofs(values_file, std::ios::binary); + ASSERT_TRUE(ofs.is_open()) << "Failed to open " << values_file; + float val_data[] = {1.0f, 2.0f}; + ofs.write(reinterpret_cast(val_data), sizeof(val_data)); + ASSERT_TRUE(ofs.good()) << "Failed to write to " << values_file; + } + + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.add_dims(4); + + // Values: legitimate external data within model directory. + auto* values = sparse.mutable_values(); + values->set_name("sparse_idx_test"); + values->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values->add_dims(2); + values->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* val_loc = values->add_external_data(); + val_loc->set_key("location"); + val_loc->set_value("values.bin"); + + auto* val_len = values->add_external_data(); + val_len->set_key("length"); + val_len->set_value(std::to_string(2 * sizeof(float))); + + // Indices: external data with path traversal. + auto* indices = sparse.mutable_indices(); + indices->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + indices->add_dims(2); + indices->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* idx_loc = indices->add_external_data(); + idx_loc->set_key("location"); + idx_loc->set_value("../indices_secret.bin"); // path traversal! + + auto* idx_len = indices->add_external_data(); + idx_len->set_key("length"); + idx_len->set_value(std::to_string(2 * sizeof(int64_t))); + + ONNX_NAMESPACE::TensorProto dense; + std::filesystem::path model_path = model_dir / "model.onnx"; + Status status = utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense); + ASSERT_FALSE(status.IsOK()) << "SparseTensorProtoToDenseTensorProto should reject path-traversal " + "in indices external_data location, but it succeeded."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("escapes")); +} + +// Regression test: SparseTensorProtoToDenseTensorProto must reject absolute paths +// in values external_data location. +TEST_F(PathValidationTest, SparseTensorExternalDataAbsolutePathBlocked_Values) { + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.add_dims(4); + + auto* values = sparse.mutable_values(); + values->set_name("abs_path_test"); + values->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values->add_dims(2); + values->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* loc = values->add_external_data(); + loc->set_key("location"); + loc->set_value("/data.bin"); // absolute path + + auto* len_entry = values->add_external_data(); + len_entry->set_key("length"); + len_entry->set_value(std::to_string(2 * sizeof(float))); + + auto* indices = sparse.mutable_indices(); + indices->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + indices->add_dims(2); + indices->add_int64_data(0); + indices->add_int64_data(1); + + ONNX_NAMESPACE::TensorProto dense; + std::filesystem::path model_path = base_dir_ / "model.onnx"; + Status status = utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense); + ASSERT_FALSE(status.IsOK()) << "SparseTensorProtoToDenseTensorProto should reject absolute path " + "in values external_data location."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Absolute path not allowed")); + +#ifdef _WIN32 + // Also verify Windows-style absolute path. + loc->set_value("C:\\data.bin"); + status = utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense); + ASSERT_FALSE(status.IsOK()) << "SparseTensorProtoToDenseTensorProto should reject Windows absolute path " + "in values external_data location."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Absolute path not allowed")); +#endif +} + +// Regression test: SparseTensorProtoToDenseTensorProto must reject absolute paths +// in indices external_data location. +TEST_F(PathValidationTest, SparseTensorExternalDataAbsolutePathBlocked_Indices) { + // Create a valid values file inside base_dir_ so values validation passes. + auto values_file = base_dir_ / "values.bin"; + { + std::ofstream ofs(values_file, std::ios::binary); + ASSERT_TRUE(ofs.is_open()) << "Failed to open " << values_file; + float val_data[] = {1.0f, 2.0f}; + ofs.write(reinterpret_cast(val_data), sizeof(val_data)); + ASSERT_TRUE(ofs.good()) << "Failed to write to " << values_file; + } + + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.add_dims(4); + + // Values: legitimate external data within base_dir_. + auto* values = sparse.mutable_values(); + values->set_name("abs_path_idx_test"); + values->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values->add_dims(2); + values->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* val_loc = values->add_external_data(); + val_loc->set_key("location"); + val_loc->set_value("values.bin"); + + auto* val_len = values->add_external_data(); + val_len->set_key("length"); + val_len->set_value(std::to_string(2 * sizeof(float))); + + // Indices: external data with absolute path. + auto* indices = sparse.mutable_indices(); + indices->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + indices->add_dims(2); + indices->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* idx_loc = indices->add_external_data(); + idx_loc->set_key("location"); + idx_loc->set_value("/data.bin"); // absolute path + + auto* idx_len = indices->add_external_data(); + idx_len->set_key("length"); + idx_len->set_value(std::to_string(2 * sizeof(int64_t))); + + ONNX_NAMESPACE::TensorProto dense; + std::filesystem::path model_path = base_dir_ / "model.onnx"; + Status status = utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense); + ASSERT_FALSE(status.IsOK()) << "SparseTensorProtoToDenseTensorProto should reject absolute path " + "in indices external_data location."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Absolute path not allowed")); + +#ifdef _WIN32 + idx_loc->set_value("C:\\data.bin"); + status = utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense); + ASSERT_FALSE(status.IsOK()) << "SparseTensorProtoToDenseTensorProto should reject Windows absolute path " + "in indices external_data location."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Absolute path not allowed")); +#endif +} + +// Regression test: validation must still reject escaping paths for zero-element dense tensors, +// which previously returned early before path validation ran. +TEST_F(PathValidationTest, SparseTensorExternalDataPathTraversalBlocked_ZeroDenseElements) { + auto model_dir = base_dir_ / "model_dir"; + std::error_code ec; + std::filesystem::create_directories(model_dir, ec); + ASSERT_FALSE(ec) << "Failed to create model_dir: " << ec.message(); + + // Create the escaping file so that a "file not found" error would NOT be raised. + auto secret_file = base_dir_ / "secret.bin"; + { + std::ofstream ofs(secret_file, std::ios::binary); + ASSERT_TRUE(ofs.is_open()) << "Failed to open " << secret_file; + ofs.put('\0'); + ASSERT_TRUE(ofs.good()) << "Failed to write to " << secret_file; + } + + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.add_dims(0); // dense shape [0] → dense_elements == 0 + + auto* values = sparse.mutable_values(); + values->set_name("zero_dense_test"); + values->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values->add_dims(0); // NNZ=0 + values->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* loc = values->add_external_data(); + loc->set_key("location"); + loc->set_value("../secret.bin"); // path traversal + + auto* len_entry = values->add_external_data(); + len_entry->set_key("length"); + len_entry->set_value("0"); + + auto* indices = sparse.mutable_indices(); + indices->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + indices->add_dims(0); + + ONNX_NAMESPACE::TensorProto dense; + std::filesystem::path model_path = model_dir / "model.onnx"; + Status status = utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense); + ASSERT_FALSE(status.IsOK()) << "Should reject path-traversal in values even when dense_elements == 0."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("escapes")); +} + +// Regression test: validation must reject escaping paths in indices even when NNZ == 0. +TEST_F(PathValidationTest, SparseTensorExternalDataPathTraversalBlocked_ZeroNNZ) { + auto model_dir = base_dir_ / "model_dir"; + std::error_code ec; + std::filesystem::create_directories(model_dir, ec); + ASSERT_FALSE(ec) << "Failed to create model_dir: " << ec.message(); + + // Create the escaping file so that a "file not found" error would NOT be raised. + auto secret_file = base_dir_ / "indices_secret.bin"; + { + std::ofstream ofs(secret_file, std::ios::binary); + ASSERT_TRUE(ofs.is_open()) << "Failed to open " << secret_file; + ofs.put('\0'); + ASSERT_TRUE(ofs.good()) << "Failed to write to " << secret_file; + } + + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.add_dims(4); // dense shape [4] → non-zero dense_elements + + auto* values = sparse.mutable_values(); + values->set_name("zero_nnz_test"); + values->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values->add_dims(0); // NNZ=0 + + auto* indices = sparse.mutable_indices(); + indices->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + indices->add_dims(0); + indices->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* idx_loc = indices->add_external_data(); + idx_loc->set_key("location"); + idx_loc->set_value("../indices_secret.bin"); // path traversal + + auto* idx_len = indices->add_external_data(); + idx_len->set_key("length"); + idx_len->set_value("0"); + + ONNX_NAMESPACE::TensorProto dense; + std::filesystem::path model_path = model_dir / "model.onnx"; + Status status = utils::SparseTensorProtoToDenseTensorProto(sparse, model_path, dense); + ASSERT_FALSE(status.IsOK()) << "Should reject path-traversal in indices even when NNZ == 0."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("escapes")); +} + +#endif // !defined(DISABLE_SPARSE_TENSORS) + TEST(TensorProtoUtilsTest, GetNodeProtoLayeringAnnotation) { // Case 1: Annotation exists { diff --git a/onnxruntime/test/mlas/unittest/test_activation.cpp b/onnxruntime/test/mlas/unittest/test_activation.cpp index a4334c6c80477..73d18d8a7dc38 100644 --- a/onnxruntime/test/mlas/unittest/test_activation.cpp +++ b/onnxruntime/test/mlas/unittest/test_activation.cpp @@ -247,7 +247,8 @@ class MlasActivationTest : public MlasTestBase { for (unsigned i = 0; i < _countof(TestData); i++) { // Sensitive to comparing positive/negative zero and NaNs. float error = std::min(std::fabs((Buffer[i].f - TestData[i][kind].f) / TestData[i][kind].f), std::fabs(Buffer[i].f - TestData[i][kind].f)); - EXPECT_TRUE(Buffer[i].u == TestData[i][kind].u || Buffer[i].f == TestData[i][kind].f || error < 0.000001f) + EXPECT_TRUE(Buffer[i].u == TestData[i][kind].u || Buffer[i].f == TestData[i][kind].f || error < 0.000001f || + (std::isnan(Buffer[i].f) && std::isnan(TestData[i][kind].f))) << ", Scalar Activation Kind:" << (int)kind << ", i=" << i << ", value:" << std::setw(8) << std::setfill('0') << std::hex << Buffer[i].u << ", expecting:" << std::setw(8) << std::setfill('0') << std::hex << TestData[i][kind].u; diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 0e14bc59a09c9..038a8eaade116 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -3127,6 +3127,63 @@ TEST(CastOpTest, CopyCpuTensor_SubByteTypes_DistinctBuffers) { } } +// Correctness test for Cast kernel with a moderately large tensor. +// Exercises the same kernel code path as tensors > 2^31 elements but stays within +// CI GPU memory limits. For the actual overflow scenario, see the host-side test below. +TEST(CastOpTest, CastKernelCorrectness_ModerateSize) { + constexpr int64_t num_elements = 1 << 24; // 16M elements + const std::vector shape = {num_elements}; + + std::vector input(num_elements); + std::vector expected(num_elements); + for (int64_t i = 0; i < num_elements; ++i) { + input[i] = static_cast(i % 1000); + expected[i] = static_cast(i % 1000); + } + + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape); +} + +// Host-side regression test that verifies the grid launch arithmetic uses 64-bit +// types for element counts exceeding INT32_MAX. This validates the fix without +// needing to allocate > 8 GB of GPU memory. +// The fix changed: +// CUDA_LONG N = static_cast(count) // was int32 truncation +// to: +// int64_t N = static_cast(count) // correct 64-bit +TEST(CastOpTest, CastKernel_Int64IndexArithmetic_NoOverflow) { + // Simulate the grid launch calculation from UnaryElementWiseImpl / CudaCastStd + // with a count that exceeds INT32_MAX. + constexpr size_t count = static_cast(INT32_MAX) + 65536; // 2^31 + 65536 + constexpr int maxThreadsPerBlock = 256; + constexpr int maxElementsPerThread = 4; + + // Verify N is correctly represented (not truncated to int32) + int64_t N = static_cast(count); + ASSERT_GT(N, static_cast(INT32_MAX)); + ASSERT_EQ(N, static_cast(count)); + + // Verify blocksPerGrid calculation doesn't overflow + // (uses size_t arithmetic for the divisor) + size_t elements_per_block = static_cast(maxThreadsPerBlock) * maxElementsPerThread; + int blocksPerGrid = static_cast((count + elements_per_block - 1) / elements_per_block); + ASSERT_GT(blocksPerGrid, 0); + // For count = 2^31 + 65536, elements_per_block = 1024, we expect ~2M blocks + ASSERT_EQ(blocksPerGrid, static_cast((count + 1023) / 1024)); + + // Verify that the per-thread index calculation doesn't overflow in int64_t + // Simulate the last block's thread 0: id = NumElementsPerThread * NumThreadsPerBlock * (blocksPerGrid-1) + 0 + int64_t last_block_start = static_cast(maxElementsPerThread) * maxThreadsPerBlock * + (blocksPerGrid - 1); + ASSERT_GT(last_block_start, 0); // Positive (no overflow) + ASSERT_LE(last_block_start, N); // Within bounds + + // Verify the old int32 code would have failed: + // static_cast(count) would silently wrap + int32_t truncated_N = static_cast(count); + ASSERT_LT(truncated_N, 0); // Proves the old code was broken (wraps negative) +} + #if !defined(DISABLE_FLOAT8_TYPES) float FloatFromBits(uint32_t bits) { diff --git a/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc b/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc index 9ca081a74c850..47c9978b3a9c8 100644 --- a/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc +++ b/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc @@ -574,6 +574,89 @@ TEST(XnnpackEP, DISABLED_TestResize_u8_and_s8_NHWC_pytorch_half_pixel) { // [ON {ExpectedEPNodeAssignment::Some, 1e-2f /* fp32_abs_err */}); } +// Regression test for https://github.com/microsoft/onnxruntime/issues/28541. +// A two-input Gemm (no optional C bias) used to dereference a null NodeArg pointer in +// Gemm::IsOnnxNodeSupported, segfaulting InferenceSession::Initialize before any kernel +// ran. The capability check must accept the missing-C case and let the node be assigned +// to XNNPACK without crashing. +TEST(XnnpackEP, TestGemm_NoC_NoSegfault) { + const std::vector a_shape = {2, 3}; + const std::vector b_shape = {3, 4}; + auto modelBuilder = [&](ModelTestBuilder& builder) { + auto* input_a = builder.MakeInput(a_shape, -1.f, 1.f); + auto* input_b = builder.MakeInitializer(b_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + auto& gemm_node = builder.AddNode("Gemm", {input_a, input_b}, {output_arg}); + gemm_node.AddAttribute("alpha", 1.0f); + gemm_node.AddAttribute("beta", 1.0f); + gemm_node.AddAttribute("transA", static_cast(0)); + gemm_node.AddAttribute("transB", static_cast(0)); + }; + // ExpectedEPNodeAssignment::All asserts both that the session initialized without + // segfaulting AND that XNNPACK accepted the 2-input Gemm node. + RunModelTest(modelBuilder, "xnnpack_test_graph_gemm_no_c", + { + ExpectedEPNodeAssignment::All, + 1e-4f /* fp32_abs_err */, + }); +} + +// Regression test for https://github.com/microsoft/onnxruntime/issues/28542. +// A Gemm with a scalar (rank 0) C bias used to skip past the dim_size() >= 3 guard and +// then crash on C_shape->dim(0) inside Gemm::IsOnnxNodeSupported. The capability check +// must reject rank-0 C cleanly so the node falls back to the CPU EP and the session +// initializes without segfaulting. RunModelTest compares the XNNPACK + CPU run against +// the pure CPU baseline, so this also verifies numerical correctness end to end. +TEST(XnnpackEP, TestGemm_ScalarC_NoSegfault) { + const std::vector a_shape = {2, 3}; + const std::vector b_shape = {3, 4}; + auto modelBuilder = [&](ModelTestBuilder& builder) { + auto* input_a = builder.MakeInput(a_shape, -1.f, 1.f); + auto* input_b = builder.MakeInitializer(b_shape, -1.f, 1.f); + auto* input_c = builder.MakeScalarInitializer(0.5f); + auto* output_arg = builder.MakeOutput(); + auto& gemm_node = builder.AddNode("Gemm", {input_a, input_b, input_c}, {output_arg}); + gemm_node.AddAttribute("alpha", 1.0f); + gemm_node.AddAttribute("beta", 1.0f); + gemm_node.AddAttribute("transA", static_cast(0)); + gemm_node.AddAttribute("transB", static_cast(0)); + }; + RunModelTest(modelBuilder, "xnnpack_test_graph_gemm_scalar_c", + { + ExpectedEPNodeAssignment::None, + 1e-4f /* fp32_abs_err */, + }); +} + +// Defense-in-depth regression test for the C_arg->Exists() == false branch. A 3-input +// Gemm whose C slot is an empty optional input (Exists() == false) is semantically +// equivalent to a 2-input Gemm per ONNX (empty optional input == omitted input). The +// XNNPACK support check now treats them identically: both are accepted and routed to +// the bias=nullptr path in xnn_create_fully_connected_nc_*. This locks in the +// consistency between the 2-input case (TestGemm_NoC_NoSegfault) and the empty-optional +// case, matching the kernel constructor's C_matrix_exists_ = C_arg && C_arg->Exists() +// contract. +TEST(XnnpackEP, TestGemm_EmptyC_NoSegfault) { + const std::vector a_shape = {2, 3}; + const std::vector b_shape = {3, 4}; + auto modelBuilder = [&](ModelTestBuilder& builder) { + auto* input_a = builder.MakeInput(a_shape, -1.f, 1.f); + auto* input_b = builder.MakeInitializer(b_shape, -1.f, 1.f); + auto* input_c = builder.MakeEmptyInput(); + auto* output_arg = builder.MakeOutput(); + auto& gemm_node = builder.AddNode("Gemm", {input_a, input_b, input_c}, {output_arg}); + gemm_node.AddAttribute("alpha", 1.0f); + gemm_node.AddAttribute("beta", 1.0f); + gemm_node.AddAttribute("transA", static_cast(0)); + gemm_node.AddAttribute("transB", static_cast(0)); + }; + RunModelTest(modelBuilder, "xnnpack_test_graph_gemm_empty_c", + { + ExpectedEPNodeAssignment::All, + 1e-4f /* fp32_abs_err */, + }); +} + #endif } // namespace test diff --git a/plugin-ep-webgpu/_packaging_utils.py b/plugin-ep-webgpu/_packaging_utils.py index 201b3342ff39c..84850e4dee5fe 100644 --- a/plugin-ep-webgpu/_packaging_utils.py +++ b/plugin-ep-webgpu/_packaging_utils.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. + """Shared utilities for the WebGPU plugin EP packaging scripts. Not a public API.""" from __future__ import annotations diff --git a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj index 58860c46b9c16..5bfbac0308e01 100644 --- a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj +++ b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj @@ -16,7 +16,8 @@ ONNX;ONNX Runtime;Machine Learning;AI;Deep Learning;WebGPU - MIT + LICENSE + https://onnxruntime.ai https://github.com/microsoft/onnxruntime git © Microsoft Corporation. All rights reserved. @@ -29,6 +30,9 @@ + + +