diff --git a/megatron/training/config/container.py b/megatron/training/config/container.py index 505e0527bf7..2e1ba5a652c 100644 --- a/megatron/training/config/container.py +++ b/megatron/training/config/container.py @@ -1,18 +1,31 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. import copy import os -from dataclasses import dataclass, field, fields as dataclass_fields, is_dataclass +from dataclasses import dataclass, field +from dataclasses import fields as dataclass_fields +from dataclasses import is_dataclass from typing import Any, Type, TypeVar + import yaml -from omegaconf import OmegaConf -from megatron.training.config.common_config import RNGConfig, DistributedInitConfig, ProfilingConfig -from megatron.training.config.training_config import TokenizerConfig, TrainingConfig, ValidationConfig, SchedulerConfig, LoggerConfig, CheckpointConfig -from megatron.core.optimizer import OptimizerConfig -from megatron.core.msc_utils import MultiStorageClientFeature + from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig -from megatron.training.config.resilience_config import RerunStateMachineConfig, StragglerDetectionConfig -from megatron.training.config.utils import sanitize_dataclass_config +from megatron.core.msc_utils import MultiStorageClientFeature +from megatron.core.optimizer import OptimizerConfig +from megatron.training.config.common_config import DistributedInitConfig, ProfilingConfig, RNGConfig from megatron.training.config.instantiate_utils import InstantiationMode, instantiate +from megatron.training.config.resilience_config import ( + RerunStateMachineConfig, + StragglerDetectionConfig, +) +from megatron.training.config.training_config import ( + CheckpointConfig, + LoggerConfig, + SchedulerConfig, + TokenizerConfig, + TrainingConfig, + ValidationConfig, +) +from megatron.training.config.utils import sanitize_dataclass_config from megatron.training.config.yaml_utils import safe_yaml_representers T = TypeVar("T", bound="ConfigContainerBase") @@ -80,6 +93,8 @@ def from_yaml(cls: Type[T], yaml_path: str, mode: InstantiationMode = Instantiat Returns: A new instance of this class initialized with the YAML file values """ + from omegaconf import OmegaConf + if MultiStorageClientFeature.is_enabled(): msc = MultiStorageClientFeature.import_package() yaml_path_exists = msc.os.path.exists(yaml_path) diff --git a/megatron/training/config/instantiate_utils.py b/megatron/training/config/instantiate_utils.py index 362d8574690..3039b514c5b 100644 --- a/megatron/training/config/instantiate_utils.py +++ b/megatron/training/config/instantiate_utils.py @@ -7,8 +7,13 @@ from textwrap import dedent from typing import Any, Callable, Sequence -from omegaconf import OmegaConf -from omegaconf._utils import is_structured_config +try: + from omegaconf import OmegaConf + from omegaconf._utils import is_structured_config + + HAVE_OMEGACONF = True +except ImportError: + HAVE_OMEGACONF = False class InstantiationException(Exception): @@ -158,6 +163,11 @@ def instantiate( or instantiation fails in STRICT mode. TypeError: If the _partial_ flag is not a boolean. """ + if not HAVE_OMEGACONF: + raise ImportError( + "omegaconf is required for config instantiation. " + "Install via `pip install omegaconf`." + ) # Return None if config is None if config is None: diff --git a/pyproject.toml b/pyproject.toml index 05a1843222e..74c7ac7992b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ training = [ "wandb", "transformers", "accelerate", + "omegaconf", ] ### 'mlm' group is deprecated. please use 'training' instead ### @@ -85,6 +86,7 @@ mlm = [ "wandb", "transformers", "accelerate", + "omegaconf", ] dev = [ diff --git a/tests/unit_tests/training/config/test_container_base.py b/tests/unit_tests/training/config/test_container_base.py index 2b87c69679f..3dbec27856a 100644 --- a/tests/unit_tests/training/config/test_container_base.py +++ b/tests/unit_tests/training/config/test_container_base.py @@ -203,7 +203,7 @@ def test_from_yaml_file_not_found(self): TestConfigContainer.from_yaml("non_existent_file.yaml") @patch("megatron.training.config.container.MultiStorageClientFeature.is_enabled") - @patch("megatron.training.config.container.OmegaConf") + @patch("omegaconf.OmegaConf") @patch("builtins.open", new_callable=mock_open) @patch("os.path.exists") def test_from_yaml_success(self, mock_exists, mock_file, mock_omegaconf, mock_msc): @@ -251,7 +251,7 @@ def test_from_yaml_with_mode(self, mock_exists, mock_msc): with patch("builtins.open", mock_open()): with patch("yaml.safe_load", return_value={}): - with patch("megatron.training.config.container.OmegaConf") as mock_omegaconf: + with patch("omegaconf.OmegaConf") as mock_omegaconf: # Mock OmegaConf methods to return expected values mock_conf = MagicMock() mock_omegaconf.create.return_value = mock_conf diff --git a/uv.lock b/uv.lock index 59527b68746..3051afe6b06 100644 --- a/uv.lock +++ b/uv.lock @@ -2629,6 +2629,7 @@ lts = [ mlm = [ { name = "accelerate" }, { name = "flask-restful" }, + { name = "omegaconf" }, { name = "sentencepiece" }, { name = "tiktoken" }, { name = "transformers" }, @@ -2644,6 +2645,7 @@ te = [ training = [ { name = "accelerate" }, { name = "flask-restful" }, + { name = "omegaconf" }, { name = "sentencepiece" }, { name = "tiktoken" }, { name = "transformers" }, @@ -2732,6 +2734,8 @@ requires-dist = [ { name = "numpy" }, { name = "nvidia-modelopt", extras = ["torch"], marker = "sys_platform != 'darwin' and extra == 'dev'" }, { name = "nvidia-resiliency-ext", marker = "extra == 'dev'", git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git?rev=b2bb3d728a18795807d9f76c535e005a609a1b01" }, + { name = "omegaconf", marker = "extra == 'mlm'" }, + { name = "omegaconf", marker = "extra == 'training'" }, { name = "onnxscript", marker = "extra == 'dev'" }, { name = "onnxscript", marker = "extra == 'lts'" }, { name = "openai", extras = ["aiohttp"], marker = "extra == 'dev'" }, @@ -3278,16 +3282,16 @@ resolution-markers = [ "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'emscripten'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/a9/75/10dd1f8116a8b796cb2c737b674e02d02e80454bda953fa7e65d8c12b016/numpy-2.0.2.tar.gz", hash = "sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78", size = 18902015, upload-time = "2024-08-26T20:19:40.945Z" } @@ -4313,16 +4317,16 @@ resolution-markers = [ "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'emscripten'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } @@ -4371,16 +4375,16 @@ resolution-markers = [ "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'emscripten'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] dependencies = [