Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,10 @@ def build_conversion_tasks(
if not isinstance(hf_pretrained, PretrainedConfig) and not has_hf_state:
raise ValueError("hf_pretrained.state.source is required for weight ordering")

# Stash for subclass hooks (e.g. ``maybe_modify_loaded_hf_weight``) that need access
# to the source HF config to disambiguate ambiguous tensor shapes.
self.hf_pretrained = hf_pretrained

hf_keys: Optional[Iterable[str]] = hf_pretrained.state.source.get_all_keys() if has_hf_state else None

mapping_registry = self.mapping_registry()
Expand Down
49 changes: 37 additions & 12 deletions src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,15 @@ def maybe_modify_loaded_hf_weight(
) -> torch.Tensor:
"""Load weights from HuggingFace state dict with MXFP4 dequantization support.

down_proj expert weights are stored transposed vs Megatron (HF: [in, out], Megatron: [out, in]).
We transpose them once here so that GPTOSSMLPDownProjMapping.hf_to_megatron can treat the
per-expert slice as-is, and megatron_to_hf symmetrically transposes back on export.
Per-expert ``down_proj`` is square for GPT-OSS-20B/120B (hidden == intermediate), so
the bridge cannot auto-detect orientation from shape alone. BF16 checkpoints (e.g.
``unsloth/gpt-oss-20b-BF16``, and what ``transformers.GptOssForCausalLM`` produces at
init) store it as ``[E, intermediate, hidden]``, matching ``gate_up_proj``'s
``[E, hidden, 2*intermediate]`` convention. MXFP4-dequantized weights come out as
``[E, hidden, intermediate]``. Megatron's TE ``RowParallelGroupedLinear`` expects
per-expert ``(hidden, intermediate)``, so the BF16 path needs a transpose here while
the MXFP4 path is already aligned. Without this, BF16 imports silently store down_proj
in the wrong orientation and inference is broken.

gate_up_proj is handled directly in GPTOSSMLPGateUpProjMapping.hf_to_megatron via
_align_expert_weight_to_shape, which auto-detects the orientation difference between
Expand All @@ -118,15 +124,27 @@ def maybe_modify_loaded_hf_weight(
if isinstance(hf_param, str):
if hf_param in hf_state_dict:
hf_weights = hf_state_dict[hf_param]
if ".mlp.experts.down_proj" in hf_param and hf_weights.ndim == 3:
hf_weights = hf_weights.transpose(-1, -2)
if hf_param.endswith(".mlp.experts.down_proj") and hf_weights.ndim == 3:
cfg = self.hf_pretrained.config
hidden = cfg.hidden_size
intermediate = cfg.intermediate_size
last2 = tuple(hf_weights.shape[-2:])
if last2 == (intermediate, hidden) and intermediate != hidden:
# Unambiguous BF16 layout (E, intermediate, hidden); transpose to (E, hidden, intermediate).
hf_weights = hf_weights.transpose(-1, -2).contiguous()
elif last2 == (hidden, intermediate) and intermediate != hidden:
# Already aligned with Megatron — no-op.
pass
elif intermediate == hidden:
# Square: HF GptOssForCausalLM init produces (E, intermediate, hidden), so a plain BF16
# checkpoint is in that layout. Transpose to (E, hidden, intermediate) for Megatron.
hf_weights = hf_weights.transpose(-1, -2).contiguous()
return hf_weights
blocks_key = hf_param + "_blocks"
scales_key = hf_param + "_scales"
if blocks_key in hf_state_dict and scales_key in hf_state_dict:
hf_weights = _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key])
if ".mlp.experts.down_proj" in hf_param and hf_weights.ndim == 3:
hf_weights = hf_weights.transpose(-1, -2)
# MXFP4 dequant already emits [E, hidden, intermediate] for down_proj — leave as-is.
return hf_weights
raise KeyError(
f"Cannot locate weights for '{hf_param}'. Missing both de-quantized tensor and "
Expand Down Expand Up @@ -236,11 +254,18 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -
return super().hf_to_megatron(hf_weights[global_expert_number], megatron_module)

def megatron_to_hf(self, megatron_weights: torch.Tensor, megatron_module: nn.Module) -> Dict[str, torch.Tensor]:
if megatron_weights is None:
return super().megatron_to_hf(megatron_weights, megatron_module)
if len(megatron_weights.shape) == 2:
megatron_weights = megatron_weights.transpose(0, 1)
return super().megatron_to_hf(megatron_weights.contiguous(), megatron_module)
# Megatron stores per-expert weight as (hidden, intermediate); HF down_proj
# weight is (E, intermediate, hidden). Transpose the last two dims so the
# grouped-export stack assembles in HF's layout. Under EP the parent's gather
# may have already cat'd across the EP group, producing a 3D (ep_size, out, in)
# tensor — handle that too. The bias has no orientation to align (per-expert
# 1-D, stacked to (E, hidden) on export), so leave bias mappings untouched.
if megatron_weights is not None:
megatron_weights = megatron_weights.contiguous()
result = super().megatron_to_hf(megatron_weights, megatron_module)
if self.hf_param.endswith("_bias"):
return result
return {k: v.transpose(-1, -2).contiguous() if v.ndim >= 2 else v for k, v in result.items()}


class GPTOSSMLPGateUpProjMapping(AutoMapping):
Expand Down
Loading
Loading