Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions examples/megatron_fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,32 @@ bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh
# With real data
bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh \
checkpoints/llama3_8b_fsdp_fp8 \
tensorboard_logs/llama3_8b_fsdp_fp8 \
/path/to/data_prefix \
/path/to/tokenizer \
/path/to/data_prefix
nsys_profiles/llama3_8b_fsdp_fp8 \
tensorboard_logs/llama3_8b_fsdp_fp8

# With Nsight Systems profiling (steps 4–6 on rank 0)
NSYS_PROFILE=1 bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh

# Without uv (use the ambient `python`)
USE_UV=0 bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh
```

| Positional Argument | Default | Description |
|---------------------|---------|-------------|
| `$1` — Checkpoint path | `checkpoints/llama3_8b_fsdp_fp8` | Directory for saving and loading checkpoints. |
| `$2` — TensorBoard path | `tensorboard_logs/llama3_8b_fsdp_fp8` | Directory for TensorBoard logs. |
| `$1` — Checkpoint Path | `checkpoints/llama3_8b_fsdp_fp8` | Directory for saving and loading checkpoints. |
| `$2` — Data Path | `MOCK` | Data prefix for training data, or `MOCK` for mock data. |
| `$3` — Tokenizer | `MOCK` | Path to a tokenizer model, or `MOCK` for `NullTokenizer`. |
| `$4` — Data path | `MOCK` | Data prefix for training data, or `MOCK` for mock data. |
| `$4` — NSight Profiling Path | `nsys_profiles/llama3_8b_fsdp_fp8` | Output path (without extension) for the `.nsys-rep` file when `NSYS_PROFILE=1`. |
| `$5` — TensorBoard Path | `tensorboard_logs/llama3_8b_fsdp_fp8` | Directory for TensorBoard logs. |

#### Environment Variables

| Variable | Default | Description |
|----------|---------|-------------|
| `USE_UV` | `1` | Set to `1` to launch via `uv run` (project venv). Set to `0` to use the ambient `python`. |
| `NSYS_PROFILE` | `0` | Set to `1` to wrap the launch in `nsys profile`. Captures steps 4–6 on rank 0 via `--capture-range=cudaProfilerApi`, with CUDA graph node tracing and CUDA memory usage enabled. Output goes to the path in `$4`. |
| `USE_MEGATRON_FSDP` | `1` | Set to `1` to enable Megatron-FSDP. Set to `0` to train with standard DDP. |
| `SHARDING_STRATEGY` | `optim_grads_params` | FSDP sharding strategy (ZeRO-3). Options: `no_shard`, `optim`, `optim_grads`, `optim_grads_params`. |
| `OUTER_SHARDING_STRATEGY` | `no_shard` | DP-Outer sharding strategy for HSDP/HFSDP. Options: `no_shard`, `optim`. |
Expand All @@ -49,6 +59,7 @@ bash examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh \
- **Precision**: FP8 (hybrid format) with BF16 training and BF16 gradient reduction
- **Batch size**: micro-batch=1, global-batch=128, sequence length=8192
- **Optimizations**: NCCL user buffers, FSDP double buffering, manual registration, meta-device initialization, per-token loss, overlapped grad-reduce and param-gather
- **Launch**: `[uv run] [nsys profile ...] python -m torch.distributed.run ... pretrain_gpt.py ...` — the `uv run` and `nsys profile` prefixes are toggled by `USE_UV` and `NSYS_PROFILE` respectively.

---

Expand Down
55 changes: 46 additions & 9 deletions examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#!/bin/bash

CHECKPOINT_PATH=${1:-"checkpoints/llama3_8b_fsdp_fp8"}
TENSORBOARD_LOGS_PATH=${2:-"tensorboard_logs/llama3_8b_fsdp_fp8"}
TOKENIZER_ARG=${3:-"MOCK"} # Path to tokenizer model, or "MOCK"
DATA_ARG=${4:-"MOCK"} # Data prefix, or "MOCK"
DATA_ARG=${2:-"MOCK"} # Data prefix, or "MOCK"
TOKENIZER_ARG=${3:-"MOCK"} # Path to tokenizer model, or "MOCK"
NSYS_PROFILE_PATH=${4:-"nsys_profiles/llama3_8b_fsdp_fp8"}
TENSORBOARD_LOGS_PATH=${5:-"tensorboard_logs/llama3_8b_fsdp_fp8"}

# Create directories if they don't exist
mkdir -p "$(dirname "$CHECKPOINT_PATH")"
mkdir -p "$(dirname "$NSYS_PROFILE_PATH")"
mkdir -p "$(dirname "$TENSORBOARD_LOGS_PATH")"

# Distributed training setup
Expand All @@ -21,6 +23,18 @@ WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))
# is run from the root of the Megatron-LM repository.
PRETRAIN_SCRIPT_PATH="pretrain_gpt.py"

# NSight Profiling
NSYS_PROFILE=${NSYS_PROFILE:-0}

# Optional `uv run` venv prefix. With uv, nsys (and its child workers) all
# inherit the project venv. Without uv, fall back to the ambient `python`.
USE_UV=${USE_UV:-1}
if [ "${USE_UV}" = 1 ]; then
VENV_PREFIX="uv run"
else
VENV_PREFIX=""
fi

# Model & Training Parameters
USE_MEGATRON_FSDP=${USE_MEGATRON_FSDP:-1}
SHARDING_STRATEGY=${SHARDING_STRATEGY:-"optim_grads_params"}
Expand Down Expand Up @@ -89,6 +103,7 @@ TRAINING_ARGS=(
--adam-beta2 0.95
--bf16
--cross-entropy-loss-fusion
--no-check-for-nan-in-loss-and-grad
--manual-gc
--empty-unused-memory-level 1
--exit-duration-in-mins 235
Expand All @@ -104,7 +119,7 @@ if [ "${USE_MEGATRON_FSDP}" = 1 ]; then
--calculate-per-token-loss
--init-model-with-meta-device
--ckpt-format fsdp_dtensor
--grad-reduce-in-bf16
--grad-reduce-in-bf16 # Will be deprecated soon!
--use-nccl-ub
--fsdp-double-buffer
--fsdp-manual-registration
Expand All @@ -116,6 +131,11 @@ if [ "${USE_MEGATRON_FSDP}" = 1 ]; then
# --megatron-fsdp-main-params-dtype fp32
# --megatron-fsdp-main-grads-dtype auto
# --megatron-fsdp-grad-comm-dtype auto
# To use decoupled (mixed-precision) gradients...
# --use-precision-aware-optimizer
# To use full-iteration CUDA graphs with Megatron-FSDP...
# --cuda-graph-impl local
# --cuda-graph-scope full_iteration
)
fi

Expand Down Expand Up @@ -183,24 +203,41 @@ EVAL_AND_LOGGING_ARGS=(
--eval-interval 100
--save-interval 1000
--log-throughput
--profile
--profile-step-start 4
--profile-step-end 6
--distributed-timeout-minutes 60
--save "$CHECKPOINT_PATH"
--load "$CHECKPOINT_PATH"
--tensorboard-dir "$TENSORBOARD_LOGS_PATH"
)

# Profiling (NSYS_PROFILE=1 bash ...)
if [ "${NSYS_PROFILE}" = 1 ]; then
TRAINING_ARGS+=(
--profile
--profile-step-start 8
--profile-step-end 12
--profile-ranks 0
)
PROFILE_CMD=(
nsys profile
--sample=none --cpuctxsw=none
--trace=cuda,nvtx,cublas,cudnn
--capture-range=cudaProfilerApi --capture-range-end=stop
--cuda-graph-trace=node --cuda-memory-usage=true
-f true -x true -o "$NSYS_PROFILE_PATH"
)
else
PROFILE_CMD=()
fi

# Ensure pretrain_gpt.py is found
if [ ! -f "$PRETRAIN_SCRIPT_PATH" ]; then
echo "Error: pretrain_gpt.py not found at $PRETRAIN_SCRIPT_PATH"
echo "Please ensure you are running this script from the root of the Megatron-LM repository, and pretrain_gpt.py is present."
exit 1
fi

# Run the training command
torchrun ${DISTRIBUTED_ARGS[@]} \
# Run the training command.
$VENV_PREFIX "${PROFILE_CMD[@]}" python -m torch.distributed.run ${DISTRIBUTED_ARGS[@]} \
"$PRETRAIN_SCRIPT_PATH" \
${MODEL_ARGS[@]} \
${TRAINING_ARGS[@]} \
Expand Down
11 changes: 11 additions & 0 deletions megatron/core/distributed/distributed_data_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,17 @@ class DistributedDataParallelConfig:
main gradients to parameter dtype for `.grad`.
"""

megatron_fsdp_cuda_graph_mode: bool = False
"""If set to True, Megatron-FSDP will practice CUDA graph-safe operations, such as
not dereferencing `param.grad` after the optimizer step to preserve references for
CUDA graph replay. Can affect memory utilization in some cases, such as when the
gradient shard is not a view of the Megatron-FSDP sharded gradient buffer, so
FusedAdam(use_decoupled_grad=True) + megatron_fsdp_use_decoupled_grad=True or
setting megatron_fsdp_main_params_dtype == megatron_fsdp_main_grads_dtype is
recommended to avoid casting the gradient to the parameter precision and creating
a casted-copy of the gradient shard that cannot be dereferenced due to replay.
"""

def __post_init__(self):
import os

Expand Down
9 changes: 6 additions & 3 deletions megatron/core/distributed/finalize_model_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,10 @@ def finalize_model_grads(

# all-reduce across DP ranks.
torch.distributed.all_reduce(num_tokens, group=dp_cp_group)

# Clamp to avoid div-by-zero without a host-side branch on a device tensor,
# which would otherwise cause a sync that is illegal during CUDA graph capture.
safe_num_tokens = torch.clamp(num_tokens, min=1)
scaling = 1.0 / safe_num_tokens
for model_chunk in model:
if num_tokens > 0:
scaling = 1.0 / num_tokens
model_chunk.scale_gradients(scaling)
model_chunk.scale_gradients(scaling)
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ class DistributedDataParallelConfig:
main gradients to parameter dtype for `.grad`.
"""

megatron_fsdp_cuda_graph_mode: bool = False
"""If set to True, Megatron-FSDP will practice CUDA graph-safe operations, such as
not dereferencing `param.grad` after the optimizer step to preserve references for
CUDA graph replay. Can affect memory utilization in some cases, such as when the
gradient shard is not a view of the Megatron-FSDP sharded gradient buffer, so
FusedAdam(use_decoupled_grad=True) + megatron_fsdp_use_decoupled_grad=True or
setting megatron_fsdp_main_params_dtype == megatron_fsdp_main_grads_dtype is
recommended to avoid casting the gradient to the parameter precision and creating
a casted-copy of the gradient shard that cannot be dereferenced due to replay.
"""

def __post_init__(self):
import os

Expand Down
15 changes: 15 additions & 0 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def fully_shard_model(
disable_symmetric_registration: bool = False,
enable_fine_grained_param_gather: bool = False,
use_decoupled_grad: bool = False,
cuda_graph_mode: bool = False,
) -> torch.nn.Module:
"""
Fully-shard the model for Megatron-FSDP. This wraps the model in a MegatronFSDP
Expand Down Expand Up @@ -265,6 +266,17 @@ class that schedules the sharding lifecycle of the model parameters and gradient
If true, reduced gradients are installed into `Parameter.decoupled_grad` instead
of `Parameter.grad`. Defaults to False.

cuda_graph_mode (bool):
If true, Megatron-FSDP will practice CUDA graph-safe operations, such as
not dereferencing `param.grad` after the optimizer step to preserve references
for CUDA graph replay. Can affect memory utilization in some cases, such as
when the gradient shard is not a view of the Megatron-FSDP sharded gradient
buffer, so `FusedAdam(use_decoupled_grad=True) + use_decoupled_grad=True` or
setting `megatron_fsdp_main_params_dtype == megatron_fsdp_main_grads_dtype`
is recommended to avoid casting the gradient to the parameter precision and
creating a casted-copy of the gradient shard that cannot be dereferenced due
to replay. Defaults to False.

Returns:
model (MegatronFSDP): The wrapped Megatron-FSDP model configured for FSDP.
"""
Expand Down Expand Up @@ -360,6 +372,7 @@ class that schedules the sharding lifecycle of the model parameters and gradient
fsdp_db_use_persist_buf_on_alloc_fail=fsdp_db_use_persist_buf_on_alloc_fail,
disable_symmetric_registration=disable_symmetric_registration,
megatron_fsdp_use_decoupled_grad=use_decoupled_grad,
megatron_fsdp_cuda_graph_mode=cuda_graph_mode,
)

# Create FSDPDistributedIndex.
Expand Down Expand Up @@ -666,6 +679,7 @@ def fully_shard(
disable_symmetric_registration: bool = False,
enable_fine_grained_param_gather: bool = False,
use_decoupled_grad: bool = False,
cuda_graph_mode: bool = False,
) -> tuple[MegatronFSDP, torch.optim.Optimizer]:
"""
Fully shard the model and the optimizer for Megatron-FSDP.
Expand Down Expand Up @@ -717,6 +731,7 @@ def fully_shard(
disable_symmetric_registration=disable_symmetric_registration,
enable_fine_grained_param_gather=enable_fine_grained_param_gather,
use_decoupled_grad=use_decoupled_grad,
cuda_graph_mode=cuda_graph_mode,
)

# Extend optimizer methods to support Megatron-FSDP operations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,8 @@ def _grad_acc(param):
# Sharded Gradient Buffer
gbuf = group.hfsdp_helper_gbuf if group.hfsdp_helper_gbuf else group.main_grad_buffer
if gbuf.is_data_distributed:
# If TransformerEngine gradient accumulation is fused, then param.get_main_grad()
# already holds the wgrad and param.grad_added_to_main_grad=True.
if not param.grad_added_to_main_grad:
# Get `main_grad` will allocate bucket, check that the currently
# used main_grad buffer does not exceed the scope of two FSDP Unit
Expand All @@ -612,7 +614,6 @@ def _grad_acc(param):
param.main_grad.copy_(to_local_if_dtensor(param.grad))
del param.grad
else:
# Prepare for fused wgrad accumulation.
param.main_grad.zero_()
# Unsharded Gradient Buffer
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2777,14 +2777,31 @@ def zero_grad(self):
"""
Zero out the underlying grad_buffer and reset all buckets in preparation
for the next iteration of training.

Gradient shards are dereferenced to free memory. However, dereferencing is
not compatible with (FWD-BWD / full-iteration) CUDA graph-ability, because
we need to preserve this reference to the sharded gradient generated during
CUDA graph replay (`setattr` in `update_main_grads` not executed during
CUDA graph replay, as it is not a CUDA kernel).

If the gradient is decoupled (precision-aware) or is equivalent to the
distributed optimizer parameter precision, the gradient shard is a view of
the Megatron-FSDP sharded gradient buffer. If not, then not dereferencing
this gradient shard will increase memory utilization as this gradient is a
persistent casted-copy of the accumulated gradient.
"""
for name, param in self.optimizer_named_parameters:
param.grad = None
if hasattr(param, "decoupled_grad"):
param.decoupled_grad = None
if name in self.dist_main_grad:
self.dist_main_grad[name]._local_tensor = None

if not self.ddp_config.megatron_fsdp_cuda_graph_mode:
# Dereference the sharded gradient to reclaim memory
# unless a full-iteration CUDA graph is utilized.
param.grad = None
if hasattr(param, "decoupled_grad"):
param.decoupled_grad = None
if name in self.dist_main_grad:
self.dist_main_grad[name]._local_tensor = None

# Zero the Megatron-FSDP sharded gradient buffer. If param.grad or param.decoupled_grad
# is a view of this buffer, they will be zero'd as well.
for group in self.parameter_groups:
if group.main_grad_buffer:
group.main_grad_buffer.data.zero_()
Expand Down Expand Up @@ -2923,11 +2940,6 @@ def update_main_grads(self):
from the main gradient buffer. If the model parameters are sharded,
we only need to update the gradient shard associated with the model
parameter shard, as both are sharded symmetrically.

Checks if high-precision main weights are utilized for optimization.
Otherwise, falls back to low-precision model weights, and further
falls back to the original module parameters not managed by cFSDP
in the case of no sharding / cFSDP OFF.
"""
for name, param in self.optimizer_named_parameters:
orig_param = param.orig_param
Expand All @@ -2948,11 +2960,11 @@ def update_main_grads(self):
optimizer_grad = group.main_grad_buffer.get_item(
item_id, only_shard=sharded_optimizer_state
)
if group.main_weight_buffer is not None:
if not self.use_decoupled_grad:
# Convert the gradient to the main weight buffer dtype.
# TODO(@cspades): Why this is necessary? Casted below.
optimizer_grad = optimizer_grad.to(param.dtype)
if group.main_weight_buffer is not None and not self.use_decoupled_grad:
# Convert the gradient to the main weight data-type for optimization.
# Not needed for decoupled gradients, because the precision-aware
# optimizer can apply gradients to parameters of different precision!
optimizer_grad = optimizer_grad.to(param.dtype)

if name not in self.dist_main_grad:
# Register the gradient as a distributed tensor.
Expand Down Expand Up @@ -2981,7 +2993,7 @@ def update_main_grads(self):
setattr(param, "decoupled_grad", grad)
else:
# Attach the gradient to the optimizer parameter.
setattr(param, "grad", grad.to(param.dtype) if grad is not None else None)
setattr(param, "grad", grad)

@property
def num_buckets(self):
Expand Down
18 changes: 11 additions & 7 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,11 +1047,15 @@ def validate_args(args, defaults={}):
assert args.ckpt_format == "fsdp_dtensor", \
"Megatron-FSDP requires the `fsdp_dtensor` checkpointing format."

if args.nccl_ub and args.use_megatron_fsdp:
# In Megatron-LM, required implementation for manual registration is already provided.
# So we enable the manual registration by default when nccl-ub and use_megatron_fsdp is set.
args.fsdp_manual_registration = True
warn_rank_0('FSDP manual registration is enabled by default when nccl-ub is enabled')
if args.nccl_ub:
# In Megatron-LM, required implementation for manual registration is already provided.
# So we enable the manual registration by default when nccl-ub and use_megatron_fsdp is set.
args.fsdp_manual_registration = True
warn_rank_0('FSDP manual registration is enabled by default when --nccl-ub is enabled!')

assert args.cuda_graph_impl != "transformer_engine", (
"Megatron-FSDP doesn't support TE partial CUDA graphs, use cuda_graph_impl=local instead!"
)

if args.fsdp_manual_registration:
assert args.use_megatron_fsdp, "FSDP manual registration is only supported with Megatron FSDP."
Expand Down Expand Up @@ -3269,11 +3273,11 @@ def _add_experimental_args(parser):
group.add_argument('--megatron-fsdp-main-params-dtype', default='fp32', choices=['fp32', 'bf16', 'fp16', 'auto'],
help="Data type for the main weight buffer utilized for distributed optimization "
"and quantization with Megatron-FSDP. If 'auto', then the native model parameter "
"data-type will be used for the main weight data-type.")
"data-type will be used for the main weight data-type. Replaces --main-params-dtype.")
group.add_argument('--megatron-fsdp-main-grads-dtype', default='auto', choices=['fp32', 'bf16', 'fp16', 'auto'],
help="Data type for the main gradient buffer utilized for distributed optimization "
"with Megatron-FSDP. If 'auto', then the native model gradient data-type will "
"be used for the main gradient / accumulation data-type.")
"be used for the main gradient / accumulation data-type. Replaces --main-grads-dtype.")
group.add_argument("--megatron-fsdp-grad-comm-dtype", default='auto', choices=['fp32', 'fp16', 'bf16', 'auto'],
help="When using Megatron-FSDP, this controls the data-type used when communicating "
"model gradients during FSDP. If 'auto', then the main gradient data-type will "
Expand Down
Loading
Loading