From 0d4e10271d5e40eeb13d38ebf0f5a9c480c6a477 Mon Sep 17 00:00:00 2001 From: Nan Zheng <80790206+nanz-nv@users.noreply.github.com> Date: Fri, 8 May 2026 13:19:56 +0800 Subject: [PATCH] Allow optimizer CG to share the same pool as full-iter CG (#4521) --- megatron/core/full_cuda_graph.py | 47 ++++++++++++++++++- .../core/optimizer/optimizer_cuda_graph.py | 13 +++-- .../core/transformer/transformer_config.py | 11 ++--- megatron/rl/rl_utils.py | 4 +- megatron/training/training.py | 18 +++++-- .../models/test_hybrid_moe_model.py | 2 +- 6 files changed, 79 insertions(+), 16 deletions(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 7d790a07c6c..f84f40f56ee 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -10,6 +10,47 @@ logger = logging.getLogger(__name__) +# Process-wide handle so full-iter and optimizer graph captures share one pool and one +# non-default stream (per-stream alloc segments can inflate memory_reserved; see +# tools/debug_cuda_graph_pool_memory*.py). +_shared_graph_pool = None +_shared_capture_stream = None + + +def get_shared_capture_stream(): + """Return one `torch.cuda.Stream` for all full-iter and optimizer graph captures. + + Call after the target CUDA device is selected. + """ + global _shared_capture_stream + if _shared_capture_stream is None: + _shared_capture_stream = torch.cuda.Stream() + return _shared_capture_stream + + +def get_shared_graph_pool(): + """Return a process-wide handle so all call sites share one graph memory pool. + + `torch.cuda.graph_pool_handle()` returns a new pool each time; this lazy singleton + ensures e.g. full-iteration and optimizer captures reuse the same pool. + """ + global _shared_graph_pool + if _shared_graph_pool is None: + _shared_graph_pool = torch.cuda.graph_pool_handle() + return _shared_graph_pool + + +def get_graph_pool(use_single_mempool): + """Return graph pool handle for full-iter/optimizer graph capture. + + When `use_single_mempool` is True, train/eval and optimizer captures reuse one + process-wide pool. Otherwise, each capture call gets a new pool handle. + """ + if use_single_mempool: + return get_shared_graph_pool() + return torch.cuda.graph_pool_handle() + + # The below functions traverse through nested data structures (tuples, lists, dicts) # present in src and creates a deep copy where all PyTorch tensors are cloned, # detached from the computation graph, and moved to CUDA device. Non-tensor objects @@ -100,10 +141,11 @@ class FullCudaGraphWrapper: cuda_graph = {'training': None, 'validation': None} result = {'training': None, 'validation': None} - def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1): + def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1, use_single_mempool=False): self.forward_backward_func = forward_backward_func self.static_loader = StaticBufferLoader() self.cuda_graph_warmup_steps = cuda_graph_warmup_steps + self.use_single_mempool = use_single_mempool def data_read(self, data_iterator, model, training, num_microbatches): """Read all microbatch inputs from Dataloader and copy to static buffers.""" @@ -170,10 +212,11 @@ def __call__(self, *args, **kwargs): for _, state in get_all_rng_states().items(): FullCudaGraphWrapper.cuda_graph[training_str].register_generator_state(state) torch.cuda.synchronize() - capture_stream = torch.cuda.Stream() + capture_stream = get_shared_capture_stream() with torch.cuda.graph( FullCudaGraphWrapper.cuda_graph[training_str], stream=capture_stream, + pool=get_graph_pool(self.use_single_mempool), capture_error_mode="thread_local", ): FullCudaGraphWrapper.result[training_str] = self.forward_backward_func( diff --git a/megatron/core/optimizer/optimizer_cuda_graph.py b/megatron/core/optimizer/optimizer_cuda_graph.py index b96ea2678a4..21a32cf69ef 100644 --- a/megatron/core/optimizer/optimizer_cuda_graph.py +++ b/megatron/core/optimizer/optimizer_cuda_graph.py @@ -6,6 +6,8 @@ import torch +from megatron.core.full_cuda_graph import get_graph_pool, get_shared_capture_stream + logger = logging.getLogger(__name__) @@ -16,9 +18,10 @@ class OptimizerCudaGraphWrapper: cuda_graph = None result = None # result of the optimizer.step() function - def __init__(self, optimizer_step_func, cuda_graph_warmup_steps=1): + def __init__(self, optimizer_step_func, cuda_graph_warmup_steps=1, use_single_mempool=False): self.optimizer_step_func = optimizer_step_func self.cuda_graph_warmup_steps = cuda_graph_warmup_steps + self.use_single_mempool = use_single_mempool def __call__(self, *args, **kwargs): assert len(args) == 0, 'optimizer.step() does not accept positional args' @@ -31,8 +34,12 @@ def __call__(self, *args, **kwargs): assert OptimizerCudaGraphWrapper.cuda_graph is None OptimizerCudaGraphWrapper.cuda_graph = torch.cuda.CUDAGraph() torch.cuda.synchronize() - capture_stream = torch.cuda.Stream() - with torch.cuda.graph(OptimizerCudaGraphWrapper.cuda_graph, stream=capture_stream): + capture_stream = get_shared_capture_stream() + with torch.cuda.graph( + OptimizerCudaGraphWrapper.cuda_graph, + stream=capture_stream, + pool=get_graph_pool(self.use_single_mempool), + ): OptimizerCudaGraphWrapper.result = self.optimizer_step_func() torch.cuda.synchronize() torch.distributed.barrier() diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index bb044787b9c..e1289519281 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -849,12 +849,11 @@ class TransformerConfig(ModelParallelConfig): CUDA graph (1 CUDA graph for whole iteration excluding optimizer) is enabled. --cuda-graph-scope determines the scope of graph capture.""" - cuda_graph_use_single_mempool: bool = False - """[For `local` implementation only] When set to true, cudagraphs will be captured inside a - single mempool, in which all cudagraphs may only be used once per step. If false, cudagraphs may - be reused across microbatches. Enabling may reduce cudagraph memory overheads due to memory - fragmentation, however may greatly increase the number of cudagraphs created when the number of - microbatches is high.""" + cuda_graph_use_single_mempool: bool = True + """For cuda_graph_impl "local" with cuda_graph_scope "full_iteration" only. + + When True, full-iteration graph replay (training and evaluation) and optimizer graph + capture/replay share the same CUDA graph memory pool.""" cuda_graph_retain_backward_graph: bool = False """When set to true, cudagraph backward passes will be graph captured with 'retain_grad=True' diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 4d018217cec..5faadab4a31 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -1513,7 +1513,9 @@ def prepare_data_for_update( forward_backward_func = get_forward_backward_func() if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: forward_backward_func = FullCudaGraphWrapper( - forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps + forward_backward_func, + cuda_graph_warmup_steps=args.cuda_graph_warmup_steps, + use_single_mempool=args.cuda_graph_use_single_mempool, ) dtype = ( diff --git a/megatron/training/training.py b/megatron/training/training.py index c6cab8df952..27732c4df5a 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2943,9 +2943,17 @@ def train( # Wrap forward_backward_func for Full iteration CUDA graph forward_backward_func = get_forward_backward_func() if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: - forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + forward_backward_func = FullCudaGraphWrapper( + forward_backward_func, + cuda_graph_warmup_steps=args.cuda_graph_warmup_steps, + use_single_mempool=args.cuda_graph_use_single_mempool, + ) if args.optimizer_cuda_graph: - optimizer.step = OptimizerCudaGraphWrapper(optimizer.step, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + optimizer.step = OptimizerCudaGraphWrapper( + optimizer.step, + cuda_graph_warmup_steps=args.cuda_graph_warmup_steps, + use_single_mempool=args.cuda_graph_use_single_mempool, + ) def get_e2e_base_metrics(): """Get base metrics values for one-logger to calculate E2E tracking metrics.""" @@ -3457,7 +3465,11 @@ def evaluate( eval_num_microbatches = eval_batch_size // (eval_micro_batch_size * args.data_parallel_size) forward_backward_func = get_forward_backward_func() if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: - forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + forward_backward_func = FullCudaGraphWrapper( + forward_backward_func, + cuda_graph_warmup_steps=args.cuda_graph_warmup_steps, + use_single_mempool=args.cuda_graph_use_single_mempool, + ) if has_nvidia_modelopt: # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors diff --git a/tests/unit_tests/models/test_hybrid_moe_model.py b/tests/unit_tests/models/test_hybrid_moe_model.py index 3935964c975..6a5a1ebb331 100644 --- a/tests/unit_tests/models/test_hybrid_moe_model.py +++ b/tests/unit_tests/models/test_hybrid_moe_model.py @@ -70,7 +70,7 @@ "cuda_graph_impl": "none", "cuda_graph_retain_backward_graph": False, "cuda_graph_scope": [], - "cuda_graph_use_single_mempool": False, + "cuda_graph_use_single_mempool": True, "cuda_graph_warmup_steps": 3, "deallocate_pipeline_outputs": True, "defer_embedding_wgrad_compute": False,