diff --git a/csrc/storage_backends/README.md b/csrc/storage_backends/README.md index 6ce691d084..c50f7ee6da 100644 --- a/csrc/storage_backends/README.md +++ b/csrc/storage_backends/README.md @@ -49,7 +49,7 @@ Therefore the framework enforces: |------|---------| | `connector_types.h` | `Request`, `Completion`, `BatchState`, `Op` | | `connector_interface.h` | `IStorageConnector` — top-level abstract interface | -| `connector_base.h` | `ConnectorBase` — core harness (eventfd, SQ/CQ, threading, tiling). Override 4 methods per backend | +| `connector_base.h` | `ConnectorBase` — core harness (eventfd, SQ/CQ, threading, tiling). Override 4 required + 1 optional method per backend | | `connector_pybind_utils.h` | Pybind utilities with GIL release + `LMCACHE_BIND_CONNECTOR_METHODS` macro | | `redis/` | Reference implementation (RESP2 protocol over TCP) | @@ -61,8 +61,8 @@ each step. ### Step 1: C++ connector — inherit from ConnectorBase Create your connector directory (e.g., `csrc/storage_backends/mybackend/`) -and inherit from `ConnectorBase`. You only need to -override 4 methods: +and inherit from `ConnectorBase`. You need to +override 4 required methods (and optionally `do_single_delete` for eviction): ```cpp // csrc/storage_backends/mybackend/connector.h @@ -105,6 +105,11 @@ class MyConnector : public lmcache::connector::ConnectorBase { // send EXISTS, return true/false } + // Optional: delete a key (enables eviction support) + bool do_single_delete(MyConn& conn, const std::string& key) override { + // send DELETE, return true if deleted, false if not found + } + // Optional: clean shutdown of connections void shutdown_connections() override { /* close sockets */ } @@ -275,7 +280,7 @@ Python eventfd. ## Checklist for a new backend -- [ ] C++ connector inheriting `ConnectorBase` with 4 method overrides +- [ ] C++ connector inheriting `ConnectorBase` with 4 required + 1 optional (`do_single_delete`) method overrides - [ ] Pybind module using `LMCACHE_BIND_CONNECTOR_METHODS` - [ ] `setup.py` entry for the new `CppExtension` - [ ] Python client inheriting `ConnectorClientBase` (non-MP mode) diff --git a/csrc/storage_backends/connector_base.h b/csrc/storage_backends/connector_base.h index 11c1ead13c..ecf1229d83 100644 --- a/csrc/storage_backends/connector_base.h +++ b/csrc/storage_backends/connector_base.h @@ -25,6 +25,9 @@ this base needs to have at least four methods be overridden by the derived - 3. do_single_set() - 4. do_single_exists() +optionally override do_single_delete() to support eviction (default returns +false for all keys). + see the RedisConnector (csrc/redis/) implementing the RESP2 protocol over TCP for an example */ @@ -130,6 +133,39 @@ class ConnectorBase : public IStorageConnector { return batch_future_id; } + uint64_t submit_batch_delete(const std::vector& keys) override { + if (keys.empty()) { + throw std::runtime_error("keys list is empty"); + } + + size_t num_items = keys.size(); + auto [batch_future_id, batch_state, num_tiles, tile_size] = + prepare_batch_operation(num_items, Op::BATCH_TILE_DELETE); + + // pre-allocate per-key results (1 = deleted, 0 = not found) + batch_state->per_key_results.assign(num_items, 0); + + // fan out work to threads + for (size_t tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + size_t start = tile_idx * tile_size; + size_t end = std::min(start + tile_size, num_items); + + Request tile_req; + tile_req.op = Op::BATCH_TILE_DELETE; + tile_req.future_id = batch_future_id; + tile_req.batch = batch_state; + tile_req.start_idx = start; + + for (size_t i = start; i < end; ++i) { + tile_req.keys.push_back(keys[i]); + } + + enqueue_request(std::move(tile_req)); + } + + return batch_future_id; + } + std::vector drain_completions() override { // Drain the eventfd that triggered this drain_completions callback drain_eventfd_(); @@ -216,6 +252,11 @@ class ConnectorBase : public IStorageConnector { size_t chunk_size) = 0; virtual bool do_single_exists(ConnectionType& conn, const std::string& key) = 0; + virtual bool do_single_delete(ConnectionType& conn, const std::string& key) { + (void)conn; + (void)key; + return false; // no-op default for backward compat with plugins + } virtual void shutdown_connections() {} bool is_stopping() const { return stop_.load(std::memory_order_acquire); } @@ -393,6 +434,23 @@ class ConnectorBase : public IStorageConnector { } comp.ok = true; break; + + case Op::BATCH_TILE_DELETE: + for (size_t i = 0; i < req.keys.size(); ++i) { + try { + bool deleted = do_single_delete(conn, req.keys[i]); + req.batch->per_key_results[req.start_idx + i] = + deleted ? 1 : 0; + } catch (const std::exception& e) { + // Per-key error tolerance: record failure + // but continue processing remaining keys + req.batch->per_key_results[req.start_idx + i] = 0; + fprintf(stderr, "[LMCache DELETE] key %s failed: %s\n", + req.keys[i].c_str(), e.what()); + } + } + comp.ok = true; + break; } } catch (const std::exception& e) { comp.ok = false; @@ -438,7 +496,8 @@ class ConnectorBase : public IStorageConnector { } // for batch exists and batch get, move per-key results if (req.batch->batch_op == Op::BATCH_TILE_EXISTS || - req.batch->batch_op == Op::BATCH_TILE_GET) { + req.batch->batch_op == Op::BATCH_TILE_GET || + req.batch->batch_op == Op::BATCH_TILE_DELETE) { batch_comp.result_bytes = std::move(req.batch->per_key_results); } push_completion(std::move(batch_comp)); diff --git a/csrc/storage_backends/connector_interface.h b/csrc/storage_backends/connector_interface.h index 9a8ac3ec9f..9f5de15aa8 100644 --- a/csrc/storage_backends/connector_interface.h +++ b/csrc/storage_backends/connector_interface.h @@ -80,6 +80,24 @@ class IStorageConnector { virtual uint64_t submit_batch_exists( const std::vector& keys) = 0; + /* + submit a batch DELETE operation + + deletes multiple keys in parallel. work is automatically divided + among worker threads (tiling). returns a single future_id for the entire + batch. + + args: + keys: vector of key strings to delete + + returns: + uint64_t: future id for tracking this batch operation + completion will contain result_bytes vector with 0/1 for each key + (1 = deleted, 0 = not found) + */ + virtual uint64_t submit_batch_delete( + const std::vector& keys) = 0; + /* drain all available completions diff --git a/csrc/storage_backends/connector_pybind_utils.h b/csrc/storage_backends/connector_pybind_utils.h index 7e0c462317..f8c18563f6 100644 --- a/csrc/storage_backends/connector_pybind_utils.h +++ b/csrc/storage_backends/connector_pybind_utils.h @@ -36,6 +36,10 @@ example usage (see `redis/pybind.cpp`): lmcache::connector::pybind_utils::bind_submit_batch_exists< \ ConnectorType>(), \ py::arg("keys")) \ + .def("submit_batch_delete", \ + lmcache::connector::pybind_utils::bind_submit_batch_delete< \ + ConnectorType>(), \ + py::arg("keys")) \ .def("drain_completions", \ lmcache::connector::pybind_utils::bind_drain_completions< \ ConnectorType>()) \ @@ -113,6 +117,14 @@ auto bind_submit_batch_exists() { }; } +template +auto bind_submit_batch_delete() { + return [](ConnectorType& self, const std::vector& keys) { + py::gil_scoped_release release; + return self.submit_batch_delete(keys); + }; +} + template auto bind_drain_completions() { return [](ConnectorType& self) { diff --git a/csrc/storage_backends/connector_types.h b/csrc/storage_backends/connector_types.h index e77a8665ac..35b5d75987 100644 --- a/csrc/storage_backends/connector_types.h +++ b/csrc/storage_backends/connector_types.h @@ -23,7 +23,12 @@ namespace connector { // we only support batched operations // benefits are fewer submissions and fewer completions -enum class Op : uint8_t { BATCH_TILE_GET, BATCH_TILE_SET, BATCH_TILE_EXISTS }; +enum class Op : uint8_t { + BATCH_TILE_GET, + BATCH_TILE_SET, + BATCH_TILE_EXISTS, + BATCH_TILE_DELETE +}; /* shared communication state between threads executing a single batch operation. diff --git a/csrc/storage_backends/fs/connector.cpp b/csrc/storage_backends/fs/connector.cpp index db056a1e8c..0b12e2d1e4 100644 --- a/csrc/storage_backends/fs/connector.cpp +++ b/csrc/storage_backends/fs/connector.cpp @@ -271,5 +271,12 @@ bool FSConnector::do_single_exists(WorkerFSConn& conn, const std::string& key) { return std::filesystem::exists(file_path); } +bool FSConnector::do_single_delete(WorkerFSConn& conn, const std::string& key) { + std::string filename = key_to_filename(key); + auto file_path = conn.base_path / filename; + std::error_code ec; + return std::filesystem::remove(file_path, ec); +} + } // namespace connector } // namespace lmcache diff --git a/csrc/storage_backends/fs/connector.h b/csrc/storage_backends/fs/connector.h index 02f95474ab..8f7fc7c8f6 100644 --- a/csrc/storage_backends/fs/connector.h +++ b/csrc/storage_backends/fs/connector.h @@ -46,6 +46,7 @@ class FSConnector : public ConnectorBase { void do_single_set(WorkerFSConn& conn, const std::string& key, const void* buf, size_t len, size_t chunk_size) override; bool do_single_exists(WorkerFSConn& conn, const std::string& key) override; + bool do_single_delete(WorkerFSConn& conn, const std::string& key) override; private: // Build the filesystem-safe filename from a serialized key string. diff --git a/csrc/storage_backends/mooncake/connector.cpp b/csrc/storage_backends/mooncake/connector.cpp new file mode 100644 index 0000000000..6aaa92517a --- /dev/null +++ b/csrc/storage_backends/mooncake/connector.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "connector.h" + +#include +#include +#include +#include + +namespace lmcache { +namespace connector { + +MooncakeConnector::MooncakeConnector(ConfigDict config, int num_workers) + : ConnectorBase(num_workers), config_(std::move(config)) { + // Create a RealClient via the static factory. + client_ = mooncake::RealClient::create(); + if (!client_) { + throw std::runtime_error("Failed to create mooncake RealClient"); + } + + // Forward the config dict to setup_internal(). + mooncake::ConfigDict mc_config(config_.begin(), config_.end()); + auto result = client_->setup_internal(mc_config); + if (!result.has_value()) { + throw std::runtime_error("Mooncake setup_internal failed"); + } + + start_workers(); // IMPORTANT: call at END of ctor +} + +MooncakeConnector::~MooncakeConnector() { + close(); + if (client_) { + client_->tearDownAll(); + client_.reset(); + } +} + +WorkerMooncakeConn MooncakeConnector::create_connection() { + WorkerMooncakeConn conn; + conn.client = client_.get(); + return conn; +} + +void MooncakeConnector::do_single_get(WorkerMooncakeConn& conn, + const std::string& key, void* buf, + size_t len, size_t chunk_size) { + int64_t bytes_read = conn.client->get_into(key, buf, len); + if (bytes_read < 0) { + throw std::runtime_error("Mooncake get_into failed for key: " + key); + } +} + +void MooncakeConnector::do_single_set(WorkerMooncakeConn& conn, + const std::string& key, const void* buf, + size_t len, size_t chunk_size) { + int rc = conn.client->put_from(key, const_cast(buf), len); + if (rc != 0) { + throw std::runtime_error("Mooncake put_from failed for key: " + key); + } +} + +bool MooncakeConnector::do_single_exists(WorkerMooncakeConn& conn, + const std::string& key) { + // isExist returns: 1=exists, 0=not, -1=error + int result = conn.client->isExist(key); + if (result < 0) { + throw std::runtime_error("Mooncake isExist failed for key: " + key); + } + return result == 1; +} + +} // namespace connector +} // namespace lmcache diff --git a/csrc/storage_backends/mooncake/connector.h b/csrc/storage_backends/mooncake/connector.h new file mode 100644 index 0000000000..eb4a858b2c --- /dev/null +++ b/csrc/storage_backends/mooncake/connector.h @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "../connector_base.h" +#include "real_client.h" + +#include +#include +#include +#include +#include + +namespace lmcache { +namespace connector { + +// ConfigDict mirrors mooncake::ConfigDict +// (std::unordered_map). +using ConfigDict = std::unordered_map; + +// Per-worker connection state for the Mooncake connector. +// Each worker holds a raw pointer to the shared +// RealClient (owned by MooncakeConnector). +struct WorkerMooncakeConn { + mooncake::RealClient* client{nullptr}; +}; + +class MooncakeConnector : public ConnectorBase { + public: + MooncakeConnector(ConfigDict config, int num_workers); + ~MooncakeConnector() override; + + protected: + WorkerMooncakeConn create_connection() override; + + void do_single_get(WorkerMooncakeConn& conn, const std::string& key, + void* buf, size_t len, size_t chunk_size) override; + + void do_single_set(WorkerMooncakeConn& conn, const std::string& key, + const void* buf, size_t len, size_t chunk_size) override; + + bool do_single_exists(WorkerMooncakeConn& conn, + const std::string& key) override; + + private: + // Shared Mooncake RealClient instance. + std::shared_ptr client_; + + // The original config dict (kept for diagnostics). + ConfigDict config_; +}; + +} // namespace connector +} // namespace lmcache \ No newline at end of file diff --git a/csrc/storage_backends/mooncake/pybind.cpp b/csrc/storage_backends/mooncake/pybind.cpp new file mode 100644 index 0000000000..44c567e368 --- /dev/null +++ b/csrc/storage_backends/mooncake/pybind.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +#include +#include +#include "../connector_pybind_utils.h" +#include "connector.h" + +namespace py = pybind11; + +PYBIND11_MODULE(lmcache_mooncake, m) { + py::class_(m, "LMCacheMooncakeClient") + .def(py::init(), py::arg("config"), + py::arg("num_workers")) + LMCACHE_BIND_CONNECTOR_METHODS(lmcache::connector::MooncakeConnector); +} diff --git a/csrc/storage_backends/redis/connector.cpp b/csrc/storage_backends/redis/connector.cpp index 759ed191a6..320536c546 100644 --- a/csrc/storage_backends/redis/connector.cpp +++ b/csrc/storage_backends/redis/connector.cpp @@ -338,6 +338,32 @@ bool RedisConnector::do_single_exists(WorkerConn& conn, } } +// RESP DEL +bool RedisConnector::do_single_delete(WorkerConn& conn, + const std::string& key) { + // build key header using reusable buffer + const std::string& key_header = conn.make_key_header(key); + + // send DEL cmd + conn.send_multipart({{conn.del_prefix.data(), conn.del_prefix.size()}, + {key_header.data(), key_header.size()}}); + + // parse response (either :0\r\n or :1\r\n, same format as EXISTS) + char response[WorkerConn::exists_response_len]; + conn.recv_exactly(response, WorkerConn::exists_response_len); + + if (std::memcmp(response, WorkerConn::exists_one.data(), + WorkerConn::exists_response_len) == 0) { + return true; // key was deleted + } else if (std::memcmp(response, WorkerConn::exists_zero.data(), + WorkerConn::exists_response_len) == 0) { + return false; // key did not exist + } else { + throw std::runtime_error( + "DEL returned invalid response that wasn't :0\r\n or :1\r\n"); + } +} + void RedisConnector::shutdown_connections() { std::lock_guard lk(worker_fds_mu_); for (int fd : worker_fds_) { diff --git a/csrc/storage_backends/redis/connector.h b/csrc/storage_backends/redis/connector.h index 0fb85dec76..778c7e8e18 100644 --- a/csrc/storage_backends/redis/connector.h +++ b/csrc/storage_backends/redis/connector.h @@ -34,6 +34,7 @@ struct WorkerConn { std::string get_prefix; std::string set_prefix; std::string exists_prefix; + std::string del_prefix; // reusable buffers for building headers (avoids repeated dynamic allocations) std::string key_header_buf; @@ -53,7 +54,8 @@ struct WorkerConn { WorkerConn() : get_prefix("*2\r\n$3\r\nGET\r\n"), set_prefix("*3\r\n$3\r\nSET\r\n"), - exists_prefix("*2\r\n$6\r\nEXISTS\r\n") { + exists_prefix("*2\r\n$6\r\nEXISTS\r\n"), + del_prefix("*2\r\n$3\r\nDEL\r\n") { // pre-allocate key_header_buf to handle typical keys without reallocation // typical key format: model_name@world_size@worker_id@chunk_hash_hex@dtype // - model_name: 25-50 chars (e.g., "meta-llama/Llama-3-70b-instruct") @@ -101,6 +103,7 @@ class RedisConnector : public ConnectorBase { void do_single_set(WorkerConn& conn, const std::string& key, const void* buf, size_t len, size_t chunk_size) override; bool do_single_exists(WorkerConn& conn, const std::string& key) override; + bool do_single_delete(WorkerConn& conn, const std::string& key) override; void shutdown_connections() override; private: diff --git a/docs/design/l2_adapters/l2_eviction.md b/docs/design/l2_adapters/l2_eviction.md index af2ea4b8d7..c758a15bd9 100644 --- a/docs/design/l2_adapters/l2_eviction.md +++ b/docs/design/l2_adapters/l2_eviction.md @@ -197,7 +197,32 @@ capacity) can omit steps 2–6 and rely on the base class no-op defaults. | `MockL2Adapter` | ✓ | ✓ | stored, deleted | | `NixlStoreL2Adapter` | ✓ (skips pinned) | ✓ (pool-based) | stored, deleted | | `FSL2Adapter` | no-op | `(-1, -1)` | none | -| `NativeConnectorL2Adapter` | no-op | `(-1, -1)` | none | +| `NativeConnectorL2Adapter` | ✓ (via `submit_batch_delete`) | ✓ (client-side, requires `max_capacity_gb`) | stored, deleted | + +**Note on `NativeConnectorL2Adapter`:** Eviction support requires two things: + +1. The underlying C++ connector must implement `do_single_delete()` (built-in Redis + and FS connectors do; third-party plugins may not — in which case `delete()` is a + no-op). +2. The adapter must be configured with `max_capacity_gb > 0` to enable client-side + size tracking for `get_usage()`. Without it, `get_usage()` returns `(-1, -1)` and + the eviction controller will not trigger. + +Example configuration with eviction enabled: + +```json +{ + "type": "resp", + "host": "localhost", + "port": 6379, + "max_capacity_gb": 10, + "eviction": { + "eviction_policy": "LRU", + "trigger_watermark": 0.8, + "eviction_ratio": 0.2 + } +} +``` ## Data Flow: Eviction Cycle diff --git a/docs/source/assets/maru-kvcache.png b/docs/source/assets/maru-kvcache.png new file mode 100644 index 0000000000..137261221d Binary files /dev/null and b/docs/source/assets/maru-kvcache.png differ diff --git a/docs/source/developer_guide/extending_lmcache/native_connectors.rst b/docs/source/developer_guide/extending_lmcache/native_connectors.rst index c8d3974f59..ce7f6b5724 100644 --- a/docs/source/developer_guide/extending_lmcache/native_connectors.rst +++ b/docs/source/developer_guide/extending_lmcache/native_connectors.rst @@ -50,7 +50,8 @@ Step 1: C++ Connector --------------------- Create your connector directory (e.g., ``csrc/storage_backends/mybackend/``) and -inherit from ``ConnectorBase``. You only need to override 4 methods. +inherit from ``ConnectorBase``. You need to override 4 required methods +(and optionally ``do_single_delete`` to support eviction). **connector.h:** @@ -104,6 +105,12 @@ inherit from ``ConnectorBase``. You only need to override 4 // send EXISTS, return true/false } + // 5. DELETE: remove key (optional, has default no-op) + bool do_single_delete(MyConn& conn, + const std::string& key) override { + // send DELETE, return true if deleted, false if not found + } + // Optional: clean shutdown void shutdown_connections() override { // close sockets, free resources @@ -136,8 +143,8 @@ inherit from ``ConnectorBase``. You only need to override 4 Step 2: Pybind Module --------------------- -Use the ``LMCACHE_BIND_CONNECTOR_METHODS`` macro, which binds all 6 methods -(``event_fd``, ``submit_batch_get/set/exists``, ``drain_completions``, ``close``) +Use the ``LMCACHE_BIND_CONNECTOR_METHODS`` macro, which binds all 7 methods +(``event_fd``, ``submit_batch_get/set/exists/delete``, ``drain_completions``, ``close``) with proper GIL release and Python buffer protocol handling. .. code-block:: cpp @@ -255,10 +262,12 @@ Create a new file in the L2 adapters package: class MyBackendL2AdapterConfig(L2AdapterConfigBase): def __init__(self, host: str, port: int, - num_workers: int = 8): + num_workers: int = 8, + max_capacity_gb: float = 0): self.host = host self.port = port self.num_workers = num_workers + self.max_capacity_gb = max_capacity_gb @classmethod def from_dict(cls, d: dict) -> "MyBackendL2AdapterConfig": @@ -269,8 +278,10 @@ Create a new file in the L2 adapters package: if not isinstance(port, int) or port <= 0: raise ValueError("port must be a positive integer") num_workers = d.get("num_workers", 8) + max_capacity_gb = d.get("max_capacity_gb", 0) return cls(host=host, port=port, - num_workers=num_workers) + num_workers=num_workers, + max_capacity_gb=max_capacity_gb) @classmethod def help(cls) -> str: @@ -296,7 +307,10 @@ Create a new file in the L2 adapters package: native_client = LMCacheMyBackendClient( config.host, config.port, config.num_workers ) - return NativeConnectorL2Adapter(native_client) + return NativeConnectorL2Adapter( + native_client, + max_capacity_gb=config.max_capacity_gb, + ) # Self-register -- runs automatically when the module @@ -417,12 +431,18 @@ pybind ``LMCACHE_BIND_CONNECTOR_METHODS`` contract): self, keys: list[str], ) -> int: ... + def submit_batch_delete( + self, + keys: list[str], + ) -> int: ... def drain_completions( self, ) -> list[tuple[int, bool, str, list[bool] | None]]: ... def close(self) -> None: ... -The factory validates these methods at creation time and raises ``TypeError`` if any are missing. +The factory validates the first 6 methods at creation time and raises ``TypeError`` if +any are missing. ``submit_batch_delete`` is **optional** -- if absent, the adapter's +``delete()`` method will be a no-op (eviction will not remove keys from the backend). Configuration ~~~~~~~~~~~~~ @@ -459,6 +479,10 @@ Configuration - ``dict`` - no - Forwarded as ``**kwargs`` to the connector class constructor. + * - ``max_capacity_gb`` + - ``float`` + - no + - Maximum L2 storage capacity in GB for client-side usage tracking. Required for L2 eviction. Default 0 (disabled). Loading Flow ~~~~~~~~~~~~ @@ -506,7 +530,7 @@ Step-by-Step: Building an External Native Connector Plugin 2. **Implement the C++ connector** inheriting from ``ConnectorBase`` and override the 4 required methods (``create_connection``, ``do_single_get``, ``do_single_set``, - ``do_single_exists``). + ``do_single_exists``) and optionally ``do_single_delete`` for eviction support. 3. **Create pybind11 bindings** using the ``LMCACHE_BIND_CONNECTOR_METHODS`` macro: @@ -583,7 +607,7 @@ Checklist Use this checklist when adding a new native connector: -1. C++ connector inheriting ``ConnectorBase`` with 4 method overrides +1. C++ connector inheriting ``ConnectorBase`` with 4 required + 1 optional (``do_single_delete``) method overrides 2. Pybind module using ``LMCACHE_BIND_CONNECTOR_METHODS`` 3. ``setup.py`` entry for the new ``CppExtension`` 4. Python client inheriting ``ConnectorClientBase`` (non-MP mode) @@ -594,7 +618,7 @@ Use this checklist when adding a new native connector: For **external** native connector plugins (``native_plugin``): 1. Separate pip-installable package with C++ pybind11 extension -2. Connector class exposing the 6 required methods +2. Connector class exposing the 6 required methods (+ optional ``submit_batch_delete`` for eviction) 3. Python factory class for backend selection 4. ``pip install -e .`` and configure via ``--l2-adapter`` JSON 5. Unit tests (see ``examples/lmc_external_native_connector/tests/``) diff --git a/docs/source/kv_cache/storage_backends/index.rst b/docs/source/kv_cache/storage_backends/index.rst index b39521250d..fd20dca516 100644 --- a/docs/source/kv_cache/storage_backends/index.rst +++ b/docs/source/kv_cache/storage_backends/index.rst @@ -16,6 +16,7 @@ Supported Backends gds infinistore local_storage + maru mock mooncake nixl diff --git a/docs/source/kv_cache/storage_backends/maru.rst b/docs/source/kv_cache/storage_backends/maru.rst new file mode 100644 index 0000000000..9ee69d0006 --- /dev/null +++ b/docs/source/kv_cache/storage_backends/maru.rst @@ -0,0 +1,113 @@ +Maru +==== + +.. _maru-overview: + +Overview +-------- + +`Maru `_ is a high-performance KV cache storage engine built on CXL shared memory, +designed for LLM inference scenarios where multiple instances need to share a KV cache with minimal latency. + +.. image:: ../../assets/maru-kvcache.png + :alt: KV Cache Sharing: Without vs With Maru + +For architecture details, see the `Maru documentation `_. + +Quick Start +----------- + +Install Maru: + +.. code-block:: bash + + git clone https://github.com/xcena-dev/maru.git + cd maru + ./install.sh + +This installs ``maru-server``, ``maru-resourced``, and the ``maru`` Python package. + +Deploy Model With Maru +~~~~~~~~~~~~~~~~~~~~~~ + +**Prerequisites:** CXL device (``/dev/dax*``), Python 3.12+, vLLM and LMCache installed. + +**1. Start the Maru Server** + +.. code-block:: bash + + maru-server + +**2. Create configuration file** (``maru-config.yaml``): + +.. code-block:: yaml + + chunk_size: 256 + local_cpu: False + max_local_cpu_size: 0 + save_unfull_chunk: True + + # Maru backend + maru_path: "maru://localhost:5555" + maru_pool_size: 4 + +**3. Start vLLM with Maru** + +.. code-block:: bash + + LMCACHE_CONFIG_FILE="maru-config.yaml" \ + vllm serve \ + meta-llama/Llama-3.1-8B-Instruct \ + --max-model-len 65536 \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' + +Configuration +------------- + +**LMCache Parameters:** + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``maru_path`` + - Required + - Maru server URL (format: ``maru://host:port``) + * - ``maru_pool_size`` + - ``4.0`` + - CXL memory pool size per instance in GB (e.g., ``4``, ``0.5``) + +**Advanced Parameters (via extra_config):** + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Default + - Description + * - ``maru_instance_id`` + - auto UUID + - Unique client instance identifier + * - ``maru_timeout_ms`` + - 5000 + - ZMQ RPC socket timeout in milliseconds + * - ``maru_use_async_rpc`` + - true + - Async DEALER-ROUTER RPC (``false`` for synchronous REQ-REP) + * - ``maru_max_inflight`` + - 64 + - Max concurrent async RPC requests + * - ``maru_eager_map`` + - true + - Pre-map all shared regions on connect + +Additional Resources +-------------------- + +- `Maru GitHub Repository `_ +- `Maru Documentation `_ diff --git a/docs/source/kv_cache/storage_backends/resp.rst b/docs/source/kv_cache/storage_backends/resp.rst index 9afb7b3c42..2647ab8f06 100644 --- a/docs/source/kv_cache/storage_backends/resp.rst +++ b/docs/source/kv_cache/storage_backends/resp.rst @@ -281,6 +281,45 @@ The ``--l2-adapter`` JSON accepts these fields: - str - ``""`` - Redis AUTH password (leave empty for no auth) + * - ``max_capacity_gb`` + - float + - 0 + - Maximum L2 storage capacity in GB for client-side usage tracking. Required for L2 eviction. Set to 0 (default) to disable usage tracking. + +L2 Eviction +~~~~~~~~~~~~ + +To enable automatic eviction of least-recently-used keys when the Redis backend fills up, +set ``max_capacity_gb`` and add an ``"eviction"`` block: + +.. code-block:: bash + + lmcache server \ + --l1-size-gb 10 \ + --eviction-policy LRU \ + --chunk-size 16 \ + --l2-adapter '{ + "type": "resp", + "host": "localhost", + "port": 6379, + "num_workers": 8, + "max_capacity_gb": 10, + "eviction": { + "eviction_policy": "LRU", + "trigger_watermark": 0.8, + "eviction_ratio": 0.2 + } + }' \ + --port 6555 + +This configures a 10 GB capacity limit. When usage exceeds 80% (``trigger_watermark``), +the eviction controller will delete the least-recently-used ~20% of stored keys +(``eviction_ratio``) using the Redis ``DEL`` command. + +.. note:: + ``max_capacity_gb`` enables **client-side** size tracking. It does not configure + the Redis server's ``maxmemory`` setting. You should set ``max_capacity_gb`` to + match or be slightly below your Redis server's available memory. Testing the Setup diff --git a/docs/source/mp/l2_storage.rst b/docs/source/mp/l2_storage.rst index a0b945086e..952038c127 100644 --- a/docs/source/mp/l2_storage.rst +++ b/docs/source/mp/l2_storage.rst @@ -128,6 +128,74 @@ object is stored as a raw ``.data`` file whose name encodes the full # With O_DIRECT for bypassing page cache --l2-adapter '{"type": "fs", "base_path": "/data/lmcache/l2", "use_odirect": true}' +``mooncake_store`` -- Mooncake Store native connector +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +An L2 adapter backed by the native C++ Mooncake Store connector. Uses +`Mooncake `_ for high-performance +distributed KV cache storage with RDMA support. + +**Prerequisites -- Building with Mooncake support:** + +The Mooncake extension is **not** built by default. You must explicitly +enable it: + +.. code-block:: bash + + BUILD_MOONCAKE=1 pip install -e . --verbose + +The ``BUILD_MOONCAKE`` environment variable controls compilation: + +- ``BUILD_MOONCAKE=1``: Enable the Mooncake C++ extension. +- ``BUILD_MOONCAKE=0``: Force disable (highest priority), even if + ``MOONCAKE_INCLUDE_DIR`` is set. +- **Not set**: Falls back to checking ``MOONCAKE_INCLUDE_DIR`` for + backward compatibility. If ``MOONCAKE_INCLUDE_DIR`` is also unset, + the extension is skipped. + +If the Mooncake headers are not installed in the system include path +(e.g., ``/usr/local/include``), you must point to them explicitly: + +.. code-block:: bash + + BUILD_MOONCAKE=1 \ + MOONCAKE_INCLUDE_DIR=/path/to/mooncake/include \ + MOONCAKE_LIB_DIR=/path/to/mooncake/lib \ + pip install -e . --verbose + +**LMCache-specific fields:** + +- ``num_workers``: Number of C++ worker threads (default ``4``, must + be > 0). + +**Mooncake fields:** + +All other keys in the JSON config (except ``type``, ``num_workers``, +and ``eviction``) are forwarded **as-is** to Mooncake's +``setup_internal(ConfigDict)``. Refer to the +`Mooncake documentation `_ +for available setup keys (e.g., ``local_hostname``, +``metadata_server``, ``master_server_address``, ``protocol``, +``device_name``, ``global_segment_size``). + +**Configuration example:** + +.. code-block:: bash + + --l2-adapter '{ + "type": "mooncake_store", + "num_workers": 4, + "local_hostname": "node01", + "metadata_server": "http://localhost:8080/metadata", + "master_server_address": "localhost:50051", + "protocol": "tcp", + "local_buffer_size": "3221225472" + "global_segment_size": "3221225472" + }' + +For full Mooncake setup instructions (master service, metadata server, +etc.), see `Mooncake `_ . + ``mock`` -- Mock adapter for testing ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -335,6 +403,8 @@ drops by ``eviction_ratio``. * - ``mock`` - Full support. Useful for testing eviction behaviour without real storage hardware. + * - ``mooncake_store`` + - No eviction support (native connector adapter). * - ``fs`` - No eviction support (``delete`` and ``get_usage`` are no-ops). * - native connectors diff --git a/examples/kv_cache_reuse/remote_backends/resp/README.md b/examples/kv_cache_reuse/remote_backends/resp/README.md index 61df8cf49f..89532f1439 100644 --- a/examples/kv_cache_reuse/remote_backends/resp/README.md +++ b/examples/kv_cache_reuse/remote_backends/resp/README.md @@ -91,6 +91,7 @@ The `--l2-adapter` JSON accepts these fields: | `num_workers` | int | 8 | C++ worker threads for parallel I/O | | `username` | str | `""` | Redis ACL username | | `password` | str | `""` | Redis AUTH password | +| `max_capacity_gb` | float | 0 | Max L2 capacity in GB for usage tracking (required for L2 eviction) | ### Launch vLLM with LMCache MP Connector diff --git a/examples/lmc_external_native_connector/csrc/connector.cpp b/examples/lmc_external_native_connector/csrc/connector.cpp index 86b3768925..bad1084e0b 100644 --- a/examples/lmc_external_native_connector/csrc/connector.cpp +++ b/examples/lmc_external_native_connector/csrc/connector.cpp @@ -172,6 +172,13 @@ bool ExampleFSConnector::do_single_exists(WorkerFSConn& conn, return std::filesystem::exists(path); } +bool ExampleFSConnector::do_single_delete(WorkerFSConn& conn, + const std::string& key) { + auto path = conn.base_path / safe_filename(key); + std::error_code ec; + return std::filesystem::remove(path, ec); +} + // --------------------------------------------------------------- // ExampleMemoryConnector // --------------------------------------------------------------- @@ -220,4 +227,10 @@ bool ExampleMemoryConnector::do_single_exists(WorkerMemConn& conn, return conn.store->data.count(key) > 0; } +bool ExampleMemoryConnector::do_single_delete(WorkerMemConn& conn, + const std::string& key) { + std::lock_guard lk(conn.store->mu); + return conn.store->data.erase(key) > 0; +} + } // namespace example_connector diff --git a/examples/lmc_external_native_connector/csrc/connector.h b/examples/lmc_external_native_connector/csrc/connector.h index 1ed8d7d21e..75702eafea 100644 --- a/examples/lmc_external_native_connector/csrc/connector.h +++ b/examples/lmc_external_native_connector/csrc/connector.h @@ -33,6 +33,7 @@ class ExampleFSConnector void do_single_set(WorkerFSConn& conn, const std::string& key, const void* buf, size_t len, size_t chunk_size) override; bool do_single_exists(WorkerFSConn& conn, const std::string& key) override; + bool do_single_delete(WorkerFSConn& conn, const std::string& key) override; private: static std::string safe_filename(const std::string& key); @@ -71,6 +72,7 @@ class ExampleMemoryConnector void do_single_set(WorkerMemConn& conn, const std::string& key, const void* buf, size_t len, size_t chunk_size) override; bool do_single_exists(WorkerMemConn& conn, const std::string& key) override; + bool do_single_delete(WorkerMemConn& conn, const std::string& key) override; private: std::shared_ptr store_; diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 64289545cc..588ac0774e 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -14,6 +14,7 @@ from lmcache.integration.request_telemetry.factory import RequestTelemetryFactory from lmcache.utils import _lmcache_nvtx_annotate, init_logger from lmcache.v1.multiprocess.custom_types import ( + BlockAllocationRecord, CudaIPCWrapper, IPCCacheEngineKey, KVCache, @@ -447,6 +448,28 @@ def end_session(self, request_id: str) -> None: [request_id], ) + def report_block_allocations( + self, + records: list[BlockAllocationRecord], + ) -> None: + """Report vLLM GPU block allocation deltas to LMCache server. + + Fire-and-forget: does not wait for a response. If the server + is unhealthy the report is silently dropped. + + Args: + records: List of BlockAllocationRecord with per-request + block and token allocation deltas. + """ + if not self.is_healthy or not records: + return + + send_lmcache_request( + self.mq_client, + RequestType.REPORT_BLOCK_ALLOCATION, + [records], + ) + # Helper functions def _create_key( self, diff --git a/lmcache/v1/config.py b/lmcache/v1/config.py index c6c86e563d..20bbc19276 100644 --- a/lmcache/v1/config.py +++ b/lmcache/v1/config.py @@ -236,6 +236,13 @@ "default": None, "env_converter": int, }, + # Maru CXL shared memory backend + "maru_path": {"type": Optional[str], "default": None, "env_converter": str}, + "maru_pool_size": { + "type": float, + "default": 4.0, + "env_converter": float, + }, # Other configurations # (Deprecated) The url of the actual remote lmcache instance for auditing. # Please use extra_config['audit_actual_remote_url'] instead. diff --git a/lmcache/v1/distributed/l2_adapters/fs_native_l2_adapter.py b/lmcache/v1/distributed/l2_adapters/fs_native_l2_adapter.py index 4cd443099a..7047b889c8 100644 --- a/lmcache/v1/distributed/l2_adapters/fs_native_l2_adapter.py +++ b/lmcache/v1/distributed/l2_adapters/fs_native_l2_adapter.py @@ -54,12 +54,14 @@ def __init__( relative_tmp_dir: str = "", use_odirect: bool = False, read_ahead_size: Optional[int] = None, + max_capacity_gb: float = 0, ): self.base_path = base_path self.num_workers = num_workers self.relative_tmp_dir = relative_tmp_dir self.use_odirect = use_odirect self.read_ahead_size = read_ahead_size + self.max_capacity_gb = max_capacity_gb @classmethod def from_dict(cls, d: dict) -> "FSNativeL2AdapterConfig": @@ -84,12 +86,17 @@ def from_dict(cls, d: dict) -> "FSNativeL2AdapterConfig": if not isinstance(read_ahead_size, int) or read_ahead_size <= 0: raise ValueError("read_ahead_size must be a positive integer") + max_capacity_gb = d.get("max_capacity_gb", 0) + if not isinstance(max_capacity_gb, (int, float)) or max_capacity_gb < 0: + raise ValueError("max_capacity_gb must be a non-negative number") + return cls( base_path=base_path, num_workers=num_workers, relative_tmp_dir=str(relative_tmp_dir), use_odirect=use_odirect, read_ahead_size=read_ahead_size, + max_capacity_gb=float(max_capacity_gb), ) @classmethod @@ -106,7 +113,10 @@ def help(cls) -> str: "via O_DIRECT (default false)\n" "- read_ahead_size (int): trigger fs " "readahead by reading this many bytes " - "first (optional)" + "first (optional)\n" + "- max_capacity_gb (float): max L2 capacity " + "in GB for usage tracking / eviction " + "(default 0 = disabled)" ) @@ -148,7 +158,9 @@ def _create_fs_native_l2_adapter( config.use_odirect, config.read_ahead_size, ) - return NativeConnectorL2Adapter(native_client) + return NativeConnectorL2Adapter( + native_client, max_capacity_gb=config.max_capacity_gb + ) register_l2_adapter_type("fs_native", FSNativeL2AdapterConfig) diff --git a/lmcache/v1/distributed/l2_adapters/mooncake_store_l2_adapter.py b/lmcache/v1/distributed/l2_adapters/mooncake_store_l2_adapter.py new file mode 100644 index 0000000000..886789d398 --- /dev/null +++ b/lmcache/v1/distributed/l2_adapters/mooncake_store_l2_adapter.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Mooncake Store native L2 adapter config and factory. +""" + +# Future +from __future__ import annotations + +# Standard +from typing import ( + TYPE_CHECKING, + Dict, + Optional, +) + +if TYPE_CHECKING: + from lmcache.v1.distributed.internal_api import ( + L1MemoryDesc, + ) + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.distributed.l2_adapters.base import ( + L2AdapterInterface, +) +from lmcache.v1.distributed.l2_adapters.config import ( + L2AdapterConfigBase, + register_l2_adapter_type, +) +from lmcache.v1.distributed.l2_adapters.factory import ( + register_l2_adapter_factory, +) + +logger = init_logger(__name__) + +# Keys consumed only by LMCache (never sent to mooncake). +_LMCACHE_ONLY_KEYS = {"type", "num_workers", "eviction"} + + +class MooncakeStoreL2AdapterConfig(L2AdapterConfigBase): + """Config for an L2 adapter backed by the native + C++ Mooncake Store connector. + + ``setup_config`` is a string-to-string dict that is + forwarded **as-is** to mooncake's + ``RealClient::setup_internal(ConfigDict)``. + LMCache does NOT interpret, validate, or fill in + defaults for any mooncake keys — that is mooncake's + responsibility. + + ``num_workers`` is the only LMCache-specific knob. + """ + + def __init__( + self, + setup_config: Dict[str, str], + num_workers: int = 4, + ): + super().__init__() + self.setup_config: Dict[str, str] = dict(setup_config) + self.num_workers = num_workers + + @classmethod + def from_dict(cls, d: dict) -> "MooncakeStoreL2AdapterConfig": + num_workers = d.get("num_workers", 4) + if not isinstance(num_workers, int) or num_workers <= 0: + raise ValueError("num_workers must be a positive integer") + + # Everything except LMCache-only keys is + # forwarded to mooncake as str values. + setup: Dict[str, str] = {} + for k, v in d.items(): + if k in _LMCACHE_ONLY_KEYS: + continue + if v is not None: + setup[k] = str(v) + + return cls( + setup_config=setup, + num_workers=num_workers, + ) + + @classmethod + def help(cls) -> str: + return ( + "Mooncake Store L2 adapter config.\n" + "All keys except LMCache-only keys are " + "forwarded as-is to mooncake's " + "setup_internal(ConfigDict).\n" + "Refer to mooncake documentation for " + "available setup keys.\n" + "- num_workers (int): C++ worker threads " + "(default 4, >0)" + ) + + +def _create_mooncake_store_l2_adapter( + config: L2AdapterConfigBase, + l1_memory_desc: "Optional[L1MemoryDesc]" = None, +) -> L2AdapterInterface: + """Create a NativeConnectorL2Adapter backed by the + C++ Mooncake Store connector.""" + try: + # First Party + from lmcache.lmcache_mooncake import ( + LMCacheMooncakeClient, + ) + except ImportError as e: + raise RuntimeError( + "Mooncake Store L2 adapter requires the " + "C++ Mooncake extension. Build with: " + "MOONCAKE_INCLUDE_DIR=/path/to/mooncake-" + "store/include pip install -e ." + ) from e + + # First Party + from lmcache.v1.distributed.l2_adapters.native_connector_l2_adapter import ( # noqa: E501 + NativeConnectorL2Adapter, + ) + + assert isinstance(config, MooncakeStoreL2AdapterConfig) + native_client = LMCacheMooncakeClient( + config=config.setup_config, + num_workers=config.num_workers, + ) + logger.info( + "Created Mooncake Store L2 adapter (workers=%d)", + config.num_workers, + ) + return NativeConnectorL2Adapter(native_client) + + +# Self-register config type and adapter factory +register_l2_adapter_type("mooncake_store", MooncakeStoreL2AdapterConfig) +register_l2_adapter_factory("mooncake_store", _create_mooncake_store_l2_adapter) diff --git a/lmcache/v1/distributed/l2_adapters/native_connector_l2_adapter.py b/lmcache/v1/distributed/l2_adapters/native_connector_l2_adapter.py index 26d8c3f245..4f44cd8fd8 100644 --- a/lmcache/v1/distributed/l2_adapters/native_connector_l2_adapter.py +++ b/lmcache/v1/distributed/l2_adapters/native_connector_l2_adapter.py @@ -77,8 +77,10 @@ class NativeConnectorL2Adapter(L2AdapterInterface): _OP_STORE = "store" _OP_LOOKUP = "lookup" _OP_LOAD = "load" + _OP_DELETE = "delete" - def __init__(self, native_client): + def __init__(self, native_client, max_capacity_gb: float = 0): + super().__init__() self._client = native_client self._client_fd: int = int(native_client.event_fd()) @@ -105,6 +107,19 @@ def __init__(self, native_client): # Client-side lock tracking (refcount per key) self._locked_keys: dict[ObjectKey, int] = defaultdict(int) + # Delete capability detection + self._has_delete = callable(getattr(native_client, "submit_batch_delete", None)) + + # Pending delete events for synchronous delete() calls + self._pending_delete_events: dict[L2TaskId, threading.Event] = {} + + # Client-side size tracking for get_usage() + self._max_capacity_bytes = int(max_capacity_gb * (1024**3)) + self._current_size_bytes: int = 0 + self._key_sizes: dict[ObjectKey, int] = {} + # Pending store sizes: native future_id -> (keys, per_key_sizes) + self._pending_store_sizes: dict[int, tuple[list[ObjectKey], list[int]]] = {} + # Task ID counter self._next_task_id: L2TaskId = 0 @@ -144,6 +159,7 @@ def submit_store_task( ) -> L2TaskId: key_strings = [_object_key_to_string(k) for k in keys] memviews = [_obj_to_memoryview(obj) for obj in objects] + per_key_sizes = [obj.get_size() for obj in objects] # Register pending op BEFORE submit to avoid race # with demux thread. The native submit is @@ -157,6 +173,7 @@ def submit_store_task( len(keys), None, ) + self._pending_store_sizes[future_id] = (list(keys), per_key_sizes) return task_id @@ -223,7 +240,7 @@ def submit_load_task( self._OP_LOAD, task_id, len(keys), - None, + list(keys), ) return task_id @@ -237,12 +254,57 @@ def query_load_result(self, task_id: L2TaskId) -> Bitmap | None: # --------------------------------------------------------------- def delete(self, keys: list[ObjectKey]) -> None: - # Not implemented for the native connector adapter. - pass + """Delete a batch of keys from the remote backend. + + Submits a batch delete to the native connector and blocks + until the demux thread signals completion (up to 30s timeout). + Fires ``_notify_keys_deleted`` on success so eviction policy + tracking stays in sync. + + No-op if the connector does not expose ``submit_batch_delete`` + or if the key list is empty. + """ + if not keys or not self._has_delete: + return + + key_strings = [_object_key_to_string(k) for k in keys] + done_event = threading.Event() + + with self._lock: + task_id = self._get_next_task_id() + future_id = int(self._client.submit_batch_delete(key_strings)) + self._pending_ops[future_id] = ( + self._OP_DELETE, + task_id, + len(keys), + list(keys), + ) + self._pending_delete_events[task_id] = done_event + + # Block until demux thread signals completion + if not done_event.wait(timeout=30.0): + with self._lock: + self._pending_delete_events.pop(task_id, None) + # Note: _pending_ops entry may already be consumed + # by the demux thread; pop is safe either way. + for fid, entry in list(self._pending_ops.items()): + if entry[1] == task_id: + self._pending_ops.pop(fid, None) + break + logger.warning( + "delete() timed out after 30s for %d keys", + len(keys), + ) + return + + self._notify_keys_deleted(keys) def get_usage(self) -> tuple[float, float]: - # Not implemented for the native connector adapter. - return (-1.0, -1.0) + if self._max_capacity_bytes <= 0: + return (-1.0, -1.0) + with self._lock: + usage = self._current_size_bytes / self._max_capacity_bytes + return (usage, usage) # --------------------------------------------------------------- # Cleanup @@ -291,6 +353,11 @@ def _demux_loop(self) -> None: if not completions: continue + # Collect listener notifications to fire after + # releasing the lock. + keys_stored: list[ObjectKey] = [] + keys_accessed: list[ObjectKey] = [] + with self._lock: for ( future_id, @@ -316,6 +383,15 @@ def _demux_loop(self) -> None: if op_type == self._OP_STORE: self._completed_stores[task_id] = ok + # Update size tracking on success + store_info = self._pending_store_sizes.pop(fid, None) + if ok and store_info is not None: + store_keys, sizes = store_info + for key, size in zip(store_keys, sizes, strict=False): + if key not in self._key_sizes: + self._key_sizes[key] = size + self._current_size_bytes += size + keys_stored.extend(store_keys) os.eventfd_write(self._store_efd, 1) elif op_type == self._OP_LOOKUP: @@ -331,14 +407,38 @@ def _demux_loop(self) -> None: elif op_type == self._OP_LOAD: bitmap = Bitmap(num_keys) + loaded_keys: list[ObjectKey] = [] if result_bools is not None: for i, loaded in enumerate(result_bools): if loaded: bitmap.set(i) + if lookup_keys is not None: + loaded_keys.append(lookup_keys[i]) elif ok: # Fallback for connectors that # do not report per-key results for i in range(num_keys): bitmap.set(i) + if lookup_keys is not None: + loaded_keys.extend(lookup_keys) + keys_accessed.extend(loaded_keys) self._completed_loads[task_id] = bitmap os.eventfd_write(self._load_efd, 1) + + elif op_type == self._OP_DELETE: + # Decrement sizes for successfully deleted keys + if result_bools is not None and lookup_keys is not None: + for i, deleted in enumerate(result_bools): + if deleted: + key = lookup_keys[i] + size = self._key_sizes.pop(key, 0) + self._current_size_bytes -= size + evt = self._pending_delete_events.pop(task_id, None) + if evt is not None: + evt.set() + + # Fire listener notifications outside the lock + if keys_stored: + self._notify_keys_stored(keys_stored) + if keys_accessed: + self._notify_keys_accessed(keys_accessed) diff --git a/lmcache/v1/distributed/l2_adapters/native_plugin_l2_adapter.py b/lmcache/v1/distributed/l2_adapters/native_plugin_l2_adapter.py index 688501373c..aba5f08c21 100644 --- a/lmcache/v1/distributed/l2_adapters/native_plugin_l2_adapter.py +++ b/lmcache/v1/distributed/l2_adapters/native_plugin_l2_adapter.py @@ -55,10 +55,12 @@ def __init__( module_path: str, class_name: str, adapter_params: dict[str, Any] | None = None, + max_capacity_gb: float = 0, ): self.module_path = module_path self.class_name = class_name self.adapter_params = adapter_params or {} + self.max_capacity_gb = max_capacity_gb @classmethod def from_dict(cls, d: dict) -> "NativePluginL2AdapterConfig": @@ -74,10 +76,15 @@ def from_dict(cls, d: dict) -> "NativePluginL2AdapterConfig": if not isinstance(adapter_params, dict): raise ValueError("adapter_params must be a dict") + max_capacity_gb = d.get("max_capacity_gb", 0) + if not isinstance(max_capacity_gb, (int, float)) or max_capacity_gb < 0: + raise ValueError("max_capacity_gb must be a non-negative number") + return cls( module_path=module_path, class_name=class_name, adapter_params=adapter_params, + max_capacity_gb=float(max_capacity_gb), ) @classmethod @@ -98,7 +105,10 @@ def help(cls) -> str: '"module_path": "my_ext.connector", ' '"class_name": "MyConnectorClient", ' '"adapter_params": ' - '{"host": "localhost", "port": 1234}}' + '{"host": "localhost", "port": 1234}}\n' + "- max_capacity_gb (float): max L2 capacity " + "in GB for usage tracking / eviction " + "(default 0 = disabled)" ) @@ -166,13 +176,23 @@ def _create_native_plugin_l2_adapter( native_client.close() raise + if not callable(getattr(native_client, "submit_batch_delete", None)): + logger.warning( + "%s.%s does not expose submit_batch_delete; " + "L2 eviction delete will be a no-op.", + config.module_path, + config.class_name, + ) + logger.info( "Created native plugin L2 adapter: %s.%s (params=%s)", config.module_path, config.class_name, config.adapter_params, ) - return NativeConnectorL2Adapter(native_client) + return NativeConnectorL2Adapter( + native_client, max_capacity_gb=config.max_capacity_gb + ) register_l2_adapter_type("native_plugin", NativePluginL2AdapterConfig) diff --git a/lmcache/v1/distributed/l2_adapters/resp_l2_adapter.py b/lmcache/v1/distributed/l2_adapters/resp_l2_adapter.py index 5f9ee78dda..83a141a115 100644 --- a/lmcache/v1/distributed/l2_adapters/resp_l2_adapter.py +++ b/lmcache/v1/distributed/l2_adapters/resp_l2_adapter.py @@ -53,6 +53,7 @@ def __init__( num_workers: int = 8, username: str = "", password: str = "", + max_capacity_gb: float = 0, ): super().__init__() self.host = host @@ -60,6 +61,7 @@ def __init__( self.num_workers = num_workers self.username = username self.password = password + self.max_capacity_gb = max_capacity_gb @classmethod def from_dict(cls, d: dict) -> "RESPL2AdapterConfig": @@ -78,12 +80,17 @@ def from_dict(cls, d: dict) -> "RESPL2AdapterConfig": username = d.get("username", "") password = d.get("password", "") + max_capacity_gb = d.get("max_capacity_gb", 0) + if not isinstance(max_capacity_gb, (int, float)) or max_capacity_gb < 0: + raise ValueError("max_capacity_gb must be a non-negative number") + return cls( host=host, port=port, num_workers=num_workers, username=str(username), password=str(password), + max_capacity_gb=float(max_capacity_gb), ) @classmethod @@ -98,7 +105,10 @@ def help(cls) -> str: "- username (str): auth username " "(default empty)\n" "- password (str): auth password " - "(default empty)" + "(default empty)\n" + "- max_capacity_gb (float): max L2 capacity " + "in GB for usage tracking / eviction " + "(default 0 = disabled)" ) @@ -139,7 +149,9 @@ def _create_resp_l2_adapter( config.port, config.num_workers, ) - return NativeConnectorL2Adapter(native_client) + return NativeConnectorL2Adapter( + native_client, max_capacity_gb=config.max_capacity_gb + ) # Self-register config type and adapter factory diff --git a/lmcache/v1/gpu_connector/utils.py b/lmcache/v1/gpu_connector/utils.py index c1930e04bb..c43ef9482a 100644 --- a/lmcache/v1/gpu_connector/utils.py +++ b/lmcache/v1/gpu_connector/utils.py @@ -561,7 +561,8 @@ def get_num_heads(kv_caches: Any, gpu_kv_format: "lmc_ops.GPUKVFormat") -> int: # HND: [..., NH, BS, HS] — num_heads at shape[2] return kv_caches[0].shape[2] elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_BS_HS: - raise ValueError(_ATTRIBUTE_NOT_EXIST_ERROR.format(format=gpu_kv_format)) + # MLA: heads are absorbed into hidden dim, so num_heads = 1 + return 1 elif gpu_kv_format == lmc_ops.GPUKVFormat.TWO_X_NL_X_NBBS_NH_HS: return kv_caches[0][0].shape[1] elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NBBS_ONE_HS: diff --git a/lmcache/v1/mp_observability/event.py b/lmcache/v1/mp_observability/event.py index b4d50a52d4..ea018b6aca 100644 --- a/lmcache/v1/mp_observability/event.py +++ b/lmcache/v1/mp_observability/event.py @@ -50,6 +50,9 @@ class EventType(Enum): MP_LOOKUP_PREFETCH_START = "mp.lookup_prefetch.start" MP_LOOKUP_PREFETCH_END = "mp.lookup_prefetch.end" + # vLLM block allocation events + MP_VLLM_BLOCK_ALLOCATION = "mp.vllm.block_allocation" + @dataclass class Event: diff --git a/lmcache/v1/mp_observability/subscribers/logging/mp_server.py b/lmcache/v1/mp_observability/subscribers/logging/mp_server.py index 23360dc6d6..b2b0174cdc 100644 --- a/lmcache/v1/mp_observability/subscribers/logging/mp_server.py +++ b/lmcache/v1/mp_observability/subscribers/logging/mp_server.py @@ -30,6 +30,7 @@ def get_subscriptions(self) -> dict[EventType, EventCallback]: EventType.MP_RETRIEVE_END: self._on_retrieve_end, EventType.MP_LOOKUP_PREFETCH_START: self._on_lookup_prefetch_start, EventType.MP_LOOKUP_PREFETCH_END: self._on_lookup_prefetch_end, + EventType.MP_VLLM_BLOCK_ALLOCATION: self._on_block_allocation, } def _on_store_start(self, event: Event) -> None: @@ -74,3 +75,16 @@ def _on_lookup_prefetch_end(self, event: Event) -> None: event.session_id, event.metadata.get("found_count"), ) + + def _on_block_allocation(self, event: Event) -> None: + records = event.metadata.get("records", []) + for rec in records: + logger.debug( + "vLLM block allocation: req_id=%s " + "new_blocks=%d new_tokens=%d " + "block_ids=%s", + rec.req_id, + len(rec.new_block_ids), + len(rec.new_token_ids), + rec.new_block_ids[:10], + ) diff --git a/lmcache/v1/multiprocess/custom_types.py b/lmcache/v1/multiprocess/custom_types.py index 28aeed18dc..38ca052cee 100644 --- a/lmcache/v1/multiprocess/custom_types.py +++ b/lmcache/v1/multiprocess/custom_types.py @@ -222,6 +222,15 @@ def ext_hook(code: int, data: bytes) -> Any: return msgspec.msgpack.Decoder(ext_hook=ext_hook, type=type) +@dataclass +class BlockAllocationRecord: + """A single per-request GPU block allocation delta from vLLM.""" + + req_id: str + new_block_ids: list[int] + new_token_ids: list[int] + + @dataclass class CBMatchResult: """Result of a sub-sequence match from BlendTokenRangeMatcher. diff --git a/lmcache/v1/multiprocess/protocols/__init__.py b/lmcache/v1/multiprocess/protocols/__init__.py index e06f605c9a..ed6ff22d4d 100644 --- a/lmcache/v1/multiprocess/protocols/__init__.py +++ b/lmcache/v1/multiprocess/protocols/__init__.py @@ -9,7 +9,14 @@ """ # First Party -from lmcache.v1.multiprocess.protocols import blend, blend_v2, controller, debug, engine +from lmcache.v1.multiprocess.protocols import ( + blend, + blend_v2, + controller, + debug, + engine, + observability, +) from lmcache.v1.multiprocess.protocols.base import ( HandlerType, ProtocolDefinition, @@ -29,6 +36,7 @@ class ProtocolInitializationError(Exception): ("debug", debug), ("blend", blend), ("blend_v2", blend_v2), + ("observability", observability), ] diff --git a/lmcache/v1/multiprocess/protocols/base.py b/lmcache/v1/multiprocess/protocols/base.py index f743b57a60..383a41ff8c 100644 --- a/lmcache/v1/multiprocess/protocols/base.py +++ b/lmcache/v1/multiprocess/protocols/base.py @@ -54,6 +54,9 @@ class RequestType(enum.Enum): GET_CHUNK_SIZE = enum.auto() PING = enum.auto() + # Observability operations + REPORT_BLOCK_ALLOCATION = enum.auto() + # Debug operations NOOP = enum.auto() diff --git a/lmcache/v1/multiprocess/protocols/observability.py b/lmcache/v1/multiprocess/protocols/observability.py new file mode 100644 index 0000000000..add73ed0e0 --- /dev/null +++ b/lmcache/v1/multiprocess/protocols/observability.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Observability protocol definitions. + +This module defines protocols for: +- REPORT_BLOCK_ALLOCATION: Report vLLM GPU block allocation events + (fire-and-forget, no response) +""" + +# First Party +from lmcache.v1.multiprocess.custom_types import BlockAllocationRecord +from lmcache.v1.multiprocess.protocols.base import HandlerType, ProtocolDefinition + +# Define request names for this protocol group +REQUEST_NAMES = [ + "REPORT_BLOCK_ALLOCATION", +] + + +def get_protocol_definitions() -> dict[str, ProtocolDefinition]: + """ + Returns protocol definitions for observability operations. + + Returns: + Dictionary mapping request names to their protocol definitions + """ + return { + # Report vLLM block allocation + # Payload: [list[BlockAllocationRecord]] - list of allocation records + # Returns: None (fire-and-forget) + "REPORT_BLOCK_ALLOCATION": ProtocolDefinition( + payload_classes=[list[BlockAllocationRecord]], + response_class=None, + handler_type=HandlerType.BLOCKING, + ), + } diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 64b901f438..aad74041ae 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -45,6 +45,7 @@ parse_args_to_mp_server_config, ) from lmcache.v1.multiprocess.custom_types import ( + BlockAllocationRecord, IPCCacheEngineKey, KVCache, ) @@ -819,6 +820,20 @@ def report_status(self) -> dict: "storage_manager": sm, } + def report_block_allocations(self, records: list[BlockAllocationRecord]) -> None: + """Publish vLLM block allocation records to the EventBus. + + Args: + records: List of BlockAllocationRecord with per-request + block and token allocation deltas. + """ + self._event_bus.publish( + Event( + event_type=EventType.MP_VLLM_BLOCK_ALLOCATION, + metadata={"records": records}, + ) + ) + def debug(self) -> str: return "OK" @@ -914,6 +929,11 @@ def run_cache_server( add_handler_helper(server, RequestType.PING, engine.ping) add_handler_helper(server, RequestType.END_SESSION, engine.end_session) add_handler_helper(server, RequestType.NOOP, engine.debug) + add_handler_helper( + server, + RequestType.REPORT_BLOCK_ALLOCATION, + engine.report_block_allocations, + ) # Assign thread pools server.add_affinity_thread_pool( @@ -929,6 +949,7 @@ def run_cache_server( RequestType.END_SESSION, RequestType.CLEAR, RequestType.PING, + RequestType.REPORT_BLOCK_ALLOCATION, ], max_workers=mp_config.max_cpu_workers, ) diff --git a/lmcache/v1/storage_backend/__init__.py b/lmcache/v1/storage_backend/__init__.py index b7212b9603..d49cda695a 100644 --- a/lmcache/v1/storage_backend/__init__.py +++ b/lmcache/v1/storage_backend/__init__.py @@ -218,6 +218,20 @@ def CreateStorageBackends( ) storage_backends[str(gds_backend)] = gds_backend + if config.maru_path is not None and "MaruBackend" not in _skip: + try: + # First Party + from lmcache.v1.storage_backend.maru_backend import MaruBackend + except ImportError as e: + raise ImportError( + "The 'maru' and 'maru_lmcache' packages are required " + "to use MaruBackend. Please install them according to " + "the Maru setup documentation." + ) from e + + maru_backend = MaruBackend(config, metadata, loop, dst_device) + storage_backends[str(maru_backend)] = maru_backend + if config.remote_url is not None and "RemoteBackend" not in _skip: assert local_cpu_backend is not None, ( "Remote backend requires local CPU backend as a buffer." diff --git a/lmcache/v1/storage_backend/maru_backend.py b/lmcache/v1/storage_backend/maru_backend.py new file mode 100644 index 0000000000..5069e071eb --- /dev/null +++ b/lmcache/v1/storage_backend/maru_backend.py @@ -0,0 +1,734 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from concurrent.futures import Future +from typing import Any, Callable, List, Optional, Sequence, Union +import asyncio +import threading +import time + +# Third Party +from maru import MaruConfig, MaruHandler +from maru_lmcache import CxlMemoryAdapter +import torch + +# First Party +from lmcache.integration.vllm.utils import get_size_bytes +from lmcache.logging import init_logger +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import ( + MemoryAllocatorInterface, + MemoryFormat, + MemoryObj, +) +from lmcache.v1.metadata import LMCacheMetadata +from lmcache.v1.storage_backend.abstract_backend import AllocatorBackendInterface + +logger = init_logger(__name__) + + +class MaruBackend(AllocatorBackendInterface): + """Maru CXL shared memory storage backend. + + Implements AllocatorBackendInterface with its own CxlMemoryAdapter. + No LocalCPUBackend needed — data lives directly in CXL mmap memory. + + Put is async (Future): metadata registration via RPC. + Get is sync: CXL memory direct read (no network I/O). + + Args: + config: LMCache engine configuration. Must have maru_path set. + metadata: LMCache engine metadata. + loop: asyncio event loop for async put tasks. + dst_device: Target device string (unused for CXL, kept for interface). + """ + + def __init__( + self, + config: LMCacheEngineConfig, + metadata: LMCacheMetadata, + loop: asyncio.AbstractEventLoop, + dst_device: str = "cuda", + ): + super().__init__(dst_device=dst_device) + + if config.use_layerwise: + raise NotImplementedError( + "MaruBackend does not yet support layerwise KV cache." + ) + + # 1. Config + self.config = config + self.loop = loop + + self._full_chunk_size_bytes: int = get_size_bytes( + metadata.get_shapes(), metadata.get_dtypes() + ) + assert self._full_chunk_size_bytes % metadata.chunk_size == 0 + self._single_token_size: int = ( + self._full_chunk_size_bytes // metadata.chunk_size + ) + + self._mla_worker_id_as0_mode: bool = ( + config.get_extra_config_value( + "remote_enable_mla_worker_id_as0", metadata.use_mla + ) + and metadata.use_mla + and metadata.world_size > 1 + and metadata.worker_id != 0 + ) + + # 2. Handler + self._handler = self._create_handler(config) + + # 3. Allocator + self.memory_allocator = self.initialize_allocator(config, metadata) + + # 4. State + self.put_lock = threading.Lock() + self.put_tasks: set[CacheEngineKey] = set() + + def __str__(self) -> str: + return self.__class__.__name__ + + @staticmethod + def _pool_size_gb_to_bytes(size_gb: float) -> int: + """Convert pool size in GB to bytes.""" + return int(size_gb * 1024**3) + + # ========================================================================= + # Initialization helpers + # ========================================================================= + + def _create_handler( + self, + config: LMCacheEngineConfig, + ) -> "MaruHandler": + """Create and connect a MaruHandler. + + Args: + config: LMCache engine configuration. + + Returns: + Connected MaruHandler instance. + + Raises: + RuntimeError: If MaruHandler connection fails. + """ + assert config.maru_path is not None, "maru_path must be set for MaruBackend" + + # Convert maru:// scheme to tcp:// for ZMQ + server_url = config.maru_path + if server_url.startswith("maru://"): + server_url = "tcp://" + server_url[len("maru://") :] + + extra = config.extra_config or {} + maru_config = MaruConfig( + server_url=server_url, + instance_id=extra.get("maru_instance_id"), + pool_size=self._pool_size_gb_to_bytes(config.maru_pool_size), + chunk_size_bytes=self._full_chunk_size_bytes, + auto_connect=False, + timeout_ms=extra.get("maru_timeout_ms", 5000), + use_async_rpc=extra.get("maru_use_async_rpc", True), + max_inflight=extra.get("maru_max_inflight", 64), + eager_map=extra.get("maru_eager_map", True), + ) + + handler = MaruHandler(maru_config) + if not handler.connect(): + raise RuntimeError(f"Failed to connect MaruHandler to {config.maru_path}") + logger.debug("[Maru] Connected to %s", config.maru_path) + return handler + + # ========================================================================= + # AllocatorBackendInterface + # ========================================================================= + + def initialize_allocator( + self, config: LMCacheEngineConfig, metadata: LMCacheMetadata + ) -> MemoryAllocatorInterface: + """Create CxlMemoryAdapter backed by the connected handler. + + Args: + config: LMCache engine configuration. + metadata: LMCache engine metadata. + + Returns: + CxlMemoryAdapter instance. + """ + shapes = metadata.get_shapes() + dtypes = metadata.get_dtypes() + fmt = MemoryFormat.KV_MLA_FMT if metadata.use_mla else MemoryFormat.KV_2LTD + chunk_size = self._handler.get_chunk_size() + + return CxlMemoryAdapter( + handler=self._handler, + shapes=shapes, + dtypes=dtypes, + fmt=fmt, + chunk_size=chunk_size, + ) + + def get_memory_allocator(self) -> MemoryAllocatorInterface: + """Returns the underlying CxlMemoryAdapter.""" + return self.memory_allocator + + def get_allocator_backend(self) -> "MaruBackend": + """Returns self as the allocator backend.""" + return self + + def allocate( + self, + shapes: Union[torch.Size, list[torch.Size]], + dtypes: Union[torch.dtype, list[torch.dtype]], + fmt: MemoryFormat = MemoryFormat.KV_2LTD, + eviction: bool = True, + busy_loop: bool = True, + ) -> Optional[MemoryObj]: + """Allocate CXL-backed memory via CxlMemoryAdapter. + + Args: + shapes: Tensor shape(s). + dtypes: Tensor dtype(s). + fmt: Memory format. + eviction: Unused. + busy_loop: Unused. + + Returns: + MemoryObj backed by CXL memory, or None on failure. + """ + obj = self.memory_allocator.allocate(shapes, dtypes, fmt) + if obj is not None: + logger.debug( + "[Maru] allocate rid=%d pid=%d", + *CxlMemoryAdapter.decode_address(obj.metadata.address), + ) + else: + logger.debug("[Maru] allocate failed shapes=%s dtypes=%s", shapes, dtypes) + return obj + + def batched_allocate( + self, + shapes: Union[torch.Size, list[torch.Size]], + dtypes: Union[torch.dtype, list[torch.dtype]], + batch_size: int, + fmt: MemoryFormat = MemoryFormat.KV_2LTD, + eviction: bool = True, + busy_loop: bool = True, + ) -> Optional[list[MemoryObj]]: + """Allocate multiple CXL-backed MemoryObjs. + + Args: + shapes: Tensor shape(s) (same for each allocation). + dtypes: Tensor dtype(s) (same for each allocation). + batch_size: Number of allocations. + fmt: Memory format. + eviction: Unused. + busy_loop: Unused. + + Returns: + List of MemoryObj, or None if any allocation fails. + """ + return self.memory_allocator.batched_allocate(shapes, dtypes, batch_size, fmt) + + # ========================================================================= + # Put (async) + # ========================================================================= + + def exists_in_put_tasks(self, key: CacheEngineKey) -> bool: + """Check whether key is in ongoing put tasks. + + Args: + key: The cache key. + + Returns: + True if the key has a pending put task. + """ + with self.put_lock: + return key in self.put_tasks + + @staticmethod + def _create_immediate_empty_future() -> Future: + """Create a Future that is already resolved with None.""" + f: Future = Future() + f.set_result(None) + return f + + def submit_put_task( + self, + key: CacheEngineKey, + memory_obj: MemoryObj, + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> Future: + """Submit a put task to register KV metadata with MaruServer. + + Data is already in CXL memory (zero-copy). This only registers + the key -> location metadata via RPC. + + Args: + key: The cache key. + memory_obj: MemoryObj with data already written to CXL. + on_complete_callback: Optional callback after registration. + + Returns: + Future that completes when metadata is registered. + """ + # If MLA worker id as 0 mode is enabled, skip put tasks + if self._mla_worker_id_as0_mode: + return self._create_immediate_empty_future() + + assert memory_obj.tensor is not None + + # Keep CXL page alive: ref_count_down is only called on failure. + # On success the ref is retained so the CXL memory is not reclaimed. + memory_obj.ref_count_up() + + with self.put_lock: + self.put_tasks.add(key) + + future = asyncio.run_coroutine_threadsafe( + self._async_store(key, memory_obj, on_complete_callback), + self.loop, + ) + return future + + def batched_submit_put_task( + self, + keys: Sequence[CacheEngineKey], + memory_objs: List[MemoryObj], + transfer_spec: Any = None, + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> Union[List[Future], None]: + """Submit batched put tasks via single batch_store RPC. + + Args: + keys: The cache keys. + memory_objs: MemoryObjs with data already in CXL. + transfer_spec: Unused. + on_complete_callback: Optional per-key callback. + + Returns: + List containing a single Future for the entire batch. + """ + # If MLA worker id as 0 mode is enabled, skip put tasks + if self._mla_worker_id_as0_mode: + return None + + for memory_obj in memory_objs: + assert memory_obj.tensor is not None + memory_obj.ref_count_up() + + with self.put_lock: + self.put_tasks.update(keys) + + future = asyncio.run_coroutine_threadsafe( + self._async_batch_store(list(keys), memory_objs, on_complete_callback), + self.loop, + ) + return [future] + + async def _async_store( + self, + key: CacheEngineKey, + memory_obj: MemoryObj, + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> None: + """Register KV metadata with MaruServer (runs in event loop). + + Uses CxlMemoryAdapter.create_store_handle() to extract + (region_id, page_index) from the MemoryObj's encoded address. + + Args: + key: The cache key. + memory_obj: MemoryObj backed by CXL memory. + on_complete_callback: Optional callback after registration. + """ + success = False + try: + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + handle = allocator.create_store_handle(memory_obj) + key_str = key.to_string() + + success = await asyncio.to_thread(self._handler.store, key_str, handle) + + logger.debug( + "[Maru] store key=%s rid=%d pid=%d", + key, + handle.region_id, + handle.page_index, + ) + + except Exception as e: + logger.error("[Maru] store failed key=%s: %s", key, e) + raise + finally: + with self.put_lock: + self.put_tasks.discard(key) + + if not success: + memory_obj.ref_count_down() + + if success and on_complete_callback is not None: + try: + on_complete_callback(key) + except Exception as e: + logger.warning("on_complete_callback failed for key %s: %s", key, e) + + async def _async_batch_store( + self, + keys: List[CacheEngineKey], + memory_objs: List[MemoryObj], + on_complete_callback: Optional[Callable[[CacheEngineKey], None]] = None, + ) -> None: + """Register multiple KV metadata entries via single batch_store RPC.""" + results: Optional[list[bool]] = None + try: + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + + key_strs = [k.to_string() for k in keys] + handles = [allocator.create_store_handle(m) for m in memory_objs] + + results = await asyncio.to_thread( + self._handler.batch_store, key_strs, handles + ) + if results is not None: + logger.debug("[Maru] batch_store %d/%d ok", sum(results), len(results)) + except Exception as e: + logger.error("[Maru] batch_store failed: %s", e) + raise + finally: + with self.put_lock: + self.put_tasks.difference_update(keys) + + # Release ref_count for failed stores + for i, memory_obj in enumerate(memory_objs): + succeeded = results is not None and i < len(results) and results[i] + if not succeeded: + memory_obj.ref_count_down() + + if on_complete_callback is not None: + for i, key in enumerate(keys): + if results is not None and i < len(results) and results[i]: + try: + on_complete_callback(key) + except Exception as e: + logger.warning( + "on_complete_callback failed for key %s: %s", + key, + e, + ) + + # ========================================================================= + # Get (sync) + # ========================================================================= + + def get_blocking( + self, + key: CacheEngineKey, + ) -> Optional[MemoryObj]: + """Blocking get: read KV cache directly from CXL memory. + + Queries MaruServer for metadata, then returns a MemoryObj + via CxlMemoryAdapter.get_by_location(). + + Args: + key: The cache key. + + Returns: + MemoryObj backed by CXL memory, or None if not found. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + + key_str = key.to_string() + mem_info = self._handler.retrieve(key_str) + if mem_info is None: + logger.debug("[Maru] get_blocking miss key=%s", key) + return None + + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + + memory_obj = allocator.get_by_location( + region_id=mem_info.region_id, + page_index=mem_info.page_index, + actual_size=len(mem_info.view), + single_token_size=self._single_token_size, + ) + if memory_obj is None: + logger.debug( + "[Maru] get_blocking pool miss rid=%d pid=%d", + mem_info.region_id, + mem_info.page_index, + ) + return None + + memory_obj.ref_count_up() + + logger.debug( + "[Maru] get_blocking rid=%d pid=%d size=%d", + mem_info.region_id, + mem_info.page_index, + len(mem_info.view), + ) + return memory_obj + + def batched_get_blocking( + self, + keys: List[CacheEngineKey], + ) -> List[Optional[MemoryObj]]: + """Blocking batched get via single batch_retrieve RPC. + + Args: + keys: The cache keys. + + Returns: + List of MemoryObj (None for misses). + """ + if self._mla_worker_id_as0_mode: + keys = [k.with_new_worker_id(0) for k in keys] + + key_strs = [k.to_string() for k in keys] + mem_infos = self._handler.batch_retrieve(key_strs) + + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + + results: List[Optional[MemoryObj]] = [] + for mem_info in mem_infos: + if mem_info is None: + results.append(None) + continue + memory_obj = allocator.get_by_location( + region_id=mem_info.region_id, + page_index=mem_info.page_index, + actual_size=len(mem_info.view), + single_token_size=self._single_token_size, + ) + if memory_obj is None: + results.append(None) + continue + memory_obj.ref_count_up() + results.append(memory_obj) + + hits = sum(1 for r in results if r is not None) + logger.debug("[Maru] batch_retrieve %d/%d hits", hits, len(results)) + return results + + # ========================================================================= + # Async lookup API (used by StorageManager.async_lookup_and_prefetch) + # ========================================================================= + + async def batched_async_contains( + self, + lookup_id: str, + keys: List[CacheEngineKey], + pin: bool = False, + ) -> int: + """Check how many prefix keys exist via single batch_exists RPC. + + Returns the count of contiguous keys starting from index 0 + that exist. Stops at first miss. + + Args: + lookup_id: Unique request identifier. + keys: Keys to check in prefix order. + pin: If True, atomically check and pin via batch_pin RPC. + + Returns: + Number of prefix-contiguous keys that exist. + """ + return await asyncio.to_thread(self.batched_contains, keys, pin) + + async def batched_get_non_blocking( + self, + lookup_id: str, + keys: list[CacheEngineKey], + transfer_spec: Any = None, + ) -> list[MemoryObj]: + """Non-blocking batched get via single batch_retrieve RPC. + + Uses handler.batch_retrieve() for a single RPC call, then + resolves each MemoryInfo to a MemoryObj via CxlMemoryAdapter. + Stops at first miss and returns the prefix. + + Args: + lookup_id: Unique request identifier. + keys: Keys to retrieve (already confirmed by contains). + transfer_spec: Unused. + + Returns: + List of MemoryObjs backed by CXL memory. + """ + + def _batch_get() -> list[MemoryObj]: + if self._mla_worker_id_as0_mode: + actual_keys = [k.with_new_worker_id(0) for k in keys] + else: + actual_keys = list(keys) + + key_strs = [k.to_string() for k in actual_keys] + mem_infos = self._handler.batch_retrieve(key_strs) + + allocator = self.memory_allocator + assert isinstance(allocator, CxlMemoryAdapter) + + results: list[MemoryObj] = [] + for mem_info in mem_infos: + if mem_info is None: + break + memory_obj = allocator.get_by_location( + region_id=mem_info.region_id, + page_index=mem_info.page_index, + actual_size=len(mem_info.view), + single_token_size=self._single_token_size, + ) + if memory_obj is None: + break + memory_obj.ref_count_up() + memory_obj.pin() + results.append(memory_obj) + + logger.debug( + "[Maru] batch_get_non_blocking %d/%d hits", len(results), len(keys) + ) + return results + + return await asyncio.to_thread(_batch_get) + + # ========================================================================= + # Contains / Pin / Unpin / Remove + # ========================================================================= + + def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: + """Check if key exists on MaruServer. + + Args: + key: The cache key. + pin: If True, atomically check existence and pin the entry + to protect it from eviction. + + Returns: + True if key exists. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + + key_str = key.to_string() + if pin: + return self._handler.pin(key_str) + return self._handler.exists(key_str) + + def batched_contains( + self, + keys: List[CacheEngineKey], + pin: bool = False, + ) -> int: + """Check how many prefix keys exist via single batch_exists RPC. + + Args: + keys: Keys to check in prefix order. + pin: If True, atomically check and pin via + batch_pin RPC. + + Returns: + Number of prefix-contiguous keys that exist. + """ + if self._mla_worker_id_as0_mode: + keys = [k.with_new_worker_id(0) for k in keys] + + key_strs = [k.to_string() for k in keys] + if pin: + results = self._handler.batch_pin(key_strs) + else: + results = self._handler.batch_exists(key_strs) + num_hit = 0 + for exists in results: + if not exists: + break + num_hit += 1 + return num_hit + + def pin(self, key: CacheEngineKey) -> bool: + """Pin a key to prevent eviction on MaruServer. + + Increments the server-side pin_count. + + Args: + key: The cache key. + + Returns: + True if pinned successfully. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + return self._handler.pin(key.to_string()) + + def unpin(self, key: CacheEngineKey) -> bool: + """Unpin a key to allow eviction on MaruServer. + + Decrements the server-side pin_count. When pin_count reaches 0, + the entry becomes eligible for eviction. + + Args: + key: The cache key. + + Returns: + True if unpinned successfully. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + return self._handler.unpin(key.to_string()) + + def batched_unpin(self, keys: List[CacheEngineKey]) -> None: + """Batch-unpin keys via single RPC. + + Decrements server-side pin_count for each key. When pin_count + reaches 0, the entry becomes eligible for eviction. + + Args: + keys: The cache keys to unpin. + """ + if not keys: + return + if self._mla_worker_id_as0_mode: + keys = [k.with_new_worker_id(0) for k in keys] + key_strs = [k.to_string() for k in keys] + self._handler.batch_unpin(key_strs) + + def remove(self, key: CacheEngineKey, force: bool = True) -> bool: + """Remove a key from MaruServer. + + Args: + key: The cache key. + force: Whether to force removal. + + Returns: + True if removed successfully. + """ + if self._mla_worker_id_as0_mode: + key = key.with_new_worker_id(0) + key_str = key.to_string() + result = self._handler.delete(key_str) + logger.debug("[Maru] remove key=%s success=%s", key, result) + return result + + # ========================================================================= + # Lifecycle + # ========================================================================= + + def close(self) -> None: + """Close the backend and underlying MaruHandler.""" + while True: + with self.put_lock: + if not self.put_tasks: + break + time.sleep(0.1) + + self.memory_allocator.close() + self._handler.close() + logger.info("MaruBackend closed.") diff --git a/lmcache/v1/storage_backend/storage_manager.py b/lmcache/v1/storage_backend/storage_manager.py index 4ac7a4d6ce..f1756eea76 100644 --- a/lmcache/v1/storage_backend/storage_manager.py +++ b/lmcache/v1/storage_backend/storage_manager.py @@ -314,6 +314,11 @@ def _get_allocator_backend( ) -> AllocatorBackendInterface: if self.enable_pd: allocator_backend = self.storage_backends["PDBackend"] + elif "MaruBackend" in self.storage_backends: + if "LocalCPUBackend" in self.storage_backends: + allocator_backend = self.storage_backends["LocalCPUBackend"] + else: + allocator_backend = self.storage_backends["MaruBackend"] else: allocator_backend = self.storage_backends["LocalCPUBackend"] assert isinstance(allocator_backend, AllocatorBackendInterface) @@ -443,7 +448,7 @@ def get( memory_obj = backend.get_blocking(key) if memory_obj: if ( - backend_name not in ["LocalCPUBackend", "PDBackend"] + backend_name not in ["LocalCPUBackend", "PDBackend", "MaruBackend"] and "LocalCPUBackend" in self.storage_backends ): local_cpu_backend = self.storage_backends["LocalCPUBackend"] @@ -468,7 +473,25 @@ def get_non_blocking( # NOTE(Jiayi): bypass the allocator for now task = backend.get_non_blocking(key) if task: - # TODO (Jiayi): add write-back logic here + if ( + backend_name not in ["LocalCPUBackend", "PDBackend", "MaruBackend"] + and "LocalCPUBackend" in self.storage_backends + ): + + def _write_back(fut, k=key): + try: + memory_obj = fut.result() + if memory_obj is not None: + local_cpu = self.storage_backends["LocalCPUBackend"] + assert isinstance(local_cpu, LocalCPUBackend) + local_cpu.submit_put_task(k, memory_obj) + except Exception as e: + logger.warning( + "Write-back to LocalCPUBackend failed: %s", + e, + ) + + task.add_done_callback(_write_back) return task return None @@ -487,7 +510,7 @@ def batched_get( # Align with single-key `get()` logic: # auto-write remote data to local CPU cache if ( - backend_name not in ["LocalCPUBackend", "PDBackend"] + backend_name not in ["LocalCPUBackend", "PDBackend", "MaruBackend"] and "LocalCPUBackend" in self.storage_backends and None not in memory_objs ): @@ -544,8 +567,24 @@ def prefetch_single_done_callback( Callback function when a single prefetch task (i.e., prefetching from a single backend) is done. """ - # TODO(Jiayi): support write-back policy here - pass + if ( + backend_name not in ["LocalCPUBackend", "PDBackend", "MaruBackend"] + and "LocalCPUBackend" in self.storage_backends + ): + try: + memory_objs = future.result() + if memory_objs: + local_cpu = self.storage_backends["LocalCPUBackend"] + assert isinstance(local_cpu, LocalCPUBackend) + local_cpu.batched_submit_put_task( + keys[: len(memory_objs)], memory_objs + ) + except Exception as e: + logger.warning( + "Write-back to LocalCPUBackend failed for %s: %s", + backend_name, + e, + ) def prefetch_all_done_callback( self, diff --git a/setup.py b/setup.py index 79c76ce054..2d856d5ebd 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,49 @@ def hipify_wrapper() -> None: assert len(hipified_sources) == len(extra_files) +def _mooncake_extension( + cpp_extension, + mooncake_sources: list[str], + extra_cxx_flags: list[str], +) -> list: + """Build mooncake CppExtension if enabled via env vars. + + Returns a list with zero or one Extension objects. + """ + mc_env = os.environ.get("BUILD_MOONCAKE") + if mc_env is not None: + build_mc = mc_env == "1" + else: + build_mc = os.environ.get("MOONCAKE_INCLUDE_DIR", "") != "" + if not build_mc: + return [] + + mc_include = os.environ.get("MOONCAKE_INCLUDE_DIR", "") + mc_lib = os.environ.get("MOONCAKE_LIB_DIR", "") + mc_include_dirs = [ + "csrc/storage_backends", + "csrc/storage_backends/mooncake", + ] + if mc_include: + mc_include_dirs.extend(mc_include.split(";")) + mc_library_dirs: list[str] = [] + if mc_lib: + mc_library_dirs.extend(mc_lib.split(";")) + return [ + cpp_extension.CppExtension( + "lmcache.lmcache_mooncake", + sources=mooncake_sources, + include_dirs=mc_include_dirs, + library_dirs=mc_library_dirs, + libraries=["store"], + runtime_library_dirs=mc_library_dirs, + extra_compile_args={ + "cxx": extra_cxx_flags + ["-O3", "-std=c++20", "-DYLT_ENABLE_IBV"], + }, + ), + ] + + def cuda_extension() -> tuple[list, dict]: # Third Party from torch.utils import cpp_extension # Import here @@ -96,6 +139,10 @@ def cuda_extension() -> tuple[list, dict]: "csrc/storage_backends/fs/pybind.cpp", "csrc/storage_backends/fs/connector.cpp", ] + mooncake_sources = [ + "csrc/storage_backends/mooncake/pybind.cpp", + "csrc/storage_backends/mooncake/connector.cpp", + ] ext_modules = [ cpp_extension.CUDAExtension( "lmcache.c_ops", @@ -130,6 +177,10 @@ def cuda_extension() -> tuple[list, dict]: }, ), ] + # Mooncake extension is optional. + ext_modules.extend( + _mooncake_extension(cpp_extension, mooncake_sources, [flag_cxx_abi]) + ) cmdclass = {"build_ext": cpp_extension.BuildExtension} return ext_modules, cmdclass @@ -165,6 +216,10 @@ def rocm_extension() -> tuple[list, dict]: "csrc/storage_backends/fs/pybind.cpp", "csrc/storage_backends/fs/connector.cpp", ] + mooncake_sources = [ + "csrc/storage_backends/mooncake/pybind.cpp", + "csrc/storage_backends/mooncake/connector.cpp", + ] # For HIP, we generally use CppExtension and let hipcc handle things. # Ensure CXX environment variable is set to hipcc when running this build. # e.g., CXX=hipcc python setup.py install @@ -221,6 +276,8 @@ def rocm_extension() -> tuple[list, dict]: }, ), ] + # Mooncake extension is optional. + ext_modules.extend(_mooncake_extension(cpp_extension, mooncake_sources, [])) cmdclass = {"build_ext": cpp_extension.BuildExtension} return ext_modules, cmdclass diff --git a/tests/v1/distributed/test_mooncake_store_l2_adapter.py b/tests/v1/distributed/test_mooncake_store_l2_adapter.py new file mode 100644 index 0000000000..2dca0bdaa3 --- /dev/null +++ b/tests/v1/distributed/test_mooncake_store_l2_adapter.py @@ -0,0 +1,457 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for MooncakeStoreL2AdapterConfig and factory registration. + +Integration tests require the C++ Mooncake extension and a running +Mooncake Store service. They are skipped automatically when the +extension is not available. +""" + +# Standard +import os +import select + +# Third Party +import pytest +import torch + +# First Party +from lmcache.v1.distributed.api import ObjectKey +from lmcache.v1.distributed.l2_adapters.config import ( + get_registered_l2_adapter_types, + get_type_name_for_config, +) +from lmcache.v1.distributed.l2_adapters.factory import ( + create_l2_adapter_from_registry, +) +from lmcache.v1.distributed.l2_adapters.mooncake_store_l2_adapter import ( + MooncakeStoreL2AdapterConfig, +) +from lmcache.v1.memory_management import ( + MemoryFormat, + MemoryObjMetadata, + TensorMemoryObj, +) + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _native_mooncake_available() -> bool: + """Check if the C++ Mooncake extension can be imported.""" + try: + # First Party + from lmcache.lmcache_mooncake import LMCacheMooncakeClient # noqa: F401 + + return True + except ImportError: + return False + + +requires_mooncake = pytest.mark.skipif( + not _native_mooncake_available(), + reason="C++ Mooncake extension (lmcache_mooncake) not available", +) + + +def create_object_key(chunk_id: int, model_name: str = "test_model") -> ObjectKey: + return ObjectKey( + chunk_hash=ObjectKey.IntHash2Bytes(chunk_id), + model_name=model_name, + kv_rank=0, + ) + + +def create_memory_obj(size: int = 256, fill_value: float = 1.0) -> TensorMemoryObj: + raw_data = torch.empty(size, dtype=torch.float32) + raw_data.fill_(fill_value) + metadata = MemoryObjMetadata( + shape=torch.Size([size]), + dtype=torch.float32, + address=0, + phy_size=size * 4, + fmt=MemoryFormat.KV_2LTD, + ref_count=1, + ) + return TensorMemoryObj(raw_data, metadata, parent_allocator=None) + + +def wait_for_event_fd(event_fd: int, timeout: float = 10.0) -> bool: + poll = select.poll() + poll.register(event_fd, select.POLLIN) + events = poll.poll(timeout * 1000) + if events: + try: + os.eventfd_read(event_fd) + except BlockingIOError: + pass + return True + return False + + +# ============================================================================= +# Config Unit Tests (no C++ extension needed) +# ============================================================================= + + +class TestMooncakeStoreL2AdapterConfig: + """Unit tests for MooncakeStoreL2AdapterConfig.""" + + def test_from_dict_minimal(self): + """Minimal dict with only mooncake keys should work.""" + d = { + "type": "mooncake_store", + "local_hostname": "192.168.1.1", + "metadata_server": "etcd://localhost:2379", + "global_segment_size": "3221225472", + "local_buffer_size": "1073741824", + "protocol": "tcp", + } + config = MooncakeStoreL2AdapterConfig.from_dict(d) + + # LMCache-only keys should be stripped + assert "type" not in config.setup_config + + # Mooncake keys should be forwarded as strings + assert config.setup_config["local_hostname"] == "192.168.1.1" + assert config.setup_config["metadata_server"] == "etcd://localhost:2379" + assert config.setup_config["protocol"] == "tcp" + + # Default num_workers + assert config.num_workers == 4 + + def test_from_dict_with_num_workers(self): + """num_workers should be parsed and excluded from setup_config.""" + d = { + "type": "mooncake_store", + "num_workers": 8, + "local_hostname": "10.0.0.1", + } + config = MooncakeStoreL2AdapterConfig.from_dict(d) + + assert config.num_workers == 8 + assert "num_workers" not in config.setup_config + assert config.setup_config["local_hostname"] == "10.0.0.1" + + def test_from_dict_strips_lmcache_only_keys(self): + """LMCache-only keys (type, num_workers, eviction) should + not appear in setup_config.""" + d = { + "type": "mooncake_store", + "num_workers": 2, + "eviction": "lru", + "local_hostname": "host1", + } + config = MooncakeStoreL2AdapterConfig.from_dict(d) + + assert "type" not in config.setup_config + assert "num_workers" not in config.setup_config + assert "eviction" not in config.setup_config + assert config.setup_config["local_hostname"] == "host1" + + def test_from_dict_converts_values_to_str(self): + """Non-string values should be converted to strings.""" + d = { + "type": "mooncake_store", + "global_segment_size": 3221225472, + "local_buffer_size": 1073741824, + } + config = MooncakeStoreL2AdapterConfig.from_dict(d) + + assert config.setup_config["global_segment_size"] == "3221225472" + assert config.setup_config["local_buffer_size"] == "1073741824" + + def test_from_dict_skips_none_values(self): + """Keys with None values should be excluded from setup_config.""" + d = { + "type": "mooncake_store", + "local_hostname": "host1", + "optional_key": None, + } + config = MooncakeStoreL2AdapterConfig.from_dict(d) + + assert "optional_key" not in config.setup_config + assert config.setup_config["local_hostname"] == "host1" + + def test_from_dict_invalid_num_workers_zero(self): + """num_workers=0 should raise ValueError.""" + d = {"type": "mooncake_store", "num_workers": 0} + with pytest.raises(ValueError, match="num_workers"): + MooncakeStoreL2AdapterConfig.from_dict(d) + + def test_from_dict_invalid_num_workers_negative(self): + """Negative num_workers should raise ValueError.""" + d = {"type": "mooncake_store", "num_workers": -1} + with pytest.raises(ValueError, match="num_workers"): + MooncakeStoreL2AdapterConfig.from_dict(d) + + def test_from_dict_invalid_num_workers_string(self): + """Non-integer num_workers should raise ValueError.""" + d = {"type": "mooncake_store", "num_workers": "four"} + with pytest.raises(ValueError, match="num_workers"): + MooncakeStoreL2AdapterConfig.from_dict(d) + + def test_constructor_copies_setup_config(self): + """Constructor should copy the setup_config dict.""" + original = {"key": "value"} + config = MooncakeStoreL2AdapterConfig(setup_config=original) + + # Mutating the original should not affect the config + original["key"] = "changed" + assert config.setup_config["key"] == "value" + + def test_help_returns_string(self): + """help() should return a non-empty string.""" + h = MooncakeStoreL2AdapterConfig.help() + assert isinstance(h, str) + assert len(h) > 0 + + +# ============================================================================= +# Factory Registration Tests (no C++ extension needed) +# ============================================================================= + + +class TestMooncakeStoreRegistration: + """Tests for factory and config type registration.""" + + def test_mooncake_store_type_registered(self): + """'mooncake_store' should be in the registered adapter types.""" + assert "mooncake_store" in get_registered_l2_adapter_types() + + def test_config_type_name(self): + """get_type_name_for_config should return 'mooncake_store'.""" + config = MooncakeStoreL2AdapterConfig(setup_config={}) + name = get_type_name_for_config(config) + assert name == "mooncake_store" + + def test_factory_raises_without_extension(self): + """Factory should raise RuntimeError when C++ extension + is not available.""" + if _native_mooncake_available(): + pytest.skip("C++ Mooncake extension is available") + + config = MooncakeStoreL2AdapterConfig( + setup_config={"local_hostname": "localhost"}, + num_workers=2, + ) + with pytest.raises(RuntimeError, match="Mooncake"): + create_l2_adapter_from_registry(config) + + +# ============================================================================= +# Integration Tests (require C++ Mooncake extension + running service) +# ============================================================================= + +# Mooncake service connection params from environment +MOONCAKE_LOCAL_HOSTNAME = os.environ.get("MOONCAKE_LOCAL_HOSTNAME", "") +MOONCAKE_METADATA_SERVER = os.environ.get( + "MOONCAKE_METADATA_SERVER", "etcd://localhost:2379" +) + +requires_mooncake_service = pytest.mark.skipif( + not _native_mooncake_available() or not MOONCAKE_LOCAL_HOSTNAME, + reason=("C++ Mooncake extension not available or MOONCAKE_LOCAL_HOSTNAME not set"), +) + + +@requires_mooncake_service +class TestMooncakeStoreIntegration: + """Integration tests using real Mooncake Store service. + + These tests require: + 1. The C++ Mooncake extension (lmcache_mooncake) to be built + 2. A running Mooncake Store service + 3. MOONCAKE_LOCAL_HOSTNAME environment variable set + + Set environment variables before running: + export MOONCAKE_LOCAL_HOSTNAME= + export MOONCAKE_METADATA_SERVER=etcd://:2379 + """ + + @pytest.fixture(autouse=True) + def setup_adapter(self): + # First Party + from lmcache.v1.distributed.l2_adapters import create_l2_adapter + + config = MooncakeStoreL2AdapterConfig.from_dict( + { + "type": "mooncake_store", + "local_hostname": MOONCAKE_LOCAL_HOSTNAME, + "metadata_server": MOONCAKE_METADATA_SERVER, + "num_workers": 2, + } + ) + self.adapter = create_l2_adapter(config) + yield + self.adapter.close() + + def test_event_fds_are_distinct(self): + """Each operation should have a distinct event fd.""" + fds = { + self.adapter.get_store_event_fd(), + self.adapter.get_lookup_and_lock_event_fd(), + self.adapter.get_load_event_fd(), + } + assert len(fds) == 3 + + def test_store_and_lookup(self): + """Store objects, then verify lookup finds them.""" + keys = [create_object_key(i) for i in range(5)] + objs = [create_memory_obj(size=64, fill_value=float(i)) for i in range(5)] + + store_fd = self.adapter.get_store_event_fd() + lookup_fd = self.adapter.get_lookup_and_lock_event_fd() + + # Store + store_tid = self.adapter.submit_store_task(keys, objs) + assert wait_for_event_fd(store_fd) + completed = self.adapter.pop_completed_store_tasks() + assert completed[store_tid] is True + + # Lookup all — should find everything + lookup_tid = self.adapter.submit_lookup_and_lock_task(keys) + assert wait_for_event_fd(lookup_fd) + bitmap = self.adapter.query_lookup_and_lock_result(lookup_tid) + assert bitmap is not None + for i in range(5): + assert bitmap.test(i) is True, f"Key {i} not found in lookup" + + # Unlock + self.adapter.submit_unlock(keys) + + def test_lookup_nonexistent_keys(self): + """Lookup for keys not stored should return all zeros.""" + keys = [create_object_key(i + 10000) for i in range(3)] + lookup_fd = self.adapter.get_lookup_and_lock_event_fd() + + lookup_tid = self.adapter.submit_lookup_and_lock_task(keys) + assert wait_for_event_fd(lookup_fd) + bitmap = self.adapter.query_lookup_and_lock_result(lookup_tid) + assert bitmap is not None + for i in range(3): + assert bitmap.test(i) is False + + def test_full_store_lookup_load_workflow(self): + """End-to-end: store -> lookup -> load, verify data integrity.""" + key = create_object_key(42) + store_obj = create_memory_obj(size=512, fill_value=3.14) + load_obj = create_memory_obj(size=512, fill_value=0.0) + + store_fd = self.adapter.get_store_event_fd() + lookup_fd = self.adapter.get_lookup_and_lock_event_fd() + load_fd = self.adapter.get_load_event_fd() + + # Store + store_tid = self.adapter.submit_store_task([key], [store_obj]) + assert wait_for_event_fd(store_fd) + assert self.adapter.pop_completed_store_tasks()[store_tid] is True + + # Lookup + lookup_tid = self.adapter.submit_lookup_and_lock_task([key]) + assert wait_for_event_fd(lookup_fd) + bitmap = self.adapter.query_lookup_and_lock_result(lookup_tid) + assert bitmap.test(0) is True + + # Load + load_tid = self.adapter.submit_load_task([key], [load_obj]) + assert wait_for_event_fd(load_fd) + bitmap = self.adapter.query_load_result(load_tid) + assert bitmap.test(0) is True + + # Verify data integrity + assert torch.allclose(load_obj.tensor, store_obj.tensor), ( + "Loaded data does not match stored data" + ) + + # Unlock + self.adapter.submit_unlock([key]) + + def test_batch_store_lookup_load(self): + """Batch workflow with multiple objects.""" + n = 10 + keys = [create_object_key(i + 100) for i in range(n)] + store_objs = [ + create_memory_obj(size=128, fill_value=float(i * 7)) for i in range(n) + ] + load_objs = [create_memory_obj(size=128, fill_value=0.0) for _ in range(n)] + + store_fd = self.adapter.get_store_event_fd() + lookup_fd = self.adapter.get_lookup_and_lock_event_fd() + load_fd = self.adapter.get_load_event_fd() + + # Store all + store_tid = self.adapter.submit_store_task(keys, store_objs) + assert wait_for_event_fd(store_fd) + assert self.adapter.pop_completed_store_tasks()[store_tid] is True + + # Lookup all + lookup_tid = self.adapter.submit_lookup_and_lock_task(keys) + assert wait_for_event_fd(lookup_fd) + bitmap = self.adapter.query_lookup_and_lock_result(lookup_tid) + for i in range(n): + assert bitmap.test(i) is True + + # Load all + load_tid = self.adapter.submit_load_task(keys, load_objs) + assert wait_for_event_fd(load_fd) + bitmap = self.adapter.query_load_result(load_tid) + for i in range(n): + assert bitmap.test(i) is True + assert torch.allclose(load_objs[i].tensor, store_objs[i].tensor), ( + f"Data mismatch for key {i}" + ) + + self.adapter.submit_unlock(keys) + + def test_mixed_lookup_existing_and_missing(self): + """Lookup a mix of stored and non-stored keys.""" + stored_keys = [create_object_key(i + 200) for i in range(3)] + stored_objs = [create_memory_obj(fill_value=float(i)) for i in range(3)] + + store_fd = self.adapter.get_store_event_fd() + lookup_fd = self.adapter.get_lookup_and_lock_event_fd() + + # Store first 3 + self.adapter.submit_store_task(stored_keys, stored_objs) + assert wait_for_event_fd(store_fd) + self.adapter.pop_completed_store_tasks() + + # Lookup 5 keys (3 stored + 2 missing) + all_keys = stored_keys + [ + create_object_key(10100), + create_object_key(10101), + ] + lookup_tid = self.adapter.submit_lookup_and_lock_task(all_keys) + assert wait_for_event_fd(lookup_fd) + bitmap = self.adapter.query_lookup_and_lock_result(lookup_tid) + + for i in range(3): + assert bitmap.test(i) is True, f"Stored key {i} should be found" + assert bitmap.test(3) is False, "Missing key should not be found" + assert bitmap.test(4) is False, "Missing key should not be found" + + self.adapter.submit_unlock(stored_keys) + + def test_factory_creates_adapter(self): + """Verify the factory can create a Mooncake Store L2 adapter.""" + # First Party + from lmcache.v1.distributed.l2_adapters import create_l2_adapter + + config = MooncakeStoreL2AdapterConfig.from_dict( + { + "type": "mooncake_store", + "local_hostname": MOONCAKE_LOCAL_HOSTNAME, + "metadata_server": MOONCAKE_METADATA_SERVER, + "num_workers": 2, + } + ) + adapter = create_l2_adapter(config) + try: + # Should have valid event fds + assert adapter.get_store_event_fd() >= 0 + assert adapter.get_lookup_and_lock_event_fd() >= 0 + assert adapter.get_load_event_fd() >= 0 + finally: + adapter.close() diff --git a/tests/v1/distributed/test_native_connector_l2_adapter.py b/tests/v1/distributed/test_native_connector_l2_adapter.py index 402c51e656..55d83c73ad 100644 --- a/tests/v1/distributed/test_native_connector_l2_adapter.py +++ b/tests/v1/distributed/test_native_connector_l2_adapter.py @@ -108,6 +108,22 @@ def submit_batch_exists(self, keys: list[str]) -> int: return fid + def submit_batch_delete(self, keys: list[str]) -> int: + with self._lock: + fid = self._next_id + self._next_id += 1 + + results = [] + for key in keys: + if key in self._store: + del self._store[key] + results.append(True) + else: + results.append(False) + self._push_completion(fid, True, "", results) + + return fid + def drain_completions(self) -> list[tuple[int, bool, str, list[bool] | None]]: # Drain the eventfd try: @@ -946,3 +962,214 @@ def test_type_name_lookup(self): base_path="/tmp/test", ) assert get_type_name_for_config(cfg) == "fs_native" + + +# ============================================================================= +# Delete Interface Tests +# ============================================================================= + + +class TestDeleteInterface: + def test_delete_existing_key(self, adapter): + key = create_object_key(1) + obj = create_memory_obj() + store_fd = adapter.get_store_event_fd() + lookup_fd = adapter.get_lookup_and_lock_event_fd() + + # Store + adapter.submit_store_task([key], [obj]) + wait_for_event_fd(store_fd, timeout=5.0) + adapter.pop_completed_store_tasks() + + # Verify exists + task_id = adapter.submit_lookup_and_lock_task([key]) + wait_for_event_fd(lookup_fd, timeout=5.0) + bitmap = adapter.query_lookup_and_lock_result(task_id) + assert bitmap.test(0) is True + adapter.submit_unlock([key]) + + # Delete (synchronous) + adapter.delete([key]) + + # Verify gone + task_id = adapter.submit_lookup_and_lock_task([key]) + wait_for_event_fd(lookup_fd, timeout=5.0) + bitmap = adapter.query_lookup_and_lock_result(task_id) + assert bitmap.test(0) is False + + def test_delete_nonexistent_key(self, adapter): + key = create_object_key(999) + adapter.delete([key]) # should not raise + + def test_delete_empty_keys(self, adapter): + adapter.delete([]) # should not raise + + def test_delete_batch(self, adapter): + keys = [create_object_key(i) for i in range(5)] + objs = [create_memory_obj(fill_value=float(i)) for i in range(5)] + store_fd = adapter.get_store_event_fd() + lookup_fd = adapter.get_lookup_and_lock_event_fd() + + # Store all + adapter.submit_store_task(keys, objs) + wait_for_event_fd(store_fd, timeout=5.0) + adapter.pop_completed_store_tasks() + + # Delete first 3 + adapter.delete(keys[:3]) + + # Verify: first 3 gone, last 2 remain + task_id = adapter.submit_lookup_and_lock_task(keys) + wait_for_event_fd(lookup_fd, timeout=5.0) + bitmap = adapter.query_lookup_and_lock_result(task_id) + for i in range(3): + assert bitmap.test(i) is False + for i in range(3, 5): + assert bitmap.test(i) is True + adapter.submit_unlock(keys[3:]) + + +# ============================================================================= +# Delete Backward Compatibility Tests +# ============================================================================= + + +class TestDeleteBackwardCompatibility: + def test_delete_noop_without_submit_batch_delete(self): + """Connector without submit_batch_delete => delete is no-op.""" + + class NoDeleteConnector: + """Mock connector that only has the 6 original methods.""" + + def __init__(self): + self._efd = os.eventfd(0, os.EFD_NONBLOCK | os.EFD_CLOEXEC) + self._closed = False + + def event_fd(self) -> int: + return self._efd + + def submit_batch_get(self, keys, memoryviews): + return 0 + + def submit_batch_set(self, keys, memoryviews): + return 0 + + def submit_batch_exists(self, keys): + return 0 + + def drain_completions(self): + return [] + + def close(self): + if not self._closed: + self._closed = True + os.close(self._efd) + + client = NoDeleteConnector() + adp = NativeConnectorL2Adapter(client) + try: + key = create_object_key(1) + adp.delete([key]) # should not raise, just no-op + finally: + adp.close() + + +# ============================================================================= +# Usage Tracking Tests +# ============================================================================= + + +@pytest.fixture +def adapter_with_capacity(): + """Adapter with max_capacity_gb set for usage tracking tests.""" + mock_client = MockNativeConnector() + # 100 floats * 4 bytes = 400 bytes per obj; capacity = 2000 bytes = 2000/1024^3 GB + adp = NativeConnectorL2Adapter(mock_client, max_capacity_gb=2000 / (1024**3)) + yield adp + adp.close() + + +class TestUsageTracking: + def test_get_usage_without_capacity(self, adapter): + """Without max_capacity_bytes, get_usage returns (-1, -1).""" + usage = adapter.get_usage() + assert usage == (-1.0, -1.0) + + def test_get_usage_starts_at_zero(self, adapter_with_capacity): + usage, _ = adapter_with_capacity.get_usage() + assert usage == 0.0 + + def test_get_usage_after_store(self, adapter_with_capacity): + adp = adapter_with_capacity + store_fd = adp.get_store_event_fd() + + key = create_object_key(1) + obj = create_memory_obj(size=100, fill_value=1.0) # 100 floats = 400 bytes + + adp.submit_store_task([key], [obj]) + wait_for_event_fd(store_fd, timeout=5.0) + adp.pop_completed_store_tasks() + + usage, _ = adp.get_usage() + # 400 bytes / 2000 bytes = 0.2 + assert usage == pytest.approx(0.2) + + def test_get_usage_after_delete(self, adapter_with_capacity): + adp = adapter_with_capacity + store_fd = adp.get_store_event_fd() + + key = create_object_key(1) + obj = create_memory_obj(size=100, fill_value=1.0) + + # Store + adp.submit_store_task([key], [obj]) + wait_for_event_fd(store_fd, timeout=5.0) + adp.pop_completed_store_tasks() + + assert adp.get_usage()[0] == pytest.approx(0.2) + + # Delete + adp.delete([key]) + + assert adp.get_usage()[0] == pytest.approx(0.0) + + def test_get_usage_store_delete_cycle(self, adapter_with_capacity): + adp = adapter_with_capacity + store_fd = adp.get_store_event_fd() + + # Store 3 objects (3 * 400 = 1200 bytes) + keys = [create_object_key(i) for i in range(3)] + objs = [create_memory_obj(size=100, fill_value=float(i)) for i in range(3)] + + adp.submit_store_task(keys, objs) + wait_for_event_fd(store_fd, timeout=5.0) + adp.pop_completed_store_tasks() + + usage, _ = adp.get_usage() + assert usage == pytest.approx(1200 / 2000) + + # Delete 2 + adp.delete(keys[:2]) + + usage, _ = adp.get_usage() + assert usage == pytest.approx(400 / 2000) + + def test_idempotent_store_no_double_count(self, adapter_with_capacity): + adp = adapter_with_capacity + store_fd = adp.get_store_event_fd() + + key = create_object_key(1) + obj = create_memory_obj(size=100, fill_value=1.0) + + # Store same key twice + adp.submit_store_task([key], [obj]) + wait_for_event_fd(store_fd, timeout=5.0) + adp.pop_completed_store_tasks() + + adp.submit_store_task([key], [obj]) + wait_for_event_fd(store_fd, timeout=5.0) + adp.pop_completed_store_tasks() + + # Should only count once + usage, _ = adp.get_usage() + assert usage == pytest.approx(0.2) diff --git a/tests/v1/distributed/test_resp_l2_adapter_integration.py b/tests/v1/distributed/test_resp_l2_adapter_integration.py index 2d33a243b4..2bd60bcb0b 100644 --- a/tests/v1/distributed/test_resp_l2_adapter_integration.py +++ b/tests/v1/distributed/test_resp_l2_adapter_integration.py @@ -274,7 +274,7 @@ def test_factory_creates_adapter(self): """Verify the factory can create a RESP L2 adapter from config.""" # First Party from lmcache.v1.distributed.l2_adapters import create_l2_adapter - from lmcache.v1.distributed.l2_adapters.native_connector_l2_adapter import ( + from lmcache.v1.distributed.l2_adapters.resp_l2_adapter import ( RESPL2AdapterConfig, ) diff --git a/tests/v1/mp_observability/test_event_bus.py b/tests/v1/mp_observability/test_event_bus.py index 3ee1b4f9f3..45e26a63aa 100644 --- a/tests/v1/mp_observability/test_event_bus.py +++ b/tests/v1/mp_observability/test_event_bus.py @@ -334,3 +334,60 @@ def test_init_with_none_uses_defaults(self): bus = init_event_bus() assert bus._config.enabled is True assert bus._config.max_queue_size == 10_000 + + +# --------------------------------------------------------------------------- +# Block allocation event +# --------------------------------------------------------------------------- + + +class TestBlockAllocationEvent: + def test_publish_block_allocation_event(self, bus): + """Verify MP_VLLM_BLOCK_ALLOCATION events are delivered to subscribers.""" + sub = _RecordingSubscriber(event_types=[EventType.MP_VLLM_BLOCK_ALLOCATION]) + bus.register_subscriber(sub) + bus.start() + + # First Party + from lmcache.v1.multiprocess.custom_types import BlockAllocationRecord + + records = [ + BlockAllocationRecord( + req_id="req-1", + new_block_ids=[0, 1, 2], + new_token_ids=[10, 20, 30], + ), + ] + bus.publish( + _make_event( + event_type=EventType.MP_VLLM_BLOCK_ALLOCATION, + session_id="", + records=records, + ) + ) + time.sleep(0.15) + bus.stop() + + assert len(sub.events) == 1 + evt = sub.events[0] + assert evt.event_type == EventType.MP_VLLM_BLOCK_ALLOCATION + assert len(evt.metadata["records"]) == 1 + assert evt.metadata["records"][0].req_id == "req-1" + assert evt.metadata["records"][0].new_block_ids == [0, 1, 2] + + def test_block_allocation_not_delivered_to_other_subscriber(self, bus): + """Verify block allocation events are not delivered to unrelated subscribers.""" + sub = _RecordingSubscriber(event_types=[EventType.L1_READ_FINISHED]) + bus.register_subscriber(sub) + bus.start() + + bus.publish( + _make_event( + event_type=EventType.MP_VLLM_BLOCK_ALLOCATION, + session_id="", + ) + ) + time.sleep(0.15) + bus.stop() + + assert len(sub.events) == 0 diff --git a/tests/v1/multiprocess/test_custom_types.py b/tests/v1/multiprocess/test_custom_types.py index aefe2396f3..30f102de5e 100644 --- a/tests/v1/multiprocess/test_custom_types.py +++ b/tests/v1/multiprocess/test_custom_types.py @@ -10,6 +10,7 @@ # First Party from lmcache.v1.multiprocess.custom_types import ( + BlockAllocationRecord, CudaIPCWrapper, IPCCacheEngineKey, get_customized_decoder, @@ -219,3 +220,45 @@ def test_cudaipc_wrapper_multiprocess_serialization(): f"Tensor {i}: post-modification checksum mismatch. " f"Expected {new_expected_checksum}, got {actual_checksum}" ) + + +def test_block_allocation_record_serialization(): + """Test encoding and decoding of BlockAllocationRecord using msgspec.""" + original = BlockAllocationRecord( + req_id="req-42", + new_block_ids=[10, 20, 30], + new_token_ids=[100, 200, 300, 400], + ) + + encoded = msgspec.msgpack.encode(original) + decoded = msgspec.msgpack.decode(encoded, type=BlockAllocationRecord) + + assert decoded.req_id == original.req_id + assert decoded.new_block_ids == original.new_block_ids + assert decoded.new_token_ids == original.new_token_ids + + +def test_block_allocation_record_list_serialization(): + """Test encoding and decoding of a list of BlockAllocationRecord.""" + records = [ + BlockAllocationRecord( + req_id="req-1", + new_block_ids=[1, 2], + new_token_ids=[10, 20, 30], + ), + BlockAllocationRecord( + req_id="req-2", + new_block_ids=[], + new_token_ids=[40, 50], + ), + ] + + encoded = msgspec.msgpack.encode(records) + decoded = msgspec.msgpack.decode(encoded, type=list[BlockAllocationRecord]) + + assert len(decoded) == 2 + assert decoded[0].req_id == "req-1" + assert decoded[0].new_block_ids == [1, 2] + assert decoded[1].req_id == "req-2" + assert decoded[1].new_block_ids == [] + assert decoded[1].new_token_ids == [40, 50] diff --git a/tests/v1/multiprocess/test_mq.py b/tests/v1/multiprocess/test_mq.py index 3dd1e36871..fce8763852 100644 --- a/tests/v1/multiprocess/test_mq.py +++ b/tests/v1/multiprocess/test_mq.py @@ -12,7 +12,11 @@ import zmq # First Party -from lmcache.v1.multiprocess.custom_types import CudaIPCWrapper, IPCCacheEngineKey +from lmcache.v1.multiprocess.custom_types import ( + BlockAllocationRecord, + CudaIPCWrapper, + IPCCacheEngineKey, +) from lmcache.v1.multiprocess.mq import ( BlockingRequestHandler, MessageQueueClient, @@ -533,6 +537,59 @@ def test_mq_lookup_with_different_key(): ) +def test_mq_report_block_allocation(): + """ + Test MessageQueue with REPORT_BLOCK_ALLOCATION request type. + REPORT_BLOCK_ALLOCATION takes (records: list[BlockAllocationRecord]) + and returns None. + """ + records = [ + BlockAllocationRecord( + req_id="req-1", + new_block_ids=[0, 1, 2], + new_token_ids=[100, 200, 300], + ), + BlockAllocationRecord( + req_id="req-2", + new_block_ids=[3, 4], + new_token_ids=[400, 500], + ), + ] + + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5566") + helper.register_handler( + RequestType.REPORT_BLOCK_ALLOCATION, + test_mq_handler_helpers.report_block_allocations_handler, + ) + + helper.run_test( + request_type=RequestType.REPORT_BLOCK_ALLOCATION, + payloads=[records], + expected_response=None, + num_requests=1, + ) + + +def test_mq_report_block_allocation_empty(): + """ + Test REPORT_BLOCK_ALLOCATION with an empty records list. + """ + records: list[BlockAllocationRecord] = [] + + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5567") + helper.register_handler( + RequestType.REPORT_BLOCK_ALLOCATION, + test_mq_handler_helpers.report_block_allocations_handler, + ) + + helper.run_test( + request_type=RequestType.REPORT_BLOCK_ALLOCATION, + payloads=[records], + expected_response=None, + num_requests=1, + ) + + # ============================================================================== # Thread Pool Tests # ============================================================================== diff --git a/tests/v1/multiprocess/test_mq_handler_helpers.py b/tests/v1/multiprocess/test_mq_handler_helpers.py index 241fb7af23..df083be463 100644 --- a/tests/v1/multiprocess/test_mq_handler_helpers.py +++ b/tests/v1/multiprocess/test_mq_handler_helpers.py @@ -8,7 +8,7 @@ # First Party from lmcache.v1.gpu_connector.utils import LayoutHints -from lmcache.v1.multiprocess.custom_types import KVCache +from lmcache.v1.multiprocess.custom_types import BlockAllocationRecord, KVCache from lmcache.v1.multiprocess.protocol import KeyType # ============================================================================== @@ -201,3 +201,33 @@ def free_locks_handler(key: KeyType, tp_size: int) -> None: """ assert isinstance(key, KeyType), f"Expected key to be KeyType, got {type(key)}" assert isinstance(tp_size, int), f"Expected tp_size to be int, got {type(tp_size)}" + + +# ============================================================================== +# REPORT_BLOCK_ALLOCATION Request Handlers +# ============================================================================== + + +def report_block_allocations_handler( + records: list[BlockAllocationRecord], +) -> None: + """ + Dummy handler for REPORT_BLOCK_ALLOCATION requests. + + Args: + records: List of BlockAllocationRecord with per-request + block and token allocation deltas. + + Returns: + None + """ + assert isinstance(records, list), ( + f"Expected records to be list, got {type(records)}" + ) + for rec in records: + assert isinstance(rec, BlockAllocationRecord), ( + f"Expected BlockAllocationRecord, got {type(rec)}" + ) + assert isinstance(rec.req_id, str) + assert isinstance(rec.new_block_ids, list) + assert isinstance(rec.new_token_ids, list) diff --git a/tests/v1/storage_backend/test_maru_backend.py b/tests/v1/storage_backend/test_maru_backend.py new file mode 100644 index 0000000000..94961e2a62 --- /dev/null +++ b/tests/v1/storage_backend/test_maru_backend.py @@ -0,0 +1,788 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from unittest.mock import MagicMock, patch +import asyncio +import mmap +import threading + +# Third Party +import pytest +import torch + +# First Party +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import MemoryFormat, TensorMemoryObj +from lmcache.v1.pin_monitor import PinMonitor +from lmcache.v1.storage_backend.abstract_backend import AllocatorBackendInterface +from tests.v1.utils import ( + check_method_signatures, + get_abstract_methods, + get_methods_implemented_in_class, +) + +maru = pytest.importorskip("maru", reason="maru package not installed") +maru_lmcache = pytest.importorskip( + "maru_lmcache", reason="maru_lmcache package not installed" +) + +# Third Party +from maru_handler.memory import AllocHandle # noqa: E402 +from maru_handler.memory.types import MappedRegion, MemoryInfo # noqa: E402 +from maru_lmcache.adapter import CxlMemoryAdapter # noqa: E402 + +# First Party +from lmcache.v1.storage_backend.maru_backend import MaruBackend # noqa: E402 + +# ========================================================================= +# Constants +# ========================================================================= + +TEST_CHUNK_SIZE = 1024 +TEST_DTYPE = torch.float32 +TEST_SHAPE = torch.Size([256]) # 256 * 4B = 1024 bytes = chunk_size + + +# ========================================================================= +# Helpers +# ========================================================================= + + +def _make_mock_handler(pool_size=4096, chunk_size=TEST_CHUNK_SIZE): + """Create a mock MaruHandler with mmap-backed regions.""" + handler = MagicMock() + handler._connected = True + + region_id = 100 + page_count = pool_size // chunk_size + + mmap_obj = mmap.mmap(-1, pool_size) + mapped_region = MappedRegion( + region_id=region_id, + handle=MagicMock(region_id=region_id, length=pool_size), + size=pool_size, + _mmap_obj=mmap_obj, + ) + + handler.get_buffer_view.side_effect = lambda rid, offset, size: ( + mapped_region.get_buffer_view(offset, size) if rid == region_id else None + ) + handler.get_region_page_count.side_effect = lambda rid: ( + page_count if rid == region_id else None + ) + handler.get_owned_region_ids.return_value = [region_id] + handler.get_chunk_size.return_value = chunk_size + + def mock_set_on_region_added(callback): + if callback is not None: + callback(region_id, page_count) + + handler.set_on_region_added.side_effect = mock_set_on_region_added + + page_counter = [0] + + def mock_alloc(size): + idx = page_counter[0] + page_counter[0] += 1 + buf = mapped_region.get_buffer_view(idx * chunk_size, size) + return AllocHandle(buf=buf, _region_id=region_id, _page_index=idx, _size=size) + + handler.alloc.side_effect = mock_alloc + handler.free = MagicMock() + handler.connect.return_value = True + handler.close.return_value = None + handler.store.return_value = True + handler.batch_store.return_value = None + handler.retrieve.return_value = None + handler.batch_retrieve.return_value = [] + handler.exists.return_value = False + handler.batch_exists.return_value = [] + handler.delete.return_value = True + handler.pin.return_value = True + handler.unpin.return_value = True + handler.batch_pin.return_value = [] + handler.batch_unpin.return_value = None + + return handler + + +def _make_cache_key(chunk_hash: int = 12345) -> CacheEngineKey: + """Create a CacheEngineKey for testing.""" + return CacheEngineKey( + model_name="test-model", + world_size=1, + worker_id=0, + chunk_hash=chunk_hash, + dtype=torch.float32, + ) + + +def _make_memory_obj(adapter: CxlMemoryAdapter) -> TensorMemoryObj: + """Allocate a TensorMemoryObj from the adapter.""" + obj = adapter.allocate(TEST_SHAPE, TEST_DTYPE) + assert obj is not None + return obj + + +# ========================================================================= +# Fixtures +# ========================================================================= + + +@pytest.fixture(autouse=True) +def _init_pin_monitor(): + """Initialize PinMonitor singleton required by TensorMemoryObj.pin().""" + PinMonitor._instance = None + PinMonitor.GetOrCreate(LMCacheEngineConfig.from_defaults()) + yield + PinMonitor._instance = None + + +@pytest.fixture +def async_loop(): + """Provide an asyncio event loop running in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() + yield loop + loop.call_soon_threadsafe(loop.stop) + thread.join(timeout=5) + loop.close() + + +@pytest.fixture +def mock_handler(): + return _make_mock_handler() + + +@pytest.fixture +def adapter(mock_handler): + return CxlMemoryAdapter( + handler=mock_handler, + shapes=[TEST_SHAPE], + dtypes=[TEST_DTYPE], + fmt=MemoryFormat.KV_2LTD, + chunk_size=TEST_CHUNK_SIZE, + ) + + +@pytest.fixture +def backend(mock_handler, adapter, async_loop): + """Create a MaruBackend with mocked internals.""" + # Local + + with patch.object(MaruBackend, "initialize_allocator", return_value=adapter): + backend = MaruBackend.__new__(MaruBackend) + backend.dst_device = "cpu" + backend.config = MagicMock() + backend.config.maru_pool_size = 4.0 + backend.loop = async_loop + backend.memory_allocator = adapter + backend._handler = mock_handler + + backend._full_chunk_size_bytes = TEST_CHUNK_SIZE + backend._single_token_size = TEST_CHUNK_SIZE // 256 # 4 bytes per token + backend._mla_worker_id_as0_mode = False + + backend.put_lock = threading.Lock() + backend.put_tasks = set() + return backend + + +def _run_async(loop, coro): + """Submit a coroutine to a running event loop and wait for result.""" + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result(timeout=5) + + +# ========================================================================= +# Tests — Init & Interface Compliance +# ========================================================================= + + +class TestMaruBackendInit: + def test_str(self, backend): + assert str(backend) == "MaruBackend" + + def test_get_allocator_backend_returns_self(self, backend): + assert backend.get_allocator_backend() is backend + + def test_get_memory_allocator_returns_adapter(self, backend, adapter): + assert backend.get_memory_allocator() is adapter + + +class TestMaruBackendPoolSizeGbToBytes: + """Test _pool_size_gb_to_bytes static method.""" + + def test_4gb(self): + assert MaruBackend._pool_size_gb_to_bytes(4.0) == 4 * 1024**3 + + def test_half_gb(self): + assert MaruBackend._pool_size_gb_to_bytes(0.5) == 512 * 1024**2 + + def test_1gb(self): + assert MaruBackend._pool_size_gb_to_bytes(1.0) == 1024**3 + + def test_zero(self): + assert MaruBackend._pool_size_gb_to_bytes(0.0) == 0 + + +class TestMaruBackendInterfaceCompliance: + """Verify MaruBackend implements all required interface methods.""" + + def test_implements_all_abstract_methods(self): + abstract = get_abstract_methods(AllocatorBackendInterface) + implemented = get_methods_implemented_in_class( + MaruBackend, AllocatorBackendInterface + ) + missing = abstract - implemented + assert not missing, f"Missing abstract methods: {missing}" + + def test_method_signatures_match(self): + # Known: batched_submit_put_task uses 'memory_objs' instead of 'objs' + # TODO: Rename to 'objs' for full compliance + known_param_renames = {"batched_submit_put_task"} + + mismatches = check_method_signatures(AllocatorBackendInterface, MaruBackend) + unexpected = [m for m in mismatches if m["method"] not in known_param_renames] + assert not unexpected, f"Signature mismatches: {unexpected}" + + +# ========================================================================= +# Tests — Allocate +# ========================================================================= + + +class TestMaruBackendAllocate: + def test_allocate_returns_memory_obj(self, backend): + obj = backend.allocate(TEST_SHAPE, TEST_DTYPE) + assert obj is not None + assert obj.tensor is not None + assert obj.metadata.dtype == TEST_DTYPE + + def test_batched_allocate_returns_list(self, backend): + objs = backend.batched_allocate(TEST_SHAPE, TEST_DTYPE, batch_size=3) + assert objs is not None + assert len(objs) == 3 + for obj in objs: + assert obj.tensor is not None + + +# ========================================================================= +# Tests — Put (async) +# ========================================================================= + + +class TestMaruBackendPut: + def test_submit_put_task_returns_future(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + future = backend.submit_put_task(key, obj) + assert future is not None + future.result(timeout=5) + + backend._handler.store.assert_called_once() + + def test_submit_put_task_tracks_in_flight(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + assert not backend.exists_in_put_tasks(key) + + future = backend.submit_put_task(key, obj) + future.result(timeout=5) + + # After completion, key should be removed from put_tasks + assert not backend.exists_in_put_tasks(key) + + def test_exists_in_put_tasks_true_during_store(self, backend, adapter): + """Verify exists_in_put_tasks returns True while store is in progress.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + store_entered = threading.Event() + store_proceed = threading.Event() + + def blocking_store(*args, **kwargs): + store_entered.set() + store_proceed.wait(timeout=5) + return True + + backend._handler.store.side_effect = blocking_store + + future = backend.submit_put_task(key, obj) + + # Wait until store is actually running + assert store_entered.wait(timeout=5) + assert backend.exists_in_put_tasks(key) + + # Let store complete + store_proceed.set() + future.result(timeout=5) + assert not backend.exists_in_put_tasks(key) + + def test_batched_submit_put_task(self, backend, adapter): + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + + backend._handler.batch_store.return_value = [True, True, True] + + futures = backend.batched_submit_put_task(keys, objs) + assert futures is not None + + for future in futures: + future.result(timeout=5) + + backend._handler.batch_store.assert_called_once() + + def test_submit_put_calls_callback(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + callback_called = [] + + def callback(k): + callback_called.append(k) + + future = backend.submit_put_task(key, obj, on_complete_callback=callback) + future.result(timeout=5) + + assert len(callback_called) == 1 + assert callback_called[0] == key + + def test_batched_submit_put_calls_callback_per_key(self, backend, adapter): + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + + backend._handler.batch_store.return_value = [True, True, True] + callback_keys = [] + + def callback(k): + callback_keys.append(k) + + futures = backend.batched_submit_put_task( + keys, objs, on_complete_callback=callback + ) + for future in futures: + future.result(timeout=5) + + assert set(callback_keys) == set(keys) + + def test_submit_put_task_skips_in_mla_mode(self, backend, adapter): + """In MLA worker_id_as0 mode, submit_put_task should skip store.""" + backend._mla_worker_id_as0_mode = True + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + future = backend.submit_put_task(key, obj) + assert future.result(timeout=5) is None + backend._handler.store.assert_not_called() + + def test_submit_put_task_refcount_down_on_failure(self, backend, adapter): + """On store failure, ref_count should return to pre-submit level.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + initial_ref = obj.get_ref_count() + + backend._handler.store.side_effect = RuntimeError("store failed") + + future = backend.submit_put_task(key, obj) + with pytest.raises(RuntimeError): + future.result(timeout=5) + + assert obj.get_ref_count() == initial_ref + assert not backend.exists_in_put_tasks(key) + + def test_batched_submit_put_task_refcount_down_on_failure(self, backend, adapter): + """On batch_store failure, ref_count should return to pre-submit level.""" + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + initial_refs = [obj.get_ref_count() for obj in objs] + + backend._handler.batch_store.side_effect = RuntimeError("batch failed") + + futures = backend.batched_submit_put_task(keys, objs) + for future in futures: + with pytest.raises(RuntimeError): + future.result(timeout=5) + + for obj, initial_ref in zip(objs, initial_refs, strict=False): + assert obj.get_ref_count() == initial_ref + for key in keys: + assert not backend.exists_in_put_tasks(key) + + def test_batched_submit_put_task_skips_in_mla_mode(self, backend, adapter): + """In MLA worker_id_as0 mode, batched_submit_put_task should skip.""" + backend._mla_worker_id_as0_mode = True + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + + result = backend.batched_submit_put_task(keys, objs) + assert result is None + backend._handler.batch_store.assert_not_called() + + +# ========================================================================= +# Tests — Get (sync) +# ========================================================================= + + +class TestMaruBackendGet: + def test_get_blocking_hit(self, backend, adapter): + key = _make_cache_key() + + data_size = TEST_CHUNK_SIZE + data = bytearray(data_size) + mock_info = MemoryInfo( + view=memoryview(data), + region_id=100, + page_index=0, + ) + backend._handler.retrieve.return_value = mock_info + + result = backend.get_blocking(key) + assert result is not None + backend._handler.retrieve.assert_called_once() + + def test_get_blocking_miss(self, backend): + key = _make_cache_key() + backend._handler.retrieve.return_value = None + + result = backend.get_blocking(key) + assert result is None + + def test_get_blocking_ref_count_increases(self, backend, adapter): + """After get_blocking, the returned MemoryObj should have ref_count + incremented.""" + # Pre-allocate so pool has page 0 + _make_memory_obj(adapter) + + key = _make_cache_key() + mock_info = MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=100, + page_index=0, + ) + backend._handler.retrieve.return_value = mock_info + + result = backend.get_blocking(key) + assert result is not None + # Pool objects start with ref_count=1, get_blocking calls ref_count_up + assert result.get_ref_count() >= 2 + + def test_batched_get_blocking(self, backend, adapter): + """batched_get_blocking returns list of MemoryObj via batch_retrieve.""" + objs = [_make_memory_obj(adapter) for _ in range(2)] + keys = [_make_cache_key(i) for i in range(2)] + + infos = [] + for obj in objs: + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + infos.append( + MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + ) + backend._handler.batch_retrieve.return_value = infos + + results = backend.batched_get_blocking(keys) + assert len(results) == 2 + for r in results: + assert r is not None + + def test_batched_get_blocking_with_miss(self, backend, adapter): + """batched_get_blocking returns None for missing keys.""" + obj = _make_memory_obj(adapter) + keys = [_make_cache_key(i) for i in range(2)] + + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + info = MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + backend._handler.batch_retrieve.return_value = [info, None] + + results = backend.batched_get_blocking(keys) + assert len(results) == 2 + assert results[0] is not None + assert results[1] is None + + +# ========================================================================= +# Tests — Contains +# ========================================================================= + + +class TestMaruBackendContains: + def test_contains_true(self, backend): + key = _make_cache_key() + backend._handler.exists.return_value = True + + assert backend.contains(key) is True + backend._handler.exists.assert_called_once_with(key.to_string()) + + def test_contains_false(self, backend): + key = _make_cache_key() + backend._handler.exists.return_value = False + + assert backend.contains(key) is False + + def test_batched_contains_all_hit(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, True] + + result = backend.batched_contains(keys) + assert result == 3 + + def test_batched_contains_partial_prefix(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, False] + + result = backend.batched_contains(keys) + assert result == 2 + + def test_batched_contains_first_miss(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [False, True, True] + + result = backend.batched_contains(keys) + assert result == 0 + + def test_contains_with_pin(self, backend): + key = _make_cache_key() + backend._handler.pin.return_value = True + + assert backend.contains(key, pin=True) is True + backend._handler.pin.assert_called_once_with(key.to_string()) + backend._handler.exists.assert_not_called() + + def test_contains_with_pin_false(self, backend): + key = _make_cache_key() + backend._handler.pin.return_value = False + + assert backend.contains(key, pin=True) is False + + def test_batched_contains_with_pin(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_pin.return_value = [True, True, True] + + result = backend.batched_contains(keys, pin=True) + assert result == 3 + backend._handler.batch_pin.assert_called_once_with( + [k.to_string() for k in keys] + ) + backend._handler.batch_exists.assert_not_called() + + def test_batched_contains_with_pin_partial(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_pin.return_value = [True, False, True] + + result = backend.batched_contains(keys, pin=True) + assert result == 1 + + def test_batched_contains_empty(self, backend): + backend._handler.batch_exists.return_value = [] + assert backend.batched_contains([]) == 0 + + +# ========================================================================= +# Tests — Async Lookup +# ========================================================================= + + +class TestMaruBackendAsyncLookup: + def test_batched_async_contains_all_hit(self, backend, async_loop): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, True] + + result = _run_async( + async_loop, backend.batched_async_contains("lookup-1", keys) + ) + assert result == 3 + + def test_batched_async_contains_partial_prefix(self, backend, async_loop): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, False, True] + + result = _run_async( + async_loop, backend.batched_async_contains("lookup-2", keys) + ) + assert result == 1 + + def test_batched_async_contains_empty(self, backend, async_loop): + backend._handler.batch_exists.return_value = [] + result = _run_async(async_loop, backend.batched_async_contains("lookup-3", [])) + assert result == 0 + + def test_batched_get_non_blocking_all_hit(self, backend, adapter, async_loop): + keys = [_make_cache_key(i) for i in range(2)] + + objs = [_make_memory_obj(adapter) for _ in range(2)] + infos = [] + for obj in objs: + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + infos.append( + MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + ) + backend._handler.batch_retrieve.return_value = infos + + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-4", keys) + ) + assert len(results) == 2 + for obj in results: + assert obj is not None + + def test_batched_get_non_blocking_prefix_stop_on_miss( + self, backend, adapter, async_loop + ): + """Second key is a miss -> only first returned (prefix semantics).""" + keys = [_make_cache_key(i) for i in range(3)] + + obj = _make_memory_obj(adapter) + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + info = MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + # hit, miss, hit -> should return only [hit] + backend._handler.batch_retrieve.return_value = [info, None, info] + + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-5", keys) + ) + assert len(results) == 1 + + def test_batched_get_non_blocking_empty(self, backend, async_loop): + backend._handler.batch_retrieve.return_value = [] + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-6", []) + ) + assert results == [] + + +# ========================================================================= +# Tests — Pin / Unpin / Remove +# ========================================================================= + + +class TestMaruBackendPinRemove: + def test_pin_delegates_to_handler(self, backend): + key = _make_cache_key() + backend._handler.pin.return_value = True + + assert backend.pin(key) is True + backend._handler.pin.assert_called_once_with(key.to_string()) + + def test_pin_returns_false_on_failure(self, backend): + key = _make_cache_key() + backend._handler.pin.return_value = False + + assert backend.pin(key) is False + + def test_unpin_delegates_to_handler(self, backend): + key = _make_cache_key() + backend._handler.unpin.return_value = True + + assert backend.unpin(key) is True + backend._handler.unpin.assert_called_once_with(key.to_string()) + + def test_unpin_returns_false_on_failure(self, backend): + key = _make_cache_key() + backend._handler.unpin.return_value = False + + assert backend.unpin(key) is False + + def test_batched_unpin(self, backend): + keys = [_make_cache_key(i) for i in range(3)] + + backend.batched_unpin(keys) + backend._handler.batch_unpin.assert_called_once_with( + [k.to_string() for k in keys] + ) + + def test_batched_unpin_empty(self, backend): + backend.batched_unpin([]) + backend._handler.batch_unpin.assert_not_called() + + def test_remove_existing_key(self, backend): + key = _make_cache_key() + backend._handler.delete.return_value = True + + result = backend.remove(key) + assert result is True + backend._handler.delete.assert_called_once_with(key.to_string()) + + def test_remove_nonexistent_key(self, backend): + key = _make_cache_key() + backend._handler.delete.return_value = False + + result = backend.remove(key) + assert result is False + + +# ========================================================================= +# Tests — Lifecycle +# ========================================================================= + + +class TestMaruBackendLifecycle: + def test_close_calls_handler_and_allocator(self, backend): + backend.memory_allocator = MagicMock() + backend.close() + backend.memory_allocator.close.assert_called_once() + backend._handler.close.assert_called_once() + + def test_close_drains_pending_put_tasks(self, backend, adapter): + """close() should wait for in-flight put tasks to complete.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + # Submit a real put task that will complete via the event loop + future = backend.submit_put_task(key, obj) + future.result(timeout=5) + + # After drain, close should succeed + backend.close() + backend._handler.close.assert_called_once() + + +# ========================================================================= +# Tests — Store Handle Roundtrip +# ========================================================================= + + +class TestMaruBackendStoreHandle: + def test_store_handle_roundtrip(self, backend, adapter): + """AllocHandle from create_store_handle should match original.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + + handle = adapter.create_store_handle(obj) + assert handle.region_id == 100 + assert handle.page_index == 0 + assert handle._size == obj.metadata.phy_size