Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/kv_cache/storage_backends/maru.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Deploy Model With Maru

# Maru backend
maru_path: "tcp://localhost:5555"
maru_pool_size: 4G
maru_pool_size: 4

**3. Start vLLM with Maru**

Expand Down Expand Up @@ -78,8 +78,8 @@ Configuration
- Required
- Maru server URL (format: ``tcp://host:port``)
* - ``maru_pool_size``
- ``"4G"``
- CXL memory pool size per instance (e.g., ``"4G"``, ``"500M"``)
- ``4.0``
- CXL memory pool size per instance in GB (e.g., ``4``, ``0.5``)

**Advanced Parameters (via extra_config):**

Expand Down
6 changes: 3 additions & 3 deletions lmcache/v1/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@
# Maru CXL shared memory backend
"maru_path": {"type": Optional[str], "default": None, "env_converter": str},
"maru_pool_size": {
"type": Optional[str],
"default": None,
"env_converter": str,
"type": float,
"default": 4.0,
"env_converter": float,
},
# Other configurations
# (Deprecated) The url of the actual remote lmcache instance for auditing.
Expand Down
38 changes: 11 additions & 27 deletions lmcache/v1/storage_backend/maru_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from concurrent.futures import Future
from typing import Any, Callable, List, Optional, Sequence, Union
import asyncio
import re
import threading
import time

# Third Party
from maru import MaruConfig, MaruHandler
Expand Down Expand Up @@ -93,24 +93,9 @@ def __str__(self) -> str:
return self.__class__.__name__

@staticmethod
def _parse_pool_size(raw: Optional[str]) -> int:
"""Parse human-readable pool size (e.g. '4G', '512M') to bytes."""
_DEFAULT = 4 * 1024**3
if raw is None:
return _DEFAULT
if isinstance(raw, (int, float)):
return int(raw)
s = str(raw).strip().upper()
match = re.match(r"^(\d+(?:\.\d+)?)\s*([KMGT]?)B?$", s)
if not match:
try:
return int(s)
except ValueError:
logger.warning("Cannot parse maru_pool_size=%r, using default", raw)
return _DEFAULT
value, unit = float(match.group(1)), match.group(2)
multipliers = {"": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4}
return int(value * multipliers.get(unit, 1))
def _pool_size_gb_to_bytes(size_gb: float) -> int:
"""Convert pool size in GB to bytes."""
return int(size_gb * 1024**3)

# =========================================================================
# Initialization helpers
Expand Down Expand Up @@ -142,7 +127,7 @@ def _create_handler(
maru_config = MaruConfig(
server_url=server_url,
instance_id=extra.get("maru_instance_id"),
pool_size=self._parse_pool_size(config.maru_pool_size),
pool_size=self._pool_size_gb_to_bytes(config.maru_pool_size),
chunk_size_bytes=self._full_chunk_size_bytes,
auto_connect=False,
timeout_ms=extra.get("maru_timeout_ms", 5000),
Expand Down Expand Up @@ -739,13 +724,12 @@ def remove(self, key: CacheEngineKey, force: bool = True) -> bool:

def close(self) -> None:
"""Close the backend and underlying MaruHandler."""
with self.put_lock:
pending = len(self.put_tasks)
if pending > 0:
logger.warning(
"[Maru] closing with %d in-flight put tasks still pending",
pending,
)
while True:
with self.put_lock:
if not self.put_tasks:
break
time.sleep(0.1)

self.memory_allocator.close()
self._handler.close()
logger.info("MaruBackend closed.")
55 changes: 17 additions & 38 deletions tests/v1/storage_backend/test_maru_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def backend(mock_handler, adapter, async_loop):
backend = MaruBackend.__new__(MaruBackend)
backend.dst_device = "cpu"
backend.config = MagicMock()
backend.config.maru_pool_size = "4K"
backend.config.maru_pool_size = 4.0
backend.loop = async_loop
backend.memory_allocator = adapter
backend._handler = mock_handler
Expand Down Expand Up @@ -212,41 +212,20 @@ def test_get_memory_allocator_returns_adapter(self, backend, adapter):
assert backend.get_memory_allocator() is adapter


class TestMaruBackendParsePoolSize:
"""Test _parse_pool_size static method."""
class TestMaruBackendPoolSizeGbToBytes:
"""Test _pool_size_gb_to_bytes static method."""

def test_none_returns_default(self):
result = MaruBackend._parse_pool_size(None)
assert result == 4 * 1024**3
def test_4gb(self):
assert MaruBackend._pool_size_gb_to_bytes(4.0) == 4 * 1024**3

def test_parse_gigabytes(self):
assert MaruBackend._parse_pool_size("4G") == 4 * 1024**3
assert MaruBackend._parse_pool_size("4GB") == 4 * 1024**3
def test_half_gb(self):
assert MaruBackend._pool_size_gb_to_bytes(0.5) == 512 * 1024**2

def test_parse_megabytes(self):
assert MaruBackend._parse_pool_size("512M") == 512 * 1024**2
assert MaruBackend._parse_pool_size("512MB") == 512 * 1024**2
def test_1gb(self):
assert MaruBackend._pool_size_gb_to_bytes(1.0) == 1024**3

def test_parse_kilobytes(self):
assert MaruBackend._parse_pool_size("1K") == 1024
assert MaruBackend._parse_pool_size("1KB") == 1024

def test_parse_terabytes(self):
assert MaruBackend._parse_pool_size("1T") == 1024**4

def test_parse_plain_integer_string(self):
assert MaruBackend._parse_pool_size("1048576") == 1048576

def test_parse_integer_value(self):
assert MaruBackend._parse_pool_size(2048) == 2048

def test_parse_invalid_returns_default(self):
result = MaruBackend._parse_pool_size("invalid")
assert result == 4 * 1024**3

def test_parse_case_insensitive(self):
assert MaruBackend._parse_pool_size("4g") == 4 * 1024**3
assert MaruBackend._parse_pool_size("512m") == 512 * 1024**2
def test_zero(self):
assert MaruBackend._pool_size_gb_to_bytes(0.0) == 0


class TestMaruBackendInterfaceCompliance:
Expand Down Expand Up @@ -741,17 +720,17 @@ def test_close_calls_handler_and_allocator(self, backend):
backend.memory_allocator.close.assert_called_once()
backend._handler.close.assert_called_once()

def test_close_with_pending_put_tasks(self, backend, adapter):
"""close() should warn but not raise when put tasks are pending."""
def test_close_drains_pending_put_tasks(self, backend, adapter):
"""close() should wait for in-flight put tasks to complete."""
obj = _make_memory_obj(adapter)
obj.parent_allocator = None
key = _make_cache_key()

# Manually add to put_tasks to simulate in-flight
with backend.put_lock:
backend.put_tasks.add(key)
# Submit a real put task that will complete via the event loop
future = backend.submit_put_task(key, obj)
future.result(timeout=5)

# Should not raise
# After drain, close should succeed
backend.close()
backend._handler.close.assert_called_once()

Expand Down
Loading