diff --git a/cgmanifests/README.md b/cgmanifests/README.md index a7d816a401a95..5e356e4507141 100644 --- a/cgmanifests/README.md +++ b/cgmanifests/README.md @@ -1,3 +1,7 @@ # CGManifest Files This directory contains CGManifest (cgmanifest.json) files. -See [here](https://docs.opensource.microsoft.com/tools/cg/cgmanifest.html) for details. \ No newline at end of file +See [here](https://docs.opensource.microsoft.com/tools/cg/cgmanifest.html) for details. + +The WebGPU-specific manifest is in `webgpu/cgmanifest.webgpu.json`. It is intentionally not named `cgmanifest.json` +so default whole-repository Component Governance scans do not pick it up automatically. WebGPU packaging or +NOTICE-generation pipelines should stage it as `cgmanifest.json` in their scan input. diff --git a/cgmanifests/webgpu/README.md b/cgmanifests/webgpu/README.md new file mode 100644 index 0000000000000..cf03477ea6bbe --- /dev/null +++ b/cgmanifests/webgpu/README.md @@ -0,0 +1,61 @@ +# WebGPU Component Governance manifest + +This directory contains the WebGPU-specific Component Governance manifest for ONNX Runtime. It covers Dawn and the +Dawn-derived dependency graph used when building the WebGPU Execution Provider. + +The manifest is named `cgmanifest.webgpu.json`, not `cgmanifest.json`, so default whole-repository Component +Governance scans do not pick it up automatically. WebGPU packaging and NOTICE-generation pipelines should stage or copy +this file as `cgmanifest.json` in the source directory that they scan for WebGPU package notices. + +## Classification policy + +The Component Governance manifest schema provides a `developmentDependency` boolean, but it does not provide separate +first-class fields for runtime, build-tool, test-only, or conditional dependencies. This manifest uses: + +- no `developmentDependency` field for components that are redistributed, statically linked, or otherwise part of the + WebGPU package/runtime dependency closure; +- `developmentDependency: true` for Dawn dependencies that are only build tools, tests, disabled optional backends, or + source inputs that current WebGPU packages do not redistribute; +- `comments` to preserve the more precise classification and Dawn `DEPS` path/condition. + +If a WebGPU package starts redistributing a component currently marked as a development dependency, update that +registration and explain the packaging path in `comments` and `detectedComponentLocations`. + +## Maintenance + +When rolling Dawn or changing WebGPU packaging: + +1. Update the Dawn registration to match the `dawn` entry in `cmake/deps.txt`. +2. Re-audit the Dawn dependency graph for the pinned Dawn commit: + - Start from the Dawn commit in `cmake/deps.txt`; do not audit Dawn `main` or a different roll. + - Inspect Dawn's `tools/fetch_dawn_dependencies.py` at that commit. For ORT's normal source-fetch path, + `cmake/external/onnxruntime_external_deps.cmake` enables `DAWN_FETCH_DEPENDENCIES`, so the script's + `required_submodules` list is the primary set of Dawn source dependencies fetched for the build. + - Cross-reference each fetched submodule path with Dawn's `DEPS` file to get the public upstream repository URL, + commit, and condition. Use public upstream identities in this manifest, not internal mirrors. + - Compare that fetched set against this manifest. Add new fetched components, update changed commits or repository + URLs, and remove entries that are no longer fetched or relevant unless CG/legal guidance requires keeping them. + - Cross-check ORT's Dawn CMake options in `cmake/external/onnxruntime_external_deps.cmake` and Dawn's + `third_party/CMakeLists.txt` before classifying a component. Components that are redistributed, statically linked, + or otherwise part of the WebGPU package/runtime closure should not be marked as development dependencies; build + tools, test inputs, disabled optional backends, and unfetched conditional dependencies should be marked + `developmentDependency: true` if they remain registered. + - Verify actual WebGPU package contents, especially platform-specific artifacts. For example, the Windows WebGPU + plugin pipeline downloads and redistributes DXC DLLs separately from Dawn's `third_party/dxc` source dependency, so + both the Dawn build-input registration and the redistributed DXC release registration may need review. + - Keep Dawn-derived registrations connected to the Dawn root with `dependencyRoots`. +3. If the Windows WebGPU plugin pipeline changes the downloaded DXC release, update the DirectXShaderCompiler release + registration to match `tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml`. +4. Run: + + ```powershell + python cgmanifests/webgpu/validate_webgpu_cgmanifest.py + ``` + +The validator checks for stale Dawn and DXC pins, but it does not replace the manual dependency classification review +in step 2. + +Non-git Dawn toolchain packages from CIPD/GCS, such as GN, Ninja, CMake, Go, Siso, reclient, and sysroots, are +intentionally not registered here unless they become redistributed or CG/legal guidance requires build input coverage. +They do not have stable public upstream source identities in the Dawn `DEPS` file and are not part of current WebGPU +package contents. diff --git a/cgmanifests/webgpu/cgmanifest.webgpu.json b/cgmanifests/webgpu/cgmanifest.webgpu.json new file mode 100644 index 0000000000000..90448c9b4a68e --- /dev/null +++ b/cgmanifests/webgpu/cgmanifest.webgpu.json @@ -0,0 +1,1110 @@ +{ + "$schema": "https://json.schemastore.org/component-detection-manifest.json", + "version": 1, + "registrations": [ + { + "component": { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + }, + "comments": "runtime; WebGPU EP root dependency pinned in cmake/deps.txt and patched by cmake/external/onnxruntime_external_deps.cmake.", + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/deps.txt", + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake", + "{SourceFileRoot}/cmake/patches/dawn" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b4711839eb9a87da7c3436d9b212e0492359fbbd", + "repositoryUrl": "https://github.com/microsoft/DirectXShaderCompiler.git", + "tag": "v1.8.2502" + } + }, + "comments": "runtime; redistributed by Windows WebGPU plugin packages as dxil.dll and dxcompiler.dll. Release zip: https://github.com/microsoft/DirectXShaderCompiler/releases/download/v1.8.2502/dxc_2025_02_20.zip; SHA256: 70B1913A1BFCE4A3E1A5311D16246F4ECDF3A3E613ABEC8AA529E57668426F85.", + "detectedComponentLocations": [ + "{SourceFileRoot}/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml", + "{SourceFileRoot}/plugin-ep-webgpu/csharp/pack_nuget.py", + "{SourceFileRoot}/plugin-ep-webgpu/python/build_wheel.py" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7ef32bbacabd0d04a6cfac92a542841c531e1b21", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/third_party/abseil-cpp" + } + }, + "comments": "runtime; Dawn DEPS third_party/abseil-cpp. ORT static WebGPU builds point Dawn at ORT's Abseil source when available.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "a0f4dc977fa2ef7f47708aec914a4fbfeefc6103", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/third_party/protobuf" + } + }, + "comments": "runtime; Dawn DEPS third_party/protobuf. ORT static WebGPU builds point Dawn at ORT's Protobuf source when available.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "f31ca173eff866369e54d35e53375fadbabd58f4", + "repositoryUrl": "https://github.com/KhronosGroup/SPIRV-Headers.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/spirv-headers/src used by Dawn/Tint SPIR-V support.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "cb38b2342beedde25bcff582dc3528a135cf6e67", + "repositoryUrl": "https://github.com/KhronosGroup/SPIRV-Tools.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/spirv-tools/src used by Dawn/Tint SPIR-V support.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "49f1a381e2aec33ef32adf4a377b5a39ec016ec4", + "repositoryUrl": "https://github.com/KhronosGroup/Vulkan-Headers.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/vulkan-headers/src and ORT Dawn port dependency for Vulkan-enabled WebGPU builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "50af38b6cd43afb1462f9ad26b8d015382d11a3d", + "repositoryUrl": "https://github.com/KhronosGroup/Vulkan-Utility-Libraries.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/vulkan-utility-libraries/src and ORT Dawn port dependency for Vulkan-enabled WebGPU builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "cb0597213b0fcb999caa9ed08c2f88dc45eb7d50", + "repositoryUrl": "https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/vulkan_memory_allocator used by Vulkan-enabled Dawn builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7eda07b1e067ef3fd7eea0419c88b5af45c9a776", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/third_party/zlib" + } + }, + "comments": "runtime; Dawn DEPS third_party/zlib.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "008e4fdd7e31d9133d028659348e054d350ccc3e", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/base/allocator/partition_allocator.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/partition_alloc used by Dawn standalone builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "3e6e148537683c22e3e74977d56516f16f39c7be", + "repositoryUrl": "https://github.com/microsoft/DirectXShaderCompiler.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/dxc used when ORT builds Dawn's built DXC path for Windows WebGPU builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "980971e835876dc0cde415e8f9bc646e64667bf7", + "repositoryUrl": "https://github.com/microsoft/DirectX-Headers.git" + } + }, + "comments": "runtime; Dawn DEPS third_party/dxheaders and ORT Dawn port dependency for D3D12/DXC-enabled WebGPU builds.", + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "6a18683f555b4ac8b05ac8395c29c84483ac9588", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/buildtools" + } + }, + "comments": "build-tool; Dawn DEPS buildtools, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c2725e0622e1a86d55f14514f2177a39efea4a0e", + "repositoryUrl": "https://chromium.googlesource.com/external/github.com/llvm/llvm-project/clang/tools/clang-format.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/clang-format/script, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "425882d8c0acaab53bf2f8abbe7efcf5db5b168b", + "repositoryUrl": "https://chromium.googlesource.com/chromium/tools/depot_tools.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/depot_tools, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7ab65651aed6802d2599dcb7a73b1f82d5179d05", + "repositoryUrl": "https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libcxx.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/libc++/src, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "8f11bb1d4438d0239d0dfc1bd9456a9f31629dda", + "repositoryUrl": "https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libcxxabi.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/libc++abi/src, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "d38523b674e26b7c8d61ed2e48d6cfe248b12da0", + "repositoryUrl": "https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libc.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/llvm-libc/src required by libc++, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "369990d9660a387f618d0eedc341eb285016243b", + "repositoryUrl": "https://chromium.googlesource.com/chromiumos/third_party/libdrm.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/libdrm/src for Linux build support, condition: dawn_standalone and host_os == \"linux\".", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "4c2c31b6776c1fe03a029f66ef530796f0add90d", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/build" + } + }, + "comments": "build-tool; Dawn DEPS build, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7fd7d7092fa5ee06380f06f66f1b7bd03fca71a8", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/clang" + } + }, + "comments": "build-tool; Dawn DEPS tools/clang, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b635f27e932356a2e29450e5cfa544cdcc9ea6bb", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/memory" + } + }, + "comments": "build-tool; Dawn DEPS tools/memory, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "da34b95fdbf2032df6cda5f3828c2ba421592644", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/valgrind" + } + }, + "comments": "build-tool; Dawn DEPS tools/valgrind, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "baacfc6d5986b07abe0503216b491e234b94ba79", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/win" + } + }, + "comments": "build-tool; Dawn DEPS tools/win, condition: checkout_win and not build_with_chromium.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "a975ec0340bd4b7dab6c8e43b15dbc638621a23c", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/mb" + } + }, + "comments": "build-tool; Dawn DEPS tools/mb, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "4d438b31b58e2dc84b592a052b6b97e05ceb6497", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/testing" + } + }, + "comments": "test-only; Dawn DEPS testing, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "bea408a6e01f0f7e6c82a43121fe3af4506c932e", + "repositoryUrl": "https://chromium.googlesource.com/external/github.com/llvm/llvm-project/compiler-rt/lib/fuzzer.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/libFuzzer/src, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "4fe3307fb2d9f86d19777c7eb0e4809e9694dde7", + "repositoryUrl": "https://github.com/google/googletest.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/googletest, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "59090f1f5e2b3ad9c90e4dc5fc8e79aed9110587", + "repositoryUrl": "https://chromium.googlesource.com/catapult.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/catapult, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "188e8278990a9069ffc84441cb5a024fd0bede37", + "repositoryUrl": "https://github.com/google/benchmark.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/google_benchmark/src, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "2e683eb7385c54f872acc47b371210d2282bc103", + "repositoryUrl": "https://gitlab.freedesktop.org/mesa/mesa.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/mesa/src, condition: dawn_standalone and checkout_mesa.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "d389906a136c2aac9820ded0f38d1e25ef25fb9a", + "repositoryUrl": "https://github.com/mesonbuild/meson.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/meson/src, condition: dawn_standalone and checkout_mesa.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c3027d884967773057bf74b957e3fea87e5df4d7", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/third_party/jinja2" + } + }, + "comments": "build-tool; Dawn DEPS third_party/jinja2 for code generation, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "4256084ae14175d38a3ff7d739dca83ae49ccec6", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/third_party/markupsafe" + } + }, + "comments": "build-tool; Dawn DEPS third_party/markupsafe for code generation, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b35641f4a3c62aa86a0b3c983d163bc0fe36026d", + "repositoryUrl": "https://github.com/glfw/glfw.git" + } + }, + "comments": "conditional; Dawn DEPS third_party/glfw. ORT disables GLFW unless onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP is enabled.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "cce16dfb64c7525c6a417f98c67423330db8f3d7", + "repositoryUrl": "https://chromium.googlesource.com/angle/angle" + } + }, + "comments": "conditional; Dawn DEPS third_party/angle. ORT disables Dawn desktop GL/OpenGLES unless PIX support is enabled.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b7b7fd22e5f28079b92412f47f6da4df43e4cd37", + "repositoryUrl": "https://swiftshader.googlesource.com/SwiftShader" + } + }, + "comments": "conditional; Dawn DEPS third_party/swiftshader. Not redistributed by ORT WebGPU packages.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "a26b8836968dc480ad283234823e6ffc62052489", + "repositoryUrl": "https://chromium.googlesource.com/vulkan-deps" + } + }, + "comments": "build-tool; Dawn DEPS third_party/vulkan-deps roll metadata. Concrete Vulkan components are registered separately.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "022de31e7ffa5230068858d9e6cd85ae11170bda", + "repositoryUrl": "https://github.com/KhronosGroup/glslang.git" + } + }, + "comments": "conditional; Dawn DEPS third_party/glslang/src. ORT disables GLSL writer/validator unless PIX support is enabled.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "09a024d4e422f8e603412f582d76c2051ef51cfc", + "repositoryUrl": "https://github.com/KhronosGroup/Vulkan-Loader.git" + } + }, + "comments": "conditional; Dawn DEPS third_party/vulkan-loader/src and ORT Dawn port dependency. Not redistributed by current WebGPU plugin packages.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "39a19dccf79d28951516c3c7c9f1ee4a606fb733", + "repositoryUrl": "https://github.com/KhronosGroup/Vulkan-Tools.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/vulkan-tools/src.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "145be10eff68bf41f1b556026ecf7da9a7c8d15b", + "repositoryUrl": "https://github.com/KhronosGroup/Vulkan-ValidationLayers.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/vulkan-validation-layers/src. ORT disables Dawn SPIR-V validation.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "5bae8738b23d06968e7c3a41308568120943ae77", + "repositoryUrl": "https://github.com/KhronosGroup/OpenGL-Registry.git" + } + }, + "comments": "conditional; Dawn DEPS third_party/khronos/OpenGL-Registry. ORT disables desktop GL/OpenGLES unless PIX support is enabled.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7dea2ed79187cd13f76183c4b9100159b9e3e071", + "repositoryUrl": "https://github.com/KhronosGroup/EGL-Registry.git" + } + }, + "comments": "conditional; Dawn DEPS third_party/khronos/EGL-Registry. ORT disables desktop GL/OpenGLES unless PIX support is enabled.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ], + "detectedComponentLocations": [ + "{SourceFileRoot}/cmake/external/onnxruntime_external_deps.cmake" + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "dbe37c7d554fd72651510c362cf62992e5f45e1f", + "repositoryUrl": "https://github.com/gpuweb/cts.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/webgpu-cts, condition: build_with_chromium or dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b4258c35121c8d0e12f53568ffb22236d7816723", + "repositoryUrl": "https://github.com/emscripten-core/emsdk.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/emsdk, condition: dawn_wasm.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "d5cfe19da8b974ca35764dd1c73b91d57cd3c4ce", + "repositoryUrl": "https://github.com/nodejs/node-api-headers.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/node-api-headers, condition: dawn_node.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "1e26dcb52829a74260ec262edb41fc22998669b6", + "repositoryUrl": "https://github.com/nodejs/node-addon-api.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/node-addon-api, condition: dawn_node.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b4b5752ff755fe33bf6a67fb6e5964ba9d40dcdc", + "repositoryUrl": "https://github.com/gpuweb/gpuweb.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/gpuweb, condition: dawn_node.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "0bfcdc4f487023d85e33597de0a94fc523e30fca", + "repositoryUrl": "https://github.com/webgpu-native/webgpu-headers.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/webgpu-headers/src for testing purposes.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "3438d4183bfc7c0d6850e8b970204cc8189f0323", + "repositoryUrl": "https://chromium.googlesource.com/chromium/src/tools/protoc_wrapper" + } + }, + "comments": "build-tool; Dawn DEPS tools/protoc_wrapper, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7bf98f78a30b067e22420ff699348f084f802e12", + "repositoryUrl": "https://github.com/google/libprotobuf-mutator.git" + } + }, + "comments": "test-only; Dawn DEPS third_party/libprotobuf-mutator/src, condition: dawn_standalone.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "42e892d96e47b1f6e29844cc705e148ec4856448", + "repositoryUrl": "https://github.com/open-source-parsers/jsoncpp.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/jsoncpp, condition: dawn_tintd.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "303c526231a90049a3e384549720f3fbd453cf66", + "repositoryUrl": "https://github.com/google/langsvr.git" + } + }, + "comments": "build-tool; Dawn DEPS third_party/langsvr, condition: dawn_tintd.", + "developmentDependency": true, + "dependencyRoots": [ + { + "type": "git", + "git": { + "commitHash": "ec7b457e5bb1fcec6f59733c4f3dd84d2f885a38", + "repositoryUrl": "https://github.com/google/dawn.git" + } + } + ] + } + ] +} diff --git a/cgmanifests/webgpu/validate_webgpu_cgmanifest.py b/cgmanifests/webgpu/validate_webgpu_cgmanifest.py new file mode 100644 index 0000000000000..ed4f4b19035cc --- /dev/null +++ b/cgmanifests/webgpu/validate_webgpu_cgmanifest.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Validate WebGPU Component Governance manifest drift.""" + +from __future__ import annotations + +import json +import re +import sys +from pathlib import Path +from typing import Any + +REPO_ROOT = Path(__file__).resolve().parents[2] +WEBGPU_CGMANIFEST = Path(__file__).resolve().with_name("cgmanifest.webgpu.json") +DEPS_TXT = REPO_ROOT / "cmake" / "deps.txt" +PLUGIN_WIN_WEBGPU_STAGE = ( + REPO_ROOT / "tools" / "ci_build" / "github" / "azure-pipelines" / "stages" / "plugin-win-webgpu-stage.yml" +) + +DAWN_REPOSITORY_URL = "https://github.com/google/dawn.git" +DXC_REPOSITORY_URL = "https://github.com/microsoft/DirectXShaderCompiler.git" + + +def _load_manifest() -> dict[str, Any]: + with WEBGPU_CGMANIFEST.open(encoding="utf-8") as manifest_file: + manifest = json.load(manifest_file) + + registrations = manifest.get("registrations") + if not isinstance(registrations, list): + raise ValueError(f"{WEBGPU_CGMANIFEST} must contain a registrations array") + + return manifest + + +def _git_component(registration: dict[str, Any]) -> dict[str, str] | None: + component = registration.get("component") + if not isinstance(component, dict) or component.get("type") != "git": + return None + + git = component.get("git") + if not isinstance(git, dict): + return None + + repository_url = git.get("repositoryUrl") + commit_hash = git.get("commitHash") + if not isinstance(repository_url, str) or not isinstance(commit_hash, str): + return None + + result = {"repositoryUrl": repository_url, "commitHash": commit_hash} + tag = git.get("tag") + if isinstance(tag, str): + result["tag"] = tag + return result + + +def _registrations(manifest: dict[str, Any]) -> list[dict[str, Any]]: + return manifest["registrations"] + + +def _find_git_registration(manifest: dict[str, Any], repository_url: str, *, tag: str | None = None) -> dict[str, Any]: + matches = [] + for registration in _registrations(manifest): + git = _git_component(registration) + if git is None or git["repositoryUrl"] != repository_url: + continue + if tag is not None and git.get("tag") != tag: + continue + matches.append(registration) + + if len(matches) != 1: + suffix = f" with tag {tag}" if tag is not None else "" + raise ValueError(f"expected exactly one registration for {repository_url}{suffix}, found {len(matches)}") + return matches[0] + + +def _dawn_commit_from_deps_txt() -> str: + deps_text = DEPS_TXT.read_text(encoding="utf-8") + match = re.search(r"^dawn;https://github\.com/google/dawn/archive/([0-9a-f]{40})\.zip;", deps_text, re.MULTILINE) + if not match: + raise ValueError(f"could not find Dawn commit in {DEPS_TXT}") + return match.group(1) + + +def _dxc_release_from_pipeline() -> tuple[str, str, str]: + pipeline_text = PLUGIN_WIN_WEBGPU_STAGE.read_text(encoding="utf-8") + url_match = re.search(r'\$dxcZipUrl = "([^"]+)"', pipeline_text) + hash_match = re.search(r'\$expectedHash = "([0-9A-Fa-f]+)"', pipeline_text) + if not url_match or not hash_match: + raise ValueError(f"could not find DXC release URL/hash in {PLUGIN_WIN_WEBGPU_STAGE}") + + tag_match = re.search(r"/download/(v[^/]+)/", url_match.group(1)) + if not tag_match: + raise ValueError(f"could not find DXC release tag in {url_match.group(1)}") + + return tag_match.group(1), url_match.group(1), hash_match.group(1).upper() + + +def _validate_dawn_root(manifest: dict[str, Any]) -> None: + registration = _find_git_registration(manifest, DAWN_REPOSITORY_URL) + git = _git_component(registration) + if git is None: + raise ValueError("Dawn registration must be a git component") + + expected_commit = _dawn_commit_from_deps_txt() + if git["commitHash"] != expected_commit: + raise ValueError(f"Dawn manifest commit {git['commitHash']} does not match {DEPS_TXT} commit {expected_commit}") + + +def _validate_dxc_release(manifest: dict[str, Any]) -> None: + expected_tag, expected_url, expected_hash = _dxc_release_from_pipeline() + registration = _find_git_registration(manifest, DXC_REPOSITORY_URL, tag=expected_tag) + git = _git_component(registration) + if git is None: + raise ValueError(f"DXC {expected_tag} registration must be a git component") + + comments = registration.get("comments", "") + if expected_url not in comments or expected_hash not in comments: + raise ValueError( + f"DXC {expected_tag} registration comments must contain pipeline URL {expected_url} " + f"and SHA256 {expected_hash}" + ) + + +def _validate_dawn_dependency_roots(manifest: dict[str, Any]) -> None: + dawn_commit = _dawn_commit_from_deps_txt() + + for registration in _registrations(manifest): + comments = registration.get("comments", "") + if not isinstance(comments, str) or "Dawn DEPS" not in comments: + continue + + dependency_roots = registration.get("dependencyRoots") + if not isinstance(dependency_roots, list) or len(dependency_roots) != 1: + raise ValueError(f"Dawn-derived registration is missing one dependencyRoots entry: {comments}") + + root = dependency_roots[0] + if not isinstance(root, dict): + raise ValueError(f"Dawn dependency root must be an object: {comments}") + + root_git = root.get("git") + if root.get("type") != "git" or not isinstance(root_git, dict): + raise ValueError(f"Dawn dependency root must be a git component: {comments}") + if root_git.get("repositoryUrl") != DAWN_REPOSITORY_URL or root_git.get("commitHash") != dawn_commit: + raise ValueError(f"Dawn dependency root does not match {DAWN_REPOSITORY_URL}@{dawn_commit}: {comments}") + + +def main() -> int: + try: + manifest = _load_manifest() + _validate_dawn_root(manifest) + _validate_dxc_release(manifest) + _validate_dawn_dependency_roots(manifest) + except (OSError, ValueError) as ex: + print(f"ERROR: {ex}", file=sys.stderr) + return 1 + + print(f"Validated {WEBGPU_CGMANIFEST}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 1ac1d52231577..f1126c2dce79e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -90,6 +90,7 @@ option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_USE_SVE "Build with SVE support in MLAS" OFF) option(onnxruntime_USE_RVV "Build with RISC-V Vector support in MLAS" OFF) +option(onnxruntime_USE_RVV_ZVFH "Build with RISC-V Zvfh (FP16 vector) support in MLAS" OFF) option(onnxruntime_USE_ARM_NEON_NCHWC "Build with ARM Neon NCHWc kernels in MLAS" OFF) option(onnxruntime_USE_KLEIDIAI "Build with KleidiAI integration in MLAS" OFF) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 7ae18db235ccb..8c7df780735f1 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -57,6 +57,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/qkv_quant.cpp ${MLAS_SRC_DIR}/cast.cpp + ${MLAS_SRC_DIR}/layernorm.cpp ${MLAS_SRC_DIR}/rotary_embedding.h ${MLAS_SRC_DIR}/rotary_embedding.cpp ${MLAS_SRC_DIR}/softmax.h @@ -959,6 +960,8 @@ endif() ${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp ${MLAS_SRC_DIR}/riscv64/sconv_depthwise_kernel_rvv.cpp ${MLAS_SRC_DIR}/riscv64/sconv_nchwc_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/rotary_embedding_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/layernorm_kernel_rvv.cpp ) list(REMOVE_ITEM mlas_platform_srcs "${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp") @@ -968,8 +971,22 @@ endif() ${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp ${MLAS_SRC_DIR}/riscv64/sconv_depthwise_kernel_rvv.cpp ${MLAS_SRC_DIR}/riscv64/sconv_nchwc_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/rotary_embedding_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/layernorm_kernel_rvv.cpp PROPERTIES COMPILE_FLAGS "-march=rv64gcv -mabi=lp64d") list(APPEND mlas_private_compile_definitions MLAS_USE_RVV=1) + + if(onnxruntime_USE_RVV_ZVFH) + list(APPEND mlas_platform_srcs + ${MLAS_SRC_DIR}/riscv64/halfgemm_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/cast_kernel_rvv.cpp + ) + set_source_files_properties( + ${MLAS_SRC_DIR}/riscv64/halfgemm_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/cast_kernel_rvv.cpp + PROPERTIES COMPILE_FLAGS "-march=rv64gcv_zvfh -mabi=lp64d") + list(APPEND mlas_private_compile_definitions MLAS_USE_RVV_ZVFH=1) + endif() else() message( WARNING diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index f3c2d8b947968..b28c35fd502ed 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -93,6 +93,17 @@ endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) + elseif(onnxruntime_DISABLE_CONTRIB_OPS AND NOT onnxruntime_CUDA_MINIMAL) + # The ONNX domain CUDA Attention kernel (core/providers/cuda/llm/attention.cc) depends on + # attention infrastructure in contrib_ops/cuda/bert/ (flash attention, memory efficient + # attention, unfused attention helpers, etc.). Include the bert attention infrastructure + # even when contrib ops are disabled so that the ONNX Attention kernel can compile and link. + set(onnxruntime_cuda_bert_cc_srcs ${onnxruntime_cuda_contrib_ops_cc_srcs}) + list(FILTER onnxruntime_cuda_bert_cc_srcs INCLUDE REGEX ".*/contrib_ops/cuda/bert/.*") + set(onnxruntime_cuda_bert_cu_srcs ${onnxruntime_cuda_contrib_ops_cu_srcs}) + list(FILTER onnxruntime_cuda_bert_cu_srcs INCLUDE REGEX ".*/contrib_ops/cuda/bert/.*") + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cuda_bert_cc_srcs} ${onnxruntime_cuda_bert_cu_srcs}) + list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_bert_cc_srcs} ${onnxruntime_cuda_bert_cu_srcs}) endif() if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index a061858fa068f..9f9356d4dff8d 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1450,6 +1450,50 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) target_compile_definitions(onnxruntime_mlas_softmax_riscv_compare PRIVATE ${mlas_private_compile_definitions}) set_target_properties(onnxruntime_mlas_softmax_riscv_compare PROPERTIES FOLDER "ONNXRuntimeTest") + + onnxruntime_add_executable( + onnxruntime_mlas_halfgemm_rvv_bench + ${MLAS_RISCV64_BENCH_DIR}/halfgemm_rvv_bench.cpp) + target_include_directories(onnxruntime_mlas_halfgemm_rvv_bench PRIVATE + ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib) + target_link_libraries( + onnxruntime_mlas_halfgemm_rvv_bench + PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) + target_compile_definitions(onnxruntime_mlas_halfgemm_rvv_bench PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(onnxruntime_mlas_halfgemm_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest") + + onnxruntime_add_executable( + onnxruntime_mlas_cast_rvv_bench + ${MLAS_RISCV64_BENCH_DIR}/cast_rvv_bench.cpp) + target_include_directories(onnxruntime_mlas_cast_rvv_bench PRIVATE + ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib) + target_link_libraries( + onnxruntime_mlas_cast_rvv_bench + PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) + target_compile_definitions(onnxruntime_mlas_cast_rvv_bench PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(onnxruntime_mlas_cast_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest") + + onnxruntime_add_executable( + onnxruntime_mlas_rope_rvv_bench + ${MLAS_RISCV64_BENCH_DIR}/rope_rvv_bench.cpp) + target_include_directories(onnxruntime_mlas_rope_rvv_bench PRIVATE + ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib) + target_link_libraries( + onnxruntime_mlas_rope_rvv_bench + PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) + target_compile_definitions(onnxruntime_mlas_rope_rvv_bench PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(onnxruntime_mlas_rope_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest") + + onnxruntime_add_executable( + onnxruntime_mlas_rmsnorm_rvv_bench + ${MLAS_RISCV64_BENCH_DIR}/rmsnorm_rvv_bench.cpp) + target_include_directories(onnxruntime_mlas_rmsnorm_rvv_bench PRIVATE + ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib) + target_link_libraries( + onnxruntime_mlas_rmsnorm_rvv_bench + PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) + target_compile_definitions(onnxruntime_mlas_rmsnorm_rvv_bench PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(onnxruntime_mlas_rmsnorm_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest") endif() if(WIN32) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index 5ea4261840299..cc83b7bca50c5 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -10,6 +10,10 @@ static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; // Key for the execution provider OS driver version. +// Value should be a 4-part dot-separated version string in the format "a.b.c.d" (e.g., "31.0.101.4502"). +// This maps to the Windows DXCore adapter property DXCoreAdapterProperty::DriverVersion +// (https://learn.microsoft.com/en-us/windows/win32/api/dxcore_interface/ne-dxcore_interface-dxcoreadapterproperty). +// On non-Windows platforms, the EP should provide an equivalent OS-level driver version if available. static const char* const kOrtEpDevice_EpMetadataKey_OSDriverVersion = "os_driver_version"; // Prefix for execution provider compatibility information stored in model metadata. diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index ea0049c28e31b..4df5f6a349599 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -194,6 +194,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { if (do_rotary_) { // When kv_sequence_length == 0 (shared KV), only Q needs RoPE — K is skipped below. ORT_ENFORCE(cos_cache != nullptr && sin_cache != nullptr, "cos_cache and sin_cache must be provided when do_rotary is true"); + // Validation of seqlens_k against rotary cache size is performed in CheckInputs() + // when seqlens_k is on CPU. GPU EPs where seqlens_k resides on device rely on + // RunRotaryEmbedding's position_ids validation for OOB protection. + // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; rotary_params.batch_size = batch_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index ed910e3510fed..3429ca5f5be52 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -310,6 +310,22 @@ Status CheckInputs(const T* query, int rotary_dim = 0; if (cos_cache != nullptr && sin_cache != nullptr) { ORT_RETURN_IF_ERROR(CheckRotaryCaches(cos_cache, sin_cache, head_size, total_sequence_length, rotary_dim)); + + // Validate seqlens_k against rotary cache size when rotary embeddings are enabled. + // This prevents OOB access when deriving position IDs from seqlens_k during rotary embedding. + const bool is_seqlens_k_on_cpu = (seqlens_k->Location().device.Type() == OrtDevice::CPU); + if (is_seqlens_k_on_cpu) { + const int64_t rotary_cache_max_seq = std::min(cos_cache->Shape().GetDims()[0], + sin_cache->Shape().GetDims()[0]); + const int32_t* seqlens_k_data = seqlens_k->template Data(); + for (int b = 0; b < batch_size; b++) { + if (seqlens_k_data[b] < 0 || static_cast(seqlens_k_data[b]) >= rotary_cache_max_seq) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k[", b, "] = ", seqlens_k_data[b], + " is out of range for rotary cache dimension 0 (", rotary_cache_max_seq, ")"); + } + } + } } else if (cos_cache != nullptr || sin_cache != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index f910abb538821..d051c5423c367 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -80,6 +80,12 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, if (y->Shape().Size() == 0) return Status::OK(); + if (bias_tensor != nullptr) { + ORT_RETURN_IF_NOT(bias_tensor->Shape().Size() == static_cast(helper.N()), + "bias tensor's element count must equal B's last dimension (", + helper.N(), "), but got ", bias_tensor->Shape().Size()); + } + auto* y_data = y->MutableData(); const auto* bias_data = bias_tensor != nullptr ? bias_tensor->Data() : nullptr; @@ -306,8 +312,12 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { // This evaluates to true if bias data was not provided as constant data for prepacking stage if (!dynamic_quant_mlas_bias_data_was_packed_) { if (ctx->Input(IN_BIAS) != nullptr) { - const auto biases = std::vector(&ctx->Input(IN_BIAS)->Data()[0], - &ctx->Input(IN_BIAS)->Data()[gemm_shape.N]); + const Tensor* bias_t = ctx->Input(IN_BIAS); + ORT_RETURN_IF_NOT(bias_t->Shape().Size() == static_cast(gemm_shape.N), + "bias tensor's element count must equal B's last dimension (", + gemm_shape.N, "), but got ", bias_t->Shape().Size()); + const auto biases = std::vector(&bias_t->Data()[0], + &bias_t->Data()[gemm_shape.N]); // deferred adding of bias for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) { diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc index 4c0c86aa60729..22cdcb75de126 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc @@ -202,6 +202,11 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const { ORT_RETURN_IF_ERROR(ValidateInputShape(sequence_shape, w_conv_shape, w_char_embedding_shape)); + const TensorShape& b_conv_shape = b_conv.Shape(); + ORT_RETURN_IF_NOT(b_conv_shape.NumDimensions() == 1 && b_conv_shape[0] == w_conv_shape[0], + "WordConvEmbedding: conv bias B must be a 1-D tensor of length ", + w_conv_shape[0], ", but got shape ", b_conv_shape); + int64_t seq_len = sequence_shape[0]; int64_t word_len = sequence_shape[1]; int64_t char_embedding_size = w_char_embedding_shape[1]; diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp index cd5c71f83ac27..8ba877aa21a68 100644 --- a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp @@ -171,10 +171,10 @@ class EpilogueMoeFusedFinalize { auto N = get<1>(problem_shape_mnkl); auto L = get<3>(problem_shape_mnkl); - auto mma_tile_m = tile_size<0>(tiled_mma); - auto mma_tile_n = tile_size<1>(tiled_mma); - auto epi_tile_m = size<0>(EpilogueTile{}); - auto epi_tile_n = size<1>(EpilogueTile{}); + constexpr auto mma_tile_m = decltype(tile_size<0>(tiled_mma)){}; + constexpr auto mma_tile_n = decltype(tile_size<1>(tiled_mma)){}; + constexpr auto epi_tile_m = size<0>(EpilogueTile{}); + constexpr auto epi_tile_n = size<1>(EpilogueTile{}); CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); @@ -248,16 +248,17 @@ class EpilogueMoeFusedFinalize { Tensor tRS_rD = make_tensor(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N) // Make a tiled copy vectorized along major direction of D + constexpr int TiledMmaThreads = decltype(cute::size(tiled_mma))::value; auto tiled_s2r = [&]() { if constexpr (cutlass::gemm::detail::is_k_major()) { constexpr int NumThreadsMajor = epi_tile_n / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; + constexpr int NumThreadsMinor = TiledMmaThreads / NumThreadsMajor; return make_tiled_copy(CopyAtomS2R{}, Layout, Int>, Stride, _1>>{}, Layout>>{}); } else if constexpr (cutlass::gemm::detail::is_mn_major()) { constexpr int NumThreadsMajor = epi_tile_m / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; + constexpr int NumThreadsMinor = TiledMmaThreads / NumThreadsMajor; return make_tiled_copy(CopyAtomS2R{}, Layout, Int>, Stride<_1, Int>>{}, Layout, _1>>{}); @@ -274,11 +275,11 @@ class EpilogueMoeFusedFinalize { Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) // Allocate intermediate registers for a single subtile - Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rBias = make_tensor(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rScale = make_tensor(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rBias = make_tensor(shape(tSR_gBias(_, _, _, 0, 0))); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rScale = make_tensor(shape(tSR_gScale(_, _, _, 0, 0))); // ((S2R,S2R_V),S2R_M,S2R_N) // Make an identity coordinate tensor for predicating our output MN tile Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h index e28d2b859a2f0..ab8ae054db048 100644 --- a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -586,8 +586,12 @@ struct MoeFCGemm { run_kernel(params, shared_storage); } #else - static_assert( - false, "Invalid architecture being compiled. Only Ampere+ supported in weight-only quantization kernels."); + // Pre-Ampere device compile pass: the MoeFCGemm body is unsupported on these archs, + // but NVCC must still emit *some* body for each requested target. Runtime dispatch + // in MoeGemmRunner::dispatchToArch() never invokes this kernel when sm_ < 80, so a + // device-side trap is safe and lets the same .cu compile cleanly under mixed arch + // lists (e.g. 52;61;75;86;89;90 in packaging pipelines). + CUTLASS_NOT_IMPLEMENTED(); #endif #else CUTLASS_NOT_IMPLEMENTED(); diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index 46a6bc6388a27..19bb1a0975720 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -77,7 +77,9 @@ ReturnType construct_if_true(Args&&... args) { if constexpr (FLAG) { - return ReturnType{std::forward(args)...}; + // Use parenthesized aggregate init (C++20) instead of brace-init to avoid + // MSVC C2397 narrowing conversion errors (e.g. size_t -> FastDivmod(int)). + return ReturnType(std::forward(args)...); } else { diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc index 5979f17e5abcc..f6bf5bbb1f0e3 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc @@ -1113,9 +1113,11 @@ void QMoE::PrePackSwizzleBlockScales(const Tensor& tensor, cudaStream_t stream, p_src = temp_src_gpu.get(); } + // QMoEBlockScaleInterleaveKernel writes every byte of the output buffer + // (the (batch, row, col) -> offset map is a bijection over + // [0, batch_size) x [0, rows_padded) x [0, cols_padded), and padded + // source positions are written as 0), so no explicit memset is required. packed_buf = IAllocator::MakeUniquePtr(alloc, dst_bytes, true); - // Zero-fill for padding regions (kernel only writes within bounds) - CUDA_CALL_THROW(cudaMemsetAsync(packed_buf.get(), 0, dst_bytes, stream)); int multi_processor_count = 0; int device_id = 0; @@ -1250,16 +1252,23 @@ void QMoE::PrePackComputeBias(const Tensor& tensor, cudaStream_t stream, Allocat return; } - bool is_fp16 = is_fp16_; - bool is_bf16 = !is_fp16_; - ORT_ENFORCE(shape.NumDimensions() == 3, "Expected 3D zeros for block-wise 4-bit"); + ORT_ENFORCE(shape[0] > 0 && shape[1] > 0 && shape[2] > 0, + "4-bit block-wise zeros must have positive dimensions, got ", shape.ToString()); + // packed_k_blocks is doubled to k_blocks below; constrain it to half of INT_MAX to keep the + // doubled value (and the int dims passed into LaunchQMoEScaledZP4BitBatched) within int range. + constexpr int64_t kMaxPackedKBlocks = std::numeric_limits::max() / 2; + ORT_ENFORCE(shape[0] <= std::numeric_limits::max() && + shape[1] <= std::numeric_limits::max() && + shape[2] <= kMaxPackedKBlocks, + "4-bit block-wise zeros dimensions exceed CUDA launch int range, got ", shape.ToString()); const int experts = static_cast(shape[0]); const int n = static_cast(shape[1]); const int packed_k_blocks = static_cast(shape[2]); const int k_blocks = packed_k_blocks * 2; + // QMoE only supports FP16/BF16 inputs (is_fp16_ is set in the ctor), both of which are 2 bytes. size_t output_count = static_cast(experts) * static_cast(k_blocks) * static_cast(n); - size_t bytes = output_count * (is_fp16 || is_bf16 ? 2 : 4); + size_t bytes = output_count * sizeof(uint16_t); packed_bias = IAllocator::MakeUniquePtr(alloc, bytes, true); const void* p_src_zp = tensor.DataRaw(); @@ -1272,20 +1281,18 @@ void QMoE::PrePackComputeBias(const Tensor& tensor, cudaStream_t stream, Allocat const uint8_t* zp_ptr = static_cast(p_src_zp); constexpr float kDefaultZeroPoint4Bit = 8.0f; - if (is_fp16) { + if (is_fp16_) { LaunchQMoEScaledZP4BitBatched( zp_ptr, static_cast(packed_scale.get()), static_cast(packed_bias.get()), experts, n, k_blocks, kDefaultZeroPoint4Bit, stream); - } else if (is_bf16) { + } else { LaunchQMoEScaledZP4BitBatched( zp_ptr, static_cast(packed_scale.get()), static_cast<__nv_bfloat16*>(packed_bias.get()), experts, n, k_blocks, kDefaultZeroPoint4Bit, stream); - } else { - ORT_THROW("Unsupported type for 4-bit block-wise ZP prepack. Expected FP16/BF16."); } } CUDA_CALL_THROW(cudaStreamSynchronize(stream)); diff --git a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu index cd59fd248b3a2..28fd4fb1516fb 100644 --- a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu @@ -3,25 +3,20 @@ // Licensed under the MIT License. #include "contrib_ops/cuda/moe/qmoe_kernels.h" +#include "core/common/narrow.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/llm/moe_gemm/moe_kernels.h" #include #include #include #include -#include namespace onnxruntime { namespace contrib { namespace cuda { int Compute1DGridSize(int num_elements, int block_size) { - ORT_ENFORCE(num_elements >= 0, "CUDA launch element count must be non-negative, got ", num_elements); - ORT_ENFORCE(block_size > 0, "CUDA launch block size must be positive, got ", block_size); - int64_t grid_size = (static_cast(num_elements) + block_size - 1) / block_size; - ORT_ENFORCE(grid_size <= std::numeric_limits::max(), - "CUDA launch grid size exceeds int range: ", grid_size); - return static_cast(grid_size); + return (num_elements + block_size - 1) / block_size; } template @@ -698,11 +693,7 @@ void LaunchQMoEDequantizeFp4WeightsImpl( cudaStream_t stream) { int64_t total = static_cast(num_experts) * n * k; constexpr int block = 256; - ORT_ENFORCE(total >= 0, "QMoEDequantizeFp4Weights: negative element count, got ", total); - int64_t grid_i64 = (total + block - 1) / block; - ORT_ENFORCE(grid_i64 <= std::numeric_limits::max(), - "QMoEDequantizeFp4Weights: grid size exceeds int range: ", grid_i64); - int grid = static_cast(grid_i64); + int grid = onnxruntime::narrow((total + block - 1) / block); QMoEDequantizeFp4WeightsKernel<<>>( packed_weights, block_scales, global_scales, output, num_experts, n, k); } @@ -785,11 +776,7 @@ void LaunchQMoEDequantizeFp8WeightsImpl( cudaStream_t stream) { int64_t total = static_cast(num_experts) * n * k; constexpr int block = 256; - ORT_ENFORCE(total >= 0, "QMoEDequantizeFp8Weights: negative element count, got ", total); - int64_t grid_i64 = (total + block - 1) / block; - ORT_ENFORCE(grid_i64 <= std::numeric_limits::max(), - "QMoEDequantizeFp8Weights: grid size exceeds int range: ", grid_i64); - int grid = static_cast(grid_i64); + int grid = onnxruntime::narrow((total + block - 1) / block); QMoEDequantizeFp8WeightsKernel<<>>( weights, global_scales, output, num_experts, n, k); } @@ -862,16 +849,10 @@ void LaunchQMoERepackFP4ColToRow( int64_t k, int64_t n, cudaStream_t stream) { - ORT_ENFORCE(experts > 0, "LaunchQMoERepackFP4ColToRow requires positive expert count, got ", experts); - ORT_ENFORCE(k > 0 && n > 0, "LaunchQMoERepackFP4ColToRow requires positive k and n, got k=", k, ", n=", n); - ORT_ENFORCE(k % 2 == 0 && n % 2 == 0, - "LaunchQMoERepackFP4ColToRow requires even k and n, got k=", k, ", n=", n); const int64_t total = static_cast(experts) * n * (k / 2); constexpr int kThreads = 256; - int64_t blocks = (total + kThreads - 1) / kThreads; - ORT_ENFORCE(blocks <= static_cast(std::numeric_limits::max()), - "LaunchQMoERepackFP4ColToRow grid size exceeds int range"); - QMoERepackFP4ColToRowKernel<<(blocks), kThreads, 0, stream>>>( + int blocks = onnxruntime::narrow((total + kThreads - 1) / kThreads); + QMoERepackFP4ColToRowKernel<<>>( input, output, experts, k, n); } @@ -901,10 +882,7 @@ __global__ void BatchedTransposeKernel(const T* __restrict__ input, T* __restric void LaunchBatchedTranspose(cudaStream_t stream, const void* input, void* output, int batch, int rows, int cols, int element_size) { int64_t total_elements = static_cast(batch) * rows * cols; int threads = 256; - int64_t blocks_i64 = (total_elements + threads - 1) / threads; - ORT_ENFORCE(blocks_i64 <= std::numeric_limits::max(), - "LaunchBatchedTranspose grid size exceeds int range: ", blocks_i64); - int blocks = static_cast(blocks_i64); + int blocks = onnxruntime::narrow((total_elements + threads - 1) / threads); if (element_size == 1) { BatchedTransposeKernel<<>>(static_cast(input), static_cast(output), batch, rows, cols); diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index f4935aaeb6b74..5a0d1e4841f05 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -202,16 +202,6 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { TensorShapeVector state_shape({batch_size, kv_num_heads_, head_dim_k, head_dim_v}); Tensor* present_state = context.Output(1, state_shape); - // Vectorization: when head_dim_v is divisible by 4, use vec4 to pack 4 dv values - // per element. This replaces scalar TILE_V loops with native vec4 SIMD operations, - // reduces shared memory access overhead, and enables coalesced memory reads/writes. - const int components = (head_dim_v % 4 == 0 && head_dim_v >= 4) ? 4 : 1; - int tile_v = (components == 4) ? 1 : 4; - if (components == 1 && head_dim_v <= 4) { - tile_v = onnxruntime::narrow(head_dim_v); - } - const int head_dim_v_vectorized = onnxruntime::narrow(head_dim_v) / components; - constexpr uint32_t kMaxSupportedWorkgroupSize = 256; ORT_RETURN_IF_NOT(head_dim_k <= static_cast(kMaxSupportedWorkgroupSize), "LinearAttention WebGPU kernel requires head_dim_k <= ", @@ -225,6 +215,31 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { // Cap at GPU limits workgroup_size = std::min(workgroup_size, kMaxSupportedWorkgroupSize); + // Vectorization: when head_dim_v is divisible by 4, use vec4 to pack 4 dv values + // per element. This replaces scalar TILE_V loops with native vec4 SIMD operations, + // reduces shared memory access overhead, and enables coalesced memory reads/writes. + // TODO: support components == 2 (vec2) for head_dim_v divisible by 2 but not 4. + const int components = (head_dim_v % 4 == 0) ? 4 : 1; + int tile_v = (components == 4) ? 1 : std::min(4, onnxruntime::narrow(head_dim_v)); + + // subgroup_min_size > 0 enables subgroup-based reduction; 0 falls back to barrier-tree. + int subgroup_min_size = context.HasFeature(wgpu::FeatureName::Subgroups) + ? static_cast(context.AdapterInfo().subgroupMinSize) + : 0; + // When subgroup is enabled, use larger tile_v for better data reuse. + // Only expand for longer sequences (>=16) where the benefit outweighs the + // increased register pressure and shared memory usage. + if (subgroup_min_size > 0 && seq_length >= 16) { + // Ensure the vectorized dimension is wide enough to warrant a larger tile. + if (head_dim_v / components >= tile_v * 4) { + tile_v *= 4; + } + } + // Clamp to workgroup_size since the shader assigns one thread per tile_v + // column (threads with dk_idx >= TILE_V are idle for output/state writes). + tile_v = std::min(tile_v, static_cast(workgroup_size)); + + const int head_dim_v_vectorized = onnxruntime::narrow(head_dim_v) / components; const int num_dv_tiles = (head_dim_v_vectorized + tile_v - 1) / tile_v; const uint32_t num_workgroups = onnxruntime::narrow(batch_size * kv_num_heads_ * num_dv_tiles); @@ -243,11 +258,6 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { } } - // subgroup_min_size > 0 enables subgroup-based reduction; 0 falls back to barrier-tree. - int subgroup_min_size = context.HasFeature(wgpu::FeatureName::Subgroups) - ? static_cast(context.AdapterInfo().subgroupMinSize) - : 0; - LinearAttentionProgram program{update_rule_, has_initial_state, has_decay, has_beta, decay_broadcast_dk, tile_v, components, subgroup_min_size}; program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template index 941793ecc7e79..206d9bb6eb20f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template @@ -90,9 +90,6 @@ $MAIN { // Initialize state tile in private memory var state: array; - for (var j = 0u; j < TILE_V; j++) { - state[j] = vtype(0.0); - } // Load initial state if provided #if has_initial_state diff --git a/onnxruntime/core/common/cpuid_arch_definition.h b/onnxruntime/core/common/cpuid_arch_definition.h index 5946b8ca27067..973c50b5dda38 100644 --- a/onnxruntime/core/common/cpuid_arch_definition.h +++ b/onnxruntime/core/common/cpuid_arch_definition.h @@ -12,3 +12,7 @@ #if defined(_M_ARM64) || defined(_M_ARM64EC) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) #define CPUIDINFO_ARCH_ARM #endif // ARM or ARM64 + +#if defined(__riscv) && __riscv_xlen == 64 +#define CPUIDINFO_ARCH_RISCV64 +#endif diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 5990013c925c5..96dc427ad766c 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -47,6 +47,16 @@ #endif // ARM +#if defined(CPUIDINFO_ARCH_RISCV64) +#include +#ifndef RISCV_HWPROBE_EXT_ZVFH +#define RISCV_HWPROBE_EXT_ZVFH (1 << 30) +#endif +#ifndef RISCV_HWPROBE_IMA_V +#define RISCV_HWPROBE_IMA_V (1 << 2) +#endif +#endif // RISCV64 + #endif // Linux #if _WIN32 @@ -334,6 +344,17 @@ void CPUIDInfo::ArmAppleInit() { #endif // defined(CPUIDINFO_ARCH_ARM) +#if defined(CPUIDINFO_ARCH_RISCV64) && defined(__linux__) +void CPUIDInfo::RiscvLinuxInit() { + struct riscv_hwprobe pairs[] = { + {RISCV_HWPROBE_KEY_IMA_EXT_0, 0}, + }; + if (syscall(__NR_riscv_hwprobe, pairs, 1, 0, nullptr, 0) == 0) { + has_fp16_ = (pairs[0].value & RISCV_HWPROBE_EXT_ZVFH) != 0; + } +} +#endif // defined(CPUIDINFO_ARCH_RISCV64) && defined(__linux__) + uint32_t CPUIDInfo::GetCurrentCoreIdx() const { #ifdef _WIN32 return GetCurrentProcessorNumber(); @@ -377,5 +398,11 @@ CPUIDInfo::CPUIDInfo() { ArmAppleInit(); #endif #endif // defined(CPUIDINFO_ARCH_ARM) + +#if defined(CPUIDINFO_ARCH_RISCV64) +#if defined(__linux__) + RiscvLinuxInit(); +#endif +#endif // defined(CPUIDINFO_ARCH_RISCV64) } } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index be301019df5c0..bf502c645c9eb 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -135,6 +135,12 @@ class CPUIDInfo { #endif // defined(CPUIDINFO_ARCH_ARM) +#if defined(CPUIDINFO_ARCH_RISCV64) +#if defined(__linux__) + void RiscvLinuxInit(); +#endif +#endif // defined(CPUIDINFO_ARCH_RISCV64) + #if defined(CPUINFO_SUPPORTED) bool pytorch_cpuinfo_init_{false}; #endif // defined(CPUINFO_SUPPORTED) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index ddb9daa5e244b..99b72dc756663 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1665,6 +1665,27 @@ MlasRotaryEmbedOneRow( T* output ); +/** + * @brief Compute LayerNorm or RMSNorm (simplified) for one row of float data. + * Uses platform-optimized kernel if available, otherwise returns false. + * Any platform (AMD64/ARM64/RISC-V) can register a LayerNormF32Kernel. + * + * @return true if an optimized kernel was used, false if caller should fall back + */ +bool +MLASCALL +MlasLayerNormF32( + const float* Input, + const float* Scale, + const float* Bias, + float* Output, + float* MeanOut, + float* InvStdDevOut, + size_t NormSize, + float Epsilon, + bool Simplified +); + /** * @brief Supply matrices data information to half precision gemm functions */ diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 66a335665d024..05cde92d9f9d7 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -27,6 +27,8 @@ MlasFp16AccelerationSupported() { #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration(); +#elif defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) + return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration(); #else return false; #endif diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 529db48f58e6f..3f63e00f05f12 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -503,12 +503,21 @@ extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault; extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon; #endif +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) +extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchRvv; +#endif + MLAS_FORCEINLINE const MLAS_HALFGEMM_DISPATCH* MlasHalfGemmGetDispatch() { #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) return &MlasHalfGemmDispatchNeon; +#elif defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration()) { + return &MlasHalfGemmDispatchRvv; + } + return &MlasHalfGemmDispatchDefault; #else return &MlasHalfGemmDispatchDefault; #endif diff --git a/onnxruntime/core/mlas/lib/layernorm.cpp b/onnxruntime/core/mlas/lib/layernorm.cpp new file mode 100644 index 0000000000000..34258436d60a0 --- /dev/null +++ b/onnxruntime/core/mlas/lib/layernorm.cpp @@ -0,0 +1,41 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + layernorm.cpp + +Abstract: + + This module implements the dispatch for platform-optimized + LayerNorm/RMSNorm kernels. + +--*/ + +#include "mlasi.h" + +bool + MLASCALL + MlasLayerNormF32( + const float* Input, + const float* Scale, + const float* Bias, + float* Output, + float* MeanOut, + float* InvStdDevOut, + size_t NormSize, + float Epsilon, + bool Simplified + ) +{ + auto kernel = GetMlasPlatform().LayerNormF32Kernel; + if (kernel == nullptr) { + return false; + } + + kernel(Input, Scale, Bias, Output, MeanOut, InvStdDevOut, NormSize, Epsilon, Simplified); + return true; +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index dbb414505ff38..bf4f3f6e2de2d 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -691,6 +691,18 @@ typedef void(MLASCALL MLAS_CAST_F32_TO_F16_KERNEL)( size_t Count ); +typedef void(MLASCALL MLAS_LAYERNORM_F32_KERNEL)( + const float* Input, + const float* Scale, + const float* Bias, + float* Output, + float* MeanOut, + float* InvStdDevOut, + size_t NormSize, + float Epsilon, + bool Simplified +); + typedef void (MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)( @@ -1230,6 +1242,15 @@ extern "C" { MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelNeon; MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelNeon; #endif + +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelRvv; + MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelRvv; +#endif + +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV) + MLAS_LAYERNORM_F32_KERNEL MlasLayerNormKernelRvv; +#endif } // @@ -1388,6 +1409,10 @@ struct MLAS_ROPE_DISPATCH; extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; extern const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2; +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV) +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchRvv; +#endif + // // half gemm dispatch structure // @@ -1631,6 +1656,7 @@ MLAS_COMPUTE_TANH_FP16_KERNEL* TanhFP16KernelRoutine = nullptr; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; + MLAS_LAYERNORM_F32_KERNEL* LayerNormF32Kernel{nullptr}; const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr}; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 466fa9a3e9497..6eb53684065a4 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -292,6 +292,15 @@ Return Value: this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelRvv; this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelRvv; this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelRvv; + this->RopeDispatch = &MlasRopeDispatchRvv; + this->LayerNormF32Kernel = &MlasLayerNormKernelRvv; + +#if defined(MLAS_USE_RVV_ZVFH) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration()) { + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelRvv; + this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelRvv; + } +#endif // NCHWc kernels require VLEN>=128 so that vfloat32m4_t holds 16 floats. if (__riscv_vlenb() >= 16) { diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp index d3681ff6bfdff..8bec2d350afa5 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp @@ -42,16 +42,16 @@ DequantInt4x8(const uint8_t* src, size_t col, bool per_channel, const float* sca // Load 4 packed bytes safely without strict-aliasing / alignment UB. // Compilers optimize memcpy of 4 bytes to a single mov instruction. - int raw_bytes; + uint32_t raw_bytes; std::memcpy(&raw_bytes, base, sizeof(raw_bytes)); - __m128i packed = _mm_cvtsi32_si128(raw_bytes); + __m128i packed = _mm_cvtsi32_si128(static_cast(raw_bytes)); // Low nibbles (even columns): AND with 0x0F __m128i lo_mask = _mm_set1_epi8(0x0F); __m128i lo = _mm_and_si128(packed, lo_mask); - // High nibbles (odd columns): shift right 4 using 32-bit granularity - // to prevent bit bleeding across 16-bit lane boundaries, then mask. + // High nibbles (odd columns): shift right by 4 within 32-bit lanes, then mask. + // Any cross-byte bits from the shift land in the upper nibble and are discarded by the mask. __m128i hi = _mm_and_si128(_mm_srli_epi32(packed, 4), lo_mask); // Interleave low and high nibbles: [lo0,hi0, lo1,hi1, lo2,hi2, lo3,hi3] @@ -126,19 +126,19 @@ FusedDotInt8( acc0 = _mm256_fmadd_ps(a0, bf0, acc0); } } else { - __m256 scale_vec = _mm256_broadcast_ss(scales); + // Per-tensor: defer scale multiplication until after accumulation. + // sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k]) + // This saves one vmulps per 8 elements in the hot loop. for (; k < vec_end; k += 16) { __m128i raw0 = _mm_loadl_epi64(reinterpret_cast(b_row + k)); __m256i i32_0 = _mm256_cvtepi8_epi32(raw0); __m256 bf0 = _mm256_cvtepi32_ps(i32_0); - bf0 = _mm256_mul_ps(bf0, scale_vec); __m256 a0 = _mm256_loadu_ps(a_row + k); acc0 = _mm256_fmadd_ps(a0, bf0, acc0); __m128i raw1 = _mm_loadl_epi64(reinterpret_cast(b_row + k + 8)); __m256i i32_1 = _mm256_cvtepi8_epi32(raw1); __m256 bf1 = _mm256_cvtepi32_ps(i32_1); - bf1 = _mm256_mul_ps(bf1, scale_vec); __m256 a1 = _mm256_loadu_ps(a_row + k + 8); acc1 = _mm256_fmadd_ps(a1, bf1, acc1); } @@ -146,7 +146,6 @@ FusedDotInt8( __m128i raw0 = _mm_loadl_epi64(reinterpret_cast(b_row + k)); __m256i i32_0 = _mm256_cvtepi8_epi32(raw0); __m256 bf0 = _mm256_cvtepi32_ps(i32_0); - bf0 = _mm256_mul_ps(bf0, scale_vec); __m256 a0 = _mm256_loadu_ps(a_row + k); acc0 = _mm256_fmadd_ps(a0, bf0, acc0); } @@ -161,9 +160,15 @@ FusedDotInt8( float dot = _mm_cvtss_f32(sum4); // Scalar tail - for (; k < K; ++k) { - float sc = per_channel ? scales[k] : scales[0]; - dot += a_row[k] * static_cast(b_row[k]) * sc; + if (per_channel) { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]) * scales[k]; + } + } else { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]); + } + dot *= scales[0]; } return dot; } @@ -326,7 +331,7 @@ SVGemm_Avx2( } } } else { - __m256 scale_vec = _mm256_broadcast_ss(Scales); + // Per-tensor: accumulate unscaled dot products, then scale the output row once. for (size_t k = 0; k < K; ++k) { const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); const float a_val = a_row[k]; @@ -337,15 +342,25 @@ SVGemm_Avx2( __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); __m256i i32 = _mm256_cvtepi8_epi32(raw); __m256 bf = _mm256_cvtepi32_ps(i32); - bf = _mm256_mul_ps(bf, scale_vec); __m256 c_vec = _mm256_loadu_ps(c_row + n); c_vec = _mm256_fmadd_ps(a_broadcast, bf, c_vec); _mm256_storeu_ps(c_row + n, c_vec); } for (; n < N; ++n) { - c_row[n] += a_val * static_cast(b_row[n]) * Scales[0]; + c_row[n] += a_val * static_cast(b_row[n]); } } + + __m256 scale_vec = _mm256_broadcast_ss(Scales); + n = 0; + for (; n < vec_end_n; n += 8) { + __m256 c_vec = _mm256_loadu_ps(c_row + n); + c_vec = _mm256_mul_ps(c_vec, scale_vec); + _mm256_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Scales[0]; + } } } else { // INT4 fused path diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp index ac23a0703ddff..fa5aff0165897 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp @@ -83,19 +83,21 @@ QuantizeRowToU8(const float* src, uint8_t* dst, size_t len) i = 0; for (; i < vec_end; i += 16) { __m512 v = _mm512_loadu_ps(src + i); - // q = round(v * inv_scale) + 128, clamped to [0, 255] + // q = (v * inv_scale) + 128, clamped to [0, 255] __m512 scaled = _mm512_fmadd_ps(v, inv_scale_vec, zp_vec); - scaled = _mm512_roundscale_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); scaled = _mm512_max_ps(scaled, min_val); scaled = _mm512_min_ps(scaled, max_clamp); - __m512i qi = _mm512_cvtps_epi32(scaled); + // Round-to-nearest-even and convert to int32 in a single instruction + // (AVX-512 embedded rounding eliminates a separate vrndscaleps). + __m512i qi = _mm512_cvt_roundps_epi32(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); // Pack 16 int32 -> 16 uint8 __m128i packed = _mm512_cvtepi32_epi8(qi); _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + i), packed); } - // Scalar tail + // Scalar tail (use nearbyintf for round-to-nearest-even, matching the + // AVX-512 embedded rounding semantics above). for (; i < len; ++i) { - float q = std::round(src[i] * inv_scale) + 128.0f; + float q = std::nearbyintf(src[i] * inv_scale) + 128.0f; q = std::max(0.0f, std::min(255.0f, q)); dst[i] = static_cast(q); } @@ -169,9 +171,11 @@ VnniDotInt8PerTensor( // Correction: dpbusd computed sum(a_u8 * b_s8). // We want sum((a_u8 - 128) * b_s8) = sum(a_u8 * b_s8) - 128 * sum(b_s8) - float corrected = static_cast(dot_i32) - 128.0f * static_cast(b_sum_i32); + // Perform correction in int32 to preserve precision (avoids float rounding + // when |dot_i32| or |128*b_sum_i32| exceed 2^24). + int32_t corrected = dot_i32 - (128 * b_sum_i32); - return corrected * scale_a * scale_b; + return static_cast(corrected) * scale_a * scale_b; } // @@ -221,19 +225,19 @@ FusedDotInt8_Avx512( acc0 = _mm512_fmadd_ps(a0, bf0, acc0); } } else { - __m512 scale_vec = _mm512_set1_ps(scales[0]); + // Per-tensor: defer scale multiplication until after accumulation. + // sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k]) + // This saves one vmulps per 16 elements in the hot loop. for (; k < vec_end; k += 32) { __m128i raw0 = _mm_loadu_si128(reinterpret_cast(b_row + k)); __m512i i32_0 = _mm512_cvtepi8_epi32(raw0); __m512 bf0 = _mm512_cvtepi32_ps(i32_0); - bf0 = _mm512_mul_ps(bf0, scale_vec); __m512 a0 = _mm512_loadu_ps(a_row + k); acc0 = _mm512_fmadd_ps(a0, bf0, acc0); __m128i raw1 = _mm_loadu_si128(reinterpret_cast(b_row + k + 16)); __m512i i32_1 = _mm512_cvtepi8_epi32(raw1); __m512 bf1 = _mm512_cvtepi32_ps(i32_1); - bf1 = _mm512_mul_ps(bf1, scale_vec); __m512 a1 = _mm512_loadu_ps(a_row + k + 16); acc1 = _mm512_fmadd_ps(a1, bf1, acc1); } @@ -241,7 +245,6 @@ FusedDotInt8_Avx512( __m128i raw0 = _mm_loadu_si128(reinterpret_cast(b_row + k)); __m512i i32_0 = _mm512_cvtepi8_epi32(raw0); __m512 bf0 = _mm512_cvtepi32_ps(i32_0); - bf0 = _mm512_mul_ps(bf0, scale_vec); __m512 a0 = _mm512_loadu_ps(a_row + k); acc0 = _mm512_fmadd_ps(a0, bf0, acc0); } @@ -251,9 +254,15 @@ FusedDotInt8_Avx512( float dot = _mm512_reduce_add_ps(acc0); // Scalar tail - for (; k < K; ++k) { - float sc = per_channel ? scales[k] : scales[0]; - dot += a_row[k] * static_cast(b_row[k]) * sc; + if (per_channel) { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]) * scales[k]; + } + } else { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]); + } + dot *= scales[0]; } return dot; } @@ -402,11 +411,11 @@ VnniMultiDot4Int8PerTensor( bs[3] += static_cast(b3[k]); } - const float zp = 128.0f; - out[0] = (static_cast(dot[0]) - zp * static_cast(bs[0])) * combined_scale; - out[1] = (static_cast(dot[1]) - zp * static_cast(bs[1])) * combined_scale; - out[2] = (static_cast(dot[2]) - zp * static_cast(bs[2])) * combined_scale; - out[3] = (static_cast(dot[3]) - zp * static_cast(bs[3])) * combined_scale; + // Zero-point correction in int32 for precision (see VnniDotInt8PerTensor). + out[0] = static_cast(dot[0] - 128 * bs[0]) * combined_scale; + out[1] = static_cast(dot[1] - 128 * bs[1]) * combined_scale; + out[2] = static_cast(dot[2] - 128 * bs[2]) * combined_scale; + out[3] = static_cast(dot[3] - 128 * bs[3]) * combined_scale; } // ============================================================================ @@ -569,7 +578,7 @@ SVGemm_Avx512Vnni( } } } else { - __m512 scale_vec = _mm512_set1_ps(Scales[0]); + // Per-tensor: accumulate unscaled dot products, then scale the output row once. for (size_t k = 0; k < K; ++k) { const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); const float a_val = a_row[k]; @@ -580,15 +589,25 @@ SVGemm_Avx512Vnni( __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); __m512i i32 = _mm512_cvtepi8_epi32(raw); __m512 bf = _mm512_cvtepi32_ps(i32); - bf = _mm512_mul_ps(bf, scale_vec); __m512 c_vec = _mm512_loadu_ps(c_row + n); c_vec = _mm512_fmadd_ps(a_broadcast, bf, c_vec); _mm512_storeu_ps(c_row + n, c_vec); } for (; n < N; ++n) { - c_row[n] += a_val * static_cast(b_row[n]) * Scales[0]; + c_row[n] += a_val * static_cast(b_row[n]); } } + + __m512 scale_vec = _mm512_set1_ps(Scales[0]); + n = 0; + for (; n < vec_end_n; n += 16) { + __m512 c_vec = _mm512_loadu_ps(c_row + n); + c_vec = _mm512_mul_ps(c_vec, scale_vec); + _mm512_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Scales[0]; + } } } else { // INT4 path: 512-bit wide diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp index ae5a56028bbf9..1aabbd8ca39cb 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp @@ -29,12 +29,13 @@ using namespace MlasKVQuantInternal; namespace { // -// Dequantize 8 INT8 values starting at `col` and scale them. +// Dequantize 8 INT8 values starting at `col`. +// Per-channel rows are always scaled. Per-tensor rows may defer scaling. // Produces two float32x4_t (8 floats total) stored into dst. // inline void DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel, - const float* scales, float* dst) + const float* scales, bool apply_per_tensor_scale, float* dst) { // Load 8 int8 values int8x8_t raw = vld1_s8(src + col); @@ -52,7 +53,7 @@ DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel, float32x4_t sc_hi = vld1q_f32(scales + col + 4); f_lo = vmulq_f32(f_lo, sc_lo); f_hi = vmulq_f32(f_hi, sc_hi); - } else { + } else if (apply_per_tensor_scale) { float32x4_t sc = vdupq_n_f32(scales[0]); f_lo = vmulq_f32(f_lo, sc); f_hi = vmulq_f32(f_hi, sc); @@ -64,10 +65,11 @@ DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel, // // Dequantize 8 INT4 values (4 packed bytes) starting at even column `col`. +// Per-channel rows are always scaled. Per-tensor rows may defer scaling. // inline void DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel, - const float* scales, float* dst) + const float* scales, bool apply_per_tensor_scale, float* dst) { const uint8_t* base = src + col / 2; @@ -94,7 +96,7 @@ DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel, float32x4_t sc_hi = vld1q_f32(scales + col + 4); f_lo = vmulq_f32(f_lo, sc_lo); f_hi = vmulq_f32(f_hi, sc_hi); - } else { + } else if (apply_per_tensor_scale) { float32x4_t sc = vdupq_n_f32(scales[0]); f_lo = vmulq_f32(f_lo, sc); f_hi = vmulq_f32(f_hi, sc); @@ -106,6 +108,8 @@ DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel, // // Dequantize one row of length `cols` from packed quantized buffer into `dst`. +// `apply_per_tensor_scale=false` leaves per-tensor rows unscaled so callers can +// factor the single scale out of an outer accumulation loop. // void DequantRow_Neon( @@ -113,7 +117,8 @@ DequantRow_Neon( float* dst, size_t cols, MLAS_KV_QUANT_TYPE qt, - const float* scales) + const float* scales, + bool apply_per_tensor_scale) { const bool int4 = IsInt4Mode(qt); const bool per_channel = IsPerChannelMode(qt); @@ -124,22 +129,32 @@ DequantRow_Neon( if (!int4) { const auto* src = static_cast(src_raw); for (; c < vec_end; c += 8) { - DequantInt8x8_Neon(src, c, per_channel, scales, dst + c); + DequantInt8x8_Neon(src, c, per_channel, scales, apply_per_tensor_scale, dst + c); } for (; c < cols; ++c) { - float sc = per_channel ? scales[c] : scales[0]; - dst[c] = static_cast(src[c]) * sc; + if (per_channel) { + dst[c] = static_cast(src[c]) * scales[c]; + } else if (apply_per_tensor_scale) { + dst[c] = static_cast(src[c]) * scales[0]; + } else { + dst[c] = static_cast(src[c]); + } } } else { const auto* src = static_cast(src_raw); for (; c < vec_end; c += 8) { - DequantInt4x8_Neon(src, c, per_channel, scales, dst + c); + DequantInt4x8_Neon(src, c, per_channel, scales, apply_per_tensor_scale, dst + c); } for (; c < cols; ++c) { uint8_t packed = src[c / 2]; int nibble = (c & 1) == 0 ? (packed & 0x0F) : ((packed >> 4) & 0x0F); - float sc = per_channel ? scales[c] : scales[0]; - dst[c] = static_cast(nibble - kInt4Bias) * sc; + if (per_channel) { + dst[c] = static_cast(nibble - kInt4Bias) * scales[c]; + } else if (apply_per_tensor_scale) { + dst[c] = static_cast(nibble - kInt4Bias) * scales[0]; + } else { + dst[c] = static_cast(nibble - kInt4Bias); + } } } } @@ -174,7 +189,7 @@ QKGemm_Neon( for (size_t n = 0; n < N; ++n) { const uint8_t* b_row = B_bytes + n * row_bytes; - DequantRow_Neon(b_row, b_buf, K, QuantType, Scales); + DequantRow_Neon(b_row, b_buf, K, QuantType, Scales, true); for (size_t m = 0; m < M; ++m) { const float* a_row = A + m * lda; @@ -246,6 +261,7 @@ SVGemm_Neon( { const size_t row_bytes = MlasKVQuantPackedRowBytes(QuantType, N); const auto* B_bytes = static_cast(B); + const bool per_channel = IsPerChannelMode(QuantType); float b_stack[256]; float* b_buf = b_stack; @@ -272,7 +288,7 @@ SVGemm_Neon( for (size_t k = 0; k < K; ++k) { const uint8_t* b_row_packed = B_bytes + k * row_bytes; - DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales); + DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales, per_channel); const float a_val = a_row[k]; float32x4_t a_broadcast = vdupq_n_f32(a_val); @@ -288,6 +304,19 @@ SVGemm_Neon( c_row[n] += a_val * b_buf[n]; } } + + if (!per_channel) { + const float32x4_t scale_vec = vdupq_n_f32(Scales[0]); + n = 0; + for (; n < vec_end_n; n += 4) { + float32x4_t c_vec = vld1q_f32(c_row + n); + c_vec = vmulq_f32(c_vec, scale_vec); + vst1q_f32(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Scales[0]; + } + } } } diff --git a/onnxruntime/core/mlas/lib/riscv64/cast_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/cast_kernel_rvv.cpp new file mode 100644 index 0000000000000..038b7873637db --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/cast_kernel_rvv.cpp @@ -0,0 +1,62 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + cast_kernel_rvv.cpp + +Abstract: + + This module implements FP16<->FP32 cast kernels using RISC-V Vector + Extension (RVV). Uses Zvfhmin conversion instructions, but is gated + on Zvfh at build time (no separate Zvfhmin-only cmake probe). + +--*/ + +#include "mlasi.h" + +#if defined(MLAS_USE_RVV_ZVFH) + +#include + +void + MLASCALL + MlasCastF16ToF32KernelRvv( + const unsigned short* Source, + float* Destination, + size_t Count + ) +{ + size_t i = 0; + while (i < Count) { + size_t vl = __riscv_vsetvl_e16m2(Count - i); + vuint16m2_t raw = __riscv_vle16_v_u16m2(Source + i, vl); + vfloat16m2_t fp16 = __riscv_vreinterpret_v_u16m2_f16m2(raw); + vfloat32m4_t fp32 = __riscv_vfwcvt_f_f_v_f32m4(fp16, vl); + __riscv_vse32_v_f32m4(Destination + i, fp32, vl); + i += vl; + } +} + +void + MLASCALL + MlasCastF32ToF16KernelRvv( + const float* Source, + unsigned short* Destination, + size_t Count + ) +{ + size_t i = 0; + while (i < Count) { + size_t vl = __riscv_vsetvl_e32m4(Count - i); + vfloat32m4_t fp32 = __riscv_vle32_v_f32m4(Source + i, vl); + vfloat16m2_t fp16 = __riscv_vfncvt_f_f_w_f16m2(fp32, vl); + __riscv_vse16_v_u16m2(Destination + i, __riscv_vreinterpret_v_f16m2_u16m2(fp16), vl); + i += vl; + } +} + +#endif // MLAS_USE_RVV_ZVFH diff --git a/onnxruntime/core/mlas/lib/riscv64/halfgemm_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/halfgemm_kernel_rvv.cpp new file mode 100644 index 0000000000000..f9fb2bbba96bf --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/halfgemm_kernel_rvv.cpp @@ -0,0 +1,239 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_kernel_rvv.cpp + +Abstract: + + This module implements half precision GEMM kernel for RISC-V Vector + Extension (RVV) with Zvfh (vector half-precision floating-point). + + The kernel vectorizes along the N dimension using vsetvl, so it adapts + automatically to any VLEN >= 128. Up to 4 rows of A are processed per + call (KernelMaxM = 4). + +--*/ + +#include "halfgemm.h" +#include "mlasi.h" + +#if defined(MLAS_USE_RVV_ZVFH) + +#include + +#include + +namespace +{ + +MLAS_FORCEINLINE +_Float16 +Fp16BitsToScalar(_mlas_fp16_ bits) +{ + _Float16 f; + memcpy(&f, &bits, sizeof(f)); + return f; +} + +MLAS_FORCEINLINE +vfloat16m4_t +LoadFp16(const _mlas_fp16_* ptr, size_t vl) +{ + return __riscv_vreinterpret_v_u16m4_f16m4(__riscv_vle16_v_u16m4(ptr, vl)); +} + +MLAS_FORCEINLINE +void +StoreFp16(_mlas_fp16_* ptr, vfloat16m4_t vec, size_t vl) +{ + __riscv_vse16_v_u16m4(ptr, __riscv_vreinterpret_v_f16m4_u16m4(vec), vl); +} + +template +MLAS_FORCEINLINE void +HalfGemmKernelRvvImpl( + size_t CountN, + size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + size_t lda, + const _mlas_fp16_* B, + size_t ldb, + bool ZeroMode +) +{ + static_assert(Rows >= 1 && Rows <= 4, "unsupported tile height"); + + size_t n = 0; + while (n < CountN) { + size_t vl = __riscv_vsetvl_e16m4(CountN - n); + + vfloat16m4_t acc0, acc1, acc2, acc3; + + if (ZeroMode) { + if (Bias != nullptr) { + vfloat16m4_t bv = LoadFp16(Bias + n, vl); + acc0 = bv; + if constexpr (Rows > 1) acc1 = bv; + if constexpr (Rows > 2) acc2 = bv; + if constexpr (Rows > 3) acc3 = bv; + } else { + vfloat16m4_t z = __riscv_vfmv_v_f_f16m4((_Float16)0.0f, vl); + acc0 = z; + if constexpr (Rows > 1) acc1 = z; + if constexpr (Rows > 2) acc2 = z; + if constexpr (Rows > 3) acc3 = z; + } + } else { + acc0 = LoadFp16(C + n, vl); + if constexpr (Rows > 1) acc1 = LoadFp16(C + ldc + n, vl); + if constexpr (Rows > 2) acc2 = LoadFp16(C + 2 * ldc + n, vl); + if constexpr (Rows > 3) acc3 = LoadFp16(C + 3 * ldc + n, vl); + if (Bias != nullptr) { + vfloat16m4_t bv = LoadFp16(Bias + n, vl); + acc0 = __riscv_vfadd_vv_f16m4(acc0, bv, vl); + if constexpr (Rows > 1) acc1 = __riscv_vfadd_vv_f16m4(acc1, bv, vl); + if constexpr (Rows > 2) acc2 = __riscv_vfadd_vv_f16m4(acc2, bv, vl); + if constexpr (Rows > 3) acc3 = __riscv_vfadd_vv_f16m4(acc3, bv, vl); + } + } + + for (size_t k = 0; k < CountK; k++) { + vfloat16m4_t bv = LoadFp16(B + k * ldb + n, vl); + acc0 = __riscv_vfmacc_vf_f16m4(acc0, Fp16BitsToScalar(A[k]), bv, vl); + if constexpr (Rows > 1) + acc1 = __riscv_vfmacc_vf_f16m4(acc1, Fp16BitsToScalar(A[lda + k]), bv, vl); + if constexpr (Rows > 2) + acc2 = __riscv_vfmacc_vf_f16m4(acc2, Fp16BitsToScalar(A[2 * lda + k]), bv, vl); + if constexpr (Rows > 3) + acc3 = __riscv_vfmacc_vf_f16m4(acc3, Fp16BitsToScalar(A[3 * lda + k]), bv, vl); + } + + StoreFp16(C + n, acc0, vl); + if constexpr (Rows > 1) StoreFp16(C + ldc + n, acc1, vl); + if constexpr (Rows > 2) StoreFp16(C + 2 * ldc + n, acc2, vl); + if constexpr (Rows > 3) StoreFp16(C + 3 * ldc + n, acc3, vl); + + n += vl; + } +} + +} // namespace + +struct MLAS_HALF_GEMM_KERNEL_RVV { + static constexpr bool PackNeeded = false; + static constexpr size_t KernelMaxM = 4; + static constexpr size_t PackedK = 1; + static constexpr MLAS_HALF_GEMM_STRIDES Strides{16, 128, 256}; +}; + +// FP32->FP16 conversion routines for when AIsfp32/BIsfp32 is set. +// PackNeeded=false means no packing, but these are still called +// to convert FP32 inputs to FP16 on the fly (see matmul.cc). +template <> +MLAS_FORCEINLINE void +MlasHalfGemmConvertPackA( + _mlas_fp16_* D, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +) +{ + for (size_t m = 0; m < CountM; m++) { + const float* src = A + m * lda; + _mlas_fp16_* dst = D + m * CountK; + size_t k = 0; + while (k < CountK) { + size_t vl = __riscv_vsetvl_e32m4(CountK - k); + vfloat32m4_t fp32 = __riscv_vle32_v_f32m4(src + k, vl); + vfloat16m2_t fp16 = __riscv_vfncvt_f_f_w_f16m2(fp32, vl); + __riscv_vse16_v_u16m2( + dst + k, + __riscv_vreinterpret_v_f16m2_u16m2(fp16), + vl + ); + k += vl; + } + } +} + +template <> +MLAS_FORCEINLINE void +MlasHalfGemmConvertPackB( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK +) +{ + for (size_t k = 0; k < CountK; k++) { + const float* src = B + k * ldb; + _mlas_fp16_* dst = D + k * CountN; + size_t n = 0; + while (n < CountN) { + size_t vl = __riscv_vsetvl_e32m4(CountN - n); + vfloat32m4_t fp32 = __riscv_vle32_v_f32m4(src + n, vl); + vfloat16m2_t fp16 = __riscv_vfncvt_f_f_w_f16m2(fp32, vl); + __riscv_vse16_v_u16m2( + dst + n, + __riscv_vreinterpret_v_f16m2_u16m2(fp16), + vl + ); + n += vl; + } + } +} + +template <> +MLAS_FORCEINLINE void +MlasHalfGemmKernel( + size_t CountM, + size_t CountN, + size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + size_t lda, + const _mlas_fp16_* B, + size_t ldb, + const bool ZeroMode +) +{ + size_t rows = std::min(CountM, MLAS_HALF_GEMM_KERNEL_RVV::KernelMaxM); + + switch (rows) { + case 1: + HalfGemmKernelRvvImpl<1>(CountN, CountK, C, ldc, Bias, A, lda, B, ldb, ZeroMode); + break; + case 2: + HalfGemmKernelRvvImpl<2>(CountN, CountK, C, ldc, Bias, A, lda, B, ldb, ZeroMode); + break; + case 3: + HalfGemmKernelRvvImpl<3>(CountN, CountK, C, ldc, Bias, A, lda, B, ldb, ZeroMode); + break; + default: + HalfGemmKernelRvvImpl<4>(CountN, CountK, C, ldc, Bias, A, lda, B, ldb, ZeroMode); + break; + } +} + +const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchRvv = { + MlasHalfGemmOperation, + nullptr, + MlasHalfGemmConvertPackB, + MLAS_HALF_GEMM_KERNEL_RVV::PackedK, + MLAS_HALF_GEMM_KERNEL_RVV::KernelMaxM, + 0 +}; + +#endif // MLAS_USE_RVV_ZVFH diff --git a/onnxruntime/core/mlas/lib/riscv64/layernorm_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/layernorm_kernel_rvv.cpp new file mode 100644 index 0000000000000..2bfeba1f1c993 --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/layernorm_kernel_rvv.cpp @@ -0,0 +1,109 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + layernorm_kernel_rvv.cpp + +Abstract: + + This module implements LayerNorm/RMSNorm kernels using RISC-V Vector + Extension (RVV). Processes one normalization row at a time. + +--*/ + +#include "mlasi.h" + +#if defined(MLAS_USE_RVV) + +#include + +#include +#include + +// Processes one normalization row. A multi-row variant that fuses +// several rows would reduce dispatch overhead for small NormSize. +void MLASCALL +MlasLayerNormKernelRvv( + const float* Input, + const float* Scale, + const float* Bias, + float* Output, + float* MeanOut, + float* InvStdDevOut, + size_t NormSize, + float Epsilon, + bool Simplified +) +{ + assert(!Simplified || Bias == nullptr); + const size_t n = NormSize; + + size_t maxvl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t vacc_sum = __riscv_vfmv_v_f_f32m4(0.0f, maxvl); + vfloat32m4_t vacc_sumsq = __riscv_vfmv_v_f_f32m4(0.0f, maxvl); + + size_t i = 0; + while (i < n) { + size_t vl = __riscv_vsetvl_e32m4(n - i); + vfloat32m4_t vx = __riscv_vle32_v_f32m4(Input + i, vl); + vacc_sum = __riscv_vfadd_vv_f32m4_tu(vacc_sum, vacc_sum, vx, vl); + vfloat32m4_t vx2 = __riscv_vfmul_vv_f32m4(vx, vx, vl); + vacc_sumsq = __riscv_vfadd_vv_f32m4_tu(vacc_sumsq, vacc_sumsq, vx2, vl); + i += vl; + } + + vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvl_e32m1(1)); + float mean_val = __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m4_f32m1(vacc_sum, vzero, maxvl) + ) / + static_cast(n); + float ms_val = __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m4_f32m1(vacc_sumsq, vzero, maxvl) + ); + float denom; + if (Simplified) { + denom = sqrtf(ms_val / static_cast(n) + Epsilon); + } else { + denom = sqrtf(ms_val / static_cast(n) - mean_val * mean_val + Epsilon); + } + float inv_denom = 1.0f / denom; + + i = 0; + while (i < n) { + size_t vl = __riscv_vsetvl_e32m4(n - i); + vfloat32m4_t vx = __riscv_vle32_v_f32m4(Input + i, vl); + vfloat32m4_t vs = __riscv_vle32_v_f32m4(Scale + i, vl); + + if (Simplified) { + vfloat32m4_t vy = __riscv_vfmul_vf_f32m4(vx, inv_denom, vl); + vy = __riscv_vfmul_vv_f32m4(vy, vs, vl); + __riscv_vse32_v_f32m4(Output + i, vy, vl); + } else if (Bias == nullptr) { + vfloat32m4_t vy = __riscv_vfsub_vf_f32m4(vx, mean_val, vl); + vy = __riscv_vfmul_vf_f32m4(vy, inv_denom, vl); + vy = __riscv_vfmul_vv_f32m4(vy, vs, vl); + __riscv_vse32_v_f32m4(Output + i, vy, vl); + } else { + vfloat32m4_t vb = __riscv_vle32_v_f32m4(Bias + i, vl); + vfloat32m4_t vy = __riscv_vfsub_vf_f32m4(vx, mean_val, vl); + vy = __riscv_vfmul_vf_f32m4(vy, inv_denom, vl); + vy = __riscv_vfmadd_vv_f32m4(vy, vs, vb, vl); + __riscv_vse32_v_f32m4(Output + i, vy, vl); + } + + i += vl; + } + + if (MeanOut != nullptr) { + *MeanOut = mean_val; + } + if (InvStdDevOut != nullptr) { + *InvStdDevOut = inv_denom; + } +} + +#endif // MLAS_USE_RVV diff --git a/onnxruntime/core/mlas/lib/riscv64/rotary_embedding_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/rotary_embedding_kernel_rvv.cpp new file mode 100644 index 0000000000000..3cc00624c76bd --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/rotary_embedding_kernel_rvv.cpp @@ -0,0 +1,108 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_rvv.cpp + +Abstract: + + This module implements rotary embedding kernels for RISC-V Vector + Extension (RVV). + + For the non-interleaved case: + output[i] = input[i] * cos[i] - input[i + half] * sin[i] + output[i + half] = input[i + half] * cos[i] + input[i] * sin[i] + + For the interleaved case: + output[2i] = input[2i] * cos[i] - input[2i+1] * sin[i] + output[2i+1] = input[2i+1] * cos[i] + input[2i] * sin[i] + +--*/ + +#include + +#include "rotary_embedding.h" + +#if defined(MLAS_USE_RVV) + +#include + +namespace rope_rvv +{ + +void +RopeKernel_Fp32( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output +) +{ + assert(dim % 2 == 0); + const size_t half_dim = dim / 2; + + if (!interleaved) { + size_t i = 0; + while (i < half_dim) { + size_t vl = __riscv_vsetvl_e32m4(half_dim - i); + + vfloat32m4_t vc = __riscv_vle32_v_f32m4(cos_data + i, vl); + vfloat32m4_t vs = __riscv_vle32_v_f32m4(sin_data + i, vl); + vfloat32m4_t v0 = __riscv_vle32_v_f32m4(input + i, vl); + vfloat32m4_t v1 = __riscv_vle32_v_f32m4(input + i + half_dim, vl); + + // output[i] = input[i] * cos - input[i+half] * sin + vfloat32m4_t r0 = __riscv_vfmul_vv_f32m4(v0, vc, vl); + r0 = __riscv_vfnmsac_vv_f32m4(r0, vs, v1, vl); + + // output[i+half] = input[i+half] * cos + input[i] * sin + vfloat32m4_t r1 = __riscv_vfmul_vv_f32m4(v1, vc, vl); + r1 = __riscv_vfmacc_vv_f32m4(r1, vs, v0, vl); + + __riscv_vse32_v_f32m4(output + i, r0, vl); + __riscv_vse32_v_f32m4(output + i + half_dim, r1, vl); + + i += vl; + } + } else { + size_t i = 0; + while (i < half_dim) { + size_t vl = __riscv_vsetvl_e32m4(half_dim - i); + + vfloat32m4_t vc = __riscv_vle32_v_f32m4(cos_data + i, vl); + vfloat32m4_t vs = __riscv_vle32_v_f32m4(sin_data + i, vl); + + vfloat32m4x2_t seg = __riscv_vlseg2e32_v_f32m4x2(input + 2 * i, vl); + vfloat32m4_t v_even = __riscv_vget_v_f32m4x2_f32m4(seg, 0); + vfloat32m4_t v_odd = __riscv_vget_v_f32m4x2_f32m4(seg, 1); + + // output[2i] = even * cos - odd * sin + vfloat32m4_t r_even = __riscv_vfmul_vv_f32m4(v_even, vc, vl); + r_even = __riscv_vfnmsac_vv_f32m4(r_even, vs, v_odd, vl); + + // output[2i+1] = odd * cos + even * sin + vfloat32m4_t r_odd = __riscv_vfmul_vv_f32m4(v_odd, vc, vl); + r_odd = __riscv_vfmacc_vv_f32m4(r_odd, vs, v_even, vl); + + vfloat32m4x2_t out = __riscv_vcreate_v_f32m4x2(r_even, r_odd); + __riscv_vsseg2e32_v_f32m4x2(output + 2 * i, out, vl); + + i += vl; + } + } +} + +} // namespace rope_rvv + +const MLAS_ROPE_DISPATCH MlasRopeDispatchRvv = { + rope_rvv::RopeKernel_Fp32, + nullptr, +}; + +#endif // MLAS_USE_RVV diff --git a/onnxruntime/core/mlas/lib/riscv64/sconv_depthwise_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/sconv_depthwise_kernel_rvv.cpp index c9253bb033a1a..51b3e24dddff7 100644 --- a/onnxruntime/core/mlas/lib/riscv64/sconv_depthwise_kernel_rvv.cpp +++ b/onnxruntime/core/mlas/lib/riscv64/sconv_depthwise_kernel_rvv.cpp @@ -142,6 +142,7 @@ MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1( assert(pad_bottom <= 1); assert(pad_left <= 1); assert(pad_right <= 1); + MLAS_UNREFERENCED_PARAMETER(pad_bottom); const float beta = Parameters->Beta; const bool accumulate_output = beta != 0.0f; diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 8abc6da27a64c..ab491c134b5e5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -371,6 +371,21 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, continue; } + // Do not propagate DQ forward when its data input is a constant (graph initializer or + // Constant op output). Propagation would insert a Q -> DQ pair after any downstream + // reshape-like node; later passes (e.g. S8-to-U8 weight transformer) may flip the + // existing DQ to uint8 without touching the inserted Q, causing int8 negatives to + // be clamped to zero. See GitHub issue #28491. + const NodeArg* dq_data_input = dq_node.InputDefs()[QDQ::InputIndex::INPUT_ID]; + const bool is_initializer_constant = graph_utils::NodeArgIsConstant(graph, *dq_data_input); + const Node* dq_data_producer = graph.GetProducerNode(dq_data_input->Name()); + const bool is_constant_op_output = dq_data_producer != nullptr && + dq_data_producer->OpType() == "Constant" && + dq_data_producer->Domain() == kOnnxDomain; + if (is_initializer_constant || is_constant_op_output) { + continue; + } + auto& dq_scale = *dq_node.MutableInputDefs()[QDQ::InputIndex::SCALE_ID]; auto* dq_zero_point = dq_zero_point_exists ? dq_node.MutableInputDefs()[QDQ::InputIndex::ZERO_POINT_ID] diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index f88ce56fe36fa..f50c0e2e635bc 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -470,6 +470,19 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) { break; } + // If next_node is a Reshape with allowzero=1, the fused node cannot represent this + // correctly: the fused node inherits attributes from the first node in the chain + // (which has allowzero=0 or no allowzero attribute). Bailing out here prevents + // incorrect fusion such as Reshape([0,8,2]->[4,2,-1]) + Reshape([0,0,4],allowzero=1) + // being collapsed into Reshape([0,8,2]->[0,0,4],allowzero=0), which would silently + // copy dims from the original input instead of preserving the explicit zeros. + if (next_node->OpType() == "Reshape") { + const auto* az_attr = graph_utils::GetNodeAttribute(*next_node, "allowzero"); + if ((nullptr != az_attr) && az_attr->has_i() && az_attr->i() != 0) { + break; + } + } + auto shape = next_node->OutputDefs()[0]->Shape(); if (!shape) { break; diff --git a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc index 8b58f5dc6c927..5059c6c9edd8f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc @@ -30,27 +30,121 @@ int64_t GetAxisAttribute(const Node& node) { } // namespace Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /*logger*/) const { + const logging::Logger& logger) const { + const auto axis = GetAxisAttribute(node); + const auto& data_def = *node.InputDefs()[0]; + const auto& indices_def = *node.InputDefs()[1]; + const auto& output_def = *node.OutputDefs()[0]; + + std::vector data_shape, indices_shape; + ORT_RETURN_IF_NOT(GetShape(data_def, data_shape, logger), "Failed to get 'data' shape"); + ORT_RETURN_IF_NOT(GetShape(indices_def, indices_shape, logger), "Failed to get 'indices' shape"); + + // ONNX Gather: out_shape = data_shape[:axis] + indices_shape + data_shape[axis+1:] + // CoreML's gather requires rank-1+ indices, so for scalar indices we promote + // them to [1], gather, and then squeeze the resulting axis to restore the + // original output rank. The positive axis after wrapping is needed for the + // squeeze axis below regardless of path. + const bool scalar_indices = indices_shape.empty(); + const int64_t pos_axis = HandleNegativeAxis(axis, data_shape.size()); + if (model_builder.CreateMLProgram()) { using CoreML::Specification::MILSpec::Operation; - std::unique_ptr op = model_builder.CreateOperation(node, "gather"); - - const auto axis = GetAxisAttribute(node); + // IsOpSupportedImpl gates indices to INT32 or INT64, so we can pass the + // dtype straight through to the reshape's intermediate output. + int32_t indices_dtype{}; + ORT_RETURN_IF_NOT(GetType(indices_def, indices_dtype, logger), + "Failed to get 'indices' dtype"); + const int32_t output_dtype = static_cast(output_def.TypeAsProto()->tensor_type().elem_type()); + + std::string indices_name = indices_def.Name(); + + if (scalar_indices) { + // [] -> [1] via reshape. We use reshape rather than expand_dims because + // CoreML internally pads scalars; expand_dims on the padded tensor can + // push the apparent rank past the rank-5 limit on high-rank `data`. + auto reshape = model_builder.CreateOperation(node, "reshape", "indices"); + AddOperationInput(*reshape, "x", indices_def.Name()); + const std::vector indices_1d_shape = {1}; + AddOperationInput(*reshape, "shape", + model_builder.AddConstant(reshape->type(), "shape", indices_1d_shape)); + + indices_name = model_builder.GetUniqueName(node, "indices_1d"); + AddIntermediateOperationOutput(*reshape, indices_name, indices_dtype, indices_1d_shape); + model_builder.AddOperation(std::move(reshape)); + } + + std::unique_ptr gather = model_builder.CreateOperation(node, "gather"); // coreml docs claims validate_indices is optional but in practice it is required const auto validate_indices = false; - AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); // data - AddOperationInput(*op, "indices", node.InputDefs()[1]->Name()); // indices - AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); // axis attr - AddOperationInput(*op, "validate_indices", model_builder.AddScalarConstant(op->type(), "validate_indices", validate_indices)); - AddOperationOutput(*op, *node.OutputDefs()[0]); // output - model_builder.AddOperation(std::move(op)); + AddOperationInput(*gather, "x", data_def.Name()); + AddOperationInput(*gather, "indices", indices_name); + AddOperationInput(*gather, "axis", model_builder.AddScalarConstant(gather->type(), "axis", axis)); + AddOperationInput(*gather, "validate_indices", + model_builder.AddScalarConstant(gather->type(), "validate_indices", validate_indices)); + + if (!scalar_indices) { + AddOperationOutput(*gather, output_def); + model_builder.AddOperation(std::move(gather)); + } else { + // gather output here has the data's rank (one more than ONNX scalar-gather output); + // squeeze the inserted axis to recover the original output shape. + TensorShapeVector gather_shape{data_shape.begin(), data_shape.end()}; + gather_shape[pos_axis] = 1; + const std::string& gather_out_name = model_builder.GetUniqueName(node, "gather_out"); + AddIntermediateOperationOutput(*gather, gather_out_name, output_dtype, gather_shape); + model_builder.AddOperation(std::move(gather)); + + auto squeeze = model_builder.CreateOperation(node, "squeeze", "post"); + AddOperationInput(*squeeze, "x", gather_out_name); + const std::vector sq_axes = {pos_axis}; + AddOperationInput(*squeeze, "axes", model_builder.AddConstant(squeeze->type(), "axes", sq_axes)); + AddOperationOutput(*squeeze, output_def); + model_builder.AddOperation(std::move(squeeze)); + } } else { - auto layer = model_builder.CreateNNLayer(node); - layer->mutable_gather()->set_axis(GetAxisAttribute(node)); - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); // data - *layer->mutable_input()->Add() = node.InputDefs()[1]->Name(); // indices - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); // output - model_builder.AddLayer(std::move(layer)); + if (!scalar_indices) { + auto layer = model_builder.CreateNNLayer(node); + layer->mutable_gather()->set_axis(axis); + *layer->mutable_input()->Add() = data_def.Name(); + *layer->mutable_input()->Add() = indices_def.Name(); + *layer->mutable_output()->Add() = output_def.Name(); + model_builder.AddLayer(std::move(layer)); + } else { + // expand_dims indices: [] -> [1]. Unlike the MLProgram reshape path + // above, NN's expand_dims doesn't internally pad rank, so we don't run + // into the apparent-rank inflation that forced reshape+gather there; + // expand_dims is the natural choice on this path. + const std::string& indices_1d_name = model_builder.GetUniqueName(node, "indices_1d"); + { + auto expand_layer = model_builder.CreateNNLayer(node, "_indices_expand"); + expand_layer->mutable_expanddims()->add_axes(0); + *expand_layer->mutable_input()->Add() = indices_def.Name(); + *expand_layer->mutable_output()->Add() = indices_1d_name; + model_builder.AddLayer(std::move(expand_layer)); + } + + // gather with the promoted indices + const std::string& gather_out_name = model_builder.GetUniqueName(node, "gather_out"); + { + auto gather_layer = model_builder.CreateNNLayer(node); + gather_layer->mutable_gather()->set_axis(axis); + *gather_layer->mutable_input()->Add() = data_def.Name(); + *gather_layer->mutable_input()->Add() = indices_1d_name; + *gather_layer->mutable_output()->Add() = gather_out_name; + model_builder.AddLayer(std::move(gather_layer)); + } + + // squeeze the inserted axis + { + auto squeeze_layer = model_builder.CreateNNLayer(node, "_post_squeeze"); + squeeze_layer->mutable_squeeze()->add_axes(pos_axis); + squeeze_layer->mutable_squeeze()->set_squeezeall(false); + *squeeze_layer->mutable_input()->Add() = gather_out_name; + *squeeze_layer->mutable_output()->Add() = output_def.Name(); + model_builder.AddLayer(std::move(squeeze_layer)); + } + } } return Status::OK(); } @@ -87,14 +181,45 @@ bool GatherOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa return false; } - // Don't allow scalar 'indices' input. - // We convert scalar inputs to tensors with shape [1] before providing them to CoreML. - // This modification changes the shape of the Gather output. - if (indices_shape.empty()) { - LOGS(logger, VERBOSE) << "Gather does not support scalar 'indices'"; + // ONNX Gather schema constrains indices to int32 or int64. Validate here so + // AddToModelBuilderImpl can trust the dtype rather than silently defaulting + // on an unexpected value. + int32_t indices_dtype{}; + if (!GetType(*node.InputDefs()[1], indices_dtype, logger)) { return false; } + if (indices_dtype != ONNX_NAMESPACE::TensorProto_DataType_INT32 && + indices_dtype != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + LOGS(logger, VERBOSE) << "Gather 'indices' dtype [" << indices_dtype + << "] is not supported (expected INT32 or INT64)"; + return false; + } + + // For scalar indices we internally emit gather with promoted [1] indices + // then squeeze. That requires us to claim a static intermediate shape, so + // we only handle scalar indices when the data shape itself is fully + // static. (Dynamic-shape scalar Gather still falls back to CPU.) + if (indices_shape.empty()) { + if (!IsStaticShape(data_shape)) { + LOGS(logger, VERBOSE) << "Gather with scalar 'indices' requires static 'data' shape"; + return false; + } + // The pre-squeeze intermediate has the same rank as `data`. CoreML's + // compiler reports "Invalid rank: 6" when a rank-5 intermediate is + // produced via reshape+gather, even though rank-5 intermediates are + // accepted in other op chains. Cap scalar-indices Gather at data rank 4 + // until that compiler limit is lifted. + // + // TODO: re-test on newer macOS / CoreML versions; if Apple lifts the + // intermediate rank limit, this cap can be raised to 5 (matching the + // general Gather output-rank check below). + if (data_shape.size() > 4) { + LOGS(logger, VERBOSE) << "Gather with scalar 'indices' supports 'data' rank up to 4"; + return false; + } + } + // Output rank = data_rank + indices_rank - 1. The rank-5 limit applies. if (data_shape.size() + indices_shape.size() - 1 > 5) { LOGS(logger, VERBOSE) << "Gather does not support output with rank greater than 5"; return false; diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index 7dd9d994e52b4..f8ea6d4003619 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -38,6 +38,20 @@ void ComputeJob( ORT_UNUSED_PARAMETER(bias_float_ptr); // only used in MLFloat16 overload ORT_UNUSED_PARAMETER(alloc); + int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size); + + if constexpr (std::is_same_v) { + if (MlasLayerNormF32( + X_data + task_idx * norm_size, scale_data + i, + (simplified || !bias_data) ? nullptr : bias_data + i, + Y_data + task_idx * norm_size, + mean_data ? &mean_data[task_idx] : nullptr, + inv_std_dev_data ? &inv_std_dev_data[task_idx] : nullptr, + static_cast(norm_size), epsilon, simplified)) { + return; + } + } + const T* p_input = X_data + task_idx * norm_size; T* p_output = Y_data + task_idx * norm_size; @@ -57,9 +71,6 @@ void ComputeJob( mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); } - // Compute the offset of gamma and beta to support broadcasting. - int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size); - for (int64_t h = 0; h < norm_size; h++, i++) { if (simplified) { p_output[h] = p_output[h] / mean_square * scale_data[i]; diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index 9916c426a54fe..8d0ab6a53b88e 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -208,6 +208,9 @@ class MatMulIntegerBase : public OpKernel { } if (ctx.bias != nullptr) { + if (ctx.bias->Shape().Size() != static_cast(ctx.N)) { + return false; + } dynamic_quant_mlas_bias_data_was_packed_ = true; } diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index c7a2005924836..5b5011d2a0814 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Scatter +#include #include #include @@ -10,6 +11,7 @@ #include "core/framework/element_type_lists.h" #include "core/framework/op_kernel.h" #include "core/framework/op_kernel_type_control_utils.h" +#include "core/platform/threadpool.h" #include "core/providers/common.h" #include "core/providers/op_kernel_type_control.h" #if defined(ENABLE_TRAINING_OPS) @@ -236,29 +238,44 @@ struct Func_Max { template Status GetIndices( const Tensor& data_input, const Tensor& indices_input, int64_t axis, + concurrency::ThreadPool* tp, std::vector& indices_data) { const auto& input_data_shape = data_input.Shape(); const auto* indices_data_raw = indices_input.Data(); const auto num_indices = indices_input.Shape().Size(); const auto axis_dim_limit = input_data_shape[narrow(axis)]; - std::vector indices_data_result; - indices_data_result.reserve(narrow(num_indices)); - - for (int64_t i = 0; i < num_indices; ++i) { - const int64_t idx = static_cast(indices_data_raw[i]); - - if (idx < -axis_dim_limit || idx >= axis_dim_limit) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "indices element out of data bounds, idx=", idx, - " must be within the inclusive range [", -axis_dim_limit, - ",", axis_dim_limit - 1, "]"); - } - - indices_data_result.push_back(idx < 0 ? idx + axis_dim_limit : idx); + indices_data.resize(narrow(num_indices)); + + // When multiple indices are out-of-bounds, the reported index is nondeterministic + // (whichever thread wins the CAS). This is acceptable—we only need to report that + // validation failed and provide one example of a bad index. + std::atomic found_error{false}; + std::atomic first_bad_idx{0}; + + concurrency::ThreadPool::TryParallelFor( + tp, narrow(num_indices), 1.0, + [&](std::ptrdiff_t first, std::ptrdiff_t last) { + for (std::ptrdiff_t i = first; i < last; ++i) { + const int64_t idx = static_cast(indices_data_raw[i]); + if (idx < -axis_dim_limit || idx >= axis_dim_limit) { + bool expected = false; + if (found_error.compare_exchange_strong(expected, true)) { + first_bad_idx.store(idx, std::memory_order_relaxed); + } + return; + } + indices_data[narrow(i)] = idx < 0 ? idx + axis_dim_limit : idx; + } + }); + + if (found_error.load()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "indices element out of data bounds, idx=", first_bad_idx.load(), + " must be within the inclusive range [", -axis_dim_limit, + ",", axis_dim_limit - 1, "]"); } - indices_data = std::move(indices_data_result); return Status::OK(); } @@ -266,6 +283,7 @@ template Status ScatterData( const FuncT& func, const Tensor* data_input, const std::vector& indices_data, const Tensor* updates_input, int64_t axis, + concurrency::ThreadPool* tp, Tensor* data_output) { const TensorShape& input_data_shape = data_input->Shape(); @@ -296,103 +314,129 @@ Status ScatterData( const auto num_dims = input_data_shape.NumDimensions(); ORT_RETURN_IF_NOT(num_dims > 0, "ScatterElements op: input tensor must have at least one dimension"); - // Allocate and zero out counts. The input/output is of the same rank as - // indices/updates but the actual dimensions of indices/updates must be less or equal - // than that of input/output because we can update no more elements than - // the input contains. As we walk through the indices/updates - // we maintain dimension count as we will need to use it - // to compute output offset but using input/output dim values. - // We treat the whole array as a number where each element having - // different cardinality according to the upd_shape dimensions. - // As each counter reaches its max (upd_shape) it resets to zero - // and we carry to the more significant dim (right to left) - std::vector dim_counters(num_dims); - - // This vector contains number of elements under the dimension. - // For example, for the dimensions of [4, 2, 3] the vector - // would contain [6, 3, 1] since for each count of dim 1 it - // contains 3 elements of dim 2. - // For each count of dim 0 we would have 2x3=6 elements. - // The last value is always 1. - // We use it to compute output element offset. For a given value of - // counters we multiple each counter value per corresponding entry of dim_block_size value - // and add up resulting the output element offset. However, for dimensions - // that are equal to the specified axis value we take indices_data[index] - // instead of the counter value. - // E.g. for 3-dim and axis=0 - // output[indices[i][j][k]][j][k] = updates[i][j][k] - // for axis 1 - // output[i][indices[i][j][k]][k] = updates[i][j][k] - // and so on - std::vector dim_block_size(num_dims); - - dim_block_size.back() = 1; + if (num_indices == 0) { + return Status::OK(); + } + + const auto* update_data = static_cast(updates_input->DataRaw()); + + // Compute outer_size (product of dims before axis) and inner_size (product of dims after axis). + // For ScatterElements with axis=a: + // output[i0]...[indices[i0..iN]][...][iN] = updates[i0][...][iN] + // Work units identified by (outer_idx, inner_idx) are completely independent: + // they never write to the same output element, even with reductions. + // This allows safe parallelization over outer_size * inner_size work units. + int64_t outer_size = 1; + for (int64_t i = 0; i < axis; ++i) { + outer_size *= upd_shape[narrow(i)]; + } + const int64_t axis_size = upd_shape[narrow(axis)]; + int64_t inner_size = 1; + for (size_t i = narrow(axis) + 1; i < num_dims; ++i) { + inner_size *= upd_shape[i]; + } + + // Compute strides for the input/output tensor + std::vector input_strides(num_dims); + input_strides.back() = 1; if (num_dims > 1) { - // We start at num_dims - 2 because we already pre-populated - // the last element above for (auto i = int64_t(num_dims - 2); i >= 0; --i) { - dim_block_size[narrow(i)] = input_data_shape[SafeInt(i) + 1] * dim_block_size[SafeInt(i) + 1]; + input_strides[narrow(i)] = input_data_shape[SafeInt(i) + 1] * input_strides[SafeInt(i) + 1]; } } - const auto* update_data = static_cast(updates_input->DataRaw()); - // For every update we compute the destination offset and copy it there - for (int64_t index = 0; index < num_indices;) { - const auto axis_idx = indices_data[narrow(index)]; - - // Compute the offset - // See comments above for dim_block_size - size_t dst_offset = 0; - for (size_t i = 0; i < num_dims; ++i) { - if (i == size_t(axis)) { - // replace the counter with the update index for this dim - dst_offset += narrow(axis_idx * dim_block_size[narrow(i)]); - } else { - dst_offset += narrow(dim_counters[narrow(i)] * dim_block_size[narrow(i)]); - } + // Compute strides for the updates/indices tensor + std::vector upd_strides(num_dims); + upd_strides.back() = 1; + if (num_dims > 1) { + for (auto i = int64_t(num_dims - 2); i >= 0; --i) { + upd_strides[narrow(i)] = upd_shape[SafeInt(i) + 1] * upd_strides[SafeInt(i) + 1]; } + } - func(dst_base + dst_offset, update_data + index); + const int64_t total_work_units = outer_size * inner_size; + const int64_t input_axis_stride = input_strides[narrow(axis)]; + const int64_t upd_axis_stride = upd_strides[narrow(axis)]; + + // Parallelize over independent work units. + // Each work unit processes axis_size elements along the scatter axis. + // Cost per unit is proportional to axis_size (number of scatter ops per work unit). + concurrency::ThreadPool::TryParallelFor( + tp, narrow(total_work_units), static_cast(axis_size), + [&](std::ptrdiff_t first, std::ptrdiff_t last) { + for (std::ptrdiff_t work_idx = first; work_idx < last; ++work_idx) { + // Decompose work_idx into outer_idx and inner_idx + const int64_t outer_idx = static_cast(work_idx) / inner_size; + const int64_t inner_idx = static_cast(work_idx) % inner_size; + + // Compute the base offset in the output for dimensions outside the axis. + // For dims before axis: determined by outer_idx + // For dims after axis: determined by inner_idx + int64_t dst_base_offset = 0; + int64_t outer_remain = outer_idx; + for (int64_t d = axis - 1; d >= 0; --d) { + const auto dim_size = upd_shape[narrow(d)]; + const auto coord = outer_remain % dim_size; + outer_remain /= dim_size; + dst_base_offset += coord * input_strides[narrow(d)]; + } + int64_t inner_remain = inner_idx; + for (int64_t d = int64_t(num_dims) - 1; d > axis; --d) { + const auto dim_size = upd_shape[narrow(d)]; + const auto coord = inner_remain % dim_size; + inner_remain /= dim_size; + dst_base_offset += coord * input_strides[narrow(d)]; + } + + // Compute the base index into the updates/indices flat array + int64_t upd_base_offset = 0; + outer_remain = outer_idx; + for (int64_t d = axis - 1; d >= 0; --d) { + const auto dim_size = upd_shape[narrow(d)]; + const auto coord = outer_remain % dim_size; + outer_remain /= dim_size; + upd_base_offset += coord * upd_strides[narrow(d)]; + } + inner_remain = inner_idx; + for (int64_t d = int64_t(num_dims) - 1; d > axis; --d) { + const auto dim_size = upd_shape[narrow(d)]; + const auto coord = inner_remain % dim_size; + inner_remain /= dim_size; + upd_base_offset += coord * upd_strides[narrow(d)]; + } + + // Process axis_size elements along the axis + for (int64_t a = 0; a < axis_size; ++a) { + const int64_t upd_flat_idx = upd_base_offset + a * upd_axis_stride; + const int64_t axis_idx = indices_data[narrow(upd_flat_idx)]; + const int64_t dst_offset = dst_base_offset + axis_idx * input_axis_stride; + func(dst_base + dst_offset, update_data + upd_flat_idx); + } + } + }); - if (++index == num_indices) { - break; - } - // Increment counters - // See comments for dim_counters above - for (auto i = int64_t(num_dims - 1); i >= 0; --i) { - auto v = ++dim_counters[narrow(i)]; - assert(v <= upd_shape[narrow(i)]); - if (v < upd_shape[narrow(i)]) { - // No carry, done - break; - } - // No carry for the most significant dim - assert(i > 0); - dim_counters[narrow(i)] = 0; - } - } return Status::OK(); } template struct ScatterDataDispatchTarget { Status operator()(const Tensor* data_input, const std::vector& indices_data, const Tensor* updates_input, int64_t axis, - const std::string& reduction, Tensor* data_output) const { + const std::string& reduction, concurrency::ThreadPool* tp, Tensor* data_output) const { if (reduction == "add") return ScatterData( - Func_Add(), data_input, indices_data, updates_input, axis, data_output); + Func_Add(), data_input, indices_data, updates_input, axis, tp, data_output); else if (reduction == "mul") return ScatterData( - Func_Mul(), data_input, indices_data, updates_input, axis, data_output); + Func_Mul(), data_input, indices_data, updates_input, axis, tp, data_output); else if (reduction == "min") return ScatterData( - Func_Min(), data_input, indices_data, updates_input, axis, data_output); + Func_Min(), data_input, indices_data, updates_input, axis, tp, data_output); else if (reduction == "max") return ScatterData( - Func_Max(), data_input, indices_data, updates_input, axis, data_output); + Func_Max(), data_input, indices_data, updates_input, axis, tp, data_output); else // if (reduction == "none") return ScatterData( - Func_Assignment(), data_input, indices_data, updates_input, axis, data_output); + Func_Assignment(), data_input, indices_data, updates_input, axis, tp, data_output); } }; @@ -444,11 +488,12 @@ Status Scatter::Compute(OpKernelContext* context) const { Status status{}; const auto index_type = indices_input->GetElementType(); std::vector indices_data{}; + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); if (index_type == utils::ToTensorProtoElementType()) { - status = GetIndices(*data_input, *indices_input, axis, indices_data); + status = GetIndices(*data_input, *indices_input, axis, tp, indices_data); } else if (index_type == utils::ToTensorProtoElementType()) { - status = GetIndices(*data_input, *indices_input, axis, indices_data); + status = GetIndices(*data_input, *indices_input, axis, tp, indices_data); } else { status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Indices type is not supported."); } @@ -462,7 +507,7 @@ Status Scatter::Compute(OpKernelContext* context) const { utils::MLTypeCallDispatcherFromTypeList dispatcher{data_type}; status = dispatcher.template InvokeRet( - data_input, indices_data, updates_input, axis, this->reduction_, data_output); + data_input, indices_data, updates_input, axis, this->reduction_, tp, data_output); return status; } @@ -482,8 +527,8 @@ template Status GatherElementsGradImpl(const Tensor* indices_input, const Tensor* updates_input, const int64_t axis, Tensor* data_output) { std::vector indices_data{}; - ORT_RETURN_IF_ERROR(GetIndices(*data_output, *indices_input, axis, indices_data)); - return ScatterData(Func_Add(), data_output, indices_data, updates_input, axis, data_output); + ORT_RETURN_IF_ERROR(GetIndices(*data_output, *indices_input, axis, nullptr, indices_data)); + return ScatterData(Func_Add(), data_output, indices_data, updates_input, axis, nullptr, data_output); } #define GATHER_ELEMENTS_GRAD_IMPL_SPECIALIZED(Tin, Tdata) \ diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 1b2e8494f5f99..537c14fd2b3b3 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -16,10 +16,7 @@ #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/tunable/cuda_tuning_context.h" - -#ifndef DISABLE_CONTRIB_OPS #include "contrib_ops/cuda/bert/attention_kernel_options.h" -#endif namespace onnxruntime { @@ -91,13 +88,11 @@ class CUDAExecutionProvider : public IExecutionProvider { bool IsFuseConvBias() const { return info_.fuse_conv_bias; } bool UseTF32() const { return info_.use_tf32; } -#ifndef DISABLE_CONTRIB_OPS // Attention kernel options parsed from sdpa_kernel cuda provider option. const AttentionKernelOptions* GetAttentionKernelOptions() const { attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true, true); return &attention_kernel_options_; } -#endif ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); @@ -143,10 +138,8 @@ class CUDAExecutionProvider : public IExecutionProvider { // the tuning context might be altered when calling into a TunableOp mutable cuda::tunable::CudaTuningContext tuning_context_; -#ifndef DISABLE_CONTRIB_OPS // Attention kernel options parsed from sdpa_kernel cuda provider option. mutable AttentionKernelOptions attention_kernel_options_; -#endif class PerThreadContext final { public: diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 13bf5b37490e0..1d891f204b9bd 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -172,11 +172,9 @@ class CudaKernel : public OpKernel { return provider_->UseTF32(); } -#ifndef DISABLE_CONTRIB_OPS const AttentionKernelOptions* GetAttentionKernelOptions() const { return provider_->GetAttentionKernelOptions(); } -#endif tunable::CudaTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index fa3b6bb840854..c68ad570fbc44 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -101,7 +101,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, VitisAI, CoreML, NvTensorRtRtx, // TensorRt EP for RTX GPUs. - MIGraphX + MIGraphX, + CPU }; struct EpToAppend { @@ -110,7 +111,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, const char* canonical_name = nullptr; }; - static std::array supported_eps = { + static std::array supported_eps = { EpToAppend{EpID::DML, "DML", kDmlExecutionProvider}, EpToAppend{EpID::QNN, "QNN", kQnnExecutionProvider}, EpToAppend{EpID::OpenVINO, "OpenVINO", kOpenVINOExecutionProvider}, @@ -123,7 +124,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, EpToAppend{EpID::VitisAI, "VitisAI", kVitisAIExecutionProvider}, EpToAppend{EpID::CoreML, "CoreML", kCoreMLExecutionProvider}, EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}, - EpToAppend{EpID::MIGraphX, "MIGraphX", kMIGraphXExecutionProvider}}; + EpToAppend{EpID::MIGraphX, "MIGraphX", kMIGraphXExecutionProvider}, + EpToAppend{EpID::CPU, "CPU", kCpuExecutionProvider}}; ProviderOptions provider_options; OrtStatus* status = ParseProviderOptions(provider_options_keys, @@ -197,6 +199,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, ep_to_append.canonical_name)); switch (ep_to_append.id) { + case EpID::CPU: { + // CPU EP is always available by default. Accept the name as valid but do nothing, + // since the CPU EP is implicitly registered in every session. + break; + } case EpID::DML: { #if defined(USE_DML) options->provider_factories.push_back( diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 6137b23111bf9..e003f3bd93786 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -11,6 +11,7 @@ #include "ep_data_transfer.h" #include "ep_stream_support.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/onnxruntime_session_options_config_keys.h" ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtLogger& default_logger) @@ -141,6 +142,9 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* // random example using made up values factory->ort_api.AddKeyValuePair(ep_metadata, "supported_devices", "CrackGriffin 7+"); + // Example os_driver_version. A real EP would read the OS driver version from the device. + // The format is a 4-part dot-separated version matching the DXCore DriverVersion property. + factory->ort_api.AddKeyValuePair(ep_metadata, kOrtEpDevice_EpMetadataKey_OSDriverVersion, "31.0.101.1000"); factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true"); // OrtEpDevice copies ep_metadata and ep_options. @@ -171,6 +175,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* // Ort::KeyValuePairs ep_metadata; // Ort::KeyValuePairs ep_options; // ep_metadata.Add("supported_devices", "CrackGriffin 7+"); + // ep_metadata.Add(kOrtEpDevice_EpMetadataKey_OSDriverVersion, "31.0.101.1000"); // ep_options.Add("run_really_fast", "true"); // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; // ep_devices[num_ep_devices++] = ep_device.release(); diff --git a/onnxruntime/test/autoep/test_registration.cc b/onnxruntime/test/autoep/test_registration.cc index 79bc34572a6f7..40ac1670b07dc 100644 --- a/onnxruntime/test/autoep/test_registration.cc +++ b/onnxruntime/test/autoep/test_registration.cc @@ -70,6 +70,8 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { auto metadata = test_ep_device->EpMetadata(); ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), "0.1.0"); ASSERT_STREQ(metadata.GetValue("supported_devices"), "CrackGriffin 7+"); + // Verify the example plugin's expected os_driver_version value. + ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_OSDriverVersion), "31.0.101.1000"); auto options = test_ep_device->EpOptions(); ASSERT_STREQ(options.GetValue("run_really_fast"), "true"); diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc index fb64d6fa9b66d..5287859292f1f 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc @@ -421,6 +421,47 @@ TEST(DynamicQuantizeMatMul, KleidiRejectsUnsupportedBShape) { test.Run(); } +// 6. Mismatched bias (runtime tensor) -> must be rejected at compute time. +TEST(DynamicQuantizeMatMul, KleidiBiasRuntimeShapeMismatch) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + // Bias has only 1 element but N=3 — this must be rejected. + const std::vector bad_bias = {1.0f}; + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true /*initializer*/); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp, true); + test.AddInput("bias", {1}, bad_bias, false /*runtime*/); + test.AddOutput("Y", {data.M, data.N}, std::vector(data.M * data.N, 0.0f)); + test.ConfigEp(DefaultCpuExecutionProvider()) + .Config(OpTester::ExpectResult::kExpectFailure, + "bias tensor's element count must equal B's last dimension") + .RunWithConfig(); +} + +// 7. Mismatched bias (constant initializer) -> KleidiAI pre-pack rejects -> falls back to ComputeCommon +// -> rejected +TEST(DynamicQuantizeMatMul, KleidiBiasInitializerShapeMismatch) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + // Bias has only 1 element but N=3 — this must be rejected. + const std::vector bad_bias = {1.0f}; + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true /*initializer*/); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp, true); + test.AddInput("bias", {1}, bad_bias, true /*initializer*/); + test.AddOutput("Y", {data.M, data.N}, std::vector(data.M * data.N, 0.0f)); + test.ConfigEp(DefaultCpuExecutionProvider()) + .Config(OpTester::ExpectResult::kExpectFailure, + "bias tensor's element count must equal B's last dimension") + .RunWithConfig(); +} + #endif // USE_KLEIDIAI TEST(DynamicQuantizeMatMul, B_PerColumn_ND) { @@ -486,5 +527,35 @@ TEST(DynamicQuantizeMatMul, B_PerColumn_ND) { test_case({15, 14, 13}, {15, 13, 27}, {15, 1, 27}); } +// Test that a bias tensor with length mismatched to B's last dimension is rejected. +// This reproduces a heap OOB read when bias is shorter than N. +TEST(DynamicQuantizeMatMul, BiasShapeMismatch) { + constexpr int64_t M = 2; + constexpr int64_t K = 4; + constexpr int64_t N = 8; + + std::vector A_data(M * K, 1.0f); + std::vector B_data(K * N, 128); + std::vector B_scale = {0.5f}; + std::vector B_zero_point = {128}; + + // Bias has only 1 element but N=8 — this must be rejected. + std::vector bad_bias = {1.0f}; + + OpTester test("DynamicQuantizeMatMul", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {1}, bad_bias); + + test.AddOutput("Y", {M, N}, std::vector(M * N, 0.0f)); + + test.ConfigEp(DefaultCpuExecutionProvider()) + .Config(OpTester::ExpectResult::kExpectFailure, + "bias tensor's element count must equal B's last dimension") + .RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index a6dd471ce639f..112d6f1eecc72 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -1578,5 +1578,174 @@ TEST(GroupQueryAttentionTest, QuantizedKV_MissingScale) { {}, nullptr, &execution_providers); } +// Regression: seqlens_k valid for KV cache but exceeding cos_cache.shape[0] must be rejected +// when do_rotary is enabled. Without this check, the position ID derived from seqlens_k +// would index out of bounds in the cos/sin cache, leaking heap memory into output. +TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_OOB) { + constexpr int num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; // must be multiple of 16 for rotary + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int rotary_half_dim = head_size / 2; // cos/sin cache dim-1 = 8 + + constexpr int cos_cache_max_seq = 4; // small rotary cache + constexpr int past_seq_len = 16; // large KV cache + constexpr int seqlens_k_val = 10; // valid for KV (10 < 16) but OOB for cos (10 >= 4) + constexpr int total_seq_len = 4; // passes CheckRotaryCaches (4 <= cos_cache_max_seq) + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + tester.AddInput("query", {1, 1, hidden_size}, std::vector(hidden_size, 1.0f)); + tester.AddInput("key", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + tester.AddInput("value", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + + // Past KV cache is large enough for seqlens_k=10 + tester.AddInput("past_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + tester.AddInput("past_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + + tester.AddInput("seqlens_k", {1}, {seqlens_k_val}); + tester.AddInput("total_sequence_length", {1}, {total_seq_len}); + + // cos/sin cache with only 4 rows — seqlens_k=10 exceeds this + tester.AddInput("cos_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 1.0f)); + tester.AddInput("sin_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 0.0f)); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + tester.AddOutput("output", {1, 1, hidden_size}, std::vector(hidden_size, 0.0f)); + tester.AddOutput("present_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectFailure, "is out of range for rotary cache dimension 0", + {}, nullptr, &execution_providers); +} + +// Positive test: seqlens_k within cos/sin cache bounds with do_rotary enabled should succeed. +TEST(GroupQueryAttentionTest, SeqlensKWithinCosCache_Rotary) { + constexpr int num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; // must be multiple of 16 for rotary + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int rotary_half_dim = head_size / 2; + + constexpr int cos_cache_max_seq = 16; // rotary cache large enough + constexpr int past_seq_len = 16; + constexpr int seqlens_k_val = 3; // valid: 3 < 16 (cos cache) and 3 < 16 (KV cache) + constexpr int total_seq_len = 4; // seqlens_k + 1 + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + tester.AddInput("query", {1, 1, hidden_size}, std::vector(hidden_size, 1.0f)); + tester.AddInput("key", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + tester.AddInput("value", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + + tester.AddInput("past_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + tester.AddInput("past_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + + tester.AddInput("seqlens_k", {1}, {seqlens_k_val}); + tester.AddInput("total_sequence_length", {1}, {total_seq_len}); + + tester.AddInput("cos_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 1.0f)); + tester.AddInput("sin_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 0.0f)); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + tester.AddOutput("output", {1, 1, hidden_size}, std::vector(hidden_size, 0.0f)); + tester.AddOutput("present_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + + tester.SetOutputTolerance(1e6f); // shape acceptance test, not numerical correctness + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", + {}, nullptr, &execution_providers); +} + +// Multi-batch test: one valid and one OOB seqlens_k value. +// Verifies the validation loop correctly identifies the offending batch index. +TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_MultiBatch) { + constexpr int num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int rotary_half_dim = head_size / 2; + + constexpr int cos_cache_max_seq = 4; + constexpr int past_seq_len = 16; + constexpr int total_seq_len = 4; + constexpr int batch_size = 2; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + tester.AddInput("query", {batch_size, 1, hidden_size}, + std::vector(batch_size * hidden_size, 1.0f)); + tester.AddInput("key", {batch_size, 1, kv_hidden_size}, + std::vector(batch_size * kv_hidden_size, 1.0f)); + tester.AddInput("value", {batch_size, 1, kv_hidden_size}, + std::vector(batch_size * kv_hidden_size, 1.0f)); + + tester.AddInput("past_key", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.5f)); + tester.AddInput("past_value", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.5f)); + + // seqlens_k: batch 0 is valid (3 < 4), batch 1 is OOB (10 >= 4) + tester.AddInput("seqlens_k", {batch_size}, {3, 10}); + tester.AddInput("total_sequence_length", {1}, {total_seq_len}); + + tester.AddInput("cos_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 1.0f)); + tester.AddInput("sin_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 0.0f)); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + tester.AddOutput("output", {batch_size, 1, hidden_size}, + std::vector(batch_size * hidden_size, 0.0f)); + tester.AddOutput("present_key", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + // Error should reference batch index 1: seqlens_k[1] = 10 + tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k[1] = 10", + {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 30b0c0fcf73c3..7142358a4e02c 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -489,5 +489,72 @@ TEST(MatMulIntegerToFloat, MatMulInteger_With_ZeroPoint) { test_case({15, 14, 13}, {15, 13, 27}, {15, 1, 27}); } +// Test that a bias tensor with length mismatched to B's last dimension is rejected. +// This reproduces a heap OOB read when bias is shorter than N. +TEST(MatMulIntegerToFloat, BiasShapeMismatch) { + constexpr int64_t M = 2; + constexpr int64_t K = 4; + constexpr int64_t N = 8; + + std::vector A_data(M * K, 128); + std::vector B_data(K * N, 128); + std::vector A_scale = {0.5f}; + std::vector B_scale = {0.5f}; + std::vector A_zero_point = {128}; + std::vector B_zero_point = {128}; + + // Bias has only 1 element but N=8. This must be rejected. + std::vector bad_bias = {1.0f}; + + OpTester test("MatMulIntegerToFloat", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {1}, bad_bias); + + test.AddOutput("Y", {M, N}, std::vector(M * N, 0.0f)); + + test.ConfigEp(DefaultCpuExecutionProvider()) + .Config(OpTester::ExpectResult::kExpectFailure, + "bias tensor's element count must equal B's last dimension") + .RunWithConfig(); +} + +// Test that a bias tensor with length larger than B's last dimension is rejected. +TEST(MatMulIntegerToFloat, BiasShapeMismatch_LargerBias) { + constexpr int64_t M = 2; + constexpr int64_t K = 4; + constexpr int64_t N = 8; + + std::vector A_data(M * K, 128); + std::vector B_data(K * N, 128); + std::vector A_scale = {0.5f}; + std::vector B_scale = {0.5f}; + std::vector A_zero_point = {128}; + std::vector B_zero_point = {128}; + + // Bias has length > N, which must be rejected. + std::vector bad_bias(static_cast(N + 1), 1.0f); + + OpTester test("MatMulIntegerToFloat", 1, onnxruntime::kMSDomain); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {N + 1}, bad_bias); + + test.AddOutput("Y", {M, N}, std::vector(M * N, 0.0f)); + + test.ConfigEp(DefaultCpuExecutionProvider()) + .Config(OpTester::ExpectResult::kExpectFailure, + "bias tensor's element count must equal B's last dimension") + .RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc b/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc index 3f50166438190..7d159e934c927 100644 --- a/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc +++ b/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc @@ -167,5 +167,31 @@ TEST(ContribOpTest, WordConvEmbedding_rejects_sequence_rank_one) { test.Run(OpTester::ExpectResult::kExpectFailure, "Sequence input must have rank greater than 1"); } +TEST(ContribOpTest, WordConvEmbedding_rejects_undersized_bias) { + OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain); + + // W has 2 filters but B has only 1 element + test.AddInput("Sequence", {1, 2}, {1, 2}); + test.AddInput("W", {2, 1, 2, 1}, {1.0f, 1.0f, 1.0f, 1.0f}); + test.AddInput("B", {1}, {0.0f}); + test.AddInput("C", {3, 1}, {0.0f, 1.0f, 2.0f}); + test.AddOutput("Y", {1, 2}, {0.0f, 0.0f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "conv bias B must be a 1-D tensor of length 2"); +} + +TEST(ContribOpTest, WordConvEmbedding_rejects_2d_bias) { + OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain); + + // B has correct element count but wrong rank + test.AddInput("Sequence", {1, 2}, {1, 2}); + test.AddInput("W", {2, 1, 2, 1}, {1.0f, 1.0f, 1.0f, 1.0f}); + test.AddInput("B", {1, 2}, {0.0f, 0.0f}); + test.AddInput("C", {3, 1}, {0.0f, 1.0f, 2.0f}); + test.AddOutput("Y", {1, 2}, {0.0f, 0.0f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "conv bias B must be a 1-D tensor of length 2"); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/bench/riscv64/cast_rvv_bench.cpp b/onnxruntime/test/mlas/bench/riscv64/cast_rvv_bench.cpp new file mode 100644 index 0000000000000..bfdcb1d3c8cfc --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/cast_rvv_bench.cpp @@ -0,0 +1,165 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + cast_rvv_bench.cpp + +Abstract: + + Correctness and performance comparison of FP16<->FP32 cast kernels. + + Scalar path: ORT's internal fallback in cast.cpp + (MLAS_Half2Float / MLAS_Float2Half loop when CastKernel == nullptr) + Dispatch path: MlasConvertHalfToFloatBuffer / MlasConvertFloatToHalfBuffer + (dispatches to registered RVV kernel via platform.CastF16ToF32Kernel) + +--*/ + +#include "mlas.h" +#include "mlas_float16.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct Options { + size_t count = 1024 * 64; + size_t iters = 200; + size_t warmup = 20; +}; + +Options ParseArgs(int argc, char** argv) { + Options options; + for (int i = 1; i < argc; ++i) { + std::string_view arg(argv[i]); + const auto split = arg.find('='); + if (split == std::string_view::npos) continue; + const auto key = arg.substr(0, split); + const auto value = arg.substr(split + 1); + if (key == "--count") + options.count = std::strtoull(value.data(), nullptr, 10); + else if (key == "--iters") + options.iters = std::strtoull(value.data(), nullptr, 10); + else if (key == "--warmup") + options.warmup = std::strtoull(value.data(), nullptr, 10); + } + return options; +} + +float MakeValue(size_t index) { + uint32_t x = static_cast(index * 747796405u + 2891336453u); + x ^= x >> 16; + x *= 2246822519u; + x ^= x >> 13; + return (static_cast(x % 2048u) / 1024.0f) - 1.0f; +} + +template +double TimeLoop(size_t iterations, Fn&& fn) { + const auto begin = std::chrono::steady_clock::now(); + for (size_t i = 0; i < iterations; ++i) { + fn(); + } + const auto end = std::chrono::steady_clock::now(); + return std::chrono::duration(end - begin).count(); +} + +} // namespace + +int main(int argc, char** argv) { + const Options opts = ParseArgs(argc, argv); + const size_t N = opts.count; + + std::cout << "=== FP16<->FP32 Cast: RVV Dispatch vs ORT Scalar Fallback ===\n" + << " count=" << N << " iters=" << opts.iters << " warmup=" << opts.warmup << "\n\n"; + + std::vector fp32_src(N); + std::vector<_mlas_fp16_> fp16_src(N); + for (size_t i = 0; i < N; ++i) { + fp32_src[i] = MakeValue(i); + fp16_src[i] = MLAS_Float2Half(fp32_src[i]); + } + + std::vector f16_to_f32_fallback(N), f16_to_f32_dispatch(N); + std::vector<_mlas_fp16_> f32_to_f16_fallback(N), f32_to_f16_dispatch(N); + + // ORT scalar fallback: same as cast.cpp when CastF16ToF32Kernel == nullptr + // for (i) Destination[i] = Source[i].ToFloat(); // calls MLAS_Half2Float + auto fallback_h2f = [&]() { + for (size_t i = 0; i < N; ++i) + f16_to_f32_fallback[i] = MLAS_Half2Float(fp16_src[i]); + }; + auto fallback_f2h = [&]() { + for (size_t i = 0; i < N; ++i) + f32_to_f16_fallback[i] = MLAS_Float2Half(fp32_src[i]); + }; + + // ORT dispatch path: MlasConvertHalfToFloatBuffer (uses registered RVV kernel) + auto dispatch_h2f = [&]() { + MlasConvertHalfToFloatBuffer( + reinterpret_cast(fp16_src.data()), + f16_to_f32_dispatch.data(), N); + }; + auto dispatch_f2h = [&]() { + MlasConvertFloatToHalfBuffer( + fp32_src.data(), + reinterpret_cast(f32_to_f16_dispatch.data()), N); + }; + + // --- Correctness --- + fallback_h2f(); + dispatch_h2f(); + fallback_f2h(); + dispatch_f2h(); + + size_t h2f_mismatches = 0; + for (size_t i = 0; i < N; ++i) { + if (f16_to_f32_fallback[i] != f16_to_f32_dispatch[i]) h2f_mismatches++; + } + size_t f2h_mismatches = 0; + for (size_t i = 0; i < N; ++i) { + if (f32_to_f16_fallback[i] != f32_to_f16_dispatch[i]) f2h_mismatches++; + } + + std::cout << "Correctness:\n" + << " F16->F32: mismatches=" << h2f_mismatches << "/" << N + << (h2f_mismatches == 0 ? " PASS" : " FAIL") << "\n" + << " F32->F16: mismatches=" << f2h_mismatches << "/" << N + << (f2h_mismatches == 0 ? " PASS" : " FAIL") << "\n"; + + // --- Performance --- + for (size_t i = 0; i < opts.warmup; ++i) { + fallback_h2f(); + dispatch_h2f(); + fallback_f2h(); + dispatch_f2h(); + } + + double s_h2f = TimeLoop(opts.iters, fallback_h2f) / opts.iters; + double d_h2f = TimeLoop(opts.iters, dispatch_h2f) / opts.iters; + double s_f2h = TimeLoop(opts.iters, fallback_f2h) / opts.iters; + double d_f2h = TimeLoop(opts.iters, dispatch_f2h) / opts.iters; + + std::cout << std::fixed << std::setprecision(3) + << "\nF16->F32 (" << N << " elements):\n" + << " ORT Fallback: " << s_h2f << " ms\n" + << " ORT Dispatch: " << d_h2f << " ms\n" + << " Speedup: " << s_h2f / d_h2f << "x\n" + << "\nF32->F16 (" << N << " elements):\n" + << " ORT Fallback: " << s_f2h << " ms\n" + << " ORT Dispatch: " << d_f2h << " ms\n" + << " Speedup: " << s_f2h / d_f2h << "x\n"; + + return (h2f_mismatches + f2h_mismatches > 0) ? 1 : 0; +} diff --git a/onnxruntime/test/mlas/bench/riscv64/halfgemm_rvv_bench.cpp b/onnxruntime/test/mlas/bench/riscv64/halfgemm_rvv_bench.cpp new file mode 100644 index 0000000000000..0f74c4ec7017b --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/halfgemm_rvv_bench.cpp @@ -0,0 +1,253 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_rvv_bench.cpp + +Abstract: + + Correctness and performance comparison of RVV-accelerated FP16 GEMM + against ORT's built-in scalar FP16 GEMM dispatch (MlasHalfGemmDispatchDefault). + + Both paths use the same MLAS HalfGemm dispatch interface with FP16 I/O. + + Usage: + ./onnxruntime_mlas_halfgemm_rvv_bench [--m=N] [--n=N] [--k=N] + [--iters=N] [--warmup=N] [--bias=0|1] + +--*/ + +#include "mlas.h" +#include "mlas_float16.h" +#include "halfgemm.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct Options { + size_t m = 64; + size_t n = 768; + size_t k = 768; + size_t iters = 20; + size_t warmup = 3; + bool use_bias = false; +}; + +void PrintUsage(const char* argv0) { + std::cout + << "Usage: " << argv0 + << " [--m=N] [--n=N] [--k=N] [--iters=N] [--warmup=N] [--bias=0|1]\n"; +} + +bool ParseBool(std::string_view value) { + return value == "1" || value == "true" || value == "on" || value == "yes"; +} + +Options ParseArgs(int argc, char** argv) { + Options options; + for (int i = 1; i < argc; ++i) { + std::string_view arg(argv[i]); + if (arg == "--help" || arg == "-h") { + PrintUsage(argv[0]); + std::exit(0); + } + const auto split = arg.find('='); + if (split == std::string_view::npos || split == 0 || split + 1 >= arg.size()) { + continue; + } + const std::string_view key = arg.substr(0, split); + const std::string_view value = arg.substr(split + 1); + if (key == "--m") + options.m = std::strtoull(value.data(), nullptr, 10); + else if (key == "--n") + options.n = std::strtoull(value.data(), nullptr, 10); + else if (key == "--k") + options.k = std::strtoull(value.data(), nullptr, 10); + else if (key == "--iters") + options.iters = std::strtoull(value.data(), nullptr, 10); + else if (key == "--warmup") + options.warmup = std::strtoull(value.data(), nullptr, 10); + else if (key == "--bias") + options.use_bias = ParseBool(value); + } + return options; +} + +float MakeValue(size_t index) { + uint32_t x = static_cast(index * 747796405u + 2891336453u); + x ^= x >> 16; + x *= 2246822519u; + x ^= x >> 13; + const uint32_t bucket = x % 2048u; + return (static_cast(bucket) / 1024.0f) - 1.0f; +} + +template +double TimeLoop(size_t iterations, Fn&& fn) { + const auto begin = std::chrono::steady_clock::now(); + for (size_t i = 0; i < iterations; ++i) { + fn(); + } + const auto end = std::chrono::steady_clock::now(); + return std::chrono::duration(end - begin).count(); +} + +void RunDispatch( + const MLAS_HALFGEMM_DISPATCH& dispatch, + size_t M, size_t N, size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* data) { + dispatch.Operation(N, K, data, 0, M, 0, N); +} + +} // namespace + +int main(int argc, char** argv) { + const Options options = ParseArgs(argc, argv); + + if (options.m == 0 || options.n == 0 || options.k == 0 || options.iters == 0) { + std::cerr << "m, n, k, and iters must be > 0\n"; + return 1; + } + + const bool fp16_supported = MlasFp16AccelerationSupported(); + + std::cout << "=== FP16 GEMM: RVV vs ORT Scalar Dispatch ===\n" + << " M=" << options.m << " N=" << options.n << " K=" << options.k + << " bias=" << (options.use_bias ? "yes" : "no") << "\n" + << " iters=" << options.iters << " warmup=" << options.warmup << "\n" + << " FP16 acceleration: " << (fp16_supported ? "YES (RVV)" : "NO") << "\n\n"; + + const size_t a_size = options.m * options.k; + const size_t b_size = options.k * options.n; + const size_t c_size = options.m * options.n; + + std::vector<_mlas_fp16_> a_fp16(a_size); + std::vector<_mlas_fp16_> b_fp16(b_size); + std::vector<_mlas_fp16_> bias_fp16(options.n); + std::vector<_mlas_fp16_> c_rvv(c_size); + std::vector<_mlas_fp16_> c_scalar(c_size); + + for (size_t i = 0; i < a_size; ++i) { + a_fp16[i] = MLAS_Float2Half(MakeValue(i) * 0.1f); + } + for (size_t i = 0; i < b_size; ++i) { + b_fp16[i] = MLAS_Float2Half(MakeValue(i + a_size) * 0.1f); + } + if (options.use_bias) { + for (size_t i = 0; i < options.n; ++i) { + bias_fp16[i] = MLAS_Float2Half(MakeValue(i + a_size + b_size) * 0.01f); + } + } + + MLAS_HALF_GEMM_DATA_PARAMS params_scalar; + params_scalar.A = a_fp16.data(); + params_scalar.lda = options.k; + params_scalar.B = b_fp16.data(); + params_scalar.ldb = options.n; + params_scalar.C = reinterpret_cast(c_scalar.data()); + params_scalar.ldc = options.n; + params_scalar.Bias = options.use_bias + ? reinterpret_cast(bias_fp16.data()) + : nullptr; + params_scalar.AIsfp32 = false; + params_scalar.BIsfp32 = false; + params_scalar.OutputProcessor = nullptr; + + MLAS_HALF_GEMM_DATA_PARAMS params_rvv = params_scalar; + params_rvv.C = reinterpret_cast(c_rvv.data()); + + // --- Run both dispatches --- + RunDispatch(MlasHalfGemmDispatchDefault, options.m, options.n, options.k, ¶ms_scalar); + +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) + RunDispatch(MlasHalfGemmDispatchRvv, options.m, options.n, options.k, ¶ms_rvv); +#else + RunDispatch(MlasHalfGemmDispatchDefault, options.m, options.n, options.k, ¶ms_rvv); + std::cout << " (RVV dispatch not available, comparing scalar vs scalar)\n\n"; +#endif + + // --- Correctness: RVV vs ORT Scalar --- + double max_abs_err = 0.0; + double max_rel_err = 0.0; + size_t error_count = 0; + + for (size_t i = 0; i < c_size; ++i) { + float ref = MLAS_Half2Float(c_scalar[i]); + float got = MLAS_Half2Float(c_rvv[i]); + double abs_err = std::abs(ref - got); + double rel_err = (std::abs(ref) > 1e-6) ? abs_err / std::abs(ref) : abs_err; + + if (abs_err > max_abs_err) max_abs_err = abs_err; + if (rel_err > max_rel_err) max_rel_err = rel_err; + + if (rel_err > 0.10 && abs_err > 0.005) { + if (error_count < 10) { + std::cerr << " MISMATCH [" << i / options.n << "," << i % options.n + << "]: scalar=" << ref << " rvv=" << got + << " abs=" << abs_err << " rel=" << rel_err << "\n"; + } + error_count++; + } + } + + std::cout << "Correctness (RVV vs ORT Scalar):\n" + << " max abs error: " << max_abs_err << "\n" + << " max rel error: " << max_rel_err << "\n" + << " mismatches (>10% rel && >0.005 abs): " << error_count + << " / " << c_size << "\n"; + + if (error_count > 0) { + std::cout << " STATUS: FAIL\n\n"; + } else { + std::cout << " STATUS: PASS\n\n"; + } + + // --- Performance --- + auto run_scalar_fn = [&]() { + RunDispatch(MlasHalfGemmDispatchDefault, options.m, options.n, options.k, ¶ms_scalar); + }; + + auto run_rvv_fn = [&]() { +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH) + RunDispatch(MlasHalfGemmDispatchRvv, options.m, options.n, options.k, ¶ms_rvv); +#else + RunDispatch(MlasHalfGemmDispatchDefault, options.m, options.n, options.k, ¶ms_rvv); +#endif + }; + + for (size_t i = 0; i < options.warmup; ++i) { + run_scalar_fn(); + run_rvv_fn(); + } + + const double scalar_ms = TimeLoop(options.iters, run_scalar_fn); + const double scalar_avg = scalar_ms / static_cast(options.iters); + + const double rvv_ms = TimeLoop(options.iters, run_rvv_fn); + const double rvv_avg = rvv_ms / static_cast(options.iters); + + const double flops = 2.0 * options.m * options.n * options.k; + const double scalar_gflops = flops / (scalar_avg * 1e6); + const double rvv_gflops = flops / (rvv_avg * 1e6); + const double speedup = scalar_avg / rvv_avg; + + std::cout << std::fixed << std::setprecision(3) + << "Performance:\n" + << " ORT Scalar: " << scalar_avg << " ms (" << scalar_gflops << " GFLOPS)\n" + << " RVV: " << rvv_avg << " ms (" << rvv_gflops << " GFLOPS)\n" + << " Speedup: " << speedup << "x\n"; + + return (error_count > 0) ? 1 : 0; +} diff --git a/onnxruntime/test/mlas/bench/riscv64/rmsnorm_rvv_bench.cpp b/onnxruntime/test/mlas/bench/riscv64/rmsnorm_rvv_bench.cpp new file mode 100644 index 0000000000000..df777744da4cb --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/rmsnorm_rvv_bench.cpp @@ -0,0 +1,169 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rmsnorm_rvv_bench.cpp + +Abstract: + + Correctness and performance comparison of RMSNorm (SimplifiedLayerNorm). + + Scalar path: ORT's ComputeJob with simplified=true + (anonymous namespace in layer_norm_impl.cc, reproduced here verbatim) + MLAS path: MlasLayerNormF32 dispatch (uses RVV kernel when available) + +--*/ + +#include "mlas.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct Options { + size_t hidden = 1024; + size_t iters = 500; + size_t warmup = 50; +}; + +Options ParseArgs(int argc, char** argv) { + Options options; + for (int i = 1; i < argc; ++i) { + std::string_view arg(argv[i]); + const auto split = arg.find('='); + if (split == std::string_view::npos) continue; + const auto key = arg.substr(0, split); + const auto value = arg.substr(split + 1); + if (key == "--hidden") + options.hidden = std::strtoull(value.data(), nullptr, 10); + else if (key == "--iters") + options.iters = std::strtoull(value.data(), nullptr, 10); + else if (key == "--warmup") + options.warmup = std::strtoull(value.data(), nullptr, 10); + } + return options; +} + +float MakeValue(size_t index) { + uint32_t x = static_cast(index * 747796405u + 2891336453u); + x ^= x >> 16; + x *= 2246822519u; + x ^= x >> 13; + return (static_cast(x % 2048u) / 1024.0f) - 1.0f; +} + +template +double TimeLoop(size_t iterations, Fn&& fn) { + const auto begin = std::chrono::steady_clock::now(); + for (size_t i = 0; i < iterations; ++i) { + fn(); + } + const auto end = std::chrono::steady_clock::now(); + return std::chrono::duration(end - begin).count(); +} + +// +// ORT scalar path: verbatim from layer_norm_impl.cc ComputeJob +// with simplified=true (RMSNorm). +// +void OrtRmsNormScalar( + const float* input, + const float* scale, + size_t norm_size, + float epsilon, + float* output) { + float mean_square = 0.0f; + for (size_t h = 0; h < norm_size; h++) { + output[h] = input[h]; + mean_square += input[h] * input[h]; + } + mean_square = sqrtf(mean_square / static_cast(norm_size) + epsilon); + for (size_t h = 0; h < norm_size; h++) { + output[h] = output[h] / mean_square * scale[h]; + } +} + +void OrtRmsNormMlas( + const float* input, + const float* scale, + size_t norm_size, + float epsilon, + float* output) { + if (!MlasLayerNormF32(input, scale, nullptr, output, nullptr, nullptr, + norm_size, epsilon, true)) { + OrtRmsNormScalar(input, scale, norm_size, epsilon, output); + } +} + +} // namespace + +int main(int argc, char** argv) { + const Options opts = ParseArgs(argc, argv); + const size_t N = opts.hidden; + const float epsilon = 1e-6f; + + std::cout << "=== RMSNorm: MLAS Dispatch vs ORT Scalar ===\n" + << " hidden=" << N << " iters=" << opts.iters << " warmup=" << opts.warmup << "\n\n"; + + std::vector input(N), scale(N); + std::vector out_scalar(N), out_rvv(N); + + for (size_t i = 0; i < N; i++) { + input[i] = MakeValue(i) * 0.1f; + scale[i] = 1.0f + MakeValue(i + N) * 0.01f; + } + + // --- Correctness --- + OrtRmsNormScalar(input.data(), scale.data(), N, epsilon, out_scalar.data()); + OrtRmsNormMlas(input.data(), scale.data(), N, epsilon, out_rvv.data()); + + double max_abs = 0.0, max_rel = 0.0; + size_t mismatches = 0; + for (size_t i = 0; i < N; i++) { + double abs_err = std::abs(out_scalar[i] - out_rvv[i]); + double rel_err = (std::abs(out_scalar[i]) > 1e-7) ? abs_err / std::abs(out_scalar[i]) : abs_err; + if (abs_err > max_abs) max_abs = abs_err; + if (rel_err > max_rel) max_rel = rel_err; + if (abs_err > 1e-5) mismatches++; + } + + std::cout << "Correctness:\n" + << " max_abs=" << max_abs << " max_rel=" << max_rel + << " mismatches=" << mismatches << "/" << N + << (mismatches == 0 ? " PASS" : " FAIL") << "\n"; + + // --- Performance --- + auto run_scalar = [&]() { + OrtRmsNormScalar(input.data(), scale.data(), N, epsilon, out_scalar.data()); + }; + auto run_rvv = [&]() { + OrtRmsNormMlas(input.data(), scale.data(), N, epsilon, out_rvv.data()); + }; + + for (size_t i = 0; i < opts.warmup; i++) { + run_scalar(); + run_rvv(); + } + + double scalar_ms = TimeLoop(opts.iters, run_scalar) / opts.iters; + double rvv_ms = TimeLoop(opts.iters, run_rvv) / opts.iters; + + std::cout << std::fixed << std::setprecision(4) + << "\nPerformance:\n" + << " ORT Scalar: " << scalar_ms * 1000 << " us\n" + << " RVV: " << rvv_ms * 1000 << " us\n" + << " Speedup: " << scalar_ms / rvv_ms << "x\n"; + + return (mismatches > 0) ? 1 : 0; +} diff --git a/onnxruntime/test/mlas/bench/riscv64/rope_rvv_bench.cpp b/onnxruntime/test/mlas/bench/riscv64/rope_rvv_bench.cpp new file mode 100644 index 0000000000000..cbe6941bf8ca7 --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/rope_rvv_bench.cpp @@ -0,0 +1,152 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rope_rvv_bench.cpp + +Abstract: + + Correctness and performance comparison of RotaryEmbedding. + + Scalar path: MlasRotaryEmbedOneRow_FallBack (ORT's internal scalar fallback) + Dispatch path: MlasRotaryEmbedOneRow (dispatches to RVV kernel via platform) + +--*/ + +#include "mlas.h" +#include "mlas_float16.h" +#include "rotary_embedding.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct Options { + size_t dim = 128; + size_t iters = 500; + size_t warmup = 50; +}; + +Options ParseArgs(int argc, char** argv) { + Options options; + for (int i = 1; i < argc; ++i) { + std::string_view arg(argv[i]); + const auto split = arg.find('='); + if (split == std::string_view::npos) continue; + const auto key = arg.substr(0, split); + const auto value = arg.substr(split + 1); + if (key == "--dim") + options.dim = std::strtoull(value.data(), nullptr, 10); + else if (key == "--iters") + options.iters = std::strtoull(value.data(), nullptr, 10); + else if (key == "--warmup") + options.warmup = std::strtoull(value.data(), nullptr, 10); + } + return options; +} + +float MakeValue(size_t index) { + uint32_t x = static_cast(index * 747796405u + 2891336453u); + x ^= x >> 16; + x *= 2246822519u; + x ^= x >> 13; + return (static_cast(x % 2048u) / 1024.0f) - 1.0f; +} + +template +double TimeLoop(size_t iterations, Fn&& fn) { + const auto begin = std::chrono::steady_clock::now(); + for (size_t i = 0; i < iterations; ++i) { + fn(); + } + const auto end = std::chrono::steady_clock::now(); + return std::chrono::duration(end - begin).count(); +} + +void CompareResults(const float* ref, const float* got, size_t n) { + double max_abs = 0.0, max_rel = 0.0; + size_t mismatches = 0; + for (size_t i = 0; i < n; i++) { + double abs_err = std::abs(ref[i] - got[i]); + double rel_err = (std::abs(ref[i]) > 1e-7) ? abs_err / std::abs(ref[i]) : abs_err; + if (abs_err > max_abs) max_abs = abs_err; + if (rel_err > max_rel) max_rel = rel_err; + if (abs_err > 1e-5) mismatches++; + } + std::cout << " max_abs=" << max_abs << " max_rel=" << max_rel + << " mismatches=" << mismatches << "/" << n + << (mismatches == 0 ? " PASS" : " FAIL") << "\n"; +} + +void BenchRoPE(const char* label, size_t dim, bool interleaved, size_t iters, size_t warmup) { + if (dim % 2 != 0) { + std::cerr << "Error: dim must be even, got " << dim << "\n"; + return; + } + const size_t half = dim / 2; + + std::vector input(dim), sin_data(half), cos_data(half); + std::vector out_fallback(dim), out_dispatch(dim); + + for (size_t i = 0; i < dim; i++) input[i] = MakeValue(i); + for (size_t i = 0; i < half; i++) { + sin_data[i] = sinf(static_cast(i) * 0.01f); + cos_data[i] = cosf(static_cast(i) * 0.01f); + } + + // ORT scalar fallback + MlasRotaryEmbedOneRow_FallBack( + input.data(), sin_data.data(), cos_data.data(), dim, interleaved, out_fallback.data()); + // ORT dispatch (→ RVV) + MlasRotaryEmbedOneRow( + input.data(), sin_data.data(), cos_data.data(), dim, interleaved, out_dispatch.data()); + + std::cout << "--- " << label << " (dim=" << dim << ") ---\n"; + CompareResults(out_fallback.data(), out_dispatch.data(), dim); + + auto run_fallback = [&]() { + MlasRotaryEmbedOneRow_FallBack( + input.data(), sin_data.data(), cos_data.data(), dim, interleaved, out_fallback.data()); + }; + auto run_dispatch = [&]() { + MlasRotaryEmbedOneRow( + input.data(), sin_data.data(), cos_data.data(), dim, interleaved, out_dispatch.data()); + }; + + for (size_t i = 0; i < warmup; i++) { + run_fallback(); + run_dispatch(); + } + + double fallback_ms = TimeLoop(iters, run_fallback) / iters; + double dispatch_ms = TimeLoop(iters, run_dispatch) / iters; + + std::cout << std::fixed << std::setprecision(4) + << " ORT Fallback: " << fallback_ms * 1000 << " us\n" + << " ORT Dispatch: " << dispatch_ms * 1000 << " us\n" + << " Speedup: " << fallback_ms / dispatch_ms << "x\n\n"; +} + +} // namespace + +int main(int argc, char** argv) { + const Options opts = ParseArgs(argc, argv); + + std::cout << "=== RotaryEmbedding: RVV Dispatch vs ORT Scalar Fallback ===\n\n"; + + BenchRoPE("RoPE non-interleaved", opts.dim, false, opts.iters, opts.warmup); + BenchRoPE("RoPE interleaved", opts.dim, true, opts.iters, opts.warmup); + + return 0; +} diff --git a/onnxruntime/test/mlas/unittest/test_cast_fp16.cpp b/onnxruntime/test/mlas/unittest/test_cast_fp16.cpp new file mode 100644 index 0000000000000..1b8126b384f1e --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_cast_fp16.cpp @@ -0,0 +1,120 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_cast_fp16.cpp + +Abstract: + + Tests for MLAS FP16<->FP32 cast kernels. + Verifies bit-exactness against MLAS_Half2Float / MLAS_Float2Half. + +--*/ + +#include "test_util.h" +#include "mlas.h" +#include "mlas_float16.h" + +#include + +class MlasCastFp16Test : public MlasTestBase { + public: + void TestF16ToF32(size_t count) { + std::vector<_mlas_fp16_> input(count); + std::vector output_ref(count); + std::vector output_dispatch(count); + + for (size_t i = 0; i < count; i++) { + float val = (static_cast(i % 2048) / 1024.0f) - 1.0f; + input[i] = MLAS_Float2Half(val); + output_ref[i] = MLAS_Half2Float(input[i]); + } + + MlasConvertHalfToFloatBuffer( + reinterpret_cast(input.data()), + output_dispatch.data(), count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(output_dispatch[i], output_ref[i]) + << "F16->F32 mismatch at [" << i << "], count=" << count; + } + } + + void TestF32ToF16(size_t count) { + std::vector input(count); + std::vector<_mlas_fp16_> output_ref(count); + std::vector<_mlas_fp16_> output_dispatch(count); + + for (size_t i = 0; i < count; i++) { + input[i] = (static_cast(i % 2048) / 1024.0f) - 1.0f; + output_ref[i] = MLAS_Float2Half(input[i]); + } + + MlasConvertFloatToHalfBuffer( + input.data(), + reinterpret_cast(output_dispatch.data()), count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(output_dispatch[i], output_ref[i]) + << "F32->F16 mismatch at [" << i << "], count=" << count; + } + } +}; + +class CastFp16ShortExecuteTest : public MlasTestFixture { + public: + CastFp16ShortExecuteTest(size_t count, bool f16_to_f32) + : count_(count), f16_to_f32_(f16_to_f32) {} + + void TestBody() override { + if (f16_to_f32_) { + MlasTestFixture::mlas_tester->TestF16ToF32(count_); + } else { + MlasTestFixture::mlas_tester->TestF32ToF16(count_); + } + } + + static size_t RegisterSingleTest(size_t count, bool f16_to_f32) { + std::stringstream ss; + ss << "/" << (f16_to_f32 ? "F16toF32" : "F32toF16") + << "/count" << count; + auto test_name = ss.str(); + + testing::RegisterTest( + "CastFp16", + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture* { + return new CastFp16ShortExecuteTest(count, f16_to_f32); + }); + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t cnt = 0; + for (size_t n : {1, 7, 15, 16, 31, 32, 63, 64, 128, 255, 256, 1024, 65536}) { + cnt += RegisterSingleTest(n, true); + cnt += RegisterSingleTest(n, false); + } + return cnt; + } + + private: + size_t count_; + bool f16_to_f32_; +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return CastFp16ShortExecuteTest::RegisterShortExecuteTests(); + } + return 0; + }); diff --git a/onnxruntime/test/mlas/unittest/test_layernorm.cpp b/onnxruntime/test/mlas/unittest/test_layernorm.cpp new file mode 100644 index 0000000000000..7475f082bb443 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_layernorm.cpp @@ -0,0 +1,157 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_layernorm.cpp + +Abstract: + + Tests for MLAS LayerNorm/RMSNorm (MlasLayerNormF32). + +--*/ + +#include "test_util.h" +#include "mlas.h" + +#include +#include + +class MlasLayerNormTest : public MlasTestBase { + private: + void ScalarLayerNorm( + const float* input, + const float* scale, + const float* bias, + float* output, + float* mean_out, + float* inv_std_out, + size_t norm_size, + float epsilon, + bool simplified) { + float sum = 0.0f; + float sum_sq = 0.0f; + for (size_t i = 0; i < norm_size; i++) { + sum += input[i]; + sum_sq += input[i] * input[i]; + } + float mean = sum / static_cast(norm_size); + float denom; + if (simplified) { + denom = std::sqrt(sum_sq / static_cast(norm_size) + epsilon); + } else { + denom = std::sqrt(sum_sq / static_cast(norm_size) - mean * mean + epsilon); + } + float inv_denom = 1.0f / denom; + + for (size_t i = 0; i < norm_size; i++) { + if (simplified) { + output[i] = input[i] * inv_denom * scale[i]; + } else if (bias == nullptr) { + output[i] = (input[i] - mean) * inv_denom * scale[i]; + } else { + output[i] = (input[i] - mean) * inv_denom * scale[i] + bias[i]; + } + } + if (mean_out) *mean_out = mean; + if (inv_std_out) *inv_std_out = inv_denom; + } + + public: + void Test(size_t norm_size, bool simplified, bool with_bias) { + std::vector input(norm_size); + std::vector scale(norm_size); + std::vector bias(norm_size); + std::vector output_ref(norm_size); + std::vector output_mlas(norm_size); + float mean_ref = 0, mean_mlas = 0; + float inv_std_ref = 0, inv_std_mlas = 0; + + for (size_t i = 0; i < norm_size; i++) { + input[i] = (static_cast(i % 127) - 63.0f) * 0.01f; + scale[i] = 1.0f + (static_cast(i % 31) - 15.0f) * 0.001f; + bias[i] = (static_cast(i % 17) - 8.0f) * 0.005f; + } + + const float* bias_ptr = (with_bias && !simplified) ? bias.data() : nullptr; + + ScalarLayerNorm(input.data(), scale.data(), bias_ptr, + output_ref.data(), &mean_ref, &inv_std_ref, + norm_size, 1e-5f, simplified); + + bool used = MlasLayerNormF32(input.data(), scale.data(), bias_ptr, + output_mlas.data(), &mean_mlas, &inv_std_mlas, + norm_size, 1e-5f, simplified); + + if (!used) { + // No optimized kernel available, skip comparison + return; + } + + for (size_t i = 0; i < norm_size; i++) { + ASSERT_NEAR(output_mlas[i], output_ref[i], 1e-4f) + << "output mismatch at [" << i << "], norm_size=" << norm_size + << " simplified=" << simplified << " bias=" << with_bias; + } + ASSERT_NEAR(mean_mlas, mean_ref, 1e-4f) << "mean mismatch"; + ASSERT_NEAR(inv_std_mlas, inv_std_ref, 1e-4f) << "inv_std_dev mismatch"; + } +}; + +class LayerNormShortExecuteTest : public MlasTestFixture { + public: + LayerNormShortExecuteTest(size_t norm_size, bool simplified, bool with_bias) + : norm_size_(norm_size), simplified_(simplified), with_bias_(with_bias) {} + + void TestBody() override { + MlasTestFixture::mlas_tester->Test(norm_size_, simplified_, with_bias_); + } + + static size_t RegisterSingleTest(size_t norm_size, bool simplified, bool with_bias) { + std::stringstream ss; + ss << "/norm_size" << norm_size + << "/simplified" << simplified + << "/bias" << with_bias; + auto test_name = ss.str(); + + testing::RegisterTest( + "LayerNorm", + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture* { + return new LayerNormShortExecuteTest(norm_size, simplified, with_bias); + }); + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t count = 0; + for (size_t n : {1, 7, 32, 63, 64, 127, 128, 256, 1024}) { + for (bool simplified : {true, false}) { + for (bool with_bias : {true, false}) { + count += RegisterSingleTest(n, simplified, with_bias); + } + } + } + return count; + } + + private: + size_t norm_size_; + bool simplified_; + bool with_bias_; +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return LayerNormShortExecuteTest::RegisterShortExecuteTests(); + } + return 0; + }); diff --git a/onnxruntime/test/mlas/unittest/test_rope.cpp b/onnxruntime/test/mlas/unittest/test_rope.cpp index eeb369224d523..c6e4b3d6545ae 100644 --- a/onnxruntime/test/mlas/unittest/test_rope.cpp +++ b/onnxruntime/test/mlas/unittest/test_rope.cpp @@ -124,8 +124,8 @@ class RoPEShortExecuteTest : public MlasTestFixture> { bool interleaved_; }; -// only test float RoPE with avx2 where RopeDispatch is assigned at this moment. -#ifdef MLAS_TARGET_AMD64 +// Enable RoPE tests on platforms where RopeDispatch is assigned. +#if defined(MLAS_TARGET_AMD64) || (defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV)) static size_t RoPERegisterAllShortExecuteTests() { return RoPEShortExecuteTest::RegisterShortExecuteTests() + RoPEShortExecuteTest::RegisterShortExecuteTests(); } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 0e4ab5c2d3b73..924ceaa19a47d 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -5027,6 +5027,96 @@ TEST_F(GraphTransformationTests, ReshapeFusionContiguousReshapesWithZeroDim) { EXPECT_EQ(y_shape->dim(2).dim_value(), 3); } +// Execution regression test: a chained Reshape with allowzero=1 on a zero-element tensor +// must produce the correct output shape at runtime. +// Input X: float[0, 8, 2] -> Reshape([4, 2, -1]) -> mid -> Reshape([0, 0, 4], allowzero=1) -> Y +// Expected Y shape: (0, 0, 4). Without the fix FuseContiguousReshapes would collapse the +// two nodes into one (losing allowzero=1) and emit (0, 8, 4) instead. +// See https://github.com/microsoft/onnxruntime/issues/28348. +TEST_F(GraphTransformationTests, ReshapeFusionContiguousReshapesWithZeroDimExecution) { + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 18; + Model model("ReshapeFusionContiguousReshapesWithZeroDimExecution", false, ModelMetaData(), + PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), *logger_); + auto& graph = model.MainGraph(); + + // X: float[0, 8, 2] + TypeProto x_type; + x_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(0); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(8); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + TypeProto y_type; + y_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + + auto& X = graph.GetOrCreateNodeArg("X", &x_type); + auto& mid = graph.GetOrCreateNodeArg("mid", &y_type); + auto& Y = graph.GetOrCreateNodeArg("Y", &y_type); + + // shape1 = [4, 2, -1] -> mid shape (4, 2, 0) + ONNX_NAMESPACE::TensorProto shape1_proto; + shape1_proto.set_name("shape1"); + shape1_proto.set_data_type(TensorProto_DataType_INT64); + shape1_proto.add_dims(3); + for (int64_t v : {4, 2, -1}) shape1_proto.add_int64_data(v); + graph.AddInitializedTensor(shape1_proto); + + // shape2 = [0, 0, 4] with allowzero=1 -> Y shape (0, 0, 4) + ONNX_NAMESPACE::TensorProto shape2_proto; + shape2_proto.set_name("shape2"); + shape2_proto.set_data_type(TensorProto_DataType_INT64); + shape2_proto.add_dims(3); + for (int64_t v : {0, 0, 4}) shape2_proto.add_int64_data(v); + graph.AddInitializedTensor(shape2_proto); + + auto& shape1 = graph.GetOrCreateNodeArg("shape1", nullptr); + auto& shape2 = graph.GetOrCreateNodeArg("shape2", nullptr); + + graph.AddNode("reshape1", "Reshape", "first reshape", {&X, &shape1}, {&mid}); + auto& reshape2 = graph.AddNode("reshape2", "Reshape", "second reshape (allowzero=1)", + {&mid, &shape2}, {&Y}); + reshape2.AddAttribute("allowzero", static_cast(1)); + + graph.SetInputs({&X}); + graph.SetOutputs({&Y}); + + ASSERT_STATUS_OK(graph.Resolve()); + + // Serialize and run via InferenceSession to exercise the full execution path. + auto model_proto = model.ToProto(); + std::string serialized_model; + ASSERT_TRUE(model_proto.SerializeToString(&serialized_model)); + + SessionOptions so; + InferenceSession session_object{so, GetEnvironment()}; + std::stringstream model_stream(serialized_model); + ASSERT_STATUS_OK(session_object.Load(model_stream)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // Input: zero-element float tensor with shape [0, 8, 2]. + OrtValue input_val; + std::vector input_dims = {0, 8, 2}; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + input_dims, std::vector(), &input_val); + + NameMLValMap feeds = {{"X", input_val}}; + std::vector output_names = {"Y"}; + std::vector fetches; + RunOptions run_options; + ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); + + // Output shape must be (0, 0, 4), not (0, 8, 4). + ASSERT_EQ(fetches.size(), 1U); + const auto& output_tensor = fetches[0].Get(); + const TensorShape& output_shape = output_tensor.Shape(); + ASSERT_EQ(output_shape.NumDimensions(), 3U); + EXPECT_EQ(output_shape[0], 0); + EXPECT_EQ(output_shape[1], 0); + EXPECT_EQ(output_shape[2], 4); +} + TEST_F(GraphTransformationTests, ReshapeFusionWithSlice1) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/reshape_fusion_with_slice1.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index da933464bb66b..bdbd2c488584d 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -4058,6 +4058,95 @@ TEST(QDQTransformerTests, QDQPropagation_DQForward) { #endif } +// Regression test for GitHub issue #28491. +// When a DQ node's data input is a constant (graph initializer or Constant op output), +// PropagateDQForward must not insert a Q -> DQ pair downstream of a reshape-like node. +// Doing so can cause subsequent S8-to-U8 weight transformers to flip the DQ dtype while +// leaving the inserted Q node in its original dtype, clamping int8 negatives to zero. +TEST(QDQTransformerTests, QDQPropagation_DQForward_ConstantInput_NoPropagation) { + // Case 1: DQ data input is a graph initializer. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + // int8 constant weight as a graph initializer + auto* weight = builder.MakeInitializer({4}, {-10, 0, 10, 20}); + auto* output_arg = builder.MakeOutput(); + + // DQ node that dequantizes the constant weight + constexpr float qdq_scale = 0.1f; + constexpr int8_t qdq_zero_point = 0; + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(weight, qdq_scale, qdq_zero_point, dq_output); + + // Reshape downstream of DQ + auto* reshape_shape = builder.Make1DInitializer({2, 2}); + builder.AddNode("Reshape", {dq_output, reshape_shape}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + const auto op_types = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true); + // No Q or DQ should have been inserted after Reshape. + // Expected order: DequantizeLinear -> Reshape (no trailing Q/DQ). + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + const std::vector expected{qdq_keys.dequantize_linear, "Reshape"}; + EXPECT_EQ(op_types, expected); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 12); + } + + // Case 2: DQ data input is the output of a Constant op node. + // Run QDQPropagationTransformer directly (bypassing ConstantFolding) so the + // Constant op node is still present when PropagateDQForward evaluates it. + // Using TransformerTester would fold the Constant into an initializer first, + // masking the is_constant_op_output code path under test. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* output_arg = builder.MakeOutput(); + + // Create a Constant op node that produces an int8 tensor. + ONNX_NAMESPACE::TensorProto constant_tensor; + constant_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8); + constant_tensor.add_dims(4); + const std::vector raw_vals = {-10, 0, 10, 20}; + constant_tensor.set_raw_data(raw_vals.data(), raw_vals.size() * sizeof(int8_t)); + + auto* constant_output = builder.MakeIntermediate(); + constant_tensor.set_name(constant_output->Name()); + builder.AddNode("Constant", {}, {constant_output}).AddAttribute("value", constant_tensor); + + // DQ node that dequantizes the Constant op output + constexpr float qdq_scale = 0.1f; + constexpr int8_t qdq_zero_point = 0; + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(constant_output, qdq_scale, qdq_zero_point, dq_output); + + // Reshape downstream of DQ + auto* reshape_shape = builder.Make1DInitializer({2, 2}); + builder.AddNode("Reshape", {dq_output, reshape_shape}, {output_arg}); + }; + + // post_graph_checker runs on Graph& directly, after only QDQPropagationTransformer. + // QuantizeLinear must not have been inserted anywhere. + auto post_graph_checker = [&](Graph& graph) -> Status { + const auto op_counts = CountOpsInGraph(graph); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + TEST_RETURN_IF_NOT(op_counts.count(qdq_keys.quantize_linear) == 0 || + op_counts.at(qdq_keys.quantize_linear) == 0); + return Status::OK(); + }; + + const auto& logger = DefaultLoggingManager().DefaultLogger(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, logger, + std::make_unique(), + TransformerLevel::Level1, 1, + nullptr, post_graph_checker)); + } +} + TEST(QDQTransformerTests, QDQPropagation_StopAtOtherQDQ) { auto test_case = [&](const std::vector& input_shape, bool same_scale, bool same_zp, bool use_contrib_qdq) { diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 94b41149a32d1..13b75f3c6e4fa 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -241,9 +241,10 @@ TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) { } TEST(CoreMLExecutionProviderTest, GatherWithScalarIndices) { - // For scalar inputs, the input shape is modified from [] -> [1] before passing the input to CoreML. - // This won't work for Gather because the output shape depends on the `indices` input shape which could be a scalar. - // Currently, we expect the CoreML EP to only take the Shape node in this graph (Gather -> Shape). + // The CoreML EP supports scalar 'indices' for Gather only when the 'data' input has a fully + // static shape (it needs to claim a static intermediate shape for the post-gather squeeze). + // This model's 'data' input is dynamic ([M, N, K]) so Gather still falls back to CPU and the + // CoreML EP only takes the Shape node. const auto model_file_name = ORT_TSTR("testdata/gather_with_scalar_indices_then_shape.onnx"); #if defined(__APPLE__) @@ -2359,6 +2360,656 @@ TEST(CoreMLExecutionProviderTest, Split11SingleOutputNotSupported) { TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::None); } +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesAxis1) { + // ai.onnx:Gather with rank-0 (scalar) 'indices'. ONNX output rank = + // data_rank + indices_rank - 1 = 2. The CoreML builder internally promotes + // indices to [1], runs gather, then squeezes the inserted axis. Pattern + // produced by StyleGAN-family generators (e.g. GFPGAN) that pick a + // per-layer style code with a scalar index. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_axis1", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + // data X: {1, 4, 8} float + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(1); + data_shape->add_dim()->set_dim_value(4); + data_shape->add_dim()->set_dim_value(8); + + // output Y: {1, 8} + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(1); + output_shape->add_dim()->set_dim_value(8); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + // Scalar int64 index with value 2. + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + // No dims => rank-0 tensor. + idx_init.add_int64_data(2); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar", "Gather", "Gather with scalar indices", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 4, 8}; + std::vector input_data(1 * 4 * 8); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = static_cast(i) * 0.25f - 1.0f; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesAxis1_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesAxis1_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesAxis0) { + // Scalar Gather along axis 0 — squeeze axis is 0; covers a different + // squeeze position than the axis=1 test. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_axis0", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + // data X: {6, 5} float + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(6); + data_shape->add_dim()->set_dim_value(5); + + // output Y: {5} + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(5); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(4); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_axis0", "Gather", "Gather scalar idx axis=0", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(0)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {6, 5}; + std::vector input_data(6 * 5); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = static_cast(i) - 12.5f; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesAxis0_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesAxis0_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesNegativeAxis) { + // Scalar Gather with negative axis (-1) — verifies HandleNegativeAxis is + // applied when computing the squeeze axis. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_negative_axis", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + // data X: {2, 3, 4} float + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(2); + data_shape->add_dim()->set_dim_value(3); + data_shape->add_dim()->set_dim_value(4); + + // output Y: {2, 3} (axis=-1 == axis 2; output drops that axis) + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(2); + output_shape->add_dim()->set_dim_value(3); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(1); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_neg_axis", "Gather", "Gather scalar idx axis=-1", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(-1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {2, 3, 4}; + std::vector input_data(2 * 3 * 4); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = static_cast(i) * 0.5f; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesNegativeAxis_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesNegativeAxis_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesFloat16) { + // FLOAT16 'data' input. HasSupportedInputsImpl restricts fp16 Gather to + // MLProgram on CoreML 6+, so this test only runs the MLProgram path. + // Exercises the MLFloat16 branch of the static intermediate shape claim. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_fp16", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(1); + data_shape->add_dim()->set_dim_value(4); + data_shape->add_dim()->set_dim_value(8); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(1); + output_shape->add_dim()->set_dim_value(8); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(2); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_fp16", "Gather", "Gather scalar idx fp16 data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 4, 8}; + std::vector input_data; + input_data.reserve(1 * 4 * 8); + for (size_t i = 0; i < 1 * 4 * 8; ++i) { + input_data.emplace_back(static_cast(i) * 0.25f - 1.0f); + } + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesFloat16_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesInt64Data) { + // INT64 'data' input. HasSupportedInputsImpl allows int64 in both NN and + // MLProgram; verify both formats correctly route int64 through the + // expand/gather/squeeze chain. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_int64_data", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(3); + data_shape->add_dim()->set_dim_value(4); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(4); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(1); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_int64", "Gather", "Gather scalar idx int64 data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(0)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {3, 4}; + std::vector input_data; + input_data.reserve(3 * 4); + for (int64_t i = 0; i < 3 * 4; ++i) input_data.push_back(i * 1000 - 5000); + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesInt64Data_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesInt64Data_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesInt32Indices) { + // INT32 'indices'. The other scalar-indices tests use INT64 indices (the + // PyTorch default); this one exercises the INT32 branch through both the + // dtype gating in IsOpSupportedImpl and the indices_dtype path-through to + // the reshape's intermediate output dtype in AddToModelBuilderImpl. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_int32_indices", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_value(3); + data_shape->add_dim()->set_dim_value(4); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(4); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + idx_init.add_int32_data(2); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_int32_idx", "Gather", "Gather scalar int32 idx", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(0)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {3, 4}; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesInt32Indices_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesInt32Indices_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesRank4Data) { + // Rank-4 'data' input — the supported maximum for scalar Gather (the + // pre-squeeze intermediate is rank 4; CoreML's compiler rejects scalar + // Gather at rank 5 with "Invalid rank: 6"). Output is rank 3. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_rank4", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + for (int64_t d : {2, 5, 3, 4}) data_shape->add_dim()->set_dim_value(d); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + // Gather on axis=1 with scalar idx removes that axis: {2,3,4} + for (int64_t d : {2, 3, 4}) output_shape->add_dim()->set_dim_value(d); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(3); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_rank4", "Gather", "Gather scalar idx rank-4 data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {2, 5, 3, 4}; + std::vector input_data(2 * 5 * 3 * 4); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = static_cast(i) * 0.1f - 5.0f; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesRank4Data_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesRank4Data_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesRank1Data) { + // Rank-1 'data' input with scalar indices — output is rank-0 (the pre-squeeze + // intermediate is rank 1, squeezed to a scalar). Confirms CoreML actually + // produces a rank-0 result on both NN and MLProgram paths. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_rank1", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + data_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(6); + + // Output is rank-0: TypeProto with a shape that has no dims. + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + output_type.mutable_tensor_type()->mutable_shape(); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(2); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_rank1", "Gather", "Gather scalar idx rank-1 data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(0)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {6}; + std::vector input_data(6); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = static_cast(i) - 2.5f; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesRank1Data_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "GatherScalarIndicesRank1Data_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesDynamicDataNotSupported) { + // The scalar-indices path emits a reshape-+squeeze chain whose intermediate + // shape we have to claim statically. IsOpSupportedImpl rejects the node + // when 'data' has any unknown dim so it falls back to CPU rather than + // produce an ill-formed CoreML program. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_dynamic_data", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + data_shape->add_dim()->set_dim_param("N"); // dynamic leading dim + data_shape->add_dim()->set_dim_value(4); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("N"); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(0); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_dyn", "Gather", "Gather scalar idx, dynamic data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::None); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::None); +} + +TEST(CoreMLExecutionProviderTest, GatherScalarIndicesRank5DataNotSupported) { + // Scalar-indices Gather caps data rank at 4 (CoreML compiler reports + // "Invalid rank: 6" on the rank-5 reshape+gather intermediate). Rank-5 + // 'data' must fall back to CPU. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("gather_scalar_indices_rank5", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto data_type; + data_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* data_shape = data_type.mutable_tensor_type()->mutable_shape(); + for (int64_t d : {2, 3, 4, 5, 6}) data_shape->add_dim()->set_dim_value(d); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + // axis=2 with scalar idx removes that axis: {2,3,5,6} + for (int64_t d : {2, 3, 5, 6}) output_shape->add_dim()->set_dim_value(d); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &data_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &output_type); + + ONNX_NAMESPACE::TensorProto idx_init; + idx_init.set_name("idx"); + idx_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + idx_init.add_int64_data(2); + graph.AddInitializedTensor(idx_init); + auto& idx_arg = graph.GetOrCreateNodeArg("idx", nullptr); + + auto& node = graph.AddNode("gather_scalar_rank5", "Gather", "Gather scalar idx rank-5 data", + {&input_arg, &idx_arg}, {&output_arg}); + node.AddAttribute("axis", static_cast(2)); + + ASSERT_STATUS_OK(graph.Resolve()); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::None); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::None); +} + #endif // !(ORT_MINIMAL_BUILD) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index a54c35accbdc7..2e34fd58a2628 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -349,14 +349,14 @@ TEST(NvExecutionProviderTest, LoadUnloadPluginLibrary) { size_t num_devices = 0; ASSERT_ORTSTATUS_OK(Ort::GetApi().GetEpDevices(*ort_env, &ep_devices, &num_devices)); - // should be one device for the example EP + // should be at least one device for the example EP auto num_test_ep_devices = std::count_if(ep_devices, ep_devices + num_devices, [®istration_name, &c_api](const OrtEpDevice* device) { // the example uses the registration name for the EP name // but that is not a requirement and the two can differ. return c_api->EpDevice_EpName(device) == registration_name; }); - ASSERT_EQ(num_test_ep_devices, 1) << "Expected an OrtEpDevice to have been created by the test library."; + ASSERT_GE(num_test_ep_devices, 1) << "Expected at least one OrtEpDevice to have been created by the test library."; // and this should unload it ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(*ort_env,