Skip to content

[model, ckpt] fix: align GPT-OSS BF16 down_proj orientation on import (r0.4.0)#3753

Merged
cuichenx merged 2 commits into
r0.4.0from
chcui/gpt-oss-down-proj-import-fix-r0.4.0
May 8, 2026
Merged

[model, ckpt] fix: align GPT-OSS BF16 down_proj orientation on import (r0.4.0)#3753
cuichenx merged 2 commits into
r0.4.0from
chcui/gpt-oss-down-proj-import-fix-r0.4.0

Conversation

@cuichenx
Copy link
Copy Markdown
Contributor

@cuichenx cuichenx commented May 8, 2026

Summary

Backports the GPT-OSS BF16 down_proj orientation fix from #3743 to the r0.4.0 release branch.

This is a silent inference-correctness bug: BF16 GPT-OSS checkpoints (e.g. unsloth/gpt-oss-20b-BF16, or anything transformers.GptOssForCausalLM produces at init) imported through the bridge ran inference with down_proj weights stored in the wrong orientation. Forward-pass cosine similarity vs HF dropped to ~0.54 on a saved/reloaded BF16-imported Megatron checkpoint, even though the in-memory roundtrip looked clean (import and export were symmetrically broken on the BF16 path).

Root cause

Per-expert down_proj is square for GPT-OSS-20B/120B (hidden == intermediate), so the bridge cannot auto-detect orientation from shape alone:

  • BF16 checkpoints store it as [E, intermediate, hidden], mirroring gate_up_proj's [E, hidden, 2*intermediate] convention.
  • MXFP4-dequantized weights come out as [E, hidden, intermediate].
  • Megatron's TE RowParallelGroupedLinear expects per-expert (hidden, intermediate).

gate_up_proj is non-square, so _align_expert_weight_to_shape already auto-detects and transposes; down_proj was passed straight through with no orientation alignment, so MXFP4 happened to land correct (matching the GEMM layout) and BF16 silently landed transposed.

Why prior PRs did not fully address this

What none of these caught is that Megatron's TE GroupedLinear expects per-expert (hidden, intermediate) — the standard PyTorch nn.Linear convention — not (intermediate, hidden). With BF16 source, the import has to transpose, and the export has to transpose back. With MXFP4 source, dequantization already emits the right layout. A forward-pass cosine-similarity check against HF would have caught any of these regressions; it wasn't run on either prior fix's verification path.

Fix

  • Import side — in GPTOSSBridge.maybe_modify_loaded_hf_weight, transpose down_proj when loading from a non-quantized BF16 checkpoint. Disambiguation: when the per-expert shape is non-square, shape-vs-config uniquely identifies layout; when square (gpt-oss-20b/120b), default to the transformers.GptOssForCausalLM init layout [E, intermediate, hidden].
  • Export side — in GPTOSSMLPDownProjMapping.megatron_to_hf, transpose the last two dims of each ndim>=2 weight tensor on the way out so the grouped-export stack reassembles in HF's [E, intermediate, hidden] layout. Under EP, gather_from_ep_ranks may have already concatenated per-rank experts into a 3-D (ep_size, hidden, intermediate) tensor, so the transpose runs unconditionally on the trailing two dims rather than only on 2-D inputs. Bias mappings are passed through untouched.
  • r0.4.0-only enabler — restore one line in MegatronModelBridge.build_conversion_tasks that stashes self.hf_pretrained = hf_pretrained. On main this assignment is already present; on r0.4.0 it was dropped by the decentralized-PG refactor in [model, ckpt, docs] fix: support HF→Megatron conversion under decentralized PGs (r0.4.0) #3674. The shape-detection in maybe_modify_loaded_hf_weight reads self.hf_pretrained.config, so we restore the stash to keep that hook self-contained. No behavioral change beyond making the attribute reachable again.

The MXFP4 dequant branch is left as-is (already produces the GEMM-correct layout).

