Skip to content

Move miles_megatron_plugins into the miles repo as a top-level package#1177

Draft
Shi-Dong wants to merge 2 commits into
shi/ci-use-submodule-imagefrom
shi/move-miles-megatron-plugins
Draft

Move miles_megatron_plugins into the miles repo as a top-level package#1177
Shi-Dong wants to merge 2 commits into
shi/ci-use-submodule-imagefrom
shi/move-miles-megatron-plugins

Conversation

@Shi-Dong
Copy link
Copy Markdown
Contributor

Summary

  • Add miles_megatron_plugins/ at the top level of the miles repo (__init__.py + true_on_policy/ with 13 modules; ~1400 lines).
  • Update setup.py find_packages(include=[...]) to expose the new package.
  • Update pyproject.toml [tool.isort] (known_first_party, src_paths) for consistency with miles / miles_plugins.
  • Bump the third_party/Megatron-LM submodule pointer to c3a94af51 (the matching deletion on radixark/Megatron-LM — see radixark/Megatron-LM#32).
  • Temporarily set .gitmodules branch = miles/move-plugins-to-miles for third_party/Megatron-LM. A follow-up commit flips back to branch = miles-main once Add metric of dynamic filter drop reasons #32 merges.

Why

miles_megatron_plugins is semantically miles' code that Megatron-LM imports back (true-on-policy contracts, sglang backend, matmul / attention / norm replacements). Today it lives inside the radixark/Megatron-LM fork and is only on sys.path because production sets PYTHONPATH=/root/Megatron-LM. That PYTHONPATH workaround doesn't survive packaging miles as a wheel for the eventual pip install miles goal.

Moving the package into miles makes it part of miles' regular Python package. The from miles_megatron_plugins... imports inside Megatron-LM source (transformer_layer, transformer_block, attention, gpt_model, gpt_layer_specs, distributed_data_parallel, tensor_parallel/{layers, matmul_tp_inv}, models/gpt/{gpt_model, gpt_layer_specs}, extensions/sglang) continue to work because miles is always installed before megatron.core is imported — both in production (Dockerfile installs miles last) and in CI (each pr-test job runs Install miles after the container starts).

This is Phase 1 of the broader plan toward pip install miles. Phase 2 (separate PR) will broaden Megatron-LM's packages.find.include so pip install -e Megatron-LM exposes megatron.training / megatron.legacy / megatron.rl, removing the remaining piece of the CPU runner's PYTHONPATH hack.

What does NOT change

  • _run-ci.yml's CPU runner PYTHONPATH still includes ${{ github.workspace }}/third_party/Megatron-LM because that path is still needed for megatron.training / megatron.legacy / megatron.rl (Phase 2's removal target). After this PR, that PYTHONPATH is no longer required for miles_megatron_plugins itself.
  • No Dockerfile changes. The submodule pointer bump auto-fires a rebuild via the path-trigger added in Source sglang and Megatron-LM from third_party submodules in docker build #1169.

