From a38ef9b5fd43b1021f84626e3055cf15a7f3e1db Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 30 Apr 2026 12:13:25 -0700 Subject: [PATCH 1/6] update colwise data after param AG in eval Signed-off-by: qiyuw --- .../core/distributed/param_and_grad_buffer.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index dc3014d72d2..e3103f6ec59 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -257,6 +257,16 @@ def reset(self): self.per_param_grad_ready_counts = {} self.is_last_microbatch = True + def _post_param_sync(self): + """Run post-processing for quantized params after param all-gather completes.""" + quantized_params = [] + for bucket in self.buckets: + for param in bucket.params: + if is_float8tensor(param) or is_nvfp4tensor(param): + quantized_params.append(param) + if len(quantized_params) > 0: + post_all_gather_processing(quantized_params) + def check_grads(self, check_for_nan_or_inf, check_for_large): """ Make sure norm of grads in bucket are not NaN prior to data-parallel @@ -315,6 +325,8 @@ def start_param_sync(self, force_sync: bool = False): if self.param_gather_handle is not None: self.param_gather_handle.wait() self.param_gather_handle = None + if not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: + self._post_param_sync() return else: assert self.param_gather_handle is None @@ -333,6 +345,8 @@ def start_param_sync(self, force_sync: bool = False): dp_size = self.intra_distributed_optimizer_instance_size if dp_size == 1: # Single-rank group (e.g., expt_dp_size == 1): no all-gather needed. + if force_sync and not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: + self._post_param_sync() self.param_gather_dispatched = True return local_rank = self.intra_distributed_optimizer_instance_rank @@ -425,6 +439,8 @@ def start_param_sync(self, force_sync: bool = False): # (async_op=False) is used, `cm` is not None. Manually set to None for # consistency with prior code. self.param_gather_handle = None + if force_sync and not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: + self._post_param_sync() self.param_gather_dispatched = True def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): @@ -505,14 +521,9 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) bucket.layerwise_gather_list = None + self._post_param_sync() else: - fp8_params = [] - for bucket in self.buckets: - for param in bucket.params: - if is_float8tensor(param): - fp8_params.append(param) - if len(fp8_params) > 0: - post_all_gather_processing(fp8_params) + self._post_param_sync() def start_grad_sync(self, force_all_reduce: Optional[bool] = False): """ From 74be067b48f2bce3cc93c85c70f4eb9d73b97a9f Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 6 May 2026 08:42:14 -0700 Subject: [PATCH 2/6] fix similar issue for mxfp8 param gather Signed-off-by: qiyuw --- .../core/distributed/param_and_grad_buffer.py | 60 +++++++++---------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index e3103f6ec59..06846a5fd45 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -258,7 +258,29 @@ def reset(self): self.is_last_microbatch = True def _post_param_sync(self): - """Run post-processing for quantized params after param all-gather completes.""" + """Run post-processing after param all-gather completes.""" + if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: + for bucket in self.buckets: + is_bf16_weight_bucket = False + for param in bucket.params: + # Skip copying since bf16 weights in the mxfp8 model + # are already mapped to param.data. + if not is_float8tensor(param): + is_bf16_weight_bucket = True + break + param_start, param_end = bucket.param_to_index[param] + param_slice = bucket.param_data.view(-1)[param_start:param_end] + param.data.copy_(param_slice.view(param.data.shape)) + if is_bf16_weight_bucket: + continue + # All-gathered params are not needed after being copied to param.data. + # Zero out the param buffer (shared with grad buffer) for gradient accumulation. + # We cannot zero out the entire grad buffer because one grad buffer may + # correspond to multiple param buffers. If we zero out the entire grad buffer, + # it would clear the data of those param buffers that have not yet completed AG. + bucket.param_data.zero_() + return + quantized_params = [] for bucket in self.buckets: for param in bucket.params: @@ -325,8 +347,7 @@ def start_param_sync(self, force_sync: bool = False): if self.param_gather_handle is not None: self.param_gather_handle.wait() self.param_gather_handle = None - if not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: - self._post_param_sync() + self._post_param_sync() return else: assert self.param_gather_handle is None @@ -345,7 +366,7 @@ def start_param_sync(self, force_sync: bool = False): dp_size = self.intra_distributed_optimizer_instance_size if dp_size == 1: # Single-rank group (e.g., expt_dp_size == 1): no all-gather needed. - if force_sync and not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: + if force_sync: self._post_param_sync() self.param_gather_dispatched = True return @@ -439,7 +460,7 @@ def start_param_sync(self, force_sync: bool = False): # (async_op=False) is used, `cm` is not None. Manually set to None for # consistency with prior code. self.param_gather_handle = None - if force_sync and not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: + if force_sync: self._post_param_sync() self.param_gather_dispatched = True @@ -480,30 +501,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): else: self.next_param_gather_bucket_group.start_param_sync() - # For the mxfp8_param with "reuse_grad_buf_for_mxfp8_param_ag=True", - # we need to copy the param_data from the shared_param/grad_buffer to param.data - # after the param all-gather. - if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: - for bucket in self.buckets: - is_bf16_weight_bucket = False - for param in bucket.params: - # Skip copying since bf16 weights in the mxfp8 model - # are already mapped to param.data. - if not is_float8tensor(param): - is_bf16_weight_bucket = True - break - param_start, param_end = bucket.param_to_index[param] - param_slice = bucket.param_data.view(-1)[param_start:param_end] - param.data.copy_(param_slice.view(param.data.shape)) - if is_bf16_weight_bucket: - continue - # All-gathered params are not needed after being copied to param.data. - # Zero out the param buffer (shared with grad buffer) for gradient accumulation. - # We cannot zero out the entire grad buffer because one grad buffer may - # correspond to multiple param buffers. If we zero out the entire grad buffer, - # it would clear the data of those param buffers that have not yet completed AG. - bucket.param_data.zero_() - elif not self.ddp_config.use_distributed_optimizer: + if not self.ddp_config.use_distributed_optimizer: for bucket in self.buckets: if bucket.layerwise_gather_list is None: continue @@ -521,9 +519,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) bucket.layerwise_gather_list = None - self._post_param_sync() - else: - self._post_param_sync() + self._post_param_sync() def start_grad_sync(self, force_all_reduce: Optional[bool] = False): """ From 3e8fd8a4bc7e46ce88481f17e86ce0cf1628f80a Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 6 May 2026 10:38:11 -0700 Subject: [PATCH 3/6] tighten the edge case Signed-off-by: qiyuw --- megatron/core/distributed/param_and_grad_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 06846a5fd45..43abbdaa088 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -366,7 +366,7 @@ def start_param_sync(self, force_sync: bool = False): dp_size = self.intra_distributed_optimizer_instance_size if dp_size == 1: # Single-rank group (e.g., expt_dp_size == 1): no all-gather needed. - if force_sync: + if force_sync and self.ddp_config.overlap_param_gather: self._post_param_sync() self.param_gather_dispatched = True return @@ -460,7 +460,7 @@ def start_param_sync(self, force_sync: bool = False): # (async_op=False) is used, `cm` is not None. Manually set to None for # consistency with prior code. self.param_gather_handle = None - if force_sync: + if force_sync and self.ddp_config.overlap_param_gather: self._post_param_sync() self.param_gather_dispatched = True From c7849a5916e3ab8aec2cc47b9166c5521fd6daf6 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 7 May 2026 13:29:13 -0700 Subject: [PATCH 4/6] add eval transition unit test in test_fp8_param.py Signed-off-by: qiyuw --- tests/unit_tests/test_fp8_param.py | 84 ++++++++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index 34b504e21de..f2bebeb9bf8 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -15,6 +15,7 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator +from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.utils import is_te_min_version from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args @@ -171,6 +172,43 @@ def get_batch(self, seq_length, micro_batch_size): loss_mask = torch.ones(seq_length).repeat((micro_batch_size, 1)).cuda() return input_ids, labels, position_ids, attention_mask, loss_mask + def copy_main_params_to_param_buffer(self, model_chunks, optimizer): + # Mirrors MBridge's pre-eval fix: disable_forward_pre_hook(param_sync=True) + # force-syncs params before eval callbacks run, so MXFP8 must repopulate + # the shared param/grad buffer before disabling forward hooks. + for model_chunk in model_chunks: + model_chunk.zero_grad_buffer() + for optim_instance in optimizer.chained_optimizers: + if isinstance(optim_instance, DistributedOptimizer): + optim_instance._copy_main_params_to_param_buffer() + + def run_eval_transition(self, args, model_chunks, optimizer, batch): + input_ids, labels, position_ids, attention_mask, loss_mask = batch + + if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather: + self.copy_main_params_to_param_buffer(model_chunks, optimizer) + + if should_disable_forward_pre_hook(args): + disable_forward_pre_hook(model_chunks, param_sync=True) + + model_chunks[0].eval() + model_chunks[0].set_is_first_microbatch() + with torch.no_grad(): + eval_output = model_chunks[0].forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + eval_loss = eval_output.mean() + model_chunks[0].train() + + if should_disable_forward_pre_hook(args): + enable_forward_pre_hook(model_chunks) + + return eval_loss.item() + def _run_test_helper( self, tp_size, @@ -178,6 +216,7 @@ def _run_test_helper( inference: bool = False, fp8_param_gather: bool = True, use_cuda_graph: bool = False, + eval_transition: bool = False, **kwargs, ): """Test fp8_param with gpt_model.""" @@ -269,6 +308,7 @@ def _run_test_helper( ) loss_list = [] + eval_loss_list = [] for i in range(100): if not inference: @@ -290,9 +330,7 @@ def _run_test_helper( # we need to call the _copy_main_params_to_param_buffer() after the grad buffer # is zeroed by zero_grad_buffer() because param and grad buffer are shared. if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather: - for optim_instance in optimizer.chained_optimizers: - if hasattr(optim_instance, "_copy_main_params_to_param_buffer"): - optim_instance._copy_main_params_to_param_buffer() + self.copy_main_params_to_param_buffer(gpt_model, optimizer) gpt_model[0].set_is_first_microbatch() output = gpt_model[0].forward( @@ -325,10 +363,22 @@ def _run_test_helper( loss_list.append(loss.item()) + if eval_transition: + eval_loss_list.append( + self.run_eval_transition( + args, + gpt_model, + optimizer, + (input_ids, labels, position_ids, attention_mask, loss_mask), + ) + ) + if self.cuda_graph_helper is not None and self.cuda_graph_helper.graphs_created(): self.cuda_graph_helper.delete_cuda_graphs() self.cuda_graph_helper = None + if eval_transition: + return torch.tensor(loss_list), torch.tensor(eval_loss_list) return torch.tensor(loss_list) def run_test(self, tp_size, recipe, inference: bool = False, **kwargs): @@ -356,6 +406,24 @@ def run_test_with_cuda_graph(self, tp_size, recipe, **kwargs): ) torch.testing.assert_close(loss, loss_ref, atol=0, rtol=0) + def run_test_with_eval_transition(self, tp_size, recipe, **kwargs): + loss, eval_loss = self._run_test_helper( + tp_size, + recipe, + fp8_param_gather=True, + eval_transition=True, + **kwargs, + ) + loss_ref, eval_loss_ref = self._run_test_helper( + tp_size, + recipe, + fp8_param_gather=False, + eval_transition=True, + **kwargs, + ) + torch.testing.assert_close(loss, loss_ref, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(eval_loss, eval_loss_ref, atol=1e-4, rtol=1e-4) + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("dp_overlap", [(True, True)]) @@ -442,6 +510,16 @@ def test_mxfp8(self, tp_size, dp_overlap): kwargs = {"overlap_param_gather": dp_overlap[0], "overlap_grad_reduce": dp_overlap[1]} self.run_test(tp_size=tp_size, recipe="mxfp8", **kwargs) + @pytest.mark.skipif( + get_device_arch_version() < 10, reason="MXFP8 is supported since Blackwell architecture" + ) + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.skipif(not is_te_min_version("2.3.0.dev0"), reason="TE 2.3.0.dev0 is required") + @pytest.mark.parametrize("tp_size", [2]) + def test_mxfp8_eval_transition(self, tp_size): + kwargs = {"overlap_param_gather": True, "overlap_grad_reduce": True} + self.run_test_with_eval_transition(tp_size=tp_size, recipe="mxfp8", **kwargs) + @pytest.mark.skipif( get_device_arch_version() < 10, reason="MXFP8 is supported since Blackwell architecture" ) From 3b4a151716112f4a2cecfe94f71c480af26eef7f Mon Sep 17 00:00:00 2001 From: qiyuw Date: Thu, 7 May 2026 13:34:43 -0700 Subject: [PATCH 5/6] add eval transition unit test in test_fp4_param.py Signed-off-by: qiyuw --- tests/unit_tests/test_fp4_param.py | 68 +++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/test_fp4_param.py b/tests/unit_tests/test_fp4_param.py index f01d6592d23..d4f5c58bb05 100644 --- a/tests/unit_tests/test_fp4_param.py +++ b/tests/unit_tests/test_fp4_param.py @@ -162,8 +162,37 @@ def get_batch(self, seq_length, micro_batch_size): loss_mask = torch.ones(seq_length).repeat((micro_batch_size, 1)).cuda() return input_ids, labels, position_ids, attention_mask, loss_mask + def run_eval_transition(self, args, model_chunks, batch): + input_ids, labels, position_ids, attention_mask, loss_mask = batch + + if should_disable_forward_pre_hook(args): + disable_forward_pre_hook(model_chunks, param_sync=True) + + model_chunks[0].eval() + model_chunks[0].set_is_first_microbatch() + with torch.no_grad(): + eval_output = model_chunks[0].forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + eval_loss = eval_output.mean() + model_chunks[0].train() + + if should_disable_forward_pre_hook(args): + enable_forward_pre_hook(model_chunks) + + return eval_loss.item() + def _run_test_helper( - self, tp_size, inference: bool = False, fp4_param_gather: bool = True, **kwargs + self, + tp_size, + inference: bool = False, + fp4_param_gather: bool = True, + eval_transition: bool = False, + **kwargs, ): """Test fp4_param with gpt_model.""" args = self.create_test_args( @@ -206,6 +235,7 @@ def _run_test_helper( assert num_fp4_params == 4 * fp4_layers loss_list = [] + eval_loss_list = [] # CUDA graph setup (transformer_engine implementation) cuda_graph_helper = None @@ -267,6 +297,17 @@ def _run_test_helper( loss_list.append(loss.item()) + if eval_transition: + eval_loss_list.append( + self.run_eval_transition( + args, + gpt_model, + (input_ids, labels, position_ids, attention_mask, loss_mask), + ) + ) + + if eval_transition: + return torch.tensor(loss_list), torch.tensor(eval_loss_list) return torch.tensor(loss_list) def run_test(self, tp_size, inference: bool = False, **kwargs): @@ -282,6 +323,24 @@ def run_test(self, tp_size, inference: bool = False, **kwargs): torch.testing.assert_close(loss_list, loss_list_ref, atol=1e-2, rtol=1e-2) + def run_test_with_eval_transition(self, tp_size, **kwargs): + """Test fp4_param eval transition with gpt_model.""" + loss_list, eval_loss_list = self._run_test_helper( + tp_size, + fp4_param_gather=True, + eval_transition=True, + **kwargs, + ) + loss_list_ref, eval_loss_list_ref = self._run_test_helper( + tp_size, + fp4_param_gather=False, + eval_transition=True, + **kwargs, + ) + + torch.testing.assert_close(loss_list, loss_list_ref, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(eval_loss_list, eval_loss_list_ref, atol=1e-2, rtol=1e-2) + @pytest.mark.skipif(not is_nvfp4_available, reason=reason_for_no_nvfp4) @pytest.mark.skipif(not is_te_min_version("2.7.0.dev0"), reason="TE 2.7.0.dev0 is required") @pytest.mark.parametrize("tp_size", [2]) @@ -294,6 +353,13 @@ def test_nvfp4(self, tp_size, dp_overlap): kwargs = {"overlap_param_gather": dp_overlap[0], "overlap_grad_reduce": dp_overlap[1]} self.run_test(tp_size=tp_size, inference=False, **kwargs) + @pytest.mark.skipif(not is_nvfp4_available, reason=reason_for_no_nvfp4) + @pytest.mark.skipif(not is_te_min_version("2.7.0.dev0"), reason="TE 2.7.0.dev0 is required") + @pytest.mark.parametrize("tp_size", [2]) + def test_nvfp4_eval_transition(self, tp_size): + kwargs = {"overlap_param_gather": True, "overlap_grad_reduce": True} + self.run_test_with_eval_transition(tp_size=tp_size, **kwargs) + @pytest.mark.skipif(not is_nvfp4_available, reason=reason_for_no_nvfp4) @pytest.mark.skipif(not is_te_min_version("2.7.0.dev0"), reason="TE 2.7.0.dev0 is required") @pytest.mark.parametrize("tp_size", [2]) From 3ec63c982d48ee39670d666f7ae4afbfb7069386 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Fri, 8 May 2026 11:26:13 -0700 Subject: [PATCH 6/6] lint Signed-off-by: qiyuw --- tests/unit_tests/test_fp4_param.py | 10 ++-------- tests/unit_tests/test_fp8_param.py | 12 ++---------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/tests/unit_tests/test_fp4_param.py b/tests/unit_tests/test_fp4_param.py index d4f5c58bb05..1276748ad45 100644 --- a/tests/unit_tests/test_fp4_param.py +++ b/tests/unit_tests/test_fp4_param.py @@ -326,16 +326,10 @@ def run_test(self, tp_size, inference: bool = False, **kwargs): def run_test_with_eval_transition(self, tp_size, **kwargs): """Test fp4_param eval transition with gpt_model.""" loss_list, eval_loss_list = self._run_test_helper( - tp_size, - fp4_param_gather=True, - eval_transition=True, - **kwargs, + tp_size, fp4_param_gather=True, eval_transition=True, **kwargs ) loss_list_ref, eval_loss_list_ref = self._run_test_helper( - tp_size, - fp4_param_gather=False, - eval_transition=True, - **kwargs, + tp_size, fp4_param_gather=False, eval_transition=True, **kwargs ) torch.testing.assert_close(loss_list, loss_list_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index f2bebeb9bf8..b785a35396d 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -408,18 +408,10 @@ def run_test_with_cuda_graph(self, tp_size, recipe, **kwargs): def run_test_with_eval_transition(self, tp_size, recipe, **kwargs): loss, eval_loss = self._run_test_helper( - tp_size, - recipe, - fp8_param_gather=True, - eval_transition=True, - **kwargs, + tp_size, recipe, fp8_param_gather=True, eval_transition=True, **kwargs ) loss_ref, eval_loss_ref = self._run_test_helper( - tp_size, - recipe, - fp8_param_gather=False, - eval_transition=True, - **kwargs, + tp_size, recipe, fp8_param_gather=False, eval_transition=True, **kwargs ) torch.testing.assert_close(loss, loss_ref, atol=1e-4, rtol=1e-4) torch.testing.assert_close(eval_loss, eval_loss_ref, atol=1e-4, rtol=1e-4)