Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions megatron/core/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import cat_with_oom_fallback
from megatron.core.transformer.utils import cat_with_oom_fallback, sharded_state_dict_default
from megatron.core.typed_torch import apply_module, not_none
from megatron.core.utils import (
get_tensor_model_parallel_group_if_none,
Expand Down Expand Up @@ -349,7 +349,9 @@ def sharded_state_dict(
sharded_state_dict = {}
singleton_local_shards = (metadata or {}).get('singleton_local_shards', False)
for name, module in self._modules.items():
sub_sd = module.sharded_state_dict(f"{prefix}{name}.", sharded_offsets, metadata)
sub_sd = sharded_state_dict_default(
module, f"{prefix}{name}.", sharded_offsets, metadata
)
if self.config.gated_linear_unit and name == "linear_fc1":
for k, v in sub_sd.items():
if k in (f"{prefix}{name}.weight", f"{prefix}{name}.bias"):
Expand Down