-
Notifications
You must be signed in to change notification settings - Fork 308
[ckpt] fix: Use DTensor split shapes for Megatron-FSDP TP loading #3746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -553,10 +553,29 @@ def _get_shard_spec( | |||||||||||
| shard_rank = self.tp_rank // replicas | ||||||||||||
| return shard_world_size, shard_rank | ||||||||||||
|
|
||||||||||||
| def _broadcast_tp_split_shape( | ||||||||||||
| self, | ||||||||||||
| splits: Optional[List[torch.Tensor]], | ||||||||||||
| src_rank: int = 0, | ||||||||||||
| ) -> torch.Size: | ||||||||||||
| """Broadcast the actual tensor shard shape from the TP scatter source.""" | ||||||||||||
| shape = tuple(splits[0].shape) if self.tp_rank == src_rank and splits else None | ||||||||||||
| shape_list = [shape] | ||||||||||||
| global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) | ||||||||||||
|
|
||||||||||||
| torch.distributed.broadcast_object_list( | ||||||||||||
| shape_list, | ||||||||||||
| src=global_src, | ||||||||||||
| group=self.tp_group, | ||||||||||||
| ) | ||||||||||||
| if shape_list[0] is None: | ||||||||||||
| raise RuntimeError("Failed to broadcast tensor-parallel split shape") | ||||||||||||
| return torch.Size(shape_list[0]) | ||||||||||||
|
|
||||||||||||
| def scatter_to_tp_ranks( | ||||||||||||
| self, | ||||||||||||
| splits: Optional[List[torch.Tensor]], | ||||||||||||
| output_shape: torch.Size, | ||||||||||||
| output_shape: Optional[torch.Size], | ||||||||||||
| dtype: torch.dtype, | ||||||||||||
| device: torch.device, | ||||||||||||
| src_rank: int = 0, | ||||||||||||
|
|
@@ -578,7 +597,10 @@ def scatter_to_tp_ranks( | |||||||||||
| if self.tp_size == 1: | ||||||||||||
| return splits[0].to(device=device, dtype=dtype) if splits else None | ||||||||||||
|
|
||||||||||||
| output = torch.empty(output_shape, dtype=dtype, device=device) | ||||||||||||
| if output_shape is None: | ||||||||||||
| output_shape = self._broadcast_tp_split_shape(splits, src_rank) | ||||||||||||
|
|
||||||||||||
| output = torch.empty(torch.Size(output_shape), dtype=dtype, device=device) | ||||||||||||
| global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) | ||||||||||||
|
|
||||||||||||
| scatter_list = None | ||||||||||||
|
|
@@ -914,10 +936,7 @@ def hf_to_megatron( | |||||||||||
| else: | ||||||||||||
| splits = None | ||||||||||||
|
|
||||||||||||
| if isinstance(target_param, DTensor): | ||||||||||||
| output_shape = [target_param.shape[0] // self.tp_size, *target_param.shape[1:]] | ||||||||||||
| else: | ||||||||||||
| output_shape = target_param.shape | ||||||||||||
| output_shape = None if isinstance(target_param, DTensor) else target_param.shape | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reverting this hunk avoids one broadcast per column-parallel param. The old formula is correct under the same DTensor-shape semantic the new
Suggested change
|
||||||||||||
| # Scatter to all ranks. Each rank gets its sharded shape from its module. | ||||||||||||
| return self.scatter_to_tp_ranks( | ||||||||||||
| splits, | ||||||||||||
|
|
@@ -1023,10 +1042,7 @@ def hf_to_megatron( | |||||||||||
| else: | ||||||||||||
| splits = None | ||||||||||||
|
|
||||||||||||
| if isinstance(target_param, DTensor) and hf_weights.ndim != 1: | ||||||||||||
| output_shape = [target_param.shape[0], target_param.shape[1] // self.tp_size, *target_param.shape[2:]] | ||||||||||||
| else: | ||||||||||||
| output_shape = target_param.shape | ||||||||||||
| output_shape = None if isinstance(target_param, DTensor) else target_param.shape | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as the
Suggested change
|
||||||||||||
| # Scatter to all ranks. Each rank gets its sharded shape from its module. | ||||||||||||
| return self.scatter_to_tp_ranks( | ||||||||||||
| splits, | ||||||||||||
|
|
@@ -2243,10 +2259,7 @@ def hf_to_megatron( | |||||||||||
| else: | ||||||||||||
| splits = None | ||||||||||||
|
|
||||||||||||
| if isinstance(target_param, DTensor): | ||||||||||||
| output_shape = [target_param.shape[0] // self.tp_size, *target_param.shape[1:]] | ||||||||||||
| else: | ||||||||||||
| output_shape = target_param.shape | ||||||||||||
| output_shape = None if isinstance(target_param, DTensor) else target_param.shape | ||||||||||||
| # Scatter the concatenated shards to each rank | ||||||||||||
| return self.scatter_to_tp_ranks( | ||||||||||||
| splits, | ||||||||||||
|
|
@@ -2482,11 +2495,12 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) - | |||||||||||
| if target_shape[0] % 2 != 0: | ||||||||||||
| raise ValueError(f"Expected even fused dim for {self.megatron_param}, got {target_shape}.") | ||||||||||||
|
|
||||||||||||
| gate_target_shape = (target_shape[0] // 2, target_shape[1]) | ||||||||||||
| # target_shape is the TP-sharded Megatron shape; compute the full (unsharded) shapes | ||||||||||||
| # so that _align_expert_weight_to_shape can correctly match the raw HF weights. | ||||||||||||
| # _gated_mapping.hf_to_megatron is responsible for TP scatter. | ||||||||||||
| gate_full_shape = (gate_target_shape[0] * self.tp_size, target_shape[1]) | ||||||||||||
| if isinstance(target_param, DTensor): | ||||||||||||
| gate_full_shape = (target_shape[0] // 2, target_shape[1]) | ||||||||||||
| else: | ||||||||||||
| # target_shape is the TP-sharded Megatron shape; compute the full | ||||||||||||
| # unsharded shape so raw HF weights can be validated before TP scatter. | ||||||||||||
| gate_full_shape = (target_shape[0] // 2 * self.tp_size, target_shape[1]) | ||||||||||||
| gate_up_full_shape = (gate_full_shape[0] * 2, target_shape[1]) | ||||||||||||
|
|
||||||||||||
| if expert_weight.ndim == 3 and expert_weight.shape[0] == 2: | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should not be here, can be added in MBridge's TransformerConfigs finalize - but how does it needed here in conversion / generate?