diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index bcfa4c886e0..06676b07dec 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -259,6 +259,38 @@ def reset(self): self.per_param_grad_ready_counts = {} self.is_last_microbatch = True + def _post_param_sync(self): + """Run post-processing after param all-gather completes.""" + if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: + for bucket in self.buckets: + is_bf16_weight_bucket = False + for param in bucket.params: + # Skip copying since bf16 weights in the mxfp8 model + # are already mapped to param.data. + if not is_float8tensor(param): + is_bf16_weight_bucket = True + break + param_start, param_end = bucket.param_to_index[param] + param_slice = bucket.param_data.view(-1)[param_start:param_end] + param.data.copy_(param_slice.view(param.data.shape)) + if is_bf16_weight_bucket: + continue + # All-gathered params are not needed after being copied to param.data. + # Zero out the param buffer (shared with grad buffer) for gradient accumulation. + # We cannot zero out the entire grad buffer because one grad buffer may + # correspond to multiple param buffers. If we zero out the entire grad buffer, + # it would clear the data of those param buffers that have not yet completed AG. + bucket.param_data.zero_() + return + + quantized_params = [] + for bucket in self.buckets: + for param in bucket.params: + if is_float8tensor(param) or is_nvfp4tensor(param): + quantized_params.append(param) + if len(quantized_params) > 0: + post_all_gather_processing(quantized_params) + def check_grads(self, check_for_nan_or_inf, check_for_large): """ Make sure norm of grads in bucket are not NaN prior to data-parallel @@ -317,6 +349,7 @@ def start_param_sync(self, force_sync: bool = False): if self.param_gather_handle is not None: self.param_gather_handle.wait() self.param_gather_handle = None + self._post_param_sync() return else: assert self.param_gather_handle is None @@ -335,6 +368,8 @@ def start_param_sync(self, force_sync: bool = False): dp_size = self.intra_distributed_optimizer_instance_size if dp_size == 1: # Single-rank group (e.g., expt_dp_size == 1): no all-gather needed. + if force_sync and self.ddp_config.overlap_param_gather: + self._post_param_sync() self.param_gather_dispatched = True return local_rank = self.intra_distributed_optimizer_instance_rank @@ -428,6 +463,8 @@ def start_param_sync(self, force_sync: bool = False): # (async_op=False) is used, `cm` is not None. Manually set to None for # consistency with prior code. self.param_gather_handle = None + if force_sync and self.ddp_config.overlap_param_gather: + self._post_param_sync() self.param_gather_dispatched = True def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): @@ -467,30 +504,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): else: self.next_param_gather_bucket_group.start_param_sync() - # For the mxfp8_param with "reuse_grad_buf_for_mxfp8_param_ag=True", - # we need to copy the param_data from the shared_param/grad_buffer to param.data - # after the param all-gather. - if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: - for bucket in self.buckets: - is_bf16_weight_bucket = False - for param in bucket.params: - # Skip copying since bf16 weights in the mxfp8 model - # are already mapped to param.data. - if not is_float8tensor(param): - is_bf16_weight_bucket = True - break - param_start, param_end = bucket.param_to_index[param] - param_slice = bucket.param_data.view(-1)[param_start:param_end] - param.data.copy_(param_slice.view(param.data.shape)) - if is_bf16_weight_bucket: - continue - # All-gathered params are not needed after being copied to param.data. - # Zero out the param buffer (shared with grad buffer) for gradient accumulation. - # We cannot zero out the entire grad buffer because one grad buffer may - # correspond to multiple param buffers. If we zero out the entire grad buffer, - # it would clear the data of those param buffers that have not yet completed AG. - bucket.param_data.zero_() - elif not self.ddp_config.use_distributed_optimizer: + if not self.ddp_config.use_distributed_optimizer: for bucket in self.buckets: if bucket.layerwise_gather_list is None: continue @@ -508,14 +522,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): for updated_p, model_p in zip(updated_params, params): model_p.data.copy_(updated_p) bucket.layerwise_gather_list = None - else: - fp8_params = [] - for bucket in self.buckets: - for param in bucket.params: - if is_float8tensor(param): - fp8_params.append(param) - if len(fp8_params) > 0: - post_all_gather_processing(fp8_params) + self._post_param_sync() def start_grad_sync(self, force_all_reduce: Optional[bool] = False): """ diff --git a/tests/unit_tests/test_fp4_param.py b/tests/unit_tests/test_fp4_param.py index f01d6592d23..1276748ad45 100644 --- a/tests/unit_tests/test_fp4_param.py +++ b/tests/unit_tests/test_fp4_param.py @@ -162,8 +162,37 @@ def get_batch(self, seq_length, micro_batch_size): loss_mask = torch.ones(seq_length).repeat((micro_batch_size, 1)).cuda() return input_ids, labels, position_ids, attention_mask, loss_mask + def run_eval_transition(self, args, model_chunks, batch): + input_ids, labels, position_ids, attention_mask, loss_mask = batch + + if should_disable_forward_pre_hook(args): + disable_forward_pre_hook(model_chunks, param_sync=True) + + model_chunks[0].eval() + model_chunks[0].set_is_first_microbatch() + with torch.no_grad(): + eval_output = model_chunks[0].forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + eval_loss = eval_output.mean() + model_chunks[0].train() + + if should_disable_forward_pre_hook(args): + enable_forward_pre_hook(model_chunks) + + return eval_loss.item() + def _run_test_helper( - self, tp_size, inference: bool = False, fp4_param_gather: bool = True, **kwargs + self, + tp_size, + inference: bool = False, + fp4_param_gather: bool = True, + eval_transition: bool = False, + **kwargs, ): """Test fp4_param with gpt_model.""" args = self.create_test_args( @@ -206,6 +235,7 @@ def _run_test_helper( assert num_fp4_params == 4 * fp4_layers loss_list = [] + eval_loss_list = [] # CUDA graph setup (transformer_engine implementation) cuda_graph_helper = None @@ -267,6 +297,17 @@ def _run_test_helper( loss_list.append(loss.item()) + if eval_transition: + eval_loss_list.append( + self.run_eval_transition( + args, + gpt_model, + (input_ids, labels, position_ids, attention_mask, loss_mask), + ) + ) + + if eval_transition: + return torch.tensor(loss_list), torch.tensor(eval_loss_list) return torch.tensor(loss_list) def run_test(self, tp_size, inference: bool = False, **kwargs): @@ -282,6 +323,18 @@ def run_test(self, tp_size, inference: bool = False, **kwargs): torch.testing.assert_close(loss_list, loss_list_ref, atol=1e-2, rtol=1e-2) + def run_test_with_eval_transition(self, tp_size, **kwargs): + """Test fp4_param eval transition with gpt_model.""" + loss_list, eval_loss_list = self._run_test_helper( + tp_size, fp4_param_gather=True, eval_transition=True, **kwargs + ) + loss_list_ref, eval_loss_list_ref = self._run_test_helper( + tp_size, fp4_param_gather=False, eval_transition=True, **kwargs + ) + + torch.testing.assert_close(loss_list, loss_list_ref, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(eval_loss_list, eval_loss_list_ref, atol=1e-2, rtol=1e-2) + @pytest.mark.skipif(not is_nvfp4_available, reason=reason_for_no_nvfp4) @pytest.mark.skipif(not is_te_min_version("2.7.0.dev0"), reason="TE 2.7.0.dev0 is required") @pytest.mark.parametrize("tp_size", [2]) @@ -294,6 +347,13 @@ def test_nvfp4(self, tp_size, dp_overlap): kwargs = {"overlap_param_gather": dp_overlap[0], "overlap_grad_reduce": dp_overlap[1]} self.run_test(tp_size=tp_size, inference=False, **kwargs) + @pytest.mark.skipif(not is_nvfp4_available, reason=reason_for_no_nvfp4) + @pytest.mark.skipif(not is_te_min_version("2.7.0.dev0"), reason="TE 2.7.0.dev0 is required") + @pytest.mark.parametrize("tp_size", [2]) + def test_nvfp4_eval_transition(self, tp_size): + kwargs = {"overlap_param_gather": True, "overlap_grad_reduce": True} + self.run_test_with_eval_transition(tp_size=tp_size, **kwargs) + @pytest.mark.skipif(not is_nvfp4_available, reason=reason_for_no_nvfp4) @pytest.mark.skipif(not is_te_min_version("2.7.0.dev0"), reason="TE 2.7.0.dev0 is required") @pytest.mark.parametrize("tp_size", [2]) diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index e0a71526297..9cc77a2c397 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -15,6 +15,7 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator +from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.utils import is_te_min_version from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args @@ -171,6 +172,43 @@ def get_batch(self, seq_length, micro_batch_size): loss_mask = torch.ones(seq_length).repeat((micro_batch_size, 1)).cuda() return input_ids, labels, position_ids, attention_mask, loss_mask + def copy_main_params_to_param_buffer(self, model_chunks, optimizer): + # Mirrors MBridge's pre-eval fix: disable_forward_pre_hook(param_sync=True) + # force-syncs params before eval callbacks run, so MXFP8 must repopulate + # the shared param/grad buffer before disabling forward hooks. + for model_chunk in model_chunks: + model_chunk.zero_grad_buffer() + for optim_instance in optimizer.chained_optimizers: + if isinstance(optim_instance, DistributedOptimizer): + optim_instance._copy_main_params_to_param_buffer() + + def run_eval_transition(self, args, model_chunks, optimizer, batch): + input_ids, labels, position_ids, attention_mask, loss_mask = batch + + if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather: + self.copy_main_params_to_param_buffer(model_chunks, optimizer) + + if should_disable_forward_pre_hook(args): + disable_forward_pre_hook(model_chunks, param_sync=True) + + model_chunks[0].eval() + model_chunks[0].set_is_first_microbatch() + with torch.no_grad(): + eval_output = model_chunks[0].forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + eval_loss = eval_output.mean() + model_chunks[0].train() + + if should_disable_forward_pre_hook(args): + enable_forward_pre_hook(model_chunks) + + return eval_loss.item() + def _run_test_helper( self, tp_size, @@ -178,6 +216,7 @@ def _run_test_helper( inference: bool = False, fp8_param_gather: bool = True, use_cuda_graph: bool = False, + eval_transition: bool = False, **kwargs, ): """Test fp8_param with gpt_model.""" @@ -269,6 +308,7 @@ def _run_test_helper( ) loss_list = [] + eval_loss_list = [] for i in range(100): if not inference: @@ -290,9 +330,7 @@ def _run_test_helper( # we need to call the _copy_main_params_to_param_buffer() after the grad buffer # is zeroed by zero_grad_buffer() because param and grad buffer are shared. if args.reuse_grad_buf_for_mxfp8_param_ag and args.overlap_param_gather: - for optim_instance in optimizer.chained_optimizers: - if hasattr(optim_instance, "_copy_main_params_to_param_buffer"): - optim_instance._copy_main_params_to_param_buffer() + self.copy_main_params_to_param_buffer(gpt_model, optimizer) gpt_model[0].set_is_first_microbatch() output = gpt_model[0].forward( @@ -325,10 +363,22 @@ def _run_test_helper( loss_list.append(loss.item()) + if eval_transition: + eval_loss_list.append( + self.run_eval_transition( + args, + gpt_model, + optimizer, + (input_ids, labels, position_ids, attention_mask, loss_mask), + ) + ) + if self.cuda_graph_helper is not None and self.cuda_graph_helper.graphs_created(): self.cuda_graph_helper.delete_cuda_graphs() self.cuda_graph_helper = None + if eval_transition: + return torch.tensor(loss_list), torch.tensor(eval_loss_list) return torch.tensor(loss_list) def run_test(self, tp_size, recipe, inference: bool = False, **kwargs): @@ -356,6 +406,16 @@ def run_test_with_cuda_graph(self, tp_size, recipe, **kwargs): ) torch.testing.assert_close(loss, loss_ref, atol=0, rtol=0) + def run_test_with_eval_transition(self, tp_size, recipe, **kwargs): + loss, eval_loss = self._run_test_helper( + tp_size, recipe, fp8_param_gather=True, eval_transition=True, **kwargs + ) + loss_ref, eval_loss_ref = self._run_test_helper( + tp_size, recipe, fp8_param_gather=False, eval_transition=True, **kwargs + ) + torch.testing.assert_close(loss, loss_ref, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(eval_loss, eval_loss_ref, atol=1e-4, rtol=1e-4) + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("dp_overlap", [(True, True)]) @@ -442,6 +502,16 @@ def test_mxfp8(self, tp_size, dp_overlap): kwargs = {"overlap_param_gather": dp_overlap[0], "overlap_grad_reduce": dp_overlap[1]} self.run_test(tp_size=tp_size, recipe="mxfp8", **kwargs) + @pytest.mark.skipif( + get_device_arch_version() < 10, reason="MXFP8 is supported since Blackwell architecture" + ) + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.skipif(not is_te_min_version("2.3.0.dev0"), reason="TE 2.3.0.dev0 is required") + @pytest.mark.parametrize("tp_size", [2]) + def test_mxfp8_eval_transition(self, tp_size): + kwargs = {"overlap_param_gather": True, "overlap_grad_reduce": True} + self.run_test_with_eval_transition(tp_size=tp_size, recipe="mxfp8", **kwargs) + @pytest.mark.skipif( get_device_arch_version() < 10, reason="MXFP8 is supported since Blackwell architecture" )