From 151d3dc4f729cbb78b45592665b3c3f3c1ee1b69 Mon Sep 17 00:00:00 2001 From: Malay Nagda Date: Wed, 6 May 2026 22:57:42 +0530 Subject: [PATCH 1/2] Disable param gather overlap for Llama3.1 405B GB200 NVFP4 Signed-off-by: Malay Nagda --- .../performance/configs/llama/llama31_llm_pretrain.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/scripts/performance/configs/llama/llama31_llm_pretrain.py b/scripts/performance/configs/llama/llama31_llm_pretrain.py index 440584fc84..b78e645291 100644 --- a/scripts/performance/configs/llama/llama31_llm_pretrain.py +++ b/scripts/performance/configs/llama/llama31_llm_pretrain.py @@ -43,6 +43,14 @@ def set_llama31_common_configs(cfg: ConfigContainer) -> None: cfg.ddp.grad_reduce_in_fp32 = False +def disable_param_gather_overlap(cfg: ConfigContainer) -> None: + """Disable parameter-gather overlap to reduce training peak memory.""" + cfg.ddp.overlap_param_gather = False + cfg.optimizer.overlap_param_gather = False + cfg.comm_overlap.overlap_param_gather = False + cfg.comm_overlap.align_param_gather = False + + def llama31_405b_pretrain_config_gb300( precision: str = "bf16", mock: bool = True, config_variant: str = "v1" ) -> ConfigContainer: @@ -108,6 +116,8 @@ def llama31_405b_pretrain_config_gb200( cfg.comm_overlap.tp_comm_overlap_cfg = comm_overlap_cfg cfg.comm_overlap.tp_comm_overlap = False if precision == "nvfp4" else cfg.comm_overlap.tp_comm_overlap + if precision == "nvfp4" and config_variant.lower() == "v2": + disable_param_gather_overlap(cfg) return cfg From a8a6ed4098bd1f75a7c4aa9704361b8b40ee0d8b Mon Sep 17 00:00:00 2001 From: Malay Nagda Date: Mon, 11 May 2026 14:10:31 +0530 Subject: [PATCH 2/2] doc string Signed-off-by: Malay Nagda --- scripts/performance/configs/llama/llama31_llm_pretrain.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/performance/configs/llama/llama31_llm_pretrain.py b/scripts/performance/configs/llama/llama31_llm_pretrain.py index b78e645291..aefb5b3c72 100644 --- a/scripts/performance/configs/llama/llama31_llm_pretrain.py +++ b/scripts/performance/configs/llama/llama31_llm_pretrain.py @@ -44,7 +44,11 @@ def set_llama31_common_configs(cfg: ConfigContainer) -> None: def disable_param_gather_overlap(cfg: ConfigContainer) -> None: - """Disable parameter-gather overlap to reduce training peak memory.""" + """ + Disable parameter-gather overlap to reduce training peak memory and avoid OOM. + Note: This is a workaround and should be removed once the issue is fixed. + See: https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/3714 + """ cfg.ddp.overlap_param_gather = False cfg.optimizer.overlap_param_gather = False cfg.comm_overlap.overlap_param_gather = False