From 3ef485274863e64cd9f8f8d1e11259faf2eefe73 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 7 May 2026 16:27:08 -0700 Subject: [PATCH 1/2] [model, ckpt] fix: align GPT-OSS BF16 down_proj orientation on import (r0.4.0) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per-expert ``down_proj`` is square for GPT-OSS-20B/120B (hidden == intermediate), so the bridge cannot auto-detect orientation from shape alone. BF16 checkpoints (e.g. unsloth/gpt-oss-20b-BF16, and what transformers.GptOssForCausalLM produces at init) store it as [E, intermediate, hidden]; MXFP4-dequantized weights come out as [E, hidden, intermediate]. Megatron's TE RowParallelGroupedLinear expects per-expert (hidden, intermediate), so the BF16 path needs a transpose on import while the MXFP4 path is already aligned. Without the import transpose, BF16 imports silently store down_proj in the wrong orientation: roundtrip vs the same BF16 source still matches (import and export are symmetrically broken), but inference is broken — forward-pass cosine similarity vs HF drops to ~0.54 for gpt-oss-20b on a saved/reloaded BF16-imported Megatron checkpoint. Fix the import side in ``maybe_modify_loaded_hf_weight``, and add a coordinated per-expert transpose in ``GPTOSSMLPDownProjMapping.megatron_to_hf`` so the grouped-export stack returns to HF's [E, intermediate, hidden] layout. The shape-detection in ``maybe_modify_loaded_hf_weight`` reads ``self.hf_pretrained.config``. On main this is already populated by ``MegatronModelBridge.build_conversion_tasks``; on r0.4.0 the decentralized-PG refactor (#3674) dropped that assignment, so this backport restores the one-line stash inside ``build_conversion_tasks`` to keep ``self.hf_pretrained`` available to subclass hooks. (No behavioral change beyond making the attribute reachable again.) Verification on r0.4.0 with TP=1 PP=8 EP=1: - BF16 import → forward cos sim vs HF: 0.999973 - MXFP4 import → forward cos sim vs HF: 0.999973 - BF16 import → reload → roundtrip vs BF16 HF: 411/411 ✅ - MXFP4 import → reload → roundtrip vs BF16 HF: 411/411 ✅ Signed-off-by: Chen Cui --- .../bridge/models/conversion/model_bridge.py | 4 ++ .../bridge/models/gpt_oss/gpt_oss_bridge.py | 44 ++++++++++++++----- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 4afee912c9..036f3e9d4d 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -1417,6 +1417,10 @@ def build_conversion_tasks( if not isinstance(hf_pretrained, PretrainedConfig) and not has_hf_state: raise ValueError("hf_pretrained.state.source is required for weight ordering") + # Stash for subclass hooks (e.g. ``maybe_modify_loaded_hf_weight``) that need access + # to the source HF config to disambiguate ambiguous tensor shapes. + self.hf_pretrained = hf_pretrained + hf_keys: Optional[Iterable[str]] = hf_pretrained.state.source.get_all_keys() if has_hf_state else None mapping_registry = self.mapping_registry() diff --git a/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py b/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py index f4c1dd8762..f1626987de 100644 --- a/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py +++ b/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py @@ -106,9 +106,15 @@ def maybe_modify_loaded_hf_weight( ) -> torch.Tensor: """Load weights from HuggingFace state dict with MXFP4 dequantization support. - down_proj expert weights are stored transposed vs Megatron (HF: [in, out], Megatron: [out, in]). - We transpose them once here so that GPTOSSMLPDownProjMapping.hf_to_megatron can treat the - per-expert slice as-is, and megatron_to_hf symmetrically transposes back on export. + Per-expert ``down_proj`` is square for GPT-OSS-20B/120B (hidden == intermediate), so + the bridge cannot auto-detect orientation from shape alone. BF16 checkpoints (e.g. + ``unsloth/gpt-oss-20b-BF16``, and what ``transformers.GptOssForCausalLM`` produces at + init) store it as ``[E, intermediate, hidden]``, matching ``gate_up_proj``'s + ``[E, hidden, 2*intermediate]`` convention. MXFP4-dequantized weights come out as + ``[E, hidden, intermediate]``. Megatron's TE ``RowParallelGroupedLinear`` expects + per-expert ``(hidden, intermediate)``, so the BF16 path needs a transpose here while + the MXFP4 path is already aligned. Without this, BF16 imports silently store down_proj + in the wrong orientation and inference is broken. gate_up_proj is handled directly in GPTOSSMLPGateUpProjMapping.hf_to_megatron via _align_expert_weight_to_shape, which auto-detects the orientation difference between @@ -118,15 +124,27 @@ def maybe_modify_loaded_hf_weight( if isinstance(hf_param, str): if hf_param in hf_state_dict: hf_weights = hf_state_dict[hf_param] - if ".mlp.experts.down_proj" in hf_param and hf_weights.ndim == 3: - hf_weights = hf_weights.transpose(-1, -2) + if hf_param.endswith(".mlp.experts.down_proj") and hf_weights.ndim == 3: + cfg = self.hf_pretrained.config + hidden = cfg.hidden_size + intermediate = cfg.intermediate_size + last2 = tuple(hf_weights.shape[-2:]) + if last2 == (intermediate, hidden) and intermediate != hidden: + # Unambiguous BF16 layout (E, intermediate, hidden); transpose to (E, hidden, intermediate). + hf_weights = hf_weights.transpose(-1, -2).contiguous() + elif last2 == (hidden, intermediate) and intermediate != hidden: + # Already aligned with Megatron — no-op. + pass + elif intermediate == hidden: + # Square: HF GptOssForCausalLM init produces (E, intermediate, hidden), so a plain BF16 + # checkpoint is in that layout. Transpose to (E, hidden, intermediate) for Megatron. + hf_weights = hf_weights.transpose(-1, -2).contiguous() return hf_weights blocks_key = hf_param + "_blocks" scales_key = hf_param + "_scales" if blocks_key in hf_state_dict and scales_key in hf_state_dict: hf_weights = _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key]) - if ".mlp.experts.down_proj" in hf_param and hf_weights.ndim == 3: - hf_weights = hf_weights.transpose(-1, -2) + # MXFP4 dequant already emits [E, hidden, intermediate] for down_proj — leave as-is. return hf_weights raise KeyError( f"Cannot locate weights for '{hf_param}'. Missing both de-quantized tensor and " @@ -236,11 +254,13 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) - return super().hf_to_megatron(hf_weights[global_expert_number], megatron_module) def megatron_to_hf(self, megatron_weights: torch.Tensor, megatron_module: nn.Module) -> Dict[str, torch.Tensor]: - if megatron_weights is None: - return super().megatron_to_hf(megatron_weights, megatron_module) - if len(megatron_weights.shape) == 2: - megatron_weights = megatron_weights.transpose(0, 1) - return super().megatron_to_hf(megatron_weights.contiguous(), megatron_module) + # Megatron stores per-expert as (hidden, intermediate); HF down_proj is + # (E, intermediate, hidden). Transpose each per-expert tensor so the + # grouped-export stack assembles in HF's layout. + if megatron_weights is not None: + megatron_weights = megatron_weights.contiguous() + result = super().megatron_to_hf(megatron_weights, megatron_module) + return {k: v.t().contiguous() if v.ndim == 2 else v for k, v in result.items()} class GPTOSSMLPGateUpProjMapping(AutoMapping): From 25294b571e132c3cb0e87cc0ac6899cb42bdd93b Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 7 May 2026 17:39:27 -0700 Subject: [PATCH 2/2] [model, ckpt, test] fix: cover EP export and split toys into BF16/MXFP4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on the BF16 import-side transpose by extending the GPT-OSS ``down_proj`` export to handle the EP-aggregated path, and rewrites the toy conversion test to faithfully model both real checkpoint layouts. Bridge change (``gpt_oss_bridge.py``) - ``GPTOSSMLPDownProjMapping.megatron_to_hf`` now transposes the last two dims of any ndim>=2 weight tensor, not only 2-D ones. Under EP the parent ``gather_from_ep_ranks`` may concatenate the per-rank experts before the per-expert export hook runs, producing a 3-D ``(ep_size, hidden, intermediate)`` tensor that the previous 2-D-only guard skipped. Bias mappings (``hf_param`` ending in ``_bias``) are passed through unchanged so per-expert biases that arrive 2-D under EP are not flipped. Toy test rewrite (``test_gpt_oss_conversion.py``) - New fixture builds two toys from the same underlying weights: * BF16 toy: faithful unsloth-style layout (``gate_up_proj`` ``[E, hidden, 2*intermediate]``, ``down_proj`` ``[E, intermediate, hidden]``). * MXFP4 toy: ``*_blocks``/``*_scales`` whose ``_dequantize_mxfp4`` output equals the BF16 toy transposed per expert, matching the ``openai/gpt-oss-20b`` shipping layout. - Test parametrizes over ``source ∈ {bf16, mxfp4}`` × ``{PP=2, EP=2}``. BF16 runs the existing one-shot roundtrip; MXFP4 runs as a two-step ``convert_checkpoints_multi_gpu.py import`` then ``hf_megatron_roundtrip_multi_gpu.py --megatron-load-path`` against the BF16 toy as the reference, since the verification table cannot resolve ``down_proj``/``gate_up_proj`` keys in a quantized state dict. - ``hidden_size`` and ``intermediate_size`` are intentionally unequal so that any wrong-direction transpose surfaces as a shape mismatch (square real-model shapes silently mask layout bugs as wrong values). Verification on this branch - All 4 toy parametrizations pass: ``bf16-PP``, ``bf16-EP``, ``mxfp4-PP``, ``mxfp4-EP``. - Real model (``unsloth/gpt-oss-20b-BF16`` HF reference, TP=1): * BF16 import → forward cos sim vs HF: PP=8 0.999973, EP=8 0.999975. * MXFP4 import → forward cos sim vs HF: PP=8 0.999973, EP=8 0.999975. * Reload-roundtrip vs BF16 HF: 411/411 ✅ for all four (BF16/MXFP4) × (PP=8/EP=8) combinations. Signed-off-by: Chen Cui --- .../bridge/models/gpt_oss/gpt_oss_bridge.py | 13 +- .../models/gpt_oss/test_gpt_oss_conversion.py | 273 ++++++++++++++---- 2 files changed, 220 insertions(+), 66 deletions(-) diff --git a/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py b/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py index f1626987de..78d8b7aab7 100644 --- a/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py +++ b/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py @@ -254,13 +254,18 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) - return super().hf_to_megatron(hf_weights[global_expert_number], megatron_module) def megatron_to_hf(self, megatron_weights: torch.Tensor, megatron_module: nn.Module) -> Dict[str, torch.Tensor]: - # Megatron stores per-expert as (hidden, intermediate); HF down_proj is - # (E, intermediate, hidden). Transpose each per-expert tensor so the - # grouped-export stack assembles in HF's layout. + # Megatron stores per-expert weight as (hidden, intermediate); HF down_proj + # weight is (E, intermediate, hidden). Transpose the last two dims so the + # grouped-export stack assembles in HF's layout. Under EP the parent's gather + # may have already cat'd across the EP group, producing a 3D (ep_size, out, in) + # tensor — handle that too. The bias has no orientation to align (per-expert + # 1-D, stacked to (E, hidden) on export), so leave bias mappings untouched. if megatron_weights is not None: megatron_weights = megatron_weights.contiguous() result = super().megatron_to_hf(megatron_weights, megatron_module) - return {k: v.t().contiguous() if v.ndim == 2 else v for k, v in result.items()} + if self.hf_param.endswith("_bias"): + return result + return {k: v.transpose(-1, -2).contiguous() if v.ndim >= 2 else v for k, v in result.items()} class GPTOSSMLPGateUpProjMapping(AutoMapping): diff --git a/tests/functional_tests/test_groups/models/gpt_oss/test_gpt_oss_conversion.py b/tests/functional_tests/test_groups/models/gpt_oss/test_gpt_oss_conversion.py index 96c2f65724..3388875f99 100644 --- a/tests/functional_tests/test_groups/models/gpt_oss/test_gpt_oss_conversion.py +++ b/tests/functional_tests/test_groups/models/gpt_oss/test_gpt_oss_conversion.py @@ -20,7 +20,9 @@ import pytest -# Minimal GPT-OSS config used for building a tiny local HF directory to test conversion. +# Minimal GPT-OSS config used for building tiny local HF directories to test conversion. +# hidden_size and intermediate_size are different (and both divisible by 32) so the per-expert +# down_proj/gate_up_proj are non-square and the bridge can detect orientation from shape. GPT_OSS_TOY_OVERRIDES = { "architectures": ["GptOssForCausalLM"], "hidden_size": 512, @@ -34,59 +36,158 @@ } -class TestGptOssConversion: - """Functional tests for GPT-OSS toy conversion paths.""" +def _build_toy_models(bf16_dir: Path, mxfp4_dir: Path, seed: int = 0): + """Build two faithful toy GPT-OSS checkpoints sharing the same underlying weights. - @pytest.fixture(scope="class") - def gpt_oss_toy_model_path(self, tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp("gptoss_toy_model") - model_dir = tmp_dir / "gpt_oss_toy" + BF16 toy: + Stores per-expert ``gate_up_proj`` as ``[E, hidden, 2*intermediate]`` and + ``down_proj`` as ``[E, intermediate, hidden]`` — matching ``unsloth/gpt-oss-20b-BF16`` + and what ``transformers.GptOssForCausalLM`` produces at init. - # Importorskip ensures test is skipped gracefully if transformers lacks GPT-OSS - transformers = pytest.importorskip("transformers") - GptOssForCausalLM = getattr(transformers, "GptOssForCausalLM", None) - GptOssConfig = getattr(transformers, "GptOssConfig", None) - if GptOssForCausalLM is None or GptOssConfig is None: - pytest.skip("transformers installation does not include GPT-OSS classes") + MXFP4 toy: + Stores ``*_blocks`` and ``*_scales`` whose dequantization (via + ``_dequantize_mxfp4``) yields the BF16 toy's per-expert values *transposed* + — i.e. ``[E, 2*intermediate, hidden]`` for gate_up_proj and ``[E, hidden, + intermediate]`` for down_proj — matching how ``openai/gpt-oss-20b`` ships. + + The two toys are built so that BF16 == dequant(MXFP4).t(-1, -2) per expert, + which means a Megatron checkpoint imported from either source must contain + identical per-expert weights, and exporting that Megatron checkpoint back to + HF format must match the BF16 toy on every tensor. + """ + import torch + from safetensors.torch import save_file + from transformers import GptOssConfig, GptOssForCausalLM + + from megatron.bridge.models.gpt_oss.gpt_oss_bridge import _dequantize_mxfp4 + + config = GptOssConfig(**GPT_OSS_TOY_OVERRIDES) + model = GptOssForCausalLM(config).bfloat16() + + e = config.num_local_experts + h = config.hidden_size + i_ = config.intermediate_size + num_layers = config.num_hidden_layers + assert h % 32 == 0 and i_ % 32 == 0, "Both dims must be divisible by MXFP4 block size 32" + + gen = torch.Generator().manual_seed(seed) + + # Per-layer MXFP4 generation, then dequantize to obtain BF16 reference values. + layer_data = [] + for _ in range(num_layers): + gu_blocks = torch.randint(0, 256, (e, 2 * i_, h // 32, 16), dtype=torch.int32, generator=gen).to(torch.uint8) + # UE8M0 scales clustered near 127 so dequantized magnitudes are O(1). + gu_scales = torch.randint(124, 130, (e, 2 * i_, h // 32), dtype=torch.int32, generator=gen).to(torch.uint8) + # Dequant returns (E, 2*intermediate, hidden) = (E, out, in). + gu_dq = _dequantize_mxfp4(gu_blocks, gu_scales) + # BF16 HF layout is (E, in, out) = (E, hidden, 2*intermediate). + gu_bf16 = gu_dq.transpose(-1, -2).contiguous() + + dn_blocks = torch.randint(0, 256, (e, h, i_ // 32, 16), dtype=torch.int32, generator=gen).to(torch.uint8) + dn_scales = torch.randint(124, 130, (e, h, i_ // 32), dtype=torch.int32, generator=gen).to(torch.uint8) + # Dequant returns (E, hidden, intermediate) = (E, out, in). + dn_dq = _dequantize_mxfp4(dn_blocks, dn_scales) + # BF16 HF layout is (E, in, out) = (E, intermediate, hidden). + dn_bf16 = dn_dq.transpose(-1, -2).contiguous() + + layer_data.append( + { + "gu_blocks": gu_blocks, + "gu_scales": gu_scales, + "gu_bf16": gu_bf16, + "dn_blocks": dn_blocks, + "dn_scales": dn_scales, + "dn_bf16": dn_bf16, + } + ) - # Build tiny config and model - config = GptOssConfig(**GPT_OSS_TOY_OVERRIDES) - model = GptOssForCausalLM(config) - if hasattr(model, "bfloat16"): - model = model.bfloat16() + # Inject the dequantized values back into the BF16 model so its on-disk weights match + # exactly what the MXFP4 path will produce after dequantization. + sd = dict(model.state_dict()) + for li in range(num_layers): + sd[f"model.layers.{li}.mlp.experts.gate_up_proj"] = layer_data[li]["gu_bf16"] + sd[f"model.layers.{li}.mlp.experts.down_proj"] = layer_data[li]["dn_bf16"] + model.load_state_dict(sd) - # Save tokenizer (fallback to gpt2 tokenizer if GPT-OSS doesn't ship one) - try: - from transformers import AutoTokenizer + # ---- BF16 toy ---- + bf16_dir.mkdir(parents=True, exist_ok=True) + model.save_pretrained(bf16_dir, safe_serialization=True) + with open(bf16_dir / "config.json", "w") as f: + json.dump(model.config.to_dict(), f, indent=2) - tok = AutoTokenizer.from_pretrained("gpt2") - tok.save_pretrained(model_dir) - except Exception: - pass + # ---- MXFP4 toy ---- + mxfp4_dir.mkdir(parents=True, exist_ok=True) + mxfp4_sd = {} + for n, p in model.state_dict().items(): + if n.endswith(".mlp.experts.gate_up_proj") or n.endswith(".mlp.experts.down_proj"): + continue + mxfp4_sd[n] = p.detach().clone().contiguous() + for li in range(num_layers): + prefix = f"model.layers.{li}.mlp.experts" + mxfp4_sd[f"{prefix}.gate_up_proj_blocks"] = layer_data[li]["gu_blocks"] + mxfp4_sd[f"{prefix}.gate_up_proj_scales"] = layer_data[li]["gu_scales"] + mxfp4_sd[f"{prefix}.down_proj_blocks"] = layer_data[li]["dn_blocks"] + mxfp4_sd[f"{prefix}.down_proj_scales"] = layer_data[li]["dn_scales"] + save_file(mxfp4_sd, str(mxfp4_dir / "model.safetensors")) + with open(mxfp4_dir / "config.json", "w") as f: + json.dump(model.config.to_dict(), f, indent=2) - # Save model and config to directory - model.save_pretrained(model_dir, safe_serialization=True) + # Tokenizer (best-effort; both toys share it) + try: + from transformers import AutoTokenizer - # Also save config.json explicitly to ensure compatibility with correct torch_dtype - config_path = model_dir / "config.json" - with open(config_path, "w") as f: - json.dump(model.config.to_dict(), f, indent=2) + tok = AutoTokenizer.from_pretrained("gpt2") + tok.save_pretrained(bf16_dir) + tok.save_pretrained(mxfp4_dir) + except Exception: + pass - return str(model_dir) + +class TestGptOssConversion: + """Functional tests for GPT-OSS toy conversion paths. + + Two toys are built once per class: + - ``bf16``: faithful BF16 layout (matches ``unsloth/gpt-oss-20b-BF16``). + - ``mxfp4``: faithful MXFP4 layout (matches ``openai/gpt-oss-20b``). + + Each parallelism config is exercised against both sources. The MXFP4 path runs as a + two-step convert→roundtrip because the roundtrip script's verification table cannot + look up ``gate_up_proj``/``down_proj`` directly in a quantized state dict; instead we + import MXFP4 → save Megatron → reload Megatron → export → compare against the BF16 + toy as the reference. + """ + + @pytest.fixture(scope="class") + def gpt_oss_toy_paths(self, tmp_path_factory): + tmp_dir = tmp_path_factory.mktemp("gptoss_toys") + bf16_dir = tmp_dir / "gpt_oss_toy_bf16" + mxfp4_dir = tmp_dir / "gpt_oss_toy_mxfp4" + + transformers = pytest.importorskip("transformers") + if not all(hasattr(transformers, n) for n in ("GptOssForCausalLM", "GptOssConfig")): + pytest.skip("transformers installation does not include GPT-OSS classes") + + _build_toy_models(bf16_dir, mxfp4_dir) + return {"bf16": str(bf16_dir), "mxfp4": str(mxfp4_dir)} @pytest.mark.run_only_on("GPU") @pytest.mark.parametrize( - "tp,pp,ep,test_name", + "tp,pp,ep,parallel_name", [ (1, 2, 1, "PP"), (1, 1, 2, "EP"), ], ) - def test_gpt_oss_conversion_parallelism(self, gpt_oss_toy_model_path, tmp_path, tp, pp, ep, test_name): - out_dir = tmp_path / f"gpt_oss_{test_name}" + @pytest.mark.parametrize("source", ["bf16", "mxfp4"]) + def test_gpt_oss_conversion_parallelism(self, gpt_oss_toy_paths, tmp_path, tp, pp, ep, parallel_name, source): + repo_root = Path(__file__).parent.parent.parent.parent.parent.parent + bf16_path = gpt_oss_toy_paths["bf16"] + toy_path = gpt_oss_toy_paths[source] + + out_dir = tmp_path / f"gpt_oss_{source}_{parallel_name}" out_dir.mkdir(exist_ok=True) - cmd = [ + common_dist_args = [ "python", "-m", "torch.distributed.run", @@ -98,40 +199,89 @@ def test_gpt_oss_conversion_parallelism(self, gpt_oss_toy_model_path, tmp_path, "--data-file=/opt/Megatron-Bridge/.coverage", "--source=/opt/Megatron-Bridge/", "--parallel-mode", - "examples/conversion/hf_megatron_roundtrip_multi_gpu.py", - "--hf-model-id", - gpt_oss_toy_model_path, - "--output-dir", - str(out_dir), - "--tp", - str(tp), - "--pp", - str(pp), - "--ep", - str(ep), ] - result = subprocess.run( - cmd, - capture_output=True, - text=True, - cwd=Path(__file__).parent.parent.parent.parent.parent.parent, - ) + if source == "bf16": + # Single-step roundtrip: import + export + compare against the source. + cmd = common_dist_args + [ + "examples/conversion/hf_megatron_roundtrip_multi_gpu.py", + "--hf-model-id", + bf16_path, + "--output-dir", + str(out_dir), + "--tp", + str(tp), + "--pp", + str(pp), + "--ep", + str(ep), + ] + result = subprocess.run(cmd, capture_output=True, text=True, cwd=repo_root) + if result.returncode != 0: + print(f"STDOUT: {result.stdout}") + print(f"STDERR: {result.stderr}") + assert result.returncode == 0, f"GPT-OSS bf16 {parallel_name} roundtrip failed with rc={result.returncode}" + else: + # Two-step: (1) import MXFP4 -> save Megatron, (2) reload Megatron and export, + # comparing exported HF tensors against the BF16 toy reference (which equals + # dequant(MXFP4) by construction). + mcore_dir = tmp_path / f"mcore_mxfp4_{parallel_name}" + mcore_dir.mkdir(exist_ok=True) + import_cmd = common_dist_args + [ + "examples/conversion/convert_checkpoints_multi_gpu.py", + "import", + "--hf-model", + toy_path, + "--megatron-path", + str(mcore_dir), + "--tp", + str(tp), + "--pp", + str(pp), + "--ep", + str(ep), + ] + res = subprocess.run(import_cmd, capture_output=True, text=True, cwd=repo_root) + if res.returncode != 0: + print(f"STDOUT: {res.stdout}") + print(f"STDERR: {res.stderr}") + assert res.returncode == 0, f"GPT-OSS mxfp4 {parallel_name} import failed with rc={res.returncode}" - if result.returncode != 0: - print(f"STDOUT: {result.stdout}") - print(f"STDERR: {result.stderr}") - assert result.returncode == 0, f"GPT-OSS {test_name} conversion failed with {result.returncode}" + roundtrip_cmd = common_dist_args + [ + "examples/conversion/hf_megatron_roundtrip_multi_gpu.py", + "--hf-model-id", + bf16_path, + "--megatron-load-path", + str(mcore_dir), + "--output-dir", + str(out_dir), + "--tp", + str(tp), + "--pp", + str(pp), + "--ep", + str(ep), + "--skip-save", + ] + res = subprocess.run(roundtrip_cmd, capture_output=True, text=True, cwd=repo_root) + if res.returncode != 0: + print(f"STDOUT: {res.stdout}") + print(f"STDERR: {res.stderr}") + assert res.returncode == 0, ( + f"GPT-OSS mxfp4 {parallel_name} roundtrip-vs-bf16 failed with rc={res.returncode}" + ) + # MXFP4 path uses --skip-save, so there's no exported HF directory to inspect; the + # roundtrip script's internal verification table is the assertion of correctness. + return - # Verify output structure - model_name = Path(gpt_oss_toy_model_path).name + # Verify output structure for the BF16 path. + model_name = Path(bf16_path).name converted_dir = out_dir / model_name assert converted_dir.exists() config_file = converted_dir / "config.json" assert config_file.exists() - # weights can be either consolidated or sharded, and in safetensors or bin weights_file_safetensors = converted_dir / "model.safetensors" weights_file_pytorch = converted_dir / "pytorch_model.bin" weights_found = weights_file_safetensors.exists() or weights_file_pytorch.exists() @@ -144,11 +294,10 @@ def test_gpt_oss_conversion_parallelism(self, gpt_oss_toy_model_path, tmp_path, with open(config_file) as f: saved = json.load(f) - # Minimal sanity checks on saved config assert saved["num_hidden_layers"] == GPT_OSS_TOY_OVERRIDES["num_hidden_layers"] assert saved["num_attention_heads"] == GPT_OSS_TOY_OVERRIDES["num_attention_heads"] assert saved.get("num_local_experts", 0) == GPT_OSS_TOY_OVERRIDES["num_local_experts"] assert saved["vocab_size"] == GPT_OSS_TOY_OVERRIDES["vocab_size"] - print(f"SUCCESS: GPT-OSS {test_name} conversion test completed successfully") + print(f"SUCCESS: GPT-OSS {source} {parallel_name} conversion test completed successfully") print(f"Converted model saved at: {converted_dir}")