diff --git a/docs/source/kv_cache/storage_backends/maru.rst b/docs/source/kv_cache/storage_backends/maru.rst index 3cd4cc7361..b91920f986 100644 --- a/docs/source/kv_cache/storage_backends/maru.rst +++ b/docs/source/kv_cache/storage_backends/maru.rst @@ -49,7 +49,7 @@ Deploy Model With Maru # Maru backend maru_path: "tcp://localhost:5555" - maru_pool_size: 4G + maru_pool_size: 4 **3. Start vLLM with Maru** @@ -78,8 +78,8 @@ Configuration - Required - Maru server URL (format: ``tcp://host:port``) * - ``maru_pool_size`` - - ``"4G"`` - - CXL memory pool size per instance (e.g., ``"4G"``, ``"500M"``) + - ``4.0`` + - CXL memory pool size per instance in GB (e.g., ``4``, ``0.5``) **Advanced Parameters (via extra_config):** diff --git a/lmcache/v1/config.py b/lmcache/v1/config.py index 3e51907cec..20bbc19276 100644 --- a/lmcache/v1/config.py +++ b/lmcache/v1/config.py @@ -239,9 +239,9 @@ # Maru CXL shared memory backend "maru_path": {"type": Optional[str], "default": None, "env_converter": str}, "maru_pool_size": { - "type": Optional[str], - "default": None, - "env_converter": str, + "type": float, + "default": 4.0, + "env_converter": float, }, # Other configurations # (Deprecated) The url of the actual remote lmcache instance for auditing. diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index 99a9319cf9..21b343229d 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -4,8 +4,8 @@ from concurrent.futures import Future from typing import Any, Callable, List, Optional, Sequence, Union import asyncio -import re import threading +import time # Third Party from maru import MaruConfig, MaruHandler @@ -93,24 +93,9 @@ def __str__(self) -> str: return self.__class__.__name__ @staticmethod - def _parse_pool_size(raw: Optional[str]) -> int: - """Parse human-readable pool size (e.g. '4G', '512M') to bytes.""" - _DEFAULT = 4 * 1024**3 - if raw is None: - return _DEFAULT - if isinstance(raw, (int, float)): - return int(raw) - s = str(raw).strip().upper() - match = re.match(r"^(\d+(?:\.\d+)?)\s*([KMGT]?)B?$", s) - if not match: - try: - return int(s) - except ValueError: - logger.warning("Cannot parse maru_pool_size=%r, using default", raw) - return _DEFAULT - value, unit = float(match.group(1)), match.group(2) - multipliers = {"": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4} - return int(value * multipliers.get(unit, 1)) + def _pool_size_gb_to_bytes(size_gb: float) -> int: + """Convert pool size in GB to bytes.""" + return int(size_gb * 1024**3) # ========================================================================= # Initialization helpers @@ -142,7 +127,7 @@ def _create_handler( maru_config = MaruConfig( server_url=server_url, instance_id=extra.get("maru_instance_id"), - pool_size=self._parse_pool_size(config.maru_pool_size), + pool_size=self._pool_size_gb_to_bytes(config.maru_pool_size), chunk_size_bytes=self._full_chunk_size_bytes, auto_connect=False, timeout_ms=extra.get("maru_timeout_ms", 5000), @@ -739,13 +724,12 @@ def remove(self, key: CacheEngineKey, force: bool = True) -> bool: def close(self) -> None: """Close the backend and underlying MaruHandler.""" - with self.put_lock: - pending = len(self.put_tasks) - if pending > 0: - logger.warning( - "[Maru] closing with %d in-flight put tasks still pending", - pending, - ) + while True: + with self.put_lock: + if not self.put_tasks: + break + time.sleep(0.1) + self.memory_allocator.close() self._handler.close() logger.info("MaruBackend closed.") diff --git a/tests/v1/storage_backend/test_maru_backend.py b/tests/v1/storage_backend/test_maru_backend.py index abb383b74e..a0a085af39 100644 --- a/tests/v1/storage_backend/test_maru_backend.py +++ b/tests/v1/storage_backend/test_maru_backend.py @@ -176,7 +176,7 @@ def backend(mock_handler, adapter, async_loop): backend = MaruBackend.__new__(MaruBackend) backend.dst_device = "cpu" backend.config = MagicMock() - backend.config.maru_pool_size = "4K" + backend.config.maru_pool_size = 4.0 backend.loop = async_loop backend.memory_allocator = adapter backend._handler = mock_handler @@ -212,41 +212,20 @@ def test_get_memory_allocator_returns_adapter(self, backend, adapter): assert backend.get_memory_allocator() is adapter -class TestMaruBackendParsePoolSize: - """Test _parse_pool_size static method.""" +class TestMaruBackendPoolSizeGbToBytes: + """Test _pool_size_gb_to_bytes static method.""" - def test_none_returns_default(self): - result = MaruBackend._parse_pool_size(None) - assert result == 4 * 1024**3 + def test_4gb(self): + assert MaruBackend._pool_size_gb_to_bytes(4.0) == 4 * 1024**3 - def test_parse_gigabytes(self): - assert MaruBackend._parse_pool_size("4G") == 4 * 1024**3 - assert MaruBackend._parse_pool_size("4GB") == 4 * 1024**3 + def test_half_gb(self): + assert MaruBackend._pool_size_gb_to_bytes(0.5) == 512 * 1024**2 - def test_parse_megabytes(self): - assert MaruBackend._parse_pool_size("512M") == 512 * 1024**2 - assert MaruBackend._parse_pool_size("512MB") == 512 * 1024**2 + def test_1gb(self): + assert MaruBackend._pool_size_gb_to_bytes(1.0) == 1024**3 - def test_parse_kilobytes(self): - assert MaruBackend._parse_pool_size("1K") == 1024 - assert MaruBackend._parse_pool_size("1KB") == 1024 - - def test_parse_terabytes(self): - assert MaruBackend._parse_pool_size("1T") == 1024**4 - - def test_parse_plain_integer_string(self): - assert MaruBackend._parse_pool_size("1048576") == 1048576 - - def test_parse_integer_value(self): - assert MaruBackend._parse_pool_size(2048) == 2048 - - def test_parse_invalid_returns_default(self): - result = MaruBackend._parse_pool_size("invalid") - assert result == 4 * 1024**3 - - def test_parse_case_insensitive(self): - assert MaruBackend._parse_pool_size("4g") == 4 * 1024**3 - assert MaruBackend._parse_pool_size("512m") == 512 * 1024**2 + def test_zero(self): + assert MaruBackend._pool_size_gb_to_bytes(0.0) == 0 class TestMaruBackendInterfaceCompliance: @@ -741,17 +720,17 @@ def test_close_calls_handler_and_allocator(self, backend): backend.memory_allocator.close.assert_called_once() backend._handler.close.assert_called_once() - def test_close_with_pending_put_tasks(self, backend, adapter): - """close() should warn but not raise when put tasks are pending.""" + def test_close_drains_pending_put_tasks(self, backend, adapter): + """close() should wait for in-flight put tasks to complete.""" obj = _make_memory_obj(adapter) obj.parent_allocator = None key = _make_cache_key() - # Manually add to put_tasks to simulate in-flight - with backend.put_lock: - backend.put_tasks.add(key) + # Submit a real put task that will complete via the event loop + future = backend.submit_put_task(key, obj) + future.result(timeout=5) - # Should not raise + # After drain, close should succeed backend.close() backend._handler.close.assert_called_once()