From 5bd5fec7814c37bc351f7233f54fdaf35e2b6fcb Mon Sep 17 00:00:00 2001 From: Shivanjan Chakravorty Date: Thu, 7 May 2026 06:11:20 +0800 Subject: [PATCH] fix(fsdp): recognize legacy GDN TP metadata GDN conv1d parameters use Megatron's legacy tensor_model_parallel and partition_dim attributes. The FSDP DTensor checkpoint splitter copies those attributes to meta tensors before calling make_fsdp_dtensor, but the FSDP TP detection path only handled _tensor_parallel_mode. Honor the legacy metadata so split GDN checkpoint tensors keep the TP placement and validate against the full FSDP plus TP mesh. Add regression coverage for copied legacy TP attributes and replicated attributes. Fixes #4553. Signed-off-by: Shivanjan Chakravorty --- .../fsdp/src/megatron_fsdp/utils.py | 4 ++ .../test_fsdp_dtensor_checkpoint.py | 38 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py index fe87a880390..9c9773a7123 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py @@ -822,6 +822,10 @@ def get_mcore_tensor_parallel_partition_dim(param: torch.Tensor) -> Optional[int return 0 elif param._tensor_parallel_mode == "row": return 1 + if getattr(param, "tensor_model_parallel", False): + partition_dim = getattr(param, "partition_dim", None) + if partition_dim is not None and partition_dim >= 0: + return int(partition_dim) return None diff --git a/tests/unit_tests/transformer/test_fsdp_dtensor_checkpoint.py b/tests/unit_tests/transformer/test_fsdp_dtensor_checkpoint.py index 20da61c90ef..e591d7cd2d8 100644 --- a/tests/unit_tests/transformer/test_fsdp_dtensor_checkpoint.py +++ b/tests/unit_tests/transformer/test_fsdp_dtensor_checkpoint.py @@ -33,6 +33,15 @@ import pytest import torch +from megatron.core.distributed.fsdp.src.megatron_fsdp.utils import ( + get_mcore_tensor_parallel_partition_dim, + is_mcore_tensor_model_parallel, + is_mcore_tensor_parallel_duplicated, +) +from megatron.core.tensor_parallel.layers import ( + copy_tensor_model_parallel_attributes, + set_tensor_model_parallel_attributes, +) from megatron.core.transformer.fsdp_dtensor_checkpoint import ( flatten_state_dict, get_expert_index_from_key, @@ -515,6 +524,35 @@ def test_tp_affects_split_sizes(self): assert r1[0][2] == 128 and r2[0][2] == 64 +class TestGDNFSDPTensorParallelMetadata: + """Regression coverage for GDN conv1d FSDP checkpoint splitting. + + GDN conv1d parameters are manually annotated with the legacy Megatron + tensor-parallel attrs. The GDN FSDP splitter copies those attrs to meta + tensors before calling make_fsdp_dtensor, so the FSDP utility must still + recognize them as tensor-parallel. + """ + + def test_legacy_tp_attrs_are_recognized_after_copy(self): + source = torch.nn.Parameter(torch.empty(8, 1, 4)) + set_tensor_model_parallel_attributes(source, True, 0, 1) + + meta = torch.empty(4, 1, 4, device="meta") + copy_tensor_model_parallel_attributes(meta, source) + + assert get_mcore_tensor_parallel_partition_dim(meta) == 0 + assert is_mcore_tensor_model_parallel(meta) + assert not is_mcore_tensor_parallel_duplicated(meta) + + def test_replicated_legacy_attrs_are_not_tp_sharded(self): + source = torch.nn.Parameter(torch.empty(8)) + set_tensor_model_parallel_attributes(source, False, -1, 1) + + assert get_mcore_tensor_parallel_partition_dim(source) is None + assert not is_mcore_tensor_model_parallel(source) + assert is_mcore_tensor_parallel_duplicated(source) + + # ============================================================================ # Test GDN state dict key generation after splitting # ============================================================================