diff --git a/lmcache/v1/distributed/config.py b/lmcache/v1/distributed/config.py index 5bdf503ef4..fe01d61bb1 100644 --- a/lmcache/v1/distributed/config.py +++ b/lmcache/v1/distributed/config.py @@ -6,7 +6,7 @@ # Standard from dataclasses import dataclass, field -from typing import Literal +from typing import TYPE_CHECKING, Literal, Optional import argparse # First Party @@ -16,6 +16,10 @@ parse_args_to_l2_adapters_config, ) +if TYPE_CHECKING: + # First Party + from lmcache.v1.distributed.maru_memory_allocator import MaruL1Config + @dataclass class L1MemoryManagerConfig: @@ -24,10 +28,11 @@ class L1MemoryManagerConfig: """ size_in_bytes: int - """ The size of L1 memory in bytes. """ + """ The size of L1 memory in bytes. (Ignored when ``maru_config`` is set.) """ use_lazy: bool - """ Whether to use lazy initialization for L1 memory. """ + """ Whether to use lazy initialization for L1 memory. + (Ignored when ``maru_config`` is set.) """ init_size_in_bytes: int = field(default=20 << 30) """ The initial size when using lazy allocation. Default is 20GB. """ @@ -35,8 +40,15 @@ class L1MemoryManagerConfig: align_bytes: int = field(default=0x1000) """ The alignment size in bytes. Default is 4KB. """ + maru_config: Optional["MaruL1Config"] = None + """ Optional Maru backend config. When set, the L1 allocator is + constructed as ``MaruMemoryAllocator`` (CXL-backed) and the DRAM + fields above are ignored. """ + def __post_init__(self): - self.init_size_in_bytes = min(self.init_size_in_bytes, self.size_in_bytes) + # The DRAM init-size clamp only makes sense for default backends. + if self.maru_config is None: + self.init_size_in_bytes = min(self.init_size_in_bytes, self.size_in_bytes) @dataclass @@ -152,6 +164,38 @@ def add_storage_manager_args( help="The alignment size in bytes. Default is 4KB (4096 bytes).", ) + # Maru L1 backend (optional). When --maru-server-url is set, the + # L1 allocator becomes CXL-backed and the DRAM L1 settings above + # (--l1-size-gb / --l1-use-lazy / --l1-init-size-gb) are ignored. + # Pass ``--l1-size-gb 0`` in that case to satisfy the required flag. + maru_group = parser.add_argument_group( + "Maru L1 Backend", + "Optional CXL-backed L1 via Maru. Overrides DRAM L1 settings.", + ) + maru_group.add_argument( + "--maru-server-url", + type=str, + default=None, + help="MaruServer endpoint (e.g. maru://host:port or tcp://host:port). " + "When set, the L1 allocator is CXL-backed and the DRAM L1 settings " + "(--l1-size-gb, --l1-use-lazy, --l1-init-size-gb) are ignored.", + ) + maru_group.add_argument( + "--maru-pool-size-gb", + type=float, + default=0.0, + help="CXL pool size to request from MaruServer (GB). " + "Required when --maru-server-url is set.", + ) + maru_group.add_argument( + "--maru-instance-id", + type=str, + default=None, + help="Stable client identifier reported to MaruServer for ownership " + "tracking and restart recovery. Auto-generated if omitted " + "(acceptable for single-node setups; recommended for multi-node).", + ) + # L1 Manager Config (TTL settings) ttl_group = parser.add_argument_group( "L1 Manager TTL", "TTL configuration for L1 manager locks" @@ -274,11 +318,27 @@ def parse_args_to_config( Returns: StorageManagerConfig: The configuration object. """ + maru_config: Optional["MaruL1Config"] = None + if args.maru_server_url is not None: + if args.maru_pool_size_gb <= 0: + raise ValueError( + "--maru-pool-size-gb must be positive when --maru-server-url is set" + ) + # First Party + from lmcache.v1.distributed.maru_memory_allocator import MaruL1Config + + maru_config = MaruL1Config( + server_url=args.maru_server_url, + pool_size_bytes=int(args.maru_pool_size_gb * (1 << 30)), + instance_id=args.maru_instance_id, + ) + memory_config = L1MemoryManagerConfig( size_in_bytes=int(args.l1_size_gb * (1 << 30)), use_lazy=args.l1_use_lazy, init_size_in_bytes=int(args.l1_init_size_gb * (1 << 30)), align_bytes=args.l1_align_bytes, + maru_config=maru_config, ) l1_manager_config = L1ManagerConfig( diff --git a/lmcache/v1/distributed/internal_api.py b/lmcache/v1/distributed/internal_api.py index dcc660bfd9..f3aa28decc 100644 --- a/lmcache/v1/distributed/internal_api.py +++ b/lmcache/v1/distributed/internal_api.py @@ -10,6 +10,13 @@ # First Party from lmcache.v1.distributed.api import ObjectKey +from lmcache.v1.distributed.error import L1Error +from lmcache.v1.memory_management import MemoryObj + +L1OperationResult = tuple[L1Error, MemoryObj | None] +""" Result tuple returned by L1Manager (and its maru dispatcher) +read/write reservation methods: ``(error, memory_obj)``. ``memory_obj`` +is ``None`` whenever ``error != L1Error.SUCCESS``. """ @dataclass(frozen=True) diff --git a/lmcache/v1/distributed/l1_manager.py b/lmcache/v1/distributed/l1_manager.py index e4e4379f30..7d3506101c 100644 --- a/lmcache/v1/distributed/l1_manager.py +++ b/lmcache/v1/distributed/l1_manager.py @@ -5,18 +5,22 @@ # Standard from dataclasses import dataclass -from typing import Literal +from typing import Literal, Optional import threading +# Third Party +import torch + # First Party from lmcache.logging import init_logger from lmcache.native_storage_ops import TTLLock from lmcache.v1.distributed.api import MemoryLayoutDesc, ObjectKey from lmcache.v1.distributed.config import L1ManagerConfig from lmcache.v1.distributed.error import L1Error -from lmcache.v1.distributed.internal_api import L1ManagerListener -from lmcache.v1.distributed.memory_manager import L1MemoryManager -from lmcache.v1.memory_management import MemoryObj +from lmcache.v1.distributed.internal_api import L1ManagerListener, L1OperationResult +from lmcache.v1.distributed.maru_l1_dispatch import MaruL1Dispatcher +from lmcache.v1.distributed.memory_manager import L1MemoryManager, _is_maru_allocator +from lmcache.v1.memory_management import MemoryFormat, MemoryObj from lmcache.v1.mp_observability.event import Event, EventType from lmcache.v1.mp_observability.event_bus import get_event_bus from lmcache.v1.mp_observability.otel_init import register_gauge @@ -78,8 +82,6 @@ def wrapper(self: "L1Manager", *args, **kwargs): return wrapper -L1OperationResult = tuple[L1Error, MemoryObj | None] - # Upper bound for the count parameter in reserve_read / finish_read # to prevent a single call from holding the global lock for too long. MAX_READ_LOCK_COUNT = 128 @@ -193,6 +195,28 @@ def __init__(self, config: L1ManagerConfig): self._event_bus = get_event_bus() + # When the L1 allocator is ``MaruMemoryAllocator``, L1Manager + # operates in pass-through mode: the state machine / TTLLock / + # eviction policy are bypassed and MaruServer RPCs are issued + # directly via :class:`MaruL1Dispatcher`, which encapsulates + # the maru-specific handler reference and read-side channel. + self._maru_dispatcher: Optional[MaruL1Dispatcher] = None + if _is_maru_allocator(self._memory_manager.allocator): + # Lazy-import the concrete allocator class so the maru + # runtime stays optional for non-maru deployments. + # First Party + from lmcache.v1.distributed.maru_memory_allocator import ( + MaruMemoryAllocator, + ) + + assert isinstance(self._memory_manager.allocator, MaruMemoryAllocator) + self._maru_dispatcher = MaruL1Dispatcher( + allocator=self._memory_manager.allocator, + memory_manager=self._memory_manager, + write_ttl_seconds=self._write_ttl_seconds, + read_ttl_seconds=self._read_ttl_seconds, + ) + L1Manager._gauge_target = self if not L1Manager._gauge_registered: L1Manager._gauge_registered = True @@ -213,12 +237,55 @@ def __init__(self, config: L1ManagerConfig): lambda: _l1_usage_ratio_or_zero(L1Manager._gauge_target), ) + def _is_maru_backend(self) -> bool: + """True when the L1 allocator is ``MaruMemoryAllocator``. + + In maru mode, L1Manager operates as a pass-through shim: + - The object dict / TTLLock state machine / eviction policy are + all skipped (the engine flow goes straight to MaruServer via + the dispatcher). + - Listeners are NOT invoked (controllers / observability paths + are bypassed). + """ + return self._maru_dispatcher is not None + + def register_kv_layout( + self, + shapes: list[torch.Size], + dtypes: list[torch.dtype], + fmt: MemoryFormat, + chunk_size_in_tokens: int, + ) -> None: + """Bind the KV layout to the underlying memory manager. + + Forwarded from ``StorageManager.register_kv_layout``, which is + in turn invoked by ``MPCacheEngine.register_kv_cache`` after a + vLLM worker exposes its KV cache tensors. Only the maru + backend acts on the call; default DRAM allocators are + layout-agnostic so it is a no-op for them. + + Args: + shapes: KV chunk shapes (per-layer-group). + dtypes: KV chunk dtypes aligned with ``shapes``. + fmt: Memory format. + chunk_size_in_tokens: LMCache chunk size in tokens. + """ + self._memory_manager.register_kv_layout( + shapes, dtypes, fmt, chunk_size_in_tokens + ) + def register_listener(self, listener: L1ManagerListener) -> None: """Register a listener for L1Manager events. Args: listener: The listener to register. """ + if self._is_maru_backend(): + # Maru mode bypasses StoreController / PrefetchController / + # L1EvictionController, so listener callbacks are never + # invoked. Registration is silently dropped to keep the API + # surface stable. + return with self._lock: self._registered_listeners.append(listener) @@ -248,6 +315,9 @@ def reserve_read( KEY_NOT_READABLE: The key exists but is not readable. """ + if self._maru_dispatcher is not None: + return self._maru_dispatcher.reserve_read(keys) + extra_count = _validate_extra_count(extra_count) total = 1 + extra_count ret: dict[ObjectKey, L1OperationResult] = {} @@ -302,6 +372,9 @@ def unsafe_read( KEY_NOT_EXIST: The key does not exist. KEY_NOT_READABLE: The key is not readable (in this case, not read-locked). """ + if self._maru_dispatcher is not None: + return self._maru_dispatcher.unsafe_read(keys) + ret: dict[ObjectKey, L1OperationResult] = {} for key in keys: @@ -347,6 +420,9 @@ def finish_read( non-read-locked, which means the reader may read inconsistent data. """ + if self._maru_dispatcher is not None: + return self._maru_dispatcher.finish_read(keys) + extra_count = _validate_extra_count(extra_count) total = 1 + extra_count need_to_free: list[MemoryObj] = [] @@ -446,6 +522,11 @@ def reserve_write( KEY_NOT_WRITABLE: The key exists but is not writable. OUT_OF_MEMORY: Not enough memory to allocate for the object. """ + if self._maru_dispatcher is not None: + return self._maru_dispatcher.reserve_write( + keys, is_temporary, layout_desc, mode + ) + need_to_allocate: list[tuple[ObjectKey, bool]] = [] ret: dict[ObjectKey, L1OperationResult] = {} successful_keys: list[ObjectKey] = [] @@ -518,11 +599,18 @@ def reserve_write( def finish_write( self, keys: list[ObjectKey], + memory_objs: Optional[list[MemoryObj]] = None, ) -> dict[ObjectKey, L1Error]: """Finish write access for the given keys. Args: keys: The list of object keys to finish write access for. + memory_objs: The ``MemoryObj`` instances corresponding to + ``keys``. **Required in maru mode** (used to issue + ``MaruHandler.batch_store``); ignored in default mode + (the in-process dict already holds the MemoryObj). + Defaults to ``None`` for backward compatibility with + callers that only update L1 state. Returns: A dictionary mapping each object key to an L1Error. @@ -532,6 +620,9 @@ def finish_write( KEY_IN_WRONG_STATE: The key is not write-locked, or it's read-locked, which means the writer may have caused inconsistent data. """ + if self._maru_dispatcher is not None: + return self._maru_dispatcher.finish_write(keys, memory_objs) + ret: dict[ObjectKey, L1Error] = {} successful_keys: list[ObjectKey] = [] @@ -602,6 +693,15 @@ def finish_write_and_reserve_read( KEY_IN_WRONG_STATE: The key is not write-locked, or it already has read locks. """ + if self._maru_dispatcher is not None: + # Maru flow stages MemoryObjs in the side channel during + # ``reserve_read`` and the engine transitions straight to + # ``unsafe_read`` → ``finish_read``, so this atomic + # write-to-read transition is never exercised by the maru + # path. Return a safe SUCCESS response in case any caller + # still invokes it. + return self._maru_dispatcher.finish_write_and_reserve_read(keys) + extra_count = _validate_extra_count(extra_count) total = 1 + extra_count ret: dict[ObjectKey, L1OperationResult] = {} @@ -661,6 +761,9 @@ def delete(self, keys: list[ObjectKey]) -> dict[ObjectKey, L1Error]: KEY_IS_LOCKED: The key is locked (either write-locked or read-locked and cannot be deleted). """ + if self._maru_dispatcher is not None: + return self._maru_dispatcher.delete(keys) + need_to_free: list[MemoryObj] = [] ret: dict[ObjectKey, L1Error] = {} successful_keys: list[ObjectKey] = [] @@ -698,6 +801,10 @@ def touch_keys(self, keys: list[ObjectKey]): Args: keys: The list of object keys to touch. """ + if self._is_maru_backend(): + # No LRU bookkeeping in maru mode — MaruServer owns eviction + # decisions and ``touch_keys`` has no observable effect. + return for listener in self._registered_listeners: listener.on_l1_keys_accessed(keys) @@ -711,6 +818,10 @@ def clear(self, force: bool = False) -> None: If False (default), only clear unlocked objects, keeping write-locked and read-locked objects intact. """ + if self._maru_dispatcher is not None: + self._maru_dispatcher.clear(force) + return + if force: logger.warning( "L1Manager: force-clearing all %d objects " @@ -782,6 +893,14 @@ def is_key_evictable(self, key: ObjectKey) -> bool: True if the key exists and is not locked (neither read-locked nor write-locked), False otherwise. """ + if self._is_maru_backend(): + # L1EvictionController is not registered in maru mode, so + # this method is never consulted on the hot path. Return + # True to keep the contract simple for any defensive caller: + # MaruServer's ``pin_kv`` / ``delete_kv`` make their own + # atomic decisions and the LMCache-side answer has no + # bearing on actual eviction. + return True entry = self._objects.get(key, None) if entry is None: return False @@ -806,9 +925,15 @@ def get_l1_memory_desc(self): def close(self) -> None: """Close the L1Manager and free all resources.""" with self._lock: - all_memory_objs = [entry.memory_obj for entry in self._objects.values()] - self._memory_manager.free(all_memory_objs) - self._objects.clear() + if self._maru_dispatcher is not None: + # No in-process state machine — just drop any pending + # read-side handles so the engine can shut down cleanly. + # CXL page lifecycle remains owned by MaruServer. + self._maru_dispatcher.clear(force=False) + else: + all_memory_objs = [entry.memory_obj for entry in self._objects.values()] + self._memory_manager.free(all_memory_objs) + self._objects.clear() self._memory_manager.close() @@ -816,6 +941,9 @@ def close(self) -> None: @l1_mgr_synchronized def report_status(self) -> dict: """Return a status dict describing L1 cache state.""" + if self._maru_dispatcher is not None: + return self._maru_dispatcher.report_status() + write_locked = 0 read_locked = 0 temporary = 0 @@ -851,11 +979,20 @@ def get_object_state(self, key: ObjectKey) -> L1ObjectState | None: Returns: The L1ObjectState if the object exists, None otherwise. """ + if self._is_maru_backend(): + # No in-process state machine in maru mode (no TTLLock / + # is_temporary / L1ObjectState). + return None return self._objects.get(key, None) @l1_mgr_synchronized def memcheck(self) -> bool: """Perform memory check for L1 cache.""" + if self._is_maru_backend(): + # No object dict to introspect; MaruServer + handler own + # consistency. Always healthy from LMCache's vantage point. + return True + mem_check_result = self._memory_manager.memcheck() # Log the locked objects for debugging diff --git a/lmcache/v1/distributed/l2_adapters/maru_l2_adapter.py b/lmcache/v1/distributed/l2_adapters/maru_l2_adapter.py new file mode 100644 index 0000000000..097e9bd20f --- /dev/null +++ b/lmcache/v1/distributed/l2_adapters/maru_l2_adapter.py @@ -0,0 +1,796 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Maru-backed MP L2 adapter. + +Stores L1 (DRAM) ``MemoryObj`` payloads in a CXL pool managed by +``MaruServer``. The adapter keeps L1 as the default DRAM allocator +and uses ``MaruHandler`` directly for the L2 tier — engine hot path +is ``GPU ↔ DRAM (cudaMemcpy) ↔ CXL (DRAM↔CXL memcpy via adapter +worker)``. + +Registered under ``--l2-adapter '{"type":"maru",...}'`` via +``register_l2_adapter_factory``. +""" + +# Future +from __future__ import annotations + +# Standard +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Any, Optional +import ctypes +import os +import threading + +if TYPE_CHECKING: + # First Party + from lmcache.v1.distributed.internal_api import L1MemoryDesc + +# Third Party +import numpy as np + +# First Party +from lmcache.logging import init_logger +from lmcache.native_storage_ops import Bitmap +from lmcache.v1.distributed.api import ObjectKey +from lmcache.v1.distributed.l2_adapters.base import ( + L2AdapterInterface, + L2TaskId, +) +from lmcache.v1.distributed.l2_adapters.config import ( + L2AdapterConfigBase, + register_l2_adapter_type, +) +from lmcache.v1.distributed.l2_adapters.factory import ( + register_l2_adapter_factory, +) +from lmcache.v1.memory_management import MemoryObj +from lmcache.v1.platform import create_event_notifier + +logger = init_logger(__name__) + + +def _parse_positive_int(value, field_name: str) -> int: + if not isinstance(value, int) or value <= 0: + raise ValueError(f"{field_name} must be a positive integer") + return value + + +def _object_key_to_string(key: ObjectKey) -> str: + """``ObjectKey`` → MaruServer-stable string form. + + The format mirrors the L1 dispatcher's + :func:`lmcache.v1.distributed.maru_l1_dispatch.object_key_to_string` + so KV entries are interoperable between the two paths. + """ + base = f"{key.model_name}@{key.kv_rank:08x}@{key.chunk_hash.hex()}" + if key.cache_salt: + return f"{base}@{key.cache_salt}" + return base + + +def _memoryview_addr(mv: memoryview) -> int: + """Return the raw memory address of ``mv``'s first byte. + + Uses :func:`numpy.frombuffer` for a zero-copy uint8 view; the + numpy array exposes the underlying pointer via ``ctypes.data``. + Works for any writable / readable memoryview backed by contiguous + memory (which MaruHandler's ``AllocHandle.buf`` and + ``MemoryInfo.view`` both are). + """ + return int(np.frombuffer(mv, dtype=np.uint8).ctypes.data) + + +class MaruL2AdapterConfig(L2AdapterConfigBase): + """Configuration for the maru L2 adapter. + + The adapter connects to a ``MaruServer`` at ``server_url`` and + requests a CXL pool of ``pool_size_gb`` GiB. User-facing name + matches the L1 path (``--maru-pool-size-gb`` CLI flag); the value + is converted to bytes once inside :meth:`__init__` and exposed + internally as :attr:`pool_size_bytes` (matching + :class:`MaruL1Config.pool_size_bytes`). + + ``chunk_size_bytes`` (the MaruServer page size) is optional — + leaving it unset defers ``MaruHandler.connect()`` until the first + :meth:`MaruL2Adapter.submit_store_task` and derives the value + from the inbound ``MemoryObj.get_physical_size()`` (which + LMCache aligns to a full KV chunk). Spell it out only if you + need lookup-and-lock or load to fire before any store, or if + you want to pin the page size to a specific value. + """ + + def __init__( + self, + *, + server_url: str, + pool_size_gb: float, + chunk_size_bytes: Optional[int] = None, + instance_id: Optional[str] = None, + num_store_workers: int = 1, + num_lookup_workers: int = 1, + num_load_workers: int = min(4, os.cpu_count() or 1), + timeout_ms: int = 5000, + use_async_rpc: bool = True, + max_inflight: int = 64, + eager_map: bool = True, + ) -> None: + """Build a validated config. + + Args: + server_url: MaruServer endpoint (``maru://host:port`` or + ``tcp://host:port``; the former is rewritten internally). + pool_size_gb: CXL pool quota requested from MaruServer + (GiB; ``1 GiB = 1 << 30 bytes``). Matches the + user-facing form of ``--maru-pool-size-gb`` used by + the L1 path. Internally converted to + :attr:`pool_size_bytes`. + chunk_size_bytes: MaruServer page / chunk size, in bytes. + Optional — when ``None``, derived from the first + ``MemoryObj`` handed to :meth:`submit_store_task`. + Set explicitly to lock the page size or to allow + lookup-first usage. + instance_id: Stable client identifier reported to MaruServer + (UUID auto-generated if ``None``). + num_store_workers: Worker threads for store tasks. + num_lookup_workers: Worker threads for lookup-and-lock tasks. + num_load_workers: Worker threads for load tasks. + timeout_ms: Socket timeout for MaruHandler RPCs. + use_async_rpc: Whether to use the DEALER-ROUTER async RPC + client (matches MaruHandler default). + max_inflight: Max concurrent in-flight async RPCs. + eager_map: Whether MaruHandler should pre-map all shared + regions on connect. + """ + self.server_url = server_url + self.pool_size_gb = float(pool_size_gb) + # Cached bytes form for internal use (matches MaruL1Config + # naming). Converted once here so downstream code doesn't + # re-multiply. + self.pool_size_bytes: int = int(self.pool_size_gb * (1 << 30)) + self.chunk_size_bytes = chunk_size_bytes + self.instance_id = instance_id + self.num_store_workers = num_store_workers + self.num_lookup_workers = num_lookup_workers + self.num_load_workers = num_load_workers + self.timeout_ms = timeout_ms + self.use_async_rpc = use_async_rpc + self.max_inflight = max_inflight + self.eager_map = eager_map + + @classmethod + def from_dict(cls, d: dict) -> "MaruL2AdapterConfig": + """Build the config from a ``--l2-adapter`` JSON object. + + Args: + d: Parsed CLI JSON. + + Returns: + A validated ``MaruL2AdapterConfig``. + + Raises: + ValueError: If a required field is missing or any numeric + field fails its positivity check. + """ + server_url = d.get("server_url") + if not isinstance(server_url, str) or not server_url.strip(): + raise ValueError("server_url must be a non-empty string") + + pool_size_gb = d.get("pool_size_gb") + if not isinstance(pool_size_gb, (int, float)) or pool_size_gb <= 0: + raise ValueError("pool_size_gb must be a positive number") + + chunk_size_bytes = d.get("chunk_size_bytes") + if chunk_size_bytes is not None: + if not isinstance(chunk_size_bytes, int) or chunk_size_bytes <= 0: + raise ValueError( + "chunk_size_bytes must be a positive integer when provided" + ) + + instance_id = d.get("instance_id") + if instance_id is not None and not isinstance(instance_id, str): + raise ValueError("instance_id must be a string when provided") + + num_store_workers = _parse_positive_int( + d.get("num_store_workers", 1), "num_store_workers" + ) + num_lookup_workers = _parse_positive_int( + d.get("num_lookup_workers", 1), "num_lookup_workers" + ) + num_load_workers = _parse_positive_int( + d.get("num_load_workers", min(4, os.cpu_count() or 1)), + "num_load_workers", + ) + timeout_ms = _parse_positive_int(d.get("timeout_ms", 5000), "timeout_ms") + max_inflight = _parse_positive_int(d.get("max_inflight", 64), "max_inflight") + + use_async_rpc = bool(d.get("use_async_rpc", True)) + eager_map = bool(d.get("eager_map", True)) + + return cls( + server_url=server_url.strip(), + pool_size_gb=float(pool_size_gb), + chunk_size_bytes=chunk_size_bytes, + instance_id=instance_id, + num_store_workers=num_store_workers, + num_lookup_workers=num_lookup_workers, + num_load_workers=num_load_workers, + timeout_ms=timeout_ms, + use_async_rpc=use_async_rpc, + max_inflight=max_inflight, + eager_map=eager_map, + ) + + @classmethod + def help(cls) -> str: + """Return CLI help text for the maru L2 adapter config.""" + return ( + "Maru L2 adapter config fields:\n" + "- server_url (str): MaruServer endpoint, maru:// or " + "tcp:// (required)\n" + "- pool_size_gb (float): CXL pool size to request, in GiB (required, >0). " + "Same user-facing form as the L1 path's --maru-pool-size-gb.\n" + "- chunk_size_bytes (int): MaruServer page size (optional, >0). " + "When omitted, derived from the first stored MemoryObj's " + "physical size. Set explicitly to pin the size or to allow " + "lookup-first usage.\n" + "- instance_id (str): client identifier (optional; UUID if " + "omitted)\n" + "- num_store_workers (int): store worker threads " + "(optional, default 1)\n" + "- num_lookup_workers (int): lookup worker threads " + "(optional, default 1)\n" + "- num_load_workers (int): load worker threads " + "(optional, default min(4, cpu_count))\n" + "- timeout_ms (int): RPC socket timeout (optional, default 5000)\n" + "- use_async_rpc (bool): async DEALER-ROUTER (optional, default true)\n" + "- max_inflight (int): concurrent in-flight RPCs " + "(optional, default 64)\n" + "- eager_map (bool): pre-map regions on connect " + "(optional, default true)" + ) + + +class MaruL2Adapter(L2AdapterInterface): + """MP L2 adapter that stores KV chunks in a CXL pool via MaruServer. + + Threading model: + Three independent ``ThreadPoolExecutor`` pools serve store, + lookup-and-lock, and load tasks. Each completion signals its + dedicated ``EventNotifier`` so the store / prefetch controllers + can poll completions without cross-talk. ``MaruHandler`` itself + is thread-safe (async DEALER-ROUTER); the adapter's task + bookkeeping is guarded by ``self._lock``. + """ + + def __init__(self, config: MaruL2AdapterConfig) -> None: + """Prepare worker pools / event fds; defer MaruServer connect. + + The MaruHandler connection is built lazily on first use so + ``chunk_size_bytes`` can be derived from inbound MemoryObjs + when the user did not pin it in config. See + :meth:`_ensure_connected`. + + Args: + config: Validated ``MaruL2AdapterConfig``. + """ + super().__init__(max_capacity_bytes=int(config.pool_size_bytes)) + self._config = config + + # Handler stays ``None`` until ``_ensure_connected`` resolves + # a concrete ``chunk_size_bytes``. ``MaruHandler`` / + # ``maru_lmcache`` imports are also deferred inside + # ``_connect_handler`` so this module can load without the + # maru runtime installed (mirroring MaruMemoryAllocator). + self._handler: Optional[Any] = None + # ``None`` while the page size is still unknown; pinned to the + # config value (if set) or the first-store hint on connect. + self._chunk_size_bytes: Optional[int] = config.chunk_size_bytes + + # Three distinct event notifiers — controllers' fd-to-adapter + # dispatch maps require them to be unique per task type. + self._store_efd = create_event_notifier() + self._lookup_efd = create_event_notifier() + self._load_efd = create_event_notifier() + + # Lazily-built thread pools so close() can tear them down + # without surprising shutdown races when a task is mid-flight. + self._store_executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor( + max_workers=config.num_store_workers, + thread_name_prefix="maru-l2-store", + ) + self._lookup_executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor( + max_workers=config.num_lookup_workers, + thread_name_prefix="maru-l2-lookup", + ) + self._load_executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor( + max_workers=config.num_load_workers, + thread_name_prefix="maru-l2-load", + ) + + # Task bookkeeping — shape matches the DAX adapter so the store + # / prefetch controllers see a familiar surface. + self._next_task_id: L2TaskId = 0 + self._completed_store_tasks: dict[L2TaskId, bool] = {} + self._completed_lookup_tasks: dict[L2TaskId, Bitmap] = {} + self._completed_load_tasks: dict[L2TaskId, Bitmap] = {} + self._inflight_store_tasks = 0 + self._inflight_lookup_tasks = 0 + self._inflight_load_tasks = 0 + + self._lock = threading.Lock() + self._closed = False + + # ------------------------------------------------------------------ + # Construction helpers + # ------------------------------------------------------------------ + + @staticmethod + def _connect_handler(config: MaruL2AdapterConfig, chunk_size_bytes: int) -> Any: + """Build and connect a ``MaruHandler`` with the resolved page size. + + ``maru`` is imported lazily so the adapter module can be + loaded without the maru runtime installed. ``chunk_size_bytes`` + is taken as an explicit argument rather than from ``config`` + because the lazy path may derive it from a store payload at + first use. + + Args: + config: Adapter configuration (everything except chunk size). + chunk_size_bytes: Resolved MaruServer page size. + + Returns: + A connected ``MaruHandler``. + + Raises: + RuntimeError: If ``MaruHandler.connect()`` fails. + """ + # Third Party + from maru import MaruConfig, MaruHandler + + server_url = config.server_url + if server_url.startswith("maru://"): + server_url = "tcp://" + server_url[len("maru://") :] + + maru_config = MaruConfig( + server_url=server_url, + instance_id=config.instance_id, + pool_size=int(config.pool_size_bytes), + chunk_size_bytes=chunk_size_bytes, + auto_connect=False, + timeout_ms=config.timeout_ms, + use_async_rpc=config.use_async_rpc, + max_inflight=config.max_inflight, + eager_map=config.eager_map, + ) + + handler = MaruHandler(maru_config) + if not handler.connect(): + raise RuntimeError(f"Failed to connect MaruHandler to {config.server_url}") + logger.info( + "[MaruL2Adapter] connected: server=%s instance_id=%s " + "pool_bytes=%s chunk_size_bytes=%d", + config.server_url, + handler.instance_id, + config.pool_size_bytes, + chunk_size_bytes, + ) + return handler + + def _ensure_connected(self, hint_size: Optional[int] = None) -> Any: + """Lazy MaruHandler bring-up. + + On first call the page size is resolved (config value wins; + otherwise ``hint_size`` is used) and the handler is connected. + Subsequent calls just return the cached handler. + + Args: + hint_size: A page-size suggestion, typically + ``MemoryObj.get_physical_size()`` of the first store + payload. Ignored once ``chunk_size_bytes`` is known. + + Returns: + The connected ``MaruHandler``. + + Raises: + RuntimeError: If the adapter has been closed, or if no + page size is known and no ``hint_size`` is provided + (which happens when a lookup / load fires before any + store and the user did not pin ``chunk_size_bytes`` + in config). + """ + with self._lock: + if self._handler is not None: + return self._handler + self._ensure_open_locked() + chunk_size_bytes = self._chunk_size_bytes + if chunk_size_bytes is None: + if hint_size is None or hint_size <= 0: + raise RuntimeError( + "MaruL2Adapter: chunk_size_bytes is unknown — either " + "set it in config or run a store before lookup / load." + ) + chunk_size_bytes = hint_size + self._handler = self._connect_handler(self._config, chunk_size_bytes) + self._chunk_size_bytes = chunk_size_bytes + return self._handler + + def _get_next_task_id_locked(self) -> L2TaskId: + """Return a fresh task id; caller must hold ``self._lock``.""" + task_id = self._next_task_id + self._next_task_id += 1 + return task_id + + def _ensure_open_locked(self) -> None: + """Raise if the adapter has been closed; caller must hold ``self._lock``.""" + if self._closed: + raise RuntimeError("MaruL2Adapter has been closed") + + def _signal_eventfd(self, notifier) -> None: + """Wake any controller waiting on this notifier. + + Failures during shutdown (eventfd already closed) are + downgraded to debug logs — the calling worker is already on + its way out. + """ + try: + notifier.notify() + except OSError: + logger.debug("MaruL2Adapter: eventfd notify skipped (closed)") + + # ------------------------------------------------------------------ + # Event fd accessors + # ------------------------------------------------------------------ + + def get_store_event_fd(self) -> int: + return self._store_efd.fileno() + + def get_lookup_and_lock_event_fd(self) -> int: + return self._lookup_efd.fileno() + + def get_load_event_fd(self) -> int: + return self._load_efd.fileno() + + # ------------------------------------------------------------------ + # Store path + # ------------------------------------------------------------------ + + def submit_store_task( + self, + keys: list[ObjectKey], + objects: list[MemoryObj], + ) -> L2TaskId: + """Submit an asynchronous DRAM→CXL store task. + + Args: + keys: Object keys to register with MaruServer. + objects: Aligned L1 (DRAM) ``MemoryObj`` instances whose + bytes will be copied into freshly-allocated CXL pages. + + Returns: + Task id usable with :meth:`pop_completed_store_tasks`. + """ + if len(keys) != len(objects): + raise ValueError( + f"keys and objects length mismatch ({len(keys)} vs {len(objects)})" + ) + + with self._lock: + self._ensure_open_locked() + task_id = self._get_next_task_id_locked() + self._inflight_store_tasks += 1 + + assert self._store_executor is not None + self._store_executor.submit(self._execute_store_task, task_id, keys, objects) + return task_id + + def _execute_store_task( + self, + task_id: L2TaskId, + keys: list[ObjectKey], + objects: list[MemoryObj], + ) -> None: + """Worker entry: alloc CXL page, memcpy DRAM→CXL, batch_store.""" + success = True + stored_sizes: list[int] = [] + try: + # First store doubles as the lazy connect trigger — the + # MemoryObj's physical size sets the page size if config + # didn't pin it. ``get_physical_size`` is LMCache's + # chunk-aligned byte count. + hint = objects[0].get_physical_size() if objects else None + handler = self._ensure_connected(hint_size=hint) + + handles: list[Any] = [] + for obj in objects: + size = obj.get_size() + handle = handler.alloc(size) + # DRAM → CXL byte copy. ``obj.data_ptr`` is the + # source (L1 DRAM); the CXL page is mapped behind + # ``handle.buf`` (zero-copy memoryview). + dst_addr = _memoryview_addr(handle.buf) + ctypes.memmove(dst_addr, obj.data_ptr, size) + handles.append(handle) + stored_sizes.append(size) + + key_strs = [_object_key_to_string(k) for k in keys] + results = handler.batch_store(key_strs, handles) + # ``batch_store`` returns per-key flags. Treat dup-skip + # (True from server) as success — the KV is in maru regardless. + success = all(results) + if not success: + # Partial-failure aggregation isn't supported by the + # store task contract — surface the overall flag. + logger.warning( + "MaruL2Adapter: batch_store reported some failures " + "(task_id=%d, total=%d, ok=%d)", + task_id, + len(results), + sum(1 for r in results if r), + ) + except Exception: + logger.exception( + "MaruL2Adapter: store task %d failed (keys=%d)", + task_id, + len(keys), + ) + success = False + + with self._lock: + self._completed_store_tasks[task_id] = success + self._inflight_store_tasks -= 1 + + if success and stored_sizes: + self._notify_keys_stored(keys, stored_sizes) + self._signal_eventfd(self._store_efd) + + def pop_completed_store_tasks(self) -> dict[L2TaskId, bool]: + """Hand the controller every completed store task at once.""" + with self._lock: + out = self._completed_store_tasks + self._completed_store_tasks = {} + return out + + # ------------------------------------------------------------------ + # Lookup-and-lock path + # ------------------------------------------------------------------ + + def submit_lookup_and_lock_task(self, keys: list[ObjectKey]) -> L2TaskId: + """Submit an asynchronous lookup-and-lock task. + + Returns a bitmap (via :meth:`query_lookup_and_lock_result`) + where bit ``i`` is set when key ``i`` is present + pinned. + Maru's pin contract is prefix-stop (the server stops at the + first miss) — the bitmap reflects that. + """ + with self._lock: + self._ensure_open_locked() + task_id = self._get_next_task_id_locked() + self._inflight_lookup_tasks += 1 + + assert self._lookup_executor is not None + self._lookup_executor.submit(self._execute_lookup_task, task_id, keys) + return task_id + + def _execute_lookup_task( + self, + task_id: L2TaskId, + keys: list[ObjectKey], + ) -> None: + """Worker entry: ``batch_pin`` + record prefix-bitmap.""" + bitmap = Bitmap(len(keys)) + try: + # Lookup must follow a store (or an explicit config + # ``chunk_size_bytes``) — ``_ensure_connected`` raises + # otherwise. The exception turns into an all-miss bitmap + # so the caller sees a clean "nothing cached" answer. + handler = self._ensure_connected() + key_strs = [_object_key_to_string(k) for k in keys] + pin_results = handler.batch_pin(key_strs) + # Prefix-stop: first miss ends the contiguous hit run. + for i, ok in enumerate(pin_results): + if not ok: + break + bitmap.set(i) + except Exception: + logger.exception( + "MaruL2Adapter: lookup task %d failed (keys=%d)", + task_id, + len(keys), + ) + + with self._lock: + self._completed_lookup_tasks[task_id] = bitmap + self._inflight_lookup_tasks -= 1 + + self._signal_eventfd(self._lookup_efd) + + def query_lookup_and_lock_result(self, task_id: L2TaskId) -> Optional[Bitmap]: + """Pop the bitmap for ``task_id`` (single-consumer).""" + with self._lock: + return self._completed_lookup_tasks.pop(task_id, None) + + def submit_unlock(self, keys: list[ObjectKey]) -> None: + """Release prior ``submit_lookup_and_lock_task`` locks. + + Synchronous — there is no per-task completion contract for + unlock (the controller fires it and moves on). If the + handler has not been connected yet (no prior store / lookup) + the call is a silent no-op since there can be no live pins. + """ + if not keys: + return + if self._handler is None: + return + key_strs = [_object_key_to_string(k) for k in keys] + try: + self._handler.batch_unpin(key_strs) + except Exception: + logger.exception("MaruL2Adapter: batch_unpin failed (keys=%d)", len(keys)) + + # ------------------------------------------------------------------ + # Load path + # ------------------------------------------------------------------ + + def submit_load_task( + self, + keys: list[ObjectKey], + objects: list[MemoryObj], + ) -> L2TaskId: + """Submit an asynchronous CXL→DRAM load task. + + Args: + keys: Keys whose payloads to fetch from MaruServer. + objects: Pre-allocated L1 (DRAM) destination ``MemoryObj``\\ s. + + Returns: + Task id usable with :meth:`query_load_result`. + """ + if len(keys) != len(objects): + raise ValueError( + f"keys and objects length mismatch ({len(keys)} vs {len(objects)})" + ) + + with self._lock: + self._ensure_open_locked() + task_id = self._get_next_task_id_locked() + self._inflight_load_tasks += 1 + + assert self._load_executor is not None + self._load_executor.submit(self._execute_load_task, task_id, keys, objects) + return task_id + + def _execute_load_task( + self, + task_id: L2TaskId, + keys: list[ObjectKey], + objects: list[MemoryObj], + ) -> None: + """Worker entry: ``batch_retrieve`` + memcpy CXL→DRAM per key.""" + bitmap = Bitmap(len(keys)) + accessed: list[ObjectKey] = [] + try: + # Load can also seed the lazy connect: the L1 destination + # ``MemoryObj`` is already chunk-aligned, so we use its + # physical size as the page-size hint if no prior store + # has resolved one. + hint = objects[0].get_physical_size() if objects else None + handler = self._ensure_connected(hint_size=hint) + + key_strs = [_object_key_to_string(k) for k in keys] + mem_infos = handler.batch_retrieve(key_strs) + for i, (obj, info) in enumerate(zip(objects, mem_infos, strict=False)): + if info is None: + continue + nbytes = len(info.view) + if nbytes <= 0: + continue + src_addr = _memoryview_addr(info.view) + ctypes.memmove(obj.data_ptr, src_addr, nbytes) + bitmap.set(i) + accessed.append(keys[i]) + except Exception: + logger.exception( + "MaruL2Adapter: load task %d failed (keys=%d)", + task_id, + len(keys), + ) + + with self._lock: + self._completed_load_tasks[task_id] = bitmap + self._inflight_load_tasks -= 1 + + if accessed: + self._notify_keys_accessed(accessed) + self._signal_eventfd(self._load_efd) + + def query_load_result(self, task_id: L2TaskId) -> Optional[Bitmap]: + """Pop the bitmap for ``task_id`` (single-consumer).""" + with self._lock: + return self._completed_load_tasks.pop(task_id, None) + + # ------------------------------------------------------------------ + # Delete + # ------------------------------------------------------------------ + + def delete(self, keys: list[ObjectKey]) -> None: + """Remove ``keys`` from MaruServer. + + ``MaruHandler.delete`` is per-key and may return ``False`` when + the key is pinned (still being read) or missing. Both cases + are logged but not re-raised — eviction is best-effort and the + controller can retry later. A no-op when the handler has not + been connected (no prior store / lookup); there is nothing to + delete in that case. + """ + if not keys or self._handler is None: + return + for key in keys: + key_str = _object_key_to_string(key) + try: + self._handler.delete(key_str) + except Exception: + logger.exception("MaruL2Adapter: delete failed for key=%s", key_str) + + def close(self) -> None: + """Shut down worker pools, signal final eventfd events, drop + the MaruHandler connection. Best-effort: errors are logged but + do not propagate (mirroring the base adapters).""" + with self._lock: + if self._closed: + return + self._closed = True + + for executor_attr in ( + "_store_executor", + "_lookup_executor", + "_load_executor", + ): + executor = getattr(self, executor_attr) + if executor is not None: + executor.shutdown(wait=True, cancel_futures=True) + setattr(self, executor_attr, None) + + for efd_attr in ("_store_efd", "_lookup_efd", "_load_efd"): + efd = getattr(self, efd_attr, None) + if efd is not None: + try: + efd.close() + except OSError: + logger.debug("Skipping %s.close() — already closed", efd_attr) + + if self._handler is not None: + try: + self._handler.close() + except Exception: + logger.exception("[MaruL2Adapter] MaruHandler.close() failed") + self._handler = None + + +# ---------------------------------------------------------------------- +# Factory registration +# ---------------------------------------------------------------------- + + +def _create_maru_l2_adapter( + config: L2AdapterConfigBase, + l1_memory_desc: "Optional[L1MemoryDesc]" = None, +) -> L2AdapterInterface: + """Factory invoked by the L2 adapter registry. + + Args: + config: Validated maru L2 config (the registry calls us with + the base type; the concrete type is enforced at + registration time by ``register_l2_adapter_type``). + l1_memory_desc: L1 buffer descriptor passed by ``StorageManager``. + Not used by this adapter — CXL ↔ DRAM transfer happens + via ``MemoryObj.data_ptr`` on the inbound ``MemoryObj``, + so no RDMA registration of a single L1 base pointer is + required. + """ + del l1_memory_desc + return MaruL2Adapter(config) # type: ignore[arg-type] + + +register_l2_adapter_type("maru", MaruL2AdapterConfig) +register_l2_adapter_factory("maru", _create_maru_l2_adapter) diff --git a/lmcache/v1/distributed/maru_l1_dispatch.py b/lmcache/v1/distributed/maru_l1_dispatch.py new file mode 100644 index 0000000000..edd53d7a00 --- /dev/null +++ b/lmcache/v1/distributed/maru_l1_dispatch.py @@ -0,0 +1,382 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Maru-mode dispatch logic for L1Manager. + +This module isolates the maru-specific behaviour ``L1Manager`` would +otherwise carry inline. ``L1Manager`` constructs a +:class:`MaruL1Dispatcher` when it detects a ``MaruMemoryAllocator`` +and forwards each public method to it. + +The dispatcher owns: + +- ``MaruMemoryAllocator`` reference — for the ``handler`` property and + the ``get_by_location`` / ``create_store_handle`` extension methods. +- ``L1MemoryManager`` reference — used by :meth:`reserve_write` to + drive allocation and by :meth:`report_status` to read usage stats. +- ``_pending_read_memobjs`` side channel — populated in + :meth:`reserve_read` and drained by :meth:`unsafe_read` / + :meth:`finish_read`. + +Thread safety: the dispatcher assumes the caller holds the +``L1Manager`` lock (the public methods on L1Manager are wrapped with +``@l1_mgr_synchronized``). The side-channel dict is therefore not +guarded by a separate lock here. +""" + +# Future +from __future__ import annotations + +# Standard +from typing import TYPE_CHECKING, Any, Optional + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.distributed.api import MemoryLayoutDesc, ObjectKey +from lmcache.v1.distributed.error import L1Error +from lmcache.v1.distributed.internal_api import L1OperationResult +from lmcache.v1.memory_management import MemoryObj + +if TYPE_CHECKING: + # First Party + from lmcache.v1.distributed.maru_memory_allocator import MaruMemoryAllocator + from lmcache.v1.distributed.memory_manager import L1MemoryManager + +logger = init_logger(__name__) + + +def object_key_to_string(key: ObjectKey) -> str: + """Stable string representation of ``ObjectKey`` for ``MaruHandler`` + RPCs. + + The format mirrors the encoding used by other L2 adapters + (``model@kv_rank_hex@chunk_hash_hex[@salt]``) so KV index entries + are inter-operable with adapters that might query the same + MaruServer instance through the L2 path. + + Args: + key: The object key to encode. + + Returns: + ``"@@[@]"``. + """ + base = f"{key.model_name}@{key.kv_rank:08x}@{key.chunk_hash.hex()}" + if key.cache_salt: + return f"{base}@{key.cache_salt}" + return base + + +class MaruL1Dispatcher: + """Encapsulates maru-mode dispatch for L1Manager. + + Each method maps 1:1 to the corresponding ``L1Manager`` public + method. The dispatcher is constructed only when the L1 allocator + is a :class:`MaruMemoryAllocator` — see + :class:`L1Manager.__init__`. + """ + + def __init__( + self, + allocator: "MaruMemoryAllocator", + memory_manager: "L1MemoryManager", + write_ttl_seconds: int, + read_ttl_seconds: int, + ) -> None: + self._allocator = allocator + self._memory_manager = memory_manager + self._write_ttl_seconds = write_ttl_seconds + self._read_ttl_seconds = read_ttl_seconds + self._pending_read_memobjs: dict[ObjectKey, MemoryObj] = {} + + @property + def handler(self) -> Any: + """The connected ``MaruHandler``. + + Resolves through ``MaruMemoryAllocator.handler``, which raises + if the allocator's ``init_layout`` has not been called. On the + engine hot path that ordering is guaranteed by + ``MPCacheEngine.register_kv_cache`` running before any + ``store`` / ``lookup`` RPC. + """ + return self._allocator.handler + + # ------------------------------------------------------------------ + # Read path + # ------------------------------------------------------------------ + + def reserve_read(self, keys: list[ObjectKey]) -> dict[ObjectKey, L1OperationResult]: + """Pin + retrieve + stage MemoryObjs in the side channel. + + ``MaruHandler.batch_pin`` has prefix-stop semantics — it only + pins the contiguous prefix of existing keys. We then resolve + each pinned key via ``batch_retrieve`` + ``get_by_location`` + to materialise a ``MemoryObj`` pointing at the existing CXL + page (no data copy). The resolved ``MemoryObj`` is staged in + ``self._pending_read_memobjs`` so the subsequent + ``unsafe_read`` can return it. + + If a pinned key cannot be resolved (race between pin and + retrieve), we unpin the unused tail to keep MaruServer's + ``pin_count`` accurate. + """ + handler = self.handler + key_strs = [object_key_to_string(k) for k in keys] + try: + pin_results = handler.batch_pin(key_strs) + except Exception: + logger.exception("MaruHandler.batch_pin failed for %d keys", len(keys)) + return {k: (L1Error.KEY_NOT_EXIST, None) for k in keys} + + num_pinned = 0 + for ok in pin_results: + if not ok: + break + num_pinned += 1 + + ret: dict[ObjectKey, L1OperationResult] = { + k: (L1Error.KEY_NOT_EXIST, None) for k in keys + } + if num_pinned == 0: + return ret + + try: + mem_infos = handler.batch_retrieve(key_strs[:num_pinned]) + except Exception: + logger.exception( + "MaruHandler.batch_retrieve failed for %d keys", num_pinned + ) + # Roll back the pins so MaruServer's refcount stays consistent. + try: + handler.batch_unpin(key_strs[:num_pinned]) + except Exception: + logger.exception( + "MaruHandler.batch_unpin rollback failed for %d keys", + num_pinned, + ) + return ret + + resolved = 0 + for k, mi in zip(keys[:num_pinned], mem_infos, strict=False): + if mi is None: + # Race between pin and retrieve — treat this and all + # subsequent keys as miss to preserve prefix semantics. + break + mem_obj = self._allocator.get_by_location( + region_id=mi.region_id, + page_index=mi.page_index, + actual_size=len(mi.view), + ) + if mem_obj is None: + break + self._pending_read_memobjs[k] = mem_obj + ret[k] = (L1Error.SUCCESS, mem_obj) + resolved += 1 + + if resolved < num_pinned: + extras = key_strs[resolved:num_pinned] + try: + handler.batch_unpin(extras) + except Exception: + logger.exception( + "MaruHandler.batch_unpin (reconciliation) failed for %d keys", + len(extras), + ) + return ret + + def unsafe_read(self, keys: list[ObjectKey]) -> dict[ObjectKey, L1OperationResult]: + """Look up MemoryObjs staged by :meth:`reserve_read`.""" + ret: dict[ObjectKey, L1OperationResult] = {} + for k in keys: + mem_obj = self._pending_read_memobjs.get(k) + if mem_obj is None: + ret[k] = (L1Error.KEY_NOT_EXIST, None) + else: + ret[k] = (L1Error.SUCCESS, mem_obj) + return ret + + def finish_read(self, keys: list[ObjectKey]) -> dict[ObjectKey, L1Error]: + """Drop side-channel entries and ``batch_unpin``.""" + handler = self.handler + + ret: dict[ObjectKey, L1Error] = {} + to_unpin: list[str] = [] + for k in keys: + if self._pending_read_memobjs.pop(k, None) is not None: + to_unpin.append(object_key_to_string(k)) + ret[k] = L1Error.SUCCESS + else: + ret[k] = L1Error.KEY_NOT_EXIST + + if to_unpin: + try: + handler.batch_unpin(to_unpin) + except Exception: + logger.exception( + "MaruHandler.batch_unpin failed in finish_read for %d keys", + len(to_unpin), + ) + return ret + + # ------------------------------------------------------------------ + # Write path + # ------------------------------------------------------------------ + + def reserve_write( + self, + keys: list[ObjectKey], + is_temporary: list[bool], + layout_desc: MemoryLayoutDesc, + mode: str, + ) -> dict[ObjectKey, L1OperationResult]: + """Allocate CXL ``MemoryObj``s for ``keys``. + + No in-process dict / TTLLock / state machine is used. The + engine takes the returned ``MemoryObj``s, runs cudaMemcpy + into their CXL-backed ``data_ptr``, then hands them back via + :meth:`finish_write` (which issues ``batch_store``). + + ``is_temporary`` and ``mode`` are accepted for interface + compatibility but have no effect in maru mode — the maru flow + only uses ``mode="new"``. + """ + del is_temporary, mode # unused in maru mode + + ret: dict[ObjectKey, L1OperationResult] = {} + if not keys: + return ret + + err, allocated_objs = self._memory_manager.allocate(layout_desc, len(keys)) + if err != L1Error.SUCCESS: + for k in keys: + ret[k] = (L1Error.OUT_OF_MEMORY, None) + return ret + + for k, obj in zip(keys, allocated_objs, strict=False): + ret[k] = (L1Error.SUCCESS, obj) + return ret + + def finish_write( + self, + keys: list[ObjectKey], + memory_objs: Optional[list[MemoryObj]], + ) -> dict[ObjectKey, L1Error]: + """Register KVs with MaruServer via ``batch_store``. + + ``batch_store`` performs dup-skip + auto-free transparently: + keys that already exist have their newly-allocated CXL page + returned to the pool. Both "newly registered" and + "skipped because already present" are functional successes. + """ + handler = self.handler + + if memory_objs is None or len(memory_objs) != len(keys): + actual = 0 if memory_objs is None else len(memory_objs) + logger.error( + "Maru finish_write requires memory_objs matching keys " + "(keys=%d, memory_objs=%d)", + len(keys), + actual, + ) + return {k: L1Error.KEY_IN_WRONG_STATE for k in keys} + + key_strs = [object_key_to_string(k) for k in keys] + try: + handles = [self._allocator.create_store_handle(mo) for mo in memory_objs] + except Exception: + logger.exception( + "create_store_handle failed for %d MemoryObjs", len(memory_objs) + ) + return {k: L1Error.KEY_IN_WRONG_STATE for k in keys} + + try: + results = handler.batch_store(key_strs, handles) + except Exception: + logger.exception("MaruHandler.batch_store failed for %d keys", len(keys)) + return {k: L1Error.KEY_IN_WRONG_STATE for k in keys} + + ret: dict[ObjectKey, L1Error] = {} + for k, ok in zip(keys, results, strict=False): + ret[k] = L1Error.SUCCESS if ok else L1Error.KEY_IN_WRONG_STATE + return ret + + def finish_write_and_reserve_read( + self, keys: list[ObjectKey] + ) -> dict[ObjectKey, L1OperationResult]: + """Defensive no-op for the atomic write-to-read transition. + + The maru flow stages MemoryObjs in the side channel during + :meth:`reserve_read` rather than going through the + ``reserve_write(is_temporary=True)`` → ``submit_load_task`` → + ``finish_write_and_reserve_read`` sequence used by other + backends. Return ``SUCCESS`` for already-staged keys and + ``KEY_NOT_EXIST`` otherwise so any defensive caller still + sees a useful answer. + """ + ret: dict[ObjectKey, L1OperationResult] = {} + for k in keys: + mem_obj = self._pending_read_memobjs.get(k) + if mem_obj is None: + ret[k] = (L1Error.KEY_NOT_EXIST, None) + else: + ret[k] = (L1Error.SUCCESS, mem_obj) + return ret + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def delete(self, keys: list[ObjectKey]) -> dict[ObjectKey, L1Error]: + """Forward to ``MaruHandler.delete`` per key. + + ``MaruHandler.delete`` returns ``False`` when the key is + pinned or missing — the API conflates the two, so we report + the softer ``KEY_NOT_EXIST`` (callers retrying after + ``KEY_IS_LOCKED`` would quickly hit it again anyway). + """ + handler = self.handler + + ret: dict[ObjectKey, L1Error] = {} + for k in keys: + key_str = object_key_to_string(k) + try: + ok = handler.delete(key_str) + except Exception: + logger.exception("MaruHandler.delete failed for key=%s", key_str) + ret[k] = L1Error.KEY_IN_WRONG_STATE + continue + ret[k] = L1Error.SUCCESS if ok else L1Error.KEY_NOT_EXIST + return ret + + def clear(self, force: bool) -> None: + """Drop staged side-channel entries only. + + The CXL pool itself is owned by ``MaruServer`` and is never + wiped by the L1 layer — ``force=True`` only affects the + in-process read-side bookkeeping. Server-side wipes go + through explicit ``MaruHandler.delete`` calls or MaruServer's + own lifecycle. + """ + if force: + logger.warning( + "L1Manager (maru): force-clear drops %d pending read " + "MemoryObjs but does NOT touch MaruServer.", + len(self._pending_read_memobjs), + ) + self._pending_read_memobjs.clear() + + def report_status(self) -> dict: + """Maru-flavoured status snapshot.""" + used, total = self._memory_manager.get_memory_usage() + return { + "is_healthy": True, + "backend": "maru", + "total_object_count": 0, + "write_locked_count": 0, + "read_locked_count": 0, + "temporary_count": 0, + "pending_read_memobjs": len(self._pending_read_memobjs), + "memory_used_bytes": used, + "memory_total_bytes": total, + "memory_usage_ratio": used / total if total > 0 else 0.0, + "write_ttl_seconds": self._write_ttl_seconds, + "read_ttl_seconds": self._read_ttl_seconds, + } diff --git a/lmcache/v1/distributed/maru_memory_allocator.py b/lmcache/v1/distributed/maru_memory_allocator.py new file mode 100644 index 0000000000..b971ab5676 --- /dev/null +++ b/lmcache/v1/distributed/maru_memory_allocator.py @@ -0,0 +1,482 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Maru-backed L1 memory allocator for MP mode. + +This module exposes :class:`MaruMemoryAllocator`, an implementation of +:class:`MemoryAllocatorInterface` whose ``MemoryObj`` instances are +backed by CXL shared memory via the embedded ``CxlMemoryAdapter`` +(``maru_lmcache``). + +Lifecycle: + The allocator is constructed eagerly (when ``L1MemoryManager`` is + built) but the underlying ``MaruHandler`` connection and + ``CxlMemoryAdapter`` pool are deferred until :meth:`init_layout` is + called with the KV layout learned from the first + ``register_kv_cache`` RPC. This matches LMCache MP's two-phase + startup: the storage manager exists before any vLLM worker has + registered its KV cache tensors. + +Key invariants: + - ``MemoryObj.parent_allocator`` is ``None`` for all objects + returned by this allocator. LMCache's refcount-driven free path + must NOT release the underlying CXL pages — lifecycle is owned + by ``MaruServer`` (``pin_kv`` / ``unpin_kv`` / ``delete_kv``). + - :meth:`get_by_location` and :meth:`create_store_handle` are not + part of ``MemoryAllocatorInterface``; ``L1Manager``'s maru + branch reaches them through an ``isinstance`` check. + - ``maru`` and ``maru_lmcache`` are imported lazily so loading + this module does not require those packages to be installed. + +Known limitations: + *Single-model per LMCache instance.* The CXL pool is typed at the + first :meth:`init_layout` call (``CxlMemoryAdapter`` pre-creates + one ``MemoryObj`` per page with the canonical + shapes/dtypes/fmt). Subsequent registrations with a different + layout are rejected. The default DRAM allocators + (:class:`LazyMemoryAllocator` / :class:`MixedMemoryAllocator`) + support multi-model deployments transparently; maru does not. + TODO(maru-multi-model): partition the pool by layout key and hold + one ``CxlMemoryAdapter`` per distinct layout. +""" + +# Future +from __future__ import annotations + +# Standard +from dataclasses import dataclass +from typing import Any, List, Optional, Union + +# Third Party +import torch + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.memory_management import ( + MemoryAllocatorInterface, + MemoryFormat, + MemoryObj, +) + +logger = init_logger(__name__) + + +@dataclass +class MaruL1Config: + """Configuration for :class:`MaruMemoryAllocator`. + + Carries only the layout-independent parameters known at storage + manager construction time. The KV layout + (shapes/dtypes/fmt/chunk_size) is supplied later via + :meth:`MaruMemoryAllocator.init_layout` once a vLLM worker has + registered its KV cache. + + Attributes: + server_url: MaruServer endpoint. Accepts both + ``maru://host:port`` and ``tcp://host:port``; the former + is rewritten to the latter internally. + pool_size_bytes: Per-instance CXL pool quota requested from + ``MaruServer``. + instance_id: Stable identifier for this client instance, used + by ``MaruServer`` for ownership tracking, restart + recovery, and observability. If ``None``, ``MaruConfig`` + auto-generates a UUID (acceptable for single-instance / + single-node setups but not recommended for multi-node + deployments). + timeout_ms: Socket timeout for RPC calls. + use_async_rpc: Whether to use async DEALER-ROUTER RPC client. + max_inflight: Max concurrent in-flight async requests. + eager_map: Pre-map all shared regions on connect. + """ + + server_url: str + pool_size_bytes: int + instance_id: Optional[str] = None + timeout_ms: int = 5000 + use_async_rpc: bool = True + max_inflight: int = 64 + eager_map: bool = True + + +class MaruMemoryAllocator(MemoryAllocatorInterface): + """L1 memory allocator backed by CXL shared memory via Maru. + + Wraps the embedded ``CxlMemoryAdapter`` (``maru_lmcache``) and + re-exposes the slice of its API required by + :class:`MemoryAllocatorInterface`. Adds :meth:`get_by_location` + and :meth:`create_store_handle` for callers — chiefly + ``L1Manager``'s maru branch — that need to drive ``MaruHandler`` + RPCs directly. + + Lifecycle (lazy): + ``__init__`` stores the config only — no MaruServer RPC. The + ``MaruHandler`` connection and ``CxlMemoryAdapter`` pool are + built on the first :meth:`init_layout` call (triggered by the + first ``register_kv_cache`` from a vLLM worker). Calls to + :meth:`batched_allocate` / :meth:`allocate` / + :meth:`get_by_location` / :meth:`create_store_handle` before + :meth:`init_layout` raise ``RuntimeError``. + + Known limitations: + Single-model per instance — see module docstring. + ``TODO(maru-multi-model)``. + """ + + def __init__(self, config: MaruL1Config) -> None: + self._config = config + self._handler: Optional[Any] = None + self._cxl_adapter: Optional[Any] = None + # Set in ``init_layout``; ``0`` is a sentinel for "not yet + # initialized" and is never returned to callers (the property + # raises before they can observe it). + self._single_token_size: int = 0 + self._shapes: Optional[List[torch.Size]] = None + self._dtypes: Optional[List[torch.dtype]] = None + self._fmt: Optional[MemoryFormat] = None + self._chunk_size_in_tokens: int = 0 + + # ------------------------------------------------------------------ + # Two-phase initialization + # ------------------------------------------------------------------ + + def init_layout( + self, + shapes: List[torch.Size], + dtypes: List[torch.dtype], + fmt: MemoryFormat, + chunk_size_in_tokens: int, + ) -> None: + """Bind a KV layout and bring up the CXL pool. + + First call connects ``MaruHandler`` (sized to the layout's + full-chunk byte budget) and constructs the + ``CxlMemoryAdapter`` pool. Subsequent calls with the same + layout are no-ops; mismatched layouts raise ``ValueError`` + (single-model constraint — see class docstring). + + Args: + shapes: KV chunk shapes (per-layer-group when + heterogeneous, otherwise single-element). + dtypes: KV chunk dtypes aligned with ``shapes``. + fmt: Memory format (e.g. ``KV_2LTD`` or ``KV_MLA_FMT``). + chunk_size_in_tokens: LMCache chunk size in tokens + (typically 256). + + Raises: + ValueError: If a layout has already been bound and the new + layout differs — maru is single-model only. + RuntimeError: If ``MaruHandler.connect()`` fails. + """ + if chunk_size_in_tokens <= 0: + raise ValueError( + f"chunk_size_in_tokens must be positive, got {chunk_size_in_tokens}" + ) + + full_chunk_size_bytes = _compute_full_chunk_size_bytes(shapes, dtypes) + if full_chunk_size_bytes <= 0: + raise ValueError( + f"full_chunk_size_bytes computed to non-positive value " + f"({full_chunk_size_bytes}) from shapes={shapes} dtypes={dtypes}" + ) + if full_chunk_size_bytes % chunk_size_in_tokens != 0: + raise ValueError( + f"full_chunk_size_bytes ({full_chunk_size_bytes}) must be a " + f"multiple of chunk_size_in_tokens ({chunk_size_in_tokens})" + ) + + if self._cxl_adapter is not None: + if ( + self._shapes != shapes + or self._dtypes != dtypes + or self._fmt != fmt + or self._chunk_size_in_tokens != chunk_size_in_tokens + ): + raise ValueError( + "MaruMemoryAllocator: layout mismatch on subsequent " + "register_kv_layout call. The maru backend is " + "single-model only — see class docstring " + "(TODO maru-multi-model).\n" + f" existing: shapes={self._shapes} dtypes={self._dtypes} " + f"fmt={self._fmt} chunk={self._chunk_size_in_tokens}\n" + f" new: shapes={shapes} dtypes={dtypes} " + f"fmt={fmt} chunk={chunk_size_in_tokens}" + ) + return + + # Lazy import: maru runtime is only required once a layout is + # actually bound. Importing here keeps the module loadable on + # non-maru deployments. + # Third Party + from maru import MaruConfig, MaruHandler + from maru_lmcache import CxlMemoryAdapter + + # ``MaruHandler`` expects the ``tcp://`` scheme; ``maru://`` + # is the LMCache-facing convention (mirrors ``MaruBackend``). + server_url = self._config.server_url + if server_url.startswith("maru://"): + server_url = "tcp://" + server_url[len("maru://") :] + + maru_config = MaruConfig( + server_url=server_url, + instance_id=self._config.instance_id, + pool_size=self._config.pool_size_bytes, + chunk_size_bytes=full_chunk_size_bytes, + auto_connect=False, + timeout_ms=self._config.timeout_ms, + use_async_rpc=self._config.use_async_rpc, + max_inflight=self._config.max_inflight, + eager_map=self._config.eager_map, + ) + + handler = MaruHandler(maru_config) + if not handler.connect(): + raise RuntimeError( + f"Failed to connect MaruHandler to {self._config.server_url}" + ) + logger.info( + "[MaruMemoryAllocator] connected: server=%s instance_id=%s " + "pool_size=%d chunk_size_bytes=%d", + self._config.server_url, + handler.instance_id, + self._config.pool_size_bytes, + full_chunk_size_bytes, + ) + + self._handler = handler + self._cxl_adapter = CxlMemoryAdapter( + handler=handler, + shapes=shapes, + dtypes=dtypes, + fmt=fmt, + chunk_size=handler.get_chunk_size(), + ) + self._shapes = shapes + self._dtypes = dtypes + self._fmt = fmt + self._chunk_size_in_tokens = chunk_size_in_tokens + self._single_token_size = full_chunk_size_bytes // chunk_size_in_tokens + + @property + def is_initialized(self) -> bool: + """``True`` once :meth:`init_layout` has constructed the pool.""" + return self._cxl_adapter is not None + + # ------------------------------------------------------------------ + # Accessors used by L1Manager's maru branch (via isinstance check) + # ------------------------------------------------------------------ + + @property + def handler(self) -> Any: + """The connected ``MaruHandler``. + + ``L1Manager``'s maru branch uses this to issue + ``batch_store`` / ``batch_pin`` / ``batch_retrieve`` / + ``batch_unpin`` / ``delete`` directly, bypassing the + ``L2AdapterInterface`` framework. + + Raises: + RuntimeError: If :meth:`init_layout` has not yet been + called. + """ + if self._handler is None: + raise RuntimeError( + "MaruMemoryAllocator.handler accessed before init_layout(); " + "the MaruHandler is built lazily on the first " + "register_kv_cache RPC." + ) + return self._handler + + @property + def single_token_size(self) -> int: + """Bytes per single token in a KV chunk. + + Used by ``L1Manager``'s maru branch when invoking + :meth:`get_by_location` to materialize a ``MemoryObj`` for a + partial chunk. + + Raises: + RuntimeError: If :meth:`init_layout` has not yet been + called. + """ + if self._single_token_size == 0: + raise RuntimeError( + "MaruMemoryAllocator.single_token_size accessed before init_layout()." + ) + return self._single_token_size + + def get_by_location( + self, + region_id: int, + page_index: int, + actual_size: int, + single_token_size: Optional[int] = None, + ) -> Optional[MemoryObj]: + """Resolve a CXL ``(region_id, page_index)`` to a + ``MemoryObj``. + + Used during the RETRIEVE lookup phase: + ``MaruHandler.batch_retrieve`` reports the location and this + method materialises the pool-resident ``MemoryObj`` (no data + copy). + + Args: + region_id: Region id from ``MaruServer``. + page_index: Page index within the region. + actual_size: Actual KV chunk size in bytes (may be less + than a full chunk for trailing partial chunks). + single_token_size: Bytes-per-token override for partial + chunks. Defaults to :attr:`single_token_size`. + + Returns: + ``MemoryObj`` resolving the CXL page, or ``None`` if the + location is no longer valid (e.g. region not mapped). + + Raises: + RuntimeError: If :meth:`init_layout` has not yet been + called. + """ + self._require_initialized("get_by_location") + if single_token_size is None: + single_token_size = self._single_token_size + return self._cxl_adapter.get_by_location( # type: ignore[union-attr] + region_id=region_id, + page_index=page_index, + actual_size=actual_size, + single_token_size=single_token_size, + ) + + def create_store_handle(self, memory_obj: MemoryObj) -> Any: + """Reconstruct an ``AllocHandle`` from a ``MemoryObj`` for use + with ``MaruHandler.batch_store``. + + Raises: + RuntimeError: If :meth:`init_layout` has not yet been + called. + """ + self._require_initialized("create_store_handle") + return self._cxl_adapter.create_store_handle(memory_obj) # type: ignore[union-attr] + + # ------------------------------------------------------------------ + # MemoryAllocatorInterface + # ------------------------------------------------------------------ + + def allocate( + self, + shapes: Union[torch.Size, List[torch.Size]], + dtypes: Union[torch.dtype, List[torch.dtype]], + fmt: MemoryFormat = MemoryFormat.UNDEFINED, + allocator_type: Optional[str] = None, + ) -> Optional[MemoryObj]: + """Allocate a single CXL-backed ``MemoryObj``. + + ``CxlMemoryAdapter`` uses the canonical shapes/dtypes/fmt + fixed at :meth:`init_layout` time; the arguments here are + accepted for interface compatibility but the pool's metadata + is authoritative. + + Raises: + RuntimeError: If :meth:`init_layout` has not yet been + called. + """ + self._require_initialized("allocate") + return self._cxl_adapter.allocate(shapes, dtypes, fmt, allocator_type) # type: ignore[union-attr] + + def batched_allocate( + self, + shapes: Union[torch.Size, List[torch.Size]], + dtypes: Union[torch.dtype, List[torch.dtype]], + batch_size: int, + fmt: MemoryFormat = MemoryFormat.UNDEFINED, + allocator_type: Optional[str] = None, + ) -> Optional[List[MemoryObj]]: + """Allocate ``batch_size`` CXL-backed ``MemoryObj`` instances. + + Raises: + RuntimeError: If :meth:`init_layout` has not yet been + called. + """ + self._require_initialized("batched_allocate") + return self._cxl_adapter.batched_allocate( # type: ignore[union-attr] + shapes, dtypes, batch_size, fmt, allocator_type + ) + + def free( + self, + memory_obj: MemoryObj, + allocator_type: Optional[str] = None, + ) -> None: + """No-op. CXL lifecycle owned by MaruServer.""" + return + + def batched_free( + self, + memory_objs: List[MemoryObj], + allocator_type: Optional[str] = None, + update_stats: bool = True, + ) -> None: + """No-op. CXL lifecycle owned by MaruServer.""" + return + + def close(self) -> None: + """Close the underlying ``CxlMemoryAdapter`` and + ``MaruHandler`` if they were ever built. + + Best-effort: errors during close are logged but do not + propagate. Safe to call before :meth:`init_layout` — both + underlying objects are ``None`` in that case and the call is + a no-op. + """ + if self._cxl_adapter is not None: + try: + self._cxl_adapter.close() + except Exception: + logger.exception( + "[MaruMemoryAllocator] CxlMemoryAdapter.close() failed" + ) + self._cxl_adapter = None + if self._handler is not None: + try: + self._handler.close() + except Exception: + logger.exception("[MaruMemoryAllocator] MaruHandler.close() failed") + self._handler = None + + def memcheck(self) -> bool: + return True + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _require_initialized(self, op: str) -> None: + if self._cxl_adapter is None: + raise RuntimeError( + f"MaruMemoryAllocator.{op} called before init_layout(); " + f"call init_layout(shapes, dtypes, fmt, chunk_size_in_tokens) " + f"first (typically from MPCacheEngine.register_kv_cache)." + ) + + +def _compute_full_chunk_size_bytes( + shapes: List[torch.Size], dtypes: List[torch.dtype] +) -> int: + """Total bytes for one full KV chunk across all layer groups. + + Args: + shapes: Per-layer-group shapes. + dtypes: Per-layer-group dtypes (must align with ``shapes``). + + Returns: + ``sum(shape.numel() * dtype.itemsize)``. + + Raises: + ValueError: If ``shapes`` and ``dtypes`` differ in length. + """ + if len(shapes) != len(dtypes): + raise ValueError( + f"shapes and dtypes must have the same length, " + f"got {len(shapes)} and {len(dtypes)}" + ) + return sum( + shape.numel() * dtype.itemsize + for shape, dtype in zip(shapes, dtypes, strict=True) + ) diff --git a/lmcache/v1/distributed/memory_manager.py b/lmcache/v1/distributed/memory_manager.py index e2bfff45c4..976c45131b 100644 --- a/lmcache/v1/distributed/memory_manager.py +++ b/lmcache/v1/distributed/memory_manager.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +# Third Party +import torch + # First Party from lmcache.logging import init_logger from lmcache.v1.distributed.api import MemoryLayoutDesc @@ -9,6 +12,7 @@ from lmcache.v1.lazy_memory_allocator import LazyMemoryAllocator from lmcache.v1.memory_management import ( MemoryAllocatorInterface, + MemoryFormat, MemoryObj, MixedMemoryAllocator, ) @@ -27,6 +31,18 @@ def create_memory_allocator(config: L1MemoryManagerConfig) -> MemoryAllocatorInt Returns: MemoryAllocatorInterface: An instance of a memory allocator. """ + if config.maru_config is not None: + # Maru backend — CXL-backed allocator via MaruMemoryAllocator. + # Lazy import keeps the maru runtime optional for non-maru builds. + # First Party + from lmcache.v1.distributed.maru_memory_allocator import MaruMemoryAllocator + + logger.debug( + "use maru memory allocator: server=%s pool_size=%d bytes", + config.maru_config.server_url, + config.maru_config.pool_size_bytes, + ) + return MaruMemoryAllocator(config.maru_config) if config.use_lazy: logger.debug( "use lazy memory allocator, init size is %d bytes, " @@ -51,6 +67,18 @@ def create_memory_allocator(config: L1MemoryManagerConfig) -> MemoryAllocatorInt ) +def _is_maru_allocator(allocator: MemoryAllocatorInterface) -> bool: + """``isinstance(allocator, MaruMemoryAllocator)`` with lazy import. + + Avoids importing the maru-backed allocator (and indirectly the maru + runtime types it lazily uses) when not needed. + """ + # First Party + from lmcache.v1.distributed.maru_memory_allocator import MaruMemoryAllocator + + return isinstance(allocator, MaruMemoryAllocator) + + # MAIN CLASS class L1MemoryManager: """ @@ -66,6 +94,18 @@ def __init__(self, config: L1MemoryManagerConfig): self._size_in_bytes = config.size_in_bytes self._align_bytes = config.align_bytes + @property + def allocator(self) -> MemoryAllocatorInterface: + """Underlying memory allocator. + + Exposed primarily for callers that need allocator-specific + operations not in :class:`MemoryAllocatorInterface` — e.g. + ``L1Manager``'s maru branch reaches into + :class:`MaruMemoryAllocator` for ``handler`` / + ``get_by_location`` / ``create_store_handle``. + """ + return self._allocator + def allocate( self, layout_desc: MemoryLayoutDesc, count: int ) -> tuple[L1Error, list[MemoryObj]]: @@ -121,6 +161,25 @@ def get_memory_usage(self) -> tuple[int, int]: In the future, we may want to make a "callback" based mechanism to trigger eviction when the memory usage reaches a watermark. """ + # Maru backend: query MaruHandler stats. Eviction is owned by + # MaruServer so this is best-effort observability; on failure + # return (0, 0) rather than crash the eviction controller. + if _is_maru_allocator(self._allocator): + allocator = self._allocator + # Lazy backend — handler not built until register_kv_layout. + if not allocator.is_initialized: # type: ignore[attr-defined] + return 0, 0 + try: + handler = allocator.handler # type: ignore[attr-defined] + stats = handler.get_stats() if hasattr(handler, "get_stats") else {} + used = int(stats.get("used_bytes", 0)) + total = int( + stats.get("pool_size_bytes", 0) or stats.get("pool_size", 0) + ) + return used, total + except Exception: + logger.exception("Failed to query Maru handler stats") + return 0, 0 # HACK: now trying to read this from the address manager in a ad-hoc # manner @@ -152,6 +211,14 @@ def get_l1_memory_desc(self) -> L1MemoryDesc: Raises: NotImplementedError: If the allocator type does not support this operation. """ + if _is_maru_allocator(self._allocator): + # No contiguous DRAM buffer to describe — Maru-backed L1 lives in + # CXL pages mmap'd via the handler. RDMA-style registration of a + # single base pointer does not apply. + raise NotImplementedError( + "get_l1_memory_desc is not supported for the maru backend " + "(L1 lives in CXL via mmap, not a single contiguous buffer)." + ) if isinstance(self._allocator, MixedMemoryAllocator): buffer = self._allocator.buffer elif isinstance(self._allocator, LazyMemoryAllocator): @@ -168,6 +235,35 @@ def get_l1_memory_desc(self) -> L1MemoryDesc: align_bytes=self._align_bytes, ) + def register_kv_layout( + self, + shapes: list[torch.Size], + dtypes: list[torch.dtype], + fmt: MemoryFormat, + chunk_size_in_tokens: int, + ) -> None: + """Bind the KV layout to the underlying allocator. + + Only the maru backend acts on this — its ``CxlMemoryAdapter`` + pool is typed at first registration. The default DRAM + allocators (``LazyMemoryAllocator`` / ``MixedMemoryAllocator``) + are layout-agnostic so this call is a no-op for them. + + Idempotent for matching layouts; layout mismatch on a + subsequent call raises ``ValueError`` (maru single-model + constraint). + + Args: + shapes: KV chunk shapes (per-layer-group). + dtypes: KV chunk dtypes aligned with ``shapes``. + fmt: Memory format. + chunk_size_in_tokens: LMCache chunk size in tokens. + """ + if _is_maru_allocator(self._allocator): + self._allocator.init_layout( # type: ignore[attr-defined] + shapes, dtypes, fmt, chunk_size_in_tokens + ) + def close(self) -> None: """ Close the memory manager and release all resources. diff --git a/lmcache/v1/distributed/storage_manager.py b/lmcache/v1/distributed/storage_manager.py index 6b4f95bd99..01ac45d681 100644 --- a/lmcache/v1/distributed/storage_manager.py +++ b/lmcache/v1/distributed/storage_manager.py @@ -5,9 +5,12 @@ # Standard from contextlib import contextmanager -from typing import Iterator, Literal +from typing import Iterator, Literal, Optional import time +# Third Party +import torch + # First Party from lmcache.logging import init_logger from lmcache.v1.distributed.api import ( @@ -37,7 +40,7 @@ AdapterDescriptor, create_store_policy, ) -from lmcache.v1.memory_management import MemoryObj +from lmcache.v1.memory_management import MemoryFormat, MemoryObj from lmcache.v1.mp_observability.event import Event, EventType from lmcache.v1.mp_observability.event_bus import get_event_bus from lmcache.v1.mp_observability.trace.decorator import ( @@ -54,6 +57,27 @@ def __init__(self, config: StorageManagerConfig): self._l1_manager = L1Manager(config.l1_manager_config) self._event_bus = get_event_bus() + # Per-cache_salt quota registry. Always present so the HTTP + # layer has a stable ``quota_manager`` reference; populated + # below for the default-backend path. + self._quota_manager = QuotaManager() + + # Maru-backed L1 bypasses the full controller / L2 adapter + # stack — see ``MaruMemoryAllocator`` docstring for the + # design rationale. ``L2EvictionController`` / ``StoreController`` + # / ``PrefetchController`` are not instantiated; ``L2`` adapters + # are not created (the pool is the L2 tier, owned by MaruServer). + self._is_maru: bool = ( + config.l1_manager_config.memory_config.maru_config is not None + ) + self._l2_adapters: list[L2AdapterInterface] = [] + self._eviction_controller: Optional[L1EvictionController] = None + self._l2_eviction_controller: Optional[L2EvictionController] = None + self._store_controller: Optional[StoreController] = None + self._prefetch_controller: Optional[PrefetchController] = None + if self._is_maru: + return + # L1 eviction controller self._eviction_controller = L1EvictionController( l1_manager=self._l1_manager, @@ -66,7 +90,6 @@ def __init__(self, config: StorageManagerConfig): # ``SerdeL2AdapterWrapper`` so controllers see a plain L2 adapter # and serde is transparent. l1_memory_desc = self._l1_manager.get_l1_memory_desc() - self._l2_adapters: list[L2AdapterInterface] = [] for ac in config.l2_adapter_config.adapters: adapter: L2AdapterInterface = create_l2_adapter(ac, l1_memory_desc) if ac.serde_config is not None: @@ -77,14 +100,6 @@ def __init__(self, config: StorageManagerConfig): ) self._l2_adapters.append(adapter) - # Per-cache_salt quota registry. Shared across the L2 eviction - # controller (reads quotas each cycle) and the HTTP quota - # endpoints (CRUD). Present even when no adapter uses - # IsolatedLRU so the HTTP layer has a stable ``quota_manager`` - # reference. No explicit cleanup on close — the registry is - # just a dict protected by a lock and has no OS resources. - self._quota_manager = QuotaManager() - # Unified L2 eviction controller for all adapters with eviction # config. Aggregate-usage policies (``LRU``, ``noop``) need # ``max_capacity_bytes > 0`` to compute a usage fraction; @@ -204,14 +219,22 @@ def reserve_write( def finish_write( self, keys: list[ObjectKey], + memory_objs: Optional[list[MemoryObj]] = None, ) -> None: """ Finish writing the objects into the storage manager. Args: keys (list[ObjectKey]): List of object keys that have been written. + memory_objs: ``MemoryObj`` instances aligned with ``keys``. + Required when the L1 backend is maru — the caller + (``MPCacheEngine.store``) keeps the reserved + MemoryObjs alive across the GPU copy and threads them + here so the maru branch can issue + ``MaruHandler.batch_store``. Ignored by default L1 + backends, which read state from the in-process dict. """ - finish_result = self._l1_manager.finish_write(keys) + finish_result = self._l1_manager.finish_write(keys, memory_objs=memory_objs) successful_keys = [k for k, e in finish_result.items() if e == L1Error.SUCCESS] failed_keys = [k for k, e in finish_result.items() if e != L1Error.SUCCESS] self._event_bus.publish( @@ -435,6 +458,9 @@ def submit_prefetch_task( remaining_keys = keys[hit_count:] prefetch_request_id = -1 if remaining_keys and self._l2_adapters: + # In maru mode ``_l2_adapters`` is empty, so we never + # enter this branch and the controller stays ``None``. + assert self._prefetch_controller is not None prefetch_request_id = self._prefetch_controller.submit_prefetch_request( remaining_keys, layout_desc, @@ -492,7 +518,10 @@ def query_prefetch_lookup_hits( # No L2 request, the prefix hit count is final return handle.l1_prefix_hit_count - # Have L2 request, need to check the status from prefetch controller + # Have L2 request, need to check the status from prefetch + # controller. A non-(-1) request id implies the controller + # was constructed (maru mode never submits requests). + assert self._prefetch_controller is not None l2_r = self._prefetch_controller.query_lookup_result(handle.prefetch_request_id) if l2_r is None: @@ -521,6 +550,7 @@ def query_prefetch_status( # Have L2 request, need to check the result from prefetch controller if handle.prefetch_request_id != -1: + assert self._prefetch_controller is not None l2_r = self._prefetch_controller.query_prefetch_result( handle.prefetch_request_id ) @@ -596,14 +626,41 @@ def clear(self, force: bool = False): """ self._l1_manager.clear(force=force) + def register_kv_layout( + self, + shapes: list[torch.Size], + dtypes: list[torch.dtype], + fmt: MemoryFormat, + chunk_size_in_tokens: int, + ) -> None: + """Bind the KV layout to the underlying allocator. + + Called from ``MPCacheEngine.register_kv_cache`` once a vLLM + worker exposes its KV cache tensors. Only the maru backend + acts on the call (its ``CxlMemoryAdapter`` pool is typed at + first registration); default backends ignore it. + + Args: + shapes: KV chunk shapes (per-layer-group). + dtypes: KV chunk dtypes aligned with ``shapes``. + fmt: Memory format. + chunk_size_in_tokens: LMCache chunk size in tokens. + """ + self._l1_manager.register_kv_layout(shapes, dtypes, fmt, chunk_size_in_tokens) + def close(self): """ Close the storage manager and release all resources. """ - self._prefetch_controller.stop() - self._store_controller.stop() - self._eviction_controller.stop() - self._l2_eviction_controller.stop() + # Maru mode leaves controllers as ``None`` (see __init__). + if self._prefetch_controller is not None: + self._prefetch_controller.stop() + if self._store_controller is not None: + self._store_controller.stop() + if self._eviction_controller is not None: + self._eviction_controller.stop() + if self._l2_eviction_controller is not None: + self._l2_eviction_controller.stop() for adapter in self._l2_adapters: adapter.close() @@ -611,13 +668,30 @@ def close(self): self._l1_manager.close() def report_status(self) -> dict: - """Return a status dict aggregating all sub-component statuses.""" + """Return a status dict aggregating all sub-component statuses. + + In maru mode the controller / L2-adapter entries are absent; + only ``l1_manager`` and ``num_l2_adapters=0`` are reported. + """ l1 = self._l1_manager.report_status() + adapters = [a.report_status() for a in self._l2_adapters] + if self._is_maru: + return { + "is_healthy": l1["is_healthy"], + "l1_manager": l1, + "l2_adapters": adapters, + "num_l2_adapters": 0, + "backend": "maru", + } + + assert self._store_controller is not None + assert self._prefetch_controller is not None + assert self._eviction_controller is not None + assert self._l2_eviction_controller is not None store = self._store_controller.report_status() prefetch = self._prefetch_controller.report_status() l1_eviction = self._eviction_controller.report_status() l2_eviction = self._l2_eviction_controller.report_status() - adapters = [a.report_status() for a in self._l2_adapters] children = [l1, store, prefetch, l1_eviction, l2_eviction] + adapters return { "is_healthy": all(c["is_healthy"] for c in children), diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 1ff8f6ca17..0c1444a58c 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -35,7 +35,7 @@ lmcache_memcpy_async_h2d, ) from lmcache.v1.gpu_connector.utils import LayoutHints -from lmcache.v1.memory_management import MemoryObj +from lmcache.v1.memory_management import MemoryFormat, MemoryObj from lmcache.v1.mp_observability.config import ( ObservabilityConfig, add_observability_args, @@ -255,6 +255,19 @@ def register_kv_cache( ) self.gpu_contexts[instance_id] = gpu_context self.gpu_context_meta[instance_id] = (model_name, world_size) + + # Forward the KV layout to the storage manager. The maru + # backend needs this to bring up its ``CxlMemoryAdapter`` pool + # on first registration; default DRAM backends ignore the + # call. Subsequent registrations with a different layout are + # rejected by maru (single-model constraint) and pass through + # for default backends. + layout_desc = get_layout_desc(gpu_context, self.chunk_size) + fmt = MemoryFormat.KV_MLA_FMT if gpu_context.is_mla_ else MemoryFormat.KV_2LTD + self.storage_manager.register_kv_layout( + layout_desc.shapes, layout_desc.dtypes, fmt, self.chunk_size + ) + logger.info( "Registered KV cache for GPU ID %d with %d layers", instance_id, @@ -427,9 +440,19 @@ def store( finally: event.record() if reserved_dict: + # Snapshot keys/values now so the host callback + # sees the same set even if ``reserved_dict`` is + # mutated. ``memory_objs`` is required by the + # maru backend (used to issue + # ``MaruHandler.batch_store``); the default + # backend ignores it. + finish_keys = list(reserved_dict.keys()) + finish_objs = list(reserved_dict.values()) gpu_context.cupy_stream.launch_host_func( - self.storage_manager.finish_write, - list(reserved_dict.keys()), + lambda _: self.storage_manager.finish_write( + finish_keys, memory_objs=finish_objs + ), + None, ) # All reserved MemoryObjs share one layout_desc, so per-object # size is identical — avoid summing N identical values. diff --git a/scripts/maru_smoke.py b/scripts/maru_smoke.py new file mode 100755 index 0000000000..683aac3782 --- /dev/null +++ b/scripts/maru_smoke.py @@ -0,0 +1,535 @@ +#!/home/shson/.venv/bin/python +# SPDX-License-Identifier: Apache-2.0 +"""Maru integration smoke test. + +Run after ``maru-resource-manager`` and ``maru-server`` are up. Use +``--verbose`` to see full stack traces on failure. + +The script imports the ``lmcache`` package as the active environment +resolves it — i.e. the version your ``uv pip install . --no-build- +isolation`` last produced. T3 prints the loaded path so a stale +install is obvious; re-run ``uv pip install . --no-build-isolation`` +after editing the checkout if you want the smoke to exercise the +new code. + +Tiers (each builds on the previous): + + T1. maru's own ``examples/basic/single_instance.py`` runs to + completion. If this fails the issue is in the maru runtime + / environment, not in LMCache. + T2. ``MaruHandler.connect()`` returns control without crashing. + T3. ``MaruMemoryAllocator`` constructs (lazy — no RPC yet). + T4. ``MaruMemoryAllocator.init_layout()`` brings the CXL pool up. + T5. ``MaruMemoryAllocator.batched_allocate()`` returns MemoryObjs. + T6. ``L1MemoryManager.register_kv_layout()`` forwards down. + T7. ``MaruL2Adapter`` constructs (lazy — handler stays ``None``). + T8. First store triggers the lazy ``MaruHandler.connect`` and + the store → load round-trip preserves bytes. + +Usage:: + + /home/shson/.venv/bin/python scripts/maru_smoke.py + /home/shson/.venv/bin/python scripts/maru_smoke.py \ + --server maru://localhost:5555 --pool-gb 1 -v +""" + +# Standard +from pathlib import Path +import argparse +import os +import subprocess +import sys +import traceback + +PYTHON = sys.executable +MARU_EXAMPLE = Path("/home/shson/maru/examples/basic/single_instance.py") + +# ANSI colours, disabled when not a TTY so the script can be piped to a file. +_TTY = sys.stdout.isatty() +GREEN = "\033[32m" if _TTY else "" +RED = "\033[31m" if _TTY else "" +DIM = "\033[2m" if _TTY else "" +BOLD = "\033[1m" if _TTY else "" +RESET = "\033[0m" if _TTY else "" + + +def hdr(name: str) -> None: + print(f"\n{BOLD}=== {name} ==={RESET}", flush=True) + + +def ok(msg: str) -> None: + print(f" {GREEN}✓{RESET} {msg}", flush=True) + + +def fail(msg: str) -> None: + print(f" {RED}✗{RESET} {msg}", flush=True) + + +def dim(msg: str) -> None: + print(f" {DIM}{msg}{RESET}", flush=True) + + +# --------------------------------------------------------------------------- +# T1 — maru's own example as a subprocess, captures the SIGBUS exit code. +# --------------------------------------------------------------------------- + + +def t1_maru_example() -> bool: + hdr("T1 — maru built-in single_instance example") + if not MARU_EXAMPLE.is_file(): + fail(f"example not found: {MARU_EXAMPLE}") + dim("Adjust MARU_EXAMPLE in this script if maru lives elsewhere.") + return False + + try: + proc = subprocess.run( + [PYTHON, "-u", str(MARU_EXAMPLE)], + capture_output=True, + text=True, + timeout=30, + ) + except subprocess.TimeoutExpired: + fail("timed out after 30s") + return False + + if proc.returncode == 0: + ok(f"{MARU_EXAMPLE.name} ran to completion") + return True + + fail(f"exit code {proc.returncode}") + if proc.returncode == 135: + dim("Exit 135 = SIGBUS (128 + 7). Almost certainly an mmap/DAX") + dim("permission or pool-backing issue on the maru side.") + dim("---- stdout (tail) ----") + for line in proc.stdout.splitlines()[-10:]: + dim(line) + dim("---- stderr (tail) ----") + for line in proc.stderr.splitlines()[-10:]: + dim(line) + return False + + +# --------------------------------------------------------------------------- +# T2 — MaruHandler.connect() returns. Run in subprocess so a SIGBUS doesn't +# kill the rest of the smoke run. +# --------------------------------------------------------------------------- + + +_T2_SNIPPET = r""" +import sys +from maru import MaruConfig, MaruHandler +mc = MaruConfig( + server_url="{server_url}", + instance_id="maru-smoke-t2", + pool_size={pool_bytes}, + chunk_size_bytes=4 * 1024 * 1024, + auto_connect=False, + timeout_ms=5000, +) +h = MaruHandler(mc) +print("BUILT", flush=True) +ok = h.connect() +print(f"CONNECT_RETURNED ok={{ok}}", flush=True) +h.close() +print("CLOSE_OK", flush=True) +""" + + +def _run_in_subprocess(snippet: str, label: str) -> tuple[bool, str]: + """Return (success, stdout) — success means exit 0.""" + try: + proc = subprocess.run( + [PYTHON, "-u", "-c", snippet], + capture_output=True, + text=True, + timeout=30, + ) + except subprocess.TimeoutExpired: + return False, f"<{label} timed out>" + success = proc.returncode == 0 + return success, proc.stdout + ("\n[stderr]\n" + proc.stderr if proc.stderr else "") + + +def t2_handler_connect(server_url: str, pool_bytes: int) -> bool: + hdr("T2 — MaruHandler.connect()") + snippet = _T2_SNIPPET.format(server_url=server_url, pool_bytes=pool_bytes) + success, out = _run_in_subprocess(snippet, "T2") + last = [line for line in out.splitlines() if line.strip()][-5:] + if not success: + fail("connect path crashed or returned non-zero") + for line in last: + dim(line) + return False + if "CLOSE_OK" in out: + ok("connect/close cycle clean") + return True + fail("subprocess returned 0 but did not reach CLOSE_OK") + for line in last: + dim(line) + return False + + +# --------------------------------------------------------------------------- +# T3–T6 — LMCache-side. Run in-process now that we know the runtime is sane. +# --------------------------------------------------------------------------- + + +def t3_allocator_construct() -> bool: + hdr("T3 — MaruMemoryAllocator __init__ (lazy)") + # First Party + from lmcache.v1.distributed.maru_memory_allocator import ( + MaruL1Config, + MaruMemoryAllocator, + ) + import lmcache + + # Surface which lmcache we picked up so a path-shadowing surprise + # (e.g. an older pip-installed copy) is obvious. + dim(f"lmcache loaded from: {Path(lmcache.__file__).parent}") + + cfg = MaruL1Config( + server_url="maru://unused-for-this-test:1", + pool_size_bytes=1, + instance_id="maru-smoke-t3", + ) + alloc = MaruMemoryAllocator(cfg) + if alloc.is_initialized: + fail("allocator is_initialized=True after __init__ (lazy contract broken)") + return False + if alloc._handler is not None or alloc._cxl_adapter is not None: + fail("handler/adapter populated before init_layout") + return False + ok("__init__ returned with handler=adapter=None, is_initialized=False") + return True + + +def t4_init_layout(server_url: str, pool_bytes: int) -> bool: + hdr("T4 — MaruMemoryAllocator.init_layout()") + # Third Party + import torch + + # First Party + from lmcache.v1.distributed.maru_memory_allocator import ( + MaruL1Config, + MaruMemoryAllocator, + ) + from lmcache.v1.memory_management import MemoryFormat + + cfg = MaruL1Config( + server_url=server_url, + pool_size_bytes=pool_bytes, + instance_id="maru-smoke-t4", + ) + alloc = MaruMemoryAllocator(cfg) + shapes = [torch.Size([2, 32, 256, 128])] # 4 MiB / chunk + dtypes = [torch.float16] + try: + alloc.init_layout(shapes, dtypes, MemoryFormat.KV_2LTD, 256) + except Exception: + fail("init_layout raised") + traceback.print_exc() + return False + if not alloc.is_initialized: + fail("init_layout returned but is_initialized=False") + return False + ok(f"init_layout OK (single_token_size={alloc.single_token_size})") + + # Idempotent same-layout + try: + alloc.init_layout(shapes, dtypes, MemoryFormat.KV_2LTD, 256) + ok("idempotent same-layout call accepted") + except Exception: + fail("idempotent same-layout call raised") + traceback.print_exc() + return False + + # Mismatched layout — must reject. + try: + alloc.init_layout( + [torch.Size([2, 16, 256, 128])], dtypes, MemoryFormat.KV_2LTD, 256 + ) + fail("layout-mismatch did NOT raise (single-model constraint broken)") + return False + except ValueError: + ok("layout-mismatch rejected as expected") + + # Stash for T5 by returning the live allocator. + t4_init_layout._alloc = alloc # type: ignore[attr-defined] + return True + + +def t5_batched_allocate() -> bool: + hdr("T5 — MaruMemoryAllocator.batched_allocate()") + alloc = getattr(t4_init_layout, "_alloc", None) + if alloc is None: + fail("T4 must run first to populate the allocator") + return False + # Third Party + import torch + + # First Party + from lmcache.v1.memory_management import MemoryFormat + + shapes = [torch.Size([2, 32, 256, 128])] + dtypes = [torch.float16] + try: + objs = alloc.batched_allocate( + shapes, dtypes, batch_size=4, fmt=MemoryFormat.KV_2LTD + ) + except Exception: + fail("batched_allocate raised") + traceback.print_exc() + return False + if not objs or len(objs) != 4: + fail(f"expected 4 MemoryObjs, got {None if not objs else len(objs)}") + return False + ok(f"got {len(objs)} MemoryObjs back from the pool") + + alloc.close() + ok("close() OK") + return True + + +def t6_register_kv_layout(server_url: str, pool_bytes: int) -> bool: + hdr("T6 — L1MemoryManager.register_kv_layout()") + # Third Party + import torch + + # First Party + from lmcache.v1.distributed.config import L1MemoryManagerConfig + from lmcache.v1.distributed.maru_memory_allocator import MaruL1Config + from lmcache.v1.distributed.memory_manager import L1MemoryManager + from lmcache.v1.memory_management import MemoryFormat + + maru_cfg = MaruL1Config( + server_url=server_url, + pool_size_bytes=pool_bytes, + instance_id="maru-smoke-t6", + ) + mgr = L1MemoryManager( + L1MemoryManagerConfig(size_in_bytes=0, use_lazy=False, maru_config=maru_cfg) + ) + try: + mgr.register_kv_layout( + [torch.Size([2, 32, 256, 128])], + [torch.float16], + MemoryFormat.KV_2LTD, + 256, + ) + except Exception: + fail("register_kv_layout raised") + traceback.print_exc() + mgr.close() + return False + if not mgr._allocator.is_initialized: # type: ignore[attr-defined] + fail("register_kv_layout did not initialize the allocator") + mgr.close() + return False + ok("register_kv_layout drove the allocator to init_layout()") + mgr.close() + ok("L1MemoryManager.close() OK") + return True + + +def t7_l2_adapter_connect(server_url: str, pool_bytes: int) -> bool: + hdr("T7 — MaruL2Adapter construction (lazy — no RPC yet)") + # First Party + from lmcache.v1.distributed.l2_adapters.maru_l2_adapter import ( + MaruL2Adapter, + MaruL2AdapterConfig, + ) + + chunk_size_bytes = 1 << 20 # 1 MiB — small enough for a quick round trip + cfg = MaruL2AdapterConfig( + server_url=server_url, + pool_size_gb=max(pool_bytes / (1 << 30), 0.125), + chunk_size_bytes=chunk_size_bytes, + instance_id="maru-smoke-t7", + num_store_workers=1, + num_lookup_workers=1, + num_load_workers=1, + ) + try: + adapter = MaruL2Adapter(cfg) + except Exception: + fail("MaruL2Adapter construction raised") + traceback.print_exc() + return False + fds = { + adapter.get_store_event_fd(), + adapter.get_lookup_and_lock_event_fd(), + adapter.get_load_event_fd(), + } + if len(fds) != 3: + fail(f"expected 3 distinct event fds, got {len(fds)}") + adapter.close() + return False + if adapter._handler is not None: + fail("handler populated before any store (lazy contract broken)") + adapter.close() + return False + ok(f"constructed; handler=None, event fds = {sorted(fds)}") + + t7_l2_adapter_connect._adapter = adapter # type: ignore[attr-defined] + t7_l2_adapter_connect._chunk_size_bytes = chunk_size_bytes # type: ignore[attr-defined] + return True + + +def t8_l2_store_load_roundtrip() -> bool: + hdr("T8 — first store triggers lazy connect + store → load round-trip") + # Standard + from unittest import mock + import time + + # Third Party + import numpy as np + + # First Party + from lmcache.native_storage_ops import ( # noqa: F401 — surface import errors here + Bitmap, + ) + from lmcache.v1.distributed.api import ObjectKey + + adapter = getattr(t7_l2_adapter_connect, "_adapter", None) + chunk_size_bytes = getattr(t7_l2_adapter_connect, "_chunk_size_bytes", None) + if adapter is None or chunk_size_bytes is None: + fail("T7 must run first to populate the adapter") + return False + + # 1) Set up a deterministic byte pattern in a DRAM-side numpy + # buffer and wrap it as a fake MemoryObj (data_ptr + get_size). + payload_bytes = min(chunk_size_bytes, 65536) + src_arr = np.frombuffer( + bytes(i % 256 for i in range(payload_bytes)), dtype=np.uint8 + ).copy() # writable copy + + src_obj = mock.MagicMock(name="DramMemoryObj") + src_obj.data_ptr = int(src_arr.ctypes.data) + src_obj.get_size = mock.MagicMock(return_value=payload_bytes) + + key = ObjectKey( + chunk_hash=(0xCAFEBABE).to_bytes(4, "big"), + model_name="smoke-t8", + kv_rank=0xABCD, + cache_salt="", + ) + + # 2) Store: alloc CXL page + memcpy DRAM→CXL + batch_store. + store_task = adapter.submit_store_task([key], [src_obj]) + deadline = time.time() + 10 + while time.time() < deadline: + completed = adapter.pop_completed_store_tasks() + if store_task in completed: + break + time.sleep(0.05) + else: + fail("store task did not complete within 10s") + return False + if not completed[store_task]: + fail("store task returned failure") + return False + ok(f"stored {payload_bytes} bytes") + + # 3) Lookup + lock — bit 0 should be set after batch_pin. + lookup_task = adapter.submit_lookup_and_lock_task([key]) + bm = None + deadline = time.time() + 10 + while time.time() < deadline: + bm = adapter.query_lookup_and_lock_result(lookup_task) + if bm is not None: + break + time.sleep(0.05) + if bm is None or not bm.test(0): + fail("lookup did not report a hit for the just-stored key") + return False + ok("lookup hit") + + # 4) Load into a fresh DRAM buffer and verify bytes match. + dst_arr = np.zeros(payload_bytes, dtype=np.uint8) + dst_obj = mock.MagicMock(name="DramMemoryObj") + dst_obj.data_ptr = int(dst_arr.ctypes.data) + dst_obj.get_size = mock.MagicMock(return_value=payload_bytes) + + load_task = adapter.submit_load_task([key], [dst_obj]) + bm = None + deadline = time.time() + 10 + while time.time() < deadline: + bm = adapter.query_load_result(load_task) + if bm is not None: + break + time.sleep(0.05) + if bm is None or not bm.test(0): + fail("load did not report success for the just-stored key") + return False + if not np.array_equal(src_arr, dst_arr): + fail("loaded bytes do not match the source") + return False + ok(f"loaded {payload_bytes} bytes; DRAM↔CXL↔DRAM round-trip preserved") + + # 5) Cleanup: unlock + delete + close so the next run starts cold. + adapter.submit_unlock([key]) + adapter.delete([key]) + adapter.close() + ok("unlock + delete + close OK") + return True + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--server", + default=os.environ.get("MARU_SERVER_URL", "maru://localhost:5555"), + help="MaruServer URL (default: maru://localhost:5555).", + ) + parser.add_argument( + "--pool-gb", + type=float, + default=1.0, + help="CXL pool size in GB to request (default: 1).", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Currently a no-op placeholder; tracebacks are always printed on failure.", + ) + args = parser.parse_args() + pool_bytes = int(args.pool_gb * (1 << 30)) + + print(f"server={args.server} pool={args.pool_gb} GB", flush=True) + + # T1, T2 are subprocess-isolated so SIGBUS doesn't terminate us. + if not t1_maru_example(): + print(f"\n{RED}stop:{RESET} fix the maru runtime before continuing.") + return 1 + if not t2_handler_connect(args.server, pool_bytes): + print(f"\n{RED}stop:{RESET} MaruHandler.connect() not stable.") + return 1 + + # T3–T8 run in-process — at this point we trust the runtime. + # T7/T8 cover the L2 adapter path (DRAM→CXL store + CXL→DRAM load + # via ``MaruL2Adapter``); they share state through the + # ``t7_l2_adapter_connect._adapter`` stash so T8 reuses the + # already-connected handler. + results = [ + t3_allocator_construct(), + t4_init_layout(args.server, pool_bytes), + t5_batched_allocate(), + t6_register_kv_layout(args.server, pool_bytes), + t7_l2_adapter_connect(args.server, pool_bytes), + t8_l2_store_load_roundtrip(), + ] + if all(results): + print(f"\n{GREEN}{BOLD}all tiers passed.{RESET}") + return 0 + print(f"\n{RED}{BOLD}some tiers failed (see above).{RESET}") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/v1/distributed/test_l1_manager_maru.py b/tests/v1/distributed/test_l1_manager_maru.py new file mode 100644 index 0000000000..c2d871c54b --- /dev/null +++ b/tests/v1/distributed/test_l1_manager_maru.py @@ -0,0 +1,619 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the maru-backend branches of L1Manager. + +Coverage: + +1. ``object_key_to_string`` — stable string form for MaruHandler RPCs. +2. ``L1Manager.__init__`` — auto-detects ``MaruMemoryAllocator`` and + constructs a :class:`MaruL1Dispatcher`. +3. ``_is_maru_backend`` — dispatch flag. +4. STORE path: ``reserve_write`` (allocate) → ``finish_write`` + (``MaruHandler.batch_store``). +5. RETRIEVE path: ``reserve_read`` (``batch_pin`` + ``batch_retrieve`` + + ``get_by_location`` + side channel) → ``unsafe_read`` (side channel + lookup) → ``finish_read`` (``batch_unpin`` + side channel clear). +6. Race-condition rollback in ``reserve_read``. +7. ``delete`` / ``clear`` / ``finish_write_and_reserve_read``. +8. No-op methods: ``register_listener`` / ``touch_keys`` / + ``is_key_evictable`` / ``memcheck`` / ``get_object_state``. +9. ``report_status`` shape in maru mode. + +The maru runtime (``maru``, ``maru_lmcache``) is NOT required: the +``MaruMemoryAllocator`` constructor is monkey-patched to install +``MagicMock`` handler + adapter instead of opening a real connection. +""" + +# Standard +from dataclasses import dataclass +from typing import Optional +from unittest import mock + +# Third Party +import pytest + +# First Party +from lmcache.v1.distributed.api import ObjectKey +from lmcache.v1.distributed.config import L1ManagerConfig, L1MemoryManagerConfig +from lmcache.v1.distributed.error import L1Error + +try: + # First Party + from lmcache.v1.distributed.l1_manager import L1Manager + from lmcache.v1.distributed.maru_l1_dispatch import ( + MaruL1Dispatcher, + object_key_to_string, + ) + from lmcache.v1.distributed.maru_memory_allocator import ( + MaruL1Config, + MaruMemoryAllocator, + ) +except ImportError: + pytest.skip( + "l1_manager / maru_memory_allocator could not be imported", + allow_module_level=True, + ) + + +# ========================================================================= +# Fixtures +# ========================================================================= + + +@dataclass +class _FakeMemInfo: + """Minimal stand-in for ``MaruHandler.batch_retrieve`` return entries. + + Only the fields read by ``L1Manager._maru_reserve_read`` are present: + ``region_id``, ``page_index``, and a ``view`` object whose ``len()`` + gives the chunk size in bytes. + """ + + region_id: int + page_index: int + view: bytes + + +@pytest.fixture +def maru_cfg() -> MaruL1Config: + return MaruL1Config( + server_url="maru://localhost:5555", + pool_size_bytes=60 * 1024**3, + instance_id="test-mp", + ) + + +@pytest.fixture +def fake_maru_allocator(): + """Replace ``MaruMemoryAllocator.__init__`` with a stub that + installs ``MagicMock`` handler + adapter directly — equivalent + to the post-``init_layout`` state. Reverts on teardown. + """ + real_init = MaruMemoryAllocator.__init__ + + def fake_init(self, config: MaruL1Config) -> None: + real_init(self, config) + # Post-``init_layout`` state: pool, handler, and layout + # metadata are present so the allocator is considered + # initialized. + self._handler = mock.MagicMock(name="MaruHandler") + self._cxl_adapter = mock.MagicMock(name="CxlMemoryAdapter") + self._single_token_size = 4096 # dummy non-zero + + MaruMemoryAllocator.__init__ = fake_init + try: + yield + finally: + MaruMemoryAllocator.__init__ = real_init + + +@pytest.fixture +def maru_mgr(maru_cfg, fake_maru_allocator) -> L1Manager: + """Build an ``L1Manager`` whose underlying allocator is a fake + ``MaruMemoryAllocator``. + """ + cfg = L1ManagerConfig( + memory_config=L1MemoryManagerConfig( + size_in_bytes=0, use_lazy=False, maru_config=maru_cfg + ) + ) + return L1Manager(cfg) + + +@pytest.fixture +def maru_handler(maru_mgr): + """The ``MagicMock`` handler installed by ``fake_maru_allocator``. + + Shortcut so test methods can configure RPC return values directly + without threading through the full dispatcher → allocator chain. + """ + return maru_mgr._maru_dispatcher._allocator._handler + + +@pytest.fixture +def maru_adapter(maru_mgr): + """The ``MagicMock`` ``CxlMemoryAdapter`` installed by + ``fake_maru_allocator``. Use it to configure + ``create_store_handle`` / ``get_by_location`` side effects. + """ + return maru_mgr._maru_dispatcher._allocator._cxl_adapter + + +def _mk_key(idx: int = 0, salt: str = "") -> ObjectKey: + return ObjectKey( + chunk_hash=idx.to_bytes(4, byteorder="big"), + model_name="test-model", + kv_rank=0xABCD, + cache_salt=salt, + ) + + +# ========================================================================= +# (1) object_key_to_string +# ========================================================================= + + +class TestObjectKeyToString: + def test_basic(self): + k = _mk_key(idx=0x01020304) + assert object_key_to_string(k) == "test-model@0000abcd@01020304" + + def test_with_salt(self): + k = _mk_key(idx=0xFF, salt="user-1") + assert object_key_to_string(k) == "test-model@0000abcd@000000ff@user-1" + + +# ========================================================================= +# (2) __init__ / _is_maru_backend +# ========================================================================= + + +class TestMaruBackendDetection: + def test_dispatcher_wired(self, maru_mgr, maru_handler): + assert maru_mgr._is_maru_backend() is True + assert isinstance(maru_mgr._maru_dispatcher, MaruL1Dispatcher) + # Dispatcher resolves ``handler`` through its allocator + # reference — verify it points at the same MagicMock. + assert maru_mgr._maru_dispatcher.handler is maru_handler + + def test_pending_read_memobjs_initialized(self, maru_mgr): + assert maru_mgr._maru_dispatcher._pending_read_memobjs == {} + + +# ========================================================================= +# (3) STORE path: reserve_write + finish_write +# ========================================================================= + + +class TestMaruReserveWrite: + def test_happy_path_allocates_via_memory_manager(self, maru_mgr): + # Patch the memory_manager to short-circuit to a known result. + keys = [_mk_key(i) for i in range(3)] + fake_objs = [mock.MagicMock(spec=[]) for _ in keys] + maru_mgr._memory_manager = mock.MagicMock() + maru_mgr._maru_dispatcher._memory_manager = maru_mgr._memory_manager + maru_mgr._memory_manager.allocate.return_value = (L1Error.SUCCESS, fake_objs) + + ret = maru_mgr.reserve_write( + keys, + is_temporary=[False] * len(keys), + layout_desc=mock.MagicMock(), + mode="new", + ) + + assert maru_mgr._memory_manager.allocate.called + for k, obj in zip(keys, fake_objs, strict=False): + err, returned = ret[k] + assert err is L1Error.SUCCESS + assert returned is obj + # No in-process dict entries should be created in maru mode. + assert maru_mgr._objects == {} + + def test_out_of_memory(self, maru_mgr): + keys = [_mk_key(i) for i in range(2)] + maru_mgr._memory_manager = mock.MagicMock() + maru_mgr._maru_dispatcher._memory_manager = maru_mgr._memory_manager + maru_mgr._memory_manager.allocate.return_value = (L1Error.OUT_OF_MEMORY, []) + + ret = maru_mgr.reserve_write( + keys, is_temporary=[False, False], layout_desc=mock.MagicMock(), mode="new" + ) + for k in keys: + err, returned = ret[k] + assert err is L1Error.OUT_OF_MEMORY + assert returned is None + + +class TestMaruFinishWrite: + def test_happy_path_calls_batch_store(self, maru_mgr, maru_handler, maru_adapter): + keys = [_mk_key(i) for i in range(2)] + memory_objs = [mock.MagicMock(name=f"mo-{i}") for i in range(2)] + # ``MaruMemoryAllocator.create_store_handle`` forwards to the + # underlying ``CxlMemoryAdapter`` — we configure that mock to + # observe and override the return values. + handles = [mock.MagicMock(name=f"handle-{i}") for i in range(2)] + maru_adapter.create_store_handle.side_effect = handles + maru_handler.batch_store.return_value = [ + True, + True, + ] + + ret = maru_mgr.finish_write(keys, memory_objs=memory_objs) + + # Verify batch_store was called with key strings + handles. + called_args = maru_handler.batch_store.call_args + called_key_strs, called_handles = called_args.args + assert called_key_strs == [object_key_to_string(k) for k in keys] + assert called_handles == handles + for k in keys: + assert ret[k] is L1Error.SUCCESS + + def test_dup_skip_returns_success(self, maru_mgr, maru_handler, maru_adapter): + # ``batch_store`` returns True for both newly registered AND + # dup-skipped keys; both are functional successes. + keys = [_mk_key(i) for i in range(2)] + memory_objs = [mock.MagicMock() for _ in keys] + maru_adapter.create_store_handle.side_effect = [mock.MagicMock() for _ in keys] + # MaruHandler returns True even for dup-skipped keys. + maru_handler.batch_store.return_value = [ + True, + True, + ] + + ret = maru_mgr.finish_write(keys, memory_objs=memory_objs) + for k in keys: + assert ret[k] is L1Error.SUCCESS + + def test_missing_memory_objs_returns_error(self, maru_mgr, maru_handler): + keys = [_mk_key(i) for i in range(2)] + ret = maru_mgr.finish_write(keys, memory_objs=None) + for k in keys: + assert ret[k] is L1Error.KEY_IN_WRONG_STATE + maru_handler.batch_store.assert_not_called() + + def test_length_mismatch_returns_error(self, maru_mgr, maru_handler): + keys = [_mk_key(i) for i in range(3)] + ret = maru_mgr.finish_write(keys, memory_objs=[mock.MagicMock()]) + for k in keys: + assert ret[k] is L1Error.KEY_IN_WRONG_STATE + maru_handler.batch_store.assert_not_called() + + def test_batch_store_exception_returns_error( + self, maru_mgr, maru_handler, maru_adapter + ): + keys = [_mk_key(i) for i in range(2)] + memory_objs = [mock.MagicMock() for _ in keys] + maru_adapter.create_store_handle.side_effect = [mock.MagicMock() for _ in keys] + maru_handler.batch_store.side_effect = RuntimeError("rpc fail") + ret = maru_mgr.finish_write(keys, memory_objs=memory_objs) + for k in keys: + assert ret[k] is L1Error.KEY_IN_WRONG_STATE + + +# ========================================================================= +# (4) RETRIEVE path: reserve_read + unsafe_read + finish_read +# ========================================================================= + + +class TestMaruReserveRead: + def test_all_hit(self, maru_mgr, maru_handler, maru_adapter): + keys = [_mk_key(i) for i in range(3)] + maru_handler.batch_pin.return_value = [ + True, + True, + True, + ] + mem_infos = [ + _FakeMemInfo(region_id=i, page_index=i, view=b"x" * 32) for i in range(3) + ] + maru_handler.batch_retrieve.return_value = mem_infos + fake_objs = [mock.MagicMock(name=f"obj-{i}") for i in range(3)] + maru_adapter.get_by_location.side_effect = fake_objs + + ret = maru_mgr.reserve_read(keys) + + for k, obj in zip(keys, fake_objs, strict=False): + err, returned = ret[k] + assert err is L1Error.SUCCESS + assert returned is obj + assert maru_mgr._maru_dispatcher._pending_read_memobjs[k] is obj + + def test_prefix_miss(self, maru_mgr, maru_handler, maru_adapter): + """``batch_pin`` reports prefix-stop: only k0, k1 are pinned.""" + keys = [_mk_key(i) for i in range(3)] + maru_handler.batch_pin.return_value = [ + True, + True, + False, + ] + mem_infos = [_FakeMemInfo(i, i, b"x" * 32) for i in range(2)] + maru_handler.batch_retrieve.return_value = mem_infos + fake_objs = [mock.MagicMock(), mock.MagicMock()] + maru_adapter.get_by_location.side_effect = fake_objs + + ret = maru_mgr.reserve_read(keys) + + assert ret[keys[0]][0] is L1Error.SUCCESS + assert ret[keys[1]][0] is L1Error.SUCCESS + assert ret[keys[2]] == (L1Error.KEY_NOT_EXIST, None) + assert keys[2] not in maru_mgr._maru_dispatcher._pending_read_memobjs + + def test_all_miss(self, maru_mgr, maru_handler): + keys = [_mk_key(i) for i in range(2)] + maru_handler.batch_pin.return_value = [ + False, + False, + ] + + ret = maru_mgr.reserve_read(keys) + + for k in keys: + assert ret[k] == (L1Error.KEY_NOT_EXIST, None) + maru_handler.batch_retrieve.assert_not_called() + assert maru_mgr._maru_dispatcher._pending_read_memobjs == {} + + def test_race_batch_retrieve_returns_none_mid_batch( + self, maru_mgr, maru_handler, maru_adapter + ): + """k0 resolves; k1 races and returns None — k1 is unpinned.""" + keys = [_mk_key(i) for i in range(2)] + maru_handler.batch_pin.return_value = [ + True, + True, + ] + maru_handler.batch_retrieve.return_value = [ + _FakeMemInfo(0, 0, b"x" * 32), + None, + ] + maru_adapter.get_by_location.return_value = mock.MagicMock(name="obj-0") + + ret = maru_mgr.reserve_read(keys) + + assert ret[keys[0]][0] is L1Error.SUCCESS + assert ret[keys[1]] == (L1Error.KEY_NOT_EXIST, None) + # Only k1's key string should be rolled back via batch_unpin. + maru_handler.batch_unpin.assert_called_once_with( + [object_key_to_string(keys[1])] + ) + + def test_race_get_by_location_returns_none( + self, maru_mgr, maru_handler, maru_adapter + ): + keys = [_mk_key(0)] + maru_handler.batch_pin.return_value = [True] + maru_handler.batch_retrieve.return_value = [_FakeMemInfo(0, 0, b"x" * 32)] + maru_adapter.get_by_location.return_value = None + + ret = maru_mgr.reserve_read(keys) + + assert ret[keys[0]] == (L1Error.KEY_NOT_EXIST, None) + maru_handler.batch_unpin.assert_called_once_with( + [object_key_to_string(keys[0])] + ) + + def test_batch_pin_exception_returns_miss_for_all(self, maru_mgr, maru_handler): + keys = [_mk_key(i) for i in range(2)] + maru_handler.batch_pin.side_effect = RuntimeError("rpc fail") + + ret = maru_mgr.reserve_read(keys) + for k in keys: + assert ret[k] == (L1Error.KEY_NOT_EXIST, None) + + def test_batch_retrieve_exception_rolls_back_pins(self, maru_mgr, maru_handler): + keys = [_mk_key(i) for i in range(2)] + maru_handler.batch_pin.return_value = [ + True, + True, + ] + maru_handler.batch_retrieve.side_effect = RuntimeError("rpc fail") + + ret = maru_mgr.reserve_read(keys) + for k in keys: + assert ret[k] == (L1Error.KEY_NOT_EXIST, None) + # Both pins should have been rolled back. + maru_handler.batch_unpin.assert_called_once_with( + [object_key_to_string(k) for k in keys] + ) + + +class TestMaruUnsafeRead: + def test_returns_staged_memobj(self, maru_mgr): + keys = [_mk_key(i) for i in range(2)] + fake_objs = [mock.MagicMock(name=f"obj-{i}") for i in range(2)] + for k, obj in zip(keys, fake_objs, strict=False): + maru_mgr._maru_dispatcher._pending_read_memobjs[k] = obj + + ret = maru_mgr.unsafe_read(keys) + + for k, obj in zip(keys, fake_objs, strict=False): + err, returned = ret[k] + assert err is L1Error.SUCCESS + assert returned is obj + + def test_missing_key_returns_not_exist(self, maru_mgr): + keys = [_mk_key(0), _mk_key(1)] + maru_mgr._maru_dispatcher._pending_read_memobjs[keys[0]] = mock.MagicMock() + + ret = maru_mgr.unsafe_read(keys) + + assert ret[keys[0]][0] is L1Error.SUCCESS + assert ret[keys[1]] == (L1Error.KEY_NOT_EXIST, None) + + +class TestMaruFinishRead: + def test_pops_side_channel_and_unpins(self, maru_mgr, maru_handler): + keys = [_mk_key(i) for i in range(2)] + for k in keys: + maru_mgr._maru_dispatcher._pending_read_memobjs[k] = mock.MagicMock() + + ret = maru_mgr.finish_read(keys) + + for k in keys: + assert ret[k] is L1Error.SUCCESS + assert k not in maru_mgr._maru_dispatcher._pending_read_memobjs + maru_handler.batch_unpin.assert_called_once_with( + [object_key_to_string(k) for k in keys] + ) + + def test_non_pending_key_returns_not_exist_and_skips_unpin( + self, maru_mgr, maru_handler + ): + unknown = _mk_key(99) + ret = maru_mgr.finish_read([unknown]) + assert ret[unknown] is L1Error.KEY_NOT_EXIST + maru_handler.batch_unpin.assert_not_called() + + def test_unpin_exception_does_not_propagate(self, maru_mgr, maru_handler): + keys = [_mk_key(0)] + maru_mgr._maru_dispatcher._pending_read_memobjs[keys[0]] = mock.MagicMock() + maru_handler.batch_unpin.side_effect = RuntimeError("rpc fail") + + # Should not raise; side channel is still cleared. + ret = maru_mgr.finish_read(keys) + assert ret[keys[0]] is L1Error.SUCCESS + assert maru_mgr._maru_dispatcher._pending_read_memobjs == {} + + +# ========================================================================= +# (5) delete / clear / finish_write_and_reserve_read +# ========================================================================= + + +class TestMaruDelete: + def test_success(self, maru_mgr, maru_handler): + keys = [_mk_key(i) for i in range(2)] + maru_handler.delete.return_value = True + ret = maru_mgr.delete(keys) + assert all(v is L1Error.SUCCESS for v in ret.values()) + assert maru_handler.delete.call_count == 2 + + def test_handler_returns_false_reports_not_exist(self, maru_mgr, maru_handler): + """``MaruHandler.delete`` returns False for either missing or + pinned keys; ``L1Manager`` maps both to ``KEY_NOT_EXIST``. + """ + keys = [_mk_key(0)] + maru_handler.delete.return_value = False + ret = maru_mgr.delete(keys) + assert ret[keys[0]] is L1Error.KEY_NOT_EXIST + + def test_exception_reports_wrong_state(self, maru_mgr, maru_handler): + keys = [_mk_key(0)] + maru_handler.delete.side_effect = RuntimeError("rpc fail") + ret = maru_mgr.delete(keys) + assert ret[keys[0]] is L1Error.KEY_IN_WRONG_STATE + + +class TestMaruClear: + def test_clear_drops_side_channel_only(self, maru_mgr, maru_handler): + maru_mgr._maru_dispatcher._pending_read_memobjs[_mk_key(0)] = mock.MagicMock() + maru_mgr._maru_dispatcher._pending_read_memobjs[_mk_key(1)] = mock.MagicMock() + + maru_mgr.clear() + + assert maru_mgr._maru_dispatcher._pending_read_memobjs == {} + # Server-side state is untouched. + maru_handler.delete.assert_not_called() + + def test_force_clear_also_drops_only_side_channel(self, maru_mgr, maru_handler): + maru_mgr._maru_dispatcher._pending_read_memobjs[_mk_key(0)] = mock.MagicMock() + + maru_mgr.clear(force=True) + + assert maru_mgr._maru_dispatcher._pending_read_memobjs == {} + maru_handler.delete.assert_not_called() + + +class TestMaruFinishWriteAndReserveRead: + def test_resolves_from_side_channel(self, maru_mgr): + k = _mk_key(0) + fake_obj = mock.MagicMock() + maru_mgr._maru_dispatcher._pending_read_memobjs[k] = fake_obj + ret = maru_mgr.finish_write_and_reserve_read([k]) + err, returned = ret[k] + assert err is L1Error.SUCCESS + assert returned is fake_obj + + def test_missing_key_returns_not_exist(self, maru_mgr): + ret = maru_mgr.finish_write_and_reserve_read([_mk_key(0)]) + assert ret[_mk_key(0)] == (L1Error.KEY_NOT_EXIST, None) + + +# ========================================================================= +# (6) No-op methods +# ========================================================================= + + +class TestMaruNoOps: + def test_register_listener_is_dropped(self, maru_mgr): + listener = mock.MagicMock() + maru_mgr.register_listener(listener) + assert listener not in maru_mgr._registered_listeners + + def test_touch_keys_is_noop(self, maru_mgr): + listener = mock.MagicMock() + # Even if a listener somehow ended up registered, maru-mode + # ``touch_keys`` should not call it. + maru_mgr._registered_listeners.append(listener) + maru_mgr.touch_keys([_mk_key(0)]) + listener.on_l1_keys_accessed.assert_not_called() + + def test_is_key_evictable_always_true(self, maru_mgr): + # No L1EvictionController is registered in maru mode, but the + # method still returns True so any defensive caller sees a + # consistent answer. + assert maru_mgr.is_key_evictable(_mk_key(0)) is True + # Also for a key that happens to be in the side channel. + maru_mgr._maru_dispatcher._pending_read_memobjs[_mk_key(1)] = mock.MagicMock() + assert maru_mgr.is_key_evictable(_mk_key(1)) is True + + def test_memcheck_returns_true(self, maru_mgr): + assert maru_mgr.memcheck() is True + + def test_get_object_state_returns_none(self, maru_mgr): + assert maru_mgr.get_object_state(_mk_key(0)) is None + + +# ========================================================================= +# (7) report_status / close +# ========================================================================= + + +class TestMaruReportStatus: + def test_shape(self, maru_mgr): + maru_mgr._maru_dispatcher._pending_read_memobjs[_mk_key(0)] = mock.MagicMock() + # Memory manager get_memory_usage is exercised by + # ``test_l1_memory_manager_maru.py``; here we only verify the + # maru-mode dict shape. + maru_mgr._memory_manager = mock.MagicMock() + maru_mgr._maru_dispatcher._memory_manager = maru_mgr._memory_manager + maru_mgr._memory_manager.get_memory_usage.return_value = (10, 100) + + status = maru_mgr.report_status() + + assert status["backend"] == "maru" + assert status["is_healthy"] is True + assert status["total_object_count"] == 0 + assert status["pending_read_memobjs"] == 1 + assert status["memory_used_bytes"] == 10 + assert status["memory_total_bytes"] == 100 + assert status["memory_usage_ratio"] == 0.1 + + +class TestMaruClose: + def test_close_clears_side_channel(self, maru_mgr): + maru_mgr._maru_dispatcher._pending_read_memobjs[_mk_key(0)] = mock.MagicMock() + maru_mgr._memory_manager = mock.MagicMock() + maru_mgr._maru_dispatcher._memory_manager = maru_mgr._memory_manager + + maru_mgr.close() + + assert maru_mgr._maru_dispatcher._pending_read_memobjs == {} + maru_mgr._memory_manager.close.assert_called_once() + + +# ========================================================================= +# Quick coverage for an unused-import suppressor — keep linters happy. +# ========================================================================= + + +def test_optional_import_smoke(): + assert Optional[int] is not None # noqa: B015 diff --git a/tests/v1/distributed/test_l1_memory_manager_maru.py b/tests/v1/distributed/test_l1_memory_manager_maru.py new file mode 100644 index 0000000000..9d9223b07e --- /dev/null +++ b/tests/v1/distributed/test_l1_memory_manager_maru.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the maru-backend wiring of L1MemoryManager. + +Coverage: + +1. ``L1MemoryManagerConfig.maru_config`` — when set, the DRAM-only + ``init_size_in_bytes`` clamp is skipped. +2. ``create_memory_allocator()`` — routes to ``MaruMemoryAllocator`` + when ``maru_config`` is set, otherwise to the existing DRAM + allocators. +3. ``_is_maru_allocator()`` helper. +4. ``L1MemoryManager.get_memory_usage()`` — best-effort forwarding to + ``MaruHandler.get_stats``; short-circuits to ``(0, 0)`` before + ``init_layout`` is called. +5. ``L1MemoryManager.get_l1_memory_desc()`` — raises + ``NotImplementedError`` for maru (no contiguous DRAM buffer). +6. ``L1MemoryManager.register_kv_layout()`` — forwards to + ``MaruMemoryAllocator.init_layout`` for the maru backend and is a + no-op for default backends. + +The maru runtime (``maru``, ``maru_lmcache``) is NOT required: the +lazy ``MaruMemoryAllocator.__init__`` performs no RPC. Tests that +need an "initialized" allocator install ``MagicMock`` handler + +adapter directly on the instance. +""" + +# Standard +from unittest import mock + +# Third Party +import pytest +import torch + +# First Party +from lmcache.v1.distributed.config import L1MemoryManagerConfig +from lmcache.v1.lazy_memory_allocator import LazyMemoryAllocator +from lmcache.v1.memory_management import MemoryFormat + +try: + # First Party + from lmcache.v1.distributed.maru_memory_allocator import ( + MaruL1Config, + MaruMemoryAllocator, + ) + from lmcache.v1.distributed.memory_manager import ( + L1MemoryManager, + _is_maru_allocator, + create_memory_allocator, + ) +except ImportError: + pytest.skip( + "MaruMemoryAllocator / memory_manager could not be imported", + allow_module_level=True, + ) + + +@pytest.fixture +def maru_cfg() -> MaruL1Config: + """Plausible MaruL1Config — the new lazy ``__init__`` performs no + MaruServer RPC, so these values are not exercised unless a test + explicitly drives ``init_layout``. + """ + return MaruL1Config( + server_url="maru://localhost:5555", + pool_size_bytes=60 * 1024**3, + instance_id="test-mp", + ) + + +# Tiny allocations so the dispatch tests don't pin gigabytes of host memory +# and starve subsequent ``MixedMemoryAllocator`` tests in the same process. +# ``LazyMemoryAllocator.__init__`` eagerly calls ``torch.empty(final_size)`` +# and ``cudaHostRegister`` on ``init_size`` — so we keep both ≤ 1MB. +_TINY_BYTES = 1 << 20 # 1MB + + +# ========================================================================= +# (1) L1MemoryManagerConfig — maru_config field +# ========================================================================= + + +class TestL1MemoryManagerConfigMaru: + def test_default_has_no_maru_config(self): + cfg = L1MemoryManagerConfig(size_in_bytes=_TINY_BYTES, use_lazy=False) + assert cfg.maru_config is None + + def test_default_clamps_init_size(self): + # init_size_in_bytes defaults to 20GB; size_in_bytes=1MB → clamp to 1MB. + cfg = L1MemoryManagerConfig(size_in_bytes=_TINY_BYTES, use_lazy=False) + assert cfg.init_size_in_bytes == _TINY_BYTES + + def test_maru_config_skips_clamp(self, maru_cfg): + # size_in_bytes=0 is OK when maru_config is set (DRAM fields ignored). + # The default init_size_in_bytes (20GB) should NOT be clamped to 0. + cfg = L1MemoryManagerConfig( + size_in_bytes=0, use_lazy=False, maru_config=maru_cfg + ) + assert cfg.maru_config is maru_cfg + assert cfg.init_size_in_bytes == 20 << 30 + + +# ========================================================================= +# (2) create_memory_allocator() dispatch +# ========================================================================= + + +class TestCreateMemoryAllocatorDispatch: + def test_lazy_path_unchanged(self): + cfg = L1MemoryManagerConfig(size_in_bytes=_TINY_BYTES, use_lazy=True) + alloc = create_memory_allocator(cfg) + try: + assert isinstance(alloc, LazyMemoryAllocator) + finally: + alloc.close() + + # NOTE: ``use_lazy=False`` (MixedMemoryAllocator) is intentionally + # NOT covered here — its constructor eagerly invokes + # ``cudaHostAlloc`` which is environment-dependent. That path is + # already exercised by ``test_l1_memory_manager.py``; we only need + # to verify the maru routing here. + + def test_maru_path_routes_to_maru_allocator(self, maru_cfg): + cfg = L1MemoryManagerConfig( + size_in_bytes=0, use_lazy=False, maru_config=maru_cfg + ) + alloc = create_memory_allocator(cfg) + assert isinstance(alloc, MaruMemoryAllocator) + # Lazy: handler / adapter are still ``None`` before init_layout. + assert alloc._handler is None + assert alloc._cxl_adapter is None + assert alloc.is_initialized is False + + def test_maru_config_takes_precedence_over_use_lazy(self, maru_cfg): + # use_lazy=True should be ignored when maru_config is set. + cfg = L1MemoryManagerConfig( + size_in_bytes=_TINY_BYTES, use_lazy=True, maru_config=maru_cfg + ) + alloc = create_memory_allocator(cfg) + assert isinstance(alloc, MaruMemoryAllocator) + + +# ========================================================================= +# (3) _is_maru_allocator helper +# ========================================================================= + + +class TestIsMaruAllocator: + def test_returns_false_for_lazy(self): + cfg = L1MemoryManagerConfig(size_in_bytes=_TINY_BYTES, use_lazy=True) + alloc = create_memory_allocator(cfg) + try: + assert _is_maru_allocator(alloc) is False + finally: + alloc.close() + + def test_returns_true_for_maru(self, maru_cfg): + cfg = L1MemoryManagerConfig( + size_in_bytes=0, use_lazy=False, maru_config=maru_cfg + ) + alloc = create_memory_allocator(cfg) + assert _is_maru_allocator(alloc) is True + + +# ========================================================================= +# (4) L1MemoryManager.get_memory_usage() — maru case +# ========================================================================= + + +def _make_maru_manager(maru_cfg) -> L1MemoryManager: + """Build an ``L1MemoryManager`` whose allocator is a freshly + constructed (uninitialized) maru allocator. Tests that exercise + handler stats need to install ``_handler`` and ``_cxl_adapter`` + mocks on the allocator. + """ + cfg = L1MemoryManagerConfig(size_in_bytes=0, use_lazy=False, maru_config=maru_cfg) + return L1MemoryManager(cfg) + + +def _fake_init_layout(allocator: MaruMemoryAllocator) -> None: + """Install ``MagicMock`` handler + adapter so the allocator + behaves as ``is_initialized`` without contacting MaruServer. + """ + allocator._handler = mock.MagicMock() + allocator._cxl_adapter = mock.MagicMock() + + +class TestGetMemoryUsageMaru: + def test_returns_zero_before_init_layout(self, maru_cfg): + # Allocator constructed but ``init_layout`` not yet called. + mgr = _make_maru_manager(maru_cfg) + assert mgr.get_memory_usage() == (0, 0) + + def test_returns_zero_when_handler_has_no_get_stats(self, maru_cfg): + mgr = _make_maru_manager(maru_cfg) + _fake_init_layout(mgr._allocator) + # spec=[] → mock has no attributes (no ``get_stats``) + mgr._allocator._handler = mock.Mock(spec=[]) + assert mgr.get_memory_usage() == (0, 0) + + def test_forwards_used_and_pool_size_bytes(self, maru_cfg): + mgr = _make_maru_manager(maru_cfg) + _fake_init_layout(mgr._allocator) + mgr._allocator._handler.get_stats.return_value = { + "used_bytes": 1234, + "pool_size_bytes": 5678, + } + assert mgr.get_memory_usage() == (1234, 5678) + + def test_falls_back_to_pool_size_key(self, maru_cfg): + mgr = _make_maru_manager(maru_cfg) + _fake_init_layout(mgr._allocator) + mgr._allocator._handler.get_stats.return_value = { + "used_bytes": 100, + "pool_size": 999, + } + assert mgr.get_memory_usage() == (100, 999) + + def test_returns_zero_on_handler_exception(self, maru_cfg): + mgr = _make_maru_manager(maru_cfg) + _fake_init_layout(mgr._allocator) + mgr._allocator._handler.get_stats.side_effect = RuntimeError("boom") + # Should swallow and return (0, 0) rather than crash. + assert mgr.get_memory_usage() == (0, 0) + + +# ========================================================================= +# (5) L1MemoryManager.get_l1_memory_desc() — maru case +# ========================================================================= + + +class TestGetL1MemoryDescMaru: + def test_raises_not_implemented_for_maru(self, maru_cfg): + mgr = _make_maru_manager(maru_cfg) + with pytest.raises(NotImplementedError, match="maru"): + mgr.get_l1_memory_desc() + + +# ========================================================================= +# (6) L1MemoryManager.register_kv_layout() — maru forwarding +# ========================================================================= + + +class TestRegisterKvLayoutMaru: + def test_forwards_to_allocator_init_layout(self, maru_cfg): + mgr = _make_maru_manager(maru_cfg) + shapes = [torch.Size([2, 32, 256, 128])] + dtypes = [torch.float16] + with mock.patch.object(MaruMemoryAllocator, "init_layout") as mock_init_layout: + mgr.register_kv_layout(shapes, dtypes, MemoryFormat.KV_2LTD, 256) + mock_init_layout.assert_called_once_with( + shapes, dtypes, MemoryFormat.KV_2LTD, 256 + ) + + def test_default_backend_is_noop(self): + # Lazy / Mixed backends are layout-agnostic — call must succeed + # silently and not affect their internal state. + cfg = L1MemoryManagerConfig(size_in_bytes=_TINY_BYTES, use_lazy=True) + mgr = L1MemoryManager(cfg) + try: + mgr.register_kv_layout( + [torch.Size([2, 32, 256, 128])], + [torch.float16], + MemoryFormat.KV_2LTD, + 256, + ) + finally: + mgr.close() diff --git a/tests/v1/distributed/test_maru_l2_adapter.py b/tests/v1/distributed/test_maru_l2_adapter.py new file mode 100644 index 0000000000..639f8465a0 --- /dev/null +++ b/tests/v1/distributed/test_maru_l2_adapter.py @@ -0,0 +1,780 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for ``MaruL2Adapter``. + +Coverage: + +1. ``MaruL2AdapterConfig.from_dict`` — required-field validation, + type checks, default fills. +2. Factory registration in the ``L2`` adapter registry. +3. ``submit_store_task`` — happy path (alloc + memmove DRAM→CXL + + ``batch_store``), length mismatch, handler exception. +4. ``submit_lookup_and_lock_task`` — prefix-stop bitmap (all-hit, + prefix, all-miss), handler exception. +5. ``submit_load_task`` — happy path memmove CXL→DRAM, partial-miss + bitmap, handler exception. +6. ``submit_unlock`` / ``delete`` — sync dispatch, error swallowing, + empty input. +7. ``close()`` — idempotent teardown without inflight work. +8. ``_object_key_to_string`` — encoding parity with the L1 dispatcher. + +``MaruHandler`` is monkey-patched: ``MaruL2Adapter._connect_handler`` +is replaced so it returns a ``MagicMock`` instead of dialing +MaruServer. The default ``adapter`` fixture additionally pre-pins +``_handler`` and ``_chunk_size_bytes`` so each ``submit_*`` runs as +if the lazy connect already fired (tests that care about the lazy +behaviour itself use ``lazy_adapter``). Worker pools are swapped +for an inline ``_SyncExecutor`` so submit_* returns +deterministically before assertions. +""" + +# Standard +from unittest import mock + +# Third Party +import numpy as np +import pytest + +# First Party +from lmcache.v1.distributed.api import ObjectKey + +try: + # First Party + from lmcache.v1.distributed.l2_adapters.maru_l2_adapter import ( + MaruL2Adapter, + MaruL2AdapterConfig, + _memoryview_addr, + _object_key_to_string, + ) +except ImportError: + pytest.skip("MaruL2Adapter could not be imported", allow_module_level=True) + + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- + + +class _SyncExecutor: + """Inline replacement for ``ThreadPoolExecutor`` so worker + callbacks run on the caller's thread — keeps assertions + deterministic without ``shutdown(wait=True)`` plumbing. + """ + + def submit(self, fn, *args, **kwargs): + fn(*args, **kwargs) + return mock.MagicMock() + + def shutdown(self, wait: bool = True, cancel_futures: bool = False) -> None: + del wait, cancel_futures + + +class _FakeAllocHandle: + """Stand-in for ``MaruHandler.AllocHandle`` — exposes ``.buf`` as + a writable memoryview backed by a numpy uint8 array.""" + + def __init__(self, size: int) -> None: + self._arr = np.zeros(size, dtype=np.uint8) + + @property + def buf(self) -> memoryview: + return memoryview(self._arr) + + +def _mk_key(idx: int = 0, salt: str = "") -> ObjectKey: + return ObjectKey( + chunk_hash=idx.to_bytes(4, byteorder="big"), + model_name="test-model", + kv_rank=0xABCD, + cache_salt=salt, + ) + + +def _make_dram_memory_obj( + backing: np.ndarray, size_override: int | None = None +) -> mock.MagicMock: + """Wrap a contiguous numpy buffer as a fake ``MemoryObj``. + + The MagicMock exposes ``data_ptr`` (raw address) and + ``get_size()`` so :meth:`MaruL2Adapter._execute_store_task` + treats it like a real L1 DRAM allocation. + """ + mo = mock.MagicMock(name="MemoryObj") + mo.data_ptr = int(backing.ctypes.data) + mo.get_size = mock.MagicMock( + return_value=size_override if size_override is not None else backing.nbytes + ) + return mo + + +# --------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------- + + +@pytest.fixture +def base_cfg() -> MaruL2AdapterConfig: + return MaruL2AdapterConfig( + server_url="maru://localhost:5555", + pool_size_gb=1.0, + chunk_size_bytes=4 * 1024 * 1024, + instance_id="test-l2", + num_store_workers=1, + num_lookup_workers=1, + num_load_workers=1, + ) + + +@pytest.fixture +def fake_handler() -> mock.MagicMock: + h = mock.MagicMock(name="MaruHandler") + h.instance_id = "test-l2" + return h + + +@pytest.fixture +def adapter(base_cfg, fake_handler): + """Build a ``MaruL2Adapter`` with the handler pre-installed. + + Bypasses the lazy connect path so existing store/load/lookup + tests see a working handler immediately. ``_connect_handler`` is + also patched so any test that does trigger ``_ensure_connected`` + picks up the same mock instead of dialing MaruServer. + """ + with mock.patch.object( + MaruL2Adapter, "_connect_handler", return_value=fake_handler + ): + a = MaruL2Adapter(base_cfg) + # Simulate the post-connect state: handler is wired in and the + # chunk size is locked to the config value. + a._handler = fake_handler + a._chunk_size_bytes = base_cfg.chunk_size_bytes + # Swap async executors for the inline sync stub. + a._store_executor = _SyncExecutor() + a._lookup_executor = _SyncExecutor() + a._load_executor = _SyncExecutor() + try: + yield a + finally: + # ``close()`` walks every executor; ``_SyncExecutor.shutdown`` + # is a no-op, so this is cheap. + a.close() + + +@pytest.fixture +def lazy_adapter(base_cfg, fake_handler): + """Build a ``MaruL2Adapter`` with NO pre-installed handler. + + For tests that exercise ``_ensure_connected`` directly: + ``_connect_handler`` is patched so the first internal call + returns ``fake_handler``. Tests assert when (and with which + chunk-size hint) the call fired. + """ + connect_mock = mock.MagicMock(return_value=fake_handler) + with mock.patch.object(MaruL2Adapter, "_connect_handler", connect_mock): + # No chunk_size_bytes in config — derived lazily. + cfg = MaruL2AdapterConfig( + server_url=base_cfg.server_url, + pool_size_gb=base_cfg.pool_size_gb, + chunk_size_bytes=None, + instance_id=base_cfg.instance_id, + ) + a = MaruL2Adapter(cfg) + # Inline-sync executors so submit_* finishes before assertions. + a._store_executor = _SyncExecutor() + a._lookup_executor = _SyncExecutor() + a._load_executor = _SyncExecutor() + try: + yield a, connect_mock + finally: + a.close() + + +# ===================================================================== +# (1) Config +# ===================================================================== + + +class TestConfig: + def test_from_dict_minimal(self): + cfg = MaruL2AdapterConfig.from_dict( + { + "server_url": "maru://localhost:5555", + "pool_size_gb": 1, + "chunk_size_bytes": 4194304, + } + ) + assert cfg.server_url == "maru://localhost:5555" + assert cfg.pool_size_gb == 1.0 + assert cfg.chunk_size_bytes == 4194304 + assert cfg.instance_id is None + assert cfg.num_store_workers == 1 + assert cfg.num_lookup_workers == 1 + assert cfg.num_load_workers >= 1 + assert cfg.eager_map is True + + def test_from_dict_full(self): + cfg = MaruL2AdapterConfig.from_dict( + { + "server_url": "tcp://m:1", + "pool_size_gb": 0.5, + "chunk_size_bytes": 65536, + "instance_id": "client-x", + "num_store_workers": 2, + "num_lookup_workers": 3, + "num_load_workers": 4, + "timeout_ms": 1000, + "use_async_rpc": False, + "max_inflight": 8, + "eager_map": False, + } + ) + assert cfg.instance_id == "client-x" + assert cfg.num_store_workers == 2 + assert cfg.num_lookup_workers == 3 + assert cfg.num_load_workers == 4 + assert cfg.timeout_ms == 1000 + assert cfg.use_async_rpc is False + assert cfg.max_inflight == 8 + assert cfg.eager_map is False + + def test_from_dict_missing_server_url(self): + with pytest.raises(ValueError, match="server_url"): + MaruL2AdapterConfig.from_dict({"pool_size_gb": 1, "chunk_size_bytes": 4096}) + + def test_from_dict_empty_server_url(self): + with pytest.raises(ValueError, match="server_url"): + MaruL2AdapterConfig.from_dict( + { + "server_url": " ", + "pool_size_gb": 1, + "chunk_size_bytes": 4096, + } + ) + + def test_from_dict_pool_zero_rejected(self): + with pytest.raises(ValueError, match="pool_size_gb"): + MaruL2AdapterConfig.from_dict( + { + "server_url": "maru://x", + "pool_size_gb": 0, + "chunk_size_bytes": 4096, + } + ) + + def test_from_dict_chunk_size_zero_rejected(self): + with pytest.raises(ValueError, match="chunk_size_bytes"): + MaruL2AdapterConfig.from_dict( + {"server_url": "maru://x", "pool_size_gb": 1, "chunk_size_bytes": 0} + ) + + def test_from_dict_instance_id_non_string_rejected(self): + with pytest.raises(ValueError, match="instance_id"): + MaruL2AdapterConfig.from_dict( + { + "server_url": "maru://x", + "pool_size_gb": 1, + "chunk_size_bytes": 4096, + "instance_id": 123, + } + ) + + def test_help_text_mentions_required_fields(self): + txt = MaruL2AdapterConfig.help() + assert "server_url" in txt and "required" in txt + assert "pool_size_gb" in txt + assert "chunk_size_bytes" in txt + + +# ===================================================================== +# (2) Factory registration +# ===================================================================== + + +class TestRegistration: + def test_maru_registered_in_supported_types(self): + # First Party + from lmcache.v1.distributed.l2_adapters.config import ( + _L2_ADAPTER_CONFIG_REGISTRY, + ) + + # Importing the module triggered the ``register_l2_adapter_type`` call + # at module-bottom; the registry should now know "maru". + assert "maru" in _L2_ADAPTER_CONFIG_REGISTRY + + def test_factory_returns_adapter(self, base_cfg, fake_handler): + # First Party + from lmcache.v1.distributed.l2_adapters.maru_l2_adapter import ( + _create_maru_l2_adapter, + ) + + with mock.patch.object( + MaruL2Adapter, "_connect_handler", return_value=fake_handler + ): + a = _create_maru_l2_adapter(base_cfg, l1_memory_desc=None) + try: + assert isinstance(a, MaruL2Adapter) + finally: + a.close() + + +# ===================================================================== +# (3) Lifecycle / event fds +# ===================================================================== + + +class TestLifecycle: + def test_event_fds_are_distinct(self, adapter): + fds = { + adapter.get_store_event_fd(), + adapter.get_lookup_and_lock_event_fd(), + adapter.get_load_event_fd(), + } + assert len(fds) == 3 + + def test_close_idempotent(self, base_cfg, fake_handler): + with mock.patch.object( + MaruL2Adapter, "_connect_handler", return_value=fake_handler + ): + a = MaruL2Adapter(base_cfg) + a._store_executor = _SyncExecutor() + a._lookup_executor = _SyncExecutor() + a._load_executor = _SyncExecutor() + a.close() + a.close() # second call must be a safe no-op + + def test_submit_after_close_raises(self, base_cfg, fake_handler): + with mock.patch.object( + MaruL2Adapter, "_connect_handler", return_value=fake_handler + ): + a = MaruL2Adapter(base_cfg) + a._store_executor = _SyncExecutor() + a._lookup_executor = _SyncExecutor() + a._load_executor = _SyncExecutor() + a.close() + with pytest.raises(RuntimeError, match="closed"): + a.submit_store_task([], []) + + +# ===================================================================== +# (4) Store path +# ===================================================================== + + +class TestStore: + def test_keys_objects_length_mismatch(self, adapter): + with pytest.raises(ValueError, match="length mismatch"): + adapter.submit_store_task([_mk_key(0), _mk_key(1)], [mock.MagicMock()]) + + def test_happy_path_alloc_memmove_store(self, adapter, fake_handler): + keys = [_mk_key(i) for i in range(2)] + size = 1024 + src_arrs = [np.full(size, 0xAB + i, dtype=np.uint8) for i in range(2)] + mem_objs = [_make_dram_memory_obj(a) for a in src_arrs] + + handles = [_FakeAllocHandle(size) for _ in keys] + fake_handler.alloc.side_effect = handles + fake_handler.batch_store.return_value = [True, True] + + task_id = adapter.submit_store_task(keys, mem_objs) + + # Verify alloc + batch_store were issued exactly once. + assert fake_handler.alloc.call_count == 2 + fake_handler.batch_store.assert_called_once() + passed_keys, passed_handles = fake_handler.batch_store.call_args.args + assert passed_keys == [_object_key_to_string(k) for k in keys] + assert passed_handles == handles + + # DRAM bytes must have made it into the CXL handle's buffer. + for h, src in zip(handles, src_arrs, strict=True): + assert bytes(h.buf[:size]) == bytes(src.tobytes()) + + # Completion bookkeeping. + results = adapter.pop_completed_store_tasks() + assert results == {task_id: True} + + def test_batch_store_partial_failure_marks_overall_false( + self, adapter, fake_handler + ): + keys = [_mk_key(0), _mk_key(1)] + size = 256 + srcs = [np.zeros(size, dtype=np.uint8) for _ in keys] + objs = [_make_dram_memory_obj(s) for s in srcs] + + fake_handler.alloc.side_effect = [ + _FakeAllocHandle(size), + _FakeAllocHandle(size), + ] + # One success, one failure. + fake_handler.batch_store.return_value = [True, False] + + task_id = adapter.submit_store_task(keys, objs) + assert adapter.pop_completed_store_tasks() == {task_id: False} + + def test_alloc_exception_marks_failure(self, adapter, fake_handler): + keys = [_mk_key(0)] + size = 256 + src = np.zeros(size, dtype=np.uint8) + objs = [_make_dram_memory_obj(src)] + + fake_handler.alloc.side_effect = RuntimeError("oom") + task_id = adapter.submit_store_task(keys, objs) + + # batch_store must not have been reached. + fake_handler.batch_store.assert_not_called() + assert adapter.pop_completed_store_tasks() == {task_id: False} + + def test_pop_drains_completed_dict(self, adapter, fake_handler): + keys = [_mk_key(0)] + size = 16 + src = np.zeros(size, dtype=np.uint8) + fake_handler.alloc.return_value = _FakeAllocHandle(size) + fake_handler.batch_store.return_value = [True] + + task_id = adapter.submit_store_task(keys, [_make_dram_memory_obj(src)]) + assert adapter.pop_completed_store_tasks() == {task_id: True} + # Second pop yields nothing — single-consumer contract. + assert adapter.pop_completed_store_tasks() == {} + + +# ===================================================================== +# (5) Lookup-and-lock path +# ===================================================================== + + +class TestLookup: + def test_all_hit(self, adapter, fake_handler): + keys = [_mk_key(i) for i in range(3)] + fake_handler.batch_pin.return_value = [True, True, True] + task_id = adapter.submit_lookup_and_lock_task(keys) + bm = adapter.query_lookup_and_lock_result(task_id) + assert bm is not None + assert [bm.test(i) for i in range(3)] == [True, True, True] + + def test_prefix_stop(self, adapter, fake_handler): + keys = [_mk_key(i) for i in range(4)] + fake_handler.batch_pin.return_value = [True, True, False, True] + task_id = adapter.submit_lookup_and_lock_task(keys) + bm = adapter.query_lookup_and_lock_result(task_id) + # Once a miss hits, the bitmap freezes — even the trailing True + # after the first False must not be set. + assert [bm.test(i) for i in range(4)] == [True, True, False, False] + + def test_all_miss(self, adapter, fake_handler): + keys = [_mk_key(i) for i in range(2)] + fake_handler.batch_pin.return_value = [False, False] + task_id = adapter.submit_lookup_and_lock_task(keys) + bm = adapter.query_lookup_and_lock_result(task_id) + assert [bm.test(i) for i in range(2)] == [False, False] + + def test_handler_exception_empty_bitmap(self, adapter, fake_handler): + keys = [_mk_key(0)] + fake_handler.batch_pin.side_effect = RuntimeError("rpc fail") + task_id = adapter.submit_lookup_and_lock_task(keys) + bm = adapter.query_lookup_and_lock_result(task_id) + assert bm is not None + assert bm.test(0) is False + + def test_query_is_single_consumer(self, adapter, fake_handler): + keys = [_mk_key(0)] + fake_handler.batch_pin.return_value = [True] + task_id = adapter.submit_lookup_and_lock_task(keys) + assert adapter.query_lookup_and_lock_result(task_id) is not None + # Second query returns None. + assert adapter.query_lookup_and_lock_result(task_id) is None + + +# ===================================================================== +# (6) Load path +# ===================================================================== + + +class TestLoad: + def test_keys_objects_length_mismatch(self, adapter): + with pytest.raises(ValueError, match="length mismatch"): + adapter.submit_load_task([_mk_key(0)], []) + + def test_happy_path_memmove_cxl_to_dram(self, adapter, fake_handler): + keys = [_mk_key(0)] + size = 1024 + cxl_arr = np.full(size, 0xCD, dtype=np.uint8) + dram_arr = np.zeros(size, dtype=np.uint8) + + mi = mock.MagicMock(name="MemoryInfo") + mi.view = memoryview(cxl_arr) + fake_handler.batch_retrieve.return_value = [mi] + + task_id = adapter.submit_load_task(keys, [_make_dram_memory_obj(dram_arr)]) + bm = adapter.query_load_result(task_id) + assert bm is not None and bm.test(0) is True + # CXL bytes must have landed in DRAM. + assert bytes(dram_arr.tobytes()) == bytes(cxl_arr.tobytes()) + + def test_partial_miss_bitmap(self, adapter, fake_handler): + keys = [_mk_key(i) for i in range(3)] + size = 32 + cxl_arrs = [np.full(size, 0x10 + i, dtype=np.uint8) for i in range(3)] + # Second slot is a miss — None from MaruServer. + mem_infos = [] + for i, arr in enumerate(cxl_arrs): + if i == 1: + mem_infos.append(None) + else: + mi = mock.MagicMock() + mi.view = memoryview(arr) + mem_infos.append(mi) + fake_handler.batch_retrieve.return_value = mem_infos + + dram_arrs = [np.zeros(size, dtype=np.uint8) for _ in keys] + mem_objs = [_make_dram_memory_obj(a) for a in dram_arrs] + + task_id = adapter.submit_load_task(keys, mem_objs) + bm = adapter.query_load_result(task_id) + assert [bm.test(i) for i in range(3)] == [True, False, True] + # Hit slots got the data; miss slot stayed zero. + assert bytes(dram_arrs[0].tobytes()) == bytes(cxl_arrs[0].tobytes()) + assert int(dram_arrs[1].sum()) == 0 + assert bytes(dram_arrs[2].tobytes()) == bytes(cxl_arrs[2].tobytes()) + + def test_handler_exception_empty_bitmap(self, adapter, fake_handler): + keys = [_mk_key(0)] + fake_handler.batch_retrieve.side_effect = RuntimeError("rpc fail") + size = 16 + dram = np.zeros(size, dtype=np.uint8) + task_id = adapter.submit_load_task(keys, [_make_dram_memory_obj(dram)]) + bm = adapter.query_load_result(task_id) + assert bm is not None and bm.test(0) is False + # DRAM untouched. + assert int(dram.sum()) == 0 + + +# ===================================================================== +# (7) Unlock / delete +# ===================================================================== + + +class TestUnlockDelete: + def test_unlock_invokes_batch_unpin_with_encoded_keys(self, adapter, fake_handler): + keys = [_mk_key(i) for i in range(2)] + adapter.submit_unlock(keys) + called_keys = fake_handler.batch_unpin.call_args.args[0] + assert called_keys == [_object_key_to_string(k) for k in keys] + + def test_unlock_empty_skips_rpc(self, adapter, fake_handler): + adapter.submit_unlock([]) + fake_handler.batch_unpin.assert_not_called() + + def test_unlock_swallows_exception(self, adapter, fake_handler): + fake_handler.batch_unpin.side_effect = RuntimeError("boom") + adapter.submit_unlock([_mk_key(0)]) # must not raise + + def test_delete_calls_per_key(self, adapter, fake_handler): + keys = [_mk_key(i) for i in range(3)] + adapter.delete(keys) + assert fake_handler.delete.call_count == 3 + called_keys = [call.args[0] for call in fake_handler.delete.call_args_list] + assert called_keys == [_object_key_to_string(k) for k in keys] + + def test_delete_empty_skips_rpc(self, adapter, fake_handler): + adapter.delete([]) + fake_handler.delete.assert_not_called() + + def test_delete_swallows_exception(self, adapter, fake_handler): + fake_handler.delete.side_effect = RuntimeError("boom") + adapter.delete([_mk_key(0), _mk_key(1)]) # must not raise + # Both keys attempted regardless of the first failure. + assert fake_handler.delete.call_count == 2 + + +# ===================================================================== +# (8) Key encoding +# ===================================================================== + + +class TestKeyEncoding: + def test_basic(self): + k = ObjectKey( + chunk_hash=(0x01020304).to_bytes(4, "big"), + model_name="m", + kv_rank=0xAB, + cache_salt="", + ) + assert _object_key_to_string(k) == "m@000000ab@01020304" + + def test_with_salt(self): + k = ObjectKey( + chunk_hash=(0xFF).to_bytes(4, "big"), + model_name="m", + kv_rank=0xAB, + cache_salt="u1", + ) + assert _object_key_to_string(k) == "m@000000ab@000000ff@u1" + + +# ===================================================================== +# (9) _memoryview_addr +# ===================================================================== + + +class TestMemoryviewAddr: + def test_returns_positive_int(self): + buf = bytearray(b"\x00" * 16) + addr = _memoryview_addr(memoryview(buf)) + assert isinstance(addr, int) and addr > 0 + + def test_stable_across_calls(self): + # A fixed-size bytearray's pointer must not move between calls. + buf = bytearray(b"\x00" * 16) + mv = memoryview(buf) + assert _memoryview_addr(mv) == _memoryview_addr(mv) + + +# ===================================================================== +# (10) Lazy MaruHandler connect +# ===================================================================== + + +class TestLazyConnect: + def test_handler_none_at_init(self, lazy_adapter): + a, connect_mock = lazy_adapter + assert a._handler is None + assert a._chunk_size_bytes is None + connect_mock.assert_not_called() + + def test_first_store_triggers_connect_with_physical_size_hint( + self, lazy_adapter, fake_handler + ): + a, connect_mock = lazy_adapter + keys = [_mk_key(0)] + size = 4096 + src = np.zeros(size, dtype=np.uint8) + obj = _make_dram_memory_obj(src, size_override=size) + # Distinct logical vs physical size — the lazy path must pick + # the *physical* size for the page allocation. + obj.get_physical_size = mock.MagicMock(return_value=size * 2) + + fake_handler.alloc.return_value = _FakeAllocHandle(size) + fake_handler.batch_store.return_value = [True] + + task_id = a.submit_store_task(keys, [obj]) + + connect_mock.assert_called_once() + # ``_connect_handler(config, chunk_size_bytes)`` — second arg + # is what we care about. + passed_chunk = connect_mock.call_args.args[1] + assert passed_chunk == size * 2 + assert a._chunk_size_bytes == size * 2 + assert a._handler is fake_handler + + # And the task itself succeeded. + assert a.pop_completed_store_tasks() == {task_id: True} + + def test_explicit_chunk_size_wins_over_hint(self, base_cfg, fake_handler): + """Config-level ``chunk_size_bytes`` is authoritative; the + first store's physical-size hint is ignored. + """ + # Pin an explicit chunk size in config. + cfg = MaruL2AdapterConfig( + server_url=base_cfg.server_url, + pool_size_gb=base_cfg.pool_size_gb, + chunk_size_bytes=999_999, + instance_id=base_cfg.instance_id, + ) + connect_mock = mock.MagicMock(return_value=fake_handler) + # NB: the patch must wrap the full test body, not just the + # constructor — ``submit_store_task`` triggers ``_ensure_connected`` + # which calls ``_connect_handler``. Letting the patch lapse + # before the store would dial the real MaruServer. + with mock.patch.object(MaruL2Adapter, "_connect_handler", connect_mock): + a = MaruL2Adapter(cfg) + a._store_executor = _SyncExecutor() + a._lookup_executor = _SyncExecutor() + a._load_executor = _SyncExecutor() + try: + obj = _make_dram_memory_obj(np.zeros(64, dtype=np.uint8)) + obj.get_physical_size = mock.MagicMock(return_value=12345) + fake_handler.alloc.return_value = _FakeAllocHandle(64) + fake_handler.batch_store.return_value = [True] + + a.submit_store_task([_mk_key(0)], [obj]) + + connect_mock.assert_called_once() + assert connect_mock.call_args.args[1] == 999_999 + finally: + a.close() + + def test_lookup_before_store_yields_all_miss(self, lazy_adapter): + """Lookup without a prior store + no config chunk_size_bytes + can't connect — ``_ensure_connected`` raises and the worker + records an all-miss bitmap. + """ + a, connect_mock = lazy_adapter + keys = [_mk_key(i) for i in range(2)] + task_id = a.submit_lookup_and_lock_task(keys) + bm = a.query_lookup_and_lock_result(task_id) + assert bm is not None + assert [bm.test(i) for i in range(2)] == [False, False] + # connect never happened. + connect_mock.assert_not_called() + assert a._handler is None + + def test_load_first_triggers_connect_with_hint(self, lazy_adapter, fake_handler): + """Load also seeds the lazy connect via the destination + MemoryObj's physical size. + """ + a, connect_mock = lazy_adapter + size = 2048 + dst = np.zeros(size, dtype=np.uint8) + obj = _make_dram_memory_obj(dst, size_override=size) + obj.get_physical_size = mock.MagicMock(return_value=size) + + fake_handler.batch_retrieve.return_value = [None] # miss + task_id = a.submit_load_task([_mk_key(0)], [obj]) + bm = a.query_load_result(task_id) + assert bm is not None and bm.test(0) is False + + connect_mock.assert_called_once() + assert connect_mock.call_args.args[1] == size + + def test_unlock_skips_when_not_connected(self, lazy_adapter): + a, connect_mock = lazy_adapter + # Must not raise — and must not attempt to connect, since + # there can be no live pins on a never-connected handler. + a.submit_unlock([_mk_key(0), _mk_key(1)]) + connect_mock.assert_not_called() + assert a._handler is None + + def test_delete_skips_when_not_connected(self, lazy_adapter): + a, connect_mock = lazy_adapter + a.delete([_mk_key(0)]) + connect_mock.assert_not_called() + assert a._handler is None + + def test_second_call_reuses_handler(self, lazy_adapter, fake_handler): + """``_ensure_connected`` is idempotent — only the first store + triggers ``_connect_handler``. + """ + a, connect_mock = lazy_adapter + size = 4096 + src = np.zeros(size, dtype=np.uint8) + obj = _make_dram_memory_obj(src, size_override=size) + obj.get_physical_size = mock.MagicMock(return_value=size) + fake_handler.alloc.return_value = _FakeAllocHandle(size) + fake_handler.batch_store.return_value = [True] + + a.submit_store_task([_mk_key(0)], [obj]) + a.submit_store_task([_mk_key(1)], [obj]) + + # Single connect across both stores. + assert connect_mock.call_count == 1 + + def test_from_dict_chunk_size_optional(self): + """``chunk_size_bytes`` is no longer required in the + ``--l2-adapter`` JSON.""" + cfg = MaruL2AdapterConfig.from_dict( + { + "server_url": "maru://localhost:5555", + "pool_size_gb": 1, + # no chunk_size_bytes + } + ) + assert cfg.chunk_size_bytes is None diff --git a/tests/v1/distributed/test_storage_manager_maru.py b/tests/v1/distributed/test_storage_manager_maru.py new file mode 100644 index 0000000000..83759f5ffa --- /dev/null +++ b/tests/v1/distributed/test_storage_manager_maru.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the maru-backend wiring of StorageManager. + +Coverage: + +1. ``StorageManager.__init__`` in maru mode — controllers and L2 + adapters are not constructed; ``_is_maru`` is True. +2. ``register_kv_layout`` — forwards down through L1Manager and + L1MemoryManager to ``MaruMemoryAllocator.init_layout``. +3. ``finish_write`` — threads ``memory_objs`` to ``L1Manager.finish_write``. +4. ``close()`` — succeeds with controllers absent. +5. ``report_status()`` — returns a maru-shaped dict. + +The maru runtime (``maru``, ``maru_lmcache``) is NOT required: the +lazy ``MaruMemoryAllocator.__init__`` performs no RPC. Tests that +need a "pool ready" allocator monkey-patch ``init_layout`` so no +MaruServer connection is attempted. +""" + +# Standard +from unittest import mock + +# Third Party +import pytest +import torch + +# First Party +from lmcache.v1.distributed.config import ( + EvictionConfig, + L1ManagerConfig, + L1MemoryManagerConfig, + StorageManagerConfig, +) +from lmcache.v1.distributed.l2_adapters.config import L2AdaptersConfig +from lmcache.v1.memory_management import MemoryFormat +from lmcache.v1.mp_observability.event_bus import EventBusConfig, init_event_bus + +try: + # First Party + from lmcache.v1.distributed.maru_memory_allocator import ( + MaruL1Config, + MaruMemoryAllocator, + ) + from lmcache.v1.distributed.storage_manager import StorageManager +except ImportError: + pytest.skip( + "storage_manager / maru_memory_allocator could not be imported", + allow_module_level=True, + ) + + +@pytest.fixture(autouse=True) +def _event_bus(): + """Initialize a minimal event bus for every test in this module + (the storage manager publishes events at construction time). + """ + init_event_bus(EventBusConfig(enabled=False)) + yield + + +@pytest.fixture +def maru_storage_config() -> StorageManagerConfig: + """A ``StorageManagerConfig`` with the maru L1 backend selected.""" + maru_cfg = MaruL1Config( + server_url="maru://localhost:5555", + pool_size_bytes=1 << 30, + instance_id="test-mp", + ) + memory_config = L1MemoryManagerConfig( + size_in_bytes=0, use_lazy=False, maru_config=maru_cfg + ) + l1_manager_config = L1ManagerConfig(memory_config=memory_config) + return StorageManagerConfig( + l1_manager_config=l1_manager_config, + eviction_config=EvictionConfig(eviction_policy="noop"), + l2_adapter_config=L2AdaptersConfig([]), + ) + + +@pytest.fixture +def fake_maru_init_layout(): + """Replace ``MaruMemoryAllocator.init_layout`` with a stub that + installs ``MagicMock`` handler + adapter — equivalent to the + post-connect state. Reverts on teardown. + """ + real_init_layout = MaruMemoryAllocator.init_layout + + def fake_init_layout(self, shapes, dtypes, fmt, chunk_size_in_tokens): + if self._cxl_adapter is not None: + # Layout-mismatch validation is independent of the + # MaruServer connection, so defer to the real method when + # a layout is already bound. + real_init_layout(self, shapes, dtypes, fmt, chunk_size_in_tokens) + return + self._handler = mock.MagicMock(name="MaruHandler") + self._cxl_adapter = mock.MagicMock(name="CxlMemoryAdapter") + self._shapes = shapes + self._dtypes = dtypes + self._fmt = fmt + self._chunk_size_in_tokens = chunk_size_in_tokens + self._single_token_size = 4096 + + MaruMemoryAllocator.init_layout = fake_init_layout + try: + yield + finally: + MaruMemoryAllocator.init_layout = real_init_layout + + +# ========================================================================= +# (1) Maru-mode StorageManager construction +# ========================================================================= + + +class TestStorageManagerMaruInit: + def test_no_controllers_constructed(self, maru_storage_config): + mgr = StorageManager(maru_storage_config) + try: + assert mgr._is_maru is True + assert mgr._eviction_controller is None + assert mgr._l2_eviction_controller is None + assert mgr._store_controller is None + assert mgr._prefetch_controller is None + assert mgr._l2_adapters == [] + finally: + mgr.close() + + def test_quota_manager_present(self, maru_storage_config): + # The HTTP layer expects a stable quota_manager reference. + mgr = StorageManager(maru_storage_config) + try: + assert mgr.quota_manager is not None + finally: + mgr.close() + + +# ========================================================================= +# (2) register_kv_layout chain +# ========================================================================= + + +class TestRegisterKvLayoutChain: + def test_forwards_to_allocator(self, maru_storage_config, fake_maru_init_layout): + mgr = StorageManager(maru_storage_config) + try: + shapes = [torch.Size([2, 32, 256, 128])] + dtypes = [torch.float16] + mgr.register_kv_layout(shapes, dtypes, MemoryFormat.KV_2LTD, 256) + + alloc = mgr._l1_manager._memory_manager._allocator + assert alloc._shapes == shapes + assert alloc._dtypes == dtypes + assert alloc._fmt is MemoryFormat.KV_2LTD + assert alloc._chunk_size_in_tokens == 256 + assert alloc._handler is not None + assert alloc._cxl_adapter is not None + finally: + mgr.close() + + def test_layout_mismatch_raises(self, maru_storage_config, fake_maru_init_layout): + mgr = StorageManager(maru_storage_config) + try: + shapes_a = [torch.Size([2, 32, 256, 128])] + shapes_b = [torch.Size([2, 32, 128, 128])] # different + dtypes = [torch.float16] + mgr.register_kv_layout(shapes_a, dtypes, MemoryFormat.KV_2LTD, 256) + with pytest.raises(ValueError, match="layout mismatch"): + mgr.register_kv_layout(shapes_b, dtypes, MemoryFormat.KV_2LTD, 256) + finally: + mgr.close() + + def test_same_layout_is_idempotent( + self, maru_storage_config, fake_maru_init_layout + ): + mgr = StorageManager(maru_storage_config) + try: + shapes = [torch.Size([2, 32, 256, 128])] + dtypes = [torch.float16] + mgr.register_kv_layout(shapes, dtypes, MemoryFormat.KV_2LTD, 256) + # Second call with the same layout: no exception. + mgr.register_kv_layout(shapes, dtypes, MemoryFormat.KV_2LTD, 256) + finally: + mgr.close() + + +# ========================================================================= +# (3) finish_write threads memory_objs +# ========================================================================= + + +class TestFinishWriteThreading: + def test_memory_objs_forwarded_to_l1_manager( + self, maru_storage_config, fake_maru_init_layout + ): + mgr = StorageManager(maru_storage_config) + try: + keys = [mock.MagicMock(name=f"k-{i}") for i in range(3)] + memory_objs = [mock.MagicMock(name=f"mo-{i}") for i in range(3)] + with mock.patch.object( + mgr._l1_manager, "finish_write", return_value={} + ) as patched: + mgr.finish_write(keys, memory_objs=memory_objs) + patched.assert_called_once() + # Whether ``memory_objs`` was passed positionally or by + # keyword, the second binding should be the same list. + call = patched.call_args + assert call.kwargs.get("memory_objs") is memory_objs + finally: + mgr.close() + + def test_default_call_passes_none(self, maru_storage_config, fake_maru_init_layout): + mgr = StorageManager(maru_storage_config) + try: + with mock.patch.object( + mgr._l1_manager, "finish_write", return_value={} + ) as patched: + mgr.finish_write([mock.MagicMock()]) + assert patched.call_args.kwargs.get("memory_objs") is None + finally: + mgr.close() + + +# ========================================================================= +# (4) close() and report_status() without controllers +# ========================================================================= + + +class TestCloseAndReport: + def test_close_succeeds_without_controllers(self, maru_storage_config): + mgr = StorageManager(maru_storage_config) + # Should not raise even though all controllers are ``None``. + mgr.close() + + def test_report_status_shape(self, maru_storage_config): + mgr = StorageManager(maru_storage_config) + try: + status = mgr.report_status() + assert status["backend"] == "maru" + assert status["num_l2_adapters"] == 0 + assert status["l2_adapters"] == [] + assert "l1_manager" in status + # Controller entries should not be present in maru mode. + assert "store_controller" not in status + assert "prefetch_controller" not in status + assert "l1_eviction_controller" not in status + assert "l2_eviction_controller" not in status + finally: + mgr.close()