From e15f00b3aeaf1607589059da541ef7e10f05584a Mon Sep 17 00:00:00 2001 From: youngrok-XCENA Date: Sun, 15 Mar 2026 15:21:38 +0000 Subject: [PATCH 01/21] feat: maru storage backend bring-up --- lmcache/v1/config.py | 7 + lmcache/v1/storage_backend/__init__.py | 6 + lmcache/v1/storage_backend/maru_backend.py | 483 ++++++++++++++++++ lmcache/v1/storage_backend/storage_manager.py | 6 +- 4 files changed, 500 insertions(+), 2 deletions(-) create mode 100644 lmcache/v1/storage_backend/maru_backend.py diff --git a/lmcache/v1/config.py b/lmcache/v1/config.py index cd2ada16e6..813bcf2c07 100644 --- a/lmcache/v1/config.py +++ b/lmcache/v1/config.py @@ -234,6 +234,13 @@ "default": None, "env_converter": int, }, + # Maru CXL shared memory backend + "maru_path": {"type": Optional[str], "default": None, "env_converter": str}, + "maru_pool_size": { + "type": Optional[str], + "default": None, + "env_converter": str, + }, # Other configurations # (Deprecated) The url of the actual remote lmcache instance for auditing. # Please use extra_config['audit_actual_remote_url'] instead. diff --git a/lmcache/v1/storage_backend/__init__.py b/lmcache/v1/storage_backend/__init__.py index b7212b9603..3ab017b0ec 100644 --- a/lmcache/v1/storage_backend/__init__.py +++ b/lmcache/v1/storage_backend/__init__.py @@ -218,6 +218,12 @@ def CreateStorageBackends( ) storage_backends[str(gds_backend)] = gds_backend + if config.maru_path is not None and "MaruBackend" not in _skip: + from lmcache.v1.storage_backend.maru_backend import MaruBackend + + maru_backend = MaruBackend(config, metadata, loop, dst_device) + storage_backends[str(maru_backend)] = maru_backend + if config.remote_url is not None and "RemoteBackend" not in _skip: assert local_cpu_backend is not None, ( "Remote backend requires local CPU backend as a buffer." diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py new file mode 100644 index 0000000000..e885b2f7ec --- /dev/null +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -0,0 +1,483 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from concurrent.futures import Future +from typing import Any, Callable, List, Optional, Sequence, Union +import asyncio +import re +import threading + +# Third Party +from maru import MaruConfig, MaruHandler +from maru_lmcache import CxlMemoryAllocator +import torch + +# First Party +from lmcache.integration.vllm.utils import get_size_bytes +from lmcache.logging import init_logger +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import ( + MemoryAllocatorInterface, + MemoryFormat, + MemoryObj, +) +from lmcache.v1.metadata import LMCacheMetadata +from lmcache.v1.storage_backend.abstract_backend import AllocatorBackendInterface + +logger = init_logger(__name__) + + +class MaruBackend(AllocatorBackendInterface): + """Maru CXL shared memory storage backend. + + Implements AllocatorBackendInterface with its own CxlMemoryAllocator. + No LocalCPUBackend needed — data lives directly in CXL mmap memory. + + Put is async (Future): metadata registration via RPC. + Get is sync: CXL memory direct read (no network I/O). + + Args: + config: LMCache engine configuration. Must have maru_path set. + metadata: LMCache engine metadata. + loop: asyncio event loop for async put tasks. + dst_device: Target device string (unused for CXL, kept for interface). + """ + + def __init__( + self, + config: LMCacheEngineConfig, + metadata: LMCacheMetadata, + loop: asyncio.AbstractEventLoop, + dst_device: str = "cuda", + ): + super().__init__(dst_device=dst_device) + + # 1. Config + self.config = config + self.loop = loop + + self._full_chunk_size_bytes: int = get_size_bytes( + metadata.get_shapes(), metadata.get_dtypes() + ) + assert self._full_chunk_size_bytes % metadata.chunk_size == 0 + self._single_token_size: int = ( + self._full_chunk_size_bytes // metadata.chunk_size + ) + + self._mla_worker_id_as0_mode: bool = ( + config.get_extra_config_value( + "remote_enable_mla_worker_id_as0", metadata.use_mla + ) + and metadata.use_mla + and metadata.world_size > 1 + and metadata.worker_id != 0 + ) + + # 2. Handler + self._handler = self._create_handler(config) + + # 3. Allocator + self.memory_allocator = self.initialize_allocator(config, metadata) + + # 4. State + self.put_lock = threading.Lock() + self.put_tasks: set[CacheEngineKey] = set() + + def __str__(self) -> str: + return self.__class__.__name__ + + @staticmethod + def _parse_pool_size(raw: Optional[str]) -> int: + """Parse human-readable pool size (e.g. '4G', '512M') to bytes.""" + _DEFAULT = 4 * 1024**3 + if raw is None: + return _DEFAULT + if isinstance(raw, (int, float)): + return int(raw) + s = str(raw).strip().upper() + match = re.match(r"^(\d+(?:\.\d+)?)\s*([KMGT]?)B?$", s) + if not match: + try: + return int(s) + except ValueError: + logger.warning("Cannot parse maru_pool_size=%r, using default", raw) + return _DEFAULT + value, unit = float(match.group(1)), match.group(2) + multipliers = {"": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4} + return int(value * multipliers.get(unit, 1)) + + # ========================================================================= + # Initialization helpers + # ========================================================================= + + def _create_handler( + self, config: LMCacheEngineConfig, + ) -> "MaruHandler": + """Create and connect a MaruHandler. + + Args: + config: LMCache engine configuration. + + Returns: + Connected MaruHandler instance. + + Raises: + RuntimeError: If MaruHandler connection fails. + """ + assert config.maru_path is not None, ( + "maru_path must be set for MaruBackend" + ) + + extra = config.extra_config or {} + maru_config = MaruConfig( + server_url=config.maru_path, + instance_id=extra.get("maru_instance_id"), + pool_size=self._parse_pool_size(config.maru_pool_size), + chunk_size_bytes=self._full_chunk_size_bytes, + auto_connect=False, + timeout_ms=extra.get("maru_timeout_ms", 5000), + use_async_rpc=extra.get("maru_use_async_rpc", True), + max_inflight=extra.get("maru_max_inflight", 64), + eager_map=extra.get("maru_eager_map", True), + ) + + handler = MaruHandler(maru_config) + if not handler.connect(): + raise RuntimeError( + f"Failed to connect MaruHandler to {config.maru_path}" + ) + logger.debug("[Maru] Connected to %s", config.maru_path) + return handler + + # ========================================================================= + # AllocatorBackendInterface + # ========================================================================= + + def initialize_allocator( + self, config: LMCacheEngineConfig, metadata: LMCacheMetadata + ) -> MemoryAllocatorInterface: + """Create CxlMemoryAllocator backed by the connected handler. + + Args: + config: LMCache engine configuration. + metadata: LMCache engine metadata. + + Returns: + CxlMemoryAllocator instance. + """ + shapes = metadata.get_shapes() + dtypes = metadata.get_dtypes() + fmt = ( + MemoryFormat.KV_MLA_FMT if metadata.use_mla else MemoryFormat.KV_2LTD + ) + chunk_size = self._handler.owned_region_manager.get_chunk_size() + + return CxlMemoryAllocator( + handler=self._handler, + shapes=shapes, + dtypes=dtypes, + fmt=fmt, + chunk_size=chunk_size, + ) + + def get_memory_allocator(self) -> MemoryAllocatorInterface: + """Returns the underlying CxlMemoryAllocator.""" + return self.memory_allocator + + def get_allocator_backend(self) -> "MaruBackend": + """Returns self as the allocator backend.""" + return self + + def allocate( + self, + shapes: Union[torch.Size, list[torch.Size]], + dtypes: Union[torch.dtype, list[torch.dtype]], + fmt: MemoryFormat = MemoryFormat.KV_2LTD, + eviction: bool = True, + busy_loop: bool = True, + ) -> Optional[MemoryObj]: + """Allocate CXL-backed memory via CxlMemoryAllocator. + + Args: + shapes: Tensor shape(s). + dtypes: Tensor dtype(s). + fmt: Memory format. + eviction: Unused (no eviction policy yet). + busy_loop: Unused. + + Returns: + MemoryObj backed by CXL memory, or None on failure. + """ + obj = self.memory_allocator.allocate(shapes, dtypes, fmt) + if obj is not None: + logger.debug("[Maru] allocate rid=%d pid=%d", + *CxlMemoryAllocator.decode_address(obj.metadata.address)) + else: + logger.debug("[Maru] allocate failed shapes=%s dtypes=%s", shapes, dtypes) + return obj + + def batched_allocate( + self, + shapes: Union[torch.Size, list[torch.Size]], + dtypes: Union[torch.dtype, list[torch.dtype]], + batch_size: int, + fmt: MemoryFormat = MemoryFormat.KV_2LTD, + eviction: bool = True, + busy_loop: bool = True, + ) -> Optional[list[MemoryObj]]: + """Allocate multiple CXL-backed MemoryObjs. + + Args: + shapes: Tensor shape(s) (same for each allocation). + dtypes: Tensor dtype(s) (same for each allocation). + batch_size: Number of allocations. + fmt: Memory format. + eviction: Unused. + busy_loop: Unused. + + Returns: + List of MemoryObj, or None if any allocation fails. + """ + return self.memory_allocator.batched_allocate( + shapes, dtypes, batch_size, fmt + ) + + # ========================================================================= + # Put (async) + # ========================================================================= + + def exists_in_put_tasks(self, key: CacheEngineKey) -> bool: + """Check whether key is in ongoing put tasks. + + Args: + key: The cache key. + + Returns: + True if the key has a pending put task. + """ + with self.put_lock: + return key in self.put_tasks + + def submit_put_task( + self, + key: CacheEngineKey, + memory_obj: MemoryObj, + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> Future: + """Submit a put task to register KV metadata with MaruServer. + + Data is already in CXL memory (zero-copy). This only registers + the key -> location metadata via RPC. + + Args: + key: The cache key. + memory_obj: MemoryObj with data already written to CXL. + on_complete_callback: Optional callback after registration. + + Returns: + Future that completes when metadata is registered. + """ + assert memory_obj.tensor is not None + + with self.put_lock: + self.put_tasks.add(key) + + future = asyncio.run_coroutine_threadsafe( + self._async_store(key, memory_obj, on_complete_callback), + self.loop, + ) + return future + + def batched_submit_put_task( + self, + keys: Sequence[CacheEngineKey], + memory_objs: List[MemoryObj], + transfer_spec: Any = None, + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> Union[List[Future], None]: + """Submit batched put tasks. + + Args: + keys: The cache keys. + memory_objs: MemoryObjs with data already in CXL. + transfer_spec: Unused. + on_complete_callback: Optional per-key callback. + + Returns: + List of Futures, one per key. + """ + futures = [] + for key, memory_obj in zip(keys, memory_objs, strict=False): + future = self.submit_put_task( + key, memory_obj, on_complete_callback=on_complete_callback + ) + futures.append(future) + return futures + + async def _async_store( + self, + key: CacheEngineKey, + memory_obj: MemoryObj, + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> None: + """Register KV metadata with MaruServer (runs in event loop). + + Uses CxlMemoryAllocator.create_store_handle() to extract + (region_id, page_index) from the MemoryObj's encoded address. + + Args: + key: The cache key. + memory_obj: MemoryObj backed by CXL memory. + on_complete_callback: Optional callback after registration. + """ + try: + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAllocator) + handle = allocator.create_store_handle(memory_obj) + key_str = key.to_string() + + await asyncio.to_thread( + self._handler.store, key_str, handle + ) + + logger.debug("[Maru] store key=%s rid=%d pid=%d", + key, handle.region_id, handle.page_index) + + except Exception as e: + logger.error("[Maru] store failed key=%s: %s", key, e) + finally: + with self.put_lock: + self.put_tasks.discard(key) + + if on_complete_callback is not None: + try: + on_complete_callback(key) + except Exception as e: + logger.warning( + "on_complete_callback failed for key %s: %s", key, e + ) + + # ========================================================================= + # Get (sync) + # ========================================================================= + + def get_blocking( + self, + key: CacheEngineKey, + ) -> Optional[MemoryObj]: + """Blocking get: read KV cache directly from CXL memory. + + Queries MaruServer for metadata, then returns a MemoryObj + via CxlMemoryAllocator.get_by_location(). + + Args: + key: The cache key. + + Returns: + MemoryObj backed by CXL memory, or None if not found. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + + key_str = key.to_string() + mem_info = self._handler.retrieve(key_str) + if mem_info is None: + logger.debug("[Maru] get_blocking miss key=%s", key) + return None + + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAllocator) + + memory_obj = allocator.get_by_location( + region_id=mem_info.region_id, + page_index=mem_info.page_index, + actual_size=len(mem_info.view), + single_token_size=self._single_token_size, + ) + if memory_obj is None: + logger.debug("[Maru] get_blocking pool miss rid=%d pid=%d", + mem_info.region_id, mem_info.page_index) + return None + + memory_obj.ref_count_up() + + logger.debug("[Maru] get_blocking rid=%d pid=%d size=%d", + mem_info.region_id, mem_info.page_index, + len(mem_info.view)) + return memory_obj + + # ========================================================================= + # Contains / Pin / Unpin / Remove + # ========================================================================= + + def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: + """Check if key exists on MaruServer. + + Args: + key: The cache key. + pin: If True, pin the entry. (TODO: delegate to handler) + + Returns: + True if key exists. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + + return self._handler.exists(key.to_string()) + + def pin(self, key: CacheEngineKey) -> bool: + """Pin a key to prevent eviction. + + TODO: Delegate to MaruHandler.pin() once server-side + ref_count management is implemented. + + Args: + key: The cache key. + + Returns: + True if pinned successfully. + """ + return False + + def unpin(self, key: CacheEngineKey) -> bool: + """Unpin a key to allow eviction. + + TODO: Delegate to MaruHandler.unpin() once server-side + ref_count management is implemented. + + Args: + key: The cache key. + + Returns: + True if unpinned successfully. + """ + return False + + def remove(self, key: CacheEngineKey, force: bool = True) -> bool: + """Remove a key from MaruServer. + + Args: + key: The cache key. + force: Whether to force removal. + + Returns: + True if removed successfully. + """ + with self.data_lock: + self.data.pop(key, None) + + key_str = key.to_string() + result = self._handler.delete(key_str) + logger.debug("[Maru] remove key=%s success=%s", key, result) + return result + + # ========================================================================= + # Lifecycle + # ========================================================================= + + def close(self) -> None: + """Close the backend and underlying MaruHandler.""" + self.memory_allocator.close() + self._handler.close() + logger.info("MaruBackend closed.") diff --git a/lmcache/v1/storage_backend/storage_manager.py b/lmcache/v1/storage_backend/storage_manager.py index 884e8845ca..138008228b 100644 --- a/lmcache/v1/storage_backend/storage_manager.py +++ b/lmcache/v1/storage_backend/storage_manager.py @@ -313,6 +313,8 @@ def _get_allocator_backend( ) -> AllocatorBackendInterface: if self.enable_pd: allocator_backend = self.storage_backends["PDBackend"] + elif "MaruBackend" in self.storage_backends: + allocator_backend = self.storage_backends["MaruBackend"] else: allocator_backend = self.storage_backends["LocalCPUBackend"] assert isinstance(allocator_backend, AllocatorBackendInterface) @@ -442,7 +444,7 @@ def get( memory_obj = backend.get_blocking(key) if memory_obj: if ( - backend_name not in ["LocalCPUBackend", "PDBackend"] + backend_name not in ["LocalCPUBackend", "PDBackend", "MaruBackend"] and "LocalCPUBackend" in self.storage_backends ): local_cpu_backend = self.storage_backends["LocalCPUBackend"] @@ -486,7 +488,7 @@ def batched_get( # Align with single-key `get()` logic: # auto-write remote data to local CPU cache if ( - backend_name not in ["LocalCPUBackend", "PDBackend"] + backend_name not in ["LocalCPUBackend", "PDBackend", "MaruBackend"] and "LocalCPUBackend" in self.storage_backends and None not in memory_objs ): From e29b2483d7b2b8d7a251576646864d5030aab640 Mon Sep 17 00:00:00 2001 From: youngrok-XCENA Date: Sun, 15 Mar 2026 15:29:48 +0000 Subject: [PATCH 02/21] chore: fix lint error --- lmcache/v1/storage_backend/__init__.py | 1 + lmcache/v1/storage_backend/maru_backend.py | 60 +++++++++++----------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/lmcache/v1/storage_backend/__init__.py b/lmcache/v1/storage_backend/__init__.py index 3ab017b0ec..bcdd6f5bff 100644 --- a/lmcache/v1/storage_backend/__init__.py +++ b/lmcache/v1/storage_backend/__init__.py @@ -219,6 +219,7 @@ def CreateStorageBackends( storage_backends[str(gds_backend)] = gds_backend if config.maru_path is not None and "MaruBackend" not in _skip: + # First Party from lmcache.v1.storage_backend.maru_backend import MaruBackend maru_backend = MaruBackend(config, metadata, loop, dst_device) diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index e885b2f7ec..95942aaa64 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -112,7 +112,8 @@ def _parse_pool_size(raw: Optional[str]) -> int: # ========================================================================= def _create_handler( - self, config: LMCacheEngineConfig, + self, + config: LMCacheEngineConfig, ) -> "MaruHandler": """Create and connect a MaruHandler. @@ -125,9 +126,7 @@ def _create_handler( Raises: RuntimeError: If MaruHandler connection fails. """ - assert config.maru_path is not None, ( - "maru_path must be set for MaruBackend" - ) + assert config.maru_path is not None, "maru_path must be set for MaruBackend" extra = config.extra_config or {} maru_config = MaruConfig( @@ -144,9 +143,7 @@ def _create_handler( handler = MaruHandler(maru_config) if not handler.connect(): - raise RuntimeError( - f"Failed to connect MaruHandler to {config.maru_path}" - ) + raise RuntimeError(f"Failed to connect MaruHandler to {config.maru_path}") logger.debug("[Maru] Connected to %s", config.maru_path) return handler @@ -168,9 +165,7 @@ def initialize_allocator( """ shapes = metadata.get_shapes() dtypes = metadata.get_dtypes() - fmt = ( - MemoryFormat.KV_MLA_FMT if metadata.use_mla else MemoryFormat.KV_2LTD - ) + fmt = MemoryFormat.KV_MLA_FMT if metadata.use_mla else MemoryFormat.KV_2LTD chunk_size = self._handler.owned_region_manager.get_chunk_size() return CxlMemoryAllocator( @@ -211,8 +206,10 @@ def allocate( """ obj = self.memory_allocator.allocate(shapes, dtypes, fmt) if obj is not None: - logger.debug("[Maru] allocate rid=%d pid=%d", - *CxlMemoryAllocator.decode_address(obj.metadata.address)) + logger.debug( + "[Maru] allocate rid=%d pid=%d", + *CxlMemoryAllocator.decode_address(obj.metadata.address), + ) else: logger.debug("[Maru] allocate failed shapes=%s dtypes=%s", shapes, dtypes) return obj @@ -239,9 +236,7 @@ def batched_allocate( Returns: List of MemoryObj, or None if any allocation fails. """ - return self.memory_allocator.batched_allocate( - shapes, dtypes, batch_size, fmt - ) + return self.memory_allocator.batched_allocate(shapes, dtypes, batch_size, fmt) # ========================================================================= # Put (async) @@ -337,12 +332,14 @@ async def _async_store( handle = allocator.create_store_handle(memory_obj) key_str = key.to_string() - await asyncio.to_thread( - self._handler.store, key_str, handle - ) + await asyncio.to_thread(self._handler.store, key_str, handle) - logger.debug("[Maru] store key=%s rid=%d pid=%d", - key, handle.region_id, handle.page_index) + logger.debug( + "[Maru] store key=%s rid=%d pid=%d", + key, + handle.region_id, + handle.page_index, + ) except Exception as e: logger.error("[Maru] store failed key=%s: %s", key, e) @@ -354,9 +351,7 @@ async def _async_store( try: on_complete_callback(key) except Exception as e: - logger.warning( - "on_complete_callback failed for key %s: %s", key, e - ) + logger.warning("on_complete_callback failed for key %s: %s", key, e) # ========================================================================= # Get (sync) @@ -396,15 +391,21 @@ def get_blocking( single_token_size=self._single_token_size, ) if memory_obj is None: - logger.debug("[Maru] get_blocking pool miss rid=%d pid=%d", - mem_info.region_id, mem_info.page_index) + logger.debug( + "[Maru] get_blocking pool miss rid=%d pid=%d", + mem_info.region_id, + mem_info.page_index, + ) return None memory_obj.ref_count_up() - logger.debug("[Maru] get_blocking rid=%d pid=%d size=%d", - mem_info.region_id, mem_info.page_index, - len(mem_info.view)) + logger.debug( + "[Maru] get_blocking rid=%d pid=%d size=%d", + mem_info.region_id, + mem_info.page_index, + len(mem_info.view), + ) return memory_obj # ========================================================================= @@ -464,9 +465,6 @@ def remove(self, key: CacheEngineKey, force: bool = True) -> bool: Returns: True if removed successfully. """ - with self.data_lock: - self.data.pop(key, None) - key_str = key.to_string() result = self._handler.delete(key_str) logger.debug("[Maru] remove key=%s success=%s", key, result) From c6d5dbf3a7679b1b49ffe398ccebca5b183742be Mon Sep 17 00:00:00 2001 From: youngrok-XCENA Date: Mon, 16 Mar 2026 04:21:42 +0000 Subject: [PATCH 03/21] refactor: update MaruBackend to use CxlMemoryAdapter facade API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename CxlMemoryAllocator → CxlMemoryAdapter import - Use handler.get_chunk_size() instead of handler.owned_region_manager.get_chunk_size() - Remove deprecated maru_connector.py and maru_adapter.py (storage backend replaces connector) --- .../storage_backend/connector/maru_adapter.py | 45 -- .../connector/maru_connector.py | 602 ------------------ lmcache/v1/storage_backend/maru_backend.py | 26 +- 3 files changed, 13 insertions(+), 660 deletions(-) delete mode 100644 lmcache/v1/storage_backend/connector/maru_adapter.py delete mode 100644 lmcache/v1/storage_backend/connector/maru_connector.py diff --git a/lmcache/v1/storage_backend/connector/maru_adapter.py b/lmcache/v1/storage_backend/connector/maru_adapter.py deleted file mode 100644 index 2d067a5284..0000000000 --- a/lmcache/v1/storage_backend/connector/maru_adapter.py +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# First Party -from lmcache.logging import init_logger -from lmcache.v1.storage_backend.connector import ( - ConnectorAdapter, - ConnectorContext, - parse_remote_url, -) -from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector -from lmcache.v1.storage_backend.connector.maru_connector import ( - MaruConnector, - MaruConnectorConfig, -) - -logger = init_logger(__name__) - - -class MaruConnectorAdapter(ConnectorAdapter): - """Adapter for maru scheme.""" - - def __init__(self) -> None: - super().__init__("maru://") - - def create_connector(self, context: ConnectorContext) -> RemoteConnector: - # Validate URL format (requires host:port) - _ = parse_remote_url(context.url) - - # Parse configuration from URL - maru_config = MaruConnectorConfig.from_url(context.url) - logger.info( - "Maru config from URL: server_url=%s, pool_size=%d", - maru_config.server_url, - maru_config.pool_size, - ) - - if context.config is None or context.metadata is None: - raise ValueError("Maru connector requires config and metadata") - - return MaruConnector( - url=context.url, - loop=context.loop, - local_cpu_backend=context.local_cpu_backend, - config=context.config, - metadata=context.metadata, - ) diff --git a/lmcache/v1/storage_backend/connector/maru_connector.py b/lmcache/v1/storage_backend/connector/maru_connector.py deleted file mode 100644 index 91e798eab6..0000000000 --- a/lmcache/v1/storage_backend/connector/maru_connector.py +++ /dev/null @@ -1,602 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Standard -from dataclasses import dataclass -from typing import List, Optional, no_type_check -from urllib.parse import urlparse -import asyncio -import re - -# Third Party -import torch - -# First Party -from lmcache.logging import init_logger -from lmcache.observability import LMCStatsMonitor -from lmcache.utils import CacheEngineKey -from lmcache.v1.config import LMCacheEngineConfig -from lmcache.v1.memory_management import MemoryObj, MemoryObjMetadata, TensorMemoryObj -from lmcache.v1.metadata import LMCacheMetadata -from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector -from lmcache.v1.storage_backend.local_cpu_backend import LocalCPUBackend - -logger = init_logger(__name__) - - -def parse_size(size_str: str) -> int: - """Parse human-readable size string (e.g., '1G', '100M', '1024K') to bytes.""" - if isinstance(size_str, int): - return size_str - s = str(size_str).strip().upper() - match = re.match(r"^(\d+(?:\.\d+)?)\s*([KMGT]?)B?$", s) - if not match: - try: - return int(s) - except ValueError: - raise ValueError( - f"Could not parse '{size_str}' as a size string or an integer." - ) from None - value, unit = float(match.group(1)), match.group(2) - multipliers = {"": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4} - return int(value * multipliers.get(unit, 1)) - - -@dataclass -class MaruConnectorConfig: - """Configuration for Maru connector.""" - - server_url: str = "tcp://localhost:5555" - pool_size: int = 1024 * 1024 * 1024 # 1GB default - instance_id: Optional[str] = None - auto_connect: bool = True - connection_timeout: float = 30.0 - operation_timeout: float = 10.0 - timeout_ms: int = 2000 # ZMQ socket timeout in milliseconds - use_async_rpc: bool = True # Use async DEALER-ROUTER RPC - max_inflight: int = 64 # Max concurrent in-flight async requests - eager_map: Optional[bool] = None # None = defer to MaruConfig/env - - @staticmethod - def from_url(url: str) -> "MaruConnectorConfig": - """Parse maru://host:port to extract server address only.""" - parsed = urlparse(url) - host = parsed.hostname or "localhost" - port = parsed.port or 5555 - server_url = f"tcp://{host}:{port}" - - return MaruConnectorConfig( - server_url=server_url, - ) - - @staticmethod - def from_lmcache_config(config: "LMCacheEngineConfig") -> "MaruConnectorConfig": - """Load from extra_config dict. - - All Maru-specific settings should be configured here. - Supports human-readable size strings (e.g., '4G', '500M') - for maru_pool_size. - """ - extra = config.extra_config or {} - raw_pool_size = extra.get("maru_pool_size", 1024**3) - pool_size = ( - parse_size(raw_pool_size) - if isinstance(raw_pool_size, str) - else int(raw_pool_size) - ) - return MaruConnectorConfig( - server_url=extra.get("maru_server_url", "tcp://localhost:5555"), - pool_size=pool_size, - instance_id=extra.get("maru_instance_id"), - auto_connect=extra.get("maru_auto_connect", True), - operation_timeout=float(extra.get("maru_operation_timeout", 10.0)), - timeout_ms=int(extra.get("maru_timeout_ms", 2000)), - use_async_rpc=extra.get("maru_use_async_rpc", True), - max_inflight=int(extra.get("maru_max_inflight", 64)), - eager_map=extra.get("maru_eager_map"), - ) - - -# Ping error codes -PING_SUCCESS = 0 -PING_NOT_CONNECTED = 1 -PING_RPC_ERROR = 2 - - -class MaruConnector(RemoteConnector): - """ - The remote url should start with "maru://" and have one host-port pair. - """ - - def __init__( - self, - url: str, - loop: asyncio.AbstractEventLoop, - local_cpu_backend: LocalCPUBackend, - config: LMCacheEngineConfig, - metadata: LMCacheMetadata, - ): - logger.info("init MaruConnector") - super().__init__(config, metadata) - if config.use_layerwise: - raise NotImplementedError( - "Maru connector does not yet support layerwise KV cache." - ) - - self.url = url - self.loop = loop - self.local_cpu_backend = local_cpu_backend - - # extra_config for all settings, URL for server address only - url_config = MaruConnectorConfig.from_url(url) - if config.extra_config: - self.maru_config = MaruConnectorConfig.from_lmcache_config(config) - # Use URL-derived server_url unless explicitly overridden - if not config.extra_config.get("maru_server_url"): - self.maru_config.server_url = url_config.server_url - else: - self.maru_config = url_config - - logger.info( - "Maru config: server_url=%s, pool_size=%d, instance_id=%s, eager_map=%s", - self.maru_config.server_url, - self.maru_config.pool_size, - self.maru_config.instance_id, - self.maru_config.eager_map, - ) - - # Initialize MaruHandler (lazy connection) - self._handle = None - self._connected = False - - # Metrics - self._stats_monitor = LMCStatsMonitor.GetOrCreate() - self._connection_attempts = 0 - self._connection_failures = 0 - self._rpc_errors = 0 - - # Try to connect if auto_connect is enabled - if self.maru_config.auto_connect: - self._init_handle() - - def _init_handle(self) -> bool: - try: - # Third Party - from maru import MaruConfig, MaruHandler - except ImportError: - logger.error("maru package not installed. Install with: pip install maru") - return False - - try: - maru_cfg_kwargs = dict( - server_url=self.maru_config.server_url, - instance_id=self.maru_config.instance_id, - pool_size=self.maru_config.pool_size, - chunk_size_bytes=self.full_chunk_size_bytes, - auto_connect=False, # We'll connect manually - timeout_ms=self.maru_config.timeout_ms, - use_async_rpc=self.maru_config.use_async_rpc, - max_inflight=self.maru_config.max_inflight, - ) - if self.maru_config.eager_map is not None: - maru_cfg_kwargs["eager_map"] = self.maru_config.eager_map - maru_cfg = MaruConfig(**maru_cfg_kwargs) - handle = MaruHandler(maru_cfg) - self._handle = handle - if handle.connect(): - self._connected = True - self._connection_attempts += 1 - logger.info("init maru handler success") - return True - else: - logger.error("fail to init maru handler, connect returned False") - self._connection_attempts += 1 - self._connection_failures += 1 - self._handle = None - return False - except Exception as e: - logger.error("fail to init maru handler: %s", e) - self._connection_attempts += 1 - self._connection_failures += 1 - self._handle = None - return False - - def _ensure_connected(self) -> bool: - if self._connected and self._handle is not None: - return True - return self._init_handle() - - async def exists(self, key: CacheEngineKey) -> bool: - if not self._ensure_connected(): - return False - assert self._handle is not None - - key_str = key.to_string() - try: - result = await asyncio.wait_for( - asyncio.to_thread(self._handle.exists, key_str), - timeout=self.maru_config.operation_timeout, - ) - logger.debug( - "maru exists key_str=%s, exists=%s", - key_str, - result, - ) - return result - except asyncio.TimeoutError: - self._rpc_errors += 1 - logger.warning("maru exists timed out for key_str=%s", key_str) - return False - except Exception as e: - self._rpc_errors += 1 - logger.error("maru exists failed: %s", e) - return False - - def exists_sync(self, key: CacheEngineKey) -> bool: - if not self._ensure_connected(): - return False - assert self._handle is not None - - key_str = key.to_string() - try: - result = self._handle.exists(key_str) - logger.debug("maru exists_sync key_str=%s, exists=%s", key_str, result) - return result - except Exception as e: - self._rpc_errors += 1 - logger.error("maru exists_sync failed: %s", e) - return False - - def _decode_memory_obj(self, info) -> Optional[MemoryObj]: - mv = info.view - - logger.debug("maru decode data=%d bytes", len(mv)) - - # memoryview -> torch tensor (zero-copy) - raw_data = torch.frombuffer(mv, dtype=torch.uint8) - - meta = MemoryObjMetadata( - shape=self.meta_shapes[0], - dtype=self.meta_dtypes[0], - address=0, - phy_size=raw_data.numel(), - ref_count=1, - pin_count=0, - fmt=self.meta_fmt, - shapes=self.meta_shapes, - dtypes=self.meta_dtypes, - ) - - return TensorMemoryObj( - raw_data=raw_data, - metadata=meta, - parent_allocator=None, - ) - - def _encode_memory_obj(self, memory_obj: MemoryObj): - # Third Party - from maru_handler.memory import MemoryInfo - - info = MemoryInfo(view=memory_obj.byte_array) - logger.debug("maru encode data=%d bytes", len(info.view)) - return info - - async def get(self, key: CacheEngineKey) -> Optional[MemoryObj]: - if not self._ensure_connected(): - return None - assert self._handle is not None - - key_str = key.to_string() - try: - info = await asyncio.wait_for( - asyncio.to_thread(self._handle.retrieve, key_str), - timeout=self.maru_config.operation_timeout, - ) - if info is None: - logger.debug("maru get MISS key_str=%s", key_str) - return None - - data_size = len(info.view) - logger.debug("maru get HIT key_str=%s, %d bytes", key_str, data_size) - memory_obj = self._decode_memory_obj(info) - if memory_obj is not None: - memory_obj = self.reshape_partial_chunk(memory_obj, data_size) - return memory_obj - except asyncio.TimeoutError: - self._rpc_errors += 1 - logger.warning("maru get timed out for key_str=%s", key_str) - return None - except Exception as e: - self._rpc_errors += 1 - logger.error("maru get failed: %s", e) - return None - - async def put(self, key: CacheEngineKey, memory_obj: MemoryObj): - if not self._ensure_connected(): - raise RuntimeError("MaruConnector not connected to Maru server") - assert self._handle is not None - - key_str = key.to_string() - info = self._encode_memory_obj(memory_obj) - data_size = len(info.view) - - try: - success = await asyncio.wait_for( - asyncio.to_thread(self._handle.store, key_str, info), - timeout=self.maru_config.operation_timeout, - ) - if success: - logger.debug("maru put key_str=%s, %d bytes", key_str, data_size) - else: - logger.warning("maru put failed key_str=%s", key_str) - except asyncio.TimeoutError: - self._rpc_errors += 1 - logger.warning( - "maru put timed out for key_str=%s. Decode instance may redo prefill.", - key_str, - ) - except Exception as e: - self._rpc_errors += 1 - logger.error("maru put failed: %s", e) - raise - - # TODO - @no_type_check - async def list(self) -> List[str]: - pass - - def remove_sync(self, key: CacheEngineKey) -> bool: - if not self._ensure_connected(): - return False - assert self._handle is not None - - key_str = key.to_string() - try: - return self._handle.delete(key_str) - except Exception as e: - self._rpc_errors += 1 - logger.error("maru remove_sync failed: %s", e) - return False - - async def close(self): - if self._handle is not None: - try: - self._handle.close() - except Exception as e: - logger.error("fail to close maru handler: %s", e) - finally: - self._handle = None - self._connected = False - logger.info("closed the maru connection") - - def support_batched_get(self) -> bool: - return True - - def support_batched_put(self) -> bool: - return True - - def support_batched_async_contains(self) -> bool: - return True - - def support_batched_contains(self) -> bool: - return True - - def batched_contains(self, keys: List[CacheEngineKey]) -> int: - if not self._ensure_connected() or not keys: - return 0 - assert self._handle is not None - - key_strs = [k.to_string() for k in keys] - try: - results = self._handle.batch_exists(key_strs) - count = 0 - for exists in results: - if not exists: - break - count += 1 - logger.debug("maru batched_contains hits=%d/%d", count, len(keys)) - return count - except Exception as e: - self._rpc_errors += 1 - logger.error("maru batched_contains failed: %s", e) - return 0 - - async def batched_get( - self, keys: List[CacheEngineKey] - ) -> List[Optional[MemoryObj]]: - if not self._ensure_connected() or not keys: - return [None] * len(keys) - assert self._handle is not None - - key_strs = [k.to_string() for k in keys] - try: - raw_results = await asyncio.wait_for( - asyncio.to_thread(self._handle.batch_retrieve, key_strs), - timeout=self.maru_config.operation_timeout, - ) - hits = sum(1 for r in raw_results if r is not None) - logger.debug("maru batched_get hits=%d/%d", hits, len(keys)) - memory_objs: List[Optional[MemoryObj]] = [] - for info in raw_results: - if info is None: - memory_objs.append(None) - continue - memory_obj = self._decode_memory_obj(info) - if memory_obj is not None: - memory_obj = self.reshape_partial_chunk(memory_obj, len(info.view)) - memory_objs.append(memory_obj) - return memory_objs - except asyncio.TimeoutError: - self._rpc_errors += 1 - logger.warning("maru batched_get timed out for %d keys", len(keys)) - return [None] * len(keys) - except Exception as e: - self._rpc_errors += 1 - logger.error("maru batched_get failed: %s", e) - return [None] * len(keys) - - async def batched_put( - self, - keys: List[CacheEngineKey], - memory_objs: List[MemoryObj], - ): - if not self._ensure_connected() or not keys: - return - assert self._handle is not None - - key_strs = [k.to_string() for k in keys] - infos = [self._encode_memory_obj(obj) for obj in memory_objs] - total_bytes = sum(len(info.view) for info in infos) - - try: - results = await asyncio.wait_for( - asyncio.to_thread( - self._handle.batch_store, - key_strs, - infos, - ), - timeout=self.maru_config.operation_timeout, - ) - stored = sum(results) if results else 0 - if stored < len(keys): - logger.warning( - "maru batched_put partial %d/%d keys", - stored, - len(keys), - ) - else: - logger.debug( - "maru batched_put %d keys, %d bytes", - len(keys), - total_bytes, - ) - except asyncio.TimeoutError: - self._rpc_errors += 1 - logger.warning( - "maru batched_put timed out for %d keys. " - "Decode instance may redo prefill.", - len(keys), - ) - except Exception as e: - self._rpc_errors += 1 - logger.error("maru batched_put failed: %s", e) - raise - - async def batched_get_non_blocking( - self, - lookup_id: str, - keys: List[CacheEngineKey], - ) -> List[MemoryObj]: - if not self._ensure_connected() or not keys: - return [] - assert self._handle is not None - - key_strs = [k.to_string() for k in keys] - try: - raw_results = await asyncio.wait_for( - asyncio.to_thread(self._handle.batch_retrieve, key_strs), - timeout=self.maru_config.operation_timeout, - ) - - # Build consecutive prefix of hits - memory_objs = [] - for info in raw_results: - if info is None: - break - memory_obj = self._decode_memory_obj(info) - if memory_obj is None: - break - memory_obj = self.reshape_partial_chunk(memory_obj, len(info.view)) - memory_objs.append(memory_obj) - - logger.debug( - "maru batched_get_nb lookup_id=%s, hits=%d/%d", - lookup_id, - len(memory_objs), - len(keys), - ) - return memory_objs - except asyncio.TimeoutError: - self._rpc_errors += 1 - logger.warning( - "maru batched_get_non_blocking timed out for lookup_id=%s, %d keys", - lookup_id, - len(keys), - ) - return [] - except Exception as e: - self._rpc_errors += 1 - logger.error("maru batched_get_non_blocking failed: %s", e) - return [] - - def support_batched_get_non_blocking(self) -> bool: - return True - - def support_ping(self) -> bool: - return True - - async def ping(self) -> int: - if not self._connected or self._handle is None: - self._stats_monitor.update_remote_ping_error_code(PING_NOT_CONNECTED) - return PING_NOT_CONNECTED - try: - healthy = await asyncio.wait_for( - asyncio.to_thread(self._handle.healthcheck), - timeout=self.maru_config.operation_timeout, - ) - if not healthy: - self._stats_monitor.update_remote_ping_error_code(PING_RPC_ERROR) - return PING_RPC_ERROR - self._stats_monitor.update_remote_ping_error_code(PING_SUCCESS) - return PING_SUCCESS - except asyncio.TimeoutError: - self._rpc_errors += 1 - logger.warning("maru ping timed out") - self._stats_monitor.update_remote_ping_error_code(PING_RPC_ERROR) - return PING_RPC_ERROR - except Exception as e: - self._rpc_errors += 1 - logger.warning("maru ping failed: %s", e) - self._stats_monitor.update_remote_ping_error_code(PING_RPC_ERROR) - return PING_RPC_ERROR - - async def batched_async_contains( - self, - lookup_id: str, - keys: List[CacheEngineKey], - pin: bool = False, - ) -> int: - if not self._ensure_connected() or not keys: - return 0 - assert self._handle is not None - - key_strs = [k.to_string() for k in keys] - try: - results = await asyncio.wait_for( - asyncio.to_thread(self._handle.batch_exists, key_strs), - timeout=self.maru_config.operation_timeout, - ) - # Count consecutive hits from start - count = 0 - for exists in results: - if not exists: - break - count += 1 - logger.debug( - "maru batched_async_contains lookup_id=%s, hits=%d/%d", - lookup_id, - count, - len(keys), - ) - return count - except asyncio.TimeoutError: - self._rpc_errors += 1 - logger.warning( - "maru batched_async_contains timed out for lookup_id=%s, %d keys", - lookup_id, - len(keys), - ) - return 0 - except Exception as e: - self._rpc_errors += 1 - logger.error("maru batched_async_contains failed: %s", e) - return 0 diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index 95942aaa64..f00f8e14d5 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -9,7 +9,7 @@ # Third Party from maru import MaruConfig, MaruHandler -from maru_lmcache import CxlMemoryAllocator +from maru_lmcache import CxlMemoryAdapter import torch # First Party @@ -31,7 +31,7 @@ class MaruBackend(AllocatorBackendInterface): """Maru CXL shared memory storage backend. - Implements AllocatorBackendInterface with its own CxlMemoryAllocator. + Implements AllocatorBackendInterface with its own CxlMemoryAdapter. No LocalCPUBackend needed — data lives directly in CXL mmap memory. Put is async (Future): metadata registration via RPC. @@ -154,21 +154,21 @@ def _create_handler( def initialize_allocator( self, config: LMCacheEngineConfig, metadata: LMCacheMetadata ) -> MemoryAllocatorInterface: - """Create CxlMemoryAllocator backed by the connected handler. + """Create CxlMemoryAdapter backed by the connected handler. Args: config: LMCache engine configuration. metadata: LMCache engine metadata. Returns: - CxlMemoryAllocator instance. + CxlMemoryAdapter instance. """ shapes = metadata.get_shapes() dtypes = metadata.get_dtypes() fmt = MemoryFormat.KV_MLA_FMT if metadata.use_mla else MemoryFormat.KV_2LTD - chunk_size = self._handler.owned_region_manager.get_chunk_size() + chunk_size = self._handler.get_chunk_size() - return CxlMemoryAllocator( + return CxlMemoryAdapter( handler=self._handler, shapes=shapes, dtypes=dtypes, @@ -177,7 +177,7 @@ def initialize_allocator( ) def get_memory_allocator(self) -> MemoryAllocatorInterface: - """Returns the underlying CxlMemoryAllocator.""" + """Returns the underlying CxlMemoryAdapter.""" return self.memory_allocator def get_allocator_backend(self) -> "MaruBackend": @@ -192,7 +192,7 @@ def allocate( eviction: bool = True, busy_loop: bool = True, ) -> Optional[MemoryObj]: - """Allocate CXL-backed memory via CxlMemoryAllocator. + """Allocate CXL-backed memory via CxlMemoryAdapter. Args: shapes: Tensor shape(s). @@ -208,7 +208,7 @@ def allocate( if obj is not None: logger.debug( "[Maru] allocate rid=%d pid=%d", - *CxlMemoryAllocator.decode_address(obj.metadata.address), + *CxlMemoryAdapter.decode_address(obj.metadata.address), ) else: logger.debug("[Maru] allocate failed shapes=%s dtypes=%s", shapes, dtypes) @@ -318,7 +318,7 @@ async def _async_store( ) -> None: """Register KV metadata with MaruServer (runs in event loop). - Uses CxlMemoryAllocator.create_store_handle() to extract + Uses CxlMemoryAdapter.create_store_handle() to extract (region_id, page_index) from the MemoryObj's encoded address. Args: @@ -328,7 +328,7 @@ async def _async_store( """ try: allocator = self.memory_allocator - assert isinstance(allocator, CxlMemoryAllocator) + assert isinstance(allocator, CxlMemoryAdapter) handle = allocator.create_store_handle(memory_obj) key_str = key.to_string() @@ -364,7 +364,7 @@ def get_blocking( """Blocking get: read KV cache directly from CXL memory. Queries MaruServer for metadata, then returns a MemoryObj - via CxlMemoryAllocator.get_by_location(). + via CxlMemoryAdapter.get_by_location(). Args: key: The cache key. @@ -382,7 +382,7 @@ def get_blocking( return None allocator = self.memory_allocator - assert isinstance(allocator, CxlMemoryAllocator) + assert isinstance(allocator, CxlMemoryAdapter) memory_obj = allocator.get_by_location( region_id=mem_info.region_id, From 61e4b6289c1e7463e2e0b159061335a51431753b Mon Sep 17 00:00:00 2001 From: Rocky Song <167060552+youngrok-XCENA@users.noreply.github.com> Date: Tue, 17 Mar 2026 11:29:13 +0900 Subject: [PATCH 04/21] feat/maru backend (#5) * fix(maru): address medium/low review feedback for MaruBackend - Add use_layerwise guard (NotImplementedError) - Change zip strict=False to strict=True in batched_submit_put_task - Add warning log when contains(pin=True) is called - Add warning log for in-flight put_tasks on close() * chore: remove pin warning * feat: implement batched_async_contains and batched_get_non_blocking for MaruBackend Enable MaruBackend to participate in StorageManager.async_lookup_and_prefetch() by implementing the two required async lookup APIs. Both use asyncio.to_thread to wrap sync RPC calls (handler.exists / handler.retrieve) without blocking the event loop. * fix: add maru:// URL scheme conversion and pin on get_blocking - Convert maru:// to tcp:// in _create_handler for ZMQ compatibility - Call memory_obj.pin() in get_blocking to match cleanup unpin, fixing pin_count=-1 warning on retrieve path * fix(maru): only invoke on_complete_callback on successful store Previously _async_store called on_complete_callback in the finally block regardless of success/failure, which could signal false success to callers and mask CXL page leaks on store failure. Aligns with LocalCPUBackend and NixlDynamic which only call callback on success. --- lmcache/v1/storage_backend/maru_backend.py | 92 +++++++++++++++++++++- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index f00f8e14d5..ef4ef8a21b 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -53,6 +53,11 @@ def __init__( ): super().__init__(dst_device=dst_device) + if config.use_layerwise: + raise NotImplementedError( + "MaruBackend does not yet support layerwise KV cache." + ) + # 1. Config self.config = config self.loop = loop @@ -128,9 +133,14 @@ def _create_handler( """ assert config.maru_path is not None, "maru_path must be set for MaruBackend" + # Convert maru:// scheme to tcp:// for ZMQ + server_url = config.maru_path + if server_url.startswith("maru://"): + server_url = "tcp://" + server_url[len("maru://"):] + extra = config.extra_config or {} maru_config = MaruConfig( - server_url=config.maru_path, + server_url=server_url, instance_id=extra.get("maru_instance_id"), pool_size=self._parse_pool_size(config.maru_pool_size), chunk_size_bytes=self._full_chunk_size_bytes, @@ -303,7 +313,7 @@ def batched_submit_put_task( List of Futures, one per key. """ futures = [] - for key, memory_obj in zip(keys, memory_objs, strict=False): + for key, memory_obj in zip(keys, memory_objs, strict=True): future = self.submit_put_task( key, memory_obj, on_complete_callback=on_complete_callback ) @@ -326,6 +336,7 @@ async def _async_store( memory_obj: MemoryObj backed by CXL memory. on_complete_callback: Optional callback after registration. """ + success = False try: allocator = self.memory_allocator assert isinstance(allocator, CxlMemoryAdapter) @@ -333,6 +344,7 @@ async def _async_store( key_str = key.to_string() await asyncio.to_thread(self._handler.store, key_str, handle) + success = True logger.debug( "[Maru] store key=%s rid=%d pid=%d", @@ -347,7 +359,7 @@ async def _async_store( with self.put_lock: self.put_tasks.discard(key) - if on_complete_callback is not None: + if success and on_complete_callback is not None: try: on_complete_callback(key) except Exception as e: @@ -399,6 +411,7 @@ def get_blocking( return None memory_obj.ref_count_up() + memory_obj.pin() logger.debug( "[Maru] get_blocking rid=%d pid=%d size=%d", @@ -408,6 +421,72 @@ def get_blocking( ) return memory_obj + # ========================================================================= + # Async lookup API (used by StorageManager.async_lookup_and_prefetch) + # ========================================================================= + + async def batched_async_contains( + self, + lookup_id: str, + keys: List[CacheEngineKey], + pin: bool = False, + ) -> int: + """Check how many prefix keys exist on MaruServer. + + Prefix-based: returns the count of contiguous keys starting + from index 0 that exist. Stops at first miss. + + Args: + lookup_id: Unique request identifier. + keys: Keys to check in prefix order. + pin: Whether to pin. Not supported; logged as debug. + + Returns: + Number of prefix-contiguous keys that exist. + """ + + def _contains_prefix() -> int: + num_hit = 0 + for key in keys: + if not self.contains(key): + break + num_hit += 1 + return num_hit + + return await asyncio.to_thread(_contains_prefix) + + async def batched_get_non_blocking( + self, + lookup_id: str, + keys: list[CacheEngineKey], + transfer_spec: Any = None, + ) -> list[MemoryObj]: + """Non-blocking batched get via CXL direct read. + + Each key triggers a metadata lookup on MaruServer followed by + a zero-copy CXL memory read. Stops at first miss and returns + the prefix that was successfully retrieved. + + Args: + lookup_id: Unique request identifier. + keys: Keys to retrieve (already confirmed by contains). + transfer_spec: Unused. + + Returns: + List of MemoryObjs backed by CXL memory. + """ + + def _get_batch() -> list[MemoryObj]: + results: list[MemoryObj] = [] + for key in keys: + mem_obj = self.get_blocking(key) + if mem_obj is None: + break + results.append(mem_obj) + return results + + return await asyncio.to_thread(_get_batch) + # ========================================================================= # Contains / Pin / Unpin / Remove # ========================================================================= @@ -476,6 +555,13 @@ def remove(self, key: CacheEngineKey, force: bool = True) -> bool: def close(self) -> None: """Close the backend and underlying MaruHandler.""" + with self.put_lock: + pending = len(self.put_tasks) + if pending > 0: + logger.warning( + "[Maru] closing with %d in-flight put tasks still pending", + pending, + ) self.memory_allocator.close() self._handler.close() logger.info("MaruBackend closed.") From 7c021a6ba1d437e42605beff26640017ef3c111f Mon Sep 17 00:00:00 2001 From: hyunyul-XCENA Date: Wed, 18 Mar 2026 16:39:57 +0900 Subject: [PATCH 05/21] feat(maru): use batch RPC APIs for MaruHandler operations (#11) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(maru): use batch RPC APIs for MaruHandler operations Replace per-key for-loop RPC calls with single batch RPC calls: - batched_submit_put_task: N x handler.store() → 1 x handler.batch_store() - batched_async_contains: N x handler.exists() → 1 x handler.batch_exists() - batched_get_non_blocking: N x handler.retrieve() → 1 x handler.batch_retrieve() Add missing batch methods: - batched_contains: new, using handler.batch_exists() - batched_get_blocking: new, using handler.batch_retrieve() * fix(maru): address design review findings for batch operations - Move handle creation inside try block to prevent ghost keys in put_tasks on create_store_handle failure - Only fire on_complete_callback for keys that succeeded in batch_store - Remove observability code (metrics, time imports) — belongs in separate observability branch - Update batched_submit_put_task docstring to reflect single-Future return semantics * fix: resolve ruff-format and mypy pre-commit failures - Fix ruff-format: whitespace in slice, logger arg formatting - Fix mypy: guard sum(results)/len(results) with None check * fix(maru): address PR #11 review feedback - Add assert memory_obj.tensor is not None in batched_submit_put_task for parity with single-key submit_put_task - Add summary debug logs to batched_get_blocking and batched_get_non_blocking for production observability - Deduplicate batched_async_contains by delegating to batched_contains --- lmcache/v1/storage_backend/maru_backend.py | 192 +++++++++++++++++---- 1 file changed, 159 insertions(+), 33 deletions(-) diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index ef4ef8a21b..85e230f8de 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -136,7 +136,7 @@ def _create_handler( # Convert maru:// scheme to tcp:// for ZMQ server_url = config.maru_path if server_url.startswith("maru://"): - server_url = "tcp://" + server_url[len("maru://"):] + server_url = "tcp://" + server_url[len("maru://") :] extra = config.extra_config or {} maru_config = MaruConfig( @@ -301,7 +301,7 @@ def batched_submit_put_task( transfer_spec: Any = None, on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, ) -> Union[List[Future], None]: - """Submit batched put tasks. + """Submit batched put tasks via single batch_store RPC. Args: keys: The cache keys. @@ -310,15 +310,19 @@ def batched_submit_put_task( on_complete_callback: Optional per-key callback. Returns: - List of Futures, one per key. + List containing a single Future for the entire batch. """ - futures = [] - for key, memory_obj in zip(keys, memory_objs, strict=True): - future = self.submit_put_task( - key, memory_obj, on_complete_callback=on_complete_callback - ) - futures.append(future) - return futures + for memory_obj in memory_objs: + assert memory_obj.tensor is not None + + with self.put_lock: + self.put_tasks.update(keys) + + future = asyncio.run_coroutine_threadsafe( + self._async_batch_store(list(keys), memory_objs, on_complete_callback), + self.loop, + ) + return [future] async def _async_store( self, @@ -365,6 +369,44 @@ async def _async_store( except Exception as e: logger.warning("on_complete_callback failed for key %s: %s", key, e) + async def _async_batch_store( + self, + keys: List[CacheEngineKey], + memory_objs: List[MemoryObj], + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> None: + """Register multiple KV metadata entries via single batch_store RPC.""" + results: Optional[list[bool]] = None + try: + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + + key_strs = [k.to_string() for k in keys] + handles = [allocator.create_store_handle(m) for m in memory_objs] + + results = await asyncio.to_thread( + self._handler.batch_store, key_strs, handles + ) + if results is not None: + logger.debug("[Maru] batch_store %d/%d ok", sum(results), len(results)) + except Exception as e: + logger.error("[Maru] batch_store failed: %s", e) + finally: + with self.put_lock: + self.put_tasks.difference_update(keys) + + if on_complete_callback is not None: + for i, key in enumerate(keys): + if results is not None and i < len(results) and results[i]: + try: + on_complete_callback(key) + except Exception as e: + logger.warning( + "on_complete_callback failed for key %s: %s", + key, + e, + ) + # ========================================================================= # Get (sync) # ========================================================================= @@ -421,6 +463,49 @@ def get_blocking( ) return memory_obj + def batched_get_blocking( + self, + keys: List[CacheEngineKey], + ) -> List[Optional[MemoryObj]]: + """Blocking batched get via single batch_retrieve RPC. + + Args: + keys: The cache keys. + + Returns: + List of MemoryObj (None for misses). + """ + if self._mla_worker_id_as0_mode: + keys = [k.with_new_worker_id(0) for k in keys] + + key_strs = [k.to_string() for k in keys] + mem_infos = self._handler.batch_retrieve(key_strs) + + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + + results: List[Optional[MemoryObj]] = [] + for mem_info in mem_infos: + if mem_info is None: + results.append(None) + continue + memory_obj = allocator.get_by_location( + region_id=mem_info.region_id, + page_index=mem_info.page_index, + actual_size=len(mem_info.view), + single_token_size=self._single_token_size, + ) + if memory_obj is None: + results.append(None) + continue + memory_obj.ref_count_up() + memory_obj.pin() + results.append(memory_obj) + + hits = sum(1 for r in results if r is not None) + logger.debug("[Maru] batch_retrieve %d/%d hits", hits, len(results)) + return results + # ========================================================================= # Async lookup API (used by StorageManager.async_lookup_and_prefetch) # ========================================================================= @@ -431,10 +516,10 @@ async def batched_async_contains( keys: List[CacheEngineKey], pin: bool = False, ) -> int: - """Check how many prefix keys exist on MaruServer. + """Check how many prefix keys exist via single batch_exists RPC. - Prefix-based: returns the count of contiguous keys starting - from index 0 that exist. Stops at first miss. + Returns the count of contiguous keys starting from index 0 + that exist. Stops at first miss. Args: lookup_id: Unique request identifier. @@ -444,16 +529,7 @@ async def batched_async_contains( Returns: Number of prefix-contiguous keys that exist. """ - - def _contains_prefix() -> int: - num_hit = 0 - for key in keys: - if not self.contains(key): - break - num_hit += 1 - return num_hit - - return await asyncio.to_thread(_contains_prefix) + return await asyncio.to_thread(self.batched_contains, keys, pin) async def batched_get_non_blocking( self, @@ -461,11 +537,11 @@ async def batched_get_non_blocking( keys: list[CacheEngineKey], transfer_spec: Any = None, ) -> list[MemoryObj]: - """Non-blocking batched get via CXL direct read. + """Non-blocking batched get via single batch_retrieve RPC. - Each key triggers a metadata lookup on MaruServer followed by - a zero-copy CXL memory read. Stops at first miss and returns - the prefix that was successfully retrieved. + Uses handler.batch_retrieve() for a single RPC call, then + resolves each MemoryInfo to a MemoryObj via CxlMemoryAdapter. + Stops at first miss and returns the prefix. Args: lookup_id: Unique request identifier. @@ -476,16 +552,40 @@ async def batched_get_non_blocking( List of MemoryObjs backed by CXL memory. """ - def _get_batch() -> list[MemoryObj]: + def _batch_get() -> list[MemoryObj]: + if self._mla_worker_id_as0_mode: + actual_keys = [k.with_new_worker_id(0) for k in keys] + else: + actual_keys = list(keys) + + key_strs = [k.to_string() for k in actual_keys] + mem_infos = self._handler.batch_retrieve(key_strs) + + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + results: list[MemoryObj] = [] - for key in keys: - mem_obj = self.get_blocking(key) - if mem_obj is None: + for mem_info in mem_infos: + if mem_info is None: + break + memory_obj = allocator.get_by_location( + region_id=mem_info.region_id, + page_index=mem_info.page_index, + actual_size=len(mem_info.view), + single_token_size=self._single_token_size, + ) + if memory_obj is None: break - results.append(mem_obj) + memory_obj.ref_count_up() + memory_obj.pin() + results.append(memory_obj) + + logger.debug( + "[Maru] batch_get_non_blocking %d/%d hits", len(results), len(keys) + ) return results - return await asyncio.to_thread(_get_batch) + return await asyncio.to_thread(_batch_get) # ========================================================================= # Contains / Pin / Unpin / Remove @@ -506,6 +606,32 @@ def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: return self._handler.exists(key.to_string()) + def batched_contains( + self, + keys: List[CacheEngineKey], + pin: bool = False, + ) -> int: + """Check how many prefix keys exist via single batch_exists RPC. + + Args: + keys: Keys to check in prefix order. + pin: Whether to pin. Not supported. + + Returns: + Number of prefix-contiguous keys that exist. + """ + if self._mla_worker_id_as0_mode: + keys = [k.with_new_worker_id(0) for k in keys] + + key_strs = [k.to_string() for k in keys] + results = self._handler.batch_exists(key_strs) + num_hit = 0 + for exists in results: + if not exists: + break + num_hit += 1 + return num_hit + def pin(self, key: CacheEngineKey) -> bool: """Pin a key to prevent eviction. From eb5afe09562881233ea35fc5f71ab608b0edd1a6 Mon Sep 17 00:00:00 2001 From: jooho Date: Thu, 19 Mar 2026 21:24:24 +0900 Subject: [PATCH 06/21] feat: MaruBackend allocator fallback, ImportError, batch RPC (#7) * feat: add maru_as_primary_allocator config and PD mutual exclusion assert - Assert enable_pd=False when maru_path is set (mutual exclusion) - Add maru_as_primary_allocator config option (default: True) - Update _get_allocator_backend to respect the new option * fix: change maru_as_primary_allocator default to False * refactor: remove maru_as_primary_allocator, prefer LocalCPUBackend when available Simplify allocator selection: when MaruBackend is present, use LocalCPUBackend if it exists, otherwise fall back to MaruBackend. * refactor: remove PD mutual exclusion assert from MaruBackend creation * feat: add explicit ImportError messages for maru dependencies Follow mooncakestore_connector.py pattern to provide clear error messages when maru or maru_lmcache packages are not installed. * feat: port batch RPC calls from maru-connector with debug logging Use handler.batch_exists, batch_retrieve, batch_store instead of individual loop calls to reduce RPC round-trips. Added [DEBUG] logs for batch call tracing (to be removed later). * style: fix isort and ruff-format for maru_backend --- lmcache/v1/storage_backend/maru_backend.py | 20 +++++++++++++++++-- lmcache/v1/storage_backend/storage_manager.py | 5 ++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index 85e230f8de..c9bd8f9216 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -7,9 +7,25 @@ import re import threading +try: + # Third Party + from maru import MaruConfig, MaruHandler +except ImportError as e: + raise ImportError( + "The 'maru' package is required to use MaruBackend. " + "Please install it according to the Maru setup documentation." + ) from e + +try: + # Third Party + from maru_lmcache import CxlMemoryAdapter +except ImportError as e: + raise ImportError( + "The 'maru_lmcache' package is required to use MaruBackend. " + "Please install it according to the Maru setup documentation." + ) from e + # Third Party -from maru import MaruConfig, MaruHandler -from maru_lmcache import CxlMemoryAdapter import torch # First Party diff --git a/lmcache/v1/storage_backend/storage_manager.py b/lmcache/v1/storage_backend/storage_manager.py index 138008228b..674241ca62 100644 --- a/lmcache/v1/storage_backend/storage_manager.py +++ b/lmcache/v1/storage_backend/storage_manager.py @@ -314,7 +314,10 @@ def _get_allocator_backend( if self.enable_pd: allocator_backend = self.storage_backends["PDBackend"] elif "MaruBackend" in self.storage_backends: - allocator_backend = self.storage_backends["MaruBackend"] + if "LocalCPUBackend" in self.storage_backends: + allocator_backend = self.storage_backends["LocalCPUBackend"] + else: + allocator_backend = self.storage_backends["MaruBackend"] else: allocator_backend = self.storage_backends["LocalCPUBackend"] assert isinstance(allocator_backend, AllocatorBackendInterface) From b5599d6c92d537e86c6fbeb5d2affb8b8d3ac3eb Mon Sep 17 00:00:00 2001 From: seohui-XCENA Date: Thu, 19 Mar 2026 21:24:32 +0900 Subject: [PATCH 07/21] feat: MaruBackend pin/unpin and ref_count management (#10) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: implement MaruBackend pin/unpin and ref_count management - contains(pin=True): call handler.exists_and_pin() for server-side pin - batched_async_contains(): use batch_exists_and_pin RPC instead of per-key calls - get_blocking(): remove local MemoryObj.pin(), keep ref_count_up() only - submit_put_task(): add ref_count_up() x2 (pool ref + async_store guard) - _async_store(): add ref_count_down() on completion - pin()/unpin(): delegate to handler.pin_kv()/unpin_kv() - batched_unpin(): batch RPC override for efficiency - StorageManager.batched_unpin(): use MaruBackend.batched_unpin() when available * feat: add server unpin in async cleanup_memory_objs for MaruBackend - When MaruBackend is present, use batched_unpin RPC instead of local MemoryObj.unpin() - Other backends keep existing memory_obj.unpin() behavior unchanged * chore: apply ruff-format * feat: add batched_unpin to StorageBackendInterface - Add batched_unpin() default implementation to abstract_backend (loops unpin per key) - MaruBackend overrides with single batch RPC for optimization - StorageManager.batched_unpin() uses backend.batched_unpin() directly * fix: revert batch RPC in batched_async_contains and add error handling in cleanup_memory_objs - Replace batch_exists/batch_exists_and_pin with per-key contains() loop (batch RPC optimization will be handled in a separate PR) - Add try/finally in cleanup_memory_objs to ensure ref_count_down runs even if batched_unpin fails * fix: add ref_count management to batched_submit_put_task Match submit_put_task pattern: ref_count_up x2 before async RPC, ref_count_down in _async_batch_store finally block. * refactor: simplify ref_count and remove local pin in MaruBackend - ref_count: single up only (no down), matching LocalCPUBackend pool pattern 1(alloc) → 2(ref_up) → 1(SM ref_down) — pool reference retained - remove memory_obj.pin() from all get paths (batched_get_blocking, batched_get_non_blocking) — server pin handles eviction protection * revert: remove batched_unpin and revert to per-key unpin Revert batched_unpin optimization to keep implementation simple. The batch RPC optimization will be addressed in a follow-up PR. * fix: add ref_count_down on async store failure and MLA worker_id fix in remove - Add ref_count_down() in _async_store() finally block when store fails - Add ref_count_down() for failed stores in _async_batched_store() - Add MLA worker_id normalization in remove() for consistency * feat: add batched_unpin to MaruBackend for single-RPC batch optimization --- lmcache/v1/storage_backend/maru_backend.py | 68 +++++++++++++++++----- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index c9bd8f9216..cc5d56edda 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -301,6 +301,8 @@ def submit_put_task( """ assert memory_obj.tensor is not None + memory_obj.ref_count_up() + with self.put_lock: self.put_tasks.add(key) @@ -330,6 +332,7 @@ def batched_submit_put_task( """ for memory_obj in memory_objs: assert memory_obj.tensor is not None + memory_obj.ref_count_up() with self.put_lock: self.put_tasks.update(keys) @@ -379,6 +382,9 @@ async def _async_store( with self.put_lock: self.put_tasks.discard(key) + if not success: + memory_obj.ref_count_down() + if success and on_complete_callback is not None: try: on_complete_callback(key) @@ -411,6 +417,12 @@ async def _async_batch_store( with self.put_lock: self.put_tasks.difference_update(keys) + # Release ref_count for failed stores + for i, memory_obj in enumerate(memory_objs): + succeeded = results is not None and i < len(results) and results[i] + if not succeeded: + memory_obj.ref_count_down() + if on_complete_callback is not None: for i, key in enumerate(keys): if results is not None and i < len(results) and results[i]: @@ -469,7 +481,6 @@ def get_blocking( return None memory_obj.ref_count_up() - memory_obj.pin() logger.debug( "[Maru] get_blocking rid=%d pid=%d size=%d", @@ -515,7 +526,6 @@ def batched_get_blocking( results.append(None) continue memory_obj.ref_count_up() - memory_obj.pin() results.append(memory_obj) hits = sum(1 for r in results if r is not None) @@ -593,7 +603,6 @@ def _batch_get() -> list[MemoryObj]: if memory_obj is None: break memory_obj.ref_count_up() - memory_obj.pin() results.append(memory_obj) logger.debug( @@ -612,7 +621,8 @@ def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: Args: key: The cache key. - pin: If True, pin the entry. (TODO: delegate to handler) + pin: If True, atomically check existence and pin the entry + to protect it from eviction. Returns: True if key exists. @@ -620,7 +630,10 @@ def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: if self._mla_worker_id_as0_mode: key = key.with_new_worker_id(0) - return self._handler.exists(key.to_string()) + key_str = key.to_string() + if pin: + return self._handler.exists_and_pin(key_str) + return self._handler.exists(key_str) def batched_contains( self, @@ -631,7 +644,8 @@ def batched_contains( Args: keys: Keys to check in prefix order. - pin: Whether to pin. Not supported. + pin: If True, atomically check and pin via + batch_exists_and_pin RPC. Returns: Number of prefix-contiguous keys that exist. @@ -640,7 +654,10 @@ def batched_contains( keys = [k.with_new_worker_id(0) for k in keys] key_strs = [k.to_string() for k in keys] - results = self._handler.batch_exists(key_strs) + if pin: + results = self._handler.batch_exists_and_pin(key_strs) + else: + results = self._handler.batch_exists(key_strs) num_hit = 0 for exists in results: if not exists: @@ -649,10 +666,9 @@ def batched_contains( return num_hit def pin(self, key: CacheEngineKey) -> bool: - """Pin a key to prevent eviction. + """Pin a key to prevent eviction on MaruServer. - TODO: Delegate to MaruHandler.pin() once server-side - ref_count management is implemented. + Increments the server-side pin_count. Args: key: The cache key. @@ -660,13 +676,15 @@ def pin(self, key: CacheEngineKey) -> bool: Returns: True if pinned successfully. """ - return False + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + return self._handler.pin_kv(key.to_string()) def unpin(self, key: CacheEngineKey) -> bool: - """Unpin a key to allow eviction. + """Unpin a key to allow eviction on MaruServer. - TODO: Delegate to MaruHandler.unpin() once server-side - ref_count management is implemented. + Decrements the server-side pin_count. When pin_count reaches 0, + the entry becomes eligible for eviction. Args: key: The cache key. @@ -674,7 +692,25 @@ def unpin(self, key: CacheEngineKey) -> bool: Returns: True if unpinned successfully. """ - return False + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + return self._handler.unpin_kv(key.to_string()) + + def batched_unpin(self, keys: List[CacheEngineKey]) -> None: + """Batch-unpin keys via single RPC. + + Decrements server-side pin_count for each key. When pin_count + reaches 0, the entry becomes eligible for eviction. + + Args: + keys: The cache keys to unpin. + """ + if not keys: + return + if self._mla_worker_id_as0_mode: + keys = [k.with_new_worker_id(0) for k in keys] + key_strs = [k.to_string() for k in keys] + self._handler.batch_unpin_kv(key_strs) def remove(self, key: CacheEngineKey, force: bool = True) -> bool: """Remove a key from MaruServer. @@ -686,6 +722,8 @@ def remove(self, key: CacheEngineKey, force: bool = True) -> bool: Returns: True if removed successfully. """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) key_str = key.to_string() result = self._handler.delete(key_str) logger.debug("[Maru] remove key=%s success=%s", key, result) From 2be19a9794e96374de44a6138035dcbb362b4b22 Mon Sep 17 00:00:00 2001 From: seohui-XCENA Date: Fri, 20 Mar 2026 11:50:38 +0900 Subject: [PATCH 08/21] fix: rename MaruHandler pin/unpin RPC method calls (#12) Align method names with updated MaruHandler API: - exists_and_pin -> pin - batch_exists_and_pin -> batch_pin - unpin_kv -> unpin - batch_unpin_kv -> batch_unpin --- lmcache/v1/storage_backend/maru_backend.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index cc5d56edda..8823e14307 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -632,7 +632,7 @@ def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: key_str = key.to_string() if pin: - return self._handler.exists_and_pin(key_str) + return self._handler.pin(key_str) return self._handler.exists(key_str) def batched_contains( @@ -645,7 +645,7 @@ def batched_contains( Args: keys: Keys to check in prefix order. pin: If True, atomically check and pin via - batch_exists_and_pin RPC. + batch_pin RPC. Returns: Number of prefix-contiguous keys that exist. @@ -655,7 +655,7 @@ def batched_contains( key_strs = [k.to_string() for k in keys] if pin: - results = self._handler.batch_exists_and_pin(key_strs) + results = self._handler.batch_pin(key_strs) else: results = self._handler.batch_exists(key_strs) num_hit = 0 @@ -694,7 +694,7 @@ def unpin(self, key: CacheEngineKey) -> bool: """ if self._mla_worker_id_as0_mode: key = key.with_new_worker_id(0) - return self._handler.unpin_kv(key.to_string()) + return self._handler.unpin(key.to_string()) def batched_unpin(self, keys: List[CacheEngineKey]) -> None: """Batch-unpin keys via single RPC. @@ -710,7 +710,7 @@ def batched_unpin(self, keys: List[CacheEngineKey]) -> None: if self._mla_worker_id_as0_mode: keys = [k.with_new_worker_id(0) for k in keys] key_strs = [k.to_string() for k in keys] - self._handler.batch_unpin_kv(key_strs) + self._handler.batch_unpin(key_strs) def remove(self, key: CacheEngineKey, force: bool = True) -> bool: """Remove a key from MaruServer. From d6f47e791ae8796946a5f27d6873262a03a7e636 Mon Sep 17 00:00:00 2001 From: youngrok-XCENA Date: Fri, 20 Mar 2026 04:48:03 +0000 Subject: [PATCH 09/21] tests: add maru backend test --- tests/v1/storage_backend/test_maru_backend.py | 779 ++++++++++++++++++ 1 file changed, 779 insertions(+) create mode 100644 tests/v1/storage_backend/test_maru_backend.py diff --git a/tests/v1/storage_backend/test_maru_backend.py b/tests/v1/storage_backend/test_maru_backend.py new file mode 100644 index 0000000000..a1eab8c3ee --- /dev/null +++ b/tests/v1/storage_backend/test_maru_backend.py @@ -0,0 +1,779 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import asyncio +import mmap +import threading +from unittest.mock import MagicMock, patch + +# Third Party +import pytest +import torch + +# First Party +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import MemoryFormat, TensorMemoryObj +from lmcache.v1.pin_monitor import PinMonitor +from lmcache.v1.storage_backend.abstract_backend import AllocatorBackendInterface +from tests.v1.utils import ( + check_method_signatures, + get_abstract_methods, + get_all_methods_from_base, + get_methods_implemented_in_class, +) + +maru = pytest.importorskip("maru", reason="maru package not installed") +maru_lmcache = pytest.importorskip( + "maru_lmcache", reason="maru_lmcache package not installed" +) + +# Local +from lmcache.v1.storage_backend.maru_backend import MaruBackend +from maru_handler.memory import AllocHandle +from maru_handler.memory.types import MappedRegion, MemoryInfo +from maru_lmcache.adapter import CxlMemoryAdapter + +# ========================================================================= +# Constants +# ========================================================================= + +TEST_CHUNK_SIZE = 1024 +TEST_DTYPE = torch.float32 +TEST_SHAPE = torch.Size([256]) # 256 * 4B = 1024 bytes = chunk_size + + +# ========================================================================= +# Helpers +# ========================================================================= + + +def _make_mock_handler(pool_size=4096, chunk_size=TEST_CHUNK_SIZE): + """Create a mock MaruHandler with mmap-backed regions.""" + handler = MagicMock() + handler._connected = True + + region_id = 100 + page_count = pool_size // chunk_size + + mmap_obj = mmap.mmap(-1, pool_size) + mapped_region = MappedRegion( + region_id=region_id, + handle=MagicMock(region_id=region_id, length=pool_size), + size=pool_size, + _mmap_obj=mmap_obj, + ) + + handler.get_buffer_view.side_effect = lambda rid, offset, size: ( + mapped_region.get_buffer_view(offset, size) if rid == region_id else None + ) + handler.get_region_page_count.side_effect = lambda rid: ( + page_count if rid == region_id else None + ) + handler.get_owned_region_ids.return_value = [region_id] + handler.get_chunk_size.return_value = chunk_size + + def mock_set_on_region_added(callback): + if callback is not None: + callback(region_id, page_count) + + handler.set_on_region_added.side_effect = mock_set_on_region_added + + page_counter = [0] + + def mock_alloc(size): + idx = page_counter[0] + page_counter[0] += 1 + buf = mapped_region.get_buffer_view(idx * chunk_size, size) + return AllocHandle( + buf=buf, _region_id=region_id, _page_index=idx, _size=size + ) + + handler.alloc.side_effect = mock_alloc + handler.free = MagicMock() + handler.connect.return_value = True + handler.close.return_value = None + handler.store.return_value = True + handler.batch_store.return_value = None + handler.retrieve.return_value = None + handler.batch_retrieve.return_value = [] + handler.exists.return_value = False + handler.batch_exists.return_value = [] + handler.delete.return_value = True + handler.pin.return_value = True + handler.pin_kv.return_value = True + handler.unpin.return_value = True + handler.batch_pin.return_value = [] + handler.batch_unpin.return_value = None + + return handler + + +def _make_cache_key(chunk_hash: int = 12345) -> CacheEngineKey: + """Create a CacheEngineKey for testing.""" + return CacheEngineKey( + model_name="test-model", + world_size=1, + worker_id=0, + chunk_hash=chunk_hash, + dtype=torch.float32, + ) + + +def _make_memory_obj(adapter: CxlMemoryAdapter) -> TensorMemoryObj: + """Allocate a TensorMemoryObj from the adapter.""" + obj = adapter.allocate(TEST_SHAPE, TEST_DTYPE) + assert obj is not None + return obj + + +# ========================================================================= +# Fixtures +# ========================================================================= + + +@pytest.fixture(autouse=True) +def _init_pin_monitor(): + """Initialize PinMonitor singleton required by TensorMemoryObj.pin().""" + PinMonitor._instance = None + PinMonitor.GetOrCreate(LMCacheEngineConfig.from_defaults()) + yield + PinMonitor._instance = None + + +@pytest.fixture +def async_loop(): + """Provide an asyncio event loop running in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() + yield loop + loop.call_soon_threadsafe(loop.stop) + thread.join(timeout=5) + loop.close() + + +@pytest.fixture +def mock_handler(): + return _make_mock_handler() + + +@pytest.fixture +def adapter(mock_handler): + return CxlMemoryAdapter( + handler=mock_handler, + shapes=[TEST_SHAPE], + dtypes=[TEST_DTYPE], + fmt=MemoryFormat.KV_2LTD, + chunk_size=TEST_CHUNK_SIZE, + ) + + +@pytest.fixture +def backend(mock_handler, adapter, async_loop): + """Create a MaruBackend with mocked internals.""" + # Local + + + with patch.object(MaruBackend, "initialize_allocator", return_value=adapter): + backend = MaruBackend.__new__(MaruBackend) + backend.dst_device = "cpu" + backend.config = MagicMock() + backend.config.maru_pool_size = "4K" + backend.loop = async_loop + backend.memory_allocator = adapter + backend._handler = mock_handler + + backend._full_chunk_size_bytes = TEST_CHUNK_SIZE + backend._single_token_size = TEST_CHUNK_SIZE // 256 # 4 bytes per token + backend._mla_worker_id_as0_mode = False + + backend.put_lock = threading.Lock() + backend.put_tasks = set() + return backend + + +def _run_async(loop, coro): + """Submit a coroutine to a running event loop and wait for result.""" + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result(timeout=5) + + +# ========================================================================= +# Tests — Init & Interface Compliance +# ========================================================================= + + +class TestMaruBackendInit: + def test_str(self, backend): + assert str(backend) == "MaruBackend" + + def test_get_allocator_backend_returns_self(self, backend): + assert backend.get_allocator_backend() is backend + + def test_get_memory_allocator_returns_adapter(self, backend, adapter): + assert backend.get_memory_allocator() is adapter + + +class TestMaruBackendParsePoolSize: + """Test _parse_pool_size static method.""" + + def test_none_returns_default(self): + + + result = MaruBackend._parse_pool_size(None) + assert result == 4 * 1024**3 + + def test_parse_gigabytes(self): + + + assert MaruBackend._parse_pool_size("4G") == 4 * 1024**3 + assert MaruBackend._parse_pool_size("4GB") == 4 * 1024**3 + + def test_parse_megabytes(self): + + + assert MaruBackend._parse_pool_size("512M") == 512 * 1024**2 + assert MaruBackend._parse_pool_size("512MB") == 512 * 1024**2 + + def test_parse_kilobytes(self): + + + assert MaruBackend._parse_pool_size("1K") == 1024 + assert MaruBackend._parse_pool_size("1KB") == 1024 + + def test_parse_terabytes(self): + + + assert MaruBackend._parse_pool_size("1T") == 1024**4 + + def test_parse_plain_integer_string(self): + + + assert MaruBackend._parse_pool_size("1048576") == 1048576 + + def test_parse_integer_value(self): + + + assert MaruBackend._parse_pool_size(2048) == 2048 + + def test_parse_invalid_returns_default(self): + + + result = MaruBackend._parse_pool_size("invalid") + assert result == 4 * 1024**3 + + def test_parse_case_insensitive(self): + + + assert MaruBackend._parse_pool_size("4g") == 4 * 1024**3 + assert MaruBackend._parse_pool_size("512m") == 512 * 1024**2 + + +class TestMaruBackendInterfaceCompliance: + """Verify MaruBackend implements all required interface methods.""" + + def test_implements_all_abstract_methods(self): + + + abstract = get_abstract_methods(AllocatorBackendInterface) + implemented = get_methods_implemented_in_class( + MaruBackend, AllocatorBackendInterface + ) + missing = abstract - implemented + assert not missing, f"Missing abstract methods: {missing}" + + def test_method_signatures_match(self): + + + # Known: batched_submit_put_task uses 'memory_objs' instead of 'objs' + # TODO: Rename to 'objs' for full compliance + known_param_renames = {"batched_submit_put_task"} + + mismatches = check_method_signatures(AllocatorBackendInterface, MaruBackend) + unexpected = [m for m in mismatches if m["method"] not in known_param_renames] + assert not unexpected, f"Signature mismatches: {unexpected}" + + +# ========================================================================= +# Tests — Allocate +# ========================================================================= + + +class TestMaruBackendAllocate: + def test_allocate_returns_memory_obj(self, backend): + obj = backend.allocate(TEST_SHAPE, TEST_DTYPE) + assert obj is not None + assert obj.tensor is not None + assert obj.metadata.dtype == TEST_DTYPE + + def test_batched_allocate_returns_list(self, backend): + objs = backend.batched_allocate(TEST_SHAPE, TEST_DTYPE, batch_size=3) + assert objs is not None + assert len(objs) == 3 + for obj in objs: + assert obj.tensor is not None + + +# ========================================================================= +# Tests — Put (async) +# ========================================================================= + + +class TestMaruBackendPut: + def test_submit_put_task_returns_future(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + future = backend.submit_put_task(key, obj) + assert future is not None + future.result(timeout=5) + + backend._handler.store.assert_called_once() + + def test_submit_put_task_tracks_in_flight(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + assert not backend.exists_in_put_tasks(key) + + future = backend.submit_put_task(key, obj) + future.result(timeout=5) + + # After completion, key should be removed from put_tasks + assert not backend.exists_in_put_tasks(key) + + def test_exists_in_put_tasks_true_during_store(self, backend, adapter): + """Verify exists_in_put_tasks returns True while store is in progress.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + store_entered = threading.Event() + store_proceed = threading.Event() + + original_store = backend._handler.store + + def blocking_store(*args, **kwargs): + store_entered.set() + store_proceed.wait(timeout=5) + return original_store(*args, **kwargs) + + backend._handler.store.side_effect = blocking_store + + future = backend.submit_put_task(key, obj) + + # Wait until store is actually running + assert store_entered.wait(timeout=5) + assert backend.exists_in_put_tasks(key) + + # Let store complete + store_proceed.set() + future.result(timeout=5) + assert not backend.exists_in_put_tasks(key) + + def test_batched_submit_put_task(self, backend, adapter): + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + + backend._handler.batch_store.return_value = [True, True, True] + + futures = backend.batched_submit_put_task(keys, objs) + assert futures is not None + + for future in futures: + future.result(timeout=5) + + backend._handler.batch_store.assert_called_once() + + def test_submit_put_calls_callback(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + callback_called = [] + + def callback(k): + callback_called.append(k) + + future = backend.submit_put_task(key, obj, on_complete_callback=callback) + future.result(timeout=5) + + assert len(callback_called) == 1 + assert callback_called[0] == key + + def test_batched_submit_put_calls_callback_per_key(self, backend, adapter): + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + + backend._handler.batch_store.return_value = [True, True, True] + callback_keys = [] + + def callback(k): + callback_keys.append(k) + + futures = backend.batched_submit_put_task( + keys, objs, on_complete_callback=callback + ) + for future in futures: + future.result(timeout=5) + + assert set(callback_keys) == set(keys) + + +# ========================================================================= +# Tests — Get (sync) +# ========================================================================= + + +class TestMaruBackendGet: + def test_get_blocking_hit(self, backend, adapter): + key = _make_cache_key() + + data_size = TEST_CHUNK_SIZE + data = bytearray(data_size) + mock_info = MemoryInfo( + view=memoryview(data), + region_id=100, + page_index=0, + ) + backend._handler.retrieve.return_value = mock_info + + result = backend.get_blocking(key) + assert result is not None + backend._handler.retrieve.assert_called_once() + + def test_get_blocking_miss(self, backend): + key = _make_cache_key() + backend._handler.retrieve.return_value = None + + result = backend.get_blocking(key) + assert result is None + + def test_get_blocking_ref_count_increases(self, backend, adapter): + """After get_blocking, the returned MemoryObj should have ref_count + incremented.""" + # Pre-allocate so pool has page 0 + _make_memory_obj(adapter) + + key = _make_cache_key() + mock_info = MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=100, + page_index=0, + ) + backend._handler.retrieve.return_value = mock_info + + result = backend.get_blocking(key) + assert result is not None + # Pool objects start with ref_count=1, get_blocking calls ref_count_up + assert result.get_ref_count() >= 2 + + def test_batched_get_blocking(self, backend, adapter): + """batched_get_blocking returns list of MemoryObj via batch_retrieve.""" + objs = [_make_memory_obj(adapter) for _ in range(2)] + keys = [_make_cache_key(i) for i in range(2)] + + infos = [] + for obj in objs: + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + infos.append( + MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + ) + backend._handler.batch_retrieve.return_value = infos + + results = backend.batched_get_blocking(keys) + assert len(results) == 2 + for r in results: + assert r is not None + + def test_batched_get_blocking_with_miss(self, backend, adapter): + """batched_get_blocking returns None for missing keys.""" + obj = _make_memory_obj(adapter) + keys = [_make_cache_key(i) for i in range(2)] + + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + info = MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + backend._handler.batch_retrieve.return_value = [info, None] + + results = backend.batched_get_blocking(keys) + assert len(results) == 2 + assert results[0] is not None + assert results[1] is None + + +# ========================================================================= +# Tests — Contains +# ========================================================================= + + +class TestMaruBackendContains: + def test_contains_true(self, backend): + key = _make_cache_key() + backend._handler.exists.return_value = True + + assert backend.contains(key) is True + backend._handler.exists.assert_called_once_with(key.to_string()) + + def test_contains_false(self, backend): + key = _make_cache_key() + backend._handler.exists.return_value = False + + assert backend.contains(key) is False + + def test_batched_contains_all_hit(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, True] + + result = backend.batched_contains(keys) + assert result == 3 + + def test_batched_contains_partial_prefix(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, False] + + result = backend.batched_contains(keys) + assert result == 2 + + def test_batched_contains_first_miss(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [False, True, True] + + result = backend.batched_contains(keys) + assert result == 0 + + def test_contains_with_pin(self, backend): + key = _make_cache_key() + backend._handler.pin.return_value = True + + assert backend.contains(key, pin=True) is True + backend._handler.pin.assert_called_once_with(key.to_string()) + backend._handler.exists.assert_not_called() + + def test_contains_with_pin_false(self, backend): + key = _make_cache_key() + backend._handler.pin.return_value = False + + assert backend.contains(key, pin=True) is False + + def test_batched_contains_with_pin(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_pin.return_value = [True, True, True] + + result = backend.batched_contains(keys, pin=True) + assert result == 3 + backend._handler.batch_pin.assert_called_once_with( + [k.to_string() for k in keys] + ) + backend._handler.batch_exists.assert_not_called() + + def test_batched_contains_with_pin_partial(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_pin.return_value = [True, False, True] + + result = backend.batched_contains(keys, pin=True) + assert result == 1 + + def test_batched_contains_empty(self, backend): + backend._handler.batch_exists.return_value = [] + assert backend.batched_contains([]) == 0 + + +# ========================================================================= +# Tests — Async Lookup +# ========================================================================= + + +class TestMaruBackendAsyncLookup: + def test_batched_async_contains_all_hit(self, backend, async_loop): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, True] + + result = _run_async( + async_loop, backend.batched_async_contains("lookup-1", keys) + ) + assert result == 3 + + def test_batched_async_contains_partial_prefix(self, backend, async_loop): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, False, True] + + result = _run_async( + async_loop, backend.batched_async_contains("lookup-2", keys) + ) + assert result == 1 + + def test_batched_async_contains_empty(self, backend, async_loop): + backend._handler.batch_exists.return_value = [] + result = _run_async( + async_loop, backend.batched_async_contains("lookup-3", []) + ) + assert result == 0 + + def test_batched_get_non_blocking_all_hit(self, backend, adapter, async_loop): + keys = [_make_cache_key(i) for i in range(2)] + + objs = [_make_memory_obj(adapter) for _ in range(2)] + infos = [] + for obj in objs: + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + infos.append( + MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + ) + backend._handler.batch_retrieve.return_value = infos + + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-4", keys) + ) + assert len(results) == 2 + for obj in results: + assert obj is not None + + def test_batched_get_non_blocking_prefix_stop_on_miss( + self, backend, adapter, async_loop + ): + """Second key is a miss -> only first returned (prefix semantics).""" + keys = [_make_cache_key(i) for i in range(3)] + + obj = _make_memory_obj(adapter) + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + info = MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + # hit, miss, hit -> should return only [hit] + backend._handler.batch_retrieve.return_value = [info, None, info] + + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-5", keys) + ) + assert len(results) == 1 + + def test_batched_get_non_blocking_empty(self, backend, async_loop): + backend._handler.batch_retrieve.return_value = [] + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-6", []) + ) + assert results == [] + + +# ========================================================================= +# Tests — Pin / Unpin / Remove +# ========================================================================= + + +class TestMaruBackendPinRemove: + def test_pin_delegates_to_handler(self, backend): + key = _make_cache_key() + backend._handler.pin_kv.return_value = True + + assert backend.pin(key) is True + backend._handler.pin_kv.assert_called_once_with(key.to_string()) + + def test_pin_returns_false_on_failure(self, backend): + key = _make_cache_key() + backend._handler.pin_kv.return_value = False + + assert backend.pin(key) is False + + def test_unpin_delegates_to_handler(self, backend): + key = _make_cache_key() + backend._handler.unpin.return_value = True + + assert backend.unpin(key) is True + backend._handler.unpin.assert_called_once_with(key.to_string()) + + def test_unpin_returns_false_on_failure(self, backend): + key = _make_cache_key() + backend._handler.unpin.return_value = False + + assert backend.unpin(key) is False + + def test_batched_unpin(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + + backend.batched_unpin(keys) + backend._handler.batch_unpin.assert_called_once_with( + [k.to_string() for k in keys] + ) + + def test_batched_unpin_empty(self, backend): + backend.batched_unpin([]) + backend._handler.batch_unpin.assert_not_called() + + def test_remove_existing_key(self, backend): + key = _make_cache_key() + backend._handler.delete.return_value = True + + result = backend.remove(key) + assert result is True + backend._handler.delete.assert_called_once_with(key.to_string()) + + def test_remove_nonexistent_key(self, backend): + key = _make_cache_key() + backend._handler.delete.return_value = False + + result = backend.remove(key) + assert result is False + + +# ========================================================================= +# Tests — Lifecycle +# ========================================================================= + + +class TestMaruBackendLifecycle: + def test_close_calls_handler_and_allocator(self, backend): + backend.memory_allocator = MagicMock() + backend.close() + backend.memory_allocator.close.assert_called_once() + backend._handler.close.assert_called_once() + + def test_close_with_pending_put_tasks(self, backend, adapter): + """close() should warn but not raise when put tasks are pending.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + # Manually add to put_tasks to simulate in-flight + with backend.put_lock: + backend.put_tasks.add(key) + + # Should not raise + backend.close() + backend._handler.close.assert_called_once() + + +# ========================================================================= +# Tests — Store Handle Roundtrip +# ========================================================================= + + +class TestMaruBackendStoreHandle: + def test_store_handle_roundtrip(self, backend, adapter): + """AllocHandle from create_store_handle should match original.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + + handle = adapter.create_store_handle(obj) + assert handle.region_id == 100 + assert handle.page_index == 0 + assert handle._size == obj.metadata.phy_size From ebca1a9620a930ee0d379f60092cce0c88af1b21 Mon Sep 17 00:00:00 2001 From: youngrok-XCENA Date: Fri, 20 Mar 2026 05:22:14 +0000 Subject: [PATCH 10/21] fix: fix ruff fails --- tests/v1/storage_backend/test_maru_backend.py | 46 ++++--------------- 1 file changed, 10 insertions(+), 36 deletions(-) diff --git a/tests/v1/storage_backend/test_maru_backend.py b/tests/v1/storage_backend/test_maru_backend.py index a1eab8c3ee..82c6868631 100644 --- a/tests/v1/storage_backend/test_maru_backend.py +++ b/tests/v1/storage_backend/test_maru_backend.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from unittest.mock import MagicMock, patch import asyncio import mmap import threading -from unittest.mock import MagicMock, patch # Third Party import pytest @@ -19,7 +19,6 @@ from tests.v1.utils import ( check_method_signatures, get_abstract_methods, - get_all_methods_from_base, get_methods_implemented_in_class, ) @@ -28,11 +27,13 @@ "maru_lmcache", reason="maru_lmcache package not installed" ) -# Local -from lmcache.v1.storage_backend.maru_backend import MaruBackend -from maru_handler.memory import AllocHandle -from maru_handler.memory.types import MappedRegion, MemoryInfo -from maru_lmcache.adapter import CxlMemoryAdapter +# Third Party +from maru_handler.memory import AllocHandle # noqa: E402 +from maru_handler.memory.types import MappedRegion, MemoryInfo # noqa: E402 +from maru_lmcache.adapter import CxlMemoryAdapter # noqa: E402 + +# First Party +from lmcache.v1.storage_backend.maru_backend import MaruBackend # noqa: E402 # ========================================================================= # Constants @@ -85,9 +86,7 @@ def mock_alloc(size): idx = page_counter[0] page_counter[0] += 1 buf = mapped_region.get_buffer_view(idx * chunk_size, size) - return AllocHandle( - buf=buf, _region_id=region_id, _page_index=idx, _size=size - ) + return AllocHandle(buf=buf, _region_id=region_id, _page_index=idx, _size=size) handler.alloc.side_effect = mock_alloc handler.free = MagicMock() @@ -174,7 +173,6 @@ def backend(mock_handler, adapter, async_loop): """Create a MaruBackend with mocked internals.""" # Local - with patch.object(MaruBackend, "initialize_allocator", return_value=adapter): backend = MaruBackend.__new__(MaruBackend) backend.dst_device = "cpu" @@ -219,53 +217,35 @@ class TestMaruBackendParsePoolSize: """Test _parse_pool_size static method.""" def test_none_returns_default(self): - - result = MaruBackend._parse_pool_size(None) assert result == 4 * 1024**3 def test_parse_gigabytes(self): - - assert MaruBackend._parse_pool_size("4G") == 4 * 1024**3 assert MaruBackend._parse_pool_size("4GB") == 4 * 1024**3 def test_parse_megabytes(self): - - assert MaruBackend._parse_pool_size("512M") == 512 * 1024**2 assert MaruBackend._parse_pool_size("512MB") == 512 * 1024**2 def test_parse_kilobytes(self): - - assert MaruBackend._parse_pool_size("1K") == 1024 assert MaruBackend._parse_pool_size("1KB") == 1024 def test_parse_terabytes(self): - - assert MaruBackend._parse_pool_size("1T") == 1024**4 def test_parse_plain_integer_string(self): - - assert MaruBackend._parse_pool_size("1048576") == 1048576 def test_parse_integer_value(self): - - assert MaruBackend._parse_pool_size(2048) == 2048 def test_parse_invalid_returns_default(self): - - result = MaruBackend._parse_pool_size("invalid") assert result == 4 * 1024**3 def test_parse_case_insensitive(self): - - assert MaruBackend._parse_pool_size("4g") == 4 * 1024**3 assert MaruBackend._parse_pool_size("512m") == 512 * 1024**2 @@ -274,8 +254,6 @@ class TestMaruBackendInterfaceCompliance: """Verify MaruBackend implements all required interface methods.""" def test_implements_all_abstract_methods(self): - - abstract = get_abstract_methods(AllocatorBackendInterface) implemented = get_methods_implemented_in_class( MaruBackend, AllocatorBackendInterface @@ -284,8 +262,6 @@ def test_implements_all_abstract_methods(self): assert not missing, f"Missing abstract methods: {missing}" def test_method_signatures_match(self): - - # Known: batched_submit_put_task uses 'memory_objs' instead of 'objs' # TODO: Rename to 'objs' for full compliance known_param_renames = {"batched_submit_put_task"} @@ -618,9 +594,7 @@ def test_batched_async_contains_partial_prefix(self, backend, async_loop): def test_batched_async_contains_empty(self, backend, async_loop): backend._handler.batch_exists.return_value = [] - result = _run_async( - async_loop, backend.batched_async_contains("lookup-3", []) - ) + result = _run_async(async_loop, backend.batched_async_contains("lookup-3", [])) assert result == 0 def test_batched_get_non_blocking_all_hit(self, backend, adapter, async_loop): From 5eceac6d613fe861e752fde9bcb1704455fe14fe Mon Sep 17 00:00:00 2001 From: jooho Date: Fri, 20 Mar 2026 14:30:17 +0900 Subject: [PATCH 11/21] refactor: move maru ImportError handling to __init__.py (#13) --- lmcache/v1/storage_backend/__init__.py | 11 +++++++++-- lmcache/v1/storage_backend/maru_backend.py | 20 ++------------------ 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/lmcache/v1/storage_backend/__init__.py b/lmcache/v1/storage_backend/__init__.py index bcdd6f5bff..d49cda695a 100644 --- a/lmcache/v1/storage_backend/__init__.py +++ b/lmcache/v1/storage_backend/__init__.py @@ -219,8 +219,15 @@ def CreateStorageBackends( storage_backends[str(gds_backend)] = gds_backend if config.maru_path is not None and "MaruBackend" not in _skip: - # First Party - from lmcache.v1.storage_backend.maru_backend import MaruBackend + try: + # First Party + from lmcache.v1.storage_backend.maru_backend import MaruBackend + except ImportError as e: + raise ImportError( + "The 'maru' and 'maru_lmcache' packages are required " + "to use MaruBackend. Please install them according to " + "the Maru setup documentation." + ) from e maru_backend = MaruBackend(config, metadata, loop, dst_device) storage_backends[str(maru_backend)] = maru_backend diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index 8823e14307..fef2a3e21c 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -7,25 +7,9 @@ import re import threading -try: - # Third Party - from maru import MaruConfig, MaruHandler -except ImportError as e: - raise ImportError( - "The 'maru' package is required to use MaruBackend. " - "Please install it according to the Maru setup documentation." - ) from e - -try: - # Third Party - from maru_lmcache import CxlMemoryAdapter -except ImportError as e: - raise ImportError( - "The 'maru_lmcache' package is required to use MaruBackend. " - "Please install it according to the Maru setup documentation." - ) from e - # Third Party +from maru import MaruConfig, MaruHandler +from maru_lmcache import CxlMemoryAdapter import torch # First Party From c4f1053e6544fb566eb1026ebde934a18fe2d7b8 Mon Sep 17 00:00:00 2001 From: hyunyul-XCENA Date: Fri, 20 Mar 2026 05:45:02 +0000 Subject: [PATCH 12/21] fix: rename handler.pin_kv() to handler.pin() and document ref_count intent - Fix AttributeError: MaruHandler exposes pin(), not pin_kv() - Add comment explaining intentional ref_count retention on successful store Signed-off-by: hyunyul-XCENA --- lmcache/v1/storage_backend/maru_backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index fef2a3e21c..f29ea61168 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -285,6 +285,8 @@ def submit_put_task( """ assert memory_obj.tensor is not None + # Keep CXL page alive: ref_count_down is only called on failure. + # On success the ref is retained so the CXL memory is not reclaimed. memory_obj.ref_count_up() with self.put_lock: @@ -662,7 +664,7 @@ def pin(self, key: CacheEngineKey) -> bool: """ if self._mla_worker_id_as0_mode: key = key.with_new_worker_id(0) - return self._handler.pin_kv(key.to_string()) + return self._handler.pin(key.to_string()) def unpin(self, key: CacheEngineKey) -> bool: """Unpin a key to allow eviction on MaruServer. From 3b64cd434e79bd1c8ec2dd33a077c12d159e4105 Mon Sep 17 00:00:00 2001 From: hyunyul-XCENA Date: Fri, 20 Mar 2026 06:03:33 +0000 Subject: [PATCH 13/21] docs: update maru.rst for new MaruBackend config - Replace old connector config (remote_url, remote_serde) with maru_path and maru_pool_size top-level parameters - Remove maru_operation_timeout (unused in MaruBackend) - Update maru_timeout_ms default from 2000 to 5000 Signed-off-by: hyunyul-XCENA --- .../source/kv_cache/storage_backends/maru.rst | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/docs/source/kv_cache/storage_backends/maru.rst b/docs/source/kv_cache/storage_backends/maru.rst index b433ce5746..9a4a839b2f 100644 --- a/docs/source/kv_cache/storage_backends/maru.rst +++ b/docs/source/kv_cache/storage_backends/maru.rst @@ -43,14 +43,11 @@ Deploy Model With Maru .. code-block:: yaml chunk_size: 256 - local_cpu: True - max_local_cpu_size: 5 - remote_url: "maru://localhost:5555" - remote_serde: "naive" + save_unfull_chunk: True - extra_config: - maru_pool_size: "4G" - save_chunk_meta: False + # Maru backend + maru_path: "tcp://localhost:5555" + maru_pool_size: 4G **3. Start vLLM with Maru** @@ -75,11 +72,14 @@ Configuration * - Parameter - Default - Description - * - ``remote_url`` + * - ``maru_path`` - Required - - Maru server URL (format: ``maru://host:port``) + - Maru server URL (format: ``tcp://host:port``) + * - ``maru_pool_size`` + - ``"4G"`` + - CXL memory pool size per instance (e.g., ``"4G"``, ``"500M"``) -**Maru Parameters (via extra_config):** +**Advanced Parameters (via extra_config):** .. list-table:: :header-rows: 1 @@ -88,17 +88,11 @@ Configuration * - Parameter - Default - Description - * - ``maru_pool_size`` - - ``"1G"`` - - CXL memory pool size per instance (e.g., ``"4G"``, ``"500M"``) * - ``maru_instance_id`` - auto UUID - Unique client instance identifier - * - ``maru_operation_timeout`` - - 10.0 - - Per-operation timeout in seconds * - ``maru_timeout_ms`` - - 2000 + - 5000 - ZMQ RPC socket timeout in milliseconds * - ``maru_use_async_rpc`` - true From 193b5358e29452bfa87301a36115d5d513de62e6 Mon Sep 17 00:00:00 2001 From: hyunyul-XCENA Date: Fri, 20 Mar 2026 06:05:48 +0000 Subject: [PATCH 14/21] docs: add local_cpu: False to maru config example MaruBackend manages its own CXL memory allocation. LocalCPUBackend must be disabled to avoid allocator conflicts. Signed-off-by: hyunyul-XCENA --- docs/source/kv_cache/storage_backends/maru.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/kv_cache/storage_backends/maru.rst b/docs/source/kv_cache/storage_backends/maru.rst index 9a4a839b2f..3cd4cc7361 100644 --- a/docs/source/kv_cache/storage_backends/maru.rst +++ b/docs/source/kv_cache/storage_backends/maru.rst @@ -43,6 +43,8 @@ Deploy Model With Maru .. code-block:: yaml chunk_size: 256 + local_cpu: False + max_local_cpu_size: 0 save_unfull_chunk: True # Maru backend From fd47482573e1463a14a1612d31b809105953b739 Mon Sep 17 00:00:00 2001 From: youngrok-XCENA Date: Fri, 20 Mar 2026 06:48:52 +0000 Subject: [PATCH 15/21] fix: fix handler method name mismatch --- tests/v1/storage_backend/test_maru_backend.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/v1/storage_backend/test_maru_backend.py b/tests/v1/storage_backend/test_maru_backend.py index 82c6868631..f4c6f93616 100644 --- a/tests/v1/storage_backend/test_maru_backend.py +++ b/tests/v1/storage_backend/test_maru_backend.py @@ -100,7 +100,6 @@ def mock_alloc(size): handler.batch_exists.return_value = [] handler.delete.return_value = True handler.pin.return_value = True - handler.pin_kv.return_value = True handler.unpin.return_value = True handler.batch_pin.return_value = [] handler.batch_unpin.return_value = None @@ -657,14 +656,14 @@ def test_batched_get_non_blocking_empty(self, backend, async_loop): class TestMaruBackendPinRemove: def test_pin_delegates_to_handler(self, backend): key = _make_cache_key() - backend._handler.pin_kv.return_value = True + backend._handler.pin.return_value = True assert backend.pin(key) is True - backend._handler.pin_kv.assert_called_once_with(key.to_string()) + backend._handler.pin.assert_called_once_with(key.to_string()) def test_pin_returns_false_on_failure(self, backend): key = _make_cache_key() - backend._handler.pin_kv.return_value = False + backend._handler.pin.return_value = False assert backend.pin(key) is False From 796f184828fdbc184d4884100140769e133816be Mon Sep 17 00:00:00 2001 From: youngrok-XCENA Date: Fri, 20 Mar 2026 07:23:45 +0000 Subject: [PATCH 16/21] =?UTF-8?q?fix:=20skip=20put=20in=20MLA=20worker=5Fi?= =?UTF-8?q?d=5Fas0=20mode,=20fix=20test=20pin=5Fkv=E2=86=92pin?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add MLA worker_id_as0 skip logic to submit_put_task and batched_submit_put_task, matching RemoteBackend behavior - Fix test mock/assert to use handler.pin() instead of handler.pin_kv() (pin_kv is internal RPC method, pin is public MaruHandler API) - Add MLA skip tests for both put paths - Clean up docstring for unused params --- lmcache/v1/storage_backend/maru_backend.py | 17 +++++++++++++- tests/v1/storage_backend/test_maru_backend.py | 23 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index f29ea61168..849f070e47 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -208,7 +208,7 @@ def allocate( shapes: Tensor shape(s). dtypes: Tensor dtype(s). fmt: Memory format. - eviction: Unused (no eviction policy yet). + eviction: Unused. busy_loop: Unused. Returns: @@ -264,6 +264,13 @@ def exists_in_put_tasks(self, key: CacheEngineKey) -> bool: with self.put_lock: return key in self.put_tasks + @staticmethod + def _create_immediate_empty_future() -> Future: + """Create a Future that is already resolved with None.""" + f: Future = Future() + f.set_result(None) + return f + def submit_put_task( self, key: CacheEngineKey, @@ -283,6 +290,10 @@ def submit_put_task( Returns: Future that completes when metadata is registered. """ + # If MLA worker id as 0 mode is enabled, skip put tasks + if self._mla_worker_id_as0_mode: + return self._create_immediate_empty_future() + assert memory_obj.tensor is not None # Keep CXL page alive: ref_count_down is only called on failure. @@ -316,6 +327,10 @@ def batched_submit_put_task( Returns: List containing a single Future for the entire batch. """ + # If MLA worker id as 0 mode is enabled, skip put tasks + if self._mla_worker_id_as0_mode: + return None + for memory_obj in memory_objs: assert memory_obj.tensor is not None memory_obj.ref_count_up() diff --git a/tests/v1/storage_backend/test_maru_backend.py b/tests/v1/storage_backend/test_maru_backend.py index f4c6f93616..ae94163523 100644 --- a/tests/v1/storage_backend/test_maru_backend.py +++ b/tests/v1/storage_backend/test_maru_backend.py @@ -400,6 +400,29 @@ def callback(k): assert set(callback_keys) == set(keys) + def test_submit_put_task_skips_in_mla_mode(self, backend, adapter): + """In MLA worker_id_as0 mode, submit_put_task should skip store.""" + backend._mla_worker_id_as0_mode = True + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + future = backend.submit_put_task(key, obj) + assert future.result(timeout=5) is None + backend._handler.store.assert_not_called() + + def test_batched_submit_put_task_skips_in_mla_mode(self, backend, adapter): + """In MLA worker_id_as0 mode, batched_submit_put_task should skip.""" + backend._mla_worker_id_as0_mode = True + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + + result = backend.batched_submit_put_task(keys, objs) + assert result is None + backend._handler.batch_store.assert_not_called() + # ========================================================================= # Tests — Get (sync) From 1425daf4f7b26bf33a94a91a3d7e80c507d77314 Mon Sep 17 00:00:00 2001 From: youngrok-XCENA Date: Fri, 20 Mar 2026 07:26:29 +0000 Subject: [PATCH 17/21] fix: propagate async store failures to Future callers - Re-raise exceptions in _async_store and _async_batch_store after logging, so Future reflects actual store failure instead of silent success (matches RemoteBackend error propagation pattern) - Fix blocking_store test helper to avoid mock recursion --- lmcache/v1/storage_backend/maru_backend.py | 2 ++ tests/v1/storage_backend/test_maru_backend.py | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index 849f070e47..ccc6d0960d 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -379,6 +379,7 @@ async def _async_store( except Exception as e: logger.error("[Maru] store failed key=%s: %s", key, e) + raise finally: with self.put_lock: self.put_tasks.discard(key) @@ -414,6 +415,7 @@ async def _async_batch_store( logger.debug("[Maru] batch_store %d/%d ok", sum(results), len(results)) except Exception as e: logger.error("[Maru] batch_store failed: %s", e) + raise finally: with self.put_lock: self.put_tasks.difference_update(keys) diff --git a/tests/v1/storage_backend/test_maru_backend.py b/tests/v1/storage_backend/test_maru_backend.py index ae94163523..abb383b74e 100644 --- a/tests/v1/storage_backend/test_maru_backend.py +++ b/tests/v1/storage_backend/test_maru_backend.py @@ -329,12 +329,10 @@ def test_exists_in_put_tasks_true_during_store(self, backend, adapter): store_entered = threading.Event() store_proceed = threading.Event() - original_store = backend._handler.store - def blocking_store(*args, **kwargs): store_entered.set() store_proceed.wait(timeout=5) - return original_store(*args, **kwargs) + return True backend._handler.store.side_effect = blocking_store From 000134b5ac6621bf6ed093886bad7bb7de735059 Mon Sep 17 00:00:00 2001 From: seohui-XCENA Date: Fri, 20 Mar 2026 07:28:32 +0000 Subject: [PATCH 18/21] fix: pin memory_obj in async retrieve to balance cleanup unpin Add memory_obj.pin() in batched_get_non_blocking() so that cleanup_memory_objs()'s memory_obj.unpin() call doesn't cause negative pin_count warnings. Sync retrieve paths are unaffected as they use server-side unpin via lookup_unpin(). --- lmcache/v1/storage_backend/maru_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index 849f070e47..28667b085f 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -604,6 +604,7 @@ def _batch_get() -> list[MemoryObj]: if memory_obj is None: break memory_obj.ref_count_up() + memory_obj.pin() results.append(memory_obj) logger.debug( From a73e8baf896cd30817f53fc84220055d2c79e4ba Mon Sep 17 00:00:00 2001 From: jooho Date: Fri, 20 Mar 2026 17:17:44 +0900 Subject: [PATCH 19/21] fix: MaruBackend test failures, config cleanup, close() drain (#16) * refactor: simplify maru_pool_size config from str to float (GB) Align with existing config pattern (max_local_cpu_size uses float GB). Remove _parse_pool_size string parser in favor of simple GB-to-bytes conversion. This eliminates silent fallback on invalid values. * fix: drain in-flight put tasks before closing MaruBackend Wait for all pending async store tasks to complete before closing the handler, consistent with PDBackend and GDSBackend patterns. Prevents crash/data loss when _async_store is still running. --- .../source/kv_cache/storage_backends/maru.rst | 6 +- lmcache/v1/config.py | 6 +- lmcache/v1/storage_backend/maru_backend.py | 38 ++++--------- tests/v1/storage_backend/test_maru_backend.py | 55 ++++++------------- 4 files changed, 34 insertions(+), 71 deletions(-) diff --git a/docs/source/kv_cache/storage_backends/maru.rst b/docs/source/kv_cache/storage_backends/maru.rst index 3cd4cc7361..b91920f986 100644 --- a/docs/source/kv_cache/storage_backends/maru.rst +++ b/docs/source/kv_cache/storage_backends/maru.rst @@ -49,7 +49,7 @@ Deploy Model With Maru # Maru backend maru_path: "tcp://localhost:5555" - maru_pool_size: 4G + maru_pool_size: 4 **3. Start vLLM with Maru** @@ -78,8 +78,8 @@ Configuration - Required - Maru server URL (format: ``tcp://host:port``) * - ``maru_pool_size`` - - ``"4G"`` - - CXL memory pool size per instance (e.g., ``"4G"``, ``"500M"``) + - ``4.0`` + - CXL memory pool size per instance in GB (e.g., ``4``, ``0.5``) **Advanced Parameters (via extra_config):** diff --git a/lmcache/v1/config.py b/lmcache/v1/config.py index 3e51907cec..20bbc19276 100644 --- a/lmcache/v1/config.py +++ b/lmcache/v1/config.py @@ -239,9 +239,9 @@ # Maru CXL shared memory backend "maru_path": {"type": Optional[str], "default": None, "env_converter": str}, "maru_pool_size": { - "type": Optional[str], - "default": None, - "env_converter": str, + "type": float, + "default": 4.0, + "env_converter": float, }, # Other configurations # (Deprecated) The url of the actual remote lmcache instance for auditing. diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index 99a9319cf9..21b343229d 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -4,8 +4,8 @@ from concurrent.futures import Future from typing import Any, Callable, List, Optional, Sequence, Union import asyncio -import re import threading +import time # Third Party from maru import MaruConfig, MaruHandler @@ -93,24 +93,9 @@ def __str__(self) -> str: return self.__class__.__name__ @staticmethod - def _parse_pool_size(raw: Optional[str]) -> int: - """Parse human-readable pool size (e.g. '4G', '512M') to bytes.""" - _DEFAULT = 4 * 1024**3 - if raw is None: - return _DEFAULT - if isinstance(raw, (int, float)): - return int(raw) - s = str(raw).strip().upper() - match = re.match(r"^(\d+(?:\.\d+)?)\s*([KMGT]?)B?$", s) - if not match: - try: - return int(s) - except ValueError: - logger.warning("Cannot parse maru_pool_size=%r, using default", raw) - return _DEFAULT - value, unit = float(match.group(1)), match.group(2) - multipliers = {"": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4} - return int(value * multipliers.get(unit, 1)) + def _pool_size_gb_to_bytes(size_gb: float) -> int: + """Convert pool size in GB to bytes.""" + return int(size_gb * 1024**3) # ========================================================================= # Initialization helpers @@ -142,7 +127,7 @@ def _create_handler( maru_config = MaruConfig( server_url=server_url, instance_id=extra.get("maru_instance_id"), - pool_size=self._parse_pool_size(config.maru_pool_size), + pool_size=self._pool_size_gb_to_bytes(config.maru_pool_size), chunk_size_bytes=self._full_chunk_size_bytes, auto_connect=False, timeout_ms=extra.get("maru_timeout_ms", 5000), @@ -739,13 +724,12 @@ def remove(self, key: CacheEngineKey, force: bool = True) -> bool: def close(self) -> None: """Close the backend and underlying MaruHandler.""" - with self.put_lock: - pending = len(self.put_tasks) - if pending > 0: - logger.warning( - "[Maru] closing with %d in-flight put tasks still pending", - pending, - ) + while True: + with self.put_lock: + if not self.put_tasks: + break + time.sleep(0.1) + self.memory_allocator.close() self._handler.close() logger.info("MaruBackend closed.") diff --git a/tests/v1/storage_backend/test_maru_backend.py b/tests/v1/storage_backend/test_maru_backend.py index abb383b74e..a0a085af39 100644 --- a/tests/v1/storage_backend/test_maru_backend.py +++ b/tests/v1/storage_backend/test_maru_backend.py @@ -176,7 +176,7 @@ def backend(mock_handler, adapter, async_loop): backend = MaruBackend.__new__(MaruBackend) backend.dst_device = "cpu" backend.config = MagicMock() - backend.config.maru_pool_size = "4K" + backend.config.maru_pool_size = 4.0 backend.loop = async_loop backend.memory_allocator = adapter backend._handler = mock_handler @@ -212,41 +212,20 @@ def test_get_memory_allocator_returns_adapter(self, backend, adapter): assert backend.get_memory_allocator() is adapter -class TestMaruBackendParsePoolSize: - """Test _parse_pool_size static method.""" +class TestMaruBackendPoolSizeGbToBytes: + """Test _pool_size_gb_to_bytes static method.""" - def test_none_returns_default(self): - result = MaruBackend._parse_pool_size(None) - assert result == 4 * 1024**3 + def test_4gb(self): + assert MaruBackend._pool_size_gb_to_bytes(4.0) == 4 * 1024**3 - def test_parse_gigabytes(self): - assert MaruBackend._parse_pool_size("4G") == 4 * 1024**3 - assert MaruBackend._parse_pool_size("4GB") == 4 * 1024**3 + def test_half_gb(self): + assert MaruBackend._pool_size_gb_to_bytes(0.5) == 512 * 1024**2 - def test_parse_megabytes(self): - assert MaruBackend._parse_pool_size("512M") == 512 * 1024**2 - assert MaruBackend._parse_pool_size("512MB") == 512 * 1024**2 + def test_1gb(self): + assert MaruBackend._pool_size_gb_to_bytes(1.0) == 1024**3 - def test_parse_kilobytes(self): - assert MaruBackend._parse_pool_size("1K") == 1024 - assert MaruBackend._parse_pool_size("1KB") == 1024 - - def test_parse_terabytes(self): - assert MaruBackend._parse_pool_size("1T") == 1024**4 - - def test_parse_plain_integer_string(self): - assert MaruBackend._parse_pool_size("1048576") == 1048576 - - def test_parse_integer_value(self): - assert MaruBackend._parse_pool_size(2048) == 2048 - - def test_parse_invalid_returns_default(self): - result = MaruBackend._parse_pool_size("invalid") - assert result == 4 * 1024**3 - - def test_parse_case_insensitive(self): - assert MaruBackend._parse_pool_size("4g") == 4 * 1024**3 - assert MaruBackend._parse_pool_size("512m") == 512 * 1024**2 + def test_zero(self): + assert MaruBackend._pool_size_gb_to_bytes(0.0) == 0 class TestMaruBackendInterfaceCompliance: @@ -741,17 +720,17 @@ def test_close_calls_handler_and_allocator(self, backend): backend.memory_allocator.close.assert_called_once() backend._handler.close.assert_called_once() - def test_close_with_pending_put_tasks(self, backend, adapter): - """close() should warn but not raise when put tasks are pending.""" + def test_close_drains_pending_put_tasks(self, backend, adapter): + """close() should wait for in-flight put tasks to complete.""" obj = _make_memory_obj(adapter) obj.parent_allocator = None key = _make_cache_key() - # Manually add to put_tasks to simulate in-flight - with backend.put_lock: - backend.put_tasks.add(key) + # Submit a real put task that will complete via the event loop + future = backend.submit_put_task(key, obj) + future.result(timeout=5) - # Should not raise + # After drain, close should succeed backend.close() backend._handler.close.assert_called_once() From f13580a0f46b37f8507b84237ad1c3a7211b0ac5 Mon Sep 17 00:00:00 2001 From: hyunyul-XCENA Date: Fri, 20 Mar 2026 08:44:08 +0000 Subject: [PATCH 20/21] docs: tcp -> maru --- docs/source/kv_cache/storage_backends/maru.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/kv_cache/storage_backends/maru.rst b/docs/source/kv_cache/storage_backends/maru.rst index b91920f986..8430e6ee3f 100644 --- a/docs/source/kv_cache/storage_backends/maru.rst +++ b/docs/source/kv_cache/storage_backends/maru.rst @@ -48,7 +48,7 @@ Deploy Model With Maru save_unfull_chunk: True # Maru backend - maru_path: "tcp://localhost:5555" + maru_path: "maru://localhost:5555" maru_pool_size: 4 **3. Start vLLM with Maru** @@ -76,7 +76,7 @@ Configuration - Description * - ``maru_path`` - Required - - Maru server URL (format: ``tcp://host:port``) + - Maru server URL (format: ``maru://host:port``) * - ``maru_pool_size`` - ``4.0`` - CXL memory pool size per instance in GB (e.g., ``4``, ``0.5``) From cdb7133b61c263251118242e158360dbb27ad991 Mon Sep 17 00:00:00 2001 From: youngrok-XCENA Date: Fri, 20 Mar 2026 08:56:36 +0000 Subject: [PATCH 21/21] tests: add store failure ref_count_down tests for MaruBackend Verify that ref_count returns to pre-submit level and put_tasks are cleaned up when handler.store / batch_store raises an exception. --- tests/v1/storage_backend/test_maru_backend.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/v1/storage_backend/test_maru_backend.py b/tests/v1/storage_backend/test_maru_backend.py index a0a085af39..d2b3f3fe1b 100644 --- a/tests/v1/storage_backend/test_maru_backend.py +++ b/tests/v1/storage_backend/test_maru_backend.py @@ -388,6 +388,44 @@ def test_submit_put_task_skips_in_mla_mode(self, backend, adapter): assert future.result(timeout=5) is None backend._handler.store.assert_not_called() + def test_submit_put_task_refcount_down_on_failure(self, backend, adapter): + """On store failure, ref_count should return to pre-submit level.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + initial_ref = obj.get_ref_count() + + backend._handler.store.side_effect = RuntimeError("store failed") + + future = backend.submit_put_task(key, obj) + with pytest.raises(RuntimeError): + future.result(timeout=5) + + assert obj.get_ref_count() == initial_ref + assert not backend.exists_in_put_tasks(key) + + def test_batched_submit_put_task_refcount_down_on_failure( + self, backend, adapter + ): + """On batch_store failure, ref_count should return to pre-submit level.""" + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + initial_refs = [obj.get_ref_count() for obj in objs] + + backend._handler.batch_store.side_effect = RuntimeError("batch failed") + + futures = backend.batched_submit_put_task(keys, objs) + for future in futures: + with pytest.raises(RuntimeError): + future.result(timeout=5) + + for obj, initial_ref in zip(objs, initial_refs): + assert obj.get_ref_count() == initial_ref + for key in keys: + assert not backend.exists_in_put_tasks(key) + def test_batched_submit_put_task_skips_in_mla_mode(self, backend, adapter): """In MLA worker_id_as0 mode, batched_submit_put_task should skip.""" backend._mla_worker_id_as0_mode = True