feat(shmem): add MORI_MULTITHREAD_SUPPORT for SPMT#308
Conversation
| // exit so XLA's other state isn't disturbed) before any HIP call. This | ||
| // ensures GetHandleCacheSlot() and ShmemStatesSingleton::GetInstance() | ||
| // (both keyed by hipGetDevice()) hit the right slot. | ||
| ScopedDevice _dev_guard(mori::shmem::ShmemStatesSingleton::GetDeviceByRank(cfg.rank)); |
There was a problem hiding this comment.
Has this been checked? I was under impression that xla ensures that the "context" device is set to the one the thread serves.
There was a problem hiding this comment.
Yes, and could you please check it again? Maybe my understanding is incorrect.
There was a problem hiding this comment.
I think the issue here is that, we cannot initialize MORI in a correct way (when SIMT is on), unless MORI is integrated to XLA. This is because XLA is managing GPU threads, so ShmemInit() shall be called by XLA intentally when GPU communicators are created
There was a problem hiding this comment.
In my original branch, I added some hacky way for MORI initialization to see if it generally works
|
|
||
| namespace mori { | ||
|
|
||
| inline constexpr int kMaxGpusPerNode = 8; |
There was a problem hiding this comment.
I think with CPX it is 32. Not sure if it can be split even more than that.
There was a problem hiding this comment.
Yes, we will have 32 72 GPUs at rack level. It's best to not hardcode this or we should get this max number of GPU from the build script?
There was a problem hiding this comment.
Yes, CPX and rack level will have more than 8 GPUs. We have already tried this on CPX, but it was limited to a single card before... Additionally, the rack level is also included in our plan...
So currently, kMaxGpusPerNode equals to 8...
| if (i == rank) continue; | ||
| void* mappedPtr = nullptr; | ||
| HIP_RUNTIME_CHECK( | ||
| hipIpcOpenMemHandle(&mappedPtr, signalHandles[i], hipIpcMemLazyEnablePeerAccess)); |
There was a problem hiding this comment.
I have tried running my XLA SPMT test - and getting the issue here:
[/data/mori/src/application/memory/symmetric_memory.cpp:251] hip failed with invalid device context
I think we shall guard this block by SameProcessP2P(i)
There was a problem hiding this comment.
your XLA SPMT test is related to SDMA? I think I should add a verification for the SDMA path to my TODO list... I haven't verified SDMA...
There was a problem hiding this comment.
yep, I use simple SDMA-based collectives there. But I think RegisterSymmMemObj can be used also without SDMA transport
There was a problem hiding this comment.
OK. I'm continuing my work on SDMA. I will verify the SDMA path through the async_ll EP kernel.
…read (SPMT) - Add MORI_MULTITHREAD_SUPPORT cmake option (default OFF) - ShmemStatesSingleton::GetInstance() returns per-GPU slot via hipGetDevice() using std::array<ShmemStates, 8> for stable addresses and lock-free reads - Embed GpuStates and ModuleStates as values in ShmemStates (no heap alloc) - Remove file-scope globals: s_hostGpuStatesCopy, s_shmemModule, s_deviceGpuStatesAddr, s_barrierFunc - All init/finalize functions take explicit ShmemStates* (thread-safe) - ShmemFinalize resets status to New to allow re-init on same GPU - PID allgather in CollectHostNames for same-process P2P detection - Skip hipIpcOpenMemHandle for same-process peers (use direct pointer) - Python: per-GPU JIT module loading with double-check locking - Python: GIL release on all blocking shmem APIs - mori_log: try/catch spdlog race on concurrent logger registration - Add tests/python/shmem/test_spmt.py (world_size 2/4/8) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- pybind_shmem: remove duplicate m.def for shmem_finalize / shmem_mype / shmem_npes / shmem_torch_process_group_init that overrode the GIL-release variants and effectively serialized SPMT finalize. - symmetric_memory: explicitly hipDeviceEnablePeerAccess for same-process peers. The IPC-handle path (lazy enable via hipIpcMemLazyEnablePeerAccess) is skipped for same-process, so without this fix P2P-only SPMT would hit invalid-device-pointer at peer access time. Use hipPointerGetAttributes to discover the peer's device id without assuming a rank-to-device map. - init.cpp ShmemFinalize: run FinalizeInternalSync before FinalizeGpuStates. FinalizeGpuStates calls FinalizeRuntime which clears states->gpuStates, including internalSyncPtr — running it first made FinalizeInternalSync early-return and leak the sync memory each init/finalize cycle. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…(SPMT) JAX Make the XLA FFI EP path SPMT-safe so JAX users can drive multiple GPUs from one Python process (jax.devices() returns all 8) instead of being forced into the multi-process jax.distributed.initialize model. All changes are gated by MORI_MULTITHREAD_SUPPORT; the macro-OFF and multi-process paths are byte-identical to before. Three structural caches that used hipModuleLoad-bound resources were process-global singletons. Under SPMT each thread is bound to its own device, so a singleton hands the wrong device's resources to other threads and the EP launches crash or silently use the wrong context: - KernelRegistry::GetImpl: was a static singleton holding loaded hipModule_t. Now std::array<Impl, 8> indexed by hipGetDevice(). In multi-process every process sees its single GPU as device 0 and collapses to slot[0] — equivalent to the old singleton. - pybind_xla_ffi_ops g_handle_cache: was process-global with one mutex. Under SPMT this would deadlock — thread A holds the mutex during EpDispatchCombineHandle's cross-PE Barrier(), thread B blocks waiting for the mutex and never reaches its own Barrier(). Replaced with per-GPU HandleCacheSlot (own mutex + map). - ShmemStatesSingleton rank→device map: XLA FFI handlers run on framework worker threads where hipGetDevice() does NOT match the rank's device. Added RegisterRankDevice/GetDeviceByRank, populated in InitializeBootStates from the user thread's device. FFI handlers (Instantiate + Impl) now look up the rank's device and hipSetDevice to it before any state access. python/mori/shmem/api.py:_ensure_shmem_module no longer imports torch to read the current device — that import broke JAX containers that ship without torch. Now uses ctypes hipGetDevice via the existing mori.jit.hip_driver helper, working uniformly for both torch and JAX. Adds tests/python/ops/test_dispatch_combine_jax_spmt.py: spawns one host thread per GPU in a single Python process, each thread runs ShmemInit + EpDispatchCombineOp dispatch+combine via XLA FFI. Verifies the full SPMT JAX path end-to-end. Per-thread shmem_finalize + clear_ep_handle_cache so the parametrized 2/4/8 GPU sizes can run in one pytest invocation. Verified: MORI-EP JAX SPMT 2/4/8 GPU 3 passed in 22s MORI-EP JAX multi-process (existing) 1 passed in 17s (no regression) MORI-EP intranode (torchrun) 115 passed, 208 skipped MORI-SPMT shmem control plane 3 passed in 3s Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…slot Under SPMT, each thread owns one GPU and one cache slot. The previous implementation iterated ALL per-GPU slots from any caller, so when thread 0 called clear_ep_handle_cache it ran ~EpDispatchCombineHandle on thread 1's, 2's, ... handles too. Each ~Handle calls ShmemFree on the buffers it allocated on its own GPU's symmetric heap, but the calling thread's hipDevice was still 0, so ShmemFree looked up those addresses in GPU 0's HeapVAManager and reported "address not found" hundreds of times before the test process eventually aborted with SIGABRT during teardown. Fix: only clear the slot returned by GetHandleCacheSlot() (the calling thread's slot under SPMT, the global slot in single-GPU mode). Each SPMT thread is responsible for clearing its own cache as part of its shmem_finalize sequence — same pattern as ShmemStatesSingleton. Also: in tests/python/ops/test_dispatch_combine_jax_spmt.py add the gc.collect() between cache clear and shmem_finalize (mirrors mori.jax.shmem_finalize). Do NOT call jax.clear_caches() — it is process-global and races across SPMT threads. After this fix the SPMT JAX EP test exits 0 cleanly with zero HeapVAManager errors, across 5 consecutive 2/4/8 GPU runs. Multi-process JAX EP regression unaffected. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ce guard Self-review cleanups on top of the SPMT branch. 1. Centralize kMaxGpusPerNode in mori/utils/limits.hpp. Previously the constant was duplicated in three places (internal.hpp, launch.cpp, pybind_xla_ffi_ops.cpp) and bumping it for >8-GPU nodes (e.g. future MI400) would have required editing all three. Now a single inline constexpr that any TU can pull in cheaply. 2. Drop dead ShmemStatesSingleton::mutex_ field. The comment claimed "we still take a brief lock the first time to guard concurrent ShmemInit", but no code actually locks it. SPMT's contract is one thread per GPU, so each slot is accessed serially by its owner thread; cross-thread synchronization is only needed for the rank → device map below, which has its own mutex. 3. Drop dead ShmemStatesStatus::Finalized. ShmemFinalize resets to New (so the slot can be reused for re-init), so the Finalized state was never actually set, which made the check `if (status == Finalized)` in ShmemInit dead code. Removed both. 4. RAII ScopedDevice guard for XLA FFI handlers. The previous code called hipSetDevice(rank_dev) directly, leaving XLA's worker thread bound to a different device than what XLA had set on entry — a subtle violation of the convention that XLA owns its worker thread state. ScopedDevice restores the saved device on scope exit so the change is local to the FFI handler call. 5. Also: drop spurious (void)hipGetLastError() calls in three lookups. They were silently swallowing real errors from prior unchecked HIP calls. hipGetDevice / hipSetDevice return their own status; sticky errors only surface on next hipDeviceSynchronize / kernel launch. 6. Test doc fix: clarify why _build_config passes gpu_per_node == world_size (single-node SPMT, EP handle requires IsPowerOf2(gpuPerNode) && worldSize % gpuPerNode == 0). I briefly tried to "fix" this to use the physical GPU count and tripped the assertion — kept the original behavior with explanatory comments. Verified: SPMT JAX EP 2/4/8 GPU 5/5 runs clean exit Multi-process JAX EP no regression intranode EP (torchrun) 115 passed shmem SPMT control plane 3 passed Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The earlier cleanup (d27a5df) deleted ShmemStatesStatus::Finalized as dead code on the rationale that ShmemFinalize() resets the slot to `New` to allow re-init, so Finalized was never actually set or checked. That removal was over-aggressive. The Finalized state value is documented intent (a slot can be in three logical phases: never-init / live / torn-down) and someone might want terminal-finalize semantics later — for example, to print a clearer diagnostic when a finalized slot is touched again, or to forbid re-init in a stricter deployment. Restore both the enum value and the corresponding check in CheckStatusValid(). Leave ShmemFinalize() as-is (resets to New) so SPMT test suites that init/finalize multiple times keep working; if/when finalize semantics need to flip, only the line in ShmemFinalize() needs to change. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ests)
The C++ SPMT exploration example was useful as a proof-of-concept while
designing the SPMT implementation, but it has two structural problems
that make it a liability now:
1. False-pass behavior: when the cross-PE collective-permute kernel
produces wrong data, the example still marks the result as PASS
(line 194: `result.permute_pass = true; // expected under static
binary SPMT`). A "regression test" that always passes regardless of
correctness is worse than no test — it gives false confidence.
2. The kernel is documented to be incorrect: the example's own header
comments admit that under a statically-compiled HIP binary,
globalGpuStates is a single device symbol shared across all SPMT
threads, so the device-side kernel result "may not be correct under
SPMT". The collective-permute kernel exists but does nothing
meaningful as a regression check.
Coverage is now provided by two Python tests:
- tests/python/shmem/test_spmt.py — shmem control plane
(init/finalize/malloc/barrier) with real assertions
- tests/python/ops/test_dispatch_combine_jax_spmt.py — full EP
dispatch+combine round-trip with data verification
Both Python tests use JIT-loaded modules, so each GPU has its own
globalGpuStates and the kernels actually exercise SPMT correctly.
Drop the example and its CMake gate. Net change: -273 lines, no loss
of SPMT test coverage.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
CI's pre-commit hook flagged formatting drift on three of the SPMT commits in this PR. No semantic changes — purely: - black: line wrapping in tests/python/shmem/test_spmt.py and tests/python/ops/test_dispatch_combine_jax_spmt.py - clang-format: line wrapping in symmetric_memory.cpp, dispatch_combine/launch.cpp, and pybind_shmem.cpp - cmake-format: trailing blank line in CMakeLists.txt Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Three small follow-up cleanups from a self-review pass on the SPMT PR. No behavior change for happy-path callers; pure robustness/cleanup. 1. python/mori/shmem/api.py: _current_hip_device() now sets explicit argtypes/restype on hip.hipGetDevice. Without these ctypes assumes int args + int return, which happens to be right on x86_64 Linux but is not portable. Be explicit so future ABI shifts don't silently corrupt the device id we read. 2. src/pybind/pybind_xla_ffi_ops.cpp: EpDispatchCombineInstantiate was decoding the packed ep_config twice — once to extract `rank` for SPMT device routing, and once on cache miss to construct the handle. Decode once and reuse. Saves a small amount of work on the cache-miss path; mostly a readability win. 3. include/mori/utils/mori_log.hpp: after the try/catch around spdlog::stdout_color_mt, fall back to spdlog::get() — but that call can in principle return null (e.g. if the registry was dropped between the throw and our second lookup). Bail out cleanly instead of dereferencing a null shared_ptr below. Verified: SPMT JAX EP 2/4/8 GPU 3 passed in 24s Multi-process JAX EP 1 passed in 17s Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous SPMT JAX EP smoke test only checked that the FFI handlers
returned without crashing — it materialized one byte of the combine
output and called it a day. That's enough to catch deadlocks and gross
crashes, but it would silently pass even if dispatch routed tokens to
the wrong rank or combine produced garbage.
Port the dispatch/combine validation from
test_dispatch_combine_jax.py:
- _validate_dispatch: decode (sender_pe, local_tok_id) from
src_token_pos, look up the original input via
inputs_list[pe * inp_tok_per_rank + local_id], and check it matches
the dispatched output. Also check no two received tokens share the
same src_pos (no double-delivery).
- _validate_combine: each input token is dispatched to `unique_pes`
distinct PEs; combine sums the unique_pes copies, so combined output
should equal `input * unique_pes` (within bf16 atol/rtol).
The multi-process test does cross-rank all-gather via shard_map +
jax.lax.all_gather. SPMT can't use that (no jax.distributed).
Instead, every rank generates inputs deterministically from
PRNGKey(BASE_SEED + rank), so each thread can locally reconstruct
every other rank's inputs by re-seeding — no cross-thread comm needed.
Also add the env-var bypass set by the multi-process test:
MORI_SHMEM_HEAP_SIZE=16G (4G default is tight for 8-GPU EP)
XLA_FLAGS:
--xla_gpu_autotune_level=0 (skip slow first-JIT autotune)
--xla_gpu_enable_command_buffer= (disable HIP graph)
--xla_gpu_enable_triton_gemm=false (avoid Triton-AMDGPU pass errs)
Verified: 2/4/8 GPU all PASS with both "dispatch data verified" and
"combine data verified" printed per thread; clean exit; no Triton noise.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Black wrapped a few function signatures and call sites differently than I had them. Pure formatting; no behavior change. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
6011acb to
700ca53
Compare
…i-thread) Enable AsyncLL+SDMA kernel path to work correctly in SPMT mode where multiple GPU threads share the same process address space. Core changes: - anvil: key sdma_channels_ by (srcDeviceId, dstDeviceId) pair instead of dstDeviceId alone, preventing channel cross-contamination between GPUs in the same process. Add mutex for thread-safe map access during concurrent shmem_init from multiple SPMT threads. - symmetric_memory: for same-process peers (SPMT), exchange SDMA signal pointers via Allgather raw VA + hipDeviceEnablePeerAccess instead of hipIpcOpenMemHandle (which fails within the same process). Clear HIP sticky errors after hipDeviceEnablePeerAccess. On deregistration, close SDMA signal IPC handles for cross-process peers and free all SDMA GPU allocations (signalPtrs, expectSignalsPtr, peerSignalPtrs, deviceHandles_d). - launch: fix AsyncLL kernel launch sequence in C++ path (used by JAX FFI) to match the split kernel names actually defined in ep_async_ll.hip. The previous code referenced non-existent combined kernel names. - dispatch_combine: guard ~EpDispatchCombineHandle against use-after- finalize when XLA destroys cached FFI state after shmem_finalize. - jax/ops.py: add AsyncLL to get_dispatch_src_token_pos kernel type list. - tests: run each SPMT world_size in an isolated subprocess to ensure clean shmem lifecycle (AnvilLib singleton and KFD SDMA queues are not released by shmem_finalize). Add test_dispatch_combine_jax_spmt_sdma.py covering dispatch+combine E2E with data verification for world_size 2, 4, 8. Tested: - JAX SPMT IntraNode: 3 passed (world_size 2, 4, 8) - JAX SPMT AsyncLL+SDMA: 3 passed (world_size 2, 4, 8) - Torch multi-process IntraNode: passed - Torch multi-process AsyncLL IBGDA: 68 passed - Torch multi-process AsyncLL SDMA: 68 passed Co-authored-by: Cursor <cursoragent@cursor.com>
700ca53 to
616c7b7
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-Authored-By: @pemeliya