From 1c42640f6180126641bde9380bb60c943a0bd9d3 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 7 May 2026 15:09:49 -0700 Subject: [PATCH 1/5] add to deps Signed-off-by: Maanu Grover --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 05a1843222e..74c7ac7992b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ training = [ "wandb", "transformers", "accelerate", + "omegaconf", ] ### 'mlm' group is deprecated. please use 'training' instead ### @@ -85,6 +86,7 @@ mlm = [ "wandb", "transformers", "accelerate", + "omegaconf", ] dev = [ From 49a91494f2df6d40c9e47ef696ba2fceef195639 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 7 May 2026 15:10:57 -0700 Subject: [PATCH 2/5] guard omegaconf imports Signed-off-by: Maanu Grover --- megatron/training/config/container.py | 3 ++- megatron/training/config/instantiate_utils.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/megatron/training/config/container.py b/megatron/training/config/container.py index 505e0527bf7..8c87e3e8f11 100644 --- a/megatron/training/config/container.py +++ b/megatron/training/config/container.py @@ -4,7 +4,6 @@ from dataclasses import dataclass, field, fields as dataclass_fields, is_dataclass from typing import Any, Type, TypeVar import yaml -from omegaconf import OmegaConf from megatron.training.config.common_config import RNGConfig, DistributedInitConfig, ProfilingConfig from megatron.training.config.training_config import TokenizerConfig, TrainingConfig, ValidationConfig, SchedulerConfig, LoggerConfig, CheckpointConfig from megatron.core.optimizer import OptimizerConfig @@ -80,6 +79,8 @@ def from_yaml(cls: Type[T], yaml_path: str, mode: InstantiationMode = Instantiat Returns: A new instance of this class initialized with the YAML file values """ + from omegaconf import OmegaConf + if MultiStorageClientFeature.is_enabled(): msc = MultiStorageClientFeature.import_package() yaml_path_exists = msc.os.path.exists(yaml_path) diff --git a/megatron/training/config/instantiate_utils.py b/megatron/training/config/instantiate_utils.py index 362d8574690..3039b514c5b 100644 --- a/megatron/training/config/instantiate_utils.py +++ b/megatron/training/config/instantiate_utils.py @@ -7,8 +7,13 @@ from textwrap import dedent from typing import Any, Callable, Sequence -from omegaconf import OmegaConf -from omegaconf._utils import is_structured_config +try: + from omegaconf import OmegaConf + from omegaconf._utils import is_structured_config + + HAVE_OMEGACONF = True +except ImportError: + HAVE_OMEGACONF = False class InstantiationException(Exception): @@ -158,6 +163,11 @@ def instantiate( or instantiation fails in STRICT mode. TypeError: If the _partial_ flag is not a boolean. """ + if not HAVE_OMEGACONF: + raise ImportError( + "omegaconf is required for config instantiation. " + "Install via `pip install omegaconf`." + ) # Return None if config is None if config is None: From 8e36bf30f6c9028ba84dfe2d3a64a2117076b693 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 7 May 2026 15:17:50 -0700 Subject: [PATCH 3/5] formatting Signed-off-by: Maanu Grover --- megatron/training/config/container.py | 28 ++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/megatron/training/config/container.py b/megatron/training/config/container.py index 8c87e3e8f11..2e1ba5a652c 100644 --- a/megatron/training/config/container.py +++ b/megatron/training/config/container.py @@ -1,17 +1,31 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. import copy import os -from dataclasses import dataclass, field, fields as dataclass_fields, is_dataclass +from dataclasses import dataclass, field +from dataclasses import fields as dataclass_fields +from dataclasses import is_dataclass from typing import Any, Type, TypeVar + import yaml -from megatron.training.config.common_config import RNGConfig, DistributedInitConfig, ProfilingConfig -from megatron.training.config.training_config import TokenizerConfig, TrainingConfig, ValidationConfig, SchedulerConfig, LoggerConfig, CheckpointConfig -from megatron.core.optimizer import OptimizerConfig -from megatron.core.msc_utils import MultiStorageClientFeature + from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig -from megatron.training.config.resilience_config import RerunStateMachineConfig, StragglerDetectionConfig -from megatron.training.config.utils import sanitize_dataclass_config +from megatron.core.msc_utils import MultiStorageClientFeature +from megatron.core.optimizer import OptimizerConfig +from megatron.training.config.common_config import DistributedInitConfig, ProfilingConfig, RNGConfig from megatron.training.config.instantiate_utils import InstantiationMode, instantiate +from megatron.training.config.resilience_config import ( + RerunStateMachineConfig, + StragglerDetectionConfig, +) +from megatron.training.config.training_config import ( + CheckpointConfig, + LoggerConfig, + SchedulerConfig, + TokenizerConfig, + TrainingConfig, + ValidationConfig, +) +from megatron.training.config.utils import sanitize_dataclass_config from megatron.training.config.yaml_utils import safe_yaml_representers T = TypeVar("T", bound="ConfigContainerBase") From f07d348766727c05be783671457f36e7878fcae1 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 7 May 2026 15:28:23 -0700 Subject: [PATCH 4/5] update lockfile Signed-off-by: Maanu Grover --- uv.lock | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/uv.lock b/uv.lock index 59527b68746..3051afe6b06 100644 --- a/uv.lock +++ b/uv.lock @@ -2629,6 +2629,7 @@ lts = [ mlm = [ { name = "accelerate" }, { name = "flask-restful" }, + { name = "omegaconf" }, { name = "sentencepiece" }, { name = "tiktoken" }, { name = "transformers" }, @@ -2644,6 +2645,7 @@ te = [ training = [ { name = "accelerate" }, { name = "flask-restful" }, + { name = "omegaconf" }, { name = "sentencepiece" }, { name = "tiktoken" }, { name = "transformers" }, @@ -2732,6 +2734,8 @@ requires-dist = [ { name = "numpy" }, { name = "nvidia-modelopt", extras = ["torch"], marker = "sys_platform != 'darwin' and extra == 'dev'" }, { name = "nvidia-resiliency-ext", marker = "extra == 'dev'", git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git?rev=b2bb3d728a18795807d9f76c535e005a609a1b01" }, + { name = "omegaconf", marker = "extra == 'mlm'" }, + { name = "omegaconf", marker = "extra == 'training'" }, { name = "onnxscript", marker = "extra == 'dev'" }, { name = "onnxscript", marker = "extra == 'lts'" }, { name = "openai", extras = ["aiohttp"], marker = "extra == 'dev'" }, @@ -3278,16 +3282,16 @@ resolution-markers = [ "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'emscripten'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/a9/75/10dd1f8116a8b796cb2c737b674e02d02e80454bda953fa7e65d8c12b016/numpy-2.0.2.tar.gz", hash = "sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78", size = 18902015, upload-time = "2024-08-26T20:19:40.945Z" } @@ -4313,16 +4317,16 @@ resolution-markers = [ "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'emscripten'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } @@ -4371,16 +4375,16 @@ resolution-markers = [ "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'emscripten'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] dependencies = [ From ead17605bc809996440f5170a7362124648829fc Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 7 May 2026 18:58:46 -0700 Subject: [PATCH 5/5] fix mock imports Signed-off-by: Maanu Grover --- tests/unit_tests/training/config/test_container_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/training/config/test_container_base.py b/tests/unit_tests/training/config/test_container_base.py index 2b87c69679f..3dbec27856a 100644 --- a/tests/unit_tests/training/config/test_container_base.py +++ b/tests/unit_tests/training/config/test_container_base.py @@ -203,7 +203,7 @@ def test_from_yaml_file_not_found(self): TestConfigContainer.from_yaml("non_existent_file.yaml") @patch("megatron.training.config.container.MultiStorageClientFeature.is_enabled") - @patch("megatron.training.config.container.OmegaConf") + @patch("omegaconf.OmegaConf") @patch("builtins.open", new_callable=mock_open) @patch("os.path.exists") def test_from_yaml_success(self, mock_exists, mock_file, mock_omegaconf, mock_msc): @@ -251,7 +251,7 @@ def test_from_yaml_with_mode(self, mock_exists, mock_msc): with patch("builtins.open", mock_open()): with patch("yaml.safe_load", return_value={}): - with patch("megatron.training.config.container.OmegaConf") as mock_omegaconf: + with patch("omegaconf.OmegaConf") as mock_omegaconf: # Mock OmegaConf methods to return expected values mock_conf = MagicMock() mock_omegaconf.create.return_value = mock_conf