From 158bdef0183fa25f8ddf9416f964146f38cc60a9 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 21 May 2026 15:52:51 -0700 Subject: [PATCH 01/16] Fix oob bias access for MatMulIntegerToFloat and DynamicQuantizeMatMul (#28499) ### Description Fixes a heap out-of-bounds read vulnerability in `DynamicQuantizeMatMul` and `MatMulIntegerToFloat` where a bias tensor with an incorrect number of elements could cause memory reads beyond the allocated buffer. ## Changes - **`dynamic_quantize_matmul.cc`**: Added element count validation for the bias tensor in both the `ComputeCommon` path and the deferred bias addition path (KleidiAI). - **`matmul_integer_base.h`**: Added element count validation in the KleidiAI pre-pack path, causing fallback to `ComputeCommon` (which then rejects the invalid bias with a clear error). - **Tests**: Added regression tests covering runtime bias mismatch, initializer bias mismatch (KleidiAI fallback), and the generic (non-KleidiAI) path for both operators. ## Why we validate element count, not shape (rank) The validation checks `bias_tensor->Shape().Size() == N` (total element count) rather than enforcing that the bias is strictly 1D. This is intentional for several reasons: 1. **Backward compatibility with existing models.** It's possible that some models may have bias tensors with shape `(1, N)` instead of `(N)`. Enforcing rank == 1 would break these models at runtime. This exact issue occurred with the GroupQueryAttention operator, which required relaxing its shape validation in PR #28259. 2. **Consistent with ONNX standard practice.** Most official ONNX operator schemas (Conv, ConvTranspose, DeformConv, Gemm, LayerNormalization) do *not* validate bias shape in their schema's `TypeAndShapeInferenceFunction`; they only document "1D" in the input description text. `BatchNormalization` is the only exception. 3. **The kernel only needs N contiguous floats.** The compute implementation accesses bias via raw data pointer (`bias->Data()`) and reads exactly `N` elements. It never indexes into specific dimensions or assumes a particular rank. A bias of shape `(N)`, `(1, N)`, or `(1, 1, N)` all work identically. 4. **Schema constraints cannot be relaxed without a version bump.** If we added a strict rank check to the schema now and later discovered models using `(1, N)`, fixing it would probably require a new opset version (though we've never actually bumped the version for contrib ops ...). ## Motivation and Context Without this fix, passing a bias tensor with fewer elements than `B`'s last dimension causes the kernel to read past the end of the bias buffer, potentially exposing sensitive memory contents or causing a crash. --- .../quantization/dynamic_quantize_matmul.cc | 14 +++- .../cpu/quantization/matmul_integer_base.h | 3 + .../dynamic_quantize_matmul_test.cc | 71 +++++++++++++++++++ .../matmul_integer_to_float_test.cc | 67 +++++++++++++++++ 4 files changed, 153 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index f910abb538821..d051c5423c367 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -80,6 +80,12 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, if (y->Shape().Size() == 0) return Status::OK(); + if (bias_tensor != nullptr) { + ORT_RETURN_IF_NOT(bias_tensor->Shape().Size() == static_cast(helper.N()), + "bias tensor's element count must equal B's last dimension (", + helper.N(), "), but got ", bias_tensor->Shape().Size()); + } + auto* y_data = y->MutableData(); const auto* bias_data = bias_tensor != nullptr ? bias_tensor->Data() : nullptr; @@ -306,8 +312,12 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { // This evaluates to true if bias data was not provided as constant data for prepacking stage if (!dynamic_quant_mlas_bias_data_was_packed_) { if (ctx->Input(IN_BIAS) != nullptr) { - const auto biases = std::vector(&ctx->Input(IN_BIAS)->Data()[0], - &ctx->Input(IN_BIAS)->Data()[gemm_shape.N]); + const Tensor* bias_t = ctx->Input(IN_BIAS); + ORT_RETURN_IF_NOT(bias_t->Shape().Size() == static_cast(gemm_shape.N), + "bias tensor's element count must equal B's last dimension (", + gemm_shape.N, "), but got ", bias_t->Shape().Size()); + const auto biases = std::vector(&bias_t->Data()[0], + &bias_t->Data()[gemm_shape.N]); // deferred adding of bias for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) { diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index 9916c426a54fe..8d0ab6a53b88e 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -208,6 +208,9 @@ class MatMulIntegerBase : public OpKernel { } if (ctx.bias != nullptr) { + if (ctx.bias->Shape().Size() != static_cast(ctx.N)) { + return false; + } dynamic_quant_mlas_bias_data_was_packed_ = true; } diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc index fb64d6fa9b66d..5287859292f1f 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc @@ -421,6 +421,47 @@ TEST(DynamicQuantizeMatMul, KleidiRejectsUnsupportedBShape) { test.Run(); } +// 6. Mismatched bias (runtime tensor) -> must be rejected at compute time. +TEST(DynamicQuantizeMatMul, KleidiBiasRuntimeShapeMismatch) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + // Bias has only 1 element but N=3 — this must be rejected. + const std::vector bad_bias = {1.0f}; + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true /*initializer*/); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp, true); + test.AddInput("bias", {1}, bad_bias, false /*runtime*/); + test.AddOutput("Y", {data.M, data.N}, std::vector(data.M * data.N, 0.0f)); + test.ConfigEp(DefaultCpuExecutionProvider()) + .Config(OpTester::ExpectResult::kExpectFailure, + "bias tensor's element count must equal B's last dimension") + .RunWithConfig(); +} + +// 7. Mismatched bias (constant initializer) -> KleidiAI pre-pack rejects -> falls back to ComputeCommon +// -> rejected +TEST(DynamicQuantizeMatMul, KleidiBiasInitializerShapeMismatch) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + // Bias has only 1 element but N=3 — this must be rejected. + const std::vector bad_bias = {1.0f}; + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true /*initializer*/); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp, true); + test.AddInput("bias", {1}, bad_bias, true /*initializer*/); + test.AddOutput("Y", {data.M, data.N}, std::vector(data.M * data.N, 0.0f)); + test.ConfigEp(DefaultCpuExecutionProvider()) + .Config(OpTester::ExpectResult::kExpectFailure, + "bias tensor's element count must equal B's last dimension") + .RunWithConfig(); +} + #endif // USE_KLEIDIAI TEST(DynamicQuantizeMatMul, B_PerColumn_ND) { @@ -486,5 +527,35 @@ TEST(DynamicQuantizeMatMul, B_PerColumn_ND) { test_case({15, 14, 13}, {15, 13, 27}, {15, 1, 27}); } +// Test that a bias tensor with length mismatched to B's last dimension is rejected. +// This reproduces a heap OOB read when bias is shorter than N. +TEST(DynamicQuantizeMatMul, BiasShapeMismatch) { + constexpr int64_t M = 2; + constexpr int64_t K = 4; + constexpr int64_t N = 8; + + std::vector A_data(M * K, 1.0f); + std::vector B_data(K * N, 128); + std::vector B_scale = {0.5f}; + std::vector B_zero_point = {128}; + + // Bias has only 1 element but N=8 — this must be rejected. + std::vector bad_bias = {1.0f}; + + OpTester test("DynamicQuantizeMatMul", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {1}, bad_bias); + + test.AddOutput("Y", {M, N}, std::vector(M * N, 0.0f)); + + test.ConfigEp(DefaultCpuExecutionProvider()) + .Config(OpTester::ExpectResult::kExpectFailure, + "bias tensor's element count must equal B's last dimension") + .RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 30b0c0fcf73c3..7142358a4e02c 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -489,5 +489,72 @@ TEST(MatMulIntegerToFloat, MatMulInteger_With_ZeroPoint) { test_case({15, 14, 13}, {15, 13, 27}, {15, 1, 27}); } +// Test that a bias tensor with length mismatched to B's last dimension is rejected. +// This reproduces a heap OOB read when bias is shorter than N. +TEST(MatMulIntegerToFloat, BiasShapeMismatch) { + constexpr int64_t M = 2; + constexpr int64_t K = 4; + constexpr int64_t N = 8; + + std::vector A_data(M * K, 128); + std::vector B_data(K * N, 128); + std::vector A_scale = {0.5f}; + std::vector B_scale = {0.5f}; + std::vector A_zero_point = {128}; + std::vector B_zero_point = {128}; + + // Bias has only 1 element but N=8. This must be rejected. + std::vector bad_bias = {1.0f}; + + OpTester test("MatMulIntegerToFloat", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {1}, bad_bias); + + test.AddOutput("Y", {M, N}, std::vector(M * N, 0.0f)); + + test.ConfigEp(DefaultCpuExecutionProvider()) + .Config(OpTester::ExpectResult::kExpectFailure, + "bias tensor's element count must equal B's last dimension") + .RunWithConfig(); +} + +// Test that a bias tensor with length larger than B's last dimension is rejected. +TEST(MatMulIntegerToFloat, BiasShapeMismatch_LargerBias) { + constexpr int64_t M = 2; + constexpr int64_t K = 4; + constexpr int64_t N = 8; + + std::vector A_data(M * K, 128); + std::vector B_data(K * N, 128); + std::vector A_scale = {0.5f}; + std::vector B_scale = {0.5f}; + std::vector A_zero_point = {128}; + std::vector B_zero_point = {128}; + + // Bias has length > N, which must be rejected. + std::vector bad_bias(static_cast(N + 1), 1.0f); + + OpTester test("MatMulIntegerToFloat", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {N + 1}, bad_bias); + + test.AddOutput("Y", {M, N}, std::vector(M * N, 0.0f)); + + test.ConfigEp(DefaultCpuExecutionProvider()) + .Config(OpTester::ExpectResult::kExpectFailure, + "bias tensor's element count must equal B's last dimension") + .RunWithConfig(); +} + } // namespace test } // namespace onnxruntime From ee444bdc3679b4002466a50643611fed4d9704cc Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Thu, 21 May 2026 16:22:13 -0700 Subject: [PATCH 02/16] Validate conv bias shape in WordConvEmbedding to prevent OOB read (#28279) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Add shape validation for the conv bias input (`B`) in `WordConvEmbedding::Compute()` to prevent out-of-bounds heap reads when a crafted model provides a bias tensor shorter than `num_filters`. ### Root Cause `WordConvEmbedding::Compute` passes `b_conv.Data()` directly to `ComputeConvMaxPoolWithActivation`, which iterates over `num_filters` (= `w_conv.shape[0]`) elements of the bias buffer. `ValidateInputShape` only checks the sequence, conv weight, and char embedding shapes — the bias shape is never validated. A model with `b_conv.shape[0] < w_conv.shape[0]` causes the inner loop to read past the bias buffer, and the leaked heap bytes propagate through tanh activation and max-pooling into the output tensor. ### Fix Add an inline check after `ValidateInputShape` that rejects bias tensors whose shape is not `[num_filters]`: ```cpp ORT_RETURN_IF_NOT(b_conv_shape.NumDimensions() == 1 && b_conv_shape[0] == w_conv_shape[0], "WordConvEmbedding: conv bias B must be a 1-D tensor of length ", w_conv_shape[0], ", but got shape ", b_conv_shape); --- .../contrib_ops/cpu/word_conv_embedding.cc | 5 ++++ .../contrib_ops/word_conv_embedding_test.cc | 26 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc index 4c0c86aa60729..22cdcb75de126 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc @@ -202,6 +202,11 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const { ORT_RETURN_IF_ERROR(ValidateInputShape(sequence_shape, w_conv_shape, w_char_embedding_shape)); + const TensorShape& b_conv_shape = b_conv.Shape(); + ORT_RETURN_IF_NOT(b_conv_shape.NumDimensions() == 1 && b_conv_shape[0] == w_conv_shape[0], + "WordConvEmbedding: conv bias B must be a 1-D tensor of length ", + w_conv_shape[0], ", but got shape ", b_conv_shape); + int64_t seq_len = sequence_shape[0]; int64_t word_len = sequence_shape[1]; int64_t char_embedding_size = w_char_embedding_shape[1]; diff --git a/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc b/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc index 3f50166438190..7d159e934c927 100644 --- a/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc +++ b/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc @@ -167,5 +167,31 @@ TEST(ContribOpTest, WordConvEmbedding_rejects_sequence_rank_one) { test.Run(OpTester::ExpectResult::kExpectFailure, "Sequence input must have rank greater than 1"); } +TEST(ContribOpTest, WordConvEmbedding_rejects_undersized_bias) { + OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain); + + // W has 2 filters but B has only 1 element + test.AddInput("Sequence", {1, 2}, {1, 2}); + test.AddInput("W", {2, 1, 2, 1}, {1.0f, 1.0f, 1.0f, 1.0f}); + test.AddInput("B", {1}, {0.0f}); + test.AddInput("C", {3, 1}, {0.0f, 1.0f, 2.0f}); + test.AddOutput("Y", {1, 2}, {0.0f, 0.0f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "conv bias B must be a 1-D tensor of length 2"); +} + +TEST(ContribOpTest, WordConvEmbedding_rejects_2d_bias) { + OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain); + + // B has correct element count but wrong rank + test.AddInput("Sequence", {1, 2}, {1, 2}); + test.AddInput("W", {2, 1, 2, 1}, {1.0f, 1.0f, 1.0f, 1.0f}); + test.AddInput("B", {1, 2}, {0.0f, 0.0f}); + test.AddInput("C", {3, 1}, {0.0f, 1.0f, 2.0f}); + test.AddOutput("Y", {1, 2}, {0.0f, 0.0f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "conv bias B must be a 1-D tensor of length 2"); +} + } // namespace test } // namespace onnxruntime From cf8e4f5551cf188b434c9da9342553ebdfa67fe2 Mon Sep 17 00:00:00 2001 From: Max Buckley Date: Fri, 22 May 2026 01:35:15 +0200 Subject: [PATCH 03/16] [CoreML EP] Support Gather with scalar 'indices' (#28278) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description The CoreML \`GatherOpBuilder\` rejected rank-0 (scalar) \`indices\` because CoreML's \`gather\` requires rank-1+ indices and the obvious workaround would change the output rank — see the \`// Don't allow scalar 'indices' input.\` comment at \`gather_op_builder.cc:90\`. This PR performs the workaround internally: \`\`\` reshape(indices, shape=[1]) -> indices_1d gather(data, indices_1d, axis) -> data_shape with the gather axis = 1 squeeze(., axes=[axis]) -> ONNX gather output shape \`\`\` …in both the MLProgram and NeuralNetwork emitters. The squeeze restores the original ONNX output rank, so caller-visible Gather semantics are unchanged. \`reshape\` is used rather than \`expand_dims\` because CoreML internally pads scalars and \`expand_dims\` on the padded tensor can push the apparent rank past the rank-5 limit on high-rank \`data\`. Restrictions: - \`data\` must have a fully static shape — we claim a static intermediate shape between gather and squeeze. - \`data\` rank capped at 4. The rank-5 case still trips CoreML's compiler with \`Invalid rank: 6\`, so we keep the conservative bound. Dynamic-shape and rank-5+ scalar Gather still falls back to CPU (preserves the existing \`GatherWithScalarIndices\` test, whose data is dynamic-shape). Fixes #28180. ### Motivation StyleGAN-family generators (StyleGAN, StyleGAN2, GFPGAN, …) select per-layer style codes with a scalar-index Gather. The resulting graph alternates between Gather and the rest of the generator, splitting the CoreML subgraph repeatedly. On GFPGAN-1024 (`[1, 3, 512, 512]`), this PR moves all 16 scalar Gathers off CPU and the model lands on a single CoreML partition. **M3 Max, MLProgram, batch 1, 3 × 100-iter steady-state runs (n=300):** | | Partitions | Mean | StdDev | P50 | P95 | P99 | Max | |---|---|---|---|---|---|---|---| | origin/main | 2 | 89.68 ms | 3.67 | 87.82 | 96.71 | 105.00 | 108.01 | | **this PR** | **1** | **81.77 ms** | **1.85** | **80.97** | **85.98** | **87.24** | **88.03** | **Mean −8.8%, stddev −50%, P99 −17%, max −18.5%** — eliminating the CPU↔CoreML round-trip on every scalar Gather both speeds up the steady state and tightens the tail. Striking secondary effect: the worst-case run with the fix (**88.03 ms**) is faster than the *mean* run without it (**89.68 ms**). Every single fixed inference over n=300 lands below the unfixed average. ### Tests Six new tests in \`onnxruntime/test/providers/coreml/coreml_basic_test.cc\` covering distinct code paths, exercised on both NeuralNetwork and MLProgram emitters where the dtype is supported: - \`GatherScalarIndicesAxis1\` — axis=1, mid-rank squeeze. - \`GatherScalarIndicesAxis0\` — axis=0, leading-axis squeeze. - \`GatherScalarIndicesNegativeAxis\` — axis=-1, exercises \`HandleNegativeAxis\`. - \`GatherScalarIndicesFloat16\` — fp16 data (MLProgram only, as per \`HasSupportedInputsImpl\`). - \`GatherScalarIndicesInt64Data\` — int64 data, both formats. - \`GatherScalarIndicesRank4Data\` — rank-4 data, exercises the supported maximum. Each verifies CoreML output against the CPU EP reference and asserts \`ExpectedEPNodeAssignment::All\`. The existing \`GatherWithScalarIndices\` test (dynamic-shape data) is updated only in its comment to reflect the new precise condition; it still exercises the CPU fall-back as before. All pass locally on macOS 26.3 / M3 Max. --------- Co-authored-by: Claude Opus 4.7 (1M context) --- .../coreml/builders/impl/gather_op_builder.cc | 167 ++++- .../providers/coreml/coreml_basic_test.cc | 657 +++++++++++++++++- 2 files changed, 800 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc index 8b58f5dc6c927..5059c6c9edd8f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc @@ -30,27 +30,121 @@ int64_t GetAxisAttribute(const Node& node) { } // namespace Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /*logger*/) const { + const logging::Logger& logger) const { + const auto axis = GetAxisAttribute(node); + const auto& data_def = *node.InputDefs()[0]; + const auto& indices_def = *node.InputDefs()[1]; + const auto& output_def = *node.OutputDefs()[0]; + + std::vector data_shape, indices_shape; + ORT_RETURN_IF_NOT(GetShape(data_def, data_shape, logger), "Failed to get 'data' shape"); + ORT_RETURN_IF_NOT(GetShape(indices_def, indices_shape, logger), "Failed to get 'indices' shape"); + + // ONNX Gather: out_shape = data_shape[:axis] + indices_shape + data_shape[axis+1:] + // CoreML's gather requires rank-1+ indices, so for scalar indices we promote + // them to [1], gather, and then squeeze the resulting axis to restore the + // original output rank. The positive axis after wrapping is needed for the + // squeeze axis below regardless of path. + const bool scalar_indices = indices_shape.empty(); + const int64_t pos_axis = HandleNegativeAxis(axis, data_shape.size()); + if (model_builder.CreateMLProgram()) { using CoreML::Specification::MILSpec::Operation; - std::unique_ptr op = model_builder.CreateOperation(node, "gather"); - - const auto axis = GetAxisAttribute(node); + // IsOpSupportedImpl gates indices to INT32 or INT64, so we can pass the + // dtype straight through to the reshape's intermediate output. + int32_t indices_dtype{}; + ORT_RETURN_IF_NOT(GetType(indices_def, indices_dtype, logger), + "Failed to get 'indices' dtype"); + const int32_t output_dtype = static_cast(output_def.TypeAsProto()->tensor_type().elem_type()); + + std::string indices_name = indices_def.Name(); + + if (scalar_indices) { + // [] -> [1] via reshape. We use reshape rather than expand_dims because + // CoreML internally pads scalars; expand_dims on the padded tensor can + // push the apparent rank past the rank-5 limit on high-rank `data`. + auto reshape = model_builder.CreateOperation(node, "reshape", "indices"); + AddOperationInput(*reshape, "x", indices_def.Name()); + const std::vector indices_1d_shape = {1}; + AddOperationInput(*reshape, "shape", + model_builder.AddConstant(reshape->type(), "shape", indices_1d_shape)); + + indices_name = model_builder.GetUniqueName(node, "indices_1d"); + AddIntermediateOperationOutput(*reshape, indices_name, indices_dtype, indices_1d_shape); + model_builder.AddOperation(std::move(reshape)); + } + + std::unique_ptr gather = model_builder.CreateOperation(node, "gather"); // coreml docs claims validate_indices is optional but in practice it is required const auto validate_indices = false; - AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); // data - AddOperationInput(*op, "indices", node.InputDefs()[1]->Name()); // indices - AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); // axis attr - AddOperationInput(*op, "validate_indices", model_builder.AddScalarConstant(op->type(), "validate_indices", validate_indices)); - AddOperationOutput(*op, *node.OutputDefs()[0]); // output - model_builder.AddOperation(std::move(op)); + AddOperationInput(*gather, "x", data_def.Name()); + AddOperationInput(*gather, "indices", indices_name); + AddOperationInput(*gather, "axis", model_builder.AddScalarConstant(gather->type(), "axis", axis)); + AddOperationInput(*gather, "validate_indices", + model_builder.AddScalarConstant(gather->type(), "validate_indices", validate_indices)); + + if (!scalar_indices) { + AddOperationOutput(*gather, output_def); + model_builder.AddOperation(std::move(gather)); + } else { + // gather output here has the data's rank (one more than ONNX scalar-gather output); + // squeeze the inserted axis to recover the original output shape. + TensorShapeVector gather_shape{data_shape.begin(), data_shape.end()}; + gather_shape[pos_axis] = 1; + const std::string& gather_out_name = model_builder.GetUniqueName(node, "gather_out"); + AddIntermediateOperationOutput(*gather, gather_out_name, output_dtype, gather_shape); + model_builder.AddOperation(std::move(gather)); + + auto squeeze = model_builder.CreateOperation(node, "squeeze", "post"); + AddOperationInput(*squeeze, "x", gather_out_name); + const std::vector sq_axes = {pos_axis}; + AddOperationInput(*squeeze, "axes", model_builder.AddConstant(squeeze->type(), "axes", sq_axes)); + AddOperationOutput(*squeeze, output_def); + model_builder.AddOperation(std::move(squeeze)); + } } else { - auto layer = model_builder.CreateNNLayer(node); - layer->mutable_gather()->set_axis(GetAxisAttribute(node)); - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); // data - *layer->mutable_input()->Add() = node.InputDefs()[1]->Name(); // indices - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); // output - model_builder.AddLayer(std::move(layer)); + if (!scalar_indices) { + auto layer = model_builder.CreateNNLayer(node); + layer->mutable_gather()->set_axis(axis); + *layer->mutable_input()->Add() = data_def.Name(); + *layer->mutable_input()->Add() = indices_def.Name(); + *layer->mutable_output()->Add() = output_def.Name(); + model_builder.AddLayer(std::move(layer)); + } else { + // expand_dims indices: [] -> [1]. Unlike the MLProgram reshape path + // above, NN's expand_dims doesn't internally pad rank, so we don't run + // into the apparent-rank inflation that forced reshape+gather there; + // expand_dims is the natural choice on this path. + const std::string& indices_1d_name = model_builder.GetUniqueName(node, "indices_1d"); + { + auto expand_layer = model_builder.CreateNNLayer(node, "_indices_expand"); + expand_layer->mutable_expanddims()->add_axes(0); + *expand_layer->mutable_input()->Add() = indices_def.Name(); + *expand_layer->mutable_output()->Add() = indices_1d_name; + model_builder.AddLayer(std::move(expand_layer)); + } + + // gather with the promoted indices + const std::string& gather_out_name = model_builder.GetUniqueName(node, "gather_out"); + { + auto gather_layer = model_builder.CreateNNLayer(node); + gather_layer->mutable_gather()->set_axis(axis); + *gather_layer->mutable_input()->Add() = data_def.Name(); + *gather_layer->mutable_input()->Add() = indices_1d_name; + *gather_layer->mutable_output()->Add() = gather_out_name; + model_builder.AddLayer(std::move(gather_layer)); + } + + // squeeze the inserted axis + { + auto squeeze_layer = model_builder.CreateNNLayer(node, "_post_squeeze"); + squeeze_layer->mutable_squeeze()->add_axes(pos_axis); + squeeze_layer->mutable_squeeze()->set_squeezeall(false); + *squeeze_layer->mutable_input()->Add() = gather_out_name; + *squeeze_layer->mutable_output()->Add() = output_def.Name(); + model_builder.AddLayer(std::move(squeeze_layer)); + } + } } return Status::OK(); } @@ -87,14 +181,45 @@ bool GatherOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa return false; } - // Don't allow scalar 'indices' input. - // We convert scalar inputs to tensors with shape [1] before providing them to CoreML. - // This modification changes the shape of the Gather output. - if (indices_shape.empty()) { - LOGS(logger, VERBOSE) << "Gather does not support scalar 'indices'"; + // ONNX Gather schema constrains indices to int32 or int64. Validate here so + // AddToModelBuilderImpl can trust the dtype rather than silently defaulting + // on an unexpected value. + int32_t indices_dtype{}; + if (!GetType(*node.InputDefs()[1], indices_dtype, logger)) { return false; } + if (indices_dtype != ONNX_NAMESPACE::TensorProto_DataType_INT32 && + indices_dtype != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + LOGS(logger, VERBOSE) << "Gather 'indices' dtype [" << indices_dtype + << "] is not supported (expected INT32 or INT64)"; + return false; + } + + // For scalar indices we internally emit gather with promoted [1] indices + // then squeeze. That requires us to claim a static intermediate shape, so + // we only handle scalar indices when the data shape itself is fully + // static. (Dynamic-shape scalar Gather still falls back to CPU.) + if (indices_shape.empty()) { + if (!IsStaticShape(data_shape)) { + LOGS(logger, VERBOSE) << "Gather with scalar 'indices' requires static 'data' shape"; + return false; + } + // The pre-squeeze intermediate has the same rank as `data`. CoreML's + // compiler reports "Invalid rank: 6" when a rank-5 intermediate is + // produced via reshape+gather, even though rank-5 intermediates are + // accepted in other op chains. Cap scalar-indices Gather at data rank 4 + // until that compiler limit is lifted. + // + // TODO: re-test on newer macOS / CoreML versions; if Apple lifts the + // intermediate rank limit, this cap can be raised to 5 (matching the + // general Gather output-rank check below). + if (data_shape.size() > 4) { + LOGS(logger, VERBOSE) << "Gather with scalar 'indices' supports 'data' rank up to 4"; + return false; + } + } + // Output rank = data_rank + indices_rank - 1. The rank-5 limit applies. if (data_shape.size() + indices_shape.size() - 1 > 5) { LOGS(logger, VERBOSE) << "Gather does not support output with rank greater than 5"; return false; diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 94b41149a32d1..13b75f3c6e4fa 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -241,9 +241,10 @@ TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) { } TEST(CoreMLExecutionProviderTest, GatherWithScalarIndices) { - // For scalar inputs, the input shape is modified from [] -> [1] before passing the input to CoreML. - // This won't work for Gather because the output shape depends on the `indices` input shape which could be a scalar. - // Currently, we expect the CoreML EP to only take the Shape node in this graph (Gather -> Shape). + // The CoreML EP supports scalar 'indices' for Gather only when the 'data' input has a fully + // static shape (it needs to claim a static intermediate shape for the post-gather squeeze). + // This model's 'data' input is dynamic ([M, N, K]) so Gather still falls back to CPU and the + // CoreML EP only takes the Shape node. const auto model_file_name = ORT_TSTR("testdata/gather_with_scalar_indices_then_shape.onnx"); #if defined(__APPLE__) @@ -2359,6 +2360,656 @@ TEST(CoreMLExecutionProviderTest, Split11SingleOutputNotSupported) { TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::None); } +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesAxis1) { + // ai.onnx:Gather with rank-0 (scalar) 'indices'. ONNX output rank = + // data_rank + indices_rank - 1 = 2. The CoreML builder internally promotes + // indices to [1], runs gather, then squeezes the inserted axis. Pattern + // produced by StyleGAN-family generators (e.g. GFPGAN) that pick a + // per-layer style code with a scalar index. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_axis1", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + // data X: {1, 4, 8} float + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(1); + data_shape->add_dim()->set_dim_value(4); + data_shape->add_dim()->set_dim_value(8); + + // output Y: {1, 8} + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(1); + output_shape->add_dim()->set_dim_value(8); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + // Scalar int64 index with value 2. + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + // No dims => rank-0 tensor. + idx_init.add_int64_data(2); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar", "Gather", "Gather with scalar indices", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 4, 8}; + std::vector input_data(1 * 4 * 8); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = static_cast(i) * 0.25f - 1.0f; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesAxis1_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesAxis1_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesAxis0) { + // Scalar Gather along axis 0 — squeeze axis is 0; covers a different + // squeeze position than the axis=1 test. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_axis0", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + // data X: {6, 5} float + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(6); + data_shape->add_dim()->set_dim_value(5); + + // output Y: {5} + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(5); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(4); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_axis0", "Gather", "Gather scalar idx axis=0", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(0)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {6, 5}; + std::vector input_data(6 * 5); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = static_cast(i) - 12.5f; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesAxis0_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesAxis0_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesNegativeAxis) { + // Scalar Gather with negative axis (-1) — verifies HandleNegativeAxis is + // applied when computing the squeeze axis. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_negative_axis", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + // data X: {2, 3, 4} float + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(2); + data_shape->add_dim()->set_dim_value(3); + data_shape->add_dim()->set_dim_value(4); + + // output Y: {2, 3} (axis=-1 == axis 2; output drops that axis) + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(2); + output_shape->add_dim()->set_dim_value(3); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(1); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_neg_axis", "Gather", "Gather scalar idx axis=-1", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(-1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {2, 3, 4}; + std::vector input_data(2 * 3 * 4); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = static_cast(i) * 0.5f; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesNegativeAxis_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesNegativeAxis_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesFloat16) { + // FLOAT16 'data' input. HasSupportedInputsImpl restricts fp16 Gather to + // MLProgram on CoreML 6+, so this test only runs the MLProgram path. + // Exercises the MLFloat16 branch of the static intermediate shape claim. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_fp16", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(1); + data_shape->add_dim()->set_dim_value(4); + data_shape->add_dim()->set_dim_value(8); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(1); + output_shape->add_dim()->set_dim_value(8); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(2); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_fp16", "Gather", "Gather scalar idx fp16 data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 4, 8}; + std::vector input_data; + input_data.reserve(1 * 4 * 8); + for (size_t i = 0; i < 1 * 4 * 8; ++i) { + input_data.emplace_back(static_cast(i) * 0.25f - 1.0f); + } + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesFloat16_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesInt64Data) { + // INT64 'data' input. HasSupportedInputsImpl allows int64 in both NN and + // MLProgram; verify both formats correctly route int64 through the + // expand/gather/squeeze chain. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_int64_data", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(3); + data_shape->add_dim()->set_dim_value(4); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(4); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(1); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_int64", "Gather", "Gather scalar idx int64 data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(0)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {3, 4}; + std::vector input_data; + input_data.reserve(3 * 4); + for (int64_t i = 0; i < 3 * 4; ++i) input_data.push_back(i * 1000 - 5000); + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesInt64Data_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesInt64Data_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesInt32Indices) { + // INT32 'indices'. The other scalar-indices tests use INT64 indices (the + // PyTorch default); this one exercises the INT32 branch through both the + // dtype gating in IsOpSupportedImpl and the indices_dtype path-through to + // the reshape's intermediate output dtype in AddToModelBuilderImpl. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_int32_indices", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(3); + data_shape->add_dim()->set_dim_value(4); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(4); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + idx_init.add_int32_data(2); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_int32_idx", "Gather", "Gather scalar int32 idx", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(0)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {3, 4}; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesInt32Indices_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesInt32Indices_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesRank4Data) { + // Rank-4 'data' input — the supported maximum for scalar Gather (the + // pre-squeeze intermediate is rank 4; CoreML's compiler rejects scalar + // Gather at rank 5 with "Invalid rank: 6"). Output is rank 3. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_rank4", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + for (int64_t d : {2, 5, 3, 4}) data_shape->add_dim()->set_dim_value(d); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + // Gather on axis=1 with scalar idx removes that axis: {2,3,4} + for (int64_t d : {2, 3, 4}) output_shape->add_dim()->set_dim_value(d); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(3); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_rank4", "Gather", "Gather scalar idx rank-4 data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {2, 5, 3, 4}; + std::vector input_data(2 * 5 * 3 * 4); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = static_cast(i) * 0.1f - 5.0f; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesRank4Data_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesRank4Data_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesRank1Data) { + // Rank-1 'data' input with scalar indices — output is rank-0 (the pre-squeeze + // intermediate is rank 1, squeezed to a scalar). Confirms CoreML actually + // produces a rank-0 result on both NN and MLProgram paths. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_rank1", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + data_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(6); + + // Output is rank-0: TypeProto with a shape that has no dims. + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + output_type.mutable_tensor_type()->mutable_shape(); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(2); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_rank1", "Gather", "Gather scalar idx rank-1 data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(0)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {6}; + std::vector input_data(6); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = static_cast(i) - 2.5f; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesRank1Data_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesRank1Data_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesDynamicDataNotSupported) { + // The scalar-indices path emits a reshape-+squeeze chain whose intermediate + // shape we have to claim statically. IsOpSupportedImpl rejects the node + // when 'data' has any unknown dim so it falls back to CPU rather than + // produce an ill-formed CoreML program. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_dynamic_data", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_param("N"); // dynamic leading dim + data_shape->add_dim()->set_dim_value(4); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("N"); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(0); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_dyn", "Gather", "Gather scalar idx, dynamic data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::None); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::None); +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesRank5DataNotSupported) { + // Scalar-indices Gather caps data rank at 4 (CoreML compiler reports + // "Invalid rank: 6" on the rank-5 reshape+gather intermediate). Rank-5 + // 'data' must fall back to CPU. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_rank5", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + for (int64_t d : {2, 3, 4, 5, 6}) data_shape->add_dim()->set_dim_value(d); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + // axis=2 with scalar idx removes that axis: {2,3,5,6} + for (int64_t d : {2, 3, 5, 6}) output_shape->add_dim()->set_dim_value(d); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(2); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_rank5", "Gather", "Gather scalar idx rank-5 data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(2)); + + ASSERT_STATUS_OK(graph.Resolve()); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::None); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::None); +} + #endif // !(ORT_MINIMAL_BUILD) } // namespace test } // namespace onnxruntime From e10b9a81a68e859324c74280c8382902d3a4883b Mon Sep 17 00:00:00 2001 From: adrastogi Date: Thu, 21 May 2026 17:22:15 -0700 Subject: [PATCH 04/16] Add component governance manifest for WebGPU EP (#28599) ### Description Added a WebGPU-specific Component Governance manifest for Dawn and related dependencies. Added documentation for the manifest scope, dependency classification, and maintenance steps. Added a validation script to catch Dawn and DXC pin drift. ### Motivation and Context WebGPU builds depend on Dawn and related components that are not part of vanilla ONNX Runtime builds. Downstream WebGPU packaging needs ORT-owned metadata to generate complete third-party notices without maintaining a duplicate dependency inventory. --------- Co-authored-by: Aditya Rastogi --- cgmanifests/README.md | 6 +- cgmanifests/webgpu/README.md | 61 + cgmanifests/webgpu/cgmanifest.webgpu.json | 1110 +++++++++++++++++ .../webgpu/validate_webgpu_cgmanifest.py | 165 +++ 4 files changed, 1341 insertions(+), 1 deletion(-) create mode 100644 cgmanifests/webgpu/README.md create mode 100644 cgmanifests/webgpu/cgmanifest.webgpu.json create mode 100644 cgmanifests/webgpu/validate_webgpu_cgmanifest.py diff --git a/cgmanifests/README.md b/cgmanifests/README.md index a7d816a401a95..5e356e4507141 100644 --- a/cgmanifests/README.md +++ b/cgmanifests/README.md @@ -1,3 +1,7 @@ # CGManifest Files This directory contains CGManifest (cgmanifest.json) files. -See [here](https://docs.opensource.microsoft.com/tools/cg/cgmanifest.html) for details. \ No newline at end of file +See [here](https://docs.opensource.microsoft.com/tools/cg/cgmanifest.html) for details. + +The WebGPU-specific manifest is in `webgpu/cgmanifest.webgpu.json`. It is intentionally not named `cgmanifest.json` +so default whole-repository Component Governance scans do not pick it up automatically. WebGPU packaging or +NOTICE-generation pipelines should stage it as `cgmanifest.json` in their scan input. diff --git a/cgmanifests/webgpu/README.md b/cgmanifests/webgpu/README.md new file mode 100644 index 0000000000000..cf03477ea6bbe --- /dev/null +++ b/cgmanifests/webgpu/README.md @@ -0,0 +1,61 @@ +# WebGPU Component Governance manifest + +This directory contains the WebGPU-specific Component Governance manifest for ONNX Runtime. It covers Dawn and the +Dawn-derived dependency graph used when building the WebGPU Execution Provider. + +The manifest is named `cgmanifest.webgpu.json`, not `cgmanifest.json`, so default whole-repository Component +Governance scans do not pick it up automatically. WebGPU packaging and NOTICE-generation pipelines should stage or copy +this file as `cgmanifest.json` in the source directory that they scan for WebGPU package notices. + +## Classification policy + +The Component Governance manifest schema provides a `developmentDependency` boolean, but it does not provide separate +first-class fields for runtime, build-tool, test-only, or conditional dependencies. This manifest uses: + +- no `developmentDependency` field for components that are redistributed, statically linked, or otherwise part of the + WebGPU package/runtime dependency closure; +- `developmentDependency: true` for Dawn dependencies that are only build tools, tests, disabled optional backends, or + source inputs that current WebGPU packages do not redistribute; +- `comments` to preserve the more precise classification and Dawn `DEPS` path/condition. + +If a WebGPU package starts redistributing a component currently marked as a development dependency, update that +registration and explain the packaging path in `comments` and `detectedComponentLocations`. + +## Maintenance + +When rolling Dawn or changing WebGPU packaging: + +1. Update the Dawn registration to match the `dawn` entry in `cmake/deps.txt`. +2. Re-audit the Dawn dependency graph for the pinned Dawn commit: + - Start from the Dawn commit in `cmake/deps.txt`; do not audit Dawn `main` or a different roll. + - Inspect Dawn's `tools/fetch_dawn_dependencies.py` at that commit. For ORT's normal source-fetch path, + `cmake/external/onnxruntime_external_deps.cmake` enables `DAWN_FETCH_DEPENDENCIES`, so the script's + `required_submodules` list is the primary set of Dawn source dependencies fetched for the build. + - Cross-reference each fetched submodule path with Dawn's `DEPS` file to get the public upstream repository URL, + commit, and condition. Use public upstream identities in this manifest, not internal mirrors. + - Compare that fetched set against this manifest. Add new fetched components, update changed commits or repository + URLs, and remove entries that are no longer fetched or relevant unless CG/legal guidance requires keeping them. + - Cross-check ORT's Dawn CMake options in `cmake/external/onnxruntime_external_deps.cmake` and Dawn's + `third_party/CMakeLists.txt` before classifying a component. Components that are redistributed, statically linked, + or otherwise part of the WebGPU package/runtime closure should not be marked as development dependencies; build + tools, test inputs, disabled optional backends, and unfetched conditional dependencies should be marked + `developmentDependency: true` if they remain registered. + - Verify actual WebGPU package contents, especially platform-specific artifacts. For example, the Windows WebGPU + plugin pipeline downloads and redistributes DXC DLLs separately from Dawn's `third_party/dxc` source dependency, so + both the Dawn build-input registration and the redistributed DXC release registration may need review. + - Keep Dawn-derived registrations connected to the Dawn root with `dependencyRoots`. +3. If the Windows WebGPU plugin pipeline changes the downloaded DXC release, update the DirectXShaderCompiler release + registration to match `tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml`. +4. Run: + + ```powershell + python cgmanifests/webgpu/validate_webgpu_cgmanifest.py + ``` + +The validator checks for stale Dawn and DXC pins, but it does not replace the manual dependency classification review +in step 2. + +Non-git Dawn toolchain packages from CIPD/GCS, such as GN, Ninja, CMake, Go, Siso, reclient, and sysroots, are +intentionally not registered here unless they become redistributed or CG/legal guidance requires build input coverage. +They do not have stable public upstream source identities in the Dawn `DEPS` file and are not part of current WebGPU +package contents. diff --git a/cgmanifests/webgpu/cgmanifest.webgpu.json b/cgmanifests/webgpu/cgmanifest.webgpu.json new file mode 100644 index 0000000000000..90448c9b4a68e --- /dev/null +++ b/cgmanifests/webgpu/cgmanifest.webgpu.json @@ -0,0 +1,1110 @@ +{ + "$schema": "https://json.schemastore.org/component-detection-manifest.json", + "version": 1, + "registrations": [ + { + "component": { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + }, + "comments": "runtime; WebGPU EP root dependency pinned in cmake/deps.txt and patched by cmake/external/onnxruntime_external_deps.cmake.", + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/deps.txt", + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake", + "{SourceFileRoot}/cmake/patches/dawn" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b4711839eb9a87da7c3436d9b212e0492359fbbd", + "repositoryUrl": "https://github.com/microsoft/DirectXShaderCompiler.git", + "tag": "v1.8.2502" + } + }, + "comments": "runtime; redistributed by Windows WebGPU plugin packages as dxil.dll and dxcompiler.dll. Release zip: https://github.com/microsoft/DirectXShaderCompiler/releases/download/v1.8.2502/dxc_2025_02_20.zip; SHA256: 70B1913A1BFCE4A3E1A5311D16246F4ECDF3A3E613ABEC8AA529E57668426F85.", + "detectedComponentLocations": [ + "{SourceFileRoot}/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml", + "{SourceFileRoot}/plugin-ep-webgpu/csharp/pack_nuget.py", + "{SourceFileRoot}/plugin-ep-webgpu/python/build_wheel.py" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7ef32bbacabd0d04a6cfac92a542841c531e1b21", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/third_party/abseil-cpp" + } + }, + "comments": "runtime; Dawn DEPS third_party/abseil-cpp. ORT static WebGPU builds point Dawn at ORT's Abseil source when available.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "a0f4dc977fa2ef7f47708aec914a4fbfeefc6103", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/third_party/protobuf" + } + }, + "comments": "runtime; Dawn DEPS third_party/protobuf. ORT static WebGPU builds point Dawn at ORT's Protobuf source when available.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "f31ca173eff866369e54d35e53375fadbabd58f4", + "repositoryUrl": "https://github.com/KhronosGroup/SPIRV-Headers.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/spirv-headers/src used by Dawn/Tint SPIR-V support.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "cb38b2342beedde25bcff582dc3528a135cf6e67", + "repositoryUrl": "https://github.com/KhronosGroup/SPIRV-Tools.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/spirv-tools/src used by Dawn/Tint SPIR-V support.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "49f1a381e2aec33ef32adf4a377b5a39ec016ec4", + "repositoryUrl": "https://github.com/KhronosGroup/Vulkan-Headers.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/vulkan-headers/src and ORT Dawn port dependency for Vulkan-enabled WebGPU builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "50af38b6cd43afb1462f9ad26b8d015382d11a3d", + "repositoryUrl": "https://github.com/KhronosGroup/Vulkan-Utility-Libraries.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/vulkan-utility-libraries/src and ORT Dawn port dependency for Vulkan-enabled WebGPU builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "cb0597213b0fcb999caa9ed08c2f88dc45eb7d50", + "repositoryUrl": "https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/vulkan_memory_allocator used by Vulkan-enabled Dawn builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7eda07b1e067ef3fd7eea0419c88b5af45c9a776", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/third_party/zlib" + } + }, + "comments": "runtime; Dawn DEPS third_party/zlib.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "008e4fdd7e31d9133d028659348e054d350ccc3e", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/base/allocator/partition_allocator.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/partition_alloc used by Dawn standalone builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "3e6e148537683c22e3e74977d56516f16f39c7be", + "repositoryUrl": "https://github.com/microsoft/DirectXShaderCompiler.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/dxc used when ORT builds Dawn's built DXC path for Windows WebGPU builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "980971e835876dc0cde415e8f9bc646e64667bf7", + "repositoryUrl": "https://github.com/microsoft/DirectX-Headers.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/dxheaders and ORT Dawn port dependency for D3D12/DXC-enabled WebGPU builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "6a18683f555b4ac8b05ac8395c29c84483ac9588", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/buildtools" + } + }, + "comments": "build-tool; Dawn DEPS buildtools, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c2725e0622e1a86d55f14514f2177a39efea4a0e", + "repositoryUrl": "https://chromium.googlesource.com/external/github.com/llvm/llvm-project/clang/tools/clang-format.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/clang-format/script, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "425882d8c0acaab53bf2f8abbe7efcf5db5b168b", + "repositoryUrl": "https://chromium.googlesource.com/chromium/tools/depot_tools.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/depot_tools, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7ab65651aed6802d2599dcb7a73b1f82d5179d05", + "repositoryUrl": "https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libcxx.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/libc++/src, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "8f11bb1d4438d0239d0dfc1bd9456a9f31629dda", + "repositoryUrl": "https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libcxxabi.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/libc++abi/src, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "d38523b674e26b7c8d61ed2e48d6cfe248b12da0", + "repositoryUrl": "https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libc.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/llvm-libc/src required by libc++, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "369990d9660a387f618d0eedc341eb285016243b", + "repositoryUrl": "https://chromium.googlesource.com/chromiumos/third_party/libdrm.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/libdrm/src for Linux build support, condition: dawn_standalone and host_os == \"linux\".", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "4c2c31b6776c1fe03a029f66ef530796f0add90d", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/build" + } + }, + "comments": "build-tool; Dawn DEPS build, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7fd7d7092fa5ee06380f06f66f1b7bd03fca71a8", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/clang" + } + }, + "comments": "build-tool; Dawn DEPS tools/clang, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b635f27e932356a2e29450e5cfa544cdcc9ea6bb", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/memory" + } + }, + "comments": "build-tool; Dawn DEPS tools/memory, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "da34b95fdbf2032df6cda5f3828c2ba421592644", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/valgrind" + } + }, + "comments": "build-tool; Dawn DEPS tools/valgrind, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "baacfc6d5986b07abe0503216b491e234b94ba79", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/win" + } + }, + "comments": "build-tool; Dawn DEPS tools/win, condition: checkout_win and not build_with_chromium.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "a975ec0340bd4b7dab6c8e43b15dbc638621a23c", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/mb" + } + }, + "comments": "build-tool; Dawn DEPS tools/mb, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "4d438b31b58e2dc84b592a052b6b97e05ceb6497", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/testing" + } + }, + "comments": "test-only; Dawn DEPS testing, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "bea408a6e01f0f7e6c82a43121fe3af4506c932e", + "repositoryUrl": "https://chromium.googlesource.com/external/github.com/llvm/llvm-project/compiler-rt/lib/fuzzer.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/libFuzzer/src, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "4fe3307fb2d9f86d19777c7eb0e4809e9694dde7", + "repositoryUrl": "https://github.com/google/googletest.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/googletest, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "59090f1f5e2b3ad9c90e4dc5fc8e79aed9110587", + "repositoryUrl": "https://chromium.googlesource.com/catapult.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/catapult, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "188e8278990a9069ffc84441cb5a024fd0bede37", + "repositoryUrl": "https://github.com/google/benchmark.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/google_benchmark/src, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "2e683eb7385c54f872acc47b371210d2282bc103", + "repositoryUrl": "https://gitlab.freedesktop.org/mesa/mesa.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/mesa/src, condition: dawn_standalone and checkout_mesa.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "d389906a136c2aac9820ded0f38d1e25ef25fb9a", + "repositoryUrl": "https://github.com/mesonbuild/meson.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/meson/src, condition: dawn_standalone and checkout_mesa.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c3027d884967773057bf74b957e3fea87e5df4d7", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/third_party/jinja2" + } + }, + "comments": "build-tool; Dawn DEPS third_party/jinja2 for code generation, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "4256084ae14175d38a3ff7d739dca83ae49ccec6", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/third_party/markupsafe" + } + }, + "comments": "build-tool; Dawn DEPS third_party/markupsafe for code generation, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b35641f4a3c62aa86a0b3c983d163bc0fe36026d", + "repositoryUrl": "https://github.com/glfw/glfw.git" + } + }, + "comments": "conditional; Dawn DEPS third_party/glfw. ORT disables GLFW unless onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP is enabled.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "cce16dfb64c7525c6a417f98c67423330db8f3d7", + "repositoryUrl": "https://chromium.googlesource.com/angle/angle" + } + }, + "comments": "conditional; Dawn DEPS third_party/angle. ORT disables Dawn desktop GL/OpenGLES unless PIX support is enabled.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b7b7fd22e5f28079b92412f47f6da4df43e4cd37", + "repositoryUrl": "https://swiftshader.googlesource.com/SwiftShader" + } + }, + "comments": "conditional; Dawn DEPS third_party/swiftshader. Not redistributed by ORT WebGPU packages.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "a26b8836968dc480ad283234823e6ffc62052489", + "repositoryUrl": "https://chromium.googlesource.com/vulkan-deps" + } + }, + "comments": "build-tool; Dawn DEPS third_party/vulkan-deps roll metadata. Concrete Vulkan components are registered separately.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "022de31e7ffa5230068858d9e6cd85ae11170bda", + "repositoryUrl": "https://github.com/KhronosGroup/glslang.git" + } + }, + "comments": "conditional; Dawn DEPS third_party/glslang/src. ORT disables GLSL writer/validator unless PIX support is enabled.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "09a024d4e422f8e603412f582d76c2051ef51cfc", + "repositoryUrl": "https://github.com/KhronosGroup/Vulkan-Loader.git" + } + }, + "comments": "conditional; Dawn DEPS third_party/vulkan-loader/src and ORT Dawn port dependency. Not redistributed by current WebGPU plugin packages.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "39a19dccf79d28951516c3c7c9f1ee4a606fb733", + "repositoryUrl": "https://github.com/KhronosGroup/Vulkan-Tools.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/vulkan-tools/src.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "145be10eff68bf41f1b556026ecf7da9a7c8d15b", + "repositoryUrl": "https://github.com/KhronosGroup/Vulkan-ValidationLayers.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/vulkan-validation-layers/src. ORT disables Dawn SPIR-V validation.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "5bae8738b23d06968e7c3a41308568120943ae77", + "repositoryUrl": "https://github.com/KhronosGroup/OpenGL-Registry.git" + } + }, + "comments": "conditional; Dawn DEPS third_party/khronos/OpenGL-Registry. ORT disables desktop GL/OpenGLES unless PIX support is enabled.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7dea2ed79187cd13f76183c4b9100159b9e3e071", + "repositoryUrl": "https://github.com/KhronosGroup/EGL-Registry.git" + } + }, + "comments": "conditional; Dawn DEPS third_party/khronos/EGL-Registry. ORT disables desktop GL/OpenGLES unless PIX support is enabled.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "dbe37c7d554fd72651510c362cf62992e5f45e1f", + "repositoryUrl": "https://github.com/gpuweb/cts.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/webgpu-cts, condition: build_with_chromium or dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b4258c35121c8d0e12f53568ffb22236d7816723", + "repositoryUrl": "https://github.com/emscripten-core/emsdk.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/emsdk, condition: dawn_wasm.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "d5cfe19da8b974ca35764dd1c73b91d57cd3c4ce", + "repositoryUrl": "https://github.com/nodejs/node-api-headers.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/node-api-headers, condition: dawn_node.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "1e26dcb52829a74260ec262edb41fc22998669b6", + "repositoryUrl": "https://github.com/nodejs/node-addon-api.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/node-addon-api, condition: dawn_node.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b4b5752ff755fe33bf6a67fb6e5964ba9d40dcdc", + "repositoryUrl": "https://github.com/gpuweb/gpuweb.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/gpuweb, condition: dawn_node.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "0bfcdc4f487023d85e33597de0a94fc523e30fca", + "repositoryUrl": "https://github.com/webgpu-native/webgpu-headers.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/webgpu-headers/src for testing purposes.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "3438d4183bfc7c0d6850e8b970204cc8189f0323", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/protoc_wrapper" + } + }, + "comments": "build-tool; Dawn DEPS tools/protoc_wrapper, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7bf98f78a30b067e22420ff699348f084f802e12", + "repositoryUrl": "https://github.com/google/libprotobuf-mutator.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/libprotobuf-mutator/src, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "42e892d96e47b1f6e29844cc705e148ec4856448", + "repositoryUrl": "https://github.com/open-source-parsers/jsoncpp.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/jsoncpp, condition: dawn_tintd.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "303c526231a90049a3e384549720f3fbd453cf66", + "repositoryUrl": "https://github.com/google/langsvr.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/langsvr, condition: dawn_tintd.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + } + ] +} diff --git a/cgmanifests/webgpu/validate_webgpu_cgmanifest.py b/cgmanifests/webgpu/validate_webgpu_cgmanifest.py new file mode 100644 index 0000000000000..ed4f4b19035cc --- /dev/null +++ b/cgmanifests/webgpu/validate_webgpu_cgmanifest.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Validate WebGPU Component Governance manifest drift.""" + +from __future__ import annotations + +import json +import re +import sys +from pathlib import Path +from typing import Any + +REPO_ROOT = Path(__file__).resolve().parents[2] +WEBGPU_CGMANIFEST = Path(__file__).resolve().with_name("cgmanifest.webgpu.json") +DEPS_TXT = REPO_ROOT / "cmake" / "deps.txt" +PLUGIN_WIN_WEBGPU_STAGE = ( + REPO_ROOT / "tools" / "ci_build" / "github" / "azure-pipelines" / "stages" / "plugin-win-webgpu-stage.yml" +) + +DAWN_REPOSITORY_URL = "https://github.com/google/dawn.git" +DXC_REPOSITORY_URL = "https://github.com/microsoft/DirectXShaderCompiler.git" + + +def _load_manifest() -> dict[str, Any]: + with WEBGPU_CGMANIFEST.open(encoding="utf-8") as manifest_file: + manifest = json.load(manifest_file) + + registrations = manifest.get("registrations") + if not isinstance(registrations, list): + raise ValueError(f"{WEBGPU_CGMANIFEST} must contain a registrations array") + + return manifest + + +def _git_component(registration: dict[str, Any]) -> dict[str, str] | None: + component = registration.get("component") + if not isinstance(component, dict) or component.get("type") != "git": + return None + + git = component.get("git") + if not isinstance(git, dict): + return None + + repository_url = git.get("repositoryUrl") + commit_hash = git.get("commitHash") + if not isinstance(repository_url, str) or not isinstance(commit_hash, str): + return None + + result = {"repositoryUrl": repository_url, "commitHash": commit_hash} + tag = git.get("tag") + if isinstance(tag, str): + result["tag"] = tag + return result + + +def _registrations(manifest: dict[str, Any]) -> list[dict[str, Any]]: + return manifest["registrations"] + + +def _find_git_registration(manifest: dict[str, Any], repository_url: str, *, tag: str | None = None) -> dict[str, Any]: + matches = [] + for registration in _registrations(manifest): + git = _git_component(registration) + if git is None or git["repositoryUrl"] != repository_url: + continue + if tag is not None and git.get("tag") != tag: + continue + matches.append(registration) + + if len(matches) != 1: + suffix = f" with tag {tag}" if tag is not None else "" + raise ValueError(f"expected exactly one registration for {repository_url}{suffix}, found {len(matches)}") + return matches[0] + + +def _dawn_commit_from_deps_txt() -> str: + deps_text = DEPS_TXT.read_text(encoding="utf-8") + match = re.search(r"^dawn;https://github\.com/google/dawn/archive/([0-9a-f]{40})\.zip;", deps_text, re.MULTILINE) + if not match: + raise ValueError(f"could not find Dawn commit in {DEPS_TXT}") + return match.group(1) + + +def _dxc_release_from_pipeline() -> tuple[str, str, str]: + pipeline_text = PLUGIN_WIN_WEBGPU_STAGE.read_text(encoding="utf-8") + url_match = re.search(r'\$dxcZipUrl = "([^"]+)"', pipeline_text) + hash_match = re.search(r'\$expectedHash = "([0-9A-Fa-f]+)"', pipeline_text) + if not url_match or not hash_match: + raise ValueError(f"could not find DXC release URL/hash in {PLUGIN_WIN_WEBGPU_STAGE}") + + tag_match = re.search(r"/download/(v[^/]+)/", url_match.group(1)) + if not tag_match: + raise ValueError(f"could not find DXC release tag in {url_match.group(1)}") + + return tag_match.group(1), url_match.group(1), hash_match.group(1).upper() + + +def _validate_dawn_root(manifest: dict[str, Any]) -> None: + registration = _find_git_registration(manifest, DAWN_REPOSITORY_URL) + git = _git_component(registration) + if git is None: + raise ValueError("Dawn registration must be a git component") + + expected_commit = _dawn_commit_from_deps_txt() + if git["commitHash"] != expected_commit: + raise ValueError(f"Dawn manifest commit {git['commitHash']} does not match {DEPS_TXT} commit {expected_commit}") + + +def _validate_dxc_release(manifest: dict[str, Any]) -> None: + expected_tag, expected_url, expected_hash = _dxc_release_from_pipeline() + registration = _find_git_registration(manifest, DXC_REPOSITORY_URL, tag=expected_tag) + git = _git_component(registration) + if git is None: + raise ValueError(f"DXC {expected_tag} registration must be a git component") + + comments = registration.get("comments", "") + if expected_url not in comments or expected_hash not in comments: + raise ValueError( + f"DXC {expected_tag} registration comments must contain pipeline URL {expected_url} " + f"and SHA256 {expected_hash}" + ) + + +def _validate_dawn_dependency_roots(manifest: dict[str, Any]) -> None: + dawn_commit = _dawn_commit_from_deps_txt() + + for registration in _registrations(manifest): + comments = registration.get("comments", "") + if not isinstance(comments, str) or "Dawn DEPS" not in comments: + continue + + dependency_roots = registration.get("dependencyRoots") + if not isinstance(dependency_roots, list) or len(dependency_roots) != 1: + raise ValueError(f"Dawn-derived registration is missing one dependencyRoots entry: {comments}") + + root = dependency_roots[0] + if not isinstance(root, dict): + raise ValueError(f"Dawn dependency root must be an object: {comments}") + + root_git = root.get("git") + if root.get("type") != "git" or not isinstance(root_git, dict): + raise ValueError(f"Dawn dependency root must be a git component: {comments}") + if root_git.get("repositoryUrl") != DAWN_REPOSITORY_URL or root_git.get("commitHash") != dawn_commit: + raise ValueError(f"Dawn dependency root does not match {DAWN_REPOSITORY_URL}@{dawn_commit}: {comments}") + + +def main() -> int: + try: + manifest = _load_manifest() + _validate_dawn_root(manifest) + _validate_dxc_release(manifest) + _validate_dawn_dependency_roots(manifest) + except (OSError, ValueError) as ex: + print(f"ERROR: {ex}", file=sys.stderr) + return 1 + + print(f"Validated {WEBGPU_CGMANIFEST}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From 1e80c291fe1d2e4296d1e19da4d373692f377341 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Thu, 21 May 2026 20:10:01 -0700 Subject: [PATCH 05/16] =?UTF-8?q?Validate=20seqlens=5Fk=20against=20cos=5F?= =?UTF-8?q?cache=20bounds=20in=20GroupQueryAttention=20to=E2=80=A6=20(#282?= =?UTF-8?q?77)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Validate `seqlens_k` values against `cos_cache.shape[0]` in `GroupQueryAttention::Compute()` when `do_rotary` is enabled, to prevent out-of-bounds reads in the rotary embedding lookup. ### Root Cause `CheckRotaryCaches()` validates `cos_cache.shape[0] >= total_sequence_length`, but runtime position IDs are derived from `seqlens_k` (a separate, per-batch input). An attacker can set `total_sequence_length` small enough to pass the guard while setting `seqlens_k[b]` far beyond `cos_cache.shape[0]`, causing `position_id = seqlens_k[b]` to index out of bounds into the cos/sin cache. The resulting heap bytes are used as rotation values and propagate into the inference output. ### Fix Add an explicit bounds check in `Compute()` that rejects any `seqlens_k[b] >= cos_cache.shape[0]` before position IDs are computed. This is defense-in-depth alongside the existing `RunRotaryEmbedding` position_ids validation added in #27597. ### Security - **Impact:** Heap OOB read (CWE-125) — adjacent heap memory leaks into inference output via cos/sin rotation values. - **Attack vector:** Any GQA-based LLM serving endpoint (Llama, Phi, Mistral) that accepts `seqlens_k` as an inference input. No model modification required. ### Testing Verified that crafted inputs with `seqlens_k` exceeding `cos_cache` dimensions now return `INVALID_ARGUMENT` instead of silently producing results containing leaked heap data. --- .../cpu/bert/group_query_attention.cc | 4 + .../cpu/bert/group_query_attention_helper.h | 16 ++ .../group_query_attention_op_test.cc | 169 ++++++++++++++++++ 3 files changed, 189 insertions(+) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index ea0049c28e31b..4df5f6a349599 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -194,6 +194,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { if (do_rotary_) { // When kv_sequence_length == 0 (shared KV), only Q needs RoPE — K is skipped below. ORT_ENFORCE(cos_cache != nullptr && sin_cache != nullptr, "cos_cache and sin_cache must be provided when do_rotary is true"); + // Validation of seqlens_k against rotary cache size is performed in CheckInputs() + // when seqlens_k is on CPU. GPU EPs where seqlens_k resides on device rely on + // RunRotaryEmbedding's position_ids validation for OOB protection. + // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; rotary_params.batch_size = batch_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index ed910e3510fed..3429ca5f5be52 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -310,6 +310,22 @@ Status CheckInputs(const T* query, int rotary_dim = 0; if (cos_cache != nullptr && sin_cache != nullptr) { ORT_RETURN_IF_ERROR(CheckRotaryCaches(cos_cache, sin_cache, head_size, total_sequence_length, rotary_dim)); + + // Validate seqlens_k against rotary cache size when rotary embeddings are enabled. + // This prevents OOB access when deriving position IDs from seqlens_k during rotary embedding. + const bool is_seqlens_k_on_cpu = (seqlens_k->Location().device.Type() == OrtDevice::CPU); + if (is_seqlens_k_on_cpu) { + const int64_t rotary_cache_max_seq = std::min(cos_cache->Shape().GetDims()[0], + sin_cache->Shape().GetDims()[0]); + const int32_t* seqlens_k_data = seqlens_k->template Data(); + for (int b = 0; b < batch_size; b++) { + if (seqlens_k_data[b] < 0 || static_cast(seqlens_k_data[b]) >= rotary_cache_max_seq) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k[", b, "] = ", seqlens_k_data[b], + " is out of range for rotary cache dimension 0 (", rotary_cache_max_seq, ")"); + } + } + } } else if (cos_cache != nullptr || sin_cache != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index a6dd471ce639f..112d6f1eecc72 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -1578,5 +1578,174 @@ TEST(GroupQueryAttentionTest, QuantizedKV_MissingScale) { {}, nullptr, &execution_providers); } +// Regression: seqlens_k valid for KV cache but exceeding cos_cache.shape[0] must be rejected +// when do_rotary is enabled. Without this check, the position ID derived from seqlens_k +// would index out of bounds in the cos/sin cache, leaking heap memory into output. +TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_OOB) { + constexpr int num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; // must be multiple of 16 for rotary + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int rotary_half_dim = head_size / 2; // cos/sin cache dim-1 = 8 + + constexpr int cos_cache_max_seq = 4; // small rotary cache + constexpr int past_seq_len = 16; // large KV cache + constexpr int seqlens_k_val = 10; // valid for KV (10 < 16) but OOB for cos (10 >= 4) + constexpr int total_seq_len = 4; // passes CheckRotaryCaches (4 <= cos_cache_max_seq) + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + tester.AddInput("query", {1, 1, hidden_size}, std::vector(hidden_size, 1.0f)); + tester.AddInput("key", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + tester.AddInput("value", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + + // Past KV cache is large enough for seqlens_k=10 + tester.AddInput("past_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + tester.AddInput("past_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + + tester.AddInput("seqlens_k", {1}, {seqlens_k_val}); + tester.AddInput("total_sequence_length", {1}, {total_seq_len}); + + // cos/sin cache with only 4 rows — seqlens_k=10 exceeds this + tester.AddInput("cos_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 1.0f)); + tester.AddInput("sin_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 0.0f)); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + tester.AddOutput("output", {1, 1, hidden_size}, std::vector(hidden_size, 0.0f)); + tester.AddOutput("present_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectFailure, "is out of range for rotary cache dimension 0", + {}, nullptr, &execution_providers); +} + +// Positive test: seqlens_k within cos/sin cache bounds with do_rotary enabled should succeed. +TEST(GroupQueryAttentionTest, SeqlensKWithinCosCache_Rotary) { + constexpr int num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; // must be multiple of 16 for rotary + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int rotary_half_dim = head_size / 2; + + constexpr int cos_cache_max_seq = 16; // rotary cache large enough + constexpr int past_seq_len = 16; + constexpr int seqlens_k_val = 3; // valid: 3 < 16 (cos cache) and 3 < 16 (KV cache) + constexpr int total_seq_len = 4; // seqlens_k + 1 + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + tester.AddInput("query", {1, 1, hidden_size}, std::vector(hidden_size, 1.0f)); + tester.AddInput("key", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + tester.AddInput("value", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + + tester.AddInput("past_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + tester.AddInput("past_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + + tester.AddInput("seqlens_k", {1}, {seqlens_k_val}); + tester.AddInput("total_sequence_length", {1}, {total_seq_len}); + + tester.AddInput("cos_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 1.0f)); + tester.AddInput("sin_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 0.0f)); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + tester.AddOutput("output", {1, 1, hidden_size}, std::vector(hidden_size, 0.0f)); + tester.AddOutput("present_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + + tester.SetOutputTolerance(1e6f); // shape acceptance test, not numerical correctness + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", + {}, nullptr, &execution_providers); +} + +// Multi-batch test: one valid and one OOB seqlens_k value. +// Verifies the validation loop correctly identifies the offending batch index. +TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_MultiBatch) { + constexpr int num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int rotary_half_dim = head_size / 2; + + constexpr int cos_cache_max_seq = 4; + constexpr int past_seq_len = 16; + constexpr int total_seq_len = 4; + constexpr int batch_size = 2; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + tester.AddInput("query", {batch_size, 1, hidden_size}, + std::vector(batch_size * hidden_size, 1.0f)); + tester.AddInput("key", {batch_size, 1, kv_hidden_size}, + std::vector(batch_size * kv_hidden_size, 1.0f)); + tester.AddInput("value", {batch_size, 1, kv_hidden_size}, + std::vector(batch_size * kv_hidden_size, 1.0f)); + + tester.AddInput("past_key", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.5f)); + tester.AddInput("past_value", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.5f)); + + // seqlens_k: batch 0 is valid (3 < 4), batch 1 is OOB (10 >= 4) + tester.AddInput("seqlens_k", {batch_size}, {3, 10}); + tester.AddInput("total_sequence_length", {1}, {total_seq_len}); + + tester.AddInput("cos_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 1.0f)); + tester.AddInput("sin_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 0.0f)); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + tester.AddOutput("output", {batch_size, 1, hidden_size}, + std::vector(batch_size * hidden_size, 0.0f)); + tester.AddOutput("present_key", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + // Error should reference batch index 1: seqlens_k[1] = 10 + tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k[1] = 10", + {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime From e3655b3583310366fa15c3bd6bd21ad55dbb657f Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 21 May 2026 23:46:22 -0700 Subject: [PATCH 06/16] Parallelize CPU ScatterElements kernel via ThreadPool (#28588) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Parallelizes both `GetIndices` and `ScatterData` in the CPU `ScatterElements` implementation using `ThreadPool::TryParallelFor`. **Key insight**: For ScatterElements with `axis=a`, work units identified by coordinates orthogonal to the axis (`outer_size × inner_size`) are guaranteed to write to disjoint output elements—even with reductions. This enables lock-free parallelization without correctness concerns. Changes: - **`GetIndices`**: Index validation/normalization parallelized over the flat index array - **`ScatterData`**: Rewritten to decompose into `outer_size * inner_size` independent work units, each processing `axis_size` sequential scatter operations along the axis dimension - Thread pool plumbed through `ScatterDataDispatchTarget` from `OpKernelContext::GetOperatorThreadPool()` - Training `GatherElementsGradImpl` passes `nullptr` (sequential fallback preserved) For the reported workload (`axis=0`, indices shape `[481385, 80]`): 80 independent parallel streams, each processing 481385 elements—well-suited for multi-core execution. ### Motivation and Context The CPU `ScatterElements` kernel was entirely sequential—single-threaded index conversion followed by single-threaded scatter—yielding ~761ms on a 24-core ARM system for a workload that an optimized parallel implementation handles in ~6ms (129× gap). The kernel showed zero intra-op thread utilization in ORT profiling. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu --- .../core/providers/cpu/tensor/scatter.cc | 231 +++++++++++------- 1 file changed, 138 insertions(+), 93 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index c7a2005924836..5b5011d2a0814 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Scatter +#include #include #include @@ -10,6 +11,7 @@ #include "core/framework/element_type_lists.h" #include "core/framework/op_kernel.h" #include "core/framework/op_kernel_type_control_utils.h" +#include "core/platform/threadpool.h" #include "core/providers/common.h" #include "core/providers/op_kernel_type_control.h" #if defined(ENABLE_TRAINING_OPS) @@ -236,29 +238,44 @@ struct Func_Max { template Status GetIndices( const Tensor& data_input, const Tensor& indices_input, int64_t axis, + concurrency::ThreadPool* tp, std::vector& indices_data) { const auto& input_data_shape = data_input.Shape(); const auto* indices_data_raw = indices_input.Data(); const auto num_indices = indices_input.Shape().Size(); const auto axis_dim_limit = input_data_shape[narrow(axis)]; - std::vector indices_data_result; - indices_data_result.reserve(narrow(num_indices)); - - for (int64_t i = 0; i < num_indices; ++i) { - const int64_t idx = static_cast(indices_data_raw[i]); - - if (idx < -axis_dim_limit || idx >= axis_dim_limit) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "indices element out of data bounds, idx=", idx, - " must be within the inclusive range [", -axis_dim_limit, - ",", axis_dim_limit - 1, "]"); - } - - indices_data_result.push_back(idx < 0 ? idx + axis_dim_limit : idx); + indices_data.resize(narrow(num_indices)); + + // When multiple indices are out-of-bounds, the reported index is nondeterministic + // (whichever thread wins the CAS). This is acceptable—we only need to report that + // validation failed and provide one example of a bad index. + std::atomic found_error{false}; + std::atomic first_bad_idx{0}; + + concurrency::ThreadPool::TryParallelFor( + tp, narrow(num_indices), 1.0, + [&](std::ptrdiff_t first, std::ptrdiff_t last) { + for (std::ptrdiff_t i = first; i < last; ++i) { + const int64_t idx = static_cast(indices_data_raw[i]); + if (idx < -axis_dim_limit || idx >= axis_dim_limit) { + bool expected = false; + if (found_error.compare_exchange_strong(expected, true)) { + first_bad_idx.store(idx, std::memory_order_relaxed); + } + return; + } + indices_data[narrow(i)] = idx < 0 ? idx + axis_dim_limit : idx; + } + }); + + if (found_error.load()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "indices element out of data bounds, idx=", first_bad_idx.load(), + " must be within the inclusive range [", -axis_dim_limit, + ",", axis_dim_limit - 1, "]"); } - indices_data = std::move(indices_data_result); return Status::OK(); } @@ -266,6 +283,7 @@ template Status ScatterData( const FuncT& func, const Tensor* data_input, const std::vector& indices_data, const Tensor* updates_input, int64_t axis, + concurrency::ThreadPool* tp, Tensor* data_output) { const TensorShape& input_data_shape = data_input->Shape(); @@ -296,103 +314,129 @@ Status ScatterData( const auto num_dims = input_data_shape.NumDimensions(); ORT_RETURN_IF_NOT(num_dims > 0, "ScatterElements op: input tensor must have at least one dimension"); - // Allocate and zero out counts. The input/output is of the same rank as - // indices/updates but the actual dimensions of indices/updates must be less or equal - // than that of input/output because we can update no more elements than - // the input contains. As we walk through the indices/updates - // we maintain dimension count as we will need to use it - // to compute output offset but using input/output dim values. - // We treat the whole array as a number where each element having - // different cardinality according to the upd_shape dimensions. - // As each counter reaches its max (upd_shape) it resets to zero - // and we carry to the more significant dim (right to left) - std::vector dim_counters(num_dims); - - // This vector contains number of elements under the dimension. - // For example, for the dimensions of [4, 2, 3] the vector - // would contain [6, 3, 1] since for each count of dim 1 it - // contains 3 elements of dim 2. - // For each count of dim 0 we would have 2x3=6 elements. - // The last value is always 1. - // We use it to compute output element offset. For a given value of - // counters we multiple each counter value per corresponding entry of dim_block_size value - // and add up resulting the output element offset. However, for dimensions - // that are equal to the specified axis value we take indices_data[index] - // instead of the counter value. - // E.g. for 3-dim and axis=0 - // output[indices[i][j][k]][j][k] = updates[i][j][k] - // for axis 1 - // output[i][indices[i][j][k]][k] = updates[i][j][k] - // and so on - std::vector dim_block_size(num_dims); - - dim_block_size.back() = 1; + if (num_indices == 0) { + return Status::OK(); + } + + const auto* update_data = static_cast(updates_input->DataRaw()); + + // Compute outer_size (product of dims before axis) and inner_size (product of dims after axis). + // For ScatterElements with axis=a: + // output[i0]...[indices[i0..iN]][...][iN] = updates[i0][...][iN] + // Work units identified by (outer_idx, inner_idx) are completely independent: + // they never write to the same output element, even with reductions. + // This allows safe parallelization over outer_size * inner_size work units. + int64_t outer_size = 1; + for (int64_t i = 0; i < axis; ++i) { + outer_size *= upd_shape[narrow(i)]; + } + const int64_t axis_size = upd_shape[narrow(axis)]; + int64_t inner_size = 1; + for (size_t i = narrow(axis) + 1; i < num_dims; ++i) { + inner_size *= upd_shape[i]; + } + + // Compute strides for the input/output tensor + std::vector input_strides(num_dims); + input_strides.back() = 1; if (num_dims > 1) { - // We start at num_dims - 2 because we already pre-populated - // the last element above for (auto i = int64_t(num_dims - 2); i >= 0; --i) { - dim_block_size[narrow(i)] = input_data_shape[SafeInt(i) + 1] * dim_block_size[SafeInt(i) + 1]; + input_strides[narrow(i)] = input_data_shape[SafeInt(i) + 1] * input_strides[SafeInt(i) + 1]; } } - const auto* update_data = static_cast(updates_input->DataRaw()); - // For every update we compute the destination offset and copy it there - for (int64_t index = 0; index < num_indices;) { - const auto axis_idx = indices_data[narrow(index)]; - - // Compute the offset - // See comments above for dim_block_size - size_t dst_offset = 0; - for (size_t i = 0; i < num_dims; ++i) { - if (i == size_t(axis)) { - // replace the counter with the update index for this dim - dst_offset += narrow(axis_idx * dim_block_size[narrow(i)]); - } else { - dst_offset += narrow(dim_counters[narrow(i)] * dim_block_size[narrow(i)]); - } + // Compute strides for the updates/indices tensor + std::vector upd_strides(num_dims); + upd_strides.back() = 1; + if (num_dims > 1) { + for (auto i = int64_t(num_dims - 2); i >= 0; --i) { + upd_strides[narrow(i)] = upd_shape[SafeInt(i) + 1] * upd_strides[SafeInt(i) + 1]; } + } - func(dst_base + dst_offset, update_data + index); + const int64_t total_work_units = outer_size * inner_size; + const int64_t input_axis_stride = input_strides[narrow(axis)]; + const int64_t upd_axis_stride = upd_strides[narrow(axis)]; + + // Parallelize over independent work units. + // Each work unit processes axis_size elements along the scatter axis. + // Cost per unit is proportional to axis_size (number of scatter ops per work unit). + concurrency::ThreadPool::TryParallelFor( + tp, narrow(total_work_units), static_cast(axis_size), + [&](std::ptrdiff_t first, std::ptrdiff_t last) { + for (std::ptrdiff_t work_idx = first; work_idx < last; ++work_idx) { + // Decompose work_idx into outer_idx and inner_idx + const int64_t outer_idx = static_cast(work_idx) / inner_size; + const int64_t inner_idx = static_cast(work_idx) % inner_size; + + // Compute the base offset in the output for dimensions outside the axis. + // For dims before axis: determined by outer_idx + // For dims after axis: determined by inner_idx + int64_t dst_base_offset = 0; + int64_t outer_remain = outer_idx; + for (int64_t d = axis - 1; d >= 0; --d) { + const auto dim_size = upd_shape[narrow(d)]; + const auto coord = outer_remain % dim_size; + outer_remain /= dim_size; + dst_base_offset += coord * input_strides[narrow(d)]; + } + int64_t inner_remain = inner_idx; + for (int64_t d = int64_t(num_dims) - 1; d > axis; --d) { + const auto dim_size = upd_shape[narrow(d)]; + const auto coord = inner_remain % dim_size; + inner_remain /= dim_size; + dst_base_offset += coord * input_strides[narrow(d)]; + } + + // Compute the base index into the updates/indices flat array + int64_t upd_base_offset = 0; + outer_remain = outer_idx; + for (int64_t d = axis - 1; d >= 0; --d) { + const auto dim_size = upd_shape[narrow(d)]; + const auto coord = outer_remain % dim_size; + outer_remain /= dim_size; + upd_base_offset += coord * upd_strides[narrow(d)]; + } + inner_remain = inner_idx; + for (int64_t d = int64_t(num_dims) - 1; d > axis; --d) { + const auto dim_size = upd_shape[narrow(d)]; + const auto coord = inner_remain % dim_size; + inner_remain /= dim_size; + upd_base_offset += coord * upd_strides[narrow(d)]; + } + + // Process axis_size elements along the axis + for (int64_t a = 0; a < axis_size; ++a) { + const int64_t upd_flat_idx = upd_base_offset + a * upd_axis_stride; + const int64_t axis_idx = indices_data[narrow(upd_flat_idx)]; + const int64_t dst_offset = dst_base_offset + axis_idx * input_axis_stride; + func(dst_base + dst_offset, update_data + upd_flat_idx); + } + } + }); - if (++index == num_indices) { - break; - } - // Increment counters - // See comments for dim_counters above - for (auto i = int64_t(num_dims - 1); i >= 0; --i) { - auto v = ++dim_counters[narrow(i)]; - assert(v <= upd_shape[narrow(i)]); - if (v < upd_shape[narrow(i)]) { - // No carry, done - break; - } - // No carry for the most significant dim - assert(i > 0); - dim_counters[narrow(i)] = 0; - } - } return Status::OK(); } template struct ScatterDataDispatchTarget { Status operator()(const Tensor* data_input, const std::vector& indices_data, const Tensor* updates_input, int64_t axis, - const std::string& reduction, Tensor* data_output) const { + const std::string& reduction, concurrency::ThreadPool* tp, Tensor* data_output) const { if (reduction == "add") return ScatterData( - Func_Add(), data_input, indices_data, updates_input, axis, data_output); + Func_Add(), data_input, indices_data, updates_input, axis, tp, data_output); else if (reduction == "mul") return ScatterData( - Func_Mul(), data_input, indices_data, updates_input, axis, data_output); + Func_Mul(), data_input, indices_data, updates_input, axis, tp, data_output); else if (reduction == "min") return ScatterData( - Func_Min(), data_input, indices_data, updates_input, axis, data_output); + Func_Min(), data_input, indices_data, updates_input, axis, tp, data_output); else if (reduction == "max") return ScatterData( - Func_Max(), data_input, indices_data, updates_input, axis, data_output); + Func_Max(), data_input, indices_data, updates_input, axis, tp, data_output); else // if (reduction == "none") return ScatterData( - Func_Assignment(), data_input, indices_data, updates_input, axis, data_output); + Func_Assignment(), data_input, indices_data, updates_input, axis, tp, data_output); } }; @@ -444,11 +488,12 @@ Status Scatter::Compute(OpKernelContext* context) const { Status status{}; const auto index_type = indices_input->GetElementType(); std::vector indices_data{}; + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); if (index_type == utils::ToTensorProtoElementType()) { - status = GetIndices(*data_input, *indices_input, axis, indices_data); + status = GetIndices(*data_input, *indices_input, axis, tp, indices_data); } else if (index_type == utils::ToTensorProtoElementType()) { - status = GetIndices(*data_input, *indices_input, axis, indices_data); + status = GetIndices(*data_input, *indices_input, axis, tp, indices_data); } else { status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Indices type is not supported."); } @@ -462,7 +507,7 @@ Status Scatter::Compute(OpKernelContext* context) const { utils::MLTypeCallDispatcherFromTypeList dispatcher{data_type}; status = dispatcher.template InvokeRet( - data_input, indices_data, updates_input, axis, this->reduction_, data_output); + data_input, indices_data, updates_input, axis, this->reduction_, tp, data_output); return status; } @@ -482,8 +527,8 @@ template Status GatherElementsGradImpl(const Tensor* indices_input, const Tensor* updates_input, const int64_t axis, Tensor* data_output) { std::vector indices_data{}; - ORT_RETURN_IF_ERROR(GetIndices(*data_output, *indices_input, axis, indices_data)); - return ScatterData(Func_Add(), data_output, indices_data, updates_input, axis, data_output); + ORT_RETURN_IF_ERROR(GetIndices(*data_output, *indices_input, axis, nullptr, indices_data)); + return ScatterData(Func_Add(), data_output, indices_data, updates_input, axis, nullptr, data_output); } #define GATHER_ELEMENTS_GRAD_IMPL_SPECIALIZED(Tin, Tdata) \ From b850fcbb11276e6a580879a9de6b9004954b3a20 Mon Sep 17 00:00:00 2001 From: Stephan Seitz Date: Fri, 22 May 2026 10:36:28 +0200 Subject: [PATCH 07/16] [NVEP]: fix test for multi-gpu situation (#27837) ### Description When a system has multiple Nvidia GPUs, then also multiple EpDevices for the NVEP should be created. ### Motivation and Context Fixes the following test failure on a multi-gpu system ``` [ RUN ] NvExecutionProviderTest.LoadUnloadPluginLibrary /home/stephan/projects/onnxruntime-winai/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc:359: Failure Expected equality of these values: num_test_ep_devices Which is: 2 1 Expected an OrtEpDevice to have been created by the test library. [ FAILED ] NvExecutionProviderTest.LoadUnloadPluginLibrary (0 ms) ``` --- onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index a54c35accbdc7..2e34fd58a2628 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -349,14 +349,14 @@ TEST(NvExecutionProviderTest, LoadUnloadPluginLibrary) { size_t num_devices = 0; ASSERT_ORTSTATUS_OK(Ort::GetApi().GetEpDevices(*ort_env, &ep_devices, &num_devices)); - // should be one device for the example EP + // should be at least one device for the example EP auto num_test_ep_devices = std::count_if(ep_devices, ep_devices + num_devices, [®istration_name, &c_api](const OrtEpDevice* device) { // the example uses the registration name for the EP name // but that is not a requirement and the two can differ. return c_api->EpDevice_EpName(device) == registration_name; }); - ASSERT_EQ(num_test_ep_devices, 1) << "Expected an OrtEpDevice to have been created by the test library."; + ASSERT_GE(num_test_ep_devices, 1) << "Expected at least one OrtEpDevice to have been created by the test library."; // and this should unload it ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(*ort_env, From 43989a71997db9110c0cab633d397278a98c5624 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 22 May 2026 07:56:56 -0700 Subject: [PATCH 08/16] QMoE CUDA: input validation, prepack cleanups, and packaging pipeline fix (#28607) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Follow-up to #28583. Addresses review feedback that landed after merge (input validation, redundant memset, dead branches in `PrePackComputeBias`) and fixes a pre-existing latent CUTLASS issue that surfaced as a packaging pipeline failure once MoE GEMM kernels were built with a multi-arch `CMAKE_CUDA_ARCHITECTURES` list spanning pre-Ampere and Ampere+ targets. ## Summary of Changes ### Packaging pipeline build fix | File | Change | |------|--------| | `onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h` | Replace the unconditional `static_assert(false, ...)` in the pre-Ampere `#else` branch of `MoeFCGemm::operator()` with `CUTLASS_NOT_IMPLEMENTED()` plus a comment explaining why this is safe. | Background: `moe_gemm_kernels_*.cu` instantiate `MoeFCGemm` through `MoeGemmRunner<...>::dispatchToArch`, which contains *runtime* (not `constexpr`) `if (sm_ >= 80 && sm_ < 90)` branches. NVCC therefore instantiates the kernel for every requested device target, including pre-Sm80 device compile passes. The old `static_assert(false, ...)` fired on those passes whenever `CMAKE_CUDA_ARCHITECTURES` contained any arch below 80 (e.g. the packaging pipeline list `52-real;61-real;75-real;86-real;89-real;90-virtual`). Replacing it with `CUTLASS_NOT_IMPLEMENTED()` lets NVCC emit a runtime trap stub for pre-Sm80, while runtime dispatch in `MoeGemmRunner::dispatchToArch()` already guarantees `sm_ >= 80` before the kernel is ever launched, so the stub is unreachable in practice. ### Address PR #28583 post-merge review | File | Change | |------|--------| | `onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu` | Add `ValidateScaledZP4BitBatchedArgs` (positive `experts`/`n`/`k_blocks`, `experts ≤ 65535` for the `gridDim.z` limit) and call it from both `LaunchQMoEScaledZP4BitBatched` overloads. Matches the validation style of `LaunchQMoERepackFP4ColToRow`. | | `onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc` (`PrePackSwizzleBlockScales`) | Remove the redundant `cudaMemsetAsync` of the destination buffer. `QMoEBlockScaleInterleaveKernel`'s `(batch, row, col) -> offset` map is a bijection over the padded output extent and writes 0 for padded source positions, so every output byte is already written. Comment explains the invariant. | | `onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc` (`PrePackComputeBias`, 4-bit block-wise) | Add `ORT_ENFORCE` checks for positive shape dims and an `INT_MAX/2` bound on `packed_k_blocks` (parity with `PrePackSwizzleBlockScales` / `PrePackRepackFP4Weights`). Drop the shadowed `bool is_fp16 = is_fp16_; bool is_bf16 = !is_fp16_;` locals in favour of `is_fp16_`. Replace the dead-branch ternary `(is_fp16 \|\| is_bf16 ? 2 : 4)` with `sizeof(uint16_t)` and a clarifying comment, and remove the unreachable `else ORT_THROW(...)` (the QMoE type path is strictly FP16/BF16). | ## Testing - Built locally with CUDA 12.8 against the failing CI arch list (`-DCMAKE_CUDA_ARCHITECTURES="52-real;61-real;75-real;86-real;89-real;90-virtual"`) and confirmed `onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemm_kernels_bf16_bf16.cu.o` compiles cleanly (only an `sm_<75` deprecation warning, no `static_assert` failure). - Existing QMoE Python tests (`onnxruntime/test/python/transformers/test_qmoe_cuda.py`, `test_qmoe_cpu.py`) exercise the affected `PrePackSwizzleBlockScales` / `PrePackComputeBias` paths under `--config Debug` builds and continue to pass; the added `ORT_ENFORCE` checks only trigger on invalid shapes that are not produced by the supported QMoE input contract. - No behaviour change on supported devices: `dispatchToArch` already gates `MoeFCGemm` behind `sm_ >= 80`, so the new `CUTLASS_NOT_IMPLEMENTED()` stub is unreachable at runtime. ## Motivation and Context Once #28583 enabled the MoE GEMM kernels as part of the contrib CUDA build, packaging pipelines (which target a wide arch range to maximise GPU coverage) started failing on the pre-Ampere device compile passes. The kernel-side fix in this PR resolves the immediate breakage while keeping the cmake-level binary-size optimisation (per-kernel arch pinning, TensorRT-LLM style) as a follow-up — CMake's `CUDA_ARCHITECTURES` is target/directory-scoped only, so the proper way to restrict per-kernel archs is an OBJECT-library refactor, which is intentionally not in scope here. ## Checklist - [x] Tests added/updated (input validation covered by existing QMoE tests; the new `ORT_ENFORCE` checks fail loudly on out-of-contract shapes) - [x] No documentation changes needed - [x] No breaking changes - [x] Local packaging-pipeline arch list verified to compile --- .../collective/epilogue_moe_finalize.hpp | 23 ++++++------ .../gemm/kernel/moe_cutlass_kernel.h | 8 +++-- .../launchers/moe_gemm_tma_ws_launcher.inl | 4 ++- .../contrib_ops/cuda/moe/moe_quantization.cc | 27 ++++++++------ .../contrib_ops/cuda/moe/qmoe_kernels.cu | 36 ++++--------------- 5 files changed, 45 insertions(+), 53 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp index cd5c71f83ac27..8ba877aa21a68 100644 --- a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp @@ -171,10 +171,10 @@ class EpilogueMoeFusedFinalize { auto N = get<1>(problem_shape_mnkl); auto L = get<3>(problem_shape_mnkl); - auto mma_tile_m = tile_size<0>(tiled_mma); - auto mma_tile_n = tile_size<1>(tiled_mma); - auto epi_tile_m = size<0>(EpilogueTile{}); - auto epi_tile_n = size<1>(EpilogueTile{}); + constexpr auto mma_tile_m = decltype(tile_size<0>(tiled_mma)){}; + constexpr auto mma_tile_n = decltype(tile_size<1>(tiled_mma)){}; + constexpr auto epi_tile_m = size<0>(EpilogueTile{}); + constexpr auto epi_tile_n = size<1>(EpilogueTile{}); CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); @@ -248,16 +248,17 @@ class EpilogueMoeFusedFinalize { Tensor tRS_rD = make_tensor(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N) // Make a tiled copy vectorized along major direction of D + constexpr int TiledMmaThreads = decltype(cute::size(tiled_mma))::value; auto tiled_s2r = [&]() { if constexpr (cutlass::gemm::detail::is_k_major()) { constexpr int NumThreadsMajor = epi_tile_n / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; + constexpr int NumThreadsMinor = TiledMmaThreads / NumThreadsMajor; return make_tiled_copy(CopyAtomS2R{}, Layout, Int>, Stride, _1>>{}, Layout>>{}); } else if constexpr (cutlass::gemm::detail::is_mn_major()) { constexpr int NumThreadsMajor = epi_tile_m / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; + constexpr int NumThreadsMinor = TiledMmaThreads / NumThreadsMajor; return make_tiled_copy(CopyAtomS2R{}, Layout, Int>, Stride<_1, Int>>{}, Layout, _1>>{}); @@ -274,11 +275,11 @@ class EpilogueMoeFusedFinalize { Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) // Allocate intermediate registers for a single subtile - Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rBias = make_tensor(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rScale = make_tensor(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rBias = make_tensor(shape(tSR_gBias(_, _, _, 0, 0))); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rScale = make_tensor(shape(tSR_gScale(_, _, _, 0, 0))); // ((S2R,S2R_V),S2R_M,S2R_N) // Make an identity coordinate tensor for predicating our output MN tile Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h index e28d2b859a2f0..ab8ae054db048 100644 --- a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -586,8 +586,12 @@ struct MoeFCGemm { run_kernel(params, shared_storage); } #else - static_assert( - false, "Invalid architecture being compiled. Only Ampere+ supported in weight-only quantization kernels."); + // Pre-Ampere device compile pass: the MoeFCGemm body is unsupported on these archs, + // but NVCC must still emit *some* body for each requested target. Runtime dispatch + // in MoeGemmRunner::dispatchToArch() never invokes this kernel when sm_ < 80, so a + // device-side trap is safe and lets the same .cu compile cleanly under mixed arch + // lists (e.g. 52;61;75;86;89;90 in packaging pipelines). + CUTLASS_NOT_IMPLEMENTED(); #endif #else CUTLASS_NOT_IMPLEMENTED(); diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index 46a6bc6388a27..19bb1a0975720 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -77,7 +77,9 @@ ReturnType construct_if_true(Args&&... args) { if constexpr (FLAG) { - return ReturnType{std::forward(args)...}; + // Use parenthesized aggregate init (C++20) instead of brace-init to avoid + // MSVC C2397 narrowing conversion errors (e.g. size_t -> FastDivmod(int)). + return ReturnType(std::forward(args)...); } else { diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc index 5979f17e5abcc..f6bf5bbb1f0e3 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc @@ -1113,9 +1113,11 @@ void QMoE::PrePackSwizzleBlockScales(const Tensor& tensor, cudaStream_t stream, p_src = temp_src_gpu.get(); } + // QMoEBlockScaleInterleaveKernel writes every byte of the output buffer + // (the (batch, row, col) -> offset map is a bijection over + // [0, batch_size) x [0, rows_padded) x [0, cols_padded), and padded + // source positions are written as 0), so no explicit memset is required. packed_buf = IAllocator::MakeUniquePtr(alloc, dst_bytes, true); - // Zero-fill for padding regions (kernel only writes within bounds) - CUDA_CALL_THROW(cudaMemsetAsync(packed_buf.get(), 0, dst_bytes, stream)); int multi_processor_count = 0; int device_id = 0; @@ -1250,16 +1252,23 @@ void QMoE::PrePackComputeBias(const Tensor& tensor, cudaStream_t stream, Allocat return; } - bool is_fp16 = is_fp16_; - bool is_bf16 = !is_fp16_; - ORT_ENFORCE(shape.NumDimensions() == 3, "Expected 3D zeros for block-wise 4-bit"); + ORT_ENFORCE(shape[0] > 0 && shape[1] > 0 && shape[2] > 0, + "4-bit block-wise zeros must have positive dimensions, got ", shape.ToString()); + // packed_k_blocks is doubled to k_blocks below; constrain it to half of INT_MAX to keep the + // doubled value (and the int dims passed into LaunchQMoEScaledZP4BitBatched) within int range. + constexpr int64_t kMaxPackedKBlocks = std::numeric_limits::max() / 2; + ORT_ENFORCE(shape[0] <= std::numeric_limits::max() && + shape[1] <= std::numeric_limits::max() && + shape[2] <= kMaxPackedKBlocks, + "4-bit block-wise zeros dimensions exceed CUDA launch int range, got ", shape.ToString()); const int experts = static_cast(shape[0]); const int n = static_cast(shape[1]); const int packed_k_blocks = static_cast(shape[2]); const int k_blocks = packed_k_blocks * 2; + // QMoE only supports FP16/BF16 inputs (is_fp16_ is set in the ctor), both of which are 2 bytes. size_t output_count = static_cast(experts) * static_cast(k_blocks) * static_cast(n); - size_t bytes = output_count * (is_fp16 || is_bf16 ? 2 : 4); + size_t bytes = output_count * sizeof(uint16_t); packed_bias = IAllocator::MakeUniquePtr(alloc, bytes, true); const void* p_src_zp = tensor.DataRaw(); @@ -1272,20 +1281,18 @@ void QMoE::PrePackComputeBias(const Tensor& tensor, cudaStream_t stream, Allocat const uint8_t* zp_ptr = static_cast(p_src_zp); constexpr float kDefaultZeroPoint4Bit = 8.0f; - if (is_fp16) { + if (is_fp16_) { LaunchQMoEScaledZP4BitBatched( zp_ptr, static_cast(packed_scale.get()), static_cast(packed_bias.get()), experts, n, k_blocks, kDefaultZeroPoint4Bit, stream); - } else if (is_bf16) { + } else { LaunchQMoEScaledZP4BitBatched( zp_ptr, static_cast(packed_scale.get()), static_cast<__nv_bfloat16*>(packed_bias.get()), experts, n, k_blocks, kDefaultZeroPoint4Bit, stream); - } else { - ORT_THROW("Unsupported type for 4-bit block-wise ZP prepack. Expected FP16/BF16."); } } CUDA_CALL_THROW(cudaStreamSynchronize(stream)); diff --git a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu index cd59fd248b3a2..28fd4fb1516fb 100644 --- a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu @@ -3,25 +3,20 @@ // Licensed under the MIT License. #include "contrib_ops/cuda/moe/qmoe_kernels.h" +#include "core/common/narrow.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/llm/moe_gemm/moe_kernels.h" #include #include #include #include -#include namespace onnxruntime { namespace contrib { namespace cuda { int Compute1DGridSize(int num_elements, int block_size) { - ORT_ENFORCE(num_elements >= 0, "CUDA launch element count must be non-negative, got ", num_elements); - ORT_ENFORCE(block_size > 0, "CUDA launch block size must be positive, got ", block_size); - int64_t grid_size = (static_cast(num_elements) + block_size - 1) / block_size; - ORT_ENFORCE(grid_size <= std::numeric_limits::max(), - "CUDA launch grid size exceeds int range: ", grid_size); - return static_cast(grid_size); + return (num_elements + block_size - 1) / block_size; } template @@ -698,11 +693,7 @@ void LaunchQMoEDequantizeFp4WeightsImpl( cudaStream_t stream) { int64_t total = static_cast(num_experts) * n * k; constexpr int block = 256; - ORT_ENFORCE(total >= 0, "QMoEDequantizeFp4Weights: negative element count, got ", total); - int64_t grid_i64 = (total + block - 1) / block; - ORT_ENFORCE(grid_i64 <= std::numeric_limits::max(), - "QMoEDequantizeFp4Weights: grid size exceeds int range: ", grid_i64); - int grid = static_cast(grid_i64); + int grid = onnxruntime::narrow((total + block - 1) / block); QMoEDequantizeFp4WeightsKernel<<>>( packed_weights, block_scales, global_scales, output, num_experts, n, k); } @@ -785,11 +776,7 @@ void LaunchQMoEDequantizeFp8WeightsImpl( cudaStream_t stream) { int64_t total = static_cast(num_experts) * n * k; constexpr int block = 256; - ORT_ENFORCE(total >= 0, "QMoEDequantizeFp8Weights: negative element count, got ", total); - int64_t grid_i64 = (total + block - 1) / block; - ORT_ENFORCE(grid_i64 <= std::numeric_limits::max(), - "QMoEDequantizeFp8Weights: grid size exceeds int range: ", grid_i64); - int grid = static_cast(grid_i64); + int grid = onnxruntime::narrow((total + block - 1) / block); QMoEDequantizeFp8WeightsKernel<<>>( weights, global_scales, output, num_experts, n, k); } @@ -862,16 +849,10 @@ void LaunchQMoERepackFP4ColToRow( int64_t k, int64_t n, cudaStream_t stream) { - ORT_ENFORCE(experts > 0, "LaunchQMoERepackFP4ColToRow requires positive expert count, got ", experts); - ORT_ENFORCE(k > 0 && n > 0, "LaunchQMoERepackFP4ColToRow requires positive k and n, got k=", k, ", n=", n); - ORT_ENFORCE(k % 2 == 0 && n % 2 == 0, - "LaunchQMoERepackFP4ColToRow requires even k and n, got k=", k, ", n=", n); const int64_t total = static_cast(experts) * n * (k / 2); constexpr int kThreads = 256; - int64_t blocks = (total + kThreads - 1) / kThreads; - ORT_ENFORCE(blocks <= static_cast(std::numeric_limits::max()), - "LaunchQMoERepackFP4ColToRow grid size exceeds int range"); - QMoERepackFP4ColToRowKernel<<(blocks), kThreads, 0, stream>>>( + int blocks = onnxruntime::narrow((total + kThreads - 1) / kThreads); + QMoERepackFP4ColToRowKernel<<>>( input, output, experts, k, n); } @@ -901,10 +882,7 @@ __global__ void BatchedTransposeKernel(const T* __restrict__ input, T* __restric void LaunchBatchedTranspose(cudaStream_t stream, const void* input, void* output, int batch, int rows, int cols, int element_size) { int64_t total_elements = static_cast(batch) * rows * cols; int threads = 256; - int64_t blocks_i64 = (total_elements + threads - 1) / threads; - ORT_ENFORCE(blocks_i64 <= std::numeric_limits::max(), - "LaunchBatchedTranspose grid size exceeds int range: ", blocks_i64); - int blocks = static_cast(blocks_i64); + int blocks = onnxruntime::narrow((total_elements + threads - 1) / threads); if (element_size == 1) { BatchedTransposeKernel<<>>(static_cast(input), static_cast(output), batch, rows, cols); From 5003d9322706160eae589ce27be210d5d338e5cf Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Fri, 22 May 2026 07:57:22 -0700 Subject: [PATCH 09/16] Fix CUDA build with contrib ops disabled (#28554) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description The CUDA Attention kernel (`core/providers/cuda/llm/attention.cc`) depends on contrib_ops internals (flash attention, memory efficient attention, unfused attention helpers) but was compiled unconditionally. When building with `--disable_contrib_ops`, `GetAttentionKernelOptions()` is unavailable (guarded by `#ifndef DISABLE_CONTRIB_OPS` in `cuda_kernel.h`), causing a compile error. Changes: - **`cmake/onnxruntime_providers_cuda.cmake`** — When contrib ops are disabled (and not in CUDA minimal mode), include the `contrib_ops/cuda/bert/` attention infrastructure files (flash attention, memory efficient attention, unfused attention helpers, etc.) so the ONNX domain Attention kernel can compile and link. Uses `elseif(onnxruntime_DISABLE_CONTRIB_OPS AND NOT onnxruntime_CUDA_MINIMAL)` to avoid including these files in CUDA minimal builds where `llm/attention.cc` isn't compiled and `cudnn_frontend.h` isn't available. - **`onnxruntime/core/providers/cuda/cuda_execution_provider.h`** — Remove `#ifndef DISABLE_CONTRIB_OPS` guards from the `AttentionKernelOptions` include, `GetAttentionKernelOptions()` method, and `attention_kernel_options_` member variable - **`onnxruntime/core/providers/cuda/cuda_kernel.h`** — Remove `#ifndef DISABLE_CONTRIB_OPS` guard from `GetAttentionKernelOptions()` The CUDA Attention kernel and its underlying attention backends (flash, memory efficient, unfused) are now always available in full CUDA builds regardless of whether contrib ops are enabled. No changes are needed in `cuda_execution_provider.cc` since the Attention kernel registrations remain unconditional. ### Motivation and Context Building onnxruntime with CUDA enabled and `--disable_contrib_ops` fails: ``` error C2039: 'GetAttentionKernelOptions': is not a member of 'onnxruntime::cuda::Attention' ``` This is a valid build configuration (useful for reducing compile time) that should be supported. Rather than excluding the CUDA Attention kernel when contrib ops are disabled, the necessary attention infrastructure from `contrib_ops/cuda/bert/` is included in the build so the ONNX domain Attention op retains full CUDA acceleration. The fix is scoped to non-minimal CUDA builds only, since CUDA minimal builds use a non-recursive glob that doesn't include `llm/attention.cc` and don't have `cudnn_frontend` available. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- cmake/onnxruntime_providers_cuda.cmake | 11 +++++++++++ .../core/providers/cuda/cuda_execution_provider.h | 7 ------- onnxruntime/core/providers/cuda/cuda_kernel.h | 2 -- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index f3c2d8b947968..b28c35fd502ed 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -93,6 +93,17 @@ endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) + elseif(onnxruntime_DISABLE_CONTRIB_OPS AND NOT onnxruntime_CUDA_MINIMAL) + # The ONNX domain CUDA Attention kernel (core/providers/cuda/llm/attention.cc) depends on + # attention infrastructure in contrib_ops/cuda/bert/ (flash attention, memory efficient + # attention, unfused attention helpers, etc.). Include the bert attention infrastructure + # even when contrib ops are disabled so that the ONNX Attention kernel can compile and link. + set(onnxruntime_cuda_bert_cc_srcs ${onnxruntime_cuda_contrib_ops_cc_srcs}) + list(FILTER onnxruntime_cuda_bert_cc_srcs INCLUDE REGEX ".*/contrib_ops/cuda/bert/.*") + set(onnxruntime_cuda_bert_cu_srcs ${onnxruntime_cuda_contrib_ops_cu_srcs}) + list(FILTER onnxruntime_cuda_bert_cu_srcs INCLUDE REGEX ".*/contrib_ops/cuda/bert/.*") + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cuda_bert_cc_srcs} ${onnxruntime_cuda_bert_cu_srcs}) + list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_bert_cc_srcs} ${onnxruntime_cuda_bert_cu_srcs}) endif() if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 1b2e8494f5f99..537c14fd2b3b3 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -16,10 +16,7 @@ #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/tunable/cuda_tuning_context.h" - -#ifndef DISABLE_CONTRIB_OPS #include "contrib_ops/cuda/bert/attention_kernel_options.h" -#endif namespace onnxruntime { @@ -91,13 +88,11 @@ class CUDAExecutionProvider : public IExecutionProvider { bool IsFuseConvBias() const { return info_.fuse_conv_bias; } bool UseTF32() const { return info_.use_tf32; } -#ifndef DISABLE_CONTRIB_OPS // Attention kernel options parsed from sdpa_kernel cuda provider option. const AttentionKernelOptions* GetAttentionKernelOptions() const { attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true, true); return &attention_kernel_options_; } -#endif ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); @@ -143,10 +138,8 @@ class CUDAExecutionProvider : public IExecutionProvider { // the tuning context might be altered when calling into a TunableOp mutable cuda::tunable::CudaTuningContext tuning_context_; -#ifndef DISABLE_CONTRIB_OPS // Attention kernel options parsed from sdpa_kernel cuda provider option. mutable AttentionKernelOptions attention_kernel_options_; -#endif class PerThreadContext final { public: diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 13bf5b37490e0..1d891f204b9bd 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -172,11 +172,9 @@ class CudaKernel : public OpKernel { return provider_->UseTF32(); } -#ifndef DISABLE_CONTRIB_OPS const AttentionKernelOptions* GetAttentionKernelOptions() const { return provider_->GetAttentionKernelOptions(); } -#endif tunable::CudaTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); From d2836a8f9a0814fd882c79cc91667c390537dad6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 22 May 2026 09:13:07 -0700 Subject: [PATCH 10/16] Optimize MLAS quantized KV-cache GEMM kernels (follow-up to #28578) (#28606) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Follow-up performance and correctness improvements to the MLAS quantized KV-cache GEMM kernels introduced in #28578. These changes target the AVX2, AVX512-VNNI, and NEON kernel files only. ### Changes 1. **Use embedded rounding in `QuantizeRowToU8` (AVX-512)** Replace `_mm512_roundscale_ps` + `_mm512_cvtps_epi32` with a single `_mm512_cvt_roundps_epi32` that combines round-to-nearest-even and float-to-int32 in one instruction, saving a `vrndscaleps` per loop iteration. 2. **Use int32 zero-point correction in VNNI dot products** Perform the `dot - 128*sum(b)` zero-point correction in int32 before converting to float. This avoids precision loss when operands exceed 2^24 (where float32 loses integer precision), preventing potential catastrophic cancellation. 3. **Defer per-tensor scale in `FusedDotInt8` (AVX2 + AVX-512)** Factor the constant per-tensor scale out of the inner loop: `sum(a*b*s) = s * sum(a*b)`. Saves one `vmulps` per 8/16 elements in the hot path. 4. **Defer per-tensor scale in SVGemm and NEON dequantization** - AVX2/AVX-512 `SVGemm`: accumulate unscaled dot products, multiply the output row by the per-tensor scale once after the K loop. - NEON: parameterize `DequantRow_Neon` with `apply_per_tensor_scale` to skip per-element scaling during dequantization when using per-tensor mode; scale the output row once after accumulation. - Also: clarify AVX2 INT4 nibble extraction comment and use `uint32_t` for the raw packed load. ### Motivation The per-tensor quantization paths were previously applying a constant scale factor on every element inside hot loops. By deferring the scalar multiplication to after accumulation (using the distributive property), we reduce instruction count in the inner loops without changing numerical results (within normal FP reordering tolerance). The int32 zero-point correction fix addresses a latent precision issue in AVX512-VNNI paths that could manifest at large K dimensions (K > ~512). ### Testing - `onnxruntime_mlas_test --gtest_filter=KVQuant.*` passes (Debug build, x86-64). - No new tests needed — existing `KVQuant.ShortExecute` exercises all modified code paths across INT8/INT4 per-tensor/per-channel modes. ### Benchmark Results Measured on Intel Xeon Platinum 8370C (8 cores, 16 threads, AVX-512 + VNNI), Release build. Each benchmark uses `--benchmark_min_time=0.3s --benchmark_repetitions=5`. **QKGemm (query × K_cache^T) — INT8 per-tensor (S8_PerTensor, QuantType:0)** This is the path most improved by the deferred-scale optimization (changes 3 and 4). | Shape | Before (ns) | After (ns) | Speedup | |---|---:|---:|---:| | M=1, N=512, K=64 | 2,926 | 2,803 | 1.04x | | M=1, N=512, K=128 | 5,914 | 5,074 | **1.17x** | | M=1, N=2048, K=128 | 22,401 | 19,937 | **1.12x** | | M=128, N=512, K=64 | 412,505 | 304,230 | **1.36x** | | M=128, N=512, K=128 | 911,508 | 788,198 | **1.16x** | | M=128, N=2048, K=64 | 1,662,547 | 1,242,441 | **1.34x** | | M=128, N=2048, K=128 | 3,660,599 | 3,176,911 | **1.15x** | **SVGemm (attn_probs × V_cache) — INT8 per-tensor (S8_PerTensor, QuantType:0)** | Shape | Before (ns) | After (ns) | Speedup | |---|---:|---:|---:| | M=1, N=64, K=512 | 4,707 | 4,122 | **1.14x** | | M=1, N=64, K=2048 | 18,516 | 16,533 | **1.12x** | | M=128, N=64, K=512 | 399,703 | 358,821 | **1.11x** | | M=128, N=64, K=2048 | 1,633,807 | 1,423,984 | **1.15x** | | M=128, N=128, K=512 | 775,205 | 761,527 | 1.02x | | M=128, N=128, K=2048 | 3,086,642 | 2,979,566 | **1.04x** | **Other quant types (S8_PerChannel, S4_PerTensor, S4_PerChannel) — neutral** Per-channel and INT4 paths are not affected by the deferred-scale optimization. Representative M=128 results: | Benchmark | QuantType | Before (ns) | After (ns) | Ratio | |---|---|---:|---:|---:| | QKGemm M=128, N=2048, K=128 | S8_PerChannel | 4,555,381 | 4,684,954 | 0.97x | | QKGemm M=128, N=2048, K=128 | S4_PerTensor | 3,841,759 | 3,819,387 | 1.01x | | QKGemm M=128, N=2048, K=128 | S4_PerChannel | 4,043,262 | 4,056,033 | 1.00x | | SVGemm M=128, N=128, K=2048 | S8_PerChannel | 4,449,839 | 4,290,344 | **1.04x** | | SVGemm M=128, N=128, K=2048 | S4_PerTensor | 2,989,684 | 2,998,154 | 1.00x | | SVGemm M=128, N=128, K=2048 | S4_PerChannel | 3,403,497 | 3,390,452 | 1.00x | **Summary**: The INT8 per-tensor paths (the most common decode configuration) see **12–36% QKGemm speedup** and **4–15% SVGemm speedup** at representative shapes. Other quantization modes are neutral within noise (±1–3%). --- .../core/mlas/lib/qkv_quant_kernel_avx2.cpp | 43 ++++++++----- .../mlas/lib/qkv_quant_kernel_avx512vnni.cpp | 63 ++++++++++++------- .../core/mlas/lib/qkv_quant_kernel_neon.cpp | 57 ++++++++++++----- 3 files changed, 113 insertions(+), 50 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp index d3681ff6bfdff..8bec2d350afa5 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp @@ -42,16 +42,16 @@ DequantInt4x8(const uint8_t* src, size_t col, bool per_channel, const float* sca // Load 4 packed bytes safely without strict-aliasing / alignment UB. // Compilers optimize memcpy of 4 bytes to a single mov instruction. - int raw_bytes; + uint32_t raw_bytes; std::memcpy(&raw_bytes, base, sizeof(raw_bytes)); - __m128i packed = _mm_cvtsi32_si128(raw_bytes); + __m128i packed = _mm_cvtsi32_si128(static_cast(raw_bytes)); // Low nibbles (even columns): AND with 0x0F __m128i lo_mask = _mm_set1_epi8(0x0F); __m128i lo = _mm_and_si128(packed, lo_mask); - // High nibbles (odd columns): shift right 4 using 32-bit granularity - // to prevent bit bleeding across 16-bit lane boundaries, then mask. + // High nibbles (odd columns): shift right by 4 within 32-bit lanes, then mask. + // Any cross-byte bits from the shift land in the upper nibble and are discarded by the mask. __m128i hi = _mm_and_si128(_mm_srli_epi32(packed, 4), lo_mask); // Interleave low and high nibbles: [lo0,hi0, lo1,hi1, lo2,hi2, lo3,hi3] @@ -126,19 +126,19 @@ FusedDotInt8( acc0 = _mm256_fmadd_ps(a0, bf0, acc0); } } else { - __m256 scale_vec = _mm256_broadcast_ss(scales); + // Per-tensor: defer scale multiplication until after accumulation. + // sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k]) + // This saves one vmulps per 8 elements in the hot loop. for (; k < vec_end; k += 16) { __m128i raw0 = _mm_loadl_epi64(reinterpret_cast(b_row + k)); __m256i i32_0 = _mm256_cvtepi8_epi32(raw0); __m256 bf0 = _mm256_cvtepi32_ps(i32_0); - bf0 = _mm256_mul_ps(bf0, scale_vec); __m256 a0 = _mm256_loadu_ps(a_row + k); acc0 = _mm256_fmadd_ps(a0, bf0, acc0); __m128i raw1 = _mm_loadl_epi64(reinterpret_cast(b_row + k + 8)); __m256i i32_1 = _mm256_cvtepi8_epi32(raw1); __m256 bf1 = _mm256_cvtepi32_ps(i32_1); - bf1 = _mm256_mul_ps(bf1, scale_vec); __m256 a1 = _mm256_loadu_ps(a_row + k + 8); acc1 = _mm256_fmadd_ps(a1, bf1, acc1); } @@ -146,7 +146,6 @@ FusedDotInt8( __m128i raw0 = _mm_loadl_epi64(reinterpret_cast(b_row + k)); __m256i i32_0 = _mm256_cvtepi8_epi32(raw0); __m256 bf0 = _mm256_cvtepi32_ps(i32_0); - bf0 = _mm256_mul_ps(bf0, scale_vec); __m256 a0 = _mm256_loadu_ps(a_row + k); acc0 = _mm256_fmadd_ps(a0, bf0, acc0); } @@ -161,9 +160,15 @@ FusedDotInt8( float dot = _mm_cvtss_f32(sum4); // Scalar tail - for (; k < K; ++k) { - float sc = per_channel ? scales[k] : scales[0]; - dot += a_row[k] * static_cast(b_row[k]) * sc; + if (per_channel) { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]) * scales[k]; + } + } else { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]); + } + dot *= scales[0]; } return dot; } @@ -326,7 +331,7 @@ SVGemm_Avx2( } } } else { - __m256 scale_vec = _mm256_broadcast_ss(Scales); + // Per-tensor: accumulate unscaled dot products, then scale the output row once. for (size_t k = 0; k < K; ++k) { const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); const float a_val = a_row[k]; @@ -337,15 +342,25 @@ SVGemm_Avx2( __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); __m256i i32 = _mm256_cvtepi8_epi32(raw); __m256 bf = _mm256_cvtepi32_ps(i32); - bf = _mm256_mul_ps(bf, scale_vec); __m256 c_vec = _mm256_loadu_ps(c_row + n); c_vec = _mm256_fmadd_ps(a_broadcast, bf, c_vec); _mm256_storeu_ps(c_row + n, c_vec); } for (; n < N; ++n) { - c_row[n] += a_val * static_cast(b_row[n]) * Scales[0]; + c_row[n] += a_val * static_cast(b_row[n]); } } + + __m256 scale_vec = _mm256_broadcast_ss(Scales); + n = 0; + for (; n < vec_end_n; n += 8) { + __m256 c_vec = _mm256_loadu_ps(c_row + n); + c_vec = _mm256_mul_ps(c_vec, scale_vec); + _mm256_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Scales[0]; + } } } else { // INT4 fused path diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp index ac23a0703ddff..fa5aff0165897 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp @@ -83,19 +83,21 @@ QuantizeRowToU8(const float* src, uint8_t* dst, size_t len) i = 0; for (; i < vec_end; i += 16) { __m512 v = _mm512_loadu_ps(src + i); - // q = round(v * inv_scale) + 128, clamped to [0, 255] + // q = (v * inv_scale) + 128, clamped to [0, 255] __m512 scaled = _mm512_fmadd_ps(v, inv_scale_vec, zp_vec); - scaled = _mm512_roundscale_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); scaled = _mm512_max_ps(scaled, min_val); scaled = _mm512_min_ps(scaled, max_clamp); - __m512i qi = _mm512_cvtps_epi32(scaled); + // Round-to-nearest-even and convert to int32 in a single instruction + // (AVX-512 embedded rounding eliminates a separate vrndscaleps). + __m512i qi = _mm512_cvt_roundps_epi32(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); // Pack 16 int32 -> 16 uint8 __m128i packed = _mm512_cvtepi32_epi8(qi); _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + i), packed); } - // Scalar tail + // Scalar tail (use nearbyintf for round-to-nearest-even, matching the + // AVX-512 embedded rounding semantics above). for (; i < len; ++i) { - float q = std::round(src[i] * inv_scale) + 128.0f; + float q = std::nearbyintf(src[i] * inv_scale) + 128.0f; q = std::max(0.0f, std::min(255.0f, q)); dst[i] = static_cast(q); } @@ -169,9 +171,11 @@ VnniDotInt8PerTensor( // Correction: dpbusd computed sum(a_u8 * b_s8). // We want sum((a_u8 - 128) * b_s8) = sum(a_u8 * b_s8) - 128 * sum(b_s8) - float corrected = static_cast(dot_i32) - 128.0f * static_cast(b_sum_i32); + // Perform correction in int32 to preserve precision (avoids float rounding + // when |dot_i32| or |128*b_sum_i32| exceed 2^24). + int32_t corrected = dot_i32 - (128 * b_sum_i32); - return corrected * scale_a * scale_b; + return static_cast(corrected) * scale_a * scale_b; } // @@ -221,19 +225,19 @@ FusedDotInt8_Avx512( acc0 = _mm512_fmadd_ps(a0, bf0, acc0); } } else { - __m512 scale_vec = _mm512_set1_ps(scales[0]); + // Per-tensor: defer scale multiplication until after accumulation. + // sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k]) + // This saves one vmulps per 16 elements in the hot loop. for (; k < vec_end; k += 32) { __m128i raw0 = _mm_loadu_si128(reinterpret_cast(b_row + k)); __m512i i32_0 = _mm512_cvtepi8_epi32(raw0); __m512 bf0 = _mm512_cvtepi32_ps(i32_0); - bf0 = _mm512_mul_ps(bf0, scale_vec); __m512 a0 = _mm512_loadu_ps(a_row + k); acc0 = _mm512_fmadd_ps(a0, bf0, acc0); __m128i raw1 = _mm_loadu_si128(reinterpret_cast(b_row + k + 16)); __m512i i32_1 = _mm512_cvtepi8_epi32(raw1); __m512 bf1 = _mm512_cvtepi32_ps(i32_1); - bf1 = _mm512_mul_ps(bf1, scale_vec); __m512 a1 = _mm512_loadu_ps(a_row + k + 16); acc1 = _mm512_fmadd_ps(a1, bf1, acc1); } @@ -241,7 +245,6 @@ FusedDotInt8_Avx512( __m128i raw0 = _mm_loadu_si128(reinterpret_cast(b_row + k)); __m512i i32_0 = _mm512_cvtepi8_epi32(raw0); __m512 bf0 = _mm512_cvtepi32_ps(i32_0); - bf0 = _mm512_mul_ps(bf0, scale_vec); __m512 a0 = _mm512_loadu_ps(a_row + k); acc0 = _mm512_fmadd_ps(a0, bf0, acc0); } @@ -251,9 +254,15 @@ FusedDotInt8_Avx512( float dot = _mm512_reduce_add_ps(acc0); // Scalar tail - for (; k < K; ++k) { - float sc = per_channel ? scales[k] : scales[0]; - dot += a_row[k] * static_cast(b_row[k]) * sc; + if (per_channel) { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]) * scales[k]; + } + } else { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]); + } + dot *= scales[0]; } return dot; } @@ -402,11 +411,11 @@ VnniMultiDot4Int8PerTensor( bs[3] += static_cast(b3[k]); } - const float zp = 128.0f; - out[0] = (static_cast(dot[0]) - zp * static_cast(bs[0])) * combined_scale; - out[1] = (static_cast(dot[1]) - zp * static_cast(bs[1])) * combined_scale; - out[2] = (static_cast(dot[2]) - zp * static_cast(bs[2])) * combined_scale; - out[3] = (static_cast(dot[3]) - zp * static_cast(bs[3])) * combined_scale; + // Zero-point correction in int32 for precision (see VnniDotInt8PerTensor). + out[0] = static_cast(dot[0] - 128 * bs[0]) * combined_scale; + out[1] = static_cast(dot[1] - 128 * bs[1]) * combined_scale; + out[2] = static_cast(dot[2] - 128 * bs[2]) * combined_scale; + out[3] = static_cast(dot[3] - 128 * bs[3]) * combined_scale; } // ============================================================================ @@ -569,7 +578,7 @@ SVGemm_Avx512Vnni( } } } else { - __m512 scale_vec = _mm512_set1_ps(Scales[0]); + // Per-tensor: accumulate unscaled dot products, then scale the output row once. for (size_t k = 0; k < K; ++k) { const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); const float a_val = a_row[k]; @@ -580,15 +589,25 @@ SVGemm_Avx512Vnni( __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); __m512i i32 = _mm512_cvtepi8_epi32(raw); __m512 bf = _mm512_cvtepi32_ps(i32); - bf = _mm512_mul_ps(bf, scale_vec); __m512 c_vec = _mm512_loadu_ps(c_row + n); c_vec = _mm512_fmadd_ps(a_broadcast, bf, c_vec); _mm512_storeu_ps(c_row + n, c_vec); } for (; n < N; ++n) { - c_row[n] += a_val * static_cast(b_row[n]) * Scales[0]; + c_row[n] += a_val * static_cast(b_row[n]); } } + + __m512 scale_vec = _mm512_set1_ps(Scales[0]); + n = 0; + for (; n < vec_end_n; n += 16) { + __m512 c_vec = _mm512_loadu_ps(c_row + n); + c_vec = _mm512_mul_ps(c_vec, scale_vec); + _mm512_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Scales[0]; + } } } else { // INT4 path: 512-bit wide diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp index ae5a56028bbf9..1aabbd8ca39cb 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp @@ -29,12 +29,13 @@ using namespace MlasKVQuantInternal; namespace { // -// Dequantize 8 INT8 values starting at `col` and scale them. +// Dequantize 8 INT8 values starting at `col`. +// Per-channel rows are always scaled. Per-tensor rows may defer scaling. // Produces two float32x4_t (8 floats total) stored into dst. // inline void DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel, - const float* scales, float* dst) + const float* scales, bool apply_per_tensor_scale, float* dst) { // Load 8 int8 values int8x8_t raw = vld1_s8(src + col); @@ -52,7 +53,7 @@ DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel, float32x4_t sc_hi = vld1q_f32(scales + col + 4); f_lo = vmulq_f32(f_lo, sc_lo); f_hi = vmulq_f32(f_hi, sc_hi); - } else { + } else if (apply_per_tensor_scale) { float32x4_t sc = vdupq_n_f32(scales[0]); f_lo = vmulq_f32(f_lo, sc); f_hi = vmulq_f32(f_hi, sc); @@ -64,10 +65,11 @@ DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel, // // Dequantize 8 INT4 values (4 packed bytes) starting at even column `col`. +// Per-channel rows are always scaled. Per-tensor rows may defer scaling. // inline void DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel, - const float* scales, float* dst) + const float* scales, bool apply_per_tensor_scale, float* dst) { const uint8_t* base = src + col / 2; @@ -94,7 +96,7 @@ DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel, float32x4_t sc_hi = vld1q_f32(scales + col + 4); f_lo = vmulq_f32(f_lo, sc_lo); f_hi = vmulq_f32(f_hi, sc_hi); - } else { + } else if (apply_per_tensor_scale) { float32x4_t sc = vdupq_n_f32(scales[0]); f_lo = vmulq_f32(f_lo, sc); f_hi = vmulq_f32(f_hi, sc); @@ -106,6 +108,8 @@ DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel, // // Dequantize one row of length `cols` from packed quantized buffer into `dst`. +// `apply_per_tensor_scale=false` leaves per-tensor rows unscaled so callers can +// factor the single scale out of an outer accumulation loop. // void DequantRow_Neon( @@ -113,7 +117,8 @@ DequantRow_Neon( float* dst, size_t cols, MLAS_KV_QUANT_TYPE qt, - const float* scales) + const float* scales, + bool apply_per_tensor_scale) { const bool int4 = IsInt4Mode(qt); const bool per_channel = IsPerChannelMode(qt); @@ -124,22 +129,32 @@ DequantRow_Neon( if (!int4) { const auto* src = static_cast(src_raw); for (; c < vec_end; c += 8) { - DequantInt8x8_Neon(src, c, per_channel, scales, dst + c); + DequantInt8x8_Neon(src, c, per_channel, scales, apply_per_tensor_scale, dst + c); } for (; c < cols; ++c) { - float sc = per_channel ? scales[c] : scales[0]; - dst[c] = static_cast(src[c]) * sc; + if (per_channel) { + dst[c] = static_cast(src[c]) * scales[c]; + } else if (apply_per_tensor_scale) { + dst[c] = static_cast(src[c]) * scales[0]; + } else { + dst[c] = static_cast(src[c]); + } } } else { const auto* src = static_cast(src_raw); for (; c < vec_end; c += 8) { - DequantInt4x8_Neon(src, c, per_channel, scales, dst + c); + DequantInt4x8_Neon(src, c, per_channel, scales, apply_per_tensor_scale, dst + c); } for (; c < cols; ++c) { uint8_t packed = src[c / 2]; int nibble = (c & 1) == 0 ? (packed & 0x0F) : ((packed >> 4) & 0x0F); - float sc = per_channel ? scales[c] : scales[0]; - dst[c] = static_cast(nibble - kInt4Bias) * sc; + if (per_channel) { + dst[c] = static_cast(nibble - kInt4Bias) * scales[c]; + } else if (apply_per_tensor_scale) { + dst[c] = static_cast(nibble - kInt4Bias) * scales[0]; + } else { + dst[c] = static_cast(nibble - kInt4Bias); + } } } } @@ -174,7 +189,7 @@ QKGemm_Neon( for (size_t n = 0; n < N; ++n) { const uint8_t* b_row = B_bytes + n * row_bytes; - DequantRow_Neon(b_row, b_buf, K, QuantType, Scales); + DequantRow_Neon(b_row, b_buf, K, QuantType, Scales, true); for (size_t m = 0; m < M; ++m) { const float* a_row = A + m * lda; @@ -246,6 +261,7 @@ SVGemm_Neon( { const size_t row_bytes = MlasKVQuantPackedRowBytes(QuantType, N); const auto* B_bytes = static_cast(B); + const bool per_channel = IsPerChannelMode(QuantType); float b_stack[256]; float* b_buf = b_stack; @@ -272,7 +288,7 @@ SVGemm_Neon( for (size_t k = 0; k < K; ++k) { const uint8_t* b_row_packed = B_bytes + k * row_bytes; - DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales); + DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales, per_channel); const float a_val = a_row[k]; float32x4_t a_broadcast = vdupq_n_f32(a_val); @@ -288,6 +304,19 @@ SVGemm_Neon( c_row[n] += a_val * b_buf[n]; } } + + if (!per_channel) { + const float32x4_t scale_vec = vdupq_n_f32(Scales[0]); + n = 0; + for (; n < vec_end_n; n += 4) { + float32x4_t c_vec = vld1q_f32(c_row + n); + c_vec = vmulq_f32(c_vec, scale_vec); + vst1q_f32(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Scales[0]; + } + } } } From 359d9ab435d78c4c58b91e3342e7db30ce10c1f7 Mon Sep 17 00:00:00 2001 From: Rishi Dave <62260675+Rishi-Dave@users.noreply.github.com> Date: Fri, 22 May 2026 10:33:24 -0700 Subject: [PATCH 11/16] fix(qdq): skip DQ forward propagation when DQ input is constant (#28521) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Guard `QDQPropagationTransformer::PropagateDQForward` against propagating a DQ whose data input is a constant (graph initializer or `Constant` op output). - Prevents a stale `QuantizeLinear` insertion that the S8-to-U8 weight transformer fails to update, which silently clamps int8 negatives to zero under `ORT_ENABLE_ALL`. - Adds a regression test in `qdq_transformer_test.cc` covering both constant-input shapes. ## Motivation Fixes #28491. Reported scenario: a model containing `Constant(int8) -> DequantizeLinear -> Reshape` produces correct outputs under `ORT_DISABLE_ALL` but wrong outputs (negatives clamped to 0) under `ORT_ENABLE_ALL`. Root cause: `PropagateDQForward` inserts a `Q -> DQ` pair after the Reshape, then `QDQS8ToU8Transformer` (or the avx2 weight transformer) flips the upstream DQ from int8 to uint8 without touching the freshly inserted `Q`. The orphaned uint8 `Q` then clamps the int8 weight's negative values to 0. Propagating a DQ whose data is a constant weight has no benefit anyway: the constant is folded, so there is no runtime tensor that downstream nodes need re-quantized. ## Changes - `onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc`: inside `PropagateDQForward`'s per-DQ loop, after existing skip checks and before any graph mutation, `continue` if `dq_node.InputDefs()[QDQ::InputIndex::INPUT_ID]` is a graph initializer (`graph_utils::NodeArgIsConstant`) or the output of a `Constant` op node (`graph.GetProducerNode(...)->OpType() == "Constant"`). Both checks are required because `NodeArgIsConstant` only handles the initializer case. - `onnxruntime/test/optimizer/qdq_transformer_test.cc`: new test `QDQPropagation_DQForward_ConstantInput_NoPropagation` with two cases — DQ fed by an initializer, and DQ fed by an explicit `Constant` op node — asserting no `QuantizeLinear` is inserted after the downstream Reshape. `PropagateQBackward` is structurally not affected (its data input is a live activation, not a constant), so it does not need a symmetric guard. ## Test Plan - New `QDQTransformerTests.QDQPropagation_DQForward_ConstantInput_NoPropagation` (gtest) covers both code paths and would fail on `main`. - Existing `QDQPropagation_*` tests use `MakeInput` (graph inputs, not initializers) for their DQ data tensors, so the new guard does not regress them. - The reproducer from #28491 should now produce identical results between `ORT_DISABLE_ALL` and `ORT_ENABLE_ALL`. - CI will exercise `--gtest_filter=QDQTransformerTests.QDQPropagation*` under `onnxruntime_test_all`. Fixes #28491 --- .../qdq_transformer/qdq_propagation.cc | 15 ++++ .../test/optimizer/qdq_transformer_test.cc | 89 +++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 8abc6da27a64c..ab491c134b5e5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -371,6 +371,21 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, continue; } + // Do not propagate DQ forward when its data input is a constant (graph initializer or + // Constant op output). Propagation would insert a Q -> DQ pair after any downstream + // reshape-like node; later passes (e.g. S8-to-U8 weight transformer) may flip the + // existing DQ to uint8 without touching the inserted Q, causing int8 negatives to + // be clamped to zero. See GitHub issue #28491. + const NodeArg* dq_data_input = dq_node.InputDefs()[QDQ::InputIndex::INPUT_ID]; + const bool is_initializer_constant = graph_utils::NodeArgIsConstant(graph, *dq_data_input); + const Node* dq_data_producer = graph.GetProducerNode(dq_data_input->Name()); + const bool is_constant_op_output = dq_data_producer != nullptr && + dq_data_producer->OpType() == "Constant" && + dq_data_producer->Domain() == kOnnxDomain; + if (is_initializer_constant || is_constant_op_output) { + continue; + } + auto& dq_scale = *dq_node.MutableInputDefs()[QDQ::InputIndex::SCALE_ID]; auto* dq_zero_point = dq_zero_point_exists ? dq_node.MutableInputDefs()[QDQ::InputIndex::ZERO_POINT_ID] diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index da933464bb66b..bdbd2c488584d 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -4058,6 +4058,95 @@ TEST(QDQTransformerTests, QDQPropagation_DQForward) { #endif } +// Regression test for GitHub issue #28491. +// When a DQ node's data input is a constant (graph initializer or Constant op output), +// PropagateDQForward must not insert a Q -> DQ pair downstream of a reshape-like node. +// Doing so can cause subsequent S8-to-U8 weight transformers to flip the DQ dtype while +// leaving the inserted Q node in its original dtype, clamping int8 negatives to zero. +TEST(QDQTransformerTests, QDQPropagation_DQForward_ConstantInput_NoPropagation) { + // Case 1: DQ data input is a graph initializer. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + // int8 constant weight as a graph initializer + auto* weight = builder.MakeInitializer({4}, {-10, 0, 10, 20}); + auto* output_arg = builder.MakeOutput(); + + // DQ node that dequantizes the constant weight + constexpr float qdq_scale = 0.1f; + constexpr int8_t qdq_zero_point = 0; + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(weight, qdq_scale, qdq_zero_point, dq_output); + + // Reshape downstream of DQ + auto* reshape_shape = builder.Make1DInitializer({2, 2}); + builder.AddNode("Reshape", {dq_output, reshape_shape}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + const auto op_types = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true); + // No Q or DQ should have been inserted after Reshape. + // Expected order: DequantizeLinear -> Reshape (no trailing Q/DQ). + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + const std::vector expected{qdq_keys.dequantize_linear, "Reshape"}; + EXPECT_EQ(op_types, expected); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 12); + } + + // Case 2: DQ data input is the output of a Constant op node. + // Run QDQPropagationTransformer directly (bypassing ConstantFolding) so the + // Constant op node is still present when PropagateDQForward evaluates it. + // Using TransformerTester would fold the Constant into an initializer first, + // masking the is_constant_op_output code path under test. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* output_arg = builder.MakeOutput(); + + // Create a Constant op node that produces an int8 tensor. + ONNX_NAMESPACE::TensorProto constant_tensor; + constant_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8); + constant_tensor.add_dims(4); + const std::vector raw_vals = {-10, 0, 10, 20}; + constant_tensor.set_raw_data(raw_vals.data(), raw_vals.size() * sizeof(int8_t)); + + auto* constant_output = builder.MakeIntermediate(); + constant_tensor.set_name(constant_output->Name()); + builder.AddNode("Constant", {}, {constant_output}).AddAttribute("value", constant_tensor); + + // DQ node that dequantizes the Constant op output + constexpr float qdq_scale = 0.1f; + constexpr int8_t qdq_zero_point = 0; + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(constant_output, qdq_scale, qdq_zero_point, dq_output); + + // Reshape downstream of DQ + auto* reshape_shape = builder.Make1DInitializer({2, 2}); + builder.AddNode("Reshape", {dq_output, reshape_shape}, {output_arg}); + }; + + // post_graph_checker runs on Graph& directly, after only QDQPropagationTransformer. + // QuantizeLinear must not have been inserted anywhere. + auto post_graph_checker = [&](Graph& graph) -> Status { + const auto op_counts = CountOpsInGraph(graph); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + TEST_RETURN_IF_NOT(op_counts.count(qdq_keys.quantize_linear) == 0 || + op_counts.at(qdq_keys.quantize_linear) == 0); + return Status::OK(); + }; + + const auto& logger = DefaultLoggingManager().DefaultLogger(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, logger, + std::make_unique(), + TransformerLevel::Level1, 1, + nullptr, post_graph_checker)); + } +} + TEST(QDQTransformerTests, QDQPropagation_StopAtOtherQDQ) { auto test_case = [&](const std::vector& input_shape, bool same_scale, bool same_zp, bool use_contrib_qdq) { From 1053327ed564ae53e4b57abc2db733faa09f36fb Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Sat, 23 May 2026 02:01:42 +0800 Subject: [PATCH 12/16] [WebGPU] LinearAttention: increase tile_v when subgroups are available (#28519) ### Description - Scale tile_v by 4x when subgroup is enabled and the vectorized dimension has enough columns, improving data reuse. - Gate the tile_v expansion on seq_length >= 16, where the prefill benefit outweighs increased register pressure. - Remove redundant zero-initialization of the state tile (WGSL default- initializes vars to zero). **Intel Panther Lake (xe-3lpg)** | Model | Prefill | Baseline (TPS) | Optimized (TPS) | Change | | :----------- | ------: | -------------: | --------------: | -----: | | Qwen3.5-0.8B | 128 | 1534.30 | 1681.00 | 9.56% | | Qwen3.5-0.8B | 1024 | 3267.30 | 3917.60 | 19.90% | | Qwen3.5-0.8B | 4096 | 2864.50 | 3563.40 | 24.40% | | Qwen3.5-2B | 128 | 1295.60 | 1344.60 | 3.78% | | Qwen3.5-2B | 1024 | 2177.10 | 2344.00 | 7.67% | | Qwen3.5-2B | 4096 | 1942.60 | 2247.00 | 15.67% | | Qwen3.5-4B | 128 | 701.30 | 736.20 | 4.98% | | Qwen3.5-4B | 1024 | 946.90 | 1036.30 | 9.44% | | Qwen3.5-4B | 4096 | 824.10 | 912.00 | 10.67% | ### Motivation and Context See above. --- .../webgpu/bert/linear_attention.cc | 40 ++++++++++++------- .../bert/linear_attention.wgsl.template | 3 -- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index f4935aaeb6b74..5a0d1e4841f05 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -202,16 +202,6 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { TensorShapeVector state_shape({batch_size, kv_num_heads_, head_dim_k, head_dim_v}); Tensor* present_state = context.Output(1, state_shape); - // Vectorization: when head_dim_v is divisible by 4, use vec4 to pack 4 dv values - // per element. This replaces scalar TILE_V loops with native vec4 SIMD operations, - // reduces shared memory access overhead, and enables coalesced memory reads/writes. - const int components = (head_dim_v % 4 == 0 && head_dim_v >= 4) ? 4 : 1; - int tile_v = (components == 4) ? 1 : 4; - if (components == 1 && head_dim_v <= 4) { - tile_v = onnxruntime::narrow(head_dim_v); - } - const int head_dim_v_vectorized = onnxruntime::narrow(head_dim_v) / components; - constexpr uint32_t kMaxSupportedWorkgroupSize = 256; ORT_RETURN_IF_NOT(head_dim_k <= static_cast(kMaxSupportedWorkgroupSize), "LinearAttention WebGPU kernel requires head_dim_k <= ", @@ -225,6 +215,31 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { // Cap at GPU limits workgroup_size = std::min(workgroup_size, kMaxSupportedWorkgroupSize); + // Vectorization: when head_dim_v is divisible by 4, use vec4 to pack 4 dv values + // per element. This replaces scalar TILE_V loops with native vec4 SIMD operations, + // reduces shared memory access overhead, and enables coalesced memory reads/writes. + // TODO: support components == 2 (vec2) for head_dim_v divisible by 2 but not 4. + const int components = (head_dim_v % 4 == 0) ? 4 : 1; + int tile_v = (components == 4) ? 1 : std::min(4, onnxruntime::narrow(head_dim_v)); + + // subgroup_min_size > 0 enables subgroup-based reduction; 0 falls back to barrier-tree. + int subgroup_min_size = context.HasFeature(wgpu::FeatureName::Subgroups) + ? static_cast(context.AdapterInfo().subgroupMinSize) + : 0; + // When subgroup is enabled, use larger tile_v for better data reuse. + // Only expand for longer sequences (>=16) where the benefit outweighs the + // increased register pressure and shared memory usage. + if (subgroup_min_size > 0 && seq_length >= 16) { + // Ensure the vectorized dimension is wide enough to warrant a larger tile. + if (head_dim_v / components >= tile_v * 4) { + tile_v *= 4; + } + } + // Clamp to workgroup_size since the shader assigns one thread per tile_v + // column (threads with dk_idx >= TILE_V are idle for output/state writes). + tile_v = std::min(tile_v, static_cast(workgroup_size)); + + const int head_dim_v_vectorized = onnxruntime::narrow(head_dim_v) / components; const int num_dv_tiles = (head_dim_v_vectorized + tile_v - 1) / tile_v; const uint32_t num_workgroups = onnxruntime::narrow(batch_size * kv_num_heads_ * num_dv_tiles); @@ -243,11 +258,6 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { } } - // subgroup_min_size > 0 enables subgroup-based reduction; 0 falls back to barrier-tree. - int subgroup_min_size = context.HasFeature(wgpu::FeatureName::Subgroups) - ? static_cast(context.AdapterInfo().subgroupMinSize) - : 0; - LinearAttentionProgram program{update_rule_, has_initial_state, has_decay, has_beta, decay_broadcast_dk, tile_v, components, subgroup_min_size}; program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template index 941793ecc7e79..206d9bb6eb20f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template @@ -90,9 +90,6 @@ $MAIN { // Initialize state tile in private memory var state: array; - for (var j = 0u; j < TILE_V; j++) { - state[j] = vtype(0.0); - } // Load initial state if provided #if has_initial_state From a4f40c11c608ef413c3cd556680bb744c2f79e1f Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Fri, 22 May 2026 11:04:47 -0700 Subject: [PATCH 13/16] Fix Reshape with allowzero=1 producing wrong shape for zero-size tensor in chained Reshape (#28455) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description - **Optimizer fix** (`reshape_fusion.cc`): `FuseContiguousReshapes` now explicitly bails out in the while loop when it encounters a `Reshape` next-node that has `allowzero=1`. The fused node inherits attributes from the first node in the chain (which has `allowzero=0` or no `allowzero` attribute), so including an `allowzero=1` node in the fusion would silently drop that attribute and cause zeros in the shape tensor to be misinterpreted as "copy from input" at runtime instead of being preserved as explicit zero dims. A complementary zero-dim guard (already present) also prevents fusion when the inferred final output shape contains any literal zero dimension. - **New end-to-end execution test** (`graph_transform_test.cc`, `ReshapeFusionContiguousReshapesWithZeroDimExecution`): Exercises the exact scenario from the issue — `float[0,8,2] → Reshape([4,2,-1]) → Reshape([0,0,4], allowzero=1)` — through a full `InferenceSession` run, asserting the output shape is `(0,0,4)` and not `(0,8,4)`. The existing `ReshapeFusionContiguousReshapesWithZeroDim` test only validated the optimizer transformation; this test validates runtime correctness. ### Motivation and Context A chained `Reshape` model where the second node uses `allowzero=1` produced the wrong output shape when the fused shape contained zeros. Example reproducer: ```python # X: float[0, 8, 2] n1 = Reshape(X, shape=[4, 2, -1]) # mid: [4, 2, 0] n2 = Reshape(mid, shape=[0, 0, 4], allowzero=1) # expected Y: [0, 0, 4] # ORT returned: (0, 8, 4) — wrong; reference returned: (0, 0, 4) ``` `FuseContiguousReshapes` merged both nodes into `Reshape(X, [0, 0, 4])` with `allowzero=0` (inherited from n1), so the zeros were interpreted as "copy dim from `X`" (`X[1]=8`), yielding `(0, 8, 4)`. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> --- onnxruntime/core/optimizer/reshape_fusion.cc | 13 +++ .../test/optimizer/graph_transform_test.cc | 90 +++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index f88ce56fe36fa..f50c0e2e635bc 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -470,6 +470,19 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) { break; } + // If next_node is a Reshape with allowzero=1, the fused node cannot represent this + // correctly: the fused node inherits attributes from the first node in the chain + // (which has allowzero=0 or no allowzero attribute). Bailing out here prevents + // incorrect fusion such as Reshape([0,8,2]->[4,2,-1]) + Reshape([0,0,4],allowzero=1) + // being collapsed into Reshape([0,8,2]->[0,0,4],allowzero=0), which would silently + // copy dims from the original input instead of preserving the explicit zeros. + if (next_node->OpType() == "Reshape") { + const auto* az_attr = graph_utils::GetNodeAttribute(*next_node, "allowzero"); + if ((nullptr != az_attr) && az_attr->has_i() && az_attr->i() != 0) { + break; + } + } + auto shape = next_node->OutputDefs()[0]->Shape(); if (!shape) { break; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 0e4ab5c2d3b73..924ceaa19a47d 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -5027,6 +5027,96 @@ TEST_F(GraphTransformationTests, ReshapeFusionContiguousReshapesWithZeroDim) { EXPECT_EQ(y_shape->dim(2).dim_value(), 3); } +// Execution regression test: a chained Reshape with allowzero=1 on a zero-element tensor +// must produce the correct output shape at runtime. +// Input X: float[0, 8, 2] -> Reshape([4, 2, -1]) -> mid -> Reshape([0, 0, 4], allowzero=1) -> Y +// Expected Y shape: (0, 0, 4). Without the fix FuseContiguousReshapes would collapse the +// two nodes into one (losing allowzero=1) and emit (0, 8, 4) instead. +// See https://github.com/microsoft/onnxruntime/issues/28348. +TEST_F(GraphTransformationTests, ReshapeFusionContiguousReshapesWithZeroDimExecution) { + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 18; + Model model("ReshapeFusionContiguousReshapesWithZeroDimExecution", false, ModelMetaData(), + PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), *logger_); + auto& graph = model.MainGraph(); + + // X: float[0, 8, 2] + TypeProto x_type; + x_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(0); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(8); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + TypeProto y_type; + y_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + + auto& X = graph.GetOrCreateNodeArg("X", &x_type); + auto& mid = graph.GetOrCreateNodeArg("mid", &y_type); + auto& Y = graph.GetOrCreateNodeArg("Y", &y_type); + + // shape1 = [4, 2, -1] -> mid shape (4, 2, 0) + ONNX_NAMESPACE::TensorProto shape1_proto; + shape1_proto.set_name("shape1"); + shape1_proto.set_data_type(TensorProto_DataType_INT64); + shape1_proto.add_dims(3); + for (int64_t v : {4, 2, -1}) shape1_proto.add_int64_data(v); + graph.AddInitializedTensor(shape1_proto); + + // shape2 = [0, 0, 4] with allowzero=1 -> Y shape (0, 0, 4) + ONNX_NAMESPACE::TensorProto shape2_proto; + shape2_proto.set_name("shape2"); + shape2_proto.set_data_type(TensorProto_DataType_INT64); + shape2_proto.add_dims(3); + for (int64_t v : {0, 0, 4}) shape2_proto.add_int64_data(v); + graph.AddInitializedTensor(shape2_proto); + + auto& shape1 = graph.GetOrCreateNodeArg("shape1", nullptr); + auto& shape2 = graph.GetOrCreateNodeArg("shape2", nullptr); + + graph.AddNode("reshape1", "Reshape", "first reshape", {&X, &shape1}, {&mid}); + auto& reshape2 = graph.AddNode("reshape2", "Reshape", "second reshape (allowzero=1)", + {&mid, &shape2}, {&Y}); + reshape2.AddAttribute("allowzero", static_cast(1)); + + graph.SetInputs({&X}); + graph.SetOutputs({&Y}); + + ASSERT_STATUS_OK(graph.Resolve()); + + // Serialize and run via InferenceSession to exercise the full execution path. + auto model_proto = model.ToProto(); + std::string serialized_model; + ASSERT_TRUE(model_proto.SerializeToString(&serialized_model)); + + SessionOptions so; + InferenceSession session_object{so, GetEnvironment()}; + std::stringstream model_stream(serialized_model); + ASSERT_STATUS_OK(session_object.Load(model_stream)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // Input: zero-element float tensor with shape [0, 8, 2]. + OrtValue input_val; + std::vector input_dims = {0, 8, 2}; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + input_dims, std::vector(), &input_val); + + NameMLValMap feeds = {{"X", input_val}}; + std::vector output_names = {"Y"}; + std::vector fetches; + RunOptions run_options; + ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); + + // Output shape must be (0, 0, 4), not (0, 8, 4). + ASSERT_EQ(fetches.size(), 1U); + const auto& output_tensor = fetches[0].Get(); + const TensorShape& output_shape = output_tensor.Shape(); + ASSERT_EQ(output_shape.NumDimensions(), 3U); + EXPECT_EQ(output_shape[0], 0); + EXPECT_EQ(output_shape[1], 0); + EXPECT_EQ(output_shape[2], 4); +} + TEST_F(GraphTransformationTests, ReshapeFusionWithSlice1) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/reshape_fusion_with_slice1.onnx"; std::shared_ptr p_model; From a4f79e82040a7ce823d3993a02d8d385f6e5af68 Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Fri, 22 May 2026 16:53:36 -0700 Subject: [PATCH 14/16] Fix: Accept 'CPU' as a valid provider name in SessionOptionsAppendExecutionProvider (#28625) ### Description Adds 'CPU'/'CPUExecutionProvider' to the supported provider names whitelist in `SessionOptionsAppendExecutionProvider`. The CPU case is handled as a no-op since the CPU EP is always implicitly registered. ### Motivation and Context The refactoring in #24433 introduced a strict whitelist of supported provider names but omitted 'CPU'. This caused a regression where explicitly appending the CPU provider (which was previously accepted) now throws: ``` RuntimeError: Unknown provider name 'CPU'. Currently supported values are 'DML'/'DmlExecutionProvider', ... ``` This breaks downstream consumers like onnxruntime-genai that call `AppendExecutionProvider('cpu')` when users request CPU execution. The issue was introduced in ORT builds after April 2025 and surfaces in onnxruntime-genai-winml >= 0.13.1. ### Changes - Added `CPU` to the `EpID` enum - Added `EpToAppend{EpID::CPU, 'CPU', kCpuExecutionProvider}` to the `supported_eps` array - Added a `case EpID::CPU:` no-op handler in the switch statement --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxruntime/core/session/provider_registration.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index fa3b6bb840854..c68ad570fbc44 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -101,7 +101,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, VitisAI, CoreML, NvTensorRtRtx, // TensorRt EP for RTX GPUs. - MIGraphX + MIGraphX, + CPU }; struct EpToAppend { @@ -110,7 +111,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, const char* canonical_name = nullptr; }; - static std::array supported_eps = { + static std::array supported_eps = { EpToAppend{EpID::DML, "DML", kDmlExecutionProvider}, EpToAppend{EpID::QNN, "QNN", kQnnExecutionProvider}, EpToAppend{EpID::OpenVINO, "OpenVINO", kOpenVINOExecutionProvider}, @@ -123,7 +124,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, EpToAppend{EpID::VitisAI, "VitisAI", kVitisAIExecutionProvider}, EpToAppend{EpID::CoreML, "CoreML", kCoreMLExecutionProvider}, EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}, - EpToAppend{EpID::MIGraphX, "MIGraphX", kMIGraphXExecutionProvider}}; + EpToAppend{EpID::MIGraphX, "MIGraphX", kMIGraphXExecutionProvider}, + EpToAppend{EpID::CPU, "CPU", kCpuExecutionProvider}}; ProviderOptions provider_options; OrtStatus* status = ParseProviderOptions(provider_options_keys, @@ -197,6 +199,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, ep_to_append.canonical_name)); switch (ep_to_append.id) { + case EpID::CPU: { + // CPU EP is always available by default. Accept the name as valid but do nothing, + // since the CPU EP is implicitly registered in every session. + break; + } case EpID::DML: { #if defined(USE_DML) options->provider_factories.push_back( From b2f6e151ba816a6ea989d52fba19d8c0137d0837 Mon Sep 17 00:00:00 2001 From: velonica0 <47554626+velonica0@users.noreply.github.com> Date: Sat, 23 May 2026 08:27:46 +0800 Subject: [PATCH 15/16] [MLAS] RVV-Optimized LLM Operators for RISC-V (#28518) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Added RVV implementations for a subset of LLM inference operators. Optimization of activation functions is in #28308. All tests were conducted on a Spacemit K3 CPU (VLEN=256). | Operator | File | Speedup vs Scalar | Precision | | :--- | :--- | :--- | :--- | | FP16 GEMM | `riscv64/halfgemm_kernel_rvv.cpp` | 51–191x | max_abs ≤ 0.0005 (PASS) | | FP16↔FP32 Cast | `riscv64/cast_kernel_rvv.cpp` | 4–12x | Bit-exact (PASS) | | RotaryEmbedding | `riscv64/rotary_embedding_kernel_rvv.cpp` | 3.1x | max_abs ~6e-08 (PASS) | | SimplifiedLayerNorm | `layer_norm_impl.cc` (inline RVV) | 4.3x | Bit-exact (PASS) | ## Operator Performance ### **FP16 GEMM** | Shape (M×N×K) | ORT Scalar | RVV | Speedup | Max Abs Error | | :--- | :--- | :--- | :--- | :--- | | 1×768×768 | 7.32 ms | 0.14 ms | 51.6x | 1.22e-04 | | 32×768×768 | 235 ms | 1.25 ms | 187.9x | 3.66e-04 | | 64×768×768 | 469 ms | 2.49 ms | 188.8x | 2.44e-04 | | 128×3072×768 | 5718 ms | 30.4 ms | 187.8x | 4.88e-04 | --- ### **FP16↔FP32 Cast** | Elements | Direction | ORT Scalar | RVV | Speedup | | :--- | :--- | :--- | :--- | :--- | | 1K | F16→F32 | 0.002 ms | 0.000 ms | 9.6x | | 1K | F32→F16 | 0.002 ms | 0.000 ms | 10.5x | | 64K | F16→F32 | 0.127 ms | 0.013 ms | 9.7x | | 64K | F32→F16 | 0.150 ms | 0.013 ms | 11.3x | | 1M | F16→F32 | 2.03 ms | 0.26 ms | 7.7x | | 1M | F32→F16 | 2.39 ms | 0.50 ms | 4.8x | --- ### **RotaryEmbedding** | Dim | Mode | ORT Scalar | RVV | Speedup | Max Abs Error | | :--- | :--- | :--- | :--- | :--- | :--- | | 64 | non-interleaved | 0.32 us | 0.05 us | 7.1x | 5.96e-08 | | 64 | interleaved | 0.22 us | 0.07 us | 3.2x | 0 | | 128 | non-interleaved | 0.64 us | 0.07 us | 9.7x | 5.96e-08 | | 128 | interleaved | 0.44 us | 0.13 us | 3.4x | 0 | | 256 | non-interleaved | 1.28 us | 0.10 us | 13.0x | 1.19e-07 | | 256 | interleaved | 1.05 us | 0.25 us | 4.3x | 0 | --- ### **RMSNorm** | Hidden | ORT Scalar | RVV | Speedup | Max Abs Error | | :--- | :--- | :--- | :--- | :--- | | 512 | 2.31 us | 0.38 us | 6.0x | 2.38e-07 | | 1024 | 4.64 us | 0.71 us | 6.5x | 2.38e-07 | | 2048 | 9.24 us | 1.42 us | 6.5x | 3.58e-07 | | 4096 | 18.5 us | 2.82 us | 6.6x | 3.58e-06 | > **Note**: ORT's LayerNorm ComputeJob is in an anonymous namespace — there's no public API to call it separately. So I rewrite the benchmark using the same algorithm as ORT's ComputeJob. ## Model Performance The ONNX model comes from: https://huggingface.co/onnx-community/Qwen3-0.6B-ONNX | Metric | FP32 | FP16 | | :--- | :--- | :--- | | Prompt processing | 61.1 tok/s (255 ms p50) | 58.8 tok/s (272 ms p50) | | Token generation | 6.5 tok/s (152 ms p50) | 6.1 tok/s (162 ms p50) | | E2E (16+32 tokens) | 4987 ms p50 | 5357 ms p50 | | Peak memory | 3.1 GB | 4.1 GB | > **Note**: FP32 is slightly faster because it runs SGEMM directly without the FP16↔FP32 cast overhead. FP16 uses ~1 GB less storage on disk but more runtime memory (the cast creates FP32 copies). Both use the RVV SGEMM kernel for the actual compute. --- cmake/CMakeLists.txt | 1 + cmake/onnxruntime_mlas.cmake | 17 ++ cmake/onnxruntime_unittests.cmake | 44 +++ .../core/common/cpuid_arch_definition.h | 4 + onnxruntime/core/common/cpuid_info.cc | 27 ++ onnxruntime/core/common/cpuid_info.h | 6 + onnxruntime/core/mlas/inc/mlas.h | 21 ++ onnxruntime/core/mlas/lib/halfgemm.cpp | 2 + onnxruntime/core/mlas/lib/halfgemm.h | 9 + onnxruntime/core/mlas/lib/layernorm.cpp | 41 +++ onnxruntime/core/mlas/lib/mlasi.h | 26 ++ onnxruntime/core/mlas/lib/platform.cpp | 9 + .../core/mlas/lib/riscv64/cast_kernel_rvv.cpp | 62 +++++ .../mlas/lib/riscv64/halfgemm_kernel_rvv.cpp | 239 +++++++++++++++++ .../mlas/lib/riscv64/layernorm_kernel_rvv.cpp | 109 ++++++++ .../riscv64/rotary_embedding_kernel_rvv.cpp | 108 ++++++++ .../riscv64/sconv_depthwise_kernel_rvv.cpp | 1 + .../core/providers/cpu/nn/layer_norm_impl.cc | 17 +- .../mlas/bench/riscv64/cast_rvv_bench.cpp | 165 ++++++++++++ .../mlas/bench/riscv64/halfgemm_rvv_bench.cpp | 253 ++++++++++++++++++ .../mlas/bench/riscv64/rmsnorm_rvv_bench.cpp | 169 ++++++++++++ .../mlas/bench/riscv64/rope_rvv_bench.cpp | 152 +++++++++++ .../test/mlas/unittest/test_cast_fp16.cpp | 120 +++++++++ .../test/mlas/unittest/test_layernorm.cpp | 157 +++++++++++ onnxruntime/test/mlas/unittest/test_rope.cpp | 4 +- 25 files changed, 1758 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/layernorm.cpp create mode 100644 onnxruntime/core/mlas/lib/riscv64/cast_kernel_rvv.cpp create mode 100644 onnxruntime/core/mlas/lib/riscv64/halfgemm_kernel_rvv.cpp create mode 100644 onnxruntime/core/mlas/lib/riscv64/layernorm_kernel_rvv.cpp create mode 100644 onnxruntime/core/mlas/lib/riscv64/rotary_embedding_kernel_rvv.cpp create mode 100644 onnxruntime/test/mlas/bench/riscv64/cast_rvv_bench.cpp create mode 100644 onnxruntime/test/mlas/bench/riscv64/halfgemm_rvv_bench.cpp create mode 100644 onnxruntime/test/mlas/bench/riscv64/rmsnorm_rvv_bench.cpp create mode 100644 onnxruntime/test/mlas/bench/riscv64/rope_rvv_bench.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_cast_fp16.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_layernorm.cpp diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 1ac1d52231577..f1126c2dce79e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -90,6 +90,7 @@ option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_USE_SVE "Build with SVE support in MLAS" OFF) option(onnxruntime_USE_RVV "Build with RISC-V Vector support in MLAS" OFF) +option(onnxruntime_USE_RVV_ZVFH "Build with RISC-V Zvfh (FP16 vector) support in MLAS" OFF) option(onnxruntime_USE_ARM_NEON_NCHWC "Build with ARM Neon NCHWc kernels in MLAS" OFF) option(onnxruntime_USE_KLEIDIAI "Build with KleidiAI integration in MLAS" OFF) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 7ae18db235ccb..8c7df780735f1 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -57,6 +57,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/qkv_quant.cpp ${MLAS_SRC_DIR}/cast.cpp + ${MLAS_SRC_DIR}/layernorm.cpp ${MLAS_SRC_DIR}/rotary_embedding.h ${MLAS_SRC_DIR}/rotary_embedding.cpp ${MLAS_SRC_DIR}/softmax.h @@ -959,6 +960,8 @@ endif() ${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp ${MLAS_SRC_DIR}/riscv64/sconv_depthwise_kernel_rvv.cpp ${MLAS_SRC_DIR}/riscv64/sconv_nchwc_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/rotary_embedding_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/layernorm_kernel_rvv.cpp ) list(REMOVE_ITEM mlas_platform_srcs "${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp") @@ -968,8 +971,22 @@ endif() ${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp ${MLAS_SRC_DIR}/riscv64/sconv_depthwise_kernel_rvv.cpp ${MLAS_SRC_DIR}/riscv64/sconv_nchwc_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/rotary_embedding_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/layernorm_kernel_rvv.cpp PROPERTIES COMPILE_FLAGS "-march=rv64gcv -mabi=lp64d") list(APPEND mlas_private_compile_definitions MLAS_USE_RVV=1) + + if(onnxruntime_USE_RVV_ZVFH) + list(APPEND mlas_platform_srcs + ${MLAS_SRC_DIR}/riscv64/halfgemm_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/cast_kernel_rvv.cpp + ) + set_source_files_properties( + ${MLAS_SRC_DIR}/riscv64/halfgemm_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/cast_kernel_rvv.cpp + PROPERTIES COMPILE_FLAGS "-march=rv64gcv_zvfh -mabi=lp64d") + list(APPEND mlas_private_compile_definitions MLAS_USE_RVV_ZVFH=1) + endif() else() message( WARNING diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index a061858fa068f..9f9356d4dff8d 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1450,6 +1450,50 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) target_compile_definitions(onnxruntime_mlas_softmax_riscv_compare PRIVATE ${mlas_private_compile_definitions}) set_target_properties(onnxruntime_mlas_softmax_riscv_compare PROPERTIES FOLDER "ONNXRuntimeTest") + + onnxruntime_add_executable( + onnxruntime_mlas_halfgemm_rvv_bench + ${MLAS_RISCV64_BENCH_DIR}/halfgemm_rvv_bench.cpp) + target_include_directories(onnxruntime_mlas_halfgemm_rvv_bench PRIVATE + ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib) + target_link_libraries( + onnxruntime_mlas_halfgemm_rvv_bench + PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) + target_compile_definitions(onnxruntime_mlas_halfgemm_rvv_bench PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(onnxruntime_mlas_halfgemm_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest") + + onnxruntime_add_executable( + onnxruntime_mlas_cast_rvv_bench + ${MLAS_RISCV64_BENCH_DIR}/cast_rvv_bench.cpp) + target_include_directories(onnxruntime_mlas_cast_rvv_bench PRIVATE + ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib) + target_link_libraries( + onnxruntime_mlas_cast_rvv_bench + PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) + target_compile_definitions(onnxruntime_mlas_cast_rvv_bench PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(onnxruntime_mlas_cast_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest") + + onnxruntime_add_executable( + onnxruntime_mlas_rope_rvv_bench + ${MLAS_RISCV64_BENCH_DIR}/rope_rvv_bench.cpp) + target_include_directories(onnxruntime_mlas_rope_rvv_bench PRIVATE + ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib) + target_link_libraries( + onnxruntime_mlas_rope_rvv_bench + PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) + target_compile_definitions(onnxruntime_mlas_rope_rvv_bench PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(onnxruntime_mlas_rope_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest") + + onnxruntime_add_executable( + onnxruntime_mlas_rmsnorm_rvv_bench + ${MLAS_RISCV64_BENCH_DIR}/rmsnorm_rvv_bench.cpp) + target_include_directories(onnxruntime_mlas_rmsnorm_rvv_bench PRIVATE + ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib) + target_link_libraries( + onnxruntime_mlas_rmsnorm_rvv_bench + PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) + target_compile_definitions(onnxruntime_mlas_rmsnorm_rvv_bench PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(onnxruntime_mlas_rmsnorm_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest") endif() if(WIN32) diff --git a/onnxruntime/core/common/cpuid_arch_definition.h b/onnxruntime/core/common/cpuid_arch_definition.h index 5946b8ca27067..973c50b5dda38 100644 --- a/onnxruntime/core/common/cpuid_arch_definition.h +++ b/onnxruntime/core/common/cpuid_arch_definition.h @@ -12,3 +12,7 @@ #if defined(_M_ARM64) || defined(_M_ARM64EC) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) #define CPUIDINFO_ARCH_ARM #endif // ARM or ARM64 + +#if defined(__riscv) && __riscv_xlen == 64 +#define CPUIDINFO_ARCH_RISCV64 +#endif diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 5990013c925c5..96dc427ad766c 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -47,6 +47,16 @@ #endif // ARM +#if defined(CPUIDINFO_ARCH_RISCV64) +#include +#ifndef RISCV_HWPROBE_EXT_ZVFH +#define RISCV_HWPROBE_EXT_ZVFH (1 << 30) +#endif +#ifndef RISCV_HWPROBE_IMA_V +#define RISCV_HWPROBE_IMA_V (1 << 2) +#endif +#endif // RISCV64 + #endif // Linux #if _WIN32 @@ -334,6 +344,17 @@ void CPUIDInfo::ArmAppleInit() { #endif // defined(CPUIDINFO_ARCH_ARM) +#if defined(CPUIDINFO_ARCH_RISCV64) && defined(__linux__) +void CPUIDInfo::RiscvLinuxInit() { + struct riscv_hwprobe pairs[] = { + {RISCV_HWPROBE_KEY_IMA_EXT_0, 0}, + }; + if (syscall(__NR_riscv_hwprobe, pairs, 1, 0, nullptr, 0) == 0) { + has_fp16_ = (pairs[0].value & RISCV_HWPROBE_EXT_ZVFH) != 0; + } +} +#endif // defined(CPUIDINFO_ARCH_RISCV64) && defined(__linux__) + uint32_t CPUIDInfo::GetCurrentCoreIdx() const { #ifdef _WIN32 return GetCurrentProcessorNumber(); @@ -377,5 +398,11 @@ CPUIDInfo::CPUIDInfo() { ArmAppleInit(); #endif #endif // defined(CPUIDINFO_ARCH_ARM) + +#if defined(CPUIDINFO_ARCH_RISCV64) +#if defined(__linux__) + RiscvLinuxInit(); +#endif +#endif // defined(CPUIDINFO_ARCH_RISCV64) } } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index be301019df5c0..bf502c645c9eb 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -135,6 +135,12 @@ class CPUIDInfo { #endif // defined(CPUIDINFO_ARCH_ARM) +#if defined(CPUIDINFO_ARCH_RISCV64) +#if defined(__linux__) + void RiscvLinuxInit(); +#endif +#endif // defined(CPUIDINFO_ARCH_RISCV64) + #if defined(CPUINFO_SUPPORTED) bool pytorch_cpuinfo_init_{false}; #endif // defined(CPUINFO_SUPPORTED) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index ddb9daa5e244b..99b72dc756663 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1665,6 +1665,27 @@ MlasRotaryEmbedOneRow( T* output ); +/** + * @brief Compute LayerNorm or RMSNorm (simplified) for one row of float data. + * Uses platform-optimized kernel if available, otherwise returns false. + * Any platform (AMD64/ARM64/RISC-V) can register a LayerNormF32Kernel. + * + * @return true if an optimized kernel was used, false if caller should fall back + */ +bool +MLASCALL +MlasLayerNormF32( + const float* Input, + const float* Scale, + const float* Bias, + float* Output, + float* MeanOut, + float* InvStdDevOut, + size_t NormSize, + float Epsilon, + bool Simplified +); + /** * @brief Supply matrices data information to half precision gemm functions */ diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 66a335665d024..05cde92d9f9d7 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -27,6 +27,8 @@ MlasFp16AccelerationSupported() { #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration(); +#elif defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) + return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration(); #else return false; #endif diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 529db48f58e6f..3f63e00f05f12 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -503,12 +503,21 @@ extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault; extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon; #endif +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) +extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchRvv; +#endif + MLAS_FORCEINLINE const MLAS_HALFGEMM_DISPATCH* MlasHalfGemmGetDispatch() { #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) return &MlasHalfGemmDispatchNeon; +#elif defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration()) { + return &MlasHalfGemmDispatchRvv; + } + return &MlasHalfGemmDispatchDefault; #else return &MlasHalfGemmDispatchDefault; #endif diff --git a/onnxruntime/core/mlas/lib/layernorm.cpp b/onnxruntime/core/mlas/lib/layernorm.cpp new file mode 100644 index 0000000000000..34258436d60a0 --- /dev/null +++ b/onnxruntime/core/mlas/lib/layernorm.cpp @@ -0,0 +1,41 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + layernorm.cpp + +Abstract: + + This module implements the dispatch for platform-optimized + LayerNorm/RMSNorm kernels. + +--*/ + +#include "mlasi.h" + +bool + MLASCALL + MlasLayerNormF32( + const float* Input, + const float* Scale, + const float* Bias, + float* Output, + float* MeanOut, + float* InvStdDevOut, + size_t NormSize, + float Epsilon, + bool Simplified + ) +{ + auto kernel = GetMlasPlatform().LayerNormF32Kernel; + if (kernel == nullptr) { + return false; + } + + kernel(Input, Scale, Bias, Output, MeanOut, InvStdDevOut, NormSize, Epsilon, Simplified); + return true; +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index dbb414505ff38..bf4f3f6e2de2d 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -691,6 +691,18 @@ typedef void(MLASCALL MLAS_CAST_F32_TO_F16_KERNEL)( size_t Count ); +typedef void(MLASCALL MLAS_LAYERNORM_F32_KERNEL)( + const float* Input, + const float* Scale, + const float* Bias, + float* Output, + float* MeanOut, + float* InvStdDevOut, + size_t NormSize, + float Epsilon, + bool Simplified +); + typedef void (MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)( @@ -1230,6 +1242,15 @@ extern "C" { MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelNeon; MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelNeon; #endif + +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelRvv; + MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelRvv; +#endif + +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV) + MLAS_LAYERNORM_F32_KERNEL MlasLayerNormKernelRvv; +#endif } // @@ -1388,6 +1409,10 @@ struct MLAS_ROPE_DISPATCH; extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; extern const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2; +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV) +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchRvv; +#endif + // // half gemm dispatch structure // @@ -1631,6 +1656,7 @@ MLAS_COMPUTE_TANH_FP16_KERNEL* TanhFP16KernelRoutine = nullptr; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; + MLAS_LAYERNORM_F32_KERNEL* LayerNormF32Kernel{nullptr}; const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr}; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 466fa9a3e9497..6eb53684065a4 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -292,6 +292,15 @@ Return Value: this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelRvv; this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelRvv; this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelRvv; + this->RopeDispatch = &MlasRopeDispatchRvv; + this->LayerNormF32Kernel = &MlasLayerNormKernelRvv; + +#if defined(MLAS_USE_RVV_ZVFH) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration()) { + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelRvv; + this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelRvv; + } +#endif // NCHWc kernels require VLEN>=128 so that vfloat32m4_t holds 16 floats. if (__riscv_vlenb() >= 16) { diff --git a/onnxruntime/core/mlas/lib/riscv64/cast_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/cast_kernel_rvv.cpp new file mode 100644 index 0000000000000..038b7873637db --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/cast_kernel_rvv.cpp @@ -0,0 +1,62 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + cast_kernel_rvv.cpp + +Abstract: + + This module implements FP16<->FP32 cast kernels using RISC-V Vector + Extension (RVV). Uses Zvfhmin conversion instructions, but is gated + on Zvfh at build time (no separate Zvfhmin-only cmake probe). + +--*/ + +#include "mlasi.h" + +#if defined(MLAS_USE_RVV_ZVFH) + +#include + +void + MLASCALL + MlasCastF16ToF32KernelRvv( + const unsigned short* Source, + float* Destination, + size_t Count + ) +{ + size_t i = 0; + while (i < Count) { + size_t vl = __riscv_vsetvl_e16m2(Count - i); + vuint16m2_t raw = __riscv_vle16_v_u16m2(Source + i, vl); + vfloat16m2_t fp16 = __riscv_vreinterpret_v_u16m2_f16m2(raw); + vfloat32m4_t fp32 = __riscv_vfwcvt_f_f_v_f32m4(fp16, vl); + __riscv_vse32_v_f32m4(Destination + i, fp32, vl); + i += vl; + } +} + +void + MLASCALL + MlasCastF32ToF16KernelRvv( + const float* Source, + unsigned short* Destination, + size_t Count + ) +{ + size_t i = 0; + while (i < Count) { + size_t vl = __riscv_vsetvl_e32m4(Count - i); + vfloat32m4_t fp32 = __riscv_vle32_v_f32m4(Source + i, vl); + vfloat16m2_t fp16 = __riscv_vfncvt_f_f_w_f16m2(fp32, vl); + __riscv_vse16_v_u16m2(Destination + i, __riscv_vreinterpret_v_f16m2_u16m2(fp16), vl); + i += vl; + } +} + +#endif // MLAS_USE_RVV_ZVFH diff --git a/onnxruntime/core/mlas/lib/riscv64/halfgemm_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/halfgemm_kernel_rvv.cpp new file mode 100644 index 0000000000000..f9fb2bbba96bf --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/halfgemm_kernel_rvv.cpp @@ -0,0 +1,239 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_kernel_rvv.cpp + +Abstract: + + This module implements half precision GEMM kernel for RISC-V Vector + Extension (RVV) with Zvfh (vector half-precision floating-point). + + The kernel vectorizes along the N dimension using vsetvl, so it adapts + automatically to any VLEN >= 128. Up to 4 rows of A are processed per + call (KernelMaxM = 4). + +--*/ + +#include "halfgemm.h" +#include "mlasi.h" + +#if defined(MLAS_USE_RVV_ZVFH) + +#include + +#include + +namespace +{ + +MLAS_FORCEINLINE +_Float16 +Fp16BitsToScalar(_mlas_fp16_ bits) +{ + _Float16 f; + memcpy(&f, &bits, sizeof(f)); + return f; +} + +MLAS_FORCEINLINE +vfloat16m4_t +LoadFp16(const _mlas_fp16_* ptr, size_t vl) +{ + return __riscv_vreinterpret_v_u16m4_f16m4(__riscv_vle16_v_u16m4(ptr, vl)); +} + +MLAS_FORCEINLINE +void +StoreFp16(_mlas_fp16_* ptr, vfloat16m4_t vec, size_t vl) +{ + __riscv_vse16_v_u16m4(ptr, __riscv_vreinterpret_v_f16m4_u16m4(vec), vl); +} + +template +MLAS_FORCEINLINE void +HalfGemmKernelRvvImpl( + size_t CountN, + size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + size_t lda, + const _mlas_fp16_* B, + size_t ldb, + bool ZeroMode +) +{ + static_assert(Rows >= 1 && Rows <= 4, "unsupported tile height"); + + size_t n = 0; + while (n < CountN) { + size_t vl = __riscv_vsetvl_e16m4(CountN - n); + + vfloat16m4_t acc0, acc1, acc2, acc3; + + if (ZeroMode) { + if (Bias != nullptr) { + vfloat16m4_t bv = LoadFp16(Bias + n, vl); + acc0 = bv; + if constexpr (Rows > 1) acc1 = bv; + if constexpr (Rows > 2) acc2 = bv; + if constexpr (Rows > 3) acc3 = bv; + } else { + vfloat16m4_t z = __riscv_vfmv_v_f_f16m4((_Float16)0.0f, vl); + acc0 = z; + if constexpr (Rows > 1) acc1 = z; + if constexpr (Rows > 2) acc2 = z; + if constexpr (Rows > 3) acc3 = z; + } + } else { + acc0 = LoadFp16(C + n, vl); + if constexpr (Rows > 1) acc1 = LoadFp16(C + ldc + n, vl); + if constexpr (Rows > 2) acc2 = LoadFp16(C + 2 * ldc + n, vl); + if constexpr (Rows > 3) acc3 = LoadFp16(C + 3 * ldc + n, vl); + if (Bias != nullptr) { + vfloat16m4_t bv = LoadFp16(Bias + n, vl); + acc0 = __riscv_vfadd_vv_f16m4(acc0, bv, vl); + if constexpr (Rows > 1) acc1 = __riscv_vfadd_vv_f16m4(acc1, bv, vl); + if constexpr (Rows > 2) acc2 = __riscv_vfadd_vv_f16m4(acc2, bv, vl); + if constexpr (Rows > 3) acc3 = __riscv_vfadd_vv_f16m4(acc3, bv, vl); + } + } + + for (size_t k = 0; k < CountK; k++) { + vfloat16m4_t bv = LoadFp16(B + k * ldb + n, vl); + acc0 = __riscv_vfmacc_vf_f16m4(acc0, Fp16BitsToScalar(A[k]), bv, vl); + if constexpr (Rows > 1) + acc1 = __riscv_vfmacc_vf_f16m4(acc1, Fp16BitsToScalar(A[lda + k]), bv, vl); + if constexpr (Rows > 2) + acc2 = __riscv_vfmacc_vf_f16m4(acc2, Fp16BitsToScalar(A[2 * lda + k]), bv, vl); + if constexpr (Rows > 3) + acc3 = __riscv_vfmacc_vf_f16m4(acc3, Fp16BitsToScalar(A[3 * lda + k]), bv, vl); + } + + StoreFp16(C + n, acc0, vl); + if constexpr (Rows > 1) StoreFp16(C + ldc + n, acc1, vl); + if constexpr (Rows > 2) StoreFp16(C + 2 * ldc + n, acc2, vl); + if constexpr (Rows > 3) StoreFp16(C + 3 * ldc + n, acc3, vl); + + n += vl; + } +} + +} // namespace + +struct MLAS_HALF_GEMM_KERNEL_RVV { + static constexpr bool PackNeeded = false; + static constexpr size_t KernelMaxM = 4; + static constexpr size_t PackedK = 1; + static constexpr MLAS_HALF_GEMM_STRIDES Strides{16, 128, 256}; +}; + +// FP32->FP16 conversion routines for when AIsfp32/BIsfp32 is set. +// PackNeeded=false means no packing, but these are still called +// to convert FP32 inputs to FP16 on the fly (see matmul.cc). +template <> +MLAS_FORCEINLINE void +MlasHalfGemmConvertPackA( + _mlas_fp16_* D, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +) +{ + for (size_t m = 0; m < CountM; m++) { + const float* src = A + m * lda; + _mlas_fp16_* dst = D + m * CountK; + size_t k = 0; + while (k < CountK) { + size_t vl = __riscv_vsetvl_e32m4(CountK - k); + vfloat32m4_t fp32 = __riscv_vle32_v_f32m4(src + k, vl); + vfloat16m2_t fp16 = __riscv_vfncvt_f_f_w_f16m2(fp32, vl); + __riscv_vse16_v_u16m2( + dst + k, + __riscv_vreinterpret_v_f16m2_u16m2(fp16), + vl + ); + k += vl; + } + } +} + +template <> +MLAS_FORCEINLINE void +MlasHalfGemmConvertPackB( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK +) +{ + for (size_t k = 0; k < CountK; k++) { + const float* src = B + k * ldb; + _mlas_fp16_* dst = D + k * CountN; + size_t n = 0; + while (n < CountN) { + size_t vl = __riscv_vsetvl_e32m4(CountN - n); + vfloat32m4_t fp32 = __riscv_vle32_v_f32m4(src + n, vl); + vfloat16m2_t fp16 = __riscv_vfncvt_f_f_w_f16m2(fp32, vl); + __riscv_vse16_v_u16m2( + dst + n, + __riscv_vreinterpret_v_f16m2_u16m2(fp16), + vl + ); + n += vl; + } + } +} + +template <> +MLAS_FORCEINLINE void +MlasHalfGemmKernel( + size_t CountM, + size_t CountN, + size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + size_t lda, + const _mlas_fp16_* B, + size_t ldb, + const bool ZeroMode +) +{ + size_t rows = std::min(CountM, MLAS_HALF_GEMM_KERNEL_RVV::KernelMaxM); + + switch (rows) { + case 1: + HalfGemmKernelRvvImpl<1>(CountN, CountK, C, ldc, Bias, A, lda, B, ldb, ZeroMode); + break; + case 2: + HalfGemmKernelRvvImpl<2>(CountN, CountK, C, ldc, Bias, A, lda, B, ldb, ZeroMode); + break; + case 3: + HalfGemmKernelRvvImpl<3>(CountN, CountK, C, ldc, Bias, A, lda, B, ldb, ZeroMode); + break; + default: + HalfGemmKernelRvvImpl<4>(CountN, CountK, C, ldc, Bias, A, lda, B, ldb, ZeroMode); + break; + } +} + +const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchRvv = { + MlasHalfGemmOperation, + nullptr, + MlasHalfGemmConvertPackB, + MLAS_HALF_GEMM_KERNEL_RVV::PackedK, + MLAS_HALF_GEMM_KERNEL_RVV::KernelMaxM, + 0 +}; + +#endif // MLAS_USE_RVV_ZVFH diff --git a/onnxruntime/core/mlas/lib/riscv64/layernorm_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/layernorm_kernel_rvv.cpp new file mode 100644 index 0000000000000..2bfeba1f1c993 --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/layernorm_kernel_rvv.cpp @@ -0,0 +1,109 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + layernorm_kernel_rvv.cpp + +Abstract: + + This module implements LayerNorm/RMSNorm kernels using RISC-V Vector + Extension (RVV). Processes one normalization row at a time. + +--*/ + +#include "mlasi.h" + +#if defined(MLAS_USE_RVV) + +#include + +#include +#include + +// Processes one normalization row. A multi-row variant that fuses +// several rows would reduce dispatch overhead for small NormSize. +void MLASCALL +MlasLayerNormKernelRvv( + const float* Input, + const float* Scale, + const float* Bias, + float* Output, + float* MeanOut, + float* InvStdDevOut, + size_t NormSize, + float Epsilon, + bool Simplified +) +{ + assert(!Simplified || Bias == nullptr); + const size_t n = NormSize; + + size_t maxvl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t vacc_sum = __riscv_vfmv_v_f_f32m4(0.0f, maxvl); + vfloat32m4_t vacc_sumsq = __riscv_vfmv_v_f_f32m4(0.0f, maxvl); + + size_t i = 0; + while (i < n) { + size_t vl = __riscv_vsetvl_e32m4(n - i); + vfloat32m4_t vx = __riscv_vle32_v_f32m4(Input + i, vl); + vacc_sum = __riscv_vfadd_vv_f32m4_tu(vacc_sum, vacc_sum, vx, vl); + vfloat32m4_t vx2 = __riscv_vfmul_vv_f32m4(vx, vx, vl); + vacc_sumsq = __riscv_vfadd_vv_f32m4_tu(vacc_sumsq, vacc_sumsq, vx2, vl); + i += vl; + } + + vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvl_e32m1(1)); + float mean_val = __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m4_f32m1(vacc_sum, vzero, maxvl) + ) / + static_cast(n); + float ms_val = __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m4_f32m1(vacc_sumsq, vzero, maxvl) + ); + float denom; + if (Simplified) { + denom = sqrtf(ms_val / static_cast(n) + Epsilon); + } else { + denom = sqrtf(ms_val / static_cast(n) - mean_val * mean_val + Epsilon); + } + float inv_denom = 1.0f / denom; + + i = 0; + while (i < n) { + size_t vl = __riscv_vsetvl_e32m4(n - i); + vfloat32m4_t vx = __riscv_vle32_v_f32m4(Input + i, vl); + vfloat32m4_t vs = __riscv_vle32_v_f32m4(Scale + i, vl); + + if (Simplified) { + vfloat32m4_t vy = __riscv_vfmul_vf_f32m4(vx, inv_denom, vl); + vy = __riscv_vfmul_vv_f32m4(vy, vs, vl); + __riscv_vse32_v_f32m4(Output + i, vy, vl); + } else if (Bias == nullptr) { + vfloat32m4_t vy = __riscv_vfsub_vf_f32m4(vx, mean_val, vl); + vy = __riscv_vfmul_vf_f32m4(vy, inv_denom, vl); + vy = __riscv_vfmul_vv_f32m4(vy, vs, vl); + __riscv_vse32_v_f32m4(Output + i, vy, vl); + } else { + vfloat32m4_t vb = __riscv_vle32_v_f32m4(Bias + i, vl); + vfloat32m4_t vy = __riscv_vfsub_vf_f32m4(vx, mean_val, vl); + vy = __riscv_vfmul_vf_f32m4(vy, inv_denom, vl); + vy = __riscv_vfmadd_vv_f32m4(vy, vs, vb, vl); + __riscv_vse32_v_f32m4(Output + i, vy, vl); + } + + i += vl; + } + + if (MeanOut != nullptr) { + *MeanOut = mean_val; + } + if (InvStdDevOut != nullptr) { + *InvStdDevOut = inv_denom; + } +} + +#endif // MLAS_USE_RVV diff --git a/onnxruntime/core/mlas/lib/riscv64/rotary_embedding_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/rotary_embedding_kernel_rvv.cpp new file mode 100644 index 0000000000000..3cc00624c76bd --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/rotary_embedding_kernel_rvv.cpp @@ -0,0 +1,108 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_rvv.cpp + +Abstract: + + This module implements rotary embedding kernels for RISC-V Vector + Extension (RVV). + + For the non-interleaved case: + output[i] = input[i] * cos[i] - input[i + half] * sin[i] + output[i + half] = input[i + half] * cos[i] + input[i] * sin[i] + + For the interleaved case: + output[2i] = input[2i] * cos[i] - input[2i+1] * sin[i] + output[2i+1] = input[2i+1] * cos[i] + input[2i] * sin[i] + +--*/ + +#include + +#include "rotary_embedding.h" + +#if defined(MLAS_USE_RVV) + +#include + +namespace rope_rvv +{ + +void +RopeKernel_Fp32( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output +) +{ + assert(dim % 2 == 0); + const size_t half_dim = dim / 2; + + if (!interleaved) { + size_t i = 0; + while (i < half_dim) { + size_t vl = __riscv_vsetvl_e32m4(half_dim - i); + + vfloat32m4_t vc = __riscv_vle32_v_f32m4(cos_data + i, vl); + vfloat32m4_t vs = __riscv_vle32_v_f32m4(sin_data + i, vl); + vfloat32m4_t v0 = __riscv_vle32_v_f32m4(input + i, vl); + vfloat32m4_t v1 = __riscv_vle32_v_f32m4(input + i + half_dim, vl); + + // output[i] = input[i] * cos - input[i+half] * sin + vfloat32m4_t r0 = __riscv_vfmul_vv_f32m4(v0, vc, vl); + r0 = __riscv_vfnmsac_vv_f32m4(r0, vs, v1, vl); + + // output[i+half] = input[i+half] * cos + input[i] * sin + vfloat32m4_t r1 = __riscv_vfmul_vv_f32m4(v1, vc, vl); + r1 = __riscv_vfmacc_vv_f32m4(r1, vs, v0, vl); + + __riscv_vse32_v_f32m4(output + i, r0, vl); + __riscv_vse32_v_f32m4(output + i + half_dim, r1, vl); + + i += vl; + } + } else { + size_t i = 0; + while (i < half_dim) { + size_t vl = __riscv_vsetvl_e32m4(half_dim - i); + + vfloat32m4_t vc = __riscv_vle32_v_f32m4(cos_data + i, vl); + vfloat32m4_t vs = __riscv_vle32_v_f32m4(sin_data + i, vl); + + vfloat32m4x2_t seg = __riscv_vlseg2e32_v_f32m4x2(input + 2 * i, vl); + vfloat32m4_t v_even = __riscv_vget_v_f32m4x2_f32m4(seg, 0); + vfloat32m4_t v_odd = __riscv_vget_v_f32m4x2_f32m4(seg, 1); + + // output[2i] = even * cos - odd * sin + vfloat32m4_t r_even = __riscv_vfmul_vv_f32m4(v_even, vc, vl); + r_even = __riscv_vfnmsac_vv_f32m4(r_even, vs, v_odd, vl); + + // output[2i+1] = odd * cos + even * sin + vfloat32m4_t r_odd = __riscv_vfmul_vv_f32m4(v_odd, vc, vl); + r_odd = __riscv_vfmacc_vv_f32m4(r_odd, vs, v_even, vl); + + vfloat32m4x2_t out = __riscv_vcreate_v_f32m4x2(r_even, r_odd); + __riscv_vsseg2e32_v_f32m4x2(output + 2 * i, out, vl); + + i += vl; + } + } +} + +} // namespace rope_rvv + +const MLAS_ROPE_DISPATCH MlasRopeDispatchRvv = { + rope_rvv::RopeKernel_Fp32, + nullptr, +}; + +#endif // MLAS_USE_RVV diff --git a/onnxruntime/core/mlas/lib/riscv64/sconv_depthwise_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/sconv_depthwise_kernel_rvv.cpp index c9253bb033a1a..51b3e24dddff7 100644 --- a/onnxruntime/core/mlas/lib/riscv64/sconv_depthwise_kernel_rvv.cpp +++ b/onnxruntime/core/mlas/lib/riscv64/sconv_depthwise_kernel_rvv.cpp @@ -142,6 +142,7 @@ MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1( assert(pad_bottom <= 1); assert(pad_left <= 1); assert(pad_right <= 1); + MLAS_UNREFERENCED_PARAMETER(pad_bottom); const float beta = Parameters->Beta; const bool accumulate_output = beta != 0.0f; diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index 7dd9d994e52b4..f8ea6d4003619 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -38,6 +38,20 @@ void ComputeJob( ORT_UNUSED_PARAMETER(bias_float_ptr); // only used in MLFloat16 overload ORT_UNUSED_PARAMETER(alloc); + int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size); + + if constexpr (std::is_same_v) { + if (MlasLayerNormF32( + X_data + task_idx * norm_size, scale_data + i, + (simplified || !bias_data) ? nullptr : bias_data + i, + Y_data + task_idx * norm_size, + mean_data ? &mean_data[task_idx] : nullptr, + inv_std_dev_data ? &inv_std_dev_data[task_idx] : nullptr, + static_cast(norm_size), epsilon, simplified)) { + return; + } + } + const T* p_input = X_data + task_idx * norm_size; T* p_output = Y_data + task_idx * norm_size; @@ -57,9 +71,6 @@ void ComputeJob( mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); } - // Compute the offset of gamma and beta to support broadcasting. - int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size); - for (int64_t h = 0; h < norm_size; h++, i++) { if (simplified) { p_output[h] = p_output[h] / mean_square * scale_data[i]; diff --git a/onnxruntime/test/mlas/bench/riscv64/cast_rvv_bench.cpp b/onnxruntime/test/mlas/bench/riscv64/cast_rvv_bench.cpp new file mode 100644 index 0000000000000..bfdcb1d3c8cfc --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/cast_rvv_bench.cpp @@ -0,0 +1,165 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + cast_rvv_bench.cpp + +Abstract: + + Correctness and performance comparison of FP16<->FP32 cast kernels. + + Scalar path: ORT's internal fallback in cast.cpp + (MLAS_Half2Float / MLAS_Float2Half loop when CastKernel == nullptr) + Dispatch path: MlasConvertHalfToFloatBuffer / MlasConvertFloatToHalfBuffer + (dispatches to registered RVV kernel via platform.CastF16ToF32Kernel) + +--*/ + +#include "mlas.h" +#include "mlas_float16.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct Options { + size_t count = 1024 * 64; + size_t iters = 200; + size_t warmup = 20; +}; + +Options ParseArgs(int argc, char** argv) { + Options options; + for (int i = 1; i < argc; ++i) { + std::string_view arg(argv[i]); + const auto split = arg.find('='); + if (split == std::string_view::npos) continue; + const auto key = arg.substr(0, split); + const auto value = arg.substr(split + 1); + if (key == "--count") + options.count = std::strtoull(value.data(), nullptr, 10); + else if (key == "--iters") + options.iters = std::strtoull(value.data(), nullptr, 10); + else if (key == "--warmup") + options.warmup = std::strtoull(value.data(), nullptr, 10); + } + return options; +} + +float MakeValue(size_t index) { + uint32_t x = static_cast(index * 747796405u + 2891336453u); + x ^= x >> 16; + x *= 2246822519u; + x ^= x >> 13; + return (static_cast(x % 2048u) / 1024.0f) - 1.0f; +} + +template +double TimeLoop(size_t iterations, Fn&& fn) { + const auto begin = std::chrono::steady_clock::now(); + for (size_t i = 0; i < iterations; ++i) { + fn(); + } + const auto end = std::chrono::steady_clock::now(); + return std::chrono::duration(end - begin).count(); +} + +} // namespace + +int main(int argc, char** argv) { + const Options opts = ParseArgs(argc, argv); + const size_t N = opts.count; + + std::cout << "=== FP16<->FP32 Cast: RVV Dispatch vs ORT Scalar Fallback ===\n" + << " count=" << N << " iters=" << opts.iters << " warmup=" << opts.warmup << "\n\n"; + + std::vector fp32_src(N); + std::vector<_mlas_fp16_> fp16_src(N); + for (size_t i = 0; i < N; ++i) { + fp32_src[i] = MakeValue(i); + fp16_src[i] = MLAS_Float2Half(fp32_src[i]); + } + + std::vector f16_to_f32_fallback(N), f16_to_f32_dispatch(N); + std::vector<_mlas_fp16_> f32_to_f16_fallback(N), f32_to_f16_dispatch(N); + + // ORT scalar fallback: same as cast.cpp when CastF16ToF32Kernel == nullptr + // for (i) Destination[i] = Source[i].ToFloat(); // calls MLAS_Half2Float + auto fallback_h2f = [&]() { + for (size_t i = 0; i < N; ++i) + f16_to_f32_fallback[i] = MLAS_Half2Float(fp16_src[i]); + }; + auto fallback_f2h = [&]() { + for (size_t i = 0; i < N; ++i) + f32_to_f16_fallback[i] = MLAS_Float2Half(fp32_src[i]); + }; + + // ORT dispatch path: MlasConvertHalfToFloatBuffer (uses registered RVV kernel) + auto dispatch_h2f = [&]() { + MlasConvertHalfToFloatBuffer( + reinterpret_cast(fp16_src.data()), + f16_to_f32_dispatch.data(), N); + }; + auto dispatch_f2h = [&]() { + MlasConvertFloatToHalfBuffer( + fp32_src.data(), + reinterpret_cast(f32_to_f16_dispatch.data()), N); + }; + + // --- Correctness --- + fallback_h2f(); + dispatch_h2f(); + fallback_f2h(); + dispatch_f2h(); + + size_t h2f_mismatches = 0; + for (size_t i = 0; i < N; ++i) { + if (f16_to_f32_fallback[i] != f16_to_f32_dispatch[i]) h2f_mismatches++; + } + size_t f2h_mismatches = 0; + for (size_t i = 0; i < N; ++i) { + if (f32_to_f16_fallback[i] != f32_to_f16_dispatch[i]) f2h_mismatches++; + } + + std::cout << "Correctness:\n" + << " F16->F32: mismatches=" << h2f_mismatches << "/" << N + << (h2f_mismatches == 0 ? " PASS" : " FAIL") << "\n" + << " F32->F16: mismatches=" << f2h_mismatches << "/" << N + << (f2h_mismatches == 0 ? " PASS" : " FAIL") << "\n"; + + // --- Performance --- + for (size_t i = 0; i < opts.warmup; ++i) { + fallback_h2f(); + dispatch_h2f(); + fallback_f2h(); + dispatch_f2h(); + } + + double s_h2f = TimeLoop(opts.iters, fallback_h2f) / opts.iters; + double d_h2f = TimeLoop(opts.iters, dispatch_h2f) / opts.iters; + double s_f2h = TimeLoop(opts.iters, fallback_f2h) / opts.iters; + double d_f2h = TimeLoop(opts.iters, dispatch_f2h) / opts.iters; + + std::cout << std::fixed << std::setprecision(3) + << "\nF16->F32 (" << N << " elements):\n" + << " ORT Fallback: " << s_h2f << " ms\n" + << " ORT Dispatch: " << d_h2f << " ms\n" + << " Speedup: " << s_h2f / d_h2f << "x\n" + << "\nF32->F16 (" << N << " elements):\n" + << " ORT Fallback: " << s_f2h << " ms\n" + << " ORT Dispatch: " << d_f2h << " ms\n" + << " Speedup: " << s_f2h / d_f2h << "x\n"; + + return (h2f_mismatches + f2h_mismatches > 0) ? 1 : 0; +} diff --git a/onnxruntime/test/mlas/bench/riscv64/halfgemm_rvv_bench.cpp b/onnxruntime/test/mlas/bench/riscv64/halfgemm_rvv_bench.cpp new file mode 100644 index 0000000000000..0f74c4ec7017b --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/halfgemm_rvv_bench.cpp @@ -0,0 +1,253 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_rvv_bench.cpp + +Abstract: + + Correctness and performance comparison of RVV-accelerated FP16 GEMM + against ORT's built-in scalar FP16 GEMM dispatch (MlasHalfGemmDispatchDefault). + + Both paths use the same MLAS HalfGemm dispatch interface with FP16 I/O. + + Usage: + ./onnxruntime_mlas_halfgemm_rvv_bench [--m=N] [--n=N] [--k=N] + [--iters=N] [--warmup=N] [--bias=0|1] + +--*/ + +#include "mlas.h" +#include "mlas_float16.h" +#include "halfgemm.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct Options { + size_t m = 64; + size_t n = 768; + size_t k = 768; + size_t iters = 20; + size_t warmup = 3; + bool use_bias = false; +}; + +void PrintUsage(const char* argv0) { + std::cout + << "Usage: " << argv0 + << " [--m=N] [--n=N] [--k=N] [--iters=N] [--warmup=N] [--bias=0|1]\n"; +} + +bool ParseBool(std::string_view value) { + return value == "1" || value == "true" || value == "on" || value == "yes"; +} + +Options ParseArgs(int argc, char** argv) { + Options options; + for (int i = 1; i < argc; ++i) { + std::string_view arg(argv[i]); + if (arg == "--help" || arg == "-h") { + PrintUsage(argv[0]); + std::exit(0); + } + const auto split = arg.find('='); + if (split == std::string_view::npos || split == 0 || split + 1 >= arg.size()) { + continue; + } + const std::string_view key = arg.substr(0, split); + const std::string_view value = arg.substr(split + 1); + if (key == "--m") + options.m = std::strtoull(value.data(), nullptr, 10); + else if (key == "--n") + options.n = std::strtoull(value.data(), nullptr, 10); + else if (key == "--k") + options.k = std::strtoull(value.data(), nullptr, 10); + else if (key == "--iters") + options.iters = std::strtoull(value.data(), nullptr, 10); + else if (key == "--warmup") + options.warmup = std::strtoull(value.data(), nullptr, 10); + else if (key == "--bias") + options.use_bias = ParseBool(value); + } + return options; +} + +float MakeValue(size_t index) { + uint32_t x = static_cast(index * 747796405u + 2891336453u); + x ^= x >> 16; + x *= 2246822519u; + x ^= x >> 13; + const uint32_t bucket = x % 2048u; + return (static_cast(bucket) / 1024.0f) - 1.0f; +} + +template +double TimeLoop(size_t iterations, Fn&& fn) { + const auto begin = std::chrono::steady_clock::now(); + for (size_t i = 0; i < iterations; ++i) { + fn(); + } + const auto end = std::chrono::steady_clock::now(); + return std::chrono::duration(end - begin).count(); +} + +void RunDispatch( + const MLAS_HALFGEMM_DISPATCH& dispatch, + size_t M, size_t N, size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* data) { + dispatch.Operation(N, K, data, 0, M, 0, N); +} + +} // namespace + +int main(int argc, char** argv) { + const Options options = ParseArgs(argc, argv); + + if (options.m == 0 || options.n == 0 || options.k == 0 || options.iters == 0) { + std::cerr << "m, n, k, and iters must be > 0\n"; + return 1; + } + + const bool fp16_supported = MlasFp16AccelerationSupported(); + + std::cout << "=== FP16 GEMM: RVV vs ORT Scalar Dispatch ===\n" + << " M=" << options.m << " N=" << options.n << " K=" << options.k + << " bias=" << (options.use_bias ? "yes" : "no") << "\n" + << " iters=" << options.iters << " warmup=" << options.warmup << "\n" + << " FP16 acceleration: " << (fp16_supported ? "YES (RVV)" : "NO") << "\n\n"; + + const size_t a_size = options.m * options.k; + const size_t b_size = options.k * options.n; + const size_t c_size = options.m * options.n; + + std::vector<_mlas_fp16_> a_fp16(a_size); + std::vector<_mlas_fp16_> b_fp16(b_size); + std::vector<_mlas_fp16_> bias_fp16(options.n); + std::vector<_mlas_fp16_> c_rvv(c_size); + std::vector<_mlas_fp16_> c_scalar(c_size); + + for (size_t i = 0; i < a_size; ++i) { + a_fp16[i] = MLAS_Float2Half(MakeValue(i) * 0.1f); + } + for (size_t i = 0; i < b_size; ++i) { + b_fp16[i] = MLAS_Float2Half(MakeValue(i + a_size) * 0.1f); + } + if (options.use_bias) { + for (size_t i = 0; i < options.n; ++i) { + bias_fp16[i] = MLAS_Float2Half(MakeValue(i + a_size + b_size) * 0.01f); + } + } + + MLAS_HALF_GEMM_DATA_PARAMS params_scalar; + params_scalar.A = a_fp16.data(); + params_scalar.lda = options.k; + params_scalar.B = b_fp16.data(); + params_scalar.ldb = options.n; + params_scalar.C = reinterpret_cast(c_scalar.data()); + params_scalar.ldc = options.n; + params_scalar.Bias = options.use_bias + ? reinterpret_cast(bias_fp16.data()) + : nullptr; + params_scalar.AIsfp32 = false; + params_scalar.BIsfp32 = false; + params_scalar.OutputProcessor = nullptr; + + MLAS_HALF_GEMM_DATA_PARAMS params_rvv = params_scalar; + params_rvv.C = reinterpret_cast(c_rvv.data()); + + // --- Run both dispatches --- + RunDispatch(MlasHalfGemmDispatchDefault, options.m, options.n, options.k, ¶ms_scalar); + +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) + RunDispatch(MlasHalfGemmDispatchRvv, options.m, options.n, options.k, ¶ms_rvv); +#else + RunDispatch(MlasHalfGemmDispatchDefault, options.m, options.n, options.k, ¶ms_rvv); + std::cout << " (RVV dispatch not available, comparing scalar vs scalar)\n\n"; +#endif + + // --- Correctness: RVV vs ORT Scalar --- + double max_abs_err = 0.0; + double max_rel_err = 0.0; + size_t error_count = 0; + + for (size_t i = 0; i < c_size; ++i) { + float ref = MLAS_Half2Float(c_scalar[i]); + float got = MLAS_Half2Float(c_rvv[i]); + double abs_err = std::abs(ref - got); + double rel_err = (std::abs(ref) > 1e-6) ? abs_err / std::abs(ref) : abs_err; + + if (abs_err > max_abs_err) max_abs_err = abs_err; + if (rel_err > max_rel_err) max_rel_err = rel_err; + + if (rel_err > 0.10 && abs_err > 0.005) { + if (error_count < 10) { + std::cerr << " MISMATCH [" << i / options.n << "," << i % options.n + << "]: scalar=" << ref << " rvv=" << got + << " abs=" << abs_err << " rel=" << rel_err << "\n"; + } + error_count++; + } + } + + std::cout << "Correctness (RVV vs ORT Scalar):\n" + << " max abs error: " << max_abs_err << "\n" + << " max rel error: " << max_rel_err << "\n" + << " mismatches (>10% rel && >0.005 abs): " << error_count + << " / " << c_size << "\n"; + + if (error_count > 0) { + std::cout << " STATUS: FAIL\n\n"; + } else { + std::cout << " STATUS: PASS\n\n"; + } + + // --- Performance --- + auto run_scalar_fn = [&]() { + RunDispatch(MlasHalfGemmDispatchDefault, options.m, options.n, options.k, ¶ms_scalar); + }; + + auto run_rvv_fn = [&]() { +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) + RunDispatch(MlasHalfGemmDispatchRvv, options.m, options.n, options.k, ¶ms_rvv); +#else + RunDispatch(MlasHalfGemmDispatchDefault, options.m, options.n, options.k, ¶ms_rvv); +#endif + }; + + for (size_t i = 0; i < options.warmup; ++i) { + run_scalar_fn(); + run_rvv_fn(); + } + + const double scalar_ms = TimeLoop(options.iters, run_scalar_fn); + const double scalar_avg = scalar_ms / static_cast(options.iters); + + const double rvv_ms = TimeLoop(options.iters, run_rvv_fn); + const double rvv_avg = rvv_ms / static_cast(options.iters); + + const double flops = 2.0 * options.m * options.n * options.k; + const double scalar_gflops = flops / (scalar_avg * 1e6); + const double rvv_gflops = flops / (rvv_avg * 1e6); + const double speedup = scalar_avg / rvv_avg; + + std::cout << std::fixed << std::setprecision(3) + << "Performance:\n" + << " ORT Scalar: " << scalar_avg << " ms (" << scalar_gflops << " GFLOPS)\n" + << " RVV: " << rvv_avg << " ms (" << rvv_gflops << " GFLOPS)\n" + << " Speedup: " << speedup << "x\n"; + + return (error_count > 0) ? 1 : 0; +} diff --git a/onnxruntime/test/mlas/bench/riscv64/rmsnorm_rvv_bench.cpp b/onnxruntime/test/mlas/bench/riscv64/rmsnorm_rvv_bench.cpp new file mode 100644 index 0000000000000..df777744da4cb --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/rmsnorm_rvv_bench.cpp @@ -0,0 +1,169 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rmsnorm_rvv_bench.cpp + +Abstract: + + Correctness and performance comparison of RMSNorm (SimplifiedLayerNorm). + + Scalar path: ORT's ComputeJob with simplified=true + (anonymous namespace in layer_norm_impl.cc, reproduced here verbatim) + MLAS path: MlasLayerNormF32 dispatch (uses RVV kernel when available) + +--*/ + +#include "mlas.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct Options { + size_t hidden = 1024; + size_t iters = 500; + size_t warmup = 50; +}; + +Options ParseArgs(int argc, char** argv) { + Options options; + for (int i = 1; i < argc; ++i) { + std::string_view arg(argv[i]); + const auto split = arg.find('='); + if (split == std::string_view::npos) continue; + const auto key = arg.substr(0, split); + const auto value = arg.substr(split + 1); + if (key == "--hidden") + options.hidden = std::strtoull(value.data(), nullptr, 10); + else if (key == "--iters") + options.iters = std::strtoull(value.data(), nullptr, 10); + else if (key == "--warmup") + options.warmup = std::strtoull(value.data(), nullptr, 10); + } + return options; +} + +float MakeValue(size_t index) { + uint32_t x = static_cast(index * 747796405u + 2891336453u); + x ^= x >> 16; + x *= 2246822519u; + x ^= x >> 13; + return (static_cast(x % 2048u) / 1024.0f) - 1.0f; +} + +template +double TimeLoop(size_t iterations, Fn&& fn) { + const auto begin = std::chrono::steady_clock::now(); + for (size_t i = 0; i < iterations; ++i) { + fn(); + } + const auto end = std::chrono::steady_clock::now(); + return std::chrono::duration(end - begin).count(); +} + +// +// ORT scalar path: verbatim from layer_norm_impl.cc ComputeJob +// with simplified=true (RMSNorm). +// +void OrtRmsNormScalar( + const float* input, + const float* scale, + size_t norm_size, + float epsilon, + float* output) { + float mean_square = 0.0f; + for (size_t h = 0; h < norm_size; h++) { + output[h] = input[h]; + mean_square += input[h] * input[h]; + } + mean_square = sqrtf(mean_square / static_cast(norm_size) + epsilon); + for (size_t h = 0; h < norm_size; h++) { + output[h] = output[h] / mean_square * scale[h]; + } +} + +void OrtRmsNormMlas( + const float* input, + const float* scale, + size_t norm_size, + float epsilon, + float* output) { + if (!MlasLayerNormF32(input, scale, nullptr, output, nullptr, nullptr, + norm_size, epsilon, true)) { + OrtRmsNormScalar(input, scale, norm_size, epsilon, output); + } +} + +} // namespace + +int main(int argc, char** argv) { + const Options opts = ParseArgs(argc, argv); + const size_t N = opts.hidden; + const float epsilon = 1e-6f; + + std::cout << "=== RMSNorm: MLAS Dispatch vs ORT Scalar ===\n" + << " hidden=" << N << " iters=" << opts.iters << " warmup=" << opts.warmup << "\n\n"; + + std::vector input(N), scale(N); + std::vector out_scalar(N), out_rvv(N); + + for (size_t i = 0; i < N; i++) { + input[i] = MakeValue(i) * 0.1f; + scale[i] = 1.0f + MakeValue(i + N) * 0.01f; + } + + // --- Correctness --- + OrtRmsNormScalar(input.data(), scale.data(), N, epsilon, out_scalar.data()); + OrtRmsNormMlas(input.data(), scale.data(), N, epsilon, out_rvv.data()); + + double max_abs = 0.0, max_rel = 0.0; + size_t mismatches = 0; + for (size_t i = 0; i < N; i++) { + double abs_err = std::abs(out_scalar[i] - out_rvv[i]); + double rel_err = (std::abs(out_scalar[i]) > 1e-7) ? abs_err / std::abs(out_scalar[i]) : abs_err; + if (abs_err > max_abs) max_abs = abs_err; + if (rel_err > max_rel) max_rel = rel_err; + if (abs_err > 1e-5) mismatches++; + } + + std::cout << "Correctness:\n" + << " max_abs=" << max_abs << " max_rel=" << max_rel + << " mismatches=" << mismatches << "/" << N + << (mismatches == 0 ? " PASS" : " FAIL") << "\n"; + + // --- Performance --- + auto run_scalar = [&]() { + OrtRmsNormScalar(input.data(), scale.data(), N, epsilon, out_scalar.data()); + }; + auto run_rvv = [&]() { + OrtRmsNormMlas(input.data(), scale.data(), N, epsilon, out_rvv.data()); + }; + + for (size_t i = 0; i < opts.warmup; i++) { + run_scalar(); + run_rvv(); + } + + double scalar_ms = TimeLoop(opts.iters, run_scalar) / opts.iters; + double rvv_ms = TimeLoop(opts.iters, run_rvv) / opts.iters; + + std::cout << std::fixed << std::setprecision(4) + << "\nPerformance:\n" + << " ORT Scalar: " << scalar_ms * 1000 << " us\n" + << " RVV: " << rvv_ms * 1000 << " us\n" + << " Speedup: " << scalar_ms / rvv_ms << "x\n"; + + return (mismatches > 0) ? 1 : 0; +} diff --git a/onnxruntime/test/mlas/bench/riscv64/rope_rvv_bench.cpp b/onnxruntime/test/mlas/bench/riscv64/rope_rvv_bench.cpp new file mode 100644 index 0000000000000..cbe6941bf8ca7 --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/rope_rvv_bench.cpp @@ -0,0 +1,152 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rope_rvv_bench.cpp + +Abstract: + + Correctness and performance comparison of RotaryEmbedding. + + Scalar path: MlasRotaryEmbedOneRow_FallBack (ORT's internal scalar fallback) + Dispatch path: MlasRotaryEmbedOneRow (dispatches to RVV kernel via platform) + +--*/ + +#include "mlas.h" +#include "mlas_float16.h" +#include "rotary_embedding.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct Options { + size_t dim = 128; + size_t iters = 500; + size_t warmup = 50; +}; + +Options ParseArgs(int argc, char** argv) { + Options options; + for (int i = 1; i < argc; ++i) { + std::string_view arg(argv[i]); + const auto split = arg.find('='); + if (split == std::string_view::npos) continue; + const auto key = arg.substr(0, split); + const auto value = arg.substr(split + 1); + if (key == "--dim") + options.dim = std::strtoull(value.data(), nullptr, 10); + else if (key == "--iters") + options.iters = std::strtoull(value.data(), nullptr, 10); + else if (key == "--warmup") + options.warmup = std::strtoull(value.data(), nullptr, 10); + } + return options; +} + +float MakeValue(size_t index) { + uint32_t x = static_cast(index * 747796405u + 2891336453u); + x ^= x >> 16; + x *= 2246822519u; + x ^= x >> 13; + return (static_cast(x % 2048u) / 1024.0f) - 1.0f; +} + +template +double TimeLoop(size_t iterations, Fn&& fn) { + const auto begin = std::chrono::steady_clock::now(); + for (size_t i = 0; i < iterations; ++i) { + fn(); + } + const auto end = std::chrono::steady_clock::now(); + return std::chrono::duration(end - begin).count(); +} + +void CompareResults(const float* ref, const float* got, size_t n) { + double max_abs = 0.0, max_rel = 0.0; + size_t mismatches = 0; + for (size_t i = 0; i < n; i++) { + double abs_err = std::abs(ref[i] - got[i]); + double rel_err = (std::abs(ref[i]) > 1e-7) ? abs_err / std::abs(ref[i]) : abs_err; + if (abs_err > max_abs) max_abs = abs_err; + if (rel_err > max_rel) max_rel = rel_err; + if (abs_err > 1e-5) mismatches++; + } + std::cout << " max_abs=" << max_abs << " max_rel=" << max_rel + << " mismatches=" << mismatches << "/" << n + << (mismatches == 0 ? " PASS" : " FAIL") << "\n"; +} + +void BenchRoPE(const char* label, size_t dim, bool interleaved, size_t iters, size_t warmup) { + if (dim % 2 != 0) { + std::cerr << "Error: dim must be even, got " << dim << "\n"; + return; + } + const size_t half = dim / 2; + + std::vector input(dim), sin_data(half), cos_data(half); + std::vector out_fallback(dim), out_dispatch(dim); + + for (size_t i = 0; i < dim; i++) input[i] = MakeValue(i); + for (size_t i = 0; i < half; i++) { + sin_data[i] = sinf(static_cast(i) * 0.01f); + cos_data[i] = cosf(static_cast(i) * 0.01f); + } + + // ORT scalar fallback + MlasRotaryEmbedOneRow_FallBack( + input.data(), sin_data.data(), cos_data.data(), dim, interleaved, out_fallback.data()); + // ORT dispatch (→ RVV) + MlasRotaryEmbedOneRow( + input.data(), sin_data.data(), cos_data.data(), dim, interleaved, out_dispatch.data()); + + std::cout << "--- " << label << " (dim=" << dim << ") ---\n"; + CompareResults(out_fallback.data(), out_dispatch.data(), dim); + + auto run_fallback = [&]() { + MlasRotaryEmbedOneRow_FallBack( + input.data(), sin_data.data(), cos_data.data(), dim, interleaved, out_fallback.data()); + }; + auto run_dispatch = [&]() { + MlasRotaryEmbedOneRow( + input.data(), sin_data.data(), cos_data.data(), dim, interleaved, out_dispatch.data()); + }; + + for (size_t i = 0; i < warmup; i++) { + run_fallback(); + run_dispatch(); + } + + double fallback_ms = TimeLoop(iters, run_fallback) / iters; + double dispatch_ms = TimeLoop(iters, run_dispatch) / iters; + + std::cout << std::fixed << std::setprecision(4) + << " ORT Fallback: " << fallback_ms * 1000 << " us\n" + << " ORT Dispatch: " << dispatch_ms * 1000 << " us\n" + << " Speedup: " << fallback_ms / dispatch_ms << "x\n\n"; +} + +} // namespace + +int main(int argc, char** argv) { + const Options opts = ParseArgs(argc, argv); + + std::cout << "=== RotaryEmbedding: RVV Dispatch vs ORT Scalar Fallback ===\n\n"; + + BenchRoPE("RoPE non-interleaved", opts.dim, false, opts.iters, opts.warmup); + BenchRoPE("RoPE interleaved", opts.dim, true, opts.iters, opts.warmup); + + return 0; +} diff --git a/onnxruntime/test/mlas/unittest/test_cast_fp16.cpp b/onnxruntime/test/mlas/unittest/test_cast_fp16.cpp new file mode 100644 index 0000000000000..1b8126b384f1e --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_cast_fp16.cpp @@ -0,0 +1,120 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_cast_fp16.cpp + +Abstract: + + Tests for MLAS FP16<->FP32 cast kernels. + Verifies bit-exactness against MLAS_Half2Float / MLAS_Float2Half. + +--*/ + +#include "test_util.h" +#include "mlas.h" +#include "mlas_float16.h" + +#include + +class MlasCastFp16Test : public MlasTestBase { + public: + void TestF16ToF32(size_t count) { + std::vector<_mlas_fp16_> input(count); + std::vector output_ref(count); + std::vector output_dispatch(count); + + for (size_t i = 0; i < count; i++) { + float val = (static_cast(i % 2048) / 1024.0f) - 1.0f; + input[i] = MLAS_Float2Half(val); + output_ref[i] = MLAS_Half2Float(input[i]); + } + + MlasConvertHalfToFloatBuffer( + reinterpret_cast(input.data()), + output_dispatch.data(), count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(output_dispatch[i], output_ref[i]) + << "F16->F32 mismatch at [" << i << "], count=" << count; + } + } + + void TestF32ToF16(size_t count) { + std::vector input(count); + std::vector<_mlas_fp16_> output_ref(count); + std::vector<_mlas_fp16_> output_dispatch(count); + + for (size_t i = 0; i < count; i++) { + input[i] = (static_cast(i % 2048) / 1024.0f) - 1.0f; + output_ref[i] = MLAS_Float2Half(input[i]); + } + + MlasConvertFloatToHalfBuffer( + input.data(), + reinterpret_cast(output_dispatch.data()), count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(output_dispatch[i], output_ref[i]) + << "F32->F16 mismatch at [" << i << "], count=" << count; + } + } +}; + +class CastFp16ShortExecuteTest : public MlasTestFixture { + public: + CastFp16ShortExecuteTest(size_t count, bool f16_to_f32) + : count_(count), f16_to_f32_(f16_to_f32) {} + + void TestBody() override { + if (f16_to_f32_) { + MlasTestFixture::mlas_tester->TestF16ToF32(count_); + } else { + MlasTestFixture::mlas_tester->TestF32ToF16(count_); + } + } + + static size_t RegisterSingleTest(size_t count, bool f16_to_f32) { + std::stringstream ss; + ss << "/" << (f16_to_f32 ? "F16toF32" : "F32toF16") + << "/count" << count; + auto test_name = ss.str(); + + testing::RegisterTest( + "CastFp16", + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture* { + return new CastFp16ShortExecuteTest(count, f16_to_f32); + }); + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t cnt = 0; + for (size_t n : {1, 7, 15, 16, 31, 32, 63, 64, 128, 255, 256, 1024, 65536}) { + cnt += RegisterSingleTest(n, true); + cnt += RegisterSingleTest(n, false); + } + return cnt; + } + + private: + size_t count_; + bool f16_to_f32_; +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return CastFp16ShortExecuteTest::RegisterShortExecuteTests(); + } + return 0; + }); diff --git a/onnxruntime/test/mlas/unittest/test_layernorm.cpp b/onnxruntime/test/mlas/unittest/test_layernorm.cpp new file mode 100644 index 0000000000000..7475f082bb443 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_layernorm.cpp @@ -0,0 +1,157 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_layernorm.cpp + +Abstract: + + Tests for MLAS LayerNorm/RMSNorm (MlasLayerNormF32). + +--*/ + +#include "test_util.h" +#include "mlas.h" + +#include +#include + +class MlasLayerNormTest : public MlasTestBase { + private: + void ScalarLayerNorm( + const float* input, + const float* scale, + const float* bias, + float* output, + float* mean_out, + float* inv_std_out, + size_t norm_size, + float epsilon, + bool simplified) { + float sum = 0.0f; + float sum_sq = 0.0f; + for (size_t i = 0; i < norm_size; i++) { + sum += input[i]; + sum_sq += input[i] * input[i]; + } + float mean = sum / static_cast(norm_size); + float denom; + if (simplified) { + denom = std::sqrt(sum_sq / static_cast(norm_size) + epsilon); + } else { + denom = std::sqrt(sum_sq / static_cast(norm_size) - mean * mean + epsilon); + } + float inv_denom = 1.0f / denom; + + for (size_t i = 0; i < norm_size; i++) { + if (simplified) { + output[i] = input[i] * inv_denom * scale[i]; + } else if (bias == nullptr) { + output[i] = (input[i] - mean) * inv_denom * scale[i]; + } else { + output[i] = (input[i] - mean) * inv_denom * scale[i] + bias[i]; + } + } + if (mean_out) *mean_out = mean; + if (inv_std_out) *inv_std_out = inv_denom; + } + + public: + void Test(size_t norm_size, bool simplified, bool with_bias) { + std::vector input(norm_size); + std::vector scale(norm_size); + std::vector bias(norm_size); + std::vector output_ref(norm_size); + std::vector output_mlas(norm_size); + float mean_ref = 0, mean_mlas = 0; + float inv_std_ref = 0, inv_std_mlas = 0; + + for (size_t i = 0; i < norm_size; i++) { + input[i] = (static_cast(i % 127) - 63.0f) * 0.01f; + scale[i] = 1.0f + (static_cast(i % 31) - 15.0f) * 0.001f; + bias[i] = (static_cast(i % 17) - 8.0f) * 0.005f; + } + + const float* bias_ptr = (with_bias && !simplified) ? bias.data() : nullptr; + + ScalarLayerNorm(input.data(), scale.data(), bias_ptr, + output_ref.data(), &mean_ref, &inv_std_ref, + norm_size, 1e-5f, simplified); + + bool used = MlasLayerNormF32(input.data(), scale.data(), bias_ptr, + output_mlas.data(), &mean_mlas, &inv_std_mlas, + norm_size, 1e-5f, simplified); + + if (!used) { + // No optimized kernel available, skip comparison + return; + } + + for (size_t i = 0; i < norm_size; i++) { + ASSERT_NEAR(output_mlas[i], output_ref[i], 1e-4f) + << "output mismatch at [" << i << "], norm_size=" << norm_size + << " simplified=" << simplified << " bias=" << with_bias; + } + ASSERT_NEAR(mean_mlas, mean_ref, 1e-4f) << "mean mismatch"; + ASSERT_NEAR(inv_std_mlas, inv_std_ref, 1e-4f) << "inv_std_dev mismatch"; + } +}; + +class LayerNormShortExecuteTest : public MlasTestFixture { + public: + LayerNormShortExecuteTest(size_t norm_size, bool simplified, bool with_bias) + : norm_size_(norm_size), simplified_(simplified), with_bias_(with_bias) {} + + void TestBody() override { + MlasTestFixture::mlas_tester->Test(norm_size_, simplified_, with_bias_); + } + + static size_t RegisterSingleTest(size_t norm_size, bool simplified, bool with_bias) { + std::stringstream ss; + ss << "/norm_size" << norm_size + << "/simplified" << simplified + << "/bias" << with_bias; + auto test_name = ss.str(); + + testing::RegisterTest( + "LayerNorm", + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture* { + return new LayerNormShortExecuteTest(norm_size, simplified, with_bias); + }); + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t count = 0; + for (size_t n : {1, 7, 32, 63, 64, 127, 128, 256, 1024}) { + for (bool simplified : {true, false}) { + for (bool with_bias : {true, false}) { + count += RegisterSingleTest(n, simplified, with_bias); + } + } + } + return count; + } + + private: + size_t norm_size_; + bool simplified_; + bool with_bias_; +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return LayerNormShortExecuteTest::RegisterShortExecuteTests(); + } + return 0; + }); diff --git a/onnxruntime/test/mlas/unittest/test_rope.cpp b/onnxruntime/test/mlas/unittest/test_rope.cpp index eeb369224d523..c6e4b3d6545ae 100644 --- a/onnxruntime/test/mlas/unittest/test_rope.cpp +++ b/onnxruntime/test/mlas/unittest/test_rope.cpp @@ -124,8 +124,8 @@ class RoPEShortExecuteTest : public MlasTestFixture> { bool interleaved_; }; -// only test float RoPE with avx2 where RopeDispatch is assigned at this moment. -#ifdef MLAS_TARGET_AMD64 +// Enable RoPE tests on platforms where RopeDispatch is assigned. +#if defined(MLAS_TARGET_AMD64) || (defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV)) static size_t RoPERegisterAllShortExecuteTests() { return RoPEShortExecuteTest::RegisterShortExecuteTests() + RoPEShortExecuteTest::RegisterShortExecuteTests(); } From d464b2a0a6d289a95c92dcfc1abaea05f87168f0 Mon Sep 17 00:00:00 2001 From: adrastogi Date: Sat, 23 May 2026 00:51:01 -0700 Subject: [PATCH 16/16] Add example and documentation for kOrtEpDevice_EpMetadataKey_OSDriverVersion (#28282) ### Description This change provides details on the expected value for `kOrtEpDevice_EpMetadataKey_OSDriverVersion` and provides an example in the plugin EP tests. ### Motivation and Context We introduced a new EP metadata key called `kOrtEpDevice_EpMetadataKey_OSDriverVersion` in #26616 but neglected to provide sufficient detail to guide EP authors in how to fill in the value for this key. This PR is an attempt to address that gap. --------- Co-authored-by: Aditya Rastogi Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../core/session/onnxruntime_ep_device_ep_metadata_keys.h | 4 ++++ .../test/autoep/library/example_plugin_ep/ep_factory.cc | 5 +++++ onnxruntime/test/autoep/test_registration.cc | 2 ++ 3 files changed, 11 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index 5ea4261840299..cc83b7bca50c5 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -10,6 +10,10 @@ static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; // Key for the execution provider OS driver version. +// Value should be a 4-part dot-separated version string in the format "a.b.c.d" (e.g., "31.0.101.4502"). +// This maps to the Windows DXCore adapter property DXCoreAdapterProperty::DriverVersion +// (https://learn.microsoft.com/en-us/windows/win32/api/dxcore_interface/ne-dxcore_interface-dxcoreadapterproperty). +// On non-Windows platforms, the EP should provide an equivalent OS-level driver version if available. static const char* const kOrtEpDevice_EpMetadataKey_OSDriverVersion = "os_driver_version"; // Prefix for execution provider compatibility information stored in model metadata. diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 6137b23111bf9..e003f3bd93786 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -11,6 +11,7 @@ #include "ep_data_transfer.h" #include "ep_stream_support.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/onnxruntime_session_options_config_keys.h" ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtLogger& default_logger) @@ -141,6 +142,9 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* // random example using made up values factory->ort_api.AddKeyValuePair(ep_metadata, "supported_devices", "CrackGriffin 7+"); + // Example os_driver_version. A real EP would read the OS driver version from the device. + // The format is a 4-part dot-separated version matching the DXCore DriverVersion property. + factory->ort_api.AddKeyValuePair(ep_metadata, kOrtEpDevice_EpMetadataKey_OSDriverVersion, "31.0.101.1000"); factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true"); // OrtEpDevice copies ep_metadata and ep_options. @@ -171,6 +175,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* // Ort::KeyValuePairs ep_metadata; // Ort::KeyValuePairs ep_options; // ep_metadata.Add("supported_devices", "CrackGriffin 7+"); + // ep_metadata.Add(kOrtEpDevice_EpMetadataKey_OSDriverVersion, "31.0.101.1000"); // ep_options.Add("run_really_fast", "true"); // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; // ep_devices[num_ep_devices++] = ep_device.release(); diff --git a/onnxruntime/test/autoep/test_registration.cc b/onnxruntime/test/autoep/test_registration.cc index 79bc34572a6f7..40ac1670b07dc 100644 --- a/onnxruntime/test/autoep/test_registration.cc +++ b/onnxruntime/test/autoep/test_registration.cc @@ -70,6 +70,8 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { auto metadata = test_ep_device->EpMetadata(); ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), "0.1.0"); ASSERT_STREQ(metadata.GetValue("supported_devices"), "CrackGriffin 7+"); + // Verify the example plugin's expected os_driver_version value. + ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_OSDriverVersion), "31.0.101.1000"); auto options = test_ep_device->EpOptions(); ASSERT_STREQ(options.GetValue("run_really_fast"), "true");