Skip to content

[training] feat: port mcore PR #4687 MoE memory estimator fix#3744

Draft
cuichenx wants to merge 1 commit into
mainfrom
chcui/port-mcore-4687-memory-estimator
Draft

[training] feat: port mcore PR #4687 MoE memory estimator fix#3744
cuichenx wants to merge 1 commit into
mainfrom
chcui/port-mcore-4687-memory-estimator

Conversation

@cuichenx
Copy link
Copy Markdown
Contributor

@cuichenx cuichenx commented May 8, 2026

Summary

Ports the MoE-aware theoretical memory estimator from upstream NVIDIA/Megatron-LM#4687 (which fixes NVIDIA/Megatron-LM#4050) into Bridge's local theoretical_memory_utils.py.

Bridge maintains its own copy of this estimator (it takes a ConfigContainer, not Megatron's argparse args), so the fix doesn't propagate automatically when the submodule is bumped — it has to be downstreamed by hand.

Why

The previous Bridge formula divided all transformer-layer parameters — including routed experts — by tensor_model_parallel_size. Routed experts are actually sharded by expert_tensor_parallel_size * expert_model_parallel_size, with distributed-optimizer state sized by the expert data-parallel domain. The old formula silently under-reports MoE memory whenever ETP * EP != TP, which is the common case for Mixtral / DeepSeek / Qwen3-MoE / GPT-OSS recipes.

The new formula splits parameters into three buckets and applies the correct divisor to each:

Bucket Examples Divisor
TP-sharded attention, dense MLP, shared experts TP
Replicated layernorms, router gate, shared-expert gate 1
Routed experts MoE FFN per expert ETP × EP

Distributed-optimizer state for routed experts uses the expert DP domain (world_size / (ETP × EP × PP)).

Also picks up several upstream features the prior Bridge version was missing:

  • MoE layer patterns (int and list form of moe_layer_freq)
  • Shared experts with optional gate (moe_shared_expert_gate)
  • Multi-latent attention (DeepSeek-style q_lora_rank / kv_lora_rank / split qk_head_dim + RoPE term)
  • Multi-Token Prediction (MTP) blocks
  • Active-vs-total parameter counts in verbose output

The estimator is informational only — it does not affect training correctness or actual memory allocation. Users sizing jobs from the printed memory estimate will get accurate numbers for MoE recipes after this lands.

Why draft

Upstream PR #4687 is open, not merged. The implementation can still change in review. This PR holds the port until #4687 merges, then we'll rebase to match the final upstream form before un-drafting.

Test plan

  • Pre-commit (ruff + format + trailing-ws) — clean
  • Five new unit tests in tests/unit_tests/training/utils/test_theoretical_memory_utils.py:
    • Dense model returns positive memory
    • EP > 1 reduces per-rank routed-expert weight (vs. EP = 1)
    • Distributed optimizer state shrinks as DP grows
    • List form of moe_layer_freq is honored
    • Invalid moe_layer_freq type raises TypeError
  • CI: cicd-unit-tests-core picks up the new module

Checklist

  • Sign-off (DCO)
  • Copyright year 2026 on the production file
  • Function signature unchanged (still compute_weight_and_optimizer_memory(config: ConfigContainer, verbose: bool = False) -> float)
  • No new dependencies
  • Bridge ConfigContainer already exposes every needed field (verified against mcore TransformerConfig + ModelParallelConfig)
  • Re-sync against final form of upstream #4687 before un-drafting

Refs

🤖 Generated with Claude Code

Replaces the simple TP-only formula in `compute_weight_and_optimizer_memory`
with the upstream MoE-aware version from NVIDIA/Megatron-LM PR #4687, mapped
to Bridge's `ConfigContainer` interface.

The previous formula divided routed-expert parameters by
`tensor_model_parallel_size` only, which under-counts memory whenever
`expert_tensor_parallel_size * expert_model_parallel_size != tensor_model_parallel_size`.
The new logic splits parameters into three buckets:

- TP-sharded (attention + dense MLP + shared experts) — divided by TP
- Replicated (layernorms + router + shared expert gate) — counted once per rank
- Routed experts — divided by ETP * EP

Distributed-optimizer state for routed experts is sized by the expert
data-parallel domain (world_size / (ETP * EP * PP)) instead of the regular DP.
Adds support for MoE layer patterns (int + list moe_layer_freq),
shared experts with optional gate, multi-latent attention (DeepSeek-style),
and Multi-Token Prediction (MTP) blocks — all of which the prior Bridge
formula collapsed into the dense-layer count.

This is a draft tracking the upstream PR; land after #4687 merges so the port
matches the final upstream form.

Refs: NVIDIA/Megatron-LM#4687, NVIDIA/Megatron-LM#4050
Signed-off-by: Chen Cui <chcui@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 8, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't modify this

@yaoyu-33 yaoyu-33 added the area:training Training loop, callbacks, and runtime integration label May 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:training Training loop, callbacks, and runtime integration

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[QUESTION] MoE layer theoretical memory calculation needs to account for ETP/EDP.

2 participants