diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py index fe87a880390..9c9773a7123 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py @@ -822,6 +822,10 @@ def get_mcore_tensor_parallel_partition_dim(param: torch.Tensor) -> Optional[int return 0 elif param._tensor_parallel_mode == "row": return 1 + if getattr(param, "tensor_model_parallel", False): + partition_dim = getattr(param, "partition_dim", None) + if partition_dim is not None and partition_dim >= 0: + return int(partition_dim) return None diff --git a/tests/unit_tests/transformer/test_fsdp_dtensor_checkpoint.py b/tests/unit_tests/transformer/test_fsdp_dtensor_checkpoint.py index 20da61c90ef..e591d7cd2d8 100644 --- a/tests/unit_tests/transformer/test_fsdp_dtensor_checkpoint.py +++ b/tests/unit_tests/transformer/test_fsdp_dtensor_checkpoint.py @@ -33,6 +33,15 @@ import pytest import torch +from megatron.core.distributed.fsdp.src.megatron_fsdp.utils import ( + get_mcore_tensor_parallel_partition_dim, + is_mcore_tensor_model_parallel, + is_mcore_tensor_parallel_duplicated, +) +from megatron.core.tensor_parallel.layers import ( + copy_tensor_model_parallel_attributes, + set_tensor_model_parallel_attributes, +) from megatron.core.transformer.fsdp_dtensor_checkpoint import ( flatten_state_dict, get_expert_index_from_key, @@ -515,6 +524,35 @@ def test_tp_affects_split_sizes(self): assert r1[0][2] == 128 and r2[0][2] == 64 +class TestGDNFSDPTensorParallelMetadata: + """Regression coverage for GDN conv1d FSDP checkpoint splitting. + + GDN conv1d parameters are manually annotated with the legacy Megatron + tensor-parallel attrs. The GDN FSDP splitter copies those attrs to meta + tensors before calling make_fsdp_dtensor, so the FSDP utility must still + recognize them as tensor-parallel. + """ + + def test_legacy_tp_attrs_are_recognized_after_copy(self): + source = torch.nn.Parameter(torch.empty(8, 1, 4)) + set_tensor_model_parallel_attributes(source, True, 0, 1) + + meta = torch.empty(4, 1, 4, device="meta") + copy_tensor_model_parallel_attributes(meta, source) + + assert get_mcore_tensor_parallel_partition_dim(meta) == 0 + assert is_mcore_tensor_model_parallel(meta) + assert not is_mcore_tensor_parallel_duplicated(meta) + + def test_replicated_legacy_attrs_are_not_tp_sharded(self): + source = torch.nn.Parameter(torch.empty(8)) + set_tensor_model_parallel_attributes(source, False, -1, 1) + + assert get_mcore_tensor_parallel_partition_dim(source) is None + assert not is_mcore_tensor_model_parallel(source) + assert is_mcore_tensor_parallel_duplicated(source) + + # ============================================================================ # Test GDN state dict key generation after splitting # ============================================================================