diff --git a/examples/megatron_fsdp/README.md b/examples/megatron_fsdp/README.md index eaf5eca1364..cc37911c12d 100644 --- a/examples/megatron_fsdp/README.md +++ b/examples/megatron_fsdp/README.md @@ -19,22 +19,32 @@ bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh # With real data bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh \ checkpoints/llama3_8b_fsdp_fp8 \ - tensorboard_logs/llama3_8b_fsdp_fp8 \ + /path/to/data_prefix \ /path/to/tokenizer \ - /path/to/data_prefix + nsys_profiles/llama3_8b_fsdp_fp8 \ + tensorboard_logs/llama3_8b_fsdp_fp8 + +# With Nsight Systems profiling (steps 4–6 on rank 0) +NSYS_PROFILE=1 bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh + +# Without uv (use the ambient `python`) +USE_UV=0 bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh ``` | Positional Argument | Default | Description | |---------------------|---------|-------------| -| `$1` — Checkpoint path | `checkpoints/llama3_8b_fsdp_fp8` | Directory for saving and loading checkpoints. | -| `$2` — TensorBoard path | `tensorboard_logs/llama3_8b_fsdp_fp8` | Directory for TensorBoard logs. | +| `$1` — Checkpoint Path | `checkpoints/llama3_8b_fsdp_fp8` | Directory for saving and loading checkpoints. | +| `$2` — Data Path | `MOCK` | Data prefix for training data, or `MOCK` for mock data. | | `$3` — Tokenizer | `MOCK` | Path to a tokenizer model, or `MOCK` for `NullTokenizer`. | -| `$4` — Data path | `MOCK` | Data prefix for training data, or `MOCK` for mock data. | +| `$4` — NSight Profiling Path | `nsys_profiles/llama3_8b_fsdp_fp8` | Output path (without extension) for the `.nsys-rep` file when `NSYS_PROFILE=1`. | +| `$5` — TensorBoard Path | `tensorboard_logs/llama3_8b_fsdp_fp8` | Directory for TensorBoard logs. | #### Environment Variables | Variable | Default | Description | |----------|---------|-------------| +| `USE_UV` | `1` | Set to `1` to launch via `uv run` (project venv). Set to `0` to use the ambient `python`. | +| `NSYS_PROFILE` | `0` | Set to `1` to wrap the launch in `nsys profile`. Captures steps 4–6 on rank 0 via `--capture-range=cudaProfilerApi`, with CUDA graph node tracing and CUDA memory usage enabled. Output goes to the path in `$4`. | | `USE_MEGATRON_FSDP` | `1` | Set to `1` to enable Megatron-FSDP. Set to `0` to train with standard DDP. | | `SHARDING_STRATEGY` | `optim_grads_params` | FSDP sharding strategy (ZeRO-3). Options: `no_shard`, `optim`, `optim_grads`, `optim_grads_params`. | | `OUTER_SHARDING_STRATEGY` | `no_shard` | DP-Outer sharding strategy for HSDP/HFSDP. Options: `no_shard`, `optim`. | @@ -49,6 +59,7 @@ bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh \ - **Precision**: FP8 (hybrid format) with BF16 training and BF16 gradient reduction - **Batch size**: micro-batch=1, global-batch=128, sequence length=8192 - **Optimizations**: NCCL user buffers, FSDP double buffering, manual registration, meta-device initialization, per-token loss, overlapped grad-reduce and param-gather +- **Launch**: `[uv run] [nsys profile ...] python -m torch.distributed.run ... pretrain_gpt.py ...` — the `uv run` and `nsys profile` prefixes are toggled by `USE_UV` and `NSYS_PROFILE` respectively. --- diff --git a/examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh b/examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh old mode 100644 new mode 100755 index ddd3f160fa7..6e0a2fecd08 --- a/examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh +++ b/examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh @@ -1,12 +1,14 @@ #!/bin/bash CHECKPOINT_PATH=${1:-"checkpoints/llama3_8b_fsdp_fp8"} -TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama3_8b_fsdp_fp8"} -TOKENIZER_ARG=${3:-"MOCK"} # Path to tokenizer model, or "MOCK" -DATA_ARG=${4:-"MOCK"} # Data prefix, or "MOCK" +DATA_ARG=${2:-"MOCK"} # Data prefix, or "MOCK" +TOKENIZER_ARG=${3:-"MOCK"} # Path to tokenizer model, or "MOCK" +NSYS_PROFILE_PATH=${4:-"nsys_profiles/llama3_8b_fsdp_fp8"} +TENSORBOARD_LOGS_PATH=${5:-"tensorboard_logs/llama3_8b_fsdp_fp8"} # Create directories if they don't exist mkdir -p "$(dirname "$CHECKPOINT_PATH")" +mkdir -p "$(dirname "$NSYS_PROFILE_PATH")" mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")" # Distributed training setup @@ -21,6 +23,18 @@ WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) # is run from the root of the Megatron-LM repository. PRETRAIN_SCRIPT_PATH="pretrain_gpt.py" +# NSight Profiling +NSYS_PROFILE=${NSYS_PROFILE:-0} + +# Optional `uv run` venv prefix. With uv, nsys (and its child workers) all +# inherit the project venv. Without uv, fall back to the ambient `python`. +USE_UV=${USE_UV:-1} +if [ "${USE_UV}" = 1 ]; then + VENV_PREFIX="uv run" +else + VENV_PREFIX="" +fi + # Model & Training Parameters USE_MEGATRON_FSDP=${USE_MEGATRON_FSDP:-1} SHARDING_STRATEGY=${SHARDING_STRATEGY:-"optim_grads_params"} @@ -89,6 +103,7 @@ TRAINING_ARGS=( --adam-beta2 0.95 --bf16 --cross-entropy-loss-fusion + --no-check-for-nan-in-loss-and-grad --manual-gc --empty-unused-memory-level 1 --exit-duration-in-mins 235 @@ -104,7 +119,7 @@ if [ "${USE_MEGATRON_FSDP}" = 1 ]; then --calculate-per-token-loss --init-model-with-meta-device --ckpt-format fsdp_dtensor - --grad-reduce-in-bf16 + --grad-reduce-in-bf16 # Will be deprecated soon! --use-nccl-ub --fsdp-double-buffer --fsdp-manual-registration @@ -116,6 +131,11 @@ if [ "${USE_MEGATRON_FSDP}" = 1 ]; then # --megatron-fsdp-main-params-dtype fp32 # --megatron-fsdp-main-grads-dtype auto # --megatron-fsdp-grad-comm-dtype auto + # To use decoupled (mixed-precision) gradients... + # --use-precision-aware-optimizer + # To use full-iteration CUDA graphs with Megatron-FSDP... + # --cuda-graph-impl local + # --cuda-graph-scope full_iteration ) fi @@ -183,15 +203,32 @@ EVAL_AND_LOGGING_ARGS=( --eval-interval 100 --save-interval 1000 --log-throughput - --profile - --profile-step-start 4 - --profile-step-end 6 --distributed-timeout-minutes 60 --save "$CHECKPOINT_PATH" --load "$CHECKPOINT_PATH" --tensorboard-dir "$TENSORBOARD_LOGS_PATH" ) +# Profiling (NSYS_PROFILE=1 bash ...) +if [ "${NSYS_PROFILE}" = 1 ]; then + TRAINING_ARGS+=( + --profile + --profile-step-start 8 + --profile-step-end 12 + --profile-ranks 0 + ) + PROFILE_CMD=( + nsys profile + --sample=none --cpuctxsw=none + --trace=cuda,nvtx,cublas,cudnn + --capture-range=cudaProfilerApi --capture-range-end=stop + --cuda-graph-trace=node --cuda-memory-usage=true + -f true -x true -o "$NSYS_PROFILE_PATH" + ) +else + PROFILE_CMD=() +fi + # Ensure pretrain_gpt.py is found if [ ! -f "$PRETRAIN_SCRIPT_PATH" ]; then echo "Error: pretrain_gpt.py not found at $PRETRAIN_SCRIPT_PATH" @@ -199,8 +236,8 @@ if [ ! -f "$PRETRAIN_SCRIPT_PATH" ]; then exit 1 fi -# Run the training command -torchrun ${DISTRIBUTED_ARGS[@]} \ +# Run the training command. +$VENV_PREFIX "${PROFILE_CMD[@]}" python -m torch.distributed.run ${DISTRIBUTED_ARGS[@]} \ "$PRETRAIN_SCRIPT_PATH" \ ${MODEL_ARGS[@]} \ ${TRAINING_ARGS[@]} \ diff --git a/megatron/core/distributed/distributed_data_parallel_config.py b/megatron/core/distributed/distributed_data_parallel_config.py index 50878e149de..d79fd223fbc 100644 --- a/megatron/core/distributed/distributed_data_parallel_config.py +++ b/megatron/core/distributed/distributed_data_parallel_config.py @@ -206,6 +206,17 @@ class DistributedDataParallelConfig: main gradients to parameter dtype for `.grad`. """ + megatron_fsdp_cuda_graph_mode: bool = False + """If set to True, Megatron-FSDP will practice CUDA graph-safe operations, such as + not dereferencing `param.grad` after the optimizer step to preserve references for + CUDA graph replay. Can affect memory utilization in some cases, such as when the + gradient shard is not a view of the Megatron-FSDP sharded gradient buffer, so + FusedAdam(use_decoupled_grad=True) + megatron_fsdp_use_decoupled_grad=True or + setting megatron_fsdp_main_params_dtype == megatron_fsdp_main_grads_dtype is + recommended to avoid casting the gradient to the parameter precision and creating + a casted-copy of the gradient shard that cannot be dereferenced due to replay. + """ + def __post_init__(self): import os diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py index 9c31b280875..e35db5aac5e 100644 --- a/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/core/distributed/finalize_model_grads.py @@ -536,7 +536,10 @@ def finalize_model_grads( # all-reduce across DP ranks. torch.distributed.all_reduce(num_tokens, group=dp_cp_group) + + # Clamp to avoid div-by-zero without a host-side branch on a device tensor, + # which would otherwise cause a sync that is illegal during CUDA graph capture. + safe_num_tokens = torch.clamp(num_tokens, min=1) + scaling = 1.0 / safe_num_tokens for model_chunk in model: - if num_tokens > 0: - scaling = 1.0 / num_tokens - model_chunk.scale_gradients(scaling) + model_chunk.scale_gradients(scaling) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py index a2feb99cb23..d866a47e849 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py @@ -151,6 +151,17 @@ class DistributedDataParallelConfig: main gradients to parameter dtype for `.grad`. """ + megatron_fsdp_cuda_graph_mode: bool = False + """If set to True, Megatron-FSDP will practice CUDA graph-safe operations, such as + not dereferencing `param.grad` after the optimizer step to preserve references for + CUDA graph replay. Can affect memory utilization in some cases, such as when the + gradient shard is not a view of the Megatron-FSDP sharded gradient buffer, so + FusedAdam(use_decoupled_grad=True) + megatron_fsdp_use_decoupled_grad=True or + setting megatron_fsdp_main_params_dtype == megatron_fsdp_main_grads_dtype is + recommended to avoid casting the gradient to the parameter precision and creating + a casted-copy of the gradient shard that cannot be dereferenced due to replay. + """ + def __post_init__(self): import os diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py index a1f7fabd50a..92e7ef733c0 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py @@ -104,6 +104,7 @@ def fully_shard_model( disable_symmetric_registration: bool = False, enable_fine_grained_param_gather: bool = False, use_decoupled_grad: bool = False, + cuda_graph_mode: bool = False, ) -> torch.nn.Module: """ Fully-shard the model for Megatron-FSDP. This wraps the model in a MegatronFSDP @@ -265,6 +266,17 @@ class that schedules the sharding lifecycle of the model parameters and gradient If true, reduced gradients are installed into `Parameter.decoupled_grad` instead of `Parameter.grad`. Defaults to False. + cuda_graph_mode (bool): + If true, Megatron-FSDP will practice CUDA graph-safe operations, such as + not dereferencing `param.grad` after the optimizer step to preserve references + for CUDA graph replay. Can affect memory utilization in some cases, such as + when the gradient shard is not a view of the Megatron-FSDP sharded gradient + buffer, so `FusedAdam(use_decoupled_grad=True) + use_decoupled_grad=True` or + setting `megatron_fsdp_main_params_dtype == megatron_fsdp_main_grads_dtype` + is recommended to avoid casting the gradient to the parameter precision and + creating a casted-copy of the gradient shard that cannot be dereferenced due + to replay. Defaults to False. + Returns: model (MegatronFSDP): The wrapped Megatron-FSDP model configured for FSDP. """ @@ -360,6 +372,7 @@ class that schedules the sharding lifecycle of the model parameters and gradient fsdp_db_use_persist_buf_on_alloc_fail=fsdp_db_use_persist_buf_on_alloc_fail, disable_symmetric_registration=disable_symmetric_registration, megatron_fsdp_use_decoupled_grad=use_decoupled_grad, + megatron_fsdp_cuda_graph_mode=cuda_graph_mode, ) # Create FSDPDistributedIndex. @@ -666,6 +679,7 @@ def fully_shard( disable_symmetric_registration: bool = False, enable_fine_grained_param_gather: bool = False, use_decoupled_grad: bool = False, + cuda_graph_mode: bool = False, ) -> tuple[MegatronFSDP, torch.optim.Optimizer]: """ Fully shard the model and the optimizer for Megatron-FSDP. @@ -717,6 +731,7 @@ def fully_shard( disable_symmetric_registration=disable_symmetric_registration, enable_fine_grained_param_gather=enable_fine_grained_param_gather, use_decoupled_grad=use_decoupled_grad, + cuda_graph_mode=cuda_graph_mode, ) # Extend optimizer methods to support Megatron-FSDP operations. diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 2a17315611a..ddfb020d4f6 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -596,6 +596,8 @@ def _grad_acc(param): # Sharded Gradient Buffer gbuf = group.hfsdp_helper_gbuf if group.hfsdp_helper_gbuf else group.main_grad_buffer if gbuf.is_data_distributed: + # If TransformerEngine gradient accumulation is fused, then param.get_main_grad() + # already holds the wgrad and param.grad_added_to_main_grad=True. if not param.grad_added_to_main_grad: # Get `main_grad` will allocate bucket, check that the currently # used main_grad buffer does not exceed the scope of two FSDP Unit @@ -612,7 +614,6 @@ def _grad_acc(param): param.main_grad.copy_(to_local_if_dtensor(param.grad)) del param.grad else: - # Prepare for fused wgrad accumulation. param.main_grad.zero_() # Unsharded Gradient Buffer else: diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index 960acc25ef6..8d5ddd2d26f 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -2777,14 +2777,31 @@ def zero_grad(self): """ Zero out the underlying grad_buffer and reset all buckets in preparation for the next iteration of training. + + Gradient shards are dereferenced to free memory. However, dereferencing is + not compatible with (FWD-BWD / full-iteration) CUDA graph-ability, because + we need to preserve this reference to the sharded gradient generated during + CUDA graph replay (`setattr` in `update_main_grads` not executed during + CUDA graph replay, as it is not a CUDA kernel). + + If the gradient is decoupled (precision-aware) or is equivalent to the + distributed optimizer parameter precision, the gradient shard is a view of + the Megatron-FSDP sharded gradient buffer. If not, then not dereferencing + this gradient shard will increase memory utilization as this gradient is a + persistent casted-copy of the accumulated gradient. """ for name, param in self.optimizer_named_parameters: - param.grad = None - if hasattr(param, "decoupled_grad"): - param.decoupled_grad = None - if name in self.dist_main_grad: - self.dist_main_grad[name]._local_tensor = None - + if not self.ddp_config.megatron_fsdp_cuda_graph_mode: + # Dereference the sharded gradient to reclaim memory + # unless a full-iteration CUDA graph is utilized. + param.grad = None + if hasattr(param, "decoupled_grad"): + param.decoupled_grad = None + if name in self.dist_main_grad: + self.dist_main_grad[name]._local_tensor = None + + # Zero the Megatron-FSDP sharded gradient buffer. If param.grad or param.decoupled_grad + # is a view of this buffer, they will be zero'd as well. for group in self.parameter_groups: if group.main_grad_buffer: group.main_grad_buffer.data.zero_() @@ -2923,11 +2940,6 @@ def update_main_grads(self): from the main gradient buffer. If the model parameters are sharded, we only need to update the gradient shard associated with the model parameter shard, as both are sharded symmetrically. - - Checks if high-precision main weights are utilized for optimization. - Otherwise, falls back to low-precision model weights, and further - falls back to the original module parameters not managed by cFSDP - in the case of no sharding / cFSDP OFF. """ for name, param in self.optimizer_named_parameters: orig_param = param.orig_param @@ -2948,11 +2960,11 @@ def update_main_grads(self): optimizer_grad = group.main_grad_buffer.get_item( item_id, only_shard=sharded_optimizer_state ) - if group.main_weight_buffer is not None: - if not self.use_decoupled_grad: - # Convert the gradient to the main weight buffer dtype. - # TODO(@cspades): Why this is necessary? Casted below. - optimizer_grad = optimizer_grad.to(param.dtype) + if group.main_weight_buffer is not None and not self.use_decoupled_grad: + # Convert the gradient to the main weight data-type for optimization. + # Not needed for decoupled gradients, because the precision-aware + # optimizer can apply gradients to parameters of different precision! + optimizer_grad = optimizer_grad.to(param.dtype) if name not in self.dist_main_grad: # Register the gradient as a distributed tensor. @@ -2981,7 +2993,7 @@ def update_main_grads(self): setattr(param, "decoupled_grad", grad) else: # Attach the gradient to the optimizer parameter. - setattr(param, "grad", grad.to(param.dtype) if grad is not None else None) + setattr(param, "grad", grad) @property def num_buckets(self): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index be3894999b4..9414b60dba6 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1047,11 +1047,15 @@ def validate_args(args, defaults={}): assert args.ckpt_format == "fsdp_dtensor", \ "Megatron-FSDP requires the `fsdp_dtensor` checkpointing format." - if args.nccl_ub and args.use_megatron_fsdp: - # In Megatron-LM, required implementation for manual registration is already provided. - # So we enable the manual registration by default when nccl-ub and use_megatron_fsdp is set. - args.fsdp_manual_registration = True - warn_rank_0('FSDP manual registration is enabled by default when nccl-ub is enabled') + if args.nccl_ub: + # In Megatron-LM, required implementation for manual registration is already provided. + # So we enable the manual registration by default when nccl-ub and use_megatron_fsdp is set. + args.fsdp_manual_registration = True + warn_rank_0('FSDP manual registration is enabled by default when --nccl-ub is enabled!') + + assert args.cuda_graph_impl != "transformer_engine", ( + "Megatron-FSDP doesn't support TE partial CUDA graphs, use cuda_graph_impl=local instead!" + ) if args.fsdp_manual_registration: assert args.use_megatron_fsdp, "FSDP manual registration is only supported with Megatron FSDP." @@ -3269,11 +3273,11 @@ def _add_experimental_args(parser): group.add_argument('--megatron-fsdp-main-params-dtype', default='fp32', choices=['fp32', 'bf16', 'fp16', 'auto'], help="Data type for the main weight buffer utilized for distributed optimization " "and quantization with Megatron-FSDP. If 'auto', then the native model parameter " - "data-type will be used for the main weight data-type.") + "data-type will be used for the main weight data-type. Replaces --main-params-dtype.") group.add_argument('--megatron-fsdp-main-grads-dtype', default='auto', choices=['fp32', 'bf16', 'fp16', 'auto'], help="Data type for the main gradient buffer utilized for distributed optimization " "with Megatron-FSDP. If 'auto', then the native model gradient data-type will " - "be used for the main gradient / accumulation data-type.") + "be used for the main gradient / accumulation data-type. Replaces --main-grads-dtype.") group.add_argument("--megatron-fsdp-grad-comm-dtype", default='auto', choices=['fp32', 'fp16', 'bf16', 'auto'], help="When using Megatron-FSDP, this controls the data-type used when communicating " "model gradients during FSDP. If 'auto', then the main gradient data-type will " diff --git a/megatron/training/training.py b/megatron/training/training.py index c6cab8df952..e0b004455f2 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1635,6 +1635,16 @@ def get_megatron_ddp_config(args: argparse.Namespace) -> DistributedDataParallel kwargs["megatron_fsdp_main_grads_dtype"] = args.megatron_fsdp_main_grads_dtype kwargs["megatron_fsdp_grad_comm_dtype"] = args.megatron_fsdp_grad_comm_dtype kwargs["megatron_fsdp_use_decoupled_grad"] = args.use_precision_aware_optimizer + if args.use_megatron_fsdp and args.cuda_graph_impl == "local": + # Run Megatron-FSDP in CUDA graph-safe mode. Avoids some graph-unsafe host-side + # operations (such as pointer dereferencing) that can break CUDA graph replay. + kwargs["megatron_fsdp_cuda_graph_mode"] = True + if CudaGraphScope.full_iteration in args.cuda_graph_scope: + # When using full-iteration CUDA graphs, Megatron-FSDP should not AG parameters + # during start_param_sync(), which is called during the DistOpt.step(). This + # causes an error when we wait() on a CUDA kernel launched in a stream beyond + # the scope of the full-iter / FWD-BWD CUDA graph capture. + kwargs["fsdp_all_gather_in_start_param_sync"] = False return DistributedDataParallelConfig(**kwargs) diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 7abb80de14f..19f456734d2 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -567,7 +567,7 @@ def _broadcast(item): def _broadcast_cu_seqlens(cu_seqlens): dev = torch.cuda.current_device() n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) - n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) + n_tensor = torch.empty(1, dtype=torch.int64, device=dev).fill_(n) _broadcast(n_tensor) if n == 0: diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py index 500045871e7..a158300e0b6 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py @@ -13,12 +13,14 @@ from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel from megatron.core.distributed.fsdp.src.megatron_fsdp.mixed_precision import HAVE_TE_MXFP8TENSOR +from megatron.core.fp8_utils import HAVE_TE +from megatron.core.full_cuda_graph import FullCudaGraphWrapper, StaticBufferLoader from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.optimizer import OptimizerConfig from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import TransformerConfig -from megatron.core.utils import is_torch_min_version +from megatron.core.utils import is_te_min_version, is_torch_min_version from tests.unit_tests.distributed.megatron_fsdp.utils import ( make_gpt_mock_data_iterator, make_moe_args_model_and_optimizer, @@ -29,8 +31,6 @@ # Test model for testing FSDP -@pytest.mark.flaky -@pytest.mark.flaky_in_dev class TestModel(torch.nn.Module): def __init__(self, input_dim, output_dim): super().__init__() @@ -46,8 +46,6 @@ def forward(self, x): # Test model with uniform shaped weights for testing FSDP -@pytest.mark.flaky -@pytest.mark.flaky_in_dev class TestModelUniform(torch.nn.Module): def __init__(self, hidden_dim): super().__init__() @@ -78,8 +76,6 @@ def setup_seed(seed): torch.backends.cudnn.benchmark = False # Disable auto-tuner for reproducibility -@pytest.mark.flaky -@pytest.mark.flaky_in_dev class TestFullyShardedDataParallel: @classmethod def setup_class(cls): @@ -623,8 +619,8 @@ def train_step(model, optimizer, inputs): Utils.destroy_model_parallel() @pytest.mark.parametrize("num_fsdp_group", [2]) - @pytest.mark.skipIf( - torch.cuda.device_count() % 2 == 0, "This test requires an odd number of GPUs" + @pytest.mark.skipif( + torch.cuda.device_count() % 2 == 0, reason="This test requires an odd number of GPUs" ) def test_fsdp_with_hybrid_sharding(self, num_fsdp_group): """Test that FSDP works correctly with hybrid sharding.""" @@ -897,6 +893,231 @@ def test_compatible_with_nd_parallel(self, ref_cache, nd_topology, spec_configs) ), ) + @staticmethod + def _reset_full_cuda_graph_static_state(): + """Reset class-level state on FullCudaGraphWrapper / StaticBufferLoader + so a test that uses the wrapper does not see leftovers from a previous + test in this process.""" + FullCudaGraphWrapper.curr_iteration = {'training': 0, 'validation': 0} + FullCudaGraphWrapper.cuda_graph = {'training': None, 'validation': None} + FullCudaGraphWrapper.result = {'training': None, 'validation': None} + StaticBufferLoader.static_buffers = {'training': [], 'validation': []} + + @staticmethod + def _reset_cuda_rng_tracker(): + """Force a fresh CUDA RNG tracker on the next initialize_rng_tracker + call. A prior test in this process may have created a non-cudagraphable + tracker (states stored as Tensors); without a reset, the cuda-graph + path would later feed those Tensors into ``Generator.graphsafe_set_state`` + and raise ``TypeError: expected a Generator, but got Tensor``.""" + from megatron.core.tensor_parallel import random as _tp_random + + _tp_random._CUDA_RNG_STATE_TRACKER = None + _tp_random._CUDA_RNG_STATE_TRACKER_INITIALIZED = False + + @pytest.mark.skipif( + not is_torch_min_version("2.4.0"), reason="Test needs to be updated for torch >= 2.4.0" + ) + @pytest.mark.skipif( + not (HAVE_TE and is_te_min_version("1.5.0")), + reason=( + "TransformerEngine FusedAdam and RNG tracker required for " + "full-iteration CUDA graphability with Megatron-FSDP." + ), + ) + @pytest.mark.parametrize( + "extra_overrides", + [ + pytest.param({}, id="fsdp"), + pytest.param( + dict(num_distributed_optimizer_instances=2, outer_dp_sharding_strategy="optim"), + id="hsdp_optim_outer_dp2", + ), + ], + ) + def test_full_iteration_cuda_graph_e2e(self, extra_overrides): + """ + End-to-end test for Megatron-FSDP + full-iteration CUDA graph. + + Variants: + * ``fsdp``: pure FSDP (single distributed-optimizer instance). + * ``hsdp_optim_outer_dp2``: Hybrid FSDP with two outer DP groups, + outer-DP sharding strategy ``optim``. + + Asserts: + 1. ``FullCudaGraphWrapper.cuda_graph['training']`` is populated by + the end of training (i.e. capture happened). + 2. Decoupled gradients are globally present before every + ``optimizer.step``. + 3. Loss decreases across the run. + """ + import argparse + import os + from functools import partial + + from torch.optim.optimizer import register_optimizer_step_pre_hook + + import pretrain_gpt as _pretrain_gpt + from megatron.core.enums import ModelType + from megatron.core.rerun_state_machine import destroy_rerun_state_machine + from megatron.core.transformer.enums import CudaGraphScope + from megatron.training import pretrain + from megatron.training.argument_utils import pretrain_cfg_container_from_args + from megatron.training.arguments import add_megatron_arguments, validate_args + from megatron.training.global_vars import set_global_variables, unset_global_variables + + # Because we are using pretrain() to test, destroy the entire global state + # before calling pretrain() for the next test case. + TestMegatronFSDPE2E._reset_full_cuda_graph_static_state() + TestMegatronFSDPE2E._reset_cuda_rng_tracker() + mpu.destroy_model_parallel() + unset_global_variables() + destroy_rerun_state_machine() + for _v in ("NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN"): + os.environ.pop(_v, None) + + # Minimal setup for Megatron-FSDP and CUDA graphs. + arg_overrides = dict( + num_layers=2, + hidden_size=128, + num_attention_heads=4, + seq_length=256, + max_position_embeddings=256, + micro_batch_size=1, + global_batch_size=8, + train_iters=8, + lr=1e-4, + mock_data=True, + tokenizer_type="NullTokenizer", + vocab_size=256, + bf16=True, + use_megatron_fsdp=True, + ckpt_format="fsdp_dtensor", + use_precision_aware_optimizer=True, + cuda_graph_impl="local", + cuda_graph_scope=[CudaGraphScope.full_iteration], + check_for_nan_in_loss_and_grad=False, + eval_iters=0, + eval_interval=8, + **extra_overrides, + ) + + # Test loss and gradients when using full-iter CG with FSDP. + losses: list[torch.Tensor] = [] + grads_present_steps: list[bool] = [] + + orig_forward_step = _pretrain_gpt.forward_step + + def wrapped_forward_step(*args, **kwargs): + output_tensor, loss_func_partial = orig_forward_step(*args, **kwargs) + + def wrapped_loss(*la, **lk): + ret = loss_func_partial(*la, **lk) + try: + if isinstance(ret, tuple) and len(ret) >= 1: + report = ret[-1] if isinstance(ret[-1], dict) else None + if report and "lm loss" in report: + val = report["lm loss"] + if isinstance(val, torch.Tensor): + if val.numel() >= 2: + loss_sum, num_toks = val[0], val[1] + per_token = loss_sum / num_toks.clamp(min=1) + losses.append(per_token.detach().clone()) + else: + losses.append(val.detach().clone()) + except Exception: + pass + return ret + + return output_tensor, wrapped_loss + + # Pre-step hook on every Optimizer.step — verify decoupled grads are + # visible before the optimizer reads them. With Megatron-FSDP + + # precision-aware optimizer the FusedAdam reads from + # ``param.decoupled_grad``. + def pre_step_hook(optimizer, args_, kwargs_): + local_present = any( + getattr(p, "decoupled_grad", None) is not None + and ( + getattr(p, "decoupled_grad")._local_tensor + if hasattr(getattr(p, "decoupled_grad"), "_local_tensor") + else getattr(p, "decoupled_grad") + ) + .count_nonzero() + .item() + > 0 + for group in optimizer.param_groups + for p in group["params"] + ) + gathered = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(gathered, local_present) + grads_present_steps.append(any(gathered)) + + hook_handle = register_optimizer_step_pre_hook(pre_step_hook) + cuda_graph_was_captured = False + + try: + # Setup argument overrides for FSDP <> CG test. + parser = argparse.ArgumentParser(allow_abbrev=False) + add_megatron_arguments(parser) + parser.set_defaults(**arg_overrides) + args = parser.parse_args([]) + args._is_global_batch_size_explicitly_specified = args.global_batch_size is not None + args.rank = int(os.getenv("RANK", "0")) + args.world_size = int(os.getenv("WORLD_SIZE", "1")) + validate_args(args) + set_global_variables(args) + cfg = pretrain_cfg_container_from_args(args) + + from gpt_builders import gpt_builder + from model_provider import model_provider + + pretrain( + cfg, + _pretrain_gpt.train_valid_test_datasets_provider, + partial(model_provider, gpt_builder), + ModelType.encoder_or_decoder, + wrapped_forward_step, + get_embedding_ranks=_pretrain_gpt.get_embedding_ranks, + ) + # Validate CUDA graph was captured and thus replayed. + cuda_graph_was_captured = FullCudaGraphWrapper.cuda_graph.get("training") is not None + finally: + hook_handle.remove() + TestMegatronFSDPE2E._reset_full_cuda_graph_static_state() + TestMegatronFSDPE2E._reset_cuda_rng_tracker() + mpu.destroy_model_parallel() + unset_global_variables() + destroy_rerun_state_machine() + + # ---- Assertions ---- + assert cuda_graph_was_captured, ( + "FullCudaGraphWrapper did not capture a training CUDA graph " + "during pretrain(). Capture either failed silently or the " + "wrapper was not engaged." + ) + + assert len(grads_present_steps) > 0, ( + "Optimizer pre-step hook never fired — pretrain() did not run " + "any optimizer.step calls." + ) + assert all(grads_present_steps), ( + f"Decoupled gradients were missing on at least one " + f"optimizer.step call. Per-step presence trace: " + f"{grads_present_steps}" + ) + + finite_losses = [float(l) for l in losses if torch.isfinite(l).all()] + assert len(finite_losses) >= 2, ( + f"Need at least two finite loss observations to check " + f"convergence; got {len(finite_losses)}: {finite_losses}" + ) + assert finite_losses[-1] < finite_losses[0], ( + f"Loss did not decrease across {len(finite_losses)} steps: " + f"first={finite_losses[0]:.6f}, last={finite_losses[-1]:.6f}, " + f"trace={finite_losses}" + ) + def compare_losses(loss_a: float, loss_b: float, reference: str = "b"): """ diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py b/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py index fb944b3ed76..fa4afafae40 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py @@ -14,7 +14,16 @@ from torch.nn.functional import mse_loss from torch.optim import Adam +try: + from transformer_engine.pytorch.optimizers import FusedAdam + + HAVE_TE_FUSED_ADAM = True +except ImportError: + HAVE_TE_FUSED_ADAM = False + from megatron.core.distributed.fsdp.src.megatron_fsdp.fully_shard import ( + MixedPrecisionPolicy, + fully_shard, fully_shard_model, fully_shard_optimizer, ) @@ -291,10 +300,6 @@ def test_fully_shard( don't add any new parameters unless absolutely necessary, or if some combinations can be flattened or simplified. """ - from megatron.core.distributed.fsdp.src.megatron_fsdp import ( - MixedPrecisionPolicy, - fully_shard, - ) preserve_fp32_weights = common_args["preserve_fp32_weights"] init_model_with_meta_device = common_args["init_model_with_meta_device"] @@ -424,11 +429,6 @@ def test_dcp_checkpoint_save_and_load( """ from torch.distributed.tensor import DTensor - from megatron.core.distributed.fsdp.src.megatron_fsdp import ( - MixedPrecisionPolicy, - fully_shard, - ) - # Skip tests. if outer_shard_strategy == OPTIM and shard_strategy != OPTIM_GRADS_PARAMS: # TODO(@shjwudp, @cspades): Requires various modifications to support. @@ -673,10 +673,6 @@ def test_fully_shard_ez(self, shard_strategy): """ Test fully_shard(device_mesh=None). Represents the easiest entrypoint to Megatron-FSDP. """ - from megatron.core.distributed.fsdp.src.megatron_fsdp import ( - fully_shard_model, - fully_shard_optimizer, - ) # Construct toy model. toy_model, fsdp_unit_modules = build_toy_model(TRANSFORMER, False) @@ -709,6 +705,179 @@ def test_fully_shard_ez(self, shard_strategy): optimizer.step() optimizer.zero_grad() + @pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.4.0'), + reason="Megatron-FSDP requires PyTorch 2.4.0 or later.", + ) + @pytest.mark.skipif( + not HAVE_TE_FUSED_ADAM, + reason="Full-iteration CUDA graph capture requires TransformerEngine FusedAdam.", + ) + # FSDP (no outer-DP collectives) and HFSDP (outer-DP sharded). Both wrap a + # device mesh with DP-Outer=2 / DP-Shard=4 to exercise the full hierarchy. + @pytest.mark.parametrize("dp_outer_strategy", [None, OPTIM]) + def test_full_iteration_cuda_graph(self, dp_outer_strategy): + """ + End-to-end test that a full Megatron-FSDP training iteration (forward + + backward) is CUDA-graphable, and that optimizer.zero_grad / optimizer.step + between graph replays correctly applies gradients produced inside the graph. + + Exercises the conditional grad-dereferencing path in + ``ParamAndGradBuffer.zero_grad``: when ``param.grad`` is a view of an FSDP + sharded gradient buffer, ``zero_grad`` must preserve the view between + replays so that the next replay populates the same tensor — otherwise the + optimizer would see stale gradients on subsequent replays. The companion + wrapper for full-iteration capture in production training is + ``megatron.core.full_cuda_graph.FullCudaGraphWrapper``. + + Uses TransformerEngine ``FusedAdam`` rather than ``torch.optim.Adam``: + the stock Adam unconditionally sets ``param.grad = None`` in + ``zero_grad``, which dereferences the FSDP grad-buffer view that the + captured graph writes into and breaks replay. ``FusedAdam`` honors + ``set_to_none=False`` (zeros the buffer in place) and supports + ``capturable=True`` for graph-safe step math. + """ + # Construct (DP-Outer=2, DP-Inner=4) DeviceMesh. + device_mesh = build_distributed_environment((2, 4, 1, 1)) + + # Construct toy Megatron-FSDP model. + toy_model, fsdp_unit_modules = build_toy_model( + TRANSFORMER, init_model_with_meta_device=False, seed=0 + ) + mfsdp_model = fully_shard_model( + module=toy_model, + device_mesh=device_mesh, + dp_shard_dim=DP_SHARD, + # Pure FSDP or Hybrid-FSDP. + dp_outer_dim=DP_OUTER if dp_outer_strategy is not None else None, + tp_dim=TP, + hybrid_fsdp_group=( + device_mesh[HSDP].get_group() if dp_outer_strategy is not None else None + ), + fsdp_unit_modules=fsdp_unit_modules, + zero_dp_strategy=OPTIM_GRADS_PARAMS, + outer_dp_sharding_strategy=( + dp_outer_strategy if dp_outer_strategy is not None else NO_SHARD + ), + sync_model_each_microbatch=True, + # When using CUDA graphs, gradient accumulation precision must + # align with main parameter precision. Alternatively, use: + # FusedAdam(use_decoupled_grad=True) + fully_shard_model(use_decoupled_grad=True) + mixed_precision_policy=MixedPrecisionPolicy( + main_params_dtype=torch.float32, main_grads_dtype=torch.float32 + ), + # Run Megatron-FSDP in CUDA graph-safe mode. + cuda_graph_mode=True, + ) + + # FusedAdam is REQUIRED for full-iteration CUDA graphs! + toy_adam = FusedAdam(params=mfsdp_model.parameters(), lr=0.01, capturable=True) + optimizer = fully_shard_optimizer(optimizer=toy_adam) + + # Static input/target buffers reused across capture and replay. + static_input = torch.randn(1, DIM_SIZE, DIM_SIZE, device="cuda") + static_target = torch.randn(1, DIM_SIZE, DIM_SIZE, device="cuda") + + # CUDA-graphable training loop. + def run_step(): + output = mfsdp_model(static_input, static_input) + loss = mse_loss(output, static_target) + loss.backward() + return loss + + # Side-stream warmup. CUDA graph capture requires that any one-time + # allocations and lazy-init state are already populated, so we run + # a few eager steps on a non-default stream before capture. + warmup_stream = torch.cuda.Stream() + warmup_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(warmup_stream): + for _ in range(3): + # set_to_none=False keeps param.grad as a view of the FSDP + # sharded gradient buffer — required so that the next replay's + # backward writes into the same tensor the optimizer reads. + optimizer.zero_grad(set_to_none=False) + run_step() + optimizer.step() + # Synchronize all streams before capture. + torch.cuda.current_stream().wait_stream(warmup_stream) + torch.cuda.synchronize() + + # Capture forward + backward into a CUDA graph. Optimizer step + # is not captured for this test, but FusedAdam is compatible. + # (Megatron-FSDP post-backward grad installation is captured.) + optimizer.zero_grad(set_to_none=False) + graph = torch.cuda.CUDAGraph() + torch.distributed.barrier() + torch.cuda.synchronize() + capture_stream = torch.cuda.Stream() + with torch.cuda.graph(graph, stream=capture_stream, capture_error_mode="thread_local"): + static_loss = run_step() + torch.cuda.synchronize() + torch.distributed.barrier() + + def assert_grads_present(step): + local_grads_present = any( + getattr(p, grad_attr, None) is not None + and ( + getattr(p, grad_attr)._local_tensor + if hasattr(getattr(p, grad_attr), "_local_tensor") + else getattr(p, grad_attr) + ) + .count_nonzero() + .item() + > 0 + for p in mfsdp_model.parameters() + for grad_attr in ("grad", "decoupled_grad") + ) + fsdp_group = mfsdp_model.dist_index.get_fsdp_group() + gathered = [None] * fsdp_group.size() + torch.distributed.all_gather_object( + object_list=gathered, obj=local_grads_present, group=fsdp_group + ) + assert any(gathered), ( + f"No parameter on any FSDP rank has a non-None, non-zero " + f"param.grad / param.decoupled_grad after replay step {step}. " + f"The CUDA-graph replay did not deliver gradients to the " + f"optimizer." + ) + + # Replay enough steps that a healthy training loop should clearly drive + # the loss down on this fixed (input, target) pair. + num_replays = 8 + replay_losses = [] + for step in range(num_replays): + optimizer.zero_grad(set_to_none=False) + graph.replay() + torch.cuda.synchronize() + # Post-backward, pre-step: the freshly produced gradients must be + # visible on the optimizer parameters. + assert_grads_present(step=step) + # Detach and clone the loss, as this buffer will be reused. + replay_losses.append(static_loss.detach().clone()) + # Perform the optimizer step. + optimizer.step() + + # All replays must produce finite losses. + for step, loss_value in enumerate(replay_losses): + assert torch.isfinite(loss_value).all(), ( + f"Loss at replay step {step} is not finite under full-iteration " + f"CUDA graph: {loss_value.item()}" + ) + + # Loss must clearly decrease across replays. A broken graph-replay path + # (e.g. optimizer applying stale or zero grads) typically manifests as + # a flat or oscillating loss. + first_loss = replay_losses[0].item() + last_loss = replay_losses[-1].item() + assert last_loss < first_loss, ( + f"Loss did not decrease across {num_replays} CUDA-graph replays: " + f"first={first_loss:.6f}, last={last_loss:.6f}, " + f"trace={[l.item() for l in replay_losses]}" + ) + + # Required to reset the parallelism environment. + destroy_device_mesh(device_mesh) + @pytest.mark.parametrize("init_model_with_meta_device", [True, False]) @pytest.mark.parametrize( "te_recipe", @@ -722,12 +891,6 @@ def test_fully_shard_te_quantized(self, init_model_with_meta_device, te_recipe): # TODO(@cspades, @ko3n1g): Add this test case in. pytest.skip(f"[Megatron CI/CD] MXFP8 requires Blackwell nodes to test.") - from megatron.core.distributed.fsdp.src.megatron_fsdp import ( - MixedPrecisionPolicy, - fully_shard_model, - fully_shard_optimizer, - ) - # Build FP8 recipe. te_quant_recipe = None if te_recipe == MXFP8_BLOCKWISE_RECIPE: @@ -901,11 +1064,6 @@ def test_fully_shard_custom_dtype( """ Test custom data-types for gather and reduce communications. """ - from megatron.core.distributed.fsdp.src.megatron_fsdp import ( - MixedPrecisionPolicy, - fully_shard_model, - fully_shard_optimizer, - ) if dp_outer_strategy == OPTIM and dp_shard_strategy != OPTIM_GRADS_PARAMS: pytest.skip(