Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 39 additions & 32 deletions megatron/core/distributed/param_and_grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,38 @@ def reset(self):
self.per_param_grad_ready_counts = {}
self.is_last_microbatch = True

def _post_param_sync(self):
"""Run post-processing after param all-gather completes."""
if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag:
for bucket in self.buckets:
is_bf16_weight_bucket = False
for param in bucket.params:
# Skip copying since bf16 weights in the mxfp8 model
# are already mapped to param.data.
if not is_float8tensor(param):
is_bf16_weight_bucket = True
break
param_start, param_end = bucket.param_to_index[param]
param_slice = bucket.param_data.view(-1)[param_start:param_end]
param.data.copy_(param_slice.view(param.data.shape))
if is_bf16_weight_bucket:
continue
# All-gathered params are not needed after being copied to param.data.
# Zero out the param buffer (shared with grad buffer) for gradient accumulation.
# We cannot zero out the entire grad buffer because one grad buffer may
# correspond to multiple param buffers. If we zero out the entire grad buffer,
# it would clear the data of those param buffers that have not yet completed AG.
bucket.param_data.zero_()
return

quantized_params = []
for bucket in self.buckets:
for param in bucket.params:
if is_float8tensor(param) or is_nvfp4tensor(param):
quantized_params.append(param)
if len(quantized_params) > 0:
post_all_gather_processing(quantized_params)

def check_grads(self, check_for_nan_or_inf, check_for_large):
"""
Make sure norm of grads in bucket are not NaN prior to data-parallel
Expand Down Expand Up @@ -315,6 +347,7 @@ def start_param_sync(self, force_sync: bool = False):
if self.param_gather_handle is not None:
self.param_gather_handle.wait()
self.param_gather_handle = None
self._post_param_sync()
return
else:
assert self.param_gather_handle is None
Expand All @@ -333,6 +366,8 @@ def start_param_sync(self, force_sync: bool = False):
dp_size = self.intra_distributed_optimizer_instance_size
if dp_size == 1:
# Single-rank group (e.g., expt_dp_size == 1): no all-gather needed.
if force_sync and self.ddp_config.overlap_param_gather:
self._post_param_sync()
self.param_gather_dispatched = True
return
local_rank = self.intra_distributed_optimizer_instance_rank
Expand Down Expand Up @@ -430,6 +465,8 @@ def start_param_sync(self, force_sync: bool = False):
# (async_op=False) is used, `cm` is not None. Manually set to None for
# consistency with prior code.
self.param_gather_handle = None
if force_sync and self.ddp_config.overlap_param_gather:
self._post_param_sync()
self.param_gather_dispatched = True

def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
Expand Down Expand Up @@ -469,30 +506,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
else:
self.next_param_gather_bucket_group.start_param_sync()

# For the mxfp8_param with "reuse_grad_buf_for_mxfp8_param_ag=True",
# we need to copy the param_data from the shared_param/grad_buffer to param.data
# after the param all-gather.
if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag:
for bucket in self.buckets:
is_bf16_weight_bucket = False
for param in bucket.params:
# Skip copying since bf16 weights in the mxfp8 model
# are already mapped to param.data.
if not is_float8tensor(param):
is_bf16_weight_bucket = True
break
param_start, param_end = bucket.param_to_index[param]
param_slice = bucket.param_data.view(-1)[param_start:param_end]
param.data.copy_(param_slice.view(param.data.shape))
if is_bf16_weight_bucket:
continue
# All-gathered params are not needed after being copied to param.data.
# Zero out the param buffer (shared with grad buffer) for gradient accumulation.
# We cannot zero out the entire grad buffer because one grad buffer may
# correspond to multiple param buffers. If we zero out the entire grad buffer,
# it would clear the data of those param buffers that have not yet completed AG.
bucket.param_data.zero_()
elif not self.ddp_config.use_distributed_optimizer:
if not self.ddp_config.use_distributed_optimizer:
for bucket in self.buckets:
if bucket.layerwise_gather_list is None:
continue
Expand All @@ -515,14 +529,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
# (a view into grad_data) would start from the result of the
# latest parameter all-gather instead of zero.
bucket.grad_data.zero_()
else:
fp8_params = []
for bucket in self.buckets:
for param in bucket.params:
if is_float8tensor(param):
fp8_params.append(param)
if len(fp8_params) > 0:
post_all_gather_processing(fp8_params)
self._post_param_sync()

def start_grad_sync(self, force_all_reduce: Optional[bool] = False):
"""
Expand Down
62 changes: 61 additions & 1 deletion tests/unit_tests/test_fp4_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,37 @@ def get_batch(self, seq_length, micro_batch_size):
loss_mask = torch.ones(seq_length).repeat((micro_batch_size, 1)).cuda()
return input_ids, labels, position_ids, attention_mask, loss_mask

def run_eval_transition(self, args, model_chunks, batch):
input_ids, labels, position_ids, attention_mask, loss_mask = batch

if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model_chunks, param_sync=True)

model_chunks[0].eval()
model_chunks[0].set_is_first_microbatch()
with torch.no_grad():
eval_output = model_chunks[0].forward(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
labels=labels,
loss_mask=loss_mask,
)
eval_loss = eval_output.mean()
model_chunks[0].train()

if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model_chunks)

return eval_loss.item()

def _run_test_helper(
self, tp_size, inference: bool = False, fp4_param_gather: bool = True, **kwargs
self,
tp_size,
inference: bool = False,
fp4_param_gather: bool = True,
eval_transition: bool = False,
**kwargs,
):
"""Test fp4_param with gpt_model."""
args = self.create_test_args(
Expand Down Expand Up @@ -206,6 +235,7 @@ def _run_test_helper(
assert num_fp4_params == 4 * fp4_layers

loss_list = []
eval_loss_list = []

# CUDA graph setup (transformer_engine implementation)
cuda_graph_helper = None
Expand Down Expand Up @@ -267,6 +297,17 @@ def _run_test_helper(

loss_list.append(loss.item())

if eval_transition:
eval_loss_list.append(
self.run_eval_transition(
args,
gpt_model,
(input_ids, labels, position_ids, attention_mask, loss_mask),
)
)

if eval_transition:
return torch.tensor(loss_list), torch.tensor(eval_loss_list)
return torch.tensor(loss_list)

def run_test(self, tp_size, inference: bool = False, **kwargs):
Expand All @@ -282,6 +323,18 @@ def run_test(self, tp_size, inference: bool = False, **kwargs):

torch.testing.assert_close(loss_list, loss_list_ref, atol=1e-2, rtol=1e-2)

def run_test_with_eval_transition(self, tp_size, **kwargs):
"""Test fp4_param eval transition with gpt_model."""
loss_list, eval_loss_list = self._run_test_helper(
tp_size, fp4_param_gather=True, eval_transition=True, **kwargs
)
loss_list_ref, eval_loss_list_ref = self._run_test_helper(
tp_size, fp4_param_gather=False, eval_transition=True, **kwargs
)

torch.testing.assert_close(loss_list, loss_list_ref, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(eval_loss_list, eval_loss_list_ref, atol=1e-2, rtol=1e-2)

@pytest.mark.skipif(not is_nvfp4_available, reason=reason_for_no_nvfp4)
@pytest.mark.skipif(not is_te_min_version("2.7.0.dev0"), reason="TE 2.7.0.dev0 is required")
@pytest.mark.parametrize("tp_size", [2])
Expand All @@ -294,6 +347,13 @@ def test_nvfp4(self, tp_size, dp_overlap):
kwargs = {"overlap_param_gather": dp_overlap[0], "overlap_grad_reduce": dp_overlap[1]}
self.run_test(tp_size=tp_size, inference=False, **kwargs)

@pytest.mark.skipif(not is_nvfp4_available, reason=reason_for_no_nvfp4)
@pytest.mark.skipif(not is_te_min_version("2.7.0.dev0"), reason="TE 2.7.0.dev0 is required")
@pytest.mark.parametrize("tp_size", [2])
def test_nvfp4_eval_transition(self, tp_size):
kwargs = {"overlap_param_gather": True, "overlap_grad_reduce": True}
self.run_test_with_eval_transition(tp_size=tp_size, **kwargs)

@pytest.mark.skipif(not is_nvfp4_available, reason=reason_for_no_nvfp4)
@pytest.mark.skipif(not is_te_min_version("2.7.0.dev0"), reason="TE 2.7.0.dev0 is required")
@pytest.mark.parametrize("tp_size", [2])
Expand Down
76 changes: 73 additions & 3 deletions tests/unit_tests/test_fp8_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator
from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.utils import is_te_min_version
from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args
Expand Down Expand Up @@ -171,13 +172,51 @@ def get_batch(self, seq_length, micro_batch_size):
loss_mask = torch.ones(seq_length).repeat((micro_batch_size, 1)).cuda()
return input_ids, labels, position_ids, attention_mask, loss_mask

def copy_main_params_to_param_buffer(self, model_chunks, optimizer):
# Mirrors MBridge's pre-eval fix: disable_forward_pre_hook(param_sync=True)
# force-syncs params before eval callbacks run, so MXFP8 must repopulate
# the shared param/grad buffer before disabling forward hooks.
for model_chunk in model_chunks:
model_chunk.zero_grad_buffer()
for optim_instance in optimizer.chained_optimizers:
if isinstance(optim_instance, DistributedOptimizer):
optim_instance._copy_main_params_to_param_buffer()

def run_eval_transition(self, args, model_chunks, optimizer, batch):
input_ids, labels, position_ids, attention_mask, loss_mask = batch

if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather:
self.copy_main_params_to_param_buffer(model_chunks, optimizer)

if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model_chunks, param_sync=True)

model_chunks[0].eval()
model_chunks[0].set_is_first_microbatch()
with torch.no_grad():
eval_output = model_chunks[0].forward(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
labels=labels,
loss_mask=loss_mask,
)
eval_loss = eval_output.mean()
model_chunks[0].train()

if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model_chunks)

return eval_loss.item()

def _run_test_helper(
self,
tp_size,
recipe,
inference: bool = False,
fp8_param_gather: bool = True,
use_cuda_graph: bool = False,
eval_transition: bool = False,
**kwargs,
):
"""Test fp8_param with gpt_model."""
Expand Down Expand Up @@ -269,6 +308,7 @@ def _run_test_helper(
)

loss_list = []
eval_loss_list = []

for i in range(100):
if not inference:
Expand All @@ -290,9 +330,7 @@ def _run_test_helper(
# we need to call the _copy_main_params_to_param_buffer() after the grad buffer
# is zeroed by zero_grad_buffer() because param and grad buffer are shared.
if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather:
for optim_instance in optimizer.chained_optimizers:
if hasattr(optim_instance, "_copy_main_params_to_param_buffer"):
optim_instance._copy_main_params_to_param_buffer()
self.copy_main_params_to_param_buffer(gpt_model, optimizer)

gpt_model[0].set_is_first_microbatch()
output = gpt_model[0].forward(
Expand Down Expand Up @@ -325,10 +363,22 @@ def _run_test_helper(

loss_list.append(loss.item())

if eval_transition:
eval_loss_list.append(
self.run_eval_transition(
args,
gpt_model,
optimizer,
(input_ids, labels, position_ids, attention_mask, loss_mask),
)
)

if self.cuda_graph_helper is not None and self.cuda_graph_helper.graphs_created():
self.cuda_graph_helper.delete_cuda_graphs()
self.cuda_graph_helper = None

if eval_transition:
return torch.tensor(loss_list), torch.tensor(eval_loss_list)
return torch.tensor(loss_list)

def run_test(self, tp_size, recipe, inference: bool = False, **kwargs):
Expand Down Expand Up @@ -356,6 +406,16 @@ def run_test_with_cuda_graph(self, tp_size, recipe, **kwargs):
)
torch.testing.assert_close(loss, loss_ref, atol=0, rtol=0)

def run_test_with_eval_transition(self, tp_size, recipe, **kwargs):
loss, eval_loss = self._run_test_helper(
tp_size, recipe, fp8_param_gather=True, eval_transition=True, **kwargs
)
loss_ref, eval_loss_ref = self._run_test_helper(
tp_size, recipe, fp8_param_gather=False, eval_transition=True, **kwargs
)
torch.testing.assert_close(loss, loss_ref, atol=1e-4, rtol=1e-4)
torch.testing.assert_close(eval_loss, eval_loss_ref, atol=1e-4, rtol=1e-4)

@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("dp_overlap", [(True, True)])
Expand Down Expand Up @@ -442,6 +502,16 @@ def test_mxfp8(self, tp_size, dp_overlap):
kwargs = {"overlap_param_gather": dp_overlap[0], "overlap_grad_reduce": dp_overlap[1]}
self.run_test(tp_size=tp_size, recipe="mxfp8", **kwargs)

@pytest.mark.skipif(
get_device_arch_version() < 10, reason="MXFP8 is supported since Blackwell architecture"
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(not is_te_min_version("2.3.0.dev0"), reason="TE 2.3.0.dev0 is required")
@pytest.mark.parametrize("tp_size", [2])
def test_mxfp8_eval_transition(self, tp_size):
kwargs = {"overlap_param_gather": True, "overlap_grad_reduce": True}
self.run_test_with_eval_transition(tp_size=tp_size, recipe="mxfp8", **kwargs)

@pytest.mark.skipif(
get_device_arch_version() < 10, reason="MXFP8 is supported since Blackwell architecture"
)
Expand Down