diff --git a/scripts/performance/configs/llama/llama31_llm_pretrain.py b/scripts/performance/configs/llama/llama31_llm_pretrain.py index 440584fc84..aefb5b3c72 100644 --- a/scripts/performance/configs/llama/llama31_llm_pretrain.py +++ b/scripts/performance/configs/llama/llama31_llm_pretrain.py @@ -43,6 +43,18 @@ def set_llama31_common_configs(cfg: ConfigContainer) -> None: cfg.ddp.grad_reduce_in_fp32 = False +def disable_param_gather_overlap(cfg: ConfigContainer) -> None: + """ + Disable parameter-gather overlap to reduce training peak memory and avoid OOM. + Note: This is a workaround and should be removed once the issue is fixed. + See: https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/3714 + """ + cfg.ddp.overlap_param_gather = False + cfg.optimizer.overlap_param_gather = False + cfg.comm_overlap.overlap_param_gather = False + cfg.comm_overlap.align_param_gather = False + + def llama31_405b_pretrain_config_gb300( precision: str = "bf16", mock: bool = True, config_variant: str = "v1" ) -> ConfigContainer: @@ -108,6 +120,8 @@ def llama31_405b_pretrain_config_gb200( cfg.comm_overlap.tp_comm_overlap_cfg = comm_overlap_cfg cfg.comm_overlap.tp_comm_overlap = False if precision == "nvfp4" else cfg.comm_overlap.tp_comm_overlap + if precision == "nvfp4" and config_variant.lower() == "v2": + disable_param_gather_overlap(cfg) return cfg