Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/user-guide/features/fine_grained_activation_offloading.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ Supported offloading modules are `"attn_norm"`, `"core_attn"`, `"attn_proj"`, `"
--offload-modules expert_fc1
```

## Max inflight offloads

```bash
# Optional: cap inflight D2H offloads per offload group to N (omit or None in most setups).
# Required as a non-None non-negative integer when fine-grained activation offloading is used with
# local full-iteration CUDA graphs (full_iteration in cuda_graph_scope); see prose below.
--fine_grained_offloading_max_inflight_offloads <N>
```

TransformerConfig.fine_grained_offloading_max_inflight_offloads caps, per offload group (for example `moe_act`, `qkv_linear`), how many D2H copies may be in flight before a main-stream wait_event. 0 waits after each offload; larger values allow more overlap; None skips these joins.

With full-iteration CUDA graphs (local graph impl, full_iteration in cuda_graph_scope) and fine-grained activation offloading enabled, set it to a non-None integer: that path does not rely on record_stream, so explicit joins are required.

## Compatible With Fine-Grained Recomputation

- For low-overhead modules such as LayerNorm or `moe_act`, use recomputation to save activation memory.
Expand Down
1 change: 1 addition & 0 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def preprocess_for_fine_grained_offloading(self):
vp_size=self.config.virtual_pipeline_model_parallel_size,
vp_stage=self.vp_stage,
min_offloaded_tensor_size=self.config.min_offloaded_tensor_size,
max_inflight_offloads=self.config.fine_grained_offloading_max_inflight_offloads,
)
if self.disable_param_offloading:
for param in self.decoder.parameters():
Expand Down
1 change: 1 addition & 0 deletions megatron/core/models/hybrid/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def preprocess_for_fine_grained_offloading(self):
vp_size=self.config.virtual_pipeline_model_parallel_size,
vp_stage=self.vp_stage,
min_offloaded_tensor_size=self.config.min_offloaded_tensor_size,
max_inflight_offloads=self.config.fine_grained_offloading_max_inflight_offloads,
)
if self.disable_param_offloading:
for param in self.decoder.parameters():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

from collections import deque
from collections import defaultdict, deque
from contextlib import nullcontext
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

import torch

Expand Down Expand Up @@ -608,7 +608,12 @@ def front_backward_chunk(self, name=None):
return None

def init_model_chunk_offload_handler(
self, vp_size, vp_stage, min_offloaded_tensor_size=1024 * 1024
self,
pp_rank,
vp_size,
vp_stage,
min_offloaded_tensor_size=1024 * 1024,
max_inflight_offloads: Optional[int] = None,
):
"""
Initialize a chunk offload handler for a model chunk (microbatch).
Expand All @@ -617,6 +622,9 @@ def init_model_chunk_offload_handler(
vp_size: Virtual pipeline size
vp_stage: Virtual pipeline stage index (None means stage 0)
min_offloaded_tensor_size: Minimum tensor size (in elements) to offload
max_inflight_offloads: If set, cap pending offloads per group name before main
wait_event; see ``fine_grained_offloading_max_inflight_offloads`` on
``TransformerConfig``.
"""
if not self._is_warmup:
return
Expand All @@ -636,7 +644,11 @@ def init_model_chunk_offload_handler(
self.flush()

# Use shared CPU tensor pool for better reuse across chunks
cur_chunk = ChunkOffloadHandler(min_offloaded_tensor_size, self._cpu_tensor_pool)
cur_chunk = ChunkOffloadHandler(
min_offloaded_tensor_size,
self._cpu_tensor_pool,
max_inflight_offloads=max_inflight_offloads,
)
debug_rank(f"init_model_chunk_offload_handler {cur_chunk}")
self._stages[cur_vpp_rank].append(cur_chunk)
# For the last stage, push immediately and flush
Expand Down Expand Up @@ -763,7 +775,12 @@ def reload(self, state, non_blocking=None):
self.cpu_tensor_pool.free(cpu_backup)
return gpu_tensor

def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool):
def __init__(
self,
min_offloaded_tensor_size,
cpu_tensor_pool,
max_inflight_offloads: Optional[int] = None,
):
self.do_offload = True

# Group management for batching offload/reload operations
Expand All @@ -786,6 +803,10 @@ def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool):
self.min_offloaded_tensor_size = min_offloaded_tensor_size
self.cpu_tensor_pool = cpu_tensor_pool
self.is_warmup = True
# Max per-group-name inflight offloads not yet joined on the main stream (None = off).
self._max_inflight_offloads = max_inflight_offloads
# group_name -> FIFO of offload events for that name (same cap for every name).
self._offload_pending_by_name: Dict[str, deque] = defaultdict(deque)

def reset(self):
"""Reset the chunk offload handler."""
Expand All @@ -794,6 +815,9 @@ def reset(self):
self._groups_to_reload = []
self._tensor_count_current_group = 0
self._reloading_group = []
# Clear the pending-event FIFO at iter boundary so we never wait on
# an event recorded in a previous (non-captured) iteration.
self._offload_pending_by_name.clear()

def find_group_with_name(self, name: str, start_index: int = 0):
"""Find the group with the given name starting from the given index."""
Expand Down Expand Up @@ -889,6 +913,14 @@ def bulk_offload_group(self):
group_to_offload.record_offload_event(self.d2h_stream)
self._groups_to_offload.pop()
nvtx_range_pop(nvtx_msg)
# Under full-iteration CG capture, the main stream may not wait on d2h
# events; optional max-inflight enqueues each group's offload event and
# has main wait on older events for this group name when its pending
# count exceeds the cap (each name is tracked separately).
if self._max_inflight_offloads is not None:
gname = group_to_offload._name
self._offload_pending_by_name[gname].append(group_to_offload._offload_event)
self._drain_offload_pending(gname)

def get_max_deduplicated_groups(self):
"""Get the maximum number of deduplicated groups."""
Expand Down Expand Up @@ -967,7 +999,19 @@ def bulk_offload(self, forced_released_tensors):
release_tensor.record_stream(cur_stream)
release_tensor.untyped_storage().resize_(0)

def on_group_commit_forward(self, forced_released_tensors):
def _drain_offload_pending(self, group_name: str) -> None:
"""For ``group_name``, have the main stream wait on older D2H events
when that name's pending count exceeds ``_max_inflight_offloads``
(same cap for every name; 0 = wait on each commit for that name)."""
if self._max_inflight_offloads is None:
return
cur = torch.cuda.current_stream()
q = self._offload_pending_by_name[group_name]
while len(q) > self._max_inflight_offloads:
old_evt = q.popleft()
cur.wait_event(old_evt)

def on_group_commit_forward(self, name, forced_released_tensors):
"""Called at the end of a layer group's forward pass to trigger offloading."""
if not self.do_offload:
return
Expand Down Expand Up @@ -1229,10 +1273,21 @@ def __exit__(self, *args: Any):
PipelineOffloadManager.get_instance().__exit__()

@staticmethod
def init_chunk_handler(vp_size, vp_stage, min_offloaded_tensor_size):
@staticmethod
def init_chunk_handler(
pp_rank,
vp_size,
vp_stage,
min_offloaded_tensor_size,
max_inflight_offloads: Optional[int] = None,
):
"""Initialize the chunk handler, called at the start of a microbatch forward pass."""
PipelineOffloadManager.get_instance().init_model_chunk_offload_handler(
vp_size, vp_stage, min_offloaded_tensor_size
pp_rank,
vp_size,
vp_stage,
min_offloaded_tensor_size,
max_inflight_offloads=max_inflight_offloads,
)

@staticmethod
Expand Down
30 changes: 30 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,14 @@ class TransformerConfig(ModelParallelConfig):
min_offloaded_tensor_size: int = 1024 * 1024
"""The minimum size of the tensor to be offloaded."""

fine_grained_offloading_max_inflight_offloads: Optional[int] = None
"""Per fine-grained offloading group name, max number of inflight offloads for that name not
yet joined on the main stream (wait_event on D2H). The same cap applies to every name (e.g.,
``moe_act`` and ``qkv_linear`` each have their own pending queue). 0 = wait after every
offload for that name. 1 = at most one not-yet-waited offload per name, etc. None = do not
insert these joins. This feature is particularly useful when using with full-iteration CUDA
graphs"""

def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
Expand Down Expand Up @@ -2105,6 +2113,28 @@ def __post_init__(self):
"moe_input_jitter_eps is not supported with graphed moe recomputation."
)

if self.fine_grained_activation_offloading:
assert self.cuda_graph_impl == "transformer_engine" or (
self.cuda_graph_impl == "local"
and self.cuda_graph_scope == [CudaGraphScope.full_iteration]
), (
"fine-grained activation offloading is only supported with "
"transformer_engine CUDA graph implementation or local CUDA graph "
"implementation with full_iteration scope."
)
assert (
CudaGraphScope.moe not in self.cuda_graph_scope
), "Token-drop MoE is temporarily not supported with activation offloading."
assert self.cuda_graph_warmup_steps > 0, (
"cuda_graph_warmup_steps must be greater than 0 when enabling "
"fine-grained activation offloading."
)
if CudaGraphScope.full_iteration in self.cuda_graph_scope:
assert self.fine_grained_offloading_max_inflight_offloads is not None, (
"fine_grained_offloading_max_inflight_offloads must be set when using "
"fine-grained activation offloading with full-iteration CUDA graphs "
)

if self.moe_token_dispatcher_type in ["allgather"]:
if self.variable_seq_lengths is True:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/models/test_hybrid_moe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@
"fine_grained_activation_offloading": False,
"min_offloaded_tensor_size": 1024 * 1024,
"offload_modules": [],
"fine_grained_offloading_max_inflight_offloads": None,
"hybrid_context_parallel": False,
"max_seqlen_per_dp_cp_rank": None,
"inference_disable_triton_nvls_kernels": False,
Expand Down
Loading