Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions megatron/training/config/container.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,31 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
import copy
import os
from dataclasses import dataclass, field, fields as dataclass_fields, is_dataclass
from dataclasses import dataclass, field
from dataclasses import fields as dataclass_fields
from dataclasses import is_dataclass
from typing import Any, Type, TypeVar

import yaml
from omegaconf import OmegaConf
from megatron.training.config.common_config import RNGConfig, DistributedInitConfig, ProfilingConfig
from megatron.training.config.training_config import TokenizerConfig, TrainingConfig, ValidationConfig, SchedulerConfig, LoggerConfig, CheckpointConfig
from megatron.core.optimizer import OptimizerConfig
from megatron.core.msc_utils import MultiStorageClientFeature

from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
from megatron.training.config.resilience_config import RerunStateMachineConfig, StragglerDetectionConfig
from megatron.training.config.utils import sanitize_dataclass_config
from megatron.core.msc_utils import MultiStorageClientFeature
from megatron.core.optimizer import OptimizerConfig
from megatron.training.config.common_config import DistributedInitConfig, ProfilingConfig, RNGConfig
from megatron.training.config.instantiate_utils import InstantiationMode, instantiate
from megatron.training.config.resilience_config import (
RerunStateMachineConfig,
StragglerDetectionConfig,
)
from megatron.training.config.training_config import (
CheckpointConfig,
LoggerConfig,
SchedulerConfig,
TokenizerConfig,
TrainingConfig,
ValidationConfig,
)
from megatron.training.config.utils import sanitize_dataclass_config
from megatron.training.config.yaml_utils import safe_yaml_representers

T = TypeVar("T", bound="ConfigContainerBase")
Expand Down Expand Up @@ -80,6 +93,8 @@ def from_yaml(cls: Type[T], yaml_path: str, mode: InstantiationMode = Instantiat
Returns:
A new instance of this class initialized with the YAML file values
"""
from omegaconf import OmegaConf

if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
yaml_path_exists = msc.os.path.exists(yaml_path)
Expand Down
14 changes: 12 additions & 2 deletions megatron/training/config/instantiate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
from textwrap import dedent
from typing import Any, Callable, Sequence

from omegaconf import OmegaConf
from omegaconf._utils import is_structured_config
try:
from omegaconf import OmegaConf
from omegaconf._utils import is_structured_config

HAVE_OMEGACONF = True
except ImportError:
HAVE_OMEGACONF = False


class InstantiationException(Exception):
Expand Down Expand Up @@ -158,6 +163,11 @@ def instantiate(
or instantiation fails in STRICT mode.
TypeError: If the _partial_ flag is not a boolean.
"""
if not HAVE_OMEGACONF:
raise ImportError(
"omegaconf is required for config instantiation. "
"Install via `pip install omegaconf`."
)

# Return None if config is None
if config is None:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ training = [
"wandb",
"transformers",
"accelerate",
"omegaconf",
]

### 'mlm' group is deprecated. please use 'training' instead ###
Expand All @@ -85,6 +86,7 @@ mlm = [
"wandb",
"transformers",
"accelerate",
"omegaconf",
]

dev = [
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/training/config/test_container_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_from_yaml_file_not_found(self):
TestConfigContainer.from_yaml("non_existent_file.yaml")

@patch("megatron.training.config.container.MultiStorageClientFeature.is_enabled")
@patch("megatron.training.config.container.OmegaConf")
@patch("omegaconf.OmegaConf")
@patch("builtins.open", new_callable=mock_open)
@patch("os.path.exists")
def test_from_yaml_success(self, mock_exists, mock_file, mock_omegaconf, mock_msc):
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_from_yaml_with_mode(self, mock_exists, mock_msc):

with patch("builtins.open", mock_open()):
with patch("yaml.safe_load", return_value={}):
with patch("megatron.training.config.container.OmegaConf") as mock_omegaconf:
with patch("omegaconf.OmegaConf") as mock_omegaconf:
# Mock OmegaConf methods to return expected values
mock_conf = MagicMock()
mock_omegaconf.create.return_value = mock_conf
Expand Down
22 changes: 13 additions & 9 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading