Skip to content

[Megatron-FSDP] Add conditional param.grad dereferencing logic to support CUDA graphability.#4663

Open
cspades wants to merge 4 commits intoNVIDIA:mainfrom
cspades:cye/mfsdp-cuda-graph
Open

[Megatron-FSDP] Add conditional param.grad dereferencing logic to support CUDA graphability.#4663
cspades wants to merge 4 commits intoNVIDIA:mainfrom
cspades:cye/mfsdp-cuda-graph

Conversation

@cspades
Copy link
Copy Markdown
Member

@cspades cspades commented May 6, 2026

What does this PR do ?

  • Fixes the only code blocking Megatron-FSDP from working with CUDA graph replay.
    • Some modifications (such as n_tensor = torch.empty(1, dtype=torch.int64, device=dev).fill_(n)) were downstreamed from [Dev] Paged Stashing #2690

Usage

  • Refer to examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh. Assertion guard forces the user to turn on --use-precision-aware-optimizer or match main parameter / gradient precision when using full-iter CG.

Bug

  • Megatron-FSDP was not compatible with CUDA graph replay, because we dereference one of the pointers required for the optimizer step after graph capture. We need to preserve the param.(decoupled_)grad pointer to access the output gradient shards produced by the CUDA graph replay.
    • finalize_model_grads() -> update_main_grads() -> setattr(param, "[decoupled_]grad", grad) installs the sharded gradient.
    • zero_grad() dereferences param.grad and param.decoupled_grad.
    • During replay, the param.grad pointer is not re-attached to the sharded gradient, and optimizer.step() fails.

Testing

  • @rapatel has proof-of-convergence and increased performance on MLPerf benchmarks.
  • Torch-native full-iteration CUDA graph for fully_shard, and an E2E pretrain test for Megatron-FSDP in MLM. (Couldn't get CUDA graph working with lower-level entrypoints, too many conditionals and training loop logic to port into the unit test.)

Details

  • FusedAdam(capturable=True) is required for CUDA graphability, because Adam.zero_grad() and most Torch-native optimizers dereference the gradient, i.e. param.grad = None, while FusedAdam.zero_grad(set_to_none=False) preserves the gradient buffer.
  • Additionally, Megatron-FSDP will attempt to deallocate / dereference param.grad or param.decoupled_grad if the gradient is a copy of the accumulated gradient in Megatron-FSDP's sharded gradient accumulation buffer, as FSDP software requirements imply that we should not have 2 memory allocations for the sharded gradient.
    • This PR adjusts the logic to only dereference if the gradient is not a view (of the buffer), otherwise the gradient view does not need to be deallocated and CUDA graph-ability is preserved.
  • Recommended CUDA-graphable paths where Megatron-FSDP does not need to dereference the gradient are:
    • --use-precision-aware-optimizer / megatron_fsdp_use_decoupled_grad=True - param.decoupled_grad is always a view of Megatron-FSDP's sharded main gradient accumulation buffer, which is persistent.
    • megatron_fsdp_main_params_dtype == megatron_fsdp_main_grads_dtype - param.grad does not require .to(param.dtype) casting that allocates unwanted memory overhead.

View-Checking the Gradient Shard

  • When we are not using the precision-aware optimizer and the main weight precision doesn't match the main gradient precision, we end up with a casted copy of the gradient shard. In this situation, we should deallocate / dereference param.grad and this combination of arguments will not be compatible with the full CUDA graph.
(Pdb) optimizer_grad
tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', dtype=torch.bfloat16)
(Pdb) param.dtype
torch.float32
(Pdb) id(optimizer_grad)
140718698403376
(Pdb) n
> /opt/megatron-lm/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py(2982)update_main_grads()
-> if group.main_weight_buffer is not None and not self.use_decoupled_grad:
(Pdb) n
> /opt/megatron-lm/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py(2985)update_main_grads()
-> optimizer_grad = optimizer_grad.to(param.dtype)
(Pdb) n
> /opt/megatron-lm/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py(2987)update_main_grads()
-> if name not in self.dist_main_grad:

# CASTED GRADIENT / DIFFERENT ADDRESS
(Pdb) id(optimizer_grad)
140718698869968
# NOT A VIEW
(Pdb) optimizer_grad._base
  • Otherwise, if we are using the precision-aware optimizer, or set --megatron-fsdp-main-params-dtype == --megatron-fsdp-main-grads-dtype, then we have a view of the sharded gradient buffer, do not need to dereference the view, and support CUDA graph-ability!
(Pdb) optimizer_grad
tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')
(Pdb) param.dtype
torch.float32
(Pdb) id(optimizer_grad)
140718436766800
(Pdb) n
> /opt/megatron-lm/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py(2982)update_main_grads()
-> if group.main_weight_buffer is not None and not self.use_decoupled_grad:
(Pdb) n
> /opt/megatron-lm/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py(2985)update_main_grads()
-> optimizer_grad = optimizer_grad.to(param.dtype)
(Pdb) n
> /opt/megatron-lm/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py(2987)update_main_grads()
-> if name not in self.dist_main_grad:

# NO-OP / SAME ADDRESS
(Pdb) id(optimizer_grad)
140718436766800
# VIEW OF MFSDP MAIN GRAD BUFFER
(Pdb) optimizer_grad._base
tensor([ 0.0000,  0.0000,  0.0000,  ..., -0.4228, -3.4517, -1.0283],
       device='cuda:0')

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@cspades cspades self-assigned this May 6, 2026
@cspades cspades requested review from a team as code owners May 6, 2026 23:55
@cspades cspades added Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. module: megatron-fsdp labels May 6, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft May 6, 2026 23:56
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 6, 2026

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 6, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@cspades cspades marked this pull request as ready for review May 6, 2026 23:56
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team May 6, 2026 23:56
Comment on lines 2951 to +3004
@@ -2981,7 +3001,7 @@ def update_main_grads(self):
setattr(param, "decoupled_grad", grad)
else:
# Attach the gradient to the optimizer parameter.
setattr(param, "grad", grad.to(param.dtype) if grad is not None else None)
setattr(param, "grad", grad)
Copy link
Copy Markdown
Member Author

@cspades cspades May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

90% sure this lower cast is legacy code. Originally this is where we did the cast, but a past precision-aware optimizer PR added a 2nd cast above. Let's just have 1 cast. 👍🏻

@cspades cspades marked this pull request as draft May 7, 2026 19:49
@cspades cspades marked this pull request as ready for review May 7, 2026 20:48
@cspades cspades changed the title Add conditional param.grad dereferencing logic to support CUDA graphability. [Megatron-FSDP] Add conditional param.grad dereferencing logic to support CUDA graphability. May 7, 2026
Copy link
Copy Markdown
Contributor

@shjwudp shjwudp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that removing type casting is key to ensuring compatibility with full CG. You've already tested the compatibility between decoupled gradients and full CG—correct?

To avoid blocking MLPerf, I’m approving this PR. However, I suggest we clearly document the conditions required for full CG compatibility and avoid introducing unrelated code changes, as they could cause confusion when we revisit this fix later.

and isinstance(param.decoupled_grad, DTensor)
and param.decoupled_grad._local_tensor._base is None
):
# NOTE: Decoupled gradients should always be a view of Megatron-FSDP buffers.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The newly added code appears to have implemented stricter checks. What is the purpose of this? Will it help resolve full CG compatibility issues?

Copy link
Copy Markdown
Member Author

@cspades cspades May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually completely up for debate during review, it's not a major blocker either so we can discuss this now. Also, I am still adding more tests and will add documentation when I test full-iteration and partial CG.

TL;DR We have to choose one of the following solutions:

  • Limited CUDA graph support (will not work for param.dtype != grad.dtype), and the exact same memory utilization as before.
  • Full CUDA graph support, but increased sharded memory utilization for certain cases.

The only requirement for full-iteration CUDA graph is that we do not set param.grad = None. However, FSDP dereferences the gradient sometimes:

  • If param.grad is a view of our sharded buffer, there is no need to deallocate / dereference this variable.
  • If param.grad is not a view of our sharded buffer, deallocating / dereferencing this variable can save memory, because it is a casted-copy of our sharded buffer.

The logic for decoupled_grad is exactly the same. The main difference is that we should never need to run this code for the decoupled gradient, because we never have decoupled_grad.to(param.dtype), and FusedAdam(use_decoupled_grad=True) allows for mixed-precision optimization steps.

The complex logic here dereferences this variable when there is memory to reclaim, so we do not persistently hold 2 copies of the sharded gradient that will raise memory overhead.

An alternative is that we can delete all of the param.grad = None code, where:

  • If it is a view of the sharded buffer, we will zero out the buffer below.
  • If it is not a view of the sharded buffer, then grad.to(param.dtype) will overwrite this buffer's values during CUDA graph replay.
    • This is a 2nd allocation of the sharded gradient, which increases memory utilization.

In this case, we can support CUDA graph-ability for all cases, but Megatron-FSDP will use more memory in specific cases.

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Final Review PR is in the "final review" stage label May 8, 2026
@cspades cspades requested review from a team as code owners May 8, 2026 19:56
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the Final Review PR is in the "final review" stage label May 8, 2026
@cspades cspades requested a review from nanz-nv May 8, 2026 20:04
@cspades cspades force-pushed the cye/mfsdp-cuda-graph branch from 4d8c976 to 264c850 Compare May 9, 2026 00:06
…bility.

Signed-off-by: Cory Ye <cye@nvidia.com>
cspades added 3 commits May 8, 2026 19:35
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: low Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. module: megatron-fsdp

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants