refactor(ccl): migrate collective kernels from AOT to JIT compilation#319
Merged
Conversation
Move CCL kernel compilation from build-time AOT (-x hip / --hip-link) to runtime JIT (hipcc --genco), consistent with the EP kernel architecture. Motivation: - Eliminates `undefined symbol: atexit` on glibc >= 2.34 caused by HIP device registration code in AOT-compiled .so files - Makes libmori_collective.so a pure host C++ library (no device code) - Enables runtime GPU architecture detection (no build-time --offload-arch) - Consistent JIT caching via ~/.mori/jit/ with EP kernels Changes: - Split all CCL __global__ kernels into __device__ _body + __global__ wrapper in kernel headers (backward compatible with AOT callers) - Add ccl_kernels.hip as JIT compilation unit with extern "C" entry points - Add CclAll2allArgs/CclAllgatherArgs/CclAllreduceArgs POD structs - Refactor C++ host classes: remove <<<>>> launches, add prepare/finish methods that fill args structs for Python-side launch - Rewrite Python ccl/collective.py to use compile_genco + HipModule + launch_struct pattern (same as dispatch_combine.py) - Remove hip::device from collective CMakeLists (host-only build) - Add ccl_kernels to MORI_PRECOMPILE path - Package .cuh files in wheel JIT sources Co-authored-by: Cursor <cursoragent@cursor.com>
db6a716 to
0e8b368
Compare
…hip::device The inter-node executors (ring_1d, one_shot) contain AOT device code (<<<>>> launches) that requires hip::device. Build them as a separate OBJECT library so the core SDMA path stays host-only. Fixes BUILD_EXAMPLES=ON compilation. Co-authored-by: Cursor <cursoragent@cursor.com>
Keep the JIT allreduce in-place path aligned with the original C++ behavior by forcing sync finish to copy results back to the user tensor. Co-authored-by: Cursor <cursoragent@cursor.com>
Run one functional CI case each for allgather, all2all, and allreduce so the CCL JIT paths are covered in intranode CI. Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
The out-of-place allreduce test reads from the transit buffer via a raw pointer view that PyTorch cannot track for stream sync. Without an explicit stream sync in finish_sync, .cpu() on the transit buffer view only waits on the default stream, missing the user stream's AllGather kernel writes. Align with all2all/allgather finish_sync which already sync the stream before returning. Co-authored-by: Cursor <cursoragent@cursor.com>
Move the env var from docker exec -e to inline prefix for each python command to ensure propagation to torch.multiprocessing.spawn workers. Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Reverts 993cad2. The added hipStreamSynchronize was unnecessary — the JIT launch path handles stream ordering via hipModuleLaunchKernel, and the out-of-place test failure on this machine is a pre-existing ShmemQuietThread SDMA drain issue (cross-NUMA), not a missing sync. Co-authored-by: Cursor <cursoragent@cursor.com>
jhchouuu
commented
May 14, 2026
Collaborator
Author
jhchouuu
left a comment
There was a problem hiding this comment.
Review comments on the AOT to JIT migration.
- Clarify __getattr__ error for AllGatherIntoTensor (not yet ported to JIT) - Remove dead ReduceScatterAllGatherFusedKernel from JIT compilation - Set flagVal = async_flag_token_ in allgather async prepare - Document jit_args_ reentrancy constraint Co-authored-by: Cursor <cursoragent@cursor.com>
96007f6 to
829e2ec
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Migrate CCL (collective) kernel compilation from build-time AOT (
-x hip/--hip-link) to runtime JIT (hipcc --genco), consistent with the EP kernel architecture.Motivation
undefined symbol: atexiton glibc >= 2.34 (TheRock toolchain). AOT HIP compilation generates device registration code that referencesatexit, which newer glibc no longer exports as a dynamic symbol..hip→hipcc --genco→.hsaco→HipModule→launch_struct.libmori_collective.sois now pure host C++ (no--hip-link, no device code embedded).--offload-archbaked into the .so.Changes
__global__kernels into__device__ _body+__global__wrapper in headers (backward compatible)src/collective/kernels/ccl_kernels.hipas JIT compilation unit withextern "C"entry pointsCclAll2allArgs/CclAllgatherArgs/CclAllreduceArgsPOD args structs<<<>>>, addprepare_*/finish_*methods for Python-side launchpython/mori/ccl/collective.pyto usecompile_genco+HipModule+launch_structhip::devicefrom collective CMakeLists (host-only build)ccl_kernelstoMORI_PRECOMPILEpath.cuhfiles in wheel JIT sourcesBehavioral equivalence
All kernel names, grid/block sizes, launch order, and sync patterns are identical to the original AOT path. No runtime behavior change.
Test plan
pip install .builds without-x hipor--hip-linkfor collectivenm -D libmori_collective.so | grep "U atexit"— no undefined atexitimport mori.ops/import mori.ccl— both workMORI_PRECOMPILE=1 python -c "import mori"— ccl_kernels compiled and cachedtest_allgather --world-size 8 --elems 1048576— passedtest_all2all --world-size 8 --elems 1048576— passed