diff --git a/README.md b/README.md index 57bd587..7ea395f 100644 --- a/README.md +++ b/README.md @@ -146,7 +146,7 @@ Maru works as a drop-in remote storage backend for [LMCache](https://github.com/ # LMCache config remote_url: "maru://localhost:5555" extra_config: - maru_pool_size: "4G" + maru_pool_size: 4 ``` For details on LMCache integration, see the [documentation](https://xcena-dev.github.io/maru/source/integration/lmcache.html). diff --git a/docs/source/api_reference/api.md b/docs/source/api_reference/api.md index df29e47..2da8c16 100644 --- a/docs/source/api_reference/api.md +++ b/docs/source/api_reference/api.md @@ -43,8 +43,8 @@ with MaruHandler(config) as handler: ```{eval-rst} .. autoclass:: maru_handler.MaruHandler - :members: connect, close, alloc, store, retrieve, exists, delete, - batch_store, batch_retrieve, batch_exists, + :members: connect, close, alloc, free, store, retrieve, exists, pin, unpin, delete, + batch_store, batch_retrieve, batch_exists, batch_pin, batch_unpin, healthcheck, get_stats :noindex: :no-undoc-members: diff --git a/docs/source/design_doc/maru_server.md b/docs/source/design_doc/maru_server.md index 80af668..09e50ad 100644 --- a/docs/source/design_doc/maru_server.md +++ b/docs/source/design_doc/maru_server.md @@ -119,10 +119,14 @@ The server exposes the following message types: | `REGISTER_KV` | Register a KV entry at a given location | | `LOOKUP_KV` | Look up a KV entry's location and handle | | `EXISTS_KV` | Check whether a key exists | +| `PIN_KV` | Atomically check existence and pin a KV entry | +| `UNPIN_KV` | Unpin a KV entry | | `DELETE_KV` | Delete a KV entry | | `BATCH_REGISTER_KV` | Batch register multiple KV entries | | `BATCH_LOOKUP_KV` | Batch look up multiple keys | | `BATCH_EXISTS_KV` | Batch check existence of multiple keys | +| `BATCH_PIN_KV` | Batch check existence and pin multiple entries | +| `BATCH_UNPIN_KV` | Batch unpin multiple entries | | `GET_STATS` | Retrieve server statistics | | `HEARTBEAT` | Connection health check | | `HANDSHAKE` | Reserved — initial client-server handshake | diff --git a/docs/source/design_doc/resource/lmcache_component_arch.png b/docs/source/design_doc/resource/lmcache_component_arch.png index d4748be..4c671d1 100644 Binary files a/docs/source/design_doc/resource/lmcache_component_arch.png and b/docs/source/design_doc/resource/lmcache_component_arch.png differ diff --git a/docs/source/getting_started/examples/lmcache/p2p.md b/docs/source/getting_started/examples/lmcache/p2p.md index 469c487..d77f968 100644 --- a/docs/source/getting_started/examples/lmcache/p2p.md +++ b/docs/source/getting_started/examples/lmcache/p2p.md @@ -19,26 +19,22 @@ Both instances share a single configuration file (`maru-config.yaml`): ```yaml chunk_size: 256 -local_cpu: True -max_local_cpu_size: 5 +local_cpu: False +max_local_cpu_size: 0 enable_async_loading: True enable_p2p: False enable_controller: False -remote_url: "maru://localhost:${MARU_SERVER_PORT}" -remote_serde: "naive" -remote_storage_plugins: ["maru"] +# Maru backend +maru_path: "maru://localhost:${MARU_SERVER_PORT}" +maru_pool_size: 4 extra_config: - remote_storage_plugin.maru.module_path: maru_lmcache.adapter - remote_storage_plugin.maru.class_name: MaruConnectorAdapter - maru_pool_size: "4G" - save_chunk_meta: False lookup_backoff_time: 0.001 ``` -Maru is loaded as an LMCache [remote storage plugin](https://docs.lmcache.ai/developer_guide/extending_lmcache/remote_storage_plugins.html). For details on each configuration field, see {doc}`../../../integration/lmcache`. +Maru is loaded as an LMCache [storage backend](https://docs.lmcache.ai/kv_cache/storage_backends/index.html). For details on each configuration field, see {doc}`../../../integration/lmcache`. ## How to Run diff --git a/docs/source/getting_started/examples/lmcache/pd.md b/docs/source/getting_started/examples/lmcache/pd.md index 1b812d5..5a6fdfa 100644 --- a/docs/source/getting_started/examples/lmcache/pd.md +++ b/docs/source/getting_started/examples/lmcache/pd.md @@ -22,22 +22,18 @@ Both prefiller and decoder use the same configuration: ```yaml enable_pd: False chunk_size: 256 -remote_url: "maru://localhost:${MARU_SERVER_PORT}" -remote_serde: "naive" -remote_storage_plugins: ["maru"] local_cpu: False -max_local_cpu_size: 100 save_unfull_chunk: True +# Maru backend +maru_path: "maru://localhost:${MARU_SERVER_PORT}" +maru_pool_size: 4 extra_config: - remote_storage_plugin.maru.module_path: maru_lmcache.adapter - remote_storage_plugin.maru.class_name: MaruConnectorAdapter - maru_pool_size: "4G" save_chunk_meta: False lookup_backoff_time: 0.001 ``` -Maru is loaded as an LMCache [remote storage plugin](https://docs.lmcache.ai/developer_guide/extending_lmcache/remote_storage_plugins.html). For details on each configuration field, see {doc}`../../../integration/lmcache`. +Maru is loaded as an LMCache [storage backend](https://docs.lmcache.ai/kv_cache/storage_backends/index.html). For details on each configuration field, see {doc}`../../../integration/lmcache`. ## How to Run diff --git a/docs/source/integration/lmcache.md b/docs/source/integration/lmcache.md index 0d9ac52..068a577 100644 --- a/docs/source/integration/lmcache.md +++ b/docs/source/integration/lmcache.md @@ -18,30 +18,32 @@ The full stack from inference engine to shared memory: | Layer | Responsibility | Scope | |-------|---------------|-------| -| **LMCache stack** | Inference engine → CacheEngine → StorageManager → RemoteBackend | LMCache (external) | -| **MaruConnector** | Adapts LMCache's RemoteConnector to MaruHandler's API | Integration boundary | +| **LMCache stack** | Inference engine → CacheEngine → StorageManager → MaruBackend | LMCache (external) | +| **MaruBackend** | LMCache `AllocatorBackendInterface` — allocates directly on CXL, async store, sync get | Integration boundary | +| **CxlMemoryAdapter** | LMCache `MemoryAllocatorInterface` — translates Maru pages to `TensorMemoryObj` pool | Integration boundary | | **MaruHandler** | Client-side KV operations, memory mapping, connection management | Maru client | | **MaruServer** | Central metadata store, memory allocation coordinator | Maru server | -The **integration boundary** sits at MaruConnector. Everything above is LMCache; -everything below is Maru. MaruConnector is the only component that imports from -both projects. +The **integration boundary** sits at MaruBackend + CxlMemoryAdapter. Everything above is LMCache; +everything below is Maru. These two classes are the only components that import from both projects. -## Connector Design +## Backend Design -LMCache defines a `RemoteConnector` interface that all remote storage backends -must implement (`exists`, `get`, `put`, `close`, and batch variants). MaruConnector -implements this interface by delegating to MaruHandler. +### Two-layer integration -**Why the connector pattern:** LMCache's RemoteBackend is designed for pluggable -storage. The same StorageManager can use Redis, S3, Mooncake, or Maru without -any change to the cache engine logic. MaruConnector slots in as one such plugin. - -The key translation between the two APIs involves: +``` +MaruBackend (AllocatorBackendInterface) + ├── CxlMemoryAdapter (MemoryAllocatorInterface) + │ ├── _pool: {region_id: [TensorMemoryObj per page]} + │ └── address encoding: (rid << 32) | pid + └── MaruHandler (Maru client) + ├── RpcClient → MaruServer + ├── DaxMapper (mmap management) + └── OwnedRegionManager (page allocation) +``` -- **Key conversion** — LMCache uses structured `CacheEngineKey` objects; MaruHandler uses string keys (`CacheEngineKey.to_string()`). -- **Zero-copy bridging** — MaruHandler returns `MemoryInfo` (a memoryview wrapper) which the connector wraps as LMCache's `MemoryObj` without copying data. -- **Batch optimization** — The connector maps LMCache's batch operations to MaruHandler's batch RPC calls, reducing round-trip overhead. +**MaruHandler** manages CXL memory (regions, pages, mmap). **CxlMemoryAdapter** translates +pages into LMCache's `TensorMemoryObj` format. ## Data Path @@ -53,21 +55,23 @@ When the inference engine produces new KV cache data: sequenceDiagram participant IE as Inference Engine participant CE as CacheEngine - participant MC as MaruConnector + participant MB as MaruBackend participant MH as MaruHandler participant MS as MaruServer participant CXL as CXL Memory IE->>CE: KV tensors (GPU) - CE->>MC: put(key, MemoryObj) - MC->>MH: alloc(size) - MH-->>MC: handle (page in CXL region) - MC->>CXL: write data via handle buffer (zero-copy) - MC->>MH: store(key, handle) + CE->>MB: allocate(size) + MB->>MH: alloc(size) + MH-->>MB: handle (page in CXL region) + MB-->>CE: MemoryObj (CXL-backed) + CE->>CXL: GPU → CXL direct copy (only data copy) + CE->>MB: put(key, MemoryObj) + MB->>MH: store(key, handle) MH->>MS: register_kv(key, region_id, offset, length) MS-->>MH: success - MH-->>MC: True - MC-->>CE: done + MH-->>MB: True + MB-->>CE: done ``` ### Retrieve Path (read) @@ -78,20 +82,19 @@ When the inference engine needs cached KV data: sequenceDiagram participant IE as Inference Engine participant CE as CacheEngine - participant MC as MaruConnector + participant MB as MaruBackend participant MH as MaruHandler participant MS as MaruServer participant CXL as CXL Memory IE->>CE: Request KV for prompt prefix - CE->>MC: get(key) - MC->>MH: retrieve(key) + CE->>MB: get(key) + MB->>MH: retrieve(key) MH->>MS: lookup_kv(key) MS-->>MH: region_id, offset, length MH->>CXL: Map shared region (if not already mapped) - MH-->>MC: MemoryInfo (zero-copy memoryview) - MC->>MC: Wrap as MemoryObj (zero-copy) - MC-->>CE: MemoryObj + MH-->>MB: MemoryInfo (zero-copy memoryview) + MB-->>CE: MemoryObj (points to CXL mmap, zero-copy) CE-->>IE: KV tensors ``` @@ -101,59 +104,43 @@ accessed directly from CXL shared memory through memory-mapped regions. ## Configuration -Maru is loaded as an LMCache [remote storage plugin](https://docs.lmcache.ai/developer_guide/extending_lmcache/remote_storage_plugins.html) (requires LMCache >= v0.3.14). Configuration is done via the LMCache YAML config file. +Maru is configured as a native LMCache storage backend via the `maru_path` and `maru_pool_size` +config fields. No plugin registration is needed. ```yaml chunk_size: 256 -local_cpu: True -max_local_cpu_size: 5 -enable_async_loading: True +local_cpu: False +max_local_cpu_size: 0 +save_unfull_chunk: True -# Disable P2P for Maru shared storage mode -enable_p2p: False -enable_controller: False - -# Maru backend — format: maru://:[?pool_size=&pool_id=&...] -remote_url: "maru://localhost:5555" -remote_serde: "naive" -remote_storage_plugins: ["maru"] +# Maru backend +maru_path: "maru://localhost:5555" +maru_pool_size: 4 extra_config: - remote_storage_plugin.maru.module_path: maru_lmcache.adapter - remote_storage_plugin.maru.class_name: MaruConnectorAdapter - maru_pool_size: "4G" # CXL memory pool size ("1G", "500M", etc.) - # maru_pool_id: 1 # Pin to specific DAX pool (default: any) - # maru_pool_id: "0,1" # Multi-pool fallback (try pool 0, then 1) - save_chunk_meta: False lookup_backoff_time: 0.001 # maru_instance_id: "my-id" # Unique client ID (default: auto UUID) - # maru_operation_timeout: 10.0 # Per-operation timeout in seconds - # maru_timeout_ms: 2000 # ZMQ socket timeout (ms) + # maru_timeout_ms: 5000 # ZMQ socket timeout (ms) # maru_use_async_rpc: true # Async DEALER-ROUTER RPC # maru_max_inflight: 64 # Max in-flight async requests + # maru_eager_map: true # Pre-map shared regions on connect ``` -### Plugin settings +### MaruBackend settings -| Field | Description | -| --- | --- | -| `remote_storage_plugins: ["maru"]` | Registers Maru as a plugin backend | -| `remote_storage_plugin.maru.module_path` | Python module containing the adapter class | -| `remote_storage_plugin.maru.class_name` | Adapter class name (`MaruConnectorAdapter`) | +| Field | Default | Description | +| --- | --- | --- | +| `maru_path` | (required) | MaruServer address. Format: `maru://:` | +| `maru_pool_size` | `4` | CXL memory pool size in GB | ### Maru extra_config parameters | Parameter | Default | Description | | --- | --- | --- | -| `maru_pool_size` | `"1G"` | CXL memory pool size. Supports human-readable strings (`"4G"`, `"500M"`) or integer bytes | -| `maru_pool_id` | `None` (any pool) | Pin allocations to specific DAX device pool(s). Single int (`1`) or comma-separated (`"0,1"`) for ordered fallback. Can also be set via URL query: `maru://host:port?pool_id=1` | | `maru_instance_id` | auto-generated UUID | Unique client instance identifier | -| `maru_operation_timeout` | `10.0` | Timeout in seconds for individual KV operations | -| `maru_timeout_ms` | `2000` | ZMQ socket timeout in milliseconds for RPC communication | +| `maru_timeout_ms` | `5000` | ZMQ socket timeout in milliseconds for RPC communication | | `maru_use_async_rpc` | `true` | Use async DEALER-ROUTER pattern for higher throughput | | `maru_max_inflight` | `64` | Max concurrent in-flight async RPC requests | -| `maru_server_url` | (from `remote_url`) | Override server URL. Normally not needed | -| `maru_auto_connect` | `true` | Auto-connect to MaruServer on initialization | | `maru_eager_map` | `true` | Pre-map all shared regions on connect | For runnable examples, see diff --git a/examples/lmcache/disagg_prefill/1p1d/.gitignore b/examples/lmcache/disagg_prefill/1p1d/.gitignore index 11abf78..fa1cb10 100644 --- a/examples/lmcache/disagg_prefill/1p1d/.gitignore +++ b/examples/lmcache/disagg_prefill/1p1d/.gitignore @@ -1,3 +1,4 @@ .logs/ .results/ -bench_results/ \ No newline at end of file +bench_results/ +.test_pids \ No newline at end of file diff --git a/examples/lmcache/disagg_prefill/1p1d/configs/maru-decoder-config.yaml b/examples/lmcache/disagg_prefill/1p1d/configs/maru-decoder-config.yaml index e30e53e..459e73e 100644 --- a/examples/lmcache/disagg_prefill/1p1d/configs/maru-decoder-config.yaml +++ b/examples/lmcache/disagg_prefill/1p1d/configs/maru-decoder-config.yaml @@ -1,17 +1,11 @@ enable_pd: False chunk_size: 256 -# Maru remote backend -remote_url: "maru://localhost:${MARU_SERVER_PORT}" -remote_serde: "naive" -remote_storage_plugins: ["maru"] local_cpu: False -max_local_cpu_size: 100 save_unfull_chunk: True +# Maru backend +maru_path: "maru://localhost:${MARU_SERVER_PORT}" +maru_pool_size: 4 + extra_config: - remote_storage_plugin.maru.module_path: maru_lmcache.adapter - remote_storage_plugin.maru.class_name: MaruConnectorAdapter - maru_pool_size: "4G" - save_chunk_meta: False lookup_backoff_time: 0.001 - diff --git a/examples/lmcache/disagg_prefill/1p1d/configs/maru-prefiller-config.yaml b/examples/lmcache/disagg_prefill/1p1d/configs/maru-prefiller-config.yaml index e30e53e..459e73e 100644 --- a/examples/lmcache/disagg_prefill/1p1d/configs/maru-prefiller-config.yaml +++ b/examples/lmcache/disagg_prefill/1p1d/configs/maru-prefiller-config.yaml @@ -1,17 +1,11 @@ enable_pd: False chunk_size: 256 -# Maru remote backend -remote_url: "maru://localhost:${MARU_SERVER_PORT}" -remote_serde: "naive" -remote_storage_plugins: ["maru"] local_cpu: False -max_local_cpu_size: 100 save_unfull_chunk: True +# Maru backend +maru_path: "maru://localhost:${MARU_SERVER_PORT}" +maru_pool_size: 4 + extra_config: - remote_storage_plugin.maru.module_path: maru_lmcache.adapter - remote_storage_plugin.maru.class_name: MaruConnectorAdapter - maru_pool_size: "4G" - save_chunk_meta: False lookup_backoff_time: 0.001 - diff --git a/examples/lmcache/disagg_prefill/1p1d/disagg_example_1p1d.sh b/examples/lmcache/disagg_prefill/1p1d/disagg_example_1p1d.sh index da07321..5202145 100755 --- a/examples/lmcache/disagg_prefill/1p1d/disagg_example_1p1d.sh +++ b/examples/lmcache/disagg_prefill/1p1d/disagg_example_1p1d.sh @@ -14,6 +14,8 @@ PIDS=() # Switch to the directory of the current script cd "$(dirname "${BASH_SOURCE[0]}")" +PIDFILE="$(pwd)/.test_pids" + check_hf_token() { if [ -z "$HF_TOKEN" ]; then echo "HF_TOKEN is not set. Please set it to your Hugging Face token." @@ -48,24 +50,49 @@ ensure_python_library_installed() { } -kill_tree() { - # Recursively kill a process and all its descendants +save_pids() { + printf '%s\n' "${PIDS[@]}" > "$PIDFILE" +} + +kill_pgid() { + # Kill an entire process group by its leader PID local pid=$1 sig=${2:-TERM} - for child in $(pgrep -P "$pid" 2>/dev/null); do - kill_tree "$child" "$sig" - done - kill -"$sig" "$pid" 2>/dev/null + kill -"$sig" -- -"$pid" 2>/dev/null +} + +kill_stale_pids() { + # Kill leftover processes from a previous abnormal exit + if [ ! -f "$PIDFILE" ]; then + return + fi + echo "Found stale PID file. Cleaning up leftover processes..." + while read -r pid; do + if kill -0 "$pid" 2>/dev/null; then + echo " Killing leftover process group $pid" + kill_pgid "$pid" TERM + fi + done < "$PIDFILE" + sleep 1 + while read -r pid; do + if kill -0 "$pid" 2>/dev/null; then + echo " Force killing leftover process group $pid" + kill_pgid "$pid" 9 + fi + done < "$PIDFILE" + rm -f "$PIDFILE" + echo "Stale processes cleaned up." } cleanup() { echo "Stopping everything…" - trap - INT TERM USR1 EXIT # prevent re-entrancy + trap '' INT TERM USR1 # ignore signals during cleanup + trap - EXIT - # Graceful: recursively kill all tracked process trees + # Graceful: kill entire process groups for pid in "${PIDS[@]}"; do if kill -0 "$pid" 2>/dev/null; then - echo "Killing process tree of $pid" - kill_tree "$pid" TERM + echo "Killing process group of $pid" + kill_pgid "$pid" TERM fi done @@ -75,11 +102,12 @@ cleanup() { # Force kill any survivors for pid in "${PIDS[@]}"; do if kill -0 "$pid" 2>/dev/null; then - echo "Force killing process tree of $pid" - kill_tree "$pid" 9 + echo "Force killing process group of $pid" + kill_pgid "$pid" 9 fi done + rm -f "$PIDFILE" echo "All processes stopped." exit 0 } @@ -128,6 +156,7 @@ main() { ensure_python_library_installed datasets ensure_python_library_installed vllm + kill_stale_pids trap cleanup INT TERM USR1 EXIT # Launch MaruServer @@ -135,10 +164,11 @@ main() { echo "MaruServer already running on port $MARU_SERVER_PORT, skipping launch..." else echo "Launching MaruServer..." - PYTHONUNBUFFERED=1 python -m maru_server --port $MARU_SERVER_PORT --log-level "${_LOG_LEVEL:-ERROR}" \ + setsid env PYTHONUNBUFFERED=1 python -m maru_server --port $MARU_SERVER_PORT --log-level "${_LOG_LEVEL:-ERROR}" \ > >(tee "${LOG_MARU_SERVER:-maru_server.log}") 2>&1 & maru_server_pid=$! PIDS+=($maru_server_pid) + save_pids wait_for_server $MARU_SERVER_PORT fi @@ -155,7 +185,7 @@ main() { echo "Proxy will skip wait_decode_kv_ready (shared storage mode)" # Launch the proxy first - python3 ../disagg_proxy_server.py \ + setsid python3 ../disagg_proxy_server.py \ --host localhost \ --port $LMCACHE_PROXY_EXTERNAL_PORT \ --prefiller-host localhost \ @@ -172,23 +202,25 @@ main() { > >(tee "$LOG_PROXY") 2>&1 & proxy_pid=$! PIDS+=($proxy_pid) + save_pids - # Launch the decoder - bash disagg_vllm_launcher.sh decoder ${_MODEL:+"$_MODEL"} \ - > >(tee "$LOG_DECODER") 2>&1 & - decoder_pid=$! - PIDS+=($decoder_pid) - - - # Launch the prefiller next - bash disagg_vllm_launcher.sh prefiller ${_MODEL:+"$_MODEL"} \ + # Launch the prefiller first and wait for it to be ready + setsid bash disagg_vllm_launcher.sh prefiller ${_MODEL:+"$_MODEL"} \ > >(tee "$LOG_PREFILLER") 2>&1 & prefiller_pid=$! PIDS+=($prefiller_pid) + save_pids + wait_for_server $LMCACHE_PREFILLER_PORT + # Launch the decoder after prefiller is ready + setsid bash disagg_vllm_launcher.sh decoder ${_MODEL:+"$_MODEL"} \ + > >(tee "$LOG_DECODER") 2>&1 & + decoder_pid=$! + PIDS+=($decoder_pid) + save_pids wait_for_server $LMCACHE_DECODER_PORT - wait_for_server $LMCACHE_PREFILLER_PORT + wait_for_server $LMCACHE_PROXY_EXTERNAL_PORT echo "===================================================" diff --git a/examples/lmcache/disagg_prefill/1p1d/disagg_vllm_launcher.sh b/examples/lmcache/disagg_prefill/1p1d/disagg_vllm_launcher.sh index 9568d0c..da87800 100755 --- a/examples/lmcache/disagg_prefill/1p1d/disagg_vllm_launcher.sh +++ b/examples/lmcache/disagg_prefill/1p1d/disagg_vllm_launcher.sh @@ -52,7 +52,6 @@ if [[ $1 == "prefiller" ]]; then vllm serve $MODEL \ --gpu-memory-utilization ${GPU_MEM_UTIL:-0.9} \ --port $LMCACHE_PREFILLER_PORT \ - --disable-log-requests \ --enforce-eager \ --no-enable-prefix-caching \ --kv-transfer-config \ @@ -77,7 +76,6 @@ elif [[ $1 == "decoder" ]]; then vllm serve $MODEL \ --gpu-memory-utilization ${GPU_MEM_UTIL:-0.9} \ --port $LMCACHE_DECODER_PORT \ - --disable-log-requests \ --enforce-eager \ --no-enable-prefix-caching \ --kv-transfer-config \ diff --git a/examples/lmcache/p2p_sharing/configs/maru-config.yaml b/examples/lmcache/p2p_sharing/configs/maru-config.yaml index 8796ac1..82dcd4b 100644 --- a/examples/lmcache/p2p_sharing/configs/maru-config.yaml +++ b/examples/lmcache/p2p_sharing/configs/maru-config.yaml @@ -1,20 +1,10 @@ chunk_size: 256 -local_cpu: True -max_local_cpu_size: 5 +local_cpu: False enable_async_loading: True -# P2P and Controller disabled for Maru shared storage mode -enable_p2p: False -enable_controller: False - -# Maru remote backend -remote_url: "maru://localhost:${MARU_SERVER_PORT}" -remote_serde: "naive" -remote_storage_plugins: ["maru"] +# Maru backend +maru_path: "maru://localhost:${MARU_SERVER_PORT}" +maru_pool_size: 4 extra_config: - remote_storage_plugin.maru.module_path: maru_lmcache.adapter - remote_storage_plugin.maru.class_name: MaruConnectorAdapter - maru_pool_size: "4G" - save_chunk_meta: False lookup_backoff_time: 0.001 diff --git a/examples/lmcache/p2p_sharing/p2p_example.sh b/examples/lmcache/p2p_sharing/p2p_example.sh index 4681843..8052fde 100755 --- a/examples/lmcache/p2p_sharing/p2p_example.sh +++ b/examples/lmcache/p2p_sharing/p2p_example.sh @@ -175,19 +175,18 @@ main() { echo "[$(date +%T)] Launching vLLM instances (MODEL=$MODEL, GPU_MEM_UTIL=$GPU_MEM_UTIL)..." echo "Please check $LOG_INST1 and $LOG_INST2 for logs." - # Launch Instance 1 (GPU 0) + # Launch Instance 1 (GPU 0) and wait for it to be ready bash p2p_vllm_launcher.sh inst1 ${_MODEL:+"$_MODEL"} \ > >(tee "$LOG_INST1") 2>&1 & inst1_pid=$! PIDS+=($inst1_pid) + wait_for_server $LMCACHE_INST1_PORT - # Launch Instance 2 (GPU 1) + # Launch Instance 2 (GPU 1) after Instance 1 is ready bash p2p_vllm_launcher.sh inst2 ${_MODEL:+"$_MODEL"} \ > >(tee "$LOG_INST2") 2>&1 & inst2_pid=$! PIDS+=($inst2_pid) - - wait_for_server $LMCACHE_INST1_PORT wait_for_server $LMCACHE_INST2_PORT echo "===================================================" diff --git a/examples/lmcache/single/configs/maru-config.yaml b/examples/lmcache/single/configs/maru-config.yaml new file mode 100644 index 0000000..82e58dd --- /dev/null +++ b/examples/lmcache/single/configs/maru-config.yaml @@ -0,0 +1,10 @@ +chunk_size: 256 +local_cpu: False + +enable_async_loading: False + +maru_path: "maru://localhost:${MARU_SERVER_PORT}" +maru_pool_size: 4 + +extra_config: + lookup_backoff_time: 0.001 diff --git a/examples/lmcache/single/env.sh b/examples/lmcache/single/env.sh new file mode 100755 index 0000000..553cdcf --- /dev/null +++ b/examples/lmcache/single/env.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +export VLLM_LOG_LEVEL=${VLLM_LOG_LEVEL:-DEBUG} +export LMCACHE_LOG_LEVEL=${LMCACHE_LOG_LEVEL:-INFO} +export GPU_MEM_UTIL=${GPU_MEM_UTIL:-0.1} + +# Port base configuration +# Uses user ID to avoid port conflicts between users on shared machines +export LMCACHE_PORT_BASE=${LMCACHE_PORT_BASE:-$((12000 + $(id -u)))} + +# Single instance port +export LMCACHE_INST_PORT=${LMCACHE_INST_PORT:-$((LMCACHE_PORT_BASE + 20))} + +# Maru Server port +export MARU_SERVER_PORT=${MARU_SERVER_PORT:-$((10000 + $(id -u)))} diff --git a/examples/lmcache/single/run_benchmark.py b/examples/lmcache/single/run_benchmark.py new file mode 100644 index 0000000..80e8abf --- /dev/null +++ b/examples/lmcache/single/run_benchmark.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +"""Single-instance KV cache benchmark with TTFT measurement. + +Sends the same prompt twice to a single vLLM instance with Maru storage backend. +Query 1 computes and stores KV cache; Query 2 retrieves from cache. +Measures TTFT speedup to validate cache hit. + +Usage: + python run_benchmark.py [--model MODEL] [--port PORT] + [--max-tokens N] [--repeat-count N] [--wait-time SEC] +""" + +import argparse +import asyncio +import json +import os +import sys +import time + +BASE_PROMPT = "Explain the significance of KV cache in language models." +DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B" +DEFAULT_MAX_TOKENS = 32 +DEFAULT_REPEAT_COUNT = 1 +DEFAULT_WAIT_TIME = 2.0 + + +def build_prompt(base: str = BASE_PROMPT, repeat: int = 100) -> str: + """Build a long repeated prompt for KV cache generation.""" + return base * repeat + + +async def stream_completion( + base_url: str, model: str, prompt: str, max_tokens: int +) -> dict: + """Send a streaming completion request and measure TTFT. + + Returns dict with: ttft_ms, total_time_ms, text, status. + """ + from openai import AsyncOpenAI + + client = AsyncOpenAI(base_url=f"{base_url}/v1", api_key="dummy") + + start = time.monotonic() + first_token_time = None + text_chunks = [] + + try: + stream = await client.completions.create( + model=model, + prompt=prompt, + max_tokens=max_tokens, + stream=True, + ) + async for chunk in stream: + if first_token_time is None: + first_token_time = time.monotonic() + if chunk.choices and chunk.choices[0].text: + text_chunks.append(chunk.choices[0].text) + + end = time.monotonic() + ttft = (first_token_time - start) * 1000 if first_token_time else None + total = (end - start) * 1000 + + return { + "ttft_ms": round(ttft, 2) if ttft else None, + "total_time_ms": round(total, 2), + "text": "".join(text_chunks), + "status": "ok", + } + except Exception as e: + end = time.monotonic() + return { + "ttft_ms": None, + "total_time_ms": round((end - start) * 1000, 2), + "text": "", + "status": f"error: {e}", + } + finally: + await client.close() + + +async def run_session( + label: str, + base_url: str, + model: str, + prompt: str, + max_tokens: int, + repeat_count: int, +) -> list: + """Run repeat_count requests, return list of results.""" + results = [] + for i in range(repeat_count): + result = await stream_completion(base_url, model, prompt, max_tokens) + result["session"] = label + result["iteration"] = i + 1 + results.append(result) + + ttft_str = f"{result['ttft_ms']:.1f} ms" if result["ttft_ms"] else "N/A" + print( + f" [{label}] iter {i + 1}/{repeat_count}: " + f"TTFT={ttft_str}, total={result['total_time_ms']:.1f} ms", + file=sys.stderr, + ) + return results + + +_BLUE = "\033[0;34m" +_GREEN = "\033[0;32m" +_CYAN = "\033[0;36m" +_NC = "\033[0m" + + +def avg_ttft(results: list) -> float | None: + """Calculate average TTFT from results list.""" + valid = [r["ttft_ms"] for r in results if r["ttft_ms"] is not None] + return round(sum(valid) / len(valid), 2) if valid else None + + +def print_box_summary(q1_results: list, q2_results: list, wait_time: float) -> None: + """Print a box-style human-readable summary to stderr.""" + + q1_ttft = avg_ttft(q1_results) + q2_ttft = avg_ttft(q2_results) + speedup = ( + round(q1_ttft / q2_ttft, 2) if (q1_ttft and q2_ttft and q2_ttft > 0) else None + ) + cache_hit = speedup is not None and speedup > 1.5 + + print(f"\n{_BLUE}{'=' * 60}{_NC}", file=sys.stderr) + print(f"{_BLUE} Single Instance KV Cache - Results{_NC}", file=sys.stderr) + print(f"{_BLUE}{'=' * 60}{_NC}", file=sys.stderr) + print( + f" {_GREEN}Query 1 (compute+store){_NC}: TTFT = " + f"{f'{q1_ttft:.1f} ms' if q1_ttft else 'N/A'}", + file=sys.stderr, + ) + print( + f" {_GREEN}Query 2 (cache hit){_NC}: TTFT = " + f"{f'{q2_ttft:.1f} ms' if q2_ttft else 'N/A'}", + file=sys.stderr, + ) + if speedup: + print( + f" {_CYAN}TTFT Speedup{_NC}: {speedup:.2f}x", + file=sys.stderr, + ) + print( + f" {_CYAN}Cache Hit{_NC}: {'Yes' if cache_hit else 'No'}", + file=sys.stderr, + ) + print(f" Wait between queries: {wait_time}s", file=sys.stderr) + print(f"{_BLUE}{'=' * 60}{_NC}\n", file=sys.stderr) + + +def build_json_summary(q1_results: list, q2_results: list, wait_time: float) -> dict: + """Build machine-parseable JSON summary.""" + q1_ttft = avg_ttft(q1_results) + q2_ttft = avg_ttft(q2_results) + speedup = ( + round(q1_ttft / q2_ttft, 2) if (q1_ttft and q2_ttft and q2_ttft > 0) else None + ) + cache_hit = speedup is not None and speedup > 1.5 + + return { + "query1_ttft_ms": q1_ttft, + "query2_ttft_ms": q2_ttft, + "ttft_speedup": speedup, + "cache_hit": cache_hit, + "wait_time_s": wait_time, + } + + +async def main(): + parser = argparse.ArgumentParser( + description="Single-instance KV cache benchmark with TTFT measurement" + ) + parser.add_argument("--model", default=DEFAULT_MODEL) + parser.add_argument( + "--port", + type=int, + default=int(os.environ.get("LMCACHE_INST_PORT", 8000)), + help="Instance port (default: $LMCACHE_INST_PORT)", + ) + parser.add_argument("--max-tokens", type=int, default=DEFAULT_MAX_TOKENS) + parser.add_argument( + "--repeat-count", + type=int, + default=DEFAULT_REPEAT_COUNT, + help="Requests per query session (default: 1)", + ) + parser.add_argument( + "--wait-time", + type=float, + default=DEFAULT_WAIT_TIME, + help="Seconds between query 1 and query 2 (default: 2.0)", + ) + args = parser.parse_args() + + prompt = build_prompt() + base_url = f"http://localhost:{args.port}" + + print( + f"\nModel: {args.model}, Port: {args.port}, " + f"MaxTokens: {args.max_tokens}, Repeat: {args.repeat_count}", + file=sys.stderr, + ) + + # Query 1: compute and store KV cache + print( + f"\n[Query 1] Compute + Store KV cache (port {args.port})", + file=sys.stderr, + ) + q1_results = await run_session( + "query1", + base_url, + args.model, + prompt, + args.max_tokens, + args.repeat_count, + ) + + # Wait for KV cache to be fully stored + print( + f"\nWaiting {args.wait_time}s for KV cache storage...", + file=sys.stderr, + ) + await asyncio.sleep(args.wait_time) + + # Query 2: same prompt, should hit cache + print( + f"\n[Query 2] Retrieve from cache (port {args.port})", + file=sys.stderr, + ) + q2_results = await run_session( + "query2", + base_url, + args.model, + prompt, + args.max_tokens, + args.repeat_count, + ) + + # Print human-readable summary to stderr + print_box_summary(q1_results, q2_results, args.wait_time) + + # Print machine-parseable JSON on stdout + summary = build_json_summary(q1_results, q2_results, args.wait_time) + print(json.dumps(summary)) + + sys.exit(0 if summary["cache_hit"] else 1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/lmcache/single/run_benchmark.sh b/examples/lmcache/single/run_benchmark.sh new file mode 100755 index 0000000..c78b31b --- /dev/null +++ b/examples/lmcache/single/run_benchmark.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# Single-instance KV cache benchmark -- delegates to run_benchmark.py +# Measures TTFT speedup: Query 1 computes KV -> Query 2 retrieves from Maru cache +if [ -z "${VIRTUAL_ENV:-}" ]; then + echo "Warning: No virtual environment detected. Consider activating a venv first." +fi +source "$(dirname "${BASH_SOURCE[0]}")/env.sh" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +exec python "$SCRIPT_DIR/run_benchmark.py" "$@" diff --git a/examples/lmcache/single/run_simple_query.sh b/examples/lmcache/single/run_simple_query.sh new file mode 100755 index 0000000..820c6e9 --- /dev/null +++ b/examples/lmcache/single/run_simple_query.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Single instance KV cache test: send same query twice +# Flow: query 1 computes + stores KV cache -> query 2 retrieves from Maru + +cd "$(dirname "${BASH_SOURCE[0]}")" +[ -f "env.sh" ] && source env.sh + +MODEL="Qwen/Qwen2.5-0.5B" +PORT="${LMCACHE_INST_PORT:-12020}" + +PROMPT="Explain CXL memory technology in detail. CXL stands for Compute Express Link, which is a high-speed CPU-to-device and CPU-to-memory interconnect designed to accelerate next-generation data center performance. It enables memory expansion and sharing between host processors and accelerators. CXL builds on the PCI Express (PCIe) physical and electrical interface, adding a set of protocols that allow coherent memory access between CPUs and attached devices. The CXL specification defines three protocols: CXL.io for device discovery and configuration based on PCIe, CXL.cache for device-to-host cache coherency allowing devices to cache host memory with low latency, and CXL.mem for host-managed device memory that enables the host processor to access memory attached to CXL devices using standard load and store instructions. CXL technology is particularly relevant for modern data centers where memory capacity and bandwidth requirements are growing rapidly. Applications such as large language model inference, in-memory databases, and real-time analytics benefit significantly from the ability to expand memory pools beyond what is directly attached to a single CPU socket. CXL Type 3 devices, which are memory expansion devices, allow servers to access additional DRAM or persistent memory through the CXL interface, effectively creating a larger memory pool. This is especially valuable in scenarios where memory capacity is the bottleneck rather than compute power. The CXL 2.0 specification introduced memory pooling and switching capabilities, enabling multiple hosts to share a common pool of CXL-attached memory through a CXL switch. This allows for more efficient memory utilization across a cluster of servers, as memory can be dynamically allocated to the hosts that need it most. CXL 3.0 further extended these capabilities with support for fabric-attached memory, enabling even larger scale memory sharing across multiple levels of switches.\n\nSummarize the key benefits of CXL technology:" + +send_query() { + local port="$1" + curl -sS "http://localhost:${port}/v1/completions" \ + -H "Content-Type: application/json" \ + -d "{\"model\": \"${MODEL}\", \"prompt\": \"$PROMPT\", \"max_tokens\": 200, \"temperature\": 0.0, \"ignore_eos\": true}" 2>&1 \ + | python3 -c " +import sys, json +output = [] +for line in sys.stdin: + line = line.strip() + if not line or line == 'data: [DONE]': + continue + if line.startswith('data: '): + line = line[6:] + try: + data = json.loads(line) + output.append(data['choices'][0]['text']) + except (json.JSONDecodeError, KeyError, IndexError): + pass +print(''.join(output)) +" +} + +echo "=== Prompt ===" +echo "${PROMPT:0:100}..." +echo "" + +send_query "$PORT" +echo "" diff --git a/examples/lmcache/single/single_example.sh b/examples/lmcache/single/single_example.sh new file mode 100755 index 0000000..b0a95a4 --- /dev/null +++ b/examples/lmcache/single/single_example.sh @@ -0,0 +1,197 @@ +#!/bin/bash + +if [ -z "${VIRTUAL_ENV:-}" ]; then + echo "Warning: No virtual environment detected. Consider activating a venv first." +fi + +echo "Warning: LMCache KV cache sharing support for vLLM v1 is experimental and subject to change." + +# Load common environment variables +source "$(dirname "${BASH_SOURCE[0]}")/env.sh" + +PIDS=() + +# Switch to the directory of the current script +cd "$(dirname "${BASH_SOURCE[0]}")" + +ensure_python_library_installed() { + echo "Checking if $1 is installed..." + python3 -c "import $1" > /dev/null 2>&1 + if [ $? -ne 0 ]; then + echo "$1 is not installed. Please install it via pip install $1." + exit 1 + else + echo "$1 is installed." + fi +} + +kill_tree() { + local pid=$1 sig=${2:-TERM} + for child in $(pgrep -P "$pid" 2>/dev/null); do + kill_tree "$child" "$sig" + done + kill -"$sig" "$pid" 2>/dev/null +} + +cleanup() { + echo "Stopping everything..." + trap - INT TERM USR1 EXIT + + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process tree of $pid" + kill_tree "$pid" TERM + fi + done + + sleep 2 + + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process tree of $pid" + kill_tree "$pid" 9 + fi + done + + echo "All processes stopped." + exit 0 +} + +wait_for_server() { + local port=$1 + local timeout_seconds=1200 + local start_time=$(date +%s) + local last_report=$start_time + + echo "[$(date +%T)] Waiting for server on port $port (timeout: ${timeout_seconds}s)..." + + while true; do + # MaruServer (ZeroMQ) + if [ "$port" = "$MARU_SERVER_PORT" ]; then + if timeout 1 bash -c "echo >/dev/tcp/localhost/$port" 2>/dev/null; then + echo "[$(date +%T)] MaruServer is ready on port $port" + return 0 + fi + # vLLM server + else + if curl -s "localhost:${port}/v1/completions" > /dev/null; then + echo "[$(date +%T)] Server on port $port is ready" + return 0 + fi + fi + + local now=$(date +%s) + if (( now - last_report >= 30 )); then + local elapsed=$(( now - start_time )) + echo "[$(date +%T)] Still waiting for port $port... (${elapsed}s elapsed)" + if [ -f "$LOG_INST" ]; then + echo " [inst last log] $(tail -1 "$LOG_INST" 2>/dev/null)" + fi + last_report=$now + fi + + if (( now - start_time >= timeout_seconds )); then + echo "[$(date +%T)] Timeout waiting for server on port $port (${timeout_seconds}s)" + if [ -f "$LOG_INST" ]; then + echo "--- inst log (last 20 lines) ---" + tail -20 "$LOG_INST" 2>/dev/null + fi + return 1 + fi + + sleep 1 + done +} + + +main() { + echo "Using Maru storage backend (single instance)..." + + ensure_python_library_installed lmcache + ensure_python_library_installed vllm + + trap cleanup INT TERM USR1 EXIT + + # Launch MaruServer + if timeout 1 bash -c "echo >/dev/tcp/localhost/$MARU_SERVER_PORT" 2>/dev/null; then + echo "[$(date +%T)] MaruServer already running on port $MARU_SERVER_PORT, skipping launch..." + else + echo "[$(date +%T)] Launching MaruServer on port $MARU_SERVER_PORT..." + PYTHONUNBUFFERED=1 python3 -m maru_server --port $MARU_SERVER_PORT --log-level "${_LOG_LEVEL:-INFO}" \ + > >(tee "${LOG_DIR:-.}/maru_server.log") 2>&1 & + maru_server_pid=$! + PIDS+=($maru_server_pid) + echo "[$(date +%T)] MaruServer PID: $maru_server_pid (log: ${LOG_DIR:-.}/maru_server.log)" + sleep 2 + if ! kill -0 $maru_server_pid 2>/dev/null; then + echo "[$(date +%T)] ERROR: MaruServer process died! Log:" + cat "${LOG_DIR:-.}/maru_server.log" 2>/dev/null || true + return 1 + fi + wait_for_server $MARU_SERVER_PORT + echo "[$(date +%T)] MaruServer ready." + fi + + # Log file name + LOG_INST="${LOG_INST:-inst.log}" + + echo "[$(date +%T)] Launching vLLM instance (MODEL=$MODEL, GPU_MEM_UTIL=$GPU_MEM_UTIL)..." + echo "Please check $LOG_INST for logs." + + # Launch single vLLM instance (GPU 0) + bash single_vllm_launcher.sh ${_MODEL:+"$_MODEL"} \ + > >(tee "$LOG_INST") 2>&1 & + inst_pid=$! + PIDS+=($inst_pid) + + wait_for_server $LMCACHE_INST_PORT + + echo "===================================================" + echo "Server is up. You can send requests now..." + echo " Port: $LMCACHE_INST_PORT" + echo "Press Ctrl-C to terminate." + echo "===================================================" + + while true; do + sleep 1 + done +} + +# --- Help --- +usage() { + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Launch a single vLLM instance with Maru storage backend." + echo "Send the same query twice to verify KV cache hit." + echo "" + echo "Options:" + echo " --model MODEL HuggingFace model name (default: Qwen/Qwen2.5-0.5B)" + echo " --log-level LEVEL Log level: DEBUG, INFO, WARNING, ERROR" + echo " -h, --help Show this help message" + echo "" + echo "Environment variables (from env.sh):" + echo " LMCACHE_INST_PORT Instance port (default: PORT_BASE + 20)" + echo " MARU_SERVER_PORT MaruServer port (default: 10000 + UID)" + echo " GPU_MEM_UTIL GPU memory utilization (default: 0.1)" + exit 0 +} + +# --- Argument parsing --- +_LOG_LEVEL="" +_MODEL="" + +while [[ $# -gt 0 ]]; do + case "$1" in + -h|--help) usage ;; + --log-level) _LOG_LEVEL="$2"; shift 2 ;; + --model) _MODEL="$2"; shift 2 ;; + *) echo "Unknown option: $1"; usage ;; + esac +done + +if [[ -n "$_LOG_LEVEL" ]]; then + export VLLM_LOG_LEVEL="$_LOG_LEVEL" + export LMCACHE_LOG_LEVEL="$_LOG_LEVEL" +fi + +main diff --git a/examples/lmcache/single/single_vllm_launcher.sh b/examples/lmcache/single/single_vllm_launcher.sh new file mode 100755 index 0000000..804b1dd --- /dev/null +++ b/examples/lmcache/single/single_vllm_launcher.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +if [ -z "${VIRTUAL_ENV:-}" ]; then + echo "Warning: No virtual environment detected. Consider activating a venv first." +fi +source "$(dirname "${BASH_SOURCE[0]}")/env.sh" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Function to resolve environment variables in config file +resolve_config() { + local config_file=$1 + local resolved_file="/tmp/$(basename $config_file .yaml)-resolved-$$.yaml" + envsubst < "$config_file" > "$resolved_file" + echo "$resolved_file" +} + +GPU_MEM_UTIL="${GPU_MEM_UTIL:-0.1}" +DEVICE="${CUDA_DEVICE:-0}" + +if [[ $# -ge 1 ]]; then + MODEL="$1" +else + MODEL="${MODEL:-Qwen/Qwen2.5-0.5B}" +fi +echo "Using model: ${MODEL}" + +resolved_config=$(resolve_config "$SCRIPT_DIR/configs/maru-config.yaml") +echo "Resolved config: $resolved_config" + +PYTHONHASHSEED=123 \ + CUDA_VISIBLE_DEVICES=$DEVICE \ + LMCACHE_CONFIG_FILE=$resolved_config \ + vllm serve $MODEL \ + --gpu-memory-utilization $GPU_MEM_UTIL \ + --port $LMCACHE_INST_PORT \ + --no-enable-prefix-caching \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' diff --git a/maru/__init__.py b/maru/__init__.py index df28653..f760010 100644 --- a/maru/__init__.py +++ b/maru/__init__.py @@ -16,10 +16,12 @@ from maru_common import MaruConfig # noqa: E402 from maru_handler import MaruHandler # noqa: E402 +from maru_handler.memory import AllocHandle # noqa: E402 __version__ = "0.1.0" __all__ = [ + "AllocHandle", "MaruConfig", "MaruHandler", ] diff --git a/maru_common/__init__.py b/maru_common/__init__.py index 39837c9..2d02fc7 100644 --- a/maru_common/__init__.py +++ b/maru_common/__init__.py @@ -19,8 +19,12 @@ BatchKVEntry, BatchLookupKVRequest, BatchLookupKVResponse, + BatchPinKVRequest, + BatchPinKVResponse, BatchRegisterKVRequest, BatchRegisterKVResponse, + BatchUnpinKVRequest, + BatchUnpinKVResponse, DeleteKVRequest, DeleteKVResponse, ExistsKVRequest, @@ -38,12 +42,16 @@ MessageFlags, MessageHeader, MessageType, + PinKVRequest, + PinKVResponse, RegisterKVRequest, RegisterKVResponse, RequestAllocRequest, RequestAllocResponse, ReturnAllocRequest, ReturnAllocResponse, + UnpinKVRequest, + UnpinKVResponse, ) from .serializer import Serializer, create_serializer # noqa: E402 @@ -83,8 +91,17 @@ "BatchLookupKVResponse", "BatchExistsKVRequest", "BatchExistsKVResponse", + "BatchPinKVRequest", + "BatchPinKVResponse", + "BatchUnpinKVRequest", + "BatchUnpinKVResponse", "BatchKVEntry", "LookupResult", + # Pin/Unpin messages + "PinKVRequest", + "PinKVResponse", + "UnpinKVRequest", + "UnpinKVResponse", # Admin messages "GetStatsRequest", "GetStatsResponse", diff --git a/maru_common/config.py b/maru_common/config.py index fe7c934..1e2651d 100644 --- a/maru_common/config.py +++ b/maru_common/config.py @@ -51,6 +51,8 @@ class MaruConfig: max_inflight: int = 64 # Max concurrent in-flight async requests (backpressure) eager_map: bool = True # Pre-map all shared regions on connect pool_id: list[int] | int | None = None # None means any pool (ANY_POOL_ID) + auto_expand: bool = True # Auto-expand when pool is exhausted + expand_size: int | None = None # Expansion size in bytes (None means use pool_size) def __post_init__(self): """Generate instance_id if not provided. Validate config.""" @@ -95,3 +97,12 @@ def __post_init__(self): f"pool_size ({self.pool_size}) must be >= " f"chunk_size_bytes ({self.chunk_size_bytes})" ) + + if self.expand_size is not None: + if not self.auto_expand: + raise ValueError("expand_size requires auto_expand=True") + if self.expand_size < self.chunk_size_bytes: + raise ValueError( + f"expand_size ({self.expand_size}) must be >= " + f"chunk_size_bytes ({self.chunk_size_bytes})" + ) diff --git a/maru_common/protocol.py b/maru_common/protocol.py index 9ac150c..fa69f8a 100644 --- a/maru_common/protocol.py +++ b/maru_common/protocol.py @@ -48,11 +48,15 @@ class MessageType(IntEnum): LOOKUP_KV = 0x11 EXISTS_KV = 0x12 DELETE_KV = 0x13 + PIN_KV = 0x14 + UNPIN_KV = 0x15 # Batch Operations (0x20 - 0x2F) BATCH_REGISTER_KV = 0x20 BATCH_LOOKUP_KV = 0x21 BATCH_EXISTS_KV = 0x22 + BATCH_PIN_KV = 0x23 + BATCH_UNPIN_KV = 0x24 # Admin (0xF0 - 0xFF) GET_STATS = 0xF0 @@ -271,6 +275,43 @@ class DeleteKVResponse: success: bool +@dataclass +class PinKVRequest: + """PIN_KV (0x14) - Check if KV entry exists and pin it atomically. + + If the key exists, increments the entry's pin_count to protect it from + eviction. This is an atomic operation to avoid race conditions between + existence check and pinning. + """ + + key: str + + +@dataclass +class PinKVResponse: + """Response for PIN_KV.""" + + exists: bool + + +@dataclass +class UnpinKVRequest: + """UNPIN_KV (0x15) - Unpin a KV entry. + + Decrements the entry's pin_count. When pin_count reaches 0, the entry + becomes eligible for eviction again. + """ + + key: str + + +@dataclass +class UnpinKVResponse: + """Response for UNPIN_KV.""" + + success: bool + + # ============================================================================= # Batch Operations Messages (0x20 - 0x2F) # ============================================================================= @@ -327,7 +368,7 @@ class BatchLookupKVResponse: @dataclass class BatchExistsKVRequest: - """BATCH_EXISTS_KV (0x22) - Batch check KV existence.""" + """BATCH_EXISTS_KV (0x22) - Batch check KV existence (checks all keys).""" keys: list[str] = field(default_factory=list) @@ -339,6 +380,34 @@ class BatchExistsKVResponse: results: list[bool] = field(default_factory=list) +@dataclass +class BatchPinKVRequest: + """BATCH_PIN_KV (0x23) - Batch check existence and pin (prefix-stop).""" + + keys: list[str] = field(default_factory=list) + + +@dataclass +class BatchPinKVResponse: + """Response for BATCH_PIN_KV.""" + + results: list[bool] = field(default_factory=list) + + +@dataclass +class BatchUnpinKVRequest: + """BATCH_UNPIN_KV (0x24) - Batch unpin KV entries.""" + + keys: list[str] = field(default_factory=list) + + +@dataclass +class BatchUnpinKVResponse: + """Response for BATCH_UNPIN_KV.""" + + results: list[bool] = field(default_factory=list) + + # ============================================================================= # Admin Messages (0xF0 - 0xFF) # ============================================================================= @@ -438,10 +507,17 @@ class ShutdownResponse: MessageType.LOOKUP_KV: (LookupKVRequest, LookupKVResponse), MessageType.EXISTS_KV: (ExistsKVRequest, ExistsKVResponse), MessageType.DELETE_KV: (DeleteKVRequest, DeleteKVResponse), + MessageType.PIN_KV: (PinKVRequest, PinKVResponse), + MessageType.UNPIN_KV: (UnpinKVRequest, UnpinKVResponse), # Batch Operations MessageType.BATCH_REGISTER_KV: (BatchRegisterKVRequest, BatchRegisterKVResponse), MessageType.BATCH_LOOKUP_KV: (BatchLookupKVRequest, BatchLookupKVResponse), MessageType.BATCH_EXISTS_KV: (BatchExistsKVRequest, BatchExistsKVResponse), + MessageType.BATCH_PIN_KV: ( + BatchPinKVRequest, + BatchPinKVResponse, + ), + MessageType.BATCH_UNPIN_KV: (BatchUnpinKVRequest, BatchUnpinKVResponse), # Admin MessageType.GET_STATS: (GetStatsRequest, GetStatsResponse), MessageType.HEARTBEAT: (HeartbeatRequest, HeartbeatResponse), diff --git a/maru_handler/handler.py b/maru_handler/handler.py index 326db32..a72198c 100644 --- a/maru_handler/handler.py +++ b/maru_handler/handler.py @@ -12,17 +12,15 @@ with MaruHandler(config) as handler: # Zero-copy store: alloc → write to buf → store handle = handler.alloc(size=len(data)) - handle.buf[:] = data - handler.store(key=12345, handle=handle) + handle.buf[:len(data)] = data + handler.store(key="12345", handle=handle) - result = handler.retrieve(key=12345) # returns MemoryInfo + result = handler.retrieve(key="12345") # returns MemoryInfo """ -import ctypes import logging import threading - -import numpy as np +from collections.abc import Callable from maru_common import ANY_POOL_ID, MaruConfig from maru_shm import MaruHandle @@ -39,26 +37,6 @@ logger = logging.getLogger(__name__) -def _gil_free_memcpy(dst: memoryview, src: memoryview | bytes, nbytes: int) -> None: - """Copy *nbytes* from *src* into *dst*, releasing the GIL during copy. - - Uses ``ctypes.memmove`` which releases the GIL (all ctypes foreign-function - calls do) for the actual memcpy, allowing other Python threads to run - concurrently. - """ - dst_c = (ctypes.c_char * nbytes).from_buffer(dst) - if isinstance(src, memoryview) and not src.readonly: - src_c = (ctypes.c_char * nbytes).from_buffer(src) - elif isinstance(src, memoryview): - # read-only memoryview — zero-copy view via numpy to get raw pointer - arr = np.frombuffer(src[:nbytes], dtype=np.uint8) - src_c = arr.ctypes.data - else: - # bytes — ctypes.memmove accepts bytes directly - src_c = src - ctypes.memmove(dst_c, src_c, nbytes) - - class MaruHandler: """Main interface for Maru shared memory KV cache operations. @@ -124,8 +102,96 @@ def __init__(self, config: MaruConfig | None = None): self._key_to_location: dict[str, tuple[int, int]] = {} self._connected = False + # Region-added callback (set by CxlMemoryAdapter) + self._on_region_added: Callable[[int, int], None] | None = None + + # Expansion policy + self._auto_expand = self._config.auto_expand + self._expand_size = self._config.expand_size or self._config.pool_size + logger.debug("Created MaruHandler with config: %s", self._config) + # ========================================================================= + # Public Accessors + # ========================================================================= + + @property + def mapper(self) -> DaxMapper: + """Deprecated: Use get_buffer_view() instead.""" + return self._mapper + + def get_buffer_view( + self, region_id: int, offset: int, size: int + ) -> memoryview | None: + """Get a memoryview slice from a mapped region. + + Args: + region_id: The region ID (owned or shared). + offset: Byte offset within the region. + size: Number of bytes to view. + + Returns: + Writable memoryview, or None if region not mapped. + """ + return self._mapper.get_buffer_view(region_id, offset, size) + + def get_region_page_count(self, region_id: int) -> int | None: + """Get page count for a region (owned or shared). + + Args: + region_id: The region ID. + + Returns: + Number of pages, or None if region not found. + """ + if self._owned is not None: + region = self._owned.get_owned_region(region_id) + if region is not None: + return region.allocator.page_count + mapped = self._mapper.get_region(region_id) + if mapped is None: + return None + return mapped.size // self._config.chunk_size_bytes + + def get_owned_region_ids(self) -> list[int]: + """Get list of currently owned region IDs. + + Returns: + List of region IDs. Empty if not connected. + """ + if self._owned is None: + return [] + return self._owned.get_region_ids() + + def get_chunk_size(self) -> int: + """Get the configured chunk size in bytes. + + Returns: + Chunk size in bytes. + """ + return self._config.chunk_size_bytes + + def set_on_region_added(self, callback: Callable[[int, int], None] | None) -> None: + """Register callback invoked with (region_id, page_count) after region added. + + On registration, replays callback for all existing owned regions + so the caller doesn't need separate init-time logic. + + Args: + callback: Called with (region_id, page_count), or None to unregister. + """ + self._on_region_added = callback + if callback is not None and self._owned is not None: + for rid in self._owned.get_region_ids(): + region = self._owned.get_owned_region(rid) + if region is not None: + logger.debug( + "on_region_added replay: region=%d pages=%d", + rid, + region.allocator.page_count, + ) + callback(rid, region.allocator.page_count) + # ========================================================================= # Connection Management # ========================================================================= @@ -149,13 +215,15 @@ def connect(self) -> bool: chunk_size=self._config.chunk_size_bytes, ) - # 3. Request initial owned region via RPC — try each pool in order - response = None + # 3. Request initial owned region(s) via RPC — aggregate across pools + remaining = self._config.pool_size + allocated_handles: list[MaruHandle] = [] + for pool_id in self._pool_ids: try: response = self._rpc.request_alloc( instance_id=self._config.instance_id, - size=self._config.pool_size, + size=remaining, pool_id=pool_id, ) except Exception: @@ -165,34 +233,58 @@ def connect(self) -> bool: exc_info=True, ) continue - if response.success and response.handle is not None: - break - logger.warning( - "Pool %s refused initial allocation: %s", - pool_id, - getattr(response, "error", "unknown"), - ) - if response is None or not response.success or response.handle is None: - logger.error("Failed to allocate from any pool") - self._owned = None - self._rpc.close() - return False + if not response.success or response.handle is None: + logger.warning( + "Pool %s refused initial allocation: %s", + pool_id, + getattr(response, "error", "unknown"), + ) + continue - # 4. Add region to OwnedRegionManager (mmap + allocator) - try: - self._owned.add_region(response.handle) - except Exception: - logger.error("Failed to init initial region", exc_info=True) + # 4. Add region to OwnedRegionManager (mmap + allocator) + handle = response.handle try: - self._rpc.return_alloc( - self._config.instance_id, - response.handle.region_id, - ) + self._owned.add_region(handle) except Exception: - logger.debug( - "Failed to return allocation during cleanup", exc_info=True + logger.error( + "Failed to init region from pool %s", pool_id, exc_info=True ) + try: + self._rpc.return_alloc( + self._config.instance_id, handle.region_id + ) + except Exception: + logger.debug( + "Failed to return allocation during cleanup", + exc_info=True, + ) + continue + + allocated_handles.append(handle) + remaining -= handle.length + if remaining <= 0: + break + + if remaining > 0: + logger.error( + "Failed to allocate pool_size=%d: only got %d bytes from %d pool(s)", + self._config.pool_size, + self._config.pool_size - remaining, + len(allocated_handles), + ) + # Cleanup partially allocated regions + for h in allocated_handles: + try: + self._rpc.return_alloc(self._config.instance_id, h.region_id) + except Exception: + logger.debug( + "Failed to return region %d during cleanup", + h.region_id, + exc_info=True, + ) + if self._owned is not None: + self._owned.close() self._owned = None self._rpc.close() return False @@ -257,16 +349,16 @@ def close(self) -> None: # ========================================================================= def alloc(self, size: int) -> AllocHandle: - """Allocate a page and return a handle with a writable mmap memoryview. + """Allocate a page and return a handle with a writable memoryview. The caller writes directly to ``handle.buf``, then passes the handle - to ``store(key, handle=handle)`` to register without copying. + to ``store(key, handle)`` to register without copying. Args: size: Required bytes (must be <= chunk_size) Returns: - AllocHandle with writable memoryview + AllocHandle with writable memoryview and allocation metadata Raises: RuntimeError: If not connected or closing @@ -287,7 +379,14 @@ def alloc(self, size: int) -> AllocHandle: result = self._owned.allocate() if result is None: if not self._expand_region(): - raise ValueError("Cannot allocate page: pool exhausted") + if not self._auto_expand: + raise ValueError( + "Cannot allocate page: pool exhausted " + "and auto_expand is disabled" + ) + raise ValueError( + "Cannot allocate page: pool exhausted after expansion attempt" + ) result = self._owned.allocate() if result is None: raise ValueError("Cannot allocate page after expansion") @@ -356,24 +455,16 @@ def free(self, handle: AllocHandle) -> None: def store( self, key: str, - info: MemoryInfo | memoryview | None = None, - prefix: bytes | None = None, - *, - data: memoryview | None = None, - handle: AllocHandle | None = None, + handle: AllocHandle, ) -> bool: - """Store data to the KV cache. + """Register a pre-written page in the KV cache (zero-copy). - If ``handle`` is provided (zero-copy path), data is already written - to the mmap region via alloc() and only register_kv is performed. - Otherwise, allocate + memcpy + register are performed in one call. + Data must already be written to the page via ``handle.buf``. + This method only performs duplicate check + metadata registration. Args: key: The chunk key string - info: MemoryInfo or memoryview with data - prefix: Optional bytes to prepend (e.g., serialized metadata header) - data: memoryview with data (preferred, keyword-only) - handle: AllocHandle from alloc() for zero-copy store + handle: AllocHandle from alloc() Returns: True if successful @@ -384,128 +475,21 @@ def store( if self._closing.is_set(): raise RuntimeError("Handler is closing") - # Duplicate skip: check if key already exists (common to both paths) + # Duplicate skip if key in self._key_to_location: - if handle is not None: - self._owned.free(handle._region_id, handle._page_index) + self._owned.free(handle._region_id, handle._page_index) logger.debug("store: key=%s already in local map, skipping", key) return True elif self._rpc.exists_kv(key): - if handle is not None: - self._owned.free(handle._region_id, handle._page_index) + self._owned.free(handle._region_id, handle._page_index) logger.debug("store: key=%s already exists on server, skipping", key) return True - if handle is not None: - # ── Zero-copy path ── - if data is not None or info is not None: - raise ValueError("Cannot specify both handle and data/info") - - region_id = handle._region_id - page_index = handle._page_index - offset = page_index * self._owned.get_chunk_size() - total_size = handle._size - - is_new = self._rpc.register_kv( - key=key, - region_id=region_id, - kv_offset=offset, - kv_length=total_size, - ) - - if not is_new: - self._owned.free(region_id, page_index) - logger.debug( - "store: key=%s lost register race, freed page " - "(region=%d, page=%d)", - key, - region_id, - page_index, - ) - return True - - self._key_to_location[key] = (region_id, page_index) - - logger.debug( - "Stored (zero-copy) key=%s: region=%d, page=%d, offset=%d, size=%d", - key, - region_id, - page_index, - offset, - total_size, - ) - return True - - # ── Allocate + memcpy + register ── - # Resolve source memoryview from either parameter - if data is not None: - src = data - elif isinstance(info, memoryview): - src = info - elif isinstance(info, MemoryInfo): - src = info.view - else: - raise TypeError( - "Must provide data (memoryview) or info (MemoryInfo | memoryview)" - ) - - # Normalize to 1D unsigned-byte view for mmap slice assignment - if src.format != "B": - src = src.cast("B") - - data_size = len(src) - prefix_len = len(prefix) if prefix else 0 - total_size = prefix_len + data_size - - logger.debug( - "store: key=%s, data=%d bytes, prefix=%d bytes, " - "total=%d bytes, readonly=%s", - key, - data_size, - prefix_len, - total_size, - src.readonly, - ) - - if total_size > self._owned.get_chunk_size(): - logger.error( - "Total size %d exceeds chunk_size %d", - total_size, - self._owned.get_chunk_size(), - ) - return False - - # Allocate page + CXL write + register (new or overwrite only) - result = self._owned.allocate() - if result is None: - if not self._expand_region(): - logger.error("Cannot allocate page for key %s", key) - return False - result = self._owned.allocate() - if result is None: - return False - - region_id, page_index = result - - # 2. Get writable memoryview slice for the page - buf = self._mapper.get_buffer_view( - region_id, - page_index * self._owned.get_chunk_size(), - total_size, - ) - if buf is None: - self._owned.free(region_id, page_index) - return False - - # 3. Write prefix + data via GIL-free memcpy - offset = 0 - if prefix: - _gil_free_memcpy(buf[offset:], prefix, prefix_len) - offset += prefix_len - _gil_free_memcpy(buf[offset:], src, data_size) - - # 4. Register KV with server + region_id = handle._region_id + page_index = handle._page_index offset = page_index * self._owned.get_chunk_size() + total_size = handle._size + try: is_new = self._rpc.register_kv( key=key, @@ -525,9 +509,6 @@ def store( return False if not is_new: - # Race condition: another instance registered the same key - # between our exists_kv check and register_kv call. - # Free the page we just wrote — the data is identical anyway. self._owned.free(region_id, page_index) logger.debug( "store: key=%s lost register race, freed page (region=%d, page=%d)", @@ -537,11 +518,10 @@ def store( ) return True - # 5. Track self._key_to_location[key] = (region_id, page_index) logger.debug( - "Stored key=%s: region=%d, page=%d, offset=%d, size=%d", + "store: key=%s, region=%d, page=%d, offset=%d, size=%d", key, region_id, page_index, @@ -604,7 +584,9 @@ def retrieve(self, key: str) -> MemoryInfo | None: buf.readonly, self._owned.is_owned(region_id), ) - return MemoryInfo(view=buf) + chunk_size = self._owned.get_chunk_size() + page_index = result.kv_offset // chunk_size + return MemoryInfo(view=buf, region_id=region_id, page_index=page_index) def exists(self, key: str) -> bool: """Check if a key exists. @@ -618,6 +600,32 @@ def exists(self, key: str) -> bool: self._ensure_connected() return self._rpc.exists_kv(key) + def pin(self, key: str) -> bool: + """Check if a key exists and pin it atomically. + + If the key exists, increments pin_count to protect from eviction. + + Args: + key: The chunk key string + + Returns: + True if exists (and was pinned) + """ + self._ensure_connected() + return self._rpc.pin_kv(key) + + def unpin(self, key: str) -> bool: + """Unpin a KV entry, making it eligible for eviction. + + Args: + key: The chunk key string + + Returns: + True if unpinned successfully + """ + self._ensure_connected() + return self._rpc.unpin(key) + def delete(self, key: str) -> bool: """Delete a key and free the corresponding page. @@ -762,7 +770,11 @@ def batch_retrieve(self, keys: list[str]) -> list[MemoryInfo | None]: entry.kv_length, buf.readonly, ) - results.append(MemoryInfo(view=buf)) + chunk_size = self._owned.get_chunk_size() + page_index = entry.kv_offset // chunk_size + results.append( + MemoryInfo(view=buf, region_id=region_id, page_index=page_index) + ) hits = sum(1 for r in results if r is not None) ro_count = sum(1 for r in results if r is not None and r.view.readonly) @@ -778,27 +790,24 @@ def batch_retrieve(self, keys: list[str]) -> list[MemoryInfo | None]: def batch_store( self, keys: list[str], - infos: list[MemoryInfo | memoryview], - prefixes: list[bytes | None] | None = None, + handles: list[AllocHandle], ) -> list[bool]: - """Store multiple key-value pairs in batch. + """Register multiple pre-written pages in batch (zero-copy). - Uses a single batch RPC call for registration. + Data must already be written to each page via ``handle.buf``. + Uses a single batch RPC call for metadata registration. Args: keys: List of chunk key strings - infos: List of MemoryInfo or memoryview with data - prefixes: Optional list of prefix bytes per entry + handles: List of AllocHandle from alloc() Returns: List of booleans indicating success for each key """ self._ensure_connected() - if len(keys) != len(infos): - raise ValueError("keys and infos must have the same length") - if prefixes is not None and len(prefixes) != len(keys): - raise ValueError("prefixes must have the same length as keys") + if len(keys) != len(handles): + raise ValueError("keys and handles must have the same length") with self._write_lock: if self._closing.is_set(): @@ -809,7 +818,7 @@ def batch_store( register_entries = [] allocations: dict[int, tuple[int, int]] = {} - # Phase 1: Batch check which keys already exist (avoid CXL write waste) + # Phase 1: Batch check which keys already exist try: exists_resp = self._rpc.batch_exists_kv(keys) exists_results = exists_resp.results @@ -822,80 +831,34 @@ def batch_store( skipped = sum(exists_results) if skipped > 0: logger.debug( - "batch_store: %d/%d keys already exist, skipping CXL write", + "batch_store: %d/%d keys already exist, skipping", skipped, len(keys), ) - # Phase 2: Only process new keys (skip duplicates) - for i, (key, info) in enumerate(zip(keys, infos, strict=True)): - is_local = key in self._key_to_location - if is_local: - # Same instance already stored — same key = same content, skip + # Phase 2: Build register entries, free duplicates + for i, (key, handle) in enumerate(zip(keys, handles, strict=True)): + if key in self._key_to_location: + self._owned.free(handle._region_id, handle._page_index) logger.debug( "batch_store: key=%s already in local map, skipping", key ) - continue # results[i] stays True (idempotent) + continue if exists_results[i]: - # Another instance already registered — skip CXL write + self._owned.free(handle._region_id, handle._page_index) logger.debug( - "batch_store: key=%s already exists on server, skipping", key - ) - continue # results[i] stays True (idempotent) - - prefix = prefixes[i] if prefixes else None - prefix_len = len(prefix) if prefix else 0 - # Normalize to 1D unsigned-byte view for mmap slice assignment - src = info if isinstance(info, memoryview) else info.view - if src.format != "B": - src = src.cast("B") - data_size = len(src) - total_size = prefix_len + data_size - - if total_size > chunk_size: - logger.error( - "Total size %d exceeds chunk_size %d for key %s", - total_size, - chunk_size, + "batch_store: key=%s already exists on server, skipping", key, ) - results[i] = False continue - # Allocate page (expand if needed) - alloc_result = self._owned.allocate() - if alloc_result is None: - if not self._expand_region(): - logger.error("Cannot allocate page for key %s", key) - results[i] = False - continue - alloc_result = self._owned.allocate() - if alloc_result is None: - results[i] = False - continue - - region_id, page_index = alloc_result + region_id = handle._region_id + page_index = handle._page_index allocations[i] = (region_id, page_index) - - # Write to page via GIL-free memcpy - buf = self._mapper.get_buffer_view( - region_id, page_index * chunk_size, total_size - ) - if buf is None: - self._owned.free(region_id, page_index) - results[i] = False - continue - - mv_offset = 0 - if prefix: - _gil_free_memcpy(buf[mv_offset:], prefix, prefix_len) - mv_offset += prefix_len - _gil_free_memcpy(buf[mv_offset:], src, data_size) - offset = page_index * chunk_size - register_entries.append((key, region_id, offset, total_size)) + register_entries.append((key, region_id, offset, handle._size)) - # Batch register + # Phase 3: Batch register if register_entries: try: batch_resp = self._rpc.batch_register_kv(register_entries) @@ -923,15 +886,7 @@ def batch_store( if results[i] and i in allocations: self._key_to_location[key] = allocations[i] - total_bytes = sum( - ( - infos[i].nbytes - if isinstance(infos[i], memoryview) - else infos[i].view.nbytes - ) - for i in range(len(keys)) - if results[i] - ) + total_bytes = sum(handles[i]._size for i in range(len(keys)) if results[i]) logger.debug( "batch_store: %d/%d succeeded, total_data=%d bytes", sum(results), @@ -960,6 +915,30 @@ def batch_exists(self, keys: list[str]) -> list[bool]: return [False] * len(keys) return batch_resp.results + def batch_pin(self, keys: list[str]) -> list[bool]: + """Check existence and pin multiple keys in a single RPC call. + + Args: + keys: List of chunk key strings + + Returns: + List of booleans — True if key exists (and was pinned). + """ + self._ensure_connected() + return self._rpc.batch_pin_kv(keys).results + + def batch_unpin(self, keys: list[str]) -> list[bool]: + """Unpin multiple keys in a single RPC call. + + Args: + keys: List of chunk key strings + + Returns: + List of booleans — True if successfully unpinned. + """ + self._ensure_connected() + return self._rpc.batch_unpin(keys).results + # ========================================================================= # Properties # ========================================================================= @@ -984,7 +963,7 @@ def allocator(self) -> PagedMemoryAllocator | None: @property def owned_region_manager(self) -> OwnedRegionManager | None: - """Get the owned region manager.""" + """Deprecated: Use get_owned_region_ids(), get_region_page_count() instead.""" return self._owned @property @@ -1004,16 +983,24 @@ def connected(self) -> bool: def _expand_region(self) -> bool: """Request a new store region from the server and add it. - Tries each pool_id in order, falling back to the next on failure. + Gated by ``auto_expand`` config. When enabled, tries each pool_id + in order, falling back to the next on failure. Returns: True if expansion succeeded. """ + if not self._auto_expand: + logger.warning( + "Pool exhausted but auto_expand is disabled. " + "Set auto_expand=True in MaruConfig to enable." + ) + return False + for pool_id in self._pool_ids: try: response = self._rpc.request_alloc( instance_id=self._config.instance_id, - size=self._config.pool_size, + size=self._expand_size, pool_id=pool_id, ) except Exception: @@ -1034,12 +1021,21 @@ def _expand_region(self) -> bool: handle = response.handle try: - self._owned.add_region(handle) + region = self._owned.add_region(handle) logger.info( "Expanded: new store region %d (pool_id=%s)", handle.region_id, pool_id, ) + # Callback fires under _write_lock — guarantees pool exists + # before alloc() returns. Acceptable since expansion is rare. + if self._on_region_added is not None: + logger.debug( + "on_region_added fire: region=%d pages=%d", + handle.region_id, + region.allocator.page_count, + ) + self._on_region_added(handle.region_id, region.allocator.page_count) return True except Exception: logger.error("Failed to init expanded region", exc_info=True) diff --git a/maru_handler/memory/owned_region_manager.py b/maru_handler/memory/owned_region_manager.py index 30f6352..e852d7d 100644 --- a/maru_handler/memory/owned_region_manager.py +++ b/maru_handler/memory/owned_region_manager.py @@ -207,6 +207,10 @@ def get_chunk_size(self) -> int: """Return the chunk size.""" return self._chunk_size + def get_region_ids(self) -> list[int]: + """Get list of owned region IDs in insertion order.""" + return list(self._region_order) + def get_owned_region(self, region_id: int) -> OwnedRegion | None: """Get an owned region by ID.""" return self._regions.get(region_id) diff --git a/maru_handler/memory/types.py b/maru_handler/memory/types.py index c28ba96..a5c1803 100644 --- a/maru_handler/memory/types.py +++ b/maru_handler/memory/types.py @@ -106,8 +106,15 @@ class OwnedRegion: class AllocHandle: """Handle returned by MaruHandler.alloc() for zero-copy writes. - Caller writes directly to ``buf`` (an mmap memoryview), then passes - this handle to ``store(key, handle=handle)`` to register without copy. + Contains a writable memoryview into CXL mmap memory and allocation + metadata. The caller writes directly to ``buf``, then passes the + handle to ``store(key, handle)`` to register without copying. + + Typical zero-copy flow:: + + handle = handler.alloc(size=len(data)) + handle.buf[:len(data)] = data + handler.store(key=key, handle=handle) """ buf: memoryview @@ -148,3 +155,5 @@ class MemoryInfo: """ view: memoryview + region_id: int = 0 + page_index: int = 0 diff --git a/maru_handler/rpc_async_client.py b/maru_handler/rpc_async_client.py index 614b7cf..6ebb7af 100644 --- a/maru_handler/rpc_async_client.py +++ b/maru_handler/rpc_async_client.py @@ -37,7 +37,9 @@ AllocationManagerStats, BatchExistsKVResponse, BatchLookupKVResponse, + BatchPinKVResponse, BatchRegisterKVResponse, + BatchUnpinKVResponse, GetStatsResponse, KVManagerStats, ListAllocationsResponse, @@ -419,6 +421,16 @@ def exists_kv(self, key: str) -> bool: response = self._send_request(MessageType.EXISTS_KV, {"key": key}) return response.get("exists", False) + def pin_kv(self, key: str) -> bool: + """Check if a KV entry exists and pin it atomically.""" + response = self._send_request(MessageType.PIN_KV, {"key": key}) + return response.get("exists", False) + + def unpin(self, key: str) -> bool: + """Unpin a KV entry, making it eligible for eviction.""" + response = self._send_request(MessageType.UNPIN_KV, {"key": key}) + return response.get("success", False) + def delete_kv(self, key: str) -> bool: """Delete a KV entry.""" response = self._send_request(MessageType.DELETE_KV, {"key": key}) @@ -459,6 +471,16 @@ def batch_exists_kv(self, keys: list[str]) -> BatchExistsKVResponse: response = self._send_request(MessageType.BATCH_EXISTS_KV, {"keys": keys}) return BatchExistsKVResponse(results=response.get("results", [])) + def batch_pin_kv(self, keys: list[str]) -> BatchPinKVResponse: + """Check existence and pin multiple KV entries in a single RPC call.""" + response = self._send_request(MessageType.BATCH_PIN_KV, {"keys": keys}) + return BatchPinKVResponse(results=response.get("results", [])) + + def batch_unpin(self, keys: list[str]) -> BatchUnpinKVResponse: + """Unpin multiple KV entries in a single RPC call.""" + response = self._send_request(MessageType.BATCH_UNPIN_KV, {"keys": keys}) + return BatchUnpinKVResponse(results=response.get("results", [])) + # ========================================================================= # Admin Operations (blocking) # ========================================================================= diff --git a/maru_handler/rpc_client.py b/maru_handler/rpc_client.py index 4066f47..7814434 100644 --- a/maru_handler/rpc_client.py +++ b/maru_handler/rpc_client.py @@ -12,7 +12,9 @@ AllocationManagerStats, BatchExistsKVResponse, BatchLookupKVResponse, + BatchPinKVResponse, BatchRegisterKVResponse, + BatchUnpinKVResponse, GetStatsResponse, KVManagerStats, ListAllocationsResponse, @@ -281,6 +283,32 @@ def exists_kv(self, key: str) -> bool: response = self._send_request(MessageType.EXISTS_KV, {"key": key}) return response.get("exists", False) + def pin_kv(self, key: str) -> bool: + """ + Check if a KV entry exists and pin it atomically. + + Args: + key: Chunk key string + + Returns: + True if exists (and was pinned) + """ + response = self._send_request(MessageType.PIN_KV, {"key": key}) + return response.get("exists", False) + + def unpin(self, key: str) -> bool: + """ + Unpin a KV entry, making it eligible for eviction. + + Args: + key: Chunk key string + + Returns: + True if unpinned successfully + """ + response = self._send_request(MessageType.UNPIN_KV, {"key": key}) + return response.get("success", False) + def delete_kv(self, key: str) -> bool: """ Delete a KV entry. @@ -370,6 +398,16 @@ def batch_exists_kv(self, keys: list[str]) -> BatchExistsKVResponse: return BatchExistsKVResponse(results=response.get("results", [])) + def batch_pin_kv(self, keys: list[str]) -> BatchPinKVResponse: + """Check existence and pin multiple KV entries in a single RPC call.""" + response = self._send_request(MessageType.BATCH_PIN_KV, {"keys": keys}) + return BatchPinKVResponse(results=response.get("results", [])) + + def batch_unpin(self, keys: list[str]) -> BatchUnpinKVResponse: + """Unpin multiple KV entries in a single RPC call.""" + response = self._send_request(MessageType.BATCH_UNPIN_KV, {"keys": keys}) + return BatchUnpinKVResponse(results=response.get("results", [])) + # ========================================================================= # Admin Operations # ========================================================================= diff --git a/maru_lmcache/__init__.py b/maru_lmcache/__init__.py index 26b190a..06b25a1 100644 --- a/maru_lmcache/__init__.py +++ b/maru_lmcache/__init__.py @@ -1,28 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 """ -Maru LMCache Plugin — external remote storage connector for upstream LMCache. +Maru LMCache integration — memory adapter and storage backend support. -Install: - pip install maru[lmcache] - -LMCache YAML config: - remote_url: "maru://localhost:5555?pool_size=1G" - remote_storage_plugins: ["maru"] - extra_config: - remote_storage_plugin.maru.module_path: maru_lmcache.adapter - remote_storage_plugin.maru.class_name: MaruConnectorAdapter +Usage: + from maru_lmcache import CxlMemoryAdapter """ -__all__ = ["MaruConnectorAdapter", "MaruConnector"] +__all__ = ["CxlMemoryAdapter"] def __getattr__(name: str): - if name == "MaruConnectorAdapter": - from maru_lmcache.adapter import MaruConnectorAdapter + if name == "CxlMemoryAdapter": + from maru_lmcache.adapter import CxlMemoryAdapter - return MaruConnectorAdapter - if name == "MaruConnector": - from maru_lmcache.connector import MaruConnector + return CxlMemoryAdapter + # Backward compatibility: old name still works + if name == "CxlMemoryAllocator": + from maru_lmcache.adapter import CxlMemoryAdapter - return MaruConnector + return CxlMemoryAdapter raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/maru_lmcache/adapter.py b/maru_lmcache/adapter.py index b75f84a..22e5120 100644 --- a/maru_lmcache/adapter.py +++ b/maru_lmcache/adapter.py @@ -1,65 +1,418 @@ # SPDX-License-Identifier: Apache-2.0 -""" -MaruConnectorAdapter — registers the ``maru://`` URL scheme with LMCache's -plugin discovery system (``remote_storage_plugins``). - -This module is referenced in LMCache YAML config as:: +# Copyright 2026 XCENA Inc. +"""CxlMemoryAdapter — LMCache MemoryAllocatorInterface adapter over MaruHandler. - remote_storage_plugins: ["maru"] - extra_config: - remote_storage_plugin.maru.module_path: maru_lmcache.adapter - remote_storage_plugin.maru.class_name: MaruConnectorAdapter +Adapts Maru's page-based CXL memory to LMCache's MemoryObj interface. +Pre-creates TensorMemoryObj per page via region-added callback from MaruHandler. +Address encoding uses bit-packing: (region_id << 32) | page_index. """ -import logging +import threading -from lmcache.v1.storage_backend.connector import ( - ConnectorAdapter, - ConnectorContext, - parse_remote_url, +import torch +from lmcache.logging import init_logger +from lmcache.v1.memory_management import ( + MemoryAllocatorInterface, + MemoryFormat, + MemoryObj, + MemoryObjMetadata, + TensorMemoryObj, ) -from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector -logger = logging.getLogger(__name__) +from maru_handler import MaruHandler +from maru_handler.memory import AllocHandle + +logger = init_logger(__name__) + + +class CxlMemoryAdapter(MemoryAllocatorInterface): + """LMCache MemoryAllocatorInterface adapter backed by Maru CXL shared memory. + + Adapter design: MaruHandler owns memory management (regions, pages). + This class translates between Maru's page allocation and LMCache's + MemoryObj interface by pre-creating TensorMemoryObj per page. + + Pool building is driven entirely by MaruHandler's region-added callback: + - On registration: replays for existing regions (initial pool build) + - On expansion: fires for newly added regions + + Address encoding: (region_id << 32) | page_index — stateless O(1) + bidirectional conversion, no cumulative offset table needed. + """ + + def __init__( + self, + handler: MaruHandler, + shapes: list[torch.Size], + dtypes: list[torch.dtype], + fmt: MemoryFormat, + chunk_size: int, + ): + self._handler = handler + self._lock = threading.Lock() + # LMCache metadata for MemoryObj construction + self._shapes = shapes + self._dtypes = dtypes + self._fmt = fmt + self._chunk_size = chunk_size -class MaruConnectorAdapter(ConnectorAdapter): - """Adapter that registers the ``maru://`` URL scheme.""" + # Pre-created MemoryObj pool: region_id -> [MemoryObj per page] + self._pool: dict[int, list[TensorMemoryObj]] = {} - def __init__(self) -> None: - super().__init__("maru://") + # Register callback — replays for existing regions, fires on expansion + self._handler.set_on_region_added(self._on_region_added) - def create_connector(self, context: ConnectorContext) -> RemoteConnector: - logger.info("Creating Maru connector for URL: %s", context.url) + # ========================================================================= + # Address Encoding + # ========================================================================= - # Validate URL format (requires host:port) - _ = parse_remote_url(context.url) + @staticmethod + def encode_address(region_id: int, page_index: int) -> int: + """Encode (region_id, page_index) into a single integer.""" + return (region_id << 32) | page_index - from maru_lmcache.connector import MaruConnector, MaruConnectorConfig + @staticmethod + def decode_address(address: int) -> tuple[int, int]: + """Decode a single integer into (region_id, page_index).""" + return (address >> 32, address & 0xFFFFFFFF) - maru_config = MaruConnectorConfig.from_url(context.url) + # ========================================================================= + # Pool Management + # ========================================================================= - # Override with extra_config if present - if context.config and context.config.extra_config: - maru_config = MaruConnectorConfig.from_lmcache_config( - context.config, fallback=maru_config + def _on_region_added(self, region_id: int, page_count: int) -> None: + """Callback from MaruHandler when a region is added. + + Builds the MemoryObj pool for the region. Called both during + initial registration (replay) and on region expansion. + + Args: + region_id: The region ID. + page_count: Number of pages in the region. + """ + logger.debug("[Maru] on_region_added region=%d pages=%d", region_id, page_count) + self._build_region_pool(region_id, page_count) + + def _build_region_pool(self, region_id: int, page_count: int) -> None: + """Pre-create MemoryObjs for all pages in a region. + + Args: + region_id: The region ID. + page_count: Number of pages in the region. + """ + chunk_size = self._chunk_size + objs: list[TensorMemoryObj] = [] + + for pid in range(page_count): + offset = pid * chunk_size + buf = self._handler.get_buffer_view(region_id, offset, chunk_size) + if buf is None: + logger.error( + "[Maru] buffer view failed region=%d page=%d, aborting pool", + region_id, + pid, + ) + return + + flat_dtype = self._dtypes[0] + tensor = torch.frombuffer(buf, dtype=flat_dtype) + + metadata = MemoryObjMetadata( + shape=self._shapes[0], + dtype=flat_dtype, + address=self.encode_address(region_id, pid), + phy_size=chunk_size, + ref_count=1, + fmt=self._fmt, + shapes=self._shapes if len(self._shapes) > 1 else None, + dtypes=self._dtypes if len(self._dtypes) > 1 else None, ) + objs.append(TensorMemoryObj(tensor, metadata, parent_allocator=None)) + + with self._lock: + self._pool[region_id] = objs + + logger.debug("[Maru] pool built region=%d pages=%d", region_id, len(objs)) + + def ensure_region_pool(self, region_id: int) -> bool: + """Ensure pool exists for a region (on-demand for shared regions). + + Args: + region_id: The region ID. + + Returns: + True if pool exists or was successfully created. + """ + with self._lock: + if region_id in self._pool: + return True + + page_count = self._handler.get_region_page_count(region_id) + if page_count is None: + return False + + # Double-check: another thread may have built it concurrently + with self._lock: + if region_id in self._pool: + return True + + self._build_region_pool(region_id, page_count) + return region_id in self._pool + + # ========================================================================= + # MemoryAllocatorInterface + # ========================================================================= + + def allocate( + self, + shapes: torch.Size | list[torch.Size], + dtypes: torch.dtype | list[torch.dtype], + fmt: MemoryFormat = MemoryFormat.UNDEFINED, + allocator_type: str | None = None, + ) -> MemoryObj | None: + """Allocate a CXL page and return the pooled MemoryObj. + + Pool objects are pre-created with the canonical shapes/dtypes/fmt + from __init__. The shapes/dtypes/fmt arguments are accepted for + interface compatibility but the pool's metadata is used. + + Args: + shapes: Tensor shape(s) (for size computation only). + dtypes: Tensor dtype(s) (for size computation only). + fmt: Memory format (unused, pool has canonical fmt). + allocator_type: Unused, for interface compatibility. + + Returns: + TensorMemoryObj from the pool, or None on failure. + """ + shapes_list, dtypes_list = self._adapt_shapes_and_dtypes(shapes, dtypes) + + size = 0 + for shape, dtype in zip(shapes_list, dtypes_list, strict=True): + size += shape.numel() * dtype.itemsize + + if size == 0: + return None + + try: + handle = self._handler.alloc(size=size) + except (ValueError, RuntimeError) as e: + logger.debug("[Maru] alloc failed: %s", e) + return None + + rid, pid = handle.region_id, handle.page_index + + with self._lock: + region_pool = self._pool.get(rid) - logger.info( - "Maru config: server_url=%s, pool_size=%s, pool_id=%s, instance_id=%s", - maru_config.server_url, - maru_config.pool_size, - maru_config.pool_id, - maru_config.instance_id, + if region_pool is None or pid >= len(region_pool): + logger.error("[Maru] pool miss region=%d page=%d", rid, pid) + self._handler.free(handle) + return None + + obj = region_pool[pid] + logger.debug("[Maru] allocate rid=%d pid=%d size=%d", rid, pid, size) + + # Partial chunk: return a view with adjusted shape to match actual + # token count, preventing CUDA kernel OOB on slot_mapping. + token_dim = self._fmt.token_dim() + if size < self._chunk_size and token_dim < len(self._shapes[0]): + single_token_size = self._chunk_size // self._shapes[0][token_dim] + return self._create_partial_view(obj, size, single_token_size) + + return obj + + def batched_allocate( + self, + shapes: torch.Size | list[torch.Size], + dtypes: torch.dtype | list[torch.dtype], + batch_size: int, + fmt: MemoryFormat = MemoryFormat.UNDEFINED, + allocator_type: str | None = None, + ) -> list[MemoryObj] | None: + """Allocate multiple CXL-backed MemoryObjs. + + Args: + shapes: Tensor shape(s) (same for each allocation). + dtypes: Tensor dtype(s) (same for each allocation). + batch_size: Number of allocations. + fmt: Memory format. + allocator_type: Unused, for interface compatibility. + + Returns: + List of TensorMemoryObj, or None if any allocation fails. + """ + results = [] + for _ in range(batch_size): + obj = self.allocate(shapes, dtypes, fmt, allocator_type) + if obj is None: + for allocated in results: + self.free(allocated) + return None + results.append(obj) + return results + + def free( + self, + memory_obj: MemoryObj, + allocator_type: str | None = None, + ) -> None: + """Free the underlying handler page allocation. + + Returns the page to the handler's allocator so it can be reused. + The pool MemoryObj itself is not destroyed — it persists and will + be returned by the next allocate() call for the same page. + + Called during batched_allocate() rollback and explicit free paths. + For the normal store lifecycle, pages are freed via + MaruBackend.remove() -> handler.delete(). + """ + rid, pid = self.decode_address(memory_obj.metadata.address) + handle = AllocHandle( + buf=memoryview(b""), + _region_id=rid, + _page_index=pid, + _size=0, + ) + try: + self._handler.free(handle) + except Exception as e: + logger.debug("[Maru] free failed rid=%d pid=%d: %s", rid, pid, e) + + def batched_free( + self, + memory_objs: list[MemoryObj], + allocator_type: str | None = None, + update_stats: bool = True, + ) -> None: + """Free multiple handler page allocations. See free().""" + for obj in memory_objs: + self.free(obj, allocator_type) + + def close(self) -> None: + """Clean up adapter state and unregister callback.""" + self._handler.set_on_region_added(None) + with self._lock: + self._pool.clear() + + # ========================================================================= + # Store / Retrieve Helpers + # ========================================================================= + + def create_store_handle(self, memory_obj: MemoryObj) -> AllocHandle: + """Create an AllocHandle from MemoryObj for handler.store(). + + Extracts (region_id, page_index) from metadata.address via + bit decoding. The returned handle has an empty buf — data is + already in CXL memory. + + Args: + memory_obj: MemoryObj with address set by this adapter. + + Returns: + AllocHandle for MaruHandler.store(). + """ + rid, pid = self.decode_address(memory_obj.metadata.address) + return AllocHandle( + buf=memoryview(b""), + _region_id=rid, + _page_index=pid, + _size=memory_obj.metadata.phy_size, ) - if context.config is None or context.metadata is None: - raise ValueError("Maru connector requires config and metadata") + def get_by_location( + self, + region_id: int, + page_index: int, + actual_size: int, + single_token_size: int, + ) -> MemoryObj | None: + """Look up a pooled MemoryObj by (region_id, page_index). + + For shared regions, builds the pool on-demand if not yet created. + + Args: + region_id: The region ID from retrieve response. + page_index: The page index from retrieve response. + actual_size: Actual data size in bytes. + single_token_size: Bytes per single token (for partial chunk). + + Returns: + MemoryObj from the pool, or None if not found. + """ + with self._lock: + region_pool = self._pool.get(region_id) + + if region_pool is None: + if not self.ensure_region_pool(region_id): + return None + with self._lock: + region_pool = self._pool.get(region_id) + if region_pool is None: + return None + + if page_index >= len(region_pool): + logger.error( + "Page index %d out of range for region %d (pool size=%d)", + page_index, + region_id, + len(region_pool), + ) + return None + + source = region_pool[page_index] + + if actual_size == self._chunk_size: + logger.debug( + "[Maru] get_by_location rid=%d pid=%d full", region_id, page_index + ) + return source + + # Partial chunk: create a view without mutating the pool object + logger.debug( + "[Maru] get_by_location rid=%d pid=%d partial=%d/%d", + region_id, + page_index, + actual_size, + self._chunk_size, + ) + return self._create_partial_view(source, actual_size, single_token_size) + + def _create_partial_view( + self, + source: TensorMemoryObj, + actual_size: int, + single_token_size: int, + ) -> TensorMemoryObj: + """Create a partial-chunk view from a pooled MemoryObj. + + The pool object is not mutated. Returns a new TensorMemoryObj + with a sliced tensor and adjusted shape. + + Args: + source: The full-chunk pooled MemoryObj. + actual_size: Actual data size in bytes. + single_token_size: Bytes per single token. + + Returns: + New TensorMemoryObj with sliced data and adjusted shape. + """ + # Slice the flat raw_data tensor to actual_size elements + dtype_size = source.metadata.dtype.itemsize + sliced_tensor = source.raw_data[: actual_size // dtype_size] + + shape_list = list(source.metadata.shape) + shape_list[self._fmt.token_dim()] = actual_size // single_token_size - return MaruConnector( - url=context.url, - loop=context.loop, - config=context.config, - metadata=context.metadata, - maru_config=maru_config, + metadata = MemoryObjMetadata( + shape=torch.Size(shape_list), + dtype=source.metadata.dtype, + address=source.metadata.address, + phy_size=actual_size, + ref_count=1, + fmt=source.metadata.fmt, + shapes=source.metadata.shapes, + dtypes=source.metadata.dtypes, ) + return TensorMemoryObj(sliced_tensor, metadata, parent_allocator=None) diff --git a/maru_lmcache/connector.py b/maru_lmcache/connector.py deleted file mode 100644 index 500a24a..0000000 --- a/maru_lmcache/connector.py +++ /dev/null @@ -1,642 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -MaruConnector — bridges upstream LMCache's RemoteConnector interface to -Maru's MaruHandler for CXL shared-memory KV cache storage. - -Key design points: -- Key conversion: CacheEngineKey → string key (via to_string()) -- Zero-copy bridging: MemoryInfo (memoryview) ↔ MemoryObj (torch tensor) -- Async wrapping: asyncio.to_thread() around MaruHandler's sync API -- Batch operations: batch_retrieve / batch_store / batch_exists -""" - -import asyncio -import builtins -import logging -import os -import re -import time -from dataclasses import dataclass -from typing import Optional -from urllib.parse import parse_qs, urlparse - -import torch -from lmcache.utils import CacheEngineKey -from lmcache.v1.config import LMCacheEngineConfig -from lmcache.v1.memory_management import MemoryObj -from lmcache.v1.metadata import LMCacheMetadata -from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector - -logger = logging.getLogger(__name__) - -_PERF_ENABLED = os.environ.get("LMCACHE_PERF_LOG", "0") == "1" - - -def _parse_pool_id(raw: object) -> list[int] | int | None: - """Parse a pool_id value. - - Returns: - - None if unset - - int for a single value (MaruConfig normalizes to list[int]) - - list[int] for multiple values (comma-separated string or list) - """ - if raw is None: - return None - if isinstance(raw, int): - return raw - if isinstance(raw, list): - try: - return [int(v) for v in raw] - except (ValueError, TypeError) as e: - raise ValueError( - f"Invalid pool_id list: {raw!r}. All elements must be integers." - ) from e - # String: may be comma-separated (e.g. "0,1,2") or single ("1") - s = str(raw).strip() - if "," in s: - parts = [p.strip() for p in s.split(",") if p.strip()] - try: - return [int(p) for p in parts] - except ValueError as e: - raise ValueError( - f"Invalid pool_id: {raw!r}. " - "Comma-separated values must all be non-negative integers." - ) from e - try: - return int(s) - except (ValueError, TypeError) as e: - raise ValueError( - f"Invalid pool_id: {raw!r}. Must be a non-negative integer." - ) from e - - -def _perf_log(elapsed_ms: float, msg: str) -> None: - if _PERF_ENABLED: - print(f"[PERF][{elapsed_ms:.2f}ms][maru_connector]: {msg}", flush=True) - - -# --------------------------------------------------------------------------- -# Size parsing -# --------------------------------------------------------------------------- - - -def parse_size(size_str: str) -> int: - """Parse human-readable size string (e.g., '1G', '500M') to bytes.""" - if isinstance(size_str, int): - return size_str - match = re.match(r"^(\d+(?:\.\d+)?)\s*([KMGT]?)B?$", str(size_str).upper()) - if not match: - try: - return int(size_str) - except ValueError: - raise ValueError( - f"Invalid size string: {size_str!r}. " - "Expected a number or human-readable size " - "(e.g., '1G', '500M', '1024')." - ) from None - value, unit = float(match.group(1)), match.group(2) - multipliers = {"": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4} - return int(value * multipliers.get(unit, 1)) - - -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- - - -@dataclass -class MaruConnectorConfig: - """Configuration for the Maru connector.""" - - server_url: str = "tcp://localhost:5555" - pool_size: int = 1024 * 1024 * 1024 # 1 GB - pool_id: list[int] | int | None = None # None means any pool (ANY_POOL_ID) - instance_id: str | None = None - auto_connect: bool = True - connection_timeout: float = 30.0 - operation_timeout: float = 10.0 - timeout_ms: int = 2000 - use_async_rpc: bool = True - max_inflight: int = 64 - eager_map: bool | None = None - - @staticmethod - def from_url(url: str) -> "MaruConnectorConfig": - """Parse ``maru://host:port?pool_size=1G&timeout=30``.""" - parsed = urlparse(url) - host = parsed.hostname or "localhost" - port = parsed.port or 5555 - params = parse_qs(parsed.query) - raw_pool_id = params.get("pool_id", [None])[0] - return MaruConnectorConfig( - server_url=f"tcp://{host}:{port}", - pool_size=parse_size(params.get("pool_size", ["1G"])[0]), - pool_id=_parse_pool_id(raw_pool_id), - instance_id=params.get("instance_id", [None])[0], - connection_timeout=float(params.get("timeout", ["30.0"])[0]), - operation_timeout=float(params.get("op_timeout", ["10.0"])[0]), - ) - - @staticmethod - def from_lmcache_config( - config: LMCacheEngineConfig, - fallback: Optional["MaruConnectorConfig"] = None, - ) -> "MaruConnectorConfig": - """Build from ``extra_config``, falling back to *fallback* for unset keys.""" - extra = config.extra_config or {} - fb = fallback or MaruConnectorConfig() - - raw_pool = extra.get("maru_pool_size", fb.pool_size) - pool_size = parse_size(raw_pool) if isinstance(raw_pool, str) else int(raw_pool) - - raw_pool_id = extra.get("maru_pool_id", fb.pool_id) - pool_id_val = _parse_pool_id(raw_pool_id) - - return MaruConnectorConfig( - server_url=extra.get("maru_server_url", fb.server_url), - pool_size=pool_size, - pool_id=pool_id_val, - instance_id=extra.get("maru_instance_id", fb.instance_id), - auto_connect=extra.get("maru_auto_connect", fb.auto_connect), - operation_timeout=float( - extra.get("maru_operation_timeout", fb.operation_timeout) - ), - timeout_ms=int(extra.get("maru_timeout_ms", fb.timeout_ms)), - use_async_rpc=extra.get("maru_use_async_rpc", fb.use_async_rpc), - max_inflight=int(extra.get("maru_max_inflight", fb.max_inflight)), - eager_map=extra.get("maru_eager_map", fb.eager_map), - ) - - -# --------------------------------------------------------------------------- -# Key conversion -# --------------------------------------------------------------------------- - - -def cache_key_to_str(key: CacheEngineKey) -> str: - """Convert CacheEngineKey to string key for Maru storage.""" - return key.to_string() - - -# --------------------------------------------------------------------------- -# Ping error codes -# --------------------------------------------------------------------------- - -PING_SUCCESS = 0 -PING_NOT_CONNECTED = 1 -PING_RPC_ERROR = 2 - - -# --------------------------------------------------------------------------- -# MaruConnector -# --------------------------------------------------------------------------- - - -class MaruConnector(RemoteConnector): - """ - Upstream-LMCache-compatible connector backed by Maru shared memory. - - This class inherits from upstream ``RemoteConnector`` and delegates all - storage operations to ``maru.MaruHandler``. - """ - - def __init__( - self, - url: str, - loop: asyncio.AbstractEventLoop, - config: LMCacheEngineConfig, - metadata: LMCacheMetadata, - maru_config: MaruConnectorConfig, - ): - logger.info("Initializing MaruConnector for url=%s", url) - super().__init__(config, metadata) - - self.url = url - self.loop = loop - self.maru_config = maru_config - - # MaruHandler (lazy init) - self._handle = None - self._connected = False - - if self.maru_config.auto_connect: - self._init_handle() - - # ------------------------------------------------------------------ - # Connection management - # ------------------------------------------------------------------ - - def _init_handle(self) -> bool: - try: - from maru import MaruConfig, MaruHandler - except ImportError: - logger.warning("maru package not installed. Install with: pip install maru") - return False - - try: - cfg_kwargs = { - "server_url": self.maru_config.server_url, - "instance_id": self.maru_config.instance_id, - "pool_size": self.maru_config.pool_size, - "pool_id": self.maru_config.pool_id, - "chunk_size_bytes": self.full_chunk_size_bytes, - "auto_connect": False, - "timeout_ms": self.maru_config.timeout_ms, - "use_async_rpc": self.maru_config.use_async_rpc, - "max_inflight": self.maru_config.max_inflight, - } - if self.maru_config.eager_map is not None: - cfg_kwargs["eager_map"] = self.maru_config.eager_map - - handle = MaruHandler(MaruConfig(**cfg_kwargs)) - self._handle = handle - if handle.connect(): - self._connected = True - logger.info("MaruHandler connected successfully") - return True - else: - logger.warning("MaruHandler.connect() returned False") - self._handle = None - return False - except Exception as e: - logger.warning("Failed to initialize MaruHandler: %s", e) - self._handle = None - return False - - def _ensure_connected(self) -> bool: - if self._connected and self._handle is not None: - return True - return self._init_handle() - - # ------------------------------------------------------------------ - # Zero-copy encode / decode - # ------------------------------------------------------------------ - - def _decode_memory_obj(self, info) -> MemoryObj | None: - """MemoryInfo (memoryview) → TensorMemoryObj (zero-copy).""" - from lmcache.v1.memory_management import MemoryObjMetadata, TensorMemoryObj - - mv = info.view - raw_data = torch.frombuffer(mv, dtype=torch.uint8) - - meta = MemoryObjMetadata( - shape=self.meta_shapes[0], - dtype=self.meta_dtypes[0], - address=0, - phy_size=raw_data.numel(), - ref_count=1, - pin_count=0, - fmt=self.meta_fmt, - shapes=self.meta_shapes, - dtypes=self.meta_dtypes, - ) - - return TensorMemoryObj( - raw_data=raw_data, - metadata=meta, - parent_allocator=None, - ) - - @staticmethod - def _encode_memory_obj(memory_obj: MemoryObj): - """MemoryObj → MemoryInfo (zero-copy via byte_array).""" - from maru_handler.memory import MemoryInfo - - return MemoryInfo(view=memory_obj.byte_array) - - # ------------------------------------------------------------------ - # Core operations (abstract method implementations) - # ------------------------------------------------------------------ - - async def exists(self, key: CacheEngineKey) -> bool: - if not self._ensure_connected(): - return False - assert self._handle is not None - key_hash = cache_key_to_str(key) - try: - t0 = time.perf_counter() - result = await asyncio.wait_for( - asyncio.to_thread(self._handle.exists, key_hash), - timeout=self.maru_config.operation_timeout, - ) - _perf_log( - (time.perf_counter() - t0) * 1000, - f"exists key_hash={key_hash} result={result}", - ) - return result - except TimeoutError: - logger.warning("exists timed out for key_hash=%s", key_hash) - return False - except Exception as e: - logger.error("exists failed: %s", e) - return False - - def exists_sync(self, key: CacheEngineKey) -> bool: - if not self._ensure_connected(): - return False - assert self._handle is not None - key_hash = cache_key_to_str(key) - try: - return self._handle.exists(key_hash) - except Exception as e: - logger.error("exists_sync failed: %s", e) - return False - - async def get(self, key: CacheEngineKey) -> MemoryObj | None: - if not self._ensure_connected(): - return None - assert self._handle is not None - key_hash = cache_key_to_str(key) - try: - t0 = time.perf_counter() - info = await asyncio.wait_for( - asyncio.to_thread(self._handle.retrieve, key_hash), - timeout=self.maru_config.operation_timeout, - ) - if info is None: - _perf_log( - (time.perf_counter() - t0) * 1000, - f"get key_hash={key_hash} MISS", - ) - return None - - data_size = len(info.view) - memory_obj = self._decode_memory_obj(info) - if memory_obj is not None: - memory_obj = self.reshape_partial_chunk(memory_obj, data_size) - _perf_log( - (time.perf_counter() - t0) * 1000, - f"get key_hash={key_hash} bytes={data_size}", - ) - return memory_obj - except TimeoutError: - logger.warning("get timed out for key_hash=%s", key_hash) - return None - except Exception as e: - logger.error("get failed: %s", e) - return None - - async def put(self, key: CacheEngineKey, memory_obj: MemoryObj) -> None: - if not self._ensure_connected(): - raise RuntimeError("MaruConnector not connected") - assert self._handle is not None - key_hash = cache_key_to_str(key) - - t0 = time.perf_counter() - info = self._encode_memory_obj(memory_obj) - data_size = len(info.view) - _perf_log((time.perf_counter() - t0) * 1000, f"put encode bytes={data_size}") - - try: - t1 = time.perf_counter() - success = await asyncio.wait_for( - asyncio.to_thread(self._handle.store, key_hash, info), - timeout=self.maru_config.operation_timeout, - ) - _perf_log( - (time.perf_counter() - t1) * 1000, - f"put RPC key_hash={key_hash} bytes={data_size} ok={success}", - ) - if not success: - logger.warning("put failed for key_hash=%s", key_hash) - except TimeoutError: - logger.warning("put timed out for key_hash=%s", key_hash) - except Exception as e: - logger.error("put failed: %s", e) - raise - - async def list(self) -> list[str]: - logger.warning("list() not supported by Maru connector") - return [] - - async def close(self) -> None: - logger.info("MaruConnector.close called") - if self._handle is not None: - try: - self._handle.close() - except Exception as e: - logger.error("Error closing MaruHandler: %s", e) - finally: - self._handle = None - self._connected = False - - # ------------------------------------------------------------------ - # Optional: remove - # ------------------------------------------------------------------ - - def remove_sync(self, key: CacheEngineKey) -> bool: - if not self._ensure_connected(): - return False - assert self._handle is not None - key_hash = cache_key_to_str(key) - try: - return self._handle.delete(key_hash) - except Exception as e: - logger.error("remove_sync failed: %s", e) - return False - - # ------------------------------------------------------------------ - # Optional: ping - # ------------------------------------------------------------------ - - def support_ping(self) -> bool: - return True - - async def ping(self) -> int: - if not self._connected or self._handle is None: - return PING_NOT_CONNECTED - try: - healthy = await asyncio.wait_for( - asyncio.to_thread(self._handle.healthcheck), - timeout=self.maru_config.operation_timeout, - ) - return PING_SUCCESS if healthy else PING_RPC_ERROR - except Exception: - return PING_RPC_ERROR - - # ------------------------------------------------------------------ - # Batch operations - # ------------------------------------------------------------------ - - def support_batched_get(self) -> bool: - return True - - def support_batched_put(self) -> bool: - return True - - def support_batched_async_contains(self) -> bool: - return True - - def support_batched_contains(self) -> bool: - return True - - def support_batched_get_non_blocking(self) -> bool: - return True - - def batched_contains(self, keys: builtins.list[CacheEngineKey]) -> int: - if not self._ensure_connected() or not keys: - return 0 - assert self._handle is not None - key_hashes = [cache_key_to_str(k) for k in keys] - try: - results = self._handle.batch_exists(key_hashes) - count = 0 - for exists in results: - if not exists: - break - count += 1 - return count - except Exception as e: - logger.error("batched_contains failed: %s", e) - return 0 - - async def batched_async_contains( - self, - lookup_id: str, - keys: builtins.list[CacheEngineKey], - pin: bool = False, - ) -> int: - if not self._ensure_connected() or not keys: - return 0 - assert self._handle is not None - key_hashes = [cache_key_to_str(k) for k in keys] - try: - t0 = time.perf_counter() - results = await asyncio.wait_for( - asyncio.to_thread(self._handle.batch_exists, key_hashes), - timeout=self.maru_config.operation_timeout, - ) - count = 0 - for exists in results: - if not exists: - break - count += 1 - _perf_log( - (time.perf_counter() - t0) * 1000, - f"batch_contains n={len(keys)} hits={count}", - ) - return count - except TimeoutError: - logger.warning("batched_async_contains timed out") - return 0 - except Exception as e: - logger.error("batched_async_contains failed: %s", e) - return 0 - - async def batched_get( - self, keys: builtins.list[CacheEngineKey] - ) -> builtins.list[MemoryObj | None]: - if not self._ensure_connected() or not keys: - return [None] * len(keys) - assert self._handle is not None - key_hashes = [cache_key_to_str(k) for k in keys] - try: - t0 = time.perf_counter() - raw_results = await asyncio.wait_for( - asyncio.to_thread(self._handle.batch_retrieve, key_hashes), - timeout=self.maru_config.operation_timeout, - ) - objs: list[MemoryObj | None] = [] - for info in raw_results: - if info is None: - objs.append(None) - continue - obj = self._decode_memory_obj(info) - if obj is not None: - obj = self.reshape_partial_chunk(obj, len(info.view)) - objs.append(obj) - hits = sum(1 for r in raw_results if r is not None) - _perf_log( - (time.perf_counter() - t0) * 1000, - f"batch_get n={len(keys)} hits={hits}", - ) - return objs - except TimeoutError: - logger.warning("batched_get timed out for %d keys", len(keys)) - return [None] * len(keys) - except Exception as e: - logger.error("batched_get failed: %s", e) - return [None] * len(keys) - - async def batched_put( - self, - keys: builtins.list[CacheEngineKey], - memory_objs: builtins.list[MemoryObj], - ) -> None: - if not self._ensure_connected() or not keys: - return - assert self._handle is not None - key_hashes = [cache_key_to_str(k) for k in keys] - - t0 = time.perf_counter() - infos = [self._encode_memory_obj(obj) for obj in memory_objs] - total_bytes = sum(len(info.view) for info in infos) - _perf_log( - (time.perf_counter() - t0) * 1000, - f"batch_put encode n={len(keys)} bytes={total_bytes}", - ) - - try: - t1 = time.perf_counter() - results = await asyncio.wait_for( - asyncio.to_thread(self._handle.batch_store, key_hashes, infos), - timeout=self.maru_config.operation_timeout, - ) - stored = sum(results) if results else 0 - _perf_log( - (time.perf_counter() - t1) * 1000, - f"batch_put RPC n={len(keys)} stored={stored} bytes={total_bytes}", - ) - if stored < len(keys): - logger.warning( - "batch_put partial: stored %d/%d keys", stored, len(keys) - ) - except TimeoutError: - logger.warning("batched_put timed out for %d keys", len(keys)) - except Exception as e: - logger.error("batched_put failed: %s", e) - raise - - async def batched_get_non_blocking( - self, - lookup_id: str, - keys: builtins.list[CacheEngineKey], - ) -> builtins.list[MemoryObj]: - if not self._ensure_connected() or not keys: - return [] - assert self._handle is not None - key_hashes = [cache_key_to_str(k) for k in keys] - try: - t0 = time.perf_counter() - raw_results = await asyncio.wait_for( - asyncio.to_thread(self._handle.batch_retrieve, key_hashes), - timeout=self.maru_config.operation_timeout, - ) - # Consecutive prefix of hits only - objs: list[MemoryObj] = [] - for info in raw_results: - if info is None: - break - obj = self._decode_memory_obj(info) - if obj is None: - break - obj = self.reshape_partial_chunk(obj, len(info.view)) - objs.append(obj) - _perf_log( - (time.perf_counter() - t0) * 1000, - f"batch_get_nb n={len(keys)} hits={len(objs)}", - ) - return objs - except TimeoutError: - logger.warning("batched_get_non_blocking timed out") - return [] - except Exception as e: - logger.error("batched_get_non_blocking failed: %s", e) - return [] - - def __repr__(self) -> str: - return ( - f"" - ) diff --git a/maru_server/__init__.py b/maru_server/__init__.py index deddfbc..54672d2 100644 --- a/maru_server/__init__.py +++ b/maru_server/__init__.py @@ -7,7 +7,7 @@ setup_package_logging("maru_server") from .allocation_manager import AllocationInfo, AllocationManager # noqa: E402 -from .kv_manager import KVEntry, KVManager # noqa: E402 +from .kv_manager import DeleteResult, KVEntry, KVManager # noqa: E402 from .rpc_server import RpcServer # noqa: E402 from .server import MaruServer # noqa: E402 @@ -18,6 +18,7 @@ "RpcServer", "KVManager", "KVEntry", + "DeleteResult", "AllocationManager", "AllocationInfo", ] diff --git a/maru_server/kv_manager.py b/maru_server/kv_manager.py index 39f45de..e32198f 100644 --- a/maru_server/kv_manager.py +++ b/maru_server/kv_manager.py @@ -2,6 +2,7 @@ # Copyright 2026 XCENA Inc. """KV Manager implementation for managing KV metadata.""" +import enum import logging from dataclasses import dataclass from threading import RLock @@ -16,6 +17,15 @@ class KVEntry: region_id: int # Region ID (allocation identifier) kv_offset: int # Offset within allocation (relative to handle.offset) kv_length: int # Size of KV data + pin_count: int = 0 # Pin count for eviction protection + + +class DeleteResult(enum.Enum): + """Result of a KV delete operation.""" + + NOT_FOUND = "not_found" + PINNED = "pinned" + DELETED = "deleted" class KVManager: @@ -78,22 +88,64 @@ def exists(self, key: str) -> bool: with self._lock: return key in self._store - def delete(self, key: str) -> tuple[bool, int | None]: + def pin(self, key: str) -> bool: + """Check if a KV entry exists and pin it atomically. + + Returns: + True if key exists (and was pinned), False otherwise. + """ + with self._lock: + entry = self._store.get(key) + if entry is None: + return False + entry.pin_count += 1 + logger.debug("Pinned KV: key=%s, pin_count=%d", key, entry.pin_count) + return True + + def unpin(self, key: str) -> bool: + """Decrement pin_count for a KV entry. + + Returns: + True if successfully unpinned, False if key not found or not pinned. + """ + with self._lock: + entry = self._store.get(key) + if entry is None: + logger.warning("Unpin failed: key=%s not found", key) + return False + if entry.pin_count <= 0: + logger.warning("Unpin failed: key=%s pin_count already 0", key) + return False + entry.pin_count -= 1 + logger.debug("Unpinned KV: key=%s, pin_count=%d", key, entry.pin_count) + return True + + def delete(self, key: str) -> tuple[DeleteResult, int | None]: """ Delete a KV entry. Returns: - (key_existed, region_id_to_decrement) - - (False, None): key didn't exist - - (True, region_id): entry deleted, allocation ref needs decrement + (result, region_id_to_decrement) + - (NOT_FOUND, None): key didn't exist + - (PINNED, None): key exists but pinned, deletion refused + - (DELETED, region_id): entry deleted, allocation ref needs decrement """ with self._lock: - if key not in self._store: - return (False, None) + entry = self._store.get(key) + if entry is None: + return (DeleteResult.NOT_FOUND, None) + + if entry.pin_count > 0: + logger.warning( + "Delete refused: key=%s is pinned (pin_count=%d)", + key, + entry.pin_count, + ) + return (DeleteResult.PINNED, None) region_id = self._store.pop(key).region_id logger.debug("Deleted KV: key=%s, region_id=%d", key, region_id) - return (True, region_id) + return (DeleteResult.DELETED, region_id) def get_stats(self) -> dict: """Get KV statistics.""" @@ -149,6 +201,9 @@ def batch_exists(self, keys: list[str]) -> list[bool]: """ Check existence of multiple KV entries in a single operation. + Checks ALL keys unconditionally (no prefix-stop). + For prefix-stop with pinning, use batch_pin(). + Args: keys: List of chunk key strings @@ -157,3 +212,52 @@ def batch_exists(self, keys: list[str]) -> list[bool]: """ with self._lock: return [key in self._store for key in keys] + + def batch_pin(self, keys: list[str]) -> list[bool]: + """Check existence and pin prefix-contiguous KV entries atomically. + + Uses prefix-stop: stops at the first miss, only pinning the + contiguous prefix of existing keys. This avoids pin leaks — + if all existing keys were pinned, the caller would need to + unpin non-prefix keys it doesn't use. + + Unlike batch_exists() which checks ALL keys, this method + intentionally stops early because pinning has side effects. + + Returns: + List of booleans — True if key exists (and was pinned). + After first False, remaining entries are all False. + """ + with self._lock: + results = [] + for key in keys: + entry = self._store.get(key) + if entry is None: + # First miss: fill rest with False and stop + results.extend([False] * (len(keys) - len(results))) + break + entry.pin_count += 1 + results.append(True) + return results + + def batch_unpin(self, keys: list[str]) -> list[bool]: + """Unpin multiple KV entries. + + Returns: + List of booleans — True if successfully unpinned. + """ + with self._lock: + results = [] + for key in keys: + entry = self._store.get(key) + if entry is None or entry.pin_count <= 0: + results.append(False) + else: + entry.pin_count -= 1 + results.append(True) + return results + + # TODO: Add pin timeout monitor (PinMonitor) when eviction is implemented. + # Track _pin_timestamps per key, run a periodic check_pin_timeouts() in a + # daemon thread to force-unpin entries that exceed a TTL. This prevents + # pin leaks when clients crash without sending unpin RPCs. diff --git a/maru_server/rpc_handler_mixin.py b/maru_server/rpc_handler_mixin.py index cb7f889..6f43a80 100644 --- a/maru_server/rpc_handler_mixin.py +++ b/maru_server/rpc_handler_mixin.py @@ -33,10 +33,14 @@ def _get_handlers(self) -> dict[int, Callable[[Any], dict]]: MessageType.LOOKUP_KV.value: self._handle_lookup_kv, MessageType.EXISTS_KV.value: self._handle_exists_kv, MessageType.DELETE_KV.value: self._handle_delete_kv, + MessageType.PIN_KV.value: self._handle_pin_kv, + MessageType.UNPIN_KV.value: self._handle_unpin_kv, # Batch operations MessageType.BATCH_REGISTER_KV.value: self._handle_batch_register_kv, MessageType.BATCH_LOOKUP_KV.value: self._handle_batch_lookup_kv, MessageType.BATCH_EXISTS_KV.value: self._handle_batch_exists_kv, + MessageType.BATCH_PIN_KV.value: self._handle_batch_pin_kv, + MessageType.BATCH_UNPIN_KV.value: self._handle_batch_unpin_kv, # Admin MessageType.GET_STATS.value: self._handle_get_stats, MessageType.HEARTBEAT.value: self._handle_heartbeat, @@ -146,6 +150,16 @@ def _handle_delete_kv(self, req: Any) -> dict: success = self._server.delete_kv(key=req.key) return {"success": success} + def _handle_pin_kv(self, req: Any) -> dict: + exists = self._server.pin_kv(key=req.key) + logger.debug("[PIN] key=%s -> %s", req.key, exists) + return {"exists": exists} + + def _handle_unpin_kv(self, req: Any) -> dict: + success = self._server.unpin(key=req.key) + logger.debug("[UNPIN] key=%s -> %s", req.key, success) + return {"success": success} + # ========================================================================= # Batch KV Handlers # ========================================================================= @@ -195,8 +209,29 @@ def _handle_batch_lookup_kv(self, req: Any) -> dict: def _handle_batch_exists_kv(self, req: Any) -> dict: """Handle batch exists KV request.""" keys = req.keys - logger.debug("[BATCH_EXISTS] %d keys", len(keys)) results = self._server.batch_exists_kv(keys) + hits = sum(results) + logger.debug("[BATCH_EXISTS] %d keys, %d hits", len(keys), hits) + return {"results": results} + + def _handle_batch_pin_kv(self, req: Any) -> dict: + """Handle batch pin KV request.""" + keys = req.keys + results = self._server.batch_pin_kv(keys) + hits = sum(results) + logger.debug( + "[BATCH_PIN] %d keys, %d pinned (prefix-stop)", + len(keys), + hits, + ) + return {"results": results} + + def _handle_batch_unpin_kv(self, req: Any) -> dict: + """Handle batch unpin KV request.""" + keys = req.keys + results = self._server.batch_unpin(keys) + ok = sum(results) + logger.debug("[BATCH_UNPIN] %d keys, %d ok", len(keys), ok) return {"results": results} # ========================================================================= diff --git a/maru_server/server.py b/maru_server/server.py index b14d323..6c7fb51 100644 --- a/maru_server/server.py +++ b/maru_server/server.py @@ -11,7 +11,7 @@ from maru_shm.types import MaruHandle from .allocation_manager import AllocationManager -from .kv_manager import KVManager +from .kv_manager import DeleteResult, KVManager logger = logging.getLogger(__name__) @@ -29,6 +29,9 @@ def __init__(self): self._allocation_manager = AllocationManager() self._kv_manager = KVManager() self._lock = RLock() # Coordinates cross-manager operations + # TODO: Add PinMonitor daemon thread when eviction is implemented. + # Periodically force-unpin entries that exceed a TTL to prevent + # pin leaks from crashed clients. logger.info("MaruServer initialized") # ========================================================================= @@ -109,15 +112,23 @@ def exists_kv(self, key: str) -> bool: """Check if a KV entry exists.""" return self._kv_manager.exists(key) + def pin_kv(self, key: str) -> bool: + """Check if a KV entry exists and pin it atomically.""" + return self._kv_manager.pin(key) + + def unpin(self, key: str) -> bool: + """Unpin a KV entry, making it eligible for eviction.""" + return self._kv_manager.unpin(key) + def delete_kv(self, key: str) -> bool: """Delete a KV entry.""" with self._lock: - existed, region_to_deref = self._kv_manager.delete(key) + result, region_to_deref = self._kv_manager.delete(key) if region_to_deref is not None: self._allocation_manager.decrement_kv_ref(region_to_deref) - return existed + return result == DeleteResult.DELETED # ========================================================================= # Batch KV Operations @@ -176,6 +187,14 @@ def batch_lookup_kv(self, keys: list[str]) -> list[dict | None]: return results + def batch_pin_kv(self, keys: list[str]) -> list[bool]: + """Check existence and pin multiple KV entries atomically.""" + return self._kv_manager.batch_pin(keys) + + def batch_unpin(self, keys: list[str]) -> list[bool]: + """Unpin multiple KV entries.""" + return self._kv_manager.batch_unpin(keys) + def batch_exists_kv(self, keys: list[str]) -> list[bool]: """ Check existence of multiple KV entries in a single operation. diff --git a/tests/integration/test_handler.py b/tests/integration/test_handler.py index 477f01d..95b1399 100644 --- a/tests/integration/test_handler.py +++ b/tests/integration/test_handler.py @@ -5,7 +5,6 @@ import pytest from maru import MaruConfig, MaruHandler -from maru_handler.memory import MemoryInfo pytestmark = pytest.mark.integration @@ -80,8 +79,10 @@ def test_store_and_retrieve(self, server_thread, server_port): with MaruHandler(config) as handler: data = b"hello world" - info = MemoryInfo(view=memoryview(data)) - assert handler.store(key="12345", info=info) is True + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + assert handler.store(key="12345", handle=handle) is True assert handler.exists(key="12345") is True result = handler.retrieve(key="12345") @@ -97,7 +98,11 @@ def test_store_and_delete_frees_page(self, server_thread, server_port): ) with MaruHandler(config) as handler: - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data1"))) + data = b"data1" + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + handler.store(key="1", handle=handle) assert handler.allocator.num_allocated == 1 handler.delete(key="1") @@ -116,22 +121,18 @@ def test_store_auto_expansion(self, server_thread, server_port): # Fill all pages in first region page_count = handler.allocator.page_count for i in range(page_count): - assert ( - handler.store( - key=str(i), info=MemoryInfo(view=memoryview(b"x" * 100)) - ) - is True - ) + data = b"x" * 100 + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + assert handler.store(key=str(i), handle=handle) is True # Next store triggers auto-expansion to new region overflow_data = b"overflow" - assert ( - handler.store( - key=str(page_count + 1), - info=MemoryInfo(view=memoryview(overflow_data)), - ) - is True - ) + handle = handler.alloc(size=len(overflow_data)) + buf = handle.buf + buf[: len(overflow_data)] = overflow_data + assert handler.store(key=str(page_count + 1), handle=handle) is True # Verify 2 regions exist assert handler.owned_region_manager is not None @@ -155,7 +156,11 @@ def test_store_delete_reuse(self, server_thread, server_port): # Fill all pages page_count = handler.allocator.page_count for i in range(page_count): - handler.store(key=str(i), info=MemoryInfo(view=memoryview(b"data"))) + data = b"data" + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + handler.store(key=str(i), handle=handle) assert handler.allocator.num_free_pages == 0 @@ -165,12 +170,11 @@ def test_store_delete_reuse(self, server_thread, server_port): # Store new key — should succeed using freed page new_key = str(page_count + 100) # key not in range(page_count) - assert ( - handler.store( - key=new_key, info=MemoryInfo(view=memoryview(b"new data")) - ) - is True - ) + new_data = b"new data" + handle = handler.alloc(size=len(new_data)) + buf = handle.buf + buf[: len(new_data)] = new_data + assert handler.store(key=new_key, handle=handle) is True assert handler.allocator.num_free_pages == 0 def test_store_duplicate_key_is_skipped(self, server_thread, server_port): @@ -183,11 +187,18 @@ def test_store_duplicate_key_is_skipped(self, server_thread, server_port): with MaruHandler(config) as handler: v1 = b"version1" - handler.store(key="1", info=MemoryInfo(view=memoryview(v1))) + handle = handler.alloc(size=len(v1)) + buf = handle.buf + buf[: len(v1)] = v1 + handler.store(key="1", handle=handle) assert handler.allocator.num_allocated == 1 - # Second store with same key is skipped - handler.store(key="1", info=MemoryInfo(view=memoryview(b"version2"))) + # Second store with same key is skipped (alloc freed internally) + v2 = b"version2" + handle2 = handler.alloc(size=len(v2)) + buf2 = handle2.buf + buf2[: len(v2)] = v2 + handler.store(key="1", handle=handle2) assert handler.allocator.num_allocated == 1 # still 1 page # Original value is preserved @@ -196,7 +207,7 @@ def test_store_duplicate_key_is_skipped(self, server_thread, server_port): assert bytes(result.view[: len(v1)]) == v1 def test_store_exceeds_chunk_size(self, server_thread, server_port): - """Test that store fails when data exceeds chunk_size.""" + """Test that alloc fails when data exceeds chunk_size.""" config = MaruConfig( server_url=f"tcp://127.0.0.1:{server_port}", pool_size=4096, @@ -205,9 +216,8 @@ def test_store_exceeds_chunk_size(self, server_thread, server_port): with MaruHandler(config) as handler: data = b"x" * 1025 # exceeds 1024 - assert ( - handler.store(key="1", info=MemoryInfo(view=memoryview(data))) is False - ) + with pytest.raises(ValueError): + handler.alloc(size=len(data)) class TestMaruHandlerMultiRegion: @@ -225,14 +235,30 @@ def test_retrieve_from_expanded_region(self, server_thread, server_port): # Fill all pages in first region page_count = handler.allocator.page_count d1, d2, d3 = b"region1_data1", b"region1_data2", b"region2_data1" - handler.store(key="1", info=MemoryInfo(view=memoryview(d1))) - handler.store(key="2", info=MemoryInfo(view=memoryview(d2))) + + handle1 = handler.alloc(size=len(d1)) + buf = handle1.buf + buf[: len(d1)] = d1 + handler.store(key="1", handle=handle1) + + handle2 = handler.alloc(size=len(d2)) + buf = handle2.buf + buf[: len(d2)] = d2 + handler.store(key="2", handle=handle2) + for i in range(3, page_count + 1): - handler.store(key=str(i), info=MemoryInfo(view=memoryview(b"filler"))) + filler = b"filler" + handle = handler.alloc(size=len(filler)) + buf = handle.buf + buf[: len(filler)] = filler + handler.store(key=str(i), handle=handle) # Next store triggers auto-expand to region 2 overflow_key = str(page_count + 1) - handler.store(key=overflow_key, info=MemoryInfo(view=memoryview(d3))) + handle3 = handler.alloc(size=len(d3)) + buf = handle3.buf + buf[: len(d3)] = d3 + handler.store(key=overflow_key, handle=handle3) assert handler.owned_region_manager.get_stats()["num_regions"] == 2 @@ -252,16 +278,27 @@ def test_delete_from_expanded_region(self, server_thread, server_port): with MaruHandler(config) as handler: # Fill all pages in first region page_count = handler.allocator.page_count - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data1"))) + + data1 = b"data1" + handle = handler.alloc(size=len(data1)) + buf = handle.buf + buf[: len(data1)] = data1 + handler.store(key="1", handle=handle) + for i in range(2, page_count + 1): - handler.store(key=str(i), info=MemoryInfo(view=memoryview(b"filler"))) + filler = b"filler" + handle = handler.alloc(size=len(filler)) + buf = handle.buf + buf[: len(filler)] = filler + handler.store(key=str(i), handle=handle) # Next store triggers expansion overflow_key = str(page_count + 1) - handler.store( - key=overflow_key, - info=MemoryInfo(view=memoryview(b"data2")), - ) + data2 = b"data2" + handle = handler.alloc(size=len(data2)) + buf = handle.buf + buf[: len(data2)] = data2 + handler.store(key=overflow_key, handle=handle) stats = handler.owned_region_manager.get_stats() assert stats["num_regions"] == 2 @@ -282,9 +319,17 @@ def test_duplicate_key_across_regions_is_skipped(self, server_thread, server_por with MaruHandler(config) as handler: v1 = b"version1" - handler.store(key="1", info=MemoryInfo(view=memoryview(v1))) + handle = handler.alloc(size=len(v1)) + buf = handle.buf + buf[: len(v1)] = v1 + handler.store(key="1", handle=handle) + # Second store with same key is skipped - handler.store(key="1", info=MemoryInfo(view=memoryview(b"version2"))) + v2 = b"version2" + handle2 = handler.alloc(size=len(v2)) + buf2 = handle2.buf + buf2[: len(v2)] = v2 + handler.store(key="1", handle=handle2) # Original value is preserved result = handler.retrieve(key="1") @@ -306,12 +351,18 @@ def test_close_returns_all_regions(self, server_thread, server_port): # Fill first region and trigger expansion page_count = handler.allocator.page_count for i in range(page_count): - handler.store(key=str(i), info=MemoryInfo(view=memoryview(b"filler"))) + filler = b"filler" + handle = handler.alloc(size=len(filler)) + buf = handle.buf + buf[: len(filler)] = filler + handler.store(key=str(i), handle=handle) + # Next store triggers expansion - handler.store( - key=str(page_count + 1), - info=MemoryInfo(view=memoryview(b"overflow")), - ) + overflow = b"overflow" + handle = handler.alloc(size=len(overflow)) + buf = handle.buf + buf[: len(overflow)] = overflow + handler.store(key=str(page_count + 1), handle=handle) assert handler.owned_region_manager.get_stats()["num_regions"] == 2 @@ -336,7 +387,11 @@ def test_backward_compat_properties(self, server_thread, server_port): assert handler.allocator is not None assert handler.allocator.page_count == handler.pool_handle.length // 1024 - handler.store(key="1", info=MemoryInfo(view=memoryview(b"test"))) + data = b"test" + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + handler.store(key="1", handle=handle) assert handler.allocator.num_allocated == 1 def test_stats_with_multiple_regions(self, server_thread, server_port): @@ -351,12 +406,18 @@ def test_stats_with_multiple_regions(self, server_thread, server_port): # Fill all pages in first region page_count = handler.allocator.page_count for i in range(page_count): - handler.store(key=str(i), info=MemoryInfo(view=memoryview(b"filler"))) + filler = b"filler" + handle = handler.alloc(size=len(filler)) + buf = handle.buf + buf[: len(filler)] = filler + handler.store(key=str(i), handle=handle) + # Next store triggers expansion - handler.store( - key=str(page_count + 1), - info=MemoryInfo(view=memoryview(b"overflow")), - ) + overflow = b"overflow" + handle = handler.alloc(size=len(overflow)) + buf = handle.buf + buf[: len(overflow)] = overflow + handler.store(key=str(page_count + 1), handle=handle) stats = handler.get_stats() @@ -370,10 +431,10 @@ def test_stats_with_multiple_regions(self, server_thread, server_port): class TestMaruHandlerStorePrefix: - """Test store with prefix parameter.""" + """Test store with prefix written into buffer before calling store.""" def test_store_with_prefix(self, server_thread, server_port): - """Store with prefix parameter, verify prefix+data concatenated.""" + """Store with prefix+data written to buffer, verify prefix+data concatenated.""" config = MaruConfig( server_url=f"tcp://127.0.0.1:{server_port}", pool_size=4096, @@ -383,9 +444,12 @@ def test_store_with_prefix(self, server_thread, server_port): with MaruHandler(config) as handler: prefix = b"\x01\x02" data = b"hello" - info = MemoryInfo(view=memoryview(data)) - - assert handler.store(key="1", info=info, prefix=prefix) is True + total_size = len(prefix) + len(data) + handle = handler.alloc(size=total_size) + buf = handle.buf + buf[: len(prefix)] = prefix + buf[len(prefix) : total_size] = data + assert handler.store(key="1", handle=handle) is True # Retrieve and verify prefix+data layout result = handler.retrieve(key="1") @@ -394,7 +458,7 @@ def test_store_with_prefix(self, server_thread, server_port): assert bytes(result.view[: len(expected)]) == expected def test_store_with_empty_prefix(self, server_thread, server_port): - """Store with empty prefix, verify it works the same as prefix=None.""" + """Store with data only (no prefix), verify it works correctly.""" config = MaruConfig( server_url=f"tcp://127.0.0.1:{server_port}", pool_size=4096, @@ -403,10 +467,10 @@ def test_store_with_empty_prefix(self, server_thread, server_port): with MaruHandler(config) as handler: data = b"test data" - info = MemoryInfo(view=memoryview(data)) - - # Store with empty prefix - assert handler.store(key="1", info=info, prefix=b"") is True + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + assert handler.store(key="1", handle=handle) is True # Retrieve and verify data only result = handler.retrieve(key="1") @@ -428,10 +492,15 @@ def test_batch_store_and_batch_retrieve(self, server_thread, server_port): with MaruHandler(config) as handler: keys = ["1", "2", "3"] data = [b"data1", b"data2", b"data3"] - infos = [MemoryInfo(view=memoryview(d)) for d in data] + handles = [] + for d in data: + handle = handler.alloc(size=len(d)) + buf = handle.buf + buf[: len(d)] = d + handles.append(handle) # Batch store - results = handler.batch_store(keys=keys, infos=infos) + results = handler.batch_store(keys=keys, handles=handles) assert results == [True, True, True] # Batch retrieve @@ -442,7 +511,7 @@ def test_batch_store_and_batch_retrieve(self, server_thread, server_port): assert bytes(result.view[: len(data[i])]) == data[i] def test_batch_store_with_prefixes(self, server_thread, server_port): - """Call batch_store with prefixes parameter, verify prefix+data layout.""" + """Call batch_store with prefix+data written to buffers, verify layout.""" config = MaruConfig( server_url=f"tcp://127.0.0.1:{server_port}", pool_size=4096, @@ -453,10 +522,17 @@ def test_batch_store_with_prefixes(self, server_thread, server_port): keys = ["1", "2"] data = [b"data1", b"data2"] prefixes = [b"\x01", b"\x02\x03"] - infos = [MemoryInfo(view=memoryview(d)) for d in data] + handles = [] + for d, prefix in zip(data, prefixes, strict=False): + total_size = len(prefix) + len(d) + handle = handler.alloc(size=total_size) + buf = handle.buf + buf[: len(prefix)] = prefix + buf[len(prefix) : total_size] = d + handles.append(handle) - # Batch store with prefixes - results = handler.batch_store(keys=keys, infos=infos, prefixes=prefixes) + # Batch store + results = handler.batch_store(keys=keys, handles=handles) assert results == [True, True] # Verify prefix+data layout @@ -467,7 +543,7 @@ def test_batch_store_with_prefixes(self, server_thread, server_port): assert bytes(result.view[: len(expected)]) == expected def test_batch_store_mismatched_lengths(self, server_thread, server_port): - """Call batch_store with mismatched keys/infos lengths, should raise ValueError.""" + """Call batch_store with mismatched keys/handles lengths, should raise ValueError.""" config = MaruConfig( server_url=f"tcp://127.0.0.1:{server_port}", pool_size=4096, @@ -476,12 +552,16 @@ def test_batch_store_mismatched_lengths(self, server_thread, server_port): with MaruHandler(config) as handler: keys = ["1", "2"] - infos = [MemoryInfo(view=memoryview(b"data1"))] + d = b"data1" + handle = handler.alloc(size=len(d)) + buf = handle.buf + buf[: len(d)] = d + handles = [handle] with pytest.raises( - ValueError, match="keys and infos must have the same length" + ValueError, match="keys and handles must have the same length" ): - handler.batch_store(keys=keys, infos=infos) + handler.batch_store(keys=keys, handles=handles) def test_batch_exists(self, server_thread, server_port): """Store some keys, call batch_exists, verify correct True/False results.""" @@ -493,8 +573,17 @@ def test_batch_exists(self, server_thread, server_port): with MaruHandler(config) as handler: # Store keys 1 and 3 - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data1"))) - handler.store(key="3", info=MemoryInfo(view=memoryview(b"data3"))) + d1 = b"data1" + handle = handler.alloc(size=len(d1)) + buf = handle.buf + buf[: len(d1)] = d1 + handler.store(key="1", handle=handle) + + d3 = b"data3" + handle = handler.alloc(size=len(d3)) + buf = handle.buf + buf[: len(d3)] = d3 + handler.store(key="3", handle=handle) # Check existence of keys 1, 2, 3 results = handler.batch_exists(keys=["1", "2", "3"]) @@ -558,8 +647,10 @@ def test_store_and_retrieve_with_async_server( with MaruHandler(config) as handler: data = b"hello async world" - info = MemoryInfo(view=memoryview(data)) - assert handler.store(key="42", info=info) is True + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + assert handler.store(key="42", handle=handle) is True assert handler.exists(key="42") is True result = handler.retrieve(key="42") @@ -576,7 +667,11 @@ def test_delete_with_async_server(self, async_server_thread, async_server_port): ) with MaruHandler(config) as handler: - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data1"))) + data = b"data1" + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + handler.store(key="1", handle=handle) assert handler.exists(key="1") is True handler.delete(key="1") @@ -597,23 +692,19 @@ def test_auto_expansion_with_async_server( # Fill all pages in first region page_count = handler.allocator.page_count for i in range(page_count): - assert ( - handler.store( - key=str(i), info=MemoryInfo(view=memoryview(b"x" * 100)) - ) - is True - ) + d = b"x" * 100 + handle = handler.alloc(size=len(d)) + buf = handle.buf + buf[: len(d)] = d + assert handler.store(key=str(i), handle=handle) is True # Trigger expansion overflow_data = b"overflow" overflow_key = str(page_count + 1) - assert ( - handler.store( - key=overflow_key, - info=MemoryInfo(view=memoryview(overflow_data)), - ) - is True - ) + handle = handler.alloc(size=len(overflow_data)) + buf = handle.buf + buf[: len(overflow_data)] = overflow_data + assert handler.store(key=overflow_key, handle=handle) is True # Verify expansion happened stats = handler.owned_region_manager.get_stats() @@ -638,9 +729,14 @@ def test_batch_operations_with_async_server( with MaruHandler(config) as handler: keys = ["10", "20", "30"] data = [b"batch1", b"batch2", b"batch3"] - infos = [MemoryInfo(view=memoryview(d)) for d in data] - - results = handler.batch_store(keys=keys, infos=infos) + handles = [] + for d in data: + handle = handler.alloc(size=len(d)) + buf = handle.buf + buf[: len(d)] = d + handles.append(handle) + + results = handler.batch_store(keys=keys, handles=handles) assert results == [True, True, True] retrieved = handler.batch_retrieve(keys=keys) @@ -693,8 +789,10 @@ def test_store_and_retrieve_with_sync_rpc(self, server_thread, server_port): with MaruHandler(config) as handler: data = b"sync rpc data" - info = MemoryInfo(view=memoryview(data)) - assert handler.store(key="100", info=info) is True + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + assert handler.store(key="100", handle=handle) is True assert handler.exists(key="100") is True result = handler.retrieve(key="100") @@ -711,7 +809,11 @@ def test_delete_with_sync_rpc(self, server_thread, server_port): ) with MaruHandler(config) as handler: - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data1"))) + data = b"data1" + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + handler.store(key="1", handle=handle) assert handler.exists(key="1") is True handler.delete(key="1") @@ -729,9 +831,14 @@ def test_batch_operations_with_sync_rpc(self, server_thread, server_port): with MaruHandler(config) as handler: keys = ["50", "60", "70"] data = [b"sync1", b"sync2", b"sync3"] - infos = [MemoryInfo(view=memoryview(d)) for d in data] - - results = handler.batch_store(keys=keys, infos=infos) + handles = [] + for d in data: + handle = handler.alloc(size=len(d)) + buf = handle.buf + buf[: len(d)] = d + handles.append(handle) + + results = handler.batch_store(keys=keys, handles=handles) assert results == [True, True, True] retrieved = handler.batch_retrieve(keys=keys) @@ -776,8 +883,10 @@ def test_metadata_visibility_across_handlers(self, server_thread, server_port): with MaruHandler(config) as handler_a, MaruHandler(config) as handler_b: # Handler A stores data data = b"shared metadata" - info = MemoryInfo(view=memoryview(data)) - assert handler_a.store(key="999", info=info) is True + handle = handler_a.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + assert handler_a.store(key="999", handle=handle) is True # Handler B can see the key exists in KV registry assert handler_b.exists(key="999") is True @@ -799,8 +908,17 @@ def test_batch_exists_across_handlers(self, server_thread, server_port): # Both handlers alive concurrently with MaruHandler(config) as handler_a, MaruHandler(config) as handler_b: # Handler A stores keys 2000 and 2002 - handler_a.store(key="2000", info=MemoryInfo(view=memoryview(b"data2000"))) - handler_a.store(key="2002", info=MemoryInfo(view=memoryview(b"data2002"))) + d1 = b"data2000" + handle = handler_a.alloc(size=len(d1)) + buf = handle.buf + buf[: len(d1)] = d1 + handler_a.store(key="2000", handle=handle) + + d2 = b"data2002" + handle = handler_a.alloc(size=len(d2)) + buf = handle.buf + buf[: len(d2)] = d2 + handler_a.store(key="2002", handle=handle) # Handler B checks existence via shared KV registry results = handler_b.batch_exists(keys=["2000", "2001", "2002"]) @@ -819,14 +937,15 @@ def test_concurrent_stores_different_keys(self, server_thread, server_port): data_a = b"from handler A" data_b = b"from handler B" - assert ( - handler_a.store(key="100", info=MemoryInfo(view=memoryview(data_a))) - is True - ) - assert ( - handler_b.store(key="200", info=MemoryInfo(view=memoryview(data_b))) - is True - ) + handle_a = handler_a.alloc(size=len(data_a)) + buf = handle_a.buf + buf[: len(data_a)] = data_a + assert handler_a.store(key="100", handle=handle_a) is True + + handle_b = handler_b.alloc(size=len(data_b)) + buf = handle_b.buf + buf[: len(data_b)] = data_b + assert handler_b.store(key="200", handle=handle_b) is True # Both keys visible to both handlers assert handler_a.exists(key="100") is True @@ -850,8 +969,10 @@ def test_read_only_mapping_code_path(self, server_thread, server_port): with MaruHandler(config) as handler_a, MaruHandler(config) as handler_b: # Handler A stores data data = b"test read-only mapping" - info = MemoryInfo(view=memoryview(data)) - assert handler_a.store(key="3000", info=info) is True + handle = handler_a.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + assert handler_a.store(key="3000", handle=handle) is True # Handler B retrieves via read-only mapping of handler A's region result = handler_b.retrieve(key="3000") @@ -909,7 +1030,11 @@ def test_double_delete(self, server_thread, server_port): with MaruHandler(config) as handler: # Store a key - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data"))) + data = b"data" + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + handler.store(key="1", handle=handle) # First delete succeeds assert handler.delete(key="1") is True @@ -928,9 +1053,16 @@ def test_batch_retrieve_partial(self, server_thread, server_port): with MaruHandler(config) as handler: # Store only keys 1 and 3 data1 = b"data1" + handle = handler.alloc(size=len(data1)) + buf = handle.buf + buf[: len(data1)] = data1 + handler.store(key="1", handle=handle) + data3 = b"data3" - handler.store(key="1", info=MemoryInfo(view=memoryview(data1))) - handler.store(key="3", info=MemoryInfo(view=memoryview(data3))) + handle = handler.alloc(size=len(data3)) + buf = handle.buf + buf[: len(data3)] = data3 + handler.store(key="3", handle=handle) # Batch retrieve keys 1, 2, 3 results = handler.batch_retrieve(keys=["1", "2", "3"]) @@ -948,7 +1080,7 @@ def test_batch_retrieve_partial(self, server_thread, server_port): assert bytes(results[2].view[: len(data3)]) == data3 def test_store_after_close(self, server_thread, server_port): - """Connect handler, close it, then try to store raises RuntimeError.""" + """Connect handler, close it, then try to alloc raises RuntimeError.""" config = MaruConfig( server_url=f"tcp://127.0.0.1:{server_port}", pool_size=4096, @@ -961,4 +1093,4 @@ def test_store_after_close(self, server_thread, server_port): handler.close() with pytest.raises(RuntimeError): - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data"))) + handler.alloc(size=len(b"data")) diff --git a/tests/integration/test_rpc_async.py b/tests/integration/test_rpc_async.py index 1432f3e..9334e0c 100644 --- a/tests/integration/test_rpc_async.py +++ b/tests/integration/test_rpc_async.py @@ -16,6 +16,7 @@ import threading import time +import warnings from concurrent.futures import Future, ThreadPoolExecutor, as_completed, wait import pytest @@ -715,19 +716,22 @@ def test_pipeline_is_faster_than_sequential(self, async_client): key=str(400000 + i), region_id=region_id, kv_offset=i * 64, kv_length=64 ) futures.append(f) - # Wait for all + # Wait for all — validate correctness for f in futures: - f.result(timeout=10.0) + assert f.result(timeout=10.0) is True pipe_time = time.monotonic() - t0 - # Pipeline should be at least somewhat faster (or not significantly slower) - # We use a generous threshold since local IPC latency is very low speedup = seq_time / pipe_time if pipe_time > 0 else float("inf") print( f"Sequential: {seq_time:.4f}s, Pipeline: {pipe_time:.4f}s, Speedup: {speedup:.2f}x" ) - # At minimum, pipeline should not be dramatically slower - assert speedup > 0.5, f"Pipeline too slow: {speedup:.2f}x" + if speedup < 0.5: + warnings.warn( + f"Pipeline speedup {speedup:.2f}x < 0.5 — " + "scheduling overhead may dominate on fast local IPC", + UserWarning, + stacklevel=2, + ) # ============================================================================= diff --git a/tests/lmcache/conftest.py b/tests/lmcache/conftest.py index c279c93..0a01553 100644 --- a/tests/lmcache/conftest.py +++ b/tests/lmcache/conftest.py @@ -4,55 +4,3 @@ These tests require a working lmcache installation (with CUDA C extensions). The entire module is skipped if lmcache cannot be imported. """ - -import asyncio -from unittest.mock import MagicMock - -import pytest - - -@pytest.fixture -def async_loop(): - """Provide an asyncio event loop for tests.""" - loop = asyncio.new_event_loop() - yield loop - loop.close() - - -@pytest.fixture -def mock_maru_handler(): - """A mock MaruHandler that simulates connect/store/retrieve/etc.""" - handler = MagicMock() - handler.connect.return_value = True - handler.healthcheck.return_value = True - handler.exists.return_value = True - handler.delete.return_value = True - handler.close.return_value = None - return handler - - -@pytest.fixture -def mock_memory_info(): - """A mock MemoryInfo with a memoryview payload.""" - info = MagicMock() - data = bytearray(1024) - info.view = memoryview(data) - return info - - -@pytest.fixture -def lmcache_config(): - """A minimal LMCacheEngineConfig for testing.""" - from lmcache.v1.config import LMCacheEngineConfig - - return LMCacheEngineConfig( - chunk_size=256, - remote_url="maru://localhost:5555?pool_size=1G", - remote_storage_plugins=["maru"], - extra_config={ - "remote_storage_plugin.maru.module_path": "maru_lmcache.adapter", - "remote_storage_plugin.maru.class_name": "MaruConnectorAdapter", - "maru_pool_size": "4G", - "save_chunk_meta": False, - }, - ) diff --git a/tests/lmcache/test_adapter.py b/tests/lmcache/test_adapter.py deleted file mode 100644 index 49db214..0000000 --- a/tests/lmcache/test_adapter.py +++ /dev/null @@ -1,95 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for MaruConnectorAdapter and plugin discovery.""" - -import pytest - -pytest.importorskip( - "lmcache.v1.storage_backend.connector", - reason="lmcache not importable (CUDA C extensions required)", -) - -from maru_lmcache.adapter import MaruConnectorAdapter - - -class TestMaruConnectorAdapter: - """Unit tests for the adapter.""" - - def test_schema_is_maru(self): - adapter = MaruConnectorAdapter() - assert adapter.schema == "maru://" - - def test_can_parse_maru_url(self): - adapter = MaruConnectorAdapter() - assert adapter.can_parse("maru://localhost:5555") - assert adapter.can_parse("maru://10.0.0.1:5555?pool_size=2G") - - def test_cannot_parse_other_schemes(self): - adapter = MaruConnectorAdapter() - assert not adapter.can_parse("redis://localhost:6379") - assert not adapter.can_parse("s3://bucket") - assert not adapter.can_parse("") - - def test_create_connector_requires_config_and_metadata(self, async_loop): - from lmcache.v1.storage_backend.connector import ConnectorContext - - adapter = MaruConnectorAdapter() - context = ConnectorContext( - url="maru://localhost:5555", - loop=async_loop, - local_cpu_backend=None, - config=None, - metadata=None, - ) - with pytest.raises(ValueError, match="requires config and metadata"): - adapter.create_connector(context) - - -class TestPluginDiscovery: - """Test that upstream LMCache discovers our adapter via plugin system.""" - - def test_connector_manager_loads_maru_adapter(self, async_loop, lmcache_config): - from lmcache.v1.storage_backend.connector import ConnectorManager - - manager = ConnectorManager( - url="maru://localhost:5555?pool_size=1G", - loop=async_loop, - local_cpu_backend=None, - config=lmcache_config, - ) - - adapter_names = [a.__class__.__name__ for a in manager.adapters] - assert "MaruConnectorAdapter" in adapter_names - - def test_connector_manager_can_parse_maru_url(self, async_loop, lmcache_config): - from lmcache.v1.storage_backend.connector import ConnectorManager - - manager = ConnectorManager( - url="maru://localhost:5555?pool_size=1G", - loop=async_loop, - local_cpu_backend=None, - config=lmcache_config, - ) - - matched = [a for a in manager.adapters if a.can_parse("maru://localhost:5555")] - assert len(matched) == 1 - assert matched[0].__class__.__name__ == "MaruConnectorAdapter" - - def test_not_loaded_without_remote_storage_plugins(self, async_loop): - """Without remote_storage_plugins config, Maru adapter is not loaded.""" - from lmcache.v1.config import LMCacheEngineConfig - from lmcache.v1.storage_backend.connector import ConnectorManager - - config = LMCacheEngineConfig( - chunk_size=256, - remote_url="redis://localhost:6379", - ) - - manager = ConnectorManager( - url="redis://localhost:6379", - loop=async_loop, - local_cpu_backend=None, - config=config, - ) - - adapter_names = [a.__class__.__name__ for a in manager.adapters] - assert "MaruConnectorAdapter" not in adapter_names diff --git a/tests/lmcache/test_connector.py b/tests/lmcache/test_connector.py deleted file mode 100644 index 2bbd341..0000000 --- a/tests/lmcache/test_connector.py +++ /dev/null @@ -1,403 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for MaruConnector and MaruConnectorConfig.""" - -from unittest.mock import MagicMock, patch - -import pytest - -pytest.importorskip( - "lmcache.v1.storage_backend.connector", - reason="lmcache not importable (CUDA C extensions required)", -) - -from maru_lmcache.connector import ( - MaruConnector, - MaruConnectorConfig, - cache_key_to_str, - parse_size, -) - -# --------------------------------------------------------------------------- -# parse_size -# --------------------------------------------------------------------------- - - -class TestParseSize: - @pytest.mark.parametrize( - "input_val, expected", - [ - ("1G", 1024**3), - ("2g", 2 * 1024**3), - ("500M", 500 * 1024**2), - ("1024K", 1024 * 1024), - ("4GB", 4 * 1024**3), - ("100", 100), - (42, 42), - ], - ) - def test_valid_sizes(self, input_val, expected): - assert parse_size(input_val) == expected - - -# --------------------------------------------------------------------------- -# MaruConnectorConfig -# --------------------------------------------------------------------------- - - -class TestMaruConnectorConfig: - def test_from_url_defaults(self): - cfg = MaruConnectorConfig.from_url("maru://localhost:5555") - assert cfg.server_url == "tcp://localhost:5555" - assert cfg.pool_size == 1024**3 # 1G default - assert cfg.instance_id is None - - def test_from_url_with_params(self): - cfg = MaruConnectorConfig.from_url( - "maru://10.0.0.1:7777?pool_size=4G&instance_id=worker-0&timeout=60" - ) - assert cfg.server_url == "tcp://10.0.0.1:7777" - assert cfg.pool_size == 4 * 1024**3 - assert cfg.instance_id == "worker-0" - assert cfg.connection_timeout == 60.0 - - def test_from_lmcache_config(self): - from lmcache.v1.config import LMCacheEngineConfig - - config = LMCacheEngineConfig( - chunk_size=256, - remote_url="maru://localhost:5555", - extra_config={ - "maru_server_url": "tcp://10.0.0.2:6666", - "maru_pool_size": "2G", - "maru_instance_id": "test-instance", - }, - ) - cfg = MaruConnectorConfig.from_lmcache_config(config) - assert cfg.server_url == "tcp://10.0.0.2:6666" - assert cfg.pool_size == 2 * 1024**3 - assert cfg.instance_id == "test-instance" - - def test_from_lmcache_config_with_fallback(self): - from lmcache.v1.config import LMCacheEngineConfig - - config = LMCacheEngineConfig( - chunk_size=256, - remote_url="maru://localhost:5555", - extra_config={"maru_pool_size": "8G"}, - ) - fallback = MaruConnectorConfig( - server_url="tcp://fallback:9999", - instance_id="fallback-id", - ) - cfg = MaruConnectorConfig.from_lmcache_config(config, fallback=fallback) - # pool_size from extra_config - assert cfg.pool_size == 8 * 1024**3 - # server_url/instance_id from fallback - assert cfg.server_url == "tcp://fallback:9999" - assert cfg.instance_id == "fallback-id" - - -# --------------------------------------------------------------------------- -# cache_key_to_str -# --------------------------------------------------------------------------- - - -class TestCacheKeyToStr: - def test_deterministic(self): - key = MagicMock() - key.to_string.return_value = "model|layer|token_range|fmt" - h1 = cache_key_to_str(key) - h2 = cache_key_to_str(key) - assert h1 == h2 - - def test_different_keys_differ(self): - k1, k2 = MagicMock(), MagicMock() - k1.to_string.return_value = "key_a" - k2.to_string.return_value = "key_b" - assert cache_key_to_str(k1) != cache_key_to_str(k2) - - def test_returns_string(self): - key = MagicMock() - key.to_string.return_value = "test_key" - result = cache_key_to_str(key) - assert isinstance(result, str) - assert result == "test_key" - - -# --------------------------------------------------------------------------- -# MaruConnector (with mocked MaruHandler) -# --------------------------------------------------------------------------- - - -def _make_connector(async_loop, mock_handler): - """Create a MaruConnector with a pre-injected mock handler.""" - from lmcache.v1.config import LMCacheEngineConfig - from lmcache.v1.metadata import LMCacheMetadata - - config = LMCacheEngineConfig( - chunk_size=256, - remote_url="maru://localhost:5555", - extra_config={"save_chunk_meta": False}, - ) - - # Mock metadata to avoid needing a real vLLM model config - import torch - - metadata = MagicMock(spec=LMCacheMetadata) - metadata.get_shapes.return_value = [torch.Size([2, 32, 256, 128])] - metadata.get_dtypes.return_value = [torch.float16] - metadata.use_mla = False - metadata.chunk_size = 256 - metadata.get_num_groups.return_value = 1 - - maru_config = MaruConnectorConfig(auto_connect=False) - - with ( - patch( - "lmcache.v1.storage_backend.connector.base_connector.get_size_bytes", - return_value=256 * 2 * 32 * 128 * 2, # fake chunk size - ), - patch( - "lmcache.v1.storage_backend.connector.base_connector.init_remote_metadata_info" - ), - patch( - "lmcache.v1.storage_backend.connector.base_connector.get_remote_metadata_bytes", - return_value=0, - ), - ): - connector = MaruConnector( - url="maru://localhost:5555", - loop=async_loop, - config=config, - metadata=metadata, - maru_config=maru_config, - ) - - # Inject mock handler - connector._handle = mock_handler - connector._connected = True - return connector - - -class TestMaruConnector: - def test_exists(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - mock_maru_handler.exists.return_value = True - - key = MagicMock() - key.to_string.return_value = "test_key" - - result = async_loop.run_until_complete(connector.exists(key)) - assert result is True - mock_maru_handler.exists.assert_called_once() - - def test_exists_returns_false_when_disconnected(self, async_loop): - maru_config = MaruConnectorConfig(auto_connect=False) - - # Can't create a real connector without metadata, so test the - # _ensure_connected path directly - connector = MagicMock(spec=MaruConnector) - connector._connected = False - connector._handle = None - connector._ensure_connected = MagicMock(return_value=False) - connector.maru_config = maru_config - - # Call the real exists method - result = async_loop.run_until_complete( - MaruConnector.exists(connector, MagicMock()) - ) - assert result is False - - def test_exists_sync(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - mock_maru_handler.exists.return_value = False - - key = MagicMock() - key.to_string.return_value = "test_key" - - result = connector.exists_sync(key) - assert result is False - - def test_put_and_get(self, async_loop, mock_maru_handler, mock_memory_info): - connector = _make_connector(async_loop, mock_maru_handler) - mock_maru_handler.store.return_value = True - mock_maru_handler.retrieve.return_value = mock_memory_info - - key = MagicMock() - key.to_string.return_value = "test_key" - - # put - memory_obj = MagicMock() - memory_obj.byte_array = memoryview(bytearray(1024)) - - with patch("maru_lmcache.connector.MaruConnector._encode_memory_obj") as enc: - enc.return_value = mock_memory_info - async_loop.run_until_complete(connector.put(key, memory_obj)) - - mock_maru_handler.store.assert_called_once() - - # get - with ( - patch.object(connector, "_decode_memory_obj") as dec, - patch.object(connector, "reshape_partial_chunk") as reshape, - ): - dec.return_value = MagicMock() - reshape.return_value = dec.return_value - result = async_loop.run_until_complete(connector.get(key)) - - assert result is not None - mock_maru_handler.retrieve.assert_called_once() - - def test_close(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - - async_loop.run_until_complete(connector.close()) - mock_maru_handler.close.assert_called_once() - assert connector._handle is None - assert connector._connected is False - - def test_ping_success(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - mock_maru_handler.healthcheck.return_value = True - - result = async_loop.run_until_complete(connector.ping()) - assert result == 0 # PING_SUCCESS - - def test_ping_not_connected(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - connector._connected = False - - result = async_loop.run_until_complete(connector.ping()) - assert result == 1 # PING_NOT_CONNECTED - - def test_remove_sync(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - mock_maru_handler.delete.return_value = True - - key = MagicMock() - key.to_string.return_value = "test_key" - assert connector.remove_sync(key) is True - - def test_list_returns_empty(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - result = async_loop.run_until_complete(connector.list()) - assert result == [] - - def test_repr(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - r = repr(connector) - assert "MaruConnector" in r - assert "connected=True" in r - - -# --------------------------------------------------------------------------- -# Batch operations -# --------------------------------------------------------------------------- - - -class TestBatchOperations: - def test_support_flags(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - assert connector.support_batched_get() is True - assert connector.support_batched_put() is True - assert connector.support_batched_async_contains() is True - assert connector.support_batched_contains() is True - assert connector.support_batched_get_non_blocking() is True - assert connector.support_ping() is True - - def test_batched_contains(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - mock_maru_handler.batch_exists.return_value = [True, True, False] - - keys = [MagicMock() for _ in range(3)] - for i, k in enumerate(keys): - k.to_string.return_value = f"key_{i}" - - result = connector.batched_contains(keys) - assert result == 2 # 2 consecutive hits - - def test_batched_async_contains(self, async_loop, mock_maru_handler): - connector = _make_connector(async_loop, mock_maru_handler) - mock_maru_handler.batch_exists.return_value = [True, True, True] - - keys = [MagicMock() for _ in range(3)] - for i, k in enumerate(keys): - k.to_string.return_value = f"key_{i}" - - result = async_loop.run_until_complete( - connector.batched_async_contains("lookup-1", keys) - ) - assert result == 3 - - def test_batched_get(self, async_loop, mock_maru_handler, mock_memory_info): - connector = _make_connector(async_loop, mock_maru_handler) - mock_maru_handler.batch_retrieve.return_value = [ - mock_memory_info, - None, - mock_memory_info, - ] - - keys = [MagicMock() for _ in range(3)] - for i, k in enumerate(keys): - k.to_string.return_value = f"key_{i}" - - with ( - patch.object(connector, "_decode_memory_obj") as dec, - patch.object(connector, "reshape_partial_chunk") as reshape, - ): - obj = MagicMock() - dec.return_value = obj - reshape.return_value = obj - results = async_loop.run_until_complete(connector.batched_get(keys)) - - assert len(results) == 3 - assert results[0] is not None - assert results[1] is None - assert results[2] is not None - - def test_batched_put(self, async_loop, mock_maru_handler, mock_memory_info): - connector = _make_connector(async_loop, mock_maru_handler) - mock_maru_handler.batch_store.return_value = [True, True] - - keys = [MagicMock() for _ in range(2)] - for i, k in enumerate(keys): - k.to_string.return_value = f"key_{i}" - - objs = [MagicMock() for _ in range(2)] - for obj in objs: - obj.byte_array = memoryview(bytearray(1024)) - - with patch("maru_lmcache.connector.MaruConnector._encode_memory_obj") as enc: - enc.return_value = mock_memory_info - async_loop.run_until_complete(connector.batched_put(keys, objs)) - - mock_maru_handler.batch_store.assert_called_once() - - def test_batched_get_non_blocking_consecutive_prefix( - self, async_loop, mock_maru_handler, mock_memory_info - ): - connector = _make_connector(async_loop, mock_maru_handler) - # Second key is a miss → only first returned - mock_maru_handler.batch_retrieve.return_value = [ - mock_memory_info, - None, - mock_memory_info, - ] - - keys = [MagicMock() for _ in range(3)] - for i, k in enumerate(keys): - k.to_string.return_value = f"key_{i}" - - with ( - patch.object(connector, "_decode_memory_obj") as dec, - patch.object(connector, "reshape_partial_chunk") as reshape, - ): - obj = MagicMock() - dec.return_value = obj - reshape.return_value = obj - results = async_loop.run_until_complete( - connector.batched_get_non_blocking("lookup-1", keys) - ) - - # Only first item (before the None) should be returned - assert len(results) == 1 diff --git a/tests/lmcache/test_maru_backend.py b/tests/lmcache/test_maru_backend.py new file mode 100644 index 0000000..c3d3d8f --- /dev/null +++ b/tests/lmcache/test_maru_backend.py @@ -0,0 +1,438 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for MaruBackend storage backend.""" + +import asyncio +import mmap +import threading +from unittest.mock import MagicMock, patch + +import pytest + +pytest.importorskip( + "lmcache.v1.storage_backend", + reason="lmcache not importable (CUDA C extensions required)", +) + +import torch +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import ( + MemoryFormat, + TensorMemoryObj, +) +from lmcache.v1.pin_monitor import PinMonitor + +from maru_handler.memory import AllocHandle +from maru_handler.memory.types import MappedRegion, MemoryInfo +from maru_lmcache.adapter import CxlMemoryAdapter + +# ========================================================================= +# Fixtures +# ========================================================================= + +# Match real KV cache: [2, 32, 256, 128] float16 = 4MB chunk +# For tests: use small shape that matches chunk_size +# chunk_size=1024, dtype=float32(4B) → 256 elements → shape=[256] +TEST_CHUNK_SIZE = 1024 +TEST_DTYPE = torch.float32 +TEST_SHAPE = torch.Size([256]) # 256 * 4 = 1024 bytes = chunk_size + + +def _make_mock_handler(pool_size=4096, chunk_size=TEST_CHUNK_SIZE): + """Create a mock MaruHandler with facade API and mmap-backed regions.""" + handler = MagicMock() + handler._connected = True + + region_id = 100 + page_count = pool_size // chunk_size + + # Real mmap for buffer views + mmap_obj = mmap.mmap(-1, pool_size) + mapped_region = MappedRegion( + region_id=region_id, + handle=MagicMock(region_id=region_id, length=pool_size), + size=pool_size, + _mmap_obj=mmap_obj, + ) + + # Facade methods + handler.get_buffer_view.side_effect = lambda rid, offset, size: ( + mapped_region.get_buffer_view(offset, size) if rid == region_id else None + ) + handler.get_region_page_count.side_effect = lambda rid: ( + page_count if rid == region_id else None + ) + handler.get_owned_region_ids.return_value = [region_id] + handler.get_chunk_size.return_value = chunk_size + + # set_on_region_added: capture callback and replay for existing regions + def mock_set_on_region_added(callback): + if callback is not None: + callback(region_id, page_count) + + handler.set_on_region_added.side_effect = mock_set_on_region_added + + page_counter = [0] + + def mock_alloc(size): + idx = page_counter[0] + page_counter[0] += 1 + buf = mapped_region.get_buffer_view(idx * chunk_size, size) + return AllocHandle(buf=buf, _region_id=region_id, _page_index=idx, _size=size) + + handler.alloc.side_effect = mock_alloc + handler.free = MagicMock() + handler.connect.return_value = True + handler.close.return_value = None + handler.store.return_value = True + handler.retrieve.return_value = None + handler.exists.return_value = False + handler.delete.return_value = True + + return handler + + +def _make_cache_key(chunk_hash: int = 12345) -> CacheEngineKey: + """Create a CacheEngineKey for testing.""" + return CacheEngineKey( + model_name="test-model", + world_size=1, + worker_id=0, + chunk_hash=chunk_hash, + dtype=torch.float32, + ) + + +def _make_memory_obj(adapter: CxlMemoryAdapter) -> TensorMemoryObj: + """Allocate a real TensorMemoryObj from the adapter.""" + obj = adapter.allocate(TEST_SHAPE, TEST_DTYPE) + assert obj is not None + return obj + + +@pytest.fixture(autouse=True) +def _init_pin_monitor(): + """Initialize PinMonitor singleton required by TensorMemoryObj.pin().""" + PinMonitor._instance = None + PinMonitor.GetOrCreate(LMCacheEngineConfig.from_defaults()) + yield + PinMonitor._instance = None + + +@pytest.fixture +def async_loop(): + """Provide an asyncio event loop running in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() + yield loop + loop.call_soon_threadsafe(loop.stop) + thread.join(timeout=5) + loop.close() + + +@pytest.fixture +def mock_handler(): + return _make_mock_handler() + + +@pytest.fixture +def adapter(mock_handler): + return CxlMemoryAdapter( + handler=mock_handler, + shapes=[TEST_SHAPE], + dtypes=[TEST_DTYPE], + fmt=MemoryFormat.KV_2LTD, + chunk_size=TEST_CHUNK_SIZE, + ) + + +@pytest.fixture +def backend(mock_handler, adapter, async_loop): + """Create a MaruBackend with mocked internals.""" + from lmcache.v1.storage_backend.maru_backend import MaruBackend + + with patch.object(MaruBackend, "initialize_allocator", return_value=adapter): + backend = MaruBackend.__new__(MaruBackend) + backend.dst_device = "cpu" + backend.config = MagicMock() + backend.loop = async_loop + backend.memory_allocator = adapter + backend._handler = mock_handler + + # Chunk metadata + backend._full_chunk_size_bytes = TEST_CHUNK_SIZE + backend._single_token_size = TEST_CHUNK_SIZE // 256 # 4 bytes per token + backend._mla_worker_id_as0_mode = False + + backend.put_lock = threading.Lock() + backend.put_tasks = set() + return backend + + +# ========================================================================= +# Tests +# ========================================================================= + + +class TestMaruBackendAllocate: + def test_allocate_returns_memory_obj(self, backend): + obj = backend.allocate(TEST_SHAPE, TEST_DTYPE) + assert obj is not None + assert obj.tensor is not None + assert obj.metadata.dtype == TEST_DTYPE + + def test_batched_allocate_returns_list(self, backend): + objs = backend.batched_allocate(TEST_SHAPE, TEST_DTYPE, batch_size=3) + assert objs is not None + assert len(objs) == 3 + + +class TestMaruBackendPut: + def test_submit_put_task_returns_future(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + future = backend.submit_put_task(key, obj) + assert future is not None + + future.result(timeout=5) + + backend._handler.store.assert_called_once() + + def test_submit_put_task_tracks_in_flight(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + + assert not backend.exists_in_put_tasks(key) + + future = backend.submit_put_task(key, obj) + future.result(timeout=5) + + assert not backend.exists_in_put_tasks(key) + + def test_batched_submit_put_task(self, backend, adapter): + keys = [_make_cache_key(i) for i in range(3)] + objs = [_make_memory_obj(adapter) for _ in range(3)] + for obj in objs: + obj.parent_allocator = None + + # LMCache batched_submit_put_task returns a single Future for the batch + futures = backend.batched_submit_put_task(keys, objs) + assert futures is not None + assert len(futures) == 1 + + for future in futures: + future.result(timeout=5) + + def test_submit_put_calls_callback(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + callback_called = [] + + def callback(k): + callback_called.append(k) + + future = backend.submit_put_task(key, obj, on_complete_callback=callback) + future.result(timeout=5) + + assert len(callback_called) == 1 + assert callback_called[0] == key + + def test_ref_count_managed_during_put(self, backend, adapter): + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + key = _make_cache_key() + initial_ref = obj.get_ref_count() + + future = backend.submit_put_task(key, obj) + future.result(timeout=5) + + # ref_count_up x1 in submit_put_task, SM ref_count_down not called here + assert obj.get_ref_count() == initial_ref + 1 + + +class TestMaruBackendGet: + def test_get_blocking_from_maru_server(self, backend, adapter): + key = _make_cache_key() + + # Mock retrieve to return a MemoryInfo with rid/pid + data_size = TEST_CHUNK_SIZE + data = bytearray(data_size) + mock_info = MemoryInfo( + view=memoryview(data), + region_id=100, + page_index=0, + ) + backend._handler.retrieve.return_value = mock_info + + result = backend.get_blocking(key) + assert result is not None + backend._handler.retrieve.assert_called_once() + + def test_get_blocking_not_found(self, backend): + key = _make_cache_key() + backend._handler.retrieve.return_value = None + + result = backend.get_blocking(key) + assert result is None + + +class TestMaruBackendContains: + def test_contains_true(self, backend): + key = _make_cache_key() + backend._handler.exists.return_value = True + + assert backend.contains(key) is True + backend._handler.exists.assert_called_once_with(key.to_string()) + + def test_contains_false(self, backend): + key = _make_cache_key() + backend._handler.exists.return_value = False + + assert backend.contains(key) is False + + +def _run_async(loop, coro): + """Submit a coroutine to a running event loop and wait for result.""" + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result(timeout=5) + + +class TestMaruBackendAsyncLookup: + """Tests for batched_async_contains and batched_get_non_blocking. + + These mirror the connector-era tests in test_connector.py::TestBatchOperations + that were lost during the MaruBackend transition. + """ + + def test_batched_async_contains_all_hit(self, backend, async_loop): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, True] + + result = _run_async( + async_loop, backend.batched_async_contains("lookup-1", keys) + ) + assert result == 3 + + def test_batched_async_contains_partial_prefix(self, backend, async_loop): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [True, True, False] + + result = _run_async( + async_loop, backend.batched_async_contains("lookup-2", keys) + ) + assert result == 2 + + def test_batched_async_contains_first_miss(self, backend, async_loop): + keys = [_make_cache_key(i) for i in range(3)] + backend._handler.batch_exists.return_value = [False, False, False] + + result = _run_async( + async_loop, backend.batched_async_contains("lookup-3", keys) + ) + assert result == 0 + + def test_batched_async_contains_empty_keys(self, backend, async_loop): + backend._handler.batch_exists.return_value = [] + result = _run_async(async_loop, backend.batched_async_contains("lookup-4", [])) + assert result == 0 + + def test_batched_get_non_blocking_all_hit(self, backend, adapter, async_loop): + keys = [_make_cache_key(i) for i in range(2)] + + # Pre-store: allocate objects and mock batch_retrieve to return MemoryInfo list + objs = [_make_memory_obj(adapter) for _ in range(2)] + infos = [] + for obj in objs: + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + infos.append( + MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + ) + backend._handler.batch_retrieve.return_value = infos + + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-5", keys) + ) + assert len(results) == 2 + for obj in results: + assert obj is not None + + def test_batched_get_non_blocking_prefix_stop_on_miss( + self, backend, adapter, async_loop + ): + """Second key is a miss → only first returned (prefix semantics).""" + keys = [_make_cache_key(i) for i in range(3)] + + obj = _make_memory_obj(adapter) + rid, pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + info = MemoryInfo( + view=memoryview(bytearray(TEST_CHUNK_SIZE)), + region_id=rid, + page_index=pid, + ) + # hit, miss, hit → should return only [hit] + backend._handler.batch_retrieve.return_value = [info, None, info] + + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-6", keys) + ) + assert len(results) == 1 + + def test_batched_get_non_blocking_empty_keys(self, backend, async_loop): + backend._handler.batch_retrieve.return_value = [] + results = _run_async( + async_loop, backend.batched_get_non_blocking("lookup-7", []) + ) + assert results == [] + + +class TestMaruBackendRemove: + def test_remove_existing_key(self, backend): + key = _make_cache_key() + backend._handler.delete.return_value = True + + result = backend.remove(key) + assert result is True + backend._handler.delete.assert_called_once_with(key.to_string()) + + def test_remove_nonexistent_key(self, backend): + key = _make_cache_key() + backend._handler.delete.return_value = False + + result = backend.remove(key) + assert result is False + + +class TestMaruBackendLifecycle: + def test_close_calls_handler(self, backend): + backend.close() + backend._handler.close.assert_called_once() + + def test_str_representation(self, backend): + assert str(backend) == "MaruBackend" + + def test_get_allocator_backend_returns_self(self, backend): + assert backend.get_allocator_backend() is backend + + def test_get_memory_allocator_returns_adapter(self, backend, adapter): + assert backend.get_memory_allocator() is adapter + + +class TestMaruBackendStoreHandle: + def test_store_handle_roundtrip(self, backend, adapter): + """AllocHandle from create_store_handle should match original.""" + obj = _make_memory_obj(adapter) + obj.parent_allocator = None + + handle = adapter.create_store_handle(obj) + assert handle.region_id == 100 + assert handle.page_index == 0 + assert handle._size == obj.metadata.phy_size diff --git a/tests/lmcache/test_maru_integration.py b/tests/lmcache/test_maru_integration.py new file mode 100644 index 0000000..19a88bf --- /dev/null +++ b/tests/lmcache/test_maru_integration.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for MaruBackend integration with LMCache config and storage manager.""" + +from unittest.mock import MagicMock, patch + +import pytest + +pytest.importorskip( + "lmcache.v1.storage_backend", + reason="lmcache not importable (CUDA C extensions required)", +) + +from lmcache.v1.config import LMCacheEngineConfig + + +class TestConfigFields: + """Verify maru_path and maru_pool_size config fields.""" + + def test_maru_path_default_none(self): + config = LMCacheEngineConfig(chunk_size=256) + assert config.maru_path is None + + def test_maru_pool_size_default(self): + config = LMCacheEngineConfig(chunk_size=256) + assert config.maru_pool_size == 4.0 + + def test_maru_path_set(self): + config = LMCacheEngineConfig( + chunk_size=256, + maru_path="tcp://localhost:5555", + ) + assert config.maru_path == "tcp://localhost:5555" + + def test_maru_pool_size_set(self): + config = LMCacheEngineConfig( + chunk_size=256, + maru_pool_size=8.0, + ) + assert config.maru_pool_size == 8.0 + + +class TestCreateStorageBackends: + """Verify MaruBackend is created/skipped based on config.""" + + def test_no_maru_backend_without_maru_path(self): + """maru_path=None → MaruBackend not created.""" + import asyncio + + from lmcache.v1.metadata import LMCacheMetadata + from lmcache.v1.storage_backend import CreateStorageBackends + + config = LMCacheEngineConfig( + chunk_size=256, + max_local_cpu_size=0, + ) + metadata = MagicMock(spec=LMCacheMetadata) + metadata.role = "scheduler" + loop = asyncio.new_event_loop() + + try: + backends = CreateStorageBackends(config, metadata, loop, dst_device="cpu") + assert "MaruBackend" not in backends + finally: + loop.close() + + def test_maru_backend_created_with_maru_path(self): + """maru_path set → MaruBackend created (with mocked handler).""" + import asyncio + + from lmcache.v1.metadata import LMCacheMetadata + from lmcache.v1.storage_backend import CreateStorageBackends + + config = LMCacheEngineConfig( + chunk_size=256, + max_local_cpu_size=0, + maru_path="tcp://localhost:5555", + maru_pool_size=1.0, + ) + metadata = MagicMock(spec=LMCacheMetadata) + metadata.role = "scheduler" + loop = asyncio.new_event_loop() + + try: + with patch( + "lmcache.v1.storage_backend.maru_backend.MaruBackend.__init__", + return_value=None, + ) as mock_init: + mock_init.return_value = None + CreateStorageBackends(config, metadata, loop, dst_device="cpu") + # MaruBackend.__init__ was called + mock_init.assert_called_once() + finally: + loop.close() + + def test_maru_backend_skipped_when_in_skip_set(self): + """MaruBackend in skip_backends → not created.""" + import asyncio + + from lmcache.v1.metadata import LMCacheMetadata + from lmcache.v1.storage_backend import CreateStorageBackends + + config = LMCacheEngineConfig( + chunk_size=256, + max_local_cpu_size=0, + maru_path="tcp://localhost:5555", + ) + metadata = MagicMock(spec=LMCacheMetadata) + metadata.role = "scheduler" + loop = asyncio.new_event_loop() + + try: + backends = CreateStorageBackends( + config, + metadata, + loop, + dst_device="cpu", + skip_backends={"MaruBackend"}, + ) + assert "MaruBackend" not in backends + finally: + loop.close() + + def test_other_backends_unaffected_without_maru(self): + """Without maru, existing backends work normally.""" + import asyncio + + from lmcache.v1.metadata import LMCacheMetadata + from lmcache.v1.storage_backend import CreateStorageBackends + + config = LMCacheEngineConfig( + chunk_size=256, + max_local_cpu_size=0, + ) + metadata = MagicMock(spec=LMCacheMetadata) + metadata.role = "scheduler" + loop = asyncio.new_event_loop() + + try: + # Should not raise even though maru is not configured + backends = CreateStorageBackends(config, metadata, loop, dst_device="cpu") + assert isinstance(backends, dict) + finally: + loop.close() diff --git a/tests/unit/test_cxl_memory_adapter.py b/tests/unit/test_cxl_memory_adapter.py new file mode 100644 index 0000000..d47b63f --- /dev/null +++ b/tests/unit/test_cxl_memory_adapter.py @@ -0,0 +1,418 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for CxlMemoryAdapter (pool-based).""" + +import mmap +from unittest.mock import MagicMock + +import pytest + +torch = pytest.importorskip("torch") +lmcache_mm = pytest.importorskip("lmcache.v1.memory_management") +MemoryFormat = lmcache_mm.MemoryFormat + +from maru_handler.memory import AllocHandle # noqa: E402 +from maru_handler.memory.types import MappedRegion # noqa: E402 +from maru_lmcache.adapter import CxlMemoryAdapter # noqa: E402 + +# ========================================================================= +# Fixtures +# ========================================================================= + + +def _make_mock_handler(pool_size=4096, chunk_size=1024): + """Create a mock MaruHandler with facade API and mmap-backed regions.""" + handler = MagicMock() + handler._connected = True + + region_id = 100 + page_count = pool_size // chunk_size + + # Real mmap for buffer views + mmap_obj = mmap.mmap(-1, pool_size) + mapped_region = MappedRegion( + region_id=region_id, + handle=MagicMock(region_id=region_id, length=pool_size), + size=pool_size, + _mmap_obj=mmap_obj, + ) + + # Facade methods + handler.get_buffer_view.side_effect = lambda rid, offset, size: ( + mapped_region.get_buffer_view(offset, size) if rid == region_id else None + ) + handler.get_region_page_count.side_effect = lambda rid: ( + page_count if rid == region_id else None + ) + handler.get_owned_region_ids.return_value = [region_id] + handler.get_chunk_size.return_value = chunk_size + + # set_on_region_added: capture callback and replay for existing regions + _callback_holder = [None] + + def mock_set_on_region_added(callback): + _callback_holder[0] = callback + if callback is not None: + callback(region_id, page_count) + + handler.set_on_region_added.side_effect = mock_set_on_region_added + handler._callback_holder = _callback_holder + + # alloc returns incrementing page indices + page_counter = [0] + + def mock_alloc(size): + idx = page_counter[0] + page_counter[0] += 1 + buf = mapped_region.get_buffer_view(idx * chunk_size, size) + return AllocHandle(buf=buf, _region_id=region_id, _page_index=idx, _size=size) + + handler.alloc.side_effect = mock_alloc + handler.free = MagicMock() + + # Store extra refs for tests that need expansion + handler._mapped_region = mapped_region + handler._page_counter = page_counter + + return handler + + +def _make_adapter(handler): + """Create a CxlMemoryAdapter with standard test params.""" + chunk_size = handler.get_chunk_size() + dtype = torch.float32 + num_elements = chunk_size // dtype.itemsize + shape = torch.Size([num_elements]) + + return CxlMemoryAdapter( + handler=handler, + shapes=[shape], + dtypes=[dtype], + fmt=MemoryFormat.KV_2LTD, + chunk_size=chunk_size, + ) + + +# ========================================================================= +# Tests +# ========================================================================= + + +class TestAddressEncoding: + def test_encode_decode_roundtrip(self): + for rid, pid in [(0, 0), (1, 5), (100, 3), (0xFFFF, 0xFFFFFFFF)]: + encoded = CxlMemoryAdapter.encode_address(rid, pid) + decoded_rid, decoded_pid = CxlMemoryAdapter.decode_address(encoded) + assert decoded_rid == rid + assert decoded_pid == pid + + def test_encode_is_deterministic(self): + assert CxlMemoryAdapter.encode_address(1, 2) == (1 << 32) | 2 + + +class TestPoolCreation: + def test_pool_built_on_init(self): + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + assert 100 in adapter._pool + assert len(adapter._pool[100]) == 4 + + def test_pool_objects_have_correct_address(self): + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + for pid, obj in enumerate(adapter._pool[100]): + rid, decoded_pid = CxlMemoryAdapter.decode_address(obj.metadata.address) + assert rid == 100 + assert decoded_pid == pid + + def test_pool_built_via_callback(self): + """Pool is built through set_on_region_added callback, not direct access.""" + handler = _make_mock_handler() + _make_adapter(handler) + + # Verify callback was registered + handler.set_on_region_added.assert_called_once() + # Verify facade methods were used (not internal accessors) + handler.get_buffer_view.assert_called() + + +class TestRegionExpansionCallback: + def test_callback_builds_new_region_pool(self): + """Simulates region expansion: callback builds pool for new region.""" + handler = _make_mock_handler(pool_size=4096, chunk_size=1024) + adapter = _make_adapter(handler) + + # Initial pool has region 100 + assert 100 in adapter._pool + assert 200 not in adapter._pool + + # Create a new mmap for the expanded region + new_mmap = mmap.mmap(-1, 2048) + new_region = MappedRegion( + region_id=200, + handle=MagicMock(region_id=200, length=2048), + size=2048, + _mmap_obj=new_mmap, + ) + + # Update handler mock to include new region + original_get_buffer_view = handler.get_buffer_view.side_effect + + def updated_get_buffer_view(rid, offset, size): + if rid == 200: + return new_region.get_buffer_view(offset, size) + return original_get_buffer_view(rid, offset, size) + + handler.get_buffer_view.side_effect = updated_get_buffer_view + + # Fire the callback (simulating _expand_region) + callback = handler._callback_holder[0] + assert callback is not None + callback(200, 2) # 2 pages in new region + + # Verify new pool was built + assert 200 in adapter._pool + assert len(adapter._pool[200]) == 2 + + def test_allocate_after_expansion(self): + """After expansion callback, allocate works on new region pages.""" + handler = _make_mock_handler(pool_size=4096, chunk_size=1024) + adapter = _make_adapter(handler) + + # Setup new region + new_mmap = mmap.mmap(-1, 1024) + new_region = MappedRegion( + region_id=200, + handle=MagicMock(region_id=200, length=1024), + size=1024, + _mmap_obj=new_mmap, + ) + original_get_buffer_view = handler.get_buffer_view.side_effect + + def updated_get_buffer_view(rid, offset, size): + if rid == 200: + return new_region.get_buffer_view(offset, size) + return original_get_buffer_view(rid, offset, size) + + handler.get_buffer_view.side_effect = updated_get_buffer_view + + # Fire callback for new region + callback = handler._callback_holder[0] + callback(200, 1) + + # Override alloc to return from new region + handler.alloc.side_effect = lambda size: AllocHandle( + buf=new_region.get_buffer_view(0, size), + _region_id=200, + _page_index=0, + _size=size, + ) + + obj = adapter.allocate(torch.Size([256]), torch.float32) + assert obj is not None + assert obj is adapter._pool[200][0] + + +class TestAllocate: + def test_allocate_returns_tensor_memory_obj(self): + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + obj = adapter.allocate(torch.Size([256]), torch.float32) + + assert obj is not None + assert obj.tensor is not None + assert obj.metadata.ref_count == 1 + assert obj.metadata.dtype == torch.float32 + assert obj.metadata.phy_size == 1024 # chunk_size + + def test_allocate_returns_pooled_object(self): + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + obj = adapter.allocate(torch.Size([256]), torch.float32) + assert obj is adapter._pool[100][0] + + def test_allocate_address_encodes_rid_pid(self): + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + obj1 = adapter.allocate(torch.Size([8]), torch.float32) + obj2 = adapter.allocate(torch.Size([8]), torch.float32) + + rid1, pid1 = CxlMemoryAdapter.decode_address(obj1.metadata.address) + rid2, pid2 = CxlMemoryAdapter.decode_address(obj2.metadata.address) + + assert rid1 == 100 and pid1 == 0 + assert rid2 == 100 and pid2 == 1 + + def test_allocate_zero_size_returns_none(self): + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + obj = adapter.allocate(torch.Size([0]), torch.float32) + assert obj is None + + def test_allocate_handler_failure_returns_none(self): + handler = _make_mock_handler() + handler.alloc.side_effect = ValueError("pool exhausted") + adapter = _make_adapter(handler) + + obj = adapter.allocate(torch.Size([8]), torch.float32) + assert obj is None + + def test_allocate_tensor_writable(self): + """Tensor backed by CXL memoryview should be writable.""" + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + obj = adapter.allocate(torch.Size([256]), torch.float32) + assert obj is not None + obj.tensor[:] = torch.ones(256, dtype=torch.float32) + assert obj.tensor[0].item() == 1.0 + + +class TestBatchedAllocate: + def test_batched_allocate_returns_list(self): + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + objs = adapter.batched_allocate(torch.Size([8]), torch.float32, batch_size=3) + assert objs is not None + assert len(objs) == 3 + addresses = [o.metadata.address for o in objs] + assert len(set(addresses)) == 3 + + def test_batched_allocate_rollback_on_failure(self): + handler = _make_mock_handler() + call_count = [0] + original_alloc = handler.alloc.side_effect + + def fail_on_third(size): + call_count[0] += 1 + if call_count[0] == 3: + raise ValueError("exhausted") + return original_alloc(size) + + handler.alloc.side_effect = fail_on_third + adapter = _make_adapter(handler) + + objs = adapter.batched_allocate(torch.Size([8]), torch.float32, batch_size=4) + assert objs is None + # free() is a no-op on CxlMemoryAdapter (pages managed by MaruBackend.remove) + + +class TestFree: + def test_free_is_noop(self): + """free() is a no-op — pages managed by MaruBackend.remove().""" + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + obj = adapter.allocate(torch.Size([8]), torch.float32) + assert obj is not None + + adapter.free(obj) + # No handler.free call since adapter.free is a no-op + + def test_ref_count_lifecycle(self): + """ref_count up/down tracking works correctly.""" + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + obj = adapter.allocate(torch.Size([8]), torch.float32) + assert obj is not None + assert obj.metadata.ref_count == 1 + + obj.ref_count_up() + assert obj.metadata.ref_count == 2 + + obj.parent_allocator = None + obj.ref_count_down() + assert obj.metadata.ref_count == 1 + obj.ref_count_down() + assert obj.metadata.ref_count == 0 + + +class TestCreateStoreHandle: + def test_create_store_handle_roundtrip(self): + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + obj = adapter.allocate(torch.Size([8]), torch.float32) + assert obj is not None + + handle = adapter.create_store_handle(obj) + assert handle.region_id == 100 + assert handle.page_index == 0 + assert handle._size == obj.metadata.phy_size + + +class TestGetByLocation: + def test_get_by_location_full_chunk(self): + handler = _make_mock_handler(pool_size=4096, chunk_size=1024) + adapter = _make_adapter(handler) + + obj = adapter.get_by_location( + region_id=100, + page_index=2, + actual_size=1024, + single_token_size=64, + ) + assert obj is not None + assert obj is adapter._pool[100][2] + + def test_get_by_location_partial_chunk(self): + # Use 4D shape matching chunk_size for realistic partial chunk test + # chunk_size=1024, dtype=float32(4B) → 256 elements + # shape=[2, 2, 32, 2] → 256 elements, token_dim=shape[2]=32 + handler = _make_mock_handler(pool_size=4096, chunk_size=1024) + chunk_size = 1024 + dtype = torch.float32 + shape = torch.Size([2, 2, 32, 2]) + single_token_size = chunk_size // 32 # 32 bytes per token + + adapter = CxlMemoryAdapter( + handler=handler, + shapes=[shape], + dtypes=[dtype], + fmt=MemoryFormat.KV_2LTD, + chunk_size=chunk_size, + ) + + # Request half the tokens (16 tokens × 32 bytes = 512 bytes) + obj = adapter.get_by_location( + region_id=100, + page_index=1, + actual_size=512, + single_token_size=single_token_size, + ) + assert obj is not None + assert obj is not adapter._pool[100][1] + assert obj.metadata.phy_size == 512 + # Token dim should be halved: 32 → 16 + assert obj.metadata.shape[2] == 16 + + def test_get_by_location_invalid_region(self): + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + obj = adapter.get_by_location( + region_id=999, + page_index=0, + actual_size=1024, + single_token_size=64, + ) + assert obj is None + + +class TestClose: + def test_close_clears_pool_and_unregisters_callback(self): + handler = _make_mock_handler() + adapter = _make_adapter(handler) + + assert len(adapter._pool) > 0 + adapter.close() + assert len(adapter._pool) == 0 + # Callback should be unregistered (set to None) + assert handler.set_on_region_added.call_count == 2 # init + close diff --git a/tests/unit/test_kv_manager.py b/tests/unit/test_kv_manager.py index 023115f..4a6f7cc 100644 --- a/tests/unit/test_kv_manager.py +++ b/tests/unit/test_kv_manager.py @@ -2,7 +2,7 @@ # Copyright 2026 XCENA Inc. """Tests for KV Manager.""" -from maru_server.kv_manager import KVManager +from maru_server.kv_manager import DeleteResult, KVManager class TestKVManager: @@ -58,8 +58,8 @@ def test_delete_single_ref(self): manager = KVManager() manager.register(key="123", region_id=1, kv_offset=0, kv_length=1024) - existed, region_id = manager.delete("123") - assert existed is True + result, region_id = manager.delete("123") + assert result == DeleteResult.DELETED assert region_id == 1 # Need to decrement alloc ref assert manager.exists("123") is False @@ -75,21 +75,21 @@ def test_delete_multiple_refs(self): assert rid is None # Delete removes entry entirely on first call - existed, region_id = manager.delete("123") - assert existed is True + result, region_id = manager.delete("123") + assert result == DeleteResult.DELETED assert region_id == 1 assert manager.exists("123") is False - # Second delete on now-missing key returns (False, None) - existed, region_id = manager.delete("123") - assert existed is False + # Second delete on now-missing key returns NOT_FOUND + result, region_id = manager.delete("123") + assert result == DeleteResult.NOT_FOUND assert region_id is None def test_delete_nonexistent(self): """Test deleting a nonexistent key.""" manager = KVManager() - existed, region_id = manager.delete("999") - assert existed is False + result, region_id = manager.delete("999") + assert result == DeleteResult.NOT_FOUND assert region_id is None def test_get_stats(self): @@ -210,8 +210,8 @@ def test_delete_nonexistent_key(self): """Test deleting a key that was never registered.""" manager = KVManager() - existed, region_id = manager.delete("999") - assert existed is False + result, region_id = manager.delete("999") + assert result == DeleteResult.NOT_FOUND assert region_id is None def test_register_then_delete_then_re_register(self): @@ -219,8 +219,8 @@ def test_register_then_delete_then_re_register(self): manager = KVManager() manager.register(key="123", region_id=1, kv_offset=0, kv_length=1024) - existed, region_id = manager.delete("123") - assert existed is True + result, region_id = manager.delete("123") + assert result == DeleteResult.DELETED assert region_id == 1 assert manager.exists("123") is False @@ -231,3 +231,174 @@ def test_register_then_delete_then_re_register(self): assert is_new is True assert new_region_id == 2 assert manager.exists("123") is True + + +class TestKVManagerPin: + """Test cases for pin/unpin operations.""" + + # ---- pin() ---- + + def test_pin_existing_key(self): + """pin() on existing key returns True and increments pin_count.""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + + assert manager.pin("1") is True + assert manager.lookup("1").pin_count == 1 + + def test_pin_increments_multiple_times(self): + """Multiple pin() calls increment pin_count each time.""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + + manager.pin("1") + manager.pin("1") + manager.pin("1") + assert manager.lookup("1").pin_count == 3 + + def test_pin_nonexistent_key(self): + """pin() on nonexistent key returns False.""" + manager = KVManager() + assert manager.pin("999") is False + + # ---- unpin() ---- + + def test_unpin_pinned_key(self): + """unpin() on pinned key returns True and decrements pin_count.""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + manager.pin("1") + manager.pin("1") + + assert manager.unpin("1") is True + assert manager.lookup("1").pin_count == 1 + + def test_unpin_nonexistent_key(self): + """unpin() on nonexistent key returns False.""" + manager = KVManager() + assert manager.unpin("999") is False + + def test_unpin_underflow_protection(self): + """unpin() on key with pin_count=0 returns False (no underflow).""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + + # Never pinned — pin_count is 0 + assert manager.unpin("1") is False + assert manager.lookup("1").pin_count == 0 + + def test_unpin_after_full_decrement(self): + """unpin() returns False after pin_count reaches 0.""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + manager.pin("1") + + assert manager.unpin("1") is True + assert manager.lookup("1").pin_count == 0 + # Second unpin should fail + assert manager.unpin("1") is False + + # ---- delete() with pin ---- + + def test_delete_pinned_key_refused(self): + """delete() on pinned key returns PINNED and does not remove entry.""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + manager.pin("1") + + result, region_id = manager.delete("1") + assert result == DeleteResult.PINNED + assert region_id is None + # Entry still exists + assert manager.exists("1") is True + assert manager.lookup("1").pin_count == 1 + + def test_delete_after_unpin(self): + """delete() succeeds after pin_count reaches 0 via unpin().""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + manager.pin("1") + manager.unpin("1") + + result, region_id = manager.delete("1") + assert result == DeleteResult.DELETED + assert region_id == 1 + assert manager.exists("1") is False + + +class TestKVManagerBatchPin: + """Test cases for batch pin/unpin operations.""" + + # ---- batch_pin() ---- + + def test_batch_pin_all_exist(self): + """batch_pin() pins all keys when all exist.""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + manager.register(key="2", region_id=1, kv_offset=100, kv_length=100) + manager.register(key="3", region_id=1, kv_offset=200, kv_length=100) + + results = manager.batch_pin(["1", "2", "3"]) + assert results == [True, True, True] + assert manager.lookup("1").pin_count == 1 + assert manager.lookup("2").pin_count == 1 + assert manager.lookup("3").pin_count == 1 + + def test_batch_pin_prefix_stop(self): + """batch_pin() stops at first miss — only prefix keys are pinned.""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + # key "2" missing + manager.register(key="3", region_id=1, kv_offset=200, kv_length=100) + + results = manager.batch_pin(["1", "2", "3"]) + assert results == [True, False, False] + # Only "1" should be pinned + assert manager.lookup("1").pin_count == 1 + # "3" exists but should NOT be pinned + assert manager.lookup("3").pin_count == 0 + + def test_batch_pin_first_key_missing(self): + """batch_pin() with first key missing returns all False.""" + manager = KVManager() + manager.register(key="2", region_id=1, kv_offset=0, kv_length=100) + + results = manager.batch_pin(["1", "2"]) + assert results == [False, False] + assert manager.lookup("2").pin_count == 0 + + def test_batch_pin_empty_list(self): + """batch_pin([]) returns empty list.""" + manager = KVManager() + assert manager.batch_pin([]) == [] + + # ---- batch_unpin() ---- + + def test_batch_unpin_all_pinned(self): + """batch_unpin() unpins all previously pinned keys.""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + manager.register(key="2", region_id=1, kv_offset=100, kv_length=100) + manager.pin("1") + manager.pin("2") + + results = manager.batch_unpin(["1", "2"]) + assert results == [True, True] + assert manager.lookup("1").pin_count == 0 + assert manager.lookup("2").pin_count == 0 + + def test_batch_unpin_mixed(self): + """batch_unpin() with mix of pinned, unpinned, and missing keys.""" + manager = KVManager() + manager.register(key="1", region_id=1, kv_offset=0, kv_length=100) + manager.register(key="2", region_id=1, kv_offset=100, kv_length=100) + manager.pin("1") + # "2" registered but not pinned, "3" doesn't exist + + results = manager.batch_unpin(["1", "2", "3"]) + assert results == [True, False, False] + + def test_batch_unpin_empty_list(self): + """batch_unpin([]) returns empty list.""" + manager = KVManager() + assert manager.batch_unpin([]) == [] diff --git a/tests/unit/test_maru_handler.py b/tests/unit/test_maru_handler.py index 7231a66..e73f562 100644 --- a/tests/unit/test_maru_handler.py +++ b/tests/unit/test_maru_handler.py @@ -11,66 +11,6 @@ from conftest import _make_handle from maru import MaruConfig, MaruHandler -from maru_handler.handler import _gil_free_memcpy -from maru_handler.memory import MemoryInfo - -# ============================================================================= -# _gil_free_memcpy tests -# ============================================================================= - - -class TestGilFreeMemcpy: - """Unit tests for the GIL-free memcpy helper.""" - - def test_copy_from_bytes(self): - """Copy bytes into a writable memoryview.""" - dst = bytearray(16) - src = b"hello" - _gil_free_memcpy(memoryview(dst), src, len(src)) - assert dst[:5] == b"hello" - assert dst[5:] == b"\x00" * 11 - - def test_copy_from_writable_memoryview(self): - """Copy from a writable memoryview (production path).""" - dst = bytearray(16) - src = bytearray(b"world") - _gil_free_memcpy(memoryview(dst), memoryview(src), len(src)) - assert dst[:5] == b"world" - - def test_copy_from_readonly_memoryview(self): - """Copy from a read-only memoryview (bytes-backed).""" - dst = bytearray(16) - src = memoryview(b"readonly") - assert src.readonly - _gil_free_memcpy(memoryview(dst), src, len(src)) - assert dst[:8] == b"readonly" - - def test_partial_copy(self): - """Only copy nbytes, not the full source.""" - dst = bytearray(16) - src = b"abcdefgh" - _gil_free_memcpy(memoryview(dst), src, 3) - assert dst[:3] == b"abc" - assert dst[3:] == b"\x00" * 13 - - def test_copy_into_offset_slice(self): - """Copy into a memoryview slice at an offset (like store() does).""" - dst = bytearray(16) - prefix = b"\x01\x02" - data = b"payload" - mv = memoryview(dst) - _gil_free_memcpy(mv[0:], prefix, len(prefix)) - _gil_free_memcpy(mv[2:], data, len(data)) - assert dst[:2] == b"\x01\x02" - assert dst[2:9] == b"payload" - - def test_large_copy(self): - """Copy a larger buffer (1MB) to verify no size issues.""" - size = 1024 * 1024 - dst = bytearray(size) - src = bytes(range(256)) * (size // 256) - _gil_free_memcpy(memoryview(dst), src, size) - assert bytes(dst) == src class TestMaruHandlerConfig: @@ -147,6 +87,38 @@ def test_config_env_override_eager_mmap_invalid(self, monkeypatch): with pytest.raises(ValueError, match="MARU_EAGER_MAP must be one of"): MaruConfig() + def test_config_auto_expand_defaults_true(self): + """auto_expand defaults to True.""" + config = MaruConfig() + assert config.auto_expand is True + + def test_config_auto_expand_false(self): + """auto_expand=False is valid.""" + config = MaruConfig(auto_expand=False) + assert config.auto_expand is False + + def test_config_expand_size_requires_auto_expand(self): + """expand_size without auto_expand=True raises ValueError.""" + with pytest.raises(ValueError, match="expand_size requires auto_expand=True"): + MaruConfig(auto_expand=False, expand_size=4096) + + def test_config_expand_size_with_auto_expand(self): + """expand_size with auto_expand=True is valid.""" + config = MaruConfig(auto_expand=True, expand_size=4096, chunk_size_bytes=1024) + assert config.expand_size == 4096 + + def test_config_expand_size_smaller_than_chunk_raises(self): + """expand_size < chunk_size_bytes raises ValueError.""" + with pytest.raises( + ValueError, match="expand_size.*must be >= .*chunk_size_bytes" + ): + MaruConfig(auto_expand=True, expand_size=512, chunk_size_bytes=1024) + + def test_config_expand_size_none_default(self): + """expand_size defaults to None.""" + config = MaruConfig(auto_expand=True) + assert config.expand_size is None + class TestMaruHandlerEnsureConnected: """Test that operations require connection.""" @@ -161,7 +133,7 @@ def test_store_before_connect_raises(self): # Try to store — should raise RuntimeError with pytest.raises(RuntimeError, match="Not connected"): - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data"))) + handler.store(key="1", handle=MagicMock()) # ============================================================================= @@ -169,7 +141,9 @@ def test_store_before_connect_raises(self): # ============================================================================= -def _make_mock_handler(pool_size=8192, chunk_size=1024): +def _make_mock_handler( + pool_size=8192, chunk_size=1024, auto_expand=True, expand_size=None +): """Create a MaruHandler with mocked RPC for unit testing. Follows the pattern from test_thread_safety.py. @@ -182,6 +156,8 @@ def _make_mock_handler(pool_size=8192, chunk_size=1024): chunk_size_bytes=chunk_size, auto_connect=False, use_async_rpc=False, + auto_expand=auto_expand, + expand_size=expand_size, ) handler = MaruHandler(config) @@ -390,7 +366,10 @@ def test_retrieve_shared_region_on_demand_mapping(self): handler = _make_mock_handler() # Store something first so handler is set up - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data"))) + h = handler.alloc(size=4) + buf = h.buf + buf[:4] = b"data" + handler.store(key="1", handle=h) # lookup_kv returns a handle pointing to a DIFFERENT region (shared) shared_handle = _make_handle(200, 4096) @@ -595,83 +574,22 @@ def test_batch_store_closing_check_inside_lock(self): with pytest.raises(RuntimeError, match="Handler is closing"): handler.batch_store( keys=["1"], - infos=[MemoryInfo(view=memoryview(b"data"))], + handles=[MagicMock()], ) # Reset for cleanup handler._closing.clear() handler.close() - def test_batch_store_prefixes_length_mismatch(self): - """L535: prefixes length != keys length raises ValueError.""" - handler = _make_mock_handler() - - with pytest.raises(ValueError, match="prefixes must have the same length"): - handler.batch_store( - keys=["1", "2"], - infos=[ - MemoryInfo(view=memoryview(b"d1")), - MemoryInfo(view=memoryview(b"d2")), - ], - prefixes=[b"\x01"], # only 1, but keys has 2 - ) - - handler.close() - - def test_batch_store_format_cast(self): - """L552: src.format != 'B' triggers cast.""" - import array - - handler = _make_mock_handler() - - # Create a memoryview with format 'i' (int) instead of 'B' - arr = array.array("i", [1, 2, 3]) - mv = memoryview(arr) - assert mv.format != "B" - - info = MemoryInfo(view=mv) - - # batch_exists_kv: key not on server - batch_exists_resp = MagicMock() - batch_exists_resp.results = [False] - handler._rpc.batch_exists_kv = MagicMock(return_value=batch_exists_resp) - - batch_resp = MagicMock() - batch_resp.success = True - batch_resp.results = [True] - handler._rpc.batch_register_kv = MagicMock(return_value=batch_resp) - - # The data size after cast to 'B' is 12 bytes (3 ints * 4 bytes) - # chunk_size is 1024, so it should fit - results = handler.batch_store(keys=["1"], infos=[info]) - assert results == [True] - - handler.close() - - def test_batch_store_total_size_exceeds_chunk(self): - """L557-564: total_size exceeds chunk_size for a key.""" - handler = _make_mock_handler(chunk_size=64) - - # batch_exists_kv: key not on server so it proceeds to size check - batch_exists_resp = MagicMock() - batch_exists_resp.results = [False] - handler._rpc.batch_exists_kv = MagicMock(return_value=batch_exists_resp) - - big_data = b"x" * 100 # exceeds 64-byte chunk - results = handler.batch_store( - keys=["1"], - infos=[MemoryInfo(view=memoryview(big_data))], - ) - assert results == [False] - - handler.close() - def test_batch_store_overwrite_existing_key(self): """batch_store skips keys already in local map — idempotent, returns True.""" handler = _make_mock_handler() # First store - handler.store(key="42", info=MemoryInfo(view=memoryview(b"old"))) + h = handler.alloc(size=3) + buf = h.buf + buf[:3] = b"old" + handler.store(key="42", handle=h) assert "42" in handler._key_to_location # batch_exists_kv mock (Phase 1 check): key not on server either @@ -680,10 +598,10 @@ def test_batch_store_overwrite_existing_key(self): handler._rpc.batch_exists_kv = MagicMock(return_value=batch_exists_resp) # batch_store with same key — should skip via local map check, return True - results = handler.batch_store( - keys=["42"], - infos=[MemoryInfo(view=memoryview(b"new"))], - ) + h2 = handler.alloc(size=3) + buf2 = h2.buf + buf2[:3] = b"new" + results = handler.batch_store(keys=["42"], handles=[h2]) assert results == [True] # delete_kv never called — no overwrite, just skip handler._rpc.delete_kv.assert_not_called() @@ -692,15 +610,13 @@ def test_batch_store_overwrite_existing_key(self): def test_batch_store_alloc_fails_expand_fails(self): """L577-584: allocation fails, expand fails.""" - handler = _make_mock_handler(pool_size=1024, chunk_size=1024) + handler = _make_mock_handler(pool_size=1024, chunk_size=1024, auto_expand=True) # Fill the single page - handler.store(key="1", info=MemoryInfo(view=memoryview(b"fill"))) - - # batch_exists_kv: key 2 not on server - batch_exists_resp = MagicMock() - batch_exists_resp.results = [False] - handler._rpc.batch_exists_kv = MagicMock(return_value=batch_exists_resp) + h = handler.alloc(size=4) + buf = h.buf + buf[:4] = b"fill" + handler.store(key="1", handle=h) # Make expand fail alloc_fail = MagicMock() @@ -708,41 +624,13 @@ def test_batch_store_alloc_fails_expand_fails(self): alloc_fail.handle = None handler._rpc.request_alloc = MagicMock(return_value=alloc_fail) - results = handler.batch_store( - keys=["2"], - infos=[MemoryInfo(view=memoryview(b"data"))], - ) - assert results == [False] + # alloc raises when pool is exhausted and expansion fails + with pytest.raises((ValueError, RuntimeError)): + h2 = handler.alloc(size=4) + handler.batch_store(keys=["2"], handles=[h2]) handler.close() - def test_batch_store_get_buffer_view_none(self): - """L594-596: get_buffer_view returns None in batch_store.""" - handler = _make_mock_handler() - - # batch_exists_kv: key not on server - batch_exists_resp = MagicMock() - batch_exists_resp.results = [False] - handler._rpc.batch_exists_kv = MagicMock(return_value=batch_exists_resp) - - # Make get_buffer_view return None - original_get_buf = handler._mapper.get_buffer_view - - def return_none_buf(region_id, offset, size): - return None - - handler._mapper.get_buffer_view = return_none_buf - - results = handler.batch_store( - keys=["1"], - infos=[MemoryInfo(view=memoryview(b"data"))], - ) - assert results == [False] - - # Restore for cleanup - handler._mapper.get_buffer_view = original_get_buf - handler.close() - def test_batch_store_register_rpc_raises(self): """L611-615: batch_register_kv RPC raises — free all, return [False]*len.""" handler = _make_mock_handler() @@ -756,13 +644,11 @@ def test_batch_store_register_rpc_raises(self): side_effect=RuntimeError("RPC failed") ) - results = handler.batch_store( - keys=["1", "2"], - infos=[ - MemoryInfo(view=memoryview(b"d1")), - MemoryInfo(view=memoryview(b"d2")), - ], - ) + h1 = handler.alloc(size=2) + h1.buf[:2] = b"d1" + h2 = handler.alloc(size=2) + h2.buf[:2] = b"d2" + results = handler.batch_store(keys=["1", "2"], handles=[h1, h2]) assert results == [False, False] handler.close() @@ -780,10 +666,9 @@ def test_batch_store_register_returns_failure(self): batch_resp.success = False handler._rpc.batch_register_kv = MagicMock(return_value=batch_resp) - results = handler.batch_store( - keys=["1"], - infos=[MemoryInfo(view=memoryview(b"data"))], - ) + h = handler.alloc(size=4) + h.buf[:4] = b"data" + results = handler.batch_store(keys=["1"], handles=[h]) assert results == [False] handler.close() @@ -851,29 +736,33 @@ def test_instance_id_property(self): # ================================================================= def test_expand_region_rpc_raises(self): - """L718-720: request_alloc RPC raises exception during expand.""" - handler = _make_mock_handler(pool_size=1024, chunk_size=1024) + """request_alloc RPC raises exception during expand.""" + handler = _make_mock_handler(pool_size=1024, chunk_size=1024, auto_expand=True) # Fill the single page - handler.store(key="1", info=MemoryInfo(view=memoryview(b"fill"))) + h1 = handler.alloc(size=4) + h1.buf[:4] = b"fill" + handler.store(key="1", handle=h1) # Make request_alloc raise handler._rpc.request_alloc = MagicMock(side_effect=RuntimeError("RPC timeout")) - # Try to store another key, triggering expand - result = handler.store(key="2", info=MemoryInfo(view=memoryview(b"data"))) - assert result is False + # Try to alloc another page, triggering expand — should raise + with pytest.raises((ValueError, RuntimeError)): + handler.alloc(size=4) handler.close() def test_expand_region_add_region_raises(self, monkeypatch): - """L734-740: add_region raises during expand — catches, calls return_alloc.""" + """add_region raises during expand — catches, calls return_alloc.""" from maru_handler.memory import OwnedRegionManager - handler = _make_mock_handler(pool_size=1024, chunk_size=1024) + handler = _make_mock_handler(pool_size=1024, chunk_size=1024, auto_expand=True) # Fill the single page - handler.store(key="1", info=MemoryInfo(view=memoryview(b"fill"))) + h1 = handler.alloc(size=4) + h1.buf[:4] = b"fill" + handler.store(key="1", handle=h1) # request_alloc succeeds with a new region expand_response = MagicMock() @@ -887,8 +776,9 @@ def failing_add(self_mgr, handle): monkeypatch.setattr(OwnedRegionManager, "add_region", failing_add) - result = handler.store(key="2", info=MemoryInfo(view=memoryview(b"data"))) - assert result is False + # alloc triggers expansion which fails + with pytest.raises((ValueError, RuntimeError)): + handler.alloc(size=4) # return_alloc should have been called for the failed region handler._rpc.return_alloc.assert_called() @@ -939,12 +829,14 @@ def failing_add(self_mgr, handle): # ================================================================= def test_store_happy_path(self): - """L220-308: Full store happy path (covers write_lock, allocate, write, register).""" + """Full store happy path (covers write_lock, allocate, write, register).""" handler = _make_mock_handler() data = b"hello world" - info = MemoryInfo(view=memoryview(data)) - result = handler.store(key="42", info=info) + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + result = handler.store(key="42", handle=handle) assert result is True assert "42" in handler._key_to_location @@ -952,41 +844,8 @@ def test_store_happy_path(self): handler.close() - def test_store_with_memoryview(self): - """store() accepts a raw memoryview via the info parameter.""" - handler = _make_mock_handler() - - data = b"hello memoryview" - result = handler.store(key="100", info=memoryview(data)) - assert result is True - assert "100" in handler._key_to_location - - handler._rpc.register_kv.assert_called_once() - handler.close() - - def test_store_with_data_kwarg(self): - """store() accepts a memoryview via the data keyword argument.""" - handler = _make_mock_handler() - - data = b"hello data kwarg" - result = handler.store(key="200", data=memoryview(data)) - assert result is True - assert "200" in handler._key_to_location - - handler._rpc.register_kv.assert_called_once() - handler.close() - - def test_store_no_data_raises(self): - """store() raises TypeError when neither info nor data is provided.""" - handler = _make_mock_handler() - - with pytest.raises(TypeError, match="Must provide data"): - handler.store(key="300") - - handler.close() - def test_store_closing_raises_inside_lock(self): - """L222: store() raises RuntimeError from inside write_lock when closing.""" + """store() raises RuntimeError from inside write_lock when closing.""" handler = _make_mock_handler() # Bypass _ensure_connected so we reach the check inside the lock @@ -994,33 +853,17 @@ def test_store_closing_raises_inside_lock(self): handler._closing.set() with pytest.raises(RuntimeError, match="Handler is closing"): - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data"))) + handler.store(key="1", handle=MagicMock()) handler._closing.clear() handler.close() - def test_store_format_cast(self): - """L227: src.format != 'B' triggers cast in store.""" - import array - - handler = _make_mock_handler() - - arr = array.array("i", [1, 2, 3]) - mv = memoryview(arr) - assert mv.format != "B" - - result = handler.store(key="1", info=MemoryInfo(view=mv)) - assert result is True - - handler.close() - def test_store_exceeds_chunk_size(self): - """L244-249: total_size exceeds chunk_size in store.""" + """Requesting size > chunk_size raises ValueError from alloc.""" handler = _make_mock_handler(chunk_size=64) - big_data = b"x" * 100 - result = handler.store(key="1", info=MemoryInfo(view=memoryview(big_data))) - assert result is False + with pytest.raises(ValueError, match="exceeds chunk_size"): + handler.alloc(size=100) handler.close() @@ -1028,12 +871,16 @@ def test_store_overwrite_existing_key(self): """store() now skips duplicates — second store is a no-op via local map check.""" handler = _make_mock_handler() - result1 = handler.store(key="1", info=MemoryInfo(view=memoryview(b"v1"))) + h1 = handler.alloc(size=2) + h1.buf[:2] = b"v1" + result1 = handler.store(key="1", handle=h1) assert result1 is True assert "1" in handler._key_to_location # Second store same key: skipped via local _key_to_location check, returns True - result2 = handler.store(key="1", info=MemoryInfo(view=memoryview(b"v2"))) + h2 = handler.alloc(size=2) + h2.buf[:2] = b"v2" + result2 = handler.store(key="1", handle=h2) assert result2 is True # register_kv called only once (second store skipped before allocation) handler._rpc.register_kv.assert_called_once() @@ -1043,67 +890,6 @@ def test_store_overwrite_existing_key(self): handler.close() - def test_store_expand_succeeds_but_second_alloc_none(self): - """L265-267: expand succeeds but second allocate still returns None.""" - handler = _make_mock_handler(pool_size=1024, chunk_size=1024) - - # Fill the single page - handler.store(key="1", info=MemoryInfo(view=memoryview(b"fill"))) - - # expand succeeds but the new region also has no free pages - expand_response = MagicMock() - expand_response.success = True - expand_response.handle = _make_handle(200, 1024) - handler._rpc.request_alloc = MagicMock(return_value=expand_response) - - # Patch allocate to return None even after expand - original_allocate = handler._owned.allocate - - call_count = [0] - - def always_none_after_first(): - call_count[0] += 1 - # Both calls return None (first triggers expand, second still None) - return None - - handler._owned.allocate = always_none_after_first - - result = handler.store(key="2", info=MemoryInfo(view=memoryview(b"data"))) - assert result is False - - handler._owned.allocate = original_allocate - handler.close() - - def test_store_get_buffer_view_none(self): - """L278-279: get_buffer_view returns None in store.""" - handler = _make_mock_handler() - - original_get_buf = handler._mapper.get_buffer_view - - def return_none(region_id, offset, size): - return None - - handler._mapper.get_buffer_view = return_none - - result = handler.store(key="1", info=MemoryInfo(view=memoryview(b"data"))) - assert result is False - - handler._mapper.get_buffer_view = original_get_buf - handler.close() - - def test_store_with_prefix(self): - """L284-285: prefix writing path in store.""" - handler = _make_mock_handler() - - prefix = b"\x01\x02" - data = b"hello" - result = handler.store( - key="1", info=MemoryInfo(view=memoryview(data)), prefix=prefix - ) - assert result is True - - handler.close() - # ================================================================= # delete() happy path # ================================================================= @@ -1137,7 +923,9 @@ def test_delete_with_local_tracking(self): """L396-397: delete key that IS in _key_to_location — frees page.""" handler = _make_mock_handler() - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data"))) + h = handler.alloc(size=4) + h.buf[:4] = b"data" + handler.store(key="1", handle=h) assert "1" in handler._key_to_location result = handler.delete(key="1") @@ -1176,7 +964,7 @@ def test_get_stats(self): # ================================================================= def test_batch_store_happy_path(self): - """L541-644: Full batch_store happy path including register.""" + """Full batch_store happy path including register.""" handler = _make_mock_handler() # batch_exists_kv: neither key on server @@ -1189,82 +977,17 @@ def test_batch_store_happy_path(self): batch_resp.results = [True, True] handler._rpc.batch_register_kv = MagicMock(return_value=batch_resp) - results = handler.batch_store( - keys=["1", "2"], - infos=[ - MemoryInfo(view=memoryview(b"d1")), - MemoryInfo(view=memoryview(b"d2")), - ], - ) + h1 = handler.alloc(size=2) + h1.buf[:2] = b"d1" + h2 = handler.alloc(size=2) + h2.buf[:2] = b"d2" + results = handler.batch_store(keys=["1", "2"], handles=[h1, h2]) assert results == [True, True] assert "1" in handler._key_to_location assert "2" in handler._key_to_location handler.close() - def test_batch_store_with_prefixes(self): - """L600-601: prefix writing path in batch_store.""" - handler = _make_mock_handler() - - # batch_exists_kv: key not on server - batch_exists_resp = MagicMock() - batch_exists_resp.results = [False] - handler._rpc.batch_exists_kv = MagicMock(return_value=batch_exists_resp) - - batch_resp = MagicMock() - batch_resp.success = True - batch_resp.results = [True] - handler._rpc.batch_register_kv = MagicMock(return_value=batch_resp) - - results = handler.batch_store( - keys=["1"], - infos=[MemoryInfo(view=memoryview(b"data"))], - prefixes=[b"\x01\x02"], - ) - assert results == [True] - - handler.close() - - def test_batch_store_expand_second_alloc_none(self): - """L581-584: expand succeeds but second allocate returns None in batch_store.""" - handler = _make_mock_handler(pool_size=1024, chunk_size=1024) - - # Fill the single page - handler.store(key="1", info=MemoryInfo(view=memoryview(b"fill"))) - - # batch_exists_kv: key 2 not on server - batch_exists_resp = MagicMock() - batch_exists_resp.results = [False] - handler._rpc.batch_exists_kv = MagicMock(return_value=batch_exists_resp) - - # expand succeeds - expand_response = MagicMock() - expand_response.success = True - expand_response.handle = _make_handle(200, 1024) - handler._rpc.request_alloc = MagicMock(return_value=expand_response) - - # Patch allocate to return None even after expand - original_allocate = handler._owned.allocate - - def always_none(): - return None - - handler._owned.allocate = always_none - - batch_resp = MagicMock() - batch_resp.success = True - batch_resp.results = [] - handler._rpc.batch_register_kv = MagicMock(return_value=batch_resp) - - results = handler.batch_store( - keys=["2"], - infos=[MemoryInfo(view=memoryview(b"data"))], - ) - assert results == [False] - - handler._owned.allocate = original_allocate - handler.close() - # ================================================================= # batch_exists() happy path # ================================================================= @@ -1320,11 +1043,13 @@ def test_owned_region_manager_property(self): # ================================================================= def test_expand_region_happy_path(self): - """L732-733: expand succeeds — add_region works, returns True.""" - handler = _make_mock_handler(pool_size=1024, chunk_size=1024) + """expand succeeds — add_region works, new alloc succeeds.""" + handler = _make_mock_handler(pool_size=1024, chunk_size=1024, auto_expand=True) # Fill the single page - handler.store(key="1", info=MemoryInfo(view=memoryview(b"fill"))) + h1 = handler.alloc(size=4) + h1.buf[:4] = b"fill" + handler.store(key="1", handle=h1) # request_alloc returns a new valid region expand_response = MagicMock() @@ -1333,20 +1058,24 @@ def test_expand_region_happy_path(self): handler._rpc.request_alloc = MagicMock(return_value=expand_response) # Store another key — triggers expansion - result = handler.store(key="2", info=MemoryInfo(view=memoryview(b"data2"))) + h2 = handler.alloc(size=5) + h2.buf[:5] = b"data2" + result = handler.store(key="2", handle=h2) assert result is True assert handler._owned.get_stats()["num_regions"] == 2 handler.close() def test_expand_region_add_region_raises_and_return_alloc_raises(self, monkeypatch): - """L738-739: return_alloc also raises during expand cleanup.""" + """return_alloc also raises during expand cleanup.""" from maru_handler.memory import OwnedRegionManager - handler = _make_mock_handler(pool_size=1024, chunk_size=1024) + handler = _make_mock_handler(pool_size=1024, chunk_size=1024, auto_expand=True) # Fill the single page - handler.store(key="1", info=MemoryInfo(view=memoryview(b"fill"))) + h1 = handler.alloc(size=4) + h1.buf[:4] = b"fill" + handler.store(key="1", handle=h1) expand_response = MagicMock() expand_response.success = True @@ -1363,8 +1092,9 @@ def failing_add(self_mgr, handle): monkeypatch.setattr(OwnedRegionManager, "add_region", failing_add) - result = handler.store(key="2", info=MemoryInfo(view=memoryview(b"data"))) - assert result is False + # alloc triggers the expansion which fails + with pytest.raises((ValueError, RuntimeError)): + handler.alloc(size=4) handler.close() @@ -1462,17 +1192,13 @@ def test_exists_happy_path(self): # batch_store() keys/infos mismatch # ================================================================= - def test_batch_store_keys_infos_mismatch(self): - """L533: batch_store with mismatched keys/infos raises ValueError.""" + def test_batch_store_keys_handles_mismatch(self): + """batch_store with mismatched keys/handles raises ValueError.""" handler = _make_mock_handler() - with pytest.raises( - ValueError, match="keys and infos must have the same length" - ): - handler.batch_store( - keys=["1", "2"], - infos=[MemoryInfo(view=memoryview(b"only_one"))], - ) + h = handler.alloc(size=4) + with pytest.raises(ValueError): + handler.batch_store(keys=["1", "2"], handles=[h]) handler.close() @@ -1496,21 +1222,23 @@ class TestMaruHandlerDuplicateSkip: """Test store/batch_store duplicate key skip paths.""" def test_store_skipped_by_server_exists(self): - """L258-259: store() skips when exists_kv returns True (server-side dup).""" + """store() skips when exists_kv returns True (server-side dup).""" handler = _make_mock_handler() # Key NOT in local map, but server says it exists handler._rpc.exists_kv = MagicMock(return_value=True) - result = handler.store(key="42", info=MemoryInfo(view=memoryview(b"data"))) + h = handler.alloc(size=4) + h.buf[:4] = b"data" + result = handler.store(key="42", handle=h) assert result is True - # Should not have allocated or registered + # Should not have registered handler._rpc.register_kv.assert_not_called() assert "42" not in handler._key_to_location handler.close() def test_store_register_race_frees_page(self): - """L303-310: register_kv returns is_new=False (race), page is freed.""" + """register_kv returns is_new=False (race), page is freed.""" handler = _make_mock_handler() # exists_kv returns False so store proceeds past dup check @@ -1518,7 +1246,9 @@ def test_store_register_race_frees_page(self): # register_kv returns False (another instance registered between check and register) handler._rpc.register_kv = MagicMock(return_value=False) - result = handler.store(key="77", info=MemoryInfo(view=memoryview(b"data"))) + h = handler.alloc(size=4) + h.buf[:4] = b"data" + result = handler.store(key="77", handle=h) assert result is True # Key should NOT be in local map (race lost) assert "77" not in handler._key_to_location @@ -1526,7 +1256,7 @@ def test_store_register_race_frees_page(self): handler.close() def test_batch_store_server_duplicate_skip(self): - """L606-609: batch_store skips keys that exist on server.""" + """batch_store skips keys that exist on server.""" handler = _make_mock_handler() batch_exists_resp = MagicMock() @@ -1534,13 +1264,11 @@ def test_batch_store_server_duplicate_skip(self): batch_exists_resp.results = [True, False] handler._rpc.batch_exists_kv = MagicMock(return_value=batch_exists_resp) - results = handler.batch_store( - keys=["1", "2"], - infos=[ - MemoryInfo(view=memoryview(b"data1")), - MemoryInfo(view=memoryview(b"data2")), - ], - ) + h1 = handler.alloc(size=5) + h1.buf[:5] = b"data1" + h2 = handler.alloc(size=5) + h2.buf[:5] = b"data2" + results = handler.batch_store(keys=["1", "2"], handles=[h1, h2]) # Both should succeed (key 1 skipped as dup, key 2 stored) assert results == [True, True] # Only key 2 should be in local map @@ -1550,15 +1278,14 @@ def test_batch_store_server_duplicate_skip(self): handler.close() def test_batch_store_batch_exists_rpc_failure(self): - """L583-585: batch_exists_kv RPC fails, falls back to [False]*len.""" + """batch_exists_kv RPC fails, falls back to [False]*len.""" handler = _make_mock_handler() handler._rpc.batch_exists_kv = MagicMock(side_effect=RuntimeError("RPC failed")) - results = handler.batch_store( - keys=["1"], - infos=[MemoryInfo(view=memoryview(b"data"))], - ) + h = handler.alloc(size=4) + h.buf[:4] = b"data" + results = handler.batch_store(keys=["1"], handles=[h]) # Should still succeed — fallback treats all keys as new assert results == [True] assert "1" in handler._key_to_location @@ -1566,7 +1293,7 @@ def test_batch_store_batch_exists_rpc_failure(self): handler.close() def test_batch_store_some_exist_log(self): - """L589: batch_store logs when some keys are skipped.""" + """batch_store logs when some keys are skipped.""" handler = _make_mock_handler() batch_exists_resp = MagicMock() @@ -1574,14 +1301,13 @@ def test_batch_store_some_exist_log(self): batch_exists_resp.results = [True, True, True] handler._rpc.batch_exists_kv = MagicMock(return_value=batch_exists_resp) - results = handler.batch_store( - keys=["1", "2", "3"], - infos=[ - MemoryInfo(view=memoryview(b"a")), - MemoryInfo(view=memoryview(b"b")), - MemoryInfo(view=memoryview(b"c")), - ], - ) + h1 = handler.alloc(size=1) + h1.buf[:1] = b"a" + h2 = handler.alloc(size=1) + h2.buf[:1] = b"b" + h3 = handler.alloc(size=1) + h3.buf[:1] = b"c" + results = handler.batch_store(keys=["1", "2", "3"], handles=[h1, h2, h3]) # All skipped but reported as True (idempotent) assert results == [True, True, True] # None should be in local map @@ -1630,14 +1356,18 @@ class TestMaruHandlerExpandFailure: """Test store behavior when expansion fails (mocked RPC).""" def test_store_fails_when_expansion_fails(self): - """Pre-fill all pages, mock request_alloc to fail, verify store returns False.""" + """Pre-fill all pages, mock request_alloc to fail, verify alloc raises.""" from maru_common.protocol import RequestAllocResponse handler = _make_mock_handler(pool_size=2048, chunk_size=1024) # Fill all 2 pages in the initial region - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data1"))) - handler.store(key="2", info=MemoryInfo(view=memoryview(b"data2"))) + h1 = handler.alloc(size=5) + h1.buf[:5] = b"data1" + handler.store(key="1", handle=h1) + h2 = handler.alloc(size=5) + h2.buf[:5] = b"data2" + handler.store(key="2", handle=h2) assert handler.allocator.num_free_pages == 0 @@ -1646,9 +1376,9 @@ def test_store_fails_when_expansion_fails(self): return_value=RequestAllocResponse(success=False, handle=None) ) - # Try to store a new key — should trigger expansion and fail - result = handler.store(key="3", info=MemoryInfo(view=memoryview(b"data3"))) - assert result is False + # Try to alloc a new page — should trigger expansion and fail + with pytest.raises((ValueError, RuntimeError)): + handler.alloc(size=5) handler.close() @@ -1824,16 +1554,17 @@ class TestAlloc: """Tests for MaruHandler.alloc() zero-copy allocation.""" def test_alloc_returns_alloc_handle(self): - """alloc() returns AllocHandle with writable memoryview.""" + """alloc() returns AllocHandle with correct attributes.""" from maru_handler.memory import AllocHandle handler = _make_mock_handler() handle = handler.alloc(size=512) assert isinstance(handle, AllocHandle) - assert isinstance(handle.buf, memoryview) - assert not handle.buf.readonly - assert len(handle.buf) >= 512 + buf = handle.buf + assert isinstance(buf, memoryview) + assert not buf.readonly + assert len(buf) >= 512 handler.close() def test_alloc_buf_is_writable(self): @@ -1841,8 +1572,9 @@ def test_alloc_buf_is_writable(self): handler = _make_mock_handler() handle = handler.alloc(size=64) - handle.buf[:5] = b"hello" - assert bytes(handle.buf[:5]) == b"hello" + buf = handle.buf + buf[:5] = b"hello" + assert bytes(buf[:5]) == b"hello" handler.close() def test_alloc_exceeds_chunk_size_raises(self): @@ -1893,10 +1625,12 @@ def test_multiple_alloc_independent(self): h2 = handler.alloc(size=100) assert h1._page_index != h2._page_index or h1._region_id != h2._region_id - h1.buf[:3] = b"aaa" - h2.buf[:3] = b"bbb" - assert bytes(h1.buf[:3]) == b"aaa" - assert bytes(h2.buf[:3]) == b"bbb" + buf1 = h1.buf + buf2 = h2.buf + buf1[:3] = b"aaa" + buf2[:3] = b"bbb" + assert bytes(buf1[:3]) == b"aaa" + assert bytes(buf2[:3]) == b"bbb" handler.close() @@ -1914,7 +1648,8 @@ def test_free_after_store(self): """free() after store removes key from _key_to_location.""" handler = _make_mock_handler() handle = handler.alloc(size=64) - handle.buf[:5] = b"hello" + buf = handle.buf + buf[:5] = b"hello" handler.store(key="42", handle=handle) assert "42" in handler._key_to_location @@ -1950,7 +1685,8 @@ def test_store_with_handle_happy_path(self): """alloc -> write -> store(handle=) full flow.""" handler = _make_mock_handler() handle = handler.alloc(size=64) - handle.buf[:5] = b"hello" + buf = handle.buf + buf[:5] = b"hello" result = handler.store(key="42", handle=handle) assert result is True @@ -1959,24 +1695,13 @@ def test_store_with_handle_happy_path(self): handler._rpc.register_kv.assert_called_once() handler.close() - def test_store_with_handle_no_memcpy(self): - """handle path does not call _gil_free_memcpy.""" - from unittest.mock import patch as mock_patch - - handler = _make_mock_handler() - handle = handler.alloc(size=64) - handle.buf[:5] = b"hello" - - with mock_patch("maru_handler.handler._gil_free_memcpy") as mock_memcpy: - handler.store(key="42", handle=handle) - mock_memcpy.assert_not_called() - handler.close() - def test_store_with_handle_duplicate_key_skips(self): """store(handle=) skips when key already exists.""" handler = _make_mock_handler() - handler.store(key="42", data=memoryview(b"first")) + h0 = handler.alloc(size=5) + h0.buf[:5] = b"first" + handler.store(key="42", handle=h0) handle = handler.alloc(size=64) result = handler.store(key="42", handle=handle) @@ -1990,70 +1715,46 @@ def test_store_with_handle_register_race(self): handler._rpc.register_kv = MagicMock(return_value=False) handle = handler.alloc(size=64) - handle.buf[:5] = b"hello" + buf = handle.buf + buf[:5] = b"hello" result = handler.store(key="42", handle=handle) assert result is True assert "42" not in handler._key_to_location handler.close() - def test_store_with_handle_and_data_raises(self): - """Providing both handle and data raises ValueError.""" - handler = _make_mock_handler() - handle = handler.alloc(size=64) - - with pytest.raises(ValueError, match="Cannot specify both"): - handler.store(key="42", handle=handle, data=memoryview(b"conflict")) - handler.close() - - def test_store_with_handle_and_info_raises(self): - """Providing both handle and info raises ValueError.""" - handler = _make_mock_handler() - handle = handler.alloc(size=64) - - with pytest.raises(ValueError, match="Cannot specify both"): - handler.store( - key="42", handle=handle, info=MemoryInfo(view=memoryview(b"x")) - ) - handler.close() - class TestStoreWithHandleCompat: - """Ensure store() without handle remains unaffected.""" + """Ensure store() via handle API works in various patterns.""" - def test_store_without_handle(self): - """store() without handle uses allocate+memcpy path.""" + def test_store_via_alloc_and_handle(self): + """alloc -> handle.buf -> store(handle=) succeeds.""" handler = _make_mock_handler() - result = handler.store(key="42", data=memoryview(b"hello")) + h = handler.alloc(size=5) + h.buf[:5] = b"hello" + result = handler.store(key="42", handle=h) assert result is True assert "42" in handler._key_to_location handler._rpc.register_kv.assert_called_once() handler.close() - def test_store_with_prefix(self): - """store() with prefix still works without handle.""" - handler = _make_mock_handler() - result = handler.store( - key="42", - info=MemoryInfo(view=memoryview(b"data")), - prefix=b"\x01\x02", - ) - assert result is True - handler.close() - - def test_mixed_store_modes(self): - """Interleaving store with and without handle works correctly.""" + def test_multiple_store_via_handles(self): + """Multiple alloc+store calls all succeed.""" handler = _make_mock_handler(pool_size=8192, chunk_size=1024) - handler.store(key="1", data=memoryview(b"data1")) + h1 = handler.alloc(size=5) + h1.buf[:5] = b"data1" + handler.store(key="1", handle=h1) assert "1" in handler._key_to_location - h = handler.alloc(size=64) - h.buf[:6] = b"handle" - handler.store(key="2", handle=h) + h2 = handler.alloc(size=6) + h2.buf[:6] = b"handle" + handler.store(key="2", handle=h2) assert "2" in handler._key_to_location - handler.store(key="3", data=memoryview(b"data2")) + h3 = handler.alloc(size=5) + h3.buf[:5] = b"data2" + handler.store(key="3", handle=h3) assert "3" in handler._key_to_location assert handler._rpc.register_kv.call_count == 3 @@ -2074,7 +1775,8 @@ def test_retrieve_after_store_with_handle(self): """alloc -> store(handle=) -> retrieve returns data from same mmap region.""" handler = _make_mock_handler() handle = handler.alloc(size=64) - handle.buf[:4] = b"hell" + buf = handle.buf + buf[:4] = b"hell" handler.store(key="42", handle=handle) # Mock lookup_kv returns kv_length=4, so retrieve returns 4 bytes @@ -2136,3 +1838,179 @@ def alloc_and_store(key): for i in range(4): assert i in handler._key_to_location handler.close() + + +# ============================================================================= +# Fixed Pool / Auto-Expand Tests +# ============================================================================= + + +class TestFixedPoolAllocation: + """Tests for fixed pool allocation with optional auto-expand.""" + + def test_alloc_raises_when_pool_exhausted_no_expand(self): + """auto_expand=False: alloc raises ValueError when pool is full.""" + handler = _make_mock_handler(pool_size=1024, chunk_size=1024, auto_expand=False) + assert handler._auto_expand is False + + # Fill the single page + h = handler.alloc(size=4) + h.buf[:4] = b"fill" + handler.store(key="1", handle=h) + + # Next alloc should raise — no expansion attempted + with pytest.raises(ValueError, match="auto_expand is disabled"): + handler.alloc(size=4) + + handler.close() + + def test_expand_uses_expand_size(self): + """auto_expand=True with custom expand_size uses that size for RPC.""" + handler = _make_mock_handler( + pool_size=1024, chunk_size=1024, auto_expand=True, expand_size=2048 + ) + + # Fill the single page + h = handler.alloc(size=4) + h.buf[:4] = b"fill" + handler.store(key="1", handle=h) + + # Setup expand response + expand_response = MagicMock() + expand_response.success = True + expand_response.handle = _make_handle(200, 2048) + handler._rpc.request_alloc = MagicMock(return_value=expand_response) + + # Trigger expansion + h2 = handler.alloc(size=4) + assert h2 is not None + + # Verify request_alloc was called with expand_size=2048, not pool_size=1024 + call_args = handler._rpc.request_alloc.call_args + assert call_args.kwargs["size"] == 2048 + + handler.close() + + def test_expand_size_defaults_to_pool_size(self): + """auto_expand=True without expand_size uses pool_size for expansion.""" + handler = _make_mock_handler(pool_size=1024, chunk_size=1024, auto_expand=True) + assert handler._expand_size == 1024 # defaults to pool_size + + handler.close() + + def test_connect_multi_pool_aggregation(self): + """connect() aggregates regions from multiple pools.""" + from maru_common import MaruConfig + from maru_handler.handler import MaruHandler + + config = MaruConfig( + pool_size=2048, + chunk_size_bytes=1024, + auto_connect=False, + use_async_rpc=False, + pool_id=[0, 1], + ) + handler = MaruHandler(config) + + mock_rpc = MagicMock() + mock_rpc.connect = MagicMock() + mock_rpc.return_alloc = MagicMock() + mock_rpc.close = MagicMock() + + # Pool 0 gives 1024 bytes, pool 1 gives 1024 bytes → total 2048 + resp0 = MagicMock() + resp0.success = True + resp0.handle = _make_handle(100, 1024) + + resp1 = MagicMock() + resp1.success = True + resp1.handle = _make_handle(200, 1024) + + mock_rpc.request_alloc = MagicMock(side_effect=[resp0, resp1]) + + # list_allocations for premap + list_resp = MagicMock() + list_resp.success = True + list_resp.allocations = [] + mock_rpc.list_allocations = MagicMock(return_value=list_resp) + + handler._rpc = mock_rpc + result = handler.connect() + + assert result is True + assert handler.connected is True + # Should have 2 regions + assert len(handler.get_owned_region_ids()) == 2 + + handler.close() + + def test_connect_multi_pool_partial_cleanup(self): + """connect() cleans up partial allocations if remaining > 0.""" + from maru_common import MaruConfig + from maru_handler.handler import MaruHandler + + config = MaruConfig( + pool_size=3072, + chunk_size_bytes=1024, + auto_connect=False, + use_async_rpc=False, + pool_id=[0, 1], + ) + handler = MaruHandler(config) + + mock_rpc = MagicMock() + mock_rpc.connect = MagicMock() + mock_rpc.return_alloc = MagicMock() + mock_rpc.close = MagicMock() + + # Pool 0 gives 1024 bytes, pool 1 fails → only 1024 of 3072 + resp0 = MagicMock() + resp0.success = True + resp0.handle = _make_handle(100, 1024) + + resp1 = MagicMock() + resp1.success = False + resp1.handle = None + resp1.error = "pool full" + + mock_rpc.request_alloc = MagicMock(side_effect=[resp0, resp1]) + + handler._rpc = mock_rpc + result = handler.connect() + + assert result is False + assert handler.connected is False + # return_alloc should have been called to clean up pool 0's region + mock_rpc.return_alloc.assert_called() + + def test_alloc_expand_disabled_error_message(self): + """Error message distinguishes disabled vs failed expansion.""" + handler = _make_mock_handler(pool_size=1024, chunk_size=1024, auto_expand=False) + + h = handler.alloc(size=4) + h.buf[:4] = b"fill" + handler.store(key="1", handle=h) + + with pytest.raises(ValueError, match="auto_expand is disabled"): + handler.alloc(size=4) + + handler.close() + + def test_alloc_expand_failed_error_message(self): + """Error message when expansion is enabled but fails.""" + handler = _make_mock_handler(pool_size=1024, chunk_size=1024, auto_expand=True) + + h = handler.alloc(size=4) + h.buf[:4] = b"fill" + handler.store(key="1", handle=h) + + # Make expand fail + fail_resp = MagicMock() + fail_resp.success = False + fail_resp.handle = None + handler._rpc.request_alloc = MagicMock(return_value=fail_resp) + + with pytest.raises(ValueError, match="after expansion attempt"): + handler.alloc(size=4) + + handler.close() diff --git a/tests/unit/test_memory_types.py b/tests/unit/test_memory_types.py index d645218..3167e5c 100644 --- a/tests/unit/test_memory_types.py +++ b/tests/unit/test_memory_types.py @@ -281,8 +281,7 @@ def test_alloc_handle_properties(self): """AllocHandle.region_id, page_index, size return correct values.""" from maru_handler.memory.types import AllocHandle - data = bytearray(256) - buf = memoryview(data) + buf = memoryview(bytearray(256)) handle = AllocHandle(buf=buf, _region_id=42, _page_index=7, _size=256) assert handle.region_id == 42 diff --git a/tests/unit/test_thread_safety.py b/tests/unit/test_thread_safety.py index 6f6b879..95a0dc1 100644 --- a/tests/unit/test_thread_safety.py +++ b/tests/unit/test_thread_safety.py @@ -13,7 +13,7 @@ import pytest from conftest import _make_handle -from maru_handler.memory import DaxMapper, MemoryInfo, OwnedRegionManager +from maru_handler.memory import DaxMapper, OwnedRegionManager # ============================================================================= # Helpers @@ -228,6 +228,14 @@ def _make_mock_handler(): return handler +def _store_data(handler, key: str, data: bytes) -> bool: + """Helper: alloc → handle.buf → write → store(handle).""" + handle = handler.alloc(size=len(data)) + buf = handle.buf + buf[: len(data)] = data + return handler.store(key=key, handle=handle) + + class TestConcurrentStore: """Concurrent store operations — data integrity.""" @@ -237,7 +245,7 @@ def test_concurrent_store_unique_keys(self): num_threads = 8 # 8 pages available def store_one(idx): - return handler.store(key=idx, info=MemoryInfo(view=memoryview(b"data"))) + return _store_data(handler, key=str(idx), data=b"data") results = _run_threads(store_one, [() for _ in range(num_threads)]) @@ -258,9 +266,7 @@ def test_concurrent_store_same_key(self): num_threads = 4 def store_one(idx): - return handler.store( - key="42", info=MemoryInfo(view=memoryview(f"v{idx}".encode())) - ) + return _store_data(handler, key="42", data=f"v{idx}".encode()) results = _run_threads(store_one, [() for _ in range(num_threads)]) @@ -281,7 +287,7 @@ def test_concurrent_retrieve(self): handler = _make_mock_handler() # Store some data first - handler.store(key="1", info=MemoryInfo(view=memoryview(b"test"))) + _store_data(handler, key="1", data=b"test") num_threads = 8 @@ -324,7 +330,7 @@ def test_retrieve_not_blocked_by_store(self): handler = _make_mock_handler() # Store initial data - handler.store(key="1", info=MemoryInfo(view=memoryview(b"init"))) + _store_data(handler, key="1", data=b"init") store_started = threading.Event() store_proceed = threading.Event() @@ -342,7 +348,7 @@ def slow_register(*args, **kwargs): retrieve_result = [None] def do_store(): - handler.store(key="2", info=MemoryInfo(view=memoryview(b"slow"))) + _store_data(handler, key="2", data=b"slow") def do_retrieve(): store_started.wait(timeout=5) @@ -398,7 +404,7 @@ def do_store(): # _closing should be True now, but close hasn't finished time.sleep(0.01) # small delay to ensure _closing is set try: - handler.store(key="99", info=MemoryInfo(view=memoryview(b"rejected"))) + _store_data(handler, key="99", data=b"rejected") except RuntimeError as e: store_error[0] = e @@ -439,7 +445,7 @@ def slow_register(*args, **kwargs): close_done = threading.Event() def do_store(): - handler.store(key="1", info=MemoryInfo(view=memoryview(b"data"))) + _store_data(handler, key="1", data=b"data") def do_close(): store_started.wait(timeout=5)