Skip to content

[feat] Add SP and PP support for qwen_moe true on policy#1088

Open
maocheng23 wants to merge 1 commit into
feat/true_on_policy_qwen_moefrom
feat/true_on_policy_qwen_moe_sppp
Open

[feat] Add SP and PP support for qwen_moe true on policy#1088
maocheng23 wants to merge 1 commit into
feat/true_on_policy_qwen_moefrom
feat/true_on_policy_qwen_moe_sppp

Conversation

@maocheng23
Copy link
Copy Markdown
Contributor

@maocheng23 maocheng23 commented May 7, 2026

Stacked on top of #1059.

Summary

Adds the Miles/orchestration side of the Qwen3-30B-A3B MoE SP+PP true-on-policy stack.

This PR is one of three coupled PRs that should be reviewed and landed together because they extend the qwen3_moe_true_on_policy_v1 contract to allow Megatron sequence parallel and pipeline parallel training while preserving SGLang rollout parity.

Companion PRs:

Main Changes

  • Add _sync_before_rank_subset_logging barrier in the actor train loop, skipped outside true-on-policy mode or when TP/PP <= 1.
  • Add topology-specific raw torch-dist checkpoint conversion/load paths for Qwen3 MoE TP/PP/EP/ETP runs.
  • Add PP global layer naming and Qwen3 MoE grouped-MLP weight1/weight2 mappings in miles_plugins/megatron_bridge so HF <-> Megatron conversion works under PP and EP.
  • Wire SP/PP options through true-on-policy config/model profile/schema, scripts, checkpoint conversion, and TPP fixtures.
  • Add launch-assembly coverage for the combined single-node SP+PP topology: Megatron TP=2/PP=2/EP=2 with SGLang EP=2.

Test Plan

  • Local syntax check for touched Python paths.
  • git diff --check
  • Direct launch-wiring check for TP=2/PP=2/EP=2/SP with topology-specific raw torch-dist checkpoint path.
  • pytest tests/fast/true_on_policy/test_run_qwen3_30b_a3b.py in the remote/container test environment. Local pyenv pytest collection hung before output.
  • PP update E2E smoke: TP=1/PP=2/EP=4, verify train/train_rollout_logprob_abs_diff=0.0, rollout weight versions advance without mixed-version samples, and final update completes.
  • Combined SP+PP update E2E smoke: TP=2/PP=2/EP=2, verify the same update and exact-logprob criteria.
  • Normal true-on-policy vs matched off-policy comparison: verify finite nonzero grad_norm in the off-policy-comparable band and no material timing regression in rollout, logprob, actor-train, weight-update, or step time.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables Megatron sequence parallel support for true-on-policy training and rollouts, specifically targeting Qwen3 models. Key updates include patching the Megatron Bridge to correctly handle global layer and expert indexing across Pipeline and Expert Parallelism, enhancing weight auditing with detailed layer summaries, and updating checkpoint conversion tools to support topology-specific paths. Feedback identifies a critical runtime error in the bridge plugin due to incorrect process group method calls, a logic error that skips expert globalization when pipeline parallelism is disabled, and an opportunity to improve synchronization logic by accounting for expert parallelism.

Comment on lines +191 to +195
num_experts_per_rank = num_experts // ep_group.size()

def _update_expert_number(param_name: str, param_type: str) -> str:
local_expert_number = int(param_name.split(f".{param_type}")[-1])
global_expert_number = num_experts_per_rank * ep_group.rank() + local_expert_number
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The torch.distributed.ProcessGroup object does not have .size() or .rank() methods. Accessing them will raise an AttributeError at runtime. Use bridge_model_bridge.get_pg_size(ep_group) and bridge_model_bridge.parallel_state.get_expert_model_parallel_rank() instead.

Suggested change
num_experts_per_rank = num_experts // ep_group.size()
def _update_expert_number(param_name: str, param_type: str) -> str:
local_expert_number = int(param_name.split(f".{param_type}")[-1])
global_expert_number = num_experts_per_rank * ep_group.rank() + local_expert_number
num_experts_per_rank = num_experts // bridge_model_bridge.get_pg_size(ep_group)
def _update_expert_number(param_name: str, param_type: str) -> str:
local_expert_number = int(param_name.split(f".{param_type}")[-1])
global_expert_number = num_experts_per_rank * bridge_model_bridge.parallel_state.get_expert_model_parallel_rank() + local_expert_number

Comment on lines +341 to +344
if (
self.args.tensor_model_parallel_size <= 1
and self.args.pipeline_model_parallel_size <= 1
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition to skip the synchronization barrier should also account for Expert Parallelism (EP). If expert_model_parallel_size > 1, ranks are split across experts and synchronization is likely required for consistent logging in true-on-policy mode, similar to TP and PP.

Suggested change
if (
self.args.tensor_model_parallel_size <= 1
and self.args.pipeline_model_parallel_size <= 1
):
if (
self.args.tensor_model_parallel_size <= 1
and self.args.pipeline_model_parallel_size <= 1
and getattr(self.args, "expert_model_parallel_size", 1) <= 1
):
References
  1. Model parameters should be retrieved from the model configuration rather than being hardcoded.

Comment on lines +209 to +210
if "decoder.layers." not in param_name or bridge_model_bridge.get_pg_size(pp_group) <= 1:
return original(models, config, param_name, vp_stage)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic returns early and calls the original globalization function if PP size is 1. This effectively skips the _apply_ep_global_expert_number call, meaning expert indices will not be globalized in topologies where PP=1 and EP>1. The patch should ensure both PP and EP globalization are applied independently of each other.

Suggested change
if "decoder.layers." not in param_name or bridge_model_bridge.get_pg_size(pp_group) <= 1:
return original(models, config, param_name, vp_stage)
def patched_megatron_local_name_to_global(models, config, param_name: str, vp_stage=None) -> str:
param_name = original(models, config, param_name, vp_stage)
pp_group = bridge_model_bridge.parallel_state.get_pipeline_model_parallel_group()
if "decoder.layers." in param_name and bridge_model_bridge.get_pg_size(pp_group) > 1:
pp_rank = bridge_model_bridge.parallel_state.get_pipeline_model_parallel_rank()
layer_offset = get_transformer_layer_offset(config, vp_stage=vp_stage, pp_rank=pp_rank)
param_name = _globalize_decoder_layer_name(param_name, layer_offset)
return _apply_ep_global_expert_number(param_name, config)

@maocheng23 maocheng23 force-pushed the feat/true_on_policy_qwen_moe branch from f0e102d to 8a740d7 Compare May 19, 2026 18:20
@maocheng23 maocheng23 requested a review from jybsuper as a code owner May 19, 2026 18:20
@maocheng23 maocheng23 force-pushed the feat/true_on_policy_qwen_moe branch 2 times, most recently from fe23383 to e52a170 Compare May 23, 2026 02:47
@maocheng23 maocheng23 force-pushed the feat/true_on_policy_qwen_moe_sppp branch from f541aa2 to cc32be0 Compare May 23, 2026 22:22
@maocheng23 maocheng23 force-pushed the feat/true_on_policy_qwen_moe_sppp branch 3 times, most recently from a5ef00f to 627eb30 Compare May 23, 2026 23:41
Co-authored-by: zju-stu-lizheng <lizheng.cs@zju.edu.cn>
Co-authored-by: zyxiyy02 <282300612+zyxiyy02@users.noreply.github.com>
Co-authored-by: Yi Zhang <1109276519@qq.com>
@maocheng23 maocheng23 force-pushed the feat/true_on_policy_qwen_moe_sppp branch from 627eb30 to 8c9d0b4 Compare May 23, 2026 23:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant