Skip to content

feat(shmem): add MORI_MULTITHREAD_SUPPORT for SPMT#308

Open
jhchouuu wants to merge 14 commits into
mainfrom
jiahzhou/shmem-multithread-support
Open

feat(shmem): add MORI_MULTITHREAD_SUPPORT for SPMT#308
jhchouuu wants to merge 14 commits into
mainfrom
jiahzhou/shmem-multithread-support

Conversation

@jhchouuu
Copy link
Copy Markdown
Collaborator

@jhchouuu jhchouuu commented May 8, 2026

Co-Authored-By: @pemeliya

@jhchouuu jhchouuu requested review from i-chaochen and pemeliya May 9, 2026 07:44
// exit so XLA's other state isn't disturbed) before any HIP call. This
// ensures GetHandleCacheSlot() and ShmemStatesSingleton::GetInstance()
// (both keyed by hipGetDevice()) hit the right slot.
ScopedDevice _dev_guard(mori::shmem::ShmemStatesSingleton::GetDeviceByRank(cfg.rank));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this been checked? I was under impression that xla ensures that the "context" device is set to the one the thread serves.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and could you please check it again? Maybe my understanding is incorrect.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue here is that, we cannot initialize MORI in a correct way (when SIMT is on), unless MORI is integrated to XLA. This is because XLA is managing GPU threads, so ShmemInit() shall be called by XLA intentally when GPU communicators are created

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my original branch, I added some hacky way for MORI initialization to see if it generally works


namespace mori {

inline constexpr int kMaxGpusPerNode = 8;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with CPX it is 32. Not sure if it can be split even more than that.

Copy link
Copy Markdown
Contributor

@i-chaochen i-chaochen May 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we will have 32 72 GPUs at rack level. It's best to not hardcode this or we should get this max number of GPU from the build script?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, CPX and rack level will have more than 8 GPUs. We have already tried this on CPX, but it was limited to a single card before... Additionally, the rack level is also included in our plan...
So currently, kMaxGpusPerNode equals to 8...

if (i == rank) continue;
void* mappedPtr = nullptr;
HIP_RUNTIME_CHECK(
hipIpcOpenMemHandle(&mappedPtr, signalHandles[i], hipIpcMemLazyEnablePeerAccess));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried running my XLA SPMT test - and getting the issue here:

[/data/mori/src/application/memory/symmetric_memory.cpp:251] hip failed with invalid device context

I think we shall guard this block by SameProcessP2P(i)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your XLA SPMT test is related to SDMA? I think I should add a verification for the SDMA path to my TODO list... I haven't verified SDMA...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I use simple SDMA-based collectives there. But I think RegisterSymmMemObj can be used also without SDMA transport

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I'm continuing my work on SDMA. I will verify the SDMA path through the async_ll EP kernel.

jhchouuu and others added 11 commits May 13, 2026 10:13
…read (SPMT)

- Add MORI_MULTITHREAD_SUPPORT cmake option (default OFF)
- ShmemStatesSingleton::GetInstance() returns per-GPU slot via hipGetDevice()
  using std::array<ShmemStates, 8> for stable addresses and lock-free reads
- Embed GpuStates and ModuleStates as values in ShmemStates (no heap alloc)
- Remove file-scope globals: s_hostGpuStatesCopy, s_shmemModule,
  s_deviceGpuStatesAddr, s_barrierFunc
- All init/finalize functions take explicit ShmemStates* (thread-safe)
- ShmemFinalize resets status to New to allow re-init on same GPU
- PID allgather in CollectHostNames for same-process P2P detection
- Skip hipIpcOpenMemHandle for same-process peers (use direct pointer)
- Python: per-GPU JIT module loading with double-check locking
- Python: GIL release on all blocking shmem APIs
- mori_log: try/catch spdlog race on concurrent logger registration
- Add tests/python/shmem/test_spmt.py (world_size 2/4/8)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- pybind_shmem: remove duplicate m.def for shmem_finalize / shmem_mype /
  shmem_npes / shmem_torch_process_group_init that overrode the GIL-release
  variants and effectively serialized SPMT finalize.
- symmetric_memory: explicitly hipDeviceEnablePeerAccess for same-process
  peers. The IPC-handle path (lazy enable via hipIpcMemLazyEnablePeerAccess)
  is skipped for same-process, so without this fix P2P-only SPMT would hit
  invalid-device-pointer at peer access time. Use hipPointerGetAttributes
  to discover the peer's device id without assuming a rank-to-device map.
- init.cpp ShmemFinalize: run FinalizeInternalSync before FinalizeGpuStates.
  FinalizeGpuStates calls FinalizeRuntime which clears states->gpuStates,
  including internalSyncPtr — running it first made FinalizeInternalSync
  early-return and leak the sync memory each init/finalize cycle.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…(SPMT) JAX

Make the XLA FFI EP path SPMT-safe so JAX users can drive multiple GPUs
from one Python process (jax.devices() returns all 8) instead of being
forced into the multi-process jax.distributed.initialize model. All
changes are gated by MORI_MULTITHREAD_SUPPORT; the macro-OFF and
multi-process paths are byte-identical to before.

Three structural caches that used hipModuleLoad-bound resources were
process-global singletons. Under SPMT each thread is bound to its own
device, so a singleton hands the wrong device's resources to other
threads and the EP launches crash or silently use the wrong context:

- KernelRegistry::GetImpl: was a static singleton holding loaded
  hipModule_t. Now std::array<Impl, 8> indexed by hipGetDevice(). In
  multi-process every process sees its single GPU as device 0 and
  collapses to slot[0] — equivalent to the old singleton.

- pybind_xla_ffi_ops g_handle_cache: was process-global with one
  mutex. Under SPMT this would deadlock — thread A holds the mutex
  during EpDispatchCombineHandle's cross-PE Barrier(), thread B blocks
  waiting for the mutex and never reaches its own Barrier(). Replaced
  with per-GPU HandleCacheSlot (own mutex + map).

- ShmemStatesSingleton rank→device map: XLA FFI handlers run on
  framework worker threads where hipGetDevice() does NOT match the
  rank's device. Added RegisterRankDevice/GetDeviceByRank, populated
  in InitializeBootStates from the user thread's device. FFI handlers
  (Instantiate + Impl) now look up the rank's device and hipSetDevice
  to it before any state access.

python/mori/shmem/api.py:_ensure_shmem_module no longer imports torch
to read the current device — that import broke JAX containers that
ship without torch. Now uses ctypes hipGetDevice via the existing
mori.jit.hip_driver helper, working uniformly for both torch and JAX.

Adds tests/python/ops/test_dispatch_combine_jax_spmt.py: spawns one
host thread per GPU in a single Python process, each thread runs
ShmemInit + EpDispatchCombineOp dispatch+combine via XLA FFI. Verifies
the full SPMT JAX path end-to-end. Per-thread shmem_finalize +
clear_ep_handle_cache so the parametrized 2/4/8 GPU sizes can run in
one pytest invocation.

Verified:
  MORI-EP JAX SPMT 2/4/8 GPU            3 passed in 22s
  MORI-EP JAX multi-process (existing)  1 passed in 17s (no regression)
  MORI-EP intranode (torchrun)          115 passed, 208 skipped
  MORI-SPMT shmem control plane         3 passed in 3s

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…slot

Under SPMT, each thread owns one GPU and one cache slot. The previous
implementation iterated ALL per-GPU slots from any caller, so when
thread 0 called clear_ep_handle_cache it ran ~EpDispatchCombineHandle
on thread 1's, 2's, ... handles too. Each ~Handle calls ShmemFree on
the buffers it allocated on its own GPU's symmetric heap, but the
calling thread's hipDevice was still 0, so ShmemFree looked up those
addresses in GPU 0's HeapVAManager and reported "address not found"
hundreds of times before the test process eventually aborted with
SIGABRT during teardown.

Fix: only clear the slot returned by GetHandleCacheSlot() (the calling
thread's slot under SPMT, the global slot in single-GPU mode). Each
SPMT thread is responsible for clearing its own cache as part of its
shmem_finalize sequence — same pattern as ShmemStatesSingleton.

Also: in tests/python/ops/test_dispatch_combine_jax_spmt.py add the
gc.collect() between cache clear and shmem_finalize (mirrors
mori.jax.shmem_finalize). Do NOT call jax.clear_caches() — it is
process-global and races across SPMT threads.

After this fix the SPMT JAX EP test exits 0 cleanly with zero
HeapVAManager errors, across 5 consecutive 2/4/8 GPU runs.
Multi-process JAX EP regression unaffected.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ce guard

Self-review cleanups on top of the SPMT branch.

1. Centralize kMaxGpusPerNode in mori/utils/limits.hpp. Previously the
   constant was duplicated in three places (internal.hpp, launch.cpp,
   pybind_xla_ffi_ops.cpp) and bumping it for >8-GPU nodes (e.g. future
   MI400) would have required editing all three. Now a single
   inline constexpr that any TU can pull in cheaply.

2. Drop dead ShmemStatesSingleton::mutex_ field. The comment claimed
   "we still take a brief lock the first time to guard concurrent
   ShmemInit", but no code actually locks it. SPMT's contract is one
   thread per GPU, so each slot is accessed serially by its owner
   thread; cross-thread synchronization is only needed for the rank →
   device map below, which has its own mutex.

3. Drop dead ShmemStatesStatus::Finalized. ShmemFinalize resets to New
   (so the slot can be reused for re-init), so the Finalized state was
   never actually set, which made the check `if (status == Finalized)`
   in ShmemInit dead code. Removed both.

4. RAII ScopedDevice guard for XLA FFI handlers. The previous code
   called hipSetDevice(rank_dev) directly, leaving XLA's worker thread
   bound to a different device than what XLA had set on entry — a
   subtle violation of the convention that XLA owns its worker thread
   state. ScopedDevice restores the saved device on scope exit so the
   change is local to the FFI handler call.

5. Also: drop spurious (void)hipGetLastError() calls in three lookups.
   They were silently swallowing real errors from prior unchecked HIP
   calls. hipGetDevice / hipSetDevice return their own status; sticky
   errors only surface on next hipDeviceSynchronize / kernel launch.

6. Test doc fix: clarify why _build_config passes gpu_per_node ==
   world_size (single-node SPMT, EP handle requires
   IsPowerOf2(gpuPerNode) && worldSize % gpuPerNode == 0). I briefly
   tried to "fix" this to use the physical GPU count and tripped the
   assertion — kept the original behavior with explanatory comments.

Verified:
  SPMT JAX EP 2/4/8 GPU      5/5 runs clean exit
  Multi-process JAX EP        no regression
  intranode EP (torchrun)     115 passed
  shmem SPMT control plane    3 passed

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The earlier cleanup (d27a5df) deleted ShmemStatesStatus::Finalized as
dead code on the rationale that ShmemFinalize() resets the slot to `New`
to allow re-init, so Finalized was never actually set or checked.

That removal was over-aggressive. The Finalized state value is documented
intent (a slot can be in three logical phases: never-init / live /
torn-down) and someone might want terminal-finalize semantics later — for
example, to print a clearer diagnostic when a finalized slot is touched
again, or to forbid re-init in a stricter deployment.

Restore both the enum value and the corresponding check in
CheckStatusValid(). Leave ShmemFinalize() as-is (resets to New) so SPMT
test suites that init/finalize multiple times keep working; if/when
finalize semantics need to flip, only the line in ShmemFinalize() needs
to change.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ests)

The C++ SPMT exploration example was useful as a proof-of-concept while
designing the SPMT implementation, but it has two structural problems
that make it a liability now:

1. False-pass behavior: when the cross-PE collective-permute kernel
   produces wrong data, the example still marks the result as PASS
   (line 194: `result.permute_pass = true; // expected under static
   binary SPMT`). A "regression test" that always passes regardless of
   correctness is worse than no test — it gives false confidence.

2. The kernel is documented to be incorrect: the example's own header
   comments admit that under a statically-compiled HIP binary,
   globalGpuStates is a single device symbol shared across all SPMT
   threads, so the device-side kernel result "may not be correct under
   SPMT". The collective-permute kernel exists but does nothing
   meaningful as a regression check.

Coverage is now provided by two Python tests:
  - tests/python/shmem/test_spmt.py — shmem control plane
    (init/finalize/malloc/barrier) with real assertions
  - tests/python/ops/test_dispatch_combine_jax_spmt.py — full EP
    dispatch+combine round-trip with data verification

Both Python tests use JIT-loaded modules, so each GPU has its own
globalGpuStates and the kernels actually exercise SPMT correctly.

Drop the example and its CMake gate. Net change: -273 lines, no loss
of SPMT test coverage.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
CI's pre-commit hook flagged formatting drift on three of the SPMT
commits in this PR. No semantic changes — purely:

- black: line wrapping in tests/python/shmem/test_spmt.py and
  tests/python/ops/test_dispatch_combine_jax_spmt.py
- clang-format: line wrapping in symmetric_memory.cpp,
  dispatch_combine/launch.cpp, and pybind_shmem.cpp
- cmake-format: trailing blank line in CMakeLists.txt

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Three small follow-up cleanups from a self-review pass on the SPMT PR.
No behavior change for happy-path callers; pure robustness/cleanup.

1. python/mori/shmem/api.py: _current_hip_device() now sets explicit
   argtypes/restype on hip.hipGetDevice. Without these ctypes assumes
   int args + int return, which happens to be right on x86_64 Linux
   but is not portable. Be explicit so future ABI shifts don't
   silently corrupt the device id we read.

2. src/pybind/pybind_xla_ffi_ops.cpp: EpDispatchCombineInstantiate
   was decoding the packed ep_config twice — once to extract `rank`
   for SPMT device routing, and once on cache miss to construct the
   handle. Decode once and reuse. Saves a small amount of work on
   the cache-miss path; mostly a readability win.

3. include/mori/utils/mori_log.hpp: after the try/catch around
   spdlog::stdout_color_mt, fall back to spdlog::get() — but that
   call can in principle return null (e.g. if the registry was
   dropped between the throw and our second lookup). Bail out
   cleanly instead of dereferencing a null shared_ptr below.

Verified:
  SPMT JAX EP 2/4/8 GPU         3 passed in 24s
  Multi-process JAX EP          1 passed in 17s

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous SPMT JAX EP smoke test only checked that the FFI handlers
returned without crashing — it materialized one byte of the combine
output and called it a day. That's enough to catch deadlocks and gross
crashes, but it would silently pass even if dispatch routed tokens to
the wrong rank or combine produced garbage.

Port the dispatch/combine validation from
test_dispatch_combine_jax.py:

- _validate_dispatch: decode (sender_pe, local_tok_id) from
  src_token_pos, look up the original input via
  inputs_list[pe * inp_tok_per_rank + local_id], and check it matches
  the dispatched output. Also check no two received tokens share the
  same src_pos (no double-delivery).

- _validate_combine: each input token is dispatched to `unique_pes`
  distinct PEs; combine sums the unique_pes copies, so combined output
  should equal `input * unique_pes` (within bf16 atol/rtol).

The multi-process test does cross-rank all-gather via shard_map +
jax.lax.all_gather. SPMT can't use that (no jax.distributed).
Instead, every rank generates inputs deterministically from
PRNGKey(BASE_SEED + rank), so each thread can locally reconstruct
every other rank's inputs by re-seeding — no cross-thread comm needed.

Also add the env-var bypass set by the multi-process test:
  MORI_SHMEM_HEAP_SIZE=16G   (4G default is tight for 8-GPU EP)
  XLA_FLAGS:
    --xla_gpu_autotune_level=0           (skip slow first-JIT autotune)
    --xla_gpu_enable_command_buffer=     (disable HIP graph)
    --xla_gpu_enable_triton_gemm=false   (avoid Triton-AMDGPU pass errs)

Verified: 2/4/8 GPU all PASS with both "dispatch data verified" and
"combine data verified" printed per thread; clean exit; no Triton noise.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Black wrapped a few function signatures and call sites differently
than I had them. Pure formatting; no behavior change.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@jhchouuu jhchouuu force-pushed the jiahzhou/shmem-multithread-support branch from 6011acb to 700ca53 Compare May 19, 2026 09:32
…i-thread)

