diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py index ef4ef8a21b..b2f80eadd3 100644 --- a/lmcache/v1/storage_backend/maru_backend.py +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -136,10 +136,10 @@ def _create_handler( # Convert maru:// scheme to tcp:// for ZMQ server_url = config.maru_path if server_url.startswith("maru://"): - server_url = "tcp://" + server_url[len("maru://"):] + server_url = "tcp://" + server_url[len("maru://") :] extra = config.extra_config or {} - maru_config = MaruConfig( + maru_kwargs = dict( server_url=server_url, instance_id=extra.get("maru_instance_id"), pool_size=self._parse_pool_size(config.maru_pool_size), @@ -150,6 +150,26 @@ def _create_handler( max_inflight=extra.get("maru_max_inflight", 64), eager_map=extra.get("maru_eager_map", True), ) + pool_id = extra.get("maru_pool_id") + if pool_id is not None: + try: + if isinstance(pool_id, list): + if pool_id: + maru_kwargs["pool_id"] = [int(p) for p in pool_id] + elif isinstance(pool_id, str): + stripped = pool_id.strip() + if stripped: + if "," in stripped: + maru_kwargs["pool_id"] = [ + int(p.strip()) for p in stripped.split(",") if p.strip() + ] + else: + maru_kwargs["pool_id"] = int(stripped) + else: + maru_kwargs["pool_id"] = int(pool_id) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid maru_pool_id={pool_id!r}: {e}") from e + maru_config = MaruConfig(**maru_kwargs) handler = MaruHandler(maru_config) if not handler.connect():