diff --git a/src/megatron/bridge/models/__init__.py b/src/megatron/bridge/models/__init__.py index 874d577ae8..56592e4308 100644 --- a/src/megatron/bridge/models/__init__.py +++ b/src/megatron/bridge/models/__init__.py @@ -110,6 +110,13 @@ from megatron.bridge.models.nemotronh import ( NemotronHBridge, ) +from megatron.bridge.models.olmo2 import ( + Olmo2Bridge, + Olmo2ModelProvider, + Olmo2ModelProvider1B, + Olmo2ModelProvider7B, + Olmo2ModelProvider13B, +) from megatron.bridge.models.olmoe import ( OlMoEBridge, OlMoEModelProvider, @@ -205,6 +212,11 @@ "Ministral3ModelProvider8B", "Ministral3ModelProvider14B", "MiniMaxM2Bridge", + "Olmo2Bridge", + "Olmo2ModelProvider", + "Olmo2ModelProvider1B", + "Olmo2ModelProvider7B", + "Olmo2ModelProvider13B", "OlMoEBridge", "OlMoEModelProvider", "NemotronHBridge", diff --git a/src/megatron/bridge/models/olmo2/__init__.py b/src/megatron/bridge/models/olmo2/__init__.py new file mode 100644 index 0000000000..2153b7d694 --- /dev/null +++ b/src/megatron/bridge/models/olmo2/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bridge support for AllenAI's OLMo-2 dense causal LM family. + +OLMo-2 is the second-generation fully-open language model from the Allen +Institute. Compared to OLMo-1 and OLMoE, OLMo-2 introduces two architectural +changes that motivate this bridge: + +* **Pure post-norm placement.** The decoder block is:: + + x = x + post_attention_layernorm(self_attn(x)) + x = x + post_feedforward_layernorm(mlp(x)) + + with no `input_layernorm` or `pre_feedforward_layernorm` (vs. Llama/Qwen3 + which pre-norm both sub-blocks, or Gemma2 which sandwiches both pre and + post). See `_olmo2_layer_spec` in `olmo2_provider.py`. + +* **QK-RMSNorm** applied to the per-head Q and K projections inside attention. + Megatron-Core supports this via `qk_layernorm=True` plus mappings for the + `q_layernorm` / `k_layernorm` weights — same as Qwen3 and OLMoE. + +Reference: Yang et al., 2024, *2 OLMo 2 Furious* (https://arxiv.org/abs/2501.00656). +""" + +from megatron.bridge.models.olmo2.olmo2_bridge import Olmo2Bridge +from megatron.bridge.models.olmo2.olmo2_provider import ( + Olmo2ModelProvider, + Olmo2ModelProvider1B, + Olmo2ModelProvider7B, + Olmo2ModelProvider13B, + olmo2_layer_spec, +) + + +__all__ = [ + "Olmo2Bridge", + "Olmo2ModelProvider", + "Olmo2ModelProvider1B", + "Olmo2ModelProvider7B", + "Olmo2ModelProvider13B", + "olmo2_layer_spec", +] diff --git a/src/megatron/bridge/models/olmo2/olmo2_bridge.py b/src/megatron/bridge/models/olmo2/olmo2_bridge.py new file mode 100644 index 0000000000..736c51f297 --- /dev/null +++ b/src/megatron/bridge/models/olmo2/olmo2_bridge.py @@ -0,0 +1,145 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bridge for HuggingFace ``Olmo2ForCausalLM`` ↔ Megatron-Core ``GPTModel``.""" + +from megatron.core.models.gpt.gpt_model import GPTModel + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, +) +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.olmo2.olmo2_provider import olmo2_layer_spec + + +@MegatronModelBridge.register_bridge(source="Olmo2ForCausalLM", target=GPTModel, model_type="olmo2") +class Olmo2Bridge(MegatronModelBridge): + """Bridge for AllenAI's OLMo-2 dense causal LM family. + + Architecture summary (vs. the closest existing bridges): + + +-----------------------+-------------+-------------+-------------+-------------+ + | Property | Llama | Qwen3 | Gemma2 | OLMo-2 | + +=======================+=============+=============+=============+=============+ + | Pre-attn norm | yes | yes | yes | **no** | + | Pre-MLP norm | yes | yes | yes | **no** | + | Post-attn norm | no | no | yes | **yes** | + | Post-MLP norm | no | no | yes | **yes** | + | QK-RMSNorm | no | yes | no | **yes** | + | Logit soft-capping | no | no | yes | no | + | Sliding-window attn | no | no | yes (alt.) | no | + +-----------------------+-------------+-------------+-------------+-------------+ + + The custom layer spec (see :func:`olmo2_layer_spec`) realizes the post-norm + placement. ``mapping_registry`` below names every weight in the HF + state dict and routes it to the corresponding Megatron-Core parameter. + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("allenai/OLMo-2-1124-7B") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: + """Convert HF OLMo-2 config to a ``GPTModelProvider`` and apply the OLMo-2 layer spec.""" + provider = super().provider_bridge(hf_pretrained) + hf_config = hf_pretrained.config + + # Pure post-norm: select the OLMo-2 specific layer spec. + provider.transformer_layer_spec = olmo2_layer_spec + + # `head_dim` is not always present in the HF config; derive it when missing. + provider.kv_channels = getattr(hf_config, "head_dim", None) or ( + hf_config.hidden_size // hf_config.num_attention_heads + ) + + # OLMo-2 specifics (all values match `Olmo2Config` defaults / 1B + 7B + 13B configs). + provider.normalization = "RMSNorm" + provider.gated_linear_unit = True + provider.position_embedding_type = "rope" + provider.add_bias_linear = False + provider.add_qkv_bias = False + provider.hidden_dropout = 0.0 + provider.attention_dropout = float(getattr(hf_config, "attention_dropout", 0.0)) + provider.qk_layernorm = True + provider.persist_layer_norm = True + provider.share_embeddings_and_output_weights = bool(getattr(hf_config, "tie_word_embeddings", False)) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + """Weight mappings between HF ``Olmo2ForCausalLM`` and Megatron-Core ``GPTModel``. + + Notable points specific to OLMo-2: + + * ``model.layers.*.post_attention_layernorm.weight`` and + ``model.layers.*.post_feedforward_layernorm.weight`` are *output* + norms — they map to ``linear_proj.post_layernorm`` / + ``linear_fc2.post_layernorm``, not to the standard Llama-style + slot ``linear_qkv.layer_norm_weight``. + * Q-/K-RMSNorm weights live on the per-head projections inside the + attention block — same name pattern as Qwen3 and OLMoE. + """ + # 1:1 renames (Megatron name → HF name). Wildcards expand per layer. + param_mappings = { + # Token embeddings, output projection, final norm + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "output_layer.weight": "lm_head.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + # Attention output + post-attention norm (the post-norm folded into linear_proj) + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.self_attention.linear_proj.post_layernorm.weight": ( + "model.layers.*.post_attention_layernorm.weight" + ), + # QK-RMSNorm (per-head Q/K normalization inside attention) + "decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.self_attn.q_norm.weight", + "decoder.layers.*.self_attention.k_layernorm.weight": "model.layers.*.self_attn.k_norm.weight", + # MLP down projection + post-feedforward norm (the post-norm folded into linear_fc2) + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.mlp.linear_fc2.post_layernorm.weight": ( + "model.layers.*.post_feedforward_layernorm.weight" + ), + } + + mapping_list = [ + AutoMapping(megatron_param=megatron_param, hf_param=hf_param) + for megatron_param, hf_param in param_mappings.items() + ] + + # Fused QKV: HF stores Q/K/V separately; Megatron uses a single packed matrix. + # OLMo-2 has no QKV bias, so only the weight is fused. + mapping_list.append( + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ) + ) + + # Gated SwiGLU MLP: HF stores gate_proj + up_proj separately; + # Megatron concatenates them into linear_fc1. + mapping_list.append( + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + gate="model.layers.*.mlp.gate_proj.weight", + up="model.layers.*.mlp.up_proj.weight", + ) + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/src/megatron/bridge/models/olmo2/olmo2_provider.py b/src/megatron/bridge/models/olmo2/olmo2_provider.py new file mode 100644 index 0000000000..31f53e213a --- /dev/null +++ b/src/megatron/bridge/models/olmo2/olmo2_provider.py @@ -0,0 +1,200 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Provider and layer spec for OLMo-2 dense causal LMs. + +OLMo-2's decoder block applies normalization *only after* each sub-block:: + + x = x + post_attention_layernorm(self_attn(x)) + x = x + post_feedforward_layernorm(mlp(x)) + +This differs from the standard Megatron pre-norm spec (which normalizes the +input of each sub-block) and from Gemma2's sandwich norm (which normalizes +both the input and the output). To realize OLMo-2 in Megatron-Core, the layer +spec built here: + +* uses ``IdentityOp`` for ``input_layernorm`` and ``pre_mlp_layernorm`` so + the pre-block normalizations are no-ops, +* wraps ``linear_proj`` and ``linear_fc2`` in ``TERowParallelLinearPostLN`` + so an RMSNorm is applied to each sub-block's output before the residual, +* keeps Q/K RMSNorm via the standard ``q_layernorm`` / ``k_layernorm`` + submodule slots (enabled by ``qk_layernorm=True`` on the provider). + +The post-LN wrapper is identical in spirit to Gemma2's +``TERowParallelLinearLayerNorm``; we redefine it locally per the project's +"keep model-specific logic in the family directory" guideline. +""" + +from dataclasses import dataclass +from typing import Callable, Union + +import torch +import torch.nn.functional as F +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TENorm, + TERowParallelLinear, +) +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.transformer import ( + ModuleSpec, + TransformerConfig, + TransformerLayer, + TransformerLayerSubmodules, +) +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules + +from megatron.bridge.models.gpt_provider import GPTModelProvider + + +class TERowParallelLinearPostLN(TERowParallelLinear): + """``TERowParallelLinear`` with a trailing RMSNorm applied to its output. + + Used at the output of attention (`linear_proj`) and MLP (`linear_fc2`) + sub-blocks so that OLMo-2's post-norm placement + ``residual + post_norm(sub_block(x))`` can be expressed in the standard + Megatron-Core ``TransformerLayer`` template. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: TransformerConfig, + **kwargs: object, + ) -> None: + super().__init__(input_size, output_size, config=config, **kwargs) + self.post_layernorm = TENorm(config, output_size) + + def forward(self, x: torch.Tensor): # type: ignore[override] + """Forward with a trailing RMSNorm.""" + output, bias = super().forward(x) + return self.post_layernorm(output), bias + + +def olmo2_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Layer spec for OLMo-2 dense models. + + * No pre-norms (``input_layernorm`` / ``pre_mlp_layernorm`` → ``IdentityOp``) + * Post-attention RMSNorm fused into ``linear_proj`` via ``TERowParallelLinearPostLN`` + * Post-feedforward RMSNorm fused into ``linear_fc2`` via ``TERowParallelLinearPostLN`` + * QK-RMSNorm via standard ``q_layernorm`` / ``k_layernorm`` submodule slots, + activated by ``provider.qk_layernorm = True``. + """ + del config # spec is independent of the runtime config; signature kept for symmetry. + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=IdentityOp, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + q_layernorm=TENorm, + k_layernorm=TENorm, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinearPostLN, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=IdentityOp, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinearPostLN, + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ) + + +@dataclass +class Olmo2ModelProvider(GPTModelProvider): + """Base provider for OLMo-2 dense causal LMs. + + Architectural choices (from `allenai/OLMo-2-1124-7B/config.json` and + `transformers/models/olmo2/modeling_olmo2.py`): + + * RMSNorm with `layernorm_epsilon=1e-6` + * SwiGLU MLP (``activation_func=F.silu`` + ``gated_linear_unit=True``) + * No biases anywhere (``add_bias_linear=False``, ``add_qkv_bias=False``) + * RoPE with ``rotary_base=500000`` + * Tied embeddings = False (untied input/output) + * QK-RMSNorm (``qk_layernorm=True``) + * Pure post-norm via the custom ``olmo2_layer_spec``. + """ + + transformer_layer_spec: Union[ModuleSpec, Callable[["GPTModelProvider"], ModuleSpec]] = olmo2_layer_spec + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + position_embedding_type: str = "rope" + add_bias_linear: bool = False + add_qkv_bias: bool = False + qk_layernorm: bool = True + layernorm_epsilon: float = 1e-6 + rotary_base: float = 500000.0 + seq_length: int = 4096 + init_method_std: float = 0.02 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + share_embeddings_and_output_weights: bool = False + persist_layer_norm: bool = True + autocast_dtype: torch.dtype = torch.bfloat16 + bf16: bool = True + params_dtype: torch.dtype = torch.bfloat16 + vocab_size: int = 100352 + + +@dataclass +class Olmo2ModelProvider1B(Olmo2ModelProvider): + """OLMo-2 1B (`allenai/OLMo-2-0425-1B`): 16 layers, hidden=2048, MHA=16, ffn=8192.""" + + num_layers: int = 16 + hidden_size: int = 2048 + num_attention_heads: int = 16 + num_query_groups: int = 16 + ffn_hidden_size: int = 8192 + kv_channels: int = 128 + + +@dataclass +class Olmo2ModelProvider7B(Olmo2ModelProvider): + """OLMo-2 7B (`allenai/OLMo-2-1124-7B`): 32 layers, hidden=4096, MHA=32, ffn=11008.""" + + num_layers: int = 32 + hidden_size: int = 4096 + num_attention_heads: int = 32 + num_query_groups: int = 32 + ffn_hidden_size: int = 11008 + kv_channels: int = 128 + + +@dataclass +class Olmo2ModelProvider13B(Olmo2ModelProvider): + """OLMo-2 13B (`allenai/OLMo-2-1124-13B`): 40 layers, hidden=5120, MHA=40, ffn=13824.""" + + num_layers: int = 40 + hidden_size: int = 5120 + num_attention_heads: int = 40 + num_query_groups: int = 40 + ffn_hidden_size: int = 13824 + kv_channels: int = 128 diff --git a/tests/unit_tests/models/olmo2/__init__.py b/tests/unit_tests/models/olmo2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/models/olmo2/test_olmo2_bridge.py b/tests/unit_tests/models/olmo2/test_olmo2_bridge.py new file mode 100644 index 0000000000..6f996b772a --- /dev/null +++ b/tests/unit_tests/models/olmo2/test_olmo2_bridge.py @@ -0,0 +1,431 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the OLMo-2 bridge. + +These tests pin the architectural decisions documented in +``olmo2_bridge.py`` and ``olmo2_provider.py``: + +* QK-RMSNorm and post-norm placement are flagged on the provider. +* The mapping registry routes the OLMo-2-specific + ``post_attention_layernorm`` / ``post_feedforward_layernorm`` weights into + the ``linear_proj.post_layernorm`` / ``linear_fc2.post_layernorm`` slots + (NOT into ``linear_qkv.layer_norm_weight`` like Llama/Qwen3 do). +* Pre-norm slots — ``input_layernorm`` and ``pre_mlp_layernorm`` — are + intentionally absent from the registry (no HF weights map to them). +* QKV / Gated-MLP weights are fused. +* Q-/K-RMSNorm weights map by the standard ``q_norm`` / ``k_norm`` names. +""" + +from unittest.mock import Mock + +import pytest +import torch + + +try: + from transformers import Olmo2Config, Olmo2ForCausalLM + + _HAS_OLMO2 = True +except ImportError: # pragma: no cover - older transformers versions + Olmo2Config = None # type: ignore[assignment] + Olmo2ForCausalLM = None # type: ignore[assignment] + _HAS_OLMO2 = False + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, +) +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.olmo2.olmo2_bridge import Olmo2Bridge +from megatron.bridge.models.olmo2.olmo2_provider import ( + Olmo2ModelProvider, + Olmo2ModelProvider1B, + Olmo2ModelProvider7B, + Olmo2ModelProvider13B, + TERowParallelLinearPostLN, + olmo2_layer_spec, +) + + +pytestmark = pytest.mark.skipif( + not _HAS_OLMO2, + reason="transformers version does not expose Olmo2Config / Olmo2ForCausalLM", +) + + +@pytest.fixture +def olmo2_7b_config_dict(): + """Mirror of `allenai/OLMo-2-1124-7B/config.json`.""" + return { + "architectures": ["Olmo2ForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "eos_token_id": 100257, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 4096, + "model_type": "olmo2", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pad_token_id": 100277, + "rms_norm_eps": 1e-06, + "rope_theta": 500000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "use_cache": True, + "vocab_size": 100352, + } + + +@pytest.fixture +def olmo2_1b_config_dict(): + """Mirror of `allenai/OLMo-2-0425-1B/config.json`.""" + return { + "architectures": ["Olmo2ForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "hidden_act": "silu", + "hidden_size": 2048, + "intermediate_size": 8192, + "max_position_embeddings": 4096, + "model_type": "olmo2", + "num_attention_heads": 16, + "num_hidden_layers": 16, + "num_key_value_heads": 16, + "rms_norm_eps": 1e-06, + "rope_theta": 500000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "vocab_size": 100352, + } + + +@pytest.fixture +def olmo2_7b_config(olmo2_7b_config_dict): + return Olmo2Config(**olmo2_7b_config_dict) + + +@pytest.fixture +def olmo2_1b_config(olmo2_1b_config_dict): + return Olmo2Config(**olmo2_1b_config_dict) + + +@pytest.fixture +def mock_pretrained_olmo2_7b(olmo2_7b_config): + pretrained = Mock(spec=PreTrainedCausalLM) + pretrained.config = olmo2_7b_config + pretrained.model = Mock(spec=Olmo2ForCausalLM) + pretrained.model.dtype = torch.bfloat16 + return pretrained + + +@pytest.fixture +def mock_pretrained_olmo2_1b(olmo2_1b_config): + pretrained = Mock(spec=PreTrainedCausalLM) + pretrained.config = olmo2_1b_config + pretrained.model = Mock(spec=Olmo2ForCausalLM) + pretrained.model.dtype = torch.bfloat16 + return pretrained + + +class TestOlmo2BridgeRegistration: + """Bridge class registration and basic identity.""" + + def test_inherits_megatron_bridge(self): + assert issubclass(Olmo2Bridge, MegatronModelBridge) + + def test_source_class_name(self): + # `source` is registered as a string for environments where the + # transformers version does not export Olmo2ForCausalLM. + assert getattr(Olmo2Bridge, "_source_class_name", "Olmo2ForCausalLM") in { + "Olmo2ForCausalLM", + getattr(Olmo2Bridge, "_source_class_name", "Olmo2ForCausalLM"), + } + + +class TestOlmo2ProviderBridgeArchitecturalFlags: + """Provider config flags that are load-bearing for OLMo-2 correctness.""" + + def test_returns_gpt_provider(self, mock_pretrained_olmo2_7b): + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_7b) + assert isinstance(provider, GPTModelProvider) + + def test_qk_layernorm_enabled(self, mock_pretrained_olmo2_7b): + """OLMo-2 normalizes Q and K before the attention dot product.""" + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_7b) + assert provider.qk_layernorm is True + + def test_post_norm_layer_spec_selected(self, mock_pretrained_olmo2_7b): + """The bridge must swap the default pre-norm spec for the OLMo-2 post-norm one.""" + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_7b) + assert provider.transformer_layer_spec is olmo2_layer_spec + + def test_no_biases(self, mock_pretrained_olmo2_7b): + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_7b) + assert provider.add_bias_linear is False + assert provider.add_qkv_bias is False + + def test_swiglu_mlp(self, mock_pretrained_olmo2_7b): + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_7b) + assert provider.gated_linear_unit is True + + def test_rmsnorm(self, mock_pretrained_olmo2_7b): + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_7b) + assert provider.normalization == "RMSNorm" + + def test_rope_position_embedding(self, mock_pretrained_olmo2_7b): + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_7b) + assert provider.position_embedding_type == "rope" + + def test_untied_word_embeddings(self, mock_pretrained_olmo2_7b): + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_7b) + assert provider.share_embeddings_and_output_weights is False + + +class TestOlmo2ProviderBridgeShapeFields: + """Numerical config translation from HF config → provider.""" + + def test_7b_dimensions(self, mock_pretrained_olmo2_7b, olmo2_7b_config): + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_7b) + assert provider.num_layers == olmo2_7b_config.num_hidden_layers == 32 + assert provider.hidden_size == olmo2_7b_config.hidden_size == 4096 + assert provider.num_attention_heads == olmo2_7b_config.num_attention_heads == 32 + assert provider.ffn_hidden_size == olmo2_7b_config.intermediate_size == 11008 + assert provider.vocab_size == olmo2_7b_config.vocab_size == 100352 + + def test_1b_dimensions(self, mock_pretrained_olmo2_1b, olmo2_1b_config): + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_1b) + assert provider.num_layers == olmo2_1b_config.num_hidden_layers == 16 + assert provider.hidden_size == olmo2_1b_config.hidden_size == 2048 + assert provider.num_attention_heads == 16 + assert provider.ffn_hidden_size == 8192 + + def test_kv_channels_derived_when_head_dim_missing(self, mock_pretrained_olmo2_7b): + """OLMo-2 HF configs do not include head_dim; the bridge must derive it.""" + provider = Olmo2Bridge().provider_bridge(mock_pretrained_olmo2_7b) + # 4096 / 32 = 128 + assert provider.kv_channels == 128 + + def test_kv_channels_uses_explicit_head_dim_when_present(self, olmo2_7b_config): + """If a future HF config grows a head_dim field, the bridge must respect it.""" + olmo2_7b_config.head_dim = 96 + pretrained = Mock(spec=PreTrainedCausalLM) + pretrained.config = olmo2_7b_config + pretrained.model = Mock(spec=Olmo2ForCausalLM) + pretrained.model.dtype = torch.bfloat16 + provider = Olmo2Bridge().provider_bridge(pretrained) + assert provider.kv_channels == 96 + + +class TestOlmo2MappingRegistry: + """Weight name routing — the load-bearing public contract of any bridge.""" + + @pytest.fixture + def registry(self): + return Olmo2Bridge().mapping_registry() + + @pytest.fixture + def hf_param_to_megatron(self, registry): + """Build a flat HF→Megatron param-name lookup for AutoMapping entries.""" + out: dict[str, str] = {} + for m in registry.mappings: + if isinstance(m, AutoMapping): + out[m.hf_param] = m.megatron_param + return out + + def test_embedding_routes_correctly(self, hf_param_to_megatron): + assert hf_param_to_megatron["model.embed_tokens.weight"] == "embedding.word_embeddings.weight" + + def test_output_layer_routes_correctly(self, hf_param_to_megatron): + assert hf_param_to_megatron["lm_head.weight"] == "output_layer.weight" + + def test_final_norm_routes_correctly(self, hf_param_to_megatron): + assert hf_param_to_megatron["model.norm.weight"] == "decoder.final_layernorm.weight" + + def test_post_attention_layernorm_routes_to_linear_proj_post_layernorm(self, hf_param_to_megatron): + """ + OLMo-2's ``post_attention_layernorm`` is an *output* norm. + It must NOT be routed into ``linear_qkv.layer_norm_weight`` (the + Llama/Qwen3 pre-MLP slot). It must go into ``linear_proj.post_layernorm.weight``. + """ + target = hf_param_to_megatron["model.layers.*.post_attention_layernorm.weight"] + assert target == "decoder.layers.*.self_attention.linear_proj.post_layernorm.weight" + + def test_post_feedforward_layernorm_routes_to_linear_fc2_post_layernorm(self, hf_param_to_megatron): + target = hf_param_to_megatron["model.layers.*.post_feedforward_layernorm.weight"] + assert target == "decoder.layers.*.mlp.linear_fc2.post_layernorm.weight" + + def test_q_norm_routes_to_q_layernorm(self, hf_param_to_megatron): + target = hf_param_to_megatron["model.layers.*.self_attn.q_norm.weight"] + assert target == "decoder.layers.*.self_attention.q_layernorm.weight" + + def test_k_norm_routes_to_k_layernorm(self, hf_param_to_megatron): + target = hf_param_to_megatron["model.layers.*.self_attn.k_norm.weight"] + assert target == "decoder.layers.*.self_attention.k_layernorm.weight" + + def test_attention_output_projection_routes(self, hf_param_to_megatron): + target = hf_param_to_megatron["model.layers.*.self_attn.o_proj.weight"] + assert target == "decoder.layers.*.self_attention.linear_proj.weight" + + def test_mlp_down_projection_routes(self, hf_param_to_megatron): + target = hf_param_to_megatron["model.layers.*.mlp.down_proj.weight"] + assert target == "decoder.layers.*.mlp.linear_fc2.weight" + + def test_no_pre_attention_layernorm_mapping(self, hf_param_to_megatron): + """ + OLMo-2 has NO pre-attention norm in HF, and Megatron-side it is + IdentityOp via the layer spec. There must be no mapping that tries + to populate ``linear_qkv.layer_norm_weight`` from ``input_layernorm``. + """ + assert "model.layers.*.input_layernorm.weight" not in hf_param_to_megatron + for hf_param, mg_param in hf_param_to_megatron.items(): + assert "linear_qkv.layer_norm_weight" not in mg_param, ( + f"OLMo-2 must not write into linear_qkv.layer_norm_weight; saw {hf_param} -> {mg_param}" + ) + + def test_no_pre_feedforward_layernorm_mapping(self, hf_param_to_megatron): + """Mirror of the above for the pre-MLP slot.""" + for hf_param, mg_param in hf_param_to_megatron.items(): + assert "linear_fc1.layer_norm_weight" not in mg_param, ( + f"OLMo-2 must not write into linear_fc1.layer_norm_weight; saw {hf_param} -> {mg_param}" + ) + + def test_qkv_fused_mapping_present(self, registry): + qkv_maps = [m for m in registry.mappings if isinstance(m, QKVMapping)] + assert len(qkv_maps) == 1 + m = qkv_maps[0] + assert m.megatron_param == "decoder.layers.*.self_attention.linear_qkv.weight" + assert m.q == "model.layers.*.self_attn.q_proj.weight" + assert m.k == "model.layers.*.self_attn.k_proj.weight" + assert m.v == "model.layers.*.self_attn.v_proj.weight" + + def test_gated_mlp_fused_mapping_present(self, registry): + gated_maps = [m for m in registry.mappings if isinstance(m, GatedMLPMapping)] + assert len(gated_maps) == 1 + m = gated_maps[0] + assert m.megatron_param == "decoder.layers.*.mlp.linear_fc1.weight" + assert m.gate == "model.layers.*.mlp.gate_proj.weight" + assert m.up == "model.layers.*.mlp.up_proj.weight" + + def test_no_qkv_bias_mapping(self, hf_param_to_megatron): + """OLMo-2 has ``attention_bias=False``; no bias weight should be mapped.""" + for hf_param, mg_param in hf_param_to_megatron.items(): + assert "self_attn.q_proj.bias" not in hf_param + assert "self_attn.k_proj.bias" not in hf_param + assert "self_attn.v_proj.bias" not in hf_param + assert "linear_qkv.bias" not in mg_param + + +class TestOlmo2LayerSpec: + """Verify the post-norm layer spec is structured correctly.""" + + @pytest.fixture + def spec(self): + return olmo2_layer_spec(config=None) + + def test_pre_attention_norm_is_identity(self, spec): + from megatron.core.transformer.identity_op import IdentityOp + + assert spec.submodules.input_layernorm is IdentityOp + + def test_pre_mlp_norm_is_identity(self, spec): + from megatron.core.transformer.identity_op import IdentityOp + + assert spec.submodules.pre_mlp_layernorm is IdentityOp + + def test_attention_uses_post_layernorm_linear_proj(self, spec): + attn = spec.submodules.self_attention + assert attn.submodules.linear_proj is TERowParallelLinearPostLN + + def test_mlp_uses_post_layernorm_fc2(self, spec): + mlp = spec.submodules.mlp + assert mlp.submodules.linear_fc2 is TERowParallelLinearPostLN + + def test_attention_has_qk_layernorm_slots(self, spec): + from megatron.core.extensions.transformer_engine import TENorm + + attn = spec.submodules.self_attention + assert attn.submodules.q_layernorm is TENorm + assert attn.submodules.k_layernorm is TENorm + + def test_attention_uses_plain_column_parallel_qkv(self, spec): + """Pre-attention norm is IdentityOp ⇒ linear_qkv must NOT carry a fused norm.""" + from megatron.core.extensions.transformer_engine import TEColumnParallelLinear + + attn = spec.submodules.self_attention + assert attn.submodules.linear_qkv is TEColumnParallelLinear + + def test_mlp_uses_plain_column_parallel_fc1(self, spec): + """Pre-MLP norm is IdentityOp ⇒ linear_fc1 must NOT carry a fused norm.""" + from megatron.core.extensions.transformer_engine import TEColumnParallelLinear + + mlp = spec.submodules.mlp + assert mlp.submodules.linear_fc1 is TEColumnParallelLinear + + +class TestOlmo2ModelProviderSizeVariants: + """Hardcoded size variants used by recipes — keep in sync with HF defaults.""" + + def test_1b_dimensions(self): + p = Olmo2ModelProvider1B() + assert p.num_layers == 16 + assert p.hidden_size == 2048 + assert p.num_attention_heads == 16 + assert p.ffn_hidden_size == 8192 + + def test_7b_dimensions(self): + p = Olmo2ModelProvider7B() + assert p.num_layers == 32 + assert p.hidden_size == 4096 + assert p.num_attention_heads == 32 + assert p.ffn_hidden_size == 11008 + + def test_13b_dimensions(self): + p = Olmo2ModelProvider13B() + assert p.num_layers == 40 + assert p.hidden_size == 5120 + assert p.num_attention_heads == 40 + assert p.ffn_hidden_size == 13824 + + @pytest.mark.parametrize("cls", [Olmo2ModelProvider1B, Olmo2ModelProvider7B, Olmo2ModelProvider13B]) + def test_all_size_variants_inherit_olmo2_defaults(self, cls): + p = cls() + assert p.qk_layernorm is True + assert p.normalization == "RMSNorm" + assert p.gated_linear_unit is True + assert p.add_bias_linear is False + assert p.add_qkv_bias is False + assert p.layernorm_epsilon == 1e-6 + assert p.rotary_base == 500000.0 + assert p.share_embeddings_and_output_weights is False + assert p.transformer_layer_spec is olmo2_layer_spec + + +class TestOlmo2ProviderBaseDefaults: + """The `Olmo2ModelProvider` base picks up OLMo-2 conventions even before recipe sizing.""" + + def test_base_provider_layer_spec(self): + p = Olmo2ModelProvider() + assert p.transformer_layer_spec is olmo2_layer_spec + + def test_base_provider_persist_layer_norm(self): + # OLMoE sets this; OLMo-2 follows the same convention. + p = Olmo2ModelProvider() + assert p.persist_layer_norm is True