Skip to content

refactor(ccl): migrate collective kernels from AOT to JIT compilation#319

Merged
jhchouuu merged 10 commits into
mainfrom
feat/ccl-jit-kernels
May 14, 2026
Merged

refactor(ccl): migrate collective kernels from AOT to JIT compilation#319
jhchouuu merged 10 commits into
mainfrom
feat/ccl-jit-kernels

Conversation

@jhchouuu
Copy link
Copy Markdown
Collaborator

@jhchouuu jhchouuu commented May 13, 2026

Summary

Migrate CCL (collective) kernel compilation from build-time AOT (-x hip / --hip-link) to runtime JIT (hipcc --genco), consistent with the EP kernel architecture.

Motivation

  • Fix undefined symbol: atexit on glibc >= 2.34 (TheRock toolchain). AOT HIP compilation generates device registration code that references atexit, which newer glibc no longer exports as a dynamic symbol.
  • Architecture consistency: all mori kernels (EP + CCL) now use the same JIT path — .hiphipcc --genco.hsacoHipModulelaunch_struct.
  • Clean wheel: libmori_collective.so is now pure host C++ (no --hip-link, no device code embedded).
  • Runtime arch detection: no build-time --offload-arch baked into the .so.

Changes

  • Split all CCL __global__ kernels into __device__ _body + __global__ wrapper in headers (backward compatible)
  • Add src/collective/kernels/ccl_kernels.hip as JIT compilation unit with extern "C" entry points
  • Add CclAll2allArgs / CclAllgatherArgs / CclAllreduceArgs POD args structs
  • Refactor C++ host classes: remove <<<>>>, add prepare_*/finish_* methods for Python-side launch
  • Rewrite python/mori/ccl/collective.py to use compile_genco + HipModule + launch_struct
  • Remove hip::device from collective CMakeLists (host-only build)
  • Add ccl_kernels to MORI_PRECOMPILE path
  • Package .cuh files in wheel JIT sources

Behavioral equivalence

All kernel names, grid/block sizes, launch order, and sync patterns are identical to the original AOT path. No runtime behavior change.

Test plan

  • pip install . builds without -x hip or --hip-link for collective
  • nm -D libmori_collective.so | grep "U atexit" — no undefined atexit
  • import mori.ops / import mori.ccl — both work
  • MORI_PRECOMPILE=1 python -c "import mori" — ccl_kernels compiled and cached
  • test_allgather --world-size 8 --elems 1048576 — passed
  • test_all2all --world-size 8 --elems 1048576 — passed
  • AllReduce test

@jhchouuu jhchouuu requested a review from wuyl1 May 13, 2026 07:38
Move CCL kernel compilation from build-time AOT (-x hip / --hip-link) to
runtime JIT (hipcc --genco), consistent with the EP kernel architecture.

Motivation:
- Eliminates `undefined symbol: atexit` on glibc >= 2.34 caused by HIP
  device registration code in AOT-compiled .so files
- Makes libmori_collective.so a pure host C++ library (no device code)
- Enables runtime GPU architecture detection (no build-time --offload-arch)
- Consistent JIT caching via ~/.mori/jit/ with EP kernels

Changes:
- Split all CCL __global__ kernels into __device__ _body + __global__ wrapper
  in kernel headers (backward compatible with AOT callers)
- Add ccl_kernels.hip as JIT compilation unit with extern "C" entry points
- Add CclAll2allArgs/CclAllgatherArgs/CclAllreduceArgs POD structs
- Refactor C++ host classes: remove <<<>>> launches, add prepare/finish
  methods that fill args structs for Python-side launch
- Rewrite Python ccl/collective.py to use compile_genco + HipModule +
  launch_struct pattern (same as dispatch_combine.py)
- Remove hip::device from collective CMakeLists (host-only build)
- Add ccl_kernels to MORI_PRECOMPILE path
- Package .cuh files in wheel JIT sources

Co-authored-by: Cursor <cursoragent@cursor.com>
@jhchouuu jhchouuu force-pushed the feat/ccl-jit-kernels branch from db6a716 to 0e8b368 Compare May 13, 2026 07:45
jhchouuu and others added 5 commits May 13, 2026 07:50
…hip::device

The inter-node executors (ring_1d, one_shot) contain AOT device code
(<<<>>> launches) that requires hip::device. Build them as a separate
OBJECT library so the core SDMA path stays host-only.

Fixes BUILD_EXAMPLES=ON compilation.

Co-authored-by: Cursor <cursoragent@cursor.com>
Keep the JIT allreduce in-place path aligned with the original C++ behavior by forcing sync finish to copy results back to the user tensor.

Co-authored-by: Cursor <cursoragent@cursor.com>
Run one functional CI case each for allgather, all2all, and allreduce so the CCL JIT paths are covered in intranode CI.

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
The out-of-place allreduce test reads from the transit buffer via a raw
pointer view that PyTorch cannot track for stream sync.  Without an
explicit stream sync in finish_sync, .cpu() on the transit buffer view
only waits on the default stream, missing the user stream's AllGather
kernel writes.

Align with all2all/allgather finish_sync which already sync the stream
before returning.

Co-authored-by: Cursor <cursoragent@cursor.com>
@kawhil-amd kawhil-amd closed this May 13, 2026
@kawhil-amd kawhil-amd reopened this May 13, 2026
jhchouuu and others added 3 commits May 13, 2026 13:18
Move the env var from docker exec -e to inline prefix for each python
command to ensure propagation to torch.multiprocessing.spawn workers.

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Reverts 993cad2. The added hipStreamSynchronize was unnecessary —
the JIT launch path handles stream ordering via hipModuleLaunchKernel,
and the out-of-place test failure on this machine is a pre-existing
ShmemQuietThread SDMA drain issue (cross-NUMA), not a missing sync.

Co-authored-by: Cursor <cursoragent@cursor.com>
Copy link
Copy Markdown
Collaborator Author

@jhchouuu jhchouuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review comments on the AOT to JIT migration.

Comment thread python/mori/ccl/__init__.py
Comment thread src/collective/kernels/ccl_kernels.hip Outdated
Comment thread src/collective/core/oneshot_allgather_sdma_class.cpp Outdated
Comment thread include/mori/collective/allgather/oneshot_allgather_sdma_class.hpp
- Clarify __getattr__ error for AllGatherIntoTensor (not yet ported to JIT)
- Remove dead ReduceScatterAllGatherFusedKernel from JIT compilation
- Set flagVal = async_flag_token_ in allgather async prepare
- Document jit_args_ reentrancy constraint

Co-authored-by: Cursor <cursoragent@cursor.com>
@jhchouuu jhchouuu force-pushed the feat/ccl-jit-kernels branch from 96007f6 to 829e2ec Compare May 14, 2026 03:12
@jhchouuu jhchouuu merged commit 9b318ab into main May 14, 2026
9 of 11 checks passed
@jhchouuu jhchouuu deleted the feat/ccl-jit-kernels branch May 14, 2026 07:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants