Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,10 @@ def get_mcore_tensor_parallel_partition_dim(param: torch.Tensor) -> Optional[int
return 0
elif param._tensor_parallel_mode == "row":
return 1
if getattr(param, "tensor_model_parallel", False):
partition_dim = getattr(param, "partition_dim", None)
if partition_dim is not None and partition_dim >= 0:
return int(partition_dim)
return None


Expand Down
38 changes: 38 additions & 0 deletions tests/unit_tests/transformer/test_fsdp_dtensor_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@
import pytest
import torch

from megatron.core.distributed.fsdp.src.megatron_fsdp.utils import (
get_mcore_tensor_parallel_partition_dim,
is_mcore_tensor_model_parallel,
is_mcore_tensor_parallel_duplicated,
)
from megatron.core.tensor_parallel.layers import (
copy_tensor_model_parallel_attributes,
set_tensor_model_parallel_attributes,
)
from megatron.core.transformer.fsdp_dtensor_checkpoint import (
flatten_state_dict,
get_expert_index_from_key,
Expand Down Expand Up @@ -515,6 +524,35 @@ def test_tp_affects_split_sizes(self):
assert r1[0][2] == 128 and r2[0][2] == 64


class TestGDNFSDPTensorParallelMetadata:
"""Regression coverage for GDN conv1d FSDP checkpoint splitting.

GDN conv1d parameters are manually annotated with the legacy Megatron
tensor-parallel attrs. The GDN FSDP splitter copies those attrs to meta
tensors before calling make_fsdp_dtensor, so the FSDP utility must still
recognize them as tensor-parallel.
"""

def test_legacy_tp_attrs_are_recognized_after_copy(self):
source = torch.nn.Parameter(torch.empty(8, 1, 4))
set_tensor_model_parallel_attributes(source, True, 0, 1)

meta = torch.empty(4, 1, 4, device="meta")
copy_tensor_model_parallel_attributes(meta, source)

assert get_mcore_tensor_parallel_partition_dim(meta) == 0
assert is_mcore_tensor_model_parallel(meta)
assert not is_mcore_tensor_parallel_duplicated(meta)

def test_replicated_legacy_attrs_are_not_tp_sharded(self):
source = torch.nn.Parameter(torch.empty(8))
set_tensor_model_parallel_attributes(source, False, -1, 1)

assert get_mcore_tensor_parallel_partition_dim(source) is None
assert not is_mcore_tensor_model_parallel(source)
assert is_mcore_tensor_parallel_duplicated(source)


# ============================================================================
# Test GDN state dict key generation after splitting
# ============================================================================
Expand Down
Loading