Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion lmcache/v1/storage_backend/maru_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import asyncio
import re
import threading
import time

# Third Party
from maru import MaruConfig, MaruHandler
from maru_lmcache import CxlMemoryAdapter
import prometheus_client
import torch

# First Party
Expand Down Expand Up @@ -89,9 +91,48 @@ def __init__(
self.put_lock = threading.Lock()
self.put_tasks: set[CacheEngineKey] = set()

self._setup_metrics()

def __str__(self) -> str:
return self.__class__.__name__

def _setup_metrics(self):
"""Create Prometheus metrics for this MaruBackend instance.

Uses multiprocess_mode on Gauge so that metrics are visible
across process boundaries when PROMETHEUS_MULTIPROC_DIR is set
(required for vLLM V1 multi-process architecture).
"""
self._maru_put_task_gauge = prometheus_client.Gauge(
"lmcache:maru_put_task_num",
"Number of in-flight Maru put tasks",
multiprocess_mode="livemostrecent",
)
self._maru_put_task_gauge.set_function(lambda: len(self.put_tasks))

self._maru_put_failed = prometheus_client.Counter(
"lmcache:maru_put_failed_count",
"Total Maru put failures",
)
self._maru_get_failed = prometheus_client.Counter(
"lmcache:maru_get_blocking_failed_count",
"Total Maru get_blocking failures",
)
self._maru_alloc_failed = prometheus_client.Counter(
"lmcache:maru_alloc_failed_count",
"Total Maru CXL memory allocation failures",
)
self._maru_store_latency = prometheus_client.Histogram(
"lmcache:maru_store_latency_seconds",
"Maru store RPC latency in seconds",
buckets=[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0],
)
self._maru_retrieve_latency = prometheus_client.Histogram(
"lmcache:maru_retrieve_latency_seconds",
"Maru retrieve (CXL read) latency in seconds",
buckets=[0.0001, 0.0005, 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5],
)

@staticmethod
def _parse_pool_size(raw: Optional[str]) -> int:
"""Parse human-readable pool size (e.g. '4G', '512M') to bytes."""
Expand Down Expand Up @@ -136,7 +177,7 @@ def _create_handler(
# Convert maru:// scheme to tcp:// for ZMQ
server_url = config.maru_path
if server_url.startswith("maru://"):
server_url = "tcp://" + server_url[len("maru://"):]
server_url = "tcp://" + server_url[len("maru://") :]

extra = config.extra_config or {}
maru_config = MaruConfig(
Expand Down Expand Up @@ -221,6 +262,7 @@ def allocate(
*CxlMemoryAdapter.decode_address(obj.metadata.address),
)
else:
self._maru_alloc_failed.inc()
logger.debug("[Maru] allocate failed shapes=%s dtypes=%s", shapes, dtypes)
return obj

Expand Down Expand Up @@ -343,7 +385,9 @@ async def _async_store(
handle = allocator.create_store_handle(memory_obj)
key_str = key.to_string()

t0 = time.perf_counter()
await asyncio.to_thread(self._handler.store, key_str, handle)
self._maru_store_latency.observe(time.perf_counter() - t0)
success = True

logger.debug(
Expand All @@ -354,6 +398,7 @@ async def _async_store(
)

except Exception as e:
self._maru_put_failed.inc()
logger.error("[Maru] store failed key=%s: %s", key, e)
finally:
with self.put_lock:
Expand Down Expand Up @@ -387,9 +432,11 @@ def get_blocking(
if self._mla_worker_id_as0_mode:
key = key.with_new_worker_id(0)

t0 = time.perf_counter()
key_str = key.to_string()
mem_info = self._handler.retrieve(key_str)
if mem_info is None:
self._maru_get_failed.inc()
logger.debug("[Maru] get_blocking miss key=%s", key)
return None

Expand All @@ -403,6 +450,7 @@ def get_blocking(
single_token_size=self._single_token_size,
)
if memory_obj is None:
self._maru_get_failed.inc()
logger.debug(
"[Maru] get_blocking pool miss rid=%d pid=%d",
mem_info.region_id,
Expand All @@ -412,6 +460,7 @@ def get_blocking(

memory_obj.ref_count_up()
memory_obj.pin()
self._maru_retrieve_latency.observe(time.perf_counter() - t0)

logger.debug(
"[Maru] get_blocking rid=%d pid=%d size=%d",
Expand Down
Loading