Enable AsyncLL+SDMA kernel path to work correctly in SPMT mode where
multiple GPU threads share the same process address space.

Core changes:

- anvil: key sdma_channels_ by (srcDeviceId, dstDeviceId) pair instead
  of dstDeviceId alone, preventing channel cross-contamination between
  GPUs in the same process. Add mutex for thread-safe map access during
  concurrent shmem_init from multiple SPMT threads.

- symmetric_memory: for same-process peers (SPMT), exchange SDMA signal
  pointers via Allgather raw VA + hipDeviceEnablePeerAccess instead of
  hipIpcOpenMemHandle (which fails within the same process). Clear HIP
  sticky errors after hipDeviceEnablePeerAccess. On deregistration, close
  SDMA signal IPC handles for cross-process peers and free all SDMA GPU
  allocations (signalPtrs, expectSignalsPtr, peerSignalPtrs, deviceHandles_d).

- launch: fix AsyncLL kernel launch sequence in C++ path (used by JAX
  FFI) to match the split kernel names actually defined in ep_async_ll.hip.
  The previous code referenced non-existent combined kernel names.

- dispatch_combine: guard ~EpDispatchCombineHandle against use-after-
  finalize when XLA destroys cached FFI state after shmem_finalize.

- jax/ops.py: add AsyncLL to get_dispatch_src_token_pos kernel type list.

- tests: run each SPMT world_size in an isolated subprocess to ensure
  clean shmem lifecycle (AnvilLib singleton and KFD SDMA queues are not
  released by shmem_finalize). Add test_dispatch_combine_jax_spmt_sdma.py
  covering dispatch+combine E2E with data verification for world_size
  2, 4, 8.

Tested: - JAX SPMT IntraNode: 3 passed (world_size 2, 4, 8)
  - JAX SPMT AsyncLL+SDMA: 3 passed (world_size 2, 4, 8)
  - Torch multi-process IntraNode: passed
  - Torch multi-process AsyncLL IBGDA: 68 passed
  - Torch multi-process AsyncLL SDMA: 68 passed
Co-authored-by: Cursor <cursoragent@cursor.com>
@jhchouuu jhchouuu force-pushed the jiahzhou/shmem-multithread-support branch from 700ca53 to 616c7b7 Compare May 19, 2026 09:48
jhchouuu and others added 2 commits May 19, 2026 09:51
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants