diff --git a/.github/configs/ascend.yml b/.github/configs/ascend.yml new file mode 100644 index 0000000000..03fc5acaf5 --- /dev/null +++ b/.github/configs/ascend.yml @@ -0,0 +1,15 @@ +# Huawei Ascend NPU configuration +image: ascend-infer:ubuntu18.04 +labels: + - npu + - ascend +docker_options: | + --device /dev/davinci0 + --device /dev/davinci1 + --device /dev/davinci2 + --device /dev/davinci3 + --device /dev/davinci_manager + --device /dev/devmm_svm + --device /dev/hisi_hdc + --volume /usr/local/Ascend/driver:/usr/local/Ascend/driver + --volume /usr/local/Ascend/add-ons:/usr/local/Ascend/add-ons \ No newline at end of file diff --git a/.github/configs/cuda.yml b/.github/configs/cuda.yml new file mode 100644 index 0000000000..1c77fe6c25 --- /dev/null +++ b/.github/configs/cuda.yml @@ -0,0 +1,70 @@ +# CUDA Hardware Configuration for TransformerEngine-FL +# Refactored for A100 Nodes +# This file defines environment variables, volumes, and test filters for TE tests. + +hardware_name: cuda +display_name: 'NVIDIA CUDA (A100)' + +# CI image for online env +ci_image: harbor.baai.ac.cn/flagscale/cuda12.8.1-torch2.7.1-python3.10-te2.9:20260209 + +# Runner labels for self-hosted A100 node +# runner_labels: +# - self-hosted +# - Linux +# - X64 +# - nvidia +# - gpu-8 + +# Runner labels for online env +runner_labels: + - nv-8g-cicd-te + +# Container volumes +container_volumes: + - /home/flagscale_cicd/flask/static:/workspace/report + +# Container options +container_options: >- + --privileged + --gpus all + --shm-size=500g + --ipc=host + --ulimit memlock=-1 + --ulimit stack=67108864 + --user root + +# Platform-specific environment setup script +setup_script: .github/scripts/setup_cuda.sh + +# Build environment variables (platform-specific) +build_env: + TE_FL_SKIP_CUDA: '0' + SKIP_CUDA_BUILD: '0' + NVTE_WITH_CUDA: '1' + NVTE_WITH_MACA: '0' + TE_WITH_NCCL: '1' + NVTE_FRAMEWORK: pytorch + CUDA_HOME: /usr/local/cuda-12.8 + NVCC: /usr/local/cuda-12.8/bin/nvcc + +# Device types to run tests on +device_types: + - a100 + +# Test matrix configuration +test_matrix: + l0_pytorch: + path: 'qa/L0_pytorch_unittest/test.sh' + ignored_tests: + - test_sanity_layernorm_mlp + - test_sanity_gpt + - test_sanity_bert + - test_sanity_T5 + - test_sanity_amp_and_nvfuser + - test_sanity_drop_path + - test_layernorm_mlp_accuracy + - test_grouped_linear_accuracy + - test_gpt_accuracy + - test_basic_linear + - test_layer_norm diff --git a/.github/configs/metax.yml b/.github/configs/metax.yml new file mode 100644 index 0000000000..00b4e1df34 --- /dev/null +++ b/.github/configs/metax.yml @@ -0,0 +1,68 @@ +# Metax Hardware Configuration for TE-FL +# This file defines CI/CD settings for Metax-based testing +# This file defines environment variables, volumes, and test filters for TE tests. + +hardware_name: metax +display_name: 'Metax Tests' + +# CI image for Metax dev env +# ci_image: localhost:5000/megatron-lm-with-te:v1 + +# CI image for online env +ci_image: harbor.baai.ac.cn/flagscale/megatron-lm-with-te:202603231839 + +# Runner labels for self-hosted Metax node +# runner_labels: +# - self-hosted +# - Linux +# - X64 +# - metax +# - dev + +# Runner labels for online env +runner_labels: + - mx-4g-cicd-te + +# Container volumes +container_volumes: + - /nfs/metax_fs:/nfs/metax_fs + +# Container options +container_options: >- + --uts=host + --ipc=host + --privileged=true + --group-add video + --shm-size=100gb + --ulimit memlock=-1 + --user root + --ulimit nofile=65535:65535 + -e PLATFORM=metax + -e TORCH_DISTRIBUTED_BACKEND=mccl + -e LD_LIBRARY_PATH=/opt/maca/lib:/usr/local/lib:$LD_LIBRARY_PATH + +# Platform-specific environment setup script +setup_script: .github/scripts/setup_metax.sh + +# Build environment variables (platform-specific) +build_env: + TE_FL_SKIP_CUDA: '1' + NVTE_WITH_MACA: '1' + CUDA_HOME: /opt/maca + MACA_HOME: /opt/maca + +# Device types to run tests on +device_types: + - c500 + +# Test matrix configuration +test_matrix: + unit: + devices: + - c500 + # Ignored test files for unit tests + # These files will be skipped when running pytest + ignored_tests: + # example: tests/unit_tests/test_example.py + # - tests/unit_tests/test_inference.py + # - tests/unit_tests/test_rl_utils.py diff --git a/.github/configs/template.yml b/.github/configs/template.yml new file mode 100644 index 0000000000..c7ec56b3e9 --- /dev/null +++ b/.github/configs/template.yml @@ -0,0 +1,16 @@ +# Configuration Template +# This file describes the structure for hardware-specific configurations. +# +# Fields: +# - image: Docker image to use for the runner +# - labels: List of labels for the runner +# - docker_options: Additional Docker options for mounting devices, volumes, etc. +# +# Example: +# image: +# labels: +# - +# - +# docker_options: | +# --option1 value1 +# --option2 value2 \ No newline at end of file diff --git a/.github/scripts/setup_cuda.sh b/.github/scripts/setup_cuda.sh new file mode 100755 index 0000000000..f9e289c6d0 --- /dev/null +++ b/.github/scripts/setup_cuda.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# CUDA Platform Environment Setup Script +# Called by unit_tests_common.yml for CUDA platforms (A100, H100, etc.) +set -euo pipefail + +echo "===== Step 0: Activate Python environment =====" +source /opt/miniconda3/etc/profile.d/conda.sh +conda activate flagscale-train +echo "PATH=$PATH" >> $GITHUB_ENV +echo "Python: $(which python3) ($(python3 --version 2>&1))" + +echo "===== Step 1: Remove Existing TransformerEngine =====" +pip uninstall transformer_engine transformer_engine_torch -y || true + +echo "===== Step 2: Build & Install TransformerEngine =====" +cd $GITHUB_WORKSPACE + +pip install nvdlfw-inspect --quiet +pip install expecttest --quiet +pip install . -v --no-deps --no-build-isolation + +echo "===== Step 3: Verify Installation =====" +python3 tests/pytorch/test_sanity_import.py + +echo "===== Environment Setup Complete =====" diff --git a/.github/scripts/setup_metax.sh b/.github/scripts/setup_metax.sh new file mode 100755 index 0000000000..a2d0b0a4cf --- /dev/null +++ b/.github/scripts/setup_metax.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +# Metax Platform Environment Setup Script +# Called by unit_tests_common.yml for Metax platforms (C500, etc.) +set -euo pipefail + +echo "===== Step 0: Activate Python environment =====" +source /opt/conda/etc/profile.d/conda.sh +conda activate base +echo "PATH=$PATH" >> $GITHUB_ENV +echo "Python: $(which python3) ($(python3 --version 2>&1))" + +echo "===== Step 1: Base Environment Setup =====" +# Configure MACA toolchain paths +export PATH=/opt/maca/bin:$PATH +export LD_LIBRARY_PATH=/opt/maca/lib:$LD_LIBRARY_PATH +service ssh restart + +echo "===== Step 2: Create nvcc Symlink (cucc -> nvcc) =====" +# TransformerEngine expects nvcc, but MACA provides cucc +ln -sf /opt/maca/tools/cu-bridge/bin/cucc /opt/maca/tools/cu-bridge/bin/nvcc +which nvcc || true + +echo "===== Step 3: Install Required System Tools =====" +# Use apt to install git, curl +sed -i 's|http://mirrors.aliyun.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list +apt-get update -qq || true +apt-get install -y -qq git curl +# Install cmake and ninja via pip (more reliable than apt in this env) +python3 -m pip install cmake ninja torch --no-cache-dir + +echo "===== Step 4: Remove Existing TransformerEngine =====" +# Prevent conflicts with preinstalled or incompatible versions +python3 -m pip uninstall transformer_engine -y || true +python3 -m pip install nvdlfw-inspect --no-deps || true + +echo "===== Step 5: Install TE-FL Plugin Layer =====" +# Install TransformerEngine-FL Python layer (plugin logic) +cd $GITHUB_WORKSPACE +TE_FL_SKIP_CUDA=1 python3 setup.py install + +echo "===== Step 6: Final Verification =====" +# Verify both TE Python API and backend are functional +python3 - <<'EOF' +import transformer_engine +import transformer_engine_torch as te +print("transformer_engine:", transformer_engine) +print("transformer_engine_torch:", te) +EOF + +echo "===== Environment Setup Complete =====" diff --git a/.github/workflows/all_tests_ascend.yml b/.github/workflows/all_tests_ascend.yml new file mode 100644 index 0000000000..04e8f3cba0 --- /dev/null +++ b/.github/workflows/all_tests_ascend.yml @@ -0,0 +1,32 @@ +name: ascend_tests + +on: + # push: + # branches: ["main"] + # pull_request: + # branches: ["main"] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ github.actor }} + cancel-in-progress: true + +jobs: + run_tests: + # Package manager and environment settings are read from .github/configs/ascend.yml + uses: ./.github/workflows/all_tests_common.yml + with: + platform: ascend + + all_tests: + needs: run_tests + runs-on: ubuntu-latest + if: always() + steps: + - name: Verify workflow status + run: | + if [ "${{ needs.run_tests.result }}" != "success" ]; then + echo "❌ Tests workflow failed" + exit 1 + fi + echo "✅ All tests passed!" diff --git a/.github/workflows/all_tests_common.yml b/.github/workflows/all_tests_common.yml new file mode 100644 index 0000000000..606a0d3e86 --- /dev/null +++ b/.github/workflows/all_tests_common.yml @@ -0,0 +1,187 @@ +name: Common All Tests + +on: + workflow_call: + inputs: + platform: + required: true + type: string + description: Platform name (e.g., cuda, default) + run_unit_tests: + required: false + type: boolean + default: true + description: Whether to run unit tests in this workflow + run_integration_tests: + required: false + type: boolean + default: true + description: Whether to run integration tests in this workflow + +jobs: + checkout_and_config: + name: checkout_and_config + defaults: + run: + shell: bash + runs-on: ubuntu-latest + outputs: + ci_image: ${{ steps.config.outputs.ci_image }} + runs_on: ${{ steps.config.outputs.runs_on }} + container_volumes: ${{ steps.config.outputs.container_volumes }} + container_options: ${{ steps.config.outputs.container_options }} + device_types: ${{ steps.config.outputs.device_types }} + setup_script: ${{ steps.config.outputs.setup_script }} + build_env: ${{ steps.config.outputs.build_env }} + steps: + - name: Checkout source code + uses: actions/checkout@v4 + + - name: Load platform configuration + id: config + run: | + set -euo pipefail + + PLATFORM="${{ inputs.platform }}" + CONFIG_FILE=".github/configs/${PLATFORM}.yml" + + # Install mikefarah/yq (v4) for YAML parsing + sudo wget -qO /usr/local/bin/yq https://github.com/mikefarah/yq/releases/download/v4.45.1/yq_linux_amd64 + sudo chmod +x /usr/local/bin/yq + /usr/local/bin/yq --version + echo "Loading configuration from $CONFIG_FILE" + + # Read CI image + CI_IMAGE=$(yq '.ci_image' "$CONFIG_FILE") + echo "ci_image=$CI_IMAGE" >> $GITHUB_OUTPUT + + # Read runner labels and format as JSON array + RUNS_ON=$(yq '.runner_labels | tojson(0)' "$CONFIG_FILE") + echo "runs_on=$RUNS_ON" >> $GITHUB_OUTPUT + + # Read container volumes and format as JSON array + VOLUMES=$(yq '.container_volumes | tojson(0)' "$CONFIG_FILE") + echo "container_volumes=$VOLUMES" >> $GITHUB_OUTPUT + + # Read container options + OPTIONS=$(yq '.container_options' "$CONFIG_FILE") + echo "container_options=$OPTIONS" >> $GITHUB_OUTPUT + + # Read device types + DEVICE_TYPES=$(yq '.device_types | tojson(0)' "$CONFIG_FILE") + echo "device_types=$DEVICE_TYPES" >> $GITHUB_OUTPUT + + # Read setup script path + SETUP_SCRIPT=$(yq '.setup_script // ""' "$CONFIG_FILE") + echo "setup_script=$SETUP_SCRIPT" >> $GITHUB_OUTPUT + + # Read build environment variables (default to empty object if not defined) + BUILD_ENV=$(yq '.build_env // {} | tojson(0)' "$CONFIG_FILE") + echo "build_env=$BUILD_ENV" >> $GITHUB_OUTPUT + + unit_tests: + name: unit_tests + if: inputs.run_unit_tests + needs: + - checkout_and_config + strategy: + fail-fast: false + matrix: + device: ${{ fromJson(needs.checkout_and_config.outputs.device_types) }} + uses: ./.github/workflows/unit_tests_common.yml + with: + platform: ${{ inputs.platform }} + device: ${{ matrix.device }} + image: ${{ needs.checkout_and_config.outputs.ci_image }} + runs_on: ${{ needs.checkout_and_config.outputs.runs_on }} + container_volumes: ${{ needs.checkout_and_config.outputs.container_volumes }} + container_options: ${{ needs.checkout_and_config.outputs.container_options }} + setup_script: ${{ needs.checkout_and_config.outputs.setup_script }} + build_env: ${{ needs.checkout_and_config.outputs.build_env }} + + unit_tests_complete: + name: unit_tests_complete + needs: + - unit_tests + runs-on: ubuntu-latest + if: always() && inputs.run_unit_tests + steps: + - name: Check unit tests result + run: | + if [ "${{ needs.unit_tests.result }}" != "success" ] && \ + [ "${{ needs.unit_tests.result }}" != "skipped" ]; then + echo "❌ Unit tests failed: ${{ needs.unit_tests.result }}" + exit 1 + fi + echo "✅ Unit tests passed" + + integration_tests: + name: integration_tests + if: inputs.run_integration_tests + needs: + - checkout_and_config + - unit_tests_complete + strategy: + fail-fast: false + matrix: + device: ${{ fromJson(needs.checkout_and_config.outputs.device_types) }} + uses: ./.github/workflows/integration_tests_common.yml + with: + platform: ${{ inputs.platform }} + device: ${{ matrix.device }} + image: ${{ needs.checkout_and_config.outputs.ci_image }} + runs_on: ${{ needs.checkout_and_config.outputs.runs_on }} + container_volumes: ${{ needs.checkout_and_config.outputs.container_volumes }} + container_options: ${{ needs.checkout_and_config.outputs.container_options }} + setup_script: ${{ needs.checkout_and_config.outputs.setup_script }} + build_env: ${{ needs.checkout_and_config.outputs.build_env }} + + integration_tests_complete: + name: integration_tests_complete + if: always() && inputs.run_integration_tests + needs: + - integration_tests + runs-on: ubuntu-latest + steps: + - name: Check integration tests result + run: | + if [ "${{ needs.integration_tests.result }}" != "success" ] && \ + [ "${{ needs.integration_tests.result }}" != "skipped" ]; then + echo "❌ Integration tests failed: ${{ needs.integration_tests.result }}" + exit 1 + fi + echo "✅ Integration tests passed" + + all_tests_complete: + defaults: + run: + shell: bash + needs: + - checkout_and_config + - unit_tests_complete + - integration_tests_complete + runs-on: ubuntu-latest + if: always() + steps: + - name: Verify all tests passed + run: | + # Check all test jobs (skip if not run) + failed=false + + if [ "${{ needs.unit_tests_complete.result }}" != "success" ] && \ + [ "${{ needs.unit_tests_complete.result }}" != "skipped" ]; then + echo "❌ Unit tests failed or cancelled: ${{ needs.unit_tests_complete.result }}" + failed=true + fi + + if [ "${{ needs.integration_tests_complete.result }}" != "success" ] && \ + [ "${{ needs.integration_tests_complete.result }}" != "skipped" ]; then + echo "❌ Integration tests failed or cancelled: ${{ needs.integration_tests_complete.result }}" + failed=true + fi + + if [ "$failed" = "true" ]; then + exit 1 + fi + + echo "✅ All tests completed successfully!" \ No newline at end of file diff --git a/.github/workflows/all_tests_cuda.yml b/.github/workflows/all_tests_cuda.yml new file mode 100644 index 0000000000..cc7ade9f50 --- /dev/null +++ b/.github/workflows/all_tests_cuda.yml @@ -0,0 +1,34 @@ +name: cuda_tests + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ github.actor }} + cancel-in-progress: true + +jobs: + run_tests: + # Package manager and environment settings are read from .github/configs/cuda.yml + uses: ./.github/workflows/all_tests_common.yml + with: + platform: cuda + run_unit_tests: true + run_integration_tests: true + + all_tests: + needs: run_tests + runs-on: ubuntu-latest + if: always() + steps: + - name: Verify workflow status + run: | + if [ "${{ needs.run_tests.result }}" != "success" ]; then + echo "❌ Tests workflow failed" + exit 1 + fi + echo "✅ All tests passed!" diff --git a/.github/workflows/all_tests_metax.yml b/.github/workflows/all_tests_metax.yml new file mode 100644 index 0000000000..0af545e291 --- /dev/null +++ b/.github/workflows/all_tests_metax.yml @@ -0,0 +1,34 @@ +name: metax_tests + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ github.actor }} + cancel-in-progress: true + +jobs: + run_tests: + # Package manager and environment settings are read from .github/configs/metax.yml + uses: ./.github/workflows/all_tests_common.yml + with: + platform: metax + run_unit_tests: true + run_integration_tests: true + + all_tests: + needs: run_tests + runs-on: ubuntu-latest + if: always() + steps: + - name: Verify workflow status + run: | + if [ "${{ needs.run_tests.result }}" != "success" ]; then + echo "❌ Metax Tests workflow failed" + exit 1 + fi + echo "✅ All Metax tests passed!" \ No newline at end of file diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 88719231ef..cf8f1450d3 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -3,10 +3,12 @@ # See LICENSE for license information. # A workflow to trigger ci on hybrid infra (github + self hosted runner) + +# DISABLED in FlagOS name: Blossom-CI on: issue_comment: - types: [created] + types: [__disabled_do_not_remove__] workflow_dispatch: inputs: platform: diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0f05dbc40a..2ef6d1893d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,160 +1,66 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. # A workflow to trigger TE build on GitHub + name: 'Build' on: pull_request: workflow_dispatch: -concurrency: - # Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes) - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true jobs: - core: - name: 'Core' - runs-on: ubuntu-latest - container: - image: nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04 - options: --user root - steps: - - name: 'Dependencies' - run: | - apt-get update - apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake==3.21.0 pybind11[global] ninja - - name: 'Checkout' - uses: actions/checkout@v3 - with: - submodules: recursive - - name: ccache - uses: mozilla-actions/sccache-action@7d986dd989559c6ecdb630a3fd2557667be217ad - - name: 'Build' - run: NVTE_USE_CCACHE=1 NVTE_CCACHE_BIN=sccache pip install --no-build-isolation . -v - env: - NVTE_FRAMEWORK: none - MAX_JOBS: 1 - SCCACHE_GHA_ENABLED: "true" - - name: 'Sanity check' - run: python3 -c "import transformer_engine" - working-directory: / pytorch: name: 'PyTorch' - runs-on: ubuntu-latest + runs-on: [ nv-8g-cicd-te ] + defaults: + run: + shell: bash + container: + image: harbor.baai.ac.cn/flagscale/cuda12.8.1-torch2.7.1-python3.10-te2.9:20260209 + ports: + - 80:80 + options: >- + --gpus all + --shm-size=500g + --privileged + --ipc=host + --ulimit memlock=-1 + --ulimit stack=67108864 + --ulimit nofile=65535:65535 + --user root + --pull never steps: - - name: Move /var/lib/docker/ - shell: bash -euxo pipefail {0} - run: sudo mv /var/lib/docker/ "${GITHUB_WORKSPACE}/docker" - - - name: Maximize build space - uses: easimon/maximize-build-space@c28619d8999a147d5e09c1199f84ff6af6ad5794 - with: - root-reserve-mb: 5120 - temp-reserve-mb: 32 - swap-size-mb: 10240 - remove-dotnet: 'true' - remove-android: 'true' - remove-haskell: 'true' - remove-codeql: 'true' - build-mount-path: '/var/lib/docker/' - - - name: Restore /var/lib/docker/ - shell: bash -euxo pipefail {0} - run: sudo sh -c "mv ${GITHUB_WORKSPACE}/docker/* /var/lib/docker" + - name: Configure Git Safe Directory on Cuda + run: /usr/bin/git config --global safe.directory '*' - name: 'Checkout' - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + fetch-depth: 0 submodules: recursive + set-safe-directory: true - - name: Start named container - run: | - docker run -v $(pwd):$(pwd) -w $(pwd) --name builder -d nvcr.io/nvidia/cuda:12.8.0-devel-ubuntu22.04 sleep infinity - - - name: 'Dependencies' + - name: 'Setup Environment' run: | - docker exec builder bash -c '\ - apt-get update && \ - apt-get install -y git python3.9 pip cudnn9-cuda-12 && \ - pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript && \ - apt-get clean \ - ' + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + echo "PATH=$PATH" >> $GITHUB_ENV - - name: 'Build' - run: docker exec builder bash -c 'pip install --no-build-isolation . -v --no-deps' - env: - NVTE_FRAMEWORK: pytorch - MAX_JOBS: 1 - - name: 'Sanity check' - run: docker exec builder bash -c 'python3 tests/pytorch/test_sanity_import.py' - jax: - name: 'JAX' - runs-on: ubuntu-latest - container: - image: ghcr.io/nvidia/jax:jax - options: --user root - steps: - - name: 'Dependencies' - run: pip install cmake==3.21.0 pybind11[global] - - name: 'Checkout' - uses: actions/checkout@v3 - with: - submodules: recursive - - name: ccache - uses: mozilla-actions/sccache-action@7d986dd989559c6ecdb630a3fd2557667be217ad - name: 'Build' run: | - NVTE_CCACHE_BIN=sccache NVTE_USE_CCACHE=1 pip install --no-build-isolation . -v + pip uninstall transformer_engine transformer_engine_torch -y || true + echo "GITHUB_WORKSPACE=$GITHUB_WORKSPACE" + cd $GITHUB_WORKSPACE + pip install nvdlfw-inspect + pip install expecttest + pip install . -v --no-deps --no-build-isolation env: - NVTE_FRAMEWORK: jax - MAX_JOBS: 1 - SCCACHE_GHA_ENABLED: "true" - - name: 'Sanity check' - run: python3 tests/jax/test_sanity_import.py - all: - name: 'All' - runs-on: ubuntu-latest - steps: - - name: Move /var/lib/docker/ - shell: bash -euxo pipefail {0} - run: sudo mv /var/lib/docker/ "${GITHUB_WORKSPACE}/docker" - - - name: Maximize build space - uses: easimon/maximize-build-space@c28619d8999a147d5e09c1199f84ff6af6ad5794 - with: - root-reserve-mb: 5120 - temp-reserve-mb: 32 - swap-size-mb: 10240 - remove-dotnet: 'true' - remove-android: 'true' - remove-haskell: 'true' - remove-codeql: 'true' - build-mount-path: '/var/lib/docker/' - - - name: Restore /var/lib/docker/ - shell: bash -euxo pipefail {0} - run: sudo sh -c "mv ${GITHUB_WORKSPACE}/docker/* /var/lib/docker" - - - name: 'Checkout' - uses: actions/checkout@v3 - with: - submodules: recursive - - - name: Start named container - run: | - docker run -v $(pwd):$(pwd) -w $(pwd) --name builder -d ghcr.io/nvidia/jax:jax sleep infinity + NVTE_FRAMEWORK: pytorch + TE_WITH_NCCL: '1' + NVTE_WITH_CUDA: '1' + CUDA_HOME: /usr/local/cuda-12.8 + NVCC: /usr/local/cuda-12.8/bin/nvcc - - name: 'Dependencies' - run: | - docker exec builder bash -c '\ - pip install cmake==3.21.0 pybind11[global] einops onnxscript && \ - pip install torch --no-cache-dir --index-url https://download.pytorch.org/whl/cu130 - ' - - name: 'Build' - run: docker exec builder bash -c 'pip install --no-cache-dir --no-build-isolation . -v --no-deps' - env: - NVTE_FRAMEWORK: all - MAX_JOBS: 1 - name: 'Sanity check' - run: docker exec builder bash -c 'python3 tests/pytorch/test_sanity_import.py && python3 tests/jax/test_sanity_import.py' + run: + python3 tests/pytorch/test_sanity_import.py diff --git a/.github/workflows/integration_tests_common.yml b/.github/workflows/integration_tests_common.yml new file mode 100644 index 0000000000..25f18c866d --- /dev/null +++ b/.github/workflows/integration_tests_common.yml @@ -0,0 +1,134 @@ +name: Common Integration Tests + +on: + workflow_call: + inputs: + platform: + required: true + type: string + device: + required: true + type: string + image: + required: true + type: string + runs_on: + required: true + type: string + container_volumes: + required: true + type: string + container_options: + required: true + type: string + # Platform-specific environment setup script path (from platform config) + setup_script: + required: false + type: string + default: '' + # Platform-specific build environment variables (JSON object from config) + build_env: + required: false + type: string + default: '{}' + +jobs: + integration_test: + defaults: + run: + shell: bash + runs-on: ${{ fromJson(inputs.runs_on) }} + strategy: + fail-fast: false + matrix: + test_group: + - name: pytorch_mcore_integration + path: "qa/L1_pytorch_mcore_integration/test.sh" + test_type: "integration" + name: integration-${{ inputs.device }}-${{ matrix.test_group.name }} + container: + image: ${{ inputs.image }} + volumes: ${{ fromJson(inputs.container_volumes) }} + options: --pull never ${{ inputs.container_options }} + + steps: + # Cuda requires git safe.directory configuration and 3 checkout attempts to handle submodule-heavy repos + - name: Configure Git Safe Directory on Cuda + if: inputs.platform == 'cuda' + run: /usr/bin/git config --global safe.directory '*' + + - name: Checkout Source Code on Cuda (attempt 1) + id: checkout1 + if: inputs.platform == 'cuda' + uses: actions/checkout@v4 + continue-on-error: true + with: + fetch-depth: 0 + submodules: recursive + set-safe-directory: true + + - name: Checkout Source Code on Cuda (attempt 2) + id: checkout2 + if: inputs.platform == 'cuda' && steps.checkout1.outcome == 'failure' + uses: actions/checkout@v4 + continue-on-error: true + with: + fetch-depth: 0 + submodules: recursive + set-safe-directory: true + + - name: Checkout Source Code on Cuda (attempt 3) + id: checkout3 + if: inputs.platform == 'cuda' && steps.checkout2.outcome == 'failure' + uses: actions/checkout@v4 + with: + fetch-depth: 0 + submodules: recursive + set-safe-directory: true + + # Metax requires to clean vscode-remote-container + - name: Configure Clean Git Env on Metax + if: inputs.platform == 'metax' + run: | + git config --global --unset-all credential.helper 2>/dev/null || true + git config --system --unset-all credential.helper 2>/dev/null || true + + # Metax no need submodules + - name: Checkout Source Code on Metax + if: inputs.platform == 'metax' + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Environment Setup + if: inputs.setup_script != '' + run: | + bash $GITHUB_WORKSPACE/${{ inputs.setup_script }} + + - name: Execute Tests + env: + TE_PATH: ${{ github.workspace }} + TE_FL_PREFER: vendor + MCORE_REPO_URL: https://github.com/flagos-ai/Megatron-LM-FL.git + MCORE_REF: main + run: | + set -euo pipefail + + # Activate conda environment + if ${{inputs.platform == 'metax'}}; then + source /opt/conda/etc/profile.d/conda.sh + conda activate base + else + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + fi + echo "PATH=$PATH" >> $GITHUB_ENV + export TE_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/transformer_engine + + echo "=== Running L1 PyTorch Megatron-FL MCore Integration Test ===" + # python3 --version + # pip list | grep -E "regex|six|torch" || true + + bash ${{ matrix.test_group.path }} + timeout-minutes: 30 + \ No newline at end of file diff --git a/.github/workflows/license.yml b/.github/workflows/license.yml index e12f50991f..c40ae1af43 100644 --- a/.github/workflows/license.yml +++ b/.github/workflows/license.yml @@ -5,7 +5,8 @@ # A workflow to trigger the TE license check on GitHub name: 'License' on: - pull_request: + pull_request: + branches: [ "__disabled_do_not_remove__" ] workflow_dispatch: jobs: check: diff --git a/.github/workflows/qa-format.yml b/.github/workflows/qa-format.yml new file mode 100644 index 0000000000..ff1cddf312 --- /dev/null +++ b/.github/workflows/qa-format.yml @@ -0,0 +1,32 @@ +name: format_check + +on: + pull_request: + branches: [ "main" ] + types: [opened, synchronize, reopened] + +jobs: + format: + runs-on: ubuntu-22.04 + env: + PRID: ${{ github.event.pull_request.number }} + BRANCH: main + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.base.ref }} + + - name: Merge PR to sub-branch + run: | + git fetch origin pull/${PRID}/merge + git checkout -b test FETCH_HEAD + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Run pre-commit + run: bash ./qa/format.sh \ No newline at end of file diff --git a/.github/workflows/qa-l0-pytorch-wheel.yml b/.github/workflows/qa-l0-pytorch-wheel.yml new file mode 100644 index 0000000000..aef4396ae8 --- /dev/null +++ b/.github/workflows/qa-l0-pytorch-wheel.yml @@ -0,0 +1,78 @@ +name: QA Pytorch Wheel + +on: + push: + branches: + - __disabled_do_not_remove__ + pull_request: + branches: + - __disabled_do_not_remove__ + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ github.actor }} + cancel-in-progress: true + +jobs: + qa-l0-pytorch-wheel: + runs-on: [ self-hosted, Linux, X64, nvidia, gpu-8 ] + defaults: + run: + shell: bash + container: + image: harbor.baai.ac.cn/flagscale/cuda12.8.1-torch2.7.1-python3.10-te2.9:20260209 + ports: + - 80:80 + options: >- + --gpus all + --shm-size=500g + --privileged + --ipc=host + --ulimit memlock=-1 + --ulimit stack=67108864 + --ulimit nofile=65535:65535 + --user root + --pull always + + steps: + - name: Checkout Code + uses: actions/checkout@v6.0.1 + with: + repository: ${{ github.event.pull_request.head.repo.full_name }} + ref: ${{ github.event.pull_request.head.ref }} + ssh-strict: true + ssh-user: git + persist-credentials: true + clean: true + sparse-checkout-cone-mode: true + fetch-tags: false + show-progress: true + lfs: false + submodules: recursive + set-safe-directory: true + + - name: L0 Pytorch Wheel + id: L0_pytoech_wheel + # timeout-minutes: 50 + env: + TE_PATH: . + RUN_LOG: /logs/pytorch/wheel + run: | + echo "TE_PATH: ${TE_PATH}" + sed -i "s/^cd transformer_engine\/pytorch\s*$/pushd transformer_engine\/pytorch/" qa/L0_pytorch_wheel/test.sh + sed -i '44 s/^cd \s*\$TE_PATH\s*$/popd/' qa/L0_pytorch_wheel/test.sh + + cat qa/L0_pytorch_wheel/test.sh + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + pip uninstall -y transformer_engine + + bash qa/L0_pytorch_wheel/test.sh | tee ${RUN_LOG}/pytorch_wheel-${{ github.run_id }}.log + + - name: Upload Installation Logs + if: always() && steps.L0_pytoech_wheel.outcome == 'failure' + uses: actions/upload-artifact@v4 + with: + name: L0-pytorch-logs-${{ github.run_id }} + path: /logs/pytorch/wheel + retention-days: 7 + if-no-files-found: warn diff --git a/.github/workflows/qa-l0-te-cpp-unittest-pytorch-lint.yml b/.github/workflows/qa-l0-te-cpp-unittest-pytorch-lint.yml new file mode 100644 index 0000000000..f214990581 --- /dev/null +++ b/.github/workflows/qa-l0-te-cpp-unittest-pytorch-lint.yml @@ -0,0 +1,179 @@ +name: QA L0 - Core Unit & Lint Tests + +on: + push: + branches: + - __disabled_do_not_remove__ + pull_request: + branches: + - __disabled_do_not_remove__ + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ github.actor }} + cancel-in-progress: true + +jobs: + run-qa-l0-core-tests: + runs-on: [ self-hosted, Linux, X64, nvidia, gpu-8 ] + defaults: + run: + shell: bash + container: + image: harbor.baai.ac.cn/flagscale/cuda12.8.1-torch2.7.1-python3.10-te2.9:20260209 + ports: + - 80:80 + options: >- + --gpus all + --shm-size=500g + --privileged + --ipc=host + --ulimit memlock=-1 + --ulimit stack=67108864 + --ulimit nofile=65535:65535 + --user root + --pull always + steps: + - name: Checkout Code + uses: actions/checkout@v6.0.1 + with: + repository: ${{ github.event.pull_request.head.repo.full_name }} + ref: ${{ github.event.pull_request.head.ref }} + ssh-strict: true + ssh-user: git + persist-credentials: true + clean: true + sparse-checkout-cone-mode: true + fetch-tags: false + show-progress: true + lfs: false + submodules: recursive + set-safe-directory: true + + - name: Install Dependencies & Build Transformer Engine + # timeout-minutes: 40 + env: + NVTE_FRAMEWORK: pytorch + TE_WITH_NCCL: 1 + run: | + # Activate conda environment + echo "=== Activating Conda Environment ===" + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + # Install Python dependencies with version pinning + echo "=== Installing Python Dependencies ===" + pip install transformers expecttest nvdlfw-inspect --quiet + + # Build and install transformer_engine with verbose output + echo "=== Building & Installing Transformer Engine ===" + pip install --no-build-isolation -vvv . --no-deps + + # Verify TE installation with version check + echo "=== Verifying Transformer Engine Installation ===" + python3 tests/pytorch/test_sanity_import.py + + - name: Verify GPU Availability & Health + run: | + # Execute GPU check + echo "=== Checking GPU Status ===" + source .github/workflows/scripts/gpu_check.sh + wait_for_gpu + + # too heavy, disabled for now + # - name: Run L0 C++ Unit Tests + # # timeout-minutes: 60 + # env: + # TE_PATH: . + # run: | + # # Activate conda environment + # source /opt/miniconda3/etc/profile.d/conda.sh + # conda activate flagscale-train + + # # Get TE library paths with robust detection + # TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') + # TE_CPP_LIB_PATH="${TE_LIB_PATH}/transformer_engine" + + # # Set environment variables for build + # export CMAKE_PREFIX_PATH="${TE_CPP_LIB_PATH}:${CMAKE_PREFIX_PATH}" + # export LD_LIBRARY_PATH="${TE_CPP_LIB_PATH}:${LD_LIBRARY_PATH}" + # NUM_PHYSICAL_CORES=$(nproc) + # NUM_PARALLEL_JOBS=$(nproc) + + # # Build and run C++ tests + # cd $TE_PATH/tests/cpp + # cmake -GNinja -Bbuild . -DTE_LIB_PATH="${TE_CPP_LIB_PATH}" + # cmake --build build + # export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS)) + + # # Run C++ tests with verbose output + # echo "=== Running C++ Unit Tests ===" + # ctest --test-dir build -j$NUM_PARALLEL_JOBS + + - name: PyTorch C++ Lint + # timeout-minutes: 5 + env: + CPP_ONLY: 1 + TE_PATH: . + run: | + # Activate conda environment + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + # Run C++ lint checks + echo "=== Running C++ Lint Checks ===" + bash ./qa/L0_pytorch_lint/test.sh || true + + echo "" + echo "-----------------------------------------------------" + echo "Note: Pylint check ignores errors C0411 (incorrect import position) and W0611 (unused import), which can be achieved by adding the parameter --disable=C0411,W0611" + echo "-----------------------------------------------------" + continue-on-error: true + + - name: PyTorch Python Lint + # timeout-minutes: 5 + env: + PYTHON_ONLY: 1 + TE_PATH: . + run: | + # Activate conda environment + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + # Run PyTorch lint checks + echo "=== Running PyTorch Lint Checks ===" + bash ./qa/L0_pytorch_lint/test.sh || true + + echo "" + echo "-----------------------------------------------------" + echo "Note: Pylint check ignores errors C0411 (incorrect import position) and W0611 (unused import), which can be achieved by adding the parameter --disable=C0411,W0611" + echo "-----------------------------------------------------" + continue-on-error: true + + - name: Run L0 PyTorch Debug Unit Tests + # timeout-minutes: 10 + env: + TE_PATH: . + run: | + # Activate conda environment + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + # Run debug unit tests + echo "=== Running L0 PyTorch Debug Unit Tests ===" + bash ./qa/L0_pytorch_debug_unittest/test.sh + + - name: Run L0 PyTorch Core Unit Tests + # timeout-minutes: 10 + env: + TE_PATH: . + TE_FL_PREFER: vendor + run: | + # Activate conda environment + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + export TE_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/transformer_engine + + # Run core unit tests + echo "=== Running L0 PyTorch Core Unit Tests ===" + bash ./qa/L0_pytorch_unittest/test.sh diff --git a/.github/workflows/qa-l1-te-cpp-pytorch-tests.yml b/.github/workflows/qa-l1-te-cpp-pytorch-tests.yml new file mode 100644 index 0000000000..32a13813ff --- /dev/null +++ b/.github/workflows/qa-l1-te-cpp-pytorch-tests.yml @@ -0,0 +1,165 @@ +name: QA L1 - Comprehensive Integration Tests + +on: + push: + branches: + - __disabled_do_not_remove__ + pull_request: + branches: + - __disabled_do_not_remove__ + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ github.actor }} + cancel-in-progress: true + +jobs: + run-qa-l1-comprehensive-tests: + runs-on: [ self-hosted, Linux, X64, nvidia, gpu-8 ] + defaults: + run: + shell: bash + container: + image: harbor.baai.ac.cn/flagscale/cuda12.8.1-torch2.7.1-python3.10-te2.9:20260209 + ports: + - 80:80 + options: >- + --gpus all + --shm-size=500g + --privileged + --ipc=host + --ulimit memlock=-1 + --ulimit stack=67108864 + --ulimit nofile=65535:65535 + --user root + --pull always + steps: + - name: Checkout Code + uses: actions/checkout@v6.0.1 + with: + repository: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name || github.repository }} + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.ref || github.ref_name }} + ssh-strict: true + ssh-user: git + persist-credentials: true + clean: true + sparse-checkout-cone-mode: true + fetch-tags: false + show-progress: true + lfs: false + submodules: recursive + set-safe-directory: true + + - name: Install Dependencies & Build Transformer Engine + # timeout-minutes: 40 + env: + NVTE_FRAMEWORK: pytorch + TE_WITH_NCCL: 1 + run: | + # Activate conda environment + echo "=== Activating Conda Environment ===" + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + # Install MPI + apt update + apt install -y libopenmpi-dev openmpi-bin openmpi-common + apt install -y libmpich-dev mpich + + # Verify the MPI header file + mpicxx -show | awk '{for(i=1;i<=NF;i++) if($i ~ /-I/) print substr($i,3)}' + + # Verify whether the MPI C++ environment is ready + # 1. Verify whether the MPI C++ compiler (mpicxx) exists + mpicxx --version + # 2. Verify if the MPI library file exists + ls /usr/lib/x86_64-linux-gnu/libmpi_cxx.so + + # Install dependencies + pip install optree looseversion opt_einsum lightning_utilities + + # Clone lightning-thunder + git clone --recurse-submodules https://github.com/Lightning-AI/lightning-thunder.git + + echo "Install transformer_engine" + pip install --no-build-isolation -vvv . --no-deps + + # Verify installation + python3 tests/pytorch/test_sanity_import.py + + - name: Verify GPU Availability & Health + run: | + # Execute GPU check + echo "=== Checking GPU Status ===" + source .github/workflows/scripts/gpu_check.sh + wait_for_gpu + + - name: Run L1 PyTorch Thunder Integration Tests + env: + XML_LOG_DIR: "/logs/pytorch/thunder" + THUNDER_PATH: "lightning-thunder" + TE_PATH: . + TE_FL_PREFER: vendor + run: | + # Activate conda environment + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + export TE_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/transformer_engine + + # Run thunder integration tests + echo "=== Running L1 PyTorch Thunder Integration Tests ===" + bash ./qa/L1_pytorch_thunder_integration/test.sh + # timeout-minutes: 5 + + - name: Run L1 PyTorch Distributed Unit Tests + continue-on-error: true + env: + XML_LOG_DIR: "/logs/pytorch/distributed" + TE_PATH: . + TE_FL_PREFER: vendor + run: | + # Activate conda environment + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + export TE_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/transformer_engine + + # Run distributed unit tests + echo "=== Running L1 PyTorch Distributed Unit Tests ===" + bash ./qa/L1_pytorch_distributed_unittest/test.sh + # timeout-minutes: 5 + + - name: Run L1 PyTorch ONNX Unit Tests + env: + XML_LOG_DIR: "/logs/pytorch/onnx" + TE_PATH: . + TE_FL_PREFER: vendor + run: | + # Activate conda environment + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + export TE_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/transformer_engine + + # Run ONNX unit tests + echo "=== Running L1 PyTorch ONNX Unit Tests ===" + bash ./qa/L1_pytorch_onnx_unittest/test.sh + # timeout-minutes: 30 + + + - name: Run L1 PyTorch Megatron-FL MCore Integration Test + env: + TE_PATH: . + TE_FL_PREFER: vendor + MCORE_REPO_URL: https://github.com/flagos-ai/Megatron-LM-FL.git + MCORE_REF: main + run: | + # Activate conda environment + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + export TE_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/transformer_engine + + echo "=== Running L1 PyTorch Megatron-FL MCore Integration Test ===" + bash ./qa/L1_pytorch_mcore_integration/test.sh + timeout-minutes: 30 diff --git a/.github/workflows/qa-l3-te-pytorch-fa-versions-test.yml b/.github/workflows/qa-l3-te-pytorch-fa-versions-test.yml new file mode 100644 index 0000000000..bb3e0a73fe --- /dev/null +++ b/.github/workflows/qa-l3-te-pytorch-fa-versions-test.yml @@ -0,0 +1,120 @@ +# disabled for requireing hopper or higher Compute Capabilities GPUs +name: QA L3 - Attention Tests + +on: + push: + branches: + - __disabled_do_not_remove__ + pull_request: + branches: + - __disabled_do_not_remove__ + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ github.actor }} + cancel-in-progress: true + +jobs: + run-qa-l3-attention-tests: + runs-on: [ self-hosted, Linux, X64, nvidia, gpu-8 ] + defaults: + run: + shell: bash + container: + image: harbor.baai.ac.cn/flagscale/cuda12.8.1-torch2.7.1-python3.10-te2.9:20260209 + ports: + - 80:80 + options: >- + --gpus all + --shm-size=500g + --privileged + --ipc=host + --ulimit memlock=-1 + --ulimit stack=67108864 + --ulimit nofile=65535:65535 + --user root + --pull always + steps: + - name: Checkout Code + uses: actions/checkout@v6.0.1 + with: + repository: ${{ github.event.pull_request.head.repo.full_name }} + ref: ${{ github.event.pull_request.head.ref }} + ssh-strict: true + ssh-user: git + persist-credentials: true + clean: true + sparse-checkout-cone-mode: true + fetch-tags: false + show-progress: true + lfs: false + submodules: recursive + set-safe-directory: true + + - name: Install Dependencies & Build Transformer Engine + # timeout-minutes: 40 + env: + NVTE_FRAMEWORK: pytorch + TE_WITH_NCCL: 1 + run: | + # Activate conda environment + echo "=== Activating Conda Environment ===" + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + # System dependencies installation with cleanup + echo "=== Installing System Dependencies (MPI) ===" + apt update + apt install -y libopenmpi-dev openmpi-bin openmpi-common + apt install -y libmpich-dev mpich + + # Verify MPI installation comprehensively + echo "=== Verifying MPI Installation ===" + echo "MPI Compiler Path: $(which mpicxx)" + mpicxx --version + echo "MPI Header Paths:" + mpicxx -show | awk '{for(i=1;i<=NF;i++) if($i ~ /-I/) print substr($i,3)}' + + # Verify whether the MPI C++ environment is ready + # 1. Verify whether the MPI C++ compiler (mpicxx) exists + mpicxx --version + # 2. Verify if the MPI library file exists + ls /usr/lib/x86_64-linux-gnu/libmpi_cxx.so + + # Install dependencies + pip install optree looseversion opt_einsum lightning_utilities + + # Clone lightning-thunder + git clone --recurse-submodules https://github.com/Lightning-AI/lightning-thunder.git + + echo "Install transformer_engine" + pip install --no-build-isolation -vvv . --no-deps + + # Verify installation + python3 tests/pytorch/test_sanity_import.py + + - name: Verify GPU Availability & Health + run: | + # Execute GPU check + echo "=== Checking GPU Status ===" + source .github/workflows/scripts/gpu_check.sh + wait_for_gpu + + - name: Run QA L3 PyTorch FlashAttention Versions Test + # timeout-minutes: 30 + env: + XML_LOG_DIR: "/logs/pytorch/attention" + TE_PATH: . + MAX_JOBS: 32 + run: | + # Activate conda environment + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + # Create log directory with proper permissions + echo "=== Preparing Test Environment ===" + mkdir -p "$XML_LOG_DIR" + chmod 777 "$XML_LOG_DIR" + + export TE_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/transformer_engine + + bash ./qa/L3_pytorch_FA_versions_test/test.sh diff --git a/.github/workflows/scripts/gpu_check.sh b/.github/workflows/scripts/gpu_check.sh new file mode 100644 index 0000000000..f7f533b95c --- /dev/null +++ b/.github/workflows/scripts/gpu_check.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# Function to wait for GPU availability using nvidia-smi +# This version uses integer arithmetic instead of bc for better compatibility +wait_for_gpu_nvidia() { + local gpu_count + gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + + while true; do + local memory_usage_array=() + local memory_total_array=() + # Query GPU memory usage and total memory, suppress stderr to prevent exit on failure + mapfile -t memory_usage_array < <(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits 2>/dev/null) + mapfile -t memory_total_array < <(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null) + + local need_wait=false + local max_usage_percent=0 + + # Iterate through each GPU to calculate memory usage percentage + for ((i=0; i<${#memory_usage_array[@]}; i++)); do + # Remove whitespace from nvidia-smi output + local memory_usage_i=${memory_usage_array[$i]// /} + local memory_total_i=${memory_total_array[$i]// /} + + # Validate that memory values are numeric and total memory is greater than 0 + if [[ $memory_usage_i =~ ^[0-9]+$ ]] && [[ $memory_total_i =~ ^[0-9]+$ ]] && [ "$memory_total_i" -gt 0 ]; then + # Calculate percentage using integer arithmetic (multiply by 100 first to avoid precision loss) + local usage_percent=$((memory_usage_i * 100 / memory_total_i)) + # Track the maximum usage percentage across all GPUs + if [ $usage_percent -gt $max_usage_percent ]; then + max_usage_percent=$usage_percent + fi + else + # Log warning for invalid values and continue waiting + echo "Warning: Invalid memory values - usage: '$memory_usage_i', total: '$memory_total_i'" + need_wait=true + break + fi + done + + # If max usage percentage does not exceed 10%, we can proceed + # 10% threshold = 10 (since we're using integer percentages) + if [ "$need_wait" = false ] && [ $max_usage_percent -le 10 ]; then + break + fi + + # Wait and show current status + echo "Waiting for GPU memory usage to drop below 50% (current max usage: ${max_usage_percent}%)" + sleep 1m + done + + echo "All GPUs have sufficient free memory, GPU memory usage ratio is below 50% (current max usage: ${max_usage_percent}%)" +} + +# Main function to detect GPU tool and call appropriate wait function +# Future: Additional chip types can be added here by extending the detection logic +# and implementing corresponding wait functions (e.g., wait_for_gpu_amd, wait_for_gpu_intel, etc.) +wait_for_gpu() { + if command -v nvidia-smi &> /dev/null; then + echo "Detected nvidia-smi, using NVIDIA GPU monitoring" + wait_for_gpu_nvidia + else + echo "Error: Neither nvidia-smi nor mx-smi is available" + echo "Note: If you are using a new chip type, please add GPU idle detection method for your chip" + exit 1 + fi +} diff --git a/.github/workflows/te-plugin-tests.yml b/.github/workflows/te-plugin-tests.yml new file mode 100644 index 0000000000..9b640fcce8 --- /dev/null +++ b/.github/workflows/te-plugin-tests.yml @@ -0,0 +1,107 @@ +name: Plugin - Unit Tests + +on: + push: + branches: main + paths: + - 'transformer_engine/plugin/**' + - '.github/workflows/te-plugin-tests.yml' + pull_request: + branches: main + paths: + - 'transformer_engine/plugin/**' + - '.github/workflows/te-plugin-tests.yml' + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ github.actor }} + cancel-in-progress: true + +jobs: + run-plugin-tests: + runs-on: [ nv-8g-cicd-te ] + defaults: + run: + shell: bash + container: + image: harbor.baai.ac.cn/flagscale/cuda12.8.1-torch2.7.1-python3.10-te2.9:20260209 + ports: + - 80:80 + options: >- + --gpus all + --shm-size=500g + --privileged + --ipc=host + --ulimit memlock=-1 + --ulimit stack=67108864 + --ulimit nofile=65535:65535 + --user root + --pull never + steps: + - name: Checkout Code + uses: actions/checkout@v6.0.1 + with: + repository: ${{ github.event.pull_request.head.repo.full_name }} + ref: ${{ github.event.pull_request.head.ref }} + ssh-strict: true + ssh-user: git + persist-credentials: true + clean: true + sparse-checkout-cone-mode: true + fetch-tags: false + show-progress: true + lfs: false + submodules: recursive + set-safe-directory: true + + - name: Install Dependencies & Build Transformer Engine + # timeout-minutes: 40 + env: + NVTE_FRAMEWORK: pytorch + TE_WITH_NCCL: 1 + run: | + # Activate conda environment + echo "Activating conda environment..." + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + # Print environment information for debugging + echo "=== Environment Info ===" + conda info + python --version + pip --version + gcc --version + nvcc --version + cmake --version + cat /usr/local/cuda-12.8/include/cudnn_version.h | grep -E "CUDNN_MAJOR|CUDNN_MINOR|CUDNN_PATCHLEVEL" + + # Install dependencies + echo "=== Installing Dependencies ===" + pip install transformers expecttest pytest + + # Build and install transformer_engine + echo "=== Building Transformer Engine ===" + pip install --no-build-isolation -vvv . --no-deps + + # Verify installation + echo "=== Verifying Installation ===" + python3 tests/pytorch/test_sanity_import.py + python3 -c "import transformer_engine; print('TE Version:', transformer_engine.__version__)" + + - name: Verify GPU Availability & Health + run: | + # Execute GPU check + echo "=== Checking GPU Status ===" + source .github/workflows/scripts/gpu_check.sh + wait_for_gpu + + - name: Plugin Test + # timeout-minutes: 10 + run: | + # Activate conda environment + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + + # Execute tests (optimized parameters with enhanced output and error capture) + torchrun --nproc_per_node=8 -m pytest -q -x -p no:warnings transformer_engine/plugin/tests + + echo "=== All Plugin Tests Completed Successfully ===" diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index de26531a98..3539f76ee9 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -6,7 +6,7 @@ name: TE-CI Trigger on: issue_comment: - types: [created] + types: [__disabled_do_not_remove__] jobs: Authorization: name: Authorization diff --git a/.github/workflows/unit_tests_common.yml b/.github/workflows/unit_tests_common.yml new file mode 100644 index 0000000000..10a070d9df --- /dev/null +++ b/.github/workflows/unit_tests_common.yml @@ -0,0 +1,201 @@ +name: Common Unit Tests + +on: + workflow_call: + inputs: + platform: + required: true + type: string + device: + required: true + type: string + image: + required: true + type: string + runs_on: + required: true + type: string + container_volumes: + required: true + type: string + container_options: + required: true + type: string + # Platform-specific environment setup script path (from platform config) + setup_script: + required: false + type: string + default: '' + # Platform-specific build environment variables (JSON object from config) + build_env: + required: false + type: string + default: '{}' + +jobs: + unit_test: + defaults: + run: + shell: bash + runs-on: ${{ fromJson(inputs.runs_on) }} + strategy: + fail-fast: false + matrix: + test_group: + - name: pytorch_debug + path: "qa/L0_pytorch_debug_unittest/test.sh" + test_type: "debug" + - name: pytorch_unittest + path: "qa/L0_pytorch_unittest/test.sh" + test_type: "unittest" + - name: pytorch_distributed_unittest + path: "qa/L1_pytorch_distributed_unittest/test.sh" + test_type: "unittest" + name: unit-${{ inputs.device }}-${{ matrix.test_group.name }} + container: + image: ${{ inputs.image }} + volumes: ${{ fromJson(inputs.container_volumes) }} + options: --pull never ${{ inputs.container_options }} + + steps: + # Cuda requires git safe.directory configuration and 3 checkout attempts to handle submodule-heavy repos + - name: Configure Git Safe Directory on Cuda + if: inputs.platform == 'cuda' + run: /usr/bin/git config --global safe.directory '*' + + - name: Checkout Source Code on Cuda (attempt 1) + id: checkout1 + if: inputs.platform == 'cuda' + uses: actions/checkout@v4 + continue-on-error: true + with: + fetch-depth: 0 + submodules: recursive + set-safe-directory: true + + - name: Checkout Source Code on Cuda (attempt 2) + id: checkout2 + if: inputs.platform == 'cuda' && steps.checkout1.outcome == 'failure' + uses: actions/checkout@v4 + continue-on-error: true + with: + fetch-depth: 0 + submodules: recursive + set-safe-directory: true + + - name: Checkout Source Code on Cuda (attempt 3) + id: checkout3 + if: inputs.platform == 'cuda' && steps.checkout2.outcome == 'failure' + uses: actions/checkout@v4 + with: + fetch-depth: 0 + submodules: recursive + set-safe-directory: true + + # Metax requires to clean vscode-remote-container + - name: Configure Clean Git Env on Metax + if: inputs.platform == 'metax' + run: | + git config --global --unset-all credential.helper 2>/dev/null || true + git config --system --unset-all credential.helper 2>/dev/null || true + + # Metax no need submodules + - name: Checkout Source Code on Metax + if: inputs.platform == 'metax' + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Environment Setup + if: inputs.setup_script != '' + run: | + bash $GITHUB_WORKSPACE/${{ inputs.setup_script }} + + - name: Execute Tests + working-directory: ${{ github.workspace }} + run: | + set -euo pipefail + + # Load platform-specific environment variables + while IFS='=' read -r key value; do + [ -n "$key" ] && export "$key=$value" + done < <(echo '${{ inputs.build_env }}' | python3 -c " + import json, sys + env = json.load(sys.stdin) + for k, v in env.items(): + print(f'{k}={v}') + ") + + # Activate conda environment + if ${{inputs.platform == 'metax'}}; then + source /opt/conda/etc/profile.d/conda.sh + conda activate base + else + source /opt/miniconda3/etc/profile.d/conda.sh + conda activate flagscale-train + fi + echo "PATH=$PATH" >> $GITHUB_ENV + + export TE_PATH=$GITHUB_WORKSPACE + export TE_LIB_PATH=$(python3 -c "import site; print(site.getsitepackages()[0])") + export PYTHONPATH=$GITHUB_WORKSPACE:${PYTHONPATH:-} + export PATH=${CUDA_HOME:-/usr/local/cuda}/bin:$PATH + export LD_LIBRARY_PATH=${CUDA_HOME:-/usr/local/cuda}/lib:${LD_LIBRARY_PATH:-} + + # check envs before running tests + echo "TE_PATH=$TE_PATH" + echo "TE_LIB_PATH=$TE_LIB_PATH" + echo "PYTHONPATH=$PYTHONPATH" + echo "PATH=$PATH" + echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" + + # Ensure log directory exists regardless of volume mount state + mkdir -p /logs + + # Coverage setup: install once + configure collection via PYTEST_ADDOPTS + COVERAGE_ENABLED=false + if pip3 install coverage pytest-cov --quiet 2>/dev/null; then + export PYTEST_ADDOPTS="--cov=transformer_engine --cov-append --cov-report=" + COVERAGE_ENABLED=true + else + echo "WARNING: Failed to install coverage/pytest-cov, coverage collection disabled" + fi + + if [[ "${{ matrix.test_group.name }}" != *"debug"* ]]; then + # Fail fast on backend/API mismatch before running the full test group. + # Skip for debug group (does not use FP8/optimizer symbols). + python3 -c "import sys, importlib; import transformer_engine.common as _te_common; tex = importlib.import_module('transformer_engine_torch'); required=['multi_tensor_scale','multi_tensor_compute_scale_and_scale_inv']; missing=[n for n in required if not hasattr(tex, n)]; print('[TE check] module:', tex); print('[TE check] file:', getattr(tex, '__file__', 'N/A')); print('[TE check] missing:', ', '.join(missing) if missing else 'none'); sys.exit(1 if missing else 0)" + fi + + bash ${{ matrix.test_group.path }} + exit_code=$? + + # Combine coverage fragments and generate JSON report + if [ "$COVERAGE_ENABLED" = "true" ]; then + python3 -m coverage combine --keep 2>/dev/null || true + python3 -m coverage json \ + -o "coverage-${{ inputs.platform }}-${{ inputs.device }}-${{ matrix.test_group.name }}.json" \ + --include="transformer_engine/*" 2>/dev/null \ + || echo "WARNING: No coverage data found" + fi + exit $exit_code + timeout-minutes: 60 + + - name: Upload Coverage Report + uses: actions/upload-artifact@v4 + continue-on-error: true + with: + name: coverage-${{ inputs.platform }}-${{ inputs.device }}-${{ matrix.test_group.name }} + path: | + coverage-${{ inputs.platform }}-${{ inputs.device }}-${{ matrix.test_group.name }}.json + + - name: Upload Coverage Report to FlagCICD + uses: flagos-ai/FlagOps/actions/post-pytest-report@v2 + continue-on-error: true + env: + NO_PROXY: "flagcicd-inner.flagos.net" + with: + backend_url: 'http://flagcicd-inner.flagos.net:8000/metrics/' + user_id: '000000000000000000' + report_path: 'coverage-${{ inputs.platform }}-${{ inputs.device }}-${{ matrix.test_group.name }}.json' + fail_on_error: 'false' \ No newline at end of file diff --git a/.gitignore b/.gitignore index 8a627a7e76..605a85a8c9 100644 --- a/.gitignore +++ b/.gitignore @@ -41,5 +41,8 @@ compile_commands.json .nfs tensor_dumps/ artifacts/ +# Auto-generated build configuration (specific to each environment) +transformer_engine/plugin/core/_build_config.py +# Mac OS .DS_Store .claude/ diff --git a/README.rst b/README.rst index 5a6721b04c..13f60bd72e 100644 --- a/README.rst +++ b/README.rst @@ -5,6 +5,9 @@ |License| + +**TransformerEngine-FL is a fork of TransformerEngine that introduces a plugin-based architecture for supporting diverse AI chips, built on top of** `FlagOS `_, **a unified open-source AI system software stack.** + Transformer Engine ================== diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index fdfdee9b1c..a086a238bf 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -92,8 +92,10 @@ def setup_pytorch_extension( include_dirs = [str(path) for path in include_dirs] from torch.utils.cpp_extension import CppExtension + # Use transformer_engine_torch_nv as the native NVIDIA module name + # This allows the plugin system to use transformer_engine_torch as the unified interface return CppExtension( - name="transformer_engine_torch", + name="transformer_engine_torch_nv", sources=[str(src) for src in sources], include_dirs=[str(inc) for inc in include_dirs], extra_compile_args={"cxx": cxx_flags}, diff --git a/build_tools/utils.py b/build_tools/utils.py index 885901068a..1ec3895d84 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -251,8 +251,16 @@ def get_cuda_include_dirs() -> Tuple[str, str]: ] +@functools.lru_cache(maxsize=None) +def skip_cuda_build() -> bool: + """Check if CUDA build should be skipped (for AMD/ROCm or pure FL backend)""" + return bool(int(os.getenv("TE_FL_SKIP_CUDA", "0"))) + + @functools.lru_cache(maxsize=None) def cuda_archs() -> str: + if skip_cuda_build(): + return "" # Return empty string when skipping CUDA build archs = os.getenv("NVTE_CUDA_ARCHS") if archs is None: version = cuda_version() diff --git a/docs/Doxyfile b/docs/Doxyfile index f17ffc297b..7f42e5b0ab 100644 --- a/docs/Doxyfile +++ b/docs/Doxyfile @@ -1606,7 +1606,7 @@ FORMULA_MACROFILE = # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. -USE_MATHJAX = NO +USE_MATHJAX = YES # When MathJax is enabled you can set the default output format to be used for # the MathJax output. See the MathJax site (see: diff --git a/docs/conf.py b/docs/conf.py index d2bba9825a..3785072716 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -112,23 +112,51 @@ autoapi_generate_api_docs = False autoapi_dirs = [root_path / "transformer_engine"] -autoapi_ignore = ["*test*"] +autoapi_ignore = ["*test*", "*/benchmarks/*"] + +suppress_warnings = [ + "autoapi.python_import_resolution", + "autoapi", +] # There are 2 warnings about the same namespace (transformer_engine) in two different c++ api -# docs pages. This seems to be the only way to suppress these warnings. +# docs pages, and "Unknown type: placeholder" warnings from autoapi/breathe. +# Install logging filters at module load time so they catch warnings emitted +# before setup() is called. +import logging as _logging +import warnings as _warnings + +_warnings.filterwarnings("ignore", message=".*Unknown type.*placeholder.*") + + +class _KnownWarningFilter(_logging.Filter): + def filter(self, record): + message = record.getMessage() + if "Duplicate C++ declaration" in message and "transformer_engine" in message: + return False + if "Unknown type" in message and "placeholder" in message: + return False + return True + + +for _logger_name in ["sphinx", "sphinx.application", "autoapi", ""]: + _logging.getLogger(_logger_name).addFilter(_KnownWarningFilter()) + + def setup(app): """Custom Sphinx setup to filter warnings.""" - import logging - - # Filter out duplicate C++ declaration warnings - class DuplicateDeclarationFilter(logging.Filter): - def filter(self, record): - message = record.getMessage() - if "Duplicate C++ declaration" in message and "transformer_engine" in message: - return False - return True - - # Apply filter to Sphinx logger - logger = logging.getLogger("sphinx") - logger.addFilter(DuplicateDeclarationFilter()) + import sphinx.util.logging + + # Monkey-patch Sphinx's warning handler to filter known warnings + original_warning = sphinx.util.logging.SphinxLoggerAdapter.warning + + def filtered_warning(self, msg, *args, **kwargs): + msg_str = str(msg) + if "Unknown type" in msg_str and "placeholder" in msg_str: + return + if "Duplicate C++ declaration" in msg_str and "transformer_engine" in msg_str: + return + return original_warning(self, msg, *args, **kwargs) + + sphinx.util.logging.SphinxLoggerAdapter.warning = filtered_warning diff --git a/qa/L0_pytorch_debug_unittest/README.rst b/qa/L0_pytorch_debug_unittest/README.rst new file mode 100644 index 0000000000..2ba6e9fb0c --- /dev/null +++ b/qa/L0_pytorch_debug_unittest/README.rst @@ -0,0 +1,26 @@ +L0 PyTorch Debug Unittest +========================= + +This directory contains the L0 PyTorch debug unittest runner. + +MetaX ignore rules +------------------ + +MetaX-specific ignored tests are maintained in one place in ``test.sh`` through +the ``METAX_IGNORED_TESTS`` list. + +The main execution flow only calls a helper to decide whether a test should be +skipped, instead of embedding platform-specific matching rules directly in the +main logic. + +This keeps the script easier to maintain and makes it simpler to add new +ignored cases later if needed. + +How to extend +------------- + +If a new test needs to be skipped on MetaX: + +1. Add the full test path to ``METAX_IGNORED_TESTS`` in ``test.sh``. +2. Avoid adding new platform-specific matching logic directly into the main + execution flow. \ No newline at end of file diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index ce65bc4305..2ab7340986 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -1,24 +1,13 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -function error_exit() { - echo "Error: $1" - exit 1 -} - -function test_fail() { - RET=1 - FAILED_CASES="$FAILED_CASES $1" - echo "Error: sub-test failed: $1" -} -RET=0 -FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} : ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/} + : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" @@ -26,22 +15,83 @@ mkdir -p "$XML_LOG_DIR" # Nvinspect will be disabled if no feature is active. : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} -pip install pytest==8.2.1 || error_exit "Failed to install pytest" - -pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" -NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" - -# standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" - -if [ "$RET" -ne 0 ]; then - echo "Error in the following test cases:$FAILED_CASES" - exit 1 -fi -echo "All tests passed" -exit 0 +FAIL=0 + +# It is not installed as a requirement, +# because it is not available on PyPI. +pip install pytest==8.2.1 + +METAX_IGNORED_TESTS=( + "$TE_PATH/tests/pytorch/test_numerics.py" + "$TE_PATH/tests/pytorch/test_sanity.py" +) + +should_skip_on_metax() { + local test_path=$1 + + [ "$PLATFORM" = "metax" ] || return 1 + + local ignored_test + for ignored_test in "${METAX_IGNORED_TESTS[@]}"; do + if [ "$test_path" = "$ignored_test" ]; then + echo "[SKIP] Platform MetaX: Ignoring $test_path" + return 0 + fi + done + + return 1 +} + + +run_test_step() { + local xml_file=$1 + local test_path=$2 + local cmd=$3 + + if should_skip_on_metax "$test_path"; then + return 0 + fi + + echo "-------------------------------------------------------" + echo "[RUN] Executing: $test_path" + eval "$cmd" || FAIL=1 +} + + + +# Step 1: Sanity +run_test_step "test_sanity.xml" "$TE_PATH/tests/pytorch/debug/test_sanity.py" \ +"pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS" + +# Step 2: Config +run_test_step "test_config.xml" "$TE_PATH/tests/pytorch/debug/test_config.py" \ +"pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS" + +# Step 3: Numerics +run_test_step "test_numerics.xml" "$TE_PATH/tests/pytorch/debug/test_numerics.py" \ +"pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS" + +# Step 4: Log +run_test_step "test_log.xml" "$TE_PATH/tests/pytorch/debug/test_log.py" \ +"pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR" + +# Step 5: API Features +run_test_step "test_api_features.xml" "$TE_PATH/tests/pytorch/debug/test_api_features.py" \ +"NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py -k \"not (test_per_tensor_scaling or test_fake_quant or test_statistics_collection or test_statistics_multi_run)\" --no-header --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR" + +# Step 6: Performance +run_test_step "test_perf.xml" "$TE_PATH/tests/pytorch/debug/test_perf.py" \ +"pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR" + + +# Step 7: Sanity 2 +run_test_step "test_sanity_2.xml" "$TE_PATH/tests/pytorch/test_sanity.py" \ +"NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 \ +pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py -k \"not (test_sanity_grouped_linear or test_inference_mode)\" --no-header" + +# Step 8: Numerics 2 +run_test_step "test_numerics_2.xml" "$TE_PATH/tests/pytorch/test_numerics.py" \ +"NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 \ +pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py -k \"not (test_linear_accuracy or test_layernorm_linear_accuracy or test_layernorm_mlp_accuracy or test_grouped_linear_accuracy or test_transformer_layer_hidden_states_format or test_grouped_gemm)\" --no-header" + +exit $FAIL diff --git a/qa/L0_pytorch_lint/test.sh b/qa/L0_pytorch_lint/test.sh index f08dd8a03d..8af10cdfeb 100644 --- a/qa/L0_pytorch_lint/test.sh +++ b/qa/L0_pytorch_lint/test.sh @@ -6,7 +6,7 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -pip3 install cpplint==1.6.0 pylint==3.3.1 +pip3 install cpplint==1.6.0 pylint==3.3.4 if [ -z "${PYTHON_ONLY}" ] then cd $TE_PATH diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index e67cf1bc04..bc4362e23d 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -1,66 +1,130 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. +#!/bin/bash -function error_exit() { - echo "Error: $1" - exit 1 -} -function test_fail() { - RET=1 - FAILED_CASES="$FAILED_CASES $1" +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + +pip install pytest==8.2.1 +FAIL=0 + +IS_CUDA_BACKEND=$(python3 -c "import torch; print('cuda' if torch.cuda.is_available() else 'cpu')" 2>/dev/null) + +test_fail() { + FAIL=1 echo "Error: sub-test failed: $1" } -RET=0 -FAILED_CASES="" -set -x +run_test_step() { + local xml_file=$1 + local test_path=$2 + local cmd=$3 + local label=$4 -: ${TE_PATH:=/opt/transformerengine} -: ${XML_LOG_DIR:=/logs} -mkdir -p "$XML_LOG_DIR" + if [ "$PLATFORM" = "metax" ]; then + case "$test_path" in + *"test_numerics.py" | \ + *"test_sanity.py" | \ + *"test_parallel_cross_entropy.py" | \ + *"test_fused_rope.py" | \ + *"test_gqa.py" | \ + *"test_fused_optimizer.py" | \ + *"test_multi_tensor.py" | \ + *"test_cpu_offloading.py" | \ + *"test_attention.py" | \ + *"test_kv_cache.py" | \ + *"test_checkpoint.py" | \ + *"test_fused_router.py") + echo "-------------------------------------------------------" + echo "[SKIP] Platform MetaX: Ignoring $label" + echo "-------------------------------------------------------" + return 0 + ;; + esac + fi + + if [[ "$IS_CUDA_BACKEND" == *"cuda"* ]]; then + if [[ "$test_path" == *"test_checkpoint.py" || "$test_path" == *"test_cpu_offloading.py" || "$test_path" == *"test_attention.py" ]]; then + echo "-------------------------------------------------------" + echo "[SKIP] CUDA Backend detected: Ignoring $label" + echo "-------------------------------------------------------" + return 0 + fi + fi + + + echo "-------------------------------------------------------" + echo "[RUN] Executing: $label" + + eval "$cmd" || test_fail "$label" +} + + +# Step: Sanity +run_test_step "pytest_test_sanity.xml" "$TE_PATH/tests/pytorch/test_sanity.py" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py -k \"not (test_sanity_layernorm_mlp or test_sanity_gpt or test_sanity_bert or test_sanity_T5 or test_sanity_amp_and_nvfuser or test_sanity_drop_path or test_sanity_fused_qkv_params or test_sanity_gradient_accumulation_fusion or test_inference_mode or test_sanity_normalization_amp or test_sanity_layernorm_linear or test_sanity_linear_with_zero_tokens or test_sanity_grouped_linear)\" --no-header" "test_sanity.py" + +# Step: Recipe +run_test_step "pytest_test_recipe.xml" "$TE_PATH/tests/pytorch/test_recipe.py" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py" "test_recipe.py" + +# Step: Deferred Init +run_test_step "pytest_test_deferred_init.xml" "$TE_PATH/tests/pytorch/test_deferred_init.py" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py" "test_deferred_init.py" + +# Step: Numerics +run_test_step "pytest_test_numerics.xml" "$TE_PATH/tests/pytorch/test_numerics.py" \ +"PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py -k \"not (test_layernorm_mlp_accuracy or test_grouped_linear_accuracy or test_gpt_cuda_graph or test_transformer_layer_hidden_states_format or test_grouped_gemm or test_noncontiguous or test_gpt_checkpointing or test_gpt_accuracy or test_mha_accuracy or test_linear_accuracy or test_linear_accuracy_delay_wgrad_compute or test_rmsnorm_accuracy or test_layernorm_accuracy or test_layernorm_linear_accuracy)\" --no-header" "test_numerics.py" + +# Step: CUDA Graphs +run_test_step "pytest_test_cuda_graphs.xml" "$TE_PATH/tests/pytorch/test_cuda_graphs.py" \ +"PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py" "test_cuda_graphs.py" + +# Step: JIT +run_test_step "pytest_test_jit.xml" "$TE_PATH/tests/pytorch/test_jit.py" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py -k \"not (test_torch_dynamo)\"" "test_jit.py" + +# Step: Fused Rope +run_test_step "pytest_test_fused_rope.xml" "$TE_PATH/tests/pytorch/test_fused_rope.py" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py" "test_fused_rope.py" + +# Step: NVFP4 (Directory) +run_test_step "pytest_test_nvfp4.xml" "$TE_PATH/tests/pytorch/nvfp4" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4" "test_nvfp4" + +# Step: Float8 Tensors +run_test_step "pytest_test_float8tensor.xml" "$TE_PATH/tests/pytorch/test_float8tensor.py" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py" "test_float8tensor.py" + +# Step: GQA +run_test_step "pytest_test_gqa.xml" "$TE_PATH/tests/pytorch/test_gqa.py" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py" "test_gqa.py" + +# Step: Fused Optimizer +run_test_step "pytest_test_fused_optimizer.xml" "$TE_PATH/tests/pytorch/test_fused_optimizer.py" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py" "test_fused_optimizer.py" + +# Step: Parallel Cross Entropy +run_test_step "pytest_test_parallel_cross_entropy.xml" "$TE_PATH/tests/pytorch/test_parallel_cross_entropy.py" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py" "test_parallel_cross_entropy.py" + +# Step: CPU Offloading +run_test_step "pytest_test_cpu_offloading.xml" "$TE_PATH/tests/pytorch/test_cpu_offloading.py" \ +"NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py" "test_cpu_offloading.py" + +# Step: Attention +run_test_step "pytest_test_attention.xml" "$TE_PATH/tests/pytorch/attention/test_attention.py" \ +"python3 -m pytest -s -v --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py" "test_attention.py" + +# Step: Checkpoint +run_test_step "pytest_test_checkpoint.xml" "$TE_PATH/tests/pytorch/test_checkpoint.py" \ +"NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py" "test_checkpoint.py" -pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" - -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" -NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" -export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint -if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then - python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit "Failed to generate checkpoint files" -fi -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" -if [ "$RET" -ne 0 ]; then - echo "Error in the following test cases:$FAILED_CASES" +if [ "$FAIL" -ne 0 ]; then + echo "Some tests failed." exit 1 fi -echo "All tests passed" +echo "All assigned tests passed (some might have been skipped)." exit 0 diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index fe4aab456e..cd7633822f 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -27,6 +27,7 @@ VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" # Core wheel. +rm -rf dist/*.whl 2>/dev/null || true # Clean up any existing wheels NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel" python3 -m wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" @@ -44,6 +45,8 @@ pip3 install --no-build-isolation --no-deps -vvv dist/* || error_exit "Failed to cd $TE_PATH pip3 install --no-build-isolation --no-deps -vvv dist/*.whl || error_exit "Failed to install dist/*.whl --no-deps" +export TE_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/transformer_engine + python3 $TE_PATH/tests/pytorch/test_sanity_import.py || test_fail "test_sanity_import.py" if [ "$RET" -ne 0 ]; then diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9d868d99cf..0a11a129de 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -15,23 +15,134 @@ function test_fail() { RET=0 FAILED_CASES="" +DEBUG_TESTS_READY=0 : ${TE_PATH:=/opt/transformerengine} : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" +# The current CUDA 12.8 test container hits a fused-attention runtime loader +# issue, so keep the distributed numerics suite on the unfused attention path. +export NVTE_FLASH_ATTN="${NVTE_FLASH_ATTN:-0}" +export NVTE_FUSED_ATTN="${NVTE_FUSED_ATTN:-0}" +export NVTE_UNFUSED_ATTN="${NVTE_UNFUSED_ATTN:-1}" + +# Make CUDA runtime libraries discoverable for fused attention kernels. +if [ -z "${CUDA_HOME:-}" ]; then + if [ -d /usr/local/cuda ]; then + export CUDA_HOME=/usr/local/cuda + elif [ -d /usr/local/cuda-12.8 ]; then + export CUDA_HOME=/usr/local/cuda-12.8 + fi +fi +export CUDA_PATH="${CUDA_PATH:-${CUDA_HOME:-}}" + +CUDA_LIB_DIRS=() +for path in \ + "${CUDA_HOME:-}/lib64" \ + "${CUDA_HOME:-}/targets/x86_64-linux/lib" \ + "$(python3 - <<'PY' +import site +from pathlib import Path + +for root in site.getsitepackages(): + candidate = Path(root) / "torch" / "lib" + if candidate.exists(): + print(candidate) + break +PY +)" \ + "$(python3 - <<'PY' +import site +from pathlib import Path + +for root in site.getsitepackages(): + candidate = Path(root) / "nvidia" / "cuda_runtime" / "lib" + if candidate.exists(): + print(candidate) + break +PY +)"; do + if [ -n "$path" ] && [ -d "$path" ]; then + CUDA_LIB_DIRS+=("$path") + fi +done + +if [ "${#CUDA_LIB_DIRS[@]}" -gt 0 ]; then + CUDA_LIB_PATH="$(IFS=:; echo "${CUDA_LIB_DIRS[*]}")" + export LD_LIBRARY_PATH="${CUDA_LIB_PATH}${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" +fi + +python3 - <<'PY' +import ctypes + +for name in ("libcudart.so", "libcudart.so.12"): + try: + ctypes.CDLL(name, mode=ctypes.RTLD_GLOBAL) + print(f"[CUDA] Preloaded {name}") + break + except OSError as exc: + print(f"[CUDA] Failed to preload {name}: {exc}") +PY + + +# It is not installed as a requirement, +# because it is not available on PyPI. +pip uninstall -y nvdlfw-inspect +if pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git && \ + python3 -c "import nvdlfw_inspect.api" >/dev/null 2>&1; then + DEBUG_TESTS_READY=1 +else + echo "Warning: nvdlfw_inspect is unavailable; debug numerics test will be skipped" +fi + pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" +run_test_step() { + local xml_file=$1 + local test_path=$2 + local cmd=$3 + local label=$4 + + if [ "$PLATFORM" = "metax" ]; then + case "$test_path" in + *"test_numerics.py" | \ + *"test_numerics_exact.py" | \ + *"test_torch_fsdp2.py" | \ + *"test_cast_master_weights_to_fp8.py") + echo "-------------------------------------------------------" + echo "[SKIP] Platform MetaX: Ignoring $label" + echo "-------------------------------------------------------" + return 0 + ;; + esac + fi + + echo "-------------------------------------------------------" + echo "[RUN] Executing: $label" + eval "$cmd" || test_fail "$label" +} + +# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" +run_test_step "pytest_test_numerics.xml" "$TE_PATH/tests/pytorch/distributed/test_numerics.py" \ +"python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py" \ +"test_numerics.py" +run_test_step "pytest_test_numerics_exact.xml" "$TE_PATH/tests/pytorch/distributed/test_numerics_exact.py" \ +"python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py" \ +"test_numerics_exact.py" +# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" +run_test_step "pytest_test_torch_fsdp2.xml" "$TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py" \ +"python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py -k 'not (test_distributed)'" \ +"test_torch_fsdp2.py" +# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" +# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" +# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +run_test_step "pytest_test_cp_utils.xml" "$TE_PATH/tests/pytorch/attention/test_cp_utils.py" \ +"python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py" \ +"test_cp_utils.py" +run_test_step "pytest_test_cast_master_weights_to_fp8.xml" "$TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py" \ +"python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py" \ +"test_cast_master_weights_to_fp8.py" # debug tests @@ -42,9 +153,15 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_ : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} -pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +# pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" # standard numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" +if [ "$DEBUG_TESTS_READY" -eq 1 ]; then + run_test_step "pytest_test_numerics_2.xml" "$TE_PATH/tests/pytorch/distributed/test_numerics.py" \ + "NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py" \ + "test_numerics.py (debug)" +else + echo "Skipping debug test_numerics.py because nvdlfw_inspect is unavailable" +fi if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_mcore_integration/test.sh b/qa/L1_pytorch_mcore_integration/test.sh index 06beba8864..7405cdbb47 100644 --- a/qa/L1_pytorch_mcore_integration/test.sh +++ b/qa/L1_pytorch_mcore_integration/test.sh @@ -4,69 +4,149 @@ set -e +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd) + +retry_command() { + local attempts=$1 + local delay_seconds=$2 + shift 2 + + local attempt + for attempt in $(seq 1 "${attempts}"); do + if "$@"; then + return 0 + fi + if [ "${attempt}" -lt "${attempts}" ]; then + echo "Command failed (attempt ${attempt}/${attempts}): $*" + echo "Retrying in ${delay_seconds}s..." + sleep "${delay_seconds}" + fi + done + + echo "Command failed after ${attempts} attempts: $*" + return 1 +} + # Paths -: ${TE_PATH:=/opt/transformerengine} -: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM} +: "${TE_PATH:=$(cd -- "${SCRIPT_DIR}/../.." && pwd)}" +: "${MCORE_PATH:=/workspace/Megatron-LM-FL}" +: "${MCORE_REPO_URL:=https://github.com/flagos-ai/Megatron-LM-FL.git}" +: "${MCORE_REF:=main}" +: "${OUTPUT_DIR:=${TE_PATH}/qa/L1_pytorch_mcore_integration/output}" +: "${DATA_CACHE_PATH:=/tmp/data_cache}" # Check whether FP8 is supported -DEVICE_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | sed 's/[^0-9]//g') -if [[ ${DEVICE_ARCH} -ge 89 ]]; then - WITH_FP8=1 +WITH_FP8= +if command -v nvidia-smi &>/dev/null; then + DEVICE_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | sed 's/[^0-9]//g') + if [[ ${DEVICE_ARCH} -ge 89 ]]; then + WITH_FP8=1 + fi +elif command -v mx-smi &>/dev/null; then + # Metax hardware does not support FP8; leave WITH_FP8 unset + : fi -# Download Megatron-LM if needed +# Download or sync Megatron-LM-FL to the requested repo/ref. if [ ! -d "${MCORE_PATH}" ]; then pushd $(dirname ${MCORE_PATH}) - git clone -b core_r0.12.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM + git config --global --unset-all credential.helper 2>/dev/null || true + git config --system --unset-all credential.helper 2>/dev/null || true + retry_command 3 5 git clone --depth 1 -b "${MCORE_REF}" "${MCORE_REPO_URL}" $(basename ${MCORE_PATH}) popd fi -# Create mock vocab -VOCAB_FILE=${TE_PATH}/qa/L1_pytorch_mcore_integration/vocab.json -printf "" > ${VOCAB_FILE} -printf "{" >> ${VOCAB_FILE} -printf "\"<|endoftext|>\": 0" >> ${VOCAB_FILE} -seq 1 4095 | awk '{ printf(", \"%d\": %d", $1, $1) }' >> ${VOCAB_FILE} -printf "}" >> ${VOCAB_FILE} +if [ -d "${MCORE_PATH}/.git" ]; then + git -C "${MCORE_PATH}" remote set-url origin "${MCORE_REPO_URL}" + retry_command 3 5 git -C "${MCORE_PATH}" fetch --depth 1 origin "${MCORE_REF}" + git -C "${MCORE_PATH}" checkout -B "${MCORE_REF}" "FETCH_HEAD" +fi + +# Megatron-LM-FL tokenizer imports happen at module import time, so direct +# source execution needs these Python deps available before pretrain_gpt.py +# starts. +python3 - <<'PY' || python3 -m pip install --disable-pip-version-check six regex +import regex +import six +print(f"six available: {six.__version__}") +print(f"regex available: {regex.__version__}") +PY + +CHECKPOINT_DIR=${OUTPUT_DIR}/checkpoints +TENSORBOARD_DIR=${OUTPUT_DIR}/tensorboard +mkdir -p "${CHECKPOINT_DIR}" "${TENSORBOARD_DIR}" "${DATA_CACHE_PATH}" /tmp/checkpoints + +echo "Using Megatron-LM-FL repo: ${MCORE_REPO_URL}" +echo "Using Megatron-LM-FL ref: ${MCORE_REF}" +git -C "${MCORE_PATH}" rev-parse --short HEAD -# Megatron-LM invocation +# Megatron-LM-FL invocation. Keep the argument shape aligned with the +# previously validated tp1/pp1 mock-data GPT functional case while letting CI +# exit after a few steps. COMMAND=" NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 -NVTE_FLASH_ATTN=1 -NVTE_FWD_LAYERNORM_SM_MARGIN=0 -NVTE_BWD_LAYERNORM_SM_MARGIN=0 CUDA_DEVICE_MAX_CONNECTIONS=1 -NVTE_BIAS_GELU_NVFUSION=0 -NVTE_BIAS_DROPOUT_FUSION=0 +NCCL_ALGO=Ring +CUBLAS_WORKSPACE_CONFIG=:4096:8 -python3 --m torch.distributed.launch ---use_env +torchrun --nnodes=1 --nproc_per_node=1 ${MCORE_PATH}/pretrain_gpt.py --tensor-model-parallel-size 1 --pipeline-model-parallel-size 1 ---use-cpu-initialization ---num-layers 2 ---hidden-size 128 +--num-layers 12 +--hidden-size 512 --num-attention-heads 8 ---seq-length 128 ---max-position-embeddings 128 ---micro-batch-size 1 ---global-batch-size 8 ---train-iters 10 +--log-params-norm +--log-num-zeros-in-grad +--log-validation-ppl-to-tensorboard +--log-timers-to-tensorboard +--seq-length 1024 +--max-position-embeddings 1024 +--micro-batch-size 4 +--global-batch-size 32 +--train-iters 50 --eval-iters 10 ---lr 1e-4 +--timing-log-level 0 +--lr-decay-iters 320000 +--save ${CHECKPOINT_DIR} +--split 949,50,1 +--tokenizer-type NullTokenizer +--vocab-size 8192 --mock-data ---vocab-file ${VOCAB_FILE} ---merge-file ${TE_PATH}/qa/L1_pytorch_mcore_integration/merges.txt +--distributed-backend nccl +--lr 0.00015 +--lr-decay-style cosine +--min-lr 1.0e-5 +--weight-decay 1e-2 +--clip-grad 1.0 +--lr-warmup-fraction .01 +--log-interval 1 +--save-interval 10000 +--eval-interval 1000 --transformer-impl transformer_engine +--recompute-granularity full +--recompute-method uniform +--recompute-num-layers 1 +--deterministic-mode +--no-gradient-accumulation-fusion +--attention-softmax-in-fp32 +--use-mcore-models +--ckpt-format torch_dist +--dist-ckpt-optim-fully-reshardable +--dist-ckpt-strictness log_all +--data-cache-path ${DATA_CACHE_PATH} +--bf16 +--attention-backend unfused +--log-memory-to-tensorboard +--tensorboard-dir ${TENSORBOARD_DIR} +--exit-interval 4 ${WITH_FP8:+--fp8-format hybrid} " COMMAND=$(echo "${COMMAND}" | tr '\n' ' ') -# Launch Megatron-LM +# Launch Megatron-LM-FL bash -c "${COMMAND}" diff --git a/qa/L1_pytorch_mcore_integration/test_bak.sh b/qa/L1_pytorch_mcore_integration/test_bak.sh new file mode 100644 index 0000000000..ec0b47b695 --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/test_bak.sh @@ -0,0 +1,79 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Paths +: ${TE_PATH:=/opt/transformerengine} +: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM} + +# Check whether FP8 is supported +DEVICE_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | sed 's/[^0-9]//g') +if [[ ${DEVICE_ARCH} -ge 89 ]]; then + WITH_FP8=1 +fi + +# Download Megatron-LM if needed +if [ ! -d "${MCORE_PATH}" ]; then + pushd $(dirname ${MCORE_PATH}) + git clone -b core_r0.12.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM + popd +fi + +# Megatron tokenizer import chain pulls in bert_tokenization at module import +# time, which unconditionally depends on `six`. +python3 - <<'PY' || python3 -m pip install --disable-pip-version-check six +import six +print(f"six available: {six.__version__}") +PY + +# Create mock vocab +VOCAB_FILE=${TE_PATH}/qa/L1_pytorch_mcore_integration/vocab.json +printf "" > ${VOCAB_FILE} +printf "{" >> ${VOCAB_FILE} +printf "\"<|endoftext|>\": 0" >> ${VOCAB_FILE} +seq 1 4095 | awk '{ printf(", \"%d\": %d", $1, $1) }' >> ${VOCAB_FILE} +printf "}" >> ${VOCAB_FILE} + +# Megatron-LM invocation +COMMAND=" +NVTE_TORCH_COMPILE=0 +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +NVTE_FLASH_ATTN=1 +NVTE_FWD_LAYERNORM_SM_MARGIN=0 +NVTE_BWD_LAYERNORM_SM_MARGIN=0 +CUDA_DEVICE_MAX_CONNECTIONS=1 +NVTE_BIAS_GELU_NVFUSION=0 +NVTE_BIAS_DROPOUT_FUSION=0 + +python3 +-m torch.distributed.launch +--use_env +--nnodes=1 +--nproc_per_node=1 + +${MCORE_PATH}/pretrain_gpt.py +--tensor-model-parallel-size 1 +--pipeline-model-parallel-size 1 +--use-cpu-initialization +--num-layers 2 +--hidden-size 128 +--num-attention-heads 8 +--seq-length 128 +--max-position-embeddings 128 +--micro-batch-size 1 +--global-batch-size 8 +--train-iters 10 +--eval-iters 10 +--lr 1e-4 +--mock-data +--vocab-file ${VOCAB_FILE} +--merge-file ${TE_PATH}/qa/L1_pytorch_mcore_integration/merges.txt +--transformer-impl transformer_engine +${WITH_FP8:+--fp8-format hybrid} +" +COMMAND=$(echo "${COMMAND}" | tr '\n' ' ') + +# Launch Megatron-LM +bash -c "${COMMAND}" diff --git a/setup.py b/setup.py index 3a66e624e3..16acac9bab 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,9 @@ from setuptools.command.build_ext import build_ext as BuildExtension +from setuptools.command.install import install as InstallCommand +from datetime import datetime +import platform os.environ["NVTE_PROJECT_BUILDING"] = "1" @@ -43,6 +46,61 @@ archs = cuda_archs() +def generate_build_config(skip_cuda_build): + """Generate build-time configuration file.""" + config_template_path = ( + current_file_path / "transformer_engine" / "plugin" / "core" / "_build_config.py.template" + ) + config_output_path = ( + current_file_path / "transformer_engine" / "plugin" / "core" / "_build_config.py" + ) + + if config_template_path.exists(): + with open(config_template_path, "r") as f: + template = f.read() + + config_content = template.format( + skip_cuda=skip_cuda_build, + build_time=datetime.now().isoformat(), + platform=platform.platform(), + ) + + with open(config_output_path, "w") as f: + f.write(config_content) + + print(f"Generated build config: {config_output_path}") + print(f" SKIP_CUDA_BUILD = {skip_cuda_build}") + else: + # Fallback: create minimal config if template doesn't exist + config_content = f"""# Auto-generated build configuration +SKIP_CUDA_BUILD = {skip_cuda_build} +BUILD_TIME = "{datetime.now().isoformat()}" +BUILD_PLATFORM = "{platform.platform()}" +""" + with open(config_output_path, "w") as f: + f.write(config_content) + print(f"Generated minimal build config: {config_output_path}") + + +class CustomInstall(InstallCommand): + """Custom install command to generate build config.""" + + user_options = InstallCommand.user_options + [ + ("skip-cuda-build", None, "Skip CUDA build"), + ] + + def initialize_options(self): + super().initialize_options() + self.skip_cuda_build = bool(int(os.getenv("TE_FL_SKIP_CUDA", "0"))) + + def run(self): + # Run the standard install + super().run() + + # Generate build config after installation + generate_build_config(self.skip_cuda_build) + + class TimedBdist(bdist_wheel): """Helper class to measure build time""" @@ -184,6 +242,14 @@ def git_check_submodules() -> None: with open("README.rst", encoding="utf-8") as f: long_description = f.read() + # Check if we should skip CUDA build (for AMD/ROCm or pure FL backend usage) + skip_cuda_build = bool(int(os.getenv("TE_FL_SKIP_CUDA", "0"))) + if skip_cuda_build: + print("=" * 60) + print("TE_FL_SKIP_CUDA=1: Skipping CUDA/native backend compilation") + print("Only FL (Flag-Gems/Triton) backend will be available") + print("=" * 60) + # Settings for building top level empty package for dependency management. if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): assert bool( @@ -200,6 +266,13 @@ def git_check_submodules() -> None: "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } + elif skip_cuda_build: + # Skip CUDA build - only install Python packages for FL backend + install_requires, test_requires = setup_requirements() + ext_modules = [] # No CUDA extensions + package_data = {"": ["VERSION.txt"]} + include_package_data = True + extras_require = {"test": test_requires} else: install_requires, test_requires = setup_requirements() ext_modules = [setup_common_extension()] @@ -229,6 +302,9 @@ def git_check_submodules() -> None: ) ) + # Generate build config before setup + generate_build_config(skip_cuda_build) + # Configure package setuptools.setup( name="transformer_engine", @@ -245,7 +321,11 @@ def git_check_submodules() -> None: long_description=long_description, long_description_content_type="text/x-rst", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, + cmdclass={ + "build_ext": CMakeBuildExtension, + "bdist_wheel": TimedBdist, + "install": CustomInstall, + }, python_requires=f">={min_python_version_str()}", classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000000..600fcf223d --- /dev/null +++ b/tests/README.md @@ -0,0 +1,35 @@ +# TransformerEngine-FL Test Suite + +## Quick Start + +```bash +# Run tests +bash qa//test.sh +``` + +## Directory Structure + +``` +tests/ +├── cpp/ # C++ core functionality tests +│ ├── operator/ # C++ operator layer tests (basic/core operator validation) +│ └── util/ # C++ utility function tests (common helper unit tests) +├── cpp_distributed/ # C++ distributed functionality tests (communication/parallelism) +├── jax/ # JAX framework adaptation tests (JAX backend validation) +└── pytorch/ # Full PyTorch framework tests + ├── attention/ # PyTorch attention mechanism tests (FlashAttention/MLA etc.) + ├── debug/ # Debug-specific tests (issue reproduction/debug tooling) + │ └── test_configs/ # Debug test configurations (params/cases for different scenarios) + ├── distributed/ # PyTorch distributed tests (DDP/FSDP/communication) + ├── nvfp4/ # NVFP4 quantization tests (NVIDIA FP4 operator/inference) + └── references/ # Reference implementation tests (consistency vs baseline) +``` + +## Adding Tests + +### Unit Test +Add test file: +- `tests/cpp/test_.cpp` & `tests/cpp/CMakeLists.txt` +- `tests/cpp_distributed/test_.py` & `tests/cpp_distributed/CMakeLists.txt` +- `tests/jax/test_.py` +- `tests/pytorch/test_.py` diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index 0175f04e2e..744d33c8eb 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -10,6 +10,35 @@ from importlib import metadata import transformer_engine.common +import torch + +# Public, simple global (kept for backward compatibility). +TE_DEVICE_TYPE = "cuda" +TE_PLATFORM = torch.cuda + +# Apply MUSA (VENDOR) Patches, such as torch.cuda.device -> torch.musa.device +try: + from .plugin.core.backends.vendor.musa.patches import apply_patch as _musa_apply_patch + + _musa_apply_patch() +except Exception as e: + pass + + +def te_device_type(default: str = "cuda") -> str: + try: + return TE_DEVICE_TYPE + except Exception: + return default + + +def te_platform(default=torch.cuda): + try: + return TE_PLATFORM + except Exception: + return default + + try: from . import pytorch except ImportError: diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 40933f17a9..5264938059 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -18,6 +18,34 @@ from typing import Optional, Tuple +def skip_cuda_build() -> bool: + """Check if CUDA build was skipped (FL-only mode). + + First checks environment variable (for runtime override), + then falls back to build-time configuration. + """ + # Environment variable takes precedence (allows runtime override) + if os.environ.get("TE_FL_SKIP_CUDA"): + return bool(int(os.environ.get("TE_FL_SKIP_CUDA", "0"))) + + # Fall back to build-time configuration + try: + from transformer_engine.plugin.core._build_config import SKIP_CUDA_BUILD + + return SKIP_CUDA_BUILD + except ImportError: + # If build config doesn't exist, default to False + return False + + +# Load plugin system - this handles module registration and backend initialization +# The _module_setup inside core will: +# 1. Register modules under both full and short names for relative imports +# 2. Load all available backends (flagos, reference, vendor/cuda, etc.) +# 3. Register transformer_engine_torch module from the selected backend +import transformer_engine.plugin.core # noqa: F401 # pylint: disable=wrong-import-position + + @functools.lru_cache(maxsize=None) def _is_package_installed(package) -> bool: """Check if the given package is installed via pip.""" @@ -107,7 +135,7 @@ def _get_shared_object_file(library: str) -> Path: """ # Check provided input and determine the correct prefix for .so. - assert library in ("core", "torch", "jax"), f"Unsupported TE library {library}." + assert library in ("core", "torch_nv", "jax"), f"Unsupported TE library {library}." if library == "core": so_prefix = "libtransformer_engine" else: @@ -146,14 +174,21 @@ def get_te_core_package_info() -> Tuple[bool, str, str]: @functools.lru_cache(maxsize=None) def load_framework_extension(framework: str) -> None: """ - Load shared library with Transformer Engine framework bindings - and check verify correctness if installed via PyPI. + Load shared library with Transformer Engine framework bindings. + + For PyTorch: The native module is now named transformer_engine_torch_nv, + and transformer_engine_torch is provided by the plugin system. + This function is kept for backward compatibility but does nothing for torch. """ + # Skip loading native extensions if CUDA build was skipped (FL-only mode) + if skip_cuda_build(): + return + # Supported frameworks. - assert framework in ("jax", "torch"), f"Unsupported framework {framework}" + assert framework in ("jax", "torch_nv"), f"Unsupported framework {framework}" - # Name of the framework extension library. + # For jax: load the native module as before module_name = f"transformer_engine_{framework}" # Name of the pip extra dependency for framework extensions from PyPI. @@ -161,6 +196,10 @@ def load_framework_extension(framework: str) -> None: if framework == "torch": extra_dep_name = "pytorch" + # Skip if already loaded + if module_name in sys.modules: + return + # Find the TE packages. The core and framework packages can only be installed via PyPI. # For the `transformer-engine` package, we need to check explicity. te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info() @@ -195,6 +234,10 @@ def load_framework_extension(framework: str) -> None: def sanity_checks_for_pypi_installation() -> None: """Ensure that package is installed correctly if using PyPI.""" + # Skip sanity checks if CUDA build was skipped (FL-only mode) + if skip_cuda_build(): + return + te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info() te_installed = _is_package_installed("transformer_engine") te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine") @@ -363,24 +406,26 @@ def _load_core_library(): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): sanity_checks_for_pypi_installation() - # `_load_cuda_library` is used for packages that must be loaded - # during runtime. Both system and pypi packages are searched - # and an error is thrown if not found. - _, _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn") - system_nvrtc, _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc") - system_curand, _CURAND_LIB_CTYPES = _load_cuda_library("curand") - - # This additional step is necessary to be able to install TE wheels - # and import TE (without any guards) in an environment where the cuda - # toolkit might be absent without being guarded - load_libs_for_no_ctk = not system_nvrtc and not system_curand - if load_libs_for_no_ctk: - _CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas", strict=True) - _CUDART_LIB_CTYPES = _load_cuda_library_from_python("cudart", strict=True) - _CUDNN_ALL_LIB_CTYPES = _load_cuda_library_from_python("cudnn", strict=True) - - _TE_LIB_CTYPES = _load_core_library() - - # Needed to find the correct headers for NVRTC kernels. - if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir(): - os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir() + # Skip loading CUDA libraries if CUDA build was skipped (FL-only mode) + if not skip_cuda_build(): + # `_load_cuda_library` is used for packages that must be loaded + # during runtime. Both system and pypi packages are searched + # and an error is thrown if not found. + _, _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn") + system_nvrtc, _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc") + system_curand, _CURAND_LIB_CTYPES = _load_cuda_library("curand") + + # This additional step is necessary to be able to install TE wheels + # and import TE (without any guards) in an environment where the cuda + # toolkit might be absent without being guarded + load_libs_for_no_ctk = not system_nvrtc and not system_curand + if load_libs_for_no_ctk: + _CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas", strict=True) + _CUDART_LIB_CTYPES = _load_cuda_library_from_python("cudart", strict=True) + _CUDNN_ALL_LIB_CTYPES = _load_cuda_library_from_python("cudnn", strict=True) + + _TE_LIB_CTYPES = _load_core_library() + + # Needed to find the correct headers for NVRTC kernels. + if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir(): + os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir() diff --git a/transformer_engine/debug/features/fake_quant.py b/transformer_engine/debug/features/fake_quant.py index f48b49b725..ffefd87974 100644 --- a/transformer_engine/debug/features/fake_quant.py +++ b/transformer_engine/debug/features/fake_quant.py @@ -14,6 +14,7 @@ import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.debug.features.api import TEConfigAPIMapper from transformer_engine.common.recipe import Format from transformer_engine.pytorch.tensor import Quantizer @@ -30,7 +31,9 @@ def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): torch.float16, torch.bfloat16, ), "[NVTORCH INSPECT ERROR] Unsupported tensor type." - assert tensor.is_cuda, "[NVTORCH INSPECT ERROR] Must be a GPU tensor." + assert ( + tensor.device.type == te_device_type() + ), f"[NVTORCH INSPECT ERROR] Must be a {te_device_type()} tensor." assert fp8_format in { "FP8E4M3", "FP8E5M2", diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index cf11964e25..85ab069483 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -15,6 +15,7 @@ from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method +from transformer_engine import te_device_type from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor @@ -58,7 +59,10 @@ def _get_new_quantizer(recipe_name, fp8_dtype): return Float8BlockQuantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) if recipe_name == "fp8_current_scaling": return Float8CurrentScalingQuantizer( - fp8_dtype=fp8_dtype, device=torch.device("cuda"), rowwise=True, columnwise=True + fp8_dtype=fp8_dtype, + device=torch.device(te_device_type()), + rowwise=True, + columnwise=True, ) if recipe_name == "mxfp8": return MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) diff --git a/transformer_engine/debug/features/per_tensor_scaling.py b/transformer_engine/debug/features/per_tensor_scaling.py index a4bab4eaf5..a3c2eae8a8 100644 --- a/transformer_engine/debug/features/per_tensor_scaling.py +++ b/transformer_engine/debug/features/per_tensor_scaling.py @@ -11,7 +11,9 @@ import nvdlfw_inspect.api as debug_api from nvdlfw_inspect.registry import Registry, api_method + import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, @@ -33,7 +35,9 @@ def per_tensor_cast( torch.float16, torch.bfloat16, ), "[NVTORCH INSPECT ERROR] Unsupported tensor type for per tensor current scaling" - assert tensor.is_cuda, "[NVTORCH INSPECT ERROR] Must be a GPU tensor." + assert ( + tensor.device.type == te_device_type() + ), f"[NVTORCH INSPECT ERROR] Must be a {te_device_type()} tensor." assert fp8_dtype in { tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index ca7f22e2de..51cc5a0c1f 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -16,6 +16,7 @@ from nvdlfw_inspect.utils import gather_along_first_dim from nvdlfw_inspect.logging import MetricLogger +from transformer_engine import te_device_type from transformer_engine.debug.features.utils.stats_computation import ( STATS, DEPENDENCIES, @@ -41,14 +42,14 @@ def __init__(self, layer_name, tensor_name, stats, reduction_group, reduce_withi for stat in stats: self.stats_to_compute = self.stats_to_compute | DEPENDENCIES[stat] - self._buffer = torch.zeros(len(STATS), dtype=torch.float32).cuda() + self._buffer = torch.zeros(len(STATS), dtype=torch.float32).to(te_device_type()) self._new_buffer = self._buffer.clone() self._tmp_buffer = self._buffer.clone() # in case of data parallelism it is possible that layer will not be run on one node # modified is set to True if node is run # we do not take not run nodes into account - self.modified = torch.tensor([False], dtype=torch.bool).cuda() + self.modified = torch.tensor([False], dtype=torch.bool).to(te_device_type()) self.iteration = None self.skip_reduction = False diff --git a/transformer_engine/plugin/__init__.py b/transformer_engine/plugin/__init__.py new file mode 100644 index 0000000000..2c6533b713 --- /dev/null +++ b/transformer_engine/plugin/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .core import ( + TEFLBackendBase, + TEFLModule, + get_tefl_module as _get_tefl_module, + get_registry, +) + + +def __getattr__(name): + if name == "tefl": + return _get_tefl_module() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "TEFLBackendBase", + "TEFLModule", + "get_tefl_module", + "get_registry", + "tefl", +] diff --git a/transformer_engine/plugin/benchmarks/__init__.py b/transformer_engine/plugin/benchmarks/__init__.py new file mode 100644 index 0000000000..caaec47482 --- /dev/null +++ b/transformer_engine/plugin/benchmarks/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +__all__ = [] diff --git a/transformer_engine/plugin/benchmarks/benchmark_all_backends.py b/transformer_engine/plugin/benchmarks/benchmark_all_backends.py new file mode 100644 index 0000000000..f111cf0498 --- /dev/null +++ b/transformer_engine/plugin/benchmarks/benchmark_all_backends.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, BAAI. All rights reserved. +# +import os +import sys +import torch +import time +import numpy as np +from datetime import datetime +from typing import Dict, List + + +from transformer_engine.plugin.test_utils import get_available_backends, get_backend + + +class BenchmarkResult: + def __init__( + self, + backend_name: str, + operation_name: str, + shape: tuple, + mean_time: float, + std_time: float, + min_time: float, + max_time: float, + gflops: float = None, + bandwidth: float = None, + ): + self.backend_name = backend_name + self.operation_name = operation_name + self.shape = shape + self.mean_time = mean_time + self.std_time = std_time + self.min_time = min_time + self.max_time = max_time + self.gflops = gflops + self.bandwidth = bandwidth + + def __str__(self): + gflops_str = f"{self.gflops:.2f} GFLOPS" if self.gflops else "N/A" + bandwidth_str = f"{self.bandwidth:.2f} GB/s" if self.bandwidth else "N/A" + return ( + f"{self.backend_name:12s} {self.mean_time:8.4f}±{self.std_time:6.4f} ms " + f"[{self.min_time:7.4f}, {self.max_time:7.4f}] " + f"{gflops_str:15s} {bandwidth_str:12s}" + ) + + +def time_operation(func, warmup_iters=10, benchmark_iters=100): + for _ in range(warmup_iters): + func() + if torch.cuda.is_available(): + torch.cuda.synchronize() + + times = [] + for _ in range(benchmark_iters): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + start = time.perf_counter() + func() + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + end = time.perf_counter() + times.append((end - start) * 1000) + + return { + "mean": np.mean(times), + "std": np.std(times), + "min": np.min(times), + "max": np.max(times), + } + + +def compute_gflops(operation: str, shape: tuple, time_ms: float) -> float: + if operation in ["gelu", "relu", "silu"]: + flops = np.prod(shape) * 5 + elif operation == "layernorm": + total_elements = np.prod(shape) + hidden_size = shape[-1] + flops = total_elements * (3 + 2 * hidden_size) + elif operation == "rmsnorm": + total_elements = np.prod(shape) + hidden_size = shape[-1] + flops = total_elements * (2 + hidden_size) + elif operation == "gemm": + M, N, K = shape + flops = 2 * M * N * K + else: + return None + + return (flops / 1e9) / (time_ms / 1000) + + +def compute_bandwidth(operation: str, shape: tuple, time_ms: float) -> float: + bytes_per_element = 4 + + if operation in ["gelu", "relu", "silu"]: + total_bytes = np.prod(shape) * 2 * bytes_per_element + elif operation in ["layernorm", "rmsnorm"]: + total_bytes = np.prod(shape) * 5 * bytes_per_element + elif operation == "gemm": + M, N, K = shape + total_bytes = (M * K + K * N + M * N) * bytes_per_element + else: + return None + + return (total_bytes / 1e9) / (time_ms / 1000) + + +def benchmark_activations( + backends: List[str], shapes: List[tuple], device: str +) -> List[BenchmarkResult]: + print("\n" + "=" * 80) + print("Activation Function Performance Test") + print("=" * 80) + + results = [] + operations = [ + ("gelu", "GELU"), + ("relu", "ReLU"), + ("silu", "SiLU"), + ] + + for shape in shapes: + print(f"\nShape: {shape}") + x = torch.randn(shape, dtype=torch.float32, device=device) + + for op_method, op_name in operations: + print(f"\n {op_name}:") + print( + f" {'Backend':<12s} {'Time (ms)':<20s} {'Range (ms)':<25s} {'GFLOPS':<15s} {'Bandwidth'}" + ) + print(f" {'-'*85}") + + for backend_name in backends: + backend = get_backend(backend_name) + + try: + func = lambda: getattr(backend, op_method)(x, None) + timing = time_operation(func) + + gflops = compute_gflops(op_method, shape, timing["mean"]) + bandwidth = compute_bandwidth(op_method, shape, timing["mean"]) + + result = BenchmarkResult( + backend_name, + op_method, + shape, + timing["mean"], + timing["std"], + timing["min"], + timing["max"], + gflops, + bandwidth, + ) + results.append(result) + print(f" {result}") + + except Exception as e: + print(f" {backend_name:12s} SKIPPED ({type(e).__name__}: {str(e)[:40]})") + + return results + + +def benchmark_normalization( + backends: List[str], shapes: List[tuple], device: str +) -> List[BenchmarkResult]: + print("\n" + "=" * 80) + print("Normalization Performance Test") + print("=" * 80) + + results = [] + eps = 1e-5 + + for shape in shapes: + print(f"\nShape: {shape}") + hidden_size = shape[-1] + x = torch.randn(shape, dtype=torch.float32, device=device) + weight = torch.ones(hidden_size, dtype=torch.float32, device=device) + bias = torch.zeros(hidden_size, dtype=torch.float32, device=device) + + print(f"\n LayerNorm forward:") + print( + f" {'Backend':<12s} {'Time (ms)':<20s} {'Range (ms)':<25s} {'GFLOPS':<15s} {'Bandwidth'}" + ) + print(f" {'-'*85}") + + for backend_name in backends: + backend = get_backend(backend_name) + + try: + func = lambda: backend.layernorm_fwd( + x, weight, bias, eps, None, None, torch.float32, 0, False + ) + timing = time_operation(func) + + gflops = compute_gflops("layernorm", shape, timing["mean"]) + bandwidth = compute_bandwidth("layernorm", shape, timing["mean"]) + + result = BenchmarkResult( + backend_name, + "layernorm_fwd", + shape, + timing["mean"], + timing["std"], + timing["min"], + timing["max"], + gflops, + bandwidth, + ) + results.append(result) + print(f" {result}") + + except Exception as e: + print(f" {backend_name:12s} SKIPPED ({type(e).__name__})") + + print(f"\n RMSNorm forward:") + print( + f" {'Backend':<12s} {'Time (ms)':<20s} {'Range (ms)':<25s} {'GFLOPS':<15s} {'Bandwidth'}" + ) + print(f" {'-'*85}") + + for backend_name in backends: + backend = get_backend(backend_name) + + try: + func = lambda: backend.rmsnorm_fwd( + x, weight, eps, None, None, torch.float32, 0, False + ) + timing = time_operation(func) + + gflops = compute_gflops("rmsnorm", shape, timing["mean"]) + bandwidth = compute_bandwidth("rmsnorm", shape, timing["mean"]) + + result = BenchmarkResult( + backend_name, + "rmsnorm_fwd", + shape, + timing["mean"], + timing["std"], + timing["min"], + timing["max"], + gflops, + bandwidth, + ) + results.append(result) + print(f" {result}") + + except Exception as e: + print(f" {backend_name:12s} SKIPPED ({type(e).__name__})") + + return results + + +def benchmark_gemm(backends: List[str], configs: List[tuple], device: str) -> List[BenchmarkResult]: + print("\n" + "=" * 80) + print("GEMM Performance Test") + print("=" * 80) + + results = [] + + for M, N, K in configs: + print(f"\nConfig: M={M}, N={N}, K={K}") + print( + f" {'Backend':<12s} {'Time (ms)':<20s} {'Range (ms)':<25s} {'GFLOPS':<15s} {'Bandwidth'}" + ) + print(f" {'-'*85}") + + A = torch.randn(M, K, dtype=torch.float32, device=device) + B = torch.randn(K, N, dtype=torch.float32, device=device) + D = torch.empty(M, N, dtype=torch.float32, device=device) + workspace = torch.empty(1024, dtype=torch.uint8, device=device) + + for backend_name in backends: + backend = get_backend(backend_name) + + try: + func = lambda: backend.generic_gemm( + A, + False, + B, + False, + D, + None, + torch.float32, + None, + None, + False, + None, + False, + workspace, + 1024, + False, + False, + ) + timing = time_operation(func) + + gflops = compute_gflops("gemm", (M, N, K), timing["mean"]) + bandwidth = compute_bandwidth("gemm", (M, N, K), timing["mean"]) + + result = BenchmarkResult( + backend_name, + "gemm", + (M, N, K), + timing["mean"], + timing["std"], + timing["min"], + timing["max"], + gflops, + bandwidth, + ) + results.append(result) + print(f" {result}") + + except Exception as e: + print(f" {backend_name:12s} SKIPPED ({type(e).__name__})") + + return results + + +def print_summary(all_results: List[BenchmarkResult]): + print("\n" + "=" * 80) + print("Performance Comparison Summary") + print("=" * 80) + + from collections import defaultdict + + by_operation = defaultdict(lambda: defaultdict(list)) + + for result in all_results: + by_operation[result.operation_name][result.backend_name].append(result) + + print("\nAverage Performance (all shapes):") + print(f"{'Operation':<20s} {'Backend':<12s} {'Avg Time (ms)':<15s} {'Avg GFLOPS':<15s}") + print("-" * 65) + + for op_name, backends_data in sorted(by_operation.items()): + for backend_name, results in sorted(backends_data.items()): + avg_time = np.mean([r.mean_time for r in results]) + gflops_list = [r.gflops for r in results if r.gflops is not None] + avg_gflops = np.mean(gflops_list) if gflops_list else None + + gflops_str = f"{avg_gflops:.2f}" if avg_gflops else "N/A" + print(f"{op_name:<20s} {backend_name:<12s} {avg_time:<15.4f} {gflops_str:<15s}") + + print("\n" + "=" * 80) + print("Fastest Backend (by operation)") + print("=" * 80) + + for op_name, backends_data in sorted(by_operation.items()): + backend_avg_times = {} + for backend_name, results in backends_data.items(): + backend_avg_times[backend_name] = np.mean([r.mean_time for r in results]) + + if backend_avg_times: + fastest = min(backend_avg_times.items(), key=lambda x: x[1]) + print(f"{op_name:<20s} → {fastest[0]:<12s} ({fastest[1]:.4f} ms)") + + +def save_results_csv(results: List[BenchmarkResult], filename: str): + import csv + + with open(filename, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow( + [ + "Backend", + "Operation", + "Shape", + "Mean(ms)", + "Std(ms)", + "Min(ms)", + "Max(ms)", + "GFLOPS", + "GB/s", + ] + ) + + for result in results: + writer.writerow( + [ + result.backend_name, + result.operation_name, + str(result.shape), + f"{result.mean_time:.4f}", + f"{result.std_time:.4f}", + f"{result.min_time:.4f}", + f"{result.max_time:.4f}", + f"{result.gflops:.2f}" if result.gflops else "N/A", + f"{result.bandwidth:.2f}" if result.bandwidth else "N/A", + ] + ) + + print(f"\nResults saved to: {filename}") + + +def main(): + print("\n" + "=" * 80) + print(" " * 25 + "Multi-Backend Performance Comparison Test") + print("=" * 80) + + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + print(f"\nDevice: CUDA - {torch.cuda.get_device_name(0)}") + print(f"CUDA version: {torch.version.cuda}") + else: + print(f"\nDevice: CPU") + print(f"PyTorch version: {torch.__version__}") + + backends = get_available_backends() + print(f"\nAvailable backends: {', '.join(backends)}") + print(f"Total: {len(backends)} backends") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = f"benchmark_results_{timestamp}" + os.makedirs(output_dir, exist_ok=True) + print(f"Results will be saved to: {output_dir}/") + + activation_shapes = [ + (1024, 1024), + (2048, 2048), + (4096, 4096), + ] + + normalization_shapes = [ + (8, 512, 768), + (16, 512, 1024), + (32, 512, 2048), + ] + + gemm_configs = [ + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + ] + + all_results = [] + + results = benchmark_activations(backends, activation_shapes, device) + all_results.extend(results) + save_results_csv(results, f"{output_dir}/activations.csv") + + results = benchmark_normalization(backends, normalization_shapes, device) + all_results.extend(results) + save_results_csv(results, f"{output_dir}/normalization.csv") + + results = benchmark_gemm(backends, gemm_configs, device) + all_results.extend(results) + save_results_csv(results, f"{output_dir}/gemm.csv") + + print_summary(all_results) + + save_results_csv(all_results, f"{output_dir}/all_results.csv") + + print("\n" + "=" * 80) + print("Testing complete!") + print("=" * 80 + "\n") + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/transformer_engine/plugin/core/__init__.py b/transformer_engine/plugin/core/__init__.py new file mode 100644 index 0000000000..21a94e5f1e --- /dev/null +++ b/transformer_engine/plugin/core/__init__.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .types import BackendImplKind, OpImpl, match_token + +from .ops import ( + TEFLBackendBase, + TEFLModule, + get_tefl_module, + reset_tefl_module, + get_registry, + get_manager, + reset_registry, +) + +from .logger_manager import Logger, LoggerManager +from .policy import ( + SelectionPolicy, + PolicyManager, + get_policy, + set_global_policy, + reset_global_policy, + policy_context, + policy_from_env, + get_policy_epoch, + bump_policy_epoch, + with_strict_mode, + with_preference, + with_allowed_vendors, + with_denied_vendors, + PREFER_DEFAULT, + PREFER_VENDOR, + PREFER_REFERENCE, + VALID_PREFER_VALUES, +) + +from .manager import OpManager, get_default_manager, reset_default_manager +from .registry import OpRegistry + + +from .discovery import ( + discover_plugin, + discover_from_entry_points, + discover_from_env_modules, + get_discovered_plugin, + clear_discovered_plugin, + PLUGIN_GROUP, + PLUGIN_MODULES_ENV, +) + +# Setup module aliases BEFORE importing backends to support relative imports +from ._module_setup import setup_module_aliases, register_as_transformer_engine_torch + +setup_module_aliases() + +# Import backends - this loads all available backends (flagos, reference, vendor/cuda, etc.) +from . import backends + +# Register transformer_engine_torch AFTER backends are loaded +# so that get_tefl_module() can find a registered backend +register_as_transformer_engine_torch() diff --git a/transformer_engine/plugin/core/_build_config.py.template b/transformer_engine/plugin/core/_build_config.py.template new file mode 100644 index 0000000000..27b90f5080 --- /dev/null +++ b/transformer_engine/plugin/core/_build_config.py.template @@ -0,0 +1,22 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Build-time Configuration (Auto-generated) + +This file is automatically generated during package installation. +DO NOT EDIT MANUALLY. + +Configuration settings are determined at build time and should not +be changed at runtime. +""" + +# Whether CUDA backend was skipped during build +SKIP_CUDA_BUILD = {skip_cuda} + +# Build timestamp +BUILD_TIME = "{build_time}" + +# Build platform +BUILD_PLATFORM = "{platform}" diff --git a/transformer_engine/plugin/core/_module_setup.py b/transformer_engine/plugin/core/_module_setup.py new file mode 100644 index 0000000000..74acad26cc --- /dev/null +++ b/transformer_engine/plugin/core/_module_setup.py @@ -0,0 +1,98 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Module setup for core plugin system. + +This module handles the registration of core modules in sys.modules +with both full and short names to support relative imports in backends. +""" + +import sys +from pathlib import Path + + +def setup_module_aliases(): + """ + Register core modules under both full and short names. + + This allows backends to use relative imports like: + from ...ops import TEFLBackendBase + from ...types import OpImpl, BackendImplKind + + And ensures they work correctly regardless of how the module is imported. + """ + # Get the current package + current_package = sys.modules.get("transformer_engine.plugin.core") + if current_package is None: + return + + # Register the main package under short name + sys.modules["core"] = current_package + + # List of submodules to register + submodule_names = [ + "ops", + "logger", + "types", + "logger_manager", + "policy", + "operator_registry", + "registry", + "discovery", + ] + + # Register each submodule under short name + for name in submodule_names: + full_name = f"transformer_engine.plugin.core.{name}" + short_name = f"core.{name}" + + if full_name in sys.modules and short_name not in sys.modules: + sys.modules[short_name] = sys.modules[full_name] + + # Register backends package + backends_full = "transformer_engine.plugin.core.backends" + backends_short = "core.backends" + if backends_full in sys.modules and backends_short not in sys.modules: + sys.modules[backends_short] = sys.modules[backends_full] + + # Register parent plugin package if needed + if "transformer_engine.plugin" not in sys.modules: + import types + + plugin_dir = Path(__file__).parent.parent + plugin_pkg = types.ModuleType("transformer_engine.plugin") + plugin_pkg.__path__ = [str(plugin_dir)] + sys.modules["transformer_engine.plugin"] = plugin_pkg + + +def register_as_transformer_engine_torch(): + """ + Register the tefl module as transformer_engine_torch. + + This provides backward compatibility with code that expects + transformer_engine_torch to be available. + """ + # Only register if not already present + if "transformer_engine_torch" in sys.modules: + return + + try: + from .ops import get_tefl_module + + tefl_module = get_tefl_module() + sys.modules["transformer_engine_torch"] = tefl_module + except Exception as e: + import traceback + + print(f"[TEFL Setup] Warning: Could not register transformer_engine_torch: {e}") + traceback.print_exc() + + # Create a minimal placeholder module to avoid import errors + # This allows the system to at least import without crashing + import types + + placeholder = types.ModuleType("transformer_engine_torch") + placeholder.__doc__ = "Placeholder module - TEFL backend not available" + sys.modules["transformer_engine_torch"] = placeholder diff --git a/transformer_engine/plugin/core/backends/__init__.py b/transformer_engine/plugin/core/backends/__init__.py new file mode 100644 index 0000000000..7729afc3af --- /dev/null +++ b/transformer_engine/plugin/core/backends/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. diff --git a/transformer_engine/plugin/core/backends/fa_utils.py b/transformer_engine/plugin/core/backends/fa_utils.py new file mode 100644 index 0000000000..c24b377631 --- /dev/null +++ b/transformer_engine/plugin/core/backends/fa_utils.py @@ -0,0 +1,191 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +"""Common utilities for Flash Attention backends with Context Parallelism support.""" + +from typing import Any, Tuple + +import torch +import torch.distributed as dist + + +class AllGatherFunc(torch.autograd.Function): + """Autograd function for all-gather along sequence dimension with proper backward.""" + + @staticmethod + def forward(ctx, input_tensor: torch.Tensor, cp_group: Any, seq_dim: int) -> torch.Tensor: + world_size = dist.get_world_size(cp_group) + gathered_list = [torch.empty_like(input_tensor) for _ in range(world_size)] + dist.all_gather(gathered_list, input_tensor, group=cp_group) + ctx.cp_group = cp_group + ctx.world_size = world_size + ctx.seq_dim = seq_dim + return torch.cat(gathered_list, dim=seq_dim) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: + # Split the gradient and reduce_scatter + grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.seq_dim) + local_grad = torch.zeros_like(grad_chunks[0]) + grad_list = [chunk.contiguous() for chunk in grad_chunks] + dist.reduce_scatter(local_grad, grad_list, group=ctx.cp_group) + return local_grad, None, None + + +def all_gather_along_seq( + tensor: torch.Tensor, + cp_group: Any, + seq_dim: int = 2, +) -> torch.Tensor: + """All-gather tensor along sequence dimension across CP group. + + Args: + tensor: Input tensor to gather. + cp_group: Context parallelism process group. + seq_dim: Sequence dimension (default: 2 for BHSD format). + + Returns: + Gathered tensor with sequence dimension scaled by CP world size. + """ + world_size = dist.get_world_size(cp_group) + if world_size == 1: + return tensor + + tensor = tensor.contiguous() + return AllGatherFunc.apply(tensor, cp_group, seq_dim) + + +def reduce_scatter_along_seq( + tensor: torch.Tensor, + cp_group: Any, + seq_dim: int = 2, +) -> torch.Tensor: + """Reduce-scatter tensor along sequence dimension across CP group. + + Args: + tensor: Input tensor to reduce-scatter. + cp_group: Context parallelism process group. + seq_dim: Sequence dimension (default: 2 for BHSD format). + + Returns: + Reduced tensor with sequence dimension divided by CP world size. + """ + world_size = dist.get_world_size(cp_group) + if world_size == 1: + return tensor + + tensor = tensor.contiguous() + seq_len = tensor.shape[seq_dim] + chunk_size = seq_len // world_size + + output = torch.empty( + *tensor.shape[:seq_dim], + chunk_size, + *tensor.shape[seq_dim + 1 :], + dtype=tensor.dtype, + device=tensor.device + ) + + dist.reduce_scatter_tensor(output, tensor, group=cp_group) + return output + + +def create_cp_causal_mask( + local_seq_len_q: int, + full_seq_len_kv: int, + cp_rank: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """Create causal mask for context parallelism. + + In CP mode, each rank processes a different chunk of the query sequence, + so the causal mask needs to account for global positions. + + Args: + local_seq_len_q: Local query sequence length (per rank). + full_seq_len_kv: Full key/value sequence length (after all-gather). + cp_rank: Current rank in CP group. + device: Device to create mask on. + dtype: Data type for mask. + + Returns: + Causal mask tensor of shape [local_seq_len_q, full_seq_len_kv]. + """ + # Calculate global query position offset + q_start = cp_rank * local_seq_len_q + + # Create position indices + q_indices = ( + torch.arange(local_seq_len_q, device=device, dtype=torch.long).unsqueeze(1) + q_start + ) + kv_indices = torch.arange(full_seq_len_kv, device=device, dtype=torch.long).unsqueeze(0) + + # Create causal mask: mask out positions where kv_idx > q_idx + causal_mask = torch.zeros(local_seq_len_q, full_seq_len_kv, dtype=dtype, device=device) + causal_mask.masked_fill_(kv_indices > q_indices, float("-inf")) + + return causal_mask + + +def create_cp_window_mask( + local_seq_len_q: int, + full_seq_len_kv: int, + cp_rank: int, + window_size: Tuple[int, int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """Create sliding window mask for context parallelism. + + Args: + local_seq_len_q: Local query sequence length (per rank). + full_seq_len_kv: Full key/value sequence length (after all-gather). + cp_rank: Current rank in CP group. + window_size: Tuple of (left_window, right_window). -1 means no limit. + device: Device to create mask on. + dtype: Data type for mask. + + Returns: + Window mask tensor of shape [local_seq_len_q, full_seq_len_kv]. + """ + left_window, right_window = window_size + + # Calculate global query position offset + q_start = cp_rank * local_seq_len_q + + # Create position indices + q_indices = ( + torch.arange(local_seq_len_q, device=device, dtype=torch.long).unsqueeze(1) + q_start + ) + kv_indices = torch.arange(full_seq_len_kv, device=device, dtype=torch.long).unsqueeze(0) + + # Create window mask + window_mask = torch.zeros(local_seq_len_q, full_seq_len_kv, dtype=dtype, device=device) + + if left_window >= 0: + window_mask.masked_fill_(kv_indices < q_indices - left_window, float("-inf")) + if right_window >= 0: + window_mask.masked_fill_(kv_indices > q_indices + right_window, float("-inf")) + + return window_mask + + +def get_cp_info(cp_group: Any) -> Tuple[int, int, bool]: + """Get context parallelism information from process group. + + Args: + cp_group: Context parallelism process group. + + Returns: + Tuple of (cp_size, cp_rank, use_cp). + """ + if cp_group is None: + return 1, 0, False + + cp_size = dist.get_world_size(cp_group) + cp_rank = dist.get_rank(cp_group) + use_cp = cp_size > 1 + + return cp_size, cp_rank, use_cp diff --git a/transformer_engine/plugin/core/backends/flagos/__init__.py b/transformer_engine/plugin/core/backends/flagos/__init__.py new file mode 100644 index 0000000000..86126aa3e0 --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .flagos import FlagOSBackend + +__all__ = ["FlagOSBackend"] diff --git a/transformer_engine/plugin/core/backends/flagos/attention/__init__.py b/transformer_engine/plugin/core/backends/flagos/attention/__init__.py new file mode 100644 index 0000000000..7729afc3af --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/attention/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. diff --git a/transformer_engine/plugin/core/backends/flagos/attention/dot_product_attention/__init__.py b/transformer_engine/plugin/core/backends/flagos/attention/dot_product_attention/__init__.py new file mode 100644 index 0000000000..7729afc3af --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/attention/dot_product_attention/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. diff --git a/transformer_engine/plugin/core/backends/flagos/attention/dot_product_attention/backends.py b/transformer_engine/plugin/core/backends/flagos/attention/dot_product_attention/backends.py new file mode 100644 index 0000000000..1b0e72b6f7 --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/attention/dot_product_attention/backends.py @@ -0,0 +1,386 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import nullcontext +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import warnings +from packaging.version import Version as PkgVersion + +import torch +from transformer_engine import te_device_type +from transformer_engine.pytorch.utils import ( + get_device_compute_capability, +) +from transformer_engine.pytorch.utils import nvtx_range_push, nvtx_range_pop + +from transformer_engine.pytorch.quantized_tensor import ( + prepare_for_saving, + restore_from_saved, +) +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.constants import ( + TE_DType, + QKVLayouts, + dist_group_type, +) + +from transformer_engine.pytorch.distributed import get_distributed_world_size +from transformer_engine.pytorch.jit import no_torch_dynamo +from transformer_engine.pytorch.attention.inference import InferenceParams + +import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils + +from transformer_engine.plugin.core.ops import FlashAttentionBase + +import flag_gems + + +class AttnFuncFL(torch.autograd.Function): + @staticmethod + def forward( + ctx, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + page_table_k, + page_table_v, + q, + k, + v, + attn_scale, + dropout_p, + qkv_layout, + attn_mask_type, + window_size, + rng_gen, + deterministic, + layer_number, + ): + nvtx_label = "transformer_engine.AttnFuncFL.forward" + nvtx_range_push(f"{nvtx_label}") + + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + + out_nominal_dtype = q.dtype + + max_logit = None + + is_causal = attn_mask_type == "causal" + + q_permuted = q.permute(1, 2, 0, 3).contiguous() + k_permuted = k.permute(1, 2, 0, 3).contiguous() + v_permuted = v.permute(1, 2, 0, 3).contiguous() + + (out_permuted, m) = flag_gems.scaled_dot_product_attention_forward( + q_permuted, + k_permuted, + v_permuted, + attn_mask=None, + dropout_p=dropout_p, + is_causal=is_causal, + scale=attn_scale, + enable_gqa=True, + ) + # Must be contiguous for .view() in FlashAttentionFL.forward + out = out_permuted.permute(2, 0, 1, 3).contiguous() + + aux_ctx_tensors = [out_permuted, m] + out_ret = out + qkvo_tensors = (q_permuted, k_permuted, v_permuted, out_permuted) + + nvtx_range_pop(f"{nvtx_label}") + + ctx.nominal_dtype = out_nominal_dtype + + from transformer_engine.pytorch.cpu_offload import ( + is_cpu_offload_enabled, + mark_activation_offload, + ) + + if is_cpu_offload_enabled(): + tensor_list = [q, k, v, out] + + mark_activation_offload(*tensor_list) + mark_activation_offload(*aux_ctx_tensors) + + tensors_to_save, tensor_objects = prepare_for_saving( + *qkvo_tensors, + cu_seqlens_q, + cu_seqlens_kv, + *aux_ctx_tensors, + ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.layer_number = layer_number + + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.attn_scale = attn_scale + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + + ctx.qkv_layout = qkv_layout + ctx.attn_mask_type = attn_mask_type + ctx.window_size = window_size + ctx.deterministic = deterministic + + return out_ret + + @staticmethod + def backward(ctx, d_out, *_args): + d_out = d_out.contiguous() + ( + q_permuted, + k_permuted, + v_permuted, + out_permuted, + cu_seqlens_q, + cu_seqlens_kv, + *other_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + + aux_ctx_tensors = other_tensors + + if not aux_ctx_tensors[0].is_contiguous(): + aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() + if not aux_ctx_tensors[1].is_contiguous(): + aux_ctx_tensors[1] = aux_ctx_tensors[1].contiguous() + out_permuted, m = aux_ctx_tensors + rest = [None] + + with torch.cuda.nvtx.range("AttnFuncFL.backward"): + dqkv_nominal_dtype = ctx.nominal_dtype + + dqkv_te_dtype = TE_DType[d_out.dtype] + + q_permuted = q_permuted.contiguous() if not q_permuted.is_contiguous() else q_permuted + k_permuted = k_permuted.contiguous() if not k_permuted.is_contiguous() else k_permuted + v_permuted = v_permuted.contiguous() if not v_permuted.is_contiguous() else v_permuted + out_permuted = ( + out_permuted.contiguous() if not out_permuted.is_contiguous() else out_permuted + ) + m = m.contiguous() if not m.is_contiguous() else m + + # d_out is (seq, batch, heads, dim) from autograd, permute to (batch, heads, seq, dim) + d_out_permuted = d_out.permute(1, 2, 0, 3).contiguous() + + dq_permuted, dk_permuted, dv_permuted = flag_gems.scaled_dot_product_attention_backward( + d_out_permuted, + q_permuted, + k_permuted, + v_permuted, + out_permuted, + m, + attn_mask=None, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.attn_scale, + enable_gqa=True, + ) + + dq = dq_permuted.permute(2, 0, 1, 3) + dk = dk_permuted.permute(2, 0, 1, 3) + dv = dv_permuted.permute(2, 0, 1, 3) + + rest = None + + return ( + None, + None, + None, + None, + None, + None, + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class FlashAttentionFL(FlashAttentionBase): + def __init__( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = None, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + ) -> None: + super().__init__( + softmax_scale=softmax_scale, + attention_dropout=attention_dropout, + attention_dropout_ctx=attention_dropout_ctx, + attention_type=attention_type, + layer_number=layer_number, + deterministic=deterministic, + ) + self.use_FAv2_bwd = os.getenv( + "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0" + ) == "1" and get_device_compute_capability() == (9, 0) + + def remove_extra_states_check(self, incompatible_keys): + for key in incompatible_keys.missing_keys: + if "fused_attention._extra_state" in key: + incompatible_keys.missing_keys.remove(key) + for key in incompatible_keys.unexpected_keys: + if "fused_attention._extra_state" in key: + incompatible_keys.unexpected_keys.remove(key) + warnings.warn( + "fused_attention._extra_state is not loaded from checkpoint. Please map " + "FusedAttention's _extra_state to DotProductAttention's _extra_state." + ) + + self.register_load_state_dict_post_hook(remove_extra_states_check) + + @property + def backend_name(self) -> str: + return "flagos" + + @no_torch_dynamo() + def _forward_impl( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, + cp_global_ranks: List[int] = None, + cp_stream: torch.cuda.Stream = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, + inference_params: Optional[InferenceParams] = None, + flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), + fp8_output: bool = False, + num_splits: Optional[int] = 1, + ) -> torch.Tensor: + assert all( + x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) + for x in [query_layer, key_layer, value_layer] + ), "FLAttention only supports FP16 and BF16 data types, or Float8Tensors." + assert ( + query_layer.device.type == te_device_type() + and key_layer.device.type == te_device_type() + and value_layer.device.type == te_device_type() + ), f"FLAttention only supports {te_device_type()} tensors." + assert qkv_layout in QKVLayouts, f"FLAttention does not support qkv_layout = {qkv_layout}!" + + cp_size = 1 + if isinstance(cp_group, dist_group_type): + cp_size = get_distributed_world_size(cp_group) + elif isinstance(cp_group, list): + for group in cp_group: + cp_size *= get_distributed_world_size(group) + context_parallel = cp_size > 1 + assert not context_parallel, "FLAttention do not support context parallel now" + + qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params) + + if q_format in ["bshd", "sbhd"] or kv_format in ["bshd", "sbhd"]: + batch_size = query_layer.shape[0] if q_format == "bshd" else query_layer.shape[1] + if cu_seqlens_q is not None: + cu_seqlens_q = cu_seqlens_q[: batch_size + 1] + if cu_seqlens_kv is not None: + cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1] + + page_table = None + if inference_params is None: + if qkv_format in ["sbhd", "bshd"]: + if qkv_format == "sbhd": + batch_size = query_layer.shape[1] + max_seqlen_q = query_layer.shape[0] + max_seqlen_kv = key_layer.shape[0] + if qkv_format == "bshd": + batch_size = query_layer.shape[0] + max_seqlen_q = query_layer.shape[1] + max_seqlen_kv = key_layer.shape[1] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size + if "padding" in attn_mask_type: + assert ( + not context_parallel + ), "Padding mask not supported with context parallelism!" + if cu_seqlens_q is None or cu_seqlens_kv is None: + if attention_mask is None: + raise RuntimeError( + "Please provide attention_mask or cu_seqlens for padding!" + ) + if self.attention_type == "self": + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask) + cu_seqlens_kv = cu_seqlens_q + else: + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0]) + cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1]) + else: + if cu_seqlens_q is None: + cu_seqlens_q = dpa_utils.get_full_cu_seqlens( + batch_size, + max_seqlen_q, + query_layer.device, + ) + if cu_seqlens_kv is None: + cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( + batch_size, + max_seqlen_kv, + key_layer.device, + ) + if qkv_format == "thd": + assert ( + max_seqlen_q is not None + and max_seqlen_kv is not None + and cu_seqlens_q is not None + and cu_seqlens_kv is not None + ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" + elif inference_params.is_paged: + page_table = inference_params.cache_manager.page_table + + with self.attention_dropout_ctx(): + _attn_impl = AttnFuncFL + output = _attn_impl.apply( + self.training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + page_table, + page_table, + query_layer, + key_layer, + value_layer, + self.softmax_scale, + self.attention_dropout if self.training else 0.0, + qkv_layout, + attn_mask_type, + window_size, + None, + self.deterministic, + self.layer_number, + ) + + return output.view(*output.shape[:-2], -1) diff --git a/transformer_engine/plugin/core/backends/flagos/flagos.py b/transformer_engine/plugin/core/backends/flagos/flagos.py new file mode 100644 index 0000000000..1083928721 --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/flagos.py @@ -0,0 +1,312 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import os +from typing import Any, List, Optional, Tuple, Union + +import torch + +from ...ops import * + +from .impl import ( + rmsnorm_fwd_fl, + rmsnorm_bwd_fl, + multi_tensor_scale_fl, + multi_tensor_adam_fl, + multi_tensor_adam_param_remainder_fl, + multi_tensor_l2_norm_fl, + generic_gemm_fl, + scaled_masked_softmax_forward_fl, + scaled_masked_softmax_backward_fl, + te_general_grouped_gemm_fl, +) + + +def _check_flagos_available() -> bool: + return True + + +class FlagOSBackend(TEFLBackendBase): + @staticmethod + def check_available() -> bool: + return _check_flagos_available() + + def is_available(self) -> bool: + return _check_flagos_available() + + def get_attention_backend(self, attention_params=None): + from packaging.version import Version as PkgVersion + from ...logger_manager import get_logger + + logger = get_logger() + + # Read environment variables to determine which backends to enable + use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) + use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) + use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) + + # Log disabled backends + if not use_flash_attention: + logger.info_once("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") + if not use_fused_attention: + logger.info_once("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") + if not use_unfused_attention: + logger.info_once("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") + + flash_attention_backend = PkgVersion("2.6.0") if use_flash_attention else None + fused_attention_backend = NVTE_Fused_Attn_Backend.NVTE_No_Backend + + available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + + return ( + use_flash_attention, + flash_attention_backend, + use_fused_attention, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) + + ##### transformer_engine/pytorch/csrc/extensions/pybind.cpp ##### + def generic_gemm( + self, + A: Any, + transA: bool, + B: Any, + transB: bool, + D: Any, + quantizer: Any, + output_dtype: Optional[DType], + bias: Optional[torch.Tensor], + bias_type: DType, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[CommOverlapType] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, + ) -> List[Any]: + return generic_gemm_fl( + A, + transA, + B, + transB, + D, + quantizer, + output_dtype, + bias, + bias_type, + gelu, + gelu_in, + grad, + workspace, + workspace_size, + accumulate, + use_split_accumulator, + comm_overlap, + comm_type, + extra_output, + bulk_overlap, + alpha, + beta, + ) + + def te_general_grouped_gemm( + self, + A: List[Any], + transa: bool, + B: List[Any], + transb: bool, + D: Optional[List[torch.Tensor]], + D_type: DType, + m_splits: List[int], + bias: List[torch.Tensor], + bias_type: DType, + single_output: bool, + pre_gelu_out: List[torch.Tensor], + grad: bool, + workspace: List[torch.Tensor], + workspaceSizes: int, + accumulate: bool, + use_split_accumulator: bool, + math_sm_count: int, + ) -> Optional[List[torch.Tensor]]: + return te_general_grouped_gemm_fl( + A, + transa, + B, + transb, + D, + D_type, + m_splits, + bias, + bias_type, + single_output, + pre_gelu_out, + grad, + workspace, + workspaceSizes, + accumulate, + use_split_accumulator, + math_sm_count, + ) + + # Other granular functions + def rmsnorm_fwd( + self, + input: Any, + weight: Any, + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + return rmsnorm_fwd_fl( + input=input, + weight=weight, + eps=eps, + ln_out=ln_out, + quantizer=quantizer, + odtype=otype, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + ) + + def rmsnorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + return rmsnorm_bwd_fl( + dy=dz, + x=x, + rsigma=rsigma, + gamma=gamma, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + ) + + def get_fused_attn_backend(self, *args, **kwargs) -> int: + return NVTE_Fused_Attn_Backend.NVTE_No_Backend + + # Softmax functions + def scaled_masked_softmax_forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + scale_factor: Union[float, torch.Tensor], + ) -> torch.Tensor: + return scaled_masked_softmax_forward_fl(input, mask, scale_factor) + + def scaled_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + return scaled_masked_softmax_backward_fl(output_grad_, softmax_results_, scale_factor) + + # multi-tensor functions + def multi_tensor_scale( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, + ) -> None: + return multi_tensor_scale_fl(chunk_size, noop_flag, tensor_lists, scale) + + def multi_tensor_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return multi_tensor_l2_norm_fl(chunk_size, noop_flag, tensor_lists, per_tensor) + + def multi_tensor_adam( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + return multi_tensor_adam_fl( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_param_remainder( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + return multi_tensor_adam_param_remainder_fl( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + # Misc + def get_cublasLt_version(self) -> int: + return 110000 + + def get_cudnn_version(self) -> int: + return 90000 + + def get_num_cublas_streams(self) -> int: + return 4 # keep consistent with transformer_engine/common/util/multi_stream.cpp, get_num_compute_streams() + + ############## class func ################################# + def get_flash_attention_class(self): + from .attention.dot_product_attention.backends import FlashAttentionFL + + return FlashAttentionFL diff --git a/transformer_engine/plugin/core/backends/flagos/impl/__init__.py b/transformer_engine/plugin/core/backends/flagos/impl/__init__.py new file mode 100644 index 0000000000..d4853b6fdd --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/impl/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2022-2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .gemm import * +from .rmsnorm import * +from .fused_adam import * +from .multi_tensor import * +from .softmax import * diff --git a/transformer_engine/plugin/core/backends/flagos/impl/fused_adam.py b/transformer_engine/plugin/core/backends/flagos/impl/fused_adam.py new file mode 100644 index 0000000000..95602c731f --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/impl/fused_adam.py @@ -0,0 +1,187 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import List +import torch +import flag_gems + + +def multi_tensor_adam_fl( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + eps: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, +) -> None: + + num_lists = len(tensor_lists) + assert num_lists in [4, 5], f"Expected 4 or 5 tensor lists, got {num_lists}" + + num_tensors = len(tensor_lists[0]) + assert num_tensors > 0, "No tensors provided" + + for i, lst in enumerate(tensor_lists): + assert len(lst) == num_tensors, f"List {i} has {len(lst)} tensors, expected {num_tensors}" + + bias_correction1 = 1.0 + bias_correction2 = 1.0 + if bias_correction == 1: + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + is_adamw = mode == 1 + + for i in range(num_tensors): + g = tensor_lists[0][i] + p = tensor_lists[1][i] + m = tensor_lists[2][i] + v = tensor_lists[3][i] + p_master = tensor_lists[4][i] if num_lists == 5 else None + + if not g.is_contiguous(): + g = g.contiguous() + + m = flag_gems.add_(flag_gems.mul_(m, beta1), g, alpha=1 - beta1) + v = flag_gems.add_( + flag_gems.mul_(v, beta2), flag_gems.mul_(flag_gems.mul_(g, g), 1 - beta2) + ) + + m_corr = m.clone() + v_corr = v.clone() + if bias_correction == 1: + m_corr = flag_gems.true_divide(m_corr, bias_correction1) + v_corr = flag_gems.true_divide(v_corr, bias_correction2) + + update = flag_gems.true_divide(m_corr, flag_gems.add(flag_gems.sqrt(v_corr), eps)) + + if is_adamw: + p = flag_gems.mul_(p, 1 - lr * weight_decay) + else: + update = flag_gems.add_(update, p, alpha=weight_decay) + + p = flag_gems.add_(p, update, alpha=-lr) + + if p_master is not None: + flag_gems.copy_(p_master, p) + + +def multi_tensor_adam_param_remainder_fl( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + eps: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, +) -> None: + """ + Adam optimizer with parameter remainders for BF16 precision (FlagOS implementation). + """ + if noop_flag.item() != 0: + return + + num_lists = len(tensor_lists) + assert num_lists == 5, f"Expected 5 tensor lists, got {num_lists}" + + num_tensors = len(tensor_lists[0]) + assert num_tensors > 0, "No tensors provided" + + for i, lst in enumerate(tensor_lists): + assert len(lst) == num_tensors, f"List {i} has {len(lst)} tensors, expected {num_tensors}" + + bias_correction1 = 1.0 + bias_correction2 = 1.0 + if bias_correction == 1: + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + is_adamw = mode == 1 + + for i in range(num_tensors): + g = tensor_lists[0][i] + p = tensor_lists[1][i] # int16 parameter (high 16 bits of FP32) + m = tensor_lists[2][i] # FP32 first moment + v = tensor_lists[3][i] # FP32 second moment + p_remainder = tensor_lists[4][i] # int16 remainder (low 16 bits of FP32) + + if not g.is_contiguous(): + g = g.contiguous() + + # Convert gradient to float + g_float = g.float() + + # Reconstruct FP32 master weight from int16 param + int16 remainder using bit manipulation + # This matches the CUDA implementation exactly: + # 1. If p_remainder < 0, decrement p (undo rounding) + # 2. Combine high 16 bits (p) and low 16 bits (p_remainder) into FP32 + # Note: Use PyTorch native ops for bit manipulation (int16/int32 operations) + + local_p = p.view(torch.int16).clone() + local_p_rem = p_remainder.clone() + + # Undo rounding: if remainder < 0, decrement p + local_p = torch.where(local_p_rem < 0, local_p - 1, local_p) + + # Combine into FP32 using bit shift operations + # local_p is high 16 bits, local_p_rem is low 16 bits + high_bits = local_p.to(torch.int32) << 16 + low_bits = local_p_rem.to(torch.int32) & 0xFFFF # Mask off sign extension + param_int32 = high_bits | low_bits + param_master = param_int32.view(torch.float32) + + # L2 mode: add weight decay to gradient before updating moments + if not is_adamw and weight_decay != 0: + g_float = flag_gems.add(g_float, param_master, alpha=weight_decay) + + # Update first moment: m = beta1 * m + (1 - beta1) * g + flag_gems.add_(flag_gems.mul_(m, beta1), g_float, alpha=1 - beta1) + + # Update second moment: v = beta2 * v + (1 - beta2) * g^2 + flag_gems.add_(flag_gems.mul_(v, beta2), flag_gems.mul(g_float, g_float), alpha=1 - beta2) + + # Apply bias correction + m_corr = flag_gems.true_divide(m, bias_correction1) + v_corr = flag_gems.true_divide(v, bias_correction2) + + # Compute denominator: sqrt(v_corr) + eps + denom = flag_gems.add(flag_gems.sqrt(v_corr), eps) + + # Compute update + update = flag_gems.true_divide(m_corr, denom) + + # AdamW mode: add decoupled weight decay to update + if is_adamw and weight_decay != 0: + update = flag_gems.add(update, param_master, alpha=weight_decay) + + # Update master weight: p = p - lr * update + param_master = flag_gems.sub(param_master, flag_gems.mul(update, lr)) + + # Split FP32 back into int16 param + int16 remainder using bit manipulation + # This matches the CUDA implementation exactly: + # 1. Extract high 16 bits as p + # 2. Extract low 16 bits as p_remainder + # 3. If p_remainder < 0, increment p (round up) + # Note: Use PyTorch native ops for bit manipulation (int32 operations) + + param_int32 = param_master.view(torch.int32) + # Extract low 16 bits (remainder) and high 16 bits (param) + new_p_rem = (param_int32 & 0xFFFF).to(torch.int16) + new_p = ((param_int32 >> 16) & 0xFFFF).to(torch.int16) + + # Round up: if remainder < 0, increment p + new_p = torch.where(new_p_rem < 0, new_p + 1, new_p) + + # Write back + flag_gems.copy_(p, new_p.view(torch.bfloat16)) + flag_gems.copy_(p_remainder, new_p_rem) diff --git a/transformer_engine/plugin/core/backends/flagos/impl/gemm.py b/transformer_engine/plugin/core/backends/flagos/impl/gemm.py new file mode 100644 index 0000000000..e190af5c5d --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/impl/gemm.py @@ -0,0 +1,222 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import Any, Dict, List, Optional, Tuple, Union +import torch + +import flag_gems + + +__all__ = [ + "generic_gemm_fl", + "te_general_grouped_gemm_fl", +] + +_DTYPE_TO_TORCH = { + 0: torch.uint8, + 2: torch.int32, + 4: torch.float32, + 5: torch.float16, + 6: torch.bfloat16, + 7: torch.float8_e4m3fn, + 8: torch.float8_e5m2, +} + + +def validate_gemm_scale(scale: Optional[float], required: bool) -> float: + if required: + return scale if scale is not None else 1.0 + if scale not in (0.0, None): + raise ValueError("scale must be zero") + return 0.0 + + +def _convert_dtype(dtype: Union[int, torch.dtype, None]) -> Optional[torch.dtype]: + if dtype is None: + return None + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, int): + return _DTYPE_TO_TORCH.get(dtype, None) + if hasattr(dtype, "value"): + return _DTYPE_TO_TORCH.get(dtype.value, None) + return None + + +def generic_gemm_fl( + A: torch.Tensor, + transA: bool, + B: torch.Tensor, + transB: bool, + D: Optional[torch.Tensor], + quantizer: Any, + output_dtype: Any, + bias: Optional[torch.Tensor], + bias_type: Any, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[Any] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + + assert not gelu and gelu_in is None, "Triton-Based General Gemm do not support gelu now" + assert quantizer is None, "Triton-Based General Gemm do not support quantization now" + assert bias is None, "Triton-Based General Gemm do not support bias now" + + alpha = validate_gemm_scale(alpha, True) + beta = validate_gemm_scale(beta, accumulate) + + s = -1 + b = -1 + orig_A_shape = A.shape + orig_B_shape = B.shape + shape_a_changed = False + shape_b_changed = False + + if A.ndim == 3: + A = A.view(-1, A.shape[-1]) + shape_a_changed = True + + if B.ndim == 3: + s, b, _ = B.shape + B = B.view(-1, B.shape[-1]) + shape_b_changed = True + + A_comp = A.T if transA else A + B_comp = B.T if transB else B + + out1 = flag_gems.mm(B_comp, A_comp) + + if shape_b_changed: + out1 = out1.view(s, b, -1) + + torch_out_dtype = _convert_dtype(output_dtype) + if torch_out_dtype is not None and out1.dtype != torch_out_dtype: + out1 = out1.to(torch_out_dtype) + + bias_grad = None + gelu_input = None + extra_output_ret = None + + if D is not None: + if accumulate: + flag_gems.add_(D, out1) + else: + flag_gems.copy_(D, out1) + return D, bias_grad, gelu_input, extra_output_ret + else: + return out1, bias_grad, gelu_input, extra_output_ret + + +# This function can represent both forward and backward computations. +# When grad is False (forward computation), the 'bias' is bias; +# When grad is True (backward computation/gradient calculation), the 'bias' is grad_bias; +def te_general_grouped_gemm_fl( + B: List[torch.Tensor], + transb: bool, + A: List[torch.Tensor], + transa: bool, + D: Optional[List[torch.Tensor]], + D_type: Any, + m_splits: List[int], + bias: List[torch.Tensor], # bias or grad_bias + bias_type: Any, + single_output: bool, + pre_gelu_out: List[torch.Tensor], + grad: bool, + workspace: List[torch.Tensor], + workspaceSize: int, + accumulate: bool, + use_split_accumulator: bool, + math_sm_count: int, +) -> Optional[List[torch.Tensor]]: + if single_output and D is None: + raise ValueError("not implemented, D should be allocated for single output case.") + + num_gemms = len(A) + if D is None: + D = [] + for i in range(num_gemms): + m = A[i].shape[1] if transa else A[i].shape[0] + n = B[i].shape[0] if transb else B[i].shape[1] + D.append(torch.empty((m, n), dtype=D[i].dtype, device=A[0].device)) + + temp_D = [] + for i in range(num_gemms): + # Handle the special case of zero-element inputs + if A[i].numel() == 0 or B[i].numel() == 0: + if not single_output: + if D[i].numel() != 0 and not accumulate: + flag_gems.copy_(D[i], flag_gems.zeros(D[i].shape)) + else: + out = flag_gems.zeros((A[i].shape[0], B[i].shape[1])) + if grad and len(bias) > i and bias[i] is not None and bias[i].numel() != 0: + flag_gems.copy_(bias[i], flag_gems.zeros(bias[i].shape)) + if ( + len(pre_gelu_out) > i + and pre_gelu_out[i] is not None + and pre_gelu_out[i].numel() != 0 + ): + flag_gems.copy_(pre_gelu_out[i], flag_gems.zeros(pre_gelu_out[i].shape)) + continue + + a = A[i].t() if transa else A[i] + b = B[i].t() if transb else B[i] + # Determine presence of epilogue tensors + has_bias = len(bias) > i and bias[i] is not None and bias[i].numel() > 0 + has_pre_gelu = ( + len(pre_gelu_out) > i and pre_gelu_out[i] is not None and pre_gelu_out[i].numel() > 0 + ) + + # Forward Pass calculation + if not grad: + if has_bias: + # Fused matrix multiplication and bias addition + out = flag_gems.addmm(bias[i], a, b) + else: + out = flag_gems.mm(a, b) + + # Apply GELU epilogue if pre_gelu_out is provided + if has_pre_gelu: + flag_gems.copy_(pre_gelu_out[i], out) + out = flag_gems.gelu(out) + else: + out = flag_gems.mm(a, b) + + # Apply dGELU epilogue if requested + if has_pre_gelu: + out = flag_gems.gelu_backward(out, pre_gelu_out[i]) + + # Compute bias gradients if requested + if has_bias: + bias_grad = flag_gems.sum_dim(out, dim=[0]) + if accumulate: + flag_gems.add_(bias[i], bias_grad) + else: + flag_gems.copy_(bias[i], bias_grad) + + if not single_output: + # Store output + if accumulate: + flag_gems.add_(D[i], out.to(D[i].dtype)) + else: + flag_gems.copy_(D[i], out.to(D[i].dtype)) + else: + temp_D.append(out.to(D[0].dtype)) + + if single_output: + if temp_D: + temp = flag_gems.cat(temp_D, dim=0) + flag_gems.copy_(D[0], temp) + + return bias diff --git a/transformer_engine/plugin/core/backends/flagos/impl/multi_tensor.py b/transformer_engine/plugin/core/backends/flagos/impl/multi_tensor.py new file mode 100644 index 0000000000..d728a76242 --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/impl/multi_tensor.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import List, Tuple +import torch +import flag_gems + + +def multi_tensor_l2_norm_fl( + _chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute L2 norm of tensors using flag_gems. + + Returns: + Tuple of (total_norm, per_tensor_norms_or_dummy) + - total_norm: The combined L2 norm of all tensors + - per_tensor_norms_or_dummy: Per-tensor norms stacked if per_tensor=True, else dummy tensor + """ + device = tensor_lists[0][0].device if tensor_lists and tensor_lists[0] else "cpu" + + if noop_flag.item() != 0: + return torch.tensor(0.0, device=device), torch.tensor(0.0, device=device) + + tensors = tensor_lists[0] + + # Compute per-tensor norms + per_tensor_norms = [] + total_norm_sq = torch.tensor(0.0, device=device) + + for tensor in tensors: + t_float = tensor.float() + norm_sq = flag_gems.sum(flag_gems.mul(t_float, t_float)) + # Check for inf/nan (matches CUDA behavior) + if not torch.isfinite(norm_sq): + noop_flag.fill_(1) + total_norm_sq = flag_gems.add(total_norm_sq, norm_sq) + if per_tensor: + per_tensor_norms.append(flag_gems.sqrt(norm_sq)) + + total_norm = flag_gems.sqrt(total_norm_sq) + + if per_tensor: + per_tensor_result = torch.stack(per_tensor_norms) + else: + per_tensor_result = torch.tensor(0.0, device=device) + + return total_norm, per_tensor_result + + +def multi_tensor_scale_fl( + _chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, +) -> None: + if noop_flag.item() != 0: + return + + for src, dst in zip(tensor_lists[0], tensor_lists[1]): + # Check for inf/nan (matches CUDA behavior for AMP gradient scaling) + if not torch.isfinite(src).all(): + noop_flag.fill_(1) + flag_gems.copy_(dst, flag_gems.mul(src, scale)) diff --git a/transformer_engine/plugin/core/backends/flagos/impl/rmsnorm.py b/transformer_engine/plugin/core/backends/flagos/impl/rmsnorm.py new file mode 100644 index 0000000000..12fda567ed --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/impl/rmsnorm.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import torch +import flag_gems + + +def rmsnorm_fwd_fl( + input, + weight, + eps, + ln_out, + quantizer, + odtype, + sm_margin, + zero_centered_gamma, +): + if zero_centered_gamma: + # weight_adj = 1 + weight + weight_adj = flag_gems.add(1, weight) + else: + weight_adj = weight + + y, rstdevs = flag_gems.rms_norm_forward( + input, + [input.shape[-1]], + weight_adj, + eps, + ) + + if rstdevs.shape != input.shape[:-1]: + rstdevs = rstdevs.view(input.shape[:-1]) + + return y, None, rstdevs + + +def rmsnorm_bwd_fl( + dy, + x, + rsigma, + gamma, + sm_margin, + zero_centered_gamma, + eps=1e-5, +): + # When zero_centered_gamma is True, forward uses (1 + gamma) as weight + # So backward needs to use (1 + gamma) for computing dx + if zero_centered_gamma: + gamma_adj = flag_gems.add(1, gamma) + else: + gamma_adj = gamma + + dx, dw = flag_gems.rms_norm_backward( + dy, + x, + rsigma, + [x.shape[-1]], + gamma_adj, + eps, + ) + return dx, dw diff --git a/transformer_engine/plugin/core/backends/flagos/impl/softmax.py b/transformer_engine/plugin/core/backends/flagos/impl/softmax.py new file mode 100644 index 0000000000..31564b224f --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/impl/softmax.py @@ -0,0 +1,61 @@ +import torch +from typing import Union +import flag_gems + + +def scaled_masked_softmax_forward_fl( + input: torch.Tensor, + mask: torch.Tensor, + scale_factor: Union[float, torch.Tensor], +) -> torch.Tensor: + # Ensure `mask` and 'scale_factor' is on the same device as `input`. + if mask.device != input.device: + mask = flag_gems.to_copy(mask, device=input.device) + if isinstance(scale_factor, torch.Tensor): + if scale_factor.device != input.device: + scale_factor = flag_gems.to_copy(scale_factor, device=input.device) + + # Keep semantics aligned with TE CUDA scaled_masked_softmax: + # - integer/bool mask: masked iff mask == 1, masked logits set to -10000.0 + # - float mask: treated as additive bias in logit space + if mask.dim() == 4 and mask.size(1) == 1 and input.dim() == 4: + mask = mask.expand_as(input) + + scaled = flag_gems.mul(input, scale_factor) + if mask.is_floating_point(): + mask_f = flag_gems.to_copy(mask, device=input.device, dtype=scaled.dtype) + scaled = flag_gems.add(scaled, mask_f) + return flag_gems.softmax(scaled, dim=-1) + + # Avoid using `mask == 1` (torch op) since on some devices it may fall back to CPU, + # which would break Triton kernels inside flag_gems. + cond = flag_gems.eq_scalar(mask, 1) + scaled = flag_gems.masked_fill(scaled, cond, -10000.0) + all_masked = flag_gems.all_dim(cond, dim=-1, keepdim=True) + out = flag_gems.softmax(scaled, dim=-1) + return flag_gems.masked_fill(out, all_masked, 0.0) + + +def scaled_masked_softmax_backward_fl( + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, +) -> torch.Tensor: + orig_dtype = output_grad_.dtype + # Compute in float32 for numerical stability. + output_grad_f32 = flag_gems.to_copy(output_grad_, dtype=torch.float32) + softmax_output_f32 = flag_gems.to_copy( + softmax_results_, dtype=torch.float32, device=output_grad_.device + ) + if isinstance(scale_factor, torch.Tensor): + if scale_factor.device != output_grad_.device: + scale_factor = flag_gems.to_copy(scale_factor, device=output_grad_.device) + + # term = softmax_output_f32 * output_grad_f32 + term = flag_gems.mul(softmax_output_f32, output_grad_f32) + # sum_term = sum(term, dim=-1, keepdim=True) + sum_term = flag_gems.sum_dim(term, dim=[-1], keepdim=True) + # grad_softmax = softmax_output_f32 * (output_grad_f32 - sum_term) + grad_softmax = flag_gems.mul(softmax_output_f32, flag_gems.sub(output_grad_f32, sum_term)) + grad_scaled = flag_gems.mul(grad_softmax, scale_factor) + return flag_gems.to_copy(grad_scaled, dtype=orig_dtype) diff --git a/transformer_engine/plugin/core/backends/flagos/register_ops.py b/transformer_engine/plugin/core/backends/flagos/register_ops.py new file mode 100644 index 0000000000..153012c501 --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/register_ops.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +FlagOS backend operator registrations. + +This module registers all DEFAULT (FlagOS) implementations. +""" + +from __future__ import annotations + +import functools + +from ...types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all FlagOS (DEFAULT) operator implementations. + + Args: + registry: Registry to register into + """ + from .flagos import FlagOSBackend + + # Create a backend instance to access the methods + backend = FlagOSBackend() + + # Bind is_available to all methods + is_avail = backend.is_available + + impls = [ + OpImpl( + op_name="rmsnorm_fwd", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.rmsnorm_fwd, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="rmsnorm_bwd", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.rmsnorm_bwd, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="generic_gemm", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.generic_gemm, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="te_general_grouped_gemm", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.te_general_grouped_gemm, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="multi_tensor_scale", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.multi_tensor_scale, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="multi_tensor_adam", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.multi_tensor_adam, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="multi_tensor_adam_param_remainder", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.multi_tensor_adam_param_remainder, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="multi_tensor_l2norm", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.multi_tensor_l2norm, is_avail), + vendor=None, + priority=150, + ), + # FlashAttention class getter + OpImpl( + op_name="get_flash_attention_class", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.get_flash_attention_class, is_avail), + vendor=None, + priority=150, + ), + # Attention backend selection + OpImpl( + op_name="get_attention_backend", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.get_attention_backend, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="get_fused_attn_backend", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.get_fused_attn_backend, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="get_num_cublas_streams", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.get_num_cublas_streams, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="scaled_masked_softmax_forward", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.scaled_masked_softmax_forward, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="scaled_masked_softmax_backward", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.scaled_masked_softmax_backward, is_avail), + vendor=None, + priority=150, + ), + OpImpl( + op_name="get_cudnn_version", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.get_cudnn_version, is_avail), + vendor=None, + priority=150, + ), + ] + + registry.register_many(impls) diff --git a/transformer_engine/plugin/core/backends/reference/__init__.py b/transformer_engine/plugin/core/backends/reference/__init__.py new file mode 100644 index 0000000000..08844be51b --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .reference import ReferenceBackend + +__all__ = ["ReferenceBackend"] diff --git a/transformer_engine/plugin/core/backends/reference/flash_attention.py b/transformer_engine/plugin/core/backends/reference/flash_attention.py new file mode 100644 index 0000000000..10a730ac52 --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/flash_attention.py @@ -0,0 +1,424 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.distributed as dist + +from transformer_engine.plugin.core.ops import FlashAttentionBase +from transformer_engine.plugin.core.backends.fa_utils import ( + all_gather_along_seq, + reduce_scatter_along_seq, + create_cp_causal_mask, + create_cp_window_mask, + get_cp_info, +) + + +class FlashAttentionTorch(FlashAttentionBase): + def __init__( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = None, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + ) -> None: + super().__init__( + softmax_scale=softmax_scale, + attention_dropout=attention_dropout, + attention_dropout_ctx=attention_dropout_ctx, + attention_type=attention_type, + layer_number=layer_number, + deterministic=deterministic, + ) + + @property + def backend_name(self) -> str: + return "torch_sdpa" + + def _convert_layout_to_bhsd( + self, + tensor: torch.Tensor, + layout: str, + ) -> torch.Tensor: + """Convert tensor from various layouts to [batch, heads, seq, dim] format.""" + layout = layout.lower() + + # Handle combined layouts like "sbhd_sbhd_sbhd" - extract the first part + if "_" in layout: + layout = layout.split("_")[0] + + if layout in ("sbhd", "sbh3d", "sb3hd"): + return tensor.permute(1, 2, 0, 3) + elif layout in ("bshd", "bsh3d", "bs3hd"): + return tensor.permute(0, 2, 1, 3) + elif layout in ("bhsd",): + return tensor + elif layout in ("thd",): + # thd is packed format, should not reach here for 4D tensors + raise ValueError(f"thd layout requires 3D tensor, got {tensor.dim()}D") + else: + raise ValueError(f"Unsupported qkv_layout: {layout}") + + def _convert_bhsd_to_layout( + self, + tensor: torch.Tensor, + layout: str, + ) -> torch.Tensor: + """Convert tensor from [batch, heads, seq, dim] back to original layout.""" + layout = layout.lower() + + # Handle combined layouts like "sbhd_sbhd_sbhd" - extract the first part + if "_" in layout: + layout = layout.split("_")[0] + + if layout in ("sbhd", "sbh3d", "sb3hd"): + return tensor.permute(2, 0, 1, 3) + elif layout in ("bshd", "bsh3d", "bs3hd"): + return tensor.permute(0, 2, 1, 3) + elif layout in ("bhsd",): + return tensor + elif layout in ("thd",): + raise ValueError(f"thd layout requires 3D tensor, got {tensor.dim()}D") + else: + raise ValueError(f"Unsupported qkv_layout: {layout}") + + def _create_sliding_window_mask( + self, + seq_len_q: int, + seq_len_kv: int, + window_size: Tuple[int, int], + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Create a sliding window attention mask.""" + left_window, right_window = window_size + + if left_window == -1 and right_window == -1: + return torch.zeros(seq_len_q, seq_len_kv, dtype=dtype, device=device) + + q_idx = torch.arange(seq_len_q, device=device).unsqueeze(1) + kv_idx = torch.arange(seq_len_kv, device=device).unsqueeze(0) + + mask_bool = torch.zeros(seq_len_q, seq_len_kv, dtype=torch.bool, device=device) + + if left_window >= 0: + mask_bool = mask_bool | (kv_idx < q_idx - left_window) + + if right_window >= 0: + mask_bool = mask_bool | (kv_idx > q_idx + right_window) + + mask = torch.zeros(seq_len_q, seq_len_kv, dtype=dtype, device=device) + mask.masked_fill_(mask_bool, float("-inf")) + + return mask + + def _unpack_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert packed tensor to padded tensor format.""" + batch_size = cu_seqlens.shape[0] - 1 + device = tensor.device + original_shape = tensor.shape + + if tensor.dim() == 4: + if tensor.shape[1] == 1: + tensor = tensor.squeeze(1) + else: + raise ValueError( + f"Unexpected 4D tensor shape {original_shape}. " + "Expected [total_tokens, 1, num_heads, head_dim]" + ) + + if tensor.dim() != 3: + raise ValueError( + f"Expected tensor to be 3D or 4D after processing, got shape {original_shape}" + ) + + total_tokens, num_heads, head_dim = tensor.shape + + expected_total = cu_seqlens[-1].item() + if total_tokens != expected_total: + raise ValueError( + f"Tensor has {total_tokens} tokens but cu_seqlens indicates {expected_total} tokens" + ) + + padded_tensor = torch.zeros( + batch_size, num_heads, max_seqlen, head_dim, dtype=tensor.dtype, device=device + ) + + padding_mask = torch.ones(batch_size, max_seqlen, dtype=torch.bool, device=device) + + # Vectorized unpacking - avoid Python loop and .item() calls + cu_seqlens_cpu = cu_seqlens.cpu() + for i in range(batch_size): + start = cu_seqlens_cpu[i].item() + end = cu_seqlens_cpu[i + 1].item() + seq_len = end - start + + seq_data = tensor[start:end].permute(1, 0, 2) + padded_tensor[i, :, :seq_len, :] = seq_data + padding_mask[i, :seq_len] = False + + return padded_tensor, padding_mask + + def _pack_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + """Convert padded tensor back to packed tensor format.""" + batch_size = tensor.shape[0] + num_heads = tensor.shape[1] + head_dim = tensor.shape[3] + total_tokens = cu_seqlens[-1].item() + device = tensor.device + + packed_tensor = torch.zeros( + total_tokens, num_heads, head_dim, dtype=tensor.dtype, device=device + ) + + # Vectorized packing - avoid repeated .item() calls + cu_seqlens_cpu = cu_seqlens.cpu() + for i in range(batch_size): + start = cu_seqlens_cpu[i].item() + end = cu_seqlens_cpu[i + 1].item() + seq_len = end - start + + seq_data = tensor[i, :, :seq_len, :].permute(1, 0, 2) + packed_tensor[start:end, :, :] = seq_data + + return packed_tensor + + def _forward_impl( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Any] = None, + cp_global_ranks: Optional[List[int]] = None, + cp_stream: Optional[torch.cuda.Stream] = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers: Optional[Any] = None, + inference_params: Optional[Any] = None, + flash_attention_backend: Optional[Any] = None, + fp8_output: bool = False, + num_splits: Optional[int] = 1, + ) -> torch.Tensor: + """Flash Attention implementation using PyTorch's scaled_dot_product_attention. + + Supports Context Parallelism (CP) by all-gathering key/value across the CP group. + """ + if fp8: + raise NotImplementedError("FP8 is not supported in PyTorch SDPA backend") + + if alibi_slopes is not None: + raise NotImplementedError("ALiBi slopes are not supported in PyTorch SDPA backend") + + query_original_shape = query_layer.shape + cp_size, cp_rank, use_cp = get_cp_info(cp_group) + + is_standard_4d = query_layer.dim() == 4 + + if is_standard_4d: + query = self._convert_layout_to_bhsd(query_layer, qkv_layout) + key = self._convert_layout_to_bhsd(key_layer, qkv_layout) + value = self._convert_layout_to_bhsd(value_layer, qkv_layout) + use_packed_format = False + padding_mask_q = None + padding_mask_kv = None + else: + use_packed_format = cu_seqlens_q is not None or cu_seqlens_kv is not None + padding_mask_q = None + padding_mask_kv = None + + if use_packed_format: + if cu_seqlens_q is not None: + query, padding_mask_q = self._unpack_tensor( + query_layer, cu_seqlens_q, max_seqlen_q + ) + else: + query = self._convert_layout_to_bhsd(query_layer, qkv_layout) + + if cu_seqlens_kv is not None: + key, padding_mask_kv = self._unpack_tensor( + key_layer, cu_seqlens_kv, max_seqlen_kv + ) + value, _ = self._unpack_tensor(value_layer, cu_seqlens_kv, max_seqlen_kv) + else: + key = self._convert_layout_to_bhsd(key_layer, qkv_layout) + value = self._convert_layout_to_bhsd(value_layer, qkv_layout) + else: + query = self._convert_layout_to_bhsd(query_layer, qkv_layout) + key = self._convert_layout_to_bhsd(key_layer, qkv_layout) + value = self._convert_layout_to_bhsd(value_layer, qkv_layout) + + batch_size, num_heads_q, seq_len_q, head_dim = query.shape + local_seq_len_q = seq_len_q + + if use_cp: + # All-gather key/value along sequence dimension for full context + key = all_gather_along_seq(key, cp_group, seq_dim=2) + value = all_gather_along_seq(value, cp_group, seq_dim=2) + + num_heads_kv = key.shape[1] + seq_len_kv = key.shape[2] + + if num_heads_q != num_heads_kv: + num_groups = num_heads_q // num_heads_kv + if num_heads_q % num_heads_kv != 0: + raise ValueError( + f"num_heads_q ({num_heads_q}) must be divisible by num_heads_kv" + f" ({num_heads_kv})" + ) + key = key.repeat_interleave(num_groups, dim=1) + value = value.repeat_interleave(num_groups, dim=1) + + attn_mask = None + is_causal = False + + if use_packed_format and padding_mask_kv is not None: + attn_mask = torch.zeros( + batch_size, seq_len_q, seq_len_kv, dtype=query.dtype, device=query.device + ) + padding_broadcast = padding_mask_kv.unsqueeze(1) + attn_mask.masked_fill_(padding_broadcast, float("-inf")) + + if attn_mask_type == "causal": + if use_cp: + # Use shared utility for CP causal mask creation + causal_mask = create_cp_causal_mask( + local_seq_len_q, seq_len_kv, cp_rank, query.device, query.dtype + ) + if attn_mask is not None: + if attn_mask.dim() == 2: + attn_mask = attn_mask + causal_mask + else: + attn_mask = attn_mask + causal_mask.unsqueeze(0) + else: + attn_mask = causal_mask + elif window_size is None and not use_packed_format: + is_causal = True + else: + causal_mask = torch.zeros( + seq_len_q, seq_len_kv, dtype=query.dtype, device=query.device + ) + causal_mask.masked_fill_( + torch.triu( + torch.ones(seq_len_q, seq_len_kv, device=query.device, dtype=torch.bool), + diagonal=1, + ), + float("-inf"), + ) + + if attn_mask is not None: + if attn_mask.dim() == 2: + attn_mask = attn_mask + causal_mask + else: + attn_mask = attn_mask + causal_mask.unsqueeze(0) + else: + attn_mask = causal_mask + + if window_size is not None and not is_causal: + if use_cp: + # Use shared utility for CP window mask creation + window_mask = create_cp_window_mask( + local_seq_len_q, seq_len_kv, cp_rank, window_size, query.device, query.dtype + ) + else: + window_mask = self._create_sliding_window_mask( + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + window_size=window_size, + device=query.device, + dtype=query.dtype, + ) + + if attn_mask is not None: + attn_mask = ( + attn_mask + window_mask.unsqueeze(0) + if window_mask.dim() == 2 + else attn_mask + window_mask + ) + else: + attn_mask = window_mask + + if attention_mask is not None and attn_mask_type != "causal": + if isinstance(attention_mask, tuple): + explicit_mask = attention_mask[0] + else: + explicit_mask = attention_mask + + if explicit_mask.dtype == torch.bool: + float_mask = torch.zeros_like(explicit_mask, dtype=query.dtype) + float_mask.masked_fill_(~explicit_mask, float("-inf")) + explicit_mask = float_mask + + if explicit_mask.dim() == 2: + explicit_mask = explicit_mask.unsqueeze(0).unsqueeze(0) + elif explicit_mask.dim() == 3: + explicit_mask = explicit_mask.unsqueeze(1) + + if attn_mask is not None: + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) + elif attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(1) + attn_mask = attn_mask + explicit_mask + else: + attn_mask = explicit_mask + elif attn_mask is not None: + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) + elif attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(1) + + with self.attention_dropout_ctx(): + dropout_p = self.attention_dropout if self.training else 0.0 + + output = F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=self.softmax_scale, + ) + + if use_packed_format and padding_mask_q is not None: + mask_expanded = padding_mask_q.unsqueeze(1).unsqueeze(3) + output = output.masked_fill(mask_expanded, 0.0) + + if use_packed_format and cu_seqlens_q is not None: + output = self._pack_tensor(output, cu_seqlens_q) + + if len(query_original_shape) == 4: + total_tokens = output.shape[0] + hidden_size = output.shape[1] * output.shape[2] + output = output.contiguous().view(total_tokens, 1, hidden_size) + else: + output = self._convert_bhsd_to_layout(output, qkv_layout) + output = output.contiguous().view(*output.shape[:-2], -1) + + return output diff --git a/transformer_engine/plugin/core/backends/reference/impl/__init__.py b/transformer_engine/plugin/core/backends/reference/impl/__init__.py new file mode 100644 index 0000000000..f467767d61 --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/impl/__init__.py @@ -0,0 +1,111 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .gemm import general_gemm_torch + +from .rmsnorm import rmsnorm_fwd_torch, rmsnorm_bwd_torch +from .normalization import layernorm_fwd_torch, layernorm_bwd_torch + +from .activation import ( + gelu_torch, + geglu_torch, + qgelu_torch, + qgeglu_torch, + relu_torch, + reglu_torch, + srelu_torch, + sreglu_torch, + silu_torch, + swiglu_torch, + clamped_swiglu_torch, + dgelu_torch, + dgeglu_torch, + dqgelu_torch, + dqgeglu_torch, + drelu_torch, + dreglu_torch, + dsrelu_torch, + dsreglu_torch, + dsilu_torch, + dswiglu_torch, + clamped_dswiglu_torch, + dbias_dgelu_torch, + dbias_dsilu_torch, + dbias_drelu_torch, + dbias_dqgelu_torch, + dbias_dsrelu_torch, +) + +from .softmax import ( + scaled_softmax_forward_torch, + scaled_softmax_backward_torch, + scaled_masked_softmax_forward_torch, + scaled_masked_softmax_backward_torch, + scaled_upper_triang_masked_softmax_forward_torch, + scaled_upper_triang_masked_softmax_backward_torch, + scaled_aligned_causal_masked_softmax_forward_torch, + scaled_aligned_causal_masked_softmax_backward_torch, +) + +from .dropout import dropout_fwd_torch, dropout_bwd_torch + +from .optimizer import ( + multi_tensor_scale_torch, + multi_tensor_l2norm_torch, + multi_tensor_adam_torch, + multi_tensor_adam_param_remainder_torch, + multi_tensor_sgd_torch, + multi_tensor_compute_scale_and_scale_inv_torch, +) + +__all__ = [ + "general_gemm_torch", + "rmsnorm_fwd_torch", + "rmsnorm_bwd_torch", + "layernorm_fwd_torch", + "layernorm_bwd_torch", + "gelu_torch", + "geglu_torch", + "qgelu_torch", + "qgeglu_torch", + "relu_torch", + "reglu_torch", + "srelu_torch", + "sreglu_torch", + "silu_torch", + "swiglu_torch", + "clamped_swiglu_torch", + "dgelu_torch", + "dgeglu_torch", + "dqgelu_torch", + "dqgeglu_torch", + "drelu_torch", + "dreglu_torch", + "dsrelu_torch", + "dsreglu_torch", + "dsilu_torch", + "dswiglu_torch", + "clamped_dswiglu_torch", + "dbias_dgelu_torch", + "dbias_dsilu_torch", + "dbias_drelu_torch", + "dbias_dqgelu_torch", + "dbias_dsrelu_torch", + "scaled_softmax_forward_torch", + "scaled_softmax_backward_torch", + "scaled_masked_softmax_forward_torch", + "scaled_masked_softmax_backward_torch", + "scaled_upper_triang_masked_softmax_forward_torch", + "scaled_upper_triang_masked_softmax_backward_torch", + "scaled_aligned_causal_masked_softmax_forward_torch", + "scaled_aligned_causal_masked_softmax_backward_torch", + "dropout_fwd_torch", + "dropout_bwd_torch", + "multi_tensor_scale_torch", + "multi_tensor_l2norm_torch", + "multi_tensor_adam_torch", + "multi_tensor_adam_param_remainder_torch", + "multi_tensor_sgd_torch", + "multi_tensor_compute_scale_and_scale_inv_torch", +] diff --git a/transformer_engine/plugin/core/backends/reference/impl/activation.py b/transformer_engine/plugin/core/backends/reference/impl/activation.py new file mode 100644 index 0000000000..919c3718cb --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/impl/activation.py @@ -0,0 +1,286 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import Any, Optional, Tuple +import torch +import torch.nn.functional as F + +__all__ = [ + "gelu_torch", + "geglu_torch", + "qgelu_torch", + "qgeglu_torch", + "relu_torch", + "reglu_torch", + "srelu_torch", + "sreglu_torch", + "silu_torch", + "swiglu_torch", + "clamped_swiglu_torch", + "dgelu_torch", + "dgeglu_torch", + "dqgelu_torch", + "dqgeglu_torch", + "drelu_torch", + "dreglu_torch", + "dsrelu_torch", + "dsreglu_torch", + "dsilu_torch", + "dswiglu_torch", + "clamped_dswiglu_torch", + "dbias_dgelu_torch", + "dbias_dsilu_torch", + "dbias_drelu_torch", + "dbias_dqgelu_torch", + "dbias_dsrelu_torch", +] + + +def gelu_torch(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + return F.gelu(input, approximate="tanh") + + +def geglu_torch(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = input.chunk(2, dim=-1) + return F.gelu(a, approximate="tanh") * b + + +def qgelu_torch(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + return input * torch.sigmoid(1.702 * input) + + +def qgeglu_torch(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = input.chunk(2, dim=-1) + return a * torch.sigmoid(1.702 * a) * b + + +def relu_torch(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + return F.relu(input) + + +def reglu_torch(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = input.chunk(2, dim=-1) + return F.relu(a) * b + + +def srelu_torch(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + return torch.square(F.relu(input)) + + +def sreglu_torch(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = input.chunk(2, dim=-1) + return torch.square(F.relu(a)) * b + + +def silu_torch(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + return F.silu(input) + + +def swiglu_torch(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = input.chunk(2, dim=-1) + return F.silu(a) * b + + +def clamped_swiglu_torch( + input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, +) -> torch.Tensor: + """Clamped SwiGLU matching CUDA implementation. + + CUDA implementation: + - a (activation): clamp to upper bound only: min(a, limit) + - b (gate): clamp to [-limit, limit], then add 1 + - output = (a_clamped * sigmoid(alpha * a_clamped)) * b_clamped + """ + a, b = input.chunk(2, dim=-1) + # CUDA only clamps a to upper bound + a_clamped = torch.clamp(a, max=limit) + # CUDA clamps b to [-limit, limit] and adds 1 + b_clamped = torch.clamp(b, -limit, limit) + 1 + return a_clamped * torch.sigmoid(alpha * a_clamped) * b_clamped + + +def dgelu_torch(grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> torch.Tensor: + x = fwd_input.detach().requires_grad_(True) + with torch.enable_grad(): + y = F.gelu(x, approximate="tanh") + y.backward(grad) + return x.grad + + +def dgeglu_torch(grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = fwd_input.chunk(2, dim=-1) + a = a.detach().requires_grad_(True) + b = b.detach().requires_grad_(True) + + with torch.enable_grad(): + y = F.gelu(a, approximate="tanh") * b + y.backward(grad) + + return torch.cat([a.grad, b.grad], dim=-1) + + +def dqgelu_torch(grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> torch.Tensor: + x = fwd_input.detach().requires_grad_(True) + with torch.enable_grad(): + y = x * torch.sigmoid(1.702 * x) + y.backward(grad) + return x.grad + + +def dqgeglu_torch(grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = fwd_input.chunk(2, dim=-1) + a = a.detach().requires_grad_(True) + b = b.detach().requires_grad_(True) + + with torch.enable_grad(): + y = a * torch.sigmoid(1.702 * a) * b + y.backward(grad) + + return torch.cat([a.grad, b.grad], dim=-1) + + +def drelu_torch(grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> torch.Tensor: + return grad * (fwd_input > 0).to(grad.dtype) + + +def dreglu_torch(grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = fwd_input.chunk(2, dim=-1) + + grad_a = grad * b * (a > 0).to(grad.dtype) + grad_b = grad * F.relu(a) + + return torch.cat([grad_a, grad_b], dim=-1) + + +def dsrelu_torch(grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> torch.Tensor: + relu_x = F.relu(fwd_input) + return 2 * grad * relu_x * (fwd_input > 0).to(grad.dtype) + + +def dsreglu_torch(grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = fwd_input.chunk(2, dim=-1) + + relu_a = F.relu(a) + grad_a = grad * b * 2 * relu_a * (a > 0).to(grad.dtype) + grad_b = grad * torch.square(relu_a) + + return torch.cat([grad_a, grad_b], dim=-1) + + +def dsilu_torch(grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> torch.Tensor: + x = fwd_input.detach().requires_grad_(True) + with torch.enable_grad(): + y = F.silu(x) + y.backward(grad) + return x.grad + + +def dswiglu_torch(grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = fwd_input.chunk(2, dim=-1) + a = a.detach().requires_grad_(True) + b = b.detach().requires_grad_(True) + + with torch.enable_grad(): + y = F.silu(a) * b + y.backward(grad) + + return torch.cat([a.grad, b.grad], dim=-1) + + +def clamped_dswiglu_torch( + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, +) -> torch.Tensor: + """Backward pass for clamped SwiGLU matching CUDA implementation. + + CUDA implementation: + - a (activation): clamp to upper bound only, derivative is 0 if a > limit + - b (gate): clamp to [-limit, limit] and add 1, derivative is 0 outside range + """ + a, b = fwd_input.chunk(2, dim=-1) + + # CUDA only clamps a to upper bound + a_clamped = torch.clamp(a, max=limit) + # CUDA clamps b to [-limit, limit] and adds 1 + b_clamped = torch.clamp(b, -limit, limit) + 1 + + a_clamped = a_clamped.detach().requires_grad_(True) + b_clamped = b_clamped.detach().requires_grad_(True) + + with torch.enable_grad(): + y = a_clamped * torch.sigmoid(alpha * a_clamped) * b_clamped + y.backward(grad) + + # Derivative of a clamp (upper bound only): 0 if a > limit + grad_a = a_clamped.grad * (a <= limit).to(grad.dtype) + # Derivative of b clamp ([-limit, limit]): 0 outside range + grad_b = b_clamped.grad * ((b >= -limit) & (b <= limit)).to(grad.dtype) + + return torch.cat([grad_a, grad_b], dim=-1) + + +def dbias_dgelu_torch( + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, +) -> Tuple[torch.Tensor, torch.Tensor]: + grad_input = dgelu_torch(grad, fwd_input, quantizer) + + grad_bias = grad.sum(dim=tuple(range(grad.ndim - 1))) + + return grad_input, grad_bias + + +def dbias_dsilu_torch( + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, +) -> Tuple[torch.Tensor, torch.Tensor]: + grad_input = dsilu_torch(grad, fwd_input, quantizer) + + grad_bias = grad.sum(dim=tuple(range(grad.ndim - 1))) + + return grad_input, grad_bias + + +def dbias_drelu_torch( + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, +) -> Tuple[torch.Tensor, torch.Tensor]: + grad_input = drelu_torch(grad, fwd_input, quantizer) + + grad_bias = grad.sum(dim=tuple(range(grad.ndim - 1))) + + return grad_input, grad_bias + + +def dbias_dqgelu_torch( + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, +) -> Tuple[torch.Tensor, torch.Tensor]: + grad_input = dqgelu_torch(grad, fwd_input, quantizer) + + grad_bias = grad.sum(dim=tuple(range(grad.ndim - 1))) + + return grad_input, grad_bias + + +def dbias_dsrelu_torch( + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, +) -> Tuple[torch.Tensor, torch.Tensor]: + grad_input = dsrelu_torch(grad, fwd_input, quantizer) + + grad_bias = grad.sum(dim=tuple(range(grad.ndim - 1))) + + return grad_input, grad_bias diff --git a/transformer_engine/plugin/core/backends/reference/impl/dropout.py b/transformer_engine/plugin/core/backends/reference/impl/dropout.py new file mode 100644 index 0000000000..f671ff6c5d --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/impl/dropout.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import Optional, Tuple +import torch +import torch.nn.functional as F + +__all__ = [ + "dropout_fwd_torch", + "dropout_bwd_torch", +] + + +def dropout_fwd_torch( + input: torch.Tensor, + dropout_probability: float, + out: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if dropout_probability == 0.0: + output = input.clone() if out is None else input.clone().to(out) + mask = torch.ones_like(input, dtype=torch.uint8) + return output, mask + + mask = torch.bernoulli(torch.full_like(input, 1.0 - dropout_probability)).to(torch.uint8) + + scale = 1.0 / (1.0 - dropout_probability) + output = input * mask.to(input.dtype) * scale + + if out is not None: + out.copy_(output) + output = out + + return output, mask + + +def dropout_bwd_torch( + grad_output: torch.Tensor, + mask: torch.Tensor, + dropout_probability: float, + grad_input: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if dropout_probability == 0.0: + return grad_output.clone() if grad_input is None else grad_output.clone().to(grad_input) + + scale = 1.0 / (1.0 - dropout_probability) + grad = grad_output * mask.to(grad_output.dtype) * scale + + if grad_input is not None: + grad_input.copy_(grad) + grad = grad_input + + return grad diff --git a/transformer_engine/plugin/core/backends/reference/impl/gemm.py b/transformer_engine/plugin/core/backends/reference/impl/gemm.py new file mode 100644 index 0000000000..65a3f1cc52 --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/impl/gemm.py @@ -0,0 +1,128 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import Any, Optional, Tuple, Union +import torch + +__all__ = [ + "general_gemm_torch", +] + +_DTYPE_TO_TORCH = { + 0: torch.uint8, + 2: torch.int32, + 4: torch.float32, + 5: torch.float16, + 6: torch.bfloat16, + 7: torch.float8_e4m3fn, + 8: torch.float8_e5m2, +} + + +def _convert_dtype(dtype: Union[int, torch.dtype, None]) -> Optional[torch.dtype]: + if dtype is None: + return None + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, int): + return _DTYPE_TO_TORCH.get(dtype, None) + if hasattr(dtype, "value"): + return _DTYPE_TO_TORCH.get(dtype.value, None) + return None + + +def general_gemm_torch( + A: torch.Tensor, + transA: bool, + B: torch.Tensor, + transB: bool, + D: Optional[torch.Tensor], + quantizer: Any, + output_dtype: Any, + bias: Optional[torch.Tensor], + bias_type: Any, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[Any] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + import torch.nn.functional as F + + target_device = B.device + + if A.device != target_device: + A = A.to(target_device) + + original_B_shape = None + if B.ndim == 3: + original_B_shape = B.shape + B = B.reshape(-1, B.shape[-1]) + + if A.ndim == 3: + A = A.reshape(-1, A.shape[-1]) + + A_comp = A.T if transA else A + B_comp = B.T if transB else B + + if A_comp.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + compute_dtype = torch.bfloat16 + A_comp = A_comp.to(compute_dtype) + B_comp = B_comp.to(compute_dtype) + + out = torch.mm(B_comp, A_comp) + + if alpha != 1.0: + out = out * alpha + + if original_B_shape is not None: + out = out.view(original_B_shape[0], original_B_shape[1], -1) + + gelu_input_ret = None + if gelu and gelu_in is not None: + pass + + if bias is not None: + if bias.device != target_device: + bias = bias.to(target_device) + out = out + bias + + if gelu: + if gelu_in is not None: + gelu_in.copy_(out) + gelu_input_ret = gelu_in + else: + gelu_input_ret = out.clone() + out = F.gelu(out, approximate="tanh") + + torch_out_dtype = _convert_dtype(output_dtype) + if torch_out_dtype is not None and out.dtype != torch_out_dtype: + out = out.to(torch_out_dtype) + + if D is not None: + if D.device != target_device: + D = D.to(target_device) + if accumulate: + beta_val = beta if beta is not None else 1.0 + D.mul_(beta_val).add_(out) + out = D + else: + D.copy_(out) + out = D + + bias_grad = None + if grad and bias is not None: + pass + + extra_output_ret = None + + return out, bias_grad, gelu_input_ret, extra_output_ret diff --git a/transformer_engine/plugin/core/backends/reference/impl/normalization.py b/transformer_engine/plugin/core/backends/reference/impl/normalization.py new file mode 100644 index 0000000000..c9ca2e1ae3 --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/impl/normalization.py @@ -0,0 +1,112 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import Any, Optional, Tuple +import torch +import torch.nn.functional as F +from ....ops import DType + +__all__ = [ + "layernorm_fwd_torch", + "layernorm_bwd_torch", +] + +# Mapping from DType enum to torch.dtype +_DTYPE_TO_TORCH_DTYPE = { + DType.kByte: torch.uint8, + DType.kInt16: torch.int16, + DType.kInt32: torch.int32, + DType.kInt64: torch.int64, + DType.kFloat32: torch.float32, + DType.kFloat16: torch.float16, + DType.kBFloat16: torch.bfloat16, + DType.kFloat8E4M3: torch.float8_e4m3fn, + DType.kFloat8E5M2: torch.float8_e5m2, +} + + +def _to_torch_dtype(dtype): + """Convert DType enum to torch.dtype.""" + if dtype is None: + return None + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, (int, DType)): + dtype_enum = DType(dtype) + if dtype_enum in _DTYPE_TO_TORCH_DTYPE: + return _DTYPE_TO_TORCH_DTYPE[dtype_enum] + raise ValueError(f"Unsupported dtype: {dtype}") + + +def layernorm_fwd_torch( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + eps: float, + ln_out: Optional[torch.Tensor], + quantizer: Any, + odtype: DType, + sm_margin: int, + zero_centered_gamma: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + odtype = _to_torch_dtype(odtype) + mean = input.mean(dim=-1, keepdim=True) + var = input.var(dim=-1, keepdim=True, unbiased=False) + rsigma = torch.rsqrt(var + eps) + + normalized = (input - mean) * rsigma + + if zero_centered_gamma: + output = normalized * (1.0 + weight) + else: + output = normalized * weight + + if bias is not None: + output = output + bias + + if output.dtype != odtype: + output = output.to(odtype) + + mean = mean.squeeze(-1) + rsigma = rsigma.squeeze(-1) + + return output, mean, rsigma + + +def layernorm_bwd_torch( + dy: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int = 0, + zero_centered_gamma: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if mu.ndim < x.ndim: + mu = mu.unsqueeze(-1) + if rsigma.ndim < x.ndim: + rsigma = rsigma.unsqueeze(-1) + + x_normalized = (x - mu) * rsigma + + N = x.shape[-1] + + if zero_centered_gamma: + gamma_adj = 1.0 + gamma + else: + gamma_adj = gamma + + dy_gamma = dy * gamma_adj + + mean_dy_gamma = dy_gamma.mean(dim=-1, keepdim=True) + + mean_dy_gamma_x = (dy_gamma * x_normalized).mean(dim=-1, keepdim=True) + + dx = rsigma * (dy_gamma - mean_dy_gamma - x_normalized * mean_dy_gamma_x) + + dgamma = (dy * x_normalized).sum(dim=tuple(range(dy.ndim - 1))) + + dbeta = dy.sum(dim=tuple(range(dy.ndim - 1))) + + return dx, dgamma, dbeta diff --git a/transformer_engine/plugin/core/backends/reference/impl/optimizer.py b/transformer_engine/plugin/core/backends/reference/impl/optimizer.py new file mode 100644 index 0000000000..890ae9a563 --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/impl/optimizer.py @@ -0,0 +1,394 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import List, Tuple, Union +import torch + +__all__ = [ + "multi_tensor_scale_torch", + "multi_tensor_l2norm_torch", + "multi_tensor_adam_torch", + "multi_tensor_adam_param_remainder_torch", + "multi_tensor_sgd_torch", + "multi_tensor_compute_scale_and_scale_inv_torch", +] + + +def multi_tensor_scale_torch( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, +) -> None: + if noop_flag.item() != 0: + return + + if len(tensor_lists) != 2: + raise ValueError("tensor_lists should contain [input_tensors, output_tensors]") + + input_tensors, output_tensors = tensor_lists + + if len(output_tensors) != len(input_tensors): + raise ValueError("Output and input tensor lists must have the same length") + + for in_tensor, out_tensor in zip(input_tensors, output_tensors): + # Check for inf/nan (matches CUDA behavior for AMP gradient scaling) + if not torch.isfinite(in_tensor).all(): + noop_flag.fill_(1) + out_tensor.copy_(in_tensor * scale) + + +def multi_tensor_l2norm_torch( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute L2 norm of tensors. + + Returns: + Tuple of (total_norm, per_tensor_norms_or_dummy) + - total_norm: The combined L2 norm of all tensors + - per_tensor_norms_or_dummy: Per-tensor norms stacked if per_tensor=True, else dummy tensor + """ + device = tensor_lists[0][0].device if tensor_lists and tensor_lists[0] else "cpu" + + if noop_flag.item() != 0: + return torch.tensor(0.0, device=device), torch.tensor(0.0, device=device) + + tensors = tensor_lists[0] + + # Compute per-tensor norms + per_tensor_norms = [] + total_norm_sq = torch.tensor(0.0, device=device) + + for tensor in tensors: + norm_sq = torch.sum(tensor.float() ** 2) + # Check for inf/nan (matches CUDA behavior) + if not torch.isfinite(norm_sq): + noop_flag.fill_(1) + total_norm_sq = total_norm_sq + norm_sq + if per_tensor: + per_tensor_norms.append(torch.sqrt(norm_sq)) + + total_norm = torch.sqrt(total_norm_sq) + + if per_tensor: + per_tensor_result = torch.stack(per_tensor_norms) + else: + per_tensor_result = torch.tensor(0.0, device=device) + + return total_norm, per_tensor_result + + +def multi_tensor_adam_torch( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, +) -> None: + """ + Adam optimizer implementation matching CUDA exactly. + + mode == 0: L2 regularization (add weight_decay * param to gradient before moment update) + mode == 1: AdamW (add weight_decay * param to update after moment computation) + """ + if noop_flag.item() != 0: + return + + if len(tensor_lists) != 4: + raise ValueError("tensor_lists should contain [grads, params, exp_avgs, exp_avg_sqs]") + + grads, params, exp_avgs, exp_avg_sqs = tensor_lists + + if not (len(params) == len(grads) == len(exp_avgs) == len(exp_avg_sqs)): + raise ValueError("All tensor lists must have the same length") + + if bias_correction: + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + else: + bias_correction1 = 1.0 + bias_correction2 = 1.0 + + for grad, param, exp_avg, exp_avg_sq in zip(grads, params, exp_avgs, exp_avg_sqs): + if grad is None: + continue + + # Convert to float for computation (matches CUDA's MATH_T = float) + g = grad.float() + p = param.float() + + if mode == 0: # L2 regularization + # Add weight decay to gradient before moment update + g = g + weight_decay * p + + # Update moments with modified gradient + exp_avg.mul_(beta1).add_(g, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g, g, value=1 - beta2) + + # Bias correction + m_corr = exp_avg / bias_correction1 + v_corr = exp_avg_sq / bias_correction2 + + # Compute update + denom = v_corr.sqrt().add_(epsilon) + update = m_corr / denom + + # Update parameter + param.add_(update, alpha=-lr) + else: # mode == 1, AdamW (decoupled weight decay) + # Update moments with original gradient + exp_avg.mul_(beta1).add_(g, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g, g, value=1 - beta2) + + # Bias correction + m_corr = exp_avg / bias_correction1 + v_corr = exp_avg_sq / bias_correction2 + + # Compute update with weight decay added (matches CUDA exactly) + denom = v_corr.sqrt().add_(epsilon) + update = (m_corr / denom) + (weight_decay * p) + + # Update parameter + param.add_(update, alpha=-lr) + + +def multi_tensor_adam_param_remainder_torch( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, +) -> None: + """ + Adam optimizer with parameter remainders for BF16 precision. + + This variant stores BF16 parameters + int16 remainders to reconstruct FP32 master weights + using bit manipulation, matching the CUDA implementation exactly. + + The CUDA implementation stores: + - p: int16 representing the high 16 bits of FP32 (viewed as BF16) + - p_remainder: int16 representing the low 16 bits of FP32 + + To reconstruct FP32: + - If p_remainder < 0, decrement p (undo rounding) + - Combine: fp32.int16[1] = p, fp32.int16[0] = p_remainder + + To split FP32 back: + - p = fp32.int16[1] (high 16 bits) + - p_remainder = fp32.int16[0] (low 16 bits) + - If p_remainder < 0, increment p (round up) + + Args: + chunk_size: Chunk size for processing (unused in PyTorch implementation) + noop_flag: If non-zero, skip computation + tensor_lists: [grads, params (int16/bf16), exp_avgs (fp32), exp_avg_sqs (fp32), param_remainders (int16)] + lr: Learning rate + beta1: First moment decay rate + beta2: Second moment decay rate + epsilon: Epsilon for numerical stability + step: Current optimization step + mode: 0 = L2 regularization, 1 = AdamW (decoupled weight decay) + bias_correction: Whether to apply bias correction (1 = yes, 0 = no) + weight_decay: Weight decay coefficient + """ + if noop_flag.item() != 0: + return + + if len(tensor_lists) != 5: + raise ValueError( + "tensor_lists should contain [grads, params, exp_avgs, exp_avg_sqs, param_remainders]" + ) + + grads, params, exp_avgs, exp_avg_sqs, param_remainders = tensor_lists + + if not ( + len(params) == len(grads) == len(exp_avgs) == len(exp_avg_sqs) == len(param_remainders) + ): + raise ValueError("All tensor lists must have the same length") + + if bias_correction: + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + else: + bias_correction1 = 1.0 + bias_correction2 = 1.0 + + is_adamw = mode == 1 + + for grad, param, exp_avg, exp_avg_sq, param_remainder in zip( + grads, params, exp_avgs, exp_avg_sqs, param_remainders + ): + # Convert gradient to float + g_float = grad.float() + + # Reconstruct FP32 master weight from int16 param + int16 remainder using bit manipulation + # This matches the CUDA implementation exactly: + # 1. If p_remainder < 0, decrement p (undo rounding) + # 2. Combine high 16 bits (p) and low 16 bits (p_remainder) into FP32 + + local_p = param.view(torch.int16).clone() + local_p_rem = param_remainder.clone() + + # Undo rounding: if remainder < 0, decrement p + local_p = torch.where(local_p_rem < 0, local_p - 1, local_p) + + # Combine into FP32 using bit shift operations + # local_p is high 16 bits, local_p_rem is low 16 bits + high_bits = local_p.to(torch.int32) << 16 + low_bits = local_p_rem.to(torch.int32) & 0xFFFF # Mask off sign extension + param_int32 = high_bits | low_bits + param_master = param_int32.view(torch.float32) + + # L2 mode: add weight decay to gradient before updating moments + if not is_adamw and weight_decay != 0: + g_float = g_float + weight_decay * param_master + + # Update first moment: m = beta1 * m + (1 - beta1) * g + exp_avg.mul_(beta1).add_(g_float, alpha=1 - beta1) + + # Update second moment: v = beta2 * v + (1 - beta2) * g^2 + exp_avg_sq.mul_(beta2).addcmul_(g_float, g_float, value=1 - beta2) + + # Apply bias correction + m_corr = exp_avg / bias_correction1 + v_corr = exp_avg_sq / bias_correction2 + + # Compute denominator: sqrt(v_corr) + epsilon + denom = torch.sqrt(v_corr) + epsilon + + # Compute update + update = m_corr / denom + + # AdamW mode: add decoupled weight decay to update + if is_adamw and weight_decay != 0: + update = update + weight_decay * param_master + + # Update master weight: p = p - lr * update + param_master = param_master - lr * update + + # Split FP32 back into int16 param + int16 remainder using bit manipulation + # This matches the CUDA implementation exactly: + # 1. Extract high 16 bits as p + # 2. Extract low 16 bits as p_remainder + # 3. If p_remainder < 0, increment p (round up) + + param_int32 = param_master.view(torch.int32) + # Extract low 16 bits (remainder) and high 16 bits (param) + new_p_rem = (param_int32 & 0xFFFF).to(torch.int16) + new_p = ((param_int32 >> 16) & 0xFFFF).to(torch.int16) + + # Round up: if remainder < 0, increment p + new_p = torch.where(new_p_rem < 0, new_p + 1, new_p) + + # Write back + param.view(torch.int16).copy_(new_p) + param_remainder.copy_(new_p_rem) + + +def multi_tensor_sgd_torch( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + momentum: float, + dampening: float, + weight_decay: float, + nesterov: bool, +) -> None: + if noop_flag.item() != 0: + return + + if len(tensor_lists) != 3: + raise ValueError("tensor_lists should contain [params, grads, momentum_buffers]") + + params, grads, momentum_buffers = tensor_lists + + if not (len(params) == len(grads) == len(momentum_buffers)): + raise ValueError("All tensor lists must have the same length") + + for param, grad, buf in zip(params, grads, momentum_buffers): + if grad is None: + continue + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if momentum != 0: + if buf is None or buf.numel() == 0: + buf = grad.clone().detach() + else: + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + param.add_(grad, alpha=-lr) + + +def multi_tensor_compute_scale_and_scale_inv_torch( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + max_fp8: float, + force_pow_2_scales: bool = False, + amax_epsilon: float = 0.0, +) -> None: + """ + Compute scale and scale_inv from amax values for FP8 quantization. + + Args: + chunk_size: Chunk size (unused in PyTorch implementation) + noop_flag: If non-zero, skip computation + tensor_lists: [amaxes, scales, scale_invs] + max_fp8: Maximum representable value in FP8 format (e.g., 448.0 for E4M3) + force_pow_2_scales: If True, force scales to be powers of 2 + amax_epsilon: Small epsilon to add to amax to avoid division by zero + """ + if noop_flag.item() != 0: + return + + if len(tensor_lists) != 3: + raise ValueError("tensor_lists should contain [amaxes, scales, scale_invs]") + + amaxes, scales, scale_invs = tensor_lists + + if not (len(amaxes) == len(scales) == len(scale_invs)): + raise ValueError("All tensor lists must have the same length") + + for amax, scale, scale_inv in zip(amaxes, scales, scale_invs): + # Add epsilon to avoid division by zero + amax_val = amax + amax_epsilon + + # Compute scale: max_fp8 / amax + # Clamp amax to avoid very small values + amax_val = torch.clamp(amax_val, min=1e-12) + computed_scale = max_fp8 / amax_val + + if force_pow_2_scales: + # Round scale to nearest power of 2 + log2_scale = torch.log2(computed_scale) + log2_scale = torch.round(log2_scale) + computed_scale = torch.pow(2.0, log2_scale) + + # Update scale and scale_inv + scale.copy_(computed_scale) + scale_inv.copy_(1.0 / computed_scale) diff --git a/transformer_engine/plugin/core/backends/reference/impl/rmsnorm.py b/transformer_engine/plugin/core/backends/reference/impl/rmsnorm.py new file mode 100644 index 0000000000..0aebdae2fe --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/impl/rmsnorm.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import torch + +__all__ = [ + "rmsnorm_fwd_torch", + "rmsnorm_bwd_torch", +] + + +def rmsnorm_fwd_torch( + input, + weight, + eps, + ln_out, + quantizer, + odtype, + sm_margin, + zero_centered_gamma, +): + if weight.device != input.device: + weight = weight.to(input.device) + + variance = input.pow(2).mean(-1, keepdim=True) + inv_rms = torch.rsqrt(variance + eps) + y = input * inv_rms + if zero_centered_gamma: + y = y * (1 + weight) + else: + y = y * weight + + rstdevs = inv_rms.squeeze(-1) + + return y, None, rstdevs + + +def rmsnorm_bwd_torch( + dy, + x, + rsigma, + gamma, + sm_margin, + zero_centered_gamma, +): + inv_rms = rsigma.unsqueeze(-1) + + x_norm = x * inv_rms + + if zero_centered_gamma: + weight = 1 + gamma + else: + weight = gamma + + dw = (dy * x_norm).sum(dim=tuple(range(dy.ndim - 1))) + + dy_weighted = dy * weight + + mean_term = (dy_weighted * x_norm).mean(-1, keepdim=True) + dx = inv_rms * (dy_weighted - x_norm * mean_term) + return dx, dw diff --git a/transformer_engine/plugin/core/backends/reference/impl/softmax.py b/transformer_engine/plugin/core/backends/reference/impl/softmax.py new file mode 100644 index 0000000000..2689ab938a --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/impl/softmax.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import Optional +import torch +import torch.nn.functional as F + +__all__ = [ + "scaled_softmax_forward_torch", + "scaled_softmax_backward_torch", + "scaled_masked_softmax_forward_torch", + "scaled_masked_softmax_backward_torch", + "scaled_upper_triang_masked_softmax_forward_torch", + "scaled_upper_triang_masked_softmax_backward_torch", + "scaled_aligned_causal_masked_softmax_forward_torch", + "scaled_aligned_causal_masked_softmax_backward_torch", +] + + +def scaled_softmax_forward_torch(input: torch.Tensor, scale: float) -> torch.Tensor: + return F.softmax(input * scale, dim=-1) + + +def scaled_softmax_backward_torch( + output_grad: torch.Tensor, + softmax_output: torch.Tensor, + scale: float, +) -> torch.Tensor: + # Compute in float32 for numerical stability (matching CUDA behavior) + orig_dtype = output_grad.dtype + output_grad_f32 = output_grad.float() + softmax_output_f32 = softmax_output.float() + + grad_softmax = softmax_output_f32 * ( + output_grad_f32 - (softmax_output_f32 * output_grad_f32).sum(dim=-1, keepdim=True) + ) + + return (grad_softmax * scale).to(orig_dtype) + + +def scaled_masked_softmax_forward_torch( + input: torch.Tensor, + mask: torch.Tensor, + scale: float, +) -> torch.Tensor: + """Reference forward matching TE CUDA `scaled_masked_softmax_warp_forward`. + + Integer/bool mask (same as uint8 kernel contract): + - **Exactly** ``mask == 1`` means **masked** (logit set to ``-10000``, not ``input*scale`` offset). + - Any other value (typically 0) means **unmasked** (logit is ``input * scale``). + + Floating mask: treated as **additive** bias in logit space (already scaled), added after + ``input * scale``. + + Common pitfalls this avoids vs the old implementation: + 1) ``input * scale + (-10000)`` on masked positions ≠ CUDA's plain ``-10000``. + 2) Non-uint8 masks (bool, int) were used as direct addends → wrong (0/1 added to logits). + 3) ``mask.bool()`` masks any nonzero byte; CUDA only masks when ``mask == 1``. + """ + if mask.dim() == 4 and mask.size(1) == 1 and input.dim() == 4: + mask = mask.expand_as(input) + + scaled = input * scale + + if mask.is_floating_point(): + scaled = scaled + mask.to(dtype=scaled.dtype) + return F.softmax(scaled, dim=-1) + + # Integer / bool: align with CUDA (masked iff value == 1) + scaled = scaled.masked_fill(mask == 1, -10000.0) + # CUDA zeros output row when every position in the softmax dim is masked (max == -10000) + all_masked = (mask == 1).all(dim=-1, keepdim=True) + out = F.softmax(scaled, dim=-1) + return out.masked_fill(all_masked, 0.0) + + +def scaled_masked_softmax_backward_torch( + output_grad: torch.Tensor, + softmax_output: torch.Tensor, + scale: float, +) -> torch.Tensor: + # Compute in float32 for numerical stability (matching CUDA behavior) + orig_dtype = output_grad.dtype + output_grad_f32 = output_grad.float() + softmax_output_f32 = softmax_output.float() + + grad_softmax = softmax_output_f32 * ( + output_grad_f32 - (softmax_output_f32 * output_grad_f32).sum(dim=-1, keepdim=True) + ) + + return (grad_softmax * scale).to(orig_dtype) + + +def scaled_upper_triang_masked_softmax_forward_torch( + input: torch.Tensor, + scale: float, +) -> torch.Tensor: + seq_len = input.size(-1) + + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=input.device, dtype=input.dtype), + diagonal=1, + ) + + scaled_input = input * scale + causal_mask + + return F.softmax(scaled_input, dim=-1) + + +def scaled_upper_triang_masked_softmax_backward_torch( + output_grad: torch.Tensor, + softmax_output: torch.Tensor, + scale: float, +) -> torch.Tensor: + # Compute in float32 for numerical stability (matching CUDA behavior) + orig_dtype = output_grad.dtype + output_grad_f32 = output_grad.float() + softmax_output_f32 = softmax_output.float() + + grad_softmax = softmax_output_f32 * ( + output_grad_f32 - (softmax_output_f32 * output_grad_f32).sum(dim=-1, keepdim=True) + ) + + return (grad_softmax * scale).to(orig_dtype) + + +def scaled_aligned_causal_masked_softmax_forward_torch( + input: torch.Tensor, + scale: float, +) -> torch.Tensor: + return scaled_upper_triang_masked_softmax_forward_torch(input, scale) + + +def scaled_aligned_causal_masked_softmax_backward_torch( + output_grad: torch.Tensor, + softmax_output: torch.Tensor, + scale: float, +) -> torch.Tensor: + # Compute in float32 for numerical stability (matching CUDA behavior) + orig_dtype = output_grad.dtype + output_grad_f32 = output_grad.float() + softmax_output_f32 = softmax_output.float() + + grad_softmax = softmax_output_f32 * ( + output_grad_f32 - (softmax_output_f32 * output_grad_f32).sum(dim=-1, keepdim=True) + ) + + return (grad_softmax * scale).to(orig_dtype) diff --git a/transformer_engine/plugin/core/backends/reference/reference.py b/transformer_engine/plugin/core/backends/reference/reference.py new file mode 100644 index 0000000000..7f52c41677 --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/reference.py @@ -0,0 +1,608 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import os +from typing import Any, List, Optional, Tuple +import torch +from ...ops import * + +from .impl import ( + general_gemm_torch, + rmsnorm_fwd_torch, + rmsnorm_bwd_torch, + layernorm_fwd_torch, + layernorm_bwd_torch, + gelu_torch, + geglu_torch, + qgelu_torch, + qgeglu_torch, + relu_torch, + reglu_torch, + srelu_torch, + sreglu_torch, + silu_torch, + swiglu_torch, + clamped_swiglu_torch, + dgelu_torch, + dgeglu_torch, + dqgelu_torch, + dqgeglu_torch, + drelu_torch, + dreglu_torch, + dsrelu_torch, + dsreglu_torch, + dsilu_torch, + dswiglu_torch, + clamped_dswiglu_torch, + dbias_dgelu_torch, + dbias_dsilu_torch, + dbias_drelu_torch, + dbias_dqgelu_torch, + dbias_dsrelu_torch, + scaled_softmax_forward_torch, + scaled_softmax_backward_torch, + scaled_masked_softmax_forward_torch, + scaled_masked_softmax_backward_torch, + scaled_upper_triang_masked_softmax_forward_torch, + scaled_upper_triang_masked_softmax_backward_torch, + scaled_aligned_causal_masked_softmax_forward_torch, + scaled_aligned_causal_masked_softmax_backward_torch, + dropout_fwd_torch, + dropout_bwd_torch, + multi_tensor_scale_torch, + multi_tensor_l2norm_torch, + multi_tensor_adam_torch, + multi_tensor_adam_param_remainder_torch, + multi_tensor_sgd_torch, +) + + +class ReferenceBackend(TEFLBackendBase): + @staticmethod + def check_available() -> bool: + return True + + def is_available(self) -> bool: + return True + + def get_attention_backend(self, _attention_params=None): + from packaging.version import Version as PkgVersion + from ...logger_manager import get_logger + + logger = get_logger() + + # Read environment variables to determine which backends to enable + use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) + use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) + use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) + + # Log disabled backends + if not use_flash_attention: + logger.info_once("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") + if not use_fused_attention: + logger.info_once("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") + if not use_unfused_attention: + logger.info_once("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") + + flash_attention_backend = PkgVersion("2.6.0") if use_flash_attention else None + fused_attention_backend = NVTE_Fused_Attn_Backend.NVTE_No_Backend + + available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + + return ( + use_flash_attention, + flash_attention_backend, + use_fused_attention, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) + + def generic_gemm( + self, + A: Any, + transA: bool, + B: Any, + transB: bool, + D: Any, + quantizer: Any, + output_dtype: Optional[DType], + bias: Optional[torch.Tensor], + bias_type: DType, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[CommOverlapType] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, + ) -> List[Any]: + return general_gemm_torch( + A, + transA, + B, + transB, + D, + quantizer, + output_dtype, + bias, + bias_type, + gelu, + gelu_in, + grad, + workspace, + workspace_size, + accumulate, + use_split_accumulator, + comm_overlap, + comm_type, + extra_output, + bulk_overlap, + alpha, + beta, + ) + + # GELU and variants + def gelu(self, input: torch.Tensor, quantizer: Any) -> Any: + return gelu_torch(input, quantizer) + + def geglu(self, input: torch.Tensor, quantizer: Any) -> Any: + return geglu_torch(input, quantizer) + + def qgelu(self, input: torch.Tensor, quantizer: Any) -> Any: + return qgelu_torch(input, quantizer) + + def qgeglu(self, input: torch.Tensor, quantizer: Any) -> Any: + return qgeglu_torch(input, quantizer) + + # ReLU and variants + def relu(self, input: torch.Tensor, quantizer: Any) -> Any: + return relu_torch(input, quantizer) + + def reglu(self, input: torch.Tensor, quantizer: Any) -> Any: + return reglu_torch(input, quantizer) + + def srelu(self, input: torch.Tensor, quantizer: Any) -> Any: + return srelu_torch(input, quantizer) + + def sreglu(self, input: torch.Tensor, quantizer: Any) -> Any: + return sreglu_torch(input, quantizer) + + # SwiGLU and variants + def silu(self, input: torch.Tensor, quantizer: Any) -> Any: + return silu_torch(input, quantizer) + + def swiglu(self, input: torch.Tensor, quantizer: Any) -> Any: + return swiglu_torch(input, quantizer) + + def clamped_swiglu( + self, + input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + return clamped_swiglu_torch(input, quantizer, limit, alpha) + + # Backward of GELU and variants + def dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + return dgelu_torch(grad, fwd_input, quantizer) + + def dgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + return dgeglu_torch(grad, fwd_input, quantizer) + + def dqgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + return dqgelu_torch(grad, fwd_input, quantizer) + + def dqgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + return dqgeglu_torch(grad, fwd_input, quantizer) + + # Backward of ReLU and variants + def drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + return drelu_torch(grad, fwd_input, quantizer) + + def dreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + return dreglu_torch(grad, fwd_input, quantizer) + + def dsrelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + return dsrelu_torch(grad, fwd_input, quantizer) + + def dsreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + return dsreglu_torch(grad, fwd_input, quantizer) + + # Backward of SiLU and variants + def dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + return dsilu_torch(grad, fwd_input, quantizer) + + def dswiglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + return dswiglu_torch(grad, fwd_input, quantizer) + + def clamped_dswiglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + return clamped_dswiglu_torch(grad, fwd_input, quantizer, limit, alpha) + + # DBias + DAct fusions + def dbias_dgelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + return dbias_dgelu_torch(grad, fwd_input, quantizer) + + def dbias_dsilu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + return dbias_dsilu_torch(grad, fwd_input, quantizer) + + def dbias_drelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Tuple[torch.Tensor, Any]: + return dbias_drelu_torch(grad, fwd_input, quantizer) + + def dbias_dqgelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + return dbias_dqgelu_torch(grad, fwd_input, quantizer) + + def dbias_dsrelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + return dbias_dsrelu_torch(grad, fwd_input, quantizer) + + # LayerNorm + def layernorm_fwd( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + return layernorm_fwd_torch( + input=input, + weight=weight, + bias=bias, + eps=eps, + ln_out=ln_out, + quantizer=quantizer, + odtype=otype, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + ) + + def layernorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + return layernorm_bwd_torch( + dy=dz, + x=x, + mu=mu, + rsigma=rsigma, + gamma=gamma, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + ) + + # RMSNorm + def rmsnorm_fwd( + self, + input: Any, + weight: Any, + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + return rmsnorm_fwd_torch( + input=input, + weight=weight, + eps=eps, + ln_out=ln_out, + quantizer=quantizer, + odtype=otype, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + ) + + def rmsnorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + return rmsnorm_bwd_torch( + dy=dz, + x=x, + rsigma=rsigma, + gamma=gamma, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + ) + + # Softmax functions + def scaled_softmax_forward( + self, + input: torch.Tensor, + scale: float, + ) -> torch.Tensor: + return scaled_softmax_forward_torch(input, scale) + + def scaled_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + return scaled_softmax_backward_torch(output_grad_, softmax_results_, scale_factor) + + def scaled_masked_softmax_forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + return scaled_masked_softmax_forward_torch(input, mask, scale_factor) + + def scaled_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + return scaled_masked_softmax_backward_torch(output_grad_, softmax_results_, scale_factor) + + def scaled_upper_triang_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + return scaled_upper_triang_masked_softmax_forward_torch(input, scale_factor) + + def scaled_upper_triang_masked_softmax_backward( + self, + output_grads_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + return scaled_upper_triang_masked_softmax_backward_torch( + output_grads_, softmax_results_, scale_factor + ) + + def scaled_aligned_causal_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + return scaled_aligned_causal_masked_softmax_forward_torch(input, scale_factor) + + def scaled_aligned_causal_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + return scaled_aligned_causal_masked_softmax_backward_torch( + output_grad_, softmax_results_, scale_factor + ) + + # Fused attention backend + def get_fused_attn_backend( + self, + _is_training: bool, + _q_dtype: DType, + _kv_dtype: DType, + _qkv_layout: NVTE_QKV_Layout, + _bias_type: NVTE_Bias_Type, + _attn_mask_type: NVTE_Mask_Type, + _softmax_type: NVTE_Softmax_Type, + _p_dropout: float, + _num_attn_heads: int, + _num_gqa_groups: int, + _max_seqlen_q: int, + _max_seqlen_kv: int, + _head_dim_qk: int, + _head_dim_v: int, + _window_size_left: int, + _window_size_right: int, + _return_max_logit: bool, + _cuda_graph: bool = False, + _deterministic: bool = False, + ) -> NVTE_Fused_Attn_Backend: + return NVTE_Fused_Attn_Backend.NVTE_No_Backend + + # Dropout + def dropout_fwd( + self, + input: torch.Tensor, + dropout_probability: float, + out: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + return dropout_fwd_torch(input, dropout_probability, out) + + def dropout_bwd( + self, + grad_output: torch.Tensor, + mask: torch.Tensor, + dropout_probability: float, + grad_input: Optional[torch.Tensor], + ) -> torch.Tensor: + return dropout_bwd_torch(grad_output, mask, dropout_probability, grad_input) + + # Misc + def get_cublasLt_version(self) -> int: + return 0 + + def get_cudnn_version(self) -> int: + return 0 + + def get_num_cublas_streams(self) -> int: + return 4 # keep consistent with transformer_engine/common/util/multi_stream.cpp, get_num_compute_streams() + + # Multi-tensor functions + def multi_tensor_scale( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, + ) -> None: + return multi_tensor_scale_torch(chunk_size, noop_flag, tensor_lists, scale) + + def multi_tensor_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return multi_tensor_l2norm_torch(chunk_size, noop_flag, tensor_lists, per_tensor) + + def multi_tensor_unscale_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + inv_scale: torch.Tensor, + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if noop_flag.item() != 0: + device = tensor_lists[0][0].device if tensor_lists and tensor_lists[0] else "cpu" + return torch.tensor(0.0, device=device), torch.tensor(0.0, device=device) + + # Multiply by inv_scale + unscaled_tensors = [] + for tensor in tensor_lists[0]: + unscaled_tensors.append(tensor * inv_scale.item()) + + return multi_tensor_l2norm_torch(chunk_size, noop_flag, [unscaled_tensors], per_tensor) + + def multi_tensor_adam( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + return multi_tensor_adam_torch( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_param_remainder( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + return multi_tensor_adam_param_remainder_torch( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_sgd( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + wd: float, + momentum: float, + dampening: float, + lr: float, + nesterov: bool, + first_run: bool, + wd_after_momentum: bool, + scale: float, + ) -> None: + return multi_tensor_sgd_torch( + chunk_size, + noop_flag, + tensor_lists, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, + ) + + def get_flash_attention_class(self): + from .flash_attention import FlashAttentionTorch + + return FlashAttentionTorch diff --git a/transformer_engine/plugin/core/backends/reference/register_ops.py b/transformer_engine/plugin/core/backends/reference/register_ops.py new file mode 100644 index 0000000000..0151ec00f9 --- /dev/null +++ b/transformer_engine/plugin/core/backends/reference/register_ops.py @@ -0,0 +1,491 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Reference backend operator registrations. + +This module registers all REFERENCE (PyTorch) implementations. +""" + +from __future__ import annotations + +import functools + +from ...types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all PyTorch (REFERENCE) operator implementations. + + Args: + registry: Registry to register into + """ + from .reference import ReferenceBackend + + # Create a backend instance to access the methods + backend = ReferenceBackend() + + # Bind is_available to all methods + is_avail = backend.is_available + + impls = [ + # Normalization + OpImpl( + op_name="rmsnorm_fwd", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.rmsnorm_fwd, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="rmsnorm_bwd", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.rmsnorm_bwd, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="layernorm_fwd", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.layernorm_fwd, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="layernorm_bwd", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.layernorm_bwd, is_avail), + vendor=None, + priority=50, + ), + # GEMM + OpImpl( + op_name="generic_gemm", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.generic_gemm, is_avail), + vendor=None, + priority=50, + ), + # Activations - Forward + OpImpl( + op_name="gelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.gelu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="geglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.geglu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="qgelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.qgelu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="qgeglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.qgeglu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="relu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.relu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="reglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.reglu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="srelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.srelu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="sreglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.sreglu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="silu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.silu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="swiglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.swiglu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="clamped_swiglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.clamped_swiglu, is_avail), + vendor=None, + priority=50, + ), + # Activations - Backward + OpImpl( + op_name="dgelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dgelu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dgeglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dgeglu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dqgelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dqgelu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dqgeglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dqgeglu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="drelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.drelu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dreglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dreglu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dsrelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dsrelu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dsreglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dsreglu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dsilu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dsilu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dswiglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dswiglu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="clamped_dswiglu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.clamped_dswiglu, is_avail), + vendor=None, + priority=50, + ), + # Activations - Bias + Backward + OpImpl( + op_name="dbias_dgelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dbias_dgelu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dbias_dsilu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dbias_dsilu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dbias_drelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dbias_drelu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dbias_dqgelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dbias_dqgelu, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dbias_dsrelu", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dbias_dsrelu, is_avail), + vendor=None, + priority=50, + ), + # Softmax + OpImpl( + op_name="scaled_softmax_forward", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.scaled_softmax_forward, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="scaled_softmax_backward", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.scaled_softmax_backward, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="scaled_masked_softmax_forward", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.scaled_masked_softmax_forward, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="scaled_masked_softmax_backward", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.scaled_masked_softmax_backward, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_forward", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_forward, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_backward", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_backward, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_forward", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_forward, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_backward", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_backward, is_avail), + vendor=None, + priority=50, + ), + # Fused attention backend getter + OpImpl( + op_name="get_fused_attn_backend", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.get_fused_attn_backend, is_avail), + vendor=None, + priority=50, + ), + # Dropout + OpImpl( + op_name="dropout_fwd", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dropout_fwd, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="dropout_bwd", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.dropout_bwd, is_avail), + vendor=None, + priority=50, + ), + # Library version getters + OpImpl( + op_name="get_cublasLt_version", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.get_cublasLt_version, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="get_cudnn_version", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.get_cudnn_version, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="get_num_cublas_streams", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.get_num_cublas_streams, is_avail), + vendor=None, + priority=50, + ), + # Multi-tensor optimizer operations + OpImpl( + op_name="multi_tensor_scale", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_scale, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="multi_tensor_l2norm", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_l2norm, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="multi_tensor_unscale_l2norm", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_unscale_l2norm, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="multi_tensor_adam", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_adam, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="multi_tensor_adam_param_remainder", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_adam_param_remainder, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="multi_tensor_sgd", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_sgd, is_avail), + vendor=None, + priority=50, + ), + # FlashAttention class getter + OpImpl( + op_name="get_flash_attention_class", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.get_flash_attention_class, is_avail), + vendor=None, + priority=50, + ), + # Attention backend selection + OpImpl( + op_name="get_attention_backend", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.get_attention_backend, is_avail), + vendor=None, + priority=50, + ), + ] + + registry.register_many(impls) diff --git a/transformer_engine/plugin/core/backends/vendor/__init__.py b/transformer_engine/plugin/core/backends/vendor/__init__.py new file mode 100644 index 0000000000..f94a17b393 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/__init__.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Vendor-specific backend implementations. + +This package contains hardware vendor-specific backend implementations +for TransformerEngine-FL. Each vendor subdirectory should contain its +own backend implementation. +""" + +from __future__ import annotations + +import os + +_vendor_loading_errors = [] + +try: + from ..._build_config import SKIP_CUDA_BUILD as _SKIP_CUDA_BUILD_CONFIG +except ImportError: + _SKIP_CUDA_BUILD_CONFIG = bool(int(os.environ.get("TE_FL_SKIP_CUDA", "0"))) + print(f"Build config not found, using env var: SKIP_CUDA_BUILD={_SKIP_CUDA_BUILD_CONFIG}") + +if os.environ.get("TE_FL_SKIP_CUDA"): + _SKIP_CUDA_BUILD = bool(int(os.environ.get("TE_FL_SKIP_CUDA", "0"))) +else: + _SKIP_CUDA_BUILD = _SKIP_CUDA_BUILD_CONFIG + +if not _SKIP_CUDA_BUILD: + try: + from .cuda import CUDABackend + except ImportError as e: + _vendor_loading_errors.append(("cuda", "ImportError", str(e))) + print(f"Failed to import CUDA vendor backend: {e}") + except Exception as e: + _vendor_loading_errors.append(("cuda", type(e).__name__, str(e))) + print(f"Error loading CUDA vendor backend: {type(e).__name__}: {e}") + import traceback + + traceback.print_exc() +else: + print("CUDA vendor backend skipped (CUDA build was disabled at build time)") + _vendor_loading_errors.append(("cuda", "Skipped", "CUDA build was disabled at build time")) + + +def get_vendor_loading_errors(): + """Get errors that occurred during vendor backend loading.""" + return _vendor_loading_errors.copy() + + +__all__ = ["get_vendor_loading_errors"] diff --git a/transformer_engine/plugin/core/backends/vendor/cuda/__init__.py b/transformer_engine/plugin/core/backends/vendor/cuda/__init__.py new file mode 100644 index 0000000000..8b8b610b6b --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/cuda/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .cuda import CUDABackend + +__all__ = ["CUDABackend"] diff --git a/transformer_engine/plugin/core/backends/vendor/cuda/cuda.py b/transformer_engine/plugin/core/backends/vendor/cuda/cuda.py new file mode 100644 index 0000000000..69c64ec0e2 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/cuda/cuda.py @@ -0,0 +1,2006 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. +import os +import sys +from typing import Any, Dict, List, Optional, Tuple, Union +import torch +from ....ops import * + + +def _load_cuda_libs(): + import ctypes + import os + import subprocess + from pathlib import Path + import importlib.util + import platform + import glob as glob_module + + def get_ext(): + system = platform.system() + return ".so" if system == "Linux" else ".dylib" if system == "Darwin" else ".dll" + + ext = get_ext() + + def try_load_lib(name, search_patterns): + for env_var in [f"{name.upper()}_HOME", f"{name.upper()}_PATH"]: + path = os.environ.get(env_var) + if path: + libs = glob_module.glob(f"{path}/**/lib{name}{ext}*", recursive=True) + if libs: + libs.sort(reverse=True, key=os.path.basename) + try: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + except: + pass + + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + for pattern in search_patterns: + libs = glob_module.glob(f"{cuda_home}/**/{pattern}", recursive=True) + if libs: + libs.sort(reverse=True, key=os.path.basename) + try: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + except: + pass + + try: + result = subprocess.check_output(f"ldconfig -p | grep 'lib{name}{ext}'", shell=True) + for line in result.decode().split("\n"): + if f"lib{name}" in line and "=>" in line: + so_path = line.split(">")[1].strip() + if so_path: + return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) + except: + pass + + try: + return ctypes.CDLL(f"lib{name}{ext}", mode=ctypes.RTLD_GLOBAL) + except: + return None + + try: + try_load_lib("cudnn", [f"libcudnn{ext}*"]) + try_load_lib("nvrtc", [f"libnvrtc{ext}*"]) + try_load_lib("curand", [f"libcurand{ext}*"]) + + te_path_override = os.environ.get("TE_LIB_PATH") + if te_path_override: + te_path = Path(te_path_override) + else: + te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent + for search_dir in [te_path, te_path / "transformer_engine"]: + if search_dir.exists(): + matches = list(search_dir.glob(f"libtransformer_engine{ext}*")) + if matches: + ctypes.CDLL(str(matches[0]), mode=ctypes.RTLD_GLOBAL) + return True + return False + except Exception as e: + return False + + +_cuda_libs_loaded = False + + +def _ensure_cuda_libs(): + global _cuda_libs_loaded + if not _cuda_libs_loaded: + _cuda_libs_loaded = _load_cuda_libs() + if _cuda_libs_loaded: + print(f"[CUDA] Successfully loaded CUDA libs") + return _cuda_libs_loaded + + +def _check_cuda_available() -> bool: + if not torch.cuda.is_available(): + return False + + import os + + try: + from ...._build_config import SKIP_CUDA_BUILD + + if SKIP_CUDA_BUILD: + print("[CUDA] Disabled: CUDA was skipped at build time") + return False + except ImportError: + if bool(int(os.environ.get("TE_FL_SKIP_CUDA", "0"))): + print("[CUDA] Disabled: TE_FL_SKIP_CUDA=1") + return False + + try: + if not _ensure_cuda_libs(): + return False + import transformer_engine_torch_nv + + return True + except (ImportError, OSError) as e: + print(f"[CUDA] Import failed: {e}") + return False + + +def _get_tex(): + _ensure_cuda_libs() + import transformer_engine_torch_nv + + return transformer_engine_torch_nv + + +class CUDABackend(TEFLBackendBase): + @staticmethod + def check_available() -> bool: + return _check_cuda_available() + + def __init__(self): + self._tex = None + + def _get_tex(self): + if self._tex is None: + self._tex = _get_tex() + return self._tex + + def is_available(self) -> bool: + return _check_cuda_available() + + def get_attention_backend(self, attention_params=None): + """ + CUDA backend uses the default attention backend selection logic. + This allows hardware-specific checks and optimizations for CUDA devices. + Returns: + Tuple of (use_flash_attention, flash_attention_backend, use_fused_attention, + fused_attention_backend, use_unfused_attention, available_backends) + """ + # Import the original get_attention_backend function + from transformer_engine.pytorch.attention.dot_product_attention import ( + utils as dpa_utils, + ) + + return dpa_utils._original_get_attention_backend(attention_params) + + ##### transformer_engine/pytorch/csrc/extensions/pybind.cpp ##### + def quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + output: Optional[torch.Tensor] = None, + noop: Optional[torch.Tensor] = None, + ) -> Any: + tex = self._get_tex() + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.quantize(tensor, quantizer, output, noop) + + def dequantize( + self, + input: Any, + otype: DType, + ) -> Any: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.dequantize(input, otype) + + def bgrad_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.bgrad_quantize(input, quantizer) + + def group_quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + num_tensors: int, + first_dims: List[int], + ) -> Any: + tex = self._get_tex() + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + return tex.group_quantize(tensor, quantizer, num_tensors, first_dims) + + def bgrad_group_quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + num_tensors: int, + first_dims: List[int], + ) -> Any: + tex = self._get_tex() + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + return tex.bgrad_group_quantize(tensor, quantizer, num_tensors, first_dims) + + def generic_gemm( + self, + A: Any, + transA: bool, + B: Any, + transB: bool, + D: Any, + quantizer: Any, + output_dtype: Optional[DType], + bias: Optional[torch.Tensor], + bias_type: DType, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[CommOverlapType] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, + ) -> List[Any]: + tex = self._get_tex() + + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + comm_type = tex.CommOverlapType(int(comm_type)) if comm_type is not None else None + output_dtype = tex.DType(int(output_dtype)) if output_dtype is not None else None + return tex.generic_gemm( + A, + transA, + B, + transB, + D, + quantizer, + output_dtype, + bias, + bias_type, + gelu, + gelu_in, + grad, + workspace, + workspace_size, + accumulate, + use_split_accumulator, + comm_overlap, + comm_type, + extra_output, + bulk_overlap, + alpha, + beta, + ) + + # GLU # + def glu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.glu(input, quantizer) + + # GELU and variants # + def gelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.gelu(input, quantizer) + + def geglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.geglu(input, quantizer) + + def qgelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgelu(input, quantizer) + + def qgeglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgeglu(input, quantizer) + + # ReLU and variants # + def relu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.relu(input, quantizer) + + def reglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.reglu(input, quantizer) + + def srelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.srelu(input, quantizer) + + def sreglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.sreglu(input, quantizer) + + # SwiGLU and variants # + def silu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.silu(input, quantizer) + + def swiglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.swiglu(input, quantizer) + + def clamped_swiglu( + self, + input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_swiglu(input, quantizer, limit, alpha) + + # Backward of GLU # + def dglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dglu(grad, fwd_input, quantizer) + + # Backward of GELU and variants # + def dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgelu(grad, fwd_input, quantizer) + + def dgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgeglu(grad, fwd_input, quantizer) + + def dqgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgelu(grad, fwd_input, quantizer) + + def dqgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgeglu(grad, fwd_input, quantizer) + + # Backward of ReLU and variants # + def drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.drelu(grad, fwd_input, quantizer) + + def dreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dreglu(grad, fwd_input, quantizer) + + def dsrelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsrelu(grad, fwd_input, quantizer) + + def dsreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsreglu(grad, fwd_input, quantizer) + + # Backward of SiLU and variants # + def dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsilu(grad, fwd_input, quantizer) + + def dswiglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dswiglu(grad, fwd_input, quantizer) + + def clamped_dswiglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_dswiglu(grad, fwd_input, quantizer, limit, alpha) + + # DBias + DAct fusions # + def dbias_dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dgelu(grad, fwd_input, quantizer) + + def dbias_dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsilu(grad, fwd_input, quantizer) + + def dbias_drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_drelu(grad, fwd_input, quantizer) + + def dbias_dqgelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dqgelu(grad, fwd_input, quantizer) + + def dbias_dsrelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsrelu(grad, fwd_input, quantizer) + + # Permutation functions + def moe_permute_fwd( + self, + input: torch.Tensor, + dtype: DType, + indices: torch.Tensor, + num_out_tokens: int, + workspace: List[torch.Tensor], + max_expanded_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_fwd( + input, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + + def moe_permute_bwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_bwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_fwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_bwd( + self, + input_bwd: torch.Tensor, + input_fwd: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_bwd(input_bwd, input_fwd, dtype, row_id_map, prob) + + # Softmax functions + def scaled_softmax_forward( + self, + input: torch.Tensor, + scale: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_forward(input, scale) + + def scaled_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_masked_softmax_forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_forward(input, mask, scale_factor) + + def scaled_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_upper_triang_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_forward(input, scale_factor) + + def scaled_upper_triang_masked_softmax_backward( + self, + output_grads_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_backward( + output_grads_, softmax_results_, scale_factor + ) + + def scaled_aligned_causal_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_forward(input, scale_factor) + + def scaled_aligned_causal_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_backward( + output_grad_, softmax_results_, scale_factor + ) + + # Other granular functions + def layernorm_fwd( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.layernorm_fwd( + input, + weight, + bias, + eps, + ln_out, + quantizer, + otype, + sm_margin, + zero_centered_gamma, + ) + + def layernorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_fwd( + self, + input: Any, + weight: Any, + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.rmsnorm_fwd( + input, weight, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def rmsnorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_bwd_add( + self, + dz: torch.Tensor, + x: torch.Tensor, + add: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd_add(dz, x, add, rsigma, gamma, sm_margin, zero_centered_gamma) + + def multi_tensor_quantize( + self, + tensor_list: List[torch.Tensor], + quantizer_list: List[Any], + ) -> List[Any]: + tex = self._get_tex() + return tex.multi_tensor_quantize(tensor_list, quantizer_list) + + def split_quantize( + self, + tensor: torch.Tensor, + split_sections: List[int], + quantizer_list: List[Any], + disable_bulk_allocation: bool = False, + ) -> List[Any]: + tex = self._get_tex() + return tex.split_quantize(tensor, split_sections, quantizer_list, disable_bulk_allocation) + + def te_general_grouped_gemm( + self, + A: List[Any], + transa: bool, + B: List[Any], + transb: bool, + D: Optional[List[torch.Tensor]], + D_type: DType, + m_splits: List[int], + bias: List[torch.Tensor], + bias_type: DType, + single_output: bool, + pre_gelu_out: List[torch.Tensor], + grad: bool, + workspace: List[torch.Tensor], + workspaceSizes: int, + accumulate: bool, + use_split_accumulator: bool, + math_sm_count: int, + ) -> Optional[List[torch.Tensor]]: + tex = self._get_tex() + D_type = tex.DType(int(D_type)) if D_type is not None else None + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + return tex.te_general_grouped_gemm( + A, + transa, + B, + transb, + D, + D_type, + m_splits, + bias, + bias_type, + single_output, + pre_gelu_out, + grad, + workspace, + workspaceSizes, + accumulate, + use_split_accumulator, + math_sm_count, + ) + + def te_general_grouped_gemm_for_grouped_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_grouped_tensor(*args, **kwargs) + + def te_general_grouped_gemm_for_discrete_in(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_discrete_in(*args, **kwargs) + + def te_general_grouped_gemm_for_discrete_out(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_discrete_out(*args, **kwargs) + + def fp8_transpose( + self, + input: torch.Tensor, + dtype: DType, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.fp8_transpose(input, dtype, out) + + def swap_first_dims( + self, + tensor: torch.Tensor, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.swap_first_dims(tensor, out) + + def nvfp4_data_transpose( + self, + input: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.nvfp4_data_transpose(input, out=out) + + def swizzle_scales_for_gemm_(self, tensor: torch.Tensor) -> None: + tex = self._get_tex() + return tex.swizzle_scales_for_gemm_(tensor) + + def grouped_swizzle_for_gemm( + self, + tensor: Any, + rowwise: bool, + columnwise: bool, + ) -> None: + tex = self._get_tex() + return tex.grouped_swizzle_for_gemm(tensor, rowwise, columnwise) + + def convert_host_pointers_to_tensor( + self, + tensor_lists: List[List[torch.Tensor]], + ) -> Any: + tex = self._get_tex() + return tex.convert_host_pointers_to_tensor(tensor_lists) + + def get_device_pointer_for_data_and_scales( + self, + data_tensors: List[torch.Tensor], + scale_tensors: List[torch.Tensor], + swizzle: bool = False, + rowwise: bool = True, + data_dtype: Any = None, + ) -> Any: + tex = self._get_tex() + return tex.get_device_pointer_for_data_and_scales( + data_tensors, scale_tensors, swizzle, rowwise, data_dtype + ) + + def splits_to_offsets( + self, + first_dims: List[int], + logical_last_dim: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.splits_to_offsets(first_dims, logical_last_dim) + + def get_fused_attn_backend( + self, + is_training: bool, + q_dtype: DType, + kv_dtype: DType, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + p_dropout: float, + num_attn_heads: int, + num_gqa_groups: int, + max_seqlen_q: int, + max_seqlen_kv: int, + head_dim_qk: int, + head_dim_v: int, + window_size_left: int, + window_size_right: int, + return_max_logit: bool, + cuda_graph: bool = False, + deterministic: bool = False, + ) -> NVTE_Fused_Attn_Backend: + tex = self._get_tex() + + q_dtype = tex.DType(int(q_dtype)) if q_dtype is not None else None + kv_dtype = tex.DType(int(kv_dtype)) if kv_dtype is not None else None + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + result = tex.get_fused_attn_backend( + is_training, + q_dtype, + kv_dtype, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + p_dropout, + num_attn_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size_left, + window_size_right, + return_max_logit, + cuda_graph, + deterministic, + ) + return NVTE_Fused_Attn_Backend(result) + + def compute_amax( + self, + input: torch.Tensor, + amax: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.compute_amax(input, amax) + + def fused_amax_and_scale_update_after_reduction( + self, + amax_reduction_buffer: torch.Tensor, + amax_histories: List[torch.Tensor], + scales: List[torch.Tensor], + amax_compute_algo: str, + fp8_dtype: DType, + margin: float, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.fused_amax_and_scale_update_after_reduction( + amax_reduction_buffer, + amax_histories, + scales, + amax_compute_algo, + fp8_dtype, + margin, + ) + + def fp8_block_scaling_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + ) -> None: + tex = self._get_tex() + return tex.fp8_block_scaling_compute_partial_amax( + tensor, amax, h, w, start_offset, block_len + ) + + def fp8_block_scaling_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + out_dtype: DType, + ) -> None: + tex = self._get_tex() + out_dtype = tex.DType(int(out_dtype)) if out_dtype is not None else None + return tex.fp8_block_scaling_partial_cast( + inp, out, scale, h, w, start_offset, block_len, out_dtype + ) + + # MXFP8 scaling + def mxfp8_scaling_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + ) -> None: + tex = self._get_tex() + return tex.mxfp8_scaling_compute_partial_amax(tensor, amax, h, w, start_offset, block_len) + + def mxfp8_scaling_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + out_dtype: DType, + ) -> None: + tex = self._get_tex() + out_dtype = tex.DType(int(out_dtype)) if out_dtype is not None else None + return tex.mxfp8_scaling_partial_cast( + inp, out, scale, h, w, start_offset, block_len, out_dtype + ) + + # NVFP4 2D + def nvfp4_2d_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int = 16, + ) -> None: + tex = self._get_tex() + return tex.nvfp4_2d_compute_partial_amax(tensor, amax, h, w, start_offset, block_len) + + def nvfp4_multi_tensor_compute_partial_amax( + self, + master_weight_list: List[torch.Tensor], + partial_amax_list: List[torch.Tensor], + global_amax_list: List[torch.Tensor], + h_list: List[int], + w_list: List[int], + start_offset_list: List[int], + block_len: int = 16, + ) -> None: + tex = self._get_tex() + return tex.nvfp4_multi_tensor_compute_partial_amax( + master_weight_list, + partial_amax_list, + global_amax_list, + h_list, + w_list, + start_offset_list, + block_len, + ) + + def nvfp4_compute_global_scale( + self, + global_amaxes: torch.Tensor, + global_scale_tensor: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.nvfp4_compute_global_scale(global_amaxes, global_scale_tensor) + + def nvfp4_compute_per_block_scale(self, *args, **kwargs) -> None: + tex = self._get_tex() + return tex.nvfp4_compute_per_block_scale(*args, **kwargs) + + def nvfp4_expand_scale_to_fp8(self, *args, **kwargs) -> None: + tex = self._get_tex() + return tex.nvfp4_expand_scale_to_fp8(*args, **kwargs) + + def nvfp4_fused_scale(self, *args, **kwargs) -> None: + tex = self._get_tex() + return tex.nvfp4_fused_scale(*args, **kwargs) + + def nvfp4_multi_tensor_fused_scale( + self, + block_amax_list: List[torch.Tensor], + global_amax_list: List[torch.Tensor], + per_block_scale_list: List[torch.Tensor], + target_scale_list: List[torch.Tensor], + target_amax_list: List[torch.Tensor], + tile_rows_list: List[int], + tile_cols_list: List[int], + rows_padded_list: List[int], + block_len: int, + ) -> None: + tex = self._get_tex() + return tex.nvfp4_multi_tensor_fused_scale( + block_amax_list, + global_amax_list, + per_block_scale_list, + target_scale_list, + target_amax_list, + tile_rows_list, + tile_cols_list, + rows_padded_list, + block_len, + ) + + def nvfp4_2d_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + global_scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int = 16, + ) -> None: + tex = self._get_tex() + return tex.nvfp4_2d_partial_cast( + inp, out, scale, global_scale, h, w, start_offset, block_len + ) + + def nvfp4_multi_tensor_2d_partial_cast(self, inp_list, *args, **kwargs) -> None: + tex = self._get_tex() + return tex.nvfp4_multi_tensor_2d_partial_cast(inp_list, *args, **kwargs) + + def nvfp4_2d_multi_tensor_transpose( + self, + rowwise_data_list: List[torch.Tensor], + columnwise_data_list: List[torch.Tensor], + rowwise_scale_inv_list: List[torch.Tensor], + columnwise_scale_inv_list: List[torch.Tensor], + M_list: List[int], + K_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.nvfp4_2d_multi_tensor_transpose( + rowwise_data_list, + columnwise_data_list, + rowwise_scale_inv_list, + columnwise_scale_inv_list, + M_list, + K_list, + ) + + def fused_multi_row_padding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + padded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_padding(input, output, input_row_list, padded_input_row_list) + + def fused_multi_row_unpadding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + unpadded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_unpadding(input, output, input_row_list, unpadded_input_row_list) + + # attention kernels + def fa_prepare_fwd( + self, + qkvi: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_fwd(qkvi) + + def fa_prepare_bwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_bwd(q, k, v) + + def fused_attn_fwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + is_training: bool, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + fake_dtype: torch.dtype, + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + page_table_k: Optional[torch.Tensor], + page_table_v: Optional[torch.Tensor], + s_quantizer: Any, + o_quantizer: Any, + Bias: Optional[torch.Tensor], + SoftmaxOffset: Optional[torch.Tensor], + rng_gen: Optional[torch.Generator], + rng_elts_per_thread: int, + return_max_logit: bool, + cuda_graph: bool = False, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + return tex.fused_attn_fwd( + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + fake_dtype, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + s_quantizer, + o_quantizer, + Bias, + SoftmaxOffset, + rng_gen, + rng_elts_per_thread, + return_max_logit, + cuda_graph, + ) + + def fused_attn_bwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + deterministic: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + O: Any, + dO: Any, + fake_dtype: torch.dtype, + dqkv_type: DType, + Aux_CTX_Tensors: List[torch.Tensor], + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + s_quantizer: Any, + dp_quantizer: Any, + dqkv_quantizer: Any, + cuda_graph: bool = False, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + dqkv_type = tex.DType(int(dqkv_type)) if dqkv_type is not None else None + + return tex.fused_attn_bwd( + max_seqlen_q, + max_seqlen_kv, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + deterministic, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + O, + dO, + fake_dtype, + dqkv_type, + Aux_CTX_Tensors, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + s_quantizer, + dp_quantizer, + dqkv_quantizer, + cuda_graph, + ) + + def copy_to_kv_cache( + self, + new_k: torch.Tensor, + new_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + cu_new_lens: torch.Tensor, + cu_cached_lens: torch.Tensor, + qkv_format: NVTE_QKV_Format, + b: int, + max_ctx_len: int, + max_seq_len: int, + max_pages_per_seq: int, + is_non_paged: bool, + ) -> None: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + page_table, + cu_new_lens, + cu_cached_lens, + qkv_format, + b, + max_ctx_len, + max_seq_len, + max_pages_per_seq, + is_non_paged, + ) + + def convert_thd_to_bshd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + b: int, + max_seq_len: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_thd_to_bshd(tensor, cu_seqlens, b, max_seq_len) + + def convert_bshd_to_thd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + t: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_bshd_to_thd(tensor, cu_seqlens, t) + + # fused apply rope + def fused_rope_forward( + self, + input: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_forward( + input, + freqs, + start_positions, + qkv_format, + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + + def fused_rope_backward( + self, + output_grads: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_backward( + output_grads, + freqs, + start_positions, + qkv_format, + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_forward( + self, + qkv_input: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_forward( + qkv_input, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_backward( + self, + q_grad_out: torch.Tensor, + k_grad_out: torch.Tensor, + v_grad_out: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_backward( + q_grad_out, + k_grad_out, + v_grad_out, + q_freqs, + k_freqs, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + # fused router + def fused_topk_with_score_function_fwd( + self, + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], + score_function: str, + expert_bias: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + ) + + def fused_topk_with_score_function_bwd( + self, + num_tokens: int, + num_experts: int, + routing_map: torch.Tensor, + intermediate_output: torch.Tensor, + grad_probs: torch.Tensor, + topk: int, + use_pre_softmax: bool, + scaling_factor: Optional[float], + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_topk_with_score_function_bwd( + num_tokens, + num_experts, + routing_map, + intermediate_output, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + ) + + def fused_score_for_moe_aux_loss_fwd( + self, + logits: torch.Tensor, + topk: int, + score_function: str, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_fwd( + logits, + topk, + score_function, + ) + + def fused_score_for_moe_aux_loss_bwd( + self, + num_tokens: int, + num_experts: int, + intermediate_output: torch.Tensor, + grad_scores: torch.Tensor, + topk: int, + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_bwd( + num_tokens, + num_experts, + intermediate_output, + grad_scores, + topk, + score_function, + ) + + def fused_moe_aux_loss_fwd( + self, + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + num_experts: int, + num_rows: int, + num_cols: int, + topk: int, + coeff: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_moe_aux_loss_fwd( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + num_rows, + num_cols, + topk, + coeff, + ) + + def fused_moe_aux_loss_bwd( + self, + Const_buf: torch.Tensor, + tokens_per_expert: torch.Tensor, + num_rows: int, + num_cols: int, + grad_aux_loss: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_moe_aux_loss_bwd( + Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss + ) + + # Dropout + def dropout_fwd( + self, + input: torch.Tensor, + dropout_probability: float, + out: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.dropout_fwd(input, dropout_probability, out) + + def dropout_bwd( + self, + grad_output: torch.Tensor, + mask: torch.Tensor, + dropout_probability: float, + grad_input: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.dropout_bwd(grad_output, mask, dropout_probability, grad_input) + + # Misc + def get_cublasLt_version(self) -> int: + tex = self._get_tex() + return tex.get_cublasLt_version() + + def get_cudnn_version(self) -> int: + tex = self._get_tex() + return tex.get_cudnn_version() + + def get_num_cublas_streams(self) -> int: + tex = self._get_tex() + return tex.get_num_cublas_streams() + + # Support THD format for Context Parallel + def thd_read_half_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + half_idx: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_half_tensor(tensor, cu_seqlens, half_idx) + + def thd_second_half_lse_correction( + self, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_second_half_lse_correction(lse, lse_per_step, cu_seqlens, lse_packed) + + def thd_read_second_half_lse( + self, + lse: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + second_half_lse_seqlen: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_second_half_lse(lse, cu_seqlens, lse_packed, second_half_lse_seqlen) + + def thd_out_correction( + self, + out: torch.Tensor, + out_per_step: torch.Tensor, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + only_second_half: bool, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_out_correction( + out, + out_per_step, + lse, + lse_per_step, + cu_seqlens, + only_second_half, + lse_packed, + ) + + def thd_grad_correction( + self, + grad: torch.Tensor, + grad_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + first_half: str, + second_half: str, + ) -> None: + tex = self._get_tex() + return tex.thd_grad_correction(grad, grad_per_step, cu_seqlens, first_half, second_half) + + def thd_get_partitioned_indices( + self, + cu_seqlens: torch.Tensor, + total_tokens: int, + world_size: int, + rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, world_size, rank) + + # nvshmem functions + def init_nvshmem_backend( + self, + process_group: Any, + ) -> None: + tex = self._get_tex() + return tex.init_nvshmem_backend(process_group) + + def create_nvshmem_tensor( + self, + shape: List[int], + dtype: torch.dtype, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.create_nvshmem_tensor(shape, dtype) + + def nvshmem_send_on_current_stream( + self, + src: torch.Tensor, + dst: torch.Tensor, + peer: int, + signal: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_send_on_current_stream(src, dst, peer, signal) + + def nvshmem_wait_on_current_stream( + self, + signal: torch.Tensor, + wait_kind: str, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_wait_on_current_stream(signal, wait_kind) + + def nvshmem_finalize(self) -> None: + tex = self._get_tex() + return tex.nvshmem_finalize() + + # multi-tensor functions + def multi_tensor_scale( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale) + + def multi_tensor_scale_tensor( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_scale_tensor(chunk_size, noop_flag, tensor_lists, scale) + + def multi_tensor_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor) + + def multi_tensor_unscale_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + inv_scale: torch.Tensor, + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_unscale_l2norm( + chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor + ) + + def multi_tensor_adam( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_param_remainder( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_param_remainder( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_fp8( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + fp8_dtype: DType, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.multi_tensor_adam_fp8( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + fp8_dtype, + ) + + def multi_tensor_adam_capturable( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_adam_capturable_master( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_sgd( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + wd: float, + momentum: float, + dampening: float, + lr: float, + nesterov: bool, + first_run: bool, + wd_after_momentum: bool, + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_lists, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, + ) + + def multi_tensor_compute_scale_and_scale_inv( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + max_fp8: float, + force_pow_2_scales: bool, + epsilon: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_compute_scale_and_scale_inv( + chunk_size, noop_flag, tensor_lists, max_fp8, force_pow_2_scales, epsilon + ) + + def multi_tensor_compute_scale_inv_e8m0( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + block_len: int, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_compute_scale_inv_e8m0( + chunk_size, noop_flag, tensor_lists, block_len + ) + + # Comm+GEMM Overlap + def bulk_overlap_ag_with_external_gemm( + self, + allgather_communicator: CommOverlap, + send_stream: Any, + recv_stream: Any, + ) -> Any: + tex = self._get_tex() + return tex.bulk_overlap_ag_with_external_gemm( + allgather_communicator, send_stream, recv_stream + ) + + ############## class func ################################# + def get_flash_attention_class(self): + from .flash_attention import FlashAttentionCUDA + + return FlashAttentionCUDA + + def create_fp8_tensor_meta(self) -> FP8TensorMeta: + tex = self._get_tex() + return tex.FP8TensorMeta() + + def create_comm_overlap_helper( + self, + world_group: Optional[Any] = None, + intra_node_group: Optional[Any] = None, + ) -> "CommOverlapHelper": + tex = self._get_tex() + return tex.CommOverlapHelper(world_group, intra_node_group) + + def create_comm_overlap( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + num_splits: int = 3, + num_max_streams: int = 3, + comm_cga_size: int = 2, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 16, + set_sm_margin: bool = True, + atomic_gemm: bool = False, + rs_overlap_first_gemm: bool = False, + ) -> "CommOverlap": + tex = self._get_tex() + return tex.CommOverlap( + buffer_shape, + buffer_dtype, + helper, + tp_size, + num_splits, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + rs_overlap_first_gemm, + ) + + def create_comm_overlap_p2p( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + comm_type: Any, + num_max_streams: int = 3, + comm_cga_size: int = 1, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 1, + set_sm_margin: bool = False, + atomic_gemm: bool = False, + use_ce: bool = True, + aggregate: bool = False, + ) -> "CommOverlapP2P": + tex = self._get_tex() + return tex.CommOverlapP2P( + buffer_shape, + buffer_dtype, + helper, + tp_size, + comm_type, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + use_ce, + aggregate, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/cuda/flash_attention.py b/transformer_engine/plugin/core/backends/vendor/cuda/flash_attention.py new file mode 100644 index 0000000000..23295e51a5 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/cuda/flash_attention.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from transformer_engine.plugin.core.ops import FlashAttentionBase + + +class FlashAttentionCUDA(FlashAttentionBase): + def __init__( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = None, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + ) -> None: + super().__init__( + softmax_scale=softmax_scale, + attention_dropout=attention_dropout, + attention_dropout_ctx=attention_dropout_ctx, + attention_type=attention_type, + layer_number=layer_number, + deterministic=deterministic, + ) + + # Store initialization parameters for lazy loading + self._init_params = { + "softmax_scale": softmax_scale, + "attention_dropout": attention_dropout, + "attention_dropout_ctx": attention_dropout_ctx or nullcontext, + "attention_type": attention_type, + "layer_number": layer_number, + "deterministic": deterministic, + } + self._native_flash_attn = None + + def _ensure_native_flash_attn(self): + """Lazy initialization of native FlashAttention.""" + if self._native_flash_attn is not None: + return + + try: + # Import here to avoid circular dependency issues + # transformer_engine_torch must be registered before this import + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + FlashAttention as FlashAttentionNative, + ) + + if FlashAttentionNative is None: + raise RuntimeError( + "FlashAttention class is None - flash-attn may not be installed correctly" + ) + + self._native_flash_attn = FlashAttentionNative(**self._init_params) + + except ImportError as e: + raise RuntimeError( + f"Failed to import native FlashAttention: {e}. " + "Please ensure flash-attn is installed and transformer_engine_torch is available." + ) + except Exception as e: + raise RuntimeError( + f"Failed to initialize native FlashAttention: {e}. Init params: {self._init_params}" + ) + + @property + def backend_name(self) -> str: + return "cuda" + + def _forward_impl( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Any] = None, + cp_global_ranks: Optional[List[int]] = None, + cp_stream: Optional[torch.cuda.Stream] = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers: Optional[Any] = None, + inference_params: Optional[Any] = None, + flash_attention_backend: Optional[Any] = None, + fp8_output: bool = False, + num_splits: Optional[int] = 1, + ) -> torch.Tensor: + # Ensure native flash attention is initialized + self._ensure_native_flash_attn() + + return self._native_flash_attn( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=cp_group, + cp_global_ranks=cp_global_ranks, + cp_stream=cp_stream, + cp_comm_type=cp_comm_type, + fp8=fp8, + fp8_meta=fp8_meta, + quantizers=quantizers, + inference_params=inference_params, + flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, + num_splits=num_splits, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/cuda/register_ops.py b/transformer_engine/plugin/core/backends/vendor/cuda/register_ops.py new file mode 100644 index 0000000000..5fac3e34c4 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/cuda/register_ops.py @@ -0,0 +1,1173 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +CUDA vendor backend operator registrations. + +This module registers all VENDOR (CUDA) implementations from transformer_engine_torch. +""" + +from __future__ import annotations + +import functools + +from ....types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all CUDA (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + # Import CUDA backend to get all the wrapped tex functions + from .cuda import CUDABackend + + # Create a backend instance to access the methods + backend = CUDABackend() + + # Check if CUDA is available before registering + if not backend.is_available(): + return + + # Bind is_available to all methods + is_avail = backend.is_available + + impls = [ + # Normalization + OpImpl( + op_name="rmsnorm_fwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_fwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd_add", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd_add, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="layernorm_fwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_fwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="layernorm_bwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_bwd, is_avail), + vendor="CUDA", + priority=100, + ), + # GEMM + OpImpl( + op_name="generic_gemm", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.generic_gemm, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_grouped_tensor", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_grouped_tensor, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_discrete_in", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_discrete_in, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_discrete_out", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_discrete_out, is_avail), + vendor="CUDA", + priority=100, + ), + # Quantization + OpImpl( + op_name="quantize", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.quantize, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dequantize", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dequantize, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="bgrad_quantize", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_quantize, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="group_quantize", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.group_quantize, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="bgrad_group_quantize", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_group_quantize, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="split_quantize", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.split_quantize, is_avail), + vendor="CUDA", + priority=100, + ), + # Activations - Forward + OpImpl( + op_name="glu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.glu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="gelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.gelu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="geglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.geglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="qgelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgelu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="qgeglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgeglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="relu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.relu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="reglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.reglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="srelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.srelu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="sreglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.sreglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="silu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.silu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="swiglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swiglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="clamped_swiglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_swiglu, is_avail), + vendor="CUDA", + priority=100, + ), + # Activations - Backward + OpImpl( + op_name="dglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dgelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgelu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dgeglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgeglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dqgelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgelu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dqgeglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgeglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="drelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.drelu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dreglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dreglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dsrelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsrelu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dsreglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsreglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dsilu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsilu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dswiglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dswiglu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="clamped_dswiglu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_dswiglu, is_avail), + vendor="CUDA", + priority=100, + ), + # Activations - Bias + Backward + OpImpl( + op_name="dbias_dgelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dgelu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dbias_dsilu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsilu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dbias_drelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_drelu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dbias_dqgelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dqgelu, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dbias_dsrelu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsrelu, is_avail), + vendor="CUDA", + priority=100, + ), + # Softmax + OpImpl( + op_name="scaled_softmax_forward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_forward, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="scaled_softmax_backward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_backward, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_forward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_forward, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_backward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_backward, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_forward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_forward, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_backward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_backward, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_forward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_forward, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_backward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_backward, is_avail), + vendor="CUDA", + priority=100, + ), + # MOE operations + OpImpl( + op_name="moe_permute_fwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_fwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="moe_permute_bwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_bwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_fwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_fwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_bwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_bwd, is_avail), + vendor="CUDA", + priority=100, + ), + # Fused attention + OpImpl( + op_name="get_fused_attn_backend", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_fused_attn_backend, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_attn_fwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_attn_fwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_attn_bwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_attn_bwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fa_prepare_fwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_fwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fa_prepare_bwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_bwd, is_avail), + vendor="CUDA", + priority=100, + ), + # KV cache + OpImpl( + op_name="copy_to_kv_cache", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.copy_to_kv_cache, is_avail), + vendor="CUDA", + priority=100, + ), + # Tensor format conversions + OpImpl( + op_name="convert_thd_to_bshd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_thd_to_bshd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="convert_bshd_to_thd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_bshd_to_thd, is_avail), + vendor="CUDA", + priority=100, + ), + # RoPE (Rotary Position Embedding) + OpImpl( + op_name="fused_rope_forward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_forward, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_rope_backward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_backward, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_forward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_forward, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_backward", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_backward, is_avail), + vendor="CUDA", + priority=100, + ), + # TopK and MOE aux loss + OpImpl( + op_name="fused_topk_with_score_function_fwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_fwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_topk_with_score_function_bwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_bwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_fwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_fwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_bwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_bwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_fwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_fwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_bwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_bwd, is_avail), + vendor="CUDA", + priority=100, + ), + # Dropout + OpImpl( + op_name="dropout_fwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_fwd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="dropout_bwd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_bwd, is_avail), + vendor="CUDA", + priority=100, + ), + # FP8 operations + OpImpl( + op_name="fp8_transpose", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_transpose, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="swap_first_dims", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swap_first_dims, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvfp4_data_transpose", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_data_transpose, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="swizzle_scales_for_gemm_", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swizzle_scales_for_gemm_, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="grouped_swizzle_for_gemm", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.grouped_swizzle_for_gemm, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="convert_host_pointers_to_tensor", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_host_pointers_to_tensor, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="get_device_pointer_for_data_and_scales", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_device_pointer_for_data_and_scales, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="splits_to_offsets", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.splits_to_offsets, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="compute_amax", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.compute_amax, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_amax_and_scale_update_after_reduction", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_amax_and_scale_update_after_reduction, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_compute_partial_amax", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_compute_partial_amax, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_partial_cast", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_partial_cast, is_avail), + vendor="CUDA", + priority=100, + ), + # MXFP8 scaling + OpImpl( + op_name="mxfp8_scaling_compute_partial_amax", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.mxfp8_scaling_compute_partial_amax, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="mxfp8_scaling_partial_cast", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.mxfp8_scaling_partial_cast, is_avail), + vendor="CUDA", + priority=100, + ), + # NVFP4 2D + OpImpl( + op_name="nvfp4_2d_compute_partial_amax", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_compute_partial_amax, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_compute_partial_amax", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_compute_partial_amax, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvfp4_compute_global_scale", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_compute_global_scale, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvfp4_compute_per_block_scale", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_compute_per_block_scale, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvfp4_expand_scale_to_fp8", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_expand_scale_to_fp8, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvfp4_fused_scale", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_fused_scale, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_fused_scale", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_fused_scale, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_partial_cast", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_partial_cast, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_2d_partial_cast", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_2d_partial_cast, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_multi_tensor_transpose", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_multi_tensor_transpose, is_avail), + vendor="CUDA", + priority=100, + ), + # Padding operations + OpImpl( + op_name="fused_multi_row_padding", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_padding, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="fused_multi_row_unpadding", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_unpadding, is_avail), + vendor="CUDA", + priority=100, + ), + # Library version getters + OpImpl( + op_name="get_cublasLt_version", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cublasLt_version, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="get_cudnn_version", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cudnn_version, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="get_num_cublas_streams", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_num_cublas_streams, is_avail), + vendor="CUDA", + priority=100, + ), + # THD (Tensor, Hidden, Dimension) operations + OpImpl( + op_name="thd_read_half_tensor", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_half_tensor, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="thd_second_half_lse_correction", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_second_half_lse_correction, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="thd_read_second_half_lse", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_second_half_lse, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="thd_out_correction", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_out_correction, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="thd_grad_correction", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_grad_correction, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="thd_get_partitioned_indices", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_get_partitioned_indices, is_avail), + vendor="CUDA", + priority=100, + ), + # NVSHMEM operations + OpImpl( + op_name="init_nvshmem_backend", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.init_nvshmem_backend, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="create_nvshmem_tensor", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_nvshmem_tensor, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvshmem_send_on_current_stream", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_send_on_current_stream, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvshmem_wait_on_current_stream", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_wait_on_current_stream, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="nvshmem_finalize", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_finalize, is_avail), + vendor="CUDA", + priority=100, + ), + # Multi-tensor operations + OpImpl( + op_name="multi_tensor_quantize", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_quantize, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale_tensor", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale_tensor, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_l2norm", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_l2norm, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_unscale_l2norm", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_unscale_l2norm, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_param_remainder", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_param_remainder, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_fp8", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_fp8, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable_master", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable_master, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_sgd", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_sgd, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_and_scale_inv", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_and_scale_inv, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_inv_e8m0", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_inv_e8m0, is_avail), + vendor="CUDA", + priority=100, + ), + # Communication overlap operations + OpImpl( + op_name="bulk_overlap_ag_with_external_gemm", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bulk_overlap_ag_with_external_gemm, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="create_fp8_tensor_meta", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_fp8_tensor_meta, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_helper", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_helper, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap, is_avail), + vendor="CUDA", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_p2p", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_p2p, is_avail), + vendor="CUDA", + priority=100, + ), + # FlashAttention class getter + OpImpl( + op_name="get_flash_attention_class", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_flash_attention_class, is_avail), + vendor="CUDA", + priority=100, + ), + # Attention backend selection + OpImpl( + op_name="get_attention_backend", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_attention_backend, is_avail), + vendor="CUDA", + priority=100, + ), + ] + + registry.register_many(impls) diff --git a/transformer_engine/plugin/core/backends/vendor/hygon/__init__.py b/transformer_engine/plugin/core/backends/vendor/hygon/__init__.py new file mode 100644 index 0000000000..a48a5c650f --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/hygon/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .hygon import HygonBackend + +__all__ = ["HygonBackend"] diff --git a/transformer_engine/plugin/core/backends/vendor/hygon/flash_attention.py b/transformer_engine/plugin/core/backends/vendor/hygon/flash_attention.py new file mode 100644 index 0000000000..eb2fbd4584 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/hygon/flash_attention.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from transformer_engine.plugin.core.ops import FlashAttentionBase + + +class FlashAttentionHYGON(FlashAttentionBase): + def __init__( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = None, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + ) -> None: + super().__init__( + softmax_scale=softmax_scale, + attention_dropout=attention_dropout, + attention_dropout_ctx=attention_dropout_ctx, + attention_type=attention_type, + layer_number=layer_number, + deterministic=deterministic, + ) + + # Store initialization parameters for lazy loading + self._init_params = { + "softmax_scale": softmax_scale, + "attention_dropout": attention_dropout, + "attention_dropout_ctx": attention_dropout_ctx or nullcontext, + "attention_type": attention_type, + "layer_number": layer_number, + "deterministic": deterministic, + } + self._native_flash_attn = None + + def _ensure_native_flash_attn(self): + """Lazy initialization of native FlashAttention.""" + if self._native_flash_attn is not None: + return + + try: + # Import here to avoid circular dependency issues + # transformer_engine_torch must be registered before this import + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + FlashAttention as FlashAttentionNative, + ) + + if FlashAttentionNative is None: + raise RuntimeError( + "FlashAttention class is None - flash-attn may not be installed correctly" + ) + + self._native_flash_attn = FlashAttentionNative(**self._init_params) + + except ImportError as e: + raise RuntimeError( + f"Failed to import native FlashAttention: {e}. " + "Please ensure flash-attn is installed and transformer_engine_torch is available." + ) + except Exception as e: + raise RuntimeError( + f"Failed to initialize native FlashAttention: {e}. Init params: {self._init_params}" + ) + + @property + def backend_name(self) -> str: + return "hygon" + + def _forward_impl( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Any] = None, + cp_global_ranks: Optional[List[int]] = None, + cp_stream: Optional[torch.cuda.Stream] = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers: Optional[Any] = None, + inference_params: Optional[Any] = None, + flash_attention_backend: Optional[Any] = None, + fp8_output: bool = False, + num_splits: Optional[int] = 1, + ) -> torch.Tensor: + # Ensure native flash attention is initialized + self._ensure_native_flash_attn() + + return self._native_flash_attn( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=cp_group, + cp_global_ranks=cp_global_ranks, + cp_stream=cp_stream, + cp_comm_type=cp_comm_type, + fp8=fp8, + fp8_meta=fp8_meta, + quantizers=quantizers, + inference_params=inference_params, + flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, + num_splits=num_splits, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/hygon/hygon.py b/transformer_engine/plugin/core/backends/vendor/hygon/hygon.py new file mode 100644 index 0000000000..c9ea27197c --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/hygon/hygon.py @@ -0,0 +1,1819 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +from typing import Any, Dict, List, Optional, Tuple, Union +import torch +from ....ops import * + + +def _load_hygon_libs(): + import ctypes + from pathlib import Path + import importlib + import platform + + common_prefix = "libtransformer_engine" + csrc_prefix = "transformer_engine_torch_hygon" + common_files = [] + csrc_files = [] + + def _get_sys_extension() -> str: + system = platform.system() + if system == "Linux": + return ".so" + if system == "Darwin": + return ".dylib" + if system == "Windows": + return ".dll" + raise RuntimeError(f"Unsupported operating system ({system})") + + try: + if bool(int(os.environ.get("TE_FL_SKIP_HYGON", "0"))): + return False + ext = _get_sys_extension() + hygon_spec = importlib.util.find_spec("transformer_engine_hygon") + if hygon_spec is None: + return False + hygon_path = Path(hygon_spec.origin).parent + for file_path in hygon_path.iterdir(): + if file_path.name.startswith(common_prefix) and file_path.suffix == ext: + common_files.append(file_path) + if file_path.name.startswith(csrc_prefix) and file_path.suffix == ext: + csrc_files.append(file_path) + if len(common_files) == 0: + return False + if len(csrc_files) == 0: + return False + ctypes.CDLL(str(common_files[0]), mode=ctypes.RTLD_GLOBAL) + spec = importlib.util.spec_from_file_location(csrc_prefix, csrc_files[0]) + solib = importlib.util.module_from_spec(spec) + sys.modules[csrc_prefix] = solib + spec.loader.exec_module(solib) + return True + except Exception as e: + return False + + +_hygon_libs_loaded = False + + +def _ensure_hygon_libs(): + global _hygon_libs_loaded + if not _hygon_libs_loaded: + _hygon_libs_loaded = _load_hygon_libs() + if _hygon_libs_loaded: + print(f"[HYGON] Successfully loaded HYGON libs") + return _hygon_libs_loaded + + +def _check_hygon_available() -> bool: + try: + if not _ensure_hygon_libs(): + return False + import transformer_engine_torch_hygon + + return True + except (ImportError, OSError) as e: + print(f"[HYGON] Import failed: {e}") + return False + + +def _get_tex(): + _ensure_hygon_libs() + import transformer_engine_torch_hygon + + return transformer_engine_torch_hygon + + +class HygonBackend(TEFLBackendBase): + @staticmethod + def check_available() -> bool: + return _check_hygon_available() + + def __init__(self): + self._tex = None + + def _get_tex(self): + if self._tex is None: + self._tex = _get_tex() + return self._tex + + def is_available(self) -> bool: + return _check_hygon_available() + + def get_attention_backend(self, attention_params=None): + from packaging.version import Version as PkgVersion + from ....logger_manager import get_logger + + logger = get_logger() + + # Read environment variables to determine which backends to enable + use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) + use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) + use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) + + # Log disabled backends + if not use_flash_attention: + logger.info_once("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") + if not use_fused_attention: + logger.info_once("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") + if not use_unfused_attention: + logger.info_once("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") + + flash_attention_backend = PkgVersion("2.6.0") if use_flash_attention else None + fused_attention_backend = NVTE_Fused_Attn_Backend.NVTE_No_Backend + + available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + + return ( + use_flash_attention, + flash_attention_backend, + use_fused_attention, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) + + ##### transformer_engine/pytorch/csrc/extensions/pybind.cpp ##### + def quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + output: Optional[torch.Tensor] = None, + noop: Optional[torch.Tensor] = None, + ) -> Any: + tex = self._get_tex() + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + return tex.quantize(tensor, quantizer, output, noop) + + def dequantize( + self, + input: Any, + otype: DType, + ) -> Any: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.dequantize(input, otype) + + def bgrad_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.bgrad_quantize(input, quantizer) + + def group_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.group_quantize(input, quantizer) + + def bgrad_group_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.bgrad_group_quantize(input, quantizer) + + def generic_gemm( + self, + A: Any, + transA: bool, + B: Any, + transB: bool, + D: Any, + quantizer: Any, + output_dtype: Optional[DType], + bias: Optional[torch.Tensor], + bias_type: DType, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[CommOverlapType] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, + ) -> List[Any]: + tex = self._get_tex() + + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + comm_type = tex.CommOverlapType(int(comm_type)) if comm_type is not None else None + output_dtype = tex.DType(int(output_dtype)) if output_dtype is not None else None + return tex.generic_gemm( + A, + transA, + B, + transB, + D, + quantizer, + output_dtype, + bias, + bias_type, + gelu, + gelu_in, + grad, + workspace, + workspace_size, + accumulate, + use_split_accumulator, + comm_overlap, + comm_type, + extra_output, + bulk_overlap, + alpha, + beta, + ) + + # GLU # + def glu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.glu(input, quantizer) + + # GELU and variants # + def gelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.gelu(input, quantizer) + + def geglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.geglu(input, quantizer) + + def qgelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgelu(input, quantizer) + + def qgeglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgeglu(input, quantizer) + + # ReLU and variants # + def relu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.relu(input, quantizer) + + def reglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.reglu(input, quantizer) + + def srelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.srelu(input, quantizer) + + def sreglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.sreglu(input, quantizer) + + # SwiGLU and variants # + def silu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.silu(input, quantizer) + + def swiglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.swiglu(input, quantizer) + + def clamped_swiglu( + self, + input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_swiglu(input, quantizer, limit, alpha) + + # Backward of GLU # + def dglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dglu(grad, fwd_input, quantizer) + + # Backward of GELU and variants # + def dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgelu(grad, fwd_input, quantizer) + + def dgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgeglu(grad, fwd_input, quantizer) + + def dqgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgelu(grad, fwd_input, quantizer) + + def dqgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgeglu(grad, fwd_input, quantizer) + + # Backward of ReLU and variants # + def drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.drelu(grad, fwd_input, quantizer) + + def dreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dreglu(grad, fwd_input, quantizer) + + def dsrelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsrelu(grad, fwd_input, quantizer) + + def dsreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsreglu(grad, fwd_input, quantizer) + + # Backward of SiLU and variants # + def dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsilu(grad, fwd_input, quantizer) + + def dswiglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dswiglu(grad, fwd_input, quantizer) + + def clamped_dswiglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_dswiglu(grad, fwd_input, quantizer, limit, alpha) + + # DBias + DAct fusions # + def dbias_dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dgelu(grad, fwd_input, quantizer) + + def dbias_dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsilu(grad, fwd_input, quantizer) + + def dbias_drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_drelu(grad, fwd_input, quantizer) + + def dbias_dqgelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dqgelu(grad, fwd_input, quantizer) + + def dbias_dsrelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsrelu(grad, fwd_input, quantizer) + + # Permutation functions + def moe_permute_fwd( + self, + input: torch.Tensor, + dtype: DType, + indices: torch.Tensor, + num_out_tokens: int, + workspace: List[torch.Tensor], + max_expanded_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_fwd( + input, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + + def moe_permute_bwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_bwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_fwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_bwd( + self, + input_bwd: torch.Tensor, + input_fwd: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_bwd(input_bwd, input_fwd, dtype, row_id_map, prob) + + # Softmax functions + def scaled_softmax_forward( + self, + input: torch.Tensor, + scale: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_forward(input, scale) + + def scaled_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_masked_softmax_forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_forward(input, mask, scale_factor) + + def scaled_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_upper_triang_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_forward(input, scale_factor) + + def scaled_upper_triang_masked_softmax_backward( + self, + output_grads_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_backward( + output_grads_, softmax_results_, scale_factor + ) + + def scaled_aligned_causal_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_forward(input, scale_factor) + + def scaled_aligned_causal_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_backward( + output_grad_, softmax_results_, scale_factor + ) + + # Other granular functions + def layernorm_fwd( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.layernorm_fwd( + input, weight, bias, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def layernorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_fwd( + self, + input: Any, + weight: Any, + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.rmsnorm_fwd( + input, weight, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def rmsnorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_bwd_add( + self, + dz: torch.Tensor, + x: torch.Tensor, + add: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd_add(dz, x, add, rsigma, gamma, sm_margin, zero_centered_gamma) + + def multi_tensor_quantize( + self, + tensor_list: List[torch.Tensor], + quantizer_list: List[Any], + ) -> List[Any]: + tex = self._get_tex() + return tex.multi_tensor_quantize(tensor_list, quantizer_list) + + def split_quantize( + self, + tensor: torch.Tensor, + split_sections: List[int], + quantizer_list: List[Any], + disable_bulk_allocation: bool = False, + ) -> List[Any]: + tex = self._get_tex() + return tex.split_quantize(tensor, split_sections, quantizer_list, disable_bulk_allocation) + + def te_general_grouped_gemm( + self, + A: List[Any], + transa: bool, + B: List[Any], + transb: bool, + D: Optional[List[torch.Tensor]], + D_type: DType, + m_splits: List[int], + bias: List[torch.Tensor], + bias_type: DType, + single_output: bool, + pre_gelu_out: List[torch.Tensor], + grad: bool, + workspace: List[torch.Tensor], + workspaceSizes: int, + accumulate: bool, + use_split_accumulator: bool, + math_sm_count: int, + ) -> Optional[List[torch.Tensor]]: + tex = self._get_tex() + D_type = tex.DType(int(D_type)) if D_type is not None else None + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + return tex.te_general_grouped_gemm( + A, + transa, + B, + transb, + D, + D_type, + m_splits, + bias, + bias_type, + single_output, + pre_gelu_out, + grad, + workspace, + workspaceSizes, + accumulate, + use_split_accumulator, + math_sm_count, + ) + + def te_general_grouped_gemm_for_grouped_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_grouped_tensor(*args, **kwargs) + + def te_general_grouped_gemm_for_discrete_in(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_discrete_in(*args, **kwargs) + + def te_general_grouped_gemm_for_discrete_out(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_discrete_out(*args, **kwargs) + + def fp8_transpose( + self, + input: torch.Tensor, + dtype: DType, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.fp8_transpose(input, dtype, out) + + def swap_first_dims( + self, + tensor: torch.Tensor, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.swap_first_dims(tensor, out) + + def nvfp4_data_transpose(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_data_transpose(*args, **kwargs) + + def swizzle_scales_for_gemm_(self, *args, **kwargs): + tex = self._get_tex() + return tex.swizzle_scales_for_gemm_(*args, **kwargs) + + def grouped_swizzle_for_gemm(self, *args, **kwargs): + tex = self._get_tex() + return tex.grouped_swizzle_for_gemm(*args, **kwargs) + + def convert_host_pointers_to_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.convert_host_pointers_to_tensor(*args, **kwargs) + + def get_device_pointer_for_data_and_scales(self, *args, **kwargs): + tex = self._get_tex() + return tex.get_device_pointer_for_data_and_scales(*args, **kwargs) + + def splits_to_offsets(self, *args, **kwargs): + tex = self._get_tex() + return tex.splits_to_offsets(*args, **kwargs) + + def get_fused_attn_backend( + self, + is_training: bool, + q_dtype: DType, + kv_dtype: DType, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + p_dropout: float, + num_attn_heads: int, + num_gqa_groups: int, + max_seqlen_q: int, + max_seqlen_kv: int, + head_dim_qk: int, + head_dim_v: int, + window_size_left: int, + window_size_right: int, + return_max_logit: bool, + cuda_graph: bool = False, + deterministic: bool = False, + ) -> NVTE_Fused_Attn_Backend: + tex = self._get_tex() + + q_dtype = tex.DType(int(q_dtype)) if q_dtype is not None else None + kv_dtype = tex.DType(int(kv_dtype)) if kv_dtype is not None else None + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + result = tex.get_fused_attn_backend( + is_training, + q_dtype, + kv_dtype, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + p_dropout, + num_attn_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size_left, + window_size_right, + return_max_logit, + cuda_graph, + deterministic, + ) + return NVTE_Fused_Attn_Backend(result) + + def compute_amax( + self, + input: torch.Tensor, + amax: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.compute_amax(input, amax) + + def fused_amax_and_scale_update_after_reduction( + self, + amax_reduction_buffer: torch.Tensor, + amax_histories: List[torch.Tensor], + scales: List[torch.Tensor], + amax_compute_algo: str, + fp8_dtype: DType, + margin: float, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.fused_amax_and_scale_update_after_reduction( + amax_reduction_buffer, amax_histories, scales, amax_compute_algo, fp8_dtype, margin + ) + + def fp8_block_scaling_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + ) -> None: + tex = self._get_tex() + return tex.fp8_block_scaling_compute_partial_amax( + tensor, amax, h, w, start_offset, block_len + ) + + def fp8_block_scaling_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + out_dtype: DType, + ) -> None: + tex = self._get_tex() + out_dtype = tex.DType(int(out_dtype)) if out_dtype is not None else None + return tex.fp8_block_scaling_partial_cast( + inp, out, scale, h, w, start_offset, block_len, out_dtype + ) + + def mxfp8_scaling_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.mxfp8_scaling_compute_partial_amax(*args, **kwargs) + + def mxfp8_scaling_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.mxfp8_scaling_partial_cast(*args, **kwargs) + + def nvfp4_2d_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_compute_partial_amax(*args, **kwargs) + + def nvfp4_multi_tensor_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_compute_partial_amax(*args, **kwargs) + + def nvfp4_compute_global_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_compute_global_scale(*args, **kwargs) + + def nvfp4_compute_per_block_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_compute_per_block_scale(*args, **kwargs) + + def nvfp4_expand_scale_to_fp8(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_expand_scale_to_fp8(*args, **kwargs) + + def nvfp4_fused_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_fused_scale(*args, **kwargs) + + def nvfp4_multi_tensor_fused_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_fused_scale(*args, **kwargs) + + def nvfp4_2d_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_partial_cast(*args, **kwargs) + + def nvfp4_multi_tensor_2d_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_2d_partial_cast(*args, **kwargs) + + def nvfp4_2d_multi_tensor_transpose(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_multi_tensor_transpose(*args, **kwargs) + + def fused_multi_row_padding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + padded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_padding(input, output, input_row_list, padded_input_row_list) + + def fused_multi_row_unpadding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + unpadded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_unpadding(input, output, input_row_list, unpadded_input_row_list) + + # attention kernels + def fa_prepare_fwd( + self, + qkvi: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_fwd(qkvi) + + def fa_prepare_bwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_bwd(q, k, v) + + def fused_attn_fwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + is_training: bool, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + fake_dtype: torch.dtype, + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + page_table_k: Optional[torch.Tensor], + page_table_v: Optional[torch.Tensor], + s_quantizer: Any, + o_quantizer: Any, + Bias: Optional[torch.Tensor], + SoftmaxOffset: Optional[torch.Tensor], + rng_gen: Optional[torch.Generator], + rng_elts_per_thread: int, + return_max_logit: bool, + cuda_graph: bool = False, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + return tex.fused_attn_fwd( + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + fake_dtype, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + s_quantizer, + o_quantizer, + Bias, + SoftmaxOffset, + rng_gen, + rng_elts_per_thread, + return_max_logit, + cuda_graph, + ) + + def fused_attn_bwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + deterministic: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + O: Any, + dO: Any, + fake_dtype: torch.dtype, + dqkv_type: DType, + Aux_CTX_Tensors: List[torch.Tensor], + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + s_quantizer: Any, + dp_quantizer: Any, + dqkv_quantizer: Any, + cuda_graph: bool = False, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + dqkv_type = tex.DType(int(dqkv_type)) if dqkv_type is not None else None + + return tex.fused_attn_bwd( + max_seqlen_q, + max_seqlen_kv, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + deterministic, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + O, + dO, + fake_dtype, + dqkv_type, + Aux_CTX_Tensors, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + s_quantizer, + dp_quantizer, + dqkv_quantizer, + cuda_graph, + ) + + def copy_to_kv_cache( + self, + new_k: torch.Tensor, + new_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + cu_new_lens: torch.Tensor, + cu_cached_lens: torch.Tensor, + qkv_format: NVTE_QKV_Format, + b: int, + max_ctx_len: int, + max_seq_len: int, + max_pages_per_seq: int, + is_non_paged: bool, + ) -> None: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + page_table, + cu_new_lens, + cu_cached_lens, + qkv_format, + b, + max_ctx_len, + max_seq_len, + max_pages_per_seq, + is_non_paged, + ) + + def convert_thd_to_bshd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + b: int, + max_seq_len: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_thd_to_bshd(tensor, cu_seqlens, b, max_seq_len) + + def convert_bshd_to_thd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + t: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_bshd_to_thd(tensor, cu_seqlens, t) + + # fused apply rope + def fused_rope_forward( + self, + input: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_forward( + input, freqs, start_positions, qkv_format, interleaved, cu_seqlens, cp_size, cp_rank + ) + + def fused_rope_backward( + self, + output_grads: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_backward( + output_grads, + freqs, + start_positions, + qkv_format, + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_forward( + self, + qkv_input: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_forward( + qkv_input, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_backward( + self, + q_grad_out: torch.Tensor, + k_grad_out: torch.Tensor, + v_grad_out: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_backward( + q_grad_out, + k_grad_out, + v_grad_out, + q_freqs, + k_freqs, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + # fused router + def fused_topk_with_score_function_fwd( + self, + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], + score_function: str, + expert_bias: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + ) + + def fused_topk_with_score_function_bwd( + self, + num_tokens: int, + num_experts: int, + routing_map: torch.Tensor, + intermediate_output: torch.Tensor, + grad_probs: torch.Tensor, + topk: int, + use_pre_softmax: bool, + scaling_factor: Optional[float], + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_topk_with_score_function_bwd( + num_tokens, + num_experts, + routing_map, + intermediate_output, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + ) + + def fused_score_for_moe_aux_loss_fwd( + self, + logits: torch.Tensor, + topk: int, + score_function: str, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_fwd( + logits, + topk, + score_function, + ) + + def fused_score_for_moe_aux_loss_bwd( + self, + num_tokens: int, + num_experts: int, + intermediate_output: torch.Tensor, + grad_scores: torch.Tensor, + topk: int, + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_bwd( + num_tokens, + num_experts, + intermediate_output, + grad_scores, + topk, + score_function, + ) + + def fused_moe_aux_loss_fwd( + self, + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + num_experts: int, + num_rows: int, + num_cols: int, + topk: int, + coeff: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_moe_aux_loss_fwd( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + num_rows, + num_cols, + topk, + coeff, + ) + + def fused_moe_aux_loss_bwd( + self, + Const_buf: torch.Tensor, + tokens_per_expert: torch.Tensor, + num_rows: int, + num_cols: int, + grad_aux_loss: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_moe_aux_loss_bwd( + Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss + ) + + # Dropout + def dropout_fwd( + self, + input: torch.Tensor, + dropout_probability: float, + out: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.dropout_fwd(input, dropout_probability, out) + + def dropout_bwd( + self, + grad_output: torch.Tensor, + mask: torch.Tensor, + dropout_probability: float, + grad_input: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.dropout_bwd(grad_output, mask, dropout_probability, grad_input) + + # Misc + def get_cublasLt_version(self) -> int: + tex = self._get_tex() + return tex.get_cublasLt_version() + + def get_cudnn_version(self) -> int: + tex = self._get_tex() + return tex.get_cudnn_version() + + def get_num_cublas_streams(self) -> int: + tex = self._get_tex() + return tex.get_num_cublas_streams() + + # Support THD format for Context Parallel + def thd_read_half_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + half_idx: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_half_tensor(tensor, cu_seqlens, half_idx) + + def thd_second_half_lse_correction( + self, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_second_half_lse_correction(lse, lse_per_step, cu_seqlens, lse_packed) + + def thd_read_second_half_lse( + self, + lse: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + second_half_lse_seqlen: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_second_half_lse(lse, cu_seqlens, lse_packed, second_half_lse_seqlen) + + def thd_out_correction( + self, + out: torch.Tensor, + out_per_step: torch.Tensor, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + only_second_half: bool, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_out_correction( + out, out_per_step, lse, lse_per_step, cu_seqlens, only_second_half, lse_packed + ) + + def thd_grad_correction( + self, + grad: torch.Tensor, + grad_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + first_half: str, + second_half: str, + ) -> None: + tex = self._get_tex() + return tex.thd_grad_correction(grad, grad_per_step, cu_seqlens, first_half, second_half) + + def thd_get_partitioned_indices( + self, + cu_seqlens: torch.Tensor, + total_tokens: int, + world_size: int, + rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, world_size, rank) + + # nvshmem functions + def init_nvshmem_backend( + self, + process_group: Any, + ) -> None: + tex = self._get_tex() + return tex.init_nvshmem_backend(process_group) + + def create_nvshmem_tensor( + self, + shape: List[int], + dtype: torch.dtype, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.create_nvshmem_tensor(shape, dtype) + + def nvshmem_send_on_current_stream( + self, + src: torch.Tensor, + dst: torch.Tensor, + peer: int, + signal: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_send_on_current_stream(src, dst, peer, signal) + + def nvshmem_wait_on_current_stream( + self, + signal: torch.Tensor, + wait_kind: str, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_wait_on_current_stream(signal, wait_kind) + + def nvshmem_finalize(self) -> None: + tex = self._get_tex() + return tex.nvshmem_finalize() + + # multi-tensor functions + def multi_tensor_scale( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale) + + def multi_tensor_scale_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.multi_tensor_scale_tensor(*args, **kwargs) + + def multi_tensor_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor) + + def multi_tensor_unscale_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + inv_scale: torch.Tensor, + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_unscale_l2norm( + chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor + ) + + def multi_tensor_adam( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_param_remainder( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_param_remainder( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_fp8( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + fp8_dtype: DType, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.multi_tensor_adam_fp8( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + fp8_dtype, + ) + + def multi_tensor_adam_capturable( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_adam_capturable_master( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_sgd( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + wd: float, + momentum: float, + dampening: float, + lr: float, + nesterov: bool, + first_run: bool, + wd_after_momentum: bool, + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_lists, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, + ) + + def multi_tensor_compute_scale_and_scale_inv( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + max_fp8: float, + force_pow_2_scales: bool, + epsilon: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_compute_scale_and_scale_inv( + chunk_size, noop_flag, tensor_lists, max_fp8, force_pow_2_scales, epsilon + ) + + def multi_tensor_compute_scale_inv_e8m0(self, *args, **kwargs): + tex = self._get_tex() + return tex.multi_tensor_compute_scale_inv_e8m0(*args, **kwargs) + + # Comm+GEMM Overlap + def bulk_overlap_ag_with_external_gemm( + self, + allgather_communicator: CommOverlap, + send_stream: Any, + recv_stream: Any, + ) -> Any: + tex = self._get_tex() + return tex.bulk_overlap_ag_with_external_gemm( + allgather_communicator, send_stream, recv_stream + ) + + ############## class func ################################# + def get_flash_attention_class(self): + from .flash_attention import FlashAttentionHYGON + + return FlashAttentionHYGON + + def create_fp8_tensor_meta(self) -> FP8TensorMeta: + tex = self._get_tex() + return tex.FP8TensorMeta() + + def create_comm_overlap_helper( + self, + world_group: Optional[Any] = None, + intra_node_group: Optional[Any] = None, + ) -> "CommOverlapHelper": + tex = self._get_tex() + return tex.CommOverlapHelper(world_group, intra_node_group) + + def create_comm_overlap( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + num_splits: int = 3, + num_max_streams: int = 3, + comm_cga_size: int = 2, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 16, + set_sm_margin: bool = True, + atomic_gemm: bool = False, + rs_overlap_first_gemm: bool = False, + ) -> "CommOverlap": + tex = self._get_tex() + return tex.CommOverlap( + buffer_shape, + buffer_dtype, + helper, + tp_size, + num_splits, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + rs_overlap_first_gemm, + ) + + def create_comm_overlap_p2p( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + comm_type: Any, + num_max_streams: int = 3, + comm_cga_size: int = 1, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 1, + set_sm_margin: bool = False, + atomic_gemm: bool = False, + use_ce: bool = True, + aggregate: bool = False, + ) -> "CommOverlapP2P": + tex = self._get_tex() + return tex.CommOverlapP2P( + buffer_shape, + buffer_dtype, + helper, + tp_size, + comm_type, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + use_ce, + aggregate, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/hygon/register_ops.py b/transformer_engine/plugin/core/backends/vendor/hygon/register_ops.py new file mode 100644 index 0000000000..2b0bbc8aa0 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/hygon/register_ops.py @@ -0,0 +1,1107 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Hygon vendor backend operator registrations. + +This module registers all VENDOR (Hygon) implementations from transformer_engine_torch. +""" + +from __future__ import annotations + +import functools + +from ....types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all Hygon (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + # Import Hygon backend to get all the wrapped tex functions + from .hygon import HygonBackend + + # Create a backend instance to access the methods + backend = HygonBackend() + + # Check if Hygon is available before registering + if not backend.is_available(): + return + + # Bind is_available to all methods + is_avail = backend.is_available + + impls = [ + # Normalization + OpImpl( + op_name="rmsnorm_fwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_fwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd_add", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd_add, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="layernorm_fwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_fwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="layernorm_bwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_bwd, is_avail), + vendor="HYGON", + priority=100, + ), + # GEMM + OpImpl( + op_name="generic_gemm", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.generic_gemm, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_grouped_tensor", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_grouped_tensor, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_discrete_in", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_discrete_in, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_discrete_out", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_discrete_out, is_avail), + vendor="HYGON", + priority=100, + ), + # Quantization + OpImpl( + op_name="quantize", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.quantize, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dequantize", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dequantize, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="bgrad_quantize", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_quantize, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="group_quantize", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.group_quantize, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="bgrad_group_quantize", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_group_quantize, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="split_quantize", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.split_quantize, is_avail), + vendor="HYGON", + priority=100, + ), + # Activations - Forward + OpImpl( + op_name="glu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.glu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="gelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.gelu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="geglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.geglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="qgelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgelu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="qgeglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgeglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="relu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.relu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="reglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.reglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="srelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.srelu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="sreglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.sreglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="silu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.silu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="swiglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swiglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="clamped_swiglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_swiglu, is_avail), + vendor="HYGON", + priority=100, + ), + # Activations - Backward + OpImpl( + op_name="dglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dgelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgelu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dgeglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgeglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dqgelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgelu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dqgeglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgeglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="drelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.drelu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dreglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dreglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dsrelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsrelu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dsreglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsreglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dsilu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsilu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dswiglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dswiglu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="clamped_dswiglu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_dswiglu, is_avail), + vendor="HYGON", + priority=100, + ), + # Activations - Bias + Backward + OpImpl( + op_name="dbias_dgelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dgelu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dbias_dsilu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsilu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dbias_drelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_drelu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dbias_dqgelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dqgelu, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dbias_dsrelu", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsrelu, is_avail), + vendor="HYGON", + priority=100, + ), + # Softmax + OpImpl( + op_name="scaled_softmax_forward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_forward, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="scaled_softmax_backward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_backward, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_forward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_forward, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_backward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_backward, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_forward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_forward, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_backward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_backward, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_forward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_forward, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_backward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_backward, is_avail), + vendor="HYGON", + priority=100, + ), + # MOE operations + OpImpl( + op_name="moe_permute_fwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_fwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="moe_permute_bwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_bwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_fwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_fwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_bwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_bwd, is_avail), + vendor="HYGON", + priority=100, + ), + # Fused attention + OpImpl( + op_name="fa_prepare_fwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_fwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fa_prepare_bwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_bwd, is_avail), + vendor="HYGON", + priority=100, + ), + # KV cache + OpImpl( + op_name="copy_to_kv_cache", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.copy_to_kv_cache, is_avail), + vendor="HYGON", + priority=100, + ), + # Tensor format conversions + OpImpl( + op_name="convert_thd_to_bshd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_thd_to_bshd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="convert_bshd_to_thd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_bshd_to_thd, is_avail), + vendor="HYGON", + priority=100, + ), + # RoPE (Rotary Position Embedding) + OpImpl( + op_name="fused_rope_forward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_forward, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fused_rope_backward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_backward, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_forward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_forward, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_backward", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_backward, is_avail), + vendor="HYGON", + priority=100, + ), + # TopK and MOE aux loss + OpImpl( + op_name="fused_topk_with_score_function_fwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_fwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fused_topk_with_score_function_bwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_bwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_fwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_fwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_bwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_bwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_fwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_fwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_bwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_bwd, is_avail), + vendor="HYGON", + priority=100, + ), + # Dropout + OpImpl( + op_name="dropout_fwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_fwd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="dropout_bwd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_bwd, is_avail), + vendor="HYGON", + priority=100, + ), + # FP8 operations + OpImpl( + op_name="fp8_transpose", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_transpose, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="swap_first_dims", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swap_first_dims, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_data_transpose", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_data_transpose, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="swizzle_scales_for_gemm_", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swizzle_scales_for_gemm_, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="grouped_swizzle_for_gemm", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.grouped_swizzle_for_gemm, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="convert_host_pointers_to_tensor", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_host_pointers_to_tensor, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="get_device_pointer_for_data_and_scales", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_device_pointer_for_data_and_scales, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="splits_to_offsets", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.splits_to_offsets, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="compute_amax", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.compute_amax, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fused_amax_and_scale_update_after_reduction", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_amax_and_scale_update_after_reduction, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_compute_partial_amax", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_compute_partial_amax, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_partial_cast", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_partial_cast, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="mxfp8_scaling_compute_partial_amax", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.mxfp8_scaling_compute_partial_amax, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="mxfp8_scaling_partial_cast", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.mxfp8_scaling_partial_cast, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_compute_partial_amax", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_compute_partial_amax, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_compute_partial_amax", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_compute_partial_amax, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_compute_global_scale", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_compute_global_scale, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_compute_per_block_scale", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_compute_per_block_scale, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_expand_scale_to_fp8", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_expand_scale_to_fp8, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_fused_scale", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_fused_scale, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_fused_scale", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_fused_scale, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_partial_cast", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_partial_cast, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_2d_partial_cast", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_2d_partial_cast, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_multi_tensor_transpose", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_multi_tensor_transpose, is_avail), + vendor="HYGON", + priority=100, + ), + # Padding operations + OpImpl( + op_name="fused_multi_row_padding", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_padding, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="fused_multi_row_unpadding", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_unpadding, is_avail), + vendor="HYGON", + priority=100, + ), + # Library version getters + OpImpl( + op_name="get_cublasLt_version", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cublasLt_version, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="get_cudnn_version", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cudnn_version, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="get_num_cublas_streams", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_num_cublas_streams, is_avail), + vendor="HYGON", + priority=100, + ), + # THD (Tensor, Hidden, Dimension) operations + OpImpl( + op_name="thd_read_half_tensor", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_half_tensor, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="thd_second_half_lse_correction", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_second_half_lse_correction, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="thd_read_second_half_lse", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_second_half_lse, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="thd_out_correction", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_out_correction, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="thd_grad_correction", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_grad_correction, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="thd_get_partitioned_indices", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_get_partitioned_indices, is_avail), + vendor="HYGON", + priority=100, + ), + # NVSHMEM operations + # Multi-tensor operations + OpImpl( + op_name="multi_tensor_quantize", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_quantize, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale_tensor", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale_tensor, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_l2norm", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_l2norm, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_unscale_l2norm", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_unscale_l2norm, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_param_remainder", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_param_remainder, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_fp8", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_fp8, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable_master", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable_master, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_sgd", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_sgd, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_and_scale_inv", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_and_scale_inv, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_inv_e8m0", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_inv_e8m0, is_avail), + vendor="HYGON", + priority=100, + ), + # Communication overlap operations + OpImpl( + op_name="bulk_overlap_ag_with_external_gemm", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bulk_overlap_ag_with_external_gemm, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="create_fp8_tensor_meta", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_fp8_tensor_meta, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_helper", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_helper, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap, is_avail), + vendor="HYGON", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_p2p", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_p2p, is_avail), + vendor="HYGON", + priority=100, + ), + # FlashAttention class getter + OpImpl( + op_name="get_flash_attention_class", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_flash_attention_class, is_avail), + vendor="HYGON", + priority=100, + ), + # Attention backend selection + OpImpl( + op_name="get_attention_backend", + impl_id="vendor.hygon", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_attention_backend, is_avail), + vendor="HYGON", + priority=100, + ), + ] + + registry.register_many(impls) diff --git a/transformer_engine/plugin/core/backends/vendor/iluvatar/__init__.py b/transformer_engine/plugin/core/backends/vendor/iluvatar/__init__.py new file mode 100644 index 0000000000..740c8d44d6 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/iluvatar/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .iluvatar import IluvatarBackend + +__all__ = ["IluvatarBackend"] diff --git a/transformer_engine/plugin/core/backends/vendor/iluvatar/iluvatar.py b/transformer_engine/plugin/core/backends/vendor/iluvatar/iluvatar.py new file mode 100644 index 0000000000..bc0dc390c3 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/iluvatar/iluvatar.py @@ -0,0 +1,1841 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import math +import torch + +from ....ops import * + + +def _load_iluvatar_libs(): + import ctypes + import os + import subprocess + from pathlib import Path + import importlib.util + import sysconfig + import platform + import glob as glob_module + + def get_ext(): + system = platform.system() + return ".so" if system == "Linux" else ".dylib" if system == "Darwin" else ".dll" + + ext = get_ext() + + def try_load_lib(name, search_patterns): + for env_var in [f"{name.upper()}_HOME", f"{name.upper()}_PATH"]: + path = os.environ.get(env_var) + if path: + libs = glob_module.glob(f"{path}/**/lib{name}{ext}*", recursive=True) + if libs: + libs.sort(reverse=True, key=os.path.basename) + try: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + except: + pass + + cuda_home = os.environ.get("IX_HOME") or os.environ.get("IX_PATH") or "/usr/local/corex" + for pattern in search_patterns: + libs = glob_module.glob(f"{cuda_home}/**/{pattern}", recursive=True) + if libs: + libs.sort(reverse=True, key=os.path.basename) + try: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + except: + pass + + try: + result = subprocess.check_output(f"ldconfig -p | grep 'lib{name}{ext}'", shell=True) + for line in result.decode().split("\n"): + if f"lib{name}" in line and "=>" in line: + so_path = line.split(">")[1].strip() + if so_path: + return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) + except: + pass + + try: + return ctypes.CDLL(f"lib{name}{ext}", mode=ctypes.RTLD_GLOBAL) + except: + return None + + try: + try_load_lib("cudnn", [f"libcudnn{ext}*"]) + try_load_lib("nvrtc", [f"libnvrtc{ext}*"]) + try_load_lib("curand", [f"libcurand{ext}*"]) + + te_path = Path(importlib.util.find_spec("transformer_engine_iluvatar").origin).parent.parent + for search_dir in [te_path, te_path / "transformer_engine_iluvatar/libs"]: + if search_dir.exists(): + matches = list(search_dir.glob(f"libixte_common{ext}*")) + if matches: + ctypes.CDLL(str(matches[0]), mode=ctypes.RTLD_GLOBAL) + return True + return False + except Exception as e: + return False + + +_iluvatar_libs_loaded = False + + +def _ensure_iluvatar_libs(): + global _iluvatar_libs_loaded + if not _iluvatar_libs_loaded: + _iluvatar_libs_loaded = _load_iluvatar_libs() + if _iluvatar_libs_loaded: + print(f"[ILUVATAR] Successfully loaded ILUVATAR libs") + return _iluvatar_libs_loaded + + +def _check_iluvatar_available() -> bool: + if not torch.cuda.is_available(): + return False + import os + + try: + if not _ensure_iluvatar_libs(): + return False + import transformer_engine_iluvatar + + return True + except (ImportError, OSError) as e: + print(f"[ILUVATAR] Import failed: {e}") + return False + + +def _get_tex(): + import transformer_engine_iluvatar.pytorch.ixte_torch + + return transformer_engine_iluvatar.pytorch.ixte_torch + + +class IluvatarBackend(TEFLBackendBase): + @staticmethod + def check_available() -> bool: + return _check_iluvatar_available() + + def __init__(self): + self._tex = None + + def _get_tex(self): + if self._tex is None: + self._tex = _get_tex() + return self._tex + + def is_available(self) -> bool: + return _check_iluvatar_available() + + def get_attention_backend(self, attention_params=None): + from packaging.version import Version as PkgVersion + from ....logger_manager import get_logger + + logger = get_logger() + + # Read environment variables to determine which backends to enable + use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) + use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) + use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) + + # Log disabled backends + if not use_flash_attention: + logger.info_once("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") + if not use_fused_attention: + logger.info_once("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") + if not use_unfused_attention: + logger.info_once("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") + + flash_attention_backend = PkgVersion("2.6.0") if use_flash_attention else None + fused_attention_backend = NVTE_Fused_Attn_Backend.NVTE_No_Backend + + available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + + return ( + use_flash_attention, + flash_attention_backend, + use_fused_attention, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) + + ##### transformer_engine/pytorch/csrc/extensions/pybind.cpp ##### + def quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + output: Optional[torch.Tensor] = None, + noop: Optional[torch.Tensor] = None, + ) -> Any: + tex = self._get_tex() + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + return tex.quantize(tensor, quantizer, output, noop) + + def dequantize( + self, + input: Any, + otype: DType, + ) -> Any: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.dequantize(input, otype) + + def bgrad_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.bgrad_quantize(input, quantizer) + + def group_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.group_quantize(input, quantizer) + + def bgrad_group_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.bgrad_group_quantize(input, quantizer) + + def generic_gemm( + self, + A: Any, + transA: bool, + B: Any, + transB: bool, + D: Any, + quantizer: Any, + output_dtype: Optional[DType], + bias: Optional[torch.Tensor], + bias_type: DType, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[CommOverlapType] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, + ) -> List[Any]: + tex = self._get_tex() + + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + comm_type = tex.CommOverlapType(int(comm_type)) if comm_type is not None else None + output_dtype = tex.DType(int(output_dtype)) if output_dtype is not None else None + return tex.generic_gemm( + A, + transA, + B, + transB, + D, + quantizer, + output_dtype, + bias, + bias_type, + gelu, + gelu_in, + grad, + workspace, + workspace_size, + accumulate, + use_split_accumulator, + comm_overlap, + comm_type, + extra_output, + bulk_overlap, + alpha, + beta, + ) + + # GELU and variants # + def glu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.glu(input, quantizer) + + def gelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.gelu(input, quantizer) + + def geglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.geglu(input, quantizer) + + def qgelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgelu(input, quantizer) + + def qgeglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgeglu(input, quantizer) + + # ReLU and variants # + def relu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.relu(input, quantizer) + + def reglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.reglu(input, quantizer) + + def srelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.srelu(input, quantizer) + + def sreglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.sreglu(input, quantizer) + + # SwiGLU and variants # + def silu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.silu(input, quantizer) + + def swiglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.swiglu(input, quantizer) + + def clamped_swiglu( + self, + input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_swiglu(input, quantizer, limit, alpha) + + # Backward of GELU and variants # + def dglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dglu(grad, fwd_input, quantizer) + + def dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgelu(grad, fwd_input, quantizer) + + def dgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgeglu(grad, fwd_input, quantizer) + + def dqgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgelu(grad, fwd_input, quantizer) + + def dqgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgeglu(grad, fwd_input, quantizer) + + # Backward of ReLU and variants # + def drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.drelu(grad, fwd_input, quantizer) + + def dreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dreglu(grad, fwd_input, quantizer) + + def dsrelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsrelu(grad, fwd_input, quantizer) + + def dsreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsreglu(grad, fwd_input, quantizer) + + # Backward of SiLU and variants # + def dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsilu(grad, fwd_input, quantizer) + + def dswiglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dswiglu(grad, fwd_input, quantizer) + + def clamped_dswiglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_dswiglu(grad, fwd_input, quantizer, limit, alpha) + + # DBias + DAct fusions # + def dbias_dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dgelu(grad, fwd_input, quantizer) + + def dbias_dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsilu(grad, fwd_input, quantizer) + + def dbias_drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_drelu(grad, fwd_input, quantizer) + + def dbias_dqgelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dqgelu(grad, fwd_input, quantizer) + + def dbias_dsrelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsrelu(grad, fwd_input, quantizer) + + # Permutation functions + def moe_permute_fwd( + self, + input: torch.Tensor, + dtype: DType, + indices: torch.Tensor, + num_out_tokens: int, + workspace: List[torch.Tensor], + max_expanded_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_fwd( + input, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + + def moe_permute_bwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_bwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_fwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_bwd( + self, + input_bwd: torch.Tensor, + input_fwd: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_bwd(input_bwd, input_fwd, dtype, row_id_map, prob) + + # Softmax functions + def scaled_softmax_forward( + self, + input: torch.Tensor, + scale: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_forward(input, scale) + + def scaled_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_masked_softmax_forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_forward(input, mask, scale_factor) + + def scaled_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_upper_triang_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_forward(input, scale_factor) + + def scaled_upper_triang_masked_softmax_backward( + self, + output_grads_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_backward( + output_grads_, softmax_results_, scale_factor + ) + + def scaled_aligned_causal_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_forward(input, scale_factor) + + def scaled_aligned_causal_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_backward( + output_grad_, softmax_results_, scale_factor + ) + + # Other granular functions + def layernorm_fwd( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.layernorm_fwd( + input, weight, bias, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def layernorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_fwd( + self, + input: Any, + weight: Any, + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.rmsnorm_fwd( + input, weight, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def rmsnorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_bwd_add( + self, + dz: torch.Tensor, + x: torch.Tensor, + add: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd_add(dz, x, add, rsigma, gamma, sm_margin, zero_centered_gamma) + + def multi_tensor_quantize( + self, + tensor_list: List[torch.Tensor], + quantizer_list: List[Any], + ) -> List[Any]: + tex = self._get_tex() + return tex.multi_tensor_quantize(tensor_list, quantizer_list) + + def split_quantize( + self, + tensor: torch.Tensor, + split_sections: List[int], + quantizer_list: List[Any], + disable_bulk_allocation: bool = False, + ) -> List[Any]: + tex = self._get_tex() + return tex.split_quantize(tensor, split_sections, quantizer_list, disable_bulk_allocation) + + def te_general_grouped_gemm( + self, + A: List[Any], + transa: bool, + B: List[Any], + transb: bool, + D: Optional[List[torch.Tensor]], + D_type: DType, + m_splits: List[int], + bias: List[torch.Tensor], + bias_type: DType, + single_output: bool, + pre_gelu_out: List[torch.Tensor], + grad: bool, + workspace: List[torch.Tensor], + workspaceSizes: int, + accumulate: bool, + use_split_accumulator: bool, + math_sm_count: int, + ) -> Optional[List[torch.Tensor]]: + tex = self._get_tex() + D_type = tex.DType(int(D_type)) if D_type is not None else None + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + return tex.te_general_grouped_gemm( + A, + transa, + B, + transb, + D, + D_type, + m_splits, + bias, + bias_type, + single_output, + pre_gelu_out, + grad, + workspace, + workspaceSizes, + accumulate, + use_split_accumulator, + math_sm_count, + ) + + def te_general_grouped_gemm_for_grouped_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_grouped_tensor(*args, **kwargs) + + def te_general_grouped_gemm_for_discrete_in(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_discrete_in(*args, **kwargs) + + def te_general_grouped_gemm_for_discrete_out(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_discrete_out(*args, **kwargs) + + def fp8_transpose( + self, + input: torch.Tensor, + dtype: DType, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.fp8_transpose(input, dtype, out) + + def swap_first_dims( + self, + tensor: torch.Tensor, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.swap_first_dims(tensor, out) + + def nvfp4_data_transpose(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_data_transpose(*args, **kwargs) + + def swizzle_scales_for_gemm_(self, *args, **kwargs): + tex = self._get_tex() + return tex.swizzle_scales_for_gemm_(*args, **kwargs) + + def grouped_swizzle_for_gemm(self, *args, **kwargs): + tex = self._get_tex() + return tex.grouped_swizzle_for_gemm(*args, **kwargs) + + def convert_host_pointers_to_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.convert_host_pointers_to_tensor(*args, **kwargs) + + def get_device_pointer_for_data_and_scales(self, *args, **kwargs): + tex = self._get_tex() + return tex.get_device_pointer_for_data_and_scales(*args, **kwargs) + + def splits_to_offsets(self, *args, **kwargs): + tex = self._get_tex() + return tex.splits_to_offsets(*args, **kwargs) + + def get_fused_attn_backend( + self, + is_training: bool, + q_dtype: DType, + kv_dtype: DType, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + p_dropout: float, + num_attn_heads: int, + num_gqa_groups: int, + max_seqlen_q: int, + max_seqlen_kv: int, + head_dim_qk: int, + head_dim_v: int, + window_size_left: int, + window_size_right: int, + return_max_logit: bool, + cuda_graph: bool = False, + deterministic: bool = False, + ) -> NVTE_Fused_Attn_Backend: + tex = self._get_tex() + + q_dtype = tex.DType(int(q_dtype)) if q_dtype is not None else None + kv_dtype = tex.DType(int(kv_dtype)) if kv_dtype is not None else None + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + result = tex.get_fused_attn_backend( + is_training, + q_dtype, + kv_dtype, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + p_dropout, + num_attn_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size_left, + window_size_right, + return_max_logit, + cuda_graph, + deterministic, + ) + return NVTE_Fused_Attn_Backend(result) + + def compute_amax( + self, + input: torch.Tensor, + amax: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.compute_amax(input, amax) + + def fused_amax_and_scale_update_after_reduction( + self, + amax_reduction_buffer: torch.Tensor, + amax_histories: List[torch.Tensor], + scales: List[torch.Tensor], + amax_compute_algo: str, + fp8_dtype: DType, + margin: float, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.fused_amax_and_scale_update_after_reduction( + amax_reduction_buffer, amax_histories, scales, amax_compute_algo, fp8_dtype, margin + ) + + def fp8_block_scaling_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + ) -> None: + tex = self._get_tex() + return tex.fp8_block_scaling_compute_partial_amax( + tensor, amax, h, w, start_offset, block_len + ) + + def fp8_block_scaling_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + out_dtype: DType, + ) -> None: + tex = self._get_tex() + out_dtype = tex.DType(int(out_dtype)) if out_dtype is not None else None + return tex.fp8_block_scaling_partial_cast( + inp, out, scale, h, w, start_offset, block_len, out_dtype + ) + + def mxfp8_scaling_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.mxfp8_scaling_compute_partial_amax(*args, **kwargs) + + def mxfp8_scaling_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.mxfp8_scaling_partial_cast(*args, **kwargs) + + def nvfp4_2d_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_compute_partial_amax(*args, **kwargs) + + def nvfp4_multi_tensor_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_compute_partial_amax(*args, **kwargs) + + def nvfp4_compute_global_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_compute_global_scale(*args, **kwargs) + + def nvfp4_compute_per_block_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_compute_per_block_scale(*args, **kwargs) + + def nvfp4_expand_scale_to_fp8(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_expand_scale_to_fp8(*args, **kwargs) + + def nvfp4_fused_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_fused_scale(*args, **kwargs) + + def nvfp4_multi_tensor_fused_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_fused_scale(*args, **kwargs) + + def nvfp4_2d_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_partial_cast(*args, **kwargs) + + def nvfp4_multi_tensor_2d_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_2d_partial_cast(*args, **kwargs) + + def nvfp4_2d_multi_tensor_transpose(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_multi_tensor_transpose(*args, **kwargs) + + def fused_multi_row_padding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + padded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_padding(input, output, input_row_list, padded_input_row_list) + + def fused_multi_row_unpadding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + unpadded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_unpadding(input, output, input_row_list, unpadded_input_row_list) + + # attention kernels + def fa_prepare_fwd( + self, + qkvi: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_fwd(qkvi) + + def fa_prepare_bwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_bwd(q, k, v) + + def fused_attn_fwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + is_training: bool, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + fake_dtype: torch.dtype, + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + page_table_k: Optional[torch.Tensor], + page_table_v: Optional[torch.Tensor], + s_quantizer: Any, + o_quantizer: Any, + Bias: Optional[torch.Tensor], + SoftmaxOffset: Optional[torch.Tensor], + rng_gen: Optional[torch.Generator], + rng_elts_per_thread: int, + return_max_logit: bool, + cuda_graph: bool = False, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + return tex.fused_attn_fwd( + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + fake_dtype, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + s_quantizer, + o_quantizer, + Bias, + SoftmaxOffset, + rng_gen, + rng_elts_per_thread, + return_max_logit, + cuda_graph, + ) + + def fused_attn_bwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + deterministic: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + O: Any, + dO: Any, + fake_dtype: torch.dtype, + dqkv_type: DType, + Aux_CTX_Tensors: List[torch.Tensor], + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + s_quantizer: Any, + dp_quantizer: Any, + dqkv_quantizer: Any, + cuda_graph: bool = False, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + dqkv_type = tex.DType(int(dqkv_type)) if dqkv_type is not None else None + + return tex.fused_attn_bwd( + max_seqlen_q, + max_seqlen_kv, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + deterministic, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + O, + dO, + fake_dtype, + dqkv_type, + Aux_CTX_Tensors, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + s_quantizer, + dp_quantizer, + dqkv_quantizer, + cuda_graph, + ) + + def copy_to_kv_cache( + self, + new_k: torch.Tensor, + new_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + cu_new_lens: torch.Tensor, + cu_cached_lens: torch.Tensor, + qkv_format: NVTE_QKV_Format, + b: int, + max_ctx_len: int, + max_seq_len: int, + max_pages_per_seq: int, + is_non_paged: bool, + ) -> None: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + page_table, + cu_new_lens, + cu_cached_lens, + qkv_format, + b, + max_ctx_len, + max_seq_len, + max_pages_per_seq, + is_non_paged, + ) + + def convert_thd_to_bshd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + b: int, + max_seq_len: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_thd_to_bshd(tensor, cu_seqlens, b, max_seq_len) + + def convert_bshd_to_thd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + t: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_bshd_to_thd(tensor, cu_seqlens, t) + + # fused apply rope + def fused_rope_forward( + self, + input: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_forward( + input, freqs, start_positions, qkv_format, interleaved, cu_seqlens, cp_size, cp_rank + ) + + def fused_rope_backward( + self, + output_grads: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_backward( + output_grads, + freqs, + start_positions, + qkv_format, + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_forward( + self, + qkv_input: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_forward( + qkv_input, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_backward( + self, + q_grad_out: torch.Tensor, + k_grad_out: torch.Tensor, + v_grad_out: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_backward( + q_grad_out, + k_grad_out, + v_grad_out, + q_freqs, + k_freqs, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + # fused router + def fused_topk_with_score_function_fwd( + self, + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], + score_function: str, + expert_bias: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + ) + + def fused_topk_with_score_function_bwd( + self, + num_tokens: int, + num_experts: int, + routing_map: torch.Tensor, + intermediate_output: torch.Tensor, + grad_probs: torch.Tensor, + topk: int, + use_pre_softmax: bool, + scaling_factor: Optional[float], + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_topk_with_score_function_bwd( + num_tokens, + num_experts, + routing_map, + intermediate_output, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + ) + + def fused_score_for_moe_aux_loss_fwd( + self, + logits: torch.Tensor, + topk: int, + score_function: str, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_fwd( + logits, + topk, + score_function, + ) + + def fused_score_for_moe_aux_loss_bwd( + self, + num_tokens: int, + num_experts: int, + intermediate_output: torch.Tensor, + grad_scores: torch.Tensor, + topk: int, + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_bwd( + num_tokens, + num_experts, + intermediate_output, + grad_scores, + topk, + score_function, + ) + + def fused_moe_aux_loss_fwd( + self, + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + num_experts: int, + num_rows: int, + num_cols: int, + topk: int, + coeff: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_moe_aux_loss_fwd( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + num_rows, + num_cols, + topk, + coeff, + ) + + def fused_moe_aux_loss_bwd( + self, + Const_buf: torch.Tensor, + tokens_per_expert: torch.Tensor, + num_rows: int, + num_cols: int, + grad_aux_loss: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_moe_aux_loss_bwd( + Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss + ) + + # Dropout + def dropout_fwd( + self, + input: torch.Tensor, + dropout_probability: float, + out: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.dropout_fwd(input, dropout_probability, out) + + def dropout_bwd( + self, + grad_output: torch.Tensor, + mask: torch.Tensor, + dropout_probability: float, + grad_input: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.dropout_bwd(grad_output, mask, dropout_probability, grad_input) + + # Misc + def get_cublasLt_version(self) -> int: + tex = self._get_tex() + return tex.get_cublasLt_version() + + def get_cudnn_version(self) -> int: + tex = self._get_tex() + return tex.get_cudnn_version() + + def get_num_cublas_streams(self) -> int: + tex = self._get_tex() + return tex.get_num_cublas_streams() + + # Support THD format for Context Parallel + def thd_read_half_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + half_idx: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_half_tensor(tensor, cu_seqlens, half_idx) + + def thd_second_half_lse_correction( + self, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_second_half_lse_correction(lse, lse_per_step, cu_seqlens, lse_packed) + + def thd_read_second_half_lse( + self, + lse: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + second_half_lse_seqlen: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_second_half_lse(lse, cu_seqlens, lse_packed, second_half_lse_seqlen) + + def thd_out_correction( + self, + out: torch.Tensor, + out_per_step: torch.Tensor, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + only_second_half: bool, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_out_correction( + out, out_per_step, lse, lse_per_step, cu_seqlens, only_second_half, lse_packed + ) + + def thd_grad_correction( + self, + grad: torch.Tensor, + grad_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + first_half: str, + second_half: str, + ) -> None: + tex = self._get_tex() + return tex.thd_grad_correction(grad, grad_per_step, cu_seqlens, first_half, second_half) + + def thd_get_partitioned_indices( + self, + cu_seqlens: torch.Tensor, + total_tokens: int, + world_size: int, + rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, world_size, rank) + + # nvshmem functions + def init_nvshmem_backend( + self, + process_group: Any, + ) -> None: + tex = self._get_tex() + return tex.init_nvshmem_backend(process_group) + + def create_nvshmem_tensor( + self, + shape: List[int], + dtype: torch.dtype, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.create_nvshmem_tensor(shape, dtype) + + def nvshmem_send_on_current_stream( + self, + src: torch.Tensor, + dst: torch.Tensor, + peer: int, + signal: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_send_on_current_stream(src, dst, peer, signal) + + def nvshmem_wait_on_current_stream( + self, + signal: torch.Tensor, + wait_kind: str, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_wait_on_current_stream(signal, wait_kind) + + def nvshmem_finalize(self) -> None: + tex = self._get_tex() + return tex.nvshmem_finalize() + + # multi-tensor functions + def multi_tensor_scale( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale) + + def multi_tensor_scale_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.multi_tensor_scale_tensor(*args, **kwargs) + + def multi_tensor_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor) + + def multi_tensor_unscale_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + inv_scale: torch.Tensor, + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_unscale_l2norm( + chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor + ) + + def multi_tensor_adam( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_param_remainder( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_param_remainder( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_fp8( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + fp8_dtype: DType, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.multi_tensor_adam_fp8( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + fp8_dtype, + ) + + def multi_tensor_adam_capturable( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_adam_capturable_master( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_sgd( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + wd: float, + momentum: float, + dampening: float, + lr: float, + nesterov: bool, + first_run: bool, + wd_after_momentum: bool, + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_lists, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, + ) + + def multi_tensor_compute_scale_and_scale_inv( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + max_fp8: float, + force_pow_2_scales: bool, + epsilon: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_compute_scale_and_scale_inv( + chunk_size, noop_flag, tensor_lists, max_fp8, force_pow_2_scales, epsilon + ) + + def multi_tensor_compute_scale_inv_e8m0(self, *args, **kwargs): + tex = self._get_tex() + return tex.multi_tensor_compute_scale_inv_e8m0(*args, **kwargs) + + # Comm+GEMM Overlap + def bulk_overlap_ag_with_external_gemm( + self, + allgather_communicator: CommOverlap, + send_stream: Any, + recv_stream: Any, + ) -> Any: + tex = self._get_tex() + return tex.bulk_overlap_ag_with_external_gemm( + allgather_communicator, send_stream, recv_stream + ) + + ############## class func ################################# + def get_flash_attention_class(self): + raise NotImplementedError("get_flash_attention_class - not implemented in iluvatar backend") + + def create_fp8_tensor_meta(self) -> FP8TensorMeta: + tex = self._get_tex() + return tex.FP8TensorMeta() + + def create_comm_overlap_helper( + self, + world_group: Optional[Any] = None, + intra_node_group: Optional[Any] = None, + ) -> "CommOverlapHelper": + tex = self._get_tex() + return tex.CommOverlapHelper(world_group, intra_node_group) + + def create_comm_overlap( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + num_splits: int = 3, + num_max_streams: int = 3, + comm_cga_size: int = 2, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 16, + set_sm_margin: bool = True, + atomic_gemm: bool = False, + rs_overlap_first_gemm: bool = False, + ) -> "CommOverlap": + tex = self._get_tex() + return tex.CommOverlap( + buffer_shape, + buffer_dtype, + helper, + tp_size, + num_splits, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + rs_overlap_first_gemm, + ) + + def create_comm_overlap_p2p( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + comm_type: Any, + num_max_streams: int = 3, + comm_cga_size: int = 1, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 1, + set_sm_margin: bool = False, + atomic_gemm: bool = False, + use_ce: bool = True, + aggregate: bool = False, + ) -> "CommOverlapP2P": + tex = self._get_tex() + return tex.CommOverlapP2P( + buffer_shape, + buffer_dtype, + helper, + tp_size, + comm_type, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + use_ce, + aggregate, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/iluvatar/register_ops.py b/transformer_engine/plugin/core/backends/vendor/iluvatar/register_ops.py new file mode 100644 index 0000000000..001f6129d8 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/iluvatar/register_ops.py @@ -0,0 +1,1173 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Iluvatar vendor backend operator registrations. + +This module registers all VENDOR (Iluvatar) implementations from transformer_engine_torch. +""" + +from __future__ import annotations + +import functools + +from ....types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all Iluvatar (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + # Import Iluvatar backend to get all the wrapped tex functions + from .iluvatar import IluvatarBackend + + # Create a backend instance to access the methods + backend = IluvatarBackend() + + # Check if Iluvatar is available before registering + if not backend.is_available(): + return + + # Bind is_available to all methods + is_avail = backend.is_available + + impls = [ + # Normalization + OpImpl( + op_name="rmsnorm_fwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_fwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd_add", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd_add, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="layernorm_fwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_fwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="layernorm_bwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_bwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + # GEMM + OpImpl( + op_name="generic_gemm", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.generic_gemm, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_grouped_tensor", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_grouped_tensor, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_discrete_in", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_discrete_in, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_discrete_out", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_discrete_out, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Quantization + OpImpl( + op_name="quantize", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.quantize, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dequantize", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dequantize, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="bgrad_quantize", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_quantize, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="group_quantize", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.group_quantize, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="bgrad_group_quantize", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_group_quantize, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="split_quantize", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.split_quantize, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Activations - Forward + OpImpl( + op_name="glu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.glu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="gelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.gelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="geglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.geglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="qgelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="qgeglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgeglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="relu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.relu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="reglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.reglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="srelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.srelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="sreglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.sreglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="silu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.silu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="swiglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swiglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="clamped_swiglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_swiglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Activations - Backward + OpImpl( + op_name="dglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dgelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dgeglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgeglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dqgelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dqgeglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgeglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="drelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.drelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dreglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dreglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dsrelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsrelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dsreglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsreglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dsilu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsilu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dswiglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dswiglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="clamped_dswiglu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_dswiglu, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Activations - Bias + Backward + OpImpl( + op_name="dbias_dgelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dgelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dbias_dsilu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsilu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dbias_drelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_drelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dbias_dqgelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dqgelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dbias_dsrelu", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsrelu, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Softmax + OpImpl( + op_name="scaled_softmax_forward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_forward, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="scaled_softmax_backward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_backward, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_forward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_forward, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_backward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_backward, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_forward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_forward, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_backward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_backward, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_forward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_forward, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_backward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_backward, is_avail), + vendor="Iluvatar", + priority=100, + ), + # MOE operations + OpImpl( + op_name="moe_permute_fwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_fwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="moe_permute_bwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_bwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_fwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_fwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_bwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_bwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Fused attention + OpImpl( + op_name="get_fused_attn_backend", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_fused_attn_backend, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_attn_fwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_attn_fwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_attn_bwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_attn_bwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fa_prepare_fwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_fwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fa_prepare_bwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_bwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + # KV cache + OpImpl( + op_name="copy_to_kv_cache", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.copy_to_kv_cache, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Tensor format conversions + OpImpl( + op_name="convert_thd_to_bshd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_thd_to_bshd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="convert_bshd_to_thd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_bshd_to_thd, is_avail), + vendor="Iluvatar", + priority=100, + ), + # RoPE (Rotary Position Embedding) + OpImpl( + op_name="fused_rope_forward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_forward, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_rope_backward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_backward, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_forward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_forward, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_backward", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_backward, is_avail), + vendor="Iluvatar", + priority=100, + ), + # TopK and MOE aux loss + OpImpl( + op_name="fused_topk_with_score_function_fwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_fwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_topk_with_score_function_bwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_bwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_fwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_fwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_bwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_bwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_fwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_fwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_bwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_bwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Dropout + OpImpl( + op_name="dropout_fwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_fwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="dropout_bwd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_bwd, is_avail), + vendor="Iluvatar", + priority=100, + ), + # FP8 operations + OpImpl( + op_name="fp8_transpose", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_transpose, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="swap_first_dims", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swap_first_dims, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvfp4_data_transpose", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_data_transpose, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="swizzle_scales_for_gemm_", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swizzle_scales_for_gemm_, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="grouped_swizzle_for_gemm", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.grouped_swizzle_for_gemm, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="convert_host_pointers_to_tensor", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_host_pointers_to_tensor, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="get_device_pointer_for_data_and_scales", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_device_pointer_for_data_and_scales, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="splits_to_offsets", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.splits_to_offsets, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="compute_amax", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.compute_amax, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_amax_and_scale_update_after_reduction", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_amax_and_scale_update_after_reduction, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_compute_partial_amax", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_compute_partial_amax, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_partial_cast", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_partial_cast, is_avail), + vendor="Iluvatar", + priority=100, + ), + # MXFP8 scaling operations + OpImpl( + op_name="mxfp8_scaling_compute_partial_amax", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.mxfp8_scaling_compute_partial_amax, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="mxfp8_scaling_partial_cast", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.mxfp8_scaling_partial_cast, is_avail), + vendor="Iluvatar", + priority=100, + ), + # NVFP4 operations + OpImpl( + op_name="nvfp4_2d_compute_partial_amax", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_compute_partial_amax, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_compute_partial_amax", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_compute_partial_amax, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvfp4_compute_global_scale", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_compute_global_scale, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvfp4_compute_per_block_scale", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_compute_per_block_scale, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvfp4_expand_scale_to_fp8", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_expand_scale_to_fp8, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvfp4_fused_scale", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_fused_scale, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_fused_scale", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_fused_scale, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_partial_cast", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_partial_cast, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_2d_partial_cast", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_2d_partial_cast, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_multi_tensor_transpose", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_multi_tensor_transpose, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Padding operations + OpImpl( + op_name="fused_multi_row_padding", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_padding, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="fused_multi_row_unpadding", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_unpadding, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Library version getters + OpImpl( + op_name="get_cublasLt_version", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cublasLt_version, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="get_cudnn_version", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cudnn_version, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="get_num_cublas_streams", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_num_cublas_streams, is_avail), + vendor="Iluvatar", + priority=100, + ), + # THD (Tensor, Hidden, Dimension) operations + OpImpl( + op_name="thd_read_half_tensor", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_half_tensor, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="thd_second_half_lse_correction", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_second_half_lse_correction, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="thd_read_second_half_lse", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_second_half_lse, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="thd_out_correction", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_out_correction, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="thd_grad_correction", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_grad_correction, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="thd_get_partitioned_indices", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_get_partitioned_indices, is_avail), + vendor="Iluvatar", + priority=100, + ), + # NVSHMEM operations + OpImpl( + op_name="init_nvshmem_backend", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.init_nvshmem_backend, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="create_nvshmem_tensor", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_nvshmem_tensor, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvshmem_send_on_current_stream", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_send_on_current_stream, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvshmem_wait_on_current_stream", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_wait_on_current_stream, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="nvshmem_finalize", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_finalize, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Multi-tensor operations + OpImpl( + op_name="multi_tensor_quantize", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_quantize, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale_tensor", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale_tensor, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_l2norm", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_l2norm, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_unscale_l2norm", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_unscale_l2norm, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_param_remainder", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_param_remainder, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_fp8", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_fp8, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable_master", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable_master, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_sgd", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_sgd, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_and_scale_inv", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_and_scale_inv, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_inv_e8m0", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_inv_e8m0, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Communication overlap operations + OpImpl( + op_name="bulk_overlap_ag_with_external_gemm", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bulk_overlap_ag_with_external_gemm, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="create_fp8_tensor_meta", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_fp8_tensor_meta, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_helper", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_helper, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap, is_avail), + vendor="Iluvatar", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_p2p", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_p2p, is_avail), + vendor="Iluvatar", + priority=100, + ), + # FlashAttention class getter + OpImpl( + op_name="get_flash_attention_class", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_flash_attention_class, is_avail), + vendor="Iluvatar", + priority=100, + ), + # Attention backend selection + OpImpl( + op_name="get_attention_backend", + impl_id="vendor.iluvatar", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_attention_backend, is_avail), + vendor="Iluvatar", + priority=100, + ), + ] + + registry.register_many(impls) diff --git a/transformer_engine/plugin/core/backends/vendor/kunlunxin/__init__.py b/transformer_engine/plugin/core/backends/vendor/kunlunxin/__init__.py new file mode 100644 index 0000000000..aa2198ee35 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/kunlunxin/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .kunlunxin import KunLunXinBackend + +__all__ = ["KunLunXinBackend"] diff --git a/transformer_engine/plugin/core/backends/vendor/kunlunxin/flash_attention.py b/transformer_engine/plugin/core/backends/vendor/kunlunxin/flash_attention.py new file mode 100644 index 0000000000..9beb5403ed --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/kunlunxin/flash_attention.py @@ -0,0 +1,389 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +from transformer_engine.plugin.core.ops import FlashAttentionBase + + +class FlashAttentionTorch(FlashAttentionBase): + def __init__( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = None, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + ) -> None: + super().__init__( + softmax_scale=softmax_scale, + attention_dropout=attention_dropout, + attention_dropout_ctx=attention_dropout_ctx, + attention_type=attention_type, + layer_number=layer_number, + deterministic=deterministic, + ) + + @property + def backend_name(self) -> str: + return "torch_sdpa" + + def _convert_layout_to_bhsd( + self, + tensor: torch.Tensor, + layout: str, + ) -> torch.Tensor: + """Convert tensor from various layouts to [batch, heads, seq, dim] format.""" + layout = layout.lower() + + # Handle combined layouts like "sbhd_sbhd_sbhd" - extract the first part + if "_" in layout: + layout = layout.split("_")[0] + + if layout in ("sbhd", "sbh3d", "sb3hd"): + return tensor.permute(1, 2, 0, 3) + elif layout in ("bshd", "bsh3d", "bs3hd"): + return tensor.permute(0, 2, 1, 3) + elif layout in ("bhsd",): + return tensor + elif layout in ("thd",): + # thd is packed format, should not reach here for 4D tensors + raise ValueError(f"thd layout requires 3D tensor, got {tensor.dim()}D") + else: + raise ValueError(f"Unsupported qkv_layout: {layout}") + + def _convert_bhsd_to_layout( + self, + tensor: torch.Tensor, + layout: str, + ) -> torch.Tensor: + """Convert tensor from [batch, heads, seq, dim] back to original layout.""" + layout = layout.lower() + + # Handle combined layouts like "sbhd_sbhd_sbhd" - extract the first part + if "_" in layout: + layout = layout.split("_")[0] + + if layout in ("sbhd", "sbh3d", "sb3hd"): + return tensor.permute(2, 0, 1, 3) + elif layout in ("bshd", "bsh3d", "bs3hd"): + return tensor.permute(0, 2, 1, 3) + elif layout in ("bhsd",): + return tensor + elif layout in ("thd",): + raise ValueError(f"thd layout requires 3D tensor, got {tensor.dim()}D") + else: + raise ValueError(f"Unsupported qkv_layout: {layout}") + + def _create_sliding_window_mask( + self, + seq_len_q: int, + seq_len_kv: int, + window_size: Tuple[int, int], + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Create a sliding window attention mask.""" + left_window, right_window = window_size + + if left_window == -1 and right_window == -1: + return torch.zeros(seq_len_q, seq_len_kv, dtype=dtype, device=device) + + q_idx = torch.arange(seq_len_q, device=device).unsqueeze(1) + kv_idx = torch.arange(seq_len_kv, device=device).unsqueeze(0) + + mask_bool = torch.zeros(seq_len_q, seq_len_kv, dtype=torch.bool, device=device) + + if left_window >= 0: + mask_bool = mask_bool | (kv_idx < q_idx - left_window) + + if right_window >= 0: + mask_bool = mask_bool | (kv_idx > q_idx + right_window) + + mask = torch.zeros(seq_len_q, seq_len_kv, dtype=dtype, device=device) + mask.masked_fill_(mask_bool, float("-inf")) + + return mask + + def _unpack_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert packed tensor to padded tensor format.""" + batch_size = cu_seqlens.shape[0] - 1 + device = tensor.device + original_shape = tensor.shape + + if tensor.dim() == 4: + if tensor.shape[1] == 1: + tensor = tensor.squeeze(1) + else: + raise ValueError( + f"Unexpected 4D tensor shape {original_shape}. " + "Expected [total_tokens, 1, num_heads, head_dim]" + ) + + if tensor.dim() != 3: + raise ValueError( + f"Expected tensor to be 3D or 4D after processing, got shape {original_shape}" + ) + + total_tokens, num_heads, head_dim = tensor.shape + + expected_total = cu_seqlens[-1].item() + if total_tokens != expected_total: + raise ValueError( + f"Tensor has {total_tokens} tokens but cu_seqlens indicates {expected_total} tokens" + ) + + padded_tensor = torch.zeros( + batch_size, num_heads, max_seqlen, head_dim, dtype=tensor.dtype, device=device + ) + + padding_mask = torch.ones(batch_size, max_seqlen, dtype=torch.bool, device=device) + + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + seq_len = end - start + + seq_data = tensor[start:end].permute(1, 0, 2) + padded_tensor[i, :, :seq_len, :] = seq_data + padding_mask[i, :seq_len] = False + + return padded_tensor, padding_mask + + def _pack_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + """Convert padded tensor back to packed tensor format.""" + batch_size = tensor.shape[0] + num_heads = tensor.shape[1] + head_dim = tensor.shape[3] + total_tokens = cu_seqlens[-1].item() + device = tensor.device + + packed_tensor = torch.zeros( + total_tokens, num_heads, head_dim, dtype=tensor.dtype, device=device + ) + + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + seq_len = end - start + + seq_data = tensor[i, :, :seq_len, :].permute(1, 0, 2) + packed_tensor[start:end, :, :] = seq_data + + return packed_tensor + + def _forward_impl( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Any] = None, + cp_global_ranks: Optional[List[int]] = None, + cp_stream: Optional[torch.cuda.Stream] = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers: Optional[Any] = None, + inference_params: Optional[Any] = None, + flash_attention_backend: Optional[Any] = None, + fp8_output: bool = False, + num_splits: Optional[int] = 1, + ) -> torch.Tensor: + """Flash Attention implementation using PyTorch's scaled_dot_product_attention.""" + if fp8: + raise NotImplementedError("FP8 is not supported in PyTorch SDPA backend") + if cp_group is not None: + raise NotImplementedError( + "Context parallelism is not supported in PyTorch SDPA backend" + ) + if alibi_slopes is not None: + raise NotImplementedError("ALiBi slopes are not supported in PyTorch SDPA backend") + + query_original_shape = query_layer.shape + + # Check if input is in standard 4D format - same as flagos backend + # If tensor is 4D, treat it as standard format and just do layout conversion + # Only use unpack logic for true packed format (3D tensors with thd layout) + is_standard_4d = query_layer.dim() == 4 + + if is_standard_4d: + # Standard 4D tensor format - just convert layout like flagos does + query = self._convert_layout_to_bhsd(query_layer, qkv_layout) + key = self._convert_layout_to_bhsd(key_layer, qkv_layout) + value = self._convert_layout_to_bhsd(value_layer, qkv_layout) + use_packed_format = False + padding_mask_q = None + padding_mask_kv = None + else: + # True packed format (thd layout, 3D tensor) - use unpack logic + use_packed_format = cu_seqlens_q is not None or cu_seqlens_kv is not None + padding_mask_q = None + padding_mask_kv = None + + if use_packed_format: + if cu_seqlens_q is not None: + query, padding_mask_q = self._unpack_tensor( + query_layer, cu_seqlens_q, max_seqlen_q + ) + else: + query = self._convert_layout_to_bhsd(query_layer, qkv_layout) + + if cu_seqlens_kv is not None: + key, padding_mask_kv = self._unpack_tensor( + key_layer, cu_seqlens_kv, max_seqlen_kv + ) + value, _ = self._unpack_tensor(value_layer, cu_seqlens_kv, max_seqlen_kv) + else: + key = self._convert_layout_to_bhsd(key_layer, qkv_layout) + value = self._convert_layout_to_bhsd(value_layer, qkv_layout) + else: + query = self._convert_layout_to_bhsd(query_layer, qkv_layout) + key = self._convert_layout_to_bhsd(key_layer, qkv_layout) + value = self._convert_layout_to_bhsd(value_layer, qkv_layout) + + batch_size, num_heads_q, seq_len_q, head_dim = query.shape + num_heads_kv = key.shape[1] + seq_len_kv = key.shape[2] + + if num_heads_q != num_heads_kv: + num_groups = num_heads_q // num_heads_kv + if num_heads_q % num_heads_kv != 0: + raise ValueError( + f"num_heads_q ({num_heads_q}) must be divisible by num_heads_kv" + f" ({num_heads_kv})" + ) + key = key.repeat_interleave(num_groups, dim=1) + value = value.repeat_interleave(num_groups, dim=1) + + attn_mask = None + is_causal = False + + if use_packed_format and padding_mask_kv is not None: + attn_mask = torch.zeros( + batch_size, seq_len_q, seq_len_kv, dtype=query.dtype, device=query.device + ) + padding_broadcast = padding_mask_kv.unsqueeze(1) + attn_mask.masked_fill_(padding_broadcast, float("-inf")) + + if attn_mask_type == "causal": + is_causal = True + attn_mask = None + # if window_size is None and not use_packed_format: + # is_causal = True + # else: + # causal_mask = torch.zeros( + # seq_len_q, seq_len_kv, + # dtype=query.dtype, device=query.device + # ) + # causal_mask.masked_fill_( + # torch.triu(torch.ones(seq_len_q, seq_len_kv, device=query.device, dtype=torch.bool), diagonal=1), + # float('-inf') + # ) + + # if attn_mask is not None: + # if attn_mask.dim() == 2: + # attn_mask = attn_mask + causal_mask + # else: + # attn_mask = attn_mask + causal_mask.unsqueeze(0) + # else: + # attn_mask = causal_mask + + if window_size is not None and not is_causal: + window_mask = self._create_sliding_window_mask( + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + window_size=window_size, + device=query.device, + dtype=query.dtype, + ) + + if attn_mask is not None: + attn_mask = attn_mask + window_mask.unsqueeze(0) + else: + attn_mask = window_mask + + if attention_mask is not None and attn_mask_type != "causal": + if isinstance(attention_mask, tuple): + explicit_mask = attention_mask[0] + else: + explicit_mask = attention_mask + + if explicit_mask.dtype == torch.bool: + float_mask = torch.zeros_like(explicit_mask, dtype=query.dtype) + float_mask.masked_fill_(~explicit_mask, float("-inf")) + explicit_mask = float_mask + + if explicit_mask.dim() == 2: + explicit_mask = explicit_mask.unsqueeze(0).unsqueeze(0) + elif explicit_mask.dim() == 3: + explicit_mask = explicit_mask.unsqueeze(1) + + if attn_mask is not None: + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) + elif attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(1) + attn_mask = attn_mask + explicit_mask + else: + attn_mask = explicit_mask + elif attn_mask is not None: + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) + elif attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(1) + + with self.attention_dropout_ctx(): + dropout_p = self.attention_dropout if self.training else 0.0 + + output = F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=self.softmax_scale, + ) + + if use_packed_format and padding_mask_q is not None: + mask_expanded = padding_mask_q.unsqueeze(1).unsqueeze(3) + output = output.masked_fill(mask_expanded, 0.0) + + if use_packed_format and cu_seqlens_q is not None: + output = self._pack_tensor(output, cu_seqlens_q) + + if len(query_original_shape) == 4: + total_tokens = output.shape[0] + hidden_size = output.shape[1] * output.shape[2] + output = output.contiguous().view(total_tokens, 1, hidden_size) + else: + output = self._convert_bhsd_to_layout(output, qkv_layout) + # Flatten the last two dimensions (heads, dim) -> (heads * dim) + # to match the output format of other backends + output = output.contiguous().view(*output.shape[:-2], -1) + + return output diff --git a/transformer_engine/plugin/core/backends/vendor/kunlunxin/kunlunxin.py b/transformer_engine/plugin/core/backends/vendor/kunlunxin/kunlunxin.py new file mode 100644 index 0000000000..6dbab926b2 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/kunlunxin/kunlunxin.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from typing import Any, Dict, List, Optional, Tuple, Union +import torch +from ....ops import * + +_kunlunxin_available = False + + +def _ensure_kunlunxin_available(): + global _kunlunxin_available + if not _kunlunxin_available: + try: + result = subprocess.run(["xpu-smi"], capture_output=True, timeout=10, text=True) + + if result.returncode == 0: + _kunlunxin_available = True + else: + _kunlunxin_available = False + + except subprocess.TimeoutExpired: + _kunlunxin_available = False + except FileNotFoundError: + _kunlunxin_available = False + except OSError as e: + _kunlunxin_available = False + except Exception as e: + _kunlunxin_available = False + + return _kunlunxin_available + + +def _check_kunlunxin_available() -> bool: + """Check if xpu-smi command can be executed successfully.""" + if _ensure_kunlunxin_available(): + return True + else: + return False + + +class KunLunXinBackend(TEFLBackendBase): + @staticmethod + def check_available() -> bool: + return _check_kunlunxin_available() + + def is_available(self) -> bool: + return _check_kunlunxin_available() + + def get_flash_attention_class(self): + from .flash_attention import FlashAttentionTorch + + return FlashAttentionTorch diff --git a/transformer_engine/plugin/core/backends/vendor/kunlunxin/register_ops.py b/transformer_engine/plugin/core/backends/vendor/kunlunxin/register_ops.py new file mode 100644 index 0000000000..fa014833b1 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/kunlunxin/register_ops.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +KunLunXin backend operator registrations. + +This module registers all KunLunXin PyTorch implementations. +""" + +from __future__ import annotations + +import functools + +from transformer_engine.plugin.core.types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all KunLunXin PyTorch operator implementations. + + Args: + registry: Registry to register into + """ + from .kunlunxin import KunLunXinBackend + + # Create a backend instance to access the methods + backend = KunLunXinBackend() + + if not backend.is_available(): + return + + # Bind is_available to all methods + is_avail = backend.is_available + + impls = [ + # FlashAttention class getter + OpImpl( + op_name="get_flash_attention_class", + impl_id="vendor.kunlunxin", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_flash_attention_class, is_avail), + vendor="KUNLUNXIN", + priority=100, + ), + ] + + registry.register_many(impls) diff --git a/transformer_engine/plugin/core/backends/vendor/metax/__init__.py b/transformer_engine/plugin/core/backends/vendor/metax/__init__.py new file mode 100644 index 0000000000..b663a97695 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/metax/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .metax import MetaxBackend + +__all__ = ["MetaxBackend"] diff --git a/transformer_engine/plugin/core/backends/vendor/metax/flash_attention.py b/transformer_engine/plugin/core/backends/vendor/metax/flash_attention.py new file mode 100644 index 0000000000..30d6c488ae --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/metax/flash_attention.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from transformer_engine.plugin.core.ops import FlashAttentionBase + + +class FlashAttentionMETAX(FlashAttentionBase): + def __init__( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = None, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + ) -> None: + super().__init__( + softmax_scale=softmax_scale, + attention_dropout=attention_dropout, + attention_dropout_ctx=attention_dropout_ctx, + attention_type=attention_type, + layer_number=layer_number, + deterministic=deterministic, + ) + + # Store initialization parameters for lazy loading + self._init_params = { + "softmax_scale": softmax_scale, + "attention_dropout": attention_dropout, + "attention_dropout_ctx": attention_dropout_ctx or nullcontext, + "attention_type": attention_type, + "layer_number": layer_number, + "deterministic": deterministic, + } + self._metax_flash_attn = None + + def _ensure_metax_flash_attn(self): + """Lazy initialization of metax FlashAttention.""" + if self._metax_flash_attn is not None: + return + + try: + # Import here to avoid circular dependency issues + # transformer_engine_torch must be registered before this import + from transformer_engine_metax.pytorch.attention.dot_product_attention.backends import ( + FlashAttention as FlashAttentionMetax, + ) + + if FlashAttentionMetax is None: + raise RuntimeError( + "FlashAttention class is None - flash-attn may not be installed correctly" + ) + + self._metax_flash_attn = FlashAttentionMetax(**self._init_params) + + except ImportError as e: + raise RuntimeError( + f"Failed to import metax FlashAttention: {e}. " + "Please ensure flash-attn is installed and transformer_engine_torch is available." + ) + except Exception as e: + raise RuntimeError( + f"Failed to initialize metax FlashAttention: {e}. Init params: {self._init_params}" + ) + + @property + def backend_name(self) -> str: + return "metax" + + def _forward_impl( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Any] = None, + cp_global_ranks: Optional[List[int]] = None, + cp_stream: Optional[torch.cuda.Stream] = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers: Optional[Any] = None, + inference_params: Optional[Any] = None, + flash_attention_backend: Optional[Any] = None, + fp8_output: bool = False, + num_splits: Optional[int] = 1, + ) -> torch.Tensor: + # Ensure metax flash attention is initialized + self._ensure_metax_flash_attn() + + return self._metax_flash_attn( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=cp_group, + cp_global_ranks=cp_global_ranks, + cp_stream=cp_stream, + cp_comm_type=cp_comm_type, + fp8=fp8, + fp8_meta=fp8_meta, + quantizers=quantizers, + inference_params=inference_params, + flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, + num_splits=num_splits, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/metax/metax.py b/transformer_engine/plugin/core/backends/vendor/metax/metax.py new file mode 100644 index 0000000000..c61d981826 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/metax/metax.py @@ -0,0 +1,1800 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import ctypes +from pathlib import Path +import importlib.util +import platform +import os +import functools +import inspect + +import torch + +from ....ops import * + + +def _load_metax_libs(): + + def get_ext(): + system = platform.system() + return ".so" if system == "Linux" else ".dylib" if system == "Darwin" else ".dll" + + ext = get_ext() + + try: + import transformer_engine_metax + + te_path = Path(importlib.util.find_spec("transformer_engine_metax").origin).parent.parent + for search_dir in [te_path, te_path / "transformer_engine_metax"]: + if search_dir.exists(): + matches = list(search_dir.glob(f"libtransformer_engine{ext}*")) + if matches: + ctypes.CDLL(str(matches[0]), mode=ctypes.RTLD_GLOBAL) + return True + return False + except Exception as e: + return False + + +_metax_libs_loaded = False + + +def _ensure_metax_libs(): + global _metax_libs_loaded + if not _metax_libs_loaded: + _metax_libs_loaded = _load_metax_libs() + if _metax_libs_loaded: + print(f"[Metax] Successfully loaded Metax libs") + return _metax_libs_loaded + + +def _check_metax_available() -> bool: + if not torch.cuda.is_available(): + return False + + try: + from ...._build_config import SKIP_METAX_BUILD + + if SKIP_METAX_BUILD: + print("[Metax] Disabled: Metax was skipped at build time") + return False + except ImportError: + if bool(int(os.environ.get("TE_FL_SKIP_METAX", "0"))): + print("[Metax] Disabled: TE_FL_SKIP_METAX=1") + return False + + try: + if not _ensure_metax_libs(): + return False + import transformer_engine_torch_metax + + return True + except (ImportError, OSError) as e: + print(f"[Metax] Import failed: {e}") + return False + + +def _get_tex(): + _ensure_metax_libs() + import transformer_engine_torch_metax + + return transformer_engine_torch_metax + + +class MetaxBackend(TEFLBackendBase): + @staticmethod + def check_available() -> bool: + return _check_metax_available() + + def __init__(self): + self._tex = None + + def _get_tex(self): + if self._tex is None: + self._tex = _get_tex() + return self._tex + + def is_available(self) -> bool: + return _check_metax_available() + + def get_attention_backend(self, attention_params=None): + # Import the metax get_attention_backend function + try: + from transformer_engine_metax.pytorch.attention.dot_product_attention import utils + + return utils.get_attention_backend(attention_params) + + except ImportError as e: + raise RuntimeError( + f"Failed to import metax FlashAttention: {e}. " + "Please ensure flash-attn is installed and transformer_engine_metax is available." + ) + except Exception as e: + raise RuntimeError( + f"Failed to get_attention_backend: {e}. Attention_params: {self.attention_params}" + ) + + ##### transformer_engine/pytorch/csrc/extensions/pybind.cpp ##### + def quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + output: Optional[torch.Tensor] = None, + noop: Optional[torch.Tensor] = None, + ) -> Any: + tex = self._get_tex() + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + return tex.quantize(tensor, quantizer, output, noop) + + def dequantize( + self, + input: Any, + otype: DType, + ) -> Any: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.dequantize(input, otype) + + def bgrad_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.bgrad_quantize(input, quantizer) + + def group_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.group_quantize(input, quantizer) + + def bgrad_group_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.bgrad_group_quantize(input, quantizer) + + def generic_gemm( + self, + A: Any, + transA: bool, + B: Any, + transB: bool, + D: Any, + quantizer: Any, + output_dtype: Optional[DType], + bias: Optional[torch.Tensor], + bias_type: DType, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[CommOverlapType] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, + ) -> List[Any]: + tex = self._get_tex() + + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + comm_type = tex.CommOverlapType(int(comm_type)) if comm_type is not None else None + output_dtype = tex.DType(int(output_dtype)) if output_dtype is not None else None + return tex.generic_gemm( + A, + transA, + B, + transB, + D, + quantizer, + output_dtype, + bias, + bias_type, + gelu, + gelu_in, + grad, + workspace, + workspace_size, + accumulate, + use_split_accumulator, + comm_overlap, + comm_type, + extra_output, + bulk_overlap, + alpha, + beta, + ) + + # GELU and variants # + def glu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.glu(input, quantizer) + + def gelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.gelu(input, quantizer) + + def geglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.geglu(input, quantizer) + + def qgelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgelu(input, quantizer) + + def qgeglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgeglu(input, quantizer) + + # ReLU and variants # + def relu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.relu(input, quantizer) + + def reglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.reglu(input, quantizer) + + def srelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.srelu(input, quantizer) + + def sreglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.sreglu(input, quantizer) + + # SwiGLU and variants # + def silu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.silu(input, quantizer) + + def swiglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.swiglu(input, quantizer) + + def clamped_swiglu( + self, + input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_swiglu(input, quantizer, limit, alpha) + + # Backward of GELU and variants # + def dglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dglu(grad, fwd_input, quantizer) + + def dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgelu(grad, fwd_input, quantizer) + + def dgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgeglu(grad, fwd_input, quantizer) + + def dqgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgelu(grad, fwd_input, quantizer) + + def dqgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgeglu(grad, fwd_input, quantizer) + + # Backward of ReLU and variants # + def drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.drelu(grad, fwd_input, quantizer) + + def dreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dreglu(grad, fwd_input, quantizer) + + def dsrelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsrelu(grad, fwd_input, quantizer) + + def dsreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsreglu(grad, fwd_input, quantizer) + + # Backward of SiLU and variants # + def dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsilu(grad, fwd_input, quantizer) + + def dswiglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dswiglu(grad, fwd_input, quantizer) + + def clamped_dswiglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_dswiglu(grad, fwd_input, quantizer, limit, alpha) + + # DBias + DAct fusions # + def dbias_dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dgelu(grad, fwd_input, quantizer) + + def dbias_dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsilu(grad, fwd_input, quantizer) + + def dbias_drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_drelu(grad, fwd_input, quantizer) + + def dbias_dqgelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dqgelu(grad, fwd_input, quantizer) + + def dbias_dsrelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsrelu(grad, fwd_input, quantizer) + + # Permutation functions + def moe_permute_fwd( + self, + input: torch.Tensor, + dtype: DType, + indices: torch.Tensor, + num_out_tokens: int, + workspace: List[torch.Tensor], + max_expanded_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_fwd( + input, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + + def moe_permute_bwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_bwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_fwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_bwd( + self, + input_bwd: torch.Tensor, + input_fwd: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_bwd(input_bwd, input_fwd, dtype, row_id_map, prob) + + # Softmax functions + def scaled_softmax_forward( + self, + input: torch.Tensor, + scale: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_forward(input, scale) + + def scaled_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_masked_softmax_forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_forward(input, mask, scale_factor) + + def scaled_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_upper_triang_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_forward(input, scale_factor) + + def scaled_upper_triang_masked_softmax_backward( + self, + output_grads_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_backward( + output_grads_, softmax_results_, scale_factor + ) + + def scaled_aligned_causal_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_forward(input, scale_factor) + + def scaled_aligned_causal_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_backward( + output_grad_, softmax_results_, scale_factor + ) + + # Other granular functions + def layernorm_fwd( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.layernorm_fwd( + input, weight, bias, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def layernorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_fwd( + self, + input: Any, + weight: Any, + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.rmsnorm_fwd( + input, weight, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def rmsnorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_bwd_add( + self, + dz: torch.Tensor, + x: torch.Tensor, + add: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd_add(dz, x, add, rsigma, gamma, sm_margin, zero_centered_gamma) + + def multi_tensor_quantize( + self, + tensor_list: List[torch.Tensor], + quantizer_list: List[Any], + ) -> List[Any]: + tex = self._get_tex() + return tex.multi_tensor_quantize(tensor_list, quantizer_list) + + def split_quantize( + self, + tensor: torch.Tensor, + split_sections: List[int], + quantizer_list: List[Any], + disable_bulk_allocation: bool = False, + ) -> List[Any]: + tex = self._get_tex() + return tex.split_quantize(tensor, split_sections, quantizer_list, disable_bulk_allocation) + + def te_general_grouped_gemm( + self, + A: List[Any], + transa: bool, + B: List[Any], + transb: bool, + D: Optional[List[torch.Tensor]], + D_type: DType, + m_splits: List[int], + bias: List[torch.Tensor], + bias_type: DType, + single_output: bool, + pre_gelu_out: List[torch.Tensor], + grad: bool, + workspace: List[torch.Tensor], + workspaceSizes: int, + accumulate: bool, + use_split_accumulator: bool, + math_sm_count: int, + ) -> Optional[List[torch.Tensor]]: + tex = self._get_tex() + D_type = tex.DType(int(D_type)) if D_type is not None else None + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + return tex.te_general_grouped_gemm( + A, + transa, + B, + transb, + D, + D_type, + m_splits, + bias, + bias_type, + single_output, + pre_gelu_out, + grad, + workspace, + workspaceSizes, + accumulate, + use_split_accumulator, + math_sm_count, + ) + + def te_general_grouped_gemm_for_grouped_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_grouped_tensor(*args, **kwargs) + + def te_general_grouped_gemm_for_discrete_in(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_discrete_in(*args, **kwargs) + + def te_general_grouped_gemm_for_discrete_out(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_discrete_out(*args, **kwargs) + + def fp8_transpose( + self, + input: torch.Tensor, + dtype: DType, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.fp8_transpose(input, dtype, out) + + def swap_first_dims( + self, + tensor: torch.Tensor, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.swap_first_dims(tensor, out) + + def nvfp4_data_transpose(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_data_transpose(*args, **kwargs) + + def swizzle_scales_for_gemm_(self, *args, **kwargs): + tex = self._get_tex() + return tex.swizzle_scales_for_gemm_(*args, **kwargs) + + def grouped_swizzle_for_gemm(self, *args, **kwargs): + tex = self._get_tex() + return tex.grouped_swizzle_for_gemm(*args, **kwargs) + + def convert_host_pointers_to_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.convert_host_pointers_to_tensor(*args, **kwargs) + + def get_device_pointer_for_data_and_scales(self, *args, **kwargs): + tex = self._get_tex() + return tex.get_device_pointer_for_data_and_scales(*args, **kwargs) + + def splits_to_offsets(self, *args, **kwargs): + tex = self._get_tex() + return tex.splits_to_offsets(*args, **kwargs) + + def get_fused_attn_backend( + self, + is_training: bool, + q_dtype: DType, + kv_dtype: DType, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + p_dropout: float, + num_attn_heads: int, + num_gqa_groups: int, + max_seqlen_q: int, + max_seqlen_kv: int, + head_dim_qk: int, + head_dim_v: int, + window_size_left: int, + window_size_right: int, + return_max_logit: bool, + cuda_graph: bool = False, + deterministic: bool = False, + ) -> NVTE_Fused_Attn_Backend: + tex = self._get_tex() + + q_dtype = tex.DType(int(q_dtype)) if q_dtype is not None else None + kv_dtype = tex.DType(int(kv_dtype)) if kv_dtype is not None else None + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + result = tex.get_fused_attn_backend( + is_training, + q_dtype, + kv_dtype, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + p_dropout, + num_attn_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size_left, + window_size_right, + return_max_logit, + cuda_graph, + deterministic, + ) + return NVTE_Fused_Attn_Backend(result) + + def compute_amax( + self, + input: torch.Tensor, + amax: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.compute_amax(input, amax) + + def fused_amax_and_scale_update_after_reduction( + self, + amax_reduction_buffer: torch.Tensor, + amax_histories: List[torch.Tensor], + scales: List[torch.Tensor], + amax_compute_algo: str, + fp8_dtype: DType, + margin: float, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.fused_amax_and_scale_update_after_reduction( + amax_reduction_buffer, amax_histories, scales, amax_compute_algo, fp8_dtype, margin + ) + + def fp8_block_scaling_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + ) -> None: + tex = self._get_tex() + return tex.fp8_block_scaling_compute_partial_amax( + tensor, amax, h, w, start_offset, block_len + ) + + def fp8_block_scaling_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + out_dtype: DType, + ) -> None: + tex = self._get_tex() + out_dtype = tex.DType(int(out_dtype)) if out_dtype is not None else None + return tex.fp8_block_scaling_partial_cast( + inp, out, scale, h, w, start_offset, block_len, out_dtype + ) + + # MXFP8 ops + def mxfp8_scaling_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.mxfp8_scaling_compute_partial_amax(*args, **kwargs) + + def mxfp8_scaling_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.mxfp8_scaling_partial_cast(*args, **kwargs) + + # NVFP4 ops + def nvfp4_2d_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_compute_partial_amax(*args, **kwargs) + + def nvfp4_multi_tensor_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_compute_partial_amax(*args, **kwargs) + + def nvfp4_compute_global_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_compute_global_scale(*args, **kwargs) + + def nvfp4_compute_per_block_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_compute_per_block_scale(*args, **kwargs) + + def nvfp4_expand_scale_to_fp8(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_expand_scale_to_fp8(*args, **kwargs) + + def nvfp4_fused_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_fused_scale(*args, **kwargs) + + def nvfp4_multi_tensor_fused_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_fused_scale(*args, **kwargs) + + def nvfp4_2d_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_partial_cast(*args, **kwargs) + + def nvfp4_multi_tensor_2d_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_2d_partial_cast(*args, **kwargs) + + def nvfp4_2d_multi_tensor_transpose(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_multi_tensor_transpose(*args, **kwargs) + + def fused_multi_row_padding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + padded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_padding(input, output, input_row_list, padded_input_row_list) + + def fused_multi_row_unpadding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + unpadded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_unpadding(input, output, input_row_list, unpadded_input_row_list) + + # attention kernels + def fa_prepare_fwd( + self, + qkvi: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_fwd(qkvi) + + def fa_prepare_bwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_bwd(q, k, v) + + def fused_attn_fwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + is_training: bool, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + fake_dtype: torch.dtype, + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + page_table_k: Optional[torch.Tensor], + page_table_v: Optional[torch.Tensor], + s_quantizer: Any, + o_quantizer: Any, + Bias: Optional[torch.Tensor], + SoftmaxOffset: Optional[torch.Tensor], + rng_gen: Optional[torch.Generator], + rng_elts_per_thread: int, + return_max_logit: bool, + cuda_graph: bool = False, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + return tex.fused_attn_fwd( + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + fake_dtype, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + s_quantizer, + o_quantizer, + Bias, + SoftmaxOffset, + rng_gen, + rng_elts_per_thread, + return_max_logit, + cuda_graph, + ) + + def fused_attn_bwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + deterministic: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + O: Any, + dO: Any, + fake_dtype: torch.dtype, + dqkv_type: DType, + Aux_CTX_Tensors: List[torch.Tensor], + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + s_quantizer: Any, + dp_quantizer: Any, + dqkv_quantizer: Any, + cuda_graph: bool = False, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + dqkv_type = tex.DType(int(dqkv_type)) if dqkv_type is not None else None + + return tex.fused_attn_bwd( + max_seqlen_q, + max_seqlen_kv, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + deterministic, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + O, + dO, + fake_dtype, + dqkv_type, + Aux_CTX_Tensors, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + s_quantizer, + dp_quantizer, + dqkv_quantizer, + cuda_graph, + ) + + def copy_to_kv_cache( + self, + new_k: torch.Tensor, + new_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + cu_new_lens: torch.Tensor, + cu_cached_lens: torch.Tensor, + qkv_format: NVTE_QKV_Format, + b: int, + max_ctx_len: int, + max_seq_len: int, + max_pages_per_seq: int, + is_non_paged: bool, + ) -> None: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + page_table, + cu_new_lens, + cu_cached_lens, + qkv_format, + b, + max_ctx_len, + max_seq_len, + max_pages_per_seq, + is_non_paged, + ) + + def convert_thd_to_bshd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + b: int, + max_seq_len: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_thd_to_bshd(tensor, cu_seqlens, b, max_seq_len) + + def convert_bshd_to_thd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + t: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_bshd_to_thd(tensor, cu_seqlens, t) + + # fused apply rope + def fused_rope_forward( + self, + input: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_forward( + input, freqs, start_positions, qkv_format, interleaved, cu_seqlens, cp_size, cp_rank + ) + + def fused_rope_backward( + self, + output_grads: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_backward( + output_grads, + freqs, + start_positions, + qkv_format, + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_forward( + self, + qkv_input: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_forward( + qkv_input, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_backward( + self, + q_grad_out: torch.Tensor, + k_grad_out: torch.Tensor, + v_grad_out: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_backward( + q_grad_out, + k_grad_out, + v_grad_out, + q_freqs, + k_freqs, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + # fused router + def fused_topk_with_score_function_fwd( + self, + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], + score_function: str, + expert_bias: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + ) + + def fused_topk_with_score_function_bwd( + self, + num_tokens: int, + num_experts: int, + routing_map: torch.Tensor, + intermediate_output: torch.Tensor, + grad_probs: torch.Tensor, + topk: int, + use_pre_softmax: bool, + scaling_factor: Optional[float], + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_topk_with_score_function_bwd( + num_tokens, + num_experts, + routing_map, + intermediate_output, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + ) + + def fused_score_for_moe_aux_loss_fwd( + self, + logits: torch.Tensor, + topk: int, + score_function: str, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_fwd( + logits, + topk, + score_function, + ) + + def fused_score_for_moe_aux_loss_bwd( + self, + num_tokens: int, + num_experts: int, + intermediate_output: torch.Tensor, + grad_scores: torch.Tensor, + topk: int, + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_bwd( + num_tokens, + num_experts, + intermediate_output, + grad_scores, + topk, + score_function, + ) + + def fused_moe_aux_loss_fwd( + self, + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + num_experts: int, + num_rows: int, + num_cols: int, + topk: int, + coeff: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_moe_aux_loss_fwd( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + num_rows, + num_cols, + topk, + coeff, + ) + + def fused_moe_aux_loss_bwd( + self, + Const_buf: torch.Tensor, + tokens_per_expert: torch.Tensor, + num_rows: int, + num_cols: int, + grad_aux_loss: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_moe_aux_loss_bwd( + Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss + ) + + # Dropout + def dropout_fwd( + self, + input: torch.Tensor, + dropout_probability: float, + out: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.dropout_fwd(input, dropout_probability, out) + + def dropout_bwd( + self, + grad_output: torch.Tensor, + mask: torch.Tensor, + dropout_probability: float, + grad_input: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.dropout_bwd(grad_output, mask, dropout_probability, grad_input) + + # Misc + def get_cublasLt_version(self) -> int: + tex = self._get_tex() + return tex.get_cublasLt_version() + + def get_cudnn_version(self) -> int: + tex = self._get_tex() + return tex.get_cudnn_version() + + def get_num_cublas_streams(self) -> int: + tex = self._get_tex() + return tex.get_num_cublas_streams() + + # Support THD format for Context Parallel + def thd_read_half_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + half_idx: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_half_tensor(tensor, cu_seqlens, half_idx) + + def thd_second_half_lse_correction( + self, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_second_half_lse_correction(lse, lse_per_step, cu_seqlens, lse_packed) + + def thd_read_second_half_lse( + self, + lse: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + second_half_lse_seqlen: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_second_half_lse(lse, cu_seqlens, lse_packed, second_half_lse_seqlen) + + def thd_out_correction( + self, + out: torch.Tensor, + out_per_step: torch.Tensor, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + only_second_half: bool, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_out_correction( + out, out_per_step, lse, lse_per_step, cu_seqlens, only_second_half, lse_packed + ) + + def thd_grad_correction( + self, + grad: torch.Tensor, + grad_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + first_half: str, + second_half: str, + ) -> None: + tex = self._get_tex() + return tex.thd_grad_correction(grad, grad_per_step, cu_seqlens, first_half, second_half) + + def thd_get_partitioned_indices( + self, + cu_seqlens: torch.Tensor, + total_tokens: int, + world_size: int, + rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, world_size, rank) + + # nvshmem functions + def init_nvshmem_backend( + self, + process_group: Any, + ) -> None: + tex = self._get_tex() + return tex.init_nvshmem_backend(process_group) + + def create_nvshmem_tensor( + self, + shape: List[int], + dtype: torch.dtype, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.create_nvshmem_tensor(shape, dtype) + + def nvshmem_send_on_current_stream( + self, + src: torch.Tensor, + dst: torch.Tensor, + peer: int, + signal: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_send_on_current_stream(src, dst, peer, signal) + + def nvshmem_wait_on_current_stream( + self, + signal: torch.Tensor, + wait_kind: str, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_wait_on_current_stream(signal, wait_kind) + + def nvshmem_finalize(self) -> None: + tex = self._get_tex() + return tex.nvshmem_finalize() + + # multi-tensor functions + def multi_tensor_scale( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale) + + def multi_tensor_scale_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.multi_tensor_scale_tensor(*args, **kwargs) + + def multi_tensor_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor) + + def multi_tensor_unscale_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + inv_scale: torch.Tensor, + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_unscale_l2norm( + chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor + ) + + def multi_tensor_adam( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_param_remainder( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_param_remainder( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_fp8( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + fp8_dtype: DType, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.multi_tensor_adam_fp8( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + fp8_dtype, + ) + + def multi_tensor_adam_capturable( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_adam_capturable_master( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_sgd( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + wd: float, + momentum: float, + dampening: float, + lr: float, + nesterov: bool, + first_run: bool, + wd_after_momentum: bool, + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_lists, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, + ) + + def multi_tensor_compute_scale_and_scale_inv( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + max_fp8: float, + force_pow_2_scales: bool, + epsilon: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_compute_scale_and_scale_inv( + chunk_size, noop_flag, tensor_lists, max_fp8, force_pow_2_scales, epsilon + ) + + def multi_tensor_compute_scale_inv_e8m0(self, *args, **kwargs): + tex = self._get_tex() + return tex.multi_tensor_compute_scale_inv_e8m0(*args, **kwargs) + + # Comm+GEMM Overlap + def bulk_overlap_ag_with_external_gemm( + self, + allgather_communicator: CommOverlap, + send_stream: Any, + recv_stream: Any, + ) -> Any: + tex = self._get_tex() + return tex.bulk_overlap_ag_with_external_gemm( + allgather_communicator, send_stream, recv_stream + ) + + ############## class func ################################# + def get_flash_attention_class(self): + from .flash_attention import FlashAttentionMETAX + + return FlashAttentionMETAX + + def create_fp8_tensor_meta(self) -> FP8TensorMeta: + tex = self._get_tex() + return tex.FP8TensorMeta() + + def create_comm_overlap_helper( + self, + world_group: Optional[Any] = None, + intra_node_group: Optional[Any] = None, + ) -> "CommOverlapHelper": + tex = self._get_tex() + return tex.CommOverlapHelper(world_group, intra_node_group) + + def create_comm_overlap( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + num_splits: int = 3, + num_max_streams: int = 3, + comm_cga_size: int = 2, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 16, + set_sm_margin: bool = True, + atomic_gemm: bool = False, + rs_overlap_first_gemm: bool = False, + ) -> "CommOverlap": + tex = self._get_tex() + return tex.CommOverlap( + buffer_shape, + buffer_dtype, + helper, + tp_size, + num_splits, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + rs_overlap_first_gemm, + ) + + def create_comm_overlap_p2p( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + comm_type: Any, + num_max_streams: int = 3, + comm_cga_size: int = 1, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 1, + set_sm_margin: bool = False, + atomic_gemm: bool = False, + use_ce: bool = True, + aggregate: bool = False, + ) -> "CommOverlapP2P": + tex = self._get_tex() + return tex.CommOverlapP2P( + buffer_shape, + buffer_dtype, + helper, + tp_size, + comm_type, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + use_ce, + aggregate, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/metax/register_ops.py b/transformer_engine/plugin/core/backends/vendor/metax/register_ops.py new file mode 100644 index 0000000000..cfe3a175ff --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/metax/register_ops.py @@ -0,0 +1,1173 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Metax vendor backend operator registrations. + +This module registers all VENDOR (Metax) implementations from transformer_engine_torch. +""" + +from __future__ import annotations + +import functools + +from ....types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all Metax (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + # Import Metax backend to get all the wrapped tex functions + from .metax import MetaxBackend + + # Create a backend instance to access the methods + backend = MetaxBackend() + + # Check if Metax is available before registering + if not backend.is_available(): + return + + # Bind is_available to all methods + is_avail = backend.is_available + + impls = [ + # Normalization + OpImpl( + op_name="rmsnorm_fwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_fwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd_add", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd_add, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="layernorm_fwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_fwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="layernorm_bwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_bwd, is_avail), + vendor="METAX", + priority=100, + ), + # GEMM + OpImpl( + op_name="generic_gemm", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.generic_gemm, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_grouped_tensor", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_grouped_tensor, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_discrete_in", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_discrete_in, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_discrete_out", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_discrete_out, is_avail), + vendor="METAX", + priority=100, + ), + # Quantization + OpImpl( + op_name="quantize", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.quantize, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dequantize", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dequantize, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="bgrad_quantize", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_quantize, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="group_quantize", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.group_quantize, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="bgrad_group_quantize", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_group_quantize, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="split_quantize", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.split_quantize, is_avail), + vendor="METAX", + priority=100, + ), + # Activations - Forward + OpImpl( + op_name="glu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.glu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="gelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.gelu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="geglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.geglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="qgelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgelu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="qgeglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgeglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="relu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.relu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="reglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.reglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="srelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.srelu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="sreglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.sreglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="silu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.silu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="swiglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swiglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="clamped_swiglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_swiglu, is_avail), + vendor="METAX", + priority=100, + ), + # Activations - Backward + OpImpl( + op_name="dglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dgelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgelu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dgeglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgeglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dqgelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgelu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dqgeglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgeglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="drelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.drelu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dreglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dreglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dsrelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsrelu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dsreglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsreglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dsilu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsilu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dswiglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dswiglu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="clamped_dswiglu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_dswiglu, is_avail), + vendor="METAX", + priority=100, + ), + # Activations - Bias + Backward + OpImpl( + op_name="dbias_dgelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dgelu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dbias_dsilu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsilu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dbias_drelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_drelu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dbias_dqgelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dqgelu, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dbias_dsrelu", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsrelu, is_avail), + vendor="METAX", + priority=100, + ), + # Softmax + OpImpl( + op_name="scaled_softmax_forward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_forward, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="scaled_softmax_backward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_backward, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_forward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_forward, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_backward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_backward, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_forward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_forward, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_backward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_backward, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_forward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_forward, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_backward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_backward, is_avail), + vendor="METAX", + priority=100, + ), + # MOE operations + OpImpl( + op_name="moe_permute_fwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_fwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="moe_permute_bwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_bwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_fwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_fwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_bwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_bwd, is_avail), + vendor="METAX", + priority=100, + ), + # Fused attention + OpImpl( + op_name="get_fused_attn_backend", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_fused_attn_backend, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_attn_fwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_attn_fwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_attn_bwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_attn_bwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fa_prepare_fwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_fwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fa_prepare_bwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_bwd, is_avail), + vendor="METAX", + priority=100, + ), + # KV cache + OpImpl( + op_name="copy_to_kv_cache", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.copy_to_kv_cache, is_avail), + vendor="METAX", + priority=100, + ), + # Tensor format conversions + OpImpl( + op_name="convert_thd_to_bshd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_thd_to_bshd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="convert_bshd_to_thd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_bshd_to_thd, is_avail), + vendor="METAX", + priority=100, + ), + # RoPE (Rotary Position Embedding) + OpImpl( + op_name="fused_rope_forward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_forward, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_rope_backward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_backward, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_forward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_forward, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_backward", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_backward, is_avail), + vendor="METAX", + priority=100, + ), + # TopK and MOE aux loss + OpImpl( + op_name="fused_topk_with_score_function_fwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_fwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_topk_with_score_function_bwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_bwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_fwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_fwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_bwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_bwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_fwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_fwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_bwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_bwd, is_avail), + vendor="METAX", + priority=100, + ), + # Dropout + OpImpl( + op_name="dropout_fwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_fwd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="dropout_bwd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_bwd, is_avail), + vendor="METAX", + priority=100, + ), + # FP8 operations + OpImpl( + op_name="fp8_transpose", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_transpose, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="swap_first_dims", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swap_first_dims, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvfp4_data_transpose", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_data_transpose, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="swizzle_scales_for_gemm_", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swizzle_scales_for_gemm_, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="grouped_swizzle_for_gemm", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.grouped_swizzle_for_gemm, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="convert_host_pointers_to_tensor", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_host_pointers_to_tensor, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="get_device_pointer_for_data_and_scales", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_device_pointer_for_data_and_scales, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="splits_to_offsets", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.splits_to_offsets, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="compute_amax", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.compute_amax, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_amax_and_scale_update_after_reduction", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_amax_and_scale_update_after_reduction, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_compute_partial_amax", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_compute_partial_amax, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_partial_cast", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_partial_cast, is_avail), + vendor="METAX", + priority=100, + ), + # MXFP8 ops + OpImpl( + op_name="mxfp8_scaling_compute_partial_amax", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.mxfp8_scaling_compute_partial_amax, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="mxfp8_scaling_partial_cast", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.mxfp8_scaling_partial_cast, is_avail), + vendor="METAX", + priority=100, + ), + # NVFP4 ops + OpImpl( + op_name="nvfp4_2d_compute_partial_amax", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_compute_partial_amax, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_compute_partial_amax", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_compute_partial_amax, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvfp4_compute_global_scale", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_compute_global_scale, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvfp4_compute_per_block_scale", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_compute_per_block_scale, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvfp4_expand_scale_to_fp8", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_expand_scale_to_fp8, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvfp4_fused_scale", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_fused_scale, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_fused_scale", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_fused_scale, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_partial_cast", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_partial_cast, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_2d_partial_cast", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_2d_partial_cast, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_multi_tensor_transpose", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_multi_tensor_transpose, is_avail), + vendor="METAX", + priority=100, + ), + # Padding operations + OpImpl( + op_name="fused_multi_row_padding", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_padding, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="fused_multi_row_unpadding", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_unpadding, is_avail), + vendor="METAX", + priority=100, + ), + # Library version getters + OpImpl( + op_name="get_cublasLt_version", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cublasLt_version, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="get_cudnn_version", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cudnn_version, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="get_num_cublas_streams", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_num_cublas_streams, is_avail), + vendor="METAX", + priority=100, + ), + # THD (Tensor, Hidden, Dimension) operations + OpImpl( + op_name="thd_read_half_tensor", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_half_tensor, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="thd_second_half_lse_correction", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_second_half_lse_correction, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="thd_read_second_half_lse", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_second_half_lse, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="thd_out_correction", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_out_correction, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="thd_grad_correction", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_grad_correction, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="thd_get_partitioned_indices", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_get_partitioned_indices, is_avail), + vendor="METAX", + priority=100, + ), + # NVSHMEM operations + OpImpl( + op_name="init_nvshmem_backend", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.init_nvshmem_backend, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="create_nvshmem_tensor", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_nvshmem_tensor, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvshmem_send_on_current_stream", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_send_on_current_stream, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvshmem_wait_on_current_stream", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_wait_on_current_stream, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="nvshmem_finalize", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_finalize, is_avail), + vendor="METAX", + priority=100, + ), + # Multi-tensor operations + OpImpl( + op_name="multi_tensor_quantize", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_quantize, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale_tensor", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale_tensor, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_l2norm", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_l2norm, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_unscale_l2norm", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_unscale_l2norm, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_param_remainder", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_param_remainder, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_fp8", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_fp8, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable_master", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable_master, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_sgd", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_sgd, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_and_scale_inv", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_and_scale_inv, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_inv_e8m0", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_inv_e8m0, is_avail), + vendor="METAX", + priority=100, + ), + # Communication overlap operations + OpImpl( + op_name="bulk_overlap_ag_with_external_gemm", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bulk_overlap_ag_with_external_gemm, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="create_fp8_tensor_meta", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_fp8_tensor_meta, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_helper", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_helper, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap, is_avail), + vendor="METAX", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_p2p", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_p2p, is_avail), + vendor="METAX", + priority=100, + ), + # FlashAttention class getter + OpImpl( + op_name="get_flash_attention_class", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_flash_attention_class, is_avail), + vendor="METAX", + priority=100, + ), + # Attention backend selection + OpImpl( + op_name="get_attention_backend", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_attention_backend, is_avail), + vendor="METAX", + priority=100, + ), + ] + + registry.register_many(impls) diff --git a/transformer_engine/plugin/core/backends/vendor/musa/__init__.py b/transformer_engine/plugin/core/backends/vendor/musa/__init__.py new file mode 100644 index 0000000000..a76d0b41fd --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/musa/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .musa import MUSABackend + +__all__ = ["MUSABackend"] diff --git a/transformer_engine/plugin/core/backends/vendor/musa/flash_attention.py b/transformer_engine/plugin/core/backends/vendor/musa/flash_attention.py new file mode 100644 index 0000000000..cd03e82414 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/musa/flash_attention.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from transformer_engine.plugin.core.ops import FlashAttentionBase + + +class FlashAttentionMUSA(FlashAttentionBase): + def __init__( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = None, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + ) -> None: + super().__init__( + softmax_scale=softmax_scale, + attention_dropout=attention_dropout, + attention_dropout_ctx=attention_dropout_ctx, + attention_type=attention_type, + layer_number=layer_number, + deterministic=deterministic, + ) + + # Store initialization parameters for lazy loading + self._init_params = { + "softmax_scale": softmax_scale, + "attention_dropout": attention_dropout, + "attention_dropout_ctx": attention_dropout_ctx or nullcontext, + "attention_type": attention_type, + "layer_number": layer_number, + "deterministic": deterministic, + } + self._musa_flash_attn = None + + def _ensure_musa_flash_attn(self): + """Lazy initialization of musa FlashAttention.""" + if self._musa_flash_attn is not None: + return + + try: + # Import here to avoid circular dependency issues + # transformer_engine_torch must be registered before this import + from transformer_engine_musa.pytorch.attention import ( + FlashAttention as FlashAttentionMusa, + ) + + if FlashAttentionMusa is None: + raise RuntimeError( + "FlashAttention class is None - flash-attn may not be installed correctly" + ) + + self._musa_flash_attn = FlashAttentionMusa(**self._init_params) + + except ImportError as e: + raise RuntimeError( + f"Failed to import musa FlashAttention: {e}. " + "Please ensure flash-attn is installed and transformer_engine_torch is available." + ) + except Exception as e: + raise RuntimeError( + f"Failed to initialize musa FlashAttention: {e}. Init params: {self._init_params}" + ) + + @property + def backend_name(self) -> str: + return "musa" + + def _forward_impl( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Any] = None, + cp_global_ranks: Optional[List[int]] = None, + cp_stream: Optional[torch.musa.Stream] = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers: Optional[Any] = None, + inference_params: Optional[Any] = None, + flash_attention_backend: Optional[Any] = None, + fp8_output: bool = False, + num_splits: Optional[int] = 1, + ) -> torch.Tensor: + # Ensure musa flash attention is initialized + self._ensure_musa_flash_attn() + + return self._musa_flash_attn( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=cp_group, + cp_global_ranks=cp_global_ranks, + cp_stream=cp_stream, + cp_comm_type=cp_comm_type, + fp8=fp8, + fp8_meta=fp8_meta, + quantizers=quantizers, + inference_params=inference_params, + flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, + num_splits=num_splits, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/musa/musa.py b/transformer_engine/plugin/core/backends/vendor/musa/musa.py new file mode 100644 index 0000000000..0de9350928 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/musa/musa.py @@ -0,0 +1,1829 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. +import os +import sys +from typing import Any, Dict, List, Optional, Tuple, Union +import torch +from ....ops import * + + +def _load_musa_libs(): + import ctypes + import os + import subprocess + from pathlib import Path + import importlib.util + import sysconfig + import platform + import glob as glob_module + + def get_ext(): + system = platform.system() + return ".so" if system == "Linux" else ".dylib" if system == "Darwin" else ".dll" + + ext = get_ext() + + def try_load_lib(name, search_patterns): + for env_var in [f"{name.upper()}_HOME", f"{name.upper()}_PATH"]: + path = os.environ.get(env_var) + if path: + libs = glob_module.glob(f"{path}/**/lib{name}{ext}*", recursive=True) + if libs: + libs.sort(reverse=True, key=os.path.basename) + try: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + except: + pass + + musa_home = os.environ.get("MUSA_HOME") or os.environ.get("MUSA_PATH") or "/usr/local/musa" + for pattern in search_patterns: + libs = glob_module.glob(f"{musa_home}/**/{pattern}", recursive=True) + if libs: + libs.sort(reverse=True, key=os.path.basename) + try: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + except: + pass + + try: + result = subprocess.check_output(f"ldconfig -p | grep 'lib{name}{ext}'", shell=True) + for line in result.decode().split("\n"): + if f"lib{name}" in line and "=>" in line: + so_path = line.split(">")[1].strip() + if so_path: + return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) + except: + pass + + try: + return ctypes.CDLL(f"lib{name}{ext}", mode=ctypes.RTLD_GLOBAL) + except: + return None + + try: + import transformer_engine_musa + + return True + except Exception as e: + return False + + +_musa_libs_loaded = False + + +def _ensure_musa_libs(): + global _musa_libs_loaded + if not _musa_libs_loaded: + _musa_libs_loaded = _load_musa_libs() + if _musa_libs_loaded: + print(f"[MUSA] Successfully loaded MUSA libs") + return _musa_libs_loaded + + +def _check_musa_available() -> bool: + try: + if not torch.musa.is_available(): + return False + else: + return True + except Exception as e: + return False + + +def _get_tex(): + _ensure_musa_libs() + import transformer_engine_musa + import transformer_engine_musa_torch + + return transformer_engine_musa_torch + + +class MUSABackend(TEFLBackendBase): + @staticmethod + def check_available() -> bool: + return _check_musa_available() + + def __init__(self): + self._tex = None + + def _get_tex(self): + if self._tex is None: + self._tex = _get_tex() + return self._tex + + def is_available(self) -> bool: + return _check_musa_available() + + def get_attention_backend(self, attention_params=None): + """ + MUSA backend uses the default attention backend selection logic. + This allows hardware-specific checks and optimizations for MUSA devices. + Returns: + Tuple of (use_flash_attention, flash_attention_backend, use_fused_attention, + fused_attention_backend, use_unfused_attention, available_backends) + """ + # Import the original get_attention_backend function + from transformer_engine_musa.pytorch.attention import ( + get_attention_backend as _original_get_attention_backend, + ) + + return _original_get_attention_backend(attention_params) + + ##### transformer_engine/pytorch/csrc/extensions/pybind.cpp ##### + def quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + output: Optional[torch.Tensor] = None, + noop: Optional[torch.Tensor] = None, + ) -> Any: + tex = self._get_tex() + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + return tex.quantize(tensor, quantizer, output, noop) + + def dequantize( + self, + input: Any, + otype: DType, + ) -> Any: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.dequantize(input, otype) + + def bgrad_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.bgrad_quantize(input, quantizer) + + def group_quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + num_tensors: int, + first_dims: List[int], + ) -> Any: + tex = self._get_tex() + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + return tex.group_quantize(tensor, quantizer, num_tensors, first_dims) + + def bgrad_group_quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + num_tensors: int, + first_dims: List[int], + ) -> Any: + tex = self._get_tex() + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + return tex.bgrad_group_quantize(tensor, quantizer, num_tensors, first_dims) + + def generic_gemm( + self, + A: Any, + transA: bool, + B: Any, + transB: bool, + D: Any, + quantizer: Any, + output_dtype: Optional[DType], + bias: Optional[torch.Tensor], + bias_type: DType, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[CommOverlapType] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, + ) -> List[Any]: + tex = self._get_tex() + + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + comm_type = tex.CommOverlapType(int(comm_type)) if comm_type is not None else None + output_dtype = tex.DType(int(output_dtype)) if output_dtype is not None else None + return tex.generic_gemm( + A, + transA, + B, + transB, + D, + quantizer, + output_dtype, + bias, + bias_type, + gelu, + gelu_in, + grad, + workspace, + workspace_size, + accumulate, + use_split_accumulator, + comm_overlap, + comm_type, + extra_output, + bulk_overlap, + alpha, + beta, + ) + + # GLU # + def glu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.glu(input, quantizer) + + # GELU and variants # + def gelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.gelu(input, quantizer) + + def geglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.geglu(input, quantizer) + + def qgelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgelu(input, quantizer) + + def qgeglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgeglu(input, quantizer) + + # ReLU and variants # + def relu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.relu(input, quantizer) + + def reglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.reglu(input, quantizer) + + def srelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.srelu(input, quantizer) + + def sreglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.sreglu(input, quantizer) + + # SwiGLU and variants # + def silu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.silu(input, quantizer) + + def swiglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.swiglu(input, quantizer) + + def clamped_swiglu( + self, + input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_swiglu(input, quantizer, limit, alpha) + + # Backward of GLU # + def dglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dglu(grad, fwd_input, quantizer) + + # Backward of GELU and variants # + def dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgelu(grad, fwd_input, quantizer) + + def dgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgeglu(grad, fwd_input, quantizer) + + def dqgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgelu(grad, fwd_input, quantizer) + + def dqgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgeglu(grad, fwd_input, quantizer) + + # Backward of ReLU and variants # + def drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.drelu(grad, fwd_input, quantizer) + + def dreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dreglu(grad, fwd_input, quantizer) + + def dsrelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsrelu(grad, fwd_input, quantizer) + + def dsreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsreglu(grad, fwd_input, quantizer) + + # Backward of SiLU and variants # + def dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsilu(grad, fwd_input, quantizer) + + def dswiglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dswiglu(grad, fwd_input, quantizer) + + def clamped_dswiglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_dswiglu(grad, fwd_input, quantizer, limit, alpha) + + # DBias + DAct fusions # + def dbias_dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dgelu(grad, fwd_input, quantizer) + + def dbias_dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsilu(grad, fwd_input, quantizer) + + def dbias_drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_drelu(grad, fwd_input, quantizer) + + def dbias_dqgelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dqgelu(grad, fwd_input, quantizer) + + def dbias_dsrelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsrelu(grad, fwd_input, quantizer) + + # Permutation functions + def moe_permute_fwd( + self, + input: torch.Tensor, + dtype: DType, + indices: torch.Tensor, + num_out_tokens: int, + workspace: List[torch.Tensor], + max_expanded_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_fwd( + input, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + + def moe_permute_bwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_bwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_fwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_bwd( + self, + input_bwd: torch.Tensor, + input_fwd: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_bwd(input_bwd, input_fwd, dtype, row_id_map, prob) + + # Softmax functions + def scaled_softmax_forward( + self, + input: torch.Tensor, + scale: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_forward(input, scale) + + def scaled_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_masked_softmax_forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_forward(input, mask, scale_factor) + + def scaled_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_upper_triang_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_forward(input, scale_factor) + + def scaled_upper_triang_masked_softmax_backward( + self, + output_grads_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_backward( + output_grads_, softmax_results_, scale_factor + ) + + def scaled_aligned_causal_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_forward(input, scale_factor) + + def scaled_aligned_causal_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_backward( + output_grad_, softmax_results_, scale_factor + ) + + # Other granular functions + def layernorm_fwd( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.layernorm_fwd( + input, weight, bias, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def layernorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_fwd( + self, + input: Any, + weight: Any, + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.rmsnorm_fwd( + input, weight, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def rmsnorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_bwd_add( + self, + dz: torch.Tensor, + x: torch.Tensor, + add: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd_add(dz, x, add, rsigma, gamma, sm_margin, zero_centered_gamma) + + def multi_tensor_quantize( + self, + tensor_list: List[torch.Tensor], + quantizer_list: List[Any], + ) -> List[Any]: + tex = self._get_tex() + return tex.multi_tensor_quantize(tensor_list, quantizer_list) + + def split_quantize( + self, + tensor: torch.Tensor, + split_sections: List[int], + quantizer_list: List[Any], + disable_bulk_allocation: bool = False, + ) -> List[Any]: + tex = self._get_tex() + return tex.split_quantize(tensor, split_sections, quantizer_list, disable_bulk_allocation) + + def te_general_grouped_gemm( + self, + A: List[Any], + transa: bool, + B: List[Any], + transb: bool, + D: Optional[List[torch.Tensor]], + D_type: DType, + m_splits: List[int], + bias: List[torch.Tensor], + bias_type: DType, + single_output: bool, + pre_gelu_out: List[torch.Tensor], + grad: bool, + workspace: List[torch.Tensor], + workspaceSizes: int, + accumulate: bool, + use_split_accumulator: bool, + math_sm_count: int, + ) -> Optional[List[torch.Tensor]]: + tex = self._get_tex() + D_type = tex.DType(int(D_type)) if D_type is not None else None + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + return tex.te_general_grouped_gemm( + A, + transa, + B, + transb, + D, + D_type, + m_splits, + bias, + bias_type, + single_output, + pre_gelu_out, + grad, + workspace, + workspaceSizes, + accumulate, + use_split_accumulator, + math_sm_count, + ) + + def te_general_grouped_gemm_for_grouped_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_grouped_tensor(*args, **kwargs) + + def te_general_grouped_gemm_for_discrete_in(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_discrete_in(*args, **kwargs) + + def te_general_grouped_gemm_for_discrete_out(self, *args, **kwargs): + tex = self._get_tex() + return tex.te_general_grouped_gemm_for_discrete_out(*args, **kwargs) + + def fp8_transpose( + self, + input: torch.Tensor, + dtype: DType, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.fp8_transpose(input, dtype, out) + + def swap_first_dims( + self, + tensor: torch.Tensor, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.swap_first_dims(tensor, out) + + def nvfp4_data_transpose(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_data_transpose(*args, **kwargs) + + def swizzle_scales_for_gemm_(self, tensor: torch.Tensor) -> None: + tex = self._get_tex() + return tex.swizzle_scales_for_gemm_(tensor) + + def grouped_swizzle_for_gemm(self, *args, **kwargs): + tex = self._get_tex() + return tex.grouped_swizzle_for_gemm(*args, **kwargs) + + def convert_host_pointers_to_tensor(self, *args, **kwargs): + tex = self._get_tex() + return tex.convert_host_pointers_to_tensor(*args, **kwargs) + + def get_device_pointer_for_data_and_scales(self, *args, **kwargs): + tex = self._get_tex() + return tex.get_device_pointer_for_data_and_scales(*args, **kwargs) + + def splits_to_offsets(self, first_dims, logical_last_dim): + tex = self._get_tex() + return tex.splits_to_offsets(first_dims, logical_last_dim) + + def get_fused_attn_backend( + self, + is_training: bool, + q_dtype: DType, + kv_dtype: DType, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + p_dropout: float, + num_attn_heads: int, + num_gqa_groups: int, + max_seqlen_q: int, + max_seqlen_kv: int, + head_dim_qk: int, + head_dim_v: int, + window_size_left: int, + window_size_right: int, + return_max_logit: bool, + cuda_graph: bool = False, + deterministic: bool = False, + ) -> NVTE_Fused_Attn_Backend: + tex = self._get_tex() + + q_dtype = tex.DType(int(q_dtype)) if q_dtype is not None else None + kv_dtype = tex.DType(int(kv_dtype)) if kv_dtype is not None else None + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + result = tex.get_fused_attn_backend( + is_training, + q_dtype, + kv_dtype, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + p_dropout, + num_attn_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size_left, + window_size_right, + return_max_logit, + cuda_graph, + deterministic, + ) + return NVTE_Fused_Attn_Backend(result) + + def compute_amax( + self, + input: torch.Tensor, + amax: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.compute_amax(input, amax) + + def fused_amax_and_scale_update_after_reduction( + self, + amax_reduction_buffer: torch.Tensor, + amax_histories: List[torch.Tensor], + scales: List[torch.Tensor], + amax_compute_algo: str, + fp8_dtype: DType, + margin: float, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.fused_amax_and_scale_update_after_reduction( + amax_reduction_buffer, amax_histories, scales, amax_compute_algo, fp8_dtype, margin + ) + + def fp8_block_scaling_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + ) -> None: + tex = self._get_tex() + return tex.fp8_block_scaling_compute_partial_amax( + tensor, amax, h, w, start_offset, block_len + ) + + def fp8_block_scaling_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + out_dtype: DType, + ) -> None: + tex = self._get_tex() + out_dtype = tex.DType(int(out_dtype)) if out_dtype is not None else None + return tex.fp8_block_scaling_partial_cast( + inp, out, scale, h, w, start_offset, block_len, out_dtype + ) + + # MXFP8 scaling + def mxfp8_scaling_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.mxfp8_scaling_compute_partial_amax(*args, **kwargs) + + def mxfp8_scaling_partial_cast(self, inp, out, scale, h, w, start_offset, block_len, out_dtype): + tex = self._get_tex() + out_dtype = tex.DType(int(out_dtype)) if out_dtype is not None else None + return tex.mxfp8_scaling_partial_cast( + inp, out, scale, h, w, start_offset, block_len, out_dtype + ) + + # NVFP4 2D + def nvfp4_2d_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_compute_partial_amax(*args, **kwargs) + + def nvfp4_multi_tensor_compute_partial_amax(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_compute_partial_amax(*args, **kwargs) + + def nvfp4_compute_global_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_compute_global_scale(*args, **kwargs) + + def nvfp4_compute_per_block_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_compute_per_block_scale(*args, **kwargs) + + def nvfp4_expand_scale_to_fp8(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_expand_scale_to_fp8(*args, **kwargs) + + def nvfp4_fused_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_fused_scale(*args, **kwargs) + + def nvfp4_multi_tensor_fused_scale(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_fused_scale(*args, **kwargs) + + def nvfp4_2d_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_partial_cast(*args, **kwargs) + + def nvfp4_multi_tensor_2d_partial_cast(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_multi_tensor_2d_partial_cast(*args, **kwargs) + + def nvfp4_2d_multi_tensor_transpose(self, *args, **kwargs): + tex = self._get_tex() + return tex.nvfp4_2d_multi_tensor_transpose(*args, **kwargs) + + def fused_multi_row_padding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + padded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_padding(input, output, input_row_list, padded_input_row_list) + + def fused_multi_row_unpadding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + unpadded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_unpadding(input, output, input_row_list, unpadded_input_row_list) + + # attention kernels + def fa_prepare_fwd( + self, + qkvi: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_fwd(qkvi) + + def fa_prepare_bwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_bwd(q, k, v) + + def fused_attn_fwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + is_training: bool, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + fake_dtype: torch.dtype, + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + page_table_k: Optional[torch.Tensor], + page_table_v: Optional[torch.Tensor], + s_quantizer: Any, + o_quantizer: Any, + Bias: Optional[torch.Tensor], + SoftmaxOffset: Optional[torch.Tensor], + rng_gen: Optional[torch.Generator], + rng_elts_per_thread: int, + return_max_logit: bool, + cuda_graph: bool = False, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + return tex.fused_attn_fwd( + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + fake_dtype, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + s_quantizer, + o_quantizer, + Bias, + SoftmaxOffset, + rng_gen, + rng_elts_per_thread, + return_max_logit, + cuda_graph, + ) + + def fused_attn_bwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + deterministic: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + O: Any, + dO: Any, + fake_dtype: torch.dtype, + dqkv_type: DType, + Aux_CTX_Tensors: List[torch.Tensor], + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + s_quantizer: Any, + dp_quantizer: Any, + dqkv_quantizer: Any, + cuda_graph: bool = False, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + dqkv_type = tex.DType(int(dqkv_type)) if dqkv_type is not None else None + + return tex.fused_attn_bwd( + max_seqlen_q, + max_seqlen_kv, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + deterministic, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + O, + dO, + fake_dtype, + dqkv_type, + Aux_CTX_Tensors, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + s_quantizer, + dp_quantizer, + dqkv_quantizer, + cuda_graph, + ) + + def copy_to_kv_cache( + self, + new_k: torch.Tensor, + new_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + cu_new_lens: torch.Tensor, + cu_cached_lens: torch.Tensor, + qkv_format: NVTE_QKV_Format, + b: int, + max_ctx_len: int, + max_seq_len: int, + max_pages_per_seq: int, + is_non_paged: bool, + ) -> None: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + page_table, + cu_new_lens, + cu_cached_lens, + qkv_format, + b, + max_ctx_len, + max_seq_len, + max_pages_per_seq, + is_non_paged, + ) + + def convert_thd_to_bshd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + b: int, + max_seq_len: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_thd_to_bshd(tensor, cu_seqlens, b, max_seq_len) + + def convert_bshd_to_thd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + t: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_bshd_to_thd(tensor, cu_seqlens, t) + + # fused apply rope + def fused_rope_forward( + self, + input: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_forward( + input, freqs, start_positions, qkv_format, interleaved, cu_seqlens, cp_size, cp_rank + ) + + def fused_rope_backward( + self, + output_grads: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_backward( + output_grads, + freqs, + start_positions, + qkv_format, + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_forward( + self, + qkv_input: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_forward( + qkv_input, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_backward( + self, + q_grad_out: torch.Tensor, + k_grad_out: torch.Tensor, + v_grad_out: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_backward( + q_grad_out, + k_grad_out, + v_grad_out, + q_freqs, + k_freqs, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + # fused router + def fused_topk_with_score_function_fwd( + self, + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], + score_function: str, + expert_bias: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + ) + + def fused_topk_with_score_function_bwd( + self, + num_tokens: int, + num_experts: int, + routing_map: torch.Tensor, + intermediate_output: torch.Tensor, + grad_probs: torch.Tensor, + topk: int, + use_pre_softmax: bool, + scaling_factor: Optional[float], + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_topk_with_score_function_bwd( + num_tokens, + num_experts, + routing_map, + intermediate_output, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + ) + + def fused_score_for_moe_aux_loss_fwd( + self, + logits: torch.Tensor, + topk: int, + score_function: str, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_fwd( + logits, + topk, + score_function, + ) + + def fused_score_for_moe_aux_loss_bwd( + self, + num_tokens: int, + num_experts: int, + intermediate_output: torch.Tensor, + grad_scores: torch.Tensor, + topk: int, + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_bwd( + num_tokens, + num_experts, + intermediate_output, + grad_scores, + topk, + score_function, + ) + + def fused_moe_aux_loss_fwd( + self, + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + num_experts: int, + num_rows: int, + num_cols: int, + topk: int, + coeff: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_moe_aux_loss_fwd( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + num_rows, + num_cols, + topk, + coeff, + ) + + def fused_moe_aux_loss_bwd( + self, + Const_buf: torch.Tensor, + tokens_per_expert: torch.Tensor, + num_rows: int, + num_cols: int, + grad_aux_loss: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_moe_aux_loss_bwd( + Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss + ) + + # Dropout + def dropout_fwd( + self, + input: torch.Tensor, + dropout_probability: float, + out: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.dropout_fwd(input, dropout_probability, out) + + def dropout_bwd( + self, + grad_output: torch.Tensor, + mask: torch.Tensor, + dropout_probability: float, + grad_input: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.dropout_bwd(grad_output, mask, dropout_probability, grad_input) + + # Misc + def get_cublasLt_version(self) -> int: + tex = self._get_tex() + return tex.get_cublasLt_version() + + def get_cudnn_version(self) -> int: + tex = self._get_tex() + return tex.get_cudnn_version() + + def get_num_cublas_streams(self) -> int: + tex = self._get_tex() + return tex.get_num_cublas_streams() + + # Support THD format for Context Parallel + def thd_read_half_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + half_idx: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_half_tensor(tensor, cu_seqlens, half_idx) + + def thd_second_half_lse_correction( + self, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_second_half_lse_correction(lse, lse_per_step, cu_seqlens, lse_packed) + + def thd_read_second_half_lse( + self, + lse: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + second_half_lse_seqlen: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_second_half_lse(lse, cu_seqlens, lse_packed, second_half_lse_seqlen) + + def thd_out_correction( + self, + out: torch.Tensor, + out_per_step: torch.Tensor, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + only_second_half: bool, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_out_correction( + out, out_per_step, lse, lse_per_step, cu_seqlens, only_second_half, lse_packed + ) + + def thd_grad_correction( + self, + grad: torch.Tensor, + grad_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + first_half: str, + second_half: str, + ) -> None: + tex = self._get_tex() + return tex.thd_grad_correction(grad, grad_per_step, cu_seqlens, first_half, second_half) + + def thd_get_partitioned_indices( + self, + cu_seqlens: torch.Tensor, + total_tokens: int, + world_size: int, + rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, world_size, rank) + + # nvshmem functions + def init_nvshmem_backend( + self, + process_group: Any, + ) -> None: + tex = self._get_tex() + return tex.init_nvshmem_backend(process_group) + + def create_nvshmem_tensor( + self, + shape: List[int], + dtype: torch.dtype, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.create_nvshmem_tensor(shape, dtype) + + def nvshmem_send_on_current_stream( + self, + src: torch.Tensor, + dst: torch.Tensor, + peer: int, + signal: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_send_on_current_stream(src, dst, peer, signal) + + def nvshmem_wait_on_current_stream( + self, + signal: torch.Tensor, + wait_kind: str, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_wait_on_current_stream(signal, wait_kind) + + def nvshmem_finalize(self) -> None: + tex = self._get_tex() + return tex.nvshmem_finalize() + + # multi-tensor functions + def multi_tensor_scale( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale) + + def multi_tensor_scale_tensor( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_scale_tensor(chunk_size, noop_flag, tensor_lists, scale) + + def multi_tensor_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor) + + def multi_tensor_unscale_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + inv_scale: torch.Tensor, + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_unscale_l2norm( + chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor + ) + + def multi_tensor_adam( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_param_remainder( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_param_remainder( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_fp8( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + fp8_dtype: DType, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.multi_tensor_adam_fp8( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + fp8_dtype, + ) + + def multi_tensor_adam_capturable( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_adam_capturable_master( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_sgd( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + wd: float, + momentum: float, + dampening: float, + lr: float, + nesterov: bool, + first_run: bool, + wd_after_momentum: bool, + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_lists, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, + ) + + def multi_tensor_compute_scale_and_scale_inv( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + max_fp8: float, + force_pow_2_scales: bool, + epsilon: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_compute_scale_and_scale_inv( + chunk_size, noop_flag, tensor_lists, max_fp8, force_pow_2_scales, epsilon + ) + + def multi_tensor_compute_scale_inv_e8m0( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + block_len: int, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_compute_scale_inv_e8m0( + chunk_size, noop_flag, tensor_lists, block_len + ) + + # Comm+GEMM Overlap + def bulk_overlap_ag_with_external_gemm( + self, + allgather_communicator: CommOverlap, + send_stream: Any, + recv_stream: Any, + ) -> Any: + tex = self._get_tex() + return tex.bulk_overlap_ag_with_external_gemm( + allgather_communicator, send_stream, recv_stream + ) + + ############## class func ################################# + def get_flash_attention_class(self): + from .flash_attention import FlashAttentionMusa + + return FlashAttentionMusa + + def create_fp8_tensor_meta(self) -> FP8TensorMeta: + tex = self._get_tex() + return tex.FP8TensorMeta() + + def create_comm_overlap_helper( + self, + world_group: Optional[Any] = None, + intra_node_group: Optional[Any] = None, + ) -> "CommOverlapHelper": + tex = self._get_tex() + return tex.CommOverlapHelper(world_group, intra_node_group) + + def create_comm_overlap( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + num_splits: int = 3, + num_max_streams: int = 3, + comm_cga_size: int = 2, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 16, + set_sm_margin: bool = True, + atomic_gemm: bool = False, + rs_overlap_first_gemm: bool = False, + ) -> "CommOverlap": + tex = self._get_tex() + return tex.CommOverlap( + buffer_shape, + buffer_dtype, + helper, + tp_size, + num_splits, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + rs_overlap_first_gemm, + ) + + def create_comm_overlap_p2p( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + comm_type: Any, + num_max_streams: int = 3, + comm_cga_size: int = 1, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 1, + set_sm_margin: bool = False, + atomic_gemm: bool = False, + use_ce: bool = True, + aggregate: bool = False, + ) -> "CommOverlapP2P": + tex = self._get_tex() + return tex.CommOverlapP2P( + buffer_shape, + buffer_dtype, + helper, + tp_size, + comm_type, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + use_ce, + aggregate, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/musa/patches.py b/transformer_engine/plugin/core/backends/vendor/musa/patches.py new file mode 100644 index 0000000000..8073864d2b --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/musa/patches.py @@ -0,0 +1,73 @@ +"""Python-side compatibility patches for the MUSA vendor backend.""" + +from __future__ import annotations + +from collections.abc import Callable + +import torch + + +def _noop(*args, **kwargs): + return None + + +# Patches: (parent_object, attribute_name, replacement_callable) +_PATCH_CALLS: list[tuple[object, str, Callable[..., object]]] = [ + # We do not recommend replace is_available, due to its device-related behavior. + # (torch.cuda, "is_available", torch.musa.is_available), + (torch.cuda, "get_device_properties", torch.musa.get_device_properties), + (torch.cuda, "device", torch.musa.device), + (torch.cuda, "current_device", torch.musa.current_device), + (torch.cuda, "synchronize", torch.musa.synchronize), + (torch.cuda, "is_current_stream_capturing", torch.musa.is_current_stream_capturing), + # TODO: Add NVTX patches for MUSA. + # NVTX is CUDA-specific; make it a no-op on MUSA. + (torch.cuda.nvtx, "range_push", _noop), + (torch.cuda.nvtx, "range_pop", _noop), + # TODO: Add other patches for MUSA. +] + + +def apply_patch() -> None: + """Apply MUSA Python-side patches (idempotent, best-effort).""" + try: + from .musa import MUSABackend + + if not MUSABackend().is_available(): + return + except Exception as e: + print(f"[TE-FL] MUSA backend not available: {e}") + # If backend availability can't be determined, don't patch. + return + + # Mark TE global device type for Python-side callers. + # IMPORTANT: do not import `transformer_engine` here, because TE's `__init__.py` + # imports this module to run patches and that would cause a circular import. + try: + import transformer_engine + + transformer_engine.TE_DEVICE_TYPE = "musa" + transformer_engine.TE_PLATFORM = torch.musa + except Exception as e: + print(f"[TE-FL Musa Patches] Error setting TE device type or platform: {e}") + # Best-effort: don't fail patching if we can't set the global. + pass + + # Only patch when torch.musa exists and is usable. + if not hasattr(torch, "musa"): + return + try: + if not torch.musa.is_available(): + return + except Exception: + return + + for parent, attr, replacement in _PATCH_CALLS: + if not hasattr(parent, attr): + continue + try: + setattr(parent, attr, replacement) + except Exception: + # Best-effort: patching should never crash import/initialization. + continue + print(f"[TE-FL] MUSA backend patches applied") diff --git a/transformer_engine/plugin/core/backends/vendor/musa/register_ops.py b/transformer_engine/plugin/core/backends/vendor/musa/register_ops.py new file mode 100644 index 0000000000..cb3e3b7d29 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/musa/register_ops.py @@ -0,0 +1,1171 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +MUSA vendor backend operator registrations. + +This module registers all VENDOR (MUSA) implementations from transformer_engine_torch. +""" + +from __future__ import annotations + +import functools + +from ....types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all MUSA (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + # Import MUSA backend to get all the wrapped tex functions + from .musa import MUSABackend + + # Create a backend instance to access the methods + backend = MUSABackend() + + # Check if MUSA is available before registering + if not backend.is_available(): + return + + # Bind is_available to all methods + is_avail = backend.is_available + + impls = [ + # Normalization + OpImpl( + op_name="rmsnorm_fwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_fwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd_add", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd_add, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="layernorm_fwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_fwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="layernorm_bwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_bwd, is_avail), + vendor="MUSA", + priority=100, + ), + # GEMM + OpImpl( + op_name="generic_gemm", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.generic_gemm, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_grouped_tensor", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_grouped_tensor, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_discrete_in", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_discrete_in, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm_for_discrete_out", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm_for_discrete_out, is_avail), + vendor="MUSA", + priority=100, + ), + # Quantization + OpImpl( + op_name="quantize", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.quantize, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dequantize", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dequantize, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="bgrad_quantize", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_quantize, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="group_quantize", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.group_quantize, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="bgrad_group_quantize", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_group_quantize, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="split_quantize", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.split_quantize, is_avail), + vendor="MUSA", + priority=100, + ), + # Activations - Forward + OpImpl( + op_name="glu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.glu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="gelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.gelu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="geglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.geglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="qgelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgelu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="qgeglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgeglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="relu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.relu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="reglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.reglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="srelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.srelu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="sreglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.sreglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="silu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.silu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="swiglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swiglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="clamped_swiglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_swiglu, is_avail), + vendor="MUSA", + priority=100, + ), + # Activations - Backward + OpImpl( + op_name="dglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dgelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgelu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dgeglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgeglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dqgelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgelu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dqgeglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgeglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="drelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.drelu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dreglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dreglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dsrelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsrelu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dsreglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsreglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dsilu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsilu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dswiglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dswiglu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="clamped_dswiglu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_dswiglu, is_avail), + vendor="MUSA", + priority=100, + ), + # Activations - Bias + Backward + OpImpl( + op_name="dbias_dgelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dgelu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dbias_dsilu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsilu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dbias_drelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_drelu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dbias_dqgelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dqgelu, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dbias_dsrelu", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsrelu, is_avail), + vendor="MUSA", + priority=100, + ), + # Softmax + OpImpl( + op_name="scaled_softmax_forward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_forward, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="scaled_softmax_backward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_backward, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_forward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_forward, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_backward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_backward, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_forward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_forward, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_backward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_backward, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_forward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_forward, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_backward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_backward, is_avail), + vendor="MUSA", + priority=100, + ), + # MOE operations + OpImpl( + op_name="moe_permute_fwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_fwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="moe_permute_bwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_bwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_fwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_fwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_bwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_bwd, is_avail), + vendor="MUSA", + priority=100, + ), + # Fused attention + OpImpl( + op_name="get_fused_attn_backend", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_fused_attn_backend, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_attn_fwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_attn_fwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_attn_bwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_attn_bwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fa_prepare_fwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_fwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fa_prepare_bwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_bwd, is_avail), + vendor="MUSA", + priority=100, + ), + # KV cache + OpImpl( + op_name="copy_to_kv_cache", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.copy_to_kv_cache, is_avail), + vendor="MUSA", + priority=100, + ), + # Tensor format conversions + OpImpl( + op_name="convert_thd_to_bshd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_thd_to_bshd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="convert_bshd_to_thd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_bshd_to_thd, is_avail), + vendor="MUSA", + priority=100, + ), + # RoPE (Rotary Position Embedding) + OpImpl( + op_name="fused_rope_forward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_forward, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_rope_backward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_backward, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_forward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_forward, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_backward", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_backward, is_avail), + vendor="MUSA", + priority=100, + ), + # TopK and MOE aux loss + OpImpl( + op_name="fused_topk_with_score_function_fwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_fwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_topk_with_score_function_bwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_bwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_fwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_fwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_bwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_bwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_fwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_fwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_bwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_bwd, is_avail), + vendor="MUSA", + priority=100, + ), + # Dropout + OpImpl( + op_name="dropout_fwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_fwd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="dropout_bwd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_bwd, is_avail), + vendor="MUSA", + priority=100, + ), + # FP8 operations + OpImpl( + op_name="fp8_transpose", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_transpose, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="swap_first_dims", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swap_first_dims, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_data_transpose", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_data_transpose, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="swizzle_scales_for_gemm_", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swizzle_scales_for_gemm_, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="grouped_swizzle_for_gemm", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.grouped_swizzle_for_gemm, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="convert_host_pointers_to_tensor", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_host_pointers_to_tensor, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="get_device_pointer_for_data_and_scales", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_device_pointer_for_data_and_scales, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="splits_to_offsets", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.splits_to_offsets, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="compute_amax", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.compute_amax, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_amax_and_scale_update_after_reduction", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_amax_and_scale_update_after_reduction, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_compute_partial_amax", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_compute_partial_amax, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_partial_cast", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_partial_cast, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="mxfp8_scaling_compute_partial_amax", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.mxfp8_scaling_compute_partial_amax, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="mxfp8_scaling_partial_cast", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.mxfp8_scaling_partial_cast, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_compute_partial_amax", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_compute_partial_amax, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_compute_partial_amax", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_compute_partial_amax, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_compute_global_scale", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_compute_global_scale, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_compute_per_block_scale", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_compute_per_block_scale, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_expand_scale_to_fp8", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_expand_scale_to_fp8, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_fused_scale", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_fused_scale, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_fused_scale", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_fused_scale, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_partial_cast", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_partial_cast, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_multi_tensor_2d_partial_cast", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_multi_tensor_2d_partial_cast, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvfp4_2d_multi_tensor_transpose", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvfp4_2d_multi_tensor_transpose, is_avail), + vendor="MUSA", + priority=100, + ), + # Padding operations + OpImpl( + op_name="fused_multi_row_padding", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_padding, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="fused_multi_row_unpadding", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_unpadding, is_avail), + vendor="MUSA", + priority=100, + ), + # Library version getters + OpImpl( + op_name="get_cublasLt_version", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cublasLt_version, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="get_cudnn_version", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cudnn_version, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="get_num_cublas_streams", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_num_cublas_streams, is_avail), + vendor="MUSA", + priority=100, + ), + # THD (Tensor, Hidden, Dimension) operations + OpImpl( + op_name="thd_read_half_tensor", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_half_tensor, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="thd_second_half_lse_correction", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_second_half_lse_correction, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="thd_read_second_half_lse", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_second_half_lse, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="thd_out_correction", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_out_correction, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="thd_grad_correction", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_grad_correction, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="thd_get_partitioned_indices", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_get_partitioned_indices, is_avail), + vendor="MUSA", + priority=100, + ), + # NVSHMEM operations + OpImpl( + op_name="init_nvshmem_backend", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.init_nvshmem_backend, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="create_nvshmem_tensor", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_nvshmem_tensor, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvshmem_send_on_current_stream", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_send_on_current_stream, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvshmem_wait_on_current_stream", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_wait_on_current_stream, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="nvshmem_finalize", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_finalize, is_avail), + vendor="MUSA", + priority=100, + ), + # Multi-tensor operations + OpImpl( + op_name="multi_tensor_quantize", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_quantize, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale_tensor", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale_tensor, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_l2norm", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_l2norm, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_unscale_l2norm", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_unscale_l2norm, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_param_remainder", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_param_remainder, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_fp8", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_fp8, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable_master", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable_master, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_sgd", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_sgd, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_and_scale_inv", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_and_scale_inv, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_inv_e8m0", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_inv_e8m0, is_avail), + vendor="MUSA", + priority=100, + ), + # Communication overlap operations + OpImpl( + op_name="bulk_overlap_ag_with_external_gemm", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bulk_overlap_ag_with_external_gemm, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="create_fp8_tensor_meta", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_fp8_tensor_meta, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_helper", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_helper, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap, is_avail), + vendor="MUSA", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_p2p", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_p2p, is_avail), + vendor="MUSA", + priority=100, + ), + # FlashAttention class getter + OpImpl( + op_name="get_flash_attention_class", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_flash_attention_class, is_avail), + vendor="MUSA", + priority=100, + ), + # Attention backend selection + OpImpl( + op_name="get_attention_backend", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_attention_backend, is_avail), + vendor="MUSA", + priority=100, + ), + ] + + registry.register_many(impls) diff --git a/transformer_engine/plugin/core/builtin_ops.py b/transformer_engine/plugin/core/builtin_ops.py new file mode 100644 index 0000000000..c991d4fc51 --- /dev/null +++ b/transformer_engine/plugin/core/builtin_ops.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Built-in operator implementations registration. + +This module registers DEFAULT (FlagOS) and REFERENCE (PyTorch) implementations +for all supported operators by calling register_builtins from each backend. +""" + +from __future__ import annotations + +from .registry import OpRegistry + + +def register_builtins(registry: OpRegistry) -> None: + """ + Register all built-in operator implementations. + + This function registers: + - DEFAULT implementations (FlagOS/flag_gems) + - REFERENCE implementations (PyTorch) + - VENDOR implementations (CUDA, if available) + + Args: + registry: Registry to register into + """ + # Register FlagOS (DEFAULT) implementations + try: + from .backends.flagos.register_ops import register_builtins as register_flagos + + register_flagos(registry) + except Exception as e: + print(f"[WARNING] Failed to register FlagOS operators: {e}") + + # Register PyTorch (REFERENCE) implementations + try: + from .backends.reference.register_ops import register_builtins as register_reference + + register_reference(registry) + except Exception as e: + print(f"[WARNING] Failed to register Reference operators: {e}") + + # Register CUDA (VENDOR) implementations + try: + from .backends.vendor.cuda.register_ops import register_builtins as register_cuda + + register_cuda(registry) + except Exception as e: + # CUDA may not be available, this is expected + pass + + # Register HYGON (VENDOR) implementations + try: + from .backends.vendor.hygon.register_ops import register_builtins as register_hygon + + register_hygon(registry) + except Exception as e: + # HYGON may not be available, this is expected + pass + + # Register Metax (VENDOR) implementations + try: + from .backends.vendor.metax.register_ops import register_builtins as register_metax + + register_metax(registry) + except Exception as e: + # Metax may not be available, this is expected + pass + + # Register KUNLUNXIN (VENDOR) implementations + try: + from .backends.vendor.kunlunxin.register_ops import register_builtins as register_kunlunxin + + register_kunlunxin(registry) + except Exception as e: + # KunLunXin may not be available, this is expected + pass + + # Register Iluvatar (VENDOR) implementations + try: + from .backends.vendor.iluvatar.register_ops import register_builtins as register_iluvatar + + register_iluvatar(registry) + except Exception as e: + # Iluvatar may not be available, this is expected + pass + + # Register MUSA (VENDOR) implementations + try: + from .backends.vendor.musa.register_ops import register_builtins as register_musa + + register_musa(registry) + except Exception as e: + # MUSA may not be available, this is expected + pass diff --git a/transformer_engine/plugin/core/discovery.py b/transformer_engine/plugin/core/discovery.py new file mode 100644 index 0000000000..cfde3f4774 --- /dev/null +++ b/transformer_engine/plugin/core/discovery.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import importlib +import os +import sys +from typing import Any, Callable, List, Optional, Tuple + +from .logger_manager import get_logger + +PLUGIN_GROUP = "te_fl.plugin" + +PLUGIN_MODULES_ENV = "TE_FL_PLUGIN_MODULES" + +logger = get_logger() + +_discovered_plugin: List[Tuple[str, str, bool]] = [] + + +def _log_debug(msg: str) -> None: + logger.debug(msg) + + +def _log_info(msg: str) -> None: + logger.info(msg) + + +def _log_warning(msg: str) -> None: + logger.warning(msg) + + +def _log_error(msg: str) -> None: + logger.error(msg) + + +def _get_entry_points(): + try: + from importlib.metadata import entry_points + except ImportError: + try: + from importlib_metadata import entry_points + except ImportError: + _log_debug("importlib.metadata not available, skipping entry points discovery") + return [] + + try: + eps = entry_points() + + if hasattr(eps, "select"): + return list(eps.select(group=PLUGIN_GROUP)) + + if isinstance(eps, dict): + return eps.get(PLUGIN_GROUP, []) + + if hasattr(eps, "get"): + return eps.get(PLUGIN_GROUP, []) + + return [] + + except Exception as e: + _log_warning(f"Error accessing entry points: {e}") + return [] + + +def _call_register_function( + obj: Any, + registry_module: Any, + source_name: str, +) -> bool: + if callable(obj) and not isinstance(obj, type): + try: + obj(registry_module) + _log_info(f"Registered plugin from {source_name} (direct callable)") + return True + except Exception as e: + _log_error(f"Error calling plugin {source_name}: {e}") + return False + + register_fn = getattr(obj, "te_fl_register", None) or getattr(obj, "register", None) + + if callable(register_fn): + try: + register_fn(registry_module) + _log_info(f"Registered plugin from {source_name}") + return True + except Exception as e: + _log_error(f"Error calling register function in {source_name}: {e}") + return False + + _log_debug(f"No register function found in {source_name}") + return False + + +def discover_from_entry_points(registry_module: Any) -> int: + loaded = 0 + entry_points_list = _get_entry_points() + + if not entry_points_list: + _log_debug("No entry points found for group: " + PLUGIN_GROUP) + return 0 + + _log_debug(f"Found {len(entry_points_list)} entry points") + + for ep in entry_points_list: + ep_name = getattr(ep, "name", str(ep)) + try: + _log_debug(f"Loading entry point: {ep_name}") + obj = ep.load() + + if _call_register_function(obj, registry_module, f"entry_point:{ep_name}"): + _discovered_plugin.append((ep_name, "entry_point", True)) + loaded += 1 + else: + _discovered_plugin.append((ep_name, "entry_point", False)) + + except Exception as e: + _log_error(f"Failed to load entry point {ep_name}: {e}") + _discovered_plugin.append((ep_name, "entry_point", False)) + + return loaded + + +def discover_from_env_modules(registry_module: Any) -> int: + modules_str = os.environ.get(PLUGIN_MODULES_ENV, "").strip() + + if not modules_str: + return 0 + + loaded = 0 + module_names = [m.strip() for m in modules_str.split(",") if m.strip()] + + _log_debug(f"Loading plugin from env var: {module_names}") + + for mod_name in module_names: + try: + _log_debug(f"Importing module: {mod_name}") + mod = importlib.import_module(mod_name) + + if _call_register_function(mod, registry_module, f"env_module:{mod_name}"): + _discovered_plugin.append((mod_name, "env_module", True)) + loaded += 1 + else: + _discovered_plugin.append((mod_name, "env_module", False)) + + except ImportError as e: + _log_error(f"Failed to import plugin module {mod_name}: {e}") + _discovered_plugin.append((mod_name, "env_module", False)) + except Exception as e: + _log_error(f"Error loading plugin module {mod_name}: {e}") + _discovered_plugin.append((mod_name, "env_module", False)) + + return loaded + + +def discover_plugin(registry_module: Any) -> int: + """ + Main plugin discovery function. + + Discovers and registers plugin from: + 1. Entry points (group: 'te_fl.plugin') + 2. Environment variable modules (TE_FL_PLUGIN_MODULES) + + Args: + registry_module: OpRegistry instance to register plugin to + + Returns: + Number of successfully loaded plugin + """ + if registry_module is None: + _log_warning("Registry module is None, skipping plugin discovery") + return 0 + + _log_debug("Starting plugin discovery...") + + total = 0 + + total += discover_from_entry_points(registry_module) + + total += discover_from_env_modules(registry_module) + + _log_debug(f"Plugin discovery complete. Loaded {total} plugin.") + + return total + + +# Alias for compatibility with different naming conventions +discover_op_plugin = discover_plugin + + +def get_discovered_plugin() -> List[Tuple[str, str, bool]]: + """Get list of discovered plugin (name, source, success)""" + return _discovered_plugin.copy() + + +def clear_discovered_plugin() -> None: + """Clear the discovered plugin list (for testing)""" + _discovered_plugin.clear() diff --git a/transformer_engine/plugin/core/logger_manager.py b/transformer_engine/plugin/core/logger_manager.py new file mode 100644 index 0000000000..899d067e3e --- /dev/null +++ b/transformer_engine/plugin/core/logger_manager.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import logging +import sys +import os +import threading + + +class Logger: + def __init__(self, name, level=logging.INFO): + self.logger = logging.getLogger(name) + self.logger.setLevel(level) + self.logger.propagate = False + for handler in self.logger.handlers[:]: + self.logger.removeHandler(handler) + + formatter = logging.Formatter( + "[%(asctime)s %(name)s %(filename)s:%(lineno)d %(levelname)s] %(message)s" + ) + + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setFormatter(formatter) + + self.logger.addHandler(stream_handler) + self._printed_once = set() + + def info(self, message): + self.logger.info(message, stacklevel=2) + + def warning(self, message): + self.logger.warning(message, stacklevel=2) + + def error(self, message): + self.logger.error(message, stacklevel=2) + + def critical(self, message): + self.logger.critical(message, stacklevel=2) + + def debug(self, message): + self.logger.debug(message, stacklevel=2) + + def info_once(self, message): + if message not in self._printed_once: + self._printed_once.add(message) + self.logger.info(message, stacklevel=2) + + def warning_once(self, message): + if message not in self._printed_once: + self._printed_once.add(message) + self.logger.warning(message, stacklevel=2) + + def error_once(self, message): + if message not in self._printed_once: + self._printed_once.add(message) + self.logger.error(message, stacklevel=2) + + def debug_once(self, message): + if message not in self._printed_once: + self._printed_once.add(message) + self.logger.debug(message, stacklevel=2) + + +class LoggerManager: + _instance = None + _lock = threading.Lock() + + def __init__(self): + if hasattr(self, "_global_logger"): + return + + self._global_logger = None + self._global_printed_once = set() + self._printed_once_lock = threading.Lock() + + @classmethod + def get_instance(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.__init__() + return cls._instance + + def get_logger(self): + if self._global_logger is None: + with self._lock: + if self._global_logger is None: + level = os.getenv("TEFL_LOG_LEVEL", "INFO").upper() + self._global_logger = Logger("TE-FL", level) + return self._global_logger + + def print_once(self, message): + with self._printed_once_lock: + if message not in self._global_printed_once: + self._global_printed_once.add(message) + print(message) + + def debug_print_once(self, func_name: str, backend_name: str = "Backend", *args, **kwargs): + key = f"{backend_name}.{func_name}" + + with self._printed_once_lock: + if key not in self._global_printed_once: + self._global_printed_once.add(key) + print(f"[{backend_name}] Calling {func_name}") + if args: + print(f" args: {[type(a).__name__ for a in args[:5]]}...") + if kwargs: + print(f" kwargs: {list(kwargs.keys())[:5]}...") + print(f"[{backend_name}] {func_name} completed successfully") + + def reset(self): + with self._lock: + with self._printed_once_lock: + self._global_logger = None + self._global_printed_once.clear() + + +def get_logger(): + return LoggerManager.get_instance().get_logger() + + +def print_once(message): + LoggerManager.get_instance().print_once(message) + + +def debug_print_once(func_name: str, backend_name: str = "Backend", *args, **kwargs): + LoggerManager.get_instance().debug_print_once(func_name, backend_name, *args, **kwargs) diff --git a/transformer_engine/plugin/core/manager.py b/transformer_engine/plugin/core/manager.py new file mode 100644 index 0000000000..0a53c11f31 --- /dev/null +++ b/transformer_engine/plugin/core/manager.py @@ -0,0 +1,634 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import os +import threading +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple, Any + +from .discovery import discover_plugin +from .registry import OpRegistry +from .policy import SelectionPolicy, get_policy +from .types import OpImpl, BackendImplKind, match_token +from .logger_manager import get_logger + +logger = get_logger() + + +@dataclass +class _OpManagerState: + """Internal state for OpManager""" + + init_pid: int = -1 + initialized: bool = False + policy_epoch: int = 0 + + +class OpManager: + """ + Main manager for operator dispatching and selection. + + Responsibilities: + - Lazy initialization and plugin discovery + - Multi-process safety (PID detection + at_fork) + - Policy-based operator selection + - Dispatch caching with invalidation + """ + + def __init__(self, registry: Optional[OpRegistry] = None) -> None: + self._lock = threading.RLock() + self._registry = registry or OpRegistry() + self._state = _OpManagerState() + self._dispatch_cache: Dict[Tuple[str, str, int], Callable] = {} + self._impl_cache: Dict[str, OpImpl] = {} + self._impl_cache_meta: Dict[str, Tuple[str, int]] = {} + + # Register at_fork handler for multi-process safety + try: + os.register_at_fork(after_in_child=self._reset_after_fork) + except AttributeError: + # os.register_at_fork not available (Windows) + pass + + @property + def registry(self) -> OpRegistry: + """Get the underlying operator registry""" + return self._registry + + def _reset_after_fork(self) -> None: + """Reset state after process fork""" + with self._lock: + self._state.initialized = False + self._state.init_pid = -1 + self._state.policy_epoch += 1 + self._dispatch_cache.clear() + self._impl_cache.clear() + self._impl_cache_meta.clear() + logger.debug("OpManager reset after fork") + + def bump_policy_epoch(self) -> None: + """ + Increment policy epoch to invalidate dispatch cache. + + Call this when policy changes at runtime. + """ + with self._lock: + self._state.policy_epoch += 1 + self._dispatch_cache.clear() + logger.debug(f"Policy epoch bumped to {self._state.policy_epoch}") + + def ensure_initialized(self) -> None: + """ + Ensure the manager is initialized in the current process. + + Performs: + 1. PID check (multi-process safety) + 2. Register built-in operator implementations + 3. Discover and register plugin + """ + with self._lock: + pid = os.getpid() + + # Check if already initialized in this process + if self._state.initialized and self._state.init_pid == pid: + return + + logger.debug(f"Initializing OpManager in PID {pid}") + + # Mark as initialized + self._state.initialized = True + self._state.init_pid = pid + + # Register built-in operators + from . import builtin_ops + + builtin_ops.register_builtins(self._registry) + + # Discover and register plugin + discover_plugin(self._registry) + + # Invalidate cache + self._state.policy_epoch += 1 + self._dispatch_cache.clear() + + # Print initialization summary + snap = self._registry.snapshot() + total_ops = len(snap.impls_by_op) + total_impls = sum(len(impls) for impls in snap.impls_by_op.values()) + + logger.info( + f"OpManager initialized: {total_ops} ops with {total_impls} implementations" + ) + + # Group implementations by kind for summary + vendor_count = sum( + 1 + for impls in snap.impls_by_op.values() + for impl in impls + if impl.kind == BackendImplKind.VENDOR + ) + reference_count = sum( + 1 + for impls in snap.impls_by_op.values() + for impl in impls + if impl.kind == BackendImplKind.REFERENCE + ) + default_count = sum( + 1 + for impls in snap.impls_by_op.values() + for impl in impls + if impl.kind == BackendImplKind.DEFAULT + ) + + logger.debug( + f" Vendor: {vendor_count}, Default: {default_count}, Reference: {reference_count}" + ) + + # List all registered impl_ids + if logger.logger.isEnabledFor(logger.logger.level): + impl_ids = sorted( + set(impl.impl_id for impls in snap.impls_by_op.values() for impl in impls) + ) + logger.info(f"Registered impl_ids: {impl_ids}") + + def _matches_vendor_filters(self, impl: OpImpl, policy: SelectionPolicy) -> bool: + """Check if implementation matches policy vendor filters""" + if impl.kind != BackendImplKind.VENDOR: + return True + + if impl.vendor is None: + return False + + # Check deny list + if impl.vendor in policy.deny_vendors: + return False + + # Check allow list (if specified) + if policy.allow_vendors is not None and impl.vendor not in policy.allow_vendors: + return False + + return True + + def _default_order(self, policy: SelectionPolicy) -> list[str]: + """Get default selection order based on policy""" + return policy.get_default_order() + + def resolve(self, op_name: str) -> Callable: + """ + Resolve and return the best implementation for an operator. + + Selection process: + 1. Check dispatch cache + 2. Get all registered implementations + 3. Filter by policy (vendor allow/deny) + 4. Filter by availability (is_available()) + 5. Select best match using per-op order or default order + 6. Cache the result + + Args: + op_name: Name of the operator to resolve + + Returns: + Callable implementation function + + Raises: + RuntimeError: If no implementation found + """ + self.ensure_initialized() + + policy = get_policy() + policy_fp = policy.fingerprint() + epoch = self._state.policy_epoch + + # Check cache + cache_key = (op_name, policy_fp, epoch) + cached = self._dispatch_cache.get(cache_key) + if cached is not None: + return cached + + # Get all implementations for this operator + snap = self._registry.snapshot() + candidates = list(snap.impls_by_op.get(op_name, [])) + + # Filter by vendor policy + candidates = [c for c in candidates if self._matches_vendor_filters(c, policy)] + + # Filter by availability + available: list[OpImpl] = [] + for c in candidates: + try: + if c.is_available(): + available.append(c) + else: + logger.debug(f"Implementation {c.impl_id} not available for op={op_name}") + except Exception as e: + logger.warning(f"Error checking availability of {c.impl_id}: {e}") + continue + + candidates = available + + if not candidates: + raise RuntimeError( + f"No available implementation for op='{op_name}'. " + f"Registered: {[impl.impl_id for impl in snap.impls_by_op.get(op_name, [])]}" + ) + + # Get selection order (per-op or default) + order = policy.per_op_order_dict.get(op_name) or self._default_order(policy) + + # Select best implementation + chosen: Optional[OpImpl] = None + for token in order: + matches = [c for c in candidates if match_token(c, token)] + if not matches: + continue + + # Sort by priority (higher first), then by impl_id for stability + matches.sort(key=lambda x: (x.priority, x.impl_id), reverse=True) + chosen = matches[0] + break + + if chosen is None: + if policy.strict: + raise RuntimeError( + f"No implementation available for op='{op_name}' under strict policy. " + f"Candidates: {[c.impl_id for c in candidates]}" + ) + raise RuntimeError( + f"No implementation selected for op='{op_name}'. " + f"Candidates: {[c.impl_id for c in candidates]}, Order: {order}" + ) + + # Cache the result + self._dispatch_cache[cache_key] = chosen.fn + return chosen.fn + + def resolve_candidates(self, op_name: str) -> list[OpImpl]: + """ + Resolve and return all available implementations for an operator, + sorted by priority (highest first). + + This is similar to resolve() but returns all viable candidates + instead of just the best one. Useful for fallback mechanisms. + + Args: + op_name: Name of the operator to resolve + + Returns: + List of OpImpl sorted by priority (highest first) + + Raises: + RuntimeError: If no implementation found + """ + self.ensure_initialized() + + policy = get_policy() + + # Get all implementations for this operator + snap = self._registry.snapshot() + candidates = list(snap.impls_by_op.get(op_name, [])) + + # Filter by vendor policy + candidates = [c for c in candidates if self._matches_vendor_filters(c, policy)] + + # Filter by availability + available: list[OpImpl] = [] + for c in candidates: + try: + if c.is_available(): + available.append(c) + else: + logger.debug(f"Implementation {c.impl_id} not available for op={op_name}") + except Exception as e: + logger.warning(f"Error checking availability of {c.impl_id}: {e}") + continue + + candidates = available + + if not candidates: + raise RuntimeError( + f"No available implementation for op='{op_name}'. " + f"Registered: {[impl.impl_id for impl in snap.impls_by_op.get(op_name, [])]}" + ) + + # Get selection order (per-op or default) + order = policy.per_op_order_dict.get(op_name) or self._default_order(policy) + + # Sort candidates by order tokens, then by priority + sorted_candidates: list[OpImpl] = [] + for token in order: + matches = [c for c in candidates if match_token(c, token)] + if matches: + # Sort by priority (higher first), then by impl_id for stability + matches.sort(key=lambda x: (x.priority, x.impl_id), reverse=True) + sorted_candidates.extend(matches) + + # Remove duplicates while preserving order + seen = set() + unique_candidates = [] + for c in sorted_candidates: + if c.impl_id not in seen: + seen.add(c.impl_id) + unique_candidates.append(c) + + if not unique_candidates: + raise RuntimeError( + f"No implementation selected for op='{op_name}'. " + f"Candidates: {[c.impl_id for c in candidates]}, Order: {order}" + ) + + return unique_candidates + + def _is_cache_valid(self, op_name: str) -> bool: + """Check if cached impl is still valid for current policy""" + meta = self._impl_cache_meta.get(op_name) + if meta is None: + return False + cached_fp, cached_epoch = meta + policy = get_policy() + return cached_fp == policy.fingerprint() and cached_epoch == self._state.policy_epoch + + def _update_cache(self, op_name: str, impl: OpImpl) -> None: + """Update cache with new impl""" + policy = get_policy() + self._impl_cache[op_name] = impl + self._impl_cache_meta[op_name] = (policy.fingerprint(), self._state.policy_epoch) + + def _invalidate_cache(self, op_name: str) -> None: + """Invalidate cache for an op""" + self._impl_cache.pop(op_name, None) + self._impl_cache_meta.pop(op_name, None) + + def _get_last_impl_id(self, op_name: str) -> Optional[str]: + """Get last used impl_id (even if cache is stale)""" + impl = self._impl_cache.get(op_name) + return impl.impl_id if impl else None + + def call(self, op_name: str, *args, **kwargs): + """ + Resolve and call an operator implementation with optional fallback support. + + Logs on first call or when the implementation changes. + + Args: + op_name: Name of the operator + *args, **kwargs: Arguments passed to the implementation + + Returns: + Result from the implementation + + Raises: + RuntimeError: If all implementations fail + """ + enable_fallback = os.getenv("TE_FL_STRICT", "1") != "0" + + cached_impl = self._impl_cache.get(op_name) + cache_valid = self._is_cache_valid(op_name) + + if cache_valid and cached_impl is not None: + try: + return cached_impl.fn(*args, **kwargs) + except Exception as e: + if enable_fallback: + logger.warning_once( + f"Cached implementation '{cached_impl.impl_id}' failed for op" + f" '{op_name}': {e}" + ) + self._invalidate_cache(op_name) + else: + raise + + last_impl_id = self._get_last_impl_id(op_name) + + if not enable_fallback: + fn = self.resolve(op_name) + + snap = self._registry.snapshot() + for candidate in snap.impls_by_op.get(op_name, []): + if candidate.fn is fn: + self._update_cache(op_name, candidate) + + if last_impl_id is None: + logger.info_once( + f"Op '{op_name}' using '{candidate.impl_id}' " + f"(kind={candidate.kind.value}, vendor={candidate.vendor})" + ) + elif last_impl_id != candidate.impl_id: + logger.info_once( + f"Op '{op_name}' switched from '{last_impl_id}' to" + f" '{candidate.impl_id}' (kind={candidate.kind.value}," + f" vendor={candidate.vendor})" + ) + break + + return fn(*args, **kwargs) + + candidates = self.resolve_candidates(op_name) + last_error = None + + for idx, impl in enumerate(candidates): + try: + result = impl.fn(*args, **kwargs) + + self._update_cache(op_name, impl) + + if last_impl_id is None: + logger.info_once( + f"Op '{op_name}' using '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + elif last_impl_id != impl.impl_id: + if idx == 0: + logger.info_once( + f"Op '{op_name}' switched from '{last_impl_id}' to '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + else: + logger.info_once( + f"Op '{op_name}' fallback to '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + + return result + + except Exception as e: + last_error = e + if idx < len(candidates) - 1: + logger.warning_once( + f"Implementation '{impl.impl_id}' failed for op '{op_name}': {e}" + ) + else: + logger.error( + f"Last implementation '{impl.impl_id}' failed for op '{op_name}': {e}" + ) + + raise RuntimeError( + f"All {len(candidates)} implementation(s) failed for op='{op_name}'. " + f"Last error: {last_error}" + ) from last_error + + def call_with_custom_impl( + self, + op_name: str, + current_impl_class: type, + call_impl_fn: Callable[[type], Any], + ): + """ + Call an operator with custom implementation class support (for FlashAttention). + + Args: + op_name: Name of the operator + current_impl_class: The current implementation class + call_impl_fn: Function that takes impl_class and calls it + + Returns: + Result from the implementation + """ + enable_fallback = os.getenv("TE_FL_STRICT", "1") != "0" + + cached_impl = self._impl_cache.get(op_name) + cache_valid = self._is_cache_valid(op_name) + + if cache_valid and cached_impl is not None: + try: + cached_class = cached_impl.fn() + return call_impl_fn(cached_class) + except Exception as e: + if enable_fallback: + logger.warning_once( + f"Cached implementation '{cached_impl.impl_id}' failed for op" + f" '{op_name}': {e}" + ) + self._invalidate_cache(op_name) + else: + raise + + last_impl_id = self._get_last_impl_id(op_name) + + if not enable_fallback: + snap = self._registry.snapshot() + for impl in snap.impls_by_op.get(op_name, []): + try: + impl_class = impl.fn() + if impl_class == current_impl_class: + result = call_impl_fn(impl_class) + + self._update_cache(op_name, impl) + + if last_impl_id is None: + logger.info_once( + f"Op '{op_name}' using '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + elif last_impl_id != impl.impl_id: + logger.info_once( + f"Op '{op_name}' switched from '{last_impl_id}' to '{impl.impl_id}'" + f" (kind={impl.kind.value}, vendor={impl.vendor})" + ) + return result + except Exception: + continue + + return call_impl_fn(current_impl_class) + + candidates = self.resolve_candidates(op_name) + last_error = None + current_impl_id = None + + for impl in candidates: + try: + if impl.fn() == current_impl_class: + current_impl_id = impl.impl_id + break + except: + continue + + for idx, impl in enumerate(candidates): + try: + impl_class = impl.fn() + result = call_impl_fn(impl_class) + + self._update_cache(op_name, impl) + + if last_impl_id is None: + logger.info_once( + f"Op '{op_name}' using '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + elif last_impl_id != impl.impl_id: + if impl.impl_id == current_impl_id or idx == 0: + logger.info_once( + f"Op '{op_name}' switched from '{last_impl_id}' to '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + else: + logger.info_once( + f"Op '{op_name}' fallback to '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + + return result + + except Exception as e: + last_error = e + if idx < len(candidates) - 1: + logger.warning_once( + f"Implementation '{impl.impl_id}' failed for op '{op_name}': {e}" + ) + else: + logger.error( + f"Last implementation '{impl.impl_id}' failed for op '{op_name}': {e}" + ) + + raise RuntimeError( + f"All {len(candidates)} implementation(s) failed for op='{op_name}'. " + f"Last error: {last_error}" + ) from last_error + + def get_selected_impl_id(self, op_name: str) -> str: + """ + Get the impl_id of the currently selected implementation. + + Args: + op_name: Name of the operator + + Returns: + Implementation ID string + """ + fn = self.resolve(op_name) + + # Try to find the impl by function identity + snap = self._registry.snapshot() + for impl in snap.impls_by_op.get(op_name, []): + if impl.fn is fn: + return impl.impl_id + + return "unknown" + + +# Global default instance +_default_manager: Optional[OpManager] = None +_manager_lock = threading.RLock() + + +def get_default_manager() -> OpManager: + """Get or create the global default OpManager instance""" + global _default_manager + + if _default_manager is None: + with _manager_lock: + if _default_manager is None: + _default_manager = OpManager() + + return _default_manager + + +def reset_default_manager() -> None: + """Reset the global default OpManager (useful for testing)""" + global _default_manager + + with _manager_lock: + _default_manager = None diff --git a/transformer_engine/plugin/core/ops.py b/transformer_engine/plugin/core/ops.py new file mode 100644 index 0000000000..48292c7b22 --- /dev/null +++ b/transformer_engine/plugin/core/ops.py @@ -0,0 +1,2036 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type +from enum import IntEnum +from contextlib import nullcontext +import torch + +from .logger_manager import get_logger + +logger = get_logger() + + +################### Enums ################### +class DType(IntEnum): + kByte = 0 + kInt16 = 1 + kInt32 = 2 + kInt64 = 3 + kFloat32 = 4 + kFloat16 = 5 + kBFloat16 = 6 + kFloat8E4M3 = 7 + kFloat8E5M2 = 8 + kFloat8E8M0 = 9 + kFloat4E2M1 = 10 + kNumTypes = 11 + + +class Float8BlockScaleTensorFormat(IntEnum): + GEMM_READY = 0 + COMPACT = 1 + + +class NVTE_Activation_Type(IntEnum): + GELU = 0 + GEGLU = 1 + SILU = 2 + SWIGLU = 3 + RELU = 4 + REGLU = 5 + QGELU = 6 + QGEGLU = 7 + SRELU = 8 + SREGLU = 9 + CLAMPED_SWIGLU = 10 + + +class NVTE_Softmax_Type(IntEnum): + NVTE_VANILLA_SOFTMAX = 0 + NVTE_OFF_BY_ONE_SOFTMAX = 1 + NVTE_LEARNABLE_SOFTMAX = 2 + + +class CommGemmOverlapRole(IntEnum): + INPUT = 0 + OUTPUT = 1 + + +class FP8FwdTensors(IntEnum): + GEMM1_INPUT = 0 + GEMM1_WEIGHT = 1 + GEMM1_OUTPUT = 2 + GEMM2_INPUT = 3 + GEMM2_WEIGHT = 4 + GEMM2_OUTPUT = 5 + GEMM3_INPUT = 6 + GEMM3_WEIGHT = 7 + GEMM3_OUTPUT = 8 + + +class FP8BwdTensors(IntEnum): + GRAD_OUTPUT1 = 0 + GRAD_INPUT1 = 1 + GRAD_OUTPUT2 = 2 + GRAD_INPUT2 = 3 + GRAD_OUTPUT3 = 4 + GRAD_INPUT3 = 5 + + +class NVTE_Bias_Type(IntEnum): + NVTE_NO_BIAS = 0 + NVTE_PRE_SCALE_BIAS = 1 + NVTE_POST_SCALE_BIAS = 2 + NVTE_ALIBI = 3 + + +class NVTE_Mask_Type(IntEnum): + NVTE_NO_MASK = 0 + NVTE_PADDING_MASK = 1 + NVTE_CAUSAL_MASK = 2 + NVTE_PADDING_CAUSAL_MASK = 3 + NVTE_CAUSAL_BOTTOM_RIGHT_MASK = 4 + NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5 + + +class NVTE_Fused_Attn_Backend(IntEnum): + NVTE_No_Backend = -1 + NVTE_F16_max512_seqlen = 0 + NVTE_F16_arbitrary_seqlen = 1 + NVTE_FP8 = 2 + + +class NVTE_QKV_Format(IntEnum): + NVTE_SBHD = 0 + NVTE_BSHD = 1 + NVTE_THD = 2 + NVTE_BSHD_2SBHD = 3 + NVTE_SBHD_2BSHD = 4 + NVTE_THD_2BSHD = 5 + NVTE_THD_2SBHD = 6 + + +class NVTE_QKV_Layout(IntEnum): + NVTE_SB3HD = 0 + NVTE_SBH3D = 1 + NVTE_SBHD_SB2HD = 2 + NVTE_SBHD_SBH2D = 3 + NVTE_SBHD_SBHD_SBHD = 4 + NVTE_BS3HD = 5 + NVTE_BSH3D = 6 + NVTE_BSHD_BS2HD = 7 + NVTE_BSHD_BSH2D = 8 + NVTE_BSHD_BSHD_BSHD = 9 + NVTE_T3HD = 10 + NVTE_TH3D = 11 + NVTE_THD_T2HD = 12 + NVTE_THD_TH2D = 13 + NVTE_THD_THD_THD = 14 + NVTE_SBHD_BSHD_BSHD = 15 + NVTE_BSHD_SBHD_SBHD = 16 + NVTE_THD_BSHD_BSHD = 17 + NVTE_THD_SBHD_SBHD = 18 + NVTE_Paged_KV_BSHD_BSHD_BSHD = 19 + NVTE_Paged_KV_BSHD_SBHD_SBHD = 20 + NVTE_Paged_KV_SBHD_BSHD_BSHD = 21 + NVTE_Paged_KV_SBHD_SBHD_SBHD = 22 + NVTE_Paged_KV_THD_BSHD_BSHD = 23 + NVTE_Paged_KV_THD_SBHD_SBHD = 24 + + +class CommOverlapType(IntEnum): + RS = 0 + AG = 1 + + +class CommOverlapAlgo(IntEnum): + BULK_OVERLAP_AG = 0 + BULK_OVERLAP_RS = 1 + SPLIT_PIPELINED_AG_P2P = 2 + SPLIT_PIPELINED_RS = 3 + SPLIT_PIPELINED_RS_P2P = 4 + ATOMIC_GEMM_RS = 5 + ATOMIC_GEMM_AG_P2P = 6 + ATOMIC_GEMM_RS_P2P = 7 + EXTERNAL_BULK_OVERLAP_AG = 8 + + +############ Class ################# + + +class FP8TensorMeta: + """ + FP8TensorMeta wrapper that routes to the appropriate backend implementation. + """ + + def __new__(cls, *args, **kwargs): + from .manager import get_default_manager + + return get_default_manager().call("create_fp8_tensor_meta", *args, **kwargs) + + +class CommOverlapHelper: + """ + CommOverlapHelper wrapper that routes to the appropriate backend implementation. + """ + + def __new__(cls, *args, **kwargs): + from .manager import get_default_manager + + return get_default_manager().call("create_comm_overlap_helper", *args, **kwargs) + + +class CommOverlap: + """ + CommOverlap wrapper that routes to the appropriate backend implementation. + """ + + def __new__(cls, *args, **kwargs): + from .manager import get_default_manager + + return get_default_manager().call("create_comm_overlap", *args, **kwargs) + + +class CommOverlapP2P: + """ + CommOverlapP2P wrapper that routes to the appropriate backend implementation. + """ + + def __new__(cls, *args, **kwargs): + from .manager import get_default_manager + + return get_default_manager().call("create_comm_overlap_p2p", *args, **kwargs) + + +class FlashAttentionBase(torch.nn.Module, ABC): + def __init__( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = None, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + ) -> None: + super().__init__() + + self.softmax_scale = softmax_scale + self.attention_dropout = attention_dropout + self.attention_dropout_ctx = attention_dropout_ctx or nullcontext + self.attention_type = attention_type + self.layer_number = 1 if layer_number is None else layer_number + self.deterministic = deterministic + + # For fallback support + self._manager = None + self._init_params = None + + @abstractmethod + def _forward_impl( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Any] = None, + cp_global_ranks: Optional[List[int]] = None, + cp_stream: Optional[torch.cuda.Stream] = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers: Optional[Any] = None, + inference_params: Optional[Any] = None, + flash_attention_backend: Optional[Any] = None, + fp8_output: bool = False, + num_splits: Optional[int] = 1, + ) -> torch.Tensor: + """ + Actual forward implementation - subclasses must implement this. + + This method contains the backend-specific logic for flash attention. + """ + raise NotImplementedError("Subclasses must implement _forward_impl()") + + def forward( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Any] = None, + cp_global_ranks: Optional[List[int]] = None, + cp_stream: Optional[torch.cuda.Stream] = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers: Optional[Any] = None, + inference_params: Optional[Any] = None, + flash_attention_backend: Optional[Any] = None, + fp8_output: bool = False, + num_splits: Optional[int] = 1, + ) -> torch.Tensor: + """ + Forward pass with automatic fallback support and caching. + Delegates to OpManager.call_with_custom_impl for unified dispatch. + """ + if self._manager is None: + return self._forward_impl( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=cp_group, + cp_global_ranks=cp_global_ranks, + cp_stream=cp_stream, + cp_comm_type=cp_comm_type, + fp8=fp8, + fp8_meta=fp8_meta, + quantizers=quantizers, + inference_params=inference_params, + flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, + num_splits=num_splits, + ) + + def call_impl_fn(impl_class): + if impl_class == self.__class__: + return self._forward_impl( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=cp_group, + cp_global_ranks=cp_global_ranks, + cp_stream=cp_stream, + cp_comm_type=cp_comm_type, + fp8=fp8, + fp8_meta=fp8_meta, + quantizers=quantizers, + inference_params=inference_params, + flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, + num_splits=num_splits, + ) + else: + fallback_instance = impl_class(**self._init_params) + fallback_instance._manager = self._manager + fallback_instance._init_params = self._init_params + return fallback_instance._forward_impl( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=cp_group, + cp_global_ranks=cp_global_ranks, + cp_stream=cp_stream, + cp_comm_type=cp_comm_type, + fp8=fp8, + fp8_meta=fp8_meta, + quantizers=quantizers, + inference_params=inference_params, + flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, + num_splits=num_splits, + ) + + return self._manager.call_with_custom_impl( + op_name="get_flash_attention_class", + current_impl_class=self.__class__, + call_impl_fn=call_impl_fn, + ) + + @property + def backend_name(self) -> str: + return self.__class__.__name__ + + +############ Base ################### +class TEFLBackendBase(ABC): + @abstractmethod + def is_available(self) -> bool: + raise NotImplementedError + + def get_attention_backend(self, attention_params=None): + raise NotImplementedError + + ##### transformer_engine/pytorch/csrc/extensions/pybind.cpp ##### + def quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + output: Optional[Any] = None, + noop: Optional[torch.Tensor] = None, + ) -> Any: + raise NotImplementedError + + def dequantize( + self, + input: Any, + otype: DType, + ) -> Any: + raise NotImplementedError + + def bgrad_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + raise NotImplementedError + + def generic_gemm( + self, + A: Any, + transA: bool, + B: Any, + transB: bool, + D: Any, + quantizer: Any, + output_dtype: Optional[DType], + bias: Optional[torch.Tensor], + bias_type: DType, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[CommOverlapType] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, + ) -> List[Any]: + raise NotImplementedError + + # GLU # + def glu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + # GELU and variants # + def gelu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def geglu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def qgelu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def qgeglu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + # ReLU and variants # + def relu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def reglu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def srelu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def sreglu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + # SwiGLU and variants # + def silu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def swiglu( + self, + input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def clamped_swiglu( + self, + input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + raise NotImplementedError + + # Backward of GLU # + def dglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + # Backward of GELU and variants # + def dgelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def dgeglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def dqgelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def dqgeglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + # Backward of ReLU and variants # + def drelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def dreglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def dsrelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def dsreglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + # Backward of SiLU and variants # + def dsilu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def dswiglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> Any: + raise NotImplementedError + + def clamped_dswiglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + raise NotImplementedError + + # DBias + DAct fusions # + def dbias_dgelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + raise NotImplementedError + + def dbias_dsilu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + raise NotImplementedError + + def dbias_drelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + raise NotImplementedError + + def dbias_dqgelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + raise NotImplementedError + + def dbias_dsrelu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + raise NotImplementedError + + # Permutation functions + def moe_permute_fwd( + self, + input: torch.Tensor, + dtype: DType, + indices: torch.Tensor, + num_out_tokens: int, + workspace: List[torch.Tensor], + max_expanded_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + raise NotImplementedError + + def moe_permute_bwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + raise NotImplementedError + + def moe_unpermute_fwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + raise NotImplementedError + + def moe_unpermute_bwd( + self, + input_bwd: torch.Tensor, + input_fwd: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + # Softmax functions + def scaled_softmax_forward( + self, + input: torch.Tensor, + scale: float, + ) -> torch.Tensor: + raise NotImplementedError + + def scaled_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + raise NotImplementedError + + def scaled_masked_softmax_forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + raise NotImplementedError + + def scaled_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + raise NotImplementedError + + def scaled_upper_triang_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + raise NotImplementedError + + def scaled_upper_triang_masked_softmax_backward( + self, + output_grads_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + raise NotImplementedError + + def scaled_aligned_causal_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + raise NotImplementedError + + def scaled_aligned_causal_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + raise NotImplementedError + + # Other granular functions + def layernorm_fwd( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + raise NotImplementedError + + def layernorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + raise NotImplementedError + + def rmsnorm_fwd( + self, + input: Any, + weight: Any, + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + raise NotImplementedError + + def rmsnorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + raise NotImplementedError + + def rmsnorm_bwd_add( + self, + dz: torch.Tensor, + x: torch.Tensor, + add: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + raise NotImplementedError + + def multi_tensor_quantize( + self, + tensor_list: List[torch.Tensor], + quantizer_list: List[Any], + ) -> List[Any]: + raise NotImplementedError + + def group_quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + num_tensors: int, + first_dims: List[int], + ) -> Any: + raise NotImplementedError + + def bgrad_group_quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + num_tensors: int, + first_dims: List[int], + ) -> Any: + raise NotImplementedError + + def split_quantize( + self, + tensor: torch.Tensor, + split_sections: List[int], + quantizer_list: List[Any], + disable_bulk_allocation: bool = False, + ) -> List[Any]: + raise NotImplementedError + + def te_general_grouped_gemm( + self, + A: List[Any], + transa: bool, + B: List[Any], + transb: bool, + D: Optional[List[torch.Tensor]], + D_type: DType, + m_splits: List[int], + bias: List[torch.Tensor], + bias_type: DType, + single_output: bool, + pre_gelu_out: List[torch.Tensor], + grad: bool, + workspace: List[torch.Tensor], + workspaceSizes: int, + accumulate: bool, + use_split_accumulator: bool, + math_sm_count: int, + ) -> Optional[List[torch.Tensor]]: + raise NotImplementedError + + def te_general_grouped_gemm_for_grouped_tensor( + self, + *args, + **kwargs, + ) -> Optional[List[torch.Tensor]]: + raise NotImplementedError + + def te_general_grouped_gemm_for_discrete_in( + self, + *args, + **kwargs, + ) -> Optional[List[torch.Tensor]]: + raise NotImplementedError + + def te_general_grouped_gemm_for_discrete_out( + self, + *args, + **kwargs, + ) -> Optional[List[torch.Tensor]]: + raise NotImplementedError + + def fp8_transpose( + self, + input: torch.Tensor, + dtype: DType, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + raise NotImplementedError + + def swap_first_dims( + self, + tensor: torch.Tensor, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + raise NotImplementedError + + def nvfp4_data_transpose( + self, + input: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def swizzle_scales_for_gemm_( + self, + tensor: torch.Tensor, + ) -> None: + raise NotImplementedError + + def grouped_swizzle_for_gemm( + self, + tensor: Any, + rowwise: bool, + columnwise: bool, + ) -> None: + raise NotImplementedError + + def convert_host_pointers_to_tensor( + self, + tensor_lists: List[List[torch.Tensor]], + ) -> Any: + raise NotImplementedError + + def get_device_pointer_for_data_and_scales( + self, + data_tensors: List[torch.Tensor], + scale_tensors: List[torch.Tensor], + swizzle: bool = False, + rowwise: bool = True, + data_dtype: Any = None, + ) -> Any: + raise NotImplementedError + + def splits_to_offsets( + self, + first_dims: List[int], + logical_last_dim: int, + ) -> torch.Tensor: + raise NotImplementedError + + def get_fused_attn_backend( + self, + is_training: bool, + q_dtype: DType, + kv_dtype: DType, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + p_dropout: float, + num_attn_heads: int, + num_gqa_groups: int, + max_seqlen_q: int, + max_seqlen_kv: int, + head_dim_qk: int, + head_dim_v: int, + window_size_left: int, + window_size_right: int, + return_max_logit: bool, + cuda_graph: bool = False, + deterministic: bool = False, + ) -> NVTE_Fused_Attn_Backend: + raise NotImplementedError + + def compute_amax( + self, + input: torch.Tensor, + amax: torch.Tensor, + ) -> None: + raise NotImplementedError + + def fused_amax_and_scale_update_after_reduction( + self, + amax_reduction_buffer: torch.Tensor, + amax_histories: List[torch.Tensor], + scales: List[torch.Tensor], + amax_compute_algo: str, + fp8_dtype: DType, + margin: float, + ) -> None: + raise NotImplementedError + + def fp8_block_scaling_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + ) -> None: + raise NotImplementedError + + def fp8_block_scaling_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + out_dtype: DType, + ) -> None: + raise NotImplementedError + + # MXFP8 scaling + def mxfp8_scaling_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + ) -> None: + raise NotImplementedError + + def mxfp8_scaling_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + out_dtype: DType, + ) -> None: + raise NotImplementedError + + # NVFP4 2D + def nvfp4_2d_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int = 16, + ) -> None: + raise NotImplementedError + + def nvfp4_multi_tensor_compute_partial_amax( + self, + master_weight_list: List[torch.Tensor], + partial_amax_list: List[torch.Tensor], + global_amax_list: List[torch.Tensor], + h_list: List[int], + w_list: List[int], + start_offset_list: List[int], + block_len: int = 16, + ) -> None: + raise NotImplementedError + + def nvfp4_compute_global_scale( + self, + global_amaxes: torch.Tensor, + global_scale_tensor: torch.Tensor, + ) -> None: + raise NotImplementedError + + def nvfp4_compute_per_block_scale( + self, + *args, + **kwargs, + ) -> None: + raise NotImplementedError + + def nvfp4_expand_scale_to_fp8( + self, + *args, + **kwargs, + ) -> None: + raise NotImplementedError + + def nvfp4_fused_scale( + self, + *args, + **kwargs, + ) -> None: + raise NotImplementedError + + def nvfp4_multi_tensor_fused_scale( + self, + block_amax_list: List[torch.Tensor], + global_amax_list: List[torch.Tensor], + per_block_scale_list: List[torch.Tensor], + target_scale_list: List[torch.Tensor], + target_amax_list: List[torch.Tensor], + tile_rows_list: List[int], + tile_cols_list: List[int], + rows_padded_list: List[int], + block_len: int, + ) -> None: + raise NotImplementedError + + def nvfp4_2d_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + global_scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int = 16, + ) -> None: + raise NotImplementedError + + def nvfp4_multi_tensor_2d_partial_cast( + self, + inp_list: List[torch.Tensor], + *args, + **kwargs, + ) -> None: + raise NotImplementedError + + def nvfp4_2d_multi_tensor_transpose( + self, + rowwise_data_list: List[torch.Tensor], + columnwise_data_list: List[torch.Tensor], + rowwise_scale_inv_list: List[torch.Tensor], + columnwise_scale_inv_list: List[torch.Tensor], + M_list: List[int], + K_list: List[int], + ) -> None: + raise NotImplementedError + + def fused_multi_row_padding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + padded_input_row_list: List[int], + ) -> None: + raise NotImplementedError + + def fused_multi_row_unpadding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + unpadded_input_row_list: List[int], + ) -> None: + raise NotImplementedError + + # attention kernels + def fa_prepare_fwd( + self, + qkvi: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + def fa_prepare_bwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + def fused_attn_fwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + is_training: bool, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + fake_dtype: torch.dtype, + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + page_table_k: Optional[torch.Tensor], + page_table_v: Optional[torch.Tensor], + s_quantizer: Any, + o_quantizer: Any, + Bias: Optional[torch.Tensor], + SoftmaxOffset: Optional[torch.Tensor], + rng_gen: Optional[torch.Generator], + rng_elts_per_thread: int, + return_max_logit: bool, + cuda_graph: bool = False, + ) -> List[Any]: + raise NotImplementedError + + def fused_attn_bwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + bottom_right_diagonal: Optional[bool], + deterministic: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + O: Any, + dO: Any, + fake_dtype: torch.dtype, + dqkv_type: DType, + Aux_CTX_Tensors: List[torch.Tensor], + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + s_quantizer: Any, + dp_quantizer: Any, + dqkv_quantizer: Any, + cuda_graph: bool = False, + ) -> List[Any]: + raise NotImplementedError + + def copy_to_kv_cache( + self, + new_k: torch.Tensor, + new_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + cu_new_lens: torch.Tensor, + cu_cached_lens: torch.Tensor, + qkv_format: NVTE_QKV_Format, + b: int, + max_ctx_len: int, + max_seq_len: int, + max_pages_per_seq: int, + is_non_paged: bool, + ) -> None: + raise NotImplementedError + + def convert_thd_to_bshd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + b: int, + max_seq_len: int, + ) -> torch.Tensor: + raise NotImplementedError + + def convert_bshd_to_thd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + t: int, + ) -> torch.Tensor: + raise NotImplementedError + + # fused apply rope + def fused_rope_forward( + self, + input: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + raise NotImplementedError + + def fused_rope_backward( + self, + output_grads: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + raise NotImplementedError + + def fused_qkv_rope_forward( + self, + qkv_input: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def fused_qkv_rope_backward( + self, + q_grad_out: torch.Tensor, + k_grad_out: torch.Tensor, + v_grad_out: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + raise NotImplementedError + + # fused router + def fused_topk_with_score_function_fwd( + self, + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], + score_function: str, + expert_bias: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def fused_topk_with_score_function_bwd( + self, + num_tokens: int, + num_experts: int, + routing_map: torch.Tensor, + intermediate_output: torch.Tensor, + grad_probs: torch.Tensor, + topk: int, + use_pre_softmax: bool, + scaling_factor: Optional[float], + score_function: str, + ) -> torch.Tensor: + raise NotImplementedError + + def fused_score_for_moe_aux_loss_fwd( + self, + logits: torch.Tensor, + topk: int, + score_function: str, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def fused_score_for_moe_aux_loss_bwd( + self, + num_tokens: int, + num_experts: int, + intermediate_output: torch.Tensor, + grad_scores: torch.Tensor, + topk: int, + score_function: str, + ) -> torch.Tensor: + raise NotImplementedError + + def fused_moe_aux_loss_fwd( + self, + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + num_experts: int, + num_rows: int, + num_cols: int, + topk: int, + coeff: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def fused_moe_aux_loss_bwd( + self, + Const_buf: torch.Tensor, + tokens_per_expert: torch.Tensor, + num_rows: int, + num_cols: int, + grad_aux_loss: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + # Dropout + def dropout_fwd( + self, + input: torch.Tensor, + dropout_probability: float, + out: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def dropout_bwd( + self, + grad_output: torch.Tensor, + mask: torch.Tensor, + dropout_probability: float, + grad_input: Optional[torch.Tensor], + ) -> torch.Tensor: + raise NotImplementedError + + # Misc + def get_cublasLt_version(self) -> int: + raise NotImplementedError + + def get_cudnn_version(self) -> int: + raise NotImplementedError + + def get_num_cublas_streams(self) -> int: + raise NotImplementedError + + # Support THD format for Context Parallel + def thd_read_half_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + half_idx: int, + ) -> torch.Tensor: + raise NotImplementedError + + def thd_second_half_lse_correction( + self, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + ) -> None: + raise NotImplementedError + + def thd_read_second_half_lse( + self, + lse: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + second_half_lse_seqlen: int, + ) -> torch.Tensor: + raise NotImplementedError + + def thd_out_correction( + self, + out: torch.Tensor, + out_per_step: torch.Tensor, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + only_second_half: bool, + lse_packed: bool, + ) -> None: + raise NotImplementedError + + def thd_grad_correction( + self, + grad: torch.Tensor, + grad_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + first_half: str, + second_half: str, + ) -> None: + raise NotImplementedError + + def thd_get_partitioned_indices( + self, + cu_seqlens: torch.Tensor, + total_tokens: int, + world_size: int, + rank: int, + ) -> torch.Tensor: + raise NotImplementedError + + # nvshmem functions + def init_nvshmem_backend( + self, + process_group: Any, + ) -> None: + raise NotImplementedError + + def create_nvshmem_tensor( + self, + shape: List[int], + dtype: torch.dtype, + ) -> torch.Tensor: + raise NotImplementedError + + def nvshmem_send_on_current_stream( + self, + src: torch.Tensor, + dst: torch.Tensor, + peer: int, + signal: torch.Tensor, + ) -> None: + raise NotImplementedError + + def nvshmem_wait_on_current_stream( + self, + signal: torch.Tensor, + wait_kind: str, + ) -> None: + raise NotImplementedError + + def nvshmem_finalize(self) -> None: + raise NotImplementedError + + # multi-tensor functions + def multi_tensor_scale( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, + ) -> None: + raise NotImplementedError + + def multi_tensor_scale_tensor( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: torch.Tensor, + ) -> None: + raise NotImplementedError + + def multi_tensor_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def multi_tensor_unscale_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + inv_scale: torch.Tensor, + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def multi_tensor_adam( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + raise NotImplementedError + + def multi_tensor_adam_param_remainder( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + raise NotImplementedError + + def multi_tensor_adam_fp8( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + fp8_dtype: DType, + ) -> None: + raise NotImplementedError + + def multi_tensor_adam_capturable( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + raise NotImplementedError + + def multi_tensor_adam_capturable_master( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + raise NotImplementedError + + def multi_tensor_sgd( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + wd: float, + momentum: float, + dampening: float, + lr: float, + nesterov: bool, + first_run: bool, + wd_after_momentum: bool, + scale: float, + ) -> None: + raise NotImplementedError + + def multi_tensor_compute_scale_and_scale_inv( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + max_fp8: float, + force_pow_2_scales: bool, + epsilon: float, + ) -> None: + raise NotImplementedError + + def multi_tensor_compute_scale_inv_e8m0( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + block_len: int, + ) -> None: + raise NotImplementedError + + # Comm+GEMM Overlap + def bulk_overlap_ag_with_external_gemm( + self, + allgather_communicator: CommOverlap, + send_stream: Any, + recv_stream: Any, + ) -> Any: + raise NotImplementedError + + ############## class func ################################# + def create_fp8_tensor_meta(self) -> FP8TensorMeta: + """Create FP8TensorMeta instance.""" + raise NotImplementedError + + def create_comm_overlap_helper( + self, + world_group: Optional[Any] = None, + intra_node_group: Optional[Any] = None, + ) -> "CommOverlapHelper": + """ + Internal method to create CommOverlapHelper. + Users should use CommOverlapHelper(...) directly. + """ + raise NotImplementedError + + def create_comm_overlap( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + num_splits: int = 3, + num_max_streams: int = 3, + comm_cga_size: int = 2, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 16, + set_sm_margin: bool = True, + atomic_gemm: bool = False, + rs_overlap_first_gemm: bool = False, + ) -> "CommOverlap": + """ + Internal method to create CommOverlap. + Users should use CommOverlap(...) directly. + """ + raise NotImplementedError + + def create_comm_overlap_p2p( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + comm_type: Any, + num_max_streams: int = 3, + comm_cga_size: int = 1, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 1, + set_sm_margin: bool = False, + atomic_gemm: bool = False, + use_ce: bool = True, + aggregate: bool = False, + ) -> "CommOverlapP2P": + """ + Internal method to create CommOverlapP2P. + Users should use CommOverlapP2P(...) directly. + """ + raise NotImplementedError + + def get_flash_attention_class(self) -> Type["FlashAttentionBase"]: + raise NotImplementedError + + +############ Wapper ################# +class TEFLModule: + def __init__(self, manager=None): + """ + Initialize TEFLModule. + + Args: + manager: OpManager instance for operator dispatch. + If None, will use the global default OpManager. + """ + # Import here to avoid circular dependency + from .manager import get_default_manager + + self._manager = manager if manager is not None else get_default_manager() + # emum + self.DType = DType + self.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat + self.FP8FwdTensors = FP8FwdTensors + self.FP8BwdTensors = FP8BwdTensors + self.NVTE_Activation_Type = NVTE_Activation_Type + self.NVTE_Bias_Type = NVTE_Bias_Type + self.NVTE_Mask_Type = NVTE_Mask_Type + self.NVTE_Softmax_Type = NVTE_Softmax_Type + self.NVTE_Fused_Attn_Backend = NVTE_Fused_Attn_Backend + self.NVTE_QKV_Format = NVTE_QKV_Format + self.NVTE_QKV_Layout = NVTE_QKV_Layout + self.CommOverlapType = CommOverlapType + self.CommOverlapAlgo = CommOverlapAlgo + self.CommGemmOverlapRole = CommGemmOverlapRole + # class + self.FP8TensorMeta = FP8TensorMeta + self.CommOverlapHelper = CommOverlapHelper + self.CommOverlap = CommOverlap + self.CommOverlapP2P = CommOverlapP2P + + def __getattr__(self, name: str) -> Any: + """ + Dynamically resolve operators through OpManager. + """ + if name.startswith("_"): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + # Verify the operator exists before returning the bound call method + try: + self._manager.ensure_initialized() + available_ops = self._manager.registry.list_operators() + if name not in available_ops: + raise AttributeError( + f"Operator '{name}' not found. Available operators: {available_ops}" + ) + except RuntimeError as e: + # Re-raise as AttributeError for better error messages + raise AttributeError(f"Error accessing operator '{name}': {e}") from e + + # Return a bound call method for this operator + import functools + + return functools.partial(self._manager.call, name) + + def __dir__(self): + module_attrs = [ + "DType", + "Float8BlockScaleTensorFormat", + "FP8FwdTensors", + "FP8BwdTensors", + "FP8TensorMeta", + "NVTE_Activation_Type", + "NVTE_Bias_Type", + "NVTE_Mask_Type", + "NVTE_Softmax_Type", + "NVTE_Fused_Attn_Backend", + "NVTE_QKV_Format", + "NVTE_QKV_Layout", + "CommOverlapType", + "CommOverlapAlgo", + "CommGemmOverlapRole", + "CommOverlapHelper", + "CommOverlap", + "CommOverlapP2P", + ] + + # Add operator names from OpManager's registry + op_attrs = self._manager.registry.list_operators() + + return list(set(module_attrs + op_attrs)) + + def __getitem__(self, key: str): + return self.__getattr__(key) + + @property + def __all__(self): + return self.__dir__() + + def flash_attention( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = None, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + ) -> "FlashAttentionBase": + """ + Get FlashAttention implementation through OpManager. + """ + # Get the flash attention class getter through OpManager.call + # This provides the same fallback support and logging as other operators + flash_attn_class = self._manager.call("get_flash_attention_class") + + # Prepare initialization parameters + init_params = { + "softmax_scale": softmax_scale, + "attention_dropout": attention_dropout, + "attention_dropout_ctx": attention_dropout_ctx, + "attention_type": attention_type, + "layer_number": layer_number, + "deterministic": deterministic, + } + + # Instantiate the FlashAttention + instance = flash_attn_class(**init_params) + + # Set manager and init_params for fallback support + instance._manager = self._manager + instance._init_params = init_params + + return instance + + def __repr__(self) -> str: + op_count = len(self._manager.registry.list_operators()) + return f"TEFLModule(operators={op_count}, manager={self._manager.__class__.__name__})" + + +# Global singleton instance +_global_tefl_module: Optional[TEFLModule] = None +_tefl_module_lock = None + + +def get_tefl_module() -> TEFLModule: + """ + Get or create the global TEFLModule instance. + + This function returns a singleton TEFLModule that uses the default OpManager. + The instance is created lazily on first access. + + Returns: + The global TEFLModule instance + + Example: + >>> import core as te_fl + >>> # Or explicitly: + >>> from core.base import get_tefl_module + >>> te_fl = get_tefl_module() + >>> result = te_fl.rmsnorm_fwd(input, weight, eps=1e-5) + """ + global _global_tefl_module, _tefl_module_lock + + if _global_tefl_module is None: + # Import here to avoid issues at module load time + import threading + + if _tefl_module_lock is None: + _tefl_module_lock = threading.RLock() + + with _tefl_module_lock: + if _global_tefl_module is None: + _global_tefl_module = TEFLModule() + + return _global_tefl_module + + +def reset_tefl_module() -> None: + """ + Reset the global TEFLModule instance. + + This is primarily useful for testing. After calling this function, + the next call to get_tefl_module() will create a fresh instance. + + Warning: + This function is not thread-safe and should only be used in + single-threaded test environments. + """ + global _global_tefl_module, _tefl_module_lock + + if _tefl_module_lock is None: + import threading + + _tefl_module_lock = threading.RLock() + + with _tefl_module_lock: + _global_tefl_module = None + + +# Backward compatibility functions +def get_registry(): + """ + Get the global OpRegistry instance (via OpManager). + + DEPRECATED: Use get_default_manager().registry instead. + + This function is kept for backward compatibility with code that + expects the old API. + + Returns: + The OpRegistry instance from the default OpManager + + Example: + >>> from core.base import get_registry + >>> registry = get_registry() + >>> ops = registry.list_operators() + """ + from .manager import get_default_manager + + return get_default_manager().registry + + +def get_manager(): + """ + Get the global OpManager instance. + + This is the recommended way to access the OpManager. + + Returns: + The default OpManager instance + + Example: + >>> from core.base import get_manager + >>> manager = get_manager() + >>> impl_fn = manager.resolve("rmsnorm_fwd") + """ + from .manager import get_default_manager + + return get_default_manager() + + +def reset_registry() -> None: + """ + Reset the global OpManager and OpRegistry. + + DEPRECATED: Use reset_default_manager() instead. + + This function is kept for backward compatibility. + """ + from .manager import reset_default_manager + + reset_default_manager() + # Also reset the TEFLModule singleton since it depends on OpManager + reset_tefl_module() diff --git a/transformer_engine/plugin/core/policy.py b/transformer_engine/plugin/core/policy.py new file mode 100644 index 0000000000..ce1ac9d7e0 --- /dev/null +++ b/transformer_engine/plugin/core/policy.py @@ -0,0 +1,397 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import contextvars +import os +import threading +from dataclasses import dataclass, field +from typing import Dict, FrozenSet, List, Optional, Set, Tuple + +from .types import BackendImplKind + + +# Valid preference values for TE_FL_PREFER +PREFER_DEFAULT = "flagos" +PREFER_VENDOR = "vendor" +PREFER_REFERENCE = "reference" + +VALID_PREFER_VALUES = frozenset({PREFER_DEFAULT, PREFER_VENDOR, PREFER_REFERENCE}) + + +@dataclass(frozen=True) +class SelectionPolicy: + """ + Policy for selecting operator implementations. + + Attributes: + prefer: Which implementation kind to prefer. One of: + - "flagos": Prefer DEFAULT (FlagOS) implementations + - "vendor": Prefer VENDOR (CUDA) implementations + - "reference": Prefer REFERENCE (PyTorch) implementations + strict: If True, raise error when primary implementation fails + per_op_order: Per-operator custom selection order + deny_vendors: Set of vendor names to deny + allow_vendors: Set of vendor names to allow (whitelist) + """ + + prefer: str = PREFER_DEFAULT + strict: bool = False + per_op_order: Tuple[Tuple[str, Tuple[str, ...]], ...] = field(default_factory=tuple) + + deny_vendors: FrozenSet[str] = field(default_factory=frozenset) + allow_vendors: Optional[FrozenSet[str]] = None + + def __post_init__(self): + if self.prefer not in VALID_PREFER_VALUES: + raise ValueError( + f"Invalid prefer value: '{self.prefer}'. " + f"Must be one of: {', '.join(sorted(VALID_PREFER_VALUES))}" + ) + + @classmethod + def from_dict( + cls, + prefer: str = PREFER_DEFAULT, + strict: bool = False, + per_op_order: Optional[Dict[str, List[str]]] = None, + deny_vendors: Optional[Set[str]] = None, + allow_vendors: Optional[Set[str]] = None, + ) -> "SelectionPolicy": + per_op_tuple = tuple() + if per_op_order: + per_op_tuple = tuple((k, tuple(v)) for k, v in sorted(per_op_order.items())) + + return cls( + prefer=prefer.lower(), + strict=strict, + per_op_order=per_op_tuple, + deny_vendors=frozenset(deny_vendors) if deny_vendors else frozenset(), + allow_vendors=frozenset(allow_vendors) if allow_vendors else None, + ) + + @property + def per_op_order_dict(self) -> Dict[str, List[str]]: + """Get per_op_order as a mutable dict for easier access""" + return {k: list(v) for k, v in self.per_op_order} + + def get_per_op_order(self, op_name: str) -> Optional[List[str]]: + """Get order for a specific operator""" + for name, order in self.per_op_order: + if name == op_name: + return list(order) + return None + + def get_default_order(self) -> List[str]: + """Get the default selection order based on preference setting.""" + if self.prefer == PREFER_REFERENCE: + return ["reference", "flagos", "vendor"] + elif self.prefer == PREFER_VENDOR: + return ["vendor", "flagos", "reference"] + else: # PREFER_DEFAULT + return ["flagos", "vendor", "reference"] + + def is_vendor_allowed(self, vendor_name: str) -> bool: + if vendor_name in self.deny_vendors: + return False + if self.allow_vendors is not None and vendor_name not in self.allow_vendors: + return False + return True + + def fingerprint(self) -> str: + parts = [ + f"prefer={self.prefer}", + f"st={int(self.strict)}", + ] + + if self.allow_vendors: + parts.append(f"allow={','.join(sorted(self.allow_vendors))}") + + if self.deny_vendors: + parts.append(f"deny={','.join(sorted(self.deny_vendors))}") + + if self.per_op_order: + per_op_str = ";".join(f"{k}={'|'.join(v)}" for k, v in self.per_op_order) + parts.append(f"per={per_op_str}") + + return ";".join(parts) + + def __hash__(self) -> int: + return hash( + ( + self.prefer, + self.strict, + self.per_op_order, + self.deny_vendors, + self.allow_vendors, + ) + ) + + +class PolicyManager: + _instance = None + _lock = threading.Lock() + + def __init__(self): + if hasattr(self, "_policy_epoch"): + return + + self._policy_epoch = 0 + self._policy_epoch_lock = threading.Lock() + self._global_policy = None + self._global_policy_lock = threading.Lock() + + self._policy_var = contextvars.ContextVar( + "te_fl_selection_policy", + default=None, + ) + + @classmethod + def get_instance(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.__init__() + return cls._instance + + def get_policy_epoch(self) -> int: + return self._policy_epoch + + def bump_policy_epoch(self) -> int: + with self._policy_epoch_lock: + self._policy_epoch += 1 + return self._policy_epoch + + def get_policy(self) -> SelectionPolicy: + ctx_policy = self._policy_var.get() + if ctx_policy is not None: + return ctx_policy + + if self._global_policy is None: + with self._global_policy_lock: + if self._global_policy is None: + self._global_policy = self._policy_from_env() + return self._global_policy + + def set_global_policy(self, policy: SelectionPolicy) -> SelectionPolicy: + with self._global_policy_lock: + old_policy = self._global_policy + self._global_policy = policy + self.bump_policy_epoch() + return old_policy if old_policy else self._policy_from_env() + + def reset_global_policy(self) -> None: + with self._global_policy_lock: + self._global_policy = None + self.bump_policy_epoch() + + def create_policy_context(self, policy: SelectionPolicy): + return _PolicyContext(self, policy) + + def _get_policy_var(self): + return self._policy_var + + @staticmethod + def _parse_csv_set(value: str) -> Set[str]: + if not value: + return set() + return {x.strip() for x in value.split(",") if x.strip()} + + @staticmethod + def _parse_per_op(value: str) -> Dict[str, List[str]]: + if not value: + return {} + + result: Dict[str, List[str]] = {} + parts = [p.strip() for p in value.split(";") if p.strip()] + + for part in parts: + if "=" not in part: + continue + op_name, order_str = part.split("=", 1) + op_name = op_name.strip() + order = [x.strip() for x in order_str.split("|") if x.strip()] + if op_name and order: + result[op_name] = order + + return result + + def _policy_from_env(self) -> SelectionPolicy: + # Priority: TE_FL_PREFER (highest) > TE_FL_PREFER_VENDOR (legacy) + # + # TE_FL_PREFER: Explicit preference by name (flagos, vendor, reference) + # TE_FL_PREFER_VENDOR: Legacy boolean flag (1=vendor, 0=flagos) + + prefer_str = None + + # 1. Check TE_FL_PREFER first (highest priority) + te_fl_prefer = os.environ.get("TE_FL_PREFER", "").strip().lower() + if te_fl_prefer: + if te_fl_prefer in VALID_PREFER_VALUES: + prefer_str = te_fl_prefer + else: + print( + f"[WARNING] Invalid TE_FL_PREFER value: '{te_fl_prefer}'. " + f"Valid values: {', '.join(sorted(VALID_PREFER_VALUES))}" + ) + + # 2. Fall back to TE_FL_PREFER_VENDOR (legacy) + if prefer_str is None: + prefer_vendor = os.environ.get("TE_FL_PREFER_VENDOR", "").strip() + if prefer_vendor == "1": + prefer_str = PREFER_VENDOR + elif prefer_vendor == "0": + prefer_str = PREFER_DEFAULT + else: + # Default behavior: prefer default (FlagOS) + prefer_str = PREFER_DEFAULT + + strict = os.environ.get("TE_FL_STRICT", "0").strip() == "1" + + deny_str = os.environ.get("TE_FL_DENY_VENDORS", "").strip() + deny_vendors = self._parse_csv_set(deny_str) if deny_str else None + + allow_str = os.environ.get("TE_FL_ALLOW_VENDORS", "").strip() + allow_vendors = self._parse_csv_set(allow_str) if allow_str else None + + per_op_str = os.environ.get("TE_FL_PER_OP", "").strip() + per_op_order = self._parse_per_op(per_op_str) if per_op_str else None + + return SelectionPolicy.from_dict( + prefer=prefer_str, + strict=strict, + per_op_order=per_op_order, + deny_vendors=deny_vendors, + allow_vendors=allow_vendors, + ) + + +class _PolicyContext: + + def __init__(self, manager: PolicyManager, policy: SelectionPolicy): + self._manager = manager + self._policy = policy + self._token: Optional[contextvars.Token] = None + + def __enter__(self) -> "_PolicyContext": + policy_var = self._manager._get_policy_var() + self._token = policy_var.set(self._policy) + self._manager.bump_policy_epoch() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if self._token is not None: + policy_var = self._manager._get_policy_var() + policy_var.reset(self._token) + self._manager.bump_policy_epoch() + + +# Convenience functions for easier access +def get_policy_epoch() -> int: + """Get the current policy epoch""" + return PolicyManager.get_instance().get_policy_epoch() + + +def bump_policy_epoch() -> int: + """Bump the policy epoch and return the new value""" + return PolicyManager.get_instance().bump_policy_epoch() + + +def get_policy() -> SelectionPolicy: + """Get the current effective policy (context or global)""" + return PolicyManager.get_instance().get_policy() + + +def set_global_policy(policy: SelectionPolicy) -> SelectionPolicy: + """Set the global policy and return the old policy""" + return PolicyManager.get_instance().set_global_policy(policy) + + +def reset_global_policy() -> None: + """Reset the global policy to environment defaults""" + PolicyManager.get_instance().reset_global_policy() + + +def policy_from_env() -> SelectionPolicy: + """Create a SelectionPolicy from environment variables""" + return PolicyManager.get_instance()._policy_from_env() + + +def policy_context(policy: SelectionPolicy) -> _PolicyContext: + """ + Create a context manager to temporarily override the policy. + + Example: + >>> with policy_context(my_policy): + ... # Use my_policy in this context + ... result = manager.resolve("op_name") + """ + return _PolicyContext(PolicyManager.get_instance(), policy) + + +# Convenience context managers +def with_strict_mode() -> _PolicyContext: + """Context manager to enable strict mode""" + current = get_policy() + strict_policy = SelectionPolicy.from_dict( + prefer=current.prefer, + strict=True, + per_op_order={k: list(v) for k, v in current.per_op_order}, + deny_vendors=set(current.deny_vendors), + allow_vendors=set(current.allow_vendors) if current.allow_vendors else None, + ) + return policy_context(strict_policy) + + +def with_preference(prefer: str) -> _PolicyContext: + """ + Context manager to set implementation preference. + + Args: + prefer: One of "flagos", "vendor", or "reference" + + Example: + >>> with with_preference("vendor"): + ... # Prefer vendor implementations in this context + ... result = manager.resolve("op_name") + """ + current = get_policy() + policy = SelectionPolicy.from_dict( + prefer=prefer, + strict=current.strict, + per_op_order={k: list(v) for k, v in current.per_op_order}, + deny_vendors=set(current.deny_vendors), + allow_vendors=set(current.allow_vendors) if current.allow_vendors else None, + ) + return policy_context(policy) + + +def with_allowed_vendors(*vendors: str) -> _PolicyContext: + """Context manager to set allowed vendors whitelist""" + current = get_policy() + policy = SelectionPolicy.from_dict( + prefer=current.prefer, + strict=current.strict, + per_op_order={k: list(v) for k, v in current.per_op_order}, + deny_vendors=set(current.deny_vendors), + allow_vendors=set(vendors), + ) + return policy_context(policy) + + +def with_denied_vendors(*vendors: str) -> _PolicyContext: + """Context manager to add denied vendors to blacklist""" + current = get_policy() + denied = set(current.deny_vendors) + denied.update(vendors) + policy = SelectionPolicy.from_dict( + prefer=current.prefer, + strict=current.strict, + per_op_order={k: list(v) for k, v in current.per_op_order}, + deny_vendors=denied, + allow_vendors=set(current.allow_vendors) if current.allow_vendors else None, + ) + return policy_context(policy) diff --git a/transformer_engine/plugin/core/registry.py b/transformer_engine/plugin/core/registry.py new file mode 100644 index 0000000000..1a4099936d --- /dev/null +++ b/transformer_engine/plugin/core/registry.py @@ -0,0 +1,116 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import threading +from dataclasses import dataclass +from typing import Dict, List, Sequence + +from .types import OpImpl + + +@dataclass +class OpRegistrySnapshot: + """Immutable snapshot of operator registry state""" + + impls_by_op: Dict[str, List[OpImpl]] + + +class OpRegistry: + """ + Thread-safe registry for operator implementations. + + This registry stores operator implementations indexed by op_name and impl_id. + Each operator can have multiple implementations from different backends/vendors. + """ + + def __init__(self) -> None: + self._lock = threading.RLock() + # Structure: {op_name: {impl_id: OpImpl}} + self._impls_by_op: Dict[str, Dict[str, OpImpl]] = {} + + def register_impl(self, impl: OpImpl) -> None: + """ + Register a single operator implementation. + + Args: + impl: OpImpl instance to register + + Raises: + ValueError: If impl_id is already registered for this op_name + """ + with self._lock: + by_id = self._impls_by_op.setdefault(impl.op_name, {}) + if impl.impl_id in by_id: + raise ValueError( + f"Duplicate impl_id '{impl.impl_id}' for op='{impl.op_name}'. " + f"Existing: {by_id[impl.impl_id]}, New: {impl}" + ) + by_id[impl.impl_id] = impl + + def register_many(self, impls: Sequence[OpImpl]) -> None: + """ + Register multiple operator implementations. + + Args: + impls: Sequence of OpImpl instances to register + """ + for impl in impls: + self.register_impl(impl) + + def snapshot(self) -> OpRegistrySnapshot: + """ + Create an immutable snapshot of current registry state. + + Returns: + OpRegistrySnapshot with all registered implementations + """ + with self._lock: + impls_by_op = {op: list(by_id.values()) for op, by_id in self._impls_by_op.items()} + return OpRegistrySnapshot(impls_by_op=impls_by_op) + + def get_implementations(self, op_name: str) -> List[OpImpl]: + """ + Get all implementations for a specific operator. + + Args: + op_name: Name of the operator + + Returns: + List of OpImpl for the operator (empty if not found) + """ + with self._lock: + by_id = self._impls_by_op.get(op_name, {}) + return list(by_id.values()) + + def get_implementation(self, op_name: str, impl_id: str) -> OpImpl | None: + """ + Get a specific implementation by op_name and impl_id. + + Args: + op_name: Name of the operator + impl_id: Implementation ID + + Returns: + OpImpl if found, None otherwise + """ + with self._lock: + by_id = self._impls_by_op.get(op_name, {}) + return by_id.get(impl_id) + + def list_operators(self) -> List[str]: + """ + List all registered operator names. + + Returns: + List of operator names + """ + with self._lock: + return list(self._impls_by_op.keys()) + + def clear(self) -> None: + """Clear all registered implementations""" + with self._lock: + self._impls_by_op.clear() diff --git a/transformer_engine/plugin/core/types.py b/transformer_engine/plugin/core/types.py new file mode 100644 index 0000000000..e5508320f2 --- /dev/null +++ b/transformer_engine/plugin/core/types.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Optional, Set + + +class BackendImplKind(str, Enum): + DEFAULT = "flagos" + REFERENCE = "reference" + VENDOR = "vendor" + + def __str__(self) -> str: + return self.value + + +@dataclass(frozen=True) +class OpImpl: + op_name: str + impl_id: str + kind: BackendImplKind + fn: Callable[..., Any] + vendor: Optional[str] = None + priority: int = 0 + supported_dtypes: Optional[Set[str]] = None + min_arch: Optional[str] = None + + def __post_init__(self): + if self.kind == BackendImplKind.VENDOR and not self.vendor: + raise ValueError(f"OpImpl with kind=VENDOR must specify vendor name: {self.impl_id}") + + def is_available(self) -> bool: + avail_fn = getattr(self.fn, "_is_available", None) + if callable(avail_fn): + try: + return bool(avail_fn()) + except Exception: + return False + return True + + +TOKEN_PATTERNS = { + "flagos": lambda impl: impl.kind == BackendImplKind.DEFAULT, + "reference": lambda impl: impl.kind == BackendImplKind.REFERENCE, + "vendor": lambda impl: impl.kind == BackendImplKind.VENDOR, +} + + +def match_token(impl: OpImpl, token: str) -> bool: + if token in TOKEN_PATTERNS: + return TOKEN_PATTERNS[token](impl) + + if token.startswith("vendor:"): + vendor_name = token.split(":", 1)[1] + return impl.kind == BackendImplKind.VENDOR and impl.vendor == vendor_name + + if token.startswith("impl:"): + impl_id = token.split(":", 1)[1] + return impl.impl_id == impl_id + + return False diff --git a/transformer_engine/plugin/examples/README.md b/transformer_engine/plugin/examples/README.md new file mode 100644 index 0000000000..318de59487 --- /dev/null +++ b/transformer_engine/plugin/examples/README.md @@ -0,0 +1,181 @@ +# TE-FL Custom Backend Examples + +This directory contains examples demonstrating two ways to add custom backends. + +## Two Approaches + +| Approach | Use Case | Example File | +|----------|----------|--------------| +| **In-tree** | Open source contribution, direct integration | `example_intree.py` | +| **Out-of-tree** | Closed-source / third-party plugin, standalone package | `example_outtree.py` | + +## Quick Start + +```bash +cd transformer_engine/plugin/examples + +# In-tree approach +python example_intree.py + +# Out-of-tree approach +python example_outtree.py +``` + +## In-tree Approach (3 Steps) + +```python +from transformer_engine.plugin.core import ( + OpRegistry, OpManager, OpImpl, BackendImplKind +) + +# 1. Define your operator implementation +def my_rmsnorm(input, weight, eps=1e-5, **kwargs): + variance = input.pow(2).mean(-1, keepdim=True) + return input * torch.rsqrt(variance + eps) * weight, torch.rsqrt(variance + eps) + +# 2. Register to Registry +registry = OpRegistry() +registry.register_impl(OpImpl( + op_name="rmsnorm_fwd", + impl_id="vendor.mybackend", + kind=BackendImplKind.VENDOR, + vendor="mybackend", + fn=my_rmsnorm, + priority=200, +)) + +# 3. Call via Manager +manager = OpManager(registry) +output, rsigma = manager.call("rmsnorm_fwd", input, weight) +``` + +## Out-of-tree Approach (Plugin Package) + +### Plugin Package Structure + +``` +my_vendor_plugin/ +├── __init__.py # Contains register(registry) function +└── setup.py # or pyproject.toml +``` + +### \_\_init\_\_.py + +```python +from transformer_engine.plugin.core import OpImpl, BackendImplKind + +def my_rmsnorm(input, weight, eps=1e-5, **kwargs): + # Your implementation + ... + +def register(registry): + """Called automatically by TE-FL""" + registry.register_impl(OpImpl( + op_name="rmsnorm_fwd", + impl_id="vendor.myvendor", + kind=BackendImplKind.VENDOR, + vendor="myvendor", + fn=my_rmsnorm, + priority=200, + )) +``` + +### Loading Methods + +```bash +# Method 1: Environment variable +export TE_FL_PLUGIN_MODULES=my_vendor_plugin +python your_script.py + +# Method 2: pip install (requires entry_points configuration) +pip install my-vendor-plugin +python your_script.py +``` + +## Environment Variables + +### Backend Selection + +| Variable | Description | Values | Default | +|----------|-------------|--------|---------| +| `TE_FL_PREFER` | Preferred backend type (highest priority) | `flagos` / `vendor` / `reference` | `flagos` | +| `TE_FL_PREFER_VENDOR` | Prefer vendor backend (legacy, lower priority than `TE_FL_PREFER`) | `1` = prefer vendor, `0` = prefer flagos | `0` | +| `TE_FL_STRICT` | Strict mode - raise error if preferred implementation fails instead of fallback | `1` = strict, `0` = allow fallback | `0` | + +### Vendor Filtering + +| Variable | Description | Example | +|----------|-------------|---------| +| `TE_FL_ALLOW_VENDORS` | Whitelist of allowed vendors (comma-separated) | `nvidia,amd` | +| `TE_FL_DENY_VENDORS` | Blacklist of denied vendors (comma-separated) | `vendor_a,vendor_b` | + +### Per-Operator Configuration + +| Variable | Description | Example | +|----------|-------------|---------| +| `TE_FL_PER_OP` | Per-operator backend ordering | `rmsnorm_fwd=vendor:acme\|flagos;rope_fwd=flagos\|reference` | + +Format: `op_name=backend1|backend2;op_name2=backend3|backend4` + +### Plugin Discovery + +| Variable | Description | Example | +|----------|-------------|---------| +| `TE_FL_PLUGIN_MODULES` | Plugin modules to load (comma-separated) | `my_plugin,another_plugin` | + +### Build Configuration + +| Variable | Description | Values | Default | +|----------|-------------|--------|---------| +| `TE_FL_SKIP_CUDA` | Skip CUDA backend (both build-time and runtime) | `1` = skip, `0` = enable | `0` | +| `CUDA_HOME` | CUDA installation path | `/usr/local/cuda` | Auto-detected | +| `CUDA_PATH` | Alternative CUDA path variable | `/usr/local/cuda` | Auto-detected | + +### Logging + +| Variable | Description | Values | Default | +|----------|-------------|--------|---------| +| `TEFL_LOG_LEVEL` | Log level for TE-FL | `DEBUG` / `INFO` / `WARNING` / `ERROR` | `INFO` | + +## Examples + +### Prefer vendor backend +```bash +export TE_FL_PREFER=vendor +python your_script.py +``` + +### Only allow specific vendors +```bash +export TE_FL_ALLOW_VENDORS=nvidia,acme +python your_script.py +``` + +### Custom per-operator ordering +```bash +# Use acme vendor for rmsnorm, flagos for others +export TE_FL_PER_OP="rmsnorm_fwd=vendor:acme|flagos" +python your_script.py +``` + +### Skip CUDA and use FlagOS only +```bash +export TE_FL_SKIP_CUDA=1 +export TE_FL_PREFER=flagos +python your_script.py +``` + +### Enable debug logging +```bash +export TEFL_LOG_LEVEL=DEBUG +python your_script.py +``` + +## Expected Output + +When running, you should see logs like: + +``` +[TE-FL manager.py:133 INFO] Registered impl_ids: ['default.flagos', 'reference.torch', 'vendor.mybackend'] +[TE-FL manager.py:390 INFO] Op 'rmsnorm_fwd' using 'vendor.mybackend' (kind=vendor, vendor=mybackend) +``` diff --git a/transformer_engine/plugin/examples/example_intree.py b/transformer_engine/plugin/examples/example_intree.py new file mode 100644 index 0000000000..c4badb0ccc --- /dev/null +++ b/transformer_engine/plugin/examples/example_intree.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Example: In-tree Backend Registration + +Use case: Add implementation directly to the codebase (open source contribution) + +Run: + python example_intree.py +""" + +import torch +from transformer_engine.plugin.core import ( + OpRegistry, + OpManager, + OpImpl, + BackendImplKind, + SelectionPolicy, + set_global_policy, +) + + +# ============================================================ +# Step 1: Define your operator implementation +# ============================================================ +def my_rmsnorm_fwd(input, weight, eps=1e-5, **kwargs): + """Custom RMSNorm implementation""" + print(" >>> [MyBackend] my_rmsnorm_fwd called!") + variance = input.pow(2).mean(-1, keepdim=True) + output = input * torch.rsqrt(variance + eps) * weight + rsigma = torch.rsqrt(variance + eps) + return output, rsigma + + +# Optional: Define availability check function +my_rmsnorm_fwd._is_available = lambda: True + + +# ============================================================ +# Step 2: Register to Registry +# ============================================================ +registry = OpRegistry() + +registry.register_impl( + OpImpl( + op_name="rmsnorm_fwd", # Operator name + impl_id="vendor.mybackend", # Implementation ID (unique identifier) + kind=BackendImplKind.VENDOR, # Type: VENDOR / DEFAULT / REFERENCE + vendor="mybackend", # Vendor name + fn=my_rmsnorm_fwd, # Implementation function + priority=200, # Priority (higher = preferred) + ) +) + + +# ============================================================ +# Step 3: Create Manager and call operator +# ============================================================ +manager = OpManager(registry) + +# Set policy: prefer vendor backend +set_global_policy(SelectionPolicy(prefer="vendor")) + +# Prepare test data +input_tensor = torch.randn(2, 4, 8) +weight = torch.ones(8) + +# Call operator - will automatically select highest priority implementation +print("\nCalling rmsnorm_fwd:") +output, rsigma = manager.call("rmsnorm_fwd", input_tensor, weight, eps=1e-5) + +print(f"\nInput shape: {input_tensor.shape}") +print(f"Output shape: {output.shape}") +print("\nSuccess! Your custom backend was used.") diff --git a/transformer_engine/plugin/examples/example_outtree.py b/transformer_engine/plugin/examples/example_outtree.py new file mode 100644 index 0000000000..e85339307f --- /dev/null +++ b/transformer_engine/plugin/examples/example_outtree.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Example: Out-of-tree Backend Registration + +Use case: Standalone plugin package (closed-source / third-party) + +Run: + # Method 1: Load plugin module via environment variable + TE_FL_PLUGIN_MODULES=my_vendor_plugin python example_outtree.py + + # Method 2: Install plugin package with entry_points via pip + pip install my-vendor-plugin + python example_outtree.py +""" + +import sys +import types +import torch + + +# ============================================================ +# Step 1: Create plugin module (simulates a pip-installed package) +# ============================================================ +def create_plugin_module(): + """ + Simulate a standalone plugin module. + + In practice, this code would be in a separate pip package, e.g.: + - my_vendor_plugin/__init__.py + """ + + # Create module + plugin_module = types.ModuleType("my_vendor_plugin") + + # Define operator implementation + def my_rmsnorm_fwd(input, weight, eps=1e-5, **kwargs): + """Custom RMSNorm implementation""" + print(" >>> [MyVendorPlugin] my_rmsnorm_fwd called!") + variance = input.pow(2).mean(-1, keepdim=True) + output = input * torch.rsqrt(variance + eps) * weight + rsigma = torch.rsqrt(variance + eps) + return output, rsigma + + my_rmsnorm_fwd._is_available = lambda: True + + # Define register function (must have 'register' or 'te_fl_register' function) + def register(registry): + """ + Plugin registration function - called automatically by TE-FL. + + Args: + registry: OpRegistry instance + """ + from transformer_engine.plugin.core import ( + OpImpl, + BackendImplKind, + ) + + print("[MyVendorPlugin] Registering operator implementations...") + + registry.register_impl( + OpImpl( + op_name="rmsnorm_fwd", + impl_id="vendor.myvendor", + kind=BackendImplKind.VENDOR, + vendor="myvendor", + fn=my_rmsnorm_fwd, + priority=200, + ) + ) + + print("[MyVendorPlugin] Registration complete!") + + # Add register function to module + plugin_module.register = register + + return plugin_module + + +# ============================================================ +# Step 2: Register plugin module to sys.modules (simulates pip install) +# ============================================================ +plugin = create_plugin_module() +sys.modules["my_vendor_plugin"] = plugin + + +# ============================================================ +# Step 3: Set environment variables for TE-FL auto-discovery +# ============================================================ +import os + +os.environ["TE_FL_PLUGIN_MODULES"] = "my_vendor_plugin" +os.environ["TE_FL_PREFER"] = "vendor" # Prefer vendor backend + + +# ============================================================ +# Step 4: Import TE-FL (will auto-discover and load plugin) +# ============================================================ +from transformer_engine.plugin.core import ( + get_manager, + reset_default_manager, +) + +# Reset manager to trigger plugin discovery +reset_default_manager() +manager = get_manager() + + +# ============================================================ +# Step 5: Call operator +# ============================================================ +input_tensor = torch.randn(2, 4, 8) +weight = torch.ones(8) + +print("\nCalling rmsnorm_fwd:") +output, rsigma = manager.call("rmsnorm_fwd", input_tensor, weight, eps=1e-5) + +print(f"\nInput shape: {input_tensor.shape}") +print(f"Output shape: {output.shape}") +print("\nSuccess! Your out-of-tree plugin was loaded and used.") diff --git a/transformer_engine/plugin/test_utils.py b/transformer_engine/plugin/test_utils.py new file mode 100644 index 0000000000..c1462c84d2 --- /dev/null +++ b/transformer_engine/plugin/test_utils.py @@ -0,0 +1,220 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import torch +import numpy as np +from typing import List, Dict, Callable, Any, Optional + + +def get_available_backends() -> List[str]: + """ + Get list of available backends by extracting unique impl_ids from OpRegistry. + + Returns impl_id prefixes (e.g., "default.flagos" -> "flagos") + """ + try: + from transformer_engine.plugin.core import get_registry + + registry = get_registry() + all_impls = [] + for op_name in registry.list_operators(): + all_impls.extend(registry.get_implementations(op_name)) + + # Extract unique impl_id prefixes (e.g., "default.flagos" -> "flagos") + impl_ids = set() + for impl in all_impls: + # impl_id format: "kind.name" (e.g., "default.flagos", "vendor.cuda") + parts = impl.impl_id.split(".", 1) + if len(parts) == 2: + impl_ids.add(parts[1]) # Get the "name" part + else: + impl_ids.add(impl.impl_id) + + return sorted(impl_ids) + except Exception as e: + print(f"Warning: Could not load backends: {e}") + import traceback + + traceback.print_exc() + return [] + + +def get_backend(name: str): + """ + Get a backend-like object that dispatches to a specific implementation. + + Args: + name: Backend name (e.g., "cuda", "flagos", "torch") + + Returns: + A wrapper object that calls the specific backend implementation + """ + from transformer_engine.plugin.core import get_registry + from transformer_engine.plugin.core.logger_manager import get_logger + import functools + + logger = get_logger() + + class BackendWrapper: + """Wrapper that calls specific backend implementations""" + + def __init__(self, backend_name: str): + self.backend_name = backend_name + self.registry = get_registry() + self._called_ops = set() # Track which ops have been called (for logging) + + def _find_impl(self, op_name: str): + """Find implementation matching the backend name""" + impls = self.registry.get_implementations(op_name) + + # Try to find implementation matching backend_name + # Match against impl_id suffix (e.g., "vendor.cuda" matches "cuda") + for impl in impls: + if ( + impl.impl_id.endswith(f".{self.backend_name}") + or impl.impl_id == self.backend_name + ): + if impl.is_available(): + return impl + else: + raise RuntimeError( + f"Implementation '{impl.impl_id}' for op '{op_name}' is not available" + ) + + raise NotImplementedError( + f"No implementation found for op '{op_name}' with backend '{self.backend_name}'" + ) + + def __getattr__(self, op_name: str): + """Dynamically resolve operator to specific backend implementation""" + impl = self._find_impl(op_name) + + # Log on first call to this op for this backend + if op_name not in self._called_ops: + self._called_ops.add(op_name) + logger.info( + f"[Test] Op '{op_name}' using '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + + return impl.fn + + return BackendWrapper(name) + + +def allclose(a: torch.Tensor, b: torch.Tensor, rtol: float = 1e-5, atol: float = 1e-8) -> bool: + return torch.allclose(a, b, rtol=rtol, atol=atol) + + +def compute_relative_error(output: torch.Tensor, reference: torch.Tensor) -> float: + diff = (output - reference).abs() + relative_error = (diff / (reference.abs() + 1e-10)).mean().item() + return relative_error + + +def compute_max_error(output: torch.Tensor, reference: torch.Tensor) -> float: + return (output - reference).abs().max().item() + + +class TestCase: + def __init__(self, name: str, description: str = ""): + self.name = name + self.description = description + self.passed = 0 + self.failed = 0 + self.skipped = 0 + self.errors: List[str] = [] + + def setup(self): + pass + + def teardown(self): + pass + + def assert_close( + self, + output: torch.Tensor, + reference: torch.Tensor, + rtol: float = 1e-5, + atol: float = 1e-8, + msg: str = "", + ): + if not allclose(output, reference, rtol, atol): + max_err = compute_max_error(output, reference) + rel_err = compute_relative_error(output, reference) + error_msg = f"{msg}\n Max error: {max_err:.6e}, Relative error: {rel_err:.6e}" + self.errors.append(error_msg) + self.failed += 1 + raise AssertionError(error_msg) + self.passed += 1 + + def report(self): + total = self.passed + self.failed + self.skipped + print(f"\n{'='*60}") + print(f"Test: {self.name}") + if self.description: + print(f"Description: {self.description}") + print(f"{'='*60}") + print( + f"Total: {total}, Passed: {self.passed}, Failed: {self.failed}, Skipped: {self.skipped}" + ) + if self.errors: + print(f"\nErrors:") + for i, error in enumerate(self.errors, 1): + print(f" {i}. {error}") + print(f"{'='*60}") + return self.failed == 0 + + +def generate_random_tensor( + shape: tuple, + dtype: torch.dtype = torch.float32, + device: str = "cpu", + requires_grad: bool = False, +) -> torch.Tensor: + if dtype in (torch.bfloat16, torch.float16): + tensor = torch.randn(shape, dtype=torch.float32, device=device) + tensor = tensor.to(dtype=dtype) + if requires_grad: + tensor.requires_grad_(True) + else: + tensor = torch.randn(shape, dtype=dtype, device=device, requires_grad=requires_grad) + return tensor + + +def generate_test_shapes() -> List[tuple]: + return [ + (2, 4), + (8, 16), + (32, 64), + (2, 4, 8), + (4, 8, 16), + (2, 4, 8, 16), + ] + + +def run_test_on_backends( + test_func: Callable, + backends: Optional[List[str]] = None, + reference_backend: str = "reference", +) -> Dict[str, bool]: + if backends is None: + backends = get_available_backends() + + results = {} + for backend_name in backends: + try: + test_func(backend_name) + results[backend_name] = True + print(f" ✓ {backend_name}") + except Exception as e: + results[backend_name] = False + print(f" ✗ {backend_name}: {e}") + + return results + + +def skip_if_backend_unavailable(backend_name: str) -> bool: + available = get_available_backends() + return backend_name not in available diff --git a/transformer_engine/plugin/tests/__init__.py b/transformer_engine/plugin/tests/__init__.py new file mode 100644 index 0000000000..caaec47482 --- /dev/null +++ b/transformer_engine/plugin/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +__all__ = [] diff --git a/transformer_engine/plugin/tests/run_all_tests.py b/transformer_engine/plugin/tests/run_all_tests.py new file mode 100644 index 0000000000..bfc2dee59d --- /dev/null +++ b/transformer_engine/plugin/tests/run_all_tests.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, BAAI. All rights reserved. +# +import sys +import torch + +from test_activations import ActivationTests +from test_normalization import NormalizationTests +from test_operations import OperationsTests +from test_softmax import SoftmaxTests +from test_optimizer import OptimizerTests +from test_flash_attention import FlashAttentionTests + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + + print("\n" + "=" * 70) + print(" " * 15 + "TEX Interface Backend Tests") + print("=" * 70) + print(f"Using device: {device}\n") + + test_suites = [ + ActivationTests(device=device), + NormalizationTests(device=device), + OperationsTests(device=device), + SoftmaxTests(device=device), + OptimizerTests(device=device), + FlashAttentionTests(device=device), + ] + + results = [] + for suite in test_suites: + success = suite.run_all_tests() + results.append((suite.name, success)) + + print("\n" + "=" * 70) + print(" " * 25 + "Test Summary") + print("=" * 70) + + total_passed = sum(1 for _, success in results if success) + total_tests = len(results) + + for name, success in results: + status = "✓ PASSED" if success else "✗ FAILED" + print(f" {name:40s} {status}") + + print("=" * 70) + print(f"Total: {total_passed}/{total_tests} test suites passed") + print("=" * 70) + + return 0 if all(success for _, success in results) else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/transformer_engine/plugin/tests/test_activations.py b/transformer_engine/plugin/tests/test_activations.py new file mode 100644 index 0000000000..e73851ac50 --- /dev/null +++ b/transformer_engine/plugin/tests/test_activations.py @@ -0,0 +1,642 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import os +import torch +import torch.nn.functional as F +import sys + +from transformer_engine.plugin.test_utils import ( + get_available_backends, + get_backend, + TestCase, + generate_random_tensor, + generate_test_shapes, +) + + +class ActivationTests(TestCase): + def __init__(self, device="cpu"): + super().__init__( + "Activation Functions", "Test correctness of all activation functions across backends" + ) + self.backends = get_available_backends() + self.reference_backend = "reference" + self.device = device + + # ==================== Reference implementations ==================== + def _get_reference_gelu(self, x): + return F.gelu(x, approximate="tanh") + + def _get_reference_geglu(self, x): + a, b = x.chunk(2, dim=-1) + return F.gelu(a, approximate="tanh") * b + + def _get_reference_qgelu(self, x): + return x * torch.sigmoid(1.702 * x) + + def _get_reference_qgeglu(self, x): + a, b = x.chunk(2, dim=-1) + return a * torch.sigmoid(1.702 * a) * b + + def _get_reference_relu(self, x): + return F.relu(x) + + def _get_reference_reglu(self, x): + a, b = x.chunk(2, dim=-1) + return F.relu(a) * b + + def _get_reference_srelu(self, x): + return torch.square(F.relu(x)) + + def _get_reference_sreglu(self, x): + a, b = x.chunk(2, dim=-1) + return torch.square(F.relu(a)) * b + + def _get_reference_silu(self, x): + return F.silu(x) + + def _get_reference_swiglu(self, x): + a, b = x.chunk(2, dim=-1) + return F.silu(a) * b + + def _get_reference_clamped_swiglu(self, x, limit=7.0, alpha=1.702): + """Reference implementation matching CUDA clamped_swiglu. + + CUDA implementation: + - a (activation): clamp to upper bound only: min(a, limit) + - b (gate): clamp to [-limit, limit], then add 1 + - output = (a_clamped * sigmoid(alpha * a_clamped)) * b_clamped + """ + a, b = x.chunk(2, dim=-1) + # CUDA only clamps a to upper bound + a_clamped = torch.clamp(a, max=limit) + # CUDA clamps b to [-limit, limit] and adds 1 + b_clamped = torch.clamp(b, -limit, limit) + 1 + return a_clamped * torch.sigmoid(alpha * a_clamped) * b_clamped + + # ==================== Forward tests ==================== + def test_gelu_forward(self, shape=(4, 8)): + print(f"\n Testing GELU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_gelu(x) + self._test_activation_forward("gelu", x, reference) + + def test_geglu_forward(self, shape=(4, 16)): + print(f"\n Testing GEGLU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_geglu(x) + self._test_activation_forward("geglu", x, reference) + + def test_qgelu_forward(self, shape=(4, 8)): + print(f"\n Testing QGELU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_qgelu(x) + self._test_activation_forward("qgelu", x, reference) + + def test_qgeglu_forward(self, shape=(4, 16)): + print(f"\n Testing QGEGLU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_qgeglu(x) + self._test_activation_forward("qgeglu", x, reference) + + def test_relu_forward(self, shape=(4, 8)): + print(f"\n Testing ReLU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_relu(x) + self._test_activation_forward("relu", x, reference, rtol=1e-6, atol=1e-8) + + def test_reglu_forward(self, shape=(4, 16)): + print(f"\n Testing ReGLU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_reglu(x) + self._test_activation_forward("reglu", x, reference, rtol=1e-6, atol=1e-8) + + def test_srelu_forward(self, shape=(4, 8)): + print(f"\n Testing SReLU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_srelu(x) + self._test_activation_forward("srelu", x, reference) + + def test_sreglu_forward(self, shape=(4, 16)): + print(f"\n Testing SReGLU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_sreglu(x) + self._test_activation_forward("sreglu", x, reference) + + def test_silu_forward(self, shape=(4, 8)): + print(f"\n Testing SiLU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_silu(x) + self._test_activation_forward("silu", x, reference) + + def test_swiglu_forward(self, shape=(4, 16)): + print(f"\n Testing SwiGLU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_swiglu(x) + self._test_activation_forward("swiglu", x, reference) + + def test_clamped_swiglu_forward(self, shape=(4, 16)): + print(f"\n Testing Clamped SwiGLU forward with shape {shape}") + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + reference = self._get_reference_clamped_swiglu(x) + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output = backend.clamped_swiglu(x, None, 7.0, 1.702) + self.assert_close( + output, + reference, + rtol=1e-4, + atol=1e-6, + msg=f"clamped_swiglu forward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def _test_activation_forward(self, op_name, x, reference, rtol=1e-4, atol=1e-6): + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + op_fn = getattr(backend, op_name) + output = op_fn(x, None) + self.assert_close( + output, + reference, + rtol=rtol, + atol=atol, + msg=f"{op_name} forward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + # ==================== Backward tests ==================== + def test_gelu_backward(self, shape=(4, 8)): + print(f"\n Testing GELU backward with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + y = self._get_reference_gelu(x) + y.backward(grad_output) + reference_grad = x.grad.clone() + x.grad = None + self._test_activation_backward("dgelu", x, grad_output, reference_grad) + + def test_geglu_backward(self, shape=(4, 16)): + print(f"\n Testing GEGLU backward with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor( + (shape[0], shape[1] // 2) if len(shape) == 2 else (*shape[:-1], shape[-1] // 2), + dtype=torch.float32, + device=self.device, + ) + y = self._get_reference_geglu(x) + y.backward(grad_output) + reference_grad = x.grad.clone() + x.grad = None + self._test_activation_backward("dgeglu", x, grad_output, reference_grad) + + def test_qgelu_backward(self, shape=(4, 8)): + print(f"\n Testing QGELU backward with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + y = self._get_reference_qgelu(x) + y.backward(grad_output) + reference_grad = x.grad.clone() + x.grad = None + self._test_activation_backward("dqgelu", x, grad_output, reference_grad) + + def test_qgeglu_backward(self, shape=(4, 16)): + print(f"\n Testing QGEGLU backward with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor( + (shape[0], shape[1] // 2) if len(shape) == 2 else (*shape[:-1], shape[-1] // 2), + dtype=torch.float32, + device=self.device, + ) + y = self._get_reference_qgeglu(x) + y.backward(grad_output) + reference_grad = x.grad.clone() + x.grad = None + self._test_activation_backward("dqgeglu", x, grad_output, reference_grad) + + def test_relu_backward(self, shape=(4, 8)): + print(f"\n Testing ReLU backward with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + y = self._get_reference_relu(x) + y.backward(grad_output) + reference_grad = x.grad.clone() + x.grad = None + self._test_activation_backward("drelu", x, grad_output, reference_grad) + + def test_reglu_backward(self, shape=(4, 16)): + print(f"\n Testing ReGLU backward with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor( + (shape[0], shape[1] // 2) if len(shape) == 2 else (*shape[:-1], shape[-1] // 2), + dtype=torch.float32, + device=self.device, + ) + y = self._get_reference_reglu(x) + y.backward(grad_output) + reference_grad = x.grad.clone() + x.grad = None + self._test_activation_backward("dreglu", x, grad_output, reference_grad) + + def test_srelu_backward(self, shape=(4, 8)): + print(f"\n Testing SReLU backward with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + y = self._get_reference_srelu(x) + y.backward(grad_output) + reference_grad = x.grad.clone() + x.grad = None + self._test_activation_backward("dsrelu", x, grad_output, reference_grad) + + def test_sreglu_backward(self, shape=(4, 16)): + print(f"\n Testing SReGLU backward with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor( + (shape[0], shape[1] // 2) if len(shape) == 2 else (*shape[:-1], shape[-1] // 2), + dtype=torch.float32, + device=self.device, + ) + y = self._get_reference_sreglu(x) + y.backward(grad_output) + reference_grad = x.grad.clone() + x.grad = None + self._test_activation_backward("dsreglu", x, grad_output, reference_grad) + + def test_silu_backward(self, shape=(4, 8)): + print(f"\n Testing SiLU backward with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + y = self._get_reference_silu(x) + y.backward(grad_output) + reference_grad = x.grad.clone() + x.grad = None + self._test_activation_backward("dsilu", x, grad_output, reference_grad) + + def test_swiglu_backward(self, shape=(4, 16)): + print(f"\n Testing SwiGLU backward with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor( + (shape[0], shape[1] // 2) if len(shape) == 2 else (*shape[:-1], shape[-1] // 2), + dtype=torch.float32, + device=self.device, + ) + y = self._get_reference_swiglu(x) + y.backward(grad_output) + reference_grad = x.grad.clone() + x.grad = None + self._test_activation_backward("dswiglu", x, grad_output, reference_grad) + + def _test_activation_backward( + self, op_name, x, grad_output, reference_grad, rtol=1e-4, atol=1e-6 + ): + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + op_fn = getattr(backend, op_name) + grad_input = op_fn(grad_output, x.detach(), None) + self.assert_close( + grad_input, + reference_grad, + rtol=rtol, + atol=atol, + msg=f"{op_name} backward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + # ==================== Bias + backward tests ==================== + def test_dbias_dgelu(self, shape=(4, 8)): + print(f"\n Testing dbias_dgelu with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + + # Reference: compute dgelu and sum for bias grad + y = self._get_reference_gelu(x) + y.backward(grad_output) + ref_grad_input = x.grad.clone() + ref_grad_bias = grad_output.sum(dim=tuple(range(grad_output.ndim - 1))) + x.grad = None + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + grad_input, grad_bias = backend.dbias_dgelu(grad_output, x.detach(), None) + self.assert_close( + grad_input, + ref_grad_input, + rtol=1e-4, + atol=1e-6, + msg=f"dbias_dgelu grad_input mismatch for {backend_name}", + ) + self.assert_close( + grad_bias, + ref_grad_bias, + rtol=1e-4, + atol=1e-6, + msg=f"dbias_dgelu grad_bias mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except RuntimeError as e: + # CUDA requires a valid quantizer for dbias_d* fused ops + if "NoneQuantizer does not support" in str(e): + self.skipped += 1 + print(f" ⊘ {backend_name} (requires FP8 quantizer for fused op)") + else: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_dbias_dsilu(self, shape=(4, 8)): + print(f"\n Testing dbias_dsilu with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + + y = self._get_reference_silu(x) + y.backward(grad_output) + ref_grad_input = x.grad.clone() + ref_grad_bias = grad_output.sum(dim=tuple(range(grad_output.ndim - 1))) + x.grad = None + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + grad_input, grad_bias = backend.dbias_dsilu(grad_output, x.detach(), None) + self.assert_close( + grad_input, + ref_grad_input, + rtol=1e-4, + atol=1e-6, + msg=f"dbias_dsilu grad_input mismatch for {backend_name}", + ) + self.assert_close( + grad_bias, + ref_grad_bias, + rtol=1e-4, + atol=1e-6, + msg=f"dbias_dsilu grad_bias mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except RuntimeError as e: + # CUDA requires a valid quantizer for dbias_d* fused ops + if "NoneQuantizer does not support" in str(e): + self.skipped += 1 + print(f" ⊘ {backend_name} (requires FP8 quantizer for fused op)") + else: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_dbias_drelu(self, shape=(4, 8)): + print(f"\n Testing dbias_drelu with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + + y = self._get_reference_relu(x) + y.backward(grad_output) + ref_grad_input = x.grad.clone() + ref_grad_bias = grad_output.sum(dim=tuple(range(grad_output.ndim - 1))) + x.grad = None + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + grad_input, grad_bias = backend.dbias_drelu(grad_output, x.detach(), None) + self.assert_close( + grad_input, + ref_grad_input, + rtol=1e-4, + atol=1e-6, + msg=f"dbias_drelu grad_input mismatch for {backend_name}", + ) + self.assert_close( + grad_bias, + ref_grad_bias, + rtol=1e-4, + atol=1e-6, + msg=f"dbias_drelu grad_bias mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except RuntimeError as e: + # CUDA requires a valid quantizer for dbias_d* fused ops + if "NoneQuantizer does not support" in str(e): + self.skipped += 1 + print(f" ⊘ {backend_name} (requires FP8 quantizer for fused op)") + else: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_dbias_dqgelu(self, shape=(4, 8)): + print(f"\n Testing dbias_dqgelu with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + + y = self._get_reference_qgelu(x) + y.backward(grad_output) + ref_grad_input = x.grad.clone() + ref_grad_bias = grad_output.sum(dim=tuple(range(grad_output.ndim - 1))) + x.grad = None + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + grad_input, grad_bias = backend.dbias_dqgelu(grad_output, x.detach(), None) + self.assert_close( + grad_input, + ref_grad_input, + rtol=1e-4, + atol=1e-6, + msg=f"dbias_dqgelu grad_input mismatch for {backend_name}", + ) + self.assert_close( + grad_bias, + ref_grad_bias, + rtol=1e-4, + atol=1e-6, + msg=f"dbias_dqgelu grad_bias mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except RuntimeError as e: + # CUDA requires a valid quantizer for dbias_d* fused ops + if "NoneQuantizer does not support" in str(e): + self.skipped += 1 + print(f" ⊘ {backend_name} (requires FP8 quantizer for fused op)") + else: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_dbias_dsrelu(self, shape=(4, 8)): + print(f"\n Testing dbias_dsrelu with shape {shape}") + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + + y = self._get_reference_srelu(x) + y.backward(grad_output) + ref_grad_input = x.grad.clone() + ref_grad_bias = grad_output.sum(dim=tuple(range(grad_output.ndim - 1))) + x.grad = None + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + grad_input, grad_bias = backend.dbias_dsrelu(grad_output, x.detach(), None) + self.assert_close( + grad_input, + ref_grad_input, + rtol=1e-4, + atol=1e-6, + msg=f"dbias_dsrelu grad_input mismatch for {backend_name}", + ) + self.assert_close( + grad_bias, + ref_grad_bias, + rtol=1e-4, + atol=1e-6, + msg=f"dbias_dsrelu grad_bias mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except RuntimeError as e: + # CUDA requires a valid quantizer for dbias_d* fused ops + if "NoneQuantizer does not support" in str(e): + self.skipped += 1 + print(f" ⊘ {backend_name} (requires FP8 quantizer for fused op)") + else: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def run_all_tests(self): + print("\n" + "=" * 60) + print("Testing Activation Functions") + print("=" * 60) + print(f"Available backends: {', '.join(self.backends)}") + + shapes = [(4, 8), (8, 16), (2, 4, 8)] + glu_shapes = [(4, 16), (8, 32), (2, 4, 16)] + + # Forward tests - non-gated activations + for shape in shapes: + self.test_gelu_forward(shape) + self.test_qgelu_forward(shape) + self.test_relu_forward(shape) + self.test_srelu_forward(shape) + self.test_silu_forward(shape) + + # Forward tests - gated activations + for shape in glu_shapes: + self.test_geglu_forward(shape) + self.test_qgeglu_forward(shape) + self.test_reglu_forward(shape) + self.test_sreglu_forward(shape) + self.test_swiglu_forward(shape) + self.test_clamped_swiglu_forward(shape) + + # Backward tests - non-gated activations + for shape in shapes: + self.test_gelu_backward(shape) + self.test_qgelu_backward(shape) + self.test_relu_backward(shape) + self.test_srelu_backward(shape) + self.test_silu_backward(shape) + + # Backward tests - gated activations + for shape in glu_shapes: + self.test_geglu_backward(shape) + self.test_qgeglu_backward(shape) + self.test_reglu_backward(shape) + self.test_sreglu_backward(shape) + self.test_swiglu_backward(shape) + + # Note: dbias_d* tests are skipped because CUDA requires FP8 quantizer + # for these fused ops. These will be tested separately with FP8 quantizer. + + return self.report() + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + test_suite = ActivationTests(device=device) + success = test_suite.run_all_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/transformer_engine/plugin/tests/test_flash_attention.py b/transformer_engine/plugin/tests/test_flash_attention.py new file mode 100644 index 0000000000..3a3f3be24f --- /dev/null +++ b/transformer_engine/plugin/tests/test_flash_attention.py @@ -0,0 +1,359 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import math +import torch +import torch.nn.functional as F + +from transformer_engine.plugin.test_utils import ( + get_available_backends, + get_backend, + TestCase, + generate_random_tensor, +) + + +class FlashAttentionTests(TestCase): + def __init__(self, device="cpu"): + super().__init__( + "Flash Attention", "Test correctness of Flash Attention implementation across backends" + ) + self.backends = get_available_backends() + self.device = device + + def _reference_attention( + self, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + ): + """Reference implementation of scaled dot-product attention + Input format: sbhd [seq, batch, heads, dim] + """ + # Convert sbhd to bhsd for computation + q = query.permute(1, 2, 0, 3) # [batch, heads, seq, dim] + k = key.permute(1, 2, 0, 3) + v = value.permute(1, 2, 0, 3) + + L, S = q.size(-2), k.size(-2) + if scale is None: + scale_factor = 1 / math.sqrt(q.size(-1)) + else: + scale_factor = scale + + attn_weight = q @ k.transpose(-2, -1) * scale_factor + + if is_causal: + causal_mask = torch.triu( + torch.full((L, S), float("-inf"), dtype=q.dtype, device=q.device), diagonal=1 + ) + attn_weight = attn_weight + causal_mask + + if attn_mask is not None: + attn_weight = attn_weight + attn_mask + + attn_weight = F.softmax(attn_weight, dim=-1) + + if dropout_p > 0.0: + attn_weight = F.dropout(attn_weight, p=dropout_p, training=True) + + out = attn_weight @ v + # Convert bhsd back to sbhd + return out.permute(2, 0, 1, 3) # [seq, batch, heads, dim] + + def test_flash_attention_forward_basic( + self, seq_len=16, batch_size=2, num_heads=4, head_dim=32 + ): + """Test basic flash attention forward pass with sbhd layout and bf16""" + print( + f"\n Testing Flash Attention forward sbhd bf16 (seq={seq_len}, batch={batch_size}," + f" heads={num_heads}, dim={head_dim})" + ) + + # Shape: (seq_len, batch, num_heads, head_dim) - sbhd layout + query = generate_random_tensor( + (seq_len, batch_size, num_heads, head_dim), dtype=torch.bfloat16, device=self.device + ) + key = generate_random_tensor( + (seq_len, batch_size, num_heads, head_dim), dtype=torch.bfloat16, device=self.device + ) + value = generate_random_tensor( + (seq_len, batch_size, num_heads, head_dim), dtype=torch.bfloat16, device=self.device + ) + + scale = 1.0 / math.sqrt(head_dim) + + # Reference attention (compute in float32 for accuracy) + reference = self._reference_attention( + query.float(), key.float(), value.float(), scale=scale, is_causal=False + ).to(torch.bfloat16) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + FlashAttentionClass = backend.get_flash_attention_class() + flash_attn = FlashAttentionClass( + softmax_scale=scale, + attention_dropout=0.0, + attention_type="self", + deterministic=True, + ) + + # Run forward pass with sbhd layout + output = flash_attn( + query_layer=query, + key_layer=key, + value_layer=value, + attention_mask=None, + qkv_layout="sb3hd", + attn_mask_type="no_mask", + window_size=(-1, -1), # Required by flash_attn 2.7+ + ) + + # Output shape: sbhd -> view to sb(h*d) + expected_shape = (seq_len, batch_size, num_heads * head_dim) + if output.shape != expected_shape: + # Try to reshape reference for comparison + reference_flat = reference.contiguous().reshape(seq_len, batch_size, -1) + self.assert_close( + output.float(), + reference_flat.float(), + rtol=1e-2, + atol=1e-2, + msg=f"Flash Attention forward mismatch for {backend_name}", + ) + else: + reference_flat = reference.contiguous().reshape(seq_len, batch_size, -1) + self.assert_close( + output.float(), + reference_flat.float(), + rtol=1e-2, + atol=1e-2, + msg=f"Flash Attention forward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + import traceback + + traceback.print_exc() + + def test_flash_attention_forward_causal( + self, seq_len=16, batch_size=2, num_heads=4, head_dim=32 + ): + """Test flash attention forward pass with causal mask""" + print( + f"\n Testing Flash Attention forward causal sbhd bf16 (seq={seq_len}," + f" batch={batch_size}, heads={num_heads}, dim={head_dim})" + ) + + query = generate_random_tensor( + (seq_len, batch_size, num_heads, head_dim), dtype=torch.bfloat16, device=self.device + ) + key = generate_random_tensor( + (seq_len, batch_size, num_heads, head_dim), dtype=torch.bfloat16, device=self.device + ) + value = generate_random_tensor( + (seq_len, batch_size, num_heads, head_dim), dtype=torch.bfloat16, device=self.device + ) + + scale = 1.0 / math.sqrt(head_dim) + + # Reference attention with causal mask + reference = self._reference_attention( + query.float(), key.float(), value.float(), scale=scale, is_causal=True + ).to(torch.bfloat16) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + FlashAttentionClass = backend.get_flash_attention_class() + flash_attn = FlashAttentionClass( + softmax_scale=scale, + attention_dropout=0.0, + attention_type="self", + deterministic=True, + ) + + output = flash_attn( + query_layer=query, + key_layer=key, + value_layer=value, + attention_mask=None, + qkv_layout="sb3hd", + attn_mask_type="causal", + window_size=(-1, -1), # Required by flash_attn 2.7+ + ) + + reference_flat = reference.contiguous().reshape(seq_len, batch_size, -1) + self.assert_close( + output.float(), + reference_flat.float(), + rtol=1e-2, + atol=1e-2, + msg=f"Flash Attention forward causal mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + import traceback + + traceback.print_exc() + + def test_flash_attention_backward(self, seq_len=16, batch_size=2, num_heads=4, head_dim=32): + """Test flash attention backward pass with sbhd layout, bf16, and causal mask. + + Note: FlagGems backward currently only supports causal attention. + """ + print( + f"\n Testing Flash Attention backward causal sbhd bf16 (seq={seq_len}," + f" batch={batch_size}, heads={num_heads}, dim={head_dim})" + ) + + query = generate_random_tensor( + (seq_len, batch_size, num_heads, head_dim), + dtype=torch.bfloat16, + device=self.device, + requires_grad=True, + ) + key = generate_random_tensor( + (seq_len, batch_size, num_heads, head_dim), + dtype=torch.bfloat16, + device=self.device, + requires_grad=True, + ) + value = generate_random_tensor( + (seq_len, batch_size, num_heads, head_dim), + dtype=torch.bfloat16, + device=self.device, + requires_grad=True, + ) + # grad_output shape matches output: sb(h*d) + grad_output = generate_random_tensor( + (seq_len, batch_size, num_heads * head_dim), dtype=torch.bfloat16, device=self.device + ) + + scale = 1.0 / math.sqrt(head_dim) + + # Reference backward (compute in float32 for accuracy) + # Note: FlagGems backward only supports causal attention + query_f32 = query.float().detach().requires_grad_(True) + key_f32 = key.float().detach().requires_grad_(True) + value_f32 = value.float().detach().requires_grad_(True) + + ref_output = self._reference_attention( + query_f32, key_f32, value_f32, scale=scale, is_causal=True + ) + ref_output_flat = ref_output.contiguous().reshape(seq_len, batch_size, -1) + ref_output_flat.backward(grad_output.float()) + ref_grad_q = query_f32.grad.clone().to(torch.bfloat16) + ref_grad_k = key_f32.grad.clone().to(torch.bfloat16) + ref_grad_v = value_f32.grad.clone().to(torch.bfloat16) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + FlashAttentionClass = backend.get_flash_attention_class() + flash_attn = FlashAttentionClass( + softmax_scale=scale, + attention_dropout=0.0, + attention_type="self", + deterministic=True, + ) + + # Forward pass + q_copy = query.detach().requires_grad_(True) + k_copy = key.detach().requires_grad_(True) + v_copy = value.detach().requires_grad_(True) + + output = flash_attn( + query_layer=q_copy, + key_layer=k_copy, + value_layer=v_copy, + attention_mask=None, + qkv_layout="sb3hd", + attn_mask_type="causal", + window_size=(-1, -1), # Required by flash_attn 2.7+ + ) + + # Backward pass + output.backward(grad_output) + + # bf16 backward has higher numerical error due to accumulated precision loss + self.assert_close( + q_copy.grad.float(), + ref_grad_q.float(), + rtol=2e-2, + atol=2e-2, + msg=f"Flash Attention backward grad_q mismatch for {backend_name}", + ) + self.assert_close( + k_copy.grad.float(), + ref_grad_k.float(), + rtol=2e-2, + atol=2e-2, + msg=f"Flash Attention backward grad_k mismatch for {backend_name}", + ) + self.assert_close( + v_copy.grad.float(), + ref_grad_v.float(), + rtol=2e-2, + atol=2e-2, + msg=f"Flash Attention backward grad_v mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + import traceback + + traceback.print_exc() + + def run_all_tests(self): + print("\n" + "=" * 60) + print("Testing Flash Attention") + print("=" * 60) + print(f"Available backends: {', '.join(self.backends)}") + + # Basic forward tests with sbhd layout and bf16 + self.test_flash_attention_forward_basic(seq_len=16, batch_size=2, num_heads=4, head_dim=32) + self.test_flash_attention_forward_basic(seq_len=32, batch_size=4, num_heads=8, head_dim=64) + + # Causal mask tests + self.test_flash_attention_forward_causal(seq_len=16, batch_size=2, num_heads=4, head_dim=32) + + # Backward tests + self.test_flash_attention_backward(seq_len=16, batch_size=2, num_heads=4, head_dim=32) + + return self.report() + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + if device != "cuda": + print("Warning: Flash Attention tests require CUDA. Skipping.") + return 0 + test_suite = FlashAttentionTests(device=device) + success = test_suite.run_all_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/transformer_engine/plugin/tests/test_normalization.py b/transformer_engine/plugin/tests/test_normalization.py new file mode 100644 index 0000000000..eb2dea35cc --- /dev/null +++ b/transformer_engine/plugin/tests/test_normalization.py @@ -0,0 +1,272 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import os +import torch +import torch.nn.functional as F +import sys + +from transformer_engine.plugin.test_utils import ( + get_available_backends, + get_backend, + TestCase, + generate_random_tensor, +) +from transformer_engine.plugin.core.ops import DType + + +class NormalizationTests(TestCase): + def __init__(self, device="cpu"): + super().__init__( + "Normalization Functions", "Test correctness of LayerNorm and RMSNorm across backends" + ) + self.backends = get_available_backends() + self.eps = 1e-5 + self.device = device + + def _reference_layernorm_forward(self, x, weight, bias, eps): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + rsigma = torch.rsqrt(var + eps) + normalized = (x - mean) * rsigma + output = normalized * weight + bias + return output, mean.squeeze(-1), rsigma.squeeze(-1) + + def _reference_rmsnorm_forward(self, x, weight, eps): + var = (x**2).mean(dim=-1, keepdim=True) + rsigma = torch.rsqrt(var + eps) + normalized = x * rsigma + output = normalized * weight + return output, None, rsigma.squeeze(-1) + + def test_layernorm_forward(self, shape=(2, 4, 8)): + print(f"\n Testing LayerNorm forward with shape {shape}") + + hidden_size = shape[-1] + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + weight = torch.ones(hidden_size, dtype=torch.float32, device=self.device) + bias = torch.zeros(hidden_size, dtype=torch.float32, device=self.device) + + ref_output, ref_mean, ref_rsigma = self._reference_layernorm_forward( + x, weight, bias, self.eps + ) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output, mean, rsigma = backend.layernorm_fwd( + x, weight, bias, self.eps, None, None, DType.kFloat32, 0, False + ) + self.assert_close( + output, + ref_output, + rtol=1e-5, + atol=1e-7, + msg=f"LayerNorm forward output mismatch for {backend_name}", + ) + self.assert_close( + mean, + ref_mean, + rtol=1e-5, + atol=1e-7, + msg=f"LayerNorm forward mean mismatch for {backend_name}", + ) + self.assert_close( + rsigma, + ref_rsigma, + rtol=1e-4, + atol=1e-6, + msg=f"LayerNorm forward rsigma mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_layernorm_backward(self, shape=(2, 4, 8)): + print(f"\n Testing LayerNorm backward with shape {shape}") + + hidden_size = shape[-1] + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + weight = torch.ones( + hidden_size, dtype=torch.float32, device=self.device, requires_grad=True + ) + bias = torch.zeros(hidden_size, dtype=torch.float32, device=self.device, requires_grad=True) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + + output, mean, rsigma = self._reference_layernorm_forward(x, weight, bias, self.eps) + output.backward(grad_output) + ref_grad_x = x.grad.clone() + ref_grad_weight = weight.grad.clone() + ref_grad_bias = bias.grad.clone() + + x.grad = None + weight.grad = None + bias.grad = None + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + x_copy = x.detach() + weight_copy = weight.detach() + + grad_x, grad_weight, grad_bias = backend.layernorm_bwd( + grad_output, x_copy, mean.detach(), rsigma.detach(), weight_copy, 0, False + ) + + self.assert_close( + grad_x, + ref_grad_x, + rtol=1e-4, + atol=1e-6, + msg=f"LayerNorm backward grad_x mismatch for {backend_name}", + ) + self.assert_close( + grad_weight, + ref_grad_weight, + rtol=1e-4, + atol=1e-6, + msg=f"LayerNorm backward grad_weight mismatch for {backend_name}", + ) + self.assert_close( + grad_bias, + ref_grad_bias, + rtol=1e-4, + atol=1e-5, + msg=f"LayerNorm backward grad_bias mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_rmsnorm_forward(self, shape=(2, 4, 8)): + print(f"\n Testing RMSNorm forward with shape {shape}") + + hidden_size = shape[-1] + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + weight = torch.ones(hidden_size, dtype=torch.float32, device=self.device) + + ref_output, _, ref_rsigma = self._reference_rmsnorm_forward(x, weight, self.eps) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output, _, rsigma = backend.rmsnorm_fwd( + x, weight, self.eps, None, None, DType.kFloat32, 0, False + ) + self.assert_close( + output, + ref_output, + rtol=1e-5, + atol=1e-7, + msg=f"RMSNorm forward output mismatch for {backend_name}", + ) + self.assert_close( + rsigma, + ref_rsigma, + rtol=1e-4, + atol=1e-6, + msg=f"RMSNorm forward rsigma mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_rmsnorm_backward(self, shape=(2, 4, 8)): + print(f"\n Testing RMSNorm backward with shape {shape}") + + hidden_size = shape[-1] + x = generate_random_tensor( + shape, dtype=torch.float32, device=self.device, requires_grad=True + ) + weight = torch.ones( + hidden_size, dtype=torch.float32, device=self.device, requires_grad=True + ) + grad_output = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + + output, _, rsigma = self._reference_rmsnorm_forward(x, weight, self.eps) + output.backward(grad_output) + ref_grad_x = x.grad.clone() + ref_grad_weight = weight.grad.clone() + + x.grad = None + weight.grad = None + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + x_copy = x.detach() + weight_copy = weight.detach() + + grad_x, grad_weight = backend.rmsnorm_bwd( + grad_output, x_copy, rsigma.detach(), weight_copy, 0, False + ) + + self.assert_close( + grad_x, + ref_grad_x, + rtol=1e-4, + atol=1e-6, + msg=f"RMSNorm backward grad_x mismatch for {backend_name}", + ) + self.assert_close( + grad_weight, + ref_grad_weight, + rtol=1e-4, + atol=1e-6, + msg=f"RMSNorm backward grad_weight mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def run_all_tests(self): + print("\n" + "=" * 60) + print("Testing Normalization Functions") + print("=" * 60) + print(f"Available backends: {', '.join(self.backends)}") + + shapes = [ + (8, 16), + (32, 64), + (64, 128), + (16, 256), + ] + + for shape in shapes: + self.test_layernorm_forward(shape) + self.test_layernorm_backward(shape) + self.test_rmsnorm_forward(shape) + self.test_rmsnorm_backward(shape) + + return self.report() + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + test_suite = NormalizationTests(device=device) + success = test_suite.run_all_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/transformer_engine/plugin/tests/test_operations.py b/transformer_engine/plugin/tests/test_operations.py new file mode 100644 index 0000000000..1e03dc4692 --- /dev/null +++ b/transformer_engine/plugin/tests/test_operations.py @@ -0,0 +1,315 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import os +import torch +import torch.nn.functional as F +import sys + +from transformer_engine.plugin.test_utils import ( + get_available_backends, + get_backend, + TestCase, + generate_random_tensor, +) +from transformer_engine.plugin.core.ops import DType + + +class OperationsTests(TestCase): + def __init__(self, device="cpu"): + super().__init__( + "Operations (GEMM, Softmax, Dropout)", + "Test correctness of GEMM, Softmax, and Dropout operations", + ) + self.backends = get_available_backends() + self.device = device + + def test_gemm_basic(self, M=32, N=64, K=48): + print(f"\n Testing GEMM ({M}x{K}) @ ({K}x{N})") + + A = generate_random_tensor((K, N), dtype=torch.float32, device=self.device) + B = generate_random_tensor((M, K), dtype=torch.float32, device=self.device) + reference = B @ A + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + D = torch.empty((M, N), dtype=torch.float32, device=self.device) + workspace = torch.empty(1024, dtype=torch.uint8, device=self.device) + + output, _, _, _ = backend.generic_gemm( + A, + False, + B, + False, + D, + None, + DType.kFloat32, + None, + DType.kFloat32, + False, + None, + False, + workspace, + 1024, + False, + False, + ) + + self.assert_close( + output, + reference, + rtol=5e-2, + atol=1e-2, + msg=f"GEMM output mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_gemm_transpose_a(self, M=32, N=64, K=48): + print(f"\n Testing GEMM transpose A ({N}x{K}).T @ ({M}x{K})") + + A = generate_random_tensor((N, K), dtype=torch.float32, device=self.device) + B = generate_random_tensor((M, K), dtype=torch.float32, device=self.device) + reference = B @ A.T + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + D = torch.empty((M, N), dtype=torch.float32, device=self.device) + workspace = torch.empty(1024, dtype=torch.uint8, device=self.device) + + output, _, _, _ = backend.generic_gemm( + A, + True, + B, + False, + D, + None, + DType.kFloat32, + None, + DType.kFloat32, + False, + None, + False, + workspace, + 1024, + False, + False, + ) + + self.assert_close( + output, + reference, + rtol=5e-2, + atol=1e-2, + msg=f"GEMM transpose A mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_gemm_3d(self, B=2, M=16, N=32, K=24): + print(f"\n Testing 3D GEMM ({B}x{M}x{K}) @ ({K}x{N})") + + A = generate_random_tensor((B, M, K), dtype=torch.float32, device=self.device) + B_mat = generate_random_tensor((K, N), dtype=torch.float32, device=self.device) + reference = torch.matmul(A, B_mat) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + D = torch.empty((B, M, N), dtype=torch.float32, device=self.device) + workspace = torch.empty(1024, dtype=torch.uint8, device=self.device) + + output, _, _, _ = backend.generic_gemm( + B_mat, + False, + A, + False, + D, + None, + DType.kFloat32, + None, + DType.kFloat32, + False, + None, + False, + workspace, + 1024, + False, + False, + ) + + self.assert_close( + output, + reference, + rtol=5e-2, + atol=1e-2, + msg=f"3D GEMM mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_scaled_softmax(self, shape=(2, 4, 8, 16)): + print(f"\n Testing scaled softmax with shape {shape}") + + x = generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + scale = 0.125 + reference = F.softmax(x.float() * scale, dim=-1).to(x.dtype) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output = backend.scaled_softmax_forward(x, scale) + self.assert_close( + output, + reference, + rtol=1e-2, + atol=1e-3, + msg=f"Scaled softmax mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_causal_masked_softmax(self, shape=(8, 16, 16)): + print(f"\n Testing causal masked softmax with shape {shape}") + + x = generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + scale = 0.125 + seq_len = shape[-1] + + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), dtype=x.dtype, device=self.device), + diagonal=1, + ) + reference = F.softmax(x.float() * scale + causal_mask.float(), dim=-1).to(x.dtype) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output = backend.scaled_upper_triang_masked_softmax_forward(x, scale) + self.assert_close( + output, + reference, + rtol=1e-2, + atol=1e-3, + msg=f"Causal masked softmax mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_dropout(self, shape=(4, 8, 16)): + print(f"\n Testing dropout with shape {shape}") + + x = generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + dropout_prob = 0.1 + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output, mask = backend.dropout_fwd(x, dropout_prob, None) + + num_nonzero = (output != 0).sum().item() + total_elements = output.numel() + nonzero_ratio = num_nonzero / total_elements + expected_ratio = 1.0 - dropout_prob + + assert abs(nonzero_ratio - expected_ratio) < 0.2, ( + f"Dropout ratio mismatch for {backend_name}: {nonzero_ratio:.3f} vs" + f" {expected_ratio:.3f}" + ) + + assert torch.all( + output[output == 0] == 0 + ), f"Dropped elements should be zero for {backend_name}" + + expected_scale = 1.0 / (1.0 - dropout_prob) + non_zero_output = output[output != 0] + non_zero_input = x[output != 0] + + if len(non_zero_output) > 0: + self.assert_close( + non_zero_output, + non_zero_input * expected_scale, + rtol=1e-2, + atol=1e-3, + msg=f"Dropout scaling mismatch for {backend_name}", + ) + + grad_output = generate_random_tensor( + shape, dtype=torch.bfloat16, device=self.device + ) + grad_input = backend.dropout_bwd(grad_output, mask, dropout_prob, None) + + grad_nonzero_mask = grad_input != 0 + output_nonzero_mask = output != 0 + assert torch.all( + grad_nonzero_mask == output_nonzero_mask + ), f"Dropout backward sparsity mismatch for {backend_name}" + + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def run_all_tests(self): + print("\n" + "=" * 60) + print("Testing Operations (GEMM, Softmax, Dropout)") + print("=" * 60) + print(f"Available backends: {', '.join(self.backends)}") + + self.test_gemm_basic(M=32, N=64, K=48) + self.test_gemm_basic(M=64, N=128, K=96) + self.test_gemm_transpose_a(M=32, N=64, K=48) + self.test_gemm_3d(B=2, M=16, N=32, K=24) + + self.test_scaled_softmax((4, 8, 16, 16)) + self.test_scaled_softmax((2, 4, 32, 32)) + self.test_causal_masked_softmax((16, 32, 32)) + self.test_causal_masked_softmax((8, 64, 64)) + + self.test_dropout((4, 8, 16)) + self.test_dropout((8, 16, 32)) + + return self.report() + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + test_suite = OperationsTests(device=device) + success = test_suite.run_all_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/transformer_engine/plugin/tests/test_optimizer.py b/transformer_engine/plugin/tests/test_optimizer.py new file mode 100644 index 0000000000..75c072e308 --- /dev/null +++ b/transformer_engine/plugin/tests/test_optimizer.py @@ -0,0 +1,543 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import torch +import math + +from transformer_engine.plugin.test_utils import ( + get_available_backends, + get_backend, + TestCase, + generate_random_tensor, +) + + +class OptimizerTests(TestCase): + def __init__(self, device="cpu"): + super().__init__( + "Optimizer Operations", + "Test correctness of multi_tensor optimizer operations across backends", + ) + self.backends = get_available_backends() + self.device = device + + def _reference_multi_tensor_l2norm(self, tensors, per_tensor=False): + """Reference implementation for multi_tensor_l2norm""" + if per_tensor: + return [torch.norm(t.float(), p=2) for t in tensors] + else: + total_norm_sq = sum(torch.norm(t.float(), p=2) ** 2 for t in tensors) + return torch.sqrt(total_norm_sq) + + def test_multi_tensor_scale(self, num_tensors=4, shape=(64, 128)): + print(f"\n Testing multi_tensor_scale with {num_tensors} tensors of shape {shape}") + + scale = 0.5 + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + # Create input tensors + input_tensors = [ + generate_random_tensor(shape, dtype=torch.float32, device=self.device) + for _ in range(num_tensors) + ] + # Create output tensors (will be filled by the function) + output_tensors = [torch.empty_like(t) for t in input_tensors] + # Create reference tensors + ref_tensors = [t.clone() * scale for t in input_tensors] + + # Apply backend scaling: tensor_lists = [input_tensors, output_tensors] + noop_flag = torch.tensor([0], dtype=torch.int32, device=self.device) + backend.multi_tensor_scale( + chunk_size=2048, + noop_flag=noop_flag, + tensor_lists=[input_tensors, output_tensors], + scale=scale, + ) + + # Compare results + for i, (output, reference) in enumerate(zip(output_tensors, ref_tensors)): + self.assert_close( + output, + reference, + rtol=1e-5, + atol=1e-7, + msg=f"multi_tensor_scale tensor {i} mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_multi_tensor_l2norm(self, num_tensors=4, shape=(64, 128)): + print(f"\n Testing multi_tensor_l2norm with {num_tensors} tensors of shape {shape}") + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + tensors = [ + generate_random_tensor(shape, dtype=torch.float32, device=self.device) + for _ in range(num_tensors) + ] + + # Reference computation + ref_norm = self._reference_multi_tensor_l2norm(tensors, per_tensor=False) + + # Backend computation + noop_flag = torch.tensor([0], dtype=torch.int32, device=self.device) + output_norm = backend.multi_tensor_l2norm( + chunk_size=2048, noop_flag=noop_flag, tensor_lists=[tensors], per_tensor=False + ) + + # CUDA backend returns tuple (norm, per_tensor_norms), extract the first element + if isinstance(output_norm, tuple): + output_norm = output_norm[0] + + self.assert_close( + output_norm, + ref_norm, + rtol=1e-4, + atol=1e-6, + msg=f"multi_tensor_l2norm total norm mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_multi_tensor_l2norm_per_tensor(self, num_tensors=4, shape=(64, 128)): + print( + f"\n Testing multi_tensor_l2norm per_tensor with {num_tensors} tensors of shape" + f" {shape}" + ) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + tensors = [ + generate_random_tensor(shape, dtype=torch.float32, device=self.device) + for _ in range(num_tensors) + ] + + # Reference computation + ref_norms = self._reference_multi_tensor_l2norm(tensors, per_tensor=True) + + # Backend computation + noop_flag = torch.tensor([0], dtype=torch.int32, device=self.device) + output_norms = backend.multi_tensor_l2norm( + chunk_size=2048, noop_flag=noop_flag, tensor_lists=[tensors], per_tensor=True + ) + + # CUDA backend returns tuple (total_norm, per_tensor_norms), extract second element + if isinstance(output_norms, tuple): + output_norms = output_norms[1] + + for i, (output, reference) in enumerate(zip(output_norms, ref_norms)): + self.assert_close( + output, + reference, + rtol=1e-4, + atol=1e-6, + msg=f"multi_tensor_l2norm per_tensor {i} mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_multi_tensor_adam(self, num_tensors=3, shape=(32, 64)): + print(f"\n Testing multi_tensor_adam with {num_tensors} tensors of shape {shape}") + + lr = 0.001 + beta1 = 0.9 + beta2 = 0.999 + eps = 1e-8 + step = 1 + weight_decay = 0.01 + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + # Create tensors for backend test + params = [ + generate_random_tensor(shape, dtype=torch.float32, device=self.device) + for _ in range(num_tensors) + ] + grads = [ + generate_random_tensor(shape, dtype=torch.float32, device=self.device) + for _ in range(num_tensors) + ] + exp_avgs = [torch.zeros_like(p) for p in params] + exp_avg_sqs = [torch.zeros_like(p) for p in params] + + # Create reference tensors with same values + ref_params = [p.clone() for p in params] + ref_grads = [g.clone() for g in grads] + ref_exp_avgs = [torch.zeros_like(p) for p in params] + ref_exp_avg_sqs = [torch.zeros_like(p) for p in params] + + # Apply reference Adam step (matching the torch implementation) + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + for p, g, m, v in zip(ref_params, ref_grads, ref_exp_avgs, ref_exp_avg_sqs): + # AdamW style: weight decay applied to param first + p.mul_(1 - lr * weight_decay) + + # Update biased first moment estimate + m.mul_(beta1).add_(g, alpha=1 - beta1) + # Update biased second raw moment estimate + v.mul_(beta2).addcmul_(g, g, value=1 - beta2) + + # Compute bias-corrected estimates + corrected_m = m / bias_correction1 + corrected_v = v / bias_correction2 + + # Update parameters + denom = corrected_v.sqrt().add_(eps) + p.addcdiv_(corrected_m, denom, value=-lr) + + # Apply backend Adam step + noop_flag = torch.tensor([0], dtype=torch.int32, device=self.device) + backend.multi_tensor_adam( + chunk_size=2048, + noop_flag=noop_flag, + tensor_lists=[grads, params, exp_avgs, exp_avg_sqs], + lr=lr, + beta1=beta1, + beta2=beta2, + epsilon=eps, + step=step, + mode=1, # AdamW mode + bias_correction=1, + weight_decay=weight_decay, + ) + + # Compare results with relaxed tolerance + for i, (output, reference) in enumerate(zip(params, ref_params)): + self.assert_close( + output, + reference, + rtol=1e-3, + atol=1e-5, + msg=f"multi_tensor_adam param {i} mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def _fp32_to_param_remainder(self, fp32_tensor): + """Split FP32 tensor into int16 param (high 16 bits) + int16 remainder (low 16 bits). + + Matches the CUDA split convention: + 1. Extract high 16 bits as param, low 16 bits as remainder. + 2. If remainder < 0, increment param (round up). + """ + int32 = fp32_tensor.view(torch.int32) + rem = (int32 & 0xFFFF).to(torch.int16) + high = ((int32 >> 16) & 0xFFFF).to(torch.int16) + high = torch.where(rem < 0, high + 1, high) + # param is stored as bf16 (same bits as high int16) + param = high.view(torch.bfloat16) + return param, rem + + def _param_remainder_to_fp32(self, param, remainder): + """Reconstruct FP32 from int16 param (high bits) + int16 remainder (low bits). + + Matches the CUDA reconstruct convention: + 1. If remainder < 0, decrement param (undo rounding). + 2. Combine high and low 16 bits into FP32. + """ + local_p = param.view(torch.int16).clone() + local_rem = remainder.clone() + local_p = torch.where(local_rem < 0, local_p - 1, local_p) + high = local_p.to(torch.int32) << 16 + low = local_rem.to(torch.int32) & 0xFFFF + return (high | low).view(torch.float32) + + def _reference_adam_param_remainder( + self, + grads, + params, + exp_avgs, + exp_avg_sqs, + param_remainders, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ): + """Pure-PyTorch reference for multi_tensor_adam_param_remainder.""" + bc1 = 1 - beta1**step if bias_correction else 1.0 + bc2 = 1 - beta2**step if bias_correction else 1.0 + is_adamw = mode == 1 + + for g, p, m, v, p_rem in zip(grads, params, exp_avgs, exp_avg_sqs, param_remainders): + g_float = g.float() + param_master = self._param_remainder_to_fp32(p, p_rem) + + if not is_adamw and weight_decay != 0: + g_float = g_float + weight_decay * param_master + + m.mul_(beta1).add_(g_float, alpha=1 - beta1) + v.mul_(beta2).addcmul_(g_float, g_float, value=1 - beta2) + + m_corr = m / bc1 + v_corr = v / bc2 + denom = torch.sqrt(v_corr) + epsilon + update = m_corr / denom + + if is_adamw and weight_decay != 0: + update = update + weight_decay * param_master + + param_master = param_master - lr * update + + new_p, new_rem = self._fp32_to_param_remainder(param_master) + p.view(torch.int16).copy_(new_p.view(torch.int16)) + p_rem.copy_(new_rem) + + def test_multi_tensor_adam_param_remainder(self, num_tensors=3, shape=(32, 64)): + print( + f"\n Testing multi_tensor_adam_param_remainder with {num_tensors} tensors of shape" + f" {shape}" + ) + + lr = 0.001 + beta1 = 0.9 + beta2 = 0.999 + eps = 1e-8 + step = 1 + weight_decay = 0.01 + mode = 1 # AdamW + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + # Create FP32 master weights, then split into param + remainder + master_weights = [ + generate_random_tensor(shape, dtype=torch.float32, device=self.device) + for _ in range(num_tensors) + ] + grads = [ + generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + for _ in range(num_tensors) + ] + + params = [] + remainders = [] + for mw in master_weights: + p, r = self._fp32_to_param_remainder(mw) + params.append(p.clone()) + remainders.append(r.clone()) + + exp_avgs = [ + torch.zeros(shape, dtype=torch.float32, device=self.device) + for _ in range(num_tensors) + ] + exp_avg_sqs = [ + torch.zeros(shape, dtype=torch.float32, device=self.device) + for _ in range(num_tensors) + ] + + # Clone for reference + ref_params = [p.clone() for p in params] + ref_remainders = [r.clone() for r in remainders] + ref_exp_avgs = [torch.zeros_like(m) for m in exp_avgs] + ref_exp_avg_sqs = [torch.zeros_like(v) for v in exp_avg_sqs] + ref_grads = [g.clone() for g in grads] + + # Reference step + self._reference_adam_param_remainder( + ref_grads, + ref_params, + ref_exp_avgs, + ref_exp_avg_sqs, + ref_remainders, + lr, + beta1, + beta2, + eps, + step, + mode, + 1, + weight_decay, + ) + + # Backend step + noop_flag = torch.tensor([0], dtype=torch.int32, device=self.device) + backend.multi_tensor_adam_param_remainder( + chunk_size=2048, + noop_flag=noop_flag, + tensor_lists=[grads, params, exp_avgs, exp_avg_sqs, remainders], + lr=lr, + beta1=beta1, + beta2=beta2, + epsilon=eps, + step=step, + mode=mode, + bias_correction=1, + weight_decay=weight_decay, + ) + + # Compare reconstructed FP32 master weights + for i in range(num_tensors): + out_fp32 = self._param_remainder_to_fp32(params[i], remainders[i]) + ref_fp32 = self._param_remainder_to_fp32(ref_params[i], ref_remainders[i]) + self.assert_close( + out_fp32, + ref_fp32, + rtol=1e-5, + atol=1e-7, + msg=( + f"multi_tensor_adam_param_remainder param {i} mismatch for" + f" {backend_name}" + ), + ) + self.assert_close( + exp_avgs[i], + ref_exp_avgs[i], + rtol=1e-5, + atol=1e-7, + msg=( + f"multi_tensor_adam_param_remainder exp_avg {i} mismatch for" + f" {backend_name}" + ), + ) + self.assert_close( + exp_avg_sqs[i], + ref_exp_avg_sqs[i], + rtol=1e-5, + atol=1e-7, + msg=( + f"multi_tensor_adam_param_remainder exp_avg_sq {i} mismatch for" + f" {backend_name}" + ), + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def _reference_multi_tensor_unscale_l2norm(self, tensors, inv_scale, per_tensor=False): + """Reference implementation for multi_tensor_unscale_l2norm. + + Computes L2 norm of tensors after unscaling. + Note: scale parameter is actually inv_scale (1/loss_scale). + Unscaling means multiplying by inv_scale (= dividing by loss_scale). + """ + inv_scale_value = inv_scale.item() if isinstance(inv_scale, torch.Tensor) else inv_scale + # Unscale (multiply by inv_scale) and compute L2 norm + if per_tensor: + return [torch.norm(t.float() * inv_scale_value, p=2) for t in tensors] + else: + total_norm_sq = sum(torch.norm(t.float() * inv_scale_value, p=2) ** 2 for t in tensors) + return torch.sqrt(total_norm_sq) + + def test_multi_tensor_unscale_l2norm(self, num_tensors=4, shape=(64, 128)): + print( + f"\n Testing multi_tensor_unscale_l2norm with {num_tensors} tensors of shape {shape}" + ) + + # Note: scale parameter is actually inv_scale (1/loss_scale) + # For AMP with loss_scale=1024, inv_scale would be 1/1024 + inv_scale_value = 0.5 # equivalent to loss_scale = 2.0 + tensors = [ + generate_random_tensor(shape, dtype=torch.float32, device=self.device) + for _ in range(num_tensors) + ] + noop_flag = torch.tensor([0], dtype=torch.int32, device=self.device) + inv_scale = torch.tensor([inv_scale_value], dtype=torch.float32, device=self.device) + + # Compute mathematical reference + reference_norm = self._reference_multi_tensor_unscale_l2norm( + tensors, inv_scale, per_tensor=False + ) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output_norm = backend.multi_tensor_unscale_l2norm( + chunk_size=2048, + noop_flag=noop_flag, + tensor_lists=[tensors], + inv_scale=inv_scale, + per_tensor=False, + ) + + # CUDA backend returns tuple (norm, per_tensor_norms), extract the first element + if isinstance(output_norm, tuple): + output_norm = output_norm[0] + + self.assert_close( + output_norm, + reference_norm, + rtol=1e-4, + atol=1e-6, + msg=f"multi_tensor_unscale_l2norm mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def run_all_tests(self): + print("\n" + "=" * 60) + print("Testing Optimizer Operations") + print("=" * 60) + print(f"Available backends: {', '.join(self.backends)}") + + # multi_tensor_scale tests + self.test_multi_tensor_scale(num_tensors=4, shape=(64, 128)) + self.test_multi_tensor_scale(num_tensors=8, shape=(128, 256)) + + # multi_tensor_l2norm tests + self.test_multi_tensor_l2norm(num_tensors=4, shape=(64, 128)) + self.test_multi_tensor_l2norm_per_tensor(num_tensors=4, shape=(64, 128)) + + # multi_tensor_unscale_l2norm tests + self.test_multi_tensor_unscale_l2norm(num_tensors=4, shape=(64, 128)) + + # multi_tensor_adam tests + self.test_multi_tensor_adam(num_tensors=3, shape=(32, 64)) + + # multi_tensor_adam_param_remainder tests + self.test_multi_tensor_adam_param_remainder(num_tensors=3, shape=(32, 64)) + + return self.report() + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + test_suite = OptimizerTests(device=device) + success = test_suite.run_all_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/transformer_engine/plugin/tests/test_policy.py b/transformer_engine/plugin/tests/test_policy.py new file mode 100644 index 0000000000..35b102a104 --- /dev/null +++ b/transformer_engine/plugin/tests/test_policy.py @@ -0,0 +1,760 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Test suite for TE-FL scheduling policy system. + +This module tests: +1. SelectionPolicy creation and configuration +2. Environment variable parsing +3. Policy context managers +4. Vendor filtering (allow/deny) +5. Per-operator custom ordering +6. PolicyManager singleton and thread safety +7. Integration with OpManager +""" + +import os +import sys +import threading +import unittest +from unittest.mock import patch +from typing import List, Dict + + +class TestSelectionPolicy(unittest.TestCase): + """Test SelectionPolicy dataclass and methods""" + + def setUp(self): + """Import policy module fresh for each test""" + from transformer_engine.plugin.core.policy import ( + SelectionPolicy, + PREFER_DEFAULT, + PREFER_VENDOR, + PREFER_REFERENCE, + ) + + self.SelectionPolicy = SelectionPolicy + self.PREFER_DEFAULT = PREFER_DEFAULT + self.PREFER_VENDOR = PREFER_VENDOR + self.PREFER_REFERENCE = PREFER_REFERENCE + + def test_default_policy_creation(self): + """Test creating policy with default values""" + policy = self.SelectionPolicy.from_dict() + + self.assertEqual(policy.prefer, self.PREFER_DEFAULT) + self.assertFalse(policy.strict) + self.assertEqual(policy.per_op_order, ()) + self.assertEqual(policy.deny_vendors, frozenset()) + self.assertIsNone(policy.allow_vendors) + print(" [PASS] Default policy creation") + + def test_policy_with_prefer_vendor(self): + """Test creating policy with vendor preference""" + policy = self.SelectionPolicy.from_dict(prefer="vendor") + + self.assertEqual(policy.prefer, "vendor") + self.assertEqual(policy.get_default_order(), ["vendor", "flagos", "reference"]) + print(" [PASS] Policy with vendor preference") + + def test_policy_with_prefer_reference(self): + """Test creating policy with reference preference""" + policy = self.SelectionPolicy.from_dict(prefer="reference") + + self.assertEqual(policy.prefer, "reference") + self.assertEqual(policy.get_default_order(), ["reference", "flagos", "vendor"]) + print(" [PASS] Policy with reference preference") + + def test_policy_with_prefer_flagos(self): + """Test creating policy with flagos preference (default)""" + policy = self.SelectionPolicy.from_dict(prefer="flagos") + + self.assertEqual(policy.prefer, "flagos") + self.assertEqual(policy.get_default_order(), ["flagos", "vendor", "reference"]) + print(" [PASS] Policy with flagos preference") + + def test_invalid_prefer_value(self): + """Test that invalid prefer value raises error""" + with self.assertRaises(ValueError) as context: + self.SelectionPolicy.from_dict(prefer="invalid") + + self.assertIn("Invalid prefer value", str(context.exception)) + print(" [PASS] Invalid prefer value raises error") + + def test_strict_mode(self): + """Test strict mode setting""" + policy = self.SelectionPolicy.from_dict(strict=True) + + self.assertTrue(policy.strict) + print(" [PASS] Strict mode setting") + + def test_deny_vendors(self): + """Test deny vendors configuration""" + policy = self.SelectionPolicy.from_dict(deny_vendors={"rocm", "dcu"}) + + self.assertEqual(policy.deny_vendors, frozenset({"rocm", "dcu"})) + self.assertFalse(policy.is_vendor_allowed("rocm")) + self.assertFalse(policy.is_vendor_allowed("dcu")) + self.assertTrue(policy.is_vendor_allowed("cuda")) + print(" [PASS] Deny vendors configuration") + + def test_allow_vendors(self): + """Test allow vendors whitelist""" + policy = self.SelectionPolicy.from_dict(allow_vendors={"cuda"}) + + self.assertEqual(policy.allow_vendors, frozenset({"cuda"})) + self.assertTrue(policy.is_vendor_allowed("cuda")) + self.assertFalse(policy.is_vendor_allowed("rocm")) + print(" [PASS] Allow vendors whitelist") + + def test_deny_overrides_allow(self): + """Test that deny takes precedence over allow""" + policy = self.SelectionPolicy.from_dict( + allow_vendors={"cuda", "rocm"}, + deny_vendors={"rocm"}, + ) + + self.assertTrue(policy.is_vendor_allowed("cuda")) + self.assertFalse(policy.is_vendor_allowed("rocm")) + print(" [PASS] Deny overrides allow") + + def test_per_op_order(self): + """Test per-operator custom ordering""" + policy = self.SelectionPolicy.from_dict( + per_op_order={ + "layernorm_fwd": ["vendor", "flagos"], + "rmsnorm_fwd": ["flagos", "reference"], + } + ) + + self.assertEqual(policy.get_per_op_order("layernorm_fwd"), ["vendor", "flagos"]) + self.assertEqual(policy.get_per_op_order("rmsnorm_fwd"), ["flagos", "reference"]) + self.assertIsNone(policy.get_per_op_order("unknown_op")) + print(" [PASS] Per-operator custom ordering") + + def test_policy_fingerprint(self): + """Test policy fingerprint generation""" + policy1 = self.SelectionPolicy.from_dict(prefer="vendor", strict=True) + policy2 = self.SelectionPolicy.from_dict(prefer="vendor", strict=True) + policy3 = self.SelectionPolicy.from_dict(prefer="flagos", strict=True) + + self.assertEqual(policy1.fingerprint(), policy2.fingerprint()) + self.assertNotEqual(policy1.fingerprint(), policy3.fingerprint()) + print(" [PASS] Policy fingerprint generation") + + def test_policy_immutability(self): + """Test that SelectionPolicy is immutable (frozen dataclass)""" + policy = self.SelectionPolicy.from_dict(prefer="vendor") + + with self.assertRaises(AttributeError): + policy.prefer = "flagos" # Should fail - frozen dataclass + print(" [PASS] Policy immutability") + + def test_policy_hashable(self): + """Test that SelectionPolicy is hashable (can be used in sets/dicts)""" + policy1 = self.SelectionPolicy.from_dict(prefer="vendor") + policy2 = self.SelectionPolicy.from_dict(prefer="vendor") + + policy_set = {policy1, policy2} + self.assertEqual(len(policy_set), 1) # Same policy, should dedupe + print(" [PASS] Policy hashable") + + +class TestPolicyManager(unittest.TestCase): + """Test PolicyManager singleton and state management""" + + def setUp(self): + """Reset policy manager state before each test""" + from transformer_engine.plugin.core.policy import ( + PolicyManager, + reset_global_policy, + ) + + reset_global_policy() + self.PolicyManager = PolicyManager + + def tearDown(self): + """Clean up after each test""" + from transformer_engine.plugin.core.policy import reset_global_policy + + reset_global_policy() + # Clear any test environment variables + for key in [ + "TE_FL_PREFER", + "TE_FL_PREFER_VENDOR", + "TE_FL_STRICT", + "TE_FL_DENY_VENDORS", + "TE_FL_ALLOW_VENDORS", + "TE_FL_PER_OP", + ]: + os.environ.pop(key, None) + + def test_singleton_pattern(self): + """Test PolicyManager is a singleton""" + manager1 = self.PolicyManager.get_instance() + manager2 = self.PolicyManager.get_instance() + + self.assertIs(manager1, manager2) + print(" [PASS] PolicyManager singleton pattern") + + def test_policy_epoch(self): + """Test policy epoch tracking""" + from transformer_engine.plugin.core.policy import ( + get_policy_epoch, + bump_policy_epoch, + ) + + initial_epoch = get_policy_epoch() + new_epoch = bump_policy_epoch() + + self.assertEqual(new_epoch, initial_epoch + 1) + self.assertEqual(get_policy_epoch(), new_epoch) + print(" [PASS] Policy epoch tracking") + + def test_global_policy_set_and_get(self): + """Test setting and getting global policy""" + from transformer_engine.plugin.core.policy import ( + SelectionPolicy, + set_global_policy, + get_policy, + ) + + custom_policy = SelectionPolicy.from_dict(prefer="vendor", strict=True) + old_policy = set_global_policy(custom_policy) + + current = get_policy() + self.assertEqual(current.prefer, "vendor") + self.assertTrue(current.strict) + print(" [PASS] Global policy set and get") + + def test_reset_global_policy(self): + """Test resetting global policy to env defaults""" + from transformer_engine.plugin.core.policy import ( + SelectionPolicy, + set_global_policy, + reset_global_policy, + get_policy, + ) + + # Set custom policy + custom_policy = SelectionPolicy.from_dict(prefer="vendor") + set_global_policy(custom_policy) + + # Reset to defaults + reset_global_policy() + + current = get_policy() + self.assertEqual(current.prefer, "flagos") # Default + print(" [PASS] Reset global policy") + + +class TestEnvironmentVariables(unittest.TestCase): + """Test environment variable parsing""" + + def setUp(self): + """Clear environment and reset policy""" + from transformer_engine.plugin.core.policy import reset_global_policy + + reset_global_policy() + # Clear all test env vars + for key in [ + "TE_FL_PREFER", + "TE_FL_PREFER_VENDOR", + "TE_FL_STRICT", + "TE_FL_DENY_VENDORS", + "TE_FL_ALLOW_VENDORS", + "TE_FL_PER_OP", + ]: + os.environ.pop(key, None) + + def tearDown(self): + """Clean up environment""" + for key in [ + "TE_FL_PREFER", + "TE_FL_PREFER_VENDOR", + "TE_FL_STRICT", + "TE_FL_DENY_VENDORS", + "TE_FL_ALLOW_VENDORS", + "TE_FL_PER_OP", + ]: + os.environ.pop(key, None) + from transformer_engine.plugin.core.policy import reset_global_policy + + reset_global_policy() + + def test_te_fl_prefer_flagos(self): + """Test TE_FL_PREFER=flagos""" + os.environ["TE_FL_PREFER"] = "flagos" + + from transformer_engine.plugin.core.policy import policy_from_env + + policy = policy_from_env() + + self.assertEqual(policy.prefer, "flagos") + print(" [PASS] TE_FL_PREFER=flagos") + + def test_te_fl_prefer_vendor(self): + """Test TE_FL_PREFER=vendor""" + os.environ["TE_FL_PREFER"] = "vendor" + + from transformer_engine.plugin.core.policy import policy_from_env + + policy = policy_from_env() + + self.assertEqual(policy.prefer, "vendor") + print(" [PASS] TE_FL_PREFER=vendor") + + def test_te_fl_prefer_reference(self): + """Test TE_FL_PREFER=reference""" + os.environ["TE_FL_PREFER"] = "reference" + + from transformer_engine.plugin.core.policy import policy_from_env + + policy = policy_from_env() + + self.assertEqual(policy.prefer, "reference") + print(" [PASS] TE_FL_PREFER=reference") + + def test_te_fl_prefer_vendor_legacy(self): + """Test legacy TE_FL_PREFER_VENDOR=1""" + os.environ["TE_FL_PREFER_VENDOR"] = "1" + + from transformer_engine.plugin.core.policy import policy_from_env + + policy = policy_from_env() + + self.assertEqual(policy.prefer, "vendor") + print(" [PASS] TE_FL_PREFER_VENDOR=1 (legacy)") + + def test_te_fl_prefer_overrides_legacy(self): + """Test that TE_FL_PREFER takes precedence over TE_FL_PREFER_VENDOR""" + os.environ["TE_FL_PREFER"] = "reference" + os.environ["TE_FL_PREFER_VENDOR"] = "1" + + from transformer_engine.plugin.core.policy import policy_from_env + + policy = policy_from_env() + + self.assertEqual(policy.prefer, "reference") # TE_FL_PREFER wins + print(" [PASS] TE_FL_PREFER overrides TE_FL_PREFER_VENDOR") + + def test_te_fl_strict(self): + """Test TE_FL_STRICT=1""" + os.environ["TE_FL_STRICT"] = "1" + + from transformer_engine.plugin.core.policy import policy_from_env + + policy = policy_from_env() + + self.assertTrue(policy.strict) + print(" [PASS] TE_FL_STRICT=1") + + def test_te_fl_deny_vendors(self): + """Test TE_FL_DENY_VENDORS parsing""" + os.environ["TE_FL_DENY_VENDORS"] = "rocm,dcu,intel" + + from transformer_engine.plugin.core.policy import policy_from_env + + policy = policy_from_env() + + self.assertEqual(policy.deny_vendors, frozenset({"rocm", "dcu", "intel"})) + print(" [PASS] TE_FL_DENY_VENDORS parsing") + + def test_te_fl_allow_vendors(self): + """Test TE_FL_ALLOW_VENDORS parsing""" + os.environ["TE_FL_ALLOW_VENDORS"] = "cuda,rocm" + + from transformer_engine.plugin.core.policy import policy_from_env + + policy = policy_from_env() + + self.assertEqual(policy.allow_vendors, frozenset({"cuda", "rocm"})) + print(" [PASS] TE_FL_ALLOW_VENDORS parsing") + + def test_te_fl_per_op(self): + """Test TE_FL_PER_OP parsing""" + os.environ["TE_FL_PER_OP"] = "layernorm_fwd=vendor|flagos;rmsnorm_fwd=flagos|reference" + + from transformer_engine.plugin.core.policy import policy_from_env + + policy = policy_from_env() + + self.assertEqual(policy.get_per_op_order("layernorm_fwd"), ["vendor", "flagos"]) + self.assertEqual(policy.get_per_op_order("rmsnorm_fwd"), ["flagos", "reference"]) + print(" [PASS] TE_FL_PER_OP parsing") + + +class TestContextManagers(unittest.TestCase): + """Test policy context managers""" + + def setUp(self): + """Reset policy before each test""" + from transformer_engine.plugin.core.policy import reset_global_policy + + reset_global_policy() + + def tearDown(self): + """Clean up after test""" + from transformer_engine.plugin.core.policy import reset_global_policy + + reset_global_policy() + + def test_policy_context(self): + """Test basic policy_context usage""" + from transformer_engine.plugin.core.policy import ( + SelectionPolicy, + policy_context, + get_policy, + ) + + original = get_policy() + custom = SelectionPolicy.from_dict(prefer="vendor", strict=True) + + with policy_context(custom): + inside = get_policy() + self.assertEqual(inside.prefer, "vendor") + self.assertTrue(inside.strict) + + after = get_policy() + self.assertEqual(after.prefer, original.prefer) + print(" [PASS] policy_context usage") + + def test_with_preference(self): + """Test with_preference context manager""" + from transformer_engine.plugin.core.policy import ( + with_preference, + get_policy, + ) + + original = get_policy() + + with with_preference("vendor"): + self.assertEqual(get_policy().prefer, "vendor") + + with with_preference("reference"): + self.assertEqual(get_policy().prefer, "reference") + + self.assertEqual(get_policy().prefer, original.prefer) + print(" [PASS] with_preference context manager") + + def test_with_strict_mode(self): + """Test with_strict_mode context manager""" + from transformer_engine.plugin.core.policy import ( + with_strict_mode, + get_policy, + ) + + original = get_policy() + + with with_strict_mode(): + self.assertTrue(get_policy().strict) + + self.assertEqual(get_policy().strict, original.strict) + print(" [PASS] with_strict_mode context manager") + + def test_with_allowed_vendors(self): + """Test with_allowed_vendors context manager""" + from transformer_engine.plugin.core.policy import ( + with_allowed_vendors, + get_policy, + ) + + with with_allowed_vendors("cuda", "rocm"): + policy = get_policy() + self.assertEqual(policy.allow_vendors, frozenset({"cuda", "rocm"})) + + self.assertIsNone(get_policy().allow_vendors) + print(" [PASS] with_allowed_vendors context manager") + + def test_with_denied_vendors(self): + """Test with_denied_vendors context manager""" + from transformer_engine.plugin.core.policy import ( + with_denied_vendors, + get_policy, + ) + + with with_denied_vendors("rocm", "dcu"): + policy = get_policy() + self.assertIn("rocm", policy.deny_vendors) + self.assertIn("dcu", policy.deny_vendors) + + self.assertEqual(get_policy().deny_vendors, frozenset()) + print(" [PASS] with_denied_vendors context manager") + + def test_nested_contexts(self): + """Test nested context managers""" + from transformer_engine.plugin.core.policy import ( + with_preference, + with_strict_mode, + get_policy, + ) + + with with_preference("vendor"): + self.assertEqual(get_policy().prefer, "vendor") + + with with_strict_mode(): + policy = get_policy() + # Note: with_strict_mode creates new policy with current prefer + self.assertTrue(policy.strict) + + # Back to vendor preference, not strict + self.assertEqual(get_policy().prefer, "vendor") + + # Back to default + self.assertEqual(get_policy().prefer, "flagos") + print(" [PASS] Nested context managers") + + +class TestTokenMatching(unittest.TestCase): + """Test token matching for implementation selection""" + + def test_match_flagos_token(self): + """Test matching 'flagos' token""" + from transformer_engine.plugin.core.types import OpImpl, BackendImplKind, match_token + + impl = OpImpl( + op_name="test_op", + impl_id="test.flagos", + kind=BackendImplKind.DEFAULT, + fn=lambda: None, + ) + + self.assertTrue(match_token(impl, "flagos")) + self.assertFalse(match_token(impl, "vendor")) + self.assertFalse(match_token(impl, "reference")) + print(" [PASS] Match flagos token") + + def test_match_vendor_token(self): + """Test matching 'vendor' token""" + from transformer_engine.plugin.core.types import OpImpl, BackendImplKind, match_token + + impl = OpImpl( + op_name="test_op", + impl_id="test.cuda", + kind=BackendImplKind.VENDOR, + fn=lambda: None, + vendor="cuda", + ) + + self.assertTrue(match_token(impl, "vendor")) + self.assertFalse(match_token(impl, "flagos")) + print(" [PASS] Match vendor token") + + def test_match_specific_vendor_token(self): + """Test matching 'vendor:' token""" + from transformer_engine.plugin.core.types import OpImpl, BackendImplKind, match_token + + impl = OpImpl( + op_name="test_op", + impl_id="test.cuda", + kind=BackendImplKind.VENDOR, + fn=lambda: None, + vendor="cuda", + ) + + self.assertTrue(match_token(impl, "vendor:cuda")) + self.assertFalse(match_token(impl, "vendor:rocm")) + print(" [PASS] Match specific vendor token") + + def test_match_impl_token(self): + """Test matching 'impl:' token""" + from transformer_engine.plugin.core.types import OpImpl, BackendImplKind, match_token + + impl = OpImpl( + op_name="test_op", + impl_id="layernorm_cuda_v2", + kind=BackendImplKind.VENDOR, + fn=lambda: None, + vendor="cuda", + ) + + self.assertTrue(match_token(impl, "impl:layernorm_cuda_v2")) + self.assertFalse(match_token(impl, "impl:other_impl")) + print(" [PASS] Match impl token") + + def test_match_reference_token(self): + """Test matching 'reference' token""" + from transformer_engine.plugin.core.types import OpImpl, BackendImplKind, match_token + + impl = OpImpl( + op_name="test_op", + impl_id="test.reference", + kind=BackendImplKind.REFERENCE, + fn=lambda: None, + ) + + self.assertTrue(match_token(impl, "reference")) + self.assertFalse(match_token(impl, "flagos")) + self.assertFalse(match_token(impl, "vendor")) + print(" [PASS] Match reference token") + + +class TestThreadSafety(unittest.TestCase): + """Test thread safety of PolicyManager""" + + def test_concurrent_policy_access(self): + """Test concurrent access to policy""" + from transformer_engine.plugin.core.policy import ( + SelectionPolicy, + set_global_policy, + get_policy, + reset_global_policy, + ) + + reset_global_policy() + errors = [] + results = [] + + def worker(prefer_value: str, worker_id: int): + try: + for _ in range(100): + policy = SelectionPolicy.from_dict(prefer=prefer_value) + set_global_policy(policy) + current = get_policy() + # Policy should be one of the valid values + if current.prefer not in ["flagos", "vendor", "reference"]: + errors.append(f"Worker {worker_id}: Invalid prefer value {current.prefer}") + results.append(worker_id) + except Exception as e: + errors.append(f"Worker {worker_id}: {e}") + + threads = [ + threading.Thread(target=worker, args=("flagos", 0)), + threading.Thread(target=worker, args=("vendor", 1)), + threading.Thread(target=worker, args=("reference", 2)), + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(errors), 0, f"Errors: {errors}") + self.assertEqual(len(results), 3) + print(" [PASS] Concurrent policy access") + + def test_policy_epoch_increment(self): + """Test that policy epoch increments correctly under contention""" + from transformer_engine.plugin.core.policy import ( + get_policy_epoch, + bump_policy_epoch, + ) + + initial_epoch = get_policy_epoch() + increments = 100 + threads_count = 4 + + def bump_epochs(): + for _ in range(increments): + bump_policy_epoch() + + threads = [threading.Thread(target=bump_epochs) for _ in range(threads_count)] + + for t in threads: + t.start() + for t in threads: + t.join() + + final_epoch = get_policy_epoch() + expected = initial_epoch + (increments * threads_count) + + self.assertEqual(final_epoch, expected) + print(" [PASS] Policy epoch increment under contention") + + +class TestDefaultOrder(unittest.TestCase): + """Test default selection order based on preference""" + + def test_flagos_preference_order(self): + """Test selection order with flagos preference""" + from transformer_engine.plugin.core.policy import SelectionPolicy + + policy = SelectionPolicy.from_dict(prefer="flagos") + order = policy.get_default_order() + + self.assertEqual(order, ["flagos", "vendor", "reference"]) + print(" [PASS] Flagos preference order") + + def test_vendor_preference_order(self): + """Test selection order with vendor preference""" + from transformer_engine.plugin.core.policy import SelectionPolicy + + policy = SelectionPolicy.from_dict(prefer="vendor") + order = policy.get_default_order() + + self.assertEqual(order, ["vendor", "flagos", "reference"]) + print(" [PASS] Vendor preference order") + + def test_reference_preference_order(self): + """Test selection order with reference preference""" + from transformer_engine.plugin.core.policy import SelectionPolicy + + policy = SelectionPolicy.from_dict(prefer="reference") + order = policy.get_default_order() + + self.assertEqual(order, ["reference", "flagos", "vendor"]) + print(" [PASS] Reference preference order") + + +def run_all_tests(): + """Run all policy tests""" + print("\n" + "=" * 60) + print("TE-FL Scheduling Policy Test Suite") + print("=" * 60) + + # Create test suite + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add test classes + test_classes = [ + TestSelectionPolicy, + TestPolicyManager, + TestEnvironmentVariables, + TestContextManagers, + TestTokenMatching, + TestThreadSafety, + TestDefaultOrder, + ] + + for test_class in test_classes: + print(f"\n[Testing {test_class.__name__}]") + tests = loader.loadTestsFromTestCase(test_class) + for test in tests: + result = unittest.TestResult() + test.run(result) + if result.wasSuccessful(): + pass # Print statements are in individual tests + else: + for failure in result.failures + result.errors: + print(f" [FAIL] {test}: {failure[1]}") + suite.addTests(tests) + + # Run the full suite for final summary + print("\n" + "=" * 60) + print("Final Summary") + print("=" * 60) + + runner = unittest.TextTestRunner(verbosity=0) + result = runner.run(suite) + + total = result.testsRun + failures = len(result.failures) + errors = len(result.errors) + passed = total - failures - errors + + print(f"\nTotal: {total}, Passed: {passed}, Failed: {failures}, Errors: {errors}") + + return failures == 0 and errors == 0 + + +def main(): + """Main entry point""" + success = run_all_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/transformer_engine/plugin/tests/test_softmax.py b/transformer_engine/plugin/tests/test_softmax.py new file mode 100644 index 0000000000..8bdf29dcc3 --- /dev/null +++ b/transformer_engine/plugin/tests/test_softmax.py @@ -0,0 +1,387 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import torch +import torch.nn.functional as F + +from transformer_engine.plugin.test_utils import ( + get_available_backends, + get_backend, + TestCase, + generate_random_tensor, +) + + +class SoftmaxTests(TestCase): + def __init__(self, device="cpu"): + super().__init__( + "Softmax Operations", "Test correctness of all softmax operations across backends" + ) + self.backends = get_available_backends() + self.device = device + + def test_scaled_softmax_forward(self, shape=(2, 4, 8, 16)): + print(f"\n Testing scaled softmax forward with shape {shape}") + + x = generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + scale = 0.125 + reference = F.softmax(x.float() * scale, dim=-1).to(x.dtype) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output = backend.scaled_softmax_forward(x, scale) + self.assert_close( + output, + reference, + rtol=1e-2, + atol=1e-3, + msg=f"Scaled softmax forward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_scaled_softmax_backward(self, shape=(2, 4, 8, 16)): + print(f"\n Testing scaled softmax backward with shape {shape}") + + # Use bf16 for all computation to match backend precision + x = generate_random_tensor( + shape, dtype=torch.bfloat16, device=self.device, requires_grad=True + ) + scale = 0.125 + grad_output = generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + + # Compute reference gradient using autograd (in float32 for precision, then convert) + x_f32 = x.float().detach().requires_grad_(True) + softmax_output_f32 = F.softmax(x_f32 * scale, dim=-1) + loss = (softmax_output_f32 * grad_output.float()).sum() + loss.backward() + reference_grad = x_f32.grad.clone() + + # Get softmax output in bf16 for backend + softmax_out_test = softmax_output_f32.detach().to(torch.bfloat16) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + # Clone inputs as some backends may modify them in-place + grad_input = backend.scaled_softmax_backward( + grad_output.clone(), softmax_out_test.clone(), scale + ) + self.assert_close( + grad_input.float(), + reference_grad, + rtol=1e-2, + atol=1e-2, + msg=f"Scaled softmax backward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_scaled_masked_softmax_forward(self, shape=(2, 4, 8, 16)): + print(f"\n Testing scaled masked softmax forward with shape {shape}") + + x = generate_random_tensor(shape, dtype=torch.float32, device=self.device) + scale = 0.125 + + # Create boolean mask and corresponding masks + batch = shape[0] + seq_q, seq_k = shape[-2], shape[-1] + bool_mask = torch.rand((batch, 1, seq_q, seq_k), device=self.device) > 0.5 + + # CUDA uses uint8 mask (1=masked, 0=unmasked) + uint8_mask = bool_mask.to(torch.uint8) + + # Additive mask for reference computation + additive_mask = torch.zeros((batch, 1, seq_q, seq_k), dtype=x.dtype, device=self.device) + additive_mask = additive_mask.masked_fill(bool_mask, float("-inf")) + additive_mask_expanded = additive_mask.expand(shape) + + # Reference: F.softmax(x * scale + additive_mask, dim=-1) + reference = F.softmax(x * scale + additive_mask_expanded, dim=-1) + + # Use bf16 for all backends + x_test = x.to(torch.bfloat16) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output = backend.scaled_masked_softmax_forward(x_test, uint8_mask, scale) + self.assert_close( + output.float(), + reference.float(), + rtol=1e-2, + atol=1e-3, + msg=f"Scaled masked softmax forward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_scaled_masked_softmax_backward(self, shape=(2, 4, 8, 16)): + print(f"\n Testing scaled masked softmax backward with shape {shape}") + + # Use bf16 for all computation + x = generate_random_tensor( + shape, dtype=torch.bfloat16, device=self.device, requires_grad=True + ) + scale = 0.125 + grad_output = generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + + # Compute reference gradient using autograd (in float32 for precision) + x_f32 = x.float().detach().requires_grad_(True) + softmax_output_f32 = F.softmax(x_f32 * scale, dim=-1) + loss = (softmax_output_f32 * grad_output.float()).sum() + loss.backward() + reference_grad = x_f32.grad.clone() + + # Get softmax output in bf16 for backend + softmax_out_test = softmax_output_f32.detach().to(torch.bfloat16) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + # Clone inputs as some backends may modify them in-place + grad_input = backend.scaled_masked_softmax_backward( + grad_output.clone(), softmax_out_test.clone(), scale + ) + self.assert_close( + grad_input.float(), + reference_grad, + rtol=1e-2, + atol=1e-2, + msg=f"Scaled masked softmax backward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_scaled_upper_triang_masked_softmax_forward(self, shape=(8, 16, 16)): + print(f"\n Testing scaled upper triang masked softmax forward with shape {shape}") + + x = generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + scale = 0.125 + seq_len = shape[-1] + + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), dtype=x.dtype, device=self.device), + diagonal=1, + ) + reference = F.softmax(x.float() * scale + causal_mask.float(), dim=-1).to(x.dtype) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output = backend.scaled_upper_triang_masked_softmax_forward(x, scale) + self.assert_close( + output, + reference, + rtol=1e-2, + atol=1e-3, + msg=f"Scaled upper triang masked softmax forward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_scaled_upper_triang_masked_softmax_backward(self, shape=(8, 16, 16)): + print(f"\n Testing scaled upper triang masked softmax backward with shape {shape}") + + # Use bf16 for all computation + x = generate_random_tensor( + shape, dtype=torch.bfloat16, device=self.device, requires_grad=True + ) + scale = 0.125 + seq_len = shape[-1] + grad_output = generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), dtype=torch.float32, device=self.device), + diagonal=1, + ) + + # Compute reference gradient using autograd (in float32 for precision) + x_f32 = x.float().detach().requires_grad_(True) + softmax_output_f32 = F.softmax(x_f32 * scale + causal_mask, dim=-1) + loss = (softmax_output_f32 * grad_output.float()).sum() + loss.backward() + reference_grad = x_f32.grad.clone() + + # Get softmax output in bf16 for backend + softmax_out_test = softmax_output_f32.detach().to(torch.bfloat16) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + # Clone inputs as some backends may modify them in-place + grad_input = backend.scaled_upper_triang_masked_softmax_backward( + grad_output.clone(), softmax_out_test.clone(), scale + ) + self.assert_close( + grad_input.float(), + reference_grad, + rtol=1e-2, + atol=1e-2, + msg=f"Scaled upper triang masked softmax backward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_scaled_aligned_causal_masked_softmax_forward(self, shape=(2, 4, 16, 16)): + """Test scaled aligned causal masked softmax forward. + + Note: CUDA backend requires 4D tensor (batch, heads, seq, seq). + """ + print(f"\n Testing scaled aligned causal masked softmax forward with shape {shape}") + + x = generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + scale = 0.125 + seq_len = shape[-1] + + # Aligned causal mask (lower triangular) + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), dtype=x.dtype, device=self.device), + diagonal=1, + ) + reference = F.softmax(x.float() * scale + causal_mask.float(), dim=-1).to(x.dtype) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + output = backend.scaled_aligned_causal_masked_softmax_forward(x, scale) + self.assert_close( + output, + reference, + rtol=1e-2, + atol=1e-3, + msg=f"Scaled aligned causal masked softmax forward mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def test_scaled_aligned_causal_masked_softmax_backward(self, shape=(2, 4, 16, 16)): + """Test scaled aligned causal masked softmax backward. + + Note: All backends use bf16 for consistency. + """ + print(f"\n Testing scaled aligned causal masked softmax backward with shape {shape}") + + # Use bf16 for all computation + x = generate_random_tensor( + shape, dtype=torch.bfloat16, device=self.device, requires_grad=True + ) + scale = 0.125 + seq_len = shape[-1] + grad_output = generate_random_tensor(shape, dtype=torch.bfloat16, device=self.device) + + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), dtype=torch.float32, device=self.device), + diagonal=1, + ) + + # Compute reference gradient using autograd (in float32 for precision) + x_f32 = x.float().detach().requires_grad_(True) + softmax_output_f32 = F.softmax(x_f32 * scale + causal_mask, dim=-1) + loss = (softmax_output_f32 * grad_output.float()).sum() + loss.backward() + reference_grad = x_f32.grad.clone() + + # Get softmax output in bf16 for backend + softmax_out_test = softmax_output_f32.detach().to(torch.bfloat16) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + # Clone inputs as some backends may modify them in-place + grad_input = backend.scaled_aligned_causal_masked_softmax_backward( + grad_output.clone(), softmax_out_test.clone(), scale + ) + self.assert_close( + grad_input.float(), + reference_grad, + rtol=1e-2, + atol=1e-2, + msg=( + f"Scaled aligned causal masked softmax backward mismatch for {backend_name}" + ), + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ {backend_name}: {e}") + + def run_all_tests(self): + print("\n" + "=" * 60) + print("Testing Softmax Operations") + print("=" * 60) + print(f"Available backends: {', '.join(self.backends)}") + + # Scaled softmax tests + self.test_scaled_softmax_forward((4, 8, 16, 16)) + self.test_scaled_softmax_forward((2, 4, 32, 32)) + self.test_scaled_softmax_backward((4, 8, 16, 16)) + self.test_scaled_softmax_backward((2, 4, 32, 32)) + + # Masked softmax tests + self.test_scaled_masked_softmax_forward((4, 8, 16, 16)) + self.test_scaled_masked_softmax_backward((4, 8, 16, 16)) + + # Upper triangular (causal) masked softmax tests + self.test_scaled_upper_triang_masked_softmax_forward((16, 32, 32)) + self.test_scaled_upper_triang_masked_softmax_forward((8, 64, 64)) + self.test_scaled_upper_triang_masked_softmax_backward((16, 32, 32)) + self.test_scaled_upper_triang_masked_softmax_backward((8, 64, 64)) + + # Aligned causal masked softmax tests (4D tensor required by CUDA) + self.test_scaled_aligned_causal_masked_softmax_forward((2, 4, 32, 32)) + self.test_scaled_aligned_causal_masked_softmax_backward((2, 4, 32, 32)) + + return self.report() + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + test_suite = SoftmaxTests(device=device) + success = test_suite.run_all_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/transformer_engine/plugin/tests/test_te_general_grouped.py b/transformer_engine/plugin/tests/test_te_general_grouped.py new file mode 100644 index 0000000000..1bc815cc8b --- /dev/null +++ b/transformer_engine/plugin/tests/test_te_general_grouped.py @@ -0,0 +1,169 @@ +import torch + +from transformer_engine.plugin.test_utils import ( + get_available_backends, + get_backend, + TestCase, + generate_random_tensor, +) + + +class grouped_gemmTests(TestCase): + def __init__(self, device="cpu"): + super().__init__( + "Moe permute Operations", + "Test correctness of all moe permute operations across backends", + ) + self.backends = get_available_backends() + self.device = device + + def test_grouped_gemm_equivalence(self, grad, has_bias, has_pre_gelu, single_output): + print( + "\n test te_general_grouped_gemm" + f" grad:{grad} has_bias:{has_bias},has_pre_gelu:{has_pre_gelu},single_output:{single_output}" + ) + import transformer_engine_torch_nv as tex + + num_gemms = 2 + m, k, n = 128, 32, 64 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 + + if dtype == torch.float16: + te_dtype = tex.DType.kFloat16 + elif dtype == torch.float32: + te_dtype = tex.DType.kFloat32 + elif dtype == torch.bfloat16: + te_dtype = tex.DType.kBFloat16 + else: + raise ValueError(f"不支持的 dtype: {torch_dtype}") + + torch.manual_seed(42) + + A_list = [torch.randn((k, n), device=device, dtype=dtype) for _ in range(num_gemms)] + B_list = [torch.randn((m, k), device=device, dtype=dtype) for _ in range(num_gemms)] + + bias_list_py_bias = [ + ( + torch.randn(n, device=device, dtype=dtype) + if has_bias + else torch.empty(0, device=device, dtype=dtype) + ) + for _ in range(num_gemms) + ] + bias_list_te = [b.clone() for b in bias_list_py_bias] + + pre_gelu_list_py = [ + ( + torch.randn(m, n, device=device, dtype=dtype) + if has_pre_gelu + else torch.empty(0, device=device, dtype=dtype) + ) + for _ in range(num_gemms) + ] + pre_gelu_list_te = [p.clone() for p in pre_gelu_list_py] + + if single_output: + D_list_py = [torch.empty(m * num_gemms, n, device=device, dtype=dtype)] + D_list_te = [torch.empty(m * num_gemms, n, device=device, dtype=dtype)] + else: + D_list_py = [torch.empty(m, n, device=device, dtype=dtype) for _ in range(num_gemms)] + D_list_te = [torch.empty(m, n, device=device, dtype=dtype) for _ in range(num_gemms)] + workspace_py = [torch.empty(1024 * 1024, device=device, dtype=torch.uint8)] + workspace_te = [torch.empty(1024 * 1024, device=device, dtype=torch.uint8)] + + tex.te_general_grouped_gemm( + A_list, + False, + B_list, + False, + D_list_te, + te_dtype, + [], + bias_list_te, + te_dtype, + single_output, + pre_gelu_list_te, + grad, + workspace_te, + 1024 * 1024, + False, + False, + 0, + ) + + for backend_name in self.backends: + backend = get_backend(backend_name) + print("backend:", backend) + try: + bias_list_py = [b.clone() for b in bias_list_py_bias] + backend.te_general_grouped_gemm( + A_list, + False, + B_list, + False, + D_list_py, + te_dtype, + [], + bias_list_py, + te_dtype, + single_output, + pre_gelu_list_py, + grad, + workspace_py, + 1024 * 1024, + False, + False, + 0, + ) + + for py_d, te_d in zip(D_list_py, D_list_te): + self.assert_close( + py_d, te_d, rtol=1e-3, atol=1e-3, msg="Output D tensors mismatch!" + ) + + if not grad and has_pre_gelu: + for py_p, te_p in zip(pre_gelu_list_py, pre_gelu_list_te): + self.assert_close( + py_p, te_p, rtol=1e-3, atol=1e-3, msg="Pre-GELU out tensors mismatch!" + ) + + if grad or has_bias: + for py_b, te_b in zip(bias_list_py, bias_list_te): + self.assert_close( + py_b, te_b, rtol=1e-3, atol=1e-3, msg="Bias gradient tensors mismatch!" + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ Test failed: {e}") + + def run_all_tests(self): + print("\n" + "=" * 60) + print("=" * 60) + print(f"Available backends: {', '.join(self.backends)}") + + # gemm tests + self.test_grouped_gemm_equivalence(False, False, False, False) + self.test_grouped_gemm_equivalence(False, True, False, False) + self.test_grouped_gemm_equivalence(False, False, True, False) + + self.test_grouped_gemm_equivalence(False, False, False, True) + self.test_grouped_gemm_equivalence(False, True, False, True) + self.test_grouped_gemm_equivalence(False, False, True, True) + return self.report() + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + test_suite = grouped_gemmTests(device=device) + success = test_suite.run_all_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index cd18ca75ad..df83faf9ac 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -15,7 +15,8 @@ assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." -load_framework_extension("torch") + +load_framework_extension("torch_nv") from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import LayerNormMLP diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a6a8b0b26a..c0eac9a88d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -15,6 +15,7 @@ import torch import torch.nn.functional as F import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.pytorch.utils import ( get_device_compute_capability, split_tensor_along_dim, @@ -456,10 +457,10 @@ def forward( fp8_recipe = fp8_meta["local_recipes"][0] if fp8_recipe.float8_current_scaling(): S_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=S_quantizer.dtype, device="cuda" + fp8_dtype=S_quantizer.dtype, device=te_device_type() ) dP_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=dP_quantizer.dtype, device="cuda" + fp8_dtype=dP_quantizer.dtype, device=te_device_type() ) if "2" in qkv_layout or "3" in qkv_layout: @@ -750,8 +751,10 @@ def forward( for x in [query_layer, key_layer, value_layer] ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors." assert ( - query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda - ), "FlashAttention currently only supports CUDA tensors." + query_layer.device.type == te_device_type() + and key_layer.device.type == te_device_type() + and value_layer.device.type == te_device_type() + ), f"FlashAttention currently only supports {te_device_type()} tensors." assert ( qkv_layout in QKVLayouts ), f"FlashAttention does not support qkv_layout = {qkv_layout}!" @@ -1821,8 +1824,10 @@ def forward( for x in [query_layer, key_layer, value_layer] ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors." assert ( - query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda - ), "FusedAttention only supports CUDA tensors." + query_layer.device.type == te_device_type() + and key_layer.device.type == te_device_type() + and value_layer.device.type == te_device_type() + ), f"FusedAttention only supports {te_device_type()} tensors." assert ( qkv_layout in QKVLayouts ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 030b1d9cdc..7db4e54530 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1191,6 +1191,9 @@ def cp_p2p_bwd_flash_attn( dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, -1) + # Fix: flash-attn 2.3.x ~ 2.6.x also needs rng_state for dropout + if not use_flash_attn_3 and rng_states is not None: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - step - 1] elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 2dc42be18a..2218fc7ba2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -14,6 +14,7 @@ from torch.nn.parameter import Parameter import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.common.recipe import ( Format, Recipe, @@ -61,6 +62,18 @@ FlashAttention, ) +######################################################################### +# Save reference to native FlashAttention for fallback +_FlashAttentionNative = FlashAttention +# Use plugin system's flash_attention if available, otherwise use native +FlashAttention = getattr(tex, "flash_attention", _FlashAttentionNative) +# Save the original get_attention_backend for backends that want to use default logic +# CUDA backend can access this via dpa_utils._original_get_attention_backend +dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend +# Replace dpa_utils.get_attention_backend with tex.get_attention_backend +# This allows each backend (FlagOS, CUDA, Reference) to control its own backend selection +dpa_utils.get_attention_backend = tex.get_attention_backend +######################################################################### # Setup Attention Logging attn_log.setup_logging() @@ -434,12 +447,14 @@ def __init__( self.softmax_offset = None if self.softmax_type == "off-by-one": self.softmax_offset = torch.zeros( - self.num_attention_heads // self.tp_size, device="cuda" + self.num_attention_heads // self.tp_size, device=te_device_type() ) if self.softmax_type == "learnable": self.register_parameter( "softmax_offset", - Parameter(torch.zeros(self.num_attention_heads // self.tp_size, device="cuda")), + Parameter( + torch.zeros(self.num_attention_heads // self.tp_size, device=te_device_type()) + ), get_rng_state_tracker=get_rng_state_tracker, ) @@ -1057,8 +1072,10 @@ def forward( # checks for q/k/v shapes assert ( - query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda - ), "DotProductAttention only supports CUDA tensors." + query_layer.device.type == te_device_type() + and key_layer.device.type == te_device_type() + and value_layer.device.type == te_device_type() + ), f"DotProductAttention only supports {te_device_type()} tensors." assert ( query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype ), "Queries, keys and values must have the same data type!" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/softmax.py b/transformer_engine/pytorch/attention/dot_product_attention/softmax.py index 74d9583ce5..5ccc63cad5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/softmax.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/softmax.py @@ -8,6 +8,7 @@ import torch from torch import nn import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.pytorch.export import is_in_onnx_export_mode @@ -24,7 +25,7 @@ def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: def _get_mask(): diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1 return torch.triu( - torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset + torch.ones(sq, sk, dtype=torch.bool, device=te_device_type()), diagonal=diagonal_offset ) if is_in_onnx_export_mode(): diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 170cb2cd34..1ac7319b39 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -21,6 +21,7 @@ import torch.nn.functional as F import transformer_engine_torch as tex import transformer_engine as te +from transformer_engine import te_device_type from transformer_engine.pytorch.cpp_extensions.fused_attn import ( QKVLayout, AttnBiasType, @@ -539,6 +540,9 @@ def get_attention_backend( if use_flash_attention: use_flash_attention = False logger.debug("Disabling FlashAttention for max_logit") + if use_fused_attention and qkv_format == "thd": + use_fused_attention = False + logger.debug("Disabling FusedAttention for max_logit with qkv_format = thd") if fp8 and fp8_meta["recipe"].fp8_dpa: use_flash_attention = False use_fused_attention = False @@ -1103,6 +1107,10 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None + if is_training and device_compute_capability >= (10, 0): + logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") + use_fused_attention = False + fused_attention_backend = None # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 @@ -1243,13 +1251,13 @@ def get_padding_mask( ], dim=0, ) - attention_mask_q = attention_mask_q.to(device="cuda") + attention_mask_q = attention_mask_q.to(device=te_device_type()) if attention_type == "self": attention_mask = attention_mask_q else: attention_mask = ( attention_mask_q, - attention_mask_kv.to(device="cuda"), + attention_mask_kv.to(device=te_device_type()), ) return attention_mask @@ -1368,9 +1376,11 @@ def get_full_mask( actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) # apply SWA mask - mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device=te_device_type()).view( 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device=te_device_type()).view( + 1, 1, 1, max_seqlen_kv + ) swa_left = None swa_right = None if attn_mask_type == "causal_bottom_right" or ( @@ -1466,7 +1476,7 @@ def get_alibi( m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2)) m = torch.cat([m, m_hat]) - _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda") + _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device=te_device_type()) _alibi_cache["_num_heads"] = num_heads _alibi_cache["_alibi_slopes_require_update"] = False @@ -1479,9 +1489,9 @@ def get_alibi( else: raise ValueError("ALiBi slopes cannot exceed 2 dimensions.") - bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + bias = torch.arange(max_seqlen_q, dtype=torch.int32, device=te_device_type()).view( 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device=te_device_type()).view( 1, 1, 1, max_seqlen_kv ) if actual_seqlens_q is None and actual_seqlens_kv is None: @@ -1501,7 +1511,9 @@ def get_alibi( _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment bias_dtype = torch.float32 if bias_dtype is None else bias_dtype - _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda") + _alibi_cache["_alibi_bias"] = bias.contiguous().to( + dtype=bias_dtype, device=te_device_type() + ) _alibi_cache["_alibi_bias_require_update"] = False return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"] @@ -1516,7 +1528,7 @@ def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor: mask = mask.squeeze(1).squeeze(1) reduced_mask = mask.logical_not().sum(dim=1) cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) - zero = torch.zeros(1, dtype=torch.int32, device="cuda") + zero = torch.zeros(1, dtype=torch.int32, device=te_device_type()) cu_seqlens = torch.cat((zero, cu_seqlens)) return cu_seqlens @@ -1534,7 +1546,7 @@ def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch. reduced_mask = mask.logical_not().sum(dim=1) cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) - zero = torch.zeros(1, dtype=torch.int32, device="cuda") + zero = torch.zeros(1, dtype=torch.int32, device=te_device_type()) cu_seqlens = torch.cat((zero, cu_seqlens)) mask = mask.reshape(-1) @@ -1559,7 +1571,12 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: bs = len(cu_seqlens) - 1 seqlens = cu_seqlens[1:] - cu_seqlens[:-1] indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)] - indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda") + indices = ( + torch.Tensor(indices) + .unsqueeze(1) + .unsqueeze(1) + .to(dtype=torch.int64, device=te_device_type()) + ) num_nonzeros = indices.shape[0] pad_amount = bs * max_seqlen - num_nonzeros diff --git a/transformer_engine/pytorch/attention/inference.py b/transformer_engine/pytorch/attention/inference.py index 08e50aad8b..c97280dbce 100644 --- a/transformer_engine/pytorch/attention/inference.py +++ b/transformer_engine/pytorch/attention/inference.py @@ -11,6 +11,7 @@ import torch import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat __all__ = ["InferenceParams", "KVCacheManager", "NonPagedKVCacheManager", "PagedKVCacheManager"] @@ -626,7 +627,7 @@ def __init__( self.allocated_pages = defaultdict(list) # page table, [batch_size, max_pages_per_seq] self.page_table = torch.zeros( - self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" + self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device=te_device_type() ) def reset(self): diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index d95d327c78..5864f7eff0 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -8,6 +8,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch +from transformer_engine import te_device_type from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -277,7 +278,7 @@ def __init__( ub_bulk_wgrad: bool = False, bias: bool = True, normalization: str = "LayerNorm", - device: Union[torch.device, str] = "cuda", + device: Union[torch.device, str] = te_device_type(), qkv_format: str = "sbhd", name: str = None, qk_norm_type: Optional[str] = None, diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 77ad57ed8f..3aabfeaeff 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -9,6 +9,7 @@ import torch import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat @@ -76,7 +77,7 @@ def forward(self, max_seq_len: int, offset: int = 0): offset: int, default = 0 Fixed offset for frequencies. """ - with torch.autocast(enabled=False, device_type="cuda"): + with torch.autocast(enabled=False, device_type=te_device_type()): seq = ( torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + offset diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 115569ccba..1c1e17737c 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -10,6 +10,9 @@ import functools import torch import transformer_engine_torch as tex + +from transformer_engine import te_device_type + from ..constants import TE_DType from ..utils import get_sm_count, _empty_tensor @@ -219,7 +222,8 @@ def general_grouped_gemm( if grad and use_bias: grad_bias = [ - torch.empty(B[i].size(1), dtype=out[0].dtype, device="cuda") for i in range(num_gemms) + torch.empty(B[i].size(1), dtype=out[0].dtype, device=te_device_type()) + for i in range(num_gemms) ] else: grad_bias = empty_tensors diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index d0b314a64f..f1d1de64b1 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -16,6 +16,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState import transformer_engine.pytorch as te import transformer_engine.pytorch.cpu_offload_v1 as v1_code_path +from transformer_engine import te_device_type from .quantized_tensor import ( restore_from_saved, prepare_for_saving, @@ -345,7 +346,7 @@ def start_reload(self): # cannot move tensors from pool of one stream to another without # calling cudaFree and cudaMalloc again. - reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda")) + reloaded_tensor = torch.empty_like(tensor, device=torch.device(te_device_type())) self.offload_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.offload_stream): @@ -461,7 +462,7 @@ def _check_if_offload(self, t: torch.Tensor) -> bool: not isinstance(t, torch.nn.Parameter) and not getattr(t, "_TE_do_not_offload", False) and not isinstance(t, torch._subclasses.FakeTensor) - and t.device.type == "cuda" + and t.device.type == te_device_type() ): if not t.is_contiguous() and not getattr(t, "offload_base_tensor", False): warnings.warn( diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index 8580cf4a33..c11c0e34fa 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -10,6 +10,7 @@ import torch +from transformer_engine import te_device_type from transformer_engine.pytorch.custom_recipes import quantization from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer @@ -498,7 +499,7 @@ def make_empty( # Canonicalize tensor attributes if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) # Allocate quantized data qx = torch.empty(shape, dtype=self.dtype, device=device) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index b80e58fe20..6e236e4d5e 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -29,6 +29,8 @@ import transformer_engine_torch as tex +from transformer_engine import te_device_type + from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv from .torch_version import torch_version from .utils import ( @@ -96,7 +98,7 @@ def is_graph_safe_rng_state(state: Union[torch.Tensor, torch.Generator]) -> bool def _get_cuda_rng_state( - device: Union[int, str, torch.device] = "cuda", + device: Union[int, str, torch.device] = te_device_type(), clone: bool = False, graph_safe: bool = True, ) -> torch.Tensor: @@ -106,7 +108,7 @@ def _get_cuda_rng_state( if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): - device = torch.device("cuda", device) + device = torch.device(te_device_type(), device) idx = device.index if idx is None: idx = torch.cuda.current_device() @@ -128,11 +130,11 @@ def _set_cuda_rng_state( """Sets the random number generator state of the current GPU.""" if device == -1: - device = torch.device("cuda") + device = torch.device(te_device_type()) elif isinstance(device, str): device = torch.device(device) elif isinstance(device, int): - device = torch.device("cuda", device) + device = torch.device(te_device_type(), device) def cb() -> None: idx = device.index @@ -294,10 +296,10 @@ def _get_active_autocast_contexts(): autocast_cached = torch.is_autocast_cache_enabled() if torch_version() >= (2, 4, 0): - gpu_autocast_enabled = torch.is_autocast_enabled("cuda") - gpu_autocast_dtype = torch.get_autocast_dtype("cuda") + gpu_autocast_enabled = torch.is_autocast_enabled(te_device_type()) + gpu_autocast_dtype = torch.get_autocast_dtype(te_device_type()) gpu_autocast_ctx = torch.amp.autocast( - "cuda", + te_device_type(), enabled=gpu_autocast_enabled, dtype=gpu_autocast_dtype, cache_enabled=autocast_cached, @@ -1021,7 +1023,7 @@ def _all_gather_fp8( out: Float8TensorStorage if quantizer is not None: dtype = torch.float32 - device = "cuda" + device = te_device_type() if isinstance(inp, Float8Tensor): dtype = inp.dtype device = inp.device diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 1b93b8254c..0cca36e0db 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -3,11 +3,15 @@ # See LICENSE for license information. """NVFuser functions and JIT utilities""" + +# pylint: disable=ungrouped-imports + import os from functools import wraps from typing import Callable, Optional, Tuple import torch +from transformer_engine import te_device_type from .torch_version import torch_version from .export import is_in_onnx_export_mode from .utils import gpu_autocast_ctx @@ -295,9 +299,13 @@ def warmup_jit_bias_dropout_add( # Save cuda RNG state to ensure warmup does not affect reproducibility. rng_state = torch.cuda.get_rng_state() - inp = torch.rand((seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda") - residual = torch.rand((seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda") - bias = torch.rand((hidden_size), dtype=dtype, device="cuda") + inp = torch.rand( + (seq_length, micro_batch_size, hidden_size), dtype=dtype, device=te_device_type() + ) + residual = torch.rand( + (seq_length, micro_batch_size, hidden_size), dtype=dtype, device=te_device_type() + ) + bias = torch.rand((hidden_size), dtype=dtype, device=te_device_type()) dropout_rate = 0.1 # Warmup JIT fusions with the input grad_enable state of both forward # prop and recomputation @@ -332,11 +340,11 @@ def warmup_jit_bias_gelu( # Save cuda RNG state to ensure warmup does not affect reproducibility. rng_state = torch.cuda.get_rng_state() - bias = torch.rand(ffn_hidden_size_per_partition, dtype=dtype, device="cuda") + bias = torch.rand(ffn_hidden_size_per_partition, dtype=dtype, device=te_device_type()) inp = torch.rand( (seq_length * micro_batch_size, ffn_hidden_size_per_partition), dtype=dtype, - device="cuda", + device=te_device_type(), ) # Warmup JIT fusions with the input grad_enable state of both forward # prop and recomputation @@ -370,7 +378,7 @@ def warmup_jit_l2normalization( inp = torch.rand( (seq_length * micro_batch_size, hidden_size), dtype=dtype, - device="cuda", + device=te_device_type(), ) eps = 1e-6 # Warmup JIT fusions with the input grad_enable state of both forward diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a96a87bf89..a8ef74542b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -19,6 +19,9 @@ from torch.distributed.tensor import DTensor import transformer_engine_torch as tex +from transformer_engine import te_device_type, te_platform +from transformer_engine.common.recipe import Recipe + from ._common import _ParameterInitMeta, noop_cat from ..quantization import ( @@ -87,7 +90,7 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor _dummy_wgrads[key] = torch.empty( shape, dtype=dtype, - device="cuda", + device=te_device_type(), requires_grad=False, ) if zero: @@ -640,8 +643,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self, name: Optional[str] = None) -> None: super().__init__() - if not torch.cuda.is_available(): - raise RuntimeError("TransformerEngine needs CUDA.") + assert te_platform().is_available(), f"TransformerEngine needs {te_device_type()}." self.name = name self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False @@ -923,7 +925,7 @@ def set_extra_state(self, state: torch.Tensor) -> None: elif isinstance(state, io.BytesIO): # Deprecated format with io.BytesIO state.seek(0) - state = torch.load(state, map_location="cuda") + state = torch.load(state, map_location=te_device_type()) else: raise RuntimeError("Unsupported checkpoint format.") @@ -1087,9 +1089,9 @@ def prepare_forward( delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) else: - if not inp.is_cuda: + if inp.device.type != te_device_type(): raise RuntimeError( - f"TransformerEngine needs CUDA. Got input on device: {inp.device}" + f"TransformerEngine needs {te_device_type()}. Got input on device: {inp.device}" ) if self.tp_size > 1: @@ -1297,7 +1299,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: param = param._local_tensor if is_dtensor else param # Ensure parameter is on a real device if param.device == torch.device("meta"): - param = torch.empty_like(param, device="cuda") + param = torch.empty_like(param, device=te_device_type()) # Initialize the parameter values on device init_fn = self.param_init_meta[name].init_fn get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index aecdf5fe27..82e56995e2 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -12,8 +12,11 @@ import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor + + from .base import ( get_dummy_wgrad, TransformerEngineBaseModule, @@ -626,7 +629,7 @@ def __init__( return_bias: bool = False, params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, - device: Union[torch.device, str] = "cuda", + device: Union[torch.device, str] = te_device_type(), ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d775dc3e8e..973a4a69e2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -78,6 +78,7 @@ general_gemm, ) + __all__ = ["LayerNormLinear"] diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 037fb6c858..8362fb4b13 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -15,9 +15,12 @@ import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.tensor.utils import is_custom + + from .base import ( fill_userbuffers_buffer_for_all_gather, _ub_communicators, @@ -1376,7 +1379,7 @@ def fc2_wgrad_gemm( reduce_scatter_out = None if ctx.ub_overlap_rs_dgrad: reduce_scatter_out = torch.empty( - fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" + fc1_dgrad_shape, dtype=ctx.activation_dtype, device=te_device_type() ) if ctx.ub_bulk_wgrad: gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) @@ -1458,7 +1461,7 @@ def fc2_wgrad_gemm( reduce_scatter_out = None if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf(): reduce_scatter_out = torch.empty( - fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" + fc1_dgrad_shape, dtype=ctx.activation_dtype, device=te_device_type() ) # Arguments to include in wgrad GEMM closure diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1e3eadc405..eb3a4c3240 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -75,6 +75,7 @@ ) from ...debug.pytorch.debug_state import TEDebugState + __all__ = ["Linear"] diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 13cb519c19..9b01d158d6 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -11,12 +11,16 @@ import torch import transformer_engine_torch as tex + +from transformer_engine import te_device_type + from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext from .._common import maybe_dequantize + __all__ = [ "GELU", "GEGLU", @@ -91,7 +95,7 @@ def op_forward( # Compute dtype dtype: torch.dtype if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) else: dtype = input_.dtype if dtype not in (torch.float32, torch.float16, torch.bfloat16): diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 48376a297f..94911e7ea6 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -12,6 +12,8 @@ import torch +from transformer_engine import te_device_type + from ...cpp_extensions import general_gemm from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import ( @@ -971,7 +973,7 @@ def op_forward( # Get autocast dtype if needed if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) else: dtype = self.weight.dtype diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index d580f84866..ebf94ce631 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -10,6 +10,9 @@ import torch import transformer_engine_torch as tex + +from transformer_engine import te_device_type + from ..op import BasicOperation, OperationContext from ...utils import canonicalize_device, canonicalize_dtype from ...tensor import Quantizer @@ -94,7 +97,7 @@ def reset_parameters(self) -> None: # Make sure parameter is initialized bias = self.bias - if bias.device.type != "cuda": + if bias.device.type != te_device_type(): bias = torch.empty_like(bias, device=self.device) else: bias = bias.to(device=self.device) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index f26a337a4d..0b67dca03b 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -12,8 +12,9 @@ from typing import Any, Optional import torch - import transformer_engine_torch as tex + +from transformer_engine import te_device_type from ...cpp_extensions import general_grouped_gemm from ...distributed import CudaRNGStatesTracker from ...module._common import WeightGradStore @@ -690,7 +691,7 @@ def fuser_forward( # Get autocast dtype if needed if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) else: dtype = weight_param.dtype diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 1d8d8be971..f233c8be36 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -224,7 +224,6 @@ def op_backward( dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) - # Compute RMSNorm backward pass dx, dw = rmsnorm_bwd( dy, x, diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index b4427df41a..c06c2d4c85 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -9,8 +9,9 @@ from typing import Any, Optional import torch - import transformer_engine_torch as tex + +from transformer_engine import te_device_type from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data @@ -90,7 +91,7 @@ def op_forward( # Compute dtype dtype: torch.dtype if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) else: dtype = input_.dtype if dtype not in (torch.float32, torch.float16, torch.bfloat16): @@ -242,7 +243,7 @@ def op_forward( # Compute dtype dtype: torch.dtype if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) else: dtype = input_.dtype if dtype not in (torch.float32, torch.float16, torch.bfloat16): @@ -400,7 +401,7 @@ def fuser_forward( # Determine compute dtype if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) elif isinstance(input_, torch.Tensor): dtype = input_.dtype else: diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index c5ce2b148d..8f5a53bf2c 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -12,8 +12,9 @@ from typing import Any, Optional import torch - import transformer_engine_torch as tex + +from transformer_engine import te_device_type from ...quantization import Recipe from ...tensor import Quantizer from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor @@ -157,7 +158,7 @@ def fuser_forward( fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 device = fc1_weight_param.device if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) else: dtype = fc1_weight_param.dtype diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index dfc11a19e7..08093c179d 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -10,6 +10,8 @@ import torch +from transformer_engine import te_device_type + from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer @@ -95,7 +97,7 @@ def fuser_forward( # Get autocast dtype if needed if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) else: dtype = linear_op.weight.dtype diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 2dfc0566b7..ece2add19a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -10,6 +10,10 @@ import torch + +from transformer_engine import te_device_type + + from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer @@ -89,7 +93,7 @@ def fuser_forward( # Get autocast dtype if needed if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) else: dtype = linear_op.weight.dtype diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index ae4bdd4b19..e16906f30a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -10,6 +10,8 @@ import torch +from transformer_engine import te_device_type + from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer @@ -71,7 +73,7 @@ def fuser_forward( # Get autocast dtype if needed if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) else: dtype = linear_op.weight.dtype diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index fbaf69d75d..06ef799ee1 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -11,6 +11,9 @@ import torch from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_external_gemm + +from transformer_engine import te_device_type + from ...cpp_extensions import general_gemm from ...distributed import get_distributed_world_size from ...module.base import ( @@ -175,7 +178,7 @@ def _functional_backward( else: device = grad_output.device device = canonicalize_device(device) - if device.type != "cuda": + if device.type != te_device_type(): raise ValueError(f"Only CUDA devices are supported (got {device})") # Check datatype diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 0d3e1d0416..29647ab281 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -11,6 +11,7 @@ import torch from transformer_engine_torch import CommOverlapType +from transformer_engine import te_device_type from ...cpp_extensions import general_gemm from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import get_distributed_world_size @@ -155,7 +156,7 @@ def _functional_forward( """ # Check device - if device.type != "cuda": + if device.type != te_device_type(): raise ValueError(f"Only CUDA devices are supported (got {device})") # Check datatype @@ -320,7 +321,7 @@ def fuser_forward( # Get autocast dtype if needed if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") + dtype = torch.get_autocast_dtype(te_device_type()) else: dtype = linear_op.weight.dtype diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 437dfa829e..64f717deef 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -13,6 +13,7 @@ import torch from torch.distributed._tensor import DTensor import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from .multi_tensor_apply import multi_tensor_applier @@ -185,7 +186,7 @@ def __init__( self._step_supports_amp_scaling = True # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=te_device_type()) self.multi_tensor_adam = tex.multi_tensor_adam self.multi_tensor_adam_param_remainder = tex.multi_tensor_adam_param_remainder self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8 diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index ca59a0ebf8..b103fc6992 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -8,6 +8,7 @@ import torch import transformer_engine_torch as tex +from transformer_engine import te_device_type import transformer_engine.pytorch.triton.permutation as triton_permutation from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.quantized_tensor import QuantizedTensor @@ -15,6 +16,7 @@ from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor + __all__ = [ "moe_permute", "moe_unpermute", @@ -42,10 +44,14 @@ def forward( return inp, torch.tensor([], device=inp.device) # Device check - if not inp.is_cuda: - raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") - if not index.is_cuda: - raise ValueError(f"index must be a CUDA tensor, but got tensor on {index.device}.") + if inp.device.type != te_device_type(): + raise ValueError( + f"inp must be a {te_device_type()} tensor, but got tensor on {inp.device}." + ) + if index.device.type != te_device_type(): + raise ValueError( + f"index must be a {te_device_type()} tensor, but got tensor on {index.device}." + ) # Shape check if inp.size(0) != index.size(0): raise ValueError( @@ -125,8 +131,10 @@ def forward( # None probs check if probs is not None: - if not probs.is_cuda: - raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") + if probs.device.type != te_device_type(): + raise ValueError( + f"probs must be a {te_device_type()} tensor, but got tensor on {probs.device}." + ) if probs.dtype != torch.float32: warnings.warn( @@ -143,11 +151,14 @@ def forward( probs = torch.empty(0) # Device check - if not inp.is_cuda: - raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") - if not row_id_map.is_cuda: + if inp.device.type != te_device_type(): raise ValueError( - f"row_id_map must be a CUDA tensor, but got tensor on {row_id_map.device}." + f"inp must be a {te_device_type()} tensor, but got tensor on {inp.device}." + ) + if row_id_map.device.type != te_device_type(): + raise ValueError( + f"row_id_map must be a {te_device_type()} tensor, but got tensor on" + f" {row_id_map.device}." ) # Data type check @@ -209,19 +220,25 @@ def forward( ctx.probs = probs return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device) - if not inp.is_cuda: - raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") - if not routing_map.is_cuda: + if inp.device.type != te_device_type(): raise ValueError( - f"routing_map must be a CUDA tensor, but got tensor on {routing_map.device}." + f"inp must be a {te_device_type()} tensor, but got tensor on {inp.device}." + ) + if routing_map.device.type != te_device_type(): + raise ValueError( + f"routing_map must be a {te_device_type()} tensor, but got tensor on" + f" {routing_map.device}." ) if probs is not None: - if not probs.is_cuda: - raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") + if probs.device.type != te_device_type(): + raise ValueError( + f"probs must be a {te_device_type()} tensor, but got tensor on {probs.device}." + ) if pad_offsets is not None: - if not pad_offsets.is_cuda: + if pad_offsets.device.type != te_device_type(): raise ValueError( - f"pad_offsets must be a CUDA tensor, but got tensor on {pad_offsets.device}." + f"pad_offsets must be a {te_device_type()} tensor, but got tensor on" + f" {pad_offsets.device}." ) if inp.size(0) != routing_map.size(0): @@ -396,23 +413,28 @@ def forward( with_probs = merging_probs is not None if with_probs: - if not merging_probs.is_cuda: + if merging_probs.device.type != te_device_type(): raise ValueError( - "merging_probs must be a CUDA tensor, but got tensor on " - f"{merging_probs.device}." + "merging_probs must be a " + + te_device_type() + + f" tensor, but got tensor on {merging_probs.device}." ) # Device check - if not inp.is_cuda: - raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") - if not row_id_map.is_cuda: + if inp.device.type != te_device_type(): + raise ValueError( + f"inp must be a {te_device_type()} tensor, but got tensor on {inp.device}." + ) + if row_id_map.device.type != te_device_type(): raise ValueError( - f"row_id_map must be a CUDA tensor, but got tensor on {row_id_map.device}." + f"row_id_map must be a {te_device_type()} tensor, but got tensor on" + f" {row_id_map.device}." ) if pad_offsets is not None: - if not pad_offsets.is_cuda: + if pad_offsets.device.type != te_device_type(): raise ValueError( - f"pad_offsets must be a CUDA tensor, but got tensor on {pad_offsets.device}." + f"pad_offsets must be a {te_device_type()} tensor, but got tensor on" + f" {pad_offsets.device}." ) if isinstance(inp, QuantizedTensor): @@ -777,19 +799,25 @@ def forward( if not inp.numel(): return inp, probs - if not inp.is_cuda: - raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") - if not split_sizes.is_cuda: + if inp.device.type != te_device_type(): raise ValueError( - f"split_sizes must be a CUDA tensor, but got tensor on {split_sizes.device}." + f"inp must be a {te_device_type()} tensor, but got tensor on {inp.device}." ) - if not sorted_idxs.is_cuda: + if split_sizes.device.type != te_device_type(): raise ValueError( - f"sorted_idxs must be a CUDA tensor, but got tensor on {sorted_idxs.device}." + f"split_sizes must be a {te_device_type()} tensor, but got tensor on" + f" {split_sizes.device}." + ) + if sorted_idxs.device.type != te_device_type(): + raise ValueError( + f"sorted_idxs must be a {te_device_type()} tensor, but got tensor on" + f" {sorted_idxs.device}." ) if probs is not None: - if not probs.is_cuda: - raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") + if probs.device.type != te_device_type(): + raise ValueError( + f"probs must be a {te_device_type()} tensor, but got tensor on {probs.device}." + ) num_tokens, hidden_size = inp.shape num_splits = split_sizes.size(0) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 47e6d5c8dc..5033946a80 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -16,6 +16,7 @@ import torch import transformer_engine_torch as tex +from transformer_engine import te_device_type from transformer_engine.common.recipe import ( Recipe, DelayedScaling, @@ -26,6 +27,7 @@ NVFP4BlockScaling, CustomRecipe, ) + from .constants import dist_group_type from .utils import get_device_compute_capability @@ -290,7 +292,9 @@ def reset(cls) -> None: def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: """`skip_fp8_weight_update_tensor` inplace setter.""" if cls.skip_fp8_weight_update_tensor is None: - cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") + cls.skip_fp8_weight_update_tensor = torch.empty( + 1, dtype=torch.float32, device=te_device_type() + ) cls.skip_fp8_weight_update_tensor.fill_(skip) @classmethod @@ -1078,7 +1082,7 @@ def __init__( # Allocate buffers if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) self.amax_history = torch.zeros( recipe.amax_history_len, @@ -1124,7 +1128,7 @@ def __init__( # Allocate buffers if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) self.device = device def make_quantizers(self) -> list: @@ -1164,7 +1168,7 @@ def __init__( # Allocate buffers if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) def make_quantizers(self) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. @@ -1203,7 +1207,7 @@ def __init__( # Allocate buffers if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) self.device = device def make_quantizers(self) -> list: @@ -1304,7 +1308,7 @@ def __init__( # Allocate buffers if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) def make_quantizers(self) -> list: from .tensor.nvfp4_tensor import NVFP4Quantizer @@ -1374,7 +1378,7 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) self.device = device if getattr(recipe, "qfactory", None) is None: diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 99f6a99efa..acff4fd829 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -24,7 +24,6 @@ FORCE_BUILD = os.getenv("NVTE_PYTORCH_FORCE_BUILD", "FALSE") == "TRUE" FORCE_CXX11_ABI = os.getenv("NVTE_PYTORCH_FORCE_CXX11_ABI", "FALSE") == "TRUE" -SKIP_CUDA_BUILD = os.getenv("NVTE_PYTORCH_SKIP_CUDA_BUILD", "FALSE") == "TRUE" PACKAGE_NAME = "transformer_engine_torch" BASE_WHEEL_URL = ( "https://github.com/NVIDIA/TransformerEngine/releases/download/{tag_name}/{wheel_name}" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ffa2d5fa05..f584a6c2db 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -13,6 +13,8 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType + +from transformer_engine import te_device_type from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..quantized_tensor import QuantizedTensor, Quantizer @@ -213,9 +215,11 @@ def make_empty( pin_memory: bool = False, ) -> Float8BlockwiseQTensor: """Construct quantized tensor with uninitialized data""" + if device is None: + device = torch.device(te_device_type()) tensor_kwargs = { - "device": torch.device("cuda") if device is None else device, + "device": torch.device(te_device_type()) if device is None else device, "pin_memory": pin_memory, } @@ -462,7 +466,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): qt._rowwise_scale_inv, qt._columnwise_scale_inv, ): - if t is not None and t.is_cuda: + if t is not None and t.device.type == te_device_type(): t.record_stream(stream) return None @@ -552,7 +556,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: """ # Tensor device - new_device = tensor.device if tensor.is_cuda else self.device + new_device = tensor.device if tensor.device.type == te_device_type() else self.device def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._rowwise_data = src._rowwise_data diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 168b03134e..54041c4353 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -11,6 +11,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType +from transformer_engine import te_device_type from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, @@ -124,7 +125,7 @@ def make_empty( # Canonicalize tensor attributes if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) # Allocate FP8 data data = None @@ -353,7 +354,7 @@ def make_empty( # Canonicalize tensor attributes if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) # Allocate FP8 data data = None @@ -1034,7 +1035,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: """ # Tensor device - new_device = tensor.device if tensor.is_cuda else self.device + new_device = tensor.device if tensor.device.type == te_device_type() else self.device if not devices_match(new_device, tensor.device): tensor = tensor.to(device=new_device) diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index ab0c7484fc..a02e8b9754 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -10,6 +10,7 @@ import torch from torch.utils._pytree import tree_map +from transformer_engine import te_device_type from ..quantized_tensor import QuantizedTensorStorage, Quantizer from .storage.grouped_tensor_storage import GroupedTensorStorage @@ -128,7 +129,7 @@ def __new__( device = maybe_tensor.device break if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) # Match QuantizedTensor __new__: accept externally-computed stride to # avoid Python-side stride computation overhead for C++ construction. diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index debba0cd0b..59c3b34e22 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -14,6 +14,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType +from transformer_engine import te_device_type from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple @@ -108,7 +109,7 @@ def make_empty( # Canonicalize tensor attributes if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) assert ( shape[-1] % MXFP8_BLOCK_SCALING_SIZE == 0 @@ -816,7 +817,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: """ # Tensor device - new_device = tensor.device if tensor.is_cuda else self.device + new_device = tensor.device if tensor.device.type == te_device_type() else self.device if not devices_match(new_device, tensor.device): tensor = tensor.to(device=new_device) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index eb514d3a9e..f68f1d1dea 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -14,6 +14,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType +from transformer_engine import te_device_type from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type from ..utils import ( @@ -100,7 +101,7 @@ def get_rht_matrix(with_random_sign_mask: bool, device: int) -> torch.Tensor: signs = get_no_random_sign_vector(device=device) sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device=device) rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension, device=device) - return rht_matrix.to(dtype=torch.bfloat16) + return rht_matrix.to(dtype=torch.bfloat16).to(te_device_type()) @functools.lru_cache(maxsize=None) @@ -301,7 +302,7 @@ def make_empty( # Canonicalize tensor attributes if device is None: - device = torch.device("cuda") + device = torch.device(te_device_type()) assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, ( f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" @@ -681,7 +682,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: """ # Tensor device - new_device = tensor.device if tensor.is_cuda else self.device + new_device = tensor.device if tensor.device.type == te_device_type() else self.device if not devices_match(new_device, tensor.device): tensor = tensor.to(device=new_device) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index ff1c78f695..893f0066bc 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -8,6 +8,8 @@ import math import torch + +from transformer_engine import te_device_type from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ..mxfp8_tensor import MXFP8Tensor @@ -563,7 +565,7 @@ def make_grouped_tensor( # TODO(ksivaman): Single kernel + remove the host offset calculation. tensor_offsets = GroupedTensorStorage.make_tensor_offsets(first_dims, logical_last_dim) if ( - first_dims.device.type == "cuda" + first_dims.device.type == te_device_type() and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing() ): diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 4b96ccf739..98ea4c75ec 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -10,6 +10,7 @@ import torch +from transformer_engine import te_device_type from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention @@ -343,7 +344,7 @@ def __init__( activation: str = "gelu", activation_params: Optional[dict] = None, normalization: str = "LayerNorm", - device: Union[torch.device, str] = "cuda", + device: Union[torch.device, str] = te_device_type(), attn_input_format: str = "sbhd", name: str = None, qk_norm_type: Optional[str] = None, diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 4902bc686c..554879236c 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -9,6 +9,8 @@ import torch import triton +from transformer_engine import te_device_type + from transformer_engine.common.triton.permutation import ( _row_id_map_pass_1_kernel, _row_id_map_pass_2_kernel, @@ -49,10 +51,12 @@ def make_row_id_map( The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding to the first n_routed row indices above. """ - row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda") + row_id_map = torch.empty( + (num_tokens, num_experts * 2 + 1), dtype=torch.int32, device=te_device_type() + ) block_size = 1024 grid = (num_experts, triton.cdiv(num_tokens, block_size)) - workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda") + workspace_tensor = torch.empty(grid, dtype=torch.int32, device=te_device_type()) # supposing num_tokens == 5, num_experts == 3, block_size == 3 # and we have a routing_map like this: @@ -160,12 +164,14 @@ def permute_with_mask_map( # Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed. # The kernel writes only to valid positions, leaving padding positions at zero. alloc = torch.zeros if pad_offsets is not None else torch.empty - output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") + output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device=te_device_type()) permuted_probs = ( - alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None + alloc((num_out_tokens,), dtype=probs.dtype, device=te_device_type()) + if probs is not None + else None ) permuted_scale = ( - alloc((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda") + alloc((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device=te_device_type()) if scale is not None else None ) @@ -243,10 +249,10 @@ def unpermute_with_mask_map( hidden_size : int Hidden size of the permuted tensor. """ - output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device=te_device_type()) if permuted_probs is not None: unpermuted_probs = torch.empty( - (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda" + (num_tokens, num_experts), dtype=permuted_probs.dtype, device=te_device_type() ) else: unpermuted_probs = None @@ -325,9 +331,11 @@ def unpermute_with_mask_map_bwd_with_merging_probs( # by the kernel. This matches the behavior of Fp8Unpadding.backward which zeros # out the padding slots. alloc = torch.zeros if pad_offsets is not None else torch.empty - act_grad = alloc((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda") + act_grad = alloc( + (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device=te_device_type() + ) merging_probs_grad = torch.empty( - (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" + (num_tokens, num_experts), dtype=merging_probs.dtype, device=te_device_type() ) grid = (num_tokens,) _unpermute_bwd_with_merging_probs_kernel[grid]( @@ -378,7 +386,7 @@ def make_chunk_sort_map( num_splits : int Number of splits of split_sizes and sorted_indices. """ - row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda") + row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device=te_device_type()) grid = (num_tokens,) _make_chunk_sort_map_kernel[grid]( split_sizes, @@ -416,9 +424,9 @@ def sort_chunks_by_map( is_forward : bool Whether the sort is for forward or backward. """ - output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device=te_device_type()) if probs is not None: - permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") + permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device=te_device_type()) else: permuted_probs = None # pylint: disable=unnecessary-lambda-assignment diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index a76f205acc..eecf14d0e1 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -12,6 +12,8 @@ import numpy as np import torch +from transformer_engine import te_device_type + from .torch_version import torch_version from ..debug.pytorch.debug_quantization import DebugQuantizedTensor @@ -43,7 +45,8 @@ def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: @functools.lru_cache(maxsize=None) def _empty_tensor() -> torch.Tensor: """Get tensor with no entries and no data""" - return torch.Tensor().cuda() + + return torch.Tensor().to(device=te_device_type()) def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: @@ -549,12 +552,12 @@ def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: if device is None: # Use default CUDA device device = torch.get_default_device() - if device.type != "cuda": - device = torch.device("cuda", torch.cuda.current_device()) + if device.type != te_device_type(): + device = torch.device(te_device_type(), torch.cuda.current_device()) elif not isinstance(device, torch.device): device = torch.device(device) - if device.type == "cuda" and device.index is None: - device = torch.device("cuda", torch.cuda.current_device()) + if device.type == te_device_type() and device.index is None: + device = torch.device(te_device_type(), torch.cuda.current_device()) return device @@ -576,7 +579,7 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: device2 = torch.device(device2) if device1.type != device2.type: return False - if device1.type == "cuda": + if device1.type == te_device_type(): index1 = device1.index index2 = device2.index if index1 == index2: @@ -708,12 +711,12 @@ def canonicalize_process_group( def torch_get_autocast_gpu_dtype() -> torch.dtype: """Get PyTorch autocast GPU dtype.""" if torch_version() >= (2, 4, 0): - return torch.get_autocast_dtype("cuda") + return torch.get_autocast_dtype(te_device_type()) return torch.get_autocast_gpu_dtype() if torch_version() >= (2, 4, 0): - gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda") + gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type=te_device_type()) else: gpu_autocast_ctx = torch.cuda.amp.autocast @@ -812,7 +815,7 @@ def convert_to_torch_tensor(tensor: Union[_WeakRefTensor, torch.Tensor]) -> torc if isinstance(x, torch.Tensor): return ( convert_to_torch_tensor(_WeakRefTensor(x.data_ptr(), x.dtype, x.shape)) - if x.is_cuda + if x.device.type == te_device_type() else x ) if isinstance(x, tuple):