Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/conversion/mfsdp/hf_fsdp_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None
model_provider.tensor_model_parallel_size = tp
model_provider.context_parallel_size = cp
model_provider.expert_model_parallel_size = ep
if cp > 1:
model_provider.calculate_per_token_loss = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should not be here, can be added in MBridge's TransformerConfigs finalize - but how does it needed here in conversion / generate?

model_provider.finalize()
model_provider.initialize_model_parallel(seed=0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None
model_provider.tensor_model_parallel_size = tp
model_provider.context_parallel_size = cp
model_provider.expert_model_parallel_size = ep
if cp > 1:
model_provider.calculate_per_token_loss = True
model_provider.finalize()
model_provider.initialize_model_parallel(seed=0)

Expand Down
52 changes: 33 additions & 19 deletions src/megatron/bridge/models/conversion/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,29 @@ def _get_shard_spec(
shard_rank = self.tp_rank // replicas
return shard_world_size, shard_rank

def _broadcast_tp_split_shape(
self,
splits: Optional[List[torch.Tensor]],
src_rank: int = 0,
) -> torch.Size:
"""Broadcast the actual tensor shard shape from the TP scatter source."""
shape = tuple(splits[0].shape) if self.tp_rank == src_rank and splits else None
shape_list = [shape]
global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank)

torch.distributed.broadcast_object_list(
shape_list,
src=global_src,
group=self.tp_group,
)
if shape_list[0] is None:
raise RuntimeError("Failed to broadcast tensor-parallel split shape")
return torch.Size(shape_list[0])

def scatter_to_tp_ranks(
self,
splits: Optional[List[torch.Tensor]],
output_shape: torch.Size,
output_shape: Optional[torch.Size],
dtype: torch.dtype,
device: torch.device,
src_rank: int = 0,
Expand All @@ -578,7 +597,10 @@ def scatter_to_tp_ranks(
if self.tp_size == 1:
return splits[0].to(device=device, dtype=dtype) if splits else None

output = torch.empty(output_shape, dtype=dtype, device=device)
if output_shape is None:
output_shape = self._broadcast_tp_split_shape(splits, src_rank)

output = torch.empty(torch.Size(output_shape), dtype=dtype, device=device)
global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank)

scatter_list = None
Expand Down Expand Up @@ -914,10 +936,7 @@ def hf_to_megatron(
else:
splits = None

if isinstance(target_param, DTensor):
output_shape = [target_param.shape[0] // self.tp_size, *target_param.shape[1:]]
else:
output_shape = target_param.shape
output_shape = None if isinstance(target_param, DTensor) else target_param.shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverting this hunk avoids one broadcast per column-parallel param. The old formula is correct under the same DTensor-shape semantic the new FusedGatedExpertMapping branch assumes.

Suggested change
output_shape = None if isinstance(target_param, DTensor) else target_param.shape
if isinstance(target_param, DTensor):
output_shape = (target_param.shape[0] // self.tp_size, *target_param.shape[1:])
else:
output_shape = target_param.shape

# Scatter to all ranks. Each rank gets its sharded shape from its module.
return self.scatter_to_tp_ranks(
splits,
Expand Down Expand Up @@ -1023,10 +1042,7 @@ def hf_to_megatron(
else:
splits = None

if isinstance(target_param, DTensor) and hf_weights.ndim != 1:
output_shape = [target_param.shape[0], target_param.shape[1] // self.tp_size, *target_param.shape[2:]]
else:
output_shape = target_param.shape
output_shape = None if isinstance(target_param, DTensor) else target_param.shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the ColumnParallelMapping comment — the prior DTensor branch is correct and avoids a per-param broadcast.

Suggested change
output_shape = None if isinstance(target_param, DTensor) else target_param.shape
if isinstance(target_param, DTensor) and hf_weights.ndim != 1:
output_shape = (target_param.shape[0], target_param.shape[1] // self.tp_size, *target_param.shape[2:])
else:
output_shape = target_param.shape

# Scatter to all ranks. Each rank gets its sharded shape from its module.
return self.scatter_to_tp_ranks(
splits,
Expand Down Expand Up @@ -2243,10 +2259,7 @@ def hf_to_megatron(
else:
splits = None

if isinstance(target_param, DTensor):
output_shape = [target_param.shape[0] // self.tp_size, *target_param.shape[1:]]
else:
output_shape = target_param.shape
output_shape = None if isinstance(target_param, DTensor) else target_param.shape
# Scatter the concatenated shards to each rank
return self.scatter_to_tp_ranks(
splits,
Expand Down Expand Up @@ -2482,11 +2495,12 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -
if target_shape[0] % 2 != 0:
raise ValueError(f"Expected even fused dim for {self.megatron_param}, got {target_shape}.")

gate_target_shape = (target_shape[0] // 2, target_shape[1])
# target_shape is the TP-sharded Megatron shape; compute the full (unsharded) shapes
# so that _align_expert_weight_to_shape can correctly match the raw HF weights.
# _gated_mapping.hf_to_megatron is responsible for TP scatter.
gate_full_shape = (gate_target_shape[0] * self.tp_size, target_shape[1])
if isinstance(target_param, DTensor):
gate_full_shape = (target_shape[0] // 2, target_shape[1])
else:
# target_shape is the TP-sharded Megatron shape; compute the full
# unsharded shape so raw HF weights can be validated before TP scatter.
gate_full_shape = (target_shape[0] // 2 * self.tp_size, target_shape[1])
gate_up_full_shape = (gate_full_shape[0] * 2, target_shape[1])

if expert_weight.ndim == 3 and expert_weight.shape[0] == 2:
Expand Down
Loading