diff --git a/CMakeLists.txt b/CMakeLists.txt index 059807e8..12ea7c3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,8 @@ option(ENABLE_PROFILER "Enable kernel profiling" OFF) option(ENABLE_DEBUG_PRINTF "Enable debug printf in device kernels" OFF) option(ENABLE_STANDARD_MOE_ADAPT "Enable standard moe adapt" OFF) option(BUILD_OPS_DEVICE "AOT compile EP kernels to .hsaco(requires hipcc)" OFF) +option(MORI_MULTITHREAD_SUPPORT + "Enable single-process multi-thread multi-GPU support in mori_shmem" OFF) # --------------------------------------------------------------------------- # Host NIC: all vendor-specific dv libraries (mlx5, bnxt_re, ionic) are loaded diff --git a/include/mori/application/application_device_types.hpp b/include/mori/application/application_device_types.hpp index 87faec10..589f84cf 100644 --- a/include/mori/application/application_device_types.hpp +++ b/include/mori/application/application_device_types.hpp @@ -109,6 +109,10 @@ struct SymmMemObj { // SdmaPutThread writes ATOMIC to peerSignalPtrs[remotePe] + myPe*sdmaNumQueue + qId, // so the remote PE can directly read its own signalPtrs to detect completion. HSAuint64** peerSignalPtrs = nullptr; // should only placed on GPU + // Host-side copy of peer signal pointers for IPC cleanup during deregistration. + // Only entries opened via hipIpcOpenMemHandle need closing; same-process (SPMT) + // entries are raw VA and must NOT be closed. + HSAuint64** peerSignalPtrsHost = nullptr; // should only placed on CPU __device__ __host__ RdmaMemoryRegion GetRdmaMemoryRegion(int pe) const { RdmaMemoryRegion mr; diff --git a/include/mori/application/context/context.hpp b/include/mori/application/context/context.hpp index 51a599e1..36acb428 100644 --- a/include/mori/application/context/context.hpp +++ b/include/mori/application/context/context.hpp @@ -39,10 +39,10 @@ class Context { int LocalRank() const { return bootNet.GetLocalRank(); } int WorldSize() const { return bootNet.GetWorldSize(); } int LocalRankInNode() const { return rankInNode; } - std::string HostName() const; + const std::string& HostName() const { return myHostname; } TransportType GetTransportType(int destRank) const { return transportTypes[destRank]; } - std::vector GetTransportTypes() const { return transportTypes; } + const std::vector& GetTransportTypes() const { return transportTypes; } int GetNumQpPerPe() const { return numQpPerPe; } RdmaContext* GetRdmaContext() const { return rdmaContext.get(); } @@ -51,6 +51,8 @@ class Context { // Check if P2P connection is possible with a peer (same node) bool CanUseP2P(int destRank) const; + // Check if peer is in the same OS process (enables direct pointer access, skip IPC handle) + bool SameProcessP2P(int destRank) const; // Cached env-var snapshot taken at construction time. All later code MUST // consult these (not getenv) so that env-var changes after Context init @@ -67,6 +69,11 @@ class Context { void CollectHostNames(); void InitializePossibleTransports(); + struct PeerInfo { + bool sameHost{false}; // on the same node (same hostname+IP) + bool sameProcess{false}; // in the same OS process (same pid + same host) + }; + private: BootstrapNetwork& bootNet; int rankInNode{-1}; @@ -74,7 +81,8 @@ class Context { // Snapshotted at construction; see IsSdmaEnabled() / IsP2PDisabled() above. bool sdmaEnabled{false}; bool p2pDisabled{false}; - std::vector hostnames; + std::string myHostname; + std::vector peerInfos; std::vector transportTypes; std::unique_ptr rdmaContext{nullptr}; diff --git a/include/mori/application/transport/sdma/anvil.hpp b/include/mori/application/transport/sdma/anvil.hpp index 2b86c988..89a707fc 100644 --- a/include/mori/application/transport/sdma/anvil.hpp +++ b/include/mori/application/transport/sdma/anvil.hpp @@ -102,8 +102,16 @@ class AnvilLib { int getSdmaEngineId(int srcDeviceId, int dstDeviceId); + struct PairHash { + std::size_t operator()(const std::pair& p) const { + return std::hash()(p.first) ^ (std::hash()(p.second) << 16); + } + }; + std::once_flag init_flag; - std::unordered_map>> sdma_channels_; + std::mutex channels_mutex_; + std::unordered_map, std::vector>, PairHash> + sdma_channels_; }; extern AnvilLib& anvil; diff --git a/include/mori/shmem/internal.hpp b/include/mori/shmem/internal.hpp index 0e76037c..c389e148 100644 --- a/include/mori/shmem/internal.hpp +++ b/include/mori/shmem/internal.hpp @@ -25,10 +25,12 @@ #include "mori/application/application_device_types.hpp" #include "mori/core/utils.hpp" #include "mori/hip_compat.hpp" +#include "mori/utils/limits.hpp" // Host-only includes: STL, ibverbs, application management classes. // Guarded so device compilation units (.hip files) do not pull them in. #if !defined(__HIPCC__) && !defined(__CUDACC__) +#include #include #include #include @@ -95,38 +97,10 @@ struct MemoryStates { application::SymmMemObjPtr vmmHeapObj; // SymmMemObj for the entire heap }; -enum ShmemStatesStatus { - New = 0, - Initialized = 1, - Finalized = 2, -}; - -struct ShmemStates { - ShmemStatesStatus status{ShmemStatesStatus::New}; - ShmemMode mode{ShmemMode::StaticHeap}; // Default to static heap mode - BootStates* bootStates{nullptr}; - RdmaStates* rdmaStates{nullptr}; - MemoryStates* memoryStates{nullptr}; - - // This is a temporary API for debugging only - void CheckStatusValid() { - if (status == ShmemStatesStatus::New) { - std::cout - << "Shmem state is not initialized, initialize it by calling ShmemMpiInitialize first." - << std::endl; - assert(false); - } - if (status == ShmemStatesStatus::Finalized) { - std::cout << "Shmem state has been finalized." << std::endl; - assert(false); - } - } -}; - #endif // !defined(__HIPCC__) && !defined(__CUDACC__) /* ---------------------------------------------------------------------------------------------- */ -/* Device-safe GPU-side structures */ +/* Device-safe GPU-side structures */ /* ---------------------------------------------------------------------------------------------- */ // GPU-side RDMA endpoint: only the fields used by device kernels. @@ -155,6 +129,7 @@ struct ShmemRdmaEndpoint { } }; +// GpuStates must be declared before ModuleStates and ShmemStates which embed it. struct GpuStates { int rank{-1}; int worldSize{-1}; @@ -196,19 +171,82 @@ struct RemoteAddrInfo { #if !defined(__HIPCC__) && !defined(__CUDACC__) +enum ShmemStatesStatus { + New = 0, + Initialized = 1, + // Finalized: reserved. ShmemFinalize() currently resets the slot to `New` + // so the same GPU can be re-initialized later (needed by SPMT test suites + // that run multiple init/finalize cycles). Keep this value for the case + // where future finalize semantics need to mark the slot as terminally done. + Finalized = 2, +}; + +// Per-GPU JIT module state (HIP module handle + device symbol pointers) +struct ModuleStates { + hipModule_t module{nullptr}; + GpuStates* gpuStatesPtr{nullptr}; // device-side globalGpuStates address in JIT module + hipFunction_t barrierFunc{nullptr}; +}; + +struct ShmemStates { + ShmemStatesStatus status{ShmemStatesStatus::New}; + ShmemMode mode{ShmemMode::StaticHeap}; // Default to static heap mode + BootStates* bootStates{nullptr}; + RdmaStates* rdmaStates{nullptr}; + MemoryStates* memoryStates{nullptr}; + ModuleStates moduleStates; // JIT module state for this GPU + GpuStates gpuStates; // host-side copy of device GpuStates for this GPU + + // Asserts that ShmemInit has been called and the slot is currently usable. + // Used by APIs that touch GPU state (allocation, barrier, module init) + // which need a fully constructed slot. + void CheckStatusValid() { + if (status == ShmemStatesStatus::New) { + std::cout << "Shmem state is not initialized, call ShmemInit*/shmem_init_attr first." + << std::endl; + assert(false); + } + if (status == ShmemStatesStatus::Finalized) { + std::cout << "Shmem state has been finalized." << std::endl; + assert(false); + } + } +}; + // Internal functions shared between init.cpp and runtime.cpp -void CopyGpuStatesToDevice(const GpuStates* gpuStates); -void FinalizeRuntime(); -extern GpuStates s_hostGpuStatesCopy; +void CopyGpuStatesToDevice(ShmemStates* states); +void FinalizeRuntime(ShmemStates* states); class ShmemStatesSingleton { public: ShmemStatesSingleton(const ShmemStatesSingleton& obj) = delete; - static ShmemStates* GetInstance() { - static ShmemStates states; - return &states; - } + static ShmemStates* GetInstance(); + +#ifdef MORI_MULTITHREAD_SUPPORT + // SPMT: rank → HIP device id mapping, populated at ShmemInit. + // + // Needed by FFI/custom-call handlers (e.g. XLA) that run on framework worker + // threads where hipGetDevice() does not return the rank's device. The handler + // can look up the device for a given rank and hipSetDevice() to it before + // touching MORI state. + // + // Returns -1 if no rank-to-device mapping has been recorded yet (caller + // should fall back to hipGetDevice()-based lookup or fail loudly). + static void RegisterRankDevice(int rank, int deviceId); + static int GetDeviceByRank(int rank); +#endif + + private: +#ifdef MORI_MULTITHREAD_SUPPORT + // One ShmemStates slot per GPU, indexed by hipGetDevice(). + // std::array gives stable addresses (no realloc unlike deque/vector). + // No lock needed: SPMT contract is one thread per GPU, so each slot is + // accessed serially by its owning thread; the rank → device map below is + // the only structure that needs cross-thread synchronization. + std::array states_{}; + ShmemStatesSingleton() = default; +#endif }; #endif // !defined(__HIPCC__) && !defined(__CUDACC__) diff --git a/include/mori/utils/limits.hpp b/include/mori/utils/limits.hpp new file mode 100644 index 00000000..f6236e50 --- /dev/null +++ b/include/mori/utils/limits.hpp @@ -0,0 +1,33 @@ +// Copyright © Advanced Micro Devices, Inc. All rights reserved. +// +// MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +#pragma once + +// Tiny header so any TU (device or host) can size per-GPU arrays without +// pulling in shmem/ops headers. Bump this when supporting nodes with > 8 GPUs +// (e.g. future MI400 platforms) — every per-GPU array sized by this constant +// will pick up the new value automatically. + +namespace mori { + +inline constexpr int kMaxGpusPerNode = 8; + +} // namespace mori diff --git a/include/mori/utils/mori_log.hpp b/include/mori/utils/mori_log.hpp index 4fba3ace..e6700dce 100644 --- a/include/mori/utils/mori_log.hpp +++ b/include/mori/utils/mori_log.hpp @@ -66,9 +66,18 @@ class ModuleLogger { // Use existing logger logger = existing_logger; } else { - // Create new logger - logger = spdlog::stdout_color_mt(moduleName); - logger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%P] [%n] [%^%l%$] %v"); + // spdlog::stdout_color_mt throws if another thread already registered the same name + // between our spdlog::get() check and this call — catch and fall back to the winner. + try { + logger = spdlog::stdout_color_mt(moduleName); + logger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%P] [%n] [%^%l%$] %v"); + } catch (const spdlog::spdlog_ex&) { + logger = spdlog::get(moduleName); + } + // Defensive: spdlog::get may still return null if registration was + // dropped between throw and our second lookup. Bail out cleanly + // instead of dereferencing a null shared_ptr below. + if (!logger) return; } // Determine the log level priority: env var > global setting > provided level diff --git a/python/mori/jax/ops.py b/python/mori/jax/ops.py index c7e2f2ea..0f50c846 100755 --- a/python/mori/jax/ops.py +++ b/python/mori/jax/ops.py @@ -173,8 +173,8 @@ def get_dispatch_src_token_pos(self, total_recv_token_num): cpp.EpDispatchCombineKernelType.IntraNode.value, cpp.EpDispatchCombineKernelType.InterNodeV1.value, cpp.EpDispatchCombineKernelType.InterNodeV1LL.value, + cpp.EpDispatchCombineKernelType.AsyncLL.value, ): - # here we need to allocate enough space to accomodate handle->totalRecvTokenNum[0] items n_tokens = self.config.max_num_tokens_to_recv() return jax.ffi.ffi_call( "mori_ep", diff --git a/python/mori/shmem/api.py b/python/mori/shmem/api.py index 529e1ea1..efdd96a7 100644 --- a/python/mori/shmem/api.py +++ b/python/mori/shmem/api.py @@ -19,25 +19,64 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import ctypes +import threading + from mori import cpp as mori_cpp # Initialization flags MORI_SHMEM_INIT_WITH_MPI_COMM = mori_cpp.MORI_SHMEM_INIT_WITH_MPI_COMM MORI_SHMEM_INIT_WITH_UNIQUEID = mori_cpp.MORI_SHMEM_INIT_WITH_UNIQUEID -_shmem_module_loaded = False +# Per-GPU module loading: keyed by device ID so that each GPU context gets its +# own hipModuleLoad call even when multiple threads share one process (SPMT). +_shmem_module_lock = threading.Lock() +_shmem_module_loaded_gpus: set = set() +# Cached hsaco path (compilation is arch-specific, not instance-specific). +_shmem_hsaco: str = "" + + +def _current_hip_device() -> int: + """Return the calling thread's current HIP device id. + + Uses ctypes against libamdhip64 directly so that the JAX path (which has + no torch dependency) works the same as the PyTorch path. + """ + from mori.jit.hip_driver import _get_hip_lib + + hip = _get_hip_lib() + # Set explicit ctypes signatures — without these, ctypes assumes int args + # and int return, which happens to be right on x86_64 Linux but is not + # portable. Be explicit so future ABI changes don't silently break us. + hip.hipGetDevice.argtypes = [ctypes.POINTER(ctypes.c_int)] + hip.hipGetDevice.restype = ctypes.c_int + dev = ctypes.c_int(-1) + err = hip.hipGetDevice(ctypes.byref(dev)) + if err != 0: + raise RuntimeError(f"hipGetDevice failed with error {err}") + return int(dev.value) def _ensure_shmem_module(): - """JIT-compile and load the shmem device module before ShmemInit.""" - global _shmem_module_loaded - if _shmem_module_loaded: + """JIT-compile and load the shmem device module before ShmemInit. + + Thread-safe: each GPU device context gets exactly one load_shmem_module + call, enabling single-process multi-thread (SPMT) use where each thread + owns a different GPU. + """ + device_id = _current_hip_device() + if device_id in _shmem_module_loaded_gpus: return - from mori.jit.core import compile_genco + with _shmem_module_lock: + if device_id in _shmem_module_loaded_gpus: + return + global _shmem_hsaco + if not _shmem_hsaco: + from mori.jit.core import compile_genco - hsaco = compile_genco("shmem_kernels") - mori_cpp.load_shmem_module(hsaco) - _shmem_module_loaded = True + _shmem_hsaco = compile_genco("shmem_kernels") + mori_cpp.load_shmem_module(_shmem_hsaco) + _shmem_module_loaded_gpus.add(device_id) def shmem_torch_process_group_init(group_name: str): @@ -124,7 +163,17 @@ def shmem_finalize(): Returns: Status code (0 for success) """ - return mori_cpp.shmem_finalize() + ret = mori_cpp.shmem_finalize() + # Clear this GPU's module-loaded flag so a subsequent shmem_init_attr + # call (e.g. in the next test round) will reload the JIT module. + try: + device_id = _current_hip_device() + except Exception: + # If HIP context is gone (e.g. process teardown), skip cache cleanup. + return ret + with _shmem_module_lock: + _shmem_module_loaded_gpus.discard(device_id) + return ret def shmem_module_init(hip_module: int): diff --git a/setup.py b/setup.py index 3b2ccdad..c1129ca3 100644 --- a/setup.py +++ b/setup.py @@ -354,6 +354,7 @@ def build_extension(self, ext: Extension) -> None: enable_debug_printf = os.environ.get("ENABLE_DEBUG_PRINTF", "OFF") enable_standard_moe_adapt = os.environ.get("ENABLE_STANDARD_MOE_ADAPT", "OFF") + multithread_support = os.environ.get("MORI_MULTITHREAD_SUPPORT", "OFF") gpu_archs = _get_gpu_archs() print(f"[mori] GPU architecture: {gpu_archs}") build_examples = os.environ.get("BUILD_EXAMPLES", "OFF") @@ -394,6 +395,7 @@ def build_extension(self, ext: Extension) -> None: "-DBUILD_TORCH_BOOTSTRAP=OFF", f"-DBUILD_XLA_FFI_OPS={build_xla_ffi_ops}", f"-DBUILD_OPS_DEVICE={build_ops_device}", + f"-DMORI_MULTITHREAD_SUPPORT={multithread_support}", "-B", str(build_dir), "-S", diff --git a/src/application/context/context.cpp b/src/application/context/context.cpp index cbbc9063..65da856f 100644 --- a/src/application/context/context.cpp +++ b/src/application/context/context.cpp @@ -89,16 +89,18 @@ std::string GetLocalIP() { return localIP; } -std::string Context::HostName() const { return hostnames[LocalRank()]; } - bool Context::CanUseP2P(int destRank) const { if (destRank == LocalRank()) { return false; // Cannot use P2P with self } - // Check if on the same node by comparing hostnames - // Note: IsP2PDisabled only affects transport type selection (peerPtrs), - // but we still maintain P2P data path in p2pPeerPtrs - return HostName() == hostnames[destRank]; + return peerInfos[destRank].sameHost; +} + +bool Context::SameProcessP2P(int destRank) const { + if (destRank == LocalRank()) { + return false; + } + return peerInfos[destRank].sameProcess; } void Context::CollectHostNames() { @@ -107,24 +109,33 @@ void Context::CollectHostNames() { // Keep node identity stable across ranks on the same machine. // Using hostname+IP can split local ranks when different NICs are selected. - std::string hostIdentifier = std::string(hostname); - constexpr int IDENTIFIER_MAX = HOST_NAME_MAX + INET_ADDRSTRLEN; - std::vector globalIdentifiers(IDENTIFIER_MAX * WorldSize()); - // Create a non-const buffer for Allgather - char localBuffer[IDENTIFIER_MAX]; - strncpy(localBuffer, hostIdentifier.c_str(), IDENTIFIER_MAX - 1); - localBuffer[IDENTIFIER_MAX - 1] = '\0'; - bootNet.Allgather(localBuffer, globalIdentifiers.data(), IDENTIFIER_MAX); + // Pack pid + hostname into a fixed-size buffer for Allgather. + // Using a fixed layout avoids string parsing ambiguity. + constexpr int kPidSize = sizeof(pid_t); + constexpr int kStrMax = HOST_NAME_MAX + 1; // +1 for '\0' + constexpr int kRecordSize = kPidSize + kStrMax; - for (int i = 0; i < WorldSize(); i++) { - hostnames.push_back(&globalIdentifiers.data()[i * IDENTIFIER_MAX]); - } + pid_t myPid = getpid(); + char localBuffer[kRecordSize]; + memcpy(localBuffer, &myPid, kPidSize); + snprintf(localBuffer + kPidSize, kStrMax, "%s", hostname); - if (LocalRank() == 0) { - MORI_APP_TRACE("Collected hostnames:"); - for (int i = 0; i < hostnames.size(); i++) { - MORI_APP_TRACE(" rank {}: {}", i, hostnames[i]); + std::vector global(kRecordSize * WorldSize()); + bootNet.Allgather(localBuffer, global.data(), kRecordSize); + + myHostname = std::string(localBuffer + kPidSize); + peerInfos.resize(WorldSize()); + for (int i = 0; i < WorldSize(); i++) { + const char* rec = global.data() + i * kRecordSize; + pid_t peerPid; + memcpy(&peerPid, rec, kPidSize); + std::string peerHost(rec + kPidSize); + peerInfos[i].sameHost = (peerHost == myHostname); + peerInfos[i].sameProcess = peerInfos[i].sameHost && (peerPid == myPid); + if (LocalRank() == 0) { + MORI_APP_TRACE("rank {} hostname={} pid={} sameHost={} sameProcess={}", i, peerHost, peerPid, + peerInfos[i].sameHost, peerInfos[i].sameProcess); } } } @@ -137,7 +148,7 @@ void Context::CollectHostNames() { void Context::InitializePossibleTransports() { // Find my rank in node for (int i = 0; i <= LocalRank(); i++) { - if (HostName() == hostnames[i]) rankInNode++; + if (peerInfos[i].sameHost) rankInNode++; } assert(rankInNode < 8); @@ -214,7 +225,7 @@ void Context::InitializePossibleTransports() { for (int i = 0; i < WorldSize(); i++) { // Check P2P availability if (!IsP2PDisabled()) { - if (HostName() == hostnames[i]) { + if (peerInfos[i].sameHost) { peerRankInNode++; // TODO: should use TopoSystemGpu to determine if peer access is enabled, but that requires diff --git a/src/application/memory/symmetric_memory.cpp b/src/application/memory/symmetric_memory.cpp index 71545891..f5cb78d6 100644 --- a/src/application/memory/symmetric_memory.cpp +++ b/src/application/memory/symmetric_memory.cpp @@ -134,11 +134,33 @@ SymmMemObjPtr SymmMemManager::RegisterSymmMemObj(void* localPtr, size_t size, bo static_cast(calloc(worldSize, sizeof(hipIpcMemHandle_t))); bootNet.Allgather(&handle, cpuMemObj->ipcMemHandles, sizeof(hipIpcMemHandle_t)); - // Open IPC handles for all same-node peers to establish P2P data path - // This happens regardless of transport type selection + // Open IPC handles for all same-node peers to establish P2P data path. + // Skip same-process peers: hipIpcOpenMemHandle fails within the same process; + // the peer's pointer is already valid and can be used directly. for (int i = 0; i < worldSize; i++) { if (!context.CanUseP2P(i)) continue; - + if (context.SameProcessP2P(i)) { + // Direct pointer access — no IPC handle needed within the same process. + // We must still enable peer access from our current device to the peer's + // device, because hipIpcOpenMemHandle's lazy-enable path is skipped here. + cpuMemObj->p2pPeerPtrs[i] = cpuMemObj->peerPtrs[i]; + hipPointerAttribute_t attr{}; + hipError_t attrErr = + hipPointerGetAttributes(&attr, reinterpret_cast(cpuMemObj->peerPtrs[i])); + if (attrErr == hipSuccess && attr.device != hipInvalidDeviceId) { + hipError_t peerErr = hipDeviceEnablePeerAccess(attr.device, 0); + (void)hipGetLastError(); + if (peerErr != hipSuccess && peerErr != hipErrorPeerAccessAlreadyEnabled) { + MORI_APP_WARN("hipDeviceEnablePeerAccess(peer={}) failed: {}", attr.device, + hipGetErrorString(peerErr)); + } + } else { + (void)hipGetLastError(); + MORI_APP_WARN("hipPointerGetAttributes failed for same-process peer {} ptr {:p}: {}", i, + reinterpret_cast(cpuMemObj->peerPtrs[i]), hipGetErrorString(attrErr)); + } + continue; + } HIP_RUNTIME_CHECK(hipIpcOpenMemHandle(reinterpret_cast(&cpuMemObj->p2pPeerPtrs[i]), cpuMemObj->ipcMemHandles[i], hipIpcMemLazyEnablePeerAccess)); @@ -216,7 +238,8 @@ SymmMemObjPtr SymmMemManager::RegisterSymmMemObj(void* localPtr, size_t size, bo HIP_RUNTIME_CHECK(hipMalloc(&gpuMemObj->expectSignalsPtr, signalArraySize)); HIP_RUNTIME_CHECK(hipMemset(gpuMemObj->expectSignalsPtr, 0, signalArraySize)); - // Exchange signal memory via IPC so each PE can write to remote PE's signalPtrs + // Exchange signal memory via IPC so each PE can write to remote PE's signalPtrs. + // Also allgather raw pointers for same-process peers (SPMT) where IPC fails. hipIpcMemHandle_t signalHandle; HIP_RUNTIME_CHECK(hipIpcGetMemHandle(&signalHandle, gpuMemObj->signalPtrs)); @@ -224,22 +247,46 @@ SymmMemObjPtr SymmMemManager::RegisterSymmMemObj(void* localPtr, size_t size, bo static_cast(calloc(worldSize, sizeof(hipIpcMemHandle_t))); bootNet.Allgather(&signalHandle, signalHandles, sizeof(hipIpcMemHandle_t)); + HSAuint64* mySignalPtr = gpuMemObj->signalPtrs; + auto* rawSignalPtrs = static_cast(calloc(worldSize, sizeof(HSAuint64*))); + bootNet.Allgather(&mySignalPtr, rawSignalPtrs, sizeof(HSAuint64*)); + auto* peerSignalPtrsHost = static_cast(calloc(worldSize, sizeof(HSAuint64*))); peerSignalPtrsHost[rank] = gpuMemObj->signalPtrs; for (int i = 0; i < worldSize; i++) { if (context.GetTransportType(i) != TransportType::SDMA) continue; if (i == rank) continue; + if (context.SameProcessP2P(i)) { + peerSignalPtrsHost[i] = rawSignalPtrs[i]; + hipPointerAttribute_t attr{}; + hipError_t attrErr = hipPointerGetAttributes(&attr, rawSignalPtrs[i]); + if (attrErr == hipSuccess && attr.device != hipInvalidDeviceId) { + hipError_t peerErr = hipDeviceEnablePeerAccess(attr.device, 0); + (void)hipGetLastError(); + if (peerErr != hipSuccess && peerErr != hipErrorPeerAccessAlreadyEnabled) { + MORI_APP_WARN("hipDeviceEnablePeerAccess(peer={}) failed for SDMA signal: {}", + attr.device, hipGetErrorString(peerErr)); + } + } else { + (void)hipGetLastError(); + MORI_APP_WARN( + "hipPointerGetAttributes failed for same-process SDMA signal peer {} ptr {:p}: {}", i, + reinterpret_cast(rawSignalPtrs[i]), hipGetErrorString(attrErr)); + } + continue; + } void* mappedPtr = nullptr; HIP_RUNTIME_CHECK( hipIpcOpenMemHandle(&mappedPtr, signalHandles[i], hipIpcMemLazyEnablePeerAccess)); peerSignalPtrsHost[i] = reinterpret_cast(mappedPtr); } + free(rawSignalPtrs); HIP_RUNTIME_CHECK(hipMalloc(&gpuMemObj->peerSignalPtrs, sizeof(HSAuint64*) * worldSize)); HIP_RUNTIME_CHECK(hipMemcpy(gpuMemObj->peerSignalPtrs, peerSignalPtrsHost, sizeof(HSAuint64*) * worldSize, hipMemcpyHostToDevice)); + cpuMemObj->peerSignalPtrsHost = peerSignalPtrsHost; free(signalHandles); - free(peerSignalPtrsHost); } SymmMemObjPtr result{cpuMemObj, gpuMemObj}; if (!heap_begin) { @@ -258,11 +305,14 @@ void SymmMemManager::DeregisterSymmMemObj(void* localPtr) { SymmMemObjPtr memObjPtr = memObjPool.at(localPtr); - // Close IPC handles for peers that had P2P connection + // Close IPC handles for peers that had P2P connection. + // Skip same-process peers: their p2pPeerPtrs are direct VA pointers, not + // IPC-mapped, so hipIpcCloseMemHandle would fail. int rank = bootNet.GetLocalRank(); int worldSize = bootNet.GetWorldSize(); for (int i = 0; i < worldSize; i++) { if (!context.CanUseP2P(i)) continue; + if (context.SameProcessP2P(i)) continue; if (memObjPtr.cpu->p2pPeerPtrs && memObjPtr.cpu->p2pPeerPtrs[i] != 0) { void* peerPtr = reinterpret_cast(memObjPtr.cpu->p2pPeerPtrs[i]); hipError_t closeErr = hipIpcCloseMemHandle(peerPtr); @@ -275,6 +325,28 @@ void SymmMemManager::DeregisterSymmMemObj(void* localPtr) { } } + // Close SDMA signal IPC handles for non-same-process peers and free SDMA GPU resources + if (memObjPtr.cpu->peerSignalPtrsHost) { + for (int i = 0; i < worldSize; i++) { + if (context.GetTransportType(i) != TransportType::SDMA) continue; + if (i == rank) continue; + if (context.SameProcessP2P(i)) continue; + if (memObjPtr.cpu->peerSignalPtrsHost[i] != nullptr) { + hipError_t closeErr = + hipIpcCloseMemHandle(reinterpret_cast(memObjPtr.cpu->peerSignalPtrsHost[i])); + if (closeErr != hipSuccess) { + MORI_APP_WARN("hipIpcCloseMemHandle failed for SDMA signal peer {}: {}", i, + hipGetErrorString(closeErr)); + } + } + } + free(memObjPtr.cpu->peerSignalPtrsHost); + } + if (memObjPtr.gpu->signalPtrs) HIP_RUNTIME_CHECK(hipFree(memObjPtr.gpu->signalPtrs)); + if (memObjPtr.gpu->expectSignalsPtr) HIP_RUNTIME_CHECK(hipFree(memObjPtr.gpu->expectSignalsPtr)); + if (memObjPtr.gpu->peerSignalPtrs) HIP_RUNTIME_CHECK(hipFree(memObjPtr.gpu->peerSignalPtrs)); + if (memObjPtr.gpu->deviceHandles_d) HIP_RUNTIME_CHECK(hipFree(memObjPtr.gpu->deviceHandles_d)); + free(memObjPtr.cpu->peerPtrs); free(memObjPtr.cpu->p2pPeerPtrs); free(memObjPtr.cpu->peerRkeys); diff --git a/src/application/transport/sdma/anvil.cpp b/src/application/transport/sdma/anvil.cpp index 01238c86..f194bd92 100644 --- a/src/application/transport/sdma/anvil.cpp +++ b/src/application/transport/sdma/anvil.cpp @@ -265,26 +265,27 @@ void AnvilLib::init() { } bool AnvilLib::connect(int srcDeviceId, int dstDeviceId, int numChannels) { - uint32_t engineId = getSdmaEngineId(srcDeviceId, dstDeviceId); // + 1) * 2; - // std::cout << "Connect from " << srcDeviceId << " to " << dstDeviceId << " with " << numChannels - // << " channels using engine " << engineId << std::endl; + uint32_t engineId = getSdmaEngineId(srcDeviceId, dstDeviceId); + std::lock_guard lock(channels_mutex_); + auto key = std::make_pair(srcDeviceId, dstDeviceId); for (int c = 0; c < numChannels; ++c) { - sdma_channels_[dstDeviceId].emplace_back( + sdma_channels_[key].emplace_back( std::make_unique(srcDeviceId, dstDeviceId, gpuAgents_[srcDeviceId], engineId)); } return true; } SdmaQueue* AnvilLib::getSdmaQueue(int srcDeviceId, int dstDeviceId, int channel_idx) { - if (sdma_channels_.find(dstDeviceId) == sdma_channels_.end()) { + std::lock_guard lock(channels_mutex_); + auto key = std::make_pair(srcDeviceId, dstDeviceId); + auto it = sdma_channels_.find(key); + if (it == sdma_channels_.end()) { return nullptr; } - - if (!(channel_idx < sdma_channels_[dstDeviceId].size())) { + if (!(channel_idx < static_cast(it->second.size()))) { return nullptr; } - - return sdma_channels_[dstDeviceId][channel_idx].get(); // TODO + return it->second[channel_idx].get(); } AnvilLib& AnvilLib::getInstance() { diff --git a/src/ops/dispatch_combine/dispatch_combine.cpp b/src/ops/dispatch_combine/dispatch_combine.cpp index 3b7f46aa..7911599f 100644 --- a/src/ops/dispatch_combine/dispatch_combine.cpp +++ b/src/ops/dispatch_combine/dispatch_combine.cpp @@ -27,6 +27,7 @@ #include #include "mori/core/core.hpp" +#include "mori/shmem/internal.hpp" #include "mori/shmem/shmem_api.hpp" #include "mori/utils/env_utils.hpp" #include "mori/utils/hip_helper.hpp" @@ -166,6 +167,12 @@ EpDispatchCombineHandle::EpDispatchCombineHandle(EpDispatchCombineConfig config_ } EpDispatchCombineHandle::~EpDispatchCombineHandle() { + auto* states = mori::shmem::ShmemStatesSingleton::GetInstance(); + if (states->status != mori::shmem::ShmemStatesStatus::Initialized) { + return; + } + hipDeviceSynchronize(); + (void)hipGetLastError(); FinalizeShmemBuf(); FinalizeTokenNumSignalBuf(); FinalizeOrderMapBuf(); diff --git a/src/ops/dispatch_combine/launch.cpp b/src/ops/dispatch_combine/launch.cpp index 6f07461b..b813290b 100644 --- a/src/ops/dispatch_combine/launch.cpp +++ b/src/ops/dispatch_combine/launch.cpp @@ -29,6 +29,7 @@ #include #include +#include #include #include #include @@ -39,6 +40,9 @@ #include #include +#include "mori/application/utils/check.hpp" +#include "mori/utils/limits.hpp" + #ifdef __linux__ #include #endif @@ -68,8 +72,22 @@ struct KernelRegistry::Impl { }; KernelRegistry::Impl& KernelRegistry::GetImpl() { +#ifdef MORI_MULTITHREAD_SUPPORT + // SPMT: one Impl per GPU. hipModuleLoad binds to the calling thread's + // current device; sharing modules across devices would launch the wrong + // kernel image. In multi-process mode every process sees its single GPU as + // device 0, so this collapses to slot[0] — equivalent to a singleton. + static std::array impls; + int id = -1; + HIP_RUNTIME_CHECK(hipGetDevice(&id)); + if (id < 0 || id >= mori::kMaxGpusPerNode) { + throw std::runtime_error("KernelRegistry: hipGetDevice() out of range: " + std::to_string(id)); + } + return impls[static_cast(id)]; +#else static Impl impl; return impl; +#endif } KernelRegistry& KernelRegistry::Instance() { @@ -447,10 +465,22 @@ void LaunchDispatch(EpDispatchCombineHandle& handle, void* input, void* weights, reg.Launch(std::string("EpDispatchInterNodeV1KernelLowLatency_") + sfx, bn, block_x, smem, stream, &args, args_size); break; - case KernelType::AsyncLL: - reg.Launch(std::string("EpDispatchLowLatencyAsyncSend_") + sfx, bn, block_x, smem, stream, - &args, args_size); + case KernelType::AsyncLL: { + int mp = handle.multiProcessorCount; + int mp_aligned = mp - (mp % handle.config.worldSize); + unsigned int mb_block = WARP_SIZE * 16; + reg.Launch(std::string("EpDispatchLowLatencyAsyncSendCopySlotAssign_") + sfx, mp_aligned, + mb_block, 0, stream, &args, args_size); + reg.Launch(std::string("EpDispatchLowLatencyAsyncSendCopyMultiBlock_") + sfx, mp_aligned, + mb_block, 0, stream, &args, args_size); + reg.Launch(std::string("EpDispatchLowLatencyAsyncSendTransfer_") + sfx, + handle.config.worldSize, block_x, 0, stream, &args, args_size); + reg.Launch(std::string("EpDispatchLowLatencyAsyncRecvTransfer_") + sfx, + handle.config.worldSize, block_x, 0, stream, &args, args_size); + reg.Launch(std::string("EpDispatchLowLatencyAsyncRecvCopyMultiBlock_") + sfx, mp_aligned, + mb_block, 0, stream, &args, args_size); break; + } default: throw std::runtime_error("Unsupported dispatch kernel_type"); } @@ -561,10 +591,19 @@ void LaunchCombine(EpDispatchCombineHandle& handle, void* input, void* weights, stream, &args, args_size); reg.Launch(std::string("EpCombineAll_") + sfx, mp, block_x, smem, stream, &args, args_size); break; - case KernelType::AsyncLL: - reg.Launch(std::string("EpCombineLowLatencyAsyncSend_") + sfx, bn, block_x, smem, stream, - &args, args_size); + case KernelType::AsyncLL: { + int mp = handle.multiProcessorCount; + int mp_aligned = mp - (mp % handle.config.worldSize); + reg.Launch(std::string("EpCombineLowLatencyAsyncSendCopy_") + sfx, mp_aligned, block_x, 0, + stream, &args, args_size); + reg.Launch(std::string("EpCombineLowLatencyAsyncSendTransfer_") + sfx, + handle.config.worldSize, block_x, 0, stream, &args, args_size); + reg.Launch(std::string("EpCombineLowLatencyAsyncRecvTransfer_") + sfx, + handle.config.worldSize, block_x, 0, stream, &args, args_size); + reg.Launch(std::string("EpCombineLowLatencyAsyncRecvCopy_") + sfx, mp_aligned, block_x, smem, + stream, &args, args_size); break; + } default: throw std::runtime_error("Unsupported combine kernel_type"); } @@ -578,19 +617,23 @@ void LaunchDispatchRecv(EpDispatchCombineHandle& handle, int block_num, int warp ensure_loaded(); int wpb = (warp_per_block <= 0) ? handle.config.warpNumPerBlock : warp_per_block; - int bn = (block_num <= 0) ? handle.config.blockNum : block_num; EpDispatchCombineArgsRaw args = GetEpDispatchCombineArgsRaw(handle, 0); if (handle.curHiddenDim > 0) args.config.hiddenDim = handle.curHiddenDim; unsigned int block_x = WARP_SIZE * wpb; - int smem = dispatch_shared_mem(handle.config, wpb); size_t args_size = sizeof(EpDispatchCombineArgsRaw); const char* sfx = dtype_suffix(handle.inputType); if (handle.config.kernelType == KernelType::AsyncLL) { - KernelRegistry::Instance().Launch(std::string("EpDispatchLowLatencyAsyncRecv_") + sfx, bn, - block_x, smem, stream, &args, args_size); + auto& reg = KernelRegistry::Instance(); + int mp = handle.multiProcessorCount; + int mp_aligned = mp - (mp % handle.config.worldSize); + unsigned int mb_block = WARP_SIZE * 16; + reg.Launch(std::string("EpDispatchLowLatencyAsyncRecvTransfer_") + sfx, handle.config.worldSize, + block_x, 0, stream, &args, args_size); + reg.Launch(std::string("EpDispatchLowLatencyAsyncRecvCopyMultiBlock_") + sfx, mp_aligned, + mb_block, 0, stream, &args, args_size); } else { throw std::runtime_error("LaunchDispatchRecv only supported for AsyncLL"); } @@ -604,7 +647,6 @@ void LaunchCombineRecv(EpDispatchCombineHandle& handle, int block_num, int warp_ ensure_loaded(); int wpb = (warp_per_block <= 0) ? handle.config.warpNumPerBlock : warp_per_block; - int bn = (block_num <= 0) ? handle.config.blockNum : block_num; EpDispatchCombineArgsRaw args = GetEpDispatchCombineArgsRaw(handle, 0); if (handle.curHiddenDim > 0) args.config.hiddenDim = handle.curHiddenDim; @@ -616,8 +658,13 @@ void LaunchCombineRecv(EpDispatchCombineHandle& handle, int block_num, int warp_ const char* sfx = dtype_suffix(handle.inputType); if (handle.config.kernelType == KernelType::AsyncLL) { - KernelRegistry::Instance().Launch(std::string("EpCombineLowLatencyAsyncRecv_") + sfx, bn, - block_x, smem, stream, &args, args_size); + auto& reg = KernelRegistry::Instance(); + int mp = handle.multiProcessorCount; + int mp_aligned = mp - (mp % handle.config.worldSize); + reg.Launch(std::string("EpCombineLowLatencyAsyncRecvTransfer_") + sfx, handle.config.worldSize, + block_x, 0, stream, &args, args_size); + reg.Launch(std::string("EpCombineLowLatencyAsyncRecvCopy_") + sfx, mp_aligned, block_x, smem, + stream, &args, args_size); } else { throw std::runtime_error("LaunchCombineRecv only supported for AsyncLL"); } diff --git a/src/pybind/pybind_shmem.cpp b/src/pybind/pybind_shmem.cpp index ee9fad82..27b61463 100644 --- a/src/pybind/pybind_shmem.cpp +++ b/src/pybind/pybind_shmem.cpp @@ -126,16 +126,21 @@ void RegisterMoriShmem(py::module_& m) { m.def("shmem_get_unique_id", &ShmemGetUniqueId, "Get a unique ID for shmem initialization (returns bytes)"); + // Release the GIL for blocking shmem init/finalize so that concurrent + // Python threads (SPMT mode) can all progress through the socket bootstrap + // handshake without deadlocking on the GIL. m.def("shmem_init_attr", &ShmemInitAttr, py::arg("flags"), py::arg("rank"), py::arg("nranks"), - py::arg("unique_id"), + py::arg("unique_id"), py::call_guard(), "Initialize shmem with attributes (unique_id should be bytes from shmem_get_unique_id)"); - m.def("shmem_finalize", &ShmemFinalize, "Finalize shmem"); + m.def("shmem_finalize", &ShmemFinalize, py::call_guard(), + "Finalize shmem"); // Module-specific initialization (for Triton kernels) m.def("shmem_module_init", &ShmemModuleInit, py::arg("hip_module"), "Initialize globalGpuStates in a specific HIP module (for Triton kernels)"); m.def("load_shmem_module", &LoadShmemModule, py::arg("hsaco_path"), + py::call_guard(), "Load JIT-compiled shmem module (.hsaco) with globalGpuStates and barrier kernel"); // Query APIs @@ -144,7 +149,8 @@ void RegisterMoriShmem(py::module_& m) { m.def("shmem_npes", &ShmemNPes, "Get number of PEs"); // Collective operations - m.def("shmem_barrier_all", &ShmemBarrierAll, "Global barrier synchronization"); + m.def("shmem_barrier_all", &ShmemBarrierAll, py::call_guard(), + "Global barrier synchronization"); m.def( "shmem_barrier_on_stream", @@ -152,23 +158,27 @@ void RegisterMoriShmem(py::module_& m) { py::arg("stream"), "Launch device barrier on a HIP stream"); // Symmetric memory management - m.def("shmem_malloc", &ShmemMalloc, py::arg("size"), + m.def("shmem_malloc", &ShmemMalloc, py::arg("size"), py::call_guard(), "Allocate symmetric memory (returns address as int)"); m.def("shmem_malloc_align", &ShmemMallocAlign, py::arg("alignment"), py::arg("size"), + py::call_guard(), "Allocate aligned symmetric memory (returns address as int)"); m.def("shmem_ext_malloc_with_flags", &ShmemExtMallocWithFlags, py::arg("size"), py::arg("flags"), + py::call_guard(), "Allocate symmetric memory with flags (returns address as int)"); - m.def("shmem_free", &ShmemFree, py::arg("ptr"), + m.def("shmem_free", &ShmemFree, py::arg("ptr"), py::call_guard(), "Free symmetric memory (ptr should be int address)"); // Buffer registration m.def("shmem_buffer_register", &ShmemBufferRegister, py::arg("ptr"), py::arg("size"), + py::call_guard(), "Register an existing buffer for RDMA (ptr should be int address)"); m.def("shmem_buffer_deregister", &ShmemBufferDeregister, py::arg("ptr"), py::arg("size"), + py::call_guard(), "Deregister a buffer from RDMA (ptr should be int address)"); // P2P address translation @@ -176,10 +186,6 @@ void RegisterMoriShmem(py::module_& m) { "Convert local symmetric memory pointer to remote P2P address. " "Returns 0 if connection uses RDMA or if pointer is invalid. " "Returns P2P accessible address if connection uses P2P transport."); - m.def("shmem_torch_process_group_init", &ShmemTorchProcessGroupInit); - m.def("shmem_finalize", &ShmemFinalize); - m.def("shmem_mype", &ShmemMyPe); - m.def("shmem_npes", &ShmemNPes); m.def("shmem_num_qp_per_pe", &ShmemNumQpPerPe); } diff --git a/src/pybind/pybind_xla_ffi_ops.cpp b/src/pybind/pybind_xla_ffi_ops.cpp index be043675..2c56615d 100644 --- a/src/pybind/pybind_xla_ffi_ops.cpp +++ b/src/pybind/pybind_xla_ffi_ops.cpp @@ -26,14 +26,18 @@ #include #include +#include #include #include +#include #include +#include "mori/application/utils/check.hpp" #include "mori/ops/dispatch_combine/launch.hpp" #include "mori/ops/ops.hpp" #include "mori/shmem/internal.hpp" #include "mori/utils/hip_helper.hpp" +#include "mori/utils/limits.hpp" #include "src/pybind/mori.hpp" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" @@ -53,7 +57,7 @@ using mori::moe::KernelType; /* ---------------------------------------------------------------------------------------------- */ namespace { -// Global cache: maps packed ep_config → shared handle. +// Cache: maps packed ep_config → shared handle. // All call sites with identical ep_config reuse the same EpDispatchCombineHandle. struct VecI32Hash { size_t operator()(const std::vector& v) const noexcept { @@ -66,10 +70,57 @@ struct VecI32Hash { } }; -static std::mutex g_handle_cache_mu; -static std::unordered_map, std::unique_ptr, - VecI32Hash> - g_handle_cache; +using HandleCacheMap = + std::unordered_map, std::unique_ptr, VecI32Hash>; + +struct HandleCacheSlot { + std::mutex mu; + HandleCacheMap map; +}; + +#ifdef MORI_MULTITHREAD_SUPPORT +// SPMT: per-GPU cache slot. Each thread is bound to its own device, so each +// gets its own (mu, map). This avoids a deadlock that would occur if a single +// process-global mutex were held while calling the cross-PE Barrier() below — +// thread A would hold the mutex during Barrier(), thread B would block on the +// mutex and never reach its own Barrier(). In multi-process mode every process +// sees its single GPU as device 0, so this collapses to slot[0]. +static std::array g_handle_cache_slots; + +static HandleCacheSlot& GetHandleCacheSlot() { + int id = -1; + HIP_RUNTIME_CHECK(hipGetDevice(&id)); + if (id < 0 || id >= mori::kMaxGpusPerNode) { + throw std::runtime_error("EpHandleCache: hipGetDevice() out of range: " + std::to_string(id)); + } + return g_handle_cache_slots[static_cast(id)]; +} + +// RAII helper: temporarily switch the calling thread to `target` and restore +// the previous device on scope exit. Used in XLA FFI handlers so that +// hipSetDevice does not leak out into XLA's worker-thread state. +class ScopedDevice { + public: + explicit ScopedDevice(int target) : saved_(-1), restore_(false) { + if (target < 0) return; + if (hipGetDevice(&saved_) != hipSuccess) return; + if (saved_ == target) return; // nothing to do + if (hipSetDevice(target) == hipSuccess) restore_ = true; + } + ~ScopedDevice() { + if (restore_) (void)hipSetDevice(saved_); + } + ScopedDevice(const ScopedDevice&) = delete; + ScopedDevice& operator=(const ScopedDevice&) = delete; + + private: + int saved_; + bool restore_; +}; +#else +static HandleCacheSlot g_handle_cache_singleton; +static HandleCacheSlot& GetHandleCacheSlot() { return g_handle_cache_singleton; } +#endif struct EpDispatchCombineState { static TypeId id; @@ -280,14 +331,27 @@ ErrorOr> EpDispatchCombineInstantiate( // ep_config share a single EpDispatchCombineHandle. std::vector key(ep_config->begin(), ep_config->end()); - std::lock_guard lock(g_handle_cache_mu); - auto& entry = g_handle_cache[key]; + // Decode once; reused below for both rank-based device routing (SPMT) and + // (on cache miss) handle construction. + auto cfg = EpDispatchCombineConfig::FromPackedI32Array(key.data(), key.size()); + +#ifdef MORI_MULTITHREAD_SUPPORT + // SPMT: XLA FFI handlers run on framework worker threads where + // hipGetDevice() does NOT match the rank's device. Look up the device + // recorded at ShmemInit time and bind it on this thread (RAII-restored on + // exit so XLA's other state isn't disturbed) before any HIP call. This + // ensures GetHandleCacheSlot() and ShmemStatesSingleton::GetInstance() + // (both keyed by hipGetDevice()) hit the right slot. + ScopedDevice _dev_guard(mori::shmem::ShmemStatesSingleton::GetDeviceByRank(cfg.rank)); +#endif + + auto& slot = GetHandleCacheSlot(); + std::lock_guard lock(slot.mu); + auto& entry = slot.map[key]; if (!entry) { auto* states = mori::shmem::ShmemStatesSingleton::GetInstance(); states->CheckStatusValid(); states->bootStates->bootNet->Barrier(); - - auto cfg = EpDispatchCombineConfig::FromPackedI32Array(key.data(), key.size()); // XPUT("EpDispatchCombineInstantiate: creating new handle for rank %d " // "(#attrs: %zu)", cfg.rank, attrs.size()); entry = std::make_unique(cfg); @@ -305,6 +369,11 @@ Error EpDispatchCombineImpl(hipStream_t stream, EpDispatchCombineState* state, D auto& h = *state->handle; // XPUT("EpDispatchCombineImpl stream=%p rank=%d attrs: %zu", // stream, h.config.rank, attrs.size()); +#ifdef MORI_MULTITHREAD_SUPPORT + // SPMT: bind this thread to the rank's device for the duration of this + // FFI call (RAII-restored). See EpDispatchCombineInstantiate for rationale. + ScopedDevice _dev_guard(mori::shmem::ShmemStatesSingleton::GetDeviceByRank(h.config.rank)); +#endif if (attrs.contains("dispatch_op")) { return MoriDispatchImpl(stream, &h, attrs, args, rets); } @@ -362,8 +431,14 @@ void RegisterXLAFFIOps(py::module_& m) { }); m.def("preload_kernels", []() { mori::moe::KernelRegistry::Instance().AutoLoad(); }); m.def("clear_ep_handle_cache", []() { - std::lock_guard lock(g_handle_cache_mu); - g_handle_cache.clear(); + // Clear only the calling thread's slot. Under SPMT, each thread is + // bound to its own GPU and owns one slot; clearing all slots from one + // thread would invoke ~EpDispatchCombineHandle on OTHER GPUs' handles + // while the calling thread's hipDevice is still set to its own GPU, + // and ShmemFree would then look up addresses in the wrong VA manager. + auto& slot = GetHandleCacheSlot(); + std::lock_guard lock(slot.mu); + slot.map.clear(); }); } diff --git a/src/shmem/CMakeLists.txt b/src/shmem/CMakeLists.txt index 7f123182..7c0b01cc 100644 --- a/src/shmem/CMakeLists.txt +++ b/src/shmem/CMakeLists.txt @@ -8,6 +8,11 @@ target_include_directories(mori_shmem PUBLIC ${CMAKE_SOURCE_DIR}) target_link_libraries(mori_shmem PUBLIC mori_application mori_logging ibverbs hip::host) +if(MORI_MULTITHREAD_SUPPORT) + target_compile_definitions(mori_shmem PUBLIC MORI_MULTITHREAD_SUPPORT) + message(STATUS "mori_shmem: MORI_MULTITHREAD_SUPPORT enabled") +endif() + set_target_properties( mori_shmem PROPERTIES BUILD_RPATH "$ORIGIN" diff --git a/src/shmem/init.cpp b/src/shmem/init.cpp index 6598a8bb..61df13f0 100644 --- a/src/shmem/init.cpp +++ b/src/shmem/init.cpp @@ -42,6 +42,50 @@ namespace mori { namespace shmem { +/* ---------------------------------------------------------------------------------------------- */ +/* ShmemStatesSingleton */ +/* ---------------------------------------------------------------------------------------------- */ + +#ifdef MORI_MULTITHREAD_SUPPORT +// rank → device id, populated by RegisterRankDevice at ShmemInit. +// Used by FFI handlers (XLA / custom calls) where the calling thread's +// hipGetDevice() does NOT match the rank's device. +static std::mutex g_rank_to_device_mu; +static std::unordered_map g_rank_to_device; +#endif + +ShmemStates* ShmemStatesSingleton::GetInstance() { +#ifdef MORI_MULTITHREAD_SUPPORT + // One instance per GPU, indexed by the calling thread's current HIP device. + // hipGetDevice() reads thread-local HIP state, so it is very cheap. + static ShmemStatesSingleton s_inst; + int id = -1; + HIP_RUNTIME_CHECK(hipGetDevice(&id)); + if (__builtin_expect(id < 0 || id >= mori::kMaxGpusPerNode, 0)) { + MORI_SHMEM_ERROR("hipGetDevice() returned out-of-range id {}, max supported is {}", id, + mori::kMaxGpusPerNode - 1); + assert(false); + } + return &s_inst.states_[id]; +#else + static ShmemStates states; + return &states; +#endif +} + +#ifdef MORI_MULTITHREAD_SUPPORT +void ShmemStatesSingleton::RegisterRankDevice(int rank, int deviceId) { + std::lock_guard lk(g_rank_to_device_mu); + g_rank_to_device[rank] = deviceId; +} + +int ShmemStatesSingleton::GetDeviceByRank(int rank) { + std::lock_guard lk(g_rank_to_device_mu); + auto it = g_rank_to_device.find(rank); + return it == g_rank_to_device.end() ? -1 : it->second; +} +#endif + /* ---------------------------------------------------------------------------------------------- */ /* Helper Functions */ /* ---------------------------------------------------------------------------------------------- */ @@ -125,8 +169,7 @@ static bool IsROCmVersionGreaterThan7() { /* RDMA States Initialization */ /* ---------------------------------------------------------------------------------------------- */ -void RdmaStatesInit() { - ShmemStates* states = ShmemStatesSingleton::GetInstance(); +void RdmaStatesInit(ShmemStates* states) { states->rdmaStates = new RdmaStates(); RdmaStates* rdmaStates = states->rdmaStates; @@ -284,8 +327,7 @@ static bool TryInitializeVMMHeap(ShmemStates* states, application::HeapType heap /* Memory States Initialization */ /* ---------------------------------------------------------------------------------------------- */ -void MemoryStatesInit() { - ShmemStates* states = ShmemStatesSingleton::GetInstance(); +void MemoryStatesInit(ShmemStates* states) { application::Context* context = states->rdmaStates->commContext; // Create memory management objects @@ -337,7 +379,8 @@ void MemoryStatesInit() { /* ---------------------------------------------------------------------------------------------- */ // Copy transport types to GPU device memory -static void CopyTransportTypesToGpu(GpuStates* gpuStates, const ShmemStates* states) { +static void CopyTransportTypesToGpu(ShmemStates* states) { + GpuStates* gpuStates = &states->gpuStates; int worldSize = states->bootStates->worldSize; HIP_RUNTIME_CHECK( @@ -348,7 +391,8 @@ static void CopyTransportTypesToGpu(GpuStates* gpuStates, const ShmemStates* sta } // Copy RDMA endpoints to GPU device memory -static void CopyRdmaEndpointsToGpu(GpuStates* gpuStates, const ShmemStates* states) { +static void CopyRdmaEndpointsToGpu(ShmemStates* states) { + GpuStates* gpuStates = &states->gpuStates; if (!states->rdmaStates->commContext->RdmaTransportEnabled()) { return; } @@ -381,7 +425,8 @@ static void CopyRdmaEndpointsToGpu(GpuStates* gpuStates, const ShmemStates* stat } // Configure heap information for GPU based on current heap mode -static void ConfigureHeapInfoForGpu(GpuStates* gpuStates, const ShmemStates* states) { +static void ConfigureHeapInfoForGpu(ShmemStates* states) { + GpuStates* gpuStates = &states->gpuStates; gpuStates->useVMMHeap = states->memoryStates->useVMMHeap; switch (states->mode) { @@ -447,7 +492,8 @@ static void ConfigureHeapInfoForGpu(GpuStates* gpuStates, const ShmemStates* sta } // Allocate internal synchronization memory for device barriers -static void AllocateInternalSync(GpuStates* gpuStates, const ShmemStates* states) { +static void AllocateInternalSync(ShmemStates* states) { + GpuStates* gpuStates = &states->gpuStates; constexpr size_t MORI_INTERNAL_SYNC_SIZE = 128 * sizeof(uint64_t); constexpr size_t ALIGNMENT = 256; void* syncPtr = nullptr; @@ -504,11 +550,11 @@ static void AllocateInternalSync(GpuStates* gpuStates, const ShmemStates* states } static void FinalizeInternalSync(const ShmemStates* states) { - if (s_hostGpuStatesCopy.internalSyncPtr == nullptr) { + if (states->gpuStates.internalSyncPtr == nullptr) { return; } - void* syncPtr = reinterpret_cast(s_hostGpuStatesCopy.internalSyncPtr); + void* syncPtr = reinterpret_cast(states->gpuStates.internalSyncPtr); switch (states->mode) { case ShmemMode::StaticHeap: { states->memoryStates->symmMemMgr->DeregisterStaticHeapSubRegion(syncPtr); @@ -532,27 +578,25 @@ static void FinalizeInternalSync(const ShmemStates* states) { // CopyGpuStatesToDevice is in runtime.cpp -void GpuStateInit() { - ShmemStates* states = ShmemStatesSingleton::GetInstance(); - - // Initialize basic GPU states - GpuStates gpuStates; - gpuStates.rank = states->bootStates->rank; - gpuStates.worldSize = states->bootStates->worldSize; - gpuStates.numQpPerPe = states->rdmaStates->commContext->GetNumQpPerPe(); +void GpuStateInit(ShmemStates* states) { + // Initialize basic GPU states (in-place, no heap alloc) + states->gpuStates = {}; + states->gpuStates.rank = states->bootStates->rank; + states->gpuStates.worldSize = states->bootStates->worldSize; + states->gpuStates.numQpPerPe = states->rdmaStates->commContext->GetNumQpPerPe(); // Copy communication metadata to GPU - CopyTransportTypesToGpu(&gpuStates, states); - CopyRdmaEndpointsToGpu(&gpuStates, states); + CopyTransportTypesToGpu(states); + CopyRdmaEndpointsToGpu(states); // Configure heap information for GPU access - ConfigureHeapInfoForGpu(&gpuStates, states); + ConfigureHeapInfoForGpu(states); // Allocate internal synchronization memory for device barriers - AllocateInternalSync(&gpuStates, states); + AllocateInternalSync(states); // Copy complete state to device - CopyGpuStatesToDevice(&gpuStates); + CopyGpuStatesToDevice(states); } /* ---------------------------------------------------------------------------------------------- */ @@ -595,6 +639,18 @@ static void InitializeBootStates(ShmemStates* states, application::BootstrapNetw MORI_SHMEM_TRACE("Bootstrap initialized: rank={}, worldSize={}", states->bootStates->rank, states->bootStates->worldSize); + +#ifdef MORI_MULTITHREAD_SUPPORT + // Record rank → device mapping so FFI / custom-call handlers (XLA, etc.) + // running on framework worker threads can route to the right ShmemStates. + // We capture the current device of the calling user thread (which already + // did hipSetDevice() before ShmemInit per the SPMT contract). + int dev = -1; + if (hipGetDevice(&dev) == hipSuccess && dev >= 0) { + ShmemStatesSingleton::RegisterRankDevice(states->bootStates->rank, dev); + MORI_SHMEM_TRACE("Registered rank {} -> device {}", states->bootStates->rank, dev); + } +#endif } /* ---------------------------------------------------------------------------------------------- */ @@ -609,20 +665,15 @@ int ShmemInit(application::BootstrapNetwork* bootNet) { delete bootNet; return 0; } - if (states->status == ShmemStatesStatus::Finalized) { - MORI_SHMEM_ERROR("Shmem has been finalized, cannot re-initialize"); - delete bootNet; - return -1; - } // Configure shmem mode states->mode = ConfigureShmemMode(); // Initialize all subsystems InitializeBootStates(states, bootNet); - RdmaStatesInit(); - MemoryStatesInit(); - GpuStateInit(); + RdmaStatesInit(states); + MemoryStatesInit(states); + GpuStateInit(states); states->status = ShmemStatesStatus::Initialized; MORI_SHMEM_INFO("Shmem initialization completed"); @@ -633,10 +684,12 @@ int ShmemInit(application::BootstrapNetwork* bootNet) { /* Finalization Helpers */ /* ---------------------------------------------------------------------------------------------- */ -static void FinalizeGpuStates() { - HIP_RUNTIME_CHECK(hipFree(s_hostGpuStatesCopy.transportTypes)); - HIP_RUNTIME_CHECK(hipFree(s_hostGpuStatesCopy.rdmaEndpoints)); - FinalizeRuntime(); +static void FinalizeGpuStates(ShmemStates* states) { + hipDeviceSynchronize(); + (void)hipGetLastError(); + HIP_RUNTIME_CHECK(hipFree(states->gpuStates.transportTypes)); + HIP_RUNTIME_CHECK(hipFree(states->gpuStates.rdmaEndpoints)); + FinalizeRuntime(states); MORI_SHMEM_TRACE("GPU states finalized"); } @@ -716,16 +769,19 @@ int ShmemFinalize() { MORI_SHMEM_TRACE("Starting shmem finalization"); - // Clean up in reverse order of initialization - FinalizeGpuStates(); - - // Clean up internal sync memory + // Clean up in reverse order of initialization. + // FinalizeInternalSync MUST run before FinalizeGpuStates: the latter clears + // states->gpuStates (incl. internalSyncPtr), which would make + // FinalizeInternalSync early-return and leak the sync memory. FinalizeInternalSync(states); + FinalizeGpuStates(states); FinalizeHeap(states); FinalizeAllStates(states); - states->status = ShmemStatesStatus::Finalized; + // Reset to New so the slot can be reused (e.g. SPMT test suites that run + // multiple init/finalize cycles in the same process on the same GPU). + states->status = ShmemStatesStatus::New; MORI_SHMEM_INFO("Shmem finalization completed"); return 0; } diff --git a/src/shmem/runtime.cpp b/src/shmem/runtime.cpp index 7d87fa3c..fe8e818f 100644 --- a/src/shmem/runtime.cpp +++ b/src/shmem/runtime.cpp @@ -40,11 +40,6 @@ namespace shmem { /* ---------------------------------------------------------------------------------------------- */ /* JIT Module & GpuStates Management */ /* ---------------------------------------------------------------------------------------------- */ -static hipModule_t s_shmemModule = nullptr; -static GpuStates* s_deviceGpuStatesAddr = nullptr; -static hipFunction_t s_barrierFunc = nullptr; - -GpuStates s_hostGpuStatesCopy{}; using GpuStatesAddrProvider = void* (*)(); // One entry per RegisterGpuStatesAddrProvider (e.g. multiple HIP TUs + modules). @@ -61,38 +56,42 @@ void RegisterGpuStatesAddrProvider(GpuStatesAddrProvider provider) { void RegisterBarrierLauncher(BarrierLauncher launcher) { s_staticBarrierLauncher = launcher; } int LoadShmemModule(const char* hsaco_path) { - if (s_shmemModule != nullptr) return 0; - hipError_t err = hipModuleLoad(&s_shmemModule, hsaco_path); + ShmemStates* states = ShmemStatesSingleton::GetInstance(); + ModuleStates& ms = states->moduleStates; + + if (ms.module != nullptr) return 0; + hipError_t err = hipModuleLoad(&ms.module, hsaco_path); if (err != hipSuccess) { MORI_SHMEM_ERROR("Failed to load shmem module from {}: {}", hsaco_path, hipGetErrorString(err)); return -1; } - err = hipModuleGetGlobal(reinterpret_cast(&s_deviceGpuStatesAddr), nullptr, - s_shmemModule, "_ZN4mori5shmem15globalGpuStatesE"); + err = hipModuleGetGlobal(reinterpret_cast(&ms.gpuStatesPtr), nullptr, ms.module, + "_ZN4mori5shmem15globalGpuStatesE"); if (err != hipSuccess) { MORI_SHMEM_ERROR("globalGpuStates symbol not found in shmem module: {}", hipGetErrorString(err)); return -1; } - err = hipModuleGetFunction(&s_barrierFunc, s_shmemModule, "mori_shmem_barrier_all_block"); + err = hipModuleGetFunction(&ms.barrierFunc, ms.module, "mori_shmem_barrier_all_block"); if (err != hipSuccess) { MORI_SHMEM_ERROR("mori_shmem_barrier_all_block not found in shmem module: {}", hipGetErrorString(err)); return -1; } MORI_SHMEM_TRACE("Loaded shmem JIT module: globalGpuStates={:p}, barrier={:p}", - (void*)s_deviceGpuStatesAddr, (void*)s_barrierFunc); + (void*)ms.gpuStatesPtr, (void*)ms.barrierFunc); return 0; } -void CopyGpuStatesToDevice(const GpuStates* gpuStates) { - s_hostGpuStatesCopy = *gpuStates; +void CopyGpuStatesToDevice(ShmemStates* states) { + const GpuStates* gpuStates = &states->gpuStates; + ModuleStates& ms = states->moduleStates; - if (s_deviceGpuStatesAddr != nullptr) { + if (ms.gpuStatesPtr != nullptr) { MORI_SHMEM_TRACE("Copying GpuStates to JIT module globalGpuStates ({:p})", - (void*)s_deviceGpuStatesAddr); + (void*)ms.gpuStatesPtr); HIP_RUNTIME_CHECK( - hipMemcpy(s_deviceGpuStatesAddr, gpuStates, sizeof(GpuStates), hipMemcpyHostToDevice)); + hipMemcpy(ms.gpuStatesPtr, gpuStates, sizeof(GpuStates), hipMemcpyHostToDevice)); } for (auto& provider : s_gpuStatesAddrProviders) { @@ -107,15 +106,15 @@ void CopyGpuStatesToDevice(const GpuStates* gpuStates) { gpuStates->rank, gpuStates->worldSize); } -void FinalizeRuntime() { - if (s_shmemModule) { - hipModuleUnload(s_shmemModule); - s_shmemModule = nullptr; - s_deviceGpuStatesAddr = nullptr; - s_barrierFunc = nullptr; +void FinalizeRuntime(ShmemStates* states) { + ModuleStates& ms = states->moduleStates; + if (ms.module != nullptr) { + hipModuleUnload(ms.module); + ms.module = nullptr; + ms.gpuStatesPtr = nullptr; + ms.barrierFunc = nullptr; } - s_hostGpuStatesCopy = {}; - s_gpuStatesAddrProviders.clear(); + states->gpuStates = {}; } /* ---------------------------------------------------------------------------------------------- */ @@ -139,22 +138,23 @@ int ShmemModuleInit(void* hipModule) { return -1; } - MORI_SHMEM_TRACE("Module globalGpuStates address: {:p} (shmem module address: {:p})", - (void*)moduleGlobalGpuStatesAddr, (void*)s_deviceGpuStatesAddr); + MORI_SHMEM_TRACE("Module globalGpuStates address: {:p} (JIT module address: {:p})", + (void*)moduleGlobalGpuStatesAddr, (void*)states->moduleStates.gpuStatesPtr); - HIP_RUNTIME_CHECK(hipMemcpy(moduleGlobalGpuStatesAddr, &s_hostGpuStatesCopy, sizeof(GpuStates), + HIP_RUNTIME_CHECK(hipMemcpy(moduleGlobalGpuStatesAddr, &states->gpuStates, sizeof(GpuStates), hipMemcpyHostToDevice)); MORI_SHMEM_TRACE("Successfully initialized globalGpuStates in module (rank={}, worldSize={})", - s_hostGpuStatesCopy.rank, s_hostGpuStatesCopy.worldSize); + states->gpuStates.rank, states->gpuStates.worldSize); return 0; } int CopyGpuStatesToSymbol(void* deviceSymbolAddr) { if (deviceSymbolAddr == nullptr) return -1; + ShmemStates* states = ShmemStatesSingleton::GetInstance(); HIP_RUNTIME_CHECK( - hipMemcpy(deviceSymbolAddr, &s_hostGpuStatesCopy, sizeof(GpuStates), hipMemcpyHostToDevice)); + hipMemcpy(deviceSymbolAddr, &states->gpuStates, sizeof(GpuStates), hipMemcpyHostToDevice)); return 0; } @@ -201,9 +201,9 @@ void ShmemBarrierOnStream(hipStream_t stream) { MORI_SHMEM_TRACE("PE {} launching device barrier on stream", states->bootStates->rank); - if (s_barrierFunc != nullptr) { - hipError_t err = - hipModuleLaunchKernel(s_barrierFunc, 1, 1, 1, 1, 1, 1, 0, stream, nullptr, nullptr); + if (states->moduleStates.barrierFunc != nullptr) { + hipError_t err = hipModuleLaunchKernel(states->moduleStates.barrierFunc, 1, 1, 1, 1, 1, 1, 0, + stream, nullptr, nullptr); assert(err == hipSuccess && "ShmemBarrierOnStream launch failed"); } else if (s_staticBarrierLauncher != nullptr) { s_staticBarrierLauncher(stream); diff --git a/tests/python/shmem/test_spmt.py b/tests/python/shmem/test_spmt.py new file mode 100644 index 00000000..71f3c750 --- /dev/null +++ b/tests/python/shmem/test_spmt.py @@ -0,0 +1,137 @@ +# Copyright © Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Single-Process Multi-Thread (SPMT) shmem tests. + +Each test spawns one Python thread per GPU. Threads call ShmemInit via the +socket-bootstrap UniqueId path (no MPI, no torch.distributed). This is the +JAX/XLA model: one process, N threads, each thread owns one GPU. + +Requires MORI_MULTITHREAD_SUPPORT to be compiled in. +""" +import threading +import traceback + +import pytest +import torch + +import mori.shmem as shmem + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_num_gpus() -> int: + return torch.cuda.device_count() + + +def _thread_worker( + rank: int, + world_size: int, + unique_id: bytes, + barrier: threading.Barrier, + results: list, +): + """Per-GPU thread body: init, verify, malloc/free, barrier, finalize.""" + error = None + try: + torch.cuda.set_device(rank) + + # Phase 1: all threads have their device set → init together + barrier.wait() + + ret = shmem.shmem_init_attr( + shmem.MORI_SHMEM_INIT_WITH_UNIQUEID, rank, world_size, unique_id + ) + assert ret == 0, f"shmem_init_attr failed: {ret}" + + my_pe = shmem.shmem_mype() + n_pes = shmem.shmem_npes() + assert my_pe == rank, f"pe mismatch: expected {rank}, got {my_pe}" + assert n_pes == world_size, f"npes mismatch: expected {world_size}, got {n_pes}" + + # Phase 2: all inited → barrier + barrier.wait() + shmem.shmem_barrier_all() + + # Phase 3: symmetric malloc / free + ptr = shmem.shmem_malloc(4096) + assert ptr != 0, "shmem_malloc returned NULL" + barrier.wait() + shmem.shmem_barrier_all() + shmem.shmem_free(ptr) + + # Phase 4: finalize + barrier.wait() + shmem.shmem_finalize() + + except Exception: + error = traceback.format_exc() + + results[rank] = error + + +def _run_spmt(world_size: int): + """Spawn `world_size` threads and run SPMT shmem init/finalize cycle.""" + import os + + num_gpus = _get_num_gpus() + if world_size > num_gpus: + pytest.skip(f"Need {world_size} GPUs, only {num_gpus} available") + + # Set interface before UniqueId is generated so that the socket bootstrap + # uses the same interface for both root binding and peer connections. + os.environ["MORI_SOCKET_IFNAME"] = "lo" + + unique_id = shmem.shmem_get_unique_id() + + barrier = threading.Barrier(world_size) + results = [None] * world_size + threads = [ + threading.Thread( + target=_thread_worker, + args=(rank, world_size, unique_id, barrier, results), + daemon=True, + ) + for rank in range(world_size) + ] + + for t in threads: + t.start() + for t in threads: + t.join(timeout=60) + assert not t.is_alive(), f"Thread {t.name} timed out" + + for rank, err in enumerate(results): + assert err is None, f"Thread {rank} failed:\n{err}" + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("world_size", [2, 4, 8]) +def test_spmt_shmem_init_finalize(world_size): + """SPMT: N threads each call ShmemInit + ShmemFinalize on their own GPU.""" + _run_spmt(world_size)