From 9c4a975b73654783ebcb57cdb595a8d5d4657bba Mon Sep 17 00:00:00 2001 From: Gao Deng Date: Wed, 15 Apr 2026 11:55:04 -0700 Subject: [PATCH] Use sharded_state_dict_default in MLP.sharded_state_dict Carry over the main-applicable part of PR #4325. The TEGroupedMLP fused activation_func checks from the dev PR target _can_use_fused_impl/_make_fused_ops code that is not present on main. Original dev commit: 4794aabd99487907c4e8d727309fd0763580bb05. --- megatron/core/transformer/mlp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index ce10f650fa7..ed288c3e1f4 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -25,7 +25,7 @@ from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl, weighted_bias_swiglu_impl from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import cat_with_oom_fallback +from megatron.core.transformer.utils import cat_with_oom_fallback, sharded_state_dict_default from megatron.core.typed_torch import apply_module, not_none from megatron.core.utils import ( get_tensor_model_parallel_group_if_none, @@ -348,7 +348,9 @@ def sharded_state_dict( sharded_state_dict = {} singleton_local_shards = (metadata or {}).get('singleton_local_shards', False) for name, module in self._modules.items(): - sub_sd = module.sharded_state_dict(f"{prefix}{name}.", sharded_offsets, metadata) + sub_sd = sharded_state_dict_default( + module, f"{prefix}{name}.", sharded_offsets, metadata + ) if self.config.gated_linear_unit and name == "linear_fc1": for k, v in sub_sd.items(): if k in (f"{prefix}{name}.weight", f"{prefix}{name}.bias"):