Test changes

  • tests/functional_tests/test_groups/models/gpt_oss/test_gpt_oss_conversion.py is rewritten to build two faithful toys from the same underlying weights:
    • BF16 toy with the unsloth-style [E, hidden, 2*intermediate] / [E, intermediate, hidden] layout.
    • MXFP4 toy with *_blocks/*_scales whose _dequantize_mxfp4 output equals the BF16 toy transposed per expert, matching openai/gpt-oss-20b's shipping layout.
  • Parametrized over source ∈ {bf16, mxfp4} × {PP=2, EP=2}. MXFP4 runs as a two-step convert_checkpoints_multi_gpu.py import followed by hf_megatron_roundtrip_multi_gpu.py --megatron-load-path against the BF16 toy reference, since the verification table cannot resolve down_proj/gate_up_proj keys in a quantized state dict.
  • hidden_size != intermediate_size is intentional so any wrong-direction transpose surfaces as a shape mismatch — the previous toy used the MXFP4-dequant orientation for down_proj, hiding the bug behind a symmetric pass-through.

Verification

Real model on this branch (HF reference: unsloth/gpt-oss-20b-BF16, TP=1):

Check Result
BF16 import → forward cos sim vs HF (PP=8) 0.999973 ✅
BF16 import → forward cos sim vs HF (EP=8) 0.999975 ✅
MXFP4 import → forward cos sim vs HF (PP=8) 0.999973 ✅
MXFP4 import → forward cos sim vs HF (EP=8) 0.999975 ✅
BF16 import → reload → roundtrip vs BF16 HF (PP=8) 411/411 ✅
BF16 import → reload → roundtrip vs BF16 HF (EP=8) 411/411 ✅
MXFP4 import → reload → roundtrip vs BF16 HF (PP=8) 411/411 ✅
MXFP4 import → reload → roundtrip vs BF16 HF (EP=8) 411/411 ✅

Toy tests on this branch:

Source Parallelism Status
BF16 PP=2
BF16 EP=2
MXFP4 PP=2 ✅ (two-step)
MXFP4 EP=2 ✅ (two-step)

Test plan

  • examples/conversion/compare_hf_and_megatron/compare.py cos sim vs HF for both BF16 and MXFP4 imports under PP=8 and EP=8
  • examples/conversion/hf_megatron_roundtrip_multi_gpu.py with --megatron-load-path from both BF16 and MXFP4 imports compared against unsloth BF16 under PP=8 and EP=8
  • tests/functional_tests/test_groups/models/gpt_oss/test_gpt_oss_conversion.py — all 4 parametrizations green

cuichenx added 2 commits May 8, 2026 11:11
… (r0.4.0)

Per-expert ``down_proj`` is square for GPT-OSS-20B/120B (hidden ==
intermediate), so the bridge cannot auto-detect orientation from shape
alone. BF16 checkpoints (e.g. unsloth/gpt-oss-20b-BF16, and what
transformers.GptOssForCausalLM produces at init) store it as
[E, intermediate, hidden]; MXFP4-dequantized weights come out as
[E, hidden, intermediate]. Megatron's TE RowParallelGroupedLinear
expects per-expert (hidden, intermediate), so the BF16 path needs a
transpose on import while the MXFP4 path is already aligned.

Without the import transpose, BF16 imports silently store down_proj
in the wrong orientation: roundtrip vs the same BF16 source still
matches (import and export are symmetrically broken), but inference
is broken — forward-pass cosine similarity vs HF drops to ~0.54 for
gpt-oss-20b on a saved/reloaded BF16-imported Megatron checkpoint.

Fix the import side in ``maybe_modify_loaded_hf_weight``, and add a
coordinated per-expert transpose in
``GPTOSSMLPDownProjMapping.megatron_to_hf`` so the grouped-export stack
returns to HF's [E, intermediate, hidden] layout.

The shape-detection in ``maybe_modify_loaded_hf_weight`` reads
``self.hf_pretrained.config``. On main this is already populated by
``MegatronModelBridge.build_conversion_tasks``; on r0.4.0 the
decentralized-PG refactor (#3674) dropped that assignment, so this
backport restores the one-line stash inside ``build_conversion_tasks``
to keep ``self.hf_pretrained`` available to subclass hooks. (No
behavioral change beyond making the attribute reachable again.)

Verification on r0.4.0 with TP=1 PP=8 EP=1:
- BF16 import → forward cos sim vs HF: 0.999973
- MXFP4 import → forward cos sim vs HF: 0.999973
- BF16 import → reload → roundtrip vs BF16 HF: 411/411 ✅
- MXFP4 import → reload → roundtrip vs BF16 HF: 411/411 ✅

Signed-off-by: Chen Cui <chcui@nvidia.com>
Builds on the BF16 import-side transpose by extending the GPT-OSS
``down_proj`` export to handle the EP-aggregated path, and rewrites the
toy conversion test to faithfully model both real checkpoint layouts.

Bridge change (``gpt_oss_bridge.py``)
- ``GPTOSSMLPDownProjMapping.megatron_to_hf`` now transposes the last two
  dims of any ndim>=2 weight tensor, not only 2-D ones. Under EP the
  parent ``gather_from_ep_ranks`` may concatenate the per-rank experts
  before the per-expert export hook runs, producing a 3-D
  ``(ep_size, hidden, intermediate)`` tensor that the previous 2-D-only
  guard skipped. Bias mappings (``hf_param`` ending in ``_bias``) are
  passed through unchanged so per-expert biases that arrive 2-D under EP
  are not flipped.

Toy test rewrite (``test_gpt_oss_conversion.py``)
- New fixture builds two toys from the same underlying weights:
  * BF16 toy: faithful unsloth-style layout
    (``gate_up_proj`` ``[E, hidden, 2*intermediate]``, ``down_proj``
    ``[E, intermediate, hidden]``).
  * MXFP4 toy: ``*_blocks``/``*_scales`` whose ``_dequantize_mxfp4``
    output equals the BF16 toy transposed per expert, matching the
    ``openai/gpt-oss-20b`` shipping layout.
- Test parametrizes over ``source ∈ {bf16, mxfp4}`` × ``{PP=2, EP=2}``.
  BF16 runs the existing one-shot roundtrip; MXFP4 runs as a two-step
  ``convert_checkpoints_multi_gpu.py import`` then
  ``hf_megatron_roundtrip_multi_gpu.py --megatron-load-path`` against
  the BF16 toy as the reference, since the verification table cannot
  resolve ``down_proj``/``gate_up_proj`` keys in a quantized state
  dict.
- ``hidden_size`` and ``intermediate_size`` are intentionally unequal so
  that any wrong-direction transpose surfaces as a shape mismatch
  (square real-model shapes silently mask layout bugs as wrong values).

Verification on this branch
- All 4 toy parametrizations pass:
  ``bf16-PP``, ``bf16-EP``, ``mxfp4-PP``, ``mxfp4-EP``.
- Real model (``unsloth/gpt-oss-20b-BF16`` HF reference, TP=1):
  * BF16 import → forward cos sim vs HF: PP=8 0.999973, EP=8 0.999975.
  * MXFP4 import → forward cos sim vs HF: PP=8 0.999973, EP=8 0.999975.
  * Reload-roundtrip vs BF16 HF: 411/411 ✅ for all four
    (BF16/MXFP4) × (PP=8/EP=8) combinations.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cuichenx cuichenx added the needs-more-tests Requires additional L0 and L1 test coverage before merge label May 8, 2026
@cuichenx
Copy link
Copy Markdown
Contributor Author

cuichenx commented May 8, 2026

/ok to test 25294b5

@cuichenx cuichenx merged commit b2e213c into r0.4.0 May 8, 2026
69 of 70 checks passed
@cuichenx cuichenx deleted the chcui/gpt-oss-down-proj-import-fix-r0.4.0 branch May 8, 2026 21:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-more-tests Requires additional L0 and L1 test coverage before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant