Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions megatron/core/full_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,47 @@

logger = logging.getLogger(__name__)

# Process-wide handle so full-iter and optimizer graph captures share one pool and one
# non-default stream (per-stream alloc segments can inflate memory_reserved; see
# tools/debug_cuda_graph_pool_memory*.py).
_shared_graph_pool = None
_shared_capture_stream = None


def get_shared_capture_stream():
"""Return one `torch.cuda.Stream` for all full-iter and optimizer graph captures.

Call after the target CUDA device is selected.
"""
global _shared_capture_stream
if _shared_capture_stream is None:
_shared_capture_stream = torch.cuda.Stream()
return _shared_capture_stream


def get_shared_graph_pool():
"""Return a process-wide handle so all call sites share one graph memory pool.

`torch.cuda.graph_pool_handle()` returns a new pool each time; this lazy singleton
ensures e.g. full-iteration and optimizer captures reuse the same pool.
"""
global _shared_graph_pool
if _shared_graph_pool is None:
_shared_graph_pool = torch.cuda.graph_pool_handle()
return _shared_graph_pool


def get_graph_pool(use_single_mempool):
"""Return graph pool handle for full-iter/optimizer graph capture.

When `use_single_mempool` is True, train/eval and optimizer captures reuse one
process-wide pool. Otherwise, each capture call gets a new pool handle.
"""
if use_single_mempool:
return get_shared_graph_pool()
return torch.cuda.graph_pool_handle()


# The below functions traverse through nested data structures (tuples, lists, dicts)
# present in src and creates a deep copy where all PyTorch tensors are cloned,
# detached from the computation graph, and moved to CUDA device. Non-tensor objects
Expand Down Expand Up @@ -100,10 +141,11 @@ class FullCudaGraphWrapper:
cuda_graph = {'training': None, 'validation': None}
result = {'training': None, 'validation': None}

def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1):
def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1, use_single_mempool=False):
self.forward_backward_func = forward_backward_func
self.static_loader = StaticBufferLoader()
self.cuda_graph_warmup_steps = cuda_graph_warmup_steps
self.use_single_mempool = use_single_mempool

def data_read(self, data_iterator, model, training, num_microbatches):
"""Read all microbatch inputs from Dataloader and copy to static buffers."""
Expand Down Expand Up @@ -170,10 +212,11 @@ def __call__(self, *args, **kwargs):
for _, state in get_all_rng_states().items():
FullCudaGraphWrapper.cuda_graph[training_str].register_generator_state(state)
torch.cuda.synchronize()
capture_stream = torch.cuda.Stream()
capture_stream = get_shared_capture_stream()
with torch.cuda.graph(
FullCudaGraphWrapper.cuda_graph[training_str],
stream=capture_stream,
pool=get_graph_pool(self.use_single_mempool),
capture_error_mode="thread_local",
):
FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(
Expand Down
13 changes: 10 additions & 3 deletions megatron/core/optimizer/optimizer_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import torch

from megatron.core.full_cuda_graph import get_graph_pool, get_shared_capture_stream

logger = logging.getLogger(__name__)


Expand All @@ -16,9 +18,10 @@ class OptimizerCudaGraphWrapper:
cuda_graph = None
result = None # result of the optimizer.step() function

def __init__(self, optimizer_step_func, cuda_graph_warmup_steps=1):
def __init__(self, optimizer_step_func, cuda_graph_warmup_steps=1, use_single_mempool=False):
self.optimizer_step_func = optimizer_step_func
self.cuda_graph_warmup_steps = cuda_graph_warmup_steps
self.use_single_mempool = use_single_mempool

def __call__(self, *args, **kwargs):
assert len(args) == 0, 'optimizer.step() does not accept positional args'
Expand All @@ -31,8 +34,12 @@ def __call__(self, *args, **kwargs):
assert OptimizerCudaGraphWrapper.cuda_graph is None
OptimizerCudaGraphWrapper.cuda_graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
capture_stream = torch.cuda.Stream()
with torch.cuda.graph(OptimizerCudaGraphWrapper.cuda_graph, stream=capture_stream):
capture_stream = get_shared_capture_stream()
with torch.cuda.graph(
OptimizerCudaGraphWrapper.cuda_graph,
stream=capture_stream,
pool=get_graph_pool(self.use_single_mempool),
):
OptimizerCudaGraphWrapper.result = self.optimizer_step_func()
torch.cuda.synchronize()
torch.distributed.barrier()
Expand Down
11 changes: 5 additions & 6 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,12 +849,11 @@ class TransformerConfig(ModelParallelConfig):
CUDA graph (1 CUDA graph for whole iteration excluding optimizer) is enabled. --cuda-graph-scope
determines the scope of graph capture."""

cuda_graph_use_single_mempool: bool = False
"""[For `local` implementation only] When set to true, cudagraphs will be captured inside a
single mempool, in which all cudagraphs may only be used once per step. If false, cudagraphs may
be reused across microbatches. Enabling may reduce cudagraph memory overheads due to memory
fragmentation, however may greatly increase the number of cudagraphs created when the number of
microbatches is high."""
cuda_graph_use_single_mempool: bool = True
"""For cuda_graph_impl "local" with cuda_graph_scope "full_iteration" only.

When True, full-iteration graph replay (training and evaluation) and optimizer graph
capture/replay share the same CUDA graph memory pool."""

cuda_graph_retain_backward_graph: bool = False
"""When set to true, cudagraph backward passes will be graph captured with 'retain_grad=True'
Expand Down
4 changes: 3 additions & 1 deletion megatron/rl/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,9 @@ def prepare_data_for_update(
forward_backward_func = get_forward_backward_func()
if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope:
forward_backward_func = FullCudaGraphWrapper(
forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps
forward_backward_func,
cuda_graph_warmup_steps=args.cuda_graph_warmup_steps,
use_single_mempool=args.cuda_graph_use_single_mempool,
)

dtype = (
Expand Down
18 changes: 15 additions & 3 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2943,9 +2943,17 @@ def train(
# Wrap forward_backward_func for Full iteration CUDA graph
forward_backward_func = get_forward_backward_func()
if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope:
forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps)
forward_backward_func = FullCudaGraphWrapper(
forward_backward_func,
cuda_graph_warmup_steps=args.cuda_graph_warmup_steps,
use_single_mempool=args.cuda_graph_use_single_mempool,
)
if args.optimizer_cuda_graph:
optimizer.step = OptimizerCudaGraphWrapper(optimizer.step, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps)
optimizer.step = OptimizerCudaGraphWrapper(
optimizer.step,
cuda_graph_warmup_steps=args.cuda_graph_warmup_steps,
use_single_mempool=args.cuda_graph_use_single_mempool,
)

def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics."""
Expand Down Expand Up @@ -3457,7 +3465,11 @@ def evaluate(
eval_num_microbatches = eval_batch_size // (eval_micro_batch_size * args.data_parallel_size)
forward_backward_func = get_forward_backward_func()
if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope:
forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps)
forward_backward_func = FullCudaGraphWrapper(
forward_backward_func,
cuda_graph_warmup_steps=args.cuda_graph_warmup_steps,
use_single_mempool=args.cuda_graph_use_single_mempool,
)

if has_nvidia_modelopt:
# [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/models/test_hybrid_moe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"cuda_graph_impl": "none",
"cuda_graph_retain_backward_graph": False,
"cuda_graph_scope": [],
"cuda_graph_use_single_mempool": False,
"cuda_graph_use_single_mempool": True,
"cuda_graph_warmup_steps": 3,
"deallocate_pipeline_outputs": True,
"defer_embedding_wgrad_compute": False,
Expand Down