diff --git a/examples/conversion/mfsdp/convert_checkpoints_fsdp.py b/examples/conversion/mfsdp/convert_checkpoints_fsdp.py index ff39a208ba..90234362f4 100644 --- a/examples/conversion/mfsdp/convert_checkpoints_fsdp.py +++ b/examples/conversion/mfsdp/convert_checkpoints_fsdp.py @@ -81,12 +81,18 @@ def _check_world_size(tp: int, cp: int, ep: int) -> None: except ValueError as err: raise ValueError("Invalid WORLD_SIZE environment variable.") from err - mp_size = tp * cp * ep - if mp_size <= 0: + non_expert_mp_size = tp * cp + expert_mp_size = ep + if non_expert_mp_size <= 0 or expert_mp_size <= 0: raise ValueError(f"Invalid parallel sizes: tp={tp}, cp={cp}, ep={ep}") - if world_size % mp_size != 0: + if world_size % non_expert_mp_size != 0: raise ValueError( - f"WORLD_SIZE ({world_size}) must be divisible by tp*cp*ep ({mp_size}). Got tp={tp}, cp={cp}, ep={ep}." + f"WORLD_SIZE ({world_size}) must be divisible by tp*cp ({non_expert_mp_size}). " + f"Got tp={tp}, cp={cp}, ep={ep}." + ) + if world_size % expert_mp_size != 0: + raise ValueError( + f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). Got tp={tp}, cp={cp}, ep={ep}." ) diff --git a/examples/conversion/mfsdp/hf_fsdp_roundtrip.py b/examples/conversion/mfsdp/hf_fsdp_roundtrip.py index 329a3d52aa..0c22538cc2 100644 --- a/examples/conversion/mfsdp/hf_fsdp_roundtrip.py +++ b/examples/conversion/mfsdp/hf_fsdp_roundtrip.py @@ -62,12 +62,18 @@ def _get_world_size() -> int: def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None: world_size = _get_world_size() - mp_size = tp * cp * ep - if mp_size <= 0: + non_expert_mp_size = tp * cp + expert_mp_size = ep + if non_expert_mp_size <= 0 or expert_mp_size <= 0: raise ValueError(f"Invalid parallel sizes: tp={tp}, cp={cp}, ep={ep}") - if world_size % mp_size != 0: + if world_size % non_expert_mp_size != 0: raise ValueError( - f"WORLD_SIZE ({world_size}) must be divisible by tp*cp*ep ({mp_size}). Got tp={tp}, cp={cp}, ep={ep}." + f"WORLD_SIZE ({world_size}) must be divisible by tp*cp ({non_expert_mp_size}). " + f"Got tp={tp}, cp={cp}, ep={ep}." + ) + if world_size % expert_mp_size != 0: + raise ValueError( + f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). Got tp={tp}, cp={cp}, ep={ep}." ) model_provider.tensor_model_parallel_size = tp diff --git a/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py b/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py index 51f6b113bd..696af46a2e 100644 --- a/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py +++ b/examples/conversion/mfsdp/hf_to_megatron_fsdp_generate_text.py @@ -117,12 +117,18 @@ def _get_world_size() -> int: def _configure_model_provider(model_provider, tp: int, cp: int, ep: int) -> None: world_size = _get_world_size() - mp_size = tp * cp * ep - if mp_size <= 0: + non_expert_mp_size = tp * cp + expert_mp_size = ep + if non_expert_mp_size <= 0 or expert_mp_size <= 0: raise ValueError(f"Invalid parallel sizes: tp={tp}, cp={cp}, ep={ep}") - if world_size % mp_size != 0: + if world_size % non_expert_mp_size != 0: raise ValueError( - f"WORLD_SIZE ({world_size}) must be divisible by tp*cp*ep ({mp_size}). Got tp={tp}, cp={cp}, ep={ep}." + f"WORLD_SIZE ({world_size}) must be divisible by tp*cp ({non_expert_mp_size}). " + f"Got tp={tp}, cp={cp}, ep={ep}." + ) + if world_size % expert_mp_size != 0: + raise ValueError( + f"WORLD_SIZE ({world_size}) must be divisible by ep ({expert_mp_size}). Got tp={tp}, cp={cp}, ep={ep}." ) model_provider.tensor_model_parallel_size = tp diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index 3494e5283e..3b5f46cdb2 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -915,7 +915,7 @@ def hf_to_megatron( splits = None if isinstance(target_param, DTensor): - output_shape = [target_param.shape[0] // self.tp_size, *target_param.shape[1:]] + output_shape = target_param.orig_param.shape else: output_shape = target_param.shape # Scatter to all ranks. Each rank gets its sharded shape from its module. @@ -1023,8 +1023,8 @@ def hf_to_megatron( else: splits = None - if isinstance(target_param, DTensor) and hf_weights.ndim != 1: - output_shape = [target_param.shape[0], target_param.shape[1] // self.tp_size, *target_param.shape[2:]] + if isinstance(target_param, DTensor): + output_shape = target_param.orig_param.shape else: output_shape = target_param.shape # Scatter to all ranks. Each rank gets its sharded shape from its module. @@ -2244,7 +2244,7 @@ def hf_to_megatron( splits = None if isinstance(target_param, DTensor): - output_shape = [target_param.shape[0] // self.tp_size, *target_param.shape[1:]] + output_shape = target_param.orig_param.shape else: output_shape = target_param.shape # Scatter the concatenated shards to each rank @@ -2482,11 +2482,12 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) - if target_shape[0] % 2 != 0: raise ValueError(f"Expected even fused dim for {self.megatron_param}, got {target_shape}.") - gate_target_shape = (target_shape[0] // 2, target_shape[1]) - # target_shape is the TP-sharded Megatron shape; compute the full (unsharded) shapes - # so that _align_expert_weight_to_shape can correctly match the raw HF weights. - # _gated_mapping.hf_to_megatron is responsible for TP scatter. - gate_full_shape = (gate_target_shape[0] * self.tp_size, target_shape[1]) + if isinstance(target_param, DTensor): + gate_full_shape = (target_shape[0] // 2, target_shape[1]) + else: + # target_shape is the TP-sharded Megatron shape; compute the full + # unsharded shape so raw HF weights can be validated before TP scatter. + gate_full_shape = (target_shape[0] // 2 * self.tp_size, target_shape[1]) gate_up_full_shape = (gate_full_shape[0] * 2, target_shape[1]) if expert_weight.ndim == 3 and expert_weight.shape[0] == 2: diff --git a/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py index 5ba879a618..4a3331c5af 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py +++ b/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py @@ -181,6 +181,11 @@ def __post_init__(self): self.vision_config = Qwen3_5VisionConfig() super().__post_init__() + def finalize(self) -> None: + if (self.context_parallel_size or 1) > 1: + self.calculate_per_token_loss = True + super().finalize() + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: """Provide a Qwen3.5 VL dense model instance with vision and language components.""" from megatron.bridge.models.gpt_provider import mtp_block_spec @@ -348,6 +353,11 @@ def __post_init__(self): self.vision_config = Qwen3_5MoeVisionConfig() super().__post_init__() + def finalize(self) -> None: + if (self.context_parallel_size or 1) > 1: + self.calculate_per_token_loss = True + super().finalize() + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: """Provide a Qwen3.5 VL model instance with vision and language components. diff --git a/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py index 441f88b14c..b8de59ab19 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py +++ b/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py @@ -134,6 +134,11 @@ class Qwen3VLModelProvider(GPTModelProvider): # If None, calculated from num_position_embeddings / spatial_merge_size^2 max_vision_cuda_graph_seq_length: Optional[int] = None + def finalize(self) -> None: + if (self.context_parallel_size or 1) > 1: + self.calculate_per_token_loss = True + super().finalize() + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: """Provide a Qwen3 VL model instance with vision and language components.""" language_transformer_config = self @@ -295,6 +300,8 @@ class Qwen3VLMoEModelProvider(GPTModelProvider): max_vision_cuda_graph_seq_length: Optional[int] = None def finalize(self) -> None: + if (self.context_parallel_size or 1) > 1: + self.calculate_per_token_loss = True if self.tensor_model_parallel_size > 1: self.sequence_parallel = True super().finalize()