From 3f3abd0b6f0aa31f4d69d764672736f69e0017f0 Mon Sep 17 00:00:00 2001 From: Gautham Kollu Date: Wed, 6 May 2026 01:03:17 -0700 Subject: [PATCH] FSDP GB300 DS-V3 Config with - fine-grained activation offloading Signed-off-by: Gautham Kollu --- .../performance/configs/deepseek/__init__.py | 10 +++++ .../deepseek_workload_base_configs.py | 37 +++++++++++++++++++ scripts/performance/run_script.py | 2 + scripts/performance/utils/overrides.py | 15 ++++++-- scripts/performance/utils/utils.py | 8 ++++ 5 files changed, 69 insertions(+), 3 deletions(-) diff --git a/scripts/performance/configs/deepseek/__init__.py b/scripts/performance/configs/deepseek/__init__.py index 647225b99c..ab219a7a84 100644 --- a/scripts/performance/configs/deepseek/__init__.py +++ b/scripts/performance/configs/deepseek/__init__.py @@ -43,13 +43,18 @@ DEEPSEEK_V3_PRETRAIN_CONFIG_GB200_FP8_MX_V2, DEEPSEEK_V3_PRETRAIN_CONFIG_GB200_NVFP4_V1, DEEPSEEK_V3_PRETRAIN_CONFIG_GB200_NVFP4_V2, + # FSDP + DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_BF16_FSDP, DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_BF16_V1, DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_BF16_V2, + DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_CS_FSDP, DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_CS_V1, DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_CS_V2, + DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_FSDP, DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_LARGE_SCALE, DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_V1, DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_V2, + DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_FSDP, DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_V1, DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_V2, DEEPSEEK_V3_PRETRAIN_CONFIG_H100_BF16_V1, @@ -125,6 +130,11 @@ "DEEPSEEK_V3_PRETRAIN_CONFIG_B300_FP8_MX_LARGE_SCALE", "DEEPSEEK_V3_PRETRAIN_CONFIG_B200_FP8_MX_LARGE_SCALE", "DEEPSEEK_V3_PRETRAIN_CONFIG_H100_FP8_SC_LARGE_SCALE", + # FSDP + "DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_BF16_FSDP", + "DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_CS_FSDP", + "DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_FSDP", + "DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_FSDP", ] if HAVE_MEGATRON_BRIDGE: diff --git a/scripts/performance/configs/deepseek/deepseek_workload_base_configs.py b/scripts/performance/configs/deepseek/deepseek_workload_base_configs.py index adc8ba617e..60c4ca0c2f 100644 --- a/scripts/performance/configs/deepseek/deepseek_workload_base_configs.py +++ b/scripts/performance/configs/deepseek/deepseek_workload_base_configs.py @@ -19,6 +19,7 @@ V1: GBS=2048 for Blackwell variants, GBS=8192 for H100 V2: GBS=4096 for Blackwell variants, GBS=16384 for H100 +FSDP: FSDP-based, no PP, GBS=256 for GB300 (64 GPUs) Use --config_variant to select a variant. Use --list_config_variants to see available variants interactively. @@ -259,6 +260,37 @@ ) +# ============================================================================= +# DeepSeek V3 Pretrain - FSDP (FSDP, no PP, GBS=256 for GB300) +# ============================================================================= + +DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FSDP = replace( + BASE_DEEPSEEK_V3_CONFIG, + num_gpus=64, + global_batch_size=256, + micro_batch_size=2, + pipeline_model_parallel_size=1, + expert_model_parallel_size=64, + use_megatron_fsdp=True, + moe_flex_dispatcher_backend="hybridep", + moe_a2a_overlap=False, + cuda_graph_scope=[], + recompute_modules=["layernorm", "mla_up_proj", "moe_act"], + fine_grained_activation_offloading=True, + offload_modules=["core_attn", "attn_proj"], + fp8_param_gather=True, + reuse_grad_buf_for_mxfp8_param_ag=True, +) +DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_BF16_FSDP = replace( + DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FSDP, + fp8_param_gather=None, + reuse_grad_buf_for_mxfp8_param_ag=None, +) +DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_CS_FSDP = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FSDP +DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_FSDP = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FSDP +DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_FSDP = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FSDP + + # ============================================================================= # DeepSeek V3 Pretrain - Large Scale Proxy # ============================================================================= @@ -344,6 +376,11 @@ "DEEPSEEK_V3_PRETRAIN_CONFIG_VR200_FP8_CS_V2", "DEEPSEEK_V3_PRETRAIN_CONFIG_VR200_FP8_MX_V2", "DEEPSEEK_V3_PRETRAIN_CONFIG_VR200_NVFP4_V2", + # FSDP (FSDP, GBS=256 for GB300) + "DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_BF16_FSDP", + "DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_CS_FSDP", + "DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_FSDP", + "DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_FSDP", # Large Scale Proxy "DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_LARGE_SCALE", "DEEPSEEK_V3_PRETRAIN_CONFIG_GB200_FP8_MX_LARGE_SCALE", diff --git a/scripts/performance/run_script.py b/scripts/performance/run_script.py index d6b1b4740c..5c797b45af 100644 --- a/scripts/performance/run_script.py +++ b/scripts/performance/run_script.py @@ -69,6 +69,8 @@ def main(): parser = parse_cli_args() args, cli_overrides = parser.parse_known_args() + # if fine-grained cpu-offloading + os.environ["NVTE_CPU_OFFLOAD_V1"] = "1" if args.dump_env: _dump_env_rank0() diff --git a/scripts/performance/utils/overrides.py b/scripts/performance/utils/overrides.py index 5c64aed563..9ccd48381e 100644 --- a/scripts/performance/utils/overrides.py +++ b/scripts/performance/utils/overrides.py @@ -94,9 +94,9 @@ def _set_megatron_fsdp_overrides(recipe: ConfigContainer, use_megatron_fsdp: boo logger.warning("Disabling deferring embedding wgrad compute because it cannot work with FSDP together.") recipe.comm_overlap.defer_embedding_wgrad_compute = False - if recipe.optimizer.use_precision_aware_optimizer: - recipe.optimizer.use_precision_aware_optimizer = False - logger.warning("Disabling precision aware optimizer because it cannot work with FSDP together.") + # if recipe.optimizer.use_precision_aware_optimizer: + # recipe.optimizer.use_precision_aware_optimizer = False + # logger.warning("Disabling precision aware optimizer because it cannot work with FSDP together.") recipe.checkpoint.load = None return recipe @@ -235,6 +235,15 @@ def set_workload_base_configs(cfg: ConfigContainer, settings: WorkloadBaseConfig cfg.model.quant_recipe = load_quantization_recipe(settings.te_precision_config_file) _set_common_perf_overrides(cfg) + if settings.fine_grained_activation_offloading is not None: + cfg.model.fine_grained_activation_offloading = settings.fine_grained_activation_offloading + if settings.offload_modules is not None: + cfg.model.offload_modules = settings.offload_modules + if settings.fp8_param_gather is not None: + cfg.mixed_precision.fp8_param_gather = settings.fp8_param_gather + if settings.reuse_grad_buf_for_mxfp8_param_ag is not None: + cfg.mixed_precision.reuse_grad_buf_for_mxfp8_param_ag = settings.reuse_grad_buf_for_mxfp8_param_ag + if settings.moe_flex_dispatcher_backend is not None: apply_flex_dispatcher_backend(cfg.model, settings.moe_flex_dispatcher_backend) elif hasattr(cfg.model, "moe_token_dispatcher_type"): diff --git a/scripts/performance/utils/utils.py b/scripts/performance/utils/utils.py index d1ea8e5465..bb1cef27c6 100644 --- a/scripts/performance/utils/utils.py +++ b/scripts/performance/utils/utils.py @@ -57,6 +57,14 @@ class WorkloadBaseConfig: recompute_num_layers: Optional[int] = None recompute_modules: Optional[List[str]] = None + # Fine-grained activation offloading + fine_grained_activation_offloading: Optional[bool] = None + offload_modules: Optional[List[str]] = None + + # FP8 parameter gather settings (used with FSDP) + fp8_param_gather: Optional[bool] = None + reuse_grad_buf_for_mxfp8_param_ag: Optional[bool] = None + # MoE configuration moe_flex_dispatcher_backend: Optional[str] = None moe_a2a_overlap: Optional[bool] = False