Move miles_megatron_plugins into the miles repo as a top-level package#1177
Move miles_megatron_plugins into the miles repo as a top-level package#1177Shi-Dong wants to merge 2 commits into
Conversation
Adds miles_megatron_plugins/ at the top level of the miles repo. This is the source of truth from now on; the matching change on radixark/Megatron-LM (PR #32) removes the same directory from there. The third_party/Megatron-LM submodule pointer is bumped to c3a94af51 (the Megatron-LM commit with miles_megatron_plugins/ removed), and .gitmodules is temporarily flipped from branch=miles-main to branch=miles/move-plugins-to-miles. The .gitmodules note is a TEMP marker: once radixark/Megatron-LM PR #32 merges into miles-main, a follow-up commit can flip the branch line back to miles-main. Why this is safe: - miles_megatron_plugins has zero imports from miles or sglang; it only imports megatron.core and torch. So living in either repo is fine from an import-graph standpoint. - Megatron-LM source files do `from miles_megatron_plugins...` imports in several modules (transformer_layer, transformer_block, attention, gpt_model, gpt_layer_specs, tensor_parallel/{layers, matmul_tp_inv}, distributed_data_parallel, extensions/sglang). Those imports resolve via miles' editable install once miles is installed - which is true in both production (Dockerfile installs miles last) and in CI (the Install miles step is what each pr-test job runs). - After this change, `import miles_megatron_plugins.*` no longer needs PYTHONPATH=/root/Megatron-LM to resolve. The CPU runner's PYTHONPATH still includes third_party/Megatron-LM (still needed for megatron.training / legacy / rl) - that's Phase 2's removal target, not Phase 1's. Setup.py find_packages now includes miles_megatron_plugins*; pyproject's [tool.isort] known_first_party and src_paths add miles_megatron_plugins for consistency with the existing miles / miles_plugins entries.
There was a problem hiding this comment.
Code Review
This pull request introduces the miles_megatron_plugins.true_on_policy package, which provides an SGLang-compatible backend for Megatron-LM to ensure numerical parity. Key additions include SGLangFlashAttention with Ulysses Context Parallelism support, TP-invariant matmul operations using fixed-order tree summation, and custom normalization and RoPE implementations. The review feedback focuses on performance optimizations in the hot paths, such as moving signature inspections, function availability checks, and imports out of the forward methods. Additionally, a high-severity issue was identified where unconditional bfloat16 casting in the attention layer could cause numerical errors or runtime dtype mismatches, and improvements were suggested to make configuration parameter retrieval in linear layers more robust.
| query = query.to(torch.bfloat16) | ||
| key = key.to(torch.bfloat16) | ||
| value = value.to(torch.bfloat16) |
There was a problem hiding this comment.
The input tensors are unconditionally cast to torch.bfloat16. This can lead to unexpected numerical behavior if the model is configured for float16 training. Furthermore, the output is not cast back to the original input dtype, which will likely cause a RuntimeError due to dtype mismatch in subsequent layers (e.g., residual addition). Consider preserving the original dtype and casting the output back, or using query.dtype if it is a supported floating-point type.
| if not HAVE_FA3_VARLEN or fa3_varlen_func is None: | ||
| raise ImportError("Flash Attention 3 varlen is required for SGLangFlashAttention") |
| key = self.cp_layout.sequence_to_head_parallel(key, cu_seqlens_k) | ||
| value = self.cp_layout.sequence_to_head_parallel(value, cu_seqlens_k) | ||
|
|
||
| sig = inspect.signature(fa3_varlen_func) |
There was a problem hiding this comment.
| gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion"), | ||
| allreduce_dgrad=kwargs.pop("allreduce_dgrad"), | ||
| sequence_parallel=kwargs.pop("sequence_parallel"), |
There was a problem hiding this comment.
Using kwargs.pop without a default value for critical configuration parameters like gradient_accumulation_fusion, allreduce_dgrad, and sequence_parallel makes the implementation fragile. If the base class forward method changes or if this layer is used in a different context where these keys are missing, it will raise a KeyError. As per repository guidelines, these model parameters should be retrieved from the model configuration.
| gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion"), | |
| allreduce_dgrad=kwargs.pop("allreduce_dgrad"), | |
| sequence_parallel=kwargs.pop("sequence_parallel"), | |
| gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion", getattr(self, "gradient_accumulation_fusion", False)), | |
| allreduce_dgrad=kwargs.pop("allreduce_dgrad", self.config.allreduce_dgrad), | |
| sequence_parallel=kwargs.pop("sequence_parallel", self.config.sequence_parallel), |
References
- Model parameters, such as index_topk, should be retrieved from the model configuration rather than being hardcoded.
| gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion"), | ||
| allreduce_dgrad=kwargs.pop("allreduce_dgrad"), | ||
| sequence_parallel=kwargs.pop("sequence_parallel"), |
There was a problem hiding this comment.
Similar to the ColumnParallelLinear implementation, using kwargs.pop without defaults for these parameters is risky. Ensure these parameters are retrieved from the model configuration to avoid fragility and hardcoding.
| gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion"), | |
| allreduce_dgrad=kwargs.pop("allreduce_dgrad"), | |
| sequence_parallel=kwargs.pop("sequence_parallel"), | |
| gradient_accumulation_fusion=kwargs.pop("gradient_accumulation_fusion", getattr(self, "gradient_accumulation_fusion", False)), | |
| allreduce_dgrad=kwargs.pop("allreduce_dgrad", self.config.allreduce_dgrad), | |
| sequence_parallel=kwargs.pop("sequence_parallel", self.config.sequence_parallel), |
References
- Model parameters, such as index_topk, should be retrieved from the model configuration rather than being hardcoded.
| from megatron.core.parallel_state import get_tensor_model_parallel_world_size | ||
|
|
||
| return get_tensor_model_parallel_world_size() |
There was a problem hiding this comment.
| self.weight = torch.nn.Parameter(torch.ones(hidden_size)) | ||
| self.register_parameter("bias", None) | ||
| else: | ||
| raise Exception("Only LayerNorm and RMSNorm are currently supported") |
There was a problem hiding this comment.
Avoid raising a generic Exception. Use a more specific exception type like ValueError or NotImplementedError to improve error handling and clarity.
| raise Exception("Only LayerNorm and RMSNorm are currently supported") | |
| raise NotImplementedError("Only LayerNorm and RMSNorm are currently supported") |
Point the testing-only container_image override to the image built from this branch (radixark/miles:dev-phase1-test) instead of the step-2a one. Same TEMP marker, same plan to revert at merge time.
Summary
miles_megatron_plugins/at the top level of the miles repo (__init__.py+true_on_policy/with 13 modules; ~1400 lines).setup.pyfind_packages(include=[...])to expose the new package.pyproject.toml[tool.isort](known_first_party,src_paths) for consistency withmiles/miles_plugins.third_party/Megatron-LMsubmodule pointer toc3a94af51(the matching deletion onradixark/Megatron-LM— see radixark/Megatron-LM#32)..gitmodulesbranch = miles/move-plugins-to-milesforthird_party/Megatron-LM. A follow-up commit flips back tobranch = miles-mainonce Add metric of dynamic filter drop reasons #32 merges.Why
miles_megatron_pluginsis semantically miles' code that Megatron-LM imports back (true-on-policy contracts, sglang backend, matmul / attention / norm replacements). Today it lives inside theradixark/Megatron-LMfork and is only onsys.pathbecause production setsPYTHONPATH=/root/Megatron-LM. That PYTHONPATH workaround doesn't survive packaging miles as a wheel for the eventualpip install milesgoal.Moving the package into miles makes it part of miles' regular Python package. The
from miles_megatron_plugins...imports inside Megatron-LM source (transformer_layer, transformer_block, attention, gpt_model, gpt_layer_specs, distributed_data_parallel, tensor_parallel/{layers, matmul_tp_inv}, models/gpt/{gpt_model, gpt_layer_specs}, extensions/sglang) continue to work because miles is always installed beforemegatron.coreis imported — both in production (Dockerfile installs miles last) and in CI (each pr-test job runsInstall milesafter the container starts).This is Phase 1 of the broader plan toward
pip install miles. Phase 2 (separate PR) will broaden Megatron-LM'spackages.find.includesopip install -e Megatron-LMexposesmegatron.training/megatron.legacy/megatron.rl, removing the remaining piece of the CPU runner's PYTHONPATH hack.What does NOT change
_run-ci.yml's CPU runner PYTHONPATH still includes${{ github.workspace }}/third_party/Megatron-LMbecause that path is still needed formegatron.training/megatron.legacy/megatron.rl(Phase 2's removal target). After this PR, that PYTHONPATH is no longer required formiles_megatron_pluginsitself.Test plan
gh workflow run docker-build.yml --ref shi/move-miles-megatron-plugins -f variant=primary -f image_tag=custom -f custom_tag=dev-phase1-test, then on a 1-node H200 devbox:cat /root/miles/miles_megatron_plugins/__init__.pyexists.ls /root/Megatron-LM/miles_megatron_pluginsreturns "No such file or directory" (proves the Megatron-LM side was deleted).python -c "import miles_megatron_plugins.true_on_policy.contracts; print(miles_megatron_plugins.__file__)"resolves to a path under/root/miles/, not/root/Megatron-LM/. (Key sys.path resolution check.)python -c "import megatron.core.transformer.transformer_layer"still succeeds (proves Megatron-LM'sfrom miles_megatron_plugins...imports resolve via miles' install).pr-test.ymlworkflow_dispatch run onshi/move-miles-megatron-pluginsagainst the new test image (commitb310ddb37from Wire CI to install sglang/Megatron-LM from third_party submodules #1170 keeps the# TEMPcontainer_image pin; we'll override it todev-phase1-testif needed). Expect the same matrix as run 26211353043: stage-a-fast, stage-b-fast-gpu, stage-b-short, stage-b-sglang green; stage-a-unit-test red as a pre-existing issue.Stacked on
#1170 (step 2b: CI uses submodule paths). Once #1168 / #1169 / #1170 merge in order, this PR's base auto-updates.
Coordinated change
miles-main. Both should merge together. After both merge:.gitmodulesbranch = miles/move-plugins-to-milesback tobranch = miles-main(the pointer's SHA is unchanged once Add metric of dynamic filter drop reasons #32 lands onmiles-main).miles/move-plugins-to-milesbranch on radixark/Megatron-LM.