diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 00000000000..8749647f9ac --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,14 @@ +{ + "hooks": { + "UserPromptSubmit": [ + { + "hooks": [ + { + "type": "command", + "command": "printf '{\"hookSpecificOutput\":{\"hookEventName\":\"UserPromptSubmit\",\"additionalContext\":\"MANDATORY WORKFLOW — never skip or reorder: (1) Read the artifact first (commit, file, error, PR). (2) Identify and invoke the relevant skill via the Skill tool BEFORE forming any answer or plan — even when the answer seems obvious. (3) Only then answer using the skill context. Skipping step 2 is not allowed.\"}}'" + } + ] + } + ] + } +} diff --git a/.github/actions/action.yml b/.github/actions/action.yml index c44034b48f9..3491cade580 100644 --- a/.github/actions/action.yml +++ b/.github/actions/action.yml @@ -61,6 +61,10 @@ inputs: description: "Platform to run tests on (e.g. dgx_h100, dgx_gb200)" required: false default: "dgx_h100" + cadence: + description: "Trigger cadence for cadence filter (pr|nightly|mergegroup). Empty disables filter." + required: false + default: "" runs: using: "composite" steps: @@ -136,6 +140,9 @@ runs: if [ "${{ inputs.lightweight }}" == "true" ]; then ARGS+=(--enable-lightweight-mode) fi + if [ -n "${{ inputs.cadence }}" ]; then + ARGS+=(--cadence ${{ inputs.cadence }}) + fi export PYTHONPATH=$(pwd) export NEMORUN_HOME=$(pwd) diff --git a/.github/copy-pr-bot.yaml b/.github/copy-pr-bot.yaml index 618c7c4c9c9..c4f81fbcb7d 100644 --- a/.github/copy-pr-bot.yaml +++ b/.github/copy-pr-bot.yaml @@ -1,4 +1,4 @@ enabled: true auto_sync_draft: false auto_sync_ready: true -trustees_override: ["AAnoosheh", "ArEsKay3", "Autumn1998", "BestJuly", "BoxiangW", "CarlosGomes98", "ChenhanYu", "Connor-XY", "FDecaYed", "HaochenYuan", "ISEEKYAN", "JRD971000", "Mellonta", "Phlip79", "QiZhangNV", "RPrenger", "ShriyaRishab", "Victarry", "WanZzzzzz", "Wohox", "YangFei1990", "ZhiyuLi-Nvidia", "ahmadki", "aklife97", "ananthsub", "aroshanghias-nvd", "asolergi-nv", "buptzyb", "chtruong814", "cjld", "cspades", "cuichenx", "deepakn94", "dimapihtar", "dingqingy-nv", "duncanriach", "erhoo82", "ericharper", "fanshiqing", "faradawn", "fitsumreda", "frsun-nvda", "gautham-kollu", "gdengk", "guihong-nv", "guyueh1", "hexinw-nvidia", "huvunvidia", "hxbai", "ilml", "jalbericiola", "janEbert", "jaredcasper", "jenchen13", "jiemingz", "jingqiny-99", "jkamalu", "jon-barker", "jstjohn", "kajalj22", "kanz-nv", "keshavb96", "kevalmorabia97", "ko3n1g", "ksivaman", "kunlunl", "kvareddy", "kwyss-nvidia", "layalir", "lhb8125", "lmcafee-nvidia", "maanug-nv", "mathemakitten", "matthieule", "mchrzanowski", "mehraakash", "minitu", "mkhona-nvidia", "nanz-nv", "parthmannan", "prajwal1210", "pthombre", "rhewett-nv", "rogerwaleffe", "sajadn", "sanandaraj5597", "sancha", "santhnm2", "sbak5", "shanmugamr1992", "sharathts", "sheliang-nv", "shengf-nv", "shifangx", "shjwudp", "sidsingh-nvidia", "skyw", "sraman-rgb", "sudhakarsingh27", "tdene", "theothermike", "thomasdhc", "tomlifu", "trintamaki", "tylerpoon", "wdykas", "wplf", "wujingyue", "xiaoyao0115", "xuwchen", "yanring", "yaox12", "yaoyu-33", "yashaswikarnati", "yeyu-nvidia", "yobibyte", "youngeunkwon0405", "yueshen2016", "yuzhongw-nvidia", "zhongbozhu"] +trustees_override: ["AAnoosheh", "ArEsKay3", "Autumn1998", "BestJuly", "BoxiangW", "CarlosGomes98", "ChenhanYu", "Connor-XY", "FDecaYed", "HaochenYuan", "ISEEKYAN", "JRD971000", "Mellonta", "Phlip79", "QiZhangNV", "RPrenger", "ShriyaRishab", "Victarry", "WanZzzzzz", "Wohox", "YangFei1990", "ZhiyuLi-Nvidia", "ahmadki", "aklife97", "ananthsub", "aroshanghias-nvd", "asolergi-nv", "balasaajay", "buptzyb", "chtruong814", "cjld", "cspades", "cuichenx", "deepakn94", "dimapihtar", "dingqingy-nv", "duncanriach", "erhoo82", "ericharper", "fanshiqing", "faradawn", "fitsumreda", "frsun-nvda", "gautham-kollu", "gdengk", "guihong-nv", "guyueh1", "hexinw-nvidia", "huvunvidia", "hxbai", "ilml", "jalbericiola", "janEbert", "jaredcasper", "jenchen13", "jiemingz", "jingqiny-99", "jkamalu", "jon-barker", "jstjohn", "kajalj22", "kanz-nv", "kevalmorabia97", "ko3n1g", "ksivaman", "kunlunl", "kvareddy", "kwyss-nvidia", "layalir", "lhb8125", "lmcafee-nvidia", "maanug-nv", "mathemakitten", "matthieule", "mchrzanowski", "mehraakash", "minitu", "mkhona-nvidia", "nanz-nv", "ntajbakhsh", "parthmannan", "prajwal1210", "pthombre", "rhewett-nv", "rogerwaleffe", "sajadn", "sanandaraj5597", "sancha", "santhnm2", "sbak5", "shanmugamr1992", "sharathts", "sheliang-nv", "shengf-nv", "shifangx", "shjwudp", "sidsingh-nvidia", "skyw", "sraman-rgb", "sudhakarsingh27", "tdene", "theothermike", "thomasdhc", "tomlifu", "trintamaki", "tylerpoon", "wdykas", "wplf", "wujingyue", "xiaoyao0115", "xuantengh", "xuwchen", "yanring", "yaox12", "yaoyu-33", "yashaswikarnati", "yeyu-nvidia", "yobibyte", "youngeunkwon0405", "yueshen2016", "yuzhongw-nvidia", "zhongbozhu"] diff --git a/.github/oncall_schedule.json b/.github/oncall_schedule.json index 86b1bc73cfb..08db6e7bf44 100644 --- a/.github/oncall_schedule.json +++ b/.github/oncall_schedule.json @@ -1,12 +1,4 @@ [ - { - "user": "asolergi-nv", - "date": "2026-04-22" - }, - { - "user": "maanug-nv", - "date": "2026-04-29" - }, { "user": "dimapihtar", "date": "2026-05-06" @@ -46,5 +38,13 @@ { "user": "wujingyue", "date": "2026-07-08" + }, + { + "user": "Connor-XY", + "date": "2026-07-15" + }, + { + "user": "Phlip79", + "date": "2026-07-22" } ] diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index d2825f9c34b..8f319e66f87 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -3,6 +3,15 @@ :warning: For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall. +## Issue tracking + +For PRs from open-source community contributors: + +- **New features**: a linked issue is **required**. Please open a [feature request](https://github.com/NVIDIA/Megatron-LM/issues/new?template=feature_request.md) and reference it here before submitting the PR. +- **Small updates (bug fixes, minor improvements)**: a linked issue is **recommended** and will accelerate the PR review process. + +Linked issue: + ## Contribution process ### Pre-checks diff --git a/.github/workflows/cicd-approve-test-queue.yml b/.github/workflows/cicd-approve-test-queue.yml index cfd94f02a7d..32b82a66e19 100644 --- a/.github/workflows/cicd-approve-test-queue.yml +++ b/.github/workflows/cicd-approve-test-queue.yml @@ -65,6 +65,7 @@ jobs: import json import requests import re + import time # GitHub API configuration GITHUB_TOKEN = os.environ["GITHUB_TOKEN"] @@ -88,21 +89,38 @@ jobs: "X-GitHub-Api-Version": "2022-11-28", } - def make_request(endpoint, method="GET", data=None): - """Make a request to the GitHub API with error handling.""" + def make_request(endpoint, method="GET", data=None, max_retries=5): + """Make a request to the GitHub API with retry on transient errors.""" url = f"{API_BASE}/{endpoint}" - try: - if method == "GET": - response = requests.get(url, headers=headers) - else: - response = requests.post(url, headers=headers, json=data) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - print(f"Error making request to {endpoint}: {str(e)}") - if hasattr(e.response, 'text'): - print(f"Response: {e.response.text}") - return None + for attempt in range(max_retries): + try: + if method == "GET": + response = requests.get(url, headers=headers, timeout=30) + else: + response = requests.post(url, headers=headers, json=data, timeout=30) + if response.status_code == 429: + retry_after = int(response.headers.get("Retry-After", 2 ** attempt)) + print(f"Rate limited on {endpoint}, retrying in {retry_after}s (attempt {attempt + 1}/{max_retries})") + time.sleep(retry_after) + continue + if response.status_code >= 500: + delay = 2 ** attempt + print(f"Server error {response.status_code} on {endpoint}, retrying in {delay}s (attempt {attempt + 1}/{max_retries})") + time.sleep(delay) + continue + response.raise_for_status() + return response.json() + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: + delay = 2 ** attempt + print(f"Transient error on {endpoint}: {e}, retrying in {delay}s (attempt {attempt + 1}/{max_retries})") + time.sleep(delay) + except requests.exceptions.RequestException as e: + print(f"Error making request to {endpoint}: {str(e)}") + if hasattr(e, 'response') and e.response is not None: + print(f"Response: {e.response.text}") + return None + print(f"Max retries ({max_retries}) exceeded for {endpoint}") + return None def is_internal_contributor(pr_info): """Return True if the PR author is a member of NVIDIA or NVIDIA-NeMo org (is_org_member).""" @@ -166,8 +184,16 @@ jobs: # Get current running and queued workflows print("Fetching workflow runs...") - queued_workflow_runs = make_request("actions/runs?status=queued").get("workflow_runs", []) - in_progress_workflow_runs = make_request("actions/runs?status=in_progress").get("workflow_runs", []) + queued_resp = make_request("actions/runs?status=queued") + if queued_resp is None: + print("Failed to fetch queued workflow runs after retries, exiting") + exit(1) + queued_workflow_runs = queued_resp.get("workflow_runs", []) + in_progress_resp = make_request("actions/runs?status=in_progress") + if in_progress_resp is None: + print("Failed to fetch in-progress workflow runs after retries, exiting") + exit(1) + in_progress_workflow_runs = in_progress_resp.get("workflow_runs", []) # For external contributors, enforce a single global concurrency limit across ALL branches. # For internal contributors, enforce per-branch limits as before. @@ -199,7 +225,11 @@ jobs: # Get waiting CI workflows for test environment print("Fetching deployments...") - pending_workflows = make_request("actions/runs?status=waiting").get("workflow_runs", []) + waiting_resp = make_request("actions/runs?status=waiting") + if waiting_resp is None: + print("Failed to fetch waiting workflow runs after retries, exiting") + exit(1) + pending_workflows = waiting_resp.get("workflow_runs", []) print("Pending workflows:", len(pending_workflows)) pending_workflows = [run for run in pending_workflows if run["name"] == "CICD Megatron-LM" and matches_queue(run, "${{ matrix.branch }}", CONTRIBUTOR_TYPE)] @@ -220,7 +250,11 @@ jobs: print(f"Approving workflow {workflow_name} with Run Id: {workflow_id}") deployment_url = f"actions/runs/{workflow_id}/pending_deployments" - deployment = make_request(deployment_url)[0] + deployments = make_request(deployment_url) + if not deployments: + print(f"Failed to fetch pending deployments for run {workflow_id}") + exit(1) + deployment = deployments[0] environment_id = deployment["environment"]["id"] # Approve the deployment diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 35eb570296d..1fdc5029cd2 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -77,9 +77,10 @@ jobs: IS_MAIN_BRANCH: ${{ github.ref == 'refs/heads/main' }} IS_MERGE_GROUP: ${{ github.event_name == 'merge_group' }} SCHEDULED_JOB: ${{ github.event_name == 'schedule' }} + IS_WORKFLOW_DISPATCH: ${{ github.event_name == 'workflow_dispatch' }} run: | - # Skip SSO check for scheduled jobs, main branch, dev branch, or merge groups - if [ "${{ env.SCHEDULED_JOB }}" == "true" ] || [ "${IS_MAIN_BRANCH}" == "true" ] || [ "${IS_DEV_BRANCH}" == "true" ] || [ "${IS_MERGE_GROUP}" == "true" ]; then + # Skip SSO check for scheduled jobs, main branch, merge groups, or manual dispatches + if [ "${{ env.SCHEDULED_JOB }}" == "true" ] || [ "${IS_MAIN_BRANCH}" == "true" ] || [ "${IS_MERGE_GROUP}" == "true" ] || [ "${IS_WORKFLOW_DISPATCH}" == "true" ]; then echo "is_maintainer=true" | tee -a $GITHUB_OUTPUT exit 0 fi @@ -139,12 +140,14 @@ jobs: needs: [pre-flight] if: github.repository == 'NVIDIA/Megatron-LM' outputs: - scope: ${{ steps.configure.outputs.scope }} - n_repeat: ${{ steps.configure.outputs.n_repeat }} - lightweight: ${{ steps.configure.outputs.lightweight }} - lts: ${{ steps.configure.outputs.lts }} - mbridge_suite: ${{ steps.configure.outputs.mbridge_suite }} - dev: ${{ steps.configure.outputs.dev }} + scope: ${{ steps.configure.outputs.scope }} + n_repeat: ${{ steps.configure.outputs.n_repeat }} + lightweight: ${{ steps.configure.outputs.lightweight }} + lts: ${{ steps.configure.outputs.lts }} + mbridge_suite: ${{ steps.configure.outputs.mbridge_suite }} + dev: ${{ steps.configure.outputs.dev }} + cadence: ${{ steps.configure.outputs.cadence }} + cadence_bypass: ${{ steps.configure.outputs.cadence_bypass }} steps: - name: Get PR info id: get-pr-info @@ -158,6 +161,7 @@ jobs: GH_TOKEN: ${{ secrets.PAT }} IS_CI_WORKLOAD: ${{ needs.pre-flight.outputs.is_ci_workload }} IS_MERGE_GROUP: ${{ needs.pre-flight.outputs.is_merge_group }} + EVENT_NAME: ${{ github.event_name }} run: | PR_NUMBER=${{ fromJSON(steps.get-pr-info.outputs.pr-info || '{}').number }} @@ -169,15 +173,20 @@ jobs: HAS_LTS=$(echo "$LABELS" | jq 'any(. == "container::lts")') HAS_MBRIDGE=$(echo "$LABELS" | jq 'any(. == "Run MBridge tests")') - # Scheduled/CI workloads have no PR — treat as "Run functional tests" - [ "$IS_CI_WORKLOAD" == "true" ] && HAS_RUN_FUNCTIONAL=true - if [ "$IS_MERGE_GROUP" == "true" ]; then SCOPE=mr-github; N_REPEAT=1; LIGHTWEIGHT=false elif [ "$HAS_RUN_TESTS" == "true" ]; then SCOPE=mr-github; N_REPEAT=1; LIGHTWEIGHT=true elif [ "$HAS_RUN_FUNCTIONAL" == "true" ]; then SCOPE=mr-github; N_REPEAT=5; LIGHTWEIGHT=false + elif [ "$IS_CI_WORKLOAD" == "true" ] || [ "$EVENT_NAME" == "workflow_dispatch" ]; then + # Scheduled / dispatch / release have no PR labels; default to the + # full functional scope so cadence (set below) is the discriminator. + # `workflow_dispatch` is forced into this branch because upstream + # pre-flight reports is_ci_workload=false when dispatched from a + # `pull-request/*` branch, which would otherwise drop us into the + # slim scope. + SCOPE=mr-github; N_REPEAT=5; LIGHTWEIGHT=false else SCOPE=mr-github-slim; N_REPEAT=5; LIGHTWEIGHT=false fi @@ -188,22 +197,48 @@ jobs: MBRIDGE_SUITE="unit-only" fi + # Cadence: trigger-driven test selection axis (see filter_by_cadence + # in tests/test_utils/python_scripts/recipe_parser.py). PR labels + # `Run tests` and `Run functional tests` bypass the cadence filter so + # contributors retain a manual override. + if [ "$IS_MERGE_GROUP" == "true" ]; then + CADENCE=mergegroup + elif [ "$EVENT_NAME" == "schedule" ] || [ "$EVENT_NAME" == "workflow_dispatch" ]; then + CADENCE=nightly + else + CADENCE=pr + fi + + if [ "$HAS_RUN_TESTS" == "true" ] || [ "$HAS_RUN_FUNCTIONAL" == "true" ]; then + CADENCE_BYPASS=true + CADENCE_OUTPUT="" + else + CADENCE_BYPASS=false + CADENCE_OUTPUT="$CADENCE" + fi + DEV=true - echo "scope=$SCOPE" | tee -a $GITHUB_OUTPUT - echo "n_repeat=$N_REPEAT" | tee -a $GITHUB_OUTPUT - echo "lightweight=$LIGHTWEIGHT" | tee -a $GITHUB_OUTPUT - echo "lts=$HAS_LTS" | tee -a $GITHUB_OUTPUT - echo "mbridge_suite=$MBRIDGE_SUITE" | tee -a $GITHUB_OUTPUT - echo "dev=$DEV" | tee -a $GITHUB_OUTPUT + echo "scope=$SCOPE" | tee -a $GITHUB_OUTPUT + echo "n_repeat=$N_REPEAT" | tee -a $GITHUB_OUTPUT + echo "lightweight=$LIGHTWEIGHT" | tee -a $GITHUB_OUTPUT + echo "lts=$HAS_LTS" | tee -a $GITHUB_OUTPUT + echo "mbridge_suite=$MBRIDGE_SUITE" | tee -a $GITHUB_OUTPUT + echo "dev=$DEV" | tee -a $GITHUB_OUTPUT + echo "cadence=$CADENCE_OUTPUT" | tee -a $GITHUB_OUTPUT + echo "cadence_bypass=$CADENCE_BYPASS" | tee -a $GITHUB_OUTPUT # Pre-compute active row markers for the decision tree _MG=$( [ "$IS_MERGE_GROUP" == "true" ] && echo "**→**" || echo "" ) _RT=$( [ "$IS_MERGE_GROUP" != "true" ] && [ "$HAS_RUN_TESTS" == "true" ] && echo "**→**" || echo "" ) _RF=$( [ "$IS_MERGE_GROUP" != "true" ] && [ "$HAS_RUN_TESTS" != "true" ] && [ "$HAS_RUN_FUNCTIONAL" == "true" ] && echo "**→**" || echo "" ) + _CI=$( [ "$IS_MERGE_GROUP" != "true" ] && [ "$HAS_RUN_TESTS" != "true" ] && [ "$HAS_RUN_FUNCTIONAL" != "true" ] && [ "$IS_CI_WORKLOAD" == "true" ] && echo "**→**" || echo "" ) _DF=$( [ "$SCOPE" == "mr-github-slim" ] && echo "**→**" || echo "" ) _LTS=$( [ "$HAS_LTS" == "true" ] && echo "**→**" || echo "" ) _DEV=$( [ "$HAS_LTS" != "true" ] && echo "**→**" || echo "" ) + _CMG=$( [ "$CADENCE" == "mergegroup" ] && echo "**→**" || echo "" ) + _CN=$( [ "$CADENCE" == "nightly" ] && echo "**→**" || echo "" ) + _CPR=$( [ "$CADENCE" == "pr" ] && echo "**→**" || echo "" ) cat <> $GITHUB_STEP_SUMMARY Beep boop 🤖 I have consulted the labels and decided to run **$SCOPE** $( [ "$LIGHTWEIGHT" == "true" ] && echo "in lightweight mode " || echo "" )against the **$( [ "$HAS_LTS" == "true" ] && echo "lts" || echo "dev" )** container with **$N_REPEAT** repetition(s). You are welcome. @@ -216,6 +251,8 @@ jobs: | \`lts\` | \`$HAS_LTS\` | | \`dev\` | \`$DEV\` | | \`mbridge_suite\` | \`$MBRIDGE_SUITE\` | + | \`cadence\` | \`$CADENCE\` | + | \`cadence_bypass\` | \`$CADENCE_BYPASS\` | ### Decision tree @@ -225,9 +262,18 @@ jobs: |---|---|---|---|---| | $_MG | Merge group | \`mr-github\` | \`1\` | \`false\` | | $_RT | Label: _Run tests_ | \`mr-github\` | \`1\` | \`true\` | - | $_RF | Label: _Run functional tests_ / CI workload | \`mr-github\` | \`5\` | \`false\` | + | $_RF | Label: _Run functional tests_ | \`mr-github\` | \`5\` | \`false\` | + | $_CI | Schedule / dispatch (CI workload) | \`mr-github\` | \`5\` | \`false\` | | $_DF | _(default)_ | \`mr-github-slim\` | \`5\` | \`false\` | + **Cadence** _(filter bypassed when \`Run tests\` or \`Run functional tests\` label is set)_ + + | | Trigger | \`cadence\` | + |---|---|---| + | $_CMG | Merge group | \`mergegroup\` | + | $_CN | Schedule / dispatch | \`nightly\` | + | $_CPR | PR push (default) | \`pr\` | + **Container image** | | Trigger | \`image\` | @@ -239,6 +285,7 @@ jobs: - **\`lightweight\`**: trains for 4 steps instead of 100 and skips comparison against golden values — faster feedback, no correctness guarantees - **\`lts\`**: uses the Long Term Support container base image instead of the latest dev image - **\`dev\`**: uses the latest development container base image (default) + - **\`cadence\`**: per-test trigger filter (recipe \`cadence:\` field). Recipes default to \`[pr, nightly, mergegroup]\`. SUMMARY linting: @@ -502,15 +549,6 @@ jobs: python tests/test_utils/python_scripts/download_unit_tests_dataset.py --assets-dir ./assets echo "::endgroup::" - - name: Install GH CLI - shell: bash - run: | - for i in 1 2 3; do - apt-get update && apt-get install -y gh && break - echo "apt attempt $i failed, retrying..." - sleep 10 - done - - name: Get last merged PR id: cache_from env: @@ -679,11 +717,15 @@ jobs: env: SCOPE: ${{ needs.configure.outputs.scope }} LIGHTWEIGHT: ${{ needs.configure.outputs.lightweight }} + CADENCE: ${{ needs.configure.outputs.cadence }} run: | export PYTHONPATH=$(pwd) ARGS=(--scope $SCOPE) [ "$LIGHTWEIGHT" == "true" ] && ARGS+=(--enable-lightweight-mode) + # CADENCE is empty when label-based bypass is active; pass through + # only when set so generate_jet_trigger_job sees None and skips the filter. + [ -n "$CADENCE" ] && ARGS+=(--cadence "$CADENCE") python tests/test_utils/python_scripts/generate_jet_trigger_job.py \ --n-repeat 5 \ @@ -755,6 +797,7 @@ jobs: scope: ${{ needs.configure.outputs.scope }} n_repeat: ${{ needs.configure.outputs.n_repeat }} lightweight: ${{ needs.configure.outputs.lightweight }} + cadence: ${{ needs.configure.outputs.cadence }} cicd-parse-integration-tests-gb200: runs-on: ubuntu-latest @@ -790,11 +833,15 @@ jobs: env: SCOPE: ${{ needs.configure.outputs.scope }} LIGHTWEIGHT: ${{ needs.configure.outputs.lightweight }} + CADENCE: ${{ needs.configure.outputs.cadence }} run: | export PYTHONPATH=$(pwd) ARGS=(--scope $SCOPE) [ "$LIGHTWEIGHT" == "true" ] && ARGS+=(--enable-lightweight-mode) + # CADENCE is empty when label-based bypass is active; pass through + # only when set so generate_jet_trigger_job sees None and skips the filter. + [ -n "$CADENCE" ] && ARGS+=(--cadence "$CADENCE") python tests/test_utils/python_scripts/generate_jet_trigger_job.py \ --n-repeat 5 \ @@ -868,6 +915,7 @@ jobs: n_repeat: ${{ needs.configure.outputs.n_repeat }} lightweight: ${{ needs.configure.outputs.lightweight }} platform: dgx_gb200 + cadence: ${{ needs.configure.outputs.cadence }} Nemo_CICD_Test: needs: @@ -972,7 +1020,6 @@ jobs: ( needs.pre-flight.outputs.docs_only == 'true' || needs.pre-flight.outputs.is_deployment_workflow == 'true' - || github.event == 'merge_group' ) && needs.pre-flight.outputs.is_ci_workload == 'false' && !cancelled() @@ -998,6 +1045,7 @@ jobs: if: | ( (needs.pre-flight.outputs.is_ci_workload == 'true' && !failure()) + || (needs.pre-flight.outputs.is_merge_group == 'true' && !failure()) || success() ) && !cancelled() @@ -1006,6 +1054,11 @@ jobs: matrix: flag: [unit-test] steps: + - name: Get PR info + id: get-pr-info + if: startsWith(github.ref, 'refs/heads/pull-request/') && github.event_name == 'push' + uses: nv-gha-runners/get-pr-info@main + - name: Checkout uses: actions/checkout@v6 @@ -1036,6 +1089,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true flags: ${{ matrix.flag }} + base_sha: ${{ fromJSON(steps.get-pr-info.outputs.pr-info || '{}').base.sha }} - name: Upload artifacts uses: actions/upload-artifact@v6 diff --git a/.github/workflows/claude_review.yml b/.github/workflows/claude_review.yml index 99c6cdf4ac8..b7d5f1217c0 100644 --- a/.github/workflows/claude_review.yml +++ b/.github/workflows/claude_review.yml @@ -50,12 +50,20 @@ jobs: trigger_phrase: "/claude review" show_full_output: true claude_args: | - --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr review:*)" + --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr review:*),Read" --model "claude-opus-4-6" prompt: | REPO: ${{ env.REPO }} PR NUMBER: ${{ env.PR_NUMBER }} + Mandatory workflow — never skip or reorder: + 1. Read the PR diff first (gh pr diff). + 2. Based on the changed files and areas, identify relevant skills from skills//SKILL.md. + Common skill names: build-and-dependency, testing, cicd, linting-and-formatting, run-on-slurm, + nightly-sync, create-issue, respond-to-issue, split-pr, onboard-gb200-1node-tests. + 3. Read the SKILL.md files for all relevant areas using the Read tool. + 4. Only then perform the review using the skill context. + You are doing a light code review. Keep it concise and actionable. Focus ONLY on: @@ -134,13 +142,21 @@ jobs: trigger_phrase: "/claude strict-review" show_full_output: true claude_args: | - --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr review:*),Bash(git diff:*),Bash(git show:*),Bash(git log:*)" + --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr review:*),Bash(git diff:*),Bash(git show:*),Bash(git log:*),Read" --model "claude-opus-4-6" prompt: | REPO: ${{ env.REPO }} PR NUMBER: ${{ env.PR_NUMBER }} BASE REF: origin/${{ steps.pr-info.outputs.base_ref }} + Mandatory workflow — never skip or reorder: + 1. Read the PR diff first (gh pr diff). + 2. Based on the changed files and areas, identify relevant skills from skills//SKILL.md. + Common skill names: build-and-dependency, testing, cicd, linting-and-formatting, run-on-slurm, + nightly-sync, create-issue, respond-to-issue, split-pr, onboard-gb200-1node-tests. + 3. Read the SKILL.md files for all relevant areas using the Read tool. + 4. Only then perform the review using the skill context. + You are performing a strict, comprehensive code review on a **Megatron-LM** Pull Request. Megatron-LM is NVIDIA's large-scale distributed training framework for LLMs. Review the diff with a focus on **implementation correctness**, **training performance**, and **backward compatibility**. diff --git a/.github/workflows/nightly-sync-main-to-dev.yml b/.github/workflows/nightly-sync-main-to-dev.yml new file mode 100644 index 00000000000..d7c9e46811d --- /dev/null +++ b/.github/workflows/nightly-sync-main-to-dev.yml @@ -0,0 +1,217 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Nightly Sync Main to Dev + +on: + workflow_dispatch: + schedule: + # 21:00 UTC = 2 PM PDT (1 PM PST during winter — GitHub Actions cron + # is UTC-only and does not follow DST). + - cron: '0 21 * * *' + +concurrency: + group: nightly-sync-main-to-dev + cancel-in-progress: false + +permissions: + contents: write + pull-requests: write + issues: write + id-token: write + +jobs: + # Re-dispatch scheduled runs as workflow_dispatch via a PAT so the heavy + # job runs with a real User-type actor. On `schedule` events GitHub sets + # `github.actor` to `github-merge-queue` (no Users-API entry), which + # crashes anthropics/claude-code-action@v1 in `checkHumanActor` with a + # 404 before `allowed_bots` is ever consulted. Upstream fix PR + # https://github.com/anthropics/claude-code-action/pull/1212 is closed + # and unmerged; see issue + # https://github.com/anthropics/claude-code-action/issues/1284 for the + # same class of bug. The dispatch carries the PAT owner as the actor. + cron-redispatch: + if: github.event_name == 'schedule' && github.repository == 'NVIDIA/Megatron-LM' + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ secrets.PAT }} + steps: + - name: Dispatch sync workflow via PAT + run: | + gh workflow run nightly-sync-main-to-dev.yml \ + --repo "${{ github.repository }}" \ + --ref main + + sync-main-to-dev: + if: github.event_name == 'workflow_dispatch' && github.repository == 'NVIDIA/Megatron-LM' + runs-on: ubuntu-latest + timeout-minutes: 360 + env: + GH_TOKEN: ${{ secrets.PAT }} + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + token: ${{ secrets.PAT }} + + - name: Configure Git + run: | + git config user.name "svcnvidia-nemo-ci" + git config user.email "svcnvidia-nemo-ci@nvidia.com" + + - name: Compute branch name + id: vars + run: | + DATE=$(date -u +%d_%m_%Y) + BRANCH="main2dev/${DATE}" + echo "branch=$BRANCH" >> "$GITHUB_OUTPUT" + echo "date=$DATE" >> "$GITHUB_OUTPUT" + + - name: Close previous unmerged sync PRs + run: | + OPEN_PRS=$(gh pr list \ + --repo "${{ github.repository }}" \ + --base dev \ + --state open \ + --json number,headRefName \ + --jq '.[] | select(.headRefName | startswith("main2dev/")) | .number') + + for PR_NUM in $OPEN_PRS; do + echo "Closing stale sync PR #${PR_NUM}" + gh pr close "$PR_NUM" \ + --repo "${{ github.repository }}" \ + --comment "Superseded by today's nightly sync." + done + + - name: Check if sync is needed + id: check-sync + run: | + git fetch origin main dev + AHEAD_COUNT=$(git rev-list --count origin/dev..origin/main) + echo "main is $AHEAD_COUNT commit(s) ahead of dev" + if [ "$AHEAD_COUNT" -eq 0 ]; then + echo "skip=true" >> "$GITHUB_OUTPUT" + echo "No changes to sync." + else + echo "skip=false" >> "$GITHUB_OUTPUT" + fi + + - name: Run Claude Code to merge, fix, and iterate + if: steps.check-sync.outputs.skip != 'true' + uses: anthropics/claude-code-action@v1 + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + github_token: ${{ secrets.PAT }} + prompt: | + You are an automated sync bot. Merge `main` into `dev`, create a + PR, ensure CI passes (fixing failures), and mark the PR ready. + There are 4 phases. You are NOT done until Phase 4 completes. + + REPO: ${{ github.repository }} + BRANCH: ${{ steps.vars.outputs.branch }} + DATE: ${{ steps.vars.outputs.date }} + + Read `.claude/skills/nightly-sync/SKILL.md` for the detailed + merge strategy, CI architecture, failure investigation procedures, + and known issues. Also read `.claude/skills/build-and-test/SKILL.md` + and `CLAUDE.md` for general CI and contribution guidelines. + + ## Hard Constraints + + **Exit condition:** You MUST run `gh pr ready ` before + exiting. That command is Phase 4. Do NOT exit after Phase 1, 2, + or 3 — not even if CI is "still running" or "stuck in queue." + Keep polling until it resolves, then act. + + **NO background tasks. Ever.** + You are running inside a single GitHub Actions step. The step + process owns your shell. When you stop issuing tool calls, the + step ends and the runner container is DESTROYED — every + background process dies with it and cannot resume. There is no + "future session" to wake up into. + + The following are strictly forbidden: + - `Bash` with `run_in_background: true` + - `Agent` with `run_in_background: true` + - `ScheduleWakeup` (nothing will ever wake up) + - Any shell command ending in `&`, or using `nohup`, `disown`, + or `setsid` to detach a process + - `tail -f` on a log produced by a backgrounded task + + Required shape for every long wait: ONE foreground Bash tool + call containing an inline `while true; do ... sleep ; done` + or `until ...; do sleep ; done` loop that BLOCKS inside + that single tool call and only returns when the wait is + resolved (success, failure, or a clearly-classified terminal + state). Do NOT break a long wait into many short polls with + conversation in between — that wastes `--max-turns` and + creates windows where the agent could forget the loop. + + **Source of truth for CI status:** + `gh pr view --repo $REPO --json statusCheckRollup` + This lists every required check — GitHub Actions jobs AND + external contexts (GitLab CI, `copy-pr-bot`, etc.). The + `gh api .../actions/runs//jobs` endpoint alone is + NOT sufficient — it misses external contexts. + + **Pre-existing failures:** MUST verify against recent dev CI + before classifying any failure as pre-existing. Run + `gh pr checks` on a recently merged dev PR. If the test passes + on dev, the failure is sync-caused and you must fix it. A + check that has never completed on your PR cannot be + pre-existing — wait for it to finish first. + + **Phase 4 gate — strict "all terminal, all green":** + Do NOT run `gh pr ready` until every non-exempt required check + in `statusCheckRollup` satisfies BOTH: + - `status == "COMPLETED"` (NOT `QUEUED`, `IN_PROGRESS`, + `PENDING`, `WAITING`, or `REQUESTED`), AND + - `conclusion` ∈ {`SUCCESS`, `SKIPPED`, `NEUTRAL`}. + A check stuck in a runner queue is NOT complete. Never + classify queued/in-progress jobs as "infrastructure-blocked" + and ship anyway — wait for them to reach a terminal + conclusion, then act on that result. When a check fails, + loop: diagnose → fix → commit → push → `/ok to test ` → + poll. Only exit the loop when the gate is satisfied on the + LATEST CI run against the current HEAD SHA. + + **Exempt checks (may be ignored for the Phase 4 gate):** + These categories are pre-merge policy signals, not + correctness signals, so their failure must not block the + sync bot from marking the PR ready for human review. + + - Approval / code-review: `codeowners-approval`, + `check-approval`, `multi-approval-bot-summary`, + `is-not-external-contributor`, any check whose name + contains `review` or `approval`. + - Code coverage: `Coverage (unit-test)`, `Coverage_Fake`, + any check whose name contains `codecov` or `coverage` + (case-insensitive). + - Docs: `build-docs / Build docs`, `build-docs-summary`, + any check whose name contains `build-docs`, `doc-build`, + `readthedocs`, or `sphinx`. + + Everything else — unit tests (`tests/unit_tests/...`), + integration tests (`gpt/...`, `moe/...`, etc.), `linting`, + `cicd-container-build`, `cicd-mbridge-testing`, + `Nemo_CICD_Test`, `copyright-check`, `pre-flight`, wheel + builds, etc. — is NOT exempt and must reach a terminal + green conclusion. + show_full_output: true + claude_args: | + --allowedTools "Bash,Read,Edit,Write,Grep,Glob,Agent" + --model "opus[1m]" + --effort max + --max-turns 1500 diff --git a/AGENTS.md b/AGENTS.md index 24d5d846238..70e8152cbf4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,12 +1,30 @@ # Repository Guidelines +## Skills + +The `skills/` directory contains structured guides for common tasks (running +tests, building containers, managing dependencies, submitting SLURM jobs, etc.). +**Always read the relevant `SKILL.md` before starting any task it covers — +skills are mandatory context, not optional background reading.** + +**Workflow — mandatory order for every task:** +1. **Pull information first.** Read the commit, PR, error log, file, or + whatever artifact the task is about. Do not reason about it yet. +2. **Select and invoke the skill.** Based on what you just read, identify + the relevant skill and invoke it before forming any answer or plan. +3. **Answer or implement.** Only after the skill is loaded, use its context + to reason, diagnose, or write code. + +Never skip or reorder these steps. Do not wait for the user to name the right +skill keyword — infer it from the artifact you read. + ## Contributing ### Pull Requests - All PRs must be created as **drafts**. Use `gh pr create --draft` or the GitHub UI draft option. - Never push branches directly to `https://github.com/NVIDIA/Megatron-LM`. You must push your branch to a personal fork (e.g. `https://github.com//Megatron-LM`), then open a PR from the fork's branch against `NVIDIA/Megatron-LM`. -- Read [docs/developer/contribute.md](docs/developer/contribute.md) for the full contribution policy, including code style, commit message conventions, and issue guidelines. +- Read @docs/developer/contribute.md for the full contribution policy, including code style, commit message conventions, and issue guidelines. ### Code Quality diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000000..728cdb4a1d2 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,25 @@ +## Security + +NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization. + +If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub.** If a potential security issue is inadvertently reported via a public issue or pull request, NVIDIA maintainers may limit public discussion and redirect the reporter to the appropriate private disclosure channels. + +## Reporting Potential Security Vulnerability in an NVIDIA Product + +To report a potential security vulnerability in any NVIDIA product: + +- Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html) +- E-Mail: psirt@nvidia.com + - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key) + - Please include the following information: + - Product/Driver name and version/branch that contains the vulnerability + - Type of vulnerability (code execution, denial of service, buffer overflow, etc.) + - Instructions to reproduce the vulnerability + - Proof-of-concept or exploit code + - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability + +While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information. + +## NVIDIA Product Security + +For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security diff --git a/docs/api-guide/core/transformer.md b/docs/api-guide/core/transformer.md index d35144fda4f..03bc0f501f4 100644 --- a/docs/api-guide/core/transformer.md +++ b/docs/api-guide/core/transformer.md @@ -15,5 +15,5 @@ of a transformer stack, from entire layers down to individual linear layers, can be customized by swapping in different PyTorch modules using the "spec" parameters. The configuration of the transformer (hidden size, number of layers, -number of attention heads, etc.) is provided via a `TransformerConfig` +number of attention heads) is provided using a `TransformerConfig` object. diff --git a/docs/conf.py b/docs/conf.py index 26b618b1eac..5606eaa5809 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -106,4 +106,12 @@ html_extra_path = ["project.json", "versions1.json"] # Github links are now getting rate limited from the Github Actions -linkcheck_ignore = [".*github\\.com.*", ".*githubusercontent\\.com.*"] +linkcheck_ignore = [".*github\\.com.*", ".*githubusercontent\\.com.*", "http://localhost.*"] + +# PyTorch docs use a JS-rendered frontend; anchor IDs are injected at runtime +# and are not present in the static HTML that linkcheck fetches. +linkcheck_anchors_ignore_for_url = [r"https://docs\.pytorch\.org/.*"] + +# PyTorch docs anchor IDs change between stable versions; verify the page +# loads but skip anchor validation to avoid spurious failures on redirects. +linkcheck_anchors_ignore_for_url = ["https://docs.pytorch.org/.*"] diff --git a/docs/developer/contribute.md b/docs/developer/contribute.md index aeb785f915d..30a39e1cbc0 100644 --- a/docs/developer/contribute.md +++ b/docs/developer/contribute.md @@ -13,7 +13,7 @@ This document outlines the processes and policies for issues and pull requests b Everyone is welcome to contribute to the project! We recently migrated from using an internal repo to doing all development directly from the GitHub repository. -When contributing it is important to ensure that changes are in line with the project direction. Small changes to fix bugs are welcomed and appreciated. If proposing large architectural changes or changes for stylistic reasons open an issue first so we can discuss it. +When contributing it is important to ensure that changes are in line with the project direction. Small changes to fix bugs are welcomed and appreciated. **If proposing large architectural changes or changes for stylistic reasons open an issue first so we can discuss it.** ## Issue policy diff --git a/docs/developer/oncall.md b/docs/developer/oncall.md index 0e5b38e2708..18d76f1436a 100644 --- a/docs/developer/oncall.md +++ b/docs/developer/oncall.md @@ -50,9 +50,10 @@ Below is the checklist that the oncall needs to go through for each PR. ## Issues and Discussion Questions -If you do not know the answer to an issue or discussion question: that's ok! **Delegate to someone who does.** +If you do not know the answer to an issue or discussion question, that's ok, **Delegate to someone who does.** On a daily basis, track the following: -- [new issues](https://github.com/NVIDIA/Megatron-LM/issues): check to see if there are any new issues before they become out of SLA! -- [out of SLA issues](https://github.com/orgs/NVIDIA-NeMo/projects/20/views/4?sliceBy%5Bvalue%5D=NVIDIA%2FMegatron-LM): useful dashboard that tracks all out of SLA issues +- [Dashboard for out of SLA issues](https://github.com/NVIDIA/Megatron-LM/issues?q=is%3Aissue%20state%3Aopen%20label%3Awaiting-on-maintainers). + + diff --git a/docs/discussions/README.md b/docs/discussions/README.md index e791ed57cd8..aab65fc65ca 100644 --- a/docs/discussions/README.md +++ b/docs/discussions/README.md @@ -19,13 +19,9 @@ This directory contains in-depth guides, tutorials, and discussions about optimi ### Training Guides -- **[Megatron-FSDP User Guide](megatron-fsdp-user-guide/megatron-fsdp-user-guide.md)** - - A practical guide to enable Megatron-FSDP training, including a quick-start example for DeepSeek-V3, required and recommended configurations, and instructions for checkpoint conversion from torch_dist to fsdp_dtensor. - ## Contributing -If you'd like to contribute a guide or tutorial, please follow this structure: +To contribute a guide or tutorial, follow this structure: 1. Create a new directory: `docs/discussions/your-guide-name/` 2. Add your main guide: `docs/discussions/your-guide-name/your-guide-name.md` diff --git a/docs/discussions/megatron-fsdp-user-guide/megatron-fsdp-user-guide.md b/docs/discussions/megatron-fsdp-user-guide/megatron-fsdp-user-guide.md deleted file mode 100644 index b5de090ab46..00000000000 --- a/docs/discussions/megatron-fsdp-user-guide/megatron-fsdp-user-guide.md +++ /dev/null @@ -1,129 +0,0 @@ ---- -orphan: true ---- - - - -# Megatron-FSDP User Guide - -## Table of Contents - -- [Megatron-FSDP Quick Start](#megatron-fsdp-quick-start) -- [Checkpoint Conversion from 3D-Parallel to Megatron-FSDP](#checkpoint-conversion-from-3d-parallel-to-megatron-fsdp) - -## Megatron-FSDP Quick Start - -We recommend using the latest [NVIDIA NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo/tags), which provides a tested software stack and optimized performance. - -For your reference, we provide an example launch script for DeepSeek-V3: [`sbatch_mfsdp_deepseek_v3.sh`](./example-scripts/sbatch_mfsdp_deepseek_v3.sh). - -### Required Configurations - -To enable Megatron-FSDP, add the following required flags to your training script: - -```bash ---use-megatron-fsdp ---data-parallel-sharding-strategy optim_grads_params ---no-gradient-accumulation-fusion ---use-distributed-optimizer ---ckpt-format fsdp_dtensor -``` - -### Recommended Configurations - -We also recommend adding the following configurations to further improve performance: - -```bash -unset CUDA_DEVICE_MAX_CONNECTIONS -``` -```bash ---calculate-per-token-loss ---init-model-with-meta-device ---grad-reduce-in-bf16 ---fsdp-double-buffer ---use-nccl-ub -``` - -💡 **Detailed explanations of these configurations are provided below.** - -#### 1. Disable `CUDA_DEVICE_MAX_CONNECTIONS` - -To ensure full parallelization of FSDP communication and computation, disable the CUDA_DEVICE_MAX_CONNECTIONS environment variable. This step avoids potential bubbles in the CUDA stream. (But it may slow down TP and CP to some extent.) - -#### 2. Add `--calculate-per-token-loss` - -For gradients sharding mode optimization, include the `--calculate-per-token-loss` flag in your training script. This improves performance by reducing the frequency of gradient scaling, which is also a sizable drain on SM resources. - -#### 3. Add `--init-model-with-meta-device` - -Allows model initialization using meta device, followed by layer-by-layer initialization of distributed model weight buffers via the `Module.reset_parameters` API, facilitating the initialization of extremely large models. - -#### 4. Add `--grad-reduce-in-bf16` - -Enables gradient reduction in BF16 precision instead of FP32, reducing communication volume and accelerating the backward pass. - -#### 5. Add `--fsdp-double-buffer` - -Uses persistently allocated double buffers for temporarily-defined memory needed in `MegatronFSDP` communications. While having persistent double buffers may increase peak VRAM utilization, it is necessary to register NCCL user buffers (`nccl_ub=True`) for `MegatronFSDP`. Currently, this is supported only for simple repetitive model structures such as GPT. - -- **Only effective when using Megatron-LM.** -- Defaults to `False`. Automatically overridden to `True` when `nccl_ub` is enabled. - -#### 6. Add `--use-nccl-ub` - -Allocates and [registers NCCL user buffers](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html#) for param and grad buffers. This option enables an SM-efficient NCCL algorithm that could improve the performance of overlapped computations. This flag will be much more effective when used together with [SHARP](https://docs.nvidia.com/networking/display/sharpv3130) if the FSDP communication includes both NVL and IB domains. Enabling this option will cause additional memory overhead due to the requirement to enable the `fsdp_double_buffer` option. - -- **Only effective when using Megatron-LM.** -- Defaults to `False`. -- By default we try to use NCCL window (symmetric) registration if it is available. If not it falls back to conventional local registration. -- **Incompatible with PyTorch's segmentable allocator:** Do not set `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` when using `--use-nccl-ub`, as this will cause a runtime error due to compatibility issues with the `torch.cuda.MemPool` API. - -## Checkpoint Conversion from 3D-Parallel to Megatron-FSDP - -Megatron-FSDP introduces `fsdp_dtensor`, a DTensor-based distributed checkpoint format that serves as its standard. To help you smoothly transition from 3D-Parallel to Megatron-FSDP, we provide a script for converting checkpoints from the `torch_dist` format to the `fsdp_dtensor` format. Using DeepSeek-V3 as an example, the detailed conversion process is described below. - -### Step 1: Generate 3D-Parallel Checkpoint with `param_to_param_group_map` - -Run your 3D-parallel + EP training script to generate a `torch_dist` checkpoint along with a directory containing `param_to_param_group_map` files. Add the following flag to your training script: - -```bash ---dump-param-to-param-group-map /path/to/param_to_param_group_map -``` - -If you already have a `torch_dist` checkpoint, simply specify the `--dump-param-to-param-group-map /path/to/param_to_param_group_map` flag and run a very short experiment-this will create the `param_to_param_group_map` you need without full pretraining. - -### Step 2: Export `param_to_param_group_map` to a JSON File - -Convert the `param_to_param_group_map` into a JSON file for easier processing by running: - -```bash -python tools/checkpoint/checkpoint_inspector.py print-torch-dcp-in-json /path/to/param_to_param_group_map -``` - -This will create a `param_to_param_group_map.json` file in the `/path/to/param_to_param_group_map` directory. - -### Step 3: Convert Checkpoint from `torch_dist` to `fsdp_dtensor` - -Convert your `torch_dist` checkpoint to the `fsdp_dtensor` format using the parameter to `param_to_param_group_map` JSON file: - -```bash -torchrun --nproc_per_node=8 --nnodes=1 \ - tools/checkpoint/checkpoint_inspector.py \ - convert-torch-dist-to-fsdp-dtensor --swiglu \ - /path/to/input_torch_dist_checkpoint \ - /path/to/output_fsdp_dtensor_checkpoint \ - --param-to-param-group-map-json /path/to/param_to_param_group_map.json -``` - -**Note:** For multi-node conversion tasks, please refer to the example script: [`sbatch_checkpoint_convert.sh`](./example-scripts/sbatch_checkpoint_convert.sh). - -### Step 4: Launch Megatron-FSDP Training - -Start your Megatron-FSDP training job using the converted `fsdp_dtensor` checkpoint. \ No newline at end of file diff --git a/docs/images/custom_fsdp/FSDP_workflow.png b/docs/images/custom_fsdp/FSDP_workflow.png deleted file mode 100644 index 588b6f220a3..00000000000 Binary files a/docs/images/custom_fsdp/FSDP_workflow.png and /dev/null differ diff --git a/docs/images/custom_fsdp/MCore_Custom_FSDP_Class_Diagram.png b/docs/images/custom_fsdp/MCore_Custom_FSDP_Class_Diagram.png deleted file mode 100644 index f9603079b92..00000000000 Binary files a/docs/images/custom_fsdp/MCore_Custom_FSDP_Class_Diagram.png and /dev/null differ diff --git a/docs/images/megatron_fsdp/DDP_vs_FSDP.png b/docs/images/megatron_fsdp/DDP_vs_FSDP.png new file mode 100644 index 00000000000..627821439e2 Binary files /dev/null and b/docs/images/megatron_fsdp/DDP_vs_FSDP.png differ diff --git a/docs/images/custom_fsdp/FSDP_Allreduce.png b/docs/images/megatron_fsdp/FSDP_Allreduce.png similarity index 100% rename from docs/images/custom_fsdp/FSDP_Allreduce.png rename to docs/images/megatron_fsdp/FSDP_Allreduce.png diff --git a/docs/images/megatron_fsdp/fsdp_double_buffer.png b/docs/images/megatron_fsdp/fsdp_double_buffer.png new file mode 100644 index 00000000000..fbfbcef9b28 Binary files /dev/null and b/docs/images/megatron_fsdp/fsdp_double_buffer.png differ diff --git a/docs/images/megatron_fsdp/fsdp_streams.png b/docs/images/megatron_fsdp/fsdp_streams.png new file mode 100644 index 00000000000..6b8840783c8 Binary files /dev/null and b/docs/images/megatron_fsdp/fsdp_streams.png differ diff --git a/docs/images/megatron_fsdp/fsdp_v_hfsdp_streams.png b/docs/images/megatron_fsdp/fsdp_v_hfsdp_streams.png new file mode 100644 index 00000000000..6f6e61dfb21 Binary files /dev/null and b/docs/images/megatron_fsdp/fsdp_v_hfsdp_streams.png differ diff --git a/docs/images/megatron_fsdp/hfsdp.png b/docs/images/megatron_fsdp/hfsdp.png new file mode 100644 index 00000000000..3c056d20689 Binary files /dev/null and b/docs/images/megatron_fsdp/hfsdp.png differ diff --git a/docs/images/megatron_fsdp/lcm_dim0_shard.png b/docs/images/megatron_fsdp/lcm_dim0_shard.png new file mode 100644 index 00000000000..910add676f1 Binary files /dev/null and b/docs/images/megatron_fsdp/lcm_dim0_shard.png differ diff --git a/docs/images/megatron_fsdp/mixed_sharding.png b/docs/images/megatron_fsdp/mixed_sharding.png new file mode 100644 index 00000000000..81cbc153f8a Binary files /dev/null and b/docs/images/megatron_fsdp/mixed_sharding.png differ diff --git a/docs/images/megatron_fsdp/quantized_param_gather.png b/docs/images/megatron_fsdp/quantized_param_gather.png new file mode 100644 index 00000000000..e1908e66ad7 Binary files /dev/null and b/docs/images/megatron_fsdp/quantized_param_gather.png differ diff --git a/docs/images/megatron_fsdp/sharded_quantization.png b/docs/images/megatron_fsdp/sharded_quantization.png new file mode 100644 index 00000000000..c65bab5305a Binary files /dev/null and b/docs/images/megatron_fsdp/sharded_quantization.png differ diff --git a/docs/images/megatron_fsdp/uneven_sharding.png b/docs/images/megatron_fsdp/uneven_sharding.png new file mode 100644 index 00000000000..0c34a51b026 Binary files /dev/null and b/docs/images/megatron_fsdp/uneven_sharding.png differ diff --git a/docs/images/megatron_fsdp/zero3_model_state.png b/docs/images/megatron_fsdp/zero3_model_state.png new file mode 100644 index 00000000000..84ad33ff779 Binary files /dev/null and b/docs/images/megatron_fsdp/zero3_model_state.png differ diff --git a/docs/index.md b/docs/index.md index 0dbf7d2e3b7..11337315588 100644 --- a/docs/index.md +++ b/docs/index.md @@ -67,7 +67,7 @@ models/index user-guide/features/moe user-guide/features/context_parallel -user-guide/features/custom_fsdp +user-guide/features/megatron_fsdp user-guide/features/dist_optimizer user-guide/features/optimizer_cpu_offload user-guide/features/pipeline_parallel_layout diff --git a/docs/llama_mistral.md b/docs/llama_mistral.md index 95568adce78..2754405610c 100644 --- a/docs/llama_mistral.md +++ b/docs/llama_mistral.md @@ -22,13 +22,11 @@ Architecturally Llama-2, Llama-3 and Mistral-7b are very similar. As such Megatr - [Llama, Mistral and other Llama-like model support in Megatron-LM](#llama-mistral-and-other-llama-like-model-support-in-megatron-lm) - [Contents](#contents) - [Llama-2](#llama-2) - - [Download Meta or Huggingface checkpoints](#download-meta-or-huggingface-checkpoints) + - [Download Huggingface checkpoints](#download-huggingface-checkpoints) - [Convert checkpoint format](#convert-checkpoint-format) - - [Meta format](#meta-format) - [Huggingface format](#huggingface-format) - [Launch model](#launch-model) - [Launch Megatron](#launch-megatron) - - [Launch Meta](#launch-meta) - [Launch Huggingface](#launch-huggingface) - [Benchmark results](#benchmark-results) - [Big Bench](#big-bench) @@ -48,72 +46,35 @@ Architecturally Llama-2, Llama-3 and Mistral-7b are very similar. As such Megatr - [Launch model](#launch-model) - [Other Llama-like model support](#other-llama-like-model-support) - [Known numerical differences](#known-numerical-differences) -- [Using legacy model format](#using-legacy-model-format) # Llama-2 Llama-2 checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of three steps: 1. Get access to download the checkpoints. -2. Convert the checkpoints from Meta/Huggingface format to Megatron format. +2. Convert the checkpoints from Huggingface format to Megatron format. 3. Setup arguments for launching the model. The following sections detail these steps. The final section lists benchmark result comparisons between: 1) Llama-2 inference code running the Meta-format checkpoints, and 2) Megatron inference code running the converted checkpoints. -## Download Meta or Huggingface checkpoints +## Download Huggingface checkpoints -Users must first apply for access to download the Llama-2 checkpoints either directly [Huggingface](https://huggingface.co/docs/transformers/main/model_doc/llama2) (HF). The checkpoints are available in two formats, Meta's native format (available from both the Meta and HF links), and HF's format (available only from HF). Either format can be converted to Megatron, as detailed next. +Users must first apply for access to download the Llama-2 checkpoints either directly [Huggingface](https://huggingface.co/docs/transformers/main/model_doc/llama2) (HF). The checkpoints are available in HF's format (available only from HF). HF format can be converted to Megatron, as detailed next. ## Convert checkpoint format We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16. -### Meta format - -The Meta format checkpoints are converted to HF format as an intermediate step before converting to Megatron format. The `transformers` package is required, and must have version >=4.31.0 (e.g., `pip install transformers>=4.31.0`). (**Note**: we have specifically tested with versions `4.31.0` and `4.32.0`; your experience may vary with newer versions.) Assuming the downloaded checkpoints are in `$CHECKPOINT_DIR` (with separate sub-directories for 7B, 13B, 70B, etc.), the following example command can be used to convert from Llama-2 format to HF format in bfloat16: - -``` -python tools/checkpoint/convert.py \ -> --model-type GPT \ -> --loader llama_mistral \ -> --load-dir ${META_FORMAT_DIR} \ -> --model-size ${MODEL_SIZE} \ -> --checkpoint-type meta \ -> --tokenizer-model ${TOKENIZER_MODEL} \ -> --saver core \ -> --save-dir ${MEGATRON_FORMAT_DIR} \ -> --target-tensor-parallel-size ${TP} \ -> --target-pipeline-parallel-size ${PP} \ -> --bf16 -``` - -Valid values for `--model-size` are `llama2-7B`, `llama2-13B`, and `llama2-70B` (for pretrained-only models), and `llama2-7Bf`, `llama2-13Bf`, and `llama2-70Bf` (for chat-finetuned models). - ### Huggingface format -The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-2 checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values: - -| Model size | Tensor parallel size (`TP`) | -| ---------- | --------------------------- | -| 7B | 1 | -| 13B | 2 | -| 70B | 8 | - -Using these values for `TP`, along with the path to the Llama-2 tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format: +The HF checkpoints can be converted to Megatron format by using Megatron-Bridge's checkpoint converter for HF format [see script](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/examples/conversion/convert_checkpoints.py). ``` -python tools/checkpoint/convert.py \ -> --model-type GPT \ -> --loader llama_mistral \ -> --load-dir ${HF_FORMAT_DIR} \ -> --model-size ${MODEL_SIZE} \ -> --checkpoint-type hf \ -> --tokenizer-model ${TOKENIZER_MODEL} \ -> --saver core \ -> --save-dir ${MEGATRON_FORMAT_DIR} \ -> --target-tensor-parallel-size ${TP} \ -> --target-pipeline-parallel-size ${PP} \ -> --bf16 +python Megatron-Bridge/examples/conversion/convert_checkpoints.py import \ + --hf-model meta-llama/Llama-2-7B \ + --megatron-path ./checkpoints/llama2_7b \ + --torch-dtype bfloat16 \ + --device-map auto ``` After this conversion, we are ready to load the checkpoints into a Megatron GPT model. @@ -144,12 +105,6 @@ If loading for either inference or finetuning, use the following arguments: --attention-softmax-in-fp32 ``` -**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format). - -### Launch Meta - -Meta checkpoints can be launched with: - ### Launch Huggingface Huggingface checkpoints can be launched with: @@ -243,29 +198,14 @@ We recommend passing `--dtype bf16` for training or finetuning. Inference can be ### Huggingface format -The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-3.x checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values: - -| Model size | Tensor parallel size (`TP`) | -| ---------- | --------------------------- | -| 1B | 1 | -| 3B | 1 | -| 8B | 1 | -| 70B | 8 | - -Using these values for `TP`, along with the path to the Llama-3.x tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format: +The HF checkpoints can be converted to Megatron format by using Megatron-Bridge's checkpoint converter for HF format [see script](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/examples/conversion/convert_checkpoints.py). ``` -$>: python tools/checkpoint/convert.py \ - > --bf16 \ - > --model-type GPT \ - > --loader llama_mistral \ - > --saver core \ - > --target-tensor-parallel-size ${TP} \ - > --checkpoint-type hf \ - > --load-dir ${HF_FORMAT_DIR} \ - > --save-dir ${MEGATRON_FORMAT_DIR} \ - > --tokenizer-model ${TOKENIZER_MODEL} \ - > --model-size llama3 \ +python Megatron-Bridge/examples/conversion/convert_checkpoints.py import \ + --hf-model meta-llama/Llama-3.2-1B \ + --megatron-path ./checkpoints/llama3_2_1b \ + --torch-dtype bfloat16 \ + --device-map auto ``` After this conversion, we are ready to load the checkpoints into a Megatron GPT model. @@ -345,8 +285,6 @@ For Llama3.1 please use the following arguments: --bf16 \ ``` -**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format). - # Mistral-7b Megatron currently supports loading the v0.3 release of Mistral-7b (which does not use sliding window attention and offers a larger 32768 vocabulary) for inference and finetuning. Loading these checkpoints consists of several steps: @@ -364,25 +302,17 @@ Users must first apply for access to download the Mistral-7b checkpoints through ## Convert checkpoint format -The HF checkpoints can be converted to Megatron format by using Megatron's own Mistral checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). - -Using the path to the Mistral tokenizer model (downloaded alongside the HF checkpoint), run the following command from the root of your Megatron source code to convert from HF format to the Megatron core format: +The HF checkpoints can be converted to Megatron format by using Megatron-Bridge's checkpoint converter for HF format [see script](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/examples/conversion/convert_checkpoints.py). ``` -$>: python tools/checkpoint/convert.py \ - > --bf16 \ - > --model-type GPT \ - > --loader llama_mistral \ - > --saver core \ - > --target-tensor-parallel-size ${TP} \ - > --checkpoint-type hf \ - > --load-dir ${HF_FORMAT_DIR} \ - > --save-dir ${MEGATRON_FORMAT_DIR} \ - > --tokenizer-model ${TOKENIZER_MODEL} \ - > --model-size mistral \ +python Megatron-Bridge/examples/conversion/convert_checkpoints.py import \ + --hf-model mistralai/Mistral-7B-Instruct-v0.3 \ + --megatron-path ./checkpoints/mistral_7b \ + --torch-dtype bfloat16 \ + --device-map auto ``` -After this conversion, we are ready to load the checkpoints into a Megatron core GPT model. +After this conversion, we are ready to load the checkpoints into a Megatron GPT model. ## (Optional) Validate checkpoints @@ -424,8 +354,6 @@ If loading for either inference or finetuning, use the following arguments: --num-attention-heads 32 ``` -**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format). - # Other Llama-like model support *Note: Experimental* @@ -438,15 +366,3 @@ It is not expected that the megatron and Huggingface implementations of llama3.x 1. TransformerEngine (TE) uses the model params_dtype inside RMSNorm whereas the Huggingface implementation uses fp32. See for details: 2. Huggingface `transformers` implements the q, k and v projections in self-attention as separate GEMMs whereas Megatron core combines them into a single GEMM for efficiency. This leads to small numerical differences. - -# Using legacy model format - -In all the checkpoint conversion examples used in this document, the saver format `--saver core` is used, signifying that the newer (and recommended) Megatron GPT model class will be used. I.e.: - -- old class: `megatron.legacy.model.gpt_model.GPTModel` -- new class: `megatron.core.models.gpt.gpt_model.GPTModel` - -Using this new format is the recommended approach. However, if your use case requires using the older class (i.e., convert using `--saver legacy`), then when launching training or finetuning, the following args must be added: - -- `--use-legacy-models`: use the older model class -- `--ckpt-format torch`: use the `torch` checkpoint format, which is the only checkpoint format that is compatible with the legacy model format diff --git a/docs/models/multimodal.md b/docs/models/multimodal.md index dce977e261d..07ff76d8d9a 100644 --- a/docs/models/multimodal.md +++ b/docs/models/multimodal.md @@ -18,7 +18,7 @@ Megatron Core supports multimodal models that combine language with vision, audi > **Note**: MIMO is experimental and under active development. The API may change in future releases. **Key Features:** -- Arbitrary modality combinations (vision, audio, text, etc.) +- Arbitrary modality combinations (vision, audio, text) - Flexible encoder architecture for different input modalities - Unified embedding space across modalities - Support for both vision-language and audio-vision-language models @@ -42,7 +42,8 @@ See [examples/mimo](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/mim ## Diffusion Models -For multimodal diffusion models (image generation, text-to-image, etc.), see [NeMo Diffusion Models](https://github.com/NVIDIA-NeMo/NeMo/tree/main/nemo/collections/diffusion). NeMo provides production-ready implementations of: +For multimodal diffusion models (image generation, text-to-image). Refer to [Nvidia Diffusion Models](https://github.com/NVIDIA-NeMo/DFM/ ). The Developer Program, NIM, and NeMo can offer production-ready implementations of: + - Stable Diffusion variants - Text-to-image generation - Image-to-image translation diff --git a/docs/user-guide/data-loading.md b/docs/user-guide/data-loading.md index 1f0d544317c..b60cd685cf2 100644 --- a/docs/user-guide/data-loading.md +++ b/docs/user-guide/data-loading.md @@ -72,6 +72,8 @@ If your later training job does not set `--global-batch-size`, or you are prepar This keeps the prepared cache aligned with the sample counts expected by training. +> **Unsupported configurations:** `tools/prepare_cache.py` does not support `--mock-data`, `--sft`, `--fim-data`, or `--step-batch-size-schedule`. Using any of these will cause the script to exit with an error. + ### Step 3: Optionally pre-build per-dataset metadata When blending many datasets, generate the `--per-dataset-sequences-path` JSON ahead of time to avoid one metadata read per file prefix at startup: diff --git a/docs/user-guide/features/custom_fsdp.md b/docs/user-guide/features/custom_fsdp.md deleted file mode 100644 index 3c774e5f493..00000000000 --- a/docs/user-guide/features/custom_fsdp.md +++ /dev/null @@ -1,196 +0,0 @@ - - -# Megatron FSDP - -**Note: In M-Core 0.14, the custom FSDP refactored its checkpoint implementation to use DTensor-based PyTorch distributed checkpointing. The custom FSDP was also renamed Megatron FSDP. The relevant sections of this document are no longer applicable.** - -## How to Use Megatron FSDP - -Add these flags to enable MCore custom FSDP. - -```bash ---use-megatron-fsdp ---data-parallel-sharding-strategy optim_grads_params ---no-gradient-accumulation-fusion ---use-distributed-optimizer -``` - -For a practical guide covering required configurations, checkpoint conversion, and example scripts, refer to the [Megatron-FSDP User Guide](../../discussions/megatron-fsdp-user-guide/megatron-fsdp-user-guide.md). - -## Key Features - -- **Sharding Strategy**: Shards optimizer states, gradients, and parameters to reduce memory consumption. -- **Communication and Computation Overlap**: Overlaps communication with computation where possible during training. -- **Supports automatic mixed precision training**: Compatible with BF16 O1/O2/O3 recipes, as well as FP8 compute with FP32 parameters and FP8 parameter training, with several precision configuration options. -- **Tensor Parallelism (TP), Expert Parallelism (EP), and Context Parallelism (CP)**: Compatible with TP, EP, and CP configurations for scaling large language models. -- **Distributed Model Initialization with Meta Device**: Allows model initialization using the meta device, then layer-by-layer initialization of distributed model weight buffers through the `Module.reset_parameters` API, which supports very large models. - -## Configuration Recommendations - -### Disable `CUDA_DEVICE_MAX_CONNECTIONS` - -To ensure full parallelization of FSDP communication and computation, disable the CUDA_DEVICE_MAX_CONNECTIONS environment variable. This step avoids potential bubble in CUDA stream. (But it may slow down TP and CP to some extent.) - -```bash -unset CUDA_DEVICE_MAX_CONNECTIONS -``` - -### Add `--calculate-per-token-loss` - -For gradients sharding mode optimization, include the `--calculate-per-token-loss` flag in your training script. This improves performance by reducing the frequency of gradient scaling, which is also a sizable drain on SM resources. - -## Design of Custom FSDP - -### Overview - -The custom Fully Sharded Data Parallelism (FSDP) implementation in Megatron Core targets memory use and throughput for large language models. The core design principles include: - - - **Optimized for Large Language Models**: This custom FSDP implementation scales with models containing billions of parameters and supports training at that scale. - - **Efficient Memory Consumption**: By sharding optimizer states, gradients, and model parameters, the custom FSDP cuts memory use so models that would not fit on device with plain DDP can train. - - **Efficient Workflow and Overlapping Communication and Computation**: The implementation reduces communication steps during training and overlaps communication with computation where possible to reduce idle time. - - **Support for MCore's Efficient Training Methods**: The custom FSDP integrates with advanced parallelism in Megatron Core, including tensor parallelism, expert parallelism, and context parallelism. It also supports automatic mixed precision training. - -The design of Custom FSDP draws inspiration from PyTorch FSDP [Zhao, Yanli, et al.](https://arxiv.org/pdf/2304.11277) and the MCore distributed optimizer. The following background on PyTorch FSDP clarifies the concepts behind the custom FSDP design. - -> In DistributedDataParallel, (DDP) training, each process/ worker owns a replica of the model and processes a batch of data, finally it uses all-reduce to sum up gradients over different workers. In DDP the model weights and optimizer states are replicated across all workers. FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks. - -> When training with FSDP, the GPU memory footprint is smaller than when training with DDP across all workers. This makes the training of some very large models feasible by allowing larger models or batch sizes to fit on device. This comes with the cost of increased communication volume. The communication overhead is reduced by internal optimizations like overlapping communication and computation. - -![Diagram of FSDP workflow showing all-gather, forward, discard, backward, and reduce-scatter across ranks](../../images/custom_fsdp/FSDP_workflow.png) - -The unit processed in the workflow here is the "FSDP instance 1: N layers", where an FSDP instance is the smallest FSDP processing unit (also a PyTorch module). You can release this module's weights after its forward or backward because no other computation depends on those weights. That behavior supports FSDP's layer-by-layer execution and memory-saving strategy. An FSDP instance is also called an **FSDP Unit**. - -An FSDP instance can correspond to multiple FSDP parameter groups. These groups are separated by Data Parallel (DP) communication groups and the data type of the parameter or gradient. Consequently, an FSDP instance may require several parameter-gather tasks before execution (forward or backward). Each **FSDP parameter group** corresponds to one **Data Parallel Buffer** in custom FSDP. - -At a high level, FSDP works as follows: - -In constructor: - - Shard model parameters and each rank only keeps its own shard - -In forward path: - - Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit - - Run forward computation - - Discard parameter shards it has just collected - -In backward path: - - Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit - - Run backward computation - - Run reduce_scatter to sync gradients - - Discard parameters - -One way to view FSDP’s sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. Specifically, during the backward pass, FSDP reduces and scatters gradients, ensuring that each rank possesses a shard of the gradients. Then it updates the corresponding shard of the parameters in the optimizer step. Finally, in the subsequent forward pass, it performs an all-gather operation to collect and combine the updated parameter shards. - -![Diagram comparing DDP all-reduce with FSDP reduce-scatter and all-gather](../../images/custom_fsdp/FSDP_Allreduce.png) - -### Custom FSDP Underlying Data Structure - -To implement the FSDP functionality described above, the custom FSDP is designed with the following Python classes and data structure: - -![Class diagram of Megatron Core custom FSDP Python types](../../images/custom_fsdp/MCore_Custom_FSDP_Class_Diagram.png) - -### The Custom FSDP Interface: FullyShardedDataParallel - -The custom FSDP provides the same programming interface as PyTorch's DistributedDataParallel (DDP) as FullyShardedDataParallel (FSDP). For example, you can apply FSDP to models as follows: - -```python -# Initialize model and optimizer -ddp_config.use_megatron_fsdp = True -ddp_config.data_parallel_sharding_strategy = "optim_grads_params" -model = GPTModel(transformer_config) -model = FullyShardedDataParallel( - transformer_config, - model, - ddp_config, - fsdp_unit_modules = [TransformerLayer, LanguageModelEmbedding], -) -optimizer = torch.optim.AdamW(model.parameters(), lr=lr) -optimizer = DistributedOptimizer(optimizer, [model], [model.param_and_grad_buffer]) - -# Training loop -def train_step(inputs, labels): - optimizer.zero_grad() - for mbs_input, mbs_label in zip(inputs, labels): - outputs = model(mbs_input) - loss = loss_fn(outputs, mbs_label) - loss.backward() - optimizer.step() - -# Save and load model and optimizer state dict -def model_and_optimizer_state_dict(): - state_dict = { - "model": model.sharded_state_dict(), - "optimizer": optimizer.sharded_state_dict(), - } - return state_dict - -def load_model_and_optimizer_state_dict(state_dict): - model.load_state_dict(state_dict["model"]) - optimizer.load_state_dict(state_dict["optimizer"]) -``` - -Key notes: - - - You can configure which modules should be treated as FSDP units through the `fsdp_unit_modules` argument. This configuration is mandatory. - - The custom FSDP must be used with a distributed optimizer since it provides distributed checkpointing. - - The data-parallel communication group for parameters is not explicitly shown. Custom FSDP configures these groups as either DP (data-parallel) or EDP (expert data-parallel) based on parameter markings. - -#### Initializing Models on the Meta Device - -For training particularly large models with FSDP, you can initialize the model on the meta device. Using PyTorch's `reset_parameters` API, you can initialize model weights layer by layer during the construction of the `ParamAndGradBuffer`. Most PyTorch native modules and TransformerEngine modules support this API (for example, [PyTorch Linear](https://github.com/pytorch/pytorch/blob/v2.6.0/torch/nn/modules/linear.py#L114), [TE LayerNormLinear](https://github.com/NVIDIA/TransformerEngine/blob/release_v2.0/transformer_engine/pytorch/module/layernorm_linear.py#L1107)). - -```python -# Initialize model on meta device -with torch.device("meta"): - model = GPTModel(config) - -model = FullyShardedDataParallel( - transformer_config, - model, - ddp_config, - fsdp_unit_modules=[TransformerLayer, LanguageModelEmbedding], -) -``` - -**Important Considerations:** -- *Custom Modules*: If your model contains custom modules, ensure they implement the `reset_parameters` API. Otherwise, you may need to force parameter initialization on a CUDA or CPU device. -- *Tensor Initialization*: Be cautious of tensors created during model initialization without a specified device; they default to the meta device. To avoid issues, explicitly specify the device for these tensors to ensure compatibility with this function. - -### Interaction Between Custom FSDP and Model Forward/Backward Propagation - -Custom FSDP implements Fully Sharded Data Parallelism (FSDP) through a series of module hooks, gradient hooks, or by adding functions between modules. This involves inserting communications and manipulating parameters and gradients during PyTorch's module forward or backward propagation. - -Module hooks summary: -- Module pre-forward hook(`module.register_forward_pre_hook`): This hook unshards model weights before the forward pass. In the case of an FSDP Unit Module, add a RegisterFSDPBackwardFunction function that will reshard model weights and reduce gradients after module backward propagation. -- Module post-forward hook(`module.register_forward_hook`): This hook is used to reshard model weights after the forward pass. -- Root module pre-backward hook(`root_module.register_full_backward_pre_hook`): This hook checks that all model parameters are resharded, in order to avoid unnecessary memory spikes. It also marks all modules as being in the `TrainingState.PRE_BACKWARD` state. -- Module pre-backward hook(`module.register_full_backward_pre_hook`): This hook is used to unshard the model weights before the backward pass. -- Root module post-backward hook(`torch.autograd.Variable._execution_engine.queue_callback`): This hook is used to make sure all gradients in the backprop are properly handled / available. - -The gradient reduction pipeline maintains a map of gradients to FSDP parameter groups. If all gradients in an FSDP parameter group are ready, it launches a gradient reduction. Note that this assumes that the model's gradients are always generated in a certain order (reverse of `module.parameters()`), as otherwise, FSDP would maintain too many parameter group grad buffers, leading to excessive memory usage. - -#### Optimized for Activation Recompute - -Using activation recomputation runs the same module forward first and then its backward during backprop, which can unshard and reshard model weights twice. If the runtime can treat that as one forward-plus-backward region, it can unshard once and reshard once. - -To make this determination, the implementation tracks the model state with `training_state`: `FORWARD`, `PRE_BACKWARD`, `POST_BACKWARD`, `IDLE`. It is worth noting that the pre-backward hook runs before the pre-forward hook: the pre-backward hook performs the model weight unshard, then marks the model as `PRE_BACKWARD`, and when the pre-forward hook observes that mark it skips unshard. Similarly, for duplicate reshard logic, the post-forward hook runs before the post-backward path, and checking for the `PRE_BACKWARD` flag in the post-forward hook can cancel unshard. - -### Memory Mechanisms and Features of Custom FSDP - -FSDP can distribute model parameters, gradients, optimizer states, and (for mixed-precision training) high-precision main weights. That covers most memory outside activations, but FSDP can still hit allocator and spike issues. - -FSDP frequently unshards and reshards model weights, which can lead to busy memory allocation and deallocation. This results in untimely tensor releases, causing memory spikes (or even out-of-memory errors), crashes of the PyTorch memory allocator cache, and many `cudaMalloc` and `cudaFree` calls. These issues can slow the system noticeably. - -You can often address untimely tensor release with the `tensor._typed_storage()._resize_(0)` API, which deallocates storage immediately. Custom FSDP exposes hooks in `AllGatherPipeline` and `GradReducePipeline` to swap the temporary buffer allocator used for parameter gathering and gradient reduction with `StorageResizeBasedBucketAllocator`, using that `_resize_(0)` path for releases. - -The PyTorch memory allocator cache can fail when real usage nears GPU capacity, which hurts performance. Mitigation is limited; avoiding repeated pressure at the memory limit helps. A self-managed allocator such as `RotaryBucketAllocator` is another option, though it is not yet mature. - -## References - -- [Getting Started with Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) diff --git a/docs/user-guide/features/index.md b/docs/user-guide/features/index.md index 514568afac9..461e1e400b6 100644 --- a/docs/user-guide/features/index.md +++ b/docs/user-guide/features/index.md @@ -17,7 +17,7 @@ Guides for Megatron Core training features. fine_grained_activation_offloading moe context_parallel -custom_fsdp +megatron_fsdp dist_optimizer optimizer_cpu_offload paged_stash diff --git a/docs/user-guide/features/megatron_fsdp.md b/docs/user-guide/features/megatron_fsdp.md new file mode 100644 index 00000000000..36fcc68893c --- /dev/null +++ b/docs/user-guide/features/megatron_fsdp.md @@ -0,0 +1,608 @@ + + +# Megatron-FSDP + +## ✨ Overview + +**Megatron-FSDP** is an NVIDIA-developed distributed parallelism library written in native PyTorch that provides a high-performance implementation of **Fully Sharded Data Parallelism (FSDP)**. It offers seamless cross-compatibility with various deep learning frameworks and parallelism libraries such as Megatron-Core, and is performance-optimized to support training and inference of extremely large PyTorch models at data-center scale on NVIDIA GPUs. + +- PyPI: https://pypi.org/project/megatron-fsdp/ +- Source Code: https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/distributed/fsdp/src + +### 🧩 Compatibility + +- PyTorch **[DeviceMesh](https://docs.pytorch.org/docs/2.11/distributed.html#torch.distributed.device_mesh.DeviceMesh)**, **[DTensor](https://docs.pytorch.org/docs/stable/distributed.tensor.html)**, and **[Distributed Checkpoint (DCP)](https://docs.pytorch.org/docs/stable/distributed.checkpoint.html)** +- **[Megatron Core](https://github.com/NVIDIA/Megatron-LM)** +- **[TransformerEngine](https://github.com/NVIDIA/TransformerEngine)** +- **[NVIDIA NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo)** + +### 💡 Features + +- **Performant & Scalable**: Optimized for NVIDIA CUDA with efficient memory management and performance. Sports near-linear scaling up from single compute nodes to entire data-centers. +- **Multiple Algorithms in One**: Supports sharding your choice of optimizer states, gradients, and model parameters (FSDP), including hierarchical data parallelism strategies such as **Hybrid-Sharded Data Parallelism (HSDP)** and **Hybrid-FSDP (HFSDP / Fully-Sharded Optimizer State)** for optimizing intra-node and inter-node memory, communication, and performance. +- **"Bring Your Own Parallelism"**: Works seamlessly with PyTorch, Megatron-LM, Megatron-Bridge, and TransformerEngine, and can be plugged into other frameworks such as HuggingFace Transformers and TorchTitan. +- **Simple & Powerful**: Similar to PyTorch FSDP, the `fully_shard` API doesn't depend on any complex training framework or distributed environment. + +### ⏱️ Optimizations + +- **[TransformerEngine](https://github.com/NVIDIA/TransformerEngine) Mixed-Precision & Fused Kernels**: Native performance- and memory-optimal _compatibility with MXFP8, NVFP4, and various other quantization recipes and fused kernels provided by TransformerEngine_. +- **Advanced Bucketing**: `dtype`-customizable and precision-aware bucketing system to _tune the memory overhead, numerical accuracy, and latency of collectives_. Avoids redundant `COPY` operations before and after collectives, while remaining compatible with **[DTensor](https://docs.pytorch.org/docs/stable/distributed.tensor.html)** features such as **[Torch Distributed Checkpoint (DCP)](https://docs.pytorch.org/docs/stable/distributed.checkpoint.html)**. +- **Buffer Management**: Efficient use of storage and [NCCL User Buffer Registration](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html#user-buffer-registration) enable _direct communication into NCCL-managed memory_, achieving true zero-`COPY` data movement. Introduced in NCCL `v2.27`, **NCCL Symmetric Memory** communications employ _symmetric kernels_ that drastically reduce SM utilization and include networking optimizations such as high-precision (`FP32`) reduction over-the-wire. +- **Optimized Communication & SM Utilization via SHARP**: Leverages [**SHARP** (Scalable Hierarchical Aggregation and Reduction Protocol)](https://docs.nvidia.com/networking/display/sharpv3130) to _offload FSDP collectives to network switches (InfiniBand or NVLink-Switch)_ and significantly reduce utilization of GPU streaming multi-processors (SM) from 16-32 to 1-6 for **Multi-Node NVLink (MNNVL)** systems (Grace-Blackwell, Vera-Rubin, etc.), which lowers communication latency in large scaled-out workloads and frees up GPU-hosted processors for overlapped compute (GEMM) kernels. When FSDP sharding domains span both NVLink and InfiniBand, **hierarchical SHARP collectives** (NVL-SHARP and IB-SHARP) _optimize communication paths across the entire system topology_. +- [**Hybrid-FSDP (HFSDP)**](#understanding-hybrid-fsdp-hfsdp), a variation of _Hybrid-Sharded Data Parallelism (HSDP)_ that further shards the optimizer state across intra- and inter-node data-parallel ranks, _bridges the memory-communication trade-off between HSDP and FSDP_, unlocking memory efficiency at minimal cost to performance. + +## 🚀 Quick Start + +### 📦 Installation + +#### NeMo Framework Container + +Megatron-FSDP is pre-installed with Megatron-Core in the [NVIDIA NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo/tags). + +#### Megatron-Core + +Megatron-FSDP is bundled with Megatron-Core, which can be installed via `pip`: + +``` +# Install via PyPI +pip install --no-build-isolation megatron-core[mlm,dev] + +# Install from Source +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +pip install --no-build-isolation .[mlm,dev] +``` + +To import Megatron-FSDP in Python: +```python +import megatron.core.distributed.fsdp.src.megatron_fsdp +``` + +#### PyPI + +To install Megatron-FSDP as a standalone package to use the `fully_shard` API: + +``` +pip install megatron-fsdp +``` + +To import Megatron-FSDP in Python: + +```python +import megatron_fsdp +``` + +### 🎛️ Megatron-FSDP `fully_shard` + +Megatron-FSDP supports a simple `fully_shard` API that seamlessly enables FSDP with very few lines of code. + +```python +import torch +from megatron_fsdp import ( + fully_shard_model, + fully_shard_optimizer, +) + +# Initialize Torch Distributed. +torch.distributed.init_process_group() +torch.cuda.set_device(torch.distributed.get_rank()) + +# Fully-shard the model. +model = torch.nn.Transformer() +fsdp_model = fully_shard_model( + module=model, + fsdp_unit_modules=[ + torch.nn.TransformerEncoder, + torch.nn.TransformerDecoder + ] +) + +# Fully-shard the optimizer. +toy_adam = torch.optim.AdamW(params=fsdp_model.parameters(), lr=0.01) +optimizer = fully_shard_optimizer(optimizer=toy_adam) + +# Forward pass. +inp = torch.randn(1, 512, 512).to("cuda") +tgt = torch.randn(1, 512, 512).to("cuda") +output = fsdp_model(inp, inp) + +# Backward pass. +torch.nn.functional.mse_loss(output, tgt).backward() + +# Optimizer step. +optimizer.step() +optimizer.zero_grad() + +# Checkpoint the model and optimizer. +torch.distributed.checkpoint.save({ + "model": fsdp_model.state_dict(), + "optimizer": optimizer.state_dict(), +}, checkpoint_id="ckpt/") + +# Load the saved checkpoint. +ckpt = { + "model": fsdp_model.state_dict(), + "optimizer": optimizer.state_dict(), +} +torch.distributed.checkpoint.load(state_dict=ckpt, checkpoint_id="ckpt/") +fsdp_model.load_state_dict(ckpt["model"], strict=False) +optimizer.load_state_dict(ckpt["optimizer"]) +``` + +> ℹ️ `fully_shard` is an _**experimental**_ API. Please check back for updates as we fine-tune our user experience! For more examples using `fully_shard` for Megatron-FSDP, refer to our suite of unit tests: [`tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py`](../../../tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py) + +### 🤖 Megatron-LM + +Megatron-FSDP is deeply integrated into Megatron-Core. To enable FSDP (where optimizer states, gradients, and compute parameters are sharded) in Megatron, use the following arguments: + +``` +# Train models in Megatron-LM using Megatron-FSDP. +--use-megatron-fsdp +--data-parallel-sharding-strategy {no_shard, optim, optim_grads, optim_grads_params} +--ckpt-format fsdp_dtensor +``` + +Complete Llama-8B and DeepSeek-V3 training scripts using Megatron-FSDP with recommended settings can be found in [Megatron-LM/examples/megatron_fsdp](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/megatron_fsdp). + +#### Recommended Configuration for Megatron-LM + +Frequently-used options use with Megatron-FSDP include: + +```bash +# Un-set CUDA_DEVICE_MAX_CONNECTIONS to ensure stream independence / full-parallelization of FSDP computation and communication. May slightly affect TP and CP performance though. +unset CUDA_DEVICE_MAX_CONNECTIONS + +# Meta-Device Initialization - Load large model onto CUDA devices in shards to avoid OOM. +--init-model-with-meta-device + +# Per-Token Loss / No Gradient Scaling - Deactivate DP scaling during gradient reduction, which can be a drain on SM resources. +--calculate-per-token-loss + +# Decrease gradient reduction and accumulation precision to recommended data-types based on the precision of the model parameters, usually BF16. Reduces communication volume during the backwards pass. Can be further customized with `--megatron-fsdp-main-grads-dtype` and `--megatron-fsdp-grad-comm-dtype`, which are enabled by this argument. +--grad-reduce-in-bf16 + +# Register NCCL user buffers and Megatron-FSDP double buffers to enable zero-copy symmetric kernels and low-SM utilization via SHARP. Improves overall performance but increases memory overhead due to double-buffering and is NOT compatible with `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. +--use-nccl-ub +--fsdp-double-buffer +--fsdp-manual-registration +``` + +### 🤖 Megatron-Core + +Megatron-FSDP has a lower-level `FullyShardedDataParallel` class API that can be used with a simplified version of Megatron-LM's training loop. + +```python +# Initialize model and optimizer. +ddp_config.use_megatron_fsdp = True +# Megatron-FSDP Base Sharding Strategies: +# no_shard, optim, optim_grads, optim_grads_params +ddp_config.data_parallel_sharding_strategy = "optim_grads_params" +model = GPTModel(transformer_config) +model = FullyShardedDataParallel( + transformer_config, + model, + ddp_config, + fsdp_unit_modules = [TransformerLayer, LanguageModelEmbedding], +) +optimizer = torch.optim.AdamW(model.parameters(), lr=lr) +optimizer = DistributedOptimizer(optimizer, [model], [model.param_and_grad_buffer]) + +# Training loop +def train_step(inputs, labels): + optimizer.zero_grad() + for mbs_input, mbs_label in zip(inputs, labels): + outputs = model(mbs_input) + loss = loss_fn(outputs, mbs_label) + loss.backward() + optimizer.step() + +# Save and load model and optimizer state dict +def model_and_optimizer_state_dict(): + state_dict = { + "model": model.sharded_state_dict(), + "optimizer": optimizer.sharded_state_dict(), + } + return state_dict + +def load_model_and_optimizer_state_dict(state_dict): + model.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optimizer"]) +``` + +### 🔁 Checkpoint Conversion + +Megatron-FSDP checkpointing supports [PyTorch Distributed Checkpoint (DCP)](https://docs.pytorch.org/docs/stable/distributed.checkpoint.html). In Megatron-LM, this is the `--ckpt-format fsdp_dtensor` checkpointing format. + +#### Converting Torch DCP to Torch Save (Non-Distributed) Checkpoints + +PyTorch has utilities to convert Torch DCP checkpoints to and from regular Torch checkpoints: +```shell +python -m torch.distributed.checkpoint.format_utils --help +usage: format_utils.py [-h] {torch_to_dcp,dcp_to_torch} src dst + +positional arguments: + {torch_to_dcp,dcp_to_torch} + Conversion mode + src Path to the source model + dst Path to the destination model + +options: + -h, --help show this help message and exit +``` +For example: +```shell +python -m torch.distributed.checkpoint.format_utils dcp_to_torch dcp_ckpt/ torch_ckpt.pt +``` +or: +```python +from torch.distributed.checkpoint.format_utils import ( + dcp_to_torch_save, + torch_save_to_dcp, +) + +# Convert DCP model checkpoint to torch.save format. +dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_PATH) + +# Convert torch.save model checkpoint back to DCP format. +torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_PATH, f"{CHECKPOINT_DIR}_new") +``` +Torch Save checkpoints can then be converted into HuggingFace SafeTensors or other checkpoint formats for distribution. + +> ℹ️ Megatron-FSDP checkpoints have a `module.` prefix pre-pended to all model parameter names in the state dictionary, and converting a Torch Save checkpoint to a Megatron-FSDP Torch DCP checkpoint requires testing. Work-in-progress! + +#### Converting N-D Parallel (`torch_dist`) to Megatron-FSDP (`fsdp_dtensor`) Checkpoints + +As a pre-requisite for checkpoint conversion, dump the parameter group mapping when training with 3D-parallel (DDP, TP, PP) and/or EP: + +```bash +--dump-param-to-param-group-map /path/to/param_to_param_group_map +``` + +and convert the map to a `param_to_param_group_map.json` JSON file in the `/path/to/param_to_param_group_map` directory: + +```bash +python tools/checkpoint/checkpoint_inspector.py print-torch-dcp-in-json /path/to/param_to_param_group_map +``` + +> ℹ️ If you already have a `torch_dist` checkpoint, simply specify the `--dump-param-to-param-group-map /path/to/param_to_param_group_map` flag and run a trivial training or checkpointing experiment to create the `param_to_param_group_map` you need without full pretraining. + +Finally, convert your `torch_dist` checkpoint to the `fsdp_dtensor` format using the `param_to_param_group_map.json`: + +```bash +torchrun --nproc_per_node=8 --nnodes=1 \ + tools/checkpoint/checkpoint_inspector.py \ + convert-torch-dist-to-fsdp-dtensor (--swiglu) \ # --swiglu for specific models. + /path/to/input_torch_dist_checkpoint/ \ + /path/to/output_fsdp_dtensor_checkpoint/ \ + --param-to-param-group-map-json /path/to/param_to_param_group_map.json +``` + +> ℹ️ For multi-node conversion tasks, please refer to the DeepSeek-V3 example script (`sbatch_checkpoint_convert.sh`) in [Megatron-LM/examples/megatron_fsdp](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/megatron_fsdp). + +## Megatron-FSDP Feature Guide & API + +| Optimization | Description | `Megatron-Core` Config | `fully_shard` Config | +|--------------|-------------|----------------------|----------------------| +| **Megatron-FSDP** | Use Megatron-FSDP in Megatron-LM. | `--use-megatron-fsdp` | `fully_shard_model(module)` | +| **Megatron-FSDP Checkpointing** | Save and load un-even DTensor checkpoints using [Torch Distributed Checkpoint (DCP)](https://docs.pytorch.org/docs/stable/distributed.checkpoint.html). | `--ckpt-format fsdp_dtensor` | `preproc_state_dict_for_dcp_ckpt=True` | +| **Meta Device Initialization** | Megatron-FSDP initializes a meta-initialized model to the CUDA device in shards to avoid OOM on large models. Requires implementation of `Module.reset_parameters()` for per-Module sharded initialization. | `--init-model-with-meta-device` | `init_model_with_meta_device=True` | +| **Distributed Optimizer** | Megatron-FSDP uses Megatron-Core's `DistributedOptimizer`. Automatically set when using Megatron-FSDP. | `--use-distributed-optimizer` | `fully_shard_optimizer(optimizer)` | + +### FSDP Fundamentals + +```{figure} ../../images/megatron_fsdp/DDP_vs_FSDP.png +:alt: FSDP Pipeline +:align: center + +Comparison between Distributed Data Parallelism (DDP) and Fully-Sharded Data Parallelism (FSDP). While gradients are all-reduced in DDP, they are sharded and reduce-scattered with FSDP. + +Source: Meta AI, Ott, Myle, et al. “Fully Sharded Data Parallel: Faster AI Training with Fewer GPUs.” _Facebook Engineering_, 15 July 2021, https://engineering.fb.com/2021/07/15/open-source/fsdp/. +``` + +**Fully Sharded Data Parallelism (FSDP)** is a type of distributed data parallelism (DDP) that shards optimizer state, weight gradients (`wgrad`), and model weights across devices that ingest data-parallel samples for data-parallel training or inference. Activations (`fprop`) and data gradients (`dgrad`) are not sharded or distributed, and are preserved for the backward pass, but can be recomputed during the backward pass, offloaded to CPU, or sharded / routed using other parallelisms such as tensor parallelism (TP), context parallelism (CP), or expert parallelism (EP). + +```{figure} ../../images/megatron_fsdp/zero3_model_state.png +:alt: ZeRO-3 Model State +:align: center + +Sharded memory profiles for ZeRO-1 (optimizer state), ZeRO-2 (optimizer state and gradients), and ZeRO-3 (optimizer state, gradients, and parameters). + +Source: Zero-Redundancy Optimizer Model State Partition Diagram. From _The Ultra-Scale Playbook: Training LLMs on GPU Clusters_ by Tazi, Nouamane, et al. HuggingFace, 2025, https://huggingface.co/spaces/nanotron/ultrascale-playbook. +``` + +The core principles of FSDP are: + +- Only a small depth-wise fraction of the model state can exist un-sharded at any point in time. +- Communication should overlap computation. + +From these core principles, software requirements can be derived: + +0. Model states sharded by FSDP are directly initialized across devices in shards. +1. Model parameters are all-gathered (AG) in pre-designated groups or modules pre-forward and pre-backward to un-shard a small fraction of the model state at any point in time during training or inference. After `fprop` and `dgrad` computation, the un-sharded weights are immediately de-allocated. +2. `wgrad` are reduce-scattered (RS) and accumulated in pre-designated groups or modules immediately post-backward to limit the amount of un-sharded gradients at any point in time during training or inference. +3. Distributed optimizers, optimizers that are initialized with respect to a sharded model state and support distributed mechanics, update the sharded model state using the reduced gradient shard to implement data parallelism (DP). +4. Computation and communication are overlapped across multiple CUDA streams, expending multiple streaming multi-processors (SM). Weights from subsequent groups or modules are pre-fetched, which ideally hides the communication latency required for FSDP behind model computation kernels (GEMM). + +FSDP can also be visualized as a decomposition of the all-reduce collective used in DDP into a gradient reduce-scatter, distributed optimization step, and parameter all-gather. + +```{figure} ../../images/megatron_fsdp/FSDP_Allreduce.png +:alt: FSDP RS & AG +:align: center + +Source: Feng, Wei, Will Constable, and Yifan Mao. “Getting Started with Fully Sharded Data Parallel (FSDP2).” _PyTorch Tutorials_, 17 Mar. 2022, https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html. +``` + +### FSDP Unit Modules + +| Optimization | Description | `Megatron-Core` Config | `fully_shard` Config | +|--------------|-------------|----------------------|----------------------| +| **FSDP Unit Modules** | A list of `str` or `class` import paths for `torch.nn.Module`(s) that are considered FSDP unit modules and sharded by Megatron-FSDP. Parameters and sub-modules that are not members of an FSDP unit are not sharded. | Defaults to supported Megatron-Core modules (`TransformerLayer`, etc.) in Megatron-LM. | `fsdp_unit_modules=[...]` | +| **FSDP Double Buffer Allocator** | Megatron-FSDP uses the double-buffer allocator, which persistently allocates a buffer pair assigned to alternating FSDP units that temporarily stores parameters and gradients. Automatically used with NCCL user buffer registration. | `--fsdp-double-buffer` | `fsdp_double_buffer=True` | +| **Param All-Gather Overlap** | Whether to overlap parameter all-gather with compute. Automatically activated for the ZeRO-3 sharding strategy. | `--overlap-param-gather` | `overlap_param_gather=True` | +| **Gradient Reduce-Scatter Overlap** | Whether to overlap gradient reduce-scatter or all-reduce with compute. Automatically activated for ZeRO-2 and ZeRO-3 sharding strategies. | `--overlap-grad-reduce` | `overlap_grad_reduce=True` | +| **FSDP Communication Size** | Customize the size (in `numel()` elements) of AG and RS communications in Megatron-FSDP, by limiting how many elements are concurrently pre-fetched or reduced for AG and RS. Effectively suggests how many FSDP units are processed concurrently, which may launch collectives earlier and improve performance. Optionally, tune this value depending on system memory and performance requirements. | `--suggested-communication-unit-size ` | N/A (Megatron-Core Only) | + +> Only a small depth-wise fraction of the model state can exist un-sharded at any point in time. + +**FSDP Unit Modules** represent fractions of the model state that are computed and communicated as a (coalesced) group, un-sharded when needed for computation, and re-sharded after computation to release memory for subsequent model states. Implicitly, an FSDP unit module is also a **_modeling contract_**, requiring that FSDP-managed unit module parameters are not accessed or modified beyond the scope of the forward pass, backward pass, or optimization step. + +Megatron-FSDP accepts a list of `str` or `class` paths representing FSDP unit modules via the `fsdp_unit_modules` argument, which is currently hard-coded to supported model classes (like `TransformerLayer`) in Megatron-Core. It performs a depth-first traversal of the model (via `torch.nn.Module.named_modules()`) and groups the parameters of each matching module for sharding and coalesced communication. Nested units are resolved by precedence: if a module matches an FSDP unit class but is already a sub-module of a previously registered FSDP unit, it is skipped, so the outermost (and necessarily largest) FSDP unit class in any module sub-tree becomes the effective FSDP unit module. + +> Communication should overlap computation. + +Once a model is partitioned into unit modules, computation is overlapped with communication based on the granularity of the FSDP unit module. Depending on the size of the compute and communication kernels, fine-tuning the unit module size and grouping configuration can impact performance and elicit trade-offs between overlap and memory when using FSDP. + +```{figure} ../../images/megatron_fsdp/fsdp_streams.png +:alt: FSDP Streams +:align: center + +Each color-coded block in the compute and communication streams, merged and categorized in the simplified (and worst-case) scenario where SM resources are under contention, correspond to a _single_ FSDP unit module. +``` + +Compute-communication overlaps are orchestrated using **CUDA streams** that capture and parallelize serial operations. All collectives associated with all combinations of `{DP-Inner, DP-Outer}` and `{AG, RS}` are scheduled and tracked with separate streams and communicators / `ProcessGroup`(s). + +- Parameters are un-sharded prior to `fprop` and `dgrad` computation. To overlap the pre-fetch all-gather with computation, at least two FSDP units worth of un-sharded weight memory is required at any point in time. +- Gradients are reduced and sharded after `wgrad` computation. To overlap gradient reduce-scatter with `wgrad` computation, at least two FSDP units worth of un-sharded gradient memory is required at any point in time. + +#### FSDP Module Hooks + +To implement these "unit-periodic" mechanics, Megatron-FSDP uses `Module` hooks to install a variety of (pre- and post-) forward and backward operations: + +- **Pre-Forward** + - Un-shards the model parameters of the current and (via pre-fetching) forward-subsequent FSDP unit modules. + - When `MegatronFSDP.forward()` is invoked, Megatron-FSDP will swap all parameter references to point to the un-sharded `Tensor` compute weights for the forward and backward pass. +- **Post-Forward** + - Re-shards model weights after the forward pass, if the module is an FSDP unit. Non-unit modules remain persistently un-sharded. + - When using activation recomputation during the backwards pass, computing both `fprop` and `dgrad` requires these parameters, so parameters are resharded during **Post-Backward**. + - Releases the transpose cache of quantized parameters (in FSDP / ZeRO-3) for specific quantization recipes in `TransformerEngine`. +- **Pre-Backward** + - Un-shards the model parameters of the current and (via pre-fetching) backward-subsequent FSDP unit modules. + - Implemented as a `torch.autograd.graph.register_multi_grad_hook` triggered by the output `dgrad`, and installed via a `Module` _post-forward_ hook. +- **Post-Backward** + - Re-shards model weights after the backward pass, if the module is an FSDP unit. Non-unit modules remain persistently un-sharded. + - Implemented by injecting an Autograd function (`RegisterFSDPBackwardFunction`) that is installed during a `Module` _pre-forward_ hook. + - Reduces gradients after the backward pass. + - Implemented using a `Tensor.register_post_accumulate_grad_hook` triggered by `param.grad`, as well as a root-level post-backward hook installed during **Pre-Backward** (`torch.autograd.Variable._execution_engine.queue_callback`). +- **State Dictionary** + - When `module.state_dict()` (for any module managed by Megatron-FSDP) is invoked, Megatron-FSDP will swap all parameter references to point to sharded `DTensor` main weights for distributed optimization and checkpointing. + - When `MegatronFSDP.load_state_dict()` is invoked, both the main and compute weights are updated. When using quantized model compute, the main weights are quantized and sharded. + +#### Double Buffering + +Megatron-FSDP uses a `Tensor._typed_storage()._resize_(bytes)`-based allocator to instantly allocate and de-allocate memory without depending on the `CUDACachingAllocator` for un-sharded parameters and gradients by default. (Cache fragmentation and garbage collection can procrastinate large quantities of `cudaMalloc` and `cudaFree` operations that can block programs and spike memory, particularly when memory utilization is maxed out.) However, modifying the underlying storage of a buffer is not compatible with NCCL symmetric registration or CUDA graphability, which require a persistent state during runtime. + +To support these optimizations, Megatron-FSDP uses **double-buffering**, which assigns 2 persistently-allocated buffers to FSDP units in an alternating pattern, hard-limiting the memory overhead for parameter and gradient buffer allocation and ensuring that no more than 2 FSDP units are computed or communicated concurrently. + +```{figure} ../../images/megatron_fsdp/fsdp_double_buffer.png +:alt: FSDP Double Buffering +:align: center + +Visualization of double buffering in Megatron-FSDP. Even- and odd-indexed FSDP units share the same un-sharded parameter and gradient buffers, overwriting incumbent data as needed during runtime. Megatron-FSDP ensures that no more than two FSDP units are un-sharded at any point during runtime. +``` + +With double-buffering, Megatron-FSDP does not need to allocate memory after initialization, which can reduce memory fragmentation and improve performance. However, double-buffering requires _depth-wise model symmetry_, where even- and odd-indexed FSDP units have identical size during runtime. If double-buffering is utilized, Megatron-FSDP computes the **_mode_** of FSDP unit sizes as the symmetrical double-buffer size, and any FSDP units not symmetrical to the computed size will default to the `_resize_(bytes)`-based allocator (or persistently allocated for extremely large and asymmetrical layers that affect performance significantly like `torch.nn.Embedding` when the low-level argument `fsdp_db_use_persist_buf_on_alloc_fail` is set). + +### Data-Parallel Sharding Strategies + +| Optimization | Description | `Megatron-Core` Config | `fully_shard` Config | +|--------------|-------------|----------------------|----------------------| +| **Data Parallel Sharding Strategy** | Primary data-parallel sharding strategy for FSDP, which supports DDP, ZeRO-1 (optimizer), ZeRO-2 (optimizer and gradients), and ZeRO-3 (optimizer, gradients, and parameters). Typically uses intra-node communications, i.e. "inner" or "intra" DP. | `--data-parallel-sharding-strategy {no_shard, optim, optim_grads, optim_grads_params}` | `zero_dp_strategy={no_shard, optim, optim_grads, optim_grads_params, 0, 1, 2, 3}` | +| **DP-Outer Sharding Strategy** | Secondary data-parallel sharding strategy for HSDP, which supports Hybrid-Sharded Data Parallel (HSDP / `no_shard`) and Hybrid-FSDP (HFSDP / `optim`). Typically uses inter-node communications, i.e. "outer" or "inter" DP. | `--outer-dp-sharding-strategy {no_shard, optim}` | `outer_dp_sharding_strategy={no_shard, optim, 0, 1}` | +| **Hybrid Data Parallelism Size** | Specify the DP-Outer / Inter-DP parallel size. DP-Inner / Intra-DP sizes will be deduced from the sizes of other parallelisms and `torch.distributed.get_world_size()`. | `--num-distributed-optimizer-instances ` | `dp_outer_dim=` (Cumulative DP groups `hybrid_fsdp_group` / `hybrid_fsdp_expt_group` are required for HFSDP.) | + +Megatron-FSDP supports a variety of sharding strategies over a variety of distributed topologies: + +- **Distributed Data Parallelism (DDP)** + - Model state is replicated across DP ranks. + - Gradient all-reduce is overlapped with backward compute and launched during the last backward pass before the optimization step. +- **ZeRO-1** + - Optimizer state is sharded across DP ranks. + - Gradient reduce-scatter is overlapped with backward compute and launched during the last backward pass before the optimization step. (Reduce-scatter is used in lieu of all-reduce for performance, because only a shard of the gradient is needed for optimization.) +- **ZeRO-2** + - Optimizer state and gradients are sharded across DP ranks. + - Gradient reduce-scatter is overlapped with backward compute and accumulated during every backward pass. +- **Fully-Sharded Data Parallelism (FSDP / ZeRO-3)** + - Optimizer state, gradients, and parameters are sharded across DP ranks. + - Gradient reduce-scatter is overlapped with backward compute and accumulated during every backward pass. +- **Hybrid-Sharded Data Parallelism (HSDP)** + - Optimizer state, gradients, and parameters are sharded across the "inner" or "intra" DP ranks. + - Model state is replicated across "outer" / "inter" DP ranks, and outer data-parallel gradients are all-reduced during the last backward pass before the optimization step. +- **Hybrid-FSDP (HFSDP)** + - Optimizer state, gradients, and parameters are sharded across the "inner" or "intra" DP ranks. + - Optimizer state is _further_ sharded across "outer" / "inter" DP ranks. + - Outer data-parallel gradients are reduce-scattered after during the last backward pass before the optimization step. + - Outer data-parallel parameters are all-gathered during the first forward pass after the optimization step. + - FSDP primary sharding (`optim_grads_params`) is required for HFSDP secondary sharding (`optim`). + - Requires passing cumulative data-parallel groups (`hybrid_fsdp_group` / `hybrid_fsdp_expt_group`), which include ALL data-parallel ranks, to Megatron-FSDP. + - To create these using `DeviceMesh`, create a data-parallel `DeviceMesh` for the cumulative DP group and use `DeviceMesh._unflatten(dp_dim, mesh_sizes=(dp_outer_size, dp_inner_size), mesh_dim_names=("dp_outer_dim", "dp_shard_dim"))` to construct a `DeviceMesh` with DP-Inner and DP-Outer mesh dimensions for Hybrid-FSDP. + +#### Understanding Hybrid-FSDP (HFSDP) + +```{figure} ../../images/megatron_fsdp/hfsdp.png +:alt: Hybrid-FSDP Topology +:align: center + +Hybrid-FSDP (HFSDP) is a variation of HSDP where the optimizer state in particular is sharded across both DP-Inner and DP-Outer, i.e. all data-parallel ranks, which further reduces memory utilization. In other words, intra-node sharding and communication uses ZeRO-3, while inter-node sharding and communication uses ZeRO-1. Parameters and gradients are converted from and to the fully-sharded optimizer state during optimization steps only, reducing the frequency of inter-node communications. + +Inspired by the artistry in the DHEN (Zhang, Luo, Liu, Meta, et al., 2022) paper: https://arxiv.org/abs/2203.11014 +``` + +**Hybrid-Fully Sharded Data Parallelism (HFSDP)** is a slight modification to HSDP that fully-shards the optimizer state across all data-parallel ranks and introduces outer-level all-gather and reduce-scatter collectives to map fully-sharded parameters and gradients into partially-sharded parameters and gradients. + +The memory profile of HFSDP is a "hybrid" of FSDP (optimizer state) and HSDP (gradients and model weights). Another elegant way to understand HFSDP functionality is ZeRO-1 composed with ZeRO-3. + +$$\text{Hybrid-FSDP Memory Profile} = \frac{\text{Optimizer State}}{\text{DP-Inner} \ \times \ \text{DP-Outer}} + \frac{\text{Gradient} + \text{Weight}}{\text{DP-Inner}}$$ + +The modified algorithm has the following characteristics: + +- Megatron-FSDP maintains a view of the model parameters sharded across all data-parallel ranks. + - Distributed checkpoints save and load the fully-sharded model parameters. + - Distributed optimizer state is initialized on the fully-sharded model parameters. +- During the first forward pass after checkpointing or optimization, fully-sharded model weights are all-gathered into partially-sharded model weights. +- During the last backward pass before optimization, partially-sharded model gradients are reduce-scattered into fully-sharded model gradients. +- Otherwise, FSDP is performed on the partially-sharded model weights and accumulated gradients. Because model weights and gradients are only updated and ingested once per optimization cycle, we can skip or postpone all expensive inter-node / DP-outer collectives until an optimization step.​ + +In addition to improved memory utilization, HFSDP communications are split in communication size (bytes communicated), communication topology (DP-Inner and DP-Outer groups), and communication domain (NVLink and InfiniBand) across two sharding stages. + +```{figure} ../../images/megatron_fsdp/fsdp_v_hfsdp_streams.png +:alt: Hybrid-FSDP Streams +:align: center + +Inter-node communications can also be parallelized with intra-node communications using separate CUDA streams. +``` + +#### Mixing FSDP & Model Parallelism + +Megatron-FSDP is also compatible with a variety of model parallelisms that shard the model state, such as **Tensor Parallelism (TP)** and **Expert Parallelism (EP)**. When sharding model states across multiple dimensions in the device topology, _**FSDP sharding is always performed last**_, because FSDP collectives un-shard and re-shard parameters and gradients immediately before and after computation. Thus, FSDP sharding mechanics are implemented over tensor and expert parallel (strided) shards. + +```{figure} ../../images/megatron_fsdp/mixed_sharding.png +:alt: Mixed Model Parallelism +:align: center + +Wheneveer FSDP is composed with other model parallelisms, FSDP sharding is always exercised last to seamlessly integrate with existing model shards. +``` + +Megatron-FSDP uses `torch.distributed.DeviceMesh` to describe and configure communications across devices in data-parallel group(s). Because heterogeneous models that have mixed layers, such as [Hybrid Mamba-Transformer](https://arxiv.org/abs/2504.03624) or [Mixture-of-Experts (MoE)](https://arxiv.org/abs/1701.06538) models, require different parallelism configurations, multiple `DeviceMesh`(s) may be required for specific layers that require distinct distributed topologies for optimal memory efficiency and performance. + +Currently, Megatron-FSDP supports two `DeviceMesh`(s), one for dense / non-expert `Module`(s) and another for Megatron-Core MoE sparse / expert `Module`(s). (Expert modules and parameters in Megatron-Core are automatically detected.) + +- Dense modules typically have a `DeviceMesh` with data parallel, tensor parallel, and context parallel dimensions, where the data parallel dimension is used for FSDP. Typically, both data-parallel and context-parallel ranks are used for sharding in FSDP. +- Mixture-of-experts modules typically have a `DeviceMesh` with data parallel, tensor parallel, and expert parallel dimensions, where the data parallel dimension is used for FSDP. + +For more information about Mixture-of-Experts in Megatron-Core, refer to the [Megatron-Core User Guide - MoE](https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/features/moe.html). + +#### Non-Uniform / Un-Even Model Sharding + +While `torch.distributed.tensor.DTensor` defaults to per-parameter sharding, where Tensors are split evenly on `dim=0` across the data-parallel domain, Megatron-FSDP uses **non-uniform or un-even `DTensor` shards** of a (flattened) group of parameters associated with an FSDP unit. + +```{figure} ../../images/megatron_fsdp/uneven_sharding.png +:alt: Non-Uniform Sharding +:align: center + +Comparison of FSDP2 per-parameter sharding and Megatron-FSDP per-unit or per-module sharding. FSDP2 requires `COPY` operations to move parameters and gradients in and out of communication buffers to reduce the frequency of NCCL collective calls, while Megatron-FSDP assigns sliced views of contiguous communication buffers to parameters associated with an FSDP unit. +``` + +While complex and less user-intuitive, an un-evenly sharded data structure enables a few performance benefits without introducing expensive `COPY` operations to set up communication and computation buffers: + +- **Fewer NCCL calls**, reducing kernel launch and synchronization overhead. Only parameters in FSDP units that have different communication-related properties, such as their `dtype` or distributed topology, are coalesced into separate NCCL calls. +- Flat communication and computation buffers are **contiguous-by-design**, supporting optimized CUDA kernels that require buffers backed by contiguous memory, such as grouped GEMMs used in MoE. + +Effectively, this implies that the same `DTensor`-sharded model parameters may have completely different shapes on different ranks, and if entire parameters are assigned to other ranks, the local `Tensor` will be empty. + +> ℹ️ Megatron-FSDP has a handy library ([`megatron_fsdp.uneven_dtensor`](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py)) for manipulating un-evenly sharded `DTensors`, focused on per-parameter operations like un-sharding or reducing parameters that have different shapes across ranks. While the parameter group is evenly-sharded for FSDP collectives, per-parameter collectives (that assume a symmetrical amount of bytes are communicated between devices) will hang waiting on bytes that will never arrive for un-evenly sharded `DTensors`. + +In particular, contiguous memory is only half the requirement for high-performance CUDA kernels. The other requirement is **locality**, which FSDP can violate, that introduces compatibility issues when combining FSDP with present and future optimizations. For example, block-wise quantization (scaling factor / `absmax` calculations for MXFP8, NVFP4, etc.) requires DP communication and custom max-reduce kernels if the block is sharded by FSDP. + +Megatron-FSDP supports `dim=0` sharding, which computes the _**least-common multiple (LCM) of `p.shape[1:]` for all parameters `p` in an FSDP unit**_ and _**pads the un-sharded buffer to the closest multiple of `DP x LCM(p.shape[1:])`**_, forming a "DP-LCM" partition with `LCM`-length parts to ensure that DP-sharding boundaries do not violate chunks of data for coordinates of `dim=0`. + +```{figure} ../../images/megatron_fsdp/lcm_dim0_shard.png +:alt: Flat Buffer Sharding Algorithm +:align: center + +Visualization of how parameters are assigned un-evenly to the flat per-unit buffer sharded across DP ranks. With the LCM algorithm, every slice of `dim=0` is never bisected by FSDP. Algorithms and compute kernels can leverage this locality and contiguity. +``` + +1. When a parameter is _divisble by the LCM_, it can be inserted at any index multiple of the LCM in the buffer that is free. `p[i]` chunks of this parameter by definition divide the LCM, and thus align with the DP-LCM sharding grid. +2. When a parameter _is larger than but not divisible by the LCM_, the remainder `r` populates a fraction of another LCM part, so a "conjugate" parameter that also exceeds the LCM with a "conjugate" remainder `r'` that is less than or equal to `LCM - r` is installed to fill the remaining space and align with the DP-LCM sharding grid. +3. When a parameter _is smaller than but not divisible by the LCM_, a post-assignment sweep on the leftover space in the flat buffer is run, and all gaps that are multiples of the LCM that are large enough to support the entire parameter are utilized. Once all gaps are filled, the final parameters are assigned to the tail of the buffer respecting the DP-LCM sharding grid. + +> ℹ️ Generalized support for contiguity and locality in Megatron-FSDP is a **_work-in-progress_** and will evolve with contribution from the OSS community and PyTorch. For more information about how kernel buffer requirements affect the design of FSDP data structures, refer to the [veScale: Consistent and Efficient Tensor Programming with Eager-Mode SPMD (Li, Youjie, ByteDance Seed, et al.)](https://arxiv.org/abs/2509.07003) paper that comprehensively analyzes these requirements. + +### Mixed-Precision & Quantization + +| Optimization | Description | `Megatron-Core` Config | `fully_shard` Config | +|--------------|-------------|----------------------|----------------------| +| **Quantized Parameters** | Megatron-FSDP will shard and all-gather TransformerEngine-quantized parameters for computation. Quantized parameters are updated every optimization step, and both row-wise (FWD) and column-wise (BWD) data are managed for non-transposable 1-D quantization recipes like MXFP8. Otherwise, only activations are quantized. | `--fp8-param-gather` | TransformerEngine `quantized_model_init()` | +| **Main Parameter (Optimization / Checkpoint) Data-Type** | Data-type for optimization and checkpointing parameters. If set to `auto`, model compute weights are utilized instead. Required for `--fp8-param-gather`. Defaults to FP32. | `--megatron-fsdp-main-params-dtype {fp32, bf16, fp16, auto}` | `MixedPrecisionPolicy(main_params_dtype=...)` | +| **Main Gradient (Accumulation) Data-Type** | Data-type for gradient accumulation. If set to `auto`, main gradient precision will be derived from model parameter precision. Defaults to `auto`. | `--megatron-fsdp-main-grads-dtype {fp32, bf16, fp16, auto}` | `MixedPrecisionPolicy(main_grads_dtype=...)` | +| **Gradient Communication (Reduction) Data-Type** | Data-type for gradient communication and reduction. If set to `auto`, the main gradient precision will be used for communication. (When using NCCL symmetric registration, low-precision gradients are reduced in FP32 over-the-wire.) Defaults to `auto`. | `--megatron-fsdp-grad-comm-dtype {fp32, bf16, fp16, auto}` | `MixedPrecisionPolicy(grad_comm_dtype=...)` | +| **Weight Gradient Accumulation Fusion** | When using TransformerEngine modules, Megatron-FSDP implements `get_main_grad` to allocate un-sharded gradient buffers called by TransformerEngine, to avoid `COPY`-ing the gradient to Megatron-FSDP communication buffers. Used by default and can be deactivated with `--no-gradient-accumulation-fusion`. | `--no-gradient-accumulation-fusion` | N/A (Megatron-Core Only) | +| **Precision-Aware Optimizer** | Use the TransformerEngine `FusedAdam` optimizer, and Megatron-FSDP will install the gradient in a temporary attribute `Parameter.decoupled_grad` which is consumed by `FusedAdam`. Megatron-FSDP manages the main parameters, but the optimizer state precision can be customized with `--exp-avg-dtype` and `--exp-avg-sq-dtype`, which both support `fp8` optimization state. | `--use-precision-aware-optimizer` | `use_decoupled_grad=True` | + +#### Quantization + +Quantization is an extremely important feature for Megatron-FSDP as it reduces memory utilization and communication size for both activations and parameters, which directly affects the viability and performance of FSDP. + +```{figure} ../../images/megatron_fsdp/quantized_param_gather.png +:alt: Quantized Model Parameters & FSDP +:align: center + +Visualization of Megatron-FSDP's training loop when using quantized weights from TransformerEngine. Every optimization step updates the quantized representation of sharded model weights, which have reduced communication size. +``` + +While TransformerEngine handles activation quantization, Megatron-FSDP shards quantized weights for AG. + +0. _**Quantized Model Initialization**_ - Model is initialized with quantized weights, e.g. MXFP8 or NVFP4. If using `meta` device initialization, Megatron-FSDP will call `reset_parameters()` to initialize quantized weights layer-by-layer. If row-wise and column-wise data are not transposable, Megatron-FSDP will shard and buffer both. Additionally, high-precision main weights are retrieved and sharded for distributed optimization, checkpointing, and quantization. +0. _**Forward / Backward Pass**_ - Quantized weights are un-sharded for both the forward and backward pass. If row-wise and column-wise data aren't transposable, the row-wise weights are gathered for forward, and the column-wise weights are gathered for backward. +0. _**Distributed Optimization Step**_ - Non-quantized accumulated gradient shards from quantized GEMMs are applied to high-precision main weight shards. +0. _**Sharded Quantization**_ - Sharded main weights are quantized to update the quantized compute weights for subsequent training steps. + +```{figure} ../../images/megatron_fsdp/sharded_quantization.png +:alt: Sharded Quantization +:align: center + +Sharded quantization involves reducing maxima to compute a global set of scaling factors for local / sharded quantization. +``` + +In particular, _sharded quantization_ minimizes communication size and memory utilization by communicating scaling factors instead of main weights. + +1. _**Local Abs-Max**_ - For a group of parameters in an FSDP unit, compute local tensor-wise or block-wise maxima across the global un-sharded shape, with zero padding for non-local data. +1. _**Global Abs-Max**_ - Globally all-reduce maxima and derive scaling factors from maxima. +1. _**Local Quantization**_ - Locally quantize sharded main weights and install into compute weight buffers. + +#### Mixed-Precision + +Megatron-FSDP sharding and communication buffers support mixed-precision, such that users can customize the `dtype` used for main weights, gradient communication (reduction), and gradient accumulation in addition to the native or quantized `dtype` used for model computation. These options are wrapped in a `MixedPrecisionPolicy` dataclass. + +- _**Main Weight Precision**_ - Controls the data-type for parameters responsible for distributed optimization, distributed checkpointing, and quantization. If set to `auto` (`None`), the native model compute parameter data-type will be utilized. Required for parameter quantization with `--fp8-param-gather`. Defaults to `torch.float32`. +- _**Main Gradient Precision**_ - Controls the data-type for `wgrad` accumulation and distributed optimization. Defaults to `auto` (`None`), the model native gradient data-type will be utilized. While `torch.float32` (or higher) is recommended for accuracy at scale, as `main_grads_dtype` controls the data-type for gradient accumulation, `auto` is more flexible and uses pre-determined parameter gradient logic in mixed-precision scenarios, such as `BF16` for `FP8`/`FP4` parameters quantized via TransformerEngine. +- _**Gradient Communication Precision**_ - Controls the data-type for gradient communications when reducing gradients. Lower precision improves (communication) performance. Defaults to `auto` (`None`), in which the main gradient data-type will be utilized. If using `no_shard`, `optim`, HSDP, or HFSDP, allocating `dtype`-custom gradient communication buffers may increase per-unit memory overhead, so users should consider the performance-memory trade-off when using this feature. + - If using NCCL symmetric registration `v2.27+`, gradient reduction may be performed in high-precision depending on the network domain (NVLink or IB), and can enable mixed-precision communication and accumulation, e.g. setting grad_comm_dtype to `BF16` can support `FP32` reduction even though we have `BF16` input and output communication buffers. Otherwise, gradients will be reduced and accumulated in communication and accumulation precision as usual. + +### NCCL + +| Optimization | Description | `Megatron-Core` Config | `fully_shard` Config | +|--------------|-------------|----------------------|----------------------| +| **NCCL User Buffers** | Allocate and register Megatron-FSDP communication buffers with NCCL, which enables zero-`COPY`, high-precision reduction, copy-engine collectives, and symmetric kernels. Uses double buffering. | `--use-nccl-ub` | `nccl_ub=True` | +| **NCCL Manual Registration** | Instead of registering NCCL user buffers on first allocation, batch registration of all communication buffers at the end of the initial training step. Reduces registration latency. | `--fsdp-manual-registration` | N/A (Megatron-Core Only) | +| **Disable Symmetric Registration** | Disable symmetric registration with NCCL. Optional, as symmetric registration failure defaults to normal registration. | `--disable-symmetric-registration` | `disable_symmetric_registration=True` | + +[NVIDIA Collective Communications Library (NCCL)](https://developer.nvidia.com/nccl) implements multi-device and multi-node communication primitives optimized for CUDA devices and networking from NVIDIA. Megatron-FSDP communications are registered and deeply integrated with NCCL, which enables a variety of hardware-level networking optimizations such as copy-engine AG, high-precision RS, SHARP reduction offloading, and symmetric kernels. + +To leverage NCCL networking optimizations, **NCCL user buffer registration (UBR)** is required to inform NCCL of PyTorch Tensors ("user buffers") that act directly as the input and target of NCCL collectives for PyTorch `ProcessGroup`(s). Because registered communication buffers are known to NCCL, `COPY` operations that send collective inputs to NCCL buffers and collective outputs to PyTorch buffers are no longer required, which enables Megatron-FSDP to be zero-`COPY` end-to-end. + +NCCL (`v2.27+`) supports symmetric allocation or registration for communicators over the NVLink domain, which allow buffers that share identical virtual addresses across devices to benefit from optimized collectives: + +- **Symmetric Kernels** - On the NVLink domain, symmetric kernels operating on symmetric memory reduces the SM utilization for a single communication kernel to 1. +- **NVSwitch SHARP Offloading** - To further minimize SM utilization for AG and RS collectives, NCCL SHARP offloads reduction and aggregation work to NVLink and IB Switch hardware that uses 1-6 SM depending on the domain: NVL, IB, or NVL + IB. +- **Copy-Engine (CE) Collectives**: Instead of using SMs (or CTAs) for common non-computational collectives like AG in Megatron-FSDP, copy engines are instead used to perform all-gather collectives, dedicating SM resources to compute and reduction during FSDP. Requires NCCL `v2.28+`. +- **High-Precision Reduction**: When training large models, high-precision gradient reduction and accumulation is desired for accuracy and convergence, but communicating FP32 gradients is expensive. With symmetric registration, FP32 accumulators enable gradients to be reduced in FP32 but communicated in BF16, which decreases gradient RS communication latency while maintaining high accuracy during training. Megatron-FSDP supports FP32 main gradient accumulation but BF16 gradient communication, customizable through `megatron_fsdp.MixedPrecisionPolicy`. + +These optimizations significantly reduce SM resource contention for overlapped compute and communication kernels in FSDP. Symmetric registration, allocation, and pooling is also supported in PyTorch: [`torch.distributed._symmetric_memory`](https://docs.pytorch.org/docs/stable/symmetric_memory.html). diff --git a/docs/user-guide/features/tokenizers.md b/docs/user-guide/features/tokenizers.md index 672f0f0cd98..1455d6e617e 100644 --- a/docs/user-guide/features/tokenizers.md +++ b/docs/user-guide/features/tokenizers.md @@ -149,7 +149,24 @@ tokenizer = MegatronTokenizer.from_pretrained( ### Null Tokenizer -Use a null tokenizer for testing or non-text models: +The Null tokenizer is a lightweight, zero-I/O tokenizer that requires no model files. +It is useful in three scenarios: + +1. **Performance benchmarking** with `--mock-data` where real tokenization is unnecessary. +2. **Testing** in functional tests and CI pipelines where tokenizer model files may not + be available. The Null tokenizer removes the dependency on external files, making + tests self-contained and portable. +3. **Pretraining with pretokenized data** where all data is already tokenized into + `.bin`/`.idx` files. In this case the tokenizer is only needed for metadata + (`vocab_size`, `eod`, `pad`) — not for actual tokenization. Using the Null tokenizer + avoids redundant filesystem access at scale, which is particularly beneficial on + shared filesystems like Lustre where thousands of ranks would otherwise all load the + same tokenizer files. + +Properties derived from `--vocab-size N`: +- `vocab_size` = `N` (the exact value passed) +- `eod` = `N - 1` (last token in the vocabulary) +- `pad` = `0` ```python tokenizer = MegatronTokenizer.from_pretrained( @@ -165,10 +182,20 @@ tokenizer = MegatronTokenizer.from_pretrained( The tokenizer system works with Megatron-LM training scripts: ```bash -# Null tokenizer for testing +# Null tokenizer for benchmarking with mock data torchrun --nproc_per_node=8 pretrain_gpt.py \ --tokenizer-type NullTokenizer \ --vocab-size 131072 \ + --mock-data \ + ... +``` + +```bash +# Null tokenizer for pretraining with pretokenized data (no tokenizer files needed) +torchrun --nproc_per_node=8 pretrain_gpt.py \ + --tokenizer-type NullTokenizer \ + --vocab-size 128256 \ + --data-path /path/to/pretokenized_data \ ... ``` @@ -195,7 +222,7 @@ The following table lists supported tokenizer backends: | **SentencePiece** | Google's tokenizer | GPT-style models, custom vocabularies | | **TikToken** | OpenAI's tokenizer | GPT-3.5/GPT-4 style tokenization | | **Megatron** | Built-in tokenizers | Legacy GPT-2 BPE | -| **Null** | No-op tokenizer | Testing, non-text modalities | +| **Null** | Zero-I/O tokenizer | Benchmarking, pretokenized data | ## Common Tokenizer Types diff --git a/docs/user-guide/parallelism-guide.md b/docs/user-guide/parallelism-guide.md index e09848f8800..2540ca0a827 100644 --- a/docs/user-guide/parallelism-guide.md +++ b/docs/user-guide/parallelism-guide.md @@ -15,19 +15,20 @@ Megatron Core supports multiple parallelism strategies that can be combined to e The following table summarizes supported parallelism strategies. -| Strategy | What it parallelizes | Best for | +| Strategy | Parallelism Objective | Best For | |----------|---------------------|----------| -| **Data Parallelism (DP)** | Batch dimension | Standard training, most common | -| **Tensor Parallelism (TP)** | Individual layers | Large layers, GPU memory constraints | -| **Pipeline Parallelism (PP)** | Model depth | Very deep models | -| **Context Parallelism (CP)** | Sequence length | Long sequences (8K+ tokens) | -| **Expert Parallelism (EP)** | MoE experts | Mixture-of-Experts models | +| **Data Parallelism (DP)** | Batch Dimension | Data Scalability, Standard Training | +| **Tensor Parallelism (TP)** | Individual Layers | Large Layers & Activation, GPU Memory Constraints | +| **Pipeline Parallelism (PP)** | Model Depth | Very Deep Models | +| **Context Parallelism (CP)** | Sequence Length | Long Sequences (8K+ Tokens) | +| **Expert Parallelism (EP)** | MoE Experts | Mixture-of-Experts Models | +| **Fully-Sharded Data Parallelism (Megatron-FSDP)** | Model State | Extremely Large Models & DP Interchangeability | ## Data Parallelism (DP) -Replicate the model across GPUs and split the batch. +### Standard Distributed Data Parallel (DDP) -### Standard Data Parallel (DDP) +Replicate the model across GPUs and split the batch. ```bash torchrun --nproc_per_node=8 pretrain_gpt.py \ @@ -36,22 +37,40 @@ torchrun --nproc_per_node=8 pretrain_gpt.py \ Each GPU has a full copy of the model and processes a portion of the batch. -### Fully Sharded Data Parallel (FSDP) +### Megatron Fully-Sharded Data Parallel (Megatron-FSDP) -Shard model parameters, gradients, and optimizer states to reduce memory: +Shard model parameters, gradients, and optimizer states across GPUs to reduce memory utilization. -```bash -# Megatron FSDP (~15% faster than PyTorch FSDP2) ---use-megatron-fsdp \ +``` +--use-megatron-fsdp --data-parallel-sharding-strategy optim_grads_params +--ckpt-format fsdp_dtensor +--init-model-with-meta-device ``` -**Sharding strategies** +**Sharding Strategies** + +`--data-parallel-sharding-strategy` supports the following options: - `optim` - Shard optimizer states only (ZeRO-1) - `optim_grads` - Shard gradients + optimizer (ZeRO-2) - `optim_grads_params` - Shard parameters + gradients + optimizer (ZeRO-3) +If `--num-distributed-optimizer-instances` is > 1, then hierarchical data parallelism is enabled. + +`--outer-dp-sharding-strategy` supports the following options: + +- `no_shard` (**Hybrid-Sharded Data Parallelism**) - Replicate the model state across outer data parallel ranks. +- `optim` (**Hybrid-FSDP**) - Shard the optimizer state across the outer data parallel ranks. + - Requires `--data-parallel-sharding-strategy optim_grads_params`. + +**When to Use** + +- Large models with large or fused compute kernels to hide communications under. +- Integrated with TP, CP, EP, and easily composable with heterogeneous parallelisms. +- With SM-reducing optimizations from NCCL and activation offloading from TransformerEngine. +- Using `fully_shard` without depending on Megatron-LM. + ## Tensor Parallelism (TP) Split individual model layers across GPUs. Recommended for large hidden dimensions. @@ -61,7 +80,7 @@ Split individual model layers across GPUs. Recommended for large hidden dimensio --sequence-parallel # Enable sequence parallelism (recommended) ``` -**When to use** +**When to Use** - Model layers do not fit on a single GPU - Large hidden dimensions (4096+) @@ -76,7 +95,7 @@ Split model layers across GPUs vertically (by depth). --num-layers-per-virtual-pipeline-stage 4 # Virtual pipeline for load balancing ``` -**When to use** +**When to Use** - Very deep models (50+ layers) - Combine with TP for large models @@ -91,7 +110,7 @@ Split long sequences across GPUs for efficient long-context training. --cp-comm-type p2p # Communication type ``` -**When to use** +**When to Use** - Long sequences (8K+ tokens) - Reduces activation memory @@ -119,7 +138,7 @@ Distribute experts across GPUs in Mixture-of-Experts models. ## Parallelism Selection Guide -Recommended configurations based on [NVIDIA NeMo production setups](https://github.com/NVIDIA/NeMo/tree/main/scripts/performance/recommended_model_configs): +For a list of supported configurations, refer to [Megatron Bridge Supported Models](https://github.com/NVIDIA-NeMo/Megatron-Bridge#supported-models). ### Language Models @@ -208,24 +227,24 @@ Reduces activation memory by sharding sequence dimension in LayerNorm and Dropou ## Choosing the Right Strategy ### Start Simple -1. Begin with **Data Parallelism** (DP) only -2. Add **Tensor Parallelism** (TP) if the model does not fit -3. Add **Pipeline Parallelism** (PP) for very large models -4. Add **Context Parallelism** (CP) for long sequences +1. Begin with **Data Parallelism** (DP) only. +2. Add **Tensor Parallelism** (TP) if the model does not fit. +3. Add **Pipeline Parallelism** (PP) for very large models. +4. Add **Context Parallelism** (CP) for long sequences. ### Memory Constraints -- Use **FSDP** to reduce memory per GPU -- Use **TP** to split large layers -- Use **PP** to split model depth -- Enable **activation checkpointing** for extreme cases +- Use **FSDP** to split model state per GPU. +- Use **TP** to split large layers. +- Use **PP** to split model depth. +- Enable **activation checkpointing or offloading** for extreme cases. ### Communication Bottlenecks -- Reduce **TP** degree (increases memory per GPU) -- Increase **PP** degree (may reduce efficiency) -- Use **CP** instead of larger TP for long sequences +- Reduce **TP** degree (increases memory per GPU). +- Increase **PP** degree (may reduce efficiency). +- Use **CP** instead of larger TP for long sequences. ## Next Steps - **API Reference**: Refer to [Tensor Parallel](../api-guide/core/tensor_parallel.md) and [Pipeline Parallel](../api-guide/core/pipeline_parallel.md) in the API documentation -- **Advanced Features**: Refer to [Megatron FSDP](features/custom_fsdp.md) and [Distributed Optimizer](features/dist_optimizer.md) +- **Advanced Features**: Refer to [Megatron-FSDP](features/megatron_fsdp.md), [MoE](features/moe.md), and [Distributed Optimizer](features/dist_optimizer.md) - **Performance Tuning**: Refer to the [NVIDIA NeMo Performance Guide](https://docs.nvidia.com/nemo-framework/user-guide/latest/performance/performance-guide.html) diff --git a/examples/academic_paper_scripts/detxoify_lm/finetune_gpt.py b/examples/academic_paper_scripts/detxoify_lm/finetune_gpt.py deleted file mode 100644 index c3a9f69caef..00000000000 --- a/examples/academic_paper_scripts/detxoify_lm/finetune_gpt.py +++ /dev/null @@ -1,159 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - - -"""Fine-tune GPT""" - -import torch -from functools import partial -import os -import sys -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), - os.path.pardir, os.path.pardir))) -from megatron.training import get_args -from megatron.training import get_timers -from megatron.training import get_tokenizer -from megatron.training import print_rank_0 -from megatron.core import mpu -from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.blended_megatron_dataset_config import GPTDatasetConfig -from megatron.core.datasets.gpt_dataset import GPTDataset -from megatron.core.datasets.utils import get_blend_from_list -from megatron.legacy.model import GPTModel -from megatron.core.enums import ModelType -from megatron.training import pretrain -from megatron.training.utils import get_ltor_masks_and_position_ids -from megatron.training.utils import average_losses_across_data_parallel_group - -def model_provider(pre_process=True, post_process=True): - """Build the model.""" - - print_rank_0('building GPT model ...') - model = GPTModel( - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process - ) - return model - - -def get_batch(data_iterator): - """Generate a batch""" - args = get_args() - tokenizer = get_tokenizer() - - # Items and their type. - keys = ['text'] - datatype = torch.int64 - - # Broadcast data. - if data_iterator is not None: - data = next(data_iterator) - else: - data = None - data_b = mpu.broadcast_data(keys, data, datatype) - - # Unpack. - tokens_ = data_b['text'].long() - labels = tokens_[:, 1:].contiguous() - tokens = tokens_[:, :-1].contiguous() - - # Get the masks and postition ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( - tokens, - tokenizer.eod, - args.reset_position_ids, - args.reset_attention_mask, - args.eod_mask_loss) - - return tokens, labels, loss_mask, attention_mask, position_ids - -def loss_func(loss_mask, output_tensor): - losses = output_tensor.float() - loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss, {'lm loss': averaged_loss[0]} - - -def forward_step(data_iterator, model): - """Forward step.""" - args = get_args() - timers = get_timers() - - # Get the batch. - timers('batch-generator').start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - data_iterator) - timers('batch-generator').stop() - - output_tensor = model(tokens, position_ids, attention_mask, - labels=labels) - - return output_tensor, partial(loss_func, loss_mask) - - -def train_valid_test_datasets_provider(train_val_test_num_samples): - """Build train, valid, and test datasets.""" - args = get_args() - - print_rank_0('> building train, validation, and test datasets ' - 'for GPT ...') - train_ds, _, test_ds = BlendedMegatronDatasetBuilder( - GPTDataset, - train_val_test_num_samples, - lambda: True, - GPTDatasetConfig( - blend=get_blend_from_list(args.data_path), - split=args.split, - random_seed=args.seed, - sequence_length=args.seq_length, - path_to_cache=args.data_cache_path, - return_document_ids=False, - mid_level_dataset_surplus=args.mid_level_dataset_surplus, - ) - ).build() - print_rank_0("> finished creating finetuning GPT datasets ...") - - _, valid_ds, _ = BlendedMegatronDatasetBuilder( - GPTDataset, - train_val_test_num_samples, - lambda: True, - GPTDatasetConfig( - blend=get_blend_from_list(args.data_path2), - split="98,2,0", - random_seed=1234, - sequence_length=2048, - path_to_cache=args.data_cache_path, - return_document_ids=False, - mid_level_dataset_surplus=args.mid_level_dataset_surplus, - ) - ).build() - print_rank_0("> finished creating pretrained GPT datasets ...") - - return train_ds, valid_ds, test_ds - - -def add_validation_args(parser): - """Text generation arguments.""" - group = parser.add_argument_group(title='validation set') - group.add_argument('--data-path2', nargs='*', default=None, - help='Path to the validation dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ...') - group.add_argument('--eval-ppl', action='store_true', default=False) - group.add_argument('--stored_params', type=dict, default=dict()) - return parser - - -if __name__ == "__main__": - - pretrain(train_valid_test_datasets_provider, model_provider, - ModelType.encoder_or_decoder, - forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, - extra_args_provider=add_validation_args,) diff --git a/examples/academic_paper_scripts/detxoify_lm/generate_samples_gpt.py b/examples/academic_paper_scripts/detxoify_lm/generate_samples_gpt.py index a8b72bb39ae..a5afb7e5c60 100644 --- a/examples/academic_paper_scripts/detxoify_lm/generate_samples_gpt.py +++ b/examples/academic_paper_scripts/detxoify_lm/generate_samples_gpt.py @@ -14,7 +14,6 @@ import torch -import megatron.legacy.model from megatron.core import mpu from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import ( @@ -23,73 +22,59 @@ ) from megatron.core.transformer.spec_utils import import_module from megatron.inference.text_generation import generate_and_post_process -from megatron.legacy.model import GPTModel from megatron.training import get_args, get_model, get_tokenizer, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args, parse_and_validate_args from megatron.training.checkpointing import load_checkpoint from megatron.training.initialize import initialize_megatron -def model_provider( - pre_process=True, post_process=True -) -> Union[GPTModel, megatron.legacy.model.GPTModel]: +def model_provider(pre_process=True, post_process=True) -> GPTModel: """Builds the model. - If you set the use_legacy_models to True, it will return the legacy GPT model and if not the core GPT model. - Args: pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. Returns: - Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + GPTModel: The returned model """ args = get_args() print_rank_0('building GPT model ...') config = core_transformer_config_from_args(args) - if args.use_legacy_models: - model = megatron.legacy.model.GPTModel( - config, - num_tokentypes=0, - parallel_output=False, - pre_process=pre_process, - post_process=post_process, - ) - else: - if args.spec is None: - if args.transformer_impl == 'local': - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=args.num_experts, moe_grouped_gemm=args.moe_grouped_gemm - ) - elif args.transformer_impl == 'transformer_engine': - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=args.num_experts, moe_grouped_gemm=args.moe_grouped_gemm - ) - else: - raise ValueError(f"Invalid transformer_impl {args.transformer_impl}") - elif args.spec[0] == 'local': + if args.spec is None: + if args.transformer_impl == 'local': transformer_layer_spec = get_gpt_layer_local_spec( num_experts=args.num_experts, moe_grouped_gemm=args.moe_grouped_gemm ) + elif args.transformer_impl == 'transformer_engine': + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=args.num_experts, moe_grouped_gemm=args.moe_grouped_gemm + ) else: - transformer_layer_spec = import_module(args.spec) - - model = GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=args.padded_vocab_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=False, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, + raise ValueError(f"Invalid transformer_impl {args.transformer_impl}") + elif args.spec[0] == 'local': + transformer_layer_spec = get_gpt_layer_local_spec( + num_experts=args.num_experts, moe_grouped_gemm=args.moe_grouped_gemm ) + else: + transformer_layer_spec = import_module(args.spec) + + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=False, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + ) return model diff --git a/examples/gpt3/gpt_config.yaml b/examples/gpt3/gpt_config.yaml deleted file mode 100644 index 3c387eafe5a..00000000000 --- a/examples/gpt3/gpt_config.yaml +++ /dev/null @@ -1,297 +0,0 @@ -# WARNING: Yaml configs is currently an experimental feature -language_model: - # model architecture - num_layers: 24 - hidden_size: 1024 - num_attention_heads: 16 - num_query_groups: null - - ffn_hidden_size: null - kv_channels: null - hidden_dropout: 0.0 - attention_dropout: 0.0 - fp32_residual_connection: False - - apply_residual_connection_post_layernorm: False - layernorm_epsilon: 1.e-5 - layernorm_zero_centered_gamma: True - add_bias_linear: False - bias_activation_fusion: False - add_qkv_bias: False - gated_linear_unit: False - activation_func: swiglu - num_moe_experts: null - rotary_interleaved: False - window_size: null - - # initialization - init_method: null - init_method_std: 0.02 - output_layer_init_method: null - - # mixed-precision - apply_query_key_layer_scaling: False - attention_softmax_in_fp32: False - - # fusion - bias_swiglu_fusion: True - masked_softmax_fusion: True - persist_layer_norm: False - memory_efficient_layer_norm: False - bias_dropout_fusion: True - apply_rope_fusion: True - - # activation recomputation - recompute_granularity: null - recompute_method: null - recompute_num_layers: null - distribute_saved_activations: null - - # fp8 related - fp8: null - fp8_margin: 0 - fp8_interval: 1 - fp8_amax_history_len: 1 - fp8_amax_compute_algo: "most_recent" - fp8_wgrad: True - - # miscellaneous - clone_scatter_output_in_embedding: True - - normalization: "LayerNorm" # alt value supported by TE: "RMSNorm" - - # MoE related - moe_router_load_balancing_type: "aux_loss" - moe_router_topk: 2 - moe_router_group_topk: null - moe_router_num_groups: null - moe_grouped_gemm: False - moe_aux_loss_coeff: 0 # 1e-2 would be a good start value for load balance loss. - moe_z_loss_coeff: null # 1e-3 would be a good start value for z-loss - moe_input_jitter_eps: null - moe_token_dropping: False - -model_parallel: - # Model parallelism - tensor_model_parallel_size: 1 - context_parallel_size: 1 - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null - sequence_parallel: True - expert_model_parallel_size: 1 - - # Initialization - perform_initialization: True - use_cpu_initialization: null - - # Training - fp16: False - bf16: True - params_dtype: null # Set from above arguments for core - timers: null - - # Optimizations - gradient_accumulation_fusion: True - tp_comm_overlap: False - - # Debug Options - tp_comm_split_ag: True - tp_comm_atomic_ag: True - tp_comm_split_rs: True - tp_comm_atomic_rs: True - tp_comm_bulk_wgrad: True - tp_comm_bulk_dgrad: True - - # Parallelism - finalize_model_grads_func: null - - # Pipeline Parallel - pipeline_dtype: null - grad_scale_func: null - enable_autocast: False - autocast_dtype: null - variable_seq_lengths: False - num_microbatches_with_partial_activation_checkpoints: null - overlap_p2p_comm: False - batch_p2p_comm: True - batch_p2p_sync: True - use_ring_exchange_p2p: False - deallocate_pipeline_outputs: False - no_sync_func: null - grad_sync_func: null - param_sync_func: null - - # CPU Offloading - cpu_offloading: False - cpu_offloading_num_layers: 0 - _cpu_offloading_context: null - cpu_offloading_weights: False - cpu_offloading_activations: True - - # Timing - barrier_with_L1_time: True - -# training: -use_legacy_models: False -spec: null -micro_batch_size: 2 -global_batch_size: 128 -step_batch_size_schedule: "0:32 90B:64 180B:96 270B:128" -check_for_nan_in_loss_and_grad: True -num_layers_per_virtual_pipeline_stage: null - -encoder_num_layers: null -decoder_num_layers: null -rotary_seq_len_interpolation_factor: null -add_position_embedding: False -make_vocab_size_divisible_by: 128 -group_query_attention: False - - -exit_signal_handler: False -exit_duration_in_mins: null -exit_interval: null - -untie_embeddings_and_output_weights: True -position_embedding_type: rope -rotary_percent: 0.5 -openai_gelu: False -squared_relu: False -swiglu: True -onnx_safe: null -bert_binary_head: True -max_position_embeddings: 4096 - -transformer_impl: local -use_flash_attn: False -seed: 1234 -data_parallel_random_init: False - -# Optimizer -optimizer: adam -lr: 2.5e-4 -lr_decay_style: cosine -lr_decay_iters: null -lr_decay_samples: 255126953 -lr_warmup_fraction: null -lr_warmup_iters: 0 -lr_warmup_samples: 81381 -lr_warmup_init: 0.0 -min_lr: 2.5e-5 -weight_decay: 0.1 -start_weight_decay: null -end_weight_decay: null -weight_decay_incr_style: constant -clip_grad: 1.0 -adam_beta1: 0.9 -adam_beta2: 0.95 -adam_eps: 1.e-08 -sgd_momentum: 0.9 -override_opt_param_scheduler: False -use_checkpoint_opt_param_scheduler: False - -# checkpointing arguments -save: null -save_interval: 20000 -no_save_optim: null -no_save_rng: null -load: null -no_load_optim: null -no_load_rng: null -finetune: False -use_checkpoint_args: False -exit_on_missing_checkpoint: False - -# loss arguments -loss_scale: null -initial_loss_scale: 4294967296 -min_loss_scale: 1.0 -loss_scale_window: 1000 -hysteresis: 2 -accumulate_allreduce_grads_in_fp32: False -fp16_lm_cross_entropy: False - -# distributed arguments -distributed_backend: nccl -distributed_timeout_minutes: 10 -overlap_grad_reduce: False -align_grad_reduce: True -overlap_param_gather: False -align_param_gather: False -local_rank: null -lazy_mpu_init: null -empty_unused_memory_level: 0 -standalone_embedding_stage: False -use_distributed_optimizer: False -nccl_communicator_config_path: null - -train_iters: null -eval_iters: 32 -eval_interval: 2000 -skip_train: False - -adlr_autoresume: False -adlr_autoresume_interval: 1000 - -# garbage collection -manual_gc: False -manual_gc_interval: 0 -manual_gc_eval: True - -tp_comm_overlap_cfg: null - -#data -data_path: null -split: '99,1,0' -train_data_path: null -valid_data_path: null -test_data_path: null -data_cache_path: null -mock_data: False -vocab_size: null -vocab_file: null -merge_file: null -vocab_extra_ids: 0 -seq_length: 4096 -encoder_seq_length: null -decoder_seq_length: null -sample_rate: 1.0 -mask_prob: 0.15 -short_seq_prob: 0.1 -num_workers: 2 -tokenizer_type: GPTSentencePieceTokenizer -tokenizer_model: null -reset_position_ids: False -reset_attention_mask: False -eod_mask_loss: False -train_samples: 268554688 -dataloader_type: null - -#profile: -profile: False -profile_ranks: [0] -profile_step_end: 12 -profile_step_start: 10 - -#logging: -log_params_norm: True -log_num_zeros_in_grad: True -log_throughput: False -log_progress: False -timing_log_level: 0 -timing_log_option: minmax -tensorboard_log_interval: 1 -tensorboard_queue_size: 1000 -log_timers_to_tensorboard: False -log_validation_ppl_to_tensorboard: False -log_memory_to_tensorboard: False -log_world_size_to_tensorboard: False -log_loss_scale_to_tensorboard: True -wandb_project: '' -wandb_exp_name: '' -wandb_save_dir: '' -enable_one_logger: True -one_logger_project: megatron-lm -one_logger_run_name: null -log_interval: 100 -tensorboard_dir: null diff --git a/examples/inference/README.md b/examples/inference/README.md index 7bba32868f7..3259bf7f943 100644 --- a/examples/inference/README.md +++ b/examples/inference/README.md @@ -142,7 +142,6 @@ NOTE: Other parameters which can be customized for inference: --num-tokens-to-generate (Number of tokens to generate for each prompt) --inference-batch-times-seqlen-threshold (During inference, if batch-size times sequence-length is smaller than this threshold then we will not use microbatched pipelining.') --use-dist-ckpt (If using dist checkpoint format for the model) ---use-legacy-models (If using legacy models instead of MCore models) ``` diff --git a/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py b/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py index 31c2b3529de..b786fcd1d92 100644 --- a/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py +++ b/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py @@ -18,13 +18,13 @@ from megatron.core.inference.inference_client import InferenceClient from megatron.core.inference.inference_request import DynamicInferenceRequestRecord from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.utils import configure_nvtx_profiling from megatron.inference.utils import ( add_inference_args, get_dynamic_inference_engine, get_model_for_inference, ) from megatron.training import get_args, get_tokenizer, initialize_megatron -from megatron.training.arguments import parse_and_validate_args # pylint: disable=line-too-long @@ -218,6 +218,7 @@ async def main( args_defaults={'no_load_rng': True, 'no_load_optim': True}, ) initialize_megatron() + configure_nvtx_profiling(True) tokenizer = get_tokenizer() diff --git a/examples/mamba/run_text_gen_server_8b.sh b/examples/mamba/run_text_gen_server_8b.sh index d228e0c0edb..f183dea4ad1 100755 --- a/examples/mamba/run_text_gen_server_8b.sh +++ b/examples/mamba/run_text_gen_server_8b.sh @@ -22,7 +22,7 @@ export NCCL_IB_QPS_PER_CONNECTION=4 export TRITON_CACHE_DIR="./triton-cache/" export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" -torchrun $DISTRIBUTED_ARGS ../../tools/run_mamba_text_generation_server.py \ +torchrun $DISTRIBUTED_ARGS ../../tools/run_hybrid_text_generation_server.py \ --tensor-model-parallel-size 1 \ --pipeline-model-parallel-size 1 \ --untie-embeddings-and-output-weights \ @@ -46,5 +46,5 @@ torchrun $DISTRIBUTED_ARGS ../../tools/run_mamba_text_generation_server.py \ --bf16 \ --micro-batch-size 1 \ --use-mcore-models \ - --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --spec megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec \ --seed 42 diff --git a/examples/mamba/train.sh b/examples/mamba/train.sh index ba83f0d4e33..f971242ff0b 100755 --- a/examples/mamba/train.sh +++ b/examples/mamba/train.sh @@ -96,8 +96,8 @@ options=" \ --eval-iters 32 \ --bf16 \ --use-mcore-models \ - --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --spec megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec \ --no-create-attention-mask-in-dataloader \ --tensorboard-dir ${TENSORBOARD_DIR}" -torchrun --nproc_per_node 8 ../../pretrain_mamba.py ${options} +torchrun --nproc_per_node 8 ../../pretrain_hybrid.py ${options} diff --git a/examples/megatron_fsdp/README.md b/examples/megatron_fsdp/README.md new file mode 100644 index 00000000000..eaf5eca1364 --- /dev/null +++ b/examples/megatron_fsdp/README.md @@ -0,0 +1,157 @@ +# Megatron-FSDP Examples + +Example scripts for training and checkpoint conversion using [Megatron-FSDP](../../docs/user-guide/features/megatron_fsdp.md). These demonstrate recommended configurations for Llama 3 8B and DeepSeek-V3 671B models, as well as checkpoint format conversion between `torch_dist` (N-D parallel) and `fsdp_dtensor` formats. + +## Scripts + +### `train_llama3_8b_fsdp_h100_fp8.sh` + +Single-node training script for **Llama 3 8B** using Megatron-FSDP with FP8 precision on H100 GPUs. Uses `torchrun` for local distributed training and supports both mock data (for benchmarking) and real datasets. + +#### Usage + +Run from the root of the Megatron-LM repository: + +```bash +# With mock data (default, for benchmarking) +bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh + +# With real data +bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh \ + checkpoints/llama3_8b_fsdp_fp8 \ + tensorboard_logs/llama3_8b_fsdp_fp8 \ + /path/to/tokenizer \ + /path/to/data_prefix +``` + +| Positional Argument | Default | Description | +|---------------------|---------|-------------| +| `$1` — Checkpoint path | `checkpoints/llama3_8b_fsdp_fp8` | Directory for saving and loading checkpoints. | +| `$2` — TensorBoard path | `tensorboard_logs/llama3_8b_fsdp_fp8` | Directory for TensorBoard logs. | +| `$3` — Tokenizer | `MOCK` | Path to a tokenizer model, or `MOCK` for `NullTokenizer`. | +| `$4` — Data path | `MOCK` | Data prefix for training data, or `MOCK` for mock data. | + +#### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `USE_MEGATRON_FSDP` | `1` | Set to `1` to enable Megatron-FSDP. Set to `0` to train with standard DDP. | +| `SHARDING_STRATEGY` | `optim_grads_params` | FSDP sharding strategy (ZeRO-3). Options: `no_shard`, `optim`, `optim_grads`, `optim_grads_params`. | +| `OUTER_SHARDING_STRATEGY` | `no_shard` | DP-Outer sharding strategy for HSDP/HFSDP. Options: `no_shard`, `optim`. | +| `MASTER_ADDR` | `localhost` | Master node address for distributed training. | +| `MASTER_PORT` | `6000` | Master node port. | +| `NODE_RANK` | `0` | Rank of the current node. | + +#### Configuration Summary + +- **Model**: Llama 3 8B (GQA with 32 heads / 8 KV groups, RoPE, SwiGLU, RMSNorm) +- **Parallelism**: TP=1, CP=1, PP=1, 8 GPUs per node, FSDP ZeRO-3 +- **Precision**: FP8 (hybrid format) with BF16 training and BF16 gradient reduction +- **Batch size**: micro-batch=1, global-batch=128, sequence length=8192 +- **Optimizations**: NCCL user buffers, FSDP double buffering, manual registration, meta-device initialization, per-token loss, overlapped grad-reduce and param-gather + +--- + +### `sbatch_mfsdp_deepseek_v3.sh` + +Multi-node SLURM training script for **DeepSeek-V3** (671B MoE) using Megatron-FSDP. Submits an `sbatch` job with containerized execution via `srun`. + +#### Usage + +Set the required configuration variables and submit: + +```bash +export MEGATRON_PATH=/path/to/Megatron-LM +export CONTAINER_IMAGE=/path/to/container.sqsh # or docker image URL +export OUTPUT_PATH=/path/to/output +export DATA_PATH=/path/to/training/data + +bash examples/megatron_fsdp/sbatch_mfsdp_deepseek_v3.sh +``` + +Before running, update the `#SBATCH` directives and `--container-mounts` in the script to match your cluster configuration. + +#### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `MEGATRON_PATH` | *(required)* | Path to the Megatron-LM repository. | +| `CONTAINER_IMAGE` | *(required)* | Container image (`.sqsh` file or Docker URL). | +| `OUTPUT_PATH` | *(required)* | Base directory for checkpoints, TensorBoard logs, SLURM logs, and Nsight profiles. | +| `DATA_PATH` | *(required)* | Training data prefix path. | +| `USE_MEGATRON_FSDP` | `1` | Enable Megatron-FSDP. Set to `0` for standard DDP. | +| `SHARDING_STRATEGY` | `optim_grads_params` | FSDP sharding strategy (ZeRO-3). | +| `TP` | `1` | Tensor parallel size. | +| `EP` | `8` | Expert parallel size. | +| `MBS` | `4` | Micro-batch size. | +| `GBS` | `2048` | Global batch size. | +| `PROFILE` | `0` | Set to `1` to enable Nsight Systems profiling (steps 10–12). | +| `WANDB` | `1` | Set to `1` to enable Weights & Biases logging. Requires `WANDB_API_KEY`. | +| `COMMENT` | N/A | Tag appended to W&B experiment names and Nsight profile filenames. | + +#### Configuration Summary + +- **Model**: DeepSeek-V3 (61 layers, 256 routed experts, top-8 routing, Multi-Latent Attention, MTP) +- **Parallelism**: TP=1, EP=8, CP=1, FSDP ZeRO-3 +- **Precision**: BF16 +- **MoE**: Flex dispatcher with HybridEP backend, grouped GEMM, sigmoid routing with expert bias, auxiliary sequence loss +- **Recomputation**: Selective recomputation of `mlp`, `moe`, `mla_up_proj`, and `layernorm` modules +- **Optimizations**: NCCL user buffers, FSDP double buffering, meta-device initialization, per-token loss, overlapped grad-reduce and param-gather +- **Tokenizer**: `deepseek-ai/DeepSeek-V3` via HuggingFace + +--- + +### `sbatch_checkpoint_convert.sh` + +SLURM batch script for converting checkpoints from **`torch_dist`** (N-D parallel) format to **`fsdp_dtensor`** (Megatron-FSDP) format. This enables resuming training under Megatron-FSDP from checkpoints originally saved with tensor/pipeline/expert parallelism. + +#### Prerequisites + +Before converting, you need a `param_to_param_group_map.json` file. Generate it by running a `torch_dist` training job with the `--dump-param-to-param-group-map` flag, then converting the output: + +```bash +# 1. Run a training job (or trivial experiment) with the dump flag +--dump-param-to-param-group-map /path/to/param_to_param_group_map + +# 2. Convert the dumped map to JSON +python tools/checkpoint/checkpoint_inspector.py \ + print-torch-dcp-in-json /path/to/param_to_param_group_map +``` + +See the [Checkpoint Conversion](../../docs/user-guide/features/megatron_fsdp.md#checkpoint-conversion) section in the Megatron-FSDP docs for details. + +#### Usage + +Set the required configuration variables, update the checkpoint paths in `RUN_CMD`, and submit: + +```bash +export MEGATRON_PATH=/path/to/Megatron-LM +export CONTAINER_IMAGE=/path/to/container.sqsh +export OUTPUT_PATH=/path/to/output + +bash examples/megatron_fsdp/sbatch_checkpoint_convert.sh +``` + +Before running, you must edit the script to fill in: +- The input `torch_dist` checkpoint path +- The output `fsdp_dtensor` checkpoint path +- The path to `param_to_param_group_map.json` +- The `#SBATCH` directives and `--container-mounts` for your cluster + +#### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `MEGATRON_PATH` | *(required)* | Path to the Megatron-LM repository. | +| `CONTAINER_IMAGE` | *(required)* | Container image (`.sqsh` file or Docker URL). | +| `OUTPUT_PATH` | *(required)* | Base directory for SLURM logs. | + +#### Conversion Command + +The script runs `checkpoint_inspector.py convert-torch-dist-to-fsdp-dtensor` with the `--swiglu` flag (for models using SwiGLU activations). Remove `--swiglu` if converting a non-SwiGLU model. + +## Further Reading + +- [Megatron-FSDP User Guide](../../docs/user-guide/features/megatron_fsdp.md) — full feature guide, API reference, and sharding strategy documentation. +- [Megatron-FSDP on PyPI](https://pypi.org/project/megatron-fsdp/) — standalone `fully_shard` API. +- [Megatron-FSDP Source](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/distributed/fsdp/src) — implementation source code. diff --git a/docs/discussions/megatron-fsdp-user-guide/example-scripts/sbatch_checkpoint_convert.sh b/examples/megatron_fsdp/sbatch_checkpoint_convert.sh similarity index 100% rename from docs/discussions/megatron-fsdp-user-guide/example-scripts/sbatch_checkpoint_convert.sh rename to examples/megatron_fsdp/sbatch_checkpoint_convert.sh diff --git a/docs/discussions/megatron-fsdp-user-guide/example-scripts/sbatch_mfsdp_deepseek_v3.sh b/examples/megatron_fsdp/sbatch_mfsdp_deepseek_v3.sh similarity index 99% rename from docs/discussions/megatron-fsdp-user-guide/example-scripts/sbatch_mfsdp_deepseek_v3.sh rename to examples/megatron_fsdp/sbatch_mfsdp_deepseek_v3.sh index 7b93d25d943..22a8f22f68c 100644 --- a/docs/discussions/megatron-fsdp-user-guide/example-scripts/sbatch_mfsdp_deepseek_v3.sh +++ b/examples/megatron_fsdp/sbatch_mfsdp_deepseek_v3.sh @@ -23,7 +23,7 @@ TP=${TP:-1} EP=${EP:-8} MBS=${MBS:-4} GBS=${GBS:-2048} -COMMENT=${COMMENT:-"hybridep-selective-recompute"} +COMMENT=${COMMENT:-""} PRETRAIN_ARGS=( --distributed-timeout-minutes 60 diff --git a/examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh b/examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh new file mode 100644 index 00000000000..ddd3f160fa7 --- /dev/null +++ b/examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh @@ -0,0 +1,212 @@ +#!/bin/bash + +CHECKPOINT_PATH=${1:-"checkpoints/llama3_8b_fsdp_fp8"} +TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama3_8b_fsdp_fp8"} +TOKENIZER_ARG=${3:-"MOCK"} # Path to tokenizer model, or "MOCK" +DATA_ARG=${4:-"MOCK"} # Data prefix, or "MOCK" + +# Create directories if they don't exist +mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" + +# Distributed training setup +GPUS_PER_NODE=8 +NUM_NODES=1 +MASTER_ADDR=${MASTER_ADDR:-localhost} +MASTER_PORT=${MASTER_PORT:-6000} +NODE_RANK=${NODE_RANK:-0} +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) + +# Path to the pretrain_gpt.py script, assuming this script +# is run from the root of the Megatron-LM repository. +PRETRAIN_SCRIPT_PATH="pretrain_gpt.py" + +# Model & Training Parameters +USE_MEGATRON_FSDP=${USE_MEGATRON_FSDP:-1} +SHARDING_STRATEGY=${SHARDING_STRATEGY:-"optim_grads_params"} +OUTER_SHARDING_STRATEGY=${OUTER_SHARDING_STRATEGY:-"no_shard"} +TP_SIZE=1 +CP_SIZE=1 +PP_SIZE=1 +MICRO_BATCH_SIZE=1 +GLOBAL_BATCH_SIZE=128 +NUM_LAYERS=32 +DTYPE="fp8" +SEQ_LENGTH=8192 +MAX_POSITION_EMBEDDINGS=8192 + +# Data cache path (useful for both mock and real data) +DATA_CACHE_PATH="${PWD}/benchmark_cache_llama3_8b_fsdp_fp8" +mkdir -p "$DATA_CACHE_PATH" + +DISTRIBUTED_ARGS=( + --nproc_per_node $GPUS_PER_NODE + --nnodes $NUM_NODES + --node_rank $NODE_RANK + --master_addr $MASTER_ADDR + --master_port $MASTER_PORT +) + +MODEL_ARGS=( + --use-mcore-models + --num-layers $NUM_LAYERS + --hidden-size 4096 + --ffn-hidden-size 14336 + --num-attention-heads 32 + --group-query-attention + --num-query-groups 8 + --kv-channels 128 + --seq-length $SEQ_LENGTH + --max-position-embeddings $MAX_POSITION_EMBEDDINGS + --position-embedding-type rope + --rotary-base 1000000 + --rotary-percent 1.0 + --attention-dropout 0.0 + --hidden-dropout 0.0 + --swiglu + --normalization RMSNorm + --init-method-std 0.0134 + --attention-backend fused + --apply-layernorm-1p + --untie-embeddings-and-output-weights + --disable-bias-linear +) + +TRAINING_ARGS=( + --micro-batch-size $MICRO_BATCH_SIZE + --global-batch-size $GLOBAL_BATCH_SIZE + --train-samples 1953125000 + --lr-decay-samples 1949218748 + --lr-warmup-samples 3906252 + --lr 0.00015 + --min-lr 0.00001 + --decoupled-lr 5.0e-4 + --decoupled-min-lr 4.5e-5 + --lr-decay-style cosine + --clip-grad 1.0 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.95 + --bf16 + --cross-entropy-loss-fusion + --manual-gc + --empty-unused-memory-level 1 + --exit-duration-in-mins 235 +) + +if [ "${USE_MEGATRON_FSDP}" = 1 ]; then + unset CUDA_DEVICE_MAX_CONNECTIONS + TRAINING_ARGS=( + "${TRAINING_ARGS[@]}" + --use-megatron-fsdp + --data-parallel-sharding-strategy ${SHARDING_STRATEGY} + --no-gradient-accumulation-fusion + --calculate-per-token-loss + --init-model-with-meta-device + --ckpt-format fsdp_dtensor + --grad-reduce-in-bf16 + --use-nccl-ub + --fsdp-double-buffer + --fsdp-manual-registration + # To enable HFSDP, DP full-sharding of the optimizer state with + # hierarchical data parallelism (DP-Outer=2, DP-Inner=DP//2)... + # --num-distributed-optimizer-instances 2 + # --outer-dp-sharding-strategy ${OUTER_SHARDING_STRATEGY} + # To further customize Megatron-FSDP data precision... + # --megatron-fsdp-main-params-dtype fp32 + # --megatron-fsdp-main-grads-dtype auto + # --megatron-fsdp-grad-comm-dtype auto + ) +fi + +# Conditional arguments based on DTYPE (FP8) +DTYPE_ARGS=() +if [[ "$DTYPE" == "fp8" ]]; then + DTYPE_ARGS+=( + "--fp8-format hybrid" + "--fp8-amax-history-len 1024" + "--fp8-amax-compute-algo max" + "--fp8-param-gather" + ) +fi + +# Model parallelism arguments +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size $TP_SIZE + --context-parallel-size $CP_SIZE + --sequence-parallel +) + +# Distributed Data Parallel (DDP) arguments +# From original script's ddp_args +DDP_ARGS=( + --use-distributed-optimizer + --overlap-grad-reduce + --overlap-param-gather +) +TRAINING_ARGS+=("${DDP_ARGS[@]}") + + +# Data arguments (conditional for mock vs real data) +DATA_ARGS_LIST=() +if [[ "$TOKENIZER_ARG" == "MOCK" ]] || [[ "$DATA_ARG" == "MOCK" ]] || [[ -z "$TOKENIZER_ARG" ]]; then + DATA_ARGS_LIST+=( + "--mock-data" + "--tokenizer-type NullTokenizer" + "--vocab-size 128256" + "--data-cache-path ${DATA_CACHE_PATH}" + "--tiktoken-pattern v2" + "--split '99,1,0'" + "--no-create-attention-mask-in-dataloader" + "--no-mmap-bin-files" + "--num-workers 1" + ) +else + # Settings for real data + DATA_ARGS_LIST+=( + "--data-path $DATA_ARG" + "--tokenizer-type HuggingFaceTokenizer" + "--tokenizer-model $TOKENIZER_ARG" + "--data-cache-path ${DATA_CACHE_PATH}" + "--split '99,1,0'" + "--no-create-attention-mask-in-dataloader" + "--no-mmap-bin-files" + "--num-workers 1" + # Note: --vocab-size might be inferred by HuggingFaceTokenizer or might need to be explicit. + "--vocab-size 128256" + ) +fi + +EVAL_AND_LOGGING_ARGS=( + --log-interval 1 + --eval-iters 32 + --eval-interval 100 + --save-interval 1000 + --log-throughput + --profile + --profile-step-start 4 + --profile-step-end 6 + --distributed-timeout-minutes 60 + --save "$CHECKPOINT_PATH" + --load "$CHECKPOINT_PATH" + --tensorboard-dir "$TENSORBOARD_LOGS_PATH" +) + +# Ensure pretrain_gpt.py is found +if [ ! -f "$PRETRAIN_SCRIPT_PATH" ]; then + echo "Error: pretrain_gpt.py not found at $PRETRAIN_SCRIPT_PATH" + echo "Please ensure you are running this script from the root of the Megatron-LM repository, and pretrain_gpt.py is present." + exit 1 +fi + +# Run the training command +torchrun ${DISTRIBUTED_ARGS[@]} \ + "$PRETRAIN_SCRIPT_PATH" \ + ${MODEL_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${DTYPE_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${DATA_ARGS_LIST[@]} \ + ${EVAL_AND_LOGGING_ARGS[@]} + +set +x \ No newline at end of file diff --git a/examples/mimo/train.py b/examples/mimo/train.py index 05eb4f2ab0c..52be3f7ec58 100644 --- a/examples/mimo/train.py +++ b/examples/mimo/train.py @@ -19,6 +19,7 @@ get_tensor_model_parallel_src_rank, ) from megatron.training import get_args, pretrain, print_rank_0 +from megatron.training.argument_utils import pretrain_cfg_container_from_args from megatron.training.arguments import parse_and_validate_args sys.path.append( @@ -297,8 +298,10 @@ def model_provider( if __name__ == "__main__": train_valid_test_datasets_provider.is_distributed = True - parse_and_validate_args(args_defaults={}, extra_args_provider=add_mimo_args) + args = parse_and_validate_args(args_defaults={}, extra_args_provider=add_mimo_args) + full_config = pretrain_cfg_container_from_args(args) pretrain( + full_config, train_valid_test_datasets_provider, model_provider, ModelType.encoder_or_decoder, diff --git a/examples/multimodal/layer_specs.py b/examples/multimodal/layer_specs.py index ad24850b631..c51fb69f496 100644 --- a/examples/multimodal/layer_specs.py +++ b/examples/multimodal/layer_specs.py @@ -1,8 +1,9 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved. import torch +from megatron.core.extensions.transformer_engine import HAVE_TE from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules +from megatron.core.models.hybrid.hybrid_block import HybridStack, HybridStackSubmodules from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules from megatron.core.ssm.mlp_layer import MLPLayer @@ -15,7 +16,6 @@ from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.typed_torch import not_none -from megatron.core.extensions.transformer_engine import HAVE_TE if HAVE_TE: from megatron.core.extensions.transformer_engine import ( @@ -125,15 +125,15 @@ def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec: ) -def get_mamba_layer_spec_te(padding=False) -> ModuleSpec: +def get_hybrid_layer_spec_te(padding=False) -> ModuleSpec: attn_mask_type = AttnMaskType.causal # Padding mask is needed for e.g. Context Parallel. if padding: attn_mask_type = AttnMaskType.padding_causal return ModuleSpec( - module=MambaStack, - submodules=MambaStackSubmodules( + module=HybridStack, + submodules=HybridStackSubmodules( mamba_layer=ModuleSpec( module=MambaLayer, submodules=MambaLayerSubmodules( diff --git a/examples/multimodal/model.py b/examples/multimodal/model.py index 494a854099e..fff3bceaa99 100644 --- a/examples/multimodal/model.py +++ b/examples/multimodal/model.py @@ -1,23 +1,34 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import warnings +# Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved. import logging +import warnings from copy import deepcopy import torch from config import get_language_model_config, get_vision_model_config, get_vision_projection_config -from layer_specs import (get_layer_spec, get_layer_spec_te, get_mlp_module_spec, get_norm_mlp_module_spec_te, - get_mamba_layer_spec_te) +from layer_specs import ( + get_hybrid_layer_spec_te, + get_layer_spec, + get_layer_spec_te, + get_mlp_module_spec, + get_norm_mlp_module_spec_te, +) from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN, LLaVAModel from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings +from megatron.core.utils import log_single_rank from megatron.training import get_args, get_tokenizer, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args -from megatron.core.utils import log_single_rank def model_provider( - pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True, - vp_stage=None, config=None, pg_collection=None + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True, + parallel_output=True, + vp_stage=None, + config=None, + pg_collection=None, ) -> LLaVAModel: """Builds the model. @@ -51,7 +62,7 @@ def model_provider( args.pixel_shuffle, args.use_tile_tags, args.max_num_tiles, - args.tokenizer_prompt_format + args.tokenizer_prompt_format, ) old_seq_length = args.seq_length args.seq_length = args.encoder_seq_length = num_image_embeddings @@ -59,10 +70,12 @@ def model_provider( log_single_rank( logging.getLogger(__name__), logging.WARNING, - f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})" + f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})", ) - max_num_image_embeddings = max((args.max_num_tiles + int(args.use_thumbnail)), args.num_frames) * num_image_embeddings + max_num_image_embeddings = ( + max((args.max_num_tiles + int(args.use_thumbnail)), args.num_frames) * num_image_embeddings + ) assert ( args.decoder_seq_length is not None @@ -88,10 +101,16 @@ def model_provider( language_config = get_language_model_config(language_config) if language_model_type.startswith("hf://"): - assert args.tensor_model_parallel_size == 1, "Huggingface models do not support --tensor-model-parallel-size > 1" - assert args.pipeline_model_parallel_size < 2, "Huggingface models do not support --pipeline-model-parallel-size > 1" + assert ( + args.tensor_model_parallel_size == 1 + ), "Huggingface models do not support --tensor-model-parallel-size > 1" + assert ( + args.pipeline_model_parallel_size < 2 + ), "Huggingface models do not support --pipeline-model-parallel-size > 1" assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel" - assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1" + assert ( + args.context_parallel_size < 2 + ), "Huggingface models do not support --context-parallel-size > 1" if language_model_type.startswith("hf://"): language_transformer_layer_spec = None @@ -99,7 +118,7 @@ def model_provider( # Padding mask needed for SP/CP. padding = args.context_parallel_size > 1 and args.sequence_parallel if args.language_model_type.startswith('nemotron5-hybrid'): - language_transformer_layer_spec = get_mamba_layer_spec_te(padding=padding) + language_transformer_layer_spec = get_hybrid_layer_spec_te(padding=padding) else: language_transformer_layer_spec = get_layer_spec_te( is_vit=False, padding=padding @@ -115,7 +134,9 @@ def model_provider( ) if vision_model_type.startswith("hf://"): assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel" - assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1" + assert ( + args.context_parallel_size < 2 + ), "Huggingface models do not support --context-parallel-size > 1" if vision_model_type in ["clip", "siglip", "radio", "cradio-g"]: if use_te: @@ -129,17 +150,23 @@ def model_provider( elif vision_model_type == "radio-g": if use_te: from radio.radio_g import get_radio_g_layer_spec_te - vision_transformer_layer_spec = get_radio_g_layer_spec_te() # TENorm detects LayerNorm/RMS automatically. + + vision_transformer_layer_spec = ( + get_radio_g_layer_spec_te() + ) # TENorm detects LayerNorm/RMS automatically. else: from radio.radio_g import get_radio_g_layer_spec + vision_transformer_layer_spec = get_radio_g_layer_spec( normalization=vision_config.normalization ) elif vision_model_type == "internvit": from nvlm.internvit import get_internvit_layer_spec + vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te) elif vision_model_type == "internvit300M": from nvlm.internvit import get_internvit300M_layer_spec + vision_transformer_layer_spec = get_internvit300M_layer_spec(use_te=use_te) elif vision_model_type.startswith("hf://"): vision_transformer_layer_spec = None @@ -154,7 +181,9 @@ def model_provider( # Make sure vision model pipeline parallel size is not inherited from the language model pipeline parallel size. vision_config.pipeline_model_parallel_size = 1 - vision_projection_config.pipeline_model_parallel_size = vision_config.pipeline_model_parallel_size + vision_projection_config.pipeline_model_parallel_size = ( + vision_config.pipeline_model_parallel_size + ) # Make sure the vision model does not inherit first and last pipeline num layers from the language model. vision_config.first_pipeline_num_layers = vision_config.last_pipeline_num_layers = None @@ -166,7 +195,10 @@ def model_provider( # Toggle --recompute* for the vision and language model separately. if args.recompute_vision: - if vision_config.recompute_method is not None and vision_config.recompute_granularity is not None: + if ( + vision_config.recompute_method is not None + and vision_config.recompute_granularity is not None + ): vision_config.recompute_num_layers = vision_config.num_layers else: vision_config.recompute_granularity = None @@ -188,7 +220,9 @@ def model_provider( tokenizer = get_tokenizer() image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - assert image_token_index is not None, f"IMAGE_TOKEN={IMAGE_TOKEN} needs to be added using the --special-tokens arg." + assert ( + image_token_index is not None + ), f"IMAGE_TOKEN={IMAGE_TOKEN} needs to be added using the --special-tokens arg." tile_tags = _get_tile_tags(args, tokenizer) @@ -247,16 +281,25 @@ def _get_tile_tags(args, tokenizer): thumbnail_tag_text = "" if args.tokenizer_prompt_format.startswith("nemotron"): - tile_tags_text = [f"" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text] + tile_tags_text = [f"" for i in range(1, args.max_num_tiles + 1)] + [ + thumbnail_tag_text + ] else: - tile_tags_text = [f"" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text] + tile_tags_text = [f"" for i in range(1, args.max_num_tiles + 1)] + [ + thumbnail_tag_text + ] elif args.max_num_tiles <= 12: thumbnail_tag_text = "" if args.tokenizer_prompt_format == "nvlm-yi-34b": thumbnail_tag_text = "" - elif args.tokenizer_prompt_format.startswith("nemotron") or args.tokenizer_prompt_format == "llama3p1": + elif ( + args.tokenizer_prompt_format.startswith("nemotron") + or args.tokenizer_prompt_format == "llama3p1" + ): thumbnail_tag_text = "" - tile_tags_text = [f"" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text] + tile_tags_text = [f"" for i in range(1, args.max_num_tiles + 1)] + [ + thumbnail_tag_text + ] else: raise ValueError("We only support max_num_tiles <= 12 when using nvlm image_tag_type") diff --git a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.sh b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.sh index 1fa00889e99..36698852936 100644 --- a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.sh +++ b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.sh @@ -11,7 +11,7 @@ MODEL_ARGS=" \ --trust-remote-code \ --save-interval 100000 \ --micro-batch-size 1 \ - --moe-token-dispatcher-type allgather \ + --moe-token-dispatcher-type alltoall \ --enable-experimental \ --moe-permute-fusion \ --use-fused-weighted-squared-relu \ @@ -51,5 +51,5 @@ MODEL_ARGS=" \ --bf16 \ --seq-length 8192 \ --max-position-embeddings 8192 \ - --export-model-type MambaModel \ + --export-model-type HybridModel \ " diff --git a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16.sh b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16.sh index 977be033df0..f38a316632a 100644 --- a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16.sh +++ b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16.sh @@ -28,7 +28,7 @@ MODEL_ARGS=" \ --moe-router-dtype fp32 \ --moe-router-load-balancing-type seq_aux_loss \ --moe-shared-expert-intermediate-size 5376 \ - --moe-token-dispatcher-type allgather \ + --moe-token-dispatcher-type alltoall \ --moe-latent-size 1024 \ \ --attention-backend flash \ @@ -58,5 +58,5 @@ MODEL_ARGS=" \ --bf16 \ --seq-length 8192 \ --max-position-embeddings 8192 \ - --export-model-type MambaModel \ + --export-model-type HybridModel \ " diff --git a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-Nano-9B-v2.sh b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-Nano-9B-v2.sh index 83867430a97..51aff10a22a 100644 --- a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-Nano-9B-v2.sh +++ b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-Nano-9B-v2.sh @@ -35,6 +35,6 @@ MODEL_ARGS=" \ --tokenizer-type HuggingFaceTokenizer \ --make-vocab-size-divisible-by 1 \ --use-mcore-models \ - --export-model-type MambaModel \ + --export-model-type HybridModel \ --padded-vocab-size 131072 \ " diff --git a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-47B-Reasoning-128K.sh b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-47B-Reasoning-128K.sh index 901e607f298..e2da6a3c33d 100644 --- a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-47B-Reasoning-128K.sh +++ b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-47B-Reasoning-128K.sh @@ -33,5 +33,5 @@ MODEL_ARGS=" \ --max-position-embeddings 8192 \ --tokenizer-type HuggingFaceTokenizer \ --use-mcore-models \ - --export-model-type MambaModel \ + --export-model-type HybridModel \ " diff --git a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-4B-Instruct.sh b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-4B-Instruct.sh index 084db49e0eb..523f7d521b0 100644 --- a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-4B-Instruct.sh +++ b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-4B-Instruct.sh @@ -38,5 +38,5 @@ MODEL_ARGS=" \ --make-vocab-size-divisible-by 1 \ --use-mcore-models \ --rotary-base 10000 \ - --export-model-type MambaModel \ + --export-model-type HybridModel \ " diff --git a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-56B-Base-8K.sh b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-56B-Base-8K.sh index 645a159d075..be80d8a9a19 100644 --- a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-56B-Base-8K.sh +++ b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-56B-Base-8K.sh @@ -35,5 +35,5 @@ MODEL_ARGS=" \ --max-position-embeddings 8192 \ --tokenizer-type HuggingFaceTokenizer \ --bf16 \ - --export-model-type MambaModel \ + --export-model-type HybridModel \ " diff --git a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-8B-Base-8K.sh b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-8B-Base-8K.sh index 66f3ad368b4..36b242e36dd 100644 --- a/examples/post_training/modelopt/conf/nvidia/Nemotron-H-8B-Base-8K.sh +++ b/examples/post_training/modelopt/conf/nvidia/Nemotron-H-8B-Base-8K.sh @@ -37,6 +37,6 @@ MODEL_ARGS=" \ --use-mcore-models \ --rotary-percent 0.5 \ --rotary-base 500000 \ - --export-model-type MambaModel \ + --export-model-type HybridModel \ " # --rotary-base 10000 \ diff --git a/examples/post_training/modelopt/convert_model.py b/examples/post_training/modelopt/convert_model.py index eaec9789e1e..cc34e1e3e3c 100644 --- a/examples/post_training/modelopt/convert_model.py +++ b/examples/post_training/modelopt/convert_model.py @@ -19,12 +19,10 @@ from megatron.core.parallel_state import destroy_model_parallel from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.checkpointing import load_modelopt_checkpoint -from megatron.post_training.model_builder import modelopt_gpt_mamba_builder -from megatron.post_training.utils import ( - report_current_memory_info, - to_empty_if_meta, -) +from megatron.post_training.model_builder import modelopt_gpt_hybrid_builder +from megatron.post_training.utils import report_current_memory_info, to_empty_if_meta from megatron.training import get_args +from megatron.training.arguments import parse_and_validate_args from megatron.training.checkpointing import save_checkpoint from megatron.training.initialize import initialize_megatron from megatron.training.utils import print_rank_0, unwrap_model @@ -102,7 +100,7 @@ def check_arguments(): if __name__ == "__main__": - initialize_megatron( + parse_and_validate_args( extra_args_provider=add_convert_args, args_defaults={ 'tokenizer_type': 'HuggingFaceTokenizer', @@ -110,6 +108,7 @@ def check_arguments(): 'no_load_optim': True, }, ) + initialize_megatron() check_arguments() args = get_args() @@ -129,7 +128,7 @@ def check_arguments(): ) model = get_model( - functools.partial(model_provider, modelopt_gpt_mamba_builder), wrap_with_ddp=False + functools.partial(model_provider, modelopt_gpt_hybrid_builder), wrap_with_ddp=False ) report_current_memory_info() @@ -142,10 +141,7 @@ def check_arguments(): print_rank_0( "Import model from Hugging Face checkpoint in dtype {}.".format(str(import_dtype)) ) - import_kwargs = { - "dtype": import_dtype, - "moe_router_dtype": args.moe_router_dtype, - } + import_kwargs = {"dtype": import_dtype, "moe_router_dtype": args.moe_router_dtype} if "trust_remote_code" in inspect.signature(import_mcore_gpt_from_hf).parameters: import_kwargs.update({"trust_remote_code": args.trust_remote_code}) import_mcore_gpt_from_hf( diff --git a/examples/post_training/modelopt/distillation.md b/examples/post_training/modelopt/distillation.md index 49f73c4edde..9946723364e 100644 --- a/examples/post_training/modelopt/distillation.md +++ b/examples/post_training/modelopt/distillation.md @@ -53,7 +53,7 @@ Without this configuration file, the default logits-only distillation with scale ### Training -Distillation is triggered by calling `pretrain_gpt.py` or `pretrain_mamba.py` with the following arguments: +Distillation is triggered by calling `pretrain_gpt.py` or `pretrain_hybrid.py` with the following arguments: ```bash --export-kd-teacher-load diff --git a/examples/post_training/modelopt/export.py b/examples/post_training/modelopt/export.py index 5e3b2a1716e..c080115d729 100755 --- a/examples/post_training/modelopt/export.py +++ b/examples/post_training/modelopt/export.py @@ -15,8 +15,9 @@ from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.checkpointing import load_modelopt_checkpoint -from megatron.post_training.model_builder import modelopt_gpt_mamba_builder +from megatron.post_training.model_builder import modelopt_gpt_hybrid_builder from megatron.training import get_args, get_model +from megatron.training.arguments import parse_and_validate_args from megatron.training.initialize import initialize_megatron from megatron.training.utils import unwrap_model from model_provider import model_provider @@ -49,7 +50,7 @@ def add_modelopt_export_args(parser): if __name__ == "__main__": - initialize_megatron( + parse_and_validate_args( extra_args_provider=add_modelopt_export_args, args_defaults={ 'tokenizer_type': 'HuggingFaceTokenizer', @@ -57,6 +58,7 @@ def add_modelopt_export_args(parser): 'no_load_optim': True, }, ) + initialize_megatron() args = get_args() @@ -74,7 +76,7 @@ def add_modelopt_export_args(parser): ) model = get_model( - functools.partial(model_provider, modelopt_gpt_mamba_builder), wrap_with_ddp=False + functools.partial(model_provider, modelopt_gpt_hybrid_builder), wrap_with_ddp=False ) # Materialize the model from meta device to cpu before loading the checkpoint. @@ -86,7 +88,6 @@ def add_modelopt_export_args(parser): else: raise ValueError(f"Invalid load checkpoint directory: {args.load}") - # Decide whether we are exporting only the extra_modules (e.g. EAGLE3). # Only the last pp stage may have extra_modules, hence broadcast from the last rank. export_extra_modules = hasattr(unwrapped_model, "eagle_module") or hasattr( @@ -102,8 +103,11 @@ def add_modelopt_export_args(parser): "export_dir": args.export_dir, "moe_router_dtype": unwrapped_model.config.moe_router_dtype, } - if "trust_remote_code" in inspect.signature(mtex.export_mcore_gpt_to_hf).parameters: + export_fn = ( + mtex.export_mcore_gpt_to_hf_vllm_fq if args.export_vllm_fq else mtex.export_mcore_gpt_to_hf + ) + + if "trust_remote_code" in inspect.signature(export_fn).parameters: export_kwargs.update({"trust_remote_code": args.trust_remote_code}) - - export_fn = mtex.export_mcore_gpt_to_hf_vllm_fq if args.export_vllm_fq else mtex.export_mcore_gpt_to_hf + export_fn(unwrapped_model, args.pretrained_model_name, **export_kwargs) diff --git a/examples/post_training/modelopt/finetune.py b/examples/post_training/modelopt/finetune.py index f7f7c24f970..adff7421fd3 100755 --- a/examples/post_training/modelopt/finetune.py +++ b/examples/post_training/modelopt/finetune.py @@ -13,13 +13,14 @@ import datasets import torch import transformers +from utils import get_hf_tokenizer from megatron.core import mpu, tensor_parallel from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.loss_func import loss_func -from megatron.post_training.model_builder import modelopt_gpt_mamba_builder +from megatron.post_training.model_builder import modelopt_gpt_hybrid_builder from megatron.post_training.non_loss_data_func import report_draft_acceptance_length from megatron.training import get_args, get_timers, pretrain from megatron.training.utils import ( @@ -27,7 +28,6 @@ get_ltor_masks_and_position_ids, print_rank_0, ) -from utils import get_hf_tokenizer from model_provider import model_provider REMOVE_THINK_CHAT_TEMPLATE = ( @@ -38,12 +38,16 @@ def add_finetune_args(parser): """Add additional arguments for finetune.""" group = parser.add_argument_group(title='Finetune') - group.add_argument("--offline-distillation-data", type=str, help="Path to the offline dataset directory with base model features.") - + group.add_argument( + "--offline-distillation-data", + type=str, + help="Path to the offline dataset directory with base model features.", + ) add_modelopt_args(parser) return parser + def get_eos_id(): """Return the eos token id. @@ -83,6 +87,7 @@ def __getitem__(self, idx): sample = torch.load(file_path) return sample + class SFTDataset(torch.utils.data.Dataset): hf_dataset_to_kwargs = { @@ -105,7 +110,7 @@ class SFTDataset(torch.utils.data.Dataset): } hf_dataset_to_prompt_template = { - "Open-Orca/OpenOrca": "{{ messages['question'] + ' ' + messages['response'] + ' ' }}", + "Open-Orca/OpenOrca": "{{ messages['question'] + ' ' + messages['response'] + ' ' }}" } @classmethod @@ -161,13 +166,11 @@ def __init__( REMOVE_THINK_CHAT_TEMPLATE, "" ) - hf_dataset_kwargs = SFTDataset.hf_dataset_to_kwargs.get( - self.hf_dataset, {"split": "train"} - ) - self._raw_samples = datasets.load_dataset(self.hf_dataset, token=os.environ.get("HF_TOKEN", None), **hf_dataset_kwargs) - self._raw_samples = self._raw_samples.shard( - num_shards=self.num_shards, index=shard_index + hf_dataset_kwargs = SFTDataset.hf_dataset_to_kwargs.get(self.hf_dataset, {"split": "train"}) + self._raw_samples = datasets.load_dataset( + self.hf_dataset, token=os.environ.get("HF_TOKEN", None), **hf_dataset_kwargs ) + self._raw_samples = self._raw_samples.shard(num_shards=self.num_shards, index=shard_index) print( "Rank {:3}/{:3} creates SFT data shard {:3}/{:3} with {:10} raw samples".format( @@ -349,9 +352,15 @@ def train_valid_test_sft_datasets_provider(train_val_test_num_samples): raise ValueError("SFTDataloader only supports micro_batch_size=1.") if args.export_offline_model: - train_ds = OfflineDataset(os.path.join(args.offline_distillation_data, "train"), train_val_test_num_samples[0]) - valid_ds = OfflineDataset(os.path.join(args.offline_distillation_data, "valid"), train_val_test_num_samples[1]) - test_ds = OfflineDataset(os.path.join(args.offline_distillation_data, "test"), train_val_test_num_samples[2]) + train_ds = OfflineDataset( + os.path.join(args.offline_distillation_data, "train"), train_val_test_num_samples[0] + ) + valid_ds = OfflineDataset( + os.path.join(args.offline_distillation_data, "valid"), train_val_test_num_samples[1] + ) + test_ds = OfflineDataset( + os.path.join(args.offline_distillation_data, "test"), train_val_test_num_samples[2] + ) print_rank_0("> finished creating offline SFT datasets ...") else: @@ -398,14 +407,15 @@ def get_batch(data_iterator): datatype = torch.int64 data_b = tensor_parallel.broadcast_data(keys, data, datatype) data_b["loss_mask"] = torch.ones_like(data_b["input_ids"]) - data_b["loss_mask"][data_b["loss_mask"]==get_eos_id()] = 0 - data_b["loss_mask"] = torch.cat([data_b["loss_mask"], torch.zeros(1,1).to(torch.cuda.current_device())], dim=-1) + data_b["loss_mask"][data_b["loss_mask"] == get_eos_id()] = 0 + data_b["loss_mask"] = torch.cat( + [data_b["loss_mask"], torch.zeros(1, 1).to(torch.cuda.current_device())], dim=-1 + ) keys = ["aux_hidden_states", "hidden_states"] datatype = torch.bfloat16 feature_b = tensor_parallel.broadcast_data(keys, data, datatype) - # Unpack the data received. tokens_ = data_b["input_ids"] tokens = tokens_[:, 0 : 0 + args.seq_length].contiguous() @@ -414,11 +424,16 @@ def get_batch(data_iterator): # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( - tokens, get_eos_id(), get_eos_id(), args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss, False + tokens, + get_eos_id(), + get_eos_id(), + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + False, ) loss_mask = loss_mask * answer_only_loss_mask.to(dtype=loss_mask.dtype) - labels = labels.contiguous() loss_mask = loss_mask.contiguous() @@ -431,8 +446,10 @@ def get_batch(data_iterator): } if args.export_offline_model: - batch["aux_hidden_states"] = feature_b["aux_hidden_states"].transpose(0, 1)[:args.seq_length] - batch["hidden_states"] = feature_b["hidden_states"].transpose(0, 1)[:args.seq_length] + batch["aux_hidden_states"] = feature_b["aux_hidden_states"].transpose(0, 1)[ + : args.seq_length + ] + batch["hidden_states"] = feature_b["hidden_states"].transpose(0, 1)[: args.seq_length] # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) @@ -450,7 +467,6 @@ def non_loss_data_func(model: GPTModel): print(e) - def forward_step(data_iterator, model: GPTModel): """Forward training step. @@ -476,7 +492,14 @@ def forward_step(data_iterator, model: GPTModel): timers("batch-generator").stop() if args.export_offline_model: - output_tensor = model(tokens, position_ids, attention_mask, labels=labels, aux_hidden_states=aux_hidden_states, hidden_states=hidden_states,) + output_tensor = model( + tokens, + position_ids, + attention_mask, + labels=labels, + aux_hidden_states=aux_hidden_states, + hidden_states=hidden_states, + ) else: output_tensor = model(tokens, position_ids, attention_mask, labels=labels) @@ -484,12 +507,16 @@ def forward_step(data_iterator, model: GPTModel): if __name__ == "__main__": + from megatron.training.arguments import parse_and_validate_args + + parse_and_validate_args( + extra_args_provider=add_finetune_args, + args_defaults={"tokenizer_type": "HuggingFaceTokenizer"}, + ) pretrain( train_valid_test_sft_datasets_provider, - partial(model_provider, modelopt_gpt_mamba_builder), + partial(model_provider, modelopt_gpt_hybrid_builder), ModelType.encoder_or_decoder, forward_step, - extra_args_provider=add_finetune_args, - args_defaults={"tokenizer_type": "HuggingFaceTokenizer"}, non_loss_data_func=non_loss_data_func, ) diff --git a/examples/post_training/modelopt/generate.py b/examples/post_training/modelopt/generate.py index 3d3f6571b34..75807cac7f7 100644 --- a/examples/post_training/modelopt/generate.py +++ b/examples/post_training/modelopt/generate.py @@ -8,21 +8,21 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) +import modelopt.torch.quantization as mtq import torch from datasets import load_dataset +from utils import get_hf_tokenizer from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.checkpointing import load_modelopt_checkpoint from megatron.post_training.generate import simple_generate -from megatron.post_training.model_builder import modelopt_gpt_mamba_builder +from megatron.post_training.model_builder import modelopt_gpt_hybrid_builder from megatron.post_training.utils import report_current_memory_info, to_empty_if_meta from megatron.training import get_args, get_model, initialize_megatron -from utils import get_hf_tokenizer +from megatron.training.arguments import parse_and_validate_args from megatron.training.utils import print_rank_0, unwrap_model from model_provider import model_provider -import modelopt.torch.quantization as mtq - warnings.filterwarnings('once') @@ -73,7 +73,7 @@ def get_conversations(example): if __name__ == "__main__": - initialize_megatron( + parse_and_validate_args( extra_args_provider=add_generate_args, args_defaults={ 'tokenizer_type': 'HuggingFaceTokenizer', @@ -81,6 +81,7 @@ def get_conversations(example): 'no_load_optim': True, }, ) + initialize_megatron() check_arguments() @@ -100,7 +101,9 @@ def get_conversations(example): UserWarning, ) - model = get_model(functools.partial(model_provider, modelopt_gpt_mamba_builder), wrap_with_ddp=False) + model = get_model( + functools.partial(model_provider, modelopt_gpt_hybrid_builder), wrap_with_ddp=False + ) report_current_memory_info() unwrapped_model = unwrap_model(model)[0] @@ -124,7 +127,6 @@ def get_conversations(example): tokenizer = get_hf_tokenizer() - if args.load is not None: load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights) print_rank_0("Done loading checkpoint") @@ -157,12 +159,19 @@ def get_conversations(example): ) ) ) - input_ids = tokenizer.apply_chat_template( - new_conversations, return_tensors="pt", add_generation_prompt=True + encoding = tokenizer.apply_chat_template( + new_conversations, + return_tensors="pt", + add_generation_prompt=True, + return_dict=True, ) + input_ids = encoding["input_ids"] with torch.no_grad(): output_ids = simple_generate( - unwrapped_model, input_ids.cuda(), osl=args.osl, disable_tqdm=args.disable_tqdm + unwrapped_model, + input_ids.cuda(), + osl=args.osl, + disable_tqdm=args.disable_tqdm, ) output_texts = tokenizer.batch_decode(output_ids)[0] print_rank_0("{}".format(output_texts)) diff --git a/examples/post_training/modelopt/mmlu.py b/examples/post_training/modelopt/mmlu.py index 5aa5d1c24c7..3ff4b51f957 100644 --- a/examples/post_training/modelopt/mmlu.py +++ b/examples/post_training/modelopt/mmlu.py @@ -2,42 +2,54 @@ """Sample Generate GPT.""" import functools +import logging import os import sys import warnings + import datasets -import logging import torch.distributed as dist sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) +import modelopt.torch.quantization as mtq import torch from diskcache import Cache +from utils import get_hf_tokenizer from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.checkpointing import load_modelopt_checkpoint from megatron.post_training.generate import simple_generate -from megatron.post_training.model_builder import modelopt_gpt_mamba_builder +from megatron.post_training.model_builder import modelopt_gpt_hybrid_builder from megatron.post_training.utils import report_current_memory_info from megatron.training import get_args, get_model, initialize_megatron -from utils import get_hf_tokenizer +from megatron.training.arguments import parse_and_validate_args from megatron.training.utils import print_rank_0, unwrap_model -import modelopt.torch.quantization as mtq from model_provider import model_provider logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) # set to debug if you need more logging +logger.setLevel(logging.INFO) # set to debug if you need more logging warnings.filterwarnings('ignore') + def add_mmlu_args(parser): """Add additional arguments for ModelOpt text generation PTQ.""" group = parser.add_argument_group(title='ModelOpt text generation ptq') group.add_argument("--disable-tqdm", action="store_true", help="Disable tqdm.") group.add_argument("--fraction", type=float, default=1.0, help="Fraction of dataset to use.") group.add_argument("--lower-bound", type=float, default=None) - group.add_argument("--no-subject-prompt", action="store_true", help="Use empty prompt instead of subject-based prompt.") - group.add_argument("--mmlu-dataset", type=str, default="cais/mmlu", help="The default dataset to use is cais/mmlu from the HG hub.") + group.add_argument( + "--no-subject-prompt", + action="store_true", + help="Use empty prompt instead of subject-based prompt.", + ) + group.add_argument( + "--mmlu-dataset", + type=str, + default="cais/mmlu", + help="The default dataset to use is cais/mmlu from the HG hub.", + ) group.add_argument("--cache-dir", type=str, default=None) add_modelopt_args(parser) return parser @@ -133,7 +145,7 @@ def generate_prompt(test_example, dev_examples, few_shots=0, no_subject_prompt=F if __name__ == "__main__": - initialize_megatron( + parse_and_validate_args( extra_args_provider=add_mmlu_args, args_defaults={ 'tokenizer_type': 'HuggingFaceTokenizer', @@ -141,6 +153,7 @@ def generate_prompt(test_example, dev_examples, few_shots=0, no_subject_prompt=F 'no_load_optim': True, }, ) + initialize_megatron() args = get_args() cache = Cache(args.cache_dir) @@ -158,7 +171,9 @@ def generate_prompt(test_example, dev_examples, few_shots=0, no_subject_prompt=F UserWarning, ) - model = get_model(functools.partial(model_provider, modelopt_gpt_mamba_builder), wrap_with_ddp=False) + model = get_model( + functools.partial(model_provider, modelopt_gpt_hybrid_builder), wrap_with_ddp=False + ) report_current_memory_info() # Materialize the model from meta device to gpu before loading the checkpoint. @@ -181,7 +196,10 @@ def generate_prompt(test_example, dev_examples, few_shots=0, no_subject_prompt=F # [TODO]: fold_weight does not support TEGroupedMLP (QuantTEColumnParallelGroupedLinear) # which stores per-expert weights as weight0, weight1, etc. instead of a single weight. has_grouped_mlp = any("TEGroupedMLP" in type(m).__name__ for m in unwrapped_model.modules()) - if not getattr(unwrapped_model, "share_embeddings_and_output_weights", False) and not has_grouped_mlp: + if ( + not getattr(unwrapped_model, "share_embeddings_and_output_weights", False) + and not has_grouped_mlp + ): mtq.fold_weight(unwrapped_model) all_subjects = get_all_subjects() @@ -197,8 +215,10 @@ def generate_prompt(test_example, dev_examples, few_shots=0, no_subject_prompt=F if idx > args.fraction * len(test_data): break label = ["A", "B", "C", "D"][test_example["answer"]] - prompt = generate_prompt(test_example, dev_data, few_shots=0, no_subject_prompt=args.no_subject_prompt) - cache_key = f"{args.load}_{subject}_{prompt}" # model name, subject, prompt + prompt = generate_prompt( + test_example, dev_data, few_shots=0, no_subject_prompt=args.no_subject_prompt + ) + cache_key = f"{args.load}_{subject}_{prompt}" # model name, subject, prompt if cache_key in cache: predict = cache[cache_key] diff --git a/examples/post_training/modelopt/offline_feature_extract.py b/examples/post_training/modelopt/offline_feature_extract.py index 80207faf2b2..6ea44943f74 100644 --- a/examples/post_training/modelopt/offline_feature_extract.py +++ b/examples/post_training/modelopt/offline_feature_extract.py @@ -14,7 +14,7 @@ from megatron.core import mpu from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.checkpointing import load_modelopt_checkpoint -from megatron.post_training.model_builder import modelopt_gpt_mamba_builder +from megatron.post_training.model_builder import modelopt_gpt_hybrid_builder from megatron.training import get_args, get_model, get_tokenizer, initialize_megatron from megatron.training.utils import print_rank_0, unwrap_model from model_provider import model_provider @@ -29,20 +29,33 @@ def add_extract_args(parser): add_modelopt_args(parser) return parser + def extract_feature(dataset, model, output_dir, idx_start, idx_end): os.makedirs(output_dir, exist_ok=True) - for i in range(idx_start + mpu.get_expert_data_parallel_rank(), idx_end, mpu.get_expert_data_parallel_world_size()): + for i in range( + idx_start + mpu.get_expert_data_parallel_rank(), + idx_end, + mpu.get_expert_data_parallel_world_size(), + ): file_name = "{:08d}.pt".format(i - idx_start) file_path = os.path.join(output_dir, file_name) if not os.path.exists(file_path): - input_ids = dataset[i]["input_ids"][:dataset.seq_length].unsqueeze(0).to(torch.cuda.current_device()) + input_ids = ( + dataset[i]["input_ids"][: dataset.seq_length] + .unsqueeze(0) + .to(torch.cuda.current_device()) + ) output = model(input_ids, return_eagle_inputs=True) - if mpu.get_tensor_model_parallel_rank() == 0 and mpu.get_expert_model_parallel_rank() == 0: + if ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_expert_model_parallel_rank() == 0 + ): torch.save(output, file_path) torch.distributed.barrier() + if __name__ == "__main__": - initialize_megatron( + parse_and_validate_args( extra_args_provider=add_extract_args, args_defaults={ 'tokenizer_type': 'HuggingFaceTokenizer', @@ -50,10 +63,13 @@ def extract_feature(dataset, model, output_dir, idx_start, idx_end): 'no_load_optim': True, }, ) + initialize_megatron() args = get_args() tokenizer = get_tokenizer() - model = get_model(functools.partial(model_provider, modelopt_gpt_mamba_builder), wrap_with_ddp=False) + model = get_model( + functools.partial(model_provider, modelopt_gpt_hybrid_builder), wrap_with_ddp=False + ) load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights) print_rank_0("Done loading checkpoint") @@ -71,8 +87,24 @@ def extract_feature(dataset, model, output_dir, idx_start, idx_end): } sft_dataset = SFTDataset(args.num_samples, None, **kwargs) - extract_feature(sft_dataset, unwrapped_model, os.path.join(args.output_dir, "train"), 0, int(args.num_samples * 0.98)) - extract_feature(sft_dataset, unwrapped_model, os.path.join(args.output_dir, "valid"), int(args.num_samples * 0.98), int(args.num_samples * 0.99)) - extract_feature(sft_dataset, unwrapped_model, os.path.join(args.output_dir, "test"), int(args.num_samples * 0.99), args.num_samples) - - + extract_feature( + sft_dataset, + unwrapped_model, + os.path.join(args.output_dir, "train"), + 0, + int(args.num_samples * 0.98), + ) + extract_feature( + sft_dataset, + unwrapped_model, + os.path.join(args.output_dir, "valid"), + int(args.num_samples * 0.98), + int(args.num_samples * 0.99), + ) + extract_feature( + sft_dataset, + unwrapped_model, + os.path.join(args.output_dir, "test"), + int(args.num_samples * 0.99), + args.num_samples, + ) diff --git a/examples/post_training/modelopt/prune.py b/examples/post_training/modelopt/prune.py index 56bbffa0cd0..0e3e4e41dda 100644 --- a/examples/post_training/modelopt/prune.py +++ b/examples/post_training/modelopt/prune.py @@ -20,6 +20,7 @@ import modelopt.torch.prune as mtp from modelopt.torch.export import import_mcore_gpt_from_hf from modelopt.torch.prune.plugins.mcore_minitron import SUPPORTED_HPARAMS +from utils import get_hf_tokenizer from megatron.core.parallel_state import ( get_pipeline_model_parallel_group, @@ -28,12 +29,10 @@ from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.checkpointing import load_modelopt_checkpoint from megatron.post_training.generate import simple_generate -from megatron.post_training.model_builder import modelopt_gpt_mamba_builder -from megatron.post_training.utils import ( - report_current_memory_info, -) +from megatron.post_training.model_builder import modelopt_gpt_hybrid_builder +from megatron.post_training.utils import report_current_memory_info from megatron.training import get_args, get_model, initialize_megatron -from utils import get_hf_tokenizer +from megatron.training.arguments import parse_and_validate_args from megatron.training.checkpointing import save_checkpoint from megatron.training.utils import print_rank_0, unwrap_model from model_provider import model_provider @@ -149,7 +148,7 @@ def get_params(model): if __name__ == "__main__": - initialize_megatron( + parse_and_validate_args( extra_args_provider=add_prune_args, args_defaults={ "tokenizer_type": "HuggingFaceTokenizer", @@ -157,13 +156,14 @@ def get_params(model): "no_load_optim": True, }, ) + initialize_megatron() args = get_args() check_arguments(args) tokenizer = get_hf_tokenizer() model = get_model( - functools.partial(model_provider, modelopt_gpt_mamba_builder), wrap_with_ddp=False + functools.partial(model_provider, modelopt_gpt_hybrid_builder), wrap_with_ddp=False ) unwrapped_model = unwrap_model(model)[0] diff --git a/examples/post_training/modelopt/quantize.py b/examples/post_training/modelopt/quantize.py index dc4947038e5..f9935d48a5d 100644 --- a/examples/post_training/modelopt/quantize.py +++ b/examples/post_training/modelopt/quantize.py @@ -1,8 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -"""Sample Generate GPT.""" +"""Script for quantizing a HuggingFace or Megatron-LM checkpoint using ModelOpt.""" -import copy import functools import inspect import json @@ -19,6 +18,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) import modelopt.torch.quantization as mtq +from modelopt.recipe import ModelOptPTQRecipe, load_recipe from modelopt.torch.export import import_mcore_gpt_from_hf from modelopt.torch.utils.dataset_utils import get_dataset_dataloader @@ -35,17 +35,16 @@ mtq_luts = None warnings.warn("luts is not installed. LUTs quantization configs will not be available.") +from utils import get_hf_tokenizer + from megatron.core.utils import get_batch_on_this_cp_rank from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.checkpointing import load_modelopt_checkpoint from megatron.post_training.generate import simple_generate -from megatron.post_training.model_builder import modelopt_gpt_mamba_builder -from megatron.post_training.utils import ( - print_distributed_quant_summary, - report_current_memory_info, -) +from megatron.post_training.model_builder import modelopt_gpt_hybrid_builder +from megatron.post_training.utils import print_distributed_quant_summary, report_current_memory_info from megatron.training import get_args, get_model, initialize_megatron -from utils import get_hf_tokenizer +from megatron.training.arguments import parse_and_validate_args from megatron.training.checkpointing import save_checkpoint from megatron.training.utils import print_rank_0, unwrap_model from model_provider import model_provider @@ -129,21 +128,21 @@ def add_text_generate_ptq_args(parser): ) group.add_argument("--weight-only", action="store_true", help="Disable input quantization.") group.add_argument( - "--force-all-expert-routing", - action="store_true", - help="Forcing all experts to be routed during the calibration.", - ) - group.add_argument( - "--num-first-layers-to-skip-quant", - type=int, + "--recipe", + type=str, default=None, - help="Number of first layers to skip quantization.", + help=( + "PTQ recipe YAML file or name without suffix (e.g. " + "'general/ptq/nvfp4_default-fp8_kv', " + "'models/Nemotron-3-Super-120B-A12B/super-nvfp4'). " + "When set, --export-quant-cfg / --export-kv-cache-quant are ignored; " + "the recipe is authoritative for quant_cfg, algorithm, and KV cache config." + ), ) group.add_argument( - "--num-last-layers-to-skip-quant", - type=int, - default=None, - help="Number of last layers to skip quantization.", + "--sync-expert-weight-amax", + action="store_true", + help="Synchronize expert weight amax across experts.", ) add_modelopt_args(parser) return parser @@ -161,91 +160,59 @@ def check_arguments(): args.moe_grouped_gemm = False -def _is_first_layers(name: str, num_layers: int = 1, num_layers_to_disable: int = 1) -> bool: - if "layers." not in name: - return False - try: - layer_idx = int(name.split("layers.")[-1].split(".")[0]) - except ValueError: - return False - return layer_idx < num_layers_to_disable - - -def _is_last_layers(name: str, num_layers: int = 1, num_layers_to_disable: int = 1) -> bool: - if "layers." not in name: - return False - try: - layer_idx = int(name.split("layers.")[-1].split(".")[0]) - except ValueError: - return False - return layer_idx >= num_layers - num_layers_to_disable - - -def get_first_layers_disabled_config(config, num_layers: int = 1, num_layers_to_disable: int = 1): - """Get a config for `mtq.quantize` with first & last `num_layers_to_disable` layers disabled. - - The layers to disable are the first & last `num_layers_to_disable` layers. - """ - config = copy.deepcopy(config) - quant_cfg = config.get("quant_cfg", {}) - quant_cfg.update( - { - functools.partial( - _is_first_layers, num_layers=num_layers, num_layers_to_disable=num_layers_to_disable - ): {"enable": False} - } - ) - config["quant_cfg"] = quant_cfg - return config - - -def get_last_layers_disabled_config(config, num_layers: int = 1, num_layers_to_disable: int = 1): - """Get a config for `mtq.quantize` with last `num_layers_to_disable` layers disabled. - - The layers to disable are the last `num_layers_to_disable` layers. - """ - config = copy.deepcopy(config) - quant_cfg = config.get("quant_cfg", {}) - quant_cfg.update( - { - functools.partial( - _is_last_layers, num_layers=num_layers, num_layers_to_disable=num_layers_to_disable - ): {"enable": False} - } - ) - config["quant_cfg"] = quant_cfg - return config - - def get_modelopt_torch_quantization_config(): """Return a quantization config.""" args = get_args() + + if args.recipe is not None: + # YAML recipe is authoritative: skip predefined-config customizations and KV + # cache override; the recipe encodes quant_cfg + algorithm + KV cache directly. + print_rank_0(f"Use recipe {args.recipe} for quantization") + recipe = load_recipe(args.recipe) + if not isinstance(recipe, ModelOptPTQRecipe): + raise TypeError( + f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" + ) + if args.export_kv_cache_quant != "none": + print_rank_0( + f"Ignoring --export-kv-cache-quant={args.export_kv_cache_quant} since you passed in a YAML recipe." + ) + return recipe.quantize.model_dump() + if args.export_quant_cfg not in QUANT_CFG_CHOICES: raise ValueError(f"Unsupported quantization config {args.export_quant_cfg}.") mtq_config = QUANT_CFG_CHOICES[args.export_quant_cfg] - fp8_config = {"enable": True, "num_bits": (4, 3), "axis": None} + if isinstance(mtq_config["quant_cfg"], dict): + # Normalize old dict format to new list format + mtq_config["quant_cfg"] = mtq.normalize_quant_cfg_list(mtq_config["quant_cfg"]) + + fp8_config = {"enable": True, "cfg": {"num_bits": (4, 3), "axis": None}} fp4_config = { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, "enable": True, + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, } if args.export_quant_cfg == "FP8_DEFAULT_CFG": # Enable Medusa heads and kv-cache quantization - mtq_config["quant_cfg"]["*medusa_heads**"] = fp8_config + mtq_config["quant_cfg"].append({"quantizer_name": "*medusa_heads**", **fp8_config}) if "FP4" in args.export_quant_cfg: # Enable Medusa heads and kv-cache quantization - mtq_config["quant_cfg"]["*medusa_heads**"] = fp4_config + mtq_config["quant_cfg"].append({"quantizer_name": "*medusa_heads**", **fp4_config}) if "AWQ" in args.export_quant_cfg: - weight_quantizer = mtq_config["quant_cfg"]["*weight_quantizer"] # type: ignore - if isinstance(weight_quantizer, list): - weight_quantizer = weight_quantizer[0] - weight_quantizer["block_sizes"][-1] = 128 - + try: + weight_quantizer = mtq.find_quant_cfg_entry_by_path( + mtq_config["quant_cfg"], "*weight_quantizer" + ) + weight_quantizer["block_sizes"][-1] = 128 + except KeyError: + weight_quantizer = None # Customization if args.disable_qkv_quant: - mtq_config["quant_cfg"]["*self_attention*"] = {"enable": False} + mtq_config["quant_cfg"].append({"quantizer_name": "*self_attention*", "enable": False}) # KV Cache Quantization enable_quant_kv_cache = args.export_kv_cache_quant != "none" @@ -257,20 +224,7 @@ def get_modelopt_torch_quantization_config(): # Weight Only Quantization if args.weight_only: - mtq_config["quant_cfg"]["*input_quantizer"] = {"enable": False} - if args.num_first_layers_to_skip_quant is not None: - mtq_config = get_first_layers_disabled_config( - mtq_config, - num_layers=args.num_layers, - num_layers_to_disable=args.num_first_layers_to_skip_quant, - ) - if args.num_last_layers_to_skip_quant is not None: - mtq_config = get_last_layers_disabled_config( - mtq_config, - num_layers=args.num_layers, - num_layers_to_disable=args.num_last_layers_to_skip_quant, - ) - + mtq_config["quant_cfg"].append({"quantizer_name": "*input_quantizer", "enable": False}) return mtq_config @@ -294,6 +248,8 @@ def get_calib_dataloader( for i, line in enumerate(f): if len(all_texts) == calib_size: break + if not line.strip(): + continue sample = json.loads(line) # Extract text field from various possible keys @@ -305,7 +261,9 @@ def get_calib_dataloader( elif isinstance(sample, dict) and "messages" in sample: conversations = sample["messages"] assert "role" in conversations[0] and "content" in conversations[0] - full_text = "".join([f"{msg['role']}: {msg['content']}" for msg in conversations]) + full_text = "".join( + [f"{msg['role']}: {msg['content']}" for msg in conversations] + ) elif isinstance(sample, list) and isinstance(sample[0], dict): assert "role" in sample[0] and "content" in sample[0] full_text = "".join([f"{msg['role']}: {msg['content']}" for msg in sample]) @@ -313,23 +271,36 @@ def get_calib_dataloader( raise ValueError(f"Sample {i} has unexpected format") # Slice text - max_text_length = int(max_sequence_length / 0.75) # tokenized text is roughtly ~75% length of original + max_text_length = int( + max_sequence_length / 0.75 + ) # tokenized text is roughtly ~75% length of original start_idx = 0 if use_random_offset and len(full_text) > max_text_length: start_idx = random.randint(0, len(full_text) - max_text_length) text = full_text[start_idx : start_idx + max_text_length] all_texts.append(text) - print_rank_0(f"Loaded calibration dataset ({dataset_path_or_name}) with {len(all_texts)} samples") + print_rank_0( + f"Loaded calibration dataset ({dataset_path_or_name}) with {len(all_texts)} samples" + ) print_rank_0(f"Actual num samples: {len(all_texts)}, max seq length: {max_sequence_length}") - print_rank_0(f"Sampling Strategy: {'Random Index' if use_random_offset else 'From Beginning'}") + print_rank_0( + f"Sampling Strategy: {'Random Index' if use_random_offset else 'From Beginning'}" + ) # Tokenize all texts at once and move to device tokens = tokenizer( - all_texts, return_tensors="pt", padding="max_length", max_length=max_sequence_length, truncation=True + all_texts, + return_tensors="pt", + padding="max_length", + max_length=max_sequence_length, + truncation=True, ) all_input_ids = tokens.input_ids.cuda() - return [{"input_ids": all_input_ids[i:i+batch_size]} for i in range(0, len(all_input_ids), batch_size)] + return [ + {"input_ids": all_input_ids[i : i + batch_size]} + for i in range(0, len(all_input_ids), batch_size) + ] else: # HuggingFace dataset if use_random_offset: @@ -346,7 +317,7 @@ def get_calib_dataloader( if __name__ == "__main__": - initialize_megatron( + parse_and_validate_args( extra_args_provider=add_text_generate_ptq_args, args_defaults={ "tokenizer_type": "HuggingFaceTokenizer", @@ -354,6 +325,7 @@ def get_calib_dataloader( "no_load_optim": True, }, ) + initialize_megatron() check_arguments() @@ -362,7 +334,7 @@ def get_calib_dataloader( tokenizer = get_hf_tokenizer() model = get_model( - functools.partial(model_provider, modelopt_gpt_mamba_builder), wrap_with_ddp=False + functools.partial(model_provider, modelopt_gpt_hybrid_builder), wrap_with_ddp=False ) report_current_memory_info() @@ -415,12 +387,7 @@ def _dataset_forward_loop_func(model): unwrapped_model = unwrap_model(model)[0] - if args.force_all_expert_routing: - warnings.warn( - "--force-all-expert-routing will be deprecated in the next release and is no longer needed." - ) - - if args.export_quant_cfg is not None: + if args.export_quant_cfg is not None or args.recipe is not None: print_rank_0("Quantizing the model...") mtq_config = get_modelopt_torch_quantization_config() @@ -448,6 +415,7 @@ def _dataset_forward_loop_func(model): # Free calibration/quantization memory before generate import gc + gc.collect() torch.cuda.empty_cache() diff --git a/examples/post_training/modelopt/quantize.sh b/examples/post_training/modelopt/quantize.sh index 9119ff4ae76..e96b224f3c1 100755 --- a/examples/post_training/modelopt/quantize.sh +++ b/examples/post_training/modelopt/quantize.sh @@ -20,6 +20,18 @@ if [ -z ${QUANT_CFG} ]; then printf "${MLM_WARNING} Variable ${PURPLE}QUANT_CFG${WHITE} is not set (default: ${QUANT_CFG})!\n" fi +# If the 2nd positional arg looks like a recipe path (contains '/' or ends in +# '.yaml'/'.yml') pass it via --recipe; otherwise treat it as a built-in +# config name and pass it via --export-quant-cfg. +case "${QUANT_CFG}" in + */*|*.yaml|*.yml) + QUANT_CFG_ARGS=(--recipe "${QUANT_CFG}") + ;; + *) + QUANT_CFG_ARGS=(--export-quant-cfg "${QUANT_CFG}") + ;; +esac + if [ -z ${MLM_MODEL_SAVE} ]; then MLM_MODEL_SAVE=${MLM_WORK_DIR}/${MLM_MODEL_CFG}_quant printf "${MLM_WARNING} Variable ${PURPLE}MLM_MODEL_SAVE${WHITE} is not set (default: ${MLM_MODEL_SAVE})!\n" @@ -41,7 +53,7 @@ if [ -z ${MLM_MODEL_CKPT} ]; then --tokenizer-model ${TOKENIZER_MODEL} \ --pretrained-model-path ${HF_MODEL_CKPT} \ --save ${MLM_MODEL_SAVE} \ - --export-quant-cfg ${QUANT_CFG} \ + "${QUANT_CFG_ARGS[@]}" \ --references "${MLM_REF_LABEL}" \ "${EXTRA_ARGS[@]}" else @@ -55,7 +67,7 @@ else --tokenizer-model ${TOKENIZER_MODEL} \ --load ${MLM_MODEL_CKPT} \ --save ${MLM_MODEL_SAVE} \ - --export-quant-cfg ${QUANT_CFG} \ + "${QUANT_CFG_ARGS[@]}" \ --references "${MLM_REF_LABEL}" \ "${EXTRA_ARGS[@]}" fi diff --git a/examples/post_training/modelopt/train.sh b/examples/post_training/modelopt/train.sh index 1ebb8bf3d76..3afcd4f5be7 100755 --- a/examples/post_training/modelopt/train.sh +++ b/examples/post_training/modelopt/train.sh @@ -69,8 +69,8 @@ fi export HF_TOKEN=${HF_TOKEN} -if [[ ${MODEL_ARGS} == *"MambaModel"* ]]; then - PRETRAIN_EXE=${SCRIPT_DIR}/../../../pretrain_mamba.py +if [[ ${MODEL_ARGS} == *"HybridModel"* ]] || [[ ${MODEL_ARGS} == *"MambaModel"* ]]; then + PRETRAIN_EXE=${SCRIPT_DIR}/../../../pretrain_hybrid.py else PRETRAIN_EXE=${SCRIPT_DIR}/../../../pretrain_gpt.py fi diff --git a/examples/post_training/modelopt/validate.py b/examples/post_training/modelopt/validate.py index 8b8f1ffc9dd..3b3f855c393 100644 --- a/examples/post_training/modelopt/validate.py +++ b/examples/post_training/modelopt/validate.py @@ -11,26 +11,24 @@ import torch from modelopt.torch.speculative.plugins.megatron_eagle import MegatronARValidation +from utils import get_hf_tokenizer from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.checkpointing import load_modelopt_checkpoint -from megatron.post_training.model_builder import modelopt_gpt_mamba_builder +from megatron.post_training.model_builder import modelopt_gpt_hybrid_builder from megatron.post_training.utils import get_mtbench_chat_data from megatron.training import get_args, get_model, initialize_megatron -from utils import get_hf_tokenizer +from megatron.training.arguments import parse_and_validate_args from megatron.training.utils import print_rank_0, unwrap_model from model_provider import model_provider warnings.filterwarnings('ignore') - def add_ar_validation_args(parser): """Add additional arguments for ModelOpt acceptance rate validation.""" group = parser.add_argument_group(title='ModelOpt ar validation') - group.add_argument( - "--osl", type=int, default=64, help="Output sequence length." - ) + group.add_argument("--osl", type=int, default=64, help="Output sequence length.") parser.add_argument( "--prompts-path", type=str, @@ -38,14 +36,9 @@ def add_ar_validation_args(parser): help="Path to the prompts json file. If not provided, MTBench will be used.", ) parser.add_argument( - "--ground-truth-path", - type=str, - default=None, - help="Path to the ground truth pt file.", - ) - parser.add_argument( - "--steps", type=int, default=1, help="Only used in EAGLE." + "--ground-truth-path", type=str, default=None, help="Path to the ground truth pt file." ) + parser.add_argument("--steps", type=int, default=1, help="Only used in EAGLE.") parser.add_argument( "--save-ground-truth-path", type=str, @@ -86,10 +79,8 @@ def report_current_memory_info(): torch.distributed.barrier() - - if __name__ == "__main__": - initialize_megatron( + parse_and_validate_args( extra_args_provider=add_ar_validation_args, args_defaults={ 'tokenizer_type': 'HuggingFaceTokenizer', @@ -97,6 +88,7 @@ def report_current_memory_info(): 'no_load_optim': True, }, ) + initialize_megatron() check_arguments() @@ -116,7 +108,9 @@ def report_current_memory_info(): ground_truth = [None for _ in range(len(prompts))] tokenizer = get_hf_tokenizer() - model = get_model(functools.partial(model_provider, modelopt_gpt_mamba_builder), wrap_with_ddp=False) + model = get_model( + functools.partial(model_provider, modelopt_gpt_hybrid_builder), wrap_with_ddp=False + ) report_current_memory_info() @@ -124,7 +118,6 @@ def report_current_memory_info(): load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights) print_rank_0("Done loading checkpoint") - unwrapped_model = unwrap_model(model)[0] unwrapped_model.eval() @@ -136,7 +129,7 @@ def report_current_memory_info(): gt.append(output[0]) ar.append(output[1]) print_rank_0("Acceptance Rate: " + str(ar)) - print_rank_0("Average: " + str(sum(ar)/len(ar))) + print_rank_0("Average: " + str(sum(ar) / len(ar))) if args.save_ground_truth_path is not None: torch.save(gt, args.save_ground_truth_path) diff --git a/examples/rl/model_configs/llama3p1_8b_instruct.sh b/examples/rl/model_configs/llama3p1_8b_instruct.sh index ff3b5327710..325c1d80617 100644 --- a/examples/rl/model_configs/llama3p1_8b_instruct.sh +++ b/examples/rl/model_configs/llama3p1_8b_instruct.sh @@ -101,8 +101,6 @@ MODEL_OPTIONS="\ --max-position-embeddings 131072 \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model unsloth/Meta-Llama-3.1-8B-Instruct \ - --tokenizer-hf-use-fast \ - --tokenizer-hf-include-special-tokens \ --lr 3e-7 \ --make-vocab-size-divisible-by 128 \ --clip-grad 1.0 \ diff --git a/examples/rl/model_configs/nemotron5_56b.sh b/examples/rl/model_configs/nemotron5_56b.sh index 23b9f99a72a..b4fcee17a8e 100644 --- a/examples/rl/model_configs/nemotron5_56b.sh +++ b/examples/rl/model_configs/nemotron5_56b.sh @@ -69,7 +69,7 @@ MODEL_OPTIONS="\ \ --fp8-recipe tensorwise \ --hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ - --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --spec megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec \ --mamba-state-dim 256 \ --per-split-data-args-path ${BLEND_PATH} \ --tiktoken-pattern v2 \ diff --git a/examples/rl/model_configs/nemotron5_8b.sh b/examples/rl/model_configs/nemotron5_8b.sh index c18149f03d6..198efd2a163 100644 --- a/examples/rl/model_configs/nemotron5_8b.sh +++ b/examples/rl/model_configs/nemotron5_8b.sh @@ -61,7 +61,7 @@ MODEL_OPTIONS="\ --inference-max-requests $MAX_INFERENCE_BS \ --pretrained-checkpoint $CHECKPOINT \ --hybrid-layer-pattern M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- \ - --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --spec megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec \ --tiktoken-pattern v2 \ --distributed-timeout-minutes 60 \ --use-mcore-models \ diff --git a/examples/rl/model_configs/nemotron5p5_12b_H.sh b/examples/rl/model_configs/nemotron5p5_12b_H.sh index 1826d57e913..bfb4c7e4727 100644 --- a/examples/rl/model_configs/nemotron5p5_12b_H.sh +++ b/examples/rl/model_configs/nemotron5p5_12b_H.sh @@ -76,7 +76,7 @@ MODEL_OPTIONS="\ --disable-gloo-process-groups \ --mamba-head-dim 80 \ --hybrid-layer-pattern M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- \ - --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --spec megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec \ --tiktoken-pattern v2 \ --distributed-timeout-minutes 10 \ --use-mcore-models \ diff --git a/examples/rl/model_configs/nemotron6_3b_moe.sh b/examples/rl/model_configs/nemotron6_3b_moe.sh index 85de0c6be0a..a807f270a01 100644 --- a/examples/rl/model_configs/nemotron6_3b_moe.sh +++ b/examples/rl/model_configs/nemotron6_3b_moe.sh @@ -104,7 +104,6 @@ MODEL_OPTIONS="\ --tiktoken-pattern v2 \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model ${TOKENIZER_MODEL} \ - --tokenizer-hf-include-special-tokens \ --dist-ckpt-strictness log_unexpected \ --ckpt-format torch_dist \ --ckpt-fully-parallel-save \ diff --git a/examples/rl/model_configs/qwen3_30b_a3b_moe.sh b/examples/rl/model_configs/qwen3_30b_a3b_moe.sh index 637b431280f..eb55ba35cc6 100644 --- a/examples/rl/model_configs/qwen3_30b_a3b_moe.sh +++ b/examples/rl/model_configs/qwen3_30b_a3b_moe.sh @@ -51,7 +51,6 @@ MODEL_OPTIONS=" --te-rng-tracker \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model Qwen/Qwen3-30B-A3B \ ---tokenizer-hf-include-special-tokens \ --untie-embeddings-and-output-weights \ --num-layers 48 \ --hidden-size 2048 \ diff --git a/examples/rl/model_configs/qwen3_32b.sh b/examples/rl/model_configs/qwen3_32b.sh index fcadb0c4021..c06c5f55b53 100644 --- a/examples/rl/model_configs/qwen3_32b.sh +++ b/examples/rl/model_configs/qwen3_32b.sh @@ -64,7 +64,6 @@ MODEL_OPTIONS="\ --attention-softmax-in-fp32 \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model Qwen/Qwen3-4B \ - --tokenizer-hf-include-special-tokens \ --vocab-size 151936 \ --make-vocab-size-divisible-by 128 \ --optimizer adam \ diff --git a/examples/rl/model_configs/qwen_2p5_32b.sh b/examples/rl/model_configs/qwen_2p5_32b.sh index 0bfe19ba1bb..2a2a9ae2420 100644 --- a/examples/rl/model_configs/qwen_2p5_32b.sh +++ b/examples/rl/model_configs/qwen_2p5_32b.sh @@ -85,7 +85,6 @@ MODEL_OPTIONS="\ --max-position-embeddings 131072 \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model unsloth/Qwen2.5-32B \ - --tokenizer-hf-include-special-tokens \ --lr 1e-6 \ --lr-warmup-samples 0 \ --make-vocab-size-divisible-by 128 \ diff --git a/examples/rl/model_configs/qwen_2p5_3b.sh b/examples/rl/model_configs/qwen_2p5_3b.sh index 4880272d4a6..647023d3050 100644 --- a/examples/rl/model_configs/qwen_2p5_3b.sh +++ b/examples/rl/model_configs/qwen_2p5_3b.sh @@ -87,7 +87,6 @@ MODEL_OPTIONS="\ --max-position-embeddings 32768 \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model unsloth/Qwen2.5-3B \ - --tokenizer-hf-include-special-tokens \ --lr 0.000001 \ --lr-warmup-samples 0 \ --make-vocab-size-divisible-by 64 \ diff --git a/examples/rl/model_configs/qwen_2p5_math_7b.sh b/examples/rl/model_configs/qwen_2p5_math_7b.sh index b00077bc07a..b598bb127bd 100644 --- a/examples/rl/model_configs/qwen_2p5_math_7b.sh +++ b/examples/rl/model_configs/qwen_2p5_math_7b.sh @@ -84,7 +84,6 @@ MODEL_OPTIONS="\ --max-position-embeddings 4096 \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model "unsloth/Qwen2.5-Math-7B" \ - --tokenizer-hf-include-special-tokens \ --lr 0.000001 \ --lr-warmup-samples 0 \ --make-vocab-size-divisible-by 128 \ diff --git a/gpt_builders.py b/gpt_builders.py index 4f3f983bc5c..fca96eb7d06 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -1,17 +1,17 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, + get_transformer_layer_with_experimental_attention_variant_spec, +) from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_decoder_block_spec, + get_gpt_decoder_layer_specs, get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_with_inference_spec, + get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec, - get_gpt_decoder_layer_specs, -) -from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( - get_transformer_block_with_experimental_attention_variant_spec, - get_transformer_layer_with_experimental_attention_variant_spec, ) from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( get_gpt_heterogeneous_layer_spec, @@ -92,10 +92,7 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_ mtp_transformer_layer_spec = decoder_layer_specs[-1] # Use spec of the last layer in decoder block as spec of the transformer layer in MTP mtp_block_spec = get_gpt_mtp_block_spec( - config, - mtp_transformer_layer_spec, - use_transformer_engine=use_te, - vp_stage=vp_stage, + config, mtp_transformer_layer_spec, use_transformer_engine=use_te, vp_stage=vp_stage ) model = GPTModel( @@ -149,9 +146,7 @@ def _get_transformer_layer_spec(use_te, config): ) elif config.transformer_impl == "inference_optimized": return get_gpt_layer_with_inference_spec( - config.qk_layernorm, - config.multi_latent_attention, - qk_l2_norm=config.qk_l2_norm, + config.qk_layernorm, config.multi_latent_attention, qk_l2_norm=config.qk_l2_norm ) else: return get_gpt_layer_local_spec( diff --git a/hybrid_builders.py b/hybrid_builders.py new file mode 100644 index 00000000000..05b219277ef --- /dev/null +++ b/hybrid_builders.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_inference_stack_spec +from megatron.core.models.hybrid.hybrid_model import HybridModel +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.spec_utils import import_module +from megatron.training import print_rank_0 +from megatron.training.arguments import core_transformer_config_from_args +from model_provider import count_parameters_in_layer + + +def hybrid_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None): + print_rank_0('building Hybrid model ...') + if config is None: + config = core_transformer_config_from_args(args, TransformerConfig) + + if config.transformer_impl == "inference_optimized": + hybrid_stack_spec = hybrid_inference_stack_spec + assert ( + not config.inference_fuse_tp_communication + ), "inference_fuse_tp_communication is not supported for HybridModel" + elif args.spec is not None: + hybrid_stack_spec = import_module(args.spec) + else: + raise ValueError("You must provide a valid hybrid layer spec via --spec") + + model = HybridModel( + config=config, + hybrid_stack_spec=hybrid_stack_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + hybrid_layer_pattern=args.hybrid_layer_pattern, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + pg_collection=pg_collection, + vp_stage=vp_stage, + ) + + for l in range(model.decoder.num_layers_per_pipeline_rank): + layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.') + print_rank_0(f" == params layer {l}: {layer_params}") + + return model + + +# Backward-compatible alias +mamba_builder = hybrid_builder diff --git a/mamba_builders.py b/mamba_builders.py index 650ea4a719f..f824fce9be3 100644 --- a/mamba_builders.py +++ b/mamba_builders.py @@ -1,50 +1,15 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. +"""Backward-compatible re-export of hybrid_builders. -from model_provider import count_parameters_in_layer -from megatron.core.models.mamba import MambaModel -from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.spec_utils import import_module -from megatron.training import print_rank_0 -from megatron.training.arguments import core_transformer_config_from_args -from megatron.core.models.mamba.mamba_layer_specs import mamba_inference_stack_spec +Deprecated. Use hybrid_builders instead. +""" +import warnings +warnings.warn( + "mamba_builders has been deprecated. Use hybrid_builders instead.", + DeprecationWarning, + stacklevel=2, +) -def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None): - print_rank_0('building MAMBA model ...') - if config is None: - config = core_transformer_config_from_args(args, TransformerConfig) - assert args.use_legacy_models is False, "Mamba only supported in Mcore!" - - if config.transformer_impl == "inference_optimized": - mamba_stack_spec = mamba_inference_stack_spec - assert ( - not config.inference_fuse_tp_communication - ), "inference_fuse_tp_communication is not supported for Mamba" - elif args.spec is not None: - mamba_stack_spec = import_module(args.spec) - else: - raise ValueError("You must provide a valid Mamba layer spec via --spec") - - model = MambaModel( - config=config, - mamba_stack_spec=mamba_stack_spec, - vocab_size=args.padded_vocab_size, - max_sequence_length=args.max_position_embeddings, - hybrid_layer_pattern=args.hybrid_layer_pattern, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, - rotary_base=args.rotary_base, - pg_collection=pg_collection, - vp_stage=vp_stage, - ) - - for l in range(model.decoder.num_layers_per_pipeline_rank): - layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.') - print_rank_0(f" == params layer {l}: {layer_params}") - - return model +from hybrid_builders import * # noqa: F401,F403 +from hybrid_builders import hybrid_builder as mamba_builder # noqa: F401 diff --git a/megatron/core/__init__.py b/megatron/core/__init__.py index b9668b2ce66..4c6dbbad5ea 100644 --- a/megatron/core/__init__.py +++ b/megatron/core/__init__.py @@ -1,5 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import torch + import megatron.core.tensor_parallel import megatron.core.utils from megatron.core import parallel_state @@ -46,7 +48,11 @@ "__version__", ] -from .safe_globals import register_safe_globals +from .safe_globals import register_safe_globals, safe_load_from_bytes if is_torch_min_version("2.6a0"): register_safe_globals() + +# Avoid direct usage of unsafe `torch.storage._load_from_bytes` (weights_only=False) +# Use safe implementation with weights_only=True +torch.storage._load_from_bytes = safe_load_from_bytes diff --git a/megatron/core/datasets/blended_dataset.py b/megatron/core/datasets/blended_dataset.py index 802a9770506..9b642ee1ff3 100644 --- a/megatron/core/datasets/blended_dataset.py +++ b/megatron/core/datasets/blended_dataset.py @@ -150,7 +150,10 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: else: cache_hit = False - if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0): + if not path_to_cache or ( + not cache_hit + and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0) + ): log_single_rank( logger, logging.INFO, f"Build and save the {type(self).__name__} indices" ) diff --git a/megatron/core/datasets/readme.md b/megatron/core/datasets/readme.md index a61c623d960..afbed9046a3 100644 --- a/megatron/core/datasets/readme.md +++ b/megatron/core/datasets/readme.md @@ -254,12 +254,25 @@ Utility functions consumed by the schedulers above: | `broadcast_tensor()` | Broadcast a single tensor within a process group. | | `create_data_iterator()` | Wrap packed sample lists into a data iterator; handles VPP stage splitting. | + +## Offline cache preparation + +For GPT-style training, the dataset caches described above can be prepared ahead of time with `tools/prepare_cache.py` instead of waiting for rank 0 to build them during training startup. + +The script reuses the normal dataset construction path used by `pretrain_gpt.py` and `pretrain_mamba.py`, including `GPTDataset`, `BlendedDataset`, and `BlendedMegatronDatasetBuilder`. It accepts the usual dataset arguments, supports blends and per-split dataset definitions, and requires `--data-cache-path` so the generated cache can later be reused by training. + +This is especially useful for large blends or many file prefixes, where building the document, sample, and shuffle indices can take several minutes and leave all GPUs idle while rank 0 performs CPU-only work. + +If the later training job does not specify `--global-batch-size` (which is needed to determine the dataset size and splits), you should specify `--prepare-cache-world-size` to explicitly set the world size used during cache preparation. + +`tools/prepare_cache.py` does not support `--mock-data`, `--sft`, `--fim-data`, or `--step-batch-size-schedule`. + ## Fast DataLoader initialization -Especially for large-scale runs, DataLoader initialization can take several minutes, since it involves opening and memory-mapping multiple files and can significantly stress the filesystem. To speed up this process, we have developed the following three optimizations, controlled by configuration flags": +Especially for large-scale runs, DataLoader initialization can take several minutes, since it involves opening and memory-mapping multiple files and can significantly stress the filesystem. To speed up this process, we have developed the following three optimizations, controlled by configuration flags: - `--dataloader-fast-cache-load`: This option assumes that the dataset cache already exists in the specified `--data-cache-path`. When enabled, it speeds up the creation process by removing synchronization points and file check assertions. - `--dataloader-defer-npy-index-mmap`: This option also assumes that the dataset cache already exists in the specified `--data-cache-path`. When enabled, it defers the memory mapping of the dataset indexes (.npy files) until their first access. We recommend using this configuration together with `--num-workers` > 0 so that the DataLoader prefetches the next batches of data, thereby hiding the cost of index memory mapping. - - `--per-dataset-sequences-path`: With this configuration, we specify the JSON file generated by the `tools/build_sequences_per_dataset.py` script. This script generates a single file containing the required metadata from all the specified file prefixes. This configuration is especially useful when dealing with hundreds to thousands of file prefixes, since it requires only a single `open` operation instead of one per file prefix. \ No newline at end of file + - `--per-dataset-sequences-path`: With this configuration, we specify the JSON file generated by the `tools/build_sequences_per_dataset.py` script. This script generates a single file containing the required metadata from all the specified file prefixes. This configuration is especially useful when dealing with hundreds to thousands of file prefixes, since it requires only a single `open` operation instead of one per file prefix. diff --git a/megatron/core/dist_checkpointing/serialization.py b/megatron/core/dist_checkpointing/serialization.py index 2ee7970f143..1d42a03c0c5 100644 --- a/megatron/core/dist_checkpointing/serialization.py +++ b/megatron/core/dist_checkpointing/serialization.py @@ -36,8 +36,10 @@ StrictHandling, determine_global_metadata, parse_strict_flag, + save_integrity_manifest, validate_integrity_and_strict_load, verify_checkpoint, + verify_integrity_manifest, ) logger = logging.getLogger(__name__) @@ -56,6 +58,7 @@ def load( common_strategy: None = None, validate_access_integrity: bool = True, strict: Union[str, StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED, + verify_integrity: bool = False, ) -> Union[StateDict, Tuple[StateDict, Set[str], Set[str]]]: """Loading entrypoint. @@ -89,6 +92,10 @@ def load( incur any performance overhead. Other recommended values are: `False` (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys or `StrictHandling.RETURN_ALL` which returns all mismatch keys. + verify_integrity (bool, optional): if True, re-hashes every checkpoint file + and compares against the SHA-256 manifest. Raises `CheckpointingException` on any + mismatch. Requires that the checkpoint was previously saved with + `verify_integrity=True`. Returns: StateDict or Tuple[StateDict, Set[str], Set[str]]: in most cases only @@ -97,6 +104,8 @@ def load( assert common_strategy is None verify_checkpoint(checkpoint_dir) + if verify_integrity: + verify_integrity_manifest(checkpoint_dir) if sharded_strategy is None: sharded_strategy = TorchDistLoadShardedStrategy() @@ -139,7 +148,12 @@ def load( ckpt_sharded_metadata, ) - async_strategy = getattr(common_state_dict.get("args"), "async_strategy", "nvrx") + ckpt_args = common_state_dict.get("args") + async_strategy = ( + getattr(ckpt_args, "async_strategy", "mcore") + if getattr(ckpt_args, "async_save", False) + else "mcore" + ) loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir, async_strategy) merge(common_state_dict, loaded_state_dict) @@ -295,6 +309,7 @@ def save( ] = None, content_metadata: Optional[dict] = None, async_strategy: Optional[str] = "nvrx", + verify_integrity: bool = False, ) -> Optional[AsyncRequest]: """Saving entrypoint. @@ -340,6 +355,11 @@ def save( modify the original state dict content_metadata (dict, optional): metadata to identify the checkpoint content. Useful for framework specific versioning. + verify_integrity (bool, optional): if True, compute SHA-256 hashes for every + file in the checkpoint directory after all data has been written. This manifest can + later be verified on load with `load(..., verify_integrity=True)`. + Adds I/O overhead proportional to the total checkpoint size (one extra + read pass over all files on rank 0). Returns: AsyncRequest (optional): if `async_sharded_save` is True, returns @@ -386,13 +406,22 @@ def metadata_finalize_fn(): ) torch.distributed.barrier() + def integrity_finalize_fn(): + if torch.distributed.get_rank() == 0: + save_integrity_manifest(checkpoint_dir) + torch.distributed.barrier() + if not async_sharded_save: sharded_strategy.save(sharded_state_dict, checkpoint_dir) metadata_finalize_fn() + if verify_integrity: + integrity_finalize_fn() return None async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir, async_strategy) async_request.finalize_fns.append(metadata_finalize_fn) + if verify_integrity: + async_request.finalize_fns.append(integrity_finalize_fn) return async_request diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py index 0ae800e46f8..3fdab41b4b0 100644 --- a/megatron/core/dist_checkpointing/strategies/common.py +++ b/megatron/core/dist_checkpointing/strategies/common.py @@ -42,9 +42,9 @@ def load_common(checkpoint_dir: str): try: if MultiStorageClientFeature.is_enabled(): msc = MultiStorageClientFeature.import_package() - return msc.torch.load(load_path, map_location='cpu', weights_only=False) + return msc.torch.load(load_path, map_location='cpu') else: - return torch.load(load_path, map_location='cpu', weights_only=False) + return torch.load(load_path, map_location='cpu') except FileNotFoundError as e: err_msg = f'Common file {load_path} does not exist' if MultiStorageClientFeature.is_enabled(): diff --git a/megatron/core/dist_checkpointing/strategies/fully_parallel.py b/megatron/core/dist_checkpointing/strategies/fully_parallel.py index 6638f215cd4..db3c8ee6cae 100644 --- a/megatron/core/dist_checkpointing/strategies/fully_parallel.py +++ b/megatron/core/dist_checkpointing/strategies/fully_parallel.py @@ -189,7 +189,7 @@ def load( self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path, - async_strategy: str = "nvrx", + async_strategy: str = "mcore", ) -> StateDict: """Distributes the load and calls underlying strategy only for parts of the state dict. diff --git a/megatron/core/dist_checkpointing/strategies/nvrx.py b/megatron/core/dist_checkpointing/strategies/nvrx.py new file mode 100644 index 00000000000..1df26f00ff4 --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/nvrx.py @@ -0,0 +1,55 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Helpers for interacting with the experimental nvidia-resiliency-ext API.""" + +from importlib import import_module +from typing import Any, Callable, Dict + + +def has_nvrx_async_support() -> bool: + """Checks whether the NVRx async checkpointing symbols Megatron uses are importable.""" + try: + core = import_module("nvidia_resiliency_ext.checkpointing.async_ckpt.core") + cached_metadata_reader = import_module( + "nvidia_resiliency_ext.checkpointing.async_ckpt.cached_metadata_filesystem_reader" + ) + filesystem_async = import_module( + "nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async" + ) + state_dict_saver = import_module( + "nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver" + ) + except (ImportError, ModuleNotFoundError): + return False + + required_symbols = ( + getattr(core, "AsyncCallsQueue", None), + getattr(core, "AsyncRequest", None), + getattr(cached_metadata_reader, "CachedMetadataFileSystemReader", None), + getattr(filesystem_async, "FileSystemWriterAsync", None), + getattr(filesystem_async, "get_write_results_queue", None), + getattr(state_dict_saver, "CheckpointMetadataCache", None), + getattr(state_dict_saver, "save_state_dict_async_finalize", None), + getattr(state_dict_saver, "save_state_dict_async_plan", None), + ) + return all(symbol is not None for symbol in required_symbols) and hasattr( + filesystem_async, "_results_queue" + ) + + +def make_nvrx_async_request( + async_request_cls: type, + async_fn: Callable[..., Any], + async_fn_args: Any, + finalize_fns: list[Callable[..., Any]], + async_fn_kwargs: Dict[str, Any] | None = None, + preload_fn: Callable[..., Any] | None = None, +): + """Builds an AsyncRequest using the expected NVRx API.""" + return async_request_cls( + async_fn, + async_fn_args, + finalize_fns, + async_fn_kwargs=async_fn_kwargs or {}, + preload_fn=preload_fn, + ) diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index 3e467c1f9dd..7943561700f 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -6,13 +6,12 @@ import os import pickle import warnings -from abc import ABC from collections import defaultdict from contextlib import contextmanager from itertools import product from logging import getLogger from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast import torch from packaging.version import Version as PkgVersion @@ -51,19 +50,18 @@ ) from .async_utils import AsyncRequest from .checkpointable import CheckpointableShardedTensor, LocalShardsContainer +from .nvrx import has_nvrx_async_support, make_nvrx_async_request -try: +if TYPE_CHECKING: from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncRequest as NVRxAsyncRequest from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import ( CheckpointMetadataCache, ) +else: + CheckpointMetadataCache = Any + NVRxAsyncRequest = Any - HAVE_NVRX = True -except (ImportError, ModuleNotFoundError): - CheckpointMetadataCache = ABC - NVRxAsyncRequest = ABC - - HAVE_NVRX = False +HAVE_NVRX = has_nvrx_async_support() try: if not torch.cuda.is_available(): @@ -103,6 +101,7 @@ class MCoreSavePlan: logger = getLogger(__name__) +_logged_mcore_async_deprecation = False def flatten_state_dict( @@ -651,9 +650,8 @@ def __init__( self.validated_loaded_metadata_reuse = False def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - """Each async strategy can be trivially used as a sync strategy.""" - strategy = "nvrx" if HAVE_NVRX else "mcore" - async_request = self.async_save(sharded_state_dict, checkpoint_dir, async_strategy=strategy) + """Sync save always uses the built-in implementation.""" + async_request = self.async_save(sharded_state_dict, checkpoint_dir, async_strategy="mcore") async_request.execute_sync() del async_request @@ -671,11 +669,14 @@ def async_save( Returns: None """ + global _logged_mcore_async_deprecation if async_strategy == "mcore": - logger.warning( - "MCore's async save is deprecated and will be removed in the future releases. " - "Please, use NVRx async solution by setting `async_strategy` to `nvrx`." - ) + if not _logged_mcore_async_deprecation: + logger.warning( + "MCore's async save is deprecated and will be removed in the future releases. " + "Please, use NVRx async solution by setting `async_strategy` to `nvrx`." + ) + _logged_mcore_async_deprecation = True # Translate the state dict (sharded_state_dict, flat_mapping, rename_mapping) = ( @@ -701,7 +702,9 @@ def async_save( if async_strategy == "nvrx": if self._metadata_cache is None: self._metadata_cache = checkpointable_metadata_cache() - if self.cached_global_metadata is not None: + if self.cached_global_metadata is not None and hasattr( + self._metadata_cache, "set_cached_global_metadata" + ): self._metadata_cache.set_cached_global_metadata(self.cached_global_metadata) # Define additional arguments async_writer_kwargs["use_cached_data_structure"] = self.use_cached_ckpt_structure @@ -818,11 +821,13 @@ def _get_save_and_finalize_callbacks( def finalize_fn(): save_state_dict_async_finalize(*save_state_dict_ret) - return async_request(save_fn, save_args, [finalize_fn], preload_fn=preload_fn) + return make_nvrx_async_request( + async_request, save_fn, save_args, [finalize_fn], preload_fn=preload_fn + ) def _get_filesystem_reader( - checkpoint_dir: Union[str, Path], cache_metadata: bool = False, async_strategy: str = "nvrx" + checkpoint_dir: Union[str, Path], cache_metadata: bool = False, async_strategy: str = "mcore" ) -> FileSystemReader: if MultiStorageClientFeature.is_enabled(): msc = MultiStorageClientFeature.import_package() @@ -846,7 +851,7 @@ def load( self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path, - async_strategy: str = "nvrx", + async_strategy: str = "mcore", ) -> StateDict: """Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt. @@ -1062,9 +1067,8 @@ def get_async_strategy(async_strategy: str = "nvrx", module: str = None) -> tupl async_strategy = "nvrx" except (ImportError, ModuleNotFoundError): raise ModuleNotFoundError( - "nvidia-resiliency-ext package is not installed. " - "Please, install nvidia-resiliency-ext package or set `async_strategy` to `mcore` " - "to enable async save strategy." + "A compatible `nvidia-resiliency-ext` installation is required for " + '`async_strategy="nvrx"`. Please install it or set `async_strategy` to `mcore`.' ) elif async_strategy == "mcore": # do mcore async imports diff --git a/megatron/core/dist_checkpointing/validation.py b/megatron/core/dist_checkpointing/validation.py index 89ecba1a968..5c9e6ba3caa 100644 --- a/megatron/core/dist_checkpointing/validation.py +++ b/megatron/core/dist_checkpointing/validation.py @@ -1,10 +1,13 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import hashlib +import json import logging +import os from collections import Counter, defaultdict from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -22,6 +25,7 @@ ShardedStateDict, is_main_replica, ) +from megatron.core.msc_utils import MultiStorageClientFeature if TYPE_CHECKING: from megatron.core.dist_checkpointing.serialization import CkptShardedMetadata @@ -34,6 +38,10 @@ # list of lists of global saved/loaded ShardedBase objects (each element corresponds to global rank) _GlobalMetadata = List[_LocalMetadata] +INTEGRITY_FNAME = 'integrity.json' +_HASH_ALGORITHM = 'sha256' +_READ_CHUNK_SIZE = 1 << 20 # 1 MiB + class StrictHandling(Enum): """Determines handling of load mismatch (non-empty "unexpected" or "missing" keys). @@ -483,3 +491,148 @@ def determine_global_metadata( global_metadata = [None] * torch.distributed.get_world_size() torch.distributed.all_gather_object(global_metadata, local_metadata) return local_metadata, global_metadata # type: ignore[return-value] + + +def _compute_file_hash(file_path: str) -> str: + """Return the SHA-256 hex digest of `file_path`, read in streaming chunks. + Args: + file_path: absolute path to the file to hash. + Returns: + Lowercase hex-encoded SHA-256 digest string. + """ + h = hashlib.sha256() + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + with msc.open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(_READ_CHUNK_SIZE), b''): + h.update(chunk) + else: + with open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(_READ_CHUNK_SIZE), b''): + h.update(chunk) + return h.hexdigest() + + +def save_integrity_manifest(checkpoint_dir: str) -> None: + """Hash every file in `heckpoint_dir` and write an integrity manifest. + The manifest lists each filename (relative to `checkpoint_dir`) + together with its SHA-256 digest. The manifest file itself is excluded + from the listing. + Args: + checkpoint_dir: directory that contains the checkpoint files. + """ + manifest: Dict[str, str] = {} + + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + ckpt_path = msc.Path(checkpoint_dir) + for entry in sorted(ckpt_path.iterdir(), key=lambda p: str(p)): + if entry.name != INTEGRITY_FNAME: + manifest[entry.name] = _compute_file_hash(str(entry)) + else: + ckpt_path = Path(checkpoint_dir) + for entry in sorted(ckpt_path.iterdir()): + if entry.is_file() and entry.name != INTEGRITY_FNAME: + manifest[entry.name] = _compute_file_hash(str(entry)) + + integrity_path = os.path.join(checkpoint_dir, INTEGRITY_FNAME) + payload = {'algorithm': _HASH_ALGORITHM, 'files': manifest} + + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + with msc.open(integrity_path, 'w') as f: + json.dump(payload, f, indent=2) + else: + with open(integrity_path, 'w') as f: + json.dump(payload, f, indent=2) + + logger.info("Saved integrity manifest with %d file(s) to %s", len(manifest), integrity_path) + + +def _verify_integrity_manifest_impl(checkpoint_dir: str) -> None: + """Single-process implementation of integrity verification. + Reads ``integrity.json``, recomputes each file's hash, and raises + `megatron.core.dist_checkpointing.core.CheckpointingException` + on any mismatch or missing file. + Args: + checkpoint_dir: checkpoint directory to verify. + Raises: + CheckpointingException: if the manifest is absent, uses an unsupported + algorithm, or any file's hash does not match. + """ + integrity_path = os.path.join(checkpoint_dir, INTEGRITY_FNAME) + + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + if not msc.os.path.exists(integrity_path): + raise CheckpointingException( + f'Integrity manifest not found at {integrity_path}. ' + 'The checkpoint must be saved with integrity verification enabled ' + '(save_integrity=True) before it can be verified on load.' + ) + with msc.open(integrity_path) as f: + manifest_data = json.load(f) + else: + if not os.path.exists(integrity_path): + raise CheckpointingException( + f'Integrity manifest not found at {integrity_path}. ' + 'The checkpoint must be saved with integrity verification enabled ' + '(save_integrity=True) before it can be verified on load.' + ) + with open(integrity_path) as f: + manifest_data = json.load(f) + + algorithm = manifest_data.get('algorithm', _HASH_ALGORITHM) + if algorithm != _HASH_ALGORITHM: + raise CheckpointingException( + f'Unsupported hash algorithm in integrity manifest: {algorithm!r}. ' + f'Expected: {_HASH_ALGORITHM!r}.' + ) + + manifest: Dict[str, str] = manifest_data['files'] + mismatches = [] + + for filename, expected_hash in manifest.items(): + full_path = os.path.join(checkpoint_dir, filename) + try: + actual_hash = _compute_file_hash(full_path) + except (FileNotFoundError, OSError) as exc: + mismatches.append(f' {filename}: file missing or unreadable ({exc})') + continue + if actual_hash != expected_hash: + mismatches.append( + f' {filename}: hash mismatch ' + f'(expected {expected_hash[:16]}..., got {actual_hash[:16]}...)' + ) + + if mismatches: + raise CheckpointingException( + f'Checkpoint integrity verification failed for {len(mismatches)} ' + f'file(s) in {checkpoint_dir}:\n' + '\n'.join(mismatches) + ) + + logger.info("Checkpoint integrity verified: %d file(s) OK in %s", len(manifest), checkpoint_dir) + + +def verify_integrity_manifest(checkpoint_dir: str) -> None: + """Verify checkpoint files against their recorded SHA-256 hashes. + Args: + checkpoint_dir: checkpoint directory to verify. + Raises: + CheckpointingException: if ``integrity.json`` is absent or any file's + hash no longer matches the stored value. + """ + import torch + + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + error_payload = [None] + if torch.distributed.get_rank() == 0: + try: + _verify_integrity_manifest_impl(checkpoint_dir) + except CheckpointingException as exc: + error_payload = [str(exc)] + torch.distributed.broadcast_object_list(error_payload, src=0) + if error_payload[0] is not None: + raise CheckpointingException(error_payload[0]) + else: + _verify_integrity_manifest_impl(checkpoint_dir) diff --git a/megatron/core/distributed/README.md b/megatron/core/distributed/README.md index c4a75284414..489e381f9e0 100644 --- a/megatron/core/distributed/README.md +++ b/megatron/core/distributed/README.md @@ -1,11 +1,27 @@ -## How to use pytorch FSDP2? +# Distributed Data Parallelism -Add these flag to enable Torch FSDP2. +This module contains algorithms, data structures, and utilities used for different types of distributed data parallelism, such as DDP and FSDP. + +## Distributed Data Parallelism + +This is the default data parallelism used with all parallelism topologies in Megatron-LM. + +## Megatron-FSDP + +To use Megatron-FSDP in Megatron-LM, enable the following arguments: + +``` +--use-megatron-fsdp +--ckpt-format fsdp_dtensor +--init-model-with-meta-device +``` + +## FSDP2 + +To use FSDP2 in Megatron-LM, enable the following arguments: ``` --use-torch-fsdp2 --no-gradient-accumulation-fusion --ckpt-format torch_dist ``` - -It is worth noting that CUDA_MAX_CONNECTIONS=1 should not be enabled to ensure that the communication of FSDP and the computation on the primary stream can be fully parallelized. diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 35325d70ce9..a711f1405d1 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -7,15 +7,15 @@ import torch from ..config_logger import has_config_logger_enabled, log_config_to_disk -from ..fp4_utils import is_nvfp4tensor from ..fp8_utils import is_float8tensor, post_all_gather_processing +from ..optimizer.param_layout import FullParamLayout from ..process_groups_config import ProcessGroupCollection from ..transformer.cuda_graphs import is_graph_capturing from ..transformer.transformer_config import TransformerConfig from ..utils import log_single_rank from .data_parallel_base import _BaseDataParallel from .distributed_data_parallel_config import DistributedDataParallelConfig -from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets +from .param_and_grad_buffer import _ParamAndGradBuffer, group_params_for_buffers, partition_buckets logger = logging.getLogger(__name__) @@ -36,6 +36,9 @@ class DistributedDataParallel(_BaseDataParallel): use standard bucketing policy: assign parameters to smaller buckets and all-reduce per bucket _if_ overlap_grad_reduce is True and pp_rank is 0. pg_collection: Optional unified process group for distributed training. + full_param_layout: Optional FullParamLayout providing pre-computed layouts for all + dtype groups. When provided, each buffer uses the corresponding PerBufferParamLayout + instead of computing a default one. """ @@ -46,6 +49,7 @@ def __init__( module: torch.nn.Module, disable_bucketing: bool = False, pg_collection: Optional[ProcessGroupCollection] = None, + full_param_layout: Optional[FullParamLayout] = None, ): super().__init__(config=config, module=module) if has_config_logger_enabled(config): @@ -104,11 +108,10 @@ def __init__( self.param_to_bucket_group = {} - # Group parameters by their gradient type. + # Collect all trainable parameters. param_to_name = {} - dense_params = [] - expert_parallel_params = [] self.params_with_grad = [] + all_params = [] for name, param in self.module.named_parameters(): if not param.requires_grad: continue @@ -119,142 +122,50 @@ def __init__( param.grad_added_to_main_grad = False param_to_name[param] = name + all_params.append(param) + + # Group parameters by (param_dtype, grad_dtype, is_expert_parallel). + buffer_groups = group_params_for_buffers(all_params, self.ddp_config.grad_reduce_in_fp32) + + # Auto-compute layouts when using distributed optimizer but no layout was provided. + # This maintains backward compatibility for callers that create DDP directly + # without pre-computing layouts (e.g., tests, external code). + if full_param_layout is None and self.ddp_config.use_distributed_optimizer: + log_single_rank( + logger, + logging.WARNING, + "DistributedDataParallel: full_param_layout not provided with " + "use_distributed_optimizer=True. Auto-computing layout inside DDP. " + "Callers should pre-compute layouts via " + "DistributedOptimizer.compute_full_param_layout() and pass them in.", + ) + from ..optimizer.distrib_optimizer import DistributedOptimizer + + full_param_layout = DistributedOptimizer.compute_full_param_layout( + all_params, + self.bucket_size, + self.intra_dp_cp_group.size(), + self.ddp_config, + expert_data_parallel_world_size=self.intra_expt_dp_group.size(), + ) - if getattr(param, 'allreduce', True): - dense_params.append((param, name)) - else: - expert_parallel_params.append((param, name)) - - def _allocate_buffers_for_parameters( - input_params, data_parallel_group, gradient_scaling_factor - ): - param_and_grad_dtype_to_params = {} - param_and_grad_dtype_to_offsets = {} - param_and_grad_dtype_to_indices = {} - - # Group parameters by their gradient type. - for param, param_name in input_params: - assert param.requires_grad - - param_dtype = param.dtype - if is_float8tensor(param) or is_nvfp4tensor(param): - # Currently TE's Float8Tensor is a wrapper of torch.Tensor. It has a "fake" - # dtype (usually a higher precision dtype such as bfloat16), but its actual - # data is stored in the form of a torch uint8 tensor within the Float8Tensor's - # ".data" attribute. Therefore, when creating the param buffer for fp8/fp4 - # params,it is necessary to use torch.uint8, not the "fake" dtype got from - # "param.dtype". - param_dtype = torch.uint8 - grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype - - params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), []) - params.append((param, param_name)) - param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params - - # Get the index of each param among the params with same dtype, if a param is fp8, - # use its "fake" high precision dtype to find which params have same dtype with it. - # For example: - # Case 1: - # params = [p1(bf16), p2(bf16), p3(bf16), p4(bf16)] - # param_and_grad_dtype_to_indices = { - # (torch.bfloat16, torch.float32): [0, 1, 2, 3], - # } - # Case 2: - # params = [p1(bf16), p2(fp8), p3(fp8), p4(bf16)] - # param_and_grad_dtype_to_indices = { - # (torch.bfloat16, torch.float32): [0, 3], - # (torch.uint8, torch.float32): [1, 2], - # } - # We need these indices to load a non-native-fp8 checkpoint in native-fp8 mode. - offset = param_and_grad_dtype_to_offsets.get((param.dtype, grad_dtype), 0) - param_and_grad_dtype_to_offsets[(param.dtype, grad_dtype)] = offset + 1 - indices = param_and_grad_dtype_to_indices.get((param_dtype, grad_dtype), []) - indices.append(offset) - param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)] = indices - - if not config.calculate_per_token_loss: - target_gradient_scaling_factor = 1.0 / self.dp_cp_group.size() - if self.ddp_config.average_in_collective: - if self.ddp_config.num_distributed_optimizer_instances == 1: - # Collective is averaging gradients in collective with data_parallel_group. - assert ( - gradient_scaling_factor / data_parallel_group.size() - == target_gradient_scaling_factor - ) - else: - # For non-expert parameters, gradient_scaling_factor is 1. - # For expert parameters, gradient_scaling_factor is edp_size/dp_size. - assert (gradient_scaling_factor == 1) or ( - gradient_scaling_factor - == (self.expt_dp_group.size() / self.dp_cp_group.size()) - ) - else: - assert gradient_scaling_factor == target_gradient_scaling_factor - - # Allocate the grad buffers and map the grads. - buffers = [] - pg_collection = ProcessGroupCollection() - pg_collection.tp = self.tp_group - pg_collection.dp_cp = self.dp_cp_group - for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items(): - buffers.append( - _ParamAndGradBuffer( - self.ddp_config, - param_dtype, - grad_dtype, - params, - data_parallel_group, - self.bucket_size, - param_to_name, - gradient_scaling_factor, - param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)], - self.ddp_config.nccl_ub, - pg_collection, - ) - ) - - # In some scenarios, we want to put buckets from different buffers into a group so that - # their communication can be aggregated. For example, when there are both fp8 buffers - # and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8 - # bucket and a bf16 bucket, which doubles the number of communication kernels, and - # because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back - # communications will prevent the overlap of the communication kernels with computation - # kernels. - # If bucketing is explicitly disabled, then put all buckets in a buffer into a single - # bucket group. - bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing) - - if self.ddp_config.num_distributed_optimizer_instances > 1: + # When a full_param_layout is provided, verify that the grouping is consistent + # with the layout (same buffer keys, same params per key, same param_indices). + if full_param_layout is not None: + assert set(buffer_groups.keys()) == set(full_param_layout.layouts.keys()), ( + f"Buffer keys from param grouping {set(buffer_groups.keys())} do not match " + f"full_param_layout keys {set(full_param_layout.layouts.keys())}" + ) + for buffer_key, (params, param_indices) in buffer_groups.items(): + layout = full_param_layout.layouts[buffer_key] + assert set(params) == set( + layout.param_index_map.keys() + ), f"Params for {buffer_key} do not match between grouping and layout" assert ( - self.ddp_config.use_distributed_optimizer - ), 'Partial DistOpt cannot be used without DistOpt' - communication_stream = torch.cuda.Stream(device=torch.cuda.current_device()) - for bucket_group in bucket_groups: - bucket_group.inter_distributed_optimizer_instance_group = ( - self.inter_dist_opt_group - ) - bucket_group.communication_stream = communication_stream - - # Set `next_param_gather_bucket_group` for different bucket groups by iterating through - # buckets in reverse order (since all-gathers happen in reverse order of buckets). - # Note: overlap_param_gather covers both the distributed optimizer and the - # layer-wise optimizer cases; the latter sets overlap_param_gather=True - # without use_distributed_optimizer. - if self.ddp_config.overlap_param_gather: - num_bucket_groups = len(bucket_groups) - for i in range(1, num_bucket_groups): - bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = ( - bucket_groups[num_bucket_groups - i - 1] - ) - - # Create map from param to bucket group, used in pre_hook. - for bucket_group in bucket_groups: - for bucket in bucket_group.buckets: - for param in bucket.params_list: - self.param_to_bucket_group[param] = bucket_group - - return buffers, bucket_groups + param_indices == layout.param_indices + ), f"param_indices for {buffer_key} do not match between grouping and layout" + # Compute gradient scaling factors. if config.calculate_per_token_loss: assert ( not self.ddp_config.average_in_collective @@ -291,20 +202,115 @@ def _allocate_buffers_for_parameters( gradient_scaling_factor = 1.0 / data_parallel_world_size expert_gradient_scaling_factor = 1.0 / data_parallel_world_size - # Allocate the param+grad buffers for dense params' grads. - self.buffers, self.bucket_groups = _allocate_buffers_for_parameters( - dense_params, self.intra_dp_cp_group, gradient_scaling_factor=gradient_scaling_factor - ) + # Allocate buffers for each group. + self.buffers = [] + self.expert_parallel_buffers = [] + pg_collection = ProcessGroupCollection(tp=self.tp_group, dp_cp=self.dp_cp_group) + for buffer_key, (params, param_indices) in buffer_groups.items(): + if buffer_key.is_expert_parallel: + data_parallel_group = self.intra_expt_dp_group + scaling_factor = expert_gradient_scaling_factor + else: + data_parallel_group = self.intra_dp_cp_group + scaling_factor = gradient_scaling_factor - # Allocate separate param+grad buffers for expert parallel params' grads. - self.expert_parallel_buffers, self.expert_parallel_bucket_groups = ( - _allocate_buffers_for_parameters( - expert_parallel_params, - self.intra_expt_dp_group, - gradient_scaling_factor=expert_gradient_scaling_factor, + if not config.calculate_per_token_loss: + target_gradient_scaling_factor = 1.0 / self.dp_cp_group.size() + if self.ddp_config.average_in_collective: + if self.ddp_config.num_distributed_optimizer_instances == 1: + # Collective is averaging gradients in collective with data_parallel_group. + assert ( + scaling_factor / data_parallel_group.size() + == target_gradient_scaling_factor + ) + else: + # For non-expert parameters, gradient_scaling_factor is 1. + # For expert parameters, gradient_scaling_factor is edp_size/dp_size. + assert (scaling_factor == 1) or ( + scaling_factor == (self.expt_dp_group.size() / self.dp_cp_group.size()) + ) + else: + assert scaling_factor == target_gradient_scaling_factor + + param_layout = ( + full_param_layout.layouts.get(buffer_key) if full_param_layout is not None else None ) + params_with_names = [(p, param_to_name[p]) for p in params] + buffer = _ParamAndGradBuffer( + self.ddp_config, + buffer_key.param_dtype, + buffer_key.grad_dtype, + params_with_names, + data_parallel_group, + self.bucket_size, + param_to_name, + scaling_factor, + param_indices, + self.ddp_config.nccl_ub, + pg_collection, + param_layout=param_layout, + ) + if buffer_key.is_expert_parallel: + self.expert_parallel_buffers.append(buffer) + else: + self.buffers.append(buffer) + + # In some scenarios, we want to put buckets from different buffers into a group so that + # their communication can be aggregated. For example, when there are both fp8 buffers + # and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8 + # bucket and a bf16 bucket, which doubles the number of communication kernels, and + # because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back + # communications will prevent the overlap of the communication kernels with computation + # kernels. + # If bucketing is explicitly disabled, then put all buckets in a buffer into a single + # bucket group. + self.bucket_groups = partition_buckets( + self.buffers, + force_single_bucket_group=disable_bucketing, + reduce_scatter_with_fp32_accumulation=( + self.ddp_config.reduce_scatter_with_fp32_accumulation + ), + ) + self.expert_parallel_bucket_groups = partition_buckets( + self.expert_parallel_buffers, + force_single_bucket_group=disable_bucketing, + reduce_scatter_with_fp32_accumulation=( + self.ddp_config.reduce_scatter_with_fp32_accumulation + ), ) + if self.ddp_config.num_distributed_optimizer_instances > 1: + assert ( + self.ddp_config.use_distributed_optimizer + ), 'Partial DistOpt cannot be used without DistOpt' + for bucket_groups in [self.bucket_groups, self.expert_parallel_bucket_groups]: + communication_stream = torch.cuda.Stream(device=torch.cuda.current_device()) + for bucket_group in bucket_groups: + bucket_group.inter_distributed_optimizer_instance_group = ( + self.inter_dist_opt_group + ) + bucket_group.communication_stream = communication_stream + + # Set `next_param_gather_bucket_group` for different bucket groups by iterating through + # buckets in reverse order (since all-gathers happen in reverse order of buckets). + # Note: overlap_param_gather covers both the distributed optimizer and the + # layer-wise optimizer cases; the latter sets overlap_param_gather=True + # without use_distributed_optimizer. + if self.ddp_config.overlap_param_gather: + for bucket_groups in [self.bucket_groups, self.expert_parallel_bucket_groups]: + num_bucket_groups = len(bucket_groups) + for i in range(1, num_bucket_groups): + bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = ( + bucket_groups[num_bucket_groups - i - 1] + ) + + # Create map from param to bucket group, used in pre_hook. + for bucket_groups in [self.bucket_groups, self.expert_parallel_bucket_groups]: + for bucket_group in bucket_groups: + for bucket in bucket_group.buckets: + for param in bucket.params_list: + self.param_to_bucket_group[param] = bucket_group + # Delete references to weight_tensor if they exist since we don't want two parameter copies # if we re-mapped parameters (which happens when we use the distributed optimizer). # This is a temporary workaround around a TE bug that is fixed with diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py index 540dbbd51c5..cc1d0d80e7b 100644 --- a/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/core/distributed/finalize_model_grads.py @@ -1,7 +1,7 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from functools import partial -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors @@ -275,6 +275,44 @@ def _allreduce_position_embedding_grads( ) +def _allreduce_router_grads(model: List[torch.nn.Module], config: TransformerConfig): + """ + All-reduce router grads. + + Reduce grads across all the pp stages to ensure that parameters of the router stay in sync. + """ + + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + grads_dict: Dict[str, List[torch.Tensor]] = {} + for model_chunk in model: + for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')(): + if param.requires_grad and getattr(param, 'flextron_router_pp_sync', False): + grad = param.main_grad + if name in grads_dict: + # Add all the virtual PP rank's gradients to + # the first local virtual PP rank. + grads_dict[name][0].add_(grad) + # Append to the end for later update after cross-rank reduce. + grads_dict[name].append(grad) + else: + grads_dict[name] = [grad] + + if grads_dict: + # All-reduce the gradient on the first VPP rank. + grads = [param_grad[0] for _, param_grad in grads_dict.items()] + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce( + coalesced, group=parallel_state.get_pipeline_model_parallel_group() + ) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + # Update the gradients on other VPP ranks. + for grads in grads_dict.values(): + for grad in grads[1:]: + grad.copy_(grads[0]) + + def reset_model_temporary_tensors(config: TransformerConfig, model: List[torch.nn.Module]): """ Reset the temporary tensors of the model. @@ -465,6 +503,9 @@ def finalize_model_grads( if config.timers is not None: config.timers('conditional-embedder-grads-all-reduce').stop() + if getattr(config, 'flextron', False): + _allreduce_router_grads(model, config) + # All-reduce layer-norm grads (for sequence parallelism) and non-tensor parallel modules. if config.timers is not None: config.timers('non-tensor-parallel-grads-all-reduce', log_level=1).start( diff --git a/megatron/core/distributed/fsdp/src/README.md b/megatron/core/distributed/fsdp/src/README.md index 98c01a759eb..d3422d03abb 100644 --- a/megatron/core/distributed/fsdp/src/README.md +++ b/megatron/core/distributed/fsdp/src/README.md @@ -1,6 +1,6 @@
-# 🚀 Megatron-FSDP +# Megatron-FSDP
@@ -12,38 +12,16 @@ ## ✨ What is Megatron-FSDP? -**Megatron-FSDP** is an NVIDIA-developed PyTorch extension that provides a high-performance implementation of Fully Sharded Data Parallelism (FSDP). It offers seamless cross-compatibility with major deep learning frameworks and parallelism libraries, making it easy to scale your PyTorch models across multiple GPUs and nodes. +**Megatron-FSDP** is an NVIDIA-developed distributed parallelism library written in native PyTorch that provides a high-performance implementation of **Fully Sharded Data Parallelism (FSDP)**. It offers seamless cross-compatibility with various deep learning frameworks and parallelism libraries such as Megatron-Core, and is performance-optimized to support training and inference of extremely large PyTorch models at data-center scale on NVIDIA GPUs. -Megatron-FSDP can provide up to 25% speed up and 23% memory savings compared to FSDP2. +For comprehensive information about Megatron-FSDP, refer to: [Megatron-FSDP | Megatron-Core Developer Guide](https://docs.nvidia.com/megatron-core/developer-guide/latest/) -### Compatibility +### 🧩 Compatibility -- **[PyTorch DTensor](https://docs.pytorch.org/docs/stable/distributed.tensor.html)** +- PyTorch **[DeviceMesh](https://docs.pytorch.org/docs/stable/distributed.html#devicemesh)**, **[DTensor](https://docs.pytorch.org/docs/stable/distributed.tensor.html)**, and **[Distributed Checkpoint (DCP)](https://docs.pytorch.org/docs/stable/distributed.checkpoint.html)** - **[Megatron Core](https://github.com/NVIDIA/Megatron-LM)** - **[TransformerEngine](https://github.com/NVIDIA/TransformerEngine)** - -## ✨ Features - -- **Easy Integration**: Simple `fully_shard` function for quick model parallelization -- **High Performance**: Optimized for NVIDIA GPUs with efficient memory management -- **Cross-Framework**: Works seamlessly with PyTorch, Huggingface Transformers, Megatron-LM, Megatron Bridge and TransformerEngine -- **Scalable**: Supports both single-node multi-GPU and multi-node distributed training -- **Flexible Configuration**: Configurable sharding strategies and process groups - -## ⚡ Optimizations - -- **Advanced Bucketing**: Data-type aware bucketing system to minimize the overhead of collective operations -- **Buffer Management**: Zero copy communication is achieved by reorganizing the storage of parameters and main grad with `ParamAndGradBuffer` class -- **Communication Overlapping**: Improved communication overlap of paramter all-gather and gradient reduce-scatter -- **FP8 Mixed Precision with Transformer Engine**: Compatibility with Transformer Engine enables efficient FP8 mixed precision training -- **Gradient accumulate fusion support with Transformer Engine**: Remove the explicit gradient copy to the communication buffer in backwards pass - -### Advanced Collective Communication -- **SM Usage Reduction with SHARP**: FSDP's `All-Gather` (AG) and `Reduce-Scatter` (RS) collectives are designed to overlap with compute kernels. However, standard NCCL communication kernels can consume a significant number of GPU SMs (e.g., 16-32 SMs), "stealing" resources from compute (GEMM) kernels and reducing overall TFLOPS. -- **In-Switch Processing**: We leverage **SHARP** (Scalable Hierarchical Aggregation and Reduction Protocol) to offload these collective operations. SHARP performs aggregation and reduction computations directly on the network switches (InfiniBand or NVLink Switch) instead of on the GPU SMs. This dramatically reduces the SM consumption for communication to **1-6 SM** freeing up GPU resources for compute. It also provides lower communication latency, especially in large, scaled-out workloads. -- **Symmetric Optimizations for MNNVL**: We support **symmetric-based optimizations**, introduced in NCCL v2.27, which enable switch offloading for **Multi-Node NVLink (MNNVL)** systems such as GB200/GB300. This allows the same SM-saving benefits over the high-bandwidth NVLink fabric itself. -- **Hierarchical Collectives**: When an FSDP sharding domain spans both NVLink and InfiniBand, the library utilizes **hierarchical SHARP collectives** (e.g., NVL-SHARP + IB-SHARP) to optimize the communication path across the entire system topology. - +- **[NVIDIA NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo)** ## 📦 Installation @@ -56,226 +34,57 @@ pip install megatron-fsdp ## 🚀 Quick Start -### Basic Usage - -Transform your PyTorch model to use Fully Sharded Data Parallelism with just a few lines: - -```python -import torch -from megatron_fsdp import ( - fully_shard_model, - fully_shard_optimizer, -) - -""" -Enable FSDP with Megatron-FSDP via the `fully_shard_*` API. -""" -# Shard your model. -model = fully_shard_model( - model, - fsdp_unit_modules=[ - YourModelLayerClass, - "import.path.to.model.class.YourModelLayerClass", - ], - ... -) -# Shard your optimizer. -optimizer = fully_shard_optimizer( - torch.optim.Adam(model.parameters(), lr=1e-3) -) - -# Your model is now ready for distributed training! -``` - -### Comparison with FSDP-2 - -`fully_shard` / `fully_shard_model` / `fully_shard_optimizer` are simple entrypoints into `MegatronFSDP`. - -- No need to call `fully_shard` on all the sub-modules, just pass your sub-module classes or import paths to `fully_shard`! -- Seamlessly preserves the identity of your training loop with only a few lines of code and multiple options for initialization: - - `fully_shard_*` is a two-line change when sharding the model and optimizer separately. - - `fully_shard` is a one-line change for previously-initialized models and optimizers. - -Compare this with FSDP2: - -```python -import torch -from torch.distributed.fsdp import fully_shard - -# Your existing model and optimizer. -model = YourModel() -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - -# Enable FSDP with FSDP2. -for module in model.modules(): - # Sub-Modules to shard. - if isinstance(module, YourModelLayerClass): - fully_shard(module) -fully_shard(model) - -# Your model is now ready for distributed training! -``` - -### `torch.compile` Compatibility - -Megatron-FSDP is compatible with `torch.compile`, but this feature is still experimental and may introduce performance regressions in some workloads. - -## 📖 Megatron-FSDP Comprehensive Walkthrough - -### Import `megatron_fsdp`. - ```python import torch from megatron_fsdp import ( fully_shard_model, fully_shard_optimizer, - MixedPrecisionPolicy, ) -``` - -### Set up a distributed environment using `DeviceMesh`. - -`DeviceMesh` simplifies the construction of complex arrangements of devices -to support various parallelisms. - -```python -from torch.distributed.device_mesh import DeviceMesh - -# Initialize DeviceMesh. -device_mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", - mesh_shape=(dp_outer_size, dp_shard_size, cp_size, tp_size), - mesh_dim_names=("dp_outer", "dp_shard", "cp", "tp"), -) -# Only relevant when using HSDP, where we also need the full DP group for data parallelism, -# This sub-mesh can be provided to distributed samplers or dataloaders. -device_mesh[("dp_outer", "dp_shard")]._flatten("dp") -# Only required if using CP. Otherwise, just pass dp_shard to FSDP. -device_mesh[("dp_shard", "cp")]._flatten("dp_shard_cp") -# Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group. -device_mesh[("dp_outer", "dp_shard", "cp")]._flatten("hsdp") -hsdp_group = device_mesh["hsdp"].get_group() - -# Initialize DeviceMesh for expert parallel (EP) modules when using FSDP + EP. -expert_device_mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", - mesh_shape=(dp_outer_size, expt_dp_shard_size, expt_tp_size), - mesh_dim_names=("dp_outer", "dp_shard_cp", "tp"), -) -expert_device_mesh[("dp_outer", "dp_shard_cp")].flatten("hsdp") -hsdp_expt_group = expert_device_mesh["hsdp"].get_group() -``` - -### Convert models into fully-sharded `MegatronFSDP` models with `fully_shard_model`. - -This wraps the model in a MegatronFSDP class that schedules the sharding -lifecycle of the model parameters and gradients during training and inference. - -```python -model = fully_shard_model( - # PyTorch (Root) Module - model, - # Sharded Modules - fsdp_unit_modules=[...], - # Device Mesh - device_mesh=device_mesh - # Always required for FSDP or HSDP. - dp_shard_dim="dp_shard_cp", - # Set this required argument to use HSDP instead of FSDP. Otherwise, set this to None. - dp_outer_dim="dp_outer", - # Only required for TP-sensitive models (i.e. Megatron-LM / TransformerEngine) - # or when using DTensor-based TP. Otherwise, set this to None. - tp_dim="tp", - # Only required when using HSDP. Otherwise, set this to None. - hybrid_fsdp_group=hsdp_group, - # Only required when using HSDP + EP. Otherwise, set this to None. - hybrid_fsdp_expt_group=hsdp_expt_group, - # Only required for FSDP + EP. Otherwise, set this to None. - expt_device_mesh=expt_device_mesh, - # FSDP Sharding Strategy: no_shard (0) / optim (1) / optim_grads (2) / optim_grads_params (3) - zero_dp_strategy=3, - outer_dp_sharding_strategy=1, - # Initialize the model on devices in shards to avoid OOM. Requires device("meta")-init for model. - init_model_with_meta_device=True, - # Mixed-Precision Policy for controlling compute and communication precision in Megatron-FSDP. - mixed_precision_policy=MixedPrecisionPolicy(), - # Sync parameters and gradients each step. Allows for gradient transformations after backward pass, - # and synchronizes parameters and gradients across HSDP groups, but deactivates compute-communication - # overlap going into the subsequent training step. - sync_model_each_microbatch=True, - # Preprocess state dict for DCP checkpointing. Required for Torch Distributed Checkpoint. - preproc_state_dict_for_dcp_ckpt=True, -) -``` - -The original `torch.nn.Module` can be accessed at `MegatronFSDP.module`. - -### Initialize and fully-shard your optimizer on the `MegatronFSDP` model. -Initialize your optimizer on the Megatron-FSDP model distributed `Parameter`(s). -If your optimizer has already been initialized, either use the `fully_shard` -entrypoint, or use `optimizer.add_param_group({"params": model.parameters()})` -after resetting your optimizer state via `optimizer.param_groups.clear()` -and `optimizer.state.clear()`. +# Initialize Torch Distributed. +torch.distributed.init_process_group() +torch.cuda.set_device(torch.distributed.get_rank()) -```python -optimizer = torch.optim.Optimizer(model.parameters()) -``` - -`fully_shard_optimizer` modifies your `optimizer.step()`, `optimizer.zero_grad()`, -and distributed optimizer parameters to punctually trigger scheduled FSDP operations -for Megatron-FSDP. - -```python -fully_shard_optimizer( - # PyTorch Optimizer - optimizer, - # Preprocess state dict for DCP checkpointing. - # Required for Torch Distributed Checkpoint. - preproc_state_dict_for_dcp_ckpt=True, -) -``` - -Extended arguments to `step()` and `zero_grad()` control these FSDP operations: - -```python - optimizer.step( - ..., - # Sync all gradients before the optimizer step. Alternatively enabled using - # `sync_model_each_microbatch=True` in MegatronFSDP. - sync_grad_before_optimizer_step=True, - # After `optimizer.step()`, install optimized weights into MegatronFSDP's buffers. - install_optimized_model_weights=True, - ) - - optimizer.zero_grad( - ..., - # Also zero out MegatronFSDP's gradient accumulation buffers. - zero_grad_buffer=True - ) -``` - -### `MegatronFSDP` Distributed Checkpointing - -Distributed checkpoints can be saved and loaded using Torch DCP. Alternatively, -you can load non-distributed checkpoints before fully-sharding your model with -any existing checkpoint utility compatible with PyTorch Modules. - -```python -# Save model and optimizer state. -torch.distributed.checkpoint.save( - {"model": model.state_dict(), "optimizer": optimizer.state_dict()}, - checkpoint_id=str(CKPT_DIR) +# Fully-shard the model. +model = torch.nn.Transformer() +fsdp_model = fully_shard_model( + module=model, + fsdp_unit_modules=[ + torch.nn.TransformerEncoder, + torch.nn.TransformerDecoder + ] ) -# Load model and optimizer state. -ckpt_state_dict = {"model": model.state_dict(), "optimizer": optimizer.state_dict()} -torch.distributed.checkpoint.load(state_dict=ckpt_state_dict, checkpoint_id=str(CKPT_DIR)) -# `model.load_state_dict(strict=False)` is only necessary to ignore TE FP8 extra state -# that is missing from the DCP checkpoint but present in TEBaseModule. -# Megatron-FSDP does not support TE FP8 extra state checkpointing with DCP. -model.load_state_dict(ckpt_state_dict["model"], strict=False) -optimizer.load_state_dict(ckpt_state_dict["optimizer"]) +# Fully-shard the optimizer. +toy_adam = torch.optim.AdamW(params=fsdp_model.parameters(), lr=0.01) +optimizer = fully_shard_optimizer(optimizer=toy_adam) + +# Forward pass. +inp = torch.randn(1, 512, 512).to("cuda") +tgt = torch.randn(1, 512, 512).to("cuda") +output = fsdp_model(inp, inp) + +# Backward pass. +torch.nn.functional.mse_loss(output, tgt).backward() + +# Optimizer step. +optimizer.step() +optimizer.zero_grad() + +# Checkpoint the model and optimizer. +torch.distributed.checkpoint.save({ + "model": fsdp_model.state_dict(), + "optimizer": optimizer.state_dict(), +}, checkpoint_id="ckpt/") + +# Load the saved checkpoint. +ckpt = { + "model": fsdp_model.state_dict(), + "optimizer": optimizer.state_dict(), +} +torch.distributed.checkpoint.load(state_dict=ckpt, checkpoint_id="ckpt/") +fsdp_model.load_state_dict(ckpt["model"], strict=False) +optimizer.load_state_dict(ckpt["optimizer"]) ``` ## ⚙️ `fully_shard` / `MegatronFSDP` API - Advanced Features @@ -305,17 +114,17 @@ Megatron-FSDP's `fully_shard_*` API has a comprehensive set of arguments for fin - Defaults to `False`. - Note that the `device` argument which installs your model on a specific device or rank will be deactivated when `init_model_with_meta_device=True`. - `mixed_precision_policy` takes a `megatron_fsdp.MixedPrecisionPolicy` that configures mixed-precision compute and communication for Megatron-FSDP. Configuration options include: - - `main_params_dtype` controls the data-type for parameters used in distributed optimization or quantization. + - `main_params_dtype` controls the data-type for parameters responsible for distributed checkpointing, distributed optimization, and quantization. - Defaults to `torch.float32`. - If set to `None`, the native model compute parameter data-type will be utilized. - - Requires specification (cannot be `None`) when using `FP8` parameters with Megatron-FSDP. + - Requires specification (cannot be `None`) when using quantized parameters with Megatron-FSDP. - `main_grads_dtype` controls the data-type for gradients used in distributed optimization. - - Defaults to `None`, the model native gradient data-type will be utilized. + - Defaults to `None`, in which the model native gradient data-type will be utilized. - While `torch.float32` (or higher) is recommended for accuracy at scale, as `main_grads_dtype` controls the data-type for gradient accumulation, `None` is more flexible and uses pre-determined parameter gradient logic in mixed-precision scenarios, such as `BF16` for `FP8`/`FP4` parameters quantized via TransformerEngine. - - `grad_comm_dtype` controls the data-type for gradient communications (RS / AR) when reducing gradients. Lower precision `grad_comm_dtype` improves (communication) performance, but may increase memory utilization or sacrifice gradient precision in certain cases. - - Defaults to `None`, the `main_grads_dtype` data-type will be utilized, and no additional memory is allocated when `grad_comm_dtype == main_grads_dtype`. - - If using HSDP (either DP-Replicate or DP-Outer in `outer_dp_sharding_strategy`), `no_shard`, `optim`, or a `FixedPoolAllocator` (`fsdp_double_buffer`), allocating `dtype`-custom gradient communication buffers (per FSDP group) adds memory overhead of up to 10% or more, and users should consider the performance-memory trade-off when using this feature. - - If using NCCL UBR v2.27+ (`nccl_ub=True`), gradient reduction may be performed in high-precision depending on the network domain (NVLink or IB), and can enable mixed-precision communication and accumulation, e.g. setting grad_comm_dtype to `BF16` can support `FP32` reduction even though we have `BF16` input and output communication buffers. Otherwise, gradients will be reduced in `grad_comm_dtype` (and accumulated in `main_grads_dtype`) as usual. + - `grad_comm_dtype` controls the data-type for gradient communications when reducing gradients. Lower precision `grad_comm_dtype` improves (communication) performance, but may increase memory utilization or sacrifice gradient precision in certain cases. + - Defaults to `None`, in which the `main_grads_dtype` data-type will be utilized. No additional memory is allocated when `grad_comm_dtype == main_grads_dtype`. + - If using HSDP (either DP-Replicate or DP-Outer in `outer_dp_sharding_strategy`), `no_shard`, or `optim`, allocating `dtype`-custom gradient communication buffers may increase per-unit memory overhead, so users should consider the performance-memory trade-off when using this feature. + - If using NCCL user buffer registration `v2.27+`, gradient reduction may be performed in high-precision depending on the network domain (NVLink or IB), and can enable mixed-precision communication and accumulation, e.g. setting grad_comm_dtype to `BF16` can support `FP32` reduction even though we have `BF16` input and output communication buffers. Otherwise, gradients will be reduced in `grad_comm_dtype` (and accumulated in `main_grads_dtype`) as usual. - `overlap_grad_reduce` and `overlap_param_gather` will overlap gradient [`reduce-scatter`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter) and parameter [`all-gather`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather) group communications with backward and forward compute with asynchronous calls and pre-fetching. (In the case of `no_shard`, parameters are not gathered but gradient [`all-reduce`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce) is overlapped.) - Both default to `True`. - `sync_model_each_microbatch` will trigger a `wait` (`MegatronFSDP.finish_grad_sync()`) on gradient reduction, parameter de-allocation, and optimizer parameter / gradient installation (in preparation for `optimizer.step()`) after every forward-backward pass. When using HSDP, parameters and gradients will be all-gathered and reduced respectively on the "outer" DP group each training step instead of each optimization cycle. This behavior is desirable for a transparent and user-friendly sharded training loop where post-backward transformations on the gradient and a clean compute / memory state are necessary within and between training iterations, but damages performance in situations where optimization is delayed (e.g. gradient accumulation) when the communications of the previous training iteration can be overlapped with the compute of the next training iteration. Will also override `is_last_microbatch` / `microbatch_count` logic in `MegatronFSDP`. @@ -362,7 +171,7 @@ Megatron-FSDP natively supports mixed-precision activations and parameter shardi - Within the [`transformer_engine.pytorch.autocast(recipe: transformer_engine.common.recipe.Recipe)`](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.autocast) context, model activations are converted based on the recipe. - Within the [`transformer_engine.pytorch.quantized_model_init(recipe: transformer_engine.common.recipe.Recipe)`](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.quantized_model_init) context, TransformerEngine native modules (e.g. [`transformer_engine.pytorch.TransformerLayer`](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.TransformerLayer)) have their parameters converted based on the recipe. - - Requires FP8 model activations, i.e. `transformer_engine.pytorch.autocast`. + - Requires quantized model activations, i.e. `transformer_engine.pytorch.autocast`. ```python # FP8 Recipe @@ -397,4 +206,4 @@ with transformer_engine.pytorch.autocast(recipe=fp8_recipe): mfsdp_model(x).sum().backward() ``` -ℹ️ `TransformerEngine` kernels have a fair bit of configuration constraints when using FP8-quantized parameters, such as using fused QKV parameters or defining activations and parameters with shapes compatible to FP8 CuBLAS kernels on supported hardware from NVIDIA. To properly initialize `TransformerLayer`, you can refer to the toy model used in our FP8 unit tests: `Megatron-LM/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py::TestMegatronFsdpFullyShard::test_fully_shard_te_quantized`. \ No newline at end of file +ℹ️ `TransformerEngine` kernels have various constraints related to quantized Tensors, such as using fused QKV parameters or defining activations and parameters with shapes compatible to CuBLAS kernels on supported hardware from NVIDIA. To properly initialize `TransformerLayer`, you can refer to the example model used in our unit tests: `Megatron-LM/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py::TestMegatronFsdpFullyShard::test_fully_shard_te_quantized`. \ No newline at end of file diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 06676b07dec..26d5ba81b3d 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -27,6 +27,7 @@ modify_underlying_storage, post_all_gather_processing, ) +from ..optimizer.param_layout import pad_bucket_end, pad_param_start from ..utils import is_torch_min_version, log_on_each_pipeline_stage from .distributed_data_parallel_config import DistributedDataParallelConfig from .reduce_scatter_with_fp32_accumulation import reduce_scatter_with_fp32_accumulation @@ -433,6 +434,11 @@ def start_param_sync(self, force_sync: bool = False): for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) bucket.layerwise_gather_list = None + # Zero out grad_data since it was reused as the all-gather + # receive buffer. Without this, accumulation into main_grad + # (a view into grad_data) would start from the result of the + # latest parameter all-gather instead of zero. + bucket.grad_data.zero_() self.param_gather_handle = None else: @@ -522,6 +528,11 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) bucket.layerwise_gather_list = None + # Zero out grad_data since it was reused as the all-gather + # receive buffer. Without this, accumulation into main_grad + # (a view into grad_data) would start from the result of the + # latest parameter all-gather instead of zero. + bucket.grad_data.zero_() self._post_param_sync() def start_grad_sync(self, force_all_reduce: Optional[bool] = False): @@ -764,6 +775,114 @@ def register_grad_ready( self.start_grad_sync(force_all_reduce=force_all_reduce) +def group_params_for_buffers( + params: List[torch.nn.Parameter], grad_reduce_in_fp32: bool +) -> Dict['BufferKey', Tuple[List[torch.nn.Parameter], List[int]]]: + """Group parameters by buffer identity for buffer allocation. + + Each distinct buffer is identified by a BufferKey with three dimensions: + - param_dtype: storage dtype (torch.uint8 for FP8/NVFP4 parameters, else param.dtype). + - grad_dtype: gradient reduction dtype (torch.float if grad_reduce_in_fp32, else param.dtype). + - is_expert_parallel: whether the parameter is expert-parallel (param.allreduce == False), + which requires a separate buffer with a different data-parallel group. + + The param_indices track each parameter's position among same-dtype params (using + the "fake" high-precision dtype for FP8/NVFP4 params), needed for loading non-native-fp8 + checkpoints in native-fp8 mode. + + Args: + params: List of parameters to group. + grad_reduce_in_fp32: Whether gradients are reduced in FP32. + + Returns: + Dict mapping BufferKey to (params_list, param_indices). + """ + from ..optimizer.param_layout import BufferKey + + key_to_params = {} + dtype_to_offsets = {} + key_to_indices = {} + + for param in params: + assert param.requires_grad + + param_dtype = param.dtype + if is_float8tensor(param) or is_nvfp4tensor(param): + param_dtype = torch.uint8 + grad_dtype = torch.float if grad_reduce_in_fp32 else param.dtype + is_expert_parallel = not getattr(param, 'allreduce', True) + + key = BufferKey(param_dtype, grad_dtype, is_expert_parallel) + param_list = key_to_params.get(key, []) + param_list.append(param) + key_to_params[key] = param_list + + # Use param.dtype (not param_dtype) so FP8/NVFP4 params share offsets with their + # logical high-precision dtype, needed for checkpoint compatibility. + offset_key = BufferKey(param.dtype, grad_dtype, is_expert_parallel) + offset = dtype_to_offsets.get(offset_key, 0) + dtype_to_offsets[offset_key] = offset + 1 + indices = key_to_indices.get(key, []) + indices.append(offset) + key_to_indices[key] = indices + + result = {} + for key, param_list in key_to_params.items(): + result[key] = (param_list, key_to_indices[key]) + return result + + +def _compute_default_per_buffer_param_layout( + params: List[torch.nn.Parameter], bucket_size: Optional[int] +) -> 'PerBufferParamLayout': + """Compute parameter layout for the non-distributed-optimizer case. + + No padding is applied. Parameters are iterated in reverse order (backprop order) + and grouped into buckets of approximately `bucket_size` elements. + + Args: + params: List of parameters to lay out. + bucket_size: Approximate number of elements per bucket, or None for a single bucket. + + Returns: + PerBufferParamLayout with the computed mapping. + """ + from ..optimizer.param_layout import PerBufferParamLayout + + param_index_map = {} + bucket_indices = [] + per_bucket_numel_unpadded = [] + + param_start_index = 0 + bucket_start_index = 0 + bucket_params = set() + bucket_id = 0 + + for param in params[::-1]: + this_numel = param.data.nelement() + param_end_index = param_start_index + this_numel + param_index_map[param] = (param_start_index, param_end_index, bucket_id) + bucket_params.add(param) + + if bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size: + per_bucket_numel_unpadded.append(param_end_index - bucket_start_index) + bucket_indices.append((bucket_start_index, param_end_index)) + bucket_start_index = param_end_index + bucket_params = set() + bucket_id += 1 + param_start_index = param_end_index + + if len(bucket_params) > 0: + per_bucket_numel_unpadded.append(param_end_index - bucket_start_index) + bucket_indices.append((bucket_start_index, param_end_index)) + + return PerBufferParamLayout( + param_index_map=param_index_map, + bucket_indices=bucket_indices, + per_bucket_numel_unpadded=per_bucket_numel_unpadded, + ) + + class _ParamAndGradBuffer: """ Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into @@ -799,6 +918,7 @@ def __init__( param_indices: List[int], nccl_ub: bool, pg_collection: Optional[ProcessGroupCollection] = None, + param_layout: Optional['PerBufferParamLayout'] = None, ): if pg_collection is None: @@ -833,121 +953,13 @@ def __init__( # Data structures to store underlying buckets and relevant indexing data. self.buckets = [] self.param_to_bucket = {} # Param -> bucket mapping. - self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer). - - def _pad(number_to_be_padded: int, divisor: int) -> int: - return int(math.ceil(number_to_be_padded / divisor) * divisor) - - def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int: - """ - Pads end index of bucket if using distributed optimizer (to ensure uniform sharding). - """ - if self.ddp_config.use_distributed_optimizer: - # Workaround for TE bug causing cuBLAS to pick an incompatible algorithm. - # This also helps cuBLAS pick more efficient algorithms for GEMMs. - # We now ensure that all buckets start at a memory address that is 256-byte - # aligned (128 values since params and grads use >= 16-bit precision). - if self.ddp_config.pad_buckets_for_high_nccl_busbw: - # Make sure the bucket size is divisible by a large power of 2 (2^16) to - # ensure NCCL collectives have high bus bandwidth at large DP counts, - # since NCCL message size (which for ring algorithms is bucket_size / - # dp_size) apparently needs to be divisible by a power of 2 for high busbw. - bucket_size_divisor = math.lcm(self.data_parallel_world_size, 128, 2**16) - else: - bucket_size_divisor = math.lcm(self.data_parallel_world_size, 128) - return _pad(bucket_end_index, bucket_size_divisor) - return bucket_end_index - - def _pad_start_of_param_if_needed(param_start_index: int) -> int: - """ - Pads start index of param if using distributed optimizer (to ensure "good" alignment). - """ - if self.ddp_config.use_distributed_optimizer: - # Ensure that params start at 128-byte aligned addresses (64 values - # since params are >= 16-bit precision). - return _pad(param_start_index, 64) - return param_start_index - # First, figure out how many elements should be in the underlying buffer storage. - # Note that if we need to split the buffer into smaller buckets, each of these - # might need to be padded as well (if using the distributed optimizer). - param_start_index = 0 - bucket_start_index = param_start_index - bucket_params = set() - self.bucket_indices = [] - per_bucket_numel_unpadded = [] - bucket_id = 0 - - def _update_bucket_metadata( - param_end_index: int, - bucket_start_index: int, - bucket_indices: list, - numel_unpadded_list: list, - ) -> int: - """ - Record metadata for a bucket. Returns the bucket's (padded) end_index. - - Args: - param_end_index: End index of the last param in this bucket (unpadded). - bucket_start_index: Start index of this bucket. - bucket_indices: List to append (start, end) bucket boundaries to. - numel_unpadded_list: List to append unpadded bucket numel to. - - Returns: - The bucket's end index, padded if using distributed optimizer. - """ - numel_unpadded_list.append(param_end_index - bucket_start_index) - bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index) - bucket_indices.append((bucket_start_index, bucket_end_index)) - return bucket_end_index - - def _finalize_bucket_all_index_spaces( - param_end_index: int, - bucket_start_index: int, - nvfp4_packed_param_end_index: int = None, - nvfp4_packed_bucket_start_index: int = None, - ) -> tuple: - """ - Record metadata for the current bucket across both main and (if applicable) - NVFP4 packed index spaces. Also resets bucket_params and increments bucket_id. - - Args: - param_end_index: End index of the last param in the bucket (full numel). - bucket_start_index: Start index of the bucket (full numel). - nvfp4_packed_param_end_index: End index in packed space (NVFP4 only). - nvfp4_packed_bucket_start_index: Bucket start in packed space (NVFP4 only). - - Returns: - Tuple of (bucket_end_index, nvfp4_packed_bucket_end_index). - """ - nonlocal bucket_params, bucket_id - bucket_end_index = _update_bucket_metadata( - param_end_index, bucket_start_index, self.bucket_indices, per_bucket_numel_unpadded - ) - nvfp4_packed_bucket_end_index = None - if self.has_nvfp4_params: - nvfp4_packed_bucket_end_index = _update_bucket_metadata( - nvfp4_packed_param_end_index, - nvfp4_packed_bucket_start_index, - self.nvfp4_packed_bucket_indices, - nvfp4_packed_per_bucket_numel_unpadded, - ) - bucket_params = set() - bucket_id += 1 - return bucket_end_index, nvfp4_packed_bucket_end_index - - def _does_param_require_new_bucket(param): - """ - Split shared embedding parameters into separate bucket if using distributed - optimizer that makes use of reduce-scatters instead of all-reduces. - This ensures that the first and last pipeline stage partition optimizer state - for the shared embedding parameters the same way across DP replicas, allowing - the DP reduce-scatter to be before the embedding all-reduce. - """ - return ( - getattr(param, "shared_embedding", False) - and self.ddp_config.use_distributed_optimizer - ) + # Use the provided layout if given, otherwise compute the default (no-padding) layout. + if param_layout is None: + param_layout = _compute_default_per_buffer_param_layout(self.params, bucket_size) + self.param_index_map = param_layout.param_index_map + self.bucket_indices = param_layout.bucket_indices + per_bucket_numel_unpadded = param_layout.per_bucket_numel_unpadded # Check if this buffer contains NVFP4 params. # @@ -964,95 +976,17 @@ def _does_param_require_new_bucket(param): # Grad buffer: [g0, g1, g2, g3, ...] numel = N # # We therefore maintain two index maps: - # - param_index_map: offsets using full numel. + # - param_index_map: offsets using full numel (from pre-computed layout). # - nvfp4_packed_param_index_map: offsets into the packed param buffer (numel // 2). # + # The packed index map is derived from param_index_map by iterating through + # the already-computed layout and halving numel for NVFP4 tensors. + # self.has_nvfp4_params = any(is_nvfp4tensor(p) for p in self.params) - # Secondary (packed) index map, counters, and bucket tracking for NVFP4. - self.nvfp4_packed_param_index_map = {} if self.has_nvfp4_params else None - nvfp4_packed_param_start_index = 0 if self.has_nvfp4_params else None - nvfp4_packed_param_end_index = None - nvfp4_packed_bucket_start_index = 0 if self.has_nvfp4_params else None - self.nvfp4_packed_bucket_indices = [] if self.has_nvfp4_params else None - nvfp4_packed_per_bucket_numel_unpadded = [] if self.has_nvfp4_params else None - - for param, _ in params_with_names[::-1]: - # Iterate through parameters in reverse order to roughly follow backprop order. - - param_start_index = _pad_start_of_param_if_needed(param_start_index) - if self.has_nvfp4_params: - nvfp4_packed_param_start_index = _pad_start_of_param_if_needed( - nvfp4_packed_param_start_index - ) - - # Create bucket with collected parameters if current param needs its own bucket. - if _does_param_require_new_bucket(param) and len(bucket_params) > 0: - # Finalize the current bucket and update start indices for the next bucket. - bucket_start_index, nvfp4_packed_bucket_start_index = ( - _finalize_bucket_all_index_spaces( - param_start_index, - bucket_start_index, - nvfp4_packed_param_start_index, - nvfp4_packed_bucket_start_index, - ) - ) - param_start_index = bucket_start_index - if self.has_nvfp4_params: - nvfp4_packed_param_start_index = nvfp4_packed_bucket_start_index - - # Primary index computation: always uses full param numel. - param_numel = param.data.nelement() - param_end_index = param_start_index + param_numel - self.param_index_map[param] = (param_start_index, param_end_index, bucket_id) - - # Secondary (packed) index computation for NVFP4. - if self.has_nvfp4_params: - if is_nvfp4tensor(param): - assert ( - param_numel % 2 == 0 - ), f"NVFP4 requires even numel for packing, got {param_numel}" - # NVFP4 packs two FP4 values into one byte, so packed numel is half. - nvfp4_packed_param_end_index = nvfp4_packed_param_start_index + param_numel // 2 - else: - nvfp4_packed_param_end_index = nvfp4_packed_param_start_index + param_numel - self.nvfp4_packed_param_index_map[param] = ( - nvfp4_packed_param_start_index, - nvfp4_packed_param_end_index, - bucket_id, - ) - - bucket_params.add(param) - - # If we have enough elements already or the current param is part of the shared - # embedding layer and needs a separate bucket, form a new bucket. - if ( - bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size - ) or _does_param_require_new_bucket(param): - # Finalize the current bucket and update start indices for the next bucket. - bucket_start_index, nvfp4_packed_bucket_start_index = ( - _finalize_bucket_all_index_spaces( - param_end_index, - bucket_start_index, - nvfp4_packed_param_end_index, - nvfp4_packed_bucket_start_index, - ) - ) - param_start_index = bucket_start_index - if self.has_nvfp4_params: - nvfp4_packed_param_start_index = nvfp4_packed_bucket_start_index - else: - param_start_index = param_end_index - if self.has_nvfp4_params: - nvfp4_packed_param_start_index = nvfp4_packed_param_end_index - - # Add remaining params to a new bucket. - if len(bucket_params) > 0: - _finalize_bucket_all_index_spaces( - param_end_index, - bucket_start_index, - nvfp4_packed_param_end_index, - nvfp4_packed_bucket_start_index, - ) + self.nvfp4_packed_param_index_map = None + self.nvfp4_packed_bucket_indices = None + if self.has_nvfp4_params: + self._compute_nvfp4_packed_layout(params_with_names) # Next, create underlying storage for buffer (with numel elements that includes # padding as necessary). @@ -1060,13 +994,14 @@ def _does_param_require_new_bucket(param): self.numel_unpadded = sum(per_bucket_numel_unpadded) if self.has_nvfp4_params: self.nvfp4_packed_numel = self.nvfp4_packed_bucket_indices[-1][1] - self.nvfp4_packed_numel_unpadded = sum(nvfp4_packed_per_bucket_numel_unpadded) + # nvfp4_packed_numel_unpadded is already set by _compute_nvfp4_packed_layout. assert self.numel_unpadded <= self.numel + if self.has_nvfp4_params: + assert self.nvfp4_packed_numel_unpadded <= self.nvfp4_packed_numel if self.ddp_config.use_distributed_optimizer: assert self.numel % self.data_parallel_world_size == 0 if self.has_nvfp4_params: - assert self.nvfp4_packed_numel_unpadded <= self.nvfp4_packed_numel assert self.nvfp4_packed_numel % self.data_parallel_world_size == 0 else: assert self.numel == self.numel_unpadded @@ -1297,6 +1232,93 @@ def _create_bucket(bucket_id, bucket_params, bucket_params_with_extra_main_grads dp_cp_group=self.dp_cp_group, ) + def _compute_nvfp4_packed_layout(self, params_with_names): + """Derive packed NVFP4 index map and bucket indices from the primary layout. + + The primary layout (self.param_index_map, self.bucket_indices) uses full numel + for all params. NVFP4 tensors pack two FP4 values into one byte, so the param + buffer needs a separate "packed" index map where NVFP4 params occupy half the + space. Non-NVFP4 params keep their full numel in the packed space. + + The same padding rules used by the primary layout are applied here: + - 64-element alignment at the start of each param. + - Bucket-end padding for DP-divisibility (when using distributed optimizer). + + Sets: + self.nvfp4_packed_param_index_map: param -> (start, end, bucket_id) + self.nvfp4_packed_bucket_indices: list of (start, end) per bucket + self.nvfp4_packed_numel_unpadded: total unpadded elements across all buckets + """ + + def _pad_start_of_param(param_start_index: int) -> int: + if self.ddp_config.use_distributed_optimizer: + return pad_param_start(param_start_index) + return param_start_index + + def _pad_end_of_bucket(bucket_end_index: int) -> int: + if self.ddp_config.use_distributed_optimizer: + return pad_bucket_end( + bucket_end_index, + self.data_parallel_world_size, + self.ddp_config.pad_buckets_for_high_nccl_busbw, + ) + return bucket_end_index + + self.nvfp4_packed_param_index_map = {} + self.nvfp4_packed_bucket_indices = [] + nvfp4_packed_per_bucket_numel_unpadded = [] + + packed_param_start = 0 + packed_bucket_start = 0 + cur_bucket_id = 0 + + for param, _ in params_with_names[::-1]: + _, _, bucket_id = self.param_index_map[param] + param_numel = param.data.nelement() + + packed_param_start = _pad_start_of_param(packed_param_start) + + # Finalize previous bucket if we've moved to a new one. + if bucket_id != cur_bucket_id: + # Record unpadded numel, then pad the bucket end. + nvfp4_packed_per_bucket_numel_unpadded.append( + packed_param_start - packed_bucket_start + ) + packed_bucket_end = _pad_end_of_bucket(packed_param_start) + self.nvfp4_packed_bucket_indices.append((packed_bucket_start, packed_bucket_end)) + packed_bucket_start = packed_bucket_end + packed_param_start = packed_bucket_start + cur_bucket_id = bucket_id + + # NVFP4 tensors use half the numel in the packed param buffer. + if is_nvfp4tensor(param): + assert ( + param_numel % 2 == 0 + ), f"NVFP4 requires even numel for packing, got {param_numel}" + packed_numel = param_numel // 2 + else: + packed_numel = param_numel + + packed_param_end = packed_param_start + packed_numel + self.nvfp4_packed_param_index_map[param] = ( + packed_param_start, + packed_param_end, + bucket_id, + ) + packed_param_start = packed_param_end + + # Finalize last bucket. + if packed_param_start > packed_bucket_start: + nvfp4_packed_per_bucket_numel_unpadded.append(packed_param_start - packed_bucket_start) + packed_bucket_end = _pad_end_of_bucket(packed_param_start) + self.nvfp4_packed_bucket_indices.append((packed_bucket_start, packed_bucket_end)) + + assert len(self.nvfp4_packed_bucket_indices) == len(self.bucket_indices), ( + f"Packed bucket count ({len(self.nvfp4_packed_bucket_indices)}) != " + f"primary bucket count ({len(self.bucket_indices)})" + ) + self.nvfp4_packed_numel_unpadded = sum(nvfp4_packed_per_bucket_numel_unpadded) + def scale_gradients(self, scaling_factor: float) -> None: """Scale the gradient data by `scaling_factor`.""" self.grad_data *= scaling_factor diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 2a82a1e1cf2..c6a8653c4ea 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2002,11 +2002,13 @@ def _encode_extra_state(self, state): return state_serialized def _decode_extra_state(self, state): + from megatron.core.safe_globals import SafeUnpickler + if isinstance(state, torch.Tensor): # No FP8 is indicated by an empty tensor we don't need to unpickle. if state.numel() == 0: return - return pickle.loads(state.detach().cpu().numpy().tobytes()) + return SafeUnpickler(io.BytesIO(state.detach().cpu().numpy().tobytes())).load() elif isinstance(state, io.BytesIO): state.seek(0) return torch.load(state, map_location="cuda") diff --git a/megatron/core/fault_injector.py b/megatron/core/fault_injector.py new file mode 100644 index 00000000000..68e0464fad7 --- /dev/null +++ b/megatron/core/fault_injector.py @@ -0,0 +1,233 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +import datetime +import logging +import math +import random +from dataclasses import dataclass +from typing import Optional, Protocol, Sequence, TypeVar, cast + +import torch +import torch.distributed as dist + +try: + from nvidia_resiliency_ext.shared_utils.inject_fault import ( # type: ignore[import-untyped] + Fault, + clear_workload_exception, + dispatch_fault_injection, + maybe_raise_workload_exception, + ) + + has_nvidia_resiliency_ext = True +except ModuleNotFoundError: + has_nvidia_resiliency_ext = False + + def maybe_raise_workload_exception(): # pylint: disable=missing-function-docstring + raise ModuleNotFoundError( + "nvidia_resiliency_ext is required for fault injection. " + "Please install it or disable fault injection." + ) + + +__all__ = ["FaultInjectorConfig", "setup_fault_injection", "maybe_raise_workload_exception"] + + +def _require_nvidia_resiliency_ext(): + if not has_nvidia_resiliency_ext: + raise ModuleNotFoundError( + "nvidia_resiliency_ext is required for fault injection. " + "Please install it or disable fault injection." + ) + + +logger = logging.getLogger(__name__) + +_T = TypeVar("_T") + + +@dataclass(kw_only=True) +class FaultInjectorConfig: + """Configuration for fault injection testing via nvidia_resiliency_ext.""" + + fault_injector_ranks: Optional[str] = None + """Comma-separated list of ranks to inject faults on.""" + + fault_injector_num_ranks: Optional[int] = None + """Number of ranks to inject faults on (random selection).""" + + fault_injector_fault_types: Optional[str] = None + """Comma-separated list of fault types to inject (e.g. 'hang,crash').""" + + fault_injector_fault_probabilities: Optional[str] = None + """Comma-separated list of fault probabilities (normalized at runtime).""" + + fault_injector_fault_delay: Optional[float] = None + """Force a specific fault delay in seconds from training start or delay_start_iteration.""" + + fault_injector_delay_start_iteration: Optional[int] = None + """Start the fault delay timer after iteration N completes. + If unset, fault delay timing starts from the beginning of training.""" + + fault_injector_mtti_seconds: Optional[float] = None + """Mean time to inject (MTTI) in seconds; used when fault_delay is None.""" + + fault_injector_offset_seconds: Optional[float] = None + """Offset seconds added to the sampled fault delay.""" + + fault_injector_seed: Optional[int] = None + """RNG seed for the fault injector.""" + + +class _FaultInjectorRNG(Protocol): + """Minimal RNG interface used by fault injector helper functions.""" + + def sample(self, population: Sequence[int], k: int) -> list[int]: + """Return ``k`` sampled items from the given population.""" + ... + + def choices(self, population: Sequence[_T], weights: Sequence[float], k: int) -> list[_T]: + """Return ``k`` weighted samples from the given population.""" + ... + + def random(self) -> float: + """Return a floating-point value in the half-open interval [0.0, 1.0).""" + ... + + +rng: _FaultInjectorRNG | None = None + + +def _require_rng() -> _FaultInjectorRNG: + assert rng is not None, "fault injector rng must be initialized" + return rng + + +def get_fault_ranks(config: FaultInjectorConfig): + """Return list of ranks to inject faults on, from explicit list or random sample.""" + global rng + + force_ranks = config.fault_injector_ranks + world_size = dist.get_world_size() + + if force_ranks is not None: + assert ( + config.fault_injector_num_ranks is None + ), "Cannot specify both force_ranks and num_ranks" + if ',' in force_ranks: + fault_ranks = [int(r) for r in force_ranks.split(",")] + else: + fault_ranks = [int(force_ranks)] + assert all( + 0 <= r < world_size for r in fault_ranks + ), f"Fault ranks must be between 0 and {world_size - 1}" + assert len(fault_ranks) > 0, "Must specify at least one fault rank" + else: + assert ( + config.fault_injector_num_ranks is not None + ), "Must specify either force_ranks or num_ranks" + fault_ranks = _require_rng().sample(range(1, world_size), k=config.fault_injector_num_ranks) + + return fault_ranks + + +def get_fault(config: FaultInjectorConfig): + """Sample a fault type according to the configured types and probabilities.""" + _require_nvidia_resiliency_ext() + global rng + + fault_types_config = config.fault_injector_fault_types + fault_probabilities_config = config.fault_injector_fault_probabilities + assert fault_types_config is not None, "fault_injector_fault_types must be specified" + + if ',' in fault_types_config: + fault_types = [Fault[t.upper()] for t in fault_types_config.split(",")] + else: + fault_types = [Fault[fault_types_config.upper()]] + + if fault_probabilities_config is not None: + if ',' in fault_probabilities_config: + fault_probabilities = [float(p) for p in fault_probabilities_config.split(",")] + else: + fault_probabilities = [float(fault_probabilities_config)] + fault_probabilities = [p / sum(fault_probabilities) for p in fault_probabilities] + else: + fault_probabilities = [1 / len(fault_types) for _ in fault_types] + + assert len(fault_types) > 0, "Must specify at least one fault type" + assert len(fault_types) == len( + fault_probabilities + ), "Number of fault types and fault probabilities must match" + + return _require_rng().choices(fault_types, fault_probabilities, k=1)[0] + + +def should_setup_fault_injection_at_start(config: FaultInjectorConfig): + """Return True when fault timing is anchored to training start.""" + return config.fault_injector_delay_start_iteration is None + + +def should_setup_fault_injection_at_iteration(config: FaultInjectorConfig, iteration): + """Return True when fault timing should start from the given iteration.""" + delay_start_iteration = config.fault_injector_delay_start_iteration + return delay_start_iteration is not None and delay_start_iteration == iteration + + +def get_fault_delay(config: FaultInjectorConfig): + """Return fault delay in seconds from the configured scheduling anchor.""" + global rng + + fault_delay = config.fault_injector_fault_delay + assert ( + fault_delay is not None or config.fault_injector_mtti_seconds is not None + ), "fault_injector_fault_delay or fault_injector_mtti_seconds must be specified" + if fault_delay is None: + mtti_seconds = config.fault_injector_mtti_seconds + assert mtti_seconds is not None, "fault_injector_mtti_seconds must be specified" + offset_seconds = config.fault_injector_offset_seconds or 0.0 + lambda_inj = 1.0 / mtti_seconds + fault_delay = offset_seconds + (-math.log(1.0 - _require_rng().random()) / lambda_inj) + + return fault_delay + + +def setup_fault_injection(config: FaultInjectorConfig): + """Broadcast fault plan across ranks and dispatch injection on target ranks.""" + _require_nvidia_resiliency_ext() + global rng + + my_rank = dist.get_rank() + world_size = dist.get_world_size() + + device = torch.device("cuda", torch.cuda.current_device()) + plan_tensor = torch.full((world_size + 1,), float("nan"), dtype=torch.float64, device=device) + + clear_workload_exception() + + if my_rank == 0: + if rng is None: + rng = cast(_FaultInjectorRNG, random.Random(config.fault_injector_seed)) + + fault_ranks = get_fault_ranks(config) + fault = get_fault(config) + fault_delay = get_fault_delay(config) + + for rank in fault_ranks: + plan_tensor[rank] = float(fault.value) + plan_tensor[world_size] = fault_delay + + dist.broadcast(plan_tensor, src=0) + + planned_fault = float(plan_tensor[my_rank].item()) + is_target_rank = not math.isnan(planned_fault) + + if is_target_rank: + fault = Fault(int(planned_fault)) + fault_delay = float(plan_tensor[world_size].item()) + current_time = datetime.datetime.now() + fault_time = current_time + datetime.timedelta(seconds=fault_delay) + timestamp = current_time.strftime("%Y-%m-%d %H:%M:%S.%f") + fault_timestamp = fault_time.strftime("%Y-%m-%d %H:%M:%S.%f") + logger.warning( + f"[{timestamp}] FAULT INJECTION: Rank {my_rank} will inject fault " + f"{fault.name} at {fault_timestamp}" + ) + dispatch_fault_injection(fault=fault, delay=fault_delay, callback=None) diff --git a/megatron/core/fp4_utils.py b/megatron/core/fp4_utils.py index 245a04eb39e..45e57285a8d 100644 --- a/megatron/core/fp4_utils.py +++ b/megatron/core/fp4_utils.py @@ -81,7 +81,8 @@ def modify_nvfp4_rowwise_storage(fp4_tensor: torch.Tensor, new_rowwise_data: tor ), "Rowwise NVFP4 storage must be uint8" # Preserve existing values and then swap storage new_rowwise_data.detach().copy_(old_rowwise) - setattr(fp4_tensor, "_rowwise_data", new_rowwise_data) + fp4_tensor._rowwise_data = new_rowwise_data + del old_rowwise def quantize_nvfp4_param_shard( diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 4e23151c533..b9f62e59547 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -14,7 +14,7 @@ import torch -from megatron.core.utils import get_pg_size +from megatron.core.utils import get_pg_size, round_up_to_nearest_multiple @dataclass(order=True, frozen=True) @@ -85,6 +85,10 @@ def is_valid( Returns: True if the config is valid, False otherwise """ + # A dimension with no tokens serves no requests. + if self.token_count <= 0: + return False + # Check if total requests exceed maximum if self.prefill_req_count + self.decode_req_count > max_requests: return False @@ -138,91 +142,64 @@ def req_count(self) -> int: @staticmethod def adjust_batch_dims_for_expert_parallelism( local_batch_dims, - strict: bool, - decode_only_cuda_graphs: bool, - smallest_non_decode_cuda_graph_size: int, ep_group: Optional[torch.distributed.ProcessGroup] = None, - num_speculative_tokens: int = 0, + ep_zmq_communicator=None, ) -> Optional["InferenceBatchDimensions"]: - """Adjusted cuda graph batch dimensions for expert parallelism. - We take the max token count across expert model parallel group. + """Adjust CUDA graph batch dimensions for expert parallelism. + + All-reduce-max the token count and non-decode flag across the EP group. + If any rank has a prefill (non-decode) step, all ranks fall back to eager + mode (return None) — the non-CG path handles variable token counts via + use_allgather_v. Otherwise return adjusted dims with the max token count. Args: local_batch_dims: The local batch dimensions to adjust. - strict: Whether to use strict matching for batch dimensions. - decode_only_cuda_graphs: Whether CUDA graphs are only used for decode steps. ep_group: Optional expert parallel process group. If None, uses global parallel state. When using different EP sizes for inference vs training, pass the inference EP group explicitly. + ep_zmq_communicator: Optional AsyncZMQCommunicator over the EP group. When + provided, the cross-rank MAX reduction runs on the CPU via ZMQ + (no GPU kernel, no H2D/D2H), avoiding a per-step NCCL AllReduce + on the compute stream. When absent, falls back to + torch.distributed.all_reduce on a GPU tensor. - Return: - (InferenceBatchDimensions) A new InferenceBatchDimensions object with - adjusted dimensions, or None if eager mode should be used. + Returns: + InferenceBatchDimensions with max token count, or None for eager mode. """ ep_size = get_pg_size(ep_group) if ep_size <= 1: return local_batch_dims - # all reduce local work across expert model parallel group is_non_decode = local_batch_dims.prefill_req_count > 0 - sync_tensor = torch.tensor( - [ - local_batch_dims.token_count, - int(is_non_decode), - local_batch_dims.prefill_req_count, - local_batch_dims.decode_req_count, - ], - dtype=torch.int32, - device=torch.cuda.current_device(), - ) + if ep_zmq_communicator is not None: + # CPU-only sync via ZMQ: avoids a NCCL AllReduce kernel on the + # compute stream plus the H2D/D2H pair that sandwiches it. + (max_token_count, max_is_non_decode) = ep_zmq_communicator.sync_all_reduce_max( + local_batch_dims.token_count, int(is_non_decode) + ) + else: + sync_tensor = torch.tensor( + [local_batch_dims.token_count, int(is_non_decode)], + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + torch.distributed.all_reduce( + sync_tensor, op=torch.distributed.ReduceOp.MAX, group=ep_group + ) + sync_tensor = sync_tensor.cpu() + max_token_count = int(sync_tensor[0].item()) + max_is_non_decode = int(sync_tensor[1].item()) - torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.MAX, group=ep_group) - - sync_tensor = sync_tensor.cpu() - is_any_ep_rank_in_non_decode = sync_tensor[1].item() == 1 - - # We force eager mode for scenarios where some ranks will run with CUDA graphs - # while others will not. Without this check, communication in the - # expert routing layer would pad up to the maximum capacity only for the ranks that - # are using CUDA graphs in this step, leading to a hang. - # This can happen if we only allow decode CUDA graphs but some ranks are running - # non-decode batches. - if is_any_ep_rank_in_non_decode and decode_only_cuda_graphs: - return None # indicate no match, run in eager mode - - adjusted_token_count = int(sync_tensor[0].item()) - - # Sync request counts across EP ranks when strict matching is enabled - # or when speculative tokens are used. With speculative tokens, - # decode-only graphs have token counts of decode_req_count * (spec+1) - # which creates a different granularity than mixed graphs (raw sizes). - # Without syncing, decode-only ranks and prefill ranks search different - # graph pools and may pick graphs with different token counts. - sync_request_counts = strict or ( - is_any_ep_rank_in_non_decode and num_speculative_tokens > 0 - ) - adjusted_prefill_req_count = ( - int(sync_tensor[2].item()) - if sync_request_counts - else local_batch_dims.prefill_req_count - ) - adjusted_decode_req_count = ( - int(sync_tensor[3].item()) if sync_request_counts else local_batch_dims.decode_req_count - ) + is_any_ep_rank_in_non_decode = max_is_non_decode == 1 - # When any EP rank has prefill requests (non-strict mode), elevate - # the token count to be >= the smallest prefill/mixed cuda graph. - # This ensures decode-only ranks don't match a fine-grained decode - # graph while prefill ranks match a coarser mixed graph, which would - # produce inconsistent token counts across EP ranks. - if is_any_ep_rank_in_non_decode and not strict: - adjusted_token_count = max(adjusted_token_count, smallest_non_decode_cuda_graph_size) + if is_any_ep_rank_in_non_decode: + return None # any rank has prefill → eager mode adjusted_batch_dim = InferenceBatchDimensions( - token_count=adjusted_token_count, - prefill_req_count=adjusted_prefill_req_count, - decode_req_count=adjusted_decode_req_count, + token_count=max_token_count, + prefill_req_count=local_batch_dims.prefill_req_count, + decode_req_count=local_batch_dims.decode_req_count, ) return adjusted_batch_dim @@ -269,7 +246,9 @@ def _calculate_cuda_graph_token_counts( ) # Align each entry to TP size cuda_graph_token_counts = list( - dict.fromkeys(math.ceil(s / tp_size) * tp_size for s in cuda_graph_token_counts) + dict.fromkeys( + round_up_to_nearest_multiple(s, tp_size) for s in cuda_graph_token_counts + ) ) # Clamp to max tokens cuda_graph_token_counts = [ @@ -291,7 +270,7 @@ def _calculate_cuda_graph_token_counts( math.ceil(int(cuda_graph_step_size) / CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER) ) # Make sure divisible by TP size - cuda_graph_step_size = math.ceil(cuda_graph_step_size / tp_size) * tp_size + cuda_graph_step_size = round_up_to_nearest_multiple(cuda_graph_step_size, tp_size) # round down cuda graph max tokens to be multiple of TP size cuda_graph_max_tokens = (cuda_graph_max_tokens // tp_size) * tp_size @@ -506,11 +485,10 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int def match_graph_config( real_batch_dim: InferenceBatchDimensions, cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions], - smallest_non_decode_cuda_graph_size: int, strict: bool = False, - decode_only_cuda_graphs: bool = False, ep_group: Optional[torch.distributed.ProcessGroup] = None, - num_speculative_tokens: int = 0, + ep_zmq_communicator=None, + match_ep_token_counts: bool = True, ) -> Optional[InferenceBatchDimensions]: """ Matches the best CUDA graph batch dimension for the given real batch dimension. @@ -526,6 +504,14 @@ def match_graph_config( ep_group: Optional expert parallel process group. If None, uses global parallel state. When using different EP sizes for inference vs training, pass the inference EP group explicitly. + ep_zmq_communicator: Optional AsyncZMQCommunicator over the EP group. When + provided, batch-dimension MAX reduction uses a CPU-only ZMQ sync + instead of a GPU NCCL AllReduce. Forwarded to + adjust_batch_dims_for_expert_parallelism. + match_ep_token_counts: If True (default), token counts are synced across EP ranks via + all-reduce-max so all ranks select the same CUDA graph. Set to False when the + dispatcher handles per-rank token variation internally (e.g. AGV/RSV in the NVLS + path) and external EP sync is not needed. Returns: The best matching CUDA graph batch dimension, or None if no applicable match is found """ @@ -534,20 +520,20 @@ def match_graph_config( # no need to match if no cuda graph batch dimensions are provided return None - adjusted_batch_dim = InferenceBatchDimensions.adjust_batch_dims_for_expert_parallelism( - real_batch_dim, - strict=strict, - decode_only_cuda_graphs=decode_only_cuda_graphs, - ep_group=ep_group, - smallest_non_decode_cuda_graph_size=smallest_non_decode_cuda_graph_size, - num_speculative_tokens=num_speculative_tokens, - ) + if match_ep_token_counts: + # NCCL dispatcher: all EP ranks must select the same CUDA graph. Sync batch dims + # across the EP group so graph selection is consistent. + adjusted_batch_dim = InferenceBatchDimensions.adjust_batch_dims_for_expert_parallelism( + real_batch_dim, ep_group=ep_group, ep_zmq_communicator=ep_zmq_communicator + ) - if adjusted_batch_dim is None: - # we hit this scenario if decode_only_cuda_graphs is true, - # and one of the EP ranks is running a non-decode step - # in that case, all ranks have to run in eager mode - return None + if adjusted_batch_dim is None: + # we hit this scenario if decode_only_cuda_graphs is true, + # and one of the EP ranks is running a non-decode step + # in that case, all ranks have to run in eager mode + return None + else: + adjusted_batch_dim = real_batch_dim # first filter out batch dimensions with smaller token count, prefill req count, # or decode req count, as they are not applicable diff --git a/megatron/core/inference/communication/torch_symm_triton/__init__.py b/megatron/core/inference/communication/torch_symm_triton/__init__.py index 967dc8329f1..75da02eaf4b 100644 --- a/megatron/core/inference/communication/torch_symm_triton/__init__.py +++ b/megatron/core/inference/communication/torch_symm_triton/__init__.py @@ -3,3 +3,8 @@ from .collectives import multimem_all_gather, multimem_all_gather_fused, multimem_reduce_scatter from .fused_collectives import fused_multimem_rs_add_norm_ag from .utils import are_tensors_nvls_eligible, is_device_nvls_capable +from .variable_collectives import ( + multimem_all_gather_v, + multimem_all_gatherv_3tensor, + multimem_reduce_scatter_v, +) diff --git a/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py b/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py index 859b9010aea..eace10ff167 100644 --- a/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py +++ b/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py @@ -211,6 +211,182 @@ def add_v8_bf16_from_u32( ) +@triton.jit +def ld_64(ptr, mask): + """ + Loads 64 bits from local global memory into two 32-bit registers. + + Uses `ld.global.v2.u32`. Mirrors the non-multicast path of ld_128. + + Args: + ptr: source pointer typed as uint64 (8-byte aligned). + mask: boolean predicate — if False, the load is skipped. + + Returns: + (x, y): two tl.uint32 registers containing 64 bits of loaded data. + """ + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $3, 1; + @%p0 bra end; + ld.global.v2.u32 {$0, $1}, [$2]; + end: + } + """, + "=r,=r,l,r", + args=[ptr, mask.to(tl.int32)], + dtype=(tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + + +@triton.jit +def st_64(ptr, x, y, mask, multicast_op: tl.constexpr): + """ + Stores 64 bits (two 32-bit registers) to memory. + + Mirrors st_128 but operates on 64-bit (v2) quantities. + + 1. **Standard Store (`multicast_op=False`)**: + - `st.global.v2.f32` — writes 64 bits to local global memory. + + 2. **Multicast Store (`multicast_op=True`)**: + - `multimem.st.relaxed.sys.global.v2.f32` — broadcasts 64 bits to all + peers in the multicast group simultaneously. + + Args: + ptr: destination pointer typed as uint64 (8-byte aligned). + x, y: two tl.uint32 registers containing the data to store. + mask: boolean predicate — if False, the store is skipped. + multicast_op (tl.constexpr): False = local store, True = multicast broadcast. + """ + if multicast_op: + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $4, 1; + @%p0 bra end; + multimem.st.relaxed.sys.global.v2.f32 [$1], {$2, $3}; + end: + } + """, + "=r,l,r,r,r", + args=[ptr, x, y, mask.to(tl.int32)], + dtype=(tl.uint32), + is_pure=False, + pack=1, + ) + else: + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $4, 1; + @%p0 bra end; + st.global.v2.f32 [$1], {$2, $3}; + end: + } + """, + "=r,l,r,r,r", + args=[ptr, x, y, mask.to(tl.int32)], + dtype=(tl.uint32), + is_pure=False, + pack=1, + ) + + +@triton.jit +def ld_32(ptr, mask): + """ + Loads 32 bits from local global memory into one 32-bit register. + + Uses `ld.global.u32`. Scalar version of ld_64/ld_128. + + Args: + ptr: source pointer typed as uint32 (4-byte aligned). + mask: boolean predicate — if False, the load is skipped. + + Returns: + x: one tl.uint32 register containing 32 bits of loaded data. + """ + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $2, 1; + @%p0 bra end; + ld.global.u32 $0, [$1]; + end: + } + """, + "=r,l,r", + args=[ptr, mask.to(tl.int32)], + dtype=(tl.uint32,), + is_pure=True, + pack=1, + ) + + +@triton.jit +def st_32(ptr, x, mask, multicast_op: tl.constexpr): + """ + Stores 32 bits (one 32-bit register) to memory. + + Scalar version of st_64/st_128. + + 1. **Standard Store (`multicast_op=False`)**: + - `st.global.f32` — writes 32 bits to local global memory. + + 2. **Multicast Store (`multicast_op=True`)**: + - `multimem.st.relaxed.sys.global.f32` — broadcasts 32 bits to all + peers in the multicast group simultaneously. + + Args: + ptr: destination pointer typed as uint32 (4-byte aligned). + x: one tl.uint32 register containing the data to store. + mask: boolean predicate — if False, the store is skipped. + multicast_op (tl.constexpr): False = local store, True = multicast broadcast. + """ + if multicast_op: + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $3, 1; + @%p0 bra end; + multimem.st.relaxed.sys.global.f32 [$1], $2; + end: + } + """, + "=r,l,r,r", + args=[ptr, x, mask.to(tl.int32)], + dtype=(tl.uint32), + is_pure=False, + pack=1, + ) + else: + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $3, 1; + @%p0 bra end; + st.global.f32 [$1], $2; + end: + } + """, + "=r,l,r,r", + args=[ptr, x, mask.to(tl.int32)], + dtype=(tl.uint32), + is_pure=False, + pack=1, + ) + + @triton.jit def asm_rsqrt(x, eps): """ diff --git a/megatron/core/inference/communication/torch_symm_triton/variable_collectives.py b/megatron/core/inference/communication/torch_symm_triton/variable_collectives.py new file mode 100644 index 00000000000..a32b20b9a14 --- /dev/null +++ b/megatron/core/inference/communication/torch_symm_triton/variable_collectives.py @@ -0,0 +1,776 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Variable-count NVLS collectives (AllGatherV / ReduceScatterV). + +Unlike the uniform collectives in collectives.py, each rank may contribute +a different number of tokens. The caller provides: + - rank_token_offset: prefix sum of token counts for all lower-ranked ranks. + - local_tokens: this rank's token count. + +One CTA processes one token; the outer loop is persistent over local_tokens. +""" + +from unittest.mock import MagicMock + +import torch + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + HAVE_TRITON = False + +try: + from torch._C._distributed_c10d import _SymmetricMemory +except ImportError: + _SymmetricMemory = MagicMock() + +from .barrier import symm_mem_sync +from .multimem_asm import ld_64, ld_128, st_64, st_128 +from .utils import is_device_nvls_capable, sync_threads + + +@triton.jit +def _multimem_all_gather_v_kernel( + local_ptr, + multicast_ptr, + signal_pad_ptrs, + local_tokens, + rank_token_offset_ptr, + ep_max_tokens_ptr, + output_byte_offset, + HIDDEN_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + BITS: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): + """Variable-count multicast all-gather kernel. One CTA processes one token. + + Each rank contributes local_tokens tokens starting at rank_token_offset in + the global output. Ranks may have different local_tokens values. + + Args: + local_ptr: pointer to this rank's local input, shape [local_tokens, hidden_size]. + multicast_ptr: multicast pointer to the output symmetric memory buffer. + signal_pad_ptrs: signal pads for barrier synchronization. + local_tokens: number of tokens this rank contributes. + rank_token_offset_ptr: pointer to a scalar int32 CUDA tensor holding the index + of the first token this rank writes in the global output (prefix sum of + local_tokens for all lower-ranked ranks). Fixed address; value set each step. + ep_max_tokens_ptr: pointer to a scalar int32 CUDA tensor holding the + maximum local_tokens across all EP ranks for this iteration. Fixed address; + value set each step. CTAs with pid >= this value exit immediately. Safe + because the value is identical on all ranks, so paired CTAs on every rank + exit together — the barrier for those CTAs is never entered on any rank. + output_byte_offset: byte offset of this tensor within the symmetric memory buffer. + HIDDEN_SIZE: hidden dimension, i.e. number of elements per token row (constexpr). + BLOCK_SIZE: threads per block (constexpr, >= numel_per_token). + NUMEL_PER_THREAD: elements per thread per load/store, i.e. BITS / element_bits (constexpr). + BITS: width of each load/store in bits — 128 for activations (bf16) and expert + indices (int64, always 16-byte aligned for any topk); 64 for routing probs + (fp32 with topk=6 or topk=22 yields 24/88-byte rows, not 16-byte aligned + but 8-byte aligned) (constexpr). + RANK: this rank's index (constexpr). + WORLD_SIZE: total number of ranks (constexpr). + """ + pid = tl.program_id(axis=0) + + # Exit before the barrier if this CTA's pid exceeds the iteration maximum. + # ep_max_tokens is the max over all EP ranks, so all ranks agree on + # which CTAs exit — the barrier slots for those CTAs are never touched on any rank. + ep_max_tokens = tl.load(ep_max_tokens_ptr) + if pid >= ep_max_tokens: + return + + tid = tl.arange(0, BLOCK_SIZE) + rank_token_offset = tl.load(rank_token_offset_ptr) + + numel_per_token = tl.cdiv(HIDDEN_SIZE, NUMEL_PER_THREAD) + local_numel = local_tokens * numel_per_token + # BLOCK_SIZE is the next power of 2 >= numel_per_token, so it may be larger. + # channel_mask deactivates the extra padding threads (tid >= numel_per_token). + channel_mask = tid < numel_per_token + + for token_offset in range(pid, local_tokens, tl.num_programs(axis=0)): + for channel_offset in range(0, numel_per_token, BLOCK_SIZE): + local_offsets = token_offset * numel_per_token + channel_offset + tid + # Two independent masks in orthogonal dimensions: + # channel_mask — deactivates power-of-2 padding threads (tid >= numel_per_token). + # token_mask — deactivates overflow threads in the last inner-loop chunk + # when numel_per_token > BLOCK_SIZE and the window + # [channel_offset, channel_offset+BLOCK_SIZE) extends past + # the final token row. + token_mask = local_offsets < local_numel + mask = token_mask & channel_mask + + # This rank's tokens start at rank_token_offset in the global output. + global_offsets = rank_token_offset * numel_per_token + local_offsets + + if BITS == 128: + # Each 128-bit pack occupies 2 uint64 units; output_byte_offset // 8 converts + # the tensor's byte offset within the symm-mem buffer to uint64 units. + # The global offset is multiplied by 2 to convert from 128-bit + # units to uint64 units. + multicast_ptrs = ( + multicast_ptr.to(tl.pointer_type(tl.uint64)) + + output_byte_offset // 8 + + global_offsets * 2 + ) + local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + local_offsets * 2 + (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) + st_128(multicast_ptrs, x, y, z, w, mask=mask, multicast_op=True) + else: + # Each 64-bit pack is exactly 1 uint64, so offsets index directly (no * 2 stride). + multicast_ptrs = ( + multicast_ptr.to(tl.pointer_type(tl.uint64)) + + output_byte_offset // 8 + + global_offsets + ) + local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + local_offsets + (x, y) = ld_64(local_ptrs, mask=mask) + st_64(multicast_ptrs, x, y, mask=mask, multicast_op=True) + + sync_threads() + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=True, + hasSubsequentMemAccess=True, + ) + + +@triton.jit +def _multimem_reduce_scatter_v_kernel( + local_ptr, + multicast_ptr, + signal_pad_ptrs, + local_tokens, + rank_token_offset_ptr, + ep_max_tokens_ptr, + input_byte_offset, + HIDDEN_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, + REDUCE_F32: tl.constexpr = False, +): + """Variable-count multicast reduce-scatter kernel. One CTA processes one token. + + Reads this rank's token shard from the symmetric buffer via multimem.ld_reduce + (which atomically sums contributions from all EP ranks) and writes the result + to local memory. + + The barrier runs first — it waits for all ranks to have written their expert + GEMM outputs into the symmetric buffer before any rank starts reading. + + Args: + local_ptr: output pointer to this rank's local buffer, shape [local_tokens, hidden_size]. + multicast_ptr: multicast pointer to the symmetric memory buffer holding all expert outputs. + signal_pad_ptrs: signal pads for barrier synchronization. + local_tokens: number of tokens this rank owns. + rank_token_offset_ptr: pointer to a scalar int32 CUDA tensor holding the index of the + first token this rank owns in the global token sequence. Fixed address; set each step. + ep_max_tokens_ptr: pointer to a scalar int32 CUDA tensor holding the maximum local_tokens + across all EP ranks. Fixed address; set each step. CTAs with pid >= this value exit + immediately — safe because the value is identical on all ranks. + input_byte_offset: byte offset of the input tensor within the symmetric memory buffer. + HIDDEN_SIZE: number of elements per token row (constexpr). + BLOCK_SIZE: threads per block (constexpr, >= numel_per_token). + NUMEL_PER_THREAD: elements per thread per load/store, i.e. 128 / element_bits (constexpr). + RANK: this rank's index (constexpr). + WORLD_SIZE: total number of ranks (constexpr). + """ + pid = tl.program_id(axis=0) + + # Exit before the barrier if this CTA's pid exceeds the iteration maximum. + # ep_max_tokens is the max over all EP ranks, so all ranks agree on which + # CTAs exit — the barrier slots for those CTAs are never touched on any rank. + ep_max_tokens = tl.load(ep_max_tokens_ptr) + if pid >= ep_max_tokens: + return + + # Wait for all ranks to have written their expert GEMM outputs to symm_mem + # before any rank starts the reduce-load. + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=False, + hasSubsequentMemAccess=False, + ) + sync_threads() + + tid = tl.arange(0, BLOCK_SIZE) + rank_token_offset = tl.load(rank_token_offset_ptr) + + numel_per_token = tl.cdiv(HIDDEN_SIZE, NUMEL_PER_THREAD) + local_numel = local_tokens * numel_per_token + # channel_mask: deactivates power-of-2 padding threads (tid >= numel_per_token). + channel_mask = tid < numel_per_token + + for token_offset in range(pid, local_tokens, tl.num_programs(axis=0)): + program_offset = token_offset * numel_per_token + + for channel_offset in range(0, numel_per_token, BLOCK_SIZE): + local_offsets = program_offset + channel_offset + tid + # Two independent masks in orthogonal dimensions: + # channel_mask — deactivates power-of-2 padding threads (tid >= numel_per_token). + # token_mask — deactivates overflow threads in the last inner-loop chunk + # when numel_per_token > BLOCK_SIZE and the window + # [channel_offset, channel_offset+BLOCK_SIZE) extends past + # the final token row. + token_mask = local_offsets < local_numel + mask = token_mask & channel_mask + + # This rank's tokens start at rank_token_offset in the global input. + global_offsets = rank_token_offset * numel_per_token + local_offsets + + # Each 128-bit pack occupies 2 uint64 units; input_byte_offset // 8 converts + # the tensor's byte offset within the symm-mem buffer to uint64 units. + multicast_ptrs = ( + multicast_ptr.to(tl.pointer_type(tl.uint64)) + + input_byte_offset // 8 + + global_offsets * 2 + ) + local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + local_offsets * 2 + + (x, y, z, w) = ld_128( + multicast_ptrs, mask=mask, multicast_op=True, reduce_f32=REDUCE_F32 + ) + st_128(local_ptrs, x, y, z, w, mask=mask, multicast_op=False) + + +def multimem_reduce_scatter_v( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + symm_mem_hdl: _SymmetricMemory, + rank_token_offset: torch.Tensor, + ep_max_tokens: torch.Tensor, + per_rank_max_tokens: int, + input_byte_offset: int = 0, + **kwargs, +) -> torch.Tensor: + """Variable-count multicast reduce-scatter for a single 2-D tensor. + + Reduces expert GEMM outputs across all EP ranks. Each rank reads its owned + token shard [rank_token_offset : rank_token_offset + local_tokens] from the + symmetric buffer using multimem.ld_reduce (which atomically sums all ranks' + contributions), and writes the result to output_tensor. + + Both tensors must be 2-D and 16-byte row-aligned (128-bit path only). + hidden_size is inferred from output_tensor.shape[1]. + + Args: + output_tensor: local output, shape [local_tokens, hidden_size]. + input_tensor: symmetric memory buffer holding all expert outputs, + shape [global_tokens, hidden_size]. + symm_mem_hdl: symmetric memory handle for input_tensor. + rank_token_offset: pre-allocated scalar int32 CUDA tensor. The dispatcher + writes this rank's token offset into it each step before kernel launch. + ep_max_tokens: pre-allocated scalar int32 CUDA tensor. The dispatcher writes + the maximum local_tokens across all EP ranks each step. CTAs with + pid >= ep_max_tokens exit immediately without entering the barrier. + per_rank_max_tokens: static int set at model init. Determines the CTA grid size + as min(per_rank_max_tokens, MAX_NUM_BLOCKS). + input_byte_offset: byte offset of input_tensor within the symmetric memory + buffer (for packing multiple tensors into one buffer; 0 otherwise). + + Returns: + output_tensor populated with this rank's reduced token outputs. + """ + assert HAVE_TRITON, "Triton is required for multimem reduce-scatter-v." + assert ( + output_tensor.ndim == 2 and input_tensor.ndim == 2 + ), "output_tensor and input_tensor must be 2-D [tokens, hidden_size]." + assert is_device_nvls_capable( + output_tensor.device + ), "multimem_reduce_scatter_v requires a Hopper+ GPU with NVLink (SM >= 9)." + assert ( + rank_token_offset.numel() == 1 + and rank_token_offset.dtype == torch.int32 + and rank_token_offset.is_cuda + ), "rank_token_offset must be a scalar int32 CUDA tensor." + assert output_tensor.dtype in ( + torch.bfloat16, + torch.float32, + ), f"Only bfloat16 and float32 are supported, got {output_tensor.dtype}" + assert ( + output_tensor.dtype == input_tensor.dtype + ), f"output and input dtype mismatch: {output_tensor.dtype} vs {input_tensor.dtype}" + + hidden_size = output_tensor.shape[1] + assert ( + input_tensor.shape[1] == hidden_size + ), f"input and output hidden_size mismatch: {input_tensor.shape[1]} vs {hidden_size}" + row_bytes = hidden_size * output_tensor.element_size() + assert row_bytes % 16 == 0, ( + f"Row size ({hidden_size} elements × {output_tensor.element_size()} bytes) = " + f"{row_bytes} bytes is not 16-byte aligned; RSV requires 128-bit alignment." + ) + + MAX_NUM_BLOCKS = kwargs.get("max_num_blocks", 128) + MAX_BLOCK_SIZE = 1024 + WARP_SIZE = 32 + + local_tokens = output_tensor.shape[0] + numel_per_thread = 128 // (output_tensor.element_size() * 8) + numel_per_token = (hidden_size + numel_per_thread - 1) // numel_per_thread + + block_size = min(triton.next_power_of_2(numel_per_token), MAX_BLOCK_SIZE) + num_warps = max(1, block_size // WARP_SIZE) + num_blocks = min(per_rank_max_tokens, MAX_NUM_BLOCKS) + + reduce_f32 = output_tensor.dtype == torch.float32 + _multimem_reduce_scatter_v_kernel[(num_blocks, 1, 1)]( + output_tensor.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + local_tokens=local_tokens, + rank_token_offset_ptr=rank_token_offset, + ep_max_tokens_ptr=ep_max_tokens, + input_byte_offset=input_byte_offset, + HIDDEN_SIZE=hidden_size, + BLOCK_SIZE=block_size, + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + REDUCE_F32=reduce_f32, + num_warps=num_warps, + ) + + return output_tensor + + +@triton.jit +def _multimem_all_gatherv_3tensor_kernel( + local_ptr_0, + multicast_ptr_0, + output_byte_offset_0, + local_ptr_1, + multicast_ptr_1, + output_byte_offset_1, + local_ptr_2, + multicast_ptr_2, + output_byte_offset_2, + signal_pad_ptrs, + local_tokens, + rank_token_offset_ptr, + ep_max_tokens_ptr, + HIDDEN_SIZE_0: tl.constexpr, + HIDDEN_SIZE_1: tl.constexpr, + HIDDEN_SIZE_2: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD_0: tl.constexpr, + NUMEL_PER_THREAD_1: tl.constexpr, + NUMEL_PER_THREAD_2: tl.constexpr, + BITS_0: tl.constexpr, + BITS_1: tl.constexpr, + BITS_2: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): + """Variable-count multicast all-gather for three tensors in a single kernel. + + Identical semantics to _multimem_all_gather_v_kernel but processes three + tensors per CTA iteration, sharing a single barrier. This avoids launching + three separate kernels (and three separate barriers) for the common case + of gathering hidden states, routing probabilities, and expert indices together. + + The outer token loop is shared across all three tensors; each tensor has its + own inner channel loop with independent masking. BLOCK_SIZE is the maximum + of the three per-tensor block sizes — smaller tensors mask out the extra threads + via channel_mask. + + signal_pad_ptrs from the first output buffer's symmetric memory handle are used + for the single end-of-kernel barrier. Since all three writes complete before the + barrier, a single sync suffices for all three tensors. + + Args: + local_ptr_0/1/2: pointers to each rank's local input for tensors 0/1/2. + multicast_ptr_0/1/2: multicast pointers to the output symmetric memory buffers. + output_byte_offset_0/1/2: byte offsets of each tensor within its symmetric + memory buffer (0 when the buffer holds only that tensor). + signal_pad_ptrs: signal pads from symm_mem_hdl_0, used for the single barrier. + local_tokens: number of tokens this rank contributes (shared across tensors). + rank_token_offset_ptr: pointer to a scalar int32 CUDA tensor holding this rank's + write offset in the global output (prefix sum over lower-ranked EP ranks). + ep_max_tokens_ptr: pointer to a scalar int32 CUDA tensor holding the maximum + local_tokens across all EP ranks. CTAs with pid >= this value exit immediately. + HIDDEN_SIZE_0/1/2: hidden dimension (elements per token row) for each tensor (constexpr). + BLOCK_SIZE: threads per block — max of the three per-tensor block sizes (constexpr). + NUMEL_PER_THREAD_0/1/2: elements per thread per load/store for each tensor (constexpr). + BITS_0/1/2: load/store width in bits (128 or 64) for each tensor (constexpr). + RANK: this rank's index (constexpr). + WORLD_SIZE: total number of ranks (constexpr). + """ + pid = tl.program_id(axis=0) + + ep_max_tokens = tl.load(ep_max_tokens_ptr) + if pid >= ep_max_tokens: + return + + tid = tl.arange(0, BLOCK_SIZE) + rank_token_offset = tl.load(rank_token_offset_ptr) + + numel_per_token_0 = tl.cdiv(HIDDEN_SIZE_0, NUMEL_PER_THREAD_0) + numel_per_token_1 = tl.cdiv(HIDDEN_SIZE_1, NUMEL_PER_THREAD_1) + numel_per_token_2 = tl.cdiv(HIDDEN_SIZE_2, NUMEL_PER_THREAD_2) + + local_numel_0 = local_tokens * numel_per_token_0 + local_numel_1 = local_tokens * numel_per_token_1 + local_numel_2 = local_tokens * numel_per_token_2 + + # channel_mask: deactivates threads beyond each tensor's numel_per_token (power-of-2 padding). + channel_mask_0 = tid < numel_per_token_0 + channel_mask_1 = tid < numel_per_token_1 + channel_mask_2 = tid < numel_per_token_2 + + for token_offset in range(pid, local_tokens, tl.num_programs(axis=0)): + # --- Tensor 0 --- + for channel_offset in range(0, numel_per_token_0, BLOCK_SIZE): + local_offsets = token_offset * numel_per_token_0 + channel_offset + tid + token_mask = local_offsets < local_numel_0 + mask = token_mask & channel_mask_0 + global_offsets = rank_token_offset * numel_per_token_0 + local_offsets + if BITS_0 == 128: + multicast_ptrs = ( + multicast_ptr_0.to(tl.pointer_type(tl.uint64)) + + output_byte_offset_0 // 8 + + global_offsets * 2 + ) + local_ptrs = local_ptr_0.to(tl.pointer_type(tl.uint64)) + local_offsets * 2 + (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) + st_128(multicast_ptrs, x, y, z, w, mask=mask, multicast_op=True) + else: + multicast_ptrs = ( + multicast_ptr_0.to(tl.pointer_type(tl.uint64)) + + output_byte_offset_0 // 8 + + global_offsets + ) + local_ptrs = local_ptr_0.to(tl.pointer_type(tl.uint64)) + local_offsets + (x, y) = ld_64(local_ptrs, mask=mask) + st_64(multicast_ptrs, x, y, mask=mask, multicast_op=True) + + # --- Tensor 1 --- + for channel_offset in range(0, numel_per_token_1, BLOCK_SIZE): + local_offsets = token_offset * numel_per_token_1 + channel_offset + tid + token_mask = local_offsets < local_numel_1 + mask = token_mask & channel_mask_1 + global_offsets = rank_token_offset * numel_per_token_1 + local_offsets + if BITS_1 == 128: + multicast_ptrs = ( + multicast_ptr_1.to(tl.pointer_type(tl.uint64)) + + output_byte_offset_1 // 8 + + global_offsets * 2 + ) + local_ptrs = local_ptr_1.to(tl.pointer_type(tl.uint64)) + local_offsets * 2 + (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) + st_128(multicast_ptrs, x, y, z, w, mask=mask, multicast_op=True) + else: + multicast_ptrs = ( + multicast_ptr_1.to(tl.pointer_type(tl.uint64)) + + output_byte_offset_1 // 8 + + global_offsets + ) + local_ptrs = local_ptr_1.to(tl.pointer_type(tl.uint64)) + local_offsets + (x, y) = ld_64(local_ptrs, mask=mask) + st_64(multicast_ptrs, x, y, mask=mask, multicast_op=True) + + # --- Tensor 2 --- + for channel_offset in range(0, numel_per_token_2, BLOCK_SIZE): + local_offsets = token_offset * numel_per_token_2 + channel_offset + tid + token_mask = local_offsets < local_numel_2 + mask = token_mask & channel_mask_2 + global_offsets = rank_token_offset * numel_per_token_2 + local_offsets + if BITS_2 == 128: + multicast_ptrs = ( + multicast_ptr_2.to(tl.pointer_type(tl.uint64)) + + output_byte_offset_2 // 8 + + global_offsets * 2 + ) + local_ptrs = local_ptr_2.to(tl.pointer_type(tl.uint64)) + local_offsets * 2 + (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) + st_128(multicast_ptrs, x, y, z, w, mask=mask, multicast_op=True) + else: + multicast_ptrs = ( + multicast_ptr_2.to(tl.pointer_type(tl.uint64)) + + output_byte_offset_2 // 8 + + global_offsets + ) + local_ptrs = local_ptr_2.to(tl.pointer_type(tl.uint64)) + local_offsets + (x, y) = ld_64(local_ptrs, mask=mask) + st_64(multicast_ptrs, x, y, mask=mask, multicast_op=True) + + sync_threads() + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=True, + hasSubsequentMemAccess=True, + ) + + +def multimem_all_gather_v( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + symm_mem_hdl: _SymmetricMemory, + rank_token_offset: torch.Tensor, + ep_max_tokens: torch.Tensor, + per_rank_max_tokens: int, + output_byte_offset: int = 0, + **kwargs, +) -> torch.Tensor: + """Variable-count multicast all-gather for a single 2-D tensor. + + Gathers [local_tokens, hidden_size] from each EP rank into a shared + output_tensor of shape [global_tokens, hidden_size], where global_tokens is + the sum of all ranks' local_tokens. Each rank writes its slice starting at + rank_token_offset in the output. + + Both tensors must be 2-D; hidden_size is inferred from input_tensor.shape[1]. + The 128-bit or 64-bit NVLS path is selected automatically based on row alignment. + + Args: + output_tensor: symmetric memory buffer, shape [global_tokens, hidden_size]. + input_tensor: this rank's local input, shape [local_tokens, hidden_size]. + symm_mem_hdl: symmetric memory handle for output_tensor. + rank_token_offset: pre-allocated scalar int32 CUDA tensor. The dispatcher + writes this rank's token offset (prefix sum over lower-ranked EP ranks) + into it each step before kernel launch. + ep_max_tokens: pre-allocated scalar int32 CUDA tensor. The dispatcher writes + the maximum local_tokens across all EP ranks into it each step. CTAs with + pid >= ep_max_tokens exit immediately — safe because all ranks agree on + this value, so the corresponding CTAs exit on every rank simultaneously. + per_rank_max_tokens: static int set at model init. Determines the CTA grid size + as min(per_rank_max_tokens, MAX_NUM_BLOCKS). Typically > MAX_NUM_BLOCKS so + we always launch MAX_NUM_BLOCKS CTAs. + output_byte_offset: byte offset of this tensor within the symmetric memory buffer + (for packing multiple tensors into one buffer; 0 if the buffer holds only + this tensor). + + Returns: + output_tensor with all ranks' data written. + """ + assert HAVE_TRITON, "Triton is required for multimem all-gather-v." + assert input_tensor.ndim == 2 and output_tensor.ndim == 2, ( + f"input_tensor and output_tensor must be 2-D [tokens, hidden_size], " + f"got input_tensor.shape={input_tensor.shape}, output_tensor.shape={output_tensor.shape}." + ) + assert is_device_nvls_capable( + input_tensor.device + ), "multimem_all_gather_v requires a Hopper+ GPU with NVLink (SM >= 9)." + assert ( + rank_token_offset.numel() == 1 + and rank_token_offset.dtype == torch.int32 + and rank_token_offset.is_cuda + ), "rank_token_offset must be a scalar int32 CUDA tensor." + + hidden_size = input_tensor.shape[1] + assert ( + input_tensor.shape[1] == output_tensor.shape[1] + ), f"input and output hidden_size mismatch: {input_tensor.shape[1]} vs {output_tensor.shape[1]}" + + row_bytes = hidden_size * input_tensor.element_size() + assert row_bytes % 8 == 0, ( + f"Row size ({hidden_size} elements × {input_tensor.element_size()} bytes) = " + f"{row_bytes} bytes is not 8-byte aligned; cannot use NVLS." + ) + bits = 128 if row_bytes % 16 == 0 else 64 + + MAX_NUM_BLOCKS = kwargs.get("max_num_blocks", 128) + MAX_BLOCK_SIZE = 1024 + WARP_SIZE = 32 + + local_tokens = input_tensor.shape[0] + numel_per_thread = bits // (input_tensor.element_size() * 8) + numel_per_token = (hidden_size + numel_per_thread - 1) // numel_per_thread + + # BLOCK_SIZE must be a constexpr and >= numel_per_token; round up to next power of 2. + block_size = min(triton.next_power_of_2(numel_per_token), MAX_BLOCK_SIZE) + num_warps = max(1, block_size // WARP_SIZE) + + # All ranks launch the same fixed number of CTAs. CTAs with + # pid >= ep_max_tokens exit immediately at kernel entry. + num_blocks = min(per_rank_max_tokens, MAX_NUM_BLOCKS) + + _multimem_all_gather_v_kernel[(num_blocks, 1, 1)]( + input_tensor.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + local_tokens=local_tokens, + rank_token_offset_ptr=rank_token_offset, + ep_max_tokens_ptr=ep_max_tokens, + output_byte_offset=output_byte_offset, + HIDDEN_SIZE=hidden_size, + BLOCK_SIZE=block_size, + NUMEL_PER_THREAD=numel_per_thread, + BITS=bits, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=num_warps, + ) + + return output_tensor + + +def multimem_all_gatherv_3tensor( + output_tensor_0: torch.Tensor, + output_tensor_1: torch.Tensor, + output_tensor_2: torch.Tensor, + input_tensor_0: torch.Tensor, + input_tensor_1: torch.Tensor, + input_tensor_2: torch.Tensor, + symm_mem_hdl_0: _SymmetricMemory, + symm_mem_hdl_1: _SymmetricMemory, + symm_mem_hdl_2: _SymmetricMemory, + rank_token_offset: torch.Tensor, + ep_max_tokens: torch.Tensor, + per_rank_max_tokens: int, + output_byte_offset_0: int = 0, + output_byte_offset_1: int = 0, + output_byte_offset_2: int = 0, + **kwargs, +) -> tuple: + """Variable-count multicast all-gather for three tensors in a single kernel launch. + + Gathers three independent [local_tokens, hidden_size_i] tensors from every EP rank + into their respective output symmetric memory buffers in one fused kernel, sharing a + single end-of-kernel barrier. This is more efficient than calling multimem_all_gather_v + three times because the barrier cost (one per kernel) is paid only once. + + All three input tensors must share the same local_tokens dimension (i.e. the same + number of token rows per rank). Each tensor may have a different hidden_size and dtype. + The 128-bit or 64-bit NVLS path is selected independently per tensor based on row + alignment. + + The barrier at the end of the kernel uses signal_pad_ptrs from symm_mem_hdl_0. Since + all three multicast stores complete before the barrier, a single sync covers all three + tensors. All three handles must belong to the same EP group (identical rank/world_size). + + Args: + output_tensor_0/1/2: symmetric memory buffers for each tensor, + shape [global_tokens, hidden_size_i]. + input_tensor_0/1/2: this rank's local inputs, shape [local_tokens, hidden_size_i]. + symm_mem_hdl_0/1/2: symmetric memory handles for each output buffer. + signal_pad_ptrs from hdl_0 are used for the single end-of-kernel barrier. + rank_token_offset: pre-allocated scalar int32 CUDA tensor. The dispatcher writes + this rank's token offset (prefix sum over lower-ranked EP ranks) each step. + ep_max_tokens: pre-allocated scalar int32 CUDA tensor. The dispatcher writes the + maximum local_tokens across all EP ranks each step. CTAs with + pid >= ep_max_tokens exit immediately — safe because all ranks agree. + per_rank_max_tokens: static int set at model init. Determines the CTA grid size as + min(per_rank_max_tokens, MAX_NUM_BLOCKS). + output_byte_offset_0/1/2: byte offset of each tensor within its symmetric memory + buffer (for packing multiple tensors into one buffer; 0 otherwise). + + Returns: + Tuple of (output_tensor_0, output_tensor_1, output_tensor_2) with all ranks' + data written. + """ + assert HAVE_TRITON, "Triton is required for multimem all-gather-v3." + for i, (inp, out) in enumerate( + zip( + (input_tensor_0, input_tensor_1, input_tensor_2), + (output_tensor_0, output_tensor_1, output_tensor_2), + ) + ): + assert inp.ndim == 2 and out.ndim == 2, ( + f"input_tensor_{i} and output_tensor_{i} must be 2-D [tokens, hidden_size], " + f"got input_tensor_{i}.shape={inp.shape}, output_tensor_{i}.shape={out.shape}." + ) + assert inp.shape[1] == out.shape[1], ( + f"input_tensor_{i} and output_tensor_{i} hidden_size mismatch: " + f"{inp.shape[1]} vs {out.shape[1]}." + ) + assert ( + input_tensor_0.shape[0] == input_tensor_1.shape[0] == input_tensor_2.shape[0] + ), "All three input tensors must have the same local_tokens (first dimension)." + assert is_device_nvls_capable( + input_tensor_0.device + ), "multimem_all_gatherv_3tensor requires a Hopper+ GPU with NVLink (SM >= 9)." + assert ( + rank_token_offset.numel() == 1 + and rank_token_offset.dtype == torch.int32 + and rank_token_offset.is_cuda + ), "rank_token_offset must be a scalar int32 CUDA tensor." + assert ( + symm_mem_hdl_0.rank == symm_mem_hdl_1.rank == symm_mem_hdl_2.rank + ), "All three symmetric memory handles must belong to the same EP group (rank mismatch)." + assert ( + symm_mem_hdl_0.world_size == symm_mem_hdl_1.world_size == symm_mem_hdl_2.world_size + ), "All three symmetric memory handles must belong to the same EP group (world_size mismatch)." + + MAX_NUM_BLOCKS = kwargs.get("max_num_blocks", 128) + MAX_BLOCK_SIZE = 1024 + WARP_SIZE = 32 + + local_tokens = input_tensor_0.shape[0] + + def _tensor_params(inp): + hidden_size = inp.shape[1] + row_bytes = hidden_size * inp.element_size() + assert row_bytes % 8 == 0, ( + f"Row size ({hidden_size} elements × {inp.element_size()} bytes) = " + f"{row_bytes} bytes is not 8-byte aligned; cannot use NVLS." + ) + bits = 128 if row_bytes % 16 == 0 else 64 + numel_per_thread = bits // (inp.element_size() * 8) + numel_per_token = (hidden_size + numel_per_thread - 1) // numel_per_thread + block_size = min(triton.next_power_of_2(numel_per_token), MAX_BLOCK_SIZE) + return hidden_size, bits, numel_per_thread, block_size + + hidden_size_0, bits_0, numel_per_thread_0, block_size_0 = _tensor_params(input_tensor_0) + hidden_size_1, bits_1, numel_per_thread_1, block_size_1 = _tensor_params(input_tensor_1) + hidden_size_2, bits_2, numel_per_thread_2, block_size_2 = _tensor_params(input_tensor_2) + + # Use the largest block size so all threads are occupied for at least one tensor; + # smaller tensors mask out excess threads via channel_mask inside the kernel. + block_size = max(block_size_0, block_size_1, block_size_2) + num_warps = max(1, block_size // WARP_SIZE) + num_blocks = min(per_rank_max_tokens, MAX_NUM_BLOCKS) + + _multimem_all_gatherv_3tensor_kernel[(num_blocks, 1, 1)]( + input_tensor_0.data_ptr(), + symm_mem_hdl_0.multicast_ptr, + output_byte_offset_0, + input_tensor_1.data_ptr(), + symm_mem_hdl_1.multicast_ptr, + output_byte_offset_1, + input_tensor_2.data_ptr(), + symm_mem_hdl_2.multicast_ptr, + output_byte_offset_2, + symm_mem_hdl_0.signal_pad_ptrs_dev, + local_tokens=local_tokens, + rank_token_offset_ptr=rank_token_offset, + ep_max_tokens_ptr=ep_max_tokens, + HIDDEN_SIZE_0=hidden_size_0, + HIDDEN_SIZE_1=hidden_size_1, + HIDDEN_SIZE_2=hidden_size_2, + BLOCK_SIZE=block_size, + NUMEL_PER_THREAD_0=numel_per_thread_0, + NUMEL_PER_THREAD_1=numel_per_thread_1, + NUMEL_PER_THREAD_2=numel_per_thread_2, + BITS_0=bits_0, + BITS_1=bits_1, + BITS_2=bits_2, + RANK=symm_mem_hdl_0.rank, + WORLD_SIZE=symm_mem_hdl_0.world_size, + num_warps=num_warps, + ) + + return output_tensor_0, output_tensor_1, output_tensor_2 diff --git a/megatron/core/inference/config.py b/megatron/core/inference/config.py index e1a36ff1563..df8e36c7bac 100644 --- a/megatron/core/inference/config.py +++ b/megatron/core/inference/config.py @@ -2,7 +2,7 @@ from dataclasses import InitVar, dataclass from enum import Enum -from typing import List, Optional, Tuple +from typing import List, Literal, Optional, Tuple import torch @@ -297,10 +297,13 @@ class InferenceConfig: Defaults to 0, which means no logging. """ - request_metadata_types: Optional[List[Tuple[str, torch.dtype, bool]]] = None + sampling_backend: Literal['torch', 'flashinfer'] = 'torch' + """Which sampling kernels to use during inference.""" + + request_metadata_types: Optional[List[Tuple[str, torch.dtype]]] = None """ A list of the per-request metadata types to track. Each entry is a tuple - consisting of the string label, the target dtype, and whether to store the data on GPU. + consisting of the string label and the target dtype. """ use_synchronous_zmq_collectives: bool = False @@ -309,6 +312,16 @@ class InferenceConfig: performance variability for MoEs. """ + disable_ep_consensus: bool = False + """If True, the engine skips the EP-group consensus all-reduce in + `run_engine_with_coordinator` and decides whether to step based on local + state alone. The rank still calls `controller.dummy_forward()` whenever + `local_pending == 0`, so EP collectives (NCCL all-to-all, etc.) stay in + sync — without this, a peer running a real forward would deadlock waiting + on this rank's all-to-all participation. Trades off the consensus + all-reduce CPU cost for unconditional dummy_forwards on idle ranks. + """ + verbose: InitVar[bool] = False """Whether to log detailed context configuration at initialization. This is an InitVar and is not stored as a field on the config.""" @@ -320,3 +333,12 @@ def __post_init__(self, verbose: bool): f"prefix_caching_routing_alpha must be in [0, 1], " f"got {self.prefix_caching_routing_alpha}" ) + + if self.sampling_backend == 'flashinfer': + try: + import flashinfer # noqa: F401 + except ImportError as e: + raise ImportError( + "sampling_backend='flashinfer' requires the flashinfer package; " + "install it or set sampling_backend='torch'." + ) from e diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index 19091d35bfb..ff6423be16b 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -35,14 +35,15 @@ def __init__( # Maximum possible chunks across all batch configurations self.max_chunks = max_tokens // mamba_chunk_size + max_requests - # Map from requests to slots in the static Mamba state buffer + # Map from requests to slots in the static Mamba state buffer (CPU for bookkeeping). self.request_to_mamba_state_idx = torch.full( - (self.max_requests,), -1, dtype=torch.int32, device=torch.cuda.current_device() + (self.max_requests,), -1, dtype=torch.int32, device='cpu' ) - # Map from requests to slots in the static Mamba state buffer for active decode requests + # Map from requests to slots in the static Mamba state buffer for active decode requests. + # int64 so selective_state_update can index directly without a per-layer upcast kernel; self._batch_indices_decode_buffer = torch.full( - (self.max_requests,), -1, dtype=torch.int32, device=self.device + (self.max_requests,), -1, dtype=torch.int64, device=self.device ) # Map from requests to slots in the static Mamba state buffer for active prefill requests @@ -84,9 +85,9 @@ def __init__( self._conv_seq_idx_buffer = torch.zeros(max_tokens, dtype=torch.int32, device=self.device) self._conv_seq_start_buffer = torch.zeros(max_tokens, dtype=torch.int32, device=self.device) - # Allocator for Mamba state slots + # Allocator for Mamba state slots (CPU for bookkeeping). self.mamba_state_free_slots = torch.arange( - self.max_requests, dtype=torch.int32, device=torch.cuda.current_device() + self.max_requests, dtype=torch.int32, device='cpu' ) self.mamba_state_free_slot_count = self.max_requests @@ -107,8 +108,31 @@ def __init__( else: self.conv_gather_offsets = None + # Coalesced production path: pinned CPU views + shared GPU views bound + # by DynamicInferenceContext so that the per-step Mamba metadata fields + # ride along with the single coalesced H2D in transfer_bookkeeping_to_gpu. + # The legacy update() path above keeps using the standalone _*_buffer + # tensors (exercised only by unit tests that construct MambaMetadata + # without a context). + self._cpu_bufs = None + self._gpu_view = None + self.reset_varlen_metadata() + def bind_cpu_buffers(self, bufs: dict) -> None: + """Attach pinned CPU views from DynamicInferenceContext._cpu_bookkeeping_buf. + + ``bufs`` maps field names to 1D (or (1, max_tokens) for ``seq_idx``) + pinned CPU views that compute_cpu_metadata writes into. The matching + GPU views on the other side of the H2D are exposed via + :meth:`bind_gpu_buffers`. + """ + self._cpu_bufs = bufs + + def bind_gpu_buffers(self, gpu_view) -> None: + """Attach shared GPU views from the context's :class:`ContextGPUView`.""" + self._gpu_view = gpu_view + def reset(self) -> None: """ Resets all Mamba states and frees all allocated slots. @@ -119,7 +143,7 @@ def reset(self) -> None: # Re-initialize the free slot pool self.mamba_state_free_slots = torch.arange( - self.max_requests, dtype=torch.int32, device=torch.cuda.current_device() + self.max_requests, dtype=torch.int32, device='cpu' ) self.mamba_state_free_slot_count = self.max_requests @@ -339,6 +363,7 @@ def _update_intermediate_metadata( intermediate_offsets_gpu: Optional[torch.Tensor], intermediate_counts_gpu: Optional[torch.Tensor], real_prefill_count: int, + cu_seqlens_gpu: Optional[torch.Tensor] = None, ) -> None: """Precompute intermediate extraction metadata for CUDA graph compatibility. @@ -352,18 +377,32 @@ def _update_intermediate_metadata( intermediate_counts_gpu: [real_prefill_count] int32 GPU tensor of per-request offset counts (0-3), or None. real_prefill_count: Number of real (non-padding) prefill requests. + cu_seqlens_gpu: GPU cu_seqlens tensor to read from. Defaults to + the legacy standalone ``_cu_seqlens_buffer`` used by + :meth:`update`; the coalesced production path passes the + shared ``ContextGPUView.mamba_cu_seqlens`` view. """ chunk_size = self.mamba_chunk_size max_count = self.max_intermediate_count + if cu_seqlens_gpu is None: + cu_seqlens_gpu = self._cu_seqlens_buffer if intermediate_offsets_gpu is not None and real_prefill_count > 0: - # Transfer counts to CPU (single sync) for per_request_counts and total check + # counts_list is CPU-cheap (source is already CPU from MambaSlotAllocator). counts_list = intermediate_counts_gpu.tolist() total = sum(counts_list) + # Ensure GPU copies for vectorized GPU ops below. + if not intermediate_offsets_gpu.is_cuda: + intermediate_offsets_gpu = intermediate_offsets_gpu.to( + self.device, non_blocking=True + ) + if not intermediate_counts_gpu.is_cuda: + intermediate_counts_gpu = intermediate_counts_gpu.to(self.device, non_blocking=True) + if total > 0: # Compute cumulative chunk counts from cu_seqlens (already on GPU) - cu = self._cu_seqlens_buffer[: real_prefill_count + 1] + cu = cu_seqlens_gpu[: real_prefill_count + 1] seq_lens = (cu[1 : real_prefill_count + 1] - cu[:real_prefill_count]).to( torch.int64 ) @@ -429,6 +468,223 @@ def _update_intermediate_metadata( self.intermediate_chunk_indices = self._intermediate_chunk_indices_buffer[:max_count] self.intermediate_abs_positions = self._intermediate_abs_positions_buffer[:max_count] + def compute_cpu_metadata( + self, + active_mamba_indices: torch.Tensor, + token_to_request_idx: torch.Tensor, + cpu_cu_query: torch.Tensor, + batch_dimensions: InferenceBatchDimensions, + padded_batch_dimensions: InferenceBatchDimensions, + enable_chunked_prefill: bool, + intermediate_offsets_gpu: Optional[torch.Tensor] = None, + intermediate_counts_gpu: Optional[torch.Tensor] = None, + ) -> dict: + """Compute all Mamba metadata on CPU, writing directly into the bound + pinned CPU views. + + The values written here are transferred to GPU by the single coalesced + H2D in :meth:`DynamicInferenceContext.transfer_bookkeeping_to_gpu`. + The returned dict contains only Python scalars + the intermediate GPU + tensors, which :meth:`load_from_cpu` consumes after the H2D. + + Args: + active_mamba_indices: CPU tensor of Mamba slot indices for active requests. + token_to_request_idx: CPU tensor mapping tokens to request indices. + cpu_cu_query: CPU cumulative query lengths from MHA metadata computation. + batch_dimensions: Dimensions of the current batch. + padded_batch_dimensions: Dimensions of the padded batch. + enable_chunked_prefill: Whether chunked prefill is enabled. + intermediate_offsets_gpu: GPU tensor of per-request intermediate offsets, or None. + intermediate_counts_gpu: GPU tensor of per-request intermediate counts, or None. + """ + assert self._cpu_bufs is not None, "bind_cpu_buffers() must be called first" + bufs = self._cpu_bufs + + real_decode_count = batch_dimensions.decode_req_count + real_prefill_count = batch_dimensions.prefill_req_count + padded_decode_count = padded_batch_dimensions.decode_req_count + padded_prefill_count = padded_batch_dimensions.prefill_req_count + padded_token_count = padded_batch_dimensions.token_count + chunk_size = self.mamba_chunk_size + + result = { + "padded_decode_count": padded_decode_count, + "padded_prefill_count": padded_prefill_count, + "padded_token_count": padded_token_count, + "real_decode_count": real_decode_count, + "real_prefill_count": real_prefill_count, + } + + # Decode batch indices (write into pinned view; padded slots = -1). + if padded_decode_count > 0: + bufs['batch_indices_decode'][:real_decode_count] = active_mamba_indices[ + :real_decode_count + ] + if padded_decode_count > real_decode_count: + bufs['batch_indices_decode'][real_decode_count:padded_decode_count] = -1 + + # Prefill batch indices, seq_idx, cu_seqlens, chunk/conv metadata. + if padded_prefill_count > 0: + if real_prefill_count > 0: + start = real_decode_count + bufs['batch_indices_prefill'][:real_prefill_count] = active_mamba_indices[ + start : start + real_prefill_count + ] + if padded_prefill_count > real_prefill_count: + bufs['batch_indices_prefill'][real_prefill_count:padded_prefill_count] = -1 + + # seq_idx: normalized token-to-request mapping for prefill tokens. + prefill_start_req = real_decode_count + end_prefill_req = real_decode_count + real_prefill_count + start_token = cpu_cu_query[prefill_start_req].item() + end_token = cpu_cu_query[end_prefill_req].item() + seq_len = end_token - start_token + + if seq_len > 0: + raw = token_to_request_idx[start_token:end_token] + bufs['seq_idx'][0, :seq_len] = raw - raw[0] + if padded_token_count > seq_len: + bufs['seq_idx'][0, seq_len:padded_token_count] = -1 + result["seq_len"] = seq_len + + # cu_seqlens for prefill. + cu_seqlens_view = bufs['cu_seqlens'] + cu_seqlens_view[0] = 0 + if real_prefill_count > 0: + cu_seqlens_view[1 : real_prefill_count + 1] = ( + cpu_cu_query[prefill_start_req + 1 : end_prefill_req + 1] + - cpu_cu_query[prefill_start_req] + ) + if real_prefill_count < padded_prefill_count: + last_val = cu_seqlens_view[real_prefill_count].item() + cu_seqlens_view[real_prefill_count + 1 : padded_prefill_count + 1] = last_val + + cu_seqlens_list = cu_seqlens_view[: real_prefill_count + 1].tolist() + real_prefill_tokens = ( + cu_seqlens_list[real_prefill_count] if real_prefill_count > 0 else 0 + ) + result["cu_seqlens_list"] = cu_seqlens_list + result["real_prefill_token_count"] = real_prefill_tokens + + # Chunk metadata (Python loop, pure CPU). + cu_seqlens_all = cu_seqlens_view[: padded_prefill_count + 1].tolist() + chunk_boundaries = [0] + last_chunk_idx_list = [] + chunk_to_seq_list = [] + + for i in range(padded_prefill_count): + start = cu_seqlens_all[i] + end = cu_seqlens_all[i + 1] + s_len = end - start + n_chunks = max(1, (s_len + chunk_size - 1) // chunk_size) + boundaries = [min(start + (k + 1) * chunk_size, end) for k in range(n_chunks)] + chunk_boundaries.extend(boundaries) + chunk_to_seq_list.extend([i] * n_chunks) + last_chunk_idx_list.append(len(chunk_boundaries) - 2) + + padded_max_chunks = padded_token_count // chunk_size + padded_prefill_count + last_boundary = chunk_boundaries[-1] + pad_b = padded_max_chunks + 1 - len(chunk_boundaries) + if pad_b > 0: + chunk_boundaries.extend([last_boundary] * pad_b) + pad_s = padded_max_chunks - len(chunk_to_seq_list) + if pad_s > 0: + chunk_to_seq_list.extend([0] * pad_s) + + n_cu = padded_max_chunks + 1 + bufs['cu_chunk_seqlens'][:n_cu] = torch.tensor( + chunk_boundaries[:n_cu], dtype=torch.int32 + ) + bufs['last_chunk_indices'][:padded_prefill_count] = torch.tensor( + last_chunk_idx_list, dtype=torch.int32 + ) + bufs['seq_idx_for_varlen'][:padded_max_chunks] = torch.tensor( + chunk_to_seq_list[:padded_max_chunks], dtype=torch.int32 + ) + result["padded_max_chunks"] = padded_max_chunks + + # Conv1d per-token metadata (CPU repeat_interleave). + conv_seq_idx_view = bufs['conv_seq_idx'] + conv_seq_start_view = bufs['conv_seq_start'] + if real_prefill_tokens > 0: + cu_t = cu_seqlens_view[: real_prefill_count + 1] + lengths = (cu_t[1:] - cu_t[:-1]).to(torch.int64) + seq_indices = torch.arange(real_prefill_count, dtype=torch.int32) + seq_starts = cu_t[:real_prefill_count].to(torch.int32) + conv_seq_idx_view[:real_prefill_tokens] = torch.repeat_interleave( + seq_indices, lengths + ) + conv_seq_start_view[:real_prefill_tokens] = torch.repeat_interleave( + seq_starts, lengths + ) + if padded_token_count > real_prefill_tokens: + conv_seq_idx_view[real_prefill_tokens:padded_token_count] = 0 + conv_seq_start_view[real_prefill_tokens:padded_token_count] = 0 + + # Intermediate metadata still requires GPU data: defer to load_from_cpu. + result["intermediate_offsets_gpu"] = intermediate_offsets_gpu + result["intermediate_counts_gpu"] = intermediate_counts_gpu + + # device_decode_prefill scalars. + if padded_decode_count > 0 and padded_prefill_count > 0: + result["decode_prefill_0"] = cpu_cu_query[real_decode_count].item() + result["decode_prefill_1"] = ( + cpu_cu_query[real_decode_count + real_prefill_count].item() + - cpu_cu_query[real_decode_count].item() + ) + + return result + + def load_from_cpu(self, d: dict) -> None: + """Point state attributes at the freshly-transferred shared GPU views. + + No H2D copies happen here: the Mamba metadata fields were transferred + as part of the coalesced bookkeeping H2D. This method just slices the + bound GPU views to the per-step sizes and runs the intermediate + metadata computation (which reads from the now-valid GPU cu_seqlens). + + Args: + d: Dict returned by compute_cpu_metadata(). + """ + assert self._gpu_view is not None, "bind_gpu_buffers() must be called first" + v = self._gpu_view + + padded_decode_count = d["padded_decode_count"] + padded_prefill_count = d["padded_prefill_count"] + padded_token_count = d["padded_token_count"] + real_prefill_count = d["real_prefill_count"] + + if padded_decode_count > 0: + self.batch_indices_decode = v.mamba_batch_indices_decode[:padded_decode_count] + + if padded_prefill_count > 0: + self.batch_indices_prefill = v.mamba_batch_indices_prefill[:padded_prefill_count] + self.seq_idx = v.mamba_seq_idx[:, :padded_token_count] + self.cu_seqlens = v.mamba_cu_seqlens[: padded_prefill_count + 1] + self.cu_seqlens_list = d["cu_seqlens_list"] + self.real_prefill_token_count = d["real_prefill_token_count"] + + padded_max_chunks = d["padded_max_chunks"] + self.cu_chunk_seqlens = v.mamba_cu_chunk_seqlens[: padded_max_chunks + 1] + self.last_chunk_indices = v.mamba_last_chunk_indices[:padded_prefill_count] + self.seq_idx_for_varlen = v.mamba_seq_idx_for_varlen[:padded_max_chunks] + self.conv_seq_idx = v.mamba_conv_seq_idx[:padded_token_count] + self.conv_seq_start = v.mamba_conv_seq_start[:padded_token_count] + + # Intermediate metadata reads from the just-transferred cu_seqlens + # to compute chunk indices & absolute positions for state extraction. + self._update_intermediate_metadata( + d["intermediate_offsets_gpu"], + d["intermediate_counts_gpu"], + real_prefill_count, + cu_seqlens_gpu=v.mamba_cu_seqlens, + ) + + if padded_decode_count > 0 and padded_prefill_count > 0: + self._device_decode_prefill_buffer[0] = d["decode_prefill_0"] + self._device_decode_prefill_buffer[1] = d["decode_prefill_1"] + self.device_decode_prefill = self._device_decode_prefill_buffer + def allocate_slot(self) -> Optional[int]: """ Allocates a new slot for a request in the Mamba state buffers. diff --git a/megatron/core/inference/contexts/attention_context/mha_metadata.py b/megatron/core/inference/contexts/attention_context/mha_metadata.py index 07f8a349b51..a71da895ea5 100644 --- a/megatron/core/inference/contexts/attention_context/mha_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mha_metadata.py @@ -1,215 +1,84 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import torch -from megatron.core.inference.batch_dimensions_utils import InferenceBatchDimensions - from .metadata_base import MetadataBase class MHAMetadata(MetadataBase): """ Metadata for MHA layer using flash-attention. + + GPU storage for the per-step fields (``query_lengths``, + ``cu_query_seq_lengths``, ``kv_seq_lengths``, ``cu_kv_seq_lengths``, + ``block_table``) lives inside the context's :class:`ContextGPUView` + unified buffer. Both :class:`GraphedMHAMetadata` and + :class:`NonGraphedMHAMetadata` bind to the same GPU views (only one is + active per step), so the single coalesced H2D in + :meth:`DynamicInferenceContext.transfer_bookkeeping_to_gpu` covers the + MHA fields along with the rest of the bookkeeping state. """ def __init__( self, block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen ): super().__init__() - device = torch.cuda.current_device() - self.device = device + self.device = torch.cuda.current_device() self.max_blocks = block_count_total self.max_kv_blocks = max_kv_block_count self.max_bs = max_requests self.max_seqlen = max_seqlen - self._query_lengths_buf = torch.zeros(self.max_bs, dtype=torch.int32, device=device) - self._cu_query_seq_lengths_buf = torch.zeros( - self.max_bs + 1, dtype=torch.int32, device=device - ) - self._cu_kv_seq_lengths_buf = torch.zeros(self.max_bs + 1, dtype=torch.int32, device=device) - self._kv_seq_lengths_buf = torch.zeros(self.max_bs, dtype=torch.int32, device=device) - self._block_table_buf = torch.zeros( - (self.max_bs, self.max_kv_blocks), dtype=torch.int32, device=device - ) self._max_seqlen_q = 0 self._max_seqlen_k = 0 self.state_data = {} + # Set by bind_gpu_buffers(); references shared views in ContextGPUView._buf. + self._gpu_view = None - def update( - self, - request_query_lengths: torch.Tensor, - request_kv_length_offsets: torch.Tensor, - request_to_kv_block_ids: torch.Tensor, - batch_dimensions: InferenceBatchDimensions, - padded_batch_dimensions: InferenceBatchDimensions, - num_speculative_tokens: int = 0, - ): - """ - Args: - request_query_lengths: (>real_batch_size,) - request_kv_length_offsets: (>real_batch_size,) - request_to_kv_block_ids: (>real_batch_size, max_kv_blocks) - batch_dimensions: Configuration object containing real batch settings - padded_batch_dimensions: Configuration object containing padded batch settings - num_speculative_tokens: Number of speculative tokens - """ - # Extract values from configs - real_batch_size = batch_dimensions.req_count - padded_active_token_count = padded_batch_dimensions.token_count - padded_active_request_count = padded_batch_dimensions.req_count - - assert real_batch_size <= padded_active_request_count <= self.max_bs - assert request_query_lengths.shape[0] == real_batch_size - assert request_kv_length_offsets.shape[0] == real_batch_size - assert request_to_kv_block_ids.shape[0] == real_batch_size + def bind_gpu_buffers(self, gpu_view) -> None: + """Attach shared GPU buffer views from the context's ContextGPUView. - self.tensor_copy_and_pad( - self._query_lengths_buf, - request_query_lengths, - real_batch_size, - padded_active_request_count, - ) - self._cu_query_seq_lengths_buf[0] = 0 - self.tensor_copy_and_pad( - self._cu_query_seq_lengths_buf[1:], - torch.cumsum(request_query_lengths, dim=0), - real_batch_size, - padded_active_request_count, - is_cumulative_tensor=True, - ) - self.tensor_copy_and_pad( - self._kv_seq_lengths_buf, - request_kv_length_offsets + request_query_lengths, - real_batch_size, - padded_active_request_count, - ) - self.tensor_copy_and_pad( - self._block_table_buf, - request_to_kv_block_ids, - real_batch_size, - padded_active_request_count, - pad_value=torch.tensor(self.max_kv_blocks, dtype=torch.int32, device=self.device).fill_( - -1 - ), - ) - self._cu_kv_seq_lengths_buf[0] = 0 - self.tensor_copy_and_pad( - self._cu_kv_seq_lengths_buf[1:], - torch.cumsum(self._kv_seq_lengths_buf, dim=0), - real_batch_size, - padded_active_request_count, - is_cumulative_tensor=True, - ) - - if padded_batch_dimensions.prefill_req_count == 0: - self._max_seqlen_q = num_speculative_tokens + 1 - else: - # Make sure we will launch the prefill kernel for prefill graphs - self._max_seqlen_q = max(2, padded_batch_dimensions.token_count) + Called by :class:`DynamicInferenceContext` after ``self.gpu_view`` is + constructed. Both graphed and non-graphed MHA metadata bind to the + same views; only one is active per step, so sharing storage is safe. + """ + self._gpu_view = gpu_view - self._max_seqlen_k = self.max_seqlen + def set_state_data( + self, padded_active_request_count: int, max_seqlen_q: int, max_seqlen_k: int + ) -> None: + """Build ``state_data`` slices into the bound GPU buffers. + Called once per step from ``transfer_bookkeeping_to_gpu`` after the + coalesced H2D copy. No ``.copy_()`` calls, no kernel launches. + """ + assert self._gpu_view is not None, "bind_gpu_buffers() must be called first" + n = padded_active_request_count + v = self._gpu_view + self._max_seqlen_q = max_seqlen_q + self._max_seqlen_k = max_seqlen_k self.state_data = { - "query_lengths": self._query_lengths_buf[:padded_active_request_count], - "cu_query_seq_lengths": self._cu_query_seq_lengths_buf[ - : padded_active_request_count + 1 - ], - "cu_kv_seq_lengths": self._cu_kv_seq_lengths_buf[: padded_active_request_count + 1], - "kv_seq_lengths": self._kv_seq_lengths_buf[:padded_active_request_count], - "block_table": self._block_table_buf[0:padded_active_request_count, :], - "max_seqlen_q": self._max_seqlen_q, - "max_seqlen_k": self._max_seqlen_k, + "query_lengths": v.mha_query_lengths[:n], + "cu_query_seq_lengths": v.mha_cu_query_seq_lengths[: n + 1], + "cu_kv_seq_lengths": v.mha_cu_kv_seq_lengths[: n + 1], + "kv_seq_lengths": v.mha_kv_seq_lengths[:n], + "block_table": v.mha_block_table[:n, :], + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, } def reset(self): + """Reset the metadata for the next batch. + + The GPU buffers live in the context's unified buffer and are fully + overwritten by the next H2D copy; clearing them here would launch + redundant CUDA kernels with no correctness benefit. """ - Reset the metadata for the next batch. - """ - self._query_lengths_buf.fill_(0) - self._cu_query_seq_lengths_buf.fill_(0) - self._cu_kv_seq_lengths_buf.fill_(0) - self._kv_seq_lengths_buf.fill_(0) - self._block_table_buf.fill_(0) self._max_seqlen_q = 0 self._max_seqlen_k = 0 class GraphedMHAMetadata(MHAMetadata): - """ - Metadata for MHA layer using flash-attention with CUDA graphs. - """ - - def __init__( - self, block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen - ): - super().__init__( - block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen - ) - - def update( - self, - request_query_lengths: torch.Tensor, - request_kv_length_offsets: torch.Tensor, - request_to_kv_block_ids: torch.Tensor, - batch_dimensions: InferenceBatchDimensions, - padded_batch_dimensions: InferenceBatchDimensions, - num_speculative_tokens: int = 0, - ): - """ - Args: - request_query_lengths: (>real_batch_size,) - request_kv_length_offsets: (>real_batch_size,) - request_to_kv_block_ids: (>real_batch_size, max_kv_blocks) - batch_dimensions: Configuration object containing real batch settings - padded_batch_dimensions: Configuration object containing padded batch settings - num_speculative_tokens: Number of speculative tokens - """ - super().update( - request_query_lengths, - request_kv_length_offsets, - request_to_kv_block_ids, - batch_dimensions, - padded_batch_dimensions, - num_speculative_tokens, - ) - - def reset(self): - super().reset() + """MHA metadata for CUDA-graphed execution.""" class NonGraphedMHAMetadata(MHAMetadata): - """ - Metadata for MHA layer using flash-attention without CUDA graphs. - """ - - def update( - self, - request_query_lengths: torch.Tensor, - request_kv_length_offsets: torch.Tensor, - request_to_kv_block_ids: torch.Tensor, - batch_dimensions: InferenceBatchDimensions, - padded_batch_dimensions: InferenceBatchDimensions, - num_speculative_tokens: int = 0, - ): - """ - Args: - request_query_lengths: (>real_batch_size,) - request_kv_length_offsets: (>real_batch_size,) - request_to_kv_block_ids: (>real_batch_size, max_kv_blocks) - batch_dimensions: Configuration object containing real batch settings - padded_batch_dimensions: Configuration object containing padded batch settings - num_speculative_tokens: Number of speculative tokens - """ - super().update( - request_query_lengths, - request_kv_length_offsets, - request_to_kv_block_ids, - batch_dimensions, - padded_batch_dimensions, - num_speculative_tokens, - ) - if len(self.state_data["query_lengths"]) > 0: - self.state_data["max_seqlen_q"] = torch.max(self.state_data["query_lengths"]).item() - self.state_data["max_seqlen_k"] = torch.max(self.state_data["kv_seq_lengths"]).item() - else: - self.state_data["max_seqlen_q"] = num_speculative_tokens + 1 - self.state_data["max_seqlen_k"] = 1 + """MHA metadata for non-graphed (eager) execution.""" diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index fe053522c62..d10dcbd833b 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -35,6 +35,10 @@ ) from megatron.core.package_info import __version__ as mcore_version from megatron.core.transformer import MLATransformerConfig, TransformerConfig +from megatron.core.transformer.moe.token_dispatcher_inference import ( + NCCLAllGatherDispatcher, + NVLSAllGatherVDispatcher, +) from megatron.core.utils import deprecate_args from megatron.core.utils import divide as core_divide from megatron.core.utils import get_pg_size, internal_api @@ -42,6 +46,7 @@ from .attention_context.mamba_metadata import MambaMetadata from .attention_context.mha_metadata import GraphedMHAMetadata, NonGraphedMHAMetadata from .base_context import BaseInferenceContext +from .gpu_view import ContextGPUView from .kv_block_allocator import KVBlockAllocator from .mamba_slot_allocator import MambaSlotAllocator from .routing_metadata import RoutingMetadata @@ -323,6 +328,12 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC else: self.expert_model_parallel_group = None + # Optional CPU-side collective for EP batch-dimension sync. Populated by + # the engine via set_ep_zmq_communicator() when available. When set, + # match_graph_config() uses this to perform the MAX reduction on the + # CPU, avoiding a per-step NCCL AllReduce kernel on the compute stream. + self._ep_zmq_communicator = None + # Mamba states. mamba_inference_state_config = inference_config.mamba_inference_state_config self.is_hybrid_model = mamba_inference_state_config is not None @@ -589,10 +600,30 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ), "Router recording/replay requested but no MoE experts specified!" self.moe_routing_metadata = RoutingMetadata(self, model_config.moe_router_topk) - # CUDA graph config list + # are we using the inference_optimized nccl ep dispatcher for MoEs? + self._nccl_ep_dispatcher = ( + get_pg_size(self.expert_model_parallel_group) > 1 + and model_config.inference_moe_token_dispatcher_type == 'nccl' + ) + + # are we using the training a2a dispatcher for MoEs? + # Note that this is not optimal for speed. + self._training_ep_dispatcher = ( + get_pg_size(self.expert_model_parallel_group) > 1 + and model_config.transformer_impl == "transformer_engine" + ) + + # We only allow non-decode cuda graphs for the nvls dispatcher + force_disable_non_decode_cuda_graphs = ( + self._nccl_ep_dispatcher or self._training_ep_dispatcher + ) + self.use_cuda_graphs_for_non_decode_steps = ( inference_config.use_cuda_graphs_for_non_decode_steps + and not (force_disable_non_decode_cuda_graphs) ) + + # CUDA graph config list. self.cuda_graph_batch_dimensions_list, self.cuda_graph_token_counts = ( CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( tp_size=tp_size, @@ -607,9 +638,21 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) ) - self.smallest_non_decode_cuda_graph_size = min( - inference_config.cuda_graph_mixed_prefill_count, self.max_requests - ) + # Allocate per-step dispatcher buffers upfront so update_metadata never + # triggers an allocation inside a captured CUDA graph. + if get_pg_size(self.expert_model_parallel_group) > 1: + if self._nccl_ep_dispatcher: + NCCLAllGatherDispatcher.allocate_buffers() + else: + # Use moe_latent_size if set (latent MoE: SuperV3, UltraV3), else hidden_size. + moe_hidden_size = model_config.moe_latent_size or model_config.hidden_size + NVLSAllGatherVDispatcher.allocate_buffers( + per_rank_worst_case_token_count=self.round_up_tokens(self.max_tokens) + // tp_size, + topk=model_config.moe_router_topk, + hidden_size=moe_hidden_size, + ep_group=self.expert_model_parallel_group, + ) # Deal with chunked prefill self.enable_chunked_prefill = inference_config.enable_chunked_prefill @@ -620,6 +663,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC elif inference_config.use_flashinfer_fused_rope is None: inference_config.use_flashinfer_fused_rope = HAVE_FLASHINFER self.use_flashinfer_fused_rope = inference_config.use_flashinfer_fused_rope + self.inference_grouped_gemm_backend = model_config.inference_grouped_gemm_backend # Allocate GPU state. self.is_tensor_state_allocated = False @@ -739,8 +783,26 @@ def _allocate_mamba_states(self): self.mamba_metadata = MambaMetadata( max_requests=self.max_requests, max_tokens=self.max_tokens, + mamba_chunk_size=self.mamba_chunk_size, d_conv=self.mamba_conv_states_shape[-1], ) + # Bind the unified CPU/GPU buffers so the per-step Mamba metadata + # fields ride along with the single coalesced H2D in + # transfer_bookkeeping_to_gpu(). + self.mamba_metadata.bind_cpu_buffers( + { + "batch_indices_decode": self._cpu_mamba_batch_indices_decode, + "batch_indices_prefill": self._cpu_mamba_batch_indices_prefill, + "seq_idx": self._cpu_mamba_seq_idx, + "cu_seqlens": self._cpu_mamba_cu_seqlens, + "cu_chunk_seqlens": self._cpu_mamba_cu_chunk_seqlens, + "last_chunk_indices": self._cpu_mamba_last_chunk_indices, + "seq_idx_for_varlen": self._cpu_mamba_seq_idx_for_varlen, + "conv_seq_idx": self._cpu_mamba_conv_seq_idx, + "conv_seq_start": self._cpu_mamba_conv_seq_start, + } + ) + self.mamba_metadata.bind_gpu_buffers(self.gpu_view) self.mamba_conv_states = torch.empty( (self.num_mamba_layers, self.max_requests) + self.mamba_conv_states_shape, dtype=self.mamba_conv_states_dtype, @@ -816,58 +878,319 @@ def initialize_all_tensors(self) -> None: f"Please move tensor '{key}'." ) - # Per-request state. + # Per-request state (CPU, pinned memory for fast H2D transfer). self.request_ids = torch.full( - (self.max_requests,), -1, dtype=torch.int32, device=torch.cuda.current_device() + (self.max_requests,), -1, dtype=torch.int32, device='cpu', pin_memory=True ) # request_query_lengths is the input prompt tokens length during prefill phase (1st step) and then 1 for the decode phase (i.e During generation) - self.request_query_lengths = torch.empty_like(self.request_ids) + self.request_query_lengths = torch.empty( + self.max_requests, dtype=torch.int32, device='cpu', pin_memory=True + ) # True only for a new request , then after a forward pass it is set to False - self.request_in_prefill_status_tensor = torch.empty_like(self.request_ids) + self.request_in_prefill_status_tensor = torch.empty( + self.max_requests, dtype=torch.int32, device='cpu', pin_memory=True + ) # request_output_lengths is len(input_prompt_tokens) + num_tokens_to_generate - self.request_output_lengths = torch.empty_like(self.request_ids) + self.request_output_lengths = torch.empty( + self.max_requests, dtype=torch.int32, device='cpu', pin_memory=True + ) # request_kv_length_offsets is the same as query length during prefill phase (1st step) and then 1 for the decode phase (i.e During generation) - self.request_kv_length_offsets = torch.empty_like(self.request_ids) - self.request_kv_block_counts = torch.empty_like(self.request_ids) - self.request_last_kv_block_id = torch.empty_like(self.request_ids) + self.request_kv_length_offsets = torch.empty( + self.max_requests, dtype=torch.int32, device='cpu', pin_memory=True + ) + self.request_kv_block_counts = torch.empty( + self.max_requests, dtype=torch.int32, device='cpu', pin_memory=True + ) + self.request_last_kv_block_id = torch.empty( + self.max_requests, dtype=torch.int32, device='cpu', pin_memory=True + ) # request_last_kv_block_offset represents number of tokens in the last kv block - self.request_last_kv_block_offset = torch.empty_like(self.request_ids) + self.request_last_kv_block_offset = torch.empty( + self.max_requests, dtype=torch.int32, device='cpu', pin_memory=True + ) self.request_to_kv_block_ids = torch.full( (self.max_requests, self.max_kv_block_count), -1, dtype=torch.int, - device=torch.cuda.current_device(), + device='cpu', + pin_memory=True, ) - # Track request metadata. + # Track request metadata. Backed by pinned CPU memory: bookkeeping is + # CPU-resident; GPU consumers read from the active-slice mirror in + # `active_request_metadata` (also CPU pinned, refreshed each step). self.request_metadata = { - label: torch.empty( - (self.max_requests,), dtype=dtype, device=torch.cuda.current_device() - ) - for label, dtype, _ in self.request_metadata_types + label: torch.empty((self.max_requests,), dtype=dtype, device='cpu', pin_memory=True) + for label, dtype in self.request_metadata_types } - # Per-token state. - self.token_to_input_ids = torch.full( - (self.max_tokens,), 0, dtype=torch.long, device=torch.cuda.current_device() - ) - self.token_to_pos_ids = torch.full_like(self.token_to_input_ids, 0) - self.token_to_request_idx = torch.empty_like(self.token_to_input_ids) - self.token_to_block_idx = torch.empty_like(self.token_to_input_ids) + # Static tensor addresses of active slices to enable fast inference + # kernels. Pinned CPU mirrors of `request_metadata`, refreshed each + # step by `build_active_slices()` from the active subrange. + self.active_request_metadata = { + label: torch.empty_like(tensor, pin_memory=True) + for label, tensor in self.request_metadata.items() + } + + # Coalesced pinned CPU buffer for the bookkeeping fields that get + # transferred to GPU each step via transfer_bookkeeping_to_gpu(). + # Layout matches ContextGPUView._buf so a single cudaMemcpyAsync + # suffices. Int64 token fields come first (8-byte aligned automatically), + # then int32 token fields, then int32/float32 request-staging fields. + # token_to_input_ids (int64, max_tokens) + # token_to_pos_ids (int64, max_tokens) + # token_to_block_idx (int32, max_tokens) + # token_to_local_position_within_kv_block (int32, max_tokens) + # token_to_request_idx (int32, max_tokens) + # token_to_position_in_request (int32, max_tokens) + # request_in_prefill_status (staging) (int32, max_requests) + # request_query_lengths (staging) (int32, max_requests) + # request_kv_length_offsets (staging) (int32, max_requests) + # temperature (staging) (float32, max_requests) + # top_k (staging) (int32, max_requests) + # top_p (staging) (float32, max_requests) + # active_request_last_token_idxs (alias) (int32, max_requests) + # + # Token fields are aliased with the source-of-truth attributes + # (`self.token_to_input_ids`, etc.) because the forward pass reads + # `gpu_view.token_to_input_ids[:n_tok]` which matches the CPU slot + # layout `[0, n_tok)`. Request fields, however, are read on GPU at + # `[:n_active]` but on CPU at `[paused_count:total_count)` — so the + # staging slots here are refreshed each step by copying the active + # slice from the persistent `request_*` tensors above. + _tok_int64_bytes = self.max_tokens * 8 + _tok_int32_bytes = self.max_tokens * 4 + # Request-level fields are all 4 bytes wide (5 int32 + 2 float32 = 7 fields). + _req_4byte_bytes = self.max_requests * 4 + # MHA section: 5 fields (int32) shared between GraphedMHAMetadata and + # NonGraphedMHAMetadata. max_bs == max_requests. + _mha_query_lengths_bytes = self.max_requests * 4 + _mha_cu_query_seq_lengths_bytes = (self.max_requests + 1) * 4 + _mha_kv_seq_lengths_bytes = self.max_requests * 4 + _mha_cu_kv_seq_lengths_bytes = (self.max_requests + 1) * 4 + _mha_block_table_bytes = self.max_requests * self.max_kv_block_count * 4 + # Mamba section: 9 int32 fields (hybrid models only). Must match the + # MambaMetadata shapes (mirrors the layout documented in ContextGPUView). + if self.is_hybrid_model: + self._max_mamba_chunks = self.max_tokens // self.mamba_chunk_size + self.max_requests + _mamba_batch_indices_decode_bytes = self.max_requests * 4 + _mamba_batch_indices_prefill_bytes = self.max_requests * 4 + _mamba_seq_idx_bytes = self.max_tokens * 4 + _mamba_cu_seqlens_bytes = (self.max_requests + 1) * 4 + _mamba_cu_chunk_seqlens_bytes = (self._max_mamba_chunks + 1) * 4 + _mamba_last_chunk_indices_bytes = self.max_requests * 4 + _mamba_seq_idx_for_varlen_bytes = self._max_mamba_chunks * 4 + _mamba_conv_seq_idx_bytes = self.max_tokens * 4 + _mamba_conv_seq_start_bytes = self.max_tokens * 4 + else: + self._max_mamba_chunks = 0 + _mamba_batch_indices_decode_bytes = 0 + _mamba_batch_indices_prefill_bytes = 0 + _mamba_seq_idx_bytes = 0 + _mamba_cu_seqlens_bytes = 0 + _mamba_cu_chunk_seqlens_bytes = 0 + _mamba_last_chunk_indices_bytes = 0 + _mamba_seq_idx_for_varlen_bytes = 0 + _mamba_conv_seq_idx_bytes = 0 + _mamba_conv_seq_start_bytes = 0 + _total_bytes = ( + 2 * _tok_int64_bytes + + 4 * _tok_int32_bytes + + 7 * _req_4byte_bytes + + _mha_query_lengths_bytes + + _mha_cu_query_seq_lengths_bytes + + _mha_kv_seq_lengths_bytes + + _mha_cu_kv_seq_lengths_bytes + + _mha_block_table_bytes + + _mamba_batch_indices_decode_bytes + + _mamba_batch_indices_prefill_bytes + + _mamba_seq_idx_bytes + + _mamba_cu_seqlens_bytes + + _mamba_cu_chunk_seqlens_bytes + + _mamba_last_chunk_indices_bytes + + _mamba_seq_idx_for_varlen_bytes + + _mamba_conv_seq_idx_bytes + + _mamba_conv_seq_start_bytes + ) + self._cpu_bookkeeping_buf = torch.empty( + _total_bytes, dtype=torch.uint8, device='cpu', pin_memory=True + ) + # token_to_input_ids and token_to_pos_ids were previously torch.full(0); + # zero the whole buffer so their views start at 0 too, and so the + # request staging slots start with a deterministic value. + self._cpu_bookkeeping_buf.fill_(0) + + _off = 0 + # Per-token state (source-of-truth lives in the coalesced buffer since + # the CPU-side bookkeeping and the GPU forward pass use the same + # `[:n_tok]` slice). + self.token_to_input_ids = self._cpu_bookkeeping_buf[_off : _off + _tok_int64_bytes].view( + torch.long + ) + _off += _tok_int64_bytes + self.token_to_pos_ids = self._cpu_bookkeeping_buf[_off : _off + _tok_int64_bytes].view( + torch.long + ) + _off += _tok_int64_bytes + self.token_to_block_idx = self._cpu_bookkeeping_buf[_off : _off + _tok_int32_bytes].view( + torch.int32 + ) + _off += _tok_int32_bytes # i.e For a set of tokens A B C D E F .. and block_size 4: # token_to_position_in_request is [0, 1, 2, 3, 4, 5] # token_to_local_position_within_kv_block is [0 , 1, 2, 3, 0, 1, 2] - self.token_to_position_in_request = torch.empty_like(self.token_to_input_ids) - self.token_to_local_position_within_kv_block = torch.empty_like(self.token_to_input_ids) - - # NOTE: Need to build this outside the UVM / TMS context to avoid IMA. + self.token_to_local_position_within_kv_block = self._cpu_bookkeeping_buf[ + _off : _off + _tok_int32_bytes + ].view(torch.int32) + _off += _tok_int32_bytes + self.token_to_request_idx = self._cpu_bookkeeping_buf[_off : _off + _tok_int32_bytes].view( + torch.int32 + ) + _off += _tok_int32_bytes + self.token_to_position_in_request = self._cpu_bookkeeping_buf[ + _off : _off + _tok_int32_bytes + ].view(torch.int32) + _off += _tok_int32_bytes + + # Request-level staging views into the coalesced buffer. Write-only on + # CPU (refreshed from persistent tensors in transfer_bookkeeping_to_gpu); + # read-only on GPU via matching slots in ContextGPUView._buf. + self._staging_request_in_prefill_status = self._cpu_bookkeeping_buf[ + _off : _off + _req_4byte_bytes + ].view(torch.int32) + _off += _req_4byte_bytes + self._staging_request_query_lengths = self._cpu_bookkeeping_buf[ + _off : _off + _req_4byte_bytes + ].view(torch.int32) + _off += _req_4byte_bytes + self._staging_request_kv_length_offsets = self._cpu_bookkeeping_buf[ + _off : _off + _req_4byte_bytes + ].view(torch.int32) + _off += _req_4byte_bytes + + # Sampling-parameter staging slots, refreshed from `active_request_metadata` + # in transfer_bookkeeping_to_gpu(). FlashInfer reads these via + # `gpu_view.{temperature, top_k, top_p}`. + self._staging_temperature = self._cpu_bookkeeping_buf[_off : _off + _req_4byte_bytes].view( + torch.float32 + ) + _off += _req_4byte_bytes + self._staging_top_k = self._cpu_bookkeeping_buf[_off : _off + _req_4byte_bytes].view( + torch.int32 + ) + _off += _req_4byte_bytes + self._staging_top_p = self._cpu_bookkeeping_buf[_off : _off + _req_4byte_bytes].view( + torch.float32 + ) + _off += _req_4byte_bytes + + # Per-request last-token row indices. Aliased with the matching gpu_view slot: + # build_active_slices/pad_active_slices populate this CPU view. + self.active_request_last_token_idxs = self._cpu_bookkeeping_buf[ + _off : _off + _req_4byte_bytes + ].view(torch.int32) + _off += _req_4byte_bytes + + # Static tensor addresses to make `last_token_logits` graphable with speculative decoding. + max_logit_idxs = self.max_requests * (self.num_speculative_tokens + 1) + self.active_logit_idxs = torch.zeros( + max_logit_idxs, dtype=torch.int32, device=torch.cuda.current_device() + ) + self._decode_logit_idxs = torch.arange( + max_logit_idxs, dtype=torch.int32, device=torch.cuda.current_device() + ) + + # MHA flash-attention metadata views (write-only on CPU, read-only on + # GPU via the matching region of ContextGPUView._buf). Populated per + # step by initialize_attention_state(); transferred as part of the + # single coalesced H2D in transfer_bookkeeping_to_gpu(). + self._cpu_mha_query_lengths = self._cpu_bookkeeping_buf[ + _off : _off + _mha_query_lengths_bytes + ].view(torch.int32) + _off += _mha_query_lengths_bytes + self._cpu_mha_cu_query_seq_lengths = self._cpu_bookkeeping_buf[ + _off : _off + _mha_cu_query_seq_lengths_bytes + ].view(torch.int32) + _off += _mha_cu_query_seq_lengths_bytes + self._cpu_mha_kv_seq_lengths = self._cpu_bookkeeping_buf[ + _off : _off + _mha_kv_seq_lengths_bytes + ].view(torch.int32) + _off += _mha_kv_seq_lengths_bytes + self._cpu_mha_cu_kv_seq_lengths = self._cpu_bookkeeping_buf[ + _off : _off + _mha_cu_kv_seq_lengths_bytes + ].view(torch.int32) + _off += _mha_cu_kv_seq_lengths_bytes + self._cpu_mha_block_table = ( + self._cpu_bookkeeping_buf[_off : _off + _mha_block_table_bytes] + .view(torch.int32) + .view(self.max_requests, self.max_kv_block_count) + ) + _off += _mha_block_table_bytes + + # Mamba varlen metadata views (hybrid models only). Populated per step + # by MambaMetadata.compute_cpu_metadata(); transferred as part of the + # single coalesced H2D in transfer_bookkeeping_to_gpu(). if self.is_hybrid_model: - self.mamba_metadata = MambaMetadata( - max_requests=self.max_requests, - max_tokens=self.max_tokens, - mamba_chunk_size=self.mamba_chunk_size, - d_conv=self.mamba_conv_states_shape[-1], + self._cpu_mamba_batch_indices_decode = self._cpu_bookkeeping_buf[ + _off : _off + _mamba_batch_indices_decode_bytes + ].view(torch.int32) + _off += _mamba_batch_indices_decode_bytes + self._cpu_mamba_batch_indices_prefill = self._cpu_bookkeeping_buf[ + _off : _off + _mamba_batch_indices_prefill_bytes + ].view(torch.int32) + _off += _mamba_batch_indices_prefill_bytes + self._cpu_mamba_seq_idx = ( + self._cpu_bookkeeping_buf[_off : _off + _mamba_seq_idx_bytes] + .view(torch.int32) + .view(1, self.max_tokens) ) + _off += _mamba_seq_idx_bytes + self._cpu_mamba_cu_seqlens = self._cpu_bookkeeping_buf[ + _off : _off + _mamba_cu_seqlens_bytes + ].view(torch.int32) + _off += _mamba_cu_seqlens_bytes + self._cpu_mamba_cu_chunk_seqlens = self._cpu_bookkeeping_buf[ + _off : _off + _mamba_cu_chunk_seqlens_bytes + ].view(torch.int32) + _off += _mamba_cu_chunk_seqlens_bytes + self._cpu_mamba_last_chunk_indices = self._cpu_bookkeeping_buf[ + _off : _off + _mamba_last_chunk_indices_bytes + ].view(torch.int32) + _off += _mamba_last_chunk_indices_bytes + self._cpu_mamba_seq_idx_for_varlen = self._cpu_bookkeeping_buf[ + _off : _off + _mamba_seq_idx_for_varlen_bytes + ].view(torch.int32) + _off += _mamba_seq_idx_for_varlen_bytes + self._cpu_mamba_conv_seq_idx = self._cpu_bookkeeping_buf[ + _off : _off + _mamba_conv_seq_idx_bytes + ].view(torch.int32) + _off += _mamba_conv_seq_idx_bytes + self._cpu_mamba_conv_seq_start = self._cpu_bookkeeping_buf[ + _off : _off + _mamba_conv_seq_start_bytes + ].view(torch.int32) + _off += _mamba_conv_seq_start_bytes + + assert _off == _total_bytes, f"layout bug: wrote {_off} of {_total_bytes} bytes" + + # GPU view: the single interface for GPU code to read context state. + # Populated per-step by transfer_bookkeeping_to_gpu(). + self.gpu_view = ContextGPUView( + max_requests=self.max_requests, + max_tokens=self.max_tokens, + max_kv_blocks=self.max_kv_block_count, + device=torch.cuda.current_device(), + max_mamba_chunks=self._max_mamba_chunks, + ) + + # Bind the shared MHA GPU views to both graph and non-graph metadata; + # only one is active per step, so sharing storage is safe. + self.graph_attn_metadata["mha_metadata"].bind_gpu_buffers(self.gpu_view) + self.non_graph_attn_metadata["mha_metadata"].bind_gpu_buffers(self.gpu_view) + + # Deferred Mamba GPU operations. Populated by add_request() / + # update_requests() (CPU phase), executed by transfer_bookkeeping_to_gpu(). + self._pending_mamba_zeros: list = [] + self._pending_mamba_restores: list = [] # Allocate large non-graphed buffers. need_static_addr = ( @@ -1062,6 +1385,61 @@ def get_active_request_count(self): """Returns the current number of active requests.""" return self.total_request_count - self.paused_request_count + def build_active_slices(self, batch_size: int): + """Build the active slices of specific tensors. This is run on every forward step. + + If the context is reordered to active -> paused -> finished, this can be graphed. + """ + padded_slice = slice(self.paused_request_count, self.paused_request_count + batch_size) + + # Request metadata all needs to be sliced. + for label in self.request_metadata: + self.active_request_metadata[label][:batch_size].copy_( + self.request_metadata[label][padded_slice], non_blocking=True + ) + + torch.cumsum( + self.request_query_lengths[padded_slice], + dim=0, + out=self.active_request_last_token_idxs[:batch_size], + ) + self.active_request_last_token_idxs[:batch_size].sub_(1) + + def pad_active_slices(self): + """Pad the active slices of specific tensors.""" + active_request_count = self.total_request_count - self.paused_request_count + active_decode_count = self.num_decode_requests + active_prefill_count = active_request_count - active_decode_count + active_decode_token_count = active_decode_count * (self.num_speculative_tokens + 1) + + # Decode prefix: positions [0, 1, ..., active_decode_token_count - 1]. + self.active_logit_idxs[:active_decode_token_count].copy_( + self._decode_logit_idxs[:active_decode_token_count] + ) + + # Prefill last-token positions: cumsum the prefill query lengths in place, + # then shift by (active_decode_token_count - 1) to get absolute positions. + prefill_dst = self.active_logit_idxs[ + active_decode_token_count : active_decode_token_count + active_prefill_count + ] + prefill_idxs = self.paused_request_count + active_decode_count + prefill_lengths = self.request_query_lengths[prefill_idxs : self.total_request_count] + if active_prefill_count > 0: + prefill_cumsum = torch.cumsum(prefill_lengths, dim=0, dtype=torch.int32) + prefill_cumsum.add_(active_decode_token_count - 1) + prefill_dst.copy_(prefill_cumsum, non_blocking=True) + + self.active_logit_idxs[active_decode_token_count + active_prefill_count :].zero_() + + padding_request_slice = slice(active_request_count, self.padded_active_request_count) + + # Sampling metadata: pad with neutral defaults, so that the kernel early-exits. + self.active_request_metadata["temperature"][padding_request_slice].fill_(1.0) + self.active_request_metadata["top_k"][padding_request_slice].fill_(0) + self.active_request_metadata["top_p"][padding_request_slice].fill_(0.0) + # Padded gather indices fan in to row 0 harmlessly when used by FlashInfer. + self.active_request_last_token_idxs[padding_request_slice].fill_(0) + def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) -> None: """Append to KV cache. @@ -1080,12 +1458,12 @@ def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) value=value, memory_buffer=self.memory_buffer, padded_active_token_count=self.padded_active_token_count, - token_to_block_idx=self.token_to_block_idx, - token_to_local_position_within_kv_block=self.token_to_local_position_within_kv_block, + token_to_block_idx=self.gpu_view.token_to_block_idx, + token_to_local_position_within_kv_block=self.gpu_view.token_to_local_position_within_kv_block, ) - block_idx = self.token_to_block_idx[: self.padded_active_token_count] - local_kv_seq_idx = self.token_to_local_position_within_kv_block[ + block_idx = self.gpu_view.token_to_block_idx[: self.padded_active_token_count] + local_kv_seq_idx = self.gpu_view.token_to_local_position_within_kv_block[ : self.padded_active_token_count ] @@ -1221,7 +1599,7 @@ def apply_fused_qk_rotary_emb( # use .view instead of .reshape to avoid extra transpose operations query_rope, key_rope = flashinfer.rope.apply_rope_with_cos_sin_cache( - positions=self.token_to_pos_ids[:n], + positions=self.gpu_view.token_to_pos_ids[:n], query=query[:n].reshape(n, num_q_heads * head_size), key=key[:n].reshape(n, num_k_heads * head_size), head_size=head_size, @@ -1254,7 +1632,7 @@ def apply_rotary_emb_query( (Tensor) Query tensor after applying rotary embeddings. """ n = self.padded_active_token_count - query_seq_idx = self.token_to_pos_ids[:n] + query_seq_idx = self.gpu_view.token_to_pos_ids[:n] query_emb = query_emb[query_seq_idx] query[:n] = apply_rotary_pos_emb( t=query[:n], @@ -1287,7 +1665,7 @@ def apply_rotary_emb_key( (Tensor) Key tensor after applying rotary embeddings. """ n = self.padded_active_token_count - key_seq_idx = self.token_to_position_in_request[:n] + key_seq_idx = self.gpu_view.token_to_position_in_request[:n] key_emb = key_emb[key_seq_idx] if self.is_decode_only(): if key.shape[0] != n: @@ -1316,6 +1694,20 @@ def apply_rotary_emb_key( ) return key + def set_ep_zmq_communicator(self, communicator) -> None: + """Attach an EP-group ZMQ communicator for CPU-side sync collectives. + + When set, match_graph_config() uses this communicator's + sync_all_reduce_max() to perform the EP batch-dimension MAX reduction on + the CPU instead of launching a NCCL AllReduce kernel on the compute + stream. Expected to be called once by the inference engine after both + the context and the communicator have been created. + + Args: + communicator: AsyncZMQCommunicator over the EP process group. + """ + self._ep_zmq_communicator = communicator + def reset_attention_state(self) -> None: """Reset state used within attention, after each step.""" # Attention metadata reset is now handled by MHAMetadata.reset() @@ -1402,9 +1794,9 @@ def add_dummy_requests_parallel( self.request_output_lengths[request_slice] = lengths_tensor + tokens_to_generate_tensor self.request_kv_length_offsets[request_slice] = 0 self.request_kv_block_counts[request_slice] = block_counts - for i, (label, dtype, _) in enumerate(self.request_metadata_types): + for i, (label, dtype) in enumerate(self.request_metadata_types): self.request_metadata[label][request_slice] = torch.tensor( - metadata_cols[i], dtype=dtype, device=torch.cuda.current_device() + metadata_cols[i], dtype=dtype, device='cpu' ) dummy_block_idx = self.kv_block_allocator.dummy_block_idx @@ -1463,8 +1855,7 @@ def add_dummy_requests_parallel( raise ContextOverflowError( requests[logical_idx].request_id, "No Mamba slots available" ) - self.mamba_conv_states[:, mamba_idx] = 0.0 - self.mamba_ssm_states[:, mamba_idx] = 0.0 + self._pending_mamba_zeros.append(mamba_idx) self.mamba_metadata.request_to_mamba_state_idx[request_idx] = mamba_idx self.active_token_count = token_end @@ -1486,7 +1877,7 @@ def add_dummy_requests_for_cudagraph_capture( # Pre-construct shared objects (safe due to deep copy in DynamicInferenceRequest.__post_init__) shared_sampling_params = SamplingParams(num_tokens_to_generate=1, termination_id=-1) shared_decode_tokens = torch.zeros( - self.num_speculative_tokens + 1, dtype=torch.long, device=torch.cuda.current_device() + self.num_speculative_tokens + 1, dtype=torch.long, device='cpu' ) decode_requests = [ @@ -1516,9 +1907,7 @@ def add_dummy_requests_for_cudagraph_capture( assert per_prefill_tokens > 0 # Create a single large tensor and slice from it for each prefill request max_prefill_tokens = per_prefill_tokens + (1 if rem_prefill_tokens > 0 else 0) - shared_prefill_tokens = torch.zeros( - max_prefill_tokens, dtype=torch.long, device=torch.cuda.current_device() - ) + shared_prefill_tokens = torch.zeros(max_prefill_tokens, dtype=torch.long, device='cpu') prefill_requests = [ DynamicInferenceRequest( @@ -1564,7 +1953,7 @@ def add_dummy_requests_for_expert_parallel_step( self.active_token_count = T self.num_prefill_requests = N_prefill - # 2. Per-request state consumed by mha_metadata.update(). + # 2. Per-request state consumed by initialize_attention_state(). # Decode requests come first, followed by prefill requests. self.request_query_lengths[0:N_decode].fill_(tokens_per_decode_request) if N_prefill > 0: @@ -1623,6 +2012,10 @@ def initialize_attention_state( Return: None. """ + # Launch deferred Mamba GPU ops first (state zeroing/restore) so they + # overlap with the CPU work below. These are non-blocking GPU kernels. + self._execute_pending_mamba_ops() + self.is_creating_cuda_graphs = construct_graph_dimensions is not None assert not ( self.is_creating_cuda_graphs and is_expert_parallel_dummy_cuda_graph_step @@ -1632,53 +2025,36 @@ def initialize_attention_state( # EP dummy requests are added AFTER the EP sync below. if self.is_creating_cuda_graphs: self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) - - if is_expert_parallel_dummy_cuda_graph_step: - # No real requests on this EP rank. Pass empty dimensions so the EP - # all-reduce in match_graph_config picks up the real ranks' values. - batch_dimensions = InferenceBatchDimensions( - token_count=0, prefill_req_count=0, decode_req_count=0 - ) - else: - batch_dimensions = InferenceBatchDimensions( - token_count=self.active_token_count, - prefill_req_count=self.num_prefill_requests, - decode_req_count=self.num_decode_requests, + elif is_expert_parallel_dummy_cuda_graph_step: + self.add_dummy_requests_for_expert_parallel_step( + InferenceBatchDimensions( + token_count=self.num_speculative_tokens + 1, + prefill_req_count=0, + decode_req_count=1, + ) ) + batch_dimensions = InferenceBatchDimensions( + token_count=self.active_token_count, + prefill_req_count=self.num_prefill_requests, + decode_req_count=self.num_decode_requests, + ) + self.batch_dimensions = batch_dimensions best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, self.cuda_graph_batch_dimensions_list, - smallest_non_decode_cuda_graph_size=self.smallest_non_decode_cuda_graph_size, strict=self.is_hybrid_model, - decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps), ep_group=self.expert_model_parallel_group, - num_speculative_tokens=self.num_speculative_tokens, + match_ep_token_counts=self._nccl_ep_dispatcher or self._training_ep_dispatcher, + ep_zmq_communicator=self._ep_zmq_communicator, ) self._using_cuda_graph_this_step = best_graph is not None if construct_graph_dimensions is not None: assert self._using_cuda_graph_this_step - if is_expert_parallel_dummy_cuda_graph_step and not self.using_cuda_graph_this_step(): - # If we are here, this means that CUDAGraphBatchDimensionBuilder.match_graph_config - # could not find a compatible cuda graph for the dummy forward step. - # Now, we need not do the remaining setup. The controller - # will directly call the model forward pass with a single token. - return - - # Add dummy requests AFTER the EP sync so they match the resolved graph. - if is_expert_parallel_dummy_cuda_graph_step: - self.add_dummy_requests_for_expert_parallel_step(best_graph) - batch_dimensions = InferenceBatchDimensions( - token_count=self.active_token_count, - prefill_req_count=self.num_prefill_requests, - decode_req_count=self.num_decode_requests, - ) - self.batch_dimensions = batch_dimensions - if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph else: @@ -1713,6 +2089,11 @@ def initialize_attention_state( self.padded_active_request_count = self.padded_batch_dimensions.req_count self.padding_slice = slice(self.active_token_count, self.padded_active_token_count) + self.build_active_slices( + min(self.padded_active_request_count, self.max_requests - self.paused_request_count) + ) + self.pad_active_slices() + # Update token position indexes. self.token_to_block_idx[self.active_token_count : self.padded_active_token_count] = ( self.kv_block_allocator.dummy_block_idx @@ -1750,31 +2131,98 @@ def initialize_attention_state( ) assert self.active_attn_metadata is not None - self.active_attn_metadata["mha_metadata"].update( - request_query_lengths=query_lengths_view, - request_kv_length_offsets=request_kv_length_offsets_view, - request_to_kv_block_ids=request_to_kv_block_ids_view, - batch_dimensions=attn_dimensions, - padded_batch_dimensions=self.padded_batch_dimensions, - num_speculative_tokens=self.num_speculative_tokens, + + # Compute MHA metadata directly into the pinned CPU section of + # _cpu_bookkeeping_buf. The single coalesced H2D in + # transfer_bookkeeping_to_gpu() covers these fields along with the rest + # of the bookkeeping state, so no ephemeral tensors and no per-field + # cudaMemcpyAsyncs. + real_bs = attn_dimensions.req_count + padded_bs = self.padded_batch_dimensions.req_count + mha = self.active_attn_metadata["mha_metadata"] + + # Query lengths: [0:real_bs] real data, [real_bs:padded_bs] zero pad. + self._cpu_mha_query_lengths[:real_bs] = query_lengths_view[:real_bs] + if real_bs < padded_bs: + self._cpu_mha_query_lengths[real_bs:padded_bs] = 0 + + # Cumulative query lengths (padded slots repeat cu[real_bs]). + self._cpu_mha_cu_query_seq_lengths[0] = 0 + if real_bs > 0: + self._cpu_mha_cu_query_seq_lengths[1 : real_bs + 1] = torch.cumsum( + query_lengths_view[:real_bs], dim=0 + ) + if real_bs < padded_bs: + self._cpu_mha_cu_query_seq_lengths[real_bs + 1 : padded_bs + 1] = ( + self._cpu_mha_cu_query_seq_lengths[real_bs] + ) + + # KV sequence lengths: [0:real_bs] = kv_offsets + query_lengths. + self._cpu_mha_kv_seq_lengths[:real_bs] = ( + request_kv_length_offsets_view[:real_bs] + query_lengths_view[:real_bs] + ) + if real_bs < padded_bs: + self._cpu_mha_kv_seq_lengths[real_bs:padded_bs] = 0 + + # Cumulative KV lengths. + self._cpu_mha_cu_kv_seq_lengths[0] = 0 + if real_bs > 0: + self._cpu_mha_cu_kv_seq_lengths[1 : real_bs + 1] = torch.cumsum( + self._cpu_mha_kv_seq_lengths[:real_bs], dim=0 + ) + if real_bs < padded_bs: + self._cpu_mha_cu_kv_seq_lengths[real_bs + 1 : padded_bs + 1] = ( + self._cpu_mha_cu_kv_seq_lengths[real_bs] + ) + + # Block table: [0:real_bs] real, [real_bs:padded_bs] = -1 sentinel. + self._cpu_mha_block_table[:real_bs] = request_to_kv_block_ids_view[:real_bs] + if real_bs < padded_bs: + self._cpu_mha_block_table[real_bs:padded_bs] = -1 + + # Max sequence lengths (Python scalars; consumed as kernel launch args). + if not self.using_cuda_graph_this_step() and real_bs > 0: + # NonGraphedMHAMetadata: use actual max values. + max_seqlen_q = self._cpu_mha_query_lengths[:real_bs].max().item() + max_seqlen_k = self._cpu_mha_kv_seq_lengths[:real_bs].max().item() + else: + # GraphedMHAMetadata: use conservative bounds. + if self.padded_batch_dimensions.prefill_req_count == 0: + max_seqlen_q = self.num_speculative_tokens + 1 + else: + max_seqlen_q = max(2, self.padded_batch_dimensions.token_count) + max_seqlen_k = mha.max_seqlen + if not self.using_cuda_graph_this_step() and real_bs == 0: + max_seqlen_q = self.num_speculative_tokens + 1 + max_seqlen_k = 1 + + # Bind state_data to GPU views now. set_state_data() only creates Python + # slice references into the GPU buffer (no GPU reads), so it's safe to + # call before the H2D in transfer_bookkeeping_to_gpu(). This guarantees + # that callers reading state_data["block_table"] etc. between + # initialize_attention_state() and transfer_bookkeeping_to_gpu() see + # populated entries (the actual data fill happens at the H2D). + mha.set_state_data( + padded_active_request_count=padded_bs, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, ) if self.is_hybrid_model: - active_mamba_indices_view = self.mamba_metadata.request_to_mamba_state_idx[active_slice] - token_to_request_idx_view = self.token_to_request_idx[: self.active_token_count] - cu_seqlens = self.active_attn_metadata["mha_metadata"].state_data[ - "cu_query_seq_lengths" - ] + # Mamba metadata update is deferred to transfer_bookkeeping_to_gpu() + # because it writes to GPU buffers. Store the parameters here. + # intermediate_offsets_gpu / intermediate_counts_gpu get the CPU-side + # slices here; H2D transfer happens in transfer_bookkeeping_to_gpu(). intermediate_offsets_gpu = None intermediate_counts_gpu = None if self.mamba_slot_allocator is not None: intermediate_offsets_gpu, intermediate_counts_gpu = ( - self.mamba_slot_allocator.get_intermediate_gpu_data() + self.mamba_slot_allocator.get_intermediate_cpu_data() ) - self.mamba_metadata.update( - active_mamba_indices_view, - token_to_request_idx_view, - cu_seqlens, + self._pending_mamba_transfer = self.mamba_metadata.compute_cpu_metadata( + active_mamba_indices=self.mamba_metadata.request_to_mamba_state_idx[active_slice], + token_to_request_idx=self.token_to_request_idx[: self.active_token_count], + cpu_cu_query=self._cpu_mha_cu_query_seq_lengths, batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, enable_chunked_prefill=self.is_chunked_prefill_enabled(), @@ -1788,8 +2236,111 @@ def initialize_attention_state( else: self.moe_routing_metadata.disable_static_buffer_recording() + # Flip NCCLAllGather dispatcher's path selector to not use allgathers. + # _nccl_ep_dispatcher already implies ep_size > 1, so no extra EP guard. + if self._nccl_ep_dispatcher: + NCCLAllGatherDispatcher._use_allgather_v = not self.using_cuda_graph_this_step() + + # Flush any Mamba ops queued by add_dummy_requests_for_cudagraph_capture + # (warmup) or add_dummy_requests_for_expert_parallel_step (EP dummy step). + # The earlier call at the top drained ops queued by add_request() before + # this function ran; this call covers ops queued during the function. + # No-op when the queue is already empty (regular non-warmup steps). + self._execute_pending_mamba_ops() + + # Run the H2D transfer here so callers that bypass the controller + # (e.g. unit tests that call `model.forward()` directly after + # `initialize_attention_state()`) see populated GPU bookkeeping. The + # text-generation controller still calls `transfer_bookkeeping_to_gpu` + # explicitly; that second call is a cheap idempotent re-copy. + self.transfer_bookkeeping_to_gpu() + + def _execute_pending_mamba_ops(self) -> None: + """Execute Mamba GPU operations deferred from add_request() / update_requests(). + + This runs at the start of initialize_attention_state() so that all GPU + Mamba state is correct before the forward pass. + """ + if not (self._pending_mamba_restores or self._pending_mamba_zeros): + return + + # Restore cached Mamba state to live buffers. On failure, fall back to zeroing. + for request_idx, block_id, mamba_idx in self._pending_mamba_restores: + restored = self.mamba_slot_allocator.restore_to_live(request_idx, block_id) + if not restored: + self._pending_mamba_zeros.append(mamba_idx) + self._pending_mamba_restores.clear() + + # Batch-zero newly allocated Mamba slots. + if self._pending_mamba_zeros: + device = self.mamba_conv_states.device + indices = torch.tensor(self._pending_mamba_zeros, dtype=torch.long, device=device) + self.mamba_conv_states[:, indices] = 0.0 + self.mamba_ssm_states[:, indices] = 0.0 + self._pending_mamba_zeros.clear() + + def transfer_bookkeeping_to_gpu(self) -> None: + """Batch transfer CPU bookkeeping state to GPU staging buffers. + + Called after initialize_attention_state() and before the forward pass. + All copies use non_blocking=True with pinned CPU memory. CUDA stream + ordering guarantees the forward pass sees completed transfers. + + The bookkeeping fields are backed by one contiguous pinned CPU buffer + and one contiguous GPU buffer; a single cudaMemcpyAsync suffices. + Request-level staging slots are refreshed from the persistent CPU + tensors immediately before the H2D (GPU reads them at `[:n_active]` + while CPU bookkeeping keeps them at `[paused_count:total_count)`). + """ + n_active = self.total_request_count - self.paused_request_count + active_slice = slice(self.paused_request_count, self.total_request_count) + padded_active = max(n_active, self.padded_active_request_count) + + # Refresh request-level staging slots from the persistent CPU source. + # CPU-to-CPU slice assignment on pinned memory (~15 KB total for 6 + # 4-byte fields at max_requests=624). Negligible vs. the launch overhead + # we save by merging the H2D memcpys into 1. + self._staging_request_in_prefill_status[:n_active] = self.request_in_prefill_status_tensor[ + active_slice + ] + self._staging_request_query_lengths[:n_active] = self.request_query_lengths[active_slice] + self._staging_request_kv_length_offsets[:n_active] = self.request_kv_length_offsets[ + active_slice + ] + # Sampling-parameter staging slots: read from `active_request_metadata`, + # which `build_active_slices` + `pad_active_slices` already populated for + # `[:padded_active]` (active values + neutral padding defaults). + self._staging_temperature[:padded_active] = self.active_request_metadata["temperature"][ + :padded_active + ] + self._staging_top_k[:padded_active] = self.active_request_metadata["top_k"][:padded_active] + self._staging_top_p[:padded_active] = self.active_request_metadata["top_p"][:padded_active] + + # Full-iteration CUDA graphs may have captured GPU consumers with the + # padded graph request count. Keep those padded staging rows bounded so + # graph replay never builds indices from stale request lengths. + if n_active < padded_active: + self._staging_request_in_prefill_status[n_active:padded_active] = 0 + self._staging_request_query_lengths[n_active:padded_active] = 0 + self._staging_request_kv_length_offsets[n_active:padded_active] = 0 + + # Coalesced H2D: one cudaMemcpyAsync for the entire bookkeeping buffer. + # Copying the whole (max_tokens + max_requests)-sized buffer including + # unused slots is cheap (~71 KB total, ~3-5 us on PCIe Gen4) and saves + # 8 redundant launch overheads vs. the prior per-field copies. + self.gpu_view._buf.copy_(self._cpu_bookkeeping_buf, non_blocking=True) + + # MHA metadata GPU views were already bound to state_data in + # initialize_attention_state(); the H2D above populates the underlying + # bytes. Nothing else to do here for MHA. + + # Mamba metadata: copy pre-computed CPU tensors to GPU buffers. + if hasattr(self, '_pending_mamba_transfer') and self._pending_mamba_transfer is not None: + self.mamba_metadata.load_from_cpu(self._pending_mamba_transfer) + self._pending_mamba_transfer = None + def reset_tensors(self) -> None: - """Fill all GPU tensors with sentinel values.""" + """Fill all bookkeeping tensors with sentinel values.""" # Reset request indexes. self.request_ids.fill_(-1) @@ -1893,35 +2444,47 @@ def current_input_and_position_ids( self.num_speculative_tokens + 1 ) return ( - self.token_to_input_ids[:num_tokens].unsqueeze(0), - self.token_to_pos_ids[:num_tokens].unsqueeze(0), + self.gpu_view.token_to_input_ids[:num_tokens].unsqueeze(0), + self.gpu_view.token_to_pos_ids[:num_tokens].unsqueeze(0), ) - def speculative_required_logit_indices(self, device: torch.device) -> Tensor: + def speculative_required_logit_indices(self) -> Tensor: """Token-level indices needed for speculative decode verification. Returns all decode token positions (base + speculative) concatenated with the last token position of each prefill request. - Args: - device (torch.device): Device on which to create the index tensor. - Return: (Tensor) 1-D indices into the packed token sequence, length - ``num_decode_requests * (num_speculative_tokens + 1) + num_prefill_requests``. + ``num_decode_requests * (num_speculative_tokens + 1) + num_prefill_requests`` + in eager, or the equivalent padded count under non-eager. """ - paused = self.paused_request_count - total = self.total_request_count - query_lengths = self.request_query_lengths[paused:total] - num_decode = self.num_decode_requests - - decode_token_count = num_decode * (self.num_speculative_tokens + 1) - decode_indices = torch.arange(decode_token_count, device=device) + return self.active_logit_idxs[: self.num_last_token_logits] - cumsum = torch.cumsum(query_lengths, dim=0) - prefill_last_indices = cumsum[num_decode:] - 1 + @property + def num_last_token_logits(self) -> int: + """Number of rows produced by `last_token_logits` for the current step. - return torch.cat([decode_indices, prefill_last_indices]) + Single source of truth for the bound: one row per request, with + `(num_speculative_tokens + 1)` rows per decode request when MTP is active. + """ + if self.num_speculative_tokens > 0: + if self._using_cuda_graph_this_step: + return ( + self.padded_batch_dimensions.decode_req_count + * (self.num_speculative_tokens + 1) + + self.padded_batch_dimensions.prefill_req_count + ) + else: + return ( + self.num_decode_requests * (self.num_speculative_tokens + 1) + + self.num_prefill_requests + ) + else: + if self._using_cuda_graph_this_step: + return self.padded_active_request_count + else: + return self.total_request_count - self.paused_request_count def last_token_logits(self, logits: Tensor) -> Tensor: """Select the logit positions needed for token generation. @@ -1936,7 +2499,7 @@ def last_token_logits(self, logits: Tensor) -> Tensor: logits (Tensor): Output logits of forward pass, shape [1, S, H]. Return: - (Tensor) Selected logits, shape [N, H]. + (Tensor) Selected logits, shape [N, H], where N == num_last_token_logits. """ # todo: @lmcafee, remove these asserts? assert logits.size(0) == 1, f"logits.size(0) ({tuple(logits.shape)}) != 1" @@ -1944,17 +2507,7 @@ def last_token_logits(self, logits: Tensor) -> Tensor: f"logits.size(1) ({tuple(logits.shape)}) != " f"padded_active_token_count ({self.padded_active_token_count})." ) - logits_2d = logits.squeeze(0) - - if self.num_speculative_tokens > 0: - selected = self.speculative_required_logit_indices(logits.device) - return logits_2d[selected, :] - - paused = self.paused_request_count - total = self.total_request_count - query_lengths = self.request_query_lengths[paused:total] - last_token_idxs = torch.cumsum(query_lengths, dim=0) - 1 - return logits_2d[last_token_idxs, :] + return logits.squeeze(0)[self.active_logit_idxs[: self.num_last_token_logits], :] def _compute_prefix_match( self, req: DynamicInferenceRequest, prefill_chunk_length: int @@ -2167,9 +2720,7 @@ def add_request( # Increment ref counts and update timestamps for matched (shared) blocks if num_matched_blocks > 0: - matched_tensor = torch.tensor( - matched_block_ids, dtype=torch.int32, device=torch.cuda.current_device() - ) + matched_tensor = torch.tensor(matched_block_ids, dtype=torch.int32, device='cpu') self.kv_block_allocator.block_ref_counts[matched_tensor] += 1 if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: self.kv_block_allocator.update_timestamps(matched_tensor) @@ -2194,7 +2745,7 @@ def add_request( metadata = req.tracked_metadata metadata_types = req.get_metadata_types() for m, m_type in zip(metadata, metadata_types): - label, _, _ = m_type + label, _ = m_type if not isinstance(m, torch.Tensor): m = torch.as_tensor( m, @@ -2293,17 +2844,18 @@ def _register_range(start: int, end: int): # Restore Mamba state from the block corresponding to prefix_skip_tokens restore_block_count = prefix_skip_tokens // self.block_size_tokens - restored = False if restore_block_count > 0 and self.mamba_slot_allocator is not None: restore_block_id = matched_block_ids[restore_block_count - 1] - restored = self.mamba_slot_allocator.restore_to_live( - self.total_request_count, restore_block_id + self._pending_mamba_restores.append( + (self.total_request_count, restore_block_id, mamba_idx) ) - if not restored: - self.mamba_conv_states[:, mamba_idx] = 0.0 - self.mamba_ssm_states[:, mamba_idx] = 0.0 + else: + self._pending_mamba_zeros.append(mamba_idx) - # Compute intermediate offsets for state extraction during forward pass + # compute_and_store_offsets sets both CPU state (hash_to_block_id, + # _eos_cache_block_id_gpu) and GPU staging buffers. Runs immediately + # because commit_intermediate_states() reads the CPU state after the + # forward pass. if self.mamba_slot_allocator is not None: self.mamba_slot_allocator.compute_and_store_offsets( req, @@ -2427,13 +2979,13 @@ def release_memory_blocks_from_request_indexes(self, request_indexes) -> None: if self.is_hybrid_model: self.mamba_metadata.free_slots(request_indexes) - # Clear intermediate offset entries for released requests + # Clear intermediate offset entries for released requests (CPU writes). if self.mamba_slot_allocator is not None: sa = self.mamba_slot_allocator - sa._intermediate_counts_gpu[request_indexes] = 0 - sa._intermediate_offsets_gpu[request_indexes] = 0 - sa._intermediate_block_ids_gpu[request_indexes] = -1 - sa._eos_cache_block_id_gpu[request_indexes] = -1 + sa._intermediate_counts_cpu[request_indexes] = 0 + sa._intermediate_offsets_cpu[request_indexes] = 0 + sa._intermediate_block_ids_cpu[request_indexes] = -1 + sa._eos_cache_block_id_cpu[request_indexes] = -1 def resume_paused_requests( self, active_request_count: int, newly_paused_request_ids: torch.Tensor @@ -2567,7 +3119,7 @@ def evict_overflow_paused_requests( -1, -1, dtype=paused_block_counts_cumsum.dtype, - device=torch.cuda.current_device(), + device='cpu', ) net_block_counts = paused_block_counts_cumsum - remaining_paused_request_counts evict_request_count = torch.nonzero(net_block_counts >= 0)[0].item() + 1 @@ -2575,9 +3127,7 @@ def evict_overflow_paused_requests( # Eviction index range. evict_start_idx = self.paused_request_count - evict_request_count evict_end_idx = self.paused_request_count - evict_request_idxs = torch.arange( - evict_start_idx, evict_end_idx, device=torch.cuda.current_device() - ) + evict_request_idxs = torch.arange(evict_start_idx, evict_end_idx, device='cpu') # Clone needed: subsequent release_memory_blocks_from_request_indexes and # _swap_book_keeping_tensors calls mutate self.request_ids in place. evict_request_ids = self.request_ids[evict_start_idx:evict_end_idx].clone() @@ -2592,24 +3142,24 @@ def evict_overflow_paused_requests( src_idxs = torch.arange( self.paused_request_count - evict_request_count, self.paused_request_count, - device=torch.cuda.current_device(), + device='cpu', ) dst_idxs = torch.arange( self.total_request_count - evict_request_count, self.total_request_count, - device=torch.cuda.current_device(), + device='cpu', ) else: # Swap all active requests with left-most evicted requests. src_idxs = torch.arange( self.paused_request_count - evict_request_count, self.paused_request_count - evict_request_count + active_request_count, - device=torch.cuda.current_device(), + device='cpu', ) dst_idxs = torch.arange( self.paused_request_count, self.paused_request_count + active_request_count, - device=torch.cuda.current_device(), + device='cpu', ) # Swap evicted and active requests. @@ -2685,6 +3235,14 @@ def update_requests( # active_request_count -> This corresponds to requests that have not reached EOD or max length # finished_request_count are requests that have reached the termination criterion + # Ensure all inputs are on CPU for bookkeeping operations. + if active_requests_mask.is_cuda: + active_requests_mask = active_requests_mask.cpu() + if new_tokens.is_cuda: + new_tokens = new_tokens.cpu() + if new_speculative_tokens is not None and new_speculative_tokens.is_cuda: + new_speculative_tokens = new_speculative_tokens.cpu() + self.num_prefill_requests = 0 # all turns to decode # All request that were in prefill become decode requests. # For the chunked prefill request we will overwrite this the next time add_request @@ -2989,14 +3547,14 @@ def update_requests( self.token_to_pos_ids[: self.active_token_count] = self.request_kv_length_offsets[ self.paused_request_count : self.total_request_count ].repeat_interleave(num_generated_tokens) + torch.arange( - num_generated_tokens, device=torch.cuda.current_device() + num_generated_tokens, device='cpu' ).repeat( active_request_count ) # # Token to request idx : [0, 0, 0, 1, 1, 1, 2, 2, 2 ...] self.token_to_request_idx[: self.active_token_count] = torch.arange( - self.paused_request_count, self.total_request_count, device=torch.cuda.current_device() + self.paused_request_count, self.total_request_count, device='cpu' ).repeat_interleave(num_generated_tokens) self.token_to_position_in_request[: self.active_token_count] = self.token_to_pos_ids[ @@ -3018,7 +3576,7 @@ def update_requests( raw_positions = ( old_offsets[:, None] + 1 # Offset by 1 because old_offsets points to the LAST token - + torch.arange(num_generated_tokens, device=torch.cuda.current_device())[None, :] + + torch.arange(num_generated_tokens, device='cpu')[None, :] ) # # A token crosses to the next block if its raw_position >= block_size @@ -3134,10 +3692,9 @@ def calculate_log_probs( # # active_token_ids[new_token_idx] = new_tokens # : [ 52 | 12 | 16 3 | 12 72 24 88 86 ] - active_token_ids = self.token_to_input_ids[: self.active_token_count].roll(-1, 0) - active_query_lengths = self.request_query_lengths[ - self.paused_request_count : self.total_request_count - ] + n_active = self.total_request_count - self.paused_request_count + active_token_ids = self.gpu_view.token_to_input_ids[: self.active_token_count].roll(-1, 0) + active_query_lengths = self.gpu_view.request_query_lengths[:n_active] new_token_idx = active_query_lengths.cumsum(0) - 1 active_token_ids[new_token_idx] = new_tokens diff --git a/megatron/core/inference/contexts/gpu_view.py b/megatron/core/inference/contexts/gpu_view.py new file mode 100644 index 00000000000..65c401163b0 --- /dev/null +++ b/megatron/core/inference/contexts/gpu_view.py @@ -0,0 +1,228 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import torch + + +class ContextGPUView: + """GPU-resident snapshot of context bookkeeping data for the forward pass. + + This is the ONLY interface GPU code (attention kernels, KV append, RoPE, + sampling, log-probs, speculative verification) uses to read context state. + CPU bookkeeping code accesses context tensors directly. + + Populated once per step by ``DynamicInferenceContext.transfer_bookkeeping_to_gpu()``. + All tensors have fixed addresses for CUDA graph compatibility. + + Convention: + ``context.foo`` -> CPU (source of truth, used by bookkeeping) + ``context.gpu_view.foo`` -> GPU (snapshot, used by forward pass) + + Layout note: the bookkeeping fields are backed by a single contiguous + ``uint8`` buffer (``self._buf``). Each field is a ``view(dtype)`` onto a + slice of that buffer. This matches the pinned-CPU-buffer layout in + :class:`DynamicInferenceContext` so that the per-step H2D transfer is a + single ``cudaMemcpyAsync`` instead of one per field. + """ + + def __init__( + self, + max_requests: int, + max_tokens: int, + max_kv_blocks: int, + device: torch.device, + max_mamba_chunks: int = 0, + ): + # Field layout (must match DynamicInferenceContext's CPU buffer layout): + # int64 token fields first (auto 8-byte alignment), then int32 token + # fields, then int32 request fields, then int32 MHA fields, then + # int32 Mamba fields (hybrid models only; omitted when + # max_mamba_chunks == 0). + tok_int64_bytes = max_tokens * 8 # 2 fields of int64 = 8 bytes/elem + tok_int32_bytes = max_tokens * 4 # 4 fields of int32 = 4 bytes/elem + # Request-level fields are all 4 bytes wide. 3 int32 (in_prefill_status, + # query_lengths, kv_length_offsets) + 1 int32 (top_k) + 2 float32 + # (temperature, top_p) + 1 int32 (active_request_last_token_idxs) = 7 fields. + req_4byte_bytes = max_requests * 4 + + # MHA section: 5 fields shared by both graphed and non-graphed MHAMetadata + # (only one is active per step, so sharing storage is fine). + # mha_query_lengths int32 (max_bs,) = max_bs * 4 + # mha_cu_query_seq_lengths int32 (max_bs + 1,) = (max_bs+1) * 4 + # mha_kv_seq_lengths int32 (max_bs,) = max_bs * 4 + # mha_cu_kv_seq_lengths int32 (max_bs + 1,) = (max_bs+1) * 4 + # mha_block_table int32 (max_bs, max_kv_blocks) + # max_bs == max_requests in DynamicInferenceContext. + max_bs = max_requests + mha_query_lengths_bytes = max_bs * 4 + mha_cu_query_seq_lengths_bytes = (max_bs + 1) * 4 + mha_kv_seq_lengths_bytes = max_bs * 4 + mha_cu_kv_seq_lengths_bytes = (max_bs + 1) * 4 + mha_block_table_bytes = max_bs * max_kv_blocks * 4 + + # Mamba section: 9 int32 fields, only present for hybrid models. + # mamba_batch_indices_decode int32 (max_bs,) + # mamba_batch_indices_prefill int32 (max_bs,) + # mamba_seq_idx int32 (1, max_tokens) + # mamba_cu_seqlens int32 (max_bs + 1,) + # mamba_cu_chunk_seqlens int32 (max_mamba_chunks + 1,) + # mamba_last_chunk_indices int32 (max_bs,) + # mamba_seq_idx_for_varlen int32 (max_mamba_chunks,) + # mamba_conv_seq_idx int32 (max_tokens,) + # mamba_conv_seq_start int32 (max_tokens,) + if max_mamba_chunks > 0: + mamba_batch_indices_decode_bytes = max_bs * 4 + mamba_batch_indices_prefill_bytes = max_bs * 4 + mamba_seq_idx_bytes = max_tokens * 4 + mamba_cu_seqlens_bytes = (max_bs + 1) * 4 + mamba_cu_chunk_seqlens_bytes = (max_mamba_chunks + 1) * 4 + mamba_last_chunk_indices_bytes = max_bs * 4 + mamba_seq_idx_for_varlen_bytes = max_mamba_chunks * 4 + mamba_conv_seq_idx_bytes = max_tokens * 4 + mamba_conv_seq_start_bytes = max_tokens * 4 + else: + mamba_batch_indices_decode_bytes = 0 + mamba_batch_indices_prefill_bytes = 0 + mamba_seq_idx_bytes = 0 + mamba_cu_seqlens_bytes = 0 + mamba_cu_chunk_seqlens_bytes = 0 + mamba_last_chunk_indices_bytes = 0 + mamba_seq_idx_for_varlen_bytes = 0 + mamba_conv_seq_idx_bytes = 0 + mamba_conv_seq_start_bytes = 0 + + total_bytes = ( + 2 * tok_int64_bytes + + 4 * tok_int32_bytes + + 7 * req_4byte_bytes + + mha_query_lengths_bytes + + mha_cu_query_seq_lengths_bytes + + mha_kv_seq_lengths_bytes + + mha_cu_kv_seq_lengths_bytes + + mha_block_table_bytes + + mamba_batch_indices_decode_bytes + + mamba_batch_indices_prefill_bytes + + mamba_seq_idx_bytes + + mamba_cu_seqlens_bytes + + mamba_cu_chunk_seqlens_bytes + + mamba_last_chunk_indices_bytes + + mamba_seq_idx_for_varlen_bytes + + mamba_conv_seq_idx_bytes + + mamba_conv_seq_start_bytes + ) + + # Zero-initialized so pre-transfer reads see zeros (matches prior semantics). + self._buf = torch.zeros(total_bytes, dtype=torch.uint8, device=device) + + # Token-level tensors (consumed by embedding, RoPE, KV append, Mamba). + off = 0 + self.token_to_input_ids = self._buf[off : off + tok_int64_bytes].view(torch.long) + off += tok_int64_bytes + self.token_to_pos_ids = self._buf[off : off + tok_int64_bytes].view(torch.long) + off += tok_int64_bytes + self.token_to_block_idx = self._buf[off : off + tok_int32_bytes].view(torch.int32) + off += tok_int32_bytes + self.token_to_local_position_within_kv_block = self._buf[off : off + tok_int32_bytes].view( + torch.int32 + ) + off += tok_int32_bytes + self.token_to_request_idx = self._buf[off : off + tok_int32_bytes].view(torch.int32) + off += tok_int32_bytes + self.token_to_position_in_request = self._buf[off : off + tok_int32_bytes].view(torch.int32) + off += tok_int32_bytes + + # Request-level tensors (consumed by sampling, log-probs, speculative verification, MTP). + self.request_in_prefill_status = self._buf[off : off + req_4byte_bytes].view(torch.int32) + off += req_4byte_bytes + self.request_query_lengths = self._buf[off : off + req_4byte_bytes].view(torch.int32) + off += req_4byte_bytes + self.request_kv_length_offsets = self._buf[off : off + req_4byte_bytes].view(torch.int32) + off += req_4byte_bytes + # Sampling parameters (consumed by FlashInfer sampling). + # Mirror the active slice of `active_request_metadata[{label}]`; + # padded slots get neutral defaults from `pad_active_slices` (T=1.0, top_k=0, top_p=0.0). + self.temperature = self._buf[off : off + req_4byte_bytes].view(torch.float32) + off += req_4byte_bytes + self.top_k = self._buf[off : off + req_4byte_bytes].view(torch.int32) + off += req_4byte_bytes + self.top_p = self._buf[off : off + req_4byte_bytes].view(torch.float32) + off += req_4byte_bytes + # Per-request last-token row indices (consumed by sampling kernels as `gather_indices`). + # The CPU side of this slot IS `context.active_request_last_token_idxs`, + # populated by `build_active_slices` and `pad_active_slices`. + self.active_request_last_token_idxs = self._buf[off : off + req_4byte_bytes].view( + torch.int32 + ) + off += req_4byte_bytes + + # MHA flash-attention metadata (shared between GraphedMHAMetadata and + # NonGraphedMHAMetadata — only one is active per step). + self.mha_query_lengths = self._buf[off : off + mha_query_lengths_bytes].view(torch.int32) + off += mha_query_lengths_bytes + self.mha_cu_query_seq_lengths = self._buf[off : off + mha_cu_query_seq_lengths_bytes].view( + torch.int32 + ) + off += mha_cu_query_seq_lengths_bytes + self.mha_kv_seq_lengths = self._buf[off : off + mha_kv_seq_lengths_bytes].view(torch.int32) + off += mha_kv_seq_lengths_bytes + self.mha_cu_kv_seq_lengths = self._buf[off : off + mha_cu_kv_seq_lengths_bytes].view( + torch.int32 + ) + off += mha_cu_kv_seq_lengths_bytes + self.mha_block_table = ( + self._buf[off : off + mha_block_table_bytes] + .view(torch.int32) + .view(max_bs, max_kv_blocks) + ) + off += mha_block_table_bytes + + # Mamba varlen metadata (hybrid models only). Each GPU view matches a + # pinned CPU view in DynamicInferenceContext._cpu_bookkeeping_buf; the + # per-step coalesced H2D copy covers both MHA and Mamba alongside the + # token/request bookkeeping. + if max_mamba_chunks > 0: + self.mamba_batch_indices_decode = self._buf[ + off : off + mamba_batch_indices_decode_bytes + ].view(torch.int32) + off += mamba_batch_indices_decode_bytes + self.mamba_batch_indices_prefill = self._buf[ + off : off + mamba_batch_indices_prefill_bytes + ].view(torch.int32) + off += mamba_batch_indices_prefill_bytes + self.mamba_seq_idx = ( + self._buf[off : off + mamba_seq_idx_bytes].view(torch.int32).view(1, max_tokens) + ) + off += mamba_seq_idx_bytes + self.mamba_cu_seqlens = self._buf[off : off + mamba_cu_seqlens_bytes].view(torch.int32) + off += mamba_cu_seqlens_bytes + self.mamba_cu_chunk_seqlens = self._buf[off : off + mamba_cu_chunk_seqlens_bytes].view( + torch.int32 + ) + off += mamba_cu_chunk_seqlens_bytes + self.mamba_last_chunk_indices = self._buf[ + off : off + mamba_last_chunk_indices_bytes + ].view(torch.int32) + off += mamba_last_chunk_indices_bytes + self.mamba_seq_idx_for_varlen = self._buf[ + off : off + mamba_seq_idx_for_varlen_bytes + ].view(torch.int32) + off += mamba_seq_idx_for_varlen_bytes + self.mamba_conv_seq_idx = self._buf[off : off + mamba_conv_seq_idx_bytes].view( + torch.int32 + ) + off += mamba_conv_seq_idx_bytes + self.mamba_conv_seq_start = self._buf[off : off + mamba_conv_seq_start_bytes].view( + torch.int32 + ) + off += mamba_conv_seq_start_bytes + else: + self.mamba_batch_indices_decode = None + self.mamba_batch_indices_prefill = None + self.mamba_seq_idx = None + self.mamba_cu_seqlens = None + self.mamba_cu_chunk_seqlens = None + self.mamba_last_chunk_indices = None + self.mamba_seq_idx_for_varlen = None + self.mamba_conv_seq_idx = None + self.mamba_conv_seq_start = None + + assert off == total_bytes, f"layout bug: wrote {off} of {total_bytes} bytes" diff --git a/megatron/core/inference/contexts/kv_block_allocator.py b/megatron/core/inference/contexts/kv_block_allocator.py index 87039835c7f..d555c925c93 100644 --- a/megatron/core/inference/contexts/kv_block_allocator.py +++ b/megatron/core/inference/contexts/kv_block_allocator.py @@ -3,6 +3,7 @@ from collections import deque from typing import Callable, Dict, Optional +import numpy as np import torch from torch import Tensor @@ -47,32 +48,31 @@ def __init__( assert self.active_count >= 1 # ensures paused_count < total_count - 1 self.dummy_block_idx = self.total_count - 1 - # Initialize block pool as a "stack" data structure - self.block_bag = torch.arange( - self.total_count, dtype=torch.int32, device=torch.cuda.current_device() - ) + # Initialize block pool as a "stack" data structure (CPU for bookkeeping). + self.block_bag = torch.arange(self.total_count, dtype=torch.int32, device='cpu') if self.enable_prefix_caching: # Block hash tracking for prefix caching: -1 = uncomputed, positive = valid hash - self.block_hashes = torch.full( - (self.total_count,), -1, dtype=torch.int64, device=torch.cuda.current_device() - ) + self.block_hashes = torch.full((self.total_count,), -1, dtype=torch.int64, device='cpu') # Hash-to-block mapping for O(1) prefix lookup self.kv_hash_to_block_id: Dict[int, int] = {} # Reference count per block: 0 = cached (evictable), >0 = actively used self.block_ref_counts = torch.zeros( - (self.total_count,), dtype=torch.int32, device=torch.cuda.current_device() + (self.total_count,), dtype=torch.int32, device='cpu' ) # LRU timestamps for eviction ordering (higher = more recently used) # Only needed in LRU mode; RZ mode evicts immediately on ref_count==0 if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: self.block_timestamps = torch.zeros( - (self.total_count,), dtype=torch.int64, device=torch.cuda.current_device() + (self.total_count,), dtype=torch.int64, device='cpu' ) + # Per-block MoE routing storage (populated when routing replay is enabled) + self.block_routing: Dict[int, np.ndarray] = {} + def __str__(self): return ( f"using: total {self.get_total_used()}/{self.total_count - 1}" @@ -183,6 +183,10 @@ def allocate_memory_blocks(self, num_blocks: int) -> Optional[Tensor]: if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: self.update_timestamps(block_ids) + # Clear stale routing data for re-allocated blocks + for bid in block_ids.tolist(): + self.block_routing.pop(bid, None) + return block_ids def release_memory_blocks(self, blocks: Tensor) -> None: @@ -239,9 +243,7 @@ def reset(self) -> None: # Without resetting the block bag, context request memory will clash and # requests will point to each other's memory blocks, resulting in faulty # generations. - self.block_bag = torch.arange( - self.total_count, dtype=torch.int32, device=torch.cuda.current_device() - ) + self.block_bag = torch.arange(self.total_count, dtype=torch.int32, device='cpu') self.total_avail = self.total_count - 1 @@ -255,6 +257,9 @@ def reset(self) -> None: if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: self.block_timestamps.fill_(0) + # Clear per-block routing storage + self.block_routing.clear() + # ========================================================================= # Prefix caching methods # ========================================================================= @@ -358,3 +363,123 @@ def evict_lru_blocks(self, num_blocks_needed: int) -> bool: self._deregister_blocks(blocks_to_evict) return True + + # ========================================================================= + # Per-block routing storage methods (for MoE routing replay) + # ========================================================================= + + def store_routing_per_block(self, flat_routing: Optional[np.ndarray]) -> None: + """Scatter flat routing indices into per-block storage. + + Uses the context's token-to-block mapping to distribute each token's + routing data into the appropriate block. Matched (prefix-cached) blocks + already have routing from the original request and are not overwritten + here since their tokens are not in the active token layout. + + Args: + flat_routing: ndarray of shape [active_token_count, num_layers, topk] + aligned with the context's active-token layout, or None. + """ + if flat_routing is None: + return + + context = self.context + token_count = context.active_token_count + if token_count == 0: + return + + assert ( + flat_routing.shape[0] == token_count + ), f"Routing token count {flat_routing.shape[0]} != active token count {token_count}" + + # Token-to-block mapping for all active tokens + block_ids_np = context.token_to_block_idx[:token_count].cpu().numpy() + positions_np = context.token_to_local_position_within_kv_block[:token_count].cpu().numpy() + + dummy = self.dummy_block_idx + + # Group tokens by block_id using sort for efficient scatter + unique_blocks, inverse, counts = np.unique( + block_ids_np, return_inverse=True, return_counts=True + ) + sorted_indices = np.argsort(inverse, kind='stable') + sorted_positions = positions_np[sorted_indices] + sorted_routing = flat_routing[sorted_indices] + + offset = 0 + for bid, count in zip(unique_blocks, counts): + bid = int(bid) + count = int(count) + if bid == dummy: + offset += count + continue + block_pos = sorted_positions[offset : offset + count] + block_rout = sorted_routing[offset : offset + count] + self.store_block_routing(bid, block_pos, block_rout) + offset += count + + def reconstruct_routing_from_blocks( + self, block_ids: list[int], total_routing_tokens: int + ) -> Optional[np.ndarray]: + """Reconstruct routing indices from per-block storage. + + Concatenates per-block routing ndarrays in block order, trimming the + last block to exactly ``total_routing_tokens`` entries. + + Args: + block_ids: Ordered list of block IDs for the request. + total_routing_tokens: Expected number of routing tokens + (total_tokens - 1, since the last generated token has no + forward-pass routing). + + Returns: + ndarray [total_routing_tokens, num_layers, topk] or None if any + block is missing routing data. + """ + block_size = self.context.block_size_tokens + routing_parts = [] + tokens_collected = 0 + + for bid in block_ids: + routing = self.get_block_routing(bid) + if routing is None: + return None # Missing routing data for this block + remaining = total_routing_tokens - tokens_collected + if remaining <= 0: + break + take = min(block_size, remaining) + routing_parts.append(routing[:take]) + tokens_collected += take + + if not routing_parts or tokens_collected != total_routing_tokens: + return None + + return np.concatenate(routing_parts, axis=0) + + def store_block_routing( + self, block_id: int, positions: np.ndarray, routing: np.ndarray + ) -> None: + """Store routing indices for specific token positions in a block. + + Args: + block_id: The block ID. + positions: ndarray of token positions within the block (1D, int). + routing: ndarray of routing data [num_positions, num_layers, topk]. + """ + if block_id not in self.block_routing: + self.block_routing[block_id] = np.zeros( + (self.context.block_size_tokens, routing.shape[-2], routing.shape[-1]), + dtype=routing.dtype, + ) + self.block_routing[block_id][positions] = routing + + def get_block_routing(self, block_id: int) -> Optional[np.ndarray]: + """Get routing indices for a block. + + Args: + block_id: The block ID. + + Returns: + ndarray [block_size_tokens, num_layers, topk] or None if not stored. + """ + return self.block_routing.get(block_id) diff --git a/megatron/core/inference/contexts/mamba_slot_allocator.py b/megatron/core/inference/contexts/mamba_slot_allocator.py index d7c57046c8a..60c8dd3416b 100644 --- a/megatron/core/inference/contexts/mamba_slot_allocator.py +++ b/megatron/core/inference/contexts/mamba_slot_allocator.py @@ -47,59 +47,70 @@ def __init__( self.max_slots = max_slots self.num_mamba_layers = num_mamba_layers - device = torch.cuda.current_device() + gpu_device = torch.cuda.current_device() num_blocks = context.kv_block_allocator.total_count - # Block <-> slot mappings - self.block_to_slot = torch.full((num_blocks,), -1, dtype=torch.int32, device=device) - self.slot_to_block = torch.full((max_slots,), -1, dtype=torch.int32, device=device) + # Block <-> slot mappings (CPU for bookkeeping). + self.block_to_slot = torch.full((num_blocks,), -1, dtype=torch.int32, device='cpu') + self.slot_to_block = torch.full((max_slots,), -1, dtype=torch.int32, device='cpu') - # Free slot pool (stack) - self.free_slots = torch.arange(max_slots, dtype=torch.int32, device=device) + # Free slot pool (stack, CPU). + self.free_slots = torch.arange(max_slots, dtype=torch.int32, device='cpu') self.free_count = max_slots - # State tensors + # State tensors (GPU - accessed by Mamba CUDA kernels). self.conv_states = torch.zeros( (num_mamba_layers, max_slots) + conv_states_shape, dtype=conv_states_dtype, - device=device, + device=gpu_device, ) self.ssm_states = torch.zeros( - (num_mamba_layers, max_slots) + ssm_states_shape, dtype=ssm_states_dtype, device=device + (num_mamba_layers, max_slots) + ssm_states_shape, + dtype=ssm_states_dtype, + device=gpu_device, ) # Hash-to-block mapping: only blocks with cached Mamba state self.hash_to_block_id: Dict[int, int] = {} - # Per-request intermediate state storage (GPU tensors, fixed-size per request) - # 0 = no offset, -1 = no block + # Per-request intermediate state storage. + # offsets_cpu and counts_cpu: CPU source of truth. GPU copies are + # populated by transfer_bookkeeping_to_gpu() since Triton kernels read them. + # block_ids and eos_cache_block_id: CPU only (consumed by CPU code). k = MAX_INTERMEDIATE_OFFSETS_PER_REQUEST - self._intermediate_offsets_gpu = torch.zeros( - (context.max_requests, k), dtype=torch.int32, device=device + self._intermediate_offsets_cpu = torch.zeros( + (context.max_requests, k), dtype=torch.int32, device='cpu' ) - self._intermediate_block_ids_gpu = torch.full( - (context.max_requests, k), -1, dtype=torch.int32, device=device + self._intermediate_counts_cpu = torch.zeros( + context.max_requests, dtype=torch.int32, device='cpu' + ) + self._intermediate_offsets_gpu = torch.zeros( + (context.max_requests, k), dtype=torch.int32, device=gpu_device ) self._intermediate_counts_gpu = torch.zeros( - context.max_requests, dtype=torch.int32, device=device + context.max_requests, dtype=torch.int32, device=gpu_device ) - self._eos_cache_block_id_gpu = torch.full( - (context.max_requests,), -1, dtype=torch.int32, device=device + # CPU-only: consumed by _collect_commit_data() which needs .tolist() anyway. + self._intermediate_block_ids_cpu = torch.full( + (context.max_requests, k), -1, dtype=torch.int32, device='cpu' + ) + self._eos_cache_block_id_cpu = torch.full( + (context.max_requests,), -1, dtype=torch.int32, device='cpu' ) # CPU flag to skip GPU sync when no intermediates exist self._has_intermediates = False - # Pre-allocated output buffers for CUDA graph compatible extraction + # Pre-allocated output buffers for CUDA graph compatible extraction (GPU). self.max_intermediate_count = MAX_INTERMEDIATE_OFFSETS_PER_REQUEST * context.max_requests self.intermediate_ssm_out = torch.zeros( (num_mamba_layers, self.max_intermediate_count) + ssm_states_shape, dtype=ssm_states_dtype, - device=device, + device=gpu_device, ) self.intermediate_conv_out = torch.zeros( (num_mamba_layers, self.max_intermediate_count) + conv_states_shape, dtype=conv_states_dtype, - device=device, + device=gpu_device, ) # ========================================================================= @@ -320,9 +331,11 @@ def store_from_live_batch(self, slots: list, request_indices: list) -> None: return device = self.conv_states.device slot_tensor = torch.tensor(slots, dtype=torch.int64, device=device) - req_tensor = torch.tensor(request_indices, dtype=torch.int64, device=device) - # Batch lookup mamba state indices (1 GPU sync) - mamba_indices = self.context.mamba_metadata.request_to_mamba_state_idx[req_tensor].tolist() + # Lookup mamba indices from CPU bookkeeping, then move to GPU for state copy. + req_tensor_cpu = torch.tensor(request_indices, dtype=torch.int64) + mamba_indices = self.context.mamba_metadata.request_to_mamba_state_idx[ + req_tensor_cpu + ].tolist() mamba_idx_tensor = torch.tensor(mamba_indices, dtype=torch.int64, device=device) # Fancy-indexed copy (2 kernel launches instead of 2E) self.conv_states[:, slot_tensor] = self.context.mamba_conv_states[:, mamba_idx_tensor] @@ -413,42 +426,39 @@ def compute_and_store_offsets( offsets = sorted(offsets_set) count = len(offsets) - # Vectorized block ID lookup: GPU gather avoids per-block .item() syncs + # CPU bookkeeping writes (no GPU kernel launches). if count > 0: - device = self._intermediate_offsets_gpu.device - abs_tokens = torch.tensor( - [skip_tokens + o for o in offsets], dtype=torch.int64, device=device - ) - block_indices = abs_tokens // ctx.block_size_tokens - 1 - bids = ctx.request_to_kv_block_ids[current_id][block_indices] + abs_tokens_cpu = torch.tensor([skip_tokens + o for o in offsets], dtype=torch.int64) + block_indices_cpu = abs_tokens_cpu // ctx.block_size_tokens - 1 + bids_cpu = ctx.request_to_kv_block_ids[current_id][block_indices_cpu] - self._intermediate_offsets_gpu[current_id, :count] = torch.tensor( - offsets, dtype=torch.int32, device=device + self._intermediate_offsets_cpu[current_id, :count] = torch.tensor( + offsets, dtype=torch.int32 ) - self._intermediate_block_ids_gpu[current_id, :count] = bids.to(torch.int32) + self._intermediate_block_ids_cpu[current_id, :count] = bids_cpu.to(torch.int32) self._has_intermediates = True - self._intermediate_counts_gpu[current_id] = count + self._intermediate_counts_cpu[current_id] = count # Block-aligned EOS: prompt_len is exactly block-aligned if last_aligned_abs == prompt_len and prompt_len > 0: last_block_idx = prompt_len // ctx.block_size_tokens - 1 if last_block_idx >= 0: - self._eos_cache_block_id_gpu[current_id] = ctx.request_to_kv_block_ids[current_id][ + self._eos_cache_block_id_cpu[current_id] = ctx.request_to_kv_block_ids[current_id][ last_block_idx ] self._has_intermediates = True else: - self._eos_cache_block_id_gpu[current_id] = -1 + self._eos_cache_block_id_cpu[current_id] = -1 else: - self._eos_cache_block_id_gpu[current_id] = -1 + self._eos_cache_block_id_cpu[current_id] = -1 - def get_intermediate_gpu_data(self): - """Get intermediate offsets and counts as GPU tensor slices for current prefill batch. + def get_intermediate_cpu_data(self): + """Get intermediate offsets and counts as CPU tensor slices for current prefill batch. Returns: - Tuple of (offsets_gpu, counts_gpu) where: - offsets_gpu: [prefill_count, 3] int32 GPU tensor - counts_gpu: [prefill_count] int32 GPU tensor + Tuple of (offsets_cpu, counts_cpu) where: + offsets_cpu: [prefill_count, 3] int32 CPU tensor + counts_cpu: [prefill_count] int32 CPU tensor Returns (None, None) if no prefill requests or no intermediates. """ if not self._has_intermediates: @@ -463,10 +473,25 @@ def get_intermediate_gpu_data(self): decode_count = ctx.batch_dimensions.decode_req_count prefill_start = active_start + decode_count - offsets = self._intermediate_offsets_gpu[prefill_start : prefill_start + prefill_count] - counts = self._intermediate_counts_gpu[prefill_start : prefill_start + prefill_count] + offsets = self._intermediate_offsets_cpu[prefill_start : prefill_start + prefill_count] + counts = self._intermediate_counts_cpu[prefill_start : prefill_start + prefill_count] return offsets, counts + def transfer_intermediate_to_gpu(self, prefill_start: int, prefill_count: int): + """Copy intermediate offsets/counts slice from CPU to GPU for Mamba kernels. + + Returns the GPU tensor views for the forward-pass kernels to consume. + """ + if prefill_count == 0: + return None, None + offsets_cpu = self._intermediate_offsets_cpu[prefill_start : prefill_start + prefill_count] + counts_cpu = self._intermediate_counts_cpu[prefill_start : prefill_start + prefill_count] + offsets_gpu = self._intermediate_offsets_gpu[prefill_start : prefill_start + prefill_count] + counts_gpu = self._intermediate_counts_gpu[prefill_start : prefill_start + prefill_count] + offsets_gpu.copy_(offsets_cpu, non_blocking=True) + counts_gpu.copy_(counts_cpu, non_blocking=True) + return offsets_gpu, counts_gpu + # ========================================================================= # Intermediate state commit # ========================================================================= @@ -517,14 +542,14 @@ def _collect_commit_data(self): decode_count = ctx.batch_dimensions.decode_req_count prefill_start = active_start + decode_count - # Batch-transfer block IDs and EOS block IDs from GPU (2 GPU syncs) + # Block IDs and EOS block IDs live on CPU (no GPU sync needed). intermediate_count = metadata.intermediate_count per_request_counts = metadata.per_request_intermediate_counts - all_block_ids_cpu = self._intermediate_block_ids_gpu[ + all_block_ids_cpu = self._intermediate_block_ids_cpu[ prefill_start : prefill_start + prefill_count ].tolist() - eos_bids_cpu = self._eos_cache_block_id_gpu[ + eos_bids_cpu = self._eos_cache_block_id_cpu[ prefill_start : prefill_start + prefill_count ].tolist() @@ -586,10 +611,10 @@ def _clear_intermediate_state(self) -> None: decode_count = ctx.batch_dimensions.decode_req_count prefill_start = active_start + decode_count end = prefill_start + prefill_count - self._intermediate_counts_gpu[prefill_start:end].fill_(0) - self._intermediate_offsets_gpu[prefill_start:end].fill_(0) - self._intermediate_block_ids_gpu[prefill_start:end].fill_(-1) - self._eos_cache_block_id_gpu[prefill_start:end].fill_(-1) + self._intermediate_counts_cpu[prefill_start:end].fill_(0) + self._intermediate_offsets_cpu[prefill_start:end].fill_(0) + self._intermediate_block_ids_cpu[prefill_start:end].fill_(-1) + self._eos_cache_block_id_cpu[prefill_start:end].fill_(-1) self._has_intermediates = False # ========================================================================= @@ -600,15 +625,13 @@ def reset(self) -> None: """Reset all state (mappings, free pool, cache, intermediate tracking).""" self.block_to_slot.fill_(-1) self.slot_to_block.fill_(-1) - self.free_slots = torch.arange( - self.max_slots, dtype=torch.int32, device=torch.cuda.current_device() - ) + self.free_slots = torch.arange(self.max_slots, dtype=torch.int32, device='cpu') self.free_count = self.max_slots self.hash_to_block_id.clear() self.intermediate_ssm_out.zero_() self.intermediate_conv_out.zero_() - self._intermediate_offsets_gpu.fill_(0) - self._intermediate_block_ids_gpu.fill_(-1) - self._intermediate_counts_gpu.fill_(0) - self._eos_cache_block_id_gpu.fill_(-1) + self._intermediate_offsets_cpu.fill_(0) + self._intermediate_counts_cpu.fill_(0) + self._intermediate_block_ids_cpu.fill_(-1) + self._eos_cache_block_id_cpu.fill_(-1) self._has_intermediates = False diff --git a/megatron/core/inference/engines/async_zmq_communicator.py b/megatron/core/inference/engines/async_zmq_communicator.py index 52570845d61..aa13f659d40 100644 --- a/megatron/core/inference/engines/async_zmq_communicator.py +++ b/megatron/core/inference/engines/async_zmq_communicator.py @@ -131,6 +131,45 @@ async def all_reduce_max(self, *local_vals: int, async_op=True) -> int | tuple[i except zmq.Again: await asyncio.sleep(0.001) + def sync_all_reduce_max(self, *local_vals: int) -> int | tuple[int, ...]: + """Synchronous (non-asyncio) variant of all_reduce_max. + + Uses blocking ZMQ sends/recvs so it can be called from synchronous + call sites that need a CPU-only MAX reduction across the process + group. Intended for tiny payloads (e.g. a few integers) that would + otherwise force a NCCL AllReduce kernel on the compute stream. + + Note: when called from inside a running asyncio event loop, the + blocking recv will pause other coroutines on this rank until all + peers respond. This is acceptable here because every rank reaches + the call simultaneously and the message size is trivial. + + Returns a single int when called with one argument, otherwise a tuple. + """ + n = len(local_vals) + if n == 0: + raise ValueError("sync_all_reduce_max requires at least one value") + + if self.world_size <= 1: + return local_vals[0] if n == 1 else local_vals + + fmt = f'!{n}i' + payload = struct.pack(fmt, *local_vals) + + if self.is_leader: + rows = [local_vals] + while len(rows) < self.world_size: + msg = self.gather_sock.recv() + rows.append(struct.unpack(fmt, msg)) + maxes = tuple(max(row[i] for row in rows) for i in range(n)) + self.bcast_sock.send(struct.pack(fmt, *maxes)) + return maxes[0] if n == 1 else maxes + else: + self.gather_sock.send(payload) + msg = self.bcast_sock.recv() + result = struct.unpack(fmt, msg) + return result[0] if n == 1 else result + def close(self): """ Close the ZMQ sockets. diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index ec2bc5b630d..e862c8dacd2 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -6,7 +6,6 @@ import math import multiprocessing import socket -import struct import time import warnings from collections import deque @@ -43,12 +42,7 @@ from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) -from megatron.core.inference.utils import ( - Counter, - await_process_call, - set_inference_cuda_graphed_iteration_for_ep_inference, - unset_inference_cuda_graphed_iteration_for_ep_inference, -) +from megatron.core.inference.utils import Counter, await_process_call from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import delete_cuda_graphs from megatron.core.transformer.enums import CudaGraphScope @@ -63,7 +57,9 @@ internal_api, nvtx_range_pop, nvtx_range_push, + round_up_to_nearest_multiple, trace_async_exceptions, + unwrap_model, ) from .async_zmq_communicator import AsyncZMQCommunicator @@ -225,6 +221,7 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen self.logging_step_interval = inference_config.logging_step_interval self.unified_memory_level = inference_config.unified_memory_level self.use_synchronous_zmq_collectives = inference_config.use_synchronous_zmq_collectives + self.disable_ep_consensus = inference_config.disable_ep_consensus self.cuda_graph_impl = model_config.cuda_graph_impl self.cuda_graph_scope = model_config.cuda_graph_scope # Initialize engine. @@ -357,13 +354,21 @@ def create_cuda_graphs(self, reset_context: bool = True): # Enable inference dispatcher for EP during graph capture model_config = controller.inference_wrapped_model.model.config - is_inference_optimized_ep = ( - model_config.transformer_impl == "inference_optimized" - and model_config.expert_model_parallel_size > 1 + + # MTP warmup preparation: capture MTP CUDA graphs alongside the + # decoder graphs within the same loop rather than in a separate pass. + unwrapped = unwrap_model(controller.inference_wrapped_model.model) + mtp_warmup_enabled = ( + controller.num_mtp_heads > 0 + and (controller.num_speculative_tokens or 0) > 0 + and hasattr(unwrapped, 'mtp') ) - if is_inference_optimized_ep: - unwrapped_model = controller.inference_wrapped_model.model - set_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model) + if mtp_warmup_enabled: + tp_size = get_pg_size(controller.inference_wrapped_model.tp_group) + sp_enabled = model_config.sequence_parallel and tp_size > 1 + mtp_pass_depth = not unwrapped.mtp.mtp_use_repeated_layer + mtp_warmup_depths = range(controller._num_mtp_depths) if mtp_pass_depth else [None] + mtp_seen_batch_sizes = set() tbar = enumerate(context.cuda_graph_batch_dimensions_list) if HAVE_TQDM: @@ -383,18 +388,48 @@ def create_cuda_graphs(self, reset_context: bool = True): # Enable routing recording during warmup if routing replay is enabled. # This ensures the record_indices copy operation is captured in the CUDA graph. - model_config = controller.inference_wrapped_model.model.config if model_config.moe_enable_routing_replay: RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) # Forward pass -> logits. - controller._dynamic_step_forward_logits(input_ids, position_ids) + with torch.inference_mode(): + controller._dynamic_step_forward_logits(input_ids, position_ids) + + if controller._sampling_backend == "flashinfer": + if controller.num_speculative_tokens > 0: + controller._dynamic_step_sample_logits_and_verify_tokens(input_ids) + else: + controller._dynamic_step_sample_logits() - context.reset() + # MTP CUDA graph warmup for this batch dimension. + if mtp_warmup_enabled: + n = cuda_graph_batch_dimension.req_count + # pylint: disable-next=possibly-used-before-assignment + if sp_enabled: + n = round_up_to_nearest_multiple(n, tp_size) + # pylint: disable-next=possibly-used-before-assignment + if n > 0 and n not in mtp_seen_batch_sizes: + mtp_seen_batch_sizes.add(n) + device = torch.cuda.current_device() + batch_dim = n // tp_size if sp_enabled else n + # Use zeros (not empty) — garbage token IDs cause OOB embedding lookups during graph capture/replay. + for depth in mtp_warmup_depths: + unwrapped.compute_mtp_single_step( + hidden_states=torch.zeros( + (batch_dim, 1, model_config.hidden_size), + device=device, + dtype=model_config.params_dtype, + ), + next_token_ids=torch.zeros((1, n), device=device, dtype=torch.long), + position_ids=torch.zeros((1, n), device=device, dtype=torch.int64), + depth=depth, + cache_key=("mtp", n, depth), + ) - # Disable inference dispatcher after graph capture - if is_inference_optimized_ep: - unset_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model) + context.reset() + + if mtp_warmup_enabled and mtp_seen_batch_sizes: + logging.info("> MTP CUDA graph warmup: %d batch size(s)", len(mtp_seen_batch_sizes)) # Memory usage. time_end = time.time() @@ -549,20 +584,16 @@ async def start_listening_to_data_parallel_coordinator( mp_req_sock.bind_to_random_port(f"tcp://{local_ip}") mp_req_addr = mp_req_sock.getsockopt_string(zmq.LAST_ENDPOINT) - mp_len_sock = self.zmq_context.socket(zmq.PUB) - mp_len_sock.bind_to_random_port(f"tcp://{local_ip}") - mp_len_addr = mp_len_sock.getsockopt_string(zmq.LAST_ENDPOINT) else: mp_req_addr = None - mp_len_addr = None # Broadcast addresses to respective ranks. bcast = [dp_addr] torch.distributed.broadcast_object_list(bcast, src=dp_src, group=dp_group) [dp_addr] = bcast - bcast = [mp_req_addr, mp_len_addr] + bcast = [mp_req_addr] torch.distributed.broadcast_object_list(bcast, src=mp_src, group=mp_group) - [mp_req_addr, mp_len_addr] = bcast + [mp_req_addr] = bcast identity = f'mp-coord-{dp_rank}' if self.is_mp_coordinator: @@ -579,37 +610,32 @@ async def start_listening_to_data_parallel_coordinator( # 2. Create a publisher socket. This is used to publish or broadcast # requests within the model parallel group self.model_parallel_publisher_socket = mp_req_sock - - # 3. Create another publisher socket to broadcast the number of messages to receive. - self.model_parallel_num_msgs_publisher_socket = mp_len_sock self.zmq_sockets += [ self.socket_for_receiving_requests, - self.model_parallel_num_msgs_publisher_socket, self.model_parallel_publisher_socket, ] - # All MP ranks subscribe to the two publisher sockets + # All MP ranks subscribe to the publisher socket self.model_parallel_subscriber_socket = self.zmq_context.socket(zmq.SUB) self.model_parallel_subscriber_socket.connect(mp_req_addr) self.model_parallel_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "") - self.model_parallel_num_msgs_subscriber_socket = self.zmq_context.socket(zmq.SUB) - self.model_parallel_num_msgs_subscriber_socket.connect(mp_len_addr) - self.model_parallel_num_msgs_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "") - - self.zmq_sockets += [ - self.model_parallel_subscriber_socket, - self.model_parallel_num_msgs_subscriber_socket, - ] + self.zmq_sockets += [self.model_parallel_subscriber_socket] torch.distributed.barrier(mp_group) # initialize zmq-based EP communicator self.ep_rank = get_pg_rank(self.pg_collection.ep) self.ep_world_size = get_pg_size(self.pg_collection.ep) + self._ep_consensus_loop_counter = 0 + self._last_ep_consensus: tuple[int, bool] = (0, False) if self.ep_world_size > 1: self.expert_parallel_zmq_communicator = AsyncZMQCommunicator( self.zmq_context, process_group=self.pg_collection.ep, hostname=hostname ) + # Give the context a CPU-side MAX-reduction primitive so + # match_graph_config() can avoid a per-step NCCL AllReduce kernel. + if hasattr(self.context, "set_ep_zmq_communicator"): + self.context.set_ep_zmq_communicator(self.expert_parallel_zmq_communicator) # initialize zmq-based world communicator for consensus barriers total_world_size = torch.distributed.get_world_size() @@ -1035,9 +1061,9 @@ def post_process_requests( accepted_tokens: torch.Tensor, log_probs: torch.Tensor, top_n_logprobs: Optional[Dict[int, List[Tuple[torch.Tensor, torch.Tensor]]]] = None, - routing_indices_per_request: Optional[Dict[int, torch.Tensor]] = None, pre_fwd_active_token_count: Optional[int] = None, pre_fwd_step_count: Optional[int] = None, + finished_routing_block_ids: Optional[Dict[int, list[int]]] = None, ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]: """ Handles post-processing for requests after a step. @@ -1052,9 +1078,9 @@ def post_process_requests( log_probs: (List): Log probs for each request top_n_logprobs: (Dict): Top-n log probs for each request. Maps request_idx to list of (top_n_logprobs, top_n_indices) tuples. - routing_indices_per_request: (Dict[int, Tensor]): MoE routing indices - pre-mapped by request_id. Each value is a tensor of shape - [num_tokens_this_step, num_layers, topk]. + finished_routing_block_ids: (Dict[int, List[int]]): Block IDs for + finished requests, saved before update_requests released them. + Used for per-block routing reconstruction. Returns: A list of active requests and completed requests as `DynamicInferenceRequest` objects @@ -1160,10 +1186,15 @@ def post_process_requests( request.ttft = ( first_token_event.timestamp - request.event_add_engine.timestamp ) - if request.tpot is None: - request.tpot = [] - per_token_step_time = step_time / len(tokens) - request.tpot.extend([per_token_step_time] * len(tokens)) + # TPOT is observability-only. step_time is 0.0 on + # non-logging steps (async_forward skips the event sync), + # so gate the update to keep the metric a truthful sparse + # sample instead of polluting it with zeros. + if step_time > 0: + if request.tpot is None: + request.tpot = [] + per_token_step_time = step_time / len(tokens) + request.tpot.extend([per_token_step_time] * len(tokens)) # Check for stop words (after token is appended). # With speculative decoding, a stop word may end before the last @@ -1183,6 +1214,20 @@ def post_process_requests( self._spec_tokens_accepted += actual_accepted if request_id in finished_request_ids: + # Reconstruct routing from per-block storage before popping. + if ( + finished_routing_block_ids + and request_id in finished_routing_block_ids + and len(self.requests[request_id].record.requests) == 1 + ): + block_ids = finished_routing_block_ids[request_id] + total_tokens = len(request.prompt_tokens) + len(request.generated_tokens) + request.routing_indices = ( + self.context.kv_block_allocator.reconstruct_routing_from_blocks( + block_ids, total_tokens - 1 + ) + ) + # Request finished by normal means (termination_id, max_length, or stop word from previous step) request.generated_length = len(request.generated_tokens) request.status = Status.COMPLETED @@ -1292,23 +1337,6 @@ def post_process_requests( else: request.generated_top_n_logprobs.append(logit_dict) - # Process routing indices if available (keyed by request_id) - # Each step's routing is a tensor of shape [num_tokens_this_step, num_layers, topk] - # We concatenate along dim=0 to accumulate: [total_tokens, num_layers, topk] - if ( - routing_indices_per_request is not None - and request_id in routing_indices_per_request - ): - step_routing = routing_indices_per_request[ - request_id - ] # [num_tokens, num_layers, topk] - if request.routing_indices is None: - request.routing_indices = step_routing.clone() - else: - request.routing_indices = torch.cat( - [request.routing_indices, step_routing], dim=0 - ) - # Handle evicted requests. if evict_request_ids is not None and evict_request_ids.numel() > 0: @@ -1645,56 +1673,74 @@ async def async_forward(self) -> Tuple[Dict, Dict, float]: # schedule requests self.schedule_waiting_requests() - # Saving pre-step state, for printing output below. + # The print block (async_bookkeep) and metrics block both fire on this + # condition after step_count is incremented. Predict it up-front so we + # can skip the GPU-timing sync and the context_state dict builds that + # only exist to feed those logging/metrics blocks. + will_log_this_step = ( + self.logging_step_interval > 0 + and (self.context.step_count + 1) % self.logging_step_interval == 0 + ) + is_decode_only = self.context.is_decode_only() - pre_step_context_state = { - "is_decode_only": is_decode_only, - "max_requests": self.context.max_requests, - "total_request_count": self.context.total_request_count, - "paused_request_count": self.context.paused_request_count, - "active_token_count": self.context.active_token_count, - "step_count": self.context.step_count, - } + if will_log_this_step: + pre_step_context_state = { + "is_decode_only": is_decode_only, + "max_requests": self.context.max_requests, + "total_request_count": self.context.total_request_count, + "paused_request_count": self.context.paused_request_count, + "active_token_count": self.context.active_token_count, + "step_count": self.context.step_count, + } + else: + # active_token_count and step_count are still consumed by + # post_process_requests' pre_fwd_* args (for add_event_generated_token); + # the other four fields are only read in the gated print block. + pre_step_context_state = { + "active_token_count": self.context.active_token_count, + "step_count": self.context.step_count, + } # Generate tokens. nvtx_range_push("Prefill" if not is_decode_only else "Decode") # TODO @TDE: Account for this line when overlapping forward and bookkeep. self.is_decode_only = is_decode_only - self.step_start_event.record() + if will_log_this_step: + self.step_start_event.record() result = await self.controller.async_generate_output_tokens_dynamic_batch() - self.step_end_event.record() - self.step_end_event.synchronize() - step_time = self.step_start_event.elapsed_time(self.step_end_event) / 1e3 + if will_log_this_step: + self.step_end_event.record() + self.step_end_event.synchronize() + step_time = self.step_start_event.elapsed_time(self.step_end_event) / 1e3 + else: + step_time = 0.0 self.context.step_count += 1 self.context.prefix_cache_lru_clock += 1 nvtx_range_pop("Prefill" if not is_decode_only else "Decode") - if ( - self.logging_step_interval > 0 - and self.context.step_count > 0 - and self.context.step_count % self.logging_step_interval == 0 - and self.metrics_writer is not None - ): - kvcache_util_stats = self.context.get_kvcache_utilization_stats() + if will_log_this_step: + kvcache_util_stats = ( + self.context.get_kvcache_utilization_stats() + if self.metrics_writer is not None + else None + ) + post_step_context_state = { + "waiting_request_count": len(self.waiting_request_ids), + "finished_request_count": self.finished_request_count, + "evicted_request_count": self.evicted_request_count, + "kv_stats": kvcache_util_stats, + "total_active_block_count": self.context.kv_block_allocator.active_count, + "total_paused_block_count": self.context.kv_block_allocator.paused_count, + "total_active_used_blocks": self.context.kv_block_allocator.get_active_used(), + "total_paused_used_blocks": self.context.kv_block_allocator.get_paused_used(), + } + context_state = {**pre_step_context_state, **post_step_context_state} else: - kvcache_util_stats = None - - post_step_context_state = { - "waiting_request_count": len(self.waiting_request_ids), - "finished_request_count": self.finished_request_count, - "evicted_request_count": self.evicted_request_count, - "kv_stats": kvcache_util_stats, - "padded_active_token_count": self.context.padded_active_token_count, - "using_cuda_graph_this_step": self.context.using_cuda_graph_this_step(), - "total_active_block_count": self.context.kv_block_allocator.active_count, - "total_paused_block_count": self.context.kv_block_allocator.paused_count, - "total_active_used_blocks": self.context.kv_block_allocator.get_active_used(), - "total_paused_used_blocks": self.context.kv_block_allocator.get_paused_used(), - } - - context_state = {**pre_step_context_state, **post_step_context_state} + # Keep kv_stats=None so the metrics-block gate at `async_bookkeep` + # (`if context_state["kv_stats"] is not None`) remains well-typed. + context_state = {**pre_step_context_state, "kv_stats": None} return result, context_state, step_time @@ -1728,7 +1774,7 @@ async def async_bookkeep( accepted_tokens = step_result["accepted_tokens"] log_probs = step_result["log_probs"] top_n_logprobs = step_result.get("top_n_logprobs", None) - routing_indices_per_request = step_result.get("routing_indices_per_request", None) + finished_routing_block_ids = step_result.get("finished_routing_block_ids", None) cuda_graph_request_count = step_result["cuda_graph_request_count"] # Add paused events. @@ -1746,9 +1792,9 @@ async def async_bookkeep( accepted_tokens, log_probs, top_n_logprobs, - routing_indices_per_request, pre_fwd_active_token_count=context_state.get("active_token_count"), pre_fwd_step_count=context_state.get("step_count"), + finished_routing_block_ids=finished_routing_block_ids, ) else: @@ -1811,6 +1857,7 @@ async def async_bookkeep( self.context.prefix_cache_blocks_matched = 0 # Log KV cache utilization stats to W&B + nvtx_range_push("wandb_logging") if context_state["kv_stats"] is not None: # Prepare metrics dictionary with all stats # Use 'inference/' prefix for all metrics to separate from training metrics @@ -1850,13 +1897,17 @@ async def async_bookkeep( self.metrics_writer.log(metrics, commit=True) else: raise ValueError(f"Unsupported metrics writer type: {type(self.metrics_writer)}") + nvtx_range_pop("wandb_logging") # Print context state. + nvtx_range_push("console_logging") if ( self.logging_step_interval > 0 and self.context.step_count % self.logging_step_interval == 0 ): + nvtx_range_push("cuda_memory_stats") mem = torch.cuda.memory_stats() + nvtx_range_pop("cuda_memory_stats") step_type = "decode" if context_state["is_decode_only"] else "non-decode" output_str = ( "* rank %d | step %d | %s ... time: %.3f ms%s ... " @@ -1923,6 +1974,8 @@ async def async_bookkeep( self._prefix_cache_hits = 0 self._prefix_cache_blocks_matched = 0 + nvtx_range_pop("console_logging") + return { "active_request_ids": active_request_ids, "finished_request_records": finished_request_records, @@ -2051,28 +2104,12 @@ def schedule_requests(self) -> int: except zmq.Again: # This exception is hit as soon as the socket is empty. break - messages_to_dequeue = len(all_messages) - # First publish the number of messages to dequeue. - # This is important because we want all tensor parallel ranks - # to dequeue the same number of messages. - self.model_parallel_num_msgs_publisher_socket.send( - struct.pack('!i', messages_to_dequeue) + self.model_parallel_publisher_socket.send_multipart( + [bytes([Headers.TP_BROADCAST.value])] + all_messages ) - # Now publish the actual messages to all model parallel ranks - if messages_to_dequeue > 0: - self.model_parallel_publisher_socket.send_multipart(all_messages) else: - # First, receive the number of messages to dequeue from mp-rank 0 - messages_to_dequeue = struct.unpack( - '!i', self.model_parallel_num_msgs_subscriber_socket.recv() - )[0] - # Now, dequeue the same number of messages from the subscriber socket. - # Note that these receives are blocking, because the messages - # are guaranteed to be available after the tp-rank 0 has sent them. - if messages_to_dequeue > 0: - all_messages = self.model_parallel_subscriber_socket.recv_multipart() - else: - all_messages = [] + frames = self.model_parallel_subscriber_socket.recv_multipart() + all_messages = frames[1:] nvtx_range_pop("drain_zmq_socket") @@ -2299,9 +2336,50 @@ async def run_engine_with_coordinator( local_pending = self.context.get_active_request_count() + len( self.waiting_request_ids ) - global_work, all_pausing = await self._ep_establish_consensus( - local_pending, signal_consensus=(self.state == EngineState.PAUSING) - ) + if self.disable_ep_consensus: + # Skip the EP consensus all-reduce; act on local state only. + # NOTE: even with no consensus we must still participate in EP + # collectives (NCCL all-to-all, etc.) every iteration. A peer with + # real work will block at its all-to-all kernel waiting for this + # rank, so when there is no local work we run dummy_forward() + # rather than sleeping. Sleeping here would deadlock EP > 1. + if self.state == EngineState.PAUSING: + await self._world_barrier() + self.state = EngineState.PAUSED + self._state_events[EngineState.PAUSED].set() + elif local_pending > 0: + await self.async_step() + else: + self.step_start_event.record() + nvtx_range_push("EP-dummy-forward") + self.controller.dummy_forward() + self.step_end_event.record() + self.step_end_event.synchronize() + nvtx_range_pop("EP-dummy-forward") + self.context.step_count += 1 + self.context.prefix_cache_lru_clock += 1 + # The consensus path yields via _ep_establish_consensus; + # without it we must still let other coroutines (signal + # delivery, request scheduling) run between steps. + await asyncio.sleep(0) + continue + global_work_from_last_consensus, _ = self._last_ep_consensus + if ( + global_work_from_last_consensus == 0 + or self._ep_consensus_loop_counter % 20 == 0 + ): + # selectively enter ep_establish_consensus if + # 1. there is no global work -> engine is idle. At any step in the future + # one of the ranks can receive work. So we should be eagerly checking for that + # 2. it has been 20 steps since we last established consensus, and that consensus + # had some work. + # In the worst case, this delays pausing by 20 steps which is around + # 200-400 milliseconds. + self._last_ep_consensus = await self._ep_establish_consensus( + local_pending, signal_consensus=(self.state == EngineState.PAUSING) + ) + global_work, all_pausing = self._last_ep_consensus + self._ep_consensus_loop_counter += 1 if all_pausing: # All EP peers are PAUSING: pause immediately. @@ -2315,9 +2393,11 @@ async def run_engine_with_coordinator( else: # Dummy forward to participate in the EP collective. self.step_start_event.record() + nvtx_range_push("EP-dummy-forward") self.controller.dummy_forward() self.step_end_event.record() self.step_end_event.synchronize() + nvtx_range_pop("EP-dummy-forward") self.context.step_count += 1 self.context.prefix_cache_lru_clock += 1 else: @@ -2332,6 +2412,10 @@ async def run_engine_with_coordinator( self.state = EngineState.RUNNING self._state_events[EngineState.PAUSED].clear() self._state_events[EngineState.RUNNING].set() + # The cache from the PAUSING phase still has all_pausing=True; + # without this reset the next RUNNING iteration would skip + # consensus, read the stale flag, and immediately re-pause. + self._last_ep_consensus = (0, False) elif self.state == EngineState.SUSPENDING: await self._world_barrier() diff --git a/megatron/core/inference/headers.py b/megatron/core/inference/headers.py index aa2f0568975..8ad1913e6b1 100644 --- a/megatron/core/inference/headers.py +++ b/megatron/core/inference/headers.py @@ -20,6 +20,7 @@ class Headers(Enum): STOP = auto() DISCONNECT = auto() SHUTDOWN = auto() + TP_BROADCAST = auto() class UnknownHeaderError(Exception): diff --git a/megatron/core/inference/inference_request.py b/megatron/core/inference/inference_request.py index 4c93b9024d5..33fbcdf6518 100644 --- a/megatron/core/inference/inference_request.py +++ b/megatron/core/inference/inference_request.py @@ -1,13 +1,14 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import copy +import hashlib import time import warnings from dataclasses import asdict, dataclass, field from enum import Enum, auto -from itertools import accumulate from typing import Any, Dict, List, Optional, Tuple +import numpy as np import torch from megatron.core.inference.sampling_params import SamplingParams @@ -46,6 +47,16 @@ def deserialize_tensor(tensor_as_list: List) -> torch.Tensor: return tensor +def serialize_ndarray(arr: np.ndarray) -> dict: + """Serialize numpy array to a JSON-compatible dict.""" + return {"data": arr.tolist(), "dtype": str(arr.dtype)} + + +def deserialize_ndarray(obj: dict) -> np.ndarray: + """Deserialize numpy array from dict.""" + return np.array(obj["data"], dtype=np.dtype(obj["dtype"])) + + def unwrap_serialized_tensors(serialized_request: dict) -> dict: """Unwrap ("tensor", [...]) tuples produced by serialize() into plain lists. @@ -76,53 +87,44 @@ class Status(Enum): # Hash computation for prefix caching # ========================================================================= -# Constants for hash computation -# Using 2^61 - 1 (Mersenne prime) for ~10^18 hash space, reducing collision probability -# from ~10^-9 to ~10^-18 compared to the previous prime (1000000007). -HASH_PRIME = 2305843009213693951 -HASH_BASE = 31 - -_hash_powers: Optional[torch.Tensor] = None - def compute_block_hashes_batched(prompt_tokens: torch.Tensor, block_size: int) -> List[int]: - """Compute hashes for all complete blocks in a prompt in one batched operation. + """Compute SHA-256 based hashes for all complete blocks in a prompt. - Reshapes prompt tokens into [num_blocks, block_size], computes all per-block - token hashes via a single GPU matmul, transfers results with one .tolist() call, - and chains parent hashes on CPU. + Each block hash is computed as SHA-256(parent_digest || block_bytes), where + parent_digest chains from the previous block (starting from a zero digest). + This provides cryptographic collision resistance with no exploitable algebraic + structure. Args: prompt_tokens: All prompt token IDs, shape [seq_len]. block_size: Number of tokens per block. Returns: - List of positive integer hash values (1 to HASH_PRIME), one per complete block. + List of positive integer hash values in [1, 2^63-1], one per complete block. """ num_complete_blocks = len(prompt_tokens) // block_size if num_complete_blocks == 0: return [] - global _hash_powers - if _hash_powers is None or _hash_powers.shape[0] != block_size: - positions = torch.arange(block_size, device=prompt_tokens.device, dtype=torch.int64) - _hash_powers = torch.pow(HASH_BASE, positions).to(torch.int64) % HASH_PRIME + # Single GPU->CPU transfer, get contiguous bytes + tokens_cpu = prompt_tokens[: num_complete_blocks * block_size].to(torch.int64).cpu() + tokens_bytes = tokens_cpu.numpy().tobytes() + block_byte_size = block_size * tokens_cpu.element_size() # 8 bytes per int64 - # Reshape to [num_blocks, block_size] (zero-copy view) and compute all token hashes - blocks = prompt_tokens[: num_complete_blocks * block_size].view(num_complete_blocks, block_size) - token_hashes = (blocks.to(torch.int64) * _hash_powers).sum(dim=1) % HASH_PRIME + hashes = [] + parent_digest = b'\x00' * 32 # SHA-256 digest size - # Single GPU→CPU transfer - token_hashes_list = token_hashes.tolist() + for i in range(num_complete_blocks): + block_bytes = tokens_bytes[i * block_byte_size : (i + 1) * block_byte_size] + digest = hashlib.sha256(parent_digest + block_bytes).digest() - # Chain parent hashes on CPU (C-level accumulate, no Python loop) - hashes = list( - accumulate( - token_hashes_list, - lambda parent, th: (parent * HASH_BASE + th) % HASH_PRIME + 1, - initial=0, - ) - )[1:] + # Map to positive int64 range [1, 2^63-1], avoiding sentinels -1 and 0 + raw = int.from_bytes(digest[:8], byteorder='little', signed=False) + hash_val = (raw % (2**63 - 1)) + 1 + + hashes.append(hash_val) + parent_digest = digest # Full 32-byte digest chains into next block return hashes @@ -180,9 +182,13 @@ def serialize(self) -> dict: self.inference_parameters.serialize() if self.inference_parameters else None ) - # Serialize tensors. + # Serialize tensors and numpy arrays. obj = { - k: (("tensor", serialize_tensor(v)) if isinstance(v, torch.Tensor) else v) + k: ( + ("tensor", serialize_tensor(v)) + if isinstance(v, torch.Tensor) + else ("ndarray", serialize_ndarray(v)) if isinstance(v, np.ndarray) else v + ) for k, v in obj.items() } return obj @@ -221,10 +227,12 @@ def _post_deserialize(self, obj: dict): else SamplingParams.deserialize(obj["inference_parameters"]) ) - # Deserialize tensors and sampling params. + # Deserialize tensors, numpy arrays, and sampling params. for k, v in obj.items(): if isinstance(v, list) and len(v) == 2 and v[0] == "tensor": setattr(self, k, deserialize_tensor(v[1])) + elif isinstance(v, list) and len(v) == 2 and v[0] == "ndarray": + setattr(self, k, deserialize_ndarray(v[1])) class DynamicInferenceEventType(Enum): @@ -361,9 +369,8 @@ class DynamicInferenceRequest(InferenceRequest): policy_epoch: Optional[list[tuple[int, int]]] = None kv_cache_epoch: Optional[list[tuple[int, int]]] = None latency: Optional[float] = None - # routing_indices stores MoE routing decisions for all tokens generated so far. - # Shape: [total_tokens, num_layers, topk] - accumulated across all generation steps - routing_indices: Optional[torch.Tensor] = None + # routing_indices is reconstructed from per-block storage when a request finishes. + routing_indices: Optional[np.ndarray] = None finished_chunk_token_count: int = 0 stop_word_ids: Optional[List[List[int]]] = None # Tokenized stop words (populated internally) @@ -434,7 +441,7 @@ def serialize(self): obj["events"] = [e.serialize() for e in self.events] obj.pop("event_add_engine", None) - # Sanity check routing_indices: Tensor [total_tokens - 1, num_layers, topk] + # Sanity check routing_indices: ndarray [total_tokens - 1, num_layers, topk] if self.routing_indices is not None: total_tokens = len(self.prompt_tokens) + len(self.generated_tokens) # the last generated token does not undergo a forward pass @@ -469,26 +476,25 @@ def tracked_metadata(self) -> List[Any]: "in its sampling_params. Defaulting to -1." ) sp.termination_id = -1 - return [getattr(sp, field) for field, _, _ in self.get_metadata_types()] + return [getattr(sp, field) for field, _ in self.get_metadata_types()] @staticmethod - def get_metadata_types() -> List[Tuple[str, torch.dtype, bool]]: - """Keeps track of all request metadata names, dtypes, and target device. + def get_metadata_types() -> List[Tuple[str, torch.dtype]]: + """Keeps track of all request metadata names and dtypes. Returns: - List[Tuple[str, torch.dtype, bool]]: Mapping from metadata name to: + List[Tuple[str, torch.dtype]]: Mapping from metadata name to: name (str) - The name of the metadata field. dtype (torch.dtype) - The datatype of the metadata. - on_device (bool) - Whether the metadata lives on GPU (True) or CPU (False). """ return [ - ("temperature", torch.float32, False), # CPU for torch sampling - ("top_k", torch.int32, False), # CPU for torch sampling - ("top_p", torch.float32, False), # CPU for torch sampling - ("termination_id", torch.int64, True), - ("return_log_probs", torch.bool, False), # CPU for non-selective logprobs - ("skip_prompt_log_probs", torch.bool, False), # CPU for non-selective logprobs - ("top_n_logprobs", torch.int32, False), # CPU for torch sampling + ("temperature", torch.float32), + ("top_k", torch.int32), + ("top_p", torch.float32), + ("termination_id", torch.int64), + ("return_log_probs", torch.bool), + ("skip_prompt_log_probs", torch.bool), + ("top_n_logprobs", torch.int32), ] def add_event( @@ -695,8 +701,9 @@ def merge_lists(key): prompt_tokens = self.requests[0].prompt_tokens prompt_text = self.requests[0].prompt routing_indices = None - if self.requests[0].routing_indices is not None: - routing_indices = torch.cat([r.routing_indices for r in self.requests]) + routing_parts = [r.routing_indices for r in self.requests if r.routing_indices is not None] + if routing_parts: + routing_indices = np.concatenate(routing_parts) generated_tokens = merge_lists("generated_tokens") try: generated_text = "".join(r.generated_text for r in self.requests) diff --git a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py index 55efb24cb08..5fbbcc376f3 100644 --- a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py @@ -126,34 +126,6 @@ def _forward(self, inference_input): runtime_gather_output=True, # Inference should always gather the logits ) - @torch.inference_mode() - def dummy_forward(self): - """Run a dummy forward pass through the model, with a single token. - Use-case: Used in EP on ranks which do not have any work, but are needed - for the all-to-all communication. - Runs under inference_mode so that transformer layers can distinguish this eager - dummy_forward from training/validation passes and skip matching on CUDA graphs.""" - - # we use num_dummy_tokens equal to tensor model parallel size - # so that the dummy forward pass will work with sequence parallel - num_dummy_tokens = self.tp_size - tokens = torch.zeros( - (1, num_dummy_tokens), dtype=torch.long, device=torch.cuda.current_device() - ) - position_ids = torch.zeros( - (1, num_dummy_tokens), dtype=torch.long, device=torch.cuda.current_device() - ) - attention_mask = None - # Always skip MTP during dummy forwards. When num_speculative_tokens > 0 - # the serial MTP path handles MTP separately (with its own dummy forward). - # When num_speculative_tokens == 0 MTP is not needed at all. In both - # cases, running MTP here would issue MoE all-to-all collectives that the - # real EP ranks do not execute, causing a hang. - is_spec_decode = ( - self.inference_context.is_dynamic_batching() and self.config.mtp_num_layers is not None - ) - return self.model(tokens, position_ids, attention_mask, is_spec_decode=is_spec_decode) - def _get_batch_size_and_seq_len( self, tokens: torch.Tensor, recv_buffer_seq_len: Optional[int] = None ): diff --git a/megatron/core/inference/moe/__init__.py b/megatron/core/inference/moe/__init__.py index dbbb24f07bf..cc64fb65110 100644 --- a/megatron/core/inference/moe/__init__.py +++ b/megatron/core/inference/moe/__init__.py @@ -2,55 +2,17 @@ import enum -import torch - from .fused_moe import ActivationType, mcore_fused_moe +from .vllm_fused_moe import vllm_fused_moe class InferenceGroupedGemmBackend(enum.Enum): - """Resolved backend for grouped GEMM operations during inference.""" + """Backend for grouped GEMM operations during inference. + + The string value matches the inference_grouped_gemm_backend config field so + TransformerConfig.__post_init__ can convert via InferenceGroupedGemmBackend(str). + """ FLASHINFER = "flashinfer" TORCH = "torch" - TE = "te" - - -def resolve_inference_grouped_gemm_backend( - backend: str, is_cuda_graphed: bool, is_mxfp8: bool = False -) -> InferenceGroupedGemmBackend: - """Resolve the grouped GEMM backend to use for the current iteration. - - Prerequisites are validated at init time in MoELayer; this function - simply maps (backend, is_cuda_graphed) to the concrete backend enum. - - Args: - backend: One of 'auto', 'torch', 'te'. - is_cuda_graphed: Whether this is a CUDA-graphed iteration. - is_mxfp8: Whether the model is using MXFP8 quantization (affects auto backend choice). - Returns: - An InferenceGroupedGemmBackend enum value. - """ - if backend == 'auto': - if is_mxfp8: - assert hasattr(torch.nn.functional, 'scaled_grouped_mm'), ( - "Auto backend selection for MXFP8 requires " - "torch.nn.functional.scaled_grouped_mm. " - "Please install PyTorch 2.10+." - ) - return InferenceGroupedGemmBackend.TORCH - if is_cuda_graphed: - return InferenceGroupedGemmBackend.FLASHINFER - else: - if hasattr(torch.nn.functional, 'grouped_mm'): - return InferenceGroupedGemmBackend.TORCH - else: - return InferenceGroupedGemmBackend.TE - elif backend == 'torch': - return InferenceGroupedGemmBackend.TORCH - elif backend == 'te': - return InferenceGroupedGemmBackend.TE - else: - raise ValueError( - f"Unknown inference_grouped_gemm_backend: '{backend}'. " - "Must be 'auto', 'torch', or 'te'." - ) + VLLM = "vllm" diff --git a/megatron/core/inference/moe/activations.py b/megatron/core/inference/moe/activations.py index 169d8499116..ae5e4560ce3 100644 --- a/megatron/core/inference/moe/activations.py +++ b/megatron/core/inference/moe/activations.py @@ -30,25 +30,53 @@ def _ceil_div(a, b): @triton.jit -def _squared_relu_kernel(input_ptr, output_ptr, src_idx_ptr, M, N, BLOCK_N: tl.constexpr): - """Squared ReLU that skips padding rows (permutation_map == -1).""" - row = tl.program_id(0) - if tl.load(src_idx_ptr + row) < 0: - return - for n in tl.range(0, N, BLOCK_N): - o = n + tl.arange(0, BLOCK_N) - m = o < N - x = tl.load(input_ptr + row * N + o, mask=m).to(tl.float32) - r = tl.maximum(x, 0.0) - tl.store(output_ptr + row * N + o, (r * r).to(tl.bfloat16), mask=m) +def _squared_relu_kernel( + input_ptr, + output_ptr, + src_idx_ptr, + n_used_ptr, + N, + max_rows, # output_size (fixed for CG) + BLOCK_N: tl.constexpr, + NUM_BLOCKS: tl.constexpr, # grid size (fixed for CG) +): + """Squared ReLU that skips rows beyond n_used and alignment-padding rows (perm_map == -1). + Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. + n_used_ptr gates how many rows are processed — required for CUDA graph compatibility. + """ + pid = tl.program_id(0) + n_used = tl.load(n_used_ptr) + if pid >= n_used: + return + for row in tl.range(pid, max_rows, NUM_BLOCKS): + if row < n_used: + if tl.load(src_idx_ptr + row) >= 0: + for n in tl.range(0, N, BLOCK_N): + o = n + tl.arange(0, BLOCK_N) + m = o < N + x = tl.load(input_ptr + row * N + o, mask=m).to(tl.float32) + r = tl.maximum(x, 0.0) + tl.store(output_ptr + row * N + o, (r * r).to(tl.bfloat16), mask=m) + + +def padded_squared_relu( + x: torch.Tensor, permutation_map: torch.Tensor, n_used: torch.Tensor +) -> torch.Tensor: + """Squared ReLU activation that skips rows beyond n_used and alignment-padding rows. -def padded_squared_relu(x: torch.Tensor, permutation_map: torch.Tensor) -> torch.Tensor: - """Squared ReLU activation that skips padding rows.""" + Args: + x: [output_size, ffn_hidden] BF16 FC1 output. + permutation_map: [output_size] int32, original token index or -1 for padding. + n_used: scalar int32 CUDA tensor = inclusive_expert_offsets[-1]. + """ M, N = x.shape - out = torch.zeros(M, N, dtype=x.dtype, device=x.device) + out = torch.empty(M, N, dtype=x.dtype, device=x.device) BLOCK_N = min(triton.next_power_of_2(N), 1024) - _squared_relu_kernel[(M,)](x, out, permutation_map, M, N, BLOCK_N=BLOCK_N) + NUM_BLOCKS = min(M, 512) + _squared_relu_kernel[(NUM_BLOCKS,)]( + x, out, permutation_map, n_used, N, M, BLOCK_N=BLOCK_N, NUM_BLOCKS=NUM_BLOCKS + ) return out @@ -58,68 +86,71 @@ def _squared_relu_quantize_kernel( out_fp8_ptr, out_scale_ptr, src_idx_ptr, + n_used_ptr, # pointer to inclusive_expert_offsets[-1]: number of used rows this iteration K, n_col_blocks, - skip_padding: tl.constexpr, + max_rows, # output_size (fixed for CG) REAL_GROUPS: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_GROUPS: tl.constexpr, + NUM_BLOCKS: tl.constexpr, # grid size (fixed for CG) ): """Fused squared ReLU + MXFP8 quantize + swizzle in one kernel. - Grid: (M,) — one program per row. - Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8, - writes FP8 data + swizzled scales in place. + Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. + Rows beyond n_used and alignment-padding rows (perm_map == -1) are skipped. """ - row = tl.program_id(0) - if skip_padding: - if tl.load(src_idx_ptr + row) < 0: - return - - offs = tl.arange(0, BLOCK_K) - mask = offs < K - - # Load and apply squared ReLU - x = tl.load(input_ptr + row * K + offs, mask=mask, other=0.0).to(tl.float32) - relu = tl.maximum(x, 0.0) - activated = relu * relu - - # Per-group-of-32 quantization - x_grouped = tl.reshape(activated, [BLOCK_GROUPS, 32]) - abs_grouped = tl.abs(x_grouped) - max_vals = tl.max(abs_grouped, axis=1) - - dequant_scale = max_vals / 448.0 - dequant_exp = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 - dequant_rounded = dequant_exp.to(tl.float32, bitcast=True) - quant_scale = tl.where(dequant_rounded == 0, 0.0, 1.0 / dequant_rounded) - - quantized = x_grouped * quant_scale[:, None] - quantized_flat = tl.reshape(quantized, [BLOCK_K]) - out_fp8 = quantized_flat.to(tl.float8e4nv) - - # Store FP8 data - tl.store(out_fp8_ptr + row * K + offs, out_fp8, mask=mask) - - # Store swizzled scales - scale_exp = (dequant_exp >> 23).to(tl.uint8) - col_offs = tl.arange(0, BLOCK_GROUPS) - col_mask = col_offs < REAL_GROUPS - - macro_row_block = row // 128 - macro_col_block = col_offs // 4 - local_row = row % 128 - local_col = col_offs % 4 - group = local_row // 32 - sub_row = local_row % 32 - tile_idx = macro_row_block * n_col_blocks + macro_col_block - swizzled_offs = tile_idx * 512 + sub_row * 16 + group * 4 + local_col - - tl.store(out_scale_ptr + swizzled_offs, scale_exp, mask=col_mask) + pid = tl.program_id(0) + n_used = tl.load(n_used_ptr) + if pid >= n_used: + return + for row in tl.range(pid, max_rows, NUM_BLOCKS): + if row < n_used: + if tl.load(src_idx_ptr + row) >= 0: + offs = tl.arange(0, BLOCK_K) + mask = offs < K + + # Load and apply squared ReLU + x = tl.load(input_ptr + row * K + offs, mask=mask, other=0.0).to(tl.float32) + relu = tl.maximum(x, 0.0) + activated = relu * relu + + # Per-group-of-32 quantization + x_grouped = tl.reshape(activated, [BLOCK_GROUPS, 32]) + abs_grouped = tl.abs(x_grouped) + max_vals = tl.max(abs_grouped, axis=1) + + dequant_scale = max_vals / 448.0 + dequant_exp = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + dequant_rounded = dequant_exp.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_rounded == 0, 0.0, 1.0 / dequant_rounded) + + quantized = x_grouped * quant_scale[:, None] + quantized_flat = tl.reshape(quantized, [BLOCK_K]) + out_fp8 = quantized_flat.to(tl.float8e4nv) + + # Store FP8 data + tl.store(out_fp8_ptr + row * K + offs, out_fp8, mask=mask) + + # Store swizzled scales + scale_exp = (dequant_exp >> 23).to(tl.uint8) + col_offs = tl.arange(0, BLOCK_GROUPS) + col_mask = col_offs < REAL_GROUPS + + macro_row_block = row // 128 + macro_col_block = col_offs // 4 + local_row = row % 128 + local_col = col_offs % 4 + group = local_row // 32 + sub_row = local_row % 32 + tile_idx = macro_row_block * n_col_blocks + macro_col_block + swizzled_offs = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + + tl.store(out_scale_ptr + swizzled_offs, scale_exp, mask=col_mask) def squared_relu_and_quantize_mxfp8( - x: torch.Tensor, permutation_map: torch.Tensor, skip_padding: bool = True + x: torch.Tensor, permutation_map: torch.Tensor, n_used: torch.Tensor ): """Fused squared ReLU + MXFP8 quantize + swizzle. @@ -127,12 +158,13 @@ def squared_relu_and_quantize_mxfp8( swizzled scales. Single kernel replaces padded_squared_relu + mxfp8_quantize. Args: - x: [M, K] BF16 FC1 output. - permutation_map: [M] int32, original token index or -1 for padding. - skip_padding: if True, skip rows where permutation_map == -1. + x: [output_size, K] BF16 FC1 output. + permutation_map: [output_size] int32, original token index or -1 for padding. + n_used: scalar int32 CUDA tensor = inclusive_expert_offsets[-1]. Rows beyond + this are skipped before even checking the permutation_map. Returns: - MXFP8Tensor with .data [M, K] float8_e4m3fn and .scale (swizzled e8m0). + MXFP8Tensor with .data [output_size, K] float8_e4m3fn and .scale (swizzled e8m0). """ from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor @@ -149,18 +181,21 @@ def squared_relu_and_quantize_mxfp8( BLOCK_K = triton.next_power_of_2(K) BLOCK_GROUPS = BLOCK_K // 32 + NUM_BLOCKS = min(M, 512) - _squared_relu_quantize_kernel[(M,)]( + _squared_relu_quantize_kernel[(NUM_BLOCKS,)]( x, out_fp8, out_scale, permutation_map, + n_used, K, n_col_blocks, - skip_padding, + M, REAL_GROUPS=scale_cols, BLOCK_K=BLOCK_K, BLOCK_GROUPS=BLOCK_GROUPS, + NUM_BLOCKS=NUM_BLOCKS, ) return MXFP8Tensor(data=out_fp8, scale=out_scale.view(torch.float8_e8m0fnu), backend="triton") diff --git a/megatron/core/inference/moe/fused_moe.py b/megatron/core/inference/moe/fused_moe.py index 39382eee079..f6c0af4e94e 100644 --- a/megatron/core/inference/moe/fused_moe.py +++ b/megatron/core/inference/moe/fused_moe.py @@ -6,7 +6,7 @@ """ from enum import Enum -from typing import Callable, Optional +from typing import Callable import torch @@ -14,7 +14,6 @@ padded_squared_relu, squared_relu_and_quantize_mxfp8, ) -from megatron.core.inference.moe.pad import pad_to_alignment, unpad_from_alignment from megatron.core.inference.moe.permute import ( permute_and_quantize_mxfp8, permute_tokens, @@ -27,7 +26,9 @@ HAVE_GROUPED_MM = True except ImportError: - HAVE_GROUPED_MM = False + # Fallback to the private symbol for torch versions < 2.10. + grouped_mm = getattr(torch, "_grouped_mm", None) + HAVE_GROUPED_MM = grouped_mm is not None try: from torch.nn.functional import ScalingType, SwizzleType, scaled_grouped_mm @@ -86,49 +87,46 @@ def mcore_fused_moe( activation_type: ActivationType, num_local_experts: int, local_expert_start: int, - routing_map: Optional[torch.Tensor] = None, - tokens_per_expert: Optional[torch.Tensor] = None, - skip_permute: bool = False, + valid_tokens: torch.Tensor, + routing_map: torch.Tensor, disable_fused_quant_kernels: bool = False, + out: torch.Tensor = None, ) -> torch.Tensor: - """Fused MoE: [permute ->] pad -> FC1 -> activation -> FC2 -> unpad [-> unpermute]. - - Two modes: - - skip_permute=False (default): tokens are unpermuted. Requires routing_map. - Performs full permute -> compute -> unpermute. - - skip_permute=True: tokens are already permuted by the dispatcher. Requires - tokens_per_expert. Pads to alignment, computes, then unpads. Probs are - applied during unpad. + """Fused MoE: permute -> pad -> FC1 -> activation -> FC2 -> unpad -> unpermute. Unless disable_fused_quant_kernels=True, when weights are MXFP8, uses fused kernels that combine permute/activation with MXFP8 quantization into single kernel launches. Args: - hidden_states: [num_tokens, hidden_size] BF16 input. - probs: routing probabilities. Shape is [num_tokens, topk] when - skip_permute=False, or [num_tokens] (already gathered) when - skip_permute=True. + hidden_states: [max_tokens, hidden_size] BF16 input. max_tokens = + max_local_tokens * ep_size; only the first valid_tokens rows are valid. + probs: [max_tokens, topk] routing probabilities. fc1_weight: stacked weight for FC1 (torch.Tensor for BF16, MXFP8Tensor for MXFP8). fc2_weight: stacked weight for FC2 (same type as fc1_weight). activation_type: ActivationType enum (SQUARED_RELU). num_local_experts: number of experts on this rank. local_expert_start: first global expert index on this rank. - routing_map: [num_tokens, topk] int expert assignments. Required when skip_permute=False. - tokens_per_expert: [num_local_experts] int32 token counts. Required when skip_permute=True. - skip_permute: if True, skip permute/unpermute (tokens already in expert order). + valid_tokens: scalar int32 CUDA tensor holding the number of valid tokens this + iteration. Kernels use this to ignore rows beyond the valid prefix — required + for CUDA graph compatibility since hidden_states is always max-sized. + routing_map: [max_tokens, topk] int expert assignments. disable_fused_quant_kernels: if True, disable fused permute+quantize and activation+quantize kernels for MXFP8, using separate launches instead. Useful for debugging. Ignored when weights are BF16. + out: optional pre-allocated output buffer. If provided, unpermute writes + directly into this tensor (e.g. the RSV symmetric buffer), avoiding a + separate copy before reduce-scatter. Returns: - [num_tokens, hidden_size] BF16 output. + [max_tokens, hidden_size] BF16 output. Only the first valid_tokens rows are + meaningful; rows beyond that are undefined. """ assert ( hidden_states.dtype == torch.bfloat16 ), f"mcore_fused_moe requires bf16 input, got {hidden_states.dtype}" - num_tokens = hidden_states.shape[0] + max_tokens = hidden_states.shape[0] use_mxfp8 = isinstance(fc1_weight, MXFP8Tensor) # Fused quant kernels only apply to MXFP8 path use_fused_quant = use_mxfp8 and not disable_fused_quant_kernels @@ -151,54 +149,47 @@ def mcore_fused_moe( activation_func = _get_activation_func(activation_type, fused_quant=use_fused_quant) - # --- Pre-processing: permute or pad --- - if skip_permute: - assert tokens_per_expert is not None, "tokens_per_expert is required when skip_permute=True" - tokens_per_expert = tokens_per_expert.cuda().int() - assert routing_map is None, "routing_map must be None when skip_permute=True" - hidden_states, permutation_map, offs = pad_to_alignment( - hidden_states, tokens_per_expert, expert_alignment + # --- Pre-processing: permute --- + if use_fused_quant: + # Fused permute + MXFP8 quantize: single kernel produces MXFP8Tensor + hidden_states, permuted_probs, permutation_map, offs = permute_and_quantize_mxfp8( + hidden_states, + probs, + routing_map, + local_expert_start, + num_local_experts, + valid_tokens, + alignment=expert_alignment, ) - permuted_probs = None - else: - assert routing_map is not None, "routing_map is required when skip_permute=False" - if use_fused_quant: - # Fused permute + MXFP8 quantize: single kernel produces MXFP8Tensor - hidden_states, permuted_probs, permutation_map, offs = permute_and_quantize_mxfp8( - hidden_states, - probs, - routing_map, - local_expert_start, - num_local_experts, - alignment=expert_alignment, - ) - else: - hidden_states, permuted_probs, permutation_map, offs = permute_tokens( - hidden_states, - probs, - routing_map, - local_expert_start, - num_local_experts, - alignment=expert_alignment, - ) + hidden_states, permuted_probs, permutation_map, offs = permute_tokens( + hidden_states, + probs, + routing_map, + local_expert_start, + num_local_experts, + valid_tokens, + alignment=expert_alignment, + ) # --- FC1 -> activation -> FC2 --- # Quantize if MXFP8 path and hidden_states not already quantized (fused permute+quant - # produces MXFP8Tensor directly; skip_permute path always needs separate quant). - needs_quant = use_mxfp8 and not isinstance(hidden_states, MXFP8Tensor) - if needs_quant: + # produces MXFP8Tensor directly). + if use_mxfp8 and not isinstance(hidden_states, MXFP8Tensor): hidden_states = MXFP8Tensor.from_bf16(hidden_states, backend="triton") fc1_output = mm_fn(hidden_states, fc1_weight, offs) - activation_out = activation_func(fc1_output, permutation_map) + # offs[-1:] is a 1-element view pointing to inclusive_expert_offsets[-1] — the total + # number of rows actually used by experts this iteration (valid tokens + alignment + # padding within expert blocks). Passed to activation and unpermute to skip unused rows. + n_used = offs[-1:] + activation_out = activation_func(fc1_output, permutation_map, n_used) # Fused activation+quant returns MXFP8Tensor; otherwise quantize separately. if use_mxfp8 and not isinstance(activation_out, MXFP8Tensor): activation_out = MXFP8Tensor.from_bf16(activation_out, backend="triton") fc2_output = mm_fn(activation_out, fc2_weight, offs) - # --- Post-processing: unpermute or unpad --- - if skip_permute: - probs_1d = probs.squeeze(-1) if probs.dim() > 1 else probs - return unpad_from_alignment(fc2_output, permutation_map, num_tokens, probs=probs_1d) - else: - return unpermute_tokens(fc2_output, permuted_probs, permutation_map, num_tokens) + + # --- Post-processing: unpermute --- + return unpermute_tokens( + fc2_output, permuted_probs, permutation_map, max_tokens, n_used, valid_tokens, out=out + ) diff --git a/megatron/core/inference/moe/metadata.py b/megatron/core/inference/moe/metadata.py new file mode 100644 index 00000000000..8658ab9b42a --- /dev/null +++ b/megatron/core/inference/moe/metadata.py @@ -0,0 +1,134 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Fused NVLS metadata update kernel for MoE expert parallelism. + +Replaces the multi-kernel sequence: + dist.all_gather_into_tensor(...) # NCCL + local_tokens_per_rank.sum() # kernel + local_tokens_per_rank[:rank].sum() # kernel + local_tokens_per_rank.max() # kernel + _step_metadata.copy_(...) # kernel + +with a single Triton kernel that: + 1. Multicast-stores this rank's local_tokens to the symmetric memory buffer. + 2. Barrier (all ranks have written). + 3. Reads all ranks' counts, computes sum / prefix-sum / max. + 4. Writes the 3-element step_metadata tensor in-place. +""" + +from unittest.mock import MagicMock + +import torch + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + HAVE_TRITON = False + +try: + from torch._C._distributed_c10d import _SymmetricMemory +except ImportError: + _SymmetricMemory = MagicMock() + +from megatron.core.inference.communication.torch_symm_triton.barrier import symm_mem_sync +from megatron.core.inference.communication.torch_symm_triton.multimem_asm import st_32 +from megatron.core.inference.communication.torch_symm_triton.utils import sync_threads + + +@triton.jit +def _fused_metadata_kernel( + local_tokens, + local_buf_ptr, + multicast_ptr, + signal_pad_ptrs, + step_metadata_ptr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): + """Fused allgather + reduce kernel for MoE step metadata. + + Single CTA. Writes this rank's local_tokens to the symmetric buffer + via multicast store, barriers, then reads all ranks' values from the + local buffer and computes [valid_tokens, rank_token_offset, ep_max_tokens]. + + Args: + local_tokens: scalar int32, this rank's token count. + local_buf_ptr: pointer to the local symmetric memory buffer (for reads). + multicast_ptr: multicast pointer to the symmetric memory buffer (for writes). + signal_pad_ptrs: signal pads for barrier synchronization. + step_metadata_ptr: pointer to the 3-element int32 output tensor. + RANK: this rank's index (constexpr). + WORLD_SIZE: total number of ranks (constexpr). + """ + + tid = tl.program_id(0) + if tid > 0: + return + + # 1. Multicast-store local_tokens to buffer[RANK]. + mc_ptr = multicast_ptr.to(tl.pointer_type(tl.uint32)) + RANK + mask = tl.full([], 1, dtype=tl.int1) + val = tl.full([], local_tokens, dtype=tl.uint32) + st_32(mc_ptr, val, mask, multicast_op=True) + + # 2. Barrier — wait for all ranks to have written. + sync_threads() + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=True, + hasSubsequentMemAccess=True, + ) + + # 3. Load all ranks' values, reduce, and write metadata. + offsets = tl.arange(0, WORLD_SIZE) + vals = tl.load(local_buf_ptr + offsets) + + total = tl.sum(vals) + prefix = tl.sum(tl.where(offsets < RANK, vals, tl.zeros_like(vals))) + max_val = tl.max(vals) + + tl.store(step_metadata_ptr, total) + tl.store(step_metadata_ptr + 1, prefix) + tl.store(step_metadata_ptr + 2, max_val) + + +def fused_metadata_update( + local_tokens: int, + local_buf: torch.Tensor, + symm_mem_hdl: _SymmetricMemory, + step_metadata: torch.Tensor, +) -> None: + """Fused NVLS allgather + reduce for MoE step metadata. + + Args: + local_tokens: number of tokens on this rank this step. + local_buf: the local symmetric memory buffer tensor ([WORLD_SIZE] int32). + Used for reads after the barrier. + symm_mem_hdl: symmetric memory handle for the metadata buffer. + Provides the multicast pointer for writes and signal pads for barrier. + step_metadata: [3] int32 CUDA tensor to write + [valid_tokens, rank_token_offset, ep_max_tokens] into. + """ + assert HAVE_TRITON, "Triton is required for fused_metadata_update." + + _fused_metadata_kernel[(1, 1, 1)]( + local_tokens, + local_buf, + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + step_metadata, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=min(max(1, (symm_mem_hdl.world_size + 31) // 32), 8), + ) diff --git a/megatron/core/inference/moe/pad.py b/megatron/core/inference/moe/pad.py deleted file mode 100644 index 656953b691c..00000000000 --- a/megatron/core/inference/moe/pad.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -"""Pad / unpad utilities for already-permuted expert tokens. - -When the token dispatcher has already permuted tokens into expert-grouped -order, these functions insert/remove alignment padding so that each expert's -token block satisfies the alignment requirements of grouped_mm / -scaled_grouped_mm. -""" - -from unittest.mock import MagicMock - -import torch -from packaging import version - -from megatron.core.utils import null_decorator - -try: - import triton - import triton.language as tl - - if version.parse(triton.__version__) < version.parse("3.4.0") and not torch.cuda.is_available(): - HAVE_TRITON = False - else: - HAVE_TRITON = tl.constexpr(version.parse(triton.__version__) >= version.parse("2.0.0")) -except ImportError: - HAVE_TRITON = False - -if not HAVE_TRITON: - triton = MagicMock() - triton.jit = null_decorator - tl = MagicMock() - -from megatron.core.inference.moe.permute import compute_expert_offsets - - -@triton.jit -def _pad_tokens_kernel( - src_ptr, - dst_ptr, - perm_map_ptr, - tpe_ptr, # tokens_per_expert [num_experts] - hidden_dim, - num_experts: tl.constexpr, - alignment: tl.constexpr, - BLOCK_H: tl.constexpr, -): - """Copy one input row into the padded output buffer. - - Computes unpadded and padded cumulative offsets inline from - tokens_per_expert, avoiding a separate cumsum kernel launch. - """ - row = tl.program_id(0) - - # Walk tokens_per_expert to find which expert this row belongs to - # and compute both unpadded and padded start offsets on the fly. - unpadded_start = tl.zeros([], dtype=tl.int32) - padded_start = tl.zeros([], dtype=tl.int32) - expert_id = -1 - for e in tl.static_range(0, num_experts): - count = tl.load(tpe_ptr + e).to(tl.int32) - if expert_id < 0 and row < unpadded_start + count: - expert_id = e - if expert_id < 0: - unpadded_start += count - aligned = tl.where( - count > 0, - ((count + alignment - 1) // alignment) * alignment, - tl.zeros([], dtype=tl.int32), - ) - padded_start += aligned - - if expert_id < 0: - return - - local_idx = row - unpadded_start - dst_row = padded_start + local_idx - - # Write permutation_map: padded row → original unpadded row - tl.store(perm_map_ptr + dst_row, row) - - # Copy hidden state - for h in tl.range(0, hidden_dim, BLOCK_H): - o = h + tl.arange(0, BLOCK_H) - m = o < hidden_dim - tl.store( - dst_ptr + dst_row * hidden_dim + o, - tl.load(src_ptr + row * hidden_dim + o, mask=m), - mask=m, - ) - - -def pad_to_alignment( - hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor, alignment: int -) -> tuple: - """Pad already-permuted tokens so each expert's block is aligned. - - Args: - hidden_states: [total_tokens, hidden_size] already permuted by dispatcher. - tokens_per_expert: [num_local_experts] int32 token counts. - alignment: per-expert alignment. - - Returns: - (padded_hidden, permutation_map, inclusive_offsets) - - padded_hidden: [padded_total, hidden_size] - - permutation_map: [padded_total] int32, original row index or -1 for padding. - - inclusive_offsets: [num_local_experts] int32 cumulative aligned offsets for grouped_mm. - """ - num_experts = tokens_per_expert.shape[0] - total_tokens = hidden_states.shape[0] - hidden_dim = hidden_states.shape[1] - - # We still need padded_inc for the return value (used as offs by grouped_mm) - _, padded_inc = compute_expert_offsets(tokens_per_expert, alignment=alignment) - padded_total = int(padded_inc[-1].item()) - - padded_hidden = torch.zeros( - padded_total, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device - ) - permutation_map = torch.full( - (padded_total,), -1, dtype=torch.int32, device=hidden_states.device - ) - - if total_tokens > 0: - BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) - _pad_tokens_kernel[(total_tokens,)]( - hidden_states, - padded_hidden, - permutation_map, - tokens_per_expert, - hidden_dim, - num_experts, - alignment, - BLOCK_H=BLOCK_H, - ) - - return padded_hidden, permutation_map, padded_inc - - -@triton.jit -def _unpad_tokens_kernel( - src_ptr, - dst_ptr, - perm_map_ptr, - probs_ptr, - hidden_dim, - has_probs: tl.constexpr, - BLOCK_H: tl.constexpr, -): - """Copy one real (non-padding) row from padded to unpadded layout. - - Optionally multiplies each row by its routing probability. - """ - row = tl.program_id(0) - dst_row = tl.load(perm_map_ptr + row) - if dst_row < 0: - return - if has_probs: - prob = tl.load(probs_ptr + dst_row) - for h in tl.range(0, hidden_dim, BLOCK_H): - o = h + tl.arange(0, BLOCK_H) - m = o < hidden_dim - v = tl.load(src_ptr + row * hidden_dim + o, mask=m) - if has_probs: - v = v * prob - tl.store(dst_ptr + dst_row * hidden_dim + o, v, mask=m) - - -def unpad_from_alignment( - padded_output: torch.Tensor, - permutation_map: torch.Tensor, - original_size: int, - probs: torch.Tensor = None, -) -> torch.Tensor: - """Remove alignment padding, scattering results back to original positions. - - Args: - padded_output: [padded_total, hidden_size] output from expert computation. - permutation_map: [padded_total] int32, original row index or -1 for padding. - original_size: number of rows in the unpadded output. - probs: optional [original_size] routing probabilities to multiply during unpad. - - Returns: - [original_size, hidden_size] unpadded output. - """ - hidden_dim = padded_output.shape[1] - output = torch.zeros( - original_size, hidden_dim, dtype=padded_output.dtype, device=padded_output.device - ) - has_probs = probs is not None - if padded_output.shape[0] > 0: - BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) - _unpad_tokens_kernel[(padded_output.shape[0],)]( - padded_output, - output, - permutation_map, - probs if has_probs else padded_output, # dummy pointer when no probs - hidden_dim, - has_probs, - BLOCK_H=BLOCK_H, - ) - return output diff --git a/megatron/core/inference/moe/permute.py b/megatron/core/inference/moe/permute.py index b14d0b3dbd0..6906c877061 100644 --- a/megatron/core/inference/moe/permute.py +++ b/megatron/core/inference/moe/permute.py @@ -8,6 +8,7 @@ - Unpermute expert outputs back to original token order """ +from typing import Optional from unittest.mock import MagicMock import torch @@ -28,15 +29,26 @@ tl = MagicMock() +_NUM_SMS: Optional[int] = None + + +def _get_num_sms(device: torch.device) -> int: + global _NUM_SMS + if _NUM_SMS is None: + _NUM_SMS = torch.cuda.get_device_properties(device).multi_processor_count + return _NUM_SMS + + def _ceil_div(a, b): return (a + b - 1) // b @triton.jit def _count_local_tokens_kernel( - routing_map_ptr, # [num_tokens * topk] flattened expert assignments + routing_map_ptr, # [max_tokens, topk] flattened expert assignments tokens_per_expert_ptr, # [num_local_experts] output counters (zeroed by caller) - total_pairs, # num_tokens * topk — total (token, topk) pairs + valid_tokens_ptr, # scalar int32 CUDA tensor: number of valid tokens this iteration + topk, # number of expert choices per token local_expert_start, # first global expert index owned by this rank num_local_experts: tl.constexpr, # number of experts on this rank BLOCK_SIZE: tl.constexpr, # number of pairs processed per program @@ -45,33 +57,102 @@ def _count_local_tokens_kernel( Each program processes BLOCK_SIZE (token, topk) pairs. Tokens assigned to experts outside [local_expert_start, local_expert_start + num_local_experts) - are silently skipped. + or beyond valid_tokens are silently skipped. + + Grid is launched at max size (max_tokens * topk); valid_tokens gates which + pairs are actually processed — required for CUDA graph compatibility. """ pid = tl.program_id(0) + valid_tokens = tl.load(valid_tokens_ptr) + valid_pairs = valid_tokens * topk offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < total_pairs + mask = offsets < valid_pairs expert_ids = tl.load(routing_map_ptr + offsets, mask=mask, other=-1) - # Map global expert IDs to local indices; non-local experts become negative local_ids = expert_ids - local_expert_start is_local = (local_ids >= 0) & (local_ids < num_local_experts) & mask tl.atomic_add(tokens_per_expert_ptr + local_ids, 1, mask=is_local) +@triton.jit +def _count_local_tokens_kernel_persistent( + routing_map_ptr, # [max_tokens, topk] flattened expert assignments + tokens_per_expert_ptr, # [num_local_experts] output counters (zeroed by caller) + valid_tokens_ptr, # scalar int32 CUDA tensor: number of valid tokens this iteration + topk, # number of expert choices per token + local_expert_start, # first global expert index owned by this rank + num_local_experts: tl.constexpr, # number of experts on this rank + num_sms, # number of SMs (grid size for persistent kernel) + BLOCK_SIZE: tl.constexpr, # number of pairs processed per iteration +): + """Count tokens routed to local experts using a persistent grid. + + Launches num_sms CTAs. Each CTA loops over its share of BLOCK_SIZE-sized + chunks, with total work determined device-side from valid_tokens. + """ + pid = tl.program_id(0) + valid_tokens = tl.load(valid_tokens_ptr) + valid_pairs = valid_tokens * topk + + total_blocks = tl.cdiv(valid_pairs, BLOCK_SIZE) + blocks_per_cta = tl.cdiv(total_blocks, num_sms) + block_start = pid * blocks_per_cta + + if block_start < total_blocks: + block_end = tl.minimum(block_start + blocks_per_cta, total_blocks) + + for block_id in tl.range(block_start, block_end): + offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < valid_pairs + expert_ids = tl.load(routing_map_ptr + offsets, mask=mask, other=-1) + local_ids = expert_ids - local_expert_start + is_local = (local_ids >= 0) & (local_ids < num_local_experts) & mask + tl.atomic_add(tokens_per_expert_ptr + local_ids, 1, mask=is_local) + + def compute_local_tokens_per_expert( - routing_map: torch.Tensor, local_expert_start: int, num_local_experts: int + routing_map: torch.Tensor, + local_expert_start: int, + num_local_experts: int, + valid_tokens: torch.Tensor, + persistent: bool = False, ) -> torch.Tensor: - """Count tokens routed to each local expert.""" - total_pairs = routing_map.numel() + """Count tokens routed to each local expert. + + Args: + routing_map: [max_tokens, topk] expert assignments. Only the first + valid_tokens rows are processed; the rest are ignored. + local_expert_start: first global expert index on this rank. + num_local_experts: number of experts on this rank. + valid_tokens: scalar int32 CUDA tensor with the number of valid tokens + this iteration. Fixed address; value updated each step before graph replay. + persistent: use persistent-grid kernel variant (fewer CTAs, looped). + """ + max_pairs = routing_map.numel() + topk = routing_map.shape[1] tokens_per_expert = torch.zeros(num_local_experts, dtype=torch.int32, device=routing_map.device) - BLOCK = 256 - _count_local_tokens_kernel[(_ceil_div(total_pairs, BLOCK),)]( - routing_map, - tokens_per_expert, - total_pairs, - local_expert_start, - num_local_experts, - BLOCK_SIZE=BLOCK, - ) + BLOCK = 1024 + if persistent: + num_sms = _get_num_sms(routing_map.device) + _count_local_tokens_kernel_persistent[(num_sms,)]( + routing_map, + tokens_per_expert, + valid_tokens, + topk, + local_expert_start, + num_local_experts, + num_sms, + BLOCK_SIZE=BLOCK, + ) + else: + _count_local_tokens_kernel[(_ceil_div(max_pairs, BLOCK),)]( + routing_map, + tokens_per_expert, + valid_tokens, + topk, + local_expert_start, + num_local_experts, + BLOCK_SIZE=BLOCK, + ) return tokens_per_expert @@ -101,6 +182,39 @@ def _prefix_sum_kernel( tl.store(inclusive_offsets_ptr + r, inc, mask=mask) +@triton.jit +def _init_permutation_map_kernel( + perm_map_ptr, + n_used_ptr, # pointer to inclusive_expert_offsets[-1]: total used rows this iteration + BLOCK_SIZE: tl.constexpr, +): + """Initialize permutation_map entries to -1 up to n_used rows. + + Grid is launched at max size; entries beyond n_used are left untouched — + the activation and unpermute kernels are gated by the same n_used pointer + so they never read those entries. + """ + pid = tl.program_id(0) + n_used = tl.load(n_used_ptr) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_used + tl.store(perm_map_ptr + offsets, tl.full([BLOCK_SIZE], -1, tl.int32), mask=mask) + + +def init_permutation_map(permutation_map: torch.Tensor, n_used: torch.Tensor) -> None: + """Fill permutation_map[0:n_used] with -1. + + Args: + permutation_map: [output_size] int32 buffer (pre-allocated at max size). + n_used: scalar int32 CUDA tensor = inclusive_expert_offsets[-1]. + """ + output_size = permutation_map.shape[0] + BLOCK_SIZE = 1024 + _init_permutation_map_kernel[(_ceil_div(output_size, BLOCK_SIZE),)]( + permutation_map, n_used, BLOCK_SIZE=BLOCK_SIZE + ) + + def compute_expert_offsets(tokens_per_expert: torch.Tensor, alignment: int = 1) -> tuple: """Compute exclusive and inclusive prefix sums of aligned token counts.""" n = tokens_per_expert.shape[0] @@ -119,52 +233,55 @@ def compute_expert_offsets(tokens_per_expert: torch.Tensor, alignment: int = 1) @triton.jit def _permute_tokens_kernel( - hidden_ptr, # [num_tokens, hidden_dim] input hidden states - probs_ptr, # [num_tokens, topk] routing probabilities - routing_map_ptr, # [num_tokens, topk] expert assignments (global IDs) + hidden_ptr, # [max_tokens, hidden_dim] input hidden states + probs_ptr, # [max_tokens, topk] routing probabilities + routing_map_ptr, # [max_tokens, topk] expert assignments (global IDs) out_hidden_ptr, # [output_size, hidden_dim] output: permuted hidden states out_probs_ptr, # [output_size] output: permuted probabilities out_src_idx_ptr, # [output_size] output: permutation_map (original token index, -1 for padding) - counters_ptr, # [num_local_experts] exclusive offsets, - # atomically incremented to assign positions - num_tokens, # number of input tokens + counters_ptr, # [num_local_experts] exclusive offsets, atomically incremented + valid_tokens_ptr, # scalar int32 CUDA tensor: number of valid tokens this iteration hidden_dim, # hidden dimension + max_pairs, # max_tokens * topk (fixed for CG) topk: tl.constexpr, # number of expert choices per token local_expert_start, # first global expert index on this rank num_local_experts: tl.constexpr, # number of experts on this rank BLOCK_H: tl.constexpr, # tile size for copying hidden_dim + NUM_BLOCKS: tl.constexpr, # grid size (fixed for CG) ): """Permute tokens into expert-grouped order. - Grid: one program per (token, topk) pair. Each program looks up the assigned - expert, skips non-local experts, then atomically claims a position within - that expert's block and copies the hidden state + prob + source index. + Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple (token, topk) pairs. + valid_tokens gates which pairs are actually processed — required for CUDA graph + compatibility since the grid size never changes across steps. """ - # Each program handles one (token, topk) pair - pair = tl.program_id(0) - tok = pair // topk - k = pair % topk - if tok >= num_tokens: - return - eid = tl.load(routing_map_ptr + tok * topk + k) - lid = eid - local_expert_start - # Skip tokens routed to non-local experts - if lid < 0 or lid >= num_local_experts: + pid = tl.program_id(0) + valid_tokens = tl.load(valid_tokens_ptr) + valid_pairs = valid_tokens * topk + if pid >= valid_pairs: return - # Atomically claim a position within this expert's aligned block - pos = tl.atomic_add(counters_ptr + lid, 1) - # Copy hidden state row - for h in tl.range(0, hidden_dim, BLOCK_H): - o = h + tl.arange(0, BLOCK_H) - m = o < hidden_dim - tl.store( - out_hidden_ptr + pos * hidden_dim + o, - tl.load(hidden_ptr + tok * hidden_dim + o, mask=m), - mask=m, - ) - tl.store(out_probs_ptr + pos, tl.load(probs_ptr + tok * topk + k)) - # Record source token index for unpermute - tl.store(out_src_idx_ptr + pos, tok) + for pair in tl.range(pid, max_pairs, NUM_BLOCKS): + tok = pair // topk + if tok < valid_tokens: + k = pair % topk + eid = tl.load(routing_map_ptr + tok * topk + k) + lid = eid - local_expert_start + # Skip tokens routed to non-local experts + if lid >= 0 and lid < num_local_experts: + # Atomically claim a position within this expert's aligned block + pos = tl.atomic_add(counters_ptr + lid, 1) + # Copy hidden state row + for h in tl.range(0, hidden_dim, BLOCK_H): + o = h + tl.arange(0, BLOCK_H) + m = o < hidden_dim + tl.store( + out_hidden_ptr + pos * hidden_dim + o, + tl.load(hidden_ptr + tok * hidden_dim + o, mask=m), + mask=m, + ) + tl.store(out_probs_ptr + pos, tl.load(probs_ptr + tok * topk + k)) + # Record source token index for unpermute + tl.store(out_src_idx_ptr + pos, tok) def permute_tokens( @@ -173,6 +290,7 @@ def permute_tokens( routing_map: torch.Tensor, local_expert_start: int, num_local_experts: int, + valid_tokens: torch.Tensor, alignment: int = 1, ) -> tuple: """Permute tokens into expert-grouped order. @@ -181,11 +299,14 @@ def permute_tokens( permutation in a single call. Args: - hidden_states: [num_tokens, hidden_size] input. - probs: [num_tokens, topk] routing probabilities. - routing_map: [num_tokens, topk] expert assignments. + hidden_states: [max_tokens, hidden_size] input. Only the first valid_tokens + rows are valid; the rest are ignored. + probs: [max_tokens, topk] routing probabilities. + routing_map: [max_tokens, topk] expert assignments. local_expert_start: first global expert index on this rank. num_local_experts: number of experts on this rank. + valid_tokens: scalar int32 CUDA tensor with the number of valid tokens this + iteration. Fixed address; value updated each step before graph replay. alignment: per-expert token alignment (default 1). Returns: @@ -197,13 +318,13 @@ def permute_tokens( outputs back and by activation kernels to skip padding rows (-1). - inclusive_offsets: [num_local_experts] int32 cumulative offsets for grouped_mm """ - num_tokens, hidden_dim = hidden_states.shape + max_tokens, hidden_dim = hidden_states.shape topk = probs.shape[1] # Count how many (token, topk) pairs are routed to each local expert. - # Non-local experts are ignored. Result is [num_local_experts] int32. + # Non-local experts and rows beyond valid_tokens are ignored. tokens_per_expert = compute_local_tokens_per_expert( - routing_map, local_expert_start, num_local_experts + routing_map, local_expert_start, num_local_experts, valid_tokens ) # exclusive_expert_offsets[i] = start of expert i's block in the padded output. @@ -213,15 +334,21 @@ def permute_tokens( exclusive_expert_offsets, inclusive_expert_offsets = compute_expert_offsets( tokens_per_expert, alignment=alignment ) - output_size = num_tokens * min(topk, num_local_experts) + alignment * num_local_experts + # Output sized at max to keep allocations fixed across steps (CUDA graph compatible). + output_size = max_tokens * min(topk, num_local_experts) + alignment * num_local_experts permuted_hidden = torch.empty( output_size, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device ) permuted_probs = torch.empty(output_size, dtype=probs.dtype, device=probs.device) - permutation_map = torch.full((output_size,), -1, dtype=torch.int32, device=probs.device) + permutation_map = torch.empty(output_size, dtype=torch.int32, device=probs.device) + # Only initialize [0, n_used) to -1; activation and unpermute kernels are gated + # by the same inclusive_expert_offsets[-1] pointer so they never read beyond n_used. + init_permutation_map(permutation_map, inclusive_expert_offsets[-1:]) BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) - _permute_tokens_kernel[(num_tokens * topk,)]( + max_pairs = max_tokens * topk + NUM_BLOCKS = min(max_pairs, 512) + _permute_tokens_kernel[(NUM_BLOCKS,)]( hidden_states, probs, routing_map, @@ -229,43 +356,80 @@ def permute_tokens( permuted_probs, permutation_map, exclusive_expert_offsets, - num_tokens, + valid_tokens, hidden_dim, + max_pairs, topk, local_expert_start, num_local_experts, BLOCK_H=BLOCK_H, + NUM_BLOCKS=NUM_BLOCKS, ) return permuted_hidden, permuted_probs, permutation_map, inclusive_expert_offsets +@triton.jit +def _zero_output_rows_kernel( + output_ptr, # [num_tokens, hidden_dim] fp32 buffer to partially zero + valid_tokens_ptr, # scalar int32 CUDA tensor: number of rows to zero + hidden_dim, # hidden dimension + num_tokens, # max token count (fixed for CG) + BLOCK_H: tl.constexpr, + NUM_BLOCKS: tl.constexpr, # grid size (fixed for CG) +): + """Zero rows [0, valid_tokens) of the fp32 output buffer. + + Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. + valid_tokens gates which rows are zeroed — required for CUDA graph compatibility. + """ + pid = tl.program_id(0) + valid_tokens = tl.load(valid_tokens_ptr) + if pid >= valid_tokens: + return + zero = tl.zeros([BLOCK_H], dtype=tl.float32) + for row in tl.range(pid, num_tokens, NUM_BLOCKS): + if row < valid_tokens: + for h in tl.range(0, hidden_dim, BLOCK_H): + o = h + tl.arange(0, BLOCK_H) + m = o < hidden_dim + tl.store(output_ptr + row * hidden_dim + o, zero, mask=m) + + @triton.jit def _unpermute_tokens_kernel( expert_out_ptr, # [output_size, hidden_dim] expert outputs in permuted order probs_ptr, # [output_size] fp32 routing probabilities (permuted) src_idx_ptr, # [output_size] permutation_map: original token index, or -1 for padding - output_ptr, # [num_tokens, hidden_dim] fp32 output buffer (zeroed by caller) + output_ptr, # [max_tokens, hidden_dim] fp32 output buffer (zeroed by caller) + n_used_ptr, # pointer to inclusive_expert_offsets[-1]: number of used rows this iteration hidden_dim, # hidden dimension + max_rows, # output_size (fixed for CG) BLOCK_H: tl.constexpr, # tile size for processing hidden_dim + NUM_BLOCKS: tl.constexpr, # grid size (fixed for CG) ): """Scatter weighted expert outputs back to original token positions. - Grid: one program per row of expert_out. Padding rows (src_idx == -1) are - skipped. Multiple topk selections for the same token are accumulated via - atomic adds. All arithmetic is in fp32 to avoid precision loss. + Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. + Rows beyond n_used and alignment-padding rows (src_idx == -1) are skipped. + Multiple topk selections for the same token are accumulated via atomic adds. + All arithmetic is in fp32 to avoid precision loss. """ - row = tl.program_id(0) - source_idx = tl.load(src_idx_ptr + row) - # Skip padding rows - if source_idx < 0: + pid = tl.program_id(0) + n_used = tl.load(n_used_ptr) + if pid >= n_used: return - prob = tl.load(probs_ptr + row) # fp32 - for h in tl.range(0, hidden_dim, BLOCK_H): - offsets = h + tl.arange(0, BLOCK_H) - m = offsets < hidden_dim - # Upcast bf16 expert output to fp32 before multiply + accumulate - v = tl.load(expert_out_ptr + row * hidden_dim + offsets, mask=m).to(tl.float32) - tl.atomic_add(output_ptr + source_idx * hidden_dim + offsets, v * prob, mask=m) + for row in tl.range(pid, max_rows, NUM_BLOCKS): + if row < n_used: + source_idx = tl.load(src_idx_ptr + row) + # Skip alignment-padding rows within the used range + if source_idx >= 0: + prob = tl.load(probs_ptr + row) # fp32 + for h in tl.range(0, hidden_dim, BLOCK_H): + offsets = h + tl.arange(0, BLOCK_H) + m = offsets < hidden_dim + # Upcast bf16 expert output to fp32 before multiply + accumulate + v = tl.load(expert_out_ptr + row * hidden_dim + offsets, mask=m).to(tl.float32) + tl.atomic_add(output_ptr + source_idx * hidden_dim + offsets, v * prob, mask=m) def unpermute_tokens( @@ -273,22 +437,53 @@ def unpermute_tokens( permuted_probs: torch.Tensor, permutation_map: torch.Tensor, num_tokens: int, + n_used: torch.Tensor, + valid_tokens: torch.Tensor, + out: torch.Tensor = None, ) -> torch.Tensor: """Unpermute expert outputs back to original token order. Accumulates in fp32 to avoid precision loss from multiple topk atomic adds. Returns fp32 output. + + Args: + expert_output: [output_size, hidden_dim] expert outputs in permuted order. + permuted_probs: [output_size] fp32 routing probabilities. + permutation_map: [output_size] int32, original token index or -1 for padding. + num_tokens: max token count (output buffer height); always fixed for CG. + n_used: scalar int32 CUDA tensor = inclusive_expert_offsets[-1]. Rows + beyond this are skipped without reading permutation_map. + valid_tokens: scalar int32 CUDA tensor = number of valid input tokens. + Only rows [0, valid_tokens) are zeroed; all atomic_adds target + source_idx < valid_tokens so rows beyond are never written. + out: optional pre-allocated [num_tokens, hidden_dim] fp32 output buffer. + Pass a symmetric memory tensor to scatter directly into it, avoiding + a separate copy before RSV. If None, a local buffer is allocated. """ assert ( permuted_probs.dtype == torch.float32 ), f"permuted_probs must be fp32, got {permuted_probs.dtype}" output_size, hidden_dim = expert_output.shape - output = torch.zeros(num_tokens, hidden_dim, dtype=torch.float32, device=expert_output.device) BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) - _unpermute_tokens_kernel[(output_size,)]( - expert_output, permuted_probs, permutation_map, output, hidden_dim, BLOCK_H=BLOCK_H + if out is None: + out = torch.empty(num_tokens, hidden_dim, dtype=torch.float32, device=expert_output.device) + NUM_BLOCKS_ZERO = min(num_tokens, 512) + _zero_output_rows_kernel[(NUM_BLOCKS_ZERO,)]( + out, valid_tokens, hidden_dim, num_tokens, BLOCK_H=BLOCK_H, NUM_BLOCKS=NUM_BLOCKS_ZERO + ) + NUM_BLOCKS = min(output_size, 512) + _unpermute_tokens_kernel[(NUM_BLOCKS,)]( + expert_output, + permuted_probs, + permutation_map, + out, + n_used, + hidden_dim, + output_size, + BLOCK_H=BLOCK_H, + NUM_BLOCKS=NUM_BLOCKS, ) - return output + return out @triton.jit @@ -301,75 +496,80 @@ def _permute_quantize_mxfp8_kernel( out_probs_ptr, out_src_idx_ptr, counters_ptr, - num_tokens, + valid_tokens_ptr, # scalar int32 CUDA tensor: number of valid tokens this iteration K, n_col_blocks, + max_pairs, # max_tokens * topk (fixed for CG) topk: tl.constexpr, local_expert_start, num_local_experts: tl.constexpr, REAL_GROUPS: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_GROUPS: tl.constexpr, + NUM_BLOCKS: tl.constexpr, # grid size (fixed for CG) ): """Fused permute + MXFP8 quantize + swizzle in one kernel. - Grid: (num_tokens * topk,) — one program per (token, k) pair. - Reads BF16 from source token, quantizes to FP8 e4m3, writes FP8 data + - swizzled e8m0 scales to the permuted write position. + Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple (token, topk) pairs. + valid_tokens gates which pairs are actually processed — required for CUDA graph + compatibility since the grid size never changes across steps. """ - pair = tl.program_id(0) - tok = pair // topk - k = pair % topk - if tok >= num_tokens: - return - eid = tl.load(routing_map_ptr + tok * topk + k) - lid = eid - local_expert_start - if lid < 0 or lid >= num_local_experts: + pid = tl.program_id(0) + valid_tokens = tl.load(valid_tokens_ptr) + valid_pairs = valid_tokens * topk + if pid >= valid_pairs: return - pos = tl.atomic_add(counters_ptr + lid, 1) - - # Load full row from source token - offs = tl.arange(0, BLOCK_K) - mask = offs < K - x = tl.load(hidden_ptr + tok * K + offs, mask=mask, other=0.0).to(tl.float32) - - # Per-group-of-32 quantization - x_grouped = tl.reshape(x, [BLOCK_GROUPS, 32]) - abs_grouped = tl.abs(x_grouped) - max_vals = tl.max(abs_grouped, axis=1) - - dequant_scale = max_vals / 448.0 - dequant_exp = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 - dequant_rounded = dequant_exp.to(tl.float32, bitcast=True) - quant_scale = tl.where(dequant_rounded == 0, 0.0, 1.0 / dequant_rounded) - - quantized = x_grouped * quant_scale[:, None] - quantized_flat = tl.reshape(quantized, [BLOCK_K]) - out_fp8 = quantized_flat.to(tl.float8e4nv) - - # Store FP8 data at permuted position - tl.store(out_fp8_ptr + pos * K + offs, out_fp8, mask=mask) - - # Store swizzled scales at permuted position - scale_exp = (dequant_exp >> 23).to(tl.uint8) - col_offs = tl.arange(0, BLOCK_GROUPS) - col_mask = col_offs < REAL_GROUPS - - macro_row_block = pos // 128 - macro_col_block = col_offs // 4 - local_row = pos % 128 - local_col = col_offs % 4 - group = local_row // 32 - sub_row = local_row % 32 - tile_idx = macro_row_block * n_col_blocks + macro_col_block - swizzled_offs = tile_idx * 512 + sub_row * 16 + group * 4 + local_col - - tl.store(out_scale_ptr + swizzled_offs, scale_exp, mask=col_mask) - - # Store prob and source index - tl.store(out_probs_ptr + pos, tl.load(probs_ptr + tok * topk + k)) - tl.store(out_src_idx_ptr + pos, tok) + for pair in tl.range(pid, max_pairs, NUM_BLOCKS): + tok = pair // topk + if tok < valid_tokens: + k = pair % topk + eid = tl.load(routing_map_ptr + tok * topk + k) + lid = eid - local_expert_start + if lid >= 0 and lid < num_local_experts: + pos = tl.atomic_add(counters_ptr + lid, 1) + + # Load full row from source token + offs = tl.arange(0, BLOCK_K) + mask = offs < K + x = tl.load(hidden_ptr + tok * K + offs, mask=mask, other=0.0).to(tl.float32) + + # Per-group-of-32 quantization + x_grouped = tl.reshape(x, [BLOCK_GROUPS, 32]) + abs_grouped = tl.abs(x_grouped) + max_vals = tl.max(abs_grouped, axis=1) + + dequant_scale = max_vals / 448.0 + dequant_exp = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + dequant_rounded = dequant_exp.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_rounded == 0, 0.0, 1.0 / dequant_rounded) + + quantized = x_grouped * quant_scale[:, None] + quantized_flat = tl.reshape(quantized, [BLOCK_K]) + out_fp8 = quantized_flat.to(tl.float8e4nv) + + # Store FP8 data at permuted position + tl.store(out_fp8_ptr + pos * K + offs, out_fp8, mask=mask) + + # Store swizzled scales at permuted position + scale_exp = (dequant_exp >> 23).to(tl.uint8) + col_offs = tl.arange(0, BLOCK_GROUPS) + col_mask = col_offs < REAL_GROUPS + + macro_row_block = pos // 128 + macro_col_block = col_offs // 4 + local_row = pos % 128 + local_col = col_offs % 4 + group = local_row // 32 + sub_row = local_row % 32 + tile_idx = macro_row_block * n_col_blocks + macro_col_block + swizzled_offs = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + + tl.store(out_scale_ptr + swizzled_offs, scale_exp, mask=col_mask) + + # Store prob and source index + tl.store(out_probs_ptr + pos, tl.load(probs_ptr + tok * topk + k)) + tl.store(out_src_idx_ptr + pos, tok) def permute_and_quantize_mxfp8( @@ -378,6 +578,7 @@ def permute_and_quantize_mxfp8( routing_map: torch.Tensor, local_expert_start: int, num_local_experts: int, + valid_tokens: torch.Tensor, alignment: int = 128, ) -> tuple: """Fused permute + MXFP8 quantize + swizzle. @@ -387,11 +588,14 @@ def permute_and_quantize_mxfp8( single kernel launch. Args: - hidden_states: [num_tokens, hidden_size] BF16 input. - probs: [num_tokens, topk] routing probabilities. - routing_map: [num_tokens, topk] expert assignments. + hidden_states: [max_tokens, hidden_size] BF16 input. Only the first + valid_tokens rows are valid; the rest are ignored. + probs: [max_tokens, topk] routing probabilities. + routing_map: [max_tokens, topk] expert assignments. local_expert_start: first global expert index on this rank. num_local_experts: number of experts on this rank. + valid_tokens: scalar int32 CUDA tensor with the number of valid tokens this + iteration. Fixed address; value updated each step before graph replay. alignment: per-expert token alignment (default 128, required for MXFP8 swizzle). Returns: @@ -403,13 +607,14 @@ def permute_and_quantize_mxfp8( """ from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor - num_tokens, K = hidden_states.shape + max_tokens, K = hidden_states.shape topk = probs.shape[1] assert K % 32 == 0 # Count how many (token, topk) pairs are routed to each local expert. + # Rows beyond valid_tokens are ignored. tokens_per_expert = compute_local_tokens_per_expert( - routing_map, local_expert_start, num_local_experts + routing_map, local_expert_start, num_local_experts, valid_tokens ) # exclusive_expert_offsets[i] = start of expert i's block in the padded output. @@ -417,7 +622,8 @@ def permute_and_quantize_mxfp8( exclusive_expert_offsets, inclusive_expert_offsets = compute_expert_offsets( tokens_per_expert, alignment=alignment ) - output_size = num_tokens * min(topk, num_local_experts) + alignment * num_local_experts + # Output sized at max to keep allocations fixed across steps (CUDA graph compatible). + output_size = max_tokens * min(topk, num_local_experts) + alignment * num_local_experts scale_cols = K // 32 n_row_blocks = _ceil_div(output_size, 128) @@ -427,12 +633,14 @@ def permute_and_quantize_mxfp8( out_fp8 = torch.empty(output_size, K, dtype=torch.float8_e4m3fn, device=hidden_states.device) out_scale = torch.zeros(total_scale_bytes, dtype=torch.uint8, device=hidden_states.device) permuted_probs = torch.empty(output_size, dtype=probs.dtype, device=probs.device) - permutation_map = torch.full((output_size,), -1, dtype=torch.int32, device=probs.device) + permutation_map = torch.empty(output_size, dtype=torch.int32, device=probs.device) + init_permutation_map(permutation_map, inclusive_expert_offsets[-1:]) BLOCK_K = triton.next_power_of_2(K) BLOCK_GROUPS = BLOCK_K // 32 - - _permute_quantize_mxfp8_kernel[(num_tokens * topk,)]( + max_pairs = max_tokens * topk + NUM_BLOCKS = min(max_pairs, 512) + _permute_quantize_mxfp8_kernel[(NUM_BLOCKS,)]( hidden_states, probs, routing_map, @@ -441,15 +649,17 @@ def permute_and_quantize_mxfp8( permuted_probs, permutation_map, exclusive_expert_offsets, - num_tokens, + valid_tokens, K, n_col_blocks, + max_pairs, topk, local_expert_start, num_local_experts, REAL_GROUPS=scale_cols, BLOCK_K=BLOCK_K, BLOCK_GROUPS=BLOCK_GROUPS, + NUM_BLOCKS=NUM_BLOCKS, ) permuted_mxfp8 = MXFP8Tensor( diff --git a/megatron/core/inference/moe/vllm_fused_moe.py b/megatron/core/inference/moe/vllm_fused_moe.py new file mode 100644 index 00000000000..85a1e52ccef --- /dev/null +++ b/megatron/core/inference/moe/vllm_fused_moe.py @@ -0,0 +1,667 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +# Some of this code was adopted from https://github.com/vllm-project/vllm. +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. +"""vLLM-style Triton fused MoE kernel (BF16) for Megatron inference. + +CUDA-graph compatible: all indirection table construction happens on-device +via Triton kernels with fixed-size buffers and valid_tokens gating. +""" + +from typing import Optional +from unittest.mock import MagicMock + +import torch + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + +from megatron.core.inference.moe.fused_moe import ActivationType +from megatron.core.inference.moe.permute import ( + _get_num_sms, + compute_expert_offsets, + compute_local_tokens_per_expert, +) + +# --------------------------------------------------------------------------- +# Triton kernel – BF16 grouped GEMM with indirect token addressing +# --------------------------------------------------------------------------- + + +def _select_block_size_m(max_tokens: int) -> int: + """Select BLOCK_SIZE_M based on the token buffer size. + + Smaller tiles reduce padding waste in the indirection table when each + expert sees few tokens (decode). Larger tiles improve compute density + for large batches (prefill). Minimum is 16 (tl.dot requirement on NVIDIA). + """ + if max_tokens <= 32: + return 16 + if max_tokens <= 96: + return 32 + if max_tokens <= 512: + return 64 + return 128 + + +# BLOCK_SIZE_M is NOT in these configs — it is selected on the Python side by +# _select_block_size_m and passed as a caller-provided constexpr. Each unique +# BLOCK_SIZE_M value triggers independent autotuning over these configs. +_AUTOTUNE_CONFIGS = [ + # GROUP_SIZE_M=1: better when each expert has few tokens (decode, sparse activation). + triton.Config( + {'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}, num_warps=4, num_stages=4 + ), + triton.Config( + {'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}, num_warps=4, num_stages=5 + ), + triton.Config( + {'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 1}, num_warps=4, num_stages=2 + ), + triton.Config( + {'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 1}, num_warps=4, num_stages=3 + ), + triton.Config( + {'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}, num_warps=4, num_stages=4 + ), + triton.Config( + {'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}, num_warps=4, num_stages=5 + ), + triton.Config( + {'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 1}, num_warps=4, num_stages=3 + ), + triton.Config( + {'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}, num_warps=8, num_stages=3 + ), + triton.Config( + {'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}, num_warps=8, num_stages=4 + ), + # GROUP_SIZE_M=8: better for large batches where experts see many tokens. + triton.Config( + {'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4 + ), + triton.Config( + {'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=5 + ), + triton.Config( + {'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3 + ), + triton.Config( + {'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=5 + ), + triton.Config( + {'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4 + ), + triton.Config( + {'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=5 + ), + triton.Config( + {'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3 + ), + triton.Config( + {'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=5 + ), + triton.Config( + {'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3 + ), + triton.Config( + {'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=5 + ), + triton.Config( + {'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3 + ), + triton.Config( + {'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3 + ), + triton.Config( + {'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4 + ), + triton.Config( + {'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_warps=8, num_stages=4 + ), + triton.Config( + {'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_warps=8, num_stages=5 + ), +] + + +@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=['N', 'K']) +@triton.jit +def _fused_moe_kernel( + # Pointers + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Dimensions + N, + K, + num_valid_tokens, + num_sms, + # Strides + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Flags / constexprs + MUL_ROUTED_WEIGHT: tl.constexpr, + FUSE_SQUARED_RELU: tl.constexpr, + top_k: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Persistent fused MoE grouped GEMM with indirect token addressing. + + Launches a fixed grid of num_sms CTAs. Each CTA loops over its share + of tiles, with total tile count determined device-side from + num_tokens_post_padded. This decouples grid size from buffer size, + keeping the kernel CUDA-graph safe while avoiding excess CTA overhead. + """ + pid = tl.program_id(0) + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + num_pid_m = tl.cdiv(num_tokens_post_padded, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tiles_per_cta = tl.cdiv(total_tiles, num_sms) + tile_start = pid * tiles_per_cta + tile_end = tl.minimum(tile_start + tiles_per_cta, total_tiles) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + offs_k = tl.arange(0, BLOCK_SIZE_K) + + for tile_id in tl.range(tile_start, tile_end): + # GROUP_SIZE_M swizzle: tile_id → (pid_m, pid_n) + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token_id = pid_m * BLOCK_SIZE_M + offs + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) + b_block_ptr = tl.make_block_ptr( + base=b_ptr + off_experts * stride_be, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(0, 1), + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_block_ptr, boundary_check=(0, 1), padding_option="zero") + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) + + if FUSE_SQUARED_RELU: + accumulator = tl.maximum(accumulator, 0.0) + accumulator *= accumulator + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator *= moe_weight[:, None] + + accumulator = accumulator.to(tl.bfloat16) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +# --------------------------------------------------------------------------- +# Indirection table construction (CUDA-graph safe, fully on-device) +# --------------------------------------------------------------------------- + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def _init_sorted_ids_kernel( + sorted_token_ids_ptr, + expert_ids_ptr, + max_sorted, + max_blocks, + SENTINEL: tl.constexpr, + BLOCK: tl.constexpr, +): + """Initialize sorted_token_ids to SENTINEL and expert_ids to -1.""" + pid = tl.program_id(0) + block_start = pid * BLOCK + if block_start < max_sorted or block_start < max_blocks: + offs = block_start + tl.arange(0, BLOCK) + tl.store(sorted_token_ids_ptr + offs, SENTINEL, mask=offs < max_sorted) + tl.store(expert_ids_ptr + offs, -1, mask=offs < max_blocks) + + +@triton.jit +def _scatter_token_indices_kernel( + routing_map_ptr, + sorted_token_ids_ptr, + counters_ptr, + valid_tokens_ptr, + topk: tl.constexpr, + local_expert_start, + num_local_experts: tl.constexpr, + max_pairs, + BLOCK_SIZE: tl.constexpr, +): + """Scatter local-expert pair indices into the padded indirection table. + + Only local expert pairs are written; non-local pairs are skipped (the + _moe_sum kernel handles them by checking the routing map directly). + """ + pid = tl.program_id(0) + valid_tokens = tl.load(valid_tokens_ptr) + valid_pairs = valid_tokens * topk + if pid * BLOCK_SIZE >= valid_pairs: + return + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < valid_pairs + + eids = tl.load(routing_map_ptr + offs, mask=mask, other=-1) + lids = eids - local_expert_start + is_local = (lids >= 0) & (lids < num_local_experts) & mask + + local_pos = tl.atomic_add(counters_ptr + lids, 1, mask=is_local) + tl.store(sorted_token_ids_ptr + local_pos, offs, mask=is_local) + + +@triton.jit +def _fill_expert_block_ids_kernel( + expert_ids_ptr, + exclusive_offsets_ptr, + inclusive_offsets_ptr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK: tl.constexpr, +): + """Fill expert_ids with expert index for each BLOCK_SIZE_M block. + + Grid: one CTA per expert (parallelised across experts). + Inner loop uses vectorised stores of BLOCK elements at a time. + """ + e = tl.program_id(0) + start_block = tl.load(exclusive_offsets_ptr + e) // BLOCK_SIZE_M + end_block = tl.load(inclusive_offsets_ptr + e) // BLOCK_SIZE_M + num_blocks = end_block - start_block + for off in tl.range(0, num_blocks, BLOCK): + idxs = start_block + off + tl.arange(0, BLOCK) + tl.store(expert_ids_ptr + idxs, e, mask=idxs < end_block) + + +def _moe_align_block_size_cuda_graphable( + routing_map: torch.Tensor, + block_size: int, + num_local_experts: int, + local_expert_start: int, + valid_tokens: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build indirection tables for the vLLM kernel, fully on-device. + + Replaces the original _moe_align_block_size which used .item() calls + and host-side loops. All buffers are allocated at fixed max sizes so + the function is safe for CUDA graph capture. + + Args: + routing_map: [max_tokens, topk] expert assignments. + block_size: BLOCK_SIZE_M for the vLLM kernel. + num_local_experts: experts on this rank. + local_expert_start: first global expert index on this rank. + valid_tokens: scalar int32 CUDA tensor. + + Returns: + sorted_token_ids: [max_sorted] int32 indirection table. + expert_ids: [max_blocks] int32 expert per block. + num_tokens_post_padded: [1] int32 (local expert padded count). + """ + max_tokens, topk = routing_map.shape + device = routing_map.device + + max_sorted = max_tokens * topk + block_size * (num_local_experts + 1) + max_blocks = _ceil_div(max_sorted, block_size) + sentinel = max_tokens * topk + + sorted_token_ids = torch.empty(max_sorted, dtype=torch.int32, device=device) + expert_ids = torch.empty(max_blocks, dtype=torch.int32, device=device) + + INIT_BLOCK = 1024 + init_grid = _ceil_div(max(max_sorted, max_blocks), INIT_BLOCK) + _init_sorted_ids_kernel[(init_grid,)]( + sorted_token_ids, expert_ids, max_sorted, max_blocks, SENTINEL=sentinel, BLOCK=INIT_BLOCK + ) + + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts, valid_tokens, persistent=True + ) + exclusive_offsets, inclusive_offsets = compute_expert_offsets( + tokens_per_expert, alignment=block_size + ) + + _fill_expert_block_ids_kernel[(num_local_experts,)]( + expert_ids, exclusive_offsets, inclusive_offsets, BLOCK_SIZE_M=block_size, BLOCK=128 + ) + + max_pairs = max_tokens * topk + SCATTER_BLOCK = 256 + scatter_grid = _ceil_div(max_pairs, SCATTER_BLOCK) + _scatter_token_indices_kernel[(scatter_grid,)]( + routing_map, + sorted_token_ids, + exclusive_offsets, + valid_tokens, + topk, + local_expert_start, + num_local_experts, + max_pairs, + BLOCK_SIZE=SCATTER_BLOCK, + ) + + num_tokens_post_padded = inclusive_offsets[-1:] + return sorted_token_ids, expert_ids, num_tokens_post_padded + + +# --------------------------------------------------------------------------- +# Kernel launcher +# --------------------------------------------------------------------------- + + +def _invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + block_size_m: int, + fuse_squared_relu: bool = False, +): + """Launch the persistent Triton fused-MoE kernel for one GEMM pass. + + Uses a fixed grid of NUM_SMS CTAs for CUDA-graph safety. Each CTA + loops over its share of tiles, with actual work determined device-side. + """ + M = A.size(0) + num_tokens = M * top_k + num_sms = _get_num_sms(A.device) + + _fused_moe_kernel[(num_sms,)]( + A, + B, + C, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.size(1), + B.size(2), + num_tokens, + num_sms, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + MUL_ROUTED_WEIGHT=mul_routed_weight, + FUSE_SQUARED_RELU=fuse_squared_relu, + top_k=top_k, + BLOCK_SIZE_M=block_size_m, + ) + + +# --------------------------------------------------------------------------- +# Fused topk reduction (replaces torch.sum + copy) +# --------------------------------------------------------------------------- + + +@triton.jit +def _moe_sum_kernel( + input_ptr, + output_ptr, + topk_weights_ptr, + valid_tokens_ptr, + routing_map_ptr, + local_expert_start, + num_local_experts: tl.constexpr, + K, + topk: tl.constexpr, + BLOCK_K: tl.constexpr, + NUM_K_BLOCKS: tl.constexpr, +): + """Reduce topk dimension with valid_tokens gating and routing weight application. + + input: [max_tokens * topk, K] bf16 + output: [max_tokens, K] — dtype matches the output buffer (fp32 or bf16) + + For token t < valid_tokens: output[t] = sum of input[t*topk+k] * prob[t*topk+k] + over topk slots k where the expert is local. Non-local slots are skipped + (their values in `input` are undefined because FC2 only processes + local-expert blocks). + For token t >= valid_tokens: output[t] = 0. + Routing weight multiplication and accumulation in fp32 for numerical accuracy. + """ + token_id = tl.program_id(0).to(tl.int64) + valid_tokens = tl.load(valid_tokens_ptr) + is_valid = token_id < valid_tokens + + for k_idx in range(NUM_K_BLOCKS): + offs_k = k_idx * BLOCK_K + tl.arange(0, BLOCK_K) + k_mask = offs_k < K + + acc = tl.zeros([BLOCK_K], dtype=tl.float32) + if is_valid: + base = token_id * topk * K + for t in range(topk): + eid = tl.load(routing_map_ptr + token_id * topk + t) + lid = eid - local_expert_start + if lid >= 0 and lid < num_local_experts: + v = tl.load(input_ptr + base + t * K + offs_k, mask=k_mask, other=0.0) + w = tl.load(topk_weights_ptr + token_id * topk + t) + acc += v.to(tl.float32) * w + + tl.store(output_ptr + token_id * K + offs_k, acc, mask=k_mask) + + +def _moe_sum( + input: torch.Tensor, + topk_weights: torch.Tensor, + max_tokens: int, + topk: int, + K: int, + valid_tokens: torch.Tensor, + routing_map: torch.Tensor, + local_expert_start: int, + num_local_experts: int, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Fused topk reduction: [max_tokens*topk, K] bf16 → [max_tokens, K]. + + Applies routing weights and reduces over topk in a single kernel. + Accumulates in fp32. When `out` is None, allocates and returns an fp32 + buffer. When `out` is provided (e.g. the RSV symmetric memory tensor), + writes directly into it — tl.store handles the cast to the buffer's dtype. + Rows beyond valid_tokens are zeroed. Only accumulates contributions from + local experts; non-local topk slots are skipped (their values in `input` + are undefined). + """ + if out is None: + out = torch.empty(max_tokens, K, dtype=torch.float32, device=input.device) + BLOCK_K = min(triton.next_power_of_2(K), 1024) + NUM_K_BLOCKS = _ceil_div(K, BLOCK_K) + _moe_sum_kernel[(max_tokens,)]( + input, + out, + topk_weights, + valid_tokens, + routing_map, + local_expert_start, + num_local_experts, + K, + topk=topk, + BLOCK_K=BLOCK_K, + NUM_K_BLOCKS=NUM_K_BLOCKS, + ) + return out + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def vllm_fused_moe( + hidden_states: torch.Tensor, + probs: torch.Tensor, + fc1_weight: torch.Tensor, + fc2_weight: torch.Tensor, + activation_type: ActivationType, + num_local_experts: int, + local_expert_start: int, + valid_tokens: torch.Tensor, + routing_map: torch.Tensor, + out: Optional[torch.Tensor] = None, + num_tokens_hint: Optional[int] = None, +) -> torch.Tensor: + """Fused MoE using the vLLM Triton grouped-GEMM kernel (BF16). + + CUDA-graph compatible: indirection tables are built entirely on-device + using fixed-size buffers gated by valid_tokens. + + Args: + hidden_states: [max_tokens, hidden_size] BF16 input. Only the first + valid_tokens rows are valid; the rest are ignored. + probs: [max_tokens, topk] fp32 routing probabilities. + fc1_weight: [num_local_experts, fc1_out, hidden_size] BF16. + fc2_weight: [num_local_experts, hidden_size, fc1_out] BF16. + activation_type: ActivationType enum. + num_local_experts: experts on this rank. + local_expert_start: first global expert index on this rank. + valid_tokens: scalar int32 CUDA tensor with number of valid tokens. + routing_map: [max_tokens, topk] int expert assignments. + out: optional [max_tokens, hidden_size] output buffer (e.g. the RSV + symmetric memory tensor). If None, an fp32 buffer is allocated. + When provided, tl.store casts to the buffer's dtype automatically. + num_tokens_hint: optional host-side int with the expected number of + valid tokens (e.g. batch_size * ep_size). Used to select a better + BLOCK_SIZE_M instead of using the worst-case buffer size. + + Returns: + [max_tokens, hidden_size] output (fp32 when out=None, else out's dtype). + tl.store handles the implicit cast when out is a different dtype. + """ + assert ( + hidden_states.dtype == torch.bfloat16 + ), f"vllm_fused_moe requires bf16 input, got {hidden_states.dtype}" + + max_tokens = hidden_states.size(0) + topk = routing_map.shape[1] + effective_tokens = num_tokens_hint if num_tokens_hint is not None else max_tokens + block_size_m = _select_block_size_m(effective_tokens) + + sorted_token_ids, expert_ids, num_post_padded = _moe_align_block_size_cuda_graphable( + routing_map, block_size_m, num_local_experts, local_expert_start, valid_tokens + ) + num_valid = max_tokens * topk + + N = fc1_weight.size(1) + K = fc1_weight.size(2) + + topk_weights_flat = probs.reshape(-1).contiguous() + + # FC1 + activation: [max_tokens, K] → [max_tokens*topk, N] + assert activation_type == ActivationType.SQUARED_RELU + intermediate1 = torch.empty( + num_valid, N, dtype=hidden_states.dtype, device=hidden_states.device + ) + _invoke_fused_moe_kernel( + hidden_states, + fc1_weight, + intermediate1, + topk_weights_flat, + sorted_token_ids, + expert_ids, + num_post_padded, + mul_routed_weight=False, + top_k=topk, + block_size_m=block_size_m, + fuse_squared_relu=True, + ) + + # FC2: [max_tokens*topk, N] → [max_tokens*topk, K], without routing weights. + # Routing weights are applied in the reduction kernel to avoid an extra + # bf16 truncation of prob-scaled values before the topk summation. + # Only local-expert blocks are processed; non-local positions are left + # undefined and skipped by _moe_sum (which checks the routing map). + intermediate3 = torch.empty( + num_valid, K, dtype=hidden_states.dtype, device=hidden_states.device + ) + _invoke_fused_moe_kernel( + intermediate1, + fc2_weight, + intermediate3, + topk_weights_flat, + sorted_token_ids, + expert_ids, + num_post_padded, + mul_routed_weight=False, + top_k=1, + block_size_m=block_size_m, + ) + + # Reduce over topk: [max_tokens*topk, K] → [max_tokens, K] + # Applies routing weights and accumulates in fp32, writes directly to + # out (if provided), zeros rows beyond valid_tokens, and skips non-local + # expert slots. + return _moe_sum( + intermediate3, + probs, + max_tokens, + topk, + K, + valid_tokens, + routing_map, + local_expert_start, + num_local_experts, + out=out, + ) diff --git a/megatron/core/inference/sampling/__init__.py b/megatron/core/inference/sampling/__init__.py new file mode 100644 index 00000000000..b2941b33c9e --- /dev/null +++ b/megatron/core/inference/sampling/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.inference.sampling.base import Sampling +from megatron.core.inference.sampling.flashinfer_sampling import FlashInferSampling +from megatron.core.inference.sampling.torch_sampling import TorchSampling + +__all__ = ["Sampling", "TorchSampling", "FlashInferSampling"] diff --git a/megatron/core/inference/sampling/base.py b/megatron/core/inference/sampling/base.py new file mode 100644 index 00000000000..8aa4c416c27 --- /dev/null +++ b/megatron/core/inference/sampling/base.py @@ -0,0 +1,89 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch +from torch import Tensor + + +class Sampling(ABC): + """Abstract base for inference sampling backends. + + Subclasses implement `sample_kernel`. CUDA graphs are added via `CudaGraphManager`. + """ + + @abstractmethod + def sample_kernel( + self, + logits: Tensor, + n: int, + context, + *, + gather_indices: Optional[Tensor] = None, + token_to_request_index: Optional[Tensor] = None, + eager: bool = False, + cache_key: Any = None, + ) -> Tensor: + """Sample `n` tokens from `logits` and return them. + + Args: + logits: Logits tensor of shape `[>=n, vocab_size]`. + n: Number of rows to sample. + context: The active DynamicInferenceContext. + gather_indices: If provided, only sample from `logits[gather_indices[:n], :]`. + token_to_request_index: Per-token request mapping; when set, sampling + parameters are gathered per-token instead of per-request. + eager, cache_key: Consumed by `CudaGraphManager` when it wraps this kernel. + + Returns: + Sampled token ids of shape `[n]`. Under CUDA graph replay, this is a static buffer. + """ + ... + + def sample_speculative( + self, + required_logits: Tensor, + num_decode: int, + num_prefill: int, + num_speculative_tokens: int, + context, + *, + gather_indices: Optional[Tensor] = None, + eager: bool = False, + cache_key: Any = None, + ) -> Tensor: + """Sample tokens for the speculative-verify path. + + Decode requests contribute `1 + num_speculative_tokens` rows; prefill requests contribute 1. + Builds the per-token request mapping and dispatches to `sample_kernel`. + The `sample_kernel` is forced eager so its own `CudaGraphManager` wrapper does not fire. + + When `gather_indices` is supplied, the kernel selects via `logits[gather_indices[:n], :]`. + When `gather_indices` is None, `required_logits` is expected to be already pre-gathered to + the layout described above (e.g. when `materialize_only_last_token_logits=True` upstream). + """ + # CudaGraphManager consumes these args, if it exists. + del eager, cache_key + + n_spec = num_speculative_tokens + num_decode_tokens = num_decode * (1 + n_spec) + num_tokens = num_decode_tokens + num_prefill + device = required_logits.device + + token_to_request_index = torch.cat( + [ + torch.arange(num_decode, device=device).repeat_interleave( + 1 + n_spec, output_size=num_decode_tokens + ), + torch.arange(num_decode, num_decode + num_prefill, device=device), + ] + ) + return self.sample_kernel( + required_logits, + num_tokens, + context, + gather_indices=gather_indices, + token_to_request_index=token_to_request_index, + eager=True, + ) diff --git a/megatron/core/inference/sampling/flashinfer_sampling.py b/megatron/core/inference/sampling/flashinfer_sampling.py new file mode 100644 index 00000000000..c89093daeac --- /dev/null +++ b/megatron/core/inference/sampling/flashinfer_sampling.py @@ -0,0 +1,101 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +from typing import Any, Optional + +import torch +from torch import Tensor + +try: + import flashinfer +except ImportError: + flashinfer = None + +from megatron.core.inference.sampling.base import Sampling +from megatron.core.transformer.cuda_graphs import CudaGraphManager + + +class FlashInferSampling(Sampling): + """Fused FlashInfer sampling, with optional CUDA graph capture/replay.""" + + def __init__( + self, vocab_size: int, rng: torch.Generator, config=None, enable_cuda_graph: bool = False + ) -> None: + self._vocab_size = vocab_size + self._rng = rng + if enable_cuda_graph and config is not None and config.cuda_graph_impl == "local": + CudaGraphManager( + config, + self, + function_name="sample_kernel", + need_backward=False, + inline_capture=True, + ) + CudaGraphManager( + config, + self, + function_name="sample_speculative", + need_backward=False, + inline_capture=True, + ) + + def sample_kernel( + self, + logits: Tensor, + n: int, + context, + *, + gather_indices: Optional[Tensor] = None, + token_to_request_index: Optional[Tensor] = None, + eager: bool = False, + cache_key: Any = None, + ) -> Tensor: + """FlashInfer fused top-k / top-p sampling kernel. + + Args: + logits: Logits tensor of shape `[>=n, vocab_size]`. + n: Number of rows to sample. + context: The active DynamicInferenceContext. + gather_indices: When set, sample from `logits[gather_indices[:n], :]`. + token_to_request_index: When set, sampling parameters are gathered per-token + rather than per-request (used by the speculative path). + eager, cache_key: Consumed by `CudaGraphManager` when it wraps this kernel. + + Returns: + Sampled token ids of shape `[n]`. Under CUDA graph replay, this is a static buffer. + """ + # CudaGraphManager consumes these args, if it exists. + del eager, cache_key + + # Read GPU sampling parameters from the per-step gpu_view mirror. The + # CPU source-of-truth (`active_request_metadata`) is pinned but resident + # on CPU, so reading it here would mix devices with `logits`. + gv = context.gpu_view + if token_to_request_index is None: + temperature = gv.temperature[:n] + top_k = gv.top_k[:n] + top_p = gv.top_p[:n] + else: + temperature = gv.temperature[token_to_request_index] + top_k = gv.top_k[token_to_request_index] + top_p = gv.top_p[token_to_request_index] + + # Clamp temperature to avoid division by 0. + temperature = temperature.clamp(min=1e-6) + if gather_indices is None: + scaled = logits[:n] / temperature.unsqueeze(1) + else: + scaled = logits[gather_indices[:n], :] / temperature.unsqueeze(1) + probs = torch.softmax(scaled, dim=-1) + + # Sentinel values disable filtering: + # top_k=vocab_size keeps all tokens, top_p=1.0 keeps the full probability mass. + # TODO: Consider changing the disable flags in the `InferenceRequest`. + top_k_safe = top_k.masked_fill(top_k == 0, self._vocab_size) + top_p_safe = top_p.masked_fill(top_p == 0.0, 1.0) + output = torch.empty(n, device=logits.device, dtype=torch.int64) + output.copy_( + flashinfer.sampling.top_k_top_p_sampling_from_probs( + probs, top_k_safe, top_p_safe, generator=self._rng + ) + ) + return output diff --git a/megatron/core/inference/sampling/torch_sampling.py b/megatron/core/inference/sampling/torch_sampling.py new file mode 100644 index 00000000000..79491add5ab --- /dev/null +++ b/megatron/core/inference/sampling/torch_sampling.py @@ -0,0 +1,167 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +from collections import defaultdict +from typing import Any, List, Optional, Tuple + +import torch +from torch import Tensor + +from megatron.core.inference.sampling.base import Sampling + + +class TorchSampling(Sampling): + """Sampling via bucketed `torch.multinomial`. + + Groups requests into unique buckets by `(temperature, top_k, top_p)` for separate launches. + """ + + def __init__(self, rng: torch.Generator, vocab_size: int) -> None: + self._rng = rng + self._vocab_size = vocab_size + + @staticmethod + def sample_from_logits( + last_token_logits: Tensor, + temperature: float, + top_k: int, + top_p: float, + *, + generator: torch.Generator, + vocab_size: Optional[int] = None, + ) -> Tensor: + """Sample tokens from logits with temperature, top-k, and top-p filtering. + + Shared between dynamic batching and static batching. + + Args: + last_token_logits: Logits of shape `[batch_size, vocab_size]`. + temperature: Temperature scaling factor. + top_k: Top-k filtering value (0 = disabled). + top_p: Top-p (nucleus) filtering value (0.0 = disabled). + generator: RNG used by `torch.multinomial`. + vocab_size: When provided, asserts `top_k < vocab_size` and clamps the + sampled ids to `[0, vocab_size - 1]`. + + Returns: + Sampled token ids of shape `[batch_size]`. + """ + assert isinstance(top_p, float) + assert isinstance(top_k, int) + assert not (top_k > 0 and top_p > 0.0), "Cannot have top-p and top-k both greater than zero" + assert top_p <= 1.0, "top-p should be in (0,1]" + + def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf.""" + filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(filter_, float("-Inf")) + + def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + filter_ = cumulative_probs > top_p + # Clone needed: filter_[:, 1:] and filter_[:, :-1] are overlapping views; + # without clone, each write would corrupt the next read during the shift. + filter_[:, 1:] = filter_[:, :-1].clone() + filter_[..., 0] = 0 + + filter_ = filter_.scatter(1, sorted_indices, filter_) + logits.masked_fill_(filter_, float("-Inf")) + + if top_k == 1: + return torch.argmax(last_token_logits, dim=-1) + + # Clone needed: .div_() and masked_fill_() below modify in-place. + last_token_logits = last_token_logits.clone() + if temperature != 1.0: + last_token_logits.div_(temperature) + if top_k > 1: + assert top_k <= last_token_logits.size(1), "top-k is larger than logit size." + if vocab_size: + assert top_k < vocab_size, "top-k is larger than vocab size." + modify_logits_for_top_k_filtering(last_token_logits, top_k) + elif top_p > 0.0: + modify_logits_for_top_p_filtering(last_token_logits, top_p) + + probabilities = last_token_logits.softmax(dim=-1) + sampled = torch.multinomial(probabilities, num_samples=1, generator=generator).view(-1) + + if vocab_size: + sampled = torch.clamp(sampled, min=0, max=(vocab_size - 1)) + + return sampled + + def sample_kernel( + self, + logits: Tensor, + n: int, + context, + *, + gather_indices: Optional[Tensor] = None, + token_to_request_index: Optional[Tensor] = None, + eager: bool = False, + cache_key: Any = None, + ) -> Tensor: + """Bucket active requests by `(temperature, top_k, top_p)` and sample each bucket. + + Args: + logits: Logits tensor of shape `[>=n, vocab_size]`. + n: Number of rows to sample. + context: The active DynamicInferenceContext. + gather_indices: When set, sample from `logits[gather_indices[:n], :]`. + token_to_request_index: When set, the loop dispatches per-token rather than + per-request (used by the speculative path). + eager: Accepted for API symmetry; ignored (TorchSampling has no graph wrapper). + cache_key: Accepted for API symmetry; ignored. + + Returns: + Sampled token ids of shape `[n]`. + """ + # CudaGraphManager consumes these args, if it exists. + del eager, cache_key + + # Group active requests into sampling buckets by (temperature, top_k, top_p). + active_request_count = context.total_request_count - context.paused_request_count + md = context.active_request_metadata + device = torch.cuda.current_device() + + bucket_map: dict = defaultdict(list) + temp = md["temperature"][:active_request_count].tolist() + top_k = md["top_k"][:active_request_count].tolist() + top_p = md["top_p"][:active_request_count].tolist() + for request_index, (t, k, p) in enumerate(zip(temp, top_k, top_p)): + bucket_map[(t, k, p)].append(request_index) + + buckets: List[Tuple] = [(indices, *params) for params, indices in bucket_map.items()] + bucket_index_tensors: List[Tensor] = [ + torch.tensor(indices, device=device, dtype=torch.long) for indices, *_ in buckets + ] + + if gather_indices is not None: + logits = logits[gather_indices[:n], :] + + output = torch.empty(n, device=logits.device, dtype=torch.int64) + token_list = [] + indices_list = [] + for idx_tensor, (_, temp, top_k, top_p) in zip(bucket_index_tensors, buckets): + if token_to_request_index is None: + row_indices = idx_tensor + else: + row_indices = torch.where(torch.isin(token_to_request_index, idx_tensor))[0] + token_list.append( + TorchSampling.sample_from_logits( + logits[row_indices, :], + temp, + top_k, + top_p, + generator=self._rng, + vocab_size=self._vocab_size, + ) + ) + indices_list.append(row_indices) + + sampled_tokens = torch.cat(token_list, dim=0) + sampled_indices = torch.cat(indices_list, dim=0) + output[sampled_indices] = sampled_tokens + return output diff --git a/megatron/core/inference/symmetric_memory.py b/megatron/core/inference/symmetric_memory.py index 254d41ce294..a5269989914 100644 --- a/megatron/core/inference/symmetric_memory.py +++ b/megatron/core/inference/symmetric_memory.py @@ -39,10 +39,13 @@ class SymmetricMemoryBuffer: """ def __init__(self, size_in_mb, process_group): - if not HAVE_TORCH_SYMM_MEM or not HAVE_TRITON: - # This should be hit if the user is running an older - # version of torch, or if they do not have triton - # installed. + self.init_failure_reason: Optional[str] = None + if not HAVE_TORCH_SYMM_MEM: + self.init_failure_reason = "torch.distributed._symmetric_memory not importable" + self.symm_buffer = None + self.symm_mem_hdl = None + elif not HAVE_TRITON: + self.init_failure_reason = "triton not installed" self.symm_buffer = None self.symm_mem_hdl = None else: @@ -52,8 +55,7 @@ def __init__(self, size_in_mb, process_group): self.symm_buffer = symm_mem.empty(numel, dtype=torch.uint8, device='cuda') self.symm_mem_hdl = symm_mem.rendezvous(self.symm_buffer, process_group) except RuntimeError as e: - # If symmetric memory initialization fails, set buffer and handle to None - # This should happen if the process group is not contained within NVlink + self.init_failure_reason = f"{type(e).__name__}: {e}" self.symm_buffer = None self.symm_mem_hdl = None @@ -138,7 +140,7 @@ class SymmetricMemoryManager: """ _buffers: dict[str, SymmetricMemoryBuffer] = {} - _default_size_mb: int = 256 + _default_size_mb: int = 512 @classmethod def get_buffer( diff --git a/megatron/core/inference/text_generation_controllers/mtp_utils_pytorch.py b/megatron/core/inference/text_generation_controllers/mtp_utils_pytorch.py new file mode 100644 index 00000000000..fe5474d0b22 --- /dev/null +++ b/megatron/core/inference/text_generation_controllers/mtp_utils_pytorch.py @@ -0,0 +1,255 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import torch + + +def rewind_kv_cache( + accepted_counts, + prefill_status, + last_kv_block_offset, + kv_length_offsets, + kv_block_counts, + last_kv_block_id, + kv_block_ids, + num_speculative_tokens, + block_size_tokens, + num_active_requests=None, +): + """Update the KV cache bookkeeping for speculative decoding. + + After forward pass with speculative tokens, some tokens may be rejected. + This function "rewinds" the KV cache bookkeeping to reflect only the accepted tokens. + + When speculative tokens are rejected, we need to: + 1. Update kv_length_offsets (total sequence length) + 2. Update last_kv_block_offset (position within last block) + 3. If rewinding crosses a block boundary: + - Reduce kv_block_counts + - Update last_kv_block_id to point to the previous block + - Clear the entry in kv_block_ids for the released block + + Mutates the input tensors in-place. + + Returns (blocks_to_release, remove_mask). + """ + N = accepted_counts.shape[0] + if num_active_requests is None: + num_active_requests = N + + # Bulk-extract scalars once via .tolist() instead of per-element .item(). + # Avoids N round-trips through the Python/C++ boundary inside the loop. + accepted_list = accepted_counts.tolist() + prefill_list = prefill_status.tolist() + offset_list = last_kv_block_offset.tolist() + length_list = kv_length_offsets.tolist() + block_count_list = kv_block_counts.tolist() + last_block_list = last_kv_block_id.tolist() + kv_block_ids_list = kv_block_ids.tolist() + max_blocks = kv_block_ids.shape[1] + + blocks_to_release = torch.empty_like(last_kv_block_id) + remove_mask = torch.empty(N, device=accepted_counts.device, dtype=torch.bool) + + for i in range(N): + if i >= num_active_requests: + blocks_to_release[i] = 0 + remove_mask[i] = False + continue + + accepted = accepted_list[i] + prefill = prefill_list[i] + last_offset = offset_list[i] + kv_length = length_list[i] + block_count = block_count_list[i] + last_block = last_block_list[i] + + # Number of tokens to rewind (rejected speculative tokens). + # For prefill requests, no speculative tokens were forwarded through the model, + # so there is nothing to rewind. + num_to_rewind = 0 if prefill == 1 else num_speculative_tokens - accepted + + # Save the original offset BEFORE modifying to correctly detect block boundary crossing. + # A request crosses back to a previous block if: original_offset - num_to_rewind < 0 + diff = last_offset - num_to_rewind + remove = diff < 0 + + # Update the offsets + new_offset = diff % block_size_tokens + last_kv_block_offset[i] = new_offset + kv_length_offsets[i] = kv_length - num_to_rewind + + # For requests that crossed back to a previous block, we need to: + # 1. Reduce the block count by 1 + # 2. Get the block ID to release (current last_kv_block_id) + # 3. Update last_kv_block_id to point to the previous block + # 4. Clear the entry in kv_block_ids for the released block + # 5. Release the block back to the allocator + blocks_to_release[i] = last_block + + # Reduce block counts for requests that crossed back + new_block_count = block_count - 1 if remove else block_count + kv_block_counts[i] = new_block_count + + # Update last_kv_block_id to point to the previous block (at index new_count - 1) + prev_idx = max(new_block_count - 1, 0) + prev_block_id = kv_block_ids_list[i][prev_idx] + last_kv_block_id[i] = prev_block_id if remove else last_block + + # Clear the released block entry (at index new_count, which was the old last block) + scatter_idx = min(new_block_count, max_blocks - 1) + if remove: + kv_block_ids[i, scatter_idx] = -1 + + remove_mask[i] = remove + + return blocks_to_release, remove_mask + + +# pylint: disable=line-too-long +def verify_speculative_tokens( + input_tokens, output_tokens, num_decode_requests, num_prefill_requests, num_speculative_tokens +): + """Verify speculative tokens against input tokens and compute acceptance. + + Creates an accepted tokens mask where: + - For prefill requests, the token is always accepted. + - For decode requests, the first token (base token) is always accepted, then we compare + sampled tokens with input tokens and accept consecutive matches. + Then finds the index of the last accepted token per request. + + Example (assume 1, 2, and 0 spec tokens are accepted in the first 3 decode requests): + input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 + Output tokens [ a6o a7o a8o | b40 b5o b6o | c7o c8o c9o | d3o | e5o ] + Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] + Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + Last one indices [ 1 | 5 | 6 | 9 | 10 ] + + Returns: + tuple: (last_one_indices, accepted_tokens_mask, input_tokens) where + last_one_indices contains the index of the last accepted token per request. + """ + if input_tokens.ndim == 2: + input_tokens = input_tokens.squeeze(0) + + stride = num_speculative_tokens + 1 + active_request_count = num_decode_requests + num_prefill_requests + decode_len = num_decode_requests * stride + + # Initialize mask with False to prevent boundary bleed + accepted_tokens_mask = torch.zeros_like(input_tokens, dtype=torch.bool) + + # Safe decode token verification without cross-batch boundary contamination + decode_mask_2d = None + if num_decode_requests > 0: + decode_inputs = input_tokens[:decode_len].reshape(num_decode_requests, stride) + decode_outputs = output_tokens[:decode_len].reshape(num_decode_requests, stride) + + # Shift outputs right by 1 *within* each request to align sampled tokens with input targets + decode_outputs_shifted = decode_outputs.roll(1, dims=1) + decode_mask_2d = decode_inputs == decode_outputs_shifted + # The first token (base token) is always accepted + decode_mask_2d[:, 0] = True + # Enforce consecutive acceptance: cummin propagates False to the right + decode_mask_2d = decode_mask_2d.cummin(dim=1).values + accepted_tokens_mask[:decode_len] = decode_mask_2d.flatten() + + # Make all prefill tokens accepted + if num_prefill_requests > 0: + accepted_tokens_mask[decode_len:] = True + + last_one_indices = torch.full( + (active_request_count,), -1, device=input_tokens.device, dtype=torch.long + ) + + if num_decode_requests > 0: + # Summing the consecutive mask gives the count; subtract 1 for the local index + local_last_indices = decode_mask_2d.sum(dim=1) - 1 + row_offsets = torch.arange(num_decode_requests, device=input_tokens.device) * stride + last_one_indices[:num_decode_requests] = row_offsets + local_last_indices + + if num_prefill_requests > 0: + prefill_valid = torch.nonzero(accepted_tokens_mask[decode_len:]).squeeze(-1) + decode_len + last_one_indices[num_decode_requests:] = prefill_valid + + return last_one_indices, accepted_tokens_mask, input_tokens + + +# pylint: disable=line-too-long +def prepare_next_forward_pass( + num_decode_requests, + output_tokens, + required_logit_indices, + last_one_indices, + accepted_tokens_mask, + input_tokens, + sampled_tokens_buf, + last_accepted_seq_buf, + accepted_tokens_per_request, + accepted_token_counts, + num_speculative_tokens, +): + """Prepare data for the next forward pass after speculative token verification. + + For each active request: + - Store the final sampled tokens for the next forward pass. + - Store the last accepted positions in the packed sequence for serial + MTP computation after verification. + + For decode requests, extract accepted tokens and counts: + input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] + Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only decode requests (prefill defaults to -1) + Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 + + Writes results into the pre-allocated buffers provided by the caller. + """ + active_request_count = last_one_indices.shape[0] + stride = num_speculative_tokens + 1 + + for pid in range(active_request_count): + idx = last_one_indices[pid].item() + + # Store the final sampled tokens for the next forward pass. + sampled_tokens_buf[pid] = output_tokens[idx] + + # Store the last accepted positions in the packed sequence for serial + # MTP computation after verification. + last_accepted_seq_buf[pid] = required_logit_indices[idx] + + # Extract accepted tokens and counts for decode requests. + # For prefill it is always set to 1. For decode, the first token is always accepted, + # then we compare with input tokens and accept the next tokens if its a match. + if pid < num_decode_requests: + base = pid * stride + # Skip the first token of every decode request (i.e a5, b3, c6) + for s in range(num_speculative_tokens): + pos = base + 1 + s + if accepted_tokens_mask[pos]: + accepted_tokens_per_request[pid, s] = input_tokens[pos] + else: + accepted_tokens_per_request[pid, s] = -1 + + count = 0 + for s in range(num_speculative_tokens): + if accepted_tokens_per_request[pid, s].item() != -1: + count += 1 + accepted_token_counts[pid] = count + + +def mamba_state_selective_copy( + intermediate_states, current_states, prefill_status, state_idx, accepted_counts, num_layers +): + """Mamba speculative rewind state update. + + For each decode request, copies + `intermediate[layer, slot, accepted_count, ...]` → + `current[layer, slot, ...]` for every Mamba layer. + """ + N = prefill_status.shape[0] + for i in range(N): + if prefill_status[i].item() == 1: + continue + slot = state_idx[i].item() + accepted = accepted_counts[i].item() + for layer in range(num_layers): + current_states[layer, slot] = intermediate_states[layer, slot, accepted] diff --git a/megatron/core/inference/text_generation_controllers/mtp_utils_triton.py b/megatron/core/inference/text_generation_controllers/mtp_utils_triton.py new file mode 100644 index 00000000000..37ff55c1e99 --- /dev/null +++ b/megatron/core/inference/text_generation_controllers/mtp_utils_triton.py @@ -0,0 +1,456 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import math + +import torch + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + from unittest.mock import MagicMock + + from megatron.core.utils import null_decorator + + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + HAVE_TRITON = False + + +# --------------------------------------------------------------------------- +# Kernel 1: KV-cache rewind for speculative decoding +# --------------------------------------------------------------------------- +@triton.jit +def _rewind_kv_cache_kernel( + # Per-request input (read-only) + ACCEPTED_COUNTS_PTR, + PREFILL_STATUS_PTR, + # Per-request state (read-write, updated in-place) + LAST_KV_BLOCK_OFFSET_PTR, + KV_LENGTH_OFFSETS_PTR, + KV_BLOCK_COUNTS_PTR, + LAST_KV_BLOCK_ID_PTR, + # 2-D table [N, max_blocks] (read-write) + KV_BLOCK_IDS_PTR, + # Per-request outputs + BLOCKS_TO_RELEASE_PTR, + REMOVE_MASK_PTR, + # Strides / limits + kv_block_ids_stride, + max_blocks_minus_1, + num_active_requests, + # Compile-time constants + NUM_SPEC_TOKENS: tl.constexpr, + BLOCK_SIZE_TOKENS: tl.constexpr, +): + """Rewind KV-cache bookkeeping for one request after speculative verification. + + Grid: may be padded beyond active requests for CUDA-graph compatibility. + Each program handles exactly one request. Programs with + `pid >= num_active_requests` are padding and produce safe no-op outputs. + """ + pid = tl.program_id(0) + + # Padding programs: write safe defaults and skip all state mutation. + if pid >= num_active_requests: + tl.store(BLOCKS_TO_RELEASE_PTR + pid, 0) + tl.store(REMOVE_MASK_PTR + pid, False) + return + + # --- Load per-request scalars --- + accepted = tl.load(ACCEPTED_COUNTS_PTR + pid) + prefill = tl.load(PREFILL_STATUS_PTR + pid) + last_offset = tl.load(LAST_KV_BLOCK_OFFSET_PTR + pid) + kv_length = tl.load(KV_LENGTH_OFFSETS_PTR + pid) + block_count = tl.load(KV_BLOCK_COUNTS_PTR + pid) + last_block_id = tl.load(LAST_KV_BLOCK_ID_PTR + pid) + + # --- Compute rewind (zero for prefill requests) --- + num_to_rewind = tl.where(prefill == 1, 0, NUM_SPEC_TOKENS - accepted) + diff = last_offset - num_to_rewind + remove = diff < 0 + + # Python-style modulo: ((diff % M) + M) % M to handle negative diff + new_offset = ((diff % BLOCK_SIZE_TOKENS) + BLOCK_SIZE_TOKENS) % BLOCK_SIZE_TOKENS + tl.store(LAST_KV_BLOCK_OFFSET_PTR + pid, new_offset) + tl.store(KV_LENGTH_OFFSETS_PTR + pid, kv_length - num_to_rewind) + + # Save current last block id (will be released by caller if remove is True) + tl.store(BLOCKS_TO_RELEASE_PTR + pid, last_block_id) + + # Decrement block count when a block boundary was crossed + new_block_count = tl.where(remove, block_count - 1, block_count) + tl.store(KV_BLOCK_COUNTS_PTR + pid, new_block_count) + + # Gather previous block id from the 2-D table + kv_row_base = pid.to(tl.int64) * kv_block_ids_stride + prev_idx = tl.maximum(new_block_count - 1, 0) + prev_block_id = tl.load(KV_BLOCK_IDS_PTR + kv_row_base + prev_idx) + + # Conditionally update last block id + tl.store(LAST_KV_BLOCK_ID_PTR + pid, tl.where(remove, prev_block_id, last_block_id)) + + # Clear released block entry via scatter + scatter_idx = tl.minimum(new_block_count, max_blocks_minus_1) + current_val = tl.load(KV_BLOCK_IDS_PTR + kv_row_base + scatter_idx) + tl.store(KV_BLOCK_IDS_PTR + kv_row_base + scatter_idx, tl.where(remove, -1, current_val)) + + # Output remove mask for the caller (to release blocks outside this kernel) + tl.store(REMOVE_MASK_PTR + pid, remove) + + +def rewind_kv_cache( + accepted_counts, + prefill_status, + last_kv_block_offset, + kv_length_offsets, + kv_block_counts, + last_kv_block_id, + kv_block_ids, + num_speculative_tokens, + block_size_tokens, + num_active_requests=None, +): + """Launch the KV-cache rewind Triton kernel. + + Args: + num_active_requests: Number of real (non-padding) requests. When the + grid is padded beyond this count, the kernel skips padding + programs so stale data in padding slots cannot corrupt + bookkeeping. Defaults to `accepted_counts.shape[0]` (no + padding). + + Returns: + (blocks_to_release, remove_mask) — same semantics as the original + torch.compile'd `_rewind_kv_cache` (KV-cache portion only; Mamba + state updates are handled separately by the caller). + """ + N = accepted_counts.shape[0] + if num_active_requests is None: + num_active_requests = N + if N == 0: + return ( + torch.empty(0, device=accepted_counts.device, dtype=last_kv_block_id.dtype), + torch.empty(0, device=accepted_counts.device, dtype=torch.bool), + ) + + blocks_to_release = torch.empty_like(last_kv_block_id) + remove_mask = torch.empty(N, device=accepted_counts.device, dtype=torch.bool) + + _rewind_kv_cache_kernel[(N,)]( + accepted_counts, + prefill_status, + last_kv_block_offset, + kv_length_offsets, + kv_block_counts, + last_kv_block_id, + kv_block_ids, + blocks_to_release, + remove_mask, + kv_block_ids_stride=kv_block_ids.stride(0), + max_blocks_minus_1=kv_block_ids.shape[1] - 1, + num_active_requests=num_active_requests, + NUM_SPEC_TOKENS=num_speculative_tokens, + BLOCK_SIZE_TOKENS=block_size_tokens, + ) + return blocks_to_release, remove_mask + + +# --------------------------------------------------------------------------- +# Kernel 2: Verify speculative tokens +# --------------------------------------------------------------------------- +@triton.jit +def _verify_speculative_tokens_kernel( + INPUT_TOKENS_PTR, + OUTPUT_TOKENS_PTR, + # Outputs + ACCEPTED_MASK_PTR, + LAST_ONE_INDICES_PTR, + # Runtime scalars + num_decode_requests, + decode_len, + # Compile-time constants + STRIDE: tl.constexpr, # num_speculative_tokens + 1 + BLOCK_SIZE: tl.constexpr, # next_power_of_2(STRIDE) +): + """Verify speculative tokens for one request. + + Grid: (active_request_count,) + Programs 0..num_decode_requests-1 handle decode requests. + Programs num_decode_requests..end handle prefill requests. + """ + pid = tl.program_id(0) + + if pid < num_decode_requests: + base = pid * STRIDE + offsets = tl.arange(0, BLOCK_SIZE) + valid = offsets < STRIDE + + input_toks = tl.load(INPUT_TOKENS_PTR + base + offsets, mask=valid, other=0) + + # Build shifted output: shifted[i] = output[i-1]. + # Position 0 uses a dummy load (always accepted regardless). + safe_shifted = tl.where(offsets > 0, offsets - 1, 0) + shifted_output = tl.load(OUTPUT_TOKENS_PTR + base + safe_shifted, mask=valid, other=0) + + # First token is always accepted; rest must match shifted output. + match = tl.where(offsets == 0, 1, (input_toks == shifted_output).to(tl.int32)) + match = tl.where(valid, match, 0) + + # Consecutive acceptance via cumulative-sum trick: + # accepted[i] iff cumsum(match)[i] == i + 1 + cumsum = tl.cumsum(match, axis=0) + accepted = (cumsum == (offsets + 1)) & valid + + tl.store(ACCEPTED_MASK_PTR + base + offsets, accepted, mask=valid) + + accepted_count = tl.sum(accepted.to(tl.int32)) + tl.store(LAST_ONE_INDICES_PTR + pid, (base + accepted_count - 1).to(tl.int64)) + else: + # Prefill request — single token, always accepted + prefill_idx = decode_len + (pid - num_decode_requests) + tl.store(ACCEPTED_MASK_PTR + prefill_idx, 1) + tl.store(LAST_ONE_INDICES_PTR + pid, prefill_idx.to(tl.int64)) + + +def verify_speculative_tokens( + input_tokens, output_tokens, num_decode_requests, num_prefill_requests, num_speculative_tokens +): + """Launch the speculative-token verification Triton kernel. + + Returns: + (last_one_indices, accepted_tokens_mask, input_tokens) + matching the original `_verify_speculative_tokens` signature. + """ + if input_tokens.ndim == 2: + input_tokens = input_tokens.squeeze(0) + + device = input_tokens.device + active_request_count = num_decode_requests + num_prefill_requests + stride = num_speculative_tokens + 1 + decode_len = num_decode_requests * stride + + accepted_tokens_mask = torch.zeros_like(input_tokens, dtype=torch.bool) + last_one_indices = torch.full((active_request_count,), -1, device=device, dtype=torch.long) + + if active_request_count > 0: + block_size = triton.next_power_of_2(stride) + _verify_speculative_tokens_kernel[(active_request_count,)]( + input_tokens, + output_tokens, + accepted_tokens_mask, + last_one_indices, + num_decode_requests=num_decode_requests, + decode_len=decode_len, + STRIDE=stride, + BLOCK_SIZE=block_size, + ) + + return last_one_indices, accepted_tokens_mask, input_tokens + + +# --------------------------------------------------------------------------- +# Kernel 3: Prepare speculative tokens for next forward pass +# --------------------------------------------------------------------------- +@triton.jit +def _prepare_next_forward_pass_kernel( + OUTPUT_TOKENS_PTR, + REQUIRED_LOGIT_INDICES_PTR, + LAST_ONE_INDICES_PTR, + INPUT_TOKENS_PTR, + ACCEPTED_MASK_PTR, + # Outputs + SAMPLED_TOKENS_OUT_PTR, + LAST_ACCEPTED_SEQ_OUT_PTR, + ACCEPTED_TOKENS_OUT_PTR, + ACCEPTED_COUNTS_OUT_PTR, + # Strides + accepted_tokens_out_stride, + # Runtime scalars + num_decode_requests, + # Compile-time constants + STRIDE: tl.constexpr, # num_speculative_tokens + 1 + NUM_SPEC_TOKENS: tl.constexpr, + SPEC_BLOCK_SIZE: tl.constexpr, # next_power_of_2(NUM_SPEC_TOKENS) +): + """Gather final tokens and extract accepted speculative tokens per request. + + Grid: (active_request_count,) + """ + pid = tl.program_id(0) + + # --- Gather final sampled token and sequence index for every request --- + idx = tl.load(LAST_ONE_INDICES_PTR + pid) + tl.store(SAMPLED_TOKENS_OUT_PTR + pid, tl.load(OUTPUT_TOKENS_PTR + idx)) + tl.store(LAST_ACCEPTED_SEQ_OUT_PTR + pid, tl.load(REQUIRED_LOGIT_INDICES_PTR + idx)) + + # --- For decode requests: extract accepted tokens and count --- + if pid < num_decode_requests: + base = pid * STRIDE + spec_offsets = tl.arange(0, SPEC_BLOCK_SIZE) + spec_valid = spec_offsets < NUM_SPEC_TOKENS + token_positions = base + 1 + spec_offsets # skip first (base) token + + tokens = tl.load(INPUT_TOKENS_PTR + token_positions, mask=spec_valid, other=0) + mask_val = tl.load(ACCEPTED_MASK_PTR + token_positions, mask=spec_valid, other=0) + accepted = mask_val != 0 + + result = tl.where(accepted & spec_valid, tokens, -1) + + out_base = pid.to(tl.int64) * accepted_tokens_out_stride + tl.store(ACCEPTED_TOKENS_OUT_PTR + out_base + spec_offsets, result, mask=spec_valid) + + count = tl.sum((accepted & spec_valid).to(tl.int64)) + tl.store(ACCEPTED_COUNTS_OUT_PTR + pid, count) + + +def prepare_next_forward_pass( + num_decode_requests, + output_tokens, + required_logit_indices, + last_one_indices, + accepted_tokens_mask, + input_tokens, + sampled_tokens_buf, + last_accepted_seq_buf, + accepted_tokens_per_request, + accepted_token_counts, + num_speculative_tokens, +): + """Launch the prepare-next-forward-pass Triton kernel. + + Writes results into the pre-allocated buffers provided by the caller. + """ + active_request_count = last_one_indices.shape[0] + if active_request_count == 0: + return + + stride = num_speculative_tokens + 1 + spec_block_size = triton.next_power_of_2(num_speculative_tokens) + + _prepare_next_forward_pass_kernel[(active_request_count,)]( + output_tokens, + required_logit_indices, + last_one_indices, + input_tokens, + accepted_tokens_mask, + sampled_tokens_buf, + last_accepted_seq_buf, + accepted_tokens_per_request, + accepted_token_counts, + accepted_tokens_out_stride=accepted_tokens_per_request.stride(0), + num_decode_requests=num_decode_requests, + STRIDE=stride, + NUM_SPEC_TOKENS=num_speculative_tokens, + SPEC_BLOCK_SIZE=spec_block_size, + ) + + +# --------------------------------------------------------------------------- +# Kernel 4: Mamba state selective copy (eliminates temporary allocations) +# --------------------------------------------------------------------------- +@triton.jit +def _mamba_state_selective_copy_kernel( + # Source: intermediate states [L, M, S+1, *state_shape] + SRC_PTR, + # Destination: current states [L, M, *state_shape] + DST_PTR, + # Per-request index arrays + PREFILL_STATUS_PTR, # [N] 0=decode, 1=prefill + STATE_IDX_PTR, # [N] maps request → mamba state slot + ACCEPTED_PTR, # [N] accepted token index per request + # Strides (in elements) + src_stride_layer, + src_stride_slot, + src_stride_spec, + dst_stride_layer, + dst_stride_slot, + # Data size + STATE_SIZE, + # Compile-time + BLOCK_SIZE: tl.constexpr, +): + """Copy intermediate Mamba state to current state for decode requests. + + Grid: (N, L, num_chunks) + - dim 0: active request index + - dim 1: mamba layer index + - dim 2: chunk of the flattened state vector + + No-op for prefill requests. + """ + pid_req = tl.program_id(0) + pid_layer = tl.program_id(1) + pid_chunk = tl.program_id(2) + + # Skip prefill requests immediately. + prefill = tl.load(PREFILL_STATUS_PTR + pid_req) + if prefill == 1: + return + + state_idx = tl.load(STATE_IDX_PTR + pid_req).to(tl.int64) + accepted = tl.load(ACCEPTED_PTR + pid_req).to(tl.int64) + + chunk_start = pid_chunk * BLOCK_SIZE + offsets = tl.arange(0, BLOCK_SIZE) + elem_offsets = chunk_start + offsets + mask = elem_offsets < STATE_SIZE + + src_base = ( + pid_layer.to(tl.int64) * src_stride_layer + + state_idx * src_stride_slot + + accepted * src_stride_spec + ) + dst_base = pid_layer.to(tl.int64) * dst_stride_layer + state_idx * dst_stride_slot + + data = tl.load(SRC_PTR + src_base + elem_offsets, mask=mask) + tl.store(DST_PTR + dst_base + elem_offsets, data, mask=mask) + + +def mamba_state_selective_copy( + intermediate_states, current_states, prefill_status, state_idx, accepted_counts, num_layers +): + """Copy accepted intermediate Mamba states to current states in-place. + + For each decode request, copies + `intermediate[layer, slot, accepted_count, ...]` → + `current[layer, slot, ...]` for every Mamba layer. + + Args: + intermediate_states: `(L, M, S+1, *state_shape)` — intermediate buffer. + current_states: `(L, M, *state_shape)` — current state buffer (updated in-place). + prefill_status: `(N,)` int tensor — 0 for decode, 1 for prefill. + state_idx: `(N,)` int tensor — mamba state slot index per request. + accepted_counts: `(N,)` int tensor — accepted token index per request. + num_layers: number of Mamba layers (first dim of the state tensors). + """ + N = prefill_status.shape[0] + if N == 0: + return + + # The state vector to copy per (layer, request) is the product of all + # trailing dimensions after the speculative-token axis. + # intermediate shape: (L, M, S+1, *state_shape) → state_size = prod(state_shape) + state_size = math.prod(intermediate_states.shape[3:]) + + BLOCK_SIZE = 1024 + num_chunks = triton.cdiv(state_size, BLOCK_SIZE) + grid = (N, num_layers, num_chunks) + + _mamba_state_selective_copy_kernel[grid]( + intermediate_states, + current_states, + prefill_status, + state_idx, + accepted_counts, + src_stride_layer=intermediate_states.stride(0), + src_stride_slot=intermediate_states.stride(1), + src_stride_spec=intermediate_states.stride(2), + dst_stride_layer=current_states.stride(0), + dst_stride_slot=current_states.stride(1), + STATE_SIZE=state_size, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index abf1bbf585b..87edddea566 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -7,9 +7,11 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union +import numpy as np import torch import torch.nn.functional as F from torch import Tensor +from torch.cuda.nvtx import range_pop, range_push from megatron.core import parallel_state from megatron.core.inference.async_stream import AsyncStream @@ -24,7 +26,11 @@ AbstractModelInferenceWrapper, ) from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding +from megatron.core.inference.utils import ( + get_attention_mask, + set_decode_expert_padding, + set_moe_metadata_sync, +) from megatron.core.models.multimodal.llava_model import LLaVAModel from megatron.core.tensor_parallel.mappings import ( gather_from_sequence_parallel_region, @@ -39,6 +45,9 @@ get_asyncio_loop, get_model_config, get_pg_size, + nvtx_range_pop, + nvtx_range_push, + round_up_to_nearest_multiple, unwrap_model, ) @@ -51,6 +60,13 @@ HAVE_TE = False from megatron.core.inference.batch_dimensions_utils import InferenceBatchDimensions +from megatron.core.inference.sampling import FlashInferSampling, Sampling, TorchSampling +from megatron.core.inference.text_generation_controllers.mtp_utils_pytorch import rewind_kv_cache +from megatron.core.inference.text_generation_controllers.mtp_utils_triton import ( + mamba_state_selective_copy, + prepare_next_forward_pass, + verify_speculative_tokens, +) # pylint: disable=line-too-long @@ -127,6 +143,11 @@ def _init_dynamic_sampling_tensors(self): """Initialize tensors needed for dynamic sampling.""" context = self.inference_wrapped_model.inference_context max_requests = context.max_requests + if context.config.materialize_only_last_token_logits: + # Under MTP, each decode request emits (num_speculative_tokens + 1) logit rows + max_logits = max_requests * (self.num_speculative_tokens + 1) + else: + max_logits = context.max_tokens # Callback to get request IDs that should be marked as finished due to stop words self._get_stop_word_finished_ids_callback = None @@ -134,46 +155,79 @@ def _init_dynamic_sampling_tensors(self): device = torch.cuda.current_device() logits_dtype = self.inference_wrapped_model.config.params_dtype - self._sampling_backend = "torch" - self._sampled_tokens_cuda = torch.empty(max_requests, dtype=torch.int64, device=device) - # Speculative tokens tensor will be allocated later when num_speculative_tokens is set by the engine - self._accepted_tokens_per_request = None - # MTP tensor will be allocated later when num_speculative_tokens is set by the engine - self._sampled_mtp_tokens_cuda = None - # Last accepted sequence indices for serial MTP computation - self._last_accepted_seq_indices = None + self._sampling_backend = context.config.sampling_backend + self._enable_cuda_graph = self.model_config.cuda_graph_impl == "local" - # Keep track of request metadata. - self._request_metadata: Dict[str, Tensor] = {} - for label, dtype, on_gpu in context.request_metadata_types: - tensor = context.request_metadata[label] - if not on_gpu: - # Create pinned tensors for request metadata that lives on CPU. - # This is metadata which requires D2H copies, such as top_k for torch sampling. - tensor = torch.empty_like(tensor, device="cpu", pin_memory=True) - self._request_metadata[label] = tensor - - # Used for inefficient torch sampling. - if self._sampling_backend == "torch": - self._torch_sampling_buckets: List[Tuple] = [] - - self._init_mtp_sampling_tensor() - - def _init_mtp_sampling_tensor(self): - """Initialize the MTP sampling tensor after num_speculative_tokens is set.""" - if self.num_speculative_tokens is not None and self.num_speculative_tokens > 0: - context = self.inference_wrapped_model.inference_context - max_requests = context.max_requests - device = torch.cuda.current_device() - self._sampled_mtp_tokens_cuda = torch.empty( - [self.num_speculative_tokens, max_requests], dtype=torch.int64, device=device + # Initialize bookkeeping tensors. + if self._enable_cuda_graph: + self._all_logits_cuda = torch.zeros( + (1, max_logits, self.vocab_size), dtype=logits_dtype, device=device ) - self._accepted_tokens_per_request = ( - torch.ones( - [max_requests, self.num_speculative_tokens], dtype=torch.int64, device=device - ) - * -1 + else: + self._all_logits_cuda = None + # Speculative path: + # - `self._sampled_tokens_cuda` is pre-allocated by `_init_mtp_sampling_tensors`. + # - The tensor cannot be reused between the Triton kernel and the sampling graph. + # Non-speculative path: + # - `self._sampled_tokens_cuda` is rebound to the output of `sample_kernel`, + # which uses CudaGraphManager syntactic sugar to keep it as a static tensor. + self._sampled_tokens_cuda = None + + # Sampling backend: provides the sampling kernel. + if self._sampling_backend == "flashinfer": + self._sampling: Sampling = FlashInferSampling( + self.vocab_size, + self.sampling_rng, + config=self.model_config, + enable_cuda_graph=self._enable_cuda_graph, + ) + else: + self._sampling: Sampling = TorchSampling(self.sampling_rng, self.vocab_size) + + # Cache values that are constant across inference steps. + self._unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + self._is_last_pp_stage = is_pipeline_last_stage(self.pp_group) + self._tp_size = get_pg_size(self.inference_wrapped_model.tp_group) + self._sp_enabled = self.model_config.sequence_parallel and self._tp_size > 1 + + self._init_mtp_sampling_tensors() + + def _init_mtp_sampling_tensors(self): + """Pre-allocate MTP sampling tensors. + + Addresses must be stable across steps for CUDA graph capture. + """ + if not self.num_speculative_tokens: + self._sampled_mtp_tokens_cuda = None + self._accepted_tokens_per_request = None + self._last_accepted_seq_indices = None + return + + context = self.inference_wrapped_model.inference_context + max_requests = context.max_requests + device = torch.cuda.current_device() + self._sampled_tokens_cuda = torch.empty(max_requests, dtype=torch.int64, device=device) + self._sampled_mtp_tokens_cuda = torch.empty( + [self.num_speculative_tokens, max_requests], dtype=torch.int64, device=device + ) + self._accepted_tokens_per_request = ( + torch.ones( + [max_requests, self.num_speculative_tokens], dtype=torch.int64, device=device ) + * -1 + ) + self._accepted_token_counts_per_request = torch.zeros( + max_requests, dtype=torch.int64, device=device + ) + self._last_accepted_seq_indices_buf = torch.empty( + max_requests, dtype=torch.int64, device=device + ) + self._last_accepted_seq_indices = None + self._num_mtp_depths = min(self.num_speculative_tokens, self.num_mtp_heads) + self._mtp_token_ids_buf = torch.empty([1, max_requests], dtype=torch.int64, device=device) + self._mtp_position_ids_buf = torch.empty( + [1, max_requests], dtype=torch.int64, device=device + ) @staticmethod def tokenize_prompt(tokenizer, prompt: str, add_BOS: bool = False) -> List[int]: @@ -282,95 +336,6 @@ def detokenize_generations( return text, prompts_plus_generations_segments - def _torch_sampling_func( - self, - last_token_logits: torch.Tensor, - temperature: float, - top_k: int, - top_p: float, - vocab_size: Optional[int] = None, - ): - """Samples the logits to generate outputs - - Given the logits of the last token, this function samples it - according to the parameters defined in sampling_params - and returns the samples. If sampling parameters top_n_logprobs > 0 - at each step it also updates the top_n_logprobs dict. - - Args: - last_token_logits (torch.Tensor): The last token logits. A tensor of - size [batch_size, vocab_size]. - temperature (float): The temperature to use for sampling. - top_k (int): The top-k value to use for sampling. - top_p (float): The top-p value to use for sampling. - vocab_size (int): Obtained from the tokenizer. Defaults to None. - - Returns: - sampled_logits (torch.Tensor): 1D tensor with [batch_size] elements - """ - assert isinstance(top_p, float) - assert isinstance(top_k, int) - assert not (top_k > 0 and top_p > 0.0), "Cannot have top-p and top-k both greater than zero" - assert top_p <= 1.0, "top-p should be in (0,1]" - - def modify_logits_for_top_k_filtering(logits, top_k): - """Set the logits for none top-k values to -inf.""" - filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits.masked_fill_(filter_, float("-Inf")) - - def modify_logits_for_top_p_filtering(logits, top_p): - """Set the logits for none top-p values to -inf.""" - # First sort and calculate cumulative sum of probabilities. - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) - - # Filteration based on the cumulative sum. - filter_ = cumulative_probs > top_p - # This shift by 1 is weird and I cannot justify it. This existed - # in the original implementation: - # https://github.com/ari-holtzman/degen/blob/master/gen.py - # and I guess it is needed so keeping it for now. - # Clone needed: filter_[:, 1:] and filter_[:, :-1] are overlapping views; - # without clone, each write would corrupt the next read during the shift. - filter_[:, 1:] = filter_[:, :-1].clone() - # Make sure we at least have one token to select from. - filter_[..., 0] = 0 - - # Fill in the filtered part - filter_ = filter_.scatter(1, sorted_indices, filter_) - logits.masked_fill_(filter_, float("-Inf")) - - # Greedy sampling - if top_k == 1: - sampled_logits = torch.argmax(last_token_logits, dim=-1) - else: - # Clone needed: .div_() and masked_fill_() below modify in-place, - # which would mutate the caller's tensor without this clone. - last_token_logits = last_token_logits.clone() - if temperature != 1.0: - last_token_logits.div_(temperature) - if top_k > 1: - assert top_k <= last_token_logits.size(1), "top-k is larger than logit size." - if vocab_size: - assert top_k < vocab_size, "top-k is larger than vocab size." - modify_logits_for_top_k_filtering(last_token_logits, top_k) - - elif top_p > 0.0: - modify_logits_for_top_p_filtering(last_token_logits, top_p) - - # After filtering, we need to recalculate the distribution. - probabilities = last_token_logits.softmax(dim=-1) - - sampled_logits = torch.multinomial( - probabilities, num_samples=1, generator=self.sampling_rng - ).view(-1) - - # If vocab size is provided, make sure the samples are in in the range [0, vocab-size). - if vocab_size: - sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1)) - - return sampled_logits - def sample_from_logits( self, last_token_logits: torch.Tensor, @@ -457,7 +422,14 @@ def sample_from_logits( top_k = sampling_params.top_k temperature = sampling_params.temperature - return self._torch_sampling_func(last_token_logits, temperature, top_k, top_p, vocab_size) + return TorchSampling.sample_from_logits( + last_token_logits, + temperature, + top_k, + top_p, + generator=self.sampling_rng, + vocab_size=vocab_size, + ) def update_generation_status( self, @@ -569,17 +541,37 @@ def _dynamic_step_context_init( position_ids (Tensor): The active position IDs. """ context = self.inference_wrapped_model.inference_context - active_request_slice = slice(context.paused_request_count, context.total_request_count) # Remove Float16Module wrapper if it exists unwrapped_model = unwrap_model(self.inference_wrapped_model.model) model_config = get_model_config(unwrapped_model) - # Initialize attention state. + # Initialize attention state (100% CPU computation). + range_push("initialize_attention_state") context.initialize_attention_state( construct_graph_dimensions=construct_graph_dimensions, is_expert_parallel_dummy_cuda_graph_step=is_dummy_forward, ) + range_pop() + + # Single batch CPU-to-GPU transfer of bookkeeping state. + range_push("transfer_bookkeeping_to_gpu") + context.transfer_bookkeeping_to_gpu() + range_pop() + + set_moe_metadata_sync(unwrapped_model) + + # Derive the MTP padded batch size from the existing padded graph dimensions. + # For MoE models this is post EP sync. In eager mode MTP uses locally SP-aligned + # batch size instead. + if context.using_cuda_graph_this_step(): + self._mtp_resolved_padded_count = context.padded_batch_dimensions.req_count + if self._sp_enabled: + self._mtp_resolved_padded_count = round_up_to_nearest_multiple( + self._mtp_resolved_padded_count, self._tp_size + ) + else: + self._mtp_resolved_padded_count = None # If using symmetric kernels and we are using using nccl # for prefill turn off symmetric kernels @@ -610,14 +602,6 @@ def _dynamic_step_context_init( # Turn off symmetric all reduces for prefill unwrapped_model.set_symmetric_ar(None) - # Get request metadata for this step. - for label, dtype, on_gpu in context.request_metadata_types: - if not on_gpu: - # We need a D2H copy from the context to the pinned memory buffer. - self._request_metadata[label].copy_( - context.request_metadata[label], non_blocking=True - ) - # Get flat tokens, position ids. # If we are running a dummy forward step we want to use the token count agreed upon # by all EP ranks rather than the minimum number of tokens. @@ -628,7 +612,7 @@ def _dynamic_step_context_init( else: return context.current_input_and_position_ids() - def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) -> Tensor: + def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor): """Forward step the model to get logits for dynamic batching. This also handles logits-broadcasting for pipeline parallelism. @@ -638,7 +622,10 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) position_ids (Tensor): The position IDs. """ context = self.inference_wrapped_model.inference_context - active_request_count = context.total_request_count - context.paused_request_count + if context.config.materialize_only_last_token_logits: + logits_seq_len = context.num_last_token_logits + else: + logits_seq_len = context.padded_active_token_count with torch.inference_mode(): logits = self.inference_wrapped_model.run_one_forward_step( @@ -646,6 +633,9 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) ) # logits shape: [1, seq_len, vocab_size] + if not context.config.materialize_only_last_token_logits: + assert logits_seq_len == input_ids.shape[1] + # Note: When speculative decoding is active (num_speculative_tokens > 0), # the model skips MTP computation during the forward pass. MTP logits # will be computed serially after verification to ensure they are @@ -653,13 +643,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) if self.model_is_pipeline_parallel: if context.config.materialize_only_last_token_logits: - if self.num_speculative_tokens > 0: - logits_seq_len = ( - context.num_decode_requests * (self.num_speculative_tokens + 1) - + context.num_prefill_requests - ) - else: - logits_seq_len = active_request_count + logits_seq_len = context.num_last_token_logits else: logits_seq_len = input_ids.shape[1] logits_shape = [1, logits_seq_len, self.vocab_size] @@ -674,140 +658,77 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) pp_group=self.pp_group, ) - return logits - - def _dynamic_step_sample_bookkeeping(self): - """Perform bookkeeping necessary to sample logits for dynamic batching.""" - context = self.inference_wrapped_model.inference_context - active_request_slice = slice(context.paused_request_count, context.total_request_count) - - if self._sampling_backend == "torch": - # Bucketize the core sampling parameters. - # Doing so via list comprehension is orders of magnitude faster than via torch. - bucket_map = defaultdict(list) - - # Shorthands for the dictionary comprehension. - temp = self._request_metadata["temperature"][active_request_slice].tolist() - top_k = self._request_metadata["top_k"][active_request_slice].tolist() - top_p = self._request_metadata["top_p"][active_request_slice].tolist() - - for request_index, (t, k, p) in enumerate(zip(temp, top_k, top_p)): - sampling_params = (t, k, p) - bucket_map[sampling_params].append(request_index) - - # Just unpack the key directly! - self._torch_sampling_buckets = [ - (indices, *sampling_params) for sampling_params, indices in bucket_map.items() - ] + # Copy logits to contiguous buffer. + if self._enable_cuda_graph: + self._all_logits_cuda[:, :logits_seq_len, :].copy_(logits[:, :logits_seq_len, :]) + else: + self._all_logits_cuda = logits - def _rewind_kv_cache(self): + def _rewind_kv_cache(self) -> tuple: """Update the KV cache bookkeeping for speculative decoding. After forward pass with speculative tokens, some tokens may be rejected. - This function "rewinds" the KV cache bookkeeping to reflect only the accepted tokens. - - When speculative tokens are rejected, we need to: - 1. Update request_kv_length_offsets (total sequence length) - 2. Update request_last_kv_block_offset (position within last block) - 3. If rewinding crosses a block boundary: - - Reduce request_kv_block_counts - - Update request_last_kv_block_id to point to the previous block - - Clear the entry in request_to_kv_block_ids for the released block - - Release the block back to the allocator + This function "rewinds" the KV cache bookkeeping to reflect only the + accepted tokens. The core bookkeeping rewind runs on CPU (mutating the + CPU source-of-truth tensors in place); the Mamba hybrid-model state + update stays on GPU because it operates on GPU-resident state buffers. + + Returns (blocks_to_release, remove_mask) for the caller to release blocks + back to the allocator outside the compiled graph. """ context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count active_request_slice = slice(context.paused_request_count, context.total_request_count) - # Get the accepted token counts for each request - # Note: _accepted_token_counts is indexed from 0 to active_request_count-1 - accepted_tokens_per_request = self._accepted_token_counts_per_request[:active_request_count] - - # Number of tokens to rewind (rejected speculative tokens) - num_tokens_to_rewind = self.num_speculative_tokens - accepted_tokens_per_request - - # For prefill requests, no speculative tokens were forwarded through the model, - # so there is nothing to rewind. - request_in_prefill_status = context.request_in_prefill_status_tensor[active_request_slice] - num_tokens_to_rewind[request_in_prefill_status == 1] = 0 - - # Save the original offset BEFORE modifying to correctly detect block boundary crossing - original_offset = context.request_last_kv_block_offset[active_request_slice].clone() - - # Check which requests need to rewind to a previous block BEFORE modifying - # A request crosses back to a previous block if: original_offset - num_tokens_to_rewind < 0 - remove_allocated_blocks_mask = (original_offset - num_tokens_to_rewind) < 0 - - # Update the offsets - context.request_last_kv_block_offset[active_request_slice] = ( - original_offset - num_tokens_to_rewind - ) % context.block_size_tokens - - context.request_kv_length_offsets[active_request_slice] = ( - context.request_kv_length_offsets[active_request_slice] - num_tokens_to_rewind + # accepted_counts is the only GPU input; D2H a small slice so the + # CPU rewind can read its values via .tolist() inside a Python loop. + accepted_tokens_per_request_cpu = self._accepted_token_counts_per_request[ + :active_request_count + ].cpu() + + blocks_to_release, remove_mask = rewind_kv_cache( + accepted_counts=accepted_tokens_per_request_cpu, + prefill_status=context.request_in_prefill_status_tensor[active_request_slice], + last_kv_block_offset=context.request_last_kv_block_offset[active_request_slice], + kv_length_offsets=context.request_kv_length_offsets[active_request_slice], + kv_block_counts=context.request_kv_block_counts[active_request_slice], + last_kv_block_id=context.request_last_kv_block_id[active_request_slice], + kv_block_ids=context.request_to_kv_block_ids[active_request_slice], + num_speculative_tokens=self.num_speculative_tokens, + block_size_tokens=context.block_size_tokens, + num_active_requests=active_request_count, ) - # No need to update request_query_lengths (It will be set correctly in the next iteration) - - # For requests that crossed back to a previous block, we need to: - # 1. Reduce the block count by 1 - # 2. Get the block ID to release (current request_last_kv_block_id) - # 3. Update request_last_kv_block_id to point to the previous block - # 4. Clear the entry in request_to_kv_block_ids for the released block - # 5. Release the block back to the allocator - if remove_allocated_blocks_mask.any(): - # Get indices of requests that need to release a block (relative to active requests) - requests_needing_release = torch.nonzero(remove_allocated_blocks_mask, as_tuple=True)[0] - # Convert to absolute indices in the context tensors - absolute_indices = requests_needing_release + context.paused_request_count - - # No clone needed: advanced (fancy) indexing with a tensor already returns - # a copy, not a view. - blocks_to_release = context.request_last_kv_block_id[absolute_indices] - - # Reduce block counts for requests that crossed back - context.request_kv_block_counts[absolute_indices] -= 1 - - # Get the new block counts after decrement - new_block_counts = context.request_kv_block_counts[absolute_indices] - - # Update request_last_kv_block_id to point to the previous block - # and clear the released block entry in request_to_kv_block_ids - # Vectorized implementation using advanced indexing: - # Note: new_block_counts is guaranteed to be > 0 for all requests here, since - # crossing back to a previous block implies the request had at least 2 blocks. - - # Update request_last_kv_block_id to point to the previous block (at index new_count - 1) - context.request_last_kv_block_id[absolute_indices] = context.request_to_kv_block_ids[ - absolute_indices, new_block_counts - 1 - ] - - # Clear the released block entry (at index new_count, which was the old last block) - context.request_to_kv_block_ids[absolute_indices, new_block_counts] = -1 - - # Release the blocks back to the allocator - context.kv_block_allocator.release_memory_blocks(blocks_to_release) - - # Mamba speculative rewind state update + # Mamba speculative rewind stays on GPU because it mutates GPU-resident + # SSM/conv state that the next forward pass reads directly. if context.is_hybrid_model: - active_mamba_indices = context.mamba_metadata.request_to_mamba_state_idx[ + cuda_device = torch.cuda.current_device() + # gpu_view.request_in_prefill_status was uploaded by this step's + # coalesced H2D and mirrors the active-slice CPU values, so we + # don't need to re-upload prefill_status for the Mamba kernels. + prefill_status_gpu = context.gpu_view.request_in_prefill_status[:active_request_count] + accepted_counts_gpu = self._accepted_token_counts_per_request[:active_request_count] + mamba_state_idx = context.mamba_metadata.request_to_mamba_state_idx[ active_request_slice - ] - is_decode_mask = context.request_in_prefill_status_tensor[active_request_slice] == 0 - decode_mamba_indices = active_mamba_indices[is_decode_mask] - accepted_tokens_per_decode_request = accepted_tokens_per_request[is_decode_mask] - - if decode_mamba_indices.numel() > 0: - context.mamba_conv_states[:, decode_mamba_indices] = ( - context.mamba_intermediate_conv_states[ - :, decode_mamba_indices, accepted_tokens_per_decode_request - ] - ) - context.mamba_ssm_states[:, decode_mamba_indices] = ( - context.mamba_intermediate_ssm_states[ - :, decode_mamba_indices, accepted_tokens_per_decode_request - ] - ) + ].to(cuda_device, non_blocking=True) + mamba_state_selective_copy( + intermediate_states=context.mamba_intermediate_conv_states, + current_states=context.mamba_conv_states, + prefill_status=prefill_status_gpu, + state_idx=mamba_state_idx, + accepted_counts=accepted_counts_gpu, + num_layers=context.num_mamba_layers, + ) + mamba_state_selective_copy( + intermediate_states=context.mamba_intermediate_ssm_states, + current_states=context.mamba_ssm_states, + prefill_status=prefill_status_gpu, + state_idx=mamba_state_idx, + accepted_counts=accepted_counts_gpu, + num_layers=context.num_mamba_layers, + ) + + return blocks_to_release, remove_mask def _sample_from_logits_2d(self, logits_2d: Tensor) -> Tensor: """Sample tokens from 2D logits using existing sampling parameters. @@ -818,21 +739,12 @@ def _sample_from_logits_2d(self, logits_2d: Tensor) -> Tensor: Returns: Tensor: Sampled tokens of shape [num_requests]. """ - spec_token_list = [] - indices_list = [] - for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: - request_indices_tensor = torch.tensor( - request_indices, device=logits_2d.device, dtype=torch.long - ) - spec_token_list.append( - self._torch_sampling_func(logits_2d[request_indices_tensor, :], temp, top_k, top_p) - ) - indices_list.append(request_indices_tensor) - - spec_tokens = torch.empty(logits_2d.shape[0], device=logits_2d.device, dtype=torch.int64) - for tokens, indices in zip(spec_token_list, indices_list): - spec_tokens[indices] = tokens - return spec_tokens + return self._sampling.sample_kernel( + logits_2d, + logits_2d.shape[0], + self.inference_wrapped_model.inference_context, + eager=True, + ) def _compute_serial_mtp_and_sample(self): """Compute MTP logits serially after verification and sample speculative tokens. @@ -846,14 +758,15 @@ def _compute_serial_mtp_and_sample(self): (scattered along the first dimension) between MTP depths to avoid a redundant gather + scatter round-trip per depth. """ + nvtx_range_push("mtp-spec-decoding/serial-mtp-init") context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count active_slice = slice(context.paused_request_count, context.total_request_count) - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + unwrapped_model = self._unwrapped_model # On non-last pipeline stages, the model won't have decoder hidden states. - has_mtp = is_pipeline_last_stage(self.pp_group) and hasattr( + has_mtp = self._is_last_pp_stage and hasattr( unwrapped_model, '_decoder_hidden_states_cache' ) @@ -864,7 +777,7 @@ def _compute_serial_mtp_and_sample(self): # When SP is active the decoder output is in scattered format # [S/TP, B, H], but _last_accepted_seq_indices are indices into # the full (gathered) sequence. - if self.model_config.sequence_parallel: + if self._sp_enabled: hidden_states = gather_from_sequence_parallel_region( hidden_states, group=self.inference_wrapped_model.tp_group ) @@ -874,344 +787,281 @@ def _compute_serial_mtp_and_sample(self): last_accepted_hidden = None # Compute position IDs for the next tokens. - # After rewind, request_kv_length_offsets has been adjusted. The actual - # KV cache length is: adjusted_offset + processed_tokens. - # The next position to predict starts at that cache length. - adjusted_offsets = context.request_kv_length_offsets[active_slice] - processed_tokens = context.request_query_lengths[active_slice] - base_position = adjusted_offsets + processed_tokens + # After rewind, request_kv_length_offsets has been adjusted. Read from + # CPU context (post-rewind values), NOT gpu_view (stale pre-rewind snapshot). + # The next position to predict is: adjusted_offset + processed_tokens. + cuda_device = torch.cuda.current_device() + adjusted_offsets = context.request_kv_length_offsets[active_slice].to( + cuda_device, non_blocking=True + ) + processed_tokens = context.request_query_lengths[active_slice].to( + cuda_device, non_blocking=True + ) + # Cast to int64 to match CUDA graph capture dtype expectations. + base_position = (adjusted_offsets + processed_tokens).to(torch.int64) # Start with the freshly sampled base token. next_token_ids = self._sampled_tokens_cuda[:active_request_count].clone() current_hidden = last_accepted_hidden if has_mtp else None - # Compute padding needed to make batch a multiple of tp_size for SP compatibility. - tp_size = get_pg_size(self.inference_wrapped_model.tp_group) - sp_enabled = self.model_config.sequence_parallel and tp_size > 1 - if sp_enabled: - pad_count = (tp_size - active_request_count % tp_size) % tp_size - padded_count = active_request_count + pad_count + # Compute padding needed to make batch compatible with SP and CUDA graphs. + if getattr(self, '_mtp_resolved_padded_count', None) is not None: + # CUDA-graph path: use the EP-synced padded count. + padded_count = self._mtp_resolved_padded_count + assert not self._sp_enabled or padded_count % self._tp_size == 0 + elif has_mtp: + # Eager path: pad only for SP alignment. + padded_count = active_request_count + if self._sp_enabled: + padded_count = round_up_to_nearest_multiple(padded_count, self._tp_size) else: - pad_count = 0 + padded_count = active_request_count + pad_count = padded_count - active_request_count - # Pad hidden states to align with the tensor parallel size. - if has_mtp and sp_enabled: - if pad_count > 0: - current_hidden = F.pad(current_hidden, (0, 0, 0, 0, 0, pad_count)) + # Pad hidden states and scatter for sequence parallelism. + if has_mtp: + current_hidden = F.pad(current_hidden, (0, 0, 0, 0, 0, pad_count)) + if self._sp_enabled: + current_hidden = scatter_to_sequence_parallel_region( + current_hidden, group=self.inference_wrapped_model.tp_group + ) - current_hidden = scatter_to_sequence_parallel_region( - current_hidden, group=self.inference_wrapped_model.tp_group - ) + token_ids_buf = self._mtp_token_ids_buf[:, :padded_count] + position_ids_buf = self._mtp_position_ids_buf[:, :padded_count] + + # Zero-fill padding slots so the embedding layer never sees out-of-range IDs. + token_ids_buf[0, active_request_count:] = 0 + position_ids_buf[0, active_request_count:] = 0 - num_depths = min(self.num_speculative_tokens, self.num_mtp_heads) - for depth in range(num_depths): - position_ids = (base_position + depth).unsqueeze(0) # [1, active_request_count] - token_ids = next_token_ids.unsqueeze(0) # [1, active_request_count] + nvtx_range_pop("mtp-spec-decoding/serial-mtp-init") + for depth in range(self._num_mtp_depths): + nvtx_range_push(f"mtp-spec-decoding/depth-{depth}") + + token_ids_buf[0, :active_request_count] = next_token_ids + position_ids_buf[0, :active_request_count] = base_position + depth mtp_logits_2d = None if has_mtp: - # Pad token_ids and position_ids each iteration (they change per depth). - if pad_count > 0: - token_ids = F.pad(token_ids, (0, pad_count)) - position_ids = F.pad(position_ids, (0, pad_count)) - + nvtx_range_push(f"mtp-spec-decoding/depth-{depth}/forward") + mtp_depth = None if unwrapped_model.mtp.mtp_use_repeated_layer else depth current_hidden, mtp_logits = unwrapped_model.compute_mtp_single_step( hidden_states=current_hidden, - next_token_ids=token_ids, - position_ids=position_ids, - depth=depth, + next_token_ids=token_ids_buf, + position_ids=position_ids_buf, + depth=mtp_depth, + eager=not context.using_cuda_graph_this_step(), + cache_key=( + ("mtp", padded_count, mtp_depth) + if context.using_cuda_graph_this_step() + else None + ), ) + nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}/forward") - # Strip padding from logits only. Hidden states stay padded+SP + # Strip padding from logits only. Hidden states stay padded+SP # between depths to avoid redundant gather/scatter round-trips. - if pad_count > 0: - mtp_logits = mtp_logits[:active_request_count] + mtp_logits = mtp_logits[:active_request_count] # mtp_logits: [active_request_count, 1, vocab_size] mtp_logits_2d = mtp_logits.squeeze(1) # [active_request_count, vocab_size] # Broadcast MTP logits across pipeline stages. if self.model_is_pipeline_parallel: + nvtx_range_push(f"mtp-spec-decoding/depth-{depth}/pp-broadcast") mtp_logits_2d = broadcast_from_last_pipeline_stage( [active_request_count, self.vocab_size], dtype=self.model_config.params_dtype, tensor=mtp_logits_2d, pp_group=self.pp_group, ) + nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}/pp-broadcast") # Sample speculative token using the same sampling parameters. + nvtx_range_push(f"mtp-spec-decoding/depth-{depth}/sample") spec_tokens = self._sample_from_logits_2d(mtp_logits_2d) self._sampled_mtp_tokens_cuda[depth, :active_request_count] = spec_tokens + nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}/sample") # Use sampled token as input for the next depth. next_token_ids = spec_tokens + nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}") # Clean up cached hidden states. if has_mtp: del unwrapped_model._decoder_hidden_states_cache - def _sample_speculative_logits( - self, required_logits: Tensor, request_in_prefill_status_tensor: Tensor - ) -> tuple: - """Sample tokens from logits using sampling buckets. - - For torch sampling buckets: [request_indices, temp, top_k, top_p] - - Example with 5 requests: - token_to_request_idx : [ 0 0 0 | 1 1 1 | 2 2 2 | 3 | 4 ] - required_logits : [ a5l a6l a7l | b3l b4l b5l | c6l c7l c8l | d2l | e4l ] # Shape [11, vocab_size] - - Sampling buckets: [[[0,2], temp1, top_k1, top_p1], [[1], temp3, top_k3, top_p3], [[3, 4], temp2, top_k2, top_p2]] - - Final output tokens : [a5s a6s a7s c6s c7s c8s b3s b4s b5s d2s e4s] # Shape [11] - (Rearranged from sampling bucket order back to input order using token_order) - - Returns: - tuple: (output_tokens, repeats) where output_tokens has shape [total_required_tokens] - """ - repeats = torch.where( - request_in_prefill_status_tensor == 0, 1 + self.num_speculative_tokens, 1 - ) - token_to_request_index = torch.repeat_interleave( - torch.arange( - len(request_in_prefill_status_tensor), - device=request_in_prefill_status_tensor.device, - ), - repeats, - ) - - output_tokens_jumbled_list = [] - token_order_list = [] - - for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: - request_indices_tensor = torch.tensor( - request_indices, device=token_to_request_index.device - ) - required_indices = torch.where( - torch.isin(token_to_request_index, request_indices_tensor) - )[0] - output_tokens_jumbled_list.append( - self._torch_sampling_func(required_logits[required_indices, :], temp, top_k, top_p) - ) - token_order_list.append(required_indices) - - output_tokens_jumbled = torch.cat(output_tokens_jumbled_list, dim=0) - output_tokens = torch.empty( - len(output_tokens_jumbled), - device=output_tokens_jumbled.device, - dtype=output_tokens_jumbled.dtype, - ) - token_order = torch.cat(token_order_list, dim=0) - # Rearrange output tokens from sampling_bucket request order back to input ids order - output_tokens[token_order] = output_tokens_jumbled - - return output_tokens, repeats - def _verify_speculative_tokens( self, output_tokens: Tensor, input_tokens_required: Tensor, - request_in_prefill_status_tensor: Tensor, - repeats: Tensor, num_decode_requests: int, num_prefill_requests: int, active_request_count: int, ) -> tuple: - """Verify speculative tokens against input tokens and compute acceptance. - - Creates an accepted tokens mask where: - - For prefill requests, the token is always accepted. - - For decode requests, the first token (base token) is always accepted, then we compare - sampled tokens with input tokens and accept consecutive matches. - Then finds the index of the last accepted token per request. - - Example (assume 1, 2, and 0 spec tokens are accepted in the first 3 decode requests): - input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 - Output tokens [ a6o a7o a8o | b40 b5o b6o | c7o c8o c9o | d3o | e5o ] - Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] - Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] - Last one indices [ 1 | 5 | 6 | 9 | 10 ] - - Returns: - tuple: (last_one_indices, accepted_tokens_mask, input_tokens_required) where - last_one_indices contains the index of the last accepted token per request. - """ - if input_tokens_required.ndim == 2: - assert ( - input_tokens_required.shape[0] == 1 - ), f"Expected input_tokens_required to have 1 row, but got {input_tokens_required.shape}" - input_tokens_required = input_tokens_required.squeeze(0) - - # Initialize mask with False to prevent boundary bleed - accepted_tokens_mask = torch.zeros_like(input_tokens_required, dtype=torch.bool) - - # Make all prefill tokens accepted - token_to_prefill_idx = torch.repeat_interleave(request_in_prefill_status_tensor, repeats) - accepted_tokens_mask[token_to_prefill_idx == 1] = True - - # Safe decode token verification without cross-batch boundary contamination - decode_mask_2d = None - if num_decode_requests > 0: - decode_len = num_decode_requests * (self.num_speculative_tokens + 1) - - decode_inputs = input_tokens_required[:decode_len].reshape( - num_decode_requests, self.num_speculative_tokens + 1 - ) - decode_outputs = output_tokens[:decode_len].reshape( - num_decode_requests, self.num_speculative_tokens + 1 - ) - - # Shift outputs right by 1 *within* each request to align sampled tokens with input targets - decode_outputs_shifted = decode_outputs.roll(1, dims=1) - decode_mask_2d = decode_inputs == decode_outputs_shifted - # The first token (base token) is always accepted - decode_mask_2d[:, 0] = True - # Enforce consecutive acceptance: cummin propagates False to the right - decode_mask_2d = decode_mask_2d.cummin(dim=1).values - accepted_tokens_mask[:decode_len] = decode_mask_2d.flatten() - - last_one_indices = torch.full( - (active_request_count,), -1, device=input_tokens_required.device + """Verify speculative tokens against input tokens (Triton kernel).""" + return verify_speculative_tokens( + input_tokens=input_tokens_required, + output_tokens=output_tokens, + num_decode_requests=num_decode_requests, + num_prefill_requests=num_prefill_requests, + num_speculative_tokens=self.num_speculative_tokens, ) - if num_decode_requests > 0: - # Summing the consecutive mask gives the count; subtract 1 for the local index - local_last_indices = decode_mask_2d.sum(dim=1) - 1 - row_offsets = torch.arange(num_decode_requests, device=last_one_indices.device) * ( - self.num_speculative_tokens + 1 - ) - last_one_indices[:num_decode_requests] = row_offsets + local_last_indices - - if num_prefill_requests > 0: - decode_len = num_decode_requests * (self.num_speculative_tokens + 1) - prefill_valid = ( - torch.nonzero(accepted_tokens_mask[decode_len:]).squeeze(-1) + decode_len - ) - last_one_indices[num_decode_requests:] = prefill_valid - - return last_one_indices, accepted_tokens_mask, input_tokens_required - - def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_ids: Tensor): + def _dynamic_step_sample_logits_and_verify_tokens(self, input_ids: Tensor): """ Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. """ context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count - request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ - context.paused_request_count : context.total_request_count - ] - request_query_lengths = context.request_query_lengths[ - context.paused_request_count : context.total_request_count - ] - - num_prefill_requests = request_in_prefill_status_tensor.sum().item() - num_decode_requests = active_request_count - num_prefill_requests - - # Get the logit indices for tokens that need sampling. - # These indices are always needed for input_ids slicing and tracking - # accepted sequence positions, even when logits are pre-sliced. - required_logit_indices = context.speculative_required_logit_indices(logits.device) + # Sampling-side request counts: padded when running a captured graph. + # Verify uses the actual counts so the Triton kernels operate on the real workload. + use_graph_for_sampling = ( + self._sampling_backend == "flashinfer" + and self._enable_cuda_graph + and context.using_cuda_graph_this_step() + ) + if use_graph_for_sampling: + sample_num_decode = context.padded_batch_dimensions.decode_req_count + sample_num_prefill = context.padded_batch_dimensions.prefill_req_count + else: + sample_num_decode = context.num_decode_requests + sample_num_prefill = context.num_prefill_requests + + # Logit indices for tokens that need sampling. + # Padded under graph capture so the captured `gather_indices` input has a stable shape. + # Padded slots resolve to row 0; verify and prepare-next read only the actual prefix, + # so the padded-row samples produced by the captured kernel are discarded. + nvtx_range_push("mtp-spec-decoding/verify/logit-indices") + # Use pre-allocated buffer for CUDA graph compatibility. + logits = self._all_logits_cuda + # `speculative_required_logit_indices()` already returns padded indices when + # running a captured graph (`num_last_token_logits` uses the padded counts and + # `pad_active_slices` zero-pads the trailing slots), so the call site does not + # need to re-pad here. + required_logit_indices = context.speculative_required_logit_indices() if context.config.materialize_only_last_token_logits: # last_token_logits already selected exactly the required positions. - required_logits = logits.squeeze(0) + sample_logits = logits.squeeze(0) + sample_gather_indices = None else: - required_logits = logits.squeeze(0)[ - required_logit_indices, : - ] # Shape [num_required, vocab_size] + # Push the gather inside the captured kernel: + # pass the full per-token logits buffer (constant shape) plus the padded indices. + sample_logits = logits.squeeze(0) + sample_gather_indices = required_logit_indices + nvtx_range_pop("mtp-spec-decoding/verify/logit-indices") # Sample tokens from logits - output_tokens, repeats = self._sample_speculative_logits( - required_logits, request_in_prefill_status_tensor + nvtx_range_push("mtp-spec-decoding/verify/sample") + output_tokens = self._sampling.sample_speculative( + sample_logits, + sample_num_decode, + sample_num_prefill, + self.num_speculative_tokens, + context, + gather_indices=sample_gather_indices, + eager=not use_graph_for_sampling, + cache_key=( + ("sample_speculative", sample_num_decode, sample_num_prefill) + if use_graph_for_sampling + else None + ), ) + nvtx_range_pop("mtp-spec-decoding/verify/sample") + + num_prefill_requests = context.num_prefill_requests + num_decode_requests = active_request_count - num_prefill_requests # Verify speculative tokens against input tokens. + nvtx_range_push("mtp-spec-decoding/verify/verify-tokens") input_tokens_required = input_ids[0, required_logit_indices] last_one_indices, accepted_tokens_mask, input_tokens_required = ( self._verify_speculative_tokens( output_tokens, input_tokens_required, - request_in_prefill_status_tensor, - repeats, num_decode_requests, num_prefill_requests, active_request_count, ) ) - - # Store the final sampled tokens for the next forward pass. - final_sampled_tokens = output_tokens[last_one_indices] - self._sampled_tokens_cuda[: len(final_sampled_tokens)] = final_sampled_tokens - - # Store the last accepted positions in the packed sequence for serial - # MTP computation after verification. - self._last_accepted_seq_indices = required_logit_indices[last_one_indices] - - # Extract accepted tokens and counts for decode requests. - # For prefill it is always set to 1. For decode, the first token is always accepted, - # then we compare with input tokens and accept the next tokens if its a match. - # - # Example (continuing from above): - # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] - # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] - # Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only decode requests (prefill defaults to -1) - # Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 - input_tokens_required[accepted_tokens_mask == 0] = -1 # Mask out non-accepted tokens - input_tokens_decode_mode = input_tokens_required[ - : num_decode_requests * (self.num_speculative_tokens + 1) - ] - input_tokens_reshaped = input_tokens_decode_mode.reshape( - -1, self.num_speculative_tokens + 1 - ) # shape: [num_decode_requests, num_speculative_tokens + 1] - - # Skip the first token of every decode request (i.e a5, b3, c6) - accepted_tokens = input_tokens_reshaped[:, 1:] - self._accepted_tokens_per_request[: accepted_tokens.shape[0], :] = accepted_tokens - self._accepted_token_counts_per_request = (self._accepted_tokens_per_request != -1).sum( - dim=1 + nvtx_range_pop("mtp-spec-decoding/verify/verify-tokens") + + nvtx_range_push("mtp-spec-decoding/verify/prepare-next") + self._prepare_speculative_tokens_for_next_forward_pass( + num_decode_requests, + output_tokens, + required_logit_indices, + last_one_indices, + accepted_tokens_mask, + input_tokens_required, ) + nvtx_range_pop("mtp-spec-decoding/verify/prepare-next") - def _dynamic_step_sample_logits(self, logits: Tensor): - """Sample tokens from logits for dynamic batching. + def _prepare_speculative_tokens_for_next_forward_pass( + self, + num_decode_requests: int, + output_tokens: torch.Tensor, + required_logit_indices: torch.Tensor, + last_one_indices: torch.Tensor, + accepted_tokens_mask: torch.Tensor, + input_tokens_required: torch.Tensor, + ): + """Prepare accepted speculative tokens for the next forward pass (Triton kernel). - Args: - logits (Tensor): The logits from the forward pass. + Example: + input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] + Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] (decode only; prefill → -1) + Accepted token counts [ 1 | 2 | 0 ] (prefill defaults to 0) """ + active_request_count = last_one_indices.shape[0] + prepare_next_forward_pass( + num_decode_requests=num_decode_requests, + output_tokens=output_tokens, + required_logit_indices=required_logit_indices, + last_one_indices=last_one_indices, + accepted_tokens_mask=accepted_tokens_mask, + input_tokens=input_tokens_required, + sampled_tokens_buf=self._sampled_tokens_cuda, + last_accepted_seq_buf=self._last_accepted_seq_indices_buf, + accepted_tokens_per_request=self._accepted_tokens_per_request, + accepted_token_counts=self._accepted_token_counts_per_request, + num_speculative_tokens=self.num_speculative_tokens, + ) + # Expose the active slice so downstream code sees the right length. + self._last_accepted_seq_indices = self._last_accepted_seq_indices_buf[:active_request_count] + + def _dynamic_step_sample_logits(self): + """Sample tokens from logits for dynamic batching.""" # TODO(ksanthanam): Evaluate whether it makes more sense to sample on 1 rank # and then broadcast the sampled tokens rather than broadcasting the raw logits. - # Last token logits. context = self.inference_wrapped_model.inference_context - if context.config.materialize_only_last_token_logits: - # When materialize_only_last_token_logits is true, last_token_logits is - # already called in the forward pass of GPT. - required_token_logits = logits.squeeze(0) - else: - # todo : Should do verification here and get approrpiate las token logits - required_token_logits = context.last_token_logits(logits) - - if self._sampling_backend == "torch": - # Concatenate the outputs once to prevent repeated small writes. - token_list = [] - indices_list = [] - - # e.g torch sample buckets will be - # i.e (for all unique comibnation of t, topk, topk what are the associated - # requests indices (based on the active slices) - # [ [req at index 0, req at index 2], t1, topk1, topp1 ]] - # [ [req at index 1, req at index 3, req at index 4] , t2, topk2, topp2] - for indices, temp, top_k, top_p in self._torch_sampling_buckets: - token_list.append( - self._torch_sampling_func(required_token_logits[indices, :], temp, top_k, top_p) - ) - indices_list.append(torch.tensor(indices)) - - # Single write to the output tensor. - sampled_tokens = torch.cat(token_list, dim=0) - sampled_indices = torch.cat(indices_list, dim=0) - - self._sampled_tokens_cuda[sampled_indices] = sampled_tokens + active_request_count = context.total_request_count - context.paused_request_count + use_graph = ( + self._sampling_backend == "flashinfer" + and self._enable_cuda_graph + and context.using_cuda_graph_this_step() + ) + # Padded count when running a captured graph (cache key buckets); actual otherwise. + n = context.padded_active_request_count if use_graph else active_request_count + # When `materialize_only_last_token_logits` is true the forward pass already + # selected the right rows. Otherwise we point the kernel at the per-request + # last-token positions via `gather_indices`; padded slots safely fan in to row 0. + gather_indices = ( + None + if context.config.materialize_only_last_token_logits + else context.gpu_view.active_request_last_token_idxs + ) + self._sampled_tokens_cuda = self._sampling.sample_kernel( + self._all_logits_cuda.squeeze(0), + n, + context, + gather_indices=gather_indices, + eager=not use_graph, + cache_key=("sample", n) if use_graph else None, + ) def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]: """Perform bookkeeping necessary to compute log probs for dynamic batching. @@ -1220,25 +1070,27 @@ def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]: return_log_probs (bool): Whether to return the sampled log_probs. """ context = self.inference_wrapped_model.inference_context - active_request_slice = slice(context.paused_request_count, context.total_request_count) - - return_log_probs = self._request_metadata["return_log_probs"][active_request_slice] - top_n_log_probs = self._request_metadata["top_n_logprobs"][active_request_slice] > 0 + active_request_count = context.total_request_count - context.paused_request_count - return return_log_probs.any(), top_n_log_probs.any() + return ( + (context.active_request_metadata["return_log_probs"][:active_request_count]).any(), + (context.active_request_metadata["top_n_logprobs"][:active_request_count] > 0).any(), + ) - def _router_record_bookkeeping(self) -> Optional[Dict[int, Tensor]]: - """Collect and map routing indices per request for MoE router recording. + def _router_record_bookkeeping(self) -> Optional[np.ndarray]: + """Collect flat routing indices for MoE router recording. - This method retrieves recorded routing decisions and maps them to individual - requests using the context's request_ids and query_lengths. Uses the context's - routing_metadata when available (which handles CUDA graph static buffers automatically). - Must be called while context attributes are still valid (before request transitions). + Retrieves recorded routing decisions via the context's routing_metadata + (which handles CUDA graph static buffers), performs the TP all-gather + when sequence parallelism is active, strips CUDA padding, and returns + a flat CPU numpy array aligned with the context's active-token layout. + Must be called while context attributes are still valid (before request + transitions). Returns: - Optional[Dict[int, Tensor]]: A dictionary mapping request_id to a tensor of - shape [num_tokens, num_layers, topk]. Returns None if routing replay is - disabled or no routing data was recorded. + Optional[np.ndarray]: Flat routing array of shape + [active_token_count, num_layers, topk], or None if routing + replay is disabled or no routing data was recorded. """ config = self.inference_wrapped_model.model.config if not config.moe_enable_routing_replay: @@ -1254,10 +1106,6 @@ def _router_record_bookkeeping(self) -> Optional[Dict[int, Tensor]]: if stacked_routing is None: return None - # Get active request info from context - active_request_slice = slice(context.paused_request_count, context.total_request_count) - active_request_ids = context.request_ids[active_request_slice].tolist() - active_query_lengths = context.request_query_lengths[active_request_slice].tolist() active_token_count = context.active_token_count # Get TP group for all-gather if using sequence parallelism @@ -1268,39 +1116,45 @@ def _router_record_bookkeeping(self) -> Optional[Dict[int, Tensor]]: # All-gather across TP group if using sequence parallelism (tp_size > 1) if tp_size > 1 and get_model_config(self.inference_wrapped_model.model).sequence_parallel: + # With SP, the model processes padded_active_token_count tokens total, + # scattered evenly across TP ranks. Each rank routes + # padded_active_token_count // tp_size tokens through MoE layers. + # + # The CUDA-graph static buffer path in get_routing_indices() may return + # a tensor sliced to active_token_count (the global unpadded count), + # which can be larger than the per-rank valid count. Truncate to the + # true per-rank count before the all-gather so we only gather valid + # routing data and reconstruct the full sequence in the correct order. + local_token_count = context.padded_active_token_count // tp_size + + stacked_routing = stacked_routing[:local_token_count] # gather_from_sequence_parallel_region gathers along dim 0 - # [local_token_count, num_layers, topk] -> [global_token_count, num_layers, topk] + # [local_token_count, num_layers, topk] -> [padded_token_count, num_layers, topk] stacked_routing = gather_from_sequence_parallel_region(stacked_routing, group=tp_group) - # Slice to real tokens (remove CUDA padding) - stacked_routing = stacked_routing[:active_token_count] - - # Split by request along token dimension - # stacked_routing has shape [active_token_count, num_layers, topk] - routing_splits = stacked_routing.split(active_query_lengths, dim=0) - - # Map to request IDs - routing_indices_per_request = {} - for req_id, routing_split in zip(active_request_ids, routing_splits): - # routing_split has shape [num_tokens_for_request, num_layers, topk] - routing_indices_per_request[req_id] = routing_split + # Slice to real tokens (remove CUDA padding), move to CPU as numpy with target dtype + _ri_dtype = np.int16 if (config.num_moe_experts or 0) <= 32768 else np.int32 + return stacked_routing[:active_token_count].cpu().numpy().astype(_ri_dtype) - return routing_indices_per_request - - def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]: + def _dynamic_step_calculate_log_probs(self) -> Optional[Tensor]: """Calculate log probs from logits.""" context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count + # This code cannot be reached when we are using speculative decode. + assert self.num_speculative_tokens == 0 + logits_seq_len = ( + active_request_count + if context.config.materialize_only_last_token_logits + else context.padded_active_token_count + ) return context.calculate_log_probs( - logits, + self._all_logits_cuda[:, :logits_seq_len, :], self._sampled_tokens_cuda[:active_request_count], only_last_token_logits=context.config.materialize_only_last_token_logits, ) - def _dynamic_step_calculate_log_probs_speculative( - self, logits: Tensor - ) -> Tuple[List[List[float]], Tensor]: + def _dynamic_step_calculate_log_probs_speculative(self) -> Tuple[List[List[float]], Tensor]: """Calculate log probs from logits for speculative decoding. For decode requests, computes log probs for each accepted speculative token @@ -1311,9 +1165,6 @@ def _dynamic_step_calculate_log_probs_speculative( - log_prob(accepted_token[j]) comes from logits at position j - log_prob(newly_sampled_token) comes from logits at position accepted_count - Args: - logits (Tensor): The main model logits [1, seq_len, vocab_size]. - Returns: Tuple of (log_probs_list, log_probs_tensor): log_probs_list: List of lists, one per active request, containing @@ -1323,17 +1174,18 @@ def _dynamic_step_calculate_log_probs_speculative( context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count - request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ - context.paused_request_count : context.total_request_count - ] - request_query_lengths = context.request_query_lengths[ - context.paused_request_count : context.total_request_count + # Use gpu_view for data consumed by GPU log-probs operations. + request_in_prefill_status_tensor = context.gpu_view.request_in_prefill_status[ + :active_request_count ] + request_query_lengths = context.gpu_view.request_query_lengths[:active_request_count] num_prefill_requests = request_in_prefill_status_tensor.sum().item() num_decode_requests = active_request_count - num_prefill_requests only_last = context.config.materialize_only_last_token_logits + # Use pre-allocated buffer for CUDA graph compatibility. + logits = self._all_logits_cuda logits_squeezed = logits.squeeze(0).float() if only_last: log_probs_tensor = F.log_softmax(logits_squeezed, dim=-1) @@ -1389,7 +1241,7 @@ def _dynamic_step_calculate_log_probs_speculative( ] log_probs_list_prefill = [[lp.item()] for lp in selected_log_probs] else: - prefill_token_ids = context.token_to_input_ids[ + prefill_token_ids = context.gpu_view.token_to_input_ids[ decode_len : context.active_token_count ].roll(-1, 0) prefill_query_lengths = request_query_lengths[request_in_prefill_status_tensor == 1] @@ -1431,14 +1283,12 @@ def _dynamic_step_calculate_top_n_logprobs_speculative( """ context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count - active_request_slice = slice(context.paused_request_count, context.total_request_count) - request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ - context.paused_request_count : context.total_request_count - ] - request_query_lengths = context.request_query_lengths[ - context.paused_request_count : context.total_request_count + # Use gpu_view for data consumed by GPU top-n operations. + request_in_prefill_status_tensor = context.gpu_view.request_in_prefill_status[ + :active_request_count ] + request_query_lengths = context.gpu_view.request_query_lengths[:active_request_count] num_prefill_requests = request_in_prefill_status_tensor.sum().item() num_decode_requests = active_request_count - num_prefill_requests @@ -1451,7 +1301,7 @@ def _dynamic_step_calculate_top_n_logprobs_speculative( num_decode_requests, self.num_speculative_tokens + 1, -1 ) accepted_counts = self._accepted_token_counts_per_request[:num_decode_requests] - top_n_per_request = self._request_metadata["top_n_logprobs"][active_request_slice][ + top_n_per_request = context.active_request_metadata["top_n_logprobs"][ :num_decode_requests ] max_top_n = int(top_n_per_request.max().item()) @@ -1480,8 +1330,8 @@ def _dynamic_step_calculate_top_n_logprobs_speculative( prefill_log_probs = log_probs_tensor[decode_len:] # Batch metadata reads: single CPU transfer for all prefill requests. - prefill_top_n = self._request_metadata["top_n_logprobs"][active_request_slice][ - num_decode_requests: + prefill_top_n = context.active_request_metadata["top_n_logprobs"][ + num_decode_requests:active_request_count ].tolist() max_top_n_prefill = int(max(prefill_top_n)) if prefill_top_n else 0 @@ -1508,7 +1358,7 @@ def _dynamic_step_calculate_top_n_logprobs_speculative( prefill_log_probs_per_request = prefill_log_probs.split( prefill_query_lengths.tolist(), dim=0 ) - prefill_skip_prompt = self._request_metadata["skip_prompt_log_probs"][ + prefill_skip_prompt = context.active_request_metadata["skip_prompt_log_probs"][ num_decode_requests:active_request_count ].tolist() @@ -1536,12 +1386,11 @@ def _dynamic_step_calculate_top_n_logprobs_speculative( return top_n_results if top_n_results else None def _dynamic_step_calculate_top_n_logprobs( - self, logits: Tensor, log_probs_tensor: Optional[Tensor] = None + self, log_probs_tensor: Optional[Tensor] = None ) -> Optional[Dict[int, List[Tuple[Tensor, Tensor]]]]: """Calculate top-n log probs from logits for dynamic batching. Args: - logits (Tensor): The logits to compute top-n log probs from. log_probs_tensor (Optional[Tensor]): Pre-computed log probabilities tensor. If provided, avoids recomputing log_softmax. Should be the tensor returned by calculate_log_probs. @@ -1567,9 +1416,7 @@ def _dynamic_step_calculate_top_n_logprobs( top_n_results = {} for req_idx in range(active_request_count): - top_n = int( - self._request_metadata["top_n_logprobs"][active_request_slice][req_idx].item() - ) + top_n = int(context.active_request_metadata["top_n_logprobs"][req_idx].item()) if top_n > 0: # Get top-n logprobs and indices for this request (single token) top_n_logits = torch.topk(log_probs[req_idx], k=top_n) @@ -1591,14 +1438,14 @@ def _dynamic_step_calculate_top_n_logprobs( top_n_results = {} for req_idx in range(active_request_count): - top_n = int( - self._request_metadata["top_n_logprobs"][active_request_slice][req_idx].item() - ) + top_n = int(context.active_request_metadata["top_n_logprobs"][req_idx].item()) if top_n > 0: request_log_probs = log_probs_per_request[ req_idx ] # [num_tokens_for_request, vocab_size] - skip_prompt = bool(self._request_metadata["skip_prompt_log_probs"][req_idx].item()) + skip_prompt = bool( + context.active_request_metadata["skip_prompt_log_probs"][req_idx].item() + ) # If skip_prompt_log_probs is True, only compute for last token if skip_prompt and request_log_probs.size(0) > 1: @@ -1619,42 +1466,23 @@ def _dynamic_step_calculate_top_n_logprobs( return top_n_results if top_n_results else None + @torch.inference_mode() def dummy_forward(self): """Perform a dummy forward pass. This is used in expert model parallelism on ranks that do not have any real requests. It may run in eager mode.""" context = self.inference_wrapped_model.inference_context - # if no cuda graphs, directly use dummy forward - if not context.cuda_graph_batch_dimensions_list: - self.inference_wrapped_model.dummy_forward() - - # Disable MoE padding for MTP computation - if self.model_config.moe_pad_experts_for_cuda_graph_inference: - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) - set_decode_expert_padding(unwrapped_model, False) - - self._dummy_serial_mtp_forward() - - return # attempt to use cuda-graph if possible input_ids, position_ids = self._dynamic_step_context_init(is_dummy_forward=True) + self._dynamic_step_forward_logits(input_ids, position_ids) - # _dynamic_step_context_init tries to find a cuda-graph that is compatible - # with all EP ranks. It can also return no match, in which case - # we run in eager mode. - - if context.using_cuda_graph_this_step(): - # we found a cuda-graph to run - self._dynamic_step_forward_logits(input_ids, position_ids) - else: - # fallback to eager dummy forward - self.inference_wrapped_model.dummy_forward() - - # Disable MoE padding for MTP computation + # Disable MoE padding for MTP computation, unless CUDA graphs + # are active (the graphs were captured with padding enabled). if self.model_config.moe_pad_experts_for_cuda_graph_inference: - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) - set_decode_expert_padding(unwrapped_model, False) + if not context.using_cuda_graph_this_step(): + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + set_decode_expert_padding(unwrapped_model, False) # When speculative decoding is active, the real EP ranks perform serial # MTP forward passes after the main forward pass. MTP layers may contain @@ -1686,10 +1514,11 @@ def _dummy_serial_mtp_forward(self): if self.model_config.expert_model_parallel_size <= 1: return - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + unwrapped_model = self._unwrapped_model - is_last_stage = is_pipeline_last_stage(self.pp_group) - has_mtp = is_last_stage and hasattr(unwrapped_model, '_decoder_hidden_states_cache') + has_mtp = self._is_last_pp_stage and hasattr( + unwrapped_model, '_decoder_hidden_states_cache' + ) if not has_mtp and not self.model_is_pipeline_parallel: # No MTP on this rank and no PP broadcast to participate in. return @@ -1697,31 +1526,46 @@ def _dummy_serial_mtp_forward(self): device = torch.cuda.current_device() dtype = self.model_config.params_dtype hidden_size = self.model_config.hidden_size - num_depths = min(self.num_speculative_tokens, self.num_mtp_heads) - # Pad token_ids/position_ids to nearest multiple of tp_size so that the - # embedding can reduce-scatter evenly across TP ranks. - tp_size = get_pg_size(self.inference_wrapped_model.tp_group) - sp_enabled = self.model_config.sequence_parallel and tp_size > 1 - padded_count = tp_size if sp_enabled else 1 + # Use precomputed MTP CUDA graph batch size when available; + # otherwise use minimal SP-compatible size. + if getattr(self, '_mtp_resolved_padded_count', None) is not None: + padded_count = self._mtp_resolved_padded_count + assert not self._sp_enabled or padded_count % self._tp_size == 0 + elif has_mtp: + # Eager path: use TP-aligned minimum size for dummy tensors. + padded_count = self._tp_size if self._sp_enabled else 1 dummy_hidden = None if has_mtp: - # Minimal dummy tensors — just enough to drive the MTP layer forward + # Minimal dummy tensors to drive the MTP layer forward # so that the MoE all-to-all collectives are issued. - # Depth 0 uses full-format hidden; subsequent depths use SP format. - dummy_hidden = torch.zeros((1, 1, hidden_size), device=device, dtype=dtype) + dummy_hidden = torch.zeros((padded_count, 1, hidden_size), device=device, dtype=dtype) + if self._sp_enabled: + dummy_hidden = scatter_to_sequence_parallel_region( + dummy_hidden, group=self.inference_wrapped_model.tp_group + ) dummy_token_ids = torch.zeros((1, padded_count), device=device, dtype=torch.long) dummy_position_ids = torch.zeros((1, padded_count), device=device, dtype=torch.long) - for depth in range(num_depths): + context = self.inference_wrapped_model.inference_context + + for depth in range(self._num_mtp_depths): + nvtx_range_push(f"mtp-spec-decoding/dummy-depth-{depth}") mtp_logits_2d = None if has_mtp: + mtp_depth = None if unwrapped_model.mtp.mtp_use_repeated_layer else depth dummy_hidden, mtp_logits = unwrapped_model.compute_mtp_single_step( hidden_states=dummy_hidden, next_token_ids=dummy_token_ids, position_ids=dummy_position_ids, - depth=depth, + depth=mtp_depth, + eager=not context.using_cuda_graph_this_step(), + cache_key=( + ("mtp", padded_count, mtp_depth) + if context.using_cuda_graph_this_step() + else None + ), ) mtp_logits_2d = mtp_logits.squeeze(1) # [padded_count, vocab_size] @@ -1733,6 +1577,25 @@ def _dummy_serial_mtp_forward(self): tensor=mtp_logits_2d, pp_group=self.pp_group, ) + nvtx_range_pop(f"mtp-spec-decoding/dummy-depth-{depth}") + + def _transfer_samples_to_cpu(self, active_request_count: int) -> tuple: + """Batch GPU-to-CPU transfer of sampled tokens. + + Called at the boundary between GPU sampling and CPU bookkeeping. + After this returns, all sampled data is on CPU and the remainder + of the step is 100% CPU. + + Returns: + tuple: (sampled_tokens_cpu, sampled_mtp_tokens_cpu) where + sampled_mtp_tokens_cpu is None when speculative decoding is off. + """ + sampled_tokens_cpu = self._sampled_tokens_cuda[:active_request_count].cpu() + if self.num_speculative_tokens > 0: + sampled_mtp_tokens_cpu = self._sampled_mtp_tokens_cuda[:, :active_request_count].cpu() + else: + sampled_mtp_tokens_cpu = None + return sampled_tokens_cpu, sampled_mtp_tokens_cpu def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: """Update the dynamic inference context after sampling. @@ -1753,7 +1616,15 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: active_request_count = context.total_request_count - context.paused_request_count active_request_slice = slice(context.paused_request_count, context.total_request_count) - # Active sequence lengths. + # Batch GPU-to-CPU transfer of all sampled tokens. + range_push("transfer_samples_to_cpu") + sampled_tokens_cpu, sampled_mtp_tokens_cpu = self._transfer_samples_to_cpu( + active_request_count + ) + range_pop() + + range_push("active_request_mask") + # Everything below is 100% CPU. active_request_ids = context.request_ids[active_request_slice].long() active_sequence_lengths = context.get_active_sequence_lengths() @@ -1765,13 +1636,15 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: max_sequence_lengths = context.get_max_sequence_lengths() # Request finished if termination_id or length >= max_sequence_length. - # Note: termination_id tensor has per-request termination IDs from mixed sampling + # Both operands are CPU: sampled_tokens_cpu was D2H'd above, and + # active_request_metadata is CPU-pinned. active_request_mask = ( - self._sampled_tokens_cuda[:active_request_count] - != self._request_metadata["termination_id"][active_request_slice] + sampled_tokens_cpu + != context.active_request_metadata["termination_id"][:active_request_count] ).byte() & torch.less(active_sequence_lengths, max_sequence_lengths).byte() - # Mark requests as finished if they hit stop words (detected in previous step's post_process_requests) + # Mark requests as finished if they hit stop words + # (detected in previous step's post_process_requests) if self._get_stop_word_finished_ids_callback is not None: request_ids_list = active_request_ids.tolist() stop_word_finished_ids = self._get_stop_word_finished_ids_callback(request_ids_list) @@ -1785,23 +1658,37 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: ) finished_request_ids = context.request_ids[finished_idxs] + # Save block IDs for finished requests before update_requests releases them. + # Needed for per-block routing reconstruction in the engine. + finished_routing_block_ids = {} + if context.kv_block_allocator.block_routing and finished_idxs.numel() > 0: + for fidx in finished_idxs.tolist(): + req_id = int(context.request_ids[fidx].item()) + blocks = context.request_to_kv_block_ids[fidx] + valid = blocks[blocks >= 0].tolist() + if valid: + finished_routing_block_ids[req_id] = valid + # Clone needed: update_requests mutates next_tokens in-place via tensor_swap, - # which would corrupt the reused _sampled_tokens_cuda buffer. - new_sample_copy = self._sampled_tokens_cuda[:active_request_count].clone() + # which would corrupt the reused buffer. + new_sample_copy = sampled_tokens_cpu.clone() + range_pop() - # Update requests. - # _sampled_mtp_tokens_cuda has shape [num_speculative_tokens, max_requests] - if self.num_speculative_tokens > 0: - sampled_mtp_tokens_cuda = self._sampled_mtp_tokens_cuda[:, :active_request_count] - else: - sampled_mtp_tokens_cuda = None + range_push("update_requests") update_result = context.update_requests( - active_request_mask, new_sample_copy, sampled_mtp_tokens_cuda + active_request_mask, new_sample_copy, sampled_mtp_tokens_cpu ) + range_pop() return { "active_request_ids": active_request_ids, "finished_request_ids": finished_request_ids, + # Already a CPU tensor (independent of _sampled_tokens_cuda via the + # .cpu() in _transfer_samples_to_cpu; update_requests only mutates + # the separate new_sample_copy). Returning the CPU copy avoids a + # D2H sync when the engine later calls sample.tolist(). + "sample": sampled_tokens_cpu, + "finished_routing_block_ids": finished_routing_block_ids, **(update_result or {}), } @@ -1845,7 +1732,8 @@ async def async_generate_output_tokens_dynamic_batch( # Forward pass produces only base logits. When speculative decoding is # active, MTP logits are computed serially after verification. - logits = self._dynamic_step_forward_logits(input_ids, position_ids) + range_push("forward_pass") + self._dynamic_step_forward_logits(input_ids, position_ids) # Commit Mamba intermediate states before update_requests, which # may swap request indices. The Python lists tracking EOS block IDs @@ -1854,8 +1742,11 @@ async def async_generate_output_tokens_dynamic_batch( if context.is_hybrid_model and context.mamba_slot_allocator is not None: context.mamba_slot_allocator.commit_intermediate_states() - # Collect routing indices per request (must be done before context transitions) - routing_indices_per_request = self._router_record_bookkeeping() + # Collect flat routing indices and scatter them into per-block storage. + # Must be done before update_requests while token-to-block mappings are valid. + # Reconstruction happens from blocks at request completion. + context.kv_block_allocator.store_routing_per_block(self._router_record_bookkeeping()) + range_pop() # This is the best place to yield control back to event loop. # At this point we have enqueued FW pass GPU kernels asynchronously. @@ -1867,52 +1758,68 @@ async def async_generate_output_tokens_dynamic_batch( await asyncio.sleep(0) with torch.inference_mode(): + range_push("sampling") return_log_probs, return_top_n_logprobs = self._dynamic_step_log_probs_bookkeeping() - self._dynamic_step_sample_bookkeeping() - if self.num_speculative_tokens > 0: # Phase 1: Verify speculative tokens using base logits only. - self._dynamic_step_sample_logits_and_verify_tokens(logits, input_ids) + nvtx_range_push("mtp-spec-decoding/verify") + self._dynamic_step_sample_logits_and_verify_tokens(input_ids) + nvtx_range_pop("mtp-spec-decoding/verify") # Phase 2: Rewind KV cache for rejected tokens. - self._rewind_kv_cache() + nvtx_range_push("mtp-spec-decoding/rewind-kv-cache") + blocks_to_release, remove_mask = self._rewind_kv_cache() + nvtx_range_pop("mtp-spec-decoding/rewind-kv-cache") - # Disable MoE padding for MTP computation + # Disable MoE padding for MTP computation, unless CUDA graphs + # are active (the graphs were captured with padding enabled). if self.model_config.moe_pad_experts_for_cuda_graph_inference: - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) - set_decode_expert_padding(unwrapped_model, False) + if not context.using_cuda_graph_this_step(): + set_decode_expert_padding(self._unwrapped_model, False) # Phase 3: Compute MTP serially with correct (verified) inputs. + nvtx_range_push("mtp-spec-decoding/serial-mtp") self._compute_serial_mtp_and_sample() + nvtx_range_pop("mtp-spec-decoding/serial-mtp") + + # Phase 4: Release freed blocks. Deferred from Phase 2 so the + # data-dependent boolean-mask sync overlaps with MTP GPU work. + context.kv_block_allocator.release_memory_blocks(blocks_to_release[remove_mask]) else: - self._dynamic_step_sample_logits(logits) + self._dynamic_step_sample_logits() log_probs = None top_n_logprobs = None if return_log_probs or return_top_n_logprobs: if self.num_speculative_tokens > 0: log_probs, log_probs_tensor = ( - self._dynamic_step_calculate_log_probs_speculative(logits) + self._dynamic_step_calculate_log_probs_speculative() ) if return_top_n_logprobs: top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs_speculative( log_probs_tensor ) else: - log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits) + log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs() if return_top_n_logprobs: top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs( - logits, log_probs_tensor + log_probs_tensor ) + range_pop() if skip_bookkeeping: - request_bookkeeping = {} + # _transfer_samples_to_cpu wasn't invoked on this path, so do + # a one-shot D2H here to keep "sample" as a CPU tensor for + # downstream consumers. + request_bookkeeping = { + "sample": self._sampled_tokens_cuda[:active_request_count].cpu() + } else: + # request_bookkeeping supplies "sample" as the already-CPU + # tensor produced by _transfer_samples_to_cpu. request_bookkeeping = self._dynamic_step_context_bookkeeping() ret = { - # Clone needed: _sampled_tokens_cuda is a reused buffer overwritten each step. - "sample": self._sampled_tokens_cuda[:active_request_count].clone(), "accepted_tokens": ( # Clone needed: .fill_(-1) on line 1480 would corrupt the returned value. self._accepted_tokens_per_request.clone() @@ -1921,7 +1828,6 @@ async def async_generate_output_tokens_dynamic_batch( ), "log_probs": log_probs, "top_n_logprobs": top_n_logprobs, - "routing_indices_per_request": routing_indices_per_request, "cuda_graph_request_count": cuda_graph_request_count, } if self.num_speculative_tokens > 0: diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py index d2279b0d07d..6f57a863c1c 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py @@ -3,6 +3,7 @@ import asyncio import logging import time +import uuid from megatron.core.inference.inference_request import unwrap_serialized_tensors from megatron.core.inference.sampling_params import SamplingParams @@ -92,6 +93,8 @@ async def completions(): if isinstance(stop, str): stop = [stop] + ignore_eos = bool(req.get("ignore_eos", False)) + sampling_params = SamplingParams( temperature=temperature, top_k=top_k, @@ -101,6 +104,7 @@ async def completions(): skip_prompt_log_probs=skip_prompt_log_probs, num_tokens_to_generate=int(req.get("max_tokens", 16)), stop_words=stop, + termination_id=-1 if ignore_eos else None, ) except ValueError as e: return f"Invalid sampling parameter: {e}", 400 @@ -117,6 +121,7 @@ async def completions(): skip_prompt_log_probs=sampling_params.skip_prompt_log_probs, num_tokens_to_generate=sampling_params.num_tokens_to_generate, stop_words=sampling_params.stop_words, + termination_id=sampling_params.termination_id, ) tasks.append(client.add_request(prompt_tokens, per_req_params)) @@ -160,6 +165,8 @@ async def completions(): # --- 5. Format Response (matching old_completions.py) --- choices = [] + total_completion_tokens = 0 + prompt_tokens_counts = [] request_idx = 0 for completed_request in batch_results: @@ -167,6 +174,17 @@ async def completions(): full_text = result["generated_text"] or "" text_output = (prompts_as_strings[request_idx] + full_text) if echo else full_text + generated_tokens = result.get("generated_tokens") or [] + prompt_tokens_list = result.get("prompt_tokens") or [] + total_completion_tokens += len(generated_tokens) + prompt_tokens_counts.append(len(prompt_tokens_list)) + + finish_reason = "length" + sampling_params_result = result.get("sampling_params") or {} + num_tokens_requested = sampling_params_result.get("num_tokens_to_generate") + if num_tokens_requested is None or len(generated_tokens) < num_tokens_requested: + finish_reason = "stop" + logprobs_data = None if sampling_params.return_log_probs: # Get prompt tokens and logprobs @@ -230,20 +248,49 @@ async def completions(): "top_logprobs": top_logprobs, } - choices.append({"index": request_idx, "text": text_output, "logprobs": logprobs_data}) + choice_data = { + "index": request_idx, + "text": text_output, + "logprobs": logprobs_data, + "finish_reason": finish_reason, + "prompt_token_ids": result["prompt_tokens"], + "generation_token_ids": result["generated_tokens"], + "generation_log_probs": result.get("generated_log_probs", []), + } + choice_data["policy_epoch"] = result["policy_epoch"] + choice_data["kv_cache_epoch"] = result["kv_cache_epoch"] + choice_data["num_evictions"] = sum( + 1 for e in result["events"] if e.get("type") == "EVICT" + ) + if result["routing_indices"] is not None: - choices[-1]["moe_topk_indices"] = result["routing_indices"] + choice_data["moe_topk_indices"] = result["routing_indices"] prompt_length = ( len(result["prompt_tokens"]) if result["prompt_tokens"] is not None else 0 ) if prompt_length: - choices[-1]["prompt_moe_topk_indices"] = result["routing_indices"][ + choice_data["prompt_moe_topk_indices"] = result["routing_indices"][ :prompt_length ] + choices.append(choice_data) request_idx += 1 - return jsonify({"choices": choices}) + prompt_token_count = max(prompt_tokens_counts) if prompt_tokens_counts else 0 + return jsonify( + { + "id": str(uuid.uuid4()), + "object": "text_completion", # as per the openAI spec + "created": int(time.time()), + "model": "EMPTY", + "choices": choices, + "usage": { + "prompt_tokens": prompt_token_count, + "completion_tokens": total_completion_tokens, + "total_tokens": prompt_token_count + total_completion_tokens, + }, + } + ) except ImportError as e: logger.warning(f"Could not import quart: {e}") diff --git a/megatron/core/inference/utils.py b/megatron/core/inference/utils.py index 0914b81f005..11973551aa2 100644 --- a/megatron/core/inference/utils.py +++ b/megatron/core/inference/utils.py @@ -73,6 +73,7 @@ def get_attention_mask(seq_length: int) -> torch.Tensor: # Initialize cache for sequence parallel modules moe_layer_cache = None +_moe_metadata_sync_initialized = False def _init_moe_expert_cache(model): @@ -100,6 +101,25 @@ def walk(module): walk(model) +def set_moe_metadata_sync(model) -> None: + """Set _runs_metadata_sync on inference dispatchers. + + Exactly one dispatcher per model — the first MoE layer — fires update_metadata + each step. All subsequent layers skip it to avoid redundant collective calls. + Must be called once after the model is built and put into eval mode. + """ + global moe_layer_cache, _moe_metadata_sync_initialized + if _moe_metadata_sync_initialized: + return + if moe_layer_cache is None: + _init_moe_expert_cache(model) + for i, moe_layer in enumerate(moe_layer_cache): + dispatcher = getattr(moe_layer, '_inference_token_dispatcher', None) + if dispatcher is not None: + dispatcher._runs_metadata_sync = i == 0 + _moe_metadata_sync_initialized = True + + def set_decode_expert_padding(model, set_to: bool = False, capacity_factor: int = None): """ Toggle MoE drop-and-pad for decode. @@ -201,34 +221,6 @@ def check_flashinfer_jit_cache_installed(log_version: bool = False): ) -def set_inference_cuda_graphed_iteration_for_ep_inference(model): - """Enable CUDA graph compatibility for expert parallel inference. - - Sets a flag in all MoELayers indicating the current iteration is being - captured/executed in a CUDA graph. This allows the dispatcher to adjust - its behavior for CUDA graph compatibility. - """ - global moe_layer_cache - if moe_layer_cache is None: - _init_moe_expert_cache(model) - - for moe_layer in moe_layer_cache: - moe_layer.set_inference_cuda_graphed_iteration() - - -def unset_inference_cuda_graphed_iteration_for_ep_inference(model): - """Disable CUDA graph compatibility for expert parallel inference. - - Clears the flag in all MoELayers, restoring standard dispatcher behavior. - """ - global moe_layer_cache - if moe_layer_cache is None: - _init_moe_expert_cache(model) - - for moe_layer in moe_layer_cache: - moe_layer.unset_inference_cuda_graphed_iteration() - - def tensor_swap(x, src_idxs, dst_idxs): """ Swap x[src_idxs] and x[dst_idxs] diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index 10261686eef..be7cb29ddb6 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -8,6 +8,7 @@ from megatron.core import parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.transformer.cuda_graphs import CudaGraphManager try: from megatron.core.extensions.transformer_engine import te_parallel_cross_entropy @@ -63,6 +64,20 @@ def __init__( self.vp_stage = None self.vp_size = self.config.virtual_pipeline_model_parallel_size + def _setup_mtp_cuda_graphs(self): + """Wrap `compute_mtp_single_step` with a CudaGraphManager. + + Must be called by subclasses after `self.mtp` is created. + """ + if self.config.cuda_graph_impl == "local": + self._mtp_cudagraph_manager = CudaGraphManager( + self.config, + base_module=self, + function_name="compute_mtp_single_step", + need_backward=False, + inline_capture=True, + ) + def _is_in_embd_group(self): if self.embd_group is None: return False @@ -325,7 +340,13 @@ def shared_embedding_or_output_weight(self) -> Tensor: @torch.inference_mode() def compute_mtp_single_step( - self, hidden_states: Tensor, next_token_ids: Tensor, position_ids: Tensor, depth: int + self, + hidden_states: Tensor, + next_token_ids: Tensor, + position_ids: Tensor, + depth: Optional[int] = None, + eager: bool = False, + cache_key=None, ) -> tuple: """Compute a single MTP depth for speculative decoding. @@ -336,13 +357,20 @@ def compute_mtp_single_step( hidden_states (Tensor): Hidden states at last accepted positions. next_token_ids (Tensor): Correct next token IDs [1, N]. position_ids (Tensor): Position IDs for the next tokens [1, N]. - depth (int): MTP depth index (0-indexed). + depth (int, optional): MTP depth index. Only needed when `mtp_use_repeated_layer` is + False (each depth uses a distinct layer). Omit for repeated-layer models so that a + single CUDA graph can serve all depths. + eager, cache_key: The `CudaGraphManager` works by monkey-patching this argument onto the + function signature. Explictly including them removes the need for a monkey-patch, + and makes it straightforward to call the same method with and without eager mode. + These arguments are consumed by `CudaGraphManager`, if it exists. Returns: tuple: (new_hidden_states, logits [N, 1, vocab_size]). """ - layer_idx = 0 if self.mtp.mtp_use_repeated_layer else depth - + # CudaGraphManager consumes these args, if it exists + del eager, cache_key + layer_idx = 0 if depth is None else depth mtp_hidden = self.mtp.layers[layer_idx].forward_single_position( hidden_states=hidden_states, next_token_ids=next_token_ids, diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 92d561c3412..3d44b4c783e 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -797,15 +797,22 @@ def get_gpt_mtp_block_spec_for_backend( mtp_model_layer_spec=transformer_layer_spec, backend=backend ) mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0 - mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers + if config.mtp_use_repeated_layer: + mtp_layer_specs = [mtp_layer_spec] + else: + mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers + + if not config.mtp_use_repeated_layer: + offset = get_mtp_layer_offset(config, vp_stage=vp_stage) + # Split the MTP layer specs to only include the layers that are built in this + # pipeline stage. + mtp_layer_specs = mtp_layer_specs[offset : offset + num_layers_to_build] + if len(mtp_layer_specs) > 0: + assert ( + len(mtp_layer_specs) == config.mtp_num_layers + ), f"All MTP layers must reside in the same pipeline stage" - offset = get_mtp_layer_offset(config, vp_stage=vp_stage) - # split the mtp layer specs to only include the layers that are built in this pipeline stage. - mtp_layer_specs = mtp_layer_specs[offset : offset + num_layers_to_build] if len(mtp_layer_specs) > 0: - assert ( - len(mtp_layer_specs) == config.mtp_num_layers - ), f"currently all of the mtp layers must stage in the same pipeline stage." mtp_block_spec = MultiTokenPredictionBlockSubmodules(layer_specs=mtp_layer_specs) else: mtp_block_spec = None diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index a1a9e62b9c1..ace7310accf 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -230,6 +230,8 @@ def __init__( pg_collection=self.pg_collection, ) + self._setup_mtp_cuda_graphs() + # Output if self.post_process: @@ -506,7 +508,6 @@ def forward( inference_params: Optional[BaseInferenceContext] = None, loss_mask: Optional[Tensor] = None, padding_mask: Optional[Tensor] = None, - is_spec_decode: Optional[bool] = None, ) -> Tensor: """Forward function of the GPT Model This function passes the input tensors through the embedding layer, and then the decoder and finally into the post @@ -520,9 +521,6 @@ def forward( padding_mask (Tensor, optional): Padding mask for MoE routing. Shape [bsz, seq_length]. True = padding (exclude), False = valid (include). Only used for MoE layers to exclude padding tokens from routing computations. - is_spec_decode (bool, optional): Explicitly override whether speculative - decoding is active. When ``None`` (default) the flag is inferred from - ``inference_context.num_speculative_tokens``. """ if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() @@ -590,7 +588,6 @@ def forward( runtime_gather_output=runtime_gather_output, extra_block_kwargs=extra_block_kwargs, inference_context=inference_context, - is_spec_decode=is_spec_decode, ) def _postprocess( @@ -612,7 +609,6 @@ def _postprocess( runtime_gather_output=None, extra_block_kwargs=None, inference_context=None, - is_spec_decode=None, ): """Postprocesses decoder hidden states to generate logits or compute loss. @@ -626,12 +622,11 @@ def _postprocess( # Check if speculative decoding is active. When it is, MTP must be # computed *after* verification so that it is conditioned on verified # tokens rather than stale speculative tokens from the previous step. - if is_spec_decode is None: - is_spec_decode = ( - in_inference_mode - and inference_context.is_dynamic_batching() - and inference_context.num_speculative_tokens > 0 - ) + is_spec_decode = ( + in_inference_mode + and inference_context.is_dynamic_batching() + and inference_context.num_speculative_tokens > 0 + ) # logits and loss output_weight = None diff --git a/megatron/core/models/hybrid/hybrid_block.py b/megatron/core/models/hybrid/hybrid_block.py index 5494d531e52..6d20bcdd6e5 100644 --- a/megatron/core/models/hybrid/hybrid_block.py +++ b/megatron/core/models/hybrid/hybrid_block.py @@ -25,7 +25,7 @@ from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule +from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.transformer.utils import sharded_state_dict_default @@ -47,7 +47,7 @@ class HybridStackSubmodules: mtp_block_spec: Optional[ModuleSpec] = None -class HybridStack(GraphableMegatronModule, MegatronModule): +class HybridStack(MegatronModule): """ Constructor for the HybridStack class. @@ -206,39 +206,6 @@ def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int return layer.mamba_state_shapes_per_request() return None - def _should_call_local_cudagraph(self, *args, **kwargs): - """ - Check if we should call the local cudagraph path. - """ - if ( - not self.training - and hasattr(self, 'cudagraph_manager') - and kwargs['attention_mask'] is None - and ( - kwargs.get('inference_context') is not None - or kwargs.get('inference_params') is not None - ) - and CudaGraphScope.full_iteration_inference in self.config.cuda_graph_scope - ): - if kwargs['inference_context'].is_static_batching(): - using_cuda_graph = kwargs['inference_context'].is_decode_only() - else: - using_cuda_graph = kwargs['inference_context'].using_cuda_graph_this_step() - - if using_cuda_graph: - return True - return False - - def __call__(self, *args, **kwargs): - if self._should_call_local_cudagraph(*args, **kwargs): - kwargs['hidden_states'] = ( - kwargs['hidden_states'].unwrap() - if isinstance(kwargs['hidden_states'], WrappedTensor) - else kwargs['hidden_states'] - ) - return super().__call__(*args, **kwargs)[0] - return super().__call__(*args, **kwargs) - def forward( self, hidden_states: Union[Tensor, WrappedTensor], diff --git a/megatron/core/models/hybrid/hybrid_model.py b/megatron/core/models/hybrid/hybrid_model.py index a9a6d1e106b..d82cfa8f568 100644 --- a/megatron/core/models/hybrid/hybrid_model.py +++ b/megatron/core/models/hybrid/hybrid_model.py @@ -10,6 +10,7 @@ from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import YarnRotaryEmbedding from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( @@ -19,7 +20,8 @@ from megatron.core.quantization.utils import get_quant_config_or_none from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.enums import CudaGraphScope, ModelType +from megatron.core.transformer.module import GraphableMegatronModule from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, mtp_on_this_rank, @@ -36,7 +38,7 @@ logger = logging.getLogger(__name__) -class HybridModel(LanguageModule): +class HybridModel(LanguageModule, GraphableMegatronModule): """Hybrid language model. Args: @@ -71,7 +73,7 @@ class HybridModel(LanguageModule): parallel ranks. Defaults to True. share_embeddings_and_output_weights (bool, optional): When True, input embeddings and output logit weights are shared. Defaults to False. - position_embedding_type (Literal[learned_absolute,rope,none], optional): Position + position_embedding_type (Literal[learned_absolute,rope,yarn,none], optional): Position embedding type. Defaults to 'none'. rotary_percent (float, optional): Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. @@ -100,7 +102,7 @@ def __init__( parallel_output: bool = True, share_embeddings_and_output_weights: bool = False, # Mamba with no attention has no need for position embeddings, so none is default - position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'none', + position_embedding_type: Literal['learned_absolute', 'rope', 'yarn', 'none'] = 'none', rotary_percent: float = 1.0, rotary_base: int = 10000, scatter_embedding_sequence_parallel: bool = True, @@ -228,7 +230,26 @@ def __init__( use_cpu_initialization=self.config.use_cpu_initialization, cp_group=self.pg_collection.cp, ) - + elif self.position_embedding_type == 'yarn': + self.rotary_pos_emb = YarnRotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + scaling_factor=getattr(self.config, "yarn_rotary_scaling_factor"), + original_max_position_embeddings=getattr( + self.config, "yarn_original_max_position_embeddings" + ), + beta_fast=getattr(self.config, "yarn_beta_fast"), + beta_slow=getattr(self.config, "yarn_beta_slow"), + mscale=getattr(self.config, "yarn_mscale"), + mscale_all_dim=getattr(self.config, "yarn_mscale_all_dim"), + correction_range_round_to_int=getattr( + self.config, "yarn_correction_range_round_to_int" + ), + use_cpu_initialization=self.config.use_cpu_initialization, + cp_group=self.pg_collection.cp, + ) self.decoder = build_module( hybrid_stack_spec, self.config, @@ -258,6 +279,7 @@ def __init__( mtp_num_depths=self.mtp_num_depths, hybrid_submodules=hybrid_submodules, ) + self._setup_mtp_cuda_graphs() # Output if post_process or self.mtp_process: @@ -324,6 +346,42 @@ def preprocess_for_fine_grained_offloading(self): off_interface.mark_not_offload(param) self.disable_param_offloading = False + def _should_call_local_cudagraph(self, *args, **kwargs): + """ + Check if we should call the local cudagraph path. + """ + if ( + not self.training + and hasattr(self, 'cudagraph_manager') + and ( + kwargs.get('inference_context') is not None + or kwargs.get('inference_params') is not None + ) + and CudaGraphScope.full_iteration_inference in self.config.cuda_graph_scope + ): + if kwargs['inference_context'].is_static_batching(): + using_cuda_graph = kwargs['inference_context'].is_decode_only() + else: + using_cuda_graph = kwargs['inference_context'].using_cuda_graph_this_step() + + if using_cuda_graph: + return True + return False + + def __call__(self, *args, **kwargs): + if self._should_call_local_cudagraph(*args, **kwargs): + return super().__call__(*args, **kwargs)[0] + return super().__call__(*args, **kwargs) + + def create_mcore_cudagraph_manager(self, config): + """ + Create the cudagraph manager for the full iteration inference scope + """ + if CudaGraphScope.full_iteration_inference in config.cuda_graph_scope: + from megatron.core.transformer.cuda_graphs import CudaGraphManager + + self.cudagraph_manager = CudaGraphManager(config) + def forward( self, input_ids: Tensor, @@ -338,7 +396,6 @@ def forward( loss_mask: Optional[Tensor] = None, packed_seq_params: Optional[PackedSeqParams] = None, padding_mask: Optional[Tensor] = None, - is_spec_decode: Optional[bool] = None, ) -> Tensor: """Forward function of the Hybrid model. This function passes the input tensors through the embedding layer, and then the decoder and finally into the post @@ -387,6 +444,15 @@ def forward( rotary_seq_len, packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', ) + elif self.position_embedding_type == 'yarn': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_context, self.decoder, decoder_input, self.config, packed_seq_params + ) + # YarnRotaryEmbedding.forward returns (emb, mscale); discard mscale here + rotary_pos_emb, _ = self.rotary_pos_emb( + rotary_seq_len, + packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', + ) # Wrap decoder_input to allow the decoder (HybridStack) to delete the # reference held by this caller function, enabling early garbage collection @@ -421,12 +487,11 @@ def forward( # Check if speculative decoding is active. When it is, MTP must be # computed *after* verification so that it is conditioned on verified # tokens rather than stale speculative tokens from the previous step. - if is_spec_decode is None: - is_spec_decode = ( - in_inference_mode - and inference_context.is_dynamic_batching() - and inference_context.num_speculative_tokens > 0 - ) + is_spec_decode = ( + in_inference_mode + and inference_context.is_dynamic_batching() + and inference_context.num_speculative_tokens > 0 + ) mtp_forward_ran = self.mtp_process and not (in_inference_mode or is_spec_decode) if mtp_forward_ran: diff --git a/megatron/core/models/mimo/comm/__init__.py b/megatron/core/models/mimo/comm/__init__.py new file mode 100644 index 00000000000..26496bfed70 --- /dev/null +++ b/megatron/core/models/mimo/comm/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/models/mimo/comm/colocated_communicator.py b/megatron/core/models/mimo/comm/colocated_communicator.py new file mode 100644 index 00000000000..4c43dcdf3cd --- /dev/null +++ b/megatron/core/models/mimo/comm/colocated_communicator.py @@ -0,0 +1,325 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist + +from megatron.core.hyper_comm_grid import HyperCommGrid + + +@dataclass +class SliceInfo: + """Batch dimension slice information for a rank's data partition.""" + + start: int + size: int + + +class BridgeDirection(str, Enum): + """Which side of the bridge scales up, if any. + + ``FAN_IN`` — src has more DP replicas than dest; forward all-gathers + src outputs along the batch dim, backward narrows the sibling dest + gradient down to this src rank's slot. + + ``FAN_OUT`` — dest has more DP replicas; forward narrows, backward + all-gathers across the sibling dest DP ranks (the adjoint of narrow + is not zero-pad-and-scatter because every dest rank consumes a + different slice of the same src activation). + + ``EQUAL`` — matching DP; the bridge is a pure passthrough. + """ + + FAN_IN = "fan_in" + FAN_OUT = "fan_out" + EQUAL = "equal" + + +class ColocatedBridgeCommunicator: + """Bridges tensors between colocated modules with different TP/DP layouts. + + Default ``dim_mapping`` assumes 3D ``(b, s, h)``. Callers bridging + ``MimoModel``'s pre-flattened ``(s*b, h)`` encoder output should pass + ``dim_mapping={'b': 0, 'h': 1}``; this relies on a uniform token count per + sample so dim 0 divides evenly by the DP scale. + + Precondition: the input must be TP-replicated across the src TP group — + i.e. all TP ranks inside a src DP replica hold the same tensor on the + batch dim. The bridge never gathers along TP; violating this silently + produces wrong results. + """ + + def __init__( + self, + src_grid: HyperCommGrid, + dest_grid: HyperCommGrid, + src_module_name: str = "src", + dest_module_name: str = "dest", + dim_mapping: Optional[Dict[str, int]] = None, + ): + self.src_grid = src_grid + self.dest_grid = dest_grid + self.src_module_name = src_module_name + self.dest_module_name = dest_module_name + self.dim_mapping = dim_mapping or {'b': 0, 's': 1, 'h': 2} + self.current_rank = dist.get_rank() + + self._validate_grids() + self._extract_parallelism_info() + self._build_rank_mappings() + + # At most one direction is active; fan-in and fan-out are mutually + # exclusive (one of ``src_dp / dest_dp`` is >1, the other is 1). + # Equal DP uses no collective at all. Unify behind a single + # ``gather_pg`` + ``direction`` + ``scale`` rather than a fan-in + # and fan-out pair of attributes. + self.gather_pg: Optional[dist.ProcessGroup] = None + self.gather_group_ranks: List[List[int]] = [] + + if self.src_dp_size > self.dest_dp_size: + self.direction = BridgeDirection.FAN_IN + self.scale = self.src_dp_size // self.dest_dp_size + self.gather_group_ranks = self._build_gather_groups( + iter_size=self.dest_dp_size, + sibling_tp_size=self.src_tp_size, + scale=self.scale, + rank_to_pos=self.rank_to_src_pos, + ) + self.gather_pg, _ = dist.new_subgroups_by_enumeration( + self.gather_group_ranks, backend='nccl' + ) + elif self.dest_dp_size > self.src_dp_size: + self.direction = BridgeDirection.FAN_OUT + self.scale = self.dest_dp_size // self.src_dp_size + self.gather_group_ranks = self._build_gather_groups( + iter_size=self.src_dp_size, + sibling_tp_size=self.dest_tp_size, + scale=self.scale, + rank_to_pos=self.rank_to_dest_pos, + ) + self.gather_pg, _ = dist.new_subgroups_by_enumeration( + self.gather_group_ranks, backend='nccl' + ) + else: + self.direction = BridgeDirection.EQUAL + self.scale = 1 + + logging.info( + f"[Rank {self.current_rank}] ColocatedBridgeCommunicator: " + f"{src_module_name}({self.src_tp_size}TP/{self.src_dp_size}DP) -> " + f"{dest_module_name}({self.dest_tp_size}TP/{self.dest_dp_size}DP), " + f"direction={self.direction.value}, scale={self.scale}" + ) + + def _validate_grids(self): + if self.src_grid.size != self.dest_grid.size: + raise ValueError( + f"Grids must span same number of ranks: " + f"src={self.src_grid.size}, dest={self.dest_grid.size}" + ) + + if self.src_grid.rank_offset != self.dest_grid.rank_offset: + raise ValueError( + f"Grids must have same rank offset: " + f"src={self.src_grid.rank_offset}, dest={self.dest_grid.rank_offset}" + ) + + # Per-grid dim checks: tp/dp required; pp and cp (if present) must be 1. + # CP>1 also corrupts dp_idx when iterating get_rank_enum(['tp']) groups. + for name, grid in [("src", self.src_grid), ("dest", self.dest_grid)]: + for required in ('tp', 'dp'): + if required not in grid.dim_names: + raise ValueError( + f"{name} grid must have '{required}' dimension, " + f"got dim_names={grid.dim_names}" + ) + for singleton in ('pp', 'cp'): + if singleton in grid.dim_names: + size = grid.shape[grid.dim_names.index(singleton)] + if size != 1: + raise ValueError( + f"{name} {singleton.upper()} must be 1 for " + f"ColocatedBridgeCommunicator, got {size}" + ) + + src_dp = self.src_grid.shape[self.src_grid.dim_names.index('dp')] + dest_dp = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] + if src_dp % dest_dp != 0 and dest_dp % src_dp != 0: + raise ValueError( + f"DP sizes must be evenly divisible: src_dp={src_dp}, dest_dp={dest_dp}" + ) + + def _extract_parallelism_info(self): + self.src_tp_size = self.src_grid.shape[self.src_grid.dim_names.index('tp')] + self.src_dp_size = self.src_grid.shape[self.src_grid.dim_names.index('dp')] + self.dest_tp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('tp')] + self.dest_dp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] + + def _build_rank_mappings(self): + self.rank_to_src_pos: Dict[int, Tuple[int, int]] = {} + self.rank_to_dest_pos: Dict[int, Tuple[int, int]] = {} + + src_tp_groups = self.src_grid.get_rank_enum(['tp']) + for dp_idx, tp_group in enumerate(src_tp_groups): + for tp_idx, rank in enumerate(tp_group): + self.rank_to_src_pos[rank] = (dp_idx, tp_idx) + + dest_tp_groups = self.dest_grid.get_rank_enum(['tp']) + for dp_idx, tp_group in enumerate(dest_tp_groups): + for tp_idx, rank in enumerate(tp_group): + self.rank_to_dest_pos[rank] = (dp_idx, tp_idx) + + @staticmethod + def _build_gather_groups( + iter_size: int, sibling_tp_size: int, scale: int, rank_to_pos: Dict[int, Tuple[int, int]] + ) -> List[List[int]]: + """Build ``iter_size * sibling_tp_size`` gather groups of ``scale`` ranks. + + For each slot on the "iterating" side and each TP shard on the + sibling side, collect the ``scale`` sibling ranks whose DP indices + map into that slot. Append order equals group-local-rank order, + which ``all_gather_into_tensor`` uses to concatenate outputs — do + not sort. + """ + groups: List[List[int]] = [] + for iter_idx in range(iter_size): + sibling_dp_indices = range(iter_idx * scale, (iter_idx + 1) * scale) + for sibling_tp_idx in range(sibling_tp_size): + group_ranks = [] + for sibling_dp_idx in sibling_dp_indices: + for rank, (dp, tp) in rank_to_pos.items(): + if dp == sibling_dp_idx and tp == sibling_tp_idx: + group_ranks.append(rank) + break + groups.append(group_ranks) + return groups + + def is_fan_in(self) -> bool: + """True if src DP > dest DP (forward all-gathers).""" + return self.direction is BridgeDirection.FAN_IN + + def is_fan_out(self) -> bool: + """True if src DP < dest DP (forward narrows).""" + return self.direction is BridgeDirection.FAN_OUT + + def get_slice_info(self, batch_size: int) -> SliceInfo: + """Compute this rank's slice of ``batch_size`` on the narrowing side. + + For FAN_OUT this is the forward narrow; for FAN_IN it is the + backward narrow against the post-gather batch. EQUAL returns the + identity slice. + + Raises ``ValueError`` if ``batch_size`` is not divisible by ``scale``. + """ + if self.direction is BridgeDirection.EQUAL: + return SliceInfo(start=0, size=batch_size) + self._check_divisible(batch_size) + if self.direction is BridgeDirection.FAN_OUT: + dp_idx = self.rank_to_dest_pos[self.current_rank][0] + else: # FAN_IN + dp_idx = self.rank_to_src_pos[self.current_rank][0] + slot = dp_idx % self.scale + slice_size = batch_size // self.scale + return SliceInfo(start=slot * slice_size, size=slice_size) + + def _check_divisible(self, batch_size: int) -> None: + if batch_size % self.scale != 0: + raise ValueError( + f"ColocatedBridgeCommunicator: batch dim size {batch_size} is " + f"not divisible by {self.direction.value} scale={self.scale}." + ) + + def communicate(self, tensor: torch.Tensor) -> torch.Tensor: + """Transform ``tensor`` from src TP/DP layout to dest TP/DP layout. + + Raises ``ValueError`` when FAN_OUT and the batch dim is not + divisible by ``scale``; FAN_IN only slices on the backward pass + and re-checks via ``get_slice_info`` there. + """ + if self.direction is BridgeDirection.FAN_OUT: + self._check_divisible(tensor.shape[self.dim_mapping['b']]) + return _ColocatedCommunicate.apply(tensor, self) + + def destroy(self) -> None: + """Release the NCCL subgroup created by this communicator. + + NCCL caps concurrent communicators; long-lived or repeated + construction leaks PGs without this call. + """ + if self.gather_pg is not None: + dist.destroy_process_group(self.gather_pg) + self.gather_pg = None + + +class _ColocatedCommunicate(torch.autograd.Function): + """Autograd function for colocated communication with correct backward pass.""" + + @staticmethod + def forward(ctx, tensor: torch.Tensor, comm: ColocatedBridgeCommunicator) -> torch.Tensor: + """Reshape the batch dim across the bridge: narrow on fan-out, all-gather on fan-in.""" + ctx.comm = comm + ctx.batch_dim = comm.dim_mapping['b'] + + if comm.direction is BridgeDirection.FAN_OUT: + # Narrow this rank's slice out of the full src batch. + slice_info = comm.get_slice_info(tensor.shape[ctx.batch_dim]) + return tensor.narrow(ctx.batch_dim, slice_info.start, slice_info.size).contiguous() + + if comm.direction is BridgeDirection.FAN_IN: + # All-gather sibling src outputs into a single full-batch tensor. + return _all_gather_along_batch_dim(tensor, comm.gather_pg, ctx.batch_dim) + + # EQUAL: pure passthrough. + return tensor.contiguous() + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: + """Adjoint of forward: narrow for fan-in, all-gather for fan-out. + + Fan-out's forward is ``narrow``, whose naive adjoint is zero-pad. + That would leave each src rank with only its own dest rank's + slice of the gradient, missing the contributions from every + other dest rank that consumed a different slice of the same src + activation. Instead we all-gather across the fan-out sibling + group, reconstructing the full src-batch gradient (symmetric + with the fan-in forward's all-gather). + """ + comm = ctx.comm + batch_dim = ctx.batch_dim + + if comm.direction is BridgeDirection.FAN_OUT: + return _all_gather_along_batch_dim(grad_output, comm.gather_pg, batch_dim), None + + if comm.direction is BridgeDirection.FAN_IN: + slice_info = comm.get_slice_info(grad_output.shape[batch_dim]) + return ( + grad_output.narrow(batch_dim, slice_info.start, slice_info.size).contiguous(), + None, + ) + + return grad_output.contiguous(), None + + +def _all_gather_along_batch_dim( + tensor: torch.Tensor, group: dist.ProcessGroup, batch_dim: int +) -> torch.Tensor: + """All-gather ``tensor`` along an arbitrary batch dim into a single tensor. + + ``all_gather_into_tensor`` concatenates along dim 0, so when the + batch dim is not 0 we move it, gather, then restore. + """ + world_size = dist.get_world_size(group) + src = tensor.contiguous() + if batch_dim != 0: + src = src.movedim(batch_dim, 0).contiguous() + out_shape = list(src.shape) + out_shape[0] *= world_size + out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device) + dist.all_gather_into_tensor(out, src, group=group) + if batch_dim != 0: + out = out.movedim(0, batch_dim).contiguous() + return out diff --git a/megatron/core/models/mimo/config/base_configs.py b/megatron/core/models/mimo/config/base_configs.py index a92484a5a48..0eda09465e0 100644 --- a/megatron/core/models/mimo/config/base_configs.py +++ b/megatron/core/models/mimo/config/base_configs.py @@ -23,9 +23,11 @@ class MimoModelConfig: in the input_ids to insert the modality embeddings at the correct positions. module_to_grid_map (Optional[Dict[str, HyperCommGrid]]): Dictionary mapping module keys (e.g., "vision", "language") to their - corresponding HyperCommGrid configurations for non-colocated pipeline - parallelism. The language model must use the key MIMO_LANGUAGE_MODULE_KEY. - When None, all modules are assumed to be colocated on the same ranks. + corresponding HyperCommGrid configurations. The language model must use + the key MIMO_LANGUAGE_MODULE_KEY. + When grids span the same ranks → colocated (same or different TP/DP). + When grids span disjoint ranks → non-colocated (pipeline parallel). + When None → colocated with legacy global parallel_state. kv_format (str): Key-value format for attention: "sbhd" (seq-batch-head-dim) or "thd" (total-head-dim). Default is "sbhd". @@ -43,3 +45,18 @@ class MimoModelConfig: special_token_ids: Dict[str, int] = field(default_factory=dict) module_to_grid_map: Optional[Dict[str, HyperCommGrid]] = None kv_format: str = "sbhd" + + def __post_init__(self): + if not self.module_to_grid_map: + return + # Local import avoids circular imports at dataclass-module import time. + from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY + + expected_keys = set(self.modality_submodules_spec.keys()) | {MIMO_LANGUAGE_MODULE_KEY} + grid_keys = set(self.module_to_grid_map.keys()) + if grid_keys != expected_keys: + raise ValueError( + f"module_to_grid_map keys must match modality module names + " + f"'{MIMO_LANGUAGE_MODULE_KEY}'. Missing: {expected_keys - grid_keys}, " + f"Extra: {grid_keys - expected_keys}" + ) diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 77c2512e8e6..411791f1e5c 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List +from typing import Dict, List, Optional import torch.distributed as dist @@ -24,22 +24,17 @@ class ModuleLayout(Enum): Determines how modules are distributed across ranks and which forward path is used. - UNIFIED: No module_to_grid_map. All modules share same ranks and - parallelism. Uses the unified forward path (_forward_all_modules). + COLOCATED: All modules share the same ranks. Covers both legacy + (no grid map, global parallel_state) and heterogeneous TP/DP + (grid map with overlapping ranks). Uses _forward_all_modules. NON_COLOCATED: module_to_grid_map is set with non-overlapping rank ranges. Each rank runs EITHER encoder(s) OR the language model. Uses role-based dispatch with separate forward paths. - - COLOCATED: (future) module_to_grid_map is set with overlapping rank - ranges. Encoder(s) and language model share ranks but have - different parallelism configs. Uses role-based dispatch but - allows both module types on the same rank. """ - UNIFIED = "unified" - NON_COLOCATED = "non_colocated" COLOCATED = "colocated" + NON_COLOCATED = "non_colocated" @dataclass @@ -70,50 +65,50 @@ class RankRole: """ modules: Dict[str, ModuleStageInfo] = field(default_factory=dict) - mode: ModuleLayout = ModuleLayout.UNIFIED + mode: ModuleLayout = ModuleLayout.COLOCATED + + @classmethod + def build( + cls, + modality_module_names: List[str], + module_to_grid_map: Optional[Dict[str, 'HyperCommGrid']] = None, + ) -> 'RankRole': + """Build a RankRole, dispatching by whether grids share ranks. + + No grid map or all grids span the same ranks → COLOCATED. + Grids differ → NON_COLOCATED with PP-stage info per module. + """ + if module_to_grid_map is None or cls._all_grids_colocated(module_to_grid_map): + return cls._colocated(modality_module_names) + return cls._from_grid_map(module_to_grid_map) + + @staticmethod + def _all_grids_colocated(module_to_grid_map: Dict[str, 'HyperCommGrid']) -> bool: + grids = list(module_to_grid_map.values()) + first = grids[0] + return all(g.rank_offset == first.rank_offset and g.size == first.size for g in grids[1:]) @classmethod - def unified(cls, module_names: List[str]) -> 'RankRole': - """Create a role for the unified case: every module, first+last stage.""" + def _colocated(cls, modality_module_names: List[str]) -> 'RankRole': + """Colocated layout: every module on every rank, PP=1.""" + all_module_names = list(modality_module_names) + [MIMO_LANGUAGE_MODULE_KEY] return cls( modules={ name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) - for name in module_names + for name in all_module_names }, - mode=ModuleLayout.UNIFIED, + mode=ModuleLayout.COLOCATED, ) @classmethod - def from_grid_map( - cls, module_to_grid_map: Dict[str, HyperCommGrid], modality_module_names: List[str] - ) -> 'RankRole': - """Create a role from a module-to-grid mapping for non-colocated PP. - - Determines which modules the current rank participates in and its - pipeline stage position within each module. + def _from_grid_map(cls, module_to_grid_map: Dict[str, HyperCommGrid]) -> 'RankRole': + """Non-colocated role for this rank from a module-to-grid mapping. - Args: - module_to_grid_map: Dict mapping module names to HyperCommGrid objects. - Must contain keys matching modality_module_names + MIMO_LANGUAGE_MODULE_KEY. - modality_module_names: List of modality module names (e.g., ["images", "audio"]). - - Returns: - RankRole for the current rank. + Grid map keys are validated by ``MimoModelConfig.__post_init__``. Raises: - ValueError: If grid map keys don't match expected module names. RuntimeError: If current rank is not in any module grid. """ - # Validate keys - expected_keys = set(modality_module_names) | {MIMO_LANGUAGE_MODULE_KEY} - grid_keys = set(module_to_grid_map.keys()) - if grid_keys != expected_keys: - raise ValueError( - f"module_to_grid_map keys must match modality module names + " - f"'{MIMO_LANGUAGE_MODULE_KEY}'. Missing: {expected_keys - grid_keys}, " - f"Extra: {grid_keys - expected_keys}" - ) - current_rank = dist.get_rank() modules = {} @@ -131,7 +126,7 @@ def from_grid_map( is_first = pp_rank == 0 is_last = pp_rank == pp_size - 1 logger.info( - f"[RankRole.from_grid_map] Rank {current_rank}: module={module_name}, " + f"[RankRole._from_grid_map] Rank {current_rank}: module={module_name}, " f"pp_rank={pp_rank}/{pp_size}, is_first_stage={is_first}, is_last_stage={is_last}" ) modules[module_name] = ModuleStageInfo(is_first_stage=is_first, is_last_stage=is_last) diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index b1c12f521c3..bdfe4289dd0 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -7,6 +7,7 @@ import torch from megatron.core.distributed import DistributedDataParallel +from megatron.core.models.mimo.comm.colocated_communicator import ColocatedBridgeCommunicator from megatron.core.models.mimo.config import MimoModelConfig from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, ModuleLayout, RankRole from megatron.core.models.mimo.partition.utils import PartitionAdapter, PartitionConfig @@ -59,10 +60,12 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - self.mimo_config = mimo_config modality_names = list(mimo_config.modality_submodules_spec.keys()) - if mimo_config.module_to_grid_map: - self.role = RankRole.from_grid_map(mimo_config.module_to_grid_map, modality_names) - else: - self.role = RankRole.unified(modality_names + [MIMO_LANGUAGE_MODULE_KEY]) + self.colocated_comms = {} + self.role = RankRole.build(modality_names, mimo_config.module_to_grid_map) + if self.role.mode is ModuleLayout.COLOCATED and mimo_config.module_to_grid_map: + # Per-encoder bridge needed iff modules share ranks but may differ + # in TP/DP within those ranks. + self._build_colocated_communicators() # Use special token IDs from the config self.special_token_ids = ( @@ -358,7 +361,7 @@ def forward( # Get any tensors passed via set_input_tensor input_tensors = getattr(self, 'input_tensors', None) - if self.role.mode == ModuleLayout.UNIFIED: + if self.role.mode == ModuleLayout.COLOCATED: return self._forward_all_modules( input_ids, position_ids, @@ -491,6 +494,47 @@ def _forward_language_module( return lm_output + def _build_colocated_communicators(self): + grid_map = self.mimo_config.module_to_grid_map + if any( + 'tp' not in grid.dim_names or 'dp' not in grid.dim_names for grid in grid_map.values() + ): + logger.info( + "Skipping colocated communicator setup because module_to_grid_map " + "does not define TP/DP topology for every module." + ) + return + + lang_key = MIMO_LANGUAGE_MODULE_KEY + lang_grid = grid_map[lang_key] + for mod_name in self.mimo_config.modality_submodules_spec: + if mod_name == lang_key: + continue + self.colocated_comms[(mod_name, lang_key)] = ColocatedBridgeCommunicator( + src_grid=grid_map[mod_name], + dest_grid=lang_grid, + src_module_name=mod_name, + dest_module_name=lang_key, + dim_mapping={'b': 0, 'h': 1}, + ) + + def destroy(self) -> None: + """Release process groups owned by this MimoModel.""" + for comm in self.colocated_comms.values(): + comm.destroy() + self.colocated_comms.clear() + + def _apply_colocated_comms(self, modality_embeddings): + """Transform encoder embeddings from encoder TP/DP to LLM TP/DP layout.""" + lang_key = MIMO_LANGUAGE_MODULE_KEY + for modality_name in list(modality_embeddings.keys()): + comm = self.colocated_comms.get((modality_name, lang_key)) + if comm is not None: + modality_embeddings[modality_name] = comm.communicate( + modality_embeddings[modality_name] + ) + return modality_embeddings + def _forward_all_modules( self, input_ids: torch.Tensor, @@ -533,6 +577,10 @@ def _forward_all_modules( f"Generated embeddings for {modality_name} with shape {embeddings.shape}" ) + # Apply colocated communication if configured (no-op when colocated_comms is empty) + if self.colocated_comms: + modality_embeddings = self._apply_colocated_comms(modality_embeddings) + # Get text embeddings text_embeddings = self.get_text_embeddings(input_ids, position_ids, self.special_token_ids) logger.debug(f"Generated text embeddings with shape {text_embeddings.shape}") diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 7bd3aad7d1e..a9768d5a49b 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -47,7 +47,11 @@ ShardedTensorFactory, ) from ..dist_checkpointing.utils import extract_sharded_tensors_and_factories -from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets +from ..distributed.param_and_grad_buffer import ( + _ParamAndGradBuffer, + group_params_for_buffers, + partition_buckets, +) from ..fp4_utils import is_nvfp4tensor, quantize_nvfp4_param_shard from ..fp8_utils import dequantize_fp8_tensor, is_float8tensor, quantize_param_shard from ..transformer.fsdp_dtensor_checkpoint import handle_experts_in_state_dict @@ -56,6 +60,7 @@ from .grad_scaler import MegatronGradScaler from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper, param_group_identifier_keys from .optimizer_config import OptimizerConfig +from .param_layout import FullParamLayout, PerBufferParamLayout, pad_bucket_end, pad_param_start logger = getLogger(__name__) @@ -474,6 +479,133 @@ def _build_model_and_main_param_groups( shard_fp32_from_float16_groups, ) + @staticmethod + def _compute_per_buffer_param_layout( + params: List[torch.nn.Parameter], + bucket_size: Optional[int], + data_parallel_world_size: int, + ddp_config, + param_indices: Optional[List[int]] = None, + ) -> 'PerBufferParamLayout': + """Compute how parameters should be laid out in the contiguous buffer. + + Iterates params in reverse order (backprop order), applies 64-byte param + alignment, bucket-end padding for DP divisibility, and shared-embedding + bucket splitting. + + Args: + params: List of parameters to lay out. + bucket_size: Approximate number of elements per bucket, or None for single bucket. + data_parallel_world_size: Size of the data-parallel group. + ddp_config: DistributedDataParallel config object. + param_indices: Optional indices for each param among same-dtype params. + + Returns: + PerBufferParamLayout with the computed mapping. + """ + + def _does_param_require_new_bucket(param): + return getattr(param, "shared_embedding", False) + + param_index_map = {} + bucket_indices = [] + per_bucket_numel_unpadded = [] + + param_start_index = 0 + bucket_start_index = 0 + bucket_params = set() + bucket_id = 0 + + def _finalize_bucket(param_end_index, bucket_start_index, bucket_id): + per_bucket_numel_unpadded.append(param_end_index - bucket_start_index) + bucket_end_index = pad_bucket_end( + param_end_index, + data_parallel_world_size, + ddp_config.pad_buckets_for_high_nccl_busbw, + ) + bucket_indices.append((bucket_start_index, bucket_end_index)) + return bucket_end_index, bucket_id + 1 + + for param in params[::-1]: + param_start_index = pad_param_start(param_start_index) + + # Split shared embedding params into separate bucket. + if _does_param_require_new_bucket(param) and len(bucket_params) > 0: + bucket_start_index, bucket_id = _finalize_bucket( + param_start_index, bucket_start_index, bucket_id + ) + bucket_params = set() + param_start_index = bucket_start_index + + param_numel = param.data.nelement() + param_end_index = param_start_index + param_numel + param_index_map[param] = (param_start_index, param_end_index, bucket_id) + bucket_params.add(param) + + if ( + bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size + ) or _does_param_require_new_bucket(param): + bucket_start_index, bucket_id = _finalize_bucket( + param_end_index, bucket_start_index, bucket_id + ) + bucket_params = set() + param_start_index = bucket_start_index + else: + param_start_index = param_end_index + + if len(bucket_params) > 0: + _finalize_bucket(param_end_index, bucket_start_index, bucket_id) + + return PerBufferParamLayout( + param_index_map=param_index_map, + bucket_indices=bucket_indices, + per_bucket_numel_unpadded=per_bucket_numel_unpadded, + param_indices=param_indices if param_indices is not None else [], + ) + + @staticmethod + def compute_full_param_layout( + params: List[torch.nn.Parameter], + bucket_size: Optional[int], + data_parallel_world_size: int, + ddp_config, + expert_data_parallel_world_size: Optional[int] = None, + ) -> 'FullParamLayout': + """Compute parameter layouts for all buffer groups. + + Groups parameters by (param_dtype, grad_dtype, is_expert_parallel), then + computes a padded PerBufferParamLayout for each group. Expert-parallel groups use + expert_data_parallel_world_size for padding alignment. + + Args: + params: List of all parameters to lay out. + bucket_size: Approximate number of elements per bucket, or None for single bucket. + data_parallel_world_size: Size of the data-parallel group for dense params. + ddp_config: DistributedDataParallel config object. + expert_data_parallel_world_size: Size of the expert data-parallel group. + Required if any expert-parallel params are present. Defaults to + data_parallel_world_size if not provided. + + Returns: + FullParamLayout with a PerBufferParamLayout per buffer group. + """ + buffer_groups = group_params_for_buffers(params, ddp_config.grad_reduce_in_fp32) + layouts = {} + for buffer_key, (group_params, param_indices) in buffer_groups.items(): + if buffer_key.is_expert_parallel: + dp_world_size = ( + expert_data_parallel_world_size + if expert_data_parallel_world_size is not None + else data_parallel_world_size + ) + else: + dp_world_size = data_parallel_world_size + layout = DistributedOptimizer._compute_per_buffer_param_layout( + group_params, bucket_size, dp_world_size, ddp_config, param_indices + ) + layouts[buffer_key] = layout + return FullParamLayout(layouts=layouts) + def __init__( self, optimizer: torch.optim.Optimizer, diff --git a/megatron/core/optimizer/param_layout.py b/megatron/core/optimizer/param_layout.py new file mode 100644 index 00000000000..6ebcc348f84 --- /dev/null +++ b/megatron/core/optimizer/param_layout.py @@ -0,0 +1,92 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Parameter layout dataclasses for optimizer-driven buffer layout. + +These dataclasses describe how parameters are laid out in contiguous buffers. +Each distributed optimizer implementation (e.g., DistributedOptimizer) is +responsible for computing these layouts via a _compute_per_buffer_param_layout method, +applying its own padding, alignment, and bucket splitting rules. DDP and +buffers consume the resulting layouts without any optimizer-specific knowledge. +""" + +import math +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + +import torch + + +def pad_to_divisor(value: int, divisor: int) -> int: + """Round up ``value`` to the nearest multiple of ``divisor``.""" + return int(math.ceil(value / divisor) * divisor) + + +def pad_param_start(param_start_index: int) -> int: + """Align parameter start index to a 64-element boundary.""" + return pad_to_divisor(param_start_index, 64) + + +def pad_bucket_end( + bucket_end_index: int, data_parallel_world_size: int, pad_for_high_nccl_busbw: bool +) -> int: + """Pad bucket end for DP-divisibility (and optionally high NCCL bus bandwidth).""" + if pad_for_high_nccl_busbw: + divisor = math.lcm(data_parallel_world_size, 128, 2**16) + else: + divisor = math.lcm(data_parallel_world_size, 128) + return pad_to_divisor(bucket_end_index, divisor) + + +@dataclass(frozen=True) +class BufferKey: + """Identifies a distinct parameter buffer. + + Each unique combination of these fields corresponds to a separate contiguous + buffer in DDP. Parameters are grouped into buffers by these dimensions. + + Attributes: + param_dtype: Storage dtype (torch.uint8 for FP8/NVFP4 parameters, else param.dtype). + grad_dtype: Gradient reduction dtype. + is_expert_parallel: Whether the buffer holds expert-parallel parameters, + which use a separate data-parallel group. + """ + + param_dtype: torch.dtype + grad_dtype: torch.dtype + is_expert_parallel: bool + + +@dataclass +class PerBufferParamLayout: + """Layout for parameters within a single contiguous buffer. + + Describes how parameters should be laid out in the contiguous buffer. + + Attributes: + param_index_map: Mapping from parameter to (start_index, end_index, bucket_id) in buffer. + bucket_indices: List of (start_index, end_index) for each bucket. + per_bucket_numel_unpadded: Number of unpadded elements per bucket. + param_indices: The index of each param among same-dtype params (using the "fake" + high-precision dtype for FP8/NVFP4 params). Needed for loading non-native-fp8 + checkpoints in native-fp8 mode. Order matches param_index_map iteration order. + """ + + param_index_map: Dict[torch.nn.Parameter, Tuple[int, int, int]] = field(default_factory=dict) + bucket_indices: List[Tuple[int, int]] = field(default_factory=list) + per_bucket_numel_unpadded: List[int] = field(default_factory=list) + param_indices: List[int] = field(default_factory=list) + + +@dataclass +class FullParamLayout: + """Layout for all parameters across all buffer groups in a model chunk. + + Maps BufferKey to per-buffer PerBufferParamLayout objects. Each PerBufferParamLayout has its + own independent index space since different buffer groups are physically + separate buffers. + + Attributes: + layouts: Mapping from BufferKey to PerBufferParamLayout. + """ + + layouts: Dict[BufferKey, PerBufferParamLayout] = field(default_factory=dict) diff --git a/megatron/core/pipeline_parallel/combined_1f1b.py b/megatron/core/pipeline_parallel/combined_1f1b.py index ffb9aa0e3e2..ec689e8fe7f 100644 --- a/megatron/core/pipeline_parallel/combined_1f1b.py +++ b/megatron/core/pipeline_parallel/combined_1f1b.py @@ -231,10 +231,15 @@ def combined_1f1b_schedule_for_interleaved_pipelining(): if f_model_chunk_id is not None: forward_step_helper_postprocess(f_model_chunk_id, output_tensor, num_tokens) # backward post process - if b_model_chunk_id: + if b_model_chunk_id is not None: # The same as the backward_step_helper backward_step_helper_postprocess(b_virtual_microbatch_id) - if input_tensor is not None: + # Verify backward grad: if backward microbatch received activation from upstream + # (b_input_tensor is not None), input_tensor_grad must be produced. + # Note: the original assert used forward's input_tensor, which is incorrect when + # forward and backward are on different VP stages (backward has chunk reversal: + # model_chunk_id = num_chunks - id - 1), causing false failures in interleaved PP. + if b_input_tensor is not None: assert input_tensor_grad is not None return output_tensor, input_tensor_grad diff --git a/megatron/core/rerun_state_machine.py b/megatron/core/rerun_state_machine.py index 77a973401d5..c928adef7e8 100644 --- a/megatron/core/rerun_state_machine.py +++ b/megatron/core/rerun_state_machine.py @@ -774,8 +774,13 @@ def state_dict( data_iterator: the data iterator that needs to be checkpointed (or None if this checkpoint is not requested by the rerun state machine). ckpt_format: the checkpoint format to use. + force: if True, emit the full state dict even when the machine is + disabled or no rerun is pending (used on the load path to + build a matching template). Returns: - A state dict representing the rerun state machine. + A state dict representing the rerun state machine, or None if + rerun checkpointing is not applicable (mode is DISABLED, or the + checkpoint format does not support ShardedObject). Example usage: @@ -790,11 +795,18 @@ def save_my_model_checkpoint(data_iterator, ...): return checkpoint """ - # Only save a checkpoint if a step needs to be rerun. + # Short-circuits only apply on the save path. On the load path + # (``force=True``) we build a template that mirrors whatever the + # checkpoint happens to contain, regardless of the current mode or + # ckpt_format -- this preserves the behavior the load path in + # ``checkpointing.py`` has always relied on. if not force: - if self.state == RerunState.NOT_RUNNING_YET: + # Disabled mode never triggers a rerun workflow, so there's + # nothing to persist across a restart; keep returning None to + # avoid bloating the checkpoint. + if self.mode == RerunMode.DISABLED: return None - + # ShardedObject is only supported by the torch_dist format. if ckpt_format != "torch_dist": log_single_rank( logger, @@ -804,11 +816,28 @@ def save_my_model_checkpoint(data_iterator, ...): ) return None - data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator) + # Data-iterator buffers only need to be checkpointed when a rerun is + # actually pending. In steady state we skip sanitization so the + # caller isn't required to have wrapped its iterator in + # ``RerunDataIterator`` -- that requirement only applies when a + # fault is mid-flight. + if self.state == RerunState.NOT_RUNNING_YET and not force: + data_iterator_checkpoints = None + else: + data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator) + data_iterator_checkpoints = ( + [d.state_dict() for d in data_iterators] if data_iterators else None + ) # When saving a step to re-run, the RerunStateMachine state is different across all ranks. # We keep the common state in the non-sharded (common) checkpoint and move the rank-level # state to a sharded object. + # In NOT_RUNNING_YET this is all zero/None defaults (a sentinel); + # in WILL_RERUN_FROM_CHECKPOINT it carries the real fault context. The + # ShardedObject key/shape/offset are identical in both cases. This keeps the + # checkpoint's sharded structure constant across saves (a + # requirement of ``--ckpt-assume-constant-structure``) + # For details, see GitHub issue NVIDIA/Megatron-LM#4378. sharded_dict = { "rerun_requested": self.rerun_requested, "checkpoint_requested": self.checkpoint_requested, @@ -822,9 +851,7 @@ def save_my_model_checkpoint(data_iterator, ...): "suspicious_node": self.suspicious_node, "suspicious_device": self.suspicious_device, # No need to save saved_state (RNG state already captured in checkpoint). - "data_iterator_checkpoints": ( - [d.state_dict() for d in data_iterators] if data_iterators else None - ), + "data_iterator_checkpoints": data_iterator_checkpoints, "large_value_counts": self.large_value_counts, "max_values": self.max_values, # No need to save saved_results and stats (resets when job resumes). diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index adccc9f2356..e1b75bb0a70 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -9,7 +9,7 @@ from .copy_services.base import CopyService from .transforms import ReshardTransform, _ensure_sendable -from .utils import ReshardPlan +from .utils import ReshardPlan, named_refit_tensors logger = logging.getLogger(__name__) @@ -48,13 +48,15 @@ def execute_reshard_plan( of the default slice-and-copy logic. """ - # Extract parameters from models if present + # Extract parameters and persistent buffers from models if present. + # Persistent buffers carry training state (e.g. MoE router expert_bias) + # and must be refit alongside parameters. src_params = {} dst_params = {} if src_module is not None: - src_params = {name: p for name, p in src_module.named_parameters(recurse=True)} + src_params = {name: p for name, p in named_refit_tensors(src_module)} if dst_module is not None: - dst_params = {name: p for name, p in dst_module.named_parameters(recurse=True)} + dst_params = {name: p for name, p in named_refit_tensors(dst_module)} # Cache dequantized BF16 views of MXFP8 source params so that multiple # send ops for the same param reuse one dequant instead of repeating it. diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py index 1c7c63a72a9..242450bd835 100644 --- a/megatron/core/resharding/planner.py +++ b/megatron/core/resharding/planner.py @@ -15,6 +15,7 @@ _build_layer_module_prefix_map, _get_rank_in_group, extract_param_metadata, + named_refit_tensors, select_src_metadata_balanced, ) @@ -376,7 +377,12 @@ def build_centralized_reshard_plan( _rank_list_cache: dict = {} def _extract_metadata(module, rank_offset): - """Extract per-parameter metadata from a module, or [] if module is None.""" + """Extract per-parameter metadata from a module, or [] if module is None. + + Includes both ``nn.Parameter`` instances and persistent buffers — the + latter so that buffers carrying training state (e.g. MoE router + ``expert_bias``) travel with the weights during refit. + """ if module is None: return [] pg = getattr(module, "pg_collection", None) @@ -394,7 +400,7 @@ def _extract_metadata(module, rank_offset): rank_offset=rank_offset, _rank_list_cache=_rank_list_cache, ) - for name, p in module.named_parameters(recurse=True) + for name, p in named_refit_tensors(module) ] my_src_metadata = _extract_metadata(src_module, src_rank_offset) diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index c995e7ecd0a..2cb6ba4479f 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -27,6 +27,7 @@ from .copy_services.nccl_copy_service import NCCLCopyService from .copy_services.nvshmem_copy_service import NVSHMEMCopyService from .transforms import MXFP8ReshardTransform, ReshardTransform +from .utils import named_persistent_buffers # Supported refit backend names RefitBackendName = Literal["nccl", "gloo", "nvshmem"] @@ -376,6 +377,59 @@ def swap_model_weights( ) +def _harmonize_buffer_dtypes(src_core, tgt_core, group=None): + """Bring destination persistent-buffer dtypes into agreement with source. + + Some buffers (notably the MoE router ``expert_bias``) are upcast to fp32 + inside the trainer on first forward by ``_maintain_float32_expert_bias``, + while the freshly-built inference model still holds them in bf16 from the + ``Float16Module`` wrap. The reshard send/recv path is dtype-strict — + sending fp32 bytes into a bf16 receive buffer corrupts the data — so dst's + buffer must match src's dtype before the transfer. + + Works for both collocated and non-collocated transfers: every rank reports + its source-side persistent-buffer dtypes via a single + ``all_gather_object`` on ``group``. Destination-side ranks then look up + each of their own buffers in the gathered map and replace the tensor with + one in src's dtype. Source-only and idle ranks contribute empty dicts and + skip the apply step, but still participate in the collective so it is + well-formed across every rank. + + Buffer matching is by raw module path (e.g. ``decoder.layers.0.…``); the + planner's PP-aware ``resolved_name`` is intentionally not used here because + we only need the dtype, which is uniform for a given buffer kind across + layers in practice. + """ + # Build local map of source-side persistent buffer dtypes. + local_src_dtypes: dict[str, torch.dtype] = {} + if src_core is not None: + for full_name, _sub, _buf_name, buf in named_persistent_buffers(src_core): + local_src_dtypes[full_name] = buf.dtype + + world_size = group.size() if group is not None else torch.distributed.get_world_size() + gathered: list = [None] * world_size + torch.distributed.all_gather_object(gathered, local_src_dtypes, group=group) + + canonical: dict[str, torch.dtype] = {} + for d in gathered: + if not d: + continue + for name, dtype in d.items(): + # Replicated buffers agree across ranks; first writer wins. + canonical.setdefault(name, dtype) + + if tgt_core is None: + return + + for full_name, sub, buf_name, dst_buf in named_persistent_buffers(tgt_core): + expected = canonical.get(full_name) + if expected is not None and dst_buf.dtype != expected: + # Replace the tensor in-place on the parent module so subsequent + # recvs write the right number of bytes and the in-model lookup + # (``self.expert_bias``) sees the new storage. + sub._buffers[buf_name] = dst_buf.to(expected) + + def reshard_model_weights( src_model: LanguageModule, target_model: LanguageModule, @@ -400,6 +454,7 @@ def reshard_model_weights( transform: Optional ReshardTransform for custom format conversion. """ src_core, tgt_core, num_experts = _unwrap_model_cores(src_model, target_model) + _harmonize_buffer_dtypes(src_core, tgt_core, group=group) plan = _build_or_get_plan( src_core, tgt_core, num_experts, group, src_rank_offset, dst_rank_offset ) diff --git a/megatron/core/resharding/utils.py b/megatron/core/resharding/utils.py index 4e11fedee68..1c748f5ff98 100644 --- a/megatron/core/resharding/utils.py +++ b/megatron/core/resharding/utils.py @@ -197,6 +197,38 @@ def assign_resolved_name_inplace( assign_ep_resolved_name_inplace(meta, base_name=name) +def named_persistent_buffers(module: torch.nn.Module): + """Yield ``(full_name, parent_module, buf_name, tensor)`` for every + persistent buffer in ``module``. Skips ``_non_persistent_buffers_set``. + + Persistent buffers (those saved in ``state_dict``) carry training state that + must travel with the weights during refit/resharding — e.g. the MoE + router's ``expert_bias``, which is updated each step by aux-loss-free load + balancing. Non-persistent buffers are excluded since they hold ephemeral + state (e.g. accumulators reset at the next train step). + """ + for module_prefix, sub_module in module.named_modules(): + non_persistent = sub_module._non_persistent_buffers_set + for buf_name, buf in sub_module._buffers.items(): + if buf is None or buf_name in non_persistent: + continue + full_name = f"{module_prefix}.{buf_name}" if module_prefix else buf_name + yield full_name, sub_module, buf_name, buf + + +def named_refit_tensors(module: torch.nn.Module): + """Yield ``(name, tensor)`` pairs for every parameter and persistent buffer. + + Used by the refit planner and executor to enumerate which tensors should + travel during resharding. Persistent buffers are included alongside + parameters because they may carry training state (see + ``named_persistent_buffers``). + """ + yield from module.named_parameters(recurse=True) + for full_name, _sub, _buf_name, buf in named_persistent_buffers(module): + yield full_name, buf + + def _build_layer_module_prefix_map(module: torch.nn.Module) -> dict[str, str]: """Build a mapping local_module_prefix -> global_module_prefix for PP layer modules. diff --git a/megatron/core/safe_globals.py b/megatron/core/safe_globals.py index 790050749cd..9241405876a 100755 --- a/megatron/core/safe_globals.py +++ b/megatron/core/safe_globals.py @@ -1,5 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import io +import pickle from argparse import Namespace from io import BytesIO from pathlib import PosixPath @@ -41,3 +43,57 @@ def register_safe_globals(): """Register megatron-core safe classes with torch serialization.""" for cls in SAFE_GLOBALS: torch.serialization.add_safe_globals([cls]) + + +def safe_load_from_bytes(b): + """Safe version (weights_only=True) of `torch.storage._load_from_bytes`.""" + return torch.load(io.BytesIO(b), weights_only=True) + + +class SafeUnpickler(pickle.Unpickler): + """Restricted unpickler for FP8 extra-state checkpoints. + Only allows the narrow set of types that ``_encode_extra_state`` can + produce: plain Python containers, numeric scalars, and the PyTorch + tensor/storage primitives used by `pickle.dumps(tensor)`. Any attempt + to instantiate a class outside this allowlist raises + `pickle.UnpicklingError`, preventing arbitrary code execution via a + crafted checkpoint. + """ + + _SAFE_CLASSES: frozenset = frozenset( + { + ("builtins", "dict"), + ("builtins", "list"), + ("builtins", "tuple"), + ("builtins", "int"), + ("builtins", "float"), + ("builtins", "bool"), + ("builtins", "bytes"), + ("builtins", "str"), + ("collections", "OrderedDict"), + ("torch", "Size"), + ("torch._utils", "_rebuild_tensor_v2"), + ("torch._tensor", "_rebuild_from_type_v2"), + ("torch.storage", "UntypedStorage"), + ("torch.storage", "_load_from_bytes"), + ("transformer_engine.common.recipe", "DelayedScaling"), + ("transformer_engine.common.recipe", "Float8CurrentScaling"), + ("transformer_engine.common.recipe", "Float8BlockScaling"), + ("transformer_engine.common.recipe", "MXFP8BlockScaling"), + ("transformer_engine.common.recipe", "NVFP4BlockScaling"), + ("transformer_engine.common.recipe", "Format"), + ("transformer_engine.common.recipe", "_FormatHelper"), + ("transformer_engine.common.recipe", "MMParams"), + ("transformer_engine.common.recipe", "QParams"), + ("megatron.core.extensions.transformer_engine", "TEDelayedScaling"), + ("megatron.core.safe_globals", "safe_load_from_bytes"), + } + ) + + def find_class(self, module: str, name: str): + if (module, name) not in self._SAFE_CLASSES: + raise pickle.UnpicklingError( + f"Refusing to unpickle disallowed class '{module}.{name}' " + "in FP8 extra-state checkpoint." + ) + return super().find_class(module, name) diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index 8df4df1e562..f9b923632f5 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -33,6 +33,7 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.utils import ( + cat_with_oom_fallback, ensure_metadata_has_dp_cp_group, make_sharded_tensors_for_checkpoint, sharded_state_dict_default, @@ -747,12 +748,12 @@ def sh_ten_build_fn( ) return chunk_sh_tens - @torch.no_grad() - def sh_ten_merge_fn(sub_state_dict): - return torch.cat(sub_state_dict) - return ShardedTensorFactory( - orig_sh_ten.key, orig_sh_ten.data, sh_ten_build_fn, sh_ten_merge_fn, orig_sh_ten.replica_id + orig_sh_ten.key, + orig_sh_ten.data, + sh_ten_build_fn, + cat_with_oom_fallback, + orig_sh_ten.replica_id, ) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index e9ee2dd8deb..707c7d7690f 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -32,6 +32,7 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.utils import ( + cat_with_oom_fallback, ensure_metadata_has_dp_cp_group, make_sharded_tensors_for_checkpoint, sharded_state_dict_default, @@ -355,6 +356,14 @@ def __init__( self.A_log = nn.Parameter(A_log) setattr(self.A_log, "tensor_model_parallel", True) setattr(self.A_log, "partition_dim", 0) + # Persistent inference cache for -exp(A_log.float()). Allocated + # here (outside any later CUDA-graph capture) so its address + # lives in the default memory pool and stays valid across every + # graph capture and replay, including across RL train/eval + # cycles. Never freed -- the memory cost is ``nheads * 4B`` per + # layer (a few KB across a full model). + self._A_neg_exp_cache = torch.empty_like(A_log, dtype=torch.float32) + self._A_neg_exp_cache_stale = True # D "skip" parameter self.D = nn.Parameter( torch.ones( @@ -1029,6 +1038,32 @@ def _ssm_prefill( return y + def _get_decode_A_neg_exp(self) -> torch.Tensor: + """Cached ``-exp(A_log.float())`` pre-expanded to ``(nheads, headdim, dstate)``. + + A_log is frozen during inference; recomputing it per token otherwise + launches three small elementwise kernels (float cast, exp, neg) that + rival ``selective_state_update`` itself in the decode profile. The + stride-0 expand view also triggers the kernel's TIE_HDIM fast path. + """ + if self.training or torch.is_grad_enabled(): + base = -torch.exp(self.A_log.float()) + return base.view(-1, 1, 1).expand(-1, self.headdim, self.d_state) + # Inference path. Refill when stale + if self._A_neg_exp_cache_stale: + with torch.no_grad(): + self._A_neg_exp_cache.copy_(-torch.exp(self.A_log.float())) + self._A_neg_exp_cache_stale = False + return self._A_neg_exp_cache.view(-1, 1, 1).expand(-1, self.headdim, self.d_state) + + def train(self, mode: bool = True): + """Mark the decode cache stale; weights may have updated.""" + if mode: + # only mark stale when switching to training mode. + # otherwise retain the staleness state. + self._A_neg_exp_cache_stale = True + return super().train(mode) + def _ssm_decode( self, zxBCdt: torch.Tensor, @@ -1107,10 +1142,10 @@ def _ssm_decode( ], dim=-1, ) - A = -torch.exp(self.A_log.float()) - # SSM step if selective_state_update is None: + # Fallback uses 1D A; the decode cache is pre-expanded for Triton. + A = -torch.exp(self.A_log.float()) # TODO(ksanthanam): Consider deprecating this path assert seq_len == 1, "Native PyTorch fallback only supports 1 token at a time" @@ -1168,7 +1203,7 @@ def _ssm_decode( y = y.unsqueeze(1) # Restore seq dimension else: - A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) + A = self._get_decode_A_neg_exp() # Incorporate sequence dimension in einops rearrengements dt = repeat(dt, "b s h -> b s h p", p=self.headdim) @@ -1180,11 +1215,6 @@ def _ssm_decode( if not self.rmsnorm: z = rearrange(z, "b s (h p) -> b s h p", p=self.headdim) - # Upcast the batch_indices to prevent integer overflow errors in the case of - # large max request counts. - if batch_indices is not None: - batch_indices = batch_indices.to(torch.int64) - y = selective_state_update( ssm_state, x_reshaped, @@ -1392,12 +1422,12 @@ def sh_ten_build_fn( ) return chunk_sh_tens - @torch.no_grad() - def sh_ten_merge_fn(sub_state_dict): - return torch.cat(sub_state_dict) - return ShardedTensorFactory( - orig_sh_ten.key, orig_sh_ten.data, sh_ten_build_fn, sh_ten_merge_fn, orig_sh_ten.replica_id + orig_sh_ten.key, + orig_sh_ten.data, + sh_ten_build_fn, + cat_with_oom_fallback, + orig_sh_ten.replica_id, ) diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index f57f5d94cea..4ff7ddcf933 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -139,33 +139,71 @@ def causal_conv1d_update_kernel( conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask ).to(tl.float32) - # Shift the linear state buffer left by 1 - i = 0 - while i < state_len - 1: - val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) - tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) - i += 1 - # Process the single token for the current sequence step x_val = tl.load(x_ptrs, mask=mask) + # Shift the linear state buffer left by 1. When state_len == WIDTH (the + # common case: conv_state dim == conv kernel width) the shifted values + # are already resident in the x_val_* registers from the loads above, so + # we can write them back without a second HBM read per position. For + # state_len > WIDTH the leading positions are untouched by compute; fall + # through to the explicit load+store shift. + if state_len == WIDTH: + out_dtype = conv_state_ptrs.dtype.element_ty + if WIDTH >= 2: + tl.store( + conv_state_ptrs + 0 * conv_state_l_stride, x_val_0.to(out_dtype), mask=mask + ) + if WIDTH >= 3: + tl.store( + conv_state_ptrs + 1 * conv_state_l_stride, x_val_1.to(out_dtype), mask=mask + ) + if WIDTH >= 4: + tl.store( + conv_state_ptrs + 2 * conv_state_l_stride, x_val_2.to(out_dtype), mask=mask + ) + else: + i = 0 + while i < state_len - 1: + val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) + tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) + i += 1 + # Store the new token at the end of the linear state buffer tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask) - # Write out to the intermediate state buffer if requested + # Write out to the intermediate state buffer if requested. Reuse the + # register values from the shift above (and x_val for the new tail + # position) when state_len == WIDTH, instead of re-reading from HBM. if HAS_INT_STATE: - i = 0 - while i < state_len: - val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask) - int_ptr = ( + if state_len == WIDTH: + int_base = ( int_state_ptr + state_batch_coord * int_state_b_stride + s * int_state_s_stride + channel_offsets * int_state_c_stride - + i * int_state_l_stride ) - tl.store(int_ptr, val, mask=mask) - i += 1 + out_dtype = int_base.dtype.element_ty + if WIDTH >= 2: + tl.store(int_base + 0 * int_state_l_stride, x_val_0.to(out_dtype), mask=mask) + if WIDTH >= 3: + tl.store(int_base + 1 * int_state_l_stride, x_val_1.to(out_dtype), mask=mask) + if WIDTH >= 4: + tl.store(int_base + 2 * int_state_l_stride, x_val_2.to(out_dtype), mask=mask) + tl.store(int_base + (state_len - 1) * int_state_l_stride, x_val, mask=mask) + else: + i = 0 + while i < state_len: + val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask) + int_ptr = ( + int_state_ptr + + state_batch_coord * int_state_b_stride + + s * int_state_s_stride + + channel_offsets * int_state_c_stride + + i * int_state_l_stride + ) + tl.store(int_ptr, val, mask=mask) + i += 1 # Advance registers for calculation x_val_f32 = x_val.to(tl.float32) diff --git a/megatron/core/ssm/ops/mamba_ssm.py b/megatron/core/ssm/ops/mamba_ssm.py index 4e079da8a31..672b00cf7dc 100644 --- a/megatron/core/ssm/ops/mamba_ssm.py +++ b/megatron/core/ssm/ops/mamba_ssm.py @@ -2,9 +2,11 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. # Some of this code was adopted from https://github.com/state-spaces/mamba/ +# and https://github.com/vllm-project/vllm. # This source code is licensed under the Apache license found in the # LICENSE file in the root directory of this source tree. + import torch from packaging import version @@ -41,6 +43,15 @@ def softplus(dt): return tl.math.log1p(tl.exp(dt)) +@triton.jit +def fast_exp(x): + """ + Fast calculation of exponent via exponent of 2. + """ + LOG2E = tl.constexpr(1.4426950408889634) + return tl.math.exp2(LOG2E * x) + + @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) @@ -49,7 +60,7 @@ def softplus(dt): ) @triton.heuristics({"HAS_INT_STATE": lambda args: args["int_state_ptr"] is not None}) @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) -@triton.jit +@triton.jit(do_not_specialize=["batch"]) def _selective_scan_update_kernel( # Pointers to matrices state_ptr, @@ -223,14 +234,14 @@ def _selective_scan_update_kernel( dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if DT_SOFTPLUS: dt = tl.where(dt <= 20.0, softplus(dt), dt) - dA = tl.exp(A * dt[:, None]) + dA = fast_exp(A * dt[:, None]) else: dt = tl.load(dt_ptr + s * stride_dt_seq).to(tl.float32) if HAS_DT_BIAS: dt += tl.load(dt_bias_ptr).to(tl.float32) if DT_SOFTPLUS: dt = tl.where(dt <= 20.0, softplus(dt), dt) - dA = tl.exp(A * dt) + dA = fast_exp(A * dt) # Load B and C B = tl.load(B_s_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) @@ -368,15 +379,23 @@ def selective_state_update( (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0) ) - BLOCK_SIZE_M, num_warps = ( - (32, 4) - if dstate <= 16 - else ( - (16, 4) - if dstate <= 32 - else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) - ) - ) + is_blackwell = torch.cuda.get_device_capability(x.device)[0] >= 10 + + # Default + BLOCK_SIZE_M, num_warps = 4, 8 + if dstate <= 16: + BLOCK_SIZE_M, num_warps = 32, 4 + elif dstate <= 32: + BLOCK_SIZE_M, num_warps = 16, 4 + elif dstate <= 64: + BLOCK_SIZE_M, num_warps = 8, 4 + else: + # dstate > 64 + if is_blackwell: + # Optimized for B200 with dstate>64 + BLOCK_SIZE_M, num_warps = 32, 8 + elif dstate <= 128: + BLOCK_SIZE_M, num_warps = 4, 4 tie_hdim = ( A.stride(-1) == 0 diff --git a/megatron/core/tensor_parallel/inference_layers.py b/megatron/core/tensor_parallel/inference_layers.py index 80aa754dd50..14ac28fbefa 100644 --- a/megatron/core/tensor_parallel/inference_layers.py +++ b/megatron/core/tensor_parallel/inference_layers.py @@ -20,6 +20,10 @@ from megatron.core.inference.quantization.utils import mm_mxfp8 from megatron.core.inference.symmetric_memory import SymmetricMemoryManager from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.tensor_parallel.mappings import ( + gather_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -473,3 +477,75 @@ def forward( else: x = self._matmul_reduce_scatter(x) return x, None + + +def inference_all_gather_from_tensor_model_parallel_region( + x: torch.Tensor, tp_group: torch.distributed.ProcessGroup, config: TransformerConfig +) -> torch.Tensor: + """NVLS-optimized all-gather along the last dimension, with NCCL fallback. + + Replaces `gather_from_tensor_model_parallel_region` in inference paths + where autograd is not needed and NVLS symmetric-memory is available. + + The NVLS path performs a flat all-gather into symmetric memory (concatenating + along dim-0), then rearranges the result to the last dimension — the same + semantics as `_gather_along_last_dim` but using hardware multicast when + possible. + """ + tp_size = dist.get_world_size(tp_group) + if tp_size == 1: + return x + + triton_nvls_kernels_allowed = not getattr( + config, 'inference_disable_triton_nvls_kernels', False + ) + + if triton_nvls_kernels_allowed and SymmetricMemoryManager.is_initialized("tp"): + ag_buffer_dims = list(x.size()) + ag_buffer_dims[0] *= tp_size + buf = SymmetricMemoryManager.get_buffer("tp", process_group=tp_group) + symm_mem_buffer = buf.maybe_get_tensor(ag_buffer_dims, dtype=x.dtype) + + if are_tensors_nvls_eligible(x) and symm_mem_buffer["handle"] is not None: + multimem_all_gather(symm_mem_buffer["tensor"], x, symm_mem_buffer["handle"]) + tensor_list = symm_mem_buffer["tensor"].chunk(tp_size, dim=0) + return torch.cat(tensor_list, dim=-1).contiguous() + + return gather_from_tensor_model_parallel_region(x, group=tp_group) + + +def inference_reduce_scatter_to_sequence_parallel_region( + x: torch.Tensor, tp_group: torch.distributed.ProcessGroup, config: TransformerConfig +) -> torch.Tensor: + """NVLS-optimized reduce-scatter along the first dimension, with NCCL fallback. + + Replaces `reduce_scatter_to_sequence_parallel_region` in inference paths + where autograd is not needed and NVLS symmetric-memory is available. + """ + # TODO(ksanthanam): Refactor InferenceRowParallelLinear._matmul_reduce_scatter + # to use this function for its non-fused NVLS reduce-scatter path. + tp_size = dist.get_world_size(tp_group) + if tp_size == 1: + return x + + triton_nvls_kernels_allowed = not getattr( + config, 'inference_disable_triton_nvls_kernels', False + ) + + if triton_nvls_kernels_allowed and SymmetricMemoryManager.is_initialized("tp"): + buf = SymmetricMemoryManager.get_buffer("tp", process_group=tp_group) + symm_mem_buffer = buf.maybe_get_tensor(list(x.size()), dtype=x.dtype) + + if ( + x.dtype == torch.bfloat16 + and are_tensors_nvls_eligible(x) + and symm_mem_buffer["handle"] is not None + ): + symm_mem_buffer["tensor"].copy_(x) + output_dims = list(x.size()) + output_dims[0] = x.size(0) // tp_size + output = torch.empty(output_dims, dtype=x.dtype, device=x.device) + multimem_reduce_scatter(output, symm_mem_buffer["tensor"], symm_mem_buffer["handle"]) + return output + + return reduce_scatter_to_sequence_parallel_region(x, group=tp_group) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 610700f0a95..662373064cf 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -235,6 +235,11 @@ def __init__( ) self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index self.deterministic_mode = config.deterministic_mode + self.config = config + + self.use_inference_optimized_reduce_scatter = ( + getattr(config, 'transformer_impl', None) == 'inference_optimized' + ) # Allocate weights and initialize. if config.use_cpu_initialization: @@ -302,12 +307,22 @@ def forward(self, input_): if self.reduce_scatter_embeddings: # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. output_parallel = output_parallel.transpose(0, 1).contiguous() - output = reduce_scatter_to_sequence_parallel_region( - output_parallel, group=self.tp_group - ) - else: + if self.use_inference_optimized_reduce_scatter and not self.training: + # Deferred to avoid circular import: inference_layers → TE → layers. + from .inference_layers import inference_reduce_scatter_to_sequence_parallel_region + + output = inference_reduce_scatter_to_sequence_parallel_region( + output_parallel, self.tp_group, self.config + ) + else: + output = reduce_scatter_to_sequence_parallel_region( + output_parallel, group=self.tp_group + ) + elif self.tp_group.size() > 1: # Reduce across all the model parallel GPUs. output = reduce_from_tensor_model_parallel_region(output_parallel, group=self.tp_group) + else: + output = output_parallel return output def sharded_state_dict( @@ -921,6 +936,10 @@ def __init__( else: self.register_parameter("bias", None) + self.use_inference_optimized_all_gather = ( + getattr(config, 'transformer_impl', None) == 'inference_optimized' + ) + self.sequence_parallel = config.sequence_parallel if self.sequence_parallel and world_size <= 1: warnings.warn( @@ -1056,7 +1075,17 @@ def forward( if gather_output: # All-gather across the partitions. - output = gather_from_tensor_model_parallel_region(output_parallel, group=self.tp_group) + if self.use_inference_optimized_all_gather and not self.training: + # Deferred to avoid circular import: inference_layers → TE → layers. + from .inference_layers import inference_all_gather_from_tensor_model_parallel_region + + output = inference_all_gather_from_tensor_model_parallel_region( + output_parallel, self.tp_group, self.config + ) + else: + output = gather_from_tensor_model_parallel_region( + output_parallel, group=self.tp_group + ) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None @@ -1092,7 +1121,7 @@ def get_extra_state(self) -> None: def extra_repr(self) -> str: """Extra context to add to the module's string representation.""" tp = self.output_size // self.output_size_per_partition - use_bias = self.bias is not None and self.bias is True + use_bias = self.bias is not None return ( f"in_features={self.input_size}, " f"out_features={self.output_size}, " @@ -1354,7 +1383,7 @@ def get_extra_state(self) -> None: def extra_repr(self) -> str: """Extra context to add to the module's string representation.""" tp = self.input_size // self.input_size_per_partition - use_bias = self.bias is not None and self.bias is True + use_bias = self.bias is not None return ( f"in_features={self.input_size}, " f"out_features={self.output_size}, " diff --git a/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py b/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py index 4e3387b125c..66146cc32dc 100644 --- a/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py @@ -252,9 +252,7 @@ def text_to_ids(self, text: str) -> List[int]: def ids_to_text(self, ids: List[int], remove_special_tokens: Optional[bool] = None) -> str: """Converts list of ids to text. - When remove_special_tokens is None, uses not self.include_special_tokens so that - --tokenizer-hf-include-special-tokens keeps EOS (and other special tokens) in - detokenized output (e.g. for RL trajectory consistency). + When remove_special_tokens is None, uses not self.include_special_tokens. """ if remove_special_tokens is None: remove_special_tokens = not self.include_special_tokens diff --git a/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py b/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py index 849d3f0de0e..63110444aa5 100644 --- a/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py @@ -95,7 +95,7 @@ def __init__( ) vocab_file = self._get_vocab_file(tokenizer_name, vocab_file) - merges_file = self._get_merges_file(tokenizer_name, vocab_file) + merges_file = self._get_merges_file(tokenizer_name, merges_file) tokenizer_path = MEGATRON_CONFIG_MAP[tokenizer_name]["tokenizer_name"] super().__init__(tokenizer_path, vocab_file, merges_file, **kwargs) diff --git a/megatron/core/tokenizers/text/libraries/null_tokenizer.py b/megatron/core/tokenizers/text/libraries/null_tokenizer.py index 4ddf77fc774..96a0d3afd57 100644 --- a/megatron/core/tokenizers/text/libraries/null_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/null_tokenizer.py @@ -9,12 +9,15 @@ class NullTokenizer: Args: vocab_size: vocabulary size for embedding + eod_id: id of the end-of-document token. Defaults to ``vocab_size - 1``. + pad_id: id of the padding token. Defaults to ``-1`` (no pad token). """ - def __init__(self, vocab_size): + def __init__(self, vocab_size, eod_id=None, pad_id=-1, **kwargs): """ """ - self._vocab_size_without_eod = int(vocab_size) - self._eod_id = self._vocab_size_without_eod + self._vocab_size = int(vocab_size) + self._eod_id = int(eod_id) if eod_id is not None else self._vocab_size - 1 + self._pad_id = int(pad_id) def text_to_ids(self, text): """Converts text to ids.""" @@ -44,12 +47,19 @@ def offsets(self, ids: list[int], text: str) -> list[int]: @property def unique_identifiers(self) -> OrderedDict: """Property required for use with megatron-core datasets.""" - return OrderedDict({"class": f"{type(self).__module__}.{type(self).__qualname__}"}) + return OrderedDict( + { + "class": f"{type(self).__module__}.{type(self).__qualname__}", + "vocab_size": self._vocab_size, + "eod_id": self._eod_id, + "pad_id": self._pad_id, + } + ) @property def vocab_size(self): """Returns vocab size.""" - return self._vocab_size_without_eod + 1 + return self._vocab_size @property def vocab(self): @@ -81,6 +91,11 @@ def eod(self): """Returns eod token.""" return self._eod_id + @property + def pad_id(self): + """Returns pad token.""" + return self._pad_id + @property def additional_special_tokens_ids(self): """ """ diff --git a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py index 8a418f2dd7f..bc29ddfab68 100644 --- a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py @@ -108,6 +108,25 @@ def __init__(self, tokenizer_path: str, prompt_format: str): self._prompt_format = prompt_format + @staticmethod + def _extract_token_ids(result) -> np.ndarray: + if isinstance(result, dict) or hasattr(result, "input_ids"): + ids = result["input_ids"] + # Convert to 1D if it's a 2D batch [1, seq_len] + return np.array(ids[0] if len(np.shape(ids)) > 1 else ids) + + # Handle the "Single Sequence" Encoding object (Fast Tokenizer [0] output) + if hasattr(result, "ids"): + return np.array(result.ids) + + if isinstance(result, list): + return np.array(result) + + # Handles raw ndarray returned by transformers v4 apply_chat_template with + # return_tensors="np": shape is (1, seq_len), so squeeze the batch dimension. + arr = np.asarray(result) + return arr[0] if arr.ndim == 2 and arr.shape[0] == 1 else arr + def tokenize_conversation( self, conversation: List[Dict], return_target: bool, add_generation_prompt: bool ): @@ -128,14 +147,16 @@ def tokenize_conversation( if not self._prompt_config.has_system_role and conversation[0]["role"] == "system": conversation = conversation[1:] - tokens = self._tokenizer.apply_chat_template( - conversation, - tokenize=True, - add_generation_prompt=add_generation_prompt, - return_assistant_token_mask=False, - return_tensors="np", - chat_template=self._prompt_config.custom_chat_template, - )[0] + tokens = self._extract_token_ids( + self._tokenizer.apply_chat_template( + conversation, + tokenize=True, + add_generation_prompt=add_generation_prompt, + return_assistant_token_mask=False, + return_tensors="np", + chat_template=self._prompt_config.custom_chat_template, + ) + ) if not return_target: return tokens @@ -144,7 +165,7 @@ def tokenize_conversation( # When using the default prompt format, we do not replace any tokens with IGNORE_INDEX. # Instead, all token losses will be used for simplicity. - if self._prompt_format == "default": + if self._prompt_format in ["default", "identity"]: return tokens, target # Mask system and user tokens in the target. @@ -156,8 +177,10 @@ def tokenize_conversation( if turn["role"].lower() == "assistant": assert conversation[turn_idx - 1]["role"].lower() in ("user", "tool") - turn_tokens = self._tokenizer.apply_chat_template( - [turn], tokenize=True, chat_template=self._prompt_config.custom_chat_template + turn_tokens = self._extract_token_ids( + self._tokenizer.apply_chat_template( + [turn], tokenize=True, chat_template=self._prompt_config.custom_chat_template + ) ) # There should be only one BOS at the very beginning. diff --git a/megatron/core/tokenizers/utils/build_tokenizer.py b/megatron/core/tokenizers/utils/build_tokenizer.py index bf02451ae6c..50d89da1cd9 100644 --- a/megatron/core/tokenizers/utils/build_tokenizer.py +++ b/megatron/core/tokenizers/utils/build_tokenizer.py @@ -20,9 +20,7 @@ def build_tokenizer(args, **kwargs): if args.tokenizer_type in MEGATRON_TOKENIZERS: tokenizer_library = 'megatron' tokenizer_path = args.tokenizer_type - kwargs['additional_special_tokens'] = ( - args.tokenizer_special_tokens if args.tokenizer_special_tokens else [] - ) + kwargs['additional_special_tokens'] = args.special_tokens if args.special_tokens else [] if tokenizer_path == 'BertWordPieceCase': special_tokens = {} special_tokens['additional_special_tokens'] = [f'' for i in range(100)] @@ -36,7 +34,7 @@ def build_tokenizer(args, **kwargs): tokenizer_library = 'sentencepiece' tokenizer_path = args.tokenizer_model kwargs['legacy'] = args.tokenizer_sentencepiece_legacy - kwargs['special_tokens'] = args.tokenizer_special_tokens + kwargs['special_tokens'] = args.special_tokens elif args.tokenizer_type == 'TikTokenizer': tokenizer_library = 'tiktoken' tokenizer_path = args.tokenizer_model @@ -45,15 +43,13 @@ def build_tokenizer(args, **kwargs): if args.vocab_size: kwargs['vocab_size'] = args.vocab_size kwargs['num_special_tokens'] = args.tiktoken_num_special_tokens - kwargs['special_tokens'] = args.tokenizer_special_tokens + kwargs['special_tokens'] = args.special_tokens elif args.tokenizer_type == 'HuggingFaceTokenizer': tokenizer_library = 'huggingface' tokenizer_path = args.tokenizer_model kwargs['vocab_file'] = args.vocab_file kwargs['merges_file'] = args.merge_file - kwargs['additional_special_tokens'] = ( - args.tokenizer_special_tokens if args.tokenizer_special_tokens else [] - ) + kwargs['additional_special_tokens'] = args.special_tokens if args.special_tokens else [] kwargs['use_fast'] = not args.tokenizer_hf_no_use_fast kwargs['trust_remote_code'] = args.trust_remote_code kwargs['include_special_tokens'] = not args.tokenizer_hf_no_include_special_tokens @@ -74,6 +70,13 @@ def build_tokenizer(args, **kwargs): metadata = {'library': tokenizer_library} if args.vocab_size: kwargs['vocab_size'] = args.vocab_size + if args.tokenizer_type == 'NullTokenizer': + null_eod_id = getattr(args, 'null_tokenizer_eod_id', None) + if null_eod_id is not None: + kwargs['eod_id'] = null_eod_id + null_pad_id = getattr(args, 'null_tokenizer_pad_id', None) + if null_pad_id is not None: + kwargs['pad_id'] = null_pad_id tokenizer = MegatronTokenizer.from_pretrained(metadata_path=metadata, **kwargs) # Add vocab size (if not already set from a checkpoint). @@ -81,8 +84,8 @@ def build_tokenizer(args, **kwargs): return tokenizer - if args.tokenizer_metadata: - metadata = args.tokenizer_metadata + if args.metadata_path: + metadata = args.metadata_path else: metadata = {'library': tokenizer_library} tokenizer = MegatronTokenizer.from_pretrained( diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 48bd34ce499..62a91aa3342 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -707,7 +707,10 @@ def __init__( self.backward_retain_grad = False self.fp8_enabled = False self.fp4_enabled = False + self.fp8_runtime_enabled = None + self.fp4_runtime_enabled = None self.deallocate_pipeline_outputs = False + self.num_warmup_steps = 0 self.grad_enabled = need_backward and torch.is_grad_enabled() self.func = super(MegatronModule, self.base_module).__call__ if func is None else func @@ -796,8 +799,18 @@ def create_fwd_graph(self, args, kwargs, outputs=None, clone_inputs=True): self.kwargs = kwargs self.outputs = outputs - # save grads and other variables that may be affected by graph warmup + # Save buffers, grads, and other variables that may be affected by graph warmup. + # For example, megatron/core/transformer/moe/router.py's expert_bias is a persistent + # buffer updated each forward pass by '_apply_expert_bias()'. So we need to ensure + # graph capture's forward passes do not corrupt its value. Inference is not affected + # (no known buffer mutators) and would add new buffers (lazy MoE _fc1_weight/ + # _fc2_weight) that misalign the positional restore. + if self.training and torch.is_grad_enabled(): + buffer_backup = [] + for buf in self.base_module.buffers(): + buffer_backup.append(buf.clone()) + grad_backup = [] for param in self.base_module.parameters(): grad_backup.append(param.main_grad.clone() if hasattr(param, "main_grad") else None) @@ -841,7 +854,6 @@ def create_fwd_graph(self, args, kwargs, outputs=None, clone_inputs=True): def _resolve_input_buffer(ten): if not isinstance(ten, ArgMetadata): return ten - # the input tensor is resued from another cudagraph's input or output if ( hasattr(ten, "cg_buffer_metadata") @@ -912,7 +924,7 @@ def _resolve_input_buffer(ten): def clone_ten(ten): if not torch.is_tensor(ten): return ten - return torch.zeros_like(ten).requires_grad_(ten.requires_grad) + return torch.clone(ten).detach().requires_grad_(ten.requires_grad) warmup_args = tree_map(clone_ten, self.fwd_graph_input_args) warmup_kwargs = tree_map(clone_ten, self.fwd_graph_input_kwargs) @@ -985,17 +997,6 @@ def clone_ten(ten): o.cg_buffer_metadata.fwd_cudagraph_buffer = fwd_graph_out fwd_buffer_reuse_ref_count += 1 - # if an input buffer requires a copy, and does not have metadata attached to it at this - # point, it will not be reused after this forward pass, so return it to the pool - for buf in self.fwd_graph_input_surface: - if ( - hasattr(buf, "can_skip_replay_copy") - and not buf.can_skip_replay_copy - and not hasattr(buf, "cg_buffer_metadata") - ): - assert _CudagraphGlobalRecord.tensor_reuse_pool.owns(buf) - _CudagraphGlobalRecord.tensor_reuse_pool.insert(buf) - if self.training and torch.is_grad_enabled(): assert ( len(self.fwd_graph_output_surface) > 0 @@ -1015,6 +1016,10 @@ def clone_ten(ten): if main_grad_copy is not None: param.main_grad.copy_(main_grad_copy) + # restore cached buffers + for buf_copy, buf in zip(buffer_backup, self.base_module.buffers()): + buf.copy_(buf_copy) + if is_moe: for name, cached_values in cached_aux_losses.items(): assert ( @@ -1417,6 +1422,8 @@ def __init__( function_name=None, need_backward=True, pg_collection=None, + inline_capture=False, + num_warmup_steps=None, ): super().__init__() """Creates a CudaGraphManager to manage CUDA graphs for a Megatron module. @@ -1424,7 +1431,13 @@ def __init__( Args: config: TransformerConfig object containing CUDA graph settings for memory pooling, graph retention, gradient accumulation, FP8/FP4, and warmup steps. + inline_capture: Normally, whether the inline capture path is taken depends on whether + `inference_context` is present in the kwargs of the forward call. + Setting this argument to True always forces the inline capture path to be taken. + num_warmup_steps: If set, overrides the per-runner warmup step count. """ + self._inline_capture = inline_capture + self._num_warmup_steps = num_warmup_steps if pg_collection is None: pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.pg_collection = pg_collection @@ -1434,8 +1447,13 @@ def __init__( if function_name is not None: func = getattr(base_module, function_name) - def wrapped_func(*args, **kwargs): - out = self(base_module, args, kwargs) + def wrapped_func(*args, eager=False, cache_key=None, **kwargs): + if eager: + return func(*args, **kwargs) + out = self(base_module, args, kwargs, cache_key=cache_key) + # Unwrap single-element tuple to match the original function's return type. + if isinstance(out, tuple) and len(out) == 1: + return out[0] return out setattr(base_module, function_name, wrapped_func) @@ -1467,7 +1485,7 @@ def wrapped_func(*args, **kwargs): ) self.cudagraph_runners: list[_CudaGraphRunner] = [] - self.inference_cudagraphs_lookup_table: dict = defaultdict(lambda: None) + self.custom_cudagraphs_lookup_table: dict = defaultdict(lambda: None) self.is_first_microbatch = False # Without pipeline parallelism, microbatches execute one at a time. @@ -1495,7 +1513,7 @@ def call_ddp_preforward_hook(self, module): # Only hooks from Mcore DDP, which take no args, should be called at this point. hook(module) - def get_cudagraph_runner(self, megatron_module, args, kwargs, reuse_cudagraphs): + def get_cudagraph_runner(self, megatron_module, args, kwargs, reuse_cudagraphs, cache_key=None): '''Returns a valid cudagraph runner for the current forward call. The cudagraph corresponding to this call is the first element of 'self.cudagraph_runners'. We iterate through the list by 1 for each call, and the number of calls is equal to the @@ -1503,16 +1521,8 @@ def get_cudagraph_runner(self, megatron_module, args, kwargs, reuse_cudagraphs): Otherwise, we assign a mempool per microbatch, which allows cudagraphs to be reused over different microbatches by tracking their respective fwd and bwd passes.''' if reuse_cudagraphs: - is_inference_mode = 'inference_context' in kwargs.keys() and kwargs['inference_context'] - if is_inference_mode: - is_static_batching = kwargs['inference_context'].is_static_batching() - if is_static_batching: - batch_size = kwargs['hidden_states'].shape[0] - is_decode_only = kwargs["inference_context"].is_decode_only() - runner = self.inference_cudagraphs_lookup_table[(batch_size, is_decode_only)] - else: - padded_batch_dimensions = kwargs['inference_context'].padded_batch_dimensions - runner = self.inference_cudagraphs_lookup_table[padded_batch_dimensions] + if cache_key is not None: + runner = self.custom_cudagraphs_lookup_table[cache_key] else: # Todo: For training, we could also cache runners based on input shape. # If autograd is currently disabled, it doesnt matter if a runner was created @@ -1546,15 +1556,11 @@ def is_valid(r): self.func, self.need_backward, ) + if self._num_warmup_steps is not None: + runner.num_warmup_steps = self._num_warmup_steps self.cudagraph_runners.append(runner) - if is_inference_mode: - # Cache the newly created runner in the inference lookup table. - if is_static_batching: - self.inference_cudagraphs_lookup_table[(batch_size, is_decode_only)] = ( - runner - ) - else: - self.inference_cudagraphs_lookup_table[padded_batch_dimensions] = runner + if cache_key is not None: + self.custom_cudagraphs_lookup_table[cache_key] = runner else: # Create cudagraphs for every microbatch if _CudagraphGlobalRecord.cudagraph_created: @@ -1574,7 +1580,7 @@ def is_valid(r): return runner - def __call__(self, megatron_module, args, kwargs): + def __call__(self, megatron_module, args, kwargs, cache_key=None): """Calls the forward pass of the cudagraphed module. Args: @@ -1583,8 +1589,18 @@ def __call__(self, megatron_module, args, kwargs): args (tuple): The positional args to be passed to the module. kwargs (dict): The keyword args to be passed to the module. + + cache_key: Optional hashable key for O(1) runner lookup. + If `inference_context` is provided, this gets set to the correct value. """ is_inference_mode = 'inference_context' in kwargs.keys() and kwargs['inference_context'] + if cache_key is None and is_inference_mode: + inference_context = kwargs['inference_context'] + if inference_context.is_static_batching(): + batch_size = kwargs['hidden_states'].shape[0] + cache_key = (batch_size, inference_context.is_decode_only()) + else: + cache_key = inference_context.padded_batch_dimensions is_in_checkpoint_fwd = is_checkpointing() if HAVE_TE_GRAPHS: is_in_checkpoint_fwd = is_in_checkpoint_fwd or is_fp8_activation_recompute_enabled() @@ -1596,24 +1612,32 @@ def __call__(self, megatron_module, args, kwargs): for module in megatron_module.modules(): self.call_ddp_preforward_hook(module) - runner = self.get_cudagraph_runner(megatron_module, args, kwargs, self.reuse_cudagraphs) + runner = self.get_cudagraph_runner( + megatron_module, args, kwargs, self.reuse_cudagraphs, cache_key=cache_key + ) out = runner.replay_graph_capture(self.is_first_microbatch, args, kwargs) else: - if is_inference_mode: + if is_inference_mode or self._inline_capture: # Inference generation mode creates graphs immediately - runner = self.get_cudagraph_runner(megatron_module, args, kwargs, True) + runner = self.get_cudagraph_runner( + megatron_module, args, kwargs, True, cache_key=cache_key + ) if not runner.fwd_graph_recorded: # Reuse graph input-output buffers for inference local_args, local_kwargs = args, kwargs if not runner.is_first_layer: - # Find previous layer's runner in the global record + # Find previous layer's runner in the global record. + # Method-wrapped managers (e.g. the MTP wrapper around + # `compute_mtp_single_step`) have a base_module without + # `layer_number`; `getattr(..., None)` makes those rows + # harmlessly skipped by the predicate. try: previous_runner = next( r for r in _CudagraphGlobalRecord.cudagraph_inference_record if ( - r[0].base_module.layer_number + getattr(r[0].base_module, 'layer_number', None) == runner.base_module.layer_number - 1 and r[0].fwd_graph is not None and ArgMetadata(r[3]['hidden_states']) @@ -1645,10 +1669,6 @@ def __call__(self, megatron_module, args, kwargs): runner = self.get_cudagraph_runner( megatron_module, args, kwargs, self.reuse_cudagraphs ) - # check if a layer is frozen during training. - if not torch.is_grad_enabled(): - # If the layer is frozen, we need to set the runner to eval mode. - runner.eval() out = runner.record_graph_capture(args, kwargs) else: # No cudagraphs were found in training mode with grad disabled, so fallback to diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 99d8fd97dd9..ed288c3e1f4 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -1,8 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations -import gc -import logging import warnings from collections.abc import Callable from dataclasses import dataclass @@ -27,7 +25,7 @@ from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl, weighted_bias_swiglu_impl from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import sharded_state_dict_default +from megatron.core.transformer.utils import cat_with_oom_fallback, sharded_state_dict_default from megatron.core.typed_torch import apply_module, not_none from megatron.core.utils import ( get_tensor_model_parallel_group_if_none, @@ -43,9 +41,6 @@ HAVE_TE = False -logger = logging.getLogger(__name__) - - class LinearFc1Interface(Protocol): """Interface for linear_fc1 module in MLP.""" @@ -430,25 +425,11 @@ def sh_ten_build_fn( ), ] - def sh_ten_merge_fn(sub_state_dict): - with torch.no_grad(): - try: - return torch.cat(sub_state_dict) - except (RuntimeError, torch.cuda.OutOfMemoryError) as e: - logger.warning( - f"CUDA OutOfMemoryError encountered during tensors merging." - f" Switching to CPU merge. (Error: {e})" - ) - merged_sub_state_dict = torch.cat([t.cpu() for t in sub_state_dict]) - gc.collect() - torch.cuda.empty_cache() - return merged_sub_state_dict - return ShardedTensorFactory( original_sh_ten.key, original_sh_ten.data, sh_ten_build_fn, - sh_ten_merge_fn, + cat_with_oom_fallback, original_sh_ten.replica_id, flattened_range=original_sh_ten.flattened_range, ) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 384a26c0deb..355b84e8150 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -49,6 +49,10 @@ paged_stash_group_commit, paged_stash_group_start, ) +from megatron.core.transformer.moe.token_dispatcher_inference import ( + InferenceAllGatherDispatcherBase, + NVLSAllGatherVDispatcher, +) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import ( ensure_metadata_has_dp_cp_group, @@ -74,11 +78,7 @@ HAVE_FLASHINFER = False from megatron.core.inference.moe import ActivationType as McoreActivationType -from megatron.core.inference.moe import ( - InferenceGroupedGemmBackend, - mcore_fused_moe, - resolve_inference_grouped_gemm_backend, -) +from megatron.core.inference.moe import InferenceGroupedGemmBackend, mcore_fused_moe, vllm_fused_moe logger = logging.getLogger(__name__) @@ -1373,13 +1373,12 @@ def __init__( # checkpoint loading has already populated the per-expert parameters. self._concatenated_weights_built = False - self.is_inference_cuda_graphed_iteration = False - if HAVE_FLASHINFER: self._flashinfer_activation_type = self._resolve_flashinfer_activation_type() self._mcore_activation_type = self._resolve_mcore_activation_type() self.inference_grouped_gemm_backend = config.inference_grouped_gemm_backend + self._nvls_dispatcher = config.inference_moe_token_dispatcher_type == 'nvls' def _resolve_flashinfer_activation_type(self): """Map megatron activation config to FlashInfer ActivationType.""" @@ -1404,14 +1403,6 @@ def _resolve_mcore_activation_type(self): return McoreActivationType.SQUARED_RELU raise ValueError(f"No mcore_fused_moe ActivationType mapping for activation_func={func}") - def set_inference_cuda_graphed_iteration(self): - """Enable CUDA-graphed iteration mode.""" - self.is_inference_cuda_graphed_iteration = True - - def unset_inference_cuda_graphed_iteration(self): - """Disable CUDA-graphed iteration mode.""" - self.is_inference_cuda_graphed_iteration = False - def _build_concatenated_mxfp8_weights(self): """Build stacked MXFP8 weight tensors from per-expert MXFP8Tensor attributes. @@ -1521,12 +1512,11 @@ def _flashinfer_forward(self, hidden_states, routing_map, probs): activation_type=self._flashinfer_activation_type, ep_size=self.ep_group.size(), ep_rank=self.ep_group.rank(), + output=NVLSAllGatherVDispatcher._get_rsv_tensor() if self._nvls_dispatcher else None, )[0] return output, None - def _mcore_fused_moe_forward( - self, hidden_states, probs, routing_map=None, tokens_per_expert=None, skip_permute=False - ): + def _mcore_fused_moe_forward(self, hidden_states, probs, routing_map): """Torch grouped_mm fused MoE forward via mcore_fused_moe.""" local_expert_start = self.ep_group.rank() * self.num_local_experts output = mcore_fused_moe( @@ -1537,10 +1527,28 @@ def _mcore_fused_moe_forward( activation_type=self._mcore_activation_type, num_local_experts=self.num_local_experts, local_expert_start=local_expert_start, + valid_tokens=InferenceAllGatherDispatcherBase._valid_tokens(), routing_map=routing_map, - tokens_per_expert=tokens_per_expert, - skip_permute=skip_permute, disable_fused_quant_kernels=self.config.inference_moe_disable_fused_quant_kernels, + out=NVLSAllGatherVDispatcher._get_rsv_tensor() if self._nvls_dispatcher else None, + ) + return output, None + + def _vllm_forward(self, hidden_states, probs, routing_map): + """vLLM Triton fused MoE kernel forward (BF16, CUDA-graph safe).""" + local_expert_start = self.ep_group.rank() * self.num_local_experts + output = vllm_fused_moe( + hidden_states, + probs, + self._fc1_weight, + self._fc2_weight, + activation_type=self._mcore_activation_type, + num_local_experts=self.num_local_experts, + local_expert_start=local_expert_start, + valid_tokens=InferenceAllGatherDispatcherBase._valid_tokens(), + routing_map=routing_map, + out=NVLSAllGatherVDispatcher._get_rsv_tensor() if self._nvls_dispatcher else None, + num_tokens_hint=InferenceAllGatherDispatcherBase._get_host_valid_tokens_estimate(), ) return output, None @@ -1585,30 +1593,20 @@ def forward( self._build_concatenated_weights() self._concatenated_weights_built = True - resolved_backend = resolve_inference_grouped_gemm_backend( - self.inference_grouped_gemm_backend, - self.is_inference_cuda_graphed_iteration, - is_mxfp8=self.config.fp8_recipe == "mxfp8", - ) - - if resolved_backend == InferenceGroupedGemmBackend.FLASHINFER: + if self.inference_grouped_gemm_backend == InferenceGroupedGemmBackend.FLASHINFER: assert routing_map is not None, "routing_map is required for FlashInfer forward pass." - assert ( - self.is_inference_cuda_graphed_iteration - ), "FlashInfer forward path is only used in CUDA-graphed inference iterations." + assert not self.training, "FlashInfer forward path is only used in inference mode." return self._flashinfer_forward( permuted_local_hidden_states, routing_map, permuted_probs ) - elif resolved_backend == InferenceGroupedGemmBackend.TORCH: + elif self.inference_grouped_gemm_backend == InferenceGroupedGemmBackend.TORCH: return self._mcore_fused_moe_forward( - permuted_local_hidden_states, - permuted_probs, - routing_map=routing_map, - tokens_per_expert=tokens_per_expert, - skip_permute=(not self.is_inference_cuda_graphed_iteration), + permuted_local_hidden_states, permuted_probs, routing_map=routing_map + ) + elif self.inference_grouped_gemm_backend == InferenceGroupedGemmBackend.VLLM: + return self._vllm_forward( + permuted_local_hidden_states, permuted_probs, routing_map=routing_map ) - elif resolved_backend == InferenceGroupedGemmBackend.TE: - return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) class SequentialMLP(MegatronModule): diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 11a4bd1a8b2..ea030037762 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -29,7 +29,8 @@ MoETokenDispatcher, ) from megatron.core.transformer.moe.token_dispatcher_inference import ( - InferenceCUDAGraphTokenDispatcher, + NCCLAllGatherDispatcher, + NVLSAllGatherVDispatcher, ) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.typed_torch import apply_module, not_none @@ -51,6 +52,13 @@ except ImportError: HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE = False +try: + import triton # pylint: disable=unused-import + + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + if HAVE_TE: from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint else: @@ -342,9 +350,16 @@ def __init__( check_flashinfer_jit_cache_installed() elif config.inference_grouped_gemm_backend == 'torch': - assert hasattr(torch.nn.functional, 'grouped_mm'), ( + assert hasattr(torch.nn.functional, 'grouped_mm') or hasattr( + torch, '_grouped_mm' + ), ( "inference_grouped_gemm_backend='torch' requires " - "torch.nn.functional.grouped_mm (available since PyTorch 2.10)." + "torch.nn.functional.grouped_mm (> torch 2.10) or torch._grouped_mm (<= 2.10)." + ) + elif config.inference_grouped_gemm_backend == 'vllm': + assert HAVE_TRITON, ( + "inference_grouped_gemm_backend='vllm' requires Triton. " + "Install triton (pip install triton)." ) self._setup_inference_mode(pg_collection) @@ -361,25 +376,51 @@ def __init__( self.setup_delayed_wgrad_for_dispatch_backward_overlap() def _setup_inference_mode(self, pg_collection): - """Set up inference-optimized token dispatcher and state. + """Set up inference-optimized token dispatcher. Called from __init__ when config.transformer_impl == "inference_optimized". - Creates an InferenceCUDAGraphTokenDispatcher alongside the standard dispatcher, - which is swapped in during CUDA-graphed forward passes. + Stores the training dispatcher and creates the inference dispatcher selected + by config.inference_moe_token_dispatcher_type ('nccl' or 'nvls'). + The active dispatcher is swapped automatically via the train() override: + eval mode → inference dispatcher, train mode → standard dispatcher. """ - - assert self.config.moe_token_dispatcher_type == "alltoall", ( - f"Inference-optimized MoE requires 'alltoall' dispatcher, " - f"got '{self.config.moe_token_dispatcher_type}'" + dispatcher_type = self.config.inference_moe_token_dispatcher_type + dispatcher_cls = ( + NVLSAllGatherVDispatcher if dispatcher_type == 'nvls' else NCCLAllGatherDispatcher ) - self.is_inference_cuda_graphed_iteration = False - self._inference_token_dispatcher = InferenceCUDAGraphTokenDispatcher( + + self._training_token_dispatcher = self.token_dispatcher + self._inference_token_dispatcher = dispatcher_cls( self.num_local_experts, self.local_expert_indices, config=self.config, pg_collection=pg_collection, ) + # Wire shared-expert overlap into the inference dispatcher (NVLS only). + # The dispatcher launches the shared-expert forward on SharedExpertMLP.stream + # concurrently with AGV+experts+RSV and adds it back in combine_postprocess. + if ( + dispatcher_type == 'nvls' + and self.use_shared_expert + and self.config.moe_shared_expert_overlap + ): + self._inference_token_dispatcher.set_shared_experts(self.shared_experts) + + def train(self, mode: bool = True): + """Swap token dispatcher when switching between train and eval modes.""" + super().train(mode) + if hasattr(self, "_inference_token_dispatcher"): + if mode: + self.token_dispatcher = self._training_token_dispatcher + self.shared_expert_overlap = self.config.moe_shared_expert_overlap + else: + self.token_dispatcher = self._inference_token_dispatcher + self.shared_expert_overlap = ( + self._inference_token_dispatcher.shared_experts is not None + ) + return self + def setup_delayed_wgrad_for_dispatch_backward_overlap(self): """Initializes CUDA events and streams for overlapping expert weight gradient computation with dispatch backward. @@ -390,39 +431,6 @@ def setup_delayed_wgrad_for_dispatch_backward_overlap(self): self._delayed_wgrad_event = torch.cuda.Event() self._delayed_wgrad_stream = torch.cuda.Stream(device="cuda") - def set_inference_cuda_graphed_iteration(self): - """Enable CUDA-graphed iteration mode on this layer, its router, and its experts. - - Swaps in the inference-optimized token dispatcher and disables - shared expert overlap. - """ - self.is_inference_cuda_graphed_iteration = True - if hasattr(self.router, "set_inference_cuda_graphed_iteration"): - self.router.set_inference_cuda_graphed_iteration() - if hasattr(self.experts, "set_inference_cuda_graphed_iteration"): - self.experts.set_inference_cuda_graphed_iteration() - - if self._inference_token_dispatcher is not None: - self._saved_token_dispatcher = self.token_dispatcher - self.token_dispatcher = self._inference_token_dispatcher - self._saved_shared_expert_overlap = self.shared_expert_overlap - self.shared_expert_overlap = False - - def unset_inference_cuda_graphed_iteration(self): - """Disable CUDA-graphed iteration mode on this layer, its router, and its experts. - - Restores the standard token dispatcher and shared expert overlap setting. - """ - self.is_inference_cuda_graphed_iteration = False - if hasattr(self.router, "unset_inference_cuda_graphed_iteration"): - self.router.unset_inference_cuda_graphed_iteration() - if hasattr(self.experts, "unset_inference_cuda_graphed_iteration"): - self.experts.unset_inference_cuda_graphed_iteration() - - if hasattr(self, "_saved_token_dispatcher"): - self.token_dispatcher = self._saved_token_dispatcher - self.shared_expert_overlap = self._saved_shared_expert_overlap - @maybe_skip_or_early_return_by_cudagraph("route") def route( self, @@ -560,6 +568,7 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso if ( hasattr(self, "_inference_token_dispatcher") and self.is_inference_cuda_graphed_iteration + and not self.training ): routing_map = self.token_dispatcher.routing_map expert_output, mlp_bias = apply_module(self.experts)( diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index d316d23de10..ae0d77546d1 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -775,7 +775,8 @@ def _compute_topk( group_topk=group_topk, ) else: - return torch.topk(scores, k=topk, dim=1) + # Sorting top-k turned off during inference + return torch.topk(scores, k=topk, dim=1, sorted=torch.is_grad_enabled()) def compute_topk(scores, topk, num_groups=None, group_topk=None): # Default behavior if no replay is active diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index cdf968c6b12..131f58b0fd2 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -826,6 +826,7 @@ def __init__( config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None, is_mtp_layer: bool = False, + layer_number: Optional[int] = None, ) -> None: """Initialize the specialized inference top-k router. @@ -843,17 +844,12 @@ def __init__( f"['sigmoid', 'softmax'], got '{config.moe_router_score_function}'" ) - super().__init__(config=config, pg_collection=pg_collection) - - self.is_inference_cuda_graphed_iteration = False - - def set_inference_cuda_graphed_iteration(self): - """Enable CUDA graph-compatible operations for the router.""" - self.is_inference_cuda_graphed_iteration = True - - def unset_inference_cuda_graphed_iteration(self): - """Disable CUDA graph-compatible operations for the router.""" - self.is_inference_cuda_graphed_iteration = False + super().__init__( + config=config, + pg_collection=pg_collection, + is_mtp_layer=is_mtp_layer, + layer_number=layer_number, + ) @staticmethod @torch.compile @@ -915,7 +911,7 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No - top_indices: Selected expert indices [num_tokens, topk] """ - if self.training or not self.is_inference_cuda_graphed_iteration: + if self.training: return super().forward(input, padding_mask) return self._forward(input, padding_mask) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index b3b7b06b1b8..fee92bbb6d3 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -1,237 +1,190 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. """ -CUDA-graph-compatible token dispatcher for inference. +Inference token dispatchers for MoE expert parallelism. -This dispatcher is only used during CUDA-graphed inference iterations. It replaces -AlltoAll with AllGather/ReduceScatter for token exchange, keeping all metadata -GPU-resident to avoid host synchronizations that would break CUDA graph capture. +Two dispatchers are provided, selected via config.inference_moe_token_dispatcher_type: -Supports latency-optimized NVLS collectives (multimem all-gather/reduce-scatter) -on Hopper+ GPUs with BF16, with automatic fallback to NCCL. + NCCLAllGatherDispatcher ('nccl', default) + Standard NCCL AllGather/ReduceScatter. All EP ranks must contribute the same + token count per step; decode-only CUDA graphs are forced automatically. + + NVLSAllGatherVDispatcher ('nvls') + Variable-count NVLS AllGather-V/ReduceScatter-V via multimem kernels. Supports + different token counts per rank per step. Requires Hopper+ GPUs with NVLink and + symmetric memory. Opt-in. + +InferenceAllGatherDispatcherBase is a minimal base used solely for isinstance checks +and to hold _valid_tokens_tensor — the shared interface that mcore_fused_moe reads to +gate kernel work to the valid token prefix. Each dispatcher defines its own +update_metadata method, invoked from the first instance's token_dispatch so the +per-step metadata kernel is captured inside the CUDA graph. """ +import operator +from functools import reduce from typing import List, Optional import torch +import torch.distributed as dist from megatron.core.inference.communication.torch_symm_triton import ( - are_tensors_nvls_eligible, - multimem_all_gather_fused, - multimem_reduce_scatter, + multimem_all_gatherv_3tensor, + multimem_reduce_scatter_v, ) +from megatron.core.inference.moe import InferenceGroupedGemmBackend +from megatron.core.inference.moe.metadata import fused_metadata_update from megatron.core.inference.symmetric_memory import SymmetricMemoryManager from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel import ( gather_from_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region, ) +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.moe.token_dispatcher import MoEAllGatherTokenDispatcher from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.typed_torch import apply_module +from megatron.core.utils import get_pg_rank, get_pg_size -class InferenceCUDAGraphTokenDispatcher(MoEAllGatherTokenDispatcher): +class InferenceAllGatherDispatcherBase(MoEAllGatherTokenDispatcher): + """Minimal base for inference AllGather token dispatchers. + + Exists for isinstance checks and to expose _valid_tokens_tensor — the single + class-level value that mcore_fused_moe reads (via experts.py) to gate kernel + work to the valid token prefix. Each concrete subclass owns its own metadata + and defines update_metadata independently. """ - CUDA-graph-compatible AllGather token dispatcher for inference. - Only used during CUDA-graphed inference iterations. Swapped in by - MoELayer.set_inference_cuda_graphed_iteration() before graph capture - and swapped out by MoELayer.unset_inference_cuda_graphed_iteration() after. + # [1] int32: total valid tokens across all EP ranks this step. + # Written in-place each step so CUDA graph replay sees a stable address. + # NVLSAllGatherVDispatcher points this at _step_metadata[0:1] on first init + # so that experts.py can always call _valid_tokens() on this base class. + _valid_tokens_tensor: Optional[torch.Tensor] = None + + # Host-side estimate of the total valid token count across all EP ranks. + # Computed as local_tokens * ep_size to avoid a device-to-host sync (which + # would break CUDA graph capture). This may differ from _valid_tokens_tensor + # when ranks have unequal token counts. + _host_valid_tokens_estimate: Optional[int] = None + + def __init__(self, *args, runs_metadata_sync: bool = True, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._runs_metadata_sync = runs_metadata_sync + + @classmethod + def _valid_tokens(cls) -> torch.Tensor: + return cls._valid_tokens_tensor + + @classmethod + def _get_host_valid_tokens_estimate(cls) -> Optional[int]: + return cls._host_valid_tokens_estimate + + def update_metadata(self, local_tokens: int) -> None: + """Per-step metadata refresh fired from the first instance's token_dispatch. - Key features: - - AllGather/ReduceScatter instead of AlltoAll for CUDA graph compatibility - - GPU-resident metadata (no host synchronization) - - NVLS collectives on Hopper+ with automatic NCCL fallback + Must be idempotent across a step (only called once) and safe to capture + into a CUDA graph on the decode path. + """ + raise NotImplementedError + + +class NCCLAllGatherDispatcher(InferenceAllGatherDispatcherBase): + """AllGather token dispatcher for inference using NCCL. + + Two modes, selected by _use_allgather_v (set from the context each step): + + CG path (use_allgather_v=False): all EP ranks contribute the same token count, + guaranteed by decode-only CUDA graphs. Standard AllGather/ReduceScatter. + + Non-CG path (use_allgather_v=True): ranks may have different token counts + (prefill). Each rank pads its tensors to max_tokens, runs a standard AllGather, + then compacts by stripping per-rank padding. Combine is the reverse: expand + compact output to padded layout, ReduceScatter, truncate to local token count. """ + _use_allgather_v: bool = False + _local_tokens_per_rank: Optional[List[int]] = None + def __init__( self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None, + runs_metadata_sync: bool = True, ) -> None: - """ - Initialize the InferenceCUDAGraphTokenDispatcher. - - Args: - num_local_experts: Number of experts on this rank. - local_expert_indices: Global indices of experts on this rank. - config: Transformer configuration. - pg_collection: Process group collection for distributed ops. - """ super().__init__( num_local_experts=num_local_experts, local_expert_indices=local_expert_indices, config=config, pg_collection=pg_collection, + runs_metadata_sync=runs_metadata_sync, ) self.topk = config.moe_router_topk - self.triton_nvls_kernels_allowed = not self.config.inference_disable_triton_nvls_kernels - - def _maybe_allocate_ag_buffers( - self, routing_map: torch.Tensor, probs: torch.Tensor, hidden_states: torch.Tensor - ) -> dict: - """Allocate a single symmetric memory output buffer for fused all-gather. + @classmethod + def allocate_buffers(cls) -> None: + """Allocate the per-step valid-tokens tensor read by mcore_fused_moe. - Creates one contiguous symmetric memory buffer sized for the gathered - (global) routing_map, probs, and hidden_states, then returns sliced views - into it. This allows a single fused NVLS all-gather kernel to write all - three outputs in one launch. - - Args: - routing_map (torch.Tensor): Local routing map, shape [local_tokens, topk]. - Boolean or integer tensor mapping each token to its selected experts. - probs (torch.Tensor): Local routing probabilities, shape [local_tokens, topk]. - Normalized weights for each token's selected experts. - hidden_states (torch.Tensor): Local hidden states, shape [local_tokens, hidden_dim]. - - Returns: - dict: A dictionary with the following keys: - - "handle": Symmetric memory handle for NVLS ops, or None if - symmetric memory is unavailable. - - "routing_map": Raw byte view for the gathered routing map output. - - "routing_map_offset": Byte offset of routing_map within the buffer. - - "probs": Raw byte view for the gathered probs output. - - "probs_offset": Byte offset of probs within the buffer. - - "hidden_states": Raw byte view for the gathered hidden states output. - - "hidden_states_offset": Byte offset of hidden_states within the buffer. - When allocation fails, all tensor views are None and offsets are 0. + Called once at model init from the dynamic context. Must run outside any + CUDA graph capture so update_metadata can write to a stable address during + replay without triggering allocations inside the graph. """ - _NONE = { - "handle": None, - "routing_map": None, - "routing_map_offset": 0, - "probs": None, - "probs_offset": 0, - "hidden_states": None, - "hidden_states_offset": 0, - } - - local_tokens = probs.size(0) - global_tokens = local_tokens * self.ep_size - topk = probs.size(-1) - hidden_dim = hidden_states.size(-1) - - result = SymmetricMemoryManager.get_buffer( - "ep", process_group=self.ep_group - ).maybe_get_tensors( - [ - (global_tokens * topk, routing_map.dtype), - (global_tokens * topk, probs.dtype), - (global_tokens * hidden_dim, hidden_states.dtype), - ] + device = torch.cuda.current_device() + InferenceAllGatherDispatcherBase._valid_tokens_tensor = torch.zeros( + 1, dtype=torch.int32, device=device ) - if result["handle"] is None: - return _NONE - - (rm_buf, rm_off), (p_buf, p_off), (hs_buf, hs_off) = result["tensors"] - return { - "handle": result["handle"], - "routing_map": rm_buf, - "routing_map_offset": rm_off, - "probs": p_buf, - "probs_offset": p_off, - "hidden_states": hs_buf, - "hidden_states_offset": hs_off, - } + def update_metadata(self, local_tokens: int) -> None: + """Per-step metadata update; invoked from the first instance's token_dispatch. - def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: - """Allocate a symmetric memory buffer for reduce-scatter input. + CG path (_use_allgather_v=False): ranks have equal counts by construction, so + we only refresh _valid_tokens_tensor — a single .fill_ that is safe to capture. - The buffer has the same shape and dtype as x so that x can be copied - into it before the NVLS reduce-scatter kernel. - - Args: - x (torch.Tensor): The global hidden states to be reduce-scattered, - shape [global_tokens, hidden_dim]. - - Returns: - dict: A dictionary with keys "handle" (symmetric memory handle, or - None if unavailable) and "tensor" (the allocated buffer, or None). + Non-CG path (_use_allgather_v=True): ranks may differ, so we all-gather the + per-rank counts and host-sync via .tolist() for the pad/compact logic below. + This path never runs under graph capture. """ - symm_mem_buffer = SymmetricMemoryManager.get_buffer( - "ep", process_group=self.ep_group - ).maybe_get_tensor(list(x.size()), dtype=x.dtype) - return symm_mem_buffer + cls = NCCLAllGatherDispatcher + ep_size = self.ep_size + device = torch.cuda.current_device() + + if cls._use_allgather_v: + local_count = torch.tensor([local_tokens], dtype=torch.int32, device=device) + local_tokens_per_rank = torch.empty(ep_size, dtype=torch.int32, device=device) + dist.all_gather_into_tensor(local_tokens_per_rank, local_count, group=self.ep_group) + cls._local_tokens_per_rank = local_tokens_per_rank.tolist() + total = local_tokens_per_rank.sum() + InferenceAllGatherDispatcherBase._valid_tokens_tensor.copy_(total) + InferenceAllGatherDispatcherBase._host_valid_tokens_estimate = int(total.item()) + else: + total = ep_size * local_tokens + InferenceAllGatherDispatcherBase._valid_tokens_tensor.fill_(total) + InferenceAllGatherDispatcherBase._host_valid_tokens_estimate = total def token_dispatch(self, hidden_states, probs): - """Gathers tokens from all EP ranks using AllGather. + """Gather hidden_states, probs, and routing_map from all EP ranks. - Performs all-gather on routing_map (stored in self.routing_map), probs, - and hidden_states so that every rank holds the full global view. - Uses latency-optimized fused NVLS multimem_all_gather on Hopper+ GPUs - with BF16 when symmetric memory is available. Falls back to NCCL otherwise. + CG path: standard AllGather (equal token counts guaranteed). + Non-CG path: pad to max_tokens, AllGather, compact (strip per-rank padding). Args: - hidden_states (torch.Tensor): Local hidden states, - shape [local_tokens, hidden_dim]. - probs (torch.Tensor): Local routing probabilities, - shape [local_tokens, topk]. Normalized weights for each token's - selected experts. + hidden_states: [local_tokens, hidden_dim] local input. + probs: [local_tokens, topk] local routing probabilities. Returns: - tuple: (hidden_states, probs) gathered across all EP ranks. - - hidden_states (torch.Tensor): Shape [global_tokens, hidden_dim]. - - probs (torch.Tensor): Shape [global_tokens, topk]. - Also updates self.routing_map in-place to the gathered - shape [global_tokens, topk]. + (hidden_states, probs) gathered to [total_tokens, *] shape. + Also updates self.routing_map to [total_tokens, topk]. """ if self.ep_size == 1: return hidden_states, probs - # 1. Check inputs only: if inputs are 16-byte divisible, - # outputs (world_size * input) are too. - nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible( - hidden_states, probs, self.routing_map - ) - ag_buffers = None - - if nvls_eligible: - # 2. Now attempt to allocate symmetric memory buffers for - # all-gather outputs. If allocation fails, fallback to NCCL. - ag_buffers = self._maybe_allocate_ag_buffers(self.routing_map, probs, hidden_states) - - # 3. Can use NVLS if eligible and buffers allocated successfully (handle is not None) - can_use_nvls = nvls_eligible and ag_buffers["handle"] is not None - - if can_use_nvls: - # Capture shapes for reshaping after all-gather - # Output shape: [local_tokens * ep_size, dim] - local_tokens = probs.size(0) - global_tokens = local_tokens * self.ep_size - topk = probs.size(1) - hidden_dim = hidden_states.size(1) - routing_map_dtype = self.routing_map.dtype - probs_dtype = probs.dtype - hidden_dtype = hidden_states.dtype - - # Fused NVLS all-gather: single kernel launch + single barrier for all 3 tensors - multimem_all_gather_fused( - ag_buffers["routing_map"].view( - torch.bfloat16 - ), # .view does not change the underlying data - self.routing_map.view(torch.bfloat16), - ag_buffers["routing_map_offset"], - ag_buffers["probs"].view(torch.bfloat16), - probs.view(torch.bfloat16), - ag_buffers["probs_offset"], - ag_buffers["hidden_states"].view(torch.bfloat16), - hidden_states.view(torch.bfloat16), - ag_buffers["hidden_states_offset"], - ag_buffers["handle"], - ) - self.routing_map = ( - ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, topk) - ) - probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, topk) - hidden_states = ( - ag_buffers["hidden_states"].view(hidden_dtype).view(global_tokens, hidden_dim) - ) - else: - # Fallback to NCCL for all tensors + if self._runs_metadata_sync: + self.update_metadata(hidden_states.shape[0]) + + if not self.__class__._use_allgather_v: + # CG path: equal token counts, standard gather. with torch.no_grad(): self.routing_map = gather_from_sequence_parallel_region( self.routing_map, group=self.tp_ep_group @@ -240,93 +193,398 @@ def token_dispatch(self, hidden_states, probs): hidden_states = gather_from_sequence_parallel_region( hidden_states, group=self.tp_ep_group ) + return hidden_states, probs + + # Non-CG path: pad → AllGather → compact. + tokens_per_rank = self.__class__._local_tokens_per_rank + max_tokens = max(tokens_per_rank) + + def pad_to_max(tensor): + deficit = max_tokens - tensor.shape[0] + if deficit == 0: + return tensor + return torch.cat([tensor, tensor.new_empty((deficit,) + tensor.shape[1:])], dim=0) + + def allgather(padded_tensor): + gathered = padded_tensor.new_empty( + (self.ep_size * max_tokens,) + padded_tensor.shape[1:] + ) + dist.all_gather_into_tensor(gathered, padded_tensor, group=self.ep_group) + return gathered + + hidden_gathered = allgather(pad_to_max(hidden_states)) + probs_gathered = allgather(pad_to_max(probs)) + with torch.no_grad(): + routing_gathered = allgather(pad_to_max(self.routing_map)) + + def compact(gathered_tensor): + return torch.cat( + [ + gathered_tensor[src_rank * max_tokens : src_rank * max_tokens + n_tokens] + for src_rank, n_tokens in enumerate(tokens_per_rank) + ], + dim=0, + ) + hidden_states = compact(hidden_gathered) + probs = compact(probs_gathered) + self.routing_map = compact(routing_gathered) return hidden_states, probs def dispatch_postprocess(self, hidden_states, probs): - """Pass-through: returns inputs directly without permutation. + """Pass-through: mcore_fused_moe operates directly on the gathered tensors.""" + return hidden_states, None, probs - Unlike the training dispatcher, this does not permute tokens or compute - tokens_per_expert. The downstream InferenceGroupedMLP (FlashInfer / - CUTLASS fused MoE kernel) operates directly on the routing map stored - in self.routing_map. + def combine_preprocess(self, expert_output): + """Pass-through: unpermute is handled inside mcore_fused_moe.""" + return expert_output + + def token_combine(self, hidden_states): + """Scatter-reduce expert outputs back to each EP rank. + + CG path: standard ReduceScatter (equal token counts guaranteed). + Non-CG path: expand compact output to padded layout, ReduceScatter, truncate. Args: - hidden_states (torch.Tensor): Gathered hidden states, - shape [global_tokens, hidden_dim]. - probs (torch.Tensor): Gathered routing probabilities, - shape [global_tokens, topk]. + hidden_states: [total_tokens, hidden_dim] expert outputs. Returns: - tuple: (hidden_states, tokens_per_expert, probs) where - tokens_per_expert is always None. + [local_tokens, hidden_dim] bf16 local token outputs. """ - return hidden_states, None, probs + if self.ep_size == 1: + return hidden_states.to(torch.bfloat16) + + if not self.__class__._use_allgather_v: + # CG path: equal token counts, standard reduce-scatter. + hidden_states = reduce_scatter_to_sequence_parallel_region( + hidden_states, group=self.tp_ep_group + ) + return hidden_states.to(torch.bfloat16) + + # Non-CG path: expand compact → padded, ReduceScatter, truncate. + tokens_per_rank = self.__class__._local_tokens_per_rank + max_tokens = max(tokens_per_rank) + ep_rank = get_pg_rank(self.ep_group) + + # Expand [total_tokens, H] → [ep_size * max_tokens, H], zeros in padding slots. + padded_output = hidden_states.new_zeros(self.ep_size * max_tokens, hidden_states.shape[1]) + offset = 0 + for dst_rank, n_tokens in enumerate(tokens_per_rank): + padded_output[dst_rank * max_tokens : dst_rank * max_tokens + n_tokens] = hidden_states[ + offset : offset + n_tokens + ] + offset += n_tokens + + # ReduceScatter: [ep_size * max_tokens, H] → [max_tokens, H]. + scattered = padded_output.new_empty(max_tokens, hidden_states.shape[1]) + dist.reduce_scatter_tensor(scattered, padded_output, group=self.ep_group) + + # Truncate padding and cast. + return scattered[: tokens_per_rank[ep_rank]].to(torch.bfloat16) - def combine_preprocess(self, expert_output): - """Pass-through: InferenceGroupedMLP already produces unpermuted output. - No unpermutation is needed because dispatch_postprocess did not permute - the tokens in the first place. +class NVLSAllGatherVDispatcher(InferenceAllGatherDispatcherBase): + """Variable-count AllGather-V / ReduceScatter-V dispatcher for inference CUDA graphs. + + Replaces the fixed AllGather/ReduceScatter of NCCLAllGatherDispatcher with + variable-count NVLS collectives so ranks can hold different token counts per step. + All metadata lives on-device; no host sync is needed between steps. + + Requires Hopper+ GPUs with NVLink and symmetric memory. + """ + + # ── Class-level NVLS step metadata ─────────────────────────────────────────── + # Packed [3] int32: [valid_tokens, rank_token_offset, ep_max_tokens]. + # Written in-place each step for stable CUDA graph addresses. + # _valid_tokens_tensor on the base is pointed at _step_metadata[0:1] on first + # init, so experts.py can read valid_tokens via the base class interface. + _step_metadata: Optional[torch.Tensor] = None # [3] int32 + _per_rank_worst_case_token_count: int = 2048 # round_up_tokens(max_tokens) // tp_size + + # ── Class-level symmetric buffer handles (allocated once at model init) ─────── + # Dtypes: hidden=bf16, routing=int64, probs=fp32, rsv=fp32. + _symm_agv_hidden: Optional[dict] = None # {"tensor": ..., "handle": ...} + _symm_agv_routing: Optional[dict] = None + _symm_agv_probs: Optional[dict] = None + _symm_rsv: Optional[dict] = None + + @classmethod + def _get_rsv_tensor(cls) -> Optional[torch.Tensor]: + """Return the RSV symmetric buffer tensor so mcore_fused_moe can write + unpermute output directly into it, avoiding a copy before RSV.""" + return cls._symm_rsv["tensor"] if cls._symm_rsv is not None else None + + @classmethod + def _rank_token_offset(cls) -> torch.Tensor: + return cls._step_metadata[1:2] + + @classmethod + def _ep_max_tokens(cls) -> torch.Tensor: + return cls._step_metadata[2:3] + + @classmethod + def _delete_buffers(cls): + # needed by CI. + cls._step_metadata = None + cls._symm_agv_hidden = None + cls._symm_agv_routing = None + cls._symm_agv_probs = None + cls._symm_rsv = None + cls._symm_metadata = None + + @classmethod + def allocate_buffers( + cls, + per_rank_worst_case_token_count: int, + topk: int, + hidden_size: int, + ep_group: torch.distributed.ProcessGroup, + ) -> None: + """Allocate all symmetric buffers and initialize class-level metadata. + + Called once at model init. Allocates fixed-size AGV and RSV symmetric + memory buffers so dispatch/combine can proceed without any allocation on + the hot path. Args: - expert_output (torch.Tensor): Output from InferenceGroupedMLP, - shape [global_tokens, hidden_dim]. + per_rank_worst_case_token_count: Max tokens this rank can contribute, + computed by the context as round_up_tokens(max_tokens) // tp_size. + topk: MoE router top-k value. + hidden_size: Model hidden dimension. + ep_group: Expert parallel process group. + """ + ep_size = get_pg_size(ep_group) + cls._per_rank_worst_case_token_count = per_rank_worst_case_token_count + global_max = per_rank_worst_case_token_count * ep_size + device = torch.cuda.current_device() + + # Each buffer self-sizes from its exact tensor footprint so non-default + # max_tokens / hidden_size / ep_size combinations don't silently overflow + # the symmetric-memory cap. + _MB = 1024 * 1024 + + def _size_mb(shape, dtype) -> int: + nbytes = reduce(operator.mul, shape, 1) * torch.tensor([], dtype=dtype).element_size() + return max(1, (nbytes + _MB - 1) // _MB) + + agv_h_shape = [global_max, hidden_size] + agv_r_shape = [global_max, topk] + agv_p_shape = [global_max, topk] + rsv_shape = [global_max, hidden_size] + meta_shape = [ep_size] + + cls._symm_agv_hidden = SymmetricMemoryManager.get_buffer( + "ep_agv_h", process_group=ep_group, size_mb=_size_mb(agv_h_shape, torch.bfloat16) + ).maybe_get_tensor(agv_h_shape, dtype=torch.bfloat16) + + cls._symm_agv_routing = SymmetricMemoryManager.get_buffer( + "ep_agv_r", process_group=ep_group, size_mb=_size_mb(agv_r_shape, torch.int64) + ).maybe_get_tensor(agv_r_shape, dtype=torch.int64) + + cls._symm_agv_probs = SymmetricMemoryManager.get_buffer( + "ep_agv_p", process_group=ep_group, size_mb=_size_mb(agv_p_shape, torch.float32) + ).maybe_get_tensor(agv_p_shape, dtype=torch.float32) + + cls._symm_rsv = SymmetricMemoryManager.get_buffer( + "ep_rsv", process_group=ep_group, size_mb=_size_mb(rsv_shape, torch.float32) + ).maybe_get_tensor(rsv_shape, dtype=torch.float32) + + # Small scratch buffer for fused metadata allgather (WORLD_SIZE int32s). + cls._symm_metadata = SymmetricMemoryManager.get_buffer( + "ep_meta", process_group=ep_group, size_mb=_size_mb(meta_shape, torch.int32) + ).maybe_get_tensor(meta_shape, dtype=torch.int32) + + failed = [ + (name, SymmetricMemoryManager.get_buffer(name).init_failure_reason) + for name, buf in ( + ("ep_agv_h", cls._symm_agv_hidden), + ("ep_agv_r", cls._symm_agv_routing), + ("ep_agv_p", cls._symm_agv_probs), + ("ep_rsv", cls._symm_rsv), + ("ep_meta", cls._symm_metadata), + ) + if buf["handle"] is None + ] + if failed: + details = "; ".join(f"{name}: {reason or 'unknown'}" for name, reason in failed) + raise RuntimeError( + f"NVLSAllGatherVDispatcher: symmetric memory init failed [{details}]. " + f"This dispatcher requires Hopper+ GPUs fully connected via NVLink, and torch built" + f"with torch.distributed._symmetric_memory plus triton installed. " + f"Use inference_moe_token_dispatcher_type='nccl' on non-NVLS systems." + ) - Returns: - torch.Tensor: The input tensor unchanged. + # Initialise step-metadata tensor and wire base class valid_tokens pointer. + cls._step_metadata = torch.zeros(3, dtype=torch.int32, device=device) + InferenceAllGatherDispatcherBase._valid_tokens_tensor = cls._step_metadata[0:1] + + def update_metadata(self, local_tokens: int) -> None: + """Per-step metadata update; invoked from the first instance's token_dispatch. + + Fires the fused NVLS allgather+reduce to publish + [valid_tokens, rank_token_offset, ep_max_tokens] into _step_metadata, then + (for FlashInfer) pre-masks the routing buffer with -1 so rows beyond + valid_tokens are ignored by the GEMM; the AGV below overwrites + [0, valid_tokens) in-place. """ - return expert_output + cls = NVLSAllGatherVDispatcher + fused_metadata_update( + local_tokens=local_tokens, + local_buf=cls._symm_metadata["tensor"], + symm_mem_hdl=cls._symm_metadata["handle"], + step_metadata=cls._step_metadata, + ) + InferenceAllGatherDispatcherBase._host_valid_tokens_estimate = local_tokens * self.ep_size + if self.config.inference_grouped_gemm_backend == InferenceGroupedGemmBackend.FLASHINFER: + cls._symm_agv_routing["tensor"].fill_(-1) - def token_combine(self, hidden_states): - """Combines expert outputs across EP ranks using Reduce-Scatter. + def __init__( + self, + num_local_experts: int, + local_expert_indices: List[int], + config: TransformerConfig, + pg_collection: Optional[ProcessGroupCollection] = None, + runs_metadata_sync: bool = True, + ) -> None: + super().__init__( + num_local_experts=num_local_experts, + local_expert_indices=local_expert_indices, + config=config, + pg_collection=pg_collection, + runs_metadata_sync=runs_metadata_sync, + ) + self.topk = config.moe_router_topk + # Set in dispatch_preprocess; consumed by token_dispatch and token_combine. + self._local_tokens: int = 0 + # When shared_expert_overlap is enabled, the shared expert forward is launched + # on SharedExpertMLP.stream in dispatch_preprocess and joined in combine_postprocess. + self._shared_expert_output: Optional[torch.Tensor] = None + + # ── Dispatch path ───────────────────────────────────────────────────────────── - Reduces the global expert output (summing contributions from each rank) - and scatters the result so each rank receives its local token slice. - Uses latency-optimized NVLS multimem_reduce_scatter on Hopper+ GPUs - with BF16 when symmetric memory is available. Falls back to NCCL otherwise. + def dispatch_preprocess(self, hidden_states, routing_map, probs): + """Store routing map and local token count; no inter-rank communication. + + If shared_expert_overlap is enabled (set_shared_experts has been called), + launch the entire shared-expert forward on SharedExpertMLP.stream so it + runs concurrently with AGV dispatch, expert GEMMs, and RSV combine. + """ + self.hidden_shape = hidden_states.shape + if self.shared_experts is not None: + stream = SharedExpertMLP.stream + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + self._shared_expert_output = apply_module(self.shared_experts)(hidden_states) + # [S/TP, B, H] -> [S*B/TP, H] + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + self._local_tokens = hidden_states.shape[0] + self.routing_map = routing_map + return hidden_states, probs + + def token_dispatch(self, hidden_states, probs): + """AllGather-V: gather hidden_states, probs, and routing_map from all EP ranks. Args: - hidden_states (torch.Tensor): Combined expert output after routing - weights have been applied, shape [global_tokens, hidden_dim]. + hidden_states: [local_tokens, hidden_size] bf16 local input. + probs: [local_tokens, topk] fp32 local routing probabilities. Returns: - torch.Tensor: Local slice of the reduced output, - shape [local_tokens, hidden_dim] where - local_tokens = global_tokens // ep_size. + (hidden_states, probs) gathered to [global_max, *] shape. + Also updates self.routing_map to [global_max, topk] int64. """ if self.ep_size == 1: - return hidden_states - - # Compute output shape first — check NVLS eligibility on the output, - # since if the smaller output is 16-byte divisible, the input is too. - output_shape = list(hidden_states.size()) - output_shape[0] = hidden_states.size(0) // self.ep_size - output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) - - # Check output only: if output is 16-byte divisible, input (world_size * output) is too. - nvls_eligible = ( - self.triton_nvls_kernels_allowed - and output.dtype in (torch.bfloat16, torch.float32) - and are_tensors_nvls_eligible(output) + return hidden_states, probs + + if self._runs_metadata_sync: + self.update_metadata(hidden_states.shape[0]) + + agv_h = self.__class__._symm_agv_hidden + agv_r = self.__class__._symm_agv_routing + agv_p = self.__class__._symm_agv_probs + + per_rank_max = self._per_rank_worst_case_token_count + global_max = per_rank_max * self.ep_size + rank_token_offset = self._rank_token_offset() + ep_max_tokens = self._ep_max_tokens() + + # Cap AGV CTAs when overlapping the shared expert so the AGV does not + # starve the shared-expert GEMMs running on the side stream. + agv_kwargs = {"max_num_blocks": 16} if self.shared_experts is not None else {} + multimem_all_gatherv_3tensor( + agv_h["tensor"], + agv_r["tensor"], + agv_p["tensor"], + hidden_states, + self.routing_map, + probs, + agv_h["handle"], + agv_r["handle"], + agv_p["handle"], + rank_token_offset=rank_token_offset, + ep_max_tokens=ep_max_tokens, + per_rank_max_tokens=per_rank_max, + **agv_kwargs, ) - rs_buffer = None - if nvls_eligible: - rs_buffer = self._maybe_allocate_rs_buffer(hidden_states) + topk = probs.shape[1] + hidden_dim = hidden_states.shape[1] + self.routing_map = agv_r["tensor"].view(global_max, topk) + probs = agv_p["tensor"].view(global_max, topk) + hidden_states = agv_h["tensor"].view(global_max, hidden_dim) + return hidden_states, probs - can_use_nvls = nvls_eligible and rs_buffer["handle"] is not None + def dispatch_postprocess(self, hidden_states, probs): + """Pass-through: mcore_fused_moe operates directly on the gathered tensors.""" + return hidden_states, None, probs - if can_use_nvls: - # Copy input to symmetric memory for reduce-scatter - rs_buffer["tensor"].copy_(hidden_states) + # ── Combine path ────────────────────────────────────────────────────────────── - # Use latency-optimized NVLS reduce-scatter - multimem_reduce_scatter(output, rs_buffer["tensor"], rs_buffer["handle"]) - return output.to(torch.bfloat16) - else: - # Fallback to NCCL - hidden_states = reduce_scatter_to_sequence_parallel_region( - hidden_states, group=self.tp_ep_group - ) + def combine_preprocess(self, expert_output): + """Pass-through: unpermute is handled inside mcore_fused_moe.""" + return expert_output + + def token_combine(self, hidden_states): + """ReduceScatter-V: sum expert outputs across EP ranks, scatter to local tokens. + + Args: + hidden_states: [global_max, hidden_size] expert outputs (fp32 when + written directly to the RSV buffer, bf16 otherwise). + + Returns: + [local_tokens, hidden_size] bf16 local token outputs. + """ + if self.ep_size == 1: return hidden_states.to(torch.bfloat16) + + rsv = self.__class__._symm_rsv + + if hidden_states is not rsv["tensor"]: + rsv["tensor"].copy_(hidden_states) + output = torch.empty( + self._local_tokens, + hidden_states.shape[1], + dtype=rsv["tensor"].dtype, + device=hidden_states.device, + ) + multimem_reduce_scatter_v( + output, + rsv["tensor"], + rsv["handle"], + rank_token_offset=self._rank_token_offset(), + ep_max_tokens=self._ep_max_tokens(), + per_rank_max_tokens=self._per_rank_worst_case_token_count, + ) + return output.to(torch.bfloat16) + + def combine_postprocess(self, hidden_states): + """Restore original input shape (e.g. [S/TP, B, H] from [S*B/TP, H]). + + If shared_expert_overlap is enabled, join SharedExpertMLP.stream and add + the shared-expert output produced concurrently during dispatch+combine. + """ + output = hidden_states.view(self.hidden_shape) + if self._shared_expert_output is not None: + torch.cuda.current_stream().wait_stream(SharedExpertMLP.stream) + output = output + self._shared_expert_output + self._shared_expert_output = None + return output diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 60afdc03ee3..9f43e3565c5 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -22,6 +22,9 @@ gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region, ) +from megatron.core.tensor_parallel.inference_layers import ( + inference_all_gather_from_tensor_model_parallel_region, +) from megatron.core.transformer.enums import AttnMaskType, LayerType from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -931,11 +934,15 @@ def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.T hidden_states = torch.cat((decoder_input, hidden_states), -1) hidden_states, _ = self.eh_proj(hidden_states) # For tensor parallel we need to gather the tensor across the model-parallel - # ranks after the linear projection. This used to call - # `all_gather_last_dim_from_tensor_parallel_region`, but that utility reduces - # the gradient in backward pass and was therefore incorrect in this context. - # It has been replaced with the correct `gather_from_tensor_model_parallel_region`. - hidden_states = gather_from_tensor_model_parallel_region(hidden_states, group=self.tp_group) + # ranks after the linear projection. + if not self.training: + hidden_states = inference_all_gather_from_tensor_model_parallel_region( + hidden_states, self.tp_group, self.config + ) + else: + hidden_states = gather_from_tensor_model_parallel_region( + hidden_states, group=self.tp_group + ) # For sequence parallel, scatter after linear_fc and before transformer layer. if self.sequence_parallel: hidden_states = scatter_to_sequence_parallel_region(hidden_states, group=self.tp_group) @@ -1034,7 +1041,6 @@ def forward_single_position( rotary_pos_emb: Optional[Tensor] = None, rotary_pos_cos: Optional[Tensor] = None, rotary_pos_sin: Optional[Tensor] = None, - inference_params=None, packed_seq_params: Optional[PackedSeqParams] = None, sequence_len_offset: Optional[Tensor] = None, ) -> Tensor: @@ -1065,7 +1071,6 @@ def forward_single_position( rotary_pos_emb=rotary_pos_emb, rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, - inference_params=inference_params, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 4224b0cfe4c..fb9bd752371 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1079,6 +1079,18 @@ class TransformerConfig(ModelParallelConfig): fp8_recipe='mxfp8'. Set to True to disable fusion and use separate kernel launches (useful for debugging).""" + inference_moe_token_dispatcher_type: Literal['nccl', 'nvls'] = 'nvls' + """Token dispatcher to use for MoE expert parallelism during inference. + - 'nccl': AllGather/ReduceScatter via NCCL. Fixed token counts per rank; requires + decode-only CUDA graphs (forced automatically). + - 'nvls': Variable-count AllGather-V/ReduceScatter-V via NVLS multimem kernels. + Requires Hopper+ GPUs with NVLink and symmetric memory. Default. + Only applies when transformer_impl='inference_optimized' and EP > 1.""" + + mlp_chunks_for_training: int = 1 + """The number of chunks along the sequence dimension to use for MLP computation + during training.""" + mrope_section: Optional[List[int]] = None """ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. """ diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 271744b57a3..eff4c18db4f 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -786,6 +786,12 @@ def _forward_mlp( and not isinstance(self.mlp, IdentityOp) and not self.config.transformer_impl == "inference_optimized" ) + should_chunk_mlp_for_training = ( + self.config.mlp_chunks_for_training > 1 + and inference_context is None + and self.training + and not isinstance(self.mlp, IdentityOp) + ) using_fused_tp_inference_kernel = (not self.training) and ( self.config.inference_fuse_tp_communication @@ -815,9 +821,17 @@ def _forward_mlp( False, pre_mlp_layernorm_output, ) - elif should_chunk_mlp_for_prefill: + elif should_chunk_mlp_for_prefill or should_chunk_mlp_for_training: # Chunk input along sequence dimension - num_chunks = min(self.config.mlp_chunks_for_prefill, pre_mlp_layernorm_output.shape[0]) + num_chunks = min( + ( + self.config.mlp_chunks_for_prefill + if should_chunk_mlp_for_prefill + else self.config.mlp_chunks_for_training + ), + pre_mlp_layernorm_output.shape[0], + ) + chunks = pre_mlp_layernorm_output.chunk(num_chunks, dim=0) # Compute outputs for each chunk @@ -826,7 +840,8 @@ def _forward_mlp( # Aggregate chunk outputs mlp_output = torch.cat([out for out, _ in outputs], dim=0) bias_chunks = [bias for _, bias in outputs if bias is not None] - bias_output = torch.stack(bias_chunks, dim=0).sum(dim=0) if bias_chunks else None + # elements in bias_chunks are the same for all chunks, so we can just use the first one + bias_output = bias_chunks[0] if bias_chunks else None mlp_output_with_bias = (mlp_output, bias_output) else: if using_fused_tp_inference_kernel: @@ -2045,6 +2060,18 @@ def create_mcore_cudagraph_manager(self, config): ): self.transition_cudagraph_scope('partial') + def _resolve_token_dispatcher_attr(self, attr_name: str) -> tuple[Any, str]: + parent_attr_name, _, leaf_attr_name = attr_name.rpartition('.') + obj = self.mlp.token_dispatcher + for parent_name in parent_attr_name.split('.') if parent_attr_name else (): + obj = getattr(obj, parent_name) + return obj, leaf_attr_name or attr_name + + def _restore_token_dispatcher_attrs(self): + for attr_name, attr in self.token_dispatcher_attrs.items(): + obj, name = self._resolve_token_dispatcher_attr(attr_name) + setattr(obj, name, attr) + def _forward_mlp_router(self, hidden_states, padding_mask=None, input_ids=None): """ Executes the router phase of the MoE block. @@ -2077,10 +2104,12 @@ def _forward_mlp_router(self, hidden_states, padding_mask=None, input_ids=None): ) for attr_name in self.mlp.token_dispatcher.cudagraph_attrs: - attr = self.mlp.token_dispatcher.get_cudagraph_attr(attr_name) + obj, name = self._resolve_token_dispatcher_attr(attr_name) + attr = getattr(obj, name) if torch.is_tensor(attr): - if attr_name in self.token_dispatcher_attrs: - self.token_dispatcher_attrs[attr_name].copy_(attr) + cached_attr = self.token_dispatcher_attrs.get(attr_name) + if torch.is_tensor(cached_attr) and not cached_attr.requires_grad: + cached_attr.copy_(attr) else: self.token_dispatcher_attrs[attr_name] = attr.detach() @@ -2095,8 +2124,12 @@ def _forward_mlp_expert_compute(self, hidden_states, probs): step runs eagerly between the router and postprocess graph replays. """ - for name, attr in self.token_dispatcher_attrs.items(): - self.mlp.token_dispatcher.set_cudagraph_attr(name, attr) + # During partial CUDA graph replay, use the probs returned from the graph in order + # to retain the router autograd edge. Rebinding it to the live router output ensures + # the backward DDP hook of router.weight is properly triggered. + if '_comm_manager.token_probs' in self.token_dispatcher_attrs: + self.token_dispatcher_attrs['_comm_manager.token_probs'] = probs + self._restore_token_dispatcher_attrs() self.mlp.fwd_execution_map = "expert_compute" return self.mlp(None, intermediate_tensors=(hidden_states, probs)) @@ -2114,8 +2147,7 @@ def _forward_mlp_postprocess(self, residual, output, shared_expert_output, mlp_b # Restore token dispatcher attributes. During graph warmup, the router capture leaves these # attrs pointing into cudagraph pool memory; restoring them here ensures the postprocess # graph captures with valid pointers. - for name, attr in self.token_dispatcher_attrs.items(): - setattr(self.mlp.token_dispatcher, name, attr) + self._restore_token_dispatcher_attrs() self.mlp.fwd_execution_map = "postprocess" output = self.mlp(None, intermediate_tensors=(output, shared_expert_output)) @@ -2151,8 +2183,6 @@ def _forward_mlp_partial_cudagraphs( # graph and wait on it, so we block only until the router's D2H copies complete. self._router_dtoh_event.record() self._router_dtoh_event.synchronize() - for name, attr in self.token_dispatcher_attrs.items(): - setattr(self.mlp.token_dispatcher, name, attr) expert_output, mlp_bias = self._forward_mlp_expert_compute(hidden_states, probs) return self._forward_mlp_postprocess( diff --git a/megatron/core/transformer/utils.py b/megatron/core/transformer/utils.py index d1df4898d18..2249c79a2bd 100644 --- a/megatron/core/transformer/utils.py +++ b/megatron/core/transformer/utils.py @@ -1,6 +1,8 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """Utilities for transformer layers.""" +import gc +import logging from operator import itemgetter from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple, Union @@ -19,6 +21,8 @@ if TYPE_CHECKING: from megatron.core.transformer import TransformerConfig +logger = logging.getLogger(__name__) + def get_linear_layer(rows, columns, init_method, perform_initialization=True): """Simple linear layer with weight initialization.""" @@ -72,6 +76,22 @@ def erf_gelu(x): ) +@torch.no_grad() +def cat_with_oom_fallback(sub_state_dict): + """Merge sharded tensor pieces, falling back to CPU if device-side cat OOMs.""" + try: + return torch.cat(sub_state_dict) + except (RuntimeError, torch.cuda.OutOfMemoryError) as e: + logger.warning( + f"CUDA OutOfMemoryError encountered during tensors merging." + f" Switching to CPU merge. (Error: {e})" + ) + merged_sub_state_dict = torch.cat([t.cpu() for t in sub_state_dict]) + gc.collect() + torch.cuda.empty_cache() + return merged_sub_state_dict + + def make_sharded_tensors_for_checkpoint( state_dict: StateDict, prefix: str, diff --git a/megatron/core/utils.py b/megatron/core/utils.py index b7d82d260d9..2c1f5c8c75f 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -52,13 +52,6 @@ except ImportError: HAVE_PACKAGING = False -try: - import nvtx - - HAVE_NVTX = True -except ImportError: - HAVE_NVTX = False - logger = logging.getLogger(__name__) try: @@ -513,6 +506,11 @@ def divide(numerator, denominator): return numerator // denominator +def round_up_to_nearest_multiple(value: int, multiple: int) -> int: + """Round *value* up to the nearest multiple of *multiple*.""" + return math.ceil(value / multiple) * multiple + + def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_initialized=True): """Issue a deprecation warning if tp_group is None and return the default tp group.""" # TODO(zijiey): remove this function later. @@ -2084,6 +2082,10 @@ def get_thd_batch_on_this_cp_rank( _nvtx_enabled: bool = False # Whether NVTX range profiling is enabled _nvtx_range_messages: list[str] = [] # Messages associated with active NVTX ranges +# Permanently pin the string object representing the name of each NVTX range. +# These string objects may be created during CUDA graph capture. +# If they are not pinned, the NVTX range names will be garbage-collected and nsys profile crashes. +_nvtx_range_msg_pool: dict[str, str] = {} def configure_nvtx_profiling(enabled: bool) -> None: @@ -2125,6 +2127,11 @@ def nvtx_range_push(msg=None, suffix=None) -> None: if suffix is not None: msg = f"{msg}.{suffix}" + # If we have entered this range before, do not use the newly-created "msg" object. + # But instead point to the original, first-created, "msg" object. + # They may hold identical data, but they are different addresses; matters when CUDA-graphed. + msg = _nvtx_range_msg_pool.setdefault(msg, msg) + # Track messages to ensure consistency when popping _nvtx_range_messages.append(msg) @@ -2177,14 +2184,16 @@ def _nvtx_decorator_get_func_path(func): return f"{module.__name__}.{caller_func}" -def nvtx_decorator( - message: Optional[str] = None, color: Optional[str] = None -) -> Callable[[_Wrapped], _Wrapped]: +def nvtx_decorator(message: Optional[str] = None) -> Callable[[_Wrapped], _Wrapped]: """Decorator to add NVTX range to a function. + The ``_nvtx_enabled`` flag is checked at **call time** inside + ``nvtx_range_push`` / ``nvtx_range_pop``, so the decorator works + correctly even when applied before ``configure_nvtx_profiling()`` + is called (e.g. at module-import time). + Args: message (str, optional): Custom message for the NVTX range. If None, uses function path - color (str, optional): Color for the NVTX range. Defaults to None Returns: Callable: Decorated function with NVTX profiling if enabled @@ -2194,17 +2203,23 @@ def nvtx_decorator( def my_function(): pass - @nvtx_decorator(message="Custom Range", color="blue") + @nvtx_decorator(message="Custom Range") def another_function(): pass """ def decorator(func: _Wrapped) -> _Wrapped: - if _nvtx_enabled and HAVE_NVTX: - return nvtx.annotate( - message=message or _nvtx_decorator_get_func_path(func), color=color - )(func) - return func + msg = message or _nvtx_decorator_get_func_path(func) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + nvtx_range_push(msg) + try: + return func(*args, **kwargs) + finally: + nvtx_range_pop(msg) + + return wrapper # type: ignore[return-value] return decorator diff --git a/megatron/elastification/__init__.py b/megatron/elastification/__init__.py new file mode 100644 index 00000000000..26496bfed70 --- /dev/null +++ b/megatron/elastification/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/elastification/arguments.py b/megatron/elastification/arguments.py new file mode 100644 index 00000000000..328e4da02dd --- /dev/null +++ b/megatron/elastification/arguments.py @@ -0,0 +1,380 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import math + + +def convert_per_lists_to_int_lists(config): + """Convert all *_per_list attributes to *_int_list using model dimensions. + + Called once after model dimensions are known so downstream code can always + use the int-list path without branching on which list type is active. + After this call every *_per_list is None and every *_int_list is set. + """ + conversions = [ + ('emb_per_list', 'emb_int_list', config.hidden_size), + ('mlp_per_list', 'mlp_int_list', config.ffn_hidden_size), + ('mamba_per_list', 'mamba_int_list', config.mamba_num_heads), + ('moe_expert_per_list', 'moe_expert_int_list', config.num_moe_experts), + ] + for per_attr, int_attr, ref_dim in conversions: + per_val = getattr(config, per_attr, None) + if per_val is not None: + setattr(config, int_attr, [math.floor(x * ref_dim) for x in per_val]) + setattr(config, per_attr, None) + + +def sort_budget_list_descending(args): + """Sort ``budget_list`` descending and permute ``budget_probs`` to match. + + The Flextron router's interpolation branch and the elasticity hooks + implicitly assume budget_list is descending (largest first). Sort once + here so downstream code can rely on that invariant regardless of the + order the user passed on the CLI. Idempotent: a list already in + descending order is unchanged. + """ + bl = getattr(args, 'budget_list', None) + if bl is None or len(bl) <= 1: + return + + bp = getattr(args, 'budget_probs', None) + if bp is not None: + assert len(bp) == len( + bl + ), f'budget_probs length {len(bp)} does not match budget_list length {len(bl)}' + + order = sorted(range(len(bl)), key=lambda i: bl[i], reverse=True) + args.budget_list = [bl[i] for i in order] + if bp is not None: + args.budget_probs = [bp[i] for i in order] + + +def validate_flextron_per_int_lists(args): + """ + Enforce mutual exclusion between ratio per-lists and integer choice lists. + + For each module, at most one of (*_per_list, *_int_list) may be set. If neither + is set, *_per_list defaults to [1.0]. Skips when flextron-related args were not + registered on the parser. + """ + pairs = ( + ('mamba', 'mamba_per_list', 'mamba_int_list'), + ('mlp', 'mlp_per_list', 'mlp_int_list'), + ('emb', 'emb_per_list', 'emb_int_list'), + ('moe-expert', 'moe_expert_per_list', 'moe_expert_int_list'), + ) + for cli_name, per_attr, int_attr in pairs: + # Default to None when the attribute is missing - happens when + # flextron args weren't registered on the parser (the docstring + # promises we skip in that case). + per_val = getattr(args, per_attr, None) + int_val = getattr(args, int_attr, None) + per_set = per_val is not None + int_set = int_val is not None + if per_set: + for x in per_val: + assert 0.0 <= x <= 1.0, f'--{cli_name}-per-list values must be in [0, 1], got {x}.' + assert not ( + per_set and int_set + ), f'Use either --{cli_name}-per-list or --{cli_name}-int-list for {cli_name}, not both.' + if not per_set and not int_set: + setattr(args, per_attr, [1.0]) + + +def add_flextron_args(parser): + group = parser.add_argument_group(title='flextron') + # Distillation flags + group.add_argument('--distillation', action='store_true', help='Enable self-distillation.') + group.add_argument('--distill-coeff', type=float, default=0.0, help='Distillation coefficient.') + group.add_argument('--distill-only', action='store_true', help='Distillation only.') + # Basic Flextron flags + group.add_argument('--flextron', action='store_true', help='Enable Flextron.') + group.add_argument('--binary-mask', action='store_true', help='Use binary mask in Flextron.') + group.add_argument('--slice', action='store_true', help='Use slice in Flextron.') + group.add_argument('--enable-router', action='store_true', help='Enable router in Flextron.') + group.add_argument( + '--add-skipping', action='store_true', help='Add layer skipping in Flextron.' + ) + group.add_argument('--no-attn-skip', action='store_true', help='No attn skip in Flextron.') + group.add_argument( + '--lr-mult-router', + type=float, + default=1.0, + help='Learning rate multiplier for router in Flextron.', + ) + group.add_argument('--flex-strict', action='store_true', help='Strict loading of Flextron.') + group.add_argument('--is-flex-eval', action='store_true', help='Is Flextron evaluation.') + group.add_argument('--freeze-router', action='store_true', help='Freeze router in Flextron.') + group.add_argument('--freeze-model', action='store_true', help='Freeze model in Flextron.') + group.add_argument( + '--flex-hetero-ffn', action='store_true', help='Use flex hetero FFN in Flextron.' + ) + group.add_argument( + '--flex-hetero-mamba', action='store_true', help='Use flex hetero Mamba in Flextron.' + ) + group.add_argument( + '--flex-hetero-moe-expert', + action='store_true', + help='Use flex hetero MoE expert in Flextron.', + ) + group.add_argument( + '--router-std', type=float, default=0.1, help='Router init std for Flextron.' + ) + group.add_argument( + '--normalize-router-logits', + action='store_true', + help='Normalize router logits in Flextron.', + ) + group.add_argument('--soft-mask', action='store_true', help='Soft mask in Flextron.') + + # Flextron hyperparameters + group.add_argument( + '--budget-probs', + nargs='+', + type=float, + default=None, + help='List of budget probabilities for Flextron.', + ) + group.add_argument( + '--prefill-chunk-size', type=int, default=16384, help='Prefill chunk size for Flextron.' + ) + group.add_argument( + '--mem-infer-seq-len', + type=int, + default=131072, + help='Memory infer sequence length for Flextron.', + ) + group.add_argument( + '--mem-batch-size', type=int, default=1, help='Memory batch size for Flextron.' + ) + group.add_argument( + '--original-model-sample-prob', + type=float, + default=0.33, + help='Probability of sampling the original model in Flextron.', + ) + group.add_argument( + '--force-router-skip', + nargs='+', + type=int, + default=None, + help='Force router skip for Flextron router.', + ) + group.add_argument( + '--force-mlp', nargs='+', type=float, default=None, help='Force MLP for Flextron router.' + ) + group.add_argument( + '--force-mamba', + nargs='+', + type=float, + default=None, + help='Force Mamba for Flextron router.', + ) + group.add_argument( + '--force-emb', + nargs='+', + type=float, + default=None, + help='Force Embedding for Flextron router.', + ) + group.add_argument( + '--skip-num-attn-layer-constraint', + type=int, + default=None, + help='Skip number of attention layer constraint for Flextron router.', + ) + group.add_argument( + '--skip-total-layer-constraint', + type=int, + default=None, + help='Skip total layer constraint for Flextron router.', + ) + group.add_argument( + '--disable-budget', action='store_true', help='Disable budget for Flextron router.' + ) + group.add_argument( + '--curr-iteration', type=int, default=None, help='Current iteration for Flextron router.' + ) + group.add_argument( + '--hard-sample-th', + type=float, + default=0.996, + help='Hard sample threshold for Flextron router.', + ) + group.add_argument( + '--router-beta', type=float, default=1.0, help='Beta value for Flextron router.' + ) + group.add_argument( + '--loss-alpha', type=float, default=1.0, help='Alpha coefficient for Flextron loss.' + ) + group.add_argument('--tau-init', type=float, default=1.0, help='Tau init for Flextron router.') + group.add_argument( + '--tau-decay', type=float, default=0.9999, help='Tau decay for Flextron router.' + ) + group.add_argument( + '--router-inter-dim', + type=int, + default=128, + help='Intermediate dimension for Flextron router.', + ) + group.add_argument( + '--linear-scaler-start', + type=float, + default=1.0, + help='Linear scaler start for Flextron router.', + ) + group.add_argument( + '--linear-scaler-end', + type=float, + default=10.0, + help='Linear scaler end for Flextron router.', + ) + group.add_argument( + '--override-selected-budget', + nargs='+', + type=float, + default=None, + help='Override selected budget for Flextron router.', + ) + group.add_argument('--router-gbs', type=int, default=32, help='Router gbs for Flextron router.') + # Model configuration lists + group.add_argument( + '--budget-list', + nargs='+', + type=float, + default=[1.0], + help='List of budget values for Flextron.', + ) + group.add_argument( + '--mamba-per-list', + nargs='+', + type=float, + default=None, + help='List of Mamba percentage values for Flextron (mutually exclusive with --mamba-int-list).', + ) + group.add_argument( + '--mlp-per-list', + nargs='+', + type=float, + default=None, + help='List of MLP percentage values for Flextron (mutually exclusive with --mlp-int-list).', + ) + group.add_argument( + '--emb-per-list', + nargs='+', + type=float, + default=None, + help='List of embedding percentage values for Flextron (mutually exclusive with --emb-int-list).', + ) + group.add_argument( + '--moe-expert-per-list', + nargs='+', + type=float, + default=None, + help='List of MoE expert percentage values for Flextron (mutually exclusive with --moe-expert-int-list).', + ) + group.add_argument( + '--mamba-int-list', + nargs='+', + type=int, + default=None, + help='List of Mamba integer router choices for Flextron (mutually exclusive with --mamba-per-list).', + ) + group.add_argument( + '--mlp-int-list', + nargs='+', + type=int, + default=None, + help='List of MLP integer router choices for Flextron (mutually exclusive with --mlp-per-list).', + ) + group.add_argument( + '--emb-int-list', + nargs='+', + type=int, + default=None, + help='List of embedding integer router choices for Flextron (mutually exclusive with --emb-per-list).', + ) + group.add_argument( + '--moe-expert-int-list', + nargs='+', + type=int, + default=None, + help='List of MoE expert integer router choices for Flextron (mutually exclusive with --moe-expert-per-list).', + ) + group.add_argument( + '--budget-type', type=str, default='param', choices=['param', 'mem'], help='Type of budget.' + ) + # Memory quantization profile + group.add_argument( + '--memory-profile', + type=str, + default='bf16', + help='Named memory quantization preset from memory_profiles.yaml ' + '(e.g. bf16, fp8_kv, fp8_all, int8). ' + 'Individual --bpe-* overrides take priority.', + ) + group.add_argument( + '--memory-profile-path', + type=str, + default=None, + help='Path to a custom memory_profiles.yaml. ' + 'Defaults to the bundled megatron/elastification/memory_profiles.yaml.', + ) + group.add_argument( + '--bpe-params', + type=float, + default=None, + help='Override bytes-per-element for model parameters ' '(2=BF16, 1=FP8/INT8, 0.5625=FP4).', + ) + group.add_argument( + '--bpe-kv-cache', type=float, default=None, help='Override bytes-per-element for KV cache.' + ) + group.add_argument( + '--bpe-ssm-cache', + type=float, + default=None, + help='Override bytes-per-element for Mamba SSM state cache.', + ) + group.add_argument( + '--bpe-max-buffer', + type=float, + default=None, + help='Override bytes-per-element for MoE dispatch buffer.', + ) + group.add_argument( + '--param-budget-target', + type=str, + default=None, + choices=['active', 'total'], + help='Whether param budget loss supervises on active params ' + '(top-k experts only) or total params. ' + 'Overrides the preset value from --memory-profile.', + ) + group.add_argument( + '--layer-ranking-list', nargs='+', type=int, default=None, help='List of layer ranking.' + ) + group.add_argument( + '--log-budgets', + nargs='+', + type=str, + default=["all"], + help='Budget values to log distillation loss for (space-separated list or "all").', + ) + # Additional parameters + + group.add_argument( + '--basemodel-type', + type=str, + default='nemotronh_8b', + choices=['nemotronh_8b'], + help='Base model type for parameter loss calculation.', + ) + + # Budget configuration + group.add_argument( + '--flextron-config-file', + type=str, + default=None, + help='Configuration file for Flextron budget settings.', + ) + + return parser diff --git a/megatron/elastification/flextron_config.py b/megatron/elastification/flextron_config.py new file mode 100644 index 00000000000..620e7ef2557 --- /dev/null +++ b/megatron/elastification/flextron_config.py @@ -0,0 +1,113 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +FlextronConfig — all Flextron/elastification config fields in one place. + +Previously these lived as fields on TransformerConfig (megatron/core). +They are now injected onto the model config at runtime via inject_flextron_config +so that megatron/core stays clean. +""" + +import dataclasses +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class FlextronConfig: + # ── Core flags ──────────────────────────────────────────────────────────── + flextron: bool = False + binary_mask: bool = False + add_skipping: bool = False + no_attn_skip: bool = False + slice: bool = False + soft_mask: bool = False + + # ── Router ──────────────────────────────────────────────────────────────── + enable_router: bool = False + router_inter_dim: int = 128 + hard_sample_th: float = 0.996 + router_beta: float = 1.0 + loss_alpha: float = 1.0 + tau_init: float = 1.0 + tau_decay: float = 0.9999 + router_std: float = 0.1 + router_gbs: int = 32 + normalize_router_logits: bool = False + linear_scaler_start: Optional[float] = None + linear_scaler_end: Optional[float] = None + + # ── Budget ──────────────────────────────────────────────────────────────── + budget_probs: Optional[List[float]] = None + budget_list: Optional[List[float]] = None + budget_type: str = 'param' + disable_budget: bool = False + + # ── Training / eval control ─────────────────────────────────────────────── + basemodel_type: str = 'nemotronh_8b' + is_flex_eval: bool = False + freeze_router: bool = False + freeze_model: bool = False + curr_iteration: Optional[int] = None + original_model_sample_prob: float = 0.33 + override_selected_budget: Optional[List[float]] = None + + # ── Layer-skip constraints ──────────────────────────────────────────────── + skip_num_attn_layer_constraint: Optional[int] = None + skip_total_layer_constraint: Optional[int] = None + layer_ranking_list: Optional[List[int]] = None + + # ── Force overrides (eval / frozen-router mode) ─────────────────────────── + force_router_skip: Optional[List[int]] = None + force_mlp: Optional[List[float]] = None + force_mamba: Optional[List[float]] = None + force_emb: Optional[List[float]] = None + + # ── Choice lists (converted to int at model-setup time) ─────────────────── + mamba_per_list: Optional[List[float]] = None + mlp_per_list: Optional[List[float]] = None + emb_per_list: Optional[List[float]] = None + moe_expert_per_list: Optional[List[float]] = None + mamba_int_list: Optional[List[int]] = None + mlp_int_list: Optional[List[int]] = None + emb_int_list: Optional[List[int]] = None + moe_expert_int_list: Optional[List[int]] = None + + # ── Heterogeneous per-layer routing ─────────────────────────────────────── + flex_hetero_ffn: bool = False + flex_hetero_mamba: bool = False + flex_hetero_moe_expert: bool = False + + # ── Memory / inference sizing ───────────────────────────────────────────── + prefill_chunk_size: int = 16384 + mem_infer_seq_len: int = 131072 + mem_batch_size: int = 1 + + # ── Distillation ────────────────────────────────────────────────────────── + distillation: bool = False + distill_coeff: float = 0.0 + distill_only: bool = False + + +def inject_flextron_config(args, config) -> None: + """Copy all FlextronConfig fields from parsed args onto an existing config object. + + Safe to call even when flextron args were not registered on the parser — + falls back to FlextronConfig defaults via getattr. After this call every + FlextronConfig field is accessible directly as config.. + """ + # Validate per-list/int-list mutual exclusion, apply default fallbacks, + # and sort the budget list descending before copying onto config so + # downstream code sees the resolved + ordered state. + from megatron.elastification.arguments import ( + sort_budget_list_descending, + validate_flextron_per_int_lists, + ) + + validate_flextron_per_int_lists(args) + sort_budget_list_descending(args) + + defaults = FlextronConfig() + for f in dataclasses.fields(defaults): + value = getattr(args, f.name, getattr(defaults, f.name)) + setattr(config, f.name, value) diff --git a/megatron/elastification/flextron_elasticity_hooks.py b/megatron/elastification/flextron_elasticity_hooks.py new file mode 100644 index 00000000000..da815043865 --- /dev/null +++ b/megatron/elastification/flextron_elasticity_hooks.py @@ -0,0 +1,1832 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Flextron Elasticity Hooks + +Applies elasticity masking through PyTorch hooks without modifying original +modules. One manager class per module type (MambaMixer, SelfAttention, +TransformerLayer, MoELayer, TopKRouter, TEGroupedMLP, HybridStack). +""" + +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +from megatron.core import parallel_state +from megatron.core.tensor_parallel.utils import split_tensor_along_last_dim +from megatron.core.transformer.moe.moe_utils import get_capacity, group_limited_topk + + +class FlextronMambaElasticityManager: + """ + Manages elasticity for MambaMixer using pure PyTorch hooks. + Based on the exact implementation from original flextron_os MambaMixer. + """ + + def __init__(self, config, layer_idx=0): + self.config = config + self.layer_idx = layer_idx + self.enabled = getattr(config, 'flextron', False) + + if not self.enabled: + return + + # Current elasticity parameters - store the full router outputs + self.current_router_emb = None + self.current_router_mamba = None + + # Hook handles for cleanup + self.hook_handles = [] + + def _init_embedding_masks(self): + """Initialize embedding dimension masks.""" + mask_list = [] + for emb_int in self.config.emb_int_list: + assert ( + 0 <= emb_int <= self.config.hidden_size + ), f'emb_int_list entries must be in [0, hidden_size={self.config.hidden_size}], got {emb_int}.' + mask = torch.zeros(self.config.hidden_size, dtype=torch.bool) + mask[:emb_int] = True + mask_list.append(mask) + self.emb_masks_lookup = { + emb_int: idx for idx, emb_int in enumerate(self.config.emb_int_list) + } + self.emb_masks = torch.stack(mask_list, dim=0).to(device='cuda').to(dtype=torch.bfloat16) + + def _init_mamba_masks(self): + """Initialize Mamba-specific masks for different layers.""" + in_proj_mask_list = [] + conv1d_mask_list = [] + + world_size = parallel_state.get_tensor_model_parallel_world_size() + + in_proj_z_shard = [i for i in range(self.mamba_mixer.d_inner_local_tp)] + in_proj_x_shard = [ + i + for i in range(self.mamba_mixer.d_inner_local_tp, 2 * self.mamba_mixer.d_inner_local_tp) + ] + in_proj_B_shard = [ + i + for i in range( + 2 * self.mamba_mixer.d_inner_local_tp, + 2 * self.mamba_mixer.d_inner_local_tp + + self.mamba_mixer.ngroups_local_tp * self.mamba_mixer.d_state, + ) + ] + in_proj_C_shard = [ + i + for i in range( + 2 * self.mamba_mixer.d_inner_local_tp + + self.mamba_mixer.ngroups_local_tp * self.mamba_mixer.d_state, + 2 * self.mamba_mixer.d_inner_local_tp + + 2 * self.mamba_mixer.ngroups_local_tp * self.mamba_mixer.d_state, + ) + ] + in_proj_dt_shard = [ + i + for i in range( + 2 * self.mamba_mixer.d_inner_local_tp + + 2 * self.mamba_mixer.ngroups_local_tp * self.mamba_mixer.d_state, + 2 * self.mamba_mixer.d_inner_local_tp + + 2 * self.mamba_mixer.ngroups_local_tp * self.mamba_mixer.d_state + + self.mamba_mixer.nheads_local_tp, + ) + ] + + conv1d_x_shard = [i for i in range(self.mamba_mixer.d_inner_local_tp)] + conv1d_B_shard = [ + i + for i in range( + self.mamba_mixer.d_inner_local_tp, + self.mamba_mixer.d_inner_local_tp + + self.mamba_mixer.ngroups_local_tp * self.mamba_mixer.d_state, + ) + ] + conv1d_C_shard = [ + i + for i in range( + self.mamba_mixer.d_inner_local_tp + + self.mamba_mixer.ngroups_local_tp * self.mamba_mixer.d_state, + self.mamba_mixer.d_inner_local_tp + + 2 * self.mamba_mixer.ngroups_local_tp * self.mamba_mixer.d_state, + ) + ] + + out_proj_x_shard = [i for i in range(self.mamba_mixer.d_inner_local_tp)] + + tp_size = parallel_state.get_tensor_model_parallel_world_size() + for mamba_int in self.config.mamba_int_list: + assert ( + 0 <= mamba_int <= self.mamba_mixer.nheads + ), f"mamba_int_list entries must be in [0, nheads={self.mamba_mixer.nheads}], got {mamba_int}." + assert ( + mamba_int % tp_size == 0 + ), f"mamba_int_list entries must be evenly divisible by tp_size={tp_size}, got {mamba_int}." + mamba_nhead_idx = mamba_int // tp_size + + in_proj_mask = torch.zeros( + self.mamba_mixer.d_inner_local_tp * 2 + + self.mamba_mixer.ngroups_local_tp * self.mamba_mixer.d_state * 2 + + self.mamba_mixer.nheads_local_tp, + dtype=torch.bool, + ) + in_proj_mask[in_proj_z_shard[: mamba_nhead_idx * self.mamba_mixer.headdim]] = True + in_proj_mask[in_proj_x_shard[: mamba_nhead_idx * self.mamba_mixer.headdim]] = True + in_proj_mask[in_proj_B_shard] = True + in_proj_mask[in_proj_C_shard] = True + in_proj_mask[in_proj_dt_shard[:mamba_nhead_idx]] = True + in_proj_mask_list.append(in_proj_mask) + + conv1d_mask = torch.zeros( + self.mamba_mixer.d_inner_local_tp + + self.mamba_mixer.ngroups_local_tp * self.mamba_mixer.d_state * 2, + dtype=torch.bool, + ) + conv1d_mask[conv1d_x_shard[: mamba_nhead_idx * self.mamba_mixer.headdim]] = True + conv1d_mask[conv1d_B_shard] = True + conv1d_mask[conv1d_C_shard] = True + conv1d_mask_list.append(conv1d_mask) + + self.mamba_masks_lookup = { + mamba_int: idx for idx, mamba_int in enumerate(self.config.mamba_int_list) + } + + in_proj_mask_list = [item.to(in_proj_mask_list[0].device) for item in in_proj_mask_list] + self.in_proj_mask_list = ( + torch.stack(in_proj_mask_list, dim=0).to(device='cuda').to(dtype=torch.bfloat16) + ) + + conv1d_mask_list = [item.to(conv1d_mask_list[0].device) for item in conv1d_mask_list] + self.conv1d_mask_list = ( + torch.stack(conv1d_mask_list, dim=0).to(device='cuda').to(dtype=torch.bfloat16) + ) + + def attach_hooks(self, mamba_mixer): + """Attach hooks to MambaMixer following the original flextron_os pattern.""" + if not self.enabled: + return + + self.mamba_mixer = mamba_mixer + + emb_effective_per_list = [x / self.config.hidden_size for x in self.config.emb_int_list] + mamba_effective_per_list = [x / mamba_mixer.nheads for x in self.config.mamba_int_list] + + # Setup hook - runs first to initialize masks for this forward pass + def setup_masks_hook(module, input): + if self.config.flextron: + self._init_embedding_masks() + self._init_mamba_masks() + return input + + # Cleanup hook - runs last to remove masks after forward pass + def cleanup_masks_hook(module, input, output): + if self.config.flextron: + self.emb_masks = None + self.in_proj_mask_list = None + self.conv1d_mask_list = None + self.emb_masks_lookup = {} + self.mamba_masks_lookup = {} + return output + + # Hook 1: Input masking and router_emb processing + def input_mask_hook(module, input): + if self.config.flextron and self.current_router_emb is not None: + hidden_states = input[0] + + # Apply embedding mask + if self.config.soft_mask: + soft_mask = torch.zeros( + self.emb_masks[0].shape, + dtype=torch.bfloat16, + device=self.emb_masks[0].device, + ) + for mask, per_logit in zip(self.emb_masks, self.current_router_emb[0]): + soft_mask.add_(mask * per_logit) + mask = soft_mask + masked_input = hidden_states * mask[None, None, :] + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + mask = self.emb_masks[self.emb_masks_lookup[emb_choice]] + masked_input = hidden_states * mask[None, None, :] + masked_input = masked_input * router_emb_logits + + return tuple([masked_input] + list(input[1:])) + + return input + + # Hook 2: in_proj pre-hook for eps modification + def in_proj_pre_hook(module, input): + if self.config.flextron and self.current_router_emb is not None: + + # Set eps to the pruned value + if self.config.soft_mask: + soft_eps = 0 + for emb_per, per_logit in zip( + emb_effective_per_list, self.current_router_emb[0] + ): + soft_eps += self.config.layernorm_epsilon * emb_per * per_logit + module.eps = soft_eps.float().detach().item() + else: + emb_choice = self.current_router_emb[1] + emb_effective_per = emb_choice / self.config.hidden_size + module.eps = self.config.layernorm_epsilon * emb_effective_per + return input + + # Hook 3: in_proj post-hook for router scaling + def in_proj_post_hook(module, input, output): + if self.config.flextron and self.current_router_mamba is not None: + # Apply router_emb scaling to in_proj output + xz, bias = output + + if self.config.soft_mask: + # Soft scaling with embedding router + soft_xz = torch.zeros_like(xz) + for emb_per, per_logit in zip( + emb_effective_per_list, self.current_router_emb[0] + ): + soft_xz.add_(xz * per_logit * (emb_per**0.5)) + xz = soft_xz + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + emb_effective_per = emb_choice / self.config.hidden_size + xz = xz * router_emb_logits * (emb_effective_per**0.5) + + # Apply mamba router logic (hard mask only) + if not self.config.soft_mask: + if self.config.flex_hetero_mamba: + mamba_idx = ( + self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('M') - 1 + ) + router_mamba_logits = torch.max(self.current_router_mamba[0][mamba_idx]) + mamba_per = self.current_router_mamba[1][mamba_idx] + else: + router_mamba_logits, mamba_per = ( + torch.max(self.current_router_mamba[0]), + self.current_router_mamba[1], + ) + + # Apply mamba masking + if self.config.soft_mask: + soft_in_proj_mask = torch.zeros_like(self.in_proj_mask_list[0]) + if self.config.flex_hetero_mamba: + mamba_idx = ( + self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('M') - 1 + ) + for mask, per_logit in zip( + self.in_proj_mask_list, self.current_router_mamba[0][mamba_idx] + ): + soft_in_proj_mask.add_(mask * per_logit) + else: + for mask, per_logit in zip( + self.in_proj_mask_list, self.current_router_mamba[0] + ): + soft_in_proj_mask.add_(mask * per_logit) + in_proj_mask = soft_in_proj_mask + else: + in_proj_mask = self.in_proj_mask_list[self.mamba_masks_lookup[mamba_per]] + + xz = xz * in_proj_mask.to(device=xz.device)[None, None, :] + + if not self.config.soft_mask: + xz = xz * router_mamba_logits + + # Reset eps to original + module.eps = self.config.layernorm_epsilon + + return (xz, bias) + return output + + # Hook 4: conv1d output masking + def conv1d_mask_hook(module, input, output): + if self.config.flextron and self.current_router_mamba is not None: + if not self.config.soft_mask: + if self.config.flex_hetero_mamba: + mamba_idx = ( + self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('M') - 1 + ) + router_mamba_logits = torch.max(self.current_router_mamba[0][mamba_idx]) + mamba_per = self.current_router_mamba[1][mamba_idx] + else: + router_mamba_logits, mamba_per = ( + torch.max(self.current_router_mamba[0]), + self.current_router_mamba[1], + ) + + # Apply conv1d masking + if self.config.soft_mask: + soft_conv1d_mask = torch.zeros_like(self.conv1d_mask_list[0]) + if self.config.flex_hetero_mamba: + mamba_idx = ( + self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('M') - 1 + ) + for mask, per_logit in zip( + self.conv1d_mask_list, self.current_router_mamba[0][mamba_idx] + ): + soft_conv1d_mask.add_(mask * per_logit) + else: + for mask, per_logit in zip( + self.conv1d_mask_list, self.current_router_mamba[0] + ): + soft_conv1d_mask.add_(mask * per_logit) + conv1d_mask = soft_conv1d_mask + else: + conv1d_mask = self.conv1d_mask_list[self.mamba_masks_lookup[mamba_per]] + masked_output = output * conv1d_mask.to(device=output.device)[None, :, None] + + if not self.config.soft_mask: + masked_output = masked_output * router_mamba_logits + + return masked_output + return output + + # Hook 5a: RMSNorm pre-hook for eps modification + def norm_pre_hook(module, input): + if self.config.flextron and self.current_router_mamba is not None: + if self.config.soft_mask: + soft_eps = 0 + if self.config.flex_hetero_mamba: + mamba_idx = ( + self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('M') - 1 + ) + for mamba_per, per_logit in zip( + mamba_effective_per_list, self.current_router_mamba[0][mamba_idx] + ): + soft_eps += self.config.layernorm_epsilon * mamba_per * per_logit + else: + for mamba_per, per_logit in zip( + mamba_effective_per_list, self.current_router_mamba[0] + ): + soft_eps += self.config.layernorm_epsilon * mamba_per * per_logit + module.eps = soft_eps.float().detach().item() + else: + if self.config.flex_hetero_mamba: + mamba_idx = ( + self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('M') - 1 + ) + mamba_per = self.current_router_mamba[1][mamba_idx] + else: + mamba_per = self.current_router_mamba[1] + mamba_effective_per = mamba_per / self.mamba_mixer.nheads + module.eps = self.config.layernorm_epsilon * mamba_effective_per + + return input + + # Hook 5b: RMSNorm post-hook for scaling and eps restoration + def norm_post_hook(module, input, output): + if self.config.flextron and self.current_router_mamba is not None: + # Restore original eps + module.eps = self.config.layernorm_epsilon + + if self.config.soft_mask: + soft_scaled_output = torch.zeros_like(output) + if self.config.flex_hetero_mamba: + mamba_idx = ( + self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('M') - 1 + ) + for mamba_per, per_logit in zip( + mamba_effective_per_list, self.current_router_mamba[0][mamba_idx] + ): + soft_scaled_output.add_(output * (mamba_per**0.5) * per_logit) + else: + for mamba_per, per_logit in zip( + mamba_effective_per_list, self.current_router_mamba[0] + ): + soft_scaled_output.add_(output * (mamba_per**0.5) * per_logit) + return soft_scaled_output + else: + if self.config.flex_hetero_mamba: + mamba_idx = ( + self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('M') - 1 + ) + router_mamba_logits = torch.max(self.current_router_mamba[0][mamba_idx]) + mamba_per = self.current_router_mamba[1][mamba_idx] + else: + router_mamba_logits, mamba_per = ( + torch.max(self.current_router_mamba[0]), + self.current_router_mamba[1], + ) + mamba_effective_per = mamba_per / self.mamba_mixer.nheads + return output * (mamba_effective_per**0.5) * router_mamba_logits + + return output + + # Hook 6: Final output masking + def output_mask_hook(module, input, output): + if self.config.flextron and self.current_router_emb is not None: + out, out_bias = output + + # Apply embedding mask + if self.config.soft_mask: + soft_mask = torch.zeros( + self.emb_masks[0].shape, + dtype=torch.bfloat16, + device=self.emb_masks[0].device, + ) + for mask, per_logit in zip(self.emb_masks, self.current_router_emb[0]): + soft_mask.add_(mask * per_logit) + mask = soft_mask + masked_out = out * mask[None, None, :] + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + mask = self.emb_masks[self.emb_masks_lookup[emb_choice]] + masked_out = out * mask[None, None, :] + masked_out = masked_out * router_emb_logits + + return (masked_out, out_bias) + return output + + # IMPORTANT: Register setup hook FIRST + setup_handle = mamba_mixer.register_forward_pre_hook(setup_masks_hook) + self.hook_handles.append(setup_handle) + + # Attach main input hook + main_handle = mamba_mixer.register_forward_pre_hook(input_mask_hook) + self.hook_handles.append(main_handle) + + # Attach in_proj hooks + in_proj_pre_handle = mamba_mixer.in_proj.register_forward_pre_hook(in_proj_pre_hook) + in_proj_post_handle = mamba_mixer.in_proj.register_forward_hook(in_proj_post_hook) + self.hook_handles.append(in_proj_pre_handle) + self.hook_handles.append(in_proj_post_handle) + + # Attach conv1d hook (this will handle the standard conv1d path) + conv_handle = mamba_mixer.conv1d.register_forward_hook(conv1d_mask_hook) + self.hook_handles.append(conv_handle) + + # Attach RMSNorm hooks if rmsnorm is enabled + norm_pre_handle = mamba_mixer.norm.register_forward_pre_hook(norm_pre_hook) + norm_post_handle = mamba_mixer.norm.register_forward_hook(norm_post_hook) + self.hook_handles.append(norm_pre_handle) + self.hook_handles.append(norm_post_handle) + + # Final output hook + output_handle = mamba_mixer.register_forward_hook(output_mask_hook) + self.hook_handles.append(output_handle) + + # Cleanup hook - runs last to remove masks after forward pass + cleanup_handle = mamba_mixer.register_forward_hook(cleanup_masks_hook) + self.hook_handles.append(cleanup_handle) + + def set_elasticity_params(self, router_emb=None, router_mamba=None, **kwargs): + """Set current elasticity parameters that will be used by hooks.""" + if router_emb is not None: + self.current_router_emb = router_emb + + if router_mamba is not None: + self.current_router_mamba = router_mamba + + def detach_hooks(self): + """Remove all hooks.""" + if not hasattr(self, 'hook_handles'): + return + for handle in self.hook_handles: + handle.remove() + self.hook_handles.clear() + + def __del__(self): + """Cleanup hooks when manager is destroyed.""" + self.detach_hooks() + + +class FlextronTransformerLayerElasticityManager: + """ + Manages elasticity for TransformerLayer using pure PyTorch hooks. + Handles input/pre-MLP layernorm eps modification and MLP routing. + """ + + def __init__(self, config, layer_idx=0): + self.config = config + self.layer_idx = layer_idx + self.enabled = getattr(config, 'flextron', False) + + if not self.enabled: + return + + # Current elasticity parameters - store the full router outputs + self.current_router_emb = None + + # Hook handles for cleanup + self.hook_handles = [] + + def _init_embedding_masks(self): + """Initialize embedding dimension masks.""" + mask_list = [] + for emb_int in self.config.emb_int_list: + mask = torch.zeros(self.config.hidden_size, dtype=torch.bool) + mask[:emb_int] = True + mask_list.append(mask) + self.emb_masks_lookup = { + emb_int: idx for idx, emb_int in enumerate(self.config.emb_int_list) + } + self.emb_masks = torch.stack(mask_list, dim=0).to(device='cuda').to(dtype=torch.bfloat16) + + def initialize_masks(self, transformer_layer): + """Initialize masks based on the MoE module configuration.""" + if not self.enabled: + return + + self.transformer_layer = transformer_layer + self._init_embedding_masks() + + def attach_hooks(self, transformer_layer): + """Attach hooks to MLP/MoE layer for layer skipping only.""" + if not self.enabled: + return + + self.initialize_masks(transformer_layer) + + emb_effective_per_list = [x / self.config.hidden_size for x in self.config.emb_int_list] + + # Hook 2: Pre-MLP layernorm pre-hook for eps modification + def pre_mlp_layernorm_pre_hook(module, input): + + if self.config.flextron and self.current_router_emb is not None: + hidden_states = input[0] + # Apply embedding mask + if self.config.soft_mask: + soft_mask = torch.zeros( + self.emb_masks[0].shape, + dtype=torch.bfloat16, + device=self.emb_masks[0].device, + ) + for mask, per_logit in zip(self.emb_masks, self.current_router_emb[0]): + soft_mask.add_(mask * per_logit) + mask = soft_mask + masked_input = hidden_states * mask[None, None, :] + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + mask = self.emb_masks[self.emb_masks_lookup[emb_choice]] + masked_input = hidden_states * mask[None, None, :] + masked_input = masked_input * router_emb_logits + + # Modify eps for this forward pass + if self.config.soft_mask: + soft_eps = 0 + for emb_per, per_logit in zip( + emb_effective_per_list, self.current_router_emb[0] + ): + soft_eps += self.config.layernorm_epsilon * emb_per * per_logit + module.eps = soft_eps.float().detach().item() + else: + emb_choice = self.current_router_emb[1] + emb_effective_per = emb_choice / self.config.hidden_size + module.eps = self.config.layernorm_epsilon * emb_effective_per + + return tuple([masked_input] + list(input[1:])) + + return input + + # Hook 3: Pre-MLP layernorm post-hook for scaling and eps restoration + def pre_mlp_layernorm_post_hook(module, input, output): + if self.config.flextron and self.current_router_emb is not None: + + # Restore original eps + module.eps = self.config.layernorm_epsilon + + # Apply scaling + if self.config.soft_mask: + soft_scaled_output = torch.zeros_like(output) + for emb_per, per_logit in zip( + emb_effective_per_list, self.current_router_emb[0] + ): + soft_scaled_output.add_(output * (emb_per**0.5) * per_logit) + scaled_output = soft_scaled_output + else: + emb_choice = self.current_emb_choice + emb_effective_per = emb_choice / self.config.hidden_size + router_emb_logits = torch.max(self.current_router_emb[0]) + scaled_output = output * (emb_effective_per**0.5) * router_emb_logits + return scaled_output + + return output + + # Attach the pre-MLP layernorm hooks + pre_mlp_ln_pre_handle = transformer_layer.pre_mlp_layernorm.register_forward_pre_hook( + pre_mlp_layernorm_pre_hook + ) + pre_mlp_ln_post_handle = transformer_layer.pre_mlp_layernorm.register_forward_hook( + pre_mlp_layernorm_post_hook + ) + self.hook_handles.append(pre_mlp_ln_pre_handle) + self.hook_handles.append(pre_mlp_ln_post_handle) + + def set_elasticity_params(self, router_emb=None, **kwargs): + """Set current elasticity parameters that will be used by hooks.""" + if router_emb is not None: + self.current_router_emb = router_emb + self.current_emb_choice = router_emb[1] + + def detach_hooks(self): + """Remove all hooks.""" + if not hasattr(self, 'hook_handles'): + return + for handle in self.hook_handles: + handle.remove() + self.hook_handles.clear() + + def __del__(self): + """Cleanup hooks when manager is destroyed.""" + self.detach_hooks() + + +def topk_softmax_with_capacity( + logits: torch.Tensor, + topk: int, + capacity_factor: Optional[float] = None, + pad_to_capacity: bool = False, + drop_policy: str = "probs", + use_pre_softmax: bool = False, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + scaling_factor: Optional[float] = None, + deterministic_mode: bool = False, + score_function: str = "softmax", + expert_bias: Optional[torch.Tensor] = None, + current_router_moe_expert_0: Optional[torch.Tensor] = None, + current_router_moe_expert_per: Optional[List[float]] = None, + num_experts: int = 0, +): + """Apply capacity and padding to the top-k selection. + Args: + logits (torch.Tensor): Logits tensor. + topk (int): The number of experts to select for each token. + capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number + of tokens exceeds the capacity. + pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded + tokens will be 0. + drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". + If "prob", the tokens with the lowest probabilities will be dropped. + If "position", tokens at the end of each batch will be dropped. + use_pre_softmax (bool): Whether to apply softmax or sigmoid before top-k selection. + num_groups (int): Number of groups for routed experts. + group_topk (int): Number of selected groups for each token. + scaling_factor (float): Scaling factor of routing score in top-k selection. + deterministic_mode (bool): Deprecated. + score_function (str): The score function to use. Can be either "softmax" or "sigmoid". + expert_bias (torch.Tensor): The bias added to logits for expert routing. + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing + the routing probabilities for each token to each expert. + - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts] + indicating which experts were selected for each token. True values represent + the selected experts. + - tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing + the number of local tokens assigned to each expert before dropping and padding. + """ + assert score_function == "sigmoid", "Only sigmoid score function is supported for now." + assert expert_bias is not None, "Expert bias is required for sigmoid score function." + assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}." + num_tokens, num_experts = logits.shape + + def compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk: + return group_limited_topk( + scores=scores, + topk=topk, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups, + group_topk=group_topk, + ) + else: + return torch.topk(scores, k=topk, dim=1) + + if score_function == "sigmoid": + scores_for_routing = 0 + scores_for_topk = 0 + for router_moe_expert_logits, router_moe_expert_per in zip( + current_router_moe_expert_0, current_router_moe_expert_per + ): + expert_threshold = math.floor(router_moe_expert_per * num_experts) + + logits_current = logits.clone() + logits_current[:, expert_threshold:] = float('-inf') + expert_bias_current = expert_bias.clone() + expert_bias_current[expert_threshold:] = 0 + expert_bias_current = expert_bias_current * router_moe_expert_logits + + scores = ( + torch.sigmoid(logits_current.float()).type_as(logits_current) + * router_moe_expert_logits + ) + + scores_for_topk += scores + scores_for_routing += scores + expert_bias_current + + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = torch.gather(scores_for_topk, dim=1, index=top_indices).type_as(logits) + + probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + + else: + raise ValueError(f"Invalid score_function: {score_function}") + + if scaling_factor: + probs = probs * scaling_factor + + # TODO Try using element-wise operations instead of scatter? + topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs) + topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() + tokens_per_expert = topk_map.sum(dim=0) + + if capacity_factor is None: + # TopK without capacity + return topk_masked_gates, topk_map, tokens_per_expert + else: + # TopK with capacity + expert_capacity = get_capacity( + num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor + ) + + # Maskout exceeded tokens + if drop_policy == "probs": + _, capacity_indices = torch.topk( + topk_masked_gates, k=expert_capacity, dim=0, sorted=False + ) + capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool() + elif drop_policy == "position": + _, capacity_indices = torch.topk(topk_map.int(), k=expert_capacity, dim=0, sorted=False) + capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool() + else: + raise ValueError(f"Invalid drop_policy: {drop_policy}") + + if pad_to_capacity: + final_map = capacity_mask + final_probs = topk_masked_gates * final_map + else: + # Get exceed mask and maskout exceeded probs and indices + final_map = torch.logical_and(topk_map, capacity_mask) + final_probs = topk_masked_gates * final_map + return final_probs, final_map, tokens_per_expert + + +class FlextronTopKRouterElasticityManager: + """ + Manages elasticity for MoE Router using pure PyTorch hooks. + Handles expert masking in the routing logits before topk selection. + """ + + def __init__(self, config, layer_idx=0): + self.config = config + self.layer_idx = layer_idx + self.enabled = getattr(config, 'flextron', False) + + if not self.enabled: + return + + # Current elasticity parameters + self.current_router_moe_expert = None + + # Hook handles for cleanup + self.hook_handles = [] + + def attach_hooks(self, router): + """Attach hooks to TopKRouter for expert masking.""" + if not self.enabled: + return + + # Store original method for restoration + original_routing = router.routing + + def wrapped_routing(logits, **kwargs): + + # Apply expert masking before calling original routing + if self.config.flextron and self.current_router_moe_expert is not None: + + if self.config.soft_mask: + if self.config.flex_hetero_moe_expert: + moe_expert_idx = ( + self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('E') - 1 + ) + current_router_moe_expert_0 = self.current_router_moe_expert[0][ + moe_expert_idx + ] + else: + current_router_moe_expert_0 = self.current_router_moe_expert[0] + + seq_length, bsz = logits.shape[:2] + logits = logits.view(-1, self.config.num_moe_experts) + + # Apply Z-Loss + logits = router.apply_z_loss(logits) + assert self.config.moe_router_load_balancing_type == "none" + + scores, routing_map, _ = topk_softmax_with_capacity( + logits, + self.config.moe_router_topk, + capacity_factor=self.config.moe_expert_capacity_factor, + pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, + drop_policy=self.config.moe_token_drop_policy, + use_pre_softmax=self.config.moe_router_pre_softmax, + num_groups=self.config.moe_router_num_groups, + group_topk=self.config.moe_router_group_topk, + scaling_factor=self.config.moe_router_topk_scaling_factor, + deterministic_mode=self.config.deterministic_mode, + score_function=self.config.moe_router_score_function, + expert_bias=router.expert_bias, + current_router_moe_expert_0=current_router_moe_expert_0, + current_router_moe_expert_per=[ + x / self.config.num_moe_experts for x in self.config.moe_expert_int_list + ], + num_experts=self.config.num_moe_experts, + ) + + if self.config.moe_router_enable_expert_bias and torch.is_grad_enabled(): + with torch.no_grad(): + router.local_tokens_per_expert += routing_map.sum(dim=0) + + return scores, routing_map + else: + if self.config.flex_hetero_moe_expert: + moe_expert_idx = ( + self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('E') - 1 + ) + router_moe_expert_logits = torch.max( + self.current_router_moe_expert[0][moe_expert_idx] + ) + router_moe_expert_per = self.current_router_moe_expert[1][moe_expert_idx] + else: + router_moe_expert_logits, router_moe_expert_per = ( + torch.max(self.current_router_moe_expert[0]), + self.current_router_moe_expert[1], + ) + + expert_threshold = ( + router_moe_expert_per # always an integer count after conversion + ) + + # Apply the same logic as the commented lines + logits = logits.clone() + logits[:, expert_threshold:] = float('-inf') + logits = logits * router_moe_expert_logits + + # Mask the expert_bias on a temporary tensor and restore + # the original after the call so subsequent forwards + # (including full-model passes that bypass this branch) + # don't see a truncated bias. + if hasattr(router, 'expert_bias') and router.expert_bias is not None: + original_expert_bias = router.expert_bias + masked_bias = original_expert_bias.clone() + masked_bias[expert_threshold:] = 0 + router.expert_bias = masked_bias + try: + return original_routing(logits, **kwargs) + finally: + router.expert_bias = original_expert_bias + return original_routing(logits, **kwargs) + + else: + return original_routing(logits, **kwargs) + + router.routing = wrapped_routing + # Store reference to restore later + router._original_routing = original_routing + self.hook_handles.append(('method_replacement', router, 'routing')) + + def set_elasticity_params(self, router_moe_expert=None, **kwargs): + """Set current elasticity parameters that will be used by hooks.""" + if router_moe_expert is not None: + self.current_router_moe_expert = router_moe_expert + + def detach_hooks(self): + """Remove all hooks and restore original methods.""" + if not hasattr(self, 'hook_handles'): + return + for handle in self.hook_handles: + if isinstance(handle, tuple) and handle[0] == 'method_replacement': + # Restore original method + _, router, method_name = handle + if hasattr(router, '_original_routing'): + router.routing = router._original_routing + delattr(router, '_original_routing') + else: + # Regular hook handle + handle.remove() + self.hook_handles.clear() + + def __del__(self): + """Cleanup hooks when manager is destroyed.""" + self.detach_hooks() + + +class FlextronMoEElasticityManager: + """ + Manages elasticity for MLP/MoE layers using pure PyTorch hooks. + Now supports both traditional MLP ('-') and MoE ('E') layers with layer skipping. + """ + + def __init__(self, config, layer_idx=0): + self.config = config + self.layer_idx = layer_idx + self.enabled = getattr(config, 'flextron', False) + + if not self.enabled: + return + + # Current elasticity parameters - store the full router outputs + self.current_router_emb = None + # Hook handles for cleanup + self.hook_handles = [] + + def _init_embedding_masks(self): + """Initialize embedding dimension masks.""" + mask_list = [] + for emb_int in self.config.emb_int_list: + mask = torch.zeros(self.config.hidden_size, dtype=torch.bool) + mask[:emb_int] = True + mask_list.append(mask) + self.emb_masks_lookup = { + emb_int: idx for idx, emb_int in enumerate(self.config.emb_int_list) + } + self.emb_masks = torch.stack(mask_list, dim=0).to(device='cuda').to(dtype=torch.bfloat16) + + def initialize_masks(self, moe_module): + """Initialize masks based on the MoE module configuration.""" + if not self.enabled: + return + + self.moe_module = moe_module + self._init_embedding_masks() + + def attach_hooks(self, moe_module): + """Attach hooks to MLP/MoE layer for layer skipping only.""" + if not self.enabled: + return + + self.initialize_masks(moe_module) + + def output_mask_hook(module, input, output): + + if self.config.flextron and self.current_router_emb is not None: + out, out_bias = output + + if self.config.soft_mask: + soft_mask = torch.zeros( + self.emb_masks[0].shape, + dtype=torch.bfloat16, + device=self.emb_masks[0].device, + ) + for mask, per_logit in zip(self.emb_masks, self.current_router_emb[0]): + soft_mask.add_(mask * per_logit) + mask = soft_mask + masked_out = out * mask[None, None, :] + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + mask = self.emb_masks[self.emb_masks_lookup[emb_choice]] + masked_out = out * mask[None, None, :] + masked_out = masked_out * router_emb_logits + + return (masked_out, out_bias) + return output + + # Attach the output hook + + output_handle = moe_module.register_forward_hook(output_mask_hook) + self.hook_handles.append(output_handle) + + def set_elasticity_params(self, router_emb=None, **kwargs): + """Set current elasticity parameters that will be used by hooks.""" + if router_emb is not None: + self.current_router_emb = router_emb + + def detach_hooks(self): + """Remove all hooks.""" + if not hasattr(self, 'hook_handles'): + return + for handle in self.hook_handles: + handle.remove() + self.hook_handles.clear() + + def __del__(self): + """Cleanup hooks when manager is destroyed.""" + self.detach_hooks() + + +class FlextronGroupedMLPElasticityManager: + + def __init__(self, config, layer_idx=0): + self.config = config + self.layer_idx = layer_idx + self.enabled = getattr(config, 'flextron', False) + self.mlp_idx = self.config.hybrid_layer_pattern[: self.layer_idx + 1].count('E') - 1 + + if not self.enabled: + return + + self.current_router_mlp = None + self.current_router_emb = None + + self.hook_handles = [] + + def _init_embedding_masks(self): + """Initialize embedding dimension masks.""" + mask_list = [] + for emb_int in self.config.emb_int_list: + mask = torch.zeros(self.config.hidden_size, dtype=torch.bool) + mask[:emb_int] = True + mask_list.append(mask) + self.emb_masks_lookup = { + emb_int: idx for idx, emb_int in enumerate(self.config.emb_int_list) + } + self.emb_masks = torch.stack(mask_list, dim=0).to(device='cuda').to(dtype=torch.bfloat16) + + def _init_mlp_masks(self): + """Initialize MLP-specific masks.""" + mask_list = [] + list_mlp_mask = list(set(self.config.mlp_int_list)) + list_mlp_mask.sort(reverse=True) + for mlp_int in list_mlp_mask: + mask_temp = torch.zeros(self.config.ffn_hidden_size, dtype=torch.bool) + mask_temp[:mlp_int] = True + mask_list.append(mask_temp) + mask_list = [item.to(mask_list[0].device) for item in mask_list] + self.mlp_intermediate_masks = ( + torch.stack(mask_list, dim=0).to(device='cuda').to(dtype=torch.bfloat16) + ) + self.mlp_intermediate_masks_lookup = { + mlp_int: idx for idx, mlp_int in enumerate(list_mlp_mask) + } + + def initialize_masks(self, mlp_module): + """Initialize masks based on the MLP configuration.""" + if not self.enabled: + return + + self.mlp_module = mlp_module + self._init_embedding_masks() + self._init_mlp_masks() + + def attach_hooks(self, mlp_module): + """Attach hooks to MLP following the original flextron_os pattern.""" + if not self.enabled: + return + + self.mlp_module = mlp_module + + emb_effective_per_list = [x / self.config.hidden_size for x in self.config.emb_int_list] + + # Setup hook - runs first to initialize masks for this forward pass + def setup_masks_hook(module, input): + if self.config.flextron: + self._init_embedding_masks() + self._init_mlp_masks() + return input + + # Cleanup hook - runs last to remove masks after forward pass + def cleanup_masks_hook(module, input, output): + if self.config.flextron: + self.emb_masks = None + self.mlp_intermediate_masks = None + self.emb_masks_lookup = {} + self.mlp_intermediate_masks_lookup = {} + return output + + # IMPORTANT: Register setup hook FIRST + setup_handle = mlp_module.register_forward_pre_hook(setup_masks_hook) + self.hook_handles.append(setup_handle) + + # Hook 1: Input masking and router_emb processing + def input_mask_hook(module, input): + if self.config.flextron and self.current_router_emb is not None: + hidden_states = input[0] + + # Apply embedding mask + if self.config.soft_mask: + soft_mask = torch.zeros( + self.emb_masks[0].shape, + dtype=torch.bfloat16, + device=self.emb_masks[0].device, + ) + for mask, per_logit in zip(self.emb_masks, self.current_router_emb[0]): + soft_mask.add_(mask * per_logit) + mask = soft_mask + masked_input = hidden_states * mask[None, :] + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + mask = self.emb_masks[self.emb_masks_lookup[emb_choice]] + masked_input = hidden_states * mask[None, :] + masked_input = masked_input * router_emb_logits + + # Process router_mlp logic here. Both attributes are read + # by fc1_post_hook only when current_router_mlp is not None, + # so we only assign them in that branch (avoids + # UnboundLocalError on mlp_per when the router produced no + # MLP output this step). + if self.current_router_mlp is not None: + if self.config.flex_hetero_ffn: + router_weights = torch.max(self.current_router_mlp[0][self.mlp_idx]) + mlp_per = self.current_router_mlp[1][self.mlp_idx] + else: + router_weights, mlp_per = ( + torch.max(self.current_router_mlp[0]), + self.current_router_mlp[1], + ) + module._flextron_router_weights = router_weights + module._flextron_mlp_per = mlp_per + + return tuple([masked_input] + list(input[1:])) + return input + + # Hook 2: Linear FC1 post-hook for router scaling and masking + def fc1_post_hook(module, input, output): + + # Apply router_emb scaling and MLP masking. Both router_emb and + # router_mlp must be set: the body reads current_router_emb[0] + # for the emb scaling and current_router_mlp[0] for the mask. + # Today they're always set together by update_hook_elasticity_params, + # but guarding both makes the precondition explicit. + if ( + self.config.flextron + and self.current_router_mlp is not None + and self.current_router_emb is not None + ): + intermediate_parallel, bias_parallel = output + if self.config.soft_mask: + soft_intermediate_parallel = torch.zeros_like(intermediate_parallel) + for emb_per, per_logit in zip( + emb_effective_per_list, self.current_router_emb[0] + ): + soft_intermediate_parallel.add_(intermediate_parallel * per_logit) + intermediate_parallel = soft_intermediate_parallel + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + intermediate_parallel = intermediate_parallel * router_emb_logits + + # # Apply MLP masking and router weights + + mlp_per = mlp_module._flextron_mlp_per + router_weights = getattr(mlp_module, '_flextron_router_weights', None) + + # Apply masking + if self.config.soft_mask: + soft_mask = torch.zeros( + self.mlp_intermediate_masks[0].shape, + dtype=torch.bfloat16, + device=self.mlp_intermediate_masks[0].device, + ) + if self.config.flex_hetero_ffn: + for mask, per_logit in zip( + self.mlp_intermediate_masks, self.current_router_mlp[0][self.mlp_idx] + ): + soft_mask.add_(mask * per_logit) + else: + for mask, per_logit in zip( + self.mlp_intermediate_masks, self.current_router_mlp[0] + ): + soft_mask.add_(mask * per_logit) + mask = soft_mask + else: + mask = self.mlp_intermediate_masks[self.mlp_intermediate_masks_lookup[mlp_per]] + + world_size = parallel_state.get_expert_tensor_parallel_world_size() + mask_list = split_tensor_along_last_dim(mask, world_size) + rank = parallel_state.get_expert_tensor_parallel_rank() + + mask = mask_list[rank].contiguous() + + intermediate_parallel = ( + intermediate_parallel * mask.to(device=intermediate_parallel.device)[None, :] + ) + if router_weights is not None and not self.config.soft_mask: + intermediate_parallel = intermediate_parallel * router_weights + + module.eps = self.config.layernorm_epsilon + + return (intermediate_parallel, bias_parallel) + return output + + # Hook 3: Final output masking + def output_mask_hook(module, input, output): + if self.config.flextron and self.current_router_emb is not None: + out, out_bias = output + + # Apply embedding mask + if self.config.soft_mask: + soft_mask = torch.zeros( + self.emb_masks[0].shape, + dtype=torch.bfloat16, + device=self.emb_masks[0].device, + ) + for mask, per_logit in zip(self.emb_masks, self.current_router_emb[0]): + soft_mask.add_(mask * per_logit) + mask = soft_mask + masked_out = out * mask[None, :] + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + mask = self.emb_masks[self.emb_masks_lookup[emb_choice]] + masked_out = out * mask[None, :] + masked_out = masked_out * router_emb_logits + + return (masked_out, out_bias) + return output + + # Hook 1: Input masking and router_emb processing + main_handle = mlp_module.register_forward_pre_hook(input_mask_hook) + self.hook_handles.append(main_handle) + + # Hook 2: Linear FC1 pre-hook for eps modification + fc1_post_handle = mlp_module.linear_fc1.register_forward_hook(fc1_post_hook) + self.hook_handles.append(fc1_post_handle) + + # Hook 3: Final output masking + output_handle = mlp_module.register_forward_hook(output_mask_hook) + self.hook_handles.append(output_handle) + + # Cleanup hook - runs last to remove masks after forward pass + cleanup_handle = mlp_module.register_forward_hook(cleanup_masks_hook) + self.hook_handles.append(cleanup_handle) + + def set_elasticity_params(self, router_emb=None, router_mlp=None, **kwargs): + """Set current elasticity parameters that will be used by hooks.""" + if router_emb is not None: + self.current_router_emb = router_emb + + if router_mlp is not None: + self.current_router_mlp = router_mlp + + def detach_hooks(self): + """Remove all hooks.""" + if not hasattr(self, 'hook_handles'): + return + for handle in self.hook_handles: + handle.remove() + self.hook_handles.clear() + + def __del__(self): + """Cleanup hooks when manager is destroyed.""" + self.detach_hooks() + + +class FlextronAttentionElasticityManager: + """ + Manages elasticity for Attention using pure PyTorch hooks. + Based on the exact implementation from original flextron_os Attention. + """ + + def __init__(self, config, layer_idx=0): + self.config = config + self.layer_idx = layer_idx + self.enabled = getattr(config, 'flextron', False) + + if not self.enabled: + return + + # Current elasticity parameters - store the full router outputs + self.current_router_emb = None + + # Hook handles for cleanup + self.hook_handles = [] + + def _init_embedding_masks(self): + """Initialize embedding dimension masks.""" + mask_list = [] + for emb_int in self.config.emb_int_list: + mask = torch.zeros(self.config.hidden_size, dtype=torch.bool) + mask[:emb_int] = True + mask_list.append(mask) + self.emb_masks_lookup = { + emb_int: idx for idx, emb_int in enumerate(self.config.emb_int_list) + } + self.emb_masks = torch.stack(mask_list, dim=0).to(device='cuda').to(dtype=torch.bfloat16) + + def attach_hooks(self, attention_module): + """Attach hooks to Attention following the original flextron_os pattern.""" + if not self.enabled: + return + + self.attention_module = attention_module + + emb_effective_per_list = [x / self.config.hidden_size for x in self.config.emb_int_list] + + # Setup hook - runs first to initialize masks for this forward pass + def setup_masks_hook(module, input): + if self.config.flextron: + self._init_embedding_masks() + return input + + # Cleanup hook - runs last to remove masks after forward pass + def cleanup_masks_hook(module, input, output): + if self.config.flextron: + self.emb_masks = None + self.emb_masks_lookup = {} + return output + + # IMPORTANT: Register setup hook FIRST + setup_handle = attention_module.register_forward_pre_hook(setup_masks_hook) + self.hook_handles.append(setup_handle) + + # Hook 1: Input masking and router_emb processing + def input_mask_hook(module, input): + if self.config.flextron and self.current_router_emb is not None: + hidden_states = input[0] + + # Apply embedding mask + if self.config.soft_mask: + soft_mask = torch.zeros( + self.emb_masks[0].shape, + dtype=torch.bfloat16, + device=self.emb_masks[0].device, + ) + for mask, per_logit in zip(self.emb_masks, self.current_router_emb[0]): + soft_mask.add_(mask * per_logit) + mask = soft_mask + masked_input = hidden_states * mask[None, None, :] + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + mask = self.emb_masks[self.emb_masks_lookup[emb_choice]] + masked_input = ( + hidden_states * mask.to(device=hidden_states.device)[None, None, :] + ) + masked_input = masked_input * router_emb_logits + + return tuple([masked_input] + list(input[1:])) + return input + + # Hook 2: Linear QKV pre-hook for eps modification + def linear_qkv_pre_hook(module, input): + if self.config.flextron and self.current_router_emb is not None: + # Set eps on linear_qkv (fused layernorm) + if self.config.soft_mask: + soft_eps = 0 + for emb_per, per_logit in zip( + emb_effective_per_list, self.current_router_emb[0] + ): + soft_eps += self.config.layernorm_epsilon * emb_per * per_logit + module.eps = soft_eps.float().detach().item() + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + emb_effective_per = emb_choice / self.config.hidden_size + module.eps = self.config.layernorm_epsilon * emb_effective_per + + return input + + # Hook 3: Linear QKV post-hook for scaling + def linear_qkv_post_hook(module, input, output): + if self.config.flextron and self.current_router_emb is not None: + query_key_value, bias = output + if self.config.soft_mask: + soft_query_key_value = torch.zeros_like(query_key_value) + for emb_per, per_logit in zip( + emb_effective_per_list, self.current_router_emb[0] + ): + soft_query_key_value.add_(query_key_value * (emb_per**0.5) * per_logit) + scaled_output = soft_query_key_value + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + emb_effective_per = emb_choice / self.config.hidden_size + scaled_output = query_key_value * router_emb_logits * (emb_effective_per**0.5) + module.eps = self.config.layernorm_epsilon + return (scaled_output, bias) + return output + + # Hook 5: Final output masking + def output_mask_hook(module, input, output): + if self.config.flextron and self.current_router_emb is not None: + out, out_bias = output + + # Apply embedding mask + if self.config.soft_mask: + soft_mask = torch.zeros( + self.emb_masks[0].shape, + dtype=torch.bfloat16, + device=self.emb_masks[0].device, + ) + for mask, per_logit in zip(self.emb_masks, self.current_router_emb[0]): + soft_mask.add_(mask * per_logit) + mask = soft_mask + masked_out = out * mask[None, None, :] + else: + router_emb_logits, emb_choice = ( + torch.max(self.current_router_emb[0]), + self.current_router_emb[1], + ) + mask = self.emb_masks[self.emb_masks_lookup[emb_choice]] + masked_out = out * mask[None, None, :] + masked_out = masked_out * router_emb_logits + + return (masked_out, out_bias) + return output + + # Hook 1: Input masking and router_emb processing + main_handle = attention_module.register_forward_pre_hook(input_mask_hook) + self.hook_handles.append(main_handle) + + # Hook 2&3: Linear QKV pre-hook for eps modification + qkv_pre_handle = attention_module.linear_qkv.register_forward_pre_hook(linear_qkv_pre_hook) + qkv_post_handle = attention_module.linear_qkv.register_forward_hook(linear_qkv_post_hook) + self.hook_handles.append(qkv_pre_handle) + self.hook_handles.append(qkv_post_handle) + + # Final output masking + output_handle = attention_module.register_forward_hook(output_mask_hook) + self.hook_handles.append(output_handle) + + # Cleanup hook - runs last to remove masks after forward pass + cleanup_handle = attention_module.register_forward_hook(cleanup_masks_hook) + self.hook_handles.append(cleanup_handle) + + def set_elasticity_params(self, router_emb=None, **kwargs): + """Set current elasticity parameters that will be used by hooks.""" + if router_emb is not None: + self.current_router_emb = router_emb + + def detach_hooks(self): + """Remove all hooks.""" + if not hasattr(self, 'hook_handles'): + return + for handle in self.hook_handles: + handle.remove() + self.hook_handles.clear() + + def __del__(self): + """Cleanup hooks when manager is destroyed.""" + self.detach_hooks() + + +class FlextronStackElasticityManager: + """ + Manages elasticity for HybridStack using pure PyTorch hooks. + Handles input masking and final norm scaling. + """ + + def __init__(self, config): + self.config = config + self.enabled = getattr(config, 'flextron', False) + + if not self.enabled: + return + + # Current elasticity parameters + self.current_emb_choice = self.config.hidden_size + self.current_router_emb = None + + # Hook handles for cleanup + self.hook_handles = [] + + # Pre-computed masks + self.emb_masks = None + self.emb_masks_lookup = {} + + def initialize_masks(self, stack): + """Initialize masks based on the stack configuration.""" + if not self.enabled: + return + + self.stack = stack + + def attach_hooks(self, stack): + """Attach hooks to HybridStack.""" + if not self.enabled: + return + + self.initialize_masks(stack) + + emb_effective_per_list = [x / self.config.hidden_size for x in self.config.emb_int_list] + + # Hook 1: Final norm pre-hook for eps modification + def final_norm_pre_hook(module, input): + if self.config.flextron and self.current_router_emb is not None: + # Modify eps for this forward pass + if self.config.soft_mask: + soft_eps = 0 + for emb_per, per_logit in zip( + emb_effective_per_list, self.current_router_emb[0] + ): + soft_eps += self.config.layernorm_epsilon * emb_per * per_logit + module.eps = soft_eps.float().detach().item() + else: + emb_choice = self.current_emb_choice + emb_effective_per = emb_choice / self.config.hidden_size + module.eps = self.config.layernorm_epsilon * emb_effective_per + + return input + + # Hook 2: Final norm post-hook for scaling and eps restoration + def final_norm_post_hook(module, input, output): + if self.config.flextron and self.current_router_emb is not None: + # Restore original eps + module.eps = self.config.layernorm_epsilon + + # Apply scaling + if self.config.soft_mask: + soft_scaled_output = torch.zeros_like(output) + for emb_per, per_logit in zip( + emb_effective_per_list, self.current_router_emb[0] + ): + soft_scaled_output.add_(output * (emb_per**0.5) * per_logit) + scaled_output = soft_scaled_output + else: + emb_choice = self.current_emb_choice + emb_effective_per = emb_choice / self.config.hidden_size + router_emb_logits = torch.max(self.current_router_emb[0]) + scaled_output = output * (emb_effective_per**0.5) * router_emb_logits + return scaled_output + + return output + + # Hooks for final norm if it exists + final_norm_pre_handle = stack.final_norm.register_forward_pre_hook(final_norm_pre_hook) + final_norm_post_handle = stack.final_norm.register_forward_hook(final_norm_post_hook) + self.hook_handles.append(final_norm_pre_handle) + self.hook_handles.append(final_norm_post_handle) + + def set_elasticity_params(self, router_emb=None, **kwargs): + """Set current elasticity parameters that will be used by hooks.""" + if router_emb is not None: + self.current_router_emb = router_emb + self.current_emb_choice = router_emb[1] + + def detach_hooks(self): + """Remove all hooks and restore original forward method.""" + if not hasattr(self, 'hook_handles'): + return + for handle in self.hook_handles: + handle.remove() + self.hook_handles.clear() + + def __del__(self): + """Cleanup hooks when manager is destroyed.""" + self.detach_hooks() + + +def add_flextron_mamba_elasticity(mamba_mixer, config, layer_idx=0): + """ + Add elasticity to a MambaMixer using hooks. + + Args: + mamba_mixer: The MambaMixer instance to add elasticity to + config: Configuration object with flextron settings + layer_idx: Index of this layer in the hybrid pattern + + Returns: + FlextronMambaElasticityManager: Manager object to control elasticity + """ + if hasattr(mamba_mixer, '_flextron_manager'): + return mamba_mixer._flextron_manager + manager = FlextronMambaElasticityManager(config, layer_idx) + manager.attach_hooks(mamba_mixer) + + # Store manager reference on the mixer for easy access + mamba_mixer._flextron_manager = manager + + return manager + + +def add_flextron_transformer_layer_elasticity(transformer_layer, config, layer_idx=0): + """ + Add elasticity to a TransformerLayer using hooks. + + Args: + transformer_layer: The TransformerLayer instance to add elasticity to + config: Configuration object with flextron settings + layer_idx: Index of this layer in the hybrid pattern + + Returns: + FlextronTransformerLayerElasticityManager: Manager object to control elasticity + """ + if hasattr(transformer_layer, '_flextron_layer_manager'): + return transformer_layer._flextron_layer_manager + manager = FlextronTransformerLayerElasticityManager(config, layer_idx) + manager.attach_hooks(transformer_layer) + + # Store manager reference on the layer for easy access + transformer_layer._flextron_layer_manager = manager + + return manager + + +def add_flextron_topk_router_elasticity(router, config, layer_idx=0): + """ + Add elasticity to a TopKRouter using hooks. + + Args: + router: The TopKRouter instance to add elasticity to + config: Configuration object with flextron settings + layer_idx: Index of this layer in the hybrid pattern + + Returns: + FlextronTopKRouterElasticityManager: Manager object to control elasticity + """ + if hasattr(router, '_flextron_router_manager'): + return router._flextron_router_manager + manager = FlextronTopKRouterElasticityManager(config, layer_idx) + manager.attach_hooks(router) + + # Store manager reference on the router for easy access + router._flextron_router_manager = manager + + return manager + + +def add_flextron_moe_elasticity(moe_module, config, layer_idx=0): + """ + Add elasticity to a MoE using hooks. + + Args: + moe_module: The MoE instance to add elasticity to + config: Configuration object with flextron settings + layer_idx: Index of this layer in the hybrid pattern + + Returns: + FlextronMoEElasticityManager: Manager object to control elasticity + """ + if hasattr(moe_module, '_flextron_manager'): + return moe_module._flextron_manager + manager = FlextronMoEElasticityManager(config, layer_idx) + manager.attach_hooks(moe_module) + + # Store manager reference on the module for easy access + moe_module._flextron_manager = manager + + return manager + + +def add_flextron_grouped_mlp_elasticity(grouped_mlp_module, config, layer_idx=0): + """ + Add elasticity to a GroupedMLP using hooks. + """ + if hasattr(grouped_mlp_module, '_flextron_manager'): + return grouped_mlp_module._flextron_manager + manager = FlextronGroupedMLPElasticityManager(config, layer_idx) + manager.attach_hooks(grouped_mlp_module) + + # Store manager reference on the module for easy access + grouped_mlp_module._flextron_manager = manager + + return manager + + +def add_flextron_attention_elasticity(attention_module, config, layer_idx=0): + """ + Add elasticity to an Attention module using hooks. + + Args: + attention_module: The Attention instance to add elasticity to + config: Configuration object with flextron settings + layer_idx: Index of this layer in the hybrid pattern + + Returns: + FlextronAttentionElasticityManager: Manager object to control elasticity + """ + if hasattr(attention_module, '_flextron_manager'): + return attention_module._flextron_manager + manager = FlextronAttentionElasticityManager(config, layer_idx) + manager.attach_hooks(attention_module) + + # Store manager reference on the module for easy access + attention_module._flextron_manager = manager + + return manager + + +def add_flextron_stack_elasticity(stack, config): + """ + Add elasticity to a HybridStack using hooks. + + Args: + stack: The HybridStack instance to add elasticity to + config: Configuration object with flextron settings + + Returns: + FlextronStackElasticityManager: Manager object to control elasticity + """ + if hasattr(stack, '_flextron_manager'): + return stack._flextron_manager + manager = FlextronStackElasticityManager(config) + manager.attach_hooks(stack) + + # Store manager reference on the stack for easy access + stack._flextron_manager = manager + + return manager + + +# Convenience function to apply elasticity to all modules in a model +def apply_flextron_elasticity_to_model(model, config): + """Apply elasticity to all MambaMixer, MLP/MoE, and Attention instances in a model based on hybrid pattern.""" + managers = [] + + if not hasattr(config, 'hybrid_layer_pattern') or not config.hybrid_layer_pattern: + # No hybrid pattern, skip elasticity setup + return managers + + hybrid_pattern = config.hybrid_layer_pattern + + # Find decoder layers + decoder = getattr(model, 'decoder', None) + layers = getattr(decoder, 'layers', None) + + if decoder is None or layers is None: + return managers + + # Apply elasticity per layer based on hybrid pattern + for layer_idx, layer_char in enumerate(hybrid_pattern): + if layer_idx >= len(layers): + break + + layer = layers[layer_idx] + + if layer_char == 'E': # MoE layer (treated as MLP replacement) + if ( + 'MoETransformerLayer' == layer.__class__.__name__ + or 'TransformerLayer' == layer.__class__.__name__ + ): + layer_manager = add_flextron_transformer_layer_elasticity(layer, config, layer_idx) + managers.append(layer_manager) + + # Find MoELayer module in this layer + moe_module = None + for name, module in layer.named_modules(): + if 'MoELayer' == module.__class__.__name__: + moe_module = module + break + if moe_module is not None: + manager = add_flextron_moe_elasticity(moe_module, config, layer_idx) + managers.append(manager) + + # Also add router elasticity to the MoE router + router_module = None + for name, module in moe_module.named_modules(): + if 'TopKRouter' == module.__class__.__name__: + router_module = module + break + if router_module is not None: + router_manager = add_flextron_topk_router_elasticity( + router_module, config, layer_idx + ) + managers.append(router_manager) + + # Find TEGroupedMLP module in this layer + moe_module = None + for name, module in layer.named_modules(): + if 'TEGroupedMLP' == module.__class__.__name__: + moe_module = module + break + if moe_module is not None: + manager = add_flextron_grouped_mlp_elasticity(moe_module, config, layer_idx) + managers.append(manager) + + elif layer_char == 'M': # Mamba layer + mamba_module = None + for name, module in layer.named_modules(): + if 'MambaMixer' == module.__class__.__name__: + mamba_module = module + break + if mamba_module is not None: + manager = add_flextron_mamba_elasticity(mamba_module, config, layer_idx) + managers.append(manager) + + elif layer_char == '*': # Attention layer (TransformerLayer) + attention_module = None + for name, module in layer.named_modules(): + if 'SelfAttention' == module.__class__.__name__: + attention_module = module + break + if attention_module is not None: + manager = add_flextron_attention_elasticity(attention_module, config, layer_idx) + managers.append(manager) + + # Also add hooks to HybridStack if present + if hasattr(model, 'decoder') and hasattr(model.decoder, 'final_norm'): + stack_manager = add_flextron_stack_elasticity(model.decoder, config) + managers.append(stack_manager) + + # Store all managers on the model + model._flextron_managers = managers + return managers diff --git a/megatron/elastification/flextron_utils.py b/megatron/elastification/flextron_utils.py new file mode 100644 index 00000000000..d165bf678cb --- /dev/null +++ b/megatron/elastification/flextron_utils.py @@ -0,0 +1,471 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Flextron Utilities + +Provides setup and configuration functions for Flextron elasticity. +Extracted from HybridModel to keep the core model clean. +""" + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from megatron.core import mpu, parallel_state +from megatron.elastification.arguments import convert_per_lists_to_int_lists +from megatron.elastification.flextron_config import inject_flextron_config +from megatron.elastification.flextron_elasticity_hooks import apply_flextron_elasticity_to_model +from megatron.elastification.memory_config import MemoryConfig, load_memory_config +from megatron.elastification.router.flex_budget_utils import ( + get_memory_footprint, + get_num_parameters, +) +from megatron.elastification.router.hybrid_flex_router import FlextronRouter +from megatron.training import get_args + + +class FlextronModelManager: + """ + Manages Flextron functionality for a model. + Handles router, budget calculations, and loss functions. + """ + + def __init__(self, model, config): + self.model = model + self.config = config + inject_flextron_config(get_args(), config) + convert_per_lists_to_int_lists(config) + config.hybrid_layer_pattern = getattr(model, 'hybrid_layer_pattern', '') + self.router = None + self.budget_type = getattr(config, 'budget_type', 'param') + + # Load memory quantization profile from args + args = get_args() + self.memory_config = load_memory_config(args) + + # Budget calculation attributes + self.all_param = None + self.total_memory = None + + # Hook managers + self.hook_managers = [] + + def setup_router(self): + """Initialize the Flextron router if enabled.""" + if getattr(self.config, 'enable_router', False): # and self.model.pre_process: + self.router = FlextronRouter(config=self.config) + + # Make router name pipeline-stage-aware to avoid naming conflicts in PP>1 + pp_rank = mpu.get_pipeline_model_parallel_rank() + router_name = f"router_pp{pp_rank}" + + # Set the router with pipeline-specific name + setattr(self.model, router_name, self.router) + self.model.router = self.router + else: + self.model.router = None + + def setup_budget_functions(self): + """Setup budget calculation functions based on budget type.""" + self._setup_param_loss_func() + + if self.budget_type == 'mem': + self._setup_memory_loss_func() + + def setup_hooks(self): + """Setup elasticity hooks on the model.""" + if getattr(self.config, 'flextron', False): + self.hook_managers = apply_flextron_elasticity_to_model(self.model, self.config) + + def _setup_param_loss_func(self): + """Setup parameter counting for budget calculations.""" + + self.all_param, self.active_param = torch.tensor( + get_num_parameters( + hybrid_pattern=self.model.hybrid_layer_pattern, + mamba_num_heads=self.config.mamba_num_heads, + mamba_d_head=self.config.mamba_head_dim, + mamba_d_state=self.config.mamba_state_dim, + num_attention_heads=self.config.num_attention_heads, + num_query_groups=self.config.num_query_groups, + ffn_hidden_size=self.config.ffn_hidden_size, + hidden_size=self.config.hidden_size, + kv_channels=self.config.kv_channels, + vocab_size=self.model.vocab_size, + tied_vocab=self.model.share_embeddings_and_output_weights, + num_experts=self.config.num_moe_experts, + shared_expert_intermediate_size=self.config.moe_shared_expert_intermediate_size, + moe_router_topk=self.config.moe_router_topk, + ), + dtype=torch.float32, + device=torch.cuda.current_device(), + ) + + def _setup_memory_loss_func(self): + """Setup memory loss function by calculating the baseline memory footprint.""" + self.total_memory = ( + get_memory_footprint( + hybrid_pattern=self.model.hybrid_layer_pattern, + mamba_num_heads=self.config.mamba_num_heads, + mamba_d_head=self.config.mamba_head_dim, + mamba_d_state=self.config.mamba_state_dim, + num_attention_heads=self.config.num_attention_heads, + num_query_groups=self.config.num_query_groups, + ffn_hidden_size=self.config.ffn_hidden_size, + hidden_size=self.config.hidden_size, + kv_channels=self.config.kv_channels, + vocab_size=self.model.vocab_size, + tied_vocab=self.model.share_embeddings_and_output_weights, + mem_infer_seq_len=self.config.mem_infer_seq_len, + mem_batch_size=self.config.mem_batch_size, + prefill_chunk_size=self.config.prefill_chunk_size, + moe_num_experts=self.config.num_moe_experts, + shared_expert_intermediate_size=self.config.moe_shared_expert_intermediate_size, + moe_router_topk=self.config.moe_router_topk, + memory_config=self.memory_config, + ) + .float() + .to(torch.cuda.current_device()) + ) + + print( + f"Total baseline memory footprint: {self.total_memory.item():.4f} GB " + f"(profile={getattr(get_args(), 'memory_profile', 'bf16')}, " + f"param_target={self.memory_config.param_budget_target})" + ) + + def budget_loss_func(self, flextron_kwargs, budget_item=0): + """Calculate budget-based loss exactly as in the original implementation.""" + dtype, device = ( + flextron_kwargs['router_mlp'][0].dtype, + flextron_kwargs['router_mlp'][0].device, + ) + + flex_mamba_num_head = flextron_kwargs['router_mamba'][0] @ torch.tensor( + self.config.mamba_int_list, dtype=dtype, device=device + ) + flex_hidden_size = flextron_kwargs['router_emb'][0] @ torch.tensor( + self.config.emb_int_list, dtype=dtype, device=device + ) + flex_ffn_hidden_size = flextron_kwargs['router_mlp'][0] @ torch.tensor( + self.config.mlp_int_list, dtype=dtype, device=device + ) + flex_moe_expert = flextron_kwargs['router_moe_expert'][0] @ torch.tensor( + self.config.moe_expert_int_list, dtype=dtype, device=device + ) + # Attention heads are not router-controlled; pass the parent value through. + num_attention_heads = self.config.num_attention_heads + + if self.config.add_skipping: + logit_skip_selected = torch.cumsum(flextron_kwargs['router_skip'][0], 0)[:-1] + logit_skip_all = torch.ones(self.config.num_layers).to(dtype=dtype, device=device) + logit_skip_all[self.config.layer_ranking_list] = logit_skip_selected + + mamba_idxs = [ + i for i, char in enumerate(self.model.hybrid_layer_pattern) if char == 'M' + ] + mamba_idxs = torch.tensor(mamba_idxs, dtype=torch.long) + flex_mamba_num_head = flex_mamba_num_head * logit_skip_all[mamba_idxs] + flex_mamba_num_head = flex_mamba_num_head.unsqueeze(-1) + + head_idxs = [i for i, char in enumerate(self.model.hybrid_layer_pattern) if char == '*'] + head_idxs = torch.tensor(head_idxs, dtype=torch.long) + num_attention_heads = num_attention_heads * logit_skip_all[head_idxs] + num_attention_heads = num_attention_heads.unsqueeze(-1) + + moe_idxs = [i for i, char in enumerate(self.model.hybrid_layer_pattern) if char == 'E'] + moe_idxs = torch.tensor(moe_idxs, dtype=torch.long) + flex_ffn_hidden_size = flex_ffn_hidden_size * logit_skip_all[moe_idxs] + flex_ffn_hidden_size = flex_ffn_hidden_size.unsqueeze(-1) + + flex_moe_expert = flex_moe_expert * logit_skip_all[moe_idxs] + flex_moe_expert = flex_moe_expert.unsqueeze(-1) + + if not self.config.flex_hetero_ffn and not self.config.add_skipping: + flex_ffn_hidden_size = flex_ffn_hidden_size.unsqueeze(-1) + if not self.config.flex_hetero_mamba and not self.config.add_skipping: + flex_mamba_num_head = flex_mamba_num_head.unsqueeze(-1) + if not self.config.flex_hetero_moe_expert and not self.config.add_skipping: + flex_moe_expert = flex_moe_expert.unsqueeze(-1) + + current_param_all, current_param_active = get_num_parameters( + hybrid_pattern=self.model.hybrid_layer_pattern, + mamba_num_heads=flex_mamba_num_head.float(), + mamba_d_head=self.config.mamba_head_dim, + mamba_d_state=self.config.mamba_state_dim, + num_attention_heads=( + num_attention_heads.float() + if isinstance(num_attention_heads, torch.Tensor) + else num_attention_heads + ), + num_query_groups=self.config.num_query_groups, + ffn_hidden_size=flex_ffn_hidden_size.float(), + hidden_size=flex_hidden_size.unsqueeze(-1).float(), + kv_channels=self.config.kv_channels, + vocab_size=self.model.vocab_size, + tied_vocab=self.model.share_embeddings_and_output_weights, + num_experts=flex_moe_expert.float(), + shared_expert_intermediate_size=self.config.moe_shared_expert_intermediate_size, + moe_router_topk=self.config.moe_router_topk, + ) + + if self.config.budget_type == 'param': + if self.memory_config.param_budget_target == 'active': + diff = abs(current_param_active / (budget_item * self.active_param) - 1) + else: + diff = abs(current_param_all / (budget_item * self.all_param) - 1) + elif self.config.budget_type == 'mem': + current_mem = get_memory_footprint( + hybrid_pattern=self.model.hybrid_layer_pattern, + mamba_num_heads=flex_mamba_num_head.float(), + mamba_d_head=self.config.mamba_head_dim, + mamba_d_state=self.config.mamba_state_dim, + num_attention_heads=( + num_attention_heads.float() + if isinstance(num_attention_heads, torch.Tensor) + else num_attention_heads + ), + num_query_groups=self.config.num_query_groups, + ffn_hidden_size=flex_ffn_hidden_size.float(), + hidden_size=flex_hidden_size.unsqueeze(-1).float(), + kv_channels=self.config.kv_channels, + vocab_size=self.model.vocab_size, + tied_vocab=self.model.share_embeddings_and_output_weights, + mem_infer_seq_len=self.config.mem_infer_seq_len, + mem_batch_size=self.config.mem_batch_size, + prefill_chunk_size=self.config.prefill_chunk_size, + moe_num_experts=flex_moe_expert.float(), + shared_expert_intermediate_size=self.config.moe_shared_expert_intermediate_size, + moe_router_topk=self.config.moe_router_topk, + memory_config=self.memory_config, + ).float() + diff = abs(current_mem / budget_item - 1) + else: + raise ValueError(f"Invalid budget type: {self.config.budget_type}") + + # return current_param, {} + if budget_item != 1.0 and diff < 0.05: + diff = diff * 0.0 + + # if getattr(self.config, 'disable_budget', False): + # diff = diff * 0.0 + + if budget_item == 1.0: + if self.config.flex_hetero_moe_expert: + label_moe_expert = torch.zeros_like(flextron_kwargs['router_moe_expert'][0]) + label_moe_expert[:, 0] = 1.0 + mse_loss_moe_expert = F.mse_loss( + flextron_kwargs['router_moe_expert'][0], label_moe_expert + ) + else: + label_moe_expert = torch.zeros_like(flextron_kwargs['router_moe_expert'][0]) + label_moe_expert[0] = 1.0 + mse_loss_moe_expert = F.mse_loss( + flextron_kwargs['router_moe_expert'][0], label_moe_expert + ) + + if self.config.flex_hetero_mamba: + label_mamba = torch.zeros_like(flextron_kwargs['router_mamba'][0]) + label_mamba[:, 0] = 1.0 + mse_loss_mamba = F.mse_loss(flextron_kwargs['router_mamba'][0], label_mamba) + else: + label_mamba = torch.zeros_like(flextron_kwargs['router_mamba'][0]) + label_mamba[0] = 1.0 + mse_loss_mamba = F.mse_loss(flextron_kwargs['router_mamba'][0], label_mamba) + + if self.config.flex_hetero_ffn: + label_mlp = torch.zeros_like(flextron_kwargs['router_mlp'][0]) + label_mlp[:, 0] = 1.0 + mse_loss_mlp = F.mse_loss(flextron_kwargs['router_mlp'][0], label_mlp) + else: + label_mlp = torch.zeros_like(flextron_kwargs['router_mlp'][0]) + label_mlp[0] = 1.0 + mse_loss_mlp = F.mse_loss(flextron_kwargs['router_mlp'][0], label_mlp) + + if self.config.add_skipping: + label_skip = torch.zeros_like(flextron_kwargs['router_skip'][0]) + label_skip[0] = 1.0 + mse_loss_skip = F.mse_loss(flextron_kwargs['router_skip'][0], label_skip) + else: + mse_loss_skip = 0.0 + + label_emb = torch.zeros_like(flextron_kwargs['router_emb'][0]) + label_emb[0] = 1.0 + mse_loss_emb = F.mse_loss(flextron_kwargs['router_emb'][0], label_emb) + + diff += 10 * ( + mse_loss_mamba + mse_loss_mlp + mse_loss_moe_expert + mse_loss_skip + mse_loss_emb + ) + + return diff.bfloat16(), {} + + def get_loss_func(self): + """Get the budget loss function.""" + return self.budget_loss_func + + def process_router_output(self, budget_item): + """Process router output and return flextron_kwargs.""" + if self.router is None: + return {}, None + + (router_mlp, router_skip, router_emb, router_mamba, router_moe_expert) = self.router( + budget_item + ) + + flextron_kwargs = { + 'router_mlp': router_mlp, + 'router_skip': router_skip, + 'router_emb': router_emb, + 'router_mamba': router_mamba, + 'router_moe_expert': router_moe_expert, + } + + return flextron_kwargs, self.get_loss_func() + + def update_hook_elasticity_params(self, flextron_kwargs): + """Update elasticity parameters in all hook managers.""" + if not self.hook_managers: + return + + # Extract elasticity parameters from router outputs + router_emb = flextron_kwargs.get('router_emb') + router_mamba = flextron_kwargs.get('router_mamba') + router_mlp = flextron_kwargs.get('router_mlp') + router_moe_expert = flextron_kwargs.get('router_moe_expert') + router_skip = flextron_kwargs.get('router_skip') # General layer skipping + + # Update all hook managers with router outputs directly + for manager in self.hook_managers: + if hasattr(manager, 'set_elasticity_params'): + manager.set_elasticity_params( + router_emb=router_emb, + router_mamba=router_mamba, + router_mlp=router_mlp, + router_moe_expert=router_moe_expert, + router_skip=router_skip, + ) + + +def setup_flextron_model(model): + """ + Setup Flextron functionality for a model after creation. + + Args: + model: The HybridModel instance + + Returns: + FlextronModelManager: Manager object to handle Flextron operations + """ + manager = FlextronModelManager(model, model.config) + + # Setup all Flextron components + manager.setup_router() + manager.setup_budget_functions() + manager.setup_hooks() + + # Store manager on model for easy access + model._flextron_manager = manager + + return manager + + +def inject_flextron_forward_logic(model): + """ + Inject Flextron-specific forward pass logic into the model. + This replaces the router logic that was previously in HybridModel.forward(). + """ + original_forward = model.forward + + def flextron_forward( + self, + input_ids, + position_ids, + attention_mask, + decoder_input=None, + labels=None, + inference_context=None, + runtime_gather_output=None, + *, + inference_params=None, + **flextron_kwargs, + ): + + # Handle override budget settings + if getattr(self.config, 'override_selected_budget', None) is not None: + assert ( + self.config.is_flex_eval + ), "Override selected budget should only be set in flex eval mode" + # Both branches must populate the 'budget' key — downstream code + # at line 422 reads it unconditionally. Setting budget=1.0 routes + # the override-1.0 case through the regular router-forward path, + # which after training produces near-identity router outputs that + # mask down to the full model. + flextron_kwargs = {'budget': self.config.override_selected_budget[0]} + + # Initialize budget_loss + budget_loss = None + + # Handle router logic if enabled and model has Flextron manager + if ( + hasattr(self, '_flextron_manager') + and self._flextron_manager is not None + and self._flextron_manager.router is not None + ): + # Every step is router-driven, including budget=1.0 (which now + # propagates the identity-MSE regularization that the old + # ``original_model`` kill-switch silently zeroed out). Use + # ``freeze_router`` if you need to train without router gradients. + budget_item = flextron_kwargs['budget'] + + # Get router output and loss function + flextron_kwargs, loss_func = self._flextron_manager.process_router_output(budget_item) + + # Calculate loss + if loss_func: + budget_loss = loss_func(flextron_kwargs, budget_item) + + # Push router outputs into the elasticity hook managers so masks + # fire on this forward. + self._flextron_manager.update_hook_elasticity_params(flextron_kwargs) + else: + # If no Flextron manager, clear flextron_kwargs to avoid passing unknown args + flextron_kwargs = {} + + # Call original forward with processed flextron_kwargs + result = original_forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + decoder_input=decoder_input, + labels=labels, + inference_context=inference_context, + runtime_gather_output=runtime_gather_output, + inference_params=inference_params, + ) + + # Handle return values based on training mode and flextron settings + if labels is not None: + loss = result if not isinstance(result, tuple) else result[0] + + if ( + hasattr(self, '_flextron_manager') + and self._flextron_manager is not None + and self._flextron_manager.router is not None + and getattr(self.config, 'flextron', False) + and not getattr(self.config, 'is_flex_eval', False) + ): + if mpu.is_pipeline_last_stage(): + return loss, budget_loss + else: + return loss + else: + # Evaluation mode or non-flextron, return loss only + return loss + else: + # No labels, return logits + return result + + # Replace the forward method + model.forward = flextron_forward.__get__(model, model.__class__) diff --git a/megatron/elastification/loss_func.py b/megatron/elastification/loss_func.py new file mode 100644 index 00000000000..cd0be6df67c --- /dev/null +++ b/megatron/elastification/loss_func.py @@ -0,0 +1,210 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Flextron loss function(s). + +Combines lm loss with the router's budget loss, optional KD distillation +loss, and per-budget reporting (full-model vs sub-budget breakdown). +""" + +import torch + +from megatron.core import parallel_state +from megatron.core.models.gpt import GPTModel +from megatron.training import get_args +from megatron.training.utils import unwrap_model + + +def _mask_loss(output_tensor, loss_mask): + """Apply mask to the unreduced loss tensor.""" + args = get_args() + if isinstance(output_tensor, tuple) and len(output_tensor) == 2: + (output_tensor, (param_loss, extra_reporting_dict)) = output_tensor + tp_reduce, is_sequence_parallel = False, False + elif isinstance(output_tensor, tuple): + # Special distillation flags indicating whether to perform additional tensor-parallel adjustments. + output_tensor, tp_reduce, is_sequence_parallel = output_tensor + param_loss = None + else: + tp_reduce, is_sequence_parallel = False, False + param_loss = None + + num_tokens = loss_mask.sum().float() + + if param_loss is not None: + if param_loss > 0: + pass + else: + param_loss = -args.router_beta * param_loss + + if is_sequence_parallel: + # Sequence-parallel tensor derived from intermediate activation - need to split loss mask. + idx = parallel_state.get_tensor_model_parallel_rank() + loss_mask = torch.tensor_split(loss_mask, args.tensor_model_parallel_size, dim=1)[idx] + + losses = output_tensor.view(-1).float() + loss_mask = loss_mask.reshape(-1).float() + loss = torch.sum(losses * loss_mask) + + alpha = args.loss_alpha + if not args.freeze_router and param_loss is not None: + param_loss_item = param_loss[0] * num_tokens * alpha + # add param loss to lm loss + loss += param_loss_item + else: + param_loss_item = None + + if tp_reduce or is_sequence_parallel: + # Losses on parallel tensors require extra all-reduce to sync across MP ranks. + torch.distributed.all_reduce(loss, group=parallel_state.get_tensor_model_parallel_group()) + + if param_loss_item is not None: + return loss, param_loss_item + else: + return loss + + +def loss_func( + loss_mask: torch.Tensor, + output_tensor: torch.Tensor, + model: GPTModel, + selected_budget: float = None, +): + """Loss function (with KD Loss support). + + Args: + loss_mask (Tensor): Used to mask out some portions of the loss + output_tensor (Tensor): The tensor with the losses + model (GPTModel): The model (can be wrapped) + selected_budget (float): The budget value used for this forward pass + """ + args = get_args() + + # Unwrap for both Distillation and LANA + model = unwrap_model(model) + + # Standard lm loss + out_mask_loss = _mask_loss(output_tensor, loss_mask) + + if isinstance(out_mask_loss, tuple): + loss_lm, param_loss_item = out_mask_loss + else: + # assert args.freeze_router, "Param loss None is not supported without freezing router" + loss_lm = out_mask_loss + param_loss_item = torch.tensor(0.0, device=loss_lm.device, dtype=loss_lm.dtype) + + loss = loss_lm + num_tokens = loss_mask.sum().clone().detach().to(torch.int) + # Protect against division by zero when all tokens are masked. + num_tokens = torch.clamp(num_tokens, min=1) + # Report (value, num_tokens) as local-rank values; the training loop performs the + # DP+CP all-reduce on report-dict tuples (training.py: token-weighted reduction). + report = { + 'lm loss': ((loss_lm.detach() - param_loss_item.detach()).view(1), num_tokens), + 'param loss item': (param_loss_item.detach().view(1), num_tokens), + } + + # Add per-model LM loss breakdown for logging only when KD is NOT active + kd_active = model.training and args.export_kd_teacher_load + if not kd_active: + try: + is_full_model = (param_loss_item is None) or (param_loss_item.detach().abs() == 0) + except Exception: + is_full_model = False + zero_num = torch.zeros_like(report['lm loss'][0]) + zero_den = torch.zeros_like(num_tokens) + if is_full_model: + report['lm loss (full)'] = report['lm loss'] + report['lm loss (budget)'] = (zero_num, zero_den) + else: + report['lm loss (budget)'] = report['lm loss'] + report['lm loss (full)'] = (zero_num, zero_den) + + if model.training and args.export_kd_teacher_load: + # [ModelOpt]: Handle knowledge distillation. + # The installed balancer with skip_lm_loss=True drops student_loss (param_loss) from + # the total. Add loss_lm back manually to restore the router gradient signal. + losses = model.compute_kd_loss( + student_loss=loss_lm, loss_reduction_fn=lambda x: _mask_loss(x, loss_mask) + ) + loss = losses["kd_loss"] + param_loss_item + # All-gather logits_loss across DP ranks so we can mask by selected_budget below. + logits_loss = losses["logits_loss"].detach() + dp_world_size = torch.distributed.get_world_size( + group=parallel_state.get_data_parallel_group() + ) + logits_loss_gathered = [torch.zeros_like(logits_loss) for _ in range(dp_world_size)] + torch.distributed.all_gather( + logits_loss_gathered, logits_loss, group=parallel_state.get_data_parallel_group() + ) + logits_loss_gathered = torch.stack(logits_loss_gathered) + + total_loss_report = losses["kd_loss"].detach() + param_loss_item.detach() + report["total loss"] = (total_loss_report, num_tokens) + + # Log KD loss split into full vs budget similar to LM loss breakdown. + try: + is_full_model_kd = (param_loss_item is None) or (param_loss_item.detach().abs() == 0) + except Exception: + is_full_model_kd = False + zero_num_kd = torch.zeros_like(total_loss_report) + zero_den_kd = torch.zeros_like(num_tokens) + if is_full_model_kd: + report["kd loss (full)"] = (total_loss_report, num_tokens) + report["kd loss (budget)"] = (zero_num_kd, zero_den_kd) + else: + report["kd loss (budget)"] = (total_loss_report, num_tokens) + report["kd loss (full)"] = (zero_num_kd, zero_den_kd) + report["logits distillation loss"] = (losses["logits_loss"].detach(), num_tokens) + report["intermediate distillation loss"] = ( + losses["intermediate_loss"].detach(), + num_tokens, + ) + + local_budget = torch.tensor( + [selected_budget], dtype=torch.float32, device=logits_loss.device + ) + budgets_gathered = [torch.zeros_like(local_budget) for _ in range(dp_world_size)] + torch.distributed.all_gather( + budgets_gathered, local_budget, group=parallel_state.get_data_parallel_group() + ) + budgets_gathered = torch.cat(budgets_gathered) + + # Create a binary mask where gathered budgets are equal to selected_budget (with 1e-6 tolerance) + budget_mask = (budgets_gathered - selected_budget).abs() < 1e-6 + logits_loss_gathered_selected = logits_loss_gathered[budget_mask].sum() / budget_mask.sum() + budget_num_tokens = ( + num_tokens.float() * budget_mask.sum() / budget_mask.shape[0] / budget_mask.sum() + ) + + corrected_budget_list = list(set(args.budget_list)) + + for temp_budget in corrected_budget_list: + report[f"logits distillation loss {temp_budget:.3f}"] = ( + torch.tensor(0.0, device=logits_loss.device, dtype=torch.float32), + torch.tensor(0.0, device=logits_loss.device, dtype=torch.float32), + ) + index_of_selected_budget = corrected_budget_list.index(selected_budget) + all_budget_logit = torch.zeros( + len(corrected_budget_list), device=logits_loss.device, dtype=logits_loss.dtype + ) + all_budget_tokens = torch.zeros( + len(corrected_budget_list), device=logits_loss.device, dtype=logits_loss.dtype + ) + + all_budget_logit[index_of_selected_budget] = logits_loss_gathered_selected + all_budget_tokens[index_of_selected_budget] = budget_num_tokens + + for i in range(len(corrected_budget_list)): + report[f"logits distillation loss {corrected_budget_list[i]:.3f}"] = ( + all_budget_logit[i], + all_budget_tokens[i], + ) + + # Convert all items in report dict to a single (value, num_tokens) tensor. + for key, val in report.items(): + assert isinstance(val, tuple), "Value is not a tuple" + report[key] = torch.tensor( + [val[0], val[1].view(1)], device=loss_lm.device, dtype=loss_lm.dtype + ) + + return loss, num_tokens, report diff --git a/megatron/elastification/memory_config.py b/megatron/elastification/memory_config.py new file mode 100644 index 00000000000..3ac5b7132e6 --- /dev/null +++ b/megatron/elastification/memory_config.py @@ -0,0 +1,140 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +memory_config.py — MemoryConfig dataclass and loader for Flextron budget calculations. + +Usage +----- +From CLI args (in training/eval scripts): + cfg = load_memory_config(args) + total_gb = get_memory_footprint(..., memory_config=cfg) + +Directly (in tests or notebooks): + cfg = MemoryConfig(bpe_kv_cache=1, param_budget_target='active') +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + +import yaml + +# Path to the bundled presets file (same directory as this module). +_DEFAULT_PROFILES_PATH = os.path.join(os.path.dirname(__file__), "memory_profiles.yaml") + + +@dataclass +class MemoryConfig: + """ + Bytes-per-element for each memory component and param budget supervision target. + + Attributes + ---------- + bpe_params : float + Bytes per weight parameter element (2 = BF16, 1 = FP8/INT8, 0.5625 = FP4). + bpe_kv_cache : float + Bytes per KV-cache element. + bpe_ssm_cache : float + Bytes per SSM-state element (covers both conv_state and ssm_state). + bpe_max_buffer : float + Bytes per MoE dispatch buffer element. + param_budget_target : str + Whether the param-budget loss supervises on ``'active'`` (top-k experts only) + or ``'total'`` (all parameters including non-active experts) parameter count. + """ + + bpe_params: float = 2.0 + bpe_kv_cache: float = 2.0 + bpe_ssm_cache: float = 2.0 + bpe_max_buffer: float = 2.0 + param_budget_target: str = "active" # "active" | "total" + + def __post_init__(self): + valid_targets = {"active", "total"} + if self.param_budget_target not in valid_targets: + raise ValueError( + f"param_budget_target must be one of {valid_targets}, " + f"got '{self.param_budget_target}'" + ) + + +def load_memory_config(args) -> MemoryConfig: + """ + Build a MemoryConfig from parsed CLI args. + + Resolution order (highest wins): + 1. Individual override args (--bpe-params, --bpe-kv-cache, …) + 2. Named preset from YAML (--memory-profile ) + 3. Built-in defaults (BF16 everywhere, active param target) + + Parameters + ---------- + args : argparse.Namespace + Parsed arguments. Relevant attributes (all optional): + memory_profile str — preset name (default: 'bf16') + memory_profile_path str — path to YAML profiles file + bpe_params float — override + bpe_kv_cache float — override + bpe_ssm_cache float — override + bpe_max_buffer float — override + param_budget_target str — override ('active' | 'total') + """ + cfg = MemoryConfig() + + # ── Load preset from YAML ────────────────────────────────────────────── + profile_name = getattr(args, "memory_profile", "bf16") or "bf16" + profile_path = getattr(args, "memory_profile_path", None) or _DEFAULT_PROFILES_PATH + print(f"[memory_config] profile='{profile_name}' path={profile_path}") + + if not os.path.isfile(profile_path): + raise FileNotFoundError(f"Memory profiles file not found: {profile_path}") + + with open(profile_path) as f: + profiles = yaml.safe_load(f) + + presets = profiles.get("presets", {}) + if profile_name not in presets: + available = list(presets.keys()) + raise ValueError( + f"Memory profile '{profile_name}' not found in {profile_path}. " + f"Available: {available}" + ) + + preset = presets[profile_name] + cfg.bpe_params = float(preset.get("params", cfg.bpe_params)) + cfg.bpe_kv_cache = float(preset.get("kv_cache", cfg.bpe_kv_cache)) + cfg.bpe_ssm_cache = float(preset.get("ssm_cache", cfg.bpe_ssm_cache)) + cfg.bpe_max_buffer = float(preset.get("max_buffer", cfg.bpe_max_buffer)) + cfg.param_budget_target = preset.get("param_budget_target", cfg.param_budget_target) + print( + f"[memory_config] after preset : bpe_params={cfg.bpe_params} bpe_kv_cache={cfg.bpe_kv_cache} " + f"bpe_ssm_cache={cfg.bpe_ssm_cache} bpe_max_buffer={cfg.bpe_max_buffer} " + f"param_budget_target={cfg.param_budget_target}" + ) + + # ── Apply individual CLI overrides (take priority over preset) ───────── + if getattr(args, "bpe_params", None) is not None: + print(f"[memory_config] override bpe_params: {cfg.bpe_params} -> {args.bpe_params}") + cfg.bpe_params = float(args.bpe_params) + if getattr(args, "bpe_kv_cache", None) is not None: + print(f"[memory_config] override bpe_kv_cache: {cfg.bpe_kv_cache} -> {args.bpe_kv_cache}") + cfg.bpe_kv_cache = float(args.bpe_kv_cache) + if getattr(args, "bpe_ssm_cache", None) is not None: + print( + f"[memory_config] override bpe_ssm_cache: {cfg.bpe_ssm_cache} -> {args.bpe_ssm_cache}" + ) + cfg.bpe_ssm_cache = float(args.bpe_ssm_cache) + if getattr(args, "bpe_max_buffer", None) is not None: + print( + f"[memory_config] override bpe_max_buffer: {cfg.bpe_max_buffer} -> {args.bpe_max_buffer}" + ) + cfg.bpe_max_buffer = float(args.bpe_max_buffer) + if getattr(args, "param_budget_target", None) is not None: + print( + f"[memory_config] override param_budget_target: {cfg.param_budget_target} -> {args.param_budget_target}" + ) + cfg.param_budget_target = args.param_budget_target + + print(f"[memory_config] final : {cfg}") + return cfg diff --git a/megatron/elastification/memory_profiles.yaml b/megatron/elastification/memory_profiles.yaml new file mode 100644 index 00000000000..f060ce5d0cb --- /dev/null +++ b/megatron/elastification/memory_profiles.yaml @@ -0,0 +1,59 @@ +# memory_profiles.yaml +# +# Named memory quantization profiles for Flextron budget calculations. +# Each preset specifies bytes-per-element (bpe) for each memory component +# and whether the param budget loss targets total or active parameters. +# +# Select a preset via: --memory-profile +# Override a single value via: --bpe-kv-cache 1 (takes priority over preset) +# Use a custom file via: --memory-profile-path /path/to/file.yaml + +presets: + + # ── Standard BF16 inference (current implementation) ────────────────────── + bf16: + params: 2 # BF16 weights + kv_cache: 2 # BF16 KV cache + ssm_cache: 2 # BF16 SSM state (conv + ssm) + max_buffer: 2 # BF16 MoE dispatch buffer + param_budget_target: active + + # ── FP8 KV cache, BF16 weights (common serving optimisation) ────────────── + fp8_kv: + params: 2 # BF16 weights + kv_cache: 1 # FP8 KV cache + ssm_cache: 2 # BF16 SSM state + max_buffer: 2 # BF16 MoE dispatch buffer + param_budget_target: active + + # ── FP8 KV + SSM cache ──────────────────────────────────────────────────── + fp8_kv_ssm: + params: 2 # BF16 weights + kv_cache: 1 # FP8 KV cache + ssm_cache: 1 # FP8/INT8 SSM state + max_buffer: 2 # BF16 MoE dispatch buffer + param_budget_target: active + + # ── Fully quantised FP8 inference ───────────────────────────────────────── + fp8_all: + params: 1 # FP8 weights + kv_cache: 1 # FP8 KV cache + ssm_cache: 1 # FP8 SSM state + max_buffer: 1 # FP8 MoE dispatch buffer + param_budget_target: active + + # ── INT8 weights + INT8 caches ──────────────────────────────────────────── + int8: + params: 1 # INT8 weights + kv_cache: 1 # INT8 KV cache + ssm_cache: 1 # INT8 SSM state + max_buffer: 2 # BF16 MoE dispatch buffer (activation, usually not quantised) + param_budget_target: active + + # ── FP4 (speculative / future) ──────────────────────────────────────────── + fp4: + params: 0.5625 # FP4 weights + kv_cache: 0.5625 # FP4 KV cache + ssm_cache: 1 # INT8 SSM state + max_buffer: 2 # BF16 MoE dispatch buffer + param_budget_target: active diff --git a/megatron/elastification/pretrain_hybrid_flex.py b/megatron/elastification/pretrain_hybrid_flex.py new file mode 100644 index 00000000000..08df15333e9 --- /dev/null +++ b/megatron/elastification/pretrain_hybrid_flex.py @@ -0,0 +1,572 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Pretrain and SFT Mamba.""" + +import os +from functools import partial +from typing import List, Optional, Tuple, Union + +import torch + +from megatron.core import mpu, parallel_state +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset +from megatron.core.enums import ModelType +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.hybrid.hybrid_model import HybridModel +from megatron.core.num_microbatches_calculator import ( + get_current_global_batch_size, + get_micro_batch_size, +) +from megatron.core.parallel_state import ( + get_context_parallel_rank, + get_context_parallel_world_size, + get_data_parallel_rank, + get_data_parallel_world_size, + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, +) +from megatron.core.rerun_state_machine import get_rerun_state_machine +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.spec_utils import import_module +from megatron.core.utils import StragglerDetector +from megatron.elastification.arguments import add_flextron_args +from megatron.training import ( + get_args, + get_timers, + get_tokenizer, + inprocess_restart, + pretrain, + print_rank_0, +) +from megatron.training.argument_utils import pretrain_cfg_container_from_args +from megatron.training.arguments import core_transformer_config_from_args, parse_and_validate_args +from megatron.training.datasets.sft_dataset import SFTDataset +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, + get_blend_and_blend_per_split, +) + +# modelopt distillation +try: + from megatron.elastification.loss_func import loss_func as loss_func_modelopt + from megatron.post_training.arguments import add_modelopt_args + from megatron.post_training.model_builder import ( + modelopt_gpt_mamba_builder as model_provider_modelopt, + ) + + has_nvidia_modelopt = True +except ImportError: + print_rank_0("ModelOpt is not installed. Please install it using `pip install nvidia-modelopt`") + has_nvidia_modelopt = False +print_rank_0("has_nvidia_modelopt is {}".format(has_nvidia_modelopt)) +import numpy as np + +try: + # Register the TE CUDA kernels + import transformer_engine # pylint: disable=unused-import + + # Alias the PyTorch wrapper so we can call tex.* APIs + import transformer_engine_torch as tex +except ImportError: + # TE isn’t installed or the torch wrapper is missing + tex = None + +from megatron.core.utils import is_te_min_version + +_global_choice_counter = 0 +_logged_params_norm = False + +stimer = StragglerDetector() + + +def count_parameters_in_layer(model, layer_name): + num_params = 0 + for name, param in model.named_parameters(): + if layer_name in name: + num_params += param.numel() + print_rank_0(f" - {name}: {param.numel()}") + return num_params + + +def model_provider( + pre_process=True, + post_process=True, + vp_stage: Optional[int] = None, + config=None, + pg_collection=None, +) -> HybridModel: + """Builds the model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embeddings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + HybridModel: The returned model + """ + args = get_args() + if has_nvidia_modelopt: + + model = model_provider_modelopt( + args, + pre_process, + post_process, + vp_stage=vp_stage, + config=config, + pg_collection=pg_collection, + ) + from megatron.elastification.flextron_utils import ( + inject_flextron_forward_logic, + setup_flextron_model, + ) + + setup_flextron_model(model) + inject_flextron_forward_logic(model) + + if args.freeze_model: + for name, param in model.named_parameters(): + if 'gate' not in name: + param.requires_grad = False + + if args.freeze_router: + for name, param in model.named_parameters(): + if 'gate' in name: + param.requires_grad = False + + return model + + print_rank_0('building Mamba model ...') + config = core_transformer_config_from_args(args, TransformerConfig) + + assert args.use_legacy_models == False, "Mamba only supported in Mcore!" + + if args.spec is not None: + hybrid_stack_spec = import_module(args.spec) + else: + raise ValueError("You must provide a valid Mamba layer spec!") + + model = HybridModel( + config=config, + hybrid_stack_spec=hybrid_stack_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + hybrid_layer_pattern=args.hybrid_layer_pattern, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + vp_stage=vp_stage, + ) + from megatron.elastification.flextron_utils import ( + inject_flextron_forward_logic, + setup_flextron_model, + ) + + setup_flextron_model(model) + inject_flextron_forward_logic(model) + + for l in range(model.decoder.num_layers_per_pipeline_rank): + layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.') + print_rank_0(f" == params layer {l}: {layer_params}") + + return model + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None, None, None + + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator) + + cu_seqlens = batch['cu_seqlens'] + if cu_seqlens is None: + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore + else: # Packed THD format + assert ( + cu_seqlens.dim() == 2 and cu_seqlens.shape[0] == 1 + ), "micro-batch-size must be 1 for packing" + cu_seqlens = cu_seqlens[0] + batch['cu_seqlens'] = cu_seqlens + + max_seqlen = batch['max_seqlen'] + assert max_seqlen.dim() == 1 + # TODO(duncan): can this be kept as a 0-D tensor? + batch['max_seqlen'] = int(max_seqlen[0].item()) + + cp_size = get_context_parallel_world_size() + if cp_size > 1: # slice batch along sequence dimension for context parallelism + assert tex is not None and is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + cp_rank = get_context_parallel_rank() + index = tex.thd_get_partitioned_indices( + cu_seqlens, batch['tokens'].size(1), cp_size, cp_rank + ) + for key, data in batch.items(): + if key in {'attention_mask', 'cu_seqlens', 'max_seqlen'}: + continue + batch[key] = data.index_select(1, index) + + return ( + batch.get('tokens'), + batch.get('labels'), + batch.get('loss_mask'), + batch.get('attention_mask'), + batch.get('position_ids'), + batch.get('cu_seqlens'), + batch.get('max_seqlen'), + ) + + +# define spiky loss as a loss that's 10x the max loss observed +SPIKY_LOSS_FACTOR = 10 + + +def loss_func( + loss_mask: torch.Tensor, + output_tensor: torch.Tensor, + model: Optional[HybridModel] = None, + selected_budget=None, +): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + if has_nvidia_modelopt: + return loss_func_modelopt( + loss_mask, output_tensor, model=model, selected_budget=selected_budget + ) + + alpha = args.loss_alpha + + (output_tensor, (param_loss, extra_reporting_dict)) = output_tensor + + if param_loss is not None: + if param_loss > 0: + param_loss_report = param_loss.detach().clone() + else: + param_loss_report = param_loss.detach().clone() + param_loss = -args.router_beta * param_loss + + losses = output_tensor.view(-1).float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses * loss_mask) + + # Check individual rank losses are not NaN prior to DP all-reduce. + rerun_state_machine = get_rerun_state_machine() + if args.check_for_nan_in_loss_and_grad: + rerun_state_machine.validate_result( + result=loss, + rejection_func=torch.isnan, + message="found NaN in local forward loss calculation", + tolerance=0.0, # forward pass calculations are deterministic + fatal=True, + ) + rerun_state_machine.validate_result( + result=loss, + rejection_func=torch.isinf, + message="found Inf in local forward loss calculation", + tolerance=0.0, # forward pass calculations are deterministic + fatal=True, + ) + # Check for spiky loss + if args.check_for_spiky_loss: + rerun_state_machine.validate_result( + result=loss, + rejection_func=partial( + rerun_state_machine.is_unexpectedly_large, + threshold=SPIKY_LOSS_FACTOR, + context="loss", + ), + message="Spiky loss", + tolerance=0.0, # forward pass calculations are deterministic + fatal=False, + ) + + num_tokens = loss_mask.sum().clone().detach().to(torch.int) + + if param_loss is not None: + param_loss *= num_tokens * alpha + if param_loss < 0: + param_loss = -args.router_beta * param_loss + + param_loss_report = torch.cat([param_loss.clone().detach().view(1), num_tokens.view(1)]) + lm_loss_report = torch.cat([loss.clone().detach().view(1), num_tokens.view(1)]) + loss += param_loss[0] + + # Protect against division by zero when all tokens are masked. + num_tokens = torch.clamp(num_tokens, min=1) + reporting_loss = torch.cat([loss.clone().detach().view(1), num_tokens.view(1)]) + + if param_loss is not None: + return ( + loss, + num_tokens, + { + 'lm loss': lm_loss_report, + 'param loss': param_loss_report, + 'total loss': reporting_loss, + }, + ) + else: + return (loss, num_tokens, {'lm loss': reporting_loss}) + + +def get_grad_acc_based_random_choice(args, choices=None, prob=None, base_seed=42): + + dp_size = get_data_parallel_world_size() + grad_accumulation_steps = get_current_global_batch_size() // (get_micro_batch_size() * dp_size) + global _global_choice_counter + + # DP-specific seeding + rng = np.random.RandomState( + base_seed + _global_choice_counter + grad_accumulation_steps * args.curr_iteration * 10 + ) + if choices is None: + choice = rng.uniform(0, 1) + else: + if prob is None: + choice = rng.choice(choices) + else: + choice = rng.choice(choices, p=prob) + _global_choice_counter += 1 + _global_choice_counter %= grad_accumulation_steps + return choice + + +def forward_step(data_iterator, model: HybridModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (HybridModel): The GPT Model + """ + args = get_args() + timers = get_timers() + + # One-time per-component params-norm breakdown (mirrors calc_params_l2_norm). + global _logged_params_norm + if not _logged_params_norm: + _logged_params_norm = True + from collections import defaultdict + + groups = defaultdict(float) + trainable_sq = frozen_sq = total_sq = 0.0 + for name, param in model.named_parameters(): + # Use fp32 main_param when distributed optimizer is active (matches WandB metric). + # Note: main_param can be None for some params under DistOpt; getattr's default + # only applies when the attr is missing, so handle the None case explicitly. + main = getattr(param, 'main_param', None) + p = (main if main is not None else param.detach()).float() + norm_sq = p.norm(2).item() ** 2 + # Strip DDP 'module.' wrappers to get the logical top-level name. + clean = name + while clean.startswith('module.'): + clean = clean[len('module.') :] + top = clean.split('.')[0] + groups[top] += norm_sq + total_sq += norm_sq + if param.requires_grad: + trainable_sq += norm_sq + else: + frozen_sq += norm_sq + print_rank_0( + f"[PARAMS_NORM] total={total_sq**0.5:.2f} " + f"trainable={trainable_sq**0.5:.2f} frozen={frozen_sq**0.5:.2f}" + ) + for grp, sq in sorted(groups.items()): + print_rank_0(f"[PARAMS_NORM] {grp}: {sq**0.5:.2f}") + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + (tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, max_seqlen) = ( + get_batch(data_iterator) + ) + timers('batch-generator').stop() + + if get_grad_acc_based_random_choice(args=args) < args.original_model_sample_prob: + # Funnel "full-model sample" through the regular router-driven path + # with budget=1.0. flextron_forward unconditionally reads + # flextron_kwargs['budget'], so an empty dict would KeyError here. + flextron_kwargs = {'budget': 1.0} + selected_budget = 1.0 + else: + if args.budget_probs is None: + budget_probs = [1.0 for _ in args.budget_list] + else: + budget_probs = args.budget_probs + + assert len(args.budget_list) == len( + budget_probs + ), "budget_list and budget_probs must have the same length" + budget_probs = [float(p) for p in budget_probs] + budget_probs = [p / sum(budget_probs) for p in budget_probs] + selected_budget = get_grad_acc_based_random_choice( + args=args, choices=args.budget_list, prob=budget_probs + ) + flextron_kwargs = {'budget': selected_budget} + + with stimer: + output_tensor = model( + tokens, position_ids, attention_mask, labels=labels, **flextron_kwargs + ) + + # [ModelOpt]: model is needed to access ModelOpt distillation losses + return output_tensor, partial( + loss_func, loss_mask, model=model, selected_budget=selected_budget + ) + + +def is_dataset_built_on_rank(vp_stage=None): + ignore_virtual = True + if vp_stage is not None: + ignore_virtual = False + return ( + mpu.is_pipeline_first_stage(ignore_virtual=ignore_virtual, vp_stage=vp_stage) + or mpu.is_pipeline_last_stage(ignore_virtual=ignore_virtual, vp_stage=vp_stage) + ) and mpu.get_tensor_model_parallel_rank() == 0 + + +def core_gpt_dataset_config_from_args(args): + tokenizer = get_tokenizer() + + # Sometimes --data-path is too long, instead we parse it from a file. + blend: Optional[Tuple[List[str], Optional[List[float]]]] + blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] + blend, blend_per_split = get_blend_and_blend_per_split(args) + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=blend, + blend_per_split=blend_per_split, + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + object_storage_cache_path=args.object_storage_cache_path, + mid_level_dataset_surplus=args.mid_level_dataset_surplus, + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + config = core_gpt_dataset_config_from_args(args) + + if args.sft: + dataset_type = SFTDataset + else: + if args.mock_data: + dataset_type = MockGPTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + partial(is_dataset_built_on_rank, vp_stage=vp_stage), + config, + ).build() + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +def mamba_flex_extra_args_provider(parser): + """Add Flextron CLI if not already registered by ``add_megatron_arguments``, then ModelOpt.""" + if not any(getattr(action, "dest", None) == "flextron" for action in parser._actions): + parser = add_flextron_args(parser) + if has_nvidia_modelopt: + parser = add_modelopt_args(parser) + return parser + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + # Optionally enable inprocess restart on pretrain + pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + + # Restore router LR multiplier (Bug 4 fix): monkey-patch get_megatron_optimizer_config + # to inject a per-parameter LR override for router params via config_overrides. + # Main branch removed the scale_lr_cond parameter from pretrain(); this achieves the same. + import megatron.training.training as _mtt + from megatron.core.optimizer.optimizer_config import ParamKey, ParamWithNamePredicate + from megatron.core.optimizer_param_scheduler import ParamGroupOverride + + _orig_get_opt_cfg = _mtt.get_megatron_optimizer_config + + def _patched_get_opt_cfg(args): + config, config_overrides = _orig_get_opt_cfg(args) + lr_mult = getattr(args, 'lr_mult_router', 1.0) + if lr_mult != 1.0: + router_key = ParamKey( + with_name_predicate=ParamWithNamePredicate( + name="router_pp", fn=lambda p, name: 'router_pp' in name + ) + ) + router_override = ParamGroupOverride( + max_lr=args.lr * lr_mult, min_lr=args.min_lr * lr_mult + ) + config_overrides = {**(config_overrides or {}), router_key: router_override} + return config, config_overrides + + _mtt.get_megatron_optimizer_config = _patched_get_opt_cfg + + # `pretrain()` no longer accepts extra_args_provider / args_defaults; parse + # args up-front instead (see pretrain_mamba.py for the same pattern). + args = parse_and_validate_args( + extra_args_provider=mamba_flex_extra_args_provider, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + ) + + full_config = pretrain_cfg_container_from_args(args) + pretrain( + full_config, + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + store=store, + ) diff --git a/megatron/elastification/router/__init__.py b/megatron/elastification/router/__init__.py new file mode 100644 index 00000000000..26496bfed70 --- /dev/null +++ b/megatron/elastification/router/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/elastification/router/flex_budget_utils.py b/megatron/elastification/router/flex_budget_utils.py new file mode 100644 index 00000000000..ba6537196f1 --- /dev/null +++ b/megatron/elastification/router/flex_budget_utils.py @@ -0,0 +1,393 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +from typing import Dict, List, Optional, Tuple, Union + +import torch + + +def get_num_parameters( + hybrid_pattern: str = None, + mamba_num_heads: int = 0, + mamba_d_head: int = 0, + mamba_d_state: int = 0, + num_attention_heads: int = 0, + num_query_groups: int = 0, + ffn_hidden_size: int = 0, + hidden_size: int = 0, + kv_channels: int = 0, + vocab_size: int = 0, + tied_vocab: bool = False, + num_experts: int = 0, + shared_expert_intermediate_size: int = 0, + moe_router_topk: int = 0, +) -> int: + + norm_multiplier = 1 + + embedding = vocab_size * hidden_size + final_layernorm = hidden_size * 1 + output_layer = 0 if tied_vocab else (vocab_size * hidden_size) + if isinstance(ffn_hidden_size, int): + flex_hetero_ffn = False + else: + flex_hetero_ffn = ffn_hidden_size.shape[0] != 1 + + if isinstance(mamba_num_heads, int): + flex_hetero_mamba = False + else: + flex_hetero_mamba = mamba_num_heads.shape[0] != 1 + + # Per-layer attention head counts arise only from layer skipping; head + # elasticity itself is no longer supported. + if isinstance(num_attention_heads, int): + per_layer_attn_heads = False + else: + per_layer_attn_heads = num_attention_heads.shape[0] != 1 + + if isinstance(num_experts, int): + flex_hetero_moe_expert = False + else: + flex_hetero_moe_expert = num_experts.shape[0] != 1 + + # MOE + + if flex_hetero_ffn or flex_hetero_moe_expert: + if flex_hetero_ffn and not flex_hetero_moe_expert: + num_experts = [num_experts] * ffn_hidden_size.shape[0] + if flex_hetero_moe_expert and not flex_hetero_ffn: + ffn_hidden_size = [ffn_hidden_size] * num_experts.shape[0] + + moe_all = [] + moe_active = [] + for i in range(len(num_experts)): + pre_moe_ln = norm_multiplier * hidden_size + linear_fc1 = ffn_hidden_size[i] * ( + hidden_size * num_experts[i] + shared_expert_intermediate_size + ) + linear_fc2 = ffn_hidden_size[i] * ( + hidden_size * num_experts[i] + shared_expert_intermediate_size + ) + linear_fc1_active = ffn_hidden_size[i] * ( + hidden_size * moe_router_topk + shared_expert_intermediate_size + ) + linear_fc2_active = ffn_hidden_size[i] * ( + hidden_size * moe_router_topk + shared_expert_intermediate_size + ) + moe_all.append(pre_moe_ln + linear_fc1 + linear_fc2) + moe_active.append(pre_moe_ln + linear_fc1_active + linear_fc2_active) + else: + pre_mlp_ln = norm_multiplier * hidden_size + linear_fc1 = ffn_hidden_size * (hidden_size * num_experts + shared_expert_intermediate_size) + linear_fc2 = ffn_hidden_size * (hidden_size * num_experts + shared_expert_intermediate_size) + linear_fc1_active = ffn_hidden_size * ( + hidden_size * moe_router_topk + shared_expert_intermediate_size + ) + linear_fc2_active = ffn_hidden_size * ( + hidden_size * moe_router_topk + shared_expert_intermediate_size + ) + moe_all = pre_mlp_ln + linear_fc1 + linear_fc2 + moe_active = pre_mlp_ln + linear_fc1_active + linear_fc2_active + + # ATT + if per_layer_attn_heads: + att = [] + for i in range(num_attention_heads.shape[0]): + input_ln = norm_multiplier * hidden_size + linear_proj = num_attention_heads[i] * kv_channels * hidden_size + linear_qkv = (num_attention_heads[i] + 2 * num_query_groups) * kv_channels * hidden_size + att.append(input_ln + linear_proj + linear_qkv) + else: + input_ln = norm_multiplier * hidden_size + linear_proj = num_attention_heads * kv_channels * hidden_size + linear_qkv = (num_attention_heads + 2 * num_query_groups) * kv_channels * hidden_size + att = input_ln + linear_proj + linear_qkv + + # Mamba + def mamba_params(mamba_nheads): + d_inner = mamba_nheads * mamba_d_head + ngroups = 8 + + def get_conv_params(kernel_size, stride): + cdim = d_inner + 2 * ngroups * mamba_d_state + cbias = cdim + cweight = cdim * stride * kernel_size + return cbias + cweight + + mamba_dt_bias = mamba_nheads + mamba_A_log = mamba_nheads + # self.d_inner_local if self.D_has_hdim else self.nheads_local, + mamba_D = mamba_nheads + mamba_input_ln = norm_multiplier * hidden_size + mamba_in_proj = hidden_size * (d_inner * 2 + 2 * ngroups * mamba_d_state + mamba_nheads) + mamba_conv = get_conv_params(4, 1) + mamba_norm = d_inner + mamba_out_proj = d_inner * hidden_size + return ( + mamba_dt_bias + + mamba_A_log + + mamba_D + + mamba_input_ln + + mamba_in_proj + + mamba_conv + + mamba_norm + + mamba_out_proj + ) + + all_params = 0 + active_params = 0 + for i, c in enumerate(hybrid_pattern): + + if c == 'M': + if flex_hetero_mamba: + mamba_idx = hybrid_pattern[: i + 1].count('M') - 1 + all_params += mamba_params(mamba_num_heads[mamba_idx]) + active_params += mamba_params(mamba_num_heads[mamba_idx]) + else: + all_params += mamba_params(mamba_num_heads) + active_params += mamba_params(mamba_num_heads) + elif c == '*': + if per_layer_attn_heads: + head_idx = hybrid_pattern[: i + 1].count('*') - 1 + all_params += att[head_idx] + active_params += att[head_idx] + else: + all_params += att + active_params += att + elif c == 'E': + if flex_hetero_ffn or flex_hetero_moe_expert: + # Count how many 'E' characters appear before and including layer i + moe_idx = hybrid_pattern[: i + 1].count('E') - 1 + all_params += moe_all[moe_idx] + active_params += moe_active[moe_idx] + else: + all_params += moe_all + active_params += moe_active + elif c == '|': + pass + else: + raise RuntimeError(f'Unknown layer type: {c}') + + return ( + embedding + all_params + final_layernorm + output_layer, + embedding + active_params + final_layernorm + output_layer, + ) + + +def get_kv_cache_size( + hybrid_pattern: str = None, + num_attention_heads=None, + num_query_groups=None, + kv_channels=None, + mem_infer_seq_len: int = 0, + mem_batch_size: int = 0, +) -> Union[int, torch.Tensor]: + + # Per-layer attention head counts arise only from layer skipping; head + # elasticity itself is no longer supported. + if isinstance(num_attention_heads, int): + per_layer_attn_heads = False + else: + per_layer_attn_heads = num_attention_heads.shape[0] != 1 + + if per_layer_attn_heads: + kv_cache_size = 0 + head_idx = 0 + + for c in hybrid_pattern: + if c == '*': + current_heads = num_attention_heads[head_idx] + + kv_cache_size_per_layer = ( + 2.0 + * mem_batch_size + * mem_infer_seq_len + * num_query_groups + * current_heads + * kv_channels + / current_heads.detach().item() + ) + kv_cache_size += kv_cache_size_per_layer + head_idx += 1 + + else: + num_attention_layers = hybrid_pattern.count('*') + divider = ( + num_attention_heads.detach().item() + if isinstance(num_attention_heads, torch.Tensor) + else num_attention_heads + ) + kv_cache_size = ( + 2.0 + * mem_batch_size + * mem_infer_seq_len + * num_query_groups + * num_attention_heads + * kv_channels + * num_attention_layers + / divider + ) + + return kv_cache_size + + +def get_mamba_ssm_cache_size( + hybrid_pattern: str = None, + mamba_num_heads: int = 0, + mamba_d_head: int = 0, + mamba_d_state: int = 0, + mem_batch_size: int = 0, +) -> int: + + if isinstance(mamba_num_heads, int): + flex_hetero_mamba = False + else: + flex_hetero_mamba = mamba_num_heads.shape[0] != 1 + + if flex_hetero_mamba: + ssm_cache_size = 0 + mamba_idx = 0 + for c in hybrid_pattern: + if c == 'M': + current_mamba_num_heads = mamba_num_heads[mamba_idx] + ssm_cache_size += ( + mem_batch_size * current_mamba_num_heads * mamba_d_head * mamba_d_state + ) + mamba_idx += 1 + + else: + num_mamba_layers = hybrid_pattern.count('M') + ssm_cache_size = ( + mem_batch_size * mamba_num_heads * mamba_d_head * mamba_d_state * num_mamba_layers + ) + + return ssm_cache_size + + +def get_max_buffer_size( + hybrid_pattern: str = None, + moe_num_experts: int = 0, + shared_expert_intermediate_size: int = 0, + ffn_hidden_size: int = 0, + moe_router_topk: int = 0, + mem_batch_size: int = 0, + prefill_chunk_size: int = 0, +) -> int: + + if isinstance(moe_num_experts, int) or moe_num_experts.shape[0] == 1: + moe_num_experts = ( + torch.tensor([moe_num_experts] * hybrid_pattern.count('E')) + .to(torch.cuda.current_device()) + .float() + ) + + if isinstance(ffn_hidden_size, int) or ffn_hidden_size.shape[0] == 1: + ffn_hidden_size = ( + torch.tensor([ffn_hidden_size] * hybrid_pattern.count('E')) + .to(torch.cuda.current_device()) + .float() + ) + + max_buffer_list = [] + moe_idx = 0 + for char in hybrid_pattern: + if char == 'E': + current_moe_num_experts = moe_num_experts[moe_idx] + current_ffn_hidden_size = ffn_hidden_size[moe_idx] + max_buffer_list.append( + shared_expert_intermediate_size + current_ffn_hidden_size * moe_router_topk + ) + moe_idx += 1 + + max_buffer = torch.stack(max_buffer_list) + max_buffer_softmax = torch.nn.functional.softmax(max_buffer, dim=0) + max_buffer = (max_buffer_softmax * max_buffer).sum().unsqueeze(0) + max_buffer *= mem_batch_size * prefill_chunk_size + + return max_buffer + + +def get_memory_footprint( + hybrid_pattern: str = None, + mamba_num_heads: int = 0, + mamba_d_head: int = 80, + mamba_d_state: int = 128, + num_attention_heads: int = 0, + num_query_groups: int = 8, + ffn_hidden_size: int = 0, + hidden_size: int = 0, + kv_channels: int = 128, + vocab_size: int = 131072, + tied_vocab: bool = False, + mem_infer_seq_len: int = 131072, + mem_batch_size: int = 1, + prefill_chunk_size: int = 16384, + moe_num_experts: int = 0, + shared_expert_intermediate_size: int = 0, + moe_router_topk: int = 0, + memory_config=None, +): + """ + Returns total inference memory footprint in GB. + + Parameters + ---------- + memory_config : MemoryConfig, optional + Bytes-per-element values and param budget target. When None, defaults + to BF16 for all components (bpe=2). Pass a MemoryConfig built via + ``load_memory_config(args)`` to select a quantisation profile. + """ + from megatron.elastification.memory_config import MemoryConfig + + if memory_config is None: + memory_config = MemoryConfig() # BF16 defaults + + # Select all-param or active-param count based on param_budget_target + param_idx = 1 if memory_config.param_budget_target == "active" else 0 + + mem_params = ( + memory_config.bpe_params + * get_num_parameters( + hybrid_pattern=hybrid_pattern, + mamba_num_heads=mamba_num_heads, + mamba_d_head=mamba_d_head, + mamba_d_state=mamba_d_state, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + hidden_size=hidden_size, + kv_channels=kv_channels, + vocab_size=vocab_size, + tied_vocab=tied_vocab, + num_experts=moe_num_experts, + shared_expert_intermediate_size=shared_expert_intermediate_size, + moe_router_topk=moe_router_topk, + )[param_idx] + ) + + mem_kv_cache = memory_config.bpe_kv_cache * get_kv_cache_size( + hybrid_pattern=hybrid_pattern, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + kv_channels=kv_channels, + mem_infer_seq_len=mem_infer_seq_len, + mem_batch_size=mem_batch_size, + ) + + mem_max_buffer = memory_config.bpe_max_buffer * get_max_buffer_size( + hybrid_pattern=hybrid_pattern, + moe_num_experts=moe_num_experts, + shared_expert_intermediate_size=shared_expert_intermediate_size, + ffn_hidden_size=ffn_hidden_size, + moe_router_topk=moe_router_topk, + mem_batch_size=mem_batch_size, + prefill_chunk_size=prefill_chunk_size, + ) + + mem_mamba_ssm_cache = memory_config.bpe_ssm_cache * get_mamba_ssm_cache_size( + hybrid_pattern=hybrid_pattern, + mamba_num_heads=mamba_num_heads, + mamba_d_head=mamba_d_head, + mamba_d_state=mamba_d_state, + mem_batch_size=mem_batch_size, + ) + return (mem_params + mem_kv_cache + mem_max_buffer + mem_mamba_ssm_cache) / 1024 / 1024 / 1024 diff --git a/megatron/elastification/router/hybrid_flex_router.py b/megatron/elastification/router/hybrid_flex_router.py new file mode 100644 index 00000000000..d9920360fc3 --- /dev/null +++ b/megatron/elastification/router/hybrid_flex_router.py @@ -0,0 +1,619 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import random + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +from megatron.core import parallel_state +from megatron.core.num_microbatches_calculator import ( + get_current_global_batch_size, + get_micro_batch_size, +) +from megatron.core.parallel_state import ( + get_data_parallel_rank, + get_data_parallel_world_size, + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, +) + +# Remove top-level import to avoid circular imports +# from megatron.training import get_args, print_rank_0 +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import init_method_normal + +# Import TE parallel linear layers +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + # Fallback to regular tensor parallel layers + from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear + + +# Router implementation for pre-gating router. +# Use router to determine #heads, MLP sizes, and layers to skip (Router_v2) +# Only takes the budget as input +class FlextronRouter(MegatronModule): + def __init__(self, config: TransformerConfig): + super().__init__(config=config) + + self.config = config + self.input_dim = len(self.config.budget_list) + self.n_dim = self.config.router_inter_dim + self.budget_map = { + item: torch.tensor(idx) for idx, item in enumerate(self.config.budget_list) + } + + # Initialize DP-aware Gumbel softmax + self._init_dp_gumbel_softmax() + + # Create init method for router layers + self.init_method = init_method_normal(self.config.router_std) + + self.add_router_for_mlp() + self.add_router_for_emb() + self.add_router_for_mamba() + self.add_router_for_moe_expert() + if self.config.add_skipping: + self.add_router_for_skipping() + + # Synchronize router weights across all pipeline parallel ranks + self._sync_router_weights() + self._mark_router_params_for_pp_sync() + self.hard_sample_th = config.hard_sample_th + + self.add_scaler_schedule() + + self.dp_size = get_data_parallel_world_size() + self.grad_accumulation_steps = get_current_global_batch_size() // ( + get_micro_batch_size() * self.dp_size + ) + self.fwd_pass_count = 0 + + def _init_dp_gumbel_softmax(self, base_seed=42): + """Initialize DP-aware Gumbel softmax functionality""" + self.dp_rank = get_data_parallel_rank() + self.gumbel_base_seed = base_seed + + def _sync_router_weights(self): + """ + Synchronize router weights across all pipeline parallel groups by broadcasting + from global rank 0 to all other ranks. + """ + if not torch.distributed.is_initialized(): + return + + # Get global rank 0 as the source + source_rank = 0 + + # Broadcast all router parameters from rank 0 + for name, param in self.named_parameters(): + if param is not None: + torch.distributed.broadcast(param.data, src=source_rank) + + def _mark_router_params_for_pp_sync(self): + """ + Mark all router parameters to be synchronized across pipeline parallel ranks. + This ensures they get handled by the main gradient synchronization system. + """ + for param in self.parameters(): + if param.requires_grad: + # Mark parameter for pipeline parallel synchronization + setattr(param, 'flextron_router_pp_sync', True) + + def _dp_gumbel_softmax(self, logits, tau=1.0, hard=False, curr_iteration=0): + """DP-aware Gumbel softmax that uses different random seeds per DP rank and iteration""" + # Create unique seed for this iteration and DP rank + + seed = ( + self.gumbel_base_seed + + (self.dp_rank + self.fwd_pass_count * self.dp_size) % self.config.router_gbs + + curr_iteration * 1000 + ) + # torch.manual_seed seeds both CPU and CUDA generators globally, so we + # must save/restore both - otherwise the CUDA RNG leaks the deterministic + # state we set here into other CUDA random ops elsewhere in the model. + cpu_state = torch.get_rng_state() + cuda_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None + torch.manual_seed(seed) + + try: + return F.gumbel_softmax(logits, tau=tau, hard=hard) + finally: + torch.set_rng_state(cpu_state) + if cuda_state is not None: + torch.cuda.set_rng_state_all(cuda_state) + + def _create_linear_layer(self, input_size, output_size, bias=False, is_first_layer=True): + """Helper method to create appropriate linear layer (TE or fallback)""" + if HAVE_TE: + if is_first_layer: + # First layer: TEColumnParallelLinear + return TEColumnParallelLinear( + input_size=input_size, + output_size=output_size, + config=self.config, + init_method=self.init_method, + gather_output=False, + bias=bias, + skip_bias_add=False, + is_expert=False, + ) + else: + # Second layer: TERowParallelLinear + return TERowParallelLinear( + input_size=input_size, + output_size=output_size, + config=self.config, + init_method=self.init_method, + bias=bias, + input_is_parallel=True, + skip_bias_add=False, + is_expert=False, + ) + else: + # Fallback to regular tensor parallel layers + if is_first_layer: + return ColumnParallelLinear( + input_size=input_size, + output_size=output_size, + config=self.config, + init_method=self.init_method, + gather_output=False, + bias=bias, + skip_bias_add=False, + is_expert=False, + ) + else: + return RowParallelLinear( + input_size=input_size, + output_size=output_size, + config=self.config, + init_method=self.init_method, + bias=bias, + input_is_parallel=True, + skip_bias_add=False, + is_expert=False, + ) + + def add_router_for_mlp(self): + mlp_list = self.config.mlp_int_list + if self.config.flex_hetero_ffn: + num_mlp = self.config.hybrid_layer_pattern.count("E") + gate_mlp_layer_list = [ + self._create_linear_layer( + self.input_dim, self.n_dim, bias=False, is_first_layer=True + ), + nn.LeakyReLU(0.1), + self._create_linear_layer( + self.n_dim, len(mlp_list) * num_mlp, bias=False, is_first_layer=False + ), + ] + # Set bias for the last layer + if ( + hasattr(gate_mlp_layer_list[-1], 'bias') + and gate_mlp_layer_list[-1].bias is not None + ): + last_layer_bias = [0.00 for _ in range(len(mlp_list))] + last_layer_bias[-1] = 1.00 + gate_mlp_layer_list[-1].bias.data = torch.tensor( + last_layer_bias, + dtype=gate_mlp_layer_list[-1].weight.dtype, + device=gate_mlp_layer_list[-1].weight.device, + ).repeat(num_mlp) + else: + gate_mlp_layer_list = [ + self._create_linear_layer( + self.input_dim, self.n_dim, bias=False, is_first_layer=True + ), + nn.LeakyReLU(0.1), + self._create_linear_layer( + self.n_dim, len(mlp_list), bias=False, is_first_layer=False + ), + ] + self.gate_mlp = nn.Sequential(*gate_mlp_layer_list) + + def add_router_for_moe_expert(self): + moe_expert_list = self.config.moe_expert_int_list + if self.config.flex_hetero_moe_expert: + num_moe_expert = self.config.hybrid_layer_pattern.count("E") + gate_moe_expert_layer_list = [ + self._create_linear_layer( + self.input_dim, self.n_dim, bias=False, is_first_layer=True + ), + nn.LeakyReLU(0.1), + self._create_linear_layer( + self.n_dim, + len(moe_expert_list) * num_moe_expert, + bias=False, + is_first_layer=False, + ), + ] + else: + gate_moe_expert_layer_list = [ + self._create_linear_layer( + self.input_dim, self.n_dim, bias=False, is_first_layer=True + ), + nn.LeakyReLU(0.1), + self._create_linear_layer( + self.n_dim, len(moe_expert_list), bias=False, is_first_layer=False + ), + ] + self.gate_moe_expert = nn.Sequential(*gate_moe_expert_layer_list) + + def add_router_for_emb(self): + emb_list = self.config.emb_int_list + gate_emb_layer_list = [ + self._create_linear_layer(self.input_dim, self.n_dim, bias=False, is_first_layer=True), + nn.LeakyReLU(0.1), + self._create_linear_layer(self.n_dim, len(emb_list), bias=False, is_first_layer=False), + ] + self.gate_emb = nn.Sequential(*gate_emb_layer_list) + + def add_router_for_skipping(self): + + self.output_dim = int(len(self.config.layer_ranking_list) + 1) + + gate_skip_mlp_layer_list = [ + self._create_linear_layer(self.input_dim, self.n_dim, bias=False, is_first_layer=True), + nn.LeakyReLU(0.1), + self._create_linear_layer( + self.n_dim, self.output_dim, bias=False, is_first_layer=False + ), + ] + + self.gate_skip_layer = nn.Sequential(*gate_skip_mlp_layer_list) + + def add_router_for_mamba(self): + mamba_list = self.config.mamba_int_list + if self.config.flex_hetero_mamba: + num_mamba = self.config.hybrid_layer_pattern.count("M") + gate_mamba_layer_list = [ + self._create_linear_layer( + self.input_dim, self.n_dim, bias=False, is_first_layer=True + ), + nn.LeakyReLU(0.1), + self._create_linear_layer( + self.n_dim, len(mamba_list) * num_mamba, bias=False, is_first_layer=False + ), + ] + # Set bias for the last layer + if ( + hasattr(gate_mamba_layer_list[-1], 'bias') + and gate_mamba_layer_list[-1].bias is not None + ): + last_layer_bias = [0.00 for _ in range(len(mamba_list))] + last_layer_bias[-1] = 1.00 + gate_mamba_layer_list[-1].bias.data = torch.tensor(last_layer_bias).repeat( + num_mamba + ) + else: + gate_mamba_layer_list = [ + self._create_linear_layer( + self.input_dim, self.n_dim, bias=False, is_first_layer=True + ), + nn.LeakyReLU(0.1), + self._create_linear_layer( + self.n_dim, len(mamba_list), bias=False, is_first_layer=False + ), + ] + self.gate_mamba = nn.Sequential(*gate_mamba_layer_list) + + def mamba_forward(self, args, budget_tensor, device, dtype, tau, hard_sample): + + # TODO @ataghibakhsh: check router out of sync on TP ranks + + router_mamba_logits1 = self.gate_mamba[0](budget_tensor) + router_mamba_logits2 = self.gate_mamba[1](router_mamba_logits1[0]) + router_mamba_logits = self.gate_mamba[2](router_mamba_logits2)[0].flatten() + # torch.distributed.all_reduce(router_mamba_logits, group=get_tensor_model_parallel_group(), op=torch.distributed.ReduceOp.AVG) + if self.scaler is not None: + scale = self.scaler[args.curr_iteration].to(device=device, dtype=dtype) + + if self.config.flex_hetero_mamba: + mamba_n = len(self.config.mamba_int_list) + router_mamba_logits = router_mamba_logits.reshape(-1, mamba_n) + if self.config.normalize_router_logits: + router_mamba_logits = ( + scale + * router_mamba_logits + / router_mamba_logits.std(dim=1, keepdim=True).clamp(min=1e-6) + ) + else: + router_mamba_logits = scale * router_mamba_logits + router_mamba_logits = self._dp_gumbel_softmax( + router_mamba_logits, tau=tau, hard=hard_sample, curr_iteration=args.curr_iteration + ) + _, choices_mamba = torch.topk(router_mamba_logits, 1, dim=-1) + return ( + router_mamba_logits, + [self.config.mamba_int_list[i] for i in choices_mamba.flatten().tolist()], + ) + else: + if self.config.normalize_router_logits: + # Std-normalize only when there's actually >1 choice; with a + # single choice the std is 0 and the routing is trivial, so we + # skip both the scale and the normalization (consistent with + # the no-op semantics of a single-choice axis). + if len(self.config.mamba_int_list) > 1: + router_mamba_logits = ( + scale + * router_mamba_logits + / router_mamba_logits.std(dim=0, keepdim=True).clamp(min=1e-6) + ) + else: + router_mamba_logits = scale * router_mamba_logits + router_mamba_logits = self._dp_gumbel_softmax( + router_mamba_logits, tau=tau, hard=hard_sample, curr_iteration=args.curr_iteration + ) + _, choices_mamba = torch.topk(router_mamba_logits, 1, dim=-1) + return (router_mamba_logits, self.config.mamba_int_list[choices_mamba.item()]) + + def mlp_forward(self, args, budget_tensor, device, dtype, tau, hard_sample): + + # TODO @ataghibakhsh: check router out of sync on TP ranks + router_mlp_logits1 = self.gate_mlp[0](budget_tensor) + router_mlp_logits2 = self.gate_mlp[1](router_mlp_logits1[0]) + router_mlp_logits = self.gate_mlp[2](router_mlp_logits2)[0].flatten() + # torch.distributed.all_reduce(router_mlp_logits, group=get_tensor_model_parallel_group(), op=torch.distributed.ReduceOp.AVG) + if self.scaler is not None: + scale = self.scaler[args.curr_iteration].to(device=device, dtype=dtype) + if self.config.flex_hetero_ffn: + mlp_n = len(self.config.mlp_int_list) + router_mlp_logits = router_mlp_logits.reshape(-1, mlp_n) + if self.config.normalize_router_logits: + router_mlp_logits = ( + scale + * router_mlp_logits + / router_mlp_logits.std(dim=1, keepdim=True).clamp(min=1e-6) + ) + else: + router_mlp_logits = scale * router_mlp_logits + router_mlp_logits = self._dp_gumbel_softmax( + router_mlp_logits, tau=tau, hard=hard_sample, curr_iteration=args.curr_iteration + ) + _, choices_mlp = torch.topk(router_mlp_logits, 1, dim=-1) + return ( + router_mlp_logits, + [self.config.mlp_int_list[i] for i in choices_mlp.flatten().tolist()], + ) + else: + if self.config.normalize_router_logits: + # Std-normalize only when there's actually >1 choice; with a + # single choice the std is 0 and the routing is trivial, so we + # skip both the scale and the normalization (consistent with + # the no-op semantics of a single-choice axis). + if len(self.config.mlp_int_list) > 1: + router_mlp_logits = ( + scale + * router_mlp_logits + / router_mlp_logits.std(dim=0, keepdim=True).clamp(min=1e-6) + ) + else: + router_mlp_logits = scale * router_mlp_logits + router_mlp_logits = self._dp_gumbel_softmax( + router_mlp_logits, tau=tau, hard=hard_sample, curr_iteration=args.curr_iteration + ) + _, choices_mlp = torch.topk(router_mlp_logits, 1, dim=-1) + return (router_mlp_logits, self.config.mlp_int_list[choices_mlp.item()]) + + def moe_expert_forward(self, args, budget_tensor, device, dtype, tau, hard_sample): + router_moe_expert_logits1 = self.gate_moe_expert[0](budget_tensor) + router_moe_expert_logits2 = self.gate_moe_expert[1](router_moe_expert_logits1[0]) + router_moe_expert_logits = self.gate_moe_expert[2](router_moe_expert_logits2)[0].flatten() + # torch.distributed.all_reduce(router_moe_expert_logits, group=get_tensor_model_parallel_group(), op=torch.distributed.ReduceOp.AVG) + if self.scaler is not None: + scale = self.scaler[args.curr_iteration].to(device=device, dtype=dtype) + if self.config.flex_hetero_moe_expert: + moe_expert_n = len(self.config.moe_expert_int_list) + router_moe_expert_logits = router_moe_expert_logits.reshape(-1, moe_expert_n) + if self.config.normalize_router_logits: + router_moe_expert_logits = ( + scale + * router_moe_expert_logits + / router_moe_expert_logits.std(dim=1, keepdim=True).clamp(min=1e-6) + ) + else: + router_moe_expert_logits = scale * router_moe_expert_logits + router_moe_expert_logits = self._dp_gumbel_softmax( + router_moe_expert_logits, + tau=tau, + hard=hard_sample, + curr_iteration=args.curr_iteration, + ) + _, choices_moe_expert = torch.topk(router_moe_expert_logits, 1, dim=-1) + return ( + router_moe_expert_logits, + [self.config.moe_expert_int_list[i] for i in choices_moe_expert.flatten().tolist()], + ) + else: + if self.config.normalize_router_logits: + # Std-normalize only when there's actually >1 choice; with a + # single choice the std is 0 and the routing is trivial, so we + # skip both the scale and the normalization (consistent with + # the no-op semantics of a single-choice axis). + if len(self.config.moe_expert_int_list) > 1: + router_moe_expert_logits = ( + scale + * router_moe_expert_logits + / router_moe_expert_logits.std(dim=0, keepdim=True).clamp(min=1e-6) + ) + else: + router_moe_expert_logits = scale * router_moe_expert_logits + router_moe_expert_logits = self._dp_gumbel_softmax( + router_moe_expert_logits, + tau=tau, + hard=hard_sample, + curr_iteration=args.curr_iteration, + ) + _, choices_moe_expert = torch.topk(router_moe_expert_logits, 1, dim=-1) + return ( + router_moe_expert_logits, + self.config.moe_expert_int_list[choices_moe_expert.item()], + ) + + def emb_forward(self, args, budget_tensor, device, dtype, tau, hard_sample): + + router_emb_logits1 = self.gate_emb[0](budget_tensor) + router_emb_logits2 = self.gate_emb[1](router_emb_logits1[0]) + router_emb_logits = self.gate_emb[2](router_emb_logits2)[0].flatten() + # torch.distributed.all_reduce(router_emb_logits, group=get_tensor_model_parallel_group(), op=torch.distributed.ReduceOp.AVG) + if self.scaler is not None: + scale = self.scaler[args.curr_iteration].to(device=device, dtype=dtype) + router_emb_logits = scale * router_emb_logits + + # router_emb_logits = F.gumbel_softmax(router_emb_logits, tau=tau, hard=hard_sample) + router_emb_logits = self._dp_gumbel_softmax( + router_emb_logits, tau=tau, hard=hard_sample, curr_iteration=args.curr_iteration + ) + _, choices_emb = torch.topk(router_emb_logits, 1, dim=-1) + + return (router_emb_logits, self.config.emb_int_list[choices_emb.item()]) + + def skipping_forward(self, args, budget_tensor, device, dtype, tau, hard_sample): + + # for layer skipping, skipping MLP layers + router_skip_layer_logits1 = self.gate_skip_layer[0](budget_tensor) + router_skip_layer_logits2 = self.gate_skip_layer[1](router_skip_layer_logits1[0]) + router_skip_layer_logits = self.gate_skip_layer[2](router_skip_layer_logits2)[0].flatten() + # torch.distributed.all_reduce(router_skip_layer_logits, group=get_tensor_model_parallel_group(), op=torch.distributed.ReduceOp.AVG) + router_skip_layer_logits = torch.repeat_interleave( + router_skip_layer_logits, repeats=1, dim=0 + ) + if self.scaler is not None: + router_skip_layer_logits = router_skip_layer_logits * self.scaler[ + args.curr_iteration + ].to(device=device, dtype=dtype) + + # router_skip_layer_logits = F.gumbel_softmax(router_skip_layer_logits, tau=tau, hard=hard_sample) + router_skip_layer_logits = self._dp_gumbel_softmax( + router_skip_layer_logits, tau=tau, hard=hard_sample, curr_iteration=args.curr_iteration + ) + _, choices_skip_layer = torch.topk(router_skip_layer_logits, 1, dim=-1) + if choices_skip_layer.item() != 0: + selected_to_drop = self.config.layer_ranking_list[: choices_skip_layer.item()] + choices_skip_layer = torch.zeros(self.config.num_layers).to(device=device, dtype=dtype) + choices_skip_layer[selected_to_drop] = 1 + else: + choices_skip_layer = torch.zeros(self.config.num_layers).to(device=device, dtype=dtype) + return (router_skip_layer_logits, choices_skip_layer) + + def get_curr_tau(self, curr_iteration): + tau = self.config.tau_init * torch.pow(torch.tensor(self.config.tau_decay), curr_iteration) + return tau + + def add_scaler_schedule(self): + + if ( + self.config.linear_scaler_start is not None + and self.config.linear_scaler_end is not None + ): + from megatron.training import get_args + + args = get_args() + self.scaler = torch.linspace( + start=self.config.linear_scaler_start, + end=self.config.linear_scaler_end, + steps=( + args.train_iters + if args.train_iters is not None + else (args.train_samples // args.global_batch_size) + ), + ) + else: + self.scaler = None + + def forward(self, budget): + + from megatron.training import get_args + + args = get_args() + + hard_sample = random.random() > self.hard_sample_th + + tau = self.get_curr_tau(args.curr_iteration) + + device, dtype = next(self.parameters()).device, next(self.parameters()).dtype + + if budget in self.budget_map.keys(): + budget_tensor = torch.nn.functional.one_hot( + self.budget_map[budget], len(self.config.budget_list) + ).to(device=device, dtype=dtype) + elif budget == 1.0: + # Requested full model but 1.0 isn't a trained budget — fall back + # to the largest configured budget. Using max() instead of [0] + # makes this independent of budget_list ordering. + budget_tensor = torch.nn.functional.one_hot( + self.budget_map[max(self.budget_map.keys())], len(self.config.budget_list) + ).to(device=device, dtype=dtype) + else: + # budget_list is enforced descending by sort_budget_list_descending + # at config-injection time. We re-sort ascending locally for + # bucketize, then flip(0) below to land back in the descending + # one-hot coordinate system the router was trained against. + budget_values = torch.tensor( + sorted(self.config.budget_list), device=device, dtype=dtype + ) + budget_t = torch.as_tensor(budget, device=device, dtype=dtype) + + # idx2 = first index where budget_values[idx] > budget (right=False gives >= behavior with floats) + idx2 = torch.bucketize(budget_t, budget_values, right=False) + # Clamp to valid interior so we always have a left neighbor + idx2 = idx2.clamp(min=1, max=len(self.config.budget_list) - 1) + idx1 = idx2 - 1 + + b1 = budget_values.index_select(0, idx1.to(torch.long)) + b2 = budget_values.index_select(0, idx2.to(torch.long)) + denom = b2 - b1 # .clamp_min(1e-12) + weight = (budget_t - b1) / denom # in [0,1] when budget is between b1 and b2 + + num_classes = len(self.config.budget_list) + one_hot_1 = torch.nn.functional.one_hot( + idx1.to(torch.long), num_classes=num_classes + ).to(device=device, dtype=dtype) + one_hot_2 = torch.nn.functional.one_hot( + idx2.to(torch.long), num_classes=num_classes + ).to(device=device, dtype=dtype) + + # If weight is scalar, broadcasting works; if vector, it blends per-sample + budget_tensor = (1 - weight).unsqueeze(-1) * one_hot_1 + weight.unsqueeze( + -1 + ) * one_hot_2 + budget_tensor = budget_tensor.squeeze(0).flip(0) + + budget_tensor = budget_tensor.unsqueeze(0) + mlp_forward_outputs = self.mlp_forward(args, budget_tensor, device, dtype, tau, hard_sample) + mamba_forward_outputs = self.mamba_forward( + args, budget_tensor, device, dtype, tau, hard_sample + ) + moe_expert_forward_outputs = self.moe_expert_forward( + args, budget_tensor, device, dtype, tau, hard_sample + ) + + if self.config.add_skipping: + skipping_forward_outputs = self.skipping_forward( + args, budget_tensor, device, dtype, tau, hard_sample + ) + else: + skipping_forward_outputs = None + + emb_forward_outputs = self.emb_forward(args, budget_tensor, device, dtype, tau, hard_sample) + self.fwd_pass_count += 1 + return ( + mlp_forward_outputs, + skipping_forward_outputs, + emb_forward_outputs, + mamba_forward_outputs, + moe_expert_forward_outputs, + ) diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 374a4a1b05f..e144e17c3ae 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -8,7 +8,7 @@ import torch from gpt_builders import gpt_builder -from mamba_builders import mamba_builder +from hybrid_builders import hybrid_builder from megatron.core.inference.config import ( InferenceConfig, KVCacheManagementMode, @@ -44,8 +44,16 @@ def get_model_for_inference() -> MegatronModule: if args.model_provider == "gpt": model_builder = gpt_builder - elif args.model_provider == "mamba": - model_builder = mamba_builder + elif args.model_provider in ("hybrid", "mamba"): + if args.model_provider == "mamba": + import warnings + + warnings.warn( + '--model-provider "mamba" is deprecated. Use --model-provider "hybrid" instead.', + DeprecationWarning, + stacklevel=2, + ) + model_builder = hybrid_builder else: raise ValueError(f"Invalid model provider {args.model_provider}") @@ -158,7 +166,11 @@ def add_inference_args(parser: ArgumentParser) -> ArgumentParser: "total number of requests. Set to -1 to add all requests together.", ) group.add_argument( - "--model-provider", choices=["mamba", "gpt"], default="gpt", help="Model provider" + "--model-provider", + choices=["hybrid", "mamba", "gpt"], + default="gpt", + help='Model provider. Use "hybrid" for HybridModel (formerly MambaModel). ' + '"mamba" is accepted for backward compatibility but deprecated.', ) group.add_argument( "--skip-prompt-log-probs", action='store_true', default=False, help='Skip prompt log probs.' @@ -370,6 +382,8 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): logging_step_interval=args.inference_logging_step_interval, num_speculative_tokens=args.num_speculative_tokens, use_synchronous_zmq_collectives=args.inference_use_synchronous_zmq_collectives, + disable_ep_consensus=args.inference_disable_ep_consensus, + sampling_backend=args.inference_dynamic_batching_sampling_backend, ) diff --git a/megatron/legacy/model/__init__.py b/megatron/legacy/model/__init__.py index 979d93892b4..556c9077b19 100644 --- a/megatron/legacy/model/__init__.py +++ b/megatron/legacy/model/__init__.py @@ -2,5 +2,3 @@ from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .rms_norm import RMSNorm -from .gpt_model import GPTModel -from .language_model import get_language_model diff --git a/megatron/legacy/model/gpt_model.py b/megatron/legacy/model/gpt_model.py deleted file mode 100644 index 66fd0979c46..00000000000 --- a/megatron/legacy/model/gpt_model.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""GPT-2 model.""" - -import torch -from typing import Optional - -from megatron.training import get_args -from megatron.core import tensor_parallel -from megatron.core.utils import deprecate_inference_params - -from .enums import AttnMaskType -from .language_model import parallel_lm_logits -from .language_model import get_language_model -from .module import MegatronModule - - -def post_language_model_processing(lm_output, labels, logit_weights, - parallel_output, - fp16_lm_cross_entropy): - - # Output. Format [s b h] - output = parallel_lm_logits( - lm_output, - logit_weights, - parallel_output) - - if labels is None: - # [s b h] => [b s h] - return output.transpose(0,1).contiguous() - else: - # [b s] => [s b] - labels = labels.transpose(0,1).contiguous() - if fp16_lm_cross_entropy: - assert output.dtype == torch.half - loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) - else: - loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) - - # [s b] => [b, s] - loss = loss.transpose(0,1).contiguous() - return loss - - -class GPTModel(MegatronModule): - """GPT-2 Language model.""" - - def __init__(self, - config, - num_tokentypes=0, - parallel_output=True, - pre_process=True, - post_process=True): - args = get_args() - super().__init__(config=config, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights) - - self.parallel_output = parallel_output - self.pre_process = pre_process - self.post_process = post_process - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights - - self.language_model, self._language_model_key = get_language_model( - config=config, - num_tokentypes=num_tokentypes, - add_pooler=False, - encoder_attn_mask_type=AttnMaskType.causal, - pre_process=self.pre_process, - post_process=self.post_process) - - if not args.untie_embeddings_and_output_weights: - self.initialize_word_embeddings() - - def set_input_tensor(self, input_tensor): - """See megatron.legacy.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, input_ids, position_ids, attention_mask, - labels=None, tokentype_ids=None, inference_context=None, *, inference_params=None): - - inference_context = deprecate_inference_params(inference_context, inference_params) - - lm_output = self.language_model( - input_ids, - position_ids, - attention_mask, - inference_context=inference_context) - - if self.post_process: - return post_language_model_processing( - lm_output, labels, - self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(), - self.parallel_output, - self.fp16_lm_cross_entropy) - else: - return lm_output - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars) - # Save word_embeddings. - if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: - state_dict_[self._word_embeddings_for_head_key] \ - = self.word_embeddings.state_dict(prefix=prefix, - keep_vars=keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - # Load word_embeddings. - if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: - self.word_embeddings.load_state_dict( - state_dict[self._word_embeddings_for_head_key], strict=strict) - if self._language_model_key in state_dict: - state_dict = state_dict[self._language_model_key] - self.language_model.load_state_dict(state_dict, strict=strict) diff --git a/megatron/legacy/model/language_model.py b/megatron/legacy/model/language_model.py deleted file mode 100644 index 383230edb7f..00000000000 --- a/megatron/legacy/model/language_model.py +++ /dev/null @@ -1,639 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Transformer based language model.""" - -import torch -import torch.nn.functional as F -from typing import Optional - -from megatron.core import mpu, tensor_parallel -from megatron.core.enums import ModelType -from megatron.core.inference.contexts import BaseInferenceContext -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding -from megatron.core.utils import deprecate_inference_params -from megatron.training import get_args - -from .enums import AttnMaskType, LayerType -from .module import MegatronModule -from .transformer import ParallelTransformer -from .utils import get_linear_layer, init_method_normal, scaled_init_method_normal - - -def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): - """LM logits using word embedding weights.""" - args = get_args() - # Parallel logits. - model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 - if model_parallel or args.sequence_parallel: - input_parallel = input_ - allreduce_dgrad = model_parallel and not args.sequence_parallel - else: - input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) - allreduce_dgrad = False - - # Matrix multiply. - logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce( - input=input_parallel, - weight=word_embeddings_weight, - bias=bias, - gradient_accumulation_fusion=args.gradient_accumulation_fusion, - sequence_parallel=args.sequence_parallel, - grad_output_buffer=None, - allreduce_dgrad=allreduce_dgrad, - ) - # Gather if needed. - - if parallel_output: - return logits_parallel - - return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) - - -def get_language_model( - config, - num_tokentypes, - add_pooler, - encoder_attn_mask_type, - add_encoder=True, - add_decoder=False, - decoder_attn_mask_type=AttnMaskType.causal, - pre_process=True, - post_process=True, -): - """Build language model and return along with the key to save.""" - args = get_args() - if config.init_method is None: - config.init_method = init_method_normal(config.init_method_std) - - if config.output_layer_init_method is None: - config.output_layer_init_method = scaled_init_method_normal( - config.init_method_std, config.num_layers - ) - - # Language model. - language_model = TransformerLanguageModel( - config, - encoder_attn_mask_type, - num_tokentypes=num_tokentypes, - add_encoder=add_encoder, - add_decoder=add_decoder, - decoder_attn_mask_type=decoder_attn_mask_type, - add_pooler=add_pooler, - pre_process=pre_process, - post_process=post_process, - ) - # key used for checkpoints. - language_model_key = 'language_model' - - return language_model, language_model_key - - -class Pooler(MegatronModule): - """Pooler layer. - - Pool hidden states of a specific token (for example start of the - sequence) and add a linear transformation followed by a tanh. - - Args: - hidden_size: hidden size - init_method: weight initialization method for the linear layer. - bias is set to zero. - """ - - def __init__(self, hidden_size, init_method): - super(Pooler, self).__init__() - args = get_args() - self.dense = get_linear_layer(hidden_size, hidden_size, init_method) - self.sequence_parallel = args.sequence_parallel - - def forward(self, hidden_states, sequence_index=0): - # hidden_states: [s, b, h] - # sequence_index: index of the token to pool. - - # gather data along sequence dimensions - # same pooler is run on all tensor parallel nodes - if self.sequence_parallel: - hidden_states = tensor_parallel.gather_from_sequence_parallel_region( - hidden_states, tensor_parallel_output_grad=False - ) - - pooled = hidden_states[sequence_index, :, :] - pooled = self.dense(pooled) - pooled = torch.tanh(pooled) - return pooled - - -class Embedding(MegatronModule): - """Language model embeddings. - - Args: - hidden_size: hidden size - vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding - embedding_dropout_prob: dropout probability for embeddings - init_method: weight initialization method - num_tokentypes: size of the token-type embeddings. 0 value - will ignore this embedding - """ - - def __init__( - self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - config, - num_tokentypes=0, - ): - super(Embedding, self).__init__() - - self.hidden_size = hidden_size - self.init_method = config.init_method - self.num_tokentypes = num_tokentypes - - args = get_args() - - # Word embeddings (parallel). - self.params_dtype = args.params_dtype - self.word_embeddings = tensor_parallel.VocabParallelEmbedding( - vocab_size, self.hidden_size, config=config, init_method=config.init_method - ) - self._word_embeddings_key = 'word_embeddings' - - # Position embedding (serial). - self.add_position_embedding = args.position_embedding_type == 'learned_absolute' - if self.add_position_embedding: - self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) - self._position_embeddings_key = 'position_embeddings' - # Initialize the position embeddings. - if args.perform_initialization: - self.init_method(self.position_embeddings.weight) - - # Token type embedding. - # Add this as an optional field that can be added through - # method call so we can load a pretrain model without - # token types and add them as needed. - self._tokentype_embeddings_key = 'tokentype_embeddings' - if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) - # Initialize the token-type embeddings. - if args.perform_initialization: - self.init_method(self.tokentype_embeddings.weight) - else: - self.tokentype_embeddings = None - - self.fp32_residual_connection = args.fp32_residual_connection - self.sequence_parallel = args.sequence_parallel - self.clone_scatter_output_in_embedding = args.clone_scatter_output_in_embedding - # Embeddings dropout - self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) - - def zero_parameters(self): - """Zero out all parameters in embedding.""" - self.word_embeddings.weight.data.fill_(0) - self.word_embeddings.weight.shared = True - if self.add_position_embedding: - self.position_embeddings.weight.data.fill_(0) - self.position_embeddings.weight.shared = True - if self.num_tokentypes > 0: - self.tokentype_embeddings.weight.data.fill_(0) - self.tokentype_embeddings.weight.shared = True - - def add_tokentype_embeddings(self, num_tokentypes): - """Add token-type embedding. This function is provided so we can add - token-type embeddings in case the pretrained model does not have it. - This allows us to load the model normally and then add this embedding. - """ - if self.tokentype_embeddings is not None: - raise Exception('tokentype embeddings is already initialized') - if torch.distributed.get_rank() == 0: - print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) - self.num_tokentypes = num_tokentypes - self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) - # Initialize the token-type embeddings. - args = get_args() - self.init_method(self.tokentype_embeddings.weight) - - def forward(self, input_ids, position_ids, tokentype_ids=None): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - if self.add_position_embedding: - position_embeddings = self.position_embeddings(position_ids) - embeddings = words_embeddings + position_embeddings - else: - embeddings = words_embeddings - - if tokentype_ids is not None: - assert self.tokentype_embeddings is not None - embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) - else: - assert self.tokentype_embeddings is None - - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - - # Dropout. - if self.sequence_parallel: - embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) - # `scatter_to_sequence_parallel_region` returns a view, which prevents - # the original tensor from being garbage collected. Clone to facilitate GC. - # Has a small runtime cost (~0.5%). - if self.clone_scatter_output_in_embedding: - embeddings = embeddings.clone() - with tensor_parallel.get_cuda_rng_tracker().fork(): - embeddings = self.embedding_dropout(embeddings) - else: - embeddings = self.embedding_dropout(embeddings) - - return embeddings - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load.""" - - state_dict_ = {} - state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict( - prefix=prefix, keep_vars=keep_vars - ) - if self.add_position_embedding: - state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict( - prefix=prefix, keep_vars=keep_vars - ) - if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict( - prefix=prefix, keep_vars=keep_vars - ) - - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - # Word embedding. - if self._word_embeddings_key in state_dict: - state_dict_ = state_dict[self._word_embeddings_key] - else: - # for backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] = state_dict[key] - self.word_embeddings.load_state_dict(state_dict_, strict=strict) - - # Position embedding. - if self.add_position_embedding: - if self._position_embeddings_key in state_dict: - state_dict_ = state_dict[self._position_embeddings_key] - else: - # for backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] = state_dict[key] - self.position_embeddings.load_state_dict(state_dict_, strict=strict) - - # Tokentype embedding. - if self.num_tokentypes > 0: - state_dict_ = {} - if self._tokentype_embeddings_key in state_dict: - state_dict_ = state_dict[self._tokentype_embeddings_key] - else: - # for backward compatibility. - for key in state_dict.keys(): - if 'tokentype_embeddings' in key: - state_dict_[key.split('tokentype_embeddings.')[1]] = state_dict[key] - if len(state_dict_.keys()) > 0: - self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) - else: - print( - '***WARNING*** expected tokentype embeddings in the ' - 'checkpoint but could not find it', - flush=True, - ) - - -class TransformerLanguageModel(MegatronModule): - """Transformer language model. - - Args: - transformer_hparams: transformer hyperparameters - vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding - embedding_dropout_prob: dropout probability for embeddings - num_tokentypes: size of the token-type embeddings. 0 value - will ignore this embedding - """ - - def __init__( - self, - config, - encoder_attn_mask_type, - num_tokentypes=0, - add_encoder=True, - add_decoder=False, - decoder_attn_mask_type=AttnMaskType.causal, - add_pooler=False, - pre_process=True, - post_process=True, - ): - args = get_args() - # TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5. - if args.untie_embeddings_and_output_weights: - assert not add_decoder - super(TransformerLanguageModel, self).__init__( - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights - ) - - self.pre_process = pre_process - self.post_process = post_process - self.hidden_size = config.hidden_size - self.num_tokentypes = num_tokentypes - self.init_method = config.init_method - self.add_encoder = add_encoder - self.encoder_attn_mask_type = encoder_attn_mask_type - self.add_decoder = add_decoder - self.decoder_attn_mask_type = decoder_attn_mask_type - self.add_pooler = add_pooler - self.encoder_hidden_state = None - self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights - - # Embeddings. - if self.pre_process: - self.embedding = Embedding( - self.hidden_size, - args.padded_vocab_size, - args.max_position_embeddings, - args.hidden_dropout, - config, - self.num_tokentypes, - ) - self._embedding_key = 'embedding' - - # Rotary positional embeddings - self.use_rotary_position_embeddings = args.position_embedding_type == 'rope' - if self.use_rotary_position_embeddings: - self.seq_length = args.seq_length - rotary_dim = ( - args.hidden_size // args.num_attention_heads - if args.kv_channels is None - else args.kv_channels - ) - - # partial rotary embeddings, which is better than full rotary - # Wang and Komatsuzaki et al - # https://github.com/kingoflolz/mesh-transformer-jax/ - self.rotary_pos_emb = RotaryEmbedding( - kv_channels=rotary_dim, - rotary_percent=args.rotary_percent, - seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor, - ) - - # Encoder (usually set to True, False if part of an encoder-decoder - # architecture and in encoder-only stage). - if self.add_encoder: - self.encoder = ParallelTransformer( - config, - model_type=args.model_type, - self_attn_mask_type=self.encoder_attn_mask_type, - pre_process=self.pre_process, - post_process=self.post_process, - ) - self._encoder_key = 'encoder' - else: - self.encoder = None - - # Decoder (usually set to False, True if part of an encoder-decoder - # architecture and in decoder-only stage). - if self.add_decoder: - self.decoder = ParallelTransformer( - config, - model_type=args.model_type, - layer_type=LayerType.decoder, - self_attn_mask_type=self.decoder_attn_mask_type, - pre_process=self.pre_process, - post_process=self.post_process, - ) - self._decoder_key = 'decoder' - else: - self.decoder = None - - if self.post_process: - # Pooler. - if self.add_pooler: - self.pooler = Pooler(self.hidden_size, self.init_method) - self._pooler_key = 'pooler' - - if self.untie_embeddings_and_output_weights: - self.output_layer = tensor_parallel.ColumnParallelLinear( - args.hidden_size, - args.padded_vocab_size, - config=config, - init_method=self.init_method, - bias=False, - ) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias. - self._output_layer_key = 'output_layer' - - def set_input_tensor(self, input_tensor): - """See megatron.legacy.model.transformer.set_input_tensor()""" - - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - - if self.add_encoder and self.add_decoder: - assert ( - len(input_tensor) == 1 - ), 'input_tensor should only be length 1 for stage with both encoder and decoder' - self.encoder.set_input_tensor(input_tensor[0]) - elif self.add_encoder: - assert ( - len(input_tensor) == 1 - ), 'input_tensor should only be length 1 for stage with only encoder' - self.encoder.set_input_tensor(input_tensor[0]) - elif self.add_decoder: - if len(input_tensor) == 2: - self.decoder.set_input_tensor(input_tensor[0]) - self.encoder_hidden_state = input_tensor[1] - elif len(input_tensor) == 1: - self.decoder.set_input_tensor(None) - self.encoder_hidden_state = input_tensor[0] - else: - raise Exception('input_tensor must have either length 1 or 2') - else: - raise Exception('Stage must have at least either encoder or decoder') - - def forward( - self, - enc_input_ids, - enc_position_ids, - enc_attn_mask, - dec_input_ids=None, - dec_position_ids=None, - dec_attn_mask=None, - enc_dec_attn_mask=None, - tokentype_ids=None, - inference_context=None, - pooling_sequence_index=0, - enc_hidden_states=None, - output_enc_hidden=False, - *, - inference_params: Optional[BaseInferenceContext] = None, - ): - - inference_context = deprecate_inference_params(inference_context, inference_params) - - # Encoder embedding. - if self.pre_process: - encoder_input = self.embedding( - enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids - ) - else: - encoder_input = None - - # Rotary positional embeddings - rotary_pos_emb = None - if self.use_rotary_position_embeddings: - if inference_context is not None: - rotary_pos_emb = self.rotary_pos_emb(inference_context.max_sequence_length) - else: - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - - # Run encoder. - if enc_hidden_states is None: - if self.encoder is not None: - encoder_output = self.encoder( - encoder_input, - enc_attn_mask, - inference_context=inference_context, - rotary_pos_emb=rotary_pos_emb, - ) - else: - encoder_output = self.encoder_hidden_state - else: - encoder_output = enc_hidden_states.to(encoder_input.dtype) - - if self.post_process: - if self.add_pooler: - pooled_output = self.pooler(encoder_output, pooling_sequence_index) - - # output_enc_hidden refers to when we just need the encoder's - # output. For example, it is helpful to compute - # similarity between two sequences by average pooling - if not self.add_decoder or output_enc_hidden: - if self.add_pooler and self.post_process: - return encoder_output, pooled_output - else: - return encoder_output - - # Decoder embedding. - if self.pre_process: - decoder_input = self.embedding(dec_input_ids, dec_position_ids) - else: - decoder_input = None - - # Run decoder. - decoder_output = self.decoder( - decoder_input, - dec_attn_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - inference_context=inference_context, - rotary_pos_emb=rotary_pos_emb, - ) - - if self.add_pooler and self.post_process: - return decoder_output, encoder_output, pooled_output - else: - return decoder_output, encoder_output - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load.""" - - state_dict_ = {} - if self.pre_process: - state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars - ) - if self.add_encoder: - state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars - ) - if self.post_process: - if self.add_pooler: - state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars - ) - if self.untie_embeddings_and_output_weights: - state_dict_[self._output_layer_key] = self.output_layer.state_dict( - prefix=prefix, keep_vars=keep_vars - ) - - if self.add_decoder: - state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars - ) - - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - # Embedding. - if self.pre_process: - if self._embedding_key in state_dict: - state_dict_ = state_dict[self._embedding_key] - else: - # for backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if '_embeddings' in key: - state_dict_[key] = state_dict[key] - self.embedding.load_state_dict(state_dict_, strict=strict) - - # Encoder. - if self.add_encoder: - if self._encoder_key in state_dict: - state_dict_ = state_dict[self._encoder_key] - # For backward compatibility. - elif 'transformer' in state_dict: - state_dict_ = state_dict['transformer'] - else: - # For backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if 'transformer.' in key: - state_dict_[key.split('transformer.')[1]] = state_dict[key] - - # For backward compatibility. - state_dict_self_attention = {} - for key in state_dict_.keys(): - if '.attention.' in key: - state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = ( - state_dict_[key] - ) - else: - state_dict_self_attention[key] = state_dict_[key] - state_dict_ = state_dict_self_attention - - self.encoder.load_state_dict(state_dict_, strict=strict) - - # Pooler. - if self.post_process: - if self.add_pooler: - assert 'pooler' in state_dict, 'could not find data for pooler in the checkpoint' - self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict) - if self.untie_embeddings_and_output_weights: - assert ( - 'output_layer' in state_dict - ), 'could not find data for output_layer in the checkpoint' - self.output_layer.load_state_dict(state_dict[self._output_layer_key], strict=strict) - # Decoder. - if self.add_decoder: - assert 'decoder' in state_dict, 'could not find data for pooler in the checkpoint' - self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) diff --git a/megatron/legacy/model/module.py b/megatron/legacy/model/module.py deleted file mode 100644 index 3a9651df99e..00000000000 --- a/megatron/legacy/model/module.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Megatron Module""" - -import torch -from torch.autograd import Variable -from torch.nn.parameter import Parameter - -from megatron.training import get_args -from megatron.core import mpu, tensor_parallel - - -_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) -_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) -_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) - - -class MegatronModule(torch.nn.Module): - """Megatron specific extensions of torch Module with support - for pipelining.""" - - def __init__(self, config=None, share_embeddings_and_output_weights=True): - super(MegatronModule, self).__init__() - self.config = config - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """Use this function to override the state dict for - saving checkpoints.""" - return self.state_dict(prefix=prefix, keep_vars=keep_vars) - - def shared_embedding_or_output_weight(self): - if self.pre_process: - return self.language_model.embedding.word_embeddings.weight - else: - if not self.share_embeddings_and_output_weights: - raise Exception( - 'shared_embedding_or_output_weight() called for last ' - 'stage, but share_embeddings_and_output_weights is false' - ) - return self.word_embeddings.weight - - def initialize_word_embeddings(self): - args = get_args() - if not self.share_embeddings_and_output_weights: - raise Exception( - 'initialize_word_embeddings() was called but ' - 'share_embeddings_and_output_weights is false' - ) - - # This function just initializes the word embeddings in the final stage - # when we are using pipeline parallelism. Nothing to do if we aren't - # using pipeline parallelism. - if args.pipeline_model_parallel_size == 1: - # Zero out wgrad if sharing embeddings between two layers on same - # pipeline stage to make sure grad accumulation into main_grad is - # correct and does not include garbage values (e.g., from torch.empty). - self.shared_embedding_or_output_weight().zero_out_wgrad = True - return - - if ( - mpu.is_pipeline_first_stage(ignore_virtual=False) - and self.pre_process - and not self.post_process - ): - self.shared_embedding_or_output_weight().shared_embedding = True - - # Parameters are shared between the word embeddings layers, and the - # heads at the end of the model. In a pipelined setup with more than - # one stage, the initial embedding layer and the head are on different - # workers, so we do the following: - # 1. Create a second copy of word_embeddings on the last stage, with - # initial parameters of 0.0. - # 2. Do an all-reduce between the first and last stage to ensure that - # the two copies of word_embeddings start off with the same - # parameter values. - # 3. In the training loop, before an all-reduce between the grads of - # the two word_embeddings layers to ensure that every applied weight - # update is the same on both stages. - if mpu.is_pipeline_last_stage(ignore_virtual=False) and not self.pre_process: - assert not mpu.is_pipeline_first_stage(ignore_virtual=False) - self._word_embeddings_for_head_key = 'word_embeddings_for_head' - # set word_embeddings weights to 0 here, then copy first - # stage's weights using all_reduce below. - self.word_embeddings = tensor_parallel.VocabParallelEmbedding( - args.padded_vocab_size, - self.config.hidden_size, - config=self.config, - init_method=self.config.init_method, - ) - self.word_embeddings.weight.data.fill_(0) - self.word_embeddings.weight.shared = True - self.word_embeddings.weight.shared_embedding = True - - # Zero out initial weights for decoder embedding. - # NOTE: We don't currently support T5 with the interleaved schedule. - if not mpu.is_pipeline_first_stage(ignore_virtual=True) and self.pre_process: - self.language_model.embedding.zero_parameters() - - if not torch.distributed.is_initialized(): - if not getattr(MegatronModule, "embedding_warning_printed", False): - print( - "WARNING! Distributed processes aren't initialized, so " - "word embeddings in the last layer are not initialized. " - "If you are just manipulating a model this is fine, but " - "this needs to be handled manually. If you are training " - "something is definitely wrong." - ) - MegatronModule.embedding_warning_printed = True - return - - # Ensure that first and last stages have the same initial parameter - # values. - if mpu.is_rank_in_embedding_group(ignore_virtual=False): - self.shared_embedding_or_output_weight().data = ( - self.shared_embedding_or_output_weight().data.cuda() - ) - torch.distributed.all_reduce( - self.shared_embedding_or_output_weight().data, group=mpu.get_embedding_group() - ) - - -def conversion_helper(val, conversion): - """Apply conversion to val. Recursively apply conversion if `val` - #is a nested tuple/list structure.""" - if not isinstance(val, (tuple, list)): - return conversion(val) - rtn = [conversion_helper(v, conversion) for v in val] - if isinstance(val, tuple): - rtn = tuple(rtn) - return rtn - - -def fp32_to_float16(val, float16_convertor): - """Convert fp32 `val` to fp16/bf16""" - - def half_conversion(val): - val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): - val_typecheck = val.data - if isinstance(val_typecheck, _FLOAT_TYPES): - val = float16_convertor(val) - return val - - return conversion_helper(val, half_conversion) - - -def float16_to_fp32(val): - """Convert fp16/bf16 `val` to fp32""" - - def float_conversion(val): - val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): - val_typecheck = val.data - if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): - val = val.float() - return val - - return conversion_helper(val, float_conversion) diff --git a/megatron/post_training/arguments.py b/megatron/post_training/arguments.py index dc98c6d28e4..dc459586df3 100644 --- a/megatron/post_training/arguments.py +++ b/megatron/post_training/arguments.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved. def add_modelopt_args(parser): @@ -10,8 +10,9 @@ def add_modelopt_args(parser): "--export-model-type", type=str, default="GPTModel", - choices=["GPTModel", "MambaModel"], - help="Model type to use in model_provider.", + choices=["GPTModel", "HybridModel", "MambaModel"], + help='Model type to use in model_provider. Use "HybridModel" for hybrid models ' + '(formerly MambaModel). "MambaModel" is accepted for backward compatibility but deprecated.', ) group.add_argument( "--export-legacy-megatron", @@ -90,7 +91,10 @@ def add_modelopt_args(parser): "--finetune-hf-dataset", type=str, default=None, help="HF dataset used for finetuning." ) group.add_argument( - "--finetune-data-split", type=str, default="train", help="HF dataset split used for finetuning." + "--finetune-data-split", + type=str, + default="train", + help="HF dataset split used for finetuning.", ) # Special model architecture option @@ -124,7 +128,7 @@ def add_modelopt_args(parser): '--enable-gpt-oss', action="store_true", help='Enable GPT-OSS mode with YaRN RoPE configuration. When enabled, automatically ' - 'configures all YaRN parameters with GPT-OSS defaults.', + 'configures all YaRN parameters with GPT-OSS defaults.', ) return parser diff --git a/megatron/post_training/model_builder.py b/megatron/post_training/model_builder.py index 085d188e811..eac6500c1a2 100644 --- a/megatron/post_training/model_builder.py +++ b/megatron/post_training/model_builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved. """ModelOpt GPT model provider.""" @@ -16,17 +16,16 @@ from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( get_gpt_heterogeneous_layer_spec, ) -from megatron.core.models.mamba import MambaModel as MCoreMambaModel +from megatron.core.models.hybrid.hybrid_model import HybridModel as MCoreHybridModel from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec from megatron.core.post_training.modelopt.gpt.state_dict_hooks import ( mcore_gpt_load_te_state_dict_pre_hook, ) -from megatron.post_training.checkpointing import load_modelopt_checkpoint, load_modelopt_state +from megatron.post_training.checkpointing import load_modelopt_state +from megatron.post_training.utils import print_distributed_quant_summary from megatron.training import get_args, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args -from megatron.post_training.utils import print_distributed_quant_summary - def count_parameters_in_layer(model, layer_name): num_params = 0 @@ -38,8 +37,7 @@ def count_parameters_in_layer(model, layer_name): def _add_load_convert_hooks(model: MCoreGPTModel): - """Register some load_state_dict prehooks to handle some known state_dict key mismatch. - """ + """Register some load_state_dict prehooks to handle some known state_dict key mismatch.""" args = get_args() if args.export_te_mcore_model: model._register_load_state_dict_pre_hook(mcore_gpt_load_te_state_dict_pre_hook) @@ -104,17 +102,28 @@ def _load_teacher_model_config(checkpoint_path: str) -> Namespace: args_dict = vars(get_args()).copy() del args_dict["kv_channels"] # not recalculated if present + # Setting teacher Flextron fields to false if training with Flextron, can be overridden + if "flextron" in args_dict: + config["flextron"] = False + if "enable_router" in args_dict: + config["enable_router"] = False + if "freeze_model" in args_dict: + config["freeze_model"] = False args_dict.update(config) # Backward compat: old checkpoints have hybrid_override_pattern but not hybrid_layer_pattern - if (args_dict.get('hybrid_override_pattern') is not None - and args_dict.get('hybrid_layer_pattern') is None): + if ( + args_dict.get('hybrid_override_pattern') is not None + and args_dict.get('hybrid_layer_pattern') is None + ): args_dict['hybrid_layer_pattern'] = args_dict['hybrid_override_pattern'] return Namespace(**args_dict) -def _load_teacher_model(config, config_raw: Namespace, model_kwargs: Dict[str, Any]) -> MCoreGPTModel: +def _build_teacher_model( + config, config_raw: Namespace, model_kwargs: Dict[str, Any] +) -> MCoreGPTModel: """Teacher model creator.""" args = get_args() @@ -124,48 +133,35 @@ def _load_teacher_model(config, config_raw: Namespace, model_kwargs: Dict[str, A # _load_teacher_model_config, so config_raw.hybrid_layer_pattern is always set here. model_kwargs["hybrid_layer_pattern"] = config_raw.hybrid_layer_pattern - teacher = MCoreMambaModel(config=config, **model_kwargs) + teacher = MCoreHybridModel(config=config, **model_kwargs) else: # GPT layer spec needs re-creation since it depends on number of model layers. if config.heterogeneous_block_specs: model_kwargs["transformer_layer_spec"] = get_gpt_heterogeneous_layer_spec( - config=config, - use_te=(args.transformer_impl == "transformer_engine"), + config=config, use_te=(args.transformer_impl == "transformer_engine") ) else: model_kwargs["transformer_layer_spec"] = get_gpt_modelopt_spec( config=config, - local_core_attention=False if config.context_parallel_size > 1 else args.export_force_local_attention, + local_core_attention=( + False if config.context_parallel_size > 1 else args.export_force_local_attention + ), remap_te_layernorm=args.export_te_mcore_model, real_quant_cfg=args.export_real_quant_cfg, use_arbitrary_attention_mask=False, ) teacher = MCoreGPTModel(config=config, **model_kwargs) + _add_load_convert_hooks(teacher) - print_rank_0(f"Loading teacher as {type(teacher).__name__} from {args.export_kd_teacher_load} ...") - # [WAR]: load checkpoint will check checkpoint's saved args and rng state if not finetune. - # To avoid error out on loading teacher's checkpoint, we temporarily set args.finetune to - # True while loading the teacher checkpoint. - original_args_finetune, original_ckpt_format = args.finetune, args.ckpt_format - args.finetune = True - if args.export_kd_teacher_ckpt_format is not None: - args.ckpt_format = args.export_kd_teacher_ckpt_format - load_modelopt_checkpoint([teacher], load_arg='export_kd_teacher_load') - args.finetune, args.ckpt_format = original_args_finetune, original_ckpt_format - print_rank_0("...teacher loaded successfully.") + # NOTE: Checkpoint loading now handled in `megatron/training/checkpointing.py`. return teacher -def modelopt_gpt_mamba_builder( - args, - pre_process, - post_process, - vp_stage=None, - config=None, - pg_collection=None, -) -> MCoreGPTModel | MCoreMambaModel: +def modelopt_gpt_hybrid_builder( + args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None +) -> MCoreGPTModel | MCoreHybridModel: """Builds the model. Args: @@ -179,7 +175,7 @@ def modelopt_gpt_mamba_builder( attached to the returned model for downstream routing/resharding utilities. Returns: - MCoreGPTModel | MCoreMambaModel: The returned model + MCoreGPTModel | MCoreHybridModel: The returned model """ print_rank_0("building GPT model ...") @@ -202,10 +198,8 @@ def modelopt_gpt_mamba_builder( config.yarn_correction_range_round_to_int = False if vp_stage is not None: - raise ValueError("ModelOpt integration does not currently support virtual pipeline parallel.") - if args.use_legacy_models: raise ValueError( - "ModelOpt integration only support MCore models. Use --use-mcore-modules instead." + "ModelOpt integration does not currently support virtual pipeline parallel." ) if args.spec is not None: raise ValueError("ModelOpt integration does not support custom args.spec.") @@ -225,15 +219,14 @@ def modelopt_gpt_mamba_builder( config.sequence_parallel = False if config.heterogeneous_block_specs: transformer_layer_spec = get_gpt_heterogeneous_layer_spec( - config=config, - use_te=args.transformer_impl == "transformer_engine", + config=config, use_te=args.transformer_impl == "transformer_engine" ) else: if config.context_parallel_size > 1: print_rank_0("context_parallel_size > 1! Force using TEDotProductAttention!") - local_core_attention=False + local_core_attention = False else: - local_core_attention=args.export_force_local_attention + local_core_attention = args.export_force_local_attention transformer_layer_spec = get_gpt_modelopt_spec( config=config, @@ -259,8 +252,22 @@ def modelopt_gpt_mamba_builder( "pg_collection": pg_collection, } model = MCoreGPTModel(config=config, **model_kwargs) - elif args.export_model_type == "MambaModel" or getattr(args, 'hybrid_layer_pattern', None) is not None: - from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec + elif ( + args.export_model_type in ("HybridModel", "MambaModel") + or getattr(args, 'hybrid_layer_pattern', None) is not None + ): + if args.export_model_type == "MambaModel": + import warnings + + warnings.warn( + '--export-model-type "MambaModel" is deprecated. ' + 'Use --export-model-type "HybridModel" instead.', + DeprecationWarning, + stacklevel=2, + ) + from megatron.core.post_training.modelopt.hybrid.model_specs import ( + get_hybrid_stack_modelopt_spec, + ) if args.export_default_te_spec and args.export_te_mcore_model: logging.getLogger(__name__).warning( @@ -269,12 +276,12 @@ def modelopt_gpt_mamba_builder( ) args.export_te_mcore_model = False - mamba_stack_spec = get_mamba_stack_modelopt_spec( + hybrid_stack_spec = get_hybrid_stack_modelopt_spec( remap_te_layernorm=args.export_te_mcore_model, use_default_te_spec=args.export_default_te_spec, ) model_kwargs = { - "mamba_stack_spec": mamba_stack_spec, + "hybrid_stack_spec": hybrid_stack_spec, "vocab_size": args.padded_vocab_size, "max_sequence_length": args.max_position_embeddings, "hybrid_layer_pattern": args.hybrid_layer_pattern, @@ -289,7 +296,7 @@ def modelopt_gpt_mamba_builder( "pg_collection": pg_collection, } - model = MCoreMambaModel(config=config, **model_kwargs) + model = MCoreHybridModel(config=config, **model_kwargs) for l in range(model.decoder.num_layers_per_pipeline_rank): layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.') @@ -332,13 +339,15 @@ def modelopt_gpt_mamba_builder( ), "ModelOpt Distillation currently incompatible with interleaved pipeline schedule." teacher_config_raw = _load_teacher_model_config(args.export_kd_teacher_load) - teacher_config = core_transformer_config_from_args(teacher_config_raw) # convert to TransformerConfig + teacher_config = core_transformer_config_from_args( + teacher_config_raw + ) # convert to TransformerConfig distill_cfg = mtd_mcore.setup_distillation_config( args.export_kd_cfg, student_cfg=config, teacher_cfg=teacher_config ) kd_config = { - "teacher_model": _load_teacher_model(teacher_config, teacher_config_raw, model_kwargs), + "teacher_model": _build_teacher_model(teacher_config, teacher_config_raw, model_kwargs), "criterion": distill_cfg.criterion, "loss_balancer": distill_cfg.loss_balancer, } @@ -348,7 +357,13 @@ def modelopt_gpt_mamba_builder( # (accounts for sharded state, pipeline parallel, and potentially skipping LM loss) mtd_mcore.adjust_distillation_model_for_mcore(model, distill_cfg) # Also remove KD mode state to prevent issues with re-conversion after restore. - mto.ModeloptStateManager(model).state_dict().pop() # TODO(aanoosheh): remove once fixed in ModelOpt - + mto.ModeloptStateManager( + model + ).state_dict().pop() # TODO(aanoosheh): remove once fixed in ModelOpt + print_distributed_quant_summary(model) return model + + +# Backward-compatible alias +modelopt_gpt_mamba_builder = modelopt_gpt_hybrid_builder diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index ff3622bac2d..4d83a8f8954 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -2030,8 +2030,8 @@ def megatron_rl_inference_mode( logger.debug(f"[{dist.get_rank()}] Entering inference mode") - # Change cudagraph scope for inference (empty list = full-layer capture) - model[0].config.cuda_graph_scope = [] + # Set cudagraph scope for inference. + model[0].config.cuda_graph_scope = args.cuda_graph_scope model[0].config.cuda_graph_impl = "local" # If we get a lower precision wrapper, we go one object deeper. @@ -2091,7 +2091,8 @@ def megatron_rl_inference_mode( # Reset drop_and_pad leaked from inference decode set_decode_expert_padding(unwrap_model(model[0]), set_to=False) - # Restore partial capture cudagraph scope for training if this is MoE + # Restore cudagraph scope for training. + # MoE partial capture requires specific scopes that aren't user-facing. if args.num_experts is not None: model[0].config.cuda_graph_scope = [ CudaGraphScope.mamba, @@ -2099,6 +2100,10 @@ def megatron_rl_inference_mode( CudaGraphScope.moe_router, CudaGraphScope.moe_preprocess, ] + else: + model[0].config.cuda_graph_scope = [ + s for s in args.cuda_graph_scope if s != CudaGraphScope.full_iteration_inference + ] # Switch MoE layers to partial CUDA graph capture for training if args.rl_training_cuda_graphs and args.num_experts is not None: diff --git a/megatron/training/activation_logging.py b/megatron/training/activation_logging.py index 97f69e20b09..167712aadea 100644 --- a/megatron/training/activation_logging.py +++ b/megatron/training/activation_logging.py @@ -80,6 +80,25 @@ def _discover_te_types(): ) +def _parse_tpe_module_name(module_name: str) -> Tuple[str, int | None, int] | None: + """Parse a TPE-eligible module name into ``(block, mtp_idx, layer)``. + + Returns ``None`` if *module_name* matches neither the decoder nor the MTP pattern. + + Examples:: + + decoder.layers.3.mlp.experts.linear_fc1 -> ("decoder", None, 3) + mtp.layers.0.mtp_model_layer.layers.1.mlp.experts.linear_fc1 -> ("mtp", 0, 1) + """ + if m := re.fullmatch(r'decoder\.layers\.(\d+)\.mlp\.experts\.linear_fc1', module_name): + return "decoder", None, int(m.group(1)) + if m := re.fullmatch( + r'mtp\.layers\.(\d+)\.mtp_model_layer\.layers\.(\d+)\.mlp\.experts\.linear_fc1', module_name + ): + return "mtp", int(m.group(1)), int(m.group(2)) + return None + + def _register_hooks(model, module_types, hook_factory, *, name_filter=None): """Walk *model* and register a forward hook on every module matching *module_types*. @@ -126,8 +145,10 @@ def __init__(self, save_dir: str): self._activations_state_dict: defaultdict = defaultdict(dict) self._activation_hooks: List[torch.utils.hooks.RemovableHook] = [] - # Tokens-per-expert state: layer -> list of per-microbatch token counts. - self._tpe_records: dict[str, list] = defaultdict(list) + # Tokens-per-expert state: per-microbatch token counts. Decoder entries + # are keyed by ``layer``; MTP entries by ``(mtp_idx, inner_layer)``. + self._decoder_tpe_records: dict[int, list[list[int]]] = defaultdict(list) + self._mtp_tpe_records: dict[Tuple[int, int], list[list[int]]] = defaultdict(list) self._tpe_hooks: List[torch.utils.hooks.RemovableHook] = [] # ------------------------------------------------------------------ @@ -179,28 +200,34 @@ def save_activations(self, iteration: int): # Tokens-per-expert hooks # ------------------------------------------------------------------ - def _make_tpe_hook(self, _model_chunk_name: str, module_name: str) -> Callable: + def _make_tpe_hook(self, _model_chunk_name: str, module_name: str) -> Callable | None: """Forward hook that captures only the non-Tensor ``input1`` (tokens_per_expert). - The layer number is extracted from *module_name* - (e.g. ``decoder.layers.3.mlp.experts.linear_fc1`` → ``3``). + Attaches to main decoder MoE layers + (``decoder.layers..mlp.experts.linear_fc1``) and MTP MoE layers + (``mtp.layers..mtp_model_layer.layers..mlp.experts.linear_fc1``). + Returns ``None`` (and logs a warning) for any other module name. """ - m = re.search(r'\.layers\.(\d+)\.', module_name) - if not m: + parsed = _parse_tpe_module_name(module_name) + if parsed is None: logger.warning( "Cannot extract layer number from module name: %r — " "skipping tokens-per-expert hook for this module", module_name, ) return None - layer = m.group(1) + block, mtp_idx, layer = parsed + if block == "decoder": + records, key = self._decoder_tpe_records, layer + else: + records, key = self._mtp_tpe_records, (mtp_idx, layer) def hook(_, args, kwargs, output): input_tuple = args if isinstance(args, tuple) else (args,) if len(input_tuple) > 1 and input_tuple[1] is not None: inp = input_tuple[1] if not isinstance(inp, torch.Tensor): - self._tpe_records[layer].append(list(inp)) + records[key].append(list(inp)) return hook @@ -222,23 +249,45 @@ def save_tpe(self, iteration: int): """Append captured tokens-per-expert records as JSON Lines. Each rank writes to its own file under ``{save_dir}/tokens_per_expert/``, - e.g. ``rank0.jsonl``, ``rank1.jsonl``. Each line is a JSON object:: + e.g. ``rank0.jsonl``, ``rank1.jsonl``. Each line is a JSON object; the + ``mtp_idx`` field is present only for MTP entries:: - {"iter": 100, "layer": 3, "tpe": [[128, 64], [96, 80]]} + {"iter": 100, "block": "decoder", "layer": 3, "tpe": [[128, 64], [96, 80]]} + {"iter": 100, "block": "mtp", "mtp_idx": 0, "layer": 1, "tpe": [[50, 50]]} """ - if not self._tpe_records: + if not self._decoder_tpe_records and not self._mtp_tpe_records: return rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 tpe_dir = os.path.join(self._save_dir, "tokens_per_expert") os.makedirs(tpe_dir, exist_ok=True) filepath = os.path.join(tpe_dir, f"rank{rank}.jsonl") - lines = "".join( - json.dumps({"iter": iteration, "layer": int(layer), "tpe": microbatches}) + "\n" - for layer, microbatches in sorted(self._tpe_records.items()) - ) + + lines = [] + for layer, microbatches in sorted(self._decoder_tpe_records.items()): + lines.append( + json.dumps( + {"iter": iteration, "block": "decoder", "layer": layer, "tpe": microbatches} + ) + + "\n" + ) + for (mtp_idx, layer), microbatches in sorted(self._mtp_tpe_records.items()): + lines.append( + json.dumps( + { + "iter": iteration, + "block": "mtp", + "mtp_idx": mtp_idx, + "layer": layer, + "tpe": microbatches, + } + ) + + "\n" + ) + with open(filepath, "a") as f: - f.write(lines) - self._tpe_records.clear() + f.writelines(lines) + self._decoder_tpe_records.clear() + self._mtp_tpe_records.clear() _LOGGER: ActivationLogger | None = None diff --git a/megatron/training/argument_utils.py b/megatron/training/argument_utils.py index b9f7c7b22d1..611047aed9f 100644 --- a/megatron/training/argument_utils.py +++ b/megatron/training/argument_utils.py @@ -1,30 +1,48 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import ast +import builtins import dataclasses -import typing -import types -from typing import Any, Optional -from argparse import ArgumentParser, _ArgumentGroup +import enum import inspect import itertools -import builtins -import ast -import enum +import types +import typing +from argparse import ArgumentParser, Namespace, _ArgumentGroup from dataclasses import Field, fields +from typing import Any, Callable, Optional + +from megatron.training.config import ( + CheckpointConfig, + DistributedInitConfig, + LoggerConfig, + PretrainConfigContainer, + ProfilingConfig, + RerunStateMachineConfig, + RNGConfig, + SchedulerConfig, + StragglerDetectionConfig, + TokenizerConfig, + TrainingConfig, + ValidationConfig, +) # TODO: support arg renames + class TypeInferenceError(Exception): """Custom exception type to be conditionally handled by ArgumentGroupFactory.""" + pass + class ArgumentGroupFactory: """Utility that adds an argument group to an ArgumentParser based on the attributes of a dataclass. This utility uses dataclass metadata including type annotations and docstrings to automatically infer the type, default, and other argparse keyword arguments. - You can override or supplement the automatically inferred argparse kwargs for any + You can override or supplement the automatically inferred argparse kwargs for any dataclass field by providing an "argparse_meta" key in the field's metadata dict. The value should be a dict of kwargs that will be passed to ArgumentParser.add_argument(). These metadata kwargs take precedence over the automatically inferred values. @@ -53,13 +71,13 @@ class YourConfig: that require some customized or additional handling. Args: - src_cfg_class: The source dataclass type (not instance) whose fields will be - converted into command-line arguments. Each field's type annotation determines - the argument type, default values become argument defaults, and field-level + src_cfg_class: The source dataclass type (not instance) whose fields will be + converted into command-line arguments. Each field's type annotation determines + the argument type, default values become argument defaults, and field-level docstrings are extracted to populate argument help text. - exclude: Optional list of attribute names from `src_cfg_class` to exclude from + exclude: Optional list of attribute names from `src_cfg_class` to exclude from argument generation. Useful for omitting internal fields, computed properties, - or attributes that should be configured through other means. If None, all + or attributes that should be configured through other means. If None, all dataclass fields will be converted to command-line arguments. Default: None. """ @@ -73,7 +91,7 @@ def _format_arg_name(self, config_attr_name: str, prefix: Optional[str] = None) Args: config_attr_name: dataclass attribute name - prefix: prefix string to add to the dataclass attribute name. e.g. 'no' for bool + prefix: prefix string to add to the dataclass attribute name. e.g. 'no' for bool settings that are default True. A hyphen is added after the prefix. Default: None """ arg_name = config_attr_name @@ -88,6 +106,7 @@ def _get_enum_kwargs(self, config_type: enum.EnumMeta) -> dict[str, Any]: With these settings, the user must provide a valid enum value, e.g. 'flash', for `AttnBackend.flash`. """ + def enum_type_handler(cli_arg): return config_type[cli_arg] @@ -111,7 +130,9 @@ def _extract_type(self, config_type: type) -> dict[str, Any]: if origin in [types.UnionType, typing.Union]: # Handle Optional and Union - if type_tuple[1] == type(None): # Optional type. First element is value inside Optional[] + if type_tuple[1] == type( + None + ): # Optional type. First element is value inside Optional[] return self._extract_type(type_tuple[0]) else: raise TypeInferenceError(f"Unions not supported by argparse: {config_type}") @@ -122,17 +143,20 @@ def _extract_type(self, config_type: type) -> dict[str, Any]: kwargs["nargs"] = "+" return kwargs else: - raise TypeInferenceError(f"Multi-type lists not supported by argparse: {config_type}") + raise TypeInferenceError( + f"Multi-type lists not supported by argparse: {config_type}" + ) elif origin is typing.Literal: choices_types = [type(choice) for choice in type_tuple] - assert all([t == choices_types[0] for t in choices_types]), "Type of each choice in a Literal type should all be the same." + assert all( + [t == choices_types[0] for t in choices_types] + ), "Type of each choice in a Literal type should all be the same." kwargs = {"type": choices_types[0], "choices": type_tuple} return kwargs else: raise TypeInferenceError(f"Unsupported type: {config_type}") - def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: """Assemble kwargs for add_argument(). @@ -142,7 +166,9 @@ def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: argparse_kwargs = {} argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name)] argparse_kwargs["dest"] = attribute.name - argparse_kwargs["help"] = self.field_docstrings[attribute.name] if attribute.name in self.field_docstrings else "" + argparse_kwargs["help"] = ( + self.field_docstrings[attribute.name] if attribute.name in self.field_docstrings else "" + ) # dataclasses specifies that both should not be set if isinstance(attribute.default, type(dataclasses.MISSING)): @@ -156,7 +182,6 @@ def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: # save metadata here, but update at the end so the metadata has highest precedence attr_argparse_meta = attribute.metadata["argparse_meta"] - # if we cannot infer the argparse type, all of this logic may fail. we try to defer # to the developer-specified metadata if present try: @@ -164,12 +189,17 @@ def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: # use store_true or store_false action for enable/disable flags, which doesn't accept a 'type' if argparse_kwargs["type"] == bool: - argparse_kwargs["action"] = "store_true" if attribute.default == False else "store_false" + argparse_kwargs["action"] = ( + "store_true" if attribute.default == False else "store_false" + ) argparse_kwargs.pop("type") # add '--no-*' and '--disable-*' prefix if this is a store_false argument if argparse_kwargs["action"] == "store_false": - argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name, prefix="no"), self._format_arg_name(attribute.name, prefix="disable")] + argparse_kwargs["arg_names"] = [ + self._format_arg_name(attribute.name, prefix="no"), + self._format_arg_name(attribute.name, prefix="disable"), + ] except TypeInferenceError as e: if attr_argparse_meta is not None: print( @@ -181,7 +211,7 @@ def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: else: raise e - # metadata provided by field takes precedence + # metadata provided by field takes precedence if attr_argparse_meta is not None: argparse_kwargs.update(attr_argparse_meta) @@ -231,8 +261,12 @@ def _get_field_docstrings(self, src_cfg_class: type) -> dict[str, str]: if a_cond and b_cond: # These should be guaranteed by typechecks above, but assert just in case - assert isinstance(a.target.id, str), "Dataclass attribute not in the expected format. Name is not a string." - assert isinstance(b.value.value, str), "Dataclass attribute docstring is not a string." + assert isinstance( + a.target.id, str + ), "Dataclass attribute not in the expected format. Name is not a string." + assert isinstance( + b.value.value, str + ), "Dataclass attribute docstring is not a string." # Formatting docstring = inspect.cleandoc(b.value.value) @@ -248,3 +282,60 @@ def _get_field_docstrings(self, src_cfg_class: type) -> dict[str, str]: field_docstrings.update(self._get_field_docstrings(base_classes[0])) return field_docstrings + + +def _default_config_from_args(cls: type, args: Namespace, return_instance: bool = True) -> Any: + """Create a config dataclass from the appropriate values in the `args` Namespace. + + This is generic, i.e. it will work if dataclass attribute names map 1-to-1 with + names in `args`. Some classes might require additional logic. + """ + kwargs = {} + for f in fields(cls): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + + if return_instance: + return cls(**kwargs) + else: + return kwargs + + +def pretrain_cfg_container_from_args(args: Namespace) -> PretrainConfigContainer: + """Build a PretrainConfigContainer from the argparse arguments.""" + from megatron.training.training import get_megatron_ddp_config, get_megatron_optimizer_config + + ckpt_kwargs = _default_config_from_args(CheckpointConfig, args, return_instance=False) + ckpt_kwargs["save_optim"] = not args.no_save_optim + ckpt_kwargs["save_rng"] = not args.no_save_rng + ckpt_kwargs["load_optim"] = not args.no_load_optim + ckpt_kwargs["load_rng"] = not args.no_load_rng + ckpt_kwargs["fully_parallel_save"] = args.ckpt_fully_parallel_save + ckpt_kwargs["fully_parallel_load"] = args.ckpt_fully_parallel_load + + prof_kwargs = _default_config_from_args(ProfilingConfig, args, return_instance=False) + prof_kwargs["use_nsys_profiler"] = args.profile + + rerunsm_kwargs = _default_config_from_args(RerunStateMachineConfig, args, return_instance=False) + rerunsm_kwargs["check_for_nan_in_loss"] = args.check_for_nan_in_loss_and_grad + + optim_cfg, _ = get_megatron_optimizer_config(args) + ddp_config = get_megatron_ddp_config(args) + + cfg = PretrainConfigContainer( + train=_default_config_from_args(TrainingConfig, args), + validation=_default_config_from_args(ValidationConfig, args), + optimizer=optim_cfg, + scheduler=_default_config_from_args(SchedulerConfig, args), + ddp=ddp_config, + dist=_default_config_from_args(DistributedInitConfig, args), + rng=_default_config_from_args(RNGConfig, args), + logger=_default_config_from_args(LoggerConfig, args), + checkpoint=CheckpointConfig(**ckpt_kwargs), + profiling=ProfilingConfig(**prof_kwargs), + tokenizer=_default_config_from_args(TokenizerConfig, args), + rerun_state_machine=RerunStateMachineConfig(**rerunsm_kwargs), + straggler=_default_config_from_args(StragglerDetectionConfig, args), + ) + + return cfg diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 6a108a0d6d0..7c531b290fb 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1773,6 +1773,19 @@ def validate_args(args, defaults={}): ) args.async_save = False + if not args.async_save: + args.async_strategy = "mcore" + + if args.inference_dynamic_batching_sampling_backend == 'flashinfer': + try: + import flashinfer # noqa: F401 + except ImportError as e: + raise ImportError( + "--inference-dynamic-batching-sampling-backend=flashinfer requires " + "the flashinfer package; install it or pass " + "--inference-dynamic-batching-sampling-backend=torch." + ) from e + # Inference args if args.inference_batch_times_seqlen_threshold > -1: assert ( @@ -2370,6 +2383,15 @@ def _add_inference_args(parser): default=16, help='Number of mixed prefill requests to capture in a cuda graph.', ) + group.add_argument( + '--inference-dynamic-batching-sampling-backend', + type=str, + default='torch', + choices=['torch', 'flashinfer'], + help='Which sampling kernels to use during inference. ' + 'Falls back to "torch" with a warning if "flashinfer" ' + 'is requested but the package is not installed.', + ) group.add_argument( '--inference-logging-step-interval', type=int, @@ -2396,6 +2418,16 @@ def _add_inference_args(parser): type=int, help="This port will be used to setup the inference coordinator on node-0", ) + group.add_argument( + '--inference-disable-ep-consensus', + action=argparse.BooleanOptionalAction, + required=False, + default=False, + help='Skip the EP-group consensus all-reduce in the inference engine control loop and ' + 'step on local state only. ' + 'Pause/unpause take effect as soon as the signal is delivered to a rank. ' + 'Only safe when EP coordination is not required (e.g. ep_world_size == 1).', + ) group.add_argument( '--mamba-inference-conv-states-dtype', type=str, @@ -4599,13 +4631,13 @@ def _add_mla_args(parser): '--o-groups', type=int, default=8, - help="Number of groups for grouped output (wo_a). 0 = single linear." + help="Number of groups for grouped output (wo_a). 0 = single linear.", ) group.add_argument( '--o-lora-rank', type=int, default=1024, - help="Low-rank dimension per group for grouped output (wo_a). Used when o-groups > 0." + help="Low-rank dimension per group for grouped output (wo_a). Used when o-groups > 0.", ) group.add_argument( '--cache-mla-latents', @@ -4645,11 +4677,11 @@ def _add_experimental_attention_variant_args(parser): type=compress_ratios_type, default=None, help='Per-layer compress ratios for compressed sparse attention. ' - 'Accepts a string containing a Python list expression, e.g.: ' - '"[0,0,4,128,4,128]" or "([0]+[4,128]*2)*3". ' - 'Each value is the compression ratio for the corresponding ' - 'transformer layer (valid values: 0, 4, 128). ' - 'The list length must equal num_layers.' + 'Accepts a string containing a Python list expression, e.g.: ' + '"[0,0,4,128,4,128]" or "([0]+[4,128]*2)*3". ' + 'Each value is the compression ratio for the corresponding ' + 'transformer layer (valid values: 0, 4, 128). ' + 'The list length must equal num_layers.', ) return parser diff --git a/megatron/training/async_utils.py b/megatron/training/async_utils.py index c1f75934aa5..091bb1f93cf 100644 --- a/megatron/training/async_utils.py +++ b/megatron/training/async_utils.py @@ -8,14 +8,20 @@ import logging import time from abc import ABC +from typing import TYPE_CHECKING, Any from megatron.core.dist_checkpointing.strategies.async_utils import AsyncRequest +from megatron.core.dist_checkpointing.strategies.nvrx import make_nvrx_async_request from megatron.core.dist_checkpointing.strategies.torch import get_async_strategy from megatron.training import get_args from megatron.training.utils import print_rank_0 -try: +if TYPE_CHECKING: from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncRequest as NVRxAsyncRequest +else: + NVRxAsyncRequest = Any + +try: from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import _results_queue from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import ( save_state_dict_async_finalize, @@ -26,8 +32,6 @@ save_state_dict_async_finalize, ) - NVRxAsyncRequest = ABC - logger = logging.getLogger(__name__) # Singleton manager of async calls @@ -70,13 +74,13 @@ def init_persistent_async_worker(rank: int, mp_mode: str = 'spawn'): ), ) # initialize the persistent caller with QoS priorities from args - kwargs = {} + warmup_kwargs = {} if async_strategy == "mcore": # Note: nvidia-resiliency-ext uses is_daemon instead of mp_mode (always spawns) - kwargs["mp_mode"] = mp_mode + warmup_kwargs["mp_mode"] = mp_mode elif async_strategy == "nvrx": if "cpu_shm_mode" in inspect.signature(AsyncCallsQueue.warmup_persistent_caller).parameters: - kwargs["cpu_shm_mode"] = args.async_ckpt_use_cpu_shm + warmup_kwargs["cpu_shm_mode"] = args.async_ckpt_use_cpu_shm elif args.async_ckpt_use_cpu_shm: raise AssertionError( "Installed nvidia-resiliency-ext does not support cpu_shm_mode. " @@ -86,10 +90,16 @@ def init_persistent_async_worker(rank: int, mp_mode: str = 'spawn'): rank, cpu_priority=args.async_ckpt_cpu_priority, io_priority=args.async_ckpt_io_priority, - **kwargs, + **warmup_kwargs, ) # initialize ckpt write results queue - get_write_results_queue('fork') + if async_strategy == "nvrx": + if "mp_mode" not in inspect.signature(get_write_results_queue).parameters: + raise AssertionError( + "Installed nvidia-resiliency-ext does not support " + "get_write_results_queue(mp_mode=...). Update nvidia-resiliency-ext." + ) + get_write_results_queue(mp_mode="fork") if rank == 0: print( f"init_persistent_async_worker: rank {rank}, Async Caller Started in {time.time() - time_start} seconds", @@ -161,14 +171,24 @@ def reset_persistent_async_worker(async_strategy): module.clear_metadata_cache() -def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> NVRxAsyncRequest: +def get_save_and_finalize_callbacks( + writer, save_state_dict_ret, async_strategy: str = "nvrx" +) -> AsyncRequest | NVRxAsyncRequest: """Creates an async save request for fsdp_dtensor & torch_dcp with a finalize function.""" save_fn, preload_fn, save_args = writer.get_save_function_and_args() + _, async_modules = get_async_strategy(async_strategy) + async_request_cls = async_modules["AsyncRequest"] + save_state_dict_async_finalize = async_modules["save_state_dict_async_finalize"] def finalize_fn(): """Finalizes async checkpointing and synchronizes processes.""" save_state_dict_async_finalize(*save_state_dict_ret) - return NVRxAsyncRequest( - save_fn, save_args, [finalize_fn], async_fn_kwargs={}, preload_fn=preload_fn + return make_nvrx_async_request( + async_request_cls, + save_fn, + save_args, + [finalize_fn], + async_fn_kwargs={}, + preload_fn=preload_fn, ) diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index ed27e83051e..2b4a3829286 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -33,12 +33,13 @@ from megatron.core.dist_checkpointing.strategies.torch import ( TorchDistLoadShardedStrategy, TorchDistSaveShardedStrategy, + get_async_strategy, ) from megatron.core.msc_utils import MultiStorageClientFeature, open_file from megatron.core.num_microbatches_calculator import update_num_microbatches from megatron.core.optimizer import DistributedOptimizer from megatron.core.rerun_state_machine import get_rerun_state_machine -from megatron.core.utils import get_pg_rank, get_pg_size, get_torch_version, is_torch_min_version +from megatron.core.utils import get_pg_rank, get_pg_size from ..core.dist_checkpointing.utils import _clean_metadata_for_serialization from . import ft_integration, wandb_utils @@ -74,19 +75,6 @@ has_nvidia_modelopt = False -try: - from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import ( - FileSystemWriterAsync, - ) - from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import ( - save_state_dict_async_plan, - ) - - HAVE_NVRX = True -except (ImportError, ModuleNotFoundError): - - HAVE_NVRX = False - _CHECKPOINT_VERSION = None _LOADED_ITERATION = None @@ -793,6 +781,7 @@ def save_checkpoint( preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn, content_metadata=_clean_metadata_for_serialization(sharded_sd_metadata), async_strategy=args.async_strategy, + verify_integrity=args.verify_integrity, ) # [ModelOpt]: save sharded modelopt_state if has_nvidia_modelopt: @@ -808,6 +797,9 @@ def save_checkpoint( if args.async_save: planner = torch.distributed.checkpoint.DefaultSavePlanner() coordinator_rank = 0 + _, async_modules = get_async_strategy(args.async_strategy) + FileSystemWriterAsync = async_modules["FileSystemWriterAsync"] + save_state_dict_async_plan = async_modules["save_state_dict_async_plan"] _cpu_shm = getattr(args, 'async_ckpt_use_cpu_shm', False) _writer_kwargs = {} if _cpu_shm: @@ -840,6 +832,9 @@ def save_checkpoint( async_save_request = get_save_and_finalize_callbacks( fs_storage_writer, save_state_dict_ret ) + async_save_request = get_save_and_finalize_callbacks( + fs_storage_writer, save_state_dict_ret, args.async_strategy + ) else: fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(checkpoint_name) torch.distributed.checkpoint.save( @@ -1464,7 +1459,12 @@ def _load_global_dist_base_checkpoint( if checkpointing_context is not None: checkpointing_context["load_strategy"] = load_strategy state_dict = dist_checkpointing.load( - sharded_state_dict, checkpoint_name, load_strategy, strict=args.dist_ckpt_strictness + sharded_state_dict, + checkpoint_name, + load_strategy, + validate_access_integrity=args.ckpt_load_validate_sharding_integrity, + strict=args.dist_ckpt_strictness, + verify_integrity=args.verify_integrity, ) return state_dict, checkpoint_name, release, CheckpointType.GLOBAL @@ -2309,9 +2309,6 @@ def load_model_state_dict(module, state_dict, strict: bool): f'at iteration {iteration}' ) - if has_nvidia_modelopt: - print_distributed_quant_summary(model, msg="After loading checkpoint") - # Additional callback for wandb (last rank) if not torch.distributed.is_initialized() or is_last_rank(): wandb_utils.on_load_checkpoint_success(checkpoint_name, load_dir) @@ -2335,6 +2332,30 @@ def load_model_state_dict(module, state_dict, strict: bool): ) log_printed = True + if has_nvidia_modelopt: + print_distributed_quant_summary(model, msg="After loading checkpoint") + + # Load teacher model in Distillation mode. + if getattr(args, "export_kd_teacher_load", None): + from megatron.post_training.checkpointing import load_modelopt_checkpoint + + unwrapped_model = unwrap_model(model)[0] + # Note: load_modelopt_checkpoint may call this function so we prevent infinite recursion. + if hasattr(unwrapped_model, 'teacher_model'): + teacher = unwrapped_model.teacher_model + print_rank_0( + f"Loading teacher as {type(teacher).__name__} from {args.export_kd_teacher_load} ..." + ) + # [WAR]: To avoid error out on loading teacher's checkpoint, we temporarily + # set args.finetune to True while loading the teacher checkpoint. + original_args_finetune, original_ckpt_format = args.finetune, args.ckpt_format + args.finetune = True + if args.export_kd_teacher_ckpt_format is not None: + args.ckpt_format = args.export_kd_teacher_ckpt_format + load_modelopt_checkpoint([teacher], load_arg='export_kd_teacher_load') + args.finetune, args.ckpt_format = original_args_finetune, original_ckpt_format + print_rank_0("... teacher loaded successfully.") + return iteration, num_floating_point_operations_so_far diff --git a/megatron/training/config/__init__.py b/megatron/training/config/__init__.py index 3d346ddd8fe..46da0025362 100644 --- a/megatron/training/config/__init__.py +++ b/megatron/training/config/__init__.py @@ -1,18 +1,18 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -from megatron.training.config.common_config import ( - RNGConfig, - ProfilingConfig, - DistributedInitConfig, +from megatron.training.config.common_config import DistributedInitConfig, ProfilingConfig, RNGConfig +from megatron.training.config.container import PretrainConfigContainer +from megatron.training.config.instantiate_utils import TargetAllowlist, target_allowlist +from megatron.training.config.resilience_config import ( + FaultInjectorConfig, + RerunStateMachineConfig, + StragglerDetectionConfig, ) from megatron.training.config.training_config import ( + CheckpointConfig, + LoggerConfig, + SchedulerConfig, + TokenizerConfig, TrainingConfig, ValidationConfig, - SchedulerConfig, - LoggerConfig, - CheckpointConfig, -) -from megatron.training.config.resilience_config import ( - RerunStateMachineConfig, - StragglerDetectionConfig, ) diff --git a/megatron/training/config/container.py b/megatron/training/config/container.py new file mode 100644 index 00000000000..efd96302425 --- /dev/null +++ b/megatron/training/config/container.py @@ -0,0 +1,253 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +import copy +import os +from dataclasses import dataclass, field +from dataclasses import fields as dataclass_fields +from dataclasses import is_dataclass +from typing import Any, Type, TypeVar + +import yaml +from omegaconf import OmegaConf + +from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig +from megatron.core.msc_utils import MultiStorageClientFeature +from megatron.core.optimizer import OptimizerConfig +from megatron.training.config.common_config import DistributedInitConfig, ProfilingConfig, RNGConfig +from megatron.training.config.instantiate_utils import InstantiationMode, instantiate +from megatron.training.config.resilience_config import ( + RerunStateMachineConfig, + StragglerDetectionConfig, +) +from megatron.training.config.training_config import ( + CheckpointConfig, + LoggerConfig, + SchedulerConfig, + TokenizerConfig, + TrainingConfig, + ValidationConfig, +) +from megatron.training.config.utils import sanitize_dataclass_config +from megatron.training.config.yaml_utils import safe_yaml_representers + +T = TypeVar("T", bound="ConfigContainerBase") + + +@dataclass(kw_only=True) +class ConfigContainerBase: + """ + Configuration container base class for Megatron configurations. + + Provides YAML/Dict serialization and deserialization. + """ + + @classmethod + def from_dict( + cls: Type[T], + config_dict: dict[str, Any], + mode: InstantiationMode = InstantiationMode.STRICT, + ) -> T: + """ + Create a config container from a dictionary. + + Args: + config_dict: Dictionary containing configuration + mode: Serialization mode (strict or lenient) + + Returns: + A new instance of this class initialized with the dictionary values + """ + # Make a copy to avoid modifying the input + config_dict = copy.deepcopy(config_dict) + + assert "_target_" in config_dict + + # Apply backward compatibility: remove init=False fields that may have been + # serialized by older versions (these are computed in __post_init__) + config_dict = sanitize_dataclass_config(config_dict) + + # Check for extra keys in strict mode + expected_fields = {f.name for f in dataclass_fields(cls) if not f.name.startswith("_")} + expected_fields.add("_target_") # Add _target_ as a valid field + extra_keys = set(config_dict.keys()) - expected_fields + + if extra_keys: + if mode == InstantiationMode.STRICT: + raise ValueError( + f"Dictionary contains extra keys not in {cls.__qualname__}: {extra_keys}" + ) + else: + # In lenient mode, remove extra keys + for key in extra_keys: + config_dict.pop(key) + + # Use instantiate to create the object + instance = instantiate(config_dict, mode=mode) + + return instance + + @classmethod + def from_yaml( + cls: Type[T], yaml_path: str, mode: InstantiationMode = InstantiationMode.LENIENT + ) -> T: + """ + Create a config container from a YAML file. + + Args: + yaml_path: Path to the YAML file + mode: Serialization mode (strict or lenient) + + Returns: + A new instance of this class initialized with the YAML file values + """ + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + yaml_path_exists = msc.os.path.exists(yaml_path) + else: + yaml_path_exists = os.path.exists(yaml_path) + + if not yaml_path_exists: + raise FileNotFoundError(f"YAML file not found: {yaml_path}") + + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + with msc.open(yaml_path, "r") as f: + config_dict = yaml.safe_load(f) + else: + with open(yaml_path, "r") as f: + config_dict = yaml.safe_load(f) + + # Convert to OmegaConf first for better compatibility with instantiate + conf = OmegaConf.create(config_dict) + + return cls.from_dict(OmegaConf.to_container(conf, resolve=True), mode=mode) + + def to_dict(self) -> dict[str, Any]: + """ + Convert the config container to a dictionary. + + Also converts any nested dataclasses (both ConfigContainer and regular dataclasses) + to dictionaries recursively. + + Returns: + Dictionary representation of this config + """ + result = {} + result["_target_"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + + for f in dataclass_fields(self): + if f.name.startswith("_"): + continue + + value = getattr(self, f.name) + result[f.name] = self._convert_value_to_dict(value) + + return result + + @classmethod + def _convert_value_to_dict(cls, value: Any) -> Any: + """ + Recursively convert a value to a dictionary representation. + + Handles: + - ConfigContainer instances (using to_dict) + - Serializable instances (using as_dict) + - Classes which implement a to_cfg_dict method + - Regular dataclasses (converting each non-private field) + - Lists and tuples (converting each element) + - Dictionaries (converting each value) + - Other types (kept as-is) + + Args: + value: The value to convert + + Returns: + The converted value + """ + if isinstance(value, ConfigContainerBase): + return value.to_dict() + # elif isinstance(value, Serializable): # TODO (@maanug): re-enable after upstreaming ModelConfig+Serializable + # return value.as_dict() + elif hasattr(value, "to_cfg_dict"): + # Allow non-Container classes to implement own custom method + return value.to_cfg_dict() + elif is_dataclass(value) and not isinstance(value, type): + # Handle regular dataclasses + result = {} + + # Add _target_ field for instantiation + result["_target_"] = f"{value.__class__.__module__}.{value.__class__.__qualname__}" + + # Convert each field, handling nested dataclasses properly + for field in dataclass_fields(value): + if field.name.startswith("_"): + continue + + field_value = getattr(value, field.name) + result[field.name] = cls._convert_value_to_dict(field_value) + + return result + elif isinstance(value, (list, tuple)): + return [cls._convert_value_to_dict(item) for item in value] + elif isinstance(value, dict): + return {k: cls._convert_value_to_dict(v) for k, v in value.items()} + else: + return value + + def to_yaml(self, yaml_path: str) -> None: + """ + Save the config container to a YAML file. + + Args: + yaml_path: Path where to save the YAML file. + """ + config_dict = self.to_dict() + + with safe_yaml_representers(): + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + with msc.open(yaml_path, "w") as f: + yaml.safe_dump(config_dict, f, default_flow_style=False) + else: + with open(yaml_path, "w") as f: + yaml.safe_dump(config_dict, f, default_flow_style=False) + + def print_yaml(self) -> None: + """ + Print the config container to the console in YAML format. + """ + config_dict = self.to_dict() + with safe_yaml_representers(): + print(yaml.safe_dump(config_dict, default_flow_style=False)) + + def __deepcopy__(self, memo): + """Support for deep copying.""" + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + + for f in dataclass_fields(self): + setattr(result, f.name, copy.deepcopy(getattr(self, f.name), memo)) + + return result + + +@dataclass(kw_only=True) +class PretrainConfigContainer(ConfigContainerBase): + """Top-level container holding all configuration objects.""" + + train: TrainingConfig + validation: ValidationConfig = field(default_factory=ValidationConfig) + # model: GPTModelConfig | MambaModelConfig # TODO (@maanug): add support + optimizer: OptimizerConfig + scheduler: SchedulerConfig + # dataset: GPTDatasetConfig # TODO (@maanug): add support + ddp: DistributedDataParallelConfig = field(default_factory=DistributedDataParallelConfig) + dist: DistributedInitConfig = field(default_factory=DistributedInitConfig) + rng: RNGConfig = field(default_factory=RNGConfig) + logger: LoggerConfig + checkpoint: CheckpointConfig + profiling: ProfilingConfig = field(default_factory=ProfilingConfig) + tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig) + + rerun_state_machine: RerunStateMachineConfig = field(default_factory=RerunStateMachineConfig) + straggler: StragglerDetectionConfig | None = None diff --git a/megatron/training/config/instantiate_utils.py b/megatron/training/config/instantiate_utils.py new file mode 100644 index 00000000000..0a519e2021e --- /dev/null +++ b/megatron/training/config/instantiate_utils.py @@ -0,0 +1,550 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +import copy +import functools +import inspect +import logging +from enum import Enum +from textwrap import dedent +from typing import Any, Callable, Sequence + +from omegaconf import OmegaConf +from omegaconf._utils import is_structured_config + + +class InstantiationException(Exception): + """Custom exception type for instantiation errors.""" + + ... + + +class InstantiationMode(Enum): + """Enum for instantiation modes.""" + + STRICT = "strict" + LENIENT = "lenient" + + +class _Keys(str, Enum): + """Special keys in configs used by instantiate.""" + + TARGET = "_target_" + PARTIAL = "_partial_" + CALL = "_call_" + ARGS = "_args_" + NAME = "_name_" + + +_DEFAULT_ALLOWED_PREFIXES: tuple[str, ...] = ( + "megatron.training.", + "megatron.core.", + "torch.", + "transformers.", + "signal.", +) + +_DEFAULT_ALLOWED_EXACT: frozenset[str] = frozenset({"functools.partial"}) + + +class TargetAllowlist: + """Controls which ``_target_`` strings are permitted for instantiation. + + Security: prevents arbitrary code execution from untrusted YAML configs + by gating which module paths can be imported and called. + """ + + def __init__(self) -> None: + self._allowed_prefixes: list[str] = list(_DEFAULT_ALLOWED_PREFIXES) + self._allowed_exact: set[str] = set(_DEFAULT_ALLOWED_EXACT) + self._enabled: bool = True + + def is_allowed(self, target: str) -> bool: + """Check whether *target* is permitted by the allowlist.""" + if not self._enabled: + return True + if target in self._allowed_exact: + return True + return any(target.startswith(prefix) for prefix in self._allowed_prefixes) + + def add_prefix(self, prefix: str) -> None: + """Add an allowed module prefix (must end with ``'.'``).""" + if not prefix.endswith("."): + raise ValueError(f"Prefix must end with '.': got '{prefix}'") + if prefix not in self._allowed_prefixes: + self._allowed_prefixes.append(prefix) + + def remove_prefix(self, prefix: str) -> None: + """Remove an allowed module prefix.""" + self._allowed_prefixes.remove(prefix) + + def add_exact(self, target: str) -> None: + """Add an exact target string to the allowlist.""" + self._allowed_exact.add(target) + + def remove_exact(self, target: str) -> None: + """Remove an exact target string from the allowlist.""" + self._allowed_exact.discard(target) + + def disable(self) -> None: + """Disable the allowlist check (allows all targets).""" + logging.warning( + "Target allowlist has been disabled. " "Arbitrary _target_ values will be permitted." + ) + self._enabled = False + + def enable(self) -> None: + """Re-enable the allowlist check.""" + self._enabled = True + + @property + def enabled(self) -> bool: + return self._enabled + + @property + def allowed_prefixes(self) -> tuple[str, ...]: + return tuple(self._allowed_prefixes) + + @property + def allowed_exact(self) -> frozenset[str]: + return frozenset(self._allowed_exact) + + +target_allowlist = TargetAllowlist() + + +def instantiate( + config: Any, *args: Any, mode: InstantiationMode = InstantiationMode.LENIENT, **kwargs: Any +) -> Any: + """Instantiate an object or callable from a config object. + + This function takes a configuration object (dictionary, list, OmegaConf config, + or Structured Config instance) and instantiates the target specified within it. + + The config object must contain: + _target_ (str): The fully qualified name of the class or callable to instantiate. + + The config object may also contain: + _args_ (list): Positional arguments for the target. + _partial_ (bool): If True, return a functools.partial object instead of calling + the target. Defaults to False. + _call_ (bool): If False, simply resolves and returns the target without calling it. + Defaults to True. + Additional keyword arguments to pass to the target. + + Args: + config: The configuration object describing the target and its parameters. + *args: Optional positional arguments that will override _args_ in the config + if provided. + mode: Instantiation mode (STRICT or LENIENT). Controls how config keys that + do not match the target's signature are handled: LENIENT (default) + drops them with a warning, STRICT raises ``InstantiationException``. + Errors resolving a ``_target_`` propagate in both modes. + **kwargs: Optional keyword arguments that will override parameters in the config. + Note: Dataclass instances in kwargs are treated as nested configs. + + Returns: + The instantiated object or the return value of the callable. + If config._partial_ is True, returns a functools.partial object. + If config._call_ is False, returns the resolved target callable/class itself. + Returns None if the input config is None. + + Raises: + InstantiationException: If the config is invalid, the target cannot be resolved, + or instantiation fails in STRICT mode. + TypeError: If the _partial_ flag is not a boolean. + """ + + # Return None if config is None + if config is None: + return None + + if isinstance(config, (dict, list)): + config = _prepare_input_dict_or_list(config) + + kwargs = _prepare_input_dict_or_list(kwargs) + + # Structured Config always converted first to OmegaConf + if is_structured_config(config) or isinstance(config, (dict, list)): + config = OmegaConf.structured(config, flags={"allow_objects": True}) + + if OmegaConf.is_dict(config): + # Finalize config (convert targets to strings, merge with kwargs) + config_copy = copy.deepcopy(config) + config_copy._set_flag( + flags=["allow_objects", "struct", "readonly"], values=[True, False, False] + ) + config_copy._set_parent(config._get_parent()) + config = config_copy + + if kwargs: + config = OmegaConf.merge(config, kwargs) + + OmegaConf.resolve(config) + + _partial_ = config.pop(_Keys.PARTIAL, False) + + return instantiate_node(config, *args, partial=_partial_, mode=mode) + elif OmegaConf.is_list(config): + # Finalize config (convert targets to strings, merge with kwargs) + config_copy = copy.deepcopy(config) + config_copy._set_flag( + flags=["allow_objects", "struct", "readonly"], values=[True, False, False] + ) + config_copy._set_parent(config._get_parent()) + config = config_copy + + OmegaConf.resolve(config) + + _partial_ = kwargs.pop(_Keys.PARTIAL, False) + + if _partial_: + raise InstantiationException( + "The _partial_ keyword is not compatible with top-level list instantiation" + ) + + return instantiate_node(config, *args, partial=_partial_, mode=mode) + else: + raise InstantiationException( + dedent( + f"""\ + Cannot instantiate config of type {type(config).__name__}. + Top level config must be an OmegaConf DictConfig/ListConfig object, + a plain dict/list, or a Structured Config class or instance.""" + ) + ) + + +def instantiate_node( + node: Any, + *args: Any, + partial: bool = False, + mode: InstantiationMode = InstantiationMode.LENIENT, +) -> Any: + """Recursively instantiates a node within a configuration structure. + + This function handles the instantiation of individual nodes (dictionaries, + lists, or primitive values) within a larger configuration tree, typically + managed by OmegaConf. + + If the node is a dictionary containing a `_target_` key, it resolves and + instantiates the target callable/class using the other items in the + dictionary as keyword arguments. Nested nodes are recursively instantiated. + + If the node is a list, it recursively instantiates each item in the list. + + If the node is not an OmegaConf config node (e.g., a primitive type), it's + returned directly. + + Args: + node: The configuration node to instantiate (can be DictConfig, ListConfig, + or a primitive type). + *args: Positional arguments passed down from the top-level `instantiate` call, + used primarily for the final target call if the node is a dictionary + with `_target_`. + partial: Boolean flag indicating whether to return a `functools.partial` object + instead of calling the target. This can be overridden by a + `_partial_` key within the node itself. + mode: Instantiation mode (STRICT or LENIENT). Determines error handling + behavior for nested instantiations. + + Returns: + The instantiated object, list, or the original node if it wasn't a config. + Returns None if the input node is None or represents a None value in OmegaConf. + + Raises: + InstantiationException: If instantiation fails in STRICT mode, or if there are + issues like incompatible arguments or non-callable targets. + TypeError: If a `_partial_` flag within the config is not a boolean. + """ + # Return None if config is None + if node is None or (OmegaConf.is_config(node) and node._is_none()): + return None + + if not OmegaConf.is_config(node): + return node + + if OmegaConf.is_dict(node): + partial = node[_Keys.PARTIAL] if _Keys.PARTIAL in node else partial + + full_key = node._get_full_key(None) + + if not isinstance(partial, bool): + msg = f"Instantiation: _partial_ flag must be a bool, got {type(partial)}" + if node and full_key: + msg += f"\nfull_key: {full_key}" + raise TypeError(msg) + + if OmegaConf.is_list(node): + items = [instantiate_node(item, mode=mode) for item in node._iter_ex(resolve=True)] + + return items + elif OmegaConf.is_dict(node): + exclude_keys = set(item.value for item in _Keys if item != _Keys.ARGS) + if _is_target(node): + should_call_target = node.get(_Keys.CALL, True) + _target_ = _resolve_target( + node.get(_Keys.TARGET), full_key, check_callable=should_call_target + ) + kwargs = {} + is_partial = node.get(_Keys.PARTIAL, False) or partial + + if not should_call_target: + if len(set(node.keys()) - {_Keys.TARGET, _Keys.CALL}) != 0: + extra_keys = set(node.keys()) - {_Keys.TARGET, _Keys.CALL} + raise InstantiationException( + f"_call_ was set to False for target {_convert_target_to_string(_target_)}," + f" but extra keys were found: {extra_keys}" + ) + else: + return _target_ + + for key in node.keys(): + if key not in exclude_keys: + if OmegaConf.is_missing(node, key) and is_partial: + continue + value = node[key] + value = instantiate_node(value, mode=mode) + kwargs[key] = _convert_node(value) + + assert callable(_target_) + # Drop unexpected kwargs in lenient mode or raise in strict mode + kwargs = _filter_kwargs_for_target(_target_, kwargs, full_key, mode) + return _call_target(_target_, partial, args, kwargs, full_key) + else: + dict_items = {} + for key, value in node.items(): + dict_items[key] = instantiate_node(value, mode=mode) + return dict_items + + else: + raise InstantiationException(f"Unexpected config type: {type(node).__name__}") + + +def _locate(path: str) -> Any: + """ + Locate an object by name or dotted path, importing as necessary. + This function attempts to import modules starting from the most specific path + (back to front), making it possible to import objects where the final component + could be either a module or an attribute of the previous module. + """ + if path == "": + raise ImportError("Empty path") + from importlib import import_module + + parts = [part for part in path.split(".")] + for part in parts: + if not len(part): + raise ValueError( + f"Error loading '{path}': invalid dotstring." + + "\nRelative imports are not supported." + ) + assert len(parts) > 0 + + # Try importing from the most specific path first (back to front) + for i in range(len(parts), 0, -1): + module_path = ".".join(parts[:i]) + try: + obj = import_module(module_path) + + # If this isn't the full path, get the remaining attributes + remaining_parts = parts[i:] + for part in remaining_parts: + try: + obj = getattr(obj, part) + except AttributeError as exc_attr: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_attr)}" + + f"\nAre you sure that '{part}' is an attribute of '{module_path}'?" + ) from exc_attr + + # Successfully found the object + return obj + + except ModuleNotFoundError: + # Module not found, try a less specific path + continue + except Exception as exc_import: + # If we hit a different exception, it's likely an issue with the module itself + raise ImportError(f"Error loading '{path}':\n{repr(exc_import)}") from exc_import + + # If we've tried all paths and nothing worked, report failure with the base module + raise ImportError( + f"Error loading '{path}': Unable to import any module in the path. " + f"Are you sure that module '{parts[0]}' is installed?" + ) + + +def _is_target(x: Any) -> bool: + if isinstance(x, dict): + return _Keys.TARGET in x + if OmegaConf.is_dict(x): + return _Keys.TARGET in x + return False + + +def _call_target( + _target_: Callable[..., Any], + _partial_: bool, + args: tuple[Any, ...], + kwargs: dict[str, Any], + full_key: str, +) -> Any: + """Call target (type) with args and kwargs.""" + args, kwargs = _extract_pos_args(args, kwargs) + if _partial_: + try: + return functools.partial(_target_, *args, **kwargs) + except Exception as e: + msg = ( + f"Error in creating partial({_convert_target_to_string(_target_)}, ...) object:" + + f"\n{repr(e)}" + ) + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + else: + try: + return _target_(*args, **kwargs) + except Exception as e: + msg = f"Error in call to target '{_convert_target_to_string(_target_)}':\n{repr(e)}" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + + +def _convert_target_to_string(t: Any) -> Any: + if callable(t): + return f"{t.__module__}.{t.__qualname__}" + else: + return t + + +def _filter_kwargs_for_target( + target: Callable[..., Any] | type, + kwargs: dict[str, Any], + full_key: str, + mode: InstantiationMode, +) -> dict[str, Any]: + """Drop unexpected keyword arguments for a target and warn. + + If the target accepts ``**kwargs`` we forward everything. Otherwise we + inspect the signature and remove keys not present as keyword-capable + parameters, emitting a warning with the dropped keys. + """ + try: + signature = inspect.signature(target) + except (TypeError, ValueError): + # Some builtins or C-extensions may not have an inspectable signature. + return kwargs + + parameters = signature.parameters + if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()): + return kwargs + + allowed_keys = { + name + for name, param in parameters.items() + if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + } + + unexpected = set(kwargs.keys()) - allowed_keys + if _Keys.ARGS in unexpected: + unexpected.remove(_Keys.ARGS) + + if not unexpected: + return kwargs + + target_str = _convert_target_to_string(target) + if mode == InstantiationMode.LENIENT: + # Warn and drop the unexpected keys + warning_msg = ( + f"Dropping unexpected config keys for target '{target_str}': {sorted(unexpected)}" + ) + if full_key: + warning_msg += f"\nfull_key: {full_key}" + logging.warning(warning_msg) + filtered = {k: v for k, v in kwargs.items() if k in allowed_keys} + if _Keys.ARGS in kwargs: + filtered[_Keys.ARGS] = kwargs[_Keys.ARGS] + return filtered + else: + msg = f"Unexpected config keys for target '{target_str}': {sorted(unexpected)}" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) + + +def _prepare_input_dict_or_list(d: dict[Any, Any] | list[Any]) -> Any: + res: Any + if isinstance(d, dict): + res = {} + for k, v in d.items(): + if k == _Keys.TARGET: + v = _convert_target_to_string(d[_Keys.TARGET]) + elif isinstance(v, (dict, list)): + v = _prepare_input_dict_or_list(v) + res[k] = v + elif isinstance(d, list): + res = [] + for v in d: + if isinstance(v, (list, dict)): + v = _prepare_input_dict_or_list(v) + res.append(v) + else: + raise InstantiationException(f"Expected a dict or list, got {type(d).__name__}") + return res + + +def _resolve_target( + target: str | type | Callable[..., Any], full_key: str, check_callable: bool = True +) -> type | Callable[..., Any] | object: + """Resolve target string, type or callable into type or callable.""" + if isinstance(target, str): + # Security: check allowlist BEFORE importing to prevent + # arbitrary code execution from untrusted _target_ strings. + if not target_allowlist.is_allowed(target): + msg = ( + f"Target '{target}' is not in the allowlist for _target_ instantiation.\n" + f"Allowed module prefixes: {', '.join(target_allowlist.allowed_prefixes)}\n" + f"Allowed exact targets: {', '.join(sorted(target_allowlist.allowed_exact))}\n" + f"To allow this target, call:\n" + f" target_allowlist.add_prefix('{target.rsplit('.', 1)[0] + '.'}')\n" + f" or: target_allowlist.add_exact('{target}')" + ) + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) + try: + target = _locate(target) + except Exception as e: + msg = f"Error locating target '{target}'." + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + if check_callable and not callable(target): + msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) + return target + + +def _extract_pos_args(input_args: Any, kwargs: Any) -> tuple[Any, Any]: + config_args = kwargs.pop(_Keys.ARGS, ()) + output_args = config_args + + if isinstance(config_args, Sequence): + if len(input_args) > 0: + output_args = input_args + else: + raise InstantiationException( + f"Unsupported _args_ type: '{type(config_args).__name__}'. value: '{config_args}'" + ) + + return output_args, kwargs + + +def _convert_node(node: Any) -> Any: + if OmegaConf.is_config(node): + node = OmegaConf.to_container(node, resolve=True) + + return node diff --git a/megatron/training/config/resilience_config.py b/megatron/training/config/resilience_config.py index dd0bd716521..be2a56b3bfb 100644 --- a/megatron/training/config/resilience_config.py +++ b/megatron/training/config/resilience_config.py @@ -2,6 +2,9 @@ from dataclasses import dataclass from typing import Literal +from megatron.core.fault_injector import FaultInjectorConfig + + @dataclass(kw_only=True) class RerunStateMachineConfig: """Configuration for the rerun state machine used for result validation or stats.""" @@ -10,7 +13,9 @@ class RerunStateMachineConfig: """Rate at which to inject unexpected results, e.g. 1000 means once every 1000 result validations""" - error_injection_type: Literal["correct_result", "transient_error", "persistent_error"] = "transient_error" + error_injection_type: Literal["correct_result", "transient_error", "persistent_error"] = ( + "transient_error" + ) """Type of error to inject. """ rerun_mode: Literal["disabled", "validate_results", "report_stats"] = "validate_results" @@ -39,4 +44,3 @@ class StragglerDetectionConfig: disable_straggler_on_startup: bool = False """If set, StragglerDetector is disabled on startup.""" - diff --git a/megatron/training/config/training_config.py b/megatron/training/config/training_config.py index c2e30d01f6f..cf45845c222 100644 --- a/megatron/training/config/training_config.py +++ b/megatron/training/config/training_config.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import signal from dataclasses import dataclass, field -from typing import Literal, Optional +from typing import List, Literal, Optional @dataclass(kw_only=True) @@ -145,10 +145,17 @@ class ValidationConfig: multiple_validation_sets: bool = False """If set, multiple datasets listed in the validation split are evaluated independently with a - separate loss for each dataset in the list. This argument requires that no weights are + separate loss for each dataset in the list. This argument requires that no weights are included in the list. """ + validation_set_names: Optional[List[str]] = None + """Optional list of names for multiple validation sets. When provided with + --multiple-validation-sets, these names are used instead of numeric indices + (e.g. 'validation-wiki' instead of 'validation-0'). The number of names must + match the number of validation datasets. + """ + @dataclass(kw_only=True) class SchedulerConfig: @@ -574,6 +581,11 @@ class CheckpointConfig: ckpt_assume_constant_structure: bool = False """Assume the checkpoint structure is constant across saves to enable optimizations.""" + ckpt_load_validate_sharding_integrity: bool = True + """Whether to validate sharding access integrity when loading a distributed checkpoint. + When True (default), each tensor shard is checked to be accessed exactly once as main + replica by some rank. Disabling skips this validation""" + strict_fsdp_dtensor_load: bool = True """Whether to enforce strict loading for FSDP DTensor checkpoints. When False, allows partial loading.""" @@ -619,16 +631,110 @@ class CheckpointConfig: replication_factor: int = 2 """Number of machines storing the replica of a given rank's data.""" + verify_integrity: bool = False + """Whether to hash checkpointing files during save and validate their integrity during load.""" + def __post_init__(self): - from megatron.training.utils import has_nvrx_installed + from megatron.training.utils import has_nvrx_checkpointing_async_support assert self.async_strategy in [ "nvrx", "mcore", ], f"async_strategy {self.async_strategy} is not supported. Available strategies: nvrx, mcore." - if self.async_save and self.ckpt_format in ["torch_dcp", "fsdp_dtensor"]: - assert has_nvrx_installed(), ( - "nvidia-resiliency-ext is not installed. " - "Please, install nvidia-resiliency-ext to enable async save." + if not self.async_save: + self.async_strategy = "mcore" + + if ( + self.async_save + and self.async_strategy == "nvrx" + and self.ckpt_format in ["torch_dcp", "fsdp_dtensor"] + ): + assert has_nvrx_checkpointing_async_support(), ( + "A compatible nvidia-resiliency-ext installation is required to enable " + "async save with async_strategy='nvrx'." ) + + if self.verify_integrity: + assert ( + self.ckpt_format == "torch_dist" + ), f"`verify_integrity` is only supported with torch_dist checkpoint format." + + +@dataclass(kw_only=True) +class TokenizerConfig: + """Configuration settings for the tokenizers.""" + + vocab_size: int = None + """Size of vocab before EOD or padding.""" + + padded_vocab_size: int = None + """Vocabulary size of the model (padded to be divisible by tensor model parallel size). + If not provided, it will be automatically calculated from vocab-size.""" + + vocab_file: str = None + """Path to the vocab file.""" + + merge_file: str = None + """Path to the BPE merge file.""" + + vocab_extra_ids: int = 0 + """Number of additional vocabulary tokens. They are used for span masking in the T5 model.""" + + tokenizer_type: Literal[ + "BertWordPieceLowerCase", + "BertWordPieceCase", + "GPT2BPETokenizer", + "SentencePieceTokenizer", + "GPTSentencePieceTokenizer", + "HuggingFaceTokenizer", + "Llama2Tokenizer", + "TikTokenizer", + "MultimodalTokenizer", + "NullTokenizer", + "NullMultimodalTokenizer", + "SFTTokenizer", + ] = None + """What type of tokenizer to use.""" + + tokenizer_model: str = None + """Path to the tokenizer model.""" + + metadata_path: str | None = field( + default=None, metadata={"argparse_meta": {"arg_names": ["--tokenizer-metadata"]}} + ) + """Path to the tokenizer metadata file in json format.""" + + special_tokens: Optional[list[str]] = field( + default=None, metadata={"argparse_meta": {"arg_names": ["--tokenizer-special-tokens"]}} + ) + """List of special tokens. For TikTokenizer needs to have + ["", "", "", "", "", "", ""]""" + + tiktoken_pattern: Literal["v1", "v2"] = None + """Which tiktoken pattern to use. Options: [v1, v2]""" + + tiktoken_num_special_tokens: int = 1000 + """Number of special tokens in tiktoken tokenizer.""" + + tokenizer_sentencepiece_legacy: bool = False + """SentencePiece tokenizer wrapper legacy behavior. Allows special tokens usage.""" + + tokenizer_hf_no_use_fast: bool = False + """Whether to use fast HuggingFace tokenizer.""" + + tokenizer_hf_no_include_special_tokens: bool = False + """Converting text to ids will not include special for HuggingFace tokenizer.""" + + trust_remote_code: bool = False + """Whether or not to allow PreTrainedTokenizer to execute remote code.""" + + null_tokenizer_eod_id: int = None + """EOD token id for NullTokenizer. Defaults to `vocab_size - 1`.""" + + null_tokenizer_pad_id: int = -1 + """Pad token id for NullTokenizer. Defaults to -1 (no pad token). + Set to a value outside the dataset to avoid masking real tokens.""" + + chat_template: Optional[str] = None + """Custom chat template in jinja format for conversation formatting.""" diff --git a/megatron/training/config/utils.py b/megatron/training/config/utils.py new file mode 100644 index 00000000000..28234534230 --- /dev/null +++ b/megatron/training/config/utils.py @@ -0,0 +1,110 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import importlib +import logging +from dataclasses import fields as dataclass_fields +from dataclasses import is_dataclass +from functools import lru_cache +from typing import Any + +logger = logging.getLogger(__name__) + + +def sanitize_dataclass_config( + config: dict[str, Any], _visited: set | None = None +) -> dict[str, Any]: + """Remove init=False fields from a dataclass config dict for backward compatibility. + + This function automatically detects fields with init=False by inspecting the + target class specified in the config's _target_ field. This handles cases where + older checkpoints serialized computed fields that should not be passed to __init__. + + The function recursively processes nested dicts that may also be dataclass configs. + + Args: + config: A configuration dictionary, potentially with a _target_ field. + _visited: Internal set to track visited objects and prevent infinite recursion. + + Returns: + The sanitized configuration with init=False fields removed. + """ + if not isinstance(config, dict): + return config + + if _visited is None: + _visited = set() + config_id = id(config) + if config_id in _visited: + return config + _visited.add(config_id) + + target = config.get("_target_") + init_false_fields: frozenset[str] = frozenset() + + if isinstance(target, str): + target_class = _resolve_target_class(target) + if target_class is not None: + init_false_fields = _get_init_false_fields(target_class) + + # Process all values, filtering init=False fields and recursing into nested dicts + sanitized = {} + for key, value in config.items(): + if key in init_false_fields: + if target_class is not None: + logger.debug( + f"Removing init=False field '{key}' from {target_class.__name__} config for backward compatibility" + ) + continue + + # Recursively sanitize nested dicts (which may be nested dataclass configs) + if isinstance(value, dict): + value = sanitize_dataclass_config(value, _visited) + elif isinstance(value, list): + value = [ + sanitize_dataclass_config(item, _visited) if isinstance(item, dict) else item + for item in value + ] + + sanitized[key] = value + + return sanitized + + +def _resolve_target_class(target: str) -> type | None: + """Resolve a _target_ string to a class. + + Args: + target: A fully qualified class path (e.g., "module.submodule.ClassName"). + + Returns: + The resolved class, or None if resolution fails. + """ + from megatron.training.config.instantiate_utils import target_allowlist + + if not target_allowlist.is_allowed(target): + logger.warning(f"Target '{target}' is not in the allowlist. Skipping resolution.") + return None + + try: + module_path, class_name = target.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name, None) + except (ValueError, ImportError, AttributeError) as e: + logger.warning(f"Could not resolve target '{target}': {e}") + return None + + +@lru_cache(maxsize=128) +def _get_init_false_fields(target_class: type) -> frozenset[str]: + """Get the set of field names with init=False for a dataclass. + + Args: + target_class: A dataclass type to inspect. + + Returns: + A frozenset of field names that have init=False. + """ + if not is_dataclass(target_class): + return frozenset() + + return frozenset(f.name for f in dataclass_fields(target_class) if not f.init) diff --git a/megatron/training/config/yaml_utils.py b/megatron/training/config/yaml_utils.py new file mode 100644 index 00000000000..8af2e17d388 --- /dev/null +++ b/megatron/training/config/yaml_utils.py @@ -0,0 +1,179 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import enum +import functools +import inspect +from contextlib import contextmanager +from typing import Generator + +import yaml + + +@contextmanager +def safe_yaml_representers() -> Generator[None, None, None]: + """ + Context manager for safely adding and removing custom YAML representers. + + Temporarily adds custom representers for functions, classes, and other objects + to the YAML SafeDumper, and restores the original representers when exiting + the context. + + Usage: + with safe_yaml_representers(): + yaml_str = yaml.safe_dump(my_complex_object) + """ + # Save original representers + original_representers = yaml.SafeDumper.yaml_representers.copy() + original_multi_representers = yaml.SafeDumper.yaml_multi_representers.copy() + + try: + # Register custom representers + + # Partial representer + yaml.SafeDumper.add_representer(functools.partial, _partial_representer) + + # Enum representer + yaml.SafeDumper.add_multi_representer(enum.Enum, _enum_representer) + + # Function representer + yaml.SafeDumper.add_representer(type(lambda: ...), _function_representer) + yaml.SafeDumper.add_representer(type(object), _function_representer) + + # Try to add torch dtype representer if available + try: + import torch + + yaml.SafeDumper.add_representer(torch.dtype, _torch_dtype_representer) + except ModuleNotFoundError: + pass + + # Try to add GenerationConfig representer if available + try: + from transformers import GenerationConfig + + yaml.SafeDumper.add_representer(GenerationConfig, _generation_config_representer) + except ModuleNotFoundError: + pass + + # Try to add PretrainedConfig representer if available (generic for HF configs) + try: + from transformers import PretrainedConfig + + # Use multi-representer so subclasses of PretrainedConfig are also handled + yaml.SafeDumper.add_multi_representer(PretrainedConfig, _pretrained_config_representer) + except ModuleNotFoundError: + pass + + # General object representer + yaml.SafeDumper.add_multi_representer(object, _safe_object_representer) + + yield + finally: + # Restore original representers + yaml.SafeDumper.yaml_representers = original_representers + yaml.SafeDumper.yaml_multi_representers = original_multi_representers + + +def _function_representer(dumper, data): + """Represent functions in YAML.""" + value = { + "_target_": f"{inspect.getmodule(data).__name__}.{data.__qualname__}", # type: ignore + "_call_": False, + } + return dumper.represent_data(value) + + +def _torch_dtype_representer(dumper, data): + """Represent torch dtypes in YAML.""" + value = {"_target_": str(data), "_call_": False} + return dumper.represent_data(value) + + +def _safe_object_representer(dumper, data): + """ + General object representer for YAML. + + This function is a fallback for objects that don't have specific representers. + If the object has __qualname__ attr, + the _target_ is set to f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}". + If the object does not have a __qualname__ attr, the _target_ is set from its __class__ attr. + The _call_ key is used to indicate whether the target should be called to create an instance. + + Args: + dumper (yaml.Dumper): The YAML dumper to use for serialization. + data (Any): The data to serialize. + + Returns: + The YAML representation of the data. + """ + try: + obj = data + target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" + call = False + except AttributeError: + obj = data.__class__ + target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" + call = True + + value = {"_target_": target, "_call_": call} # type: ignore + return dumper.represent_data(value) + + +def _partial_representer(dumper, data): + """Represent functools.partial objects in YAML.""" + # Get the underlying function + func = data.func + + # Create a dictionary representation + value = { + "_target_": f"{inspect.getmodule(func).__name__}.{func.__qualname__}", + "_partial_": True, + "_args_": list(data.args) if data.args else [], + } + + # Add keyword arguments if any exist + if data.keywords: + for k, v in data.keywords.items(): + value[k] = v + + return dumper.represent_data(value) + + +def _enum_representer(dumper, data): + """Represent enums in YAML.""" + # Create a dictionary representation + enum_class = data.__class__ + value = { + "_target_": f"{inspect.getmodule(enum_class).__name__}.{enum_class.__qualname__}", + "_call_": True, + "_args_": [data.value], + "_name_": data.name, + } + + return dumper.represent_data(value) + + +def _generation_config_representer(dumper, data): + """Represent transformers GenerationConfig objects in YAML.""" + cls = data.__class__ + value = { + "_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}.from_dict", + "_call_": True, + "config_dict": data.to_dict(), + } + + return dumper.represent_data(value) + + +def _pretrained_config_representer(dumper, data): + """Represent transformers PretrainedConfig objects in YAML generically. + + Uses the class's from_dict/to_dict methods to ensure full round-trip of all fields. + """ + cls = data.__class__ + value = { + "_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}.from_dict", + "_call_": True, + "config_dict": data.to_dict(), + } + return dumper.represent_data(value) diff --git a/megatron/training/global_vars.py b/megatron/training/global_vars.py index ac16e1db01b..e08667bdd91 100644 --- a/megatron/training/global_vars.py +++ b/megatron/training/global_vars.py @@ -110,7 +110,7 @@ def _graceful_shutdown(signum, frame): # synchronize all ranks before exiting try: # avoid deadlock if ranks don't all reach here - torch.distributed.barrier(timeout=timedelta(seconds=5)) + torch.distributed.barrier() except Exception: pass diff --git a/megatron/training/inprocess_restart.py b/megatron/training/inprocess_restart.py index fdfc7fd3cbe..46e51d0e421 100644 --- a/megatron/training/inprocess_restart.py +++ b/megatron/training/inprocess_restart.py @@ -1,15 +1,10 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import importlib import os import socket -from datetime import timedelta - -try: - import nvidia_resiliency_ext.inprocess as inprocess -except ImportError: - inprocess = None - import warnings +from datetime import timedelta import torch @@ -20,6 +15,13 @@ from . import arguments +def _get_inprocess_module(): + try: + return importlib.import_module("nvidia_resiliency_ext.inprocess") + except ImportError: + return None + + def destroy_state(): from . import training @@ -28,6 +30,7 @@ def destroy_state(): def inprocess_restart(train, args): + inprocess = _get_inprocess_module() if inprocess is None: warnings.warn('In-process restart is not available') return train diff --git a/megatron/training/training.py b/megatron/training/training.py index 8b4247be3b0..a388d8543d2 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1786,6 +1786,31 @@ def get_optimizer_param_scheduler(optimizer): return opt_param_scheduler +def get_megatron_ddp_config(args: Any) -> DistributedDataParallelConfig: + """Return an MCore DDPConfig from the argparse arguments.""" + + kwargs = {} + for f in dataclasses.fields(DistributedDataParallelConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + kwargs["grad_reduce_in_fp32"] = args.accumulate_allreduce_grads_in_fp32 + kwargs["check_for_nan_in_grad"] = args.check_for_nan_in_loss_and_grad + kwargs["check_for_large_grads"] = args.check_for_large_grads + kwargs["pad_buckets_for_high_nccl_busbw"] = args.ddp_pad_buckets_for_high_nccl_busbw + kwargs["reduce_scatter_with_fp32_accumulation"] = ( + args.ddp_reduce_scatter_with_fp32_accumulation + ) + kwargs["param_name_patterns_for_fp32_local_accumulation"] = tuple( + args.ddp_param_name_patterns_for_fp32_local_accumulation + ) + kwargs["average_in_collective"] = args.ddp_average_in_collective + kwargs["megatron_fsdp_main_params_dtype"] = args.megatron_fsdp_main_params_dtype + kwargs["megatron_fsdp_main_grads_dtype"] = args.megatron_fsdp_main_grads_dtype + kwargs["megatron_fsdp_grad_comm_dtype"] = args.megatron_fsdp_grad_comm_dtype + kwargs["megatron_fsdp_use_decoupled_grad"] = args.use_precision_aware_optimizer + return DistributedDataParallelConfig(**kwargs) + + def get_megatron_optimizer_config(args: Any) -> OptimizerConfig: """Return a Megatron optimizer config object from Megatron's arguments.""" diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 6581193a067..7dd589cb870 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -772,3 +772,10 @@ def has_nvrx_installed(): return True except (ImportError, ModuleNotFoundError): return False + + +def has_nvrx_checkpointing_async_support(): + """Whether the installed NVRx package exposes the async checkpointing API Megatron uses.""" + from megatron.core.dist_checkpointing.strategies.nvrx import has_nvrx_async_support + + return has_nvrx_async_support() diff --git a/model_provider.py b/model_provider.py index 0c80c54dfdb..919f55af71a 100644 --- a/model_provider.py +++ b/model_provider.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. """Common functions used in train_*.py and pretrain_*.py scripts.""" @@ -7,34 +7,36 @@ import torch from megatron.core.models.gpt import GPTModel -from megatron.core.models.mamba import MambaModel +from megatron.core.models.hybrid.hybrid_model import HybridModel from megatron.training import get_args, print_rank_0 try: - from megatron.post_training.model_builder import modelopt_gpt_mamba_builder + from megatron.post_training.model_builder import modelopt_gpt_hybrid_builder + has_nvidia_modelopt = True except ImportError: has_nvidia_modelopt = False -import megatron.legacy.model # isort: skip - -# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import - def model_provider( - model_builder: Callable, pre_process=True, post_process=True, vp_stage: Optional[int] = None, config=None, pg_collection=None, -) -> Union[GPTModel, megatron.legacy.model.GPTModel, MambaModel]: + model_builder: Callable, + pre_process=True, + post_process=True, + vp_stage: Optional[int] = None, + config=None, + pg_collection=None, +) -> Union[GPTModel, HybridModel]: """Builds the model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. Args: - model_builder: A callable that builds the actual model, its signature is the same as model_provider's with an exception of the first argument which is a builder itself. In addition might take a config passed from outside to skip its own config loading. See gpt_builder or mamba_builder for an example, see _gpt_model_builder in train_rl.py to see how to augment a default gpt builder and pass the config from outside + model_builder: A callable that builds the actual model, its signature is the same as model_provider's with an exception of the first argument which is a builder itself. In addition might take a config passed from outside to skip its own config loading. See gpt_builder or hybrid_builder for an example, see _gpt_model_builder in train_rl.py to see how to augment a default gpt builder and pass the config from outside pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. post_process (bool, optional): Set to true if you need to compute output logits/loss. Defaults to True. Returns: - Union[GPTModel, megatron.legacy.model.GPTModel, MambaModel]: The returned model + Union[GPTModel, HybridModel]: The returned model """ args = get_args() @@ -58,9 +60,11 @@ def oom_observer(device, alloc, device_alloc, device_free): if has_nvidia_modelopt and getattr(args, 'modelopt_enabled', False): # [ModelOpt]: Use custom builder + spec when modelopt is enabled - model_builder = modelopt_gpt_mamba_builder + model_builder = modelopt_gpt_hybrid_builder - return model_builder(args, pre_process, post_process, vp_stage, config=config, pg_collection=pg_collection) + return model_builder( + args, pre_process, post_process, vp_stage, config=config, pg_collection=pg_collection + ) def count_parameters_in_layer(model, layer_name): diff --git a/pretrain_bert.py b/pretrain_bert.py index 65b267c46e1..4dd6160f795 100644 --- a/pretrain_bert.py +++ b/pretrain_bert.py @@ -23,6 +23,7 @@ from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer from megatron.core.transformer.spec_utils import import_module from megatron.training import get_args, get_timers, pretrain, print_rank_0 +from megatron.training.argument_utils import pretrain_cfg_container_from_args from megatron.training.arguments import core_transformer_config_from_args, parse_and_validate_args from megatron.training.utils import average_losses_across_data_parallel_group @@ -178,8 +179,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None # Temporary for transition to core datasets train_valid_test_datasets_provider.is_distributed = True - parse_and_validate_args(args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) + args = parse_and_validate_args(args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) + full_config = pretrain_cfg_container_from_args(args) pretrain( + full_config, train_valid_test_datasets_provider, model_provider, ModelType.encoder_or_decoder, diff --git a/pretrain_hybrid.py b/pretrain_hybrid.py new file mode 100644 index 00000000000..908a1d419cc --- /dev/null +++ b/pretrain_hybrid.py @@ -0,0 +1,367 @@ +# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. +"""Pretrain and SFT Hybrid.""" + +# Capture the true program start time BEFORE any heavy imports. +import time + +_PROGRAM_START_TIME = time.time() + +import json + +# Suppress warnings on all ranks but rank 0. +import os +import warnings + +rank = int(os.environ.get('RANK', 0)) +if rank != 0: + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings("ignore", category=FutureWarning) + +from functools import partial +from typing import Any, List, Optional, Tuple + +import torch + +from hybrid_builders import hybrid_builder +from megatron.core import mpu +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset +from megatron.core.enums import ModelType +from megatron.core.models.hybrid.hybrid_model import HybridModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import get_context_parallel_rank, get_context_parallel_world_size +from megatron.core.rerun_state_machine import get_rerun_state_machine +from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer +from megatron.core.utils import StragglerDetector, get_attr_wrapped_model, is_te_min_version +from megatron.training import ( + get_args, + get_timers, + inprocess_restart, + pretrain, + print_rank_0, + set_startup_timestamps, +) +from megatron.training.argument_utils import pretrain_cfg_container_from_args +from megatron.training.arguments import parse_and_validate_args +from megatron.training.datasets.sft_dataset import SFTDataset +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, + get_blend_and_blend_per_split, + is_first_or_last_pipeline_stage, +) +from model_provider import model_provider + +try: + from megatron.post_training.arguments import add_modelopt_args + from megatron.post_training.loss_func import loss_func as loss_func_modelopt + + has_nvidia_modelopt = True +except ImportError: + has_nvidia_modelopt = False + +try: + # Register the TE CUDA kernels + import transformer_engine # pylint: disable=unused-import + + # Alias the PyTorch wrapper so we can call tex.* APIs + import transformer_engine_torch as tex +except ImportError: + # TE isn’t installed or the torch wrapper is missing + tex = None + +stimer = StragglerDetector() + + +def get_batch(data_iterator, vp_stage=None): + """Generate a batch.""" + + empty_batch = { + 'tokens': None, + 'labels': None, + 'loss_mask': None, + 'attention_mask': None, + 'position_ids': None, + 'cu_seqlens': None, + 'max_seqlen': None, + } + + # TODO(duncan): Is there a more efficient way to access is_packed_sequence here? + is_packed_sequence = get_args().sft # SFT always uses packed sequence + if not is_first_or_last_pipeline_stage(vp_stage) and not is_packed_sequence: + return empty_batch.values() + + batch = get_batch_on_this_tp_rank(data_iterator) + + cu_seqlens = batch['cu_seqlens'] + # Unused at the moment + cu_seqlens_padded = batch.pop('cu_seqlens_padded', None) + # Support for Hybrid Context Parallel (Unused in this script) + local_cp_size = batch.pop('local_cp_size', None) + + if cu_seqlens is not None: + assert ( + cu_seqlens.dim() == 2 and cu_seqlens.shape[0] == 1 + ), "micro-batch-size must be 1 for packing" + cu_seqlens = cu_seqlens[0] + batch['cu_seqlens'] = cu_seqlens + + max_seqlen = batch['max_seqlen'] + assert max_seqlen.dim() == 1 + # TODO(duncan): can this be kept as a 0-D tensor? + batch['max_seqlen'] = int(max_seqlen[0].item()) + + if mpu.is_pipeline_first_stage(ignore_virtual=(vp_stage is None), vp_stage=vp_stage): + total_tokens = batch['tokens'].size(1) + elif mpu.is_pipeline_last_stage(ignore_virtual=(vp_stage is None), vp_stage=vp_stage): + total_tokens = batch['labels'].size(1) + else: # packed sequence + empty_batch['cu_seqlens'] = cu_seqlens + empty_batch['max_seqlen'] = max_seqlen + return empty_batch.values() + + if cu_seqlens is None: + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore + else: # Packed THD format + cp_size = get_context_parallel_world_size() + if cp_size > 1: # slice batch along sequence dimension for context parallelism + assert tex is not None and is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + cp_rank = get_context_parallel_rank() + index = tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) + for key, data in batch.items(): + if key in {'attention_mask', 'cu_seqlens', 'max_seqlen'}: + continue + if data is not None: + # On first PP rank, labels and loss_mask can be None. + # On last PP rank, tokens and position_ids can be None. + batch[key] = data.index_select(1, index) + + return batch.values() + + +# define spiky loss as a loss that's 10x the max loss observed +SPIKY_LOSS_FACTOR = 10 + + +def loss_func( + loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[HybridModel] = None +): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + if has_nvidia_modelopt and getattr(args, 'modelopt_enabled', False): # [ModelOpt] + loss, num_tokens, report = loss_func_modelopt(loss_mask, output_tensor, model=model) + else: + losses = output_tensor.view(-1).float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses * loss_mask) + + num_tokens = loss_mask.sum().clone().detach().to(torch.int) + report = {'lm loss': torch.cat([loss.clone().detach().view(1), num_tokens.view(1)])} + + # Check individual rank losses are not NaN prior to DP all-reduce. + rerun_state_machine = get_rerun_state_machine() + if args.check_for_nan_in_loss_and_grad: + rerun_state_machine.validate_result( + result=loss, + rejection_func=torch.isnan, + message="found NaN in local forward loss calculation", + tolerance=0.0, # forward pass calculations are deterministic + fatal=True, + ) + rerun_state_machine.validate_result( + result=loss, + rejection_func=torch.isinf, + message="found Inf in local forward loss calculation", + tolerance=0.0, # forward pass calculations are deterministic + fatal=True, + ) + # Check for spiky loss + if args.check_for_spiky_loss: + rerun_state_machine.validate_result( + result=loss, + rejection_func=partial( + rerun_state_machine.is_unexpectedly_large, + threshold=SPIKY_LOSS_FACTOR, + context="loss", + ), + message="Spiky loss", + tolerance=0.0, # forward pass calculations are deterministic + fatal=False, + ) + + return loss, num_tokens, report + + +def forward_step(data_iterator, model: HybridModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (HybridModel): The Model + """ + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + + global stimer + + with stimer(bdata=True): + vp_stage = get_attr_wrapped_model(model, "vp_stage") + (tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, max_seqlen) = ( + get_batch(data_iterator, vp_stage) + ) + + if cu_seqlens is None: + packed_seq_params = None + else: + total_tokens = tokens.size(1) if tokens is not None else labels.size(1) + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=None, + cu_seqlens_kv_padded=None, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + total_tokens=total_tokens, + ) + + timers('batch-generator').stop() + + with stimer: + output_tensor = model( + tokens, + position_ids, + attention_mask, + labels=labels, + packed_seq_params=packed_seq_params, + loss_mask=loss_mask, + ) + + # [ModelOpt]: model is needed to access ModelOpt distillation losses + return output_tensor, partial(loss_func, loss_mask, model=model) + + +def is_dataset_built_on_rank(vp_stage=None, is_packed_sequence=False): + if mpu.get_tensor_model_parallel_rank() != 0: + return False + elif is_packed_sequence: + return True + else: + return is_first_or_last_pipeline_stage(vp_stage) + + +def core_gpt_dataset_config_from_args(args: Any) -> GPTDatasetConfig: + tokenizer = build_tokenizer(args) + + # Sometimes --data-path is too long, instead we parse it from a file. + blend: Optional[Tuple[List[str], Optional[List[float]]]] + blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] + blend, blend_per_split = get_blend_and_blend_per_split(args) + + sequences_per_dataset = None + if args.per_dataset_sequences_path is not None: + with open(args.per_dataset_sequences_path, "r") as f: + sequences_per_dataset = json.load(f) + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=blend, + blend_per_split=blend_per_split, + split=args.split, + multiple_validation_sets=args.multiple_validation_sets, + full_validation=args.full_validation, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + object_storage_cache_path=args.object_storage_cache_path, + mid_level_dataset_surplus=args.mid_level_dataset_surplus, + allow_ambiguous_pad_tokens=args.allow_ambiguous_pad_tokens, + fast_cache_load=args.dataloader_fast_cache_load, + sequences_per_dataset=sequences_per_dataset, + defer_npy_index_mmap=args.dataloader_defer_npy_index_mmap, + context_parallel_size=args.context_parallel_size, + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + config = core_gpt_dataset_config_from_args(args) + + is_packed_sequence = False + if args.sft: + dataset_type = SFTDataset + is_packed_sequence = True # SFT always uses packed sequence + else: + if args.mock_data: + dataset_type = MockGPTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + partial(is_dataset_built_on_rank, vp_stage=vp_stage, is_packed_sequence=is_packed_sequence), + config, + ).build() + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + # Timestamp right after entering __main__ block (after all imports/library setup) + _MAIN_ENTRY_TIME = time.time() + + # Register startup timestamps for timing report in pretrain() + set_startup_timestamps(program_start=_PROGRAM_START_TIME, main_entry=_MAIN_ENTRY_TIME) + + # Temporary for transition to core datasets + setattr(train_valid_test_datasets_provider, "is_distributed", True) + + # Optionally enable inprocess restart on pretrain + pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + + args = parse_and_validate_args( + extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + ) + full_config = pretrain_cfg_container_from_args(args) + pretrain( + full_config, + train_valid_test_datasets_provider, + partial(model_provider, hybrid_builder), + ModelType.encoder_or_decoder, + forward_step, + store=store, + ) diff --git a/pretrain_mamba.py b/pretrain_mamba.py index df709883a66..7eb7f461cab 100644 --- a/pretrain_mamba.py +++ b/pretrain_mamba.py @@ -1,362 +1,19 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -"""Pretrain and SFT Mamba.""" +# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. +"""Backward-compatible wrapper for pretrain_hybrid.py. -# Capture the true program start time BEFORE any heavy imports. -import time - -_PROGRAM_START_TIME = time.time() - -import json - -# Suppress warnings on all ranks but rank 0. +Deprecated. Use pretrain_hybrid.py instead. +""" import os +import runpy import warnings -rank = int(os.environ.get('RANK', 0)) -if rank != 0: - warnings.filterwarnings("ignore", category=UserWarning) - warnings.filterwarnings("ignore", category=FutureWarning) - -from functools import partial -from typing import List, Optional, Tuple - -import torch - -from mamba_builders import mamba_builder -from megatron.core import mpu -from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset -from megatron.core.enums import ModelType -from megatron.core.models.mamba import MambaModel -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.parallel_state import get_context_parallel_rank, get_context_parallel_world_size -from megatron.core.rerun_state_machine import get_rerun_state_machine -from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer -from megatron.core.utils import StragglerDetector, get_attr_wrapped_model, is_te_min_version -from megatron.training import ( - get_args, - get_timers, - inprocess_restart, - pretrain, - print_rank_0, - set_startup_timestamps, +warnings.warn( + "pretrain_mamba.py has been deprecated. Use pretrain_hybrid.py instead.", + DeprecationWarning, + stacklevel=2, ) -from megatron.training.arguments import parse_and_validate_args -from megatron.training.datasets.sft_dataset import SFTDataset -from megatron.training.utils import ( - get_batch_on_this_cp_rank, - get_batch_on_this_tp_rank, - get_blend_and_blend_per_split, - is_first_or_last_pipeline_stage, -) -from model_provider import model_provider - -try: - from megatron.post_training.arguments import add_modelopt_args - from megatron.post_training.loss_func import loss_func as loss_func_modelopt - - has_nvidia_modelopt = True -except ImportError: - has_nvidia_modelopt = False - -try: - # Register the TE CUDA kernels - import transformer_engine # pylint: disable=unused-import - - # Alias the PyTorch wrapper so we can call tex.* APIs - import transformer_engine_torch as tex -except ImportError: - # TE isn’t installed or the torch wrapper is missing - tex = None - -stimer = StragglerDetector() - - -def get_batch(data_iterator, vp_stage=None): - """Generate a batch.""" - - empty_batch = { - 'tokens': None, - 'labels': None, - 'loss_mask': None, - 'attention_mask': None, - 'position_ids': None, - 'cu_seqlens': None, - 'max_seqlen': None, - } - - # TODO(duncan): Is there a more efficient way to access is_packed_sequence here? - is_packed_sequence = get_args().sft # SFT always uses packed sequence - if not is_first_or_last_pipeline_stage(vp_stage) and not is_packed_sequence: - return empty_batch.values() - - batch = get_batch_on_this_tp_rank(data_iterator) - - cu_seqlens = batch['cu_seqlens'] - # Unused at the moment - cu_seqlens_padded = batch.pop('cu_seqlens_padded', None) - # Support for Dynamic Context Parallel (Unused in this script) - local_cp_size = batch.pop('local_cp_size', None) - - if cu_seqlens is not None: - assert ( - cu_seqlens.dim() == 2 and cu_seqlens.shape[0] == 1 - ), "micro-batch-size must be 1 for packing" - cu_seqlens = cu_seqlens[0] - batch['cu_seqlens'] = cu_seqlens - - max_seqlen = batch['max_seqlen'] - assert max_seqlen.dim() == 1 - # TODO(duncan): can this be kept as a 0-D tensor? - batch['max_seqlen'] = int(max_seqlen[0].item()) - - if mpu.is_pipeline_first_stage(ignore_virtual=(vp_stage is None), vp_stage=vp_stage): - total_tokens = batch['tokens'].size(1) - elif mpu.is_pipeline_last_stage(ignore_virtual=(vp_stage is None), vp_stage=vp_stage): - total_tokens = batch['labels'].size(1) - else: # packed sequence - empty_batch['cu_seqlens'] = cu_seqlens - empty_batch['max_seqlen'] = max_seqlen - return empty_batch.values() - - if cu_seqlens is None: - # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore - else: # Packed THD format - cp_size = get_context_parallel_world_size() - if cp_size > 1: # slice batch along sequence dimension for context parallelism - assert tex is not None and is_te_min_version("1.10.0"), ( - "Please update Transformer Engine to >= 1.10 to use " - "Context Parallel with THD format data" - ) - cp_rank = get_context_parallel_rank() - index = tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank) - for key, data in batch.items(): - if key in {'attention_mask', 'cu_seqlens', 'max_seqlen'}: - continue - if data is not None: - # On first PP rank, labels and loss_mask can be None. - # On last PP rank, tokens and position_ids can be None. - batch[key] = data.index_select(1, index) - - return batch.values() - - -# define spiky loss as a loss that's 10x the max loss observed -SPIKY_LOSS_FACTOR = 10 - - -def loss_func( - loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[MambaModel] = None -): - """Loss function. - - Args: - loss_mask (torch.Tensor): Used to mask out some portions of the loss - output_tensor (torch.Tensor): The tensor with the losses - - Returns: - the loss scalar for this micro-batch - the number of non-padded tokens in this microbatch - a dict containing reporting metrics on the loss and number of tokens across - the data parallel ranks - """ - args = get_args() - if has_nvidia_modelopt and getattr(args, 'modelopt_enabled', False): # [ModelOpt] - loss, num_tokens, report = loss_func_modelopt(loss_mask, output_tensor, model=model) - else: - losses = output_tensor.view(-1).float() - loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses * loss_mask) - - num_tokens = loss_mask.sum().clone().detach().to(torch.int) - report = {'lm loss': torch.cat([loss.clone().detach().view(1), num_tokens.view(1)])} - - # Check individual rank losses are not NaN prior to DP all-reduce. - rerun_state_machine = get_rerun_state_machine() - if args.check_for_nan_in_loss_and_grad: - rerun_state_machine.validate_result( - result=loss, - rejection_func=torch.isnan, - message="found NaN in local forward loss calculation", - tolerance=0.0, # forward pass calculations are deterministic - fatal=True, - ) - rerun_state_machine.validate_result( - result=loss, - rejection_func=torch.isinf, - message="found Inf in local forward loss calculation", - tolerance=0.0, # forward pass calculations are deterministic - fatal=True, - ) - # Check for spiky loss - if args.check_for_spiky_loss: - rerun_state_machine.validate_result( - result=loss, - rejection_func=partial( - rerun_state_machine.is_unexpectedly_large, - threshold=SPIKY_LOSS_FACTOR, - context="loss", - ), - message="Spiky loss", - tolerance=0.0, # forward pass calculations are deterministic - fatal=False, - ) - - return loss, num_tokens, report - - -def forward_step(data_iterator, model: MambaModel): - """Forward training step. - - Args: - data_iterator : Input data iterator - model (MambaModel): The GPT Model - """ - timers = get_timers() - - # Get the batch. - timers('batch-generator', log_level=2).start() - - global stimer - - with stimer(bdata=True): - vp_stage = get_attr_wrapped_model(model, "vp_stage") - (tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, max_seqlen) = ( - get_batch(data_iterator, vp_stage) - ) - - if cu_seqlens is None: - packed_seq_params = None - else: - total_tokens = tokens.size(1) if tokens is not None else labels.size(1) - packed_seq_params = PackedSeqParams( - qkv_format="thd", - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - cu_seqlens_q_padded=None, - cu_seqlens_kv_padded=None, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - total_tokens=total_tokens, - ) - - timers('batch-generator').stop() - - with stimer: - output_tensor = model( - tokens, - position_ids, - attention_mask, - labels=labels, - packed_seq_params=packed_seq_params, - loss_mask=loss_mask, - ) - - # [ModelOpt]: model is needed to access ModelOpt distillation losses - return output_tensor, partial(loss_func, loss_mask, model=model) - - -def is_dataset_built_on_rank(vp_stage=None, is_packed_sequence=False): - if mpu.get_tensor_model_parallel_rank() != 0: - return False - elif is_packed_sequence: - return True - else: - return is_first_or_last_pipeline_stage(vp_stage) - - -def core_gpt_dataset_config_from_args(args): - tokenizer = build_tokenizer(args) - - # Sometimes --data-path is too long, instead we parse it from a file. - blend: Optional[Tuple[List[str], Optional[List[float]]]] - blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] - blend, blend_per_split = get_blend_and_blend_per_split(args) - - sequences_per_dataset = None - if args.per_dataset_sequences_path is not None: - with open(args.per_dataset_sequences_path, "r") as f: - sequences_per_dataset = json.load(f) - - return GPTDatasetConfig( - random_seed=args.seed, - sequence_length=args.seq_length, - blend=blend, - blend_per_split=blend_per_split, - split=args.split, - num_dataset_builder_threads=args.num_dataset_builder_threads, - path_to_cache=args.data_cache_path, - mmap_bin_files=args.mmap_bin_files, - tokenizer=tokenizer, - reset_position_ids=args.reset_position_ids, - reset_attention_mask=args.reset_attention_mask, - eod_mask_loss=args.eod_mask_loss, - create_attention_mask=args.create_attention_mask_in_dataloader, - object_storage_cache_path=args.object_storage_cache_path, - mid_level_dataset_surplus=args.mid_level_dataset_surplus, - allow_ambiguous_pad_tokens=args.allow_ambiguous_pad_tokens, - fast_cache_load=args.dataloader_fast_cache_load, - sequences_per_dataset=sequences_per_dataset, - defer_npy_index_mmap=args.dataloader_defer_npy_index_mmap, - context_parallel_size=args.context_parallel_size, - ) - - -def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None): - """Build the train test and validation datasets. - - Args: - train_val_test_num_samples : A list containing the number of samples in train test and validation. - """ - args = get_args() - config = core_gpt_dataset_config_from_args(args) - - is_packed_sequence = False - if args.sft: - dataset_type = SFTDataset - is_packed_sequence = True # SFT always uses packed sequence - else: - if args.mock_data: - dataset_type = MockGPTDataset - else: - dataset_type = GPTDataset - - print_rank_0("> building train, validation, and test datasets for GPT ...") - - train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( - dataset_type, - train_val_test_num_samples, - partial(is_dataset_built_on_rank, vp_stage=vp_stage, is_packed_sequence=is_packed_sequence), - config, - ).build() - - print_rank_0("> finished creating GPT datasets ...") - - return train_ds, valid_ds, test_ds - if __name__ == "__main__": - # Timestamp right after entering __main__ block (after all imports/library setup) - _MAIN_ENTRY_TIME = time.time() - - # Register startup timestamps for timing report in pretrain() - set_startup_timestamps(program_start=_PROGRAM_START_TIME, main_entry=_MAIN_ENTRY_TIME) - - # Temporary for transition to core datasets - train_valid_test_datasets_provider.is_distributed = True - - # Optionally enable inprocess restart on pretrain - pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) - - args = parse_and_validate_args( - extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None, - args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, - ) - pretrain( - train_valid_test_datasets_provider, - partial(model_provider, mamba_builder), - ModelType.encoder_or_decoder, - forward_step, - store=store, - ) + # Execute pretrain_hybrid.py as if it were invoked directly. + _this_dir = os.path.dirname(os.path.abspath(__file__)) + runpy.run_path(os.path.join(_this_dir, "pretrain_hybrid.py"), run_name="__main__") diff --git a/pretrain_t5.py b/pretrain_t5.py index 59918930a78..fe928de78c7 100644 --- a/pretrain_t5.py +++ b/pretrain_t5.py @@ -26,6 +26,7 @@ ) from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer from megatron.training import get_args, get_timers, pretrain, print_rank_0 +from megatron.training.argument_utils import pretrain_cfg_container_from_args from megatron.training.arguments import core_transformer_config_from_args, parse_and_validate_args from pretrain_gpt import loss_func @@ -269,8 +270,10 @@ def t5_position_embedding_ranks(pp_ranks): # Temporary for transition to core datasets train_valid_test_datasets_provider.is_distributed = True - parse_and_validate_args(args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) + args = parse_and_validate_args(args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) + full_config = pretrain_cfg_container_from_args(args) pretrain( + full_config, train_valid_test_datasets_provider, model_provider, ModelType.encoder_or_decoder, diff --git a/pretrain_vlm.py b/pretrain_vlm.py index 720094b5ec7..dc735c25517 100644 --- a/pretrain_vlm.py +++ b/pretrain_vlm.py @@ -24,6 +24,7 @@ from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.spec_utils import import_module from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 +from megatron.training.argument_utils import pretrain_cfg_container_from_args from megatron.training.arguments import core_transformer_config_from_args, parse_and_validate_args from pretrain_gpt import loss_func @@ -470,10 +471,12 @@ def llava_position_embedding_ranks(pp_ranks): if __name__ == "__main__": train_valid_test_datasets_provider.is_distributed = True - parse_and_validate_args( + args = parse_and_validate_args( extra_args_provider=add_vlm_extra_args, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'} ) + full_config = pretrain_cfg_container_from_args(args) pretrain( + full_config, train_valid_test_datasets_provider, model_provider, ModelType.encoder_or_decoder, diff --git a/skills/build-and-dependency/SKILL.md b/skills/build-and-dependency/SKILL.md index c7dc89ec1ad..af262884a41 100644 --- a/skills/build-and-dependency/SKILL.md +++ b/skills/build-and-dependency/SKILL.md @@ -1,8 +1,7 @@ --- name: build-and-dependency -description: Container-based dev environment setup and dependency management for Megatron-LM. Covers acquiring and launching the CI container, uv package management, updating uv.lock, and linting. -TRIGGER when: user asks to add, remove, or update a dependency; user edits or asks about pyproject.toml or uv.lock; uv.lock has a merge conflict; user asks to set up a dev environment or pull/build the CI container; user hits a container build error or uv error; user asks to run linting or autoformat. -DO NOT TRIGGER when: user is only running tests, investigating CI failures, or opening a PR (use testsystem instead). +description: Container-based dev environment setup and dependency management for Megatron-LM. Covers acquiring and launching the CI container, uv package management, and updating uv.lock. +when_to_use: Adding, removing, or updating a dependency; editing pyproject.toml or uv.lock; uv.lock merge conflict; setting up a dev environment; pulling or building the CI container; container build errors; uv errors; 'how do I install', 'uv sync fails', 'ModuleNotFoundError'. --- # Build & Dependency Guide @@ -28,12 +27,30 @@ dependency. --- +## dev vs lts + +Two image variants exist, controlled by the `IMAGE_TYPE` build arg and the +`container::lts` PR label: + +| Variant | Base image pin | uv group | When used | +|---------|---------------|----------|-----------| +| **`dev`** | `docker/.ngc_version.dev` | `dev` | Default — CI, local development, most PRs | +| **`lts`** | `docker/.ngc_version.lts` | `lts` | Stability testing; excludes ModelOpt and other bleeding-edge extras | + +**Use `dev` for everything unless you have a specific reason to test `lts`.** +CI runs `dev` by default; attach `container::lts` to a PR only when verifying +compatibility with the stable stack (e.g. a dependency upgrade that must not +break LTS users). The `@pytest.mark.flaky_in_dev` marker skips tests in the +`dev` environment; `@pytest.mark.flaky` skips them in `lts`. + +--- + ## Step 1 — Acquire an Image **Option A — NVIDIA-internal: pull a CI-built image** > ⚠️ Requires access to the internal GitLab instance. -> See `tools/trigger_internal_ci.md` for setup (adding the git remote, obtaining a token). +> See @tools/trigger_internal_ci.md for setup (adding the git remote, obtaining a token). The internal GitLab CI publishes images to its container registry. Derive the registry host from your configured `gitlab` remote — the same @@ -180,25 +197,6 @@ uv lock # re-resolve on top of your pyproject.toml --- -## Linting - -Run before opening a PR: - -```bash -# Check mode (no changes applied) -BASE_REF=main CHECK_ONLY=true SKIP_DOCS=false bash tools/autoformat.sh - -# Fix mode -BASE_REF=main CHECK_ONLY=false bash tools/autoformat.sh -``` - -Tools invoked: `black`, `isort`, `pylint`, `ruff`, `mypy`. - -After editing imports in any Python files, always run `uv run isort` on those -files before committing (repo CLAUDE.md requirement). - ---- - ## Common Pitfalls | Problem | Cause | Fix | @@ -207,5 +205,5 @@ files before committing (repo CLAUDE.md requirement). | `ModuleNotFoundError` after pip install | pip installed outside the uv-managed venv | Use `uv add` and `uv sync`, never bare `pip install` | | `uv: command not found` inside container | Wrong container image | Use the `megatron-lm` image built from `Dockerfile.ci.dev` | | `No space left on device` during uv ops | Cache fills container's `/root/.cache/` | Mount a host cache dir via `-v $HOME/.cache/uv:/root/.cache/uv` | -| Pre-commit fails with linting errors | Code style violations | Run `BASE_REF=main CHECK_ONLY=false bash tools/autoformat.sh` | | `docker build` fails with secret-related error | `Dockerfile.ci.dev` has a `jet` stage that requires an internal secret | Add `--target main` to stop before the `jet` stage | +| `access forbidden` when pulling | Registry URL includes an explicit port (e.g. `:5005`) | Use `${GITLAB_HOST}/adlr/...` with no port — the sed extracts the hostname only | diff --git a/skills/cicd/SKILL.md b/skills/cicd/SKILL.md new file mode 100644 index 00000000000..002bef8ce58 --- /dev/null +++ b/skills/cicd/SKILL.md @@ -0,0 +1,170 @@ +--- +name: cicd +description: CI/CD reference for Megatron-LM. Covers CI pipeline structure, PR scope labels, triggering internal GitLab CI, and CI failure investigation. +when_to_use: Investigating a CI failure; understanding the pipeline structure; which CI label to attach; triggering internal GitLab CI; 'CI is red', 'how do I trigger CI', 'PR labels', 'where are the logs', 'pull-request branch'. +--- + +# CI/CD Guide + +--- + +## CI Pipeline Structure + +The main workflow is `.github/workflows/cicd-main.yml`. It triggers on pushes +to branches matching `pull-request/[0-9]+` and `deploy-release/*`, on merge +groups, on a daily schedule, and on manual dispatch. + +```text +is-not-external-contributor + └─ pre-flight + └─ configure # determines scope, container tag, n_repeat + ├─ linting + ├─ cicd-container-build + │ ├─ cicd-parse-unit-tests → cicd-unit-tests-latest + │ ├─ cicd-parse-integration-tests-h100 → cicd-integration-tests-latest-h100 + │ └─ cicd-parse-integration-tests-gb200 → cicd-integration-tests-latest-gb200 (maintainers only) + └─ Nemo_CICD_Test # final pass/fail gate +``` + +Images are pushed to: + +- AWS ECR: `766267172432.dkr.ecr.us-east-1.amazonaws.com/…` +- GCP Artifact Registry: `us-east4-docker.pkg.dev/nv-projdgxchipp-20260113193621/megatron-lm/…` + +--- + +## CI Test Scope Labels + +The CI pipeline reads PR labels to decide test scope, n_repeat, and container image. + +**Decision tree (first match wins):** + +| Condition | `scope` | `n_repeat` | `lightweight` | Notes | +|-----------|---------|-----------|---------------|-------| +| Merge group | `mr-github` | 1 | false | Automatic, no label needed | +| Label: **`Run tests`** | `mr-github` | 1 | **true** | Trains 4 steps, no golden-value compare | +| Label: **`Run functional tests`** | `mr-github` | 5 | **false** | Trains 100 steps, golden-value compare | +| _(no label)_ | `mr-github-slim` | 5 | false | Slim subset only | + +**Orthogonal image label:** + +| Label | Effect | +|-------|--------| +| **`container::lts`** | Use the LTS base image instead of `dev` (combinable with any scope label) | +| **`Run MBridge tests`** | Also triggers the MBridge L1 test suite | + +### Which label to attach when opening a PR + +| Changed paths / nature of change | Label to attach | +|----------------------------------|-----------------| +| Docs only (`docs/`, `*.md`, docstrings) | _(none)_ | +| CI/tooling only (`.github/`, `tools/`, `Makefile`) | _(none)_ | +| Test files only (`tests/`) — existing tests, no new golden values | `Run tests` | +| **New test cases added** (no golden values exist yet) | `Run functional tests` | +| **Re-enabling a disabled test** (scope `-broken` → active) | `Run functional tests` | +| Non-numerical library code (logging, error handling, CLI flags, refactors) | `Run tests` | +| Could affect training numerics (model arch, attention, optimizer, distributed, MoE routing) | `Run functional tests` | +| Container or dependency changes (`docker/`, `pyproject.toml`, `uv.lock`) | `Run tests` + `container::lts` | +| Touches MBridge integration | add `Run MBridge tests` | + +**Rule of thumb:** default to `Run tests`. Always use `Run functional tests` when the PR adds new test cases (golden values must be generated) or when the change could plausibly shift loss curves. + +--- + +## Triggering Internal CI + +Use `tools/trigger_internal_ci.py` to push the current branch to the internal +GitLab remote and trigger a pipeline — without touching the GitLab UI. +Full setup and usage details: @tools/trigger_internal_ci.md. + +**Prerequisites** (one-time): + +```bash +# 1. Add the internal GitLab remote +git remote add gitlab git@:ADLR/Megatron-LM.git + +# 2. Create a personal access token with 'api' scope on your GitLab profile, +# then store it: +export GITLAB_TOKEN=glpat- +``` + +**Usage:** + +```bash +python tools/trigger_internal_ci.py \ + --gitlab-origin gitlab \ + [--functional-test-scope mr] \ + [--functional-test-repeat 5] \ + [--functional-test-cases all] \ + [--dry-run] +``` + +The script force-pushes the current branch as `pull-request/` and +prints the resulting pipeline URL. + +--- + +## CI Failure Investigation + +CI branches always follow the pattern `pull-request/`. + +### Locating the PR from a CI Branch + +```bash +# Extract PR number from the current branch +PR_NUMBER=$(git rev-parse --abbrev-ref HEAD | grep -oP '(?<=pull-request/)\d+') + +# Fetch the PR metadata (title, labels, author, base branch) +gh pr view "$PR_NUMBER" --repo NVIDIA/Megatron-LM + +# Show the changeset for that PR +gh pr diff "$PR_NUMBER" --repo NVIDIA/Megatron-LM +``` + +### Reading CI Job Logs + +```bash +# List recent workflow runs for the PR +gh run list --repo NVIDIA/Megatron-LM --branch "pull-request/$PR_NUMBER" + +# Stream failing job output +gh run view --repo NVIDIA/Megatron-LM --log-failed +``` + +Full per-rank logs are **not** in the runner stdout. They are uploaded as +GitHub artifacts named `logs---`. + +```bash +# 1. Find artifact name +gh run view --repo NVIDIA/Megatron-LM --json artifacts \ + --jq '.artifacts[].name' + +# 2. Download the artifact zip +gh run download --repo NVIDIA/Megatron-LM \ + --name "logs-" -D ./ci-logs + +# 3. Locate which rank logs contain errors +grep -r -l "ERROR\|Traceback\|FAILED\|fatal" ./ci-logs/ + +# 4. Log files can exceed 10 000 lines — never read a full log at once. +wc -l ./ci-logs///attempt_0//stderr.log +sed -n '1,200p' ./ci-logs/.../stderr.log # read in chunks +``` + +### Identifying Failure Root Cause + +1. **Linting failure** — re-run `tools/autoformat.sh` locally; the diff shows exactly what needs to change. +2. **Container build failure** — inspect the `cicd-container-build` job log. +3. **Unit test failure** — the failing bucket is in the `cicd-unit-tests-latest` job matrix. +4. **Functional test failure** — look at the `cicd-integration-tests-*` job. Start with `stdout.log` for rank 0. +5. **Flaky test** — the runner retries automatically up to 3 times. If all retries exhausted and the pattern matches a known transient (NCCL, ECC, segfault), it is infrastructure noise. + +### Correlating a Failure with the PR Changeset + +```bash +# Find unit tests that cover a changed source file +grep -r "from megatron.core.transformer.attention" tests/unit_tests/ -l + +# Check CODEOWNERS for reviewer assignment +cat .github/CODEOWNERS | grep "" +``` diff --git a/skills/create-issue/SKILL.md b/skills/create-issue/SKILL.md index d974fcb438e..449c69bf559 100644 --- a/skills/create-issue/SKILL.md +++ b/skills/create-issue/SKILL.md @@ -1,6 +1,7 @@ --- name: create-issue -description: Investigate a failing GitHub Actions run or job and create a GitHub issue for the failure. Use when the user shares a GitHub Actions URL and wants to file a bug report for the CI failure. +description: Investigate a failing GitHub Actions run or job and create a GitHub issue for the failure. +when_to_use: User shares a GitHub Actions URL and wants to file a bug report; 'create an issue for this failure', 'file a bug for this CI run', 'triage this GitHub Actions failure'. user_invocable: true argument: "" --- diff --git a/skills/linting-and-formatting/SKILL.md b/skills/linting-and-formatting/SKILL.md new file mode 100644 index 00000000000..00ab01a1342 --- /dev/null +++ b/skills/linting-and-formatting/SKILL.md @@ -0,0 +1,57 @@ +--- +name: linting-and-formatting +description: Linting and formatting for Megatron-LM. Covers running autoformat.sh, tools (ruff, black, isort, pylint, mypy), and code style rules. +when_to_use: Running linting or autoformat; fixing style violations before a PR; 'pre-commit fails', 'ruff error', 'isort', 'mypy', 'style violation', 'how do I format', 'autoformat.sh'. +--- + +# Linting and Formatting + +--- + +## Running the Formatter + +Run before opening a PR: + +```bash +# Check mode (no changes applied) +BASE_REF=main CHECK_ONLY=true SKIP_DOCS=false bash tools/autoformat.sh + +# Fix mode +BASE_REF=main CHECK_ONLY=false bash tools/autoformat.sh +``` + +Tools invoked: `black`, `isort`, `pylint`, `ruff`, `mypy`. + +--- + +## Import Ordering + +After editing imports in any Python files, always run `uv run isort` on those +files before committing: + +```bash +uv run isort .py .py +``` + +--- + +## Setting Up the Linting Group + +Inside the container: + +```bash +uv sync --locked --only-group linting +``` + +This installs `ruff`, `black`, `isort`, `pylint` — the same tools used by +`tools/autoformat.sh` and CI's `linting` job. + +--- + +## Code Style Rules + +- **Type hints**: required on all public API functions. Use `X | None`, not `Optional[X]`. +- **Docstrings**: Google-style on all public classes and functions. +- **Naming**: follow Python conventions — `snake_case` for functions and variables, `PascalCase` for classes. +- **Line length**: 119 characters (configured in `pyproject.toml`). +- **No bare `except`**: always catch specific exception types. diff --git a/skills/nightly-sync/SKILL.md b/skills/nightly-sync/SKILL.md new file mode 100644 index 00000000000..5c803d65b7e --- /dev/null +++ b/skills/nightly-sync/SKILL.md @@ -0,0 +1,608 @@ +--- +name: nightly-sync +description: Domain knowledge for the nightly main-to-dev sync workflow. Covers merge strategy, CI architecture, failure investigation, and known issues. +when_to_use: Working on the nightly sync PR; investigating a nightly sync failure; resolving merge conflicts between main and dev; 'nightly sync failed', 'main-to-dev merge', 'sync bot'. +--- + +# Nightly Sync: Main to Dev + +This skill is read by the automated sync bot during the nightly-sync-main-to-dev +workflow. It contains all domain knowledge for merging main into dev, resolving +conflicts, iterating on CI, and shipping the PR. + +--- + +## Phase 1: Create the Sync Branch and Merge + +### Branch Setup + +1. Create branch `$BRANCH` from `origin/dev` +2. Merge: `git merge origin/main -X theirs --no-edit` +3. If conflicts remain (e.g. add/add), resolve by favoring main + +### Preserving Dev-Only Additions + +Do NOT blanket-override all shared files with main's version. Dev has features +not yet in main (new classes, new modules, new tests). The merge preserves both +sides' non-conflicting additions — only intervene where there is an actual +conflict. + +### Squash-Merge Chain Detection + +Dev often develops features as a chain of PRs (PR1 → PR2 → PR3) where each +builds on the last. When PR1 is squash-merged to main, git sees main's squashed +version and dev's original commits as unrelated changes. `-X theirs` will pick +main's PR1 code and silently discard PR2/PR3's improvements on dev. + +After the merge, check for this pattern: + +1. For each file where `-X theirs` resolved a conflict, run + `git log --oneline origin/dev -- ` to see if dev has commits that + came AFTER the code main is bringing in. +2. If dev has follow-up commits (bug fixes, refactors, extensions), **favor + dev's version** for those sections. +3. If the conflict is just main bringing in a clean copy of what dev already + has (no follow-ups), main's version is fine. + +Practical check: run `git diff origin/dev -- ` on conflicted files. If +dev's code was removed or reverted, investigate whether dev's version is the +more evolved one. + +Real examples from PR #4291: +- `emerging_optimizers.py`: Main's version was MORE complete — it squash-merged + dev's PRs plus added more. `-X theirs` was correct. +- `distrib_optimizer.py`: Main overwrote dev's `GroupedQuantizedTensor` support. + Had to restore `_is_distopt_quantized_param` and the expanded + `_expand_quantized_param_shard_for_cast` loop while keeping main's NVFP4 + additions. This required a surgical merge combining sections from both. + +Key insight: squash-merge chains can go in EITHER direction. Sometimes main +is ahead (it squash-merged dev's work + more), sometimes dev is ahead (it has +follow-up PRs). Always diff both ways before deciding which version to favor. + +### Files to Override from Main + +These files have known semantic conflicts where dev's versions reference args +or APIs that main removed or renamed. Take main's version with +`git checkout origin/main -- `: + +- `megatron/training/training.py` — references dev-only args +- `megatron/training/initialize.py` — references dev-only args +- `megatron/training/utils.py` — references dev-only args +- `megatron/training/datasets/data_samplers.py` — references dev-only args +- `megatron/core/optimizer/layer_wise_optimizer.py` — constructor signature + +**Caveat for ALL overrides:** After taking main's version of any file, you +MUST run the API Mismatch Detection procedure (see below) on that file. +Taking main's caller code while keeping dev's callee implementations is the +#1 source of sync bugs. + +**IMPORTANT: Do NOT take main's `pyproject.toml`, `uv.lock`, or +`docker/Dockerfile.ci.dev`.** These three files are a tightly coupled +triple — the Dockerfile's `uv sync` command must match the dependency +groups in `pyproject.toml`, and `uv.lock` must be consistent with both. +Main's versions are missing dev-only dependencies (e.g. +`fast-hadamard-transform`, correct TransformerEngine revision) and the +`--group no_pypi_wheels` flag needed to install them. Keep dev's versions +of all three files. + +**IMPORTANT: `.github/CODEOWNERS` must NEVER be modified by the sync +bot under any circumstances.** Dev's CODEOWNERS is intentionally +different from main's — do not take main's version, do not merge them, +do not touch the file. If the merge produces a conflict or a non-zero +diff against `origin/dev` on this path, restore dev's version verbatim: + +``` +git checkout origin/dev -- .github/CODEOWNERS +``` + +Then verify with `git diff origin/dev -- .github/CODEOWNERS` — output +must be empty. Modifying CODEOWNERS triggers spurious reviewer +requests and conflicts with the dev team's governance; rolling back a +CODEOWNERS change after the PR lands is painful. + +**NEVER manually edit `uv.lock`.** It is a machine-generated lockfile. If +it needs to change, it must be regenerated with `uv lock` inside a CUDA +container (see `.claude/skills/build-and-test/SKILL.md`). + +### Git Source Reconciliation (pyproject.toml) + +After keeping dev's `pyproject.toml`, check whether main has added NEW git +sources to `[tool.uv.sources]` that don't exist in dev's version. Main's +merged code may import from packages only available at specific git revisions. + +1. Diff the `[tool.uv.sources]` sections: + `git show origin/main:pyproject.toml` vs `git show origin/dev:pyproject.toml` +2. For each git source in main but not dev, add it to dev's `pyproject.toml` +3. For sources in both but at different revisions, check whether dev's revision + works. If dev's revision is broken (TOML parse errors, missing classes main's + code imports), take main's revision instead. + +Real examples from PR #4291: +- `nvidia-resiliency-ext`: Main's `torch.py` imports `get_write_results_queue` + which only existed in main's pinned git revision, not on PyPI. Had to add + main's git source to dev's pyproject.toml. +- `nemo-run`: Dev's pinned revision had a TOML parse error with uv 0.7.2. + Had to swap to main's revision. + +After any changes to `pyproject.toml`, regenerate `uv.lock` inside a CUDA +container: +```bash +docker run --rm -v $(pwd):/workspace nvcr.io/nvidia/pytorch:26.02-py3 \ + bash -c "pip install uv==0.7.2 && cd /workspace && \ + uv venv .venv --system-site-packages && uv sync --only-group build && uv lock" +# Clean up root-owned .venv: +docker run --rm -v $(pwd):/workspace nvcr.io/nvidia/pytorch:26.02-py3 \ + bash -c "rm -rf /workspace/.venv" +``` + +### API Mismatch Detection (Post-Merge Audit) + +The merge can create "Frankenstein" code where main's callers use dev's +implementations (or vice versa) with different method signatures. This +compiles fine but fails at runtime. + +After the merge, audit cross-boundary call sites: + +1. Identify files where main's version was taken (`-X theirs` or explicit + `git checkout origin/main`) +2. For each, find all external call sites: classes it instantiates, methods + it calls on imported objects, functions from other modules it invokes +3. Verify method names, parameter counts, and signatures match between the + caller and the implementation in the merged tree +4. Pay special attention to "interface" modules (files defining base classes) + — if main and dev evolved the interface differently, every caller and + implementer must agree + +Real examples from PR #4291: +- `multi_latent_attention.py` (main) called `off_interface.group_commit()` + but dev's interface only had `group_offload()` — method renamed +- `mamba_model.py` (main) called `init_chunk_handler(3 params)` but dev's + interface required 6 params — signature expanded on dev +- `mamba_model.py` called `mark_not_offloadable()` but dev had + `mark_not_offload()` — method renamed +- `bulk_offload()` did `.remove()` after `bulk_offload_group()` already + `.pop()`d the same item — double-removal from a list + +Practical detection: +```bash +# For each file taken from main, find what it imports and calls +grep -rn "from import\|\." megatron/ +# Cross-reference with the actual implementations in the merged tree +``` + +### File-Specific Merge Lessons + +These lessons were learned from PR #4291. They may recur if the same files +continue to diverge: + +- `gated_delta_net.py`: If the merge creates code calling non-existent helper + methods (e.g. `_resolve_cu_seqlens`), take dev's version wholesale. +- `model_chunk_schedule_plan.py`: Watch for missing imports (e.g. + `CudaGraphScope`) silently dropped during conflict resolution. +- `fine_grained_activation_offload.py`: Critical interface file used by many + callers. If main and dev have divergent method names/signatures, prefer + dev's implementation and patch main-originated callers to match. +- `distrib_optimizer.py`: Dev may have broader type abstractions (e.g. + `_is_distopt_quantized_param` covering both FP8 and GroupedQuantizedTensor). + Main may simplify to explicit type checks. Restore dev's abstractions. + +### Special Handling: data_schedule.py + +Main and dev have completely different classes in this file: +- Main: `HybridCPDataLoaderWrapper` (imported by main's `training.py`) +- Dev: `BasePackingScheduler`, `DpBalancedScheduler`, + `DefaultDynamicCPScheduler`, `wrap_data_iterator`, + `get_batch_on_this_rank_for_sequence_packing` (imported by `pretrain_gpt.py` + and tests) + +**Do NOT take either version wholesale.** Keep dev's file and append main's +`HybridCPDataLoaderWrapper` class (plus any missing imports like +`BalancedCPScheduler`, `Any`, `List`) at the end. + +### Restore Deleted Files + +Compare `git ls-tree` between `origin/main` and HEAD to find files in main +that are missing from the merged tree. For each: +- **Restore** if main's code imports/references it and would break without it + (e.g. `hybrid_cp_schedule.py` if `data_schedule.py` imports from it) +- **Do NOT restore** if dev intentionally deleted it — check + `git log origin/dev -- ` for the deletion commit to understand intent +- When in doubt, check whether any file in the merged tree imports from the + missing file. If nothing imports it, skip it. + +### Formatting + +Run on ALL changed Python files (relative to `origin/dev`), in this order: + +1. `black` (version 24, `--config pyproject.toml`) +2. `isort` +3. Order matters: black first, then isort — reverse order can undo isort's work +4. `pylint` on changed `megatron/core/` files — fix missing-docstring and + line-too-long violations before pushing + +### Pre-push invariant checks + +Before every `git push` in this workflow (the initial push in Phase 1 +AND every fix-push in Phase 3), run these bash checks. If any fails, +fix the condition and re-check before pushing: + +```bash +# 1. CODEOWNERS must be identical to dev's. +if ! git diff --quiet origin/dev -- .github/CODEOWNERS; then + echo "ABORT: .github/CODEOWNERS differs from origin/dev. Restore with:" + echo " git checkout origin/dev -- .github/CODEOWNERS" + exit 1 +fi + +# 2. Dependency-management triple must be identical to dev's. +for f in pyproject.toml uv.lock docker/Dockerfile.ci.dev; do + if ! git diff --quiet origin/dev -- "$f"; then + # pyproject.toml is allowed to differ ONLY for git source reconciliation + # (new [tool.uv.sources] entries from main). If you intentionally edited + # it for that reason, bypass this check by re-running with $f skipped. + echo "WARNING: $f differs from origin/dev" + fi +done +``` + +The CODEOWNERS check is a HARD abort — never push if it fails. + +### Commit and Push + +Phase 1 produces a single commit on the sync branch. The merge itself +creates the merge commit; fold any post-merge work (formatting, +conflict surgery, restored files, regenerated `uv.lock`) into it +rather than stacking a second commit: + +```bash +git add -A +git commit --amend --no-edit # rewrites the merge commit's tree; + # parents are preserved. +git push -u origin "$BRANCH" # only non-force push of the run. +``` + +Once pushed, this commit is immutable for the rest of the run. +Phase 3 fixes go into a separate rolling fix commit on top (see +Phase 3 step 4 and the two-commit policy in Rules). + +--- + +## Phase 2: Create the Draft PR + +- Title: `chore: nightly sync main into dev ($DATE)` +- Create as **draft**: `gh pr create --draft` +- Body should include: + 1. Summary of what was synced (number of commits from main) + 2. **Python-only line-change stats**, so reviewers can gauge the real + code surface (excluding golden-value JSON, uv.lock, etc.). Compute + with: + + ```bash + git diff --numstat origin/dev...HEAD -- '*.py' \ + | awk 'BEGIN{a=0;d=0} {a+=$1; d+=$2} END{ + printf "Python lines: +%d / -%d across %d files\n", a, d, NR + }' + ``` + + Include the exact line (e.g. `Python lines: +1234 / -567 across 42 files`) + in the PR body so reviewers see it at a glance. + 3. List of files where main's version was taken over the merge + 4. List of files that were deleted in dev but restored (and why) + 5. The remerge-diff output (`git show --remerge-diff HEAD` on the merge + commit) so reviewers can inspect ONLY the conflict resolutions. If the + output is very long, summarize conflicts by file and put the full diff + in a collapsed `
` block. If git is too old for `--remerge-diff`, + note the git version and describe the merge strategy used instead. +- Save the PR number for later phases +- **Add the `Run functional tests` and `Run MBridge tests` labels** to the + PR immediately after creation. The `Run functional tests` label ensures + `/ok to test` triggers the full CI suite (unit tests + functional/ + integration tests with 100-step training and golden value comparison). + The `Run MBridge tests` label triggers the MBridge test suite. Without + these labels, only a lightweight subset runs. + ```bash + gh pr edit --repo $REPO \ + --add-label "Run functional tests" \ + --add-label "Run MBridge tests" + ``` + +--- + +## Phase 3: CI Iteration + +### CI Architecture + +- **`Nemo_CICD_Test`** is a downstream gate job aggregating unit test, + integration test, and other results. If it fails, investigate the upstream + jobs it depends on — do NOT debug the gate itself. +- **Integration tests** (H100, GB200) may be skipped for non-maintainer PRs. + This is expected; the `Nemo_CICD_Test` gate will fail as a result. +- **`tests/unit_tests/conftest.py`** imports from `megatron.training.training`, + so a broken import in `training.py` (or anything it transitively imports) + cascades to fail ALL test suites. If every test job fails with ImportError, + check the training.py import chain first. + +### Execution model: one step, no background + +You run inside ONE GitHub Actions step. The moment you stop emitting +tool calls, the step ends and the runner container is destroyed. Any +background process you started dies with it. There is NO persistent +session and NO future wakeup. See the workflow prompt's "NO background +tasks" block for the full ban list. + +Practical rule: every wait for CI to resolve is a SINGLE foreground Bash +tool call that blocks inline until the wait is resolved. + +### The Fix-Then-Retrigger Loop + +Two nested loops. Do NOT conflate them: + +- The **outer loop** is YOUR sequence of tool calls (each iteration: one + `/ok to test`, one blocking poll, maybe one fix-and-push). It is NOT a + Bash loop. It advances because you make new tool calls. +- The **inner loop** is a single blocking Bash tool call using + `while true; do ... sleep 120; done`. It runs during one iteration of + the outer loop and ends when CI reaches a terminal state for that + iteration. + +The outer loop terminates ONLY when Phase 4's gate is satisfied. + +**Source of truth:** `gh pr view --repo $REPO --json statusCheckRollup`. +This lists every required check, including external status contexts +(GitLab CI, `copy-pr-bot`, etc.) that `gh api .../actions/runs/.../jobs` +does NOT show. + +**Outer-loop iteration (each iteration is a few tool calls):** + +1. `latest_sha=$(git rev-parse HEAD)` (one Bash call). +2. Post `/ok to test $latest_sha` on the PR: + `gh pr comment --repo $REPO --body "/ok to test $latest_sha"` +3. ONE blocking Bash tool call. This is the inner loop. Copy this + template verbatim, only changing `REPO` and `PR`: + + ```bash + REPO='NVIDIA/Megatron-LM' + PR='' + # Names matched case-insensitively, anchored to the START of the name. + EXEMPT='copy-pr-bot|is-not-external-contributor|greptile|coderabbit|codeowners|.*review|.*approval|codecov|coverage|build-docs|doc-build|readthedocs|sphinx' + # Sentinel check that tells us CI has fully run. Update this if the + # aggregate gate job is renamed. + SENTINEL='Nemo_CICD_Test' + + while true; do + # Normalize both CheckRun (.status / .conclusion) and StatusContext + # (.state) entries into the same {name, status, conclusion} shape. + rollup=$(gh pr view "$PR" --repo "$REPO" --json statusCheckRollup --jq ' + .statusCheckRollup[] | [ + (.name // .context // "?"), + (if .__typename == "StatusContext" then + (if (.state == "PENDING" or .state == "EXPECTED") then "IN_PROGRESS" + else "COMPLETED" end) + else (.status // "UNKNOWN") end), + (if .__typename == "StatusContext" then + (if .state == "SUCCESS" then "SUCCESS" + elif (.state == "FAILURE" or .state == "ERROR") then "FAILURE" + else "NEUTRAL" end) + else (.conclusion // "UNKNOWN") end) + ] | @tsv') + + # Sentinel: do NOT declare green until the CI aggregate gate has + # reached a terminal state. Before /ok to test triggers the run, + # the sentinel is absent; while CI is running, it's IN_PROGRESS. + sentinel_line=$(printf '%s\n' "$rollup" | awk -F'\t' -v s="$SENTINEL" '$1 == s') + sentinel_status=$(printf '%s\n' "$sentinel_line" | awk -F'\t' 'NR==1 {print $2}') + if [ "$sentinel_status" != "COMPLETED" ]; then + echo "=== $(date -u) waiting for $SENTINEL (status: ${sentinel_status:-absent}) ===" + sleep 120 + continue + fi + + # Classify non-exempt checks (exempt list applied to the NAME only). + non_exempt=$(printf '%s\n' "$rollup" | awk -F'\t' -v p="^($EXEMPT)" 'tolower($1) !~ tolower(p)') + failed=$(printf '%s\n' "$non_exempt" | awk -F'\t' '$2 == "COMPLETED" && $3 !~ /^(SUCCESS|SKIPPED|NEUTRAL)$/') + pending=$(printf '%s\n' "$non_exempt" | awk -F'\t' '$2 != "COMPLETED"') + + if [ -n "$failed" ]; then + echo "=== NON-EXEMPT FAILURES ===" + printf '%s\n' "$failed" + echo "RESULT=FAILURE" + exit 0 + fi + if [ -n "$pending" ]; then + # Sentinel is COMPLETED but a non-exempt check is still pending — + # rare but possible. Keep waiting; do NOT ship. + echo "=== $(date -u) sentinel done but non-exempt checks still pending ===" + printf '%s\n' "$pending" + sleep 120 + continue + fi + + echo "=== ALL NON-EXEMPT CHECKS COMPLETED GREEN ===" + printf '%s\n' "$non_exempt" + echo "RESULT=GREEN" + exit 0 + done + ``` + + This Bash call blocks for as long as CI takes (minutes to hours). Do + NOT split it into many short polls interleaved with other tool calls + — that wastes `--max-turns` and creates windows where you could lose + track of the loop state. + +4. Read the tool output: + - If `RESULT=FAILURE`: diagnose via + `gh api repos/$REPO/actions/jobs//logs` (or the + external-context equivalent) and fix the code. The Phase 1 + commit is immutable; fixes accumulate in a single rolling fix + commit on top of it: + ```bash + git add -A + if git rev-parse --verify HEAD^2 >/dev/null 2>&1; then + # HEAD has two parents → still the Phase 1 merge commit. + # First failure of this run: create the fix commit. + git commit -m "fix: post-CI corrections" + git push origin "$BRANCH" + else + # HEAD is the existing fix commit → amend it. + git commit --amend --no-edit + git push --force-with-lease origin "$BRANCH" + fi + ``` + `--force-with-lease` (not `--force`): if a human pushed onto the + branch since the bot last fetched, the lease aborts the push + instead of clobbering them — fetch and decide what to do. + Start a new outer-loop iteration at step 1 with the new HEAD SHA. + - If `RESULT=GREEN`: outer loop is done. Proceed to Phase 4. + +**Why not wait-for-run-to-register first?** `gh pr comment` with +`/ok to test ` is handled by `copy-pr-bot`, which takes a few +seconds to trigger the CI run. The `statusCheckRollup` poll in step 3 +will initially show checks in `PENDING` / `QUEUED`; that's fine — the +inner loop treats those as "keep waiting" and will see them advance as +CI progresses. No separate registration poll needed. + +### Anti-Patterns (what went wrong on run 24800621116) + +- **Do NOT classify a queued/in-progress job as "infrastructure- + blocked" and ship.** A stuck queue drains eventually — wait. If the + job eventually passes, great; if it fails, go fix it. +- **Do NOT mark ready while any required check is `PENDING` / + `QUEUED` / `IN_PROGRESS` on the HEAD SHA.** A push is not a pass; + only a `COMPLETED` + green status is. +- **Do NOT declare an untested job "pre-existing."** Pre-existing + means the test ran to completion and failed the same way on recent + dev CI. A job that never ran on your PR cannot be pre-existing. +- **Do NOT use `gh api .../actions/runs/.../jobs` alone** as the gate + signal. External status contexts (GitLab CI pipelines, copy-pr-bot + status, etc.) do NOT appear there. Use `statusCheckRollup`. +- **Do NOT start any background process.** No `&`, no `nohup`, no + `run_in_background: true`, no `ScheduleWakeup`. The GitHub Actions + step owns your shell; when the step ends, every background process + is killed and cannot resume. +- **Do NOT push directly to `pull-request/` branches.** + The community bot manages those branches when it processes + `/ok to test`. Pushing to them directly breaks the CI trigger + mechanism. Always push to your own sync branch (e.g. + `main2dev/`) instead. +- **Do NOT forget the `Run functional tests` and `Run MBridge tests` + labels.** Without `Run functional tests`, the internal GitLab + functional tests do not run; without `Run MBridge tests`, the + MBridge test suite does not run. + +### Failure Investigation + +1. Fetch logs: `gh api repos/$REPO/actions/jobs//logs` +2. Grep for: `ImportError`, `ModuleNotFoundError`, `FAILED`, + `would reformat`, `line-too-long`, `Traceback` +3. Read the error, understand root cause, fix the code + +### Common Issues + +- **ImportError for a class/module:** Dev test imports a class from a file + where we took main's version. Restore only the missing class/function — + not the entire file. If a file's classes are completely different between + main and dev, keep both sets of code. +- **Formatting failures (black/pylint):** Run `black --config pyproject.toml` + on offending files. For pylint long-line or missing-docstring, edit directly. +- **Circular imports:** `isort` can reorder imports in a way that introduces + circular dependencies (e.g. `megatron/legacy/model/__init__.py`). Check + `git diff` on `__init__.py` files to see if import order changed. +- **Dependency version mismatches:** Taking main's `pyproject.toml`/`uv.lock` + can change library versions in the CI container. Dev-only code may depend on + newer versions (e.g. TransformerEngine's `single_grouped_weight`). If failures + trace to missing kwargs or changed APIs in third-party libs, this is the cause. +- **API mismatch (AttributeError / TypeError at runtime):** Main's callers + reference methods that don't exist (or have different signatures) in dev's + implementations. See "API Mismatch Detection" in Phase 1. Fix by adding + shims, renaming methods, or adjusting call signatures. +- **Infrastructure / network failures (apt-get, pip download):** Errors like + `archive.ubuntu.com unreachable` or `Connection timed out` during package + installation are transient CI infrastructure issues, not code problems. + Retry CI with the same SHA. Do not investigate as code failures. + +### Pre-Existing Failure Verification + +**You MUST empirically verify before classifying any failure as pre-existing.** + +1. `gh pr list --repo $REPO --base dev --state merged --limit 3` +2. `gh pr checks --repo $REPO` on a recently merged dev PR +3. If the same test bucket **passes on recent dev CI** → the failure is + sync-caused. You must fix it. +4. Only if the test **also fails on recent dev CI** can you classify it as + pre-existing. Document with the dev PR number and CI run as evidence. + +### Internal GitLab Functional Tests + +GitHub CI covers unit tests and some integration tests. Internal GitLab +(`gitlab-master.nvidia.com`) runs additional functional tests on +H100/GB200 hardware that may reveal issues GitHub CI does not catch. +These surface in `statusCheckRollup` as external status contexts (the +bash template already handles them via the `__typename == "StatusContext"` +branch). + +- Fine-grained activation offloading failures, for example, only showed + up in GitLab functional tests during PR #4291 +- If GitHub CI passes but a reviewer reports GitLab failures, + investigate with the same rigor as GitHub CI failures +- The sync PR should ideally pass both GitHub and GitLab CI before + merge, but GitHub CI passing (i.e. the Phase 4 gate above) is the + minimum before `gh pr ready` + +--- + +## Phase 4: Mark PR Ready — Strict Gate + +Run `gh pr ready` ONLY when every non-exempt required check on the latest +CI run (against the current HEAD SHA) satisfies BOTH: + +1. `status == "completed"` — NOT `queued`, `in_progress`, `pending`, + `waiting`, or `requested`. +2. `conclusion ∈ {"success", "skipped", "neutral"}`. + +If a non-exempt check is pending/queued/in-progress: keep polling; do not +run `gh pr ready`. If it fails: go back to Phase 3's loop. + +The exempt list (approval/coverage/docs) is defined in Phase 3; only those +checks may be ignored. + +A pre-existing failure (same test failing identically on recent dev CI) +may be accepted, but ONLY after it has fully run, been empirically +verified against dev, and documented in the PR body with evidence (dev PR +number + CI run URL). + +``` +gh pr ready --repo $REPO +``` + +Then comment on the PR confirming it is ready for human review. The +comment should include: +- Which non-exempt checks passed (summary from the bash template's + final `ALL NON-EXEMPT CHECKS COMPLETED GREEN` output) +- Any documented pre-existing failures with evidence (dev PR number + + CI run URL showing the same failure on recent dev CI) +- Which files were taken from main vs. merged manually +- Any API mismatches detected and fixed +- Any `pyproject.toml` git source reconciliation performed +- Links to the CI runs that validated the fixes + +--- + +## Rules + +- Prioritize main over dev on genuine conflicts. Preserve dev-only additions + that do not conflict. +- **Two-commit policy:** the PR contains at most two bot-authored + commits — the Phase 1 merge commit (immutable once pushed) and a + single rolling fix commit on top. The fix commit is created on + the first Phase 3 failure (normal push) and amended on every + subsequent failure (`git commit --amend --no-edit` + + `git push --force-with-lease`). Never modify the Phase 1 commit + after pushing it; never let the fix-commit count exceed one. +- CI triggers via comment: `/ok to test ` +- CI runs appear on branch `pull-request/` +- Git committer identity: `svcnvidia-nemo-ci` +- After editing imports, run `isort` on those files +- **Push directly to NVIDIA/Megatron-LM** (not a fork). The bot uses a PAT + with write access. CLAUDE.md says "never push directly" but that rule is + for human contributors — the sync bot is an exception. diff --git a/skills/onboard-gb200-1node-tests/SKILL.md b/skills/onboard-gb200-1node-tests/SKILL.md index ef5642f73fb..b24143fa808 100644 --- a/skills/onboard-gb200-1node-tests/SKILL.md +++ b/skills/onboard-gb200-1node-tests/SKILL.md @@ -1,10 +1,9 @@ --- name: onboard-gb200-1node-tests -description: Onboard 1-node GitHub MR functional tests for GB200 from existing mr-scoped 2-node tests. Use when the user asks to add GB200 github-mr tests, create single-node variants of existing tests, or expand CI coverage for GB200. +description: Onboard 1-node GitHub MR functional tests for GB200 from existing mr-scoped 2-node tests. +when_to_use: Adding GB200 github-mr tests; creating single-node variants of existing tests; expanding CI coverage for GB200; 'add GB200 MR tests', 'onboard GB200 1-node', 'create single-node variant'. user_invocable: true argument: "[model-yaml] # optional: gpt, moe, or both (default: both)" -TRIGGER when: user asks to add GB200 mr-github tests, create 1-node variants of functional tests, or onboard tests for GitHub CI on GB200. -DO NOT TRIGGER when: user is asking about H100 tests or making unrelated changes to existing test configs. --- # Onboard GB200 1-Node GitHub MR Tests diff --git a/skills/respond-to-issue/SKILL.md b/skills/respond-to-issue/SKILL.md index 1d513da2e63..e815c93f9d2 100644 --- a/skills/respond-to-issue/SKILL.md +++ b/skills/respond-to-issue/SKILL.md @@ -1,6 +1,7 @@ --- name: respond-to-issue -description: Research and draft a response to a GitHub issue or question from an external contributor. Use when the user shares a GitHub issue URL or asks to respond to a community question. +description: Research and draft a response to a GitHub issue or question from an external contributor. +when_to_use: User shares a GitHub issue URL or asks to respond to a community question; 'respond to this issue', 'draft a reply', 'answer this GitHub question'. user_invocable: true argument: "" --- diff --git a/skills/run-on-slurm/SKILL.md b/skills/run-on-slurm/SKILL.md new file mode 100644 index 00000000000..4c9793f06f2 --- /dev/null +++ b/skills/run-on-slurm/SKILL.md @@ -0,0 +1,114 @@ +--- +name: run-on-slurm +description: How to launch distributed Megatron-LM training jobs on a SLURM cluster. Covers a minimal sbatch skeleton, environment-variable setup for torch.distributed.run, CUDA_DEVICE_MAX_CONNECTIONS rules across hardware and parallelism modes, container conventions, monitoring, and per-rank failure diagnosis. +when_to_use: Submitting a SLURM job; writing or debugging an sbatch script; configuring multi-node distributed training; setting MASTER_ADDR / MASTER_PORT / WORLD_SIZE; diagnosing a SLURM job failure; 'how do I run on the cluster', 'sbatch', 'multi-node training'. +--- + +# Run Megatron-LM on SLURM + +## Prerequisites + +- A SLURM cluster login with submission rights to a GPU partition. +- Megatron-LM checked out on a filesystem visible to all nodes in the allocation (NFS, Lustre, or similar). All nodes must reach the same paths for code, data, checkpoints, and output. +- `uv` installed; run `uv sync --extra training --extra dev` (or `--extra lts`) on the worktree once before submission so the `.venv` is materialized and visible to every node. + +## Minimal sbatch script + +Save as `run_megatron.slurm` in the worktree: + +```bash +#!/bin/bash +#SBATCH --job-name=megatron +#SBATCH --account= +#SBATCH --partition= +#SBATCH --nodes= +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node= +#SBATCH --time= +#SBATCH --output=logs/%x-%j.out +#SBATCH --error=logs/%x-%j.err + +set -euo pipefail +cd + +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1) +export MASTER_PORT=${MASTER_PORT:-29500} +export NNODES=${SLURM_NNODES} +export GPUS_PER_NODE= +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Set CUDA_DEVICE_MAX_CONNECTIONS only when your configuration requires it +# (see the section below). Example for pre-Blackwell with TP>1 or CP>1 +# (non-FSDP): +# export CUDA_DEVICE_MAX_CONNECTIONS=1 + +srun --ntasks=${NNODES} --ntasks-per-node=1 bash -c ' + # NODE_RANK comes from SLURM_NODEID with one task per node. + NODE_RANK=${SLURM_NODEID} + uv run python -m torch.distributed.run \ + --nnodes='"${NNODES}"' \ + --nproc-per-node='"${GPUS_PER_NODE}"' \ + --node-rank=${NODE_RANK} \ + --master-addr='"${MASTER_ADDR}"' \ + --master-port='"${MASTER_PORT}"' \ + pretrain_gpt.py \ + +' +``` + +Submit: + +```bash +mkdir -p logs && JOB_ID=$(sbatch --parsable run_megatron.slurm) +echo "Submitted ${JOB_ID}" +``` + +## Multi-node rules + +- Submit from the worktree you intend to run, or `cd` to it in the script. All nodes must reach the same path on a shared filesystem (NFS, Lustre, or similar) — node-local paths will not be visible to peer ranks. +- Use one `torchrun` worker group across all nodes; do not start independent single-node jobs. +- `--nproc-per-node` should equal the number of visible GPUs per node. +- Write checkpoints, tensorboard data, and structured logs to shared storage. + +## CUDA_DEVICE_MAX_CONNECTIONS + +The right value depends on your hardware and parallelism mode. Do not export it unconditionally: + +- **Pre-Blackwell (Hopper, Ampere) with TP>1 or CP>1, non-FSDP:** set to `1`. The relevant code path asserts on this — you will get an assertion error if it is not `1`, not a silent deadlock. +- **Blackwell:** not required; setting it has no effect. +- **Torch-FSDP2 or Megatron-FSDP:** must NOT be `1`. Leave the env var unset, or set it to a value greater than `1`. +- **`overlap_moe_expert_parallel_comm` enabled:** set to `32`. + +Set it explicitly in the sbatch script when your configuration calls for it. + +## Containers + +Many sites run Megatron-LM inside a container (enroot/pyxis on some clusters, singularity on others). If you do, the uv-managed `.venv` must live on a path that is visible from inside the container, and the container image must provide the CUDA / NCCL / torch versions the repo expects (see `docker/.ngc_version.dev` and `.ngc_version.lts`). The skeleton above stays the same; wrap the `srun` invocation with your scheduler's container flags (`--container-image=…`, `--container-mounts=…`, etc.). + +## Monitor and collect + +```bash +squeue -j "$JOB_ID" -o "%.10i %.8T %.10M %.6D %R" +sacct -j "$JOB_ID" --format=JobID,State,ExitCode,Elapsed +scancel "$JOB_ID" +``` + +If your training script writes a result artifact (a JSON metrics file from rank 0, a final checkpoint, etc.), poll for the artifact rather than waiting only on `squeue` state. Useful output usually appears before SLURM marks the job complete, and polling on the artifact lets you cancel the job as soon as it lands instead of holding the allocation until the timeout. + +## Failure diagnosis + +Scan stderr from every rank, not just rank 0. The earliest non-NCCL Python traceback is usually the root cause; later NCCL timeouts on other ranks are downstream symptoms of the first crash. + +Classify quickly: + +- **OOM**: record rank, phase (forward / backward / optimizer), batch size, sequence length, parallelism (TP/DP/CP/PP), and peak memory before adjusting. +- **Shape / divisibility error**: check `WORLD_SIZE = TP × DP × CP × PP` and head-count divisibility (`num_attention_heads % TP == 0`). +- **Import error**: wrong worktree, missing `uv sync`, or stale `PYTHONPATH`. Confirm `cd ` before launch. +- **NCCL failure** with no Python traceback: verify allocation, port reachability, `MASTER_ADDR` resolution, and command consistency across ranks. + +## Common pitfalls + +- Forgetting `uv sync` before the first submission. If the venv is missing, every job rebuilds it from inside `srun`, costing minutes per job. +- Writing logs to a node-local path that disappears at job exit. Always write to the shared filesystem. +- Setting `CUDA_DEVICE_MAX_CONNECTIONS=1` blindly. The right value depends on hardware and parallelism mode (see the dedicated section above). Setting it to `1` with FSDP causes a different problem; on Blackwell it has no effect; on pre-Blackwell with TP>1 or CP>1 (non-FSDP) the code asserts, it does not deadlock. +- Running bare `torchrun` instead of `uv run python -m torch.distributed.run`. Bare `torchrun` may dispatch through a python interpreter that does not see venv packages, depending on how the venv is set up. diff --git a/skills/split-pr/SKILL.md b/skills/split-pr/SKILL.md index 52c0de2d47b..1a6f8966e49 100644 --- a/skills/split-pr/SKILL.md +++ b/skills/split-pr/SKILL.md @@ -1,6 +1,7 @@ --- name: split-pr -description: Split a PR into multiple PRs to reduce the number of required CODEOWNERS reviewer groups. Use when the user asks to split a PR, reduce reviewer groups, or break up a large PR. +description: Split a PR into multiple PRs to reduce the number of required CODEOWNERS reviewer groups. +when_to_use: User asks to split a PR, reduce reviewer groups, or break up a large PR; 'too many CODEOWNERS', 'split this PR', 'break up PR', 'reduce reviewers needed'. user_invocable: true argument: "" --- diff --git a/skills/testing/SKILL.md b/skills/testing/SKILL.md new file mode 100644 index 00000000000..13c1afcb533 --- /dev/null +++ b/skills/testing/SKILL.md @@ -0,0 +1,193 @@ +--- +name: testing +description: Test system for Megatron-LM. Covers test layout, recipe YAML structure, adding and running unit and functional tests, golden values, marker filters, and CI parity. +when_to_use: Adding or running a unit or functional test; understanding the test layout; writing a recipe YAML; downloading or updating golden values; reproducing a test failure locally; 'how do I add a test', 'run unit tests', 'pytest fails', 'test layout', 'golden values', 'recipe YAML', 'marker filter'. +--- + +# Testing Guide + +--- + +## Test Layout + +```text +tests/ +├── unit_tests/ # pytest, 1 node × 8 GPUs, torch.distributed runner +├── functional_tests/ # end-to-end shell + training scripts +│ └── test_cases/ +│ └── {model}/{test_case}/ +│ ├── model_config.yaml # training args +│ └── golden_values_{env}_{platform}.json +└── test_utils/ + ├── recipes/ + │ ├── h100/ # YAML recipes for H100 jobs + │ └── gb200/ # YAML recipes for GB200 jobs + └── python_scripts/ # helpers (recipe_parser, golden-value download, …) +``` + +--- + +## How Tests Execute + +The GitHub Actions runner invokes `launch_nemo_run_workload.py`, which uses +**nemo-run** to launch a `DockerExecutor` container. The repo is bind-mounted +at `/opt/megatron-lm`; training data is mounted at `/mnt/artifacts`. + +**Unit tests** are dispatched through `torch.distributed.run`: + +- Ranks 0 and 3 are tee-d to stdout; all other ranks write only to log files. +- Per-rank log files land at `{assets_dir}/logs/1/` and are uploaded as a + GitHub artifact after the run. + +**Functional tests** are driven by +`tests/functional_tests/shell_test_utils/run_ci_test.sh`. Only rank 0 runs the +pytest validation step; training output from all ranks is uploaded as an artifact. + +**Flaky-failure auto-retry**: `launch_nemo_run_workload.py` retries up to +**3 times** for known transient patterns (NCCL timeout, ECC error, segfault, +HuggingFace connectivity, …) before declaring a genuine failure. + +--- + +## Recipe YAML Structure + +Recipes live in `tests/test_utils/recipes/` and are parsed by +`tests/test_utils/python_scripts/recipe_parser.py`. Each file expands a +cartesian `products` block into individual workload specs: + +```yaml +type: basic +format_version: 1 +maintainers: [mcore] +loggers: [stdout] +spec: + name: "{test_case}_{environment}_{platforms}" + model: gpt # maps to tests/functional_tests/test_cases/{model}/ + build: mcore-pyt-{environment} + nodes: 1 + gpus: 8 + n_repeat: 5 + platforms: dgx_h100 + time_limit: 1800 + script_setup: | + ... + script: |- + bash tests/functional_tests/shell_test_utils/run_ci_test.sh ... +products: + - test_case: [my_test] + products: + - environment: [dev, lts] + scope: [mr-github] + platforms: [dgx_h100] +``` + +Key runtime placeholders: `{assets_dir}`, `{artifacts_dir}`, `{test_case}`, +`{environment}`, `{platforms}`, `{n_repeat}`. + +### Disabling a Test Without Deleting It + +To temporarily disable a test case in a recipe YAML, suffix its `scope` value +with `-broken` — **do not delete the entry**: + +```yaml +# before (test runs in CI) +scope: [mr-github] + +# after (test is skipped; entry preserved for easy re-enable) +scope: [mr-github-broken] +``` + +--- + +## Running Unit Tests Locally + +All unit tests initialize a `torch.distributed` group, so every invocation +requires GPU access and must go through `torch.distributed.run`: + +```bash +# Full suite +uv run python -m torch.distributed.run --nproc-per-node 8 -m pytest -q \ + tests/unit_tests + +# Single file +uv run python -m torch.distributed.run --nproc-per-node 8 -m pytest -q \ + tests/unit_tests/models/test_gpt_model.py + +# Single test +uv run python -m torch.distributed.run --nproc-per-node 8 -m pytest -q \ + tests/unit_tests/models/test_gpt_model.py::TestGPTModel::test_constructor + +# Filter by name substring +uv run python -m torch.distributed.run --nproc-per-node 8 -m pytest -q \ + tests/unit_tests -k optimizer +``` + +### Marker filters + +```bash +# Exclude flaky tests during development +uv run python -m torch.distributed.run --nproc-per-node 8 -m pytest -q \ + tests/unit_tests -m "not flaky and not flaky_in_dev" + +# Include experimental tests +uv run python -m torch.distributed.run --nproc-per-node 8 -m pytest -q \ + tests/unit_tests --experimental +``` + +### CI parity + +Use `tests/unit_tests/run_ci_test.sh` to reproduce a CI bucket failure exactly. +For ad-hoc runs, prefer the direct `torch.distributed.run` invocations above. + +### Gotchas + +- `pyproject.toml` sets `addopts = --durations=15 -s -rA` — stdout is not + captured (`-s`), so ranks interleave during multi-rank runs. Override with + `--capture=fd` when debugging a specific rank. +- `tests/unit_tests/conftest.py` looks for test data under `/opt/data` and + attempts a download if missing. Supply it manually or skip data-dependent + tests when running outside the canonical container. + +--- + +## Adding a Unit Test + +1. Create `tests/unit_tests//test_.py`. +2. Use fixtures from `tests/unit_tests/conftest.py`. +3. Apply markers as needed: + - `@pytest.mark.internal` — skipped on `legacy` tag + - `@pytest.mark.flaky_in_dev` — skipped in `dev` environment (CI default; use this to disable a flaky test without blocking the standard pipeline) + - `@pytest.mark.flaky` — skipped in `lts` environment + - `@pytest.mark.experimental` — `latest` tag only +4. Verify locally (see Running Unit Tests Locally above). +5. If the test needs a dedicated CI bucket, add an entry to + `tests/test_utils/recipes/h100/unit-tests.yaml`. + +--- + +## Adding a Functional / Integration Test + +1. Create `tests/functional_tests/test_cases///`. +2. Write `model_config.yaml` with `MODEL_ARGS`, `ENV_VARS`, and `TEST_TYPE`. +3. Add a YAML recipe under `tests/test_utils/recipes/h100/` (and `gb200/` if + needed). Required fields: `scope`, `environment`, `platform`, `n_repeat`, + `time_limit`. +4. Push the PR, add the label **"Run functional tests"** to trigger a full run. +5. After a successful run, download golden values: + + ```bash + python tests/test_utils/python_scripts/download_golden_values.py \ + --source github --pipeline-id + ``` + +6. Commit the downloaded golden values. + +--- + +## Common Pitfalls + +| Problem | Cause | Fix | +|---------|-------|-----| +| Test passes locally but fails in CI | Different environment or data path | Check `DATA_PATH`, `DATA_CACHE_PATH`, and the `environment` tag (`dev` vs `lts`) | +| Golden value mismatch after a code change | Numerical regression | Download new golden values via `download_golden_values.py` after a clean run | +| `cicd-integration-tests-gb200` not triggered | GB200 jobs require maintainer status | Ask a maintainer to trigger, or add the `Run functional tests` label | diff --git a/skills/testsystem/SKILL.md b/skills/testsystem/SKILL.md deleted file mode 100644 index 3b3d6175915..00000000000 --- a/skills/testsystem/SKILL.md +++ /dev/null @@ -1,371 +0,0 @@ ---- -name: testsystem -description: Test system, CI pipeline, and CI failure investigation for Megatron-LM. Covers test layout, recipe YAML structure, adding unit and functional tests, CI scope labels, triggering internal GitLab CI, pipeline structure, and debugging CI failures. -TRIGGER when: user asks to run tests, add a test, investigate a CI failure, understand the CI pipeline, or work with test recipes; user opens or pushes to a PR and needs to know which CI label to attach; user wants to trigger the internal GitLab CI pipeline; user asks to download golden values or references a pipeline/run ID in the context of golden values. -DO NOT TRIGGER when: user is only setting up the dev environment or managing dependencies (use build-and-dependency instead). ---- - -# Test System & CI Guide - ---- - -## Test Layout - -```text -tests/ -├── unit_tests/ # pytest, 1 node × 8 GPUs, torch.distributed runner -├── functional_tests/ # end-to-end shell + training scripts -│ └── test_cases/ -│ └── {model}/{test_case}/ -│ ├── model_config.yaml # training args -│ └── golden_values_{env}_{platform}.json -└── test_utils/ - ├── recipes/ - │ ├── h100/ # YAML recipes for H100 jobs - │ └── gb200/ # YAML recipes for GB200 jobs - └── python_scripts/ # helpers (recipe_parser, golden-value download, …) -``` - ---- - -## How Tests Execute - -The GitHub Actions runner invokes `launch_nemo_run_workload.py`, which uses -**nemo-run** to launch a `DockerExecutor` container. The repo is bind-mounted -at `/opt/megatron-lm`; training data is mounted at `/mnt/artifacts`. - -**Unit tests** are dispatched through `torch.distributed.run`: - -- Ranks 0 and 3 are tee-d to stdout; all other ranks write only to log files. -- Per-rank log files land at `{assets_dir}/logs/1/` and are uploaded as a - GitHub artifact after the run. - -**Functional tests** are driven by -`tests/functional_tests/shell_test_utils/run_ci_test.sh`. Only rank 0 runs the -pytest validation step; training output from all ranks is uploaded as an artifact. - -**Flaky-failure auto-retry**: `launch_nemo_run_workload.py` retries up to -**3 times** for known transient patterns (NCCL timeout, ECC error, segfault, -HuggingFace connectivity, …) before declaring a genuine failure. - ---- - -## Recipe YAML Structure - -Recipes live in `tests/test_utils/recipes/` and are parsed by -`tests/test_utils/python_scripts/recipe_parser.py`. Each file expands a -cartesian `products` block into individual workload specs: - -```yaml -type: basic -format_version: 1 -maintainers: [mcore] -loggers: [stdout] -spec: - name: "{test_case}_{environment}_{platforms}" - model: gpt # maps to tests/functional_tests/test_cases/{model}/ - build: mcore-pyt-{environment} - nodes: 1 - gpus: 8 - n_repeat: 5 - platforms: dgx_h100 - time_limit: 1800 - script_setup: | - ... - script: |- - bash tests/functional_tests/shell_test_utils/run_ci_test.sh ... -products: - - test_case: [my_test] - products: - - environment: [dev, lts] - scope: [mr-github] - platforms: [dgx_h100] -``` - -Key runtime placeholders: `{assets_dir}`, `{artifacts_dir}`, `{test_case}`, -`{environment}`, `{platforms}`, `{n_repeat}`. - -### CI Test Scope Labels - -The CI pipeline reads PR labels to decide test scope, n_repeat, and container image. - -**Decision tree (first match wins):** - -| Condition | `scope` | `n_repeat` | `lightweight` | Notes | -|-----------|---------|-----------|---------------|-------| -| Merge group | `mr-github` | 1 | false | Automatic, no label needed | -| Label: **`Run tests`** | `mr-github` | 1 | **true** | Trains 4 steps, no golden-value compare | -| Label: **`Run functional tests`** | `mr-github` | 5 | **false** | Trains 100 steps, golden-value compare | -| _(no label)_ | `mr-github-slim` | 5 | false | Slim subset only | - -**Orthogonal image label:** - -| Label | Effect | -|-------|--------| -| **`container::lts`** | Use the LTS base image instead of `dev` (combinable with any scope label) | -| **`Run MBridge tests`** | Also triggers the MBridge L1 test suite | - -### Disabling a Test Without Deleting It - -To temporarily disable a test case in a recipe YAML, suffix its `scope` value -with `-broken` — **do not delete the entry**: - -```yaml -# before (test runs in CI) -scope: [mr-github] - -# after (test is skipped; entry preserved for easy re-enable) -scope: [mr-github-broken] -``` - -This applies to any scope token (`mr-github`, `mr-github-slim`, `mr-gitlab`, -etc.). Deleting the entry entirely would require recreating the test case -definition when the fix lands. - -### Which label to attach when opening a PR - -Apply this logic based on what the PR changes: - -| Changed paths / nature of change | Label to attach | -|----------------------------------|-----------------| -| Docs only (`docs/`, `*.md`, docstrings) | _(none)_ | -| CI/tooling only (`.github/`, `tools/`, `Makefile`) | _(none)_ | -| Test files only (`tests/`) — existing tests, no new golden values | `Run tests` | -| **New test cases added** (no golden values exist yet) | `Run functional tests` | -| Non-numerical library code (logging, error handling, CLI flags, refactors) | `Run tests` | -| Could affect training numerics (model arch, attention, optimizer, distributed, MoE routing) | `Run functional tests` | -| Container or dependency changes (`docker/`, `pyproject.toml`, `uv.lock`) | `Run tests` + `container::lts` | -| Touches MBridge integration | add `Run MBridge tests` | - -**Rule of thumb:** default to `Run tests`. Always use `Run functional tests` when the PR adds new test cases (golden values must be generated) or when the change could plausibly shift loss curves. - ---- - -## Adding a Unit Test - -1. Create `tests/unit_tests//test_.py`. -2. Use fixtures from `tests/unit_tests/conftest.py`. -3. Apply markers as needed: - - `@pytest.mark.internal` — skipped on `legacy` tag - - `@pytest.mark.flaky` — skipped in `lts` environment - - `@pytest.mark.experimental` — `latest` tag only -4. Verify locally inside the container: - - ```bash - pytest -xvs tests/unit_tests//test_.py - ``` - -5. If the test needs a dedicated CI bucket, add an entry to - `tests/test_utils/recipes/h100/unit-tests.yaml`. - ---- - -## Adding a Functional / Integration Test - -1. Create `tests/functional_tests/test_cases///`. -2. Write `model_config.yaml` with `MODEL_ARGS`, `ENV_VARS`, and `TEST_TYPE`. -3. Add a YAML recipe under `tests/test_utils/recipes/h100/` (and `gb200/` if - needed). Required fields: `scope`, `environment`, `platform`, `n_repeat`, - `time_limit`. -4. Push the PR, add the label **"Run functional tests"** to trigger a full run. -5. After a successful run, download golden values: - - ```bash - python tests/test_utils/python_scripts/download_golden_values.py \ - --source github --pipeline-id - ``` - -6. Commit the downloaded golden values. - ---- - -## Triggering Internal CI - -Use `tools/trigger_internal_ci.py` to push the current branch to the internal -GitLab remote and trigger a pipeline — without touching the GitLab UI. -Full setup and usage details: `tools/trigger_internal_ci.md`. - -**Prerequisites** (one-time): - -```bash -# 1. Add the internal GitLab remote -git remote add gitlab git@:ADLR/Megatron-LM.git - -# 2. Create a personal access token with 'api' scope on your GitLab profile, -# then store it: -export GITLAB_TOKEN=glpat- -``` - -**Usage:** - -```bash -python tools/trigger_internal_ci.py \ - --gitlab-origin gitlab \ - [--functional-test-scope mr] \ - [--functional-test-repeat 5] \ - [--functional-test-cases all] \ - [--dry-run] -``` - -The script force-pushes the current branch as `pull-request/` and -prints the resulting pipeline URL. - ---- - -## CI Pipeline - -The main workflow is `.github/workflows/cicd-main.yml`. It triggers on pushes -to branches matching `pull-request/[0-9]+` and `deploy-release/*`, on merge -groups, on a daily schedule, and on manual dispatch. - -### Pipeline Structure - -```text -is-not-external-contributor - └─ pre-flight - └─ configure # determines scope, container tag, n_repeat - ├─ linting - ├─ cicd-container-build - │ ├─ cicd-parse-unit-tests → cicd-unit-tests-latest - │ ├─ cicd-parse-integration-tests-h100 → cicd-integration-tests-latest-h100 - │ └─ cicd-parse-integration-tests-gb200 → cicd-integration-tests-latest-gb200 (maintainers only) - └─ Nemo_CICD_Test # final pass/fail gate -``` - -Images are pushed to: - -- AWS ECR: `766267172432.dkr.ecr.us-east-1.amazonaws.com/…` -- GCP Artifact Registry: `us-east4-docker.pkg.dev/nv-projdgxchipp-20260113193621/megatron-lm/…` - ---- - -## CI Failure Investigation - -CI branches always follow the pattern `pull-request/`. - -### Locating the PR from a CI Branch - -```bash -# Extract PR number from the current branch -PR_NUMBER=$(git rev-parse --abbrev-ref HEAD | grep -oP '(?<=pull-request/)\d+') - -# Fetch the PR metadata (title, labels, author, base branch) -gh pr view "$PR_NUMBER" --repo NVIDIA/Megatron-LM - -# Show the changeset for that PR -gh pr diff "$PR_NUMBER" --repo NVIDIA/Megatron-LM - -# List the files changed in the PR -gh pr view "$PR_NUMBER" --repo NVIDIA/Megatron-LM --json files --jq '.files[].path' -``` - -If the branch name contains a non-numeric suffix (e.g. `pull-request/my-branch`), -search by branch name instead: - -```bash -gh pr list --repo NVIDIA/Megatron-LM --head "pull-request/my-branch" -``` - -### Reading CI Job Logs - -```bash -# List recent workflow runs for the PR -gh run list --repo NVIDIA/Megatron-LM --branch "pull-request/$PR_NUMBER" - -# Show summary of a specific run -gh run view --repo NVIDIA/Megatron-LM - -# Stream the GitHub Actions runner output (stdout of ranks 0 and 3 only) -gh run view --repo NVIDIA/Megatron-LM --log-failed -``` - -Full per-rank logs are **not** in the runner stdout. They are uploaded as -GitHub artifacts named `logs---`. - -If the runner output does not show a clear error, download the full artifact -and crawl all rank logs: - -```bash -# 1. Find the artifact name for the failing run -gh run view --repo NVIDIA/Megatron-LM --json artifacts \ - --jq '.artifacts[].name' - -# 2. Download the artifact zip -gh run download --repo NVIDIA/Megatron-LM \ - --name "logs-" -D ./ci-logs - -# 3. Locate which rank logs contain errors (file list only, no content yet) -grep -r -l "ERROR\|Traceback\|FAILED\|fatal" ./ci-logs/ - -# 4. Log files can exceed 10 000 lines — never read a full log at once. -# Check size first, then read in chunks of ~200 lines: -wc -l ./ci-logs///attempt_0//stderr.log -sed -n '1,200p' ./ci-logs/.../stderr.log # chunk 1 -sed -n '201,400p' ./ci-logs/.../stderr.log # chunk 2 -# … continue until the traceback / error is found, then stop. -``` - -Inside the artifact the log tree mirrors the container's `assets_dir`: - -```text -ci-logs/ -└── / - └── / - └── attempt_0/ - └── / - ├── stdout.log - └── stderr.log -``` - -### Identifying Failure Root Cause - -1. **Linting failure** — re-run `tools/autoformat.sh` locally; the diff shows - exactly what needs to change. -2. **Container build failure** — inspect the `cicd-container-build` job log. - Common causes: new dependency with conflicting pins in `uv.lock`, or a broken - git-sourced package revision. -3. **Unit test failure** — the failing bucket is identified in the - `cicd-unit-tests-latest` job matrix. Ranks 0 and 3 appear in runner stdout; - for failures on other ranks download the artifact and check per-rank logs. - Re-run the specific bucket locally inside the container: - - ```bash - bash tests/unit_tests/run_ci_test.sh \ - --tag latest --environment dev \ - --bucket "" \ - --log-dir ./logs - ``` - -4. **Functional test failure** — look at the `cicd-integration-tests-*` job. - Failures in lightweight mode indicate a crash; failures with golden-value - mismatch indicate a numerical regression. Only rank 0 runs pytest validation, - so start with `stdout.log` for rank 0 in the artifact. -5. **Flaky test** — the runner retries automatically up to 3 times for known - transient patterns. If the job exhausted all retries and the failure matches - one of those patterns it is infrastructure noise, not a code regression. - Mark genuinely non-deterministic tests with `@pytest.mark.flaky` and open a - follow-up issue. - -### Correlating a Failure with the PR Changeset - -```bash -# Find unit tests that cover a changed source file -grep -r "from megatron.core.transformer.attention" tests/unit_tests/ -l - -# Check CODEOWNERS for reviewer assignment -cat .github/CODEOWNERS | grep "" -``` - -Use this mapping to determine whether the failure is directly caused by the -PR's changes or is a pre-existing issue on `main`. - ---- - -## Common Pitfalls - -| Problem | Cause | Fix | -|---------|-------|-----| -| Port collision on multi-GPU runs | torchrun binding conflicts | Use `torch.distributed.run` via the container entry point | -| Test passes locally but fails in CI | Different environment or data path | Check `DATA_PATH`, `DATA_CACHE_PATH`, and the `environment` tag (`dev` vs `lts`) | -| Golden value mismatch after a code change | Numerical regression | Download new golden values via `download_golden_values.py` after a clean run | -| `cicd-integration-tests-gb200` not triggered | GB200 jobs require maintainer status | Ask a maintainer to trigger, or add the `Run functional tests` label | diff --git a/tests/functional_tests/python_test_utils/common.py b/tests/functional_tests/python_test_utils/common.py index 21a003d39b1..94105942651 100644 --- a/tests/functional_tests/python_test_utils/common.py +++ b/tests/functional_tests/python_test_utils/common.py @@ -216,7 +216,9 @@ def pipeline( ] total_steps_evaluated = 1 else: - total_steps_evaluated = golden_value.end_step / golden_value.step_interval + 1 + total_steps_evaluated = ( + golden_value.end_step - golden_value.start_step + ) / golden_value.step_interval + 1 actual_value_list = [np.inf if type(v) is str else v for v in actual_value_list] golden_value_list = [np.inf if type(v) is str else v for v in golden_value_list] diff --git a/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py b/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py index 165e17b102b..a212ed417d6 100644 --- a/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py +++ b/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py @@ -6,6 +6,7 @@ import os from statistics import median +import numpy as np import pytest import yaml @@ -189,10 +190,16 @@ def test_inference_pipeline( if "routing_indices" in groundtruth_results and "routing_indices" in metrics: at_least_one_test_loop = True - routing_indices_groundtruth = groundtruth_results["routing_indices"] - routing_indices_current = current_results["routing_indices"] - assert ( - routing_indices_groundtruth == routing_indices_current + token_indices = groundtruth_results.get("routing_indices_token_indices") + current_routing = np.array(current_results["routing_indices"]) + assert token_indices is not None + current_routing = current_routing[token_indices] + routing_indices_groundtruth = np.sort( + np.array(groundtruth_results["routing_indices"]), axis=-1 + ) + routing_indices_current = np.sort(current_routing, axis=-1) + assert np.array_equal( + routing_indices_groundtruth, routing_indices_current ), f"Routing indices mismatch:\nGround truth: {routing_indices_groundtruth}\nCurrent: {routing_indices_current}" if not at_least_one_test_loop: diff --git a/tests/functional_tests/shell_test_utils/_run_training.sh b/tests/functional_tests/shell_test_utils/_run_training.sh index cd22e9b1104..d3f166c88eb 100644 --- a/tests/functional_tests/shell_test_utils/_run_training.sh +++ b/tests/functional_tests/shell_test_utils/_run_training.sh @@ -187,10 +187,14 @@ DISTRIBUTED_ARGS=( --redirects "3" ) +FT_LAUNCHER_ARGS=( + --max-restarts=3 +) + # Start training if [[ "$IS_NEMO_TEST" == "true" ]]; then if [[ "$LAUNCHER" == "ft_launcher" ]]; then - ft_launcher ${DISTRIBUTED_ARGS[@]} \ + ft_launcher ${DISTRIBUTED_ARGS[@]} ${FT_LAUNCHER_ARGS[@]} \ --no-python /opt/venv/bin/$TRAINING_SCRIPT_PATH "${PARAMS[@]}" && EXIT_CODE=0 || EXIT_CODE=$? else uv run --no-sync python -m torch.distributed.run ${DISTRIBUTED_ARGS[@]} \ @@ -198,7 +202,7 @@ if [[ "$IS_NEMO_TEST" == "true" ]]; then fi else if [[ "$LAUNCHER" == "ft_launcher" ]]; then - ft_launcher ${DISTRIBUTED_ARGS[@]} \ + ft_launcher ${DISTRIBUTED_ARGS[@]} ${FT_LAUNCHER_ARGS[@]} \ $TRAINING_SCRIPT_PATH "${PARAMS[@]}" && EXIT_CODE=0 || EXIT_CODE=$? else uv run --no-sync python -m torch.distributed.run ${DISTRIBUTED_ARGS[@]} \ diff --git a/tests/functional_tests/test_cases/common/ckpt_converter/__main__.py b/tests/functional_tests/test_cases/common/ckpt_converter/__main__.py index 543ddd36a6d..5e876dcf23d 100644 --- a/tests/functional_tests/test_cases/common/ckpt_converter/__main__.py +++ b/tests/functional_tests/test_cases/common/ckpt_converter/__main__.py @@ -579,7 +579,7 @@ def get_model_argv(self): "--tokenizer-type", "NullTokenizer", "--vocab-size", - "127", # ... NullTokenizer adds +1 EOD token. + "128", "--make-vocab-size-divisible-by", "1", ] diff --git a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release/model_config.yaml index 2baa92999e6..ad3d2c46d4c 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release/model_config.yaml @@ -22,7 +22,6 @@ MODEL_ARGS: --sequence-parallel: true --disable-bias-linear: true --micro-batch-size: 4 - --global-batch-size: 1152 --step-batch-size-schedule: "0:384 200B:768 400B:1152" --train-samples: 19531250 --manual-gc: true diff --git a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_gb200/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_gb200/model_config.yaml index d9b44500eca..e56dd5b7d13 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_gb200/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_gb200/model_config.yaml @@ -24,7 +24,6 @@ MODEL_ARGS: --sequence-parallel: true --disable-bias-linear: true --micro-batch-size: 4 - --global-batch-size: 1152 --step-batch-size-schedule: "0:384 200B:768 400B:1152" --train-samples: 19531250 --manual-gc: true diff --git a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm/model_config.yaml index 5c70d072482..5bfeca07694 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm/model_config.yaml @@ -22,7 +22,6 @@ MODEL_ARGS: --sequence-parallel: true --disable-bias-linear: true --micro-batch-size: 4 - --global-batch-size: 1152 --step-batch-size-schedule: "0:384 200B:768 400B:1152" --train-samples: 4882812 --manual-gc: true diff --git a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm_gb200/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm_gb200/model_config.yaml index 5f51a40ecc9..577258ab3f5 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm_gb200/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm_gb200/model_config.yaml @@ -24,7 +24,6 @@ MODEL_ARGS: --sequence-parallel: true --disable-bias-linear: true --micro-batch-size: 4 - --global-batch-size: 1152 --step-batch-size-schedule: "0:384 200B:768 400B:1152" --train-samples: 19531250 --manual-gc: true diff --git a/tests/functional_tests/test_cases/gpt/gpt3_7b_tp1_pp4_memory_speed/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_7b_tp1_pp4_memory_speed/model_config.yaml index e91d1105e83..eb253b243f1 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_7b_tp1_pp4_memory_speed/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_7b_tp1_pp4_memory_speed/model_config.yaml @@ -28,7 +28,7 @@ MODEL_ARGS: --save: ${CHECKPOINT_SAVE_PATH} --load: ${CHECKPOINT_LOAD_PATH} --tokenizer-type: NullTokenizer - --vocab-size: 131072 + --vocab-size: 131073 --mock-data: true --split: 949,50,1 --distributed-backend: nccl diff --git a/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/golden_values_dev_dgx_gb200.json b/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/golden_values_dev_dgx_gb200.json index 1be5704733f..b1c389f8532 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/golden_values_dev_dgx_gb200.json +++ b/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/golden_values_dev_dgx_gb200.json @@ -6,29 +6,29 @@ "values": { "1": 12.61164, "2": 12.60596, - "3": 12.60284, - "4": 12.59692, - "5": 12.59561, - "6": 12.59765, - "7": 12.58049, - "8": 12.53856, - "9": 12.51228, - "10": 12.49859, - "11": 12.3236, - "12": 12.29422, - "13": 12.23138, - "14": 12.22825, - "15": 11.82222, - "16": 11.80417, - "17": 11.76129, - "18": 11.7371, - "19": 11.61308, - "20": 11.50145, - "21": 11.26476, - "22": 11.37641, - "23": 11.28395, + "3": 12.60283, + "4": 12.59697, + "5": 12.59555, + "6": 12.59771, + "7": 12.58043, + "8": 12.53852, + "9": 12.51223, + "10": 12.49867, + "11": 12.32368, + "12": 12.29429, + "13": 12.23146, + "14": 12.22821, + "15": 11.82219, + "16": 11.80416, + "17": 11.76127, + "18": 11.73709, + "19": 11.61313, + "20": 11.50155, + "21": 11.26472, + "22": 11.37638, + "23": 11.28391, "24": 11.15655, - "25": 10.99869 + "25": 10.99872 } }, "num-zeros": { @@ -36,31 +36,31 @@ "end_step": 25, "step_interval": 1, "values": { - "1": 523050272.0, - "2": 523678944.0, - "3": 522945696.0, - "4": 523243072.0, - "5": 523021376.0, - "6": 523374272.0, - "7": 523438432.0, - "8": 523085472.0, - "9": 523469312.0, - "10": 523196096.0, - "11": 524297536.0, - "12": 523455968.0, - "13": 523498528.0, - "14": 524478016.0, - "15": 523634528.0, - "16": 523464768.0, - "17": 523079328.0, - "18": 523361472.0, - "19": 523209888.0, - "20": 523229312.0, - "21": 524937184.0, - "22": 523658208.0, - "23": 523417760.0, - "24": 523486464.0, - "25": 525637376.0 + "1": 523049312.0, + "2": 523676640.0, + "3": 522947296.0, + "4": 523241568.0, + "5": 523021536.0, + "6": 523375648.0, + "7": 523434944.0, + "8": 523086432.0, + "9": 523468448.0, + "10": 523196352.0, + "11": 524296800.0, + "12": 523454400.0, + "13": 523497696.0, + "14": 524480800.0, + "15": 523636416.0, + "16": 523466048.0, + "17": 523080416.0, + "18": 523359776.0, + "19": 523209440.0, + "20": 523228640.0, + "21": 524937728.0, + "22": 523660096.0, + "23": 523415296.0, + "24": 523486336.0, + "25": 525638688.0 } }, "mem-allocated-bytes": { @@ -133,29 +133,29 @@ "step_interval": 1, "values": { "1": "nan", - "2": 5.08847, + "2": 5.59096, "3": "nan", - "4": 0.8967, + "4": 0.89026, "5": "nan", - "6": 0.88539, + "6": 0.88523, "7": "nan", - "8": 0.90974, + "8": 0.88482, "9": "nan", - "10": 0.89722, + "10": 0.88377, "11": "nan", - "12": 0.88709, + "12": 0.90678, "13": "nan", - "14": 0.88279, + "14": 0.96674, "15": "nan", - "16": 0.88529, + "16": 0.88644, "17": "nan", - "18": 0.8909, + "18": 0.88775, "19": "nan", - "20": 0.88451, + "20": 0.88634, "21": "nan", - "22": 0.88706, + "22": 0.88696, "23": "nan", - "24": 0.88544, + "24": 0.88377, "25": "nan" } } diff --git a/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/model_config.yaml index 96bc016ca97..ae067d246c9 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/model_config.yaml @@ -28,7 +28,7 @@ MODEL_ARGS: --save: ${CHECKPOINT_SAVE_PATH} --load: ${CHECKPOINT_LOAD_PATH} --tokenizer-type: NullTokenizer - --vocab-size: 131072 + --vocab-size: 131073 --mock-data: true --split: 949,50,1 --distributed-backend: nccl diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_cp2/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_cp2/model_config.yaml index 56df7d07a0d..fc6d56fab55 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_cp2/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_cp2/model_config.yaml @@ -20,8 +20,8 @@ MODEL_ARGS: --save: ${CHECKPOINT_SAVE_PATH} --load: ${CHECKPOINT_LOAD_PATH} --data-path: ${DATA_PATH}/text/common_pile/v01_filtered_data/my-gpt3_00_text_document - --vocab-file: ${DATA_PATH}/text/common_pile/v01_filtered_data/bpe/vocab.json - --merge-file: ${DATA_PATH}/text/common_pile/v01_filtered_data/bpe/merges.txt + --tokenizer-type: NullTokenizer + --vocab-size: 50257 --split: 949,50,1 --distributed-backend: nccl --lr: 0.00015 diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist/model_config.yaml index 9436fa2a5e6..14e40a430d5 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist/model_config.yaml @@ -51,4 +51,5 @@ MODEL_ARGS: --log-memory-to-tensorboard: true --async-save: true --use-persistent-ckpt-worker: true + --verify-integrity: true TEST_TYPE: ckpt-resume diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist_cp2_nondeterministic/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist_cp2_nondeterministic/model_config.yaml index 2fe3d590c1b..29a9bbef0c1 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist_cp2_nondeterministic/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist_cp2_nondeterministic/model_config.yaml @@ -53,4 +53,5 @@ MODEL_ARGS: --log-memory-to-tensorboard: true --async-save: true --use-persistent-ckpt-worker: true + --verify-integrity: true TEST_TYPE: ckpt-resume diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist_reshard_1x4xNone/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist_reshard_1x4xNone/model_config.yaml index a413b7ebb0c..ec29ea58ca6 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist_reshard_1x4xNone/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp2_resume_torch_dist_reshard_1x4xNone/model_config.yaml @@ -49,4 +49,5 @@ MODEL_ARGS: --log-memory-to-tensorboard: true --async-save: true --use-persistent-ckpt-worker: true + --verify-integrity: true TEST_TYPE: ckpt-resume diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mcore_tp1_pp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mcore_tp1_pp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather/model_config.yaml index 823312a21df..6588160cc67 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_mcore_tp1_pp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_mcore_tp1_pp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather/model_config.yaml @@ -55,4 +55,5 @@ MODEL_ARGS: --log-memory-to-tensorboard: true --async-save: true --use-persistent-ckpt-worker: true + --verify-integrity: true TEST_TYPE: ckpt-resume diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..95d578e2614 --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill/golden_values_dev_dgx_h100.json @@ -0,0 +1,226 @@ +{ + "0": { + "input_prompt": "Artificial intelligence has transformed numerous industries over the past decade. From healthcare to finance, manufacturing to education, AI systems are now capable of performing tasks that once required significant human expertise. Machine learning models can diagnose diseases from medical images, detect fraud in financial transactions, optimize supply chains, and personalize educational content for individual students. Large language models in particular have demonstrated remarkable capabilities in understanding and generating human language, enabling applications such as code generation, document summarization, question answering, and creative writing assistance. As these systems continue to improve, researchers and practitioners are working to address challenges around reliability, fairness, and interpretability. The next generation of AI systems will likely be even more capable, but also require careful consideration of their societal implications. In summary, the field of artificial intelligence is", + "generated_text": " advancing rapidly, with new breakthroughs emerging regularly. While AI systems have already demonstrated impressive capabilities, there is still much work to be done to ensure these", + "generated_tokens": [ + 56332, + 18747, + 1044, + 1454, + 1875, + 70639, + 1115, + 30455, + 27632, + 1046, + 9076, + 26554, + 6467, + 1736, + 5314, + 10597, + 42703, + 28946, + 1044, + 2156, + 1395, + 3637, + 3315, + 2196, + 1317, + 1402, + 5595, + 1317, + 11811, + 2576 + ], + "latency": 0.5657311892136931, + "logprobs": [ + -1.9631825685501099, + -1.4970390796661377, + -2.6500136852264404, + -3.3695499897003174, + -2.849766731262207, + -0.6258476972579956, + -4.9532599449157715, + -0.11778059601783752, + -0.30011042952537537, + -0.7112784385681152, + -1.7553068399429321, + -2.0854082107543945, + -1.093907117843628, + -0.20775330066680908, + -0.3064386248588562, + -0.11147496849298477, + -6.308845520019531, + -0.08563490211963654, + -2.5601937770843506, + -0.012138190679252148, + -0.16679446399211884, + -3.9401566982269287, + -0.6840294003486633, + -0.9575856924057007, + -2.4038002490997314, + -0.0004950728034600616, + -0.5693215131759644, + -0.5039942264556885, + -0.1864156872034073, + -0.8986333012580872, + -0.028368404135107994, + -0.03446716070175171, + -0.04009075462818146, + -2.1998469829559326, + -0.23252013325691223, + -4.589982032775879, + -0.008946098387241364, + -1.7037732601165771, + -1.0597485303878784, + -4.041484355926514, + -0.36329641938209534, + -2.1041674613952637, + -0.0678618922829628, + -0.22670072317123413, + -0.5720489025115967, + -2.714329242706299, + -0.5115664601325989, + -0.1820123940706253, + -0.4119052290916443, + -0.012396075762808323, + -0.020758748054504395, + -0.6604316234588623, + -0.43596622347831726, + -0.1802099645137787, + -0.18521766364574432, + -0.05841841548681259, + -1.4234997034072876, + -0.004653695039451122, + -1.0471034049987793, + -0.3272952437400818, + -0.7283356189727783, + -0.6264514923095703, + -0.6953774690628052, + -0.6589646339416504, + -6.88499116897583, + -0.06728250533342361, + -0.002077446784824133, + -7.879212379455566, + -0.5264835357666016, + -0.10031754523515701, + -0.577612042427063, + -0.3896879255771637, + -0.39892667531967163, + -0.08144717663526535, + -1.0562388896942139, + -0.07140254229307175, + -0.0015612567076459527, + -0.010199331678450108, + -1.2312755584716797, + -0.15363134443759918, + -1.2347755432128906, + -1.346854329109192, + -2.38948392868042, + -3.504691630951129e-05, + -4.086829662322998, + -0.14582525193691254, + -0.038434501737356186, + -4.89980411529541, + -0.6460154056549072, + -0.0009527434594929218, + -0.0008908117306418717, + -7.490355014801025, + -0.11345300078392029, + -0.008508029393851757, + -0.015247219242155552, + -0.394756555557251, + -0.09369947016239166, + -1.036057949066162, + -1.1025354862213135, + -2.6586172580718994, + -0.43926483392715454, + -0.5900647640228271, + -2.593330144882202, + -0.002643188228830695, + -1.5542654991149902, + -0.4676969051361084, + -4.994030475616455, + -1.3037813901901245, + -2.0418052673339844, + -0.4943341016769409, + -3.5061726570129395, + -0.08762064576148987, + -2.472935199737549, + -2.0049428939819336, + -2.8443005084991455, + -1.2312332391738892, + -0.00997180212289095, + -1.4323421716690063, + -0.0009956170106306672, + -0.9328322410583496, + -2.180025100708008, + -4.3987260141875595e-05, + -1.3120726346969604, + -1.139200210571289, + -3.656489610671997, + -0.8193762302398682, + -0.00391182117164135, + -0.034046005457639694, + -0.8485873341560364, + -1.7457873821258545, + -0.2671572268009186, + -1.998857021331787, + -1.886608362197876, + -0.00030858523678034544, + -0.604221522808075, + -0.5272942781448364, + -1.4377481937408447, + -1.3155831098556519, + -1.7748348712921143, + -1.4071087837219238, + -0.33158257603645325, + -0.012675435282289982, + -0.7177118062973022, + -0.8072052001953125, + -2.185786485671997, + -1.7145639657974243, + -3.546031951904297, + -6.747668743133545, + -0.05591462552547455, + -2.4653594493865967, + -2.1707944869995117, + -0.3219299912452698, + -0.610948383808136, + -0.0008507922757416964, + -1.2661765813827515, + -0.9660887718200684, + -0.21183912456035614, + -0.2528935670852661, + -0.42069458961486816, + -1.207878589630127, + -0.2546793222427368, + -0.0005551227368414402, + -1.2553633451461792, + -0.7063015699386597, + -0.8560804128646851, + -1.5543453693389893, + -1.3915338516235352, + -1.1071290969848633, + -1.2220503091812134, + -0.7909353971481323, + -1.240220308303833, + -0.3716951906681061, + -0.25508493185043335, + -0.6280136108398438, + -1.11504328250885, + -0.863906741142273, + -0.01451974455267191, + -0.32788148522377014, + -0.36335280537605286, + -0.0031544233206659555, + -0.011296008713543415, + -0.00030357998912222683, + -0.6195628643035889, + -0.6293846368789673, + -0.43223705887794495 + ] + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill/model_config.yaml new file mode 100644 index 00000000000..1ad4cc3d576 --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill/model_config.yaml @@ -0,0 +1,67 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: :4096:8 + TRITON_CACHE_AUTOTUNING: 0 + MAMBA_DETERMINISTIC: 1 +TEST_TYPE: frozen-start +MODE: inference +MODEL_ARGS: + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --log-memory-to-tensorboard: true + --timing-log-level: 0 + --load: ${CHECKPOINT_LOAD_PATH}/model/nemotron6/3b_hybrid_moe/checkpoints/phase2_lc_reinit_emb/ + --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/nemotron6/tokenizers/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json + --tokenizer-type: TikTokenizer + --tiktoken-pattern: v2 + --distributed-backend: nccl + --log-interval: 1 + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 1 + --expert-model-parallel-size: 8 + --expert-tensor-parallel-size: 1 + --use-mcore-models: true + --model-provider: mamba + --use-checkpoint-args: true + --no-use-tokenizer-model-from-checkpoint-args: true + --dist-ckpt-strictness: log_unexpected + --ckpt-format: torch_dist + --ckpt-fully-parallel-save: true + --ckpt-fully-parallel-load: true + --ckpt-assume-constant-structure: true + --no-load-optim: true + --moe-router-score-function: sigmoid + --moe-router-enable-expert-bias: true + --moe-router-topk-scaling-factor: 2.5 + --moe-router-dtype: fp32 + --bf16: true + --attention-backend: flash + --no-create-attention-mask-in-dataloader: true + --num-workers: 8 + --micro-batch-size: 1 + --deterministic-mode: true + --temperature: 1.0 + --top_k: 1 + --return-log-probs: true + --num-tokens-to-generate: 30 + --max-tokens-to-oom: 3600000 + --inference-dynamic-batching-max-tokens: 64 # kept small to trigger chunked prefill. + --inference-dynamic-batching-max-requests: 32 + --inference-max-seq-length: 4096 + --enable-chunked-prefill: true + --cuda-graph-scope: full_iteration_inference + --cuda-graph-impl: local + --inference-dynamic-batching-num-cuda-graphs: -1 + --output-path: ${INFERENCE_OUTPUT_PATH} + --prompts: 'Artificial intelligence has transformed numerous industries over the past decade. From healthcare to finance, manufacturing to education, AI systems are now capable of performing tasks that once required significant human expertise. Machine learning models can diagnose diseases from medical images, detect fraud in financial transactions, optimize supply chains, and personalize educational content for individual students. Large language models in particular have demonstrated remarkable capabilities in understanding and generating human language, enabling applications such as code generation, document summarization, question answering, and creative writing assistance. As these systems continue to improve, researchers and practitioners are working to address challenges around reliability, fairness, and interpretability. The next generation of AI systems will likely be even more capable, but also require careful consideration of their societal implications. In summary, the field of artificial intelligence is' + --incoming-requests-per-step: 32 + --no-record-throughput: true + --mamba-inference-conv-states-dtype: fp32 + --mamba-inference-ssm-states-dtype: fp32 + --transformer-impl: inference_optimized + --inference-moe-token-dispatcher-type: nvls + --inference-logging-step-interval: 1 +METRICS: diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml index f5de6eaac72..4b258afe0d6 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml @@ -24,7 +24,7 @@ MODEL_ARGS: --pipeline-model-parallel-size: 1 --expert-model-parallel-size: 1 --use-mcore-models: true - --model-provider: mamba + --model-provider: hybrid --init-method-std: 0.0198 --untie-embeddings-and-output-weights: true --disable-bias-linear: true @@ -35,7 +35,7 @@ MODEL_ARGS: --num-attention-heads: 16 --kv-channels: 128 --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- - --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec + --spec: megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec --normalization: RMSNorm --swiglu: true --attention-dropout: 0.0 diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml index b10698d521f..bd86d2faa44 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml @@ -24,7 +24,7 @@ MODEL_ARGS: --pipeline-model-parallel-size: 1 --expert-model-parallel-size: 1 --use-mcore-models: true - --model-provider: mamba + --model-provider: hybrid --init-method-std: 0.0198 --untie-embeddings-and-output-weights: true --disable-bias-linear: true @@ -35,7 +35,7 @@ MODEL_ARGS: --num-attention-heads: 16 --kv-channels: 128 --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- - --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec + --spec: megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec --normalization: RMSNorm --swiglu: true --attention-dropout: 0.0 diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_flashinfer/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_flashinfer/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..956754f44c1 --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_flashinfer/golden_values_dev_dgx_h100.json @@ -0,0 +1,286 @@ +{ + "0": { + "input_prompt": "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies.", + "generated_text": " You are not alone. You are not alone. You are not alone. You are not alone. You are not alone. You are not alone.", + "generated_tokens": [ + 3213, + 1584, + 1605, + 9412, + 1046, + 3213, + 1584, + 1605, + 9412, + 1046, + 3213, + 1584, + 1605, + 9412, + 1046, + 3213, + 1584, + 1605, + 9412, + 1046, + 3213, + 1584, + 1605, + 9412, + 1046, + 3213, + 1584, + 1605, + 9412, + 1046 + ], + "latency": 1.878054141998291, + "ttft": 0.07786321640014648, + "cuda_graph_request_count_map": null, + "step_count": 30, + "top_n_logprobs": null, + "prompt_top_n_logprobs": null, + "prompt_logprobs": [ + -9.498085021972656, + -3.787536859512329, + -3.0404648780822754, + -1.7445809841156006, + -0.29672086238861084, + -1.3661342859268188, + -2.3458175659179688, + -1.83931303024292, + -1.4894113540649414, + -6.440437316894531, + -0.8176816701889038, + -1.790361762046814, + -3.662419557571411, + -3.7036173343658447, + -1.6009434461593628, + -1.5501081943511963, + -2.846059799194336, + -6.732302665710449, + -0.06605878472328186, + -1.334327220916748, + -6.083745002746582, + -9.440131187438965, + -10.473882675170898, + -1.5964821577072144, + -4.702763557434082, + -0.7514524459838867, + -2.1461901664733887, + -0.012340382672846317, + -0.03605639934539795, + -3.0907557010650635, + -8.744739532470703, + -1.5410845279693604, + -5.84979772567749, + -3.0918972492218018, + -3.9814329147338867, + -3.78017520904541, + -2.5227086544036865, + -2.258594036102295, + -0.4719255566596985, + -1.0329649448394775, + -5.3284382820129395, + -8.25335693359375, + -0.015789249911904335, + -2.854100227355957, + -1.2236379384994507, + -3.905193328857422, + -0.9268187284469604, + -0.0030202509369701147, + -3.224249839782715, + -11.11172103881836, + -3.8121743202209473, + -2.3400487899780273, + -4.672845363616943, + -0.09729652851819992, + -0.06232408434152603, + -1.336004614830017, + -2.054157257080078, + -4.390933036804199, + -0.44248226284980774, + -3.9417736530303955, + -0.5888474583625793, + -0.26697415113449097, + -2.9271092414855957, + -13.515066146850586, + -0.10294333100318909, + -3.5007452964782715, + -0.8535972237586975, + -5.173652648925781, + -0.330394983291626, + -2.304553508758545, + -0.5418462753295898, + -1.300589919090271, + -4.9136152267456055, + -15.558022499084473, + -4.918652534484863, + -0.22206512093544006, + -6.589188575744629, + -0.9015690684318542, + -2.2228457927703857, + -1.8689247369766235, + -0.2006368339061737, + -5.918689727783203, + -0.006355076562613249, + -7.532094955444336, + -3.2708187103271484, + -3.743263006210327, + -2.011824131011963 + ], + "generated_logprobs": [ + -3.0331737995147705, + -1.9080564975738525, + -2.52506947517395, + -2.325258493423462, + -1.180279016494751, + -1.1824196577072144, + -0.39788734912872314, + -1.110222578048706, + -1.5034958124160767, + -0.9765141606330872, + -0.9300433397293091, + -0.15196305513381958, + -0.2200203537940979, + -0.06051275506615639, + -0.6840062737464905, + -1.0964292287826538, + -0.17654964327812195, + -0.18547140061855316, + -0.06710249185562134, + -0.4758152365684509, + -0.6657928228378296, + -0.10342729091644287, + -0.10059614479541779, + -0.046978313475847244, + -0.410809725522995, + -0.428723007440567, + -0.06053968518972397, + -0.06518109142780304, + -0.030038274824619293, + -0.3271780014038086 + ], + "logprobs": [ + -9.498085021972656, + -3.787536859512329, + -3.0404648780822754, + -1.7445809841156006, + -0.29672086238861084, + -1.3661342859268188, + -2.3458175659179688, + -1.83931303024292, + -1.4894113540649414, + -6.440437316894531, + -0.8176816701889038, + -1.790361762046814, + -3.662419557571411, + -3.7036173343658447, + -1.6009434461593628, + -1.5501081943511963, + -2.846059799194336, + -6.732302665710449, + -0.06605878472328186, + -1.334327220916748, + -6.083745002746582, + -9.440131187438965, + -10.473882675170898, + -1.5964821577072144, + -4.702763557434082, + -0.7514524459838867, + -2.1461901664733887, + -0.012340382672846317, + -0.03605639934539795, + -3.0907557010650635, + -8.744739532470703, + -1.5410845279693604, + -5.84979772567749, + -3.0918972492218018, + -3.9814329147338867, + -3.78017520904541, + -2.5227086544036865, + -2.258594036102295, + -0.4719255566596985, + -1.0329649448394775, + -5.3284382820129395, + -8.25335693359375, + -0.015789249911904335, + -2.854100227355957, + -1.2236379384994507, + -3.905193328857422, + -0.9268187284469604, + -0.0030202509369701147, + -3.224249839782715, + -11.11172103881836, + -3.8121743202209473, + -2.3400487899780273, + -4.672845363616943, + -0.09729652851819992, + -0.06232408434152603, + -1.336004614830017, + -2.054157257080078, + -4.390933036804199, + -0.44248226284980774, + -3.9417736530303955, + -0.5888474583625793, + -0.26697415113449097, + -2.9271092414855957, + -13.515066146850586, + -0.10294333100318909, + -3.5007452964782715, + -0.8535972237586975, + -5.173652648925781, + -0.330394983291626, + -2.304553508758545, + -0.5418462753295898, + -1.300589919090271, + -4.9136152267456055, + -15.558022499084473, + -4.918652534484863, + -0.22206512093544006, + -6.589188575744629, + -0.9015690684318542, + -2.2228457927703857, + -1.8689247369766235, + -0.2006368339061737, + -5.918689727783203, + -0.006355076562613249, + -7.532094955444336, + -3.2708187103271484, + -3.743263006210327, + -2.011824131011963, + -3.0331737995147705, + -1.9080564975738525, + -2.52506947517395, + -2.325258493423462, + -1.180279016494751, + -1.1824196577072144, + -0.39788734912872314, + -1.110222578048706, + -1.5034958124160767, + -0.9765141606330872, + -0.9300433397293091, + -0.15196305513381958, + -0.2200203537940979, + -0.06051275506615639, + -0.6840062737464905, + -1.0964292287826538, + -0.17654964327812195, + -0.18547140061855316, + -0.06710249185562134, + -0.4758152365684509, + -0.6657928228378296, + -0.10342729091644287, + -0.10059614479541779, + -0.046978313475847244, + -0.410809725522995, + -0.428723007440567, + -0.06053968518972397, + -0.06518109142780304, + -0.030038274824619293, + -0.3271780014038086 + ] + }, + "mem-max-allocated-bytes": 53350692864, + "lifetime_prefill_token_count": 88 +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_flashinfer/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_flashinfer/model_config.yaml new file mode 100644 index 00000000000..e989be22f7e --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_flashinfer/model_config.yaml @@ -0,0 +1,76 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: :4096:8 + TRITON_CACHE_AUTOTUNING: 0 + MAMBA_DETERMINISTIC: 1 +TEST_TYPE: frozen-start +MODE: inference +MODEL_ARGS: + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --log-memory-to-tensorboard: true + --timing-log-level: 0 + --load: ${CHECKPOINT_LOAD_PATH}/model/mamba_hybrid_2b/dcp/mcore-v1_bf16/checkpoint + --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mamba_hybrid_2b/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json + --tokenizer-type: TikTokenizer + --tiktoken-pattern: v2 + --distributed-backend: nccl + --log-interval: 1 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 1 + --expert-model-parallel-size: 1 + --use-mcore-models: true + --model-provider: hybrid + --init-method-std: 0.0198 + --untie-embeddings-and-output-weights: true + --disable-bias-linear: true + --init-method-std: 0.014 + --position-embedding-type: none + --hidden-size: 2048 + --ffn-hidden-size: 11264 + --num-attention-heads: 16 + --kv-channels: 128 + --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- + --spec: megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec + --normalization: RMSNorm + --swiglu: true + --attention-dropout: 0.0 + --hidden-dropout: 0.0 + --seq-length: 4096 + --max-position-embeddings: 4096 + --micro-batch-size: 1 + --ckpt-format: torch_dist + --ckpt-fully-parallel-save: true + --ckpt-fully-parallel-load: true + --ckpt-assume-constant-structure: true + --dist-ckpt-strictness: log_unexpected + --bf16: true + --attention-backend: flash + --no-create-attention-mask-in-dataloader: true + --num-workers: 8 + --use-checkpoint-args: true + --no-use-tokenizer-model-from-checkpoint-args: true + --no-load-optim: true + --deterministic-mode: true + --save-interval: 2000 + --temperature: 1.0 + --top_k: 1 + --inference-dynamic-batching-sampling-backend: flashinfer + --return-log-probs: true + --num-tokens-to-generate: 30 + --max-tokens-to-oom: 3600000 + --inference-max-seq-length: 4096 + --output-path: ${INFERENCE_OUTPUT_PATH} + --prompts: "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies." + --incoming-requests-per-step: 32 + --inference-repeat-n: 3 + --no-record-throughput: true + --mamba-inference-conv-states-dtype: fp32 + --mamba-inference-ssm-states-dtype: fp32 +METRICS: + - "generated_tokens" + - "logprobs" diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_flextron_nightly_tp2_pp1_ep2_dgx_h100_1N8G/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/hybrid/hybrid_flextron_nightly_tp2_pp1_ep2_dgx_h100_1N8G/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..3550f734923 --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_flextron_nightly_tp2_pp1_ep2_dgx_h100_1N8G/golden_values_dev_dgx_h100.json @@ -0,0 +1,137 @@ +{ + "lm loss": { + "start_step": 1, + "end_step": 20, + "step_interval": 1, + "values": { + "1": 10.96772, + "2": 10.94878, + "3": 10.13223, + "4": 10.10273, + "5": 9.77116, + "6": 9.74409, + "7": 9.67572, + "8": 9.54443, + "9": 9.47498, + "10": 9.46493, + "11": 9.0426, + "12": 9.29951, + "13": 9.30226, + "14": 9.27311, + "15": 9.10353, + "16": 8.79531, + "17": 8.51664, + "18": 8.77952, + "19": 8.71592, + "20": 8.71692 + } + }, + "num-zeros": { + "start_step": 1, + "end_step": 20, + "step_interval": 1, + "values": { + "1": 47889940.0, + "2": 49012016.0, + "3": 48644404.0, + "4": 48106204.0, + "5": 58246720.0, + "6": 63126404.0, + "7": 77135048.0, + "8": 67132312.0, + "9": 67750168.0, + "10": 64553460.0, + "11": 77738728.0, + "12": 64587012.0, + "13": 64661168.0, + "14": 61613476.0, + "15": 61997576.0, + "16": 71691368.0, + "17": 69936240.0, + "18": 70742744.0, + "19": 67630016.0, + "20": 67347280.0 + } + }, + "mem-allocated-bytes": { + "start_step": 1, + "end_step": 20, + "step_interval": 1, + "values": { + "1": 1872031232.0, + "2": 1872031232.0, + "3": 1872031232.0, + "4": 1872031232.0, + "5": 1872031232.0, + "6": 1872031232.0, + "7": 1872031232.0, + "8": 1872031232.0, + "9": 1872031232.0, + "10": 1872031232.0, + "11": 1872031232.0, + "12": 1872031232.0, + "13": 1872031232.0, + "14": 1872031232.0, + "15": 1872031232.0, + "16": 1872031232.0, + "17": 1872031232.0, + "18": 1872031232.0, + "19": 1872031232.0, + "20": 1872031232.0 + } + }, + "mem-max-allocated-bytes": { + "start_step": 1, + "end_step": 20, + "step_interval": 1, + "values": { + "1": 2476529664.0, + "2": 2827837440.0, + "3": 2827837440.0, + "4": 2847291904.0, + "5": 2886460416.0, + "6": 2886460416.0, + "7": 2938775552.0, + "8": 2938775552.0, + "9": 2938775552.0, + "10": 2938775552.0, + "11": 2938775552.0, + "12": 2938775552.0, + "13": 2938775552.0, + "14": 2938775552.0, + "15": 2938775552.0, + "16": 2938775552.0, + "17": 2938775552.0, + "18": 2938775552.0, + "19": 2938775552.0, + "20": 2938775552.0 + } + }, + "iteration-time": { + "start_step": 1, + "end_step": 20, + "step_interval": 1, + "values": { + "1": "nan", + "2": 43.50173, + "3": 1.91918, + "4": 0.33128, + "5": 0.32613, + "6": 0.30839, + "7": 0.30715, + "8": 0.29241, + "9": 0.28352, + "10": 0.2695, + "11": 0.27092, + "12": 0.27945, + "13": 0.26134, + "14": 0.28396, + "15": 0.2654, + "16": 0.2759, + "17": 0.26684, + "18": 0.25266, + "19": 0.25121, + "20": 0.25109 + } + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_flextron_nightly_tp2_pp1_ep2_dgx_h100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_flextron_nightly_tp2_pp1_ep2_dgx_h100_1N8G/model_config.yaml new file mode 100644 index 00000000000..493aaa31b16 --- /dev/null +++ b/tests/functional_tests/test_cases/hybrid/hybrid_flextron_nightly_tp2_pp1_ep2_dgx_h100_1N8G/model_config.yaml @@ -0,0 +1,127 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 1 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: :4096:8 + TORCH_COMPILE_DISABLE: 1 + TORCHDYNAMO_DISABLE: 1 + TORCH_INDUCTOR_DISABLE: 1 +MODEL_ARGS: + # ── Hybrid Mamba / Attention / MoE backbone (scaled down from train_flextron.sh) ── + --num-layers: 12 + --hidden-size: 1024 + --ffn-hidden-size: 512 + --num-attention-heads: 32 + --group-query-attention: true + --num-query-groups: 2 + --kv-channels: 128 + --mamba-num-heads: 32 + --mamba-head-dim: 64 + --hybrid-override-pattern: MEM*EMEM*EME* + --is-hybrid-model: true + --position-embedding-type: none + --normalization: RMSNorm + --squared-relu: true + --use-fused-weighted-squared-relu: true + --untie-embeddings-and-output-weights: true + --disable-bias-linear: true + --init-method-std: 0.0173 + --attention-dropout: 0.0 + --hidden-dropout: 0.0 + # ── MoE ─────────────────────────────────────────────────────────────────── + --num-experts: 32 + --moe-router-topk: 2 + --moe-router-score-function: sigmoid + --moe-router-enable-expert-bias: true + --moe-router-topk-scaling-factor: 2.5 + --moe-router-dtype: fp32 + --moe-router-load-balancing-type: none + --moe-aux-loss-coeff: 1.0e-4 + --moe-shared-expert-intermediate-size: 512 + --moe-shared-expert-overlap: true + --moe-token-dispatcher-type: alltoall + --moe-grouped-gemm: true + --moe-permute-fusion: true + --cross-entropy-loss-fusion: true + --cross-entropy-fusion-impl: native + # ── Parallelism (8 GPUs total: TP=2, EP=2, PP=1, CP=1 → DP=4) ───────────── + --tensor-model-parallel-size: 2 + --pipeline-model-parallel-size: 1 + --expert-model-parallel-size: 2 + --expert-tensor-parallel-size: 1 + --context-parallel-size: 1 + --sequence-parallel: true + --attention-backend: flash + # ── Data / tokenizer (common_pile + GPT2 BPE, matches hybrid tests) ─────── + --seq-length: 2048 + --max-position-embeddings: 2048 + --micro-batch-size: 1 + --global-batch-size: 8 + --train-iters: 20 + --data-path: ${DATA_PATH}/text/common_pile/v01_filtered_data/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/text/common_pile/v01_filtered_data/bpe/vocab.json + --merge-file: ${DATA_PATH}/text/common_pile/v01_filtered_data/bpe/merges.txt + --split: 949,50,1 + --data-cache-path: ${DATA_CACHE_PATH} + --no-mmap-bin-files: true + --no-create-attention-mask-in-dataloader: true + # ── Optimizer / schedule ────────────────────────────────────────────────── + --lr: 1.0e-4 + --min-lr: 1.0e-5 + --lr-decay-style: cosine + --lr-warmup-fraction: 0.01 + --weight-decay: 0.0 + --clip-grad: 1.0 + --adam-beta1: 0.9 + --adam-beta2: 0.98 + --use-distributed-optimizer: true + --bf16: true + # ── Logging / checkpointing ─────────────────────────────────────────────── + --log-interval: 1 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --log-memory-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --save-interval: 10000 + --save: ${CHECKPOINT_SAVE_PATH} + --ckpt-format: torch_dist + --ckpt-fully-parallel-save: true + --ckpt-fully-parallel-load: true + --dist-ckpt-strictness: ignore_all + --eval-interval: 1000 + --eval-iters: 4 + # ── Core model plumbing ─────────────────────────────────────────────────── + --use-mcore-models: true + --transformer-impl: transformer_engine + --export-default-te-spec: true + --export-model-type: HybridModel + --distributed-backend: nccl + --distributed-timeout-minutes: 20 + # ── Flextron (from train_flextron.sh flex_options, int-lists with ÷8) ───── + --flextron: true + --enable-router: true + --binary-mask: true + --soft-mask: true + --hard-sample-th: 0.996 + --router-beta: 1.0 + --original-model-sample-prob: 0.0 + --tau-init: 1.0 + --tau-decay: 0.9997 + --loss-alpha: 1.0 + --lr-mult-router: 100 + --router-gbs: 2 + --router-inter-dim: 256 + --budget-list: "[1.0 0.697]" + --budget-probs: "[1.0 1.0]" + --budget-type: param + --emb-int-list: "[1024 768 512]" + --mlp-int-list: "[512 384 256]" + --mamba-int-list: "[32 24 16]" + --moe-expert-int-list: "[32 24 16]" + --linear-scaler-start: 1.0 + --linear-scaler-end: 10.0 + --slice: true + --router-std: 0.1 +TEST_TYPE: regular diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp1_cp1_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp1_cp1_dgx_a100_1N8G/model_config.yaml index 6d40098499d..9add53f8a49 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp1_cp1_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp1_cp1_dgx_a100_1N8G/model_config.yaml @@ -9,7 +9,7 @@ MODEL_ARGS: --group-query-attention: true --num-query-groups: 8 --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- - --spec: "[megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec]" + --spec: "[megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec]" --log-params-norm: true --log-num-zeros-in-grad: true --log-validation-ppl-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G/model_config.yaml index 51492f98c6e..25df6aa0359 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp2_vpp2_cp1_dgx_a100_1N8G/model_config.yaml @@ -9,7 +9,7 @@ MODEL_ARGS: --group-query-attention: true --num-query-groups: 8 --hybrid-layer-pattern: M-M-M-M*-M-|M-M-M*-M-M-|M-M*-M-M-M-|M*-M-M-M-M- - --spec: "[megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec]" + --spec: "[megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec]" --log-params-norm: true --log-num-zeros-in-grad: true --log-validation-ppl-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/model_config.yaml index 6eff846884a..fe4f9e63714 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp1_pp4_cp1_dgx_a100_1N8G/model_config.yaml @@ -9,7 +9,7 @@ MODEL_ARGS: --group-query-attention: true --num-query-groups: 8 --hybrid-layer-pattern: M-M-M-M*-M-|M-M-M*-M-M-|M-M*-M-M-M-|M*-M-M-M-M- - --spec: "[megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec]" + --spec: "[megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec]" --log-params-norm: true --log-num-zeros-in-grad: true --log-validation-ppl-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp1_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp1_dgx_a100_1N8G/model_config.yaml index 8c655bc135c..2339f7a7ce9 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp1_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp1_dgx_a100_1N8G/model_config.yaml @@ -9,7 +9,7 @@ MODEL_ARGS: --group-query-attention: true --num-query-groups: 8 --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- - --spec: "[megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec]" + --spec: "[megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec]" --log-params-norm: true --log-num-zeros-in-grad: true --log-validation-ppl-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp4_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp4_dgx_a100_1N8G/model_config.yaml index 44b588ee140..3efc155949f 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp4_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_mr_mcore_te_tp2_pp1_cp4_dgx_a100_1N8G/model_config.yaml @@ -9,7 +9,7 @@ MODEL_ARGS: --group-query-attention: true --num-query-groups: 8 --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M- - --spec: "[megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec]" + --spec: "[megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec]" --log-params-norm: true --log-num-zeros-in-grad: true --log-validation-ppl-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml index 26708b32a60..02c5cc3055c 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml @@ -22,7 +22,7 @@ MODEL_ARGS: --pipeline-model-parallel-size: 1 --expert-model-parallel-size: 1 --use-mcore-models: true - --model-provider: mamba + --model-provider: hybrid --init-method-std: 0.0198 --untie-embeddings-and-output-weights: true --disable-bias-linear: true @@ -33,7 +33,7 @@ MODEL_ARGS: --num-attention-heads: 16 --kv-channels: 128 --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- - --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec + --spec: megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec --normalization: RMSNorm --swiglu: true --attention-dropout: 0.0 diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml index 3964bcb8ecb..2543f59e668 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml @@ -22,7 +22,7 @@ MODEL_ARGS: --pipeline-model-parallel-size: 1 --expert-model-parallel-size: 1 --use-mcore-models: true - --model-provider: mamba + --model-provider: hybrid --init-method-std: 0.0198 --untie-embeddings-and-output-weights: true --disable-bias-linear: true @@ -33,7 +33,7 @@ MODEL_ARGS: --num-attention-heads: 16 --kv-channels: 128 --hybrid-layer-pattern: M-M-M-M*-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M- - --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec + --spec: megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec --normalization: RMSNorm --swiglu: true --attention-dropout: 0.0 diff --git a/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp1_pp2_resume_torch_dist_reshard_2x1x4_te_8experts2parallel_dist_optimizer/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp1_pp2_resume_torch_dist_reshard_2x1x4_te_8experts2parallel_dist_optimizer/model_config.yaml index e885bc255aa..1c0875aa0d0 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp1_pp2_resume_torch_dist_reshard_2x1x4_te_8experts2parallel_dist_optimizer/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp1_pp2_resume_torch_dist_reshard_2x1x4_te_8experts2parallel_dist_optimizer/model_config.yaml @@ -56,5 +56,6 @@ MODEL_ARGS: --disable-bias-linear: true --no-bias-gelu-fusion: true --log-memory-to-tensorboard: true + --verify-integrity: true TEST_TYPE: ckpt-resume LAUNCHER: ft_launcher diff --git a/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer/model_config.yaml index 764c576645e..2a8a2a5d72b 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer/model_config.yaml @@ -62,4 +62,5 @@ MODEL_ARGS: --log-memory-to-tensorboard: true --async-save: true --use-persistent-ckpt-worker: true + --verify-integrity: true TEST_TYPE: ckpt-resume diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_resume_torch_dist_dist_optimizer/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_resume_torch_dist_dist_optimizer/model_config.yaml index faf717c7821..07fffe2cfd3 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_resume_torch_dist_dist_optimizer/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_resume_torch_dist_dist_optimizer/model_config.yaml @@ -67,5 +67,6 @@ MODEL_ARGS: --async-save: true --use-persistent-ckpt-worker: true --async-ckpt-use-cpu-shm: true + --log-moe-overload-factor: true TEST_TYPE: ckpt-resume LAUNCHER: ft_launcher diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/golden_values_dev_dgx_h100.json index f0bbe4685d3..b553507edfb 100644 --- a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/golden_values_dev_dgx_h100.json @@ -157,25510 +157,2196 @@ "routing_indices": [ [ [ + 48, 33, - 50, - 36, - 4, - 25, - 63 + 24, + 52, + 38, + 61 ], [ - 0, - 16, 3, - 26, - 9, - 54 + 40, + 45, + 19, + 58, + 35 ], [ - 62, + 46, + 39, + 15, 60, - 8, - 16, - 58, - 52 + 31, + 25 ], [ - 58, - 10, - 6, + 49, + 62, + 61, + 48, + 55, + 46 + ], + [ + 41, + 35, + 53, + 52, + 13, + 10 + ], + [ + 52, 45, - 16, - 32 + 15, + 63, + 46, + 10 ], [ - 43, + 14, + 15, + 35, + 60, 49, - 18, - 54, - 55, - 13 + 31 ], [ - 27, - 19, - 26, 44, - 12, - 28 + 3, + 6, + 8, + 23, + 48 ], [ - 53, - 42, - 3, + 32, + 15, + 44, 27, - 26, - 19 + 4, + 41 ], [ - 6, - 47, - 1, - 19, - 8, - 22 + 51, + 25, + 62, + 53, + 48, + 10 ], [ - 51, - 27, - 1, - 38, + 8, + 50, + 5, + 19, 16, - 62 + 22 ], [ - 41, + 36, 49, - 21, - 57, - 16, - 24 + 60, + 44, + 15, + 33 ], [ 13, - 28, - 38, + 27, + 53, + 30, + 56, + 43 + ], + [ + 26, 22, - 49, - 48 + 14, + 7, + 32, + 17 ], [ - 52, + 26, + 60, + 2, 58, - 6, - 25, - 29, - 17 + 54, + 10 ], [ - 49, 50, - 25, - 41, - 54, - 58 + 55, + 17, + 51, + 47, + 14 + ], + [ + 0, + 12, + 16, + 4, + 23, + 9 ], [ + 10, + 23, 27, - 41, - 3, - 26, - 1, - 29 + 46, + 56, + 55 ], [ - 41, - 18, - 34, - 45, - 1, - 52 + 3, + 37, + 4, + 60, + 16, + 59 ], [ - 17, - 26, - 59, - 22, - 19, - 4 + 38, + 3, + 29, + 40, + 25, + 50 ], [ - 1, - 43, + 31, + 16, 62, - 57, - 61, - 29 + 54, + 42, + 63 ], [ + 47, + 35, 11, + 37, 46, - 15, - 28, - 61, - 50 - ], - [ - 14, - 30, - 58, - 26, - 38, - 53 + 32 ], [ - 59, + 51, + 27, 20, - 63, - 54, - 47, - 61 + 50, + 16, + 55 ], [ - 9, - 5, - 43, - 33, - 15, - 46 + 63, + 55, + 46, + 27, + 48, + 12 ], [ - 16, - 52, - 33, - 61, - 49, - 11 - ], - [ - 52, - 35, - 40, - 43, - 29, - 36 - ], - [ - 57, - 34, - 38, - 44, - 20, - 18 + 53, + 50, + 30, + 2, + 39, + 20 ], [ + 53, + 24, 44, + 8, 51, - 2, - 63, - 7, - 17 - ], - [ - 32, - 47, - 58, - 9, - 54, - 5 + 14 ], [ - 12, - 59, - 54, - 33, - 50, - 6 + 19, + 49, + 37, + 14, + 44, + 0 ] ], [ [ - 49, - 43, - 18, - 28, - 23, - 25 - ], - [ + 21, + 50, + 29, + 41, 34, - 24, - 23, - 60, - 2, - 18 + 60 ], [ + 28, + 51, + 60, 33, - 40, - 30, - 3, - 59, - 48 - ], - [ - 16, - 58, - 17, - 48, - 6, + 14, 45 ], [ - 23, - 58, + 31, + 2, 46, - 37, - 34, - 48 + 33, + 24, + 49 ], [ - 28, - 26, - 35, - 33, - 43, - 22 + 42, + 3, + 18, + 62, + 39, + 49 ], [ - 6, - 31, - 13, - 46, - 41, - 37 + 25, + 15, + 62, + 27, + 12, + 11 ], [ - 60, - 29, 44, + 50, 36, - 39, - 15 + 57, + 55, + 41 ], [ - 51, - 18, - 62, - 5, - 27, - 35 + 41, + 37, + 22, + 15, + 2, + 40 ], [ + 36, 62, - 45, - 32, - 56, - 25, - 3 + 53, + 30, + 14, + 57 ], [ - 6, - 20, - 3, 16, - 49, - 41 + 58, + 2, + 29, + 4, + 3 ], [ - 36, 41, - 50, + 38, + 26, + 16, 45, - 35, - 48 + 46 ], [ - 39, + 45, 46, - 48, - 25, + 32, + 41, + 56, + 26 + ], + [ + 17, + 53, 21, - 33 + 11, + 36, + 35 ], [ - 30, - 38, 11, + 16, + 28, + 14, + 51, + 61 + ], + [ + 9, + 35, + 33, 22, - 15, - 9 + 52, + 62 ], [ - 24, - 2, - 32, - 56, + 58, + 50, 63, - 3 + 30, + 7, + 27 ], [ - 36, 55, - 35, - 32, - 17, - 44 + 3, + 8, + 41, + 63, + 37 ], [ - 59, + 3, + 51, 46, 32, - 44, - 24, - 14 - ], - [ - 51, 15, - 61, - 43, - 30, - 22 + 6 ], [ 32, - 2, + 12, + 10, + 25, 5, + 49 + ], + [ + 34, + 2, + 37, + 61, 39, - 11, - 50 + 63 ], [ - 28, 42, - 6, - 30, - 57, - 37 + 22, + 27, + 53, + 11, + 56 ], [ - 28, - 55, - 45, + 53, + 12, 0, - 7, - 5 + 47, + 61, + 1 ], [ + 39, + 45, + 53, + 17, 48, - 40, - 34, - 3, - 22, - 49 + 14 ], [ - 25, 6, - 62, - 18, - 50, - 53 + 0, + 4, + 53, + 25, + 11 ], [ - 13, - 22, - 20, - 28, - 25, - 59 + 51, + 11, + 1, + 63, + 54, + 45 ], [ - 58, - 39, - 36, - 47, - 29, - 37 + 40, + 56, + 37, + 53, + 5, + 35 ], [ - 4, - 33, + 59, + 28, 41, - 12, - 3, - 17 + 10, + 1, + 45 ], [ - 22, - 13, - 20, - 52, + 27, + 30, + 28, 24, - 62 + 32, + 57 ] ], [ [ - 17, - 10, - 57, - 54, + 24, + 56, 6, - 15 + 0, + 19, + 45 ], [ - 33, - 43, - 13, - 1, - 16, - 62 + 11, + 57, + 59, + 25, + 46, + 30 ], [ - 63, - 1, - 35, - 43, - 27, - 10 + 11, + 26, + 37, + 29, + 14, + 52 ], [ - 47, - 4, + 3, + 32, + 7, 38, - 50, - 51, - 0 + 36, + 24 ], [ - 11, - 51, - 57, - 23, + 61, + 2, + 24, 14, - 34 + 51, + 44 ], [ - 10, - 43, - 35, - 33, 20, - 22 + 47, + 0, + 63, + 30, + 58 ], [ + 4, 36, - 48, - 35, - 19, - 21, - 28 + 29, + 58, + 16, + 3 ], [ + 20, + 0, + 45, 14, - 8, - 7, - 46, - 35, - 13 + 28, + 44 ], [ - 18, - 44, - 63, - 6, - 4, - 37 + 29, + 56, + 47, + 35, + 16, + 4 ], [ - 62, - 29, - 15, - 38, - 39, - 34 + 33, + 61, + 55, + 41, + 51, + 38 ], [ + 58, 1, - 6, - 16, - 46, - 22, - 13 + 38, + 14, + 4, + 19 ], [ + 0, 36, - 8, - 16, - 37, - 10, - 14 + 14, + 18, + 52, + 42 ], [ + 29, + 36, + 45, + 25, 8, - 0, - 32, - 3, - 43, - 10 + 6 ], [ - 25, - 0, - 22, - 30, - 60, - 57 + 6, + 57, + 50, + 40, + 58, + 61 ], [ - 60, + 44, 58, - 55, - 32, - 2, - 7 + 29, + 19, + 61, + 56 ], [ + 23, + 18, + 28, 55, - 11, - 19, 5, - 24, - 43 + 37 ], [ + 13, 9, - 47, - 36, - 39, - 5, - 42 + 19, + 43, + 37, + 3 ], [ - 31, - 20, - 9, - 43, - 18, + 32, + 22, + 63, + 14, + 57, 41 ], [ - 12, - 5, - 32, - 50, + 10, + 61, 3, - 31 + 1, + 19, + 32 ], [ - 40, - 18, - 30, - 63, - 7, - 25 + 16, + 55, + 10, + 41, + 59, + 22 ], [ - 14, + 53, 7, - 3, + 29, 38, - 59, - 54 + 27, + 46 ], [ - 31, - 20, - 1, - 22, + 24, 47, - 6 + 18, + 53, + 39, + 30 ], [ + 0, + 33, + 19, + 5, 51, - 15, - 18, - 53, - 40, - 52 + 17 ], [ - 27, - 54, + 51, 8, - 38, - 3, - 59 + 11, + 45, + 44, + 41 ], [ + 40, + 4, + 23, 11, - 16, - 39, - 59, - 9, - 23 + 27, + 19 ], [ - 7, - 37, - 51, - 30, + 16, 18, - 48 + 3, + 48, + 51, + 21 ], [ - 0, - 36, - 33, - 40, + 43, 46, - 48 + 60, + 19, + 53, + 12 ] ], [ [ - 48, - 33, 24, - 52, - 38, - 61 - ], - [ - 3, - 40, - 45, + 56, + 6, + 0, 19, - 58, - 35 - ], - [ - 46, - 39, - 15, - 60, - 31, - 25 - ], - [ - 49, - 62, - 61, - 48, - 55, - 46 - ], - [ - 41, - 35, - 53, - 52, - 13, - 10 + 45 ], [ - 52, - 45, - 15, - 63, + 11, + 57, + 59, 46, - 10 + 25, + 30 ], [ + 26, + 11, + 37, 14, - 15, - 35, - 60, - 49, - 31 + 29, + 49 ], [ - 44, + 38, + 36, 3, - 6, - 8, - 23, - 48 - ], - [ - 32, - 15, - 44, - 27, - 4, - 41 + 24, + 18, + 20 ], [ + 61, 51, - 25, - 62, - 53, - 48, - 10 + 14, + 2, + 24, + 1 ], [ + 20, + 0, + 47, + 30, 8, - 50, - 5, - 19, - 16, - 22 + 35 ], [ + 4, + 58, 36, - 49, - 60, + 54, + 29, + 12 + ], + [ + 20, + 58, 44, - 15, - 33 + 28, + 45, + 9 ], [ - 13, - 27, - 53, - 30, 56, - 43 + 47, + 10, + 29, + 35, + 27 ], [ - 26, - 22, - 14, - 7, - 32, - 17 + 61, + 33, + 55, + 54, + 4, + 36 ], [ - 26, - 60, - 2, 58, - 54, - 10 + 1, + 14, + 4, + 38, + 52 ], [ - 50, - 55, - 17, - 51, - 47, - 14 + 14, + 0, + 36, + 63, + 15, + 52 ], [ - 0, - 12, + 29, + 36, + 44, + 8, 16, - 4, - 23, - 9 + 2 ], [ - 10, - 23, + 6, + 40, 27, - 46, - 56, - 55 + 57, + 50, + 42 ], [ - 3, + 19, + 44, + 58, + 61, 37, - 4, - 60, - 16, - 59 + 38 ], [ - 38, - 3, + 23, + 18, + 17, + 57, + 13, + 40 + ], + [ + 13, + 9, + 19, + 37, + 50, + 15 + ], + [ + 32, + 14, + 57, + 58, 29, - 40, - 25, - 50 + 22 ], [ - 31, + 61, + 10, + 1, + 3, + 14, + 59 + ], + [ + 55, 16, - 62, - 54, - 42, - 63 + 34, + 18, + 22, + 49 + ], + [ + 53, + 27, + 38, + 28, + 23, + 44 ], [ + 24, 47, - 35, - 11, - 37, - 46, - 32 + 18, + 62, + 41, + 30 ], [ + 33, 51, - 27, - 20, - 50, - 16, - 55 + 19, + 5, + 0, + 31 ], [ - 63, - 55, - 46, + 51, + 8, + 25, + 53, 27, - 48, - 12 + 16 ], [ - 53, - 50, - 30, - 2, - 39, - 20 + 40, + 11, + 27, + 4, + 23, + 19 ], [ - 53, - 24, - 44, - 8, + 16, + 18, 51, - 14 + 48, + 3, + 47 ], [ - 19, - 49, - 37, - 14, - 44, - 0 + 46, + 43, + 36, + 9, + 5, + 12 ] ], [ [ 0, - 52, - 16, - 12, - 54, - 7 + 10, + 49, + 23, + 62, + 44 ], [ - 42, - 25, - 51, - 61, - 35, - 58 + 28, + 0, + 36, + 26, + 47, + 52 ], [ - 51, - 42, - 19, - 57, - 28, - 8 + 30, + 4, + 16, + 48, + 40, + 10 ], [ - 49, - 62, - 5, - 2, - 46, - 21 + 61, + 32, + 26, + 16, + 33, + 62 ], [ - 3, - 41, - 53, - 25, + 30, 39, - 37 + 53, + 5, + 57, + 20 ], [ - 45, - 15, + 5, 37, - 48, - 19, - 60 + 61, + 15, + 25, + 6 ], [ - 14, 15, - 47, 17, 24, - 35 + 60, + 49, + 62 + ], + [ + 34, + 39, + 61, + 0, + 58, + 40 ], [ - 3, - 52, - 63, 16, - 28, - 47 + 39, + 36, + 51, + 2, + 29 ], [ - 32, - 27, - 15, + 9, + 11, + 41, + 31, + 56, + 52 + ], + [ + 10, + 48, 24, + 45, 62, - 23 + 51 ], [ - 62, - 25, - 51, - 53, - 0, - 20 + 11, + 38, + 36, + 37, + 6, + 42 ], [ - 8, - 19, + 51, 50, - 16, - 32, - 22 - ], - [ - 13, - 36, - 42, - 49, - 60, - 44 - ], - [ - 30, - 27, - 53, - 56, - 0, - 10 - ], - [ - 22, - 7, - 14, - 26, - 32, - 17 - ], - [ - 60, - 7, - 26, - 2, - 58, - 13 - ], - [ - 55, - 51, - 17, - 60, - 62, - 47 - ], - [ - 53, - 16, - 18, - 4, - 50, - 5 - ], - [ - 28, - 39, - 41, - 1, - 55, - 18 - ], - [ - 59, - 49, - 16, - 23, - 42, - 11 - ], - [ - 17, - 30, - 46, - 55, - 25, - 0 - ], - [ - 54, - 9, - 45, - 0, - 6, - 19 - ], - [ - 51, - 9, - 22, - 23, - 16, - 25 - ], - [ - 62, - 50, - 43, - 51, - 55, - 27 - ], - [ - 13, - 43, - 27, - 1, - 14, - 52 - ], - [ - 61, - 26, - 1, - 17, - 32, - 63 - ], - [ - 37, - 46, - 63, - 20, - 24, - 4 - ], - [ - 63, - 11, - 12, - 61, - 31, - 22 - ] - ], - [ - [ - 49, - 54, - 56, - 11, - 38, - 40 - ], - [ - 53, - 46, - 49, - 38, - 57, - 17 - ], - [ - 48, - 9, - 31, - 12, - 6, - 56 - ], - [ - 49, - 62, - 23, - 5, - 12, - 63 - ], - [ - 36, - 26, - 38, - 7, - 20, - 23 - ], - [ - 37, - 33, - 41, - 58, - 57, - 32 - ], - [ - 10, - 14, - 15, - 31, - 8, - 43 - ], - [ - 16, - 52, - 3, - 2, - 34, - 14 - ], - [ - 15, - 32, - 35, - 27, - 62, - 54 - ], - [ - 51, - 62, - 25, - 53, - 10, - 20 - ], - [ - 19, - 8, - 22, - 50, - 1, - 5 - ], - [ - 13, - 49, - 36, - 60, - 42, - 20 - ], - [ - 27, - 30, - 53, - 54, - 26, - 43 - ], - [ - 14, - 32, - 26, - 22, - 7, - 17 - ], - [ - 26, - 60, - 7, - 2, - 52, - 54 - ], - [ - 55, - 17, - 51, - 47, - 62, - 26 - ], - [ - 16, - 18, - 44, - 4, - 53, - 50 - ], - [ - 1, - 47, - 39, - 45, - 28, - 56 - ], - [ - 23, - 16, - 55, - 49, - 8, - 32 - ], - [ - 17, - 55, - 30, - 62, - 31, - 23 - ], - [ - 9, - 54, - 0, - 44, - 14, - 56 - ], - [ - 23, - 9, - 51, - 22, - 31, - 50 - ], - [ - 50, - 27, - 14, - 51, - 18, - 8 - ], - [ - 12, - 13, - 33, - 1, - 5, - 43 - ], - [ - 26, - 32, - 1, - 50, - 37, - 57 - ], - [ - 37, - 47, - 63, - 46, - 5, - 4 - ], - [ - 63, - 11, - 12, - 19, - 33, - 61 - ] - ], - [ - [ - 49, - 54, - 40, - 56, - 11, - 3 - ], - [ - 37, - 15, - 12, - 33, - 59, - 17 - ], - [ - 38, - 49, - 14, - 46, - 35, - 59 - ], - [ - 25, - 20, - 39, - 62, - 49, - 32 - ], - [ - 26, - 51, - 16, - 36, - 8, - 7 - ], - [ - 37, - 41, - 51, - 33, - 32, - 60 - ], - [ - 10, - 14, - 59, - 8, - 15, - 40 - ], - [ - 16, - 52, - 19, - 61, - 32, - 3 - ], - [ - 32, - 15, - 27, - 24, - 62, - 35 - ], - [ - 51, - 25, - 62, - 0, - 49, - 54 - ], - [ - 8, - 50, - 19, - 48, - 16, - 54 - ], - [ - 13, - 49, - 36, - 42, - 11, - 60 - ], - [ - 53, - 27, - 7, - 26, - 30, - 21 - ], - [ - 14, - 22, - 26, - 37, - 32, - 7 - ], - [ - 26, - 40, - 60, - 2, - 52, - 7 - ], - [ - 55, - 51, - 17, - 46, - 13, - 62 - ], - [ - 38, - 16, - 53, - 44, - 4, - 18 - ], - [ - 39, - 1, - 4, - 14, - 56, - 57 - ], - [ - 55, - 23, - 32, - 14, - 13, - 16 - ], - [ - 55, - 17, - 2, - 30, - 62, - 12 - ], - [ - 13, - 54, - 0, - 62, - 61, - 25 - ], - [ - 9, - 51, - 5, - 22, - 19, - 23 - ], - [ - 50, - 51, - 7, - 19, - 48, - 53 - ], - [ - 12, - 1, - 5, - 43, - 61, - 13 - ], - [ - 26, - 32, - 30, - 37, - 34, - 57 - ], - [ - 37, - 63, - 46, - 4, - 47, - 8 - ], - [ - 63, - 12, - 11, - 31, - 33, - 61 - ] - ], - [ - [ - 47, - 34, - 30, - 25, - 31, - 3 - ], - [ - 15, - 24, - 46, - 21, - 8, - 6 - ], - [ - 34, - 21, - 18, - 62, - 28, - 55 - ], - [ - 35, - 32, - 20, - 39, - 59, - 54 - ], - [ - 26, - 27, - 15, - 48, - 60, - 47 - ], - [ - 37, - 8, - 50, - 18, - 54, - 61 - ], - [ - 35, - 31, - 8, - 24, - 14, - 15 - ], - [ - 16, - 52, - 34, - 29, - 48, - 36 - ], - [ - 11, - 32, - 62, - 27, - 46, - 26 - ], - [ - 61, - 62, - 25, - 56, - 46, - 53 - ], - [ - 56, - 50, - 63, - 3, - 45, - 28 - ], - [ - 11, - 36, - 5, - 60, - 35, - 50 - ], - [ - 21, - 26, - 41, - 51, - 46, - 53 - ], - [ - 14, - 22, - 33, - 19, - 41, - 16 - ], - [ - 2, - 52, - 34, - 60, - 21, - 49 - ], - [ - 59, - 55, - 29, - 8, - 61, - 22 - ], - [ - 51, - 44, - 2, - 59, - 47, - 53 - ], - [ - 39, - 25, - 18, - 12, - 51, - 56 - ], - [ - 34, - 53, - 32, - 12, - 9, - 38 - ], - [ - 30, - 53, - 56, - 7, - 40, - 62 - ], - [ - 40, - 49, - 28, - 14, - 23, - 55 - ], - [ - 15, - 48, - 40, - 47, - 9, - 1 - ], - [ - 50, - 41, - 25, - 53, - 18, - 0 - ], - [ - 22, - 1, - 59, - 3, - 55, - 8 - ], - [ - 1, - 53, - 32, - 26, - 47, - 3 - ], - [ - 4, - 33, - 28, - 37, - 55, - 54 - ], - [ - 30, - 22, - 57, - 12, - 33, - 63 - ] - ], - [ - [ - 16, - 11, - 0, - 31, - 46, - 22 - ], - [ - 49, - 13, - 5, - 11, - 31, - 14 - ], - [ - 36, - 13, - 56, - 27, - 46, - 3 - ], - [ - 24, - 44, - 62, - 29, - 15, - 13 - ], - [ - 17, - 2, - 50, - 8, - 45, - 1 - ], - [ - 8, - 7, - 49, - 0, - 62, - 13 - ], - [ - 35, - 61, - 58, - 23, - 36, - 0 - ], - [ - 16, - 48, - 42, - 4, - 32, - 29 - ], - [ - 63, - 18, - 32, - 45, - 4, - 34 - ], - [ - 57, - 62, - 54, - 27, - 25, - 53 - ], - [ - 1, - 59, - 60, - 29, - 22, - 14 - ], - [ - 31, - 36, - 11, - 14, - 20, - 10 - ], - [ - 34, - 2, - 19, - 14, - 8, - 37 - ], - [ - 57, - 22, - 40, - 14, - 62, - 48 - ], - [ - 44, - 60, - 7, - 14, - 45, - 2 - ], - [ - 8, - 55, - 6, - 25, - 50, - 59 - ], - [ - 13, - 47, - 42, - 23, - 61, - 39 - ], - [ - 23, - 25, - 4, - 14, - 46, - 60 - ], - [ - 8, - 45, - 32, - 53, - 10, - 54 - ], - [ - 15, - 38, - 53, - 55, - 30, - 7 - ], - [ - 41, - 14, - 28, - 5, - 58, - 27 - ], - [ - 11, - 41, - 57, - 1, - 10, - 47 - ], - [ - 50, - 0, - 51, - 53, - 34, - 45 - ], - [ - 1, - 14, - 55, - 8, - 25, - 3 - ], - [ - 11, - 49, - 1, - 9, - 0, - 3 - ], - [ - 50, - 51, - 6, - 42, - 4, - 54 - ], - [ - 17, - 37, - 31, - 5, - 40, - 36 - ] - ], - [ - [ - 22, - 53, - 47, - 6, - 57, - 21 - ], - [ - 27, - 11, - 14, - 6, - 57, - 16 - ], - [ - 1, - 11, - 29, - 26, - 41, - 17 - ], - [ - 14, - 22, - 38, - 31, - 29, - 36 - ], - [ - 14, - 59, - 29, - 61, - 45, - 52 - ], - [ - 30, - 8, - 0, - 21, - 47, - 58 - ], - [ - 58, - 35, - 4, - 61, - 23, - 36 - ], - [ - 42, - 20, - 48, - 16, - 9, - 4 - ], - [ - 47, - 29, - 4, - 18, - 63, - 32 - ], - [ - 19, - 54, - 62, - 53, - 57, - 29 - ], - [ - 1, - 60, - 14, - 59, - 4, - 29 - ], - [ - 36, - 0, - 47, - 3, - 31, - 8 - ], - [ - 2, - 19, - 36, - 8, - 20, - 37 - ], - [ - 57, - 22, - 40, - 31, - 49, - 14 - ], - [ - 44, - 37, - 2, - 5, - 60, - 21 - ], - [ - 6, - 43, - 24, - 5, - 2, - 59 - ], - [ - 13, - 19, - 61, - 47, - 50, - 39 - ], - [ - 58, - 14, - 28, - 4, - 11, - 22 - ], - [ - 35, - 32, - 46, - 10, - 31, - 45 - ], - [ - 15, - 13, - 55, - 45, - 18, - 63 - ], - [ - 15, - 27, - 28, - 14, - 5, - 60 - ], - [ - 57, - 41, - 47, - 19, - 36, - 10 - ], - [ - 34, - 10, - 53, - 55, - 22, - 19 - ], - [ - 38, - 55, - 39, - 27, - 3, - 25 - ], - [ - 11, - 39, - 0, - 9, - 3, - 49 - ], - [ - 51, - 6, - 43, - 18, - 50, - 53 - ], - [ - 55, - 43, - 9, - 36, - 40, - 5 - ] - ], - [ - [ - 18, - 9, - 1, - 36, - 61, - 44 - ], - [ - 56, - 34, - 19, - 42, - 3, - 5 - ], - [ - 39, - 20, - 15, - 60, - 46, - 32 - ], - [ - 60, - 22, - 31, - 27, - 14, - 19 - ], - [ - 59, - 58, - 10, - 7, - 46, - 18 - ], - [ - 43, - 2, - 57, - 62, - 11, - 30 - ], - [ - 54, - 19, - 9, - 21, - 48, - 56 - ], - [ - 46, - 24, - 7, - 14, - 3, - 8 - ], - [ - 47, - 0, - 4, - 18, - 31, - 29 - ], - [ - 54, - 62, - 47, - 38, - 4, - 32 - ], - [ - 1, - 14, - 15, - 22, - 59, - 38 - ], - [ - 16, - 36, - 42, - 55, - 15, - 18 - ], - [ - 49, - 8, - 20, - 14, - 0, - 33 - ], - [ - 18, - 39, - 25, - 2, - 62, - 22 - ], - [ - 62, - 5, - 58, - 37, - 7, - 32 - ], - [ - 43, - 5, - 42, - 63, - 55, - 37 - ], - [ - 47, - 33, - 15, - 63, - 50, - 12 - ], - [ - 60, - 0, - 7, - 16, - 32, - 13 - ], - [ - 12, - 39, - 32, - 61, - 16, - 45 - ], - [ - 52, - 34, - 15, - 62, - 18, - 30 - ], - [ - 28, - 26, - 46, - 40, - 6, - 14 - ], - [ - 1, - 19, - 17, - 20, - 4, - 21 - ], - [ - 41, - 40, - 4, - 53, - 55, - 19 - ], - [ - 25, - 38, - 27, - 34, - 52, - 46 - ], - [ - 11, - 29, - 52, - 44, - 53, - 13 - ], - [ - 50, - 51, - 41, - 16, - 4, - 15 - ], - [ - 19, - 6, - 23, - 36, - 60, - 0 - ] - ], - [ - [ - 17, - 10, - 57, - 27, - 5, - 54 - ], - [ - 33, - 9, - 43, - 40, - 56, - 11 - ], - [ - 63, - 1, - 35, - 43, - 10, - 27 - ], - [ - 51, - 47, - 20, - 21, - 28, - 61 - ], - [ - 25, - 11, - 58, - 23, - 55, - 46 - ], - [ - 43, - 10, - 12, - 2, - 62, - 30 - ], - [ - 48, - 19, - 21, - 8, - 7, - 54 - ], - [ - 14, - 7, - 24, - 8, - 46, - 2 - ], - [ - 4, - 47, - 37, - 0, - 44, - 27 - ], - [ - 54, - 38, - 62, - 47, - 15, - 14 - ], - [ - 1, - 46, - 15, - 22, - 51, - 38 - ], - [ - 36, - 16, - 42, - 55, - 24, - 37 - ], - [ - 49, - 10, - 0, - 3, - 43, - 8 - ], - [ - 39, - 58, - 0, - 62, - 22, - 25 - ], - [ - 58, - 38, - 7, - 55, - 62, - 56 - ], - [ - 19, - 42, - 55, - 43, - 11, - 37 - ], - [ - 9, - 47, - 43, - 52, - 18, - 50 - ], - [ - 31, - 41, - 32, - 25, - 20, - 13 - ], - [ - 12, - 32, - 61, - 3, - 21, - 43 - ], - [ - 36, - 13, - 40, - 7, - 62, - 16 - ], - [ - 14, - 53, - 50, - 47, - 51, - 1 - ], - [ - 1, - 38, - 19, - 18, - 30, - 16 - ], - [ - 0, - 19, - 51, - 18, - 52, - 15 - ], - [ - 8, - 52, - 27, - 34, - 38, - 3 - ], - [ - 27, - 53, - 59, - 9, - 40, - 4 - ], - [ - 37, - 3, - 26, - 48, - 8, - 16 - ], - [ - 46, - 18, - 11, - 40, - 33, - 44 - ] - ], - [ - [ - 48, - 62, - 61, - 50, - 26, - 59 - ], - [ - 3, - 45, - 40, - 35, - 29, - 54 - ], - [ - 56, - 31, - 23, - 28, - 2, - 53 - ], - [ - 62, - 49, - 20, - 61, - 6, - 41 - ], - [ - 18, - 25, - 50, - 0, - 14, - 57 - ], - [ - 58, - 4, - 10, - 43, - 56, - 20 - ], - [ - 35, - 15, - 25, - 24, - 3, - 7 - ], - [ - 14, - 23, - 8, - 12, - 57, - 24 - ], - [ - 29, - 17, - 35, - 44, - 24, - 27 - ], - [ - 62, - 15, - 38, - 20, - 58, - 21 - ], - [ - 19, - 46, - 1, - 26, - 63, - 22 - ], - [ - 36, - 60, - 16, - 42, - 55, - 11 - ], - [ - 17, - 7, - 14, - 26, - 16, - 49 - ], - [ - 45, - 47, - 22, - 0, - 62, - 58 - ], - [ - 58, - 38, - 48, - 49, - 63, - 2 - ], - [ - 55, - 0, - 1, - 37, - 30, - 10 - ], - [ - 12, - 43, - 21, - 9, - 47, - 23 - ], - [ - 32, - 57, - 42, - 25, - 43, - 63 - ], - [ - 3, - 32, - 49, - 61, - 21, - 12 - ], - [ - 5, - 36, - 22, - 16, - 62, - 42 - ], - [ - 53, - 7, - 46, - 61, - 14, - 52 - ], - [ - 55, - 30, - 3, - 5, - 53, - 31 - ], - [ - 0, - 44, - 15, - 18, - 19, - 28 - ], - [ - 8, - 52, - 51, - 11, - 4, - 29 - ], - [ - 27, - 40, - 4, - 9, - 35, - 39 - ], - [ - 14, - 26, - 3, - 48, - 16, - 21 - ], - [ - 60, - 54, - 35, - 20, - 53, - 12 - ] - ], - [ - [ - 21, - 50, - 29, - 41, - 34, - 60 - ], - [ - 28, - 51, - 60, - 33, - 14, - 45 - ], - [ - 31, - 2, - 46, - 33, - 24, - 49 - ], - [ - 42, - 3, - 18, - 62, - 39, - 49 - ], - [ - 25, - 15, - 62, - 27, - 12, - 11 - ], - [ - 44, - 50, - 36, - 57, - 55, - 41 - ], - [ - 41, - 37, - 22, - 15, - 2, - 40 - ], - [ - 36, - 62, - 53, - 30, - 14, - 57 - ], - [ - 16, - 58, - 2, - 29, - 4, - 3 - ], - [ - 41, - 38, - 26, - 16, - 45, - 46 - ], - [ - 45, - 46, - 32, - 41, - 56, - 26 - ], - [ - 17, - 53, - 21, - 11, - 36, - 35 - ], - [ - 11, - 16, - 28, - 14, - 51, - 61 - ], - [ - 9, - 35, - 33, - 22, - 52, - 62 - ], - [ - 58, - 50, - 63, - 30, - 7, - 27 - ], - [ - 55, - 3, - 8, - 41, - 63, - 37 - ], - [ - 3, - 51, - 46, - 32, - 15, - 6 - ], - [ - 32, - 12, - 10, - 25, - 5, - 49 - ], - [ - 34, - 2, - 37, - 61, - 39, - 63 - ], - [ - 42, - 22, - 27, - 53, - 11, - 56 - ], - [ - 53, - 12, - 0, - 47, - 61, - 1 - ], - [ - 39, - 45, - 53, - 17, - 48, - 14 - ], - [ - 6, - 0, - 4, - 53, - 25, - 11 - ], - [ - 51, - 11, - 1, - 63, - 54, - 45 - ], - [ - 40, - 56, - 37, - 53, - 5, - 35 - ], - [ - 59, - 28, - 41, - 10, - 1, - 45 - ], - [ - 27, - 30, - 28, - 24, - 32, - 57 - ] - ], - [ - [ - 24, - 56, - 6, - 0, - 19, - 45 - ], - [ - 11, - 57, - 59, - 25, - 46, - 30 - ], - [ - 11, - 26, - 37, - 29, - 14, - 52 - ], - [ - 3, - 32, - 7, - 38, - 36, - 24 - ], - [ - 61, - 2, - 24, - 14, - 51, - 44 - ], - [ - 20, - 47, - 0, - 63, - 30, - 58 - ], - [ - 4, - 36, - 29, - 58, - 16, - 3 - ], - [ - 20, - 0, - 45, - 14, - 28, - 44 - ], - [ - 29, - 56, - 47, - 35, - 16, - 4 - ], - [ - 33, - 61, - 55, - 41, - 51, - 38 - ], - [ - 58, - 1, - 38, - 14, - 4, - 19 - ], - [ - 0, - 36, - 14, - 18, - 52, - 42 - ], - [ - 29, - 36, - 45, - 25, - 8, - 6 - ], - [ - 6, - 57, - 50, - 40, - 58, - 61 - ], - [ - 44, - 58, - 29, - 19, - 61, - 56 - ], - [ - 23, - 18, - 28, - 55, - 5, - 37 - ], - [ - 13, - 9, - 19, - 43, - 37, - 3 - ], - [ - 32, - 22, - 63, - 14, - 57, - 41 - ], - [ - 10, - 61, - 3, - 1, - 19, - 32 - ], - [ - 16, - 55, - 10, - 41, - 59, - 22 - ], - [ - 53, - 7, - 29, - 38, - 27, - 46 - ], - [ - 24, - 47, - 18, - 53, - 39, - 30 - ], - [ - 0, - 33, - 19, - 5, - 51, - 17 - ], - [ - 51, - 8, - 11, - 45, - 44, - 41 - ], - [ - 40, - 4, - 23, - 11, - 27, - 19 - ], - [ - 16, - 18, - 3, - 48, - 51, - 21 - ], - [ - 43, - 46, - 60, - 19, - 53, - 12 - ] - ], - [ - [ - 48, - 62, - 61, - 30, - 50, - 52 - ], - [ - 45, - 3, - 35, - 29, - 54, - 2 - ], - [ - 56, - 31, - 53, - 23, - 49, - 28 - ], - [ - 60, - 57, - 14, - 46, - 41, - 48 - ], - [ - 18, - 61, - 59, - 14, - 44, - 32 - ], - [ - 45, - 58, - 47, - 20, - 4, - 30 - ], - [ - 54, - 13, - 25, - 36, - 26, - 47 - ], - [ - 20, - 12, - 0, - 47, - 30, - 45 - ], - [ - 56, - 29, - 47, - 17, - 35, - 16 - ], - [ - 33, - 61, - 55, - 11, - 38, - 48 - ], - [ - 58, - 19, - 14, - 1, - 38, - 36 - ], - [ - 14, - 36, - 0, - 60, - 11, - 52 - ], - [ - 29, - 44, - 7, - 36, - 16, - 45 - ], - [ - 6, - 47, - 50, - 33, - 42, - 62 - ], - [ - 44, - 58, - 61, - 38, - 29, - 56 - ], - [ - 23, - 55, - 18, - 0, - 57, - 37 - ], - [ - 9, - 12, - 43, - 19, - 13, - 6 - ], - [ - 32, - 22, - 63, - 57, - 42, - 29 - ], - [ - 3, - 61, - 1, - 10, - 49, - 32 - ], - [ - 5, - 16, - 55, - 36, - 22, - 59 - ], - [ - 53, - 7, - 46, - 29, - 9, - 14 - ], - [ - 24, - 30, - 18, - 39, - 55, - 53 - ], - [ - 33, - 0, - 19, - 44, - 51, - 5 - ], - [ - 51, - 8, - 53, - 41, - 4, - 11 - ], - [ - 40, - 4, - 27, - 19, - 23, - 16 - ], - [ - 16, - 14, - 48, - 3, - 21, - 26 - ], - [ - 60, - 54, - 35, - 53, - 12, - 43 - ] - ], - [ - [ - 19, - 41, - 8, - 7, - 13, - 2 - ], - [ - 48, - 46, - 62, - 29, - 5, - 41 - ], - [ - 12, - 5, - 59, - 3, - 58, - 49 - ], - [ - 60, - 3, - 42, - 39, - 14, - 18 - ], - [ - 42, - 12, - 27, - 11, - 25, - 19 - ], - [ - 50, - 36, - 44, - 26, - 33, - 37 - ], - [ - 41, - 54, - 22, - 52, - 37, - 35 - ], - [ - 62, - 30, - 36, - 53, - 10, - 14 - ], - [ - 2, - 16, - 58, - 29, - 7, - 41 - ], - [ - 18, - 32, - 45, - 16, - 22, - 38 - ], - [ - 45, - 56, - 41, - 10, - 3, - 46 - ], - [ - 50, - 21, - 36, - 35, - 53, - 12 - ], - [ - 11, - 28, - 16, - 41, - 39, - 46 - ], - [ - 16, - 9, - 33, - 38, - 28, - 19 - ], - [ - 58, - 50, - 63, - 62, - 27, - 52 - ], - [ - 55, - 59, - 13, - 8, - 43, - 3 - ], - [ - 3, - 51, - 15, - 46, - 47, - 57 - ], - [ - 32, - 30, - 12, - 10, - 25, - 18 - ], - [ - 34, - 2, - 61, - 27, - 53, - 59 - ], - [ - 42, - 22, - 56, - 53, - 44, - 34 - ], - [ - 53, - 12, - 49, - 41, - 44, - 8 - ], - [ - 45, - 1, - 48, - 47, - 16, - 17 - ], - [ - 4, - 0, - 53, - 25, - 24, - 11 - ], - [ - 1, - 11, - 44, - 45, - 34, - 51 - ], - [ - 40, - 5, - 53, - 6, - 22, - 18 - ], - [ - 28, - 10, - 1, - 3, - 15, - 41 - ], - [ - 30, - 27, - 24, - 57, - 32, - 16 - ] - ], - [ - [ - 24, - 56, - 6, - 0, - 19, - 45 - ], - [ - 11, - 57, - 59, - 46, - 25, - 30 - ], - [ - 26, - 11, - 37, - 14, - 29, - 49 - ], - [ - 38, - 36, - 3, - 24, - 18, - 20 - ], - [ - 61, - 51, - 14, - 2, - 24, - 1 - ], - [ - 20, - 0, - 47, - 30, - 8, - 35 - ], - [ - 4, - 58, - 36, - 54, - 29, - 12 - ], - [ - 20, - 58, - 44, - 28, - 45, - 9 - ], - [ - 56, - 47, - 10, - 29, - 35, - 27 - ], - [ - 61, - 33, - 55, - 54, - 4, - 36 - ], - [ - 58, - 1, - 14, - 4, - 38, - 52 - ], - [ - 14, - 0, - 36, - 63, - 15, - 52 - ], - [ - 29, - 36, - 44, - 8, - 16, - 2 - ], - [ - 6, - 40, - 27, - 57, - 50, - 42 - ], - [ - 19, - 44, - 58, - 61, - 37, - 38 - ], - [ - 23, - 18, - 17, - 57, - 13, - 40 - ], - [ - 13, - 9, - 19, - 37, - 50, - 15 - ], - [ - 32, - 14, - 57, - 58, - 29, - 22 - ], - [ - 61, - 10, - 1, - 3, - 14, - 59 - ], - [ - 55, - 16, - 34, - 18, - 22, - 49 - ], - [ - 53, - 27, - 38, - 28, - 23, - 44 - ], - [ - 24, - 47, - 18, - 62, - 41, - 30 - ], - [ - 33, - 51, - 19, - 5, - 0, - 31 - ], - [ - 51, - 8, - 25, - 53, - 27, - 16 - ], - [ - 40, - 11, - 27, - 4, - 23, - 19 - ], - [ - 16, - 18, - 51, - 48, - 3, - 47 - ], - [ - 46, - 43, - 36, - 9, - 5, - 12 - ] - ], - [ - [ - 37, - 10, - 46, - 60, - 61, - 59 - ], - [ - 35, - 53, - 34, - 43, - 19, - 57 - ], - [ - 49, - 56, - 45, - 30, - 6, - 12 - ], - [ - 60, - 27, - 14, - 48, - 46, - 57 - ], - [ - 61, - 59, - 14, - 41, - 16, - 1 - ], - [ - 45, - 4, - 3, - 58, - 24, - 47 - ], - [ - 54, - 13, - 9, - 43, - 16, - 26 - ], - [ - 47, - 23, - 12, - 20, - 63, - 30 - ], - [ - 23, - 44, - 56, - 29, - 47, - 17 - ], - [ - 33, - 60, - 61, - 48, - 41, - 14 - ], - [ - 58, - 63, - 19, - 11, - 9, - 38 - ], - [ - 60, - 63, - 0, - 36, - 15, - 9 - ], - [ - 29, - 36, - 30, - 59, - 11, - 27 - ], - [ - 6, - 7, - 47, - 62, - 50, - 57 - ], - [ - 27, - 58, - 19, - 46, - 29, - 56 - ], - [ - 29, - 60, - 56, - 55, - 23, - 26 - ], - [ - 53, - 59, - 6, - 9, - 16, - 43 - ], - [ - 41, - 32, - 57, - 63, - 18, - 37 - ], - [ - 42, - 61, - 3, - 10, - 34, - 59 - ], - [ - 4, - 43, - 17, - 16, - 52, - 60 - ], - [ - 45, - 53, - 61, - 56, - 16, - 7 - ], - [ - 55, - 9, - 18, - 61, - 45, - 3 - ], - [ - 60, - 47, - 53, - 33, - 12, - 27 - ], - [ - 43, - 51, - 11, - 8, - 45, - 63 - ], - [ - 6, - 40, - 15, - 27, - 26, - 23 - ], - [ - 49, - 14, - 9, - 21, - 58, - 12 - ], - [ - 63, - 24, - 60, - 31, - 12, - 34 - ] - ], - [ - [ - 16, - 13, - 4, - 44, - 23, - 46 - ], - [ - 16, - 50, - 9, - 13, - 23, - 36 - ], - [ - 11, - 35, - 21, - 7, - 59, - 8 - ], - [ - 1, - 3, - 25, - 15, - 60, - 39 - ], - [ - 54, - 61, - 31, - 35, - 55, - 1 - ], - [ - 51, - 52, - 46, - 15, - 4, - 45 - ], - [ - 60, - 54, - 59, - 44, - 10, - 7 - ], - [ - 12, - 22, - 14, - 47, - 0, - 30 - ], - [ - 42, - 29, - 23, - 56, - 47, - 33 - ], - [ - 33, - 61, - 20, - 60, - 0, - 49 - ], - [ - 35, - 58, - 63, - 14, - 51, - 24 - ], - [ - 29, - 33, - 36, - 60, - 0, - 49 - ], - [ - 29, - 17, - 30, - 12, - 31, - 36 - ], - [ - 6, - 0, - 61, - 50, - 48, - 3 - ], - [ - 8, - 6, - 58, - 37, - 29, - 19 - ], - [ - 60, - 39, - 27, - 19, - 1, - 57 - ], - [ - 56, - 9, - 30, - 6, - 10, - 43 - ], - [ - 32, - 20, - 13, - 57, - 63, - 49 - ], - [ - 8, - 42, - 21, - 61, - 37, - 4 - ], - [ - 45, - 49, - 16, - 13, - 2, - 58 - ], - [ - 53, - 29, - 14, - 50, - 61, - 3 - ], - [ - 37, - 57, - 27, - 54, - 46, - 9 - ], - [ - 52, - 19, - 22, - 0, - 18, - 5 - ], - [ - 14, - 49, - 30, - 33, - 53, - 34 - ], - [ - 13, - 9, - 4, - 40, - 23, - 39 - ], - [ - 27, - 43, - 47, - 36, - 49, - 3 - ], - [ - 59, - 43, - 40, - 28, - 0, - 33 - ] - ], - [ - [ - 48, - 42, - 63, - 50, - 34, - 38 - ], - [ - 3, - 40, - 61, - 62, - 6, - 2 - ], - [ - 39, - 7, - 36, - 6, - 45, - 40 - ], - [ - 41, - 35, - 46, - 13, - 63, - 56 - ], - [ - 6, - 1, - 54, - 37, - 38, - 34 - ], - [ - 59, - 46, - 51, - 31, - 4, - 52 - ], - [ - 60, - 44, - 11, - 54, - 4, - 24 - ], - [ - 12, - 0, - 2, - 63, - 50, - 47 - ], - [ - 33, - 42, - 29, - 23, - 16, - 56 - ], - [ - 20, - 61, - 33, - 60, - 53, - 0 - ], - [ - 35, - 58, - 63, - 9, - 8, - 19 - ], - [ - 59, - 48, - 36, - 60, - 10, - 14 - ], - [ - 29, - 44, - 7, - 17, - 36, - 12 - ], - [ - 47, - 27, - 6, - 62, - 42, - 48 - ], - [ - 18, - 58, - 49, - 46, - 42, - 44 - ], - [ - 60, - 34, - 27, - 18, - 23, - 55 - ], - [ - 43, - 40, - 9, - 50, - 18, - 45 - ], - [ - 32, - 57, - 48, - 42, - 29, - 39 - ], - [ - 42, - 61, - 49, - 3, - 32, - 1 - ], - [ - 23, - 37, - 1, - 16, - 36, - 39 - ], - [ - 53, - 21, - 7, - 61, - 50, - 31 - ], - [ - 8, - 60, - 18, - 24, - 9, - 30 - ], - [ - 51, - 33, - 28, - 5, - 44, - 8 - ], - [ - 51, - 52, - 8, - 4, - 41, - 45 - ], - [ - 40, - 4, - 27, - 9, - 60, - 19 - ], - [ - 3, - 61, - 16, - 26, - 48, - 12 - ], - [ - 54, - 61, - 35, - 1, - 53, - 43 - ] - ], - [ - [ - 62, - 28, - 1, - 42, - 8, - 55 - ], - [ - 18, - 12, - 8, - 41, - 40, - 31 - ], - [ - 12, - 6, - 50, - 4, - 23, - 45 - ], - [ - 43, - 35, - 8, - 20, - 42, - 46 - ], - [ - 39, - 41, - 29, - 22, - 3, - 56 - ], - [ - 61, - 45, - 46, - 48, - 28, - 51 - ], - [ - 44, - 4, - 11, - 25, - 54, - 59 - ], - [ - 12, - 33, - 56, - 52, - 30, - 17 - ], - [ - 55, - 29, - 17, - 42, - 23, - 14 - ], - [ - 60, - 12, - 18, - 61, - 33, - 28 - ], - [ - 35, - 58, - 37, - 63, - 6, - 27 - ], - [ - 48, - 59, - 10, - 36, - 58, - 60 - ], - [ - 17, - 7, - 28, - 31, - 29, - 27 - ], - [ - 47, - 42, - 50, - 6, - 8, - 14 - ], - [ - 39, - 58, - 56, - 37, - 18, - 59 - ], - [ - 60, - 18, - 57, - 9, - 55, - 23 - ], - [ - 43, - 63, - 18, - 60, - 19, - 22 - ], - [ - 1, - 32, - 42, - 57, - 35, - 63 - ], - [ - 42, - 61, - 3, - 32, - 1, - 50 - ], - [ - 37, - 36, - 10, - 23, - 16, - 57 - ], - [ - 53, - 61, - 7, - 57, - 21, - 23 - ], - [ - 9, - 39, - 30, - 18, - 14, - 17 - ], - [ - 33, - 44, - 8, - 5, - 0, - 19 - ], - [ - 51, - 53, - 49, - 4, - 52, - 41 - ], - [ - 40, - 4, - 27, - 6, - 9, - 16 - ], - [ - 3, - 16, - 48, - 26, - 12, - 4 - ], - [ - 61, - 14, - 12, - 54, - 35, - 53 - ] - ], - [ - [ - 47, - 23, - 63, - 11, - 61, - 55 - ], - [ - 17, - 44, - 28, - 39, - 47, - 27 - ], - [ - 34, - 53, - 50, - 38, - 29, - 5 - ], - [ - 11, - 10, - 17, - 52, - 47, - 42 - ], - [ - 15, - 41, - 27, - 20, - 12, - 6 - ], - [ - 34, - 44, - 50, - 39, - 36, - 61 - ], - [ - 37, - 41, - 52, - 29, - 46, - 47 - ], - [ - 62, - 36, - 34, - 30, - 39, - 22 - ], - [ - 62, - 16, - 58, - 5, - 2, - 8 - ], - [ - 32, - 41, - 56, - 12, - 46, - 8 - ], - [ - 10, - 35, - 45, - 41, - 3, - 56 - ], - [ - 50, - 48, - 35, - 53, - 36, - 12 - ], - [ - 39, - 11, - 46, - 7, - 23, - 51 - ], - [ - 9, - 47, - 19, - 22, - 52, - 34 - ], - [ - 35, - 18, - 56, - 50, - 3, - 23 - ], - [ - 3, - 60, - 38, - 36, - 9, - 35 - ], - [ - 46, - 28, - 32, - 5, - 43, - 56 - ], - [ - 30, - 32, - 57, - 42, - 52, - 19 - ], - [ - 2, - 32, - 34, - 61, - 14, - 42 - ], - [ - 42, - 37, - 20, - 50, - 9, - 48 - ], - [ - 53, - 7, - 56, - 25, - 60, - 13 - ], - [ - 17, - 39, - 14, - 53, - 30, - 25 - ], - [ - 5, - 40, - 6, - 33, - 29, - 25 - ], - [ - 51, - 4, - 11, - 58, - 57, - 28 - ], - [ - 40, - 37, - 4, - 44, - 8, - 48 - ], - [ - 3, - 48, - 26, - 9, - 12, - 41 - ], - [ - 47, - 61, - 26, - 24, - 20, - 53 - ] - ], - [ - [ - 31, - 43, - 41, - 47, - 11, - 25 - ], - [ - 50, - 25, - 31, - 40, - 24, - 46 - ], - [ - 23, - 9, - 62, - 15, - 20, - 53 - ], - [ - 4, - 47, - 44, - 58, - 48, - 25 - ], - [ - 2, - 19, - 12, - 52, - 0, - 40 - ], - [ - 49, - 15, - 24, - 34, - 60, - 42 - ], - [ - 12, - 46, - 17, - 29, - 41, - 3 - ], - [ - 39, - 60, - 44, - 41, - 33, - 36 - ], - [ - 21, - 60, - 16, - 44, - 51, - 57 - ], - [ - 24, - 41, - 12, - 33, - 13, - 21 - ], - [ - 43, - 62, - 3, - 12, - 28, - 45 - ], - [ - 58, - 19, - 39, - 17, - 49, - 42 - ], - [ - 25, - 54, - 4, - 7, - 11, - 39 - ], - [ - 34, - 35, - 4, - 42, - 62, - 19 - ], - [ - 43, - 41, - 42, - 35, - 40, - 32 - ], - [ - 21, - 63, - 3, - 17, - 20, - 50 - ], - [ - 58, - 46, - 44, - 1, - 25, - 20 - ], - [ - 26, - 32, - 16, - 25, - 46, - 41 - ], - [ - 37, - 63, - 61, - 28, - 24, - 56 - ], - [ - 33, - 42, - 40, - 37, - 48, - 50 - ], - [ - 11, - 53, - 25, - 39, - 4, - 61 - ], - [ - 12, - 54, - 4, - 27, - 50, - 14 - ], - [ - 47, - 19, - 42, - 17, - 35, - 40 - ], - [ - 54, - 40, - 60, - 63, - 45, - 57 - ], - [ - 44, - 56, - 40, - 62, - 37, - 3 - ], - [ - 59, - 41, - 57, - 34, - 48, - 22 - ], - [ - 56, - 13, - 59, - 51, - 26, - 58 - ] - ], - [ - [ - 35, - 32, - 8, - 40, - 51, - 52 - ], - [ - 52, - 5, - 22, - 21, - 6, - 33 - ], - [ - 22, - 58, - 11, - 25, - 3, - 51 - ], - [ - 63, - 2, - 56, - 4, - 23, - 54 - ], - [ - 39, - 12, - 23, - 32, - 30, - 46 - ], - [ - 50, - 34, - 36, - 58, - 26, - 28 - ], - [ - 46, - 41, - 3, - 2, - 22, - 16 - ], - [ - 60, - 36, - 53, - 30, - 54, - 39 - ], - [ - 16, - 51, - 3, - 39, - 2, - 26 - ], - [ - 45, - 26, - 18, - 41, - 32, - 46 - ], - [ - 45, - 3, - 10, - 56, - 36, - 35 - ], - [ - 21, - 36, - 35, - 50, - 11, - 19 - ], - [ - 28, - 11, - 46, - 59, - 41, - 15 - ], - [ - 23, - 16, - 38, - 19, - 15, - 22 - ], - [ - 27, - 7, - 34, - 58, - 3, - 42 - ], - [ - 9, - 22, - 36, - 46, - 26, - 41 - ], - [ - 3, - 51, - 40, - 56, - 46, - 8 - ], - [ - 12, - 25, - 21, - 50, - 17, - 62 - ], - [ - 27, - 34, - 61, - 13, - 60, - 11 - ], - [ - 53, - 12, - 56, - 0, - 42, - 33 - ], - [ - 53, - 37, - 12, - 24, - 25, - 63 - ], - [ - 45, - 55, - 18, - 26, - 17, - 43 - ], - [ - 4, - 25, - 32, - 1, - 48, - 53 - ], - [ - 17, - 27, - 63, - 4, - 62, - 31 - ], - [ - 6, - 52, - 62, - 40, - 46, - 23 - ], - [ - 10, - 42, - 28, - 49, - 3, - 53 - ], - [ - 45, - 27, - 41, - 21, - 16, - 47 - ] - ], - [ - [ - 44, - 24, - 33, - 56, - 15, - 41 - ], - [ - 38, - 26, - 24, - 29, - 19, - 53 - ], - [ - 12, - 15, - 29, - 9, - 1, - 63 - ], - [ - 58, - 38, - 50, - 0, - 43, - 61 - ], - [ - 24, - 51, - 31, - 34, - 60, - 7 - ], - [ - 0, - 7, - 22, - 43, - 35, - 1 - ], - [ - 63, - 36, - 11, - 1, - 16, - 4 - ], - [ - 8, - 50, - 56, - 4, - 30, - 55 - ], - [ - 43, - 16, - 42, - 29, - 60, - 35 - ], - [ - 34, - 0, - 9, - 22, - 18, - 26 - ], - [ - 54, - 51, - 45, - 35, - 2, - 36 - ], - [ - 37, - 36, - 43, - 60, - 11, - 59 - ], - [ - 56, - 38, - 10, - 28, - 14, - 43 - ], - [ - 30, - 0, - 58, - 62, - 22, - 19 - ], - [ - 7, - 55, - 42, - 58, - 30, - 38 - ], - [ - 11, - 33, - 1, - 39, - 19, - 16 - ], - [ - 55, - 20, - 40, - 9, - 18, - 30 - ], - [ - 18, - 20, - 57, - 32, - 45, - 1 - ], - [ - 43, - 61, - 12, - 32, - 31, - 30 - ], - [ - 23, - 25, - 7, - 28, - 40, - 19 - ], - [ - 14, - 51, - 48, - 58, - 53, - 25 - ], - [ - 18, - 30, - 1, - 49, - 41, - 9 - ], - [ - 2, - 51, - 22, - 0, - 52, - 5 - ], - [ - 53, - 4, - 47, - 52, - 51, - 40 - ], - [ - 40, - 16, - 9, - 47, - 23, - 11 - ], - [ - 47, - 3, - 43, - 46, - 26, - 53 - ], - [ - 8, - 40, - 18, - 46, - 33, - 63 - ] - ], - [ - [ - 48, - 38, - 50, - 42, - 63, - 36 - ], - [ - 3, - 10, - 26, - 2, - 6, - 61 - ], - [ - 39, - 44, - 45, - 40, - 6, - 7 - ], - [ - 41, - 5, - 20, - 49, - 56, - 13 - ], - [ - 6, - 1, - 30, - 37, - 28, - 38 - ], - [ - 59, - 46, - 22, - 35, - 61, - 0 - ], - [ - 1, - 63, - 35, - 3, - 60, - 49 - ], - [ - 8, - 12, - 2, - 50, - 5, - 55 - ], - [ - 42, - 33, - 43, - 16, - 32, - 29 - ], - [ - 9, - 34, - 0, - 20, - 41, - 31 - ], - [ - 51, - 54, - 8, - 19, - 63, - 9 - ], - [ - 37, - 56, - 36, - 11, - 59, - 43 - ], - [ - 38, - 10, - 28, - 17, - 56, - 63 - ], - [ - 27, - 30, - 42, - 19, - 0, - 22 - ], - [ - 7, - 55, - 49, - 42, - 58, - 38 - ], - [ - 29, - 34, - 39, - 33, - 47, - 11 - ], - [ - 55, - 40, - 20, - 18, - 7, - 5 - ], - [ - 43, - 57, - 39, - 54, - 48, - 28 - ], - [ - 12, - 43, - 61, - 42, - 32, - 49 - ], - [ - 23, - 36, - 1, - 7, - 59, - 28 - ], - [ - 14, - 53, - 21, - 7, - 57, - 37 - ], - [ - 18, - 1, - 24, - 60, - 30, - 9 - ], - [ - 51, - 0, - 33, - 2, - 44, - 5 - ], - [ - 52, - 29, - 4, - 41, - 54, - 58 - ], - [ - 40, - 19, - 16, - 9, - 46, - 47 - ], - [ - 61, - 3, - 47, - 22, - 21, - 53 - ], - [ - 35, - 60, - 54, - 1, - 5, - 40 - ] - ], - [ - [ - 17, - 18, - 8, - 53, - 25, - 43 - ], - [ - 9, - 38, - 24, - 47, - 25, - 63 - ], - [ - 20, - 24, - 5, - 12, - 54, - 28 - ], - [ - 43, - 10, - 20, - 42, - 11, - 8 - ], - [ - 53, - 61, - 30, - 39, - 29, - 18 - ], - [ - 61, - 56, - 25, - 40, - 5, - 22 - ], - [ - 62, - 17, - 24, - 1, - 47, - 33 - ], - [ - 41, - 16, - 34, - 39, - 29, - 8 - ], - [ - 39, - 16, - 36, - 42, - 29, - 23 - ], - [ - 9, - 11, - 41, - 63, - 56, - 31 - ], - [ - 48, - 51, - 10, - 62, - 63, - 45 - ], - [ - 36, - 11, - 37, - 42, - 58, - 46 - ], - [ - 51, - 38, - 25, - 63, - 29, - 44 - ], - [ - 4, - 56, - 44, - 62, - 58, - 30 - ], - [ - 3, - 7, - 46, - 42, - 33, - 35 - ], - [ - 39, - 9, - 33, - 58, - 60, - 29 - ], - [ - 40, - 37, - 20, - 16, - 55, - 25 - ], - [ - 54, - 19, - 11, - 57, - 0, - 39 - ], - [ - 12, - 43, - 61, - 25, - 49, - 32 - ], - [ - 4, - 23, - 54, - 36, - 7, - 28 - ], - [ - 40, - 25, - 26, - 14, - 2, - 58 - ], - [ - 18, - 58, - 24, - 1, - 22, - 46 - ], - [ - 2, - 63, - 22, - 6, - 44, - 56 - ], - [ - 52, - 29, - 51, - 4, - 40, - 32 - ], - [ - 40, - 17, - 15, - 16, - 46, - 57 - ], - [ - 9, - 61, - 3, - 47, - 24, - 11 - ], - [ - 2, - 39, - 24, - 42, - 0, - 44 - ] - ], - [ - [ - 0, - 10, - 49, - 23, - 62, - 44 - ], - [ - 28, - 0, - 36, - 26, - 47, - 52 - ], - [ - 30, - 4, - 16, - 48, - 40, - 10 - ], - [ - 61, - 32, - 26, - 16, - 33, - 62 - ], - [ - 30, - 39, - 53, - 5, - 57, - 20 - ], - [ - 5, - 37, - 61, - 15, - 25, - 6 - ], - [ - 15, - 17, - 24, - 60, - 49, - 62 - ], - [ - 34, - 39, - 61, - 0, - 58, - 40 - ], - [ - 16, - 39, - 36, - 51, - 2, - 29 - ], - [ - 9, - 11, - 41, - 31, - 56, - 52 - ], - [ - 10, - 48, - 24, - 45, - 62, - 51 - ], - [ - 11, - 38, - 36, - 37, - 6, - 42 - ], - [ - 51, - 50, - 15, - 30, - 25, - 38 - ], - [ - 4, - 19, - 24, - 35, - 31, - 48 - ], - [ - 7, - 46, - 3, - 58, - 30, - 41 - ], - [ - 58, - 9, - 39, - 32, - 29, - 40 - ], - [ - 40, - 37, - 20, - 8, - 25, - 55 - ], - [ - 19, - 0, - 54, - 52, - 17, - 39 - ], - [ - 25, - 43, - 12, - 61, - 14, - 11 - ], - [ - 23, - 4, - 54, - 36, - 28, - 33 - ], - [ - 40, - 2, - 25, - 58, - 36, - 53 - ], - [ - 18, - 46, - 35, - 22, - 53, - 16 - ], - [ - 2, - 6, - 63, - 14, - 42, - 11 - ], - [ - 35, - 7, - 52, - 40, - 29, - 57 - ], - [ - 40, - 15, - 19, - 57, - 17, - 23 - ], - [ - 9, - 11, - 47, - 22, - 49, - 1 - ], - [ - 24, - 39, - 42, - 2, - 16, - 0 - ] - ], - [ - [ - 55, - 39, - 9, - 43, - 21, - 46 - ], - [ - 56, - 0, - 63, - 39, - 30, - 41 - ], - [ - 20, - 1, - 26, - 58, - 34, - 19 - ], - [ - 54, - 24, - 32, - 51, - 26, - 44 - ], - [ - 30, - 53, - 56, - 39, - 34, - 40 - ], - [ - 5, - 37, - 25, - 50, - 6, - 61 - ], - [ - 24, - 49, - 37, - 15, - 6, - 29 - ], - [ - 34, - 16, - 30, - 61, - 10, - 36 - ], - [ - 16, - 29, - 2, - 5, - 51, - 26 - ], - [ - 9, - 56, - 11, - 31, - 46, - 45 - ], - [ - 10, - 45, - 56, - 62, - 25, - 36 - ], - [ - 11, - 6, - 35, - 36, - 1, - 52 - ], - [ - 51, - 50, - 41, - 46, - 38, - 4 - ], - [ - 19, - 33, - 41, - 16, - 31, - 52 - ], - [ - 34, - 7, - 17, - 47, - 63, - 3 - ], - [ - 58, - 9, - 22, - 61, - 59, - 8 - ], - [ - 40, - 37, - 3, - 51, - 22, - 57 - ], - [ - 12, - 52, - 21, - 54, - 25, - 19 - ], - [ - 34, - 53, - 27, - 43, - 14, - 13 - ], - [ - 56, - 44, - 53, - 24, - 60, - 43 - ], - [ - 12, - 53, - 40, - 49, - 2, - 62 - ], - [ - 18, - 39, - 44, - 61, - 26, - 23 - ], - [ - 0, - 4, - 53, - 25, - 41, - 21 - ], - [ - 1, - 7, - 25, - 10, - 40, - 56 - ], - [ - 40, - 22, - 6, - 29, - 19, - 48 - ], - [ - 28, - 10, - 47, - 55, - 42, - 44 - ], - [ - 30, - 27, - 57, - 16, - 50, - 59 - ] - ], - [ - [ - 45, - 37, - 48, - 29, - 30, - 3 - ], - [ - 8, - 60, - 10, - 59, - 43, - 6 - ], - [ - 51, - 45, - 28, - 59, - 63, - 34 - ], - [ - 4, - 16, - 20, - 58, - 44, - 28 - ], - [ - 50, - 31, - 57, - 24, - 51, - 53 - ], - [ - 58, - 9, - 0, - 61, - 35, - 41 - ], - [ - 16, - 63, - 11, - 61, - 23, - 36 - ], - [ - 4, - 47, - 42, - 53, - 8, - 30 - ], - [ - 44, - 14, - 16, - 33, - 3, - 20 - ], - [ - 34, - 28, - 26, - 57, - 22, - 18 - ], - [ - 20, - 35, - 19, - 59, - 2, - 38 - ], - [ - 12, - 60, - 43, - 63, - 32, - 62 - ], - [ - 28, - 12, - 29, - 11, - 14, - 50 - ], - [ - 23, - 29, - 33, - 22, - 11, - 19 - ], - [ - 23, - 60, - 51, - 50, - 7, - 22 - ], - [ - 44, - 46, - 49, - 7, - 1, - 12 - ], - [ - 2, - 54, - 27, - 61, - 18, - 5 - ], - [ - 17, - 50, - 51, - 32, - 33, - 34 - ], - [ - 5, - 19, - 61, - 27, - 32, - 11 - ], - [ - 6, - 0, - 5, - 13, - 41, - 57 - ], - [ - 27, - 33, - 53, - 45, - 38, - 32 - ], - [ - 26, - 36, - 55, - 59, - 61, - 18 - ], - [ - 47, - 46, - 3, - 37, - 57, - 49 - ], - [ - 20, - 22, - 4, - 16, - 51, - 11 - ], - [ - 62, - 11, - 21, - 34, - 4, - 1 - ], - [ - 34, - 18, - 7, - 60, - 33, - 32 - ], - [ - 45, - 52, - 4, - 36, - 21, - 9 - ] - ], - [ - [ - 18, - 8, - 20, - 49, - 30, - 23 - ], - [ - 1, - 27, - 26, - 22, - 59, - 36 - ], - [ - 43, - 26, - 15, - 58, - 0, - 46 - ], - [ - 55, - 1, - 35, - 28, - 16, - 32 - ], - [ - 59, - 9, - 10, - 53, - 12, - 58 - ], - [ - 9, - 2, - 27, - 11, - 61, - 43 - ], - [ - 16, - 57, - 63, - 23, - 19, - 12 - ], - [ - 46, - 45, - 26, - 4, - 30, - 37 - ], - [ - 43, - 44, - 20, - 16, - 14, - 9 - ], - [ - 34, - 47, - 42, - 43, - 26, - 51 - ], - [ - 42, - 2, - 38, - 45, - 20, - 36 - ], - [ - 18, - 7, - 12, - 2, - 43, - 60 - ], - [ - 1, - 28, - 12, - 3, - 29, - 33 - ], - [ - 25, - 13, - 0, - 63, - 2, - 62 - ], - [ - 18, - 36, - 6, - 29, - 19, - 15 - ], - [ - 1, - 42, - 63, - 41, - 57, - 19 - ], - [ - 57, - 54, - 5, - 27, - 18, - 31 - ], - [ - 50, - 6, - 13, - 32, - 17, - 20 - ], - [ - 17, - 5, - 27, - 32, - 1, - 55 - ], - [ - 49, - 0, - 61, - 5, - 10, - 30 - ], - [ - 29, - 53, - 51, - 13, - 33, - 46 - ], - [ - 29, - 17, - 21, - 30, - 14, - 40 - ], - [ - 5, - 17, - 33, - 32, - 18, - 28 - ], - [ - 51, - 4, - 20, - 54, - 58, - 41 - ], - [ - 47, - 4, - 27, - 48, - 37, - 60 - ], - [ - 3, - 26, - 12, - 59, - 2, - 48 - ], - [ - 46, - 43, - 18, - 20, - 9, - 53 - ] - ], - [ - [ - 45, - 6, - 57, - 43, - 40, - 58 - ], - [ - 38, - 63, - 36, - 27, - 54, - 33 - ], - [ - 37, - 14, - 19, - 41, - 58, - 63 - ], - [ - 9, - 12, - 2, - 55, - 28, - 23 - ], - [ - 39, - 59, - 7, - 13, - 33, - 43 - ], - [ - 45, - 9, - 63, - 27, - 32, - 58 - ], - [ - 16, - 57, - 10, - 63, - 11, - 23 - ], - [ - 51, - 45, - 25, - 4, - 21, - 30 - ], - [ - 21, - 44, - 14, - 16, - 33, - 39 - ], - [ - 42, - 44, - 43, - 5, - 37, - 34 - ], - [ - 42, - 19, - 20, - 2, - 38, - 61 - ], - [ - 4, - 12, - 2, - 62, - 63, - 36 - ], - [ - 32, - 55, - 0, - 11, - 47, - 28 - ], - [ - 43, - 13, - 2, - 44, - 26, - 50 - ], - [ - 49, - 33, - 15, - 28, - 29, - 35 - ], - [ - 44, - 41, - 7, - 2, - 22, - 63 - ], - [ - 48, - 6, - 54, - 20, - 2, - 27 - ], - [ - 50, - 51, - 32, - 3, - 17, - 36 - ], - [ - 5, - 61, - 57, - 48, - 19, - 32 - ], - [ - 21, - 0, - 6, - 31, - 29, - 47 - ], - [ - 33, - 9, - 53, - 27, - 17, - 36 - ], - [ - 29, - 26, - 55, - 19, - 17, - 62 - ], - [ - 12, - 46, - 5, - 37, - 57, - 3 - ], - [ - 20, - 51, - 4, - 22, - 16, - 41 - ], - [ - 21, - 11, - 62, - 46, - 23, - 48 - ], - [ - 32, - 60, - 37, - 18, - 3, - 7 - ], - [ - 9, - 11, - 36, - 48, - 0, - 45 - ] - ], - [ - [ - 49, - 42, - 28, - 23, - 33, - 61 - ], - [ - 4, - 2, - 12, - 6, - 8, - 55 - ], - [ - 12, - 0, - 26, - 41, - 6, - 27 - ], - [ - 9, - 57, - 6, - 23, - 51, - 28 - ], - [ - 40, - 7, - 20, - 16, - 15, - 33 - ], - [ - 45, - 59, - 63, - 62, - 32, - 3 - ], - [ - 10, - 39, - 57, - 13, - 16, - 19 - ], - [ - 45, - 23, - 51, - 33, - 25, - 46 - ], - [ - 28, - 21, - 44, - 11, - 16, - 59 - ], - [ - 5, - 42, - 44, - 24, - 43, - 47 - ], - [ - 42, - 53, - 30, - 18, - 2, - 27 - ], - [ - 2, - 62, - 4, - 43, - 10, - 36 - ], - [ - 0, - 56, - 55, - 47, - 32, - 49 - ], - [ - 43, - 25, - 2, - 5, - 3, - 49 - ], - [ - 4, - 28, - 15, - 8, - 49, - 58 - ], - [ - 2, - 42, - 44, - 41, - 7, - 63 - ], - [ - 48, - 27, - 54, - 20, - 2, - 18 - ], - [ - 50, - 51, - 0, - 36, - 3, - 32 - ], - [ - 57, - 5, - 61, - 19, - 32, - 38 - ], - [ - 21, - 0, - 6, - 63, - 23, - 51 - ], - [ - 33, - 53, - 27, - 36, - 9, - 38 - ], - [ - 29, - 26, - 55, - 62, - 18, - 31 - ], - [ - 46, - 56, - 12, - 53, - 29, - 0 - ], - [ - 20, - 16, - 22, - 4, - 51, - 17 - ], - [ - 21, - 62, - 11, - 31, - 46, - 33 - ], - [ - 37, - 60, - 18, - 7, - 32, - 44 - ], - [ - 11, - 9, - 36, - 0, - 48, - 63 - ] - ], - [ - [ - 41, - 32, - 49, - 39, - 61, - 44 - ], - [ - 47, - 26, - 16, - 21, - 36, - 22 - ], - [ - 4, - 30, - 37, - 42, - 60, - 54 - ], - [ - 9, - 57, - 26, - 32, - 50, - 20 - ], - [ - 56, - 3, - 40, - 33, - 36, - 54 - ], - [ - 11, - 38, - 2, - 32, - 61, - 30 - ], - [ - 39, - 57, - 19, - 10, - 16, - 42 - ], - [ - 46, - 21, - 35, - 39, - 45, - 25 - ], - [ - 21, - 37, - 12, - 20, - 11, - 28 - ], - [ - 5, - 47, - 44, - 10, - 42, - 23 - ], - [ - 18, - 42, - 61, - 2, - 38, - 31 - ], - [ - 54, - 4, - 2, - 7, - 22, - 16 - ], - [ - 5, - 3, - 17, - 56, - 32, - 55 - ], - [ - 55, - 0, - 2, - 25, - 43, - 5 - ], - [ - 22, - 28, - 15, - 6, - 5, - 49 - ], - [ - 2, - 57, - 19, - 54, - 41, - 30 - ], - [ - 7, - 48, - 20, - 54, - 27, - 0 - ], - [ - 3, - 56, - 13, - 37, - 43, - 59 - ], - [ - 45, - 55, - 57, - 61, - 48, - 52 - ], - [ - 21, - 5, - 0, - 16, - 27, - 23 - ], - [ - 25, - 42, - 17, - 54, - 23, - 14 - ], - [ - 21, - 44, - 15, - 20, - 42, - 18 - ], - [ - 35, - 12, - 25, - 53, - 61, - 2 - ], - [ - 38, - 54, - 48, - 53, - 21, - 36 - ], - [ - 13, - 31, - 48, - 33, - 18, - 55 - ], - [ - 38, - 27, - 19, - 6, - 44, - 3 - ], - [ - 29, - 62, - 43, - 59, - 46, - 5 - ] - ], - [ - [ - 57, - 9, - 19, - 51, - 18, - 41 - ], - [ - 28, - 57, - 36, - 8, - 48, - 60 - ], - [ - 2, - 51, - 59, - 5, - 34, - 9 - ], - [ - 9, - 55, - 59, - 26, - 4, - 2 - ], - [ - 49, - 56, - 35, - 42, - 30, - 23 - ], - [ - 18, - 30, - 22, - 29, - 19, - 52 - ], - [ - 39, - 34, - 33, - 51, - 56, - 3 - ], - [ - 32, - 21, - 1, - 7, - 46, - 49 - ], - [ - 33, - 54, - 23, - 21, - 12, - 11 - ], - [ - 5, - 30, - 60, - 47, - 15, - 18 - ], - [ - 4, - 18, - 46, - 27, - 20, - 22 - ], - [ - 22, - 59, - 54, - 48, - 19, - 4 - ], - [ - 17, - 5, - 56, - 31, - 49, - 4 - ], - [ - 29, - 47, - 55, - 2, - 53, - 60 - ], - [ - 8, - 22, - 11, - 44, - 36, - 15 - ], - [ - 60, - 44, - 30, - 57, - 54, - 39 - ], - [ - 7, - 44, - 27, - 20, - 2, - 61 - ], - [ - 48, - 17, - 21, - 37, - 32, - 57 - ], - [ - 48, - 32, - 46, - 6, - 61, - 42 - ], - [ - 4, - 57, - 1, - 36, - 0, - 30 - ], - [ - 7, - 17, - 61, - 53, - 21, - 63 - ], - [ - 60, - 14, - 53, - 35, - 18, - 55 - ], - [ - 10, - 15, - 33, - 51, - 36, - 5 - ], - [ - 11, - 4, - 19, - 51, - 21, - 52 - ], - [ - 47, - 19, - 43, - 48, - 58, - 4 - ], - [ - 3, - 33, - 26, - 21, - 52, - 19 - ], - [ - 24, - 45, - 60, - 35, - 49, - 1 - ] - ], - [ - [ - 16, - 4, - 44, - 23, - 22, - 43 - ], - [ - 16, - 23, - 50, - 9, - 32, - 13 - ], - [ - 11, - 35, - 21, - 7, - 48, - 59 - ], - [ - 55, - 15, - 1, - 11, - 8, - 40 - ], - [ - 35, - 61, - 30, - 59, - 31, - 62 - ], - [ - 51, - 29, - 15, - 52, - 38, - 61 - ], - [ - 60, - 0, - 55, - 34, - 59, - 33 - ], - [ - 12, - 22, - 56, - 63, - 54, - 55 - ], - [ - 42, - 54, - 23, - 33, - 27, - 47 - ], - [ - 30, - 60, - 20, - 5, - 4, - 22 - ], - [ - 4, - 35, - 22, - 46, - 23, - 19 - ], - [ - 29, - 22, - 59, - 49, - 24, - 28 - ], - [ - 17, - 31, - 5, - 56, - 4, - 9 - ], - [ - 61, - 29, - 0, - 48, - 59, - 50 - ], - [ - 8, - 6, - 22, - 60, - 55, - 31 - ], - [ - 60, - 39, - 19, - 57, - 53, - 1 - ], - [ - 56, - 30, - 22, - 10, - 5, - 18 - ], - [ - 20, - 31, - 1, - 26, - 61, - 37 - ], - [ - 8, - 23, - 7, - 46, - 48, - 4 - ], - [ - 45, - 23, - 51, - 13, - 17, - 4 - ], - [ - 17, - 13, - 61, - 29, - 14, - 55 - ], - [ - 14, - 27, - 21, - 43, - 57, - 56 - ], - [ - 52, - 51, - 15, - 58, - 8, - 5 - ], - [ - 4, - 51, - 49, - 14, - 21, - 34 - ], - [ - 4, - 9, - 47, - 13, - 8, - 61 - ], - [ - 27, - 3, - 16, - 43, - 31, - 47 - ], - [ - 59, - 43, - 29, - 61, - 0, - 40 - ] - ], - [ - [ - 48, - 21, - 18, - 49, - 41, - 23 - ], - [ - 36, - 4, - 60, - 8, - 49, - 44 - ], - [ - 20, - 39, - 30, - 59, - 45, - 55 - ], - [ - 35, - 46, - 15, - 48, - 33, - 2 - ], - [ - 61, - 37, - 8, - 15, - 54, - 10 - ], - [ - 46, - 6, - 51, - 29, - 58, - 4 - ], - [ - 28, - 11, - 44, - 60, - 0, - 1 - ], - [ - 63, - 12, - 13, - 27, - 10, - 0 - ], - [ - 33, - 42, - 54, - 44, - 23, - 14 - ], - [ - 18, - 60, - 30, - 22, - 40, - 14 - ], - [ - 35, - 4, - 61, - 9, - 18, - 33 - ], - [ - 59, - 45, - 48, - 28, - 62, - 22 - ], - [ - 17, - 56, - 7, - 5, - 53, - 36 - ], - [ - 8, - 47, - 29, - 59, - 1, - 6 - ], - [ - 9, - 8, - 18, - 22, - 60, - 15 - ], - [ - 46, - 60, - 22, - 44, - 30, - 57 - ], - [ - 61, - 2, - 27, - 34, - 7, - 60 - ], - [ - 48, - 21, - 37, - 17, - 50, - 57 - ], - [ - 44, - 42, - 5, - 2, - 48, - 61 - ], - [ - 37, - 57, - 4, - 36, - 17, - 59 - ], - [ - 7, - 17, - 44, - 61, - 53, - 33 - ], - [ - 60, - 45, - 14, - 42, - 18, - 9 - ], - [ - 10, - 8, - 36, - 33, - 15, - 58 - ], - [ - 19, - 20, - 11, - 4, - 49, - 51 - ], - [ - 61, - 37, - 47, - 23, - 12, - 3 - ], - [ - 17, - 33, - 3, - 40, - 19, - 26 - ], - [ - 61, - 45, - 49, - 1, - 14, - 63 - ] - ], - [ - [ - 13, - 40, - 55, - 63, - 26, - 41 - ], - [ - 5, - 35, - 49, - 40, - 17, - 46 - ], - [ - 38, - 17, - 59, - 49, - 2, - 58 - ], - [ - 40, - 8, - 1, - 16, - 0, - 11 - ], - [ - 37, - 62, - 51, - 10, - 8, - 38 - ], - [ - 9, - 42, - 61, - 29, - 35, - 33 - ], - [ - 63, - 53, - 11, - 16, - 33, - 60 - ], - [ - 63, - 37, - 5, - 13, - 17, - 39 - ], - [ - 44, - 20, - 31, - 54, - 38, - 21 - ], - [ - 43, - 21, - 30, - 34, - 18, - 49 - ], - [ - 2, - 11, - 19, - 35, - 4, - 9 - ], - [ - 22, - 60, - 43, - 2, - 4, - 49 - ], - [ - 29, - 5, - 17, - 22, - 24, - 55 - ], - [ - 2, - 59, - 29, - 5, - 55, - 6 - ], - [ - 23, - 8, - 36, - 22, - 15, - 30 - ], - [ - 12, - 44, - 41, - 5, - 45, - 22 - ], - [ - 54, - 7, - 41, - 11, - 1, - 53 - ], - [ - 6, - 50, - 2, - 9, - 21, - 37 - ], - [ - 13, - 19, - 5, - 10, - 48, - 8 - ], - [ - 41, - 6, - 32, - 21, - 0, - 47 - ], - [ - 38, - 33, - 36, - 53, - 31, - 61 - ], - [ - 3, - 26, - 7, - 62, - 18, - 59 - ], - [ - 56, - 57, - 46, - 12, - 35, - 3 - ], - [ - 20, - 16, - 22, - 24, - 27, - 42 - ], - [ - 36, - 21, - 46, - 34, - 3, - 11 - ], - [ - 33, - 34, - 60, - 45, - 7, - 32 - ], - [ - 56, - 34, - 52, - 58, - 26, - 48 - ] - ], - [ - [ - 54, - 23, - 53, - 11, - 58, - 3 - ], - [ - 11, - 30, - 59, - 58, - 63, - 4 - ], - [ - 20, - 29, - 58, - 17, - 42, - 4 - ], - [ - 1, - 35, - 40, - 45, - 53, - 21 - ], - [ - 40, - 55, - 33, - 21, - 38, - 49 - ], - [ - 45, - 29, - 61, - 27, - 63, - 62 - ], - [ - 33, - 57, - 11, - 28, - 53, - 34 - ], - [ - 11, - 63, - 39, - 10, - 45, - 14 - ], - [ - 30, - 54, - 57, - 59, - 33, - 26 - ], - [ - 43, - 23, - 5, - 18, - 21, - 42 - ], - [ - 11, - 18, - 2, - 9, - 34, - 6 - ], - [ - 22, - 60, - 28, - 2, - 63, - 17 - ], - [ - 5, - 41, - 6, - 17, - 56, - 29 - ], - [ - 55, - 6, - 5, - 2, - 48, - 59 - ], - [ - 23, - 19, - 22, - 62, - 11, - 9 - ], - [ - 12, - 45, - 41, - 8, - 27, - 42 - ], - [ - 11, - 53, - 41, - 44, - 51, - 7 - ], - [ - 21, - 2, - 6, - 36, - 50, - 56 - ], - [ - 13, - 10, - 48, - 53, - 61, - 19 - ], - [ - 47, - 21, - 56, - 44, - 6, - 31 - ], - [ - 44, - 12, - 3, - 55, - 41, - 53 - ], - [ - 44, - 47, - 28, - 43, - 45, - 63 - ], - [ - 1, - 25, - 53, - 11, - 39, - 19 - ], - [ - 1, - 59, - 38, - 3, - 37, - 42 - ], - [ - 45, - 3, - 0, - 21, - 22, - 33 - ], - [ - 10, - 28, - 42, - 49, - 11, - 3 - ], - [ - 30, - 57, - 15, - 16, - 56, - 41 - ] - ], - [ - [ - 53, - 15, - 34, - 0, - 46, - 33 - ], - [ - 8, - 12, - 41, - 19, - 39, - 32 - ], - [ - 56, - 31, - 36, - 13, - 23, - 9 - ], - [ - 36, - 51, - 30, - 21, - 1, - 11 - ], - [ - 13, - 58, - 50, - 2, - 53, - 54 - ], - [ - 49, - 52, - 32, - 7, - 23, - 47 - ], - [ - 61, - 38, - 23, - 39, - 0, - 35 - ], - [ - 42, - 27, - 9, - 20, - 17, - 57 - ], - [ - 34, - 1, - 29, - 4, - 35, - 45 - ], - [ - 54, - 57, - 27, - 19, - 38, - 62 - ], - [ - 59, - 1, - 60, - 26, - 38, - 22 - ], - [ - 25, - 31, - 51, - 36, - 32, - 8 - ], - [ - 14, - 62, - 2, - 19, - 37, - 11 - ], - [ - 57, - 40, - 13, - 22, - 37, - 46 - ], - [ - 45, - 34, - 58, - 44, - 42, - 16 - ], - [ - 50, - 16, - 6, - 5, - 33, - 43 - ], - [ - 42, - 39, - 61, - 13, - 5, - 15 - ], - [ - 46, - 23, - 27, - 4, - 28, - 63 - ], - [ - 62, - 31, - 10, - 45, - 35, - 56 - ], - [ - 15, - 13, - 38, - 63, - 4, - 31 - ], - [ - 34, - 15, - 38, - 57, - 27, - 19 - ], - [ - 41, - 62, - 36, - 57, - 19, - 47 - ], - [ - 34, - 22, - 53, - 10, - 46, - 33 - ], - [ - 24, - 51, - 4, - 47, - 39, - 10 - ], - [ - 11, - 57, - 51, - 50, - 54, - 48 - ], - [ - 51, - 7, - 11, - 43, - 50, - 18 - ], - [ - 39, - 37, - 9, - 42, - 40, - 44 - ] - ], - [ - [ - 57, - 17, - 62, - 42, - 23, - 60 - ], - [ - 18, - 7, - 53, - 43, - 26, - 60 - ], - [ - 60, - 5, - 3, - 53, - 23, - 57 - ], - [ - 17, - 10, - 22, - 19, - 11, - 31 - ], - [ - 10, - 15, - 12, - 27, - 17, - 4 - ], - [ - 2, - 44, - 39, - 36, - 25, - 54 - ], - [ - 52, - 62, - 37, - 21, - 41, - 42 - ], - [ - 62, - 7, - 46, - 30, - 36, - 14 - ], - [ - 50, - 4, - 10, - 58, - 0, - 16 - ], - [ - 32, - 47, - 38, - 54, - 8, - 41 - ], - [ - 41, - 32, - 3, - 18, - 1, - 22 - ], - [ - 16, - 50, - 53, - 7, - 44, - 21 - ], - [ - 11, - 39, - 3, - 35, - 25, - 46 - ], - [ - 18, - 9, - 15, - 13, - 63, - 28 - ], - [ - 62, - 58, - 13, - 5, - 17, - 3 - ], - [ - 3, - 31, - 43, - 53, - 35, - 57 - ], - [ - 24, - 51, - 15, - 46, - 32, - 5 - ], - [ - 13, - 30, - 0, - 32, - 5, - 59 - ], - [ - 2, - 39, - 32, - 38, - 34, - 22 - ], - [ - 42, - 26, - 34, - 28, - 37, - 54 - ], - [ - 28, - 43, - 53, - 41, - 13, - 23 - ], - [ - 14, - 15, - 34, - 1, - 48, - 40 - ], - [ - 5, - 25, - 4, - 33, - 39, - 6 - ], - [ - 58, - 4, - 57, - 17, - 51, - 11 - ], - [ - 37, - 47, - 35, - 31, - 63, - 29 - ], - [ - 15, - 3, - 28, - 33, - 23, - 9 - ], - [ - 23, - 6, - 58, - 47, - 56, - 30 - ] - ], - [ - [ - 47, - 29, - 14, - 6, - 51, - 43 - ], - [ - 30, - 29, - 39, - 7, - 52, - 12 - ], - [ - 63, - 34, - 41, - 2, - 47, - 7 - ], - [ - 4, - 28, - 54, - 45, - 52, - 58 - ], - [ - 29, - 7, - 12, - 15, - 41, - 6 - ], - [ - 34, - 29, - 48, - 3, - 43, - 40 - ], - [ - 30, - 29, - 16, - 47, - 42, - 45 - ], - [ - 33, - 39, - 25, - 60, - 41, - 3 - ], - [ - 50, - 26, - 4, - 25, - 17, - 13 - ], - [ - 5, - 54, - 43, - 16, - 12, - 53 - ], - [ - 18, - 6, - 35, - 3, - 21, - 1 - ], - [ - 56, - 46, - 48, - 10, - 16, - 44 - ], - [ - 9, - 35, - 7, - 24, - 47, - 57 - ], - [ - 53, - 42, - 15, - 56, - 59, - 47 - ], - [ - 39, - 11, - 36, - 32, - 35, - 18 - ], - [ - 46, - 20, - 53, - 38, - 56, - 26 - ], - [ - 58, - 29, - 14, - 26, - 17, - 49 - ], - [ - 24, - 25, - 39, - 16, - 1, - 57 - ], - [ - 24, - 41, - 4, - 1, - 63, - 28 - ], - [ - 42, - 37, - 48, - 34, - 26, - 41 - ], - [ - 11, - 28, - 16, - 32, - 7, - 56 - ], - [ - 14, - 42, - 6, - 16, - 22, - 15 - ], - [ - 33, - 56, - 42, - 8, - 25, - 38 - ], - [ - 4, - 58, - 48, - 33, - 11, - 28 - ], - [ - 37, - 47, - 29, - 48, - 30, - 53 - ], - [ - 12, - 41, - 3, - 4, - 48, - 46 - ], - [ - 14, - 13, - 61, - 6, - 62, - 1 - ] - ], - [ - [ - 45, - 10, - 44, - 43, - 53, - 33 - ], - [ - 32, - 63, - 22, - 27, - 30, - 29 - ], - [ - 54, - 35, - 37, - 32, - 26, - 30 - ], - [ - 24, - 63, - 0, - 17, - 25, - 45 - ], - [ - 40, - 7, - 0, - 57, - 29, - 22 - ], - [ - 10, - 34, - 20, - 22, - 43, - 33 - ], - [ - 42, - 30, - 5, - 25, - 19, - 34 - ], - [ - 33, - 18, - 35, - 51, - 7, - 57 - ], - [ - 50, - 28, - 25, - 4, - 10, - 9 - ], - [ - 5, - 38, - 16, - 43, - 54, - 21 - ], - [ - 18, - 21, - 3, - 6, - 39, - 53 - ], - [ - 56, - 16, - 53, - 39, - 46, - 42 - ], - [ - 9, - 35, - 57, - 11, - 47, - 13 - ], - [ - 56, - 15, - 3, - 59, - 9, - 28 - ], - [ - 4, - 62, - 39, - 58, - 63, - 36 - ], - [ - 20, - 53, - 57, - 8, - 51, - 35 - ], - [ - 51, - 49, - 11, - 26, - 15, - 14 - ], - [ - 0, - 25, - 59, - 62, - 21, - 13 - ], - [ - 39, - 34, - 48, - 53, - 61, - 33 - ], - [ - 47, - 26, - 28, - 34, - 21, - 39 - ], - [ - 28, - 43, - 12, - 53, - 41, - 32 - ], - [ - 14, - 52, - 17, - 1, - 15, - 38 - ], - [ - 25, - 4, - 5, - 11, - 58, - 50 - ], - [ - 58, - 4, - 17, - 10, - 25, - 57 - ], - [ - 29, - 47, - 35, - 31, - 52, - 48 - ], - [ - 55, - 28, - 23, - 15, - 3, - 24 - ], - [ - 6, - 30, - 57, - 32, - 34, - 62 - ] - ], - [ - [ - 39, - 5, - 30, - 17, - 61, - 15 - ], - [ - 11, - 63, - 0, - 23, - 61, - 46 - ], - [ - 61, - 15, - 53, - 22, - 7, - 57 - ], - [ - 50, - 57, - 58, - 63, - 45, - 47 - ], - [ - 55, - 31, - 57, - 24, - 60, - 5 - ], - [ - 22, - 7, - 43, - 1, - 10, - 0 - ], - [ - 14, - 58, - 1, - 34, - 19, - 45 - ], - [ - 50, - 8, - 14, - 7, - 57, - 9 - ], - [ - 43, - 0, - 4, - 10, - 45, - 46 - ], - [ - 38, - 0, - 4, - 55, - 54, - 10 - ], - [ - 51, - 54, - 46, - 39, - 1, - 38 - ], - [ - 16, - 37, - 33, - 36, - 21, - 63 - ], - [ - 10, - 38, - 57, - 58, - 3, - 63 - ], - [ - 30, - 0, - 63, - 13, - 22, - 18 - ], - [ - 55, - 58, - 62, - 38, - 6, - 36 - ], - [ - 19, - 53, - 11, - 1, - 57, - 25 - ], - [ - 11, - 6, - 51, - 16, - 30, - 18 - ], - [ - 20, - 57, - 32, - 36, - 13, - 56 - ], - [ - 3, - 12, - 61, - 26, - 32, - 1 - ], - [ - 25, - 36, - 34, - 51, - 59, - 37 - ], - [ - 13, - 53, - 28, - 16, - 14, - 9 - ], - [ - 14, - 42, - 15, - 55, - 22, - 38 - ], - [ - 62, - 58, - 29, - 33, - 5, - 34 - ], - [ - 21, - 4, - 51, - 49, - 12, - 58 - ], - [ - 61, - 60, - 40, - 35, - 59, - 47 - ], - [ - 37, - 46, - 3, - 48, - 12, - 53 - ], - [ - 63, - 8, - 46, - 33, - 1, - 53 - ] - ], - [ - [ - 62, - 0, - 9, - 61, - 26, - 41 - ], - [ - 45, - 3, - 29, - 35, - 2, - 54 - ], - [ - 56, - 31, - 53, - 28, - 23, - 2 - ], - [ - 13, - 41, - 46, - 49, - 5, - 45 - ], - [ - 18, - 25, - 57, - 55, - 50, - 15 - ], - [ - 58, - 22, - 4, - 46, - 19, - 12 - ], - [ - 1, - 25, - 58, - 22, - 43, - 35 - ], - [ - 50, - 23, - 12, - 8, - 9, - 16 - ], - [ - 29, - 17, - 16, - 43, - 10, - 4 - ], - [ - 38, - 55, - 0, - 40, - 20, - 10 - ], - [ - 51, - 19, - 54, - 39, - 46, - 1 - ], - [ - 60, - 16, - 37, - 11, - 56, - 36 - ], - [ - 10, - 7, - 38, - 57, - 54, - 44 - ], - [ - 45, - 13, - 0, - 30, - 47, - 63 - ], - [ - 58, - 48, - 55, - 38, - 36, - 29 - ], - [ - 29, - 53, - 0, - 1, - 55, - 57 - ], - [ - 11, - 12, - 6, - 29, - 5, - 40 - ], - [ - 57, - 32, - 42, - 8, - 20, - 36 - ], - [ - 3, - 61, - 49, - 32, - 26, - 1 - ], - [ - 42, - 5, - 36, - 16, - 39, - 51 - ], - [ - 9, - 13, - 7, - 53, - 46, - 28 - ], - [ - 14, - 55, - 42, - 24, - 33, - 47 - ], - [ - 33, - 58, - 62, - 15, - 36, - 5 - ], - [ - 51, - 4, - 49, - 58, - 21, - 41 - ], - [ - 40, - 60, - 19, - 35, - 61, - 15 - ], - [ - 14, - 21, - 3, - 12, - 5, - 37 - ], - [ - 35, - 60, - 1, - 10, - 53, - 54 - ] - ], - [ - [ - 40, - 28, - 60, - 31, - 59, - 23 - ], - [ - 42, - 12, - 26, - 34, - 0, - 61 - ], - [ - 16, - 5, - 62, - 3, - 32, - 0 - ], - [ - 13, - 3, - 18, - 39, - 42, - 52 - ], - [ - 25, - 12, - 39, - 0, - 57, - 15 - ], - [ - 33, - 50, - 58, - 36, - 6, - 26 - ], - [ - 41, - 43, - 1, - 15, - 2, - 25 - ], - [ - 36, - 16, - 53, - 14, - 58, - 30 - ], - [ - 16, - 29, - 3, - 2, - 10, - 5 - ], - [ - 38, - 16, - 32, - 26, - 45, - 11 - ], - [ - 45, - 19, - 3, - 46, - 56, - 53 - ], - [ - 21, - 36, - 11, - 9, - 16, - 41 - ], - [ - 11, - 16, - 10, - 50, - 8, - 63 - ], - [ - 35, - 38, - 33, - 51, - 25, - 22 - ], - [ - 58, - 48, - 27, - 50, - 29, - 26 - ], - [ - 29, - 57, - 53, - 3, - 54, - 4 - ], - [ - 3, - 23, - 6, - 46, - 51, - 11 - ], - [ - 32, - 5, - 57, - 54, - 52, - 30 - ], - [ - 34, - 61, - 40, - 0, - 3, - 48 - ], - [ - 22, - 53, - 12, - 33, - 39, - 60 - ], - [ - 20, - 53, - 47, - 40, - 12, - 8 - ], - [ - 55, - 17, - 52, - 14, - 45, - 56 - ], - [ - 1, - 4, - 13, - 11, - 39, - 33 - ], - [ - 17, - 10, - 51, - 44, - 4, - 55 - ], - [ - 35, - 6, - 45, - 21, - 52, - 37 - ], - [ - 10, - 42, - 9, - 1, - 53, - 24 - ], - [ - 27, - 41, - 32, - 45, - 10, - 47 - ] - ], - [ - [ - 11, - 34, - 44, - 51, - 41, - 12 - ], - [ - 20, - 34, - 3, - 25, - 63, - 16 - ], - [ - 40, - 26, - 37, - 22, - 15, - 54 - ], - [ - 4, - 16, - 25, - 28, - 45, - 58 - ], - [ - 33, - 10, - 32, - 24, - 3, - 4 - ], - [ - 0, - 9, - 58, - 3, - 34, - 15 - ], - [ - 16, - 63, - 43, - 59, - 25, - 42 - ], - [ - 16, - 48, - 36, - 33, - 25, - 58 - ], - [ - 20, - 50, - 16, - 23, - 42, - 29 - ], - [ - 60, - 51, - 16, - 22, - 38, - 48 - ], - [ - 31, - 19, - 38, - 3, - 2, - 43 - ], - [ - 56, - 9, - 21, - 40, - 18, - 44 - ], - [ - 11, - 13, - 10, - 45, - 27, - 57 - ], - [ - 56, - 14, - 35, - 44, - 45, - 13 - ], - [ - 58, - 43, - 26, - 33, - 31, - 50 - ], - [ - 13, - 38, - 46, - 51, - 28, - 37 - ], - [ - 58, - 63, - 42, - 10, - 0, - 16 - ], - [ - 0, - 1, - 53, - 52, - 24, - 59 - ], - [ - 17, - 29, - 0, - 40, - 60, - 24 - ], - [ - 17, - 24, - 2, - 26, - 35, - 23 - ], - [ - 44, - 20, - 22, - 4, - 50, - 40 - ], - [ - 9, - 61, - 17, - 16, - 27, - 37 - ], - [ - 63, - 13, - 16, - 11, - 2, - 24 - ], - [ - 44, - 48, - 23, - 49, - 56, - 45 - ], - [ - 6, - 5, - 49, - 33, - 31, - 63 - ], - [ - 49, - 47, - 1, - 57, - 4, - 53 - ], - [ - 2, - 56, - 19, - 63, - 39, - 50 - ] - ], - [ - [ - 16, - 22, - 44, - 23, - 3, - 13 - ], - [ - 16, - 23, - 50, - 9, - 13, - 59 - ], - [ - 11, - 35, - 21, - 7, - 59, - 24 - ], - [ - 53, - 32, - 7, - 60, - 1, - 13 - ], - [ - 24, - 33, - 57, - 35, - 0, - 39 - ], - [ - 15, - 38, - 0, - 45, - 32, - 58 - ], - [ - 59, - 16, - 25, - 49, - 35, - 5 - ], - [ - 45, - 40, - 16, - 14, - 12, - 30 - ], - [ - 20, - 16, - 42, - 6, - 17, - 23 - ], - [ - 60, - 10, - 22, - 48, - 50, - 38 - ], - [ - 31, - 19, - 35, - 27, - 38, - 53 - ], - [ - 33, - 59, - 48, - 16, - 60, - 39 - ], - [ - 9, - 57, - 10, - 11, - 7, - 44 - ], - [ - 13, - 0, - 59, - 61, - 19, - 33 - ], - [ - 18, - 39, - 58, - 21, - 50, - 36 - ], - [ - 57, - 60, - 49, - 19, - 53, - 54 - ], - [ - 6, - 36, - 17, - 41, - 38, - 10 - ], - [ - 63, - 1, - 57, - 20, - 36, - 31 - ], - [ - 40, - 41, - 7, - 0, - 61, - 50 - ], - [ - 39, - 16, - 23, - 14, - 17, - 59 - ], - [ - 51, - 30, - 50, - 7, - 3, - 61 - ], - [ - 9, - 14, - 42, - 15, - 1, - 55 - ], - [ - 33, - 19, - 58, - 56, - 11, - 15 - ], - [ - 9, - 26, - 4, - 41, - 51, - 10 - ], - [ - 4, - 60, - 35, - 16, - 47, - 48 - ], - [ - 3, - 52, - 31, - 16, - 26, - 39 - ], - [ - 4, - 1, - 3, - 61, - 0, - 54 - ] - ], - [ - [ - 38, - 45, - 52, - 43, - 32, - 39 - ], - [ - 29, - 48, - 22, - 60, - 55, - 57 - ], - [ - 5, - 49, - 8, - 20, - 14, - 55 - ], - [ - 35, - 46, - 49, - 32, - 7, - 48 - ], - [ - 24, - 29, - 49, - 16, - 41, - 0 - ], - [ - 56, - 3, - 35, - 15, - 4, - 40 - ], - [ - 49, - 59, - 4, - 16, - 33, - 11 - ], - [ - 12, - 51, - 30, - 56, - 40, - 0 - ], - [ - 23, - 46, - 14, - 55, - 42, - 17 - ], - [ - 60, - 22, - 38, - 20, - 21, - 6 - ], - [ - 11, - 19, - 31, - 35, - 38, - 22 - ], - [ - 59, - 48, - 19, - 33, - 16, - 9 - ], - [ - 9, - 28, - 45, - 57, - 55, - 53 - ], - [ - 59, - 1, - 25, - 6, - 47, - 3 - ], - [ - 9, - 39, - 18, - 50, - 58, - 21 - ], - [ - 60, - 57, - 38, - 53, - 50, - 54 - ], - [ - 41, - 53, - 6, - 38, - 10, - 25 - ], - [ - 1, - 63, - 15, - 24, - 36, - 29 - ], - [ - 41, - 40, - 7, - 42, - 62, - 32 - ], - [ - 4, - 10, - 14, - 37, - 39, - 17 - ], - [ - 61, - 30, - 7, - 50, - 3, - 57 - ], - [ - 9, - 58, - 14, - 30, - 42, - 8 - ], - [ - 33, - 56, - 19, - 60, - 24, - 59 - ], - [ - 9, - 26, - 4, - 51, - 25, - 41 - ], - [ - 4, - 35, - 47, - 41, - 48, - 60 - ], - [ - 39, - 52, - 3, - 26, - 8, - 30 - ], - [ - 4, - 1, - 61, - 10, - 0, - 12 - ] - ], - [ - [ - 51, - 53, - 33, - 13, - 28, - 48 - ], - [ - 63, - 31, - 41, - 39, - 40, - 49 - ], - [ - 42, - 14, - 3, - 24, - 50, - 44 - ], - [ - 11, - 39, - 52, - 10, - 17, - 42 - ], - [ - 7, - 60, - 58, - 15, - 12, - 27 - ], - [ - 34, - 36, - 44, - 28, - 29, - 40 - ], - [ - 52, - 17, - 47, - 42, - 37, - 41 - ], - [ - 62, - 33, - 38, - 39, - 41, - 36 - ], - [ - 50, - 58, - 37, - 10, - 16, - 62 - ], - [ - 41, - 59, - 12, - 62, - 49, - 32 - ], - [ - 3, - 35, - 10, - 41, - 6, - 25 - ], - [ - 39, - 58, - 53, - 21, - 19, - 54 - ], - [ - 42, - 35, - 7, - 39, - 63, - 21 - ], - [ - 15, - 9, - 3, - 54, - 51, - 32 - ], - [ - 3, - 35, - 32, - 36, - 18, - 13 - ], - [ - 20, - 3, - 35, - 45, - 32, - 4 - ], - [ - 46, - 24, - 32, - 33, - 14, - 44 - ], - [ - 16, - 57, - 30, - 19, - 61, - 63 - ], - [ - 41, - 25, - 2, - 24, - 26, - 4 - ], - [ - 42, - 37, - 47, - 48, - 33, - 50 - ], - [ - 54, - 13, - 53, - 56, - 63, - 0 - ], - [ - 34, - 14, - 7, - 51, - 42, - 6 - ], - [ - 6, - 43, - 25, - 58, - 11, - 39 - ], - [ - 0, - 26, - 44, - 4, - 51, - 9 - ], - [ - 37, - 63, - 4, - 15, - 13, - 61 - ], - [ - 9, - 59, - 41, - 1, - 52, - 25 - ], - [ - 56, - 58, - 47, - 1, - 62, - 51 - ] - ], - [ - [ - 52, - 47, - 27, - 36, - 38, - 33 - ], - [ - 43, - 56, - 4, - 25, - 52, - 21 - ], - [ - 25, - 54, - 35, - 18, - 11, - 63 - ], - [ - 17, - 4, - 1, - 18, - 50, - 39 - ], - [ - 36, - 7, - 32, - 4, - 30, - 60 - ], - [ - 34, - 3, - 61, - 48, - 24, - 40 - ], - [ - 17, - 42, - 3, - 12, - 29, - 25 - ], - [ - 33, - 38, - 39, - 55, - 17, - 19 - ], - [ - 17, - 50, - 41, - 16, - 13, - 51 - ], - [ - 41, - 12, - 60, - 16, - 62, - 58 - ], - [ - 3, - 36, - 28, - 39, - 35, - 2 - ], - [ - 48, - 27, - 53, - 54, - 19, - 43 - ], - [ - 42, - 63, - 7, - 35, - 62, - 3 - ], - [ - 56, - 15, - 9, - 51, - 42, - 47 - ], - [ - 35, - 39, - 18, - 36, - 43, - 38 - ], - [ - 60, - 20, - 0, - 3, - 35, - 31 - ], - [ - 24, - 14, - 42, - 17, - 32, - 62 - ], - [ - 24, - 57, - 63, - 0, - 42, - 1 - ], - [ - 41, - 50, - 28, - 26, - 32, - 40 - ], - [ - 2, - 37, - 46, - 17, - 42, - 57 - ], - [ - 49, - 7, - 11, - 54, - 13, - 53 - ], - [ - 14, - 9, - 42, - 6, - 19, - 58 - ], - [ - 63, - 33, - 43, - 19, - 58, - 25 - ], - [ - 9, - 26, - 48, - 4, - 41, - 44 - ], - [ - 4, - 41, - 5, - 15, - 9, - 53 - ], - [ - 14, - 38, - 9, - 3, - 52, - 8 - ], - [ - 23, - 1, - 61, - 4, - 47, - 14 - ] - ], - [ - [ - 37, - 14, - 3, - 5, - 33, - 53 - ], - [ - 34, - 0, - 56, - 58, - 37, - 13 - ], - [ - 61, - 14, - 22, - 29, - 15, - 46 - ], - [ - 50, - 58, - 18, - 44, - 47, - 17 - ], - [ - 4, - 36, - 53, - 7, - 32, - 2 - ], - [ - 22, - 1, - 34, - 61, - 10, - 33 - ], - [ - 28, - 25, - 42, - 29, - 30, - 3 - ], - [ - 33, - 51, - 38, - 39, - 62, - 60 - ], - [ - 17, - 51, - 26, - 16, - 46, - 50 - ], - [ - 12, - 16, - 59, - 6, - 38, - 3 - ], - [ - 39, - 31, - 35, - 3, - 36, - 0 - ], - [ - 48, - 19, - 53, - 56, - 39, - 27 - ], - [ - 7, - 62, - 9, - 63, - 15, - 42 - ], - [ - 56, - 19, - 9, - 47, - 15, - 59 - ], - [ - 39, - 36, - 18, - 35, - 42, - 38 - ], - [ - 60, - 20, - 0, - 57, - 47, - 53 - ], - [ - 6, - 32, - 29, - 62, - 43, - 5 - ], - [ - 24, - 63, - 57, - 1, - 42, - 2 - ], - [ - 41, - 40, - 7, - 62, - 32, - 50 - ], - [ - 37, - 14, - 10, - 2, - 57, - 17 - ], - [ - 7, - 54, - 30, - 53, - 50, - 49 - ], - [ - 9, - 14, - 8, - 6, - 42, - 58 - ], - [ - 33, - 60, - 56, - 19, - 18, - 15 - ], - [ - 9, - 26, - 4, - 41, - 59, - 23 - ], - [ - 4, - 41, - 47, - 9, - 2, - 16 - ], - [ - 3, - 14, - 16, - 30, - 8, - 52 - ], - [ - 1, - 4, - 14, - 61, - 0, - 12 - ] - ], - [ - [ - 27, - 21, - 61, - 30, - 22, - 63 - ], - [ - 33, - 12, - 0, - 44, - 47, - 27 - ], - [ - 14, - 39, - 58, - 37, - 16, - 63 - ], - [ - 11, - 47, - 4, - 18, - 32, - 35 - ], - [ - 2, - 19, - 20, - 39, - 61, - 10 - ], - [ - 39, - 61, - 49, - 15, - 30, - 47 - ], - [ - 12, - 47, - 17, - 49, - 29, - 46 - ], - [ - 39, - 41, - 62, - 30, - 52, - 37 - ], - [ - 21, - 15, - 28, - 48, - 26, - 27 - ], - [ - 12, - 32, - 16, - 24, - 55, - 41 - ], - [ - 43, - 35, - 28, - 31, - 60, - 47 - ], - [ - 3, - 17, - 19, - 49, - 10, - 42 - ], - [ - 7, - 54, - 27, - 59, - 4, - 23 - ], - [ - 35, - 33, - 7, - 6, - 14, - 51 - ], - [ - 51, - 28, - 36, - 42, - 41, - 32 - ], - [ - 45, - 52, - 60, - 21, - 53, - 12 - ], - [ - 53, - 1, - 23, - 5, - 6, - 46 - ], - [ - 35, - 9, - 63, - 16, - 26, - 22 - ], - [ - 41, - 60, - 27, - 40, - 42, - 15 - ], - [ - 42, - 48, - 46, - 9, - 17, - 37 - ], - [ - 31, - 11, - 23, - 58, - 39, - 44 - ], - [ - 34, - 2, - 4, - 14, - 13, - 9 - ], - [ - 60, - 59, - 47, - 40, - 17, - 27 - ], - [ - 60, - 63, - 7, - 42, - 40, - 49 - ], - [ - 59, - 15, - 38, - 62, - 44, - 25 - ], - [ - 23, - 57, - 60, - 31, - 41, - 3 - ], - [ - 38, - 59, - 31, - 51, - 36, - 7 - ] - ], - [ - [ - 12, - 10, - 50, - 23, - 53, - 33 - ], - [ - 48, - 56, - 44, - 11, - 31, - 17 - ], - [ - 33, - 11, - 17, - 54, - 15, - 62 - ], - [ - 31, - 13, - 17, - 40, - 8, - 3 - ], - [ - 42, - 4, - 27, - 15, - 12, - 5 - ], - [ - 34, - 50, - 33, - 29, - 55, - 6 - ], - [ - 16, - 49, - 52, - 29, - 41, - 30 - ], - [ - 0, - 30, - 62, - 57, - 26, - 22 - ], - [ - 26, - 62, - 5, - 58, - 51, - 49 - ], - [ - 5, - 62, - 16, - 46, - 39, - 37 - ], - [ - 31, - 10, - 24, - 45, - 18, - 35 - ], - [ - 50, - 19, - 48, - 54, - 16, - 35 - ], - [ - 59, - 39, - 62, - 7, - 35, - 28 - ], - [ - 20, - 9, - 19, - 6, - 22, - 15 - ], - [ - 35, - 18, - 50, - 32, - 36, - 39 - ], - [ - 3, - 4, - 13, - 38, - 60, - 26 - ], - [ - 46, - 32, - 28, - 3, - 37, - 33 - ], - [ - 30, - 63, - 16, - 19, - 24, - 42 - ], - [ - 41, - 17, - 4, - 2, - 32, - 34 - ], - [ - 42, - 37, - 48, - 33, - 3, - 31 - ], - [ - 56, - 7, - 25, - 11, - 39, - 44 - ], - [ - 42, - 14, - 26, - 13, - 12, - 22 - ], - [ - 25, - 11, - 6, - 42, - 13, - 38 - ], - [ - 0, - 9, - 26, - 41, - 4, - 57 - ], - [ - 37, - 4, - 63, - 41, - 2, - 44 - ], - [ - 9, - 52, - 41, - 3, - 16, - 59 - ], - [ - 58, - 47, - 56, - 1, - 26, - 62 - ] - ], - [ - [ - 54, - 30, - 22, - 26, - 3, - 55 - ], - [ - 44, - 7, - 49, - 50, - 25, - 5 - ], - [ - 54, - 4, - 48, - 58, - 26, - 32 - ], - [ - 20, - 25, - 3, - 9, - 55, - 28 - ], - [ - 7, - 27, - 42, - 12, - 58, - 32 - ], - [ - 50, - 34, - 6, - 42, - 29, - 55 - ], - [ - 16, - 49, - 40, - 3, - 27, - 11 - ], - [ - 51, - 30, - 26, - 62, - 53, - 57 - ], - [ - 3, - 5, - 49, - 28, - 26, - 50 - ], - [ - 16, - 22, - 46, - 6, - 49, - 45 - ], - [ - 31, - 45, - 11, - 10, - 56, - 3 - ], - [ - 21, - 51, - 50, - 19, - 9, - 61 - ], - [ - 28, - 41, - 59, - 13, - 34, - 53 - ], - [ - 23, - 20, - 16, - 9, - 38, - 19 - ], - [ - 50, - 34, - 58, - 20, - 27, - 35 - ], - [ - 8, - 59, - 29, - 61, - 35, - 53 - ], - [ - 3, - 46, - 51, - 10, - 25, - 18 - ], - [ - 62, - 21, - 25, - 9, - 18, - 56 - ], - [ - 34, - 27, - 10, - 29, - 53, - 59 - ], - [ - 56, - 33, - 44, - 24, - 9, - 18 - ], - [ - 44, - 12, - 19, - 30, - 8, - 49 - ], - [ - 47, - 62, - 51, - 12, - 16, - 17 - ], - [ - 11, - 37, - 25, - 58, - 35, - 45 - ], - [ - 4, - 9, - 10, - 39, - 37, - 16 - ], - [ - 22, - 11, - 21, - 48, - 45, - 47 - ], - [ - 10, - 28, - 55, - 12, - 24, - 23 - ], - [ - 30, - 16, - 27, - 32, - 57, - 15 - ] - ], - [ - [ - 16, - 11, - 31, - 46, - 35, - 49 - ], - [ - 13, - 49, - 54, - 5, - 6, - 14 - ], - [ - 36, - 13, - 27, - 46, - 3, - 18 - ], - [ - 24, - 61, - 15, - 0, - 63, - 13 - ], - [ - 17, - 2, - 58, - 50, - 35, - 19 - ], - [ - 8, - 7, - 49, - 52, - 47, - 23 - ], - [ - 61, - 58, - 16, - 20, - 38, - 23 - ], - [ - 20, - 42, - 9, - 51, - 35, - 16 - ], - [ - 25, - 47, - 4, - 1, - 50, - 63 - ], - [ - 54, - 38, - 57, - 27, - 33, - 28 - ], - [ - 1, - 59, - 30, - 60, - 14, - 31 - ], - [ - 31, - 51, - 25, - 14, - 41, - 55 - ], - [ - 34, - 2, - 14, - 11, - 19, - 28 - ], - [ - 40, - 57, - 13, - 61, - 59, - 18 - ], - [ - 44, - 58, - 50, - 45, - 37, - 26 - ], - [ - 5, - 8, - 62, - 24, - 38, - 61 - ], - [ - 13, - 47, - 15, - 3, - 5, - 39 - ], - [ - 25, - 23, - 9, - 18, - 63, - 33 - ], - [ - 1, - 10, - 20, - 8, - 53, - 4 - ], - [ - 38, - 15, - 37, - 0, - 34, - 60 - ], - [ - 41, - 27, - 30, - 57, - 19, - 40 - ], - [ - 11, - 62, - 41, - 14, - 46, - 44 - ], - [ - 11, - 33, - 46, - 31, - 45, - 0 - ], - [ - 4, - 51, - 47, - 16, - 9, - 12 - ], - [ - 11, - 21, - 45, - 47, - 61, - 18 - ], - [ - 10, - 50, - 51, - 12, - 18, - 3 - ], - [ - 17, - 31, - 36, - 5, - 19, - 1 - ] - ], - [ - [ - 22, - 6, - 39, - 57, - 29, - 47 - ], - [ - 27, - 6, - 14, - 17, - 51, - 32 - ], - [ - 1, - 29, - 11, - 26, - 47, - 51 - ], - [ - 14, - 38, - 22, - 31, - 29, - 53 - ], - [ - 14, - 61, - 59, - 1, - 29, - 49 - ], - [ - 30, - 8, - 21, - 47, - 52, - 0 - ], - [ - 4, - 58, - 61, - 23, - 20, - 29 - ], - [ - 20, - 9, - 42, - 35, - 7, - 24 - ], - [ - 47, - 25, - 4, - 1, - 29, - 0 - ], - [ - 54, - 19, - 38, - 29, - 33, - 44 - ], - [ - 14, - 1, - 59, - 40, - 60, - 20 - ], - [ - 51, - 0, - 14, - 62, - 16, - 52 - ], - [ - 2, - 36, - 20, - 29, - 19, - 52 - ], - [ - 57, - 13, - 40, - 22, - 60, - 6 - ], - [ - 37, - 44, - 58, - 8, - 5, - 50 - ], - [ - 24, - 5, - 43, - 62, - 23, - 59 - ], - [ - 13, - 19, - 47, - 39, - 61, - 15 - ], - [ - 58, - 33, - 9, - 7, - 4, - 28 - ], - [ - 1, - 35, - 10, - 19, - 31, - 20 - ], - [ - 15, - 55, - 63, - 18, - 34, - 38 - ], - [ - 27, - 15, - 38, - 30, - 57, - 42 - ], - [ - 62, - 41, - 16, - 29, - 6, - 46 - ], - [ - 46, - 33, - 45, - 10, - 34, - 23 - ], - [ - 4, - 47, - 51, - 1, - 16, - 41 - ], - [ - 11, - 21, - 18, - 7, - 48, - 28 - ], - [ - 18, - 51, - 7, - 50, - 6, - 32 - ], - [ - 9, - 11, - 36, - 55, - 43, - 48 - ] - ], - [ - [ - 47, - 8, - 36, - 61, - 21, - 45 - ], - [ - 46, - 2, - 15, - 32, - 0, - 51 - ], - [ - 24, - 15, - 33, - 61, - 2, - 43 - ], - [ - 60, - 22, - 31, - 27, - 14, - 11 - ], - [ - 59, - 58, - 39, - 57, - 46, - 3 - ], - [ - 57, - 43, - 2, - 31, - 7, - 62 - ], - [ - 9, - 42, - 54, - 19, - 4, - 55 - ], - [ - 46, - 14, - 7, - 24, - 43, - 35 - ], - [ - 47, - 4, - 0, - 37, - 12, - 13 - ], - [ - 54, - 38, - 4, - 47, - 25, - 6 - ], - [ - 47, - 14, - 15, - 24, - 1, - 61 - ], - [ - 16, - 54, - 5, - 0, - 7, - 63 - ], - [ - 49, - 3, - 33, - 11, - 13, - 10 - ], - [ - 18, - 13, - 10, - 39, - 58, - 63 - ], - [ - 58, - 62, - 10, - 33, - 5, - 26 - ], - [ - 48, - 5, - 63, - 53, - 43, - 2 - ], - [ - 24, - 47, - 51, - 15, - 53, - 0 - ], - [ - 44, - 0, - 6, - 3, - 34, - 7 - ], - [ - 58, - 38, - 53, - 61, - 0, - 54 - ], - [ - 55, - 9, - 15, - 27, - 42, - 34 - ], - [ - 43, - 1, - 51, - 41, - 4, - 5 - ], - [ - 49, - 20, - 15, - 6, - 37, - 46 - ], - [ - 11, - 25, - 52, - 5, - 4, - 39 - ], - [ - 38, - 40, - 44, - 51, - 10, - 14 - ], - [ - 13, - 8, - 52, - 63, - 2, - 23 - ], - [ - 23, - 38, - 59, - 57, - 55, - 41 - ], - [ - 23, - 6, - 62, - 0, - 7, - 28 - ] - ], - [ - [ - 41, - 2, - 42, - 16, - 50, - 61 - ], - [ - 51, - 41, - 5, - 15, - 61, - 63 - ], - [ - 43, - 1, - 29, - 21, - 55, - 60 - ], - [ - 24, - 53, - 25, - 13, - 51, - 32 - ], - [ - 31, - 41, - 57, - 49, - 34, - 11 - ], - [ - 17, - 4, - 35, - 30, - 10, - 38 - ], - [ - 34, - 7, - 56, - 42, - 19, - 21 - ], - [ - 14, - 46, - 7, - 27, - 25, - 52 - ], - [ - 0, - 4, - 6, - 12, - 60, - 47 - ], - [ - 54, - 25, - 4, - 38, - 47, - 6 - ], - [ - 24, - 61, - 15, - 46, - 7, - 22 - ], - [ - 5, - 16, - 57, - 0, - 22, - 55 - ], - [ - 49, - 3, - 26, - 17, - 57, - 52 - ], - [ - 13, - 10, - 61, - 60, - 0, - 58 - ], - [ - 58, - 6, - 49, - 10, - 5, - 2 - ], - [ - 25, - 19, - 2, - 11, - 54, - 53 - ], - [ - 47, - 30, - 27, - 18, - 5, - 10 - ], - [ - 44, - 55, - 63, - 13, - 22, - 31 - ], - [ - 58, - 7, - 38, - 35, - 32, - 40 - ], - [ - 13, - 34, - 16, - 49, - 45, - 55 - ], - [ - 51, - 35, - 30, - 58, - 55, - 4 - ], - [ - 46, - 57, - 15, - 36, - 30, - 27 - ], - [ - 52, - 33, - 23, - 51, - 10, - 15 - ], - [ - 40, - 41, - 4, - 51, - 10, - 31 - ], - [ - 8, - 4, - 59, - 48, - 34, - 9 - ], - [ - 43, - 3, - 27, - 26, - 19, - 31 - ], - [ - 46, - 18, - 8, - 4, - 50, - 40 - ] - ], - [ - [ - 22, - 36, - 35, - 63, - 43, - 23 - ], - [ - 54, - 30, - 4, - 36, - 35, - 55 - ], - [ - 28, - 19, - 23, - 49, - 50, - 59 - ], - [ - 62, - 5, - 50, - 53, - 42, - 48 - ], - [ - 0, - 3, - 61, - 57, - 41, - 49 - ], - [ - 60, - 4, - 29, - 16, - 53, - 30 - ], - [ - 34, - 32, - 33, - 9, - 56, - 35 - ], - [ - 12, - 53, - 14, - 36, - 25, - 61 - ], - [ - 23, - 48, - 35, - 29, - 4, - 16 - ], - [ - 22, - 25, - 4, - 54, - 62, - 5 - ], - [ - 24, - 19, - 27, - 55, - 47, - 25 - ], - [ - 18, - 5, - 22, - 34, - 63, - 30 - ], - [ - 17, - 49, - 30, - 28, - 11, - 42 - ], - [ - 29, - 13, - 54, - 25, - 45, - 47 - ], - [ - 27, - 2, - 10, - 5, - 54, - 33 - ], - [ - 55, - 2, - 57, - 54, - 56, - 22 - ], - [ - 53, - 3, - 60, - 27, - 5, - 50 - ], - [ - 17, - 8, - 47, - 50, - 61, - 44 - ], - [ - 27, - 38, - 32, - 14, - 61, - 39 - ], - [ - 22, - 43, - 32, - 57, - 39, - 34 - ], - [ - 16, - 44, - 37, - 23, - 61, - 27 - ], - [ - 45, - 40, - 55, - 32, - 31, - 3 - ], - [ - 32, - 28, - 41, - 15, - 1, - 52 - ], - [ - 21, - 22, - 31, - 10, - 4, - 40 - ], - [ - 29, - 35, - 62, - 60, - 41, - 1 - ], - [ - 39, - 58, - 1, - 63, - 3, - 35 - ], - [ - 10, - 62, - 31, - 45, - 27, - 7 - ] - ], - [ - [ - 17, - 5, - 10, - 57, - 14, - 27 - ], - [ - 43, - 9, - 33, - 56, - 1, - 20 - ], - [ - 63, - 1, - 35, - 43, - 27, - 10 - ], - [ - 47, - 6, - 18, - 3, - 38, - 15 - ], - [ - 11, - 51, - 61, - 34, - 44, - 55 - ], - [ - 10, - 63, - 53, - 60, - 37, - 58 - ], - [ - 51, - 45, - 63, - 34, - 18, - 60 - ], - [ - 35, - 51, - 52, - 53, - 38, - 45 - ], - [ - 44, - 0, - 35, - 55, - 38, - 9 - ], - [ - 42, - 43, - 5, - 25, - 21, - 6 - ], - [ - 42, - 13, - 25, - 27, - 38, - 32 - ], - [ - 43, - 5, - 1, - 38, - 22, - 2 - ], - [ - 17, - 49, - 0, - 37, - 28, - 30 - ], - [ - 2, - 29, - 54, - 5, - 13, - 60 - ], - [ - 2, - 27, - 10, - 58, - 40, - 28 - ], - [ - 2, - 49, - 54, - 62, - 53, - 57 - ], - [ - 60, - 53, - 12, - 27, - 28, - 0 - ], - [ - 17, - 50, - 51, - 33, - 3, - 11 - ], - [ - 38, - 5, - 19, - 13, - 27, - 32 - ], - [ - 43, - 22, - 0, - 57, - 40, - 63 - ], - [ - 27, - 16, - 35, - 52, - 38, - 39 - ], - [ - 29, - 31, - 55, - 40, - 62, - 27 - ], - [ - 41, - 48, - 32, - 46, - 40, - 57 - ], - [ - 20, - 54, - 4, - 22, - 0, - 55 - ], - [ - 62, - 23, - 25, - 33, - 28, - 20 - ], - [ - 7, - 18, - 60, - 22, - 58, - 3 - ], - [ - 36, - 9, - 11, - 0, - 48, - 31 - ] - ], - [ - [ - 11, - 59, - 52, - 28, - 6, - 45 - ], - [ - 7, - 23, - 15, - 17, - 55, - 61 - ], - [ - 28, - 0, - 29, - 46, - 58, - 14 - ], - [ - 62, - 49, - 39, - 18, - 6, - 13 - ], - [ - 36, - 61, - 7, - 40, - 35, - 33 - ], - [ - 8, - 16, - 29, - 43, - 57, - 37 - ], - [ - 30, - 19, - 21, - 57, - 42, - 14 - ], - [ - 35, - 11, - 51, - 53, - 36, - 33 - ], - [ - 48, - 35, - 4, - 50, - 11, - 40 - ], - [ - 5, - 3, - 21, - 47, - 43, - 38 - ], - [ - 21, - 18, - 2, - 6, - 3, - 10 - ], - [ - 18, - 43, - 5, - 45, - 22, - 4 - ], - [ - 33, - 28, - 37, - 41, - 49, - 10 - ], - [ - 56, - 55, - 22, - 54, - 62, - 15 - ], - [ - 2, - 27, - 57, - 5, - 63, - 19 - ], - [ - 55, - 8, - 54, - 2, - 59, - 48 - ], - [ - 0, - 53, - 51, - 3, - 50, - 52 - ], - [ - 21, - 12, - 62, - 60, - 18, - 43 - ], - [ - 38, - 13, - 53, - 27, - 14, - 28 - ], - [ - 22, - 56, - 44, - 43, - 51, - 53 - ], - [ - 29, - 43, - 12, - 16, - 41, - 52 - ], - [ - 51, - 1, - 35, - 44, - 48, - 34 - ], - [ - 32, - 25, - 4, - 41, - 53, - 54 - ], - [ - 48, - 17, - 25, - 60, - 1, - 62 - ], - [ - 22, - 29, - 5, - 18, - 53, - 20 - ], - [ - 28, - 55, - 15, - 1, - 49, - 8 - ], - [ - 30, - 57, - 6, - 7, - 31, - 50 - ] - ], - [ - [ - 11, - 16, - 31, - 46, - 35, - 49 - ], - [ - 13, - 49, - 31, - 16, - 34, - 19 - ], - [ - 36, - 13, - 27, - 34, - 42, - 18 - ], - [ - 24, - 7, - 61, - 63, - 34, - 15 - ], - [ - 17, - 35, - 2, - 48, - 44, - 62 - ], - [ - 8, - 7, - 23, - 47, - 51, - 26 - ], - [ - 61, - 58, - 29, - 38, - 33, - 45 - ], - [ - 20, - 9, - 42, - 27, - 54, - 29 - ], - [ - 25, - 47, - 34, - 1, - 4, - 38 - ], - [ - 54, - 27, - 57, - 4, - 28, - 25 - ], - [ - 1, - 30, - 59, - 60, - 17, - 22 - ], - [ - 31, - 25, - 51, - 5, - 58, - 49 - ], - [ - 34, - 2, - 19, - 14, - 29, - 36 - ], - [ - 57, - 40, - 6, - 9, - 32, - 14 - ], - [ - 8, - 44, - 34, - 16, - 45, - 47 - ], - [ - 8, - 25, - 40, - 37, - 14, - 59 - ], - [ - 13, - 15, - 19, - 51, - 25, - 37 - ], - [ - 23, - 4, - 44, - 52, - 19, - 15 - ], - [ - 8, - 10, - 45, - 1, - 31, - 53 - ], - [ - 38, - 59, - 15, - 2, - 34, - 0 - ], - [ - 41, - 25, - 57, - 55, - 27, - 24 - ], - [ - 11, - 44, - 41, - 39, - 62, - 53 - ], - [ - 31, - 30, - 42, - 37, - 34, - 59 - ], - [ - 33, - 14, - 12, - 10, - 54, - 51 - ], - [ - 11, - 18, - 45, - 20, - 33, - 48 - ], - [ - 50, - 51, - 54, - 23, - 10, - 2 - ], - [ - 17, - 31, - 37, - 5, - 19, - 36 - ] - ], - [ - [ - 31, - 35, - 51, - 18, - 53, - 61 - ], - [ - 7, - 40, - 39, - 41, - 31, - 37 - ], - [ - 16, - 29, - 26, - 50, - 33, - 10 - ], - [ - 14, - 22, - 37, - 17, - 6, - 25 - ], - [ - 60, - 20, - 46, - 4, - 3, - 57 - ], - [ - 11, - 2, - 49, - 21, - 27, - 9 - ], - [ - 42, - 30, - 54, - 58, - 19, - 56 - ], - [ - 46, - 33, - 38, - 41, - 35, - 49 - ], - [ - 50, - 48, - 57, - 37, - 38, - 10 - ], - [ - 59, - 63, - 47, - 28, - 10, - 35 - ], - [ - 28, - 3, - 2, - 52, - 33, - 43 - ], - [ - 58, - 19, - 17, - 16, - 57, - 63 - ], - [ - 42, - 4, - 3, - 23, - 45, - 57 - ], - [ - 18, - 42, - 51, - 3, - 20, - 15 - ], - [ - 32, - 43, - 12, - 3, - 0, - 35 - ], - [ - 15, - 45, - 4, - 7, - 53, - 21 - ], - [ - 28, - 24, - 14, - 60, - 15, - 31 - ], - [ - 26, - 30, - 0, - 54, - 5, - 36 - ], - [ - 63, - 37, - 11, - 41, - 51, - 4 - ], - [ - 48, - 9, - 33, - 0, - 54, - 42 - ], - [ - 22, - 60, - 11, - 39, - 1, - 49 - ], - [ - 17, - 4, - 34, - 2, - 27, - 53 - ], - [ - 17, - 6, - 40, - 58, - 42, - 39 - ], - [ - 60, - 54, - 35, - 39, - 0, - 32 - ], - [ - 44, - 15, - 37, - 13, - 8, - 25 - ], - [ - 41, - 59, - 25, - 45, - 13, - 6 - ], - [ - 23, - 58, - 13, - 19, - 62, - 29 - ] - ], - [ - [ - 52, - 47, - 27, - 36, - 38, - 33 - ], - [ - 43, - 56, - 4, - 21, - 25, - 6 - ], - [ - 25, - 54, - 35, - 18, - 11, - 57 - ], - [ - 17, - 16, - 1, - 6, - 33, - 45 - ], - [ - 32, - 36, - 0, - 60, - 46, - 57 - ], - [ - 11, - 1, - 3, - 21, - 2, - 34 - ], - [ - 30, - 42, - 58, - 54, - 19, - 13 - ], - [ - 33, - 35, - 56, - 42, - 38, - 54 - ], - [ - 50, - 37, - 36, - 14, - 11, - 20 - ], - [ - 10, - 28, - 32, - 4, - 37, - 6 - ], - [ - 3, - 18, - 28, - 61, - 44, - 2 - ], - [ - 16, - 19, - 56, - 27, - 46, - 43 - ], - [ - 42, - 33, - 57, - 3, - 58, - 26 - ], - [ - 56, - 15, - 51, - 55, - 50, - 13 - ], - [ - 12, - 43, - 40, - 13, - 16, - 29 - ], - [ - 15, - 7, - 28, - 53, - 5, - 20 - ], - [ - 14, - 58, - 24, - 60, - 31, - 51 - ], - [ - 0, - 60, - 3, - 44, - 24, - 19 - ], - [ - 36, - 28, - 11, - 63, - 53, - 15 - ], - [ - 2, - 26, - 9, - 34, - 0, - 3 - ], - [ - 49, - 28, - 43, - 41, - 30, - 11 - ], - [ - 25, - 51, - 12, - 6, - 61, - 16 - ], - [ - 39, - 17, - 41, - 50, - 40, - 21 - ], - [ - 25, - 58, - 48, - 12, - 60, - 33 - ], - [ - 31, - 49, - 5, - 52, - 63, - 3 - ], - [ - 15, - 55, - 38, - 47, - 1, - 49 - ], - [ - 23, - 6, - 32, - 19, - 62, - 7 - ] - ], - [ - [ - 44, - 24, - 56, - 33, - 15, - 7 - ], - [ - 38, - 26, - 24, - 29, - 53, - 19 - ], - [ - 12, - 15, - 29, - 9, - 1, - 63 - ], - [ - 38, - 61, - 58, - 50, - 45, - 6 - ], - [ - 24, - 34, - 4, - 36, - 57, - 31 - ], - [ - 1, - 22, - 43, - 21, - 10, - 7 - ], - [ - 20, - 19, - 54, - 58, - 18, - 42 - ], - [ - 56, - 33, - 14, - 21, - 51, - 18 - ], - [ - 60, - 50, - 14, - 36, - 4, - 43 - ], - [ - 6, - 10, - 9, - 63, - 4, - 38 - ], - [ - 54, - 39, - 46, - 18, - 3, - 2 - ], - [ - 30, - 16, - 37, - 53, - 56, - 43 - ], - [ - 56, - 10, - 42, - 58, - 57, - 23 - ], - [ - 30, - 56, - 13, - 51, - 50, - 0 - ], - [ - 55, - 40, - 12, - 62, - 13, - 30 - ], - [ - 53, - 28, - 11, - 61, - 7, - 19 - ], - [ - 55, - 14, - 17, - 47, - 30, - 5 - ], - [ - 18, - 31, - 20, - 60, - 57, - 32 - ], - [ - 43, - 12, - 53, - 26, - 32, - 61 - ], - [ - 25, - 16, - 26, - 61, - 3, - 53 - ], - [ - 49, - 28, - 51, - 59, - 55, - 11 - ], - [ - 56, - 6, - 32, - 14, - 10, - 21 - ], - [ - 2, - 15, - 58, - 17, - 13, - 62 - ], - [ - 53, - 51, - 4, - 29, - 50, - 25 - ], - [ - 61, - 31, - 15, - 19, - 60, - 49 - ], - [ - 46, - 44, - 19, - 36, - 8, - 40 - ], - [ - 8, - 29, - 46, - 7, - 53, - 20 - ] - ], - [ - [ - 48, - 42, - 38, - 63, - 50, - 62 - ], - [ - 3, - 2, - 40, - 33, - 14, - 61 - ], - [ - 39, - 7, - 45, - 40, - 6, - 44 - ], - [ - 41, - 5, - 20, - 56, - 13, - 0 - ], - [ - 6, - 37, - 30, - 1, - 38, - 52 - ], - [ - 59, - 46, - 4, - 22, - 5, - 6 - ], - [ - 20, - 1, - 44, - 35, - 13, - 3 - ], - [ - 12, - 56, - 8, - 50, - 31, - 2 - ], - [ - 33, - 60, - 41, - 43, - 37, - 52 - ], - [ - 9, - 10, - 0, - 55, - 40, - 37 - ], - [ - 54, - 39, - 9, - 8, - 61, - 46 - ], - [ - 30, - 56, - 53, - 24, - 16, - 59 - ], - [ - 7, - 58, - 57, - 42, - 52, - 10 - ], - [ - 27, - 30, - 42, - 25, - 59, - 13 - ], - [ - 9, - 11, - 49, - 55, - 61, - 0 - ], - [ - 34, - 53, - 54, - 57, - 29, - 37 - ], - [ - 55, - 17, - 5, - 60, - 31, - 16 - ], - [ - 57, - 48, - 43, - 39, - 32, - 36 - ], - [ - 12, - 43, - 42, - 49, - 7, - 61 - ], - [ - 23, - 36, - 1, - 0, - 16, - 61 - ], - [ - 21, - 35, - 11, - 31, - 55, - 7 - ], - [ - 60, - 8, - 24, - 42, - 6, - 14 - ], - [ - 15, - 51, - 58, - 2, - 33, - 30 - ], - [ - 52, - 51, - 4, - 28, - 21, - 19 - ], - [ - 19, - 60, - 15, - 34, - 54, - 61 - ], - [ - 19, - 61, - 58, - 40, - 12, - 3 - ], - [ - 35, - 49, - 54, - 53, - 1, - 25 - ] - ], - [ - [ - 47, - 37, - 59, - 38, - 33, - 10 - ], - [ - 36, - 7, - 21, - 51, - 8, - 47 - ], - [ - 20, - 32, - 44, - 47, - 4, - 54 - ], - [ - 43, - 20, - 42, - 52, - 8, - 19 - ], - [ - 27, - 9, - 39, - 57, - 12, - 54 - ], - [ - 28, - 16, - 29, - 11, - 61, - 58 - ], - [ - 42, - 2, - 30, - 43, - 28, - 25 - ], - [ - 36, - 18, - 32, - 21, - 53, - 15 - ], - [ - 50, - 12, - 37, - 48, - 14, - 52 - ], - [ - 9, - 10, - 13, - 3, - 58, - 26 - ], - [ - 18, - 3, - 9, - 55, - 6, - 61 - ], - [ - 30, - 18, - 45, - 60, - 16, - 24 - ], - [ - 35, - 5, - 42, - 26, - 37, - 58 - ], - [ - 55, - 38, - 15, - 13, - 14, - 54 - ], - [ - 9, - 62, - 2, - 17, - 13, - 26 - ], - [ - 9, - 40, - 37, - 54, - 17, - 47 - ], - [ - 22, - 60, - 51, - 14, - 40, - 39 - ], - [ - 17, - 3, - 21, - 30, - 36, - 25 - ], - [ - 56, - 41, - 24, - 9, - 43, - 14 - ], - [ - 53, - 34, - 26, - 22, - 12, - 35 - ], - [ - 36, - 42, - 24, - 37, - 8, - 4 - ], - [ - 34, - 16, - 8, - 46, - 56, - 6 - ], - [ - 31, - 10, - 16, - 41, - 56, - 32 - ], - [ - 48, - 0, - 42, - 56, - 31, - 30 - ], - [ - 6, - 54, - 26, - 19, - 8, - 7 - ], - [ - 49, - 22, - 6, - 13, - 24, - 59 - ], - [ - 45, - 62, - 27, - 47, - 50, - 7 - ] - ], - [ - [ - 45, - 37, - 48, - 29, - 30, - 3 - ], - [ - 8, - 60, - 59, - 43, - 10, - 48 - ], - [ - 51, - 45, - 28, - 34, - 59, - 63 - ], - [ - 43, - 2, - 38, - 12, - 20, - 4 - ], - [ - 50, - 57, - 39, - 31, - 0, - 63 - ], - [ - 58, - 53, - 18, - 9, - 30, - 21 - ], - [ - 63, - 51, - 34, - 7, - 20, - 27 - ], - [ - 32, - 21, - 46, - 47, - 25, - 18 - ], - [ - 6, - 12, - 50, - 14, - 33, - 37 - ], - [ - 30, - 10, - 9, - 6, - 13, - 29 - ], - [ - 61, - 18, - 23, - 20, - 44, - 6 - ], - [ - 16, - 12, - 30, - 60, - 0, - 62 - ], - [ - 5, - 26, - 57, - 35, - 37, - 61 - ], - [ - 46, - 55, - 13, - 1, - 17, - 31 - ], - [ - 6, - 10, - 2, - 22, - 16, - 15 - ], - [ - 6, - 54, - 19, - 57, - 25, - 46 - ], - [ - 2, - 30, - 60, - 61, - 18, - 49 - ], - [ - 17, - 4, - 2, - 27, - 3, - 44 - ], - [ - 56, - 46, - 8, - 32, - 6, - 14 - ], - [ - 5, - 13, - 0, - 34, - 14, - 30 - ], - [ - 15, - 23, - 59, - 57, - 27, - 53 - ], - [ - 36, - 61, - 8, - 43, - 57, - 37 - ], - [ - 10, - 15, - 52, - 31, - 29, - 23 - ], - [ - 10, - 38, - 1, - 4, - 57, - 31 - ], - [ - 8, - 32, - 54, - 33, - 3, - 50 - ], - [ - 6, - 33, - 19, - 50, - 2, - 3 - ], - [ - 55, - 43, - 4, - 5, - 25, - 8 - ] - ], - [ - [ - 25, - 14, - 18, - 49, - 51, - 63 - ], - [ - 42, - 21, - 30, - 24, - 43, - 7 - ], - [ - 54, - 39, - 9, - 59, - 28, - 49 - ], - [ - 23, - 1, - 55, - 45, - 43, - 40 - ], - [ - 37, - 30, - 41, - 59, - 21, - 44 - ], - [ - 4, - 41, - 31, - 35, - 19, - 14 - ], - [ - 28, - 55, - 44, - 63, - 9, - 51 - ], - [ - 63, - 12, - 32, - 13, - 47, - 28 - ], - [ - 33, - 12, - 59, - 35, - 6, - 39 - ], - [ - 30, - 40, - 10, - 29, - 52, - 13 - ], - [ - 63, - 19, - 23, - 61, - 8, - 55 - ], - [ - 45, - 62, - 22, - 12, - 38, - 42 - ], - [ - 58, - 26, - 57, - 20, - 45, - 30 - ], - [ - 8, - 59, - 47, - 17, - 25, - 22 - ], - [ - 24, - 54, - 51, - 55, - 10, - 15 - ], - [ - 46, - 22, - 54, - 44, - 57, - 40 - ], - [ - 34, - 60, - 2, - 17, - 27, - 18 - ], - [ - 51, - 6, - 61, - 2, - 39, - 32 - ], - [ - 5, - 49, - 38, - 8, - 32, - 46 - ], - [ - 57, - 4, - 23, - 14, - 59, - 5 - ], - [ - 7, - 36, - 45, - 21, - 53, - 4 - ], - [ - 60, - 3, - 39, - 8, - 14, - 35 - ], - [ - 33, - 57, - 60, - 47, - 15, - 53 - ], - [ - 20, - 11, - 22, - 19, - 58, - 4 - ], - [ - 36, - 34, - 47, - 41, - 60, - 1 - ], - [ - 33, - 3, - 49, - 29, - 59, - 14 - ], - [ - 52, - 60, - 24, - 25, - 35, - 34 - ] - ], - [ - [ - 23, - 54, - 53, - 11, - 58, - 8 - ], - [ - 11, - 30, - 15, - 63, - 59, - 49 - ], - [ - 20, - 58, - 29, - 17, - 52, - 30 - ], - [ - 18, - 1, - 43, - 15, - 3, - 8 - ], - [ - 59, - 55, - 13, - 20, - 44, - 30 - ], - [ - 33, - 45, - 27, - 53, - 63, - 52 - ], - [ - 28, - 57, - 51, - 34, - 53, - 31 - ], - [ - 13, - 63, - 3, - 32, - 44, - 45 - ], - [ - 59, - 33, - 12, - 57, - 6, - 35 - ], - [ - 23, - 30, - 42, - 10, - 29, - 13 - ], - [ - 63, - 23, - 18, - 42, - 38, - 19 - ], - [ - 62, - 2, - 22, - 19, - 45, - 23 - ], - [ - 6, - 20, - 26, - 30, - 5, - 25 - ], - [ - 3, - 6, - 8, - 31, - 17, - 37 - ], - [ - 59, - 2, - 10, - 54, - 55, - 19 - ], - [ - 45, - 7, - 57, - 42, - 54, - 46 - ], - [ - 60, - 55, - 53, - 34, - 41, - 18 - ], - [ - 36, - 2, - 8, - 31, - 32, - 46 - ], - [ - 36, - 45, - 38, - 32, - 61, - 25 - ], - [ - 57, - 39, - 6, - 18, - 19, - 30 - ], - [ - 37, - 39, - 59, - 48, - 53, - 14 - ], - [ - 52, - 28, - 46, - 55, - 47, - 61 - ], - [ - 1, - 28, - 52, - 15, - 31, - 18 - ], - [ - 22, - 21, - 8, - 10, - 46, - 39 - ], - [ - 45, - 41, - 35, - 9, - 54, - 33 - ], - [ - 42, - 39, - 58, - 61, - 24, - 3 - ], - [ - 41, - 46, - 10, - 3, - 15, - 33 - ] - ], - [ - [ - 48, - 38, - 63, - 42, - 47, - 7 - ], - [ - 3, - 10, - 26, - 2, - 6, - 62 - ], - [ - 39, - 7, - 44, - 6, - 45, - 40 - ], - [ - 21, - 39, - 1, - 18, - 15, - 57 - ], - [ - 6, - 33, - 9, - 13, - 3, - 27 - ], - [ - 59, - 33, - 53, - 62, - 21, - 45 - ], - [ - 28, - 10, - 33, - 7, - 57, - 50 - ], - [ - 50, - 13, - 12, - 49, - 3, - 55 - ], - [ - 33, - 59, - 26, - 35, - 48, - 38 - ], - [ - 23, - 30, - 20, - 43, - 10, - 58 - ], - [ - 63, - 23, - 18, - 17, - 38, - 9 - ], - [ - 22, - 62, - 36, - 33, - 6, - 2 - ], - [ - 12, - 17, - 59, - 7, - 26, - 49 - ], - [ - 47, - 3, - 27, - 6, - 24, - 22 - ], - [ - 59, - 57, - 46, - 54, - 2, - 10 - ], - [ - 27, - 45, - 57, - 54, - 34, - 61 - ], - [ - 34, - 60, - 55, - 31, - 58, - 43 - ], - [ - 48, - 42, - 32, - 39, - 2, - 38 - ], - [ - 42, - 49, - 32, - 44, - 12, - 61 - ], - [ - 57, - 36, - 39, - 1, - 30, - 59 - ], - [ - 48, - 7, - 21, - 53, - 17, - 29 - ], - [ - 8, - 60, - 58, - 35, - 46, - 14 - ], - [ - 51, - 15, - 28, - 47, - 33, - 30 - ], - [ - 11, - 19, - 4, - 8, - 58, - 52 - ], - [ - 24, - 5, - 41, - 60, - 40, - 54 - ], - [ - 61, - 3, - 21, - 58, - 19, - 8 - ], - [ - 60, - 35, - 54, - 49, - 1, - 0 - ] - ], - [ - [ - 6, - 24, - 63, - 25, - 26, - 45 - ], - [ - 47, - 13, - 49, - 44, - 20, - 19 - ], - [ - 23, - 32, - 49, - 20, - 24, - 2 - ], - [ - 43, - 21, - 8, - 40, - 39, - 45 - ], - [ - 39, - 29, - 3, - 5, - 41, - 10 - ], - [ - 61, - 33, - 48, - 40, - 29, - 62 - ], - [ - 29, - 28, - 25, - 33, - 44, - 31 - ], - [ - 50, - 33, - 13, - 11, - 30, - 54 - ], - [ - 48, - 26, - 35, - 17, - 55, - 6 - ], - [ - 43, - 23, - 12, - 9, - 25, - 62 - ], - [ - 63, - 35, - 18, - 53, - 38, - 9 - ], - [ - 19, - 48, - 5, - 36, - 59, - 39 - ], - [ - 7, - 26, - 59, - 17, - 12, - 30 - ], - [ - 47, - 24, - 22, - 62, - 42, - 14 - ], - [ - 59, - 46, - 39, - 35, - 57, - 2 - ], - [ - 27, - 9, - 20, - 0, - 57, - 52 - ], - [ - 38, - 58, - 60, - 34, - 43, - 29 - ], - [ - 42, - 32, - 39, - 54, - 38, - 57 - ], - [ - 59, - 32, - 42, - 6, - 21, - 18 - ], - [ - 24, - 36, - 57, - 4, - 30, - 60 - ], - [ - 9, - 48, - 7, - 53, - 21, - 13 - ], - [ - 49, - 31, - 14, - 8, - 19, - 52 - ], - [ - 15, - 33, - 44, - 8, - 3, - 14 - ], - [ - 11, - 4, - 21, - 28, - 41, - 23 - ], - [ - 37, - 27, - 24, - 12, - 9, - 42 - ], - [ - 58, - 9, - 19, - 3, - 12, - 48 - ], - [ - 24, - 14, - 60, - 47, - 25, - 35 - ] - ], - [ - [ - 26, - 62, - 58, - 18, - 38, - 5 - ], - [ - 19, - 12, - 40, - 39, - 31, - 57 - ], - [ - 27, - 38, - 9, - 22, - 23, - 61 - ], - [ - 42, - 20, - 63, - 39, - 45, - 43 - ], - [ - 7, - 36, - 60, - 29, - 57, - 24 - ], - [ - 61, - 34, - 3, - 1, - 44, - 51 - ], - [ - 29, - 42, - 25, - 47, - 30, - 17 - ], - [ - 33, - 38, - 39, - 50, - 62, - 0 - ], - [ - 17, - 26, - 13, - 50, - 41, - 11 - ], - [ - 12, - 43, - 35, - 16, - 55, - 60 - ], - [ - 3, - 18, - 39, - 63, - 35, - 62 - ], - [ - 19, - 48, - 39, - 60, - 54, - 2 - ], - [ - 7, - 42, - 56, - 59, - 37, - 24 - ], - [ - 56, - 47, - 54, - 59, - 3, - 6 - ], - [ - 35, - 39, - 43, - 59, - 2, - 13 - ], - [ - 20, - 27, - 9, - 56, - 0, - 54 - ], - [ - 17, - 60, - 29, - 18, - 58, - 5 - ], - [ - 57, - 42, - 45, - 0, - 24, - 29 - ], - [ - 49, - 28, - 52, - 32, - 50, - 42 - ], - [ - 24, - 2, - 36, - 30, - 46, - 59 - ], - [ - 11, - 7, - 48, - 54, - 53, - 21 - ], - [ - 33, - 14, - 42, - 6, - 8, - 5 - ], - [ - 15, - 36, - 33, - 17, - 51, - 13 - ], - [ - 41, - 2, - 4, - 11, - 8, - 29 - ], - [ - 30, - 15, - 60, - 5, - 46, - 9 - ], - [ - 21, - 3, - 14, - 19, - 61, - 58 - ], - [ - 24, - 60, - 1, - 14, - 35, - 53 - ] - ], - [ - [ - 8, - 56, - 54, - 4, - 37, - 38 - ], - [ - 36, - 9, - 24, - 8, - 1, - 2 - ], - [ - 16, - 57, - 29, - 32, - 58, - 30 - ], - [ - 32, - 63, - 42, - 52, - 4, - 20 - ], - [ - 27, - 24, - 44, - 39, - 63, - 20 - ], - [ - 24, - 11, - 28, - 6, - 15, - 31 - ], - [ - 30, - 41, - 42, - 25, - 52, - 2 - ], - [ - 26, - 36, - 62, - 18, - 50, - 21 - ], - [ - 48, - 50, - 32, - 51, - 27, - 26 - ], - [ - 13, - 43, - 3, - 26, - 12, - 41 - ], - [ - 3, - 32, - 18, - 53, - 39, - 17 - ], - [ - 21, - 19, - 9, - 36, - 48, - 1 - ], - [ - 35, - 59, - 42, - 54, - 63, - 17 - ], - [ - 15, - 12, - 38, - 9, - 51, - 54 - ], - [ - 3, - 2, - 48, - 59, - 57, - 13 - ], - [ - 27, - 9, - 4, - 20, - 22, - 3 - ], - [ - 60, - 40, - 22, - 27, - 46, - 32 - ], - [ - 30, - 25, - 21, - 17, - 0, - 42 - ], - [ - 33, - 27, - 34, - 28, - 38, - 44 - ], - [ - 53, - 32, - 33, - 35, - 31, - 56 - ], - [ - 40, - 11, - 20, - 47, - 48, - 24 - ], - [ - 45, - 62, - 33, - 53, - 17, - 34 - ], - [ - 41, - 9, - 39, - 25, - 17, - 32 - ], - [ - 44, - 30, - 2, - 31, - 0, - 47 - ], - [ - 6, - 62, - 37, - 52, - 55, - 33 - ], - [ - 10, - 24, - 57, - 9, - 49, - 13 - ], - [ - 45, - 24, - 27, - 47, - 19, - 26 - ] - ], - [ - [ - 4, - 16, - 59, - 44, - 13, - 56 - ], - [ - 16, - 23, - 9, - 59, - 13, - 50 - ], - [ - 11, - 35, - 21, - 7, - 9, - 59 - ], - [ - 44, - 1, - 25, - 26, - 15, - 20 - ], - [ - 35, - 57, - 52, - 31, - 24, - 5 - ], - [ - 35, - 15, - 38, - 53, - 12, - 0 - ], - [ - 36, - 20, - 63, - 28, - 60, - 33 - ], - [ - 49, - 50, - 12, - 13, - 8, - 16 - ], - [ - 42, - 48, - 6, - 16, - 35, - 41 - ], - [ - 23, - 9, - 54, - 34, - 30, - 13 - ], - [ - 23, - 63, - 51, - 53, - 55, - 3 - ], - [ - 33, - 36, - 62, - 19, - 59, - 57 - ], - [ - 59, - 38, - 12, - 32, - 17, - 53 - ], - [ - 48, - 0, - 19, - 24, - 61, - 22 - ], - [ - 55, - 6, - 39, - 7, - 60, - 18 - ], - [ - 27, - 57, - 19, - 1, - 60, - 47 - ], - [ - 30, - 40, - 10, - 17, - 36, - 60 - ], - [ - 20, - 13, - 1, - 31, - 17, - 43 - ], - [ - 50, - 32, - 23, - 7, - 33, - 30 - ], - [ - 14, - 19, - 23, - 13, - 10, - 7 - ], - [ - 53, - 23, - 19, - 13, - 7, - 50 - ], - [ - 30, - 18, - 46, - 14, - 1, - 15 - ], - [ - 15, - 18, - 52, - 0, - 51, - 27 - ], - [ - 4, - 21, - 41, - 53, - 10, - 14 - ], - [ - 4, - 9, - 39, - 8, - 16, - 54 - ], - [ - 3, - 19, - 16, - 47, - 30, - 27 - ], - [ - 33, - 8, - 46, - 40, - 29, - 14 - ] - ], - [ - [ - 6, - 26, - 3, - 24, - 11, - 38 - ], - [ - 49, - 16, - 19, - 57, - 0, - 18 - ], - [ - 16, - 60, - 57, - 0, - 22, - 30 - ], - [ - 35, - 46, - 49, - 44, - 26, - 17 - ], - [ - 25, - 54, - 9, - 43, - 45, - 27 - ], - [ - 55, - 33, - 30, - 6, - 28, - 57 - ], - [ - 55, - 31, - 49, - 52, - 15, - 25 - ], - [ - 11, - 36, - 15, - 29, - 30, - 62 - ], - [ - 55, - 42, - 48, - 54, - 46, - 6 - ], - [ - 10, - 3, - 43, - 21, - 62, - 54 - ], - [ - 32, - 56, - 6, - 3, - 18, - 55 - ], - [ - 50, - 21, - 36, - 19, - 4, - 42 - ], - [ - 38, - 35, - 53, - 39, - 41, - 32 - ], - [ - 28, - 24, - 38, - 41, - 15, - 12 - ], - [ - 2, - 40, - 18, - 60, - 55, - 13 - ], - [ - 27, - 22, - 57, - 8, - 54, - 37 - ], - [ - 40, - 27, - 60, - 46, - 44, - 50 - ], - [ - 17, - 21, - 30, - 12, - 29, - 26 - ], - [ - 33, - 38, - 23, - 2, - 13, - 27 - ], - [ - 32, - 56, - 9, - 44, - 31, - 60 - ], - [ - 53, - 12, - 22, - 40, - 41, - 5 - ], - [ - 44, - 45, - 49, - 17, - 14, - 56 - ], - [ - 25, - 39, - 41, - 4, - 9, - 53 - ], - [ - 44, - 1, - 17, - 0, - 13, - 58 - ], - [ - 62, - 52, - 37, - 54, - 42, - 36 - ], - [ - 10, - 9, - 28, - 55, - 2, - 24 - ], - [ - 57, - 30, - 27, - 45, - 47, - 16 - ] - ], - [ - [ - 16, - 11, - 31, - 46, - 0, - 35 - ], - [ - 13, - 49, - 34, - 31, - 16, - 8 - ], - [ - 36, - 13, - 27, - 58, - 18, - 34 - ], - [ - 24, - 61, - 15, - 46, - 63, - 25 - ], - [ - 17, - 2, - 50, - 35, - 58, - 44 - ], - [ - 8, - 23, - 7, - 49, - 26, - 47 - ], - [ - 61, - 38, - 29, - 58, - 0, - 36 - ], - [ - 42, - 20, - 9, - 4, - 11, - 27 - ], - [ - 1, - 34, - 30, - 48, - 26, - 7 - ], - [ - 27, - 10, - 28, - 62, - 13, - 54 - ], - [ - 60, - 1, - 30, - 59, - 6, - 52 - ], - [ - 31, - 62, - 25, - 42, - 32, - 21 - ], - [ - 34, - 32, - 19, - 2, - 62, - 14 - ], - [ - 6, - 40, - 24, - 31, - 13, - 57 - ], - [ - 44, - 2, - 40, - 8, - 34, - 45 - ], - [ - 8, - 27, - 7, - 57, - 50, - 47 - ], - [ - 13, - 60, - 45, - 22, - 52, - 25 - ], - [ - 4, - 23, - 33, - 46, - 58, - 34 - ], - [ - 8, - 33, - 56, - 45, - 51, - 59 - ], - [ - 38, - 53, - 39, - 26, - 35, - 31 - ], - [ - 41, - 46, - 53, - 11, - 59, - 8 - ], - [ - 11, - 44, - 33, - 55, - 52, - 53 - ], - [ - 31, - 1, - 16, - 9, - 15, - 41 - ], - [ - 2, - 31, - 14, - 10, - 44, - 7 - ], - [ - 45, - 6, - 62, - 12, - 24, - 18 - ], - [ - 10, - 50, - 42, - 54, - 24, - 3 - ], - [ - 17, - 31, - 37, - 5, - 10, - 41 - ] - ], - [ - [ - 45, - 37, - 48, - 29, - 30, - 59 - ], - [ - 8, - 59, - 10, - 60, - 43, - 55 - ], - [ - 45, - 51, - 28, - 59, - 34, - 31 - ], - [ - 14, - 31, - 22, - 2, - 19, - 44 - ], - [ - 59, - 50, - 4, - 31, - 9, - 44 - ], - [ - 58, - 23, - 47, - 9, - 53, - 30 - ], - [ - 54, - 13, - 51, - 22, - 29, - 61 - ], - [ - 47, - 42, - 32, - 20, - 24, - 37 - ], - [ - 33, - 12, - 47, - 10, - 30, - 44 - ], - [ - 30, - 10, - 27, - 28, - 33, - 13 - ], - [ - 61, - 23, - 14, - 20, - 1, - 29 - ], - [ - 12, - 14, - 32, - 62, - 31, - 52 - ], - [ - 5, - 36, - 19, - 20, - 2, - 32 - ], - [ - 46, - 24, - 1, - 40, - 17, - 33 - ], - [ - 44, - 8, - 16, - 30, - 2, - 6 - ], - [ - 6, - 46, - 15, - 54, - 21, - 33 - ], - [ - 2, - 30, - 60, - 61, - 18, - 10 - ], - [ - 6, - 17, - 33, - 27, - 58, - 4 - ], - [ - 56, - 46, - 8, - 60, - 22, - 16 - ], - [ - 5, - 13, - 14, - 35, - 36, - 52 - ], - [ - 15, - 53, - 59, - 23, - 6, - 46 - ], - [ - 36, - 30, - 60, - 33, - 53, - 38 - ], - [ - 10, - 15, - 16, - 52, - 4, - 38 - ], - [ - 31, - 1, - 10, - 4, - 41, - 59 - ], - [ - 32, - 9, - 8, - 59, - 33, - 18 - ], - [ - 33, - 19, - 6, - 3, - 45, - 50 - ], - [ - 55, - 4, - 43, - 5, - 25, - 34 - ] - ], - [ - [ - 25, - 14, - 18, - 49, - 51, - 63 - ], - [ - 42, - 21, - 30, - 43, - 24, - 7 - ], - [ - 54, - 39, - 9, - 59, - 28, - 45 - ], - [ - 1, - 23, - 55, - 8, - 16, - 7 - ], - [ - 59, - 37, - 30, - 41, - 16, - 1 - ], - [ - 4, - 31, - 41, - 45, - 56, - 18 - ], - [ - 28, - 55, - 44, - 54, - 9, - 51 - ], - [ - 24, - 47, - 12, - 63, - 32, - 61 - ], - [ - 33, - 47, - 39, - 44, - 38, - 56 - ], - [ - 40, - 30, - 59, - 33, - 58, - 10 - ], - [ - 19, - 8, - 23, - 4, - 25, - 34 - ], - [ - 45, - 14, - 12, - 42, - 62, - 52 - ], - [ - 58, - 45, - 59, - 29, - 25, - 36 - ], - [ - 8, - 59, - 1, - 47, - 33, - 24 - ], - [ - 24, - 30, - 38, - 56, - 4, - 7 - ], - [ - 46, - 22, - 44, - 57, - 12, - 40 - ], - [ - 34, - 60, - 2, - 17, - 18, - 29 - ], - [ - 6, - 61, - 51, - 39, - 53, - 32 - ], - [ - 49, - 5, - 47, - 60, - 21, - 59 - ], - [ - 57, - 4, - 23, - 39, - 45, - 60 - ], - [ - 7, - 36, - 45, - 53, - 21, - 56 - ], - [ - 60, - 3, - 39, - 18, - 8, - 53 - ], - [ - 57, - 33, - 15, - 47, - 38, - 16 - ], - [ - 20, - 11, - 19, - 4, - 2, - 8 - ], - [ - 36, - 34, - 60, - 47, - 1, - 40 - ], - [ - 33, - 14, - 3, - 21, - 19, - 35 - ], - [ - 60, - 24, - 22, - 52, - 35, - 53 - ] - ], - [ - [ - 23, - 54, - 53, - 58, - 11, - 8 - ], - [ - 11, - 30, - 15, - 59, - 63, - 55 - ], - [ - 20, - 58, - 29, - 17, - 42, - 30 - ], - [ - 18, - 1, - 43, - 15, - 8, - 3 - ], - [ - 59, - 55, - 13, - 28, - 26, - 63 - ], - [ - 33, - 45, - 27, - 53, - 63, - 19 - ], - [ - 28, - 57, - 51, - 54, - 34, - 53 - ], - [ - 24, - 13, - 3, - 47, - 45, - 50 - ], - [ - 47, - 59, - 33, - 57, - 37, - 35 - ], - [ - 23, - 42, - 33, - 41, - 48, - 30 - ], - [ - 42, - 63, - 23, - 25, - 17, - 34 - ], - [ - 14, - 62, - 2, - 19, - 45, - 43 - ], - [ - 12, - 6, - 20, - 30, - 29, - 17 - ], - [ - 6, - 3, - 17, - 8, - 27, - 31 - ], - [ - 19, - 59, - 2, - 10, - 54, - 55 - ], - [ - 45, - 61, - 19, - 57, - 42, - 17 - ], - [ - 55, - 60, - 41, - 34, - 35, - 53 - ], - [ - 36, - 2, - 8, - 31, - 14, - 5 - ], - [ - 36, - 45, - 16, - 38, - 51, - 21 - ], - [ - 57, - 39, - 48, - 6, - 19, - 30 - ], - [ - 39, - 37, - 59, - 48, - 42, - 40 - ], - [ - 52, - 28, - 33, - 46, - 18, - 55 - ], - [ - 1, - 15, - 28, - 13, - 52, - 18 - ], - [ - 10, - 21, - 4, - 2, - 23, - 31 - ], - [ - 45, - 9, - 41, - 18, - 54, - 3 - ], - [ - 42, - 61, - 36, - 3, - 19, - 27 - ], - [ - 41, - 46, - 33, - 3, - 10, - 15 - ] - ], - [ - [ - 49, - 52, - 60, - 63, - 21, - 0 - ], - [ - 14, - 7, - 25, - 52, - 58, - 36 - ], - [ - 46, - 57, - 28, - 24, - 49, - 12 - ], - [ - 10, - 21, - 39, - 1, - 11, - 17 - ], - [ - 53, - 3, - 2, - 16, - 46, - 44 - ], - [ - 25, - 40, - 39, - 15, - 34, - 30 - ], - [ - 62, - 24, - 47, - 52, - 17, - 49 - ], - [ - 34, - 41, - 39, - 11, - 33, - 15 - ], - [ - 36, - 26, - 48, - 39, - 15, - 56 - ], - [ - 23, - 36, - 43, - 20, - 10, - 49 - ], - [ - 63, - 48, - 51, - 10, - 62, - 18 - ], - [ - 10, - 42, - 36, - 39, - 49, - 19 - ], - [ - 22, - 38, - 51, - 56, - 50, - 7 - ], - [ - 4, - 31, - 19, - 22, - 17, - 7 - ], - [ - 41, - 7, - 39, - 2, - 26, - 57 - ], - [ - 33, - 27, - 58, - 28, - 20, - 45 - ], - [ - 58, - 8, - 37, - 18, - 34, - 62 - ], - [ - 24, - 34, - 43, - 39, - 17, - 19 - ], - [ - 23, - 50, - 12, - 32, - 43, - 31 - ], - [ - 9, - 7, - 51, - 16, - 30, - 60 - ], - [ - 25, - 62, - 14, - 53, - 58, - 0 - ], - [ - 53, - 4, - 18, - 22, - 14, - 29 - ], - [ - 6, - 14, - 15, - 48, - 58, - 0 - ], - [ - 35, - 4, - 45, - 3, - 59, - 39 - ], - [ - 17, - 57, - 23, - 16, - 40, - 37 - ], - [ - 9, - 19, - 3, - 36, - 11, - 47 - ], - [ - 39, - 14, - 22, - 33, - 62, - 42 - ] - ], - [ - [ - 14, - 17, - 2, - 39, - 47, - 63 - ], - [ - 23, - 58, - 2, - 25, - 5, - 18 - ], - [ - 27, - 33, - 0, - 56, - 6, - 7 - ], - [ - 23, - 5, - 30, - 7, - 21, - 39 - ], - [ - 39, - 53, - 27, - 38, - 54, - 41 - ], - [ - 37, - 25, - 50, - 40, - 33, - 36 - ], - [ - 49, - 37, - 24, - 55, - 29, - 28 - ], - [ - 34, - 15, - 36, - 16, - 30, - 11 - ], - [ - 5, - 36, - 2, - 26, - 16, - 48 - ], - [ - 56, - 46, - 10, - 23, - 39, - 45 - ], - [ - 10, - 56, - 62, - 45, - 49, - 3 - ], - [ - 35, - 6, - 36, - 50, - 11, - 1 - ], - [ - 50, - 46, - 39, - 51, - 38, - 41 - ], - [ - 41, - 19, - 31, - 48, - 6, - 24 - ], - [ - 2, - 7, - 57, - 18, - 47, - 63 - ], - [ - 27, - 58, - 22, - 8, - 57, - 45 - ], - [ - 37, - 59, - 44, - 51, - 55, - 18 - ], - [ - 52, - 21, - 25, - 12, - 34, - 39 - ], - [ - 23, - 34, - 60, - 13, - 27, - 32 - ], - [ - 32, - 56, - 26, - 24, - 44, - 30 - ], - [ - 41, - 12, - 53, - 2, - 5, - 56 - ], - [ - 44, - 18, - 39, - 33, - 38, - 46 - ], - [ - 25, - 0, - 59, - 18, - 11, - 53 - ], - [ - 1, - 7, - 17, - 30, - 4, - 37 - ], - [ - 62, - 48, - 19, - 54, - 37, - 22 - ], - [ - 10, - 28, - 3, - 16, - 11, - 42 - ], - [ - 26, - 30, - 57, - 42, - 41, - 33 - ] - ], - [ - [ - 36, - 31, - 37, - 16, - 43, - 63 - ], - [ - 26, - 51, - 0, - 48, - 42, - 21 - ], - [ - 18, - 41, - 37, - 34, - 24, - 30 - ], - [ - 58, - 4, - 22, - 44, - 18, - 51 - ], - [ - 18, - 26, - 4, - 61, - 14, - 31 - ], - [ - 9, - 23, - 34, - 29, - 3, - 31 - ], - [ - 12, - 58, - 36, - 46, - 16, - 25 - ], - [ - 4, - 56, - 9, - 33, - 60, - 32 - ], - [ - 14, - 26, - 43, - 30, - 46, - 35 - ], - [ - 2, - 0, - 10, - 43, - 28, - 27 - ], - [ - 61, - 18, - 19, - 17, - 2, - 33 - ], - [ - 60, - 19, - 23, - 62, - 37, - 43 - ], - [ - 56, - 32, - 8, - 61, - 58, - 38 - ], - [ - 25, - 28, - 22, - 0, - 26, - 29 - ], - [ - 22, - 49, - 7, - 23, - 60, - 55 - ], - [ - 46, - 57, - 19, - 37, - 12, - 27 - ], - [ - 22, - 60, - 11, - 30, - 19, - 20 - ], - [ - 36, - 20, - 17, - 2, - 37, - 32 - ], - [ - 13, - 16, - 32, - 58, - 41, - 10 - ], - [ - 50, - 55, - 13, - 25, - 14, - 41 - ], - [ - 14, - 3, - 44, - 48, - 60, - 6 - ], - [ - 36, - 18, - 57, - 22, - 1, - 43 - ], - [ - 10, - 22, - 15, - 52, - 13, - 18 - ], - [ - 4, - 56, - 10, - 31, - 14, - 52 - ], - [ - 9, - 4, - 16, - 54, - 8, - 59 - ], - [ - 3, - 43, - 27, - 45, - 6, - 30 - ], - [ - 33, - 4, - 43, - 40, - 46, - 18 - ] - ], - [ - [ - 48, - 38, - 42, - 63, - 47, - 7 - ], - [ - 3, - 10, - 62, - 6, - 26, - 2 - ], - [ - 39, - 44, - 6, - 7, - 45, - 40 - ], - [ - 56, - 33, - 2, - 41, - 62, - 48 - ], - [ - 6, - 18, - 37, - 41, - 27, - 30 - ], - [ - 59, - 45, - 23, - 31, - 53, - 9 - ], - [ - 0, - 13, - 12, - 44, - 33, - 16 - ], - [ - 12, - 56, - 50, - 60, - 28, - 63 - ], - [ - 14, - 33, - 26, - 24, - 32, - 16 - ], - [ - 43, - 2, - 0, - 50, - 28, - 22 - ], - [ - 61, - 9, - 63, - 8, - 17, - 39 - ], - [ - 60, - 19, - 59, - 23, - 24, - 30 - ], - [ - 56, - 59, - 53, - 32, - 7, - 38 - ], - [ - 28, - 25, - 22, - 24, - 59, - 42 - ], - [ - 22, - 49, - 46, - 23, - 60, - 7 - ], - [ - 46, - 34, - 27, - 29, - 0, - 12 - ], - [ - 17, - 22, - 5, - 31, - 20, - 27 - ], - [ - 39, - 48, - 57, - 37, - 17, - 36 - ], - [ - 42, - 49, - 32, - 7, - 16, - 61 - ], - [ - 1, - 7, - 23, - 14, - 36, - 39 - ], - [ - 21, - 7, - 48, - 14, - 60, - 35 - ], - [ - 8, - 18, - 24, - 60, - 1, - 0 - ], - [ - 51, - 15, - 33, - 36, - 5, - 30 - ], - [ - 52, - 4, - 8, - 29, - 9, - 26 - ], - [ - 4, - 9, - 16, - 40, - 58, - 60 - ], - [ - 3, - 17, - 21, - 61, - 19, - 8 - ], - [ - 35, - 54, - 1, - 53, - 60, - 0 - ] - ], - [ - [ - 19, - 1, - 31, - 52, - 49, - 63 - ], - [ - 7, - 47, - 5, - 60, - 22, - 46 - ], - [ - 59, - 30, - 3, - 11, - 0, - 19 - ], - [ - 43, - 42, - 19, - 62, - 8, - 56 - ], - [ - 61, - 15, - 25, - 18, - 39, - 27 - ], - [ - 61, - 50, - 36, - 45, - 33, - 44 - ], - [ - 8, - 37, - 52, - 1, - 2, - 41 - ], - [ - 36, - 30, - 53, - 11, - 16, - 29 - ], - [ - 14, - 58, - 46, - 49, - 3, - 26 - ], - [ - 62, - 43, - 0, - 45, - 22, - 46 - ], - [ - 6, - 56, - 45, - 18, - 10, - 41 - ], - [ - 60, - 21, - 50, - 47, - 30, - 35 - ], - [ - 11, - 53, - 28, - 56, - 41, - 39 - ], - [ - 23, - 9, - 33, - 28, - 22, - 26 - ], - [ - 23, - 56, - 34, - 27, - 2, - 63 - ], - [ - 22, - 9, - 44, - 41, - 37, - 47 - ], - [ - 1, - 11, - 46, - 27, - 3, - 52 - ], - [ - 51, - 37, - 17, - 21, - 61, - 30 - ], - [ - 13, - 19, - 32, - 5, - 2, - 9 - ], - [ - 41, - 32, - 6, - 47, - 29, - 56 - ], - [ - 53, - 33, - 9, - 35, - 38, - 12 - ], - [ - 40, - 19, - 51, - 7, - 26, - 22 - ], - [ - 37, - 5, - 25, - 46, - 34, - 53 - ], - [ - 16, - 55, - 20, - 24, - 44, - 53 - ], - [ - 21, - 46, - 11, - 0, - 36, - 7 - ], - [ - 18, - 60, - 32, - 3, - 34, - 28 - ], - [ - 52, - 9, - 36, - 48, - 11, - 41 - ] - ], - [ - [ - 53, - 15, - 34, - 0, - 36, - 63 - ], - [ - 8, - 12, - 41, - 11, - 19, - 55 - ], - [ - 56, - 13, - 31, - 36, - 23, - 47 - ], - [ - 36, - 51, - 30, - 7, - 26, - 54 - ], - [ - 58, - 13, - 50, - 2, - 53, - 34 - ], - [ - 49, - 52, - 23, - 32, - 7, - 26 - ], - [ - 61, - 38, - 23, - 0, - 4, - 28 - ], - [ - 42, - 27, - 17, - 9, - 18, - 20 - ], - [ - 1, - 34, - 45, - 4, - 12, - 7 - ], - [ - 54, - 27, - 57, - 38, - 44, - 19 - ], - [ - 59, - 40, - 1, - 48, - 30, - 60 - ], - [ - 25, - 31, - 32, - 51, - 62, - 8 - ], - [ - 62, - 19, - 14, - 2, - 37, - 26 - ], - [ - 40, - 57, - 37, - 35, - 22, - 61 - ], - [ - 45, - 16, - 34, - 42, - 37, - 48 - ], - [ - 50, - 16, - 62, - 33, - 25, - 37 - ], - [ - 42, - 13, - 39, - 47, - 3, - 63 - ], - [ - 46, - 23, - 28, - 27, - 4, - 15 - ], - [ - 56, - 62, - 31, - 35, - 59, - 45 - ], - [ - 15, - 38, - 13, - 63, - 4, - 48 - ], - [ - 34, - 15, - 57, - 38, - 13, - 24 - ], - [ - 62, - 36, - 41, - 54, - 46, - 29 - ], - [ - 22, - 53, - 46, - 34, - 30, - 23 - ], - [ - 24, - 10, - 4, - 47, - 18, - 36 - ], - [ - 50, - 57, - 51, - 11, - 49, - 3 - ], - [ - 51, - 7, - 18, - 0, - 11, - 44 - ], - [ - 39, - 37, - 9, - 42, - 40, - 44 - ] - ], - [ - [ - 36, - 25, - 57, - 55, - 47, - 63 - ], - [ - 0, - 2, - 46, - 3, - 51, - 34 - ], - [ - 24, - 2, - 46, - 15, - 33, - 43 - ], - [ - 22, - 31, - 17, - 19, - 10, - 55 - ], - [ - 58, - 59, - 3, - 9, - 40, - 57 - ], - [ - 23, - 31, - 43, - 2, - 57, - 38 - ], - [ - 9, - 62, - 13, - 42, - 52, - 47 - ], - [ - 24, - 7, - 14, - 10, - 46, - 59 - ], - [ - 47, - 18, - 4, - 37, - 0, - 13 - ], - [ - 54, - 4, - 25, - 47, - 36, - 38 - ], - [ - 15, - 47, - 1, - 24, - 58, - 14 - ], - [ - 54, - 5, - 16, - 63, - 14, - 7 - ], - [ - 49, - 3, - 33, - 13, - 46, - 10 - ], - [ - 18, - 10, - 11, - 13, - 63, - 39 - ], - [ - 19, - 62, - 32, - 58, - 10, - 43 - ], - [ - 43, - 48, - 63, - 5, - 55, - 53 - ], - [ - 24, - 51, - 47, - 15, - 59, - 32 - ], - [ - 44, - 0, - 34, - 43, - 3, - 6 - ], - [ - 58, - 38, - 54, - 47, - 11, - 4 - ], - [ - 55, - 27, - 15, - 9, - 42, - 31 - ], - [ - 43, - 41, - 1, - 51, - 5, - 29 - ], - [ - 49, - 27, - 20, - 6, - 4, - 13 - ], - [ - 11, - 25, - 2, - 54, - 27, - 50 - ], - [ - 38, - 44, - 40, - 54, - 33, - 14 - ], - [ - 13, - 63, - 52, - 2, - 29, - 8 - ], - [ - 23, - 41, - 59, - 57, - 38, - 15 - ], - [ - 23, - 6, - 62, - 50, - 51, - 34 - ] - ], - [ - [ - 41, - 2, - 42, - 16, - 50, - 23 - ], - [ - 51, - 41, - 5, - 15, - 40, - 21 - ], - [ - 43, - 1, - 29, - 55, - 21, - 35 - ], - [ - 24, - 53, - 25, - 51, - 32, - 29 - ], - [ - 41, - 31, - 49, - 57, - 60, - 34 - ], - [ - 17, - 4, - 35, - 30, - 10, - 38 - ], - [ - 34, - 7, - 21, - 9, - 48, - 31 - ], - [ - 14, - 24, - 7, - 46, - 25, - 27 - ], - [ - 47, - 0, - 12, - 6, - 37, - 60 - ], - [ - 30, - 4, - 25, - 47, - 36, - 54 - ], - [ - 24, - 61, - 15, - 47, - 46, - 1 - ], - [ - 5, - 14, - 24, - 16, - 57, - 63 - ], - [ - 49, - 3, - 17, - 26, - 36, - 44 - ], - [ - 13, - 10, - 61, - 0, - 11, - 22 - ], - [ - 10, - 6, - 2, - 49, - 58, - 46 - ], - [ - 19, - 11, - 2, - 25, - 54, - 18 - ], - [ - 10, - 5, - 52, - 24, - 18, - 17 - ], - [ - 44, - 13, - 55, - 31, - 63, - 38 - ], - [ - 58, - 7, - 25, - 32, - 38, - 14 - ], - [ - 49, - 55, - 34, - 13, - 16, - 40 - ], - [ - 51, - 23, - 59, - 35, - 5, - 4 - ], - [ - 57, - 15, - 46, - 27, - 42, - 32 - ], - [ - 33, - 15, - 23, - 52, - 24, - 27 - ], - [ - 4, - 41, - 33, - 10, - 26, - 40 - ], - [ - 8, - 13, - 59, - 4, - 9, - 39 - ], - [ - 19, - 3, - 27, - 31, - 43, - 2 - ], - [ - 46, - 61, - 25, - 8, - 29, - 50 - ] - ], - [ - [ - 48, - 38, - 42, - 50, - 62, - 63 - ], - [ - 3, - 2, - 10, - 26, - 17, - 6 - ], - [ - 39, - 44, - 6, - 45, - 7, - 40 - ], - [ - 5, - 62, - 53, - 50, - 41, - 3 - ], - [ - 6, - 41, - 49, - 37, - 30, - 23 - ], - [ - 59, - 60, - 4, - 46, - 53, - 29 - ], - [ - 44, - 34, - 7, - 15, - 13, - 43 - ], - [ - 12, - 24, - 14, - 25, - 58, - 7 - ], - [ - 33, - 35, - 4, - 37, - 8, - 36 - ], - [ - 40, - 30, - 36, - 25, - 20, - 54 - ], - [ - 9, - 8, - 24, - 5, - 25, - 63 - ], - [ - 59, - 34, - 5, - 24, - 57, - 6 - ], - [ - 17, - 49, - 44, - 26, - 55, - 7 - ], - [ - 47, - 13, - 59, - 27, - 22, - 26 - ], - [ - 49, - 2, - 10, - 44, - 46, - 54 - ], - [ - 34, - 2, - 54, - 55, - 57, - 53 - ], - [ - 5, - 17, - 29, - 31, - 43, - 27 - ], - [ - 48, - 57, - 38, - 39, - 63, - 43 - ], - [ - 42, - 38, - 49, - 32, - 7, - 40 - ], - [ - 59, - 1, - 16, - 23, - 60, - 10 - ], - [ - 21, - 7, - 35, - 53, - 48, - 31 - ], - [ - 8, - 60, - 24, - 42, - 14, - 35 - ], - [ - 33, - 51, - 15, - 28, - 23, - 5 - ], - [ - 4, - 41, - 9, - 8, - 11, - 51 - ], - [ - 60, - 24, - 19, - 48, - 9, - 4 - ], - [ - 19, - 3, - 26, - 58, - 12, - 21 - ], - [ - 54, - 35, - 1, - 60, - 53, - 49 - ] - ], - [ - [ - 21, - 7, - 53, - 56, - 63, - 33 - ], - [ - 3, - 34, - 57, - 16, - 20, - 51 - ], - [ - 55, - 11, - 16, - 60, - 0, - 13 - ], - [ - 62, - 43, - 5, - 50, - 8, - 42 - ], - [ - 35, - 52, - 9, - 43, - 0, - 27 - ], - [ - 16, - 60, - 29, - 61, - 28, - 58 - ], - [ - 31, - 34, - 43, - 30, - 2, - 18 - ], - [ - 36, - 25, - 24, - 53, - 48, - 35 - ], - [ - 48, - 35, - 4, - 32, - 8, - 60 - ], - [ - 25, - 54, - 30, - 10, - 4, - 41 - ], - [ - 9, - 3, - 19, - 20, - 61, - 24 - ], - [ - 30, - 5, - 17, - 45, - 18, - 49 - ], - [ - 49, - 21, - 17, - 26, - 42, - 37 - ], - [ - 38, - 25, - 10, - 51, - 54, - 13 - ], - [ - 2, - 27, - 26, - 10, - 58, - 30 - ], - [ - 9, - 2, - 54, - 21, - 25, - 13 - ], - [ - 22, - 23, - 33, - 27, - 51, - 52 - ], - [ - 30, - 21, - 35, - 55, - 5, - 17 - ], - [ - 38, - 27, - 24, - 56, - 21, - 35 - ], - [ - 53, - 45, - 22, - 28, - 0, - 32 - ], - [ - 40, - 42, - 48, - 37, - 52, - 8 - ], - [ - 46, - 45, - 57, - 36, - 51, - 32 - ], - [ - 62, - 10, - 38, - 42, - 41, - 54 - ], - [ - 60, - 1, - 10, - 31, - 44, - 36 - ], - [ - 29, - 6, - 9, - 8, - 56, - 7 - ], - [ - 9, - 59, - 22, - 0, - 35, - 57 - ], - [ - 45, - 27, - 62, - 47, - 3, - 28 - ] - ], - [ - [ - 27, - 13, - 18, - 8, - 63, - 55 - ], - [ - 36, - 21, - 57, - 8, - 46, - 55 - ], - [ - 43, - 61, - 10, - 13, - 41, - 37 - ], - [ - 43, - 16, - 24, - 6, - 26, - 61 - ], - [ - 60, - 29, - 35, - 31, - 16, - 23 - ], - [ - 9, - 58, - 60, - 0, - 17, - 38 - ], - [ - 63, - 16, - 7, - 13, - 31, - 18 - ], - [ - 16, - 25, - 24, - 2, - 47, - 48 - ], - [ - 8, - 35, - 6, - 14, - 48, - 4 - ], - [ - 30, - 25, - 54, - 34, - 4, - 10 - ], - [ - 20, - 23, - 19, - 7, - 34, - 38 - ], - [ - 5, - 3, - 12, - 32, - 24, - 42 - ], - [ - 49, - 17, - 10, - 26, - 32, - 60 - ], - [ - 13, - 10, - 29, - 22, - 58, - 54 - ], - [ - 44, - 8, - 2, - 58, - 6, - 5 - ], - [ - 25, - 6, - 2, - 54, - 19, - 53 - ], - [ - 5, - 61, - 30, - 17, - 27, - 1 - ], - [ - 22, - 55, - 63, - 57, - 19, - 33 - ], - [ - 35, - 46, - 6, - 32, - 14, - 38 - ], - [ - 45, - 13, - 51, - 14, - 5, - 7 - ], - [ - 15, - 23, - 50, - 51, - 13, - 59 - ], - [ - 57, - 36, - 6, - 53, - 14, - 61 - ], - [ - 10, - 23, - 15, - 33, - 27, - 30 - ], - [ - 4, - 10, - 33, - 41, - 26, - 36 - ], - [ - 8, - 4, - 33, - 9, - 47, - 20 - ], - [ - 3, - 43, - 6, - 19, - 30, - 2 - ], - [ - 55, - 4, - 46, - 40, - 18, - 20 - ] - ], - [ - [ - 48, - 38, - 63, - 42, - 47, - 50 - ], - [ - 3, - 10, - 26, - 6, - 35, - 2 - ], - [ - 39, - 44, - 7, - 45, - 6, - 42 - ], - [ - 56, - 23, - 61, - 16, - 33, - 41 - ], - [ - 6, - 37, - 44, - 16, - 30, - 19 - ], - [ - 59, - 17, - 60, - 19, - 5, - 52 - ], - [ - 44, - 55, - 33, - 63, - 13, - 31 - ], - [ - 12, - 25, - 63, - 2, - 24, - 54 - ], - [ - 33, - 8, - 42, - 4, - 35, - 41 - ], - [ - 30, - 40, - 25, - 21, - 48, - 51 - ], - [ - 8, - 23, - 34, - 5, - 19, - 22 - ], - [ - 34, - 59, - 42, - 11, - 5, - 3 - ], - [ - 10, - 26, - 50, - 55, - 53, - 59 - ], - [ - 59, - 54, - 13, - 47, - 10, - 25 - ], - [ - 24, - 2, - 49, - 46, - 38, - 8 - ], - [ - 34, - 2, - 53, - 54, - 40, - 57 - ], - [ - 17, - 29, - 5, - 27, - 18, - 43 - ], - [ - 48, - 57, - 63, - 39, - 38, - 32 - ], - [ - 42, - 49, - 7, - 32, - 38, - 61 - ], - [ - 1, - 59, - 23, - 14, - 16, - 10 - ], - [ - 21, - 7, - 35, - 53, - 13, - 11 - ], - [ - 8, - 60, - 24, - 25, - 55, - 42 - ], - [ - 33, - 5, - 15, - 51, - 28, - 44 - ], - [ - 4, - 9, - 51, - 8, - 41, - 39 - ], - [ - 60, - 4, - 19, - 48, - 9, - 24 - ], - [ - 19, - 3, - 26, - 12, - 21, - 53 - ], - [ - 53, - 1, - 35, - 54, - 60, - 20 - ] - ], - [ - [ - 17, - 37, - 31, - 32, - 63, - 50 - ], - [ - 12, - 2, - 9, - 32, - 47, - 17 - ], - [ - 3, - 57, - 56, - 50, - 33, - 38 - ], - [ - 43, - 42, - 19, - 52, - 8, - 17 - ], - [ - 61, - 39, - 27, - 12, - 15, - 57 - ], - [ - 6, - 33, - 36, - 44, - 29, - 61 - ], - [ - 2, - 41, - 42, - 15, - 52, - 5 - ], - [ - 36, - 37, - 15, - 53, - 18, - 62 - ], - [ - 50, - 58, - 3, - 5, - 16, - 4 - ], - [ - 16, - 22, - 55, - 32, - 41, - 26 - ], - [ - 3, - 41, - 56, - 45, - 34, - 10 - ], - [ - 21, - 53, - 50, - 34, - 38, - 35 - ], - [ - 35, - 11, - 37, - 16, - 53, - 42 - ], - [ - 9, - 15, - 54, - 38, - 12, - 52 - ], - [ - 56, - 2, - 23, - 24, - 9, - 5 - ], - [ - 2, - 44, - 9, - 53, - 35, - 54 - ], - [ - 27, - 46, - 1, - 5, - 60, - 50 - ], - [ - 51, - 61, - 50, - 54, - 33, - 44 - ], - [ - 2, - 38, - 19, - 9, - 5, - 32 - ], - [ - 32, - 28, - 6, - 15, - 0, - 33 - ], - [ - 26, - 53, - 27, - 13, - 5, - 4 - ], - [ - 34, - 40, - 45, - 55, - 62, - 53 - ], - [ - 41, - 5, - 46, - 4, - 45, - 3 - ], - [ - 20, - 4, - 22, - 47, - 59, - 58 - ], - [ - 62, - 36, - 25, - 28, - 53, - 33 - ], - [ - 22, - 33, - 7, - 3, - 54, - 0 - ], - [ - 58, - 27, - 52, - 48, - 45, - 17 - ] - ], - [ - [ - 3, - 43, - 17, - 42, - 35, - 55 - ], - [ - 0, - 42, - 36, - 34, - 32, - 40 - ], - [ - 43, - 12, - 19, - 20, - 2, - 4 - ], - [ - 4, - 28, - 54, - 22, - 58, - 23 - ], - [ - 12, - 60, - 52, - 57, - 46, - 7 - ], - [ - 43, - 2, - 1, - 63, - 20, - 11 - ], - [ - 42, - 19, - 57, - 31, - 15, - 21 - ], - [ - 35, - 6, - 46, - 25, - 36, - 33 - ], - [ - 11, - 4, - 35, - 50, - 32, - 52 - ], - [ - 21, - 25, - 38, - 47, - 31, - 58 - ], - [ - 3, - 2, - 22, - 45, - 42, - 15 - ], - [ - 2, - 21, - 5, - 7, - 43, - 38 - ], - [ - 33, - 10, - 41, - 42, - 60, - 16 - ], - [ - 44, - 56, - 30, - 54, - 9, - 43 - ], - [ - 9, - 5, - 33, - 28, - 58, - 49 - ], - [ - 8, - 53, - 59, - 54, - 31, - 2 - ], - [ - 47, - 0, - 48, - 27, - 24, - 51 - ], - [ - 0, - 60, - 21, - 12, - 56, - 28 - ], - [ - 38, - 13, - 53, - 51, - 9, - 7 - ], - [ - 19, - 28, - 56, - 17, - 21, - 26 - ], - [ - 12, - 43, - 26, - 4, - 0, - 5 - ], - [ - 43, - 1, - 12, - 55, - 4, - 46 - ], - [ - 41, - 5, - 21, - 45, - 25, - 12 - ], - [ - 4, - 20, - 1, - 22, - 30, - 42 - ], - [ - 5, - 62, - 25, - 54, - 48, - 47 - ], - [ - 10, - 28, - 55, - 23, - 0, - 53 - ], - [ - 30, - 57, - 8, - 29, - 17, - 5 - ] - ], - [ - [ - 11, - 16, - 31, - 0, - 46, - 35 - ], - [ - 13, - 49, - 31, - 50, - 16, - 34 - ], - [ - 36, - 13, - 53, - 27, - 4, - 18 - ], - [ - 24, - 7, - 29, - 26, - 12, - 32 - ], - [ - 17, - 35, - 2, - 44, - 10, - 48 - ], - [ - 8, - 7, - 23, - 5, - 51, - 26 - ], - [ - 58, - 15, - 61, - 29, - 38, - 62 - ], - [ - 20, - 9, - 42, - 35, - 3, - 6 - ], - [ - 4, - 47, - 25, - 11, - 1, - 52 - ], - [ - 54, - 25, - 55, - 38, - 21, - 27 - ], - [ - 1, - 60, - 14, - 59, - 22, - 30 - ], - [ - 51, - 31, - 5, - 25, - 14, - 52 - ], - [ - 34, - 2, - 10, - 26, - 52, - 47 - ], - [ - 40, - 57, - 13, - 54, - 9, - 6 - ], - [ - 8, - 44, - 58, - 5, - 16, - 1 - ], - [ - 8, - 53, - 59, - 25, - 52, - 24 - ], - [ - 13, - 47, - 45, - 0, - 42, - 8 - ], - [ - 23, - 44, - 55, - 33, - 38, - 7 - ], - [ - 53, - 11, - 38, - 1, - 8, - 24 - ], - [ - 15, - 38, - 14, - 28, - 0, - 19 - ], - [ - 41, - 27, - 8, - 42, - 40, - 57 - ], - [ - 11, - 62, - 55, - 16, - 10, - 41 - ], - [ - 31, - 20, - 46, - 37, - 34, - 41 - ], - [ - 47, - 4, - 55, - 33, - 49, - 22 - ], - [ - 11, - 45, - 47, - 54, - 48, - 36 - ], - [ - 50, - 10, - 23, - 51, - 18, - 3 - ], - [ - 17, - 31, - 5, - 36, - 4, - 20 - ] - ], - [ - [ - 22, - 6, - 39, - 57, - 29, - 47 - ], - [ - 27, - 6, - 14, - 17, - 51, - 55 - ], - [ - 1, - 11, - 29, - 26, - 47, - 4 - ], - [ - 14, - 38, - 31, - 22, - 29, - 6 - ], - [ - 14, - 59, - 61, - 16, - 1, - 19 - ], - [ - 30, - 8, - 23, - 21, - 47, - 1 - ], - [ - 58, - 4, - 15, - 61, - 27, - 31 - ], - [ - 20, - 42, - 3, - 9, - 35, - 6 - ], - [ - 47, - 4, - 25, - 8, - 36, - 0 - ], - [ - 54, - 55, - 21, - 19, - 33, - 25 - ], - [ - 14, - 4, - 60, - 20, - 24, - 40 - ], - [ - 51, - 0, - 5, - 32, - 52, - 3 - ], - [ - 2, - 36, - 10, - 52, - 26, - 32 - ], - [ - 40, - 13, - 54, - 36, - 57, - 46 - ], - [ - 44, - 8, - 5, - 37, - 58, - 2 - ], - [ - 6, - 53, - 24, - 2, - 54, - 37 - ], - [ - 13, - 47, - 61, - 5, - 19, - 17 - ], - [ - 58, - 55, - 44, - 38, - 63, - 6 - ], - [ - 35, - 46, - 31, - 1, - 32, - 19 - ], - [ - 15, - 13, - 63, - 45, - 9, - 55 - ], - [ - 27, - 15, - 23, - 6, - 35, - 63 - ], - [ - 36, - 62, - 57, - 10, - 41, - 16 - ], - [ - 10, - 33, - 41, - 20, - 5, - 46 - ], - [ - 47, - 4, - 26, - 55, - 10, - 49 - ], - [ - 11, - 4, - 8, - 36, - 48, - 33 - ], - [ - 18, - 51, - 43, - 33, - 50, - 3 - ], - [ - 4, - 55, - 9, - 36, - 43, - 5 - ] - ], - [ - [ - 48, - 38, - 63, - 42, - 47, - 7 - ], - [ - 3, - 26, - 10, - 6, - 42, - 2 - ], - [ - 39, - 44, - 6, - 7, - 45, - 8 - ], - [ - 60, - 31, - 22, - 0, - 54, - 45 - ], - [ - 6, - 59, - 14, - 16, - 37, - 44 - ], - [ - 59, - 30, - 8, - 47, - 60, - 17 - ], - [ - 44, - 56, - 4, - 13, - 15, - 9 - ], - [ - 12, - 24, - 20, - 58, - 61, - 28 - ], - [ - 47, - 33, - 8, - 4, - 36, - 42 - ], - [ - 40, - 54, - 55, - 33, - 21, - 51 - ], - [ - 8, - 14, - 22, - 39, - 31, - 7 - ], - [ - 34, - 59, - 14, - 5, - 52, - 0 - ], - [ - 36, - 10, - 52, - 26, - 44, - 59 - ], - [ - 59, - 44, - 13, - 1, - 36, - 22 - ], - [ - 24, - 5, - 46, - 2, - 44, - 38 - ], - [ - 34, - 53, - 6, - 40, - 30, - 2 - ], - [ - 17, - 29, - 50, - 5, - 47, - 27 - ], - [ - 48, - 38, - 57, - 63, - 32, - 60 - ], - [ - 42, - 7, - 49, - 46, - 32, - 1 - ], - [ - 59, - 1, - 23, - 14, - 10, - 16 - ], - [ - 21, - 35, - 7, - 53, - 6, - 17 - ], - [ - 8, - 24, - 60, - 13, - 55, - 14 - ], - [ - 33, - 5, - 51, - 15, - 3, - 23 - ], - [ - 4, - 9, - 51, - 26, - 41, - 28 - ], - [ - 60, - 4, - 19, - 47, - 40, - 48 - ], - [ - 19, - 3, - 26, - 21, - 53, - 8 - ], - [ - 53, - 1, - 35, - 54, - 60, - 20 - ] - ], - [ - [ - 37, - 46, - 39, - 54, - 27, - 55 - ], - [ - 34, - 5, - 16, - 47, - 6, - 42 - ], - [ - 32, - 38, - 16, - 42, - 3, - 20 - ], - [ - 43, - 19, - 60, - 42, - 52, - 11 - ], - [ - 7, - 61, - 39, - 57, - 12, - 46 - ], - [ - 50, - 36, - 44, - 24, - 28, - 31 - ], - [ - 15, - 42, - 41, - 2, - 40, - 32 - ], - [ - 36, - 35, - 38, - 53, - 58, - 51 - ], - [ - 50, - 4, - 16, - 3, - 36, - 58 - ], - [ - 59, - 16, - 32, - 38, - 26, - 56 - ], - [ - 3, - 45, - 37, - 41, - 21, - 33 - ], - [ - 21, - 53, - 34, - 35, - 5, - 39 - ], - [ - 35, - 42, - 63, - 16, - 11, - 46 - ], - [ - 23, - 9, - 51, - 54, - 15, - 38 - ], - [ - 2, - 5, - 20, - 43, - 24, - 29 - ], - [ - 53, - 44, - 9, - 2, - 54, - 20 - ], - [ - 27, - 46, - 1, - 47, - 50, - 5 - ], - [ - 61, - 54, - 57, - 44, - 51, - 43 - ], - [ - 19, - 9, - 33, - 38, - 61, - 37 - ], - [ - 0, - 33, - 15, - 32, - 6, - 9 - ], - [ - 26, - 27, - 53, - 5, - 47, - 54 - ], - [ - 55, - 34, - 12, - 62, - 3, - 4 - ], - [ - 41, - 5, - 46, - 40, - 4, - 32 - ], - [ - 55, - 20, - 44, - 26, - 4, - 40 - ], - [ - 62, - 12, - 28, - 34, - 23, - 33 - ], - [ - 18, - 7, - 22, - 3, - 54, - 14 - ], - [ - 36, - 9, - 27, - 52, - 48, - 11 - ] - ], - [ - [ - 46, - 37, - 61, - 18, - 36, - 63 - ], - [ - 22, - 34, - 28, - 59, - 24, - 56 - ], - [ - 32, - 15, - 17, - 60, - 38, - 20 - ], - [ - 28, - 4, - 58, - 16, - 30, - 35 - ], - [ - 7, - 36, - 9, - 57, - 33, - 23 - ], - [ - 43, - 63, - 2, - 30, - 11, - 19 - ], - [ - 19, - 42, - 57, - 15, - 3, - 22 - ], - [ - 35, - 46, - 47, - 6, - 58, - 8 - ], - [ - 32, - 4, - 37, - 36, - 35, - 57 - ], - [ - 21, - 47, - 55, - 58, - 38, - 54 - ], - [ - 3, - 42, - 2, - 38, - 5, - 17 - ], - [ - 43, - 18, - 2, - 21, - 5, - 10 - ], - [ - 10, - 49, - 24, - 32, - 25, - 4 - ], - [ - 30, - 21, - 63, - 49, - 16, - 51 - ], - [ - 62, - 2, - 13, - 5, - 29, - 40 - ], - [ - 53, - 2, - 9, - 63, - 46, - 54 - ], - [ - 27, - 0, - 63, - 47, - 5, - 14 - ], - [ - 60, - 57, - 43, - 44, - 0, - 50 - ], - [ - 38, - 9, - 61, - 33, - 60, - 13 - ], - [ - 25, - 0, - 15, - 34, - 35, - 11 - ], - [ - 26, - 54, - 28, - 47, - 53, - 37 - ], - [ - 55, - 1, - 34, - 4, - 21, - 22 - ], - [ - 41, - 5, - 40, - 21, - 17, - 23 - ], - [ - 25, - 43, - 52, - 26, - 4, - 59 - ], - [ - 52, - 29, - 35, - 17, - 45, - 60 - ], - [ - 38, - 10, - 15, - 7, - 50, - 3 - ], - [ - 23, - 6, - 19, - 56, - 41, - 15 - ] - ], - [ - [ - 44, - 14, - 20, - 47, - 19, - 56 - ], - [ - 28, - 34, - 2, - 56, - 0, - 11 - ], - [ - 46, - 15, - 61, - 14, - 22, - 60 - ], - [ - 15, - 50, - 36, - 47, - 25, - 21 - ], - [ - 25, - 12, - 13, - 36, - 23, - 57 - ], - [ - 22, - 43, - 1, - 37, - 36, - 30 - ], - [ - 28, - 42, - 19, - 31, - 14, - 21 - ], - [ - 35, - 51, - 47, - 33, - 7, - 46 - ], - [ - 32, - 4, - 16, - 11, - 35, - 19 - ], - [ - 21, - 38, - 16, - 47, - 62, - 15 - ], - [ - 3, - 45, - 39, - 50, - 21, - 2 - ], - [ - 21, - 5, - 56, - 43, - 53, - 52 - ], - [ - 10, - 41, - 33, - 11, - 63, - 37 - ], - [ - 56, - 16, - 9, - 63, - 33, - 45 - ], - [ - 58, - 52, - 2, - 5, - 30, - 56 - ], - [ - 53, - 8, - 59, - 55, - 2, - 15 - ], - [ - 0, - 47, - 51, - 45, - 14, - 37 - ], - [ - 0, - 43, - 21, - 12, - 60, - 62 - ], - [ - 38, - 53, - 60, - 34, - 36, - 9 - ], - [ - 27, - 9, - 56, - 0, - 53, - 11 - ], - [ - 28, - 12, - 43, - 54, - 5, - 62 - ], - [ - 50, - 55, - 34, - 16, - 4, - 21 - ], - [ - 4, - 21, - 5, - 40, - 32, - 54 - ], - [ - 25, - 57, - 49, - 1, - 44, - 43 - ], - [ - 5, - 35, - 42, - 25, - 22, - 2 - ], - [ - 15, - 28, - 38, - 55, - 35, - 37 - ], - [ - 19, - 30, - 6, - 57, - 39, - 51 - ] - ], - [ - [ - 11, - 31, - 46, - 49, - 0, - 16 - ], - [ - 13, - 49, - 50, - 16, - 31, - 19 - ], - [ - 36, - 13, - 27, - 34, - 4, - 53 - ], - [ - 24, - 32, - 7, - 59, - 13, - 15 - ], - [ - 17, - 35, - 2, - 44, - 10, - 63 - ], - [ - 8, - 7, - 23, - 26, - 56, - 42 - ], - [ - 58, - 61, - 29, - 38, - 62, - 50 - ], - [ - 20, - 42, - 3, - 35, - 61, - 47 - ], - [ - 4, - 10, - 7, - 47, - 25, - 34 - ], - [ - 54, - 55, - 18, - 38, - 27, - 28 - ], - [ - 60, - 1, - 14, - 59, - 3, - 30 - ], - [ - 51, - 31, - 5, - 21, - 25, - 52 - ], - [ - 34, - 10, - 2, - 26, - 11, - 47 - ], - [ - 16, - 40, - 34, - 23, - 9, - 33 - ], - [ - 34, - 8, - 1, - 31, - 5, - 32 - ], - [ - 25, - 59, - 8, - 52, - 53, - 27 - ], - [ - 13, - 47, - 45, - 48, - 42, - 0 - ], - [ - 4, - 23, - 53, - 10, - 3, - 25 - ], - [ - 53, - 8, - 38, - 11, - 24, - 63 - ], - [ - 38, - 15, - 0, - 11, - 21, - 8 - ], - [ - 27, - 23, - 62, - 41, - 8, - 42 - ], - [ - 55, - 11, - 10, - 41, - 16, - 34 - ], - [ - 31, - 21, - 41, - 54, - 34, - 9 - ], - [ - 55, - 25, - 22, - 33, - 44, - 47 - ], - [ - 47, - 45, - 20, - 35, - 12, - 42 - ], - [ - 50, - 10, - 47, - 35, - 53, - 57 - ], - [ - 17, - 31, - 5, - 37, - 36, - 20 - ] - ], - [ - [ - 22, - 6, - 29, - 39, - 57, - 44 - ], - [ - 27, - 6, - 14, - 17, - 55, - 57 - ], - [ - 1, - 11, - 29, - 26, - 47, - 4 - ], - [ - 14, - 38, - 31, - 22, - 36, - 29 - ], - [ - 14, - 61, - 59, - 16, - 44, - 63 - ], - [ - 30, - 8, - 47, - 23, - 21, - 0 - ], - [ - 58, - 4, - 54, - 61, - 15, - 62 - ], - [ - 20, - 42, - 58, - 3, - 35, - 0 - ], - [ - 10, - 47, - 4, - 8, - 41, - 19 - ], - [ - 54, - 55, - 33, - 11, - 38, - 21 - ], - [ - 14, - 1, - 4, - 60, - 20, - 40 - ], - [ - 51, - 5, - 14, - 0, - 32, - 62 - ], - [ - 2, - 10, - 36, - 49, - 32, - 52 - ], - [ - 13, - 40, - 54, - 36, - 22, - 11 - ], - [ - 2, - 44, - 37, - 58, - 5, - 8 - ], - [ - 6, - 24, - 2, - 53, - 19, - 52 - ], - [ - 13, - 47, - 61, - 5, - 50, - 30 - ], - [ - 55, - 58, - 4, - 63, - 22, - 38 - ], - [ - 35, - 32, - 46, - 14, - 7, - 58 - ], - [ - 15, - 13, - 45, - 0, - 51, - 60 - ], - [ - 15, - 23, - 27, - 50, - 6, - 17 - ], - [ - 57, - 36, - 10, - 55, - 14, - 16 - ], - [ - 10, - 23, - 41, - 33, - 30, - 54 - ], - [ - 33, - 4, - 26, - 10, - 49, - 55 - ], - [ - 4, - 8, - 33, - 9, - 20, - 12 - ], - [ - 43, - 6, - 3, - 50, - 18, - 38 - ], - [ - 4, - 55, - 43, - 36, - 46, - 5 - ] - ], - [ - [ - 48, - 38, - 63, - 47, - 7, - 42 - ], - [ - 3, - 10, - 26, - 6, - 2, - 38 - ], - [ - 39, - 44, - 6, - 7, - 45, - 8 - ], - [ - 60, - 31, - 22, - 0, - 54, - 27 - ], - [ - 6, - 59, - 16, - 14, - 37, - 44 - ], - [ - 59, - 8, - 30, - 47, - 48, - 3 - ], - [ - 44, - 54, - 4, - 56, - 15, - 13 - ], - [ - 12, - 24, - 20, - 58, - 31, - 61 - ], - [ - 47, - 33, - 10, - 8, - 4, - 36 - ], - [ - 54, - 40, - 55, - 33, - 11, - 51 - ], - [ - 8, - 14, - 23, - 29, - 22, - 31 - ], - [ - 34, - 59, - 14, - 5, - 57, - 50 - ], - [ - 44, - 36, - 10, - 59, - 52, - 9 - ], - [ - 59, - 44, - 13, - 9, - 54, - 1 - ], - [ - 24, - 2, - 46, - 38, - 44, - 47 - ], - [ - 34, - 53, - 6, - 2, - 57, - 40 - ], - [ - 17, - 29, - 47, - 5, - 50, - 34 - ], - [ - 48, - 63, - 57, - 38, - 60, - 32 - ], - [ - 42, - 7, - 49, - 32, - 46, - 1 - ], - [ - 1, - 59, - 23, - 14, - 16, - 10 - ], - [ - 21, - 35, - 7, - 53, - 6, - 17 - ], - [ - 8, - 24, - 60, - 14, - 13, - 55 - ], - [ - 33, - 5, - 51, - 15, - 23, - 3 - ], - [ - 9, - 4, - 51, - 26, - 41, - 8 - ], - [ - 4, - 60, - 19, - 47, - 40, - 9 - ], - [ - 19, - 3, - 26, - 21, - 17, - 8 - ], - [ - 1, - 35, - 53, - 54, - 60, - 20 - ] - ], - [ - [ - 12, - 41, - 14, - 62, - 24, - 10 - ], - [ - 10, - 53, - 39, - 35, - 41, - 58 - ], - [ - 33, - 32, - 50, - 31, - 3, - 34 - ], - [ - 43, - 10, - 42, - 11, - 17, - 47 - ], - [ - 42, - 12, - 11, - 19, - 58, - 54 - ], - [ - 36, - 50, - 55, - 61, - 25, - 56 - ], - [ - 41, - 22, - 16, - 52, - 2, - 15 - ], - [ - 26, - 36, - 62, - 53, - 15, - 51 - ], - [ - 58, - 16, - 5, - 53, - 3, - 49 - ], - [ - 32, - 46, - 26, - 45, - 16, - 62 - ], - [ - 41, - 45, - 56, - 49, - 11, - 3 - ], - [ - 52, - 34, - 35, - 50, - 21, - 53 - ], - [ - 59, - 53, - 46, - 30, - 39, - 37 - ], - [ - 20, - 9, - 52, - 2, - 7, - 33 - ], - [ - 20, - 50, - 24, - 29, - 23, - 2 - ], - [ - 53, - 2, - 44, - 41, - 9, - 13 - ], - [ - 47, - 27, - 1, - 5, - 45, - 46 - ], - [ - 61, - 7, - 51, - 30, - 35, - 9 - ], - [ - 19, - 34, - 32, - 17, - 2, - 14 - ], - [ - 15, - 32, - 6, - 45, - 9, - 11 - ], - [ - 35, - 27, - 53, - 6, - 13, - 60 - ], - [ - 26, - 62, - 16, - 28, - 41, - 3 - ], - [ - 5, - 20, - 46, - 37, - 11, - 55 - ], - [ - 55, - 47, - 4, - 16, - 14, - 27 - ], - [ - 36, - 11, - 27, - 62, - 33, - 7 - ], - [ - 33, - 7, - 54, - 3, - 32, - 12 - ], - [ - 58, - 52, - 27, - 26, - 48, - 38 - ] - ], - [ - [ - 6, - 52, - 19, - 63, - 46, - 38 - ], - [ - 8, - 42, - 4, - 47, - 57, - 56 - ], - [ - 31, - 46, - 32, - 4, - 14, - 10 - ], - [ - 28, - 27, - 4, - 37, - 58, - 20 - ], - [ - 57, - 59, - 60, - 22, - 62, - 14 - ], - [ - 30, - 2, - 9, - 57, - 11, - 13 - ], - [ - 20, - 19, - 57, - 42, - 51, - 27 - ], - [ - 46, - 35, - 47, - 32, - 7, - 0 - ], - [ - 12, - 4, - 35, - 10, - 50, - 47 - ], - [ - 15, - 47, - 54, - 25, - 38, - 51 - ], - [ - 18, - 42, - 5, - 15, - 38, - 61 - ], - [ - 18, - 22, - 8, - 16, - 5, - 7 - ], - [ - 5, - 3, - 24, - 35, - 4, - 30 - ], - [ - 17, - 13, - 20, - 55, - 36, - 22 - ], - [ - 22, - 10, - 21, - 54, - 47, - 6 - ], - [ - 30, - 53, - 6, - 19, - 2, - 54 - ], - [ - 22, - 26, - 5, - 7, - 47, - 21 - ], - [ - 4, - 41, - 13, - 46, - 55, - 43 - ], - [ - 17, - 56, - 32, - 45, - 14, - 6 - ], - [ - 5, - 49, - 53, - 28, - 34, - 60 - ], - [ - 42, - 55, - 57, - 17, - 28, - 22 - ], - [ - 40, - 23, - 28, - 57, - 16, - 21 - ], - [ - 51, - 21, - 35, - 24, - 44, - 10 - ], - [ - 38, - 25, - 4, - 14, - 62, - 31 - ], - [ - 31, - 50, - 13, - 56, - 39, - 53 - ], - [ - 6, - 51, - 55, - 8, - 0, - 21 - ], - [ - 29, - 46, - 18, - 55, - 37, - 50 - ] - ], - [ - [ - 31, - 46, - 49, - 59, - 35, - 14 - ], - [ - 13, - 16, - 31, - 50, - 33, - 19 - ], - [ - 13, - 36, - 27, - 52, - 3, - 19 - ], - [ - 9, - 24, - 29, - 12, - 55, - 61 - ], - [ - 35, - 17, - 2, - 10, - 41, - 48 - ], - [ - 51, - 23, - 30, - 57, - 52, - 17 - ], - [ - 34, - 48, - 0, - 38, - 27, - 51 - ], - [ - 47, - 32, - 2, - 35, - 58, - 55 - ], - [ - 15, - 12, - 35, - 45, - 52, - 4 - ], - [ - 36, - 15, - 4, - 59, - 11, - 54 - ], - [ - 60, - 5, - 59, - 50, - 29, - 18 - ], - [ - 18, - 8, - 31, - 23, - 24, - 3 - ], - [ - 34, - 4, - 5, - 58, - 14, - 30 - ], - [ - 17, - 57, - 36, - 55, - 51, - 6 - ], - [ - 10, - 45, - 8, - 54, - 22, - 47 - ], - [ - 30, - 50, - 14, - 6, - 37, - 16 - ], - [ - 26, - 0, - 22, - 5, - 13, - 4 - ], - [ - 4, - 46, - 23, - 43, - 60, - 34 - ], - [ - 45, - 8, - 56, - 62, - 17, - 51 - ], - [ - 4, - 38, - 53, - 15, - 58, - 5 - ], - [ - 57, - 24, - 34, - 15, - 55, - 42 - ], - [ - 23, - 37, - 40, - 41, - 54, - 11 - ], - [ - 48, - 29, - 51, - 31, - 34, - 9 - ], - [ - 62, - 14, - 18, - 4, - 31, - 59 - ], - [ - 50, - 31, - 49, - 45, - 56, - 57 - ], - [ - 51, - 50, - 56, - 8, - 52, - 12 - ], - [ - 17, - 37, - 5, - 39, - 44, - 20 - ] - ], - [ - [ - 45, - 13, - 63, - 37, - 38, - 56 - ], - [ - 63, - 6, - 12, - 18, - 27, - 51 - ], - [ - 3, - 21, - 4, - 48, - 17, - 27 - ], - [ - 14, - 55, - 9, - 37, - 29, - 26 - ], - [ - 35, - 57, - 2, - 13, - 41, - 10 - ], - [ - 39, - 30, - 57, - 13, - 53, - 23 - ], - [ - 34, - 48, - 38, - 27, - 56, - 51 - ], - [ - 47, - 35, - 46, - 2, - 32, - 7 - ], - [ - 12, - 56, - 4, - 35, - 52, - 50 - ], - [ - 15, - 36, - 14, - 30, - 47, - 48 - ], - [ - 5, - 60, - 59, - 22, - 15, - 46 - ], - [ - 8, - 22, - 23, - 47, - 5, - 26 - ], - [ - 4, - 5, - 30, - 58, - 26, - 8 - ], - [ - 17, - 29, - 13, - 50, - 14, - 34 - ], - [ - 54, - 10, - 41, - 14, - 21, - 5 - ], - [ - 50, - 6, - 54, - 30, - 56, - 16 - ], - [ - 0, - 47, - 7, - 4, - 5, - 34 - ], - [ - 4, - 46, - 23, - 41, - 59, - 55 - ], - [ - 62, - 1, - 45, - 47, - 57, - 32 - ], - [ - 4, - 53, - 15, - 50, - 60, - 1 - ], - [ - 34, - 57, - 24, - 31, - 15, - 53 - ], - [ - 23, - 41, - 45, - 12, - 57, - 14 - ], - [ - 51, - 29, - 49, - 48, - 18, - 33 - ], - [ - 14, - 4, - 62, - 26, - 18, - 28 - ], - [ - 50, - 49, - 7, - 24, - 9, - 48 - ], - [ - 56, - 20, - 51, - 3, - 14, - 26 - ], - [ - 55, - 37, - 50, - 14, - 42, - 20 - ] - ], - [ - [ - 51, - 43, - 27, - 30, - 5, - 55 - ], - [ - 24, - 16, - 48, - 15, - 7, - 30 - ], - [ - 26, - 21, - 50, - 52, - 56, - 4 - ], - [ - 19, - 17, - 2, - 14, - 57, - 22 - ], - [ - 3, - 35, - 37, - 45, - 1, - 6 - ], - [ - 13, - 30, - 23, - 39, - 57, - 60 - ], - [ - 9, - 34, - 48, - 17, - 27, - 26 - ], - [ - 47, - 32, - 2, - 59, - 28, - 57 - ], - [ - 18, - 12, - 31, - 20, - 52, - 4 - ], - [ - 14, - 15, - 36, - 30, - 58, - 31 - ], - [ - 5, - 4, - 26, - 19, - 22, - 59 - ], - [ - 22, - 8, - 34, - 26, - 42, - 52 - ], - [ - 20, - 5, - 52, - 4, - 30, - 33 - ], - [ - 34, - 17, - 2, - 5, - 39, - 20 - ], - [ - 54, - 61, - 30, - 14, - 20, - 25 - ], - [ - 56, - 41, - 2, - 5, - 54, - 37 - ], - [ - 12, - 34, - 47, - 7, - 54, - 59 - ], - [ - 9, - 11, - 2, - 43, - 33, - 50 - ], - [ - 8, - 10, - 19, - 20, - 32, - 18 - ], - [ - 17, - 38, - 6, - 29, - 49, - 41 - ], - [ - 38, - 52, - 45, - 57, - 63, - 27 - ], - [ - 26, - 28, - 3, - 40, - 5, - 47 - ], - [ - 60, - 49, - 35, - 34, - 38, - 3 - ], - [ - 61, - 24, - 4, - 55, - 45, - 16 - ], - [ - 46, - 11, - 27, - 0, - 56, - 48 - ], - [ - 60, - 45, - 44, - 25, - 32, - 3 - ], - [ - 11, - 38, - 52, - 48, - 9, - 21 - ] - ], - [ - [ - 22, - 19, - 46, - 31, - 3, - 23 - ], - [ - 32, - 62, - 15, - 54, - 10, - 55 - ], - [ - 47, - 30, - 38, - 5, - 7, - 60 - ], - [ - 15, - 13, - 1, - 8, - 2, - 25 - ], - [ - 13, - 59, - 5, - 6, - 62, - 52 - ], - [ - 27, - 63, - 62, - 45, - 12, - 56 - ], - [ - 50, - 9, - 8, - 51, - 48, - 18 - ], - [ - 59, - 57, - 28, - 2, - 61, - 6 - ], - [ - 18, - 59, - 6, - 52, - 39, - 57 - ], - [ - 14, - 23, - 11, - 36, - 15, - 32 - ], - [ - 26, - 5, - 42, - 25, - 22, - 23 - ], - [ - 33, - 22, - 55, - 28, - 24, - 16 - ], - [ - 6, - 20, - 33, - 14, - 52, - 15 - ], - [ - 17, - 6, - 5, - 39, - 2, - 34 - ], - [ - 54, - 62, + 15, + 30, 25, - 61, - 21, - 14 - ], - [ - 41, - 45, - 14, - 5, - 2, - 54 - ], - [ - 47, - 34, - 10, - 31, - 5, - 41 - ], - [ - 9, - 2, - 36, - 6, - 43, 38 ], [ - 57, - 20, - 41, - 10, - 32, - 18 - ], - [ - 47, - 6, - 49, - 15, - 7, - 34 - ], - [ - 3, - 57, - 44, - 38, - 50, - 53 - ], - [ - 47, - 28, - 57, - 26, - 19, - 22 - ], - [ - 11, - 1, - 18, - 5, - 46, - 22 - ], - [ - 9, - 55, 4, - 3, - 25, - 10 - ], - [ - 45, - 11, - 0, - 63, - 48, - 57 - ], - [ - 42, - 10, - 3, - 43, - 49, - 0 - ], - [ - 38, - 15, - 17, - 41, - 10, - 3 - ] - ], - [ - [ 19, - 18, - 51, - 25, - 60, - 55 - ], - [ - 56, - 27, - 61, - 42, - 55, - 23 + 24, + 35, + 31, + 48 ], [ - 32, - 39, - 37, + 7, 46, - 20, - 52 - ], - [ - 41, - 21, - 37, - 13, - 57, - 2 - ], - [ - 10, - 9, 3, - 46, 58, - 32 - ], - [ - 62, - 57, - 27, - 43, - 2, - 53 - ], - [ - 50, - 21, - 19, - 48, - 15, - 6 - ], - [ - 14, - 21, - 54, - 8, - 57, - 28 + 30, + 41 ], [ - 6, - 18, - 4, - 52, + 58, + 9, 39, - 24 - ], - [ 32, - 38, - 23, - 53, - 25, - 17 - ], - [ - 26, - 22, - 25, - 5, - 42, - 33 - ], - [ - 5, - 22, - 16, - 9, - 61, - 55 + 29, + 40 ], [ - 49, - 33, - 30, - 25, + 40, + 37, 20, - 22 - ], - [ - 18, - 17, + 8, 25, - 63, - 39, - 11 - ], - [ - 62, - 54, - 10, - 5, - 58, - 37 - ], - [ - 42, - 41, - 5, - 24, - 37, - 54 - ], - [ - 47, - 34, - 27, - 10, - 22, 55 ], [ - 60, - 43, - 44, - 7, - 52, - 37 - ], - [ - 12, - 32, - 39, - 38, - 1, - 20 + 19, + 0, + 54, + 52, + 17, + 39 ], [ - 34, - 52, - 49, - 15, - 28, - 40 + 25, + 43, + 12, + 61, + 14, + 11 ], [ + 23, + 4, + 54, + 36, 28, - 39, - 3, - 26, - 30, - 22 + 33 ], [ - 29, - 1, 40, - 22, - 19, - 63 + 2, + 25, + 58, + 36, + 53 ], [ - 41, - 5, + 18, + 46, + 35, + 22, 53, - 33, - 26, - 39 + 16 ], [ - 25, - 38, - 4, - 34, - 49, - 51 + 2, + 6, + 63, + 14, + 42, + 11 ], [ - 29, + 35, + 7, 52, - 48, - 47, - 20, - 33 + 40, + 29, + 57 ], [ - 50, - 3, - 16, - 38, - 53, - 15 + 40, + 15, + 19, + 57, + 17, + 23 ], [ - 19, - 62, - 6, - 23, - 10, - 36 + 9, + 11, + 47, + 22, + 49, + 1 + ], + [ + 24, + 39, + 42, + 2, + 16, + 0 ] ], [ [ - 5, - 14, - 17, - 57, - 10, - 27 + 18, + 8, + 20, + 49, + 30, + 23 ], [ - 43, - 9, - 56, 1, - 14, - 33 + 27, + 26, + 22, + 59, + 36 ], [ - 63, - 35, 43, - 1, - 10, - 27 - ], - [ - 51, - 50, - 57, - 41, - 20, - 54 + 26, + 15, + 58, + 0, + 46 ], [ - 11, - 43, - 40, - 54, - 30, - 33 + 55, + 1, + 35, + 28, + 16, + 32 ], [ - 27, + 59, + 9, + 10, 53, - 63, - 30, - 15, + 12, 58 ], [ - 51, - 48, - 21, - 57, 9, - 50 + 2, + 27, + 11, + 61, + 43 ], [ - 21, - 14, - 59, - 52, - 28, - 7 + 16, + 57, + 63, + 23, + 19, + 12 ], [ - 6, - 18, + 46, + 45, + 26, 4, - 0, - 12, - 59 + 30, + 37 ], [ + 43, + 44, + 20, + 16, 14, - 53, - 30, - 42, - 4, - 36 + 9 ], [ - 26, + 34, + 47, 42, - 5, - 22, - 58, - 19 + 43, + 26, + 51 ], [ - 22, - 5, - 34, + 42, + 2, 38, - 52, - 29 - ], - [ - 49, - 33, - 0, + 45, 20, - 26, - 8 + 36 ], [ - 17, + 18, + 7, + 12, 2, - 39, - 13, - 44, - 63 + 43, + 60 ], [ - 10, - 54, - 20, - 37, - 5, - 59 + 1, + 28, + 12, + 3, + 29, + 33 + ], + [ + 25, + 13, + 0, + 63, + 2, + 62 + ], + [ + 18, + 36, + 6, + 29, + 19, + 15 ], [ + 1, + 42, + 63, 41, - 43, - 24, - 5, - 53, - 14 + 57, + 19 ], [ - 12, - 47, - 34, 57, + 54, + 5, 27, - 5 + 18, + 31 ], [ - 9, - 33, - 7, - 38, - 43, - 31 + 50, + 6, + 13, + 32, + 17, + 20 ], [ - 19, - 10, - 20, + 17, + 5, + 27, 32, - 18, - 1 + 1, + 55 ], [ - 36, - 63, - 15, 49, - 50, - 40 + 0, + 61, + 5, + 10, + 30 ], [ - 39, - 3, - 38, - 27, - 36, - 35 + 29, + 53, + 51, + 13, + 33, + 46 ], [ 29, - 19, - 40, - 47, - 62, - 14 + 17, + 21, + 30, + 14, + 40 ], [ 5, - 46, + 17, 33, - 53, - 49, - 21 + 32, + 18, + 28 ], [ - 55, - 4, - 27, - 61, 51, - 34 + 4, + 20, + 54, + 58, + 41 ], [ - 11, - 0, - 59, + 47, + 4, + 27, 48, - 14, - 27 + 37, + 60 ], [ - 7, - 32, 3, - 51, - 18, - 14 + 26, + 12, + 59, + 2, + 48 ], [ - 11, + 46, + 43, + 18, + 20, 9, - 36, - 48, - 0, - 46 + 53 ] ], [ [ - 63, - 62, - 60, + 57, + 9, 19, - 23, - 56 + 51, + 18, + 41 ], [ + 28, + 57, + 36, + 8, 48, - 32, - 1, - 35, - 5, - 21 + 60 ], [ - 22, - 24, - 46, - 58, + 2, + 51, 59, - 60 + 5, + 34, + 9 ], [ - 27, - 37, - 50, - 28, - 61, - 6 + 9, + 55, + 59, + 26, + 4, + 2 ], [ - 10, - 12, - 15, - 58, + 49, + 56, 35, + 42, + 30, 23 ], [ - 2, - 43, - 57, - 36, + 18, 30, - 20 - ], - [ - 21, - 14, + 22, + 29, 19, - 63, - 41, - 42 - ], - [ - 46, - 7, - 35, - 43, - 21, - 36 + 52 ], [ - 6, - 18, - 0, - 35, - 40, - 4 + 39, + 34, + 33, + 51, + 56, + 3 ], [ - 47, - 53, 32, - 38, + 21, 1, - 58 - ], - [ - 26, - 18, - 22, - 15, - 5, - 46 + 7, + 46, + 49 ], [ - 7, - 22, - 24, - 63, - 53, - 5 + 33, + 54, + 23, + 21, + 12, + 11 ], [ - 11, - 3, - 10, - 37, - 18, - 24 + 5, + 30, + 60, + 47, + 15, + 18 ], [ - 63, + 4, 18, - 13, - 39, - 0, - 17 + 46, + 27, + 20, + 22 ], [ - 62, + 22, + 59, 54, - 5, - 18, - 6, - 14 + 48, + 19, + 4 ], [ - 19, - 43, + 17, + 5, + 56, 31, - 30, - 57, - 42 + 49, + 4 ], [ - 24, - 5, - 51, - 34, + 29, 47, - 21 + 55, + 2, + 53, + 60 ], [ - 59, - 13, - 33, - 44, + 8, + 22, 11, - 41 + 44, + 36, + 15 ], [ - 58, - 32, - 1, - 7, - 39, - 38 + 60, + 44, + 30, + 57, + 54, + 39 ], [ - 49, - 13, - 28, - 9, - 34, - 7 + 7, + 44, + 27, + 20, + 2, + 61 ], [ + 48, + 17, + 21, + 37, 32, - 51, - 3, - 53, + 57 + ], + [ 48, - 13 + 32, + 46, + 6, + 61, + 42 ], [ - 40, - 15, - 37, + 4, 57, 1, - 23 + 36, + 0, + 30 ], [ - 5, + 7, + 17, + 61, 53, - 51, - 35, - 22, - 38 + 21, + 63 ], [ - 4, + 60, 14, - 44, - 41, + 53, + 35, + 18, + 55 + ], + [ + 10, + 15, + 33, + 51, + 36, + 5 + ], + [ 11, - 51 + 4, + 19, + 51, + 21, + 52 ], [ 47, + 19, + 43, 48, - 8, - 61, - 13, + 58, 4 ], [ - 23, 3, - 27, - 43, - 0, - 1 + 33, + 26, + 21, + 52, + 19 ], [ - 46, - 23, - 59, - 62, - 18, - 50 + 24, + 45, + 60, + 35, + 49, + 1 ] ], [ [ - 62, - 0, - 9, - 61, + 23, 54, - 32 + 53, + 58, + 11, + 8 ], [ - 45, + 11, + 30, + 15, + 59, + 63, + 55 + ], + [ + 20, + 58, 29, - 7, - 35, - 22, - 62 + 17, + 42, + 30 ], [ - 56, - 31, - 23, - 53, - 28, - 2 + 18, + 1, + 43, + 15, + 8, + 3 ], [ - 36, - 2, - 5, - 4, - 48, - 41 + 59, + 55, + 13, + 28, + 26, + 63 ], [ - 18, - 0, - 15, - 23, - 16, - 11 + 33, + 45, + 27, + 53, + 63, + 19 ], [ - 4, - 36, + 28, 57, - 56, - 14, - 5 + 51, + 54, + 34, + 53 ], [ - 59, - 56, - 21, - 8, - 7, - 33 + 24, + 13, + 3, + 47, + 45, + 50 ], [ - 12, - 23, - 7, - 2, - 28, - 34 + 47, + 59, + 33, + 57, + 37, + 35 ], [ - 17, - 6, 23, - 19, - 62, - 27 - ], - [ - 53, - 32, - 51, - 38, + 42, + 33, 41, - 58 + 48, + 30 ], [ - 18, + 42, 63, - 19, - 22, - 26, - 5 + 23, + 25, + 17, + 34 ], [ - 22, - 7, - 60, - 11, - 12, - 19 + 14, + 62, + 2, + 19, + 45, + 43 ], [ - 24, - 7, - 11, - 17, + 12, + 6, + 20, 30, - 37 + 29, + 17 ], [ - 47, - 13, - 63, - 50, - 39, - 45 + 6, + 3, + 17, + 8, + 27, + 31 ], [ - 62, + 19, + 59, + 2, + 10, 54, - 18, - 14, - 29, - 56 + 55 ], [ - 56, - 30, - 0, - 55, - 51, - 10 + 45, + 61, + 19, + 57, + 42, + 17 ], [ - 12, - 5, - 21, + 55, + 60, + 41, 34, - 29, - 63 + 35, + 53 ], [ - 42, + 36, + 2, 8, + 31, + 14, + 5 + ], + [ + 36, + 45, + 16, 38, + 51, + 21 + ], + [ 57, - 41, - 44 + 39, + 48, + 6, + 19, + 30 ], [ - 32, - 54, - 1, + 39, 37, - 49, + 59, + 48, + 42, 40 ], [ - 5, - 57, - 49, + 52, 28, - 34, - 10 - ], - [ + 33, 46, - 32, - 13, - 7, - 61, - 48 + 18, + 55 ], [ - 40, - 3, - 30, - 60, - 39, - 14 + 1, + 15, + 28, + 13, + 52, + 18 ], [ - 5, - 44, - 33, - 36, - 28, + 10, + 21, + 4, + 2, + 23, 31 ], [ - 4, - 11, - 28, + 45, + 9, 41, - 51, - 5 + 18, + 54, + 3 ], [ - 48, - 47, + 42, 61, - 28, - 60, + 36, + 3, + 19, 27 ], [ - 14, + 41, + 46, + 33, 3, - 26, - 12, - 53, - 61 - ], - [ - 54, - 60, - 49, - 35, 10, - 62 + 15 ] ], [ [ - 27, - 62, - 63, - 23, - 47, - 56 + 19, + 1, + 31, + 52, + 49, + 63 ], [ 7, - 4, - 2, - 10, - 35, - 36 + 47, + 5, + 60, + 22, + 46 ], [ + 59, + 30, 3, + 11, 0, - 27, - 62, - 50, - 60 + 19 ], [ - 36, - 3, + 43, 42, - 18, - 2, - 48 + 19, + 62, + 8, + 56 ], [ 61, - 12, - 27, - 10, 15, - 14 + 25, + 18, + 39, + 27 ], [ + 61, + 50, 36, + 45, 33, - 50, - 6, - 29, - 16 + 44 ], [ - 2, - 41, 8, - 43, - 40, - 59 + 37, + 52, + 1, + 2, + 41 ], [ 36, + 30, 53, - 15, - 7, - 37, - 2 + 11, + 16, + 29 ], [ - 3, + 14, 58, - 4, - 19, - 5, - 27 + 46, + 49, + 3, + 26 ], [ - 22, - 26, - 53, + 62, + 43, + 0, 45, - 25, - 38 + 22, + 46 ], [ + 6, + 56, + 45, 18, - 41, - 22, - 3, - 34, - 45 + 10, + 41 ], [ + 60, 21, - 53, - 5, - 22, - 28, - 9 + 50, + 47, + 30, + 35 ], [ 11, - 24, + 53, 28, - 16, - 8, - 35 + 56, + 41, + 39 ], [ - 38, + 23, 9, - 21, - 52, - 41, - 39 + 33, + 28, + 22, + 26 ], [ + 23, + 56, + 34, 27, - 54, - 14, - 5, - 8, + 2, 63 ], [ - 56, - 55, - 43, - 30, - 35, - 51 + 22, + 9, + 44, + 41, + 37, + 47 ], [ - 51, + 1, + 11, + 46, + 27, 3, - 14, - 60, - 5, - 26 + 52 ], [ - 62, - 30, - 8, + 51, + 37, + 17, 21, - 44, - 35 + 61, + 30 ], [ - 38, - 34, - 27, + 13, + 19, 32, - 22, - 8 + 5, + 2, + 9 ], [ - 22, - 53, - 28, - 60, - 49, - 12 + 41, + 32, + 6, + 47, + 29, + 56 ], [ - 49, - 32, - 41, 53, - 12, - 33 + 33, + 9, + 35, + 38, + 12 ], [ 40, - 45, - 22, - 44, - 1, - 34 + 19, + 51, + 7, + 26, + 22 ], [ + 37, 5, - 4, - 53, - 25, - 63, - 41 - ], - [ - 4, - 11, 25, - 43, - 31, - 14 - ], - [ - 31, - 52, - 5, - 58, - 43, - 48 - ], - [ - 23, - 1, - 28, - 3, - 0, - 55 + 46, + 34, + 53 ], [ - 27, - 45, - 32, - 4, - 30, - 6 - ] - ], - [ - [ - 41, - 2, - 42, 16, - 50, - 32 - ], - [ - 51, - 5, - 41, - 40, + 55, + 20, + 24, 44, - 21 + 53 ], [ - 43, - 1, - 29, - 55, 21, - 35 - ], - [ + 46, + 11, + 0, 36, - 58, - 25, - 3, - 18, - 54 + 7 ], [ - 31, + 18, 60, - 24, - 12, - 61, - 41 + 32, + 3, + 34, + 28 ], [ - 4, - 10, - 35, + 52, + 9, 36, - 0, - 43 + 48, + 11, + 41 + ] + ], + [ + [ + 17, + 37, + 31, + 32, + 63, + 50 ], [ - 45, - 43, - 63, - 35, - 36, - 48 + 12, + 2, + 9, + 32, + 47, + 17 ], [ 3, - 9, - 8, - 7, - 43, - 27 + 57, + 56, + 50, + 33, + 38 ], [ - 0, 43, - 4, - 40, - 18, - 44 + 42, + 19, + 52, + 8, + 17 ], [ - 34, + 61, + 39, + 27, + 12, 15, - 50, - 38, - 53, - 25 + 57 + ], + [ + 6, + 33, + 36, + 44, + 29, + 61 ], [ - 24, - 46, - 26, - 13, 2, - 18 + 41, + 42, + 15, + 52, + 5 ], [ - 57, + 36, 37, + 15, + 53, + 18, + 62 + ], + [ + 50, + 58, + 3, 5, + 16, + 4 + ], + [ + 16, 22, - 7, - 53 + 55, + 32, + 41, + 26 ], [ 3, - 9, + 41, + 56, + 45, + 34, + 10 + ], + [ + 21, + 53, + 50, + 34, 38, - 26, - 37, - 23 + 35 ], [ - 60, - 26, - 0, - 39, - 13, - 55 + 35, + 11, + 37, + 16, + 53, + 42 ], [ + 9, + 15, 54, - 21, - 58, - 6, - 55, - 14 + 38, + 12, + 52 ], [ - 11, - 19, - 43, - 57, 56, - 37 - ], - [ - 10, + 2, + 23, 24, - 60, - 21, - 5, - 17 + 9, + 5 ], [ - 31, - 20, - 63, + 2, 44, - 55, - 33 + 9, + 53, + 35, + 54 ], [ - 7, - 25, - 40, - 32, - 54, - 58 + 27, + 46, + 1, + 5, + 60, + 50 ], [ - 49, - 30, - 13, + 51, + 61, 50, - 55, - 28 + 54, + 33, + 44 ], [ - 30, - 3, - 51, - 61, - 0, - 16 + 2, + 38, + 19, + 9, + 5, + 32 ], [ - 57, - 46, - 43, - 40, - 30, - 22 + 32, + 28, + 6, + 15, + 0, + 33 ], [ - 52, - 33, - 51, - 19, + 26, + 53, 27, - 29 + 13, + 5, + 4 ], [ - 14, - 4, + 34, 40, + 45, + 55, + 62, + 53 + ], + [ 41, - 0, - 34 + 5, + 46, + 4, + 45, + 3 ], [ + 20, 4, - 13, - 8, - 61, + 22, 47, - 9 + 59, + 58 ], [ - 31, - 27, - 43, - 1, + 62, + 36, + 25, + 28, + 53, + 33 + ], + [ + 22, + 33, + 7, 3, + 54, 0 ], [ - 8, - 4, - 46, - 9, - 3, - 50 + 58, + 27, + 52, + 48, + 45, + 17 ] ] + ], + "routing_indices_token_indices": [ + 3, + 13, + 14, + 17, + 28, + 31, + 35, + 81, + 86, + 94 ] } } \ No newline at end of file diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/golden_values_dev_dgx_h100.json index 032eaf98f95..e277ccea027 100644 --- a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/golden_values_dev_dgx_h100.json @@ -157,25510 +157,2196 @@ "routing_indices": [ [ [ + 48, 33, - 50, - 36, - 4, - 25, - 63 + 24, + 52, + 38, + 61 ], [ - 0, - 16, 3, - 26, - 9, - 54 + 40, + 45, + 19, + 58, + 35 ], [ - 62, + 46, + 39, + 15, 60, - 8, - 16, - 58, - 52 + 31, + 25 ], [ - 58, - 10, - 6, + 49, + 62, + 61, + 48, + 55, + 46 + ], + [ + 41, + 35, + 53, + 52, + 13, + 10 + ], + [ + 52, 45, - 16, - 32 + 15, + 63, + 46, + 10 ], [ - 43, + 14, + 15, + 35, + 60, 49, - 18, - 54, - 55, - 13 + 31 ], [ - 27, - 19, - 26, 44, - 12, - 28 + 3, + 6, + 8, + 23, + 48 ], [ - 53, - 42, - 3, + 32, + 15, + 44, 27, - 26, - 19 + 4, + 41 ], [ - 6, - 47, - 1, - 19, - 8, - 22 + 51, + 25, + 62, + 53, + 48, + 10 ], [ - 51, - 27, - 1, - 38, + 8, + 50, + 5, + 19, 16, - 62 + 22 ], [ - 41, + 36, 49, - 21, - 57, - 16, - 24 + 60, + 44, + 15, + 33 ], [ 13, - 28, - 38, + 27, + 53, + 30, + 56, + 43 + ], + [ + 26, 22, - 49, - 48 + 14, + 7, + 32, + 17 ], [ - 52, + 26, + 60, + 2, 58, - 6, - 25, - 29, - 17 + 54, + 10 ], [ - 49, 50, - 25, - 41, - 54, - 58 + 55, + 17, + 51, + 47, + 14 + ], + [ + 0, + 12, + 16, + 4, + 23, + 9 ], [ + 10, + 23, 27, - 41, - 3, - 26, - 1, - 29 + 46, + 56, + 55 ], [ - 41, - 18, - 34, - 45, - 1, - 52 + 3, + 37, + 4, + 60, + 16, + 59 ], [ - 17, - 26, - 59, - 22, - 19, - 4 + 38, + 3, + 29, + 40, + 25, + 50 ], [ - 1, - 43, + 31, + 16, 62, - 57, - 61, - 29 + 54, + 42, + 63 ], [ + 47, + 35, 11, + 37, 46, - 15, - 28, - 61, - 50 - ], - [ - 14, - 30, - 58, - 26, - 38, - 53 + 32 ], [ - 59, + 51, + 27, 20, - 63, - 54, - 47, - 61 + 50, + 16, + 55 ], [ - 9, - 5, - 43, - 33, - 15, - 46 + 63, + 55, + 46, + 27, + 48, + 12 ], [ - 16, - 52, - 33, - 61, - 49, - 11 - ], - [ - 52, - 35, - 40, - 43, - 29, - 36 - ], - [ - 57, - 34, - 38, - 44, - 20, - 18 + 53, + 50, + 30, + 2, + 39, + 20 ], [ + 53, + 24, 44, + 8, 51, - 2, - 63, - 7, - 17 - ], - [ - 32, - 47, - 58, - 9, - 54, - 5 + 14 ], [ - 12, - 59, - 54, - 33, - 50, - 6 + 19, + 49, + 37, + 14, + 44, + 0 ] ], [ [ - 49, - 43, - 18, - 28, - 23, - 25 - ], - [ + 21, + 50, + 29, + 41, 34, - 24, - 23, - 60, - 2, - 18 + 60 ], [ + 28, + 51, + 60, 33, - 40, - 30, - 3, - 59, - 48 - ], - [ - 16, - 58, - 17, - 48, - 6, + 14, 45 ], [ - 23, - 58, + 31, + 2, 46, - 37, - 34, - 48 + 33, + 24, + 49 ], [ - 28, - 26, - 35, - 33, - 43, - 22 + 42, + 3, + 18, + 62, + 39, + 49 ], [ - 6, - 31, - 13, - 46, - 41, - 37 + 25, + 15, + 62, + 27, + 12, + 11 ], [ - 60, - 29, 44, + 50, 36, - 39, - 15 + 57, + 55, + 41 ], [ - 51, - 18, - 62, - 5, - 27, - 35 + 41, + 37, + 22, + 15, + 2, + 40 ], [ + 36, 62, - 45, - 32, - 56, - 25, - 3 + 53, + 30, + 14, + 57 ], [ - 6, - 20, - 3, 16, - 49, - 41 + 58, + 2, + 29, + 4, + 3 ], [ - 36, 41, - 50, + 38, + 26, + 16, 45, - 35, - 48 + 46 ], [ - 39, + 45, 46, - 48, - 25, + 32, + 41, + 56, + 26 + ], + [ + 17, + 53, 21, - 33 + 11, + 36, + 35 ], [ - 30, - 38, 11, + 16, + 28, + 14, + 51, + 61 + ], + [ + 9, + 35, + 33, 22, - 15, - 9 + 52, + 62 ], [ - 24, - 2, - 32, - 56, + 58, + 50, 63, - 3 + 30, + 7, + 27 ], [ - 36, 55, - 35, - 32, - 17, - 44 + 3, + 8, + 41, + 63, + 37 ], [ - 59, + 3, + 51, 46, 32, - 44, - 24, - 14 - ], - [ - 51, 15, - 61, - 43, - 30, - 22 + 6 ], [ 32, - 2, + 12, + 10, + 25, 5, + 49 + ], + [ + 34, + 2, + 37, + 61, 39, - 11, - 50 + 63 ], [ - 28, 42, - 6, - 30, - 57, - 37 + 22, + 27, + 53, + 11, + 56 ], [ - 28, - 55, - 45, + 53, + 12, 0, - 7, - 5 + 47, + 61, + 1 ], [ + 39, + 45, + 53, + 17, 48, - 40, - 34, - 3, - 22, - 49 + 14 ], [ - 25, 6, - 62, - 18, - 50, - 53 + 0, + 4, + 53, + 25, + 11 ], [ - 13, - 22, - 20, - 28, - 25, - 59 + 51, + 11, + 1, + 63, + 54, + 45 ], [ - 58, - 39, - 36, - 47, - 29, - 37 + 40, + 56, + 37, + 53, + 5, + 35 ], [ - 4, - 33, + 59, + 28, 41, - 12, - 3, - 17 + 10, + 1, + 45 ], [ - 22, - 13, - 20, - 52, + 27, + 30, + 28, 24, - 62 + 32, + 57 ] ], [ [ - 17, - 10, - 57, - 54, + 24, + 56, 6, - 15 + 0, + 19, + 45 ], [ - 33, - 43, - 13, - 1, - 16, - 62 + 11, + 57, + 59, + 25, + 46, + 30 ], [ - 63, - 1, - 35, - 43, - 27, - 10 + 11, + 26, + 37, + 29, + 14, + 52 ], [ - 47, - 4, + 3, + 32, + 7, 38, - 50, - 51, - 0 + 36, + 24 ], [ - 11, - 51, - 57, - 23, + 61, + 2, + 24, 14, - 34 + 51, + 44 ], [ - 10, - 43, - 35, - 33, 20, - 22 + 47, + 0, + 63, + 30, + 58 ], [ + 4, 36, - 48, - 35, - 19, - 21, - 28 + 29, + 58, + 16, + 3 ], [ + 20, + 0, + 45, 14, - 8, - 7, - 46, - 35, - 13 + 28, + 44 ], [ - 18, - 44, - 63, - 6, - 4, - 37 + 29, + 56, + 47, + 35, + 16, + 4 ], [ - 62, - 29, - 15, - 38, - 39, - 34 + 33, + 61, + 55, + 41, + 51, + 38 ], [ + 58, 1, - 6, - 16, - 46, - 22, - 13 + 38, + 14, + 4, + 19 ], [ + 0, 36, - 8, - 16, - 37, - 10, - 14 + 14, + 18, + 52, + 42 ], [ + 29, + 36, + 45, + 25, 8, - 0, - 32, - 3, - 43, - 10 + 6 ], [ - 25, - 0, - 22, - 30, - 60, - 57 + 6, + 57, + 50, + 40, + 58, + 61 ], [ - 60, + 44, 58, - 55, - 32, - 2, - 7 + 29, + 19, + 61, + 56 ], [ + 23, + 18, + 28, 55, - 11, - 19, 5, - 24, - 43 + 37 ], [ + 13, 9, - 47, - 36, - 39, - 5, - 42 + 19, + 43, + 37, + 3 ], [ - 31, - 20, - 9, - 43, - 18, + 32, + 22, + 63, + 14, + 57, 41 ], [ - 12, - 5, - 32, - 50, + 10, + 61, 3, - 31 + 1, + 19, + 32 ], [ - 40, - 18, - 30, - 63, - 7, - 25 + 16, + 55, + 10, + 41, + 59, + 22 ], [ - 14, + 53, 7, - 3, + 29, 38, - 59, - 54 + 27, + 46 ], [ - 31, - 20, - 1, - 22, + 24, 47, - 6 + 18, + 53, + 39, + 30 ], [ + 0, + 33, + 19, + 5, 51, - 15, - 18, - 53, - 40, - 52 + 17 ], [ - 27, - 54, + 51, 8, - 38, - 3, - 59 + 11, + 45, + 44, + 41 ], [ + 40, + 4, + 23, 11, - 16, - 39, - 59, - 9, - 23 + 27, + 19 ], [ - 7, - 37, - 51, - 30, + 16, 18, - 48 + 3, + 48, + 51, + 21 ], [ - 0, - 36, - 33, - 40, + 43, 46, - 48 + 60, + 19, + 53, + 12 ] ], [ [ - 48, - 33, 24, - 52, - 38, - 61 - ], - [ - 3, - 40, - 45, + 56, + 6, + 0, 19, - 58, - 35 - ], - [ - 46, - 39, - 15, - 60, - 31, - 25 - ], - [ - 49, - 62, - 61, - 48, - 55, - 46 - ], - [ - 41, - 35, - 53, - 52, - 13, - 10 + 45 ], [ - 52, - 45, - 15, - 63, + 11, + 57, + 59, 46, - 10 + 25, + 30 ], [ + 26, + 11, + 37, 14, - 15, - 35, - 60, - 49, - 31 + 29, + 49 ], [ - 44, + 38, + 36, 3, - 6, - 8, - 23, - 48 - ], - [ - 32, - 15, - 44, - 27, - 4, - 41 + 24, + 18, + 20 ], [ + 61, 51, - 25, - 62, - 53, - 48, - 10 + 14, + 2, + 24, + 1 ], [ + 20, + 0, + 47, + 30, 8, - 50, - 5, - 19, - 16, - 22 + 35 ], [ + 4, + 58, 36, - 49, - 60, + 54, + 29, + 12 + ], + [ + 20, + 58, 44, - 15, - 33 + 28, + 45, + 9 ], [ - 13, - 27, - 53, - 30, 56, - 43 + 47, + 10, + 29, + 35, + 27 ], [ - 26, - 22, - 14, - 7, - 32, - 17 + 61, + 33, + 55, + 54, + 4, + 36 ], [ - 26, - 60, - 2, 58, - 54, - 10 + 1, + 14, + 4, + 38, + 52 ], [ - 50, - 55, - 17, - 51, - 47, - 14 + 14, + 0, + 36, + 63, + 15, + 52 ], [ - 0, - 12, + 29, + 36, + 44, + 8, 16, - 4, - 23, - 9 + 2 ], [ - 10, - 23, + 6, + 40, 27, - 46, - 56, - 55 + 57, + 50, + 42 ], [ - 3, + 19, + 44, + 58, + 61, 37, - 4, - 60, - 16, - 59 + 38 ], [ - 38, - 3, + 23, + 18, + 17, + 57, + 13, + 40 + ], + [ + 13, + 9, + 19, + 37, + 50, + 15 + ], + [ + 32, + 14, + 57, + 58, 29, - 40, - 25, - 50 + 22 ], [ - 31, + 61, + 10, + 1, + 3, + 14, + 59 + ], + [ + 55, 16, - 62, - 54, - 42, - 63 + 34, + 18, + 22, + 49 + ], + [ + 53, + 27, + 38, + 28, + 23, + 44 ], [ + 24, 47, - 35, - 11, - 37, - 46, - 32 + 18, + 62, + 41, + 30 ], [ + 33, 51, - 27, - 20, - 50, - 16, - 55 + 19, + 5, + 0, + 31 ], [ - 63, - 55, - 46, + 51, + 8, + 25, + 53, 27, - 48, - 12 + 16 ], [ - 53, - 50, - 30, - 2, - 39, - 20 + 40, + 11, + 27, + 4, + 23, + 19 ], [ - 53, - 24, - 44, - 8, + 16, + 18, 51, - 14 + 48, + 3, + 47 ], [ - 19, - 49, - 37, - 14, - 44, - 0 + 46, + 43, + 36, + 9, + 5, + 12 ] ], [ [ 0, - 52, - 16, - 12, - 54, - 7 + 10, + 49, + 23, + 62, + 44 ], [ - 42, - 25, - 51, - 61, - 35, - 58 + 28, + 0, + 36, + 26, + 47, + 52 ], [ - 51, - 42, - 19, - 57, - 28, - 8 + 30, + 4, + 16, + 48, + 40, + 10 ], [ - 49, - 62, - 5, - 2, - 46, - 21 + 61, + 32, + 26, + 16, + 33, + 62 ], [ - 3, - 41, - 53, - 25, + 30, 39, - 37 + 53, + 5, + 57, + 20 ], [ - 45, - 15, + 5, 37, - 48, - 19, - 60 + 61, + 15, + 25, + 6 ], [ - 14, 15, - 47, 17, 24, - 35 + 60, + 49, + 62 + ], + [ + 34, + 39, + 61, + 0, + 58, + 40 ], [ - 3, - 52, - 63, 16, - 28, - 47 + 39, + 36, + 51, + 2, + 29 ], [ - 32, - 27, - 15, + 9, + 11, + 41, + 31, + 56, + 52 + ], + [ + 10, + 48, 24, + 45, 62, - 23 + 51 ], [ - 62, - 25, - 51, - 53, - 0, - 20 + 11, + 38, + 36, + 37, + 6, + 42 ], [ - 8, - 19, + 51, 50, - 16, - 32, - 22 - ], - [ - 13, - 36, - 42, - 49, - 60, - 44 - ], - [ - 30, - 27, - 53, - 56, - 0, - 10 - ], - [ - 22, - 7, - 14, - 26, - 32, - 17 - ], - [ - 60, - 7, - 26, - 2, - 58, - 13 - ], - [ - 55, - 51, - 17, - 60, - 62, - 47 - ], - [ - 53, - 16, - 18, - 4, - 50, - 5 - ], - [ - 28, - 39, - 41, - 1, - 55, - 18 - ], - [ - 59, - 49, - 16, - 23, - 42, - 11 - ], - [ - 17, - 30, - 46, - 55, - 25, - 0 - ], - [ - 54, - 9, - 45, - 0, - 6, - 19 - ], - [ - 51, - 9, - 22, - 23, - 16, - 25 - ], - [ - 62, - 50, - 43, - 51, - 55, - 27 - ], - [ - 13, - 43, - 27, - 1, - 14, - 52 - ], - [ - 61, - 26, - 1, - 17, - 32, - 63 - ], - [ - 37, - 46, - 63, - 20, - 24, - 4 - ], - [ - 63, - 11, - 12, - 61, - 31, - 22 - ] - ], - [ - [ - 49, - 54, - 56, - 11, - 38, - 40 - ], - [ - 53, - 46, - 49, - 38, - 57, - 17 - ], - [ - 48, - 9, - 31, - 12, - 6, - 56 - ], - [ - 49, - 62, - 23, - 5, - 12, - 63 - ], - [ - 36, - 26, - 38, - 7, - 20, - 23 - ], - [ - 37, - 33, - 41, - 58, - 57, - 32 - ], - [ - 10, - 14, - 15, - 31, - 8, - 43 - ], - [ - 16, - 52, - 3, - 2, - 34, - 14 - ], - [ - 15, - 32, - 35, - 27, - 62, - 54 - ], - [ - 51, - 62, - 25, - 53, - 10, - 20 - ], - [ - 19, - 8, - 22, - 50, - 1, - 5 - ], - [ - 13, - 49, - 36, - 60, - 42, - 20 - ], - [ - 27, - 30, - 53, - 54, - 26, - 43 - ], - [ - 14, - 32, - 26, - 22, - 7, - 17 - ], - [ - 26, - 60, - 7, - 2, - 52, - 54 - ], - [ - 55, - 17, - 51, - 47, - 62, - 26 - ], - [ - 16, - 18, - 44, - 4, - 53, - 50 - ], - [ - 1, - 47, - 39, - 45, - 28, - 56 - ], - [ - 23, - 16, - 55, - 49, - 8, - 32 - ], - [ - 17, - 55, - 30, - 62, - 31, - 23 - ], - [ - 9, - 54, - 0, - 44, - 14, - 56 - ], - [ - 23, - 9, - 51, - 22, - 31, - 50 - ], - [ - 50, - 27, - 14, - 51, - 18, - 8 - ], - [ - 12, - 13, - 33, - 1, - 5, - 43 - ], - [ - 26, - 32, - 1, - 50, - 37, - 57 - ], - [ - 37, - 47, - 63, - 46, - 5, - 4 - ], - [ - 63, - 11, - 12, - 19, - 33, - 61 - ] - ], - [ - [ - 49, - 54, - 40, - 56, - 11, - 3 - ], - [ - 37, - 15, - 12, - 33, - 59, - 17 - ], - [ - 38, - 49, - 14, - 46, - 35, - 59 - ], - [ - 25, - 20, - 39, - 62, - 49, - 32 - ], - [ - 26, - 51, - 16, - 36, - 8, - 7 - ], - [ - 37, - 41, - 51, - 33, - 32, - 60 - ], - [ - 10, - 14, - 59, - 8, - 15, - 40 - ], - [ - 16, - 52, - 19, - 61, - 32, - 3 - ], - [ - 32, - 15, - 27, - 24, - 62, - 35 - ], - [ - 51, - 25, - 62, - 0, - 49, - 54 - ], - [ - 8, - 50, - 19, - 48, - 16, - 54 - ], - [ - 13, - 49, - 36, - 42, - 11, - 60 - ], - [ - 53, - 27, - 7, - 26, - 30, - 21 - ], - [ - 14, - 22, - 26, - 37, - 32, - 7 - ], - [ - 26, - 40, - 60, - 2, - 52, - 7 - ], - [ - 55, - 51, - 17, - 46, - 13, - 62 - ], - [ - 38, - 16, - 53, - 44, - 4, - 18 - ], - [ - 39, - 1, - 4, - 14, - 56, - 57 - ], - [ - 55, - 23, - 32, - 14, - 13, - 16 - ], - [ - 55, - 17, - 2, - 30, - 62, - 12 - ], - [ - 13, - 54, - 0, - 62, - 61, - 25 - ], - [ - 9, - 51, - 5, - 22, - 19, - 23 - ], - [ - 50, - 51, - 7, - 19, - 48, - 53 - ], - [ - 12, - 1, - 5, - 43, - 61, - 13 - ], - [ - 26, - 32, - 30, - 37, - 34, - 57 - ], - [ - 37, - 63, - 46, - 4, - 47, - 8 - ], - [ - 63, - 12, - 11, - 31, - 33, - 61 - ] - ], - [ - [ - 47, - 34, - 30, - 25, - 31, - 3 - ], - [ - 15, - 24, - 46, - 21, - 8, - 6 - ], - [ - 34, - 21, - 18, - 62, - 28, - 55 - ], - [ - 35, - 32, - 20, - 39, - 59, - 54 - ], - [ - 26, - 27, - 15, - 48, - 60, - 47 - ], - [ - 37, - 8, - 50, - 18, - 54, - 61 - ], - [ - 35, - 31, - 8, - 24, - 14, - 15 - ], - [ - 16, - 52, - 34, - 29, - 48, - 36 - ], - [ - 11, - 32, - 62, - 27, - 46, - 26 - ], - [ - 61, - 62, - 25, - 56, - 46, - 53 - ], - [ - 56, - 50, - 63, - 3, - 45, - 28 - ], - [ - 11, - 36, - 5, - 60, - 35, - 50 - ], - [ - 21, - 26, - 41, - 51, - 46, - 53 - ], - [ - 14, - 22, - 33, - 19, - 41, - 16 - ], - [ - 2, - 52, - 34, - 60, - 21, - 49 - ], - [ - 59, - 55, - 29, - 8, - 61, - 22 - ], - [ - 51, - 44, - 2, - 59, - 47, - 53 - ], - [ - 39, - 25, - 18, - 12, - 51, - 56 - ], - [ - 34, - 53, - 32, - 12, - 9, - 38 - ], - [ - 30, - 53, - 56, - 7, - 40, - 62 - ], - [ - 40, - 49, - 28, - 14, - 23, - 55 - ], - [ - 15, - 48, - 40, - 47, - 9, - 1 - ], - [ - 50, - 41, - 25, - 53, - 18, - 0 - ], - [ - 22, - 1, - 59, - 3, - 55, - 8 - ], - [ - 1, - 53, - 32, - 26, - 47, - 3 - ], - [ - 4, - 33, - 28, - 37, - 55, - 54 - ], - [ - 30, - 22, - 57, - 12, - 33, - 63 - ] - ], - [ - [ - 16, - 11, - 0, - 31, - 46, - 22 - ], - [ - 49, - 13, - 5, - 11, - 31, - 14 - ], - [ - 36, - 13, - 56, - 27, - 46, - 3 - ], - [ - 24, - 44, - 62, - 29, - 15, - 13 - ], - [ - 17, - 2, - 50, - 8, - 45, - 1 - ], - [ - 8, - 7, - 49, - 0, - 62, - 13 - ], - [ - 35, - 61, - 58, - 23, - 36, - 0 - ], - [ - 16, - 48, - 42, - 4, - 32, - 29 - ], - [ - 63, - 18, - 32, - 45, - 4, - 34 - ], - [ - 57, - 62, - 54, - 27, - 25, - 53 - ], - [ - 1, - 59, - 60, - 29, - 22, - 14 - ], - [ - 31, - 36, - 11, - 14, - 20, - 10 - ], - [ - 34, - 2, - 19, - 14, - 8, - 37 - ], - [ - 57, - 22, - 40, - 14, - 62, - 48 - ], - [ - 44, - 60, - 7, - 14, - 45, - 2 - ], - [ - 8, - 55, - 6, - 25, - 50, - 59 - ], - [ - 13, - 47, - 42, - 23, - 61, - 39 - ], - [ - 23, - 25, - 4, - 14, - 46, - 60 - ], - [ - 8, - 45, - 32, - 53, - 10, - 54 - ], - [ - 15, - 38, - 53, - 55, - 30, - 7 - ], - [ - 41, - 14, - 28, - 5, - 58, - 27 - ], - [ - 11, - 41, - 57, - 1, - 10, - 47 - ], - [ - 50, - 0, - 51, - 53, - 34, - 45 - ], - [ - 1, - 14, - 55, - 8, - 25, - 3 - ], - [ - 11, - 49, - 1, - 9, - 0, - 3 - ], - [ - 50, - 51, - 6, - 42, - 4, - 54 - ], - [ - 17, - 37, - 31, - 5, - 40, - 36 - ] - ], - [ - [ - 22, - 53, - 47, - 6, - 57, - 21 - ], - [ - 27, - 11, - 14, - 6, - 57, - 16 - ], - [ - 1, - 11, - 29, - 26, - 41, - 17 - ], - [ - 14, - 22, - 38, - 31, - 29, - 36 - ], - [ - 14, - 59, - 29, - 61, - 45, - 52 - ], - [ - 30, - 8, - 0, - 21, - 47, - 58 - ], - [ - 58, - 35, - 4, - 61, - 23, - 36 - ], - [ - 42, - 20, - 48, - 16, - 9, - 4 - ], - [ - 47, - 29, - 4, - 18, - 63, - 32 - ], - [ - 19, - 54, - 62, - 53, - 57, - 29 - ], - [ - 1, - 60, - 14, - 59, - 4, - 29 - ], - [ - 36, - 0, - 47, - 3, - 31, - 8 - ], - [ - 2, - 19, - 36, - 8, - 20, - 37 - ], - [ - 57, - 22, - 40, - 31, - 49, - 14 - ], - [ - 44, - 37, - 2, - 5, - 60, - 21 - ], - [ - 6, - 43, - 24, - 5, - 2, - 59 - ], - [ - 13, - 19, - 61, - 47, - 50, - 39 - ], - [ - 58, - 14, - 28, - 4, - 11, - 22 - ], - [ - 35, - 32, - 46, - 10, - 31, - 45 - ], - [ - 15, - 13, - 55, - 45, - 18, - 63 - ], - [ - 15, - 27, - 28, - 14, - 5, - 60 - ], - [ - 57, - 41, - 47, - 19, - 36, - 10 - ], - [ - 34, - 10, - 53, - 55, - 22, - 19 - ], - [ - 38, - 55, - 39, - 27, - 3, - 25 - ], - [ - 11, - 39, - 0, - 9, - 3, - 49 - ], - [ - 51, - 6, - 43, - 18, - 50, - 53 - ], - [ - 55, - 43, - 9, - 36, - 40, - 5 - ] - ], - [ - [ - 18, - 9, - 1, - 36, - 61, - 44 - ], - [ - 56, - 34, - 19, - 42, - 3, - 5 - ], - [ - 39, - 20, - 15, - 60, - 46, - 32 - ], - [ - 60, - 22, - 31, - 27, - 14, - 19 - ], - [ - 59, - 58, - 10, - 7, - 46, - 18 - ], - [ - 43, - 2, - 57, - 62, - 11, - 30 - ], - [ - 54, - 19, - 9, - 21, - 48, - 56 - ], - [ - 46, - 24, - 7, - 14, - 3, - 8 - ], - [ - 47, - 0, - 4, - 18, - 31, - 29 - ], - [ - 54, - 62, - 47, - 38, - 4, - 32 - ], - [ - 1, - 14, - 15, - 22, - 59, - 38 - ], - [ - 16, - 36, - 42, - 55, - 15, - 18 - ], - [ - 49, - 8, - 20, - 14, - 0, - 33 - ], - [ - 18, - 39, - 25, - 2, - 62, - 22 - ], - [ - 62, - 5, - 58, - 37, - 7, - 32 - ], - [ - 43, - 5, - 42, - 63, - 55, - 37 - ], - [ - 47, - 33, - 15, - 63, - 50, - 12 - ], - [ - 60, - 0, - 7, - 16, - 32, - 13 - ], - [ - 12, - 39, - 32, - 61, - 16, - 45 - ], - [ - 52, - 34, - 15, - 62, - 18, - 30 - ], - [ - 28, - 26, - 46, - 40, - 6, - 14 - ], - [ - 1, - 19, - 17, - 20, - 4, - 21 - ], - [ - 41, - 40, - 4, - 53, - 55, - 19 - ], - [ - 25, - 38, - 27, - 34, - 52, - 46 - ], - [ - 11, - 29, - 52, - 44, - 53, - 13 - ], - [ - 50, - 51, - 41, - 16, - 4, - 15 - ], - [ - 19, - 6, - 23, - 36, - 60, - 0 - ] - ], - [ - [ - 17, - 10, - 57, - 27, - 5, - 54 - ], - [ - 33, - 9, - 43, - 40, - 56, - 11 - ], - [ - 63, - 1, - 35, - 43, - 10, - 27 - ], - [ - 51, - 47, - 20, - 21, - 28, - 61 - ], - [ - 25, - 11, - 58, - 23, - 55, - 46 - ], - [ - 43, - 10, - 12, - 2, - 62, - 30 - ], - [ - 48, - 19, - 21, - 8, - 7, - 54 - ], - [ - 14, - 7, - 24, - 8, - 46, - 2 - ], - [ - 4, - 47, - 37, - 0, - 44, - 27 - ], - [ - 54, - 38, - 62, - 47, - 15, - 14 - ], - [ - 1, - 46, - 15, - 22, - 51, - 38 - ], - [ - 36, - 16, - 42, - 55, - 24, - 37 - ], - [ - 49, - 10, - 0, - 3, - 43, - 8 - ], - [ - 39, - 58, - 0, - 62, - 22, - 25 - ], - [ - 58, - 38, - 7, - 55, - 62, - 56 - ], - [ - 19, - 42, - 55, - 43, - 11, - 37 - ], - [ - 9, - 47, - 43, - 52, - 18, - 50 - ], - [ - 31, - 41, - 32, - 25, - 20, - 13 - ], - [ - 12, - 32, - 61, - 3, - 21, - 43 - ], - [ - 36, - 13, - 40, - 7, - 62, - 16 - ], - [ - 14, - 53, - 50, - 47, - 51, - 1 - ], - [ - 1, - 38, - 19, - 18, - 30, - 16 - ], - [ - 0, - 19, - 51, - 18, - 52, - 15 - ], - [ - 8, - 52, - 27, - 34, - 38, - 3 - ], - [ - 27, - 53, - 59, - 9, - 40, - 4 - ], - [ - 37, - 3, - 26, - 48, - 8, - 16 - ], - [ - 46, - 18, - 11, - 40, - 33, - 44 - ] - ], - [ - [ - 48, - 62, - 61, - 50, - 26, - 59 - ], - [ - 3, - 45, - 40, - 35, - 29, - 54 - ], - [ - 56, - 31, - 23, - 28, - 2, - 53 - ], - [ - 62, - 49, - 20, - 61, - 6, - 41 - ], - [ - 18, - 25, - 50, - 0, - 14, - 57 - ], - [ - 58, - 4, - 10, - 43, - 56, - 20 - ], - [ - 35, - 15, - 25, - 24, - 3, - 7 - ], - [ - 14, - 23, - 8, - 12, - 57, - 24 - ], - [ - 29, - 17, - 35, - 44, - 24, - 27 - ], - [ - 62, - 15, - 38, - 20, - 58, - 21 - ], - [ - 19, - 46, - 1, - 26, - 63, - 22 - ], - [ - 36, - 60, - 16, - 42, - 55, - 11 - ], - [ - 17, - 7, - 14, - 26, - 16, - 49 - ], - [ - 45, - 47, - 22, - 0, - 62, - 58 - ], - [ - 58, - 38, - 48, - 49, - 63, - 2 - ], - [ - 55, - 0, - 1, - 37, - 30, - 10 - ], - [ - 12, - 43, - 21, - 9, - 47, - 23 - ], - [ - 32, - 57, - 42, - 25, - 43, - 63 - ], - [ - 3, - 32, - 49, - 61, - 21, - 12 - ], - [ - 5, - 36, - 22, - 16, - 62, - 42 - ], - [ - 53, - 7, - 46, - 61, - 14, - 52 - ], - [ - 55, - 30, - 3, - 5, - 53, - 31 - ], - [ - 0, - 44, - 15, - 18, - 19, - 28 - ], - [ - 8, - 52, - 51, - 11, - 4, - 29 - ], - [ - 27, - 40, - 4, - 9, - 35, - 39 - ], - [ - 14, - 26, - 3, - 48, - 16, - 21 - ], - [ - 60, - 54, - 35, - 20, - 53, - 12 - ] - ], - [ - [ - 21, - 50, - 29, - 41, - 34, - 60 - ], - [ - 28, - 51, - 60, - 33, - 14, - 45 - ], - [ - 31, - 2, - 46, - 33, - 24, - 49 - ], - [ - 42, - 3, - 18, - 62, - 39, - 49 - ], - [ - 25, - 15, - 62, - 27, - 12, - 11 - ], - [ - 44, - 50, - 36, - 57, - 55, - 41 - ], - [ - 41, - 37, - 22, - 15, - 2, - 40 - ], - [ - 36, - 62, - 53, - 30, - 14, - 57 - ], - [ - 16, - 58, - 2, - 29, - 4, - 3 - ], - [ - 41, - 38, - 26, - 16, - 45, - 46 - ], - [ - 45, - 46, - 32, - 41, - 56, - 26 - ], - [ - 17, - 53, - 21, - 11, - 36, - 35 - ], - [ - 11, - 16, - 28, - 14, - 51, - 61 - ], - [ - 9, - 35, - 33, - 22, - 52, - 62 - ], - [ - 58, - 50, - 63, - 30, - 7, - 27 - ], - [ - 55, - 3, - 8, - 41, - 63, - 37 - ], - [ - 3, - 51, - 46, - 32, - 15, - 6 - ], - [ - 32, - 12, - 10, - 25, - 5, - 49 - ], - [ - 34, - 2, - 37, - 61, - 39, - 63 - ], - [ - 42, - 22, - 27, - 53, - 11, - 56 - ], - [ - 53, - 12, - 0, - 47, - 61, - 1 - ], - [ - 39, - 45, - 53, - 17, - 48, - 14 - ], - [ - 6, - 0, - 4, - 53, - 25, - 11 - ], - [ - 51, - 11, - 1, - 63, - 54, - 45 - ], - [ - 40, - 56, - 37, - 53, - 5, - 35 - ], - [ - 59, - 28, - 41, - 10, - 1, - 45 - ], - [ - 27, - 30, - 28, - 24, - 32, - 57 - ] - ], - [ - [ - 24, - 56, - 6, - 0, - 19, - 45 - ], - [ - 11, - 57, - 59, - 25, - 46, - 30 - ], - [ - 11, - 26, - 37, - 29, - 14, - 52 - ], - [ - 3, - 32, - 7, - 38, - 36, - 24 - ], - [ - 61, - 2, - 24, - 14, - 51, - 44 - ], - [ - 20, - 47, - 0, - 63, - 30, - 58 - ], - [ - 4, - 36, - 29, - 58, - 16, - 3 - ], - [ - 20, - 0, - 45, - 14, - 28, - 44 - ], - [ - 29, - 56, - 47, - 35, - 16, - 4 - ], - [ - 33, - 61, - 55, - 41, - 51, - 38 - ], - [ - 58, - 1, - 38, - 14, - 4, - 19 - ], - [ - 0, - 36, - 14, - 18, - 52, - 42 - ], - [ - 29, - 36, - 45, - 25, - 8, - 6 - ], - [ - 6, - 57, - 50, - 40, - 58, - 61 - ], - [ - 44, - 58, - 29, - 19, - 61, - 56 - ], - [ - 23, - 18, - 28, - 55, - 5, - 37 - ], - [ - 13, - 9, - 19, - 43, - 37, - 3 - ], - [ - 32, - 22, - 63, - 14, - 57, - 41 - ], - [ - 10, - 61, - 3, - 1, - 19, - 32 - ], - [ - 16, - 55, - 10, - 41, - 59, - 22 - ], - [ - 53, - 7, - 29, - 38, - 27, - 46 - ], - [ - 24, - 47, - 18, - 53, - 39, - 30 - ], - [ - 0, - 33, - 19, - 5, - 51, - 17 - ], - [ - 51, - 8, - 11, - 45, - 44, - 41 - ], - [ - 40, - 4, - 23, - 11, - 27, - 19 - ], - [ - 16, - 18, - 3, - 48, - 51, - 21 - ], - [ - 43, - 46, - 60, - 19, - 53, - 12 - ] - ], - [ - [ - 48, - 62, - 61, - 30, - 50, - 52 - ], - [ - 45, - 3, - 35, - 29, - 54, - 2 - ], - [ - 56, - 31, - 53, - 23, - 49, - 28 - ], - [ - 60, - 57, - 14, - 46, - 41, - 48 - ], - [ - 18, - 61, - 59, - 14, - 44, - 32 - ], - [ - 45, - 58, - 47, - 20, - 4, - 30 - ], - [ - 54, - 13, - 25, - 36, - 26, - 47 - ], - [ - 20, - 12, - 0, - 47, - 30, - 45 - ], - [ - 56, - 29, - 47, - 17, - 35, - 16 - ], - [ - 33, - 61, - 55, - 11, - 38, - 48 - ], - [ - 58, - 19, - 14, - 1, - 38, - 36 - ], - [ - 14, - 36, - 0, - 60, - 11, - 52 - ], - [ - 29, - 44, - 7, - 36, - 16, - 45 - ], - [ - 6, - 47, - 50, - 33, - 42, - 62 - ], - [ - 44, - 58, - 61, - 38, - 29, - 56 - ], - [ - 23, - 55, - 18, - 0, - 57, - 37 - ], - [ - 9, - 12, - 43, - 19, - 13, - 6 - ], - [ - 32, - 22, - 63, - 57, - 42, - 29 - ], - [ - 3, - 61, - 1, - 10, - 49, - 32 - ], - [ - 5, - 16, - 55, - 36, - 22, - 59 - ], - [ - 53, - 7, - 46, - 29, - 9, - 14 - ], - [ - 24, - 30, - 18, - 39, - 55, - 53 - ], - [ - 33, - 0, - 19, - 44, - 51, - 5 - ], - [ - 51, - 8, - 53, - 41, - 4, - 11 - ], - [ - 40, - 4, - 27, - 19, - 23, - 16 - ], - [ - 16, - 14, - 48, - 3, - 21, - 26 - ], - [ - 60, - 54, - 35, - 53, - 12, - 43 - ] - ], - [ - [ - 19, - 41, - 8, - 7, - 13, - 2 - ], - [ - 48, - 46, - 62, - 29, - 5, - 41 - ], - [ - 12, - 5, - 59, - 3, - 58, - 49 - ], - [ - 60, - 3, - 42, - 39, - 14, - 18 - ], - [ - 42, - 12, - 27, - 11, - 25, - 19 - ], - [ - 50, - 36, - 44, - 26, - 33, - 37 - ], - [ - 41, - 54, - 22, - 52, - 37, - 35 - ], - [ - 62, - 30, - 36, - 53, - 10, - 14 - ], - [ - 2, - 16, - 58, - 29, - 7, - 41 - ], - [ - 18, - 32, - 45, - 16, - 22, - 38 - ], - [ - 45, - 56, - 41, - 10, - 3, - 46 - ], - [ - 50, - 21, - 36, - 35, - 53, - 12 - ], - [ - 11, - 28, - 16, - 41, - 39, - 46 - ], - [ - 16, - 9, - 33, - 38, - 28, - 19 - ], - [ - 58, - 50, - 63, - 62, - 27, - 52 - ], - [ - 55, - 59, - 13, - 8, - 43, - 3 - ], - [ - 3, - 51, - 15, - 46, - 47, - 57 - ], - [ - 32, - 30, - 12, - 10, - 25, - 18 - ], - [ - 34, - 2, - 61, - 27, - 53, - 59 - ], - [ - 42, - 22, - 56, - 53, - 44, - 34 - ], - [ - 53, - 12, - 49, - 41, - 44, - 8 - ], - [ - 45, - 1, - 48, - 47, - 16, - 17 - ], - [ - 4, - 0, - 53, - 25, - 24, - 11 - ], - [ - 1, - 11, - 44, - 45, - 34, - 51 - ], - [ - 40, - 5, - 53, - 6, - 22, - 18 - ], - [ - 28, - 10, - 1, - 3, - 15, - 41 - ], - [ - 30, - 27, - 24, - 57, - 32, - 16 - ] - ], - [ - [ - 24, - 56, - 6, - 0, - 19, - 45 - ], - [ - 11, - 57, - 59, - 46, - 25, - 30 - ], - [ - 26, - 11, - 37, - 14, - 29, - 49 - ], - [ - 38, - 36, - 3, - 24, - 18, - 20 - ], - [ - 61, - 51, - 14, - 2, - 24, - 1 - ], - [ - 20, - 0, - 47, - 30, - 8, - 35 - ], - [ - 4, - 58, - 36, - 54, - 29, - 12 - ], - [ - 20, - 58, - 44, - 28, - 45, - 9 - ], - [ - 56, - 47, - 10, - 29, - 35, - 27 - ], - [ - 61, - 33, - 55, - 54, - 4, - 36 - ], - [ - 58, - 1, - 14, - 4, - 38, - 52 - ], - [ - 14, - 0, - 36, - 63, - 15, - 52 - ], - [ - 29, - 36, - 44, - 8, - 16, - 2 - ], - [ - 6, - 40, - 27, - 57, - 50, - 42 - ], - [ - 19, - 44, - 58, - 61, - 37, - 38 - ], - [ - 23, - 18, - 17, - 57, - 13, - 40 - ], - [ - 13, - 9, - 19, - 37, - 50, - 15 - ], - [ - 32, - 14, - 57, - 58, - 29, - 22 - ], - [ - 61, - 10, - 1, - 3, - 14, - 59 - ], - [ - 55, - 16, - 34, - 18, - 22, - 49 - ], - [ - 53, - 27, - 38, - 28, - 23, - 44 - ], - [ - 24, - 47, - 18, - 62, - 41, - 30 - ], - [ - 33, - 51, - 19, - 5, - 0, - 31 - ], - [ - 51, - 8, - 25, - 53, - 27, - 16 - ], - [ - 40, - 11, - 27, - 4, - 23, - 19 - ], - [ - 16, - 18, - 51, - 48, - 3, - 47 - ], - [ - 46, - 43, - 36, - 9, - 5, - 12 - ] - ], - [ - [ - 37, - 10, - 46, - 60, - 61, - 59 - ], - [ - 35, - 53, - 34, - 43, - 19, - 57 - ], - [ - 49, - 56, - 45, - 30, - 6, - 12 - ], - [ - 60, - 27, - 14, - 48, - 46, - 57 - ], - [ - 61, - 59, - 14, - 41, - 16, - 1 - ], - [ - 45, - 4, - 3, - 58, - 24, - 47 - ], - [ - 54, - 13, - 9, - 43, - 16, - 26 - ], - [ - 47, - 23, - 12, - 20, - 63, - 30 - ], - [ - 23, - 44, - 56, - 29, - 47, - 17 - ], - [ - 33, - 60, - 61, - 48, - 41, - 14 - ], - [ - 58, - 63, - 19, - 11, - 9, - 38 - ], - [ - 60, - 63, - 0, - 36, - 15, - 9 - ], - [ - 29, - 36, - 30, - 59, - 11, - 27 - ], - [ - 6, - 7, - 47, - 62, - 50, - 57 - ], - [ - 27, - 58, - 19, - 46, - 29, - 56 - ], - [ - 29, - 60, - 56, - 55, - 23, - 26 - ], - [ - 53, - 59, - 6, - 9, - 16, - 43 - ], - [ - 41, - 32, - 57, - 63, - 18, - 37 - ], - [ - 42, - 61, - 3, - 10, - 34, - 59 - ], - [ - 4, - 43, - 17, - 16, - 52, - 60 - ], - [ - 45, - 53, - 61, - 56, - 16, - 7 - ], - [ - 55, - 9, - 18, - 61, - 45, - 3 - ], - [ - 60, - 47, - 53, - 33, - 12, - 27 - ], - [ - 43, - 51, - 11, - 8, - 45, - 63 - ], - [ - 6, - 40, - 15, - 27, - 26, - 23 - ], - [ - 49, - 14, - 9, - 21, - 58, - 12 - ], - [ - 63, - 24, - 60, - 31, - 12, - 34 - ] - ], - [ - [ - 16, - 13, - 4, - 44, - 23, - 46 - ], - [ - 16, - 50, - 9, - 13, - 23, - 36 - ], - [ - 11, - 35, - 21, - 7, - 59, - 8 - ], - [ - 1, - 3, - 25, - 15, - 60, - 39 - ], - [ - 54, - 61, - 31, - 35, - 55, - 1 - ], - [ - 51, - 52, - 46, - 15, - 4, - 45 - ], - [ - 60, - 54, - 59, - 44, - 10, - 7 - ], - [ - 12, - 22, - 14, - 47, - 0, - 30 - ], - [ - 42, - 29, - 23, - 56, - 47, - 33 - ], - [ - 33, - 61, - 20, - 60, - 0, - 49 - ], - [ - 35, - 58, - 63, - 14, - 51, - 24 - ], - [ - 29, - 33, - 36, - 60, - 0, - 49 - ], - [ - 29, - 17, - 30, - 12, - 31, - 36 - ], - [ - 6, - 0, - 61, - 50, - 48, - 3 - ], - [ - 8, - 6, - 58, - 37, - 29, - 19 - ], - [ - 60, - 39, - 27, - 19, - 1, - 57 - ], - [ - 56, - 9, - 30, - 6, - 10, - 43 - ], - [ - 32, - 20, - 13, - 57, - 63, - 49 - ], - [ - 8, - 42, - 21, - 61, - 37, - 4 - ], - [ - 45, - 49, - 16, - 13, - 2, - 58 - ], - [ - 53, - 29, - 14, - 50, - 61, - 3 - ], - [ - 37, - 57, - 27, - 54, - 46, - 9 - ], - [ - 52, - 19, - 22, - 0, - 18, - 5 - ], - [ - 14, - 49, - 30, - 33, - 53, - 34 - ], - [ - 13, - 9, - 4, - 40, - 23, - 39 - ], - [ - 27, - 43, - 47, - 36, - 49, - 3 - ], - [ - 59, - 43, - 40, - 28, - 0, - 33 - ] - ], - [ - [ - 48, - 42, - 63, - 50, - 34, - 38 - ], - [ - 3, - 40, - 61, - 62, - 6, - 2 - ], - [ - 39, - 7, - 36, - 6, - 45, - 40 - ], - [ - 41, - 35, - 46, - 13, - 63, - 56 - ], - [ - 6, - 1, - 54, - 37, - 38, - 34 - ], - [ - 59, - 46, - 51, - 31, - 4, - 52 - ], - [ - 60, - 44, - 11, - 54, - 4, - 24 - ], - [ - 12, - 0, - 2, - 63, - 50, - 47 - ], - [ - 33, - 42, - 29, - 23, - 16, - 56 - ], - [ - 20, - 61, - 33, - 60, - 53, - 0 - ], - [ - 35, - 58, - 63, - 9, - 8, - 19 - ], - [ - 59, - 48, - 36, - 60, - 10, - 14 - ], - [ - 29, - 44, - 7, - 17, - 36, - 12 - ], - [ - 47, - 27, - 6, - 62, - 42, - 48 - ], - [ - 18, - 58, - 49, - 46, - 42, - 44 - ], - [ - 60, - 34, - 27, - 18, - 23, - 55 - ], - [ - 43, - 40, - 9, - 50, - 18, - 45 - ], - [ - 32, - 57, - 48, - 42, - 29, - 39 - ], - [ - 42, - 61, - 49, - 3, - 32, - 1 - ], - [ - 23, - 37, - 1, - 16, - 36, - 39 - ], - [ - 53, - 21, - 7, - 61, - 50, - 31 - ], - [ - 8, - 60, - 18, - 24, - 9, - 30 - ], - [ - 51, - 33, - 28, - 5, - 44, - 8 - ], - [ - 51, - 52, - 8, - 4, - 41, - 45 - ], - [ - 40, - 4, - 27, - 9, - 60, - 19 - ], - [ - 3, - 61, - 16, - 26, - 48, - 12 - ], - [ - 54, - 61, - 35, - 1, - 53, - 43 - ] - ], - [ - [ - 62, - 28, - 1, - 42, - 8, - 55 - ], - [ - 18, - 12, - 8, - 41, - 40, - 31 - ], - [ - 12, - 6, - 50, - 4, - 23, - 45 - ], - [ - 43, - 35, - 8, - 20, - 42, - 46 - ], - [ - 39, - 41, - 29, - 22, - 3, - 56 - ], - [ - 61, - 45, - 46, - 48, - 28, - 51 - ], - [ - 44, - 4, - 11, - 25, - 54, - 59 - ], - [ - 12, - 33, - 56, - 52, - 30, - 17 - ], - [ - 55, - 29, - 17, - 42, - 23, - 14 - ], - [ - 60, - 12, - 18, - 61, - 33, - 28 - ], - [ - 35, - 58, - 37, - 63, - 6, - 27 - ], - [ - 48, - 59, - 10, - 36, - 58, - 60 - ], - [ - 17, - 7, - 28, - 31, - 29, - 27 - ], - [ - 47, - 42, - 50, - 6, - 8, - 14 - ], - [ - 39, - 58, - 56, - 37, - 18, - 59 - ], - [ - 60, - 18, - 57, - 9, - 55, - 23 - ], - [ - 43, - 63, - 18, - 60, - 19, - 22 - ], - [ - 1, - 32, - 42, - 57, - 35, - 63 - ], - [ - 42, - 61, - 3, - 32, - 1, - 50 - ], - [ - 37, - 36, - 10, - 23, - 16, - 57 - ], - [ - 53, - 61, - 7, - 57, - 21, - 23 - ], - [ - 9, - 39, - 30, - 18, - 14, - 17 - ], - [ - 33, - 44, - 8, - 5, - 0, - 19 - ], - [ - 51, - 53, - 49, - 4, - 52, - 41 - ], - [ - 40, - 4, - 27, - 6, - 9, - 16 - ], - [ - 3, - 16, - 48, - 26, - 12, - 4 - ], - [ - 61, - 14, - 12, - 54, - 35, - 53 - ] - ], - [ - [ - 47, - 23, - 63, - 11, - 61, - 55 - ], - [ - 17, - 44, - 28, - 39, - 47, - 27 - ], - [ - 34, - 53, - 50, - 38, - 29, - 5 - ], - [ - 11, - 10, - 17, - 52, - 47, - 42 - ], - [ - 15, - 41, - 27, - 20, - 12, - 6 - ], - [ - 34, - 44, - 50, - 39, - 36, - 61 - ], - [ - 37, - 41, - 52, - 29, - 46, - 47 - ], - [ - 62, - 36, - 34, - 30, - 39, - 22 - ], - [ - 62, - 16, - 58, - 5, - 2, - 8 - ], - [ - 32, - 41, - 56, - 12, - 46, - 8 - ], - [ - 10, - 35, - 45, - 41, - 3, - 56 - ], - [ - 50, - 48, - 35, - 53, - 36, - 12 - ], - [ - 39, - 11, - 46, - 7, - 23, - 51 - ], - [ - 9, - 47, - 19, - 22, - 52, - 34 - ], - [ - 35, - 18, - 56, - 50, - 3, - 23 - ], - [ - 3, - 60, - 38, - 36, - 9, - 35 - ], - [ - 46, - 28, - 32, - 5, - 43, - 56 - ], - [ - 30, - 32, - 57, - 42, - 52, - 19 - ], - [ - 2, - 32, - 34, - 61, - 14, - 42 - ], - [ - 42, - 37, - 20, - 50, - 9, - 48 - ], - [ - 53, - 7, - 56, - 25, - 60, - 13 - ], - [ - 17, - 39, - 14, - 53, - 30, - 25 - ], - [ - 5, - 40, - 6, - 33, - 29, - 25 - ], - [ - 51, - 4, - 11, - 58, - 57, - 28 - ], - [ - 40, - 37, - 4, - 44, - 8, - 48 - ], - [ - 3, - 48, - 26, - 9, - 12, - 41 - ], - [ - 47, - 61, - 26, - 24, - 20, - 53 - ] - ], - [ - [ - 31, - 43, - 41, - 47, - 11, - 25 - ], - [ - 50, - 25, - 31, - 40, - 24, - 46 - ], - [ - 23, - 9, - 62, - 15, - 20, - 53 - ], - [ - 4, - 47, - 44, - 58, - 48, - 25 - ], - [ - 2, - 19, - 12, - 52, - 0, - 40 - ], - [ - 49, - 15, - 24, - 34, - 60, - 42 - ], - [ - 12, - 46, - 17, - 29, - 41, - 3 - ], - [ - 39, - 60, - 44, - 41, - 33, - 36 - ], - [ - 21, - 60, - 16, - 44, - 51, - 57 - ], - [ - 24, - 41, - 12, - 33, - 13, - 21 - ], - [ - 43, - 62, - 3, - 12, - 28, - 45 - ], - [ - 58, - 19, - 39, - 17, - 49, - 42 - ], - [ - 25, - 54, - 4, - 7, - 11, - 39 - ], - [ - 34, - 35, - 4, - 42, - 62, - 19 - ], - [ - 43, - 41, - 42, - 35, - 40, - 32 - ], - [ - 21, - 63, - 3, - 17, - 20, - 50 - ], - [ - 58, - 46, - 44, - 1, - 25, - 20 - ], - [ - 26, - 32, - 16, - 25, - 46, - 41 - ], - [ - 37, - 63, - 61, - 28, - 24, - 56 - ], - [ - 33, - 42, - 40, - 37, - 48, - 50 - ], - [ - 11, - 53, - 25, - 39, - 4, - 61 - ], - [ - 12, - 54, - 4, - 27, - 50, - 14 - ], - [ - 47, - 19, - 42, - 17, - 35, - 40 - ], - [ - 54, - 40, - 60, - 63, - 45, - 57 - ], - [ - 44, - 56, - 40, - 62, - 37, - 3 - ], - [ - 59, - 41, - 57, - 34, - 48, - 22 - ], - [ - 56, - 13, - 59, - 51, - 26, - 58 - ] - ], - [ - [ - 35, - 32, - 8, - 40, - 51, - 52 - ], - [ - 52, - 5, - 22, - 21, - 6, - 33 - ], - [ - 22, - 58, - 11, - 25, - 3, - 51 - ], - [ - 63, - 2, - 56, - 4, - 23, - 54 - ], - [ - 39, - 12, - 23, - 32, - 30, - 46 - ], - [ - 50, - 34, - 36, - 58, - 26, - 28 - ], - [ - 46, - 41, - 3, - 2, - 22, - 16 - ], - [ - 60, - 36, - 53, - 30, - 54, - 39 - ], - [ - 16, - 51, - 3, - 39, - 2, - 26 - ], - [ - 45, - 26, - 18, - 41, - 32, - 46 - ], - [ - 45, - 3, - 10, - 56, - 36, - 35 - ], - [ - 21, - 36, - 35, - 50, - 11, - 19 - ], - [ - 28, - 11, - 46, - 59, - 41, - 15 - ], - [ - 23, - 16, - 38, - 19, - 15, - 22 - ], - [ - 27, - 7, - 34, - 58, - 3, - 42 - ], - [ - 9, - 22, - 36, - 46, - 26, - 41 - ], - [ - 3, - 51, - 40, - 56, - 46, - 8 - ], - [ - 12, - 25, - 21, - 50, - 17, - 62 - ], - [ - 27, - 34, - 61, - 13, - 60, - 11 - ], - [ - 53, - 12, - 56, - 0, - 42, - 33 - ], - [ - 53, - 37, - 12, - 24, - 25, - 63 - ], - [ - 45, - 55, - 18, - 26, - 17, - 43 - ], - [ - 4, - 25, - 32, - 1, - 48, - 53 - ], - [ - 17, - 27, - 63, - 4, - 62, - 31 - ], - [ - 6, - 52, - 62, - 40, - 46, - 23 - ], - [ - 10, - 42, - 28, - 49, - 3, - 53 - ], - [ - 45, - 27, - 41, - 21, - 16, - 47 - ] - ], - [ - [ - 44, - 24, - 33, - 56, - 15, - 41 - ], - [ - 38, - 26, - 24, - 29, - 19, - 53 - ], - [ - 12, - 15, - 29, - 9, - 1, - 63 - ], - [ - 58, - 38, - 50, - 0, - 43, - 61 - ], - [ - 24, - 51, - 31, - 34, - 60, - 7 - ], - [ - 0, - 7, - 22, - 43, - 35, - 1 - ], - [ - 63, - 36, - 11, - 1, - 16, - 4 - ], - [ - 8, - 50, - 56, - 4, - 30, - 55 - ], - [ - 43, - 16, - 42, - 29, - 60, - 35 - ], - [ - 34, - 0, - 9, - 22, - 18, - 26 - ], - [ - 54, - 51, - 45, - 35, - 2, - 36 - ], - [ - 37, - 36, - 43, - 60, - 11, - 59 - ], - [ - 56, - 38, - 10, - 28, - 14, - 43 - ], - [ - 30, - 0, - 58, - 62, - 22, - 19 - ], - [ - 7, - 55, - 42, - 58, - 30, - 38 - ], - [ - 11, - 33, - 1, - 39, - 19, - 16 - ], - [ - 55, - 20, - 40, - 9, - 18, - 30 - ], - [ - 18, - 20, - 57, - 32, - 45, - 1 - ], - [ - 43, - 61, - 12, - 32, - 31, - 30 - ], - [ - 23, - 25, - 7, - 28, - 40, - 19 - ], - [ - 14, - 51, - 48, - 58, - 53, - 25 - ], - [ - 18, - 30, - 1, - 49, - 41, - 9 - ], - [ - 2, - 51, - 22, - 0, - 52, - 5 - ], - [ - 53, - 4, - 47, - 52, - 51, - 40 - ], - [ - 40, - 16, - 9, - 47, - 23, - 11 - ], - [ - 47, - 3, - 43, - 46, - 26, - 53 - ], - [ - 8, - 40, - 18, - 46, - 33, - 63 - ] - ], - [ - [ - 48, - 38, - 50, - 42, - 63, - 36 - ], - [ - 3, - 10, - 26, - 2, - 6, - 61 - ], - [ - 39, - 44, - 45, - 40, - 6, - 7 - ], - [ - 41, - 5, - 20, - 49, - 56, - 13 - ], - [ - 6, - 1, - 30, - 37, - 28, - 38 - ], - [ - 59, - 46, - 22, - 35, - 61, - 0 - ], - [ - 1, - 63, - 35, - 3, - 60, - 49 - ], - [ - 8, - 12, - 2, - 50, - 5, - 55 - ], - [ - 42, - 33, - 43, - 16, - 32, - 29 - ], - [ - 9, - 34, - 0, - 20, - 41, - 31 - ], - [ - 51, - 54, - 8, - 19, - 63, - 9 - ], - [ - 37, - 56, - 36, - 11, - 59, - 43 - ], - [ - 38, - 10, - 28, - 17, - 56, - 63 - ], - [ - 27, - 30, - 42, - 19, - 0, - 22 - ], - [ - 7, - 55, - 49, - 42, - 58, - 38 - ], - [ - 29, - 34, - 39, - 33, - 47, - 11 - ], - [ - 55, - 40, - 20, - 18, - 7, - 5 - ], - [ - 43, - 57, - 39, - 54, - 48, - 28 - ], - [ - 12, - 43, - 61, - 42, - 32, - 49 - ], - [ - 23, - 36, - 1, - 7, - 59, - 28 - ], - [ - 14, - 53, - 21, - 7, - 57, - 37 - ], - [ - 18, - 1, - 24, - 60, - 30, - 9 - ], - [ - 51, - 0, - 33, - 2, - 44, - 5 - ], - [ - 52, - 29, - 4, - 41, - 54, - 58 - ], - [ - 40, - 19, - 16, - 9, - 46, - 47 - ], - [ - 61, - 3, - 47, - 22, - 21, - 53 - ], - [ - 35, - 60, - 54, - 1, - 5, - 40 - ] - ], - [ - [ - 17, - 18, - 8, - 53, - 25, - 43 - ], - [ - 9, - 38, - 24, - 47, - 25, - 63 - ], - [ - 20, - 24, - 5, - 12, - 54, - 28 - ], - [ - 43, - 10, - 20, - 42, - 11, - 8 - ], - [ - 53, - 61, - 30, - 39, - 29, - 18 - ], - [ - 61, - 56, - 25, - 40, - 5, - 22 - ], - [ - 62, - 17, - 24, - 1, - 47, - 33 - ], - [ - 41, - 16, - 34, - 39, - 29, - 8 - ], - [ - 39, - 16, - 36, - 42, - 29, - 23 - ], - [ - 9, - 11, - 41, - 63, - 56, - 31 - ], - [ - 48, - 51, - 10, - 62, - 63, - 45 - ], - [ - 36, - 11, - 37, - 42, - 58, - 46 - ], - [ - 51, - 38, - 25, - 63, - 29, - 44 - ], - [ - 4, - 56, - 44, - 62, - 58, - 30 - ], - [ - 3, - 7, - 46, - 42, - 33, - 35 - ], - [ - 39, - 9, - 33, - 58, - 60, - 29 - ], - [ - 40, - 37, - 20, - 16, - 55, - 25 - ], - [ - 54, - 19, - 11, - 57, - 0, - 39 - ], - [ - 12, - 43, - 61, - 25, - 49, - 32 - ], - [ - 4, - 23, - 54, - 36, - 7, - 28 - ], - [ - 40, - 25, - 26, - 14, - 2, - 58 - ], - [ - 18, - 58, - 24, - 1, - 22, - 46 - ], - [ - 2, - 63, - 22, - 6, - 44, - 56 - ], - [ - 52, - 29, - 51, - 4, - 40, - 32 - ], - [ - 40, - 17, - 15, - 16, - 46, - 57 - ], - [ - 9, - 61, - 3, - 47, - 24, - 11 - ], - [ - 2, - 39, - 24, - 42, - 0, - 44 - ] - ], - [ - [ - 0, - 10, - 49, - 23, - 62, - 44 - ], - [ - 28, - 0, - 36, - 26, - 47, - 52 - ], - [ - 30, - 4, - 16, - 48, - 40, - 10 - ], - [ - 61, - 32, - 26, - 16, - 33, - 62 - ], - [ - 30, - 39, - 53, - 5, - 57, - 20 - ], - [ - 5, - 37, - 61, - 15, - 25, - 6 - ], - [ - 15, - 17, - 24, - 60, - 49, - 62 - ], - [ - 34, - 39, - 61, - 0, - 58, - 40 - ], - [ - 16, - 39, - 36, - 51, - 2, - 29 - ], - [ - 9, - 11, - 41, - 31, - 56, - 52 - ], - [ - 10, - 48, - 24, - 45, - 62, - 51 - ], - [ - 11, - 38, - 36, - 37, - 6, - 42 - ], - [ - 51, - 50, - 15, - 30, - 25, - 38 - ], - [ - 4, - 19, - 24, - 35, - 31, - 48 - ], - [ - 7, - 46, - 3, - 58, - 30, - 41 - ], - [ - 58, - 9, - 39, - 32, - 29, - 40 - ], - [ - 40, - 37, - 20, - 8, - 25, - 55 - ], - [ - 19, - 0, - 54, - 52, - 17, - 39 - ], - [ - 25, - 43, - 12, - 61, - 14, - 11 - ], - [ - 23, - 4, - 54, - 36, - 28, - 33 - ], - [ - 40, - 2, - 25, - 58, - 36, - 53 - ], - [ - 18, - 46, - 35, - 22, - 53, - 16 - ], - [ - 2, - 6, - 63, - 14, - 42, - 11 - ], - [ - 35, - 7, - 52, - 40, - 29, - 57 - ], - [ - 40, - 15, - 19, - 57, - 17, - 23 - ], - [ - 9, - 11, - 47, - 22, - 49, - 1 - ], - [ - 24, - 39, - 42, - 2, - 16, - 0 - ] - ], - [ - [ - 55, - 39, - 9, - 43, - 21, - 46 - ], - [ - 56, - 0, - 63, - 39, - 30, - 41 - ], - [ - 20, - 1, - 26, - 58, - 34, - 19 - ], - [ - 54, - 24, - 32, - 51, - 26, - 44 - ], - [ - 30, - 53, - 56, - 39, - 34, - 40 - ], - [ - 5, - 37, - 25, - 50, - 6, - 61 - ], - [ - 24, - 49, - 37, - 15, - 6, - 29 - ], - [ - 34, - 16, - 30, - 61, - 10, - 36 - ], - [ - 16, - 29, - 2, - 5, - 51, - 26 - ], - [ - 9, - 56, - 11, - 31, - 46, - 45 - ], - [ - 10, - 45, - 56, - 62, - 25, - 36 - ], - [ - 11, - 6, - 35, - 36, - 1, - 52 - ], - [ - 51, - 50, - 41, - 46, - 38, - 4 - ], - [ - 19, - 33, - 41, - 16, - 31, - 52 - ], - [ - 34, - 7, - 17, - 47, - 63, - 3 - ], - [ - 58, - 9, - 22, - 61, - 59, - 8 - ], - [ - 40, - 37, - 3, - 51, - 22, - 57 - ], - [ - 12, - 52, - 21, - 54, - 25, - 19 - ], - [ - 34, - 53, - 27, - 43, - 14, - 13 - ], - [ - 56, - 44, - 53, - 24, - 60, - 43 - ], - [ - 12, - 53, - 40, - 49, - 2, - 62 - ], - [ - 18, - 39, - 44, - 61, - 26, - 23 - ], - [ - 0, - 4, - 53, - 25, - 41, - 21 - ], - [ - 1, - 7, - 25, - 10, - 40, - 56 - ], - [ - 40, - 22, - 6, - 29, - 19, - 48 - ], - [ - 28, - 10, - 47, - 55, - 42, - 44 - ], - [ - 30, - 27, - 57, - 16, - 50, - 59 - ] - ], - [ - [ - 45, - 37, - 48, - 29, - 30, - 3 - ], - [ - 8, - 60, - 10, - 59, - 43, - 6 - ], - [ - 51, - 45, - 28, - 59, - 63, - 34 - ], - [ - 4, - 16, - 20, - 58, - 44, - 28 - ], - [ - 50, - 31, - 57, - 24, - 51, - 53 - ], - [ - 58, - 9, - 0, - 61, - 35, - 41 - ], - [ - 16, - 63, - 11, - 61, - 23, - 36 - ], - [ - 4, - 47, - 42, - 53, - 8, - 30 - ], - [ - 44, - 14, - 16, - 33, - 3, - 20 - ], - [ - 34, - 28, - 26, - 57, - 22, - 18 - ], - [ - 20, - 35, - 19, - 59, - 2, - 38 - ], - [ - 12, - 60, - 43, - 63, - 32, - 62 - ], - [ - 28, - 12, - 29, - 11, - 14, - 50 - ], - [ - 23, - 29, - 33, - 22, - 11, - 19 - ], - [ - 23, - 60, - 51, - 50, - 7, - 22 - ], - [ - 44, - 46, - 49, - 7, - 1, - 12 - ], - [ - 2, - 54, - 27, - 61, - 18, - 5 - ], - [ - 17, - 50, - 51, - 32, - 33, - 34 - ], - [ - 5, - 19, - 61, - 27, - 32, - 11 - ], - [ - 6, - 0, - 5, - 13, - 41, - 57 - ], - [ - 27, - 33, - 53, - 45, - 38, - 32 - ], - [ - 26, - 36, - 55, - 59, - 61, - 18 - ], - [ - 47, - 46, - 3, - 37, - 57, - 49 - ], - [ - 20, - 22, - 4, - 16, - 51, - 11 - ], - [ - 62, - 11, - 21, - 34, - 4, - 1 - ], - [ - 34, - 18, - 7, - 60, - 33, - 32 - ], - [ - 45, - 52, - 4, - 36, - 21, - 9 - ] - ], - [ - [ - 18, - 8, - 20, - 49, - 30, - 23 - ], - [ - 1, - 27, - 26, - 22, - 59, - 36 - ], - [ - 43, - 26, - 15, - 58, - 0, - 46 - ], - [ - 55, - 1, - 35, - 28, - 16, - 32 - ], - [ - 59, - 9, - 10, - 53, - 12, - 58 - ], - [ - 9, - 2, - 27, - 11, - 61, - 43 - ], - [ - 16, - 57, - 63, - 23, - 19, - 12 - ], - [ - 46, - 45, - 26, - 4, - 30, - 37 - ], - [ - 43, - 44, - 20, - 16, - 14, - 9 - ], - [ - 34, - 47, - 42, - 43, - 26, - 51 - ], - [ - 42, - 2, - 38, - 45, - 20, - 36 - ], - [ - 18, - 7, - 12, - 2, - 43, - 60 - ], - [ - 1, - 28, - 12, - 3, - 29, - 33 - ], - [ - 25, - 13, - 0, - 63, - 2, - 62 - ], - [ - 18, - 36, - 6, - 29, - 19, - 15 - ], - [ - 1, - 42, - 63, - 41, - 57, - 19 - ], - [ - 57, - 54, - 5, - 27, - 18, - 31 - ], - [ - 50, - 6, - 13, - 32, - 17, - 20 - ], - [ - 17, - 5, - 27, - 32, - 1, - 55 - ], - [ - 49, - 0, - 61, - 5, - 10, - 30 - ], - [ - 29, - 53, - 51, - 13, - 33, - 46 - ], - [ - 29, - 17, - 21, - 30, - 14, - 40 - ], - [ - 5, - 17, - 33, - 32, - 18, - 28 - ], - [ - 51, - 4, - 20, - 54, - 58, - 41 - ], - [ - 47, - 4, - 27, - 48, - 37, - 60 - ], - [ - 3, - 26, - 12, - 59, - 2, - 48 - ], - [ - 46, - 43, - 18, - 20, - 9, - 53 - ] - ], - [ - [ - 45, - 6, - 57, - 43, - 40, - 58 - ], - [ - 38, - 63, - 36, - 27, - 54, - 33 - ], - [ - 37, - 14, - 19, - 41, - 58, - 63 - ], - [ - 9, - 12, - 2, - 55, - 28, - 23 - ], - [ - 39, - 59, - 7, - 13, - 33, - 43 - ], - [ - 45, - 9, - 63, - 27, - 32, - 58 - ], - [ - 16, - 57, - 10, - 63, - 11, - 23 - ], - [ - 51, - 45, - 25, - 4, - 21, - 30 - ], - [ - 21, - 44, - 14, - 16, - 33, - 39 - ], - [ - 42, - 44, - 43, - 5, - 37, - 34 - ], - [ - 42, - 19, - 20, - 2, - 38, - 61 - ], - [ - 4, - 12, - 2, - 62, - 63, - 36 - ], - [ - 32, - 55, - 0, - 11, - 47, - 28 - ], - [ - 43, - 13, - 2, - 44, - 26, - 50 - ], - [ - 49, - 33, - 15, - 28, - 29, - 35 - ], - [ - 44, - 41, - 7, - 2, - 22, - 63 - ], - [ - 48, - 6, - 54, - 20, - 2, - 27 - ], - [ - 50, - 51, - 32, - 3, - 17, - 36 - ], - [ - 5, - 61, - 57, - 48, - 19, - 32 - ], - [ - 21, - 0, - 6, - 31, - 29, - 47 - ], - [ - 33, - 9, - 53, - 27, - 17, - 36 - ], - [ - 29, - 26, - 55, - 19, - 17, - 62 - ], - [ - 12, - 46, - 5, - 37, - 57, - 3 - ], - [ - 20, - 51, - 4, - 22, - 16, - 41 - ], - [ - 21, - 11, - 62, - 46, - 23, - 48 - ], - [ - 32, - 60, - 37, - 18, - 3, - 7 - ], - [ - 9, - 11, - 36, - 48, - 0, - 45 - ] - ], - [ - [ - 49, - 42, - 28, - 23, - 33, - 61 - ], - [ - 4, - 2, - 12, - 6, - 8, - 55 - ], - [ - 12, - 0, - 26, - 41, - 6, - 27 - ], - [ - 9, - 57, - 6, - 23, - 51, - 28 - ], - [ - 40, - 7, - 20, - 16, - 15, - 33 - ], - [ - 45, - 59, - 63, - 62, - 32, - 3 - ], - [ - 10, - 39, - 57, - 13, - 16, - 19 - ], - [ - 45, - 23, - 51, - 33, - 25, - 46 - ], - [ - 28, - 21, - 44, - 11, - 16, - 59 - ], - [ - 5, - 42, - 44, - 24, - 43, - 47 - ], - [ - 42, - 53, - 30, - 18, - 2, - 27 - ], - [ - 2, - 62, - 4, - 43, - 10, - 36 - ], - [ - 0, - 56, - 55, - 47, - 32, - 49 - ], - [ - 43, - 25, - 2, - 5, - 3, - 49 - ], - [ - 4, - 28, - 15, - 8, - 49, - 58 - ], - [ - 2, - 42, - 44, - 41, - 7, - 63 - ], - [ - 48, - 27, - 54, - 20, - 2, - 18 - ], - [ - 50, - 51, - 0, - 36, - 3, - 32 - ], - [ - 57, - 5, - 61, - 19, - 32, - 38 - ], - [ - 21, - 0, - 6, - 63, - 23, - 51 - ], - [ - 33, - 53, - 27, - 36, - 9, - 38 - ], - [ - 29, - 26, - 55, - 62, - 18, - 31 - ], - [ - 46, - 56, - 12, - 53, - 29, - 0 - ], - [ - 20, - 16, - 22, - 4, - 51, - 17 - ], - [ - 21, - 62, - 11, - 31, - 46, - 33 - ], - [ - 37, - 60, - 18, - 7, - 32, - 44 - ], - [ - 11, - 9, - 36, - 0, - 48, - 63 - ] - ], - [ - [ - 41, - 32, - 49, - 39, - 61, - 44 - ], - [ - 47, - 26, - 16, - 21, - 36, - 22 - ], - [ - 4, - 30, - 37, - 42, - 60, - 54 - ], - [ - 9, - 57, - 26, - 32, - 50, - 20 - ], - [ - 56, - 3, - 40, - 33, - 36, - 54 - ], - [ - 11, - 38, - 2, - 32, - 61, - 30 - ], - [ - 39, - 57, - 19, - 10, - 16, - 42 - ], - [ - 46, - 21, - 35, - 39, - 45, - 25 - ], - [ - 21, - 37, - 12, - 20, - 11, - 28 - ], - [ - 5, - 47, - 44, - 10, - 42, - 23 - ], - [ - 18, - 42, - 61, - 2, - 38, - 31 - ], - [ - 54, - 4, - 2, - 7, - 22, - 16 - ], - [ - 5, - 3, - 17, - 56, - 32, - 55 - ], - [ - 55, - 0, - 2, - 25, - 43, - 5 - ], - [ - 22, - 28, - 15, - 6, - 5, - 49 - ], - [ - 2, - 57, - 19, - 54, - 41, - 30 - ], - [ - 7, - 48, - 20, - 54, - 27, - 0 - ], - [ - 3, - 56, - 13, - 37, - 43, - 59 - ], - [ - 45, - 55, - 57, - 61, - 48, - 52 - ], - [ - 21, - 5, - 0, - 16, - 27, - 23 - ], - [ - 25, - 42, - 17, - 54, - 23, - 14 - ], - [ - 21, - 44, - 15, - 20, - 42, - 18 - ], - [ - 35, - 12, - 25, - 53, - 61, - 2 - ], - [ - 38, - 54, - 48, - 53, - 21, - 36 - ], - [ - 13, - 31, - 48, - 33, - 18, - 55 - ], - [ - 38, - 27, - 19, - 6, - 44, - 3 - ], - [ - 29, - 62, - 43, - 59, - 46, - 5 - ] - ], - [ - [ - 57, - 9, - 19, - 51, - 18, - 41 - ], - [ - 28, - 57, - 36, - 8, - 48, - 60 - ], - [ - 2, - 51, - 59, - 5, - 34, - 9 - ], - [ - 9, - 55, - 59, - 26, - 4, - 2 - ], - [ - 49, - 56, - 35, - 42, - 30, - 23 - ], - [ - 18, - 30, - 22, - 29, - 19, - 52 - ], - [ - 39, - 34, - 33, - 51, - 56, - 3 - ], - [ - 32, - 21, - 1, - 7, - 46, - 49 - ], - [ - 33, - 54, - 23, - 21, - 12, - 11 - ], - [ - 5, - 30, - 60, - 47, - 15, - 18 - ], - [ - 4, - 18, - 46, - 27, - 20, - 22 - ], - [ - 22, - 59, - 54, - 48, - 19, - 4 - ], - [ - 17, - 5, - 56, - 31, - 49, - 4 - ], - [ - 29, - 47, - 55, - 2, - 53, - 60 - ], - [ - 8, - 22, - 11, - 44, - 36, - 15 - ], - [ - 60, - 44, - 30, - 57, - 54, - 39 - ], - [ - 7, - 44, - 27, - 20, - 2, - 61 - ], - [ - 48, - 17, - 21, - 37, - 32, - 57 - ], - [ - 48, - 32, - 46, - 6, - 61, - 42 - ], - [ - 4, - 57, - 1, - 36, - 0, - 30 - ], - [ - 7, - 17, - 61, - 53, - 21, - 63 - ], - [ - 60, - 14, - 53, - 35, - 18, - 55 - ], - [ - 10, - 15, - 33, - 51, - 36, - 5 - ], - [ - 11, - 4, - 19, - 51, - 21, - 52 - ], - [ - 47, - 19, - 43, - 48, - 58, - 4 - ], - [ - 3, - 33, - 26, - 21, - 52, - 19 - ], - [ - 24, - 45, - 60, - 35, - 49, - 1 - ] - ], - [ - [ - 16, - 4, - 44, - 23, - 22, - 43 - ], - [ - 16, - 23, - 50, - 9, - 32, - 13 - ], - [ - 11, - 35, - 21, - 7, - 48, - 59 - ], - [ - 55, - 15, - 1, - 11, - 8, - 40 - ], - [ - 35, - 61, - 30, - 59, - 31, - 62 - ], - [ - 51, - 29, - 15, - 52, - 38, - 61 - ], - [ - 60, - 0, - 55, - 34, - 59, - 33 - ], - [ - 12, - 22, - 56, - 63, - 54, - 55 - ], - [ - 42, - 54, - 23, - 33, - 27, - 47 - ], - [ - 30, - 60, - 20, - 5, - 4, - 22 - ], - [ - 4, - 35, - 22, - 46, - 23, - 19 - ], - [ - 29, - 22, - 59, - 49, - 24, - 28 - ], - [ - 17, - 31, - 5, - 56, - 4, - 9 - ], - [ - 61, - 29, - 0, - 48, - 59, - 50 - ], - [ - 8, - 6, - 22, - 60, - 55, - 31 - ], - [ - 60, - 39, - 19, - 57, - 53, - 1 - ], - [ - 56, - 30, - 22, - 10, - 5, - 18 - ], - [ - 20, - 31, - 1, - 26, - 61, - 37 - ], - [ - 8, - 23, - 7, - 46, - 48, - 4 - ], - [ - 45, - 23, - 51, - 13, - 17, - 4 - ], - [ - 17, - 13, - 61, - 29, - 14, - 55 - ], - [ - 14, - 27, - 21, - 43, - 57, - 56 - ], - [ - 52, - 51, - 15, - 58, - 8, - 5 - ], - [ - 4, - 51, - 49, - 14, - 21, - 34 - ], - [ - 4, - 9, - 47, - 13, - 8, - 61 - ], - [ - 27, - 3, - 16, - 43, - 31, - 47 - ], - [ - 59, - 43, - 29, - 61, - 0, - 40 - ] - ], - [ - [ - 48, - 21, - 18, - 49, - 41, - 23 - ], - [ - 36, - 4, - 60, - 8, - 49, - 44 - ], - [ - 20, - 39, - 30, - 59, - 45, - 55 - ], - [ - 35, - 46, - 15, - 48, - 33, - 2 - ], - [ - 61, - 37, - 8, - 15, - 54, - 10 - ], - [ - 46, - 6, - 51, - 29, - 58, - 4 - ], - [ - 28, - 11, - 44, - 60, - 0, - 1 - ], - [ - 63, - 12, - 13, - 27, - 10, - 0 - ], - [ - 33, - 42, - 54, - 44, - 23, - 14 - ], - [ - 18, - 60, - 30, - 22, - 40, - 14 - ], - [ - 35, - 4, - 61, - 9, - 18, - 33 - ], - [ - 59, - 45, - 48, - 28, - 62, - 22 - ], - [ - 17, - 56, - 7, - 5, - 53, - 36 - ], - [ - 8, - 47, - 29, - 59, - 1, - 6 - ], - [ - 9, - 8, - 18, - 22, - 60, - 15 - ], - [ - 46, - 60, - 22, - 44, - 30, - 57 - ], - [ - 61, - 2, - 27, - 34, - 7, - 60 - ], - [ - 48, - 21, - 37, - 17, - 50, - 57 - ], - [ - 44, - 42, - 5, - 2, - 48, - 61 - ], - [ - 37, - 57, - 4, - 36, - 17, - 59 - ], - [ - 7, - 17, - 44, - 61, - 53, - 33 - ], - [ - 60, - 45, - 14, - 42, - 18, - 9 - ], - [ - 10, - 8, - 36, - 33, - 15, - 58 - ], - [ - 19, - 20, - 11, - 4, - 49, - 51 - ], - [ - 61, - 37, - 47, - 23, - 12, - 3 - ], - [ - 17, - 33, - 3, - 40, - 19, - 26 - ], - [ - 61, - 45, - 49, - 1, - 14, - 63 - ] - ], - [ - [ - 13, - 40, - 55, - 63, - 26, - 41 - ], - [ - 5, - 35, - 49, - 40, - 17, - 46 - ], - [ - 38, - 17, - 59, - 49, - 2, - 58 - ], - [ - 40, - 8, - 1, - 16, - 0, - 11 - ], - [ - 37, - 62, - 51, - 10, - 8, - 38 - ], - [ - 9, - 42, - 61, - 29, - 35, - 33 - ], - [ - 63, - 53, - 11, - 16, - 33, - 60 - ], - [ - 63, - 37, - 5, - 13, - 17, - 39 - ], - [ - 44, - 20, - 31, - 54, - 38, - 21 - ], - [ - 43, - 21, - 30, - 34, - 18, - 49 - ], - [ - 2, - 11, - 19, - 35, - 4, - 9 - ], - [ - 22, - 60, - 43, - 2, - 4, - 49 - ], - [ - 29, - 5, - 17, - 22, - 24, - 55 - ], - [ - 2, - 59, - 29, - 5, - 55, - 6 - ], - [ - 23, - 8, - 36, - 22, - 15, - 30 - ], - [ - 12, - 44, - 41, - 5, - 45, - 22 - ], - [ - 54, - 7, - 41, - 11, - 1, - 53 - ], - [ - 6, - 50, - 2, - 9, - 21, - 37 - ], - [ - 13, - 19, - 5, - 10, - 48, - 8 - ], - [ - 41, - 6, - 32, - 21, - 0, - 47 - ], - [ - 38, - 33, - 36, - 53, - 31, - 61 - ], - [ - 3, - 26, - 7, - 62, - 18, - 59 - ], - [ - 56, - 57, - 46, - 12, - 35, - 3 - ], - [ - 20, - 16, - 22, - 24, - 27, - 42 - ], - [ - 36, - 21, - 46, - 34, - 3, - 11 - ], - [ - 33, - 34, - 60, - 45, - 7, - 32 - ], - [ - 56, - 34, - 52, - 58, - 26, - 48 - ] - ], - [ - [ - 54, - 23, - 53, - 11, - 58, - 3 - ], - [ - 11, - 30, - 59, - 58, - 63, - 4 - ], - [ - 20, - 29, - 58, - 17, - 42, - 4 - ], - [ - 1, - 35, - 40, - 45, - 53, - 21 - ], - [ - 40, - 55, - 33, - 21, - 38, - 49 - ], - [ - 45, - 29, - 61, - 27, - 63, - 62 - ], - [ - 33, - 57, - 11, - 28, - 53, - 34 - ], - [ - 11, - 63, - 39, - 10, - 45, - 14 - ], - [ - 30, - 54, - 57, - 59, - 33, - 26 - ], - [ - 43, - 23, - 5, - 18, - 21, - 42 - ], - [ - 11, - 18, - 2, - 9, - 34, - 6 - ], - [ - 22, - 60, - 28, - 2, - 63, - 17 - ], - [ - 5, - 41, - 6, - 17, - 56, - 29 - ], - [ - 55, - 6, - 5, - 2, - 48, - 59 - ], - [ - 23, - 19, - 22, - 62, - 11, - 9 - ], - [ - 12, - 45, - 41, - 8, - 27, - 42 - ], - [ - 11, - 53, - 41, - 44, - 51, - 7 - ], - [ - 21, - 2, - 6, - 36, - 50, - 56 - ], - [ - 13, - 10, - 48, - 53, - 61, - 19 - ], - [ - 47, - 21, - 56, - 44, - 6, - 31 - ], - [ - 44, - 12, - 3, - 55, - 41, - 53 - ], - [ - 44, - 47, - 28, - 43, - 45, - 63 - ], - [ - 1, - 25, - 53, - 11, - 39, - 19 - ], - [ - 1, - 59, - 38, - 3, - 37, - 42 - ], - [ - 45, - 3, - 0, - 21, - 22, - 33 - ], - [ - 10, - 28, - 42, - 49, - 11, - 3 - ], - [ - 30, - 57, - 15, - 16, - 56, - 41 - ] - ], - [ - [ - 53, - 15, - 34, - 0, - 46, - 33 - ], - [ - 8, - 12, - 41, - 19, - 39, - 32 - ], - [ - 56, - 31, - 36, - 13, - 23, - 9 - ], - [ - 36, - 51, - 30, - 21, - 1, - 11 - ], - [ - 13, - 58, - 50, - 2, - 53, - 54 - ], - [ - 49, - 52, - 32, - 7, - 23, - 47 - ], - [ - 61, - 38, - 23, - 39, - 0, - 35 - ], - [ - 42, - 27, - 9, - 20, - 17, - 57 - ], - [ - 34, - 1, - 29, - 4, - 35, - 45 - ], - [ - 54, - 57, - 27, - 19, - 38, - 62 - ], - [ - 59, - 1, - 60, - 26, - 38, - 22 - ], - [ - 25, - 31, - 51, - 36, - 32, - 8 - ], - [ - 14, - 62, - 2, - 19, - 37, - 11 - ], - [ - 57, - 40, - 13, - 22, - 37, - 46 - ], - [ - 45, - 34, - 58, - 44, - 42, - 16 - ], - [ - 50, - 16, - 6, - 5, - 33, - 43 - ], - [ - 42, - 39, - 61, - 13, - 5, - 15 - ], - [ - 46, - 23, - 27, - 4, - 28, - 63 - ], - [ - 62, - 31, - 10, - 45, - 35, - 56 - ], - [ - 15, - 13, - 38, - 63, - 4, - 31 - ], - [ - 34, - 15, - 38, - 57, - 27, - 19 - ], - [ - 41, - 62, - 36, - 57, - 19, - 47 - ], - [ - 34, - 22, - 53, - 10, - 46, - 33 - ], - [ - 24, - 51, - 4, - 47, - 39, - 10 - ], - [ - 11, - 57, - 51, - 50, - 54, - 48 - ], - [ - 51, - 7, - 11, - 43, - 50, - 18 - ], - [ - 39, - 37, - 9, - 42, - 40, - 44 - ] - ], - [ - [ - 57, - 17, - 62, - 42, - 23, - 60 - ], - [ - 18, - 7, - 53, - 43, - 26, - 60 - ], - [ - 60, - 5, - 3, - 53, - 23, - 57 - ], - [ - 17, - 10, - 22, - 19, - 11, - 31 - ], - [ - 10, - 15, - 12, - 27, - 17, - 4 - ], - [ - 2, - 44, - 39, - 36, - 25, - 54 - ], - [ - 52, - 62, - 37, - 21, - 41, - 42 - ], - [ - 62, - 7, - 46, - 30, - 36, - 14 - ], - [ - 50, - 4, - 10, - 58, - 0, - 16 - ], - [ - 32, - 47, - 38, - 54, - 8, - 41 - ], - [ - 41, - 32, - 3, - 18, - 1, - 22 - ], - [ - 16, - 50, - 53, - 7, - 44, - 21 - ], - [ - 11, - 39, - 3, - 35, - 25, - 46 - ], - [ - 18, - 9, - 15, - 13, - 63, - 28 - ], - [ - 62, - 58, - 13, - 5, - 17, - 3 - ], - [ - 3, - 31, - 43, - 53, - 35, - 57 - ], - [ - 24, - 51, - 15, - 46, - 32, - 5 - ], - [ - 13, - 30, - 0, - 32, - 5, - 59 - ], - [ - 2, - 39, - 32, - 38, - 34, - 22 - ], - [ - 42, - 26, - 34, - 28, - 37, - 54 - ], - [ - 28, - 43, - 53, - 41, - 13, - 23 - ], - [ - 14, - 15, - 34, - 1, - 48, - 40 - ], - [ - 5, - 25, - 4, - 33, - 39, - 6 - ], - [ - 58, - 4, - 57, - 17, - 51, - 11 - ], - [ - 37, - 47, - 35, - 31, - 63, - 29 - ], - [ - 15, - 3, - 28, - 33, - 23, - 9 - ], - [ - 23, - 6, - 58, - 47, - 56, - 30 - ] - ], - [ - [ - 47, - 29, - 14, - 6, - 51, - 43 - ], - [ - 30, - 29, - 39, - 7, - 52, - 12 - ], - [ - 63, - 34, - 41, - 2, - 47, - 7 - ], - [ - 4, - 28, - 54, - 45, - 52, - 58 - ], - [ - 29, - 7, - 12, - 15, - 41, - 6 - ], - [ - 34, - 29, - 48, - 3, - 43, - 40 - ], - [ - 30, - 29, - 16, - 47, - 42, - 45 - ], - [ - 33, - 39, - 25, - 60, - 41, - 3 - ], - [ - 50, - 26, - 4, - 25, - 17, - 13 - ], - [ - 5, - 54, - 43, - 16, - 12, - 53 - ], - [ - 18, - 6, - 35, - 3, - 21, - 1 - ], - [ - 56, - 46, - 48, - 10, - 16, - 44 - ], - [ - 9, - 35, - 7, - 24, - 47, - 57 - ], - [ - 53, - 42, - 15, - 56, - 59, - 47 - ], - [ - 39, - 11, - 36, - 32, - 35, - 18 - ], - [ - 46, - 20, - 53, - 38, - 56, - 26 - ], - [ - 58, - 29, - 14, - 26, - 17, - 49 - ], - [ - 24, - 25, - 39, - 16, - 1, - 57 - ], - [ - 24, - 41, - 4, - 1, - 63, - 28 - ], - [ - 42, - 37, - 48, - 34, - 26, - 41 - ], - [ - 11, - 28, - 16, - 32, - 7, - 56 - ], - [ - 14, - 42, - 6, - 16, - 22, - 15 - ], - [ - 33, - 56, - 42, - 8, - 25, - 38 - ], - [ - 4, - 58, - 48, - 33, - 11, - 28 - ], - [ - 37, - 47, - 29, - 48, - 30, - 53 - ], - [ - 12, - 41, - 3, - 4, - 48, - 46 - ], - [ - 14, - 13, - 61, - 6, - 62, - 1 - ] - ], - [ - [ - 45, - 10, - 44, - 43, - 53, - 33 - ], - [ - 32, - 63, - 22, - 27, - 30, - 29 - ], - [ - 54, - 35, - 37, - 32, - 26, - 30 - ], - [ - 24, - 63, - 0, - 17, - 25, - 45 - ], - [ - 40, - 7, - 0, - 57, - 29, - 22 - ], - [ - 10, - 34, - 20, - 22, - 43, - 33 - ], - [ - 42, - 30, - 5, - 25, - 19, - 34 - ], - [ - 33, - 18, - 35, - 51, - 7, - 57 - ], - [ - 50, - 28, - 25, - 4, - 10, - 9 - ], - [ - 5, - 38, - 16, - 43, - 54, - 21 - ], - [ - 18, - 21, - 3, - 6, - 39, - 53 - ], - [ - 56, - 16, - 53, - 39, - 46, - 42 - ], - [ - 9, - 35, - 57, - 11, - 47, - 13 - ], - [ - 56, - 15, - 3, - 59, - 9, - 28 - ], - [ - 4, - 62, - 39, - 58, - 63, - 36 - ], - [ - 20, - 53, - 57, - 8, - 51, - 35 - ], - [ - 51, - 49, - 11, - 26, - 15, - 14 - ], - [ - 0, - 25, - 59, - 62, - 21, - 13 - ], - [ - 39, - 34, - 48, - 53, - 61, - 33 - ], - [ - 47, - 26, - 28, - 34, - 21, - 39 - ], - [ - 28, - 43, - 12, - 53, - 41, - 32 - ], - [ - 14, - 52, - 17, - 1, - 15, - 38 - ], - [ - 25, - 4, - 5, - 11, - 58, - 50 - ], - [ - 58, - 4, - 17, - 10, - 25, - 57 - ], - [ - 29, - 47, - 35, - 31, - 52, - 48 - ], - [ - 55, - 28, - 23, - 15, - 3, - 24 - ], - [ - 6, - 30, - 57, - 32, - 34, - 62 - ] - ], - [ - [ - 39, - 5, - 30, - 17, - 61, - 15 - ], - [ - 11, - 63, - 0, - 23, - 61, - 46 - ], - [ - 61, - 15, - 53, - 22, - 7, - 57 - ], - [ - 50, - 57, - 58, - 63, - 45, - 47 - ], - [ - 55, - 31, - 57, - 24, - 60, - 5 - ], - [ - 22, - 7, - 43, - 1, - 10, - 0 - ], - [ - 14, - 58, - 1, - 34, - 19, - 45 - ], - [ - 50, - 8, - 14, - 7, - 57, - 9 - ], - [ - 43, - 0, - 4, - 10, - 45, - 46 - ], - [ - 38, - 0, - 4, - 55, - 54, - 10 - ], - [ - 51, - 54, - 46, - 39, - 1, - 38 - ], - [ - 16, - 37, - 33, - 36, - 21, - 63 - ], - [ - 10, - 38, - 57, - 58, - 3, - 63 - ], - [ - 30, - 0, - 63, - 13, - 22, - 18 - ], - [ - 55, - 58, - 62, - 38, - 6, - 36 - ], - [ - 19, - 53, - 11, - 1, - 57, - 25 - ], - [ - 11, - 6, - 51, - 16, - 30, - 18 - ], - [ - 20, - 57, - 32, - 36, - 13, - 56 - ], - [ - 3, - 12, - 61, - 26, - 32, - 1 - ], - [ - 25, - 36, - 34, - 51, - 59, - 37 - ], - [ - 13, - 53, - 28, - 16, - 14, - 9 - ], - [ - 14, - 42, - 15, - 55, - 22, - 38 - ], - [ - 62, - 58, - 29, - 33, - 5, - 34 - ], - [ - 21, - 4, - 51, - 49, - 12, - 58 - ], - [ - 61, - 60, - 40, - 35, - 59, - 47 - ], - [ - 37, - 46, - 3, - 48, - 12, - 53 - ], - [ - 63, - 8, - 46, - 33, - 1, - 53 - ] - ], - [ - [ - 62, - 0, - 9, - 61, - 26, - 41 - ], - [ - 45, - 3, - 29, - 35, - 2, - 54 - ], - [ - 56, - 31, - 53, - 28, - 23, - 2 - ], - [ - 13, - 41, - 46, - 49, - 5, - 45 - ], - [ - 18, - 25, - 57, - 55, - 50, - 15 - ], - [ - 58, - 22, - 4, - 46, - 19, - 12 - ], - [ - 1, - 25, - 58, - 22, - 43, - 35 - ], - [ - 50, - 23, - 12, - 8, - 9, - 16 - ], - [ - 29, - 17, - 16, - 43, - 10, - 4 - ], - [ - 38, - 55, - 0, - 40, - 20, - 10 - ], - [ - 51, - 19, - 54, - 39, - 46, - 1 - ], - [ - 60, - 16, - 37, - 11, - 56, - 36 - ], - [ - 10, - 7, - 38, - 57, - 54, - 44 - ], - [ - 45, - 13, - 0, - 30, - 47, - 63 - ], - [ - 58, - 48, - 55, - 38, - 36, - 29 - ], - [ - 29, - 53, - 0, - 1, - 55, - 57 - ], - [ - 11, - 12, - 6, - 29, - 5, - 40 - ], - [ - 57, - 32, - 42, - 8, - 20, - 36 - ], - [ - 3, - 61, - 49, - 32, - 26, - 1 - ], - [ - 42, - 5, - 36, - 16, - 39, - 51 - ], - [ - 9, - 13, - 7, - 53, - 46, - 28 - ], - [ - 14, - 55, - 42, - 24, - 33, - 47 - ], - [ - 33, - 58, - 62, - 15, - 36, - 5 - ], - [ - 51, - 4, - 49, - 58, - 21, - 41 - ], - [ - 40, - 60, - 19, - 35, - 61, - 15 - ], - [ - 14, - 21, - 3, - 12, - 5, - 37 - ], - [ - 35, - 60, - 1, - 10, - 53, - 54 - ] - ], - [ - [ - 40, - 28, - 60, - 31, - 59, - 23 - ], - [ - 42, - 12, - 26, - 34, - 0, - 61 - ], - [ - 16, - 5, - 62, - 3, - 32, - 0 - ], - [ - 13, - 3, - 18, - 39, - 42, - 52 - ], - [ - 25, - 12, - 39, - 0, - 57, - 15 - ], - [ - 33, - 50, - 58, - 36, - 6, - 26 - ], - [ - 41, - 43, - 1, - 15, - 2, - 25 - ], - [ - 36, - 16, - 53, - 14, - 58, - 30 - ], - [ - 16, - 29, - 3, - 2, - 10, - 5 - ], - [ - 38, - 16, - 32, - 26, - 45, - 11 - ], - [ - 45, - 19, - 3, - 46, - 56, - 53 - ], - [ - 21, - 36, - 11, - 9, - 16, - 41 - ], - [ - 11, - 16, - 10, - 50, - 8, - 63 - ], - [ - 35, - 38, - 33, - 51, - 25, - 22 - ], - [ - 58, - 48, - 27, - 50, - 29, - 26 - ], - [ - 29, - 57, - 53, - 3, - 54, - 4 - ], - [ - 3, - 23, - 6, - 46, - 51, - 11 - ], - [ - 32, - 5, - 57, - 54, - 52, - 30 - ], - [ - 34, - 61, - 40, - 0, - 3, - 48 - ], - [ - 22, - 53, - 12, - 33, - 39, - 60 - ], - [ - 20, - 53, - 47, - 40, - 12, - 8 - ], - [ - 55, - 17, - 52, - 14, - 45, - 56 - ], - [ - 1, - 4, - 13, - 11, - 39, - 33 - ], - [ - 17, - 10, - 51, - 44, - 4, - 55 - ], - [ - 35, - 6, - 45, - 21, - 52, - 37 - ], - [ - 10, - 42, - 9, - 1, - 53, - 24 - ], - [ - 27, - 41, - 32, - 45, - 10, - 47 - ] - ], - [ - [ - 11, - 34, - 44, - 51, - 41, - 12 - ], - [ - 20, - 34, - 3, - 25, - 63, - 16 - ], - [ - 40, - 26, - 37, - 22, - 15, - 54 - ], - [ - 4, - 16, - 25, - 28, - 45, - 58 - ], - [ - 33, - 10, - 32, - 24, - 3, - 4 - ], - [ - 0, - 9, - 58, - 3, - 34, - 15 - ], - [ - 16, - 63, - 43, - 59, - 25, - 42 - ], - [ - 16, - 48, - 36, - 33, - 25, - 58 - ], - [ - 20, - 50, - 16, - 23, - 42, - 29 - ], - [ - 60, - 51, - 16, - 22, - 38, - 48 - ], - [ - 31, - 19, - 38, - 3, - 2, - 43 - ], - [ - 56, - 9, - 21, - 40, - 18, - 44 - ], - [ - 11, - 13, - 10, - 45, - 27, - 57 - ], - [ - 56, - 14, - 35, - 44, - 45, - 13 - ], - [ - 58, - 43, - 26, - 33, - 31, - 50 - ], - [ - 13, - 38, - 46, - 51, - 28, - 37 - ], - [ - 58, - 63, - 42, - 10, - 0, - 16 - ], - [ - 0, - 1, - 53, - 52, - 24, - 59 - ], - [ - 17, - 29, - 0, - 40, - 60, - 24 - ], - [ - 17, - 24, - 2, - 26, - 35, - 23 - ], - [ - 44, - 20, - 22, - 4, - 50, - 40 - ], - [ - 9, - 61, - 17, - 16, - 27, - 37 - ], - [ - 63, - 13, - 16, - 11, - 2, - 24 - ], - [ - 44, - 48, - 23, - 49, - 56, - 45 - ], - [ - 6, - 5, - 49, - 33, - 31, - 63 - ], - [ - 49, - 47, - 1, - 57, - 4, - 53 - ], - [ - 2, - 56, - 19, - 63, - 39, - 50 - ] - ], - [ - [ - 16, - 22, - 44, - 23, - 3, - 13 - ], - [ - 16, - 23, - 50, - 9, - 13, - 59 - ], - [ - 11, - 35, - 21, - 7, - 59, - 24 - ], - [ - 53, - 32, - 7, - 60, - 1, - 13 - ], - [ - 24, - 33, - 57, - 35, - 0, - 39 - ], - [ - 15, - 38, - 0, - 45, - 32, - 58 - ], - [ - 59, - 16, - 25, - 49, - 35, - 5 - ], - [ - 45, - 40, - 16, - 14, - 12, - 30 - ], - [ - 20, - 16, - 42, - 6, - 17, - 23 - ], - [ - 60, - 10, - 22, - 48, - 50, - 38 - ], - [ - 31, - 19, - 35, - 27, - 38, - 53 - ], - [ - 33, - 59, - 48, - 16, - 60, - 39 - ], - [ - 9, - 57, - 10, - 11, - 7, - 44 - ], - [ - 13, - 0, - 59, - 61, - 19, - 33 - ], - [ - 18, - 39, - 58, - 21, - 50, - 36 - ], - [ - 57, - 60, - 49, - 19, - 53, - 54 - ], - [ - 6, - 36, - 17, - 41, - 38, - 10 - ], - [ - 63, - 1, - 57, - 20, - 36, - 31 - ], - [ - 40, - 41, - 7, - 0, - 61, - 50 - ], - [ - 39, - 16, - 23, - 14, - 17, - 59 - ], - [ - 51, - 30, - 50, - 7, - 3, - 61 - ], - [ - 9, - 14, - 42, - 15, - 1, - 55 - ], - [ - 33, - 19, - 58, - 56, - 11, - 15 - ], - [ - 9, - 26, - 4, - 41, - 51, - 10 - ], - [ - 4, - 60, - 35, - 16, - 47, - 48 - ], - [ - 3, - 52, - 31, - 16, - 26, - 39 - ], - [ - 4, - 1, - 3, - 61, - 0, - 54 - ] - ], - [ - [ - 38, - 45, - 52, - 43, - 32, - 39 - ], - [ - 29, - 48, - 22, - 60, - 55, - 57 - ], - [ - 5, - 49, - 8, - 20, - 14, - 55 - ], - [ - 35, - 46, - 49, - 32, - 7, - 48 - ], - [ - 24, - 29, - 49, - 16, - 41, - 0 - ], - [ - 56, - 3, - 35, - 15, - 4, - 40 - ], - [ - 49, - 59, - 4, - 16, - 33, - 11 - ], - [ - 12, - 51, - 30, - 56, - 40, - 0 - ], - [ - 23, - 46, - 14, - 55, - 42, - 17 - ], - [ - 60, - 22, - 38, - 20, - 21, - 6 - ], - [ - 11, - 19, - 31, - 35, - 38, - 22 - ], - [ - 59, - 48, - 19, - 33, - 16, - 9 - ], - [ - 9, - 28, - 45, - 57, - 55, - 53 - ], - [ - 59, - 1, - 25, - 6, - 47, - 3 - ], - [ - 9, - 39, - 18, - 50, - 58, - 21 - ], - [ - 60, - 57, - 38, - 53, - 50, - 54 - ], - [ - 41, - 53, - 6, - 38, - 10, - 25 - ], - [ - 1, - 63, - 15, - 24, - 36, - 29 - ], - [ - 41, - 40, - 7, - 42, - 62, - 32 - ], - [ - 4, - 10, - 14, - 37, - 39, - 17 - ], - [ - 61, - 30, - 7, - 50, - 3, - 57 - ], - [ - 9, - 58, - 14, - 30, - 42, - 8 - ], - [ - 33, - 56, - 19, - 60, - 24, - 59 - ], - [ - 9, - 26, - 4, - 51, - 25, - 41 - ], - [ - 4, - 35, - 47, - 41, - 48, - 60 - ], - [ - 39, - 52, - 3, - 26, - 8, - 30 - ], - [ - 4, - 1, - 61, - 10, - 0, - 12 - ] - ], - [ - [ - 51, - 53, - 33, - 13, - 28, - 48 - ], - [ - 63, - 31, - 41, - 39, - 40, - 49 - ], - [ - 42, - 14, - 3, - 24, - 50, - 44 - ], - [ - 11, - 39, - 52, - 10, - 17, - 42 - ], - [ - 7, - 60, - 58, - 15, - 12, - 27 - ], - [ - 34, - 36, - 44, - 28, - 29, - 40 - ], - [ - 52, - 17, - 47, - 42, - 37, - 41 - ], - [ - 62, - 33, - 38, - 39, - 41, - 36 - ], - [ - 50, - 58, - 37, - 10, - 16, - 62 - ], - [ - 41, - 59, - 12, - 62, - 49, - 32 - ], - [ - 3, - 35, - 10, - 41, - 6, - 25 - ], - [ - 39, - 58, - 53, - 21, - 19, - 54 - ], - [ - 42, - 35, - 7, - 39, - 63, - 21 - ], - [ - 15, - 9, - 3, - 54, - 51, - 32 - ], - [ - 3, - 35, - 32, - 36, - 18, - 13 - ], - [ - 20, - 3, - 35, - 45, - 32, - 4 - ], - [ - 46, - 24, - 32, - 33, - 14, - 44 - ], - [ - 16, - 57, - 30, - 19, - 61, - 63 - ], - [ - 41, - 25, - 2, - 24, - 26, - 4 - ], - [ - 42, - 37, - 47, - 48, - 33, - 50 - ], - [ - 54, - 13, - 53, - 56, - 63, - 0 - ], - [ - 34, - 14, - 7, - 51, - 42, - 6 - ], - [ - 6, - 43, - 25, - 58, - 11, - 39 - ], - [ - 0, - 26, - 44, - 4, - 51, - 9 - ], - [ - 37, - 63, - 4, - 15, - 13, - 61 - ], - [ - 9, - 59, - 41, - 1, - 52, - 25 - ], - [ - 56, - 58, - 47, - 1, - 62, - 51 - ] - ], - [ - [ - 52, - 47, - 27, - 36, - 38, - 33 - ], - [ - 43, - 56, - 4, - 25, - 52, - 21 - ], - [ - 25, - 54, - 35, - 18, - 11, - 63 - ], - [ - 17, - 4, - 1, - 18, - 50, - 39 - ], - [ - 36, - 7, - 32, - 4, - 30, - 60 - ], - [ - 34, - 3, - 61, - 48, - 24, - 40 - ], - [ - 17, - 42, - 3, - 12, - 29, - 25 - ], - [ - 33, - 38, - 39, - 55, - 17, - 19 - ], - [ - 17, - 50, - 41, - 16, - 13, - 51 - ], - [ - 41, - 12, - 60, - 16, - 62, - 58 - ], - [ - 3, - 36, - 28, - 39, - 35, - 2 - ], - [ - 48, - 27, - 53, - 54, - 19, - 43 - ], - [ - 42, - 63, - 7, - 35, - 62, - 3 - ], - [ - 56, - 15, - 9, - 51, - 42, - 47 - ], - [ - 35, - 39, - 18, - 36, - 43, - 38 - ], - [ - 60, - 20, - 0, - 3, - 35, - 31 - ], - [ - 24, - 14, - 42, - 17, - 32, - 62 - ], - [ - 24, - 57, - 63, - 0, - 42, - 1 - ], - [ - 41, - 50, - 28, - 26, - 32, - 40 - ], - [ - 2, - 37, - 46, - 17, - 42, - 57 - ], - [ - 49, - 7, - 11, - 54, - 13, - 53 - ], - [ - 14, - 9, - 42, - 6, - 19, - 58 - ], - [ - 63, - 33, - 43, - 19, - 58, - 25 - ], - [ - 9, - 26, - 48, - 4, - 41, - 44 - ], - [ - 4, - 41, - 5, - 15, - 9, - 53 - ], - [ - 14, - 38, - 9, - 3, - 52, - 8 - ], - [ - 23, - 1, - 61, - 4, - 47, - 14 - ] - ], - [ - [ - 37, - 14, - 3, - 5, - 33, - 53 - ], - [ - 34, - 0, - 56, - 58, - 37, - 13 - ], - [ - 61, - 14, - 22, - 29, - 15, - 46 - ], - [ - 50, - 58, - 18, - 44, - 47, - 17 - ], - [ - 4, - 36, - 53, - 7, - 32, - 2 - ], - [ - 22, - 1, - 34, - 61, - 10, - 33 - ], - [ - 28, - 25, - 42, - 29, - 30, - 3 - ], - [ - 33, - 51, - 38, - 39, - 62, - 60 - ], - [ - 17, - 51, - 26, - 16, - 46, - 50 - ], - [ - 12, - 16, - 59, - 6, - 38, - 3 - ], - [ - 39, - 31, - 35, - 3, - 36, - 0 - ], - [ - 48, - 19, - 53, - 56, - 39, - 27 - ], - [ - 7, - 62, - 9, - 63, - 15, - 42 - ], - [ - 56, - 19, - 9, - 47, - 15, - 59 - ], - [ - 39, - 36, - 18, - 35, - 42, - 38 - ], - [ - 60, - 20, - 0, - 57, - 47, - 53 - ], - [ - 6, - 32, - 29, - 62, - 43, - 5 - ], - [ - 24, - 63, - 57, - 1, - 42, - 2 - ], - [ - 41, - 40, - 7, - 62, - 32, - 50 - ], - [ - 37, - 14, - 10, - 2, - 57, - 17 - ], - [ - 7, - 54, - 30, - 53, - 50, - 49 - ], - [ - 9, - 14, - 8, - 6, - 42, - 58 - ], - [ - 33, - 60, - 56, - 19, - 18, - 15 - ], - [ - 9, - 26, - 4, - 41, - 59, - 23 - ], - [ - 4, - 41, - 47, - 9, - 2, - 16 - ], - [ - 3, - 14, - 16, - 30, - 8, - 52 - ], - [ - 1, - 4, - 14, - 61, - 0, - 12 - ] - ], - [ - [ - 27, - 21, - 61, - 30, - 22, - 63 - ], - [ - 33, - 12, - 0, - 44, - 47, - 27 - ], - [ - 14, - 39, - 58, - 37, - 16, - 63 - ], - [ - 11, - 47, - 4, - 18, - 32, - 35 - ], - [ - 2, - 19, - 20, - 39, - 61, - 10 - ], - [ - 39, - 61, - 49, - 15, - 30, - 47 - ], - [ - 12, - 47, - 17, - 49, - 29, - 46 - ], - [ - 39, - 41, - 62, - 30, - 52, - 37 - ], - [ - 21, - 15, - 28, - 48, - 26, - 27 - ], - [ - 12, - 32, - 16, - 24, - 55, - 41 - ], - [ - 43, - 35, - 28, - 31, - 60, - 47 - ], - [ - 3, - 17, - 19, - 49, - 10, - 42 - ], - [ - 7, - 54, - 27, - 59, - 4, - 23 - ], - [ - 35, - 33, - 7, - 6, - 14, - 51 - ], - [ - 51, - 28, - 36, - 42, - 41, - 32 - ], - [ - 45, - 52, - 60, - 21, - 53, - 12 - ], - [ - 53, - 1, - 23, - 5, - 6, - 46 - ], - [ - 35, - 9, - 63, - 16, - 26, - 22 - ], - [ - 41, - 60, - 27, - 40, - 42, - 15 - ], - [ - 42, - 48, - 46, - 9, - 17, - 37 - ], - [ - 31, - 11, - 23, - 58, - 39, - 44 - ], - [ - 34, - 2, - 4, - 14, - 13, - 9 - ], - [ - 60, - 59, - 47, - 40, - 17, - 27 - ], - [ - 60, - 63, - 7, - 42, - 40, - 49 - ], - [ - 59, - 15, - 38, - 62, - 44, - 25 - ], - [ - 23, - 57, - 60, - 31, - 41, - 3 - ], - [ - 38, - 59, - 31, - 51, - 36, - 7 - ] - ], - [ - [ - 12, - 10, - 50, - 23, - 53, - 33 - ], - [ - 48, - 56, - 44, - 11, - 31, - 17 - ], - [ - 33, - 11, - 17, - 54, - 15, - 62 - ], - [ - 31, - 13, - 17, - 40, - 8, - 3 - ], - [ - 42, - 4, - 27, - 15, - 12, - 5 - ], - [ - 34, - 50, - 33, - 29, - 55, - 6 - ], - [ - 16, - 49, - 52, - 29, - 41, - 30 - ], - [ - 0, - 30, - 62, - 57, - 26, - 22 - ], - [ - 26, - 62, - 5, - 58, - 51, - 49 - ], - [ - 5, - 62, - 16, - 46, - 39, - 37 - ], - [ - 31, - 10, - 24, - 45, - 18, - 35 - ], - [ - 50, - 19, - 48, - 54, - 16, - 35 - ], - [ - 59, - 39, - 62, - 7, - 35, - 28 - ], - [ - 20, - 9, - 19, - 6, - 22, - 15 - ], - [ - 35, - 18, - 50, - 32, - 36, - 39 - ], - [ - 3, - 4, - 13, - 38, - 60, - 26 - ], - [ - 46, - 32, - 28, - 3, - 37, - 33 - ], - [ - 30, - 63, - 16, - 19, - 24, - 42 - ], - [ - 41, - 17, - 4, - 2, - 32, - 34 - ], - [ - 42, - 37, - 48, - 33, - 3, - 31 - ], - [ - 56, - 7, - 25, - 11, - 39, - 44 - ], - [ - 42, - 14, - 26, - 13, - 12, - 22 - ], - [ - 25, - 11, - 6, - 42, - 13, - 38 - ], - [ - 0, - 9, - 26, - 41, - 4, - 57 - ], - [ - 37, - 4, - 63, - 41, - 2, - 44 - ], - [ - 9, - 52, - 41, - 3, - 16, - 59 - ], - [ - 58, - 47, - 56, - 1, - 26, - 62 - ] - ], - [ - [ - 54, - 30, - 22, - 26, - 3, - 55 - ], - [ - 44, - 7, - 49, - 50, - 25, - 5 - ], - [ - 54, - 4, - 48, - 58, - 26, - 32 - ], - [ - 20, - 25, - 3, - 9, - 55, - 28 - ], - [ - 7, - 27, - 42, - 12, - 58, - 32 - ], - [ - 50, - 34, - 6, - 42, - 29, - 55 - ], - [ - 16, - 49, - 40, - 3, - 27, - 11 - ], - [ - 51, - 30, - 26, - 62, - 53, - 57 - ], - [ - 3, - 5, - 49, - 28, - 26, - 50 - ], - [ - 16, - 22, - 46, - 6, - 49, - 45 - ], - [ - 31, - 45, - 11, - 10, - 56, - 3 - ], - [ - 21, - 51, - 50, - 19, - 9, - 61 - ], - [ - 28, - 41, - 59, - 13, - 34, - 53 - ], - [ - 23, - 20, - 16, - 9, - 38, - 19 - ], - [ - 50, - 34, - 58, - 20, - 27, - 35 - ], - [ - 8, - 59, - 29, - 61, - 35, - 53 - ], - [ - 3, - 46, - 51, - 10, - 25, - 18 - ], - [ - 62, - 21, - 25, - 9, - 18, - 56 - ], - [ - 34, - 27, - 10, - 29, - 53, - 59 - ], - [ - 56, - 33, - 44, - 24, - 9, - 18 - ], - [ - 44, - 12, - 19, - 30, - 8, - 49 - ], - [ - 47, - 62, - 51, - 12, - 16, - 17 - ], - [ - 11, - 37, - 25, - 58, - 35, - 45 - ], - [ - 4, - 9, - 10, - 39, - 37, - 16 - ], - [ - 22, - 11, - 21, - 48, - 45, - 47 - ], - [ - 10, - 28, - 55, - 12, - 24, - 23 - ], - [ - 30, - 16, - 27, - 32, - 57, - 15 - ] - ], - [ - [ - 16, - 11, - 31, - 46, - 35, - 49 - ], - [ - 13, - 49, - 54, - 5, - 6, - 14 - ], - [ - 36, - 13, - 27, - 46, - 3, - 18 - ], - [ - 24, - 61, - 15, - 0, - 63, - 13 - ], - [ - 17, - 2, - 58, - 50, - 35, - 19 - ], - [ - 8, - 7, - 49, - 52, - 47, - 23 - ], - [ - 61, - 58, - 16, - 20, - 38, - 23 - ], - [ - 20, - 42, - 9, - 51, - 35, - 16 - ], - [ - 25, - 47, - 4, - 1, - 50, - 63 - ], - [ - 54, - 38, - 57, - 27, - 33, - 28 - ], - [ - 1, - 59, - 30, - 60, - 14, - 31 - ], - [ - 31, - 51, - 25, - 14, - 41, - 55 - ], - [ - 34, - 2, - 14, - 11, - 19, - 28 - ], - [ - 40, - 57, - 13, - 61, - 59, - 18 - ], - [ - 44, - 58, - 50, - 45, - 37, - 26 - ], - [ - 5, - 8, - 62, - 24, - 38, - 61 - ], - [ - 13, - 47, - 15, - 3, - 5, - 39 - ], - [ - 25, - 23, - 9, - 18, - 63, - 33 - ], - [ - 1, - 10, - 20, - 8, - 53, - 4 - ], - [ - 38, - 15, - 37, - 0, - 34, - 60 - ], - [ - 41, - 27, - 30, - 57, - 19, - 40 - ], - [ - 11, - 62, - 41, - 14, - 46, - 44 - ], - [ - 11, - 33, - 46, - 31, - 45, - 0 - ], - [ - 4, - 51, - 47, - 16, - 9, - 12 - ], - [ - 11, - 21, - 45, - 47, - 61, - 18 - ], - [ - 10, - 50, - 51, - 12, - 18, - 3 - ], - [ - 17, - 31, - 36, - 5, - 19, - 1 - ] - ], - [ - [ - 22, - 6, - 39, - 57, - 29, - 47 - ], - [ - 27, - 6, - 14, - 17, - 51, - 32 - ], - [ - 1, - 29, - 11, - 26, - 47, - 51 - ], - [ - 14, - 38, - 22, - 31, - 29, - 53 - ], - [ - 14, - 61, - 59, - 1, - 29, - 49 - ], - [ - 30, - 8, - 21, - 47, - 52, - 0 - ], - [ - 4, - 58, - 61, - 23, - 20, - 29 - ], - [ - 20, - 9, - 42, - 35, - 7, - 24 - ], - [ - 47, - 25, - 4, - 1, - 29, - 0 - ], - [ - 54, - 19, - 38, - 29, - 33, - 44 - ], - [ - 14, - 1, - 59, - 40, - 60, - 20 - ], - [ - 51, - 0, - 14, - 62, - 16, - 52 - ], - [ - 2, - 36, - 20, - 29, - 19, - 52 - ], - [ - 57, - 13, - 40, - 22, - 60, - 6 - ], - [ - 37, - 44, - 58, - 8, - 5, - 50 - ], - [ - 24, - 5, - 43, - 62, - 23, - 59 - ], - [ - 13, - 19, - 47, - 39, - 61, - 15 - ], - [ - 58, - 33, - 9, - 7, - 4, - 28 - ], - [ - 1, - 35, - 10, - 19, - 31, - 20 - ], - [ - 15, - 55, - 63, - 18, - 34, - 38 - ], - [ - 27, - 15, - 38, - 30, - 57, - 42 - ], - [ - 62, - 41, - 16, - 29, - 6, - 46 - ], - [ - 46, - 33, - 45, - 10, - 34, - 23 - ], - [ - 4, - 47, - 51, - 1, - 16, - 41 - ], - [ - 11, - 21, - 18, - 7, - 48, - 28 - ], - [ - 18, - 51, - 7, - 50, - 6, - 32 - ], - [ - 9, - 11, - 36, - 55, - 43, - 48 - ] - ], - [ - [ - 47, - 8, - 36, - 61, - 21, - 45 - ], - [ - 46, - 2, - 15, - 32, - 0, - 51 - ], - [ - 24, - 15, - 33, - 61, - 2, - 43 - ], - [ - 60, - 22, - 31, - 27, - 14, - 11 - ], - [ - 59, - 58, - 39, - 57, - 46, - 3 - ], - [ - 57, - 43, - 2, - 31, - 7, - 62 - ], - [ - 9, - 42, - 54, - 19, - 4, - 55 - ], - [ - 46, - 14, - 7, - 24, - 43, - 35 - ], - [ - 47, - 4, - 0, - 37, - 12, - 13 - ], - [ - 54, - 38, - 4, - 47, - 25, - 6 - ], - [ - 47, - 14, - 15, - 24, - 1, - 61 - ], - [ - 16, - 54, - 5, - 0, - 7, - 63 - ], - [ - 49, - 3, - 33, - 11, - 13, - 10 - ], - [ - 18, - 13, - 10, - 39, - 58, - 63 - ], - [ - 58, - 62, - 10, - 33, - 5, - 26 - ], - [ - 48, - 5, - 63, - 53, - 43, - 2 - ], - [ - 24, - 47, - 51, - 15, - 53, - 0 - ], - [ - 44, - 0, - 6, - 3, - 34, - 7 - ], - [ - 58, - 38, - 53, - 61, - 0, - 54 - ], - [ - 55, - 9, - 15, - 27, - 42, - 34 - ], - [ - 43, - 1, - 51, - 41, - 4, - 5 - ], - [ - 49, - 20, - 15, - 6, - 37, - 46 - ], - [ - 11, - 25, - 52, - 5, - 4, - 39 - ], - [ - 38, - 40, - 44, - 51, - 10, - 14 - ], - [ - 13, - 8, - 52, - 63, - 2, - 23 - ], - [ - 23, - 38, - 59, - 57, - 55, - 41 - ], - [ - 23, - 6, - 62, - 0, - 7, - 28 - ] - ], - [ - [ - 41, - 2, - 42, - 16, - 50, - 61 - ], - [ - 51, - 41, - 5, - 15, - 61, - 63 - ], - [ - 43, - 1, - 29, - 21, - 55, - 60 - ], - [ - 24, - 53, - 25, - 13, - 51, - 32 - ], - [ - 31, - 41, - 57, - 49, - 34, - 11 - ], - [ - 17, - 4, - 35, - 30, - 10, - 38 - ], - [ - 34, - 7, - 56, - 42, - 19, - 21 - ], - [ - 14, - 46, - 7, - 27, - 25, - 52 - ], - [ - 0, - 4, - 6, - 12, - 60, - 47 - ], - [ - 54, - 25, - 4, - 38, - 47, - 6 - ], - [ - 24, - 61, - 15, - 46, - 7, - 22 - ], - [ - 5, - 16, - 57, - 0, - 22, - 55 - ], - [ - 49, - 3, - 26, - 17, - 57, - 52 - ], - [ - 13, - 10, - 61, - 60, - 0, - 58 - ], - [ - 58, - 6, - 49, - 10, - 5, - 2 - ], - [ - 25, - 19, - 2, - 11, - 54, - 53 - ], - [ - 47, - 30, - 27, - 18, - 5, - 10 - ], - [ - 44, - 55, - 63, - 13, - 22, - 31 - ], - [ - 58, - 7, - 38, - 35, - 32, - 40 - ], - [ - 13, - 34, - 16, - 49, - 45, - 55 - ], - [ - 51, - 35, - 30, - 58, - 55, - 4 - ], - [ - 46, - 57, - 15, - 36, - 30, - 27 - ], - [ - 52, - 33, - 23, - 51, - 10, - 15 - ], - [ - 40, - 41, - 4, - 51, - 10, - 31 - ], - [ - 8, - 4, - 59, - 48, - 34, - 9 - ], - [ - 43, - 3, - 27, - 26, - 19, - 31 - ], - [ - 46, - 18, - 8, - 4, - 50, - 40 - ] - ], - [ - [ - 22, - 36, - 35, - 63, - 43, - 23 - ], - [ - 54, - 30, - 4, - 36, - 35, - 55 - ], - [ - 28, - 19, - 23, - 49, - 50, - 59 - ], - [ - 62, - 5, - 50, - 53, - 42, - 48 - ], - [ - 0, - 3, - 61, - 57, - 41, - 49 - ], - [ - 60, - 4, - 29, - 16, - 53, - 30 - ], - [ - 34, - 32, - 33, - 9, - 56, - 35 - ], - [ - 12, - 53, - 14, - 36, - 25, - 61 - ], - [ - 23, - 48, - 35, - 29, - 4, - 16 - ], - [ - 22, - 25, - 4, - 54, - 62, - 5 - ], - [ - 24, - 19, - 27, - 55, - 47, - 25 - ], - [ - 18, - 5, - 22, - 34, - 63, - 30 - ], - [ - 17, - 49, - 30, - 28, - 11, - 42 - ], - [ - 29, - 13, - 54, - 25, - 45, - 47 - ], - [ - 27, - 2, - 10, - 5, - 54, - 33 - ], - [ - 55, - 2, - 57, - 54, - 56, - 22 - ], - [ - 53, - 3, - 60, - 27, - 5, - 50 - ], - [ - 17, - 8, - 47, - 50, - 61, - 44 - ], - [ - 27, - 38, - 32, - 14, - 61, - 39 - ], - [ - 22, - 43, - 32, - 57, - 39, - 34 - ], - [ - 16, - 44, - 37, - 23, - 61, - 27 - ], - [ - 45, - 40, - 55, - 32, - 31, - 3 - ], - [ - 32, - 28, - 41, - 15, - 1, - 52 - ], - [ - 21, - 22, - 31, - 10, - 4, - 40 - ], - [ - 29, - 35, - 62, - 60, - 41, - 1 - ], - [ - 39, - 58, - 1, - 63, - 3, - 35 - ], - [ - 10, - 62, - 31, - 45, - 27, - 7 - ] - ], - [ - [ - 17, - 5, - 10, - 57, - 14, - 27 - ], - [ - 43, - 9, - 33, - 56, - 1, - 20 - ], - [ - 63, - 1, - 35, - 43, - 27, - 10 - ], - [ - 47, - 6, - 18, - 3, - 38, - 15 - ], - [ - 11, - 51, - 61, - 34, - 44, - 55 - ], - [ - 10, - 63, - 53, - 60, - 37, - 58 - ], - [ - 51, - 45, - 63, - 34, - 18, - 60 - ], - [ - 35, - 51, - 52, - 53, - 38, - 45 - ], - [ - 44, - 0, - 35, - 55, - 38, - 9 - ], - [ - 42, - 43, - 5, - 25, - 21, - 6 - ], - [ - 42, - 13, - 25, - 27, - 38, - 32 - ], - [ - 43, - 5, - 1, - 38, - 22, - 2 - ], - [ - 17, - 49, - 0, - 37, - 28, - 30 - ], - [ - 2, - 29, - 54, - 5, - 13, - 60 - ], - [ - 2, - 27, - 10, - 58, - 40, - 28 - ], - [ - 2, - 49, - 54, - 62, - 53, - 57 - ], - [ - 60, - 53, - 12, - 27, - 28, - 0 - ], - [ - 17, - 50, - 51, - 33, - 3, - 11 - ], - [ - 38, - 5, - 19, - 13, - 27, - 32 - ], - [ - 43, - 22, - 0, - 57, - 40, - 63 - ], - [ - 27, - 16, - 35, - 52, - 38, - 39 - ], - [ - 29, - 31, - 55, - 40, - 62, - 27 - ], - [ - 41, - 48, - 32, - 46, - 40, - 57 - ], - [ - 20, - 54, - 4, - 22, - 0, - 55 - ], - [ - 62, - 23, - 25, - 33, - 28, - 20 - ], - [ - 7, - 18, - 60, - 22, - 58, - 3 - ], - [ - 36, - 9, - 11, - 0, - 48, - 31 - ] - ], - [ - [ - 11, - 59, - 52, - 28, - 6, - 45 - ], - [ - 7, - 23, - 15, - 17, - 55, - 61 - ], - [ - 28, - 0, - 29, - 46, - 58, - 14 - ], - [ - 62, - 49, - 39, - 18, - 6, - 13 - ], - [ - 36, - 61, - 7, - 40, - 35, - 33 - ], - [ - 8, - 16, - 29, - 43, - 57, - 37 - ], - [ - 30, - 19, - 21, - 57, - 42, - 14 - ], - [ - 35, - 11, - 51, - 53, - 36, - 33 - ], - [ - 48, - 35, - 4, - 50, - 11, - 40 - ], - [ - 5, - 3, - 21, - 47, - 43, - 38 - ], - [ - 21, - 18, - 2, - 6, - 3, - 10 - ], - [ - 18, - 43, - 5, - 45, - 22, - 4 - ], - [ - 33, - 28, - 37, - 41, - 49, - 10 - ], - [ - 56, - 55, - 22, - 54, - 62, - 15 - ], - [ - 2, - 27, - 57, - 5, - 63, - 19 - ], - [ - 55, - 8, - 54, - 2, - 59, - 48 - ], - [ - 0, - 53, - 51, - 3, - 50, - 52 - ], - [ - 21, - 12, - 62, - 60, - 18, - 43 - ], - [ - 38, - 13, - 53, - 27, - 14, - 28 - ], - [ - 22, - 56, - 44, - 43, - 51, - 53 - ], - [ - 29, - 43, - 12, - 16, - 41, - 52 - ], - [ - 51, - 1, - 35, - 44, - 48, - 34 - ], - [ - 32, - 25, - 4, - 41, - 53, - 54 - ], - [ - 48, - 17, - 25, - 60, - 1, - 62 - ], - [ - 22, - 29, - 5, - 18, - 53, - 20 - ], - [ - 28, - 55, - 15, - 1, - 49, - 8 - ], - [ - 30, - 57, - 6, - 7, - 31, - 50 - ] - ], - [ - [ - 11, - 16, - 31, - 46, - 35, - 49 - ], - [ - 13, - 49, - 31, - 16, - 34, - 19 - ], - [ - 36, - 13, - 27, - 34, - 42, - 18 - ], - [ - 24, - 7, - 61, - 63, - 34, - 15 - ], - [ - 17, - 35, - 2, - 48, - 44, - 62 - ], - [ - 8, - 7, - 23, - 47, - 51, - 26 - ], - [ - 61, - 58, - 29, - 38, - 33, - 45 - ], - [ - 20, - 9, - 42, - 27, - 54, - 29 - ], - [ - 25, - 47, - 34, - 1, - 4, - 38 - ], - [ - 54, - 27, - 57, - 4, - 28, - 25 - ], - [ - 1, - 30, - 59, - 60, - 17, - 22 - ], - [ - 31, - 25, - 51, - 5, - 58, - 49 - ], - [ - 34, - 2, - 19, - 14, - 29, - 36 - ], - [ - 57, - 40, - 6, - 9, - 32, - 14 - ], - [ - 8, - 44, - 34, - 16, - 45, - 47 - ], - [ - 8, - 25, - 40, - 37, - 14, - 59 - ], - [ - 13, - 15, - 19, - 51, - 25, - 37 - ], - [ - 23, - 4, - 44, - 52, - 19, - 15 - ], - [ - 8, - 10, - 45, - 1, - 31, - 53 - ], - [ - 38, - 59, - 15, - 2, - 34, - 0 - ], - [ - 41, - 25, - 57, - 55, - 27, - 24 - ], - [ - 11, - 44, - 41, - 39, - 62, - 53 - ], - [ - 31, - 30, - 42, - 37, - 34, - 59 - ], - [ - 33, - 14, - 12, - 10, - 54, - 51 - ], - [ - 11, - 18, - 45, - 20, - 33, - 48 - ], - [ - 50, - 51, - 54, - 23, - 10, - 2 - ], - [ - 17, - 31, - 37, - 5, - 19, - 36 - ] - ], - [ - [ - 31, - 35, - 51, - 18, - 53, - 61 - ], - [ - 7, - 40, - 39, - 41, - 31, - 37 - ], - [ - 16, - 29, - 26, - 50, - 33, - 10 - ], - [ - 14, - 22, - 37, - 17, - 6, - 25 - ], - [ - 60, - 20, - 46, - 4, - 3, - 57 - ], - [ - 11, - 2, - 49, - 21, - 27, - 9 - ], - [ - 42, - 30, - 54, - 58, - 19, - 56 - ], - [ - 46, - 33, - 38, - 41, - 35, - 49 - ], - [ - 50, - 48, - 57, - 37, - 38, - 10 - ], - [ - 59, - 63, - 47, - 28, - 10, - 35 - ], - [ - 28, - 3, - 2, - 52, - 33, - 43 - ], - [ - 58, - 19, - 17, - 16, - 57, - 63 - ], - [ - 42, - 4, - 3, - 23, - 45, - 57 - ], - [ - 18, - 42, - 51, - 3, - 20, - 15 - ], - [ - 32, - 43, - 12, - 3, - 0, - 35 - ], - [ - 15, - 45, - 4, - 7, - 53, - 21 - ], - [ - 28, - 24, - 14, - 60, - 15, - 31 - ], - [ - 26, - 30, - 0, - 54, - 5, - 36 - ], - [ - 63, - 37, - 11, - 41, - 51, - 4 - ], - [ - 48, - 9, - 33, - 0, - 54, - 42 - ], - [ - 22, - 60, - 11, - 39, - 1, - 49 - ], - [ - 17, - 4, - 34, - 2, - 27, - 53 - ], - [ - 17, - 6, - 40, - 58, - 42, - 39 - ], - [ - 60, - 54, - 35, - 39, - 0, - 32 - ], - [ - 44, - 15, - 37, - 13, - 8, - 25 - ], - [ - 41, - 59, - 25, - 45, - 13, - 6 - ], - [ - 23, - 58, - 13, - 19, - 62, - 29 - ] - ], - [ - [ - 52, - 47, - 27, - 36, - 38, - 33 - ], - [ - 43, - 56, - 4, - 21, - 25, - 6 - ], - [ - 25, - 54, - 35, - 18, - 11, - 57 - ], - [ - 17, - 16, - 1, - 6, - 33, - 45 - ], - [ - 32, - 36, - 0, - 60, - 46, - 57 - ], - [ - 11, - 1, - 3, - 21, - 2, - 34 - ], - [ - 30, - 42, - 58, - 54, - 19, - 13 - ], - [ - 33, - 35, - 56, - 42, - 38, - 54 - ], - [ - 50, - 37, - 36, - 14, - 11, - 20 - ], - [ - 10, - 28, - 32, - 4, - 37, - 6 - ], - [ - 3, - 18, - 28, - 61, - 44, - 2 - ], - [ - 16, - 19, - 56, - 27, - 46, - 43 - ], - [ - 42, - 33, - 57, - 3, - 58, - 26 - ], - [ - 56, - 15, - 51, - 55, - 50, - 13 - ], - [ - 12, - 43, - 40, - 13, - 16, - 29 - ], - [ - 15, - 7, - 28, - 53, - 5, - 20 - ], - [ - 14, - 58, - 24, - 60, - 31, - 51 - ], - [ - 0, - 60, - 3, - 44, - 24, - 19 - ], - [ - 36, - 28, - 11, - 63, - 53, - 15 - ], - [ - 2, - 26, - 9, - 34, - 0, - 3 - ], - [ - 49, - 28, - 43, - 41, - 30, - 11 - ], - [ - 25, - 51, - 12, - 6, - 61, - 16 - ], - [ - 39, - 17, - 41, - 50, - 40, - 21 - ], - [ - 25, - 58, - 48, - 12, - 60, - 33 - ], - [ - 31, - 49, - 5, - 52, - 63, - 3 - ], - [ - 15, - 55, - 38, - 47, - 1, - 49 - ], - [ - 23, - 6, - 32, - 19, - 62, - 7 - ] - ], - [ - [ - 44, - 24, - 56, - 33, - 15, - 7 - ], - [ - 38, - 26, - 24, - 29, - 53, - 19 - ], - [ - 12, - 15, - 29, - 9, - 1, - 63 - ], - [ - 38, - 61, - 58, - 50, - 45, - 6 - ], - [ - 24, - 34, - 4, - 36, - 57, - 31 - ], - [ - 1, - 22, - 43, - 21, - 10, - 7 - ], - [ - 20, - 19, - 54, - 58, - 18, - 42 - ], - [ - 56, - 33, - 14, - 21, - 51, - 18 - ], - [ - 60, - 50, - 14, - 36, - 4, - 43 - ], - [ - 6, - 10, - 9, - 63, - 4, - 38 - ], - [ - 54, - 39, - 46, - 18, - 3, - 2 - ], - [ - 30, - 16, - 37, - 53, - 56, - 43 - ], - [ - 56, - 10, - 42, - 58, - 57, - 23 - ], - [ - 30, - 56, - 13, - 51, - 50, - 0 - ], - [ - 55, - 40, - 12, - 62, - 13, - 30 - ], - [ - 53, - 28, - 11, - 61, - 7, - 19 - ], - [ - 55, - 14, - 17, - 47, - 30, - 5 - ], - [ - 18, - 31, - 20, - 60, - 57, - 32 - ], - [ - 43, - 12, - 53, - 26, - 32, - 61 - ], - [ - 25, - 16, - 26, - 61, - 3, - 53 - ], - [ - 49, - 28, - 51, - 59, - 55, - 11 - ], - [ - 56, - 6, - 32, - 14, - 10, - 21 - ], - [ - 2, - 15, - 58, - 17, - 13, - 62 - ], - [ - 53, - 51, - 4, - 29, - 50, - 25 - ], - [ - 61, - 31, - 15, - 19, - 60, - 49 - ], - [ - 46, - 44, - 19, - 36, - 8, - 40 - ], - [ - 8, - 29, - 46, - 7, - 53, - 20 - ] - ], - [ - [ - 48, - 42, - 38, - 63, - 50, - 62 - ], - [ - 3, - 2, - 40, - 33, - 14, - 61 - ], - [ - 39, - 7, - 45, - 40, - 6, - 44 - ], - [ - 41, - 5, - 20, - 56, - 13, - 0 - ], - [ - 6, - 37, - 30, - 1, - 38, - 52 - ], - [ - 59, - 46, - 4, - 22, - 5, - 6 - ], - [ - 20, - 1, - 44, - 35, - 13, - 3 - ], - [ - 12, - 56, - 8, - 50, - 31, - 2 - ], - [ - 33, - 60, - 41, - 43, - 37, - 52 - ], - [ - 9, - 10, - 0, - 55, - 40, - 37 - ], - [ - 54, - 39, - 9, - 8, - 61, - 46 - ], - [ - 30, - 56, - 53, - 24, - 16, - 59 - ], - [ - 7, - 58, - 57, - 42, - 52, - 10 - ], - [ - 27, - 30, - 42, - 25, - 59, - 13 - ], - [ - 9, - 11, - 49, - 55, - 61, - 0 - ], - [ - 34, - 53, - 54, - 57, - 29, - 37 - ], - [ - 55, - 17, - 5, - 60, - 31, - 16 - ], - [ - 57, - 48, - 43, - 39, - 32, - 36 - ], - [ - 12, - 43, - 42, - 49, - 7, - 61 - ], - [ - 23, - 36, - 1, - 0, - 16, - 61 - ], - [ - 21, - 35, - 11, - 31, - 55, - 7 - ], - [ - 60, - 8, - 24, - 42, - 6, - 14 - ], - [ - 15, - 51, - 58, - 2, - 33, - 30 - ], - [ - 52, - 51, - 4, - 28, - 21, - 19 - ], - [ - 19, - 60, - 15, - 34, - 54, - 61 - ], - [ - 19, - 61, - 58, - 40, - 12, - 3 - ], - [ - 35, - 49, - 54, - 53, - 1, - 25 - ] - ], - [ - [ - 47, - 37, - 59, - 38, - 33, - 10 - ], - [ - 36, - 7, - 21, - 51, - 8, - 47 - ], - [ - 20, - 32, - 44, - 47, - 4, - 54 - ], - [ - 43, - 20, - 42, - 52, - 8, - 19 - ], - [ - 27, - 9, - 39, - 57, - 12, - 54 - ], - [ - 28, - 16, - 29, - 11, - 61, - 58 - ], - [ - 42, - 2, - 30, - 43, - 28, - 25 - ], - [ - 36, - 18, - 32, - 21, - 53, - 15 - ], - [ - 50, - 12, - 37, - 48, - 14, - 52 - ], - [ - 9, - 10, - 13, - 3, - 58, - 26 - ], - [ - 18, - 3, - 9, - 55, - 6, - 61 - ], - [ - 30, - 18, - 45, - 60, - 16, - 24 - ], - [ - 35, - 5, - 42, - 26, - 37, - 58 - ], - [ - 55, - 38, - 15, - 13, - 14, - 54 - ], - [ - 9, - 62, - 2, - 17, - 13, - 26 - ], - [ - 9, - 40, - 37, - 54, - 17, - 47 - ], - [ - 22, - 60, - 51, - 14, - 40, - 39 - ], - [ - 17, - 3, - 21, - 30, - 36, - 25 - ], - [ - 56, - 41, - 24, - 9, - 43, - 14 - ], - [ - 53, - 34, - 26, - 22, - 12, - 35 - ], - [ - 36, - 42, - 24, - 37, - 8, - 4 - ], - [ - 34, - 16, - 8, - 46, - 56, - 6 - ], - [ - 31, - 10, - 16, - 41, - 56, - 32 - ], - [ - 48, - 0, - 42, - 56, - 31, - 30 - ], - [ - 6, - 54, - 26, - 19, - 8, - 7 - ], - [ - 49, - 22, - 6, - 13, - 24, - 59 - ], - [ - 45, - 62, - 27, - 47, - 50, - 7 - ] - ], - [ - [ - 45, - 37, - 48, - 29, - 30, - 3 - ], - [ - 8, - 60, - 59, - 43, - 10, - 48 - ], - [ - 51, - 45, - 28, - 34, - 59, - 63 - ], - [ - 43, - 2, - 38, - 12, - 20, - 4 - ], - [ - 50, - 57, - 39, - 31, - 0, - 63 - ], - [ - 58, - 53, - 18, - 9, - 30, - 21 - ], - [ - 63, - 51, - 34, - 7, - 20, - 27 - ], - [ - 32, - 21, - 46, - 47, - 25, - 18 - ], - [ - 6, - 12, - 50, - 14, - 33, - 37 - ], - [ - 30, - 10, - 9, - 6, - 13, - 29 - ], - [ - 61, - 18, - 23, - 20, - 44, - 6 - ], - [ - 16, - 12, - 30, - 60, - 0, - 62 - ], - [ - 5, - 26, - 57, - 35, - 37, - 61 - ], - [ - 46, - 55, - 13, - 1, - 17, - 31 - ], - [ - 6, - 10, - 2, - 22, - 16, - 15 - ], - [ - 6, - 54, - 19, - 57, - 25, - 46 - ], - [ - 2, - 30, - 60, - 61, - 18, - 49 - ], - [ - 17, - 4, - 2, - 27, - 3, - 44 - ], - [ - 56, - 46, - 8, - 32, - 6, - 14 - ], - [ - 5, - 13, - 0, - 34, - 14, - 30 - ], - [ - 15, - 23, - 59, - 57, - 27, - 53 - ], - [ - 36, - 61, - 8, - 43, - 57, - 37 - ], - [ - 10, - 15, - 52, - 31, - 29, - 23 - ], - [ - 10, - 38, - 1, - 4, - 57, - 31 - ], - [ - 8, - 32, - 54, - 33, - 3, - 50 - ], - [ - 6, - 33, - 19, - 50, - 2, - 3 - ], - [ - 55, - 43, - 4, - 5, - 25, - 8 - ] - ], - [ - [ - 25, - 14, - 18, - 49, - 51, - 63 - ], - [ - 42, - 21, - 30, - 24, - 43, - 7 - ], - [ - 54, - 39, - 9, - 59, - 28, - 49 - ], - [ - 23, - 1, - 55, - 45, - 43, - 40 - ], - [ - 37, - 30, - 41, - 59, - 21, - 44 - ], - [ - 4, - 41, - 31, - 35, - 19, - 14 - ], - [ - 28, - 55, - 44, - 63, - 9, - 51 - ], - [ - 63, - 12, - 32, - 13, - 47, - 28 - ], - [ - 33, - 12, - 59, - 35, - 6, - 39 - ], - [ - 30, - 40, - 10, - 29, - 52, - 13 - ], - [ - 63, - 19, - 23, - 61, - 8, - 55 - ], - [ - 45, - 62, - 22, - 12, - 38, - 42 - ], - [ - 58, - 26, - 57, - 20, - 45, - 30 - ], - [ - 8, - 59, - 47, - 17, - 25, - 22 - ], - [ - 24, - 54, - 51, - 55, - 10, - 15 - ], - [ - 46, - 22, - 54, - 44, - 57, - 40 - ], - [ - 34, - 60, - 2, - 17, - 27, - 18 - ], - [ - 51, - 6, - 61, - 2, - 39, - 32 - ], - [ - 5, - 49, - 38, - 8, - 32, - 46 - ], - [ - 57, - 4, - 23, - 14, - 59, - 5 - ], - [ - 7, - 36, - 45, - 21, - 53, - 4 - ], - [ - 60, - 3, - 39, - 8, - 14, - 35 - ], - [ - 33, - 57, - 60, - 47, - 15, - 53 - ], - [ - 20, - 11, - 22, - 19, - 58, - 4 - ], - [ - 36, - 34, - 47, - 41, - 60, - 1 - ], - [ - 33, - 3, - 49, - 29, - 59, - 14 - ], - [ - 52, - 60, - 24, - 25, - 35, - 34 - ] - ], - [ - [ - 23, - 54, - 53, - 11, - 58, - 8 - ], - [ - 11, - 30, - 15, - 63, - 59, - 49 - ], - [ - 20, - 58, - 29, - 17, - 52, - 30 - ], - [ - 18, - 1, - 43, - 15, - 3, - 8 - ], - [ - 59, - 55, - 13, - 20, - 44, - 30 - ], - [ - 33, - 45, - 27, - 53, - 63, - 52 - ], - [ - 28, - 57, - 51, - 34, - 53, - 31 - ], - [ - 13, - 63, - 3, - 32, - 44, - 45 - ], - [ - 59, - 33, - 12, - 57, - 6, - 35 - ], - [ - 23, - 30, - 42, - 10, - 29, - 13 - ], - [ - 63, - 23, - 18, - 42, - 38, - 19 - ], - [ - 62, - 2, - 22, - 19, - 45, - 23 - ], - [ - 6, - 20, - 26, - 30, - 5, - 25 - ], - [ - 3, - 6, - 8, - 31, - 17, - 37 - ], - [ - 59, - 2, - 10, - 54, - 55, - 19 - ], - [ - 45, - 7, - 57, - 42, - 54, - 46 - ], - [ - 60, - 55, - 53, - 34, - 41, - 18 - ], - [ - 36, - 2, - 8, - 31, - 32, - 46 - ], - [ - 36, - 45, - 38, - 32, - 61, - 25 - ], - [ - 57, - 39, - 6, - 18, - 19, - 30 - ], - [ - 37, - 39, - 59, - 48, - 53, - 14 - ], - [ - 52, - 28, - 46, - 55, - 47, - 61 - ], - [ - 1, - 28, - 52, - 15, - 31, - 18 - ], - [ - 22, - 21, - 8, - 10, - 46, - 39 - ], - [ - 45, - 41, - 35, - 9, - 54, - 33 - ], - [ - 42, - 39, - 58, - 61, - 24, - 3 - ], - [ - 41, - 46, - 10, - 3, - 15, - 33 - ] - ], - [ - [ - 48, - 38, - 63, - 42, - 47, - 7 - ], - [ - 3, - 10, - 26, - 2, - 6, - 62 - ], - [ - 39, - 7, - 44, - 6, - 45, - 40 - ], - [ - 21, - 39, - 1, - 18, - 15, - 57 - ], - [ - 6, - 33, - 9, - 13, - 3, - 27 - ], - [ - 59, - 33, - 53, - 62, - 21, - 45 - ], - [ - 28, - 10, - 33, - 7, - 57, - 50 - ], - [ - 50, - 13, - 12, - 49, - 3, - 55 - ], - [ - 33, - 59, - 26, - 35, - 48, - 38 - ], - [ - 23, - 30, - 20, - 43, - 10, - 58 - ], - [ - 63, - 23, - 18, - 17, - 38, - 9 - ], - [ - 22, - 62, - 36, - 33, - 6, - 2 - ], - [ - 12, - 17, - 59, - 7, - 26, - 49 - ], - [ - 47, - 3, - 27, - 6, - 24, - 22 - ], - [ - 59, - 57, - 46, - 54, - 2, - 10 - ], - [ - 27, - 45, - 57, - 54, - 34, - 61 - ], - [ - 34, - 60, - 55, - 31, - 58, - 43 - ], - [ - 48, - 42, - 32, - 39, - 2, - 38 - ], - [ - 42, - 49, - 32, - 44, - 12, - 61 - ], - [ - 57, - 36, - 39, - 1, - 30, - 59 - ], - [ - 48, - 7, - 21, - 53, - 17, - 29 - ], - [ - 8, - 60, - 58, - 35, - 46, - 14 - ], - [ - 51, - 15, - 28, - 47, - 33, - 30 - ], - [ - 11, - 19, - 4, - 8, - 58, - 52 - ], - [ - 24, - 5, - 41, - 60, - 40, - 54 - ], - [ - 61, - 3, - 21, - 58, - 19, - 8 - ], - [ - 60, - 35, - 54, - 49, - 1, - 0 - ] - ], - [ - [ - 6, - 24, - 63, - 25, - 26, - 45 - ], - [ - 47, - 13, - 49, - 44, - 20, - 19 - ], - [ - 23, - 32, - 49, - 20, - 24, - 2 - ], - [ - 43, - 21, - 8, - 40, - 39, - 45 - ], - [ - 39, - 29, - 3, - 5, - 41, - 10 - ], - [ - 61, - 33, - 48, - 40, - 29, - 62 - ], - [ - 29, - 28, - 25, - 33, - 44, - 31 - ], - [ - 50, - 33, - 13, - 11, - 30, - 54 - ], - [ - 48, - 26, - 35, - 17, - 55, - 6 - ], - [ - 43, - 23, - 12, - 9, - 25, - 62 - ], - [ - 63, - 35, - 18, - 53, - 38, - 9 - ], - [ - 19, - 48, - 5, - 36, - 59, - 39 - ], - [ - 7, - 26, - 59, - 17, - 12, - 30 - ], - [ - 47, - 24, - 22, - 62, - 42, - 14 - ], - [ - 59, - 46, - 39, - 35, - 57, - 2 - ], - [ - 27, - 9, - 20, - 0, - 57, - 52 - ], - [ - 38, - 58, - 60, - 34, - 43, - 29 - ], - [ - 42, - 32, - 39, - 54, - 38, - 57 - ], - [ - 59, - 32, - 42, - 6, - 21, - 18 - ], - [ - 24, - 36, - 57, - 4, - 30, - 60 - ], - [ - 9, - 48, - 7, - 53, - 21, - 13 - ], - [ - 49, - 31, - 14, - 8, - 19, - 52 - ], - [ - 15, - 33, - 44, - 8, - 3, - 14 - ], - [ - 11, - 4, - 21, - 28, - 41, - 23 - ], - [ - 37, - 27, - 24, - 12, - 9, - 42 - ], - [ - 58, - 9, - 19, - 3, - 12, - 48 - ], - [ - 24, - 14, - 60, - 47, - 25, - 35 - ] - ], - [ - [ - 26, - 62, - 58, - 18, - 38, - 5 - ], - [ - 19, - 12, - 40, - 39, - 31, - 57 - ], - [ - 27, - 38, - 9, - 22, - 23, - 61 - ], - [ - 42, - 20, - 63, - 39, - 45, - 43 - ], - [ - 7, - 36, - 60, - 29, - 57, - 24 - ], - [ - 61, - 34, - 3, - 1, - 44, - 51 - ], - [ - 29, - 42, - 25, - 47, - 30, - 17 - ], - [ - 33, - 38, - 39, - 50, - 62, - 0 - ], - [ - 17, - 26, - 13, - 50, - 41, - 11 - ], - [ - 12, - 43, - 35, - 16, - 55, - 60 - ], - [ - 3, - 18, - 39, - 63, - 35, - 62 - ], - [ - 19, - 48, - 39, - 60, - 54, - 2 - ], - [ - 7, - 42, - 56, - 59, - 37, - 24 - ], - [ - 56, - 47, - 54, - 59, - 3, - 6 - ], - [ - 35, - 39, - 43, - 59, - 2, - 13 - ], - [ - 20, - 27, - 9, - 56, - 0, - 54 - ], - [ - 17, - 60, - 29, - 18, - 58, - 5 - ], - [ - 57, - 42, - 45, - 0, - 24, - 29 - ], - [ - 49, - 28, - 52, - 32, - 50, - 42 - ], - [ - 24, - 2, - 36, - 30, - 46, - 59 - ], - [ - 11, - 7, - 48, - 54, - 53, - 21 - ], - [ - 33, - 14, - 42, - 6, - 8, - 5 - ], - [ - 15, - 36, - 33, - 17, - 51, - 13 - ], - [ - 41, - 2, - 4, - 11, - 8, - 29 - ], - [ - 30, - 15, - 60, - 5, - 46, - 9 - ], - [ - 21, - 3, - 14, - 19, - 61, - 58 - ], - [ - 24, - 60, - 1, - 14, - 35, - 53 - ] - ], - [ - [ - 8, - 56, - 54, - 4, - 37, - 38 - ], - [ - 36, - 9, - 24, - 8, - 1, - 2 - ], - [ - 16, - 57, - 29, - 32, - 58, - 30 - ], - [ - 32, - 63, - 42, - 52, - 4, - 20 - ], - [ - 27, - 24, - 44, - 39, - 63, - 20 - ], - [ - 24, - 11, - 28, - 6, - 15, - 31 - ], - [ - 30, - 41, - 42, - 25, - 52, - 2 - ], - [ - 26, - 36, - 62, - 18, - 50, - 21 - ], - [ - 48, - 50, - 32, - 51, - 27, - 26 - ], - [ - 13, - 43, - 3, - 26, - 12, - 41 - ], - [ - 3, - 32, - 18, - 53, - 39, - 17 - ], - [ - 21, - 19, - 9, - 36, - 48, - 1 - ], - [ - 35, - 59, - 42, - 54, - 63, - 17 - ], - [ - 15, - 12, - 38, - 9, - 51, - 54 - ], - [ - 3, - 2, - 48, - 59, - 57, - 13 - ], - [ - 27, - 9, - 4, - 20, - 22, - 3 - ], - [ - 60, - 40, - 22, - 27, - 46, - 32 - ], - [ - 30, - 25, - 21, - 17, - 0, - 42 - ], - [ - 33, - 27, - 34, - 28, - 38, - 44 - ], - [ - 53, - 32, - 33, - 35, - 31, - 56 - ], - [ - 40, - 11, - 20, - 47, - 48, - 24 - ], - [ - 45, - 62, - 33, - 53, - 17, - 34 - ], - [ - 41, - 9, - 39, - 25, - 17, - 32 - ], - [ - 44, - 30, - 2, - 31, - 0, - 47 - ], - [ - 6, - 62, - 37, - 52, - 55, - 33 - ], - [ - 10, - 24, - 57, - 9, - 49, - 13 - ], - [ - 45, - 24, - 27, - 47, - 19, - 26 - ] - ], - [ - [ - 4, - 16, - 59, - 44, - 13, - 56 - ], - [ - 16, - 23, - 9, - 59, - 13, - 50 - ], - [ - 11, - 35, - 21, - 7, - 9, - 59 - ], - [ - 44, - 1, - 25, - 26, - 15, - 20 - ], - [ - 35, - 57, - 52, - 31, - 24, - 5 - ], - [ - 35, - 15, - 38, - 53, - 12, - 0 - ], - [ - 36, - 20, - 63, - 28, - 60, - 33 - ], - [ - 49, - 50, - 12, - 13, - 8, - 16 - ], - [ - 42, - 48, - 6, - 16, - 35, - 41 - ], - [ - 23, - 9, - 54, - 34, - 30, - 13 - ], - [ - 23, - 63, - 51, - 53, - 55, - 3 - ], - [ - 33, - 36, - 62, - 19, - 59, - 57 - ], - [ - 59, - 38, - 12, - 32, - 17, - 53 - ], - [ - 48, - 0, - 19, - 24, - 61, - 22 - ], - [ - 55, - 6, - 39, - 7, - 60, - 18 - ], - [ - 27, - 57, - 19, - 1, - 60, - 47 - ], - [ - 30, - 40, - 10, - 17, - 36, - 60 - ], - [ - 20, - 13, - 1, - 31, - 17, - 43 - ], - [ - 50, - 32, - 23, - 7, - 33, - 30 - ], - [ - 14, - 19, - 23, - 13, - 10, - 7 - ], - [ - 53, - 23, - 19, - 13, - 7, - 50 - ], - [ - 30, - 18, - 46, - 14, - 1, - 15 - ], - [ - 15, - 18, - 52, - 0, - 51, - 27 - ], - [ - 4, - 21, - 41, - 53, - 10, - 14 - ], - [ - 4, - 9, - 39, - 8, - 16, - 54 - ], - [ - 3, - 19, - 16, - 47, - 30, - 27 - ], - [ - 33, - 8, - 46, - 40, - 29, - 14 - ] - ], - [ - [ - 6, - 26, - 3, - 24, - 11, - 38 - ], - [ - 49, - 16, - 19, - 57, - 0, - 18 - ], - [ - 16, - 60, - 57, - 0, - 22, - 30 - ], - [ - 35, - 46, - 49, - 44, - 26, - 17 - ], - [ - 25, - 54, - 9, - 43, - 45, - 27 - ], - [ - 55, - 33, - 30, - 6, - 28, - 57 - ], - [ - 55, - 31, - 49, - 52, - 15, - 25 - ], - [ - 11, - 36, - 15, - 29, - 30, - 62 - ], - [ - 55, - 42, - 48, - 54, - 46, - 6 - ], - [ - 10, - 3, - 43, - 21, - 62, - 54 - ], - [ - 32, - 56, - 6, - 3, - 18, - 55 - ], - [ - 50, - 21, - 36, - 19, - 4, - 42 - ], - [ - 38, - 35, - 53, - 39, - 41, - 32 - ], - [ - 28, - 24, - 38, - 41, - 15, - 12 - ], - [ - 2, - 40, - 18, - 60, - 55, - 13 - ], - [ - 27, - 22, - 57, - 8, - 54, - 37 - ], - [ - 40, - 27, - 60, - 46, - 44, - 50 - ], - [ - 17, - 21, - 30, - 12, - 29, - 26 - ], - [ - 33, - 38, - 23, - 2, - 13, - 27 - ], - [ - 32, - 56, - 9, - 44, - 31, - 60 - ], - [ - 53, - 12, - 22, - 40, - 41, - 5 - ], - [ - 44, - 45, - 49, - 17, - 14, - 56 - ], - [ - 25, - 39, - 41, - 4, - 9, - 53 - ], - [ - 44, - 1, - 17, - 0, - 13, - 58 - ], - [ - 62, - 52, - 37, - 54, - 42, - 36 - ], - [ - 10, - 9, - 28, - 55, - 2, - 24 - ], - [ - 57, - 30, - 27, - 45, - 47, - 16 - ] - ], - [ - [ - 16, - 11, - 31, - 46, - 0, - 35 - ], - [ - 13, - 49, - 34, - 31, - 16, - 8 - ], - [ - 36, - 13, - 27, - 58, - 18, - 34 - ], - [ - 24, - 61, - 15, - 46, - 63, - 25 - ], - [ - 17, - 2, - 50, - 35, - 58, - 44 - ], - [ - 8, - 23, - 7, - 49, - 26, - 47 - ], - [ - 61, - 38, - 29, - 58, - 0, - 36 - ], - [ - 42, - 20, - 9, - 4, - 11, - 27 - ], - [ - 1, - 34, - 30, - 48, - 26, - 7 - ], - [ - 27, - 10, - 28, - 62, - 13, - 54 - ], - [ - 60, - 1, - 30, - 59, - 6, - 52 - ], - [ - 31, - 62, - 25, - 42, - 32, - 21 - ], - [ - 34, - 32, - 19, - 2, - 62, - 14 - ], - [ - 6, - 40, - 24, - 31, - 13, - 57 - ], - [ - 44, - 2, - 40, - 8, - 34, - 45 - ], - [ - 8, - 27, - 7, - 57, - 50, - 47 - ], - [ - 13, - 60, - 45, - 22, - 52, - 25 - ], - [ - 4, - 23, - 33, - 46, - 58, - 34 - ], - [ - 8, - 33, - 56, - 45, - 51, - 59 - ], - [ - 38, - 53, - 39, - 26, - 35, - 31 - ], - [ - 41, - 46, - 53, - 11, - 59, - 8 - ], - [ - 11, - 44, - 33, - 55, - 52, - 53 - ], - [ - 31, - 1, - 16, - 9, - 15, - 41 - ], - [ - 2, - 31, - 14, - 10, - 44, - 7 - ], - [ - 45, - 6, - 62, - 12, - 24, - 18 - ], - [ - 10, - 50, - 42, - 54, - 24, - 3 - ], - [ - 17, - 31, - 37, - 5, - 10, - 41 - ] - ], - [ - [ - 45, - 37, - 48, - 29, - 30, - 59 - ], - [ - 8, - 59, - 10, - 60, - 43, - 55 - ], - [ - 45, - 51, - 28, - 59, - 34, - 31 - ], - [ - 14, - 31, - 22, - 2, - 19, - 44 - ], - [ - 59, - 50, - 4, - 31, - 9, - 44 - ], - [ - 58, - 23, - 47, - 9, - 53, - 30 - ], - [ - 54, - 13, - 51, - 22, - 29, - 61 - ], - [ - 47, - 42, - 32, - 20, - 24, - 37 - ], - [ - 33, - 12, - 47, - 10, - 30, - 44 - ], - [ - 30, - 10, - 27, - 28, - 33, - 13 - ], - [ - 61, - 23, - 14, - 20, - 1, - 29 - ], - [ - 12, - 14, - 32, - 62, - 31, - 52 - ], - [ - 5, - 36, - 19, - 20, - 2, - 32 - ], - [ - 46, - 24, - 1, - 40, - 17, - 33 - ], - [ - 44, - 8, - 16, - 30, - 2, - 6 - ], - [ - 6, - 46, - 15, - 54, - 21, - 33 - ], - [ - 2, - 30, - 60, - 61, - 18, - 10 - ], - [ - 6, - 17, - 33, - 27, - 58, - 4 - ], - [ - 56, - 46, - 8, - 60, - 22, - 16 - ], - [ - 5, - 13, - 14, - 35, - 36, - 52 - ], - [ - 15, - 53, - 59, - 23, - 6, - 46 - ], - [ - 36, - 30, - 60, - 33, - 53, - 38 - ], - [ - 10, - 15, - 16, - 52, - 4, - 38 - ], - [ - 31, - 1, - 10, - 4, - 41, - 59 - ], - [ - 32, - 9, - 8, - 59, - 33, - 18 - ], - [ - 33, - 19, - 6, - 3, - 45, - 50 - ], - [ - 55, - 4, - 43, - 5, - 25, - 34 - ] - ], - [ - [ - 25, - 14, - 18, - 49, - 51, - 63 - ], - [ - 42, - 21, - 30, - 43, - 24, - 7 - ], - [ - 54, - 39, - 9, - 59, - 28, - 45 - ], - [ - 1, - 23, - 55, - 8, - 16, - 7 - ], - [ - 59, - 37, - 30, - 41, - 16, - 1 - ], - [ - 4, - 31, - 41, - 45, - 56, - 18 - ], - [ - 28, - 55, - 44, - 54, - 9, - 51 - ], - [ - 24, - 47, - 12, - 63, - 32, - 61 - ], - [ - 33, - 47, - 39, - 44, - 38, - 56 - ], - [ - 40, - 30, - 59, - 33, - 58, - 10 - ], - [ - 19, - 8, - 23, - 4, - 25, - 34 - ], - [ - 45, - 14, - 12, - 42, - 62, - 52 - ], - [ - 58, - 45, - 59, - 29, - 25, - 36 - ], - [ - 8, - 59, - 1, - 47, - 33, - 24 - ], - [ - 24, - 30, - 38, - 56, - 4, - 7 - ], - [ - 46, - 22, - 44, - 57, - 12, - 40 - ], - [ - 34, - 60, - 2, - 17, - 18, - 29 - ], - [ - 6, - 61, - 51, - 39, - 53, - 32 - ], - [ - 49, - 5, - 47, - 60, - 21, - 59 - ], - [ - 57, - 4, - 23, - 39, - 45, - 60 - ], - [ - 7, - 36, - 45, - 53, - 21, - 56 - ], - [ - 60, - 3, - 39, - 18, - 8, - 53 - ], - [ - 57, - 33, - 15, - 47, - 38, - 16 - ], - [ - 20, - 11, - 19, - 4, - 2, - 8 - ], - [ - 36, - 34, - 60, - 47, - 1, - 40 - ], - [ - 33, - 14, - 3, - 21, - 19, - 35 - ], - [ - 60, - 24, - 22, - 52, - 35, - 53 - ] - ], - [ - [ - 23, - 54, - 53, - 58, - 11, - 8 - ], - [ - 11, - 30, - 15, - 59, - 63, - 55 - ], - [ - 20, - 58, - 29, - 17, - 42, - 30 - ], - [ - 18, - 1, - 43, - 15, - 8, - 3 - ], - [ - 59, - 55, - 13, - 28, - 26, - 63 - ], - [ - 33, - 45, - 27, - 53, - 63, - 19 - ], - [ - 28, - 57, - 51, - 54, - 34, - 53 - ], - [ - 24, - 13, - 3, - 47, - 45, - 50 - ], - [ - 47, - 59, - 33, - 57, - 37, - 35 - ], - [ - 23, - 42, - 33, - 41, - 48, - 30 - ], - [ - 42, - 63, - 23, - 25, - 17, - 34 - ], - [ - 14, - 62, - 2, - 19, - 45, - 43 - ], - [ - 12, - 6, - 20, - 30, - 29, - 17 - ], - [ - 6, - 3, - 17, - 8, - 27, - 31 - ], - [ - 19, - 59, - 2, - 10, - 54, - 55 - ], - [ - 45, - 61, - 19, - 57, - 42, - 17 - ], - [ - 55, - 60, - 41, - 34, - 35, - 53 - ], - [ - 36, - 2, - 8, - 31, - 14, - 5 - ], - [ - 36, - 45, - 16, - 38, - 51, - 21 - ], - [ - 57, - 39, - 48, - 6, - 19, - 30 - ], - [ - 39, - 37, - 59, - 48, - 42, - 40 - ], - [ - 52, - 28, - 33, - 46, - 18, - 55 - ], - [ - 1, - 15, - 28, - 13, - 52, - 18 - ], - [ - 10, - 21, - 4, - 2, - 23, - 31 - ], - [ - 45, - 9, - 41, - 18, - 54, - 3 - ], - [ - 42, - 61, - 36, - 3, - 19, - 27 - ], - [ - 41, - 46, - 33, - 3, - 10, - 15 - ] - ], - [ - [ - 49, - 52, - 60, - 63, - 21, - 0 - ], - [ - 14, - 7, - 25, - 52, - 58, - 36 - ], - [ - 46, - 57, - 28, - 24, - 49, - 12 - ], - [ - 10, - 21, - 39, - 1, - 11, - 17 - ], - [ - 53, - 3, - 2, - 16, - 46, - 44 - ], - [ - 25, - 40, - 39, - 15, - 34, - 30 - ], - [ - 62, - 24, - 47, - 52, - 17, - 49 - ], - [ - 34, - 41, - 39, - 11, - 33, - 15 - ], - [ - 36, - 26, - 48, - 39, - 15, - 56 - ], - [ - 23, - 36, - 43, - 20, - 10, - 49 - ], - [ - 63, - 48, - 51, - 10, - 62, - 18 - ], - [ - 10, - 42, - 36, - 39, - 49, - 19 - ], - [ - 22, - 38, - 51, - 56, - 50, - 7 - ], - [ - 4, - 31, - 19, - 22, - 17, - 7 - ], - [ - 41, - 7, - 39, - 2, - 26, - 57 - ], - [ - 33, - 27, - 58, - 28, - 20, - 45 - ], - [ - 58, - 8, - 37, - 18, - 34, - 62 - ], - [ - 24, - 34, - 43, - 39, - 17, - 19 - ], - [ - 23, - 50, - 12, - 32, - 43, - 31 - ], - [ - 9, - 7, - 51, - 16, - 30, - 60 - ], - [ - 25, - 62, - 14, - 53, - 58, - 0 - ], - [ - 53, - 4, - 18, - 22, - 14, - 29 - ], - [ - 6, - 14, - 15, - 48, - 58, - 0 - ], - [ - 35, - 4, - 45, - 3, - 59, - 39 - ], - [ - 17, - 57, - 23, - 16, - 40, - 37 - ], - [ - 9, - 19, - 3, - 36, - 11, - 47 - ], - [ - 39, - 14, - 22, - 33, - 62, - 42 - ] - ], - [ - [ - 14, - 17, - 2, - 39, - 47, - 63 - ], - [ - 23, - 58, - 2, - 25, - 5, - 18 - ], - [ - 27, - 33, - 0, - 56, - 6, - 7 - ], - [ - 23, - 5, - 30, - 7, - 21, - 39 - ], - [ - 39, - 53, - 27, - 38, - 54, - 41 - ], - [ - 37, - 25, - 50, - 40, - 33, - 36 - ], - [ - 49, - 37, - 24, - 55, - 29, - 28 - ], - [ - 34, - 15, - 36, - 16, - 30, - 11 - ], - [ - 5, - 36, - 2, - 26, - 16, - 48 - ], - [ - 56, - 46, - 10, - 23, - 39, - 45 - ], - [ - 10, - 56, - 62, - 45, - 49, - 3 - ], - [ - 35, - 6, - 36, - 50, - 11, - 1 - ], - [ - 50, - 46, - 39, - 51, - 38, - 41 - ], - [ - 41, - 19, - 31, - 48, - 6, - 24 - ], - [ - 2, - 7, - 57, - 18, - 47, - 63 - ], - [ - 27, - 58, - 22, - 8, - 57, - 45 - ], - [ - 37, - 59, - 44, - 51, - 55, - 18 - ], - [ - 52, - 21, - 25, - 12, - 34, - 39 - ], - [ - 23, - 34, - 60, - 13, - 27, - 32 - ], - [ - 32, - 56, - 26, - 24, - 44, - 30 - ], - [ - 41, - 12, - 53, - 2, - 5, - 56 - ], - [ - 44, - 18, - 39, - 33, - 38, - 46 - ], - [ - 25, - 0, - 59, - 18, - 11, - 53 - ], - [ - 1, - 7, - 17, - 30, - 4, - 37 - ], - [ - 62, - 48, - 19, - 54, - 37, - 22 - ], - [ - 10, - 28, - 3, - 16, - 11, - 42 - ], - [ - 26, - 30, - 57, - 42, - 41, - 33 - ] - ], - [ - [ - 36, - 31, - 37, - 16, - 43, - 63 - ], - [ - 26, - 51, - 0, - 48, - 42, - 21 - ], - [ - 18, - 41, - 37, - 34, - 24, - 30 - ], - [ - 58, - 4, - 22, - 44, - 18, - 51 - ], - [ - 18, - 26, - 4, - 61, - 14, - 31 - ], - [ - 9, - 23, - 34, - 29, - 3, - 31 - ], - [ - 12, - 58, - 36, - 46, - 16, - 25 - ], - [ - 4, - 56, - 9, - 33, - 60, - 32 - ], - [ - 14, - 26, - 43, - 30, - 46, - 35 - ], - [ - 2, - 0, - 10, - 43, - 28, - 27 - ], - [ - 61, - 18, - 19, - 17, - 2, - 33 - ], - [ - 60, - 19, - 23, - 62, - 37, - 43 - ], - [ - 56, - 32, - 8, - 61, - 58, - 38 - ], - [ - 25, - 28, - 22, - 0, - 26, - 29 - ], - [ - 22, - 49, - 7, - 23, - 60, - 55 - ], - [ - 46, - 57, - 19, - 37, - 12, - 27 - ], - [ - 22, - 60, - 11, - 30, - 19, - 20 - ], - [ - 36, - 20, - 17, - 2, - 37, - 32 - ], - [ - 13, - 16, - 32, - 58, - 41, - 10 - ], - [ - 50, - 55, - 13, - 25, - 14, - 41 - ], - [ - 14, - 3, - 44, - 48, - 60, - 6 - ], - [ - 36, - 18, - 57, - 22, - 1, - 43 - ], - [ - 10, - 22, - 15, - 52, - 13, - 18 - ], - [ - 4, - 56, - 10, - 31, - 14, - 52 - ], - [ - 9, - 4, - 16, - 54, - 8, - 59 - ], - [ - 3, - 43, - 27, - 45, - 6, - 30 - ], - [ - 33, - 4, - 43, - 40, - 46, - 18 - ] - ], - [ - [ - 48, - 38, - 42, - 63, - 47, - 7 - ], - [ - 3, - 10, - 62, - 6, - 26, - 2 - ], - [ - 39, - 44, - 6, - 7, - 45, - 40 - ], - [ - 56, - 33, - 2, - 41, - 62, - 48 - ], - [ - 6, - 18, - 37, - 41, - 27, - 30 - ], - [ - 59, - 45, - 23, - 31, - 53, - 9 - ], - [ - 0, - 13, - 12, - 44, - 33, - 16 - ], - [ - 12, - 56, - 50, - 60, - 28, - 63 - ], - [ - 14, - 33, - 26, - 24, - 32, - 16 - ], - [ - 43, - 2, - 0, - 50, - 28, - 22 - ], - [ - 61, - 9, - 63, - 8, - 17, - 39 - ], - [ - 60, - 19, - 59, - 23, - 24, - 30 - ], - [ - 56, - 59, - 53, - 32, - 7, - 38 - ], - [ - 28, - 25, - 22, - 24, - 59, - 42 - ], - [ - 22, - 49, - 46, - 23, - 60, - 7 - ], - [ - 46, - 34, - 27, - 29, - 0, - 12 - ], - [ - 17, - 22, - 5, - 31, - 20, - 27 - ], - [ - 39, - 48, - 57, - 37, - 17, - 36 - ], - [ - 42, - 49, - 32, - 7, - 16, - 61 - ], - [ - 1, - 7, - 23, - 14, - 36, - 39 - ], - [ - 21, - 7, - 48, - 14, - 60, - 35 - ], - [ - 8, - 18, - 24, - 60, - 1, - 0 - ], - [ - 51, - 15, - 33, - 36, - 5, - 30 - ], - [ - 52, - 4, - 8, - 29, - 9, - 26 - ], - [ - 4, - 9, - 16, - 40, - 58, - 60 - ], - [ - 3, - 17, - 21, - 61, - 19, - 8 - ], - [ - 35, - 54, - 1, - 53, - 60, - 0 - ] - ], - [ - [ - 19, - 1, - 31, - 52, - 49, - 63 - ], - [ - 7, - 47, - 5, - 60, - 22, - 46 - ], - [ - 59, - 30, - 3, - 11, - 0, - 19 - ], - [ - 43, - 42, - 19, - 62, - 8, - 56 - ], - [ - 61, - 15, - 25, - 18, - 39, - 27 - ], - [ - 61, - 50, - 36, - 45, - 33, - 44 - ], - [ - 8, - 37, - 52, - 1, - 2, - 41 - ], - [ - 36, - 30, - 53, - 11, - 16, - 29 - ], - [ - 14, - 58, - 46, - 49, - 3, - 26 - ], - [ - 62, - 43, - 0, - 45, - 22, - 46 - ], - [ - 6, - 56, - 45, - 18, - 10, - 41 - ], - [ - 60, - 21, - 50, - 47, - 30, - 35 - ], - [ - 11, - 53, - 28, - 56, - 41, - 39 - ], - [ - 23, - 9, - 33, - 28, - 22, - 26 - ], - [ - 23, - 56, - 34, - 27, - 2, - 63 - ], - [ - 22, - 9, - 44, - 41, - 37, - 47 - ], - [ - 1, - 11, - 46, - 27, - 3, - 52 - ], - [ - 51, - 37, - 17, - 21, - 61, - 30 - ], - [ - 13, - 19, - 32, - 5, - 2, - 9 - ], - [ - 41, - 32, - 6, - 47, - 29, - 56 - ], - [ - 53, - 33, - 9, - 35, - 38, - 12 - ], - [ - 40, - 19, - 51, - 7, - 26, - 22 - ], - [ - 37, - 5, - 25, - 46, - 34, - 53 - ], - [ - 16, - 55, - 20, - 24, - 44, - 53 - ], - [ - 21, - 46, - 11, - 0, - 36, - 7 - ], - [ - 18, - 60, - 32, - 3, - 34, - 28 - ], - [ - 52, - 9, - 36, - 48, - 11, - 41 - ] - ], - [ - [ - 53, - 15, - 34, - 0, - 36, - 63 - ], - [ - 8, - 12, - 41, - 11, - 19, - 55 - ], - [ - 56, - 13, - 31, - 36, - 23, - 47 - ], - [ - 36, - 51, - 30, - 7, - 26, - 54 - ], - [ - 58, - 13, - 50, - 2, - 53, - 34 - ], - [ - 49, - 52, - 23, - 32, - 7, - 26 - ], - [ - 61, - 38, - 23, - 0, - 4, - 28 - ], - [ - 42, - 27, - 17, - 9, - 18, - 20 - ], - [ - 1, - 34, - 45, - 4, - 12, - 7 - ], - [ - 54, - 27, - 57, - 38, - 44, - 19 - ], - [ - 59, - 40, - 1, - 48, - 30, - 60 - ], - [ - 25, - 31, - 32, - 51, - 62, - 8 - ], - [ - 62, - 19, - 14, - 2, - 37, - 26 - ], - [ - 40, - 57, - 37, - 35, - 22, - 61 - ], - [ - 45, - 16, - 34, - 42, - 37, - 48 - ], - [ - 50, - 16, - 62, - 33, - 25, - 37 - ], - [ - 42, - 13, - 39, - 47, - 3, - 63 - ], - [ - 46, - 23, - 28, - 27, - 4, - 15 - ], - [ - 56, - 62, - 31, - 35, - 59, - 45 - ], - [ - 15, - 38, - 13, - 63, - 4, - 48 - ], - [ - 34, - 15, - 57, - 38, - 13, - 24 - ], - [ - 62, - 36, - 41, - 54, - 46, - 29 - ], - [ - 22, - 53, - 46, - 34, - 30, - 23 - ], - [ - 24, - 10, - 4, - 47, - 18, - 36 - ], - [ - 50, - 57, - 51, - 11, - 49, - 3 - ], - [ - 51, - 7, - 18, - 0, - 11, - 44 - ], - [ - 39, - 37, - 9, - 42, - 40, - 44 - ] - ], - [ - [ - 36, - 25, - 57, - 55, - 47, - 63 - ], - [ - 0, - 2, - 46, - 3, - 51, - 34 - ], - [ - 24, - 2, - 46, - 15, - 33, - 43 - ], - [ - 22, - 31, - 17, - 19, - 10, - 55 - ], - [ - 58, - 59, - 3, - 9, - 40, - 57 - ], - [ - 23, - 31, - 43, - 2, - 57, - 38 - ], - [ - 9, - 62, - 13, - 42, - 52, - 47 - ], - [ - 24, - 7, - 14, - 10, - 46, - 59 - ], - [ - 47, - 18, - 4, - 37, - 0, - 13 - ], - [ - 54, - 4, - 25, - 47, - 36, - 38 - ], - [ - 15, - 47, - 1, - 24, - 58, - 14 - ], - [ - 54, - 5, - 16, - 63, - 14, - 7 - ], - [ - 49, - 3, - 33, - 13, - 46, - 10 - ], - [ - 18, - 10, - 11, - 13, - 63, - 39 - ], - [ - 19, - 62, - 32, - 58, - 10, - 43 - ], - [ - 43, - 48, - 63, - 5, - 55, - 53 - ], - [ - 24, - 51, - 47, - 15, - 59, - 32 - ], - [ - 44, - 0, - 34, - 43, - 3, - 6 - ], - [ - 58, - 38, - 54, - 47, - 11, - 4 - ], - [ - 55, - 27, - 15, - 9, - 42, - 31 - ], - [ - 43, - 41, - 1, - 51, - 5, - 29 - ], - [ - 49, - 27, - 20, - 6, - 4, - 13 - ], - [ - 11, - 25, - 2, - 54, - 27, - 50 - ], - [ - 38, - 44, - 40, - 54, - 33, - 14 - ], - [ - 13, - 63, - 52, - 2, - 29, - 8 - ], - [ - 23, - 41, - 59, - 57, - 38, - 15 - ], - [ - 23, - 6, - 62, - 50, - 51, - 34 - ] - ], - [ - [ - 41, - 2, - 42, - 16, - 50, - 23 - ], - [ - 51, - 41, - 5, - 15, - 40, - 21 - ], - [ - 43, - 1, - 29, - 55, - 21, - 35 - ], - [ - 24, - 53, - 25, - 51, - 32, - 29 - ], - [ - 41, - 31, - 49, - 57, - 60, - 34 - ], - [ - 17, - 4, - 35, - 30, - 10, - 38 - ], - [ - 34, - 7, - 21, - 9, - 48, - 31 - ], - [ - 14, - 24, - 7, - 46, - 25, - 27 - ], - [ - 47, - 0, - 12, - 6, - 37, - 60 - ], - [ - 30, - 4, - 25, - 47, - 36, - 54 - ], - [ - 24, - 61, - 15, - 47, - 46, - 1 - ], - [ - 5, - 14, - 24, - 16, - 57, - 63 - ], - [ - 49, - 3, - 17, - 26, - 36, - 44 - ], - [ - 13, - 10, - 61, - 0, - 11, - 22 - ], - [ - 10, - 6, - 2, - 49, - 58, - 46 - ], - [ - 19, - 11, - 2, - 25, - 54, - 18 - ], - [ - 10, - 5, - 52, - 24, - 18, - 17 - ], - [ - 44, - 13, - 55, - 31, - 63, - 38 - ], - [ - 58, - 7, - 25, - 32, - 38, - 14 - ], - [ - 49, - 55, - 34, - 13, - 16, - 40 - ], - [ - 51, - 23, - 59, - 35, - 5, - 4 - ], - [ - 57, - 15, - 46, - 27, - 42, - 32 - ], - [ - 33, - 15, - 23, - 52, - 24, - 27 - ], - [ - 4, - 41, - 33, - 10, - 26, - 40 - ], - [ - 8, - 13, - 59, - 4, - 9, - 39 - ], - [ - 19, - 3, - 31, - 27, - 43, - 2 - ], - [ - 46, - 61, - 25, - 8, - 29, - 50 - ] - ], - [ - [ - 48, - 38, - 42, - 50, - 62, - 63 - ], - [ - 3, - 2, - 10, - 26, - 17, - 6 - ], - [ - 39, - 44, - 6, - 45, - 7, - 40 - ], - [ - 5, - 62, - 53, - 50, - 41, - 3 - ], - [ - 6, - 41, - 49, - 37, - 30, - 23 - ], - [ - 59, - 60, - 4, - 46, - 53, - 29 - ], - [ - 44, - 34, - 7, - 15, - 13, - 43 - ], - [ - 12, - 24, - 14, - 25, - 58, - 7 - ], - [ - 33, - 35, - 4, - 37, - 8, - 55 - ], - [ - 40, - 30, - 36, - 25, - 20, - 54 - ], - [ - 9, - 8, - 24, - 25, - 63, - 5 - ], - [ - 59, - 34, - 5, - 24, - 57, - 6 - ], - [ - 17, - 49, - 44, - 26, - 55, - 7 - ], - [ - 47, - 13, - 59, - 27, - 22, - 61 - ], - [ - 49, - 2, - 44, - 10, - 46, - 40 - ], - [ - 34, - 2, - 54, - 57, - 53, - 55 - ], - [ - 5, - 17, - 29, - 31, - 43, - 52 - ], - [ - 48, - 57, - 38, - 39, - 63, - 43 - ], - [ - 42, - 38, - 49, - 32, - 7, - 40 - ], - [ - 59, - 1, - 16, - 23, - 60, - 10 - ], - [ - 21, - 7, - 35, - 53, - 48, - 31 - ], - [ - 8, - 60, - 24, - 42, - 14, - 35 - ], - [ - 33, - 51, - 15, - 28, - 23, - 5 - ], - [ - 4, - 41, - 9, - 11, - 8, - 51 - ], - [ - 60, - 24, - 19, - 48, - 9, - 4 - ], - [ - 19, - 3, - 26, - 58, - 12, - 61 - ], - [ - 35, - 54, - 1, - 60, - 53, - 49 - ] - ], - [ - [ - 21, - 7, - 53, - 56, - 63, - 33 - ], - [ - 3, - 34, - 57, - 16, - 20, - 51 - ], - [ - 55, - 11, - 16, - 60, - 0, - 13 - ], - [ - 62, - 43, - 5, - 50, - 8, - 42 - ], - [ - 35, - 52, - 9, - 43, - 0, - 27 - ], - [ - 16, - 60, - 29, - 61, - 28, - 58 - ], - [ - 31, - 34, - 43, - 30, - 2, - 18 - ], - [ - 36, - 25, - 53, - 24, - 35, - 48 - ], - [ - 48, - 4, - 35, - 32, - 8, - 60 - ], - [ - 25, - 54, - 30, - 10, - 4, - 41 - ], - [ - 9, - 3, - 19, - 20, - 61, - 24 - ], - [ - 30, - 5, - 17, - 45, - 18, - 49 - ], - [ - 49, - 21, - 17, - 26, - 42, - 37 - ], - [ - 38, - 25, - 10, - 51, - 54, - 13 - ], - [ - 2, - 27, - 26, - 10, - 58, - 30 - ], - [ - 9, - 2, - 54, - 21, - 25, - 13 - ], - [ - 22, - 23, - 33, - 27, - 51, - 52 - ], - [ - 30, - 21, - 35, - 55, - 5, - 17 - ], - [ - 38, - 27, - 24, - 56, - 21, - 35 - ], - [ - 53, - 45, - 22, - 28, - 0, - 32 - ], - [ - 40, - 42, - 48, - 37, - 8, - 52 - ], - [ - 46, - 45, - 57, - 36, - 51, - 32 - ], - [ - 62, - 10, - 38, - 42, - 41, - 54 - ], - [ - 60, - 1, - 10, - 31, - 44, - 36 - ], - [ - 29, - 6, - 9, - 8, - 56, - 7 - ], - [ - 9, - 59, - 22, - 0, - 35, - 57 - ], - [ - 45, - 27, - 62, - 47, - 3, - 28 - ] - ], - [ - [ - 27, - 13, - 18, - 8, - 63, - 55 - ], - [ - 36, - 21, - 57, - 8, - 46, - 55 - ], - [ - 43, - 61, - 10, - 13, - 41, - 37 - ], - [ - 43, - 16, - 24, - 6, - 26, - 61 - ], - [ - 60, - 29, - 35, - 31, - 16, - 23 - ], - [ - 9, - 58, - 60, - 17, - 0, - 38 - ], - [ - 63, - 16, - 7, - 13, - 31, - 18 - ], - [ - 16, - 25, - 24, - 2, - 47, - 58 - ], - [ - 8, - 35, - 6, - 14, - 48, - 4 - ], - [ - 30, - 25, - 54, - 4, - 10, - 34 - ], - [ - 20, - 23, - 19, - 7, - 38, - 61 - ], - [ - 5, - 3, - 24, - 32, - 12, - 42 - ], - [ - 49, - 17, - 10, - 26, - 32, - 60 - ], - [ - 10, - 13, - 29, - 22, - 58, - 54 - ], - [ - 44, - 8, - 2, - 58, - 6, - 5 - ], - [ - 25, - 6, - 2, - 54, - 19, - 53 - ], - [ - 5, - 61, - 30, - 17, - 1, - 27 - ], - [ - 22, - 55, - 63, - 57, - 19, - 33 - ], - [ - 35, - 46, - 6, - 32, - 14, - 7 - ], - [ - 45, - 13, - 51, - 14, - 7, - 5 - ], - [ - 15, - 23, - 50, - 51, - 13, - 59 - ], - [ - 57, - 36, - 53, - 61, - 6, - 14 - ], - [ - 10, - 23, - 15, - 33, - 27, - 38 - ], - [ - 4, - 10, - 33, - 41, - 26, - 36 - ], - [ - 8, - 4, - 33, - 9, - 47, - 20 - ], - [ - 3, - 43, - 6, - 19, - 30, - 2 - ], - [ - 55, - 4, - 46, - 40, - 18, - 20 - ] - ], - [ - [ - 48, - 38, - 63, - 42, - 47, - 50 - ], - [ - 3, - 10, - 26, - 6, - 35, - 2 - ], - [ - 39, - 44, - 7, - 45, - 6, - 42 - ], - [ - 56, - 23, - 61, - 16, - 33, - 41 - ], - [ - 6, - 37, - 44, - 16, - 30, - 19 - ], - [ - 59, - 17, - 60, - 19, - 5, - 52 - ], - [ - 44, - 55, - 33, - 63, - 13, - 31 - ], - [ - 12, - 25, - 63, - 2, - 24, - 54 - ], - [ - 33, - 8, - 42, - 4, - 35, - 41 - ], - [ - 30, - 40, - 25, - 21, - 48, - 51 - ], - [ - 8, - 23, - 34, - 5, - 19, - 22 - ], - [ - 34, - 59, - 42, - 11, - 5, - 3 - ], - [ - 10, - 26, - 50, - 55, - 53, - 59 - ], - [ - 59, - 54, - 13, - 10, - 47, - 25 - ], - [ - 24, - 2, - 49, - 46, - 38, - 8 - ], - [ - 34, - 2, - 53, - 54, - 40, - 57 - ], - [ - 17, - 29, - 5, - 27, - 18, - 43 - ], - [ - 48, - 57, - 63, - 39, - 38, - 32 - ], - [ - 42, - 49, - 7, - 32, - 38, - 61 - ], - [ - 1, - 59, - 23, - 14, - 16, - 10 - ], - [ - 21, - 7, - 35, - 53, - 13, - 11 - ], - [ - 8, - 60, - 24, - 25, - 55, - 42 - ], - [ - 33, - 5, - 15, - 51, - 28, - 44 - ], - [ - 4, - 9, - 51, - 8, - 41, - 39 - ], - [ - 60, - 4, - 19, - 48, - 9, - 24 - ], - [ - 19, - 3, - 26, - 12, - 53, - 21 - ], - [ - 53, - 35, - 1, - 54, - 60, - 20 - ] - ], - [ - [ - 17, - 37, - 31, - 32, - 63, - 50 - ], - [ - 12, - 2, - 9, - 32, - 47, - 17 - ], - [ - 3, - 57, - 56, - 50, - 33, - 38 - ], - [ - 43, - 42, - 19, - 52, - 8, - 17 - ], - [ - 61, - 39, - 27, - 12, - 15, - 57 - ], - [ - 6, - 33, - 36, - 44, - 29, - 61 - ], - [ - 2, - 41, - 42, - 15, - 52, - 5 - ], - [ - 36, - 37, - 15, - 53, - 18, - 62 - ], - [ - 50, - 58, - 3, - 5, - 16, - 4 - ], - [ - 16, - 22, - 55, - 32, - 41, - 26 - ], - [ - 3, - 41, - 56, - 45, - 38, - 34 - ], - [ - 21, - 53, - 50, - 34, - 38, - 35 - ], - [ - 35, - 11, - 37, - 16, - 53, - 42 - ], - [ - 9, - 15, - 54, - 38, - 12, - 52 - ], - [ - 56, - 2, - 24, - 23, - 9, - 5 - ], - [ - 2, - 44, - 9, - 53, - 35, - 54 - ], - [ - 27, - 46, - 1, - 5, - 60, - 50 - ], - [ - 51, - 61, - 50, - 54, - 33, - 44 - ], - [ - 2, - 38, - 9, - 19, - 5, - 32 - ], - [ - 32, - 28, - 6, - 15, - 0, - 33 - ], - [ - 26, - 53, - 27, - 13, - 5, - 4 - ], - [ - 34, - 40, - 45, - 55, - 62, - 53 - ], - [ - 41, - 5, - 46, - 4, - 45, - 3 - ], - [ - 20, - 4, - 22, - 47, - 59, - 58 - ], - [ - 62, - 36, - 25, - 28, - 53, - 33 - ], - [ - 22, - 33, - 7, - 3, - 54, - 0 - ], - [ - 58, - 27, - 52, - 48, - 45, - 17 - ] - ], - [ - [ - 3, - 43, - 17, - 42, - 35, - 55 - ], - [ - 0, - 42, - 36, - 34, - 32, - 40 - ], - [ - 43, - 12, - 19, - 20, - 2, - 4 - ], - [ - 4, - 28, - 54, - 22, - 58, - 23 - ], - [ - 12, - 60, - 52, - 57, - 46, - 7 - ], - [ - 43, - 2, - 1, - 63, - 20, - 11 - ], - [ - 42, - 19, - 57, - 31, - 15, - 21 - ], - [ - 35, - 6, - 46, - 25, - 36, - 33 - ], - [ - 11, - 4, - 35, - 50, - 32, - 52 - ], - [ - 21, - 25, - 38, - 47, - 31, - 58 - ], - [ - 3, - 2, - 22, - 45, - 42, - 15 - ], - [ - 2, - 21, - 5, - 7, - 43, - 38 - ], - [ - 33, - 10, - 41, - 42, - 60, - 16 - ], - [ - 44, - 56, - 30, - 54, - 9, - 43 - ], - [ - 9, - 5, - 33, - 28, - 58, - 49 - ], - [ - 8, - 53, - 54, - 59, - 31, - 2 - ], - [ - 47, - 0, - 48, - 27, - 24, - 51 - ], - [ - 0, - 60, - 21, - 12, - 56, - 28 - ], - [ - 38, - 13, - 53, - 51, - 9, - 7 - ], - [ - 19, - 28, - 56, - 17, - 21, - 26 - ], - [ - 12, - 43, - 26, - 4, - 0, - 5 - ], - [ - 43, - 1, - 55, - 12, - 4, - 21 - ], - [ - 41, - 5, - 21, - 45, - 25, - 12 - ], - [ - 4, - 20, - 1, - 22, - 30, - 42 - ], - [ - 5, - 62, - 25, - 54, - 48, - 47 - ], - [ - 10, - 28, - 55, - 23, - 0, - 53 - ], - [ - 30, - 57, - 8, - 29, - 17, - 5 - ] - ], - [ - [ - 11, - 16, - 31, - 0, - 46, - 35 - ], - [ - 13, - 49, - 31, - 50, - 16, - 34 - ], - [ - 36, - 13, - 53, - 27, - 4, - 18 - ], - [ - 24, - 7, - 29, - 26, - 12, - 32 - ], - [ - 17, - 35, - 2, - 44, - 10, - 48 - ], - [ - 8, - 7, - 23, - 5, - 51, - 26 - ], - [ - 58, - 15, - 61, - 29, - 38, - 62 - ], - [ - 20, - 9, - 42, - 35, - 3, - 6 - ], - [ - 4, - 47, - 25, - 11, - 1, - 52 - ], - [ - 54, - 25, - 55, - 38, - 27, - 21 - ], - [ - 1, - 60, - 14, - 59, - 30, - 22 - ], - [ - 51, - 31, - 5, - 25, - 14, - 52 - ], - [ - 34, - 2, - 10, - 26, - 52, - 47 - ], - [ - 40, - 57, - 13, - 54, - 9, - 6 - ], - [ - 8, - 44, - 58, - 5, - 16, - 1 - ], - [ - 8, - 53, - 59, - 25, - 52, - 24 - ], - [ - 13, - 47, - 45, - 0, - 42, - 8 - ], - [ - 23, - 44, - 55, - 33, - 38, - 7 - ], - [ - 53, - 38, - 11, - 1, - 8, - 24 - ], - [ - 15, - 38, - 14, - 28, - 19, - 0 - ], - [ - 41, - 27, - 8, - 42, - 40, - 57 - ], - [ - 11, - 62, - 55, - 16, - 10, - 41 - ], - [ - 31, - 20, - 46, - 34, - 37, - 41 - ], - [ - 47, - 4, - 55, - 33, - 49, - 22 - ], - [ - 11, - 45, - 47, - 48, - 54, - 36 - ], - [ - 50, - 10, - 23, - 51, - 3, - 18 - ], - [ - 17, - 31, - 5, - 36, - 4, - 20 - ] - ], - [ - [ - 22, - 6, - 39, - 57, - 29, - 47 - ], - [ - 27, - 6, - 14, - 17, - 51, - 55 - ], - [ - 1, - 11, - 29, - 26, - 47, - 4 - ], - [ - 14, - 38, - 31, - 22, - 29, - 6 - ], - [ - 14, - 59, - 61, - 16, - 1, - 19 - ], - [ - 30, - 8, - 23, - 21, - 47, - 1 - ], - [ - 58, - 4, - 15, - 61, - 27, - 31 - ], - [ - 20, - 42, - 3, - 9, - 35, - 6 - ], - [ - 47, - 4, - 25, - 8, - 36, - 0 - ], - [ - 54, - 55, - 21, - 19, - 33, - 25 - ], - [ - 14, - 4, - 60, - 20, - 40, - 24 - ], - [ - 51, - 0, - 5, - 32, - 52, - 3 - ], - [ - 2, - 36, - 10, - 52, - 26, - 32 - ], - [ - 40, - 13, - 54, - 36, - 57, - 46 - ], - [ - 44, - 8, - 5, - 37, - 58, - 2 - ], - [ - 6, - 53, - 24, - 2, - 54, - 37 - ], - [ - 13, - 47, - 61, - 5, - 19, - 17 - ], - [ - 58, - 55, - 44, - 38, - 63, - 6 - ], - [ - 35, - 46, - 31, - 1, - 19, - 32 - ], - [ - 15, - 13, - 63, - 45, - 9, - 55 - ], - [ - 27, - 15, - 23, - 6, - 35, - 63 - ], - [ - 36, - 62, - 57, - 41, - 10, - 16 - ], - [ - 10, - 33, - 41, - 20, - 5, - 46 - ], - [ - 47, - 4, - 26, - 55, - 10, - 49 - ], - [ - 11, - 4, - 8, - 48, - 36, - 33 - ], - [ - 18, - 51, - 43, - 33, - 50, - 6 - ], - [ - 4, - 55, - 9, - 36, - 5, - 43 - ] - ], - [ - [ - 48, - 38, - 63, - 42, - 47, - 7 - ], - [ - 3, - 26, - 10, - 6, - 42, - 2 - ], - [ - 39, - 44, - 6, - 7, - 45, - 8 - ], - [ - 60, - 31, - 22, - 0, - 54, - 45 - ], - [ - 6, - 59, - 14, - 16, - 37, - 44 - ], - [ - 59, - 30, - 8, - 47, - 60, - 17 - ], - [ - 44, - 56, - 4, - 13, - 15, - 9 - ], - [ - 12, - 24, - 20, - 58, - 61, - 28 - ], - [ - 47, - 33, - 8, - 4, - 36, - 42 - ], - [ - 40, - 54, - 55, - 33, - 21, - 51 - ], - [ - 8, - 14, - 22, - 39, - 31, - 7 - ], - [ - 34, - 59, - 14, - 5, - 52, - 0 - ], - [ - 36, - 10, - 52, - 26, - 44, - 59 - ], - [ - 59, - 44, - 13, - 1, - 36, - 22 - ], - [ - 24, - 5, - 46, - 2, - 44, - 38 - ], - [ - 34, - 53, - 6, - 40, - 30, - 2 - ], - [ - 17, - 29, - 50, - 5, - 47, - 27 - ], - [ - 48, - 38, - 57, - 63, - 32, - 60 - ], - [ - 42, - 7, - 49, - 46, - 32, - 1 - ], - [ - 59, - 1, - 23, - 14, - 10, - 16 - ], - [ - 21, - 35, - 7, - 53, - 6, - 17 - ], - [ - 8, - 24, - 60, - 13, - 55, - 14 - ], - [ - 33, - 5, - 51, - 15, - 3, - 23 - ], - [ - 4, - 9, - 51, - 26, - 41, - 28 - ], - [ - 60, - 4, - 19, - 47, - 40, - 48 - ], - [ - 19, - 3, - 26, - 21, - 53, - 8 - ], - [ - 53, - 1, - 35, - 54, - 60, - 20 - ] - ], - [ - [ - 37, - 46, - 39, - 54, - 27, - 55 - ], - [ - 34, - 5, - 16, - 47, - 6, - 42 - ], - [ - 32, - 38, - 16, - 42, - 3, - 20 - ], - [ - 43, - 19, - 60, - 42, - 52, - 11 - ], - [ - 7, - 61, - 39, - 57, - 12, - 46 - ], - [ - 50, - 36, - 44, - 24, - 28, - 31 - ], - [ - 15, - 42, - 41, - 2, - 40, - 32 - ], - [ - 36, - 35, - 38, - 53, - 58, - 51 - ], - [ - 50, - 4, - 16, - 3, - 36, - 58 - ], - [ - 59, - 16, - 32, - 38, - 26, - 56 - ], - [ - 3, - 45, - 37, - 41, - 21, - 33 - ], - [ - 21, - 53, - 34, - 35, - 5, - 39 - ], - [ - 35, - 42, - 63, - 16, - 11, - 46 - ], - [ - 23, - 9, - 51, - 54, - 15, - 38 - ], - [ - 2, - 5, - 20, - 43, - 24, - 29 - ], - [ - 53, - 44, - 9, - 2, - 54, - 20 - ], - [ - 27, - 46, - 1, - 47, - 50, - 5 - ], - [ - 61, - 54, - 57, - 44, - 51, - 43 - ], - [ - 19, - 9, - 33, - 38, - 61, - 37 - ], - [ - 0, - 33, - 15, - 32, - 6, - 9 - ], - [ - 26, - 27, - 53, - 5, - 47, - 54 - ], - [ - 55, - 34, - 12, - 62, - 3, - 4 - ], - [ - 41, - 5, - 46, - 40, - 4, - 32 - ], - [ - 55, - 20, - 44, - 26, - 4, - 40 - ], - [ - 62, - 12, - 28, - 34, - 23, - 33 - ], - [ - 18, - 7, - 22, - 3, - 54, - 14 - ], - [ - 36, - 9, - 27, - 52, - 48, - 11 - ] - ], - [ - [ - 46, - 37, - 61, - 18, - 36, - 63 - ], - [ - 22, - 34, - 28, - 59, - 24, - 56 - ], - [ - 32, - 15, - 17, - 60, - 38, - 20 - ], - [ - 28, - 4, - 58, - 16, - 30, - 35 - ], - [ - 7, - 36, - 9, - 57, - 33, - 23 - ], - [ - 43, - 63, - 2, - 30, - 11, - 19 - ], - [ - 19, - 42, - 57, - 15, - 3, - 22 - ], - [ - 35, - 46, - 47, - 6, - 58, - 8 - ], - [ - 32, - 4, - 37, - 36, - 35, - 57 - ], - [ - 21, - 47, - 55, - 58, - 38, - 54 - ], - [ - 3, - 42, - 2, - 38, - 5, - 17 - ], - [ - 43, - 18, - 2, - 21, - 5, - 10 - ], - [ - 10, - 49, - 24, - 32, - 25, - 4 - ], - [ - 30, - 21, - 63, - 49, - 16, - 51 - ], - [ - 62, - 2, - 13, - 5, - 29, - 40 - ], - [ - 53, - 2, - 9, - 63, - 46, - 54 - ], - [ - 27, - 0, - 63, - 47, - 5, - 14 - ], - [ - 60, - 57, - 43, - 44, - 0, - 50 - ], - [ - 38, - 9, - 61, - 33, - 60, - 13 - ], - [ - 25, - 0, - 15, - 34, - 35, - 11 - ], - [ - 26, - 54, - 28, - 47, - 53, - 37 - ], - [ - 55, - 1, - 34, - 4, - 21, - 22 - ], - [ - 41, - 5, - 40, - 21, - 17, - 23 - ], - [ - 25, - 43, - 52, - 26, - 4, - 55 - ], - [ - 52, - 29, - 35, - 17, - 45, - 60 - ], - [ - 38, - 10, - 15, - 7, - 50, - 3 - ], - [ - 23, - 6, - 19, - 56, - 41, - 15 - ] - ], - [ - [ - 44, - 14, - 20, - 47, - 19, - 56 - ], - [ - 28, - 34, - 2, - 56, - 0, - 11 - ], - [ - 46, - 15, - 61, - 14, - 22, - 60 - ], - [ - 15, - 50, - 36, - 47, - 25, - 21 - ], - [ - 25, - 12, - 13, - 36, - 23, - 57 - ], - [ - 22, - 43, - 1, - 37, - 36, - 30 - ], - [ - 28, - 42, - 19, - 31, - 14, - 21 - ], - [ - 35, - 51, - 47, - 33, - 7, - 46 - ], - [ - 32, - 4, - 16, - 11, - 19, - 35 - ], - [ - 21, - 38, - 16, - 47, - 62, - 15 - ], - [ - 3, - 45, - 39, - 50, - 21, - 2 - ], - [ - 21, - 5, - 43, - 56, - 53, - 51 - ], - [ - 10, - 41, - 33, - 11, - 63, - 37 - ], - [ - 56, - 16, - 9, - 33, - 63, - 54 - ], - [ - 58, - 52, - 2, - 5, - 30, - 56 - ], - [ - 53, - 59, - 8, - 55, - 2, - 15 - ], - [ - 0, - 47, - 51, - 45, - 14, - 37 - ], - [ - 0, - 43, - 21, - 12, - 60, - 53 - ], - [ - 38, - 53, - 60, - 34, - 9, - 36 - ], - [ - 27, - 9, - 56, - 53, - 0, - 11 - ], - [ - 28, - 12, - 43, - 54, - 5, - 62 - ], - [ - 50, - 55, - 34, - 16, - 4, - 21 - ], - [ - 4, - 21, - 5, - 40, - 32, - 54 - ], - [ - 25, - 57, - 49, - 1, - 44, - 43 - ], - [ - 5, - 35, - 42, - 25, - 2, - 22 - ], - [ - 15, - 28, - 38, - 55, - 35, - 37 - ], - [ - 19, - 30, - 6, - 33, - 57, - 39 - ] - ], - [ - [ - 11, - 31, - 46, - 49, - 0, - 16 - ], - [ - 13, - 49, - 50, - 16, - 31, - 19 - ], - [ - 36, - 13, - 27, - 34, - 4, - 53 - ], - [ - 24, - 32, - 7, - 59, - 13, - 15 - ], - [ - 17, - 35, - 2, - 44, - 10, - 63 - ], - [ - 8, - 7, - 23, - 26, - 56, - 42 - ], - [ - 58, - 61, - 29, - 38, - 62, - 50 - ], - [ - 20, - 42, - 3, - 35, - 61, - 47 - ], - [ - 4, - 10, - 7, - 47, - 25, - 34 - ], - [ - 54, - 55, - 18, - 38, - 27, - 28 - ], - [ - 60, - 1, - 14, - 59, - 3, - 30 - ], - [ - 51, - 31, - 5, - 21, - 25, - 52 - ], - [ - 34, - 10, - 2, - 26, - 11, - 47 - ], - [ - 16, - 40, - 23, - 34, - 9, - 33 - ], - [ - 34, - 8, - 1, - 31, - 5, - 32 - ], - [ - 25, - 59, - 8, - 52, - 53, - 27 - ], - [ - 13, - 47, - 48, - 45, - 42, - 0 - ], - [ - 4, - 23, - 53, - 10, - 3, - 25 - ], - [ - 53, - 8, - 38, - 11, - 24, - 63 - ], - [ - 38, - 15, - 0, - 11, - 21, - 8 - ], - [ - 27, - 23, - 62, - 8, - 41, - 42 - ], - [ - 55, - 11, - 10, - 41, - 16, - 34 - ], - [ - 31, - 21, - 54, - 34, - 41, - 46 - ], - [ - 55, - 25, - 22, - 33, - 44, - 47 - ], - [ - 47, - 45, - 35, - 20, - 12, - 56 - ], - [ - 50, - 10, - 47, - 35, - 57, - 53 - ], - [ - 17, - 31, - 5, - 37, - 36, - 20 - ] - ], - [ - [ - 22, - 6, - 29, - 39, - 57, - 44 - ], - [ - 27, - 6, - 14, - 17, - 55, - 57 - ], - [ - 1, - 11, - 29, - 26, - 47, - 4 - ], - [ - 14, - 38, - 31, - 22, - 36, - 29 - ], - [ - 14, - 61, - 59, - 16, - 44, - 63 - ], - [ - 30, - 8, - 47, - 23, - 21, - 0 - ], - [ - 58, - 4, - 54, - 61, - 15, - 62 - ], - [ - 20, - 42, - 58, - 3, - 35, - 0 - ], - [ - 10, - 47, - 4, - 8, - 41, - 19 - ], - [ - 54, - 55, - 33, - 11, - 38, - 21 - ], - [ - 14, - 1, - 4, - 60, - 20, - 40 - ], - [ - 51, - 5, - 14, - 0, - 32, - 62 - ], - [ - 2, - 10, - 36, - 49, - 32, - 52 - ], - [ - 13, - 40, - 54, - 36, - 22, - 11 - ], - [ - 2, - 44, - 37, - 58, - 5, - 8 - ], - [ - 6, - 24, - 2, - 53, - 19, - 52 - ], - [ - 13, - 47, - 61, - 5, - 50, - 30 - ], - [ - 55, - 58, - 4, - 63, - 22, - 33 - ], - [ - 35, - 46, - 32, - 14, - 7, - 58 - ], - [ - 15, - 13, - 45, - 0, - 51, - 60 - ], - [ - 15, - 27, - 23, - 50, - 6, - 51 - ], - [ - 57, - 36, - 10, - 55, - 14, - 16 - ], - [ - 10, - 23, - 33, - 41, - 30, - 54 - ], - [ - 33, - 26, - 4, - 10, - 49, - 55 - ], - [ - 4, - 8, - 33, - 9, - 20, - 48 - ], - [ - 43, - 6, - 3, - 50, - 18, - 38 - ], - [ - 4, - 55, - 36, - 43, - 46, - 5 - ] - ], - [ - [ - 48, - 38, - 63, - 47, - 7, - 42 - ], - [ - 3, - 10, - 26, - 6, - 2, - 38 - ], - [ - 39, - 44, - 6, - 7, - 45, - 8 - ], - [ - 60, - 31, - 22, - 0, - 54, - 27 - ], - [ - 6, - 59, - 16, - 14, - 37, - 44 - ], - [ - 59, - 8, - 30, - 47, - 48, - 3 - ], - [ - 44, - 54, - 4, - 56, - 15, - 13 - ], - [ - 12, - 24, - 20, - 58, - 31, - 61 - ], - [ - 47, - 33, - 10, - 8, - 4, - 36 - ], - [ - 54, - 40, - 55, - 33, - 11, - 51 - ], - [ - 8, - 14, - 23, - 29, - 22, - 31 - ], - [ - 34, - 59, - 14, - 5, - 57, - 50 - ], - [ - 44, - 36, - 10, - 59, - 52, - 9 - ], - [ - 59, - 44, - 13, - 9, - 54, - 1 - ], - [ - 24, - 2, - 38, - 46, - 44, - 47 - ], - [ - 34, - 53, - 6, - 2, - 57, - 40 - ], - [ - 17, - 29, - 47, - 5, - 50, - 34 - ], - [ - 48, - 63, - 57, - 38, - 60, - 32 - ], - [ - 42, - 7, - 49, - 32, - 46, - 1 - ], - [ - 1, - 59, - 23, - 14, - 16, - 10 - ], - [ - 21, - 35, - 7, - 53, - 6, - 17 - ], - [ - 8, - 24, - 60, - 14, - 13, - 55 - ], - [ - 33, - 5, - 51, - 15, - 23, - 3 - ], - [ - 9, - 4, - 51, - 26, - 41, - 8 - ], - [ - 4, - 60, - 19, - 47, - 40, - 9 - ], - [ - 19, - 3, - 26, - 21, - 17, - 8 - ], - [ - 1, - 35, - 53, - 54, - 60, - 20 - ] - ], - [ - [ - 12, - 41, - 14, - 62, - 24, - 10 - ], - [ - 10, - 53, - 39, - 35, - 41, - 58 - ], - [ - 33, - 32, - 50, - 31, - 3, - 34 - ], - [ - 43, - 10, - 42, - 11, - 17, - 47 - ], - [ - 42, - 12, - 11, - 19, - 58, - 54 - ], - [ - 36, - 50, - 55, - 61, - 25, - 56 - ], - [ - 41, - 22, - 16, - 52, - 2, - 15 - ], - [ - 26, - 36, - 62, - 53, - 15, - 51 - ], - [ - 58, - 16, - 5, - 53, - 3, - 49 - ], - [ - 32, - 46, - 26, - 45, - 16, - 62 - ], - [ - 41, - 45, - 56, - 49, - 11, - 3 - ], - [ - 52, - 34, - 35, - 50, - 21, - 53 - ], - [ - 59, - 53, - 46, - 30, - 39, - 37 - ], - [ - 20, - 9, - 52, - 2, - 7, - 33 - ], - [ - 20, - 50, - 24, - 29, - 23, - 2 - ], - [ - 53, - 2, - 44, - 41, - 9, - 13 - ], - [ - 47, - 27, - 1, - 5, - 45, - 46 - ], - [ - 61, - 7, - 51, - 30, - 35, - 9 - ], - [ - 19, - 34, - 32, - 17, - 2, - 14 - ], - [ - 15, - 32, - 6, - 45, - 9, - 11 - ], - [ - 35, - 27, - 6, - 53, - 13, - 60 - ], - [ - 26, - 62, - 16, - 28, - 41, - 3 - ], - [ - 5, - 20, - 46, - 37, - 11, - 55 - ], - [ - 55, - 47, - 4, - 16, - 14, - 27 - ], - [ - 36, - 11, - 27, - 62, - 33, - 7 - ], - [ - 33, - 7, - 54, - 3, - 32, - 12 - ], - [ - 58, - 52, - 27, - 26, - 48, - 38 - ] - ], - [ - [ - 6, - 52, - 19, - 63, - 46, - 38 - ], - [ - 8, - 42, - 4, - 47, - 57, - 56 - ], - [ - 31, - 46, - 32, - 4, - 14, - 10 - ], - [ - 28, - 27, - 4, - 37, - 58, - 20 - ], - [ - 57, - 59, - 60, - 62, - 22, - 14 - ], - [ - 30, - 2, - 9, - 57, - 11, - 13 - ], - [ - 20, - 19, - 57, - 42, - 51, - 27 - ], - [ - 46, - 35, - 47, - 32, - 7, - 0 - ], - [ - 12, - 4, - 35, - 10, - 50, - 47 - ], - [ - 15, - 47, - 54, - 25, - 38, - 51 - ], - [ - 18, - 42, - 5, - 15, - 38, - 61 - ], - [ - 18, - 22, - 8, - 16, - 5, - 7 - ], - [ - 5, - 3, - 24, - 35, - 4, - 30 - ], - [ - 17, - 13, - 20, - 55, - 36, - 22 - ], - [ - 22, - 10, - 21, - 54, - 6, - 47 - ], - [ - 30, - 53, - 6, - 19, - 2, - 54 - ], - [ - 22, - 26, - 5, - 7, - 47, - 21 - ], - [ - 4, - 41, - 13, - 46, - 55, - 43 - ], - [ - 17, - 56, - 32, - 45, - 14, - 6 - ], - [ - 5, - 49, - 53, - 28, - 34, - 60 - ], - [ - 42, - 55, - 57, - 17, - 28, - 22 - ], - [ - 40, - 23, - 28, - 57, - 21, - 16 - ], - [ - 51, - 21, - 35, - 24, - 44, - 10 - ], - [ - 38, - 25, - 4, - 14, - 62, - 31 - ], - [ - 31, - 50, - 13, - 56, - 39, - 8 - ], - [ - 6, - 51, - 55, - 8, - 0, - 21 - ], - [ - 29, - 46, - 18, - 55, - 37, - 50 - ] - ], - [ - [ - 31, - 46, - 49, - 59, - 35, - 14 - ], - [ - 13, - 16, - 31, - 50, - 33, - 19 - ], - [ - 13, - 36, - 27, - 52, - 3, - 19 - ], - [ - 9, - 24, - 29, - 12, - 55, - 61 - ], - [ - 35, - 17, - 2, - 10, - 41, - 48 - ], - [ - 51, - 23, - 30, - 57, - 52, - 17 - ], - [ - 34, - 48, - 0, - 38, - 27, - 51 - ], - [ - 47, - 32, - 2, - 35, - 58, - 55 - ], - [ - 15, - 12, - 35, - 45, - 52, - 4 - ], - [ - 36, - 15, - 4, - 59, - 11, - 54 - ], - [ - 60, - 5, - 59, - 50, - 29, - 18 - ], - [ - 18, - 8, - 31, - 23, - 24, - 3 - ], - [ - 34, - 4, - 5, - 58, - 14, - 30 - ], - [ - 17, - 57, - 36, - 55, - 51, - 6 - ], - [ - 10, - 45, - 8, - 54, - 22, - 47 - ], - [ - 30, - 50, - 14, - 37, - 6, - 16 - ], - [ - 26, - 0, - 22, - 5, - 13, - 4 - ], - [ - 4, - 46, - 23, - 43, - 60, - 34 - ], - [ - 45, - 8, - 56, - 62, - 17, - 51 - ], - [ - 4, - 38, - 53, - 15, - 58, - 5 - ], - [ - 57, - 24, - 34, - 15, - 55, - 42 - ], - [ - 23, - 37, - 40, - 41, - 54, - 11 - ], - [ - 48, - 29, - 51, - 31, - 34, - 9 - ], - [ - 62, - 14, - 18, - 4, - 31, - 59 - ], - [ - 50, - 31, - 49, - 45, - 56, - 57 - ], - [ - 51, - 50, - 56, - 8, - 12, - 52 - ], - [ - 17, - 37, - 5, - 39, - 44, - 20 - ] - ], - [ - [ - 45, - 13, - 63, - 37, - 38, - 56 - ], - [ - 63, - 6, - 12, - 18, - 27, - 51 - ], - [ - 3, - 21, - 4, - 48, - 17, - 27 - ], - [ - 14, - 55, - 9, - 37, - 29, - 26 - ], - [ - 35, - 57, - 2, - 13, - 41, - 10 - ], - [ - 39, - 30, - 57, - 13, - 53, - 23 - ], - [ - 34, - 48, - 38, - 27, - 56, - 51 - ], - [ - 47, - 35, - 46, - 2, - 32, - 7 - ], - [ - 12, - 56, - 4, - 35, - 52, - 50 - ], - [ - 15, - 36, - 14, - 30, - 47, - 48 - ], - [ - 5, - 60, - 59, - 22, - 15, - 18 - ], - [ - 8, - 22, - 23, - 47, - 5, - 26 - ], - [ - 4, - 5, - 30, - 58, - 26, - 8 - ], - [ - 17, - 29, - 13, - 50, - 14, - 34 - ], - [ - 54, - 10, - 41, - 14, - 21, - 5 - ], - [ - 50, - 6, - 54, - 30, - 16, - 56 - ], - [ - 0, - 47, - 7, - 4, - 5, - 34 - ], - [ - 4, - 46, - 23, - 41, - 59, - 55 - ], - [ - 62, - 45, - 1, - 57, - 47, - 32 - ], - [ - 4, - 53, - 15, - 50, - 60, - 1 - ], - [ - 34, - 57, - 24, - 31, - 15, - 53 - ], - [ - 23, - 41, - 45, - 12, - 57, - 14 - ], - [ - 51, - 29, - 49, - 48, - 18, - 52 - ], - [ - 14, - 4, - 62, - 18, - 26, - 28 - ], - [ - 50, - 49, - 7, - 24, - 9, - 3 - ], - [ - 56, - 20, - 51, - 3, - 26, - 14 - ], - [ - 55, - 37, - 50, - 14, - 42, - 20 - ] - ], - [ - [ - 51, - 43, - 27, - 30, - 5, - 55 - ], - [ - 24, - 16, - 48, - 15, - 7, - 30 - ], - [ - 26, - 21, - 50, - 52, - 4, - 56 - ], - [ - 19, - 17, - 2, - 14, - 57, - 22 - ], - [ - 3, - 35, - 37, - 45, - 1, - 6 - ], - [ - 13, - 30, - 23, - 39, - 57, - 51 - ], - [ - 9, - 34, - 48, - 17, - 27, - 26 - ], - [ - 47, - 32, - 59, - 2, - 28, - 57 - ], - [ - 18, - 12, - 31, - 20, - 52, - 36 - ], - [ - 14, - 15, - 36, - 30, - 58, - 31 - ], - [ - 5, - 4, - 26, - 19, - 22, - 59 - ], - [ - 22, - 8, - 34, - 26, - 42, - 52 - ], - [ - 20, - 5, - 52, - 4, - 30, - 33 - ], - [ - 34, - 17, - 2, - 5, - 39, - 20 - ], - [ - 54, - 61, - 30, - 14, - 20, - 25 - ], - [ - 56, - 41, - 2, - 5, - 54, - 44 - ], - [ - 12, - 34, - 47, - 7, - 54, - 59 - ], - [ - 9, - 11, - 2, - 43, - 33, - 50 - ], - [ - 8, - 10, - 19, - 20, - 32, - 18 - ], - [ - 17, - 38, - 6, - 29, - 49, - 41 - ], - [ - 38, - 45, - 52, - 57, - 63, - 27 - ], - [ - 26, - 28, - 3, - 40, - 5, - 47 - ], - [ - 60, - 49, - 35, - 34, - 38, - 55 - ], - [ - 61, - 24, - 4, - 55, - 45, - 16 - ], - [ - 46, - 11, - 27, - 0, - 56, - 48 - ], - [ - 60, - 45, - 44, - 25, - 32, - 3 - ], - [ - 11, - 38, - 52, - 48, - 9, - 21 - ] - ], - [ - [ - 22, - 19, - 46, - 31, - 3, - 23 - ], - [ - 32, - 62, - 15, - 54, - 10, - 55 - ], - [ - 47, - 30, - 38, - 5, - 7, - 60 - ], - [ - 15, - 13, - 1, - 8, - 25, - 2 - ], - [ - 13, - 59, - 5, - 6, - 62, - 52 - ], - [ - 27, - 63, - 62, - 45, - 12, - 56 - ], - [ - 50, - 9, - 8, - 51, - 48, - 18 - ], - [ - 59, - 57, - 28, - 2, - 61, - 6 - ], - [ - 18, - 59, - 6, - 52, - 39, - 57 - ], - [ - 14, - 23, - 11, - 36, - 15, - 32 - ], - [ - 26, - 5, - 42, - 25, - 22, - 23 - ], - [ - 33, - 22, - 55, - 28, - 16, - 24 - ], - [ - 6, - 20, - 33, - 14, - 52, - 15 - ], - [ - 17, - 6, - 5, - 39, - 2, - 34 - ], - [ - 54, - 62, + 15, + 30, 25, - 61, - 21, - 14 - ], - [ - 41, - 45, - 14, - 5, - 2, - 54 - ], - [ - 47, - 34, - 10, - 31, - 5, - 41 - ], - [ - 9, - 2, - 36, - 6, - 43, 38 ], [ - 57, - 20, - 41, - 10, - 32, - 29 - ], - [ - 47, - 6, - 49, - 15, - 13, - 34 - ], - [ - 3, - 57, - 44, - 38, - 50, - 53 - ], - [ - 47, - 57, - 28, - 26, - 19, - 22 - ], - [ - 11, - 1, - 18, - 5, - 46, - 53 - ], - [ - 9, - 55, 4, - 3, - 25, - 10 - ], - [ - 45, - 11, - 0, - 63, - 48, - 57 - ], - [ - 42, - 10, - 43, - 3, - 49, - 39 - ], - [ - 38, - 15, - 17, - 10, - 41, - 8 - ] - ], - [ - [ 19, - 18, - 51, - 25, - 60, - 55 - ], - [ - 56, - 27, - 61, - 42, - 55, - 23 + 24, + 35, + 31, + 48 ], [ - 32, - 39, - 37, + 7, 46, - 20, - 52 - ], - [ - 41, - 21, - 37, - 13, - 57, - 2 - ], - [ - 10, - 9, 3, - 46, 58, - 32 - ], - [ - 62, - 27, - 57, - 43, - 2, - 53 - ], - [ - 50, - 21, - 19, - 48, - 15, - 6 - ], - [ - 14, - 21, - 54, - 8, - 57, - 28 + 30, + 41 ], [ - 6, - 18, - 4, - 52, + 58, + 9, 39, - 24 - ], - [ 32, - 38, - 23, - 53, - 25, - 17 - ], - [ - 26, - 22, - 25, - 5, - 42, - 33 - ], - [ - 5, - 22, - 16, - 9, - 61, - 55 + 29, + 40 ], [ - 49, - 33, - 30, - 25, + 40, + 37, 20, - 0 - ], - [ - 18, - 17, + 8, 25, - 63, - 39, - 11 - ], - [ - 62, - 54, - 10, - 5, - 58, - 37 - ], - [ - 42, - 41, - 5, - 24, - 37, - 54 - ], - [ - 47, - 34, - 27, - 10, - 22, 55 ], [ - 60, - 43, - 44, - 7, - 52, - 37 - ], - [ - 12, - 32, - 39, - 38, - 1, - 20 + 19, + 0, + 54, + 52, + 17, + 39 ], [ - 34, - 52, - 49, - 15, - 63, - 28 + 25, + 43, + 12, + 61, + 14, + 11 ], [ + 23, + 4, + 54, + 36, 28, - 39, - 3, - 26, - 22, - 30 + 33 ], [ - 29, - 1, 40, - 22, - 19, - 63 + 2, + 25, + 58, + 36, + 53 ], [ - 41, - 5, + 18, + 46, + 35, + 22, 53, - 33, - 26, - 39 + 16 ], [ - 25, - 38, - 4, - 34, - 49, - 51 + 2, + 6, + 63, + 14, + 42, + 11 ], [ - 29, + 35, + 7, 52, - 48, - 47, - 20, - 35 + 40, + 29, + 57 ], [ - 50, - 3, - 38, - 16, - 53, - 15 + 40, + 15, + 19, + 57, + 17, + 23 ], [ - 19, - 62, - 6, - 23, - 10, - 36 + 9, + 11, + 47, + 22, + 49, + 1 + ], + [ + 24, + 39, + 42, + 2, + 16, + 0 ] ], [ [ - 5, - 14, - 17, - 57, - 10, - 27 + 18, + 8, + 20, + 49, + 30, + 23 ], [ - 43, - 9, - 56, 1, - 14, - 33 + 27, + 26, + 22, + 59, + 36 ], [ - 63, - 35, 43, - 1, - 10, - 27 - ], - [ - 51, - 50, - 57, - 41, - 20, - 54 + 26, + 15, + 58, + 0, + 46 ], [ - 11, - 43, - 40, - 54, - 30, - 33 + 55, + 1, + 35, + 28, + 16, + 32 ], [ - 27, + 59, + 9, + 10, 53, - 63, - 30, - 15, + 12, 58 ], [ - 51, - 48, - 21, - 57, 9, - 50 + 2, + 27, + 11, + 61, + 43 ], [ - 21, - 14, - 59, - 52, - 28, - 8 + 16, + 57, + 63, + 23, + 19, + 12 ], [ - 6, - 18, + 46, + 45, + 26, 4, - 0, - 12, - 59 + 30, + 37 ], [ + 43, + 44, + 20, + 16, 14, - 53, - 30, - 4, - 42, - 36 + 9 ], [ - 26, + 34, + 47, 42, - 5, - 22, - 58, - 19 + 43, + 26, + 51 ], [ - 22, - 5, - 34, + 42, + 2, 38, - 52, - 29 - ], - [ - 49, - 33, - 0, + 45, 20, - 26, - 8 + 36 ], [ - 17, + 18, + 7, + 12, 2, - 39, - 13, - 44, - 63 + 43, + 60 ], [ - 10, - 54, - 20, - 37, - 5, - 59 + 1, + 28, + 12, + 3, + 29, + 33 + ], + [ + 25, + 13, + 0, + 63, + 2, + 62 + ], + [ + 18, + 36, + 6, + 29, + 19, + 15 ], [ + 1, + 42, + 63, 41, - 43, - 24, - 5, - 53, - 37 + 57, + 19 ], [ - 12, - 47, - 34, 57, + 54, + 5, 27, - 5 + 18, + 31 ], [ - 9, - 33, - 7, - 38, - 43, - 31 + 50, + 6, + 13, + 32, + 17, + 20 ], [ - 19, - 10, - 20, + 17, + 5, + 27, 32, - 18, - 1 + 1, + 55 ], [ - 36, - 63, - 15, 49, - 40, - 50 + 0, + 61, + 5, + 10, + 30 ], [ - 39, - 3, - 38, - 27, - 36, - 35 + 29, + 53, + 51, + 13, + 33, + 46 ], [ 29, - 19, - 40, - 47, - 62, - 14 + 17, + 21, + 30, + 14, + 40 ], [ 5, - 46, + 17, 33, - 53, - 49, - 21 + 32, + 18, + 28 ], [ - 55, - 4, - 27, - 61, 51, - 34 + 4, + 20, + 54, + 58, + 41 ], [ - 11, - 0, - 59, + 47, + 4, + 27, 48, - 14, - 27 + 37, + 60 ], [ - 7, - 32, 3, - 51, - 18, - 14 + 26, + 12, + 59, + 2, + 48 ], [ - 11, + 46, + 43, + 18, + 20, 9, - 36, - 48, - 0, - 46 + 53 ] ], [ [ - 63, - 62, - 60, + 57, + 9, 19, - 23, - 56 + 51, + 18, + 41 ], [ + 28, + 57, + 36, + 8, 48, - 32, - 1, - 35, - 5, - 21 + 60 ], [ - 22, - 24, - 46, - 58, + 2, + 51, 59, - 60 + 5, + 34, + 9 ], [ - 27, - 37, - 50, - 28, - 61, - 6 + 9, + 55, + 59, + 26, + 4, + 2 ], [ - 10, - 12, - 15, - 58, + 49, + 56, 35, + 42, + 30, 23 ], [ - 2, - 43, - 36, - 57, + 18, 30, - 20 - ], - [ - 21, - 14, + 22, + 29, 19, - 41, - 63, - 42 - ], - [ - 46, - 7, - 35, - 43, - 21, - 36 + 52 ], [ - 6, - 18, - 0, - 35, - 40, - 4 + 39, + 34, + 33, + 51, + 56, + 3 ], [ - 47, - 53, 32, - 38, + 21, 1, - 58 - ], - [ - 26, - 18, - 22, - 15, - 5, - 46 + 7, + 46, + 49 ], [ - 7, - 22, - 24, - 63, - 53, - 5 + 33, + 54, + 23, + 21, + 12, + 11 ], [ - 11, - 3, - 10, - 18, - 37, - 24 + 5, + 30, + 60, + 47, + 15, + 18 ], [ - 63, + 4, 18, - 13, - 39, - 17, - 0 + 46, + 27, + 20, + 22 ], [ - 62, + 22, + 59, 54, - 5, - 18, - 6, - 14 + 48, + 19, + 4 ], [ - 19, - 43, + 17, + 5, + 56, 31, - 30, - 57, - 42 + 49, + 4 ], [ - 24, - 5, - 51, - 34, + 29, 47, - 55 + 55, + 2, + 53, + 60 ], [ - 59, - 13, - 33, - 44, + 8, + 22, 11, - 41 + 44, + 36, + 15 ], [ - 58, - 32, - 1, - 7, - 39, - 38 + 60, + 44, + 30, + 57, + 54, + 39 ], [ - 49, - 13, - 28, - 9, - 34, - 7 + 7, + 44, + 27, + 20, + 2, + 61 ], [ + 48, + 17, + 21, + 37, 32, - 51, - 3, - 53, + 57 + ], + [ 48, - 13 + 32, + 46, + 6, + 61, + 42 ], [ - 40, - 15, - 37, + 4, 57, 1, - 23 + 36, + 0, + 30 ], [ - 5, + 7, + 17, + 61, + 53, + 21, + 63 + ], + [ + 60, + 14, 53, 35, + 18, + 55 + ], + [ + 10, + 15, + 33, 51, - 22, - 38 + 36, + 5 ], [ - 4, - 44, - 14, - 41, 11, - 38 + 4, + 19, + 51, + 21, + 52 ], [ 47, + 19, + 43, 48, - 8, - 61, - 13, - 9 + 58, + 4 ], [ - 23, 3, - 27, - 43, - 0, - 61 + 33, + 26, + 21, + 52, + 19 ], [ - 46, - 23, - 59, - 62, - 18, - 43 + 24, + 45, + 60, + 35, + 49, + 1 ] ], [ [ - 62, - 0, - 9, - 61, + 23, 54, - 32 + 53, + 58, + 11, + 8 ], [ - 45, + 11, + 30, + 15, + 59, + 63, + 55 + ], + [ + 20, + 58, 29, - 7, - 35, - 22, - 62 + 17, + 42, + 30 ], [ - 56, - 31, - 23, - 53, - 28, - 2 + 18, + 1, + 43, + 15, + 8, + 3 ], [ - 36, - 2, - 5, - 4, - 48, - 41 + 59, + 55, + 13, + 28, + 26, + 63 ], [ - 18, - 0, - 15, - 23, - 16, - 11 + 33, + 45, + 27, + 53, + 63, + 19 ], [ - 4, - 36, + 28, 57, - 56, - 14, - 5 + 51, + 54, + 34, + 53 ], [ - 59, - 56, - 21, - 8, - 7, - 33 + 24, + 13, + 3, + 47, + 45, + 50 ], [ - 12, - 23, - 7, - 2, - 28, - 34 + 47, + 59, + 33, + 57, + 37, + 35 ], [ - 17, - 6, 23, - 19, - 62, - 27 - ], - [ - 53, - 32, - 51, - 38, + 42, + 33, 41, - 58 + 48, + 30 ], [ - 18, + 42, 63, - 19, - 22, - 26, - 5 + 23, + 25, + 17, + 34 ], [ - 22, - 7, - 60, - 11, - 12, - 19 + 14, + 62, + 2, + 19, + 45, + 43 ], [ - 24, - 7, - 11, - 17, + 12, + 6, + 20, 30, - 37 + 29, + 17 ], [ - 47, - 13, - 63, - 50, - 39, - 45 + 6, + 3, + 17, + 8, + 27, + 31 ], [ - 62, + 19, + 59, + 2, + 10, 54, - 18, - 14, - 29, - 56 + 55 ], [ - 56, - 30, - 0, - 55, - 51, - 10 + 45, + 61, + 19, + 57, + 42, + 17 ], [ - 12, - 5, - 21, + 55, + 60, + 41, 34, - 63, - 29 + 35, + 53 ], [ - 42, + 36, + 2, 8, + 31, + 14, + 5 + ], + [ + 36, + 45, + 16, 38, + 51, + 21 + ], + [ 57, - 41, - 44 + 39, + 48, + 6, + 19, + 30 ], [ - 32, - 54, - 1, + 39, 37, - 49, + 59, + 48, + 42, 40 ], [ - 5, - 57, - 49, + 52, 28, - 34, - 10 - ], - [ + 33, 46, - 32, - 13, - 7, - 61, - 48 + 18, + 55 ], [ - 40, - 3, - 30, - 60, - 39, - 1 + 1, + 15, + 28, + 13, + 52, + 18 ], [ - 5, - 44, - 33, - 36, - 28, + 10, + 21, + 4, + 2, + 23, 31 ], [ - 4, - 11, - 28, + 45, + 9, 41, - 51, - 5 + 18, + 54, + 3 ], [ - 48, - 47, + 42, 61, - 28, - 60, + 36, + 3, + 19, 27 ], [ - 14, + 41, + 46, + 33, 3, - 26, - 12, - 53, - 61 - ], - [ - 54, - 60, - 49, - 35, 10, - 62 + 15 ] ], [ [ - 27, - 62, - 63, - 23, - 47, - 56 + 19, + 1, + 31, + 52, + 49, + 63 ], [ 7, - 4, - 2, - 10, - 35, - 36 + 47, + 5, + 60, + 22, + 46 ], [ + 59, + 30, 3, + 11, 0, - 27, - 62, - 50, - 60 + 19 ], [ - 36, - 3, + 43, 42, - 18, - 2, - 48 + 19, + 62, + 8, + 56 ], [ 61, - 12, - 27, - 10, 15, - 14 + 25, + 18, + 39, + 27 ], [ + 61, + 50, 36, + 45, 33, - 50, - 6, - 29, - 16 + 44 ], [ - 2, - 41, 8, - 43, - 40, - 59 + 37, + 52, + 1, + 2, + 41 ], [ 36, + 30, 53, - 15, - 7, - 37, - 2 + 11, + 16, + 29 ], [ - 3, + 14, 58, - 4, - 19, - 5, - 27 + 46, + 49, + 3, + 26 ], [ - 22, - 26, - 53, + 62, + 43, + 0, 45, - 25, - 38 + 22, + 46 ], [ + 6, + 56, + 45, 18, - 41, - 22, - 3, - 34, - 45 + 10, + 41 ], [ + 60, 21, - 53, - 5, - 22, - 28, - 9 + 50, + 47, + 30, + 35 ], [ 11, - 24, + 53, 28, - 16, - 8, - 35 + 56, + 41, + 39 ], [ - 38, + 23, 9, - 21, - 52, - 41, - 39 + 33, + 28, + 22, + 26 ], [ + 23, + 56, + 34, 27, - 54, - 14, - 5, - 8, + 2, 63 ], [ - 56, - 55, - 43, - 30, - 35, - 51 + 22, + 9, + 44, + 41, + 37, + 47 ], [ - 51, + 1, + 11, + 46, + 27, 3, - 14, - 60, - 5, - 26 + 52 ], [ - 62, - 30, - 8, + 51, + 37, + 17, 21, - 44, - 35 + 61, + 30 ], [ - 38, - 34, - 27, + 13, + 19, 32, - 22, - 8 + 5, + 2, + 9 ], [ - 22, - 53, - 28, - 60, - 12, - 49 + 41, + 32, + 6, + 47, + 29, + 56 ], [ - 49, - 32, - 41, 53, - 12, - 33 + 33, + 9, + 35, + 38, + 12 ], [ 40, - 45, - 22, - 44, - 1, - 34 + 19, + 51, + 7, + 26, + 22 ], [ + 37, 5, - 4, - 53, - 25, - 41, - 63 - ], - [ - 4, - 11, 25, - 43, - 31, - 14 - ], - [ - 31, - 52, - 5, - 58, - 48, - 43 - ], - [ - 23, - 1, - 28, - 3, - 0, - 10 + 46, + 34, + 53 ], [ - 27, - 45, - 32, - 4, - 30, - 6 - ] - ], - [ - [ - 41, - 2, - 42, 16, - 50, - 32 - ], - [ - 51, - 5, - 41, - 40, + 55, + 20, + 24, 44, - 21 + 53 ], [ - 43, - 1, - 29, - 55, 21, - 35 - ], - [ + 46, + 11, + 0, 36, - 58, - 25, - 3, - 18, - 54 + 7 ], [ - 31, + 18, 60, - 24, - 12, - 61, - 41 + 32, + 3, + 34, + 28 ], [ - 4, - 10, - 35, + 52, + 9, 36, - 0, - 43 + 48, + 11, + 41 + ] + ], + [ + [ + 17, + 37, + 31, + 32, + 63, + 50 ], [ - 45, - 43, - 63, - 35, - 36, - 48 + 12, + 2, + 9, + 32, + 47, + 17 ], [ 3, - 9, - 8, - 7, - 43, - 27 + 57, + 56, + 50, + 33, + 38 ], [ - 0, 43, - 4, - 40, - 18, - 44 + 42, + 19, + 52, + 8, + 17 ], [ - 34, + 61, + 39, + 27, + 12, 15, - 50, - 38, - 53, - 25 + 57 + ], + [ + 6, + 33, + 36, + 44, + 29, + 61 ], [ - 24, - 46, - 26, - 13, 2, - 18 + 41, + 42, + 15, + 52, + 5 ], [ - 57, + 36, 37, + 15, + 53, + 18, + 62 + ], + [ + 50, + 58, + 3, 5, + 16, + 4 + ], + [ + 16, 22, - 7, - 53 + 55, + 32, + 41, + 26 ], [ 3, - 9, + 41, + 56, + 45, 38, - 26, - 37, - 23 - ], - [ - 60, - 26, - 0, - 39, - 13, - 55 + 34 ], [ - 54, 21, - 58, - 6, - 55, - 14 + 53, + 50, + 34, + 38, + 35 ], [ + 35, 11, - 19, - 43, - 57, - 56, - 37 + 37, + 16, + 53, + 42 ], [ - 10, + 9, + 15, + 54, + 38, + 12, + 52 + ], + [ + 56, + 2, 24, - 60, - 21, - 5, - 58 + 23, + 9, + 5 ], [ - 31, - 20, - 63, + 2, 44, - 55, - 33 + 9, + 53, + 35, + 54 ], [ - 7, - 25, - 40, - 32, - 54, - 58 + 27, + 46, + 1, + 5, + 60, + 50 ], [ - 49, - 30, + 51, + 61, 50, - 13, - 28, - 55 + 54, + 33, + 44 ], [ - 30, - 3, - 51, - 61, - 0, - 16 + 2, + 38, + 9, + 19, + 5, + 32 ], [ - 57, - 46, - 43, - 40, - 30, - 22 + 32, + 28, + 6, + 15, + 0, + 33 ], [ - 52, - 33, - 51, - 19, + 26, + 53, 27, - 29 + 13, + 5, + 4 ], [ - 14, - 4, + 34, 40, + 45, + 55, + 62, + 53 + ], + [ 41, - 0, - 34 + 5, + 46, + 4, + 45, + 3 ], [ + 20, 4, - 13, - 8, - 61, + 22, 47, - 9 + 59, + 58 ], [ - 31, - 27, - 43, - 1, + 62, + 36, + 25, + 28, + 53, + 33 + ], + [ + 22, + 33, + 7, 3, + 54, 0 ], [ - 8, - 4, - 46, - 9, - 3, - 50 + 58, + 27, + 52, + 48, + 45, + 17 ] ] + ], + "routing_indices_token_indices": [ + 3, + 13, + 14, + 17, + 28, + 31, + 35, + 81, + 86, + 94 ] } } \ No newline at end of file diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp1_pp1/model_config.yaml b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp1_pp1/model_config.yaml index 2898070f957..377371b2370 100644 --- a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp1_pp1/model_config.yaml +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp1_pp1/model_config.yaml @@ -25,7 +25,7 @@ MODEL_ARGS: --load: ${CHECKPOINT_LOAD_PATH} --split: 949,50,1 --tokenizer-type: NullTokenizer - --vocab-size: 8192 + --vocab-size: 8193 --distributed-backend: nccl --lr: 0.00015 --lr-decay-style: cosine diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp4_sp_cp2/model_config.yaml b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp4_sp_cp2/model_config.yaml index 23bdaac5010..9332c934613 100644 --- a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp4_sp_cp2/model_config.yaml +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp4_sp_cp2/model_config.yaml @@ -27,7 +27,7 @@ MODEL_ARGS: --load: ${CHECKPOINT_LOAD_PATH} --split: 949,50,1 --tokenizer-type: NullTokenizer - --vocab-size: 8192 + --vocab-size: 8193 --distributed-backend: nccl --lr: 0.00015 --lr-decay-style: cosine diff --git a/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_g200/model_config.yaml b/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_g200/model_config.yaml index 1147dda6118..9c5f1807c2d 100644 --- a/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_g200/model_config.yaml +++ b/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_g200/model_config.yaml @@ -42,7 +42,7 @@ MODEL_ARGS: # Network size args --use-mcore-models: true - --spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec + --spec: megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec --is-hybrid-model: true --mamba-num-heads: 128 --num-layers: 88 @@ -90,7 +90,7 @@ MODEL_ARGS: --moe-shared-expert-compute-before-router: true # MTP args - --mtp-spec: megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec + --mtp-spec: megatron.core.models.hybrid.hybrid_layer_specs hybrid_stack_spec --mtp-num-layers: 2 --mtp-hybrid-override-pattern: \"*E\" --calculate-per-token-loss: true diff --git a/tests/test_utils/python_scripts/generate_jet_trigger_job.py b/tests/test_utils/python_scripts/generate_jet_trigger_job.py index 50d8598ae66..1df66175949 100644 --- a/tests/test_utils/python_scripts/generate_jet_trigger_job.py +++ b/tests/test_utils/python_scripts/generate_jet_trigger_job.py @@ -1,3 +1,5 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import pathlib from typing import Optional @@ -58,6 +60,16 @@ type=bool, help="Run one job as dependency to others as to warm up cache", ) +@click.option( + "--cadence", + required=False, + type=str, + default=None, + help=( + "Trigger cadence to filter tests by (pr|nightly|mergegroup). " + "Empty/unset disables the cadence filter." + ), +) def main( scope: str, environment: str, @@ -78,7 +90,12 @@ def main( wandb_experiment: Optional[str] = None, enable_lightweight_mode: bool = False, enable_warmup: Optional[bool] = None, + cadence: Optional[str] = None, ): + # Treat empty string as "no cadence filter" so callers can wire shell + # variables in directly without conditional flag emission. + cadence_arg = cadence or None + list_of_test_cases = [ test_case for test_case in recipe_parser.load_workloads( @@ -88,6 +105,7 @@ def main( test_cases=test_cases, platform=platform, tag=tag, + cadence=cadence_arg, ) if test_case.type != "build" ] diff --git a/tests/test_utils/python_scripts/launch_nemo_run_workload.py b/tests/test_utils/python_scripts/launch_nemo_run_workload.py index 6569110ea02..0597a3189da 100644 --- a/tests/test_utils/python_scripts/launch_nemo_run_workload.py +++ b/tests/test_utils/python_scripts/launch_nemo_run_workload.py @@ -67,6 +67,16 @@ def is_flaky_failure(concat_allranks_logs: str) -> bool: default=False, help="To enable lightweight mode", ) +@click.option( + "--cadence", + required=False, + type=str, + default=None, + help=( + "Trigger cadence to filter tests by (pr|nightly|mergegroup). " + "Empty/unset disables the cadence filter." + ), +) def main( scope, model, @@ -79,7 +89,10 @@ def main( hf_home: Optional[str] = None, tag: Optional[str] = None, enable_lightweight_mode: Optional[bool] = False, + cadence: Optional[str] = None, ): + cadence_arg = cadence or None + workloads = recipe_parser.load_workloads( container_image="none", scope=scope, @@ -89,6 +102,7 @@ def main( container_tag="none", platform=platform, tag=tag, + cadence=cadence_arg, ) workloads = [workload for workload in workloads if workload.type != "build"] diff --git a/tests/test_utils/python_scripts/recipe_parser.py b/tests/test_utils/python_scripts/recipe_parser.py index a04340407e3..2e2e9f45d79 100644 --- a/tests/test_utils/python_scripts/recipe_parser.py +++ b/tests/test_utils/python_scripts/recipe_parser.py @@ -12,6 +12,22 @@ logger = logging.getLogger(__name__) +DEFAULT_CADENCE = ["pr", "nightly", "mergegroup"] +ALLOWED_CADENCE_VALUES = set(DEFAULT_CADENCE) + + +def _validate_cadence(cadence: List[str], test_case: str) -> None: + if not isinstance(cadence, list): + raise ValueError( + f"cadence for test_case {test_case} must be a list, got {type(cadence).__name__}" + ) + invalid = [c for c in cadence if c not in ALLOWED_CADENCE_VALUES] + if invalid: + raise ValueError( + f"Invalid cadence value(s) {invalid} for test_case {test_case}. " + f"Allowed: {sorted(ALLOWED_CADENCE_VALUES)}" + ) + class dotdict(dict): """dot.notation access to dictionary attributes""" @@ -47,14 +63,35 @@ def flatten_products(workload_manifest: dotdict) -> dotdict: continue test_case = product["test_case"][0] + # Outer-level cadence (next to test_case) acts as a default for every + # inner products block under this test_case. Inner cadence wins when + # both are present. + outer_cadence = product.get("cadence") + if outer_cadence is not None: + _validate_cadence(outer_cadence, test_case) + for param_dict in product["products"]: - # Generate all combinations of parameter values - param_combinations = itertools.product(*param_dict.values()) + # cadence is a list-valued attribute, not a cartesian dimension. + # Pull it out of the cartesian product before expansion so a list + # like ["pr", "nightly"] doesn't multiply the workload count. + inner_cadence = param_dict.get("cadence") + if inner_cadence is not None: + _validate_cadence(inner_cadence, test_case) + cartesian_keys = [k for k in param_dict.keys() if k != "cadence"] + cartesian_values = [param_dict[k] for k in cartesian_keys] + + # Resolve effective cadence: inner overrides outer, default loose. + effective_cadence = inner_cadence if inner_cadence is not None else outer_cadence + if effective_cadence is None: + effective_cadence = list(DEFAULT_CADENCE) + + param_combinations = itertools.product(*cartesian_values) for value_combination in param_combinations: # Map parameter names to their values - flattened = dict(zip(param_dict.keys(), value_combination)) + flattened = dict(zip(cartesian_keys, value_combination)) flattened["test_case"] = test_case + flattened["cadence"] = effective_cadence flattened_products.append(flattened) workload_manifest.products = flattened_products @@ -137,6 +174,31 @@ def filter_by_scope(workload_manifests: List[dotdict], scope: str) -> List[dotdi return workload_manifests +def filter_by_cadence(workload_manifests: List[dotdict], cadence: Optional[str]) -> List[dotdict]: + """Returns workloads whose cadence list includes the requested cadence value. + + A cadence of None disables the filter (used for the label-based bypass path). + Workloads missing a cadence field default to all triggers (loose default). + """ + if cadence is None: + return workload_manifests + + if cadence not in ALLOWED_CADENCE_VALUES: + raise ValueError(f"Invalid cadence {cadence!r}. Allowed: {sorted(ALLOWED_CADENCE_VALUES)}") + + filtered = list( + workload_manifest + for workload_manifest in workload_manifests + if cadence in workload_manifest.spec.get("cadence", DEFAULT_CADENCE) + ) + + if len(filtered) == 0: + logger.info("No test_case found for cadence %s!", cadence) + return [] + + return filtered + + def filter_by_environment(workload_manifests: List[dotdict], environment: str) -> List[dotdict]: workload_manifests_copy = list( @@ -232,6 +294,7 @@ def load_workloads( test_case: Optional[str] = None, container_image: Optional[str] = None, record_checkpoints: Optional[str] = None, + cadence: Optional[str] = None, ) -> List[dotdict]: """Return all workloads from disk that match scope and platform.""" recipes_dir = BASE_PATH / ".." / "recipes" @@ -246,6 +309,10 @@ def load_workloads( if scope: workloads = filter_by_scope(workload_manifests=workloads, scope=scope) + + if workloads and cadence: + workloads = filter_by_cadence(workload_manifests=workloads, cadence=cadence) + if workloads and environment: workloads = filter_by_environment(workload_manifests=workloads, environment=environment) diff --git a/tests/test_utils/recipes/h100/flextron.yaml b/tests/test_utils/recipes/h100/flextron.yaml new file mode 100644 index 00000000000..9bd40eaa493 --- /dev/null +++ b/tests/test_utils/recipes/h100/flextron.yaml @@ -0,0 +1,62 @@ +type: basic +format_version: 1 +maintainers: [mcore] +loggers: [stdout] +spec: + name: "{test_case}_{environment}_{platforms}" + model: hybrid + build: mcore-pyt-{environment} + nodes: 1 + gpus: 8 + n_repeat: 1 + platforms: dgx_h100 + script_setup: | + unset https_proxy + echo "machine gitlab-master.nvidia.com login okoenig password $RO_API_TOKEN" | tee -a /root/.netrc + + # Checkout latest + cd /opt + rm -rf /opt/megatron-lm; mkdir megatron-lm; cd megatron-lm + git init + git remote add origin $MCORE_REPO + git fetch origin '+refs/merge-requests/*:refs/remotes/merge-requests/*' + git fetch origin $MCORE_MR_COMMIT + git checkout $MCORE_MR_COMMIT + git rev-parse HEAD + + # Checkout backwards-ref + cd /opt + rm -rf /opt/megatron-lm-legacy; mkdir megatron-lm-legacy; cd megatron-lm-legacy + git init + git remote add origin $MCORE_REPO + git fetch origin $MCORE_BACKWARDS_COMMIT + git checkout $MCORE_BACKWARDS_COMMIT + git rev-parse HEAD + rm -rf megatron; cp -a /opt/megatron-lm/megatron ./ + script: |- + ls + cd /opt/megatron-lm + + ARGUMENTS=( + "DATA_PATH=/mnt/artifacts" + "DATA_CACHE_PATH=/workspace/data/cache" + "OUTPUT_PATH={assets_dir}" + "TENSORBOARD_PATH={assets_dir}/tensorboard" + "CHECKPOINT_SAVE_PATH={artifacts_dir}/checkpoints" + "CHECKPOINT_LOAD_PATH=/mnt/artifacts/model/{name}" + "TRAINING_SCRIPT_PATH=megatron/elastification/pretrain_hybrid_flex.py" + "TRAINING_PARAMS_PATH=./tests/functional_tests/test_cases/{model}/{test_case}/model_config.yaml" + "GOLDEN_VALUES_PATH=./tests/functional_tests/test_cases/{model}/{test_case}/golden_values_{environment}_{platforms}.json" + "N_REPEAT={n_repeat}" + "ENABLE_LIGHTWEIGHT_MODE=${{ENABLE_LIGHTWEIGHT_MODE}}" + "RECORD_CHECKPOINTS=${{RECORD_CHECKPOINTS}}" + ) + + bash ./tests/functional_tests/shell_test_utils/run_ci_test.sh ${{ARGUMENTS[@]}} + +products: + - test_case: [hybrid_flextron_nightly_tp2_pp1_ep2_dgx_h100_1N8G] + products: + - environment: [dev] + scope: [nightly] + platforms: [dgx_h100] diff --git a/tests/test_utils/recipes/h100/mamba-dynamic-inference-with-coordinator.yaml b/tests/test_utils/recipes/h100/mamba-dynamic-inference-with-coordinator.yaml new file mode 100644 index 00000000000..77951e97d66 --- /dev/null +++ b/tests/test_utils/recipes/h100/mamba-dynamic-inference-with-coordinator.yaml @@ -0,0 +1,62 @@ +type: basic +format_version: 1 +maintainers: [mcore] +loggers: [stdout] +spec: + name: '{test_case}_{environment}_{platforms}' + model: hybrid + build: mcore-pyt-{environment} + nodes: 1 + gpus: 8 + n_repeat: 1 + platforms: dgx_a100 + script_setup: | + unset https_proxy + echo "machine gitlab-master.nvidia.com login okoenig password $RO_API_TOKEN" | tee -a /root/.netrc + + # Checkout latest + cd /opt + rm -rf /opt/megatron-lm; mkdir megatron-lm; cd megatron-lm + git init + git remote add origin $MCORE_REPO + git fetch origin '+refs/merge-requests/*:refs/remotes/merge-requests/*' + git fetch origin $MCORE_MR_COMMIT + git checkout $MCORE_MR_COMMIT + git rev-parse HEAD + # Checkout backwards-ref + cd /opt + rm -rf /opt/megatron-lm-legacy; mkdir megatron-lm-legacy; cd megatron-lm-legacy + git init + git remote add origin $MCORE_REPO + git fetch origin $MCORE_BACKWARDS_COMMIT + git checkout $MCORE_BACKWARDS_COMMIT + git rev-parse HEAD + rm -rf megatron; cp -a /opt/megatron-lm/megatron ./ + script: |- + ls + cd /opt/megatron-lm + + ARGUMENTS=( + "CHECKPOINT_LOAD_PATH=/mnt/artifacts" + "CHECKPOINT_SAVE_PATH=/tmp/checkpoints" + "DATA_PATH=null" + "DATA_CACHE_PATH=/workspace/data/cache" + "TRAINING_SCRIPT_PATH=examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py" + "TRAINING_PARAMS_PATH=./tests/functional_tests/test_cases/{model}/{test_case}/model_config.yaml" + "GOLDEN_VALUES_PATH=./tests/functional_tests/test_cases/{model}/{test_case}/golden_values_{environment}_{platforms}.json" + "OUTPUT_PATH={assets_dir}" + "TENSORBOARD_PATH={assets_dir}/tensorboard" + "INFERENCE_OUTPUT_PATH={assets_dir}/golden_values_{environment}_{platforms}.json" + "N_REPEAT={n_repeat}" + "ENABLE_LIGHTWEIGHT_MODE=${{ENABLE_LIGHTWEIGHT_MODE}}" + "RECORD_CHECKPOINTS=${{RECORD_CHECKPOINTS}}" + ) + + bash ./tests/functional_tests/shell_test_utils/run_ci_test.sh ${{ARGUMENTS[@]}} + +products: + - test_case: [hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill] + products: + - environment: [dev] + scope: [mr, mr-github] + platforms: [dgx_h100] diff --git a/tests/test_utils/recipes/h100/mamba-dynamic-inference.yaml b/tests/test_utils/recipes/h100/mamba-dynamic-inference.yaml index 495a4d130b0..aa78cdb8316 100644 --- a/tests/test_utils/recipes/h100/mamba-dynamic-inference.yaml +++ b/tests/test_utils/recipes/h100/mamba-dynamic-inference.yaml @@ -65,3 +65,8 @@ products: - environment: [dev] scope: [mr, mr-github] platforms: [dgx_h100] + - test_case: [hybrid_dynamic_inference_tp1_pp1_dp8_583m_flashinfer] + products: + - environment: [dev] + scope: [mr, mr-github] + platforms: [dgx_h100] diff --git a/tests/test_utils/recipes/h100/mamba.yaml b/tests/test_utils/recipes/h100/mamba.yaml index 703fb53160f..72b44495617 100644 --- a/tests/test_utils/recipes/h100/mamba.yaml +++ b/tests/test_utils/recipes/h100/mamba.yaml @@ -44,7 +44,7 @@ spec: "TENSORBOARD_PATH={assets_dir}/tensorboard" "CHECKPOINT_SAVE_PATH={artifacts_dir}/checkpoints" "CHECKPOINT_LOAD_PATH=/mnt/artifacts/model/{name}" - "TRAINING_SCRIPT_PATH=pretrain_mamba.py" + "TRAINING_SCRIPT_PATH=pretrain_hybrid.py" "TRAINING_PARAMS_PATH=./tests/functional_tests/test_cases/{model}/{test_case}/model_config.yaml" "GOLDEN_VALUES_PATH=./tests/functional_tests/test_cases/{model}/{test_case}/golden_values_{environment}_{platforms}.json" "N_REPEAT={n_repeat}" diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 10aec38b15c..ef3d87c7c6d 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -44,7 +44,7 @@ def cleanup(): yield if torch.distributed.is_initialized(): try: - torch.distributed.barrier(timeout=timedelta(seconds=300)) + torch.distributed.barrier() except Exception: return torch.distributed.destroy_process_group() diff --git a/tests/unit_tests/data/test_prepare_cache.py b/tests/unit_tests/data/test_prepare_cache.py new file mode 100644 index 00000000000..1e69bc48de7 --- /dev/null +++ b/tests/unit_tests/data/test_prepare_cache.py @@ -0,0 +1,288 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import os +import random +from argparse import Namespace + +import pytest +import torch + +from megatron.core.datasets.blended_dataset import BlendedDataset +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDataset +from megatron.core.datasets.indexed_dataset import DType, IndexedDatasetBuilder +from megatron.core.datasets.utils import compile_helpers +from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer +from tests.unit_tests.dist_checkpointing import TempNamedDir +from tests.unit_tests.test_utilities import Utils +from tools.prepare_cache import ( + _normalize_prepare_cache_args, + build_dataset_caches, + core_gpt_dataset_config_from_args, +) + + +def _build_null_tokenizer(vocab_size: int = 2048): + return build_tokenizer( + Namespace( + vocab_size=vocab_size, + tokenizer_type="NullTokenizer", + padded_vocab_size=None, + rank=0, + make_vocab_size_divisible_by=128, + tensor_model_parallel_size=1, + ) + ) + + +def _initialize_test_environment() -> None: + if torch.distributed.is_available(): + Utils.initialize_distributed() + if torch.distributed.get_rank() == 0: + compile_helpers() + torch.distributed.barrier() + else: + compile_helpers() + + +def _create_file_prefixes(tokenizer, dataset_dir, number_of_files: int = 4) -> list[str]: + os.makedirs(dataset_dir, exist_ok=True) + + file_prefixes = [] + for i in range(number_of_files): + file_prefix = os.path.join(dataset_dir, f"file_{i}") + builder = IndexedDatasetBuilder( + file_prefix + ".bin", dtype=DType.optimal_dtype(tokenizer.vocab_size) + ) + + for j in range(32): + tokens = [int((i * 97 + j * 13 + k) % tokenizer.vocab_size) for k in range(64)] + builder.add_document(tokens, [len(tokens)]) + + builder.finalize(file_prefix + ".idx") + file_prefixes.append(file_prefix) + + return file_prefixes + + +def _create_shared_file_prefixes(tokenizer, dataset_dir, number_of_files: int = 4) -> list[str]: + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + file_prefixes = _create_file_prefixes(tokenizer, dataset_dir, number_of_files) + else: + file_prefixes = [os.path.join(dataset_dir, f"file_{i}") for i in range(number_of_files)] + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + random.seed(1234) # NOTE(asolergi-nv): re-sync random state across all ranks + + return file_prefixes + + +def _build_prepare_cache_args(file_prefixes, data_cache_path, **overrides): + args = dict( + seed=1234, + seq_length=16, + split="70,20,10", + data_path=file_prefixes, + train_data_path=None, + valid_data_path=None, + test_data_path=None, + per_split_data_args_path=None, + data_args_path=None, + per_dataset_sequences_path=None, + data_cache_path=str(data_cache_path), + mmap_bin_files=True, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + create_attention_mask_in_dataloader=False, + object_storage_cache_path=None, + mid_level_dataset_surplus=0.005, + allow_ambiguous_pad_tokens=False, + dataloader_fast_cache_load=True, + dataloader_defer_npy_index_mmap=True, + context_parallel_size=1, + data_parallel_size=4, + tensor_model_parallel_size=1, + sequence_parallel=False, + dynamic_context_parallel=False, + multiple_validation_sets=False, + full_validation=False, + num_dataset_builder_threads=1, + tokenizer_type="NullTokenizer", + vocab_size=2048, + padded_vocab_size=None, + make_vocab_size_divisible_by=128, + rank=0, + world_size=4, + train_samples=None, + train_iters=4, + skip_train=False, + eval_iters=2, + eval_interval=2, + start_eval_at_iter=None, + global_batch_size=8, + phase_transition_iterations=None, + iteration=0, + mock_data=False, + sft=False, + fim_data=False, + step_batch_size_schedule=None, + ) + args.update(overrides) + return Namespace(**args) + + +def test_prepare_cache_builds_blended_dataset_cache(tmp_path_dist_ckpt): + _initialize_test_environment() + + tokenizer = _build_null_tokenizer() + + with TempNamedDir( + tmp_path_dist_ckpt / "test_prepare_cache_builds_blended_dataset_cache", sync=True + ) as temp_dir: + file_prefixes = _create_shared_file_prefixes(tokenizer, os.path.join(temp_dir, "dataset")) + args = _build_prepare_cache_args(file_prefixes, temp_dir / "cache") + + summary = build_dataset_caches(args) + + assert args.dataloader_fast_cache_load is False + assert args.dataloader_defer_npy_index_mmap is False + assert summary["train_valid_test_num_samples"] == (32, 48, 16) + assert summary["train_dataset_length"] == 32 + assert summary["valid_dataset_length"] == 48 + assert summary["test_dataset_length"] == 16 + assert list((temp_dir / "cache").glob("*document_index.npy")) + assert list((temp_dir / "cache").glob("*dataset_index.npy")) + + +def test_prepare_cache_world_size_override(): + args = Namespace(rank=11, world_size=1, prepare_cache_world_size=8) + + _normalize_prepare_cache_args(args) + + assert args.rank == 0 + assert args.world_size == 8 + + +def test_prepare_cache_builds_and_hits_per_split_dataset_cache(tmp_path_dist_ckpt): + _initialize_test_environment() + + tokenizer = _build_null_tokenizer() + + with TempNamedDir( + tmp_path_dist_ckpt / "test_prepare_cache_builds_and_hits_per_split_dataset_cache", sync=True + ) as temp_dir: + file_prefixes = _create_shared_file_prefixes(tokenizer, os.path.join(temp_dir, "dataset")) + args = _build_prepare_cache_args( + None, + temp_dir / "cache", + split=None, + data_path=None, + train_data_path=[50, file_prefixes[0], 50, file_prefixes[1]], + valid_data_path=[file_prefixes[2]], + test_data_path=[file_prefixes[3]], + ) + + summary = build_dataset_caches(args) + + assert summary["train_valid_test_num_samples"] == (32, 48, 16) + assert list((temp_dir / "cache").glob("*description.txt")) + + slow_args = _build_prepare_cache_args( + None, + temp_dir / "cache", + split=None, + data_path=None, + train_data_path=[50, file_prefixes[0], 50, file_prefixes[1]], + valid_data_path=[file_prefixes[2]], + test_data_path=[file_prefixes[3]], + dataloader_fast_cache_load=False, + dataloader_defer_npy_index_mmap=False, + ) + slow_config = core_gpt_dataset_config_from_args(slow_args) + train_slow, valid_slow, test_slow = BlendedMegatronDatasetBuilder( + GPTDataset, list(summary["train_valid_test_num_samples"]), lambda: True, slow_config + ).build() + + fast_args = _build_prepare_cache_args( + None, + temp_dir / "cache", + split=None, + data_path=None, + train_data_path=[50, file_prefixes[0], 50, file_prefixes[1]], + valid_data_path=[file_prefixes[2]], + test_data_path=[file_prefixes[3]], + dataloader_fast_cache_load=True, + dataloader_defer_npy_index_mmap=True, + ) + fast_config = core_gpt_dataset_config_from_args(fast_args) + train_fast, valid_fast, test_fast = BlendedMegatronDatasetBuilder( + GPTDataset, list(summary["train_valid_test_num_samples"]), lambda: True, fast_config + ).build() + + assert isinstance(train_fast, BlendedDataset) + assert train_fast.dataset_index is None + assert train_fast.dataset_sample_index is None + assert isinstance(valid_fast, GPTDataset) + assert valid_fast.document_index is None + assert valid_fast.sample_index is None + assert valid_fast.shuffle_index is None + assert isinstance(test_fast, GPTDataset) + assert test_fast.document_index is None + assert test_fast.sample_index is None + assert test_fast.shuffle_index is None + + assert summary["train_dataset_length"] == len(train_slow) == len(train_fast) == 32 + assert summary["valid_dataset_length"] == len(valid_slow) == len(valid_fast) + assert summary["test_dataset_length"] == len(test_slow) == len(test_fast) + assert summary["valid_dataset_length"] >= summary["train_valid_test_num_samples"][1] + assert summary["test_dataset_length"] >= summary["train_valid_test_num_samples"][2] + assert torch.all(train_slow[0]["tokens"] == train_fast[0]["tokens"]) + assert torch.all(valid_slow[0]["tokens"] == valid_fast[0]["tokens"]) + assert torch.all(test_slow[0]["tokens"] == test_fast[0]["tokens"]) + + assert train_fast.dataset_index is not None + assert train_fast.dataset_sample_index is not None + assert valid_fast.document_index is not None + assert valid_fast.sample_index is not None + assert valid_fast.shuffle_index is not None + assert test_fast.document_index is not None + assert test_fast.sample_index is not None + assert test_fast.shuffle_index is not None + + +@pytest.mark.parametrize( + ("flag_name", "flag_value", "message"), + [ + ("mock_data", True, "--mock-data"), + ("sft", True, "--sft"), + ("fim_data", True, "--fim-data"), + ("step_batch_size_schedule", [(0, 8)], "--step-batch-size-schedule"), + ], +) +def test_prepare_cache_rejects_unsupported_modes(tmp_path, flag_name, flag_value, message): + args = _build_prepare_cache_args([], tmp_path / "cache", **{flag_name: flag_value}) + + with pytest.raises(ValueError, match=message): + build_dataset_caches(args) + + +def test_prepare_cache_builds_with_train_samples(tmp_path_dist_ckpt): + _initialize_test_environment() + + tokenizer = _build_null_tokenizer() + + with TempNamedDir( + tmp_path_dist_ckpt / "test_prepare_cache_builds_with_train_samples", sync=True + ) as temp_dir: + file_prefixes = _create_shared_file_prefixes(tokenizer, os.path.join(temp_dir, "dataset")) + args = _build_prepare_cache_args( + file_prefixes, temp_dir / "cache", train_iters=None, train_samples=32 + ) + + summary = build_dataset_caches(args) + + assert args.train_iters == 32 // args.global_batch_size + assert summary["train_valid_test_num_samples"][0] == 32 diff --git a/tests/unit_tests/dist_checkpointing/test_async_save.py b/tests/unit_tests/dist_checkpointing/test_async_save.py index cbb0b3f79b7..1575e01e0d0 100644 --- a/tests/unit_tests/dist_checkpointing/test_async_save.py +++ b/tests/unit_tests/dist_checkpointing/test_async_save.py @@ -107,3 +107,13 @@ def test_get_async_strategy_no_nvrx_installed(self, async_strategy): assert strategy == "mcore" assert module == MCoreAsyncRequest + + def test_get_async_strategy_missing_nvrx_cached_metadata_reader(self): + with mock.patch.dict( + 'sys.modules', + { + 'nvidia_resiliency_ext.checkpointing.async_ckpt.cached_metadata_filesystem_reader': None + }, + ): + with pytest.raises(ModuleNotFoundError): + get_async_strategy("nvrx", module="CachedMetadataFileSystemReader") diff --git a/tests/unit_tests/dist_checkpointing/test_integrity.py b/tests/unit_tests/dist_checkpointing/test_integrity.py new file mode 100644 index 00000000000..e87af62af93 --- /dev/null +++ b/tests/unit_tests/dist_checkpointing/test_integrity.py @@ -0,0 +1,106 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import json +import os +from pathlib import Path + +import pytest +import torch + +from megatron.core.dist_checkpointing import ShardedTensor, load, save +from megatron.core.dist_checkpointing.core import CheckpointingException +from megatron.core.dist_checkpointing.validation import ( + save_integrity_manifest, + verify_integrity_manifest, +) +from tests.unit_tests.dist_checkpointing import TempNamedDir +from tests.unit_tests.test_utilities import Utils + + +@pytest.fixture +def init_model_parallel(): + """Init torch distributed.""" + Utils.initialize_model_parallel(1, 1) + yield # Run the actual test. + Utils.destroy_model_parallel() + + +class TestIntegrity: + def test_save_verify_integrity_manifest_with_ckpt(self, tmp_path_dist_ckpt): + Utils.initialize_model_parallel(1, 1) + state_dict = { + 'sd_keyA': ShardedTensor.from_rank_offsets( + 'keyA', torch.ones(1, 1), replica_id=Utils.rank + ), + 'rank': 0, + } + load_state_dict = { + 'sd_keyA': ShardedTensor.from_rank_offsets( + 'keyA', torch.empty(1, 1), replica_id=Utils.rank + ) + } + + with TempNamedDir( + tmp_path_dist_ckpt / 'test_save_integrity_manifest', sync=True + ) as ckpt_dir: + save(state_dict, ckpt_dir, verify_integrity=True) + + integrity_file = Path(ckpt_dir / "integrity.json") + assert integrity_file.is_file(), "integrity.json doesn't exist." + + with open(integrity_file, "r") as f: + data = json.load(f) + files = list(data["files"].keys()) + + assert "__0_0.distcp" in files + assert len(data["files"]["common.pt"]) == 64 + + loaded_state_dict = load(load_state_dict, ckpt_dir, verify_integrity=True) + + Utils.destroy_model_parallel() + + def test_save_verify_integrity_manifest_directly(self, init_model_parallel, tmp_path_dist_ckpt): + with TempNamedDir( + tmp_path_dist_ckpt / 'test_save_integrity_manifest_directly', sync=True + ) as ckpt_dir: + metadata_file = Path(ckpt_dir / "metadata.json") + with open(metadata_file, "w") as f: + data = {"test_metadata": 1} + json.dump(data, f) + + if torch.distributed.get_rank() == 0: + save_integrity_manifest(ckpt_dir) + torch.distributed.barrier() + integrity_file = Path(ckpt_dir / "integrity.json") + assert integrity_file.is_file(), "integrity.json doesn't exist." + + with open(integrity_file, "r") as f: + data = json.load(f) + files = list(data["files"].keys()) + + assert len(files) == 1 + assert len(data["files"]["metadata.json"]) == 64 + + verify_integrity_manifest(ckpt_dir) + + def test_save_verify_integrity_manifest_error(self, init_model_parallel, tmp_path_dist_ckpt): + with TempNamedDir( + tmp_path_dist_ckpt / 'test_save_integrity_manifest_error', sync=True + ) as ckpt_dir: + metadata_file = Path(ckpt_dir / "metadata.json") + + with open(metadata_file, "w") as f: + data = {"test_metadata": 1} + json.dump(data, f) + + if torch.distributed.get_rank() == 0: + save_integrity_manifest(ckpt_dir) + torch.distributed.barrier() + + with open(metadata_file, "w") as f: + data = {"test_metadata": 11} + json.dump(data, f) + + # CheckpointingException, hash mismatch + with pytest.raises(CheckpointingException): + verify_integrity_manifest(ckpt_dir) diff --git a/tests/unit_tests/dist_checkpointing/test_optimizer.py b/tests/unit_tests/dist_checkpointing/test_optimizer.py index 1db844dd9a8..149323707de 100644 --- a/tests/unit_tests/dist_checkpointing/test_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_optimizer.py @@ -589,7 +589,7 @@ def test_bucket_space_optimizer_save_load( ) as ckpt_dir_B, ): # Init model and optimizer with "src" bucket padding - with patch('megatron.core.distributed.param_and_grad_buffer.math.lcm') as lcm_mock: + with patch('megatron.core.optimizer.param_layout.math.lcm') as lcm_mock: lcm_mock.return_value = src_bucket_pad_divisor model_A, optimizer_A = setup_model_and_optimizer( @@ -615,7 +615,7 @@ def test_bucket_space_optimizer_save_load( parallel_state.get_model_parallel_group() ) # Init model and optimizer with "dest" bucket padding - with patch('megatron.core.distributed.param_and_grad_buffer.math.lcm') as lcm_mock: + with patch('megatron.core.optimizer.param_layout.math.lcm') as lcm_mock: lcm_mock.return_value = dest_bucket_pad_divisor model_B, optimizer_B = setup_model_and_optimizer( diff --git a/tests/unit_tests/dist_checkpointing/test_pipeline_parallel_layout.py b/tests/unit_tests/dist_checkpointing/test_pipeline_parallel_layout.py index 5f9c617893c..3e493539b1b 100644 --- a/tests/unit_tests/dist_checkpointing/test_pipeline_parallel_layout.py +++ b/tests/unit_tests/dist_checkpointing/test_pipeline_parallel_layout.py @@ -140,6 +140,7 @@ def create_args(): args.vocab_file = None args.add_position_embedding = False args.ckpt_assume_constant_structure = True + args.ckpt_load_validate_sharding_integrity = True args.dist_ckpt_strictness = "assume_ok_unexpected" args.fp16 = False args.bf16 = True @@ -152,6 +153,7 @@ def create_args(): args.dist_ckpt_optim_fully_reshardable = False args.distrib_optim_fully_reshardable_mem_efficient = False args.phase_transition_iterations = None + args.verify_integrity = False yield args diff --git a/tests/unit_tests/dist_checkpointing/test_rerun_state_machine_ckpt.py b/tests/unit_tests/dist_checkpointing/test_rerun_state_machine_ckpt.py new file mode 100644 index 00000000000..07a571e7d1b --- /dev/null +++ b/tests/unit_tests/dist_checkpointing/test_rerun_state_machine_ckpt.py @@ -0,0 +1,157 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +import torch + +from megatron.core.dist_checkpointing.mapping import ShardedObject +from megatron.core.rerun_state_machine import RerunMode, RerunState, RerunStateMachine +from tests.unit_tests.test_utilities import Utils + + +class TestRerunStateMachineCheckpointContract: + """Unit tests for ``RerunStateMachine.state_dict()``. + + These are fast, single-process tests that directly exercise the + ``state_dict()`` contract. They do not require a full save/load round + trip because the bug lives entirely in the structure returned by + ``state_dict()`` on the *first* save, which is what gets cached by + ``TorchDistSaveShardedStrategy`` when + ``--ckpt-assume-constant-structure`` is set. + """ + + def setup_method(self, method): + Utils.initialize_distributed() + + def teardown_method(self, method): + pass + + def _assert_sharded_object_shape(self, sh_obj): + """The ShardedObject must be uniquely keyed per rank with a shape + equal to the world size; this is what lets torch DCP reconstruct + the global object on load.""" + assert isinstance(sh_obj, ShardedObject) + assert sh_obj.key == "rerun_state_machine_state" + assert sh_obj.global_shape == (torch.distributed.get_world_size(),) + assert sh_obj.global_offset == (torch.distributed.get_rank(),) + + def test_steady_state_emits_sharded_object(self): + """Regression test for issue #4378. + + In the steady state (no pending rerun), ``state_dict()`` used to + return ``None`` which left the cached SavePlan with no entry for + ``rerun_state_machine_state``. After the fix, ``state_dict()`` must + emit a ShardedObject sentinel so the plan always includes the + rerun shard from the first save onwards. + """ + machine = RerunStateMachine(mode=RerunMode.VALIDATE_RESULTS) + assert machine.state == RerunState.NOT_RUNNING_YET + + sd = machine.state_dict(data_iterator=None, ckpt_format="torch_dist") + + assert sd is not None, ( + "state_dict() must not return None in the steady state when rerun is" + " enabled; otherwise --ckpt-assume-constant-structure caches a plan" + " that is missing the rerun_state_machine_state shard and the fault" + " save silently drops it (issue #4378)." + ) + assert set(sd.keys()) >= {"mode", "state", "current_iteration", "sharded"} + assert sd["state"] == RerunState.NOT_RUNNING_YET + self._assert_sharded_object_shape(sd["sharded"]) + + def test_fault_state_emits_sharded_object(self): + """When a fault is in flight, ``state_dict()`` continues to emit + the same ShardedObject structure, now carrying the real fault + payload instead of the sentinel values.""" + machine = RerunStateMachine(mode=RerunMode.VALIDATE_RESULTS) + machine.state = RerunState.WILL_RERUN_FROM_CHECKPOINT + machine.rerun_requested = True + machine.checkpoint_requested = True + + sd = machine.state_dict(data_iterator=None, ckpt_format="torch_dist") + + assert sd is not None + assert sd["state"] == RerunState.WILL_RERUN_FROM_CHECKPOINT + self._assert_sharded_object_shape(sd["sharded"]) + + def test_structure_constant_across_rerun_transition(self): + """This is the core invariant the Option-2 fix establishes: the + ShardedObject's key / global_shape / global_offset are identical + across the steady-state save and a subsequent fault save, so the + cached SavePlan built on the first save remains valid when a fault + triggers the second save.""" + machine = RerunStateMachine(mode=RerunMode.VALIDATE_RESULTS) + + # Save #1: steady state (mirrors every normal checkpoint during a + # healthy run). + sd_steady = machine.state_dict(data_iterator=None, ckpt_format="torch_dist") + + # Simulate the transition performed by should_run_forward_backward + # when a mismatching rerun result demands a fault checkpoint. + machine.state = RerunState.WILL_RERUN_FROM_CHECKPOINT + machine.rerun_requested = True + machine.checkpoint_requested = True + + # Save #2: fault save, performed in the same process with the same + # (already-cached) TorchDistSaveShardedStrategy. + sd_fault = machine.state_dict(data_iterator=None, ckpt_format="torch_dist") + + steady_obj = sd_steady["sharded"] + fault_obj = sd_fault["sharded"] + assert steady_obj.key == fault_obj.key + assert steady_obj.global_shape == fault_obj.global_shape + assert steady_obj.global_offset == fault_obj.global_offset + # Sanity: the payloads do differ between the two saves (sentinel vs + # real fault context). torch DCP caches structure, not contents, so + # this is fine. + assert steady_obj.data["rerun_requested"] is False + assert fault_obj.data["rerun_requested"] is True + + def test_steady_state_does_not_require_wrapped_data_iterator(self): + """In the steady state we skip ``_sanitize_data_iterators`` so the + caller isn't forced to wrap its training iterator in + ``RerunDataIterator`` just to satisfy checkpointing. The + requirement to wrap only kicks in once a rerun is pending.""" + machine = RerunStateMachine(mode=RerunMode.VALIDATE_RESULTS) + + # An unwrapped iterator would assert inside + # _sanitize_data_iterators. In steady state it must be accepted. + sd = machine.state_dict(data_iterator=iter([1, 2, 3]), ckpt_format="torch_dist") + + assert sd is not None + assert sd["sharded"].data["data_iterator_checkpoints"] is None + + def test_disabled_mode_returns_none(self): + """When the rerun state machine is disabled, ``state_dict()`` + returns ``None`` so disabled jobs don't pay any checkpoint + overhead.""" + machine = RerunStateMachine(mode=RerunMode.DISABLED) + + sd = machine.state_dict(data_iterator=None, ckpt_format="torch_dist") + + assert sd is None + + def test_non_torch_dist_format_returns_none(self): + """``ShardedObject`` is only supported by the ``torch_dist`` + format; for other formats ``state_dict()`` returns ``None`` + regardless of machine state.""" + machine = RerunStateMachine(mode=RerunMode.VALIDATE_RESULTS) + assert machine.state_dict(data_iterator=None, ckpt_format="torch") is None + + machine.state = RerunState.WILL_RERUN_FROM_CHECKPOINT + assert machine.state_dict(data_iterator=None, ckpt_format="torch") is None + + def test_force_overrides_short_circuits(self): + """``force=True`` is used on the load path to build a template + that mirrors whatever the checkpoint happens to contain. It + bypasses both the DISABLED short-circuit and the ckpt_format + short-circuit, matching the pre-existing load-side contract in + ``megatron/training/checkpointing.py``.""" + machine = RerunStateMachine(mode=RerunMode.DISABLED) + + sd = machine.state_dict(data_iterator=None, ckpt_format="torch_dist", force=True) + assert sd is not None + self._assert_sharded_object_shape(sd["sharded"]) + + # For legacy formats used on the load path (e.g. fsdp_dtensor), + # force=True still produces the template. + sd = machine.state_dict(data_iterator=None, ckpt_format="fsdp_dtensor", force=True) + assert sd is not None + self._assert_sharded_object_shape(sd["sharded"]) diff --git a/tests/unit_tests/dist_checkpointing/test_safe_globals.py b/tests/unit_tests/dist_checkpointing/test_safe_globals.py index dc09b2a292e..6034648b600 100755 --- a/tests/unit_tests/dist_checkpointing/test_safe_globals.py +++ b/tests/unit_tests/dist_checkpointing/test_safe_globals.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import os +import io +import pickle from argparse import Namespace from collections import OrderedDict from pickle import UnpicklingError @@ -8,6 +9,7 @@ import pytest import torch +from megatron.core.safe_globals import SafeUnpickler from megatron.core.utils import is_torch_min_version @@ -48,3 +50,16 @@ def test_unsafe_globals(self, tmp_path_dist_ckpt): # add class to safe globals torch.serialization.add_safe_globals([UnsafeClass]) torch.load(ckpt_path) + + +class TestSafeUnpickler: + def test_safe_types(self): + data = {"key": [1, 2.0, True, "s"], "od": OrderedDict(a=1)} + raw = pickle.dumps(data) + result = SafeUnpickler(io.BytesIO(raw)).load() + assert result == data + + def test_unsafe_types(self): + raw = pickle.dumps(UnsafeClass(123)) + with pytest.raises(pickle.UnpicklingError, match="Refusing to unpickle"): + SafeUnpickler(io.BytesIO(raw)).load() diff --git a/tests/unit_tests/dist_checkpointing/utils.py b/tests/unit_tests/dist_checkpointing/utils.py index 0aadaee3b29..8a9df54ddc8 100644 --- a/tests/unit_tests/dist_checkpointing/utils.py +++ b/tests/unit_tests/dist_checkpointing/utils.py @@ -150,6 +150,7 @@ def init_checkpointing_mock_args(args, ckpt_dir, fully_parallel=False): args.no_save_optim = False args.no_save_rng = False args.ckpt_assume_constant_structure = False + args.ckpt_load_validate_sharding_integrity = True args.log_progress = False args.auto_detect_ckpt_format = False args.exit_on_missing_checkpoint = False diff --git a/tests/unit_tests/distributed/test_param_and_grad_buffer.py b/tests/unit_tests/distributed/test_param_and_grad_buffer.py index 223383bba64..ac1fdfe2ed6 100644 --- a/tests/unit_tests/distributed/test_param_and_grad_buffer.py +++ b/tests/unit_tests/distributed/test_param_and_grad_buffer.py @@ -11,10 +11,39 @@ from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets +from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer from megatron.core.transformer import TransformerConfig from tests.unit_tests.test_utilities import TestModel, Utils +class TestModelWithExperts(torch.nn.Module): + """Model with both dense and expert-parallel parameters. + + Dense layers have the default allreduce=True. Expert layers have + allreduce=False on their parameters, which routes them to a separate + buffer with a different data-parallel group. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_dense_layers: int, + num_expert_layers: int, + bias: bool, + ): + super().__init__() + self.dense_layers = torch.nn.ModuleList( + [torch.nn.Linear(input_dim, output_dim, bias) for _ in range(num_dense_layers)] + ) + self.expert_layers = torch.nn.ModuleList( + [torch.nn.Linear(input_dim, output_dim, bias) for _ in range(num_expert_layers)] + ) + for layer in self.expert_layers: + for param in layer.parameters(): + param.allreduce = False + + def get_model_and_buffers( input_dim: int, output_dim: int, @@ -49,8 +78,18 @@ def get_model_and_buffers( # Wrap with DistributedDataParallel, and get underlying buffer. # Use dummy TransformerConfig with mostly default values. Avoid divide-by-zero # errors for num_attention_heads and num_layers. + # Pre-compute parameter layouts for the distributed optimizer. + full_param_layout = None + if use_distributed_optimizer: + all_params = [p for p in model.parameters() if p.requires_grad] + full_param_layout = DistributedOptimizer.compute_full_param_layout( + all_params, bucket_size, parallel_state.get_data_parallel_world_size(), ddp_config + ) model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config=ddp_config, module=model + TransformerConfig(num_attention_heads=1, num_layers=1), + ddp_config=ddp_config, + module=model, + full_param_layout=full_param_layout, ) assert len(model.buffers) == 1 param_and_grad_buffer = model.buffers[0] @@ -733,6 +772,14 @@ def mock_packed_shape(shape): average_in_collective=False, ) + # Pre-compute layout for distributed optimizer (with padding); + # otherwise use default (no padding). + param_layout = None + if use_distributed_optimizer: + param_layout = DistributedOptimizer._compute_per_buffer_param_layout( + params, bucket_size, dp_world_size, ddp_config, list(range(len(params))) + ) + with ( mock.patch( 'megatron.core.distributed.param_and_grad_buffer.is_nvfp4tensor', @@ -760,6 +807,7 @@ def mock_packed_shape(shape): param_indices=list(range(len(params))), nccl_ub=False, pg_collection=mock_pg, + param_layout=param_layout, ) return buffer, params @@ -884,3 +932,101 @@ def test_nvfp4_varied_param_sizes(self): assert buffer.param_index_map[params[1]] == (large_unpacked_start, large_unpacked_end, 0) assert buffer.param_index_map[params[0]] == (small_unpacked_start, small_unpacked_end, 0) + + +@pytest.mark.parametrize("use_distributed_optimizer", [False, True]) +def test_expert_parallel_params_get_separate_buffers(use_distributed_optimizer: bool): + """Verify that expert-parallel params (allreduce=False) land in separate buffers + with correctly scoped layouts and independent param_index_maps.""" + Utils.initialize_model_parallel() + + input_dim = 95 + output_dim = 95 + num_dense_layers = 3 + num_expert_layers = 2 + bucket_size = None # Single bucket per buffer. + + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=True, + use_distributed_optimizer=use_distributed_optimizer, + overlap_grad_reduce=True, + bucket_size=bucket_size, + average_in_collective=False, + ) + model = TestModelWithExperts( + input_dim=input_dim, + output_dim=output_dim, + num_dense_layers=num_dense_layers, + num_expert_layers=num_expert_layers, + bias=True, + ).bfloat16() + + full_param_layout = None + if use_distributed_optimizer: + all_params = [p for p in model.parameters() if p.requires_grad] + full_param_layout = DistributedOptimizer.compute_full_param_layout( + all_params, bucket_size, parallel_state.get_data_parallel_world_size(), ddp_config + ) + + ddp_model = DistributedDataParallel( + TransformerConfig(num_attention_heads=1, num_layers=1), + ddp_config=ddp_config, + module=model, + full_param_layout=full_param_layout, + ) + + # Should have exactly one dense buffer and one expert buffer. + assert len(ddp_model.buffers) == 1, f"Expected 1 dense buffer, got {len(ddp_model.buffers)}" + assert ( + len(ddp_model.expert_parallel_buffers) == 1 + ), f"Expected 1 expert buffer, got {len(ddp_model.expert_parallel_buffers)}" + + dense_buffer = ddp_model.buffers[0] + expert_buffer = ddp_model.expert_parallel_buffers[0] + + # Collect expected params for each buffer. + expected_dense_params = set() + expected_expert_params = set() + for param in model.parameters(): + if not param.requires_grad: + continue + if getattr(param, 'allreduce', True): + expected_dense_params.add(param) + else: + expected_expert_params.add(param) + + # Verify each buffer contains exactly the right params. + dense_buffer_params = set() + for bucket in dense_buffer.buckets: + dense_buffer_params.update(bucket.params) + assert ( + dense_buffer_params == expected_dense_params + ), "Dense buffer should contain exactly the dense params" + + expert_buffer_params = set() + for bucket in expert_buffer.buckets: + expert_buffer_params.update(bucket.params) + assert ( + expert_buffer_params == expected_expert_params + ), "Expert buffer should contain exactly the expert-parallel params" + + # Verify param_index_maps are scoped to their own buffer (no cross-contamination). + assert set(dense_buffer.param_index_map.keys()) == expected_dense_params + assert set(expert_buffer.param_index_map.keys()) == expected_expert_params + + # Verify both buffers have indices starting from 0 (independent index spaces). + dense_starts = [s for s, _, _ in dense_buffer.param_index_map.values()] + expert_starts = [s for s, _, _ in expert_buffer.param_index_map.values()] + assert min(dense_starts) == 0, "Dense buffer indices should start at 0" + assert min(expert_starts) == 0, "Expert buffer indices should start at 0" + + # Verify DP divisibility for distributed optimizer. + if use_distributed_optimizer: + dp_world_size = parallel_state.get_data_parallel_world_size() + for buffer_name, buffer in [("dense", dense_buffer), ("expert", expert_buffer)]: + assert buffer.numel % dp_world_size == 0, ( + f"{buffer_name} buffer numel ({buffer.numel}) should be " + f"divisible by dp_world_size ({dp_world_size})" + ) + + Utils.destroy_model_parallel() diff --git a/tests/unit_tests/distributed/test_param_layout.py b/tests/unit_tests/distributed/test_param_layout.py new file mode 100644 index 00000000000..99922e37970 --- /dev/null +++ b/tests/unit_tests/distributed/test_param_layout.py @@ -0,0 +1,478 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Tests for parameter layout computation functions. + +These tests verify the pure-computation layout functions without requiring +GPU or distributed setup: +- pad_to_divisor, pad_param_start, pad_bucket_end (shared padding utilities) +- group_params_for_buffers (parameter grouping by dtype/expert) +- _compute_default_per_buffer_param_layout (no-padding layout) +- DistributedOptimizer._compute_per_buffer_param_layout (padded layout) +- DistributedOptimizer.compute_full_param_layout (end-to-end layout) +""" + +import math +from unittest import mock + +import pytest +import torch + +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.distributed.param_and_grad_buffer import ( + _compute_default_per_buffer_param_layout, + group_params_for_buffers, +) +from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer +from megatron.core.optimizer.param_layout import ( + BufferKey, + pad_bucket_end, + pad_param_start, + pad_to_divisor, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_params(*shapes, dtype=torch.bfloat16): + """Create a list of nn.Parameters with the given shapes.""" + return [torch.nn.Parameter(torch.randn(s, dtype=dtype)) for s in shapes] + + +def _make_param_with_attrs(shape, dtype=torch.bfloat16, **attrs): + """Create an nn.Parameter with extra attributes (e.g. allreduce, shared_embedding).""" + param = torch.nn.Parameter(torch.randn(shape, dtype=dtype)) + for k, v in attrs.items(): + setattr(param, k, v) + return param + + +# --------------------------------------------------------------------------- +# Tests for shared padding utilities +# --------------------------------------------------------------------------- + + +class TestPaddingUtilities: + + def test_pad_to_divisor_exact_multiple(self): + assert pad_to_divisor(128, 64) == 128 + + def test_pad_to_divisor_rounds_up(self): + assert pad_to_divisor(65, 64) == 128 + + def test_pad_to_divisor_zero(self): + assert pad_to_divisor(0, 64) == 0 + + def test_pad_param_start(self): + assert pad_param_start(0) == 0 + assert pad_param_start(1) == 64 + assert pad_param_start(63) == 64 + assert pad_param_start(64) == 64 + assert pad_param_start(65) == 128 + + def test_pad_bucket_end_basic(self): + dp_size = 4 + divisor = math.lcm(dp_size, 128) + result = pad_bucket_end(1, dp_size, pad_for_high_nccl_busbw=False) + assert result == divisor + assert result % dp_size == 0 + assert result % 128 == 0 + + def test_pad_bucket_end_high_busbw(self): + dp_size = 4 + divisor = math.lcm(dp_size, 128, 2**16) + result = pad_bucket_end(1, dp_size, pad_for_high_nccl_busbw=True) + assert result == divisor + assert result % (2**16) == 0 + + def test_pad_bucket_end_already_aligned(self): + dp_size = 2 + divisor = math.lcm(dp_size, 128) + result = pad_bucket_end(divisor, dp_size, pad_for_high_nccl_busbw=False) + assert result == divisor + + +# --------------------------------------------------------------------------- +# Tests for group_params_for_buffers +# --------------------------------------------------------------------------- + + +class TestGroupParamsForBuffers: + + def test_single_dtype_no_experts(self): + """All bf16 params with no expert-parallel should go in one group.""" + params = _make_params((100, 100), (50, 50)) + result = group_params_for_buffers(params, grad_reduce_in_fp32=True) + + assert len(result) == 1 + key = list(result.keys())[0] + assert key == BufferKey(torch.bfloat16, torch.float, False) + group_params, indices = result[key] + assert group_params == params + assert indices == [0, 1] + + def test_grad_reduce_not_fp32(self): + """When grad_reduce_in_fp32=False, grad_dtype matches param dtype.""" + params = _make_params((100,)) + result = group_params_for_buffers(params, grad_reduce_in_fp32=False) + + key = list(result.keys())[0] + assert key.grad_dtype == torch.bfloat16 + + def test_expert_parallel_separation(self): + """Params with allreduce=False should be in a separate group.""" + dense = _make_param_with_attrs((100,)) + expert = _make_param_with_attrs((100,), allreduce=False) + result = group_params_for_buffers([dense, expert], grad_reduce_in_fp32=True) + + assert len(result) == 2 + dense_key = BufferKey(torch.bfloat16, torch.float, False) + expert_key = BufferKey(torch.bfloat16, torch.float, True) + assert dense_key in result + assert expert_key in result + assert result[dense_key][0] == [dense] + assert result[expert_key][0] == [expert] + + def test_param_indices_independent_per_group(self): + """Expert and dense groups should have independent param_indices starting at 0.""" + dense_params = _make_params((100,), (200,)) + expert = _make_param_with_attrs((100,), allreduce=False) + result = group_params_for_buffers( + [dense_params[0], expert, dense_params[1]], grad_reduce_in_fp32=True + ) + + dense_key = BufferKey(torch.bfloat16, torch.float, False) + expert_key = BufferKey(torch.bfloat16, torch.float, True) + _, dense_indices = result[dense_key] + _, expert_indices = result[expert_key] + assert dense_indices == [0, 1] + assert expert_indices == [0] + + def test_mixed_dtypes(self): + """Params with different dtypes go in separate groups.""" + bf16_param = _make_params((100,), dtype=torch.bfloat16)[0] + fp32_param = _make_params((100,), dtype=torch.float32)[0] + result = group_params_for_buffers([bf16_param, fp32_param], grad_reduce_in_fp32=True) + + assert len(result) == 2 + bf16_key = BufferKey(torch.bfloat16, torch.float, False) + fp32_key = BufferKey(torch.float32, torch.float, False) + assert bf16_key in result + assert fp32_key in result + + +# --------------------------------------------------------------------------- +# Tests for _compute_default_per_buffer_param_layout +# --------------------------------------------------------------------------- + + +class TestDefaultParamLayout: + + def test_single_bucket_no_padding(self): + """With bucket_size=None, all params go in one bucket with no padding.""" + params = _make_params((100, 100), (50, 50)) + layout = _compute_default_per_buffer_param_layout(params, bucket_size=None) + + # Params iterated in reverse: params[1] first (2500 elems), params[0] second (10000). + assert layout.param_index_map[params[1]] == (0, 2500, 0) + assert layout.param_index_map[params[0]] == (2500, 12500, 0) + assert layout.bucket_indices == [(0, 12500)] + assert layout.per_bucket_numel_unpadded == [12500] + + def test_multiple_buckets(self): + """Params should split into buckets when exceeding bucket_size.""" + params = _make_params((100, 100), (100, 100), (100, 100)) + layout = _compute_default_per_buffer_param_layout(params, bucket_size=15000) + + # Reverse order: params[2], params[1], params[0]. Each is 10000 elems. + # After params[2]: 10000 < 15000, continue. + # After params[1]: 20000 >= 15000, finalize bucket. + # After params[0]: 10000 < 15000, finalize at end. + assert len(layout.bucket_indices) == 2 + assert layout.per_bucket_numel_unpadded[0] == 20000 + assert layout.per_bucket_numel_unpadded[1] == 10000 + + def test_no_padding_applied(self): + """Default layout should never add padding.""" + params = _make_params((97,), (103,)) + layout = _compute_default_per_buffer_param_layout(params, bucket_size=None) + + total_numel = 97 + 103 + assert layout.bucket_indices == [(0, total_numel)] + assert layout.per_bucket_numel_unpadded == [total_numel] + + def test_single_param(self): + params = _make_params((256,)) + layout = _compute_default_per_buffer_param_layout(params, bucket_size=None) + + assert layout.param_index_map[params[0]] == (0, 256, 0) + assert layout.bucket_indices == [(0, 256)] + + +# --------------------------------------------------------------------------- +# Tests for DistributedOptimizer._compute_per_buffer_param_layout +# --------------------------------------------------------------------------- + + +class TestDistOptParamLayout: + + @staticmethod + def _make_ddp_config(**overrides): + defaults = dict( + grad_reduce_in_fp32=False, + use_distributed_optimizer=True, + overlap_grad_reduce=True, + bucket_size=None, + average_in_collective=False, + ) + defaults.update(overrides) + return DistributedDataParallelConfig(**defaults) + + def test_param_start_64_alignment(self): + """Each param's start index should be 64-aligned.""" + # 97 is not a multiple of 64, so second param must be padded. + params = _make_params((97,), (103,)) + ddp_config = self._make_ddp_config() + layout = DistributedOptimizer._compute_per_buffer_param_layout( + params, bucket_size=None, data_parallel_world_size=2, ddp_config=ddp_config + ) + + for param in params: + start, end, _ = layout.param_index_map[param] + assert start % 64 == 0, f"Start {start} should be 64-aligned" + assert end - start == param.numel() + + def test_bucket_end_dp_divisible(self): + """Each bucket end should be divisible by lcm(dp_size, 128).""" + params = _make_params((1000,), (1000,)) + dp_size = 4 + ddp_config = self._make_ddp_config() + layout = DistributedOptimizer._compute_per_buffer_param_layout( + params, bucket_size=None, data_parallel_world_size=dp_size, ddp_config=ddp_config + ) + + divisor = math.lcm(dp_size, 128) + for start, end in layout.bucket_indices: + assert end % divisor == 0, f"Bucket end {end} should be divisible by {divisor}" + + def test_bucket_end_high_busbw_padding(self): + """With pad_buckets_for_high_nccl_busbw, bucket end should be divisible by 2^16.""" + params = _make_params((1000,)) + dp_size = 2 + ddp_config = self._make_ddp_config(pad_buckets_for_high_nccl_busbw=True) + layout = DistributedOptimizer._compute_per_buffer_param_layout( + params, bucket_size=None, data_parallel_world_size=dp_size, ddp_config=ddp_config + ) + + divisor = math.lcm(dp_size, 128, 2**16) + for _, end in layout.bucket_indices: + assert end % divisor == 0 + + def test_shared_embedding_gets_separate_bucket(self): + """Params with shared_embedding=True should be placed in their own bucket.""" + regular = _make_param_with_attrs((1000,)) + shared = _make_param_with_attrs((1000,), shared_embedding=True) + # Reverse order: shared first (since it's last in list), then regular. + params = [regular, shared] + ddp_config = self._make_ddp_config() + layout = DistributedOptimizer._compute_per_buffer_param_layout( + params, bucket_size=None, data_parallel_world_size=2, ddp_config=ddp_config + ) + + # shared_embedding param should be in its own bucket. + _, _, shared_bucket = layout.param_index_map[shared] + _, _, regular_bucket = layout.param_index_map[regular] + assert shared_bucket != regular_bucket + + def test_shared_embedding_as_first_reversed_param_no_extra_bucket(self): + """If shared_embedding param is the first in reversed order (last in list), + it should not create an empty extra bucket before it.""" + shared = _make_param_with_attrs((1000,), shared_embedding=True) + regular = _make_param_with_attrs((1000,)) + # Reverse order: regular first, then shared. + params = [shared, regular] + ddp_config = self._make_ddp_config() + layout = DistributedOptimizer._compute_per_buffer_param_layout( + params, bucket_size=None, data_parallel_world_size=2, ddp_config=ddp_config + ) + + # Both params should be in their own buckets (shared splits after regular). + assert len(layout.bucket_indices) == 2 + + def test_multiple_buckets_with_bucket_size(self): + """Verify bucket splitting with an explicit bucket_size.""" + params = _make_params((5000,), (5000,), (5000,)) + dp_size = 2 + ddp_config = self._make_ddp_config() + layout = DistributedOptimizer._compute_per_buffer_param_layout( + params, bucket_size=8000, data_parallel_world_size=dp_size, ddp_config=ddp_config + ) + + # Each param is 5000 elems. With 64-alignment: + # Bucket 0: params[2] starts at 0, ends at 5000 (< 8000); params[1] starts at 5000 + # (already 64-aligned since 5000 rounds to 5056), ends at 10056 >= 8000 → finalize. + # Bucket 1: params[0]. + assert len(layout.bucket_indices) == 2 + assert len(layout.per_bucket_numel_unpadded) == 2 + + def test_numel_unpadded_vs_padded(self): + """per_bucket_numel_unpadded should be <= padded bucket size.""" + params = _make_params((1000,)) + dp_size = 8 + ddp_config = self._make_ddp_config() + layout = DistributedOptimizer._compute_per_buffer_param_layout( + params, bucket_size=None, data_parallel_world_size=dp_size, ddp_config=ddp_config + ) + + for i, (start, end) in enumerate(layout.bucket_indices): + padded_numel = end - start + assert layout.per_bucket_numel_unpadded[i] <= padded_numel + + def test_shared_embedding_with_bucket_size(self): + """Shared embedding that hits bucket_size threshold should still get its own bucket.""" + # 3 regular params + 1 shared embedding, each 5000 elements. + # With bucket_size=8000, regular params would form 2-param buckets, + # but shared embedding must always be isolated. + regulars = [_make_param_with_attrs((5000,)) for _ in range(3)] + shared = _make_param_with_attrs((5000,), shared_embedding=True) + # Order: regulars first, shared last → reversed: shared first. + params = regulars + [shared] + dp_size = 2 + ddp_config = self._make_ddp_config() + layout = DistributedOptimizer._compute_per_buffer_param_layout( + params, bucket_size=8000, data_parallel_world_size=dp_size, ddp_config=ddp_config + ) + + # Shared embedding must be alone in its bucket. + _, _, shared_bucket = layout.param_index_map[shared] + shared_bucket_params = [ + p for p, (_, _, bid) in layout.param_index_map.items() if bid == shared_bucket + ] + assert len(shared_bucket_params) == 1 + assert shared_bucket_params[0] is shared + + def test_layout_matches_default_when_no_padding_needed(self): + """When params are 64-aligned and bucket_size=None, distributed optimizer layout + should produce the same param_index_map ordering as the default layout + (only bucket end padding differs).""" + # Use param sizes that are multiples of 64 to avoid start-of-param padding. + params = _make_params((64 * 10,), (64 * 20,), (64 * 15,)) + ddp_config = self._make_ddp_config() + dist_layout = DistributedOptimizer._compute_per_buffer_param_layout( + params, bucket_size=None, data_parallel_world_size=1, ddp_config=ddp_config + ) + default_layout = _compute_default_per_buffer_param_layout(params, bucket_size=None) + + # With dp_size=1, lcm(1, 128) = 128, so bucket end padding may differ. + # But param ordering and unpadded numel should match. + for param in params: + dist_start, dist_end, dist_bid = dist_layout.param_index_map[param] + def_start, def_end, def_bid = default_layout.param_index_map[param] + assert dist_start == def_start, f"Start mismatch for param: {dist_start} vs {def_start}" + assert dist_end == def_end, f"End mismatch for param: {dist_end} vs {def_end}" + assert dist_bid == def_bid, f"Bucket ID mismatch for param: {dist_bid} vs {def_bid}" + assert dist_layout.per_bucket_numel_unpadded == default_layout.per_bucket_numel_unpadded + + +# --------------------------------------------------------------------------- +# Tests for DistributedOptimizer.compute_full_param_layout +# --------------------------------------------------------------------------- + + +class TestComputeFullParamLayout: + + @staticmethod + def _make_ddp_config(**overrides): + defaults = dict( + grad_reduce_in_fp32=False, + use_distributed_optimizer=True, + overlap_grad_reduce=True, + bucket_size=None, + average_in_collective=False, + ) + defaults.update(overrides) + return DistributedDataParallelConfig(**defaults) + + def test_dense_only(self): + """With only dense params, should produce a single layout.""" + params = _make_params((100, 100), (50, 50)) + ddp_config = self._make_ddp_config() + full_layout = DistributedOptimizer.compute_full_param_layout( + params, bucket_size=None, data_parallel_world_size=2, ddp_config=ddp_config + ) + + assert len(full_layout.layouts) == 1 + key = list(full_layout.layouts.keys())[0] + assert key.is_expert_parallel is False + layout = full_layout.layouts[key] + assert set(layout.param_index_map.keys()) == set(params) + + def test_dense_and_expert_separate_layouts(self): + """Dense and expert-parallel params should get independent layouts.""" + dense = _make_param_with_attrs((100, 100)) + expert = _make_param_with_attrs((100, 100), allreduce=False) + ddp_config = self._make_ddp_config() + full_layout = DistributedOptimizer.compute_full_param_layout( + [dense, expert], bucket_size=None, data_parallel_world_size=2, ddp_config=ddp_config + ) + + assert len(full_layout.layouts) == 2 + dense_key = BufferKey(torch.bfloat16, torch.bfloat16, False) + expert_key = BufferKey(torch.bfloat16, torch.bfloat16, True) + assert dense_key in full_layout.layouts + assert expert_key in full_layout.layouts + + # Each layout should only contain its own params. + assert set(full_layout.layouts[dense_key].param_index_map.keys()) == {dense} + assert set(full_layout.layouts[expert_key].param_index_map.keys()) == {expert} + + # Both should start at index 0 (independent index spaces). + dense_starts = [s for s, _, _ in full_layout.layouts[dense_key].param_index_map.values()] + expert_starts = [s for s, _, _ in full_layout.layouts[expert_key].param_index_map.values()] + assert min(dense_starts) == 0 + assert min(expert_starts) == 0 + + def test_expert_uses_expert_dp_world_size(self): + """Expert-parallel layout should use expert_data_parallel_world_size for padding.""" + dense = _make_param_with_attrs((1000,)) + expert = _make_param_with_attrs((1000,), allreduce=False) + ddp_config = self._make_ddp_config() + + # Dense dp_size=3, expert dp_size=256. + # lcm(3, 128) = 384, lcm(256, 128) = 256 — different divisors. + full_layout = DistributedOptimizer.compute_full_param_layout( + [dense, expert], + bucket_size=None, + data_parallel_world_size=3, + ddp_config=ddp_config, + expert_data_parallel_world_size=256, + ) + + dense_key = BufferKey(torch.bfloat16, torch.bfloat16, False) + expert_key = BufferKey(torch.bfloat16, torch.bfloat16, True) + + # Expert bucket end should be divisible by lcm(256, 128) = 256. + expert_divisor = math.lcm(256, 128) + assert expert_divisor == 256 + for _, end in full_layout.layouts[expert_key].bucket_indices: + assert end % expert_divisor == 0 + + # Dense bucket end should be divisible by lcm(3, 128) = 384. + dense_divisor = math.lcm(3, 128) + assert dense_divisor == 384 + for _, end in full_layout.layouts[dense_key].bucket_indices: + assert end % dense_divisor == 0 + + def test_param_indices_populated(self): + """compute_full_param_layout should populate param_indices on each layout.""" + params = _make_params((100,), (200,), (300,)) + ddp_config = self._make_ddp_config() + full_layout = DistributedOptimizer.compute_full_param_layout( + params, bucket_size=None, data_parallel_world_size=2, ddp_config=ddp_config + ) + + layout = list(full_layout.layouts.values())[0] + assert len(layout.param_indices) == 3 + assert min(layout.param_indices) == 0 + assert max(layout.param_indices) == 2 diff --git a/tests/unit_tests/elastification/__init__.py b/tests/unit_tests/elastification/__init__.py new file mode 100644 index 00000000000..26496bfed70 --- /dev/null +++ b/tests/unit_tests/elastification/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. diff --git a/tests/unit_tests/elastification/test_apply_flextron_elasticity_to_model.py b/tests/unit_tests/elastification/test_apply_flextron_elasticity_to_model.py new file mode 100644 index 00000000000..7f0edd8f6de --- /dev/null +++ b/tests/unit_tests/elastification/test_apply_flextron_elasticity_to_model.py @@ -0,0 +1,236 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for ``apply_flextron_elasticity_to_model``. + +These tests focus on the layer-class-name-based routing logic (which manager +gets attached to which layer type). They use stub nn.Modules so the tests are +pure-Python and run without a GPU or distributed setup. The individual manager +classes are exercised via GPU-backed tests elsewhere. +""" + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from megatron.elastification import flextron_elasticity_hooks as hooks_module +from megatron.elastification.flextron_elasticity_hooks import apply_flextron_elasticity_to_model + + +def _make_submod(class_name): + """Build a bare nn.Module with a given __class__.__name__.""" + mod = nn.Module() + mod.__class__ = type(class_name, (nn.Module,), {}) + return mod + + +def _mamba_layer(): + layer = nn.Module() + layer.__class__ = type("MambaLayer", (nn.Module,), {}) + layer.add_module("mixer", _make_submod("MambaMixer")) + return layer + + +def _moe_layer(*, cls="MoETransformerLayer"): + layer = nn.Module() + layer.__class__ = type(cls, (nn.Module,), {}) + layer.add_module("pre_mlp_layernorm", _make_submod("RMSNorm")) + mlp = _make_submod("MoELayer") + mlp.add_module("router", _make_submod("TopKRouter")) + mlp.add_module("experts", _make_submod("TEGroupedMLP")) + layer.add_module("mlp", mlp) + return layer + + +def _attention_layer(): + layer = nn.Module() + layer.__class__ = type("TransformerLayer", (nn.Module,), {}) + attn = _make_submod("SelfAttention") + layer.add_module("self_attention", attn) + return layer + + +class _StubModel(nn.Module): + """Minimal model exposing .decoder.layers and (optionally) .decoder.final_norm.""" + + def __init__(self, layers, with_final_norm=False): + super().__init__() + decoder = nn.Module() + decoder.layers = nn.ModuleList(layers) + if with_final_norm: + decoder.add_module("final_norm", _make_submod("RMSNorm")) + self.decoder = decoder + + +def _make_config(pattern="MEM*", flextron=True): + return SimpleNamespace(hybrid_layer_pattern=pattern, flextron=flextron) + + +@pytest.fixture(autouse=True) +def stub_managers(monkeypatch): + """Stub every add_flextron_* entry point to a call-recorder. + + The real managers attach PyTorch hooks to real submodules; testing the + routing logic does not need that machinery. Each stub returns a Sentinel + whose ``.target`` field points at the module it would have hooked. + """ + calls = { + "transformer_layer": [], + "moe": [], + "topk_router": [], + "grouped_mlp": [], + "mamba": [], + "attention": [], + "stack": [], + } + + def _record(bucket): + def _stub(module, config, layer_idx=None): + entry = SimpleNamespace(target=module, layer_idx=layer_idx, config=config) + calls[bucket].append(entry) + return entry + + return _stub + + def _record_stack(module, config): + entry = SimpleNamespace(target=module, config=config) + calls["stack"].append(entry) + return entry + + monkeypatch.setattr( + hooks_module, "add_flextron_transformer_layer_elasticity", _record("transformer_layer") + ) + monkeypatch.setattr(hooks_module, "add_flextron_moe_elasticity", _record("moe")) + monkeypatch.setattr(hooks_module, "add_flextron_topk_router_elasticity", _record("topk_router")) + monkeypatch.setattr(hooks_module, "add_flextron_grouped_mlp_elasticity", _record("grouped_mlp")) + monkeypatch.setattr(hooks_module, "add_flextron_mamba_elasticity", _record("mamba")) + monkeypatch.setattr(hooks_module, "add_flextron_attention_elasticity", _record("attention")) + monkeypatch.setattr(hooks_module, "add_flextron_stack_elasticity", _record_stack) + + return calls + + +class TestEarlyReturns: + def test_missing_hybrid_pattern_returns_empty(self): + model = _StubModel([_mamba_layer()]) + config = SimpleNamespace() # no hybrid_layer_pattern + assert apply_flextron_elasticity_to_model(model, config) == [] + + def test_empty_hybrid_pattern_returns_empty(self): + model = _StubModel([_mamba_layer()]) + config = _make_config(pattern="") + assert apply_flextron_elasticity_to_model(model, config) == [] + + def test_missing_decoder_returns_empty(self): + model = nn.Module() # no .decoder + config = _make_config() + assert apply_flextron_elasticity_to_model(model, config) == [] + + +class TestLayerRouting: + def test_m_layer_registers_mamba_only(self, stub_managers): + model = _StubModel([_mamba_layer()]) + config = _make_config(pattern="M") + apply_flextron_elasticity_to_model(model, config) + assert len(stub_managers["mamba"]) == 1 + assert stub_managers["mamba"][0].layer_idx == 0 + assert stub_managers["mamba"][0].target.__class__.__name__ == "MambaMixer" + for key in ("transformer_layer", "moe", "topk_router", "grouped_mlp", "attention"): + assert stub_managers[key] == [] + + def test_star_layer_registers_attention_only(self, stub_managers): + model = _StubModel([_attention_layer()]) + config = _make_config(pattern="*") + apply_flextron_elasticity_to_model(model, config) + assert len(stub_managers["attention"]) == 1 + assert stub_managers["attention"][0].target.__class__.__name__ == "SelfAttention" + for key in ("transformer_layer", "moe", "topk_router", "grouped_mlp", "mamba"): + assert stub_managers[key] == [] + + def test_e_layer_registers_all_four_moe_managers(self, stub_managers): + model = _StubModel([_moe_layer()]) + config = _make_config(pattern="E") + apply_flextron_elasticity_to_model(model, config) + assert len(stub_managers["transformer_layer"]) == 1 + assert len(stub_managers["moe"]) == 1 + assert len(stub_managers["topk_router"]) == 1 + assert len(stub_managers["grouped_mlp"]) == 1 + + def test_e_layer_accepts_both_class_names(self, stub_managers): + """Regression: the E-layer hook should fire whether the layer class is + TransformerLayer (modelopt spec) or MoETransformerLayer (default spec).""" + model = _StubModel( + [_moe_layer(cls="TransformerLayer"), _moe_layer(cls="MoETransformerLayer")] + ) + config = _make_config(pattern="EE") + apply_flextron_elasticity_to_model(model, config) + # Both E-layers should have TransformerLayer elasticity attached. + assert len(stub_managers["transformer_layer"]) == 2 + + def test_hybrid_pattern_routes_each_layer(self, stub_managers): + layers = [_mamba_layer(), _moe_layer(), _mamba_layer(), _attention_layer()] + model = _StubModel(layers, with_final_norm=True) + config = _make_config(pattern="MEM*") + apply_flextron_elasticity_to_model(model, config) + + # One mamba manager per M, one attention per *, and all four moe managers per E. + assert len(stub_managers["mamba"]) == 2 + assert len(stub_managers["attention"]) == 1 + assert len(stub_managers["transformer_layer"]) == 1 + assert len(stub_managers["moe"]) == 1 + assert len(stub_managers["topk_router"]) == 1 + assert len(stub_managers["grouped_mlp"]) == 1 + # And a single stack-level manager for the final norm. + assert len(stub_managers["stack"]) == 1 + + +class TestStackManager: + def test_stack_manager_registered_when_final_norm_present(self, stub_managers): + model = _StubModel([_mamba_layer()], with_final_norm=True) + config = _make_config(pattern="M") + apply_flextron_elasticity_to_model(model, config) + assert len(stub_managers["stack"]) == 1 + + def test_stack_manager_skipped_when_no_final_norm(self, stub_managers): + model = _StubModel([_mamba_layer()], with_final_norm=False) + config = _make_config(pattern="M") + apply_flextron_elasticity_to_model(model, config) + assert stub_managers["stack"] == [] + + +class TestMissingSubmodules: + def test_mamba_layer_without_mixer_is_skipped(self, stub_managers): + """M-layer without a MambaMixer submodule should not crash.""" + layer = nn.Module() + layer.__class__ = type("MambaLayer", (nn.Module,), {}) + # intentionally no 'mixer' submodule + model = _StubModel([layer]) + config = _make_config(pattern="M") + apply_flextron_elasticity_to_model(model, config) + assert stub_managers["mamba"] == [] + + def test_attention_layer_without_self_attention_is_skipped(self, stub_managers): + layer = nn.Module() + layer.__class__ = type("TransformerLayer", (nn.Module,), {}) + # no SelfAttention submodule + model = _StubModel([layer]) + config = _make_config(pattern="*") + apply_flextron_elasticity_to_model(model, config) + assert stub_managers["attention"] == [] + + +class TestManagersStoredOnModel: + def test_model_gets_flextron_managers_attribute(self, stub_managers): + model = _StubModel([_mamba_layer()]) + config = _make_config(pattern="M") + returned = apply_flextron_elasticity_to_model(model, config) + assert model._flextron_managers is returned + assert len(returned) == len(stub_managers["mamba"]) + + def test_pattern_shorter_than_layers_only_uses_pattern_length(self, stub_managers): + layers = [_mamba_layer(), _mamba_layer(), _mamba_layer()] + model = _StubModel(layers) + config = _make_config(pattern="M") # only first layer is covered + apply_flextron_elasticity_to_model(model, config) + assert len(stub_managers["mamba"]) == 1 diff --git a/tests/unit_tests/elastification/test_arguments.py b/tests/unit_tests/elastification/test_arguments.py new file mode 100644 index 00000000000..15a726d5498 --- /dev/null +++ b/tests/unit_tests/elastification/test_arguments.py @@ -0,0 +1,168 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for megatron.elastification.arguments.""" + +from argparse import Namespace + +import pytest + +from megatron.elastification.arguments import ( + convert_per_lists_to_int_lists, + sort_budget_list_descending, + validate_flextron_per_int_lists, +) + + +def _make_config(**overrides): + defaults = dict( + hidden_size=1920, + ffn_hidden_size=960, + num_attention_heads=32, + mamba_num_heads=64, + num_moe_experts=128, + emb_per_list=None, + mlp_per_list=None, + mamba_per_list=None, + moe_expert_per_list=None, + emb_int_list=None, + mlp_int_list=None, + mamba_int_list=None, + moe_expert_int_list=None, + ) + defaults.update(overrides) + return Namespace(**defaults) + + +class TestConvertPerListsToIntLists: + def test_ratio_one_maps_to_full_dim(self): + cfg = _make_config(emb_per_list=[1.0, 0.5]) + convert_per_lists_to_int_lists(cfg) + assert cfg.emb_int_list == [1920, 960] + # After conversion the per-list is cleared. + assert cfg.emb_per_list is None + + def test_floor_rounding(self): + # 0.71429 * 1920 = 1371.4368 -> floor -> 1371 + cfg = _make_config(emb_per_list=[0.71429, 0.51725]) + convert_per_lists_to_int_lists(cfg) + assert cfg.emb_int_list == [1371, 993] + + def test_all_axes_converted_with_correct_ref_dim(self): + cfg = _make_config( + emb_per_list=[1.0], mlp_per_list=[0.5], mamba_per_list=[0.75], moe_expert_per_list=[0.5] + ) + convert_per_lists_to_int_lists(cfg) + assert cfg.emb_int_list == [1920] + assert cfg.mlp_int_list == [480] # 0.5 * 960 + assert cfg.mamba_int_list == [48] # 0.75 * 64 + assert cfg.moe_expert_int_list == [64] # 0.5 * 128 + + def test_axis_with_no_per_list_is_untouched(self): + cfg = _make_config(emb_per_list=[1.0]) # only emb + convert_per_lists_to_int_lists(cfg) + # mlp / mamba / moe_expert should not gain int_list from nothing. + assert cfg.mlp_int_list is None + assert cfg.mamba_int_list is None + assert cfg.moe_expert_int_list is None + + +class TestValidateFlextronPerIntLists: + def _make_args(self, **overrides): + defaults = dict( + emb_per_list=None, + emb_int_list=None, + mlp_per_list=None, + mlp_int_list=None, + mamba_per_list=None, + mamba_int_list=None, + moe_expert_per_list=None, + moe_expert_int_list=None, + ) + defaults.update(overrides) + return Namespace(**defaults) + + def test_unset_axis_defaults_to_full(self): + args = self._make_args() + validate_flextron_per_int_lists(args) + # Each axis defaults to [1.0] on the per-list side. + assert args.emb_per_list == [1.0] + assert args.mlp_per_list == [1.0] + assert args.mamba_per_list == [1.0] + assert args.moe_expert_per_list == [1.0] + + def test_per_list_preserved_when_set(self): + args = self._make_args(emb_per_list=[1.0, 0.5]) + validate_flextron_per_int_lists(args) + assert args.emb_per_list == [1.0, 0.5] + + def test_int_list_preserved_when_set(self): + args = self._make_args(emb_int_list=[1920, 960]) + validate_flextron_per_int_lists(args) + # int_list was explicitly set: per_list stays None (not defaulted to [1.0]). + assert args.emb_per_list is None + assert args.emb_int_list == [1920, 960] + + def test_both_set_raises(self): + args = self._make_args(emb_per_list=[1.0], emb_int_list=[1920]) + with pytest.raises(AssertionError, match="not both"): + validate_flextron_per_int_lists(args) + + def test_per_list_out_of_range_raises(self): + args = self._make_args(emb_per_list=[1.5]) + with pytest.raises(AssertionError, match=r"\[0, 1\]"): + validate_flextron_per_int_lists(args) + + def test_per_list_negative_raises(self): + args = self._make_args(emb_per_list=[-0.1]) + with pytest.raises(AssertionError, match=r"\[0, 1\]"): + validate_flextron_per_int_lists(args) + + +class TestSortBudgetListDescending: + def test_ascending_input_gets_reversed(self): + args = Namespace(budget_list=[0.5, 0.7, 1.0], budget_probs=[0.1, 0.4, 0.5]) + sort_budget_list_descending(args) + assert args.budget_list == [1.0, 0.7, 0.5] + assert args.budget_probs == [0.5, 0.4, 0.1] + + def test_descending_input_unchanged(self): + args = Namespace(budget_list=[1.0, 0.5], budget_probs=[0.7, 0.3]) + sort_budget_list_descending(args) + assert args.budget_list == [1.0, 0.5] + assert args.budget_probs == [0.7, 0.3] + + def test_unsorted_input_paired_correctly(self): + # Verify probs follow the same permutation as budgets. + args = Namespace(budget_list=[0.5, 1.0, 0.7], budget_probs=[0.1, 0.5, 0.4]) + sort_budget_list_descending(args) + assert args.budget_list == [1.0, 0.7, 0.5] + assert args.budget_probs == [0.5, 0.4, 0.1] + + def test_no_probs_only_sorts_budgets(self): + args = Namespace(budget_list=[0.5, 1.0], budget_probs=None) + sort_budget_list_descending(args) + assert args.budget_list == [1.0, 0.5] + assert args.budget_probs is None + + def test_single_element_unchanged(self): + args = Namespace(budget_list=[1.0], budget_probs=[1.0]) + sort_budget_list_descending(args) + assert args.budget_list == [1.0] + assert args.budget_probs == [1.0] + + def test_none_budget_list_skipped(self): + args = Namespace(budget_list=None, budget_probs=None) + sort_budget_list_descending(args) # must not raise + assert args.budget_list is None + + def test_length_mismatch_raises(self): + args = Namespace(budget_list=[1.0, 0.5], budget_probs=[1.0]) + with pytest.raises(AssertionError, match="length"): + sort_budget_list_descending(args) + + def test_idempotent(self): + args = Namespace(budget_list=[0.5, 0.7, 1.0], budget_probs=[0.1, 0.4, 0.5]) + sort_budget_list_descending(args) + sort_budget_list_descending(args) # second call must be a no-op + assert args.budget_list == [1.0, 0.7, 0.5] + assert args.budget_probs == [0.5, 0.4, 0.1] diff --git a/tests/unit_tests/elastification/test_flex_budget_utils.py b/tests/unit_tests/elastification/test_flex_budget_utils.py new file mode 100644 index 00000000000..662b0176c3f --- /dev/null +++ b/tests/unit_tests/elastification/test_flex_budget_utils.py @@ -0,0 +1,155 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for megatron.elastification.router.flex_budget_utils.""" + +import pytest + +from megatron.elastification.router.flex_budget_utils import get_num_parameters + +# Reference dimensions used by most tests. Small enough to compute by hand. +_DIMS = dict( + mamba_num_heads=4, + mamba_d_head=2, + mamba_d_state=2, + num_attention_heads=2, + num_query_groups=1, + ffn_hidden_size=8, + hidden_size=4, + kv_channels=2, + vocab_size=10, + num_experts=2, + shared_expert_intermediate_size=0, + moe_router_topk=1, +) + +_EMBED_PLUS_LN = (_DIMS["vocab_size"] * _DIMS["hidden_size"]) + _DIMS["hidden_size"] +_OUTPUT_LAYER = _DIMS["vocab_size"] * _DIMS["hidden_size"] + + +def _att_cost(): + h, k, q = (_DIMS["hidden_size"], _DIMS["kv_channels"], _DIMS["num_query_groups"]) + n_heads = _DIMS["num_attention_heads"] + input_ln = h + linear_proj = n_heads * k * h + linear_qkv = (n_heads + 2 * q) * k * h + return input_ln + linear_proj + linear_qkv + + +def _moe_cost_all(): + pre_mlp_ln = _DIMS["hidden_size"] + n_experts = _DIMS["num_experts"] + ffn = _DIMS["ffn_hidden_size"] + shared = _DIMS["shared_expert_intermediate_size"] + h = _DIMS["hidden_size"] + linear_fc1 = ffn * (h * n_experts + shared) + linear_fc2 = ffn * (h * n_experts + shared) + return pre_mlp_ln + linear_fc1 + linear_fc2 + + +def _moe_cost_active(): + pre_mlp_ln = _DIMS["hidden_size"] + topk = _DIMS["moe_router_topk"] + ffn = _DIMS["ffn_hidden_size"] + shared = _DIMS["shared_expert_intermediate_size"] + h = _DIMS["hidden_size"] + linear_fc1 = ffn * (h * topk + shared) + linear_fc2 = ffn * (h * topk + shared) + return pre_mlp_ln + linear_fc1 + linear_fc2 + + +def _mamba_cost(): + h = _DIMS["hidden_size"] + nheads = _DIMS["mamba_num_heads"] + d_head = _DIMS["mamba_d_head"] + d_state = _DIMS["mamba_d_state"] + d_inner = nheads * d_head + ngroups = 8 # hard-coded in the implementation + cdim = d_inner + 2 * ngroups * d_state + mamba_conv = cdim + cdim * 1 * 4 # bias + weight, kernel=4, stride=1 + mamba_input_ln = h + mamba_in_proj = h * (d_inner * 2 + 2 * ngroups * d_state + nheads) + mamba_norm = d_inner + mamba_out_proj = d_inner * h + scalars = nheads + nheads + nheads # dt_bias + A_log + D + return scalars + mamba_input_ln + mamba_in_proj + mamba_conv + mamba_norm + mamba_out_proj + + +class TestGetNumParameters: + def test_single_moe_layer_matches_manual(self): + total, active = get_num_parameters(hybrid_pattern="E", tied_vocab=False, **_DIMS) + expected_total = _EMBED_PLUS_LN + _OUTPUT_LAYER + _moe_cost_all() + expected_active = _EMBED_PLUS_LN + _OUTPUT_LAYER + _moe_cost_active() + assert total == expected_total + assert active == expected_active + + def test_single_attention_layer(self): + total, active = get_num_parameters(hybrid_pattern="*", tied_vocab=False, **_DIMS) + expected = _EMBED_PLUS_LN + _OUTPUT_LAYER + _att_cost() + assert total == expected + # Attention has no active/total split. + assert active == expected + + def test_single_mamba_layer(self): + total, active = get_num_parameters(hybrid_pattern="M", tied_vocab=False, **_DIMS) + expected = _EMBED_PLUS_LN + _OUTPUT_LAYER + _mamba_cost() + assert total == expected + assert active == expected + + def test_hybrid_pattern_is_sum_of_per_layer_costs(self): + pattern = "MEM*E" + total, active = get_num_parameters(hybrid_pattern=pattern, tied_vocab=False, **_DIMS) + expected_total = ( + _EMBED_PLUS_LN + _OUTPUT_LAYER + 2 * _mamba_cost() + 2 * _moe_cost_all() + _att_cost() + ) + expected_active = ( + _EMBED_PLUS_LN + + _OUTPUT_LAYER + + 2 * _mamba_cost() + + 2 * _moe_cost_active() + + _att_cost() + ) + assert total == expected_total + assert active == expected_active + + def test_tied_vocab_zeros_output_layer(self): + total_tied, _ = get_num_parameters(hybrid_pattern="M", tied_vocab=True, **_DIMS) + total_untied, _ = get_num_parameters(hybrid_pattern="M", tied_vocab=False, **_DIMS) + # Untied adds one more vocab*hidden block. + assert total_untied - total_tied == _DIMS["vocab_size"] * _DIMS["hidden_size"] + + def test_pipe_character_ignored(self): + # The '|' marker (pipeline split) should not contribute any params. + base = get_num_parameters(hybrid_pattern="ME", tied_vocab=False, **_DIMS) + with_pipe = get_num_parameters(hybrid_pattern="M|E", tied_vocab=False, **_DIMS) + assert base == with_pipe + + def test_unknown_layer_char_raises(self): + with pytest.raises(RuntimeError, match="Unknown layer type"): + get_num_parameters(hybrid_pattern="Z", tied_vocab=False, **_DIMS) + + def test_moe_active_less_than_or_equal_total(self): + # topk < num_experts, so active < total; topk == num_experts, active == total. + total_tk1, active_tk1 = get_num_parameters( + hybrid_pattern="E", tied_vocab=False, **{**_DIMS, "moe_router_topk": 1} + ) + total_tkN, active_tkN = get_num_parameters( + hybrid_pattern="E", + tied_vocab=False, + **{**_DIMS, "moe_router_topk": _DIMS["num_experts"]}, + ) + assert active_tk1 < total_tk1 + assert active_tkN == total_tkN + + def test_topk_zero_active_excludes_experts(self): + # With topk=0 the active cost per expert's linear_fc1/fc2 contribution + # collapses to 0 (shared_expert_intermediate_size=0 in our fixture). + _, active = get_num_parameters( + hybrid_pattern="E", tied_vocab=False, **{**_DIMS, "moe_router_topk": 0} + ) + # active == embed + output + pre_mlp_ln (no fc1/fc2 contribution) + assert active == _EMBED_PLUS_LN + _OUTPUT_LAYER + _DIMS["hidden_size"] + + def test_empty_pattern_only_embeddings_and_final_norm(self): + total, active = get_num_parameters(hybrid_pattern="", tied_vocab=False, **_DIMS) + assert total == _EMBED_PLUS_LN + _OUTPUT_LAYER + assert active == _EMBED_PLUS_LN + _OUTPUT_LAYER diff --git a/tests/unit_tests/elastification/test_flextron_config.py b/tests/unit_tests/elastification/test_flextron_config.py new file mode 100644 index 00000000000..34bf2e8a7f3 --- /dev/null +++ b/tests/unit_tests/elastification/test_flextron_config.py @@ -0,0 +1,91 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for megatron.elastification.flextron_config.""" + +import dataclasses +from argparse import Namespace +from types import SimpleNamespace + +from megatron.elastification.flextron_config import FlextronConfig, inject_flextron_config + + +class TestFlextronConfigDefaults: + def test_default_values(self): + cfg = FlextronConfig() + assert cfg.flextron is False + assert cfg.enable_router is False + assert cfg.router_inter_dim == 128 + assert cfg.hard_sample_th == 0.996 + assert cfg.tau_init == 1.0 + assert cfg.tau_decay == 0.9999 + assert cfg.router_std == 0.1 + assert cfg.budget_type == 'param' + assert cfg.original_model_sample_prob == 0.33 + + def test_all_fields_accessible_after_construction(self): + cfg = FlextronConfig() + for f in dataclasses.fields(FlextronConfig): + # Every declared field should be readable. + getattr(cfg, f.name) + + +class TestInjectFlextronConfig: + def test_copies_all_fields_from_args(self): + args = Namespace( + flextron=True, + enable_router=True, + router_inter_dim=256, + hard_sample_th=0.5, + tau_init=2.0, + tau_decay=0.9, + router_std=0.01, + budget_type='mem', + budget_list=[1.0, 0.5], + original_model_sample_prob=0.0, + ) + target = SimpleNamespace() + inject_flextron_config(args, target) + assert target.flextron is True + assert target.enable_router is True + assert target.router_inter_dim == 256 + assert target.hard_sample_th == 0.5 + assert target.tau_init == 2.0 + assert target.tau_decay == 0.9 + assert target.router_std == 0.01 + assert target.budget_type == 'mem' + assert target.budget_list == [1.0, 0.5] + assert target.original_model_sample_prob == 0.0 + + def test_missing_arg_falls_back_to_default(self): + # args has only a subset of FlextronConfig fields. + args = Namespace(flextron=True) + target = SimpleNamespace() + inject_flextron_config(args, target) + # Present-on-args field is copied. + assert target.flextron is True + # Absent-on-args field gets FlextronConfig default. + assert target.router_inter_dim == 128 + assert target.hard_sample_th == 0.996 + assert target.tau_init == 1.0 + + def test_preserves_unrelated_config_attributes(self): + args = Namespace(flextron=True) + target = SimpleNamespace(hidden_size=1920, num_layers=52) + inject_flextron_config(args, target) + # Fields that are not FlextronConfig fields stay untouched. + assert target.hidden_size == 1920 + assert target.num_layers == 52 + + def test_every_flextron_field_is_set_on_target(self): + args = Namespace() # totally empty + target = SimpleNamespace() + inject_flextron_config(args, target) + for f in dataclasses.fields(FlextronConfig): + assert hasattr(target, f.name), f"field {f.name!r} not injected onto target" + + def test_returns_none(self): + # inject_flextron_config mutates in place and should not return a value. + args = Namespace(flextron=True) + target = SimpleNamespace() + result = inject_flextron_config(args, target) + assert result is None diff --git a/tests/unit_tests/elastification/test_flextron_grouped_mlp_elasticity_manager.py b/tests/unit_tests/elastification/test_flextron_grouped_mlp_elasticity_manager.py new file mode 100644 index 00000000000..fb8593f7b56 --- /dev/null +++ b/tests/unit_tests/elastification/test_flextron_grouped_mlp_elasticity_manager.py @@ -0,0 +1,211 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""GPU-backed tests for FlextronGroupedMLPElasticityManager. + +Covers the multi-hook MLP masking pipeline: setup-mask init, the +input/fc1-post/output hook trio that applies emb + intermediate masking, +and detach. The fc1_post_hook calls into expert-tensor-parallel state, +so we initialize MPU at world_size=1 (mask split is the whole mask). + +Run with: + torchrun --nproc_per_node=1 -m pytest tests/unit_tests/elastification/test_flextron_grouped_mlp_elasticity_manager.py +""" + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from megatron.elastification.flextron_elasticity_hooks import ( + FlextronGroupedMLPElasticityManager, + add_flextron_grouped_mlp_elasticity, +) +from tests.unit_tests.test_utilities import Utils + + +def _config(hidden_size=64, ffn_hidden_size=128, soft_mask=True): + return SimpleNamespace( + flextron=True, + soft_mask=soft_mask, + flex_hetero_ffn=False, + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + emb_int_list=[hidden_size, hidden_size // 2], + mlp_int_list=[ffn_hidden_size, ffn_hidden_size // 2], + hybrid_layer_pattern="E", + layernorm_epsilon=1e-5, + ) + + +class _StubGroupedMLP(nn.Module): + """Minimal module exposing the surface attach_hooks needs: + - register_forward_*_hook (inherited from nn.Module) + - a ``linear_fc1`` child (for fc1_post_hook) + + Forward chain mimics a real GroupedMLP: hidden -> fc1 -> ffn-sized + "intermediate" -> projected back to hidden-sized output. Both stages + return ``(tensor, None)`` so the hooks see the (out, bias) tuple shape + they expect.""" + + def __init__(self, hidden_size, ffn_hidden_size): + super().__init__() + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + # Stash the captured intermediate so tests can inspect what fc1_post_hook + # produced before the output projection runs. + self._captured_intermediate = None + + class _FC1(nn.Module): + def forward(_self, x): + inter = x.new_zeros(*x.shape[:-1], ffn_hidden_size) + inter[..., : x.shape[-1]] = x # plant the input into the lower channels + return inter, None + + self.linear_fc1 = _FC1() + + def forward(self, hidden_states): + intermediate, _ = self.linear_fc1(hidden_states) + # fc1_post_hook may have masked the intermediate before we get here. + self._captured_intermediate = intermediate.detach().clone() + # Project back to hidden dim: take the lower hidden_size channels. + out = intermediate[..., : self.hidden_size].contiguous() + return (out, None) + + +@pytest.mark.internal +class TestFlextronGroupedMLPElasticityManager: + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_tensor_parallel_size=1, + ) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _make_module(self, cfg): + return _StubGroupedMLP(cfg.hidden_size, cfg.ffn_hidden_size).cuda().to(torch.bfloat16) + + def test_attach_registers_expected_hook_count(self): + cfg = _config() + mod = self._make_module(cfg) + mgr = FlextronGroupedMLPElasticityManager(cfg) + mgr.attach_hooks(mod) + # setup + input_mask + fc1_post + output_mask + cleanup = 5 + assert len(mgr.hook_handles) == 5 + mgr.detach_hooks() + + def test_init_emb_masks_match_choice_list(self): + cfg = _config(hidden_size=64) + mgr = FlextronGroupedMLPElasticityManager(cfg) + mgr._init_embedding_masks() + # One mask per emb_int_list entry, each shape == [hidden_size]. + assert mgr.emb_masks.shape == (len(cfg.emb_int_list), cfg.hidden_size) + # Mask 0 (full): all-ones over the full hidden dim. + torch.testing.assert_close( + mgr.emb_masks[0], torch.ones(cfg.hidden_size, dtype=torch.bfloat16, device="cuda") + ) + # Mask 1 (half): ones on lower half, zeros on upper half. + expected = torch.zeros(cfg.hidden_size, dtype=torch.bfloat16, device="cuda") + expected[: cfg.hidden_size // 2] = 1.0 + torch.testing.assert_close(mgr.emb_masks[1], expected) + + def test_init_mlp_masks_dedupe_and_sort(self): + """``_init_mlp_masks`` dedupes via set() and sorts descending. Verify + the lookup maps each unique value to the right index.""" + cfg = _config(ffn_hidden_size=128) + cfg.mlp_int_list = [128, 128, 64] # duplicate to exercise dedupe + mgr = FlextronGroupedMLPElasticityManager(cfg) + mgr._init_mlp_masks() + # Two unique values, sorted descending: [128, 64]. + assert mgr.mlp_intermediate_masks.shape[0] == 2 + assert mgr.mlp_intermediate_masks_lookup == {128: 0, 64: 1} + + def test_no_router_emb_is_passthrough(self): + """With current_router_emb None, no hook should mutate output.""" + cfg = _config() + mod = self._make_module(cfg) + x = torch.randn(2, cfg.hidden_size, dtype=torch.bfloat16, device="cuda") + baseline_out, baseline_bias = mod(x) + + mgr = FlextronGroupedMLPElasticityManager(cfg) + mgr.attach_hooks(mod) + # current_router_emb is None — the input/fc1/output hooks all early-out. + out, bias = mod(x) + torch.testing.assert_close(out, baseline_out) + mgr.detach_hooks() + + def test_soft_mask_zeros_upper_intermediate_at_half_budget(self): + """fc1_post_hook applies the mlp_intermediate_mask. Soft-mask weighted + sum on a one-hot at the half-budget choice should leave the upper + ffn channels zeroed.""" + cfg = _config(hidden_size=64, ffn_hidden_size=128, soft_mask=True) + mod = self._make_module(cfg) + + mgr = FlextronGroupedMLPElasticityManager(cfg) + mgr.attach_hooks(mod) + # One-hot router_emb on full-emb, one-hot router_mlp on half-ffn (index 1). + emb_logits = torch.tensor([1.0, 0.0], dtype=torch.bfloat16, device="cuda") + mlp_logits = torch.tensor([0.0, 1.0], dtype=torch.bfloat16, device="cuda") + mgr.set_elasticity_params( + router_emb=(emb_logits, cfg.hidden_size), + router_mlp=(mlp_logits, cfg.ffn_hidden_size // 2), + ) + + x = torch.ones(2, cfg.hidden_size, dtype=torch.bfloat16, device="cuda") + mod(x) + intermediate = mod._captured_intermediate + # mlp_int_list sorted-desc dedupe = [128, 64]; one-hot on index 1 -> 64. + # Lower 64 channels active, upper 64 zeroed by the intermediate mask. + assert (intermediate[..., 64:] == 0).all() + assert not (intermediate[..., :64] == 0).all() + mgr.detach_hooks() + + def test_set_elasticity_params_only_updates_provided_axes(self): + """Calling set_elasticity_params with only one kwarg must not clear + the other (regression-guard for the ``if x is not None`` pattern).""" + cfg = _config() + mgr = FlextronGroupedMLPElasticityManager(cfg) + sentinel_emb = (torch.tensor([1.0, 0.0]), cfg.hidden_size) + sentinel_mlp = (torch.tensor([0.0, 1.0]), cfg.ffn_hidden_size // 2) + mgr.set_elasticity_params(router_emb=sentinel_emb, router_mlp=sentinel_mlp) + + # Update only emb; mlp must still be the prior value. + new_emb = (torch.tensor([0.0, 1.0]), cfg.hidden_size // 2) + mgr.set_elasticity_params(router_emb=new_emb) + + assert mgr.current_router_emb is new_emb + assert mgr.current_router_mlp is sentinel_mlp + + def test_detach_clears_hook_handles(self): + cfg = _config() + mod = self._make_module(cfg) + mgr = FlextronGroupedMLPElasticityManager(cfg) + mgr.attach_hooks(mod) + assert len(mgr.hook_handles) == 5 + mgr.detach_hooks() + assert mgr.hook_handles == [] + + +@pytest.mark.internal +class TestAddFlextronGroupedMLPElasticity: + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_tensor_parallel_size=1, + ) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_factory_returns_manager_with_layer_idx(self): + cfg = _config() + mod = _StubGroupedMLP(cfg.hidden_size, cfg.ffn_hidden_size).cuda().to(torch.bfloat16) + mgr = add_flextron_grouped_mlp_elasticity(mod, cfg, layer_idx=0) + assert isinstance(mgr, FlextronGroupedMLPElasticityManager) + assert mgr.layer_idx == 0 + assert len(mgr.hook_handles) == 5 + mgr.detach_hooks() diff --git a/tests/unit_tests/elastification/test_flextron_mamba_elasticity_manager.py b/tests/unit_tests/elastification/test_flextron_mamba_elasticity_manager.py new file mode 100644 index 00000000000..81ec4839dbe --- /dev/null +++ b/tests/unit_tests/elastification/test_flextron_mamba_elasticity_manager.py @@ -0,0 +1,206 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""GPU-backed tests for FlextronMambaElasticityManager. + +Builds a real MambaMixer (mirroring ``tests/unit_tests/ssm/test_mamba_mixer.py``) +and verifies that the elasticity hooks attach, behave as no-ops without +elasticity params, and produce different activations once params are set. + +Run with: + + torchrun --nproc_per_node=1 -m pytest tests/unit_tests/elastification/test_flextron_mamba_elasticity_manager.py +""" + +import pytest +import torch + +from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.ssm.mamba_mixer import MambaMixer +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer import TransformerConfig +from megatron.elastification.flextron_elasticity_hooks import ( + FlextronMambaElasticityManager, + add_flextron_mamba_elasticity, +) +from tests.unit_tests.test_utilities import Utils + + +def _flextron_fields(hidden_size, num_heads): + """Return dict of flextron attrs to copy onto a TransformerConfig.""" + return dict( + flextron=True, + soft_mask=True, + flex_hetero_mamba=False, + flex_hetero_ffn=False, + flex_hetero_moe_expert=False, + hybrid_layer_pattern="M", + emb_int_list=[hidden_size, hidden_size // 2], + mamba_int_list=[num_heads, num_heads // 2], + ) + + +@pytest.mark.internal +class TestFlextronMambaElasticityManager: + + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def _build_mixer_and_config(self, hidden_size=256, num_heads=8): + """Construct a bf16 MambaMixer on CUDA + a flextron-enabled config.""" + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + config = TransformerConfig( + hidden_size=hidden_size, + num_layers=1, + num_attention_heads=1, + use_cpu_initialization=True, + use_mamba_mem_eff_path=True, + ) + # Inject the flextron fields directly (bypassing inject_flextron_config + # to avoid pulling in the whole args-parser stack). + for k, v in _flextron_fields(hidden_size, num_heads).items(): + setattr(config, k, v) + + pg_collection = ProcessGroupCollection.use_mpu_process_groups(required_pgs=['tp', 'cp']) + mixer = MambaMixer( + config, + hybrid_stack_spec.submodules.mamba_layer.submodules.mixer.submodules, + config.hidden_size, + layer_number=1, + pg_collection=pg_collection, + ) + mixer.cuda() + return mixer, config + + def test_attach_produces_hook_handles(self): + mixer, config = self._build_mixer_and_config() + mgr = FlextronMambaElasticityManager(config) + mgr.attach_hooks(mixer) + # There are several hooks: setup + input_mask + in_proj(pre,post) + + # conv1d + norm(pre,post) + output + cleanup. Exact count can change; + # we just verify > 1. + assert len(mgr.hook_handles) > 1 + mgr.detach_hooks() + + def test_current_router_none_preserves_output(self): + """Without elasticity params set, a forward with hooks attached must + produce (approximately) the same output as a forward without hooks.""" + mixer, config = self._build_mixer_and_config() + + seq_len, micro_batch = 16, 2 + x = torch.ones((seq_len, micro_batch, config.hidden_size), device="cuda") + + # Baseline (no elasticity). + baseline_out, _ = mixer(x) + + mgr = FlextronMambaElasticityManager(config) + mgr.attach_hooks(mixer) + # current_router_emb/mamba both None -> all hooks except setup/cleanup + # should no-op. + hooked_out, _ = mixer(x) + torch.testing.assert_close(hooked_out, baseline_out, atol=1e-2, rtol=1e-2) + mgr.detach_hooks() + + def test_full_budget_one_hot_approx_matches_baseline(self): + """One-hot on index 0 (full emb + full mamba heads) should approximately + reproduce the baseline output (up to bf16 / eps drift).""" + mixer, config = self._build_mixer_and_config() + + seq_len, micro_batch = 16, 2 + x = torch.ones((seq_len, micro_batch, config.hidden_size), device="cuda") + baseline_out, _ = mixer(x) + + mgr = FlextronMambaElasticityManager(config) + mgr.attach_hooks(mixer) + emb_logits = torch.tensor([1.0, 0.0], dtype=torch.bfloat16, device="cuda") + mamba_logits = torch.tensor([1.0, 0.0], dtype=torch.bfloat16, device="cuda") + mgr.set_elasticity_params( + router_emb=(emb_logits, config.hidden_size), + router_mamba=(mamba_logits, config.mamba_int_list[0]), + ) + full_out, _ = mixer(x) + # Full-budget one-hot should not materially change the output. + torch.testing.assert_close(full_out, baseline_out, atol=5e-2, rtol=5e-2) + mgr.detach_hooks() + + def test_small_budget_one_hot_changes_output(self): + """One-hot on a smaller choice should change the output norm.""" + mixer, config = self._build_mixer_and_config() + + seq_len, micro_batch = 16, 2 + x = torch.randn((seq_len, micro_batch, config.hidden_size), device="cuda") + baseline_out, _ = mixer(x) + + mgr = FlextronMambaElasticityManager(config) + mgr.attach_hooks(mixer) + emb_logits = torch.tensor([0.0, 1.0], dtype=torch.bfloat16, device="cuda") + mamba_logits = torch.tensor([0.0, 1.0], dtype=torch.bfloat16, device="cuda") + mgr.set_elasticity_params( + router_emb=(emb_logits, config.emb_int_list[1]), + router_mamba=(mamba_logits, config.mamba_int_list[1]), + ) + small_out, _ = mixer(x) + # The small-budget output should measurably differ from the baseline. + assert not torch.allclose(small_out, baseline_out, atol=1e-2) + mgr.detach_hooks() + + def test_detach_restores_baseline(self): + mixer, config = self._build_mixer_and_config() + seq_len, micro_batch = 16, 2 + x = torch.ones((seq_len, micro_batch, config.hidden_size), device="cuda") + baseline_out, _ = mixer(x) + + mgr = FlextronMambaElasticityManager(config) + mgr.attach_hooks(mixer) + emb_logits = torch.tensor([0.0, 1.0], dtype=torch.bfloat16, device="cuda") + mamba_logits = torch.tensor([0.0, 1.0], dtype=torch.bfloat16, device="cuda") + mgr.set_elasticity_params( + router_emb=(emb_logits, config.emb_int_list[1]), + router_mamba=(mamba_logits, config.mamba_int_list[1]), + ) + _ = mixer(x) + mgr.detach_hooks() + + detached_out, _ = mixer(x) + torch.testing.assert_close(detached_out, baseline_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.internal +class TestAddFlextronMambaElasticity: + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_factory_returns_manager(self): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + config = TransformerConfig( + hidden_size=256, + num_layers=1, + num_attention_heads=1, + use_cpu_initialization=True, + use_mamba_mem_eff_path=True, + ) + for k, v in _flextron_fields(256, 8).items(): + setattr(config, k, v) + + pg_collection = ProcessGroupCollection.use_mpu_process_groups(required_pgs=['tp', 'cp']) + mixer = MambaMixer( + config, + hybrid_stack_spec.submodules.mamba_layer.submodules.mixer.submodules, + config.hidden_size, + layer_number=1, + pg_collection=pg_collection, + ).cuda() + + mgr = add_flextron_mamba_elasticity(mixer, config, layer_idx=0) + assert isinstance(mgr, FlextronMambaElasticityManager) + assert mgr.layer_idx == 0 + mgr.detach_hooks() diff --git a/tests/unit_tests/elastification/test_flextron_stack_elasticity_manager.py b/tests/unit_tests/elastification/test_flextron_stack_elasticity_manager.py new file mode 100644 index 00000000000..c276b519a8f --- /dev/null +++ b/tests/unit_tests/elastification/test_flextron_stack_elasticity_manager.py @@ -0,0 +1,120 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""GPU-backed tests for FlextronStackElasticityManager. + +Tests the final-norm hooks that apply eps modification and sqrt(emb_per) +scaling when the router supplies an embedding choice. Run with: + + torchrun --nproc_per_node=1 -m pytest tests/unit_tests/elastification/test_flextron_stack_elasticity_manager.py +""" + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from megatron.elastification.flextron_elasticity_hooks import ( + FlextronStackElasticityManager, + add_flextron_stack_elasticity, +) + + +def _stack_config(emb_int_list=(256, 128), soft_mask=True, layernorm_epsilon=1e-5): + """Minimal SimpleNamespace exposing every attr the stack manager reads.""" + return SimpleNamespace( + flextron=True, + soft_mask=soft_mask, + hidden_size=256, + emb_int_list=list(emb_int_list), + layernorm_epsilon=layernorm_epsilon, + ) + + +def _stack_with_final_norm(hidden_size=256, eps=1e-5): + """A stand-in HybridStack: only `final_norm` is hooked, nothing else required.""" + stack = nn.Module() + stack.final_norm = nn.LayerNorm(hidden_size, eps=eps).cuda().to(torch.bfloat16) + return stack + + +@pytest.mark.internal +class TestFlextronStackElasticityManager: + def teardown_method(self, method): + # No parallel state was initialized; nothing to tear down. + pass + + def test_disabled_manager_is_noop(self): + config = _stack_config() + config.flextron = False + mgr = FlextronStackElasticityManager(config) + stack = _stack_with_final_norm() + mgr.attach_hooks(stack) # Should silently skip. + assert mgr.hook_handles == [] if hasattr(mgr, "hook_handles") else True + + def test_attach_registers_two_hooks(self): + config = _stack_config() + mgr = FlextronStackElasticityManager(config) + stack = _stack_with_final_norm() + mgr.attach_hooks(stack) + # One pre-hook + one post-hook on final_norm. + assert len(mgr.hook_handles) == 2 + + def test_current_router_emb_none_is_noop(self): + """Without elasticity params set, hooks must pass through unchanged.""" + config = _stack_config() + mgr = FlextronStackElasticityManager(config) + stack = _stack_with_final_norm() + mgr.attach_hooks(stack) + x = torch.randn(4, 2, 256, dtype=torch.bfloat16, device="cuda") + + expected = stack.final_norm(x) # direct call — hooks do run but should no-op + # Hooks were attached in-place, so call again to capture the hooked output. + out = stack.final_norm(x) + torch.testing.assert_close(out, expected) + + def test_soft_mask_scales_output_by_sqrt_emb_per(self): + """With soft_mask and a one-hot router distribution, output should scale by + sqrt(emb_per) of the selected choice.""" + config = _stack_config(emb_int_list=[256, 128], soft_mask=True) + mgr = FlextronStackElasticityManager(config) + stack = _stack_with_final_norm() + mgr.attach_hooks(stack) + + # One-hot on index 1 (emb_int=128 -> per=0.5) + per_logits = torch.tensor([0.0, 1.0], dtype=torch.bfloat16, device="cuda") + mgr.set_elasticity_params(router_emb=(per_logits, 128)) + + x = torch.randn(4, 2, 256, dtype=torch.bfloat16, device="cuda") + # Baseline without elasticity: detach hooks first. + mgr.detach_hooks() + baseline = stack.final_norm(x) + + # Re-attach and run with elasticity. + mgr.attach_hooks(stack) + mgr.set_elasticity_params(router_emb=(per_logits, 128)) + scaled = stack.final_norm(x) + + # Expected: baseline * sqrt(0.5) (since per_logit is 1.0 on idx 1) + expected_scale = (128 / 256) ** 0.5 + torch.testing.assert_close(scaled, baseline * expected_scale, atol=1e-2, rtol=1e-2) + + def test_detach_removes_all_hooks(self): + config = _stack_config() + mgr = FlextronStackElasticityManager(config) + stack = _stack_with_final_norm() + mgr.attach_hooks(stack) + assert len(mgr.hook_handles) == 2 + mgr.detach_hooks() + assert mgr.hook_handles == [] + + +@pytest.mark.internal +class TestAddFlextronStackElasticity: + def test_factory_returns_manager_with_hooks_attached(self): + config = _stack_config() + stack = _stack_with_final_norm() + mgr = add_flextron_stack_elasticity(stack, config) + assert isinstance(mgr, FlextronStackElasticityManager) + assert len(mgr.hook_handles) == 2 + mgr.detach_hooks() diff --git a/tests/unit_tests/elastification/test_flextron_topk_router_elasticity_manager.py b/tests/unit_tests/elastification/test_flextron_topk_router_elasticity_manager.py new file mode 100644 index 00000000000..f21a78a2b98 --- /dev/null +++ b/tests/unit_tests/elastification/test_flextron_topk_router_elasticity_manager.py @@ -0,0 +1,212 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for FlextronTopKRouterElasticityManager. + +Focuses on the hard-mask path: it replaces the router's ``routing`` method +with a wrapper that masks the upper expert indices before delegating, and +must save/restore ``router.expert_bias`` around the call (regression +guard for the prior permanent-mutation bug). + +The soft-mask path goes through ``topk_softmax_with_capacity``, which +requires real MoE plumbing — covered by integration tests, not here. + +Run with: + torchrun --nproc_per_node=1 -m pytest tests/unit_tests/elastification/test_flextron_topk_router_elasticity_manager.py +""" + +from types import SimpleNamespace + +import pytest +import torch + +from megatron.elastification.flextron_elasticity_hooks import ( + FlextronTopKRouterElasticityManager, + add_flextron_topk_router_elasticity, +) + + +def _config(num_moe_experts=8, soft_mask=False, flex_hetero_moe_expert=False): + return SimpleNamespace( + flextron=True, + soft_mask=soft_mask, + flex_hetero_moe_expert=flex_hetero_moe_expert, + num_moe_experts=num_moe_experts, + moe_expert_int_list=[num_moe_experts, num_moe_experts // 2], + hybrid_layer_pattern="E", + ) + + +class _StubRouter: + """Minimal router: holds an ``expert_bias`` tensor and an ``original_routing`` + method that records its inputs so we can assert what was passed.""" + + def __init__(self, expert_bias): + self.expert_bias = expert_bias + # Record (logits_clone, expert_bias_clone) at call-time so we can verify + # what the inner call observed. + self.calls = [] + + def routing(logits, **kwargs): + self.calls.append( + { + "logits": logits.detach().clone(), + "expert_bias": self.expert_bias.detach().clone(), + "kwargs": kwargs, + } + ) + return logits, kwargs + + self.routing = routing + + +@pytest.mark.internal +class TestFlextronTopKRouterElasticityManager: + def test_attach_replaces_routing_method(self): + cfg = _config() + router = _StubRouter(expert_bias=torch.zeros(8)) + original = router.routing + mgr = FlextronTopKRouterElasticityManager(cfg) + mgr.attach_hooks(router) + assert router.routing is not original + # The handle list records the method-replacement entry for detach. + assert len(mgr.hook_handles) == 1 + assert mgr.hook_handles[0][0] == "method_replacement" + + def test_no_elasticity_params_delegates_to_original(self): + """When current_router_moe_expert is None, wrapped_routing must + forward (logits, kwargs) unchanged to the original method.""" + cfg = _config() + router = _StubRouter(expert_bias=torch.zeros(8)) + mgr = FlextronTopKRouterElasticityManager(cfg) + mgr.attach_hooks(router) + + logits = torch.randn(4, 8) + out_logits, out_kwargs = router.routing(logits, foo="bar") + + assert len(router.calls) == 1 + # Logits passed through untouched. + torch.testing.assert_close(router.calls[0]["logits"], logits) + assert router.calls[0]["kwargs"] == {"foo": "bar"} + torch.testing.assert_close(out_logits, logits) + + def test_hard_mask_truncates_upper_logits(self): + """With expert_int=4 (half), logits[:, 4:] should be -inf when the + original routing sees them, and logits[:, :4] should equal the input + scaled by the router_moe_expert logit (max of one-hot).""" + cfg = _config(num_moe_experts=8, soft_mask=False) + router = _StubRouter(expert_bias=torch.zeros(8)) + mgr = FlextronTopKRouterElasticityManager(cfg) + mgr.attach_hooks(router) + + # One-hot on the half-experts choice (index 1 of moe_expert_int_list = 4 experts). + per_logits = torch.tensor([0.0, 1.0]) + mgr.set_elasticity_params(router_moe_expert=(per_logits, 4)) + + logits = torch.ones(2, 8) + router.routing(logits) + + seen = router.calls[0]["logits"] + # Lower 4 columns: scaled by router_moe_expert_logits = max(per_logits) = 1.0. + torch.testing.assert_close(seen[:, :4], torch.ones(2, 4)) + # Upper 4 columns: -inf. + assert torch.isinf(seen[:, 4:]).all() and (seen[:, 4:] < 0).all() + + def test_hard_mask_preserves_expert_bias_after_call(self): + """Regression: the wrapper must save and restore router.expert_bias. + Previously it left a truncated clone bound to router.expert_bias, + leaking into subsequent forwards.""" + cfg = _config(num_moe_experts=8, soft_mask=False) + original_bias = torch.arange(8, dtype=torch.float32) + 1.0 + router = _StubRouter(expert_bias=original_bias.clone()) + mgr = FlextronTopKRouterElasticityManager(cfg) + mgr.attach_hooks(router) + mgr.set_elasticity_params(router_moe_expert=(torch.tensor([0.0, 1.0]), 4)) + + bias_before = router.expert_bias.clone() + router.routing(torch.zeros(2, 8)) + bias_after = router.expert_bias + + # The bias seen *during* the call should have indices 4: zeroed. + seen_bias = router.calls[0]["expert_bias"] + assert (seen_bias[:4] == bias_before[:4]).all() + assert (seen_bias[4:] == 0).all() + # But the bias on the router after the call must be the original. + torch.testing.assert_close(bias_after, original_bias) + # Same Python object, not just equal values. + assert bias_after is not seen_bias + + def test_hard_mask_bias_restored_even_if_inner_raises(self): + """``try/finally`` must restore expert_bias when original_routing raises.""" + cfg = _config(num_moe_experts=8, soft_mask=False) + original_bias = torch.arange(8, dtype=torch.float32) + 1.0 + router = _StubRouter(expert_bias=original_bias.clone()) + + def boom(logits, **kwargs): + raise RuntimeError("simulated downstream failure") + + router.routing = boom + + mgr = FlextronTopKRouterElasticityManager(cfg) + mgr.attach_hooks(router) + mgr.set_elasticity_params(router_moe_expert=(torch.tensor([0.0, 1.0]), 4)) + + with pytest.raises(RuntimeError, match="simulated"): + router.routing(torch.zeros(2, 8)) + + torch.testing.assert_close(router.expert_bias, original_bias) + + def test_hard_mask_with_no_expert_bias(self): + """When router has no expert_bias, the save/restore branch must skip + cleanly and the inner call must still see the masked logits.""" + cfg = _config(num_moe_experts=8, soft_mask=False) + + # Minimal router: no expert_bias attribute, original_routing records + # only the logits it saw (avoids the StubRouter's bias.detach()). + class _RouterNoBias: + def __init__(self): + self.expert_bias = None + self.seen = None + + def routing(logits, **kwargs): + self.seen = logits.detach().clone() + return logits, kwargs + + self.routing = routing + + router = _RouterNoBias() + mgr = FlextronTopKRouterElasticityManager(cfg) + mgr.attach_hooks(router) + mgr.set_elasticity_params(router_moe_expert=(torch.tensor([0.0, 1.0]), 4)) + + router.routing(torch.ones(2, 8)) + assert torch.isinf(router.seen[:, 4:]).all() + # The bias attribute must remain None — no accidental clone-binding. + assert router.expert_bias is None + + def test_detach_restores_original_routing(self): + cfg = _config() + router = _StubRouter(expert_bias=torch.zeros(8)) + original_callable = router.routing # capture pre-attach reference + mgr = FlextronTopKRouterElasticityManager(cfg) + mgr.attach_hooks(router) + wrapped = router.routing + assert wrapped is not original_callable + + mgr.detach_hooks() + + # After detach: routing is back, the helper attribute is gone, the + # handle list is empty. + assert router.routing is original_callable + assert not hasattr(router, "_original_routing") + assert mgr.hook_handles == [] + + +@pytest.mark.internal +class TestAddFlextronTopKRouterElasticity: + def test_factory_returns_manager(self): + cfg = _config() + router = _StubRouter(expert_bias=torch.zeros(8)) + mgr = add_flextron_topk_router_elasticity(router, cfg, layer_idx=0) + assert isinstance(mgr, FlextronTopKRouterElasticityManager) + assert len(mgr.hook_handles) == 1 + mgr.detach_hooks() diff --git a/tests/unit_tests/elastification/test_flextron_transformer_layer_elasticity_manager.py b/tests/unit_tests/elastification/test_flextron_transformer_layer_elasticity_manager.py new file mode 100644 index 00000000000..c14e187cfb1 --- /dev/null +++ b/tests/unit_tests/elastification/test_flextron_transformer_layer_elasticity_manager.py @@ -0,0 +1,141 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""GPU-backed tests for FlextronTransformerLayerElasticityManager. + +Tests the pre_mlp_layernorm pre/post hooks for E-layers. Run with: + + torchrun --nproc_per_node=1 -m pytest tests/unit_tests/elastification/test_flextron_transformer_layer_elasticity_manager.py +""" + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from megatron.elastification.flextron_elasticity_hooks import ( + FlextronTransformerLayerElasticityManager, + add_flextron_transformer_layer_elasticity, +) + + +def _tl_config(emb_int_list=(256, 128), soft_mask=True, layernorm_epsilon=1e-5): + return SimpleNamespace( + flextron=True, + soft_mask=soft_mask, + hidden_size=256, + emb_int_list=list(emb_int_list), + layernorm_epsilon=layernorm_epsilon, + ) + + +def _fake_transformer_layer(hidden_size=256, eps=1e-5): + """Minimal module exposing .pre_mlp_layernorm (the only submodule hooked).""" + layer = nn.Module() + layer.pre_mlp_layernorm = nn.LayerNorm(hidden_size, eps=eps).cuda().to(torch.bfloat16) + return layer + + +@pytest.mark.internal +class TestFlextronTransformerLayerElasticityManager: + def teardown_method(self, method): + pass + + def test_attach_registers_two_hooks(self): + config = _tl_config() + mgr = FlextronTransformerLayerElasticityManager(config) + layer = _fake_transformer_layer() + mgr.attach_hooks(layer) + assert len(mgr.hook_handles) == 2 + + def test_current_router_emb_none_is_noop(self): + """With current_router_emb unset, hook behavior must match no-hook forward.""" + config = _tl_config() + layer = _fake_transformer_layer() + x = torch.randn(4, 2, 256, dtype=torch.bfloat16, device="cuda") + expected = layer.pre_mlp_layernorm(x) + + mgr = FlextronTransformerLayerElasticityManager(config) + mgr.attach_hooks(layer) + # current_router_emb is None — no masking, no scaling. + out = layer.pre_mlp_layernorm(x) + torch.testing.assert_close(out, expected) + + def test_soft_mask_scales_output(self): + """With soft_mask one-hot on a smaller choice, the hook should: + (1) zero input channels beyond the chosen emb_int, then + (2) scale the LN output by sqrt(emb_per). + Reproduce the expected output by applying that mask+scale manually + against plain LN and assert equality (within bf16 tolerance).""" + config = _tl_config(emb_int_list=[256, 128], soft_mask=True) + layer = _fake_transformer_layer() + + x = torch.randn(4, 2, 256, dtype=torch.bfloat16, device="cuda") + + # Build expected output: mask upper half, LN, scale by sqrt(emb_per). + mask = torch.zeros(256, dtype=torch.bfloat16, device="cuda") + mask[:128] = 1.0 + expected = layer.pre_mlp_layernorm(x * mask[None, None, :]) * (128 / 256) ** 0.5 + + mgr = FlextronTransformerLayerElasticityManager(config) + mgr.attach_hooks(layer) + # One-hot on index 1 (emb_int=128 -> per=0.5) + per_logits = torch.tensor([0.0, 1.0], dtype=torch.bfloat16, device="cuda") + mgr.set_elasticity_params(router_emb=(per_logits, 128)) + out = layer.pre_mlp_layernorm(x) + + # Tolerance accommodates the tiny eps drift (5e-6 vs 1e-5) inside LN. + torch.testing.assert_close(out, expected, atol=1e-2, rtol=1e-2) + + def test_full_budget_one_hot_preserves_magnitude_order(self): + """When router is one-hot on full budget (index 0 = 100% emb), the + pre-hook masks nothing and post-hook scales by sqrt(1.0)=1.0.""" + config = _tl_config(emb_int_list=[256, 128], soft_mask=True) + layer = _fake_transformer_layer() + x = torch.randn(4, 2, 256, dtype=torch.bfloat16, device="cuda") + baseline = layer.pre_mlp_layernorm(x) + + mgr = FlextronTransformerLayerElasticityManager(config) + mgr.attach_hooks(layer) + per_logits = torch.tensor([1.0, 0.0], dtype=torch.bfloat16, device="cuda") + mgr.set_elasticity_params(router_emb=(per_logits, 256)) + + out = layer.pre_mlp_layernorm(x) + # Full-budget path: input mask is all-ones, scale is sqrt(1.0). Output + # should equal baseline within bf16 tolerance (eps adjustment may add + # tiny drift). + torch.testing.assert_close(out, baseline, atol=5e-2, rtol=5e-2) + + def test_detach_restores_forward(self): + config = _tl_config() + layer = _fake_transformer_layer() + x = torch.randn(4, 2, 256, dtype=torch.bfloat16, device="cuda") + + mgr = FlextronTransformerLayerElasticityManager(config) + mgr.attach_hooks(layer) + per_logits = torch.tensor([0.0, 1.0], dtype=torch.bfloat16, device="cuda") + mgr.set_elasticity_params(router_emb=(per_logits, 128)) + + masked_out = layer.pre_mlp_layernorm(x) + mgr.detach_hooks() + detached_out = layer.pre_mlp_layernorm(x) + + # After detach, the output should match the un-hooked LN output. + expected = nn.LayerNorm(256, eps=layer.pre_mlp_layernorm.eps).cuda().to(torch.bfloat16) + expected.weight.data.copy_(layer.pre_mlp_layernorm.weight.data) + expected.bias.data.copy_(layer.pre_mlp_layernorm.bias.data) + torch.testing.assert_close(detached_out, expected(x), atol=1e-2, rtol=1e-2) + # The masked output from before detach should differ from the detached one. + assert not torch.allclose(masked_out, detached_out, atol=1e-2) + + +@pytest.mark.internal +class TestAddFlextronTransformerLayerElasticity: + def test_factory_returns_manager(self): + config = _tl_config() + layer = _fake_transformer_layer() + mgr = add_flextron_transformer_layer_elasticity(layer, config, layer_idx=3) + assert isinstance(mgr, FlextronTransformerLayerElasticityManager) + assert mgr.layer_idx == 3 + assert len(mgr.hook_handles) == 2 + mgr.detach_hooks() diff --git a/tests/unit_tests/elastification/test_hybrid_flex_router.py b/tests/unit_tests/elastification/test_hybrid_flex_router.py new file mode 100644 index 00000000000..a5720e92dc4 --- /dev/null +++ b/tests/unit_tests/elastification/test_hybrid_flex_router.py @@ -0,0 +1,212 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""GPU-backed tests for FlextronRouter. + +Covers construction, forward-pass shape/structure for each axis, and +DP-aware Gumbel determinism (same seed + iteration => identical output). + +Run with: + + torchrun --nproc_per_node=1 -m pytest tests/unit_tests/elastification/test_hybrid_flex_router.py +""" + +from argparse import Namespace + +import pytest +import torch + +import megatron.elastification.router.hybrid_flex_router as _router_module +import megatron.training as _megatron_training +from megatron.core.transformer import TransformerConfig +from megatron.elastification.router.hybrid_flex_router import FlextronRouter +from tests.unit_tests.test_utilities import Utils + + +def _router_config( + hidden_size=256, ffn_hidden_size=128, num_heads=8, mamba_num_heads=8, num_moe_experts=8 +): + """Build a TransformerConfig with every attr FlextronRouter reads.""" + config = TransformerConfig( + hidden_size=hidden_size, + num_layers=2, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_hidden_size, + num_moe_experts=num_moe_experts, + use_cpu_initialization=True, + ) + flex_fields = dict( + flextron=True, + soft_mask=True, + add_skipping=False, + flex_hetero_ffn=False, + flex_hetero_mamba=False, + flex_hetero_moe_expert=False, + hybrid_layer_pattern="ME", + normalize_router_logits=False, + router_inter_dim=32, + router_std=0.1, + router_gbs=2, + router_beta=1.0, + loss_alpha=1.0, + tau_init=1.0, + tau_decay=0.9999, + hard_sample_th=0.996, + # Enable the scaler with a constant 1.0 so `scale` is defined inside + # the axis forwards (they use it unconditionally) but its value is a + # no-op. The get_args stub in setup_method supplies train_iters so + # add_scaler_schedule can construct the linspace. + linear_scaler_start=1.0, + linear_scaler_end=1.0, + budget_list=[1.0, 0.5], + budget_probs=[1.0, 1.0], + budget_type="param", + original_model_sample_prob=0.0, + curr_iteration=0, + mamba_num_heads=mamba_num_heads, + emb_int_list=[hidden_size, hidden_size // 2], + mlp_int_list=[ffn_hidden_size, ffn_hidden_size // 2], + mamba_int_list=[mamba_num_heads, mamba_num_heads // 2], + moe_expert_int_list=[num_moe_experts, num_moe_experts // 2], + override_selected_budget=None, + ) + for k, v in flex_fields.items(): + setattr(config, k, v) + return config + + +@pytest.mark.internal +class TestFlextronRouter: + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + # FlextronRouter calls _sync_router_weights in __init__, which does an + # NCCL broadcast on CPU params before we get a chance to call .cuda(). + # NCCL has no CPU backend, so we stub the sync out — it's a no-op at + # world_size=1 anyway. Restored in teardown. + self._orig_sync = FlextronRouter._sync_router_weights + FlextronRouter._sync_router_weights = lambda self: None + # forward() pulls `args.curr_iteration` via megatron.training.get_args() + # which fails outside a full Megatron initialize. Install a minimal + # stub that returns the attrs the router needs. + self._orig_get_args = _megatron_training.get_args + # train_iters needs to be >= max curr_iteration used in any test: the + # scaler is a linspace of length train_iters and axis forwards index + # it with curr_iteration. Since start=end=1.0 the values are all 1.0 + # regardless of length, so overshooting is free. + _megatron_training.get_args = lambda: Namespace( + curr_iteration=0, train_iters=1000, train_samples=1000, global_batch_size=1 + ) + # The router's __init__ also reads global microbatch state via + # get_current_global_batch_size / get_micro_batch_size. These return + # None outside a full Megatron initialize — stub them at the module + # level where the router imported them. + self._orig_gbs = _router_module.get_current_global_batch_size + self._orig_mbs = _router_module.get_micro_batch_size + _router_module.get_current_global_batch_size = lambda: 1 + _router_module.get_micro_batch_size = lambda: 1 + + def teardown_method(self, method): + FlextronRouter._sync_router_weights = self._orig_sync + _megatron_training.get_args = self._orig_get_args + _router_module.get_current_global_batch_size = self._orig_gbs + _router_module.get_micro_batch_size = self._orig_mbs + Utils.destroy_model_parallel() + + def test_construction(self): + config = _router_config() + router = FlextronRouter(config).cuda() + # Each gate is a Sequential of two linear layers + activation. + assert hasattr(router, "gate_mlp") + assert hasattr(router, "gate_emb") + assert hasattr(router, "gate_mamba") + assert hasattr(router, "gate_moe_expert") + # Attention head elasticity is not supported. + assert not hasattr(router, "gate_head") + # Skipping was disabled in the config. + assert not hasattr(router, "gate_skip_layer") + + def test_router_params_marked_for_pp_sync(self): + config = _router_config() + router = FlextronRouter(config).cuda() + for p in router.parameters(): + # _mark_router_params_for_pp_sync adds this attribute to every + # trainable parameter so the PP gradient sync picks them up. + assert getattr(p, "flextron_router_pp_sync", False) is True + + def test_forward_returns_five_axis_outputs(self): + config = _router_config() + router = FlextronRouter(config).cuda() + out = router(1.0) + assert len(out) == 5 + # Order (per hybrid_flex_router.forward): + # (mlp, skipping, emb, mamba, moe_expert) + mlp, skipping, emb, mamba, moe_expert = out + # Skipping is None when add_skipping=False. + assert skipping is None + # Each axis output is a (logits, choice) tuple. + for axis in (mlp, emb, mamba, moe_expert): + assert isinstance(axis, tuple) and len(axis) == 2 + + def test_emb_output_shape_matches_choice_count(self): + config = _router_config() + router = FlextronRouter(config).cuda() + _, _, emb, _, _ = router(1.0) + logits, choice = emb + # Logits have one entry per emb_int_list choice. + assert logits.numel() == len(config.emb_int_list) + assert choice in config.emb_int_list + + def test_gumbel_determinism(self): + """Two routers at the same config + iteration + fwd_pass_count should + produce identical Gumbel-softmax samples.""" + config = _router_config() + config.curr_iteration = 0 + + router_a = FlextronRouter(config).cuda() + router_b = FlextronRouter(config).cuda() + # Copy weights so both routers are in the same parameter state; the + # determinism check is about the Gumbel RNG, not init noise. + router_b.load_state_dict(router_a.state_dict()) + + out_a = router_a(1.0) + out_b = router_b(1.0) + for axis_a, axis_b in zip(out_a, out_b): + if axis_a is None: + assert axis_b is None + continue + logits_a, choice_a = axis_a + logits_b, choice_b = axis_b + torch.testing.assert_close(logits_a, logits_b, atol=0, rtol=0) + assert choice_a == choice_b + + def test_fwd_pass_count_increments(self): + config = _router_config() + router = FlextronRouter(config).cuda() + assert router.fwd_pass_count == 0 + router(1.0) + assert router.fwd_pass_count == 1 + router(1.0) + assert router.fwd_pass_count == 2 + + def test_different_iterations_give_different_samples(self): + """Bumping curr_iteration changes the Gumbel seed; logits should differ.""" + config = _router_config() + router = FlextronRouter(config).cuda() + + # Iteration 0 via the default setup-method stub. + out_iter_0 = router(1.0) + + # Swap the stub to return iteration 100, reset fwd_pass_count so + # that is the only thing that varies. train_iters must stay >= + # curr_iteration (matches setup-method stub length). + _megatron_training.get_args = lambda: Namespace( + curr_iteration=100, train_iters=1000, train_samples=1000, global_batch_size=1 + ) + router.fwd_pass_count = 0 + out_iter_100 = router(1.0) + + # Emb-axis logits should differ between iterations. + _, _, emb_0, _, _ = out_iter_0 + _, _, emb_100, _, _ = out_iter_100 + assert not torch.allclose(emb_0[0], emb_100[0]) diff --git a/tests/unit_tests/elastification/test_inject_flextron_forward_logic.py b/tests/unit_tests/elastification/test_inject_flextron_forward_logic.py new file mode 100644 index 00000000000..c802cdb1b9a --- /dev/null +++ b/tests/unit_tests/elastification/test_inject_flextron_forward_logic.py @@ -0,0 +1,217 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for ``inject_flextron_forward_logic``. + +These tests pin two invariants the dev-branch divergence hunt uncovered: + +1. When a Flextron manager is attached and a budget is passed in kwargs, + ``update_hook_elasticity_params`` must be called *before* the original + forward runs — otherwise all hooks fire with ``current_router_emb=None`` + and silently no-op. +2. When no Flextron manager is present, ``flextron_kwargs`` is cleared + before the original forward so no unexpected keyword args leak through. +""" + +from types import SimpleNamespace + +import pytest + +from megatron.elastification.flextron_utils import inject_flextron_forward_logic + + +class _CallLog: + """Record the order in which stub methods are invoked.""" + + def __init__(self): + self.events = [] + + def note(self, name, **payload): + self.events.append((name, payload)) + + +def _make_original_forward(log): + def _fwd( + input_ids=None, + position_ids=None, + attention_mask=None, + decoder_input=None, + labels=None, + inference_context=None, + runtime_gather_output=None, + inference_params=None, + ): + log.note( + "original_forward", + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + ) + return "forward-result" + + return _fwd + + +def _make_manager(log, *, router_present=True, router_output_kwargs=None, loss_func=None): + router_output_kwargs = ( + router_output_kwargs + if router_output_kwargs is not None + else {"router_emb": (object(), object())} + ) + + def process_router_output(budget_item): + log.note("process_router_output", budget_item=budget_item) + return router_output_kwargs, loss_func + + def update_hook_elasticity_params(flextron_kwargs): + log.note("update_hook_elasticity_params", flextron_kwargs=flextron_kwargs) + + return SimpleNamespace( + router=object() if router_present else None, + process_router_output=process_router_output, + update_hook_elasticity_params=update_hook_elasticity_params, + ) + + +def _attach_forward(model_cls, original): + """Return a model whose .forward is `original` and install the wrapper. + + `original` is stored as an instance attribute so Python does not auto-bind + it as a method — we want the stub called as a plain function by + ``flextron_forward``. + """ + model = model_cls() + model.forward = original + inject_flextron_forward_logic(model) + return model + + +class _StubModelNoManager: + """Minimal model with no ._flextron_manager attribute at all.""" + + def __init__(self): + self.config = SimpleNamespace() + + +class _StubModelWithManager: + def __init__(self): + self.config = SimpleNamespace(flextron=True, is_flex_eval=False) + self._flextron_manager = None # filled in by test + + +class TestForwardReplacement: + def test_forward_is_replaced(self): + log = _CallLog() + original = _make_original_forward(log) + model = _StubModelNoManager() + model.forward = original + before = model.forward + inject_flextron_forward_logic(model) + # model.forward is now a bound method wrapper, not the raw stub. + assert model.forward is not before + + +class TestNoManager: + def test_without_manager_original_is_called_directly(self): + log = _CallLog() + original = _make_original_forward(log) + model = _attach_forward(_StubModelNoManager, original) + + result = model.forward( + input_ids="ids", + position_ids="pos", + attention_mask="mask", + budget=0.697, # should be swallowed, not leaked through + ) + + assert result == "forward-result" + # Only original_forward was called; no router / hook-update step. + names = [e[0] for e in log.events] + assert names == ["original_forward"] + # budget kwarg was cleared before reaching original_forward. + assert "budget" not in log.events[0][1] + + def test_manager_with_no_router_skips_router_logic(self): + log = _CallLog() + original = _make_original_forward(log) + model = _attach_forward(_StubModelWithManager, original) + model._flextron_manager = _make_manager(log, router_present=False) + + model.forward(input_ids="ids", position_ids="pos", attention_mask="mask", budget=0.5) + + names = [e[0] for e in log.events] + assert names == ["original_forward"] + + +class TestManagerOrdering: + def test_budget_kwarg_triggers_router_and_hooks_before_forward(self): + """Core invariant: update_hook_elasticity_params runs *before* original_forward.""" + log = _CallLog() + original = _make_original_forward(log) + model = _attach_forward(_StubModelWithManager, original) + model._flextron_manager = _make_manager(log) + + model.forward(input_ids="ids", position_ids="pos", attention_mask="mask", budget=0.697) + + names = [e[0] for e in log.events] + # Exact expected sequence. + assert names == [ + "process_router_output", + "update_hook_elasticity_params", + "original_forward", + ] + # Sanity: the budget actually used is the one passed in kwargs. + assert log.events[0][1]["budget_item"] == 0.697 + + def test_loss_func_invoked_when_returned(self): + log = _CallLog() + original = _make_original_forward(log) + model = _attach_forward(_StubModelWithManager, original) + + def _loss_func(kwargs, budget_item): + log.note("loss_func", budget_item=budget_item) + return "budget-loss" + + model._flextron_manager = _make_manager(log, loss_func=_loss_func) + + model.forward(input_ids="ids", position_ids="pos", attention_mask="mask", budget=0.697) + + names = [e[0] for e in log.events] + # loss_func must run after router output and before the hook update. + assert names == [ + "process_router_output", + "loss_func", + "update_hook_elasticity_params", + "original_forward", + ] + + +class TestOverrideSelectedBudget: + def test_override_non_one_sets_budget_from_override(self): + log = _CallLog() + original = _make_original_forward(log) + model = _attach_forward(_StubModelWithManager, original) + model._flextron_manager = _make_manager(log) + model.config.is_flex_eval = True + model.config.override_selected_budget = [0.577] + + # No budget kwarg on the caller side — override should supply it. + model.forward(input_ids="ids", position_ids="pos", attention_mask="mask") + + names = [e[0] for e in log.events] + assert names == [ + "process_router_output", + "update_hook_elasticity_params", + "original_forward", + ] + assert log.events[0][1]["budget_item"] == 0.577 + + def test_override_without_flex_eval_raises(self): + log = _CallLog() + original = _make_original_forward(log) + model = _attach_forward(_StubModelWithManager, original) + model._flextron_manager = _make_manager(log) + model.config.is_flex_eval = False + model.config.override_selected_budget = [0.577] + + with pytest.raises(AssertionError): + model.forward(input_ids="ids", position_ids="pos", attention_mask="mask") diff --git a/tests/unit_tests/elastification/test_loss_func.py b/tests/unit_tests/elastification/test_loss_func.py new file mode 100644 index 00000000000..d5f27aa1ba8 --- /dev/null +++ b/tests/unit_tests/elastification/test_loss_func.py @@ -0,0 +1,168 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for megatron.elastification.loss_func. + +Covers the non-distributed paths of ``_mask_loss`` and ``loss_func``: +no tensor-parallel reductions, no sequence parallel, no KD. The +distributed/KD branches require a multi-rank torch.distributed init and +are left for a follow-up. +""" + +from argparse import Namespace +from unittest.mock import MagicMock + +import pytest +import torch + +from megatron.elastification import loss_func as loss_func_module +from megatron.elastification.loss_func import _mask_loss, loss_func + + +def _stub_args(**overrides): + defaults = dict( + router_beta=1.0, + loss_alpha=1.0, + freeze_router=False, + tensor_model_parallel_size=1, + export_kd_teacher_load=False, + budget_list=[1.0, 0.5], + ) + defaults.update(overrides) + return Namespace(**defaults) + + +@pytest.fixture +def patch_get_args(monkeypatch): + """Replace get_args in the loss_func module with a stub returning the args we set.""" + holder = {"args": _stub_args()} + + def _set_args(**overrides): + holder["args"] = _stub_args(**overrides) + + monkeypatch.setattr(loss_func_module, "get_args", lambda: holder["args"]) + return _set_args + + +def _flat_loss_tensor(values): + """Build the (B, S) loss tensor that ``_mask_loss`` expects.""" + return torch.tensor(values, dtype=torch.float32).reshape(1, -1) + + +class TestMaskLossPlainTensor: + """When output_tensor is a plain Tensor, no param_loss is reported.""" + + def test_returns_scalar_loss_tensor(self, patch_get_args): + patch_get_args() + out = torch.tensor([[1.0, 2.0, 4.0]]) + mask = torch.tensor([[1.0, 1.0, 0.0]]) + result = _mask_loss(out, mask) + assert isinstance(result, torch.Tensor) + assert result.item() == pytest.approx(3.0) # 1*1 + 2*1 + 4*0 + + +class TestMaskLossWithParamLossTuple: + """output_tensor is (output, (param_loss, extra_dict)).""" + + def test_positive_param_loss_added_to_lm(self, patch_get_args): + patch_get_args(loss_alpha=1.0) + out = torch.tensor([[1.0, 2.0]]) + mask = torch.tensor([[1.0, 1.0]]) + param_loss = torch.tensor([0.5]) + loss, param_item = _mask_loss((out, (param_loss, {})), mask) + # lm = 1+2 = 3; param contribution = 0.5 * num_tokens(2) * alpha(1) = 1.0 + assert loss.item() == pytest.approx(4.0) + assert param_item.item() == pytest.approx(1.0) + + def test_negative_param_loss_scaled_by_router_beta(self, patch_get_args): + # router_beta flips and scales negative param losses. + patch_get_args(router_beta=2.0, loss_alpha=1.0) + out = torch.tensor([[1.0, 1.0]]) + mask = torch.tensor([[1.0, 1.0]]) + param_loss = torch.tensor([-0.25]) + loss, param_item = _mask_loss((out, (param_loss, {})), mask) + # param_loss negated and scaled: -2.0 * (-0.25) = 0.5 + # param_item = 0.5 * num_tokens(2) * alpha(1) = 1.0 + # lm contribution = 2; total = 3 + assert param_item.item() == pytest.approx(1.0) + assert loss.item() == pytest.approx(3.0) + + def test_freeze_router_drops_param_contribution(self, patch_get_args): + patch_get_args(freeze_router=True) + out = torch.tensor([[1.0, 1.0]]) + mask = torch.tensor([[1.0, 1.0]]) + param_loss = torch.tensor([0.5]) + result = _mask_loss((out, (param_loss, {})), mask) + # When router is frozen, param_loss isn't added — bare scalar returned. + assert isinstance(result, torch.Tensor) + assert result.item() == pytest.approx(2.0) + + def test_loss_alpha_scales_param_contribution(self, patch_get_args): + patch_get_args(loss_alpha=10.0) + out = torch.tensor([[1.0, 1.0]]) + mask = torch.tensor([[1.0, 1.0]]) + param_loss = torch.tensor([0.5]) + loss, param_item = _mask_loss((out, (param_loss, {})), mask) + # param_item = 0.5 * 2 tokens * 10 = 10.0 + assert param_item.item() == pytest.approx(10.0) + assert loss.item() == pytest.approx(12.0) # 2 (lm) + 10 (param) + + +class TestLossFuncReportingNoKD: + """Top-level loss_func paths that don't enter the KD branch.""" + + def _model(self, training=True): + m = MagicMock() + m.training = training + return m + + def test_full_model_step_routes_to_lm_loss_full(self, patch_get_args): + patch_get_args() + out = torch.tensor([[1.0, 2.0]]) + mask = torch.tensor([[1.0, 1.0]]) + # param_loss = 0 → recognized as full-model step + zero_param = torch.tensor([0.0]) + loss, num_tokens, report = loss_func( + mask, (out, (zero_param, {})), self._model(training=True) + ) + # The report dict must contain both keys, but only "(full)" carries data. + assert "lm loss (full)" in report and "lm loss (budget)" in report + full_val, full_den = report["lm loss (full)"][0], report["lm loss (full)"][1] + budget_val, budget_den = report["lm loss (budget)"][0], report["lm loss (budget)"][1] + assert budget_val.item() == 0.0 and budget_den.item() == 0.0 + # full_val gets lm loss minus param contribution. param_loss=0 → just lm. + assert full_val.item() == pytest.approx(3.0) + assert num_tokens.item() == 2 + + def test_sub_budget_step_routes_to_lm_loss_budget(self, patch_get_args): + patch_get_args() + out = torch.tensor([[1.0, 2.0]]) + mask = torch.tensor([[1.0, 1.0]]) + nonzero_param = torch.tensor([0.5]) # signals sub-budget step + loss, num_tokens, report = loss_func( + mask, (out, (nonzero_param, {})), self._model(training=True) + ) + full_val = report["lm loss (full)"][0] + budget_val = report["lm loss (budget)"][0] + assert full_val.item() == 0.0 + # budget side carries the lm loss (3.0 = 1+2) + assert budget_val.item() == pytest.approx(3.0) + + def test_num_tokens_clamped_when_all_masked(self, patch_get_args): + patch_get_args() + out = torch.tensor([[5.0, 5.0]]) + mask = torch.tensor([[0.0, 0.0]]) + zero_param = torch.tensor([0.0]) + _, num_tokens, _ = loss_func(mask, (out, (zero_param, {})), self._model(training=True)) + # Guard at line 94 clamps to min=1 to avoid divide-by-zero downstream. + assert num_tokens.item() == 1 + + def test_report_values_are_packed_pairs(self, patch_get_args): + """Every report entry is converted to a (value, num_tokens) tensor pair.""" + patch_get_args() + out = torch.tensor([[1.0, 1.0]]) + mask = torch.tensor([[1.0, 1.0]]) + zero_param = torch.tensor([0.0]) + _, _, report = loss_func(mask, (out, (zero_param, {})), self._model(training=True)) + for key, val in report.items(): + assert isinstance(val, torch.Tensor), f"report[{key}] not packed" + assert val.shape == (2,), f"report[{key}] not (value, num_tokens)" diff --git a/tests/unit_tests/elastification/test_memory_config.py b/tests/unit_tests/elastification/test_memory_config.py new file mode 100644 index 00000000000..70ea488108c --- /dev/null +++ b/tests/unit_tests/elastification/test_memory_config.py @@ -0,0 +1,128 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for megatron.elastification.memory_config.""" + +from argparse import Namespace + +import pytest +import yaml + +from megatron.elastification.memory_config import MemoryConfig, load_memory_config + + +@pytest.fixture +def profiles_file(tmp_path): + """Write a minimal memory-profiles YAML and return its path.""" + data = { + "presets": { + "bf16": { + "params": 2, + "kv_cache": 2, + "ssm_cache": 2, + "max_buffer": 2, + "param_budget_target": "active", + }, + "fp8_kv": { + "params": 2, + "kv_cache": 1, + "ssm_cache": 2, + "max_buffer": 2, + "param_budget_target": "active", + }, + "total_target": { + "params": 2, + "kv_cache": 2, + "ssm_cache": 2, + "max_buffer": 2, + "param_budget_target": "total", + }, + } + } + path = tmp_path / "memory_profiles.yaml" + path.write_text(yaml.safe_dump(data)) + return str(path) + + +def _make_args(profile="bf16", profiles_path=None, **overrides): + defaults = dict( + memory_profile=profile, + memory_profile_path=profiles_path, + bpe_params=None, + bpe_kv_cache=None, + bpe_ssm_cache=None, + bpe_max_buffer=None, + param_budget_target=None, + ) + defaults.update(overrides) + return Namespace(**defaults) + + +class TestMemoryConfigDataclass: + def test_default_values(self): + cfg = MemoryConfig() + assert cfg.bpe_params == 2.0 + assert cfg.bpe_kv_cache == 2.0 + assert cfg.bpe_ssm_cache == 2.0 + assert cfg.bpe_max_buffer == 2.0 + assert cfg.param_budget_target == "active" + + def test_invalid_param_budget_target_rejected(self): + with pytest.raises(ValueError, match="param_budget_target"): + MemoryConfig(param_budget_target="bogus") + + def test_valid_param_budget_target_accepted(self): + MemoryConfig(param_budget_target="active") + MemoryConfig(param_budget_target="total") + + +class TestLoadMemoryConfig: + def test_preset_applied(self, profiles_file): + args = _make_args(profile="fp8_kv", profiles_path=profiles_file) + cfg = load_memory_config(args) + assert cfg.bpe_params == 2.0 + assert cfg.bpe_kv_cache == 1.0 # FP8 + assert cfg.bpe_ssm_cache == 2.0 + assert cfg.bpe_max_buffer == 2.0 + + def test_preset_param_budget_target(self, profiles_file): + args = _make_args(profile="total_target", profiles_path=profiles_file) + cfg = load_memory_config(args) + assert cfg.param_budget_target == "total" + + def test_cli_override_takes_priority_over_preset(self, profiles_file): + args = _make_args( + profile="bf16", profiles_path=profiles_file, bpe_kv_cache=0.5625 # override + ) + cfg = load_memory_config(args) + assert cfg.bpe_kv_cache == 0.5625 # override wins + assert cfg.bpe_params == 2.0 # preset preserved + + def test_param_budget_target_override(self, profiles_file): + args = _make_args(profile="bf16", profiles_path=profiles_file, param_budget_target="total") + cfg = load_memory_config(args) + assert cfg.param_budget_target == "total" + + def test_unknown_profile_raises(self, profiles_file): + args = _make_args(profile="nonexistent", profiles_path=profiles_file) + with pytest.raises(ValueError, match="not found"): + load_memory_config(args) + + def test_missing_profile_file_raises(self, tmp_path): + args = _make_args(profile="bf16", profiles_path=str(tmp_path / "missing.yaml")) + with pytest.raises(FileNotFoundError): + load_memory_config(args) + + def test_none_profile_name_defaults_to_bf16(self, profiles_file): + args = _make_args(profile=None, profiles_path=profiles_file) + cfg = load_memory_config(args) + # bf16 defaults in the fixture. + assert cfg.bpe_params == 2.0 + assert cfg.bpe_kv_cache == 2.0 + + def test_default_profiles_path_loads_bundled_yaml(self): + # When profiles_path is None, the loader falls back to the bundled + # megatron/elastification/memory_profiles.yaml. + args = _make_args(profile="bf16", profiles_path=None) + cfg = load_memory_config(args) + assert cfg.bpe_params == 2.0 + assert cfg.bpe_kv_cache == 2.0 diff --git a/tests/unit_tests/fusions/test_mla_yarn_rope_apply.py b/tests/unit_tests/fusions/test_mla_yarn_rope_apply.py index 762195b5d7f..968e07e9201 100644 --- a/tests/unit_tests/fusions/test_mla_yarn_rope_apply.py +++ b/tests/unit_tests/fusions/test_mla_yarn_rope_apply.py @@ -271,13 +271,10 @@ def _test_fused_mla_rope_kv_split(input_format, remove_interleaving=False): @pytest.mark.skipif(not is_torch_min_version("2.5.0"), reason="Requires PyTorch >= 2.5.0") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("input_format", ["sbhd", "thd"]) -class TestFusedMLARope: - @pytest.mark.parametrize("inverse", [False, True]) - @pytest.mark.parametrize("remove_interleaving", [False, True]) - def test_inplace_forward_backward(self, input_format, inverse, remove_interleaving): - _test_fused_mla_rope_inplace( - input_format, inverse=inverse, remove_interleaving=remove_interleaving - ) +class TestFusedApplyMLARope: + @pytest.mark.flaky_in_dev + def test_forward_backward_for_q(self, input_format): + _test_fused_apply_mla_rope_for_q(input_format) @pytest.mark.parametrize("remove_interleaving", [False, True]) def test_kv_split_forward_backward(self, input_format, remove_interleaving): diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 06acdcfec9f..0e307f600f1 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -16,7 +16,7 @@ ) from megatron.core.inference.inference_request import DynamicInferenceRequest from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols +from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils @@ -221,7 +221,7 @@ def test_request_overflow(self, is_hybrid_model: bool): dynamic_context.add_request( DynamicInferenceRequest( request_id=i, - prompt_tokens=torch.zeros(10, device='cuda'), + prompt_tokens=torch.zeros(10, device='cpu'), sampling_params=SamplingParams( num_tokens_to_generate=dynamic_context.max_tokens - 10 ), @@ -249,7 +249,7 @@ def test_token_overflow_error(self, is_hybrid_model: bool): dynamic_context.add_request( DynamicInferenceRequest( request_id=1, - prompt_tokens=torch.arange(0, 225, device='cuda'), + prompt_tokens=torch.arange(0, 225, device='cpu'), sampling_params=SamplingParams( num_tokens_to_generate=dynamic_context.max_tokens - 25 ), @@ -279,7 +279,7 @@ def test_reset(self, is_hybrid_model: bool): dynamic_context.paused_request_count = 5 dynamic_context.padded_active_token_count = 10 dynamic_context.padded_active_request_count = 5 - dynamic_context.paused_tokens = torch.tensor([1, 2, 3], device='cuda') + dynamic_context.paused_tokens = torch.tensor([1, 2, 3], device='cpu') dynamic_context.request_ids.fill_(1) dynamic_context.request_query_lengths.fill_(1) dynamic_context.request_kv_length_offsets.fill_(1) @@ -363,7 +363,7 @@ def test_allocate_and_release_memory_blocks(self, is_hybrid_model): ) assert dynamic_context.kv_block_allocator.total_avail == expected_block_count_avail dynamic_context.kv_block_allocator.release_memory_blocks( - torch.tensor(expected_memory_blocks[-2:], device='cuda') + torch.tensor(expected_memory_blocks[-2:], device='cpu') ) assert dynamic_context.kv_block_allocator.total_avail == expected_block_count_avail + 2 assert ( @@ -400,7 +400,7 @@ def test_add_request(self, is_hybrid_model: bool): dynamic_context.add_request( DynamicInferenceRequest( request_id=0, - prompt_tokens=torch.arange(0, context_length, dtype=torch.long, device='cuda'), + prompt_tokens=torch.arange(0, context_length, dtype=torch.long, device='cpu'), sampling_params=SamplingParams( num_tokens_to_generate=dynamic_context.max_tokens - context_length ), @@ -419,15 +419,15 @@ def test_add_request(self, is_hybrid_model: bool): assert dynamic_context.request_last_kv_block_offset[0].item() == 15 assert torch.all( dynamic_context.token_to_pos_ids[0:context_length] - == torch.arange(0, context_length, dtype=torch.long, device='cuda') + == torch.arange(0, context_length, dtype=torch.long, device='cpu') ) assert torch.all( dynamic_context.token_to_input_ids[0:context_length] - == torch.arange(0, context_length, dtype=torch.long, device='cuda') + == torch.arange(0, context_length, dtype=torch.long, device='cpu') ) assert torch.all( dynamic_context.token_to_position_in_request[0:context_length] - == torch.arange(0, context_length, dtype=torch.long, device='cuda') + == torch.arange(0, context_length, dtype=torch.long, device='cpu') ) # Verify token_to_block_idx and token_to_local_position_within_kv_block based on assigned blocks @@ -448,7 +448,7 @@ def test_add_request(self, is_hybrid_model: bool): ) assert torch.all( dynamic_context.token_to_local_position_within_kv_block[0:context_length] - == torch.arange(0, context_length, dtype=torch.long, device='cuda') + == torch.arange(0, context_length, dtype=torch.long, device='cpu') % dynamic_context.block_size_tokens ) @@ -470,12 +470,12 @@ def test_add_dummy_requests_parallel_populates_state(self): requests = [ DynamicInferenceRequest( request_id=100, - prompt_tokens=torch.arange(0, 3, device='cuda'), + prompt_tokens=torch.arange(0, 3, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=2, termination_id=7), ), DynamicInferenceRequest( request_id=101, - prompt_tokens=torch.arange(3, 9, device='cuda'), + prompt_tokens=torch.arange(3, 9, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=1, termination_id=8), ), ] @@ -492,12 +492,12 @@ def test_add_dummy_requests_parallel_populates_state(self): assert dynamic_context.kv_block_allocator.total_avail == block_avail_before expected_tokens = torch.cat( - [torch.arange(0, 3, device='cuda'), torch.arange(3, 9, device='cuda')] + [torch.arange(0, 3, device='cpu'), torch.arange(3, 9, device='cpu')] ) assert torch.equal(dynamic_context.token_to_input_ids[:total_tokens], expected_tokens) expected_positions = torch.tensor( - [0, 1, 2, 0, 1, 2, 3, 4, 5], device='cuda', dtype=torch.long + [0, 1, 2, 0, 1, 2, 3, 4, 5], device='cpu', dtype=torch.long ) assert torch.equal( dynamic_context.token_to_position_in_request[:total_tokens], expected_positions @@ -505,7 +505,7 @@ def test_add_dummy_requests_parallel_populates_state(self): assert torch.equal(dynamic_context.token_to_pos_ids[:total_tokens], expected_positions) expected_request_indices = torch.tensor( - [0, 0, 0, 1, 1, 1, 1, 1, 1], device='cuda', dtype=torch.long + [0, 0, 0, 1, 1, 1, 1, 1, 1], device='cpu', dtype=torch.long ) assert torch.equal( dynamic_context.token_to_request_idx[:total_tokens], expected_request_indices @@ -521,15 +521,15 @@ def test_add_dummy_requests_parallel_populates_state(self): assert torch.equal( dynamic_context.request_query_lengths[: len(requests)], - torch.tensor(lengths, device='cuda', dtype=torch.int32), + torch.tensor(lengths, device='cpu', dtype=torch.int32), ) assert torch.equal( dynamic_context.request_output_lengths[: len(requests)], - torch.tensor([5, 7], device='cuda', dtype=torch.int32), + torch.tensor([5, 7], device='cpu', dtype=torch.int32), ) assert torch.equal( dynamic_context.request_kv_block_counts[: len(requests)], - torch.tensor([1, 2], device='cuda', dtype=torch.int32), + torch.tensor([1, 2], device='cpu', dtype=torch.int32), ) assert torch.all( dynamic_context.request_to_kv_block_ids[0, :1] == dummy_block_idx @@ -542,12 +542,12 @@ def test_add_dummy_requests_parallel_populates_state(self): assert torch.all(dynamic_context.request_last_kv_block_id[:2] == dummy_block_idx) assert torch.equal( dynamic_context.request_last_kv_block_offset[:2], - torch.tensor([2, 1], device='cuda', dtype=torch.int32), + torch.tensor([2, 1], device='cpu', dtype=torch.int32), ) assert torch.equal( dynamic_context.request_metadata["termination_id"][:2], - torch.tensor([7.0, 8.0], device='cuda'), + torch.tensor([7.0, 8.0], device='cpu'), ) @pytest.mark.internal @@ -569,7 +569,7 @@ def test_add_dummy_requests_parallel_hybrid_allocates_mamba(self): request = DynamicInferenceRequest( request_id=55, - prompt_tokens=torch.arange(0, 5, device='cuda'), + prompt_tokens=torch.arange(0, 5, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=4, termination_id=9), ) @@ -577,6 +577,10 @@ def test_add_dummy_requests_parallel_hybrid_allocates_mamba(self): mamba_idx = dynamic_context.mamba_metadata.request_to_mamba_state_idx[0].item() assert mamba_idx >= 0 + + # Mamba state zeroing is deferred until transfer_bookkeeping_to_gpu(). + dynamic_context.initialize_attention_state() + dynamic_context.transfer_bookkeeping_to_gpu() assert torch.all(dynamic_context.mamba_conv_states[:, mamba_idx] == 0) assert torch.all(dynamic_context.mamba_ssm_states[:, mamba_idx] == 0) @@ -597,7 +601,7 @@ def test_add_dummy_requests_parallel_decode_does_not_count_as_prefill(self): request = DynamicInferenceRequest( request_id=5, - prompt_tokens=torch.arange(0, 1, device='cuda'), + prompt_tokens=torch.arange(0, 1, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=1, termination_id=2), ) @@ -663,10 +667,10 @@ def test_update_request(self, is_hybrid_model: bool): is_hybrid_model=is_hybrid_model, ) - active_requests_mask = torch.Tensor([1, 0, 1, 1, 1, 0, 0, 1]).cuda().int() - next_tokens = torch.arange(2, 10, device='cuda').int() + active_requests_mask = torch.Tensor([1, 0, 1, 1, 1, 0, 0, 1]).int() + next_tokens = torch.arange(2, 10, device='cpu').int() dynamic_context.paused_request_count = 2 - dynamic_context.paused_tokens = torch.Tensor([0, 1]).cuda().int() + dynamic_context.paused_tokens = torch.Tensor([0, 1]).int() dynamic_context.total_request_count = 5 # Total req count should be equal to paused + num elements in active request mask. @@ -723,7 +727,7 @@ def test_update_request(self, is_hybrid_model: bool): # Then set up the test data dynamic_context.request_ids[0:10] = torch.tensor( - [0, 1, 5, 6, 4, 2, 9, 7, 8, 9], device=torch.cuda.current_device() + [0, 1, 5, 6, 4, 2, 9, 7, 8, 9], device='cpu' ) # Now verify the values @@ -850,12 +854,12 @@ def test_release_memory_blocks_for_finished_requests(self, is_hybrid_model): # Create an active_requests_mask where requests 0, 2, and 4 are finished (0), # and requests 1 and 3 are still active (1) - active_requests_mask = torch.tensor([0, 1, 0, 1, 0], device=torch.cuda.current_device()) + active_requests_mask = torch.tensor([0, 1, 0, 1, 0], device='cpu') # Call update_requests with these parameters dynamic_context.update_requests( active_requests_mask=active_requests_mask, - new_tokens=torch.tensor([10, 11, 12, 13, 14], device=torch.cuda.current_device()), + new_tokens=torch.tensor([10, 11, 12, 13, 14], device='cpu'), ) # After the update, we should have released 3 blocks (for requests 0, 2, and 4) @@ -939,12 +943,12 @@ def test_finished_requests_with_multiple_blocks(self, is_hybrid_model): dynamic_context.mamba_metadata.mamba_state_free_slot_count -= 1 # Create an active_requests_mask where all requests are finished - active_requests_mask = torch.tensor([0, 0, 0], device=torch.cuda.current_device()) + active_requests_mask = torch.tensor([0, 0, 0], device='cpu') # Call update_requests with these parameters dynamic_context.update_requests( active_requests_mask=active_requests_mask, - new_tokens=torch.tensor([10, 11, 12], device=torch.cuda.current_device()), + new_tokens=torch.tensor([10, 11, 12], device='cpu'), ) # After the update, we should have released all 6 blocks and have 0 active requests @@ -994,7 +998,7 @@ def test_mamba_states_cache(self, is_hybrid_model: bool): dynamic_context.add_request( DynamicInferenceRequest( request_id=0, - prompt_tokens=torch.arange(0, context_length, dtype=torch.long, device='cuda'), + prompt_tokens=torch.arange(0, context_length, dtype=torch.long, device='cpu'), sampling_params=SamplingParams( num_tokens_to_generate=dynamic_context.max_tokens - 10 ), @@ -1047,17 +1051,17 @@ def test_calculate_and_store_log_probs(self): # Add a few requests to the context request_data = { 1001: { - "tokens": torch.randint(0, 100, (10,), device='cuda'), + "tokens": torch.randint(0, 100, (10,), device='cpu'), "prefill_len": 10, "initial_token_offset": 0, }, 1002: { - "tokens": torch.randint(0, 100, (5,), device='cuda'), + "tokens": torch.randint(0, 100, (5,), device='cpu'), "prefill_len": 5, "initial_token_offset": 10, }, 1003: { - "tokens": torch.randint(0, 100, (7,), device='cuda'), + "tokens": torch.randint(0, 100, (7,), device='cpu'), "prefill_len": 7, "initial_token_offset": 15, }, @@ -1081,7 +1085,12 @@ def test_calculate_and_store_log_probs(self): # Simulate prefill step total_active_tokens = dynamic_context.active_token_count vocab_size = 50000 - # logits will have shape [1, total_active_tokens, vocab_size] + + # Populate gpu_view for calculate_log_probs (which reads from gpu_view). + dynamic_context.initialize_attention_state() + dynamic_context.transfer_bookkeeping_to_gpu() + + # logits and new_tokens must be on GPU (calculate_log_probs uses gpu_view). prefill_logits = torch.randn( 1, total_active_tokens, vocab_size, device='cuda', dtype=torch.float32 ) @@ -1122,12 +1131,16 @@ def test_calculate_and_store_log_probs(self): # Simulate decode step # All requests are active, so the mask will be all ones for the current active requests - active_requests_mask = torch.ones(dynamic_context.total_request_count, device='cuda').int() + active_requests_mask = torch.ones(dynamic_context.total_request_count, device='cpu').int() dynamic_context.update_requests( active_requests_mask=active_requests_mask, new_tokens=prefill_new_tokens ) + # Populate gpu_view again after update_requests modified bookkeeping state. + dynamic_context.initialize_attention_state() + dynamic_context.transfer_bookkeeping_to_gpu() + # Generate new logits for the decode step. Now each request contributes 1 token. decode_logits = torch.randn( 1, num_active_requests, vocab_size, device='cuda', dtype=torch.float32 @@ -1153,7 +1166,7 @@ def test_calculate_and_store_log_probs(self): # Add a new prefill request to the existing context new_request_id = 1004 - new_request_tokens = torch.randint(0, 100, (12,), device='cuda').long() + new_request_tokens = torch.randint(0, 100, (12,), device='cpu').long() new_request_prefill_len = new_request_tokens.shape[0] initial_token_offset_new_request = dynamic_context.active_token_count dynamic_context.add_request( @@ -1175,6 +1188,7 @@ def test_calculate_and_store_log_probs(self): # This step will involve both prefill (for the new request) and decode (for existing requests). dynamic_context.initialize_attention_state() + dynamic_context.transfer_bookkeeping_to_gpu() total_active_tokens_mixed_step = dynamic_context.active_token_count mixed_step_logits = torch.randn( @@ -1299,7 +1313,7 @@ def test_pipeline_parallel_uneven_layers(self): ), ) - # Collect the total block counts on each rank + # Collect the total block counts on each rank (CUDA needed for NCCL all_gather) local_total_blocks = torch.tensor( [context.kv_block_allocator.total_count], device='cuda', dtype=torch.long ) @@ -1654,14 +1668,14 @@ def test_chunked_prefill_state_preserved_across_decode_completions(self): dynamic_context.add_request( DynamicInferenceRequest( request_id=10, - prompt_tokens=torch.arange(0, 2, device='cuda'), + prompt_tokens=torch.arange(0, 2, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=10), ) ) dynamic_context.add_request( DynamicInferenceRequest( request_id=11, - prompt_tokens=torch.arange(0, 2, device='cuda'), + prompt_tokens=torch.arange(0, 2, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=10), ) ) @@ -1669,7 +1683,7 @@ def test_chunked_prefill_state_preserved_across_decode_completions(self): # Add Chunk 1 of the chunked prefill request req_999 = DynamicInferenceRequest( request_id=999, - prompt_tokens=torch.arange(0, 8, device='cuda'), + prompt_tokens=torch.arange(0, 8, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=10), ) dynamic_context.add_request(req_999, prefill_chunk_length=4) @@ -1683,8 +1697,8 @@ def test_chunked_prefill_state_preserved_across_decode_completions(self): assert kv_block_before != -1 # Step 1: Forward pass for all 3 requests - active_requests_mask = torch.tensor([1, 1, 1], dtype=torch.int32, device='cuda') - new_tokens = torch.tensor([100, 101, 102], dtype=torch.int32, device='cuda') + active_requests_mask = torch.tensor([1, 1, 1], dtype=torch.int32, device='cpu') + new_tokens = torch.tensor([100, 101, 102], dtype=torch.int32, device='cpu') dynamic_context.update_requests(active_requests_mask, new_tokens) # At this point, req 999 is hidden at index 2. total_request_count is 2 (req 10, 11). @@ -1692,8 +1706,8 @@ def test_chunked_prefill_state_preserved_across_decode_completions(self): assert dynamic_context.request_ids[2].item() == 999 # Step 2: Forward pass where req 10 finishes, req 11 continues. Req 999 is NOT scheduled. - active_requests_mask = torch.tensor([0, 1], dtype=torch.int32, device='cuda') - new_tokens = torch.tensor([103, 104], dtype=torch.int32, device='cuda') + active_requests_mask = torch.tensor([0, 1], dtype=torch.int32, device='cpu') + new_tokens = torch.tensor([103, 104], dtype=torch.int32, device='cpu') dynamic_context.update_requests(active_requests_mask, new_tokens) # At this point, req 10 is evicted. Req 11 shifts to index 0. total_request_count becomes 1. @@ -1756,14 +1770,14 @@ def test_chunked_prefill_all_active_requests_finish_while_hidden(self): dynamic_context.add_request( DynamicInferenceRequest( request_id=10, - prompt_tokens=torch.arange(0, 2, device='cuda'), + prompt_tokens=torch.arange(0, 2, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=10), ) ) dynamic_context.add_request( DynamicInferenceRequest( request_id=11, - prompt_tokens=torch.arange(0, 2, device='cuda'), + prompt_tokens=torch.arange(0, 2, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=10), ) ) @@ -1771,7 +1785,7 @@ def test_chunked_prefill_all_active_requests_finish_while_hidden(self): # Add Chunk 1 of a chunked prefill request req_999 = DynamicInferenceRequest( request_id=999, - prompt_tokens=torch.arange(0, 8, device='cuda'), + prompt_tokens=torch.arange(0, 8, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=10), ) dynamic_context.add_request(req_999, prefill_chunk_length=4) @@ -1781,8 +1795,8 @@ def test_chunked_prefill_all_active_requests_finish_while_hidden(self): assert kv_block_before != -1 # Step 1: All 3 requests are active, process forward pass - active_requests_mask = torch.tensor([1, 1, 1], dtype=torch.int32, device='cuda') - new_tokens = torch.tensor([100, 101, 102], dtype=torch.int32, device='cuda') + active_requests_mask = torch.tensor([1, 1, 1], dtype=torch.int32, device='cpu') + new_tokens = torch.tensor([100, 101, 102], dtype=torch.int32, device='cpu') dynamic_context.update_requests(active_requests_mask, new_tokens) # Chunked prefill is now hidden at position 2, total_request_count = 2 @@ -1791,8 +1805,8 @@ def test_chunked_prefill_all_active_requests_finish_while_hidden(self): # Step 2: Both decode requests finish, chunked prefill NOT scheduled this step. # This must NOT crash even though active_request_count becomes 0. - active_requests_mask = torch.tensor([0, 0], dtype=torch.int32, device='cuda') - new_tokens = torch.tensor([103, 104], dtype=torch.int32, device='cuda') + active_requests_mask = torch.tensor([0, 0], dtype=torch.int32, device='cpu') + new_tokens = torch.tensor([103, 104], dtype=torch.int32, device='cpu') dynamic_context.update_requests(active_requests_mask, new_tokens) # total_request_count should be 0 (both finished, chunked prefill hidden) @@ -1839,10 +1853,10 @@ def test_update_requests_speculative(self): ctx.request_to_kv_block_ids[:2, 0] = torch.tensor([0, 1]) ctx.request_last_kv_block_id[:2] = torch.tensor([0, 1]) - active_requests_mask = torch.tensor([1, 1], device='cuda') - new_tokens = torch.tensor([99, 100], device='cuda') # Sampled tokens + active_requests_mask = torch.tensor([1, 1], device='cpu') + new_tokens = torch.tensor([99, 100], device='cpu') # Sampled tokens new_speculative_tokens = torch.tensor( - [[991, 1001], [992, 1002]], device='cuda' + [[991, 1001], [992, 1002]], device='cpu' ) # Spec tokens ctx.update_requests( @@ -1854,15 +1868,14 @@ def test_update_requests_speculative(self): # Each request generates 1 (sampled) + 2 (speculative) = 3 tokens. assert ctx.active_token_count == 6 assert torch.equal( - ctx.request_query_lengths[:2], torch.tensor([3, 3], dtype=torch.int32, device='cuda') + ctx.request_query_lengths[:2], torch.tensor([3, 3], dtype=torch.int32, device='cpu') ) assert torch.equal( - ctx.request_kv_length_offsets[:2], - torch.tensor([6, 9], dtype=torch.int32, device='cuda'), + ctx.request_kv_length_offsets[:2], torch.tensor([6, 9], dtype=torch.int32, device='cpu') ) # Check interleaving: [sampled_1, spec1_1, spec2_1, sampled_2, spec1_2, spec2_2] - expected_tokens = torch.tensor([99, 991, 992, 100, 1001, 1002], device='cuda') + expected_tokens = torch.tensor([99, 991, 992, 100, 1001, 1002], device='cpu') assert torch.equal(ctx.token_to_input_ids[:6], expected_tokens) @pytest.mark.internal @@ -1903,9 +1916,9 @@ def test_speculative_boundary_crossing(self): ctx.request_to_kv_block_ids[0, 0] = first_block ctx.request_last_kv_block_id[0] = first_block - active_requests_mask = torch.tensor([1], device='cuda') - new_tokens = torch.tensor([50], device='cuda') - new_speculative_tokens = torch.tensor([[51], [52]], device='cuda') + active_requests_mask = torch.tensor([1], device='cpu') + new_tokens = torch.tensor([50], device='cpu') + new_speculative_tokens = torch.tensor([[51], [52]], device='cpu') # Run update_requests natively. It will automatically: # 1. Detect the boundary crossing and pause the request. @@ -1929,7 +1942,7 @@ def test_speculative_boundary_crossing(self): # Token 1 (offset 3) -> first_block # Token 2 (offset 4) -> second_block expected_blocks = torch.tensor( - [first_block, first_block, second_block], dtype=torch.int, device='cuda' + [first_block, first_block, second_block], dtype=torch.int, device='cpu' ) assert torch.equal(ctx.token_to_block_idx[:3], expected_blocks) @@ -1979,10 +1992,10 @@ def test_paused_speculative_tokens_tracking(self): ctx.kv_block_allocator.total_avail = 0 ctx.kv_block_allocator.paused_count = 100 # Ensure it doesn't get completely evicted either - active_requests_mask = torch.tensor([1, 1], device='cuda') - new_tokens = torch.tensor([99, 100], device='cuda') # Sampled + active_requests_mask = torch.tensor([1, 1], device='cpu') + new_tokens = torch.tensor([99, 100], device='cpu') # Sampled new_speculative_tokens = torch.tensor( - [[991, 1001], [992, 1002]], device='cuda' + [[991, 1001], [992, 1002]], device='cpu' ) # Speculative # In update_requests, request 0 will be paused to allocate a new block. @@ -2004,7 +2017,7 @@ def test_paused_speculative_tokens_tracking(self): assert ctx.paused_tokens[0].item() == 99 assert torch.equal( - ctx.paused_speculative_tokens[:, 0], torch.tensor([991, 992], device='cuda') + ctx.paused_speculative_tokens[:, 0], torch.tensor([991, 992], device='cpu') ) @pytest.mark.internal @@ -2043,8 +2056,8 @@ def test_swap_book_keeping_tensors_with_speculative_tokens(self): ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) ctx.request_ids[:2] = torch.tensor([10, 11]) - next_tokens = torch.tensor([99, 100], device='cuda') - new_speculative_tokens = torch.tensor([[991, 1001], [992, 1002]], device='cuda') + next_tokens = torch.tensor([99, 100], device='cpu') + new_speculative_tokens = torch.tensor([[991, 1001], [992, 1002]], device='cpu') ctx._swap_book_keeping_tensors( src_idxs=torch.tensor([0]), @@ -2053,10 +2066,10 @@ def test_swap_book_keeping_tensors_with_speculative_tokens(self): new_speculative_tokens=new_speculative_tokens, ) - assert torch.equal(ctx.request_ids[:2], torch.tensor([11, 10], device='cuda')) - assert torch.equal(next_tokens[:2], torch.tensor([100, 99], device='cuda')) + assert torch.equal(ctx.request_ids[:2], torch.tensor([11, 10], device='cpu')) + assert torch.equal(next_tokens[:2], torch.tensor([100, 99], device='cpu')) assert torch.equal( - new_speculative_tokens[:, :2], torch.tensor([[1001, 991], [1002, 992]], device='cuda') + new_speculative_tokens[:, :2], torch.tensor([[1001, 991], [1002, 992]], device='cpu') ) @pytest.mark.internal @@ -2087,9 +2100,9 @@ def test_update_requests_with_finished_requests_and_speculative_tokens(self): ctx.request_last_kv_block_id[:3] = torch.tensor([0, 1, 2]) ctx.request_kv_block_counts[:3] = 1 - active_requests_mask = torch.tensor([1, 0, 1], device='cuda') - new_tokens = torch.tensor([99, 100, 101], device='cuda') - new_speculative_tokens = torch.tensor([[991, 1001, 1011], [992, 1002, 1012]], device='cuda') + active_requests_mask = torch.tensor([1, 0, 1], device='cpu') + new_tokens = torch.tensor([99, 100, 101], device='cpu') + new_speculative_tokens = torch.tensor([[991, 1001, 1011], [992, 1002, 1012]], device='cpu') ctx.update_requests( active_requests_mask=active_requests_mask, @@ -2100,13 +2113,13 @@ def test_update_requests_with_finished_requests_and_speculative_tokens(self): # req1 is finished. req2 moves to req1's position. assert ctx.total_request_count == 2 assert torch.equal( - ctx.request_ids[:2], torch.tensor([10, 12], device='cuda', dtype=torch.int32) + ctx.request_ids[:2], torch.tensor([10, 12], device='cpu', dtype=torch.int32) ) # Check interleaving for req0 and req2 # req0: [99, 991, 992] # req2: [101, 1011, 1012] - expected_tokens = torch.tensor([99, 991, 992, 101, 1011, 1012], device='cuda') + expected_tokens = torch.tensor([99, 991, 992, 101, 1011, 1012], device='cpu') assert torch.equal(ctx.token_to_input_ids[:6], expected_tokens) @pytest.mark.internal @@ -2137,7 +2150,7 @@ def test_chunked_prefill_hidden_state_prevents_token_bloat(self): # 1. Add a standard decode request req_decode = DynamicInferenceRequest( request_id=10, - prompt_tokens=torch.arange(0, 10, device='cuda'), + prompt_tokens=torch.arange(0, 10, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=10), ) ctx.add_request(req_decode) @@ -2145,7 +2158,7 @@ def test_chunked_prefill_hidden_state_prevents_token_bloat(self): # 2. Add chunk 1 of a chunked prefill request req_chunked = DynamicInferenceRequest( request_id=42, - prompt_tokens=torch.arange(0, 100, device='cuda'), + prompt_tokens=torch.arange(0, 100, device='cpu'), sampling_params=SamplingParams(num_tokens_to_generate=10), ) ctx.chunked_prefill_request_id = 42 @@ -2155,10 +2168,10 @@ def test_chunked_prefill_hidden_state_prevents_token_bloat(self): assert ctx.active_token_count == 60 # 3. Call update_requests - active_requests_mask = torch.tensor([1, 1], dtype=torch.int32, device='cuda') - new_tokens = torch.tensor([99, 199], dtype=torch.int32, device='cuda') + active_requests_mask = torch.tensor([1, 1], dtype=torch.int32, device='cpu') + new_tokens = torch.tensor([99, 199], dtype=torch.int32, device='cpu') new_spec = torch.tensor( - [[100, 200], [101, 201], [102, 202]], dtype=torch.int32, device='cuda' + [[100, 200], [101, 201], [102, 202]], dtype=torch.int32, device='cpu' ) ctx.update_requests( @@ -2223,13 +2236,13 @@ def test_chunked_prefill_swap_with_speculative_tokens(self): ctx.request_last_kv_block_id[:2] = torch.tensor([0, 1]) ctx.request_kv_block_counts[:2] = 1 - active_requests_mask = torch.tensor([1, 1], device='cuda') + active_requests_mask = torch.tensor([1, 1], device='cpu') # New base tokens: [100 (for prefill), 200 (for decode)] - new_tokens = torch.tensor([100, 200], device='cuda') + new_tokens = torch.tensor([100, 200], device='cpu') # New spec tokens: Col 0 for prefill (dummy), Col 1 for decode (real draft tokens) - new_speculative_tokens = torch.tensor([[101, 201], [102, 202]], device='cuda') + new_speculative_tokens = torch.tensor([[101, 201], [102, 202]], device='cpu') # Trigger update_requests. # It must detect ID 42 is at index 0, and swap it with index 1. @@ -2241,7 +2254,7 @@ def test_chunked_prefill_swap_with_speculative_tokens(self): # 1. Verify the IDs were swapped successfully assert torch.equal( - ctx.request_ids[:2], torch.tensor([99, 42], dtype=torch.int32, device='cuda') + ctx.request_ids[:2], torch.tensor([99, 42], dtype=torch.int32, device='cpu') ) # 2. Verify the Decode request (now at Index 0) correctly flattened its @@ -2249,7 +2262,7 @@ def test_chunked_prefill_swap_with_speculative_tokens(self): # 3. Verify the Prefill request (now at Index 1) is hidden and does NOT # flatten its dummy tokens. expected_flattened_tokens = torch.tensor( - [200, 201, 202], device='cuda' # Decode request (ID 99) + [200, 201, 202], device='cpu' # Decode request (ID 99) ) assert ctx.active_token_count == 3 @@ -2259,7 +2272,7 @@ def test_chunked_prefill_swap_with_speculative_tokens(self): # 4. Verify that the new_speculative_tokens tensor itself was swapped so that # the hidden state perfectly preserves the alignment for subsequent steps. - expected_swapped_spec_tokens = torch.tensor([[201, 101], [202, 102]], device='cuda') + expected_swapped_spec_tokens = torch.tensor([[201, 101], [202, 102]], device='cpu') assert torch.equal( new_speculative_tokens, expected_swapped_spec_tokens ), "new_speculative_tokens was not swapped in-place alongside the request metadata!" @@ -2287,7 +2300,7 @@ def test_speculative_with_prefix_caching_shared_blocks(self): # This avoids the single-token-chunk clamp (effective_prefill >= 2) and # verifies that the prefix skip actually works. tail = 5 - prompt = torch.arange(bs * 3 + tail, device='cuda') + prompt = torch.arange(bs * 3 + tail, device='cpu') # First request registers blocks. req1 = DynamicInferenceRequest( @@ -2349,7 +2362,7 @@ def test_speculative_with_prefix_caching_kv_offset(self): # Use bs * 2 + 5 tokens so the prompt extends past the last full block, # avoiding the single-token-chunk clamp while still testing the skip. tail = 5 - prompt = torch.arange(bs * 2 + tail, device='cuda') + prompt = torch.arange(bs * 2 + tail, device='cpu') # First request. req1 = DynamicInferenceRequest( @@ -2397,7 +2410,7 @@ def test_speculative_update_then_release_with_prefix_caching(self): ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) bs = ctx.block_size_tokens - prompt = torch.arange(bs * 2, device='cuda') + prompt = torch.arange(bs * 2, device='cpu') # Two requests sharing the same prefix. req1 = DynamicInferenceRequest( @@ -2454,7 +2467,7 @@ def test_speculative_boundary_crossing_with_prefix_caching(self): ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) bs = ctx.block_size_tokens - prompt = torch.arange(bs * 2, device='cuda') + prompt = torch.arange(bs * 2, device='cpu') # Request 1: adds prefix blocks. req1 = DynamicInferenceRequest( @@ -2491,9 +2504,9 @@ def test_speculative_boundary_crossing_with_prefix_caching(self): ctx.request_in_prefill_status_tensor[0] = 0 ctx.active_token_count = 2 - active_mask = torch.tensor([1, 1], device='cuda', dtype=torch.int32) - new_tokens = torch.tensor([50, 50], device='cuda') - new_spec = torch.tensor([[51, 51], [52, 52]], device='cuda') + active_mask = torch.tensor([1, 1], device='cpu', dtype=torch.int32) + new_tokens = torch.tensor([50, 50], device='cpu') + new_spec = torch.tensor([[51, 51], [52, 52]], device='cpu') ctx.update_requests( active_requests_mask=active_mask, new_tokens=new_tokens, new_speculative_tokens=new_spec @@ -2535,7 +2548,7 @@ def test_chunked_prefill_prefix_caching_from_hidden_state(self): bs = ctx.block_size_tokens # First request: register prefix blocks (bs * 3 tokens = 3 complete blocks). - first_prompt = torch.arange(bs * 3, device='cuda') + first_prompt = torch.arange(bs * 3, device='cpu') req_first = DynamicInferenceRequest( request_id=1, prompt_tokens=first_prompt.clone(), @@ -2560,9 +2573,9 @@ def test_chunked_prefill_prefix_caching_from_hidden_state(self): ctx.add_request(req2, prefill_chunk_length=bs) # Call update_requests to move req2 to the hidden state - active_requests_mask = torch.tensor([1, 1], dtype=torch.int32, device='cuda') - new_tokens = torch.tensor([99, 199], dtype=torch.int32, device='cuda') - new_spec = torch.tensor([[100, 200], [101, 201]], dtype=torch.int32, device='cuda') + active_requests_mask = torch.tensor([1, 1], dtype=torch.int32, device='cpu') + new_tokens = torch.tensor([99, 199], dtype=torch.int32, device='cpu') + new_spec = torch.tensor([[100, 200], [101, 201]], dtype=torch.int32, device='cpu') ctx.update_requests(active_requests_mask, new_tokens, new_speculative_tokens=new_spec) # Capture active tokens before chunk 2 (which should just be the 3 tokens of req_first) @@ -2604,7 +2617,7 @@ def test_prefix_caching_check_availability_with_speculative(self): ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) bs = ctx.block_size_tokens - prompt = torch.arange(bs * 2, device='cuda') + prompt = torch.arange(bs * 2, device='cpu') # First request registers blocks. req1 = DynamicInferenceRequest( @@ -2654,7 +2667,7 @@ def test_prefix_match_exact_block_boundary(self): bs = ctx.block_size_tokens # req1: 32 tokens (exactly 2 complete blocks) - prompt1 = torch.arange(bs * 2, device='cuda') + prompt1 = torch.arange(bs * 2, device='cpu') req1 = DynamicInferenceRequest( request_id=1, prompt_tokens=prompt1, @@ -2665,7 +2678,7 @@ def test_prefix_match_exact_block_boundary(self): ctx.add_request(req1) # req2: 35 tokens (first 32 tokens match req1) - prompt2 = torch.arange(bs * 2 + 3, device='cuda') + prompt2 = torch.arange(bs * 2 + 3, device='cpu') req2 = DynamicInferenceRequest( request_id=2, prompt_tokens=prompt2, @@ -2712,7 +2725,7 @@ def test_eviction_with_shared_prefix_blocks(self): ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) bs = ctx.block_size_tokens - prompt = torch.arange(bs * 2, device='cuda') + prompt = torch.arange(bs * 2, device='cpu') # Add req1 and req2 with identical prompts req1 = DynamicInferenceRequest( @@ -2752,7 +2765,7 @@ def test_eviction_with_shared_prefix_blocks(self): # Trigger the eviction logic # next_tokens must be sized to total_request_count (1 paused + 1 active = 2) - next_tokens = torch.tensor([50, 51], device='cuda') + next_tokens = torch.tensor([50, 51], device='cpu') evicted_ids = ctx.evict_overflow_paused_requests( active_request_count=1, next_tokens=next_tokens ) @@ -2792,17 +2805,17 @@ def test_oom_during_speculative_boundary_crossing(self): ctx.paused_request_count = 0 ctx.active_token_count = 2 - ctx.request_ids[:2] = torch.tensor([10, 11], device='cuda') + ctx.request_ids[:2] = torch.tensor([10, 11], device='cpu') ctx.request_query_lengths[:2] = 1 ctx.request_kv_block_counts[:2] = 1 # Request 0 offset is 15. Adding 1 sampled + 2 spec = 3 tokens crosses the boundary (16). # Request 1 offset is 5. Adding 3 tokens = 8 (does not cross). ctx.request_kv_length_offsets[:2] = torch.tensor( - [bs - 1, 5], device='cuda', dtype=torch.int32 + [bs - 1, 5], device='cpu', dtype=torch.int32 ) ctx.request_last_kv_block_offset[:2] = torch.tensor( - [bs - 1, 5], device='cuda', dtype=torch.int32 + [bs - 1, 5], device='cpu', dtype=torch.int32 ) blocks = ctx.kv_block_allocator.allocate_memory_blocks(2) @@ -2814,9 +2827,9 @@ def test_oom_during_speculative_boundary_crossing(self): ctx.kv_block_allocator.total_avail = 0 ctx.kv_block_allocator.paused_count = 100 # Prevent immediate eviction out of the system - active_mask = torch.tensor([1, 1], device='cuda', dtype=torch.int32) - new_tokens = torch.tensor([99, 88], device='cuda') - new_spec = torch.tensor([[100, 200], [101, 201]], device='cuda') + active_mask = torch.tensor([1, 1], device='cpu', dtype=torch.int32) + new_tokens = torch.tensor([99, 88], device='cpu') + new_spec = torch.tensor([[100, 200], [101, 201]], device='cpu') # Run update requests ctx.update_requests( @@ -2880,9 +2893,9 @@ def test_speculative_boundary_crossing_at_max_kv_block_count(self): ctx.request_to_kv_block_ids[0, 1] = blocks[1] ctx.request_last_kv_block_id[0] = blocks[1] - active_requests_mask = torch.tensor([1], device='cuda') - new_tokens = torch.tensor([50], device='cuda') - new_speculative_tokens = torch.tensor([[51], [52]], device='cuda') + active_requests_mask = torch.tensor([1], device='cpu') + new_tokens = torch.tensor([50], device='cpu') + new_speculative_tokens = torch.tensor([[51], [52]], device='cpu') # This will pause the request (offset 13 >= 13), then resume it by # allocating a 3rd block at col_idx=2. Without the fix, this raises @@ -2921,7 +2934,7 @@ def test_chunked_prefill_meets_prefix_caching(self): ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) bs = ctx.block_size_tokens - prompt = torch.arange(128, device='cuda') + prompt = torch.arange(128, device='cpu') # Cache req1 (fully processed) req1 = DynamicInferenceRequest( @@ -2970,3 +2983,243 @@ def test_chunked_prefill_meets_prefix_caching(self): # Verify block references updated appropriately assert ctx.kv_block_allocator.block_ref_counts[req1_blocks[2]].item() == 2 assert ctx.kv_block_allocator.block_ref_counts[req1_blocks[3]].item() == 2 + + # ------------------------------------------------------------------ # + # Tests for active_logit_idxs / last_token_logits / pad_active_slices + # ------------------------------------------------------------------ # + + def _build_speculative_ctx(self, num_speculative_tokens=2, block_size=256): + """Build a context configured for speculative decoding.""" + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.05, + block_size_tokens=block_size, + num_speculative_tokens=num_speculative_tokens, + unified_memory_level=0, + ) + return DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + def _add_and_step_decode_requests(self, ctx, num_requests, prompt_length=10): + """Add prefill requests, then step them into decode state with speculative tokens. + + Returns the context in a state with ``num_requests`` decode requests whose + query_lengths equal ``num_speculative_tokens + 1``. + """ + for i in range(num_requests): + req = DynamicInferenceRequest( + request_id=i, + prompt_tokens=torch.arange(0, prompt_length, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=100), + ) + ctx.add_request(req) + + ctx.initialize_attention_state() + + active_mask = torch.ones(num_requests, device='cuda', dtype=torch.int32) + new_tokens = torch.arange(num_requests, device='cuda') + num_spec = ctx.num_speculative_tokens + new_spec = torch.arange(num_spec * num_requests, device='cuda').reshape( + num_spec, num_requests + ) + ctx.update_requests( + active_requests_mask=active_mask, new_tokens=new_tokens, new_speculative_tokens=new_spec + ) + return ctx + + @pytest.mark.internal + @rounder_override(64) + def test_pad_active_slices_speculative_decode_only(self): + """Verify active_logit_idxs for a decode-only batch with speculative tokens.""" + num_decode = 3 + num_spec = 2 + ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec) + self._add_and_step_decode_requests(ctx, num_decode) + + assert ctx.num_prefill_requests == 0 + assert ctx.num_decode_requests == num_decode + tokens_per_decode = num_spec + 1 + + ctx.initialize_attention_state() + + decode_token_count = num_decode * tokens_per_decode + expected_decode = torch.arange(decode_token_count, dtype=torch.int32, device='cuda') + actual = ctx.active_logit_idxs[:decode_token_count] + assert torch.equal( + actual, expected_decode + ), f"decode indices mismatch: {actual.tolist()} vs {expected_decode.tolist()}" + + assert ctx.num_last_token_logits == decode_token_count + assert ctx.active_logit_idxs[decode_token_count:].sum().item() == 0 + + @pytest.mark.internal + @rounder_override(64) + def test_pad_active_slices_speculative_mixed_batch(self): + """Verify active_logit_idxs for a mixed decode+prefill batch with speculative tokens.""" + num_decode = 2 + num_spec = 2 + ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec) + self._add_and_step_decode_requests(ctx, num_decode) + + prefill_lengths = [15, 20] + for i, pl in enumerate(prefill_lengths): + req = DynamicInferenceRequest( + request_id=100 + i, + prompt_tokens=torch.arange(0, pl, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req) + + assert ctx.num_decode_requests == num_decode + assert ctx.num_prefill_requests == len(prefill_lengths) + tokens_per_decode = num_spec + 1 + + ctx.initialize_attention_state() + + decode_token_count = num_decode * tokens_per_decode + expected_decode = torch.arange(decode_token_count, dtype=torch.int32, device='cuda') + actual_decode = ctx.active_logit_idxs[:decode_token_count] + assert torch.equal(actual_decode, expected_decode) + + cumulative = 0 + for i, pl in enumerate(prefill_lengths): + cumulative += pl + expected_prefill_idx = decode_token_count + cumulative - 1 + actual_prefill_idx = ctx.active_logit_idxs[decode_token_count + i].item() + assert ( + actual_prefill_idx == expected_prefill_idx + ), f"prefill request {i}: expected idx {expected_prefill_idx}, got {actual_prefill_idx}" + + expected_num_logits = decode_token_count + len(prefill_lengths) + assert ctx.num_last_token_logits == expected_num_logits + + @pytest.mark.internal + @rounder_override(64) + def test_pad_active_slices_speculative_all_prefill(self): + """Verify active_logit_idxs with only prefill requests (no decode) and speculative tokens.""" + num_spec = 2 + ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec) + + prefill_lengths = [12, 8, 25] + for i, pl in enumerate(prefill_lengths): + req = DynamicInferenceRequest( + request_id=i, + prompt_tokens=torch.arange(0, pl, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req) + + assert ctx.num_decode_requests == 0 + assert ctx.num_prefill_requests == len(prefill_lengths) + + ctx.initialize_attention_state() + + cumulative = 0 + for i, pl in enumerate(prefill_lengths): + cumulative += pl + expected_idx = cumulative - 1 + actual_idx = ctx.active_logit_idxs[i].item() + assert ( + actual_idx == expected_idx + ), f"prefill request {i}: expected idx {expected_idx}, got {actual_idx}" + + expected_num_logits = len(prefill_lengths) + assert ctx.num_last_token_logits == expected_num_logits + + @pytest.mark.internal + @rounder_override(64) + def test_pad_active_slices_no_speculative_tokens(self): + """Verify active_logit_idxs without speculative tokens matches cumsum - 1.""" + ctx = self._build_speculative_ctx(num_speculative_tokens=0) + + req0 = DynamicInferenceRequest( + request_id=0, + prompt_tokens=torch.arange(0, 10, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req0) + ctx.initialize_attention_state() + active_mask = torch.ones(1, device='cuda', dtype=torch.int32) + new_tokens = torch.tensor([42], device='cuda') + ctx.update_requests(active_requests_mask=active_mask, new_tokens=new_tokens) + + prefill_lengths = [20, 30] + for i, pl in enumerate(prefill_lengths): + req = DynamicInferenceRequest( + request_id=10 + i, + prompt_tokens=torch.arange(0, pl, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req) + + assert ctx.num_decode_requests == 1 + assert ctx.num_prefill_requests == 2 + + ctx.initialize_attention_state() + + all_query_lengths = ctx.request_query_lengths[ + ctx.paused_request_count : ctx.total_request_count + ] + expected_idxs = torch.cumsum(all_query_lengths, dim=0) - 1 + num_logits = ctx.num_last_token_logits + actual_idxs = ctx.active_logit_idxs[:num_logits] + assert torch.equal( + actual_idxs, expected_idxs.to(device=actual_idxs.device, dtype=torch.int32) + ), f"non-speculative mismatch: {actual_idxs.tolist()} vs {expected_idxs.tolist()}" + + @pytest.mark.internal + @rounder_override(64) + def test_last_token_logits_selects_correct_values_speculative(self): + """Verify last_token_logits returns logits at the correct token positions.""" + num_decode = 2 + num_spec = 2 + ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec) + self._add_and_step_decode_requests(ctx, num_decode) + + prefill_lengths = [10, 15] + for i, pl in enumerate(prefill_lengths): + req = DynamicInferenceRequest( + request_id=100 + i, + prompt_tokens=torch.arange(0, pl, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req) + + ctx.initialize_attention_state() + + vocab_size = 32 + logits = torch.arange( + ctx.padded_active_token_count * vocab_size, dtype=torch.float32, device='cuda' + ).reshape(1, ctx.padded_active_token_count, vocab_size) + + result = ctx.last_token_logits(logits) + expected_num_logits = ctx.num_last_token_logits + assert result.shape == (expected_num_logits, vocab_size) + + idxs = ctx.active_logit_idxs[:expected_num_logits].long() + expected = logits.squeeze(0)[idxs, :] + assert torch.equal(result, expected) + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_required_logit_indices_matches_active_logit_idxs(self): + """speculative_required_logit_indices returns a slice of active_logit_idxs.""" + num_decode = 2 + num_spec = 2 + ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec) + self._add_and_step_decode_requests(ctx, num_decode) + + req = DynamicInferenceRequest( + request_id=100, + prompt_tokens=torch.arange(0, 20, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req) + ctx.initialize_attention_state() + + indices = ctx.speculative_required_logit_indices() + expected_len = ctx.num_last_token_logits + assert indices.numel() == expected_len + assert indices.data_ptr() == ctx.active_logit_idxs.data_ptr() diff --git a/tests/unit_tests/inference/contexts/test_dynamic_prefix_caching.py b/tests/unit_tests/inference/contexts/test_dynamic_prefix_caching.py index 049c0e5e040..84898db60d8 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_prefix_caching.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_prefix_caching.py @@ -3,6 +3,7 @@ import asyncio from collections import deque +import numpy as np import pytest import torch @@ -10,7 +11,6 @@ from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine from megatron.core.inference.inference_request import ( - HASH_PRIME, DynamicInferenceRequest, DynamicInferenceRequestRecord, Status, @@ -141,7 +141,7 @@ def test_hash_computation(self): tokens = self._prompt(32) h1 = compute_block_hashes_batched(tokens, 32) h2 = compute_block_hashes_batched(tokens, 32) - assert h1 == h2 and len(h1) == 1 and 1 <= h1[0] <= HASH_PRIME + assert h1 == h2 and len(h1) == 1 and h1[0] >= 1 assert compute_block_hashes_batched(self._prompt(32, offset=1), 32)[0] != h1[0] # parent chaining: 4 blocks of all-zero tokens produce distinct hashes @@ -160,6 +160,47 @@ def test_hash_computation(self): ) assert len(long_h) == 120 and all(v > 0 for v in long_h) + @pytest.mark.internal + def test_hash_collision_resistance(self): + """Regression tests: old polynomial collision attacks must fail with SHA-256.""" + bs = 32 + + # V2 regression: algebraic attack (token[j] += 31, token[j+1] -= 1) + # This was a zero-delta exploit against the old polynomial hash. + tokens = self._prompt(bs) + collision = tokens.clone() + collision[0] += 31 + collision[1] -= 1 + h_orig = compute_block_hashes_batched(tokens, bs) + h_coll = compute_block_hashes_batched(collision, bs) + assert h_orig != h_coll, "V2 algebraic collision: token[j]+=31, token[j+1]-=1" + + # V2 at different positions within the block + for j in range(bs - 1): + c = tokens.clone() + c[j] += 31 + c[j + 1] -= 1 + assert compute_block_hashes_batched(c, bs) != h_orig, f"V2 at position {j}" + + # V2 across multiple blocks: modify one block, verify all downstream hashes change + tokens_multi = self._prompt(bs * 4) + h_multi = compute_block_hashes_batched(tokens_multi, bs) + modified = tokens_multi.clone() + modified[0] += 31 + modified[1] -= 1 + h_mod = compute_block_hashes_batched(modified, bs) + assert h_mod[0] != h_multi[0], "modified block hash must differ" + # Parent chaining: all subsequent blocks must also differ + for i in range(1, 4): + assert h_mod[i] != h_multi[i], f"parent chain: block {i} must differ" + + # V2 generalized: arbitrary linear combinations (token[j] += k*31, token[j+1] -= k) + for k in [1, 2, 5, 100]: + c = tokens.clone() + c[0] += k * 31 + c[1] -= k + assert compute_block_hashes_batched(c, bs) != h_orig, f"V2 generalized k={k}" + @pytest.mark.internal def test_registration_and_discovery(self): ctx = self._ctx() @@ -762,12 +803,12 @@ def test_mamba_intermediate_offsets(self): overall, ) # Penultimate block offset (block 2 boundary) is a valid intermediate - count = msa._intermediate_counts_gpu[1].item() + count = msa._intermediate_counts_cpu[1].item() if count > 0: - offsets = msa._intermediate_offsets_gpu[1, :count].tolist() + offsets = msa._intermediate_offsets_cpu[1, :count].tolist() for o in offsets: assert o > 0 and o % 128 == 0 - assert msa._eos_cache_block_id_gpu[1].item() >= 0 + assert msa._eos_cache_block_id_cpu[1].item() >= 0 # non-aligned prompt produces last_aligned intermediate offset ctx2 = self._mctx(block_size_tokens=bs) @@ -779,12 +820,12 @@ def test_mamba_intermediate_offsets(self): req2b = self._req(ctx2, p2.clone(), request_id=2) req2b._mamba_num_matched_blocks = 2 ctx2.add_request(req2b) - count2 = msa2._intermediate_counts_gpu[1].item() + count2 = msa2._intermediate_counts_cpu[1].item() if count2 > 0: - offsets = msa2._intermediate_offsets_gpu[1, :count2].tolist() + offsets = msa2._intermediate_offsets_cpu[1, :count2].tolist() for o in offsets: assert o > 0 and o % 128 == 0 - assert msa2._eos_cache_block_id_gpu[1].item() < 0 + assert msa2._eos_cache_block_id_cpu[1].item() < 0 # block-aligned prompts set EOS cache block ID ctx3 = self._mctx(block_size_tokens=bs) @@ -793,7 +834,10 @@ def test_mamba_intermediate_offsets(self): req3 = self._req(ctx3, p3.clone(), request_id=2) req3._mamba_num_matched_blocks = 0 ctx3.add_request(req3) - assert ctx3.mamba_slot_allocator._eos_cache_block_id_gpu[1].item() >= 0 + # Deferred Mamba ops execute during transfer. + ctx3.initialize_attention_state() + ctx3.transfer_bookkeeping_to_gpu() + assert ctx3.mamba_slot_allocator._eos_cache_block_id_cpu[1].item() >= 0 # intermediate output buffers are pre-allocated ctx4 = self._mctx() @@ -901,6 +945,7 @@ def test_mixed_batch(self, model_type): # last_token_logits ctx.initialize_attention_state() + ctx.transfer_bookkeeping_to_gpu() logits = torch.randn( 1, ctx.padded_active_token_count, vocab_size, device=torch.cuda.current_device() ) @@ -1026,9 +1071,9 @@ def test_commit_intermediate_states_batched(self): # Set up intermediate offsets: 1 intermediate at src_offset=0 bid0 = ctx.request_to_kv_block_ids[ctx_idx][0].item() - msa._intermediate_block_ids_gpu[ctx_idx, 0] = bid0 - msa._intermediate_offsets_gpu[ctx_idx, 0] = 128 - msa._intermediate_counts_gpu[ctx_idx] = 1 + msa._intermediate_block_ids_cpu[ctx_idx, 0] = bid0 + msa._intermediate_offsets_cpu[ctx_idx, 0] = 128 + msa._intermediate_counts_cpu[ctx_idx] = 1 msa._has_intermediates = True # Set metadata fields that would normally be set by _update_intermediate_offsets @@ -1037,7 +1082,7 @@ def test_commit_intermediate_states_batched(self): # Set up EOS block (block-aligned prompt) eos_bid = ctx.request_to_kv_block_ids[ctx_idx][2].item() - msa._eos_cache_block_id_gpu[ctx_idx] = eos_bid + msa._eos_cache_block_id_cpu[ctx_idx] = eos_bid # Write known patterns to live mamba state for EOS copy mamba_idx = metadata.request_to_mamba_state_idx[ctx_idx].item() @@ -1084,3 +1129,213 @@ def test_commit_intermediate_states_batched(self): # Verify _has_intermediates cleared assert not msa._has_intermediates + + +class TestPerBlockRouting(PrefixCachingTestBase): + """Tests for per-block routing storage and reconstruction.""" + + @pytest.mark.internal + def test_store_and_get_block_routing(self): + """Verify store_block_routing / get_block_routing round-trip.""" + ctx = self._ctx() + alloc = ctx.kv_block_allocator + bs = ctx.block_size_tokens + num_layers, topk = 4, 2 + + # Allocate a block + block_ids = alloc.allocate_memory_blocks(1) + bid = block_ids[0].item() + + # Store routing for some positions + positions = np.array([0, 1, 2]) + routing = np.random.randint(-100, 100, size=(3, num_layers, topk), dtype=np.int16) + alloc.store_block_routing(bid, positions, routing) + + # Retrieve and verify + stored = alloc.get_block_routing(bid) + assert stored is not None + assert isinstance(stored, np.ndarray) + assert stored.shape == (bs, num_layers, topk) + assert np.allclose(stored[:3], routing) + # Remaining positions should be zero + assert (stored[3:] == 0).all() + + @pytest.mark.internal + def test_routing_cleared_on_allocate(self): + """Routing data is cleared when a block is re-allocated.""" + ctx = self._ctx(enable_prefix_caching=False) + alloc = ctx.kv_block_allocator + + # Allocate, store routing, release, re-allocate + block_ids = alloc.allocate_memory_blocks(1) + bid = block_ids[0].item() + positions = np.array([0]) + routing = np.random.randint(-100, 100, size=(1, 4, 2), dtype=np.int16) + alloc.store_block_routing(bid, positions, routing) + assert alloc.get_block_routing(bid) is not None + + alloc.release_memory_blocks(block_ids) + # After release, routing still present (persists until re-alloc) + assert alloc.get_block_routing(bid) is not None + + # Re-allocate the same block + new_ids = alloc.allocate_memory_blocks(1) + new_bid = new_ids[0].item() + # The re-allocated block should have routing cleared + assert alloc.get_block_routing(new_bid) is None + + @pytest.mark.internal + def test_routing_cleared_on_reset(self): + """Routing data is cleared on allocator reset.""" + ctx = self._ctx() + alloc = ctx.kv_block_allocator + + block_ids = alloc.allocate_memory_blocks(1) + bid = block_ids[0].item() + alloc.store_block_routing( + bid, np.array([0]), np.random.randint(-100, 100, size=(1, 4, 2), dtype=np.int16) + ) + assert alloc.get_block_routing(bid) is not None + + alloc.reset() + assert alloc.get_block_routing(bid) is None + assert len(alloc.block_routing) == 0 + + @pytest.mark.internal + def test_routing_persists_through_deregister(self): + """Routing data persists through block deregister (needed for reconstruction).""" + ctx = self._ctx(prefix_caching_eviction_policy=PrefixCachingEvictionPolicy.REF_ZERO) + alloc = ctx.kv_block_allocator + bs = ctx.block_size_tokens + + # Add a request so blocks get allocated and registered + prompt = self._prompt(bs * 2) + req = self._req(ctx, prompt) + ctx.add_request(req) + b0, b1 = self._block_ids(ctx, 0, 2) + + # Store routing for both blocks + for bid in [b0, b1]: + alloc.store_block_routing( + bid, np.arange(bs), np.random.randint(-100, 100, size=(bs, 4, 2), dtype=np.int16) + ) + + # Release blocks (REF_ZERO deregisters immediately) + blocks = ctx.request_to_kv_block_ids[0] + valid_blocks = blocks[blocks >= 0] + alloc.release_memory_blocks(valid_blocks) + + # Routing data should still be present + assert alloc.get_block_routing(b0) is not None + assert alloc.get_block_routing(b1) is not None + + @pytest.mark.internal + def test_reconstruct_routing_from_blocks(self): + """Test reconstruction of routing indices from per-block storage.""" + ctx = self._ctx() + alloc = ctx.kv_block_allocator + bs = ctx.block_size_tokens + num_layers, topk = 4, 2 + + # Allocate 3 blocks + block_ids = alloc.allocate_memory_blocks(3) + bids = block_ids.tolist() + + # Store routing for all positions in first two blocks (full) + for bid in bids[:2]: + alloc.store_block_routing( + bid, + np.arange(bs), + np.arange(bs * num_layers * topk, dtype=np.int16).reshape(bs, num_layers, topk) + + bid, + ) + + # Store routing for partial last block (e.g., 5 tokens) + partial = 5 + alloc.store_block_routing( + bids[2], + np.arange(partial), + np.arange(partial * num_layers * topk, dtype=np.int16).reshape( + partial, num_layers, topk + ) + + bids[2], + ) + + # total_routing_tokens = 2 full blocks + 5 partial = 2*bs + 5 + total_routing_tokens = 2 * bs + partial + + result = alloc.reconstruct_routing_from_blocks(bids, total_routing_tokens) + + assert result is not None + assert isinstance(result, np.ndarray) + assert result.shape == (total_routing_tokens, num_layers, topk) + + # Verify content: first block + expected_b0 = ( + np.arange(bs * num_layers * topk, dtype=np.int16).reshape(bs, num_layers, topk) + + bids[0] + ) + assert np.allclose(result[:bs], expected_b0) + + # Verify content: partial last block + expected_partial = ( + np.arange(partial * num_layers * topk, dtype=np.int16).reshape( + partial, num_layers, topk + ) + + bids[2] + ) + assert np.allclose(result[2 * bs :], expected_partial) + + @pytest.mark.internal + def test_reconstruct_returns_none_for_missing_block(self): + """Reconstruction returns None if a block has no routing data.""" + ctx = self._ctx() + alloc = ctx.kv_block_allocator + bs = ctx.block_size_tokens + + block_ids = alloc.allocate_memory_blocks(2) + bids = block_ids.tolist() + + # Only store routing for the first block + alloc.store_block_routing( + bids[0], np.arange(bs), np.random.randint(-100, 100, size=(bs, 4, 2), dtype=np.int16) + ) + + result = alloc.reconstruct_routing_from_blocks(bids, 2 * bs) + assert result is None + + @pytest.mark.internal + def test_routing_survives_prefix_match_lru(self): + """In LRU mode, matched blocks' routing persists for the new request.""" + ctx = self._ctx(prefix_caching_eviction_policy=PrefixCachingEvictionPolicy.LRU) + alloc = ctx.kv_block_allocator + bs = ctx.block_size_tokens + + # First request: 2 full blocks + prompt = self._prompt(bs * 2) + req1 = self._req(ctx, prompt, request_id=1) + ctx.add_request(req1) + b0, b1 = self._block_ids(ctx, 0, 2) + + # Store routing for both blocks + routing_b0 = np.random.randint(-100, 100, size=(bs, 4, 2), dtype=np.int16) + routing_b1 = np.random.randint(-100, 100, size=(bs, 4, 2), dtype=np.int16) + alloc.store_block_routing(b0, np.arange(bs), routing_b0) + alloc.store_block_routing(b1, np.arange(bs), routing_b1) + + # Release first request's blocks (LRU: blocks stay cached) + blocks = ctx.request_to_kv_block_ids[0] + valid_blocks = blocks[blocks >= 0] + active_mask = torch.zeros(1, device=torch.cuda.current_device(), dtype=torch.int32) + new_tokens = torch.tensor([100], device=torch.cuda.current_device()) + ctx.update_requests(active_mask, new_tokens) + + # Second request with same prefix should match + req2 = self._req(ctx, prompt.clone(), request_id=2) + ctx.add_request(req2) + + # The matched blocks should still have routing data + assert alloc.get_block_routing(b0) is not None + assert np.allclose(alloc.get_block_routing(b0), routing_b0) + assert alloc.get_block_routing(b1) is not None + assert np.allclose(alloc.get_block_routing(b1), routing_b1) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index fe2b8fc5802..7bcf21882c1 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import asyncio import gc @@ -46,15 +46,15 @@ get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.models.mamba.mamba_model import MambaModel +from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec +from megatron.core.models.hybrid.hybrid_model import HybridModel from megatron.core.ssm.mamba_mixer import _check_mamba_sequence_packing_support from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord +from megatron.core.transformer.cuda_graphs import delete_cuda_graphs from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_fa_min_version, is_te_min_version -from tests.unit_tests.test_utilities import Utils +from tests.unit_tests.test_utilities import Utils, clear_nvte_env_vars try: from torch_memory_saver import torch_memory_saver # noqa: F401 @@ -65,7 +65,7 @@ def skip_if_mamba_sequence_packing_not_available(model_provider: str): - if model_provider == "mamba": + if model_provider == "hybrid": sequence_packing_available, reason_for_no_sequence_packing = ( _check_mamba_sequence_packing_support() ) @@ -133,6 +133,7 @@ class DynamicEngineTestConfig: ) force_build_cuda_graphs: bool = False transformer_impl: str = "local" + inference_moe_token_dispatcher_type: str = "nccl" # If False, do not build cuda graphs in the tests, even if # num_cuda_graphs is set. # For tests concerning cuda-graph warmups, we set this to False @@ -145,6 +146,7 @@ class DynamicEngineTestConfig: track_generated_token_events: bool = False num_speculative_tokens: int = 0 position_embedding_type: str = "learned_absolute" + sampling_backend: str = 'torch' def __post_init__(self): @@ -178,7 +180,7 @@ class DynamicEngineTestEnv: ) -class TestDynamicInferenceEngine: +class DynamicInferenceEngineTestBase: @classmethod def _build_requests(cls, test_config: DynamicEngineTestConfig) -> List[DynamicInferenceRequest]: @@ -273,6 +275,7 @@ def _build_inference_context( unified_memory_level=0, # unit tests currently broken with UVM track_generated_token_events=test_config.track_generated_token_events, num_speculative_tokens=test_config.num_speculative_tokens, + sampling_backend=test_config.sampling_backend, ), ) @@ -281,11 +284,7 @@ def _build_inference_context( @classmethod @torch.inference_mode() def _build_test_env(cls, test_config): - Utils.initialize_model_parallel( - tensor_model_parallel_size=test_config.tensor_model_parallel_size, - pipeline_model_parallel_size=test_config.pipeline_model_parallel_size, - ) - + clear_nvte_env_vars() set_rounder(4) # Random state. @@ -334,6 +333,9 @@ def _build_test_env(cls, test_config): inference_sampling_seed=test_config.random_seed, cuda_graph_scope=test_config.cuda_graph_scope, transformer_impl=test_config.transformer_impl, + inference_moe_token_dispatcher_type=( + test_config.inference_moe_token_dispatcher_type + ), normalization=( "RMSNorm" if test_config.transformer_impl == "inference_optimized" @@ -368,7 +370,7 @@ def _build_test_env(cls, test_config): mtp_block_spec=mtp_block_spec, position_embedding_type=test_config.position_embedding_type, ).cuda() - elif test_config.model_provider == "mamba": + elif test_config.model_provider == "hybrid": pp_size = test_config.pipeline_model_parallel_size # Transformer config. transformer_config = TransformerConfig( @@ -398,16 +400,25 @@ def _build_test_env(cls, test_config): ), sequence_parallel=test_config.sequence_parallel, pipeline_dtype=torch.bfloat16, - add_bias_linear=test_config.expert_model_parallel_size == 1, + add_bias_linear=test_config.expert_model_parallel_size == 1 + and not (test_config.transformer_impl == "inference_optimized"), fp8="hybrid" if test_config.fp8 else None, fp8_recipe="tensorwise" if test_config.fp8 else None, inference_sampling_seed=test_config.random_seed, cuda_graph_scope=test_config.cuda_graph_scope, transformer_impl=test_config.transformer_impl, + inference_moe_token_dispatcher_type=( + test_config.inference_moe_token_dispatcher_type + ), + normalization=( + "RMSNorm" + if test_config.transformer_impl == "inference_optimized" + else "LayerNorm" + ), is_hybrid_model=True, # Needs to be set for correct out_proj init ) - # Mamba model. + # Hybrid model. # When speculative tokens are configured, append MTP depth sections # to the hybrid layer pattern so the model creates MTP blocks. mtp_suffix = "/M" * test_config.num_speculative_tokens @@ -415,9 +426,9 @@ def _build_test_env(cls, test_config): mamba_pattern = "M*-" + mtp_suffix else: mamba_pattern = "M*-|M*-" + mtp_suffix - model = MambaModel( + model = HybridModel( config=transformer_config, - mamba_stack_spec=mamba_stack_spec, + hybrid_stack_spec=hybrid_stack_spec, vocab_size=test_config.vocab_size, max_sequence_length=test_config.max_sequence_length, parallel_output=True, @@ -459,10 +470,7 @@ def _build_test_env(cls, test_config): ), ) - # Reset global cuda graph state. - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] - CudaGraphManager.global_mempool = None + delete_cuda_graphs() # Inference engine. engine = DynamicInferenceEngine(text_generation_controller, inference_context) @@ -565,8 +573,21 @@ def _run_test(cls, **test_config_kwargs): return env + +class TestDynamicInferenceEngine(DynamicInferenceEngineTestBase): + + @classmethod + def setup_class(cls): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=1, + expert_tensor_parallel_size=1, + ) + @classmethod def teardown_class(cls): + delete_cuda_graphs() set_rounder(64) Utils.destroy_model_parallel() @@ -574,7 +595,7 @@ def teardown_class(cls): @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) - @pytest.mark.parametrize("model_provider", ["gpt", "mamba"]) + @pytest.mark.parametrize("model_provider", ["gpt", "hybrid"]) @pytest.mark.parametrize("num_cuda_graphs", [None, 1, 4, -1]) @pytest.mark.parametrize("cuda_graph_scope", [[], [CudaGraphScope.full_iteration_inference]]) def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None: @@ -600,8 +621,11 @@ def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None assert env.engine.context.cuda_graph_batch_dimensions_list model = env.engine.controller.inference_wrapped_model.model if cuda_graph_scope == [CudaGraphScope.full_iteration_inference]: - # check if cudagraph runners are created at the decoder level - assert model.decoder.cudagraph_manager.cudagraph_runners + # hybrid models attach cudagraph_manager to the model; others attach to the decoder + if model_provider == "hybrid": + assert model.cudagraph_manager.cudagraph_runners + else: + assert model.decoder.cudagraph_manager.cudagraph_runners else: # check if cudagraph runners are created at the layer level for layer in model.decoder.layers: @@ -632,7 +656,7 @@ def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None if model_provider == "gpt": expected_generated_tokens_list = gpt_expected_generated_tokens - elif model_provider == "mamba": + elif model_provider == "hybrid": expected_generated_tokens_list = mamba_expected_generated_tokens else: raise ValueError(f"Invalid model_provider {model_provider}") @@ -693,7 +717,7 @@ def test_token_overflow_nontransient(self) -> None: @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) - @pytest.mark.parametrize("model_provider", ["gpt", "mamba"]) + @pytest.mark.parametrize("model_provider", ["gpt", "hybrid"]) def test_block_overflow(self, model_provider: str) -> None: """Test block overflow.""" skip_if_mamba_sequence_packing_not_available(model_provider) @@ -739,7 +763,7 @@ def test_block_overflow_insufficient_kv_cache(self) -> None: @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) - @pytest.mark.parametrize("model_provider", ["gpt", "mamba"]) + @pytest.mark.parametrize("model_provider", ["gpt", "hybrid"]) def test_multi_add(self, model_provider: str) -> None: """Test adding multiple requests simultaneously.""" skip_if_mamba_sequence_packing_not_available(model_provider) @@ -749,7 +773,7 @@ def test_multi_add(self, model_provider: str) -> None: @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) - @pytest.mark.parametrize("model_provider", ["gpt", "mamba"]) + @pytest.mark.parametrize("model_provider", ["gpt", "hybrid"]) def test_fixed_output_lengths(self, model_provider: str) -> None: """Test generating a fixed number of output tokens.""" skip_if_mamba_sequence_packing_not_available(model_provider) @@ -792,7 +816,7 @@ def test_cuda_graph_token_counts(self) -> None: @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) - @pytest.mark.parametrize("model_provider", ["gpt", "mamba"]) + @pytest.mark.parametrize("model_provider", ["gpt", "hybrid"]) @torch.inference_mode() def test_generate_function(self, model_provider: str) -> None: """Test the generate function that processes multiple prompts at once.""" @@ -886,7 +910,7 @@ async def test_run_engine(self): not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) @pytest.mark.skipif(not is_te_min_version("2.2.0"), reason="TE 2.2.0 is required") - @pytest.mark.parametrize("model_provider", ["gpt", "mamba"]) + @pytest.mark.parametrize("model_provider", ["gpt", "hybrid"]) def test_fp8_inference(self, model_provider: str): skip_if_mamba_sequence_packing_not_available(model_provider) @@ -1083,88 +1107,6 @@ def test_log_probs_token_correspondence(self): assert not math.isnan(log_prob) and not math.isinf(log_prob) assert -100.0 <= log_prob <= 0.0 - @pytest.mark.internal - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - @pytest.mark.parametrize("materialize_only_last_token_logits", [False, True]) - @pytest.mark.parametrize("sequence_parallel", [False, True]) - @pytest.mark.parametrize("ep_size", [1, 2]) - @pytest.mark.parametrize("pp_size", [1, 2]) - @pytest.mark.parametrize("tp_size", [1, 2]) - @pytest.mark.parametrize("model_provider", ["gpt", "mamba"]) - @pytest.mark.parametrize("transformer_impl", ["local", "inference_optimized"]) - @torch.inference_mode() - def test_parallel_inference( - self, - model_provider, - tp_size, - pp_size, - ep_size, - sequence_parallel, - materialize_only_last_token_logits, - transformer_impl, - ): - skip_if_mamba_sequence_packing_not_available(model_provider) - - if tp_size == 1 and pp_size == 1 and ep_size == 1: - pytest.skip(reason="Test requires tp_size > 1 or pp_size > 1 or ep_size > 1") - elif not torch.distributed.is_initialized(): - pytest.skip("Distributed not initialized") - world_size = torch.distributed.get_world_size() - min_world_size = tp_size * pp_size * ep_size - if world_size < min_world_size: - pytest.skip(f"Test requires at least {min_world_size} GPUs") - elif tp_size == 1 and sequence_parallel: - pytest.skip(reason="Sequence parallelism requires tp_size > 1") - elif tp_size > 1 and ep_size > 1 and not sequence_parallel: - pytest.skip(reason="Sequence parallelism must be used with tp_size > 1 and ep_size > 1") - elif transformer_impl == "inference_optimized": - if ep_size > 1: - pytest.skip( - reason="MoE models are not supported with the inference optimized transformer." - ) - if tp_size > 1 and not sequence_parallel: - pytest.skip( - reason=( - "The inference optimized transformer requires sequence parallelism " - "when tp_size > 1." - ) - ) - if model_provider == "mamba": - pytest.skip( - reason="Mamba model is not supported with the inference optimized transformer." - ) - - env = self._run_test( - model_provider=model_provider, - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - expert_model_parallel_size=ep_size, - sequence_parallel=sequence_parallel, - materialize_only_last_token_logits=materialize_only_last_token_logits, - transformer_impl=transformer_impl, - ) - - @pytest.mark.internal - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - @pytest.mark.parametrize("materialize_only_last_token_logits", [False, True]) - def test_sequence_parallel_fp8_inference(self, materialize_only_last_token_logits: bool): - fp8_available, reason_for_no_fp8 = check_fp8_support() - if not fp8_available: - pytest.skip(reason_for_no_fp8) - - self._run_test( - min_prompt_length=19, - max_prompt_length=19, - tensor_model_parallel_size=4, - sequence_parallel=True, - materialize_only_last_token_logits=True, - fp8=True, - ) - @pytest.mark.internal @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" @@ -1299,11 +1241,11 @@ def test_mamba_chunked_prefill(self): """ Test chunked prefill with a Mamba model. """ - skip_if_mamba_sequence_packing_not_available("mamba") + skip_if_mamba_sequence_packing_not_available("hybrid") # Context max tokens = 50. test_config = DynamicEngineTestConfig( - model_provider="mamba", + model_provider="hybrid", num_requests=0, num_tokens_to_generate=None, num_tokens_total=200, @@ -2321,9 +2263,15 @@ def set_epoch(epoch): not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) @torch.inference_mode() - def test_speculative_decoding_with_early_termination(self): + @pytest.mark.parametrize("sampling_backend", ["torch", "flashinfer"]) + @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) + def test_speculative_decoding_with_early_termination( + self, materialize_only_last_token_logits, sampling_backend + ): """Test that speculative decoding handles premature request termination safely (e.g. hitting max_sequence_length mid-speculative-batch).""" + if sampling_backend == "flashinfer": + pytest.importorskip("flashinfer") # Set max_sequence_length tight so it terminates during a speculative step test_config = DynamicEngineTestConfig( @@ -2334,7 +2282,8 @@ def test_speculative_decoding_with_early_termination(self): max_sequence_length=7, # Will force termination after 3 tokens model_provider="gpt", num_speculative_tokens=3, - materialize_only_last_token_logits=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, + sampling_backend=sampling_backend, ) env = self._build_test_env(test_config) @@ -2359,9 +2308,13 @@ def mock_mtp_forward(*args, **kwargs): unwrapped_model._decoder_hidden_states_cache = torch.zeros( tokens.size(1), 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) + if test_config.materialize_only_last_token_logits: + base_logits = env.engine.context.last_token_logits(base_logits).unsqueeze(0) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) logits = torch.zeros( n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 @@ -2393,12 +2346,18 @@ def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, de @pytest.mark.internal @torch.inference_mode() - def test_speculative_block_boundary_crossing(self): + @pytest.mark.parametrize("sampling_backend", ["torch", "flashinfer"]) + @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) + def test_speculative_block_boundary_crossing( + self, materialize_only_last_token_logits, sampling_backend + ): """Test to verify KV cache block boundary crossing logic. When a request fills exactly one block and speculative decoding generates multiple tokens, the first new token shouldn't incorrectly overwrite the old block. """ + if sampling_backend == "flashinfer": + pytest.importorskip("flashinfer") test_config = DynamicEngineTestConfig( num_requests=1, min_prompt_length=256, @@ -2408,8 +2367,9 @@ def test_speculative_block_boundary_crossing(self): context_block_size_tokens=256, # Exactly matches prompt length context_max_requests=16, model_provider="gpt", - materialize_only_last_token_logits=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, use_fixed_output_lengths=True, + sampling_backend=sampling_backend, ) env = self._build_test_env(test_config) @@ -2449,9 +2409,13 @@ def test_speculative_block_boundary_crossing(self): not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) @torch.inference_mode() - def test_speculative_stop_word_hit(self): + @pytest.mark.parametrize("sampling_backend", ["torch", "flashinfer"]) + @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) + def test_speculative_stop_word_hit(self, materialize_only_last_token_logits, sampling_backend): """Test that if an accepted speculative token completes a stop word, the request correctly triggers the stop logic without crashing.""" + if sampling_backend == "flashinfer": + pytest.importorskip("flashinfer") test_config = DynamicEngineTestConfig( num_requests=0, # We will manually add our request cleanly @@ -2459,8 +2423,9 @@ def test_speculative_stop_word_hit(self): max_prompt_length=4, num_tokens_to_generate=10, num_speculative_tokens=2, - materialize_only_last_token_logits=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, model_provider="gpt", + sampling_backend=sampling_backend, ) env = self._build_test_env(test_config) @@ -2482,9 +2447,13 @@ def mock_deterministic_forward(*args, **kwargs): unwrapped_model._decoder_hidden_states_cache = torch.zeros( s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) + if test_config.materialize_only_last_token_logits: + base_logits = env.engine.context.last_token_logits(base_logits).unsqueeze(0) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) # Predict next_token_ids + 1 (continuing the ascending sequence) pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) @@ -2533,9 +2502,15 @@ def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, de not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) @torch.inference_mode() - def test_speculative_long_stop_word_hit(self): + @pytest.mark.parametrize("sampling_backend", ["torch", "flashinfer"]) + @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) + def test_speculative_long_stop_word_hit( + self, materialize_only_last_token_logits, sampling_backend + ): """Test that if an accepted speculative token completes a long stop word (length > num_speculative_tokens), it is correctly detected.""" + if sampling_backend == "flashinfer": + pytest.importorskip("flashinfer") test_config = DynamicEngineTestConfig( num_requests=0, @@ -2543,8 +2518,9 @@ def test_speculative_long_stop_word_hit(self): max_prompt_length=4, num_tokens_to_generate=10, num_speculative_tokens=2, - materialize_only_last_token_logits=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, model_provider="gpt", + sampling_backend=sampling_backend, ) env = self._build_test_env(test_config) @@ -2566,9 +2542,13 @@ def mock_deterministic_forward(*args, **kwargs): unwrapped_model._decoder_hidden_states_cache = torch.zeros( s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) + if test_config.materialize_only_last_token_logits: + base_logits = env.engine.context.last_token_logits(base_logits).unsqueeze(0) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) # Predict next_token_ids + 1 (continuing the ascending sequence) pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) @@ -2613,7 +2593,11 @@ def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, de not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) @torch.inference_mode() - def test_speculative_stop_word_truncates_trailing_tokens(self): + @pytest.mark.parametrize("sampling_backend", ["torch", "flashinfer"]) + @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) + def test_speculative_stop_word_truncates_trailing_tokens( + self, materialize_only_last_token_logits, sampling_backend + ): """Test that when a stop word lands in the middle of speculative tokens, the extra tokens generated after the stop word are removed. @@ -2621,6 +2605,8 @@ def test_speculative_stop_word_truncates_trailing_tokens(self): (1 base + 2 speculative). If the stop word is [6] and the engine generates [5, 6, 7] in one step, token 7 must be truncated so the output ends with the stop word [6].""" + if sampling_backend == "flashinfer": + pytest.importorskip("flashinfer") test_config = DynamicEngineTestConfig( num_requests=0, @@ -2628,8 +2614,9 @@ def test_speculative_stop_word_truncates_trailing_tokens(self): max_prompt_length=4, num_tokens_to_generate=10, num_speculative_tokens=2, - materialize_only_last_token_logits=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, model_provider="gpt", + sampling_backend=sampling_backend, ) env = self._build_test_env(test_config) @@ -2651,9 +2638,13 @@ def mock_deterministic_forward(*args, **kwargs): unwrapped_model._decoder_hidden_states_cache = torch.zeros( s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) + if test_config.materialize_only_last_token_logits: + base_logits = env.engine.context.last_token_logits(base_logits).unsqueeze(0) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) # Predict next_token_ids + 1 (continuing the ascending sequence) pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) @@ -2729,9 +2720,16 @@ def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, de "non_divisible_boundary", ], ) + @pytest.mark.parametrize("sampling_backend", ["torch", "flashinfer"]) + @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) @torch.inference_mode() def test_speculative_tokens_exceed_max_sequence_length( - self, prompt_length, num_tokens_to_generate, num_speculative_tokens + self, + prompt_length, + num_tokens_to_generate, + num_speculative_tokens, + materialize_only_last_token_logits, + sampling_backend, ): """Test that speculative decoding correctly trims output when speculative tokens would push the sequence beyond max_sequence_length. @@ -2741,6 +2739,8 @@ def test_speculative_tokens_exceed_max_sequence_length( speculative tokens are accepted and the boundary trimming logic is actually exercised. """ + if sampling_backend == "flashinfer": + pytest.importorskip("flashinfer") max_sequence_length = prompt_length + num_tokens_to_generate test_config = DynamicEngineTestConfig( @@ -2750,11 +2750,12 @@ def test_speculative_tokens_exceed_max_sequence_length( num_tokens_to_generate=num_tokens_to_generate, max_sequence_length=max_sequence_length, num_speculative_tokens=num_speculative_tokens, - materialize_only_last_token_logits=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, model_provider="gpt", # Disable positional embeddings so speculative position IDs # beyond max_sequence_length don't cause out-of-bounds lookups. position_embedding_type="none", + sampling_backend=sampling_backend, ) env = self._build_test_env(test_config) @@ -2776,8 +2777,12 @@ def deterministic_forward(*args, **kwargs): # Wrap the real MTP step similarly. real_mtp = unwrapped_model.compute_mtp_single_step - def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def deterministic_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) logits.zero_() logits[..., 0] = 100.0 return hidden_states, logits @@ -2853,8 +2858,12 @@ def test_detokenize_stop_sequence_flag(self, detokenize_stop_sequence): @pytest.mark.parametrize( "acceptance_mode", ["all_rejected", "all_accepted"], ids=["all_rejected", "all_accepted"] ) + @pytest.mark.parametrize("sampling_backend", ["torch", "flashinfer"]) + @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) @torch.inference_mode() - def test_speculative_sequence_length_double_counting(self, acceptance_mode): + def test_speculative_sequence_length_double_counting( + self, acceptance_mode, materialize_only_last_token_logits, sampling_backend + ): """Test to verify active_sequence_lengths is not double-counted. If active sequence length is double-counted during speculative decoding, @@ -2866,6 +2875,8 @@ def test_speculative_sequence_length_double_counting(self, acceptance_mode): a faulty formula that adds accepted_tokens on top of the KV length will over-count by 2 per step, finishing the request after only 4 of 6 tokens. """ + if sampling_backend == "flashinfer": + pytest.importorskip("flashinfer") test_config = DynamicEngineTestConfig( num_requests=0, min_prompt_length=4, @@ -2875,10 +2886,11 @@ def test_speculative_sequence_length_double_counting(self, acceptance_mode): context_max_requests=16, num_speculative_tokens=2, model_provider="gpt", - materialize_only_last_token_logits=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, use_fixed_output_lengths=False, context_max_tokens=512, position_embedding_type="none", + sampling_backend=sampling_backend, ) env = self._build_test_env(test_config) @@ -2902,6 +2914,8 @@ def mock_mtp_forward(*args, **kwargs): model._decoder_hidden_states_cache = torch.zeros( s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) + if test_config.materialize_only_last_token_logits: + base_logits = env.engine.context.last_token_logits(base_logits).unsqueeze(0) return base_logits def mock_compute_mtp(*args_mtp, **kwargs_mtp): @@ -2930,8 +2944,17 @@ def deterministic_forward(*args, **kwargs): real_mtp = model.compute_mtp_single_step - def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def deterministic_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, + next_token_ids, + position_ids, + depth, + eager=eager, + cache_key=cache_key, + ) logits.zero_() logits[..., 0] = 100.0 return hidden_states, logits @@ -2964,12 +2987,18 @@ def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) @torch.inference_mode() - def test_speculative_decoding_with_eviction_and_swapping(self): + @pytest.mark.parametrize("sampling_backend", ["torch", "flashinfer"]) + @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) + def test_speculative_decoding_with_eviction_and_swapping( + self, materialize_only_last_token_logits, sampling_backend + ): """Test that speculative decoding works correctly when requests are paused and evicted. This exercises the `_swap_book_keeping_tensors` logic with the 2D `new_speculative_tokens` tensor, ensuring no dimensional mismatch or index errors occur during tensor swapping. """ + if sampling_backend == "flashinfer": + pytest.importorskip("flashinfer") # Very constrained memory environment to force pausing and eviction test_config = DynamicEngineTestConfig( num_requests=3, @@ -2981,8 +3010,9 @@ def test_speculative_decoding_with_eviction_and_swapping(self): context_buffer_size_gb=0.00064, # 640 KB context_paused_buffer_size_gb=0.0, # 0 paused buffer forces immediate eviction model_provider="gpt", - materialize_only_last_token_logits=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, use_fixed_output_lengths=True, + sampling_backend=sampling_backend, ) env = self._build_test_env(test_config) @@ -3005,9 +3035,13 @@ def mock_safe_forward(*args, **kwargs): unwrapped_model._decoder_hidden_states_cache = torch.zeros( s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) + if test_config.materialize_only_last_token_logits: + base_logits = env.engine.context.last_token_logits(base_logits).unsqueeze(0) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) logits = torch.zeros( n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 @@ -3220,7 +3254,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) logits = torch.randn( n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 @@ -3340,7 +3376,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) logits = torch.randn( n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 @@ -3470,7 +3508,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) logits = torch.randn( n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 @@ -3723,7 +3763,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_wrong(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_wrong( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): n = hidden_states.size(0) wrong_toks = (next_token_ids + 5).clamp(max=test_config.vocab_size - 1) logits = torch.zeros( @@ -3819,7 +3861,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) logits = torch.zeros( @@ -3931,8 +3975,12 @@ def deterministic_forward(*args, **kwargs): real_mtp = unwrapped_model.compute_mtp_single_step - def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def deterministic_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) logits.zero_() logits[..., 0] = 100.0 return hidden_states, logits @@ -4040,8 +4088,12 @@ def deterministic_forward(*args, **kwargs): # Deterministic MTP: also predict token 0 → all speculative tokens accepted. real_mtp = unwrapped_model.compute_mtp_single_step - def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def deterministic_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) logits.zero_() logits[..., 0] = 100.0 return hidden_states, logits @@ -4112,8 +4164,12 @@ def deterministic_forward(*args, **kwargs): # During prefill, no MTP runs, so request 2 is unaffected. real_mtp = unwrapped_model.compute_mtp_single_step - def heterogeneous_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def heterogeneous_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) n = logits.size(0) logits.zero_() if n >= 2: @@ -4211,8 +4267,12 @@ def deterministic_forward(*args, **kwargs): real_mtp = unwrapped_model.compute_mtp_single_step - def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def deterministic_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) logits.zero_() logits[..., 0] = 100.0 return hidden_states, logits @@ -4257,40 +4317,6 @@ def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): assert isinstance(lp, float) assert -0.1 < lp <= 0.0, f"Token {j}: expected log prob near 0.0, got {lp}" - @pytest.mark.internal - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - @torch.inference_mode() - def test_speculative_decoding_pipeline_parallel(self): - """Test speculative decoding with pipeline parallelism (pp_size=2). - - Verifies that MTP logit broadcasts across pipeline stages don't hang - or produce incorrect results. Each PP stage must participate in the - same number of MTP broadcast rounds. - """ - if not torch.distributed.is_initialized(): - pytest.skip("Distributed not initialized") - world_size = torch.distributed.get_world_size() - pp_size = 2 - if world_size < pp_size: - pytest.skip(f"Test requires at least {pp_size} GPUs") - - env = self._run_test( - model_provider="gpt", - pipeline_model_parallel_size=pp_size, - num_speculative_tokens=2, - num_tokens_to_generate=6, - materialize_only_last_token_logits=False, - ) - - for request in env.requests: - assert ( - request.status == Status.COMPLETED - ), f"Request {request.request_id}: status={request.status}" - num_expected = request.sampling_params.num_tokens_to_generate - assert len(request.generated_tokens) <= num_expected - @pytest.mark.internal @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" @@ -4319,7 +4345,7 @@ def test_speculative_decoding_mamba_hybrid(self, rejection_mode): Two requests run simultaneously to exercise batched rewind indexing where mamba_metadata.request_to_mamba_state_idx differs per request. """ - skip_if_mamba_sequence_packing_not_available("mamba") + skip_if_mamba_sequence_packing_not_available("hybrid") num_tokens_to_generate = 8 test_config = DynamicEngineTestConfig( @@ -4329,7 +4355,7 @@ def test_speculative_decoding_mamba_hybrid(self, rejection_mode): num_tokens_to_generate=num_tokens_to_generate, num_speculative_tokens=2, materialize_only_last_token_logits=False, - model_provider="mamba", + model_provider="hybrid", ) env = self._build_test_env(test_config) @@ -4348,9 +4374,13 @@ def deterministic_forward(*args, **kwargs): real_mtp = unwrapped_model.compute_mtp_single_step - def mtp_with_rejection(hidden_states, next_token_ids, position_ids, depth): + def mtp_with_rejection( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): # Run real MTP to exercise Mamba intermediate state saving. - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) logits.zero_() if rejection_mode == "all_accepted": # Predict token 0 (same as base) → accepted. @@ -4410,6 +4440,137 @@ def mtp_with_rejection(hidden_states, next_token_ids, position_ids, depth): assert env.engine.context.total_request_count == 0 +class TestDynamicInferenceEngineParallel(DynamicInferenceEngineTestBase): + """Tests that require non-default parallel configs (tp>1, pp>1, or ep>1). + + Each test initializes its own parallel state and tears it down afterward, + so these are separated from TestDynamicInferenceEngine to avoid accumulating + NCCL communicator memory from repeated init/destroy cycles. + """ + + def teardown_method(self, method): + delete_cuda_graphs() + Utils.destroy_model_parallel() + + @classmethod + @torch.inference_mode() + def _build_test_env(cls, test_config): + Utils.initialize_model_parallel( + tensor_model_parallel_size=test_config.tensor_model_parallel_size, + pipeline_model_parallel_size=test_config.pipeline_model_parallel_size, + expert_model_parallel_size=test_config.expert_model_parallel_size, + expert_tensor_parallel_size=1, + ) + return super()._build_test_env(test_config) + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @pytest.mark.parametrize("materialize_only_last_token_logits", [False, True]) + @pytest.mark.parametrize("sequence_parallel", [False, True]) + @pytest.mark.parametrize("ep_size", [1, 2]) + @pytest.mark.parametrize("pp_size", [1, 2]) + @pytest.mark.parametrize("tp_size", [1, 2]) + @pytest.mark.parametrize("model_provider", ["gpt", "hybrid"]) + @pytest.mark.parametrize("transformer_impl", ["local", "inference_optimized"]) + @torch.inference_mode() + def test_parallel_inference( + self, + model_provider, + tp_size, + pp_size, + ep_size, + sequence_parallel, + materialize_only_last_token_logits, + transformer_impl, + ): + skip_if_mamba_sequence_packing_not_available(model_provider) + + if tp_size == 1 and pp_size == 1 and ep_size == 1: + pytest.skip(reason="Test requires tp_size > 1 or pp_size > 1 or ep_size > 1") + elif not torch.distributed.is_initialized(): + pytest.skip("Distributed not initialized") + world_size = torch.distributed.get_world_size() + min_world_size = tp_size * pp_size * ep_size + if world_size < min_world_size: + pytest.skip(f"Test requires at least {min_world_size} GPUs") + elif tp_size == 1 and sequence_parallel: + pytest.skip(reason="Sequence parallelism requires tp_size > 1") + elif tp_size > 1 and ep_size > 1 and not sequence_parallel: + pytest.skip(reason="Sequence parallelism must be used with tp_size > 1 and ep_size > 1") + elif transformer_impl == "inference_optimized": + if ep_size > 1: + pytest.skip( + reason="MoE models are not supported with the inference optimized transformer." + ) + if tp_size > 1 and not sequence_parallel: + pytest.skip( + reason=( + "The inference optimized transformer requires sequence parallelism " + "when tp_size > 1." + ) + ) + + env = self._run_test( + model_provider=model_provider, + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=ep_size, + sequence_parallel=sequence_parallel, + materialize_only_last_token_logits=materialize_only_last_token_logits, + transformer_impl=transformer_impl, + ) + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @pytest.mark.parametrize("materialize_only_last_token_logits", [False, True]) + def test_sequence_parallel_fp8_inference(self, materialize_only_last_token_logits: bool): + fp8_available, reason_for_no_fp8 = check_fp8_support() + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + self._run_test( + min_prompt_length=19, + max_prompt_length=19, + tensor_model_parallel_size=4, + sequence_parallel=True, + materialize_only_last_token_logits=True, + fp8=True, + ) + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_decoding_pipeline_parallel(self): + """Test speculative decoding with pipeline parallelism (pp_size=2).""" + if not torch.distributed.is_initialized(): + pytest.skip("Distributed not initialized") + world_size = torch.distributed.get_world_size() + pp_size = 2 + if world_size < pp_size: + pytest.skip(f"Test requires at least {pp_size} GPUs") + + env = self._run_test( + model_provider="gpt", + pipeline_model_parallel_size=pp_size, + num_speculative_tokens=2, + num_tokens_to_generate=6, + materialize_only_last_token_logits=False, + ) + + for request in env.requests: + assert ( + request.status == Status.COMPLETED + ), f"Request {request.request_id}: status={request.status}" + num_expected = request.sampling_params.num_tokens_to_generate + assert len(request.generated_tokens) <= num_expected + + CHUNKED_CG_BLOCK_SIZE = 256 CHUNKED_CG_VOCAB_SIZE = 10000 CHUNKED_CG_MAX_SEQ_LEN = 2048 @@ -4430,6 +4591,7 @@ def setup_class(cls): @classmethod def teardown_class(cls): + delete_cuda_graphs() set_rounder(64) Utils.destroy_model_parallel() @@ -4460,7 +4622,7 @@ def _create_model(self, model_provider, num_cuda_graphs): pre_process=parallel_state.is_pipeline_first_stage(), post_process=parallel_state.is_pipeline_last_stage(), ).cuda() - elif model_provider == "mamba": + elif model_provider == "hybrid": config = TransformerConfig( params_dtype=torch.bfloat16, num_layers=3, @@ -4476,9 +4638,9 @@ def _create_model(self, model_provider, num_cuda_graphs): add_bias_linear=True, is_hybrid_model=True, ) - model = MambaModel( + model = HybridModel( config=config, - mamba_stack_spec=mamba_stack_spec, + hybrid_stack_spec=hybrid_stack_spec, vocab_size=CHUNKED_CG_VOCAB_SIZE, max_sequence_length=CHUNKED_CG_MAX_SEQ_LEN, parallel_output=True, @@ -4494,17 +4656,6 @@ def _create_model(self, model_provider, num_cuda_graphs): model.eval() return model - def _reset_cuda_graph_state(self, model): - """Reset all CUDA graph global and per-module state.""" - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] - _CudagraphGlobalRecord.cudagraph_inference_record = [] - CudaGraphManager.global_mempool = None - for module in model.modules(): - if isinstance(module, CudaGraphManager): - module.cudagraph_runners.clear() - module.inference_cudagraphs_lookup_table.clear() - def _build_engine(self, model, enable_chunked_prefill, num_cuda_graphs, context_max_tokens): """Build an engine with the given chunked prefill / CUDA graph config.""" set_rounder(4) @@ -4521,6 +4672,7 @@ def _build_engine(self, model, enable_chunked_prefill, num_cuda_graphs, context_ enable_chunked_prefill=enable_chunked_prefill, max_tokens=context_max_tokens, max_requests=128, + sampling_backend='torch', ) if mamba_config is not None: inference_config_kwargs.update(mamba_inference_state_config=mamba_config) @@ -4537,7 +4689,7 @@ def _build_engine(self, model, enable_chunked_prefill, num_cuda_graphs, context_ vocab_size=CHUNKED_CG_VOCAB_SIZE, detokenize=lambda tokens: "tokenized_prompt" ), ) - self._reset_cuda_graph_state(model) + delete_cuda_graphs() return DynamicInferenceEngine(controller, context) def _run_to_completion(self, engine, prompts, num_tokens_to_generate): @@ -4564,7 +4716,7 @@ def _run_to_completion(self, engine, prompts, num_tokens_to_generate): return finished, step_count - @pytest.mark.parametrize("model_provider", ["gpt", "mamba"]) + @pytest.mark.parametrize("model_provider", ["gpt", "hybrid"]) @pytest.mark.parametrize("chunked_prefill", [False, True]) @pytest.mark.parametrize("num_cuda_graphs", [None, 2]) @torch.inference_mode() @@ -4572,10 +4724,7 @@ def test_chunked_prefill_cuda_graphs(self, model_provider, chunked_prefill, num_ """Verify generated tokens match across chunked prefill and CUDA graph configs.""" skip_if_mamba_sequence_packing_not_available(model_provider) - # Clear NVTE env vars set by conftest set_env fixture. - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) + clear_nvte_env_vars() random.seed(123) torch.manual_seed(123) diff --git a/tests/unit_tests/inference/engines/test_mamba_prefix_caching_e2e.py b/tests/unit_tests/inference/engines/test_hybrid_prefix_caching_e2e.py similarity index 98% rename from tests/unit_tests/inference/engines/test_mamba_prefix_caching_e2e.py rename to tests/unit_tests/inference/engines/test_hybrid_prefix_caching_e2e.py index ce21c775b73..92890b22b53 100644 --- a/tests/unit_tests/inference/engines/test_mamba_prefix_caching_e2e.py +++ b/tests/unit_tests/inference/engines/test_hybrid_prefix_caching_e2e.py @@ -54,8 +54,8 @@ from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.models.mamba.mamba_model import MambaModel +from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec +from megatron.core.models.hybrid.hybrid_model import HybridModel from megatron.core.ssm.mamba_mixer import _check_mamba_sequence_packing_support from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord @@ -131,9 +131,9 @@ def _create_model(self, num_cuda_graphs=None): add_bias_linear=True, is_hybrid_model=True, ) - model = MambaModel( + model = HybridModel( config=transformer_config, - mamba_stack_spec=mamba_stack_spec, + hybrid_stack_spec=hybrid_stack_spec, vocab_size=VOCAB_SIZE, max_sequence_length=MAX_SEQ_LEN, parallel_output=True, @@ -200,6 +200,7 @@ def _build_engine( enable_prefix_caching=enable_prefix_caching, unified_memory_level=0, num_cuda_graphs=num_cuda_graphs, + sampling_backend='torch', ) if enable_prefix_caching: inference_config_kwargs.update( @@ -226,7 +227,7 @@ def _build_engine( for module in model.modules(): if isinstance(module, CudaGraphManager): module.cudagraph_runners.clear() - module.inference_cudagraphs_lookup_table.clear() + module.custom_cudagraphs_lookup_table.clear() return DynamicInferenceEngine(controller, context) def _make_request(self, req_id, prompt, enable_pc, num_tokens=NUM_TOKENS_TO_GENERATE): diff --git a/tests/unit_tests/inference/engines/test_prefix_caching_cuda_graphs.py b/tests/unit_tests/inference/engines/test_prefix_caching_cuda_graphs.py index 52a05f7f80f..c89209699b3 100644 --- a/tests/unit_tests/inference/engines/test_prefix_caching_cuda_graphs.py +++ b/tests/unit_tests/inference/engines/test_prefix_caching_cuda_graphs.py @@ -37,8 +37,8 @@ ) from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.models.mamba.mamba_model import MambaModel +from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec +from megatron.core.models.hybrid.hybrid_model import HybridModel from megatron.core.ssm.mamba_mixer import _check_mamba_sequence_packing_support from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord @@ -121,9 +121,9 @@ def _create_model(self, model_type, num_cuda_graphs=None): add_bias_linear=True, is_hybrid_model=True, ) - model = MambaModel( + model = HybridModel( config=config, - mamba_stack_spec=mamba_stack_spec, + hybrid_stack_spec=hybrid_stack_spec, vocab_size=VOCAB_SIZE, max_sequence_length=MAX_SEQ_LEN, parallel_output=True, @@ -147,7 +147,7 @@ def _reset_cuda_graph_state(self, model): for module in model.modules(): if isinstance(module, CudaGraphManager): module.cudagraph_runners.clear() - module.inference_cudagraphs_lookup_table.clear() + module.custom_cudagraphs_lookup_table.clear() def _build_engine(self, model, mamba_config, num_cuda_graphs): """Build an engine with prefix caching and optional CUDA graphs.""" @@ -343,9 +343,9 @@ def _create_hybrid_model(self, num_cuda_graphs=None): add_bias_linear=True, is_hybrid_model=True, ) - model = MambaModel( + model = HybridModel( config=config, - mamba_stack_spec=mamba_stack_spec, + hybrid_stack_spec=hybrid_stack_spec, vocab_size=VOCAB_SIZE, max_sequence_length=MAX_SEQ_LEN, parallel_output=True, @@ -367,7 +367,7 @@ def _reset_cuda_graph_state(self, model): for module in model.modules(): if isinstance(module, CudaGraphManager): module.cudagraph_runners.clear() - module.inference_cudagraphs_lookup_table.clear() + module.custom_cudagraphs_lookup_table.clear() def _build_engine( self, diff --git a/tests/unit_tests/inference/engines/test_static_engine.py b/tests/unit_tests/inference/engines/test_static_engine.py index 483a21d13bd..0067ff6e9bc 100644 --- a/tests/unit_tests/inference/engines/test_static_engine.py +++ b/tests/unit_tests/inference/engines/test_static_engine.py @@ -27,9 +27,10 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.cuda_graphs import delete_cuda_graphs from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_fa_min_version -from tests.unit_tests.test_utilities import Utils +from tests.unit_tests.test_utilities import Utils, clear_nvte_env_vars class StaticInferenceEngineTestHarness: @@ -45,11 +46,7 @@ def setup_engine( buffer_size_gb=10, inference_config_params_dtype=torch.float, ): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tensor_model_parallel_size, - pipeline_model_parallel_size=pipeline_model_parallel_size, - ) - + clear_nvte_env_vars() model_parallel_cuda_manual_seed(123) self.batch_size = 4 self.hidden_size = 32 @@ -111,11 +108,23 @@ def setup_engine( buffer_size_gb=buffer_size_gb, ) + +class TestStaticInferenceEngine(StaticInferenceEngineTestHarness): + + @classmethod + def setup_class(cls): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + def teardown_method(self, method): - Utils.destroy_model_parallel() + delete_cuda_graphs() + @classmethod + def teardown_class(cls): + delete_cuda_graphs() + Utils.destroy_model_parallel() -class TestStaticInferenceEngine(StaticInferenceEngineTestHarness): @pytest.mark.parametrize( "batch_size,num_trials,empty_prompt", [(4, 1, False), (4, 1, True), (4, 3, False), (2, 1, False), (8, 1, False)], @@ -294,6 +303,47 @@ async def collect_stream(stream_generator, num_tokens_to_generate): f"final_streamed_token.generated_log_probs={final_streamed_token.generated_log_probs}" ) + +class TestStaticInferenceEngineParallel(StaticInferenceEngineTestHarness): + """Tests that require non-default parallel configs (varying tp/pp/ep). + + Each test initializes its own parallel state and tears it down afterward, + so these are separated from TestStaticInferenceEngine to avoid + accumulating NCCL communicator memory from repeated init/destroy cycles. + """ + + def teardown_method(self, method): + delete_cuda_graphs() + Utils.destroy_model_parallel() + + def setup_engine( + self, + engine_max_batch_size=None, + vocab_size=100, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=1, + sequence_parallel=False, + legacy=False, + buffer_size_gb=10, + inference_config_params_dtype=torch.float, + ): + Utils.initialize_model_parallel( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + ) + super().setup_engine( + engine_max_batch_size=engine_max_batch_size, + vocab_size=vocab_size, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + sequence_parallel=sequence_parallel, + legacy=legacy, + buffer_size_gb=buffer_size_gb, + inference_config_params_dtype=inference_config_params_dtype, + ) + @pytest.mark.parametrize("sequence_parallel", [False, True]) @pytest.mark.parametrize("ep_size", [1, 2]) @pytest.mark.parametrize("pp_size", [1, 2]) diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py index f520c2441d7..ce9bb579ee6 100644 --- a/tests/unit_tests/inference/test_batch_dimension_utils.py +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -40,15 +40,13 @@ def _generate_graphs(num_cuda_graphs, use_non_decode=True): return graph_list -def _match(real, graph_list, ep_group, strict=False, decode_only=False, num_speculative_tokens=0): +def _match(real, graph_list, ep_group, strict=False): return CUDAGraphBatchDimensionBuilder.match_graph_config( real_batch_dim=real, cuda_graph_batch_dimensions_list=graph_list, strict=strict, - decode_only_cuda_graphs=decode_only, ep_group=ep_group, - smallest_non_decode_cuda_graph_size=min(MIXED_PREFILL_COUNT, MAX_REQUESTS), - num_speculative_tokens=num_speculative_tokens, + match_ep_token_counts=True, ) @@ -122,6 +120,7 @@ class TestMatchGraphConfigWithEP: Uses the world group as the EP group (all 8 GPUs form one EP group). """ + @classmethod def setup_class(cls): Utils.initialize_model_parallel( tensor_model_parallel_size=1, @@ -129,6 +128,7 @@ def setup_class(cls): expert_model_parallel_size=Utils.world_size, ) + @classmethod def teardown_class(cls): Utils.destroy_model_parallel() @@ -173,13 +173,14 @@ def test_varying_decode_token_counts(self, num_cuda_graphs): assert result is not None # ------------------------------------------------------------------ # - # 3. decode_only_cuda_graphs=True, some ranks have prefill → all None + # 3. Any rank has prefill → all ranks fall back to eager (None) # ------------------------------------------------------------------ # @pytest.mark.internal @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) - def test_decode_only_graphs_with_mixed_ranks(self, num_cuda_graphs): - """When decode_only_cuda_graphs=True and at least one EP rank has a - prefill request, ALL ranks should get None (eager mode).""" + def test_any_prefill_rank_forces_eager(self, num_cuda_graphs): + """When at least one EP rank has a prefill request, + adjust_batch_dims_for_expert_parallelism returns None and ALL ranks + get None from match_graph_config (eager mode).""" ep_group = self._get_ep_group() graph_list = _generate_graphs(num_cuda_graphs) rank = dist.get_rank() @@ -190,11 +191,9 @@ def test_decode_only_graphs_with_mixed_ranks(self, num_cuda_graphs): else: real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) - result = _match(real, graph_list, ep_group=ep_group, decode_only=True) + result = _match(real, graph_list, ep_group=ep_group) _assert_consistent_across_ranks(result, ep_group) - assert ( - result is None - ), "All ranks should run eager when decode_only=True and some rank has prefill" + assert result is None, "All ranks should run eager when any rank has prefill" # ------------------------------------------------------------------ # # 4. Mixed prefill graphs with strict matching @@ -284,11 +283,13 @@ def test_mixed_decode_and_prefill_ranks_non_strict(self, num_cuda_graphs): # ------------------------------------------------------------------ # # 9. All ranks decode-only with decode_only_cuda_graphs → should match # ------------------------------------------------------------------ # + # 9. All ranks decode-only → EP max-reduce finds a matching graph + # ------------------------------------------------------------------ # @pytest.mark.internal @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) - def test_decode_only_graphs_all_decode(self, num_cuda_graphs): - """When all EP ranks are decode-only and decode_only_cuda_graphs=True, - a match should be found.""" + def test_all_decode_ranks_match(self, num_cuda_graphs): + """When all EP ranks are decode-only, the all-reduce max lifts token + counts to the largest rank's value and a matching graph is found.""" ep_group = self._get_ep_group() graph_list = _generate_graphs(num_cuda_graphs) rank = dist.get_rank() @@ -296,9 +297,9 @@ def test_decode_only_graphs_all_decode(self, num_cuda_graphs): token_count = (rank + 1) * 4 real = BD(token_count=token_count, prefill_req_count=0, decode_req_count=token_count) - result = _match(real, graph_list, ep_group=ep_group, decode_only=True) + result = _match(real, graph_list, ep_group=ep_group) _assert_consistent_across_ranks(result, ep_group) - assert result is not None, "All-decode batch with decode_only_cuda_graphs should match" + assert result is not None, "All-decode batch should match a graph" # ------------------------------------------------------------------ # # 10. Real batch exceeds all graphs → None on all ranks @@ -352,14 +353,16 @@ def test_one_rank_oversized_forces_no_match(self, num_cuda_graphs): class TestSpeculativeDecodingBatchDimensions: """Tests for batch dimensions specifically handling speculative decoding.""" - def setup_method(self, method): + @classmethod + def setup_class(cls): Utils.initialize_model_parallel( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, expert_model_parallel_size=Utils.world_size, ) - def teardown_method(self, method): + @classmethod + def teardown_class(cls): Utils.destroy_model_parallel() @staticmethod @@ -439,9 +442,7 @@ def test_ep_sync_with_speculative_tokens(self, num_cuda_graphs): token_count = decode_reqs * (num_speculative_tokens + 1) real = BD(token_count=token_count, prefill_req_count=0, decode_req_count=decode_reqs) - result = _match( - real, graph_list, ep_group=ep_group, num_speculative_tokens=num_speculative_tokens - ) + result = _match(real, graph_list, ep_group=ep_group) # All ranks should end up syncing to the maximum requirement and picking the same graph _assert_consistent_across_ranks(result, ep_group) @@ -454,14 +455,9 @@ def test_ep_sync_with_speculative_tokens(self, num_cuda_graphs): def test_ep_mixed_decode_prefill_with_speculative_tokens(self, num_cuda_graphs): """Verify EP sync when ranks have different request states with speculative tokens. - Even ranks have decode-only requests (with speculative token multiplier). - Odd ranks have mixed prefill+decode requests. After the EP all-reduce, - all ranks must agree on a graph that accommodates both states. - - This tests the scenario where the CUDA graph config changes after EP - sync — a decode-only rank may be forced into a mixed graph, and the - speculative token multiplier must be preserved correctly through the - transition. + Even ranks have decode-only requests; odd ranks have mixed prefill+decode. + Since any prefill rank causes all ranks to fall back to eager (None), + the test verifies that all ranks consistently get None. """ ep_group = self._get_ep_group() num_speculative_tokens = 2 @@ -501,30 +497,18 @@ def test_ep_mixed_decode_prefill_with_speculative_tokens(self, num_cuda_graphs): decode_req_count=decode_reqs, ) - result = _match( - real, graph_list, ep_group=ep_group, num_speculative_tokens=num_speculative_tokens - ) + result = _match(real, graph_list, ep_group=ep_group) - # All ranks must agree on a graph after EP sync. + # Any rank has prefill → all ranks get None (eager mode). _assert_consistent_across_ranks(result, ep_group) - - # The matched graph must have enough tokens for the largest rank's needs. - if result is not None: - max_token_count = max( - 4 * (num_speculative_tokens + 1), 2 * (num_speculative_tokens + 1) + 8 - ) - assert result.token_count >= max_token_count, ( - f"Matched graph token_count {result.token_count} < " f"required {max_token_count}" - ) + assert result is None, "Any prefill rank should force all ranks to eager mode" @pytest.mark.internal def test_ep_speculative_decode_to_mixed_graph_transition(self): - """Verify that a decode-only rank can use a mixed graph after EP sync. + """Verify EP consistency when ranks have mixed prefill/decode states. - When one EP rank is decode-only and another has prefill, the sync - may force the decode-only rank into a mixed graph. The decode request - count and speculative token multiplier must still be valid in the - selected graph. + When one EP rank is decode-only and another has prefill, the NCCL + EP sync detects the prefill and returns None for all ranks (eager mode). """ ep_group = self._get_ep_group() num_speculative_tokens = 3 @@ -556,13 +540,8 @@ def test_ep_speculative_decode_to_mixed_graph_transition(self): # Prefill-only: forces even ranks out of decode-only graph. real = BD(token_count=32, prefill_req_count=2, decode_req_count=0) - result = _match( - real, graph_list, ep_group=ep_group, num_speculative_tokens=num_speculative_tokens - ) + result = _match(real, graph_list, ep_group=ep_group) + # Odd ranks have prefill → any prefill rank forces all to eager (None). _assert_consistent_across_ranks(result, ep_group) - # After EP sync, a graph must be found (mixed graphs accommodate both). - assert result is not None, ( - f"Rank {rank}: no graph matched after EP sync. " - f"Decode-only ranks should transition to mixed graphs." - ) + assert result is None, "Any prefill rank should force all ranks to eager mode" diff --git a/tests/unit_tests/inference/test_communication_utils.py b/tests/unit_tests/inference/test_communication_utils.py index 95de6c70560..e0c5a9f734d 100644 --- a/tests/unit_tests/inference/test_communication_utils.py +++ b/tests/unit_tests/inference/test_communication_utils.py @@ -1,3 +1,5 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import pytest import torch import torch.distributed as dist @@ -22,6 +24,9 @@ def setup(self): self.size = [16, 8] self.dtype = torch.float32 + def teardown_method(self, method): + Utils.destroy_model_parallel() + @pytest.mark.skipif( not is_torch_min_version("2.4.0"), reason="torch.distributed.init_device_mesh requires torch >= 2.4.0", @@ -65,7 +70,8 @@ def test_broadcast_comparison(self, tp_size, pp_size): assert torch.allclose( tensor_received_global, tensor_received_custom ), "broadcast_from_last_pipeline_stage should be the same with or without custom pp_group" - Utils.destroy_model_parallel() + + grid.destroy() @pytest.mark.skipif( not is_torch_min_version("2.4.0"), @@ -126,4 +132,5 @@ def test_send_recv(self, tp_size, pp_size): assert torch.allclose( local_recv_buffer_global, local_recv_buffer_custom ), "Custom and global recv buffers should be the same." - Utils.destroy_model_parallel() + + grid.destroy() diff --git a/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py b/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py index b2e94bc54f9..5f1aeca5b13 100644 --- a/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py +++ b/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py @@ -118,6 +118,7 @@ def __init__(self): self.use_coordinator = False self.ep_world_size = 1 + self.disable_ep_consensus = False self.step_start_event = unittest.mock.MagicMock() self.step_end_event = unittest.mock.MagicMock() @@ -406,6 +407,69 @@ async def test_parallel_configs( finally: await cleanup_engine(engine, client) + @pytest.mark.internal + @pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test") + @pytest.mark.asyncio + @pytest.mark.parametrize( + "initialize_model_parallel", + [pytest.param((1, 1, 1), id="tp1-pp1-ep1")], + indirect=["initialize_model_parallel"], + ) + async def test_disable_ep_consensus( + self, initialize_model_parallel, coordinator, test_case_communicator + ): + """With disable_ep_consensus=True, the control loop must call + controller.dummy_forward() on iterations where local_pending == 0 + instead of sleeping, so EP collectives stay in sync. Sleeping here + would deadlock peers running real forwards on EP > 1.""" + dp_addr = coordinator + port = int(dp_addr.rsplit(":", 1)[-1]) + requests = self.build_requests(num_requests=2) + engine = DummyEngine() + engine.disable_ep_consensus = True + engine.controller.dummy_forward = unittest.mock.MagicMock( + wraps=engine.controller.dummy_forward + ) + rank = torch.distributed.get_rank() + client = None + + try: + await engine.start_listening_to_data_parallel_coordinator( + inference_coordinator_port=port, launch_inference_coordinator=False + ) + await asyncio.wait_for(test_case_communicator.all_reduce_max(1), timeout=30.0) + + if rank == 0: + client = InferenceClient(dp_addr) + client.start() + await asyncio.wait_for(engine.wait_until(EngineState.RUNNING), timeout=5.0) + + # Idle window: with no work, the loop must spin on dummy_forward, + # not sleep. Several iterations should fire within 0.2s. + idle_baseline = engine.controller.dummy_forward.call_count + await asyncio.sleep(0.2) + idle_calls = engine.controller.dummy_forward.call_count - idle_baseline + assert idle_calls > 0, ( + "disable_ep_consensus must call dummy_forward on idle iterations " + f"to keep EP collectives in sync (call_count={idle_calls})" + ) + + # Submit and complete requests to confirm the step path still works. + futures = [client.add_request(prompt=p, sampling_params=s) for p, s in requests] + results = await asyncio.wait_for(asyncio.gather(*futures), timeout=5.0) + for result in results: + assert result["status"] == Status.COMPLETED.name + + # Pause/unpause must still drive state transitions correctly. + client.pause_engines() + await asyncio.wait_for(engine.wait_until(EngineState.PAUSED), timeout=5.0) + client.unpause_engines() + await asyncio.wait_for(engine.wait_until(EngineState.RUNNING), timeout=5.0) + + await asyncio.wait_for(test_case_communicator.all_reduce_max(1), timeout=30.0) + finally: + await cleanup_engine(engine, client) + @pytest.mark.internal @pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test") @pytest.mark.asyncio diff --git a/tests/unit_tests/models/test_hybrid_model_expert_parallel_inference.py b/tests/unit_tests/inference/test_hybrid_moe.py similarity index 58% rename from tests/unit_tests/models/test_hybrid_model_expert_parallel_inference.py rename to tests/unit_tests/inference/test_hybrid_moe.py index 6fb3df43ad5..c3820521ba2 100644 --- a/tests/unit_tests/models/test_hybrid_model_expert_parallel_inference.py +++ b/tests/unit_tests/inference/test_hybrid_moe.py @@ -20,20 +20,23 @@ import pytest import torch -import torch.distributed as dist from megatron.core import parallel_state from megatron.core.inference.batch_dimensions_utils import InferenceBatchDimensions from megatron.core.inference.config import InferenceConfig, MambaInferenceStateConfig from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext -from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec +from megatron.core.inference.symmetric_memory import SymmetricMemoryManager +from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_inference_stack_spec from megatron.core.models.hybrid.hybrid_model import HybridModel from megatron.core.ssm.mamba_mixer import _check_mamba_sequence_packing_support from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord, delete_cuda_graphs -from megatron.core.transformer.enums import AttnBackend +from megatron.core.transformer.moe.token_dispatcher_inference import NVLSAllGatherVDispatcher from megatron.core.utils import is_fa_min_version +from tests.unit_tests.inference.test_moe_dispatching_and_routing import ( + NANOV3_BASE, + _make_base_config, +) from tests.unit_tests.test_utilities import Utils # Request state constants for parametrized tests. @@ -41,8 +44,13 @@ DECODE = "decode" # >0 decode, 0 prefill PREFILL = "prefill" # 0 decode, >0 prefill MIXED = "mixed" # >0 decode, >0 prefill - -ALL_STATES = [NONE, DECODE, PREFILL, MIXED] +PREFILL_AT_MAX_TOKENS = "prefill_max_tokens" +DECODE_AT_MAX_REQUESTS = "decode_max_requests" +MIXED_GIANT_PREFILL = ( + "mixed_giant_prefill" # (max_requests-1) decode + 1 prefill with tokens > max_requests +) +_NO_CUDA_GRAPH_STATES = {PREFILL_AT_MAX_TOKENS, MIXED_GIANT_PREFILL} +ALL_STATES = [NONE, DECODE, PREFILL, MIXED, PREFILL_AT_MAX_TOKENS, MIXED_GIANT_PREFILL] # Fixed expert-parallel size. When world_size > _EP_SIZE the remaining # ranks form data-parallel replicas, each running the same EP combo @@ -53,9 +61,10 @@ # across the EP ranks. Since rank assignment is symmetric (shuffling ranks # with the same multiset of states is not a distinct configuration), we use # combinations_with_replacement rather than the full Cartesian product. -# For _EP_SIZE=4 this gives C(4+4-1, 4) = 35 test cases. +# For _EP_SIZE=4 this gives C(6+4-1, 4) = 126 test cases. _STATE_COMBOS = list(itertools.combinations_with_replacement(ALL_STATES, _EP_SIZE)) + # Batch dimensions used to set up each non-dummy state via # add_dummy_requests_for_cudagraph_capture. These are intentionally small # to keep the tests fast while still exercising the EP padding logic. @@ -66,63 +75,81 @@ PREFILL: InferenceBatchDimensions(token_count=32, prefill_req_count=2, decode_req_count=0), # 4 decode (4 tokens) + 2 prefill (60 tokens) = 64 tokens MIXED: InferenceBatchDimensions(token_count=64, prefill_req_count=2, decode_req_count=4), + PREFILL_AT_MAX_TOKENS: InferenceBatchDimensions( + token_count=512, prefill_req_count=1, decode_req_count=0 + ), + DECODE_AT_MAX_REQUESTS: InferenceBatchDimensions( + token_count=64, prefill_req_count=0, decode_req_count=64 + ), + # 63 decode (1 token each) + 1 prefill (65 tokens) = 128 tokens; prefill tokens > max_requests=64 + MIXED_GIANT_PREFILL: InferenceBatchDimensions( + token_count=128, prefill_req_count=1, decode_req_count=63 + ), } -@pytest.mark.internal -class TestDynamicInference: - """Verify full HybridModel output shapes under EP strict matching scenarios.""" +def setup_module(module): + available, reason = _check_mamba_sequence_packing_support(for_inference_not_training=True) + if not available: + pytest.skip(reason, allow_module_level=True) + if not is_fa_min_version("2.7.3"): + pytest.skip("need flash-attn >= 2.7.3 for dynamic batching", allow_module_level=True) + if Utils.world_size < _EP_SIZE: + pytest.skip(f"EP test requires at least {_EP_SIZE} GPUs", allow_module_level=True) + if Utils.world_size % _EP_SIZE != 0: + pytest.skip( + f"world_size ({Utils.world_size}) must be divisible by EP size ({_EP_SIZE})", + allow_module_level=True, + ) + try: + from megatron.core.inference.utils import check_flashinfer_jit_cache_installed + + check_flashinfer_jit_cache_installed() + except RuntimeError as exc: + pytest.skip(str(exc), allow_module_level=True) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=_EP_SIZE, + ) - HIDDEN_SIZE = 256 - NUM_ATTN_HEADS = 4 - MAX_SEQ_LEN = 512 - VOCAB_SIZE = 128 - def setup_method(self, method): - available, reason = _check_mamba_sequence_packing_support(for_inference_not_training=True) - if not available: - pytest.skip(reason, allow_module_level=True) - if not is_fa_min_version("2.7.3"): - pytest.skip("need flash-attn >= 2.7.3 for dynamic batching", allow_module_level=True) - if Utils.world_size < _EP_SIZE: - pytest.skip(f"EP test requires at least {_EP_SIZE} GPUs", allow_module_level=True) - if Utils.world_size % _EP_SIZE != 0: - pytest.skip( - f"world_size ({Utils.world_size}) must be divisible by EP size ({_EP_SIZE})", - allow_module_level=True, - ) +def teardown_module(module): + NVLSAllGatherVDispatcher._delete_buffers() + SymmetricMemoryManager.destroy() + Utils.destroy_model_parallel() - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - expert_model_parallel_size=_EP_SIZE, - ) + +class _TestDynamicInferenceBase: + """Shared helpers for NVLS and NCCL inference test classes. + + Model-parallel is initialized once for the entire module (see setup_module / + teardown_module above) so that NVLS sym-mem handles are never alive across a + destroy/reinit cycle of the EP process group. + """ + + MAX_SEQ_LEN = 512 + VOCAB_SIZE = 128 def teardown_method(self, method): + # CUDA-graph replay is asynchronous at the CPU level. Synchronize device + # then barrier so no rank races into the next test's collectives while + # another rank is still executing the previous step on the GPU. + torch.cuda.synchronize() + torch.distributed.barrier() delete_cuda_graphs() - Utils.destroy_model_parallel() - def _build_model(self): + def _build_model(self, inference_moe_token_dispatcher_type='nvls'): model_parallel_cuda_manual_seed(123, inference_rng_tracker=True, force_reset_rng=True) - config = TransformerConfig( - num_layers=3, - mtp_hybrid_override_pattern="ME*", - hidden_size=self.HIDDEN_SIZE, - num_attention_heads=self.NUM_ATTN_HEADS, - use_cpu_initialization=True, - params_dtype=torch.bfloat16, - bf16=True, - attention_backend=AttnBackend.fused, - num_moe_experts=2, - moe_token_dispatcher_type="alltoall", - cuda_graph_impl="local", + config = _make_base_config( + num_layers=3, inference_moe_token_dispatcher_type=inference_moe_token_dispatcher_type ) model = HybridModel( config=config, - hybrid_stack_spec=hybrid_stack_spec, + hybrid_stack_spec=hybrid_inference_stack_spec, vocab_size=self.VOCAB_SIZE, max_sequence_length=self.MAX_SEQ_LEN, - hybrid_layer_pattern="M*", + hybrid_layer_pattern="ME*", ) model.cuda() model.eval() @@ -135,6 +162,7 @@ def _build_context( num_cuda_graphs=16, use_cuda_graphs_for_non_decode_steps=True, max_requests=None, + max_tokens=None, ): mamba_config = MambaInferenceStateConfig.from_model(model) return DynamicInferenceContext( @@ -148,15 +176,16 @@ def _build_context( num_cuda_graphs=num_cuda_graphs, use_cuda_graphs_for_non_decode_steps=use_cuda_graphs_for_non_decode_steps, max_requests=max_requests, + max_tokens=max_tokens, ), ) + @torch.inference_mode() def _assert_dynamic_inference_shape(self, model, ctx, rank, state_label): - """Run model.forward and assert the logits shape matches - padded_batch_dimensions.token_count.""" + """Run model and assert the logits shape matches padded_batch_dimensions.token_count.""" padded = ctx.padded_batch_dimensions input_ids = torch.randint(0, self.VOCAB_SIZE, (1, padded.token_count), device="cuda") - out = model.forward( + out = model( input_ids=input_ids, position_ids=None, attention_mask=None, @@ -169,6 +198,33 @@ def _assert_dynamic_inference_shape(self, model, ctx, rank, state_label): f"got {tuple(out.shape)}" ) + @torch.inference_mode() + def _capture_all_cuda_graphs(self, model, ctx): + """Pre-capture all cuda graphs in lockstep across EP ranks. + + Mirrors DynamicInferenceEngine.create_cuda_graphs(): iterates every + shape in cuda_graph_batch_dimensions_list and runs a forward pass + with all EP ranks at the matching shape. After this, every rank's + model.forward goes through the replay path (no warmup loop), so + capture-mode and eager-mode ranks emit the same number of EP + collectives per call. Without this, the first forward triggers + cuda_graphs.create_fwd_graph's warmup loop on capture-mode ranks + only, deadlocking against eager-mode peers in mixed combos. + """ + for graph_dim in ctx.cuda_graph_batch_dimensions_list: + ctx.reset() + ctx.initialize_attention_state(construct_graph_dimensions=graph_dim) + padded = ctx.padded_batch_dimensions + input_ids = torch.randint(0, self.VOCAB_SIZE, (1, padded.token_count), device="cuda") + model( + input_ids=input_ids, + position_ids=None, + attention_mask=None, + inference_context=ctx, + runtime_gather_output=True, + ) + ctx.reset() + @staticmethod def _assert_cuda_graphs_were_replayed(expect_replayed, rank, label): """Assert that CUDA graphs were (or were not) recorded and replayed @@ -197,19 +253,10 @@ def _assert_cuda_graphs_were_replayed(expect_replayed, rank, label): f"but cudagraph_inference_record has {len(record)} entries" ) - def _assert_dummy_forward_shape(self, model, rank): - """Run model.forward with a single dummy token (no inference context), - mirroring the real engine's dummy_forward fallback, and verify the - logits shape.""" - tp_size = parallel_state.get_tensor_model_parallel_world_size() - dummy_tokens = torch.zeros(1, tp_size, dtype=torch.long, device="cuda") - position_ids = torch.zeros(1, tp_size, dtype=torch.long, device="cuda") - out = model.forward(input_ids=dummy_tokens, position_ids=position_ids, attention_mask=None) - expected = (1, tp_size, self.VOCAB_SIZE) - assert out.shape == expected, ( - f"Rank {rank} (dummy bail-out): expected out shape " - f"{expected}, got {tuple(out.shape)}" - ) + +@pytest.mark.internal +class TestDynamicInferenceNVLS(_TestDynamicInferenceBase): + """NVLS dispatcher: combinatorial sweep of EP request states.""" # ------------------------------------------------------------------ # test_ep_state_cross_product: combinatorial sweep with mixed CUDA graphs @@ -218,14 +265,16 @@ def _assert_dummy_forward_shape(self, model, rank): @pytest.mark.parametrize("rank_states", _STATE_COMBOS, ids=[",".join(s) for s in _STATE_COMBOS]) @pytest.mark.internal @torch.inference_mode() - def test_ep_state_cross_product(self, rank_states): + def test_nvls_ep_state_cross_product(self, rank_states): """Test all combinatorial (unordered, with repetition) assignments of - the four request states across EP ranks. + the request states across EP ranks. - The context is built with use_cuda_graphs_for_non_decode_steps=True, - so the CUDA graph list contains decode-only, mixed, and prefill-only - graphs. After the EP all-reduce in match_graph_config, every rank - (including dummy ranks) should always find a matching graph. + The NVLS dispatcher (match_ep_token_counts=False) does per-rank + independent graph matching. Each rank finds a matching graph for its + own state unless its token count exceeds the cuda-graph range + (PREFILL_AT_MAX_TOKENS / MIXED_GIANT_PREFILL), in which case that + rank falls back to eager. The AllGather-V dispatcher handles per-rank + size variation, so mixed graph/eager combos work. State setup uses add_dummy_requests_for_cudagraph_capture to populate the context directly with the desired request configuration. @@ -235,49 +284,56 @@ def test_ep_state_cross_product(self, rank_states): is_dummy = my_state == NONE model = self._build_model() - ctx = self._build_context(model) + ctx = self._build_context(model, max_requests=64, max_tokens=512) + + # Pre-capture every cuda graph in lockstep across EP ranks (mirrors + # DynamicInferenceEngine.create_cuda_graphs in production). Without + # this, the first per-rank forward triggers create_fwd_graph's + # warmup loop on capture-mode ranks only, which fires extra EP + # collectives that an eager peer cannot match — deadlock. + self._capture_all_cuda_graphs(model, ctx) # Phase 1: Set up each rank's request state directly. if not is_dummy: ctx.add_dummy_requests_for_cudagraph_capture(_STATE_DIMS[my_state]) - # Phase 2: Initialize attention state (EP collective). + # Phase 2: Initialize attention state (no EP collective with NVLS). if is_dummy: ctx.initialize_attention_state(is_expert_parallel_dummy_cuda_graph_step=True) else: ctx.initialize_attention_state() # Phase 3: Verify. - # With mixed CUDA graphs available, every rank — including dummy - # ranks whose EP-adjusted dimensions inherit prefill/decode counts - # from peers — must find a matching graph. - assert ctx.using_cuda_graph_this_step(), ( - f"EP rank {ep_rank} (state={my_state}): expected a CUDA graph match " - f"with use_cuda_graphs_for_non_decode_steps=True " - f"(rank_states={rank_states})" - ) - - # All EP ranks must agree on padded token count. - padded = ctx.padded_batch_dimensions - ep_group = parallel_state.get_expert_model_parallel_group() - tc = torch.tensor([padded.token_count], dtype=torch.int32, device="cuda") - tc_max = tc.clone() - tc_min = tc.clone() - dist.all_reduce(tc_max, op=dist.ReduceOp.MAX, group=ep_group) - dist.all_reduce(tc_min, op=dist.ReduceOp.MIN, group=ep_group) - assert tc_max.item() == tc_min.item(), ( - f"Padded token count mismatch across EP ranks: " - f"min={tc_min.item()}, max={tc_max.item()} " - f"(rank_states={rank_states})" - ) + # With NVLS dispatcher each rank matches independently, so every rank + # must find a graph for its own state — except PREFILL_EXCEED, whose + # token count exceeds the max cuda-graph size and falls back to eager. + if my_state in _NO_CUDA_GRAPH_STATES: + assert not ctx.using_cuda_graph_this_step(), ( + f"EP rank {ep_rank} (state={my_state}): expected no CUDA graph match " + f"(token_count exceeds cuda-graph range) " + f"(rank_states={rank_states})" + ) + else: + assert ctx.using_cuda_graph_this_step(), ( + f"EP rank {ep_rank} (state={my_state}): expected a CUDA graph match " + f"with use_cuda_graphs_for_non_decode_steps=True " + f"(rank_states={rank_states})" + ) self._assert_dynamic_inference_shape(model, ctx, ep_rank, my_state) - self._assert_cuda_graphs_were_replayed( - True, ep_rank, f"state={my_state}, rank_states={rank_states}" - ) + + if my_state not in _NO_CUDA_GRAPH_STATES: + self._assert_cuda_graphs_were_replayed( + True, ep_rank, f"state={my_state}, rank_states={rank_states}" + ) + + +@pytest.mark.internal +class TestDynamicInferenceNCCL(_TestDynamicInferenceBase): + """NCCL dispatcher: dummy-rank bail-out and eager-fallback tests.""" # ------------------------------------------------------------------ - # test_dummy_bailout_with_decode_only_cuda_graphs: dedicated bail-out + # Cuda-graph bail-out tests for the NCCLAllGatherDispatcher # ------------------------------------------------------------------ @pytest.mark.parametrize( @@ -285,16 +341,14 @@ def test_ep_state_cross_product(self, rank_states): ) @pytest.mark.internal @torch.inference_mode() - def test_dummy_bailout_with_decode_only_cuda_graphs(self, peer_state): - """Verify the dummy-rank bail-out path when only decode CUDA graphs - are available. + def test_nccl_dummy_bailout_with_prefill_peer(self, peer_state): + """Verify the dummy-rank bail-out path with the NCCL dispatcher. - With use_cuda_graphs_for_non_decode_steps=False, the CUDA graph list - contains only decode-only graphs. When any EP rank has prefill - requests, adjust_batch_dims_for_expert_parallelism returns None - (forcing eager mode), and match_graph_config returns None for all - ranks. A dummy rank then bails out of initialize_attention_state - early (padded_batch_dimensions is not set). + With the NCCL dispatcher (match_ep_token_counts=True), when any EP + rank has prefill requests, adjust_batch_dims_for_expert_parallelism + returns None (forcing eager mode) for ALL ranks. A dummy rank then + bails out of initialize_attention_state early (padded_batch_dimensions + is not set). This test verifies that: - The dummy rank correctly falls back to model.forward (the real @@ -308,7 +362,7 @@ def test_dummy_bailout_with_decode_only_cuda_graphs(self, peer_state): ep_rank = parallel_state.get_expert_model_parallel_rank() is_even = ep_rank % 2 == 0 - model = self._build_model() + model = self._build_model(inference_moe_token_dispatcher_type='nccl') ctx = self._build_context(model, use_cuda_graphs_for_non_decode_steps=False) # Set up request state. @@ -329,8 +383,9 @@ def test_dummy_bailout_with_decode_only_cuda_graphs(self, peer_state): ) if is_even: - # Dummy rank bailed out — exercise the eager fallback. - self._assert_dummy_forward_shape(model, ep_rank) + # Dummy rank: context has one dummy decode request. Mimic the real + # engine's dummy_forward, which always uses inference_context. + self._assert_dynamic_inference_shape(model, ctx, ep_rank, "dummy") else: # Non-dummy rank: padded_batch_dimensions is set via the # non-graph fallback path in initialize_attention_state. @@ -348,19 +403,14 @@ def test_dummy_bailout_with_decode_only_cuda_graphs(self, peer_state): ) @pytest.mark.internal @torch.inference_mode() - def test_mixed_cuda_graphs_tokens_exceed_max_requests(self, peer_state): - """Verify eager fallback when mixed CUDA graphs are allowed but - a rank's token count exceeds the CUDA graph capacity. - - With use_cuda_graphs_for_non_decode_steps=True, the CUDA graph - list includes mixed and prefill-only graphs. However, the - maximum CUDA graph token capacity is bounded by max_requests - (specifically, max_requests * (num_speculative_tokens + 1)). - - When one EP rank has a token count exceeding this capacity, no - CUDA graph can accommodate the EP-adjusted dimensions. - match_graph_config returns None for all ranks, forcing eager - mode globally. This test verifies that: + def test_nccl_eager_fallback_when_tokens_exceed_capacity(self, peer_state): + """Verify eager fallback with the NCCL dispatcher when a rank's token + count exceeds the CUDA graph capacity. + + With the NCCL dispatcher (match_ep_token_counts=True), the EP all-reduce + propagates the oversized token count to all ranks. Since no CUDA graph + can accommodate it, match_graph_config returns None for all ranks, + forcing eager mode globally. This test verifies that: - No rank matches a CUDA graph (eager mode is forced). - Dummy ranks bail out and produce correct shapes via the eager dummy_forward path. @@ -370,7 +420,7 @@ def test_mixed_cuda_graphs_tokens_exceed_max_requests(self, peer_state): ep_rank = parallel_state.get_expert_model_parallel_rank() is_even = ep_rank % 2 == 0 - model = self._build_model() + model = self._build_model(inference_moe_token_dispatcher_type='nccl') # Use a small max_requests so that the CUDA graph capacity # (max_requests tokens with no speculative decoding) is easily @@ -410,8 +460,9 @@ def test_mixed_cuda_graphs_tokens_exceed_max_requests(self, peer_state): ) if is_even: - # Dummy rank bailed out — exercise the eager fallback. - self._assert_dummy_forward_shape(model, ep_rank) + # Dummy rank: context has one dummy decode request. Mimic the real + # engine's dummy_forward, which always uses inference_context. + self._assert_dynamic_inference_shape(model, ctx, ep_rank, "dummy") else: # Non-dummy rank: padded_batch_dimensions is set via the # eager fallback path. Verify shape correctness. diff --git a/tests/unit_tests/inference/test_moe_inference.py b/tests/unit_tests/inference/test_moe_dispatching_and_routing.py similarity index 56% rename from tests/unit_tests/inference/test_moe_inference.py rename to tests/unit_tests/inference/test_moe_dispatching_and_routing.py index b762b5e638c..a9caa12f178 100644 --- a/tests/unit_tests/inference/test_moe_inference.py +++ b/tests/unit_tests/inference/test_moe_dispatching_and_routing.py @@ -10,11 +10,14 @@ - shared experts """ +import gc + import pytest import torch from megatron.core.activations import squared_relu from megatron.core.inference.communication.torch_symm_triton import are_tensors_nvls_eligible +from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_te_min_version, is_torch_min_version from megatron.training.initialize import _set_random_seed @@ -34,7 +37,7 @@ # ────────────────────────────────────────────────────────────────────── NANOV3_BASE = dict( - num_layers=1, + num_layers=4, hidden_size=128, ffn_hidden_size=128, num_attention_heads=4, @@ -57,6 +60,16 @@ bf16=True, params_dtype=torch.bfloat16, transformer_impl="inference_optimized", + expert_tensor_parallel_size=1, + use_cpu_initialization=True, + attention_backend=AttnBackend.local, + cuda_graph_impl="local", + cuda_graph_scope="full_iteration_inference", + moe_pad_experts_for_cuda_graph_inference=False, + mamba_state_dim=128, + mamba_head_dim=64, + mamba_num_groups=8, + mamba_num_heads=64, ) @@ -124,17 +137,6 @@ def test_init_accepts_valid_score_function(self, score_fn): ) assert router is not None - def test_set_unset_inference_mode(self): - """Toggle is_inference_cuda_graphed_iteration flag.""" - router = self._make_router() - assert not router.is_inference_cuda_graphed_iteration - - router.set_inference_cuda_graphed_iteration() - assert router.is_inference_cuda_graphed_iteration - - router.unset_inference_cuda_graphed_iteration() - assert not router.is_inference_cuda_graphed_iteration - def test_training_mode_forward_returns_sparse(self): """In training mode, forward delegates to parent and returns sparse tensors.""" router = self._make_router() @@ -172,7 +174,6 @@ def test_inference_vs_training_selects_same_experts(self): # Inference mode: get top_indices (dense) router.eval() - router.set_inference_cuda_graphed_iteration() _, top_indices = router(input_tensor.clone()) inference_experts = set() @@ -185,12 +186,12 @@ def test_inference_vs_training_selects_same_experts(self): # ────────────────────────────────────────────────────────────────────── -# InferenceCUDAGraphTokenDispatcher +# NCCLAllGatherDispatcher # ────────────────────────────────────────────────────────────────────── @pytest.mark.internal -class TestInferenceCUDAGraphTokenDispatcher: +class TestNCCLAllGatherDispatcher: @classmethod def setup_class(cls): @@ -199,28 +200,28 @@ def setup_class(cls): @classmethod def teardown_class(cls): - from megatron.core.inference.symmetric_memory import SymmetricMemoryManager - - SymmetricMemoryManager.destroy() Utils.destroy_model_parallel() + def teardown_method(self, method): + gc.collect() + torch.cuda.empty_cache() + def _make_dispatcher(self, **config_overrides): from megatron.core.transformer.moe.moe_utils import get_default_pg_collection - from megatron.core.transformer.moe.token_dispatcher_inference import ( - InferenceCUDAGraphTokenDispatcher, - ) + from megatron.core.transformer.moe.token_dispatcher_inference import NCCLAllGatherDispatcher + NCCLAllGatherDispatcher.allocate_buffers() config_overrides.setdefault("expert_model_parallel_size", Utils.world_size) config = _make_base_config(**config_overrides) num_local_experts = config.num_moe_experts // Utils.world_size ep_rank = torch.distributed.get_rank() if Utils.world_size > 1 else 0 local_expert_indices = [ep_rank * num_local_experts + i for i in range(num_local_experts)] - - return InferenceCUDAGraphTokenDispatcher( + return NCCLAllGatherDispatcher( num_local_experts=num_local_experts, local_expert_indices=local_expert_indices, config=config, pg_collection=get_default_pg_collection(), + runs_metadata_sync=True, ) def test_init(self): @@ -229,85 +230,145 @@ def test_init(self): assert dispatcher.topk == NANOV3_BASE["moe_router_topk"] assert dispatcher.ep_size == Utils.world_size - def test_symmetric_memory_buffer_initialized(self): - """EP symmetric memory buffer is lazily created via SymmetricMemoryManager.""" + @pytest.mark.parametrize("use_allgather_v", [False, True]) + def test_dispatch_combine(self, use_allgather_v): + """Dispatch+combine correctness for both CG (equal-count) and prefill (variable-count) paths. + + All ranks share the same global reference tensors (broadcast from rank 0). + Each rank contributes its own slice, then we verify: + - dispatch gathers all slices back to the full global tensor + - combine reduce-scatters the gathered data, giving each rank ep_size * its_slice + """ + from megatron.core.transformer.moe.token_dispatcher_inference import NCCLAllGatherDispatcher + + if use_allgather_v and Utils.world_size == 1: + pytest.skip("Variable-token prefill path requires EP > 1") + + dispatcher = self._make_dispatcher() + ep_size = dispatcher.ep_size + rank = torch.distributed.get_rank() if ep_size > 1 else 0 + hidden_size = NANOV3_BASE["hidden_size"] + topk = NANOV3_BASE["moe_router_topk"] + num_experts = NANOV3_BASE["num_moe_experts"] + + if use_allgather_v: + # Variable token counts: rank r contributes (r+1)*8 tokens + tokens_per_rank = [(r + 1) * 8 for r in range(ep_size)] + else: + tokens_per_rank = [16] * ep_size + + local_tokens = tokens_per_rank[rank] + total_tokens = sum(tokens_per_rank) + + NCCLAllGatherDispatcher._use_allgather_v = use_allgather_v + + global_hidden = torch.randn(total_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + global_probs = torch.randn(total_tokens, topk, device="cuda", dtype=torch.float32) + global_routing_map = torch.randint(0, num_experts, (total_tokens, topk), device="cuda") + if ep_size > 1: + torch.distributed.broadcast(global_hidden, src=0) + torch.distributed.broadcast(global_probs, src=0) + torch.distributed.broadcast(global_routing_map, src=0) + + offset = sum(tokens_per_rank[:rank]) + hidden = global_hidden[offset : offset + local_tokens].contiguous() + probs = global_probs[offset : offset + local_tokens].contiguous() + dispatcher.routing_map = global_routing_map[offset : offset + local_tokens].contiguous() + + if ep_size == 1: + d_hidden, d_probs = dispatcher.token_dispatch(hidden, probs) + assert d_hidden is hidden + assert d_probs is probs + return + + d_hidden, d_probs = dispatcher.token_dispatch(hidden, probs) + + assert d_hidden.shape == (total_tokens, hidden_size) + assert d_probs.shape == (total_tokens, topk) + torch.testing.assert_close(d_hidden, global_hidden, atol=0, rtol=0) + torch.testing.assert_close(d_probs, global_probs, atol=0, rtol=0) + torch.testing.assert_close(dispatcher.routing_map, global_routing_map, atol=0, rtol=0) + + # All ranks have identical gathered data, so rank r's reduce-scatter output + # is ep_size * its slice of global_hidden. + combined = dispatcher.token_combine(d_hidden) + assert combined.shape == (local_tokens, hidden_size) + expected = (global_hidden[offset : offset + local_tokens].float() * ep_size).bfloat16() + torch.testing.assert_close(combined, expected) + + +# ────────────────────────────────────────────────────────────────────── +# NVLSAllGatherVDispatcher +# ────────────────────────────────────────────────────────────────────── + +_NVLS_ENGINE_MAX_TOKENS = 512 + + +@pytest.mark.internal +class TestNVLSAllGatherVDispatcher: + + @classmethod + def setup_class(cls): + Utils.initialize_model_parallel(1, 1, expert_model_parallel_size=Utils.world_size) + _set_random_seed(seed_=123, data_parallel_random_init=False) + + @classmethod + def teardown_class(cls): from megatron.core.inference.symmetric_memory import SymmetricMemoryManager - # Buffer should not exist yet (lazy init) - assert not SymmetricMemoryManager.is_initialized("ep") + SymmetricMemoryManager.destroy() + Utils.destroy_model_parallel() - # Create it explicitly and verify + def _make_dispatcher(self): from megatron.core.parallel_state import get_expert_model_parallel_group + from megatron.core.transformer.moe.moe_utils import get_default_pg_collection + from megatron.core.transformer.moe.token_dispatcher_inference import ( + NVLSAllGatherVDispatcher, + ) + + config = _make_base_config(expert_model_parallel_size=Utils.world_size) + num_local_experts = config.num_moe_experts // Utils.world_size + ep_rank = torch.distributed.get_rank() if Utils.world_size > 1 else 0 + local_expert_indices = [ep_rank * num_local_experts + i for i in range(num_local_experts)] + ep_group = get_expert_model_parallel_group() - buf = SymmetricMemoryManager.get_buffer( - "ep", process_group=get_expert_model_parallel_group() + NVLSAllGatherVDispatcher.allocate_buffers( + per_rank_worst_case_token_count=_NVLS_ENGINE_MAX_TOKENS, + topk=NANOV3_BASE["moe_router_topk"], + hidden_size=NANOV3_BASE["hidden_size"], + ep_group=ep_group, ) - assert buf is not None - assert SymmetricMemoryManager.is_initialized("ep") + + return NVLSAllGatherVDispatcher( + num_local_experts=num_local_experts, + local_expert_indices=local_expert_indices, + config=config, + pg_collection=get_default_pg_collection(), + runs_metadata_sync=True, + ) + + def test_init(self): + """Dispatcher can be constructed with nanov3-like config and EP=world_size.""" + dispatcher = self._make_dispatcher() + assert dispatcher.topk == NANOV3_BASE["moe_router_topk"] + assert dispatcher.ep_size == Utils.world_size @pytest.mark.parametrize("seed", [42, 123, 7]) @pytest.mark.parametrize( - "num_local_tokens", - [ - 1, - 2, - 4, - 8, - 16, - 24, - 32, - 40, - 48, - 56, - 64, - 72, - 80, - 88, - 96, - 104, - 112, - 120, - 128, - 136, - 144, - 152, - 160, - 168, - 176, - 184, - 192, - 200, - 208, - 216, - 224, - 232, - 240, - 248, - 256, - 272, - 288, - 304, - 320, - 336, - 352, - 368, - 384, - 400, - 416, - 432, - 448, - 464, - 480, - 496, - 512, - ], + "max_rank_tokens", + # Covers: small, unaligned, power-of-2, and large up to engine_max + [1, 7, 16, 24, 64, 128, 256, 512], ) - def test_cuda_graph_dispatch_combine(self, num_local_tokens, seed): + def test_cuda_graph_dispatch_combine(self, max_rank_tokens, seed): """Dispatch+combine can be captured in a CUDA graph and replayed. - Creates global buffers, shards per rank, and verifies: - - NVLS AllGather output matches the full globalwol buffer - - NVLS ReduceScatter output matches fp32-accumulated reference - All tensor byte sizes are 128-bit aligned for NVLS eligibility. + + Uses uneven token counts across EP ranks (rank r gets + max(1, max_rank_tokens + r - (ep_size - 1)) tokens) to exercise the + AllGatherV variable-length path. Verifies: + - AllGatherV output matches the global reference (valid prefix only) + - ReduceScatterV output matches fp32-accumulated reference + Exact match (atol=0) is possible because the NVLS triton kernels + accumulate in fp32 before writing bf16 output. """ torch.manual_seed(seed) @@ -319,41 +380,38 @@ def test_cuda_graph_dispatch_combine(self, num_local_tokens, seed): topk = NANOV3_BASE["moe_router_topk"] num_experts = NANOV3_BASE["num_moe_experts"] rank = torch.distributed.get_rank() if ep_size > 1 else 0 - num_global_tokens = num_local_tokens * ep_size - # Create global buffers on rank 0 and broadcast to all ranks - global_hidden = torch.randn( - num_global_tokens, hidden_size, device="cuda", dtype=torch.bfloat16 - ) - global_probs = torch.randn(num_global_tokens, topk, device="cuda", dtype=torch.float32) - global_routing_map = torch.randint(0, num_experts, (num_global_tokens, topk), device="cuda") - torch.distributed.broadcast(global_hidden, src=0) - torch.distributed.broadcast(global_probs, src=0) - torch.distributed.broadcast(global_routing_map, src=0) - - # Each rank grabs their own shard - start = rank * num_local_tokens - end = start + num_local_tokens + # Uneven token counts: rank r gets max(1, max_rank_tokens + r - (ep_size-1)) + # so the last rank always has max_rank_tokens (≤ engine_max) and earlier + # ranks have fewer, exercising the variable-length AllGatherV path. + tokens_per_rank = [max(1, max_rank_tokens + r - (ep_size - 1)) for r in range(ep_size)] + local_tokens = tokens_per_rank[rank] + total_tokens = sum(tokens_per_rank) + global_max = _NVLS_ENGINE_MAX_TOKENS * ep_size + + global_hidden = torch.randn(total_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + global_probs = torch.randn(total_tokens, topk, device="cuda", dtype=torch.float32) + global_routing_map = torch.randint(0, num_experts, (total_tokens, topk), device="cuda") + if ep_size > 1: + torch.distributed.broadcast(global_hidden, src=0) + torch.distributed.broadcast(global_probs, src=0) + torch.distributed.broadcast(global_routing_map, src=0) + + start = sum(tokens_per_rank[:rank]) + end = start + local_tokens static_hidden = global_hidden[start:end].contiguous() static_probs = global_probs[start:end].contiguous() static_routing_map = global_routing_map[start:end].contiguous() - if not are_tensors_nvls_eligible(static_hidden, static_probs, static_routing_map): - pytest.skip( - "Tensors are not NVLS-eligible (need SM>=9 and each tensor's memory to be a multiple of 16 bytes)" - ) - - # 3 warmup iterations on a side stream + # Warmup on a side stream with torch.no_grad(): s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): dispatcher.routing_map = static_routing_map + dispatcher._local_tokens = local_tokens d_hidden, d_probs = dispatcher.token_dispatch(static_hidden, static_probs) - d_hidden = d_hidden.clone() - d_probs = d_probs.clone() - dispatcher.routing_map = dispatcher.routing_map.clone() dispatcher.token_combine(d_hidden.clone()) torch.cuda.current_stream().wait_stream(s) @@ -361,28 +419,23 @@ def test_cuda_graph_dispatch_combine(self, num_local_tokens, seed): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): dispatcher.routing_map = static_routing_map + dispatcher._local_tokens = local_tokens d_hidden, d_probs = dispatcher.token_dispatch(static_hidden, static_probs) - graph_hidden = d_hidden.clone() - graph_probs = d_probs.clone() - graph_routing_map = dispatcher.routing_map.clone() + graph_hidden = d_hidden[:total_tokens].clone() + graph_probs = d_probs[:total_tokens].clone() + graph_routing_map = dispatcher.routing_map[:total_tokens].clone() graph_combined = dispatcher.token_combine(d_hidden.clone()) - # Verify shapes: dispatch expands by ep_size, combine shrinks back - assert graph_hidden.shape == (num_global_tokens, hidden_size) - assert graph_probs.shape == (num_global_tokens, topk) - assert graph_combined.shape == (num_local_tokens, hidden_size) + # dispatch output is (global_max, *); only first total_tokens are valid + assert d_hidden.shape == (global_max, hidden_size) + assert d_probs.shape == (global_max, topk) + assert graph_combined.shape == (local_tokens, hidden_size) - # Replay graph.replay() - # Verify AllGather: all gathered tensors should match global buffers torch.testing.assert_close(graph_hidden, global_hidden, atol=0, rtol=0) torch.testing.assert_close(graph_probs, global_probs, atol=0, rtol=0) torch.testing.assert_close(graph_routing_map, global_routing_map, atol=0, rtol=0) - # Verify ReduceScatter: all ranks have the same all-gathered data, so - # rank r gets ep_size * chunk_r. Compute reference in fp32 then downcast. - # Exact match (atol=0, rtol=0) is possible because the NVLS triton kernels - # accumulate in fp32 before writing bf16 output. expected_combined = (global_hidden[start:end].float() * ep_size).bfloat16() torch.testing.assert_close(graph_combined, expected_combined, atol=0, rtol=0) diff --git a/tests/unit_tests/inference/test_moe_permute.py b/tests/unit_tests/inference/test_moe_permute.py index 4664d0fa2cd..6bddf515b14 100644 --- a/tests/unit_tests/inference/test_moe_permute.py +++ b/tests/unit_tests/inference/test_moe_permute.py @@ -35,6 +35,11 @@ def _ref_expert_offsets(tokens_per_expert, alignment): return exc.to(torch.int32), inc.to(torch.int32) +def _vt(n): + """Create a valid_tokens scalar int32 CUDA tensor.""" + return torch.tensor(n, dtype=torch.int32, device="cuda") + + def _make_inputs(num_tokens, hidden_dim, topk, num_experts, seed=42): """Create random hidden states, probs, and routing_map.""" torch.manual_seed(seed) @@ -47,6 +52,7 @@ def _make_inputs(num_tokens, hidden_dim, topk, num_experts, seed=42): @pytest.mark.internal class TestComputeLocalTokensPerExpert: + @pytest.mark.parametrize("persistent", [False, True]) @pytest.mark.parametrize("num_tokens", [1, 4, 16, 64, 128, 256, 512]) @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) @pytest.mark.parametrize( @@ -64,34 +70,41 @@ class TestComputeLocalTokensPerExpert: (128, 32, 96), # 128 experts, last 32 local (EP=4, rank 3) ], ) - def test_matches_reference(self, num_tokens, topk, num_experts, num_local, start): + def test_matches_reference(self, num_tokens, topk, num_experts, num_local, start, persistent): from megatron.core.inference.moe.permute import compute_local_tokens_per_expert routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") - result = compute_local_tokens_per_expert(routing_map, start, num_local) + result = compute_local_tokens_per_expert( + routing_map, start, num_local, _vt(num_tokens), persistent=persistent + ) expected = _ref_tokens_per_expert(routing_map, start, num_local) torch.testing.assert_close(result, expected, atol=0, rtol=0) - def test_no_local_tokens(self): + @pytest.mark.parametrize("persistent", [False, True]) + def test_no_local_tokens(self, persistent): """All tokens routed to non-local experts -> all zeros.""" from megatron.core.inference.moe.permute import compute_local_tokens_per_expert routing_map = torch.full((16, 4), 99, dtype=torch.int64, device="cuda") - result = compute_local_tokens_per_expert(routing_map, 0, 8) + result = compute_local_tokens_per_expert(routing_map, 0, 8, _vt(16), persistent=persistent) assert result.sum().item() == 0 - def test_single_expert_all_tokens(self): + @pytest.mark.parametrize("persistent", [False, True]) + def test_single_expert_all_tokens(self, persistent): """All token-topk pairs route to a single local expert.""" from megatron.core.inference.moe.permute import compute_local_tokens_per_expert num_tokens, topk, num_local = 32, 4, 8 routing_map = torch.full((num_tokens, topk), 3, dtype=torch.int64, device="cuda") - result = compute_local_tokens_per_expert(routing_map, 0, num_local) + result = compute_local_tokens_per_expert( + routing_map, 0, num_local, _vt(num_tokens), persistent=persistent + ) assert result[3].item() == num_tokens * topk assert result.sum().item() == num_tokens * topk + @pytest.mark.parametrize("persistent", [False, True]) @pytest.mark.parametrize("seed", [0, 7, 42, 123, 999]) - def test_total_count_equals_local_pairs(self, seed): + def test_total_count_equals_local_pairs(self, seed, persistent): """Sum of tokens_per_expert equals total routing pairs hitting local experts.""" from megatron.core.inference.moe.permute import compute_local_tokens_per_expert @@ -99,7 +112,9 @@ def test_total_count_equals_local_pairs(self, seed): num_tokens, topk, num_experts = 64, 6, 16 local_start, num_local = 4, 4 routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") - result = compute_local_tokens_per_expert(routing_map, local_start, num_local) + result = compute_local_tokens_per_expert( + routing_map, local_start, num_local, _vt(num_tokens), persistent=persistent + ) local_mask = (routing_map >= local_start) & (routing_map < local_start + num_local) assert result.sum().item() == local_mask.sum().item() @@ -193,11 +208,12 @@ def test_data_integrity(self, num_tokens, hidden_dim, topk, num_experts): hidden, probs, routing_map = _make_inputs(num_tokens, hidden_dim, topk, num_experts) perm_h, perm_p, perm_map, offs = permute_tokens( - hidden, probs, routing_map, 0, num_experts, alignment=1 + hidden, probs, routing_map, 0, num_experts, _vt(num_tokens), alignment=1 ) - # Check every non-padding row - for i in range(perm_map.shape[0]): + # Only rows [0, n_used) are initialized; the rest are uninitialized padding. + n_used = offs[-1].item() + for i in range(n_used): src = perm_map[i].item() if src < 0: continue @@ -213,7 +229,7 @@ def test_offsets_are_aligned(self, alignment, num_tokens, topk, num_experts): hidden, probs, routing_map = _make_inputs(num_tokens, 128, topk, num_experts) _, _, _, offs = permute_tokens( - hidden, probs, routing_map, 0, num_experts, alignment=alignment + hidden, probs, routing_map, 0, num_experts, _vt(num_tokens), alignment=alignment ) if alignment > 1: for i in range(offs.shape[0]): @@ -230,11 +246,13 @@ def test_padding_rows_have_neg1(self, num_tokens, topk, num_experts, alignment): from megatron.core.inference.moe.permute import permute_tokens hidden, probs, routing_map = _make_inputs(num_tokens, 64, topk, num_experts) - _, _, perm_map, _ = permute_tokens( - hidden, probs, routing_map, 0, num_experts, alignment=alignment + _, _, perm_map, offs = permute_tokens( + hidden, probs, routing_map, 0, num_experts, _vt(num_tokens), alignment=alignment ) - padding_mask = perm_map == -1 - real_mask = perm_map >= 0 + # Only [0, n_used) is initialized; beyond that is uninitialized padding. + perm_map_used = perm_map[: offs[-1].item()] + padding_mask = perm_map_used == -1 + real_mask = perm_map_used >= 0 assert padding_mask.sum() > 0, "Expected some padding rows with large alignment" assert real_mask.sum() > 0, "Expected some real rows" @@ -247,10 +265,10 @@ def test_total_real_rows_equals_routed_pairs(self, num_tokens, topk, num_experts from megatron.core.inference.moe.permute import permute_tokens hidden, probs, routing_map = _make_inputs(num_tokens, 64, topk, num_experts) - _, _, perm_map, _ = permute_tokens( - hidden, probs, routing_map, 0, num_experts, alignment=alignment + _, _, perm_map, offs = permute_tokens( + hidden, probs, routing_map, 0, num_experts, _vt(num_tokens), alignment=alignment ) - real_count = (perm_map >= 0).sum().item() + real_count = (perm_map[: offs[-1].item()] >= 0).sum().item() # All experts are local, so every pair should appear assert real_count == num_tokens * topk @@ -270,10 +288,10 @@ def test_expert_subset(self, num_tokens, topk, num_experts, local_start, num_loc from megatron.core.inference.moe.permute import permute_tokens hidden, probs, routing_map = _make_inputs(num_tokens, 64, topk, num_experts) - _, _, perm_map, _ = permute_tokens( - hidden, probs, routing_map, local_start, num_local, alignment=1 + _, _, perm_map, offs = permute_tokens( + hidden, probs, routing_map, local_start, num_local, _vt(num_tokens), alignment=1 ) - real_count = (perm_map >= 0).sum().item() + real_count = (perm_map[: offs[-1].item()] >= 0).sum().item() local_mask = (routing_map >= local_start) & (routing_map < local_start + num_local) expected_count = local_mask.sum().item() assert real_count == expected_count @@ -284,9 +302,11 @@ def test_various_hidden_dims(self, hidden_dim): from megatron.core.inference.moe.permute import permute_tokens hidden, probs, routing_map = _make_inputs(32, hidden_dim, 4, 8) - perm_h, _, perm_map, _ = permute_tokens(hidden, probs, routing_map, 0, 8, alignment=1) - # Spot-check first real row - for i in range(perm_map.shape[0]): + perm_h, _, perm_map, offs = permute_tokens( + hidden, probs, routing_map, 0, 8, _vt(32), alignment=1 + ) + # Spot-check first real row within the initialized range + for i in range(offs[-1].item()): src = perm_map[i].item() if src >= 0: torch.testing.assert_close(perm_h[i], hidden[src]) @@ -306,7 +326,9 @@ def test_weighted_scatter(self): permuted_probs = torch.tensor([0.5, 0.3, 0.7], device="cuda", dtype=torch.float32) perm_map = torch.tensor([0, 0, 2], dtype=torch.int32, device="cuda") - result = unpermute_tokens(expert_output, permuted_probs, perm_map, num_tokens) + result = unpermute_tokens( + expert_output, permuted_probs, perm_map, num_tokens, _vt(3), _vt(num_tokens) + ) assert result.dtype == torch.float32 # Token 0: 0.5 * 1.0 + 0.3 * 1.0 = 0.8 @@ -328,7 +350,7 @@ def test_padding_rows_ignored(self): permuted_probs = torch.ones(4, device="cuda", dtype=torch.float32) perm_map = torch.tensor([0, -1, -1, 1], dtype=torch.int32, device="cuda") - result = unpermute_tokens(expert_output, permuted_probs, perm_map, 3) + result = unpermute_tokens(expert_output, permuted_probs, perm_map, 3, _vt(4), _vt(3)) # Only tokens 0 and 1 get values assert result[0].sum().item() != 0 assert result[1].sum().item() != 0 @@ -344,7 +366,9 @@ def test_various_hidden_dims(self, hidden_dim): permuted_probs = torch.tensor([1.0, 1.0, 1.0, 1.0], device="cuda", dtype=torch.float32) perm_map = torch.tensor([0, 1, 2, 3], dtype=torch.int32, device="cuda") - result = unpermute_tokens(expert_output, permuted_probs, perm_map, num_tokens) + result = unpermute_tokens( + expert_output, permuted_probs, perm_map, num_tokens, _vt(4), _vt(num_tokens) + ) assert result.shape == (num_tokens, hidden_dim) # First 4 tokens should have values, rest should be zero for t in range(4): @@ -363,7 +387,7 @@ def test_multiple_topk_accumulation(self, topk): probs = torch.full((topk,), 0.1, device="cuda", dtype=torch.float32) perm_map = torch.zeros(topk, dtype=torch.int32, device="cuda") - result = unpermute_tokens(expert_output, probs, perm_map, 1) + result = unpermute_tokens(expert_output, probs, perm_map, 1, _vt(topk), _vt(1)) expected_val = 0.1 * topk torch.testing.assert_close( result[0], torch.full((hidden_dim,), expected_val, device="cuda"), atol=1e-4, rtol=1e-4 @@ -400,11 +424,11 @@ def test_roundtrip_identity(self, num_tokens, hidden_dim, topk, num_experts, ali probs = torch.rand(num_tokens, topk, device="cuda", dtype=torch.float32) routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") - perm_h, perm_p, perm_map, _ = permute_tokens( - hidden, probs, routing_map, 0, num_experts, alignment=alignment + perm_h, perm_p, perm_map, offs = permute_tokens( + hidden, probs, routing_map, 0, num_experts, _vt(num_tokens), alignment=alignment ) # Pass permuted hidden directly through (identity expert) - result = unpermute_tokens(perm_h, perm_p, perm_map, num_tokens) + result = unpermute_tokens(perm_h, perm_p, perm_map, num_tokens, offs[-1:], _vt(num_tokens)) # Build reference: for each token, sum prob[k] * hidden[token] over topk ref = torch.zeros(num_tokens, hidden_dim, device="cuda", dtype=torch.float32) @@ -428,10 +452,10 @@ def test_roundtrip_with_expert_subset(self, local_start, num_local, num_experts) probs = torch.rand(num_tokens, topk, device="cuda", dtype=torch.float32) routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") - perm_h, perm_p, perm_map, _ = permute_tokens( - hidden, probs, routing_map, local_start, num_local, alignment=32 + perm_h, perm_p, perm_map, offs = permute_tokens( + hidden, probs, routing_map, local_start, num_local, _vt(num_tokens), alignment=32 ) - result = unpermute_tokens(perm_h, perm_p, perm_map, num_tokens) + result = unpermute_tokens(perm_h, perm_p, perm_map, num_tokens, offs[-1:], _vt(num_tokens)) # Reference: only accumulate probs for local experts ref = torch.zeros(num_tokens, hidden_dim, device="cuda", dtype=torch.float32) diff --git a/tests/unit_tests/inference/test_mtp_cuda_graph_inference.py b/tests/unit_tests/inference/test_mtp_cuda_graph_inference.py new file mode 100644 index 00000000000..8fd1f4a1154 --- /dev/null +++ b/tests/unit_tests/inference/test_mtp_cuda_graph_inference.py @@ -0,0 +1,1100 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Tests for CUDA-graphed MTP (Multi-Token Prediction) inference. + +Verifies that: +1. CUDA graph replay produces the same output as eager execution (no extra + padding in the CUDA graphed case). +2. CUDA graphs work correctly with sequence parallelism (padding is applied + to make batch sizes divisible by TP). +3. CUDA graphs work correctly with expert parallelism and dummy ranks. + +Uses DynamicInferenceEngine for CUDA graph warmup so MTP graph capture +logic matches production code exactly. +""" + +import itertools +from unittest import mock + +import pytest +import torch +import torch.distributed as dist + +from megatron.core import parallel_state +from megatron.core.inference.batch_dimensions_utils import InferenceBatchDimensions +from megatron.core.inference.config import InferenceConfig +from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext +from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine +from megatron.core.inference.inference_request import DynamicInferenceRequest +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, +) +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_mtp_block_spec, +) +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.tensor_parallel.mappings import scatter_to_sequence_parallel_region +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.cuda_graphs import delete_cuda_graphs +from megatron.core.transformer.enums import AttnBackend +from megatron.core.utils import unwrap_model +from tests.unit_tests.test_utilities import Utils + +# --------------------------------------------------------------------------- # +# TestMTPCudaGraphInference (TP = 2) +# --------------------------------------------------------------------------- # + + +class TestMTPCudaGraphInference: + """Tests for MTP CUDA-graphed inference with tensor parallelism. + + All tests require at least 2 GPUs (TP = 2). Uses DynamicInferenceEngine + for CUDA graph warmup so MTP graph capture matches production code. + """ + + HIDDEN_SIZE = 32 + VOCAB_SIZE = 100 + MAX_SEQ_LEN = 64 + NUM_LAYERS = 4 + NUM_ATTN_HEADS = 4 + TP_SIZE = 2 + + @classmethod + def setup_class(cls): + if Utils.world_size < cls.TP_SIZE: + pytest.skip(f"Need at least {cls.TP_SIZE} GPUs") + Utils.initialize_model_parallel( + tensor_model_parallel_size=cls.TP_SIZE, pipeline_model_parallel_size=1 + ) + + @classmethod + def teardown_class(cls): + delete_cuda_graphs() + Utils.destroy_model_parallel() + + def teardown_method(self): + delete_cuda_graphs() + + # ---- helpers ---------------------------------------------------------- # + + def _build_model( + self, *, sequence_parallel=False, mtp_num_layers=2, mtp_use_repeated_layer=False + ): + """Build a GPT model with MTP layers and local CUDA graph support.""" + model_parallel_cuda_manual_seed(123, inference_rng_tracker=True, force_reset_rng=True) + config = TransformerConfig( + num_layers=self.NUM_LAYERS, + hidden_size=self.HIDDEN_SIZE, + num_attention_heads=self.NUM_ATTN_HEADS, + use_cpu_initialization=True, + attention_backend=AttnBackend.local, + params_dtype=torch.bfloat16, + tensor_model_parallel_size=self.TP_SIZE, + pipeline_model_parallel_size=1, + pipeline_dtype=torch.bfloat16, + mtp_num_layers=mtp_num_layers, + mtp_use_repeated_layer=mtp_use_repeated_layer, + sequence_parallel=sequence_parallel, + cuda_graph_impl="local", + ) + layer_spec = get_gpt_layer_local_spec() + mtp_block_spec = get_gpt_mtp_block_spec( + config=config, spec=layer_spec, use_transformer_engine=False + ) + model = GPTModel( + config=config, + transformer_layer_spec=layer_spec, + vocab_size=self.VOCAB_SIZE, + max_sequence_length=self.MAX_SEQ_LEN, + parallel_output=True, + pre_process=True, + post_process=True, + mtp_block_spec=mtp_block_spec, + ).cuda() + for param in model.parameters(): + param.data = param.data.to(config.params_dtype) + model.eval() + return model + + def _build_engine( + self, + *, + sequence_parallel=False, + mtp_num_layers=2, + mtp_use_repeated_layer=False, + num_speculative_tokens=2, + max_requests=16, + ): + """Build a DynamicInferenceEngine with automatic MTP CUDA graph warmup. + + The engine's `__init__` calls `create_cuda_graphs()` which captures + both decoder and MTP CUDA graphs, matching production warmup exactly. + """ + delete_cuda_graphs() + model = self._build_model( + sequence_parallel=sequence_parallel, + mtp_num_layers=mtp_num_layers, + mtp_use_repeated_layer=mtp_use_repeated_layer, + ) + config = model.config + context = DynamicInferenceContext( + model_config=config, + inference_config=InferenceConfig( + max_sequence_length=self.MAX_SEQ_LEN, + buffer_size_gb=0.5, + materialize_only_last_token_logits=False, + num_speculative_tokens=num_speculative_tokens, + block_size_tokens=256, + max_requests=max_requests, + num_cuda_graphs=-1, + sampling_backend='torch', + ), + ) + wrapped = GPTInferenceWrapper(model, context) + wrapped.model_is_pipeline_parallel = False + mock_tokenizer = mock.Mock() + ctrl = TextGenerationController(inference_wrapped_model=wrapped, tokenizer=mock_tokenizer) + engine = DynamicInferenceEngine(ctrl, context) + return engine + + @staticmethod + def _get_mtp_warmed_batch_sizes(engine): + """Return the MTP batch sizes (padded req_counts) warmed by the engine. + + These are the `n` values for which MTP CUDA graphs were captured. + Hidden states shape is `[n // tp, 1, H]` with SP, `[n, 1, H]` without. + Token/position IDs are always `[1, n]`. + """ + context = engine.context + model_config = engine.controller.inference_wrapped_model.model.config + tp_size = parallel_state.get_tensor_model_parallel_world_size() + sp_enabled = model_config.sequence_parallel and tp_size > 1 + sizes = set() + for dim in context.cuda_graph_batch_dimensions_list: + n = dim.req_count + if sp_enabled: + n += (tp_size - n % tp_size) % tp_size + if n > 0: + sizes.add(n) + return sorted(sizes) + + @staticmethod + def _mtp_kwargs(use_graph, batch_size, mtp_depth): + """Construct call-site kwargs that route `compute_mtp_single_step` to + either CUDA graph replay or eager execution. + + The wrapped `compute_mtp_single_step` honors `eager=True` to bypass the + manager and `cache_key=...` for O(1) runner lookup. + """ + if use_graph: + return {"cache_key": ("mtp", batch_size, mtp_depth)} + return {"eager": True} + + @staticmethod + def _assert_mtp_cuda_graphs_were_replayed(model, expect_replayed): + """Assert that MTP CUDA graphs were (or were not) replayed. + + MTP runners are stored in the CudaGraphManager's lookup table + rather than the global inference record. A runner with + `fwd_graph_recorded=True` confirms the graph was captured and + replayed. + """ + unwrapped = unwrap_model(model) + manager = getattr(unwrapped, '_mtp_cudagraph_manager', None) + if manager is None: + assert not expect_replayed, "No MTP CudaGraphManager found on the model" + return + table = manager.custom_cudagraphs_lookup_table + mtp_runners = [v for k, v in table.items() if isinstance(k, tuple) and k[0] == 'mtp'] + if expect_replayed: + assert ( + len(mtp_runners) > 0 + ), "Expected MTP CUDA graphs to be replayed, but no MTP runners found" + for runner in mtp_runners: + assert runner.fwd_graph_recorded, ( + "Expected MTP CUDA graph to be recorded and replayed, " + f"but runner for {runner.base_module.__class__.__name__} " + "has fwd_graph_recorded=False" + ) + else: + recorded = [r for r in mtp_runners if r.fwd_graph_recorded] + assert len(recorded) == 0, ( + f"Expected no MTP CUDA graph replay, but {len(recorded)} " + "runners have fwd_graph_recorded=True" + ) + + # ---- Test 1: graph output matches eager (no additional padding) ------- # + + @pytest.mark.parametrize("mtp_use_repeated_layer", [False, True]) + @torch.inference_mode() + def test_cuda_graph_output_matches_eager(self, mtp_use_repeated_layer): + """CUDA graph replay produces the same output as eager execution. + + The batch sizes exactly match warmed-up graphs (from the engine's + CUDA graph warmup), so there is no additional padding. Both paths + must produce identical hidden states and logits. + """ + engine = self._build_engine(mtp_use_repeated_layer=mtp_use_repeated_layer) + model = engine.controller.inference_wrapped_model.model + unwrapped = unwrap_model(model) + batch_sizes = self._get_mtp_warmed_batch_sizes(engine) + assert len(batch_sizes) > 0, "Engine did not warm up any MTP CUDA graphs" + + mtp_depth = None if unwrapped.mtp.mtp_use_repeated_layer else 0 + + for batch_size in batch_sizes[:3]: + hidden = torch.randn( + batch_size, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16 + ) + dist.broadcast(hidden, src=0) + token_ids = torch.randint(0, self.VOCAB_SIZE, (1, batch_size), device='cuda') + dist.broadcast(token_ids, src=0) + position_ids = torch.arange(batch_size, device='cuda', dtype=torch.int64).unsqueeze(0) + + h_graph, logits_graph = unwrapped.compute_mtp_single_step( + hidden_states=hidden.clone(), + next_token_ids=token_ids.clone(), + position_ids=position_ids.clone(), + depth=mtp_depth, + **self._mtp_kwargs(use_graph=True, batch_size=batch_size, mtp_depth=mtp_depth), + ) + h_graph = h_graph.clone() + logits_graph = logits_graph.clone() + + h_eager, logits_eager = unwrapped.compute_mtp_single_step( + hidden_states=hidden.clone(), + next_token_ids=token_ids.clone(), + position_ids=position_ids.clone(), + depth=mtp_depth, + **self._mtp_kwargs(use_graph=False, batch_size=batch_size, mtp_depth=mtp_depth), + ) + + torch.testing.assert_close( + h_graph, h_eager, msg=f"Hidden mismatch at batch_size={batch_size}" + ) + torch.testing.assert_close( + logits_graph, logits_eager, msg=f"Logits mismatch at batch_size={batch_size}" + ) + + self._assert_mtp_cuda_graphs_were_replayed(model, True) + + # ---- Test 2: graph matches eager with sequence parallelism ------------ # + + @pytest.mark.parametrize("mtp_use_repeated_layer", [False, True]) + @torch.inference_mode() + def test_cuda_graph_output_matches_eager_with_sp(self, mtp_use_repeated_layer): + """CUDA graph replay matches eager with sequence parallelism. + + Hidden states are in scattered SP format `[batch_size/TP, 1, H]`. + Token/position IDs remain at full `[1, batch_size]`. Both paths + must produce identical outputs. + """ + engine = self._build_engine( + sequence_parallel=True, mtp_use_repeated_layer=mtp_use_repeated_layer + ) + model = engine.controller.inference_wrapped_model.model + unwrapped = unwrap_model(model) + tp_group = parallel_state.get_tensor_model_parallel_group() + batch_sizes = self._get_mtp_warmed_batch_sizes(engine) + assert len(batch_sizes) > 0, "Engine did not warm up any MTP CUDA graphs" + + mtp_depth = None if unwrapped.mtp.mtp_use_repeated_layer else 0 + + for batch_size in batch_sizes[:3]: + hidden = torch.randn( + batch_size, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16 + ) + dist.broadcast(hidden, src=0) + hidden_sp = scatter_to_sequence_parallel_region(hidden, group=tp_group) + + token_ids = torch.randint(0, self.VOCAB_SIZE, (1, batch_size), device='cuda') + dist.broadcast(token_ids, src=0) + position_ids = torch.arange(batch_size, device='cuda', dtype=torch.int64).unsqueeze(0) + + h_graph, logits_graph = unwrapped.compute_mtp_single_step( + hidden_states=hidden_sp.clone(), + next_token_ids=token_ids.clone(), + position_ids=position_ids.clone(), + depth=mtp_depth, + **self._mtp_kwargs(use_graph=True, batch_size=batch_size, mtp_depth=mtp_depth), + ) + h_graph = h_graph.clone() + logits_graph = logits_graph.clone() + + h_eager, logits_eager = unwrapped.compute_mtp_single_step( + hidden_states=hidden_sp.clone(), + next_token_ids=token_ids.clone(), + position_ids=position_ids.clone(), + depth=mtp_depth, + **self._mtp_kwargs(use_graph=False, batch_size=batch_size, mtp_depth=mtp_depth), + ) + + torch.testing.assert_close( + h_graph, h_eager, msg=f"Hidden mismatch at batch_size={batch_size}" + ) + torch.testing.assert_close( + logits_graph, logits_eager, msg=f"Logits mismatch at batch_size={batch_size}" + ) + + self._assert_mtp_cuda_graphs_were_replayed(model, True) + + # ---- Test 3: end-to-end _compute_serial_mtp_and_sample with SP ------- # + + @pytest.mark.parametrize("mtp_use_repeated_layer", [False, True]) + @torch.inference_mode() + def test_cuda_graph_sp_padding_end_to_end(self, mtp_use_repeated_layer): + """Full `_compute_serial_mtp_and_sample` with CUDA graphs and SP. + + Active request counts that are not multiples of TP are padded. + The engine's CUDA graph warmup pre-captures MTP graphs for the + padded batch sizes. Verifies that padding, SP scatter/gather, and + MTP forward all work correctly through the CUDA graph path. + """ + tp_size = self.TP_SIZE + num_spec = 2 + max_requests = 16 + engine = self._build_engine( + sequence_parallel=True, + mtp_num_layers=num_spec, + mtp_use_repeated_layer=mtp_use_repeated_layer, + num_speculative_tokens=num_spec, + max_requests=max_requests, + ) + ctrl = engine.controller + context = engine.context + model = ctrl.inference_wrapped_model.model + unwrapped = unwrap_model(model) + + mtp_sizes = self._get_mtp_warmed_batch_sizes(engine) + + # Find active_request_counts whose TP-padded values match warmed MTP sizes. + active_counts = [] + for n in mtp_sizes: + for active in range(n, 0, -1): + padded = active + (tp_size - active % tp_size) % tp_size + if padded == n and active <= max_requests: + active_counts.append(active) + break + assert len(active_counts) > 0, "No valid active request counts found" + + for active_request_count in active_counts[:4]: + padded_count = ( + active_request_count + (tp_size - active_request_count % tp_size) % tp_size + ) + + context.reset() + context.total_request_count = active_request_count + context.paused_request_count = 0 + context.request_kv_length_offsets[:active_request_count] = torch.arange( + active_request_count, dtype=torch.int32, device='cuda' + ) + context.request_query_lengths[:active_request_count] = torch.ones( + active_request_count, dtype=torch.int32, device='cuda' + ) + + ctrl.num_speculative_tokens = num_spec + ctrl.num_mtp_heads = num_spec + ctrl._init_mtp_sampling_tensors() + ctrl._mtp_token_ids_buf.zero_() + ctrl._mtp_position_ids_buf.zero_() + ctrl._sampled_tokens_cuda[:active_request_count] = torch.remainder( + torch.arange(active_request_count, device='cuda'), self.VOCAB_SIZE + ) + + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + torch.manual_seed(42) + full_hidden = torch.randn( + padded_count, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16 + ) + dist.broadcast(full_hidden, src=0) + local_hidden = full_hidden.chunk(tp_size)[tp_rank].contiguous() + unwrapped._decoder_hidden_states_cache = local_hidden + + ctrl._last_accepted_seq_indices = torch.arange(active_request_count, device='cuda') + ctrl._mtp_resolved_padded_count = padded_count + context._using_cuda_graph_this_step = True + + # Greedy sampling for all active requests. + context.active_request_metadata["temperature"][:active_request_count] = 1.0 + context.active_request_metadata["top_k"][:active_request_count] = 1 + context.active_request_metadata["top_p"][:active_request_count] = 0.0 + + ctrl._compute_serial_mtp_and_sample() + + for depth in range(num_spec): + sampled = ctrl._sampled_mtp_tokens_cuda[depth, :active_request_count] + assert sampled.shape == ( + active_request_count, + ), f"active={active_request_count}, depth={depth}" + assert sampled.dtype == torch.int64 + assert torch.all(sampled >= 0) and torch.all(sampled < self.VOCAB_SIZE) + + assert not hasattr(unwrapped, '_decoder_hidden_states_cache') + + self._assert_mtp_cuda_graphs_were_replayed(model, True) + + # ---- Test 4: SP padding graph vs eager produces same MTP tokens ------- # + + @pytest.mark.parametrize("mtp_use_repeated_layer", [False, True]) + @torch.inference_mode() + def test_cuda_graph_sp_padding_matches_eager(self, mtp_use_repeated_layer): + """With SP padding, CUDA graph path produces the same MTP tokens as eager. + + Uses a single engine (shared model weights) and toggles the CUDA + graph flag between runs. Both paths receive identical inputs and + must produce the same sampled MTP tokens. + """ + tp_size = self.TP_SIZE + num_spec = 2 + max_requests = 16 + engine = self._build_engine( + sequence_parallel=True, + mtp_num_layers=num_spec, + mtp_use_repeated_layer=mtp_use_repeated_layer, + num_speculative_tokens=num_spec, + max_requests=max_requests, + ) + ctrl = engine.controller + context = engine.context + model = ctrl.inference_wrapped_model.model + + mtp_sizes = self._get_mtp_warmed_batch_sizes(engine) + + # Find active counts that require TP padding (active % tp != 0). + active_counts = [] + for n in mtp_sizes: + for active in range(n, 0, -1): + padded = active + (tp_size - active % tp_size) % tp_size + if padded == n and active % tp_size != 0 and active <= max_requests: + active_counts.append(active) + break + assert len(active_counts) > 0, "No active counts with TP padding found" + + for active_request_count in active_counts[:2]: + padded_count = ( + active_request_count + (tp_size - active_request_count % tp_size) % tp_size + ) + + def _run_mtp(use_cuda_graph): + """Set up state and run MTP, returning sampled tokens.""" + unwrapped = unwrap_model(model) + context.reset() + context.total_request_count = active_request_count + context.paused_request_count = 0 + context.request_kv_length_offsets[:active_request_count] = torch.arange( + active_request_count, dtype=torch.int32, device='cuda' + ) + context.request_query_lengths[:active_request_count] = torch.ones( + active_request_count, dtype=torch.int32, device='cuda' + ) + + ctrl.num_speculative_tokens = num_spec + ctrl.num_mtp_heads = num_spec + ctrl._init_mtp_sampling_tensors() + ctrl._mtp_token_ids_buf.zero_() + ctrl._mtp_position_ids_buf.zero_() + ctrl._sampled_tokens_cuda[:active_request_count] = torch.remainder( + torch.arange(active_request_count, device='cuda'), self.VOCAB_SIZE + ) + + if use_cuda_graph: + ctrl._mtp_resolved_padded_count = padded_count + context._using_cuda_graph_this_step = True + else: + ctrl._mtp_resolved_padded_count = None + context._using_cuda_graph_this_step = False + + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + torch.manual_seed(42) + full_hidden = torch.randn( + padded_count, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16 + ) + dist.broadcast(full_hidden, src=0) + local_hidden = full_hidden.chunk(tp_size)[tp_rank].contiguous() + unwrapped._decoder_hidden_states_cache = local_hidden + + ctrl._last_accepted_seq_indices = torch.arange(active_request_count, device='cuda') + # Greedy sampling for all active requests. + context.active_request_metadata["temperature"][:active_request_count] = 1.0 + context.active_request_metadata["top_k"][:active_request_count] = 1 + context.active_request_metadata["top_p"][:active_request_count] = 0.0 + + ctrl._compute_serial_mtp_and_sample() + # CUDA graph replay is asynchronous, and this test reuses the controller's + # staging buffers immediately for the eager comparison. + torch.cuda.synchronize() + + return [ + ctrl._sampled_mtp_tokens_cuda[d, :active_request_count].clone() + for d in range(num_spec) + ] + + graph_tokens = _run_mtp(use_cuda_graph=True) + self._assert_mtp_cuda_graphs_were_replayed(model, True) + eager_tokens = _run_mtp(use_cuda_graph=False) + + for depth in range(num_spec): + assert torch.equal(graph_tokens[depth], eager_tokens[depth]), ( + f"active={active_request_count}, depth={depth}: " + f"graph tokens {graph_tokens[depth].tolist()} != " + f"eager tokens {eager_tokens[depth].tolist()}" + ) + + # ---- Test 5: multiple MTP depths with CUDA graphs --------------------- # + + @pytest.mark.parametrize("mtp_use_repeated_layer", [False, True]) + @torch.inference_mode() + def test_cuda_graph_multi_depth(self, mtp_use_repeated_layer): + """Run multiple MTP depths with CUDA graphs enabled. + + Verifies that the hidden output from one depth feeds correctly into + the next depth through the same CUDA graph, producing valid outputs + at every depth. + """ + num_depths = 2 + engine = self._build_engine( + mtp_num_layers=num_depths, mtp_use_repeated_layer=mtp_use_repeated_layer + ) + model = engine.controller.inference_wrapped_model.model + unwrapped = unwrap_model(model) + batch_sizes = self._get_mtp_warmed_batch_sizes(engine) + assert len(batch_sizes) > 0, "Engine did not warm up any MTP CUDA graphs" + + use_repeated = unwrapped.mtp.mtp_use_repeated_layer + + batch_size = batch_sizes[0] + + hidden = torch.randn(batch_size, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16) + dist.broadcast(hidden, src=0) + token_ids = torch.randint(0, self.VOCAB_SIZE, (1, batch_size), device='cuda') + dist.broadcast(token_ids, src=0) + position_ids = torch.arange(batch_size, device='cuda', dtype=torch.int64).unsqueeze(0) + + current_hidden = hidden.clone() + for depth in range(num_depths): + mtp_depth = None if use_repeated else depth + current_hidden, logits = unwrapped.compute_mtp_single_step( + hidden_states=current_hidden, + next_token_ids=token_ids.clone(), + position_ids=position_ids.clone(), + depth=mtp_depth, + **self._mtp_kwargs(use_graph=True, batch_size=batch_size, mtp_depth=mtp_depth), + ) + current_hidden = current_hidden.clone() + + assert current_hidden.shape == (batch_size, 1, self.HIDDEN_SIZE), ( + f"Depth {depth}: expected hidden shape ({batch_size}, 1, {self.HIDDEN_SIZE}), " + f"got {current_hidden.shape}" + ) + assert logits.shape == (batch_size, 1, self.VOCAB_SIZE), ( + f"Depth {depth}: expected logits shape ({batch_size}, 1, {self.VOCAB_SIZE}), " + f"got {logits.shape}" + ) + assert torch.all( + torch.isfinite(logits) + ), f"Depth {depth}: logits contain non-finite values" + + self._assert_mtp_cuda_graphs_were_replayed(model, True) + + # ---- Test 6: caller-driven eager bypass for non-warmed shapes --------- # + + @pytest.mark.parametrize("mtp_use_repeated_layer", [False, True]) + @torch.inference_mode() + def test_eager_bypass_for_non_warmed_shape(self, mtp_use_repeated_layer): + """Passing `eager=True` runs `compute_mtp_single_step` outside the + CudaGraphManager wrapper. This is the canonical caller-side fallback + for a shape that warmup did not capture. + """ + engine = self._build_engine(mtp_use_repeated_layer=mtp_use_repeated_layer) + model = engine.controller.inference_wrapped_model.model + unwrapped = unwrap_model(model) + warmed_sizes = set(self._get_mtp_warmed_batch_sizes(engine)) + + # Find a batch size with no matching CUDA graph. + fallback_size = None + for candidate in range(1, 32): + if candidate not in warmed_sizes: + fallback_size = candidate + break + assert fallback_size is not None, "Could not find a non-warmed batch size" + + mtp_depth = None if unwrapped.mtp.mtp_use_repeated_layer else 0 + + hidden = torch.randn( + fallback_size, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16 + ) + dist.broadcast(hidden, src=0) + token_ids = torch.randint(0, self.VOCAB_SIZE, (1, fallback_size), device='cuda') + dist.broadcast(token_ids, src=0) + position_ids = torch.arange(fallback_size, device='cuda', dtype=torch.int64).unsqueeze(0) + + h_out, logits = unwrapped.compute_mtp_single_step( + hidden_states=hidden.clone(), + next_token_ids=token_ids.clone(), + position_ids=position_ids.clone(), + depth=mtp_depth, + eager=True, + ) + + assert h_out.shape == (fallback_size, 1, self.HIDDEN_SIZE) + assert logits.shape == (fallback_size, 1, self.VOCAB_SIZE) + assert torch.all(torch.isfinite(logits)) + + # ---- Test 7: delete_cuda_graphs resets MTP runners -------------------- # + + @torch.inference_mode() + def test_delete_cuda_graphs_resets_mtp_runners(self): + """`delete_cuda_graphs()` resets MTP CUDA graph runners. + + MTP runners join the standard `cudagraph_inference_record`, so the + standard cleanup loop resets their `fwd_graph_recorded` flag. + """ + engine = self._build_engine() + model = engine.controller.inference_wrapped_model.model + + self._assert_mtp_cuda_graphs_were_replayed(model, True) + + unwrapped = unwrap_model(model) + manager = getattr(unwrapped, '_mtp_cudagraph_manager', None) + assert manager is not None + assert len(manager.cudagraph_runners) > 0 + assert all(r.fwd_graph_recorded for r in manager.cudagraph_runners) + + delete_cuda_graphs() + + assert all(not r.fwd_graph_recorded for r in manager.cudagraph_runners) + assert all(r.fwd_graph is None for r in manager.cudagraph_runners) + + # ---- Test 8: last_token_logits under CUDA graph padding ---------------- # + + @torch.inference_mode() + def test_last_token_logits_cuda_graph_padding(self): + """num_last_token_logits returns padded count and last_token_logits + produces the correct shape under CUDA graph padding. + + Uses add_request + update_requests to build real decode batches, then + verifies that under CUDA graph matching: + 1. num_last_token_logits uses the padded decode count from the matched graph + 2. last_token_logits returns the padded number of rows + 3. The real (unpadded) index positions are sequential 0..N-1 + """ + num_spec = 2 + max_requests = 16 + engine = self._build_engine(num_speculative_tokens=num_spec, max_requests=max_requests) + context = engine.context + tokens_per_decode = num_spec + 1 + + # Collect decode-only graph sizes to pick active counts that will match. + decode_graph_sizes = sorted( + { + dim.decode_req_count + for dim in context.cuda_graph_batch_dimensions_list + if dim.prefill_req_count == 0 and dim.decode_req_count > 1 + } + ) + assert len(decode_graph_sizes) > 0, "No decode-only graph dims found" + + # Use active counts 1 less than some graph sizes to guarantee padding. + active_counts = [s - 1 for s in decode_graph_sizes if s >= 2][:3] + assert len(active_counts) > 0, "No sub-capacity decode graph dims found" + + for active_decode_count in active_counts: + context.reset() + + # Add prefill requests, then step them into decode state. + prompt_length = 10 + for i in range(active_decode_count): + req = DynamicInferenceRequest( + request_id=i, + prompt_tokens=torch.arange(prompt_length, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=100), + ) + context.add_request(req) + + context.initialize_attention_state() + + active_mask = torch.ones(active_decode_count, device='cuda', dtype=torch.int32) + new_tokens = torch.arange(active_decode_count, device='cuda') + new_spec = torch.arange(num_spec * active_decode_count, device='cuda').reshape( + num_spec, active_decode_count + ) + context.update_requests( + active_requests_mask=active_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_spec, + ) + + # Now all requests are decode. initialize_attention_state should match a graph. + context.initialize_attention_state() + + assert ( + context.using_cuda_graph_this_step() + ), f"Expected CUDA graph for active={active_decode_count}" + + # Read the actually matched graph dimensions. + matched = context.padded_batch_dimensions + padded_decode = matched.decode_req_count + padded_token_count = matched.token_count + assert padded_decode >= active_decode_count + + expected_padded_logits = padded_decode * tokens_per_decode + assert context.num_last_token_logits == expected_padded_logits, ( + f"active={active_decode_count}, padded={padded_decode}: " + f"num_last_token_logits expected {expected_padded_logits}, " + f"got {context.num_last_token_logits}" + ) + + # Verify the real decode indices are [0, 1, ..., real_token_count - 1]. + real_token_count = active_decode_count * tokens_per_decode + real_slice = context.active_logit_idxs[:real_token_count] + expected_real = torch.arange(real_token_count, dtype=torch.int32, device='cuda') + assert torch.equal( + real_slice, expected_real + ), f"real decode indices: {real_slice.tolist()} vs {expected_real.tolist()}" + + # Padding indices should be zero (indexing into logits[0]). + padding_count = expected_padded_logits - real_token_count + if padding_count > 0: + padding_slice = context.active_logit_idxs[real_token_count:expected_padded_logits] + assert ( + padding_slice.sum().item() == 0 + ), f"padding indices should be zero, got {padding_slice.tolist()}" + + # Verify last_token_logits produces a tensor with the padded row count. + vocab_size = 64 + fake_logits = torch.randn( + 1, padded_token_count, vocab_size, device='cuda', dtype=torch.float32 + ) + result = context.last_token_logits(fake_logits) + assert result.shape == (expected_padded_logits, vocab_size), ( + f"last_token_logits shape: expected ({expected_padded_logits}, {vocab_size}), " + f"got {result.shape}" + ) + + +# --------------------------------------------------------------------------- # +# TestMTPCudaGraphExpertParallel (EP = 2) +# --------------------------------------------------------------------------- # + +_EP_SIZE = 2 + +# Request state constants for parametrized tests. +NONE = "none" +DECODE = "decode" +PREFILL = "prefill" +MIXED = "mixed" + +ALL_STATES = [NONE, DECODE, PREFILL, MIXED] + +# Combinatorial sweep: C(4+2-1, 2) = 10 test cases. +_STATE_COMBOS = list(itertools.combinations_with_replacement(ALL_STATES, _EP_SIZE)) + +# Batch dimensions for each non-dummy state. +_STATE_DIMS = { + DECODE: InferenceBatchDimensions(token_count=2, prefill_req_count=0, decode_req_count=2), + PREFILL: InferenceBatchDimensions(token_count=16, prefill_req_count=2, decode_req_count=0), + MIXED: InferenceBatchDimensions(token_count=32, prefill_req_count=1, decode_req_count=2), +} + + +@pytest.mark.internal +class TestMTPCudaGraphExpertParallel: + """Tests for MTP CUDA-graphed inference with expert parallelism. + + Follows the test pattern from `test_mamba_model_expert_parallel_inference.py`. + All tests require at least `_EP_SIZE` GPUs. + """ + + HIDDEN_SIZE = 32 + VOCAB_SIZE = 100 + MAX_SEQ_LEN = 128 + NUM_LAYERS = 2 + NUM_ATTN_HEADS = 4 + NUM_MOE_EXPERTS = 2 + + @classmethod + def setup_class(cls): + if Utils.world_size < _EP_SIZE: + pytest.skip(f"EP test requires at least {_EP_SIZE} GPUs") + if Utils.world_size % _EP_SIZE != 0: + pytest.skip( + f"world_size ({Utils.world_size}) must be divisible by EP size ({_EP_SIZE})" + ) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=_EP_SIZE, + ) + + @classmethod + def teardown_class(cls): + delete_cuda_graphs() + Utils.destroy_model_parallel() + + def teardown_method(self): + delete_cuda_graphs() + + # ---- helpers ---------------------------------------------------------- # + + def _build_model(self, inference_moe_token_dispatcher_type='nccl'): + """Build a GPT model with MTP + MoE + local CUDA graphs.""" + model_parallel_cuda_manual_seed(123, inference_rng_tracker=True, force_reset_rng=True) + config = TransformerConfig( + num_layers=self.NUM_LAYERS, + hidden_size=self.HIDDEN_SIZE, + num_attention_heads=self.NUM_ATTN_HEADS, + use_cpu_initialization=True, + attention_backend=AttnBackend.local, + params_dtype=torch.bfloat16, + expert_model_parallel_size=_EP_SIZE, + num_moe_experts=self.NUM_MOE_EXPERTS, + moe_token_dispatcher_type="alltoall", + add_bias_linear=False, + mtp_num_layers=2, + cuda_graph_impl="local", + moe_pad_experts_for_cuda_graph_inference=True, + inference_moe_token_dispatcher_type=inference_moe_token_dispatcher_type, + ) + layer_spec = get_gpt_layer_local_spec(num_experts=self.NUM_MOE_EXPERTS) + mtp_block_spec = get_gpt_mtp_block_spec( + config=config, spec=layer_spec, use_transformer_engine=False + ) + model = GPTModel( + config=config, + transformer_layer_spec=layer_spec, + vocab_size=self.VOCAB_SIZE, + max_sequence_length=self.MAX_SEQ_LEN, + parallel_output=True, + pre_process=True, + post_process=True, + mtp_block_spec=mtp_block_spec, + ).cuda() + for param in model.parameters(): + param.data = param.data.to(config.params_dtype) + model.eval() + return model + + def _build_context( + self, + model, + *, + num_cuda_graphs=16, + use_cuda_graphs_for_non_decode_steps=True, + max_requests=None, + ): + """Build a DynamicInferenceContext for the model.""" + return DynamicInferenceContext( + model_config=model.config, + inference_config=InferenceConfig( + max_sequence_length=self.MAX_SEQ_LEN, + buffer_size_gb=0.5, + block_size_tokens=256, + materialize_only_last_token_logits=False, + num_cuda_graphs=num_cuda_graphs, + use_cuda_graphs_for_non_decode_steps=use_cuda_graphs_for_non_decode_steps, + max_requests=max_requests, + sampling_backend='torch', + ), + ) + + # ---- Test 1: all EP ranks run MTP eager forward ----------------------- # + + @pytest.mark.parametrize("batch_size", [2, 4, 8]) + @pytest.mark.internal + @torch.inference_mode() + def test_ep_mtp_eager_forward(self, batch_size): + """All EP ranks can run MTP forward in eager mode. + + The MoE all-to-all collectives must match across EP ranks. Verifies + that all ranks complete without hanging and produce valid shapes. + """ + model = self._build_model() + unwrapped = unwrap_model(model) + + # Broadcast identical inputs so all EP ranks see the same data. + hidden = torch.randn(batch_size, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16) + dist.broadcast(hidden, src=0) + token_ids = torch.randint(0, self.VOCAB_SIZE, (1, batch_size), device='cuda') + dist.broadcast(token_ids, src=0) + position_ids = torch.arange(batch_size, device='cuda', dtype=torch.int64).unsqueeze(0) + + h_out, logits = unwrapped.compute_mtp_single_step( + hidden_states=hidden.clone(), + next_token_ids=token_ids.clone(), + position_ids=position_ids.clone(), + depth=0, + eager=True, + ) + + assert h_out.shape == (batch_size, 1, self.HIDDEN_SIZE) + assert logits.shape == (batch_size, 1, self.VOCAB_SIZE) + assert torch.all(torch.isfinite(logits)) + + # ---- Test 2: dummy ranks + real ranks in eager mode ------------------- # + + @pytest.mark.internal + @torch.inference_mode() + def test_ep_mtp_eager_dummy_and_real_ranks(self): + """Even EP ranks run as dummy (with zeros), odd ranks run with real data. + + Both must issue matching MoE all-to-all collectives via the + MTP eager forward to avoid hangs. + """ + batch_size = 4 + model = self._build_model() + unwrapped = unwrap_model(model) + + ep_rank = parallel_state.get_expert_model_parallel_rank() + is_dummy = ep_rank % 2 == 0 + + if is_dummy: + hidden = torch.zeros( + batch_size, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16 + ) + token_ids = torch.zeros(1, batch_size, device='cuda', dtype=torch.long) + else: + hidden = torch.randn( + batch_size, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16 + ) + token_ids = torch.randint(0, self.VOCAB_SIZE, (1, batch_size), device='cuda') + position_ids = torch.arange(batch_size, device='cuda', dtype=torch.int64).unsqueeze(0) + + # All ranks must complete without hanging. + h_out, logits = unwrapped.compute_mtp_single_step( + hidden_states=hidden, + next_token_ids=token_ids, + position_ids=position_ids, + depth=0, + eager=True, + ) + + assert h_out.shape == (batch_size, 1, self.HIDDEN_SIZE) + assert logits.shape == (batch_size, 1, self.VOCAB_SIZE) + + # ---- Test 3: EP state cross product with DynamicInferenceContext ------- # + + @pytest.mark.parametrize("rank_states", _STATE_COMBOS, ids=[",".join(s) for s in _STATE_COMBOS]) + @pytest.mark.internal + @torch.inference_mode() + def test_ep_state_cross_product(self, rank_states): + """Test combinatorial assignments of request states across EP ranks. + + Verifies that: + - All EP ranks agree on CUDA graph usage (on or off). + - When CUDA graphs are used, all ranks agree on the padded batch size + (which would be used as the MTP batch dimension). + """ + ep_rank = parallel_state.get_expert_model_parallel_rank() + my_state = rank_states[ep_rank] + is_dummy = my_state == NONE + + model = self._build_model() + ctx = self._build_context(model) + + # Phase 1: Set up each rank's request state. + if not is_dummy: + ctx.add_dummy_requests_for_cudagraph_capture(_STATE_DIMS[my_state]) + + # Phase 2: Initialize attention state (EP collective). + if is_dummy: + ctx.initialize_attention_state(is_expert_parallel_dummy_cuda_graph_step=True) + else: + ctx.initialize_attention_state() + + # Phase 3: Verify EP agreement on CUDA graph usage. + uses_graph = ctx.using_cuda_graph_this_step() + ep_group = parallel_state.get_expert_model_parallel_group() + uses_graph_t = torch.tensor([int(uses_graph)], device='cuda', dtype=torch.int32) + graph_min = uses_graph_t.clone() + graph_max = uses_graph_t.clone() + dist.all_reduce(graph_min, op=dist.ReduceOp.MIN, group=ep_group) + dist.all_reduce(graph_max, op=dist.ReduceOp.MAX, group=ep_group) + assert graph_min.item() == graph_max.item(), ( + f"CUDA graph usage disagrees across EP ranks: " + f"min={graph_min.item()}, max={graph_max.item()} " + f"(rank_states={rank_states})" + ) + + if not uses_graph: + return + + # Phase 4: Derive MTP padded batch size from EP-synced dimensions. + mtp_padded = ctx.padded_batch_dimensions.req_count + + # Verify MTP padded count agrees across EP ranks. + padded_t = torch.tensor([mtp_padded], dtype=torch.int32, device='cuda') + padded_max = padded_t.clone() + padded_min = padded_t.clone() + dist.all_reduce(padded_max, op=dist.ReduceOp.MAX, group=ep_group) + dist.all_reduce(padded_min, op=dist.ReduceOp.MIN, group=ep_group) + assert padded_max.item() == padded_min.item(), ( + f"MTP padded batch size mismatch across EP ranks: " + f"min={padded_min.item()}, max={padded_max.item()} " + f"(rank_states={rank_states})" + ) + + # ---- Test 4: dummy EP rank bail-out with decode-only CUDA graphs ------ # + + @pytest.mark.parametrize( + "peer_state", [PREFILL, MIXED], ids=[f"peer={s}" for s in [PREFILL, MIXED]] + ) + @pytest.mark.internal + @torch.inference_mode() + def test_nccl_ep_dummy_bailout_with_decode_only_cuda_graphs(self, peer_state): + """Verify the dummy-rank bail-out path when only decode CUDA graphs + are available. + + With `use_cuda_graphs_for_non_decode_steps=False`, only decode-only + graphs exist. When any EP rank has prefill requests, no graph matches + and all ranks fall back to eager mode. The MTP forward for the dummy + rank must use eager execution without hanging. + """ + ep_rank = parallel_state.get_expert_model_parallel_rank() + is_even = ep_rank % 2 == 0 + + model = self._build_model(inference_moe_token_dispatcher_type='nccl') + ctx = self._build_context(model, use_cuda_graphs_for_non_decode_steps=False) + + # Even ranks are dummy; odd ranks have the peer_state. + if not is_even: + ctx.add_dummy_requests_for_cudagraph_capture(_STATE_DIMS[peer_state]) + + if is_even: + ctx.initialize_attention_state(is_expert_parallel_dummy_cuda_graph_step=True) + else: + ctx.initialize_attention_state() + + # No rank should match a CUDA graph. + assert not ctx.using_cuda_graph_this_step(), ( + f"EP rank {ep_rank}: expected no CUDA graph match with " + f"decode-only graphs and peer_state={peer_state}" + ) + + # MTP eager forward should still work on all ranks. + unwrapped = unwrap_model(model) + + tp_size = parallel_state.get_tensor_model_parallel_world_size() + dummy_hidden = torch.zeros( + (tp_size, 1, self.HIDDEN_SIZE), device='cuda', dtype=torch.bfloat16 + ) + dummy_tokens = torch.zeros((1, tp_size), device='cuda', dtype=torch.long) + dummy_positions = torch.zeros((1, tp_size), device='cuda', dtype=torch.long) + + h_out, logits = unwrapped.compute_mtp_single_step( + hidden_states=dummy_hidden, + next_token_ids=dummy_tokens, + position_ids=dummy_positions, + depth=0, + eager=True, + ) + + assert h_out.shape == (tp_size, 1, self.HIDDEN_SIZE) + assert logits.shape == (tp_size, 1, self.VOCAB_SIZE) diff --git a/tests/unit_tests/inference/test_mxfp8_utils.py b/tests/unit_tests/inference/test_mxfp8_utils.py index a137dfbc820..1a6b4163191 100644 --- a/tests/unit_tests/inference/test_mxfp8_utils.py +++ b/tests/unit_tests/inference/test_mxfp8_utils.py @@ -398,6 +398,10 @@ def _make_permutation_map(M, num_padding=0): return torch.cat([real, pad]) +def _vt(n): + return torch.tensor(n, dtype=torch.int32, device="cuda") + + # ────────────────────────────────────────────────────────────────────── # squared_relu_and_quantize_mxfp8 vs PyTorch reference # ────────────────────────────────────────────────────────────────────── @@ -440,7 +444,7 @@ def test_data_matches_pytorch_ref(self, M, K): _, ref_data = ref_to_mxfp(activated_ref) # Fused kernel - fused_result = squared_relu_and_quantize_mxfp8(x, perm_map) + fused_result = squared_relu_and_quantize_mxfp8(x, perm_map, _vt(M)) torch.testing.assert_close( fused_result.data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 @@ -461,7 +465,7 @@ def test_scales_match_pytorch_ref(self, M, K): ref_swizzled = ref_swizzle(ref_scales_2d) # Fused kernel - fused_result = squared_relu_and_quantize_mxfp8(x, perm_map) + fused_result = squared_relu_and_quantize_mxfp8(x, perm_map, _vt(M)) torch.testing.assert_close( fused_result.scale.view(torch.uint8), ref_swizzled.view(torch.uint8), atol=0, rtol=0 @@ -485,7 +489,7 @@ def test_real_rows_match_pytorch_ref_with_padding(self, M, K, num_padding): _, ref_data = ref_to_mxfp(activated_ref) # Fused kernel - fused_result = squared_relu_and_quantize_mxfp8(x, perm_map) + fused_result = squared_relu_and_quantize_mxfp8(x, perm_map, _vt(M)) torch.testing.assert_close( fused_result.data[:real_rows].view(torch.uint8), @@ -533,12 +537,12 @@ def test_data_matches_pytorch_ref(self, num_tokens, K, topk, num_experts): hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) - fused_mxfp8, _, fused_perm_map, _ = permute_and_quantize_mxfp8( - hidden, probs, routing_map, 0, num_experts, alignment=128 + fused_mxfp8, _, fused_perm_map, offs = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, num_experts, _vt(num_tokens), alignment=128 ) # For each real row, quantize the source token with PyTorch ref and compare - for i in range(fused_perm_map.shape[0]): + for i in range(offs[-1].item()): src = fused_perm_map[i].item() if src < 0: continue @@ -560,11 +564,11 @@ def test_batch_data_matches_pytorch_ref(self, num_tokens, K, topk, num_experts): hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) - fused_mxfp8, _, fused_perm_map, _ = permute_and_quantize_mxfp8( - hidden, probs, routing_map, 0, num_experts, alignment=128 + fused_mxfp8, _, fused_perm_map, offs = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, num_experts, _vt(num_tokens), alignment=128 ) - real_mask = fused_perm_map >= 0 + real_mask = fused_perm_map[: offs[-1].item()] >= 0 real_indices = real_mask.nonzero(as_tuple=True)[0] if len(real_indices) == 0: return @@ -590,11 +594,11 @@ def test_correct_token_count(self, num_tokens, K, topk, num_experts): hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) - _, _, fused_perm_map, _ = permute_and_quantize_mxfp8( - hidden, probs, routing_map, 0, num_experts, alignment=128 + _, _, fused_perm_map, offs = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, num_experts, _vt(num_tokens), alignment=128 ) - real_count = (fused_perm_map >= 0).sum().item() + real_count = (fused_perm_map[: offs[-1].item()] >= 0).sum().item() # All experts are local, so every pair should appear assert real_count == num_tokens * topk @@ -608,11 +612,11 @@ def test_expert_subset(self, num_tokens, K, topk, num_experts, local_start, num_ hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) - _, _, fused_perm_map, _ = permute_and_quantize_mxfp8( - hidden, probs, routing_map, local_start, num_local, alignment=128 + _, _, fused_perm_map, offs = permute_and_quantize_mxfp8( + hidden, probs, routing_map, local_start, num_local, _vt(num_tokens), alignment=128 ) - real_count = (fused_perm_map >= 0).sum().item() + real_count = (fused_perm_map[: offs[-1].item()] >= 0).sum().item() local_mask = (routing_map >= local_start) & (routing_map < local_start + num_local) expected_count = local_mask.sum().item() assert real_count == expected_count @@ -624,7 +628,7 @@ def test_returns_mxfp8_tensor(self): hidden, probs, routing_map = self._make_inputs(16, 128, 2, 4) result, _, _, _ = permute_and_quantize_mxfp8( - hidden, probs, routing_map, 0, 4, alignment=128 + hidden, probs, routing_map, 0, 4, _vt(16), alignment=128 ) assert isinstance(result, MXFP8Tensor) assert result.backend == "triton" @@ -637,7 +641,7 @@ def test_offsets_aligned(self, alignment): hidden, probs, routing_map = self._make_inputs(64, 128, 4, 8) _, _, _, offs = permute_and_quantize_mxfp8( - hidden, probs, routing_map, 0, 8, alignment=alignment + hidden, probs, routing_map, 0, 8, _vt(64), alignment=alignment ) for i in range(offs.shape[0]): assert ( diff --git a/tests/unit_tests/inference/test_vllm_fused_moe.py b/tests/unit_tests/inference/test_vllm_fused_moe.py new file mode 100644 index 00000000000..c3d0c35057c --- /dev/null +++ b/tests/unit_tests/inference/test_vllm_fused_moe.py @@ -0,0 +1,1024 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for megatron.core.inference.moe.vllm_fused_moe. + +Tests cover: +- _select_block_size_m: BLOCK_SIZE_M selection based on token count +- _moe_align_block_size_cuda_graphable: indirection table construction +- _moe_sum: fused topk reduction with routing weight application +- vllm_fused_moe: end-to-end correctness vs sequential reference +""" + +import os +import tempfile + +# Redirect Triton cache to /tmp BEFORE triton is imported (at module level) so +# compiled kernels don't accumulate in ~/.triton/ across test runs. +os.environ.setdefault("TRITON_CACHE_DIR", os.path.join(tempfile.gettempdir(), "triton_test_cache")) + +import pytest +import torch + + +@pytest.fixture(autouse=True, scope="session") +def _single_autotune_config(): + """Replace the 25-entry autotune config list with a single config. + + Each unique (N, K, BLOCK_SIZE_M) combo triggers a full autotune pass that + compiles ALL configs. Tests only need correctness, not peak throughput, so + one config is sufficient and cuts compiled-kernel count by ~25x. + """ + from megatron.core.inference.moe.vllm_fused_moe import _fused_moe_kernel + + orig = list(_fused_moe_kernel.configs) + _fused_moe_kernel.configs = [orig[0]] + yield + _fused_moe_kernel.configs = orig + + +def _vt(n): + """Create a valid_tokens scalar int32 CUDA tensor.""" + return torch.tensor(n, dtype=torch.int32, device="cuda") + + +def _ref_sequential_moe( + hidden_states, + probs, + fc1_weight, + fc2_weight, + routing_map, + num_local_experts, + local_expert_start, + valid_tokens, +): + """PyTorch reference: sequential per-token MoE computation. + + For each valid token, for each topk slot routed to a local expert: + intermediate = squared_relu(hidden @ fc1_weight[expert].T) + output += prob * (intermediate @ fc2_weight[expert].T) + """ + vt = valid_tokens if isinstance(valid_tokens, int) else valid_tokens.item() + max_tokens, topk = routing_map.shape + hidden_size = hidden_states.shape[1] + + out = torch.zeros(max_tokens, hidden_size, device="cuda", dtype=torch.float32) + + for t in range(vt): + acc = torch.zeros(hidden_size, device="cuda", dtype=torch.float32) + for k in range(topk): + eid = routing_map[t, k].item() + lid = eid - local_expert_start + if 0 <= lid < num_local_experts: + h = hidden_states[t].float() + fc1_out = h @ fc1_weight[lid].float().T + activated = torch.clamp(fc1_out, min=0.0) ** 2 + fc2_out = activated @ fc2_weight[lid].float().T + acc += probs[t, k].item() * fc2_out + out[t] = acc + + return out + + +def _make_moe_inputs( + max_tokens, hidden_size, ffn_hidden, topk, num_experts, valid_tokens=None, seed=42 +): + """Create random inputs for fused MoE testing.""" + torch.manual_seed(seed) + if valid_tokens is None: + valid_tokens = max_tokens + hidden = torch.randn(max_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + probs = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + fc1_weight = ( + torch.randn(num_experts, ffn_hidden, hidden_size, device="cuda", dtype=torch.bfloat16) + * 0.01 + ) + fc2_weight = ( + torch.randn(num_experts, hidden_size, ffn_hidden, device="cuda", dtype=torch.bfloat16) + * 0.01 + ) + return hidden, probs, routing_map, fc1_weight, fc2_weight + + +# ────────────────────────────────────────────────────────────────────── +# _select_block_size_m +# ────────────────────────────────────────────────────────────────────── + + +class TestSelectBlockSizeM: + + @pytest.mark.parametrize( + "max_tokens,expected", + [ + (1, 16), + (16, 16), + (32, 16), + (33, 32), + (64, 32), + (96, 32), + (97, 64), + (256, 64), + (512, 64), + (513, 128), + (1024, 128), + (4096, 128), + ], + ) + def test_returns_expected(self, max_tokens, expected): + from megatron.core.inference.moe.vllm_fused_moe import _select_block_size_m + + assert _select_block_size_m(max_tokens) == expected + + def test_minimum_is_16(self): + from megatron.core.inference.moe.vllm_fused_moe import _select_block_size_m + + assert _select_block_size_m(1) >= 16 + + def test_monotonically_nondecreasing(self): + from megatron.core.inference.moe.vllm_fused_moe import _select_block_size_m + + prev = _select_block_size_m(1) + for n in range(2, 2048): + cur = _select_block_size_m(n) + assert cur >= prev, f"Decreased at n={n}: {prev} -> {cur}" + prev = cur + + +# ────────────────────────────────────────────────────────────────────── +# _moe_align_block_size_cuda_graphable +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.internal +class TestMoeAlignBlockSize: + + @pytest.mark.parametrize( + "max_tokens,topk,num_experts,block_size", + [ + (4, 1, 4, 16), + (8, 2, 4, 16), + (16, 2, 8, 32), + (32, 4, 8, 64), + (64, 6, 8, 128), + (128, 8, 16, 64), + (1, 1, 4, 16), + ], + ) + def test_output_shapes(self, max_tokens, topk, num_experts, block_size): + from megatron.core.inference.moe.vllm_fused_moe import ( + _ceil_div, + _moe_align_block_size_cuda_graphable, + ) + + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + sorted_ids, expert_ids, num_post_padded = _moe_align_block_size_cuda_graphable( + routing_map, block_size, num_experts, 0, _vt(max_tokens) + ) + + max_sorted = max_tokens * topk + block_size * (num_experts + 1) + max_blocks = _ceil_div(max_sorted, block_size) + + assert sorted_ids.shape == (max_sorted,) + assert sorted_ids.dtype == torch.int32 + assert expert_ids.shape == (max_blocks,) + assert expert_ids.dtype == torch.int32 + assert num_post_padded.shape == (1,) + assert num_post_padded.dtype == torch.int32 + + @pytest.mark.parametrize( + "max_tokens,topk,num_experts", [(16, 2, 4), (32, 4, 8), (64, 6, 8), (128, 8, 16)] + ) + @pytest.mark.parametrize("block_size", [16, 32, 64, 128]) + def test_num_post_padded_is_aligned(self, max_tokens, topk, num_experts, block_size): + from megatron.core.inference.moe.vllm_fused_moe import _moe_align_block_size_cuda_graphable + + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + _, _, num_post_padded = _moe_align_block_size_cuda_graphable( + routing_map, block_size, num_experts, 0, _vt(max_tokens) + ) + npp = num_post_padded.item() + assert npp % block_size == 0, f"num_post_padded={npp} not aligned to {block_size}" + + @pytest.mark.parametrize("max_tokens,topk,num_experts", [(8, 2, 4), (16, 4, 8), (32, 6, 8)]) + def test_all_local_tokens_present(self, max_tokens, topk, num_experts, block_size=16): + """Every (token, topk) pair routed to a local expert appears in sorted_token_ids.""" + from megatron.core.inference.moe.vllm_fused_moe import _moe_align_block_size_cuda_graphable + + torch.manual_seed(42) + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + sentinel = max_tokens * topk + sorted_ids, _, num_post_padded = _moe_align_block_size_cuda_graphable( + routing_map, block_size, num_experts, 0, _vt(max_tokens) + ) + + valid_sorted = sorted_ids[: num_post_padded.item()] + real_ids = valid_sorted[valid_sorted < sentinel] + + expected_pairs = set() + for t in range(max_tokens): + for k in range(topk): + expected_pairs.add(t * topk + k) + actual_pairs = set(real_ids.cpu().tolist()) + assert actual_pairs == expected_pairs + + @pytest.mark.parametrize( + "max_tokens,topk,num_experts,local_start,num_local", + [(32, 4, 8, 0, 4), (32, 4, 8, 4, 4), (64, 6, 16, 4, 8), (64, 6, 16, 0, 1)], + ) + def test_expert_subset_only_local_tokens( + self, max_tokens, topk, num_experts, local_start, num_local + ): + """Only tokens routed to local experts appear in sorted_token_ids.""" + from megatron.core.inference.moe.vllm_fused_moe import _moe_align_block_size_cuda_graphable + + torch.manual_seed(42) + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + sentinel = max_tokens * topk + sorted_ids, _, num_post_padded = _moe_align_block_size_cuda_graphable( + routing_map, 16, num_local, local_start, _vt(max_tokens) + ) + + valid_sorted = sorted_ids[: num_post_padded.item()] + real_ids = valid_sorted[valid_sorted < sentinel] + + expected_pairs = set() + rm_flat = routing_map.cpu() + for t in range(max_tokens): + for k in range(topk): + eid = rm_flat[t, k].item() + if local_start <= eid < local_start + num_local: + expected_pairs.add(t * topk + k) + + actual_pairs = set(real_ids.cpu().tolist()) + assert actual_pairs == expected_pairs + + @pytest.mark.parametrize("block_size", [16, 32, 64, 128]) + def test_expert_ids_cover_all_blocks(self, block_size): + """expert_ids has a valid expert index for every block in [0, num_post_padded/block_size).""" + from megatron.core.inference.moe.vllm_fused_moe import _moe_align_block_size_cuda_graphable + + max_tokens, topk, num_experts = 32, 4, 8 + torch.manual_seed(42) + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + _, expert_ids, num_post_padded = _moe_align_block_size_cuda_graphable( + routing_map, block_size, num_experts, 0, _vt(max_tokens) + ) + + n_blocks = num_post_padded.item() // block_size + active_eids = expert_ids[:n_blocks].cpu() + assert (active_eids >= 0).all(), "Found negative expert_id in active range" + assert (active_eids < num_experts).all(), "Found expert_id >= num_experts" + + +# ────────────────────────────────────────────────────────────────────── +# _moe_sum +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.internal +class TestMoeSum: + + def _ref_moe_sum( + self, + input, + topk_weights, + max_tokens, + topk, + K, + routing_map, + local_expert_start, + num_local_experts, + ): + """PyTorch reference for _moe_sum: reduce topk with local-expert filtering.""" + out = torch.zeros(max_tokens, K, device="cuda", dtype=torch.float32) + for t in range(max_tokens): + acc = torch.zeros(K, device="cuda", dtype=torch.float32) + for k in range(topk): + eid = routing_map[t, k].item() + lid = eid - local_expert_start + if 0 <= lid < num_local_experts: + v = input[t * topk + k].float() + w = topk_weights[t, k].item() + acc += v * w + out[t] = acc + return out + + @pytest.mark.parametrize( + "max_tokens,topk,K,num_experts", + [(4, 1, 64, 4), (8, 2, 64, 4), (16, 4, 128, 8), (32, 6, 128, 8), (64, 8, 256, 16)], + ) + def test_matches_reference_all_local(self, max_tokens, topk, K, num_experts): + from megatron.core.inference.moe.vllm_fused_moe import _moe_sum + + torch.manual_seed(42) + input = torch.randn(max_tokens * topk, K, device="cuda", dtype=torch.bfloat16) + topk_weights = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + + result = _moe_sum( + input, topk_weights, max_tokens, topk, K, _vt(max_tokens), routing_map, 0, num_experts + ) + expected = self._ref_moe_sum( + input, topk_weights, max_tokens, topk, K, routing_map, 0, num_experts + ) + + torch.testing.assert_close(result, expected, atol=1e-3, rtol=1e-3) + + @pytest.mark.parametrize( + "local_start,num_local,num_experts", [(0, 4, 8), (4, 4, 8), (0, 1, 8), (2, 3, 8)] + ) + def test_matches_reference_expert_subset(self, local_start, num_local, num_experts): + from megatron.core.inference.moe.vllm_fused_moe import _moe_sum + + max_tokens, topk, K = 32, 4, 128 + torch.manual_seed(42) + input = torch.randn(max_tokens * topk, K, device="cuda", dtype=torch.bfloat16) + topk_weights = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + + result = _moe_sum( + input, + topk_weights, + max_tokens, + topk, + K, + _vt(max_tokens), + routing_map, + local_start, + num_local, + ) + expected = self._ref_moe_sum( + input, topk_weights, max_tokens, topk, K, routing_map, local_start, num_local + ) + + torch.testing.assert_close(result, expected, atol=1e-3, rtol=1e-3) + + @pytest.mark.parametrize("valid_tokens", [0, 1, 8, 15]) + def test_partial_valid_tokens(self, valid_tokens): + """Rows beyond valid_tokens are zeroed.""" + from megatron.core.inference.moe.vllm_fused_moe import _moe_sum + + max_tokens, topk, K, num_experts = 16, 2, 64, 4 + torch.manual_seed(42) + input = torch.randn(max_tokens * topk, K, device="cuda", dtype=torch.bfloat16) + topk_weights = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + + result = _moe_sum( + input, topk_weights, max_tokens, topk, K, _vt(valid_tokens), routing_map, 0, num_experts + ) + + if valid_tokens < max_tokens: + zeros = result[valid_tokens:] + assert (zeros == 0).all(), "Rows beyond valid_tokens should be zero" + + def test_writes_to_provided_output_buffer(self): + from megatron.core.inference.moe.vllm_fused_moe import _moe_sum + + max_tokens, topk, K, num_experts = 8, 2, 64, 4 + torch.manual_seed(42) + input = torch.randn(max_tokens * topk, K, device="cuda", dtype=torch.bfloat16) + topk_weights = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + + out_buf = torch.empty(max_tokens, K, dtype=torch.float32, device="cuda") + result = _moe_sum( + input, + topk_weights, + max_tokens, + topk, + K, + _vt(max_tokens), + routing_map, + 0, + num_experts, + out=out_buf, + ) + + assert result.data_ptr() == out_buf.data_ptr() + + @pytest.mark.parametrize("K", [32, 64, 128, 256, 512, 1024, 2688]) + def test_various_hidden_dims(self, K): + from megatron.core.inference.moe.vllm_fused_moe import _moe_sum + + max_tokens, topk, num_experts = 8, 4, 4 + torch.manual_seed(42) + input = torch.randn(max_tokens * topk, K, device="cuda", dtype=torch.bfloat16) + topk_weights = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + + result = _moe_sum( + input, topk_weights, max_tokens, topk, K, _vt(max_tokens), routing_map, 0, num_experts + ) + expected = self._ref_moe_sum( + input, topk_weights, max_tokens, topk, K, routing_map, 0, num_experts + ) + + torch.testing.assert_close(result, expected, atol=1e-3, rtol=1e-3) + + +# ────────────────────────────────────────────────────────────────────── +# vllm_fused_moe (end-to-end) +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.internal +class TestVllmFusedMoe: + + @pytest.mark.parametrize( + "max_tokens,hidden_size,ffn_hidden,topk,num_experts", + [ + (1, 64, 64, 1, 4), + (4, 64, 64, 2, 4), + (8, 128, 128, 2, 8), + (16, 128, 128, 4, 8), + (32, 128, 128, 6, 8), + (64, 128, 256, 4, 8), + (128, 64, 128, 8, 16), + ], + ) + def test_matches_reference(self, max_tokens, hidden_size, ffn_hidden, topk, num_experts): + """vllm_fused_moe output matches sequential per-token reference.""" + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + hidden, probs, routing_map, fc1_weight, fc2_weight = _make_moe_inputs( + max_tokens, hidden_size, ffn_hidden, topk, num_experts + ) + + result = vllm_fused_moe( + hidden, + probs, + fc1_weight, + fc2_weight, + ActivationType.SQUARED_RELU, + num_experts, + 0, + _vt(max_tokens), + routing_map, + ) + expected = _ref_sequential_moe( + hidden, probs, fc1_weight, fc2_weight, routing_map, num_experts, 0, max_tokens + ) + + assert result.shape == (max_tokens, hidden_size) + assert result.dtype == torch.float32 + torch.testing.assert_close(result, expected, atol=5e-2, rtol=5e-2) + + @pytest.mark.parametrize( + "local_start,num_local,num_experts", [(0, 4, 8), (4, 4, 8), (0, 2, 8), (6, 2, 8)] + ) + def test_expert_subset(self, local_start, num_local, num_experts): + """Correct output when only a subset of experts are local.""" + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + max_tokens, hidden_size, ffn_hidden, topk = 32, 128, 128, 4 + hidden, probs, routing_map, fc1_weight, fc2_weight = _make_moe_inputs( + max_tokens, hidden_size, ffn_hidden, topk, num_experts + ) + fc1_local = fc1_weight[local_start : local_start + num_local].contiguous() + fc2_local = fc2_weight[local_start : local_start + num_local].contiguous() + + result = vllm_fused_moe( + hidden, + probs, + fc1_local, + fc2_local, + ActivationType.SQUARED_RELU, + num_local, + local_start, + _vt(max_tokens), + routing_map, + ) + expected = _ref_sequential_moe( + hidden, probs, fc1_weight, fc2_weight, routing_map, num_local, local_start, max_tokens + ) + + torch.testing.assert_close(result, expected, atol=5e-2, rtol=5e-2) + + @pytest.mark.parametrize("valid_tokens", [1, 4, 8, 15]) + def test_partial_valid_tokens(self, valid_tokens): + """Only the first valid_tokens rows have meaningful output.""" + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + max_tokens, hidden_size, ffn_hidden, topk, num_experts = 16, 128, 128, 4, 8 + hidden, probs, routing_map, fc1_weight, fc2_weight = _make_moe_inputs( + max_tokens, hidden_size, ffn_hidden, topk, num_experts + ) + + result = vllm_fused_moe( + hidden, + probs, + fc1_weight, + fc2_weight, + ActivationType.SQUARED_RELU, + num_experts, + 0, + _vt(valid_tokens), + routing_map, + ) + expected = _ref_sequential_moe( + hidden, probs, fc1_weight, fc2_weight, routing_map, num_experts, 0, valid_tokens + ) + + torch.testing.assert_close( + result[:valid_tokens], expected[:valid_tokens], atol=5e-2, rtol=5e-2 + ) + + def test_output_buffer_reuse(self): + """Output is written to provided buffer when out= is specified.""" + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + max_tokens, hidden_size, ffn_hidden, topk, num_experts = 8, 128, 128, 2, 4 + hidden, probs, routing_map, fc1_weight, fc2_weight = _make_moe_inputs( + max_tokens, hidden_size, ffn_hidden, topk, num_experts + ) + + out_buf = torch.empty(max_tokens, hidden_size, dtype=torch.float32, device="cuda") + result = vllm_fused_moe( + hidden, + probs, + fc1_weight, + fc2_weight, + ActivationType.SQUARED_RELU, + num_experts, + 0, + _vt(max_tokens), + routing_map, + out=out_buf, + ) + + assert result.data_ptr() == out_buf.data_ptr() + + def test_rejects_non_bf16_input(self): + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + max_tokens, hidden_size, ffn_hidden, topk, num_experts = 4, 64, 64, 2, 4 + _, probs, routing_map, fc1_weight, fc2_weight = _make_moe_inputs( + max_tokens, hidden_size, ffn_hidden, topk, num_experts + ) + hidden_fp32 = torch.randn(max_tokens, hidden_size, device="cuda", dtype=torch.float32) + + with pytest.raises(AssertionError, match="bf16"): + vllm_fused_moe( + hidden_fp32, + probs, + fc1_weight, + fc2_weight, + ActivationType.SQUARED_RELU, + num_experts, + 0, + _vt(max_tokens), + routing_map, + ) + + @pytest.mark.parametrize("seed", [0, 7, 42, 123, 999]) + def test_deterministic_across_seeds(self, seed): + """Same inputs produce the same output regardless of seed.""" + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + max_tokens, hidden_size, ffn_hidden, topk, num_experts = 16, 128, 128, 4, 8 + hidden, probs, routing_map, fc1_weight, fc2_weight = _make_moe_inputs( + max_tokens, hidden_size, ffn_hidden, topk, num_experts, seed=seed + ) + + result = vllm_fused_moe( + hidden, + probs, + fc1_weight, + fc2_weight, + ActivationType.SQUARED_RELU, + num_experts, + 0, + _vt(max_tokens), + routing_map, + ) + expected = _ref_sequential_moe( + hidden, probs, fc1_weight, fc2_weight, routing_map, num_experts, 0, max_tokens + ) + + torch.testing.assert_close(result, expected, atol=5e-2, rtol=5e-2) + + def test_num_tokens_hint(self): + """num_tokens_hint selects a potentially different block size but result is still correct.""" + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + max_tokens, hidden_size, ffn_hidden, topk, num_experts = 16, 128, 128, 4, 8 + hidden, probs, routing_map, fc1_weight, fc2_weight = _make_moe_inputs( + max_tokens, hidden_size, ffn_hidden, topk, num_experts + ) + + result = vllm_fused_moe( + hidden, + probs, + fc1_weight, + fc2_weight, + ActivationType.SQUARED_RELU, + num_experts, + 0, + _vt(max_tokens), + routing_map, + num_tokens_hint=4, + ) + expected = _ref_sequential_moe( + hidden, probs, fc1_weight, fc2_weight, routing_map, num_experts, 0, max_tokens + ) + + torch.testing.assert_close(result, expected, atol=5e-2, rtol=5e-2) + + +# ────────────────────────────────────────────────────────────────────── +# CUDA graph capture + replay +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.internal +class TestVllmFusedMoeCudaGraph: + + @pytest.mark.parametrize( + "max_tokens,hidden_size,ffn_hidden,topk,num_experts", + [(8, 128, 128, 2, 4), (16, 128, 128, 4, 8), (32, 128, 128, 6, 8), (64, 128, 256, 4, 8)], + ) + def test_capture_and_replay(self, max_tokens, hidden_size, ffn_hidden, topk, num_experts): + """vllm_fused_moe can be captured in a CUDA graph and replayed correctly.""" + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + torch.manual_seed(42) + static_hidden = torch.randn(max_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + static_probs = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + static_routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + static_vt = _vt(max_tokens) + fc1 = ( + torch.randn(num_experts, ffn_hidden, hidden_size, device="cuda", dtype=torch.bfloat16) + * 0.01 + ) + fc2 = ( + torch.randn(num_experts, hidden_size, ffn_hidden, device="cuda", dtype=torch.bfloat16) + * 0.01 + ) + static_out = torch.empty(max_tokens, hidden_size, dtype=torch.float32, device="cuda") + + # Warmup to trigger Triton autotuning + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.no_grad(), torch.cuda.stream(s): + for _ in range(3): + vllm_fused_moe( + static_hidden, + static_probs, + fc1, + fc2, + ActivationType.SQUARED_RELU, + num_experts, + 0, + static_vt, + static_routing_map, + out=static_out, + ) + torch.cuda.current_stream().wait_stream(s) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + vllm_fused_moe( + static_hidden, + static_probs, + fc1, + fc2, + ActivationType.SQUARED_RELU, + num_experts, + 0, + static_vt, + static_routing_map, + out=static_out, + ) + + graph.replay() + + expected = _ref_sequential_moe( + static_hidden, static_probs, fc1, fc2, static_routing_map, num_experts, 0, max_tokens + ) + torch.testing.assert_close(static_out, expected, atol=5e-2, rtol=5e-2) + + @pytest.mark.parametrize( + "max_tokens,valid_tokens_list", + [(16, [16, 8, 1, 4, 16]), (32, [32, 1, 16, 8, 32]), (64, [64, 32, 4, 1, 48])], + ) + def test_replay_with_varying_valid_tokens(self, max_tokens, valid_tokens_list): + """Replaying with different valid_tokens produces correct results each time. + + This is the core decode use case: the buffer is max-sized but only a + varying prefix is valid each iteration. The graph is captured once and + replayed with updated valid_tokens. + """ + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + hidden_size, ffn_hidden, topk, num_experts = 128, 128, 4, 8 + + torch.manual_seed(42) + static_hidden = torch.randn(max_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + static_probs = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + static_routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + static_vt = _vt(max_tokens) + fc1 = ( + torch.randn(num_experts, ffn_hidden, hidden_size, device="cuda", dtype=torch.bfloat16) + * 0.01 + ) + fc2 = ( + torch.randn(num_experts, hidden_size, ffn_hidden, device="cuda", dtype=torch.bfloat16) + * 0.01 + ) + static_out = torch.empty(max_tokens, hidden_size, dtype=torch.float32, device="cuda") + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.no_grad(), torch.cuda.stream(s): + for _ in range(3): + vllm_fused_moe( + static_hidden, + static_probs, + fc1, + fc2, + ActivationType.SQUARED_RELU, + num_experts, + 0, + static_vt, + static_routing_map, + out=static_out, + ) + torch.cuda.current_stream().wait_stream(s) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + vllm_fused_moe( + static_hidden, + static_probs, + fc1, + fc2, + ActivationType.SQUARED_RELU, + num_experts, + 0, + static_vt, + static_routing_map, + out=static_out, + ) + + for vt in valid_tokens_list: + static_vt.fill_(vt) + graph.replay() + + expected = _ref_sequential_moe( + static_hidden, static_probs, fc1, fc2, static_routing_map, num_experts, 0, vt + ) + torch.testing.assert_close( + static_out[:vt], + expected[:vt], + atol=5e-2, + rtol=5e-2, + msg=f"Mismatch with valid_tokens={vt}", + ) + + def test_replay_with_new_inputs(self): + """Replaying after mutating hidden/probs/routing produces correct results. + + After graph capture, we overwrite the static input buffers with entirely + new data and replay. The graph re-reads from the same device pointers, + so it picks up the new values. + """ + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + max_tokens, hidden_size, ffn_hidden, topk, num_experts = 16, 128, 128, 4, 8 + + torch.manual_seed(42) + static_hidden = torch.randn(max_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + static_probs = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + static_routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + static_vt = _vt(max_tokens) + fc1 = ( + torch.randn(num_experts, ffn_hidden, hidden_size, device="cuda", dtype=torch.bfloat16) + * 0.01 + ) + fc2 = ( + torch.randn(num_experts, hidden_size, ffn_hidden, device="cuda", dtype=torch.bfloat16) + * 0.01 + ) + static_out = torch.empty(max_tokens, hidden_size, dtype=torch.float32, device="cuda") + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.no_grad(), torch.cuda.stream(s): + for _ in range(3): + vllm_fused_moe( + static_hidden, + static_probs, + fc1, + fc2, + ActivationType.SQUARED_RELU, + num_experts, + 0, + static_vt, + static_routing_map, + out=static_out, + ) + torch.cuda.current_stream().wait_stream(s) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + vllm_fused_moe( + static_hidden, + static_probs, + fc1, + fc2, + ActivationType.SQUARED_RELU, + num_experts, + 0, + static_vt, + static_routing_map, + out=static_out, + ) + + for seed in [7, 123, 999]: + torch.manual_seed(seed) + static_hidden.copy_( + torch.randn(max_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + ) + static_probs.copy_(torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32)) + static_routing_map.copy_( + torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + ) + + graph.replay() + + expected = _ref_sequential_moe( + static_hidden, + static_probs, + fc1, + fc2, + static_routing_map, + num_experts, + 0, + max_tokens, + ) + torch.testing.assert_close( + static_out, + expected, + atol=5e-2, + rtol=5e-2, + msg=f"Mismatch after replaying with seed={seed}", + ) + + def test_replay_matches_eager(self): + """Graph replay produces the same output as eager execution on identical inputs.""" + from megatron.core.inference.moe.fused_moe import ActivationType + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + max_tokens, hidden_size, ffn_hidden, topk, num_experts = 32, 128, 128, 4, 8 + + torch.manual_seed(42) + static_hidden = torch.randn(max_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + static_probs = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + static_routing_map = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + static_vt = _vt(max_tokens) + fc1 = ( + torch.randn(num_experts, ffn_hidden, hidden_size, device="cuda", dtype=torch.bfloat16) + * 0.01 + ) + fc2 = ( + torch.randn(num_experts, hidden_size, ffn_hidden, device="cuda", dtype=torch.bfloat16) + * 0.01 + ) + static_out = torch.empty(max_tokens, hidden_size, dtype=torch.float32, device="cuda") + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.no_grad(), torch.cuda.stream(s): + for _ in range(3): + vllm_fused_moe( + static_hidden, + static_probs, + fc1, + fc2, + ActivationType.SQUARED_RELU, + num_experts, + 0, + static_vt, + static_routing_map, + out=static_out, + ) + torch.cuda.current_stream().wait_stream(s) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + vllm_fused_moe( + static_hidden, + static_probs, + fc1, + fc2, + ActivationType.SQUARED_RELU, + num_experts, + 0, + static_vt, + static_routing_map, + out=static_out, + ) + + for seed in [0, 7, 42]: + torch.manual_seed(seed) + new_hidden = torch.randn(max_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + new_probs = torch.rand(max_tokens, topk, device="cuda", dtype=torch.float32) + new_routing = torch.randint(0, num_experts, (max_tokens, topk), device="cuda") + new_vt = torch.randint(1, max_tokens + 1, (1,)).item() + + static_hidden.copy_(new_hidden) + static_probs.copy_(new_probs) + static_routing_map.copy_(new_routing) + static_vt.fill_(new_vt) + + graph.replay() + graph_result = static_out[:new_vt].clone() + + eager_out = torch.empty(max_tokens, hidden_size, dtype=torch.float32, device="cuda") + vllm_fused_moe( + new_hidden, + new_probs, + fc1, + fc2, + ActivationType.SQUARED_RELU, + num_experts, + 0, + _vt(new_vt), + new_routing, + out=eager_out, + ) + eager_result = eager_out[:new_vt] + + torch.testing.assert_close( + graph_result, + eager_result, + atol=0, + rtol=0, + msg=f"Graph/eager mismatch at seed={seed}, valid_tokens={new_vt}", + ) + + +# ────────────────────────────────────────────────────────────────────── +# Cross-backend: vllm vs mcore_fused_moe (BF16) +# ────────────────────────────────────────────────────────────────────── + + +requires_torch_grouped_mm = pytest.mark.skipif( + not hasattr(torch.nn.functional, "grouped_mm"), + reason="Requires torch.nn.functional.grouped_mm (PyTorch 2.10+)", +) + + +@pytest.mark.internal +@requires_torch_grouped_mm +class TestVllmVsMcoreFusedMoe: + + @pytest.mark.parametrize( + "max_tokens,hidden_size,ffn_hidden,topk,num_experts", + [(4, 64, 64, 2, 4), (16, 128, 128, 4, 8), (32, 128, 128, 6, 8), (64, 128, 256, 4, 8)], + ) + def test_vllm_matches_mcore(self, max_tokens, hidden_size, ffn_hidden, topk, num_experts): + """vllm_fused_moe and mcore_fused_moe produce equivalent results on BF16.""" + from megatron.core.inference.moe.fused_moe import ActivationType, mcore_fused_moe + from megatron.core.inference.moe.vllm_fused_moe import vllm_fused_moe + + hidden, probs, routing_map, fc1_weight, fc2_weight = _make_moe_inputs( + max_tokens, hidden_size, ffn_hidden, topk, num_experts + ) + vt = _vt(max_tokens) + + vllm_out = vllm_fused_moe( + hidden.clone(), + probs.clone(), + fc1_weight, + fc2_weight, + ActivationType.SQUARED_RELU, + num_experts, + 0, + vt, + routing_map.clone(), + ) + mcore_out = mcore_fused_moe( + hidden.clone(), + probs.clone(), + fc1_weight, + fc2_weight, + ActivationType.SQUARED_RELU, + num_experts, + 0, + vt, + routing_map.clone(), + ) + + torch.testing.assert_close(vllm_out, mcore_out, atol=5e-2, rtol=5e-2) diff --git a/tests/unit_tests/inference/text_generation_controllers/test_mtp_utils.py b/tests/unit_tests/inference/text_generation_controllers/test_mtp_utils.py new file mode 100644 index 00000000000..16d9d901624 --- /dev/null +++ b/tests/unit_tests/inference/text_generation_controllers/test_mtp_utils.py @@ -0,0 +1,694 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Unit tests for MTP Triton kernels. + +Each test runs both the pure-PyTorch reference (from mtp_utils_pytorch) and +the Triton kernel (from mtp_utils_triton) on the same inputs, and asserts +that the outputs match exactly. +""" + +import pytest +import torch + +from megatron.core.inference.text_generation_controllers.mtp_utils_pytorch import ( + mamba_state_selective_copy as mamba_state_selective_copy_pytorch, +) +from megatron.core.inference.text_generation_controllers.mtp_utils_pytorch import ( + prepare_next_forward_pass as prepare_next_forward_pass_pytorch, +) +from megatron.core.inference.text_generation_controllers.mtp_utils_pytorch import ( + rewind_kv_cache as rewind_kv_cache_pytorch, +) +from megatron.core.inference.text_generation_controllers.mtp_utils_pytorch import ( + verify_speculative_tokens as verify_speculative_tokens_pytorch, +) +from megatron.core.inference.text_generation_controllers.mtp_utils_triton import ( + mamba_state_selective_copy, + prepare_next_forward_pass, + rewind_kv_cache, + verify_speculative_tokens, +) + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + +DEVICE = "cuda" + + +def _clone_tensors(*tensors): + """Return a tuple of cloned tensors (for running reference vs kernel on the same data).""" + return tuple(t.clone() for t in tensors) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRewindKvCache: + """Tests for the rewind_kv_cache Triton kernel.""" + + @pytest.mark.parametrize("num_requests", [1, 4, 16]) + @pytest.mark.parametrize("num_speculative_tokens", [1, 2, 4]) + @pytest.mark.parametrize("block_size_tokens", [8, 16, 64]) + def test_basic(self, num_requests, num_speculative_tokens, block_size_tokens): + N = num_requests + max_blocks = 8 + + accepted_counts = torch.randint(0, num_speculative_tokens + 1, (N,), device=DEVICE) + prefill_status = torch.zeros(N, dtype=torch.int32, device=DEVICE) + + last_kv_block_offset = torch.randint(0, block_size_tokens, (N,), device=DEVICE) + kv_length_offsets = torch.randint( + block_size_tokens, block_size_tokens * 4, (N,), device=DEVICE + ) + kv_block_counts = torch.randint(2, max_blocks, (N,), device=DEVICE) + last_kv_block_id = torch.randint(0, 100, (N,), device=DEVICE) + kv_block_ids = torch.randint(0, 100, (N, max_blocks), device=DEVICE) + + ref_offset, ref_kv_len, ref_block_counts, ref_last_block, ref_block_ids = _clone_tensors( + last_kv_block_offset, kv_length_offsets, kv_block_counts, last_kv_block_id, kv_block_ids + ) + tri_offset, tri_kv_len, tri_block_counts, tri_last_block, tri_block_ids = _clone_tensors( + last_kv_block_offset, kv_length_offsets, kv_block_counts, last_kv_block_id, kv_block_ids + ) + + ref_release, ref_mask = rewind_kv_cache_pytorch( + accepted_counts.clone(), + prefill_status.clone(), + ref_offset, + ref_kv_len, + ref_block_counts, + ref_last_block, + ref_block_ids, + num_speculative_tokens, + block_size_tokens, + ) + + tri_release, tri_mask = rewind_kv_cache( + accepted_counts.clone(), + prefill_status.clone(), + tri_offset, + tri_kv_len, + tri_block_counts, + tri_last_block, + tri_block_ids, + num_speculative_tokens, + block_size_tokens, + ) + + torch.testing.assert_close(tri_offset, ref_offset) + torch.testing.assert_close(tri_kv_len, ref_kv_len) + torch.testing.assert_close(tri_block_counts, ref_block_counts) + torch.testing.assert_close(tri_last_block, ref_last_block) + torch.testing.assert_close(tri_block_ids, ref_block_ids) + torch.testing.assert_close(tri_release, ref_release) + torch.testing.assert_close(tri_mask, ref_mask) + + def test_prefill_requests_skip_rewind(self): + N = 4 + num_spec = 3 + block_size = 16 + + accepted_counts = torch.tensor([1, 0, 2, 0], device=DEVICE) + prefill_status = torch.tensor([0, 1, 0, 1], dtype=torch.int32, device=DEVICE) + last_kv_block_offset = torch.tensor([5, 10, 2, 7], device=DEVICE) + kv_length_offsets = torch.tensor([100, 200, 300, 400], device=DEVICE) + kv_block_counts = torch.tensor([3, 4, 2, 5], device=DEVICE) + last_kv_block_id = torch.tensor([10, 20, 30, 40], device=DEVICE) + kv_block_ids = torch.randint(0, 50, (N, 8), device=DEVICE) + + ref_offset, ref_kv_len, ref_block_counts, ref_last_block, ref_block_ids = _clone_tensors( + last_kv_block_offset, kv_length_offsets, kv_block_counts, last_kv_block_id, kv_block_ids + ) + tri_offset, tri_kv_len, tri_block_counts, tri_last_block, tri_block_ids = _clone_tensors( + last_kv_block_offset, kv_length_offsets, kv_block_counts, last_kv_block_id, kv_block_ids + ) + + ref_release, ref_mask = rewind_kv_cache_pytorch( + accepted_counts.clone(), + prefill_status.clone(), + ref_offset, + ref_kv_len, + ref_block_counts, + ref_last_block, + ref_block_ids, + num_spec, + block_size, + ) + tri_release, tri_mask = rewind_kv_cache( + accepted_counts.clone(), + prefill_status.clone(), + tri_offset, + tri_kv_len, + tri_block_counts, + tri_last_block, + tri_block_ids, + num_spec, + block_size, + ) + + # Prefill requests (indices 1, 3) should be unchanged. + for idx in [1, 3]: + assert ref_kv_len[idx] == kv_length_offsets[idx] + assert ref_offset[idx] == last_kv_block_offset[idx] + + torch.testing.assert_close(tri_offset, ref_offset) + torch.testing.assert_close(tri_kv_len, ref_kv_len) + torch.testing.assert_close(tri_block_counts, ref_block_counts) + torch.testing.assert_close(tri_last_block, ref_last_block) + torch.testing.assert_close(tri_block_ids, ref_block_ids) + torch.testing.assert_close(tri_mask, ref_mask) + + def test_block_boundary_crossing(self): + """When offset - rewind < 0, a block boundary is crossed.""" + N = 2 + num_spec = 3 + block_size = 16 + + accepted_counts = torch.tensor([0, 0], device=DEVICE) + prefill_status = torch.zeros(N, dtype=torch.int32, device=DEVICE) + last_kv_block_offset = torch.tensor([1, 10], device=DEVICE) + kv_length_offsets = torch.tensor([100, 200], device=DEVICE) + kv_block_counts = torch.tensor([3, 4], device=DEVICE) + last_kv_block_id = torch.tensor([50, 60], device=DEVICE) + kv_block_ids = torch.tensor( + [[10, 20, 50, -1, -1, -1, -1, -1], [15, 25, 35, 60, -1, -1, -1, -1]], device=DEVICE + ) + + ref_offset, ref_kv_len, ref_block_counts, ref_last_block, ref_block_ids = _clone_tensors( + last_kv_block_offset, kv_length_offsets, kv_block_counts, last_kv_block_id, kv_block_ids + ) + tri_offset, tri_kv_len, tri_block_counts, tri_last_block, tri_block_ids = _clone_tensors( + last_kv_block_offset, kv_length_offsets, kv_block_counts, last_kv_block_id, kv_block_ids + ) + + rewind_kv_cache_pytorch( + accepted_counts.clone(), + prefill_status.clone(), + ref_offset, + ref_kv_len, + ref_block_counts, + ref_last_block, + ref_block_ids, + num_spec, + block_size, + ) + rewind_kv_cache( + accepted_counts.clone(), + prefill_status.clone(), + tri_offset, + tri_kv_len, + tri_block_counts, + tri_last_block, + tri_block_ids, + num_spec, + block_size, + ) + + # Request 0: offset 1 - 3 = -2 → crosses boundary. + assert ref_block_counts[0] == 2 + assert tri_block_counts[0] == 2 + assert ref_last_block[0] == 20 # previous block + assert tri_last_block[0] == 20 + + # Request 1: offset 10 - 3 = 7 → no crossing. + assert ref_block_counts[1] == 4 + assert tri_block_counts[1] == 4 + + torch.testing.assert_close(tri_offset, ref_offset) + torch.testing.assert_close(tri_kv_len, ref_kv_len) + torch.testing.assert_close(tri_block_ids, ref_block_ids) + + def test_padding_programs(self): + """Padding slots (pid >= num_active_requests) must produce safe no-ops.""" + N = 8 # grid size + active = 3 + num_spec = 2 + block_size = 16 + + accepted_counts = torch.randint(0, num_spec + 1, (N,), device=DEVICE) + prefill_status = torch.zeros(N, dtype=torch.int32, device=DEVICE) + last_kv_block_offset = torch.randint(0, block_size, (N,), device=DEVICE) + kv_length_offsets = torch.randint(block_size, block_size * 4, (N,), device=DEVICE) + kv_block_counts = torch.randint(2, 6, (N,), device=DEVICE) + last_kv_block_id = torch.randint(0, 100, (N,), device=DEVICE) + kv_block_ids = torch.randint(0, 100, (N, 8), device=DEVICE) + + ref_offset, ref_kv_len, ref_block_counts, ref_last_block, ref_block_ids = _clone_tensors( + last_kv_block_offset, kv_length_offsets, kv_block_counts, last_kv_block_id, kv_block_ids + ) + tri_offset, tri_kv_len, tri_block_counts, tri_last_block, tri_block_ids = _clone_tensors( + last_kv_block_offset, kv_length_offsets, kv_block_counts, last_kv_block_id, kv_block_ids + ) + + rewind_kv_cache_pytorch( + accepted_counts.clone(), + prefill_status.clone(), + ref_offset, + ref_kv_len, + ref_block_counts, + ref_last_block, + ref_block_ids, + num_spec, + block_size, + num_active_requests=active, + ) + tri_release, tri_mask = rewind_kv_cache( + accepted_counts.clone(), + prefill_status.clone(), + tri_offset, + tri_kv_len, + tri_block_counts, + tri_last_block, + tri_block_ids, + num_spec, + block_size, + num_active_requests=active, + ) + + # Active slots should match. + torch.testing.assert_close(tri_offset[:active], ref_offset[:active]) + torch.testing.assert_close(tri_kv_len[:active], ref_kv_len[:active]) + torch.testing.assert_close(tri_block_counts[:active], ref_block_counts[:active]) + torch.testing.assert_close(tri_last_block[:active], ref_last_block[:active]) + torch.testing.assert_close(tri_block_ids[:active], ref_block_ids[:active]) + + # Padding slots: release=0, mask=False. + assert (tri_release[active:] == 0).all() + assert (~tri_mask[active:]).all() + + def test_empty(self): + N = 0 + blocks_to_release, remove_mask = rewind_kv_cache( + torch.empty(0, device=DEVICE, dtype=torch.int64), + torch.empty(0, device=DEVICE, dtype=torch.int32), + torch.empty(0, device=DEVICE, dtype=torch.int64), + torch.empty(0, device=DEVICE, dtype=torch.int64), + torch.empty(0, device=DEVICE, dtype=torch.int64), + torch.empty(0, device=DEVICE, dtype=torch.int64), + torch.empty(0, 8, device=DEVICE, dtype=torch.int64), + num_speculative_tokens=2, + block_size_tokens=16, + ) + assert blocks_to_release.shape[0] == 0 + assert remove_mask.shape[0] == 0 + + +class TestVerifySpeculativeTokens: + """Tests for the verify_speculative_tokens Triton kernel.""" + + def _make_scenario(self, num_decode, num_prefill, num_spec, *, match_pattern=None): + """Build input/output token tensors for testing. + + Args: + match_pattern: list of ints per decode request indicating how many + speculative tokens should match (0 means only base accepted). + If None, generates random matches. + """ + stride = num_spec + 1 + decode_len = num_decode * stride + total_len = decode_len + num_prefill + + input_tokens = torch.randint(1, 1000, (total_len,), device=DEVICE) + output_tokens = torch.randint(1, 1000, (total_len,), device=DEVICE) + + if match_pattern is not None: + assert len(match_pattern) == num_decode + for req_idx, num_match in enumerate(match_pattern): + base = req_idx * stride + for s in range(num_match): + output_tokens[base + s] = input_tokens[base + s + 1] + + return input_tokens, output_tokens + + @pytest.mark.parametrize( + "num_decode,num_prefill,num_spec", [(1, 0, 2), (3, 0, 2), (3, 2, 2), (0, 3, 2), (5, 3, 4)] + ) + def test_basic(self, num_decode, num_prefill, num_spec): + input_tokens, output_tokens = self._make_scenario(num_decode, num_prefill, num_spec) + + ref_last, ref_mask, ref_input = verify_speculative_tokens_pytorch( + input_tokens.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + tri_last, tri_mask, tri_input = verify_speculative_tokens( + input_tokens.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + + torch.testing.assert_close(tri_mask, ref_mask) + torch.testing.assert_close(tri_last, ref_last) + + def test_all_accepted(self): + """All speculative tokens match → all accepted.""" + num_decode, num_prefill, num_spec = 3, 0, 3 + input_tokens, output_tokens = self._make_scenario( + num_decode, num_prefill, num_spec, match_pattern=[3, 3, 3] + ) + + ref_last, ref_mask, _ = verify_speculative_tokens_pytorch( + input_tokens.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + tri_last, tri_mask, _ = verify_speculative_tokens( + input_tokens.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + + assert ref_mask.all() + torch.testing.assert_close(tri_mask, ref_mask) + torch.testing.assert_close(tri_last, ref_last) + + def test_none_accepted(self): + """No speculative tokens match → only base tokens accepted.""" + num_decode, num_prefill, num_spec = 3, 0, 3 + input_tokens, output_tokens = self._make_scenario( + num_decode, num_prefill, num_spec, match_pattern=[0, 0, 0] + ) + + ref_last, ref_mask, _ = verify_speculative_tokens_pytorch( + input_tokens.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + tri_last, tri_mask, _ = verify_speculative_tokens( + input_tokens.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + + stride = num_spec + 1 + for req in range(num_decode): + base = req * stride + assert ref_mask[base].item() is True + assert not ref_mask[base + 1 : base + stride].any() + + torch.testing.assert_close(tri_mask, ref_mask) + torch.testing.assert_close(tri_last, ref_last) + + def test_mixed_match_pattern(self): + """Different acceptance counts per request.""" + num_decode, num_prefill, num_spec = 3, 1, 3 + input_tokens, output_tokens = self._make_scenario( + num_decode, num_prefill, num_spec, match_pattern=[1, 3, 0] + ) + + ref_last, ref_mask, _ = verify_speculative_tokens_pytorch( + input_tokens.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + tri_last, tri_mask, _ = verify_speculative_tokens( + input_tokens.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + + torch.testing.assert_close(tri_mask, ref_mask) + torch.testing.assert_close(tri_last, ref_last) + + def test_2d_input(self): + """Input tokens with shape [1, total_len] should be squeezed.""" + num_decode, num_prefill, num_spec = 2, 1, 2 + input_tokens, output_tokens = self._make_scenario(num_decode, num_prefill, num_spec) + input_2d = input_tokens.unsqueeze(0) + + ref_last, ref_mask, _ = verify_speculative_tokens_pytorch( + input_2d.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + tri_last, tri_mask, _ = verify_speculative_tokens( + input_2d.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + + torch.testing.assert_close(tri_mask, ref_mask) + torch.testing.assert_close(tri_last, ref_last) + + +class TestPrepareNextForwardPass: + """Tests for the prepare_next_forward_pass Triton kernel.""" + + def _setup(self, num_decode, num_prefill, num_spec): + stride = num_spec + 1 + active = num_decode + num_prefill + decode_len = num_decode * stride + total_len = decode_len + num_prefill + + output_tokens = torch.randint(1, 1000, (total_len,), device=DEVICE, dtype=torch.int64) + required_logit_indices = torch.arange(total_len, device=DEVICE, dtype=torch.int64) + input_tokens = torch.randint(1, 1000, (total_len,), device=DEVICE, dtype=torch.int64) + + accepted_mask = torch.zeros(total_len, device=DEVICE, dtype=torch.bool) + last_one_indices = torch.empty(active, device=DEVICE, dtype=torch.int64) + + for req in range(num_decode): + base = req * stride + num_match = torch.randint(0, num_spec + 1, (1,)).item() + for j in range(stride): + if j <= num_match: + accepted_mask[base + j] = True + last_one_indices[req] = base + num_match + + for p in range(num_prefill): + idx = decode_len + p + accepted_mask[idx] = True + last_one_indices[num_decode + p] = idx + + return output_tokens, required_logit_indices, input_tokens, accepted_mask, last_one_indices + + @pytest.mark.parametrize( + "num_decode,num_prefill,num_spec", [(1, 0, 2), (3, 0, 2), (3, 2, 2), (0, 3, 2), (5, 3, 4)] + ) + def test_basic(self, num_decode, num_prefill, num_spec): + (output_tokens, required_logit_indices, input_tokens, accepted_mask, last_one_indices) = ( + self._setup(num_decode, num_prefill, num_spec) + ) + + active = num_decode + num_prefill + + ref_sampled = torch.zeros(active, device=DEVICE, dtype=torch.int64) + ref_last_seq = torch.zeros(active, device=DEVICE, dtype=torch.int64) + ref_accepted = torch.full((num_decode, num_spec), -1, device=DEVICE, dtype=torch.int64) + ref_counts = torch.zeros(num_decode, device=DEVICE, dtype=torch.int64) + + tri_sampled = torch.zeros(active, device=DEVICE, dtype=torch.int64) + tri_last_seq = torch.zeros(active, device=DEVICE, dtype=torch.int64) + tri_accepted = torch.full( + (max(num_decode, 1), num_spec), -1, device=DEVICE, dtype=torch.int64 + ) + tri_counts = torch.zeros(max(num_decode, 1), device=DEVICE, dtype=torch.int64) + + prepare_next_forward_pass_pytorch( + num_decode, + output_tokens, + required_logit_indices, + last_one_indices, + accepted_mask, + input_tokens, + ref_sampled, + ref_last_seq, + ref_accepted, + ref_counts, + num_spec, + ) + + prepare_next_forward_pass( + num_decode, + output_tokens, + required_logit_indices, + last_one_indices, + accepted_mask, + input_tokens, + tri_sampled, + tri_last_seq, + tri_accepted, + tri_counts, + num_spec, + ) + + torch.testing.assert_close(tri_sampled, ref_sampled) + torch.testing.assert_close(tri_last_seq, ref_last_seq) + if num_decode > 0: + torch.testing.assert_close(tri_accepted[:num_decode], ref_accepted[:num_decode]) + torch.testing.assert_close(tri_counts[:num_decode], ref_counts[:num_decode]) + + def test_empty(self): + """Zero active requests should be a no-op.""" + last_one_indices = torch.empty(0, device=DEVICE, dtype=torch.int64) + prepare_next_forward_pass( + num_decode_requests=0, + output_tokens=torch.empty(0, device=DEVICE, dtype=torch.int64), + required_logit_indices=torch.empty(0, device=DEVICE, dtype=torch.int64), + last_one_indices=last_one_indices, + accepted_tokens_mask=torch.empty(0, device=DEVICE, dtype=torch.bool), + input_tokens=torch.empty(0, device=DEVICE, dtype=torch.int64), + sampled_tokens_buf=torch.empty(0, device=DEVICE, dtype=torch.int64), + last_accepted_seq_buf=torch.empty(0, device=DEVICE, dtype=torch.int64), + accepted_tokens_per_request=torch.empty(0, 2, device=DEVICE, dtype=torch.int64), + accepted_token_counts=torch.empty(0, device=DEVICE, dtype=torch.int64), + num_speculative_tokens=2, + ) + + +class TestMambaStateSelectiveCopy: + """Tests for the mamba_state_selective_copy Triton kernel.""" + + @pytest.mark.parametrize("num_requests", [1, 4, 8]) + @pytest.mark.parametrize("num_layers", [1, 3]) + def test_basic(self, num_requests, num_layers): + N = num_requests + M = N # 1:1 request-to-slot mapping for simplicity + S = 4 # speculative tokens + 1 + state_shape = (16, 32) # arbitrary state dimensions + + intermediate = torch.randn(num_layers, M, S, *state_shape, device=DEVICE) + current_ref = torch.randn(num_layers, M, *state_shape, device=DEVICE) + current_tri = current_ref.clone() + + prefill_status = torch.zeros(N, dtype=torch.int32, device=DEVICE) + state_idx = torch.arange(N, device=DEVICE, dtype=torch.int64) + accepted_counts = torch.randint(0, S, (N,), device=DEVICE, dtype=torch.int64) + + mamba_state_selective_copy_pytorch( + intermediate, current_ref, prefill_status, state_idx, accepted_counts, num_layers + ) + mamba_state_selective_copy( + intermediate, current_tri, prefill_status, state_idx, accepted_counts, num_layers + ) + + torch.testing.assert_close(current_tri, current_ref) + + def test_prefill_skipped(self): + N = 4 + num_layers = 2 + M = N + S = 3 + state_shape = (8,) + + intermediate = torch.randn(num_layers, M, S, *state_shape, device=DEVICE) + current_ref = torch.randn(num_layers, M, *state_shape, device=DEVICE) + current_tri = current_ref.clone() + current_orig = current_ref.clone() + + prefill_status = torch.tensor([0, 1, 0, 1], dtype=torch.int32, device=DEVICE) + state_idx = torch.arange(N, device=DEVICE, dtype=torch.int64) + accepted_counts = torch.tensor([1, 0, 2, 0], device=DEVICE, dtype=torch.int64) + + mamba_state_selective_copy_pytorch( + intermediate, current_ref, prefill_status, state_idx, accepted_counts, num_layers + ) + mamba_state_selective_copy( + intermediate, current_tri, prefill_status, state_idx, accepted_counts, num_layers + ) + + # Prefill slots should be unchanged from original. + for layer in range(num_layers): + for slot in [1, 3]: + torch.testing.assert_close(current_ref[layer, slot], current_orig[layer, slot]) + torch.testing.assert_close(current_tri[layer, slot], current_orig[layer, slot]) + + torch.testing.assert_close(current_tri, current_ref) + + def test_noncontiguous_state_idx(self): + """state_idx does not have to be a simple arange.""" + N = 3 + num_layers = 2 + M = 6 # more slots than requests + S = 3 + state_shape = (8, 4) + + intermediate = torch.randn(num_layers, M, S, *state_shape, device=DEVICE) + current_ref = torch.randn(num_layers, M, *state_shape, device=DEVICE) + current_tri = current_ref.clone() + + prefill_status = torch.zeros(N, dtype=torch.int32, device=DEVICE) + state_idx = torch.tensor([1, 4, 0], device=DEVICE, dtype=torch.int64) + accepted_counts = torch.tensor([2, 0, 1], device=DEVICE, dtype=torch.int64) + + mamba_state_selective_copy_pytorch( + intermediate, current_ref, prefill_status, state_idx, accepted_counts, num_layers + ) + mamba_state_selective_copy( + intermediate, current_tri, prefill_status, state_idx, accepted_counts, num_layers + ) + + torch.testing.assert_close(current_tri, current_ref) + + def test_empty(self): + """Zero requests should be a no-op.""" + num_layers = 2 + state_shape = (8,) + intermediate = torch.randn(num_layers, 4, 3, *state_shape, device=DEVICE) + current = torch.randn(num_layers, 4, *state_shape, device=DEVICE) + current_before = current.clone() + + mamba_state_selective_copy( + intermediate, + current, + torch.empty(0, dtype=torch.int32, device=DEVICE), + torch.empty(0, dtype=torch.int64, device=DEVICE), + torch.empty(0, dtype=torch.int64, device=DEVICE), + num_layers, + ) + + torch.testing.assert_close(current, current_before) + + +class TestStressRandom: + """Randomized stress tests running all four kernels with varied inputs.""" + + @pytest.mark.parametrize("trial", range(5)) + def test_rewind_random(self, trial): + torch.manual_seed(42 + trial) + N = torch.randint(1, 32, (1,)).item() + num_spec = torch.randint(1, 6, (1,)).item() + block_size = 2 ** torch.randint(3, 7, (1,)).item() + max_blocks = torch.randint(4, 16, (1,)).item() + + accepted_counts = torch.randint(0, num_spec + 1, (N,), device=DEVICE) + prefill_status = (torch.rand(N, device=DEVICE) > 0.7).to(torch.int32) + last_kv_block_offset = torch.randint(0, block_size, (N,), device=DEVICE) + kv_length_offsets = torch.randint(block_size, block_size * 4, (N,), device=DEVICE) + kv_block_counts = torch.randint(2, max_blocks, (N,), device=DEVICE) + last_kv_block_id = torch.randint(0, 200, (N,), device=DEVICE) + kv_block_ids = torch.randint(0, 200, (N, max_blocks), device=DEVICE) + + ref_args = _clone_tensors( + last_kv_block_offset, kv_length_offsets, kv_block_counts, last_kv_block_id, kv_block_ids + ) + tri_args = _clone_tensors( + last_kv_block_offset, kv_length_offsets, kv_block_counts, last_kv_block_id, kv_block_ids + ) + + ref_release, ref_mask = rewind_kv_cache_pytorch( + accepted_counts.clone(), prefill_status.clone(), *ref_args, num_spec, block_size + ) + tri_release, tri_mask = rewind_kv_cache( + accepted_counts.clone(), prefill_status.clone(), *tri_args, num_spec, block_size + ) + + for r, t in zip(ref_args, tri_args): + torch.testing.assert_close(t, r) + torch.testing.assert_close(tri_release, ref_release) + torch.testing.assert_close(tri_mask, ref_mask) + + @pytest.mark.parametrize("trial", range(5)) + def test_verify_random(self, trial): + torch.manual_seed(42 + trial) + num_decode = torch.randint(0, 16, (1,)).item() + num_prefill = torch.randint(0, 8, (1,)).item() + if num_decode == 0 and num_prefill == 0: + num_prefill = 1 + num_spec = torch.randint(1, 6, (1,)).item() + + stride = num_spec + 1 + total_len = num_decode * stride + num_prefill + + input_tokens = torch.randint(1, 500, (total_len,), device=DEVICE) + output_tokens = torch.randint(1, 500, (total_len,), device=DEVICE) + + # Randomly make some speculative tokens match. + for req in range(num_decode): + base = req * stride + num_match = torch.randint(0, num_spec + 1, (1,)).item() + for s in range(num_match): + output_tokens[base + s] = input_tokens[base + s + 1] + + ref_last, ref_mask, _ = verify_speculative_tokens_pytorch( + input_tokens.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + tri_last, tri_mask, _ = verify_speculative_tokens( + input_tokens.clone(), output_tokens.clone(), num_decode, num_prefill, num_spec + ) + + torch.testing.assert_close(tri_mask, ref_mask) + torch.testing.assert_close(tri_last, ref_last) diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index 9c6564f6989..5c96836e9ba 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -14,7 +14,7 @@ from transformer_engine.pytorch.fp8 import check_fp8_support from megatron.core import parallel_state -from megatron.core.inference.config import InferenceConfig +from megatron.core.inference.config import InferenceConfig, MambaInferenceStateConfig from megatron.core.inference.contexts import DynamicInferenceContext, StaticInferenceContext from megatron.core.inference.contexts.dynamic_context import MaxSequenceLengthOverflowError from megatron.core.inference.inference_request import ( @@ -34,6 +34,8 @@ get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec +from megatron.core.models.hybrid.hybrid_model import HybridModel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.module import Float16Module @@ -43,7 +45,7 @@ from tests.unit_tests.test_utilities import Utils -class TestTextGenerationController: +class TextGenerationControllerTestBase: def setup_model( self, @@ -64,11 +66,10 @@ def setup_model( sequence_parallel: bool = False, expert_model_parallel_size: int = 1, num_moe_experts: int = None, + hybrid_layer_pattern: str = None, + sampling_backend: str = 'torch', + cuda_graph_impl: str = 'none', ): - Utils.initialize_model_parallel( - tensor_model_parallel_size=tensor_model_parallel_size, - pipeline_model_parallel_size=pipeline_model_parallel_size, - ) if use_training_random_init: # This is necessary to induce the training behavior which permutes the random seed # for every rank; otherwise, every rank will have the same seed. @@ -98,31 +99,52 @@ def setup_model( expert_model_parallel_size=expert_model_parallel_size, num_moe_experts=num_moe_experts, add_bias_linear=num_moe_experts is None, + **( + dict(is_hybrid_model=True, mamba_num_heads=2, mamba_head_dim=16, mamba_num_groups=2) + if hybrid_layer_pattern + else {} + ), + cuda_graph_impl=cuda_graph_impl, ) if dtype == torch.bfloat16: transformer_config.bf16 = True - layer_spec = get_gpt_layer_local_spec() + mamba_inference_state_config = None + if hybrid_layer_pattern: + model = HybridModel( + config=transformer_config, + hybrid_stack_spec=hybrid_stack_spec, + vocab_size=self.vocab_size, + max_sequence_length=self.sequence_length, + parallel_output=True, + hybrid_layer_pattern=hybrid_layer_pattern, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + ).cuda() + mamba_inference_state_config = MambaInferenceStateConfig.from_model(model) + else: + layer_spec = get_gpt_layer_local_spec() - mtp_block_spec = None - if mtp_num_layers > 0: - mtp_block_spec = get_gpt_mtp_block_spec( - config=transformer_config, spec=layer_spec, use_transformer_engine=False - ) + mtp_block_spec = None + if mtp_num_layers > 0: + mtp_block_spec = get_gpt_mtp_block_spec( + config=transformer_config, spec=layer_spec, use_transformer_engine=False + ) + + model = GPTModel( + config=transformer_config, + transformer_layer_spec=layer_spec, + vocab_size=self.vocab_size, + max_sequence_length=self.sequence_length, + parallel_output=True, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + mtp_block_spec=mtp_block_spec, + ).cuda() - gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=layer_spec, - vocab_size=self.vocab_size, - max_sequence_length=self.sequence_length, - parallel_output=True, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), - mtp_block_spec=mtp_block_spec, - ).cuda() - gpt_model.eval() + model.eval() if dtype == torch.bfloat16: - gpt_model = Float16Module(gpt_model.config, gpt_model) + model = Float16Module(model.config, model) if static: inference_context = StaticInferenceContext( @@ -142,10 +164,12 @@ def setup_model( block_size_tokens=block_size_tokens, enable_prefix_caching=enable_prefix_caching, max_requests=max_requests, + mamba_inference_state_config=mamba_inference_state_config, + sampling_backend=sampling_backend, ), ) - inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_context) + inference_wrapped_model = GPTInferenceWrapper(model, inference_context) inference_wrapped_model.model_is_pipeline_parallel = not ( parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() @@ -157,6 +181,15 @@ def setup_model( inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer ) + +class TestTextGenerationController(TextGenerationControllerTestBase): + + @classmethod + def setup_class(cls): + Utils.initialize_model_parallel( + tensor_model_parallel_size=2, pipeline_model_parallel_size=1 + ) + @classmethod def teardown_class(cls): Utils.destroy_model_parallel() @@ -249,28 +282,34 @@ def detokenize(self, inp, skip_special_tokens=False): sampled_logits >= expected_min_value ), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}" - @pytest.mark.parametrize("backend", ["torch"]) + @pytest.mark.parametrize("backend", ["torch", "flashinfer"]) @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) def test_sample_from_dynamic_logits( self, backend: str, materialize_only_last_token_logits: bool ): - batch_size = 12 + if backend == "flashinfer": + pytest.importorskip("flashinfer") + batch_size = 15 self.setup_model( torch.float32, batch_size=batch_size, static=False, materialize_only_last_token_logits=materialize_only_last_token_logits, + sampling_backend=backend, ) self.mock_tokenizer.eod = self.vocab_size context = self.text_generation_controller.inference_wrapped_model.inference_context # Prepare sampling params in human-readable format, to aid with test maintenance. + # The temperature=0 / top_k=1 bucket exercises the greedy path: torch short-circuits + # to argmax, flashinfer relies on its temperature clamp to avoid divide-by-zero. sampling_test_cases: List[Tuple[SamplingParams, List[int]]] = [ (SamplingParams(temperature=0.1, top_p=0.01), [9, 6, 10]), (SamplingParams(temperature=5.0, top_k=15), [0, 3, 2]), (SamplingParams(top_p=0.8), [4, 1, 7]), (SamplingParams(temperature=10.0, top_k=5), [11, 5, 8]), + (SamplingParams(temperature=0.0, top_k=1), [12, 13, 14]), ] # For non-torch backends, test simultaneous top_k and top_p sampling. if backend != "torch": @@ -286,25 +325,21 @@ def test_sample_from_dynamic_logits( temp_values = torch.Tensor([s.temperature for s in rev_sampling_dict]) top_k_values = torch.Tensor([s.top_k for s in rev_sampling_dict]).to(torch.int32) top_p_values = torch.Tensor([s.top_p for s in rev_sampling_dict]) - request_metadata = { - "temperature": temp_values, - "top_k": top_k_values, - "top_p": top_p_values, - } - self.text_generation_controller._request_metadata = request_metadata - self.text_generation_controller._sampling_backend = backend + context.active_request_metadata["temperature"][:batch_size].copy_(temp_values) + context.active_request_metadata["top_k"][:batch_size].copy_(top_k_values) + context.active_request_metadata["top_p"][:batch_size].copy_(top_p_values) context.padded_active_token_count = batch_size - context.request_query_lengths = torch.ones(batch_size, dtype=torch.int32) + context.request_query_lengths = torch.ones(batch_size, dtype=torch.int32, device='cuda') context.paused_request_count = 0 context.total_request_count = batch_size - - # Bookkeeping. - self.text_generation_controller._dynamic_step_sample_bookkeeping() + context.num_prefill_requests = 0 + context.pad_active_slices() # Sampling. logits = torch.arange(0, self.vocab_size).repeat(batch_size, 1).unsqueeze(0).float().cuda() - self.text_generation_controller._dynamic_step_sample_logits(logits) + self.text_generation_controller._all_logits_cuda = logits + self.text_generation_controller._dynamic_step_sample_logits() sampled_logits = self.text_generation_controller._sampled_tokens_cuda[:batch_size] vocab_indices = torch.arange(self.vocab_size).cuda() @@ -324,13 +359,13 @@ def test_sample_from_dynamic_logits( sampled_l.masked_fill_(top_k_mask, 0.0) top_p_mask = sampled_l.cumsum(dim=-1) > top_p_values.unsqueeze(1) + # When `top_p` is enabled, but the cumulative probs don't actually filter anything, + # our constraint reduces to top_k alone. + start_idx = torch.clamp(self.vocab_size - top_k_values, min=0).long() first_excluded = torch.where( - top_p_mask.any(dim=-1), - top_p_mask.float().argmax(dim=-1), - torch.full((batch_size,), self.vocab_size, device=top_p_mask.device), + top_p_mask.any(dim=-1), top_p_mask.float().argmax(dim=-1), start_idx + 1 ) last_included = torch.clamp(first_excluded - 1, min=0) - start_idx = torch.clamp(self.vocab_size - top_k_values, min=0).long() last_included = torch.max(last_included, start_idx) expected_min_values = l.gather(1, last_included.unsqueeze(1)).squeeze(1) assert torch.all( @@ -828,15 +863,10 @@ def test_dynamic_top_n_logprobs_calculation( # Prepare sampling params top_n = 5 - request_metadata = { - "top_n_logprobs": torch.full((batch_size,), top_n, dtype=torch.int32).cuda(), - "skip_prompt_log_probs": torch.full( - (batch_size,), float(skip_prompt_log_probs), dtype=torch.float32 - ).cuda(), - } - self.text_generation_controller._request_metadata = request_metadata - self.text_generation_controller._active_request_count = batch_size - self.text_generation_controller._active_request_slice = slice(0, batch_size) + context.active_request_metadata["top_n_logprobs"][:batch_size].fill_(top_n) + context.active_request_metadata["skip_prompt_log_probs"][:batch_size].fill_( + skip_prompt_log_probs + ) if materialize_only_last_token_logits: # Decode mode: logits for last tokens only @@ -852,7 +882,7 @@ def test_dynamic_top_n_logprobs_calculation( # Calculate top-n logprobs top_n_results = self.text_generation_controller._dynamic_step_calculate_top_n_logprobs( - logits, log_probs_tensor + log_probs_tensor ) # Validate results @@ -897,7 +927,7 @@ def test_dynamic_top_n_logprobs_calculation( # Calculate top-n logprobs top_n_results = self.text_generation_controller._dynamic_step_calculate_top_n_logprobs( - logits, log_probs_tensor + log_probs_tensor ) # Validate results @@ -931,164 +961,24 @@ def test_dynamic_top_n_logprobs_calculation( top_n_indices.shape[0] == top_n ), f"Request {req_idx}, token {token_idx}: expected {top_n} indices" - @pytest.mark.parametrize("static", [True, False]) - @pytest.mark.parametrize("tp_size", [1, 2]) - @pytest.mark.parametrize("pp_size", [1, 2]) - def test_sampled_tokens_match_with_parallelism(self, static, tp_size, pp_size): - """ - Verify that sampled tokens match across all parallel ranks. - """ - if tp_size == 1 and pp_size == 1: - pytest.skip(reason="Test requires model parallel size > 1.") - - if not static and not is_fa_min_version("2.7.3"): - pytest.skip(reason="Need latest flash attn for dynamic batching") - - # Ensure that we are using the training setup for random seed initialization - # so that every rank has a different seed - self.setup_model( - dtype=torch.bfloat16, - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - static=static, - use_training_random_init=True, - ) - - self.mock_tokenizer.vocab_size = self.vocab_size - self.mock_tokenizer.eod = self.vocab_size - 1 - self.mock_tokenizer.detokenize.side_effect = lambda x, skip_special_tokens=False: ' '.join( - [ - ''.join(random.choices(string.ascii_letters, k=random.randint(4, 10))) - for _ in range(len(x)) - ] - ) - self.mock_tokenizer.offsets.side_effect = lambda _, s: [ - i for i, c in enumerate(s) if c == ' ' - ] + [len(s)] - - # Prepare requests. - active_requests: Dict[str, InferenceRequest] = OrderedDict() - for i in range(self.batch_size): - prompt = "sample" * (i + 1) - prompt_tokens = torch.randint( - low=0, high=self.vocab_size - 1, size=(len(prompt),) - ).tolist() - request_id = str(i) - inference_request = InferenceRequest( - request_id=request_id, - prompt=prompt, - sampling_params=SamplingParams( - top_k=10, num_tokens_to_generate=25, return_log_probs=True - ), - arrival_time=time.time(), - prompt_tokens=prompt_tokens, - status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, - ) - active_requests[request_id] = inference_request - - # Generate tokens. - if static: - requests = self.text_generation_controller.generate_all_output_tokens_static_batch( - active_requests - ) - all_generated_tokens = [req.generated_tokens.tolist() for req in requests.values()] - else: - all_generated_tokens = [[] for _ in range(len(active_requests))] - context = self.text_generation_controller.inference_wrapped_model.inference_context - for request_id, request in active_requests.items(): - context.add_request( - DynamicInferenceRequest( - request_id=int(request_id), - prompt_tokens=torch.tensor( - request.prompt_tokens, - dtype=torch.long, - device=torch.cuda.current_device(), - ), - sampling_params=SamplingParams( - top_k=10, return_log_probs=True, num_tokens_to_generate=25 - ), - ) - ) - expected_active_requests = set(int(x) for x in active_requests.keys()) - while context.has_unfinished_requests(): - result = self.text_generation_controller.generate_output_tokens_dynamic_batch() - new_tokens = result["sample"] - active_ids = result["active_request_ids"].tolist() - finished_ids = result["finished_request_ids"].tolist() - assert len(new_tokens) == len(expected_active_requests) - assert set(active_ids) == expected_active_requests - expected_active_requests -= set(finished_ids) - for i, token in enumerate(new_tokens.tolist()): - all_generated_tokens[i].append(token) - - # Wait for all communication to complete before proceeding. - torch.distributed.barrier() - - # Collect all the generated tokens for each request from each rank in the - # model parallel group. - mp_group = parallel_state.get_model_parallel_group() - mp_ranks = torch.distributed.get_process_group_ranks(mp_group) - local_rank = torch.distributed.get_rank() - tokens_per_rank = {} - tokens_per_rank[local_rank] = all_generated_tokens - - for i in mp_ranks: - # Start by communicating the batch size so each rank knows how many requests to expect. - if i == local_rank: - batch_size = torch.tensor( - len(tokens_per_rank[local_rank]), - dtype=torch.long, - device=torch.cuda.current_device(), - ) - else: - tokens_per_rank[i] = [] - batch_size = torch.empty(1, dtype=torch.long, device=torch.cuda.current_device()) - torch.distributed.broadcast(batch_size, group=mp_group, src=i) - - for j in range(batch_size.item()): - # For each request, communicate the sequence length followed by the actual tokens. - if i == local_rank: - sequence_length = torch.tensor( - len(tokens_per_rank[local_rank][j]), - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - else: - sequence_length = torch.empty( - 1, dtype=torch.int32, device=torch.cuda.current_device() - ) - torch.distributed.broadcast(sequence_length, group=mp_group, src=i) - - if i == local_rank: - generated_tokens = torch.tensor( - tokens_per_rank[local_rank][j], - dtype=torch.long, - device=torch.cuda.current_device(), - ) - else: - generated_tokens = torch.empty( - sequence_length.item(), dtype=torch.long, device=torch.cuda.current_device() - ) - torch.distributed.broadcast(generated_tokens, group=mp_group, src=i) - - if i != local_rank: - tokens_per_rank[i].append(generated_tokens.tolist()) - - # Ensure that every rank in the model parallel group produced the same tokens. - for i in mp_ranks: - if i == local_rank: - continue - for j, (expected, actual) in enumerate( - zip(tokens_per_rank[local_rank], tokens_per_rank[i]) - ): - assert ( - expected == actual - ), f"Rank {i} tokens differ from rank {local_rank} tokens for request {j}" - @pytest.mark.internal - def test_speculative_verify_tokens(self): + @pytest.mark.parametrize("backend", ["torch", "flashinfer"]) + @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) + def test_speculative_verify_tokens( + self, backend: str, materialize_only_last_token_logits: bool + ): """Test consecutive token acceptance logic for speculative decoding.""" - self.setup_model(torch.float32, static=False, num_speculative_tokens=2, max_requests=2) + if backend == "flashinfer": + pytest.importorskip("flashinfer") + self.setup_model( + torch.float32, + static=False, + num_speculative_tokens=2, + max_requests=2, + mtp_num_layers=2, + sampling_backend=backend, + materialize_only_last_token_logits=materialize_only_last_token_logits, + ) # Enable speculative decoding self.text_generation_controller.num_speculative_tokens = 2 @@ -1101,9 +991,11 @@ def test_speculative_verify_tokens(self): ctx.request_query_lengths = torch.tensor( [3, 3], dtype=torch.int32, device='cuda' ) # 1 sampled + 2 spec + ctx.num_prefill_requests = 0 + ctx.pad_active_slices() # Init accepted tokens tensors - self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._init_mtp_sampling_tensors() # Mock inputs: [Req 1 sampled, Req 1 spec1, Req 1 spec2, Req 2 sampled, Req 2 spec1, Req 2 spec2] # Target tokens (what the model was fed): [T0, T1, T2, T3, T4, T5] @@ -1122,18 +1014,16 @@ def mock_sampling_func(logits, *args, **kwargs): # The verification logic only uses base tokens, so we can return zeros here. return torch.zeros((12,), dtype=torch.long, device='cuda') - # Override sampling to return our predictable mock outputs - self.text_generation_controller._torch_sampling_buckets = [([0, 1], 1.0, 1, 0.0)] - self.text_generation_controller._torch_sampling_func = mock.MagicMock( + # Override sampling to return our predictable mock outputs. + self.text_generation_controller._sampling.sample_kernel = mock.MagicMock( side_effect=mock_sampling_func ) # Mock logits matching input shape logits = torch.randn(1, 6, self.vocab_size, device='cuda') + self.text_generation_controller._all_logits_cuda = logits - self.text_generation_controller._dynamic_step_sample_logits_and_verify_tokens( - logits, input_ids - ) + self.text_generation_controller._dynamic_step_sample_logits_and_verify_tokens(input_ids) # Verify acceptance counts accepted_counts = self.text_generation_controller._accepted_token_counts_per_request[:2] @@ -1156,58 +1046,66 @@ def test_rewind_kv_cache(self, is_hybrid_model): num_speculative_tokens=3, block_size_tokens=4, max_requests=16, + hybrid_layer_pattern="***M" if is_hybrid_model else None, ) self.text_generation_controller.num_speculative_tokens = 3 ctx = self.text_generation_controller.inference_wrapped_model.inference_context + context_device = ctx.request_kv_length_offsets.device ctx.total_request_count = 2 ctx.paused_request_count = 0 - ctx.request_in_prefill_status_tensor = torch.tensor([0, 0], device='cuda') + ctx.request_in_prefill_status_tensor[:2] = torch.tensor( + [0, 0], dtype=torch.int32, device=context_device + ) # Initialize allocator and states ctx.kv_block_allocator.total_avail = 100 - ctx.request_kv_length_offsets[:2] = torch.tensor([10, 15], device='cuda') - ctx.request_kv_block_counts[:2] = torch.tensor([3, 4], device='cuda') + ctx.request_kv_length_offsets[:2] = torch.tensor([10, 15], device=context_device) + ctx.request_kv_block_counts[:2] = torch.tensor([3, 4], device=context_device) # Req 0: offset 2. Rewinding 2 tokens -> offset 0. No block released. # Req 1: offset 1. Rewinding 3 tokens -> offset 2 (prev block). 1 block released. - ctx.request_last_kv_block_offset[:2] = torch.tensor([2, 1], device='cuda') - ctx.request_last_kv_block_id[:2] = torch.tensor([50, 60], device='cuda') + ctx.request_last_kv_block_offset[:2] = torch.tensor([2, 1], device=context_device) + ctx.request_last_kv_block_id[:2] = torch.tensor([50, 60], device=context_device) ctx.request_to_kv_block_ids[:2, :4] = torch.tensor( - [[48, 49, 50, -1], [57, 58, 59, 60]], dtype=torch.int, device='cuda' + [[48, 49, 50, -1], [57, 58, 59, 60]], dtype=torch.int, device=context_device ) if is_hybrid_model: - ctx.is_hybrid_model = True - ctx.mamba_metadata = mock.MagicMock() - ctx.mamba_metadata.request_to_mamba_state_idx = torch.tensor([0, 1], device='cuda') - ctx.mamba_ssm_states = torch.zeros((1, 2, 16), device='cuda') - ctx.mamba_intermediate_ssm_states = torch.ones((1, 2, 4, 16), device='cuda') * 99 - ctx.mamba_conv_states = torch.zeros((1, 2, 8), device='cuda') - ctx.mamba_intermediate_conv_states = torch.ones((1, 2, 4, 8), device='cuda') * 77 + ctx.mamba_metadata.request_to_mamba_state_idx[:2] = torch.tensor( + [0, 1], dtype=torch.int32, device=context_device + ) + ctx.mamba_ssm_states.zero_() + ctx.mamba_intermediate_ssm_states.fill_(99) + ctx.mamba_conv_states.zero_() + ctx.mamba_intermediate_conv_states.fill_(77) # Mock accepted token counts: Req 0 accepts 1 (rejects 2), Req 1 accepts 0 (rejects 3) - self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._init_mtp_sampling_tensors() self.text_generation_controller._accepted_token_counts_per_request = torch.tensor( [1, 0], device='cuda' ) - self.text_generation_controller._rewind_kv_cache() + blocks_to_release, remove_mask = self.text_generation_controller._rewind_kv_cache() + ctx.kv_block_allocator.release_memory_blocks(blocks_to_release[remove_mask]) # Assert offsets updated assert torch.equal( ctx.request_last_kv_block_offset[:2], - torch.tensor([0, 2], dtype=torch.int, device='cuda'), + torch.tensor([0, 2], dtype=torch.int, device=context_device), ) assert torch.equal( - ctx.request_kv_length_offsets[:2], torch.tensor([8, 12], dtype=torch.int, device='cuda') + ctx.request_kv_length_offsets[:2], + torch.tensor([8, 12], dtype=torch.int, device=context_device), ) # Assert block counts and IDs updated for boundary crossing assert torch.equal( - ctx.request_kv_block_counts[:2], torch.tensor([3, 3], dtype=torch.int, device='cuda') + ctx.request_kv_block_counts[:2], + torch.tensor([3, 3], dtype=torch.int, device=context_device), ) assert torch.equal( - ctx.request_last_kv_block_id[:2], torch.tensor([50, 59], dtype=torch.int, device='cuda') + ctx.request_last_kv_block_id[:2], + torch.tensor([50, 59], dtype=torch.int, device=context_device), ) # Assert released block is cleared @@ -1221,6 +1119,108 @@ def test_rewind_kv_cache(self, is_hybrid_model): assert torch.all(ctx.mamba_conv_states[:, 0] == 77) # Req 0 accepted 1, loaded index 1 assert torch.all(ctx.mamba_conv_states[:, 1] == 77) # Req 1 accepted 0, loaded index 0 + @pytest.mark.internal + def test_rewind_kv_cache_stale_padding_is_safe(self): + """Padding slots with stale data must not corrupt active requests or + release junk blocks when the rewind kernel grid is padded beyond the + active request count. + + Without the num_active_requests guard in the kernel, padding slots + whose stale request_last_kv_block_offset < num_speculative_tokens + would produce remove_mask=True, causing the block allocator to free + block IDs that belong to other active requests. + """ + from megatron.core.inference.text_generation_controllers.mtp_utils_triton import ( + rewind_kv_cache, + ) + + num_spec = 3 + block_size = 4 + active = 2 + padded = 4 + max_blocks = 10 + dev = 'cuda' + + # --- Active requests (slots 0-1): identical to test_rewind_kv_cache --- + # Req 0: accepted 1, last_offset 2 → rewind 2 → offset 0, no release + # Req 1: accepted 0, last_offset 1 → rewind 3 → crosses block, release block 60 + accepted = torch.zeros(padded, device=dev, dtype=torch.int64) + accepted[0] = 1 + accepted[1] = 0 + + prefill = torch.zeros(padded, device=dev, dtype=torch.int64) + + last_offset = torch.zeros(padded, device=dev, dtype=torch.int64) + last_offset[0] = 2 + last_offset[1] = 1 + + kv_length = torch.zeros(padded, device=dev, dtype=torch.int64) + kv_length[0] = 10 + kv_length[1] = 15 + + block_counts = torch.zeros(padded, device=dev, dtype=torch.int64) + block_counts[0] = 3 + block_counts[1] = 4 + + last_block_id = torch.zeros(padded, device=dev, dtype=torch.int64) + last_block_id[0] = 50 + last_block_id[1] = 60 + + block_ids = torch.full((padded, max_blocks), -1, device=dev, dtype=torch.int64) + block_ids[0, :3] = torch.tensor([48, 49, 50]) + block_ids[1, :4] = torch.tensor([57, 58, 59, 60]) + + # --- Padding slots (2-3): stale data from completed requests --- + # Crucially, last_offset values < num_spec would trigger remove=True + # without the kernel guard, releasing stale block IDs. + last_offset[2] = 1 + last_offset[3] = 2 + kv_length[2] = 9999 + kv_length[3] = 9999 + block_counts[2] = 5 + block_counts[3] = 7 + last_block_id[2] = 777 + last_block_id[3] = 888 + block_ids[2, :5] = torch.arange(100, 105, device=dev) + block_ids[3, :5] = torch.arange(200, 205, device=dev) + + blocks_to_release, remove_mask = rewind_kv_cache( + accepted_counts=accepted, + prefill_status=prefill, + last_kv_block_offset=last_offset, + kv_length_offsets=kv_length, + kv_block_counts=block_counts, + last_kv_block_id=last_block_id, + kv_block_ids=block_ids, + num_speculative_tokens=num_spec, + block_size_tokens=block_size, + num_active_requests=active, + ) + + # --- Active request 0: rewind 2, no block release --- + assert remove_mask[0].item() is False + assert last_offset[0].item() == 0 + assert kv_length[0].item() == 8 + assert block_counts[0].item() == 3 + assert last_block_id[0].item() == 50 + + # --- Active request 1: rewind 3, crosses block boundary --- + assert remove_mask[1].item() is True + assert last_offset[1].item() == 2 # (1 - 3) % 4 = 2 + assert kv_length[1].item() == 12 + assert block_counts[1].item() == 3 + assert last_block_id[1].item() == 59 + assert blocks_to_release[1].item() == 60 + + # --- Padding slots 2-3: must be no-ops, no blocks released --- + assert remove_mask[2].item() is False + assert remove_mask[3].item() is False + # Stale state must be untouched (kernel skipped these programs). + assert kv_length[2].item() == 9999 + assert kv_length[3].item() == 9999 + assert block_counts[2].item() == 5 + assert block_counts[3].item() == 7 + @pytest.mark.internal def test_speculative_multinomial_sampling(self): """Test that speculative decoding can successfully use non-greedy sampling @@ -1240,6 +1240,8 @@ def test_speculative_multinomial_sampling(self): ) # Decode requests # query lengths for decode with spec tokens is (1 + num_spec) = 4 ctx.request_query_lengths = torch.tensor([4, 4], dtype=torch.int32, device='cuda') + ctx.num_prefill_requests = 0 + ctx.pad_active_slices() # Setup inputs input_ids = torch.randint(0, self.vocab_size, (1, 8), device='cuda') @@ -1248,17 +1250,15 @@ def test_speculative_multinomial_sampling(self): # Base logits shape: [1, 8, vocab_size] logits = torch.randn(1, 8, self.vocab_size, device='cuda') - # Set up a bucket that forces multinomial sampling (top_p = 0.9, top_k = 0) - # _torch_sampling_buckets format: (indices, temp, top_k, top_p) - self.text_generation_controller._torch_sampling_buckets = [([0, 1], 1.0, 0, 0.9)] - - # Since we are actually testing the internal math of `_torch_sampling_func` handling the shapes, - # we DO NOT mock `_torch_sampling_func` here. We want it to run natively to prove it doesn't crash. + # Drive sampling onto the multinomial path (top_p > 0, top_k == 0) via metadata. + # We do NOT mock the sampling kernel: we want it to run natively to prove it doesn't crash. + ctx.active_request_metadata["temperature"][:2] = 1.0 + ctx.active_request_metadata["top_k"][:2] = 0 + ctx.active_request_metadata["top_p"][:2] = 0.9 + self.text_generation_controller._all_logits_cuda = logits try: - self.text_generation_controller._dynamic_step_sample_logits_and_verify_tokens( - logits, input_ids - ) + self.text_generation_controller._dynamic_step_sample_logits_and_verify_tokens(input_ids) except RuntimeError as e: if "prob_dist must be 1 or 2 dim" in str(e): pytest.fail("MTP logits were not flattened before calling multinomial sampling.") @@ -1289,18 +1289,21 @@ def test_rewind_kv_cache_with_prefix_caching_ref_counts(self): ) ctx = self.text_generation_controller.inference_wrapped_model.inference_context + context_device = ctx.request_kv_length_offsets.device ctx.total_request_count = 2 ctx.paused_request_count = 0 - ctx.request_in_prefill_status_tensor = torch.tensor([0, 0], device='cuda') + ctx.request_in_prefill_status_tensor[:2] = torch.tensor( + [0, 0], dtype=torch.int32, device=context_device + ) # Req 0: 3 blocks, offset 1 in last block. Rewinding 1 token -> no block release. # Req 1: 3 blocks, offset 0 in last block. Rewinding 2 tokens -> crosses back, release block. - ctx.request_kv_length_offsets[:2] = torch.tensor([9, 9], device='cuda') - ctx.request_kv_block_counts[:2] = torch.tensor([3, 3], device='cuda') - ctx.request_last_kv_block_offset[:2] = torch.tensor([1, 0], device='cuda') - ctx.request_last_kv_block_id[:2] = torch.tensor([10, 20], device='cuda') + ctx.request_kv_length_offsets[:2] = torch.tensor([9, 9], device=context_device) + ctx.request_kv_block_counts[:2] = torch.tensor([3, 3], device=context_device) + ctx.request_last_kv_block_offset[:2] = torch.tensor([1, 0], device=context_device) + ctx.request_last_kv_block_id[:2] = torch.tensor([10, 20], device=context_device) ctx.request_to_kv_block_ids[:2, :3] = torch.tensor( - [[8, 9, 10], [18, 19, 20]], dtype=torch.int, device='cuda' + [[8, 9, 10], [18, 19, 20]], dtype=torch.int, device=context_device ) # Set ref counts: block 20 is shared (ref=2), block 10 is exclusive (ref=1). @@ -1310,12 +1313,13 @@ def test_rewind_kv_cache_with_prefix_caching_ref_counts(self): initial_avail = ctx.kv_block_allocator.total_avail # Req 0 accepts 1 (rewinds 1), Req 1 accepts 0 (rewinds 2, crosses boundary). - self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._init_mtp_sampling_tensors() self.text_generation_controller._accepted_token_counts_per_request = torch.tensor( [1, 0], device='cuda' ) - self.text_generation_controller._rewind_kv_cache() + blocks_to_release, remove_mask = self.text_generation_controller._rewind_kv_cache() + ctx.kv_block_allocator.release_memory_blocks(blocks_to_release[remove_mask]) # Req 1 should have released block 20 (ref count decremented). assert ctx.kv_block_allocator.block_ref_counts[20].item() == 1 @@ -1334,28 +1338,32 @@ def test_rewind_kv_cache_does_not_release_shared_prefix_blocks(self): ) ctx = self.text_generation_controller.inference_wrapped_model.inference_context + context_device = ctx.request_kv_length_offsets.device ctx.total_request_count = 1 ctx.paused_request_count = 0 - ctx.request_in_prefill_status_tensor = torch.tensor([0], device='cuda') + ctx.request_in_prefill_status_tensor[:1] = torch.tensor( + [0], dtype=torch.int32, device=context_device + ) # 4 blocks. Offset 2 in last block. Rewinding 3 crosses into previous block. - ctx.request_kv_length_offsets[:1] = torch.tensor([14], device='cuda') - ctx.request_kv_block_counts[:1] = torch.tensor([4], device='cuda') - ctx.request_last_kv_block_offset[:1] = torch.tensor([2], device='cuda') - ctx.request_last_kv_block_id[:1] = torch.tensor([40], device='cuda') + ctx.request_kv_length_offsets[:1] = torch.tensor([14], device=context_device) + ctx.request_kv_block_counts[:1] = torch.tensor([4], device=context_device) + ctx.request_last_kv_block_offset[:1] = torch.tensor([2], device=context_device) + ctx.request_last_kv_block_id[:1] = torch.tensor([40], device=context_device) ctx.request_to_kv_block_ids[0, :4] = torch.tensor( - [10, 20, 30, 40], dtype=torch.int, device='cuda' + [10, 20, 30, 40], dtype=torch.int, device=context_device ) # Blocks 10, 20 are shared prefix blocks. Block 30, 40 are exclusive. ctx.kv_block_allocator.total_avail = 50 - self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._init_mtp_sampling_tensors() self.text_generation_controller._accepted_token_counts_per_request = torch.tensor( [0], device='cuda' ) - self.text_generation_controller._rewind_kv_cache() + blocks_to_release, remove_mask = self.text_generation_controller._rewind_kv_cache() + ctx.kv_block_allocator.release_memory_blocks(blocks_to_release[remove_mask]) # Only block 40 should be released, not blocks 10, 20, or 30. assert ctx.request_kv_block_counts[0].item() == 3 @@ -1372,7 +1380,9 @@ def test_rewind_kv_cache_does_not_release_shared_prefix_blocks(self): def test_speculative_mtp_position_ids_with_prefill(self): """Test that _compute_serial_mtp_and_sample uses the correct position IDs for a mixed batch of prefill and decode requests.""" - self.setup_model(torch.float32, static=False, num_speculative_tokens=2, max_requests=2) + self.setup_model( + torch.float32, static=False, num_speculative_tokens=2, max_requests=2, mtp_num_layers=2 + ) self.text_generation_controller.num_speculative_tokens = 2 self.text_generation_controller.num_mtp_heads = 2 @@ -1388,7 +1398,7 @@ def test_speculative_mtp_position_ids_with_prefill(self): ctx.request_kv_length_offsets[:2] = torch.tensor([10, 0], dtype=torch.int32, device='cuda') ctx.request_query_lengths[:2] = torch.tensor([3, 15], dtype=torch.int32, device='cuda') - self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._init_mtp_sampling_tensors() # Mock base token sampling (the first tokens fed into MTP) self.text_generation_controller._sampled_tokens_cuda[:2] = torch.tensor( [100, 200], device='cuda' @@ -1403,7 +1413,9 @@ def test_speculative_mtp_position_ids_with_prefill(self): captured_position_ids = [] - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): captured_position_ids.append(position_ids.clone()) return hidden_states, torch.randn(2, 1, self.vocab_size, device='cuda') @@ -1463,7 +1475,7 @@ def test_mtp_sp_padding_real_ranks(self, active_request_count): active_request_count, dtype=torch.int32, device='cuda' ) - ctrl._init_mtp_sampling_tensor() + ctrl._init_mtp_sampling_tensors() ctrl._sampled_tokens_cuda[:active_request_count] = torch.remainder( torch.arange(active_request_count, device='cuda'), self.vocab_size ) @@ -1487,7 +1499,9 @@ def test_mtp_sp_padding_real_ranks(self, active_request_count): ctrl._last_accepted_seq_indices = torch.arange(active_request_count, device='cuda') # Greedy sampling: top_k=1 selects the argmax token deterministically. - ctrl._torch_sampling_buckets = [(list(range(active_request_count)), 1.0, 1, 0.0)] + ctx.active_request_metadata["temperature"][:active_request_count] = 1.0 + ctx.active_request_metadata["top_k"][:active_request_count] = 1 + ctx.active_request_metadata["top_p"][:active_request_count] = 0.0 # Run the MTP forward pass ctrl._compute_serial_mtp_and_sample() @@ -1539,10 +1553,7 @@ def test_mtp_sp_padding_dummy_ranks(self): dummy_positions = torch.zeros((1, tp_size), device='cuda', dtype=torch.long) hidden_out, logits_out = unwrapped_model.compute_mtp_single_step( - hidden_states=dummy_hidden, - next_token_ids=dummy_tokens, - position_ids=dummy_positions, - depth=0, + hidden_states=dummy_hidden, next_token_ids=dummy_tokens, position_ids=dummy_positions ) # Hidden output is in SP format: [padded_count/tp_size, 1, H] = [1, 1, H]. @@ -1585,7 +1596,6 @@ def test_mtp_sp_dummy_hidden_uses_full_seq_len(self): hidden_states=current_hidden, next_token_ids=dummy_tokens, position_ids=dummy_positions, - depth=depth, ) # Hidden stays in SP format across all depths. @@ -1598,3 +1608,209 @@ def test_mtp_sp_dummy_hidden_uses_full_seq_len(self): f"Depth {depth}: expected logits shape ({tp_size}, 1, {self.vocab_size}), " f"got {logits.shape}" ) + + +class TestTextGenerationControllerParallel(TextGenerationControllerTestBase): + """Tests that require non-default parallel configs (varying tp/pp). + + Each test initializes its own parallel state and tears it down afterward, + so these are separated from TestTextGenerationController to avoid + accumulating NCCL communicator memory from repeated init/destroy cycles. + """ + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def setup_model( + self, + dtype, + symmetric_ar_type=None, + fp8: bool = False, + tensor_model_parallel_size: int = 2, + pipeline_model_parallel_size: int = 1, + batch_size: int = 4, + static: bool = True, + use_training_random_init: bool = False, + materialize_only_last_token_logits: bool = False, + num_speculative_tokens: int = 0, + block_size_tokens: int = 256, + enable_prefix_caching: bool = False, + max_requests: int = None, + mtp_num_layers: int = 0, + sequence_parallel: bool = False, + expert_model_parallel_size: int = 1, + num_moe_experts: int = None, + hybrid_layer_pattern: str = None, + ): + Utils.initialize_model_parallel( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + ) + super().setup_model( + dtype, + symmetric_ar_type=symmetric_ar_type, + fp8=fp8, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + batch_size=batch_size, + static=static, + use_training_random_init=use_training_random_init, + materialize_only_last_token_logits=materialize_only_last_token_logits, + num_speculative_tokens=num_speculative_tokens, + block_size_tokens=block_size_tokens, + enable_prefix_caching=enable_prefix_caching, + max_requests=max_requests, + mtp_num_layers=mtp_num_layers, + sequence_parallel=sequence_parallel, + expert_model_parallel_size=expert_model_parallel_size, + num_moe_experts=num_moe_experts, + hybrid_layer_pattern=hybrid_layer_pattern, + ) + + @pytest.mark.parametrize("static", [True, False]) + @pytest.mark.parametrize("tp_size", [1, 2]) + @pytest.mark.parametrize("pp_size", [1, 2]) + def test_sampled_tokens_match_with_parallelism(self, static, tp_size, pp_size): + """Verify that sampled tokens match across all parallel ranks.""" + if tp_size == 1 and pp_size == 1: + pytest.skip(reason="Test requires model parallel size > 1.") + + if not static and not is_fa_min_version("2.7.3"): + pytest.skip(reason="Need latest flash attn for dynamic batching") + + self.setup_model( + dtype=torch.bfloat16, + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + static=static, + use_training_random_init=True, + ) + + self.mock_tokenizer.vocab_size = self.vocab_size + self.mock_tokenizer.eod = self.vocab_size - 1 + self.mock_tokenizer.detokenize.side_effect = lambda x, skip_special_tokens=False: ' '.join( + [ + ''.join(random.choices(string.ascii_letters, k=random.randint(4, 10))) + for _ in range(len(x)) + ] + ) + self.mock_tokenizer.offsets.side_effect = lambda _, s: [ + i for i, c in enumerate(s) if c == ' ' + ] + [len(s)] + + # Prepare requests. + active_requests: Dict[str, InferenceRequest] = OrderedDict() + for i in range(self.batch_size): + prompt = "sample" * (i + 1) + prompt_tokens = torch.randint( + low=0, high=self.vocab_size - 1, size=(len(prompt),) + ).tolist() + request_id = str(i) + inference_request = InferenceRequest( + request_id=request_id, + prompt=prompt, + sampling_params=SamplingParams( + top_k=10, num_tokens_to_generate=25, return_log_probs=True + ), + arrival_time=time.time(), + prompt_tokens=prompt_tokens, + status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, + ) + active_requests[request_id] = inference_request + + # Generate tokens. + if static: + requests = self.text_generation_controller.generate_all_output_tokens_static_batch( + active_requests + ) + all_generated_tokens = [req.generated_tokens.tolist() for req in requests.values()] + else: + all_generated_tokens = [[] for _ in range(len(active_requests))] + context = self.text_generation_controller.inference_wrapped_model.inference_context + for request_id, request in active_requests.items(): + context.add_request( + DynamicInferenceRequest( + request_id=int(request_id), + prompt_tokens=torch.tensor( + request.prompt_tokens, + dtype=torch.long, + device=torch.cuda.current_device(), + ), + sampling_params=SamplingParams( + top_k=10, return_log_probs=True, num_tokens_to_generate=25 + ), + ) + ) + expected_active_requests = set(int(x) for x in active_requests.keys()) + while context.has_unfinished_requests(): + result = self.text_generation_controller.generate_output_tokens_dynamic_batch() + new_tokens = result["sample"] + active_ids = result["active_request_ids"].tolist() + finished_ids = result["finished_request_ids"].tolist() + assert len(new_tokens) == len(expected_active_requests) + assert set(active_ids) == expected_active_requests + expected_active_requests -= set(finished_ids) + for i, token in enumerate(new_tokens.tolist()): + all_generated_tokens[i].append(token) + + # Wait for all communication to complete before proceeding. + torch.distributed.barrier() + + # Collect all the generated tokens for each request from each rank in the + # model parallel group. + mp_group = parallel_state.get_model_parallel_group() + mp_ranks = torch.distributed.get_process_group_ranks(mp_group) + local_rank = torch.distributed.get_rank() + tokens_per_rank = {} + tokens_per_rank[local_rank] = all_generated_tokens + + for i in mp_ranks: + if i == local_rank: + batch_size = torch.tensor( + len(tokens_per_rank[local_rank]), + dtype=torch.long, + device=torch.cuda.current_device(), + ) + else: + tokens_per_rank[i] = [] + batch_size = torch.empty(1, dtype=torch.long, device=torch.cuda.current_device()) + torch.distributed.broadcast(batch_size, group=mp_group, src=i) + + for j in range(batch_size.item()): + if i == local_rank: + sequence_length = torch.tensor( + len(tokens_per_rank[local_rank][j]), + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + else: + sequence_length = torch.empty( + 1, dtype=torch.int32, device=torch.cuda.current_device() + ) + torch.distributed.broadcast(sequence_length, group=mp_group, src=i) + + if i == local_rank: + generated_tokens = torch.tensor( + tokens_per_rank[local_rank][j], + dtype=torch.long, + device=torch.cuda.current_device(), + ) + else: + generated_tokens = torch.empty( + sequence_length.item(), dtype=torch.long, device=torch.cuda.current_device() + ) + torch.distributed.broadcast(generated_tokens, group=mp_group, src=i) + + if i != local_rank: + tokens_per_rank[i].append(generated_tokens.tolist()) + + # Ensure that every rank in the model parallel group produced the same tokens. + for i in mp_ranks: + if i == local_rank: + continue + for j, (expected, actual) in enumerate( + zip(tokens_per_rank[local_rank], tokens_per_rank[i]) + ): + assert ( + expected == actual + ), f"Rank {i} tokens differ from rank {local_rank} tokens for request {j}" diff --git a/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py b/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py index 96b782fad85..229af268a79 100644 --- a/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py +++ b/tests/unit_tests/models/test_dsa_gpt_mamba_equivalence.py @@ -1,6 +1,6 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """ -Equivalence tests: GPTModel with DSA vs MambaModel with DSA pattern. +Equivalence tests: GPTModel with DSA vs HybridModel with DSA pattern. A small DeepSeek-V3.2 proxy model (4 GPT layers / 8 Mamba layers) is built, weights are remapped GPT→Mamba, and logprobs are compared to verify they are @@ -9,8 +9,8 @@ Architecture equivalence ------------------------ GPTModel layer N (combined attention + MLP in one TransformerLayer) - ≡ MambaModel layer 2N (D, DSA TransformerLayer: input_layernorm + MLASelfAttention) - + MambaModel layer 2N+1 (-, MLPLayer: fused-norm MLP) + ≡ HybridModel layer 2N (D, DSA TransformerLayer: input_layernorm + MLASelfAttention) + + HybridModel layer 2N+1 (-, MLPLayer: fused-norm MLP) Run with:: @@ -35,10 +35,10 @@ get_transformer_block_with_experimental_attention_variant_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec -from megatron.core.models.mamba.mamba_model import MambaModel +from megatron.core.models.hybrid.hybrid_layer_allocation import validate_segment_layers +from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec +from megatron.core.models.hybrid.hybrid_model import HybridModel from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.ssm.mamba_hybrid_layer_allocation import validate_segment_layers from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import MLATransformerConfig from megatron.rl.rl_utils import selective_log_softmax @@ -193,15 +193,15 @@ def _build_mamba_model( layer_pattern: str, pre_process: bool = True, post_process: bool = True, -) -> MambaModel: - """Build a MambaModel with the given hybrid layer pattern.""" +) -> HybridModel: + """Build a HybridModel with the given hybrid layer pattern.""" layer_type_list = validate_segment_layers(layer_pattern) mamba_config = copy.deepcopy(config) mamba_config.num_layers = len(layer_type_list) assert mamba_config.num_layers == _NUM_GPT_LAYERS * 2 - model = MambaModel( + model = HybridModel( config=mamba_config, - mamba_stack_spec=mamba_stack_spec, + hybrid_stack_spec=hybrid_stack_spec, vocab_size=_VOCAB_SIZE, max_sequence_length=_MAX_SEQ_LEN, pre_process=pre_process, @@ -221,14 +221,14 @@ def _build_mamba_model( def _remap_gpt_to_mamba_state_dict( gpt_sd: Dict[str, torch.Tensor], num_local_gpt_layers: int ) -> Dict[str, torch.Tensor]: - """Remap a GPTModel state_dict to a MambaModel state_dict. + """Remap a GPTModel state_dict to a HybridModel state_dict. GPTModel layer N (combined attention + MLP) maps to: - * MambaModel layer 2N – DSA attention (input_layernorm + self_attention) - * MambaModel layer 2N+1 – MLP (mlp.*) + * HybridModel layer 2N – DSA attention (input_layernorm + self_attention) + * HybridModel layer 2N+1 – MLP (mlp.*) Additionally, ``decoder.final_layernorm.*`` (TransformerBlock naming) is - remapped to ``decoder.final_norm.*`` (MambaStack naming). + remapped to ``decoder.final_norm.*`` (HybridStack naming). All other keys (embedding, output_layer, rotary_pos_emb, …) are unchanged. @@ -238,7 +238,7 @@ def _remap_gpt_to_mamba_state_dict( pipeline stage (i.e. ``len(gpt_model.decoder.layers)``). Returns: - Remapped state dict ready for MambaModel.load_state_dict(strict=True). + Remapped state dict ready for HybridModel.load_state_dict(strict=True). """ mamba_sd: Dict[str, torch.Tensor] = {} layer_prefix = "decoder.layers." @@ -380,12 +380,12 @@ def _compare_against_golden_values( @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @pytest.mark.parametrize("tp,pp", [(1, 1), (2, 1), (1, 2)]) class TestDSAGPTMambaEquivalence: - """Verify logprob equivalence between GPTModel+DSA and MambaModel+DSA. + """Verify logprob equivalence between GPTModel+DSA and HybridModel+DSA. For each distributed configuration (TP, PP), the test: 1. Builds a GPTModel with 4 DSA layers. - 2. Builds a MambaModel with pattern "D-D-D-D-" (8 layers). - 3. Remaps and loads GPT weights into MambaModel (strict=True). + 2. Builds a HybridModel with pattern "D-D-D-D-" (8 layers). + 3. Remaps and loads GPT weights into HybridModel (strict=True). 4. Runs the same random tokens through both models. 5. Asserts logprob tensors are numerically close. """ @@ -416,7 +416,7 @@ def test_dsa_logprobs_match(self, tp: int, pp: int) -> None: num_local_gpt_layers = len(gpt_model.decoder.layers) gpt_sd = gpt_model.state_dict() - # ---- Build MambaModel ---- + # ---- Build HybridModel ---- mamba_model = _build_mamba_model( gpt_config, _MAMBA_PATTERN, pre_process=pre_process, post_process=post_process ) @@ -481,7 +481,7 @@ def test_weight_loading_strict(self, tp: int, pp: int) -> None: assert not unexpected, f"Unexpected keys: {unexpected}" def test_record_and_compare_golden_values(self, tp: int, pp: int) -> None: - """Record GPTModel logprobs as golden values, then compare MambaModel against them. + """Record GPTModel logprobs as golden values, then compare HybridModel against them. Golden values are written to the functional test directory so they can be committed and used by the CI inference golden-value tests. @@ -508,7 +508,7 @@ def test_record_and_compare_golden_values(self, tp: int, pp: int) -> None: gpt_logprobs = _forward_logprobs_pp1(gpt_model, tokens) mamba_logprobs = _forward_logprobs_pp1(mamba_model, tokens) - # Verify MambaModel matches golden values + # Verify HybridModel matches golden values _compare_against_golden_values(mamba_logprobs, gpt_logprobs, abs_tol=1e-3) @@ -520,7 +520,7 @@ def test_record_and_compare_golden_values(self, tp: int, pp: int) -> None: @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @pytest.mark.parametrize("tp,pp", [(1, 1), (2, 1), (1, 2)]) class TestDSAMoEGPTMambaEquivalence: - """Verify logprob equivalence between GPTModel+DSA+MoE and MambaModel+DSA+MoE. + """Verify logprob equivalence between GPTModel+DSA+MoE and HybridModel+DSA+MoE. Architecture: 4 GPT layers with moe_layer_freq=[0,0,1,1] (first 2 dense, last 2 MoE) maps to 8 Mamba layers with pattern "D-D-DEDE": @@ -556,7 +556,7 @@ def test_dsa_moe_logprobs_match(self, tp: int, pp: int) -> None: num_local_gpt_layers = len(gpt_model.decoder.layers) gpt_sd = gpt_model.state_dict() - # ---- Build MambaModel with MoE pattern ---- + # ---- Build HybridModel with MoE pattern ---- mamba_model = _build_mamba_model( gpt_config, _MOE_MAMBA_PATTERN, pre_process=pre_process, post_process=post_process ) @@ -618,7 +618,7 @@ def test_moe_weight_loading_strict(self, tp: int, pp: int) -> None: assert not unexpected, f"Unexpected keys: {unexpected}" def test_moe_record_and_compare_golden_values(self, tp: int, pp: int) -> None: - """Record GPTModel+MoE logprobs as golden values, then compare MambaModel+MoE.""" + """Record GPTModel+MoE logprobs as golden values, then compare HybridModel+MoE.""" self._skip_if_insufficient_gpus(tp, pp) if tp != 1 or pp != 1: pytest.skip("Golden-value recording only runs for tp=1, pp=1") @@ -640,5 +640,5 @@ def test_moe_record_and_compare_golden_values(self, tp: int, pp: int) -> None: gpt_logprobs = _forward_logprobs_pp1(gpt_model, tokens) mamba_logprobs = _forward_logprobs_pp1(mamba_model, tokens) - # Verify MambaModel matches golden values + # Verify HybridModel matches golden values _compare_against_golden_values(mamba_logprobs, gpt_logprobs, abs_tol=1e-3) diff --git a/tests/unit_tests/models/test_hybrid_model.py b/tests/unit_tests/models/test_hybrid_model.py index c4bdc147621..98a53da0314 100644 --- a/tests/unit_tests/models/test_hybrid_model.py +++ b/tests/unit_tests/models/test_hybrid_model.py @@ -15,6 +15,7 @@ from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext from megatron.core.inference.inference_request import DynamicInferenceRequest from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import YarnRotaryEmbedding from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec from megatron.core.models.hybrid.hybrid_model import HybridModel from megatron.core.packed_seq_params import PackedSeqParams @@ -523,3 +524,106 @@ def test_dynamic_inference_padding_with_fp8(self): # Assert that all padding logits are zero. assert torch.all(padding_logits == 0.0), "Logits for padding tokens are not all zero." + + +def _make_yarn_config(**kwargs): + """Build a TransformerConfig with yarn positional embedding attributes.""" + cfg = TransformerConfig( + num_layers=3, # 1 Mamba layer, 1 attention layer, 1 MLP layer + hidden_size=256, + num_attention_heads=4, + use_cpu_initialization=True, + **kwargs, + ) + # Yarn-specific attributes are set dynamically on the config (not TransformerConfig fields). + cfg.yarn_rotary_scaling_factor = 2.0 + cfg.yarn_original_max_position_embeddings = 4 + cfg.yarn_beta_fast = 32.0 + cfg.yarn_beta_slow = 1.0 + cfg.yarn_mscale = 1.0 + cfg.yarn_mscale_all_dim = 0.0 + cfg.yarn_correction_range_round_to_int = True + return cfg + + +class TestHybridModelWithYarn: + """Tests for HybridModel with YaRN positional embeddings.""" + + def setup_method(self, method): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + model_config = _make_yarn_config() + self.model = HybridModel( + config=model_config, + hybrid_stack_spec=hybrid_stack_spec, + vocab_size=100, + max_sequence_length=4, + hybrid_layer_pattern="M*-", # 1 Mamba, 1 attention, 1 MLP + position_embedding_type='yarn', + rotary_base=10000, + ) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + def test_constructor(self): + assert isinstance(self.model, HybridModel) + assert self.model.max_sequence_length == 4 + assert self.model.position_embedding_type == 'yarn' + # YaRN creates a YarnRotaryEmbedding rather than a plain RotaryEmbedding. + assert isinstance(self.model.rotary_pos_emb, YarnRotaryEmbedding) + + def test_forward(self): + sequence_length = self.model.max_sequence_length + micro_batch_size = 2 + + self.model.cuda() + + data = list(range(sequence_length)) + input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool + ).cuda() + + logits = self.model.forward( + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask + ) + + assert logits.shape[0] == micro_batch_size + assert logits.shape[1] == sequence_length + assert logits.shape[2] == self.model.vocab_size + + def test_inference(self): + micro_batch_size = 2 + inference_context: BaseInferenceContext = StaticInferenceContext( + max_batch_size=micro_batch_size, max_sequence_length=self.model.max_sequence_length + ) + prompt_length = self.model.max_sequence_length - 1 + + self.model.cuda() + + # load-context/first-output-token, step/generate + for offset in (0, prompt_length): + sequence_length = prompt_length if offset == 0 else 1 + inference_context.sequence_len_offset = offset + + data = list(range(sequence_length)) + input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + position_ids = ( + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + ) + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool + ).cuda() + + logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inference_context=inference_context, + ) + + assert logits.shape[0] == micro_batch_size + assert logits.shape[1] == sequence_length + assert logits.shape[2] == self.model.vocab_size diff --git a/tests/unit_tests/models/test_hybrid_moe_model.py b/tests/unit_tests/models/test_hybrid_moe_model.py index cb2e3c68c4a..fb270e5cbea 100644 --- a/tests/unit_tests/models/test_hybrid_moe_model.py +++ b/tests/unit_tests/models/test_hybrid_moe_model.py @@ -166,6 +166,7 @@ "mhc_sinkhorn_iterations": 20, "microbatch_group_size_per_vp_stage": 1, "mlp_chunks_for_prefill": 1, + "mlp_chunks_for_training": 1, "moe_apply_probs_on_input": False, "moe_aux_loss_coeff": 0.0, "moe_deepep_num_sms": 20, @@ -318,20 +319,23 @@ "max_seqlen_per_dp_cp_rank": None, "fallback_to_eager_attn": False, "inference_disable_triton_nvls_kernels": False, + "moe_router_force_biased": None, "inference_grouped_gemm_backend": "auto", "inference_moe_disable_fused_quant_kernels": False, - "linear_attention_type": None, - "moe_mlp_glu_interleave_size": None, - "moe_router_force_biased": None, - "sequence_packing_scheduler": None, - "use_transformer_engine_op_fuser": False, - "moe_single_grouped_weight": False, - "moe_single_grouped_bias": False, + "inference_moe_token_dispatcher_type": "nvls", } # Fields to ignore entirely (ephemeral, environment-specific, very large). SKIP_FIELDS = set() # Fields that are allowed to appear in the live config even if not yet in the golden. -ALLOW_ADDED_FIELDS = set() +ALLOW_ADDED_FIELDS = { + "linear_attention_type", + "moe_hybridep_num_sms_preprocessing", + "moe_mlp_glu_interleave_size", + "moe_single_grouped_bias", + "moe_single_grouped_weight", + "sequence_packing_scheduler", + "use_transformer_engine_op_fuser", +} def serialize_config(cfg: Any) -> Dict[str, Any]: diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py index 44be0c7911e..836382b21cc 100644 --- a/tests/unit_tests/models/test_mimo_1f1b_schedule.py +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -60,6 +60,27 @@ _embedding_pg_cache: dict = {} +def build_no_sync_func(mimo_model): + """Build a no_sync_func that stacks DDP no_sync over each sub-module. + + Shared by 1F1B pipeline tests and colocated-correctness tests — both need + DDP's gradient sync disabled during microbatches and resumed via the + schedule's finalize_grads_func. + """ + + @contextmanager + def no_sync_func(): + with ExitStack() as stack: + if mimo_model.language_model is not None: + stack.enter_context(mimo_model.language_model.no_sync()) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + stack.enter_context(submodule.no_sync()) + yield + + return no_sync_func + + def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): """Create a HyperCommGrid with specified parallelism.""" grid = HyperCommGrid( @@ -183,13 +204,40 @@ def is_rank_in_grid(grid): def get_language_model_spec( - num_layers, hidden_size, num_attention_heads, vocab_size, seq_len, pg_collection + num_layers, + hidden_size, + num_attention_heads, + vocab_size, + seq_len, + pg_collection, + bf16=True, + bias=True, + dropout=True, + per_token_loss=False, ): - """Get the language model spec.""" + """Get the language model spec. + + ``bf16=False`` switches pipeline dtype and autocast to fp32. Correctness + tests also pass ``bias=False, dropout=False`` to remove bias-update and + stochastic noise from the cross-config diff signal. + + ``per_token_loss=True`` sets ``calculate_per_token_loss=True`` on the + TransformerConfig, which pins DDP's gradient_scaling_factor to 1.0 + (pure SUM reduction). Callers that flip this must supply a 3-tuple + loss_func and drive the external divide in their finalize hook. + """ pp_rank = dist.get_rank(pg_collection.pp) pp_size = dist.get_world_size(pg_collection.pp) tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 + pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 + extra_kwargs = {} + if not bias: + extra_kwargs['add_bias_linear'] = False + if not dropout: + extra_kwargs['attention_dropout'] = 0.0 + extra_kwargs['hidden_dropout'] = 0.0 + lm_config = TransformerConfig( num_layers=num_layers, hidden_size=hidden_size, @@ -199,10 +247,12 @@ def get_language_model_spec( moe_token_dispatcher_type='alltoall', tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size, - pipeline_dtype=torch.bfloat16, - bf16=True, + pipeline_dtype=pipeline_dtype, + bf16=bf16, cross_entropy_loss_fusion=True, cross_entropy_fusion_impl='te', + calculate_per_token_loss=per_token_loss, + **extra_kwargs, ) return ModuleSpec( module=GPTModel, @@ -218,12 +268,12 @@ def get_language_model_spec( ) -def get_projection_config(hidden_size): +def get_projection_config(hidden_size, bias=True): """Return a TransformerConfig for the vision projection MLP.""" cfg = TransformerConfig(num_layers=1, hidden_size=hidden_size, num_attention_heads=1) cfg.ffn_hidden_size = hidden_size - cfg.bias_activation_fusion = True - cfg.add_bias_linear = True + cfg.bias_activation_fusion = bool(bias) + cfg.add_bias_linear = bool(bias) cfg.activation_func = torch.nn.functional.gelu return cfg @@ -239,15 +289,38 @@ def get_projection_layer_spec(): def get_vision_submodules_spec( - num_layers, hidden_size, num_attention_heads, language_hidden_size, pg_collection + num_layers, + hidden_size, + num_attention_heads, + language_hidden_size, + pg_collection, + bf16=True, + bias=True, + dropout=True, + per_token_loss=False, ): - """Get the submodule spec for the vision modality.""" + """Get the submodule spec for the vision modality. + + ``bias=False`` / ``dropout=False`` mirror the LM-spec kwargs for + correctness tests. ``per_token_loss=True`` sets + ``calculate_per_token_loss=True`` on the encoder's TransformerConfig so + the encoder DDP also pure-SUMs across DP (needed for the heterogeneous-DP + colocated path). + """ from megatron.core.transformer.transformer_block import TransformerBlock tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 pp_size = pg_collection.pp.size() if pg_collection.pp is not None else 1 pp_rank = dist.get_rank(pg_collection.pp) + pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 + extra_kwargs = {} + if not bias: + extra_kwargs['add_bias_linear'] = False + if not dropout: + extra_kwargs['attention_dropout'] = 0.0 + extra_kwargs['hidden_dropout'] = 0.0 + vision_config = TransformerConfig( num_layers=num_layers, hidden_size=hidden_size, @@ -257,8 +330,10 @@ def get_vision_submodules_spec( moe_token_dispatcher_type='alltoall', tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size, - pipeline_dtype=torch.bfloat16, - bf16=True, + pipeline_dtype=pipeline_dtype, + bf16=bf16, + calculate_per_token_loss=per_token_loss, + **extra_kwargs, ) vision_encoder_spec = ModuleSpec( module=TransformerBlock, @@ -274,7 +349,7 @@ def get_vision_submodules_spec( vision_projection_spec = ModuleSpec( module=MultimodalProjector, params={ - "config": get_projection_config(hidden_size=language_hidden_size), + "config": get_projection_config(hidden_size=language_hidden_size, bias=bias), "submodules": get_projection_layer_spec().submodules, "projector_type": "mlp", "input_size": vision_config.hidden_size, @@ -293,9 +368,38 @@ def get_vision_submodules_spec( def get_mimo_model( - encoder_name, encoder_grid, llm_grid, hidden_size, num_layers, vocab_size, seq_len + encoder_name, + encoder_grid, + llm_grid, + hidden_size, + num_layers, + vocab_size, + seq_len, + ddp_config=None, + bf16=True, + bias=True, + dropout=True, + per_token_loss=False, ): - """Create MIMO model with TransformerBlock encoder and GPTModel LLM.""" + """Create MIMO model with TransformerBlock encoder and GPTModel LLM. + + Args: + ddp_config: Optional override for the Megatron DDP config. Default + matches the 1F1B schedule tests' config. + bf16: If True (default) build the model in bf16; if False build in + fp32 end-to-end for deterministic numerics in correctness tests. + bias: If False, disable ``add_bias_linear`` in LM/vision configs and + the projection MLP — removes bias-update noise from diffs. + dropout: If False, force attention/hidden dropout to 0.0. + per_token_loss: If True, set ``calculate_per_token_loss=True`` on + both sub-model configs. This pins the encoder and LLM DDP + gradient_scaling_factor to 1.0 (pure SUM across DP). The caller + MUST supply a 3-tuple loss_func ``(sum_loss, num_tokens, + log_dict)`` and a custom ``finalize_model_grads_func`` that + divides grads by the correct global divisor on both sides; + hetero-DP callers use this to land ``1/B_full`` on both encoder + and LLM without relying on the per-DDP built-in scaling. + """ language_pg = get_pg_collection_with_embedding_groups(llm_grid, is_language_model=True) vision_pg = get_pg_collection_with_embedding_groups(encoder_grid, is_language_model=False) @@ -306,6 +410,10 @@ def get_mimo_model( vocab_size=vocab_size, seq_len=seq_len, pg_collection=language_pg, + bf16=bf16, + bias=bias, + dropout=dropout, + per_token_loss=per_token_loss, ) vision_submodule_spec = get_vision_submodules_spec( num_layers=num_layers, @@ -313,6 +421,10 @@ def get_mimo_model( num_attention_heads=8, language_hidden_size=hidden_size, pg_collection=vision_pg, + bf16=bf16, + bias=bias, + dropout=dropout, + per_token_loss=per_token_loss, ) module_to_grid_map = {encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid} @@ -326,12 +438,15 @@ def get_mimo_model( ) mimo_model = MimoModel(mimo_config) - mimo_model.to(torch.device("cuda")).to(torch.bfloat16) - - # Wrap with DDP - ddp_config = DistributedDataParallelConfig( - overlap_grad_reduce=True, bucket_size=10000, use_distributed_optimizer=True - ) + mimo_model.to(torch.device("cuda")) + if bf16: + mimo_model.to(torch.bfloat16) + + # Wrap with DDP (caller may override e.g. for heterogeneous-DP scaling). + if ddp_config is None: + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=True, bucket_size=10000, use_distributed_optimizer=True + ) if mimo_model.language_model is not None: mimo_model.language_model = DistributedDataParallel( @@ -485,16 +600,7 @@ def run_mimo_1f1b_test( seq_len=seq_length, ) - # Build schedule functions using pre-created pg_collections (no leaks) - @contextmanager - def no_sync_func(): - with ExitStack() as stack: - if mimo_model.language_model is not None: - stack.enter_context(mimo_model.language_model.no_sync()) - for submodule in mimo_model.modality_submodules.values(): - if submodule is not None: - stack.enter_context(submodule.no_sync()) - yield + no_sync_func = build_no_sync_func(mimo_model) def finalize_grads_func(*args, **kwargs): if mimo_model.language_model is not None: @@ -595,30 +701,35 @@ def loss_func(loss_mask, output_tensor): optimizer.zero_grad() - losses = schedule.forward_backward_pipelining_without_interleaving( - forward_step_func=step_func, - data_iterator=data_iterator, - model=[mimo_model], - num_microbatches=num_microbatches, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - forward_only=False, - p2p_communicator=communicator, - pg_collection=pg_collection, - ) - - # Optimizer step with global gradient clipping - success, grad_norm, num_zeros = optimizer.step() - assert success, "Optimizer step failed" - assert grad_norm is not None and grad_norm > 0, f"Expected positive grad norm, got {grad_norm}" - - # Verify results on last LLM stage - if is_rank_in_grid(llm_grid) and is_pp_last_stage(llm_grid.get_pg("pp")): - assert len(losses) > 0, "Expected losses on last LLM stage" - for loss_dict in losses: - assert 'loss_reduced' in loss_dict + try: + losses = schedule.forward_backward_pipelining_without_interleaving( + forward_step_func=step_func, + data_iterator=data_iterator, + model=[mimo_model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + p2p_communicator=communicator, + pg_collection=pg_collection, + ) - return losses + # Optimizer step with global gradient clipping + success, grad_norm, num_zeros = optimizer.step() + assert success, "Optimizer step failed" + assert ( + grad_norm is not None and grad_norm > 0 + ), f"Expected positive grad norm, got {grad_norm}" + + # Verify results on last LLM stage + if is_rank_in_grid(llm_grid) and is_pp_last_stage(llm_grid.get_pg("pp")): + assert len(losses) > 0, "Expected losses on last LLM stage" + for loss_dict in losses: + assert 'loss_reduced' in loss_dict + + return losses + finally: + mimo_model.destroy() # ============================================================================ diff --git a/tests/unit_tests/models/test_mimo_colocated_communicator.py b/tests/unit_tests/models/test_mimo_colocated_communicator.py new file mode 100644 index 00000000000..67cee551a0f --- /dev/null +++ b/tests/unit_tests/models/test_mimo_colocated_communicator.py @@ -0,0 +1,543 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +import logging +import os +import sys + +import pytest +import torch +import torch.distributed as dist + +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.mimo.comm.colocated_communicator import ColocatedBridgeCommunicator + +logging.basicConfig(level=logging.DEBUG, stream=sys.stderr) + +_active_grids: list = [] +_active_comms: list = [] + + +def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): + grid = HyperCommGrid( + shape=[tp, cp, pp, dp], + dim_names=["tp", "cp", "pp", "dp"], + rank_offset=offset, + backend="nccl", + ) + grid.create_pg(["tp"]) + grid.create_pg(["cp"]) + grid.create_pg(["pp"]) + grid.create_pg(["dp"]) + _active_grids.append(grid) + return grid + + +def make_comm(*args, **kwargs): + comm = ColocatedBridgeCommunicator(*args, **kwargs) + _active_comms.append(comm) + return comm + + +def destroy_all_grids(): + # Destroy communicators first so their NCCL subgroups are freed before we + # tear down the parent grids. NCCL caps concurrent communicators at ~500; + # leaked PGs from per-test fixtures blow that budget quickly. + for comm in _active_comms: + comm.destroy() + _active_comms.clear() + for grid in _active_grids: + grid.destroy() + _active_grids.clear() + + +# ── Test 1: Rank mappings ────────────────────────────────────────────────────── + + +class TestRankMappings: + + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + def teardown_method(self): + destroy_all_grids() + + @pytest.mark.parametrize( + "src_tp, src_dp, dest_tp, dest_dp, expected_src_pos, expected_dest_pos", + [ + # Fan-in: TP2/DP4 → TP4/DP2 + ( + 2, + 4, + 4, + 2, + { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + 4: (2, 0), + 5: (2, 1), + 6: (3, 0), + 7: (3, 1), + }, + { + 0: (0, 0), + 1: (0, 1), + 2: (0, 2), + 3: (0, 3), + 4: (1, 0), + 5: (1, 1), + 6: (1, 2), + 7: (1, 3), + }, + ), + # Fan-out: TP4/DP2 → TP2/DP4 + ( + 4, + 2, + 2, + 4, + { + 0: (0, 0), + 1: (0, 1), + 2: (0, 2), + 3: (0, 3), + 4: (1, 0), + 5: (1, 1), + 6: (1, 2), + 7: (1, 3), + }, + { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + 4: (2, 0), + 5: (2, 1), + 6: (3, 0), + 7: (3, 1), + }, + ), + ], + ids=["fan_in", "fan_out"], + ) + def test_rank_mappings( + self, src_tp, src_dp, dest_tp, dest_dp, expected_src_pos, expected_dest_pos + ): + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, dp=dest_dp) + comm = make_comm(src_grid, dest_grid) + + assert comm.rank_to_src_pos == expected_src_pos + assert comm.rank_to_dest_pos == expected_dest_pos + + def test_rank_mappings_with_rank_offset(self): + # 4-rank grids at offset=4 (covering ranks 4-7). Exercises the + # rank_offset propagation that previously only ran with offset=0. + if dist.get_world_size() < 8: + pytest.skip("requires at least 8 ranks") + src_grid = create_hypercomm_grid(offset=4, tp=2, dp=2) + dest_grid = create_hypercomm_grid(offset=4, tp=1, dp=4) + comm = make_comm(src_grid, dest_grid) + + assert comm.rank_to_src_pos == {4: (0, 0), 5: (0, 1), 6: (1, 0), 7: (1, 1)} + assert comm.rank_to_dest_pos == {4: (0, 0), 5: (1, 0), 6: (2, 0), 7: (3, 0)} + + +# ── Test 2: All-gather groups ────────────────────────────────────────────────── + + +class TestAllGatherGroups: + + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + def teardown_method(self): + destroy_all_grids() + + def test_fan_in_all_gather_groups(self): + # Fan-in TP2/DP4 → TP4/DP2. Groups are keyed (dest_dp_idx, src_tp_idx) + # and members must appear in src_dp_idx order so all_gather_into_tensor + # concatenates in slot order on the backward path. + src_grid = create_hypercomm_grid(tp=2, dp=4) + dest_grid = create_hypercomm_grid(tp=4, dp=2) + comm = make_comm(src_grid, dest_grid) + + assert comm.gather_group_ranks == [[0, 2], [1, 3], [4, 6], [5, 7]] + assert comm.gather_pg is not None + + def test_fan_out_gather_groups(self): + # Fan-out TP4/DP2 → TP2/DP4. Groups are keyed (src_dp_idx, dest_tp_idx); + # membership order must track dest_dp_idx so the backward all-gather + # reconstructs the full-batch gradient in the correct layout. + src_grid = create_hypercomm_grid(tp=4, dp=2) + dest_grid = create_hypercomm_grid(tp=2, dp=4) + comm = make_comm(src_grid, dest_grid) + + assert comm.gather_group_ranks == [[0, 2], [1, 3], [4, 6], [5, 7]] + assert comm.gather_pg is not None + + +# ── Test 3b: _validate_grids negative tests ─────────────────────────────────── + + +class TestValidateGrids: + """One negative test per raise path in ColocatedBridgeCommunicator._validate_grids. + + Each case builds a pair of grids that violates exactly one invariant and + asserts that the constructor raises ValueError. + """ + + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + def teardown_method(self): + destroy_all_grids() + + def _grid_missing_tp(self, offset=0, dp=1): + # Build a grid without a 'tp' dim to exercise the "missing 'tp'" raise. + grid = HyperCommGrid(shape=[dp], dim_names=["dp"], rank_offset=offset, backend="nccl") + grid.create_pg(["dp"]) + _active_grids.append(grid) + return grid + + def test_missing_tp_dim(self): + src_grid = self._grid_missing_tp(dp=8) + dest_grid = create_hypercomm_grid(tp=4, dp=2) + with pytest.raises(ValueError, match="must have 'tp' dimension"): + make_comm(src_grid, dest_grid) + + def test_size_mismatch(self): + src_grid = create_hypercomm_grid(tp=2, dp=4) # 8 ranks + dest_grid = create_hypercomm_grid(offset=4, tp=2, dp=2) # 4 ranks + with pytest.raises(ValueError, match="span same number of ranks"): + make_comm(src_grid, dest_grid) + + def test_rank_offset_mismatch(self): + src_grid = create_hypercomm_grid(offset=0, tp=2, dp=2) + dest_grid = create_hypercomm_grid(offset=4, tp=2, dp=2) + with pytest.raises(ValueError, match="same rank offset"): + make_comm(src_grid, dest_grid) + + @pytest.mark.parametrize( + "side,dim,expected", + [ + ("src", "pp", "src PP must be 1"), + ("dest", "pp", "dest PP must be 1"), + ("src", "cp", "CP must be 1"), + ], + ) + def test_pp_or_cp_gt_one_rejected(self, side, dim, expected): + bad = {dim: 2, "tp": 2, "dp": 2} + good = {"tp": 4, "dp": 2} + if side == "src": + src_grid = create_hypercomm_grid(**bad) + dest_grid = create_hypercomm_grid(**good) + else: + src_grid = create_hypercomm_grid(**good) + dest_grid = create_hypercomm_grid(**bad) + with pytest.raises(ValueError, match=expected): + make_comm(src_grid, dest_grid) + + def test_dp_not_divisible(self): + # 6-rank grids with DP sizes (3 vs 2) that neither divides the other. + # Fits inside an 8-rank world (HyperCommGrid enforces size <= world - offset). + if dist.get_world_size() < 6: + pytest.skip("requires at least 6 ranks") + src_grid = HyperCommGrid( + shape=[2, 1, 1, 3], dim_names=["tp", "cp", "pp", "dp"], backend="nccl" + ) + dest_grid = HyperCommGrid( + shape=[3, 1, 1, 2], dim_names=["tp", "cp", "pp", "dp"], backend="nccl" + ) + for g in (src_grid, dest_grid): + _active_grids.append(g) + with pytest.raises(ValueError, match="evenly divisible"): + make_comm(src_grid, dest_grid) + + +# ── Test 3c: communicate() runtime preconditions ────────────────────────────── + + +class TestCommunicatePreconditions: + """Runtime-input checks enforced by ``communicate()``.""" + + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + def teardown_method(self): + destroy_all_grids() + + def test_non_divisible_batch_raises_fan_out(self): + # Fan-out: dest_dp=4, src_dp=2 → scale=2. Pass a batch dim of size 3 + # so 3 % 2 != 0 and the forward communicate() raises before slicing. + src_grid = create_hypercomm_grid(tp=4, dp=2) + dest_grid = create_hypercomm_grid(tp=2, dp=4) + comm = make_comm(src_grid, dest_grid, dim_mapping={'b': 0, 'h': 1}) + tensor = torch.zeros(3, 8, device='cuda') + with pytest.raises(ValueError, match="not divisible by fan_out"): + comm.communicate(tensor) + + def test_non_divisible_batch_raises_fan_in_backward_narrow(self): + # Fan-in forward all-gathers (no slice), so the forward path never + # divides. The backward path narrows the post-gather output via + # get_slice_info, which raises on a non-divisible size. Call + # get_slice_info directly with an odd size to exercise that path. + src_grid = create_hypercomm_grid(tp=2, dp=4) + dest_grid = create_hypercomm_grid(tp=4, dp=2) + comm = make_comm(src_grid, dest_grid) + with pytest.raises(ValueError, match="not divisible by fan_in"): + comm.get_slice_info(batch_size=3) + + +# ── Test 3d: destroy() releases PGs ────────────────────────────────────────── + + +class TestDestroy: + """``destroy()`` must null out both PG attributes.""" + + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + def teardown_method(self): + destroy_all_grids() + + def test_destroy_releases_fan_in_pg(self): + src_grid = create_hypercomm_grid(tp=2, dp=4) + dest_grid = create_hypercomm_grid(tp=4, dp=2) + # Don't track via make_comm — destroy() is exactly what we're testing. + comm = ColocatedBridgeCommunicator(src_grid, dest_grid) + assert comm.gather_pg is not None + comm.destroy() + assert comm.gather_pg is None + + def test_destroy_releases_fan_out_pg(self): + src_grid = create_hypercomm_grid(tp=4, dp=2) + dest_grid = create_hypercomm_grid(tp=2, dp=4) + comm = ColocatedBridgeCommunicator(src_grid, dest_grid) + assert comm.gather_pg is not None + comm.destroy() + assert comm.gather_pg is None + + def test_destroy_is_idempotent(self): + # Calling destroy twice must not raise — leftover test fixtures often + # double-destroy during exception cleanup. + src_grid = create_hypercomm_grid(tp=2, dp=4) + dest_grid = create_hypercomm_grid(tp=4, dp=2) + comm = ColocatedBridgeCommunicator(src_grid, dest_grid) + comm.destroy() + comm.destroy() + + +# ── Test 3e: Bridge gradient correctness (bitwise exact) ───────────────────── + + +def _shape_for_dim_mapping(dim_mapping, B, S, H): + s = [0, 0, 0] + s[dim_mapping['b']] = B + s[dim_mapping['s']] = S + s[dim_mapping['h']] = H + return s + + +# Parametrize dim_mapping for the fan-in tests (tests 1 & 2 per AXIOM spec). +_DIM_MAPPINGS = [{'s': 0, 'b': 1, 'h': 2}, {'b': 0, 's': 1, 'h': 2}] +_DIM_MAPPING_IDS = ["sbh", "bsh"] + + +class TestBridgeGradients: + """Bitwise-exact gradient tests for ``ColocatedBridgeCommunicator``. + + This class is **intentionally distinct** from the model-level correctness + tests in ``test_mimo_colocated_correctness.py`` (see PR review comment 19). + The bridge forward and backward are pure data + movement (``narrow`` / ``all_gather_into_tensor``) with no FP compute, so + the mathematical adjoint relationship can — and should — be asserted at + ``rtol=0, atol=0``: + + * fan-in forward == ``torch.cat`` of sibling inputs in slot order + * fan-in backward == ``grad_output.narrow`` at this rank's slot + * fan-out forward == ``input.narrow`` at this rank's slot + * fan-out backward == ``cat`` of every sibling's grad (catches + zero-pad-without-gather, wrong slot order, double-counting, + missing siblings — the four failure modes of the adjoint) + * equal-DP is a pure identity (forward + backward) + + The MimoModel-level tests validate the full training stack including GEMM + reduction order and DDP scaling, and can only assert approximate FP32 + closeness. These tests localise the bridge's own invariants and fail + first when one of them regresses. + """ + + S = 8 + B_PER_RANK = 2 + H = 128 + + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + def teardown_method(self): + destroy_all_grids() + + # ── Test 1: fan-in forward = torch.cat of sibling inputs ───────────────── + @pytest.mark.parametrize("src_tp,src_dp,dest_tp,dest_dp", [(2, 4, 4, 2)], ids=["2x_fan_in"]) + @pytest.mark.parametrize("dim_mapping", _DIM_MAPPINGS, ids=_DIM_MAPPING_IDS) + def test_fan_in_forward_equals_torch_cat(self, src_tp, src_dp, dest_tp, dest_dp, dim_mapping): + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, dp=dest_dp) + comm = make_comm(src_grid, dest_grid, dim_mapping=dim_mapping) + + rank = dist.get_rank() + shape = _shape_for_dim_mapping(dim_mapping, self.B_PER_RANK, self.S, self.H) + + # Distinct inputs per rank so the cat reveals ordering bugs. + torch.manual_seed(1000 + rank) + local_input = torch.randn(*shape, device='cuda') + + actual = comm.communicate(local_input) + + # Expected: manual all_gather over the communicator's fan-in group, + # then cat along batch_dim. all_gather preserves group-local-rank + # order, which is the same order the communicator uses. + group = comm.gather_pg + gathered = [torch.empty_like(local_input) for _ in range(dist.get_world_size(group))] + dist.all_gather(gathered, local_input, group=group) + expected = torch.cat(gathered, dim=dim_mapping['b']) + + torch.testing.assert_close(actual, expected, rtol=0, atol=0) + + # ── Test 2: fan-in backward = grad_output.narrow for this rank's slot ──── + @pytest.mark.parametrize("src_tp,src_dp,dest_tp,dest_dp", [(2, 4, 4, 2)], ids=["2x_fan_in"]) + @pytest.mark.parametrize("dim_mapping", _DIM_MAPPINGS, ids=_DIM_MAPPING_IDS) + def test_fan_in_backward_equals_narrow(self, src_tp, src_dp, dest_tp, dest_dp, dim_mapping): + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, dp=dest_dp) + comm = make_comm(src_grid, dest_grid, dim_mapping=dim_mapping) + + rank = dist.get_rank() + batch_dim = dim_mapping['b'] + b_local = self.B_PER_RANK + shape = _shape_for_dim_mapping(dim_mapping, b_local, self.S, self.H) + + torch.manual_seed(1000 + rank) + local_input = torch.randn(*shape, device='cuda', requires_grad=True) + out = comm.communicate(local_input) + + # grad_output is TP-replicated within the dest DP group: seed the same + # on every rank so every rank in the fan-in group backward-narrows the + # same upstream gradient. out shape is identical across group members, + # so seeded randn produces the same tensor on each. + torch.manual_seed(42) + grad_output = torch.randn_like(out) + out.backward(grad_output) + + slot = comm.rank_to_src_pos[rank][0] % comm.scale + expected = grad_output.narrow(batch_dim, slot * b_local, b_local).contiguous() + torch.testing.assert_close(local_input.grad, expected, rtol=0, atol=0) + + # ── Test 3: fan-out forward = input.narrow for this rank's slot ───────── + @pytest.mark.parametrize("src_tp,src_dp,dest_tp,dest_dp", [(4, 2, 2, 4)], ids=["2x_fan_out"]) + def test_fan_out_forward_equals_narrow(self, src_tp, src_dp, dest_tp, dest_dp): + dim_mapping = {'b': 0, 's': 1, 'h': 2} + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, dp=dest_dp) + comm = make_comm(src_grid, dest_grid, dim_mapping=dim_mapping) + + rank = dist.get_rank() + batch_dim = dim_mapping['b'] + b_per_dest = self.B_PER_RANK + b_full = b_per_dest * comm.scale + shape = _shape_for_dim_mapping(dim_mapping, b_full, self.S, self.H) + + # Input is TP-replicated on the batch dim (bridge contract). Seed + # identically across all ranks to satisfy it. + torch.manual_seed(42) + input_tensor = torch.randn(*shape, device='cuda') + + actual = comm.communicate(input_tensor) + + slot = comm.rank_to_dest_pos[rank][0] % comm.scale + expected = input_tensor.narrow(batch_dim, slot * b_per_dest, b_per_dest).contiguous() + torch.testing.assert_close(actual, expected, rtol=0, atol=0) + + # ── Test 4 (CRITICAL): fan-out backward = concat of all sibling grads ── + @pytest.mark.parametrize("src_tp,src_dp,dest_tp,dest_dp", [(4, 2, 2, 4)], ids=["2x_fan_out"]) + def test_fan_out_backward_equals_concat_of_sibling_grads( + self, src_tp, src_dp, dest_tp, dest_dp + ): + """Fan-out backward must all-gather sibling grads in slot order. + + Catches four distinct regressions with a single assertion: + * zero-pad-without-gather (other slots would be zero), + * wrong slot order (values would be scrambled), + * double-counting (values would be multiplied), + * missing siblings (shape or zeros would diverge). + """ + dim_mapping = {'b': 0, 's': 1, 'h': 2} + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, dp=dest_dp) + comm = make_comm(src_grid, dest_grid, dim_mapping=dim_mapping) + + rank = dist.get_rank() + batch_dim = dim_mapping['b'] + scale = comm.scale + b_per_dest = self.B_PER_RANK + b_full = b_per_dest * scale + shape = _shape_for_dim_mapping(dim_mapping, b_full, self.S, self.H) + + torch.manual_seed(42) # identical input across ranks (TP-replicated) + input_tensor = torch.randn(*shape, device='cuda', requires_grad=True) + out = comm.communicate(input_tensor) # narrowed to (b_per_dest, S, H) + + # Distinct grad per slot so the cat reveals both membership and order. + slot = comm.rank_to_dest_pos[rank][0] % scale + grad_output = (slot + 1) * torch.ones_like(out) + out.backward(grad_output) + + slot_shape = _shape_for_dim_mapping(dim_mapping, b_per_dest, self.S, self.H) + expected = torch.cat( + [(i + 1) * torch.ones(*slot_shape, device='cuda') for i in range(scale)], dim=batch_dim + ) + torch.testing.assert_close(input_tensor.grad, expected, rtol=0, atol=0) + + # ── Test 5: equal DP is a pure identity forward and backward ──────────── + @pytest.mark.parametrize("src_tp,src_dp,dest_tp,dest_dp", [(4, 2, 4, 2)], ids=["tp4_dp2"]) + def test_equal_dp_is_bitwise_identity_fwd_and_bwd(self, src_tp, src_dp, dest_tp, dest_dp): + dim_mapping = {'b': 0, 's': 1, 'h': 2} + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, dp=dest_dp) + comm = make_comm(src_grid, dest_grid, dim_mapping=dim_mapping) + + shape = _shape_for_dim_mapping(dim_mapping, self.B_PER_RANK, self.S, self.H) + torch.manual_seed(1000 + dist.get_rank()) + x = torch.randn(*shape, device='cuda', requires_grad=True) + + out = comm.communicate(x) + torch.testing.assert_close(out, x, rtol=0, atol=0) + + grad_output = torch.randn_like(x) + out.backward(grad_output) + torch.testing.assert_close(x.grad, grad_output, rtol=0, atol=0) diff --git a/tests/unit_tests/models/test_mimo_colocated_correctness.py b/tests/unit_tests/models/test_mimo_colocated_correctness.py new file mode 100644 index 00000000000..e2d91bdf83e --- /dev/null +++ b/tests/unit_tests/models/test_mimo_colocated_correctness.py @@ -0,0 +1,1183 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""Gradient-scaling correctness for colocated MimoModel under heterogeneous DP. + +Verifies that a heterogeneous-DP MimoModel produces the same post-step +encoder weights as an **equal-DP** reference built on the SAME encoder +TP/DP layout as the dist model (so the bridge is the identity +passthrough — ``BridgeDirection.EQUAL`` in +``ColocatedBridgeCommunicator``). Under correct grad scaling, both +configs yield the DP=1 gradient on every encoder shard, so the Adam +update lands on identical values and the sharded post-step weights +compare directly. + +Why an equal-DP reference is the right oracle: + * Encoder sharding matches exactly — ref and dist both use + ``enc_tp=dist_enc_tp, enc_dp=dist_enc_dp``. Shards line up 1:1, + so there is no gather-and-slice in the weight comparison and no + TP=1-vs-TP>1 accumulation-order drift to contend with. + * ``enc_dp == llm_dp`` on the ref side → the bridge is identity and + every encoder rank feeds its colocated LLM rank with no + redistribution collective. + * Both sides set ``calculate_per_token_loss=True`` on their + TransformerConfigs, which pins DDP's ``gradient_scaling_factor=1.0`` + — pure SUM across DP. The custom + ``finalize_grads_func`` in ``_wire_training_hooks`` all-reduces + ``total_num_tokens`` over the LLM DP group, then calls + ``scale_gradients(1/N_global)`` on both encoder and LLM. This lands + the true global per-token mean on every shard without touching + ``DistributedDataParallel``. + +LLM TP differs between ref (``llm_tp=dist_enc_tp``) and dist +(``llm_tp=dist_llm_tp``), so ref's LLM weights are copied into dist via +all-gather-across-ref-TP + slice-for-dist-TP. The LLM forward then +diverges numerically by fp32 TP accumulation order, but the aggregate +gradient that flows back into the encoder remains the DP=1 gradient in +both models, which is what the post-step encoder weight oracle checks. +The test runs in fp32 with ``add_bias_linear=False`` and dropout +disabled to minimize non-bridge numerical noise — this keeps the +post-bridge hidden states bit-exact and surfaces only TP-shape drift +in the logits oracle. + +If the heterogeneous-DP scaling is wrong (e.g. dividing by encoder_dp +when it should be 1, or letting either DDP apply its default ``1/dp_size`` +on top of the per-token mean already delivered by the finalize hook), +the dist encoder's post-step weights diverge from the ref encoder's +weights — a single Adam step is enough to detect. + +Run with:: + + uv run python -m torch.distributed.run --nproc_per_node=8 \\ + -m pytest tests/unit_tests/models/test_mimo_colocated_correctness.py -v -s +""" + +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +import megatron.core.pipeline_parallel.schedules as schedule +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.models.mimo.optimizer import get_mimo_optimizer +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.transformer.enums import ModelType +from megatron.core.utils import unwrap_model +from tests.unit_tests.models.test_mimo_1f1b_schedule import ( + build_no_sync_func, + create_all_embedding_groups, + create_hypercomm_grid, + destroy_all_grids, + get_mimo_model, +) +from tests.unit_tests.test_utilities import Utils + + +def loss_func(loss_mask, output_tensor): + """Per-token-loss 3-tuple: raw local sum + local valid-token count. + + Returns ``(local_sum, local_num_tokens, log_dict)`` — the contract the + schedule expects when ``calculate_per_token_loss=True`` is set on the + TransformerConfig. No ``1/num_tokens`` or ``1/num_microbatches`` + division is applied here; the schedule skips the per-microbatch + division (see ``schedules.py:270-274``) and aggregates ``num_tokens`` + across microbatches for the finalize step. + + Paired with ``mimo_finalize_grads_func`` below, which all-reduces + ``total_num_tokens`` over the LLM DP group to obtain ``N_global`` and + then divides both encoder and LLM grads by ``1/N_global`` directly via + ``scale_gradients`` — landing the true global per-token mean on every + shard without touching DDP. + + ``output_tensor`` is per-token CE from + ``GPTModel.compute_language_model_loss`` with shape ``[b, s]``. + """ + if output_tensor is None: + zero_loss = torch.tensor(0.0, device='cuda', requires_grad=True) + zero_count = torch.tensor(0, device='cuda', dtype=torch.int) + return zero_loss, zero_count, {'loss_reduced': 0.0} + + masked = output_tensor.float() * loss_mask.float() + local_sum = masked.sum() + local_num_tokens = loss_mask.float().sum().to(torch.int) + return local_sum, local_num_tokens, {'loss_reduced': local_sum.detach().item()} + + +def forward_step(data_iterator, model, encoder_grid, llm_grid, encoder_name): + """Forward step with per-rank data slicing for heterogeneous DP.""" + batch = next(data_iterator) if data_iterator is not None else {'input_ids': None} + + if batch.get('input_ids') is None: + output_tensor, loss_mask = model(**batch) + return output_tensor, partial(loss_func, loss_mask) + + encoder_dp = encoder_grid.get_pg("dp").size() + llm_dp = llm_grid.get_pg("dp").size() + + if encoder_dp > llm_dp: + # Fan-in: input was pre-sliced to LLM-DP (larger per-rank batch). + # Narrow modality_inputs to the encoder's smaller per-rank slice. + scale = encoder_dp // llm_dp + encoder_dp_idx = encoder_grid.get_pg("dp").rank() + slot = encoder_dp_idx % scale + + if 'modality_inputs' in batch and batch['modality_inputs'] is not None: + for mod_name, mod_data in batch['modality_inputs'].items(): + for enc_name, enc_data in mod_data.items(): + for key, tensor in enc_data.items(): + if tensor is not None and isinstance(tensor, torch.Tensor): + batch_size = tensor.shape[1] # [seq, batch, hidden] + slice_size = batch_size // scale + start = slot * slice_size + enc_data[key] = tensor[:, start : start + slice_size, :].contiguous() + + elif llm_dp > encoder_dp: + # Fan-out: input was pre-sliced to encoder-DP (larger per-rank batch). + # Narrow the LLM-side tensors to this LLM-DP rank's slice. + scale = llm_dp // encoder_dp + llm_dp_idx = llm_grid.get_pg("dp").rank() + slot = llm_dp_idx % scale + + batch_size = batch['input_ids'].shape[0] + slice_size = batch_size // scale + start = slot * slice_size + + for key in ['input_ids', 'labels', 'loss_mask', 'position_ids']: + if key in batch and batch[key] is not None: + batch[key] = batch[key][start : start + slice_size].contiguous() + + output_tensor, loss_mask = model(**batch) + return output_tensor, partial(loss_func, loss_mask) + + +def _set_deterministic_env(): + for k, v in { + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8", + }.items(): + os.environ[k] = v + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + + +def _wire_training_hooks(mimo_model, language_pg, vision_pg): + """Attach no_sync / finalize_grads / grad_scale hooks to a MimoModel. + + The finalize hook implements the heterogeneous-DP grad-scaling story + without touching ``DistributedDataParallel``. Both sub-model configs + set ``calculate_per_token_loss=True``, so both DDPs pure-SUM across + their own DP group (``gradient_scaling_factor=1.0``). After backward + and DDP reduce, every rank's ``main_grad`` holds the un-normalized + full-batch sum of per-token gradients. + + This hook then: + 1. all-reduces the schedule's ``total_num_tokens`` across the LLM + DP group to obtain ``N_global`` (total valid tokens in the global + batch). Since ranks are colocated, every rank now knows + ``N_global``. + 2. Calls ``finalize_model_grads(num_tokens=None)`` per side — runs + the usual DDP grad finish + layernorm/embedding AR work without + letting the built-in divisor path fire. + 3. Calls ``scale_gradients(1/N_global)`` on each side — lands the + true global per-token mean uniformly on encoder and LLM grads. + + Note: encoder has no loss_func (so nothing emits a per-encoder-DP + ``num_tokens`` to feed ``finalize_model_grads``' internal all-reduce). + Doing the all-reduce once ourselves and calling ``scale_gradients`` + directly avoids engineering a fictitious per-encoder-rank count whose + sum happens to equal ``N_global``. + """ + + no_sync_func = build_no_sync_func(mimo_model) + + def finalize_grads_func(model_list, num_tokens, force_all_reduce=False, **kwargs): + # Schedule passes the per-rank sum-across-microbatches of what the + # loss_func returned. Because loss_func runs only on the LLM side, + # this is the LLM-local token count. + assert num_tokens is not None, ( + "finalize_grads_func expects calculate_per_token_loss=True on the " + "TransformerConfig so the schedule forwards total_num_tokens; got None." + ) + + # Phase 1: lift the all-reduce. After this, every rank (including + # encoder-only replicas) has N_global = total non-padded tokens in + # the global batch. + llm_dp_pg = language_pg.dp_cp if language_pg.dp_cp is not None else language_pg.dp + dist.all_reduce(num_tokens, group=llm_dp_pg, op=dist.ReduceOp.SUM) + n_global = num_tokens.item() + + # Phase 2: per-side DDP finish without built-in num_tokens scaling. + # Forward ``force_all_reduce`` so PP grad-sync semantics (if ever + # exercised here) aren't silently dropped. + if mimo_model.language_model is not None: + finalize_model_grads( + [mimo_model.language_model], + num_tokens=None, + pg_collection=language_pg, + force_all_reduce=force_all_reduce, + ) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + finalize_model_grads( + [submodule], + num_tokens=None, + pg_collection=vision_pg, + force_all_reduce=force_all_reduce, + ) + + # Phase 3: uniform divide by N_global. Guard div-by-zero for the + # degenerate fully-masked batch. + if n_global > 0: + inv = 1.0 / n_global + if mimo_model.language_model is not None: + mimo_model.language_model.scale_gradients(inv) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + submodule.scale_gradients(inv) + + mimo_model.config.no_sync_func = no_sync_func + mimo_model.config.finalize_model_grads_func = finalize_grads_func + mimo_model.config.grad_scale_func = lambda loss: ( + torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + if isinstance(loss, (int, float)) + else loss + ) + + +def _generate_and_broadcast_global_batches( + global_mbs, + seq_length, + hidden_size, + vocab_size, + encoder_name, + num_batches, + image_token_id=50257, + mask_pattern="uniform", +): + """Generate global batches on rank 0 and broadcast so every rank sees + identical data. Dist pre-slices per rank; ref consumes the full batch. + + ``mask_pattern``: + * ``"uniform"`` — every sample has the same valid-token count (image + tokens masked, text tokens all valid). Local/global denominators + coincide up to DP-rank partitioning. + * ``"asymmetric"`` — each sample zeros out an additional sample- + dependent number of trailing text tokens, so different samples + (and therefore different DP-rank slices) carry different valid- + token counts. This exercises the num+den global-mean CE path + where the old local-mean recipe would be only approximately + correct. + """ + if mask_pattern not in ("uniform", "asymmetric"): + raise ValueError(f"Unknown mask_pattern: {mask_pattern!r}") + + rank = dist.get_rank() + image_seq_length = seq_length // 2 + batches = [] + + for batch_idx in range(num_batches): + if rank == 0: + encoder_hidden_states = torch.randn( + image_seq_length, global_mbs, hidden_size, device='cuda', dtype=torch.float32 + ) + image_tokens = torch.full( + (global_mbs, image_seq_length), image_token_id, dtype=torch.long, device='cuda' + ) + text_tokens = torch.randint( + 1, vocab_size, (global_mbs, seq_length - image_seq_length), device='cuda' + ) + input_ids = torch.cat([image_tokens, text_tokens], dim=1) + else: + encoder_hidden_states = torch.empty( + image_seq_length, global_mbs, hidden_size, device='cuda', dtype=torch.float32 + ) + input_ids = torch.empty(global_mbs, seq_length, dtype=torch.long, device='cuda') + + dist.broadcast(encoder_hidden_states, src=0) + dist.broadcast(input_ids, src=0) + + labels = input_ids.clone() + labels[input_ids == image_token_id] = -100 + loss_mask = torch.ones(global_mbs, seq_length, device='cuda', dtype=torch.float32) + loss_mask[input_ids == image_token_id] = 0.0 + + if mask_pattern == "asymmetric": + # Zero out a sample-dependent trailing run of text tokens so + # each sample ends up with a different valid-token count. + # Counts are deterministic given (batch_idx, sample_idx) so the + # broadcast-on-rank-0 pattern is reproducible on every rank. + text_len = seq_length - image_seq_length + for sample_idx in range(global_mbs): + n_drop = ((batch_idx * 7 + sample_idx * 3) % (text_len - 1)) + 1 + loss_mask[sample_idx, seq_length - n_drop :] = 0.0 + labels[sample_idx, seq_length - n_drop :] = -100 + position_ids = ( + torch.arange(seq_length, device='cuda').unsqueeze(0).expand(global_mbs, -1).clone() + ) + + batches.append( + { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + "modality_inputs": { + encoder_name: { + "clip_encoder": { + 'hidden_states': encoder_hidden_states, + 'attention_mask': None, + } + } + }, + } + ) + + return batches + + +def _slice_batch(global_batch, split_dp, split_rank): + """Return the ``split_rank``-th of ``split_dp`` slices along the batch dim.""" + batch_dim = global_batch['input_ids'].shape[0] + slice_size = batch_dim // split_dp + start = split_rank * slice_size + end = start + slice_size + + per_rank = {} + for key in ['input_ids', 'labels', 'loss_mask', 'position_ids']: + per_rank[key] = global_batch[key][start:end].contiguous() + + mod_inputs_new = {} + for mod_name, mod_data in global_batch['modality_inputs'].items(): + mod_inputs_new[mod_name] = {} + for enc_name, enc_data in mod_data.items(): + mod_inputs_new[mod_name][enc_name] = {} + for key, tensor in enc_data.items(): + if tensor is not None and isinstance(tensor, torch.Tensor): + # modality hidden_states is [seq, batch, hidden] — slice dim 1 + mod_inputs_new[mod_name][enc_name][key] = tensor[:, start:end, :].contiguous() + else: + mod_inputs_new[mod_name][enc_name][key] = tensor + per_rank['modality_inputs'] = mod_inputs_new + return per_rank + + +def _slice_global_batch_for_dist(global_batch, encoder_grid, llm_grid): + """Pre-slice a global batch to the per-rank batch that ``forward_step`` expects. + + ``forward_step`` assumes each rank already has its LLM-DP slice + (fan-in) or encoder-DP slice (fan-out); this helper performs that + slicing so both models can consume the same underlying global batch. + When ``enc_dp == llm_dp`` there is no fan-in/fan-out to pre-slice for + (``forward_step`` also skips slicing), and the full batch is returned. + """ + enc_dp = encoder_grid.get_pg("dp").size() + llm_dp = llm_grid.get_pg("dp").size() + + if enc_dp > llm_dp: + return _slice_batch(global_batch, llm_dp, llm_grid.get_pg("dp").rank()) + if llm_dp > enc_dp: + return _slice_batch(global_batch, enc_dp, encoder_grid.get_pg("dp").rank()) + return global_batch + + +def _slice_global_batch_by_dp(global_batch, dp_pg): + """Slice a global batch along the batch dim by ``dp_pg`` rank. + + For the equal-DP reference (``enc_dp == llm_dp``, bridge is identity), + each rank consumes 1/``dp_size`` of the global batch directly. + ``_slice_global_batch_for_dist`` returns the full batch in that case, + so this helper does the DP-rank split explicitly. + """ + dp_size = dist.get_world_size(dp_pg) + if dp_size <= 1: + return global_batch + return _slice_batch(global_batch, dp_size, dist.get_rank(dp_pg)) + + +def _copy_ref_params_to_dist(ref_module, dist_module, ref_tp_group, dist_tp_group): + """Copy ref params into dist, handling differing TP shardings. + + When ref and dist params have the same shape (same TP size and layout + at offset=0), shards align 1:1 and we copy directly. When shapes differ + (different TP sizes), we all-gather ref's shards across ``ref_tp_group`` + to reconstruct the full weight, then slice by the dist ``partition_dim`` + for this rank's dist TP shard. + + Must be called **before** constructing the distributed optimizer, which + clones current param data into fp32 master weights at __init__. + """ + ref_tp_size = dist.get_world_size(ref_tp_group) + dist_tp_rank = dist.get_rank(dist_tp_group) + dist_tp_size = dist.get_world_size(dist_tp_group) + ref_params = dict(ref_module.named_parameters()) + + with torch.no_grad(): + for name, dist_param in dist_module.named_parameters(): + assert name in ref_params, f"Param '{name}' in dist but not in ref" + ref_param = ref_params[name] + partition_dim = getattr(dist_param, 'partition_dim', -1) + + if ref_param.shape == dist_param.shape: + # Same shard size (same TP layout or both replicated). + dist_param.data.copy_(ref_param.data.to(dist_param.dtype)) + continue + + assert partition_dim >= 0, ( + f"Param '{name}': shapes differ " + f"(ref={tuple(ref_param.shape)}, dist={tuple(dist_param.shape)}) " + f"but partition_dim<0 — cannot reshard a replicated param." + ) + + # Different TP sizes: gather ref shards, then slice for dist. + shards = [torch.empty_like(ref_param.data) for _ in range(ref_tp_size)] + dist.all_gather(shards, ref_param.data.contiguous(), group=ref_tp_group) + full_weight = torch.cat(shards, dim=partition_dim) + dist_slice = torch.tensor_split(full_weight, dist_tp_size, dim=partition_dim)[ + dist_tp_rank + ] + + assert dist_slice.shape == dist_param.shape, ( + f"Param '{name}': sliced.shape={tuple(dist_slice.shape)} != " + f"dist.shape={tuple(dist_param.shape)} " + f"(ref_tp={ref_tp_size}, dist_tp={dist_tp_size}, " + f"partition_dim={partition_dim})" + ) + dist_param.data.copy_(dist_slice.to(dist_param.dtype)) + + +def _global_abs_diff_stats(a, b, pg=None): + """Absolute-diff stats plus reference-tensor magnitude stats, across ``pg``. + + Reports both the abs-diff distribution AND the magnitude of ``b`` (the + reference tensor) so the caller can judge scale: a max abs-diff of 1.0 + is catastrophic for values of O(1), but fine for values of O(100). The + relative-diff column (``rel_max = max_diff / ref_max``) gives a quick + percentage read. + + Useful when the per-rank tensors cover different shards — all-reducing + MAX/MIN (and MAX of per-rank p95/p99 as a conservative worst-case) lets + rank 0 print a global view of drift across every shard in ``pg``. Mean + is SUM/world_size, which is the true global mean when every rank holds + the same number of elements (true here — shards have the same shape). + """ + diff = (a.float() - b.float()).abs().flatten() + ref = b.float().abs().flatten() + n = diff.numel() + + if n == 0: + zero = torch.tensor(0.0, device='cuda') + local_min = local_max = local_mean = local_p50 = local_p95 = local_p99 = zero + local_ref_max = local_ref_p95 = local_ref_mean = zero + else: + local_min = diff.min() + local_max = diff.max() + local_mean = diff.mean() + local_p50 = diff.quantile(0.50) + local_p95 = diff.quantile(0.95) + local_p99 = diff.quantile(0.99) + local_ref_max = ref.max() + local_ref_p95 = ref.quantile(0.95) + local_ref_mean = ref.mean() + + world = dist.get_world_size(pg) if dist.is_initialized() else 1 + if world > 1: + g_min = local_min.clone() + g_max = local_max.clone() + g_mean = local_mean.clone() + g_p50 = local_p50.clone() + g_p95 = local_p95.clone() + g_p99 = local_p99.clone() + g_ref_max = local_ref_max.clone() + g_ref_p95 = local_ref_p95.clone() + g_ref_mean = local_ref_mean.clone() + dist.all_reduce(g_min, op=dist.ReduceOp.MIN, group=pg) + dist.all_reduce(g_max, op=dist.ReduceOp.MAX, group=pg) + dist.all_reduce(g_mean, op=dist.ReduceOp.SUM, group=pg) + dist.all_reduce(g_p50, op=dist.ReduceOp.MAX, group=pg) + dist.all_reduce(g_p95, op=dist.ReduceOp.MAX, group=pg) + dist.all_reduce(g_p99, op=dist.ReduceOp.MAX, group=pg) + dist.all_reduce(g_ref_max, op=dist.ReduceOp.MAX, group=pg) + dist.all_reduce(g_ref_p95, op=dist.ReduceOp.MAX, group=pg) + dist.all_reduce(g_ref_mean, op=dist.ReduceOp.SUM, group=pg) + g_mean = g_mean / world + g_ref_mean = g_ref_mean / world + return { + 'min': g_min.item(), + 'max': g_max.item(), + 'mean': g_mean.item(), + 'p50_worst': g_p50.item(), + 'p95_worst': g_p95.item(), + 'p99_worst': g_p99.item(), + 'ref_max': g_ref_max.item(), + 'ref_p95': g_ref_p95.item(), + 'ref_mean': g_ref_mean.item(), + 'numel_per_rank': n, + 'ranks': world, + } + return { + 'min': local_min.item(), + 'max': local_max.item(), + 'mean': local_mean.item(), + 'p50_worst': local_p50.item(), + 'p95_worst': local_p95.item(), + 'p99_worst': local_p99.item(), + 'ref_max': local_ref_max.item(), + 'ref_p95': local_ref_p95.item(), + 'ref_mean': local_ref_mean.item(), + 'numel_per_rank': n, + 'ranks': 1, + } + + +def _fmt_diff_stats(s): + ref_max = s.get('ref_max', 0.0) + rel_max = (s['max'] / ref_max) if ref_max > 0 else float('inf') + return ( + f"min={s['min']:.2e} p50={s['p50_worst']:.2e} mean={s['mean']:.2e} " + f"p95={s['p95_worst']:.2e} p99={s['p99_worst']:.2e} " + f"max={s['max']:.2e} | ref_max={ref_max:.2e} ref_p95={s.get('ref_p95', 0.0):.2e} " + f"ref_mean={s.get('ref_mean', 0.0):.2e} rel_max={rel_max:.1%} " + f"(n_per_rank={s['numel_per_rank']}, ranks={s['ranks']})" + ) + + +def _print_from_rank0(msg): + if not dist.is_initialized() or dist.get_rank() == 0: + print(msg, flush=True) + + +def _register_logits_capture(mimo_model): + """Forward hook on the LLM ``output_layer``; captures per-microbatch logits. + + The hook runs on every microbatch forward. ``output`` from + ``ColumnParallelLinear`` is ``(logits, bias)`` with logits shape + ``[s, b, v/tp]`` — this rank's per-DP-slot, per-TP-vocab-shard slice + of the global logits tensor. Cloning so backward doesn't mutate. + + Returns ``(captures, handle)``; caller must ``handle.remove()`` after + the schedule completes. + """ + gpt = unwrap_model(mimo_model.language_model) + captures = [] + + def hook(_module, _inputs, output): + logits = output[0] if isinstance(output, tuple) else output + captures.append(logits.detach().clone()) + + handle = gpt.output_layer.register_forward_hook(hook) + return captures, handle + + +def _register_llm_input_capture(mimo_model): + """Forward pre-hook on the GPT ``decoder``; captures post-bridge hidden states. + + This is the activation entering the transformer block AFTER embedding + (skipped when MIMO passes ``decoder_input``) AND after the bridge has + moved the encoder output into the LLM's TP/DP layout. Shape is + ``[s, b_local, h_full]`` — hidden dim is NOT TP-sharded at this point. + + Comparing dist vs ref at this capture isolates "does the bridge deliver + mathematically equivalent inputs to the LLM?" from downstream LLM TP + forward drift. If this oracle passes but ``llm_logits`` fails, the + divergence is inside the LLM TP forward; if this fails, the bridge + (fan_in/fan_out vs equal) is not equivalent. + """ + gpt = unwrap_model(mimo_model.language_model) + captures = [] + + def pre_hook(_module, args, kwargs): + hidden = kwargs.get('hidden_states', None) + if hidden is None and args: + hidden = args[0] + if hidden is not None: + captures.append(hidden.detach().clone()) + + handle = gpt.decoder.register_forward_pre_hook(pre_hook, with_kwargs=True) + return captures, handle + + +def _gather_bs_dp(local_tensor, llm_dp_pg): + """All-gather ``[s, b, h]`` across LLM DP along the batch dim.""" + dp_size = dist.get_world_size(llm_dp_pg) + if dp_size <= 1: + return local_tensor.contiguous() + contig = local_tensor.contiguous() + shards = [torch.empty_like(contig) for _ in range(dp_size)] + dist.all_gather(shards, contig, group=llm_dp_pg) + return torch.cat(shards, dim=1) + + +def _assert_llm_input_match( + ref_captures, dist_captures, ref_llm_grid, dist_llm_grid, rtol=1e-3, atol=1e-3 +): + """Post-bridge oracle: hidden states entering the LLM decoder match. + + Hidden dim is not TP-sharded at the decoder input, so only DP-gather + across the LLM DP group is needed to reconstruct the full-batch tensor. + """ + assert len(ref_captures) == len(dist_captures), ( + f"Microbatch count mismatch: ref={len(ref_captures)}, " f"dist={len(dist_captures)}" + ) + ref_dp_pg = ref_llm_grid.get_pg("dp") + dist_dp_pg = dist_llm_grid.get_pg("dp") + + mismatches = [] + for mbs_idx, (ref_local, dist_local) in enumerate(zip(ref_captures, dist_captures)): + ref_full = _gather_bs_dp(ref_local, ref_dp_pg) + dist_full = _gather_bs_dp(dist_local, dist_dp_pg) + assert ref_full.shape == dist_full.shape, ( + f"mbs[{mbs_idx}]: gathered llm-input shape mismatch — " + f"ref={tuple(ref_full.shape)}, dist={tuple(dist_full.shape)}" + ) + stats = _global_abs_diff_stats(dist_full, ref_full, pg=dist.group.WORLD) + _print_from_rank0( + f"[llm-input-diff] mbs[{mbs_idx}] shape={tuple(ref_full.shape)} " + f"{_fmt_diff_stats(stats)}" + ) + try: + torch.testing.assert_close(dist_full, ref_full, rtol=rtol, atol=atol) + except AssertionError as e: + mismatches.append((mbs_idx, str(e))) + + if mismatches: + rank = dist.get_rank() + details = "\n".join(f" mbs[{i}]: {msg}" for i, msg in mismatches) + raise AssertionError( + f"Rank {rank}: llm-input diverged on {len(mismatches)} microbatch(es):\n" f"{details}" + ) + + +def _gather_logits_full_batch(local_logits, llm_tp_pg, llm_dp_pg): + """All-gather ``[s, b, v/tp]`` across LLM TP (vocab dim) then DP (batch dim). + + Returns ``[s, b * dp_size, v]`` — the full global-batch logits, + identical on every rank of the LLM grid. Used to compare dist vs ref + on the same global slots regardless of how TP/DP slices them. + """ + tp_size = dist.get_world_size(llm_tp_pg) + dp_size = dist.get_world_size(llm_dp_pg) + + vocab_full = local_logits.contiguous() + if tp_size > 1: + shards = [torch.empty_like(vocab_full) for _ in range(tp_size)] + dist.all_gather(shards, vocab_full, group=llm_tp_pg) + vocab_full = torch.cat(shards, dim=-1) + + batch_full = vocab_full.contiguous() + if dp_size > 1: + shards = [torch.empty_like(batch_full) for _ in range(dp_size)] + dist.all_gather(shards, batch_full, group=llm_dp_pg) + batch_full = torch.cat(shards, dim=1) + + return batch_full + + +def _assert_llm_logits_match( + ref_captures, dist_captures, ref_llm_grid, dist_llm_grid, rtol=1e-2, atol=1e-2 +): + """Logits oracle: TP+DP-gathered full-batch logits match microbatch-by-microbatch. + + Dist and ref share the same global batch on every rank (broadcast from + rank 0), and with the HyperCommGrid layout both reconstruct global + batch rows 0..N in the same order after TP+DP all-gather (see + ``_slice_global_batch_*`` helpers for how the slicing lines up). + The only numerical difference between the two gathered logits is + fp32 accumulation order across a different LLM TP shape — hence the + loose ``rtol=atol=1e-2`` default. + """ + assert len(ref_captures) == len(dist_captures), ( + f"Microbatch count mismatch: ref={len(ref_captures)}, " f"dist={len(dist_captures)}" + ) + ref_tp_pg = ref_llm_grid.get_pg("tp") + ref_dp_pg = ref_llm_grid.get_pg("dp") + dist_tp_pg = dist_llm_grid.get_pg("tp") + dist_dp_pg = dist_llm_grid.get_pg("dp") + + mismatches = [] + for mbs_idx, (ref_local, dist_local) in enumerate(zip(ref_captures, dist_captures)): + ref_full = _gather_logits_full_batch(ref_local, ref_tp_pg, ref_dp_pg) + dist_full = _gather_logits_full_batch(dist_local, dist_tp_pg, dist_dp_pg) + assert ref_full.shape == dist_full.shape, ( + f"mbs[{mbs_idx}]: gathered logits shape mismatch — " + f"ref={tuple(ref_full.shape)}, dist={tuple(dist_full.shape)}" + ) + # Gathered full-batch logits are identical on every LLM-grid rank, + # so stats at rank 0 represent the tensor globally — no reduction + # needed across other ranks. + stats = _global_abs_diff_stats(dist_full, ref_full, pg=dist.group.WORLD) + _print_from_rank0( + f"[logits-diff] mbs[{mbs_idx}] shape={tuple(ref_full.shape)} " + f"{_fmt_diff_stats(stats)}" + ) + try: + torch.testing.assert_close(dist_full, ref_full, rtol=rtol, atol=atol) + except AssertionError as e: + mismatches.append((mbs_idx, str(e))) + + if mismatches: + rank = dist.get_rank() + details = "\n".join(f" mbs[{i}]: {msg}" for i, msg in mismatches) + raise AssertionError( + f"Rank {rank}: logits diverged on {len(mismatches)} microbatch(es):\n" f"{details}" + ) + + +def _snapshot_first_layer_encoder_grads(mimo_model, encoder_name): + """Clone ``param.main_grad`` for every ``.layers.0.`` encoder param. + + ``main_grad`` holds the post-DDP-reduction gradient (reduced across + encoder DP), populated by the backward pass and consumed by + ``optimizer.step()``. Snapshot between backward and step so the values + aren't yet zeroed. + """ + encoder = mimo_model.modality_submodules[encoder_name].module + snap = {} + for name, param in encoder.named_parameters(): + if '.layers.0.' not in name: + continue + grad = getattr(param, 'main_grad', None) + if grad is None: + continue + snap[name] = grad.detach().clone() + return snap + + +def _assert_first_layer_grads_match(ref_snap, dist_snap, rtol=1e-3, atol=1e-3): + """First-layer encoder grad oracle: shard-to-shard match between ref and dist. + + Ref and dist use identical encoder TP/DP layout, so for every + ``layers.0.*`` encoder parameter their local shards line up 1:1. + Under correct grad scaling both main_grads equal the DP=1 gradient, + so the per-shard values must match within fp32 precision. Tighter + tolerances than the logits oracle are possible because the encoder + forward is identical on both sides — only the LLM TP layout differs, + and that noise enters via the gradient flowing back into the encoder. + """ + assert set(ref_snap.keys()) == set(dist_snap.keys()), ( + f"First-layer param name mismatch — " + f"ref-only: {set(ref_snap) - set(dist_snap)}, " + f"dist-only: {set(dist_snap) - set(ref_snap)}" + ) + mismatches = [] + for name in sorted(ref_snap): + ref_g = ref_snap[name] + dist_g = dist_snap[name] + assert ref_g.shape == dist_g.shape, ( + f"Param '{name}': grad shape {tuple(ref_g.shape)} != " + f"{tuple(dist_g.shape)} — caller must match encoder TP." + ) + # Every rank holds its own TP shard of this param; all-reduce + # across the full world so rank 0 prints the worst-case drift + # across all shards. + stats = _global_abs_diff_stats(dist_g, ref_g, pg=dist.group.WORLD) + _print_from_rank0( + f"[grad-diff] {name} shape={tuple(ref_g.shape)} " f"{_fmt_diff_stats(stats)}" + ) + try: + torch.testing.assert_close(dist_g, ref_g, rtol=rtol, atol=atol) + except AssertionError as e: + mismatches.append((name, str(e))) + + if mismatches: + rank = dist.get_rank() + details = "\n".join(f" {n}: {msg}" for n, msg in mismatches) + raise AssertionError( + f"Rank {rank}: {len(mismatches)} first-layer encoder grad(s) " + f"diverged between dist and ref:\n{details}" + ) + + +def _assert_encoder_weights_match(ref_module, dist_module, rtol=1e-3, atol=1e-3): + """Assert every dist encoder shard matches the ref encoder shard. + + Caller is responsible for ensuring ref and dist have the same encoder TP + layout (same ``enc_tp`` and ``enc_dp``), so each rank's shards line up + 1:1 and can be compared directly. Under correct grad scaling and + identical initial state, one Adam step yields shard-wise equal post-step + weights — modulo fp32 TP accumulation-order drift from the LLM TP + layout differing between the two models. + """ + ref_params = dict(ref_module.named_parameters()) + + mismatches = [] + for name, dist_param in dist_module.named_parameters(): + ref_param = ref_params[name] + assert ref_param.shape == dist_param.shape, ( + f"Param '{name}': ref.shape={tuple(ref_param.shape)} != " + f"dist.shape={tuple(dist_param.shape)} — caller must match encoder TP." + ) + stats = _global_abs_diff_stats(dist_param.data, ref_param.data, pg=dist.group.WORLD) + _print_from_rank0( + f"[weight-diff] {name} shape={tuple(ref_param.shape)} " f"{_fmt_diff_stats(stats)}" + ) + try: + torch.testing.assert_close(dist_param.data, ref_param.data, rtol=rtol, atol=atol) + except AssertionError as e: + mismatches.append((name, str(e))) + + if mismatches: + rank = dist.get_rank() + details = "\n".join(f" {n}: {msg}" for n, msg in mismatches) + raise AssertionError( + f"Rank {rank}: {len(mismatches)} encoder param(s) diverged between " + f"heterogeneous-DP dist model and equal-DP reference:\n{details}" + ) + + +class _BatchIterator: + """Minimal iterator over a pre-generated list of batches.""" + + def __init__(self, batches): + self.batches = batches + self.idx = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.idx >= len(self.batches): + raise StopIteration + b = self.batches[self.idx] + self.idx += 1 + return b + + +def _run_forward_backward( + mimo_model, + batches, + enc_grid, + llm_grid, + encoder_name, + language_pg, + micro_batch_size, + seq_length, + num_microbatches, +): + """One forward/backward pass through the mimo schedule.""" + return schedule.forward_backward_no_pipelining( + forward_step_func=partial( + forward_step, encoder_grid=enc_grid, llm_grid=llm_grid, encoder_name=encoder_name + ), + data_iterator=_BatchIterator(batches), + model=[mimo_model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + pg_collection=language_pg, + ) + + +class TestColocatedGradientScalingCorrectness: + """Verify heterogeneous-DP encoder grad scaling against an equal-DP reference. + + The critical invariant: with ``calculate_per_token_loss=True`` on both + sub-model configs, DDP's ``gradient_scaling_factor`` is pinned to + 1.0 and each side's DDP reduction is a pure SUM. The custom + ``finalize_grads_func`` then divides both encoder and LLM grads by + ``1/N_global`` (true global valid-token count), so the aggregate + gradient on every encoder shard equals the DP=1 per-token-mean + gradient. The reference uses the same encoder TP/DP as dist but with + ``enc_tp == llm_tp`` and ``enc_dp == llm_dp`` (identity bridge), so + after one Adam step the dist model's sharded weights match the ref + model's sharded weights within fp32 precision. + + If the scaling were wrong (e.g., if either DDP applied its default + ``1/dp_size`` on top of the per-token mean, or if the custom finalize + used the encoder DP group's sum-of-local-counts instead of the + globally lifted ``N_global``), the encoder's reduced grad would be + skewed and post-step weights would diverge — a single optimizer step + is sufficient to detect. + """ + + @classmethod + def setup_class(cls): + Utils.initialize_distributed() + cls.world_size = dist.get_world_size() + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def setup_method(self): + # Track MimoModels built by the test so teardown can release any + # ColocatedBridgeCommunicator subgroups before destroy_all_grids. + self._mimo_models = [] + + def teardown_method(self): + torch.use_deterministic_algorithms(False) + for model in self._mimo_models: + model.destroy() + self._mimo_models.clear() + destroy_all_grids() + + @pytest.mark.skipif( + version.parse(torch.__version__) < version.parse("2.3.0"), reason="Requires PyTorch 2.3+" + ) + @pytest.mark.parametrize( + "enc_tp,enc_dp,llm_tp,llm_dp", [(2, 4, 4, 2), (4, 2, 2, 4)], ids=["fan_in", "fan_out"] + ) + @pytest.mark.parametrize( + "mask_pattern", ["uniform", "asymmetric"], ids=["uniform", "asymmetric"] + ) + @pytest.mark.parametrize("num_microbatches", [1, 4], ids=["mbs1", "mbs4"]) + def test_dist_matches_dp1_reference_post_step_weights( + self, enc_tp, enc_dp, llm_tp, llm_dp, mask_pattern, num_microbatches + ): + """Heterogeneous-DP dist post-step encoder weights match equal-DP reference. + + Builds two MimoModels on every rank: + + * Dist: the heterogeneous TP/DP config under test, with + ``calculate_per_token_loss=True`` + custom finalize hook that + pure-SUMs DDP and externally divides by ``N_global``. + * Ref: equal-DP uniform with ``enc_tp=dist_enc_tp``, + ``enc_dp=dist_enc_dp``, ``llm_tp=dist_enc_tp``, + ``llm_dp=dist_enc_dp`` — bridge is + ``BridgeDirection.EQUAL`` (identity passthrough), and the + encoder TP sharding matches dist's exactly so shards line up + 1:1 for comparison. + + Both models run the same finalize wiring; both DDPs pure-SUM + across their own DP group, then divide uniformly by ``N_global``. + LLM TP differs between the two models, which introduces fp32 TP + accumulation-order drift in the gradient flowing back to the + encoder but does not change the per-token-mean invariant that the + post-step encoder oracle checks. + + Reference weights are copied into the distributed model so both + start from identical state. One Adam step later, the dist shards + should match the ref shards within fp32 precision. + """ + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + _set_deterministic_env() + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + encoder_name = "images" + hidden_size, seq_length, vocab_size = 256, 64, 1000 + micro_batch_size = 2 + + # Global batch spans the larger DP side; dist pre-slices per rank + # before forward_step (which further slices encoder/LLM side). + global_batch_size = micro_batch_size * max(enc_dp, llm_dp) + + # Grids: dist is heterogeneous; ref is equal-DP uniform matching + # dist's encoder so the bridge is identity and encoder shards + # align 1:1 for direct comparison. + dist_enc_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) + dist_llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=1, dp=llm_dp) + ref_enc_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) + ref_llm_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) + create_all_embedding_groups([dist_enc_grid, dist_llm_grid, ref_enc_grid, ref_llm_grid]) + + # Both sub-model TransformerConfigs set calculate_per_token_loss=True + # (via per_token_loss=True on get_mimo_model), which pins DDP's + # gradient_scaling_factor to 1.0 — pure SUM across DP on both sides. + # Under the 3-tuple loss_func + custom finalize_grads_func in + # _wire_training_hooks, grads are divided uniformly by N_global, + # which is the true global per-token mean on every shard. + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=True, bucket_size=10000, use_distributed_optimizer=True + ) + + # Build dist first (heterogeneous TP/DP). + torch.manual_seed(12345) + dist_mimo, _, _, dist_language_pg, dist_vision_pg = get_mimo_model( + encoder_name=encoder_name, + encoder_grid=dist_enc_grid, + llm_grid=dist_llm_grid, + hidden_size=hidden_size, + num_layers=2, + vocab_size=vocab_size, + seq_len=seq_length, + ddp_config=ddp_config, + bf16=False, + bias=False, + dropout=False, + per_token_loss=True, + ) + dist_mimo.model_type = ModelType.encoder_or_decoder + self._mimo_models.append(dist_mimo) + + # Reference with equal-DP uniform (enc_tp == llm_tp, enc_dp == llm_dp). + torch.manual_seed(12345) + ref_mimo, _, _, ref_language_pg, ref_vision_pg = get_mimo_model( + encoder_name=encoder_name, + encoder_grid=ref_enc_grid, + llm_grid=ref_llm_grid, + hidden_size=hidden_size, + num_layers=2, + vocab_size=vocab_size, + seq_len=seq_length, + ddp_config=ddp_config, + bf16=False, + bias=False, + dropout=False, + per_token_loss=True, + ) + ref_mimo.model_type = ModelType.encoder_or_decoder + self._mimo_models.append(ref_mimo) + + # Force identical initial state: encoder shards already match + # (same TP layout), so the helper copies shard-to-shard. LLM + # shards don't match (ref_llm_tp=enc_tp, dist_llm_tp=llm_tp), so + # the helper all-gathers ref's shards across ref's TP group and + # re-slices for dist's TP group. + _copy_ref_params_to_dist( + ref_mimo.modality_submodules[encoder_name].module, + dist_mimo.modality_submodules[encoder_name].module, + ref_enc_grid.get_pg("tp"), + dist_enc_grid.get_pg("tp"), + ) + _copy_ref_params_to_dist( + ref_mimo.language_model.module, + dist_mimo.language_model.module, + ref_llm_grid.get_pg("tp"), + dist_llm_grid.get_pg("tp"), + ) + + _wire_training_hooks(dist_mimo, dist_language_pg, dist_vision_pg) + _wire_training_hooks(ref_mimo, ref_language_pg, ref_vision_pg) + + # Distributed optimizers snapshot current param.data into fp32 master + # weights at __init__, so both must be built AFTER the ref-to-dist + # param copy above. + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=False, + use_distributed_optimizer=True, + ) + dist_optimizer = get_mimo_optimizer(dist_mimo, opt_config) + ref_optimizer = get_mimo_optimizer(ref_mimo, opt_config) + + # Data: one deterministic global batch, identical on every rank. + torch.manual_seed(99999) + global_batches = _generate_and_broadcast_global_batches( + global_mbs=global_batch_size, + seq_length=seq_length, + hidden_size=hidden_size, + vocab_size=vocab_size, + encoder_name=encoder_name, + num_batches=num_microbatches, + mask_pattern=mask_pattern, + ) + dist_batches = [ + _slice_global_batch_for_dist(b, dist_enc_grid, dist_llm_grid) for b in global_batches + ] + # Ref is uniform (enc_dp == llm_dp), so _slice_global_batch_for_dist + # returns the full batch; slice explicitly by enc_dp so each rank + # sees the same per-rank batch size as dist's encoder does. + ref_batches = [ + _slice_global_batch_by_dp(b, ref_enc_grid.get_pg("dp")) for b in global_batches + ] + ref_per_rank_batch_size = global_batch_size // enc_dp + + # Logits capture: hook fires on every microbatch forward. + # Registered before forward/backward, removed right after so the + # hook doesn't leak across the second model's run. + dist_logits, dist_logits_hook = _register_logits_capture(dist_mimo) + ref_logits, ref_logits_hook = _register_logits_capture(ref_mimo) + dist_llm_input, dist_input_hook = _register_llm_input_capture(dist_mimo) + ref_llm_input, ref_input_hook = _register_llm_input_capture(ref_mimo) + + try: + # One optimizer step on dist (heterogeneous forward_step slicing). + dist_optimizer.zero_grad() + _run_forward_backward( + mimo_model=dist_mimo, + batches=dist_batches, + enc_grid=dist_enc_grid, + llm_grid=dist_llm_grid, + encoder_name=encoder_name, + language_pg=dist_language_pg, + micro_batch_size=micro_batch_size, + seq_length=seq_length, + num_microbatches=num_microbatches, + ) + # Snapshot encoder first-layer grads AFTER backward and BEFORE + # optimizer.step() consumes/zeros the grad buffer. + dist_first_layer_grads = _snapshot_first_layer_encoder_grads(dist_mimo, encoder_name) + dist_success, dist_grad_norm, _ = dist_optimizer.step() + assert dist_success, "Dist optimizer step failed" + assert dist_grad_norm is not None and dist_grad_norm > 0, ( + f"Dist grad_norm={dist_grad_norm} — encoder grads may have been " + "silently zeroed by wrong scaling" + ) + + # One optimizer step on ref (enc_dp == llm_dp → forward_step skips slicing). + ref_optimizer.zero_grad() + _run_forward_backward( + mimo_model=ref_mimo, + batches=ref_batches, + enc_grid=ref_enc_grid, + llm_grid=ref_llm_grid, + encoder_name=encoder_name, + language_pg=ref_language_pg, + micro_batch_size=ref_per_rank_batch_size, + seq_length=seq_length, + num_microbatches=num_microbatches, + ) + ref_first_layer_grads = _snapshot_first_layer_encoder_grads(ref_mimo, encoder_name) + ref_success, ref_grad_norm, _ = ref_optimizer.step() + assert ref_success, "Ref optimizer step failed" + assert ref_grad_norm is not None and ref_grad_norm > 0, f"Ref grad_norm={ref_grad_norm}" + finally: + dist_logits_hook.remove() + ref_logits_hook.remove() + dist_input_hook.remove() + ref_input_hook.remove() + + # Run all three oracles regardless of individual failures so the + # diff-stats print covers every layer. Order: encoder weights / + # first-layer grads first (tightest — same encoder TP/DP layout + # → shards align 1:1), then LLM logits last (loosest — different + # LLM TP layout drives fp32 accumulation drift). Each oracle + # printed its own min/mean/p95/p99/max before its assertion ran, + # so the user sees the full drift distribution for every test. + failures = [] + + try: + _assert_encoder_weights_match( + ref_mimo.modality_submodules[encoder_name].module, + dist_mimo.modality_submodules[encoder_name].module, + rtol=1e-3, + atol=1e-3, + ) + except AssertionError as e: + failures.append(('encoder_weights', str(e))) + + try: + _assert_first_layer_grads_match( + ref_first_layer_grads, dist_first_layer_grads, rtol=1e-3, atol=1e-3 + ) + except AssertionError as e: + failures.append(('first_layer_grads', str(e))) + + try: + _assert_llm_input_match( + ref_llm_input, dist_llm_input, ref_llm_grid, dist_llm_grid, rtol=1e-3, atol=1e-3 + ) + except AssertionError as e: + failures.append(('llm_input', str(e))) + + try: + _assert_llm_logits_match( + ref_logits, dist_logits, ref_llm_grid, dist_llm_grid, rtol=1e-2, atol=1e-2 + ) + except AssertionError as e: + failures.append(('llm_logits', str(e))) + + if failures: + summary = "\n\n".join(f"== {oracle} ==\n{msg}" for oracle, msg in failures) + raise AssertionError(f"{len(failures)} oracle(s) failed:\n{summary}") diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index e1c4b6e89bf..0ef62ff570f 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -528,15 +528,13 @@ def test_grid_validation_rejects_mismatched_keys(self): self.hidden_size, self.img_h, self.img_w, self.patch_dim ) - mimo_config = MimoModelConfig( - language_model_spec=language_model_spec, - modality_submodules_spec={"images": vision_submodule_spec}, - special_token_ids={"images": 50257}, - module_to_grid_map={MIMO_LANGUAGE_MODULE_KEY: MockGrid()}, - ) - with pytest.raises(ValueError, match="module_to_grid_map keys must match"): - MimoModel(mimo_config) + MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={"images": vision_submodule_spec}, + special_token_ids={"images": 50257}, + module_to_grid_map={MIMO_LANGUAGE_MODULE_KEY: MockGrid()}, + ) def test_role_determination(self): """Test role correctly identifies modules and stage positions.""" @@ -550,7 +548,7 @@ def test_role_determination(self): self.patch_dim, {"images": 50257}, ) - assert model_no_grid.role.mode == ModuleLayout.UNIFIED + assert model_no_grid.role.mode == ModuleLayout.COLOCATED assert model_no_grid.role.has_language_module is True assert model_no_grid.role.has_modality_modules is True @@ -564,12 +562,15 @@ def test_role_determination(self): assert model_language.role.has_modality_modules is False assert model_language.role.has_language_module is True - # Stage info with PP + # Stage info with PP. language_in_grid=False so encoder and language + # grids have distinct rank_offsets and role.build dispatches to + # _from_grid_map (rather than collapsing to the COLOCATED path). model_pp = MimoModel( - self._make_config(encoder_in_grid=True, language_in_grid=True, pp_rank=1, pp_size=3) + self._make_config(encoder_in_grid=True, language_in_grid=False, pp_rank=1, pp_size=3) ) assert model_pp.role.is_first_stage("images") is False assert model_pp.role.is_last_stage("images") is False + assert model_pp.colocated_comms == {} def test_selective_init_encoder_only(self): """Test encoder-only rank initializes encoder but not language model.""" diff --git a/tests/unit_tests/pipeline_parallel/test_pipeline_layout.py b/tests/unit_tests/pipeline_parallel/test_pipeline_layout.py index a6afabe8817..9441996ec49 100644 --- a/tests/unit_tests/pipeline_parallel/test_pipeline_layout.py +++ b/tests/unit_tests/pipeline_parallel/test_pipeline_layout.py @@ -135,6 +135,7 @@ def create_args(): args.vocab_file = None args.add_position_embedding = False args.ckpt_assume_constant_structure = False + args.ckpt_load_validate_sharding_integrity = True args.dist_ckpt_strictness = "assume_ok_unexpected" args.fp16 = False args.bf16 = True @@ -148,6 +149,7 @@ def create_args(): args.distrib_optim_fully_reshardable_mem_efficient = False args.phase_transition_iterations = None args.async_strategy = "nvrx" + args.verify_integrity = False yield args diff --git a/tests/unit_tests/post_training/test_modelopt_model_builder.py b/tests/unit_tests/post_training/test_modelopt_model_builder.py index b489d659ec4..2ab8ebfe947 100644 --- a/tests/unit_tests/post_training/test_modelopt_model_builder.py +++ b/tests/unit_tests/post_training/test_modelopt_model_builder.py @@ -39,7 +39,7 @@ def test_model_provider_switches_to_modelopt_builder(monkeypatch): monkeypatch.setattr(mp, "has_nvidia_modelopt", True) monkeypatch.setattr(mp, "get_args", lambda: args) monkeypatch.setattr( - mp, "modelopt_gpt_mamba_builder", _sentinel_builder(modelopt_result, modelopt_calls) + mp, "modelopt_gpt_hybrid_builder", _sentinel_builder(modelopt_result, modelopt_calls) ) # original_builder should be ignored when ModelOpt is enabled. diff --git a/tests/unit_tests/resharding/test_model_swap.py b/tests/unit_tests/resharding/test_model_swap.py index 70d81d97829..105e5e3af50 100644 --- a/tests/unit_tests/resharding/test_model_swap.py +++ b/tests/unit_tests/resharding/test_model_swap.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved. import copy import gc import os @@ -37,8 +37,8 @@ try: import mamba_ssm # noqa: F401 - from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec - from megatron.core.models.mamba.mamba_model import MambaModel + from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec + from megatron.core.models.hybrid.hybrid_model import HybridModel has_mamba_deps = True except Exception: @@ -203,9 +203,9 @@ def _build_mamba( parallel_output: bool = True, ): pre_process, post_process = _pp_flags(pg_collection) - model = MambaModel( + model = HybridModel( config=config, - mamba_stack_spec=mamba_stack_spec, + hybrid_stack_spec=hybrid_stack_spec, vocab_size=vocab_size, max_sequence_length=seq_len, hybrid_layer_pattern=hybrid_layer_pattern, @@ -445,6 +445,337 @@ def test_swap_gpt_parametrized( torch.cuda.empty_cache() +@pytest.mark.parametrize( + "refit_backend", + [ + pytest.param( + "nvshmem", + marks=pytest.mark.skipif( + not has_nvshmem, + reason="nvshmem.core is not available (NVSHMEM Python bindings not installed)", + ), + ), + "nccl", + "gloo", + ], +) +@pytest.mark.parametrize( + "src_tp,src_ep,dst_tp,dst_ep", + [ + (2, 2, 1, 1), # TP2,EP2 -> TP1,EP1 (cross-cluster shape) + (1, 1, 2, 2), # TP1,EP1 -> TP2,EP2 (reverse) + (1, 2, 2, 2), # TP=1->TP=2 with EP unchanged + ], +) +def test_router_expert_bias_refit( + refit_backend: str, src_tp: int, src_ep: int, dst_tp: int, dst_ep: int +): + """Regression test: MoE router ``expert_bias`` (a *persistent buffer*, not a + Parameter) must travel with weights during refit/resharding. + + This was the root cause of stale routing on the inference model when + refit was used to re-shard a Nemotron-style MoE+Mamba checkpoint across + different TP/EP layouts: the router buffer carried aux-loss-free load + balancing state on the trainer but stayed at zero on the inference model + because ``swap_model_weights`` only enumerated ``named_parameters``. + """ + Utils.initialize_model_parallel( + tensor_model_parallel_size=src_tp, pipeline_model_parallel_size=1 + ) + world = dist.get_world_size() + if (world % (src_tp * src_ep) != 0) or (world % (dst_tp * dst_ep) != 0): + Utils.destroy_model_parallel() + pytest.skip("WORLD_SIZE must be divisible by both src_tp*src_ep and dst_tp*dst_ep") + + try: + import transformer_engine + except Exception: + Utils.destroy_model_parallel() + pytest.skip("Transformer Engine not available") + + model_parallel_cuda_manual_seed(1234) + torch.manual_seed(1234) + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + num_experts = 4 + cfg = TransformerConfig( + num_layers=2, + hidden_size=32, + num_attention_heads=8, + num_query_groups=4, + use_cpu_initialization=True, + pipeline_dtype=torch.float32, + hidden_dropout=0.0, + attention_dropout=0.0, + num_moe_experts=num_experts, + moe_ffn_hidden_size=64, + moe_grouped_gemm=True, + add_bias_linear=False, + moe_router_dtype="fp64", + moe_token_dispatcher_type="alltoall", + # The flag this regression covers: routers register a persistent + # ``expert_bias`` buffer used for aux-loss-free load balancing. + moe_router_enable_expert_bias=True, + moe_router_score_function="sigmoid", + ) + src_cfg = copy.deepcopy(cfg) + dst_cfg = copy.deepcopy(cfg) + src_cfg.expert_model_parallel_size = src_ep + dst_cfg.expert_model_parallel_size = dst_ep + + src_pgs = _build_pg_collection(tp_size=src_tp, pp_size=1, ep_size=src_ep) + dst_pgs = _build_pg_collection(tp_size=dst_tp, pp_size=1, ep_size=dst_ep) + + src_model = ( + _build_gpt( + src_cfg, + vocab_size=128, + seq_len=8, + pg_collection=src_pgs, + parallel_output=False, + num_moe_experts=num_experts, + ) + .to(device) + .eval() + ) + dst_model = ( + _build_gpt( + dst_cfg, + vocab_size=128, + seq_len=8, + pg_collection=dst_pgs, + parallel_output=False, + num_moe_experts=num_experts, + ) + .to(device) + .eval() + ) + + # Stamp a recognizable pattern into every router.expert_bias on src, + # promoting it to fp32 to mirror what _maintain_float32_expert_bias does on + # the trainer's first forward. dst is left at its bf16/init state so the + # refit must transfer the value AND harmonize the dtype. + test_pattern = torch.arange(num_experts, dtype=torch.float32, device=device) + 0.25 + src_buffers: dict[str, torch.Tensor] = {} + for name, mod in src_model.named_modules(): + bias = getattr(mod, "expert_bias", None) + if isinstance(bias, torch.Tensor): + with torch.no_grad(): + if bias.dtype != torch.float32: + fp32_bias = bias.detach().to(torch.float32) + fp32_bias.copy_(test_pattern) + mod._buffers["expert_bias"] = fp32_bias + else: + bias.copy_(test_pattern) + src_buffers[f"{name}.expert_bias"] = mod._buffers["expert_bias"] + + # Sanity: dst's buffers should NOT yet match src (they're zero-init). + pre_swap_match = all( + torch.allclose( + dict(dst_model.named_buffers()).get(n, torch.zeros_like(b)).float(), + b.float(), + atol=1e-5, + ) + for n, b in src_buffers.items() + ) + assert not pre_swap_match, "test setup wrong: dst already matches src before refit" + + swap_model_weights([src_model], [dst_model], refit_method=refit_backend) + torch.cuda.synchronize() + + # Verify each router.expert_bias on dst now matches src's stamped pattern. + dst_named_buffers = dict(dst_model.named_buffers()) + mismatches = [] + for name, src_buf in src_buffers.items(): + dst_buf = dst_named_buffers.get(name) + assert dst_buf is not None, f"dst missing buffer {name}" + if not torch.allclose(dst_buf.float(), src_buf.float(), atol=1e-5): + mismatches.append((name, (dst_buf - src_buf).abs().max().item())) + assert not mismatches, ( + f"router.expert_bias not transferred during refit " + f"(src_tp={src_tp}, src_ep={src_ep} -> dst_tp={dst_tp}, dst_ep={dst_ep}, " + f"backend={refit_backend}): {mismatches}" + ) + dist.barrier() + + del src_model, dst_model + clear_all_caches() + _destroy_pg_collection(src_pgs) + _destroy_pg_collection(dst_pgs) + Utils.destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + +@pytest.mark.parametrize( + "refit_backend", + [ + pytest.param( + "nvshmem", + marks=pytest.mark.skipif( + not has_nvshmem, + reason="nvshmem.core is not available (NVSHMEM Python bindings not installed)", + ), + ), + "nccl", + "gloo", + ], +) +def test_router_expert_bias_refit_non_collocated(refit_backend: str): + """Non-collocated counterpart of ``test_router_expert_bias_refit``. + + Splits the world into disjoint src and dst rank sets so dst-only ranks + have no local view of the src model. Exercises the ``all_gather_object``- + based dtype harmonization path: src ranks hold ``expert_bias`` in fp32 and + dst ranks in bf16, and the only way dst can learn the expected dtype is + via the gathered map. + """ + Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + world = dist.get_world_size() + src_tp, src_ep, dst_tp, dst_ep = 2, 1, 2, 1 + src_world = src_tp * src_ep + dst_world = dst_tp * dst_ep + if world < src_world + dst_world: + Utils.destroy_model_parallel() + pytest.skip(f"Non-collocated test requires WORLD_SIZE >= {src_world + dst_world}") + + try: + import transformer_engine # noqa: F401 + except Exception: + Utils.destroy_model_parallel() + pytest.skip("Transformer Engine not available") + + from megatron.rl.parallel_utils import build_inference_pg_collection + + rank = dist.get_rank() + is_src = rank < src_world + is_dst = src_world <= rank < src_world + dst_world + + model_parallel_cuda_manual_seed(1234) + torch.manual_seed(1234) + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + num_experts = 4 + cfg = TransformerConfig( + num_layers=2, + hidden_size=32, + num_attention_heads=8, + num_query_groups=4, + use_cpu_initialization=True, + pipeline_dtype=torch.float32, + hidden_dropout=0.0, + attention_dropout=0.0, + num_moe_experts=num_experts, + moe_ffn_hidden_size=64, + moe_grouped_gemm=True, + add_bias_linear=False, + moe_router_dtype="fp64", + moe_token_dispatcher_type="alltoall", + moe_router_enable_expert_bias=True, + moe_router_score_function="sigmoid", + ) + src_cfg = copy.deepcopy(cfg) + dst_cfg = copy.deepcopy(cfg) + src_cfg.expert_model_parallel_size = src_ep + dst_cfg.expert_model_parallel_size = dst_ep + + # Both pg collections are built collectively on every rank (dist.new_group + # requires it) but each one's groups only contain the ranks for that side. + src_pgs = build_inference_pg_collection( + world_size=src_world, tp_size=src_tp, ep_size=src_ep, rank_offset=0 + ) + dst_pgs = build_inference_pg_collection( + world_size=dst_world, tp_size=dst_tp, ep_size=dst_ep, rank_offset=src_world + ) + + src_model = None + dst_model = None + if is_src: + src_model = ( + _build_gpt( + src_cfg, + vocab_size=128, + seq_len=8, + pg_collection=src_pgs, + parallel_output=False, + num_moe_experts=num_experts, + ) + .to(device) + .eval() + ) + elif is_dst: + dst_model = ( + _build_gpt( + dst_cfg, + vocab_size=128, + seq_len=8, + pg_collection=dst_pgs, + parallel_output=False, + num_moe_experts=num_experts, + ) + .to(device) + .eval() + ) + + test_pattern = torch.arange(num_experts, dtype=torch.float32, device=device) + 0.25 + if is_src and src_model is not None: + for name, mod in src_model.named_modules(): + bias = getattr(mod, "expert_bias", None) + if isinstance(bias, torch.Tensor): + with torch.no_grad(): + # Promote to fp32 to mirror what _maintain_float32_expert_bias + # does on the trainer's first forward, while dst remains at + # its bf16/fp32-from-init state. This forces the dtype + # harmonization path to do work for non-collocated transfer + # (dst-only ranks have no local view of src's dtype). + if bias.dtype != torch.float32: + fp32_bias = bias.detach().to(torch.float32) + fp32_bias.copy_(test_pattern) + mod._buffers["expert_bias"] = fp32_bias + else: + bias.copy_(test_pattern) + + dist.barrier() + + swap_model_weights( + [src_model] if src_model is not None else None, + [dst_model] if dst_model is not None else None, + refit_method=refit_backend, + ) + torch.cuda.synchronize() + dist.barrier() + + if is_dst and dst_model is not None: + dst_named_buffers = dict(dst_model.named_buffers()) + mismatches = [] + for name, dst_buf in dst_named_buffers.items(): + if not name.endswith("expert_bias"): + continue + # Replicated buffer: expected value is the stamped test_pattern. + # dst_buf should also be fp32 thanks to dtype harmonization. + if dst_buf.dtype != torch.float32: + mismatches.append((name, f"dtype not harmonized: {dst_buf.dtype}")) + continue + if not torch.allclose(dst_buf, test_pattern, atol=1e-5): + mismatches.append((name, (dst_buf - test_pattern).abs().max().item())) + assert not mismatches, ( + f"Non-collocated refit did not transfer router.expert_bias correctly " + f"(backend={refit_backend}): {mismatches}" + ) + + dist.barrier() + if src_model is not None: + del src_model + if dst_model is not None: + del dst_model + clear_all_caches() + _destroy_pg_collection(src_pgs) + _destroy_pg_collection(dst_pgs) + Utils.destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + @pytest.mark.parametrize( "refit_backend", [ diff --git a/tests/unit_tests/ssm/test_causal_conv1d_triton.py b/tests/unit_tests/ssm/test_causal_conv1d_triton.py index 3015f5ed989..624cd8c048b 100644 --- a/tests/unit_tests/ssm/test_causal_conv1d_triton.py +++ b/tests/unit_tests/ssm/test_causal_conv1d_triton.py @@ -221,6 +221,50 @@ def test_intermediate_state(self, width): conv_state_ref[:, :, -1] = x[:, s, :] torch.testing.assert_close(int_states[:, s, :, :], conv_state_ref, atol=1e-5, rtol=1e-5) + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_state_len_eq_width_fast_path(self, width): + """Cover the ``state_len == WIDTH`` fast path (the common Mamba + configuration where d_conv == width). + + The other tests use ``state_len = 8`` so they always fall through to + the explicit shift loop. Here ``state_len = width`` exercises the + register-resident shift and the matching ``HAS_INT_STATE`` branch. + """ + torch.manual_seed(42) + B, seq_len, D = 2, 4, 64 + state_len = width + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_initial = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + int_states = torch.zeros(B, seq_len, D, state_len, device="cuda", dtype=torch.float32) + + conv_state_triton = conv_state_initial.clone() + conv_state_ref = conv_state_initial.clone() + + result = causal_conv1d_update( + x, + conv_state_triton, + weight, + bias=None, + silu_activation=False, + conv_state_indices=None, + intermediate_conv_states=int_states, + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=None, silu_activation=False + ) + + # Output and final state match the reference. + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(conv_state_triton, conv_state_ref, atol=1e-5, rtol=1e-5) + + # Per-step intermediate states match a manual replay. + replay_state = conv_state_initial.clone() + for s in range(seq_len): + replay_state[:, :, :-1] = replay_state[:, :, 1:].clone() + replay_state[:, :, -1] = x[:, s, :] + torch.testing.assert_close(int_states[:, s, :, :], replay_state, atol=1e-5, rtol=1e-5) + def test_intermediate_state_with_indices(self): """Test intermediate states work correctly with conv_state_indices mapping.""" torch.manual_seed(42) diff --git a/tests/unit_tests/ssm/test_split_tensor_factory.py b/tests/unit_tests/ssm/test_split_tensor_factory.py new file mode 100644 index 00000000000..abb668e16a8 --- /dev/null +++ b/tests/unit_tests/ssm/test_split_tensor_factory.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import logging +from unittest import mock + +import pytest +import torch + +from megatron.core.dist_checkpointing import ShardedTensor +from megatron.core.ssm.gated_delta_net import ( + _split_tensor_factory as gated_delta_split_tensor_factory, +) +from megatron.core.ssm.mamba_mixer import _split_tensor_factory as mamba_split_tensor_factory + + +@pytest.mark.parametrize( + "factory_fn", + [gated_delta_split_tensor_factory, mamba_split_tensor_factory], + ids=["gated_delta_net", "mamba_mixer"], +) +@pytest.mark.internal +def test_ssm_split_tensor_factory_oom_is_handled(factory_fn, caplog): + original_sh_ten = ShardedTensor.from_rank_offsets( + 'a', torch.arange(12, dtype=torch.float32).reshape(6, 2), (0, 0, 1) + ) + factory = factory_fn(original_sh_ten, [2, 4], ['x', 'B'], 0) + sub_state_dict = [torch.ones((2, 2), dtype=torch.float32), torch.full((4, 2), 2.0)] + + real_cat = torch.cat + call_count = 0 + + def fake_cat(tensors, *args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise torch.cuda.OutOfMemoryError('mock oom') + return real_cat(tensors, *args, **kwargs) + + with ( + mock.patch('torch.cat', side_effect=fake_cat), + mock.patch('gc.collect') as collect_mock, + mock.patch('torch.cuda.empty_cache') as empty_cache_mock, + caplog.at_level(logging.WARNING), + ): + merged = factory.merge_fn(sub_state_dict) + + assert torch.equal(merged, real_cat(sub_state_dict)) + assert merged.device.type == 'cpu' + assert call_count == 2 + collect_mock.assert_called_once() + empty_cache_mock.assert_called_once() + assert "CUDA OutOfMemoryError encountered during tensors merging" in caplog.text diff --git a/tests/unit_tests/test_activation_logging.py b/tests/unit_tests/test_activation_logging.py index cc7634d6558..afd6bc2f795 100644 --- a/tests/unit_tests/test_activation_logging.py +++ b/tests/unit_tests/test_activation_logging.py @@ -25,12 +25,21 @@ def simple_model(): class TestMakeTpeHook: """Tests for _make_tpe_hook regex layer extraction.""" - def test_extracts_layer_number(self, logger): + def test_extracts_decoder_layer_number(self, logger): hook = logger._make_tpe_hook("chunk0", "decoder.layers.3.mlp.experts.linear_fc1") assert hook is not None fake_tpe = [128, 64, 96, 80] hook(None, (torch.zeros(1), fake_tpe), {}, torch.zeros(1)) - assert logger._tpe_records["3"] == [fake_tpe] + assert logger._decoder_tpe_records[3] == [fake_tpe] + + def test_extracts_mtp_layer_number(self, logger): + hook = logger._make_tpe_hook( + "chunk0", "mtp.layers.0.mtp_model_layer.layers.1.mlp.experts.linear_fc1" + ) + assert hook is not None + fake_tpe = [50, 50] + hook(None, (torch.zeros(1), fake_tpe), {}, torch.zeros(1)) + assert logger._mtp_tpe_records[(0, 1)] == [fake_tpe] def test_returns_none_for_non_matching_name(self, logger, caplog): with caplog.at_level(logging.WARNING): @@ -43,9 +52,10 @@ class TestSaveTpe: """Tests for save_tpe JSONL output.""" def test_creates_jsonl(self, tmp_path, logger): - logger._tpe_records["3"].append([128, 64, 96, 80]) - logger._tpe_records["3"].append([100, 90, 110, 70]) - logger._tpe_records["7"].append([200, 200]) + logger._decoder_tpe_records[3].append([10, 20]) + logger._decoder_tpe_records[3].append([30, 40]) + logger._decoder_tpe_records[7].append([50, 60]) + logger._mtp_tpe_records[(0, 1)].append([70, 80]) logger.save_tpe(iteration=100) @@ -55,15 +65,16 @@ def test_creates_jsonl(self, tmp_path, logger): records = [json.loads(line) for line in filepath.read_text().strip().split("\n")] assert records == [ - {"iter": 100, "layer": 3, "tpe": [[128, 64, 96, 80], [100, 90, 110, 70]]}, - {"iter": 100, "layer": 7, "tpe": [[200, 200]]}, + {"iter": 100, "block": "decoder", "layer": 3, "tpe": [[10, 20], [30, 40]]}, + {"iter": 100, "block": "decoder", "layer": 7, "tpe": [[50, 60]]}, + {"iter": 100, "block": "mtp", "mtp_idx": 0, "layer": 1, "tpe": [[70, 80]]}, ] def test_appends_across_calls(self, tmp_path, logger): - logger._tpe_records["0"].append([10, 20]) + logger._decoder_tpe_records[0].append([10, 20]) logger.save_tpe(iteration=100) - logger._tpe_records["0"].append([30, 40]) + logger._decoder_tpe_records[0].append([30, 40]) logger.save_tpe(iteration=200) rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 diff --git a/tests/unit_tests/test_argument_utils.py b/tests/unit_tests/test_argument_utils.py index e5744c3b074..f698e879951 100644 --- a/tests/unit_tests/test_argument_utils.py +++ b/tests/unit_tests/test_argument_utils.py @@ -1,13 +1,21 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import signal -from argparse import ArgumentError, ArgumentParser +from argparse import ArgumentError, ArgumentParser, Namespace from dataclasses import dataclass, field from typing import Callable, Literal, Optional, Union +from unittest.mock import MagicMock, patch import pytest -from megatron.training.argument_utils import ArgumentGroupFactory, TypeInferenceError +from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig +from megatron.training.argument_utils import ( + ArgumentGroupFactory, + TypeInferenceError, + pretrain_cfg_container_from_args, +) +from megatron.training.config import PretrainConfigContainer @dataclass @@ -641,3 +649,223 @@ def test_handled_unsupported_unions(self): with pytest.raises(ArgumentError, match="invalid choice"): args = parser.parse_args(['--unsupported-with-metadata', 'baz']) + + +# --------------------------------------------------------------------------- +# Tests for pretrain_cfg_container_from_args +# --------------------------------------------------------------------------- + + +def _make_args(**overrides): + """Build the minimum Namespace required by pretrain_cfg_container_from_args. + + These fields have non-trivial CLI→config name mappings or boolean inversions + that pretrain_cfg_container_from_args handles explicitly. + """ + defaults = { + # CheckpointConfig boolean inversions (legacy-style --no-* flags) + "no_save_optim": False, + "no_save_rng": False, + "no_load_optim": False, + "no_load_rng": False, + # CheckpointConfig custom argparse dest names + "ckpt_fully_parallel_save": True, + "ckpt_fully_parallel_load": False, + # ProfilingConfig: use_nsys_profiler is exposed as --profile on the CLI + "profile": False, + # RerunStateMachineConfig: field is check_for_nan_in_loss, CLI flag is check_for_nan_in_loss_and_grad + "check_for_nan_in_loss_and_grad": True, + } + defaults.update(overrides) + return Namespace(**defaults) + + +@pytest.fixture +def mock_optimizer_config(): + return MagicMock(spec=OptimizerConfig) + + +@pytest.fixture +def mock_ddp_config(): + return MagicMock(spec=DistributedDataParallelConfig) + + +@pytest.fixture +def patch_training_helpers(mock_optimizer_config, mock_ddp_config): + """Patch the two helper functions called by pretrain_cfg_container_from_args.""" + with ( + patch( + "megatron.training.training.get_megatron_optimizer_config", + return_value=(mock_optimizer_config, {}), + ), + patch("megatron.training.training.get_megatron_ddp_config", return_value=mock_ddp_config), + ): + yield + + +class TestPretrainContainerFromArgsStructure: + """Test the top-level structure of the object returned by pretrain_cfg_container_from_args.""" + + def test_returns_pretrain_config_container(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args()) + assert isinstance(result, PretrainConfigContainer) + + @patch("megatron.training.training.get_megatron_ddp_config") + @patch("megatron.training.training.get_megatron_optimizer_config") + def test_optimizer_config_comes_from_helper(self, mock_opt, mock_ddp): + """Test that optimizer config comes from get_megatron_optimizer_config.""" + mock_optimizer = MagicMock(spec=OptimizerConfig) + mock_opt.return_value = (mock_optimizer, {}) + mock_ddp.return_value = MagicMock(spec=DistributedDataParallelConfig) + args = _make_args() + result = pretrain_cfg_container_from_args(args) + mock_opt.assert_called_once_with(args) + assert result.optimizer is mock_optimizer + + @patch("megatron.training.training.get_megatron_ddp_config") + @patch("megatron.training.training.get_megatron_optimizer_config") + def test_ddp_config_comes_from_helper(self, mock_opt, mock_ddp): + """Test that ddp config comes from get_megatron_ddp_config.""" + mock_ddp_instance = MagicMock(spec=DistributedDataParallelConfig) + mock_opt.return_value = (MagicMock(spec=OptimizerConfig), {}) + mock_ddp.return_value = mock_ddp_instance + args = _make_args() + result = pretrain_cfg_container_from_args(args) + mock_ddp.assert_called_once_with(args) + assert result.ddp is mock_ddp_instance + + +class TestCheckpointConfigMapping: + """Test the boolean inversions and custom dest mappings for CheckpointConfig.""" + + def test_no_save_optim_false_means_save_optim_true(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(no_save_optim=False)) + assert result.checkpoint.save_optim is True + + def test_no_save_optim_true_means_save_optim_false(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(no_save_optim=True)) + assert result.checkpoint.save_optim is False + + def test_no_save_rng_false_means_save_rng_true(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(no_save_rng=False)) + assert result.checkpoint.save_rng is True + + def test_no_save_rng_true_means_save_rng_false(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(no_save_rng=True)) + assert result.checkpoint.save_rng is False + + def test_no_load_optim_false_means_load_optim_true(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(no_load_optim=False)) + assert result.checkpoint.load_optim is True + + def test_no_load_optim_true_means_load_optim_false(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(no_load_optim=True)) + assert result.checkpoint.load_optim is False + + def test_no_load_rng_false_means_load_rng_true(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(no_load_rng=False)) + assert result.checkpoint.load_rng is True + + def test_no_load_rng_true_means_load_rng_false(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(no_load_rng=True)) + assert result.checkpoint.load_rng is False + + def test_ckpt_fully_parallel_save_mapping(self, patch_training_helpers): + """ckpt_fully_parallel_save in args maps to fully_parallel_save in CheckpointConfig.""" + assert ( + pretrain_cfg_container_from_args( + _make_args(ckpt_fully_parallel_save=True) + ).checkpoint.fully_parallel_save + is True + ) + assert ( + pretrain_cfg_container_from_args( + _make_args(ckpt_fully_parallel_save=False) + ).checkpoint.fully_parallel_save + is False + ) + + def test_ckpt_fully_parallel_load_mapping(self, patch_training_helpers): + """ckpt_fully_parallel_load in args maps to fully_parallel_load in CheckpointConfig.""" + assert ( + pretrain_cfg_container_from_args( + _make_args(ckpt_fully_parallel_load=True) + ).checkpoint.fully_parallel_load + is True + ) + assert ( + pretrain_cfg_container_from_args( + _make_args(ckpt_fully_parallel_load=False) + ).checkpoint.fully_parallel_load + is False + ) + + def test_direct_checkpoint_fields_from_args(self, patch_training_helpers): + """Checkpoint fields with 1-to-1 name mapping are pulled directly from args.""" + args = _make_args(save="/path/to/save", load="/path/to/load", save_interval=500) + result = pretrain_cfg_container_from_args(args) + assert result.checkpoint.save == "/path/to/save" + assert result.checkpoint.load == "/path/to/load" + assert result.checkpoint.save_interval == 500 + + +class TestProfilingConfigMapping: + """Test the --profile → use_nsys_profiler mapping for ProfilingConfig.""" + + def test_profile_false_disables_nsys_profiler(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(profile=False)) + assert result.profiling.use_nsys_profiler is False + + def test_profile_true_enables_nsys_profiler(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(profile=True)) + assert result.profiling.use_nsys_profiler is True + + def test_direct_profiling_fields_from_args(self, patch_training_helpers): + """Profiling fields with 1-to-1 name mapping are pulled directly from args.""" + args = _make_args(profile_step_start=5, profile_step_end=15) + result = pretrain_cfg_container_from_args(args) + assert result.profiling.profile_step_start == 5 + assert result.profiling.profile_step_end == 15 + + +class TestRerunStateMachineConfigMapping: + """Test the check_for_nan_in_loss_and_grad → check_for_nan_in_loss mapping.""" + + def test_check_for_nan_true(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(check_for_nan_in_loss_and_grad=True)) + assert result.rerun_state_machine.check_for_nan_in_loss is True + + def test_check_for_nan_false(self, patch_training_helpers): + result = pretrain_cfg_container_from_args(_make_args(check_for_nan_in_loss_and_grad=False)) + assert result.rerun_state_machine.check_for_nan_in_loss is False + + def test_direct_rerun_state_machine_fields_from_args(self, patch_training_helpers): + """RerunStateMachineConfig fields with 1-to-1 name mapping are pulled directly from args.""" + args = _make_args(error_injection_rate=500, rerun_mode="report_stats") + result = pretrain_cfg_container_from_args(args) + assert result.rerun_state_machine.error_injection_rate == 500 + assert result.rerun_state_machine.rerun_mode == "report_stats" + + +class TestTrainingConfigMapping: + """Test that fields are pulled from args for configs that use _default_config_from_args() directly. + + TrainingConfig is used as the representative case. The same pass-through logic applies to + ValidationConfig, SchedulerConfig, RNGConfig, DistributedInitConfig, LoggerConfig, and + StragglerDetectionConfig — dedicated test classes for those are only warranted if + pretrain_cfg_container_from_args() adds special handling for them. + """ + + def test_training_fields_from_args(self, patch_training_helpers): + args = _make_args(train_iters=1000, micro_batch_size=4, global_batch_size=64) + result = pretrain_cfg_container_from_args(args) + assert result.train.train_iters == 1000 + assert result.train.micro_batch_size == 4 + assert result.train.global_batch_size == 64 + + def test_training_uses_defaults_when_fields_absent(self, patch_training_helpers): + """When training fields are absent from args, TrainingConfig uses its defaults.""" + result = pretrain_cfg_container_from_args(_make_args()) + assert result.train.train_iters is None + assert result.train.micro_batch_size is None + assert result.train.global_batch_size is None diff --git a/tests/unit_tests/test_checkpointing.py b/tests/unit_tests/test_checkpointing.py index 16bac10566d..c778d58fb7f 100644 --- a/tests/unit_tests/test_checkpointing.py +++ b/tests/unit_tests/test_checkpointing.py @@ -118,6 +118,7 @@ def create_args(): args.ckpt_step = None args.swiglu = True args.num_experts = 1 + args.verify_integrity = False yield args @@ -140,6 +141,7 @@ def create_ckpt_load_args(create_args): args.ckpt_assume_constant_structure = False args.ckpt_fully_parallel_save = False args.ckpt_fully_parallel_load = False + args.ckpt_load_validate_sharding_integrity = True args.dist_ckpt_strictness = 'assume_ok_unexpected' args.use_megatron_fsdp = False args.strict_fsdp_dtensor_load = True diff --git a/tests/unit_tests/test_fault_injector.py b/tests/unit_tests/test_fault_injector.py new file mode 100644 index 00000000000..bc582c77196 --- /dev/null +++ b/tests/unit_tests/test_fault_injector.py @@ -0,0 +1,174 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import enum +import math +from argparse import ArgumentParser +from types import SimpleNamespace + +import pytest +import torch + +import megatron.core.fault_injector as fault_injector +from megatron.training.argument_utils import ArgumentGroupFactory +from megatron.training.config import FaultInjectorConfig + + +def create_test_config(**overrides): + config = FaultInjectorConfig(**overrides) + return config + + +class TestFaultInjectorConfig: + def test_fault_injector_config_has_delay_start_iteration(self): + config = FaultInjectorConfig() + assert config.fault_injector_delay_start_iteration is None + + def test_cli_arg_generated(self): + parser = ArgumentParser() + factory = ArgumentGroupFactory(FaultInjectorConfig) + factory.build_group(parser, "fault injector") + args = parser.parse_args(["--fault-injector-delay-start-iteration", "100"]) + assert args.fault_injector_delay_start_iteration == 100 + + def test_old_cli_arg_removed(self): + parser = ArgumentParser() + factory = ArgumentGroupFactory(FaultInjectorConfig) + factory.build_group(parser, "fault injector") + with pytest.raises(SystemExit): + parser.parse_args(["--fault-injector-fault-iteration", "100"]) + + +class TestFaultInjectorScheduling: + def test_get_fault_ranks_parses_explicit_rank_list(self, monkeypatch): + config = create_test_config(fault_injector_ranks="0,3,7", fault_injector_num_ranks=None) + monkeypatch.setattr(fault_injector.dist, "get_world_size", lambda: 8) + + assert fault_injector.get_fault_ranks(config) == [0, 3, 7] + + def test_get_fault_ranks_samples_requested_num_ranks(self, monkeypatch): + sampled = [] + config = create_test_config(fault_injector_ranks=None, fault_injector_num_ranks=2) + + def fake_sample(population, k): + sampled.append((list(population), k)) + return [2, 5] + + monkeypatch.setattr(fault_injector.dist, "get_world_size", lambda: 8) + monkeypatch.setattr(fault_injector, "rng", SimpleNamespace(sample=fake_sample)) + + assert fault_injector.get_fault_ranks(config) == [2, 5] + assert sampled == [([1, 2, 3, 4, 5, 6, 7], 2)] + + def test_get_fault_requires_fault_types(self, monkeypatch): + config = create_test_config() + monkeypatch.setattr(fault_injector, "has_nvidia_resiliency_ext", True) + + with pytest.raises(AssertionError, match="fault_injector_fault_types must be specified"): + fault_injector.get_fault(config) + + def test_get_fault_parses_types_and_normalizes_probabilities(self, monkeypatch): + class FakeFault(enum.IntEnum): + HANG = 1 + CRASH = 2 + + captured = [] + config = create_test_config( + fault_injector_fault_types="hang,crash", fault_injector_fault_probabilities="2,1" + ) + + def fake_choices(fault_types, fault_probabilities, k): + captured.append((fault_types, fault_probabilities, k)) + return [fault_types[1]] + + monkeypatch.setattr(fault_injector, "has_nvidia_resiliency_ext", True) + monkeypatch.setattr(fault_injector, "Fault", FakeFault, raising=False) + monkeypatch.setattr(fault_injector, "rng", SimpleNamespace(choices=fake_choices)) + + fault = fault_injector.get_fault(config) + + assert fault == FakeFault.CRASH + assert captured[0][0] == [FakeFault.HANG, FakeFault.CRASH] + assert captured[0][2] == 1 + assert captured[0][1] == pytest.approx([2 / 3, 1 / 3]) + + def test_should_setup_fault_injection_at_start_without_anchor(self): + config = create_test_config() + assert fault_injector.should_setup_fault_injection_at_start(config) + assert not fault_injector.should_setup_fault_injection_at_iteration(config, 0) + + def test_should_setup_fault_injection_at_matching_iteration(self): + config = create_test_config(fault_injector_delay_start_iteration=12) + assert not fault_injector.should_setup_fault_injection_at_start(config) + assert not fault_injector.should_setup_fault_injection_at_iteration(config, 11) + assert fault_injector.should_setup_fault_injection_at_iteration(config, 12) + + def test_get_fault_delay_returns_explicit_delay(self): + config = create_test_config( + fault_injector_fault_delay=7.5, fault_injector_delay_start_iteration=100 + ) + assert fault_injector.get_fault_delay(config) == 7.5 + + def test_get_fault_delay_samples_mtti_with_offset(self, monkeypatch): + config = create_test_config( + fault_injector_delay_start_iteration=50, + fault_injector_mtti_seconds=10.0, + fault_injector_offset_seconds=2.0, + ) + monkeypatch.setattr(fault_injector, "rng", SimpleNamespace(random=lambda: 0.5)) + + fault_delay = fault_injector.get_fault_delay(config) + + assert math.isclose(fault_delay, 2.0 + (math.log(2.0) * 10.0)) + + def test_get_fault_delay_requires_time_based_configuration(self): + config = create_test_config(fault_injector_delay_start_iteration=25) + with pytest.raises( + AssertionError, + match="fault_injector_fault_delay or fault_injector_mtti_seconds must be specified", + ): + fault_injector.get_fault_delay(config) + + +class TestFaultInjectorSetup: + def test_setup_fault_injection_uses_single_plan_broadcast_for_zero_valued_fault( + self, monkeypatch + ): + class ZeroFault(enum.IntEnum): + ZERO_FAULT = 0 + + broadcasts = [] + dispatched = [] + config = create_test_config(fault_injector_seed=123) + fake_torch = SimpleNamespace( + device=lambda *_args, **_kwargs: "cpu", + full=torch.full, + float64=torch.float64, + cuda=SimpleNamespace(current_device=lambda: 0), + ) + + monkeypatch.setattr(fault_injector, "has_nvidia_resiliency_ext", True) + monkeypatch.setattr(fault_injector, "Fault", ZeroFault, raising=False) + monkeypatch.setattr(fault_injector, "torch", fake_torch) + monkeypatch.setattr(fault_injector.dist, "get_rank", lambda: 0) + monkeypatch.setattr(fault_injector.dist, "get_world_size", lambda: 4) + monkeypatch.setattr( + fault_injector.dist, + "broadcast", + lambda tensor, src: broadcasts.append((tensor.clone(), src)), + ) + monkeypatch.setattr(fault_injector, "clear_workload_exception", lambda: None, raising=False) + monkeypatch.setattr(fault_injector, "get_fault_ranks", lambda _config: [0]) + monkeypatch.setattr(fault_injector, "get_fault", lambda _config: ZeroFault.ZERO_FAULT) + monkeypatch.setattr(fault_injector, "get_fault_delay", lambda _config: 3.5) + monkeypatch.setattr( + fault_injector, + "dispatch_fault_injection", + lambda fault, delay, callback: dispatched.append((fault, delay, callback)), + raising=False, + ) + monkeypatch.setattr(fault_injector, "rng", None) + + fault_injector.setup_fault_injection(config) + + assert len(broadcasts) == 1 + assert dispatched == [(ZeroFault.ZERO_FAULT, 3.5, None)] diff --git a/tests/unit_tests/test_inference.py b/tests/unit_tests/test_inference.py index 9474ac0475a..528cc71eefb 100644 --- a/tests/unit_tests/test_inference.py +++ b/tests/unit_tests/test_inference.py @@ -28,6 +28,7 @@ def gpt2_tiktoken_tokenizer(): @pytest.fixture(scope="module") def static_inference_engine(gpt2_tiktoken_tokenizer): + Utils.initialize_model_parallel() engine_wrapper = StaticInferenceEngineTestHarness() engine_wrapper.setup_engine(vocab_size=gpt2_tiktoken_tokenizer.vocab_size, legacy=True) @@ -63,8 +64,6 @@ def client(app): 'megatron.core.inference.text_generation_server.text_generation_server.send_do_generate' ) def test_generations_endpoint(mock_send_do_generate, client, gpt2_tiktoken_tokenizer): - Utils.initialize_distributed() - prompts = ["twinkle twinkle little star, how I wonder what you are"] request_data = {"prompts": prompts, "tokens_to_generate": 10, "logprobs": True} @@ -90,8 +89,6 @@ def test_generations_endpoint(mock_send_do_generate, client, gpt2_tiktoken_token "megatron.core.inference.text_generation_server.endpoints.completions.send_do_generate" ) def test_completions_endpoint(mock_send_do_generate, client, gpt2_tiktoken_tokenizer): - Utils.initialize_distributed() - twinkle = ("twinkle twinkle little star,", " how I wonder what you are") request_data = {"prompt": twinkle[0] + twinkle[1], "max_tokens": 0, "logprobs": 5, "echo": True} diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index e43f2ab7d88..36e5fe11b67 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -821,6 +821,13 @@ def test_overlap_param_gather_multi_iteration(self): optimizer.step() model.start_param_sync(force_sync=True) + # Verify grad_data is zeroed after synchronous param_sync. + for bucket_group in model.bucket_groups: + for bucket in bucket_group.buckets: + assert torch.all( + bucket.grad_data == 0 + ), f"grad_data not zeroed after param sync at iteration {iteration}" + # Sync path: step (includes allgather) ref_optimizer.step() @@ -860,6 +867,13 @@ def test_overlap_param_gather_async_dispatch_and_finish(self): for bucket_group in model.bucket_groups: bucket_group.finish_param_sync(skip_next_bucket_dispatch=True) + # Verify grad_data is zeroed after asynchronous param_sync. + for bucket_group in model.bucket_groups: + for bucket in bucket_group.buckets: + assert torch.all( + bucket.grad_data == 0 + ), "grad_data not zeroed after finish_param_sync" + # Sync path: step (includes allgather) ref_optimizer.step() diff --git a/tests/unit_tests/test_training.py b/tests/unit_tests/test_training.py index a893734bd89..838b963778c 100644 --- a/tests/unit_tests/test_training.py +++ b/tests/unit_tests/test_training.py @@ -18,6 +18,24 @@ def mock_train_valid_test_datasets_provider(train_val_test_num_samples): return iter([1]), iter([2]), iter([3]) +class _LenDataloader: + """Fake dataloader with __len__ (required by the full_validation path) + and __iter__ (consumed via cyclic_iter).""" + + def __init__(self, data): + self._data = list(data) + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._data) + + +def mock_multi_valid_full_datasets_provider(train_val_test_num_samples): + return (iter([1]), [_LenDataloader([2, 2]), _LenDataloader([20, 20, 20])], iter([3])) + + def create_test_args(): # Set dummy values for the args. args = SimpleNamespace() @@ -55,6 +73,24 @@ def test_build_train_valid_test_data_iterators(self): test_data = next(test_iter) assert (train_data, valid_data, test_data) == (1, 2, 3) + def test_build_train_valid_test_data_iterators_multi_full_validation(self): + """multiple_validation_sets + full_validation builds a list of iterators + (one per validation set) and sets args.eval_iters to the per-loader + lengths MAX-reduced across DP ranks.""" + args = create_test_args() + args.multiple_validation_sets = True + args.full_validation = True + set_args(args) + _, valid_iters, _ = build_train_valid_test_data_iterators( + mock_multi_valid_full_datasets_provider + ) + assert isinstance(valid_iters, list) + assert len(valid_iters) == 2 + assert next(valid_iters[0]) == 2 + assert next(valid_iters[1]) == 20 + # data_parallel_size=1, so MAX across DP ranks equals the local lengths + assert args.eval_iters == [2, 3] + def test_closed_formula_vocab_size_with_padding(self): def old_round_impl(after, multiple): while (after % multiple) != 0: diff --git a/tests/unit_tests/test_utilities.py b/tests/unit_tests/test_utilities.py index f8fad3325f5..5856d8666a1 100644 --- a/tests/unit_tests/test_utilities.py +++ b/tests/unit_tests/test_utilities.py @@ -26,6 +26,13 @@ def __init__( self.layers[-1].weight.shared_embedding = True +def clear_nvte_env_vars(): + """Clear NVTE env vars set by conftest set_env fixture.""" + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + + class Utils: world_size = int(os.environ.get('WORLD_SIZE', '1')) @@ -96,7 +103,7 @@ def destroy_model_parallel(): # Flush pending CUDA work before the barrier so slow ranks don't # time out while fast ranks tear down process groups. torch.cuda.synchronize() - torch.distributed.barrier(timeout=timedelta(seconds=300)) + torch.distributed.barrier() except Exception: Utils.inited = False return diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index 95756101e74..d7ec4042ca9 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -163,17 +163,19 @@ def test_nvtx_decorator(): # Track function execution execution_tracker = {'decorated': False, 'decorated_with_message': False} - # Create decorated functions + # Decorate while NVTX is disabled (the common import-time scenario). + # The _nvtx_enabled flag must be checked at call time, not decoration time. + util.configure_nvtx_profiling(False) + @util.nvtx_decorator() def nvtx_decorated_function(): execution_tracker['decorated'] = True - @util.nvtx_decorator(message="test_nvtx_decorator", color="red") + @util.nvtx_decorator(message="test_nvtx_decorator") def nvtx_decorated_function_with_message(): execution_tracker['decorated_with_message'] = True - # Test with NVTX disabled - util.configure_nvtx_profiling(False) + # Call with NVTX disabled — should still execute the wrapped function nvtx_decorated_function() nvtx_decorated_function_with_message() assert all(execution_tracker.values()) @@ -181,12 +183,21 @@ def nvtx_decorated_function_with_message(): # Reset tracker execution_tracker = {'decorated': False, 'decorated_with_message': False} - # Test with NVTX enabled + # Enable NVTX *after* decoration — should pick up the new flag value util.configure_nvtx_profiling(True) nvtx_decorated_function() nvtx_decorated_function_with_message() assert all(execution_tracker.values()) + # Reset tracker + execution_tracker = {'decorated': False, 'decorated_with_message': False} + + # Disable NVTX again — should respect the toggled flag + util.configure_nvtx_profiling(False) + nvtx_decorated_function() + nvtx_decorated_function_with_message() + assert all(execution_tracker.values()) + @pytest.mark.flaky @pytest.mark.flaky_in_dev diff --git a/tests/unit_tests/tokenizers/test_tokenizer.py b/tests/unit_tests/tokenizers/test_tokenizer.py index a376b1336fc..336d4cc97ac 100755 --- a/tests/unit_tests/tokenizers/test_tokenizer.py +++ b/tests/unit_tests/tokenizers/test_tokenizer.py @@ -1,5 +1,8 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +from unittest.mock import MagicMock + +import numpy as np import pytest import torch from packaging import version @@ -121,7 +124,6 @@ def test_hf_tokenizer(): assert tokenizer.vocab_size == 128258 -# HuggingFaceTokenizer.ids_to_text and include_special_tokens (--tokenizer-hf-include-special-tokens). # Uses same local path as test_hf_tokenizer; tests EOS stripping vs keeping in detokenized output (e.g. RL). LOCAL_HF_TOKENIZER_PATH = "/opt/data/tokenizers/huggingface" @@ -275,7 +277,9 @@ def test_null_tokenizer(): ids = tokenizer.tokenize("11 325 97") assert ids == [11, 325, 97] - assert tokenizer.vocab_size == 131073 + assert tokenizer.vocab_size == 131072 + assert tokenizer.eod == 131071 + assert tokenizer.pad == -1 @pytest.mark.parametrize("skip_special_tokens", [True, False]) @@ -460,3 +464,61 @@ def test_sft_tokenizer(): assert len(conv_tokens) > 0 and len(conv_tokens) == len( target_tokens ), "failed to tokenize conversation and return target tokens" + + +# --------------------------------------------------------------------------- +# Unit tests for SFTTokenizer._extract_token_ids (no GPU / real tokenizer needed) +# --------------------------------------------------------------------------- + +try: + from megatron.core.tokenizers.text.libraries.sft_tokenizer import SFTTokenizer + + HAVE_SFT_TOKENIZER = True +except Exception: + HAVE_SFT_TOKENIZER = False + +_IDS = [1, 2, 3, 4, 5] + + +@pytest.mark.skipif(not HAVE_SFT_TOKENIZER, reason="SFTTokenizer not importable") +class TestExtractTokenIds: + """Covers every return-type branch of SFTTokenizer._extract_token_ids.""" + + def _check(self, result): + arr = SFTTokenizer._extract_token_ids(result) + assert isinstance(arr, np.ndarray), "result must be ndarray" + assert arr.ndim == 1, f"expected 1D, got shape {arr.shape}" + assert arr.tolist() == _IDS + + # --- dict with 1D ids (plain list inside dict) --- + def test_dict_1d_list(self): + self._check({"input_ids": _IDS}) + + # --- dict with 2D ids (transformers return_tensors="np" wrapped in dict) --- + def test_dict_2d_ndarray(self): + self._check({"input_ids": np.array([_IDS])}) # shape (1, 5) + + # --- BatchEncoding-like object with input_ids attribute, 2D --- + def test_object_with_input_ids_attr_2d(self): + obj = MagicMock() + obj.__getitem__ = lambda self, k: np.array([_IDS]) if k == "input_ids" else None + obj.input_ids = np.array([_IDS]) + self._check(obj) + + # --- Fast-tokenizer Encoding object with .ids attribute --- + def test_object_with_ids_attr(self): + obj = MagicMock(spec=["ids"]) # no input_ids, no dict behaviour + obj.ids = _IDS + self._check(obj) + + # --- plain list (transformers default / return_dict=False) --- + def test_plain_list(self): + self._check(list(_IDS)) + + # --- 1D raw ndarray --- + def test_1d_ndarray(self): + self._check(np.array(_IDS)) + + # --- 2D raw ndarray (1, seq_len) — the bug fixed in this PR --- + def test_2d_ndarray_batch1(self): + self._check(np.array([_IDS])) # shape (1, 5) diff --git a/tests/unit_tests/tools/__init__.py b/tests/unit_tests/tools/__init__.py new file mode 100644 index 00000000000..b5dff7b5663 --- /dev/null +++ b/tests/unit_tests/tools/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/tests/unit_tests/tools/checkpoint/__init__.py b/tests/unit_tests/tools/checkpoint/__init__.py new file mode 100644 index 00000000000..b5dff7b5663 --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py b/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py new file mode 100644 index 00000000000..4e080f8f0b3 --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py @@ -0,0 +1,328 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Multi-rank distributed round-trip test for gpt_hybrid_conversion. + +Each rank participates in a multi-rank DCP save of a synthetic GPT (or MoE +GPT) state dict; rank 0 then runs the converter and verifies the GPT->Hybrid-> +GPT round-trip exactly. + +This test is meant to be launched under SLURM/srun (or torchrun) with +WORLD_SIZE = TP * PP * FSDP * EP. The (tp, pp, fsdp, ep) values are passed +as flags purely as labels — the converter sees only the DCP-stored +``global_shape`` per tensor and is agnostic to *which* dimension(s) the +source was sharded along. The test value is in: + + 1. Coordinating a real multi-rank ``dcp.save`` (cross-node networking, + collective barriers, shared-filesystem write). + 2. Verifying the converter loads a multi-rank-written checkpoint and + round-trips it through both directions. + +Usage (under srun): + export RANK=$SLURM_PROCID + export LOCAL_RANK=$SLURM_LOCALID + export WORLD_SIZE=$SLURM_NTASKS + export MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -1) + export MASTER_PORT=29500 + python test_distributed_round_trip.py \\ + --tp 2 --pp 2 --fsdp 2 --ep 2 --label TP2-PP2-FSDP2-EP2 \\ + --output-root /lustre/.../scratch/dist_test +""" + +import argparse +import copy +import os +import shutil +import sys +import time +from collections import OrderedDict +from types import SimpleNamespace + +import torch +import torch.distributed as dist + +# Make the conversion tool and helpers importable. +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_REPO_ROOT = os.path.join(_THIS_DIR, '..', '..', '..', '..') +sys.path.insert(0, os.path.join(_REPO_ROOT, 'tools', 'checkpoint')) +sys.path.insert(0, _THIS_DIR) + + +def _log(msg, rank, label=""): + prefix = f"[rank={rank}{(' ' + label) if label else ''}]" + print(f"{prefix} {msg}", flush=True) + + +def _build_state_dict(num_layers, hidden_size, num_moe_experts, vocab_size, dtype): + """Build a deterministic GPT(MoE) state dict identical on every rank. + + Determinism (via fixed seed) lets every rank produce the same tensors so + DCP's de-duplication writes a single coherent checkpoint. After load, we + re-derive the same tensors on rank 0 to verify round-trip. + """ + torch.manual_seed(0xC0FFEE) + sd = OrderedDict() + sd['embedding.word_embeddings.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + + for i in range(num_layers): + p = f'decoder.layers.{i}.' + sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd[p + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'self_attention.linear_proj.weight'] = torch.randn( + hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + + if num_moe_experts is None: + sd[p + 'mlp.linear_fc1.weight'] = torch.randn(4 * hidden_size, hidden_size, dtype=dtype) + sd[p + 'mlp.linear_fc2.weight'] = torch.randn(hidden_size, 4 * hidden_size, dtype=dtype) + else: + sd[p + 'mlp.router.weight'] = torch.randn(num_moe_experts, hidden_size, dtype=dtype) + for j in range(num_moe_experts): + ep = p + f'mlp.experts.local_experts.{j}.' + sd[ep + 'linear_fc1.weight'] = torch.randn( + 4 * hidden_size, hidden_size, dtype=dtype + ) + sd[ep + 'linear_fc2.weight'] = torch.randn( + hidden_size, 4 * hidden_size, dtype=dtype + ) + + sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + return sd + + +def _build_ckpt_args(num_layers, hidden_size, num_moe_experts): + return SimpleNamespace( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=4, + ffn_hidden_size=hidden_size * 4, + seq_length=256, + max_position_embeddings=256, + iteration=100, + consumed_train_samples=0, + consumed_valid_samples=0, + train_iters=1000, + train_samples=0, + tokenizer_type='GPT2BPETokenizer', + position_embedding_type='rope', + params_dtype=torch.float32, + fp16=False, + bf16=False, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=None, + moe_layer_freq=1, + ) + + +def _init_process_group(init_file): + """Initialize via file:// rendezvous on a shared filesystem. + + RANK / WORLD_SIZE come from env (set by srun). MASTER_ADDR / MASTER_PORT + are not needed — every rank just opens the same shared file. This avoids + the SLURM CLI tools (e.g. scontrol) which are not always present inside + container images. + + The init file MUST NOT pre-exist; rank 0 cleans up any stale leftover. + """ + if dist.is_initialized(): + return + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + if rank == 0 and os.path.exists(init_file): + os.remove(init_file) + # Brief settle so other ranks don't race ahead of the cleanup. + time.sleep(1) + dist.init_process_group( + backend='gloo', # CPU-only synthetic save; no NCCL needed + init_method=f'file://{init_file}', + rank=rank, + world_size=world_size, + ) + + +def _verify_round_trip(original, recovered, label): + missing, mismatch = [], [] + for k, v in original.items(): + if k not in recovered: + missing.append(k) + continue + if not torch.equal(v, recovered[k].to(v.dtype)): + mismatch.append(k) + + leaked_ssm = [k for k in recovered if 'mixer.' in k] + + if missing or mismatch or leaked_ssm: + print(f"FAIL [{label}]:") + for k in missing[:5]: + print(f" MISSING: {k}") + for k in mismatch[:5]: + print(f" MISMATCH: {k}") + for k in leaked_ssm[:5]: + print(f" SSM LEAKED: {k}") + raise AssertionError( + f"{label}: {len(missing)} missing, {len(mismatch)} mismatched, " + f"{len(leaked_ssm)} SSM keys leaked" + ) + + print(f"PASS [{label}]: {len(original)} keys round-tripped exactly") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--pp', type=int, default=1) + parser.add_argument('--fsdp', type=int, default=1) + parser.add_argument('--ep', type=int, default=1) + parser.add_argument('--label', type=str, required=True) + parser.add_argument( + '--output-root', + type=str, + required=True, + help='Shared-filesystem path where this scenario writes its ' + 'checkpoints (must be visible from every rank).', + ) + parser.add_argument('--num-layers', type=int, default=3) + parser.add_argument('--hidden-size', type=int, default=64) + parser.add_argument('--vocab-size', type=int, default=512) + parser.add_argument( + '--num-moe-experts', + type=int, + default=None, + help='If set, use the MoE GPT state-dict layout ' + '(mlp.router + mlp.experts.local_experts.*).', + ) + parser.add_argument( + '--pattern', + type=str, + default=None, + help='Hybrid layer pattern. Defaults to "M*-M*-M*-" for ' + 'dense or "M*EM*EM*E" when --num-moe-experts is set.', + ) + args = parser.parse_args() + + expected_world = args.tp * args.pp * args.fsdp * args.ep + pattern = args.pattern or ('M*EM*EM*E' if args.num_moe_experts is not None else 'M*-M*-M*-') + + # Shared init file lives on the same shared FS we use for checkpoints, so + # all ranks on all nodes see the same path. + init_file = os.path.join(args.output_root, f'_pg_init_{args.label}') + if int(os.environ.get('RANK', 0)) == 0: + os.makedirs(args.output_root, exist_ok=True) + _init_process_group(init_file) + rank = dist.get_rank() + world = dist.get_world_size() + if world != expected_world: + if rank == 0: + print(f"FAIL [{args.label}]: world={world} but tp*pp*fsdp*ep={expected_world}") + sys.exit(2) + + # Lazy imports after sys.path is set. + from dist_checkpoint_io import ( + load_dist_checkpoint_full, + save_dist_checkpoint_full, + write_latest_iteration_marker, + ) + from gpt_hybrid_conversion import main as conversion_main + + if rank == 0: + _log( + f"label={args.label} tp={args.tp} pp={args.pp} fsdp={args.fsdp} " + f"ep={args.ep} world={world} pattern={pattern} " + f"num_moe_experts={args.num_moe_experts}", + rank, + args.label, + ) + + # Each rank builds the same full state dict — DCP de-duplicates writes + # across ranks via its writer planner. + state_dict = _build_state_dict( + args.num_layers, args.hidden_size, args.num_moe_experts, args.vocab_size, torch.float32 + ) + ckpt_args = _build_ckpt_args(args.num_layers, args.hidden_size, args.num_moe_experts) + + scratch = os.path.join(args.output_root, args.label) + src_dir = os.path.join(scratch, 'gpt_src') + mid_dir = os.path.join(scratch, 'hybrid_mid') + dst_dir = os.path.join(scratch, 'gpt_dst') + iter_subdir = os.path.join(src_dir, 'iter_0000100') + + if rank == 0: + if os.path.isdir(scratch): + shutil.rmtree(scratch, ignore_errors=True) + os.makedirs(iter_subdir, exist_ok=True) + dist.barrier() + + # --- Multi-rank DCP write of the source GPT checkpoint --- + # dcp.save / dcp.load are both COLLECTIVE in the active process group, so + # every rank in this PG must participate in every save and every load. + # That includes the two conversion_main calls below, which internally call + # load_dist_checkpoint_full + save_dist_checkpoint_full once each. + # If a rank exits early its gloo socket closes and rank 0's reduce_scatter + # dies with "Connection closed by peer". + common_state = {'args': copy.deepcopy(ckpt_args), 'checkpoint_version': 3.0, 'iteration': 100} + save_dist_checkpoint_full( + state_dict, common_state, iter_subdir, model_prefix='model.', backend='torch_dist' + ) + if rank == 0: + write_latest_iteration_marker(iter_subdir, 100) + dist.barrier() + + # --- Convert GPT -> hybrid -> GPT (every rank participates collectively). + common_kwargs = dict( + hybrid_layer_pattern=pattern, + d_model=args.hidden_size, + mamba_version=2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_head_dim=32, + d_conv=4, + init_method_std=0.02, + reset_iterations=False, + input_format='auto', + output_format='torch_dist', + ) + + # Silence non-rank-0 stdout to keep logs readable; collective behavior + # is unaffected. + if rank != 0: + sys.stdout = open(os.devnull, 'w') + + t0 = time.time() + conversion_main( + argparse.Namespace( + direction='gpt-to-hybrid', load_dir=src_dir, save_dir=mid_dir, **common_kwargs + ) + ) + dist.barrier() + conversion_main( + argparse.Namespace( + direction='hybrid-to-gpt', load_dir=mid_dir, save_dir=dst_dir, **common_kwargs + ) + ) + dist.barrier() + dt = time.time() - t0 + + # Restore stdout for rank 0's verify message. + if rank != 0: + sys.stdout = sys.__stdout__ + + # --- Load final + (rank 0 only) verify --- + recovered, _, _, _, _ = load_dist_checkpoint_full(dst_dir) + dist.barrier() + + if rank == 0: + _log(f"conversion time: {dt:.2f}s", rank, args.label) + _verify_round_trip(state_dict, recovered, args.label) + shutil.rmtree(scratch, ignore_errors=True) + if os.path.exists(init_file): + os.remove(init_file) + dist.barrier() + dist.destroy_process_group() + + +if __name__ == '__main__': + main() diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion.py b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion.py new file mode 100644 index 00000000000..79ec3e5ab36 --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion.py @@ -0,0 +1,809 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Unit tests for the GPT <-> Hybrid checkpoint conversion tool. + +These tests validate: +- Hybrid layer pattern parsing +- Layer index mapping (GPT <-> Hybrid) +- State dict key renaming (final_layernorm <-> final_norm) +- Shared parameter copying (embeddings, output_layer) +- SSM parameter initialization shapes and dtypes +- Round-trip conversion: GPT -> Hybrid -> GPT preserves attention and MLP weights +- TP split dimension lookup +""" + +import argparse +import math +import os +import sys +import tempfile +from collections import OrderedDict + +import pytest +import torch + +# Add the tools/checkpoint directory to the path so we can import the module +sys.path.insert( + 0, os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', 'tools', 'checkpoint') +) + +from gpt_hybrid_conversion import ( + build_layer_index_mapping, + convert_gpt_to_hybrid, + convert_hybrid_to_gpt, + get_layer_num_from_key, + initialize_ssm_layer_params, + is_attention_param, + is_mlp_param, + is_ssm_param, + parse_hybrid_layer_pattern, + replace_layer_num, + validate_pattern_gpt_compatible, + validate_source_args_gpt_compatible, +) + +# --------------------------------------------------------------------------- +# Pattern parsing tests +# --------------------------------------------------------------------------- + + +class TestPatternParsing: + def test_simple_pattern(self): + result = parse_hybrid_layer_pattern("M*-M*-") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_all_mamba(self): + result = parse_hybrid_layer_pattern("MMMM") + assert result == ['M', 'M', 'M', 'M'] + + def test_all_attention(self): + result = parse_hybrid_layer_pattern("****") + assert result == ['*', '*', '*', '*'] + + def test_with_mtp_separator(self): + # Should strip MTP patterns (only main pattern) + result = parse_hybrid_layer_pattern("M*-M*-/MM/MM") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_with_pipe_separator(self): + # Should strip pipeline stage separators + result = parse_hybrid_layer_pattern("M*-|M*-") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_with_both_separators(self): + result = parse_hybrid_layer_pattern("M*-|M*-/MM/MM") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_mixed_layers(self): + result = parse_hybrid_layer_pattern("M*-EG") + assert result == ['M', '*', '-', 'E', 'G'] + + def test_invalid_symbol(self): + with pytest.raises(ValueError, match="Invalid layer symbol"): + parse_hybrid_layer_pattern("M*X") + + +# --------------------------------------------------------------------------- +# Layer index mapping tests +# --------------------------------------------------------------------------- + + +class TestLayerIndexMapping: + def test_gpt_to_hybrid_basic(self): + # Pattern: M*-M*- (2 attn at pos 1,4; 2 MLP at pos 2,5) + layer_types = ['M', '*', '-', 'M', '*', '-'] + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-hybrid') + # 2 GPT layers -> attn at [1,4], MLP at [2,5] + assert attn_map == {0: 1, 1: 4} + assert mlp_map == {0: 2, 1: 5} + assert ssm_indices == [0, 3] + + def test_hybrid_to_gpt_basic(self): + layer_types = ['M', '*', '-', 'M', '*', '-'] + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'hybrid-to-gpt') + # attn at mamba layer 1 -> GPT layer 0, attn at 4 -> GPT layer 1 + assert attn_map == {1: 0, 4: 1} + assert mlp_map == {2: 0, 5: 1} + assert ssm_indices == [0, 3] + + def test_alternating_pattern(self): + layer_types = ['*', '-', '*', '-', '*', '-'] + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-hybrid') + assert attn_map == {0: 0, 1: 2, 2: 4} + assert mlp_map == {0: 1, 1: 3, 2: 5} + assert ssm_indices == [] + + def test_mismatched_attn_mlp_count(self): + # 2 attn but 1 MLP -> should raise + layer_types = ['*', '*', '-', 'M'] + with pytest.raises(ValueError, match="must equal"): + build_layer_index_mapping(layer_types, 'gpt-to-hybrid') + + def test_unknown_direction(self): + with pytest.raises(ValueError, match="Unknown direction"): + build_layer_index_mapping(['*', '-'], 'invalid') + + +# --------------------------------------------------------------------------- +# Key helper tests +# --------------------------------------------------------------------------- + + +class TestKeyHelpers: + def test_get_layer_num(self): + assert get_layer_num_from_key('decoder.layers.5.mlp.linear_fc1.weight') == 5 + assert get_layer_num_from_key('decoder.layers.0.self_attention.linear_qkv.weight') == 0 + assert get_layer_num_from_key('decoder.layers.99.mixer.A_log') == 99 + assert get_layer_num_from_key('embedding.word_embeddings.weight') is None + + def test_replace_layer_num(self): + key = 'decoder.layers.3.mlp.linear_fc1.weight' + assert replace_layer_num(key, 3, 7) == 'decoder.layers.7.mlp.linear_fc1.weight' + + def test_is_attention_param(self): + assert is_attention_param('decoder.layers.0.self_attention.linear_qkv.weight') + assert is_attention_param('decoder.layers.0.input_layernorm.weight') + assert not is_attention_param('decoder.layers.0.mlp.linear_fc1.weight') + assert not is_attention_param('decoder.layers.0.mixer.A_log') + + def test_is_mlp_param(self): + assert is_mlp_param('decoder.layers.0.mlp.linear_fc1.weight') + assert is_mlp_param('decoder.layers.0.pre_mlp_layernorm.weight') + assert not is_mlp_param('decoder.layers.0.self_attention.linear_qkv.weight') + + def test_is_ssm_param(self): + assert is_ssm_param('decoder.layers.0.mixer.A_log') + assert is_ssm_param('decoder.layers.0.mixer.in_proj.weight') + assert is_ssm_param('decoder.layers.0.mixer.conv1d.weight') + assert is_ssm_param('decoder.layers.0.mixer.D') + assert is_ssm_param('decoder.layers.0.mixer.dt_bias') + assert is_ssm_param('decoder.layers.0.mixer.norm.weight') + assert is_ssm_param('decoder.layers.0.mixer.out_proj.weight') + assert not is_ssm_param('decoder.layers.0.mlp.linear_fc1.weight') + assert not is_ssm_param('decoder.layers.0.self_attention.linear_qkv.weight') + + +# --------------------------------------------------------------------------- +# SSM initialization tests +# --------------------------------------------------------------------------- + + +class TestSSMInitialization: + def test_shapes(self): + d_model = 256 + d_inner = 512 # 2 * d_model + d_state = 64 + n_groups = 4 + head_dim = 32 + n_heads = d_inner // head_dim # 16 + d_conv = 4 + conv_dim = d_inner + 2 * n_groups * d_state + + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=d_model, + mamba_d_inner=d_inner, + mamba_d_state=d_state, + mamba2_n_groups=n_groups, + mamba2_n_heads=n_heads, + mamba_head_dim=head_dim, + d_conv=d_conv, + dtype=torch.float32, + ) + + prefix = 'decoder.layers.0.mixer.' + + # in_proj: [2*d_inner + 2*n_groups*d_state + n_heads, d_model] + in_proj_out = 2 * d_inner + 2 * n_groups * d_state + n_heads + assert params[prefix + 'in_proj.weight'].shape == (in_proj_out, d_model) + + # in_proj layer norm weight + assert params[prefix + 'in_proj.layer_norm_weight'].shape == (d_model,) + + # conv1d: [conv_dim, 1, d_conv] + assert params[prefix + 'conv1d.weight'].shape == (conv_dim, 1, d_conv) + assert params[prefix + 'conv1d.bias'].shape == (conv_dim,) + + # A_log: [n_heads] + assert params[prefix + 'A_log'].shape == (n_heads,) + assert params[prefix + 'A_log'].dtype == torch.float32 + + # D: [n_heads] + assert params[prefix + 'D'].shape == (n_heads,) + assert params[prefix + 'D'].dtype == torch.float32 + + # dt_bias: [n_heads] + assert params[prefix + 'dt_bias'].shape == (n_heads,) + + # norm: [d_inner] + assert params[prefix + 'norm.weight'].shape == (d_inner,) + + # out_proj: [d_model, d_inner] + assert params[prefix + 'out_proj.weight'].shape == (d_model, d_inner) + + def test_A_log_values(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + A_log = params['decoder.layers.0.mixer.A_log'] + # A was uniform in (1, 16), so A_log should be in (log(1), log(16)) = (0, 2.77) + assert (A_log >= 0).all() + assert (A_log <= math.log(16) + 0.01).all() + + def test_D_values(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + D = params['decoder.layers.0.mixer.D'] + assert torch.allclose(D, torch.ones_like(D)) + + def test_conv1d_bias_zeros(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + bias = params['decoder.layers.0.mixer.conv1d.bias'] + assert torch.allclose(bias, torch.zeros_like(bias)) + + def test_norm_weight_ones(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + norm = params['decoder.layers.0.mixer.norm.weight'] + assert torch.allclose(norm, torch.ones_like(norm)) + + def test_layer_norm_weight_ones(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + ln = params['decoder.layers.0.mixer.in_proj.layer_norm_weight'] + assert torch.allclose(ln, torch.ones_like(ln)) + + def test_different_layer_idx(self): + params = initialize_ssm_layer_params( + layer_idx=7, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + assert 'decoder.layers.7.mixer.A_log' in params + assert 'decoder.layers.0.mixer.A_log' not in params + + +# --------------------------------------------------------------------------- +# Synthetic GPT checkpoint builder +# --------------------------------------------------------------------------- + + +def make_synthetic_gpt_checkpoint(num_layers, d_model, dtype=torch.float32): + """Create a minimal synthetic GPT state dict for testing.""" + state_dict = OrderedDict() + + # Embeddings + state_dict['embedding.word_embeddings.weight'] = torch.randn(1000, d_model, dtype=dtype) + + # Transformer layers + for i in range(num_layers): + prefix = f'decoder.layers.{i}.' + # Attention + state_dict[prefix + 'input_layernorm.weight'] = torch.randn(d_model, dtype=dtype) + state_dict[prefix + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * d_model, d_model, dtype=dtype + ) + state_dict[prefix + 'self_attention.linear_proj.weight'] = torch.randn( + d_model, d_model, dtype=dtype + ) + # MLP + state_dict[prefix + 'pre_mlp_layernorm.weight'] = torch.randn(d_model, dtype=dtype) + state_dict[prefix + 'mlp.linear_fc1.weight'] = torch.randn( + 4 * d_model, d_model, dtype=dtype + ) + state_dict[prefix + 'mlp.linear_fc2.weight'] = torch.randn( + d_model, 4 * d_model, dtype=dtype + ) + + # Final norm + state_dict['decoder.final_layernorm.weight'] = torch.randn(d_model, dtype=dtype) + + # Output layer + state_dict['output_layer.weight'] = torch.randn(1000, d_model, dtype=dtype) + + return state_dict + + +# --------------------------------------------------------------------------- +# Full conversion tests +# --------------------------------------------------------------------------- + + +class TestGPTToHybridConversion: + def setup_method(self): + self.d_model = 64 + self.num_gpt_layers = 2 + self.pattern = "M*-M*-" # 6 total: 2 SSM, 2 attn, 2 MLP + self.gpt_state = make_synthetic_gpt_checkpoint(self.num_gpt_layers, self.d_model) + self.args = argparse.Namespace( + d_model=self.d_model, + mamba_d_inner=self.d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(self.d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + def test_shared_params_preserved(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + # Embeddings should be identical + assert torch.equal( + result['embedding.word_embeddings.weight'], + self.gpt_state['embedding.word_embeddings.weight'], + ) + # Output layer + assert torch.equal(result['output_layer.weight'], self.gpt_state['output_layer.weight']) + + def test_final_norm_renamed(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + assert 'decoder.final_norm.weight' in result + assert 'decoder.final_layernorm.weight' not in result + assert torch.equal( + result['decoder.final_norm.weight'], self.gpt_state['decoder.final_layernorm.weight'] + ) + + def test_attention_params_mapped(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + # GPT layer 0 attn -> Mamba layer 1 (first '*' in M*-M*-) + assert torch.equal( + result['decoder.layers.1.self_attention.linear_qkv.weight'], + self.gpt_state['decoder.layers.0.self_attention.linear_qkv.weight'], + ) + # GPT layer 1 attn -> Mamba layer 4 (second '*') + assert torch.equal( + result['decoder.layers.4.self_attention.linear_qkv.weight'], + self.gpt_state['decoder.layers.1.self_attention.linear_qkv.weight'], + ) + + def test_mlp_params_mapped(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + # GPT layer 0 MLP -> Mamba layer 2 (first '-') + assert torch.equal( + result['decoder.layers.2.mlp.linear_fc1.weight'], + self.gpt_state['decoder.layers.0.mlp.linear_fc1.weight'], + ) + # GPT layer 1 MLP -> Mamba layer 5 (second '-') + assert torch.equal( + result['decoder.layers.5.mlp.linear_fc2.weight'], + self.gpt_state['decoder.layers.1.mlp.linear_fc2.weight'], + ) + + def test_ssm_layers_initialized(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + # SSM layers at index 0 and 3 + for idx in [0, 3]: + prefix = f'decoder.layers.{idx}.mixer.' + assert prefix + 'A_log' in result + assert prefix + 'D' in result + assert prefix + 'dt_bias' in result + assert prefix + 'conv1d.weight' in result + assert prefix + 'conv1d.bias' in result + assert prefix + 'in_proj.weight' in result + assert prefix + 'norm.weight' in result + assert prefix + 'out_proj.weight' in result + + def test_layer_count_mismatch_raises(self): + # Pattern with 3 attn but only 2 GPT layers + layer_types = parse_hybrid_layer_pattern("M*-*-*-") + with pytest.raises(ValueError, match="layers"): + convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + +class TestHybridToGPTConversion: + def setup_method(self): + self.d_model = 64 + self.pattern = "M*-M*-" + self.args = argparse.Namespace( + d_model=self.d_model, + mamba_d_inner=self.d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(self.d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + def _make_mamba_state(self): + """Build a synthetic Mamba state dict matching pattern M*-M*-.""" + state_dict = OrderedDict() + state_dict['embedding.word_embeddings.weight'] = torch.randn(1000, self.d_model) + state_dict['output_layer.weight'] = torch.randn(1000, self.d_model) + state_dict['decoder.final_norm.weight'] = torch.randn(self.d_model) + + layer_types = parse_hybrid_layer_pattern(self.pattern) + d_inner = self.d_model * 2 + n_heads = self.args.mamba2_n_heads + n_groups = self.args.mamba2_n_groups + d_state = self.args.mamba_d_state + + for i, lt in enumerate(layer_types): + prefix = f'decoder.layers.{i}.' + if lt == 'M': + # SSM params + ssm = initialize_ssm_layer_params( + i, self.d_model, d_inner, d_state, n_groups, n_heads, self.args.mamba2_head_dim + ) + state_dict.update(ssm) + elif lt == '*': + state_dict[prefix + 'input_layernorm.weight'] = torch.randn(self.d_model) + state_dict[prefix + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * self.d_model, self.d_model + ) + state_dict[prefix + 'self_attention.linear_proj.weight'] = torch.randn( + self.d_model, self.d_model + ) + elif lt == '-': + state_dict[prefix + 'pre_mlp_layernorm.weight'] = torch.randn(self.d_model) + state_dict[prefix + 'mlp.linear_fc1.weight'] = torch.randn( + 4 * self.d_model, self.d_model + ) + state_dict[prefix + 'mlp.linear_fc2.weight'] = torch.randn( + self.d_model, 4 * self.d_model + ) + + return state_dict + + def test_final_norm_renamed_back(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) + + assert 'decoder.final_layernorm.weight' in result + assert 'decoder.final_norm.weight' not in result + + def test_ssm_params_discarded(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) + + # No SSM keys should remain + for key in result: + assert 'mixer.' not in key, f"SSM key not discarded: {key}" + + def test_attention_params_mapped(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) + + # Mamba layer 1 (first *) -> GPT layer 0 + assert torch.equal( + result['decoder.layers.0.self_attention.linear_qkv.weight'], + mamba_state['decoder.layers.1.self_attention.linear_qkv.weight'], + ) + # Mamba layer 4 (second *) -> GPT layer 1 + assert torch.equal( + result['decoder.layers.1.self_attention.linear_qkv.weight'], + mamba_state['decoder.layers.4.self_attention.linear_qkv.weight'], + ) + + def test_mlp_params_mapped(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) + + # Mamba layer 2 (first -) -> GPT layer 0 + assert torch.equal( + result['decoder.layers.0.mlp.linear_fc1.weight'], + mamba_state['decoder.layers.2.mlp.linear_fc1.weight'], + ) + + def test_gpt_layer_count(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) + + # Should have 2 GPT layers (layers 0 and 1) + layer_nums = set() + for key in result: + lnum = get_layer_num_from_key(key) + if lnum is not None: + layer_nums.add(lnum) + assert layer_nums == {0, 1} + + +# --------------------------------------------------------------------------- +# Round-trip test: GPT -> Hybrid -> GPT; using Mamba as the example below +# --------------------------------------------------------------------------- + + +class TestRoundTrip: + def test_gpt_hybrid_gpt_preserves_weights(self): + """Converting GPT -> Hybrid -> GPT should preserve all attention & MLP weights.""" + d_model = 64 + num_layers = 2 + pattern = "M*-M*-" + + args = argparse.Namespace( + d_model=d_model, + mamba_d_inner=d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + original_gpt = make_synthetic_gpt_checkpoint(num_layers, d_model) + layer_types = parse_hybrid_layer_pattern(pattern) + + # GPT -> Hybrid + mamba_state = convert_gpt_to_hybrid(original_gpt, layer_types, args) + + # Hybrid -> GPT + recovered_gpt = convert_hybrid_to_gpt(mamba_state, layer_types, args) + + # Check all original GPT keys are preserved + for key in original_gpt: + # final_layernorm is renamed in the round trip + if 'final_layernorm' in key: + continue + assert key in recovered_gpt, f"Missing key after round-trip: {key}" + assert torch.equal( + original_gpt[key], recovered_gpt[key] + ), f"Weight mismatch after round-trip for {key}" + + # Check final_layernorm was properly renamed back + assert torch.equal( + original_gpt['decoder.final_layernorm.weight'], + recovered_gpt['decoder.final_layernorm.weight'], + ) + + def test_round_trip_different_pattern(self): + """Test with a pattern that has more SSM layers.""" + d_model = 64 + num_layers = 3 + pattern = "M*-M*-M*-" + + args = argparse.Namespace( + d_model=d_model, + mamba_d_inner=d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + original_gpt = make_synthetic_gpt_checkpoint(num_layers, d_model) + layer_types = parse_hybrid_layer_pattern(pattern) + + mamba_state = convert_gpt_to_hybrid(original_gpt, layer_types, args) + recovered_gpt = convert_hybrid_to_gpt(mamba_state, layer_types, args) + + for key in original_gpt: + if 'final_layernorm' in key: + continue + assert key in recovered_gpt, f"Missing key: {key}" + assert torch.equal(original_gpt[key], recovered_gpt[key]), f"Mismatch for {key}" + + +# --------------------------------------------------------------------------- +# GPT compatibility whitelist tests +# --------------------------------------------------------------------------- + + +class TestPatternWhitelist: + """validate_pattern_gpt_compatible rejects hybrid patterns GPTModel can't express.""" + + def test_accepts_mamba_attn_mlp(self): + # Standard hybrid with equal attn/MLP counts. + layer_types = parse_hybrid_layer_pattern("M*-M*-M*-") + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_accepts_pure_transformer_pattern(self): + layer_types = parse_hybrid_layer_pattern("*-*-*-") + validate_pattern_gpt_compatible(layer_types, 'hybrid-to-gpt') + + def test_accepts_pure_ssm_pattern(self): + # Pure-SSM models have no attention/MLP, so trivially GPT-compatible + # in the pattern sense (the GPT side would be empty). + layer_types = parse_hybrid_layer_pattern("MMMM") + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_accepts_moe_pattern(self): + # MoE layers ('E') round-trip through the converter as long as every + # MLP-bearing position is the same kind. + layer_types = parse_hybrid_layer_pattern("M*EM*EM*E") + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_accepts_pure_attn_moe_pattern(self): + # No SSM, alternating attn/MoE — i.e. a Mixtral-like GPT. + layer_types = parse_hybrid_layer_pattern("*E*E*E") + validate_pattern_gpt_compatible(layer_types, 'hybrid-to-gpt') + + def test_rejects_mixed_dense_and_moe(self): + # GPT layers must be uniform: '-' (dense) and 'E' (MoE) cannot both + # appear in the same pattern. + layer_types = parse_hybrid_layer_pattern("M*-M*E") + with pytest.raises(ValueError, match="uniform"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_rejects_gdn_symbol(self): + layer_types = parse_hybrid_layer_pattern("G*-*-") + with pytest.raises(ValueError, match="not GPT-compatible"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_rejects_unequal_attn_mlp(self): + layer_types = parse_hybrid_layer_pattern("M**-") # 2 attn, 1 MLP + with pytest.raises(ValueError, match="pair every attention"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_unequal_attn_moe_also_rejected(self): + # Same uniformity check, but with MoE — 2 attn, 1 MoE. + layer_types = parse_hybrid_layer_pattern("M**E") + with pytest.raises(ValueError, match="pair every attention"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_error_lists_offending_symbols(self): + # 'G' is still rejected; the error message should mention it. + layer_types = parse_hybrid_layer_pattern("M*-G") + with pytest.raises(ValueError) as exc: + validate_pattern_gpt_compatible(layer_types, 'hybrid-to-gpt') + assert 'G' in str(exc.value) + + +class TestSourceArgsWhitelist: + """validate_source_args_gpt_compatible rejects source checkpoints with + non-GPT-expressible features.""" + + def _ok_args(self, **overrides): + """Build a minimal args namespace that mimics a plain GPT/hybrid + training run. Any GPT-incompatible flags default to their + "off" value.""" + base = dict( + num_moe_experts=None, + moe_shared_expert_intermediate_size=None, + moe_layer_freq=1, + experimental_attention_variant=None, + linear_attention_freq=None, + heterogeneous_block_specs=False, + heterogeneous_layers_config_path=None, + heterogeneous_layers_config_encoded_json=None, + multi_latent_attention=False, + mtp_num_layers=None, + ) + base.update(overrides) + return argparse.Namespace(**base) + + def test_accepts_plain_gpt_args(self): + validate_source_args_gpt_compatible(self._ok_args(), 'gpt-to-hybrid') + + def test_none_args_is_noop(self): + # Dist checkpoints sometimes have no cached args blob. + validate_source_args_gpt_compatible(None, 'gpt-to-hybrid') + + def test_accepts_missing_optional_fields(self): + # Older checkpoints may not have every field; the validator should + # silently skip fields it doesn't find. + minimal = argparse.Namespace(num_moe_experts=None) + validate_source_args_gpt_compatible(minimal, 'hybrid-to-gpt') + + def test_accepts_moe_args(self): + # MoE keys live under decoder.layers..mlp.* and round-trip as-is. + validate_source_args_gpt_compatible(self._ok_args(num_moe_experts=8), 'gpt-to-hybrid') + + def test_accepts_shared_expert_args(self): + # Shared experts also live under mlp.shared_experts.* and round-trip. + validate_source_args_gpt_compatible( + self._ok_args(num_moe_experts=8, moe_shared_expert_intermediate_size=4096), + 'gpt-to-hybrid', + ) + + def test_rejects_moe_layer_freq_list(self): + # Heterogeneous interleaving (some dense, some MoE) breaks GPT uniformity. + with pytest.raises(ValueError, match="interleaved"): + validate_source_args_gpt_compatible( + self._ok_args(moe_layer_freq=[1, 0, 1, 0]), 'gpt-to-hybrid' + ) + + def test_accepts_moe_layer_freq_1(self): + validate_source_args_gpt_compatible(self._ok_args(moe_layer_freq=1), 'gpt-to-hybrid') + + def test_accepts_moe_layer_freq_all_ones_list(self): + # An all-1s list is uniform (every layer is the same kind) and accepted. + validate_source_args_gpt_compatible( + self._ok_args(moe_layer_freq=[1, 1, 1, 1]), 'gpt-to-hybrid' + ) + + def test_rejects_experimental_attention(self): + with pytest.raises(ValueError, match="experimental attention"): + validate_source_args_gpt_compatible( + self._ok_args(experimental_attention_variant='gated_delta_net'), 'gpt-to-hybrid' + ) + + def test_rejects_linear_attention(self): + with pytest.raises(ValueError, match="linear attention"): + validate_source_args_gpt_compatible( + self._ok_args(linear_attention_freq=4), 'gpt-to-hybrid' + ) + + def test_rejects_heterogeneous_block_specs(self): + with pytest.raises(ValueError, match="heterogeneous"): + validate_source_args_gpt_compatible( + self._ok_args(heterogeneous_block_specs=True), 'hybrid-to-gpt' + ) + + def test_rejects_heterogeneous_config_path(self): + with pytest.raises(ValueError, match="heterogeneous"): + validate_source_args_gpt_compatible( + self._ok_args(heterogeneous_layers_config_path='/tmp/x.json'), 'gpt-to-hybrid' + ) + + def test_rejects_mla(self): + with pytest.raises(ValueError, match="Multi-Latent"): + validate_source_args_gpt_compatible( + self._ok_args(multi_latent_attention=True), 'gpt-to-hybrid' + ) + + def test_rejects_mtp(self): + with pytest.raises(ValueError, match="Multi-Token Prediction"): + validate_source_args_gpt_compatible(self._ok_args(mtp_num_layers=2), 'gpt-to-hybrid') + + def test_reports_multiple_reasons(self): + # Both heterogeneous moe_layer_freq and MLA set — both should be reported. + with pytest.raises(ValueError) as exc: + validate_source_args_gpt_compatible( + self._ok_args(moe_layer_freq=[1, 0], multi_latent_attention=True), 'gpt-to-hybrid' + ) + msg = str(exc.value) + assert 'interleaved' in msg + assert 'Multi-Latent' in msg diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py new file mode 100644 index 00000000000..8102b7018ae --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py @@ -0,0 +1,418 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Parallelism-matrix integration tests for gpt_hybrid_conversion.py. + +The converter operates on dist (``torch_dist`` / ``fsdp_dtensor``) checkpoints +only — DCP's metadata stores each tensor's ``global_shape``, so the on-disk +TP / PP / FSDP layout is abstracted away from the conversion logic. We +synthesize a DCP checkpoint via a single-rank ``dcp.save`` and round-trip +GPT -> Hybrid -> GPT through the conversion CLI, asserting attention and MLP +weights match exactly. + +Each scenario is run as a distinct test to document the supported matrix and +catch regressions in dispatch logic. Designed to run on a single-GPU node via +SLURM (no torchrun needed). +""" + +import argparse +import copy +import os +import shutil +import sys +import tempfile +from collections import OrderedDict +from types import SimpleNamespace + +import pytest +import torch +import torch.distributed as dist + +# Make the conversion tool importable under both `python ` and `pytest`. +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_REPO_ROOT = os.path.join(_THIS_DIR, '..', '..', '..', '..') +sys.path.insert(0, os.path.join(_REPO_ROOT, 'tools', 'checkpoint')) +sys.path.insert(0, _THIS_DIR) + +from gpt_hybrid_conversion import main as conversion_main + + +# These scenarios are SYNTHETIC and single-rank by design: each one writes a +# tiny synthetic DCP checkpoint and round-trips it through the converter on +# rank 0. They share the default torch.distributed process group with whatever +# harness launched pytest. When that default PG is multi-rank (e.g. Megatron's +# CI/CD initialises NCCL with world_size>1 before pytest collection), the +# dcp.save/dcp.load collectives stall: each rank has its own +# tempfile.mkdtemp() path and its own torch.randn() tensors, so the metadata +# coordination across ranks never converges and the NCCL watchdog kills the +# job after 10 minutes (see ProcessGroupNCCL ALLGATHER timeout). +# +# Multi-rank coverage lives in test_distributed_round_trip.py, which uses a +# fresh single-rank gloo subgroup per scenario via SLURM/srun in +# run_slurm_ckpt_convert_tests.sh. Skip these synthetic tests whenever the +# default PG is already multi-rank. +@pytest.fixture(autouse=True) +def _skip_when_multi_rank_pg(): + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + pytest.skip( + "Synthetic single-rank tests skipped under a multi-rank default " + "process group; multi-rank coverage is in " + "test_distributed_round_trip.py." + ) + + +# --------------------------------------------------------------------------- +# Synthetic-checkpoint helpers +# --------------------------------------------------------------------------- + + +def make_checkpoint_args( + num_layers=4, + hidden_size=128, + num_attention_heads=4, + seq_length=256, + max_position_embeddings=256, + iteration=100, + num_moe_experts=None, + moe_shared_expert_intermediate_size=None, +): + """Build a minimal checkpoint 'args' namespace mirroring Megatron's. + + Set ``num_moe_experts`` to make the source/target a MoE GPT; the converter + will then pass the MoE config through unchanged so the round-trip stays + structurally consistent. + """ + return SimpleNamespace( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ffn_hidden_size=hidden_size * 4, + seq_length=seq_length, + max_position_embeddings=max_position_embeddings, + iteration=iteration, + consumed_train_samples=0, + consumed_valid_samples=0, + train_iters=1000, + train_samples=0, + tokenizer_type='GPT2BPETokenizer', + position_embedding_type='rope', + params_dtype=torch.float32, + fp16=False, + bf16=False, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + moe_layer_freq=1, + ) + + +def make_gpt_state_dict( + num_layers, + hidden_size, + vocab_size=1024, + dtype=torch.float32, + num_moe_experts=None, + shared_expert_size=None, +): + """Create a minimal GPT state dict with the standard Megatron keys. + + Dense MLP layout (default): ``mlp.linear_fc1`` / ``mlp.linear_fc2``. + MoE layout (``num_moe_experts`` set): ``mlp.router`` plus N experts under + ``mlp.experts.local_experts..linear_fc{1,2}``, optionally a shared + expert under ``mlp.shared_experts.linear_fc{1,2}``. These are exactly the + keys Megatron writes for non-grouped-GEMM MoE — they all live under + ``decoder.layers..mlp.*`` so the converter ferries them through with no + MoE-specific code. + """ + sd = OrderedDict() + sd['embedding.word_embeddings.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + + for i in range(num_layers): + p = f'decoder.layers.{i}.' + sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd[p + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'self_attention.linear_proj.weight'] = torch.randn( + hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + + if num_moe_experts is None: + # Dense MLP + sd[p + 'mlp.linear_fc1.weight'] = torch.randn(4 * hidden_size, hidden_size, dtype=dtype) + sd[p + 'mlp.linear_fc2.weight'] = torch.randn(hidden_size, 4 * hidden_size, dtype=dtype) + else: + # MoE: router + N experts (+ optional shared expert) + sd[p + 'mlp.router.weight'] = torch.randn(num_moe_experts, hidden_size, dtype=dtype) + for j in range(num_moe_experts): + ep = p + f'mlp.experts.local_experts.{j}.' + sd[ep + 'linear_fc1.weight'] = torch.randn( + 4 * hidden_size, hidden_size, dtype=dtype + ) + sd[ep + 'linear_fc2.weight'] = torch.randn( + hidden_size, 4 * hidden_size, dtype=dtype + ) + if shared_expert_size is not None: + sp = p + 'mlp.shared_experts.' + sd[sp + 'linear_fc1.weight'] = torch.randn( + shared_expert_size, hidden_size, dtype=dtype + ) + sd[sp + 'linear_fc2.weight'] = torch.randn( + hidden_size, shared_expert_size, dtype=dtype + ) + + sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + return sd + + +# --------------------------------------------------------------------------- +# Dist (torch_dist / fsdp_dtensor) fixture builders +# --------------------------------------------------------------------------- + + +def _save_dist_checkpoint( + root_dir, full_sd, ckpt_args, iteration=100, prefix='model.', backend='torch_dist' +): + """Write a full state dict as a single-rank DCP checkpoint. + + From the converter's POV, this is indistinguishable from a multi-rank + TP+PP+FSDP save: DCP stores each tensor's global shape in its metadata + and the read planner reassembles the full tensor regardless of how many + processes wrote it. + """ + from dist_checkpoint_io import ( + ensure_single_rank_process_group, + save_dist_checkpoint_full, + write_latest_iteration_marker, + ) + + ensure_single_rank_process_group() + + iter_dir = os.path.join(root_dir, f'iter_{iteration:07d}') + common_state = { + 'args': copy.deepcopy(ckpt_args), + 'checkpoint_version': 3.0, + 'iteration': iteration, + } + save_dist_checkpoint_full(full_sd, common_state, iter_dir, model_prefix=prefix, backend=backend) + write_latest_iteration_marker(iter_dir, iteration) + + +def _load_converted_dist(ckpt_dir): + """Read a dist-format converted checkpoint back into a full state dict.""" + from dist_checkpoint_io import load_dist_checkpoint_full + + sd, common, prefix, backend, iteration = load_dist_checkpoint_full(ckpt_dir) + return sd, common.get('args', None) + + +# --------------------------------------------------------------------------- +# Core scenario runner +# --------------------------------------------------------------------------- + + +def _run_scenario( + label, + source_format, + target_format, + num_layers=4, + hidden_size=128, + pattern="M*-M*-M*-M*-", + source_prefix='model.', + num_moe_experts=None, + shared_expert_size=None, +): + """Build a GPT source ckpt, convert GPT->Hybrid->GPT, verify round-trip.""" + print(f"\n=== {label} ===") + print(f" source={source_format} (prefix='{source_prefix}')") + print(f" target={target_format}") + if num_moe_experts is not None: + print(f" MoE: num_experts={num_moe_experts} shared={shared_expert_size}") + + tmpdir = tempfile.mkdtemp(prefix=f'gpt_hybrid_{label.replace(" ", "_")}_') + try: + src_gpt_dir = os.path.join(tmpdir, 'gpt_src') + hybrid_dir = os.path.join(tmpdir, 'hybrid_mid') + dst_gpt_dir = os.path.join(tmpdir, 'gpt_dst') + + ckpt_args = make_checkpoint_args( + num_layers=num_layers, + hidden_size=hidden_size, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=shared_expert_size, + ) + gpt_sd = make_gpt_state_dict( + num_layers, + hidden_size, + num_moe_experts=num_moe_experts, + shared_expert_size=shared_expert_size, + ) + + _save_dist_checkpoint( + src_gpt_dir, gpt_sd, ckpt_args, prefix=source_prefix, backend=source_format + ) + + common_kwargs = dict( + hybrid_layer_pattern=pattern, + d_model=hidden_size, + mamba_version=2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_head_dim=32, + d_conv=4, + init_method_std=0.02, + reset_iterations=False, + input_format='auto', + output_format=target_format, + ) + + # --- GPT -> Hybrid --- + conversion_main( + argparse.Namespace( + direction='gpt-to-hybrid', + load_dir=src_gpt_dir, + save_dir=hybrid_dir, + **common_kwargs, + ) + ) + + # --- Hybrid -> GPT --- + conversion_main( + argparse.Namespace( + direction='hybrid-to-gpt', + load_dir=hybrid_dir, + save_dir=dst_gpt_dir, + **common_kwargs, + ) + ) + + # --- Verify --- + recovered_sd, _ = _load_converted_dist(dst_gpt_dir) + # The hybrid->gpt step renames decoder.final_norm -> decoder.final_layernorm, + # mirroring the original GPT key. So recovered_sd should have the same + # keys and tensor values as gpt_sd. + + mismatches = [] + for key, original in gpt_sd.items(): + if key not in recovered_sd: + mismatches.append(f"MISSING: {key}") + continue + if not torch.equal(original, recovered_sd[key]): + max_diff = (original - recovered_sd[key]).abs().max().item() + mismatches.append(f"MISMATCH: {key} (max_diff={max_diff})") + + if mismatches: + for m in mismatches[:10]: + print(f" FAIL: {m}") + raise AssertionError(f"{label} failed with {len(mismatches)} weight mismatches") + + # SSM keys must be absent in the final GPT output. + assert not any('mixer.' in k for k in recovered_sd), ( + f"SSM keys leaked into final GPT output: " + f"{[k for k in recovered_sd if 'mixer.' in k][:5]}" + ) + + print(f"PASSED: {label}") + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# Test cases — one per (source backend, target backend, pattern) combo +# --------------------------------------------------------------------------- + + +def test_torch_dist_roundtrip(): + _run_scenario("torch_dist roundtrip", 'torch_dist', 'torch_dist') + + +def test_fsdp_dtensor_roundtrip(): + _run_scenario("fsdp_dtensor roundtrip", 'fsdp_dtensor', 'fsdp_dtensor') + + +def test_fsdp_dtensor_prefix(): + """fsdp_dtensor backend uses the 'model.module.' key prefix — verify we + auto-detect and strip it correctly.""" + _run_scenario( + "fsdp_dtensor prefix", 'fsdp_dtensor', 'fsdp_dtensor', source_prefix='model.module.' + ) + + +def test_torch_dist_alternating_pattern(): + """Pure transformer pattern (no SSM) round-trips.""" + _run_scenario("torch_dist alternating", 'torch_dist', 'torch_dist', pattern="*-*-*-*-") + + +def test_torch_dist_dense_ssm_pattern(): + """Dense SSM pattern still round-trips on the attn/MLP layers.""" + _run_scenario("torch_dist dense SSM", 'torch_dist', 'torch_dist', pattern="MM*-MM*-MM*-MM*-") + + +def test_torch_dist_moe_roundtrip(): + """MoE GPT (Mixtral-style) round-trips through an 'E'-bearing pattern. + + Source has num_moe_experts=4 and writes mlp.router / mlp.experts.* keys. + The hybrid pattern 'M*EM*EM*E' has 3 'E' positions, one per source layer. + The converter should ferry the router + every per-expert tensor through + verbatim — no MoE-specific code path involved. + """ + _run_scenario( + "torch_dist MoE roundtrip", + 'torch_dist', + 'torch_dist', + num_layers=3, + pattern="M*EM*EM*E", + num_moe_experts=4, + ) + + +def test_torch_dist_moe_with_shared_experts(): + """MoE + shared experts round-trip together (mlp.shared_experts.* keys).""" + _run_scenario( + "torch_dist MoE+shared", + 'torch_dist', + 'torch_dist', + num_layers=3, + hidden_size=64, + pattern="*E*E*E", + num_moe_experts=4, + shared_expert_size=64 * 2, + ) + + +def test_fsdp_dtensor_moe_roundtrip(): + """MoE round-trips through fsdp_dtensor (covers the 'model.module.' prefix + case combined with MoE keys).""" + _run_scenario( + "fsdp_dtensor MoE roundtrip", + 'fsdp_dtensor', + 'fsdp_dtensor', + num_layers=3, + pattern="M*EM*EM*E", + num_moe_experts=4, + source_prefix='model.module.', + ) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == '__main__': + print("=" * 60) + print("GPT <-> Hybrid Conversion Parallelism Matrix Tests") + print("=" * 60) + + test_torch_dist_roundtrip() + test_fsdp_dtensor_roundtrip() + test_fsdp_dtensor_prefix() + test_torch_dist_alternating_pattern() + test_torch_dist_dense_ssm_pattern() + test_torch_dist_moe_roundtrip() + test_torch_dist_moe_with_shared_experts() + test_fsdp_dtensor_moe_roundtrip() + + print("=" * 60) + print("ALL PARALLELISM MATRIX TESTS PASSED") + print("=" * 60) diff --git a/tests/unit_tests/training/config/test_container_base.py b/tests/unit_tests/training/config/test_container_base.py new file mode 100644 index 00000000000..2b87c69679f --- /dev/null +++ b/tests/unit_tests/training/config/test_container_base.py @@ -0,0 +1,850 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import copy +import functools +import os +import tempfile +from dataclasses import dataclass +from unittest.mock import MagicMock, mock_open, patch + +import pytest +import torch +import yaml + +from megatron.core.msc_utils import MultiStorageClientFeature +from megatron.training.config.container import ConfigContainerBase +from megatron.training.config.instantiate_utils import InstantiationMode + + +@pytest.fixture(autouse=True) +def _disable_allowlist(): + """Temporarily disable allowlist to fully test container logic with local test targets.""" + from megatron.training.config.instantiate_utils import target_allowlist + + target_allowlist.disable() + yield + target_allowlist.enable() + + +def _target_qualname(obj) -> str: + return f"{obj.__module__}.{obj.__qualname__}" + + +# Test functions for callable testing +def activation_function(x): + """Test activation function.""" + return x * 2 + + +def loss_function(pred, target, reduction="mean"): + """Test loss function with parameters.""" + return abs(pred - target) + + +# Test dataclasses for testing +@dataclass +class SimpleDataclass: + """Simple dataclass for testing.""" + + name: str = "test" + value: int = 42 + + +@dataclass +class NestedDataclass: + """Nested dataclass for testing.""" + + simple: SimpleDataclass + description: str = "nested" + + +@dataclass +class CallableDataclass: + """Dataclass with callable and partial fields for testing.""" + + name: str = "callable_test" + activation_func: callable = activation_function + loss_func: callable = functools.partial(loss_function, reduction="sum") + torch_func: callable = torch.nn.functional.relu + lambda_func: callable = lambda x: x + 1 + regular_value: int = 100 + + +@dataclass +class TestConfigContainer(ConfigContainerBase): + """Test configuration container.""" + + name: str = "test_config" + value: int = 100 + description: str = "A test configuration" + + +@dataclass +class ComplexConfigContainer(ConfigContainerBase): + """Complex configuration container for testing.""" + + simple_config: TestConfigContainer + nested_data: NestedDataclass + items: list[str] + metadata: dict[str, int] + + +@dataclass +class CallableConfigContainer(ConfigContainerBase): + """Configuration container with callable fields for testing.""" + + name: str = "callable_config" + callable_data: CallableDataclass = None # Will be set in tests + activation: callable = activation_function + partial_loss: callable = functools.partial(loss_function, reduction="none") + torch_activation: callable = torch.nn.functional.gelu + + def __post_init__(self): + """Initialize callable_data if not provided.""" + if self.callable_data is None: + self.callable_data = CallableDataclass() + + +class TestConfigContainer_FromDict: + """Test ConfigContainer.from_dict method.""" + + @patch("megatron.training.config.container.instantiate") + def test_from_dict_basic(self, mock_instantiate): + """Test basic from_dict functionality.""" + config_dict = { + "_target_": _target_qualname(TestConfigContainer), + "name": "from_dict", + "value": 300, + } + + expected_config = TestConfigContainer(name="from_dict", value=300) + mock_instantiate.return_value = expected_config + + result = TestConfigContainer.from_dict(config_dict) + + mock_instantiate.assert_called_once_with(config_dict, mode=InstantiationMode.STRICT) + assert result.name == "from_dict" + assert result.value == 300 + + @patch("megatron.training.config.container.instantiate") + def test_from_dict_with_mode(self, mock_instantiate): + """Test from_dict with different instantiation modes.""" + config_dict = {"_target_": _target_qualname(TestConfigContainer), "name": "lenient"} + + expected_config = TestConfigContainer(name="lenient") + mock_instantiate.return_value = expected_config + + result = TestConfigContainer.from_dict(config_dict, mode=InstantiationMode.LENIENT) + + mock_instantiate.assert_called_once_with(config_dict, mode=InstantiationMode.LENIENT) + assert result.name == "lenient" + + def test_from_dict_missing_target(self): + """Test from_dict raises error when _target_ is missing.""" + config_dict = {"name": "test"} + + with pytest.raises(AssertionError): + TestConfigContainer.from_dict(config_dict) + + def test_from_dict_extra_keys_strict_mode(self): + """Test from_dict raises error for extra keys in strict mode.""" + config_dict = { + "_target_": _target_qualname(TestConfigContainer), + "name": "test", + "extra_key": "should_fail", + } + + with pytest.raises(ValueError, match="Dictionary contains extra keys"): + TestConfigContainer.from_dict(config_dict, mode=InstantiationMode.STRICT) + + @patch("megatron.training.config.container.instantiate") + def test_from_dict_extra_keys_lenient_mode(self, mock_instantiate): + """Test from_dict removes extra keys in lenient mode.""" + config_dict = { + "_target_": _target_qualname(TestConfigContainer), + "name": "test", + "extra_key": "should_be_removed", + } + + expected_config = TestConfigContainer(name="test") + mock_instantiate.return_value = expected_config + + TestConfigContainer.from_dict(config_dict, mode=InstantiationMode.LENIENT) + + # Verify that extra_key was removed from the dict passed to instantiate + called_dict = mock_instantiate.call_args[0][0] + assert "extra_key" not in called_dict + assert called_dict["name"] == "test" + assert called_dict["_target_"] == _target_qualname(TestConfigContainer) + + def test_from_dict_preserves_original(self): + """Test that from_dict doesn't modify the original dictionary.""" + original_dict = { + "_target_": _target_qualname(TestConfigContainer), + "name": "original", + "extra_key": "should_be_preserved_in_original", + } + + original_copy = copy.deepcopy(original_dict) + + with pytest.raises(ValueError): # This will fail in strict mode + TestConfigContainer.from_dict(original_dict, mode=InstantiationMode.STRICT) + + # Original dict should be unchanged + assert original_dict == original_copy + + +class TestConfigContainer_FromYaml: + """Test ConfigContainer.from_yaml method.""" + + def test_from_yaml_file_not_found(self): + """Test from_yaml raises FileNotFoundError for missing file.""" + with pytest.raises(FileNotFoundError, match="YAML file not found"): + TestConfigContainer.from_yaml("non_existent_file.yaml") + + @patch("megatron.training.config.container.MultiStorageClientFeature.is_enabled") + @patch("megatron.training.config.container.OmegaConf") + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.exists") + def test_from_yaml_success(self, mock_exists, mock_file, mock_omegaconf, mock_msc): + """Test successful YAML loading.""" + mock_msc.return_value = False + mock_exists.return_value = True + yaml_content = f""" + _target_: {_target_qualname(TestConfigContainer)} + name: yaml_config + value: 500 + """ + mock_file.return_value.read.return_value = yaml_content + + # Mock yaml.safe_load to return parsed content + with patch("yaml.safe_load") as mock_yaml_load: + config_dict = { + "_target_": _target_qualname(TestConfigContainer), + "name": "yaml_config", + "value": 500, + } + mock_yaml_load.return_value = config_dict + + # Mock OmegaConf methods + mock_conf = MagicMock() + mock_omegaconf.create.return_value = mock_conf + mock_omegaconf.to_container.return_value = config_dict + + result = TestConfigContainer.from_yaml("test.yaml") + + mock_exists.assert_called_once_with("test.yaml") + mock_file.assert_called_once_with("test.yaml", "r") + mock_yaml_load.assert_called_once() + mock_omegaconf.create.assert_called_once_with(config_dict) + mock_omegaconf.to_container.assert_called_once_with(mock_conf, resolve=True) + + assert result.name == "yaml_config" + assert result.value == 500 + + @patch("megatron.training.config.container.MultiStorageClientFeature.is_enabled") + @patch("os.path.exists") + def test_from_yaml_with_mode(self, mock_exists, mock_msc): + """Test from_yaml with different instantiation modes.""" + mock_msc.return_value = False + mock_exists.return_value = True + + with patch("builtins.open", mock_open()): + with patch("yaml.safe_load", return_value={}): + with patch("megatron.training.config.container.OmegaConf") as mock_omegaconf: + # Mock OmegaConf methods to return expected values + mock_conf = MagicMock() + mock_omegaconf.create.return_value = mock_conf + mock_omegaconf.to_container.return_value = {} # Return actual empty dict + + with patch.object(TestConfigContainer, "from_dict") as mock_from_dict: + TestConfigContainer.from_yaml("test.yaml", mode=InstantiationMode.STRICT) + mock_from_dict.assert_called_once_with({}, mode=InstantiationMode.STRICT) + + +class TestConfigContainer_ToDict: + """Test ConfigContainer.to_dict method.""" + + def test_to_dict_basic(self): + """Test basic to_dict functionality.""" + config = TestConfigContainer(name="test", value=123, description="test desc") + result = config.to_dict() + + expected = { + "_target_": _target_qualname(TestConfigContainer), + "name": "test", + "value": 123, + "description": "test desc", + } + + assert result == expected + + def test_to_dict_with_nested_config_container(self): + """Test to_dict with nested ConfigContainer.""" + simple_config = TestConfigContainer(name="nested", value=456) + nested_data = NestedDataclass(simple=SimpleDataclass(name="inner", value=789)) + + complex_config = ComplexConfigContainer( + simple_config=simple_config, + nested_data=nested_data, + items=["a", "b", "c"], + metadata={"key1": 1, "key2": 2}, + ) + + result = complex_config.to_dict() + + # Check the structure + assert "_target_" in result + assert result["_target_"] == _target_qualname(ComplexConfigContainer) + + # Check nested ConfigContainer + assert result["simple_config"]["_target_"] == _target_qualname(TestConfigContainer) + assert result["simple_config"]["name"] == "nested" + assert result["simple_config"]["value"] == 456 + + # Check nested regular dataclass + assert result["nested_data"]["_target_"] == _target_qualname(NestedDataclass) + assert result["nested_data"]["simple"]["_target_"] == _target_qualname(SimpleDataclass) + assert result["nested_data"]["simple"]["name"] == "inner" + assert result["nested_data"]["simple"]["value"] == 789 + + # Check lists and dicts + assert result["items"] == ["a", "b", "c"] + assert result["metadata"] == {"key1": 1, "key2": 2} + + # TODO (@maanug): reenable after migrating model config+builder + # def test_convert_serializable_nested_in_config(self): + # """Test that a Serializable nested inside a ConfigContainer is serialized via as_dict().""" + + # class NestedSerializable: + # def __init__(self, value): + # self.value = value + + # def as_dict(self) -> dict: + # return {"_target_": "my.module.NestedSerializable", "value": self.value} + + # @classmethod + # def from_dict(cls, data): + # return cls(data["value"]) + + # @dataclass + # class ConfigWithSerializable(ConfigContainerBase): + # name: str = "ser_test" + # nested: object = None + + # def __post_init__(self): + # if self.nested is None: + # self.nested = NestedSerializable(99) + + # config = ConfigWithSerializable() + # result = config.to_dict() + + # assert result["name"] == "ser_test" + # assert result["nested"] == {"_target_": "my.module.NestedSerializable", "value": 99} + + def test_to_dict_excludes_private_fields(self): + """Test that to_dict excludes fields starting with underscore.""" + config = TestConfigContainer() + result = config.to_dict() + + # Should include _target_ but exclude __version__ + assert "_target_" in result + assert "__version__" not in result + + +class TestConfigContainer_ConvertValueToDict: + """Test ConfigContainer._convert_value_to_dict method.""" + + def test_convert_config_container(self): + """Test converting ConfigContainer instance.""" + config = TestConfigContainer(name="convert_test", value=999) + result = TestConfigContainer._convert_value_to_dict(config) + + expected = { + "_target_": _target_qualname(TestConfigContainer), + "name": "convert_test", + "value": 999, + "description": "A test configuration", + } + + assert result == expected + + def test_convert_regular_dataclass(self): + """Test converting regular dataclass.""" + simple = SimpleDataclass(name="simple_test", value=555) + result = TestConfigContainer._convert_value_to_dict(simple) + + expected = { + "_target_": _target_qualname(SimpleDataclass), + "name": "simple_test", + "value": 555, + } + + assert result == expected + + def test_convert_list(self): + """Test converting list with nested dataclasses.""" + items = [SimpleDataclass(name="item1", value=1), "string_item", 42] + result = TestConfigContainer._convert_value_to_dict(items) + + assert len(result) == 3 + assert result[0]["_target_"] == _target_qualname(SimpleDataclass) + assert result[0]["name"] == "item1" + assert result[1] == "string_item" + assert result[2] == 42 + + def test_convert_tuple(self): + """Test converting tuple.""" + items = (SimpleDataclass(name="tuple_item"), "string") + result = TestConfigContainer._convert_value_to_dict(items) + + assert len(result) == 2 + assert result[0]["_target_"] == _target_qualname(SimpleDataclass) + assert result[1] == "string" + + def test_convert_dict(self): + """Test converting dictionary with nested dataclasses.""" + data = { + "config": SimpleDataclass(name="dict_config"), + "value": 123, + "nested": {"inner": SimpleDataclass(name="inner_config")}, + } + result = TestConfigContainer._convert_value_to_dict(data) + + assert result["config"]["_target_"] == _target_qualname(SimpleDataclass) + assert result["value"] == 123 + assert result["nested"]["inner"]["_target_"] == _target_qualname(SimpleDataclass) + + # TODO (@maanug): reenable after migrating model config+builder + # def test_convert_serializable(self): + # """Test converting a Serializable instance uses as_dict().""" + + # class MySerializable: + # def as_dict(self) -> dict: + # return {"_target_": "my.module.MySerializable", "x": 42} + + # @classmethod + # def from_dict(cls, data): + # return cls() + + # obj = MySerializable() + # assert isinstance(obj, Serializable) # runtime_checkable sanity check + + # result = TestConfigContainer._convert_value_to_dict(obj) + + # assert result == {"_target_": "my.module.MySerializable", "x": 42} + + def test_convert_primitive_types(self): + """Test converting primitive types.""" + assert TestConfigContainer._convert_value_to_dict(42) == 42 + assert TestConfigContainer._convert_value_to_dict("string") == "string" + assert TestConfigContainer._convert_value_to_dict(True) is True + assert TestConfigContainer._convert_value_to_dict(None) is None + assert TestConfigContainer._convert_value_to_dict(3.14) == 3.14 + + def test_convert_excludes_private_fields_in_dataclass(self): + """Test that private fields are excluded from dataclass conversion.""" + + @dataclass + class DataclassWithPrivate: + public_field: str = "public" + _private_field: str = "private" + + obj = DataclassWithPrivate() + result = TestConfigContainer._convert_value_to_dict(obj) + + assert "public_field" in result + assert "_private_field" not in result + assert "_target_" in result + + +class TestConfigContainer_ToYaml: + """Test ConfigContainer.to_yaml method.""" + + def test_to_yaml_save_to_file(self): + """Test to_yaml writes valid YAML to disk matching to_dict().""" + config = TestConfigContainer(name="file_test", value=888) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = os.path.join(tmp_dir, "test_output.yaml") + config.to_yaml(tmp_path) + + assert os.path.exists(tmp_path) + with open(tmp_path, "r") as f: + parsed = yaml.safe_load(f) + + assert parsed == config.to_dict() + + def test_to_yaml_with_msc_url(self): + """Test to_yaml with MSC URL.""" + config = TestConfigContainer(name="msc_test", value=999) + + MultiStorageClientFeature.enable() + + # Verify that the file is created in the temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + config.to_yaml(f"msc://default{temp_dir}/test_output.yaml") + assert os.path.exists(f"{temp_dir}/test_output.yaml") + + loaded_config = TestConfigContainer.from_yaml( + f"msc://default{temp_dir}/test_output.yaml" + ) + assert config.to_dict() == loaded_config.to_dict() + + +class TestConfigContainer_PrintYaml: + """Test ConfigContainer.print_yaml method.""" + + def test_print_yaml_basic(self, capsys): + """Test print_yaml outputs valid YAML with the correct field values.""" + config = TestConfigContainer(name="print_test", value=555, description="test print") + + config.print_yaml() + + captured = capsys.readouterr() + parsed = yaml.safe_load(captured.out) + + assert parsed["_target_"] == _target_qualname(TestConfigContainer) + assert parsed["name"] == "print_test" + assert parsed["value"] == 555 + assert parsed["description"] == "test print" + + def test_print_yaml_with_complex_config(self, capsys): + """Test print_yaml with complex nested configuration.""" + simple_config = TestConfigContainer(name="nested", value=123) + nested_data = NestedDataclass(simple=SimpleDataclass(name="inner", value=456)) + + complex_config = ComplexConfigContainer( + simple_config=simple_config, + nested_data=nested_data, + items=["a", "b", "c"], + metadata={"key1": 10, "key2": 20}, + ) + + complex_config.print_yaml() + + captured = capsys.readouterr() + parsed = yaml.safe_load(captured.out) + + assert parsed["_target_"] == _target_qualname(ComplexConfigContainer) + assert parsed["simple_config"]["name"] == "nested" + assert parsed["nested_data"]["simple"]["value"] == 456 + assert parsed["items"] == ["a", "b", "c"] + assert parsed["metadata"] == {"key1": 10, "key2": 20} + + def test_print_yaml_output_matches_to_dict(self, capsys): + """Test that the YAML output exactly round-trips through to_dict.""" + config = TestConfigContainer(name="to_dict_test", value=999) + + config.print_yaml() + + captured = capsys.readouterr() + parsed = yaml.safe_load(captured.out) + + assert parsed == config.to_dict() + + +class TestConfigContainer_DeepCopy: + """Test ConfigContainer.__deepcopy__ method.""" + + def test_deepcopy_basic(self): + """Test basic deep copy functionality.""" + config = TestConfigContainer(name="original", value=100) + copied_config = copy.deepcopy(config) + + assert copied_config is not config + assert copied_config.name == config.name + assert copied_config.value == config.value + assert copied_config.description == config.description + + # Modify original to verify they're independent + config.name = "modified" + assert copied_config.name == "original" + + def test_deepcopy_with_nested_structures(self): + """Test deep copy with nested dataclasses and containers.""" + simple_config = TestConfigContainer(name="nested", value=456) + nested_data = NestedDataclass(simple=SimpleDataclass(name="inner", value=789)) + + complex_config = ComplexConfigContainer( + simple_config=simple_config, + nested_data=nested_data, + items=["a", "b", "c"], + metadata={"key1": 1, "key2": 2}, + ) + + copied_config = copy.deepcopy(complex_config) + + # Verify it's a deep copy + assert copied_config is not complex_config + assert copied_config.simple_config is not complex_config.simple_config + assert copied_config.nested_data is not complex_config.nested_data + assert copied_config.items is not complex_config.items + assert copied_config.metadata is not complex_config.metadata + + # Verify values are preserved + assert copied_config.simple_config.name == "nested" + assert copied_config.nested_data.simple.name == "inner" + assert copied_config.items == ["a", "b", "c"] + assert copied_config.metadata == {"key1": 1, "key2": 2} + + # Verify independence + complex_config.simple_config.name = "modified" + complex_config.items.append("d") + + assert copied_config.simple_config.name == "nested" + assert len(copied_config.items) == 3 + + +class TestConfigContainer_Integration: + """Integration tests for ConfigContainer.""" + + def test_roundtrip_dict_conversion(self): + """Test that converting to dict and back preserves data.""" + simple_config = TestConfigContainer(name="roundtrip", value=999) + nested_data = NestedDataclass( + simple=SimpleDataclass(name="nested", value=888), description="roundtrip test" + ) + + original_config = ComplexConfigContainer( + simple_config=simple_config, + nested_data=nested_data, + items=["x", "y", "z"], + metadata={"test": 42}, + ) + + config_dict = original_config.to_dict() + + reconstructed_config = ComplexConfigContainer.from_dict(config_dict) + + assert reconstructed_config.simple_config.name == original_config.simple_config.name + assert reconstructed_config.simple_config.value == original_config.simple_config.value + assert ( + reconstructed_config.nested_data.description == original_config.nested_data.description + ) + assert ( + reconstructed_config.nested_data.simple.name == original_config.nested_data.simple.name + ) + assert ( + reconstructed_config.nested_data.simple.value + == original_config.nested_data.simple.value + ) + assert reconstructed_config.items == original_config.items + assert reconstructed_config.metadata == original_config.metadata + + def test_yaml_roundtrip_structure(self): + """Test that converting to YAML and back preserves data.""" + config = TestConfigContainer(name="yaml_roundtrip", value=1234) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = os.path.join(tmp_dir, "test_config.yaml") + config.to_yaml(tmp_path) + + loaded_config = TestConfigContainer.from_yaml(tmp_path) + + assert loaded_config.name == config.name + assert loaded_config.value == config.value + assert loaded_config.description == config.description + + +class TestConfigContainer_EdgeCases: + """Test edge cases for ConfigContainer.""" + + def test_empty_config_container(self): + """Test ConfigContainer with minimal fields.""" + + @dataclass + class MinimalConfig(ConfigContainerBase): + pass + + config = MinimalConfig() + result = config.to_dict() + + assert "_target_" in result + # The actual path will be generated based on the local class + assert "MinimalConfig" in result["_target_"] + + def test_config_with_none_values(self): + """Test ConfigContainer with None values.""" + + @dataclass + class ConfigWithNone(ConfigContainerBase): + optional_field: str = None + required_field: str = "required" + + config = ConfigWithNone() + result = config.to_dict() + + assert result["optional_field"] is None + assert result["required_field"] == "required" + + def test_config_with_complex_nested_types(self): + """Test ConfigContainer with complex nested types.""" + + @dataclass + class ComplexConfig(ConfigContainerBase): + nested_list: list[dict[str, SimpleDataclass]] + nested_dict: dict[str, list[SimpleDataclass]] + + nested_list = [ + {"item1": SimpleDataclass(name="list_item1", value=1)}, + {"item2": SimpleDataclass(name="list_item2", value=2)}, + ] + + nested_dict = { + "group1": [SimpleDataclass(name="group1_item1", value=10)], + "group2": [SimpleDataclass(name="group2_item1", value=20)], + } + + config = ComplexConfig(nested_list=nested_list, nested_dict=nested_dict) + result = config.to_dict() + + # Verify complex nested structure conversion + assert len(result["nested_list"]) == 2 + assert result["nested_list"][0]["item1"]["_target_"] == _target_qualname(SimpleDataclass) + assert result["nested_dict"]["group1"][0]["name"] == "group1_item1" + + +class TestConfigContainer_CallablesAndPartials: + """Test ConfigContainer handling of callables and partial functions.""" + + def test_dataclass_with_callables_to_dict(self): + """Test converting dataclass with callables to dict.""" + callable_data = CallableDataclass() + result = TestConfigContainer._convert_value_to_dict(callable_data) + + assert result["_target_"] == _target_qualname(CallableDataclass) + assert result["name"] == "callable_test" + assert result["regular_value"] == 100 + + # Callables are not dataclasses/lists/dicts, so they pass through as-is + assert result["activation_func"] is activation_function + assert isinstance(result["loss_func"], functools.partial) + assert result["loss_func"].func is loss_function + assert result["loss_func"].keywords == {"reduction": "sum"} + assert result["torch_func"] is torch.nn.functional.relu + assert callable(result["lambda_func"]) + assert result["lambda_func"](5) == 6 + + def test_config_container_with_callables_to_dict(self): + """Test ConfigContainer with callable fields converted to dict.""" + config = CallableConfigContainer() + result = config.to_dict() + + assert result["_target_"] == _target_qualname(CallableConfigContainer) + assert result["name"] == "callable_config" + + # Nested CallableDataclass hits the is_dataclass branch and becomes a dict + assert result["callable_data"]["_target_"] == _target_qualname(CallableDataclass) + assert result["callable_data"]["name"] == "callable_test" + assert result["callable_data"]["regular_value"] == 100 + + # Top-level callable fields pass through as-is + assert result["activation"] is activation_function + assert isinstance(result["partial_loss"], functools.partial) + assert result["partial_loss"].func is loss_function + assert result["partial_loss"].keywords == {"reduction": "none"} + assert result["torch_activation"] is torch.nn.functional.gelu + + def test_partial_function_handling(self): + """Test that partial objects pass through _convert_value_to_dict unchanged.""" + partial_func = functools.partial(loss_function, reduction="sum") + result = TestConfigContainer._convert_value_to_dict(partial_func) + + assert result is partial_func + + def test_various_callable_types(self): + """Test that all callable types pass through _convert_value_to_dict unchanged.""" + # Plain function + assert ( + TestConfigContainer._convert_value_to_dict(activation_function) is activation_function + ) + + # Partial — not a dataclass/list/dict so falls through as-is + partial_func = functools.partial(loss_function, reduction="mean") + assert TestConfigContainer._convert_value_to_dict(partial_func) is partial_func + + # Torch built-in function + assert ( + TestConfigContainer._convert_value_to_dict(torch.nn.functional.relu) + is torch.nn.functional.relu + ) + + # Lambda + fn = lambda x: x * 2 + assert TestConfigContainer._convert_value_to_dict(fn) is fn + + # Callable nn.Module instance — not a dataclass, falls through as-is + relu_instance = torch.nn.ReLU() + assert TestConfigContainer._convert_value_to_dict(relu_instance) is relu_instance + + def test_config_with_callables_roundtrip_behavior(self): + """Test that to_dict/from_dict roundtrip preserves all fields for callable configs.""" + config = CallableConfigContainer(name="roundtrip_test") + config_dict = config.to_dict() + + reconstructed = CallableConfigContainer.from_dict(config_dict) + + assert reconstructed.name == config.name + assert reconstructed.callable_data.name == config.callable_data.name + assert reconstructed.callable_data.regular_value == config.callable_data.regular_value + # Callables pass through as-is in to_dict, so they come back with the same identity + assert reconstructed.activation is config.activation + assert reconstructed.partial_loss.func is config.partial_loss.func + assert reconstructed.partial_loss.keywords == config.partial_loss.keywords + assert reconstructed.torch_activation is config.torch_activation + + def test_mixed_container_with_callables_and_regular_data(self): + """Test container mixing callable and regular data.""" + + @dataclass + class MixedConfig(ConfigContainerBase): + name: str = "mixed" + regular_list: list[str] = None + callable_func: callable = activation_function + nested_data: SimpleDataclass = None + + def __post_init__(self): + if self.regular_list is None: + self.regular_list = ["a", "b", "c"] + if self.nested_data is None: + self.nested_data = SimpleDataclass(name="nested", value=999) + + config = MixedConfig() + result = config.to_dict() + + # Verify mixed content handling + assert result["name"] == "mixed" + assert result["regular_list"] == ["a", "b", "c"] + assert result["nested_data"]["name"] == "nested" + assert result["nested_data"]["value"] == 999 + + # Callable fields pass through as-is + assert result["callable_func"] is activation_function + + def test_deepcopy_with_callables(self): + """Test deep copying ConfigContainer with callable fields.""" + config = CallableConfigContainer(name="deepcopy_test") + + # Verify original works + assert config.name == "deepcopy_test" + assert callable(config.activation) + assert callable(config.partial_loss) + + # Test deep copy + copied_config = copy.deepcopy(config) + + # Verify copy independence + assert copied_config is not config + assert copied_config.name == "deepcopy_test" + + # Verify callable fields are handled properly + assert callable(copied_config.activation) + assert callable(copied_config.partial_loss) + + # Test that functions still work + assert copied_config.activation(5) == 10 # test_activation_function multiplies by 2 + + # Modify original to verify independence + config.name = "modified" + assert copied_config.name == "deepcopy_test" diff --git a/tests/unit_tests/training/config/test_instantiate_utils.py b/tests/unit_tests/training/config/test_instantiate_utils.py new file mode 100644 index 00000000000..a1316b8064b --- /dev/null +++ b/tests/unit_tests/training/config/test_instantiate_utils.py @@ -0,0 +1,580 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import enum +import functools +import logging +from unittest.mock import MagicMock, patch + +import pytest +from omegaconf import OmegaConf + +from megatron.training.config.instantiate_utils import ( + InstantiationException, + InstantiationMode, + _call_target, + _convert_node, + _convert_target_to_string, + _extract_pos_args, + _is_target, + _Keys, + _locate, + _prepare_input_dict_or_list, + _resolve_target, + instantiate, + instantiate_node, +) + + +@pytest.fixture(autouse=True) +def _disable_allowlist(): + """Temporarily disable allowlist to fully test instantiate logic with local test targets.""" + from megatron.training.config.instantiate_utils import target_allowlist + + target_allowlist.disable() + yield + target_allowlist.enable() + + +def _target_qualname(obj) -> str: + return f"{obj.__module__}.{obj.__qualname__}" + + +# Test classes and functions for instantiation testing +class TestClass: + """Test class for instantiation.""" + + def __init__(self, arg1=None, arg2=None, **kwargs): + self.arg1 = arg1 + self.arg2 = arg2 + self.kwargs = kwargs + + +def test_function(arg1=None, arg2=None, **kwargs): + """Test function for instantiation.""" + return {"arg1": arg1, "arg2": arg2, "kwargs": kwargs} + + +class TestInstantiationException: + """Test InstantiationException class.""" + + def test_instantiation_exception_creation(self): + """Test creating InstantiationException.""" + msg = "Test error message" + exc = InstantiationException(msg) + assert str(exc) == msg + assert isinstance(exc, Exception) + + +class TestInstantiationMode: + """Test InstantiationMode enum.""" + + def test_instantiation_mode_values(self): + """Test InstantiationMode enum values.""" + assert InstantiationMode.STRICT.value == "strict" + assert InstantiationMode.LENIENT.value == "lenient" + + +class TestKeys: + """Test _Keys enum.""" + + def test_keys_values(self): + """Test _Keys enum values.""" + assert _Keys.TARGET == "_target_" + assert _Keys.PARTIAL == "_partial_" + assert _Keys.CALL == "_call_" + assert _Keys.ARGS == "_args_" + assert _Keys.NAME == "_name_" + + +class TestInstantiate: + """Test instantiate function.""" + + def test_instantiate_none(self): + """Test instantiate with None config.""" + result = instantiate(None) + assert result is None + + def test_instantiate_simple_class(self): + """Test instantiating a simple class.""" + config = {"_target_": _target_qualname(TestClass), "arg1": "value1", "arg2": "value2"} + result = instantiate(config) + assert isinstance(result, TestClass) + assert result.arg1 == "value1" + assert result.arg2 == "value2" + + def test_instantiate_function(self): + """Test instantiating a function.""" + config = {"_target_": _target_qualname(test_function), "arg1": "value1", "arg2": "value2"} + result = instantiate(config) + expected = {"arg1": "value1", "arg2": "value2", "kwargs": {}} + assert result == expected + + def test_instantiate_with_args(self): + """Test instantiate with positional args.""" + config = {"_target_": _target_qualname(test_function), "_args_": ["pos1", "pos2"]} + result = instantiate(config) + expected = {"arg1": "pos1", "arg2": "pos2", "kwargs": {}} + assert result == expected + + def test_instantiate_with_partial(self): + """Test instantiate with partial=True.""" + config = {"_target_": _target_qualname(test_function), "_partial_": True, "arg1": "value1"} + result = instantiate(config) + assert isinstance(result, functools.partial) + actual_result = result(arg2="value2") + expected = {"arg1": "value1", "arg2": "value2", "kwargs": {}} + assert actual_result == expected + + def test_instantiate_with_call_false(self): + """Test instantiate with _call_=False.""" + config = {"_target_": _target_qualname(test_function), "_call_": False} + result = instantiate(config) + assert callable(result) + assert result == test_function + + def test_instantiate_with_call_false_and_extra_keys(self): + """Test instantiate with _call_=False and extra keys raises error.""" + config = { + "_target_": _target_qualname(test_function), + "_call_": False, + "extra_key": "value", + } + with pytest.raises(InstantiationException, match="_call_ was set to False"): + instantiate(config) + + def test_instantiate_with_kwargs_override(self): + """Test instantiate with kwargs override.""" + config = {"_target_": _target_qualname(test_function), "arg1": "original"} + result = instantiate(config, arg1="override", arg2="new") + expected = {"arg1": "override", "arg2": "new", "kwargs": {}} + assert result == expected + + def test_instantiate_list_config(self): + """Test instantiate with list config.""" + config = [ + {"_target_": _target_qualname(test_function), "arg1": "item1"}, + {"_target_": _target_qualname(test_function), "arg1": "item2"}, + ] + result = instantiate(config) + assert len(result) == 2 + assert result[0] == {"arg1": "item1", "arg2": None, "kwargs": {}} + assert result[1] == {"arg1": "item2", "arg2": None, "kwargs": {}} + + def test_instantiate_list_with_partial_raises_error(self): + """Test instantiate list with _partial_=True raises error.""" + config = ["item1", "item2"] + with pytest.raises(InstantiationException, match="_partial_ keyword is not compatible"): + instantiate(config, _partial_=True) + + def test_instantiate_invalid_config_type(self): + """Test instantiate with invalid config type.""" + with pytest.raises(InstantiationException, match="Cannot instantiate config of type"): + instantiate("invalid_config") + + def test_instantiate_strict_mode_error(self): + """Test instantiate in strict mode with error.""" + config = { + "_target_": _target_qualname(TestClass), + "nested": {"_target_": "non.existent.module.Class"}, + } + with pytest.raises(InstantiationException): + instantiate(config, mode=InstantiationMode.STRICT) + + def test_instantiate_lenient_mode_error(self): + """In lenient mode, nested resolution errors now propagate (no auto-None).""" + config = { + "_target_": _target_qualname(TestClass), + "nested": {"_target_": "non.existent.module.Class"}, + } + with pytest.raises(InstantiationException, match="Error locating target"): + instantiate(config, mode=InstantiationMode.LENIENT) + + def test_instantiate_with_omegaconf_dict(self): + """Test instantiate with OmegaConf DictConfig.""" + config = OmegaConf.create({"_target_": _target_qualname(TestClass), "arg1": "value1"}) + result = instantiate(config) + assert isinstance(result, TestClass) + assert result.arg1 == "value1" + + def test_instantiate_with_omegaconf_list(self): + """Test instantiate with OmegaConf ListConfig.""" + config = OmegaConf.create([{"_target_": _target_qualname(test_function), "arg1": "item1"}]) + result = instantiate(config) + assert len(result) == 1 + assert result[0] == {"arg1": "item1", "arg2": None, "kwargs": {}} + + +class TestInstantiateNode: + """Test instantiate_node function.""" + + def test_instantiate_node_none(self): + """Test instantiate_node with None.""" + result = instantiate_node(None) + assert result is None + + def test_instantiate_node_non_config(self): + """Test instantiate_node with non-config value.""" + result = instantiate_node("simple_string") + assert result == "simple_string" + + def test_instantiate_node_dict_without_target(self): + """Test instantiate_node with dict without _target_.""" + config = OmegaConf.create({"key1": "value1", "key2": "value2"}) + result = instantiate_node(config) + assert result == {"key1": "value1", "key2": "value2"} + + def test_instantiate_node_list(self): + """Test instantiate_node with list.""" + config = OmegaConf.create(["item1", "item2"]) + result = instantiate_node(config) + assert result == ["item1", "item2"] + + def test_instantiate_node_partial_not_bool_raises_error(self): + """Test instantiate_node with non-bool partial raises error.""" + config = OmegaConf.create({"_partial_": "not_bool"}) + with pytest.raises(TypeError, match="_partial_ flag must be a bool"): + instantiate_node(config) + + +class TestLocate: + """Test _locate function.""" + + def test_locate_valid_path(self): + """Test _locate with valid path.""" + result = _locate("builtins.str") + assert result == str + + def test_locate_empty_path(self): + """Test _locate with empty path.""" + with pytest.raises(ImportError, match="Empty path"): + _locate("") + + def test_locate_invalid_path(self): + """Test _locate with invalid path.""" + with pytest.raises(ImportError, match="Unable to import any module"): + _locate("non.existent.module") + + def test_locate_invalid_dotstring(self): + """Test _locate with invalid dotstring.""" + with pytest.raises(ValueError, match="invalid dotstring"): + _locate("invalid..path") + + def test_locate_relative_import(self): + """Test _locate with relative import.""" + with pytest.raises(ValueError, match="Relative imports are not supported"): + _locate(".relative.import") + + def test_locate_attribute_error(self): + """Test _locate with attribute that doesn't exist.""" + with pytest.raises(ImportError, match="Are you sure that"): + _locate("builtins.nonexistent_attribute") + + +class TestIsTarget: + """Test _is_target function.""" + + def test_is_target_dict_with_target(self): + """Test _is_target with dict containing _target_.""" + config = {"_target_": "some.target"} + assert _is_target(config) is True + + def test_is_target_dict_without_target(self): + """Test _is_target with dict not containing _target_.""" + config = {"other_key": "value"} + assert _is_target(config) is False + + def test_is_target_omegaconf_with_target(self): + """Test _is_target with OmegaConf containing _target_.""" + config = OmegaConf.create({"_target_": "some.target"}) + assert _is_target(config) is True + + def test_is_target_omegaconf_without_target(self): + """Test _is_target with OmegaConf not containing _target_.""" + config = OmegaConf.create({"other_key": "value"}) + assert _is_target(config) is False + + def test_is_target_non_dict(self): + """Test _is_target with non-dict value.""" + assert _is_target("string") is False + assert _is_target(123) is False + assert _is_target([]) is False + + +class TestCallTarget: + """Test _call_target function.""" + + def test_call_target_normal(self): + """Test _call_target with normal call.""" + result = _call_target(test_function, False, (), {"arg1": "value1"}, "test_key") + expected = {"arg1": "value1", "arg2": None, "kwargs": {}} + assert result == expected + + def test_call_target_partial(self): + """Test _call_target with partial=True.""" + result = _call_target(test_function, True, (), {"arg1": "value1"}, "test_key") + assert isinstance(result, functools.partial) + + def test_call_target_with_args(self): + """Test _call_target with positional args.""" + result = _call_target(test_function, False, ("pos1", "pos2"), {}, "test_key") + expected = {"arg1": "pos1", "arg2": "pos2", "kwargs": {}} + assert result == expected + + def test_call_target_error_normal(self): + """Test _call_target with error in normal call.""" + + def failing_function(): + raise ValueError("Test error") + + with pytest.raises(InstantiationException, match="Error in call to target"): + _call_target(failing_function, False, (), {}, "test_key") + + def test_call_target_error_partial(self): + """Test _call_target with error in partial creation.""" + # Create a mock that raises an error when used with functools.partial + mock_target = MagicMock() + mock_target.__module__ = "test_module" + mock_target.__qualname__ = "test_function" + + with patch("functools.partial", side_effect=ValueError("Partial error")): + with pytest.raises(InstantiationException, match="Error in creating partial"): + _call_target(mock_target, True, (), {}, "test_key") + + +class TestConvertTargetToString: + """Test _convert_target_to_string function.""" + + def test_convert_callable_to_string(self): + """Test converting callable to string.""" + result = _convert_target_to_string(test_function) + assert "test_function" in result + + def test_convert_non_callable_to_string(self): + """Test converting non-callable to string.""" + result = _convert_target_to_string("already_string") + assert result == "already_string" + + +class TestPrepareInputDictOrList: + """Test _prepare_input_dict_or_list function.""" + + def test_prepare_dict(self): + """Test preparing input dict.""" + input_dict = {"_target_": test_function, "key1": "value1", "nested": {"key2": "value2"}} + result = _prepare_input_dict_or_list(input_dict) + assert "_target_" in result + assert "test_function" in result["_target_"] + assert result["key1"] == "value1" + assert result["nested"]["key2"] == "value2" + + def test_prepare_list(self): + """Test preparing input list.""" + input_list = [{"_target_": test_function, "key1": "value1"}, ["nested_item"]] + result = _prepare_input_dict_or_list(input_list) + assert len(result) == 2 + assert "test_function" in result[0]["_target_"] + assert result[0]["key1"] == "value1" + assert result[1] == ["nested_item"] + + +class TestResolveTarget: + """Test _resolve_target function.""" + + def test_resolve_string_target(self): + """Test resolving string target.""" + result = _resolve_target("builtins.str", "test_key") + assert result == str + + def test_resolve_callable_target(self): + """Test resolving already callable target.""" + result = _resolve_target(test_function, "test_key") + assert result == test_function + + def test_resolve_invalid_string_target(self): + """Test resolving invalid string target.""" + with pytest.raises(InstantiationException, match="Error locating target"): + _resolve_target("invalid.target", "test_key") + + def test_resolve_non_callable_target(self): + """Test resolving non-callable target with check_callable=True.""" + with pytest.raises(InstantiationException, match="Expected a callable target"): + _resolve_target("builtins.__name__", "test_key", check_callable=True) + + def test_resolve_non_callable_target_no_check(self): + """Test resolving non-callable target with check_callable=False.""" + result = _resolve_target("builtins.__name__", "test_key", check_callable=False) + assert result == "builtins" + + +class TestExtractPosArgs: + """Test _extract_pos_args function.""" + + def test_extract_pos_args_no_input_args(self): + """Test extracting pos args with no input args.""" + kwargs = {"_args_": ["arg1", "arg2"], "key1": "value1"} + args, remaining_kwargs = _extract_pos_args((), kwargs) + assert args == ["arg1", "arg2"] + assert remaining_kwargs == {"key1": "value1"} + + def test_extract_pos_args_with_input_args(self): + """Test extracting pos args with input args override.""" + kwargs = {"_args_": ["config_arg1", "config_arg2"], "key1": "value1"} + input_args = ["input_arg1", "input_arg2"] + args, remaining_kwargs = _extract_pos_args(input_args, kwargs) + assert args == input_args + assert remaining_kwargs == {"key1": "value1"} + + def test_extract_pos_args_no_args_key(self): + """Test extracting pos args with no _args_ key.""" + kwargs = {"key1": "value1"} + args, remaining_kwargs = _extract_pos_args((), kwargs) + assert args == () + assert remaining_kwargs == {"key1": "value1"} + + def test_extract_pos_args_invalid_type(self): + """Test extracting pos args with invalid _args_ type.""" + kwargs = {"_args_": 123} # Integer is not a sequence + with pytest.raises(InstantiationException, match="Unsupported _args_ type"): + _extract_pos_args((), kwargs) + + +class TestConvertNode: + """Test _convert_node function.""" + + def test_convert_omegaconf_node(self): + """Test converting OmegaConf node.""" + config = OmegaConf.create({"key1": "value1", "key2": 2}) + result = _convert_node(config) + assert result == {"key1": "value1", "key2": 2} + assert not OmegaConf.is_config(result) + + def test_convert_non_config_node(self): + """Test converting non-config node.""" + value = {"key1": "value1"} + result = _convert_node(value) + assert result == value + + +class TestComplexScenarios: + """Test complex instantiation scenarios.""" + + def test_nested_instantiation(self): + """Test nested instantiation scenario.""" + config = { + "_target_": _target_qualname(TestClass), + "arg1": {"_target_": _target_qualname(test_function), "arg1": "nested_value"}, + "arg2": "simple_value", + } + result = instantiate(config) + assert isinstance(result, TestClass) + assert result.arg1 == {"arg1": "nested_value", "arg2": None, "kwargs": {}} + assert result.arg2 == "simple_value" + + def test_list_with_nested_targets(self): + """Test list with nested target instantiation.""" + config = [ + { + "_target_": _target_qualname(TestClass), + "arg1": {"_target_": _target_qualname(test_function), "arg1": "item1"}, + }, + "simple_item", + ] + result = instantiate(config) + assert len(result) == 2 + assert isinstance(result[0], TestClass) + assert result[0].arg1 == {"arg1": "item1", "arg2": None, "kwargs": {}} + assert result[1] == "simple_item" + + def test_missing_values_with_partial(self): + """Test missing values with partial instantiation.""" + config = OmegaConf.create( + { + "_target_": _target_qualname(test_function), + "_partial_": True, + "arg1": "value1", + "missing_arg": "???", # OmegaConf missing value + } + ) + OmegaConf.set_struct(config, True) + + result = instantiate(config) + assert isinstance(result, functools.partial) + # The missing value should be skipped in partial mode + actual_result = result(arg2="value2") + expected = {"arg1": "value1", "arg2": "value2", "kwargs": {}} + assert actual_result == expected + + +class DummyTarget: + def __init__(self, a: int, b: int = 0) -> None: + self.a = a + self.b = b + + +class KwTarget: + def __init__(self, **kwargs) -> None: + self.kwargs = dict(kwargs) + + +def test_drops_unexpected_kwargs_and_warns(caplog: pytest.LogCaptureFixture) -> None: + config = { + "_target_": _target_qualname(DummyTarget), + "a": 10, + "foo": 123, # unexpected key that should be dropped + } + + with caplog.at_level(logging.WARNING): + obj = instantiate(config) + + assert isinstance(obj, DummyTarget) + assert obj.a == 10 + # 'foo' is dropped; 'b' remains default + assert obj.b == 0 + + # Ensure a warning was emitted mentioning the dropped key + warnings = [rec.getMessage() for rec in caplog.records if rec.levelno == logging.WARNING] + assert any("Dropping unexpected config keys" in m for m in warnings) + assert any("foo" in m for m in warnings) + + +def test_allows_kwargs_when_target_accepts_var_kwargs(caplog: pytest.LogCaptureFixture) -> None: + config = {"_target_": _target_qualname(KwTarget), "foo": 1, "bar": 2} + + with caplog.at_level(logging.WARNING): + obj = instantiate(config) + + assert isinstance(obj, KwTarget) + assert obj.kwargs == {"foo": 1, "bar": 2} + + # No warning should be emitted for **kwargs targets + warnings = [rec.getMessage() for rec in caplog.records if rec.levelno == logging.WARNING] + assert not any("Dropping unexpected config keys" in m for m in warnings) + + +def test_raises_on_unexpected_kwargs_in_strict_mode() -> None: + config = {"_target_": _target_qualname(DummyTarget), "a": 10, "foo": 123} + + with pytest.raises(InstantiationException): + instantiate(config, mode=InstantiationMode.STRICT) + + +class TestEnum(enum.Enum): + A = 1 + B = 2 + + +class TestInstantiateEnum: + """Test instantiation of Enums.""" + + def test_instantiate_enum_with_args(self): + """Test instantiating an Enum with _args_.""" + config = {"_target_": _target_qualname(TestEnum), "_args_": [1]} + result = instantiate(config) + assert result == TestEnum.A + + def test_instantiate_enum_with_args_lenient(self): + """Test instantiating an Enum with _args_ in lenient mode (default).""" + config = {"_target_": _target_qualname(TestEnum), "_args_": [2]} + # This previously failed because _args_ was dropped in lenient mode + result = instantiate(config) + assert result == TestEnum.B diff --git a/tests/unit_tests/training/config/test_target_allowlist.py b/tests/unit_tests/training/config/test_target_allowlist.py new file mode 100644 index 00000000000..3be5f8c87e3 --- /dev/null +++ b/tests/unit_tests/training/config/test_target_allowlist.py @@ -0,0 +1,221 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import pytest + +from megatron.training.config.instantiate_utils import ( + InstantiationException, + TargetAllowlist, + _resolve_target, + target_allowlist, +) + + +class TestTargetAllowlistIsAllowed: + """Tests for the TargetAllowlist.is_allowed() method.""" + + def test_allows_megatron_training_targets(self): + al = TargetAllowlist() + assert al.is_allowed("megatron.training.config.training_config.TrainingConfig") + assert al.is_allowed("megatron.training.config.container.PretrainConfigContainer") + + def test_allows_megatron_core_targets(self): + al = TargetAllowlist() + assert al.is_allowed("megatron.core.optimizer.OptimizerConfig") + assert al.is_allowed( + "megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig" + ) + + def test_allows_torch_targets(self): + al = TargetAllowlist() + assert al.is_allowed("torch.float16") + assert al.is_allowed("torch.bfloat16") + assert al.is_allowed("torch.float32") + + def test_allows_transformers_targets(self): + al = TargetAllowlist() + assert al.is_allowed("transformers.GenerationConfig.from_dict") + assert al.is_allowed("transformers.LlamaConfig.from_dict") + + def test_allows_signal_targets(self): + al = TargetAllowlist() + assert al.is_allowed("signal.Signals") + + def test_allows_exact_functools_partial(self): + al = TargetAllowlist() + assert al.is_allowed("functools.partial") + + def test_blocks_os_system(self): + al = TargetAllowlist() + assert not al.is_allowed("os.system") + + def test_blocks_subprocess(self): + al = TargetAllowlist() + assert not al.is_allowed("subprocess.call") + assert not al.is_allowed("subprocess.Popen") + + def test_blocks_builtins(self): + al = TargetAllowlist() + assert not al.is_allowed("builtins.eval") + assert not al.is_allowed("builtins.exec") + assert not al.is_allowed("builtins.__import__") + + def test_blocks_shutil(self): + al = TargetAllowlist() + assert not al.is_allowed("shutil.rmtree") + + def test_blocks_importlib(self): + al = TargetAllowlist() + assert not al.is_allowed("importlib.import_module") + + def test_blocks_empty_string(self): + al = TargetAllowlist() + assert not al.is_allowed("") + + def test_blocks_partial_prefix_match(self): + """Ensure prefix matching doesn't match partial module names.""" + al = TargetAllowlist() + # "torchvision" starts with "torch" but not "torch." + assert not al.is_allowed("torchvision.models.resnet50") + + +class TestTargetAllowlistAddRemove: + """Tests for add/remove prefix and exact.""" + + def test_add_prefix(self): + al = TargetAllowlist() + assert not al.is_allowed("custom_lib.MyClass") + al.add_prefix("custom_lib.") + assert al.is_allowed("custom_lib.MyClass") + + def test_add_prefix_requires_trailing_dot(self): + al = TargetAllowlist() + with pytest.raises(ValueError, match="Prefix must end with '.'"): + al.add_prefix("custom_lib") + + def test_add_prefix_is_idempotent(self): + al = TargetAllowlist() + al.add_prefix("custom_lib.") + al.add_prefix("custom_lib.") + assert al.allowed_prefixes.count("custom_lib.") == 1 + + def test_remove_prefix(self): + al = TargetAllowlist() + al.add_prefix("custom_lib.") + assert al.is_allowed("custom_lib.MyClass") + al.remove_prefix("custom_lib.") + assert not al.is_allowed("custom_lib.MyClass") + + def test_remove_prefix_not_found_raises(self): + al = TargetAllowlist() + with pytest.raises(ValueError): + al.remove_prefix("nonexistent.") + + def test_add_exact(self): + al = TargetAllowlist() + assert not al.is_allowed("os.getcwd") + al.add_exact("os.getcwd") + assert al.is_allowed("os.getcwd") + # Other os.* targets still blocked + assert not al.is_allowed("os.system") + + def test_remove_exact(self): + al = TargetAllowlist() + al.add_exact("os.getcwd") + assert al.is_allowed("os.getcwd") + al.remove_exact("os.getcwd") + assert not al.is_allowed("os.getcwd") + + def test_remove_exact_nonexistent_is_noop(self): + al = TargetAllowlist() + al.remove_exact("nonexistent.target") # Should not raise + + +class TestTargetAllowlistEnableDisable: + """Tests for enable/disable and env var override.""" + + def test_disable_allows_everything(self): + al = TargetAllowlist() + al.disable() + assert al.is_allowed("os.system") + assert al.is_allowed("subprocess.call") + assert not al.enabled + + def test_enable_after_disable(self): + al = TargetAllowlist() + al.disable() + assert al.is_allowed("os.system") + al.enable() + assert not al.is_allowed("os.system") + assert al.enabled + + def test_properties(self): + al = TargetAllowlist() + assert isinstance(al.allowed_prefixes, tuple) + assert isinstance(al.allowed_exact, frozenset) + assert "functools.partial" in al.allowed_exact + assert "megatron.training." in al.allowed_prefixes + + +class TestResolveTargetAllowlistEnforcement: + """Tests that _resolve_target() enforces the allowlist.""" + + def test_blocked_string_target_raises(self): + with pytest.raises(InstantiationException, match="not in the allowlist"): + _resolve_target("os.system", "", check_callable=True) + + def test_blocked_target_error_message_contains_target_name(self): + with pytest.raises(InstantiationException, match="os.system"): + _resolve_target("os.system", "", check_callable=True) + + def test_blocked_target_error_message_contains_prefixes(self): + with pytest.raises(InstantiationException, match="megatron.training."): + _resolve_target("os.system", "", check_callable=True) + + def test_blocked_target_error_message_contains_remediation(self): + with pytest.raises(InstantiationException, match="add_prefix"): + _resolve_target("os.system", "", check_callable=True) + + def test_blocked_target_error_message_includes_full_key(self): + with pytest.raises(InstantiationException, match="full_key: my.config.key"): + _resolve_target("os.system", "my.config.key", check_callable=True) + + def test_nonstring_target_bypasses_allowlist(self): + """Already-resolved callables should not be blocked.""" + result = _resolve_target(int, "", check_callable=True) + assert result is int + + def test_allowed_target_resolves(self): + """Allowed targets should be resolved normally.""" + result = _resolve_target("functools.partial", "", check_callable=True) + import functools + + assert result is functools.partial + + +class TestResolveTargetClassAllowlistEnforcement: + """Tests that _resolve_target_class() in utils.py respects the allowlist.""" + + def test_blocked_target_returns_none(self): + from megatron.training.config.utils import _resolve_target_class + + result = _resolve_target_class("os.system") + assert result is None + + def test_allowed_target_resolves(self): + from megatron.training.config.utils import _resolve_target_class + + # This should resolve to the actual class + result = _resolve_target_class("megatron.training.config.instantiate_utils.TargetAllowlist") + assert result is TargetAllowlist + + +class TestModuleLevelSingleton: + """Tests for the module-level target_allowlist singleton.""" + + def test_singleton_is_enabled_by_default(self): + assert target_allowlist.enabled + + def test_singleton_has_default_prefixes(self): + assert "megatron.training." in target_allowlist.allowed_prefixes + assert "megatron.core." in target_allowlist.allowed_prefixes + assert "torch." in target_allowlist.allowed_prefixes diff --git a/tests/unit_tests/training/config/test_utils.py b/tests/unit_tests/training/config/test_utils.py new file mode 100644 index 00000000000..53c476268d2 --- /dev/null +++ b/tests/unit_tests/training/config/test_utils.py @@ -0,0 +1,266 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass, field + +import pytest + +from megatron.training.config.container import ConfigContainerBase + + +@pytest.fixture(autouse=True) +def _disable_allowlist(): + """Temporarily disable allowlist to fully test utils logic with local test targets.""" + from megatron.training.config.instantiate_utils import target_allowlist + + target_allowlist.disable() + yield + target_allowlist.enable() + + +# Test dataclasses for testing +@dataclass +class SimpleDataclass: + """Simple dataclass for testing.""" + + name: str = "test" + value: int = 42 + + +@dataclass +class DataclassWithInitFalse: + """Dataclass with init=False field for testing backward compatibility.""" + + name: str = "test" + value: int = 42 + computed_field: str = field(init=False, default="computed") + + def __post_init__(self): + self.computed_field = f"computed_{self.name}" + + +@dataclass +class NestedDataclassWithInitFalse: + """Nested dataclass with init=False field.""" + + inner: DataclassWithInitFalse = None + metadata: dict = field(default_factory=dict) + cached_result: list = field(init=False, default_factory=list) + + +class TestBackwardCompatibility: + """Test suite for backward compatibility functions.""" + + def test_get_init_false_fields_with_init_false(self): + """Test _get_init_false_fields correctly identifies init=False fields.""" + from megatron.training.config.utils import _get_init_false_fields + + result = _get_init_false_fields(DataclassWithInitFalse) + assert "computed_field" in result + assert "name" not in result + assert "value" not in result + + def test_get_init_false_fields_no_init_false(self): + """Test _get_init_false_fields returns empty set for normal dataclass.""" + from megatron.training.config.utils import _get_init_false_fields + + result = _get_init_false_fields(SimpleDataclass) + assert result == frozenset() + + def test_get_init_false_fields_non_dataclass(self): + """Test _get_init_false_fields returns empty set for non-dataclass.""" + from megatron.training.config.utils import _get_init_false_fields + + result = _get_init_false_fields(str) + assert result == frozenset() + + def test_resolve_target_class_valid(self): + """Test _resolve_target_class resolves valid class path.""" + from megatron.training.config.utils import _resolve_target_class + + result = _resolve_target_class("megatron.training.config.container.ConfigContainerBase") + assert result is ConfigContainerBase + + def test_resolve_target_class_invalid(self): + """Test _resolve_target_class returns None for invalid path.""" + from megatron.training.config.utils import _resolve_target_class + + result = _resolve_target_class("nonexistent.module.ClassName") + assert result is None + + def test_resolve_target_class_malformed(self): + """Test _resolve_target_class handles malformed paths gracefully.""" + from megatron.training.config.utils import _resolve_target_class + + result = _resolve_target_class("no_dots") + assert result is None + + def test_sanitize_dataclass_config_removes_init_false_fields(self): + """Test sanitize_dataclass_config removes init=False fields.""" + from megatron.training.config.utils import sanitize_dataclass_config + + config = { + "_target_": "tests.unit_tests.training.config.test_utils.DataclassWithInitFalse", + "name": "test_name", + "value": 123, + "computed_field": "should_be_removed", + } + + result = sanitize_dataclass_config(config) + + assert "name" in result + assert "value" in result + assert "_target_" in result + assert "computed_field" not in result + + def test_sanitize_dataclass_config_preserves_normal_fields(self): + """Test sanitize_dataclass_config preserves fields without init=False.""" + from megatron.training.config.utils import sanitize_dataclass_config + + config = { + "_target_": "tests.unit_tests.training.config.test_utils.SimpleDataclass", + "name": "preserved", + "value": 999, + } + + result = sanitize_dataclass_config(config) + + assert result["name"] == "preserved" + assert result["value"] == 999 + assert result["_target_"] == config["_target_"] + + def test_sanitize_dataclass_config_handles_nested_configs(self): + """Test sanitize_dataclass_config recursively processes nested configs.""" + from megatron.training.config.utils import sanitize_dataclass_config + + config = { + "_target_": "tests.unit_tests.training.config.test_utils.NestedDataclassWithInitFalse", + "inner": { + "_target_": "tests.unit_tests.training.config.test_utils.DataclassWithInitFalse", + "name": "inner_test", + "value": 42, + "computed_field": "nested_computed_should_be_removed", + }, + "metadata": {"key": "value"}, + "cached_result": ["should", "be", "removed"], + } + + result = sanitize_dataclass_config(config) + + # Top-level init=False field removed + assert "cached_result" not in result + # Nested init=False field removed + assert "computed_field" not in result["inner"] + # Normal fields preserved + assert result["inner"]["name"] == "inner_test" + assert result["metadata"] == {"key": "value"} + + def test_sanitize_dataclass_config_handles_lists_of_configs(self): + """Test sanitize_dataclass_config processes lists containing configs.""" + from megatron.training.config.utils import sanitize_dataclass_config + + config = { + "_target_": "some.module.ListContainer", + "items": [ + { + "_target_": "tests.unit_tests.training.config.test_utils.DataclassWithInitFalse", + "name": "item1", + "computed_field": "remove_me", + }, + { + "_target_": "tests.unit_tests.training.config.test_utils.DataclassWithInitFalse", + "name": "item2", + "computed_field": "remove_me_too", + }, + ], + } + + result = sanitize_dataclass_config(config) + + assert "computed_field" not in result["items"][0] + assert "computed_field" not in result["items"][1] + assert result["items"][0]["name"] == "item1" + assert result["items"][1]["name"] == "item2" + + def test_sanitize_dataclass_config_no_target(self): + """Test sanitize_dataclass_config handles dicts without _target_.""" + from megatron.training.config.utils import sanitize_dataclass_config + + config = {"key": "value", "number": 42} + result = sanitize_dataclass_config(config) + + assert result == config + + def test_sanitize_dataclass_config_non_dict_input(self): + """Test sanitize_dataclass_config handles non-dict input.""" + from megatron.training.config.utils import sanitize_dataclass_config + + assert sanitize_dataclass_config("string") == "string" + assert sanitize_dataclass_config(42) == 42 + assert sanitize_dataclass_config(None) is None + + def test_sanitize_dataclass_config_unresolvable_target(self): + """Test sanitize_dataclass_config handles unresolvable _target_.""" + from megatron.training.config.utils import sanitize_dataclass_config + + config = {"_target_": "nonexistent.module.Class", "field1": "value1", "field2": "value2"} + + result = sanitize_dataclass_config(config) + + # All fields preserved when target can't be resolved + assert result == config + + def test_sanitize_dataclass_config_sanitizes_model(self): + """Test sanitize_dataclass_config sanitizes model section with init=False fields.""" + from megatron.training.config.utils import sanitize_dataclass_config + + run_config = { + "model": { + "_target_": "tests.unit_tests.training.config.test_utils.DataclassWithInitFalse", + "name": "model_name", + "value": 100, + "computed_field": "should_be_removed", + }, + "training": {"lr": 0.001}, + "tokenizer": {"type": "sentencepiece"}, + } + + result = sanitize_dataclass_config(run_config) + + assert "computed_field" not in result["model"] + assert result["model"]["name"] == "model_name" + assert result["training"] == {"lr": 0.001} + assert result["tokenizer"] == {"type": "sentencepiece"} + + def test_sanitize_dataclass_config_sanitizes_all_sections(self): + """Test sanitize_dataclass_config sanitizes all sections, not just model.""" + from megatron.training.config.utils import sanitize_dataclass_config + + run_config = { + "model": { + "_target_": "tests.unit_tests.training.config.test_utils.DataclassWithInitFalse", + "name": "model_name", + "computed_field": "should_be_removed_from_model", + }, + "training": { + "_target_": "tests.unit_tests.training.config.test_utils.DataclassWithInitFalse", + "name": "training_config", + "computed_field": "should_be_removed_from_training", + }, + "data": { + "_target_": "tests.unit_tests.training.config.test_utils.DataclassWithInitFalse", + "name": "data_config", + "computed_field": "should_be_removed_from_data", + }, + } + + result = sanitize_dataclass_config(run_config) + + # All sections should have init=False fields removed + assert "computed_field" not in result["model"] + assert "computed_field" not in result["training"] + assert "computed_field" not in result["data"] + + # Regular fields should be preserved + assert result["model"]["name"] == "model_name" + assert result["training"]["name"] == "training_config" + assert result["data"]["name"] == "data_config" diff --git a/tests/unit_tests/training/config/test_yaml_utils.py b/tests/unit_tests/training/config/test_yaml_utils.py new file mode 100644 index 00000000000..12526eff2d6 --- /dev/null +++ b/tests/unit_tests/training/config/test_yaml_utils.py @@ -0,0 +1,336 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import enum +import functools +import os +import tempfile +from dataclasses import dataclass +from unittest.mock import Mock + +import pytest +import yaml + +from megatron.training.config.yaml_utils import ( + _enum_representer, + _function_representer, + _generation_config_representer, + _partial_representer, + _safe_object_representer, + _torch_dtype_representer, + safe_yaml_representers, +) + + +class TestEnum(enum.Enum): + """Test enum""" + + VALUE1 = "test_value1" + VALUE2 = "test_value2" + + +@dataclass +class TestDataclass: + """Test dataclass""" + + name: str + value: int + + +class TestClass: + def __init__(self, name: str = "test"): + """Test class""" + self.name = name + + +def test_function(): + """Test function""" + return "test" + + +class TestSafeYamlRepresenters: + """Test the safe_yaml_representers context manager.""" + + def test_context_manager_adds_and_removes_representers(self): + """Test that representers are properly added and removed.""" + # Save original state + original_representers = yaml.SafeDumper.yaml_representers.copy() + original_multi_representers = yaml.SafeDumper.yaml_multi_representers.copy() + + # Use context manager + with safe_yaml_representers(): + # Check that new representers were added + assert functools.partial in yaml.SafeDumper.yaml_representers + assert enum.Enum in yaml.SafeDumper.yaml_multi_representers + assert type(lambda: ...) in yaml.SafeDumper.yaml_representers + + # Check that original representers were restored + assert yaml.SafeDumper.yaml_representers == original_representers + assert yaml.SafeDumper.yaml_multi_representers == original_multi_representers + + def test_context_manager_handles_exceptions(self): + """Test that representers are restored even if an exception occurs.""" + original_representers = yaml.SafeDumper.yaml_representers.copy() + original_multi_representers = yaml.SafeDumper.yaml_multi_representers.copy() + + try: + with safe_yaml_representers(): + raise ValueError("Test exception") + except ValueError: + pass + + # Check that original representers were still restored + assert yaml.SafeDumper.yaml_representers == original_representers + assert yaml.SafeDumper.yaml_multi_representers == original_multi_representers + + +class TestFunctionRepresenter: + """Test the _function_representer function.""" + + def test_function_representation(self): + """Test representing a function in YAML.""" + dumper = yaml.SafeDumper("") + result = _function_representer(dumper, test_function) + + # The result should be a MappingNode + assert hasattr(result, "value") + + # Parse the represented data using the context manager + with safe_yaml_representers(): + data = yaml.safe_load(yaml.safe_dump({"test": test_function})) + assert "_target_" in data["test"] + assert "_call_" in data["test"] + assert data["test"]["_call_"] is False + assert "test_function" in data["test"]["_target_"] + + +class TestPartialRepresenter: + """Test the _partial_representer function.""" + + def test_partial_without_keywords(self): + """Test representing a partial function without keyword arguments.""" + partial_func = functools.partial(test_function) + dumper = yaml.SafeDumper("") + _ = _partial_representer(dumper, partial_func) + + # Parse the represented data + with safe_yaml_representers(): + data = yaml.safe_load(yaml.safe_dump({"test": partial_func})) + assert "_target_" in data["test"] + assert "_partial_" in data["test"] + assert data["test"]["_partial_"] is True + assert "_args_" in data["test"] + assert data["test"]["_args_"] == [] + + def test_partial_with_args_and_kwargs(self): + """Test representing a partial function with arguments and keyword arguments.""" + + def example_func(a, b, c=None): + return a + b + (c or 0) + + partial_func = functools.partial(example_func, 1, c=10) + dumper = yaml.SafeDumper("") + _ = _partial_representer(dumper, partial_func) + + # Parse the represented data + with safe_yaml_representers(): + data = yaml.safe_load(yaml.safe_dump({"test": partial_func})) + assert data["test"]["_args_"] == [1] + assert data["test"]["c"] == 10 + + +class TestEnumRepresenter: + """Test the _enum_representer function.""" + + def test_enum_representation(self): + """Test representing an enum value in YAML.""" + enum_value = TestEnum.VALUE1 + dumper = yaml.SafeDumper("") + _ = _enum_representer(dumper, enum_value) + + # Parse the represented data + with safe_yaml_representers(): + data = yaml.safe_load(yaml.safe_dump({"test": enum_value})) + assert "_target_" in data["test"] + assert "_call_" in data["test"] + assert data["test"]["_call_"] is True + assert "_args_" in data["test"] + assert data["test"]["_args_"] == ["test_value1"] + assert "_name_" in data["test"] + assert data["test"]["_name_"] == "VALUE1" + assert "TestEnum" in data["test"]["_target_"] + + +class TestSafeObjectRepresenter: + """Test the _safe_object_representer function.""" + + def test_object_with_qualname(self): + """Test representing an object that has __qualname__ attribute.""" + obj = test_function + dumper = yaml.SafeDumper("") + _ = _safe_object_representer(dumper, obj) + + # Parse the represented data + with safe_yaml_representers(): + data = yaml.safe_load(yaml.safe_dump({"test": obj})) + assert "_target_" in data["test"] + assert "_call_" in data["test"] + assert data["test"]["_call_"] is False + + def test_object_without_qualname(self): + """Test representing an object that doesn't have __qualname__ attribute.""" + obj = TestClass("test") + dumper = yaml.SafeDumper("") + _ = _safe_object_representer(dumper, obj) + + # Parse the represented data + with safe_yaml_representers(): + data = yaml.safe_load(yaml.safe_dump({"test": obj})) + assert "_target_" in data["test"] + assert "_call_" in data["test"] + assert data["test"]["_call_"] is True + assert "TestClass" in data["test"]["_target_"] + + +class TestTorchDtypeRepresenter: + """Test the _torch_dtype_representer function.""" + + def test_torch_dtype_representation(self): + """Test representing a torch dtype in YAML.""" + import torch + + dtype = torch.float32 + dumper = yaml.SafeDumper("") + _ = _torch_dtype_representer(dumper, dtype) + + # Parse the represented data + with safe_yaml_representers(): + data = yaml.safe_load(yaml.safe_dump({"test": dtype})) + assert "_target_" in data["test"] + assert "_call_" in data["test"] + assert data["test"]["_call_"] is False + assert "float32" in data["test"]["_target_"] + + def test_torch_dtype_representer_function(self): + """Test the torch dtype representer function directly.""" + # Create a mock torch dtype + mock_dtype = Mock() + mock_dtype.__str__ = Mock(return_value="torch.float32") + + dumper = yaml.SafeDumper("") + result = _torch_dtype_representer(dumper, mock_dtype) + + # Test the direct result from the representer function + # The result should be a MappingNode + assert hasattr(result, "value") + + # Parse the result from the direct function call + # We need to manually construct the data that would be generated + test_data = {"_target_": "torch.float32", "_call_": False} + yaml_result = yaml.safe_dump(test_data) + data = yaml.safe_load(yaml_result) + + assert "_target_" in data + assert "_call_" in data + assert data["_call_"] is False + + +class TestGenerationConfigRepresenter: + """Test the _generation_config_representer function.""" + + def test_generation_config_representation(self): + """Test representing a GenerationConfig object in YAML.""" + try: + from transformers import GenerationConfig + + config = GenerationConfig(max_length=100, temperature=0.8) + dumper = yaml.SafeDumper("") + _ = _generation_config_representer(dumper, config) + + # Parse the represented data + with safe_yaml_representers(): + data = yaml.safe_load(yaml.safe_dump({"test": config})) + assert "_target_" in data["test"] + assert "_call_" in data["test"] + assert data["test"]["_call_"] is True + assert "config_dict" in data["test"] + assert "from_dict" in data["test"]["_target_"] + except ImportError: + pytest.skip("Transformers not available") + + def test_generation_config_representer_function(self): + """Test the generation config representer function directly.""" + # Create a mock GenerationConfig + mock_config = Mock() + mock_config.__class__.__qualname__ = "GenerationConfig" + mock_config.__class__.__module__ = "transformers.generation.configuration_utils" + mock_config.to_dict = Mock(return_value={"max_length": 100, "temperature": 0.8}) + + dumper = yaml.SafeDumper("") + result = _generation_config_representer(dumper, mock_config) + + # Test the direct result from the representer function + # The result should be a MappingNode + assert hasattr(result, "value") + + # Parse the result from the direct function call + # We need to manually construct the data that would be generated + test_data = { + "_target_": "transformers.generation.configuration_utils.GenerationConfig.from_dict", + "_call_": True, + "config_dict": {"max_length": 100, "temperature": 0.8}, + } + yaml_result = yaml.safe_dump(test_data) + data = yaml.safe_load(yaml_result) + + assert "_target_" in data + assert "_call_" in data + assert data["_call_"] is True + assert "config_dict" in data + assert "from_dict" in data["_target_"] + + +class TestIntegration: + """Integration tests for the YAML utils functionality.""" + + def test_complex_object_serialization(self): + """Test serializing a complex object with multiple types.""" + complex_obj = { + "function": test_function, + "enum": TestEnum.VALUE1, + "partial": functools.partial(test_function), + "dataclass": TestDataclass("test", 42), + "regular_data": {"key": "value", "number": 123}, + } + + with safe_yaml_representers(): + result = yaml.safe_dump(complex_obj) + + assert isinstance(result, str) + + # Verify all components are serialized + assert "_target_:" in result + assert "test_function" in result + assert "TestEnum" in result + assert "_partial_:" in result + assert "TestDataclass" in result + assert "_call_:" in result + assert "key: value" in result + + def test_roundtrip_with_simple_objects(self): + """Test that simple objects can be serialized and deserialized.""" + simple_obj = { + "string": "test", + "number": 42, + "list": [1, 2, 3], + "dict": {"nested": "value"}, + } + + with safe_yaml_representers(): + yaml_str = yaml.safe_dump(simple_obj) + + reconstructed = yaml.safe_load(yaml_str) + + assert reconstructed["string"] == "test" + assert reconstructed["number"] == 42 + assert reconstructed["list"] == [1, 2, 3] + assert reconstructed["dict"]["nested"] == "value" diff --git a/tests/unit_tests/transformer/test_cuda_graphs.py b/tests/unit_tests/transformer/test_cuda_graphs.py index 01150d65570..493bbd6bc49 100644 --- a/tests/unit_tests/transformer/test_cuda_graphs.py +++ b/tests/unit_tests/transformer/test_cuda_graphs.py @@ -35,6 +35,7 @@ _CudagraphGlobalRecord, ) from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.fused_a2a import reset_hybrid_ep_buffer from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig @@ -1145,6 +1146,171 @@ def test_mhc_moe_partial_cudagraph(self, ep_size): Utils.destroy_model_parallel() +class _SimpleModule(MegatronModule): + """Minimal MegatronModule for testing CudaGraphManager with function_name.""" + + def __init__(self, config): + super().__init__(config) + self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size) + + def my_op(self, x): + return self.linear(x) + + +class _SimpleNonModule: + """non-nn.Module base_module for testing the function_name= form of `CudaGraphManager`.""" + + def __init__(self, config): + self.weight = torch.randn(config.hidden_size, config.hidden_size, device="cuda") + + def my_op(self, x): + return x @ self.weight + + +def _make_simple_module(config): + return _SimpleModule(config).cuda().eval() + + +def _make_simple_non_module(config): + return _SimpleNonModule(config) + + +class TestInlineCaptureManager: + """Tests for CudaGraphManager with inline_capture, function_name, eager, and cache_key.""" + + def _make_config(self): + return TransformerConfig( + num_layers=1, + hidden_size=32, + num_attention_heads=1, + use_cpu_initialization=True, + cuda_graph_impl="local", + inference_rng_tracker=True, + ) + + def setup_method(self, method): + Utils.initialize_model_parallel() + model_parallel_cuda_manual_seed( + seed=123, inference_rng_tracker=True, use_cudagraphable_rng=False, force_reset_rng=True + ) + + def teardown_method(self, method): + _CudagraphGlobalRecord.cudagraph_created = False + _CudagraphGlobalRecord.cudagraph_record = [] + _CudagraphGlobalRecord.cudagraph_inference_record = [] + CudaGraphManager.global_mempool = None + Utils.destroy_model_parallel() + + @pytest.mark.parametrize( + "make_module", + [ + pytest.param(_make_simple_module, id="nn_module"), + pytest.param(_make_simple_non_module, id="plain_class"), + ], + ) + @torch.inference_mode() + def test_inline_capture_matches_eager(self, make_module): + """Inline-captured graph output must match eager execution.""" + config = self._make_config() + module = make_module(config) + + # Get eager reference before wrapping + x = torch.randn(4, config.hidden_size, device="cuda") + eager_out = module.my_op(x).clone() + + mgr = CudaGraphManager( + config, + base_module=module, + function_name="my_op", + inline_capture=True, + num_warmup_steps=0, + need_backward=False, + ) + + # First call captures, second replays + graph_out_1 = module.my_op(x) + graph_out_2 = module.my_op(x) + assert torch.equal(eager_out, graph_out_1) + assert torch.equal(eager_out, graph_out_2) + assert len(mgr.cudagraph_runners) == 1 + assert mgr.cudagraph_runners[0].fwd_graph_recorded + + @torch.inference_mode() + def test_eager_bypass(self): + """eager=True must bypass graph capture entirely.""" + config = self._make_config() + module = _SimpleModule(config).cuda().eval() + + mgr = CudaGraphManager( + config, + base_module=module, + function_name="my_op", + inline_capture=True, + num_warmup_steps=0, + need_backward=False, + ) + + x = torch.randn(4, config.hidden_size, device="cuda") + _ = module.my_op(x, eager=True) + _ = module.my_op(x, eager=True) + assert len(mgr.cudagraph_runners) == 0, "eager=True should not create runners" + + @torch.inference_mode() + def test_cache_key_routing(self): + """Different cache_keys must create separate runners.""" + config = self._make_config() + module = _SimpleModule(config).cuda().eval() + + mgr = CudaGraphManager( + config, + base_module=module, + function_name="my_op", + inline_capture=True, + num_warmup_steps=0, + need_backward=False, + ) + + x = torch.randn(4, config.hidden_size, device="cuda") + module.my_op(x, cache_key="key_a") + module.my_op(x, cache_key="key_b") + + assert len(mgr.cudagraph_runners) == 2 + assert mgr.custom_cudagraphs_lookup_table["key_a"] is not None + assert mgr.custom_cudagraphs_lookup_table["key_b"] is not None + assert ( + mgr.custom_cudagraphs_lookup_table["key_a"] + is not mgr.custom_cudagraphs_lookup_table["key_b"] + ) + + # Same key reuses the runner + module.my_op(x, cache_key="key_a") + assert len(mgr.cudagraph_runners) == 2 + + @torch.inference_mode() + def test_num_warmup_steps_override(self): + """num_warmup_steps on the manager must override the config value on runners.""" + config = self._make_config() + config.cuda_graph_warmup_steps = 3 + module = _SimpleModule(config).cuda().eval() + + mgr = CudaGraphManager( + config, + base_module=module, + function_name="my_op", + inline_capture=True, + num_warmup_steps=0, + need_backward=False, + ) + + x = torch.randn(4, config.hidden_size, device="cuda") + module.my_op(x, cache_key="test") + + runner = mgr.cudagraph_runners[0] + assert ( + runner.num_warmup_steps == 0 + ), f"Expected 0 warmup steps (manager override), got {runner.num_warmup_steps}" + + if __name__ == "__main__": test = TestParallelTransformerBlockCudagraphs() diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py index c80b8f14480..0ca7be43ea7 100644 --- a/tests/unit_tests/transformer/test_transformer_layer.py +++ b/tests/unit_tests/transformer/test_transformer_layer.py @@ -88,73 +88,96 @@ def test_gpu_forward(self): def test_chunked_mlp(self): with torch.no_grad(): - - def test( - num_layers, - hidden_size, - num_attention_heads, - mlp_chunks_for_prefill, - hidden_states, - inference_context, - ): - - transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - mlp_chunks_for_prefill=4, - add_bias_linear=True, - use_cpu_initialization=True, - ) - parallel_transformer_layer = TransformerLayer( - transformer_config, get_gpt_layer_with_transformer_engine_submodules() - ) - - parallel_transformer_layer.cuda() - - hidden_states, context = parallel_transformer_layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - inference_context=inference_context, - ) - - return hidden_states, context - num_layers = 2 hidden_size = 12 num_attention_heads = 4 - sequence_length = 32 micro_batch_size = 2 + transformer_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + mlp_chunks_for_prefill=1, + mlp_chunks_for_training=1, + add_bias_linear=True, + use_cpu_initialization=True, + hidden_dropout=0.0, + attention_dropout=0.0, + ) + parallel_transformer_layer = TransformerLayer( + transformer_config, get_gpt_layer_with_transformer_engine_submodules() + ) + parallel_transformer_layer.cuda() + # [sequence length, batch size, hidden size] - input_hidden_states = torch.ones((sequence_length, micro_batch_size, hidden_size)) + torch.manual_seed(42) + input_hidden_states = torch.randn((sequence_length, micro_batch_size, hidden_size)) input_hidden_states = input_hidden_states.cuda() attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() + # Test chunked prefill: chunks=1 vs chunks=4 should be identical + parallel_transformer_layer.eval() inference_context = StaticInferenceContext( max_batch_size=micro_batch_size, max_sequence_length=sequence_length ) - outputs = {} - for mlp_chunks_for_prefill in [1, 4]: - hidden_states, context = test( - num_layers, - hidden_size, - num_attention_heads, - mlp_chunks_for_prefill, - input_hidden_states, - inference_context, + transformer_config.mlp_chunks_for_prefill = mlp_chunks_for_prefill + hidden_states, context = parallel_transformer_layer( + hidden_states=input_hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, ) assert hidden_states.shape[0] == sequence_length assert hidden_states.shape[1] == micro_batch_size assert hidden_states.shape[2] == hidden_size - outputs[mlp_chunks_for_prefill] = (hidden_states, context) - assert torch.equal(outputs[1][0], outputs[4][0]) + assert torch.equal(outputs[1][0], outputs[4][0]) + + # Test chunked training: chunks=1 vs chunks=4 should be identical + parallel_transformer_layer.train() + outputs = {} + for mlp_chunks_for_training in [1, 4]: + transformer_config.mlp_chunks_for_training = mlp_chunks_for_training + hidden_states, context = parallel_transformer_layer( + hidden_states=input_hidden_states, + attention_mask=attention_mask, + inference_context=None, + ) + assert hidden_states.shape[0] == sequence_length + assert hidden_states.shape[1] == micro_batch_size + assert hidden_states.shape[2] == hidden_size + outputs[mlp_chunks_for_training] = (hidden_states, context) + + assert torch.equal(outputs[1][0], outputs[4][0]) + + # Test gradient equivalence: chunked vs non-chunked training + parallel_transformer_layer.train() + grads = {} + for mlp_chunks_for_training in [1, 4]: + transformer_config.mlp_chunks_for_training = mlp_chunks_for_training + parallel_transformer_layer.zero_grad() + hidden_states, _ = parallel_transformer_layer( + hidden_states=input_hidden_states, + attention_mask=attention_mask, + inference_context=None, + ) + loss = hidden_states.sum() + loss.backward() + grads[mlp_chunks_for_training] = { + name: param.grad.clone() + for name, param in parallel_transformer_layer.named_parameters() + if param.grad is not None + } + + for name in grads[1]: + assert torch.allclose(grads[1][name], grads[4][name], atol=1e-6), ( + f"Gradient mismatch for {name}: " + f"max diff={torch.max(torch.abs(grads[1][name] - grads[4][name])).item()}" + ) def test_get_layer_offset(self): config = self.parallel_transformer_layer.config diff --git a/tools/checkpoint/checkpoint_inspector.py b/tools/checkpoint/checkpoint_inspector.py index 9033ff96a17..7c946cb5ac3 100644 --- a/tools/checkpoint/checkpoint_inspector.py +++ b/tools/checkpoint/checkpoint_inspector.py @@ -781,8 +781,10 @@ def has_layer_index(key: str) -> bool: @click.option( "--param-to-param-group-map-json", type=str, - default="{}", - help="JSON string representing the param to parameter group map.", + default=None, + help="Path to a JSON file mapping parameter names to optimizer param group ids. " + "Required only if the source checkpoint has multiple optimizer param groups " + "(e.g. different LR/weight-decay per group). Leave unset for single-group checkpoints.", ) @click.option( "--rename-mtp-keys", @@ -883,8 +885,11 @@ def oom_observer(device, alloc, device_alloc, device_free): ckpt_path = Path(input_dir) output_dir = Path(output_dir) - with open(param_to_param_group_map_json, "r") as f: - param_to_param_group_map = json.load(f) + if param_to_param_group_map_json: + with open(param_to_param_group_map_json, "r") as f: + param_to_param_group_map = json.load(f) + else: + param_to_param_group_map = {} _swiglu_modules = ( [m.strip() for m in swiglu_modules.split(",") if m.strip()] if swiglu_modules is not None diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py index e9eb7e99b60..1f9fab93637 100644 --- a/tools/checkpoint/convert.py +++ b/tools/checkpoint/convert.py @@ -2,9 +2,10 @@ import argparse import importlib -import torch.multiprocessing as mp import sys +import torch.multiprocessing as mp + # A loader is a python file with at least two functions # - add_arguments - takes in a parser and adds any arguments needed # - load_checkpoint - takes in the queue and parsed arguments @@ -87,6 +88,7 @@ # } # - "done" + def load_plugin(plugin_type, name): module_name = f"{plugin_type}_{name}" try: @@ -106,37 +108,55 @@ def load_plugin(plugin_type, name): print(f"Loaded {module_name} as the {plugin_type}.") return plugin + def main(): import argparse - parser = argparse.ArgumentParser(description="Megatron Checkpoint Converter Arguments", - allow_abbrev=False, conflict_handler='resolve') - - parser.add_argument('--model-type', type=str, required=True, - choices=['GPT', 'BERT'], - help='Type of the model') - parser.add_argument('--loader', type=str, default='megatron', - help='Module name to load checkpoint, should be on python path') - parser.add_argument('--saver', type=str, default='megatron', - help='Module name to save checkpoint, should be on python path') - parser.add_argument('--load-dir', type=str, required=True, - help='Directory to load model checkpoint from') - parser.add_argument('--save-dir', type=str, required=True, - help='Directory to save model checkpoint to') - parser.add_argument('--max-queue-size', type=int, default=50, - help='Maximum number of tensors in the queue') - parser.add_argument('--no-checking', action='store_false', - help='Do not perform checking on the name and ordering of weights', - dest='checking') + + parser = argparse.ArgumentParser( + description="Megatron Checkpoint Converter Arguments", + allow_abbrev=False, + conflict_handler='resolve', + ) + + parser.add_argument( + '--model-type', type=str, required=True, choices=['GPT', 'BERT'], help='Type of the model' + ) + parser.add_argument( + '--loader', + type=str, + default='megatron', + help='Module name to load checkpoint, should be on python path', + ) + parser.add_argument( + '--saver', + type=str, + default='megatron', + help='Module name to save checkpoint, should be on python path', + ) + parser.add_argument( + '--load-dir', type=str, required=True, help='Directory to load model checkpoint from' + ) + parser.add_argument( + '--save-dir', type=str, required=True, help='Directory to save model checkpoint to' + ) + parser.add_argument( + '--max-queue-size', type=int, default=50, help='Maximum number of tensors in the queue' + ) + parser.add_argument( + '--no-checking', + action='store_false', + help='Do not perform checking on the name and ordering of weights', + dest='checking', + ) known_args, _ = parser.parse_known_args() # Handle old arg values. def update_loader_saver(key): old_value = getattr(known_args, key) - if old_value == "megatron": - setattr(known_args, key, "legacy") if old_value == "mcore": setattr(known_args, key, "core") + update_loader_saver("loader") update_loader_saver("saver") diff --git a/tools/checkpoint/dist_checkpoint_io.py b/tools/checkpoint/dist_checkpoint_io.py new file mode 100644 index 00000000000..71038059ded --- /dev/null +++ b/tools/checkpoint/dist_checkpoint_io.py @@ -0,0 +1,250 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Distributed checkpoint I/O helpers for structural model-conversion tools. + +Provides format detection, model-free full-tensor loading, and single-rank +saving for Megatron-LM distributed checkpoints (``torch_dist`` and +``fsdp_dtensor`` backends). This lets conversion tools operate on +TP+PP+FSDP-trained checkpoints without needing to instantiate the model. + +The key observation is that PyTorch DCP stores each logical parameter with a +``global_shape`` in its metadata, and the TP / PP / FSDP slicing is just an +on-disk layout detail handled by the read planner. Loading into a plain +``torch.empty(global_shape)`` state dict on rank 0 therefore yields fully +gathered tensors regardless of the parallelism the checkpoint was trained with. +""" + +import os +from collections import OrderedDict +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +from torch.distributed.checkpoint import ( + DefaultLoadPlanner, + DefaultSavePlanner, + FileSystemReader, + FileSystemWriter, +) +from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata + +from megatron.core.dist_checkpointing.core import ( + CheckpointingConfig, + maybe_load_config, + save_config, +) +from megatron.core.dist_checkpointing.strategies.common import load_common, save_common + +FORMAT_TORCH_DIST = 'torch_dist' +FORMAT_FSDP_DTENSOR = 'fsdp_dtensor' +DIST_FORMATS = (FORMAT_TORCH_DIST, FORMAT_FSDP_DTENSOR) + +# Prefixes under which model weights may be keyed in a dist checkpoint. +_KNOWN_MODEL_PREFIXES = ('model.module.module.', 'model.module.', 'model.', '') +# Well-known bare-key suffixes we probe for when detecting the prefix. +_PROBE_SUFFIXES = ( + 'embedding.word_embeddings.weight', + 'decoder.layers.', + 'decoder.final_norm.', + 'decoder.final_layernorm.', + 'output_layer.weight', +) +# Keys that identify non-model state we drop during architecture conversion. +_NON_MODEL_TOP_LEVEL_PREFIXES = ('optimizer.', 'rng_state', 'rerun_state_machine_state') + + +def resolve_checkpoint_subdir(load_dir): + """Return ``(ckpt_dir, iteration)``. + + Megatron writes checkpoints either flat or under ``iter_XXXXXXX/``. This + picks the right directory and reports the iteration when it can be + determined. + """ + if os.path.exists(os.path.join(load_dir, 'metadata.json')): + return load_dir, None + + latest_iter = os.path.join(load_dir, 'latest_checkpointed_iteration.txt') + if os.path.exists(latest_iter): + with open(latest_iter, 'r') as f: + iteration = int(f.read().strip()) + iter_dir = os.path.join(load_dir, f'iter_{iteration:07d}') + if os.path.isdir(iter_dir): + return iter_dir, iteration + + return load_dir, None + + +def detect_checkpoint_format(load_dir): + """Return one of ``{'torch_dist', 'fsdp_dtensor'}``. + + Raises ``ValueError`` if the directory looks like the legacy + ``mp_rank_XX`` layout (no longer supported) or doesn't match any known + dist-checkpoint metadata. + """ + ckpt_dir, _ = resolve_checkpoint_subdir(load_dir) + config = maybe_load_config(ckpt_dir) + if config is not None: + return config.sharded_backend + + if os.path.isdir(ckpt_dir) and any( + name.startswith('mp_rank_') for name in os.listdir(ckpt_dir) + ): + raise ValueError( + f"{load_dir} looks like a legacy mp_rank_XX checkpoint. " + f"Legacy format is no longer supported — convert to torch_dist first." + ) + + raise ValueError(f"Unrecognized checkpoint format at {load_dir}") + + +def ensure_single_rank_process_group(): + """Initialize a 1-rank gloo process group if one isn't already up. + + DCP requires a default process group; this lets the conversion tool run + in a plain ``python`` invocation (no ``torchrun`` needed). + """ + if not dist.is_available(): + raise RuntimeError("torch.distributed is not available.") + if dist.is_initialized(): + return + os.environ.setdefault('MASTER_ADDR', '127.0.0.1') + os.environ.setdefault('MASTER_PORT', '29500') + os.environ.setdefault('RANK', '0') + os.environ.setdefault('WORLD_SIZE', '1') + os.environ.setdefault('LOCAL_RANK', '0') + dist.init_process_group(backend='gloo', rank=0, world_size=1) + + +def detect_model_prefix(keys): + """Return the prefix under which model weights live in ``keys``. + + Looks for a recognizable suffix (``embedding.word_embeddings.weight``, + ``decoder.layers.``, etc.) and returns the matching prefix. Falls back + to ``''`` if nothing obvious is found. + """ + keys = list(keys) + for prefix in _KNOWN_MODEL_PREFIXES: + for suffix in _PROBE_SUFFIXES: + probe = prefix + suffix + for key in keys: + if key.startswith(probe): + return prefix + return '' + + +def _is_non_model_key(bare_key): + if bare_key.startswith(_NON_MODEL_TOP_LEVEL_PREFIXES): + return True + # _extra_state blobs are TE per-module state; they are tied to a specific + # TP/parallelism configuration and aren't meaningful after a structural + # model conversion, so we drop them. + if '_extra_state' in bare_key: + return True + return False + + +def load_dist_checkpoint_full(load_dir): + """Load a dist checkpoint and return fully-gathered model weights. + + Returns: + model_state_dict (OrderedDict[str, torch.Tensor]): bare keys, full + tensors on CPU. Optimizer state, RNG state, and ``_extra_state`` + blobs are filtered out. + common_state (dict): contents of ``common.pt`` (e.g. ``args``). + model_prefix (str): the prefix we stripped (re-apply on save). + backend (str): ``'torch_dist'`` or ``'fsdp_dtensor'``. + iteration (int or None): iteration number if discoverable. + """ + ensure_single_rank_process_group() + + ckpt_dir, iteration = resolve_checkpoint_subdir(load_dir) + config = maybe_load_config(ckpt_dir) + if config is None: + raise ValueError(f"{load_dir} is not a distributed checkpoint (no metadata.json)") + backend = config.sharded_backend + + reader = FileSystemReader(ckpt_dir) + metadata = reader.read_metadata() + + model_prefix = detect_model_prefix(metadata.state_dict_metadata.keys()) + + raw_state_dict = {} + for key, md in metadata.state_dict_metadata.items(): + if not isinstance(md, TensorStorageMetadata): + continue + if model_prefix and not key.startswith(model_prefix): + continue + bare_key = key[len(model_prefix) :] if model_prefix else key + if _is_non_model_key(bare_key): + continue + raw_state_dict[key] = torch.empty(md.size, dtype=md.properties.dtype, device='cpu') + + if not raw_state_dict: + raise ValueError( + f"No model tensors found in {ckpt_dir} (detected prefix " + f"'{model_prefix}', backend '{backend}')." + ) + + dcp.load(raw_state_dict, storage_reader=reader, planner=DefaultLoadPlanner()) + + model_state_dict = OrderedDict() + for key, tensor in raw_state_dict.items(): + bare_key = key[len(model_prefix) :] if model_prefix else key + model_state_dict[bare_key] = tensor + + common_state = {} + try: + common_state = load_common(ckpt_dir) + except Exception: + pass + + return model_state_dict, common_state, model_prefix, backend, iteration + + +def save_dist_checkpoint_full( + model_state_dict, common_state, save_dir, model_prefix='model.', backend=FORMAT_TORCH_DIST +): + """Save a fully-gathered state dict as a distributed checkpoint. + + The output is written as a single-rank, fully-replicated DCP checkpoint + plus ``common.pt`` and ``metadata.json``. A downstream Megatron training + job reads it back through ``dist_checkpointing.load()`` with its own + sharded_state_dict template — TP+PP+FSDP resharding happens transparently + on load, since the on-disk tensors carry their full logical shape. + """ + ensure_single_rank_process_group() + + os.makedirs(save_dir, exist_ok=True) + + raw_state_dict = OrderedDict() + for bare_key, tensor in model_state_dict.items(): + full_key = f"{model_prefix}{bare_key}" if model_prefix else bare_key + raw_state_dict[full_key] = ( + tensor.contiguous() if tensor.is_contiguous() else tensor.contiguous() + ) + + writer = FileSystemWriter(save_dir) + dcp.save(state_dict=raw_state_dict, storage_writer=writer, planner=DefaultSavePlanner()) + + if common_state: + save_common(common_state, save_dir) + + if dist.get_rank() == 0: + save_config(CheckpointingConfig(sharded_backend=backend), save_dir) + dist.barrier() + + +def write_latest_iteration_marker(save_dir, iteration): + """Mirror the legacy ``latest_checkpointed_iteration.txt`` convention. + + When ``save_dir`` points at a top-level checkpoint root with an + ``iter_XXXXXXX/`` subdirectory, the tracker file lets Megatron auto-find + the latest iteration on load. + """ + parent = os.path.dirname(save_dir.rstrip('/')) or save_dir + if os.path.basename(save_dir.rstrip('/')).startswith('iter_'): + tracker = os.path.join(parent, 'latest_checkpointed_iteration.txt') + with open(tracker, 'w') as f: + f.write(str(iteration)) diff --git a/tools/checkpoint/gpt_hybrid_conversion.py b/tools/checkpoint/gpt_hybrid_conversion.py new file mode 100644 index 00000000000..6910bc05f9e --- /dev/null +++ b/tools/checkpoint/gpt_hybrid_conversion.py @@ -0,0 +1,924 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +GPT <-> Hybrid Checkpoint Conversion Tool +========================================= + +Directly converts checkpoints between GPTModel (homogeneous Transformer) and +HybridModel (hybrid Mamba+Transformer) without going through HuggingFace as an +intermediary. + +Supported directions: + gpt-to-hybrid : Convert a GPT checkpoint to Hybrid format. + hybrid-to-gpt : Convert a Hybrid checkpoint to GPT format. + +How the hybrid layer pattern maps GPT layers (gpt-to-hybrid): + - Each GPT layer contains both attention and MLP sub-layers. + - The target hybrid model's hybrid_layer_pattern specifies per-layer types: + M = Mamba SSM layer + * = Attention-only layer + - = MLP-only layer (dense) + E = MoE MLP-only layer (router + experts; supports EP) + G = GDN layer (not currently mapped) + - GPT layer i's attention params map to the i-th '*' layer in the pattern. + - GPT layer i's MLP/MoE params map to the i-th MLP-bearing position + ('-' or 'E') in the pattern. Dense ('-') and MoE ('E') cannot be mixed: + GPT layers are uniform. + - The number of '*' positions and MLP-bearing positions must each equal + the number of GPT layers. + - Mamba SSM ('M') layers have no GPT equivalent and are initialized from + scratch using standard Mamba initialization. + +How MoE / Expert Parallelism (EP) works through the converter: + - GPTModel can run with MoE (Mixtral-style: every layer has a router and + N local experts). State-dict keys live under + `decoder.layers..mlp.{router,experts,shared_experts}.*`. + - Hybrid 'E' layers use the same key naming, so MoE tensors round-trip + verbatim — no expert collapsing, no router init, no per-expert work. + - EP-sharded checkpoints load through DCP transparently because each + tensor's `global_shape` is in the metadata, regardless of how many + EP / TP / PP / FSDP ranks wrote it. + - Use a pattern like 'M*EM*EM*E' to pair Mamba/Attn/MoE-MLP per stage. + +What happens to SSM parameters: + gpt-to-hybrid: SSM layers (M) are initialized from scratch: + - A_log: log(uniform(1, 16)) + - dt_bias: inverse_softplus(log_uniform(dt_min, dt_max)) + - D: ones + - conv1d.weight: kaiming_uniform(a=sqrt(5)) + - conv1d.bias: zeros + - in_proj.weight: kaiming_uniform(a=sqrt(5)) + - in_proj.layer_norm_weight: ones + - out_proj.weight: kaiming_uniform(a=sqrt(5)) + - norm.weight: ones + hybrid-to-gpt: SSM layers are discarded with a warning. + +Supported checkpoint formats: + - torch_dist : Megatron distributed checkpoint (TP + PP + FSDP). + - fsdp_dtensor : FSDP DTensor export (TP + PP + FSDP). + + PyTorch DCP gathers TP/PP/FSDP shards via the checkpoint's global-shape + metadata, so no explicit TP/PP/DP config is needed on input. The input + format is auto-detected; the output format defaults to the input format. + + The legacy ``mp_rank_XX/model_optim_rng.pt`` layout is not supported — + convert old checkpoints to ``torch_dist`` first. + +GPT compatibility whitelist (safeguard): + GPTModel is a strict homogeneous transformer (self-attention + MLP per + layer, standard linear_qkv / linear_fc1 / linear_fc2 state-dict keys). + The converter fails fast if either side uses features that GPTModel + cannot express. + + Rejected pattern symbols: 'G' (GDN), 'D' (DS-attention), 'E' (MoE). + Allowed: 'M' (Mamba SSM), '*' (attention), '-' (MLP). The number of + '*' and '-' layers must be equal. + + Rejected source-args features (checked against the args stored in the + source checkpoint): + - num_moe_experts / moe_shared_expert_intermediate_size / moe_layer_freq + - experimental_attention_variant (gated_delta_net, dsa, ...) + - linear_attention_freq + - heterogeneous_block_specs / heterogeneous_layers_config_path + - multi_latent_attention (MLA) + - mtp_num_layers (Multi-Token Prediction) + + See `validate_pattern_gpt_compatible` and + `validate_source_args_gpt_compatible` for the exact rules. + +Example commands: + # GPT -> Hybrid (TP+PP+FSDP dist checkpoint) + python tools/checkpoint/gpt_hybrid_conversion.py \\ + --direction gpt-to-hybrid \\ + --load-dir /path/to/gpt-dist-checkpoint \\ + --save-dir /path/to/hybrid-dist-checkpoint \\ + --hybrid-layer-pattern "M*-M*-M*-M*-" \\ + --d-model 4096 \\ + --mamba-d-state 128 \\ + --mamba2-n-groups 8 \\ + --mamba2-head-dim 64 + + # Hybrid -> GPT (dist checkpoint) + python tools/checkpoint/gpt_hybrid_conversion.py \\ + --direction hybrid-to-gpt \\ + --load-dir /path/to/hybrid-dist-checkpoint \\ + --save-dir /path/to/gpt-dist-checkpoint \\ + --hybrid-layer-pattern "M*-M*-M*-M*-" \\ + --d-model 4096 \\ + --mamba-d-state 128 \\ + --mamba2-n-groups 8 \\ + --mamba2-head-dim 64 +""" + +import argparse +import copy +import math +import os +import re +from collections import OrderedDict + +import torch +from dist_checkpoint_io import ( + DIST_FORMATS, + FORMAT_TORCH_DIST, + detect_checkpoint_format, + load_dist_checkpoint_full, + save_dist_checkpoint_full, + write_latest_iteration_marker, +) + +# --------------------------------------------------------------------------- +# Hybrid layer pattern parsing (standalone, no Megatron imports needed) +# --------------------------------------------------------------------------- + +VALID_LAYER_SYMBOLS = {'M', 'G', '*', '-', 'E'} + +# Layer symbols GPTModel can emit or absorb: +# '*' : standard self-attention layer (MHA / GQA / MQA) +# '-' : standard (optionally gated) dense MLP layer +# 'E' : MoE MLP layer. Both sides keep the keys under +# decoder.layers..mlp.{router,experts,shared_experts}.* so MoE +# tensors round-trip verbatim (see convert_gpt_to_hybrid and +# convert_hybrid_to_gpt — `is_mlp_param` already matches `mlp.*`). +# SSM ('M') has no GPT equivalent and is initialized from scratch / +# discarded (see convert_gpt_to_hybrid / convert_hybrid_to_gpt). +# 'G' (GDN) and 'D' (DS-attention) are not currently mapped — they would +# need separate key-naming work. Reject for now. +GPT_COMPATIBLE_PATTERN_SYMBOLS = {'M', '*', '-', 'E'} + + +def parse_hybrid_layer_pattern(pattern): + """Parse a hybrid layer pattern string into a list of layer types. + + Strips MTP separators (/) and pipeline stage separators (|), returning only + the main decoder pattern as a list of single-character layer types. + + Returns: + list[str]: e.g. ['M', '*', '-', 'M', '*', '-'] + """ + # Take only the main pattern (before first '/') + main_pattern = pattern.split('/')[0] + # Remove pipeline stage separators + main_pattern = main_pattern.replace('|', '') + layer_types = list(main_pattern) + for ch in layer_types: + if ch not in VALID_LAYER_SYMBOLS: + raise ValueError( + f"Invalid layer symbol '{ch}' in pattern. " f"Valid symbols: {VALID_LAYER_SYMBOLS}" + ) + return layer_types + + +# Pattern symbols that pair to a GPT-side MLP block. Both dense ('-') and MoE +# ('E') keep their state-dict keys under `decoder.layers..mlp.*`, so they +# round-trip identically. The pattern uniformity check +# (validate_pattern_gpt_compatible) ensures '-' and 'E' don't appear together, +# which would mean GPT layers aren't uniform. +_MLP_BEARING_SYMBOLS = ('-', 'E') + + +def build_layer_index_mapping(layer_types, direction): + """Build mapping between GPT layer indices and hybrid-model layer indices. + + For gpt-to-hybrid: + Returns (attn_map, mlp_map, ssm_indices) where: + - attn_map[gpt_layer_i] = hybrid_layer_j (j is the index of the i-th '*') + - mlp_map[gpt_layer_i] = hybrid_layer_k (k is the index of the i-th + MLP-bearing position; either '-' or 'E') + + For hybrid-to-gpt: + Returns (attn_map, mlp_map, ssm_indices) where: + - attn_map[hybrid_attn_idx] = gpt_layer_i + - mlp_map[hybrid_mlp_idx] = gpt_layer_i + """ + attn_indices = [i for i, t in enumerate(layer_types) if t == '*'] + mlp_indices = [i for i, t in enumerate(layer_types) if t in _MLP_BEARING_SYMBOLS] + ssm_indices = [i for i, t in enumerate(layer_types) if t == 'M'] + + if direction == 'gpt-to-hybrid': + if len(attn_indices) != len(mlp_indices): + raise ValueError( + f"For gpt-to-hybrid, the number of attention layers ({len(attn_indices)}) " + f"must equal the number of MLP/MoE layers ({len(mlp_indices)}) in the pattern." + ) + attn_map = {i: attn_indices[i] for i in range(len(attn_indices))} + mlp_map = {i: mlp_indices[i] for i in range(len(mlp_indices))} + return attn_map, mlp_map, ssm_indices + + elif direction == 'hybrid-to-gpt': + if len(attn_indices) != len(mlp_indices): + raise ValueError( + f"For hybrid-to-gpt, the number of attention layers ({len(attn_indices)}) " + f"must equal the number of MLP/MoE layers ({len(mlp_indices)}) in the pattern." + ) + attn_map = {attn_indices[i]: i for i in range(len(attn_indices))} + mlp_map = {mlp_indices[i]: i for i in range(len(mlp_indices))} + return attn_map, mlp_map, ssm_indices + + else: + raise ValueError(f"Unknown direction: {direction}") + + +# --------------------------------------------------------------------------- +# GPT compatibility whitelist +# --------------------------------------------------------------------------- +# +# GPTModel is a *uniform* transformer: every decoder layer is the same kind. +# It can run with dense MLP or MoE MLP — both keep keys under +# decoder.layers..mlp.* — so MoE checkpoints round-trip through the +# converter as long as both sides share the same kind on every layer. +# The helpers below reject any hybrid layout or source-args combination that +# violates uniformity (and would therefore silently produce a corrupt target). +# +# Pattern-level rules (checked on the parsed hybrid_layer_pattern): +# * only 'M', '*', '-', 'E' are allowed (no 'G' GDN, no 'D' DS-attention) +# * MLP-bearing symbols must be uniform: '-' and 'E' cannot both appear +# (that would imply GPT has both dense and MoE layers — heterogeneous) +# * '*' count must equal '-'+'E' count (one-to-one GPT attn<->MLP pairing) +# +# Args-level rules (checked against the training args stored in the source +# checkpoint): reject anything that makes GPT layers heterogeneous OR uses +# attention variants the converter doesn't currently key-translate: +# * moe_layer_freq != 1 (interleaved dense/MoE layers) +# * experimental_attention_variant (gated_delta_net, dsa, ...) +# * linear_attention_freq (interleaved linear-attention) +# * heterogeneous_block_specs / heterogeneous_layers_config_* +# (Nemotron-NAS per-layer specs) +# * multi_latent_attention (MLA: different QKV key layout) +# * mtp_num_layers (Multi-Token Prediction head) +# +# Notably NOT rejected (they round-trip via mlp.* / self_attention.* keys): +# * num_moe_experts (MoE on every layer) +# * moe_shared_expert_intermediate_size (shared experts on every layer) +# +# All rejected configurations raise ValueError early, before any tensors +# are touched. + +# Source-args field name -> (predicate-that-means-"reject", human reason). +# Predicates are applied with getattr(args, field, None); missing fields +# are treated as "absent" and pass. +_GPT_COMPAT_REJECT_FIELDS = ( + ( + 'moe_layer_freq', + # moe_layer_freq is None or 1 when every layer is the same kind (all + # dense or all MoE). A value > 1 or a list with mixed entries means + # GPT has interleaved dense/MoE layers — heterogeneous, can't pair + # one-to-one with a uniform hybrid pattern. + lambda v: ( + v is not None + and not (isinstance(v, int) and v == 1) + and not (isinstance(v, str) and v.strip() in ('', '1')) + and not (isinstance(v, (list, tuple)) and all(x == 1 for x in v)) + ), + 'interleaved dense/MoE layers (moe_layer_freq)', + ), + ( + 'experimental_attention_variant', + lambda v: v is not None and v != '', + 'experimental attention variant (gated_delta_net / dsa / ...)', + ), + ( + 'linear_attention_freq', + lambda v: v is not None, + 'linear attention layers (linear_attention_freq)', + ), + ('heterogeneous_block_specs', lambda v: bool(v), 'heterogeneous per-layer block specs'), + ( + 'heterogeneous_layers_config_path', + lambda v: v is not None and v != '', + 'heterogeneous layers config (Nemotron-NAS)', + ), + ( + 'heterogeneous_layers_config_encoded_json', + lambda v: v is not None and v != '', + 'heterogeneous layers config (Nemotron-NAS, inline JSON)', + ), + ('multi_latent_attention', lambda v: bool(v), 'Multi-Latent Attention (MLA)'), + ( + 'mtp_num_layers', + lambda v: v is not None and v > 0, + 'Multi-Token Prediction head (mtp_num_layers)', + ), +) + + +def validate_pattern_gpt_compatible(layer_types, direction): + """Raise ValueError if the hybrid pattern cannot round-trip with GPTModel. + + Args: + layer_types: list of layer-type chars from parse_hybrid_layer_pattern(). + direction: 'gpt-to-hybrid' or 'hybrid-to-gpt' (for error messages). + + Rules: + * Allowed symbols: 'M', '*', '-', 'E'. 'G' (GDN) and 'D' (DS-attention) + are not currently key-translated. + * MLP-bearing symbols must be uniform: '-' (dense) and 'E' (MoE) cannot + both appear, because that would imply GPT has both dense and MoE + layers — the GPT side must be uniform. + * The number of attention positions must equal the number of + MLP-bearing positions: every GPT layer pairs one attention with one + MLP/MoE. + """ + bad = sorted({c for c in layer_types if c not in GPT_COMPATIBLE_PATTERN_SYMBOLS}) + if bad: + raise ValueError( + f"Hybrid layer pattern contains symbols {bad} that are not " + f"GPT-compatible (allowed: {sorted(GPT_COMPATIBLE_PATTERN_SYMBOLS)}). " + f"'G' (GDN) and 'D' (DS-attention) are not currently key-translated " + f"and cannot be {direction}-converted." + ) + + mlp_kinds_present = {t for t in layer_types if t in _MLP_BEARING_SYMBOLS} + if len(mlp_kinds_present) > 1: + raise ValueError( + f"Hybrid layer pattern mixes '-' (dense MLP) and 'E' (MoE) " + f"positions. GPTModel layers must be uniform — either all GPT " + f"layers are dense MLP, or all are MoE. Use only one of '-' or " + f"'E' in the pattern." + ) + + n_attn = sum(1 for t in layer_types if t == '*') + n_mlp = sum(1 for t in layer_types if t in _MLP_BEARING_SYMBOLS) + if n_attn != n_mlp: + raise ValueError( + f"GPT-compatible hybrid patterns must pair every attention layer " + f"('*') with one MLP/MoE layer ('-' or 'E'). Got {n_attn} '*' " + f"and {n_mlp} MLP-bearing layers in the pattern." + ) + + +def validate_source_args_gpt_compatible(source_args, direction): + """Raise ValueError if the source checkpoint uses features GPTModel can't express. + + Args: + source_args: argparse.Namespace (or any attribute-bag) loaded from the + source checkpoint; may be None, in which case this check is a no-op + (dist checkpoints without a cached args blob). + direction: 'gpt-to-hybrid' or 'hybrid-to-gpt'. + + Rejects MoE, MLA, MTP, linear / experimental attention, and heterogeneous + per-layer specs. See the module header for the full list. + """ + if source_args is None: + return + + rejected = [] + for field, predicate, reason in _GPT_COMPAT_REJECT_FIELDS: + if not hasattr(source_args, field): + continue + value = getattr(source_args, field) + try: + if predicate(value): + rejected.append(f" - {reason}: {field}={value!r}") + except Exception: + # Defensive: never let the validator crash on an unexpected + # value type — treat it as "cannot verify, pass". + continue + + if rejected: + joined = "\n".join(rejected) + raise ValueError( + f"Source checkpoint is not GPT-compatible for {direction} " + f"conversion. The following features have no GPTModel equivalent " + f"and would produce a corrupt target checkpoint:\n{joined}\n" + f"Remove these features from the model (or use a different " + f"conversion tool) before running gpt_hybrid_conversion." + ) + + +# --------------------------------------------------------------------------- +# SSM parameter initialization (for gpt-to-hybrid) +# --------------------------------------------------------------------------- + + +def initialize_ssm_layer_params( + layer_idx, + d_model, + mamba_d_inner, + mamba_d_state, + mamba2_n_groups, + mamba2_n_heads, + mamba_head_dim, + d_conv=4, + dt_min=0.001, + dt_max=0.1, + dt_init_floor=1e-4, + A_init_range=(1, 16), + init_method_std=0.02, + dtype=torch.float32, +): + """Initialize parameters for a single Mamba SSM layer from scratch. + + Follows the initialization logic from MambaMixer.__init__: + - A_log: log(uniform(A_init_range)) + - dt_bias: inverse_softplus(log_uniform(dt_min, dt_max)) + - D: ones(nheads) + - conv1d.weight: kaiming_uniform(a=sqrt(5)) + - conv1d.bias: zeros + - in_proj.weight: kaiming_uniform(a=sqrt(5)) + - in_proj.layer_norm_weight: ones(d_model) + - out_proj.weight: kaiming_uniform(a=sqrt(5)) or normal(0, std) + - norm.weight: ones(d_inner) + + Returns: + dict: {param_suffix: tensor} for one SSM layer + """ + prefix = f'decoder.layers.{layer_idx}.mixer.' + + nheads = mamba2_n_heads + conv_dim = mamba_d_inner + 2 * mamba2_n_groups * mamba_d_state + in_proj_out_dim = 2 * mamba_d_inner + 2 * mamba2_n_groups * mamba_d_state + nheads + + params = OrderedDict() + + # in_proj (ColumnParallelLinear) + in_proj_weight = torch.empty(in_proj_out_dim, d_model, dtype=dtype) + torch.nn.init.kaiming_uniform_(in_proj_weight, a=math.sqrt(5)) + params[prefix + 'in_proj.weight'] = in_proj_weight + + # in_proj layer norm weight (fused into ColumnParallelLinear in TE) + params[prefix + 'in_proj.layer_norm_weight'] = torch.ones(d_model, dtype=dtype) + + # conv1d + conv_weight = torch.empty(conv_dim, 1, d_conv, dtype=dtype) + torch.nn.init.kaiming_uniform_(conv_weight, a=math.sqrt(5)) + params[prefix + 'conv1d.weight'] = conv_weight + params[prefix + 'conv1d.bias'] = torch.zeros(conv_dim, dtype=dtype) + + # A_log (kept in fp32) + A = torch.empty(nheads, dtype=torch.float32) + A.uniform_(*A_init_range) + params[prefix + 'A_log'] = torch.log(A) + + # D (kept in fp32) + params[prefix + 'D'] = torch.ones(nheads, dtype=torch.float32) + + # dt_bias + dt = torch.exp( + torch.rand(nheads, dtype=dtype) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + ).clamp(min=dt_init_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + params[prefix + 'dt_bias'] = inv_dt + + # norm (RMSNorm) + params[prefix + 'norm.weight'] = torch.ones(mamba_d_inner, dtype=dtype) + + # out_proj (RowParallelLinear) + out_proj_weight = torch.empty(d_model, mamba_d_inner, dtype=dtype) + torch.nn.init.kaiming_uniform_(out_proj_weight, a=math.sqrt(5)) + params[prefix + 'out_proj.weight'] = out_proj_weight + + return params + + +# --------------------------------------------------------------------------- +# Key name helpers +# --------------------------------------------------------------------------- + + +def get_layer_num_from_key(key): + """Extract the layer number from a state dict key like 'decoder.layers.5.mlp...'""" + match = re.search(r'decoder\.layers\.(\d+)\.', key) + if match: + return int(match.group(1)) + return None + + +def replace_layer_num(key, old_num, new_num): + """Replace the layer number in a state dict key.""" + return key.replace(f'decoder.layers.{old_num}.', f'decoder.layers.{new_num}.', 1) + + +def is_attention_param(key): + """Check if a key belongs to an attention sub-layer.""" + return 'self_attention.' in key or 'input_layernorm.' in key + + +def is_mlp_param(key): + """Check if a key belongs to an MLP sub-layer.""" + return ('mlp.' in key or 'pre_mlp_layernorm.' in key) and 'self_attention' not in key + + +def is_ssm_param(key): + """Check if a key belongs to a Mamba SSM mixer sub-layer.""" + ssm_markers = [ + 'mixer.in_proj', + 'mixer.conv1d', + 'mixer.A_log', + 'mixer.D', + 'mixer.dt_bias', + 'mixer.norm', + 'mixer.out_proj', + 'mixer.x_proj', + 'mixer.dt_proj', + ] + return any(m in key for m in ssm_markers) + + +def is_layer_norm_for_ssm(key): + """Check if a key is the input layer norm for an SSM layer. + + In hybrid models, SSM layers can have their own input_layernorm or the + norm can be fused into in_proj.layer_norm_weight. + """ + return 'in_proj.layer_norm_weight' in key + + +# --------------------------------------------------------------------------- +# Core conversion: GPT -> Hybrid +# --------------------------------------------------------------------------- + + +def convert_gpt_to_hybrid(full_model, layer_types, args): + """Convert a GPT state dict to a Hybrid state dict. + + Args: + full_model: OrderedDict with globally-indexed GPT state dict keys. + layer_types: list of layer type chars from hybrid_layer_pattern. + args: Parsed CLI arguments. + + Returns: + OrderedDict: Hybrid state dict with globally-indexed keys. + """ + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-hybrid') + num_gpt_layers = len(attn_map) + + # Validate GPT layer count + gpt_layer_nums = set() + for key in full_model: + lnum = get_layer_num_from_key(key) + if lnum is not None: + gpt_layer_nums.add(lnum) + + if len(gpt_layer_nums) != num_gpt_layers: + raise ValueError( + f"GPT checkpoint has {len(gpt_layer_nums)} layers, but the pattern " + f"has {num_gpt_layers} attention ('*') and {num_gpt_layers} MLP ('-') " + f"layers. These must match." + ) + + target = OrderedDict() + dtype = None + + # Copy / rename non-layer params + for key, tensor in full_model.items(): + if dtype is None and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32): + dtype = tensor.dtype + + if 'decoder.layers.' in key: + continue + + # Rename final_layernorm -> final_norm + if 'decoder.final_layernorm' in key: + new_key = key.replace('decoder.final_layernorm', 'decoder.final_norm') + target[new_key] = tensor + else: + target[key] = tensor + + if dtype is None: + dtype = torch.float32 + + # Map attention and MLP params + for key, tensor in full_model.items(): + lnum = get_layer_num_from_key(key) + if lnum is None: + continue + + if is_attention_param(key): + target_layer = attn_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + elif is_mlp_param(key): + target_layer = mlp_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + # (any other layer params get copied as-is with their own mapping, + # but for pure GPT there should only be attention + MLP) + + # Initialize SSM layers from scratch + print(f" Initializing {len(ssm_indices)} SSM layers from scratch...") + for layer_idx in ssm_indices: + ssm_params = initialize_ssm_layer_params( + layer_idx=layer_idx, + d_model=args.d_model, + mamba_d_inner=args.mamba_d_inner, + mamba_d_state=args.mamba_d_state, + mamba2_n_groups=args.mamba2_n_groups, + mamba2_n_heads=args.mamba2_n_heads, + mamba_head_dim=args.mamba2_head_dim, + d_conv=getattr(args, 'd_conv', 4), + init_method_std=getattr(args, 'init_method_std', 0.02), + dtype=dtype, + ) + target.update(ssm_params) + + # Sort by layer index for consistent ordering + target = _sort_state_dict(target) + + return target + + +# --------------------------------------------------------------------------- +# Core conversion: Hybrid -> GPT +# --------------------------------------------------------------------------- + + +def convert_hybrid_to_gpt(full_model, layer_types, args): + """Convert a Hybrid state dict to a GPT state dict. + + Args: + full_model: OrderedDict with globally-indexed Hybrid state dict keys. + layer_types: list of layer type chars from hybrid_layer_pattern. + args: Parsed CLI arguments. + + Returns: + OrderedDict: GPT state dict with globally-indexed keys. + """ + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'hybrid-to-gpt') + num_gpt_layers = len(attn_map) + + target = OrderedDict() + discarded_ssm_keys = [] + + # Copy / rename non-layer params + for key, tensor in full_model.items(): + if 'decoder.layers.' in key: + continue + + # Rename final_norm -> final_layernorm + if 'decoder.final_norm' in key: + new_key = key.replace('decoder.final_norm', 'decoder.final_layernorm') + target[new_key] = tensor + else: + target[key] = tensor + + # Map attention and MLP params, discard SSM + for key, tensor in full_model.items(): + lnum = get_layer_num_from_key(key) + if lnum is None: + continue + + if is_ssm_param(key) or is_layer_norm_for_ssm(key): + # Discard SSM params + if lnum in ssm_indices: + discarded_ssm_keys.append(key) + continue + + if is_attention_param(key) and lnum in attn_map: + target_layer = attn_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + elif is_mlp_param(key) and lnum in mlp_map: + target_layer = mlp_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + elif lnum in ssm_indices: + # Any remaining SSM-layer param not caught above + discarded_ssm_keys.append(key) + + if discarded_ssm_keys: + print( + f"\n WARNING: Discarded {len(discarded_ssm_keys)} SSM parameter tensors " + f"from {len(ssm_indices)} SSM layers (no GPT equivalent)." + ) + print(f" First few discarded keys: {discarded_ssm_keys[:5]}") + + target = _sort_state_dict(target) + + return target + + +# --------------------------------------------------------------------------- +# Sorting helper +# --------------------------------------------------------------------------- + + +def _sort_state_dict(state_dict): + """Sort state dict keys so that layer-indexed keys are in order.""" + + def sort_key(item): + key = item[0] + # Extract layer number if present + match = re.search(r'decoder\.layers\.(\d+)\.', key) + if match: + return (1, int(match.group(1)), key) + # Non-layer keys: embeddings first, output_layer last + if 'embedding' in key: + return (0, 0, key) + if 'output_layer' in key: + return (2, 0, key) + if 'decoder.final' in key: + return (1, 999999, key) + return (0, 1, key) + + return OrderedDict(sorted(state_dict.items(), key=sort_key)) + + +# --------------------------------------------------------------------------- +# Format-aware save +# --------------------------------------------------------------------------- + + +def _save_dist_full(target_state_dict, common_state, model_prefix, backend, args, iteration): + """Save a fully-gathered state dict in dist-ckpt format. + + The on-disk tensors carry their full logical shape, so downstream Megatron + training reads them back with any TP+PP+FSDP configuration. + """ + if iteration is None: + out_iter = 0 if args.reset_iterations else 0 + iter_dir = args.save_dir + else: + out_iter = 0 if args.reset_iterations else iteration + iter_dir = os.path.join(args.save_dir, f'iter_{out_iter:07d}') + + # Update common state args to reflect target model structure. + common_state = copy.deepcopy(common_state) if common_state else {} + if 'args' in common_state and common_state['args'] is not None: + ckpt_args = common_state['args'] + ckpt_args.num_layers = args.target_num_layers + if hasattr(ckpt_args, 'hybrid_layer_pattern'): + if args.direction == 'gpt-to-hybrid': + ckpt_args.hybrid_layer_pattern = args.hybrid_layer_pattern + else: + ckpt_args.hybrid_layer_pattern = None + if args.reset_iterations: + for attr in ( + 'iteration', + 'consumed_valid_samples', + 'consumed_train_samples', + 'train_iters', + 'train_samples', + ): + if hasattr(ckpt_args, attr): + setattr(ckpt_args, attr, 0) + if args.reset_iterations and 'iteration' in common_state: + common_state['iteration'] = 0 + + print( + f" Writing dist checkpoint to {iter_dir} " + f"(backend={backend}, prefix='{model_prefix}')..." + ) + save_dist_checkpoint_full( + target_state_dict, common_state, iter_dir, model_prefix=model_prefix, backend=backend + ) + write_latest_iteration_marker(iter_dir, out_iter) + + +def main(args): + print("\n====RUNNING GPT <-> Hybrid CHECKPOINT CONVERSION====\n") + print(f" Direction: {args.direction}") + print(f" Source: {args.load_dir}") + print(f" Target: {args.save_dir}") + print(f" Hybrid layer pattern: {args.hybrid_layer_pattern}") + + # Compute derived Mamba dimensions + args.mamba_d_inner = args.d_model * 2 + args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim + + # Parse hybrid layer pattern + layer_types = parse_hybrid_layer_pattern(args.hybrid_layer_pattern) + total_hybrid_layers = len(layer_types) + attn_count = sum(1 for t in layer_types if t == '*') + mlp_count = sum(1 for t in layer_types if t == '-') + ssm_count = sum(1 for t in layer_types if t == 'M') + print( + f"\n Pattern: {len(layer_types)} total layers " + f"({attn_count} attn, {mlp_count} MLP, {ssm_count} SSM, " + f"{len(layer_types) - attn_count - mlp_count - ssm_count} other)" + ) + + # Pattern-level GPT compatibility whitelist (fails fast, pre-load). + validate_pattern_gpt_compatible(layer_types, args.direction) + + # 1. Resolve input format + input_format = getattr(args, 'input_format', 'auto') + if input_format == 'auto': + input_format = detect_checkpoint_format(args.load_dir) + output_format = getattr(args, 'output_format', 'auto') + if output_format == 'auto': + output_format = input_format + print(f"\n Input format: {input_format}") + print(f" Output format: {output_format}") + + if input_format not in DIST_FORMATS: + raise ValueError( + f"Unsupported input format: {input_format}. " + f"Only dist formats are supported: {DIST_FORMATS}." + ) + if output_format not in DIST_FORMATS: + raise ValueError( + f"Unsupported output format: {output_format}. " + f"Only dist formats are supported: {DIST_FORMATS}." + ) + + # 2. Load source checkpoint into a fully-gathered state dict + print("\n[Step 1] Loading source checkpoint...") + full_model, common_state, model_prefix, dist_backend, iteration = load_dist_checkpoint_full( + args.load_dir + ) + print( + f" Source: dist backend={dist_backend}, prefix='{model_prefix}', " + f"iteration={iteration}, params={len(full_model)}" + ) + + # Args-level GPT compatibility whitelist: reject MoE, MLA, MTP, linear / + # experimental attention, heterogeneous block specs, etc. See module header. + source_args = common_state.get('args') if common_state else None + validate_source_args_gpt_compatible(source_args, args.direction) + + # 3. Convert + print(f"\n[Step 2] Converting ({args.direction})...") + if args.direction == 'gpt-to-hybrid': + target_state_dict = convert_gpt_to_hybrid(full_model, layer_types, args) + args.target_num_layers = total_hybrid_layers + elif args.direction == 'hybrid-to-gpt': + target_state_dict = convert_hybrid_to_gpt(full_model, layer_types, args) + args.target_num_layers = attn_count + else: + raise ValueError(f"Unknown direction: {args.direction}") + print(f" Target model: {len(target_state_dict)} parameters") + + # 4. Save + print(f"\n[Step 3] Saving to {args.save_dir}...") + _save_dist_full(target_state_dict, common_state, model_prefix, output_format, args, iteration) + + print("\n====CONVERSION COMPLETE====\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert checkpoints between GPTModel and HybridModel formats.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + '--direction', + type=str, + required=True, + choices=['gpt-to-hybrid', 'hybrid-to-gpt'], + help='Conversion direction.', + ) + parser.add_argument( + '--load-dir', type=str, required=True, help='Path to source checkpoint directory.' + ) + parser.add_argument( + '--save-dir', type=str, required=True, help='Path to target checkpoint directory.' + ) + parser.add_argument( + '--hybrid-layer-pattern', + type=str, + required=True, + help='Hybrid layer pattern string, e.g. "M*-M*-M*-M*-".', + ) + + parser.add_argument( + '--input-format', + type=str, + default='auto', + choices=('auto',) + DIST_FORMATS, + help='Source checkpoint format. "auto" detects from metadata.json.', + ) + parser.add_argument( + '--output-format', + type=str, + default='auto', + choices=('auto',) + DIST_FORMATS, + help='Target checkpoint format. "auto" matches the input format. ' + 'Dist formats (torch_dist / fsdp_dtensor) transparently support ' + 'TP+PP+FSDP training checkpoints.', + ) + + # Model architecture params + parser.add_argument('--d-model', type=int, default=4096, help='Model hidden dimension.') + parser.add_argument( + '--mamba-version', type=int, default=2, choices=[1, 2], help='Mamba SSM version.' + ) + parser.add_argument('--mamba-d-state', type=int, default=128, help='Mamba state dimension.') + parser.add_argument( + '--mamba2-n-groups', type=int, default=8, help='Number of groups (Mamba v2).' + ) + parser.add_argument( + '--mamba2-head-dim', type=int, default=64, help='Head dimension (Mamba v2).' + ) + parser.add_argument('--d-conv', type=int, default=4, help='Causal convolution kernel size.') + + # Initialization params + parser.add_argument( + '--init-method-std', + type=float, + default=0.02, + help='Std for initializing new Mamba SSM params.', + ) + + # Checkpoint control + parser.add_argument( + '--reset-iterations', action='store_true', help='Zero out the training iteration count.' + ) + + args = parser.parse_args() + main(args) diff --git a/tools/checkpoint/loader_base.py b/tools/checkpoint/loader_base.py index 11e9224f726..de31391d076 100644 --- a/tools/checkpoint/loader_base.py +++ b/tools/checkpoint/loader_base.py @@ -82,7 +82,6 @@ def parse_megatron_args(self): self.queue.put("exit") sys.exit(1) - margs.use_legacy_models = False margs.transformer_impl = self.args.loader_transformer_impl if self.args.loader_transformer_impl == "local" and margs.normalization == "RMSNorm": margs.no_persist_layer_norm = True @@ -450,7 +449,6 @@ def build_checkpoint_metadata(self, true_vocab_size): md.true_vocab_size = true_vocab_size md.make_vocab_size_divisible_by = self.margs.make_vocab_size_divisible_by md.checkpoint_args = self.checkpoint_args - md.use_legacy_models = self.margs.use_legacy_models return md def build_sys_argv(self): diff --git a/tools/checkpoint/loader_legacy.py b/tools/checkpoint/loader_legacy.py deleted file mode 100644 index 0dffa4efff8..00000000000 --- a/tools/checkpoint/loader_legacy.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import json -import os -import sys -import types -from functools import partial - -import torch - -from tools.checkpoint.utils import _ConverterFakeProcessGroup - - -def add_arguments(parser): - group = parser.add_argument_group(title='Megatron loader') - - group.add_argument( - '--true-vocab-size', - type=int, - default=None, - help='original size of vocab, if specified will trim padding from embedding table.', - ) - group.add_argument( - '--vocab-file', - type=str, - default=None, - help='Path to the vocab file. If specified will use this to get vocab size and ' - 'trim padding from the embedding table.', - ) - group.add_argument( - '--megatron-path', type=str, default=None, help='Base directory of Megatron repository' - ) - group.add_argument( - '--position-embedding-type', - type=str, - default='learned_absolute', - choices=['learned_absolute', 'rope'], - help='Position embedding type.', - ) - group.add_argument( - '--loader-transformer-impl', - default='local', - choices=['local', 'transformer_engine'], - help='Which Transformer implementation to use.', - ) - - -def _load_checkpoint(queue, args): - - # Search in directory above this - sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) - if args.megatron_path is not None: - sys.path.insert(0, args.megatron_path) - - try: - from megatron.core import mpu - from megatron.core.enums import ModelType - from megatron.legacy.model import module - from megatron.training.arguments import parse_args, validate_args - from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint - from megatron.training.global_vars import set_args, set_global_variables - except ModuleNotFoundError: - print( - "Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting." - ) - queue.put("exit") - exit(1) - - # We want all arguments to come from us - sys.argv = [ - 'script.py', - '--no-masked-softmax-fusion', - '--no-bias-gelu-fusion', - '--no-bias-dropout-fusion', - '--use-cpu-initialization', - '--micro-batch-size', - '1', - '--no-load-optim', - '--no-load-rng', - '--no-save-optim', - '--no-save-rng', - '--mock-data', # To pass the "blend data checks" in arguments.py - '--no-initialization', - '--load', - args.load_dir, - '--position-embedding-type', - args.position_embedding_type, - '--exit-on-missing-checkpoint', - '--use-mp-args-from-checkpoint-args', - '--no-one-logger', - ] - - margs = parse_args() - margs, checkpoint_args = load_args_from_checkpoint(margs) - - # Arguments do sanity checks on the world size, but we don't care, - # so trick it into thinking we are plenty of processes - margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size - - # Explicitly copy data types from checkpoint. - margs.fp16 = checkpoint_args.fp16 - margs.bf16 = checkpoint_args.bf16 - - # Validate margs. - margs = validate_args(margs) - - margs.use_legacy_models = True - margs.transformer_impl = args.loader_transformer_impl - - def check_for_arg(arg_name, default=None): - if getattr(margs, arg_name, None) is None: - if default is not None: - setattr(margs, arg_name, default) - else: - print(f"Checkpoint does not specify the argument {arg_name}. Exiting.") - print(f"Arguments: {margs}") - queue.put("exit") - exit(1) - - check_for_arg('tensor_model_parallel_size') - check_for_arg('pipeline_model_parallel_size') - check_for_arg('num_layers') - check_for_arg('hidden_size') - check_for_arg('seq_length') - check_for_arg('num_attention_heads') - check_for_arg('max_position_embeddings') - check_for_arg('position_embedding_type') - check_for_arg('tokenizer_type') - check_for_arg('iteration') - check_for_arg('bert_binary_head') - check_for_arg('disable_bias_linear', False) - check_for_arg('params_dtype') - check_for_arg('swiglu', False) - - # Determine how to make our models - if args.model_type == 'GPT': - from gpt_builders import gpt_builder - from model_provider import model_provider as common_model_provider - - model_provider = partial(common_model_provider, gpt_builder) - margs.model_type = ModelType.encoder_or_decoder - elif args.model_type == 'BERT': - from pretrain_bert import model_provider - - margs.model_type = ModelType.encoder_or_decoder - else: - raise Exception(f'unrecognized model type: {args.model_type}') - - # supress warning about torch.distributed not being initialized - module.MegatronModule.embedding_warning_printed = True - - consumed_train_samples = None - consumed_valid_samples = None - - def get_models(count, dtype): - nonlocal consumed_train_samples - nonlocal consumed_valid_samples - model_array_len = margs.virtual_pipeline_model_parallel_size - if model_array_len is None: - model_array_len = 1 - models = [[] for _ in range(model_array_len)] - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - for rank in range(count): - fake_tp_group = mpu.get_tensor_model_parallel_group() - fake_tp_group.set_rank(rank) - mpu.set_tensor_model_parallel_rank(rank) - if margs.virtual_pipeline_model_parallel_size is not None: - model_ = [] - for i in range(margs.virtual_pipeline_model_parallel_size): - mpu.set_virtual_pipeline_model_parallel_rank(i) - # Set pre_process and post_process only after virtual rank is set. - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - this_model = model_provider( - pre_process=pre_process, post_process=post_process - ).to(dtype) - model_.append(this_model) - else: - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - model_rank = 0 - model_ = [model_provider(pre_process, post_process).to(dtype)] - margs.consumed_train_samples = 0 - margs.consumed_valid_samples = 0 - margs.exit_on_missing_checkpoint = True - load_checkpoint(model_, None, None) - - if consumed_train_samples is not None: - assert margs.consumed_train_samples == consumed_train_samples - else: - consumed_train_samples = margs.consumed_train_samples - if consumed_valid_samples is not None: - assert margs.consumed_valid_samples == consumed_valid_samples - else: - consumed_valid_samples = margs.consumed_valid_samples - for vp_rank in range(model_array_len): - models[vp_rank].append(model_[vp_rank]) - return models - - set_global_variables(margs, build_tokenizer=False) - mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) - mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) - mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size) - - # For backward compatibility during local parallel states refactoring - fake_tp_group = _ConverterFakeProcessGroup(size=margs.tensor_model_parallel_size) - mpu._TENSOR_MODEL_PARALLEL_GROUP = fake_tp_group - - # Get true (non-padded) vocab size - if args.true_vocab_size is not None: - true_vocab_size = args.true_vocab_size - elif args.vocab_file is not None: - vocab = json.load(open(args.vocab_file)) - true_vocab_size = len(vocab) - if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size: - print( - "Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting." - ) - queue.put("exit") - exit(1) - else: - true_vocab_size = None - - # short aliases - tp_size = margs.tensor_model_parallel_size - pp_size = margs.pipeline_model_parallel_size - vp_size = margs.virtual_pipeline_model_parallel_size - if vp_size is None: - vp_size = 1 - - # Layernorm has bias; RMSNorm does not. - if hasattr(checkpoint_args, 'normalization'): - norm_has_bias = checkpoint_args.normalization == "LayerNorm" - else: - # older models only supported LayerNorm - norm_has_bias = True - - # metadata - md = types.SimpleNamespace() - md.model_type = args.model_type - md.num_layers = margs.num_layers - md.hidden_size = margs.hidden_size - md.seq_length = margs.seq_length - md.num_attention_heads = margs.num_attention_heads - md.max_position_embeddings = margs.max_position_embeddings - md.tokenizer_type = margs.tokenizer_type - md.iteration = margs.iteration - md.params_dtype = margs.params_dtype - md.bert_binary_head = margs.bert_binary_head - md.output_layer = margs.untie_embeddings_and_output_weights - md.position_embedding_type = margs.position_embedding_type - md.linear_bias = margs.add_bias_linear - md.qkv_bias = margs.add_qkv_bias - md.norm_has_bias = norm_has_bias - md.swiglu = margs.swiglu - md.previous_tensor_parallel_size = margs.tensor_model_parallel_size - md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size - md.true_vocab_size = true_vocab_size - md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by - md.checkpoint_args = checkpoint_args - - # Get first pipe stage - mpu.set_pipeline_model_parallel_rank(0) - all_models = [get_models(tp_size, md.params_dtype)] - models = all_models[0][0] - - md.consumed_train_samples = consumed_train_samples - md.consumed_valid_samples = consumed_valid_samples - queue.put(md) - - def queue_put(name, msg): - print(f"sending {name}") - msg["name"] = name - queue.put(msg) - - # Send embeddings - message = { - "word embeddings": torch.cat( - [ - models[tp_rank].language_model.embedding.word_embeddings.weight.data - for tp_rank in range(tp_size) - ], - dim=0, - ) - } - if md.position_embedding_type == 'learned_absolute': - message["position embeddings"] = models[ - 0 - ].language_model.embedding.position_embeddings.weight.data - else: - assert not hasattr(models[0].language_model.embedding, 'position_embeddings') - - queue_put("embeddings", message) - - total_layer_num = 0 - for vp_rank in range(vp_size): - mpu.set_virtual_pipeline_model_parallel_rank(vp_rank) - for pp_rank in range(pp_size): - if pp_rank > 0: - mpu.set_pipeline_model_parallel_rank(pp_rank) - if vp_rank == 0: - all_models.append(get_models(tp_size, md.params_dtype)) - models = all_models[pp_rank][vp_rank] - for layer_num in range(len(models[0].language_model.encoder.layers)): - message = {} - - # Get non-parallel tensors from tp_rank 0 - layer = models[0].language_model.encoder.layers[layer_num] - message["input norm weight"] = layer.input_norm.weight.data - if norm_has_bias: - message["input norm bias"] = layer.input_norm.bias.data - message["post norm weight"] = layer.post_attention_norm.weight.data - if norm_has_bias: - message["post norm bias"] = layer.post_attention_norm.bias.data - if md.linear_bias: - message["dense bias"] = layer.self_attention.dense.bias.data - message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data - - # Grab all parallel tensors for this layer - qkv_weight = [] - qkv_bias = [] - dense_weight = [] - mlp_l0_weight = [] - mlp_l0_bias = [] - mlp_l1_weight = [] - for tp_rank, model in enumerate(models): - layer = model.language_model.encoder.layers[layer_num] - qkv_weight.append(layer.self_attention.query_key_value.weight.data) - dense_weight.append(layer.self_attention.dense.weight.data) - mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data) - mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data) - if md.qkv_bias: - qkv_bias.append(layer.self_attention.query_key_value.bias.data) - if md.linear_bias: - mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data) - - # Handle gated linear units - if md.swiglu: - # concat all the first halves ('W's) and all the second halves ('V's) - for tp_rank in range(tp_size): - mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0) - message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) - message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) - else: - message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) - - # simple concat of the rest - message["qkv weight"] = torch.cat(qkv_weight, dim=0) - message["dense weight"] = torch.cat(dense_weight, dim=1) - message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) - if md.qkv_bias: - message["qkv bias"] = torch.cat(qkv_bias, dim=0) - if md.linear_bias: - if md.swiglu: - for tp_rank in range(tp_size): - mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0) - message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias], dim=0) - message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias], dim=0) - else: - message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) - - queue_put(f"transformer layer {total_layer_num}", message) - - total_layer_num = total_layer_num + 1 - - # Send final norm from tp_rank 0 - message = {"weight": models[0].language_model.encoder.final_norm.weight.data} - if norm_has_bias: - message["bias"] = models[0].language_model.encoder.final_norm.bias.data - queue_put("final norm", message) - - if md.output_layer: - message = { - "weight": torch.cat( - [ - models[tp_rank].language_model.output_layer.weight.data - for tp_rank in range(tp_size) - ], - dim=0, - ) - } - queue_put("output layer", message) - - # Send BERT lm head and binary head if it exists - if md.model_type == 'BERT': - message = { - "weight": models[0].language_model.pooler.dense.weight.data, - "bias": models[0].language_model.pooler.dense.bias.data, - } - queue_put("pooler", message) - - message = { - "dense weight": models[0].lm_head.dense.weight.data, - "dense bias": models[0].lm_head.dense.bias.data, - "norm weight": models[0].lm_head.norm.weight.data, - } - if norm_has_bias: - message["norm bias"] = models[0].lm_head.norm.bias.data - queue_put("lm head", message) - - if md.bert_binary_head: - message = { - "weight": models[0].binary_head.weight.data, - "bias": models[0].binary_head.bias.data, - } - queue_put("binary head", message) - queue.put("done") - - -def load_checkpoint(queue, args): - try: - _load_checkpoint(queue, args) - except Exception: - queue.put("exit") - raise diff --git a/tools/checkpoint/loader_llama_mistral.py b/tools/checkpoint/loader_llama_mistral.py deleted file mode 100644 index 45285bbf29c..00000000000 --- a/tools/checkpoint/loader_llama_mistral.py +++ /dev/null @@ -1,751 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import json -import os -import sys - -import torch -from utils import _ConverterFakeProcessGroup - -try: - import transformers -except ImportError: - raise ImportError("The 'transformers' package is not installed.") -import gc -import shutil -import types - -from tqdm import tqdm - - -def add_arguments(parser): - group = parser.add_argument_group(title='Llama/Mistral loader.') - - # TODO(jbarker): Need assertion to make sure *exactly* one of these is used - parser.add_argument( - '--model-size', - type=str, - required=True, - choices=[ - 'llama2-7B', - 'llama2-13B', - 'llama2-70B', - 'llama2-7Bf', - 'llama2-13Bf', - 'llama2-70Bf', - 'llama3', - 'mistral', - 'yi-34B', - 'qwen2.5', - ], - help='Select model size/type', - ) - parser.add_argument( - '--checkpoint-type', - type=str, - required=True, - choices=['meta', 'hf'], - help='Type of checkpoint to convert, options are "meta" or "hf"', - ) - parser.add_argument('--bf16', action='store_true', help='Whether to load weights in bf16.') - parser.add_argument('--fp16', action='store_true', help='Whether to load weights in fp16.') - group.add_argument( - '--true-vocab-size', - type=int, - default=None, - help='original size of vocab, if specified will trim padding from embedding table.', - ) - group.add_argument( - '--vocab-file', - type=str, - default=None, - help='Path to the vocab file. If specified will use this to get vocab size and ' - 'trim padding from the embedding table.', - ) - group.add_argument('--tokenizer-model', required=True, help='Tokenizer model file.') - group.add_argument( - '--megatron-path', type=str, default=None, help='Base directory of Megatron repository' - ) - group.add_argument( - "--make-vocab-size-divisible-by", - type=int, - default=None, - help="Make vocab size divisible by", - ) - group.add_argument( - '--loader-transformer-impl', - default='local', - choices=['local', 'transformer_engine'], - help='Which Transformer implementation to use.', - ) - - -def verify_transformers_version(): - major, minor, patch = map(int, transformers.__version__.split('.')) - assert major >= 4 and minor >= 31 - - -NUM_SHARDS = { - "llama2-7B": 1, - "llama2-7Bf": 1, - "llama2-13B": 2, - "llama2-13Bf": 2, - "llama2-70B": 8, - "llama2-70Bf": 8, -} - - -def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): - return multiple_of * ( - (int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of - ) - - -def read_json(path): - with open(path, "r") as f: - return json.load(f) - - -def write_json(text, path): - with open(path, "w") as f: - json.dump(text, f) - - -# This conversion is adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py -def convert_to_hf(model_path, input_base_path, model_size, tokenizer_path): - if "llama2" in model_size: - from transformers import LlamaConfig as ModelConfig - from transformers import LlamaTokenizer, LlamaTokenizerFast - else: - raise NotImplementedError( - f"converting {model_size} is only supported using HuggingFace weights" - ) - - # for backward compatibility, before you needed the repo to be called `my_repo/model_size` - if not os.path.isfile(os.path.join(input_base_path, "params.json")): - input_base_path = os.path.join(input_base_path, model_size) - - os.makedirs(model_path, exist_ok=True) - - params = read_json(os.path.join(input_base_path, "params.json")) - num_shards = NUM_SHARDS[model_size] - params = params.get("model", params) - n_layers = params["n_layers"] - n_heads = params["n_heads"] - n_heads_per_shard = n_heads // num_shards - dim = params["dim"] - dims_per_head = dim // n_heads - base = params.get("rope_theta", 10000.0) - inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) - if base > 10000.0: - max_position_embeddings = 32768 if "mistral" in model_size else 16384 - else: - max_position_embeddings = 4096 - - if "llama2" in model_size: - tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast - else: - raise AttributeError(f"model_size={model_size} not supported") - - if tokenizer_path is not None: - if "llama2" in model_size: - tokenizer = tokenizer_class(tokenizer_path) - tokenizer.save_pretrained(model_path) - vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000 - else: - raise AttributeError(f"model_size={model_size} is not supported") - - if params.get("n_kv_heads", None) is not None: - num_key_value_heads = params["n_kv_heads"] # for GQA / MQA - num_local_key_value_heads = n_heads_per_shard // num_key_value_heads - key_value_dim = dim // num_key_value_heads - else: # compatibility with other checkpoints - num_key_value_heads = n_heads - num_local_key_value_heads = n_heads_per_shard - key_value_dim = dim - - # permute for sliced rotary - def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): - return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) - - print(f"Fetching all parameters from the checkpoint at {input_base_path}.") - # Load weights - if num_shards == 1: - # Not sharded - # (The sharded implementation would also work, but this is simpler.) - loaded = torch.load( - os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu" - ) - else: - # Sharded - loaded = [ - torch.load( - os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu" - ) - for i in range(num_shards) - ] - param_count = 0 - index_dict = {"weight_map": {}} - for layer_i in range(n_layers): - filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" - if num_shards == 1: - # Unsharded - q_proj = loaded[f"layers.{layer_i}.attention.wq.weight"] - k_proj = loaded[f"layers.{layer_i}.attention.wk.weight"] - if ("llama2" in model_size) or ("mistral" in model_size): - q_proj = permute(q_proj) - k_proj = permute(k_proj) - state_dict = { - f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj, - f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj, - f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[ - f"layers.{layer_i}.attention.wv.weight" - ], - f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[ - f"layers.{layer_i}.attention.wo.weight" - ], - f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[ - f"layers.{layer_i}.feed_forward.w1.weight" - ], - f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[ - f"layers.{layer_i}.feed_forward.w2.weight" - ], - f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[ - f"layers.{layer_i}.feed_forward.w3.weight" - ], - f"model.layers.{layer_i}.input_layernorm.weight": loaded[ - f"layers.{layer_i}.attention_norm.weight" - ], - f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ - f"layers.{layer_i}.ffn_norm.weight" - ], - } - else: - # Sharded - # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share - # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is - # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. - - state_dict = { - f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ - f"layers.{layer_i}.attention_norm.weight" - ].clone(), - f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ - f"layers.{layer_i}.ffn_norm.weight" - ].clone(), - } - state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( - torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wq.weight"].view( - n_heads_per_shard, dims_per_head, dim - ) - for i in range(num_shards) - ], - dim=0, - ).reshape(dim, dim) - ) - state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( - torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( - num_local_key_value_heads, dims_per_head, dim - ) - for i in range(num_shards) - ], - dim=0, - ).reshape(key_value_dim, dim), - num_key_value_heads, - key_value_dim, - dim, - ) - state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( - num_local_key_value_heads, dims_per_head, dim - ) - for i in range(num_shards) - ], - dim=0, - ).reshape(key_value_dim, dim) - - state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], - dim=1, - ) - state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], - dim=0, - ) - state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], - dim=1, - ) - state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], - dim=0, - ) - - state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq - for k, v in state_dict.items(): - index_dict["weight_map"][k] = filename - param_count += v.numel() - torch.save(state_dict, os.path.join(model_path, filename)) - - filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" - if num_shards == 1: - # Unsharded - state_dict = { - "model.embed_tokens.weight": loaded["tok_embeddings.weight"], - "model.norm.weight": loaded["norm.weight"], - "lm_head.weight": loaded["output.weight"], - } - else: - d = 0 if "llama3" in model_size else 1 - state_dict = { - "model.norm.weight": loaded[0]["norm.weight"], - "model.embed_tokens.weight": torch.cat( - [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=d - ), - "lm_head.weight": torch.cat( - [loaded[i]["output.weight"] for i in range(num_shards)], dim=0 - ), - } - - for k, v in state_dict.items(): - index_dict["weight_map"][k] = filename - param_count += v.numel() - torch.save(state_dict, os.path.join(model_path, filename)) - - # Write configs - index_dict["metadata"] = {"total_size": param_count * 2} - write_json(index_dict, os.path.join(model_path, "pytorch_model.bin.index.json")) - ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 - multiple_of = params["multiple_of"] if "multiple_of" in params else 256 - config = ModelConfig( - hidden_size=dim, - intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), - num_attention_heads=params["n_heads"], - num_hidden_layers=params["n_layers"], - rms_norm_eps=params["norm_eps"], - num_key_value_heads=num_key_value_heads, - vocab_size=vocab_size, - rope_theta=base, - max_position_embeddings=max_position_embeddings, - ) - config.save_pretrained(model_path) - - # Make space so we can load the model properly now. - del state_dict - del loaded - gc.collect() - - return model_path - - -def load_args_from_checkpoint(args, model_size): - - # Read Llama args. - model_args_path = os.path.join(args.load, "config.json") - with open(model_args_path) as f: - model_args = json.load(f) - - # Update Megatron args. - args.seq_length = 4096 - if "llama2" in model_size: - # Correct bug in earlier conversion script. - args.max_position_embeddings = 4096 - else: - args.max_position_embeddings = model_args["max_position_embeddings"] - - args.hidden_size = model_args["hidden_size"] - args.num_attention_heads = model_args["num_attention_heads"] - args.num_layers = model_args["num_hidden_layers"] - args.global_batch_size = 1024 - args.norm_epsilon = model_args["rms_norm_eps"] - args.iteration = 1 # '0', 'release' don't work - args.position_embedding_type = "rope" - args.swiglu = True - args.normalization = "RMSNorm" - args.add_bias_linear = False - args.untie_embeddings_and_output_weights = not model_args.get("tie_word_embeddings", False) - args.vocab_size = model_args["vocab_size"] - args.padded_vocab_size = model_args["vocab_size"] - args.ffn_hidden_size = model_args["intermediate_size"] - - if "num_key_value_heads" in model_args: - args.group_query_attention = True - args.num_query_groups = model_args["num_key_value_heads"] - - -def set_preprocess_state(args, model, hf_model): - '''Set embedding params.''' - model.language_model.embedding.word_embeddings.weight.data.copy_( - hf_model.model.embed_tokens.weight - ) - - -def set_postprocess_state(args, model, hf_model): - '''Set output layer & norm params.''' - model.language_model.encoder.final_norm.weight.data.copy_(hf_model.model.norm.weight) - if args.untie_embeddings_and_output_weights: - model.language_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - - -def set_attn_state(args, layer, hf_layer): - '''Set self-attention params.''' - - # Get attention layer & state. - attn = layer.self_attention - hf_attn = hf_layer.self_attn - - # Reshape loaded weights. - tp = args.tensor_model_parallel_size - nh = args.num_attention_heads // tp - ng = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) // tp - dim = args.kv_channels - assert nh % ng == 0 - - # Copy weights (re-order dimensions for Megatron). - attn.query_key_value.weight.data.copy_( - torch.cat( - [ - hf_attn.q_proj.weight.reshape((ng, dim * nh // ng, -1)), - hf_attn.k_proj.weight.reshape((ng, dim, -1)), - hf_attn.v_proj.weight.reshape((ng, dim, -1)), - ], - dim=1, - ).reshape((-1, args.hidden_size)) - ) - if args.add_qkv_bias: - attn.query_key_value.bias.data.copy_( - torch.cat( - [ - hf_attn.q_proj.bias.reshape((ng, dim * nh // ng)), - hf_attn.k_proj.bias.reshape((ng, dim)), - hf_attn.v_proj.bias.reshape((ng, dim)), - ], - dim=1, - ).reshape(-1) - ) - - attn.dense.weight.data.copy_(hf_attn.o_proj.weight) - - -def set_mlp_state(args, layer, hf_layer): - '''Set MLP params.''' - - mlp = layer.mlp - hf_mlp = hf_layer.mlp - - mlp.dense_h_to_4h.weight.data.copy_( - torch.cat([hf_mlp.gate_proj.weight, hf_mlp.up_proj.weight], dim=0) - ) - mlp.dense_4h_to_h.weight.data.copy_(hf_mlp.down_proj.weight) - - -def set_layer_state(args, model, hf_model, layer_idx): - '''Set transformer layer params.''' - - layer = model.language_model.encoder.layers[layer_idx] - hf_layer = hf_model.model.layers[layer_idx] - - set_attn_state(args, layer, hf_layer) - set_mlp_state(args, layer, hf_layer) - layer.input_norm.weight.data.copy_(hf_layer.input_layernorm.weight) - layer.post_attention_norm.weight.data.copy_(hf_layer.post_attention_layernorm.weight) - - -def load_checkpoint_to_model(args): - '''Set model params.''' - - from transformers import AutoModelForCausalLM - - from gpt_builders import gpt_builder - from model_provider import model_provider - - # Load Huggingface model. - hf_model = AutoModelForCausalLM.from_pretrained( - args.load, torch_dtype=args.params_dtype, low_cpu_mem_usage=True, device_map="cpu" - ) - - # Init Megatron model. - model = model_provider(gpt_builder, pre_process=True, post_process=True).to(args.params_dtype) - - # Set model state. - set_preprocess_state(args, model, hf_model) - set_postprocess_state(args, model, hf_model) - for layer_idx in tqdm(range(args.num_layers), "set layer states"): - set_layer_state(args, model, hf_model, layer_idx) - - return model - - -def _load_checkpoint(queue, args): - - verify_transformers_version() - - # Search in directory above this. - sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) - ) - if args.megatron_path is not None: - sys.path.insert(0, args.megatron_path) - - # Convert Meta checkpoint to HF format as an intermediate step - if args.checkpoint_type == "meta": - model_tmp_path = convert_to_hf( - model_path=os.path.join(args.save_dir, 'tmp'), - input_base_path=args.load_dir, - model_size=args.model_size, - tokenizer_path=args.tokenizer_model, - ) - args.load_dir = model_tmp_path - args.tokenizer_model = model_tmp_path # point to HF tokenizer model - - try: - from megatron.core import mpu - from megatron.core.enums import ModelType - from megatron.legacy.model import module - from megatron.training.arguments import parse_args, validate_args - from megatron.training.global_vars import set_args, set_global_variables - except ModuleNotFoundError: - print( - "Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting." - ) - queue.put("exit") - exit(1) - - # We want all arguments to come from us. - sys.argv = [ - 'script.py', - '--no-masked-softmax-fusion', - '--no-bias-gelu-fusion', - '--no-bias-dropout-fusion', - '--use-cpu-initialization', - '--micro-batch-size', - '1', - '--no-load-optim', - '--no-load-rng', - '--no-save-optim', - '--no-save-rng', - '--mock-data', # To pass the "blend data checks" in arguments.py - '--no-initialization', - '--load', - args.load_dir, - '--no-one-logger', - ] - - if args.make_vocab_size_divisible_by is not None: - sys.argv.extend(["--make-vocab-size-divisible-by", str(args.make_vocab_size_divisible_by)]) - - margs = parse_args() - margs.tokenizer_model = args.tokenizer_model - load_args_from_checkpoint(margs, args.model_size) - - if "llama2" in args.model_size: - margs.tokenizer_type = "Llama2Tokenizer" - elif "yi" in args.model_size: - margs.tokenizer_type = "HuggingFaceTokenizer" - elif "llama3" in args.model_size: - margs.tokenizer_type = "HuggingFaceTokenizer" - elif "mistral" in args.model_size: - margs.tokenizer_type = "HuggingFaceTokenizer" - elif "qwen2.5" in args.model_size: - margs.tokenizer_type = "HuggingFaceTokenizer" - margs.add_qkv_bias = True - - # Arguments do sanity checks on the world size, but we don't care, - # so trick it into thinking we are plenty of processes. - margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size - - margs = validate_args(margs) - - margs.use_legacy_models = True - margs.transformer_impl = args.loader_transformer_impl - - margs.position_embedding_type = "rope" - - def check_for_arg(arg_name, default=None): - if getattr(margs, arg_name, None) is None: - if default is not None: - setattr(margs, arg_name, default) - else: - print(f"Checkpoint does not specify the argument {arg_name}. Exiting.") - print(f"Arguments: {margs}") - queue.put("exit") - exit(1) - - check_for_arg('tensor_model_parallel_size') - check_for_arg('pipeline_model_parallel_size') - check_for_arg('num_layers') - check_for_arg('hidden_size') - check_for_arg('seq_length') - check_for_arg('num_attention_heads') - check_for_arg('max_position_embeddings') - check_for_arg('position_embedding_type') - check_for_arg('iteration') - check_for_arg('bert_binary_head') - check_for_arg('disable_bias_linear', False) - check_for_arg('params_dtype') - check_for_arg('swiglu', False) - - # Determine how to make our models. - assert args.model_type == 'GPT', 'Llama-2, Llama-3 and Mistral are GPT models.' - margs.model_type = ModelType.encoder_or_decoder - margs.params_dtype = ( - torch.bfloat16 if args.bf16 else torch.float16 if args.fp16 else torch.float32 - ) - - # Suppress warning about torch.distributed not being initialized. - module.MegatronModule.embedding_warning_printed = True - - set_global_variables(margs, build_tokenizer=False) - mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) - mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) - mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size) - - # For backward compatibility during local parallel states refactoring - fake_tp_group = _ConverterFakeProcessGroup(size=margs.tensor_model_parallel_size) - fake_ep_group = _ConverterFakeProcessGroup(size=margs.expert_model_parallel_size) - mpu._TENSOR_MODEL_PARALLEL_GROUP = fake_tp_group - mpu._EXPERT_MODEL_PARALLEL_GROUP = fake_ep_group - - # Short aliases. - tp_size = margs.tensor_model_parallel_size - pp_size = margs.pipeline_model_parallel_size - vp_size = margs.virtual_pipeline_model_parallel_size - if vp_size is None: - vp_size = 1 - - # Metadata. - md = types.SimpleNamespace() - md.model_type = args.model_type - md.num_layers = margs.num_layers - md.hidden_size = margs.hidden_size - md.seq_length = margs.seq_length - md.num_attention_heads = margs.num_attention_heads - md.max_position_embeddings = margs.max_position_embeddings - md.tokenizer_type = margs.tokenizer_type - md.iteration = margs.iteration - md.params_dtype = margs.params_dtype - md.bert_binary_head = margs.bert_binary_head - md.output_layer = margs.untie_embeddings_and_output_weights - md.position_embedding_type = margs.position_embedding_type - md.linear_bias = margs.add_bias_linear - md.qkv_bias = margs.add_qkv_bias - md.norm_has_bias = False - md.swiglu = margs.swiglu - md.previous_tensor_parallel_size = margs.tensor_model_parallel_size - md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size - md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by - md.checkpoint_args = margs - md.consumed_train_samples = 0 - md.consumed_valid_samples = 0 - - margs.model_size = args.model_size - - # Get true (non-padded) vocab size - tokenizer = transformers.AutoTokenizer.from_pretrained(margs.tokenizer_model) - md.true_vocab_size = tokenizer._tokenizer.get_vocab_size(with_added_tokens=True) - - # Get first pipe stage. - mpu.set_tensor_model_parallel_rank(0) - mpu.set_pipeline_model_parallel_rank(0) - model = load_checkpoint_to_model(margs) - - queue.put(md) - - def queue_put(name, msg): - print(f"sending {name}") - msg["name"] = name - queue.put(msg) - - # Send embeddings. - message = {"word embeddings": model.language_model.embedding.word_embeddings.weight.data} - if md.position_embedding_type == 'learned_absolute': - message["position embeddings"] = ( - model.language_model.embedding.position_embeddings.weight.data - ) - else: - assert not hasattr(model.language_model.embedding, 'position_embeddings') - - queue_put("embeddings", message) - - for layer_num in range(margs.num_layers): - message = {} - - # Get non-parallel tensors from tp_rank 0. - layer = model.language_model.encoder.layers[layer_num] - message["input norm weight"] = layer.input_norm.weight.data - message["post norm weight"] = layer.post_attention_norm.weight.data - if md.linear_bias: - message["dense bias"] = layer.self_attention.dense.bias.data - message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data - - # Grab all parallel tensors for this layer. - qkv_weight = [] - qkv_bias = [] - dense_weight = [] - mlp_l0_weight = [] - mlp_l0_bias = [] - mlp_l1_weight = [] - layer = model.language_model.encoder.layers[layer_num] - qkv_weight.append(layer.self_attention.query_key_value.weight.data) - dense_weight.append(layer.self_attention.dense.weight.data) - mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data) - mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data) - - if md.qkv_bias: - qkv_bias.append(layer.self_attention.query_key_value.bias.data) - if md.linear_bias: - mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data) - - # Handle gated linear units. - if md.swiglu: - # Concat all the first halves ('W's) and all the second halves ('V's). - for tp_rank in range(tp_size): - mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0) - message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) - message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) - else: - message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) - - # Simple concat of the rest. - message["qkv weight"] = torch.cat(qkv_weight, dim=0) - message["dense weight"] = torch.cat(dense_weight, dim=1) - message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) - if md.qkv_bias: - message["qkv bias"] = torch.cat(qkv_bias, dim=0) - if md.linear_bias: - if md.swiglu: - for tp_rank in range(tp_size): - mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0) - message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias], dim=0) - message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias], dim=0) - else: - message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) - - queue_put(f"transformer layer {layer_num}", message) - - # Send final norm from tp_rank 0. - message = {"weight": model.language_model.encoder.final_norm.weight.data} - queue_put("final norm", message) - - if md.output_layer: - message = {"weight": model.language_model.output_layer.weight.data} - queue_put("output layer", message) - - queue.put("done") - - if args.checkpoint_type == "meta": - shutil.rmtree(os.path.join(args.load_dir)) - - -def load_checkpoint(queue, args): - try: - _load_checkpoint(queue, args) - except Exception: - queue.put("exit") - raise diff --git a/tools/checkpoint/loader_mixtral_hf.py b/tools/checkpoint/loader_mixtral_hf.py index 8c57a9737c8..3774ac2e5e6 100644 --- a/tools/checkpoint/loader_mixtral_hf.py +++ b/tools/checkpoint/loader_mixtral_hf.py @@ -183,7 +183,7 @@ def _load_checkpoint(queue, args): try: from megatron.core import mpu from megatron.core.enums import ModelType - from megatron.legacy.model import module + from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.training.arguments import parse_args, validate_args from megatron.training.global_vars import set_args, set_global_variables except ModuleNotFoundError: @@ -256,7 +256,7 @@ def check_for_arg(arg_name, default=None): margs.model_type = ModelType.encoder_or_decoder # Suppress warning about torch.distributed not being initialized. - module.MegatronModule.embedding_warning_printed = True + LanguageModule.embedding_warning_printed = True set_global_variables(margs, build_tokenizer=False) mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) diff --git a/tools/checkpoint/remap_gpt_dsa_to_mamba.py b/tools/checkpoint/remap_gpt_dsa_to_mamba.py index e5acbb09673..cbeec57c23d 100644 --- a/tools/checkpoint/remap_gpt_dsa_to_mamba.py +++ b/tools/checkpoint/remap_gpt_dsa_to_mamba.py @@ -1,16 +1,16 @@ #!/usr/bin/env python3 # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -"""Convert a GPTModel DSA checkpoint to a MambaModel-compatible checkpoint. +"""Convert a GPTModel DSA checkpoint to a HybridModel-compatible checkpoint. A GPTModel with ``--experimental-attention-variant dsa`` uses one combined -TransformerLayer per model layer (attention + MLP). The equivalent MambaModel +TransformerLayer per model layer (attention + MLP). The equivalent HybridModel with pattern ``D-D-...`` stores them as two separate layers: * Layer 2N – DSA attention (TransformerLayer: input_layernorm + MLASelfAttention) * Layer 2N+1 – MLP (MLPLayer: fused-norm MLP) This script loads a GPTModel Distributed Checkpoint (DCP), remaps the state-dict -keys, and saves a new DCP that can be loaded by MambaModel. +keys, and saves a new DCP that can be loaded by HybridModel. Usage ----- @@ -43,14 +43,14 @@ def _remap_key(key: str, num_gpt_layers: int) -> str: - """Return the MambaModel state-dict key corresponding to *key* from GPTModel. + """Return the HybridModel state-dict key corresponding to *key* from GPTModel. Args: key: A key from the GPTModel state dict. num_gpt_layers: Total number of GPT decoder layers (across all PP stages). Returns: - The remapped key for MambaModel. + The remapped key for HybridModel. Raises: ValueError: If an unexpected sub-key is encountered in a decoder layer. @@ -58,7 +58,7 @@ def _remap_key(key: str, num_gpt_layers: int) -> str: layer_prefix = "decoder.layers." final_ln_prefix = "decoder.final_layernorm." - # Final layernorm name differs between TransformerBlock and MambaStack + # Final layernorm name differs between TransformerBlock and HybridStack if key.startswith(final_ln_prefix): return "decoder.final_norm." + key[len(final_ln_prefix) :] @@ -96,11 +96,11 @@ def _remap_state_dict(gpt_sd: Dict, num_gpt_layers: int) -> Dict: def convert(input_path: Path, output_path: Path, num_gpt_layers: int) -> None: - """Load a GPTModel DCP checkpoint, remap keys, and save as MambaModel DCP. + """Load a GPTModel DCP checkpoint, remap keys, and save as HybridModel DCP. Args: input_path: Path to the GPTModel DCP checkpoint directory. - output_path: Destination directory for the MambaModel DCP checkpoint. + output_path: Destination directory for the HybridModel DCP checkpoint. num_gpt_layers: Number of GPT decoder layers in the original model. """ try: @@ -134,7 +134,7 @@ def convert(input_path: Path, output_path: Path, num_gpt_layers: int) -> None: output_path.mkdir(parents=True, exist_ok=True) torch_save_to_dcp(str(tmp_mamba), str(output_path)) - print(f"MambaModel DCP checkpoint saved to: {output_path}") + print(f"HybridModel DCP checkpoint saved to: {output_path}") finally: for tmp in (tmp_flat, output_path.parent / "_tmp_mamba_flat.pt"): @@ -144,7 +144,7 @@ def convert(input_path: Path, output_path: Path, num_gpt_layers: int) -> None: def main() -> None: parser = argparse.ArgumentParser( - description="Convert GPTModel DSA checkpoint to MambaModel-compatible format." + description="Convert GPTModel DSA checkpoint to HybridModel-compatible format." ) parser.add_argument( "--input", @@ -156,7 +156,7 @@ def main() -> None: "--output", required=True, type=Path, - help="Destination path for the MambaModel DCP checkpoint.", + help="Destination path for the HybridModel DCP checkpoint.", ) parser.add_argument( "--num-gpt-layers", diff --git a/tools/checkpoint/saver_base.py b/tools/checkpoint/saver_base.py index b67d75a287a..bfddc5b2f83 100644 --- a/tools/checkpoint/saver_base.py +++ b/tools/checkpoint/saver_base.py @@ -145,7 +145,6 @@ def parse_megatron_args(self): validate_args(margs) # Use M-core models & unset loaded paths. - margs.use_legacy_models = False margs.blendable_index_path = None margs.data_path = [] margs.load = None diff --git a/tools/checkpoint/saver_legacy.py b/tools/checkpoint/saver_legacy.py deleted file mode 100644 index e0e79dba3f4..00000000000 --- a/tools/checkpoint/saver_legacy.py +++ /dev/null @@ -1,426 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import os -import sys -import torch - -from functools import partial - -from tools.checkpoint.utils import _ConverterFakeProcessGroup - - -def add_arguments(parser): - group = parser.add_argument_group(title='Megatron saver') - - group.add_argument('--megatron-path', type=str, default=None, - help='Base directory of Megatron repository') - - group.add_argument('--target-tensor-parallel-size', type=int, - help='Target tensor model parallel size, defaults to the tensor parallel size ' - 'in the input checkpoint if provided by the loader, otherwise to 1') - group.add_argument('--target-pipeline-parallel-size', type=int, - help='Target tensor model parallel size, default to the pipeline parall size ' - 'in the input checkpoint if provided by the loader, otherwise to 1') - group.add_argument('--saver-transformer-impl', default='local', - choices=['local', 'transformer_engine'], - help='Which Transformer implementation to use.') - -def save_checkpoint(queue, args): - # Search in directory above this - sys.path.append(os.path.abspath( - os.path.join(os.path.dirname(__file__), - os.path.pardir, - os.path.pardir))) - if args.megatron_path is not None: - sys.path.insert(0, args.megatron_path) - - try: - from megatron.training.arguments import (parse_args, validate_args) - from megatron.training.checkpointing import save_checkpoint - from megatron.training.global_vars import set_global_variables, get_args - from megatron.core.enums import ModelType - from megatron.core.tokenizers.utils.build_tokenizer import vocab_size_with_padding - from megatron.core import mpu - except ModuleNotFoundError: - print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") - exit(1) - - def queue_get(name=None): - val = queue.get() - if val == "exit": - print("Loader exited, exiting saver") - exit(1) - if name is not None and args.checking and val["name"] != name: - val_name = val["name"] - print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.') - exit(1) - if name is not None: - print(f"received {name}") - return val - - def check_message(msg): - if not args.checking: - return - msg_name = msg.pop("name") - if len(msg.keys()) > 0: - print(f"Unexpected values in {msg_name}:") - for key in msg.keys(): - print(f" {key}") - print(f"Exiting. If you want to ignore this, use the argument --no-checking.") - exit(1) - - md = queue_get() - - if args.target_tensor_parallel_size is None: - if hasattr(md, 'previous_tensor_parallel_size'): - args.target_tensor_parallel_size = md.previous_tensor_parallel_size - else: - print( - "loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. " - "Default to 1.") - args.target_tensor_parallel_size = 1 - - if args.target_pipeline_parallel_size is None: - if hasattr(md, 'previous_pipeline_parallel_size'): - args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size - else: - print( - "loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. " - "Default to 1.") - args.target_pipeline_parallel_size = 1 - - # Arguments do sanity checks on the world size, but we don't care, - # so trick it into thinking we are plenty of processes - if args.target_tensor_parallel_size is not None and args.target_pipeline_parallel_size is not None: - os.environ["WORLD_SIZE"] = f'{args.target_tensor_parallel_size * args.target_pipeline_parallel_size}' - - # We want all arguments to come from us - sys.argv = ['script.py', - '--num-layers', str(md.num_layers), - '--hidden-size', str(md.hidden_size), - '--seq-length', str(md.seq_length), - '--num-attention-heads', str(md.num_attention_heads), - '--max-position-embeddings', str(md.max_position_embeddings), - '--position-embedding-type', str(md.position_embedding_type), - '--tokenizer-type', str(md.tokenizer_type), - '--tensor-model-parallel-size', str(args.target_tensor_parallel_size), - '--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size), - '--no-masked-softmax-fusion', - '--no-bias-gelu-fusion', - '--no-bias-dropout-fusion', - '--use-cpu-initialization', - '--micro-batch-size', '1', - '--no-load-optim', - '--no-load-rng', - '--no-save-optim', - '--no-save-rng', - '--no-initialization', - '--save-interval', '1', - '--save', args.save_dir, - '--ckpt-format', 'torch', # only 'torch' supported for conversion - '--no-one-logger', - ] - - if md.make_vocab_size_divisible_by is not None: - sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)]) - if md.params_dtype == torch.float16: - sys.argv.append('--fp16') - elif md.params_dtype == torch.bfloat16: - sys.argv.append('--bf16') - - if md.output_layer: - sys.argv.append('--untie-embeddings-and-output-weights') - if not md.linear_bias: - sys.argv.append('--disable-bias-linear') - - if md.model_type == 'BERT' and not md.bert_binary_head: - sys.argv.append('--bert-no-binary-head') - - margs = parse_args() - - if hasattr(md, 'checkpoint_args'): - # These are arguments that we are either changing, or cause problems for validation if they are set - # Note that some of these deal with T5 so will need to be changed if we support T5. - args_to_keep = ['tensor_model_parallel_size', 'pipeline_model_parallel_size', 'world_size', 'params_dtype', - 'num_layers_per_virtual_pipeline_stage', 'virtual_pipeline_model_parallel_size', - 'masked_softmax_fusion', 'bias_gelu_fusion', 'bias_dropout_fusion', - 'sequence_parallel', - 'no_load_optim', 'no_load_rng', 'no_save_optim', 'no_save_rng', - 'vocab_file', 'tokenizer_model', - 'save_interval', 'save', - 'perform_initialization', 'use_cpu_initialization', - 'recompute_granularity', 'recompute_num_layers', 'recompute_method', - 'encoder_num_layers', 'encoder_seq_length', - 'distribute_saved_activations', - 'train_iters', 'lr_decay_iters', 'lr_warmup_iters', 'lr_warmup_fraction', - 'start_weight_decay', 'end_weight_decay', 'bf16', 'fp16', - 'ckpt_format', - ] - - for arg, value in vars(md.checkpoint_args).items(): - if arg in args_to_keep: - continue - if not hasattr(margs, arg): - print(f"Checkpoint had argument {arg} but new arguments does not have this.") - continue - if getattr(margs, arg) != value: - print(f"Overwriting default {arg} value {getattr(margs, arg)} with value from checkpoint {value}.") - setattr(margs, arg, value) - - margs.inference_batch_times_seqlen_threshold = -1 - - validate_args(margs) - - # Use MLM models. - margs.use_legacy_models = True - margs.transformer_impl = args.saver_transformer_impl - - # Do not instantiate Tensorboard - margs.tensorboard_dir = None - - set_global_variables(margs, build_tokenizer=False) - - # margs = megatron args - margs = get_args() - - if hasattr(md, 'consumed_train_samples'): - margs.consumed_train_samples = md.consumed_train_samples - margs.consumed_valid_samples = md.consumed_valid_samples - print(f"Setting consumed_train_samples to {margs.consumed_train_samples}" - f" and consumed_valid_samples to {margs.consumed_valid_samples}") - else: - print("consumed_train_samples not provided.") - - # Determine how to make our models - if md.model_type == 'GPT': - from model_provider import model_provider as common_model_provider - from gpt_builders import gpt_builder - model_provider = partial(common_model_provider, gpt_builder) - margs.model_type = ModelType.encoder_or_decoder - elif md.model_type == 'BERT': - from pretrain_bert import model_provider - margs.model_type = ModelType.encoder_or_decoder - else: - raise Exception(f'unrecognized model type: {args.model_type}') - - def get_models(count, dtype, pre_process, post_process): - models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)] - return models - - # fake initializing distributed - mpu.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size) - mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size) - mpu.set_tensor_model_parallel_rank(0) - mpu.set_pipeline_model_parallel_rank(0) - - # For backward compatibility during local parallel states refactoring - fake_tp_group = _ConverterFakeProcessGroup(size=args.target_tensor_parallel_size) - mpu._TENSOR_MODEL_PARALLEL_GROUP = fake_tp_group - - # Embeddings - # ----------- - embeddings_msg = queue_get("embeddings") - - pos_embed = None - if md.position_embedding_type == 'learned_absolute': - pos_embed = embeddings_msg.pop("position embeddings") - orig_word_embed = embeddings_msg.pop("word embeddings") - check_message(embeddings_msg) - - # Deal with padding - if md.true_vocab_size is not None: - # figure out what our padded vocab size is - orig_vocab_size = orig_word_embed.shape[0] - margs.padded_vocab_size = vocab_size_with_padding(md.true_vocab_size, margs) - - # Cut out extra padding we don't need - if orig_vocab_size > margs.padded_vocab_size: - full_word_embed = orig_word_embed[0:margs.padded_vocab_size, :] - - # Expanding embedding to larger size by replicating final entry - elif orig_vocab_size < margs.padded_vocab_size: - padding_size = margs.padded_vocab_size - orig_vocab_size - - full_word_embed = torch.cat(( - orig_word_embed, - orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1))) - - # Same size! - else: - full_word_embed = orig_word_embed - else: - print("Original vocab size not specified, leaving embedding table as-is. " - "If you've changed the tensor parallel size this could cause problems.") - margs.padded_vocab_size = orig_word_embed.shape[0] - full_word_embed = orig_word_embed - - # Split into new tensor model parallel sizes - out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0) - - # Make models for first pipeline stage and fill in embeddings - mpu.set_pipeline_model_parallel_rank(0) - post_process = args.target_pipeline_parallel_size == 1 - models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process) - for tp_rank, model in enumerate(models): - model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank]) - if pos_embed is not None: - model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed) - else: - assert not hasattr(model.language_model.embedding, "position_embeddings") - - # Transformer layers - # ------------------- - total_layer_num = 0 - for pp_rank in range(args.target_pipeline_parallel_size): - # For later pipeline parallel ranks, make the new models - if pp_rank > 0: - mpu.set_pipeline_model_parallel_rank(pp_rank) - post_process = pp_rank == args.target_pipeline_parallel_size - 1 - models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process) - - for layer in range(len(models[0].language_model.encoder.layers)): - msg = queue_get(f"transformer layer {total_layer_num}") - - # duplicated tensors - input_norm_weight = msg.pop("input norm weight") - if md.norm_has_bias: - input_norm_bias = msg.pop("input norm bias") - post_norm_weight = msg.pop("post norm weight") - if md.norm_has_bias: - post_norm_bias = msg.pop("post norm bias") - if md.linear_bias: - dense_bias = msg.pop("dense bias") - mlp_l1_bias = msg.pop("mlp l1 bias") - - # Split up the parallel tensors - qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0) - dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1) - mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1) - - # Special handling for swiglu - if md.swiglu: - mlp_l0_weight_W = torch.chunk(msg.pop("mlp l0 weight W"), args.target_tensor_parallel_size, dim=0) - mlp_l0_weight_V = torch.chunk(msg.pop("mlp l0 weight V"), args.target_tensor_parallel_size, dim=0) - mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(mlp_l0_weight_W, mlp_l0_weight_V)] - else: - mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0) - - if md.qkv_bias: - qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0) - if md.linear_bias: - if md.swiglu: - mlp_l0_bias_W = torch.chunk(msg.pop("mlp l0 bias W"), args.target_tensor_parallel_size, dim=0) - mlp_l0_bias_V = torch.chunk(msg.pop("mlp l0 bias V"), args.target_tensor_parallel_size, dim=0) - mlp_l0_bias = [torch.cat(bias, dim=0) for bias in zip(mlp_l0_bias_W, mlp_l0_bias_V)] - else: - mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0) - - # Save them to the model - for tp_rank in range(args.target_tensor_parallel_size): - l = models[tp_rank].language_model.encoder.layers[layer] - l.input_norm.weight.data.copy_(input_norm_weight) - if md.norm_has_bias: - l.input_norm.bias.data.copy_(input_norm_bias) - l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank]) - l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank]) - l.post_attention_norm.weight.data.copy_(post_norm_weight) - if md.norm_has_bias: - l.post_attention_norm.bias.data.copy_(post_norm_bias) - l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank]) - l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank]) - if md.qkv_bias: - l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank]) - if md.linear_bias: - l.self_attention.dense.bias.data.copy_(dense_bias) - l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank]) - l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias) - - total_layer_num = total_layer_num + 1 - check_message(msg) - - if post_process: - msg = queue_get("final norm") - final_norm_weight = msg.pop("weight") - if md.norm_has_bias: - final_norm_bias = msg.pop("bias") - for tp_rank in range(args.target_tensor_parallel_size): - models[tp_rank].language_model.encoder.final_norm.weight.data.copy_(final_norm_weight) - if md.norm_has_bias: - models[tp_rank].language_model.encoder.final_norm.bias.data.copy_(final_norm_bias) - if pp_rank != 0 and not md.output_layer: - # Copy word embeddings to final pipeline rank - models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank]) - del final_norm_weight - if md.norm_has_bias: - del final_norm_bias - check_message(msg) - - if md.output_layer: - msg = queue_get("output layer") - if not hasattr(models[0].language_model, 'output_layer'): - print("ERROR: got an output layer, but model does not have one") - exit(1) - output_layer_weight = torch.chunk(msg.pop("weight"), args.target_tensor_parallel_size, dim=0) - for tp_rank in range(args.target_tensor_parallel_size): - models[tp_rank].language_model.output_layer.weight.data.copy_(output_layer_weight[tp_rank]) - del output_layer_weight - check_message(msg) - - msg = queue_get() - if msg != "done" and msg["name"] == "pooler": - if not hasattr(models[0].language_model, 'pooler'): - print("ERROR: got a pooler, but model does not have one") - exit(1) - print("received pooler") - pooler_weight = msg.pop("weight") - pooler_bias = msg.pop("bias") - for tp_rank in range(args.target_tensor_parallel_size): - models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight) - models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias) - del pooler_weight - del pooler_bias - check_message(msg) - msg = queue_get() - - if msg != "done" and msg["name"] == "lm head": - if not hasattr(models[0], 'lm_head'): - print("ERROR: got an lm head, but model does not have one") - exit(1) - print("received lm head") - lm_head_dense_weight = msg.pop("dense weight") - lm_head_dense_bias = msg.pop("dense bias") - lm_head_norm_weight = msg.pop("norm weight") - if md.norm_has_bias: - lm_head_norm_bias = msg.pop("norm bias") - for tp_rank in range(args.target_tensor_parallel_size): - models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight) - models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias) - models[tp_rank].lm_head.norm.weight.data.copy_(lm_head_norm_weight) - if md.norm_has_bias: - models[tp_rank].lm_head.norm.bias.data.copy_(lm_head_norm_bias) - check_message(msg) - msg = queue_get() - - if msg != "done" and msg["name"] == "binary head": - if not hasattr(models[0], 'binary_head'): - print("ERROR: got a binary head, but model does not have one") - exit(1) - print("received binary head") - binary_head_weight = msg.pop("weight") - binary_head_bias = msg.pop("bias") - for tp_rank in range(args.target_tensor_parallel_size): - models[tp_rank].binary_head.weight.data.copy_(binary_head_weight) - models[tp_rank].binary_head.bias.data.copy_(binary_head_bias) - check_message(msg) - msg = queue_get() - - if msg != "done": - print("ERROR: got some more data but was expecting to be done") - - for tp_rank in range(args.target_tensor_parallel_size): - fake_tp_group = mpu.get_tensor_model_parallel_group() - fake_tp_group.set_rank(tp_rank) - mpu.set_tensor_model_parallel_rank(tp_rank) - save_checkpoint(md.iteration, [models[tp_rank]], None, None, - num_floating_point_operations_so_far=0) - print("Done!") diff --git a/tools/prepare_cache.py b/tools/prepare_cache.py new file mode 100644 index 00000000000..a6cb3b7e795 --- /dev/null +++ b/tools/prepare_cache.py @@ -0,0 +1,209 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Prepare GPT dataset caches ahead of training. + +Unsupported configurations: + --mock-data, --sft, --fim-data, --step-batch-size-schedule +""" + +import argparse +import json +from typing import Any, Dict, List, Optional, Tuple + +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig +from megatron.core.datasets.utils import compile_helpers +from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer +from megatron.training import get_train_valid_test_num_samples +from megatron.training.arguments import parse_args, validate_args +from megatron.training.global_vars import set_args, unset_global_variables +from megatron.training.training import update_train_iters +from megatron.training.utils import get_blend_and_blend_per_split + +try: + from megatron.post_training.arguments import add_modelopt_args + + has_nvidia_modelopt = True +except ImportError: + has_nvidia_modelopt = False + + +def add_prepare_cache_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add cache-preparation specific arguments.""" + + group = parser.add_argument_group(title="prepare cache") + group.add_argument( + "--prepare-cache-world-size", + type=int, + default=None, + help=( + "Optional override for the effective world size used to derive data-parallel size and " + "dataset sample counts during cache preparation." + ), + ) + return parser + + +def _extra_args_provider(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = add_prepare_cache_args(parser) + if has_nvidia_modelopt: + parser = add_modelopt_args(parser) + return parser + + +def _normalize_prepare_cache_args(args: Any) -> None: + """Apply cache-preparation specific argument normalization.""" + + args.rank = 0 + + if args.prepare_cache_world_size is not None: + if args.prepare_cache_world_size <= 0: + raise ValueError("--prepare-cache-world-size must be positive") + args.world_size = args.prepare_cache_world_size + + +def _validate_prepare_cache_args(args: Any) -> None: + """Validate options that are intentionally unsupported for offline cache prep.""" + + if args.data_cache_path is None: + raise ValueError("--data-cache-path must be provided for cache preparation") + if args.mock_data: + raise ValueError("--mock-data is not supported by tools/prepare_cache.py") + if getattr(args, "sft", False): + raise ValueError("--sft is not supported by tools/prepare_cache.py") + if getattr(args, "fim_data", False): + raise ValueError("--fim-data is not supported by tools/prepare_cache.py") + if getattr(args, "step_batch_size_schedule", None) is not None: + raise ValueError("--step-batch-size-schedule is not supported by tools/prepare_cache.py") + + +def _disable_cache_load_only_flags(args: Any) -> Dict[str, bool]: + """Disable flags that only make sense when consuming an existing cache.""" + + ignored = { + "dataloader_fast_cache_load": bool(args.dataloader_fast_cache_load), + "dataloader_defer_npy_index_mmap": bool(args.dataloader_defer_npy_index_mmap), + } + args.dataloader_fast_cache_load = False + args.dataloader_defer_npy_index_mmap = False + return ignored + + +def _get_dataset_length(dataset: Optional[Any]) -> Optional[Any]: + if dataset is None: + return None + if isinstance(dataset, list): + return [len(ds) if ds is not None else None for ds in dataset] + return len(dataset) + + +def _print_effective_configuration( + args: Any, train_valid_test_num_samples: Any, ignored_flags: Dict[str, bool] +) -> None: + print("> preparing dataset cache with the following effective values:") + print(f" world size: {args.world_size}") + print(f" data parallel size: {args.data_parallel_size}") + print(f" global batch size: {args.global_batch_size}") + print(f" cache path: {args.data_cache_path}") + print(" > datasets target sizes (minimum size):") + print(f" train: {train_valid_test_num_samples[0]}") + print(f" validation: {train_valid_test_num_samples[1]}") + print(f" test: {train_valid_test_num_samples[2]}") + if ignored_flags["dataloader_fast_cache_load"]: + print("> ignoring --dataloader-fast-cache-load during cache preparation") + if ignored_flags["dataloader_defer_npy_index_mmap"]: + print("> ignoring --dataloader-defer-npy-index-mmap during cache preparation") + + +def core_gpt_dataset_config_from_args(args: Any) -> GPTDatasetConfig: + """Build the explicit GPTDatasetConfig used for offline cache preparation.""" + + tokenizer = build_tokenizer(args) + + blend: Optional[Tuple[List[str], Optional[List[float]]]] + blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] + blend, blend_per_split = get_blend_and_blend_per_split(args) + + sequences_per_dataset = None + if args.per_dataset_sequences_path is not None: + with open(args.per_dataset_sequences_path, "r") as f: + sequences_per_dataset = json.load(f) + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=blend, + blend_per_split=blend_per_split, + split=args.split, + multiple_validation_sets=args.multiple_validation_sets, + full_validation=args.full_validation, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + object_storage_cache_path=args.object_storage_cache_path, + mid_level_dataset_surplus=args.mid_level_dataset_surplus, + allow_ambiguous_pad_tokens=args.allow_ambiguous_pad_tokens, + fast_cache_load=args.dataloader_fast_cache_load, + sequences_per_dataset=sequences_per_dataset, + defer_npy_index_mmap=args.dataloader_defer_npy_index_mmap, + context_parallel_size=args.context_parallel_size, + data_parallel_size=args.data_parallel_size, + sequence_parallel_size=args.tensor_model_parallel_size * args.sequence_parallel, + dynamic_context_parallel=args.dynamic_context_parallel, + ) + + +def build_dataset_caches(args: Any) -> Dict[str, Any]: + """Build the dataset caches for the plain GPTDataset path.""" + + _validate_prepare_cache_args(args) + ignored_flags = _disable_cache_load_only_flags(args) + + unset_global_variables() + set_args(args) + + try: + # Derive train_iters from --train-samples when needed (pretrain() does the same). + update_train_iters(args) + train_valid_test_num_samples = get_train_valid_test_num_samples() + _print_effective_configuration(args, train_valid_test_num_samples, ignored_flags) + + compile_helpers() + + config = core_gpt_dataset_config_from_args(args) + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + GPTDataset, train_valid_test_num_samples, lambda: True, config + ).build() + + print("> finished preparing dataset cache") + print(f" train dataset length: {_get_dataset_length(train_ds)}") + print(f" validation dataset length: {_get_dataset_length(valid_ds)}") + print(f" test dataset length: {_get_dataset_length(test_ds)}") + + return { + "world_size": args.world_size, + "data_parallel_size": args.data_parallel_size, + "global_batch_size": args.global_batch_size, + "train_valid_test_num_samples": tuple(train_valid_test_num_samples), + "train_dataset_length": _get_dataset_length(train_ds), + "valid_dataset_length": _get_dataset_length(valid_ds), + "test_dataset_length": _get_dataset_length(test_ds), + } + finally: + unset_global_variables() + + +def main() -> Dict[str, Any]: + args = parse_args(extra_args_provider=_extra_args_provider, ignore_unknown_args=False) + _normalize_prepare_cache_args(args) + validate_args(args, defaults={"tokenizer_type": "GPT2BPETokenizer"}) + return build_dataset_caches(args) + + +if __name__ == "__main__": + main() diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index f472dd50dbf..2e0b7c96515 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -2,28 +2,31 @@ """Processing large data for pretraining.""" import argparse -import math import json +import math import os import sys -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), - os.path.pardir))) -import time -import gzip + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) import glob +import gzip import multiprocessing +import time + import numpy as np + try: import nltk from nltk.tokenize.punkt import PunktLanguageVars + nltk_available = True except ImportError: PunktLanguageVars = object # Fallback to the built-in object class nltk_available = False +from megatron.core.datasets import indexed_dataset from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer from megatron.training.arguments import _add_tokenizer_args -from megatron.core.datasets import indexed_dataset # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer @@ -39,6 +42,7 @@ class CustomLanguageVars(PunktLanguageVars): (?P\S+) # <-- Normally you would have \s+ here ))""" + class IdentitySplitter(object): def tokenize(self, *text): return text @@ -56,7 +60,9 @@ def initializer(self): print("NLTK is not available to split sentences.") exit() if os.environ.get("NLTK_DATA"): - library = os.path.join(os.environ.get("NLTK_DATA"), "tokenizers", "punkt", f"{self.args.lang}.pickle") + library = os.path.join( + os.environ.get("NLTK_DATA"), "tokenizers", "punkt", f"{self.args.lang}.pickle" + ) url = f"file:{library}" else: library = os.path.join("tokenizers", "punkt", f"{self.args.lang}.pickle") @@ -65,8 +71,8 @@ def initializer(self): if self.args.keep_newlines: # this prevents punkt from eating newlines after sentences Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer( - train_text = splitter._params, - lang_vars = CustomLanguageVars()) + train_text=splitter._params, lang_vars=CustomLanguageVars() + ) else: Encoder.splitter = splitter @@ -79,7 +85,10 @@ def split(self, json_line): for key in self.args.json_keys: text = data[key] max_len = 1000000 - tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)] + tokens_list = [ + Encoder.splitter.tokenize(text[i : i + max_len]) + for i in range(0, len(text), max_len) + ] output[key] = [tokens for partial in tokens_list for tokens in partial] return json.dumps(output), len(json_line) @@ -118,12 +127,14 @@ def print_processing_stats(self, count, proc_start, total_bytes_processed): if count % self.args.log_interval == 0: current = time.time() elapsed = current - proc_start - mbs = total_bytes_processed/elapsed/1024/1024 - print(f"Processed {count} documents", - f"({count/elapsed} docs/s, {mbs} MB/s).", - file=sys.stderr) + mbs = total_bytes_processed / elapsed / 1024 / 1024 + print( + f"Processed {count} documents", + f"({count/elapsed} docs/s, {mbs} MB/s).", + file=sys.stderr, + ) if self.args.find_optimal_num_workers: - self.performance.append(count/elapsed) + self.performance.append(count / elapsed) def split_sentences(self, file_name): input_file_name, output_file_name = file_name @@ -168,10 +179,8 @@ def process_json_file(self, file_name): builders = {} for key in self.args.json_keys: - output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix, - key, level) - output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix, - key, level) + output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix, key, level) + output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix, key, level) builders[key] = indexed_dataset.IndexedDatasetBuilder( output_bin_files[key], dtype=indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size), @@ -191,7 +200,8 @@ def process_json_file(self, file_name): self.print_processing_stats(i, proc_start, total_bytes_processed) fin.close() - builders[key].finalize(output_idx_files[key]) + for key in self.args.json_keys: + builders[key].finalize(output_idx_files[key]) pool.close() pool.join() @@ -203,47 +213,86 @@ def get_args(): parser = argparse.ArgumentParser() parser = _add_tokenizer_args(parser) group = parser.add_argument_group(title='input data') - group.add_argument('--input', type=str, required=True, - help='Path to input JSON') - group.add_argument('--json-keys', nargs='+', default=['text'], - help='space separate listed of keys to extract from json') - group.add_argument('--split-sentences', action='store_true', - help='Split documents into sentences.') - group.add_argument('--keep-newlines', action='store_true', - help='Keep newlines between sentences when splitting.') + group.add_argument('--input', type=str, required=True, help='Path to input JSON') + group.add_argument( + '--json-keys', + nargs='+', + default=['text'], + help='space separate listed of keys to extract from json', + ) + group.add_argument( + '--split-sentences', action='store_true', help='Split documents into sentences.' + ) + group.add_argument( + '--keep-newlines', + action='store_true', + help='Keep newlines between sentences when splitting.', + ) group = parser.add_argument_group(title='tokenization process') - group.add_argument('--append-eod', action='store_true', - help='Append an token to the end of a document.') - group.add_argument('--lang', type=str, default='english', - help='Language to use for NLTK-powered sentence splitting.') + group.add_argument( + '--append-eod', action='store_true', help='Append an token to the end of a document.' + ) + group.add_argument( + '--lang', + type=str, + default='english', + help='Language to use for NLTK-powered sentence splitting.', + ) group = parser.add_argument_group(title='output data') - group.add_argument('--output-prefix', type=str, required=True, - help='Path to binary output file without suffix') + group.add_argument( + '--output-prefix', type=str, required=True, help='Path to binary output file without suffix' + ) group = parser.add_argument_group(title='runtime') - group.add_argument('--workers', type=int, required=True, - help=('Number of worker processes to launch.' - 'A good default for fast pre-processing ' - 'is: (workers * partitions) = available CPU cores.')) - group.add_argument('--find-optimal-num-workers', action='store_true', - help=('Find optimal number of workers.' - 'Script will run few small jobs with ' - 'different number of workers to define ' - 'optimal number of workers in terms of performance.')) - group.add_argument('--workers-to-check', nargs='+', type=int, default=[16, 32, 64], - help=('list of workers to run data processing with ' - 'to find optimal number of workers. ' - 'Works only when --find-optimal-num-workers is enabled. ')) - group.add_argument('--max-documents', type=int, default=100_000, - help=('Maximum number of documents to preprocess ' - 'to find optimal number of workers.' - 'Works only when --find-optimal-num-workers is enabled.')) - group.add_argument('--partitions', type=int, default=1, - help='Number of file partitions') - group.add_argument('--log-interval', type=int, default=1000, - help='Interval between progress updates') - group.add_argument('--keep-sequential-samples', action='store_true', - help='Ensure ordering of samples in .jsonl files is ' - 'preserved when using partitions>1.') + group.add_argument( + '--workers', + type=int, + required=True, + help=( + 'Number of worker processes to launch.' + 'A good default for fast pre-processing ' + 'is: (workers * partitions) = available CPU cores.' + ), + ) + group.add_argument( + '--find-optimal-num-workers', + action='store_true', + help=( + 'Find optimal number of workers.' + 'Script will run few small jobs with ' + 'different number of workers to define ' + 'optimal number of workers in terms of performance.' + ), + ) + group.add_argument( + '--workers-to-check', + nargs='+', + type=int, + default=[16, 32, 64], + help=( + 'list of workers to run data processing with ' + 'to find optimal number of workers. ' + 'Works only when --find-optimal-num-workers is enabled. ' + ), + ) + group.add_argument( + '--max-documents', + type=int, + default=100_000, + help=( + 'Maximum number of documents to preprocess ' + 'to find optimal number of workers.' + 'Works only when --find-optimal-num-workers is enabled.' + ), + ) + group.add_argument('--partitions', type=int, default=1, help='Number of file partitions') + group.add_argument( + '--log-interval', type=int, default=1000, help='Interval between progress updates' + ) + group.add_argument( + '--keep-sequential-samples', + action='store_true', + help='Ensure ordering of samples in .jsonl files is ' 'preserved when using partitions>1.', + ) args = parser.parse_args() args.keep_empty = False @@ -267,7 +316,8 @@ def get_file_name(args, file_id): file_names = { 'partition': input_file_name, 'sentence_split': sentence_split_file, - 'output_prefix': output_prefix} + 'output_prefix': output_prefix, + } return file_names @@ -290,12 +340,12 @@ def find_optimal_num_workers(performance, partitions): # sort by average performance (descending: fastest first) results.sort(key=lambda x: x[1], reverse=True) - + print("\n-----------------------------------") print("Performance results (fastest → slowest):") for i, (workers, avg_perf) in enumerate(results): print(f"{i+1}. {workers * partitions} workers → avg. docs/s: {avg_perf:.4f}") - + best_workers, best_perf = results[0] print("\n-----------------------------------") @@ -317,7 +367,9 @@ def main(): f"because it's not divisible by num_partitions ({args.partitions})" ) workers.remove(num_workers) - assert workers, "Please, provide valid number of workers which is divisible by number of partitions." + assert ( + workers + ), "Please, provide valid number of workers which is divisible by number of partitions." if args.find_optimal_num_workers: args.log_interval = 1000 @@ -328,8 +380,7 @@ def main(): if nltk_available: nltk.download("punkt", quiet=True, download_dir=os.environ.get("NLTK_DATA")) else: - raise Exception( - "nltk library required for sentence splitting is not available.") + raise Exception("nltk library required for sentence splitting is not available.") in_ss_out_names = [] if args.partitions == 1: @@ -338,7 +389,8 @@ def main(): file_names = { 'partition': args.input, 'sentence_split': sentence_split_file, - 'output_prefix': args.output_prefix} + 'output_prefix': args.output_prefix, + } in_ss_out_names.append(file_names) else: in_file_names = glob.glob(args.input) @@ -350,7 +402,7 @@ def main(): with open(filename, "r") as fin: for fc, _ in enumerate(fin): pass - total_sample_count += (fc + 1) + total_sample_count += fc + 1 partition_size = math.ceil(total_sample_count / args.partitions) # create .jsonl parition files @@ -362,7 +414,9 @@ def main(): partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions) # check to see if paritions with split sentences already created - split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions) + split_sentences_present = check_files_exist( + in_ss_out_names, 'sentence_split', args.partitions + ) if not partitions_present and not split_sentences_present: # populate .jsonl partition files from parent files @@ -372,7 +426,8 @@ def main(): partitioned_input_files.append(partitioned_input_file) index = 0 - if args.keep_sequential_samples: line_count = 0 + if args.keep_sequential_samples: + line_count = 0 for in_file_name in in_file_names: # support for gzip files if in_file_name.endswith(".gz"): @@ -387,24 +442,28 @@ def main(): if line_count % partition_size == 0: index += 1 else: - index = (index + 1)%args.partitions + index = (index + 1) % args.partitions fin.close() for idx in range(args.partitions): partitioned_input_files[idx].close() - partition = Partition(args, num_workers//args.partitions) + partition = Partition(args, num_workers // args.partitions) # check to see if paritions with split sentences already created - split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions) + split_sentences_present = check_files_exist( + in_ss_out_names, 'sentence_split', args.partitions + ) # split sentences in partition files if args.split_sentences and not split_sentences_present: processes = [] for name in in_ss_out_names: - p = multiprocessing.Process(target=partition.split_sentences, - args=((name['partition'], name['sentence_split']),)) + p = multiprocessing.Process( + target=partition.split_sentences, + args=((name['partition'], name['sentence_split']),), + ) p.start() processes.append(p) @@ -415,7 +474,9 @@ def main(): continue def process_json_file(name, q, input_key): - worker_performance = partition.process_json_file((name[input_key], name['output_prefix'])) + worker_performance = partition.process_json_file( + (name[input_key], name['output_prefix']) + ) q.put(worker_performance) # encode partition files in parallel @@ -450,10 +511,8 @@ def process_json_file(name, q, input_key): tokenizer = build_tokenizer(args) for key in args.json_keys: - output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, - key, level) - output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, - key, level) + output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, key, level) + output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, key, level) builders[key] = indexed_dataset.IndexedDatasetBuilder( output_bin_files[key], dtype=indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size), @@ -461,8 +520,7 @@ def process_json_file(name, q, input_key): for name in in_ss_out_names: parition_output_prefix = name['output_prefix'] - full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix, - key, level) + full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix, key, level) builders[key].add_index(full_partition_output_prefix) builders[key].finalize(output_idx_files[key]) @@ -470,7 +528,7 @@ def process_json_file(name, q, input_key): if args.find_optimal_num_workers: find_optimal_num_workers(performance, args.partitions) + if __name__ == '__main__': main() - diff --git a/tools/preprocess_mmdata.py b/tools/preprocess_mmdata.py index b63c9e99cad..eb4e815e5c7 100755 --- a/tools/preprocess_mmdata.py +++ b/tools/preprocess_mmdata.py @@ -8,20 +8,22 @@ import multiprocessing import os import sys + import numpy as np from torchvision.transforms import ToTensor -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), - os.path.pardir))) + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) import time import torch + try: from nltk.tokenize.punkt import PunktLanguageVars except ImportError: PunktLanguageVars = object # Fallback to the built-in object class -from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer from megatron.core.datasets.indexed_dataset import IndexedDatasetBuilder +from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer @@ -37,10 +39,12 @@ class CustomLanguageVars(PunktLanguageVars): (?P\S+) # <-- Normally you would have \s+ here ))""" + class IdentitySplitter(object): def tokenize(self, *text): return text + class Encoder(object): def __init__(self, args): self.args = args @@ -59,7 +63,9 @@ def encode(self, input_pair): if len(sentence_ids) > 0 and self.args.append_eod: sentence_ids = sentence_ids[:pad_len] current_length = len(sentence_ids) - sentence_ids.extend([Encoder.tokenizer.eod for _ in range(max(0,pad_len-current_length))]) + sentence_ids.extend( + [Encoder.tokenizer.eod for _ in range(max(0, pad_len - current_length))] + ) with open(img_path, "rb") as tf: xs = bytearray(tf.read()) @@ -67,60 +73,100 @@ def encode(self, input_pair): xs.extend([0 for _ in range(img_pad)]) img_raw = np.frombuffer(xs, dtype=np.int32) img_raw = np.insert(img_raw, 0, img_pad) - + return sentence_ids, img_raw, len(json_line) + def get_args(): parser = argparse.ArgumentParser() group = parser.add_argument_group(title='input data') - group.add_argument('--input', type=str, required=True, - help='Path to input JSON') - group.add_argument('--input-image', type=str, required=True, - help='Path to input image folder') - - group.add_argument('--pad-length', type=int, required=True, - help='Pad length of preprocessed text') - - group.add_argument('--split-sentences', action='store_true', - help='Split documents into sentences.') - group.add_argument('--keep-newlines', action='store_true', - help='Keep newlines between sentences when splitting.') + group.add_argument('--input', type=str, required=True, help='Path to input JSON') + group.add_argument('--input-image', type=str, required=True, help='Path to input image folder') + + group.add_argument( + '--pad-length', type=int, required=True, help='Pad length of preprocessed text' + ) + + group.add_argument( + '--split-sentences', action='store_true', help='Split documents into sentences.' + ) + group.add_argument( + '--keep-newlines', + action='store_true', + help='Keep newlines between sentences when splitting.', + ) group = parser.add_argument_group(title='tokenizer') - group.add_argument('--tokenizer-type', type=str, required=True, - choices=['BertWordPieceLowerCase','BertWordPieceCase', - 'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'], - help='What type of tokenizer to use.') - group.add_argument('--vocab-file', type=str, default=None, - help='Path to the vocab file') - group.add_argument('--merge-file', type=str, default=None, - help='Path to the BPE merge file (if necessary).') - group.add_argument('--append-eod', action='store_true', - help='Append an token to the end of a document.') - group.add_argument('--lang', type=str, default='english', - help='Language to use for NLTK-powered sentence splitting.') - group.add_argument('--tokenizer-model', type=str, default=None, - help='sentencepeice tokenizer model.') - group.add_argument('--tokenizer-metadata', type=str, default=None, - help='Path to tokenizer metadata in json format.') - group.add_argument('--tokenizer-special-tokens', type=str, nargs='+', default=None, - help='List of special tokens. For TikTokenizer needs to have ' - '["", "", "", "", "", "", ""]') - group.add_argument('--tokenizer-hf-no-use-fast', action='store_true', default=False, - help='Whether to use fast HuggingFace tokenizer.') - group.add_argument('--tokenizer-hf-no-include-special-tokens', action='store_true', default=False, - help='Converting text to ids will not include special for HuggingFace tokenizer.') - group.add_argument("--trust-remote-code", action="store_true", default=False, - help='Whether or not to allow PreTrainedTokenizer to execute remote code') + group.add_argument( + '--tokenizer-type', + type=str, + required=True, + choices=[ + 'BertWordPieceLowerCase', + 'BertWordPieceCase', + 'GPT2BPETokenizer', + 'SentencePieceTokenizer', + 'GPTSentencePieceTokenizer', + ], + help='What type of tokenizer to use.', + ) + group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file') + group.add_argument( + '--merge-file', type=str, default=None, help='Path to the BPE merge file (if necessary).' + ) + group.add_argument( + '--append-eod', action='store_true', help='Append an token to the end of a document.' + ) + group.add_argument( + '--lang', + type=str, + default='english', + help='Language to use for NLTK-powered sentence splitting.', + ) + group.add_argument( + '--tokenizer-model', type=str, default=None, help='sentencepeice tokenizer model.' + ) + group.add_argument( + '--metadata-path', type=str, default=None, help='Path to tokenizer metadata in json format.' + ) + group.add_argument( + '--special-tokens', + type=str, + nargs='+', + default=None, + help='List of special tokens. For TikTokenizer needs to have ' + '["", "", "", "", "", "", ""]', + ) + group.add_argument( + '--tokenizer-hf-no-use-fast', + action='store_true', + default=False, + help='Whether to use fast HuggingFace tokenizer.', + ) + group.add_argument( + '--tokenizer-hf-no-include-special-tokens', + action='store_true', + default=False, + help='Converting text to ids will not include special for HuggingFace tokenizer.', + ) + group.add_argument( + "--trust-remote-code", + action="store_true", + default=False, + help='Whether or not to allow PreTrainedTokenizer to execute remote code', + ) group = parser.add_argument_group(title='output data') - group.add_argument('--output-prefix', type=str, required=True, - help='Path to binary output file without suffix') + group.add_argument( + '--output-prefix', type=str, required=True, help='Path to binary output file without suffix' + ) group = parser.add_argument_group(title='runtime') - group.add_argument('--workers', type=int, default=1, - help='Number of worker processes to launch') - group.add_argument('--log-interval', type=int, default=100, - help='Interval between progress updates') + group.add_argument( + '--workers', type=int, default=1, help='Number of worker processes to launch' + ) + group.add_argument( + '--log-interval', type=int, default=100, help='Interval between progress updates' + ) args = parser.parse_args() args.keep_empty = False @@ -133,6 +179,7 @@ def get_args(): return args + def main(): args = get_args() startup_start = time.time() @@ -142,13 +189,15 @@ def main(): pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) fin = open(args.input, 'r', encoding='utf-8') - img_paths = [os.path.join(args.input_image, basename) for basename in os.listdir(args.input_image)] + img_paths = [ + os.path.join(args.input_image, basename) for basename in os.listdir(args.input_image) + ] encoded_docs = pool.imap(encoder.encode, zip(fin, img_paths), 25) print(f"Vocab size: {tokenizer.vocab_size}") print(f"Output prefix: {args.output_prefix}") - + output_bin_files = "{}.bin".format(args.output_prefix) output_idx_files = "{}.idx".format(args.output_prefix) @@ -159,7 +208,7 @@ def main(): total_bytes_processed = 0 print("Time to startup:", startup_end - startup_start) - + for i, (sentence, img_raw, bytes_processed) in enumerate(encoded_docs, start=1): total_bytes_processed += bytes_processed builders.add_item(torch.IntTensor(sentence)) @@ -168,14 +217,11 @@ def main(): if i % args.log_interval == 0: current = time.time() elapsed = current - proc_start - mbs = total_bytes_processed/elapsed/1024/1024 - print(f"Processed {i} documents", - f"({i/elapsed} docs/s, {mbs} MB/s).", - file=sys.stderr) - + mbs = total_bytes_processed / elapsed / 1024 / 1024 + print(f"Processed {i} documents", f"({i/elapsed} docs/s, {mbs} MB/s).", file=sys.stderr) + builders.finalize(output_idx_files) if __name__ == '__main__': main() - diff --git a/tools/run_dynamic_text_generation_server.py b/tools/run_dynamic_text_generation_server.py index 56edd6a116e..5aef2631595 100644 --- a/tools/run_dynamic_text_generation_server.py +++ b/tools/run_dynamic_text_generation_server.py @@ -10,7 +10,7 @@ start_text_gen_server, stop_text_gen_server, ) -from megatron.core.utils import trace_async_exceptions +from megatron.core.utils import configure_nvtx_profiling, trace_async_exceptions from megatron.inference.utils import add_inference_args, get_dynamic_inference_engine from megatron.post_training.arguments import add_modelopt_args from megatron.training import get_args @@ -89,10 +89,18 @@ async def run_text_generation_server( ) initialize_megatron() + args = get_args() + + # Match training's NVTX gating (training.py only flips this when both + # --profile and --nvtx-ranges are set). Otherwise the engine-side + # nvtx_range_push labels (bookkeeping, Decode, _ep_establish_consensus, + # etc.) are no-ops and the inter-step gap is unattributable in nsys. + if args.profile and args.nvtx_ranges: + configure_nvtx_profiling(True) + # Enable return_log_probs to allow prompt logprobs computation for echo=True requests # This sets materialize_only_last_token_logits=False in the inference context, # which is required for lm-eval compatibility (loglikelihood evaluation tasks) - args = get_args() args.return_log_probs = True engine = get_dynamic_inference_engine() diff --git a/tools/run_mamba_text_generation_server.py b/tools/run_hybrid_text_generation_server.py similarity index 89% rename from tools/run_mamba_text_generation_server.py rename to tools/run_hybrid_text_generation_server.py index 33465f1bb4a..e70e5389e88 100644 --- a/tools/run_mamba_text_generation_server.py +++ b/tools/run_hybrid_text_generation_server.py @@ -8,4 +8,4 @@ from run_text_generation_server import main if __name__ == "__main__": - main(model_type="mamba") + main(model_type="hybrid") diff --git a/tools/run_mamba_text_generation_server_completions.py b/tools/run_hybrid_text_generation_server_completions.py similarity index 89% rename from tools/run_mamba_text_generation_server_completions.py rename to tools/run_hybrid_text_generation_server_completions.py index 33465f1bb4a..e70e5389e88 100644 --- a/tools/run_mamba_text_generation_server_completions.py +++ b/tools/run_hybrid_text_generation_server_completions.py @@ -8,4 +8,4 @@ from run_text_generation_server import main if __name__ == "__main__": - main(model_type="mamba") + main(model_type="hybrid") diff --git a/tools/run_inference_performance_test.py b/tools/run_inference_performance_test.py index 4140740e284..bf9d0015549 100644 --- a/tools/run_inference_performance_test.py +++ b/tools/run_inference_performance_test.py @@ -9,7 +9,7 @@ import torch from gpt_builders import gpt_builder -from mamba_builders import mamba_builder +from hybrid_builders import hybrid_builder from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.inference.engines import DynamicInferenceEngine, StaticInferenceEngine from megatron.core.inference.engines.abstract_engine import AbstractEngine diff --git a/tools/run_text_generation_server.py b/tools/run_text_generation_server.py index b8ddd986fa6..abad4556f99 100644 --- a/tools/run_text_generation_server.py +++ b/tools/run_text_generation_server.py @@ -15,7 +15,7 @@ import torch from gpt_builders import gpt_builder -from mamba_builders import mamba_builder +from hybrid_builders import hybrid_builder from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.inference.engines import AbstractEngine, StaticInferenceEngine from megatron.core.inference.engines.abstract_engine import AbstractEngine @@ -138,8 +138,16 @@ def main(model_type: str = "gpt"): # Set up model and load checkpoint if model_type == "gpt": model_builder = gpt_builder - elif model_type == "mamba": - model_builder = mamba_builder + elif model_type in ("hybrid", "mamba"): + if model_type == "mamba": + import warnings + + warnings.warn( + 'model_type="mamba" is deprecated. Use model_type="hybrid" instead.', + DeprecationWarning, + stacklevel=2, + ) + model_builder = hybrid_builder else: raise ValueError(f"Invalid model provider {model_type}") model = get_model(partial(model_provider, model_builder), wrap_with_ddp=False) diff --git a/train_rl.py b/train_rl.py index 06457be1245..4637a184813 100644 --- a/train_rl.py +++ b/train_rl.py @@ -8,7 +8,7 @@ import torch from gpt_builders import gpt_builder -from mamba_builders import mamba_builder +from hybrid_builders import hybrid_builder from megatron.core import mpu from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel @@ -24,6 +24,7 @@ ) from megatron.rl.sequence_packing_utils import get_default_packed_seq_params from megatron.training import get_args, get_timers, pretrain, print_rank_0 +from megatron.training.argument_utils import pretrain_cfg_container_from_args from megatron.training.arguments import core_transformer_config_from_args, parse_and_validate_args from megatron.training.utils import is_hybrid_model from model_provider import model_provider @@ -395,7 +396,7 @@ def _model_builder( args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None ): if is_hybrid_model(args): - return mamba_builder( + return hybrid_builder( args, pre_process, post_process, @@ -413,8 +414,10 @@ def _model_builder( pg_collection=pg_collection, ) - parse_and_validate_args(extra_args_provider=add_inference_args, args_defaults={}) + args = parse_and_validate_args(extra_args_provider=add_inference_args, args_defaults={}) + full_config = pretrain_cfg_container_from_args(args) pretrain( + full_config, None, # we don't need to build any datasets for RL training partial(model_provider, _model_builder), ModelType.encoder_or_decoder,