From 64bcc85b2b56bcde85f749415d09b2a83d918680 Mon Sep 17 00:00:00 2001 From: conver334 Date: Thu, 7 May 2026 06:43:02 -0700 Subject: [PATCH 1/6] Fix Megatron-FSDP DTensor TP loading Signed-off-by: conver334 --- .../conversion/mfsdp/hf_fsdp_roundtrip.py | 2 + .../hf_to_megatron_fsdp_generate_text.py | 2 + .../bridge/models/conversion/param_mapping.py | 52 ++++++++++++------- 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/examples/conversion/mfsdp/hf_fsdp_roundtrip.py b/examples/conversion/mfsdp/hf_fsdp_roundtrip.py index 329a3d52aa..88ce56e7aa 100644 --- a/examples/conversion/mfsdp/hf_fsdp_roundtrip.py +++ b/examples/conversion/mfsdp/hf_fsdp_roundtrip.py @@ -73,6 +73,8 @@ def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None model_provider.tensor_model_parallel_size = tp model_provider.context_parallel_size = cp model_provider.expert_model_parallel_size = ep + if cp > 1: + model_provider.calculate_per_token_loss = True model_provider.finalize() model_provider.initialize_model_parallel(seed=0) diff --git a/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py b/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py index 51f6b113bd..4554aa9a80 100644 --- a/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py +++ b/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py @@ -128,6 +128,8 @@ def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None model_provider.tensor_model_parallel_size = tp model_provider.context_parallel_size = cp model_provider.expert_model_parallel_size = ep + if cp > 1: + model_provider.calculate_per_token_loss = True model_provider.finalize() model_provider.initialize_model_parallel(seed=0) diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index 3494e5283e..6fce99bad5 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -553,10 +553,29 @@ def _get_shard_spec( shard_rank = self.tp_rank // replicas return shard_world_size, shard_rank + def _broadcast_tp_split_shape( + self, + splits: Optional[List[torch.Tensor]], + src_rank: int = 0, + ) -> torch.Size: + """Broadcast the actual tensor shard shape from the TP scatter source.""" + shape = tuple(splits[0].shape) if self.tp_rank == src_rank and splits else None + shape_list = [shape] + global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) + + torch.distributed.broadcast_object_list( + shape_list, + src=global_src, + group=self.tp_group, + ) + if shape_list[0] is None: + raise RuntimeError("Failed to broadcast tensor-parallel split shape") + return torch.Size(shape_list[0]) + def scatter_to_tp_ranks( self, splits: Optional[List[torch.Tensor]], - output_shape: torch.Size, + output_shape: Optional[torch.Size], dtype: torch.dtype, device: torch.device, src_rank: int = 0, @@ -578,7 +597,10 @@ def scatter_to_tp_ranks( if self.tp_size == 1: return splits[0].to(device=device, dtype=dtype) if splits else None - output = torch.empty(output_shape, dtype=dtype, device=device) + if output_shape is None: + output_shape = self._broadcast_tp_split_shape(splits, src_rank) + + output = torch.empty(torch.Size(output_shape), dtype=dtype, device=device) global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) scatter_list = None @@ -914,10 +936,7 @@ def hf_to_megatron( else: splits = None - if isinstance(target_param, DTensor): - output_shape = [target_param.shape[0] // self.tp_size, *target_param.shape[1:]] - else: - output_shape = target_param.shape + output_shape = None if isinstance(target_param, DTensor) else target_param.shape # Scatter to all ranks. Each rank gets its sharded shape from its module. return self.scatter_to_tp_ranks( splits, @@ -1023,10 +1042,7 @@ def hf_to_megatron( else: splits = None - if isinstance(target_param, DTensor) and hf_weights.ndim != 1: - output_shape = [target_param.shape[0], target_param.shape[1] // self.tp_size, *target_param.shape[2:]] - else: - output_shape = target_param.shape + output_shape = None if isinstance(target_param, DTensor) else target_param.shape # Scatter to all ranks. Each rank gets its sharded shape from its module. return self.scatter_to_tp_ranks( splits, @@ -2243,10 +2259,7 @@ def hf_to_megatron( else: splits = None - if isinstance(target_param, DTensor): - output_shape = [target_param.shape[0] // self.tp_size, *target_param.shape[1:]] - else: - output_shape = target_param.shape + output_shape = None if isinstance(target_param, DTensor) else target_param.shape # Scatter the concatenated shards to each rank return self.scatter_to_tp_ranks( splits, @@ -2482,11 +2495,12 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) - if target_shape[0] % 2 != 0: raise ValueError(f"Expected even fused dim for {self.megatron_param}, got {target_shape}.") - gate_target_shape = (target_shape[0] // 2, target_shape[1]) - # target_shape is the TP-sharded Megatron shape; compute the full (unsharded) shapes - # so that _align_expert_weight_to_shape can correctly match the raw HF weights. - # _gated_mapping.hf_to_megatron is responsible for TP scatter. - gate_full_shape = (gate_target_shape[0] * self.tp_size, target_shape[1]) + if isinstance(target_param, DTensor): + gate_full_shape = (target_shape[0] // 2, target_shape[1]) + else: + # target_shape is the TP-sharded Megatron shape; compute the full + # unsharded shape so raw HF weights can be validated before TP scatter. + gate_full_shape = (target_shape[0] // 2 * self.tp_size, target_shape[1]) gate_up_full_shape = (gate_full_shape[0] * 2, target_shape[1]) if expert_weight.ndim == 3 and expert_weight.shape[0] == 2: From 5d6aeea79d2c70078917a6d037411b1e355725a1 Mon Sep 17 00:00:00 2001 From: conver334 Date: Mon, 11 May 2026 23:29:58 -0700 Subject: [PATCH 2/6] Support ETP in M-FSDP conversion Add expert tensor parallel size plumbing to import and export paths. Validate WORLD_SIZE separately for non-expert tp*cp and expert etp*ep parallel groups, and avoid forcing calculate_per_token_loss for CP conversion. Signed-off-by: conver334 --- .../mfsdp/convert_checkpoints_fsdp.py | 57 ++++++++++++++----- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/examples/conversion/mfsdp/convert_checkpoints_fsdp.py b/examples/conversion/mfsdp/convert_checkpoints_fsdp.py index ff39a208ba..f4f33a9192 100644 --- a/examples/conversion/mfsdp/convert_checkpoints_fsdp.py +++ b/examples/conversion/mfsdp/convert_checkpoints_fsdp.py @@ -75,27 +75,42 @@ def _check_distributed() -> None: sys.exit(1) -def _check_world_size(tp: int, cp: int, ep: int) -> None: +def _check_world_size(tp: int, cp: int, ep: int, etp: int) -> None: try: world_size = int(os.environ.get("WORLD_SIZE", "1")) except ValueError as err: raise ValueError("Invalid WORLD_SIZE environment variable.") from err - mp_size = tp * cp * ep - if mp_size <= 0: - raise ValueError(f"Invalid parallel sizes: tp={tp}, cp={cp}, ep={ep}") - if world_size % mp_size != 0: + non_expert_mp_size = tp * cp + expert_mp_size = etp * ep + if non_expert_mp_size <= 0 or expert_mp_size <= 0: + raise ValueError(f"Invalid parallel sizes: tp={tp}, cp={cp}, ep={ep}, etp={etp}") + if world_size % non_expert_mp_size != 0: raise ValueError( - f"WORLD_SIZE ({world_size}) must be divisible by tp*cp*ep ({mp_size}). Got tp={tp}, cp={cp}, ep={ep}." + f"WORLD_SIZE ({world_size}) must be divisible by tp*cp ({non_expert_mp_size}). " + f"Got tp={tp}, cp={cp}, ep={ep}, etp={etp}." + ) + if world_size % expert_mp_size != 0: + raise ValueError( + f"WORLD_SIZE ({world_size}) must be divisible by etp*ep ({expert_mp_size}). " + f"Got tp={tp}, cp={cp}, ep={ep}, etp={etp}." ) -def _build_fsdp_distributed_model(bridge: AutoBridge, tp: int, cp: int, ep: int, dtype: torch.dtype): +def _build_fsdp_distributed_model( + bridge: AutoBridge, + tp: int, + cp: int, + ep: int, + etp: int, + dtype: torch.dtype, +): """Build and return a Megatron-FSDP wrapped model list.""" model_provider = bridge.to_megatron_provider(load_weights=False) model_provider.tensor_model_parallel_size = tp model_provider.context_parallel_size = cp model_provider.expert_model_parallel_size = ep + model_provider.expert_tensor_parallel_size = etp model_provider.pipeline_dtype = dtype model_provider.params_dtype = dtype model_provider.gradient_accumulation_fusion = False @@ -125,6 +140,7 @@ def import_hf_to_megatron_fsdp( tp: int = 1, cp: int = 1, ep: int = 1, + etp: int = 1, torch_dtype: str = "bfloat16", trust_remote_code: bool = False, low_memory_save: bool = True, @@ -132,11 +148,13 @@ def import_hf_to_megatron_fsdp( ) -> None: """Import a HuggingFace model and save it as a DTensor checkpoint.""" _check_distributed() - _check_world_size(tp=tp, cp=cp, ep=ep) + _check_world_size(tp=tp, cp=cp, ep=ep, etp=etp) dtype = _parse_dtype(torch_dtype) print_rank_0(f"Importing: {hf_model} -> {megatron_path}") - print_rank_0(f" TP={tp} CP={cp} EP={ep} dtype={torch_dtype} ckpt_format={ckpt_format}") + print_rank_0( + f" TP={tp} CP={cp} EP={ep} ETP={etp} dtype={torch_dtype} ckpt_format={ckpt_format}" + ) bridge = AutoBridge.from_hf_pretrained( hf_model, @@ -144,7 +162,14 @@ def import_hf_to_megatron_fsdp( torch_dtype=dtype, ) - _, _, megatron_model = _build_fsdp_distributed_model(bridge, tp=tp, cp=cp, ep=ep, dtype=dtype) + _, _, megatron_model = _build_fsdp_distributed_model( + bridge, + tp=tp, + cp=cp, + ep=ep, + etp=etp, + dtype=dtype, + ) bridge.load_hf_weights(megatron_model) @@ -180,6 +205,7 @@ def export_megatron_to_hf( tp: int = 1, cp: int = 1, ep: int = 1, + etp: int = 1, torch_dtype: str = "bfloat16", trust_remote_code: bool = False, ckpt_format: str = "fsdp_dtensor", @@ -190,11 +216,11 @@ def export_megatron_to_hf( ) -> None: """Export Megatron checkpoint to HuggingFace format.""" _check_distributed() - _check_world_size(tp=tp, cp=cp, ep=ep) + _check_world_size(tp=tp, cp=cp, ep=ep, etp=etp) dtype = _parse_dtype(torch_dtype) print_rank_0(f"Exporting: {megatron_path} -> {hf_path}") - print_rank_0(f" TP={tp} CP={cp} EP={ep} dtype={torch_dtype} ckpt_format={ckpt_format}") + print_rank_0(f" TP={tp} CP={cp} EP={ep} ETP={etp} dtype={torch_dtype} ckpt_format={ckpt_format}") print_rank_0(f" distributed_save={distributed_save} save_every_n_ranks={save_every_n_ranks}") bridge = AutoBridge.from_hf_pretrained( @@ -207,7 +233,7 @@ def export_megatron_to_hf( if ckpt_format == "fsdp_dtensor": # Build an FSDP-wrapped model and load with the training checkpoint loader. model_provider, ddp_config, megatron_model = _build_fsdp_distributed_model( - bridge, tp=tp, cp=cp, ep=ep, dtype=dtype + bridge, tp=tp, cp=cp, ep=ep, etp=etp, dtype=dtype ) state = GlobalState() @@ -241,6 +267,7 @@ def export_megatron_to_hf( "tensor_model_parallel_size": tp, "context_parallel_size": cp, "expert_model_parallel_size": ep, + "expert_tensor_parallel_size": etp, "pipeline_dtype": dtype, } megatron_model = bridge.load_megatron_model( @@ -267,6 +294,7 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size") parser.add_argument("--cp", type=int, default=1, help="Context parallelism size") parser.add_argument("--ep", type=int, default=1, help="Expert parallelism size") + parser.add_argument("--etp", type=int, default=1, help="Expert tensor parallelism size") parser.add_argument( "--torch-dtype", choices=list(DTYPE_MAP), @@ -301,7 +329,6 @@ def main() -> None: action="store_true", help="Disable low-memory save mode (keeps model alive after save)", ) - export_parser = subparsers.add_parser("export", help="Export DTensor checkpoint to HuggingFace format") _add_common_args(export_parser) export_parser.add_argument("--megatron-path", required=True, help="Directory containing the DTensor checkpoint") @@ -334,6 +361,7 @@ def main() -> None: tp=args.tp, cp=args.cp, ep=args.ep, + etp=args.etp, torch_dtype=args.torch_dtype, trust_remote_code=args.trust_remote_code, low_memory_save=not args.no_low_memory_save, @@ -347,6 +375,7 @@ def main() -> None: tp=args.tp, cp=args.cp, ep=args.ep, + etp=args.etp, torch_dtype=args.torch_dtype, trust_remote_code=args.trust_remote_code, ckpt_format=args.ckpt_format, From 42f54d8f150c5c9b78d40261f991ec23b9b19944 Mon Sep 17 00:00:00 2001 From: conver334 Date: Tue, 12 May 2026 00:30:15 -0700 Subject: [PATCH 3/6] Move Qwen VL CP loss default to providers Signed-off-by: conver334 --- examples/conversion/mfsdp/hf_fsdp_roundtrip.py | 2 -- .../mfsdp/hf_to_megatron_fsdp_generate_text.py | 2 -- .../bridge/models/qwen_vl/qwen35_vl_provider.py | 10 ++++++++++ .../bridge/models/qwen_vl/qwen3_vl_provider.py | 7 +++++++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/examples/conversion/mfsdp/hf_fsdp_roundtrip.py b/examples/conversion/mfsdp/hf_fsdp_roundtrip.py index 88ce56e7aa..329a3d52aa 100644 --- a/examples/conversion/mfsdp/hf_fsdp_roundtrip.py +++ b/examples/conversion/mfsdp/hf_fsdp_roundtrip.py @@ -73,8 +73,6 @@ def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None model_provider.tensor_model_parallel_size = tp model_provider.context_parallel_size = cp model_provider.expert_model_parallel_size = ep - if cp > 1: - model_provider.calculate_per_token_loss = True model_provider.finalize() model_provider.initialize_model_parallel(seed=0) diff --git a/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py b/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py index 4554aa9a80..51f6b113bd 100644 --- a/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py +++ b/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py @@ -128,8 +128,6 @@ def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None model_provider.tensor_model_parallel_size = tp model_provider.context_parallel_size = cp model_provider.expert_model_parallel_size = ep - if cp > 1: - model_provider.calculate_per_token_loss = True model_provider.finalize() model_provider.initialize_model_parallel(seed=0) diff --git a/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py index 5ba879a618..4a3331c5af 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py +++ b/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py @@ -181,6 +181,11 @@ def __post_init__(self): self.vision_config = Qwen3_5VisionConfig() super().__post_init__() + def finalize(self) -> None: + if (self.context_parallel_size or 1) > 1: + self.calculate_per_token_loss = True + super().finalize() + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: """Provide a Qwen3.5 VL dense model instance with vision and language components.""" from megatron.bridge.models.gpt_provider import mtp_block_spec @@ -348,6 +353,11 @@ def __post_init__(self): self.vision_config = Qwen3_5MoeVisionConfig() super().__post_init__() + def finalize(self) -> None: + if (self.context_parallel_size or 1) > 1: + self.calculate_per_token_loss = True + super().finalize() + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: """Provide a Qwen3.5 VL model instance with vision and language components. diff --git a/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py index 441f88b14c..b8de59ab19 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py +++ b/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py @@ -134,6 +134,11 @@ class Qwen3VLModelProvider(GPTModelProvider): # If None, calculated from num_position_embeddings / spatial_merge_size^2 max_vision_cuda_graph_seq_length: Optional[int] = None + def finalize(self) -> None: + if (self.context_parallel_size or 1) > 1: + self.calculate_per_token_loss = True + super().finalize() + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: """Provide a Qwen3 VL model instance with vision and language components.""" language_transformer_config = self @@ -295,6 +300,8 @@ class Qwen3VLMoEModelProvider(GPTModelProvider): max_vision_cuda_graph_seq_length: Optional[int] = None def finalize(self) -> None: + if (self.context_parallel_size or 1) > 1: + self.calculate_per_token_loss = True if self.tensor_model_parallel_size > 1: self.sequence_parallel = True super().finalize() From 3cba942f0a84f23e665984413b6b216ef3a57082 Mon Sep 17 00:00:00 2001 From: conver334 Date: Wed, 13 May 2026 02:27:47 -0700 Subject: [PATCH 4/6] Fix Megatron-FSDP DTensor TP scatter shapes Signed-off-by: conver334 --- .../bridge/models/conversion/param_mapping.py | 41 +++++++------------ 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index 6fce99bad5..3b5f46cdb2 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -553,29 +553,10 @@ def _get_shard_spec( shard_rank = self.tp_rank // replicas return shard_world_size, shard_rank - def _broadcast_tp_split_shape( - self, - splits: Optional[List[torch.Tensor]], - src_rank: int = 0, - ) -> torch.Size: - """Broadcast the actual tensor shard shape from the TP scatter source.""" - shape = tuple(splits[0].shape) if self.tp_rank == src_rank and splits else None - shape_list = [shape] - global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) - - torch.distributed.broadcast_object_list( - shape_list, - src=global_src, - group=self.tp_group, - ) - if shape_list[0] is None: - raise RuntimeError("Failed to broadcast tensor-parallel split shape") - return torch.Size(shape_list[0]) - def scatter_to_tp_ranks( self, splits: Optional[List[torch.Tensor]], - output_shape: Optional[torch.Size], + output_shape: torch.Size, dtype: torch.dtype, device: torch.device, src_rank: int = 0, @@ -597,10 +578,7 @@ def scatter_to_tp_ranks( if self.tp_size == 1: return splits[0].to(device=device, dtype=dtype) if splits else None - if output_shape is None: - output_shape = self._broadcast_tp_split_shape(splits, src_rank) - - output = torch.empty(torch.Size(output_shape), dtype=dtype, device=device) + output = torch.empty(output_shape, dtype=dtype, device=device) global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) scatter_list = None @@ -936,7 +914,10 @@ def hf_to_megatron( else: splits = None - output_shape = None if isinstance(target_param, DTensor) else target_param.shape + if isinstance(target_param, DTensor): + output_shape = target_param.orig_param.shape + else: + output_shape = target_param.shape # Scatter to all ranks. Each rank gets its sharded shape from its module. return self.scatter_to_tp_ranks( splits, @@ -1042,7 +1023,10 @@ def hf_to_megatron( else: splits = None - output_shape = None if isinstance(target_param, DTensor) else target_param.shape + if isinstance(target_param, DTensor): + output_shape = target_param.orig_param.shape + else: + output_shape = target_param.shape # Scatter to all ranks. Each rank gets its sharded shape from its module. return self.scatter_to_tp_ranks( splits, @@ -2259,7 +2243,10 @@ def hf_to_megatron( else: splits = None - output_shape = None if isinstance(target_param, DTensor) else target_param.shape + if isinstance(target_param, DTensor): + output_shape = target_param.orig_param.shape + else: + output_shape = target_param.shape # Scatter the concatenated shards to each rank return self.scatter_to_tp_ranks( splits, From 970fce768fa8b41bfd99e4be743efaf884a1a989 Mon Sep 17 00:00:00 2001 From: conver334 Date: Wed, 13 May 2026 02:43:14 -0700 Subject: [PATCH 5/6] Update M-FSDP world size validation Signed-off-by: conver334 --- .../mfsdp/convert_checkpoints_fsdp.py | 50 ++++++------------- .../conversion/mfsdp/hf_fsdp_roundtrip.py | 15 ++++-- .../hf_to_megatron_fsdp_generate_text.py | 15 ++++-- 3 files changed, 36 insertions(+), 44 deletions(-) diff --git a/examples/conversion/mfsdp/convert_checkpoints_fsdp.py b/examples/conversion/mfsdp/convert_checkpoints_fsdp.py index f4f33a9192..33e8964f3d 100644 --- a/examples/conversion/mfsdp/convert_checkpoints_fsdp.py +++ b/examples/conversion/mfsdp/convert_checkpoints_fsdp.py @@ -75,42 +75,34 @@ def _check_distributed() -> None: sys.exit(1) -def _check_world_size(tp: int, cp: int, ep: int, etp: int) -> None: +def _check_world_size(tp: int, cp: int, ep: int) -> None: try: world_size = int(os.environ.get("WORLD_SIZE", "1")) except ValueError as err: raise ValueError("Invalid WORLD_SIZE environment variable.") from err non_expert_mp_size = tp * cp - expert_mp_size = etp * ep + expert_mp_size = ep if non_expert_mp_size <= 0 or expert_mp_size <= 0: - raise ValueError(f"Invalid parallel sizes: tp={tp}, cp={cp}, ep={ep}, etp={etp}") + raise ValueError(f"Invalid parallel sizes: tp={tp}, cp={cp}, ep={ep}") if world_size % non_expert_mp_size != 0: raise ValueError( f"WORLD_SIZE ({world_size}) must be divisible by tp*cp ({non_expert_mp_size}). " - f"Got tp={tp}, cp={cp}, ep={ep}, etp={etp}." + f"Got tp={tp}, cp={cp}, ep={ep}." ) if world_size % expert_mp_size != 0: raise ValueError( - f"WORLD_SIZE ({world_size}) must be divisible by etp*ep ({expert_mp_size}). " - f"Got tp={tp}, cp={cp}, ep={ep}, etp={etp}." + f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). " + f"Got tp={tp}, cp={cp}, ep={ep}." ) -def _build_fsdp_distributed_model( - bridge: AutoBridge, - tp: int, - cp: int, - ep: int, - etp: int, - dtype: torch.dtype, -): +def _build_fsdp_distributed_model(bridge: AutoBridge, tp: int, cp: int, ep: int, dtype: torch.dtype): """Build and return a Megatron-FSDP wrapped model list.""" model_provider = bridge.to_megatron_provider(load_weights=False) model_provider.tensor_model_parallel_size = tp model_provider.context_parallel_size = cp model_provider.expert_model_parallel_size = ep - model_provider.expert_tensor_parallel_size = etp model_provider.pipeline_dtype = dtype model_provider.params_dtype = dtype model_provider.gradient_accumulation_fusion = False @@ -140,7 +132,6 @@ def import_hf_to_megatron_fsdp( tp: int = 1, cp: int = 1, ep: int = 1, - etp: int = 1, torch_dtype: str = "bfloat16", trust_remote_code: bool = False, low_memory_save: bool = True, @@ -148,13 +139,11 @@ def import_hf_to_megatron_fsdp( ) -> None: """Import a HuggingFace model and save it as a DTensor checkpoint.""" _check_distributed() - _check_world_size(tp=tp, cp=cp, ep=ep, etp=etp) + _check_world_size(tp=tp, cp=cp, ep=ep) dtype = _parse_dtype(torch_dtype) print_rank_0(f"Importing: {hf_model} -> {megatron_path}") - print_rank_0( - f" TP={tp} CP={cp} EP={ep} ETP={etp} dtype={torch_dtype} ckpt_format={ckpt_format}" - ) + print_rank_0(f" TP={tp} CP={cp} EP={ep} dtype={torch_dtype} ckpt_format={ckpt_format}") bridge = AutoBridge.from_hf_pretrained( hf_model, @@ -162,14 +151,7 @@ def import_hf_to_megatron_fsdp( torch_dtype=dtype, ) - _, _, megatron_model = _build_fsdp_distributed_model( - bridge, - tp=tp, - cp=cp, - ep=ep, - etp=etp, - dtype=dtype, - ) + _, _, megatron_model = _build_fsdp_distributed_model(bridge, tp=tp, cp=cp, ep=ep, dtype=dtype) bridge.load_hf_weights(megatron_model) @@ -205,7 +187,6 @@ def export_megatron_to_hf( tp: int = 1, cp: int = 1, ep: int = 1, - etp: int = 1, torch_dtype: str = "bfloat16", trust_remote_code: bool = False, ckpt_format: str = "fsdp_dtensor", @@ -216,11 +197,11 @@ def export_megatron_to_hf( ) -> None: """Export Megatron checkpoint to HuggingFace format.""" _check_distributed() - _check_world_size(tp=tp, cp=cp, ep=ep, etp=etp) + _check_world_size(tp=tp, cp=cp, ep=ep) dtype = _parse_dtype(torch_dtype) print_rank_0(f"Exporting: {megatron_path} -> {hf_path}") - print_rank_0(f" TP={tp} CP={cp} EP={ep} ETP={etp} dtype={torch_dtype} ckpt_format={ckpt_format}") + print_rank_0(f" TP={tp} CP={cp} EP={ep} dtype={torch_dtype} ckpt_format={ckpt_format}") print_rank_0(f" distributed_save={distributed_save} save_every_n_ranks={save_every_n_ranks}") bridge = AutoBridge.from_hf_pretrained( @@ -233,7 +214,7 @@ def export_megatron_to_hf( if ckpt_format == "fsdp_dtensor": # Build an FSDP-wrapped model and load with the training checkpoint loader. model_provider, ddp_config, megatron_model = _build_fsdp_distributed_model( - bridge, tp=tp, cp=cp, ep=ep, etp=etp, dtype=dtype + bridge, tp=tp, cp=cp, ep=ep, dtype=dtype ) state = GlobalState() @@ -267,7 +248,6 @@ def export_megatron_to_hf( "tensor_model_parallel_size": tp, "context_parallel_size": cp, "expert_model_parallel_size": ep, - "expert_tensor_parallel_size": etp, "pipeline_dtype": dtype, } megatron_model = bridge.load_megatron_model( @@ -294,7 +274,6 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size") parser.add_argument("--cp", type=int, default=1, help="Context parallelism size") parser.add_argument("--ep", type=int, default=1, help="Expert parallelism size") - parser.add_argument("--etp", type=int, default=1, help="Expert tensor parallelism size") parser.add_argument( "--torch-dtype", choices=list(DTYPE_MAP), @@ -329,6 +308,7 @@ def main() -> None: action="store_true", help="Disable low-memory save mode (keeps model alive after save)", ) + export_parser = subparsers.add_parser("export", help="Export DTensor checkpoint to HuggingFace format") _add_common_args(export_parser) export_parser.add_argument("--megatron-path", required=True, help="Directory containing the DTensor checkpoint") @@ -361,7 +341,6 @@ def main() -> None: tp=args.tp, cp=args.cp, ep=args.ep, - etp=args.etp, torch_dtype=args.torch_dtype, trust_remote_code=args.trust_remote_code, low_memory_save=not args.no_low_memory_save, @@ -375,7 +354,6 @@ def main() -> None: tp=args.tp, cp=args.cp, ep=args.ep, - etp=args.etp, torch_dtype=args.torch_dtype, trust_remote_code=args.trust_remote_code, ckpt_format=args.ckpt_format, diff --git a/examples/conversion/mfsdp/hf_fsdp_roundtrip.py b/examples/conversion/mfsdp/hf_fsdp_roundtrip.py index 329a3d52aa..6b71887012 100644 --- a/examples/conversion/mfsdp/hf_fsdp_roundtrip.py +++ b/examples/conversion/mfsdp/hf_fsdp_roundtrip.py @@ -62,12 +62,19 @@ def _get_world_size() -> int: def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None: world_size = _get_world_size() - mp_size = tp * cp * ep - if mp_size <= 0: + non_expert_mp_size = tp * cp + expert_mp_size = ep + if non_expert_mp_size <= 0 or expert_mp_size <= 0: raise ValueError(f"Invalid parallel sizes: tp={tp}, cp={cp}, ep={ep}") - if world_size % mp_size != 0: + if world_size % non_expert_mp_size != 0: raise ValueError( - f"WORLD_SIZE ({world_size}) must be divisible by tp*cp*ep ({mp_size}). Got tp={tp}, cp={cp}, ep={ep}." + f"WORLD_SIZE ({world_size}) must be divisible by tp*cp ({non_expert_mp_size}). " + f"Got tp={tp}, cp={cp}, ep={ep}." + ) + if world_size % expert_mp_size != 0: + raise ValueError( + f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). " + f"Got tp={tp}, cp={cp}, ep={ep}." ) model_provider.tensor_model_parallel_size = tp diff --git a/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py b/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py index 51f6b113bd..82401a2fe1 100644 --- a/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py +++ b/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py @@ -117,12 +117,19 @@ def _get_world_size() -> int: def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None: world_size = _get_world_size() - mp_size = tp * cp * ep - if mp_size <= 0: + non_expert_mp_size = tp * cp + expert_mp_size = ep + if non_expert_mp_size <= 0 or expert_mp_size <= 0: raise ValueError(f"Invalid parallel sizes: tp={tp}, cp={cp}, ep={ep}") - if world_size % mp_size != 0: + if world_size % non_expert_mp_size != 0: raise ValueError( - f"WORLD_SIZE ({world_size}) must be divisible by tp*cp*ep ({mp_size}). Got tp={tp}, cp={cp}, ep={ep}." + f"WORLD_SIZE ({world_size}) must be divisible by tp*cp ({non_expert_mp_size}). " + f"Got tp={tp}, cp={cp}, ep={ep}." + ) + if world_size % expert_mp_size != 0: + raise ValueError( + f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). " + f"Got tp={tp}, cp={cp}, ep={ep}." ) model_provider.tensor_model_parallel_size = tp From 5da14a226ece49fe418f1363fa53c4678eb1e09c Mon Sep 17 00:00:00 2001 From: conver334 Date: Wed, 13 May 2026 08:26:37 -0700 Subject: [PATCH 6/6] Format M-FSDP world size validation Signed-off-by: conver334 --- examples/conversion/mfsdp/convert_checkpoints_fsdp.py | 3 +-- examples/conversion/mfsdp/hf_fsdp_roundtrip.py | 3 +-- examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/conversion/mfsdp/convert_checkpoints_fsdp.py b/examples/conversion/mfsdp/convert_checkpoints_fsdp.py index 33e8964f3d..90234362f4 100644 --- a/examples/conversion/mfsdp/convert_checkpoints_fsdp.py +++ b/examples/conversion/mfsdp/convert_checkpoints_fsdp.py @@ -92,8 +92,7 @@ def _check_world_size(tp: int, cp: int, ep: int) -> None: ) if world_size % expert_mp_size != 0: raise ValueError( - f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). " - f"Got tp={tp}, cp={cp}, ep={ep}." + f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). Got tp={tp}, cp={cp}, ep={ep}." ) diff --git a/examples/conversion/mfsdp/hf_fsdp_roundtrip.py b/examples/conversion/mfsdp/hf_fsdp_roundtrip.py index 6b71887012..0c22538cc2 100644 --- a/examples/conversion/mfsdp/hf_fsdp_roundtrip.py +++ b/examples/conversion/mfsdp/hf_fsdp_roundtrip.py @@ -73,8 +73,7 @@ def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None ) if world_size % expert_mp_size != 0: raise ValueError( - f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). " - f"Got tp={tp}, cp={cp}, ep={ep}." + f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). Got tp={tp}, cp={cp}, ep={ep}." ) model_provider.tensor_model_parallel_size = tp diff --git a/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py b/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py index 82401a2fe1..696af46a2e 100644 --- a/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py +++ b/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py @@ -128,8 +128,7 @@ def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None ) if world_size % expert_mp_size != 0: raise ValueError( - f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). " - f"Got tp={tp}, cp={cp}, ep={ep}." + f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). Got tp={tp}, cp={cp}, ep={ep}." ) model_provider.tensor_model_parallel_size = tp