diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 46979a8ba8f..dfc708bb350 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -26,7 +26,7 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import cat_with_oom_fallback +from megatron.core.transformer.utils import cat_with_oom_fallback, sharded_state_dict_default from megatron.core.typed_torch import apply_module, not_none from megatron.core.utils import ( get_tensor_model_parallel_group_if_none, @@ -349,7 +349,9 @@ def sharded_state_dict( sharded_state_dict = {} singleton_local_shards = (metadata or {}).get('singleton_local_shards', False) for name, module in self._modules.items(): - sub_sd = module.sharded_state_dict(f"{prefix}{name}.", sharded_offsets, metadata) + sub_sd = sharded_state_dict_default( + module, f"{prefix}{name}.", sharded_offsets, metadata + ) if self.config.gated_linear_unit and name == "linear_fc1": for k, v in sub_sd.items(): if k in (f"{prefix}{name}.weight", f"{prefix}{name}.bias"):