Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ option(ENABLE_PROFILER "Enable kernel profiling" OFF)
option(ENABLE_DEBUG_PRINTF "Enable debug printf in device kernels" OFF)
option(ENABLE_STANDARD_MOE_ADAPT "Enable standard moe adapt" OFF)
option(BUILD_OPS_DEVICE "AOT compile EP kernels to .hsaco(requires hipcc)" OFF)
option(MORI_MULTITHREAD_SUPPORT
"Enable single-process multi-thread multi-GPU support in mori_shmem" OFF)

# ---------------------------------------------------------------------------
# Host NIC: all vendor-specific dv libraries (mlx5, bnxt_re, ionic) are loaded
Expand Down
4 changes: 4 additions & 0 deletions include/mori/application/application_device_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ struct SymmMemObj {
// SdmaPutThread writes ATOMIC to peerSignalPtrs[remotePe] + myPe*sdmaNumQueue + qId,
// so the remote PE can directly read its own signalPtrs to detect completion.
HSAuint64** peerSignalPtrs = nullptr; // should only placed on GPU
// Host-side copy of peer signal pointers for IPC cleanup during deregistration.
// Only entries opened via hipIpcOpenMemHandle need closing; same-process (SPMT)
// entries are raw VA and must NOT be closed.
HSAuint64** peerSignalPtrsHost = nullptr; // should only placed on CPU

__device__ __host__ RdmaMemoryRegion GetRdmaMemoryRegion(int pe) const {
RdmaMemoryRegion mr;
Expand Down
14 changes: 11 additions & 3 deletions include/mori/application/context/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class Context {
int LocalRank() const { return bootNet.GetLocalRank(); }
int WorldSize() const { return bootNet.GetWorldSize(); }
int LocalRankInNode() const { return rankInNode; }
std::string HostName() const;
const std::string& HostName() const { return myHostname; }

TransportType GetTransportType(int destRank) const { return transportTypes[destRank]; }
std::vector<TransportType> GetTransportTypes() const { return transportTypes; }
const std::vector<TransportType>& GetTransportTypes() const { return transportTypes; }
int GetNumQpPerPe() const { return numQpPerPe; }

RdmaContext* GetRdmaContext() const { return rdmaContext.get(); }
Expand All @@ -51,6 +51,8 @@ class Context {

// Check if P2P connection is possible with a peer (same node)
bool CanUseP2P(int destRank) const;
// Check if peer is in the same OS process (enables direct pointer access, skip IPC handle)
bool SameProcessP2P(int destRank) const;

// Cached env-var snapshot taken at construction time. All later code MUST
// consult these (not getenv) so that env-var changes after Context init
Expand All @@ -67,14 +69,20 @@ class Context {
void CollectHostNames();
void InitializePossibleTransports();

struct PeerInfo {
bool sameHost{false}; // on the same node (same hostname+IP)
bool sameProcess{false}; // in the same OS process (same pid + same host)
};

private:
BootstrapNetwork& bootNet;
int rankInNode{-1};
int numQpPerPe{4};
// Snapshotted at construction; see IsSdmaEnabled() / IsP2PDisabled() above.
bool sdmaEnabled{false};
bool p2pDisabled{false};
std::vector<std::string> hostnames;
std::string myHostname;
std::vector<PeerInfo> peerInfos;
std::vector<TransportType> transportTypes;

std::unique_ptr<RdmaContext> rdmaContext{nullptr};
Expand Down
10 changes: 9 additions & 1 deletion include/mori/application/transport/sdma/anvil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,16 @@ class AnvilLib {

int getSdmaEngineId(int srcDeviceId, int dstDeviceId);

struct PairHash {
std::size_t operator()(const std::pair<int, int>& p) const {
return std::hash<int>()(p.first) ^ (std::hash<int>()(p.second) << 16);
}
};

std::once_flag init_flag;
std::unordered_map<int, std::vector<std::unique_ptr<SdmaQueue>>> sdma_channels_;
std::mutex channels_mutex_;
std::unordered_map<std::pair<int, int>, std::vector<std::unique_ptr<SdmaQueue>>, PairHash>
sdma_channels_;
};

extern AnvilLib& anvil;
Expand Down
110 changes: 74 additions & 36 deletions include/mori/shmem/internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
#include "mori/application/application_device_types.hpp"
#include "mori/core/utils.hpp"
#include "mori/hip_compat.hpp"
#include "mori/utils/limits.hpp"

// Host-only includes: STL, ibverbs, application management classes.
// Guarded so device compilation units (.hip files) do not pull them in.
#if !defined(__HIPCC__) && !defined(__CUDACC__)
#include <array>
#include <iostream>
#include <memory>
#include <mutex>
Expand Down Expand Up @@ -95,38 +97,10 @@ struct MemoryStates {
application::SymmMemObjPtr vmmHeapObj; // SymmMemObj for the entire heap
};

enum ShmemStatesStatus {
New = 0,
Initialized = 1,
Finalized = 2,
};

struct ShmemStates {
ShmemStatesStatus status{ShmemStatesStatus::New};
ShmemMode mode{ShmemMode::StaticHeap}; // Default to static heap mode
BootStates* bootStates{nullptr};
RdmaStates* rdmaStates{nullptr};
MemoryStates* memoryStates{nullptr};

// This is a temporary API for debugging only
void CheckStatusValid() {
if (status == ShmemStatesStatus::New) {
std::cout
<< "Shmem state is not initialized, initialize it by calling ShmemMpiInitialize first."
<< std::endl;
assert(false);
}
if (status == ShmemStatesStatus::Finalized) {
std::cout << "Shmem state has been finalized." << std::endl;
assert(false);
}
}
};

#endif // !defined(__HIPCC__) && !defined(__CUDACC__)

/* ---------------------------------------------------------------------------------------------- */
/* Device-safe GPU-side structures */
/* Device-safe GPU-side structures */
/* ---------------------------------------------------------------------------------------------- */

// GPU-side RDMA endpoint: only the fields used by device kernels.
Expand Down Expand Up @@ -155,6 +129,7 @@ struct ShmemRdmaEndpoint {
}
};

// GpuStates must be declared before ModuleStates and ShmemStates which embed it.
struct GpuStates {
int rank{-1};
int worldSize{-1};
Expand Down Expand Up @@ -196,19 +171,82 @@ struct RemoteAddrInfo {

#if !defined(__HIPCC__) && !defined(__CUDACC__)

enum ShmemStatesStatus {
New = 0,
Initialized = 1,
// Finalized: reserved. ShmemFinalize() currently resets the slot to `New`
// so the same GPU can be re-initialized later (needed by SPMT test suites
// that run multiple init/finalize cycles). Keep this value for the case
// where future finalize semantics need to mark the slot as terminally done.
Finalized = 2,
};

// Per-GPU JIT module state (HIP module handle + device symbol pointers)
struct ModuleStates {
hipModule_t module{nullptr};
GpuStates* gpuStatesPtr{nullptr}; // device-side globalGpuStates address in JIT module
hipFunction_t barrierFunc{nullptr};
};

struct ShmemStates {
ShmemStatesStatus status{ShmemStatesStatus::New};
ShmemMode mode{ShmemMode::StaticHeap}; // Default to static heap mode
BootStates* bootStates{nullptr};
RdmaStates* rdmaStates{nullptr};
MemoryStates* memoryStates{nullptr};
ModuleStates moduleStates; // JIT module state for this GPU
GpuStates gpuStates; // host-side copy of device GpuStates for this GPU

// Asserts that ShmemInit has been called and the slot is currently usable.
// Used by APIs that touch GPU state (allocation, barrier, module init)
// which need a fully constructed slot.
void CheckStatusValid() {
if (status == ShmemStatesStatus::New) {
std::cout << "Shmem state is not initialized, call ShmemInit*/shmem_init_attr first."
<< std::endl;
assert(false);
}
if (status == ShmemStatesStatus::Finalized) {
std::cout << "Shmem state has been finalized." << std::endl;
assert(false);
}
}
};

// Internal functions shared between init.cpp and runtime.cpp
void CopyGpuStatesToDevice(const GpuStates* gpuStates);
void FinalizeRuntime();
extern GpuStates s_hostGpuStatesCopy;
void CopyGpuStatesToDevice(ShmemStates* states);
void FinalizeRuntime(ShmemStates* states);

class ShmemStatesSingleton {
public:
ShmemStatesSingleton(const ShmemStatesSingleton& obj) = delete;

static ShmemStates* GetInstance() {
static ShmemStates states;
return &states;
}
static ShmemStates* GetInstance();

#ifdef MORI_MULTITHREAD_SUPPORT
// SPMT: rank → HIP device id mapping, populated at ShmemInit.
//
// Needed by FFI/custom-call handlers (e.g. XLA) that run on framework worker
// threads where hipGetDevice() does not return the rank's device. The handler
// can look up the device for a given rank and hipSetDevice() to it before
// touching MORI state.
//
// Returns -1 if no rank-to-device mapping has been recorded yet (caller
// should fall back to hipGetDevice()-based lookup or fail loudly).
static void RegisterRankDevice(int rank, int deviceId);
static int GetDeviceByRank(int rank);
#endif

private:
#ifdef MORI_MULTITHREAD_SUPPORT
// One ShmemStates slot per GPU, indexed by hipGetDevice().
// std::array gives stable addresses (no realloc unlike deque/vector).
// No lock needed: SPMT contract is one thread per GPU, so each slot is
// accessed serially by its owning thread; the rank → device map below is
// the only structure that needs cross-thread synchronization.
std::array<ShmemStates, mori::kMaxGpusPerNode> states_{};
ShmemStatesSingleton() = default;
#endif
};

#endif // !defined(__HIPCC__) && !defined(__CUDACC__)
Expand Down
33 changes: 33 additions & 0 deletions include/mori/utils/limits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright © Advanced Micro Devices, Inc. All rights reserved.
//
// MIT License
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#pragma once

// Tiny header so any TU (device or host) can size per-GPU arrays without
// pulling in shmem/ops headers. Bump this when supporting nodes with > 8 GPUs
// (e.g. future MI400 platforms) — every per-GPU array sized by this constant
// will pick up the new value automatically.

namespace mori {

inline constexpr int kMaxGpusPerNode = 8;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with CPX it is 32. Not sure if it can be split even more than that.

Copy link
Copy Markdown
Contributor

@i-chaochen i-chaochen May 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we will have 32 72 GPUs at rack level. It's best to not hardcode this or we should get this max number of GPU from the build script?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, CPX and rack level will have more than 8 GPUs. We have already tried this on CPX, but it was limited to a single card before... Additionally, the rack level is also included in our plan...
So currently, kMaxGpusPerNode equals to 8...


} // namespace mori
15 changes: 12 additions & 3 deletions include/mori/utils/mori_log.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,18 @@ class ModuleLogger {
// Use existing logger
logger = existing_logger;
} else {
// Create new logger
logger = spdlog::stdout_color_mt(moduleName);
logger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%P] [%n] [%^%l%$] %v");
// spdlog::stdout_color_mt throws if another thread already registered the same name
// between our spdlog::get() check and this call — catch and fall back to the winner.
try {
logger = spdlog::stdout_color_mt(moduleName);
logger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%P] [%n] [%^%l%$] %v");
} catch (const spdlog::spdlog_ex&) {
logger = spdlog::get(moduleName);
}
// Defensive: spdlog::get may still return null if registration was
// dropped between throw and our second lookup. Bail out cleanly
// instead of dereferencing a null shared_ptr below.
if (!logger) return;
}

// Determine the log level priority: env var > global setting > provided level
Expand Down
2 changes: 1 addition & 1 deletion python/mori/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def get_dispatch_src_token_pos(self, total_recv_token_num):
cpp.EpDispatchCombineKernelType.IntraNode.value,
cpp.EpDispatchCombineKernelType.InterNodeV1.value,
cpp.EpDispatchCombineKernelType.InterNodeV1LL.value,
cpp.EpDispatchCombineKernelType.AsyncLL.value,
):
# here we need to allocate enough space to accomodate handle->totalRecvTokenNum[0] items
n_tokens = self.config.max_num_tokens_to_recv()
return jax.ffi.ffi_call(
"mori_ep",
Expand Down
67 changes: 58 additions & 9 deletions python/mori/shmem/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,64 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import ctypes
import threading

from mori import cpp as mori_cpp

# Initialization flags
MORI_SHMEM_INIT_WITH_MPI_COMM = mori_cpp.MORI_SHMEM_INIT_WITH_MPI_COMM
MORI_SHMEM_INIT_WITH_UNIQUEID = mori_cpp.MORI_SHMEM_INIT_WITH_UNIQUEID

_shmem_module_loaded = False
# Per-GPU module loading: keyed by device ID so that each GPU context gets its
# own hipModuleLoad call even when multiple threads share one process (SPMT).
_shmem_module_lock = threading.Lock()
_shmem_module_loaded_gpus: set = set()
# Cached hsaco path (compilation is arch-specific, not instance-specific).
_shmem_hsaco: str = ""


def _current_hip_device() -> int:
"""Return the calling thread's current HIP device id.

Uses ctypes against libamdhip64 directly so that the JAX path (which has
no torch dependency) works the same as the PyTorch path.
"""
from mori.jit.hip_driver import _get_hip_lib

hip = _get_hip_lib()
# Set explicit ctypes signatures — without these, ctypes assumes int args
# and int return, which happens to be right on x86_64 Linux but is not
# portable. Be explicit so future ABI changes don't silently break us.
hip.hipGetDevice.argtypes = [ctypes.POINTER(ctypes.c_int)]
hip.hipGetDevice.restype = ctypes.c_int
dev = ctypes.c_int(-1)
err = hip.hipGetDevice(ctypes.byref(dev))
if err != 0:
raise RuntimeError(f"hipGetDevice failed with error {err}")
return int(dev.value)


def _ensure_shmem_module():
"""JIT-compile and load the shmem device module before ShmemInit."""
global _shmem_module_loaded
if _shmem_module_loaded:
"""JIT-compile and load the shmem device module before ShmemInit.

Thread-safe: each GPU device context gets exactly one load_shmem_module
call, enabling single-process multi-thread (SPMT) use where each thread
owns a different GPU.
"""
device_id = _current_hip_device()
if device_id in _shmem_module_loaded_gpus:
return
from mori.jit.core import compile_genco
with _shmem_module_lock:
if device_id in _shmem_module_loaded_gpus:
return
global _shmem_hsaco
if not _shmem_hsaco:
from mori.jit.core import compile_genco

hsaco = compile_genco("shmem_kernels")
mori_cpp.load_shmem_module(hsaco)
_shmem_module_loaded = True
_shmem_hsaco = compile_genco("shmem_kernels")
mori_cpp.load_shmem_module(_shmem_hsaco)
_shmem_module_loaded_gpus.add(device_id)


def shmem_torch_process_group_init(group_name: str):
Expand Down Expand Up @@ -124,7 +163,17 @@ def shmem_finalize():
Returns:
Status code (0 for success)
"""
return mori_cpp.shmem_finalize()
ret = mori_cpp.shmem_finalize()
# Clear this GPU's module-loaded flag so a subsequent shmem_init_attr
# call (e.g. in the next test round) will reload the JIT module.
try:
device_id = _current_hip_device()
except Exception:
# If HIP context is gone (e.g. process teardown), skip cache cleanup.
return ret
with _shmem_module_lock:
_shmem_module_loaded_gpus.discard(device_id)
return ret


def shmem_module_init(hip_module: int):
Expand Down
Loading
Loading