Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions scripts/performance/configs/llama/llama31_llm_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def set_llama31_common_configs(cfg: ConfigContainer) -> None:
cfg.ddp.grad_reduce_in_fp32 = False


def disable_param_gather_overlap(cfg: ConfigContainer) -> None:
"""
Disable parameter-gather overlap to reduce training peak memory and avoid OOM.
Note: This is a workaround and should be removed once the issue is fixed.
See: https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/3714
"""
cfg.ddp.overlap_param_gather = False
cfg.optimizer.overlap_param_gather = False
cfg.comm_overlap.overlap_param_gather = False
cfg.comm_overlap.align_param_gather = False


def llama31_405b_pretrain_config_gb300(
precision: str = "bf16", mock: bool = True, config_variant: str = "v1"
) -> ConfigContainer:
Expand Down Expand Up @@ -108,6 +120,8 @@ def llama31_405b_pretrain_config_gb200(

cfg.comm_overlap.tp_comm_overlap_cfg = comm_overlap_cfg
cfg.comm_overlap.tp_comm_overlap = False if precision == "nvfp4" else cfg.comm_overlap.tp_comm_overlap
if precision == "nvfp4" and config_variant.lower() == "v2":
disable_param_gather_overlap(cfg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malay-nagda can you comment why need disable?

Copy link
Copy Markdown
Contributor Author

@malay-nagda malay-nagda May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return cfg

Expand Down
Loading