Test plan

  • Build a test image via gh workflow run docker-build.yml --ref shi/move-miles-megatron-plugins -f variant=primary -f image_tag=custom -f custom_tag=dev-phase1-test, then on a 1-node H200 devbox:
    • cat /root/miles/miles_megatron_plugins/__init__.py exists.
    • ls /root/Megatron-LM/miles_megatron_plugins returns "No such file or directory" (proves the Megatron-LM side was deleted).
    • python -c "import miles_megatron_plugins.true_on_policy.contracts; print(miles_megatron_plugins.__file__)" resolves to a path under /root/miles/, not /root/Megatron-LM/. (Key sys.path resolution check.)
    • python -c "import megatron.core.transformer.transformer_layer" still succeeds (proves Megatron-LM's from miles_megatron_plugins... imports resolve via miles' install).
  • pr-test.yml workflow_dispatch run on shi/move-miles-megatron-plugins against the new test image (commit b310ddb37 from Wire CI to install sglang/Megatron-LM from third_party submodules #1170 keeps the # TEMP container_image pin; we'll override it to dev-phase1-test if needed). Expect the same matrix as run 26211353043: stage-a-fast, stage-b-fast-gpu, stage-b-short, stage-b-sglang green; stage-a-unit-test red as a pre-existing issue.

Stacked on

#1170 (step 2b: CI uses submodule paths). Once #1168 / #1169 / #1170 merge in order, this PR's base auto-updates.

Coordinated change

  • radixark/Megatron-LM: PR #32 removes the same directory from miles-main. Both should merge together. After both merge:
    1. Flip .gitmodules branch = miles/move-plugins-to-miles back to branch = miles-main (the pointer's SHA is unchanged once Add metric of dynamic filter drop reasons #32 lands on miles-main).
    2. Optionally delete the miles/move-plugins-to-miles branch on radixark/Megatron-LM.

Adds miles_megatron_plugins/ at the top level of the miles repo. This is
the source of truth from now on; the matching change on
radixark/Megatron-LM (PR #32) removes the same directory from there.

The third_party/Megatron-LM submodule pointer is bumped to
c3a94af51 (the Megatron-LM commit with miles_megatron_plugins/ removed),
and .gitmodules is temporarily flipped from branch=miles-main to
branch=miles/move-plugins-to-miles. The .gitmodules note is a TEMP
marker: once radixark/Megatron-LM PR #32 merges into miles-main, a
follow-up commit can flip the branch line back to miles-main.

Why this is safe:
- miles_megatron_plugins has zero imports from miles or sglang; it only
  imports megatron.core and torch. So living in either repo is fine
  from an import-graph standpoint.
- Megatron-LM source files do `from miles_megatron_plugins...` imports
  in several modules (transformer_layer, transformer_block, attention,
  gpt_model, gpt_layer_specs, tensor_parallel/{layers, matmul_tp_inv},
  distributed_data_parallel, extensions/sglang). Those imports resolve
  via miles' editable install once miles is installed - which is true in
  both production (Dockerfile installs miles last) and in CI (the
  Install miles step is what each pr-test job runs).
- After this change, `import miles_megatron_plugins.*` no longer needs
  PYTHONPATH=/root/Megatron-LM to resolve. The CPU runner's PYTHONPATH
  still includes third_party/Megatron-LM (still needed for
  megatron.training / legacy / rl) - that's Phase 2's removal target,
  not Phase 1's.

Setup.py find_packages now includes miles_megatron_plugins*; pyproject's
[tool.isort] known_first_party and src_paths add miles_megatron_plugins
for consistency with the existing miles / miles_plugins entries.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the miles_megatron_plugins.true_on_policy package, which provides an SGLang-compatible backend for Megatron-LM to ensure numerical parity. Key additions include SGLangFlashAttention with Ulysses Context Parallelism support, TP-invariant matmul operations using fixed-order tree summation, and custom normalization and RoPE implementations. The review feedback focuses on performance optimizations in the hot paths, such as moving signature inspections, function availability checks, and imports out of the forward methods. Additionally, a high-severity issue was identified where unconditional bfloat16 casting in the attention layer could cause numerical errors or runtime dtype mismatches, and improvements were suggested to make configuration parameter retrieval in linear layers more robust.

Comment on lines +120 to +122
query = query.to(torch.bfloat16)
key = key.to(torch.bfloat16)
value = value.to(torch.bfloat16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The input tensors are unconditionally cast to torch.bfloat16. This can lead to unexpected numerical behavior if the model is configured for float16 training. Furthermore, the output is not cast back to the original input dtype, which will likely cause a RuntimeError due to dtype mismatch in subsequent layers (e.g., residual addition). Consider preserving the original dtype and casting the output back, or using query.dtype if it is a supported floating-point type.

Comment on lines +107 to +108
if not HAVE_FA3_VARLEN or fa3_varlen_func is None:
raise ImportError("Flash Attention 3 varlen is required for SGLangFlashAttention")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Checking for the availability of fa3_varlen_func inside the forward method adds unnecessary overhead to every iteration. It is better to perform this check once during initialization in init to fail early and keep the hot path clean.

key = self.cp_layout.sequence_to_head_parallel(key, cu_seqlens_k)
value = self.cp_layout.sequence_to_head_parallel(value, cu_seqlens_k)

sig = inspect.signature(fa3_varlen_func)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling inspect.signature inside the forward method is inefficient as it is a relatively expensive operation performed on every iteration. Since the signature of fa3_varlen_func is static, you should inspect it once at the module level or during init and cache the result.

Comment on lines +20 to +22
gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion"),
allreduce_dgrad=kwargs.pop("allreduce_dgrad"),
sequence_parallel=kwargs.pop("sequence_parallel"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using kwargs.pop without a default value for critical configuration parameters like gradient_accumulation_fusion, allreduce_dgrad, and sequence_parallel makes the implementation fragile. If the base class forward method changes or if this layer is used in a different context where these keys are missing, it will raise a KeyError. As per repository guidelines, these model parameters should be retrieved from the model configuration.

Suggested change
gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion"),
allreduce_dgrad=kwargs.pop("allreduce_dgrad"),
sequence_parallel=kwargs.pop("sequence_parallel"),
gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion", getattr(self, "gradient_accumulation_fusion", False)),
allreduce_dgrad=kwargs.pop("allreduce_dgrad", self.config.allreduce_dgrad),
sequence_parallel=kwargs.pop("sequence_parallel", self.config.sequence_parallel),
References
  1. Model parameters, such as index_topk, should be retrieved from the model configuration rather than being hardcoded.

Comment on lines +42 to +44
gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion"),
allreduce_dgrad=kwargs.pop("allreduce_dgrad"),
sequence_parallel=kwargs.pop("sequence_parallel"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the ColumnParallelLinear implementation, using kwargs.pop without defaults for these parameters is risky. Ensure these parameters are retrieved from the model configuration to avoid fragility and hardcoding.

Suggested change
gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion"),
allreduce_dgrad=kwargs.pop("allreduce_dgrad"),
sequence_parallel=kwargs.pop("sequence_parallel"),
gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion", getattr(self, "gradient_accumulation_fusion", False)),
allreduce_dgrad=kwargs.pop("allreduce_dgrad", self.config.allreduce_dgrad),
sequence_parallel=kwargs.pop("sequence_parallel", self.config.sequence_parallel),
References
  1. Model parameters, such as index_topk, should be retrieved from the model configuration rather than being hardcoded.

Comment on lines +37 to +39
from megatron.core.parallel_state import get_tensor_model_parallel_world_size

return get_tensor_model_parallel_world_size()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Importing from megatron.core.parallel_state inside a function that is called in the hot path (every linear layer forward pass) adds unnecessary overhead. While Python caches imports, the lookup still has a cost. Move the import to the module level or cache the imported function.

self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.register_parameter("bias", None)
else:
raise Exception("Only LayerNorm and RMSNorm are currently supported")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid raising a generic Exception. Use a more specific exception type like ValueError or NotImplementedError to improve error handling and clarity.

Suggested change
raise Exception("Only LayerNorm and RMSNorm are currently supported")
raise NotImplementedError("Only LayerNorm and RMSNorm are currently supported")

Point the testing-only container_image override to the image built from this branch (radixark/miles:dev-phase1-test) instead of the step-2a one. Same TEMP marker, same plan to revert at merge time.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant