diff --git a/docs/source/assets/maru-kvcache.gif b/docs/source/assets/maru-kvcache.gif new file mode 100644 index 0000000000..51483276a6 Binary files /dev/null and b/docs/source/assets/maru-kvcache.gif differ diff --git a/docs/source/kv_cache/storage_backends/index.rst b/docs/source/kv_cache/storage_backends/index.rst index 9fd7fd0e0c..2b40fe92b1 100644 --- a/docs/source/kv_cache/storage_backends/index.rst +++ b/docs/source/kv_cache/storage_backends/index.rst @@ -15,6 +15,7 @@ Supported Backends gds infinistore local_storage + maru mock mooncake nixl diff --git a/docs/source/kv_cache/storage_backends/maru.rst b/docs/source/kv_cache/storage_backends/maru.rst new file mode 100644 index 0000000000..8430e6ee3f --- /dev/null +++ b/docs/source/kv_cache/storage_backends/maru.rst @@ -0,0 +1,113 @@ +Maru +==== + +.. _maru-overview: + +Overview +-------- + +`Maru `_ is a high-performance KV cache storage engine built on CXL shared memory, +designed for LLM inference scenarios where multiple instances need to share a KV cache with minimal latency. + +.. image:: ../../assets/maru-kvcache.gif + :alt: KV Cache Sharing: Without vs With Maru + +For architecture details, see the `Maru documentation `_. + +Quick Start +----------- + +Install Maru: + +.. code-block:: bash + + git clone https://github.com/xcena-dev/maru.git + cd maru + ./install.sh + +This installs ``maru-server``, ``maru-resourced``, and the ``maru`` Python package. + +Deploy Model With Maru +~~~~~~~~~~~~~~~~~~~~~~ + +**Prerequisites:** CXL device (``/dev/dax*``), Python 3.12+, vLLM and LMCache installed. + +**1. Start the Maru Server** + +.. code-block:: bash + + maru-server + +**2. Create configuration file** (``maru-config.yaml``): + +.. code-block:: yaml + + chunk_size: 256 + local_cpu: False + max_local_cpu_size: 0 + save_unfull_chunk: True + + # Maru backend + maru_path: "maru://localhost:5555" + maru_pool_size: 4 + +**3. Start vLLM with Maru** + +.. code-block:: bash + + LMCACHE_CONFIG_FILE="maru-config.yaml" \ + vllm serve \ + meta-llama/Llama-3.1-8B-Instruct \ + --max-model-len 65536 \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' + +Configuration +------------- + +**LMCache Parameters:** + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``maru_path`` + - Required + - Maru server URL (format: ``maru://host:port``) + * - ``maru_pool_size`` + - ``4.0`` + - CXL memory pool size per instance in GB (e.g., ``4``, ``0.5``) + +**Advanced Parameters (via extra_config):** + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``maru_instance_id`` + - auto UUID + - Unique client instance identifier + * - ``maru_timeout_ms`` + - 5000 + - ZMQ RPC socket timeout in milliseconds + * - ``maru_use_async_rpc`` + - true + - Async DEALER-ROUTER RPC (``false`` for synchronous REQ-REP) + * - ``maru_max_inflight`` + - 64 + - Max concurrent async RPC requests + * - ``maru_eager_map`` + - true + - Pre-map all shared regions on connect + +Additional Resources +-------------------- + +- `Maru GitHub Repository `_ +- `Maru Documentation `_ diff --git a/lmcache/v1/config.py b/lmcache/v1/config.py index c6c86e563d..20bbc19276 100644 --- a/lmcache/v1/config.py +++ b/lmcache/v1/config.py @@ -236,6 +236,13 @@ "default": None, "env_converter": int, }, + # Maru CXL shared memory backend + "maru_path": {"type": Optional[str], "default": None, "env_converter": str}, + "maru_pool_size": { + "type": float, + "default": 4.0, + "env_converter": float, + }, # Other configurations # (Deprecated) The url of the actual remote lmcache instance for auditing. # Please use extra_config['audit_actual_remote_url'] instead. diff --git a/lmcache/v1/storage_backend/__init__.py b/lmcache/v1/storage_backend/__init__.py index b7212b9603..d49cda695a 100644 --- a/lmcache/v1/storage_backend/__init__.py +++ b/lmcache/v1/storage_backend/__init__.py @@ -218,6 +218,20 @@ def CreateStorageBackends( ) storage_backends[str(gds_backend)] = gds_backend + if config.maru_path is not None and "MaruBackend" not in _skip: + try: + # First Party + from lmcache.v1.storage_backend.maru_backend import MaruBackend + except ImportError as e: + raise ImportError( + "The 'maru' and 'maru_lmcache' packages are required " + "to use MaruBackend. Please install them according to " + "the Maru setup documentation." + ) from e + + maru_backend = MaruBackend(config, metadata, loop, dst_device) + storage_backends[str(maru_backend)] = maru_backend + if config.remote_url is not None and "RemoteBackend" not in _skip: assert local_cpu_backend is not None, ( "Remote backend requires local CPU backend as a buffer." diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py new file mode 100644 index 0000000000..21b343229d --- /dev/null +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -0,0 +1,735 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from concurrent.futures import Future +from typing import Any, Callable, List, Optional, Sequence, Union +import asyncio +import threading +import time + +# Third Party +from maru import MaruConfig, MaruHandler +from maru_lmcache import CxlMemoryAdapter +import torch + +# First Party +from lmcache.integration.vllm.utils import get_size_bytes +from lmcache.logging import init_logger +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import ( + MemoryAllocatorInterface, + MemoryFormat, + MemoryObj, +) +from lmcache.v1.metadata import LMCacheMetadata +from lmcache.v1.storage_backend.abstract_backend import AllocatorBackendInterface + +logger = init_logger(__name__) + + +class MaruBackend(AllocatorBackendInterface): + """Maru CXL shared memory storage backend. + + Implements AllocatorBackendInterface with its own CxlMemoryAdapter. + No LocalCPUBackend needed — data lives directly in CXL mmap memory. + + Put is async (Future): metadata registration via RPC. + Get is sync: CXL memory direct read (no network I/O). + + Args: + config: LMCache engine configuration. Must have maru_path set. + metadata: LMCache engine metadata. + loop: asyncio event loop for async put tasks. + dst_device: Target device string (unused for CXL, kept for interface). + """ + + def __init__( + self, + config: LMCacheEngineConfig, + metadata: LMCacheMetadata, + loop: asyncio.AbstractEventLoop, + dst_device: str = "cuda", + ): + super().__init__(dst_device=dst_device) + + if config.use_layerwise: + raise NotImplementedError( + "MaruBackend does not yet support layerwise KV cache." + ) + + # 1. Config + self.config = config + self.loop = loop + + self._full_chunk_size_bytes: int = get_size_bytes( + metadata.get_shapes(), metadata.get_dtypes() + ) + assert self._full_chunk_size_bytes % metadata.chunk_size == 0 + self._single_token_size: int = ( + self._full_chunk_size_bytes // metadata.chunk_size + ) + + self._mla_worker_id_as0_mode: bool = ( + config.get_extra_config_value( + "remote_enable_mla_worker_id_as0", metadata.use_mla + ) + and metadata.use_mla + and metadata.world_size > 1 + and metadata.worker_id != 0 + ) + + # 2. Handler + self._handler = self._create_handler(config) + + # 3. Allocator + self.memory_allocator = self.initialize_allocator(config, metadata) + + # 4. State + self.put_lock = threading.Lock() + self.put_tasks: set[CacheEngineKey] = set() + + def __str__(self) -> str: + return self.__class__.__name__ + + @staticmethod + def _pool_size_gb_to_bytes(size_gb: float) -> int: + """Convert pool size in GB to bytes.""" + return int(size_gb * 1024**3) + + # ========================================================================= + # Initialization helpers + # ========================================================================= + + def _create_handler( + self, + config: LMCacheEngineConfig, + ) -> "MaruHandler": + """Create and connect a MaruHandler. + + Args: + config: LMCache engine configuration. + + Returns: + Connected MaruHandler instance. + + Raises: + RuntimeError: If MaruHandler connection fails. + """ + assert config.maru_path is not None, "maru_path must be set for MaruBackend" + + # Convert maru:// scheme to tcp:// for ZMQ + server_url = config.maru_path + if server_url.startswith("maru://"): + server_url = "tcp://" + server_url[len("maru://") :] + + extra = config.extra_config or {} + maru_config = MaruConfig( + server_url=server_url, + instance_id=extra.get("maru_instance_id"), + pool_size=self._pool_size_gb_to_bytes(config.maru_pool_size), + chunk_size_bytes=self._full_chunk_size_bytes, + auto_connect=False, + timeout_ms=extra.get("maru_timeout_ms", 5000), + use_async_rpc=extra.get("maru_use_async_rpc", True), + max_inflight=extra.get("maru_max_inflight", 64), + eager_map=extra.get("maru_eager_map", True), + ) + + handler = MaruHandler(maru_config) + if not handler.connect(): + raise RuntimeError(f"Failed to connect MaruHandler to {config.maru_path}") + logger.debug("[Maru] Connected to %s", config.maru_path) + return handler + + # ========================================================================= + # AllocatorBackendInterface + # ========================================================================= + + def initialize_allocator( + self, config: LMCacheEngineConfig, metadata: LMCacheMetadata + ) -> MemoryAllocatorInterface: + """Create CxlMemoryAdapter backed by the connected handler. + + Args: + config: LMCache engine configuration. + metadata: LMCache engine metadata. + + Returns: + CxlMemoryAdapter instance. + """ + shapes = metadata.get_shapes() + dtypes = metadata.get_dtypes() + fmt = MemoryFormat.KV_MLA_FMT if metadata.use_mla else MemoryFormat.KV_2LTD + chunk_size = self._handler.get_chunk_size() + + return CxlMemoryAdapter( + handler=self._handler, + shapes=shapes, + dtypes=dtypes, + fmt=fmt, + chunk_size=chunk_size, + ) + + def get_memory_allocator(self) -> MemoryAllocatorInterface: + """Returns the underlying CxlMemoryAdapter.""" + return self.memory_allocator + + def get_allocator_backend(self) -> "MaruBackend": + """Returns self as the allocator backend.""" + return self + + def allocate( + self, + shapes: Union[torch.Size, list[torch.Size]], + dtypes: Union[torch.dtype, list[torch.dtype]], + fmt: MemoryFormat = MemoryFormat.KV_2LTD, + eviction: bool = True, + busy_loop: bool = True, + ) -> Optional[MemoryObj]: + """Allocate CXL-backed memory via CxlMemoryAdapter. + + Args: + shapes: Tensor shape(s). + dtypes: Tensor dtype(s). + fmt: Memory format. + eviction: Unused. + busy_loop: Unused. + + Returns: + MemoryObj backed by CXL memory, or None on failure. + """ + obj = self.memory_allocator.allocate(shapes, dtypes, fmt) + if obj is not None: + logger.debug( + "[Maru] allocate rid=%d pid=%d", + *CxlMemoryAdapter.decode_address(obj.metadata.address), + ) + else: + logger.debug("[Maru] allocate failed shapes=%s dtypes=%s", shapes, dtypes) + return obj + + def batched_allocate( + self, + shapes: Union[torch.Size, list[torch.Size]], + dtypes: Union[torch.dtype, list[torch.dtype]], + batch_size: int, + fmt: MemoryFormat = MemoryFormat.KV_2LTD, + eviction: bool = True, + busy_loop: bool = True, + ) -> Optional[list[MemoryObj]]: + """Allocate multiple CXL-backed MemoryObjs. + + Args: + shapes: Tensor shape(s) (same for each allocation). + dtypes: Tensor dtype(s) (same for each allocation). + batch_size: Number of allocations. + fmt: Memory format. + eviction: Unused. + busy_loop: Unused. + + Returns: + List of MemoryObj, or None if any allocation fails. + """ + return self.memory_allocator.batched_allocate(shapes, dtypes, batch_size, fmt) + + # ========================================================================= + # Put (async) + # ========================================================================= + + def exists_in_put_tasks(self, key: CacheEngineKey) -> bool: + """Check whether key is in ongoing put tasks. + + Args: + key: The cache key. + + Returns: + True if the key has a pending put task. + """ + with self.put_lock: + return key in self.put_tasks + + @staticmethod + def _create_immediate_empty_future() -> Future: + """Create a Future that is already resolved with None.""" + f: Future = Future() + f.set_result(None) + return f + + def submit_put_task( + self, + key: CacheEngineKey, + memory_obj: MemoryObj, + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> Future: + """Submit a put task to register KV metadata with MaruServer. + + Data is already in CXL memory (zero-copy). This only registers + the key -> location metadata via RPC. + + Args: + key: The cache key. + memory_obj: MemoryObj with data already written to CXL. + on_complete_callback: Optional callback after registration. + + Returns: + Future that completes when metadata is registered. + """ + # If MLA worker id as 0 mode is enabled, skip put tasks + if self._mla_worker_id_as0_mode: + return self._create_immediate_empty_future() + + assert memory_obj.tensor is not None + + # Keep CXL page alive: ref_count_down is only called on failure. + # On success the ref is retained so the CXL memory is not reclaimed. + memory_obj.ref_count_up() + + with self.put_lock: + self.put_tasks.add(key) + + future = asyncio.run_coroutine_threadsafe( + self._async_store(key, memory_obj, on_complete_callback), + self.loop, + ) + return future + + def batched_submit_put_task( + self, + keys: Sequence[CacheEngineKey], + memory_objs: List[MemoryObj], + transfer_spec: Any = None, + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> Union[List[Future], None]: + """Submit batched put tasks via single batch_store RPC. + + Args: + keys: The cache keys. + memory_objs: MemoryObjs with data already in CXL. + transfer_spec: Unused. + on_complete_callback: Optional per-key callback. + + Returns: + List containing a single Future for the entire batch. + """ + # If MLA worker id as 0 mode is enabled, skip put tasks + if self._mla_worker_id_as0_mode: + return None + + for memory_obj in memory_objs: + assert memory_obj.tensor is not None + memory_obj.ref_count_up() + + with self.put_lock: + self.put_tasks.update(keys) + + future = asyncio.run_coroutine_threadsafe( + self._async_batch_store(list(keys), memory_objs, on_complete_callback), + self.loop, + ) + return [future] + + async def _async_store( + self, + key: CacheEngineKey, + memory_obj: MemoryObj, + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> None: + """Register KV metadata with MaruServer (runs in event loop). + + Uses CxlMemoryAdapter.create_store_handle() to extract + (region_id, page_index) from the MemoryObj's encoded address. + + Args: + key: The cache key. + memory_obj: MemoryObj backed by CXL memory. + on_complete_callback: Optional callback after registration. + """ + success = False + try: + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + handle = allocator.create_store_handle(memory_obj) + key_str = key.to_string() + + await asyncio.to_thread(self._handler.store, key_str, handle) + success = True + + logger.debug( + "[Maru] store key=%s rid=%d pid=%d", + key, + handle.region_id, + handle.page_index, + ) + + except Exception as e: + logger.error("[Maru] store failed key=%s: %s", key, e) + raise + finally: + with self.put_lock: + self.put_tasks.discard(key) + + if not success: + memory_obj.ref_count_down() + + if success and on_complete_callback is not None: + try: + on_complete_callback(key) + except Exception as e: + logger.warning("on_complete_callback failed for key %s: %s", key, e) + + async def _async_batch_store( + self, + keys: List[CacheEngineKey], + memory_objs: List[MemoryObj], + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> None: + """Register multiple KV metadata entries via single batch_store RPC.""" + results: Optional[list[bool]] = None + try: + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + + key_strs = [k.to_string() for k in keys] + handles = [allocator.create_store_handle(m) for m in memory_objs] + + results = await asyncio.to_thread( + self._handler.batch_store, key_strs, handles + ) + if results is not None: + logger.debug("[Maru] batch_store %d/%d ok", sum(results), len(results)) + except Exception as e: + logger.error("[Maru] batch_store failed: %s", e) + raise + finally: + with self.put_lock: + self.put_tasks.difference_update(keys) + + # Release ref_count for failed stores + for i, memory_obj in enumerate(memory_objs): + succeeded = results is not None and i < len(results) and results[i] + if not succeeded: + memory_obj.ref_count_down() + + if on_complete_callback is not None: + for i, key in enumerate(keys): + if results is not None and i < len(results) and results[i]: + try: + on_complete_callback(key) + except Exception as e: + logger.warning( + "on_complete_callback failed for key %s: %s", + key, + e, + ) + + # ========================================================================= + # Get (sync) + # ========================================================================= + + def get_blocking( + self, + key: CacheEngineKey, + ) -> Optional[MemoryObj]: + """Blocking get: read KV cache directly from CXL memory. + + Queries MaruServer for metadata, then returns a MemoryObj + via CxlMemoryAdapter.get_by_location(). + + Args: + key: The cache key. + + Returns: + MemoryObj backed by CXL memory, or None if not found. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + + key_str = key.to_string() + mem_info = self._handler.retrieve(key_str) + if mem_info is None: + logger.debug("[Maru] get_blocking miss key=%s", key) + return None + + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + + memory_obj = allocator.get_by_location( + region_id=mem_info.region_id, + page_index=mem_info.page_index, + actual_size=len(mem_info.view), + single_token_size=self._single_token_size, + ) + if memory_obj is None: + logger.debug( + "[Maru] get_blocking pool miss rid=%d pid=%d", + mem_info.region_id, + mem_info.page_index, + ) + return None + + memory_obj.ref_count_up() + + logger.debug( + "[Maru] get_blocking rid=%d pid=%d size=%d", + mem_info.region_id, + mem_info.page_index, + len(mem_info.view), + ) + return memory_obj + + def batched_get_blocking( + self, + keys: List[CacheEngineKey], + ) -> List[Optional[MemoryObj]]: + """Blocking batched get via single batch_retrieve RPC. + + Args: + keys: The cache keys. + + Returns: + List of MemoryObj (None for misses). + """ + if self._mla_worker_id_as0_mode: + keys = [k.with_new_worker_id(0) for k in keys] + + key_strs = [k.to_string() for k in keys] + mem_infos = self._handler.batch_retrieve(key_strs) + + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + + results: List[Optional[MemoryObj]] = [] + for mem_info in mem_infos: + if mem_info is None: + results.append(None) + continue + memory_obj = allocator.get_by_location( + region_id=mem_info.region_id, + page_index=mem_info.page_index, + actual_size=len(mem_info.view), + single_token_size=self._single_token_size, + ) + if memory_obj is None: + results.append(None) + continue + memory_obj.ref_count_up() + results.append(memory_obj) + + hits = sum(1 for r in results if r is not None) + logger.debug("[Maru] batch_retrieve %d/%d hits", hits, len(results)) + return results + + # ========================================================================= + # Async lookup API (used by StorageManager.async_lookup_and_prefetch) + # ========================================================================= + + async def batched_async_contains( + self, + lookup_id: str, + keys: List[CacheEngineKey], + pin: bool = False, + ) -> int: + """Check how many prefix keys exist via single batch_exists RPC. + + Returns the count of contiguous keys starting from index 0 + that exist. Stops at first miss. + + Args: + lookup_id: Unique request identifier. + keys: Keys to check in prefix order. + pin: Whether to pin. Not supported; logged as debug. + + Returns: + Number of prefix-contiguous keys that exist. + """ + return await asyncio.to_thread(self.batched_contains, keys, pin) + + async def batched_get_non_blocking( + self, + lookup_id: str, + keys: list[CacheEngineKey], + transfer_spec: Any = None, + ) -> list[MemoryObj]: + """Non-blocking batched get via single batch_retrieve RPC. + + Uses handler.batch_retrieve() for a single RPC call, then + resolves each MemoryInfo to a MemoryObj via CxlMemoryAdapter. + Stops at first miss and returns the prefix. + + Args: + lookup_id: Unique request identifier. + keys: Keys to retrieve (already confirmed by contains). + transfer_spec: Unused. + + Returns: + List of MemoryObjs backed by CXL memory. + """ + + def _batch_get() -> list[MemoryObj]: + if self._mla_worker_id_as0_mode: + actual_keys = [k.with_new_worker_id(0) for k in keys] + else: + actual_keys = list(keys) + + key_strs = [k.to_string() for k in actual_keys] + mem_infos = self._handler.batch_retrieve(key_strs) + + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + + results: list[MemoryObj] = [] + for mem_info in mem_infos: + if mem_info is None: + break + memory_obj = allocator.get_by_location( + region_id=mem_info.region_id, + page_index=mem_info.page_index, + actual_size=len(mem_info.view), + single_token_size=self._single_token_size, + ) + if memory_obj is None: + break + memory_obj.ref_count_up() + memory_obj.pin() + results.append(memory_obj) + + logger.debug( + "[Maru] batch_get_non_blocking %d/%d hits", len(results), len(keys) + ) + return results + + return await asyncio.to_thread(_batch_get) + + # ========================================================================= + # Contains / Pin / Unpin / Remove + # ========================================================================= + + def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: + """Check if key exists on MaruServer. + + Args: + key: The cache key. + pin: If True, atomically check existence and pin the entry + to protect it from eviction. + + Returns: + True if key exists. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + + key_str = key.to_string() + if pin: + return self._handler.pin(key_str) + return self._handler.exists(key_str) + + def batched_contains( + self, + keys: List[CacheEngineKey], + pin: bool = False, + ) -> int: + """Check how many prefix keys exist via single batch_exists RPC. + + Args: + keys: Keys to check in prefix order. + pin: If True, atomically check and pin via + batch_pin RPC. + + Returns: + Number of prefix-contiguous keys that exist. + """ + if self._mla_worker_id_as0_mode: + keys = [k.with_new_worker_id(0) for k in keys] + + key_strs = [k.to_string() for k in keys] + if pin: + results = self._handler.batch_pin(key_strs) + else: + results = self._handler.batch_exists(key_strs) + num_hit = 0 + for exists in results: + if not exists: + break + num_hit += 1 + return num_hit + + def pin(self, key: CacheEngineKey) -> bool: + """Pin a key to prevent eviction on MaruServer. + + Increments the server-side pin_count. + + Args: + key: The cache key. + + Returns: + True if pinned successfully. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + return self._handler.pin(key.to_string()) + + def unpin(self, key: CacheEngineKey) -> bool: + """Unpin a key to allow eviction on MaruServer. + + Decrements the server-side pin_count. When pin_count reaches 0, + the entry becomes eligible for eviction. + + Args: + key: The cache key. + + Returns: + True if unpinned successfully. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + return self._handler.unpin(key.to_string()) + + def batched_unpin(self, keys: List[CacheEngineKey]) -> None: + """Batch-unpin keys via single RPC. + + Decrements server-side pin_count for each key. When pin_count + reaches 0, the entry becomes eligible for eviction. + + Args: + keys: The cache keys to unpin. + """ + if not keys: + return + if self._mla_worker_id_as0_mode: + keys = [k.with_new_worker_id(0) for k in keys] + key_strs = [k.to_string() for k in keys] + self._handler.batch_unpin(key_strs) + + def remove(self, key: CacheEngineKey, force: bool = True) -> bool: + """Remove a key from MaruServer. + + Args: + key: The cache key. + force: Whether to force removal. + + Returns: + True if removed successfully. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + key_str = key.to_string() + result = self._handler.delete(key_str) + logger.debug("[Maru] remove key=%s success=%s", key, result) + return result + + # ========================================================================= + # Lifecycle + # ========================================================================= + + def close(self) -> None: + """Close the backend and underlying MaruHandler.""" + while True: + with self.put_lock: + if not self.put_tasks: + break + time.sleep(0.1) + + self.memory_allocator.close() + self._handler.close() + logger.info("MaruBackend closed.") diff --git a/lmcache/v1/storage_backend/storage_manager.py b/lmcache/v1/storage_backend/storage_manager.py index e8f6e9ebf4..ec3a53d86d 100644 --- a/lmcache/v1/storage_backend/storage_manager.py +++ b/lmcache/v1/storage_backend/storage_manager.py @@ -314,6 +314,11 @@ def _get_allocator_backend( ) -> AllocatorBackendInterface: if self.enable_pd: allocator_backend = self.storage_backends["PDBackend"] + elif "MaruBackend" in self.storage_backends: + if "LocalCPUBackend" in self.storage_backends: + allocator_backend = self.storage_backends["LocalCPUBackend"] + else: + allocator_backend = self.storage_backends["MaruBackend"] else: allocator_backend = self.storage_backends["LocalCPUBackend"] assert isinstance(allocator_backend, AllocatorBackendInterface) @@ -443,7 +448,7 @@ def get( memory_obj = backend.get_blocking(key) if memory_obj: if ( - backend_name not in ["LocalCPUBackend", "PDBackend"] + backend_name not in ["LocalCPUBackend", "PDBackend", "MaruBackend"] and "LocalCPUBackend" in self.storage_backends ): local_cpu_backend = self.storage_backends["LocalCPUBackend"] @@ -487,7 +492,7 @@ def batched_get( # Align with single-key `get()` logic: # auto-write remote data to local CPU cache if ( - backend_name not in ["LocalCPUBackend", "PDBackend"] + backend_name not in ["LocalCPUBackend", "PDBackend", "MaruBackend"] and "LocalCPUBackend" in self.storage_backends and None not in memory_objs ): diff --git a/tests/v1/storage_backend/test_maru_backend.py b/tests/v1/storage_backend/test_maru_backend.py new file mode 100644 index 0000000000..d2b3f3fe1b --- /dev/null +++ b/tests/v1/storage_backend/test_maru_backend.py @@ -0,0 +1,790 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from unittest.mock import MagicMock, patch +import asyncio +import mmap +import threading + +# Third Party +import pytest +import torch + +# First Party +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import MemoryFormat, TensorMemoryObj +from lmcache.v1.pin_monitor import PinMonitor +from lmcache.v1.storage_backend.abstract_backend import AllocatorBackendInterface +from tests.v1.utils import ( + check_method_signatures, + get_abstract_methods, + get_methods_implemented_in_class, +) + +maru = pytest.importorskip("maru", reason="maru package not installed") +maru_lmcache = pytest.importorskip( + "maru_lmcache", reason="maru_lmcache package not installed" +) + +# Third Party +from maru_handler.memory import AllocHandle # noqa: E402 +from maru_handler.memory.types import MappedRegion, MemoryInfo # noqa: E402 +from maru_lmcache.adapter import CxlMemoryAdapter # noqa: E402 + +# First Party +from lmcache.v1.storage_backend.maru_backend import MaruBackend # noqa: E402 + +# ========================================================================= +# Constants +# ========================================================================= + +TEST_CHUNK_SIZE = 1024 +TEST_DTYPE = torch.float32 +TEST_SHAPE = torch.Size([256]) # 256 * 4B = 1024 bytes = chunk_size + + +# ========================================================================= +# Helpers +# ========================================================================= + + +def _make_mock_handler(pool_size=4096, chunk_size=TEST_CHUNK_SIZE): + """Create a mock MaruHandler with mmap-backed regions.""" + handler = MagicMock() + handler._connected = True + + region_id = 100 + page_count = pool_size // chunk_size + + mmap_obj = mmap.mmap(-1, pool_size) + mapped_region = MappedRegion( + region_id=region_id, + handle=MagicMock(region_id=region_id, length=pool_size), + size=pool_size, + _mmap_obj=mmap_obj, + ) + + handler.get_buffer_view.side_effect = lambda rid, offset, size: ( + mapped_region.get_buffer_view(offset, size) if rid == region_id else None + ) + handler.get_region_page_count.side_effect = lambda rid: ( + page_count if rid == region_id else None + ) + handler.get_owned_region_ids.return_value = [region_id] + handler.get_chunk_size.return_value = chunk_size + + def mock_set_on_region_added(callback): + if callback is not None: + callback(region_id, page_count) + + handler.set_on_region_added.side_effect = mock_set_on_region_added + + page_counter = [0] + + def mock_alloc(size): + idx = page_counter[0] + page_counter[0] += 1 + buf = mapped_region.get_buffer_view(idx * chunk_size, size) + return AllocHandle(buf=buf, _region_id=region_id, _page_index=idx, _size=size) + + handler.alloc.side_effect = mock_alloc + handler.free = MagicMock() + handler.connect.return_value = True + handler.close.return_value = None + handler.store.return_value = True + handler.batch_store.return_value = None + handler.retrieve.return_value = None + handler.batch_retrieve.return_value = [] + handler.exists.return_value = False + handler.batch_exists.return_value = [] + handler.delete.return_value = True + handler.pin.return_value = True + handler.unpin.return_value = True + handler.batch_pin.return_value = [] + handler.batch_unpin.return_value = None + + return handler + + +def _make_cache_key(chunk_hash: int = 12345) -> CacheEngineKey: + """Create a CacheEngineKey for testing.""" + return CacheEngineKey( + model_name="test-model", + world_size=1, + worker_id=0, + chunk_hash=chunk_hash, + dtype=torch.float32, + ) + + +def _make_memory_obj(adapter: CxlMemoryAdapter) -> TensorMemoryObj: + """Allocate a TensorMemoryObj from the adapter.""" + obj = adapter.allocate(TEST_SHAPE, TEST_DTYPE) + assert obj is not None + return obj + + +# ========================================================================= +# Fixtures +# ========================================================================= + + +@pytest.fixture(autouse=True) +def _init_pin_monitor(): + """Initialize PinMonitor singleton required by TensorMemoryObj.pin().""" + PinMonitor._instance = None + PinMonitor.GetOrCreate(LMCacheEngineConfig.from_defaults()) + yield + PinMonitor._instance = None + + +@pytest.fixture +def async_loop(): + """Provide an asyncio event loop running in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() + yield loop + loop.call_soon_threadsafe(loop.stop) + thread.join(timeout=5) + loop.close() + + +@pytest.fixture +def mock_handler(): + return _make_mock_handler() + + +@pytest.fixture +def adapter(mock_handler): + return CxlMemoryAdapter( + handler=mock_handler, + shapes=[TEST_SHAPE], + dtypes=[TEST_DTYPE], + fmt=MemoryFormat.KV_2LTD, + chunk_size=TEST_CHUNK_SIZE, + ) + + +@pytest.fixture +def backend(mock_handler, adapter, async_loop): + """Create a MaruBackend with mocked internals.""" + # Local + + with patch.object(MaruBackend, "initialize_allocator", return_value=adapter): + backend = MaruBackend.__new__(MaruBackend) + backend.dst_device = "cpu" + backend.config = MagicMock() + backend.config.maru_pool_size = 4.0 + backend.loop = async_loop + backend.memory_allocator = adapter + backend._handler = mock_handler + + backend._full_chunk_size_bytes = TEST_CHUNK_SIZE + backend._single_token_size = TEST_CHUNK_SIZE // 256 # 4 bytes per token + backend._mla_worker_id_as0_mode = False + + backend.put_lock = threading.Lock() + backend.put_tasks = set() + return backend + + +def _run_async(loop, coro): + """Submit a coroutine to a running event loop and wait for result.""" + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result(timeout=5) + + +# ========================================================================= +# Tests — Init & Interface Compliance +# ========================================================================= + + +class TestMaruBackendInit: + def test_str(self, backend): + assert str(backend) == "MaruBackend" + + def test_get_allocator_backend_returns_self(self, backend): + assert backend.get_allocator_backend() is backend + + def test_get_memory_allocator_returns_adapter(self, backend, adapter): + assert backend.get_memory_allocator() is adapter + + +class TestMaruBackendPoolSizeGbToBytes: + """Test _pool_size_gb_to_bytes static method.""" + + def test_4gb(self): + assert MaruBackend._pool_size_gb_to_bytes(4.0) == 4 * 1024**3 + + def test_half_gb(self): + assert MaruBackend._pool_size_gb_to_bytes(0.5) == 512 * 1024**2 + + def test_1gb(self): + assert MaruBackend._pool_size_gb_to_bytes(1.0) == 1024**3 + + def test_zero(self): + assert MaruBackend._pool_size_gb_to_bytes(0.0) == 0 + + +class TestMaruBackendInterfaceCompliance: + """Verify MaruBackend implements all required interface methods.""" + + def test_implements_all_abstract_methods(self): + abstract = get_abstract_methods(AllocatorBackendInterface) + implemented = get_methods_implemented_in_class( + MaruBackend, AllocatorBackendInterface + ) + missing = abstract - implemented + assert not missing, f"Missing abstract methods: {missing}" + + def test_method_signatures_match(self): + # Known: batched_submit_put_task uses 'memory_objs' instead of 'objs' + # TODO: Rename to 'objs' for full compliance + known_param_renames = {"batched_submit_put_task"} + + mismatches = check_method_signatures(AllocatorBackendInterface, MaruBackend) + unexpected = [m for m in mismatches if m["method"] not in known_param_renames] + assert not unexpected, f"Signature mismatches: {unexpected}" + + +# ========================================================================= +# Tests — Allocate +# ========================================================================= + + +class TestMaruBackendAllocate: + def test_allocate_returns_memory_obj(self, backend): + obj = backend.allocate(TEST_SHAPE, TEST_DTYPE) + assert obj is not None + assert obj.tensor is not None + assert obj.metadata.dtype == TEST_DTYPE + + def test_batched_allocate_returns_list(self, backend): + objs = backend.batched_allocate(TEST_SHAPE, TEST_DTYPE, batch_size=3) + assert objs is not None + assert len(objs) == 3 + for obj in objs: + assert obj.tensor is not None + + +# ========================================================================= +# Tests — Put (async) +# ========================================================================= + + +class TestMaruBackendPut: + def test_submit_put_task_returns_future(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + future = backend.submit_put_task(key, obj) + assert future is not None + future.result(timeout=5) + + backend._handler.store.assert_called_once() + + def test_submit_put_task_tracks_in_flight(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + assert not backend.exists_in_put_tasks(key) + + future = backend.submit_put_task(key, obj) + future.result(timeout=5) + + # After completion, key should be removed from put_tasks + assert not backend.exists_in_put_tasks(key) + + def test_exists_in_put_tasks_true_during_store(self, backend, adapter): + """Verify exists_in_put_tasks returns True while store is in progress.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + store_entered = threading.Event() + store_proceed = threading.Event() + + def blocking_store(*args, **kwargs): + store_entered.set() + store_proceed.wait(timeout=5) + return True + + backend._handler.store.side_effect = blocking_store + + future = backend.submit_put_task(key, obj) + + # Wait until store is actually running + assert store_entered.wait(timeout=5) + assert backend.exists_in_put_tasks(key) + + # Let store complete + store_proceed.set() + future.result(timeout=5) + assert not backend.exists_in_put_tasks(key) + + def test_batched_submit_put_task(self, backend, adapter): + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + + backend._handler.batch_store.return_value = [True, True, True] + + futures = backend.batched_submit_put_task(keys, objs) + assert futures is not None + + for future in futures: + future.result(timeout=5) + + backend._handler.batch_store.assert_called_once() + + def test_submit_put_calls_callback(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + callback_called = [] + + def callback(k): + callback_called.append(k) + + future = backend.submit_put_task(key, obj, on_complete_callback=callback) + future.result(timeout=5) + + assert len(callback_called) == 1 + assert callback_called[0] == key + + def test_batched_submit_put_calls_callback_per_key(self, backend, adapter): + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + + backend._handler.batch_store.return_value = [True, True, True] + callback_keys = [] + + def callback(k): + callback_keys.append(k) + + futures = backend.batched_submit_put_task( + keys, objs, on_complete_callback=callback + ) + for future in futures: + future.result(timeout=5) + + assert set(callback_keys) == set(keys) + + def test_submit_put_task_skips_in_mla_mode(self, backend, adapter): + """In MLA worker_id_as0 mode, submit_put_task should skip store.""" + backend._mla_worker_id_as0_mode = True + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + future = backend.submit_put_task(key, obj) + assert future.result(timeout=5) is None + backend._handler.store.assert_not_called() + + def test_submit_put_task_refcount_down_on_failure(self, backend, adapter): + """On store failure, ref_count should return to pre-submit level.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + initial_ref = obj.get_ref_count() + + backend._handler.store.side_effect = RuntimeError("store failed") + + future = backend.submit_put_task(key, obj) + with pytest.raises(RuntimeError): + future.result(timeout=5) + + assert obj.get_ref_count() == initial_ref + assert not backend.exists_in_put_tasks(key) + + def test_batched_submit_put_task_refcount_down_on_failure( + self, backend, adapter + ): + """On batch_store failure, ref_count should return to pre-submit level.""" + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + initial_refs = [obj.get_ref_count() for obj in objs] + + backend._handler.batch_store.side_effect = RuntimeError("batch failed") + + futures = backend.batched_submit_put_task(keys, objs) + for future in futures: + with pytest.raises(RuntimeError): + future.result(timeout=5) + + for obj, initial_ref in zip(objs, initial_refs): + assert obj.get_ref_count() == initial_ref + for key in keys: + assert not backend.exists_in_put_tasks(key) + + def test_batched_submit_put_task_skips_in_mla_mode(self, backend, adapter): + """In MLA worker_id_as0 mode, batched_submit_put_task should skip.""" + backend._mla_worker_id_as0_mode = True + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + + result = backend.batched_submit_put_task(keys, objs) + assert result is None + backend._handler.batch_store.assert_not_called() + + +# ========================================================================= +# Tests — Get (sync) +# ========================================================================= + + +class TestMaruBackendGet: + def test_get_blocking_hit(self, backend, adapter): + key = _make_cache_key() + + data_size = TEST_CHUNK_SIZE + data = bytearray(data_size) + mock_info = MemoryInfo( + view=memoryview(data), + region_id=100, + page_index=0, + ) + backend._handler.retrieve.return_value = mock_info + + result = backend.get_blocking(key) + assert result is not None + backend._handler.retrieve.assert_called_once() + + def test_get_blocking_miss(self, backend): + key = _make_cache_key() + backend._handler.retrieve.return_value = None + + result = backend.get_blocking(key) + assert result is None + + def test_get_blocking_ref_count_increases(self, backend, adapter): + """After get_blocking, the returned MemoryObj should have ref_count + incremented.""" + # Pre-allocate so pool has page 0 + _make_memory_obj(adapter) + + key = _make_cache_key() + mock_info = MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=100, + page_index=0, + ) + backend._handler.retrieve.return_value = mock_info + + result = backend.get_blocking(key) + assert result is not None + # Pool objects start with ref_count=1, get_blocking calls ref_count_up + assert result.get_ref_count() >= 2 + + def test_batched_get_blocking(self, backend, adapter): + """batched_get_blocking returns list of MemoryObj via batch_retrieve.""" + objs = [_make_memory_obj(adapter) for _ in range(2)] + keys = [_make_cache_key(i) for i in range(2)] + + infos = [] + for obj in objs: + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + infos.append( + MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + ) + backend._handler.batch_retrieve.return_value = infos + + results = backend.batched_get_blocking(keys) + assert len(results) == 2 + for r in results: + assert r is not None + + def test_batched_get_blocking_with_miss(self, backend, adapter): + """batched_get_blocking returns None for missing keys.""" + obj = _make_memory_obj(adapter) + keys = [_make_cache_key(i) for i in range(2)] + + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + info = MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + backend._handler.batch_retrieve.return_value = [info, None] + + results = backend.batched_get_blocking(keys) + assert len(results) == 2 + assert results[0] is not None + assert results[1] is None + + +# ========================================================================= +# Tests — Contains +# ========================================================================= + + +class TestMaruBackendContains: + def test_contains_true(self, backend): + key = _make_cache_key() + backend._handler.exists.return_value = True + + assert backend.contains(key) is True + backend._handler.exists.assert_called_once_with(key.to_string()) + + def test_contains_false(self, backend): + key = _make_cache_key() + backend._handler.exists.return_value = False + + assert backend.contains(key) is False + + def test_batched_contains_all_hit(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, True] + + result = backend.batched_contains(keys) + assert result == 3 + + def test_batched_contains_partial_prefix(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, False] + + result = backend.batched_contains(keys) + assert result == 2 + + def test_batched_contains_first_miss(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [False, True, True] + + result = backend.batched_contains(keys) + assert result == 0 + + def test_contains_with_pin(self, backend): + key = _make_cache_key() + backend._handler.pin.return_value = True + + assert backend.contains(key, pin=True) is True + backend._handler.pin.assert_called_once_with(key.to_string()) + backend._handler.exists.assert_not_called() + + def test_contains_with_pin_false(self, backend): + key = _make_cache_key() + backend._handler.pin.return_value = False + + assert backend.contains(key, pin=True) is False + + def test_batched_contains_with_pin(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_pin.return_value = [True, True, True] + + result = backend.batched_contains(keys, pin=True) + assert result == 3 + backend._handler.batch_pin.assert_called_once_with( + [k.to_string() for k in keys] + ) + backend._handler.batch_exists.assert_not_called() + + def test_batched_contains_with_pin_partial(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_pin.return_value = [True, False, True] + + result = backend.batched_contains(keys, pin=True) + assert result == 1 + + def test_batched_contains_empty(self, backend): + backend._handler.batch_exists.return_value = [] + assert backend.batched_contains([]) == 0 + + +# ========================================================================= +# Tests — Async Lookup +# ========================================================================= + + +class TestMaruBackendAsyncLookup: + def test_batched_async_contains_all_hit(self, backend, async_loop): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, True] + + result = _run_async( + async_loop, backend.batched_async_contains("lookup-1", keys) + ) + assert result == 3 + + def test_batched_async_contains_partial_prefix(self, backend, async_loop): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, False, True] + + result = _run_async( + async_loop, backend.batched_async_contains("lookup-2", keys) + ) + assert result == 1 + + def test_batched_async_contains_empty(self, backend, async_loop): + backend._handler.batch_exists.return_value = [] + result = _run_async(async_loop, backend.batched_async_contains("lookup-3", [])) + assert result == 0 + + def test_batched_get_non_blocking_all_hit(self, backend, adapter, async_loop): + keys = [_make_cache_key(i) for i in range(2)] + + objs = [_make_memory_obj(adapter) for _ in range(2)] + infos = [] + for obj in objs: + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + infos.append( + MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + ) + backend._handler.batch_retrieve.return_value = infos + + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-4", keys) + ) + assert len(results) == 2 + for obj in results: + assert obj is not None + + def test_batched_get_non_blocking_prefix_stop_on_miss( + self, backend, adapter, async_loop + ): + """Second key is a miss -> only first returned (prefix semantics).""" + keys = [_make_cache_key(i) for i in range(3)] + + obj = _make_memory_obj(adapter) + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + info = MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + # hit, miss, hit -> should return only [hit] + backend._handler.batch_retrieve.return_value = [info, None, info] + + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-5", keys) + ) + assert len(results) == 1 + + def test_batched_get_non_blocking_empty(self, backend, async_loop): + backend._handler.batch_retrieve.return_value = [] + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-6", []) + ) + assert results == [] + + +# ========================================================================= +# Tests — Pin / Unpin / Remove +# ========================================================================= + + +class TestMaruBackendPinRemove: + def test_pin_delegates_to_handler(self, backend): + key = _make_cache_key() + backend._handler.pin.return_value = True + + assert backend.pin(key) is True + backend._handler.pin.assert_called_once_with(key.to_string()) + + def test_pin_returns_false_on_failure(self, backend): + key = _make_cache_key() + backend._handler.pin.return_value = False + + assert backend.pin(key) is False + + def test_unpin_delegates_to_handler(self, backend): + key = _make_cache_key() + backend._handler.unpin.return_value = True + + assert backend.unpin(key) is True + backend._handler.unpin.assert_called_once_with(key.to_string()) + + def test_unpin_returns_false_on_failure(self, backend): + key = _make_cache_key() + backend._handler.unpin.return_value = False + + assert backend.unpin(key) is False + + def test_batched_unpin(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + + backend.batched_unpin(keys) + backend._handler.batch_unpin.assert_called_once_with( + [k.to_string() for k in keys] + ) + + def test_batched_unpin_empty(self, backend): + backend.batched_unpin([]) + backend._handler.batch_unpin.assert_not_called() + + def test_remove_existing_key(self, backend): + key = _make_cache_key() + backend._handler.delete.return_value = True + + result = backend.remove(key) + assert result is True + backend._handler.delete.assert_called_once_with(key.to_string()) + + def test_remove_nonexistent_key(self, backend): + key = _make_cache_key() + backend._handler.delete.return_value = False + + result = backend.remove(key) + assert result is False + + +# ========================================================================= +# Tests — Lifecycle +# ========================================================================= + + +class TestMaruBackendLifecycle: + def test_close_calls_handler_and_allocator(self, backend): + backend.memory_allocator = MagicMock() + backend.close() + backend.memory_allocator.close.assert_called_once() + backend._handler.close.assert_called_once() + + def test_close_drains_pending_put_tasks(self, backend, adapter): + """close() should wait for in-flight put tasks to complete.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + # Submit a real put task that will complete via the event loop + future = backend.submit_put_task(key, obj) + future.result(timeout=5) + + # After drain, close should succeed + backend.close() + backend._handler.close.assert_called_once() + + +# ========================================================================= +# Tests — Store Handle Roundtrip +# ========================================================================= + + +class TestMaruBackendStoreHandle: + def test_store_handle_roundtrip(self, backend, adapter): + """AllocHandle from create_store_handle should match original.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + + handle = adapter.create_store_handle(obj) + assert handle.region_id == 100 + assert handle.page_index == 0 + assert handle._size == obj.metadata.phy_size