From da02f027c6ac65d317aeab34abc1913b9211e194 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Jun 2026 00:57:25 +0000 Subject: [PATCH 01/31] Initial plan From effa7e65b564e80737d008291bf1296bf7ec4a82 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Jun 2026 01:12:25 +0000 Subject: [PATCH 02/31] Make MP non-GPU store path fully async with preemption-aware flush --- .../integration/vllm/lmcache_mp_connector.py | 42 +++- .../vllm/vllm_multi_process_adapter.py | 12 + .../transfer_context/worker_transfer.py | 231 +++++++++++++++-- .../test_async_data_transfer_context.py | 235 ++++++++++++++++++ .../test_lmcache_mp_connector_preemption.py | 106 ++++++++ tests/v1/test_vllm_mp_adapter.py | 12 + 6 files changed, 611 insertions(+), 27 deletions(-) create mode 100644 tests/v1/multiprocess/test_async_data_transfer_context.py create mode 100644 tests/v1/test_lmcache_mp_connector_preemption.py diff --git a/lmcache/integration/vllm/lmcache_mp_connector.py b/lmcache/integration/vllm/lmcache_mp_connector.py index 27fa2de489..2137d42239 100644 --- a/lmcache/integration/vllm/lmcache_mp_connector.py +++ b/lmcache/integration/vllm/lmcache_mp_connector.py @@ -496,6 +496,11 @@ class LMCacheMPConnectorMetadata(KVConnectorMetadata): def __init__(self): super().__init__() self.requests: list[LMCacheMPRequestMetadata] = [] + # True only on scheduler steps where preemption/eviction may overwrite + # blocks referenced by in-flight async stores. Worker-side + # handle_preemptions() uses this hint to flush deferred gather + # (device->CPU copy) work before vLLM can overwrite source KV blocks. + self.need_flush: bool = False def add_request_metadata(self, request_metadata: LMCacheMPRequestMetadata): self.requests.append(request_metadata) @@ -513,7 +518,11 @@ def __str__(self): f"num_blocks={len(req_meta.op.block_ids[0])}, " f"block_ids={req_meta.op.block_ids})" ) - return "[" + "\n".join(request_strs) + "]" + return ( + f"need_flush={self.need_flush}; [" + + "\n".join(request_strs) + + "]" + ) def __repr__(self): return self.__str__() @@ -741,6 +750,17 @@ def wait_for_save(self): request_ids, ops, event, cache_salts=cache_salts ) + def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata) -> None: + """Flush async non-GPU gathers only when scheduler metadata requests it.""" + worker_adapter = getattr(self, "worker_adapter", None) + if self.role != KVConnectorRole.WORKER or worker_adapter is None: + return + need_flush = ( + isinstance(kv_connector_metadata, LMCacheMPConnectorMetadata) + and kv_connector_metadata.need_flush + ) + worker_adapter.handle_preemptions(need_flush) + def get_finished( self, finished_req_ids: set[str] ) -> tuple[set[str] | None, set[str] | None]: @@ -963,6 +983,7 @@ def build_connector_meta( scheduler_output (SchedulerOutput): the scheduler output object. """ metadata = LMCacheMPConnectorMetadata() + metadata.need_flush = self._scheduler_step_needs_flush(scheduler_output) self._process_retrieve_requests(metadata) self._process_new_requests(scheduler_output, metadata) @@ -976,6 +997,25 @@ def build_connector_meta( return metadata + def _scheduler_step_needs_flush(self, scheduler_output: SchedulerOutput) -> bool: + """Return whether this scheduler step can overwrite preempted blocks.""" + # vLLM 0.10+ exposes resumed preemptions through cached-request fields. + cached_reqs = getattr(scheduler_output, "scheduled_cached_reqs", None) + resumed_req_ids = getattr(cached_reqs, "resumed_req_ids", None) + if resumed_req_ids: + return True + resumed_flags = getattr(cached_reqs, "resumed_from_preemption", None) + if resumed_flags and any(resumed_flags): + return True + # Conservative fallback for alternate scheduler output schemas. + preempted_req_ids = getattr(scheduler_output, "preempted_req_ids", None) + if preempted_req_ids: + return True + evicted_req_ids = getattr(scheduler_output, "evicted_req_ids", None) + if evicted_req_ids: + return True + return False + def update_connector_output(self, connector_output: KVConnectorOutput): """ Update KVConnector state from worker-side connectors output. diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 86578b22db..3f855c49e5 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -1430,6 +1430,18 @@ def get_block_ids_with_load_errors(self) -> set[int]: self.error_block_ids.clear() return errors + def handle_preemptions(self, need_flush: bool) -> None: + """Handle worker-side preemption hints from connector metadata. + + When ``need_flush`` is true, synchronize deferred non-GPU gather work + before the next forward pass can overwrite paged KV blocks. + """ + if not need_flush: + return + if not self.is_healthy or self.transfer_ctx is None: + return + self.transfer_ctx.flush_inflight_gathers() + def shutdown(self): """ Shutdown the LMCache MP worker adapter diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index 41c72ca7d2..245004d14c 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -4,9 +4,12 @@ # Standard from abc import ABC, abstractmethod from collections.abc import Sequence +from concurrent.futures import Future as ConcurrentFuture +from concurrent.futures import ThreadPoolExecutor from enum import Enum from typing import Any, Callable, Protocol import os +import threading # Third Party import torch @@ -39,6 +42,7 @@ # string values of :class:`MPTransferMode` (``auto`` / ``handle`` / # ``data``); ``auto`` reproduces the historical device-type-based dispatch. ENV_MP_TRANSFER_MODE = "LMCACHE_MP_TRANSFER_MODE" +DEFAULT_MAX_ASYNC_NON_GPU_STORES = 8 class MPTransferMode(str, Enum): @@ -213,6 +217,15 @@ def submit_retrieve( def close(self) -> None: """Release resources held by this context.""" + def flush_inflight_gathers(self) -> None: + """Synchronize any in-flight gather operations. + + The default implementation is a no-op. Non-GPU async save contexts can + override this to make preemption handling block until deferred reads of + vLLM paged KV data are complete. + """ + return None + class HandleTransferContext(TransferContext): """Handle-based IPC + MQ future transport context.""" @@ -303,12 +316,72 @@ def close(self) -> None: class DataTransferContext(TransferContext): - """Data transfer context for non-CUDA workers.""" + """Data transfer context for non-CUDA workers. - def __init__(self) -> None: + Store on the non-GPU path is two-phase and fully async: + 1) gather: enqueue GPU->CPU copies on a dedicated copy stream into + LMCache-owned pinned staging buffers (ordered behind the per-step event). + 2) commit: wait for gather completion in a background thread, then perform + commit_store() (pickle or SHM commit) and resolve the returned future. + + SHM note: SHM slots are generally pageable, so device->SHM DtoH copies may + implicitly synchronize. To keep gather async, we always gather into pinned + bounce buffers first, then copy to SHM slots on the commit thread. + """ + + def __init__( + self, max_inflight_stores: int = DEFAULT_MAX_ASYNC_NON_GPU_STORES + ) -> None: self._non_gpu_context: NonGpuContext | None = None self._layout_hints: LayoutHints | None = None self._gpu_kv_format: Any = None + self._copy_stream = torch_dev.Stream() + self._commit_executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="lmcache_non_gpu_commit" + ) + self._max_inflight_stores = max(1, int(max_inflight_stores)) + self._inflight_semaphore = threading.BoundedSemaphore(self._max_inflight_stores) + self._inflight_lock = threading.Lock() + self._inflight_gather_events: set[Any] = set() + self._inflight_commits: set[ConcurrentFuture[None]] = set() + self._staging_pool: dict[ + tuple[tuple[int, ...], torch.dtype], list[torch.Tensor] + ] = {} + self._is_closing = False + + def _alloc_pinned_staging( + self, shape: torch.Size, dtype: torch.dtype, count: int + ) -> list[torch.Tensor]: + key = (tuple(shape), dtype) + with self._inflight_lock: + pooled = self._staging_pool.setdefault(key, []) + staged = [pooled.pop() for _ in range(min(len(pooled), count))] + if len(staged) == count: + return staged + + missing = count - len(staged) + for _ in range(missing): + try: + staged.append( + torch.empty(shape, dtype=dtype, device="cpu", pin_memory=True) + ) + except RuntimeError: + # Graceful fallback for CPU-only / pin-memory-disabled setups. + logger.warning( + "Falling back to non-pinned CPU staging buffer " + "(shape=%s, dtype=%s)", + tuple(shape), + dtype, + ) + staged.append(torch.empty(shape, dtype=dtype, device="cpu")) + return staged + + def _release_staging(self, chunks: list[torch.Tensor]) -> None: + if not chunks: + return + key = (tuple(chunks[0].shape), chunks[0].dtype) + with self._inflight_lock: + self._staging_pool.setdefault(key, []).extend(chunks) def register( self, @@ -406,37 +479,128 @@ def submit_store( _event: IPCEvent, blocks_in_chunk: int, ) -> MessagingFuture: + completion: MessagingFuture[bool] = MessagingFuture() if self._non_gpu_context is None: raise RuntimeError( "Data transfer context is not registered. " "Call register() before submit_store()." ) - torch_dev.synchronize() - result = self._non_gpu_context.prepare_store(key, instance_id) - out_buffers, chunk_indices = result if result is not None else (None, None) - # All chunks already in cache — nothing to gather or commit. - if chunk_indices is not None and len(chunk_indices) == 0: - future: MessagingFuture[bool] = MessagingFuture() - future.set_result(True) - return future - cpu_chunks = gather_paged_kv_to_cpu( - kv_caches, - _single_group_block_ids(block_ids), - blocks_in_chunk, - layout_hints=self._layout_hints, - gpu_kv_format=self._gpu_kv_format, - out=out_buffers, - chunk_indices=chunk_indices, - ) - if out_buffers is not None: - # SHM path uses async device->CPU copies; complete them before commit. - torch_dev.synchronize() - ok = self._non_gpu_context.commit_store(key, instance_id, cpu_chunks) + self._inflight_semaphore.acquire() + staged_chunks: list[torch.Tensor] = [] + shm_out_buffers: list[torch.Tensor] | None = None + gather_done: Any | None = None + try: + with self._inflight_lock: + if self._is_closing: + completion.set_result(False) + self._inflight_semaphore.release() + return completion + + result = self._non_gpu_context.prepare_store(key, instance_id) + out_buffers, chunk_indices = result if result is not None else (None, None) + if chunk_indices is not None and len(chunk_indices) == 0: + # All chunks are already in cache: no gather, no commit. + completion.set_result(True) + self._inflight_semaphore.release() + return completion + + full_block_ids = _single_group_block_ids(block_ids) + num_chunks = ( + len(chunk_indices) + if chunk_indices is not None + else len(full_block_ids) // blocks_in_chunk + ) + if not self._non_gpu_context.layout_desc.shapes: + raise RuntimeError("non-GPU layout_desc.shapes is empty") + if not self._non_gpu_context.layout_desc.dtypes: + raise RuntimeError("non-GPU layout_desc.dtypes is empty") + staged_chunks = self._alloc_pinned_staging( + self._non_gpu_context.layout_desc.shapes[0], + self._non_gpu_context.layout_desc.dtypes[0], + num_chunks, + ) + shm_out_buffers = out_buffers + with torch_dev.stream(self._copy_stream): + _event.wait(stream=self._copy_stream) + gather_paged_kv_to_cpu( + kv_caches, + full_block_ids, + blocks_in_chunk, + layout_hints=self._layout_hints, + gpu_kv_format=self._gpu_kv_format, + out=staged_chunks, + chunk_indices=chunk_indices, + ) + gather_done = torch_dev.Event() + gather_done.record(self._copy_stream) - future = MessagingFuture() - future.set_result(ok) - return future + with self._inflight_lock: + if gather_done is not None: + self._inflight_gather_events.add(gather_done) + + def _commit_after_gather() -> None: + ok = False + try: + if gather_done is not None: + gather_done.synchronize() + if shm_out_buffers is not None: + if len(staged_chunks) != len(shm_out_buffers): + raise RuntimeError( + "SHM staging chunk count mismatch: " + f"{len(staged_chunks)} vs {len(shm_out_buffers)} " + f"(request_id={_request_id}, instance_id={instance_id})" + ) + for staged, shm_view in zip( + staged_chunks, shm_out_buffers, strict=True + ): + shm_view.copy_(staged) + ok = self._non_gpu_context.commit_store( + key, instance_id, shm_out_buffers + ) + else: + ok = self._non_gpu_context.commit_store( + key, instance_id, staged_chunks + ) + if not ok: + logger.error( + "Async non-GPU commit_store failed for request_id=%s", + _request_id, + ) + except Exception: + logger.exception( + "Async non-GPU store failed for request_id=%s", + _request_id, + ) + ok = False + finally: + self._release_staging(staged_chunks) + with self._inflight_lock: + if gather_done is not None: + self._inflight_gather_events.discard(gather_done) + completion.set_result(ok) + self._inflight_semaphore.release() + + commit_future = self._commit_executor.submit(_commit_after_gather) + with self._inflight_lock: + self._inflight_commits.add(commit_future) + + def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: + with self._inflight_lock: + self._inflight_commits.discard(done_future) + + commit_future.add_done_callback(_drop_commit_future) + return completion + except Exception: + logger.exception("Failed to submit async non-GPU store") + if staged_chunks: + self._release_staging(staged_chunks) + if gather_done is not None: + with self._inflight_lock: + self._inflight_gather_events.discard(gather_done) + completion.set_result(False) + self._inflight_semaphore.release() + return completion def submit_retrieve( self, @@ -481,10 +645,25 @@ def submit_retrieve( return future def close(self) -> None: + with self._inflight_lock: + self._is_closing = True + gather_events = list(self._inflight_gather_events) + for event in gather_events: + try: + event.synchronize() + except Exception: + logger.exception("Failed while draining gather events") + self._commit_executor.shutdown(wait=True, cancel_futures=False) if self._non_gpu_context is not None: self._non_gpu_context.close() self._non_gpu_context = None + def flush_inflight_gathers(self) -> None: + with self._inflight_lock: + gather_events = list(self._inflight_gather_events) + for event in gather_events: + event.synchronize() + def create_transfer_context( kv_caches: dict[str, torch.Tensor], diff --git a/tests/v1/multiprocess/test_async_data_transfer_context.py b/tests/v1/multiprocess/test_async_data_transfer_context.py new file mode 100644 index 0000000000..afb2b5d63e --- /dev/null +++ b/tests/v1/multiprocess/test_async_data_transfer_context.py @@ -0,0 +1,235 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from contextlib import nullcontext +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Callable +from unittest.mock import MagicMock +import threading + +# Third Party +import pytest +import torch + +# First Party +from lmcache.v1.multiprocess.transfer_context import worker_transfer +from lmcache.v1.multiprocess.transfer_context.worker_transfer import DataTransferContext + + +@dataclass +class _FakeStoreContext: + """Minimal non-GPU context for async store tests.""" + + commit_impl: Callable[[list[torch.Tensor]], bool] + prepare_result: tuple[list[torch.Tensor], list[int]] | None = None + + def __post_init__(self) -> None: + self.layout_desc = SimpleNamespace( + shapes=[torch.Size([2, 1, 1, 1])], dtypes=[torch.float32] + ) + + def prepare_store( + self, _key: object, _instance_id: int + ) -> tuple[list[torch.Tensor], list[int]] | None: + return self.prepare_result + + def commit_store( + self, _key: object, _instance_id: int, chunks: list[torch.Tensor] + ) -> bool: + return bool(self.commit_impl(chunks)) + + def close(self) -> None: + return None + + +class _FakeEvent: + def __init__(self, gate: threading.Event): + self._gate = gate + + def record(self, stream: object | None = None) -> None: + return None + + def wait(self, stream: object | None = None) -> None: + return None + + def synchronize(self) -> None: + self._gate.wait(timeout=2) + + def query(self) -> bool: + return self._gate.is_set() + + +class _FakeTorchDev: + def __init__(self, gather_gate: threading.Event): + self._stream = object() + self._gather_gate = gather_gate + + def Stream(self) -> object: + return object() + + def stream(self, stream: object) -> object: + return nullcontext(stream) + + def current_stream(self) -> object: + return self._stream + + def Event(self, interprocess: bool = False) -> _FakeEvent: + return _FakeEvent(self._gather_gate) + + +def _install_fake_gather(monkeypatch: pytest.MonkeyPatch) -> None: + def _gather( + _kv_caches: dict[str, torch.Tensor], + _block_ids: list[int], + _blocks_in_chunk: int, + **kwargs: object, + ) -> list[torch.Tensor]: + out = kwargs["out"] + assert isinstance(out, list) + for tensor in out: + tensor.fill_(1.0) + return out + + monkeypatch.setattr(worker_transfer, "gather_paged_kv_to_cpu", _gather) + + +def _new_context( + monkeypatch: pytest.MonkeyPatch, + *, + gather_gate: threading.Event, + commit_impl: Callable[[list[torch.Tensor]], bool], + max_inflight: int = 8, +) -> DataTransferContext: + monkeypatch.setattr(worker_transfer, "torch_dev", _FakeTorchDev(gather_gate)) + _install_fake_gather(monkeypatch) + ctx = DataTransferContext(max_inflight_stores=max_inflight) + ctx._non_gpu_context = _FakeStoreContext(commit_impl=commit_impl) + return ctx + + +def test_submit_store_returns_pending_future_until_gather_and_commit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + gather_gate = threading.Event() + ctx = _new_context( + monkeypatch, gather_gate=gather_gate, commit_impl=lambda _c: True + ) + future = ctx.submit_store( + "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 + ) + assert not future.query() + gather_gate.set() + assert future.result(timeout=1) is True + ctx.close() + + +def test_submit_store_commit_waits_for_gather_done( + monkeypatch: pytest.MonkeyPatch, +) -> None: + gather_gate = threading.Event() + commit_called = threading.Event() + + def _commit(_chunks: list[torch.Tensor]) -> bool: + commit_called.set() + return True + + ctx = _new_context(monkeypatch, gather_gate=gather_gate, commit_impl=_commit) + future = ctx.submit_store( + "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 + ) + assert not commit_called.wait(timeout=0.05) + gather_gate.set() + assert future.result(timeout=1) is True + assert commit_called.is_set() + ctx.close() + + +def test_submit_store_backpressure_blocks_when_inflight_cap_hit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + gather_gate = threading.Event() + commit_gate = threading.Event() + gather_gate.set() + + def _commit(_chunks: list[torch.Tensor]) -> bool: + commit_gate.wait(timeout=2) + return True + + ctx = _new_context( + monkeypatch, gather_gate=gather_gate, commit_impl=_commit, max_inflight=1 + ) + first = ctx.submit_store( + "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 + ) + done = threading.Event() + + def _submit_second() -> None: + try: + ctx.submit_store( + "r2", + object(), + 1, + {"k": torch.zeros(1)}, + [[0]], + _FakeEvent(gather_gate), + 1, + ) + finally: + done.set() + + t = threading.Thread(target=_submit_second, daemon=True) + t.start() + assert not done.wait(timeout=0.1) + commit_gate.set() + assert first.result(timeout=1) is True + t.join(timeout=1) + assert done.is_set() + ctx.close() + + +def test_close_drains_inflight_async_store(monkeypatch: pytest.MonkeyPatch) -> None: + gather_gate = threading.Event() + commit_gate = threading.Event() + gather_gate.set() + + def _commit(_chunks: list[torch.Tensor]) -> bool: + commit_gate.wait(timeout=2) + return True + + ctx = _new_context(monkeypatch, gather_gate=gather_gate, commit_impl=_commit) + future = ctx.submit_store( + "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 + ) + closed = threading.Event() + + def _close() -> None: + ctx.close() + closed.set() + + t = threading.Thread(target=_close, daemon=True) + t.start() + assert not closed.wait(timeout=0.05) + commit_gate.set() + t.join(timeout=1) + assert closed.is_set() + assert future.result(timeout=1) is True + + +def test_commit_failure_sets_false_and_logs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + gather_gate = threading.Event() + + def _commit(_chunks: list[torch.Tensor]) -> bool: + raise RuntimeError("commit failed") + + log_exception = MagicMock() + monkeypatch.setattr(worker_transfer.logger, "exception", log_exception) + ctx = _new_context(monkeypatch, gather_gate=gather_gate, commit_impl=_commit) + future = ctx.submit_store( + "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 + ) + gather_gate.set() + assert future.result(timeout=1) is False + log_exception.assert_called_once() + ctx.close() diff --git a/tests/v1/test_lmcache_mp_connector_preemption.py b/tests/v1/test_lmcache_mp_connector_preemption.py new file mode 100644 index 0000000000..636edaa283 --- /dev/null +++ b/tests/v1/test_lmcache_mp_connector_preemption.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from types import SimpleNamespace +from unittest.mock import MagicMock + +# Third Party +import pytest + +pytest.importorskip("vllm") + +# Third Party +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole + +# First Party +from lmcache.integration.vllm.lmcache_mp_connector import ( + LMCacheMPConnector, + LMCacheMPConnectorMetadata, +) + + +def _new_connector_without_init() -> LMCacheMPConnector: + connector = LMCacheMPConnector.__new__(LMCacheMPConnector) + return connector + + +def test_build_connector_meta_sets_need_flush_from_preemption_signal() -> None: + connector = _new_connector_without_init() + connector._process_retrieve_requests = lambda metadata: None + connector._process_new_requests = lambda scheduler_output, metadata: None + connector._process_cached_requests = lambda scheduler_output, metadata: None + connector._report_block_allocation_deltas = lambda scheduler_output: None + + scheduler_output = SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace(resumed_req_ids=["req-1"]) + ) + metadata = connector.build_connector_meta(scheduler_output) + assert isinstance(metadata, LMCacheMPConnectorMetadata) + assert metadata.need_flush is True + + +def test_build_connector_meta_keeps_need_flush_false_without_signal() -> None: + connector = _new_connector_without_init() + connector._process_retrieve_requests = lambda metadata: None + connector._process_new_requests = lambda scheduler_output, metadata: None + connector._process_cached_requests = lambda scheduler_output, metadata: None + connector._report_block_allocation_deltas = lambda scheduler_output: None + + scheduler_output = SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace( + resumed_req_ids=[], + resumed_from_preemption=[], + ), + preempted_req_ids=[], + evicted_req_ids=[], + ) + metadata = connector.build_connector_meta(scheduler_output) + assert isinstance(metadata, LMCacheMPConnectorMetadata) + assert metadata.need_flush is False + + +@pytest.mark.parametrize( + "scheduler_output", + [ + SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace( + resumed_req_ids=[], + resumed_from_preemption=[False, True], + ) + ), + SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace( + resumed_req_ids=[], + resumed_from_preemption=[], + ), + preempted_req_ids=["req-1"], + ), + SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace( + resumed_req_ids=[], + resumed_from_preemption=[], + ), + preempted_req_ids=[], + evicted_req_ids=["req-2"], + ), + ], +) +def test_scheduler_step_needs_flush_for_all_supported_signals( + scheduler_output: SimpleNamespace, +) -> None: + connector = _new_connector_without_init() + assert connector._scheduler_step_needs_flush(scheduler_output) is True + + +def test_handle_preemptions_forwards_flush_hint_to_worker_adapter() -> None: + connector = _new_connector_without_init() + connector._role = KVConnectorRole.WORKER + connector.worker_adapter = MagicMock() + metadata = LMCacheMPConnectorMetadata() + + connector.handle_preemptions(metadata) + connector.worker_adapter.handle_preemptions.assert_called_once_with(False) + + connector.worker_adapter.reset_mock() + metadata.need_flush = True + connector.handle_preemptions(metadata) + connector.worker_adapter.handle_preemptions.assert_called_once_with(True) diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index 635e7ea09d..a9f08da303 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -296,3 +296,15 @@ def test_retrieve_keeps_event_until_future_finishes(fake_adapter): transfer_ctx.reset_mock() gc.collect() assert event_ref() is None + + +def test_handle_preemptions_flushes_only_when_signaled(fake_adapter): + adapter, _send_mock, _future = fake_adapter + transfer_ctx = MagicMock() + adapter.transfer_ctx = transfer_ctx + + adapter.handle_preemptions(False) + transfer_ctx.flush_inflight_gathers.assert_not_called() + + adapter.handle_preemptions(True) + transfer_ctx.flush_inflight_gathers.assert_called_once_with() From 486bd20e0021951aef36cd3c51da599dc5cedcd2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Jun 2026 01:35:50 +0000 Subject: [PATCH 03/31] Gate async non-GPU store on device capability with sync fallback --- .../integration/vllm/lmcache_mp_connector.py | 33 ++- .../transfer_context/worker_transfer.py | 231 ++++++++++++++---- .../test_async_data_transfer_context.py | 125 +++++++++- .../test_lmcache_mp_connector_preemption.py | 19 ++ 4 files changed, 358 insertions(+), 50 deletions(-) diff --git a/lmcache/integration/vllm/lmcache_mp_connector.py b/lmcache/integration/vllm/lmcache_mp_connector.py index 2137d42239..9fc4bd2644 100644 --- a/lmcache/integration/vllm/lmcache_mp_connector.py +++ b/lmcache/integration/vllm/lmcache_mp_connector.py @@ -518,11 +518,7 @@ def __str__(self): f"num_blocks={len(req_meta.op.block_ids[0])}, " f"block_ids={req_meta.op.block_ids})" ) - return ( - f"need_flush={self.need_flush}; [" - + "\n".join(request_strs) - + "]" - ) + return f"need_flush={self.need_flush}; [" + "\n".join(request_strs) + "]" def __repr__(self): return self.__str__() @@ -998,8 +994,14 @@ def build_connector_meta( return metadata def _scheduler_step_needs_flush(self, scheduler_output: SchedulerOutput) -> bool: - """Return whether this scheduler step can overwrite preempted blocks.""" - # vLLM 0.10+ exposes resumed preemptions through cached-request fields. + """Return whether this scheduler step can overwrite preempted blocks. + + Under-syncing here risks KV-block corruption (a paged block may be + overwritten before a deferred async gather reads it), while over-syncing + only costs performance, so we prefer a spurious flush over a missed one. + """ + # In vLLM v1, resumed preemptions are surfaced on the cached-request + # struct via ``resumed_req_ids`` / ``resumed_from_preemption``. cached_reqs = getattr(scheduler_output, "scheduled_cached_reqs", None) resumed_req_ids = getattr(cached_reqs, "resumed_req_ids", None) if resumed_req_ids: @@ -1007,13 +1009,28 @@ def _scheduler_step_needs_flush(self, scheduler_output: SchedulerOutput) -> bool resumed_flags = getattr(cached_reqs, "resumed_from_preemption", None) if resumed_flags and any(resumed_flags): return True - # Conservative fallback for alternate scheduler output schemas. + # Resilient optional checks: these fields may not exist on every vLLM + # ``SchedulerOutput`` version; getattr keeps the probe harmless when so. preempted_req_ids = getattr(scheduler_output, "preempted_req_ids", None) if preempted_req_ids: return True evicted_req_ids = getattr(scheduler_output, "evicted_req_ids", None) if evicted_req_ids: return True + # Conservative fallback: if cached requests are present but expose none + # of the known preemption fields, the schema is unrecognized and we + # cannot prove the step is preemption-free. Flush rather than risk + # corruption. + if cached_reqs is not None and not ( + hasattr(cached_reqs, "resumed_req_ids") + or hasattr(cached_reqs, "resumed_from_preemption") + ): + logger.warning( + "Unrecognized scheduled_cached_reqs schema (%s); conservatively " + "flushing in-flight async gathers to avoid KV block corruption.", + type(cached_reqs).__name__, + ) + return True return False def update_connector_output(self, connector_output: KVConnectorOutput): diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index 245004d14c..0dcc65e1ce 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -43,6 +43,10 @@ # ``data``); ``auto`` reproduces the historical device-type-based dispatch. ENV_MP_TRANSFER_MODE = "LMCACHE_MP_TRANSFER_MODE" DEFAULT_MAX_ASYNC_NON_GPU_STORES = 8 +# Number of background threads used to run commit (CPU->server) work for the +# async non-GPU store path. >1 so that a slow gather for one store does not +# block the commit of another store whose gather already finished. +DEFAULT_NON_GPU_COMMIT_WORKERS = 4 class MPTransferMode(str, Enum): @@ -318,29 +322,44 @@ def close(self) -> None: class DataTransferContext(TransferContext): """Data transfer context for non-CUDA workers. - Store on the non-GPU path is two-phase and fully async: + Store on the non-GPU path is two-phase and fully async *when the worker + device supports the required async primitives* (a stream, an event with + ``record``/``synchronize``/``wait``, and pinned host memory): 1) gather: enqueue GPU->CPU copies on a dedicated copy stream into LMCache-owned pinned staging buffers (ordered behind the per-step event). 2) commit: wait for gather completion in a background thread, then perform commit_store() (pickle or SHM commit) and resolve the returned future. + When those primitives are not available (e.g. a CPU-only backend without + streams/events/pinned memory), the context automatically falls back to the + original synchronous store implementation. This dispatch is internal and + capability-based; there is no user-facing async/sync flag, and async stays + the default whenever the device can support it. + SHM note: SHM slots are generally pageable, so device->SHM DtoH copies may implicitly synchronize. To keep gather async, we always gather into pinned bounce buffers first, then copy to SHM slots on the commit thread. """ def __init__( - self, max_inflight_stores: int = DEFAULT_MAX_ASYNC_NON_GPU_STORES + self, + max_inflight_stores: int = DEFAULT_MAX_ASYNC_NON_GPU_STORES, + commit_workers: int = DEFAULT_NON_GPU_COMMIT_WORKERS, ) -> None: self._non_gpu_context: NonGpuContext | None = None self._layout_hints: LayoutHints | None = None self._gpu_kv_format: Any = None - self._copy_stream = torch_dev.Stream() - self._commit_executor = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="lmcache_non_gpu_commit" - ) self._max_inflight_stores = max(1, int(max_inflight_stores)) - self._inflight_semaphore = threading.BoundedSemaphore(self._max_inflight_stores) + self._commit_workers = max(1, int(commit_workers)) + # Capability-based dispatch decided in register(); defaults to the + # synchronous fallback until the device is known to be async-capable. + self._async_capable = False + # Async-only resources. Created lazily (never in __init__) and only when + # the device is async-capable, so backends without Stream/Event/pinned + # memory never touch these primitives. + self._copy_stream: Any = None + self._commit_executor: ThreadPoolExecutor | None = None + self._inflight_semaphore: threading.BoundedSemaphore | None = None self._inflight_lock = threading.Lock() self._inflight_gather_events: set[Any] = set() self._inflight_commits: set[ConcurrentFuture[None]] = set() @@ -349,6 +368,44 @@ def __init__( ] = {} self._is_closing = False + def _detect_async_capable(self) -> bool: + """Probe whether the worker device supports the async store primitives. + + Requires a stream, an event exposing ``record``/``synchronize``/ + ``wait``, and pinned (page-locked) host memory. The probe is performed + once (cached by ``register()``); it never runs per ``submit_store``. + """ + if not hasattr(torch_dev, "Stream") or not hasattr(torch_dev, "Event"): + return False + try: + torch_dev.Stream() + event = torch_dev.Event() + except Exception: + return False + for attr in ("record", "synchronize", "wait"): + if not callable(getattr(event, attr, None)): + return False + try: + torch.empty(1, dtype=torch.uint8, device="cpu", pin_memory=True) + except (RuntimeError, TypeError): + return False + return True + + def _create_async_resources(self) -> None: + """Create the copy stream / commit executor / backpressure semaphore.""" + self._copy_stream = torch_dev.Stream() + self._commit_executor = ThreadPoolExecutor( + max_workers=self._commit_workers, + thread_name_prefix="lmcache_non_gpu_commit", + ) + self._inflight_semaphore = threading.BoundedSemaphore(self._max_inflight_stores) + + def _init_async_capability(self) -> None: + """Detect device async capability and lazily create async resources.""" + self._async_capable = self._detect_async_capable() + if self._async_capable: + self._create_async_resources() + def _alloc_pinned_staging( self, shape: torch.Size, dtype: torch.dtype, count: int ) -> list[torch.Tensor]: @@ -463,10 +520,13 @@ def register( pool_size=pool_size, ) supported_transfer_mode = "SHM" if shm_name and pool_size > 0 else "pickle" + self._init_async_capability() logger.info( - "Worker non-GPU transfer context registered (instance_id=%d, mode=%s)", + "Worker non-GPU transfer context registered " + "(instance_id=%d, mode=%s, store=%s)", instance_id, supported_transfer_mode, + "async" if self._async_capable else "sync", ) def submit_store( @@ -479,14 +539,89 @@ def submit_store( _event: IPCEvent, blocks_in_chunk: int, ) -> MessagingFuture: - completion: MessagingFuture[bool] = MessagingFuture() if self._non_gpu_context is None: raise RuntimeError( "Data transfer context is not registered. " "Call register() before submit_store()." ) + if self._async_capable: + return self._submit_store_async( + _request_id, + key, + instance_id, + kv_caches, + block_ids, + _event, + blocks_in_chunk, + ) + return self._submit_store_sync( + key, + instance_id, + kv_caches, + block_ids, + blocks_in_chunk, + ) - self._inflight_semaphore.acquire() + def _submit_store_sync( + self, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[list[int]], + blocks_in_chunk: int, + ) -> MessagingFuture: + """Original synchronous store path (capability fallback). + + Reproduces the pre-async behaviour exactly: synchronize, prepare, + gather, (SHM) synchronize, commit, and return an already-resolved + future. + """ + assert self._non_gpu_context is not None + torch_dev.synchronize() + result = self._non_gpu_context.prepare_store(key, instance_id) + out_buffers, chunk_indices = result if result is not None else (None, None) + # All chunks already in cache — nothing to gather or commit. + if chunk_indices is not None and len(chunk_indices) == 0: + future: MessagingFuture[bool] = MessagingFuture() + future.set_result(True) + return future + cpu_chunks = gather_paged_kv_to_cpu( + kv_caches, + _single_group_block_ids(block_ids), + blocks_in_chunk, + layout_hints=self._layout_hints, + gpu_kv_format=self._gpu_kv_format, + out=out_buffers, + chunk_indices=chunk_indices, + ) + if out_buffers is not None: + # SHM path uses async device->CPU copies; complete them before commit. + torch_dev.synchronize() + ok = self._non_gpu_context.commit_store(key, instance_id, cpu_chunks) + + future = MessagingFuture() + future.set_result(ok) + return future + + def _submit_store_async( + self, + _request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[list[int]], + _event: IPCEvent, + blocks_in_chunk: int, + ) -> MessagingFuture: + completion: MessagingFuture[bool] = MessagingFuture() + non_gpu_context = self._non_gpu_context + semaphore = self._inflight_semaphore + commit_executor = self._commit_executor + assert non_gpu_context is not None + assert semaphore is not None + assert commit_executor is not None + + semaphore.acquire() staged_chunks: list[torch.Tensor] = [] shm_out_buffers: list[torch.Tensor] | None = None gather_done: Any | None = None @@ -494,15 +629,15 @@ def submit_store( with self._inflight_lock: if self._is_closing: completion.set_result(False) - self._inflight_semaphore.release() + semaphore.release() return completion - result = self._non_gpu_context.prepare_store(key, instance_id) + result = non_gpu_context.prepare_store(key, instance_id) out_buffers, chunk_indices = result if result is not None else (None, None) if chunk_indices is not None and len(chunk_indices) == 0: # All chunks are already in cache: no gather, no commit. completion.set_result(True) - self._inflight_semaphore.release() + semaphore.release() return completion full_block_ids = _single_group_block_ids(block_ids) @@ -511,13 +646,13 @@ def submit_store( if chunk_indices is not None else len(full_block_ids) // blocks_in_chunk ) - if not self._non_gpu_context.layout_desc.shapes: + if not non_gpu_context.layout_desc.shapes: raise RuntimeError("non-GPU layout_desc.shapes is empty") - if not self._non_gpu_context.layout_desc.dtypes: + if not non_gpu_context.layout_desc.dtypes: raise RuntimeError("non-GPU layout_desc.dtypes is empty") staged_chunks = self._alloc_pinned_staging( - self._non_gpu_context.layout_desc.shapes[0], - self._non_gpu_context.layout_desc.dtypes[0], + non_gpu_context.layout_desc.shapes[0], + non_gpu_context.layout_desc.dtypes[0], num_chunks, ) shm_out_buffers = out_buffers @@ -555,11 +690,11 @@ def _commit_after_gather() -> None: staged_chunks, shm_out_buffers, strict=True ): shm_view.copy_(staged) - ok = self._non_gpu_context.commit_store( + ok = non_gpu_context.commit_store( key, instance_id, shm_out_buffers ) else: - ok = self._non_gpu_context.commit_store( + ok = non_gpu_context.commit_store( key, instance_id, staged_chunks ) if not ok: @@ -579,18 +714,14 @@ def _commit_after_gather() -> None: if gather_done is not None: self._inflight_gather_events.discard(gather_done) completion.set_result(ok) - self._inflight_semaphore.release() - - commit_future = self._commit_executor.submit(_commit_after_gather) - with self._inflight_lock: - self._inflight_commits.add(commit_future) - - def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: - with self._inflight_lock: - self._inflight_commits.discard(done_future) - - commit_future.add_done_callback(_drop_commit_future) - return completion + semaphore.release() + + # Submitting the commit task is the ownership-transfer point: once it + # succeeds, the commit task is solely responsible for releasing the + # semaphore, releasing staging buffers, and resolving the future. The + # except below therefore only handles failures that occur *before* + # this submit, so it can never double-release or double-resolve. + commit_future = commit_executor.submit(_commit_after_gather) except Exception: logger.exception("Failed to submit async non-GPU store") if staged_chunks: @@ -599,9 +730,19 @@ def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: with self._inflight_lock: self._inflight_gather_events.discard(gather_done) completion.set_result(False) - self._inflight_semaphore.release() + semaphore.release() return completion + with self._inflight_lock: + self._inflight_commits.add(commit_future) + + def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: + with self._inflight_lock: + self._inflight_commits.discard(done_future) + + commit_future.add_done_callback(_drop_commit_future) + return completion + def submit_retrieve( self, _request_id: str, @@ -645,20 +786,28 @@ def submit_retrieve( return future def close(self) -> None: - with self._inflight_lock: - self._is_closing = True - gather_events = list(self._inflight_gather_events) - for event in gather_events: - try: - event.synchronize() - except Exception: - logger.exception("Failed while draining gather events") - self._commit_executor.shutdown(wait=True, cancel_futures=False) + # Drain in-flight async work only when async resources were created. + # In sync (fallback) mode there is no copy stream / executor / inflight + # state, so guard against touching never-created attributes. + if self._async_capable: + with self._inflight_lock: + self._is_closing = True + gather_events = list(self._inflight_gather_events) + for event in gather_events: + try: + event.synchronize() + except Exception: + logger.exception("Failed while draining gather events") + if self._commit_executor is not None: + self._commit_executor.shutdown(wait=True, cancel_futures=False) if self._non_gpu_context is not None: self._non_gpu_context.close() self._non_gpu_context = None def flush_inflight_gathers(self) -> None: + # Cheap no-op in sync mode: no copy stream / in-flight gather events. + if not self._async_capable: + return with self._inflight_lock: gather_events = list(self._inflight_gather_events) for event in gather_events: diff --git a/tests/v1/multiprocess/test_async_data_transfer_context.py b/tests/v1/multiprocess/test_async_data_transfer_context.py index afb2b5d63e..c8ae29724b 100644 --- a/tests/v1/multiprocess/test_async_data_transfer_context.py +++ b/tests/v1/multiprocess/test_async_data_transfer_context.py @@ -84,7 +84,10 @@ def _gather( _blocks_in_chunk: int, **kwargs: object, ) -> list[torch.Tensor]: - out = kwargs["out"] + out = kwargs.get("out") + if out is None: + # Sync path may pass out=None; gather allocates its own buffers. + return [torch.ones(1)] assert isinstance(out, list) for tensor in out: tensor.fill_(1.0) @@ -104,6 +107,10 @@ def _new_context( _install_fake_gather(monkeypatch) ctx = DataTransferContext(max_inflight_stores=max_inflight) ctx._non_gpu_context = _FakeStoreContext(commit_impl=commit_impl) + # Async tests exercise the async store path directly; enable it explicitly + # so the capability probe (which needs real pinned memory) is bypassed. + ctx._async_capable = True + ctx._create_async_resources() return ctx @@ -233,3 +240,119 @@ def _commit(_chunks: list[torch.Tensor]) -> bool: assert future.result(timeout=1) is False log_exception.assert_called_once() ctx.close() + + +class _RecordingTorchDev: + """torch_dev stub that records whether async primitives are touched.""" + + def __init__(self) -> None: + self.synchronize_calls = 0 + self.stream_calls = 0 + self.event_calls = 0 + + def synchronize(self) -> None: + self.synchronize_calls += 1 + + def Stream(self) -> object: + self.stream_calls += 1 + return object() + + def stream(self, stream: object) -> object: + return nullcontext(stream) + + def Event(self, interprocess: bool = False) -> _FakeEvent: + self.event_calls += 1 + return _FakeEvent(threading.Event()) + + +def test_submit_store_sync_path_returns_resolved_future( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake = _RecordingTorchDev() + monkeypatch.setattr(worker_transfer, "torch_dev", fake) + _install_fake_gather(monkeypatch) + ctx = DataTransferContext() + ctx._non_gpu_context = _FakeStoreContext(commit_impl=lambda _c: True) + assert ctx._async_capable is False + + future = ctx.submit_store( + "r1", + object(), + 1, + {"k": torch.zeros(1)}, + [[0]], + _FakeEvent(threading.Event()), + 1, + ) + + # Sync path resolves inline and never touches copy-stream / event primitives. + assert future.query() + assert future.result(timeout=1) is True + assert fake.stream_calls == 0 + assert fake.event_calls == 0 + assert fake.synchronize_calls >= 1 + ctx.close() + + +def test_submit_store_dispatches_on_capability( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ctx = DataTransferContext() + ctx._non_gpu_context = MagicMock() + async_mock = MagicMock(return_value="async") + sync_mock = MagicMock(return_value="sync") + monkeypatch.setattr(ctx, "_submit_store_async", async_mock) + monkeypatch.setattr(ctx, "_submit_store_sync", sync_mock) + + ctx._async_capable = True + assert ctx.submit_store("r", None, 1, {}, [[0]], None, 1) == "async" + async_mock.assert_called_once() + sync_mock.assert_not_called() + + async_mock.reset_mock() + ctx._async_capable = False + assert ctx.submit_store("r", None, 1, {}, [[0]], None, 1) == "sync" + sync_mock.assert_called_once() + async_mock.assert_not_called() + + +def test_init_async_capability_non_capable_skips_resources( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # A torch_dev without Stream/Event is not async-capable. + monkeypatch.setattr(worker_transfer, "torch_dev", object()) + ctx = DataTransferContext() + assert ctx._copy_stream is None + assert ctx._commit_executor is None + + ctx._init_async_capability() + + assert ctx._async_capable is False + assert ctx._copy_stream is None + assert ctx._commit_executor is None + assert ctx._inflight_semaphore is None + # close() must not raise even though async resources were never created. + ctx.close() + + +def test_init_async_capability_capable_creates_resources( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(worker_transfer, "torch_dev", _FakeTorchDev(threading.Event())) + # Make the pinned-memory probe succeed regardless of host CUDA support. + real_empty = torch.empty + + def _fake_empty(*args: object, **kwargs: object) -> torch.Tensor: + kwargs.pop("pin_memory", None) + return real_empty(*args, **kwargs) + + monkeypatch.setattr(worker_transfer.torch, "empty", _fake_empty) + + ctx = DataTransferContext() + ctx._init_async_capability() + + assert ctx._async_capable is True + assert ctx._copy_stream is not None + assert ctx._commit_executor is not None + assert ctx._inflight_semaphore is not None + ctx.close() diff --git a/tests/v1/test_lmcache_mp_connector_preemption.py b/tests/v1/test_lmcache_mp_connector_preemption.py index 636edaa283..1f0ea93611 100644 --- a/tests/v1/test_lmcache_mp_connector_preemption.py +++ b/tests/v1/test_lmcache_mp_connector_preemption.py @@ -104,3 +104,22 @@ def test_handle_preemptions_forwards_flush_hint_to_worker_adapter() -> None: metadata.need_flush = True connector.handle_preemptions(metadata) connector.worker_adapter.handle_preemptions.assert_called_once_with(True) + + +def test_scheduler_step_needs_flush_conservative_on_unknown_schema() -> None: + connector = _new_connector_without_init() + scheduler_output = SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace(unexpected_field=["req-1"]) + ) + assert connector._scheduler_step_needs_flush(scheduler_output) is True + + +def test_scheduler_step_needs_flush_false_for_recognized_no_preemption() -> None: + connector = _new_connector_without_init() + scheduler_output = SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace( + resumed_req_ids=[], + resumed_from_preemption=[], + ) + ) + assert connector._scheduler_step_needs_flush(scheduler_output) is False From 5b8a46306de8659a504bd628c048dba45d33b852 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Jun 2026 01:38:51 +0000 Subject: [PATCH 04/31] Add docstrings and clarify preemption schema comments per review --- .../integration/vllm/lmcache_mp_connector.py | 11 +++++---- .../transfer_context/worker_transfer.py | 23 ++++++++++++++++++- .../test_async_data_transfer_context.py | 3 ++- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/lmcache/integration/vllm/lmcache_mp_connector.py b/lmcache/integration/vllm/lmcache_mp_connector.py index 9fc4bd2644..1b2ef14fea 100644 --- a/lmcache/integration/vllm/lmcache_mp_connector.py +++ b/lmcache/integration/vllm/lmcache_mp_connector.py @@ -1017,10 +1017,13 @@ def _scheduler_step_needs_flush(self, scheduler_output: SchedulerOutput) -> bool evicted_req_ids = getattr(scheduler_output, "evicted_req_ids", None) if evicted_req_ids: return True - # Conservative fallback: if cached requests are present but expose none - # of the known preemption fields, the schema is unrecognized and we - # cannot prove the step is preemption-free. Flush rather than risk - # corruption. + # Conservative fallback: a recognized cached-request schema is expected + # to expose at least one of ``resumed_req_ids`` / + # ``resumed_from_preemption`` (older vLLM exposes the former, newer the + # latter; they are version-dependent alternatives, not required to + # co-exist). If cached requests are present but expose neither, the + # schema is unrecognized and we cannot prove the step is + # preemption-free, so flush rather than risk corruption. if cached_reqs is not None and not ( hasattr(cached_reqs, "resumed_req_ids") or hasattr(cached_reqs, "resumed_from_preemption") diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index 0dcc65e1ce..02c561053f 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -346,6 +346,17 @@ def __init__( max_inflight_stores: int = DEFAULT_MAX_ASYNC_NON_GPU_STORES, commit_workers: int = DEFAULT_NON_GPU_COMMIT_WORKERS, ) -> None: + """Initialize the context (async resources are created lazily). + + Args: + max_inflight_stores: Max number of concurrently in-flight async + stores before ``submit_store`` blocks (backpressure). Async + mode only. + commit_workers: Number of background threads used to run commit + (CPU->server) work in async mode. >1 so a slow gather for one + store does not block the commit of another whose gather is + already done. Async mode only. + """ self._non_gpu_context: NonGpuContext | None = None self._layout_hints: LayoutHints | None = None self._gpu_kv_format: Any = None @@ -374,6 +385,9 @@ def _detect_async_capable(self) -> bool: Requires a stream, an event exposing ``record``/``synchronize``/ ``wait``, and pinned (page-locked) host memory. The probe is performed once (cached by ``register()``); it never runs per ``submit_store``. + + Returns: + True if all required async primitives are available, else False. """ if not hasattr(torch_dev, "Stream") or not hasattr(torch_dev, "Event"): return False @@ -574,7 +588,8 @@ def _submit_store_sync( Reproduces the pre-async behaviour exactly: synchronize, prepare, gather, (SHM) synchronize, commit, and return an already-resolved - future. + future. See :meth:`submit_store` for argument semantics (``key``, + ``instance_id``, ``kv_caches``, ``block_ids``, ``blocks_in_chunk``). """ assert self._non_gpu_context is not None torch_dev.synchronize() @@ -613,6 +628,12 @@ def _submit_store_async( _event: IPCEvent, blocks_in_chunk: int, ) -> MessagingFuture: + """Two-phase async store path (gather on copy stream, deferred commit). + + Returns an unresolved future that resolves only after both gather + completion and the commit ACK. See :meth:`submit_store` for argument + semantics. + """ completion: MessagingFuture[bool] = MessagingFuture() non_gpu_context = self._non_gpu_context semaphore = self._inflight_semaphore diff --git a/tests/v1/multiprocess/test_async_data_transfer_context.py b/tests/v1/multiprocess/test_async_data_transfer_context.py index c8ae29724b..ca02a2ea0c 100644 --- a/tests/v1/multiprocess/test_async_data_transfer_context.py +++ b/tests/v1/multiprocess/test_async_data_transfer_context.py @@ -86,7 +86,8 @@ def _gather( ) -> list[torch.Tensor]: out = kwargs.get("out") if out is None: - # Sync path may pass out=None; gather allocates its own buffers. + # Sync fallback path passes out=None: gather allocates its own + # buffers and returns them, so mirror that contract here. return [torch.ones(1)] assert isinstance(out, list) for tensor in out: From 8a1e2835f0f8f5ac8d3dcfdcb4668f1db14fc625 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Jun 2026 02:04:56 +0000 Subject: [PATCH 05/31] Refactor async non-GPU store into dedicated AsyncDataTransferContext --- .../multiprocess/transfer_context/__init__.py | 2 + .../transfer_context/async_data.py | 284 +++++++++++++ .../transfer_context/worker_transfer.py | 402 +++--------------- .../test_async_data_transfer_context.py | 122 +++--- 4 files changed, 402 insertions(+), 408 deletions(-) create mode 100644 lmcache/v1/multiprocess/transfer_context/async_data.py diff --git a/lmcache/v1/multiprocess/transfer_context/__init__.py b/lmcache/v1/multiprocess/transfer_context/__init__.py index 6551664df5..833ac32748 100644 --- a/lmcache/v1/multiprocess/transfer_context/__init__.py +++ b/lmcache/v1/multiprocess/transfer_context/__init__.py @@ -7,6 +7,7 @@ """ # Local +from .async_data import AsyncDataTransferContext from .base import ( NonGpuContext, NonGpuContextMetadata, @@ -26,6 +27,7 @@ ) __all__ = [ + "AsyncDataTransferContext", "DataTransferContext", "HandleTransferContext", "MPTransferMode", diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py new file mode 100644 index 0000000000..b1fb3fd8c0 --- /dev/null +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Async non-GPU data transfer context for multiprocess worker adapters.""" + +# Standard +from concurrent.futures import Future as ConcurrentFuture +from concurrent.futures import ThreadPoolExecutor +from typing import Any +import threading + +# Third Party +import torch + +# First Party +from lmcache import torch_dev +from lmcache.utils import init_logger +from lmcache.v1.multiprocess.futures import MessagingFuture +from lmcache.v1.multiprocess.transfer_context.base import gather_paged_kv_to_cpu +from lmcache.v1.multiprocess.transfer_context.worker_transfer import ( + DataTransferContext, + IPCEvent, + _single_group_block_ids, +) + +logger = init_logger(__name__) + +DEFAULT_MAX_ASYNC_NON_GPU_STORES = 8 +# Number of background threads used to run commit (CPU->server) work for the +# async non-GPU store path. >1 so that a slow gather for one store does not +# block the commit of another store whose gather already finished. +DEFAULT_NON_GPU_COMMIT_WORKERS = 4 + + +class AsyncDataTransferContext(DataTransferContext): + """Fully async non-GPU data transfer context (store-only async). + + Inherits :class:`DataTransferContext` and reuses its ``register()`` (layout + / SHM registration, no stream dependency) and ``submit_retrieve()`` (this + path does not change retrieve). Only the store is made async. + + Store is two-phase: + 1) gather: enqueue GPU->CPU copies on a dedicated copy stream into + LMCache-owned pinned staging buffers (ordered behind the per-step event). + 2) commit: wait for gather completion in a background thread, then perform + commit_store() (pickle or SHM commit) and resolve the returned future. + + This class is only instantiated by the factory when the device is + async-capable, so the constructor creates async resources unconditionally; + there is no ``self._async_capable`` flag. + + SHM note: SHM slots are generally pageable, so device->SHM DtoH copies may + implicitly synchronize. To keep gather async, we always gather into pinned + bounce buffers first, then copy to SHM slots on the commit thread. + """ + + def __init__( + self, + max_inflight_stores: int = DEFAULT_MAX_ASYNC_NON_GPU_STORES, + commit_workers: int = DEFAULT_NON_GPU_COMMIT_WORKERS, + ) -> None: + """Initialize the async context and create its async resources. + + Args: + max_inflight_stores: Max number of concurrently in-flight async + stores before ``submit_store`` blocks (backpressure). + commit_workers: Number of background threads used to run commit + (CPU->server) work. >1 so a slow gather for one store does not + block the commit of another whose gather is already done. + """ + super().__init__() + self._max_inflight_stores = max(1, int(max_inflight_stores)) + self._commit_workers = max(1, int(commit_workers)) + self._copy_stream: Any = torch_dev.Stream() + self._commit_executor: ThreadPoolExecutor = ThreadPoolExecutor( + max_workers=self._commit_workers, + thread_name_prefix="lmcache_non_gpu_commit", + ) + self._inflight_semaphore = threading.BoundedSemaphore(self._max_inflight_stores) + self._inflight_lock = threading.Lock() + self._inflight_gather_events: set[Any] = set() + self._inflight_commits: set[ConcurrentFuture[None]] = set() + self._staging_pool: dict[ + tuple[tuple[int, ...], torch.dtype], list[torch.Tensor] + ] = {} + self._is_closing = False + + def _alloc_pinned_staging( + self, shape: torch.Size, dtype: torch.dtype, count: int + ) -> list[torch.Tensor]: + key = (tuple(shape), dtype) + with self._inflight_lock: + pooled = self._staging_pool.setdefault(key, []) + staged = [pooled.pop() for _ in range(min(len(pooled), count))] + if len(staged) == count: + return staged + + missing = count - len(staged) + for _ in range(missing): + try: + staged.append( + torch.empty(shape, dtype=dtype, device="cpu", pin_memory=True) + ) + except RuntimeError: + # Graceful fallback for CPU-only / pin-memory-disabled setups. + logger.warning( + "Falling back to non-pinned CPU staging buffer " + "(shape=%s, dtype=%s)", + tuple(shape), + dtype, + ) + staged.append(torch.empty(shape, dtype=dtype, device="cpu")) + return staged + + def _release_staging(self, chunks: list[torch.Tensor]) -> None: + if not chunks: + return + key = (tuple(chunks[0].shape), chunks[0].dtype) + with self._inflight_lock: + self._staging_pool.setdefault(key, []).extend(chunks) + + def submit_store( + self, + _request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[list[int]], + _event: IPCEvent, + blocks_in_chunk: int, + ) -> MessagingFuture: + """Two-phase async store (gather on copy stream, deferred commit). + + Returns an unresolved future that resolves only after both gather + completion and the commit ACK. + """ + if self._non_gpu_context is None: + raise RuntimeError( + "Data transfer context is not registered. " + "Call register() before submit_store()." + ) + completion: MessagingFuture[bool] = MessagingFuture() + non_gpu_context = self._non_gpu_context + semaphore = self._inflight_semaphore + commit_executor = self._commit_executor + + semaphore.acquire() + staged_chunks: list[torch.Tensor] = [] + shm_out_buffers: list[torch.Tensor] | None = None + gather_done: Any | None = None + try: + with self._inflight_lock: + if self._is_closing: + completion.set_result(False) + semaphore.release() + return completion + + result = non_gpu_context.prepare_store(key, instance_id) + out_buffers, chunk_indices = result if result is not None else (None, None) + if chunk_indices is not None and len(chunk_indices) == 0: + # All chunks are already in cache: no gather, no commit. + completion.set_result(True) + semaphore.release() + return completion + + full_block_ids = _single_group_block_ids(block_ids) + num_chunks = ( + len(chunk_indices) + if chunk_indices is not None + else len(full_block_ids) // blocks_in_chunk + ) + if not non_gpu_context.layout_desc.shapes: + raise RuntimeError("non-GPU layout_desc.shapes is empty") + if not non_gpu_context.layout_desc.dtypes: + raise RuntimeError("non-GPU layout_desc.dtypes is empty") + staged_chunks = self._alloc_pinned_staging( + non_gpu_context.layout_desc.shapes[0], + non_gpu_context.layout_desc.dtypes[0], + num_chunks, + ) + shm_out_buffers = out_buffers + with torch_dev.stream(self._copy_stream): + _event.wait(stream=self._copy_stream) + gather_paged_kv_to_cpu( + kv_caches, + full_block_ids, + blocks_in_chunk, + layout_hints=self._layout_hints, + gpu_kv_format=self._gpu_kv_format, + out=staged_chunks, + chunk_indices=chunk_indices, + ) + gather_done = torch_dev.Event() + gather_done.record(self._copy_stream) + + with self._inflight_lock: + if gather_done is not None: + self._inflight_gather_events.add(gather_done) + + def _commit_after_gather() -> None: + ok = False + try: + if gather_done is not None: + gather_done.synchronize() + if shm_out_buffers is not None: + if len(staged_chunks) != len(shm_out_buffers): + raise RuntimeError( + "SHM staging chunk count mismatch: " + f"{len(staged_chunks)} vs {len(shm_out_buffers)} " + f"(request_id={_request_id}, instance_id={instance_id})" + ) + for staged, shm_view in zip( + staged_chunks, shm_out_buffers, strict=True + ): + shm_view.copy_(staged) + ok = non_gpu_context.commit_store( + key, instance_id, shm_out_buffers + ) + else: + ok = non_gpu_context.commit_store( + key, instance_id, staged_chunks + ) + if not ok: + logger.error( + "Async non-GPU commit_store failed for request_id=%s", + _request_id, + ) + except Exception: + logger.exception( + "Async non-GPU store failed for request_id=%s", + _request_id, + ) + ok = False + finally: + self._release_staging(staged_chunks) + with self._inflight_lock: + if gather_done is not None: + self._inflight_gather_events.discard(gather_done) + completion.set_result(ok) + semaphore.release() + + # Submitting the commit task is the ownership-transfer point: once it + # succeeds, the commit task is solely responsible for releasing the + # semaphore, releasing staging buffers, and resolving the future. The + # except below therefore only handles failures that occur *before* + # this submit, so it can never double-release or double-resolve. + commit_future = commit_executor.submit(_commit_after_gather) + except Exception: + logger.exception("Failed to submit async non-GPU store") + if staged_chunks: + self._release_staging(staged_chunks) + if gather_done is not None: + with self._inflight_lock: + self._inflight_gather_events.discard(gather_done) + completion.set_result(False) + semaphore.release() + return completion + + with self._inflight_lock: + self._inflight_commits.add(commit_future) + + def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: + with self._inflight_lock: + self._inflight_commits.discard(done_future) + + commit_future.add_done_callback(_drop_commit_future) + return completion + + def flush_inflight_gathers(self) -> None: + with self._inflight_lock: + gather_events = list(self._inflight_gather_events) + for event in gather_events: + event.synchronize() + + def close(self) -> None: + # Drain in-flight gather/commit work before closing the base context. + with self._inflight_lock: + self._is_closing = True + gather_events = list(self._inflight_gather_events) + for event in gather_events: + try: + event.synchronize() + except Exception: + logger.exception("Failed while draining gather events") + self._commit_executor.shutdown(wait=True, cancel_futures=False) + super().close() diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index 02c561053f..59fcec73c1 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -4,12 +4,9 @@ # Standard from abc import ABC, abstractmethod from collections.abc import Sequence -from concurrent.futures import Future as ConcurrentFuture -from concurrent.futures import ThreadPoolExecutor from enum import Enum from typing import Any, Callable, Protocol import os -import threading # Third Party import torch @@ -42,11 +39,6 @@ # string values of :class:`MPTransferMode` (``auto`` / ``handle`` / # ``data``); ``auto`` reproduces the historical device-type-based dispatch. ENV_MP_TRANSFER_MODE = "LMCACHE_MP_TRANSFER_MODE" -DEFAULT_MAX_ASYNC_NON_GPU_STORES = 8 -# Number of background threads used to run commit (CPU->server) work for the -# async non-GPU store path. >1 so that a slow gather for one store does not -# block the commit of another store whose gather already finished. -DEFAULT_NON_GPU_COMMIT_WORKERS = 4 class MPTransferMode(str, Enum): @@ -102,6 +94,9 @@ class IPCEvent(Protocol): def ipc_handle(self) -> object: """Return an IPC handle consumable by the multiprocess server.""" + def wait(self, stream: object | None = None) -> None: + """Make ``stream`` wait for this event (async ordering primitive).""" + SendRequest = Callable[[MessageQueueClient, RequestType, list[object]], MessagingFuture] @@ -320,139 +315,12 @@ def close(self) -> None: class DataTransferContext(TransferContext): - """Data transfer context for non-CUDA workers. - - Store on the non-GPU path is two-phase and fully async *when the worker - device supports the required async primitives* (a stream, an event with - ``record``/``synchronize``/``wait``, and pinned host memory): - 1) gather: enqueue GPU->CPU copies on a dedicated copy stream into - LMCache-owned pinned staging buffers (ordered behind the per-step event). - 2) commit: wait for gather completion in a background thread, then perform - commit_store() (pickle or SHM commit) and resolve the returned future. - - When those primitives are not available (e.g. a CPU-only backend without - streams/events/pinned memory), the context automatically falls back to the - original synchronous store implementation. This dispatch is internal and - capability-based; there is no user-facing async/sync flag, and async stays - the default whenever the device can support it. - - SHM note: SHM slots are generally pageable, so device->SHM DtoH copies may - implicitly synchronize. To keep gather async, we always gather into pinned - bounce buffers first, then copy to SHM slots on the commit thread. - """ - - def __init__( - self, - max_inflight_stores: int = DEFAULT_MAX_ASYNC_NON_GPU_STORES, - commit_workers: int = DEFAULT_NON_GPU_COMMIT_WORKERS, - ) -> None: - """Initialize the context (async resources are created lazily). + """Data transfer context for non-CUDA workers.""" - Args: - max_inflight_stores: Max number of concurrently in-flight async - stores before ``submit_store`` blocks (backpressure). Async - mode only. - commit_workers: Number of background threads used to run commit - (CPU->server) work in async mode. >1 so a slow gather for one - store does not block the commit of another whose gather is - already done. Async mode only. - """ + def __init__(self) -> None: self._non_gpu_context: NonGpuContext | None = None self._layout_hints: LayoutHints | None = None self._gpu_kv_format: Any = None - self._max_inflight_stores = max(1, int(max_inflight_stores)) - self._commit_workers = max(1, int(commit_workers)) - # Capability-based dispatch decided in register(); defaults to the - # synchronous fallback until the device is known to be async-capable. - self._async_capable = False - # Async-only resources. Created lazily (never in __init__) and only when - # the device is async-capable, so backends without Stream/Event/pinned - # memory never touch these primitives. - self._copy_stream: Any = None - self._commit_executor: ThreadPoolExecutor | None = None - self._inflight_semaphore: threading.BoundedSemaphore | None = None - self._inflight_lock = threading.Lock() - self._inflight_gather_events: set[Any] = set() - self._inflight_commits: set[ConcurrentFuture[None]] = set() - self._staging_pool: dict[ - tuple[tuple[int, ...], torch.dtype], list[torch.Tensor] - ] = {} - self._is_closing = False - - def _detect_async_capable(self) -> bool: - """Probe whether the worker device supports the async store primitives. - - Requires a stream, an event exposing ``record``/``synchronize``/ - ``wait``, and pinned (page-locked) host memory. The probe is performed - once (cached by ``register()``); it never runs per ``submit_store``. - - Returns: - True if all required async primitives are available, else False. - """ - if not hasattr(torch_dev, "Stream") or not hasattr(torch_dev, "Event"): - return False - try: - torch_dev.Stream() - event = torch_dev.Event() - except Exception: - return False - for attr in ("record", "synchronize", "wait"): - if not callable(getattr(event, attr, None)): - return False - try: - torch.empty(1, dtype=torch.uint8, device="cpu", pin_memory=True) - except (RuntimeError, TypeError): - return False - return True - - def _create_async_resources(self) -> None: - """Create the copy stream / commit executor / backpressure semaphore.""" - self._copy_stream = torch_dev.Stream() - self._commit_executor = ThreadPoolExecutor( - max_workers=self._commit_workers, - thread_name_prefix="lmcache_non_gpu_commit", - ) - self._inflight_semaphore = threading.BoundedSemaphore(self._max_inflight_stores) - - def _init_async_capability(self) -> None: - """Detect device async capability and lazily create async resources.""" - self._async_capable = self._detect_async_capable() - if self._async_capable: - self._create_async_resources() - - def _alloc_pinned_staging( - self, shape: torch.Size, dtype: torch.dtype, count: int - ) -> list[torch.Tensor]: - key = (tuple(shape), dtype) - with self._inflight_lock: - pooled = self._staging_pool.setdefault(key, []) - staged = [pooled.pop() for _ in range(min(len(pooled), count))] - if len(staged) == count: - return staged - - missing = count - len(staged) - for _ in range(missing): - try: - staged.append( - torch.empty(shape, dtype=dtype, device="cpu", pin_memory=True) - ) - except RuntimeError: - # Graceful fallback for CPU-only / pin-memory-disabled setups. - logger.warning( - "Falling back to non-pinned CPU staging buffer " - "(shape=%s, dtype=%s)", - tuple(shape), - dtype, - ) - staged.append(torch.empty(shape, dtype=dtype, device="cpu")) - return staged - - def _release_staging(self, chunks: list[torch.Tensor]) -> None: - if not chunks: - return - key = (tuple(chunks[0].shape), chunks[0].dtype) - with self._inflight_lock: - self._staging_pool.setdefault(key, []).extend(chunks) def register( self, @@ -534,13 +402,10 @@ def register( pool_size=pool_size, ) supported_transfer_mode = "SHM" if shm_name and pool_size > 0 else "pickle" - self._init_async_capability() logger.info( - "Worker non-GPU transfer context registered " - "(instance_id=%d, mode=%s, store=%s)", + "Worker non-GPU transfer context registered (instance_id=%d, mode=%s)", instance_id, supported_transfer_mode, - "async" if self._async_capable else "sync", ) def submit_store( @@ -558,40 +423,7 @@ def submit_store( "Data transfer context is not registered. " "Call register() before submit_store()." ) - if self._async_capable: - return self._submit_store_async( - _request_id, - key, - instance_id, - kv_caches, - block_ids, - _event, - blocks_in_chunk, - ) - return self._submit_store_sync( - key, - instance_id, - kv_caches, - block_ids, - blocks_in_chunk, - ) - def _submit_store_sync( - self, - key: Any, - instance_id: int, - kv_caches: dict[str, torch.Tensor], - block_ids: list[list[int]], - blocks_in_chunk: int, - ) -> MessagingFuture: - """Original synchronous store path (capability fallback). - - Reproduces the pre-async behaviour exactly: synchronize, prepare, - gather, (SHM) synchronize, commit, and return an already-resolved - future. See :meth:`submit_store` for argument semantics (``key``, - ``instance_id``, ``kv_caches``, ``block_ids``, ``blocks_in_chunk``). - """ - assert self._non_gpu_context is not None torch_dev.synchronize() result = self._non_gpu_context.prepare_store(key, instance_id) out_buffers, chunk_indices = result if result is not None else (None, None) @@ -618,152 +450,6 @@ def _submit_store_sync( future.set_result(ok) return future - def _submit_store_async( - self, - _request_id: str, - key: Any, - instance_id: int, - kv_caches: dict[str, torch.Tensor], - block_ids: list[list[int]], - _event: IPCEvent, - blocks_in_chunk: int, - ) -> MessagingFuture: - """Two-phase async store path (gather on copy stream, deferred commit). - - Returns an unresolved future that resolves only after both gather - completion and the commit ACK. See :meth:`submit_store` for argument - semantics. - """ - completion: MessagingFuture[bool] = MessagingFuture() - non_gpu_context = self._non_gpu_context - semaphore = self._inflight_semaphore - commit_executor = self._commit_executor - assert non_gpu_context is not None - assert semaphore is not None - assert commit_executor is not None - - semaphore.acquire() - staged_chunks: list[torch.Tensor] = [] - shm_out_buffers: list[torch.Tensor] | None = None - gather_done: Any | None = None - try: - with self._inflight_lock: - if self._is_closing: - completion.set_result(False) - semaphore.release() - return completion - - result = non_gpu_context.prepare_store(key, instance_id) - out_buffers, chunk_indices = result if result is not None else (None, None) - if chunk_indices is not None and len(chunk_indices) == 0: - # All chunks are already in cache: no gather, no commit. - completion.set_result(True) - semaphore.release() - return completion - - full_block_ids = _single_group_block_ids(block_ids) - num_chunks = ( - len(chunk_indices) - if chunk_indices is not None - else len(full_block_ids) // blocks_in_chunk - ) - if not non_gpu_context.layout_desc.shapes: - raise RuntimeError("non-GPU layout_desc.shapes is empty") - if not non_gpu_context.layout_desc.dtypes: - raise RuntimeError("non-GPU layout_desc.dtypes is empty") - staged_chunks = self._alloc_pinned_staging( - non_gpu_context.layout_desc.shapes[0], - non_gpu_context.layout_desc.dtypes[0], - num_chunks, - ) - shm_out_buffers = out_buffers - with torch_dev.stream(self._copy_stream): - _event.wait(stream=self._copy_stream) - gather_paged_kv_to_cpu( - kv_caches, - full_block_ids, - blocks_in_chunk, - layout_hints=self._layout_hints, - gpu_kv_format=self._gpu_kv_format, - out=staged_chunks, - chunk_indices=chunk_indices, - ) - gather_done = torch_dev.Event() - gather_done.record(self._copy_stream) - - with self._inflight_lock: - if gather_done is not None: - self._inflight_gather_events.add(gather_done) - - def _commit_after_gather() -> None: - ok = False - try: - if gather_done is not None: - gather_done.synchronize() - if shm_out_buffers is not None: - if len(staged_chunks) != len(shm_out_buffers): - raise RuntimeError( - "SHM staging chunk count mismatch: " - f"{len(staged_chunks)} vs {len(shm_out_buffers)} " - f"(request_id={_request_id}, instance_id={instance_id})" - ) - for staged, shm_view in zip( - staged_chunks, shm_out_buffers, strict=True - ): - shm_view.copy_(staged) - ok = non_gpu_context.commit_store( - key, instance_id, shm_out_buffers - ) - else: - ok = non_gpu_context.commit_store( - key, instance_id, staged_chunks - ) - if not ok: - logger.error( - "Async non-GPU commit_store failed for request_id=%s", - _request_id, - ) - except Exception: - logger.exception( - "Async non-GPU store failed for request_id=%s", - _request_id, - ) - ok = False - finally: - self._release_staging(staged_chunks) - with self._inflight_lock: - if gather_done is not None: - self._inflight_gather_events.discard(gather_done) - completion.set_result(ok) - semaphore.release() - - # Submitting the commit task is the ownership-transfer point: once it - # succeeds, the commit task is solely responsible for releasing the - # semaphore, releasing staging buffers, and resolving the future. The - # except below therefore only handles failures that occur *before* - # this submit, so it can never double-release or double-resolve. - commit_future = commit_executor.submit(_commit_after_gather) - except Exception: - logger.exception("Failed to submit async non-GPU store") - if staged_chunks: - self._release_staging(staged_chunks) - if gather_done is not None: - with self._inflight_lock: - self._inflight_gather_events.discard(gather_done) - completion.set_result(False) - semaphore.release() - return completion - - with self._inflight_lock: - self._inflight_commits.add(commit_future) - - def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: - with self._inflight_lock: - self._inflight_commits.discard(done_future) - - commit_future.add_done_callback(_drop_commit_future) - return completion - def submit_retrieve( self, _request_id: str, @@ -807,33 +493,10 @@ def submit_retrieve( return future def close(self) -> None: - # Drain in-flight async work only when async resources were created. - # In sync (fallback) mode there is no copy stream / executor / inflight - # state, so guard against touching never-created attributes. - if self._async_capable: - with self._inflight_lock: - self._is_closing = True - gather_events = list(self._inflight_gather_events) - for event in gather_events: - try: - event.synchronize() - except Exception: - logger.exception("Failed while draining gather events") - if self._commit_executor is not None: - self._commit_executor.shutdown(wait=True, cancel_futures=False) if self._non_gpu_context is not None: self._non_gpu_context.close() self._non_gpu_context = None - def flush_inflight_gathers(self) -> None: - # Cheap no-op in sync mode: no copy stream / in-flight gather events. - if not self._async_capable: - return - with self._inflight_lock: - gather_events = list(self._inflight_gather_events) - for event in gather_events: - event.synchronize() - def create_transfer_context( kv_caches: dict[str, torch.Tensor], @@ -878,8 +541,59 @@ def create_transfer_context( if resolved_mode is MPTransferMode.HANDLE: return _build_handle_context(device_type) if resolved_mode is MPTransferMode.DATA: - return DataTransferContext() + return _build_data_context(kv_caches) # AUTO: preserve the historical device-type-based dispatch. if device_type == "cuda": return HandleTransferContext() + return _build_data_context(kv_caches) + + +def _supports_async_primitives(kv_caches: dict[str, torch.Tensor]) -> bool: + """Probe whether the worker device supports the async store primitives. + + The async non-GPU store path needs a stream, an event exposing + ``record``/``synchronize``/``wait``, and pinned (page-locked) host memory. + When any of these is unavailable (e.g. a CPU-only backend), the factory + falls back to the synchronous :class:`DataTransferContext`. This dispatch is + internal and capability-based; there is no user-facing async/sync flag. + + Args: + kv_caches: Worker KV cache tensors keyed by layer name. Currently unused + (capability is a property of ``torch_dev``), accepted to keep the + probe signature forward-compatible with device-specific checks. + + Returns: + True if all required async primitives are available, else False. + """ + if not hasattr(torch_dev, "Stream") or not hasattr(torch_dev, "Event"): + return False + try: + torch_dev.Stream() + event = torch_dev.Event() + except Exception: + return False + for attr in ("record", "synchronize", "wait"): + if not callable(getattr(event, attr, None)): + return False + try: + torch.empty(1, dtype=torch.uint8, device="cpu", pin_memory=True) + except (RuntimeError, TypeError): + return False + return True + + +def _build_data_context(kv_caches: dict[str, torch.Tensor]) -> "TransferContext": + """Build the non-GPU data context, async when device-capable else sync. + + Routes the ``DATA`` and AUTO non-cuda branches through a single capability + check. ``AsyncDataTransferContext`` is imported lazily to avoid an import + cycle and to keep the synchronous path free of stream/event dependencies. + """ + if _supports_async_primitives(kv_caches): + # Local + from lmcache.v1.multiprocess.transfer_context.async_data import ( + AsyncDataTransferContext, + ) + + return AsyncDataTransferContext() return DataTransferContext() diff --git a/tests/v1/multiprocess/test_async_data_transfer_context.py b/tests/v1/multiprocess/test_async_data_transfer_context.py index ca02a2ea0c..a2b14af5d1 100644 --- a/tests/v1/multiprocess/test_async_data_transfer_context.py +++ b/tests/v1/multiprocess/test_async_data_transfer_context.py @@ -12,7 +12,10 @@ import torch # First Party -from lmcache.v1.multiprocess.transfer_context import worker_transfer +from lmcache.v1.multiprocess.transfer_context import async_data, worker_transfer +from lmcache.v1.multiprocess.transfer_context.async_data import ( + AsyncDataTransferContext, +) from lmcache.v1.multiprocess.transfer_context.worker_transfer import DataTransferContext @@ -94,6 +97,9 @@ def _gather( tensor.fill_(1.0) return out + # The async path resolves ``gather_paged_kv_to_cpu`` from async_data, the + # sync path from worker_transfer; patch both so either is exercised. + monkeypatch.setattr(async_data, "gather_paged_kv_to_cpu", _gather) monkeypatch.setattr(worker_transfer, "gather_paged_kv_to_cpu", _gather) @@ -103,15 +109,11 @@ def _new_context( gather_gate: threading.Event, commit_impl: Callable[[list[torch.Tensor]], bool], max_inflight: int = 8, -) -> DataTransferContext: - monkeypatch.setattr(worker_transfer, "torch_dev", _FakeTorchDev(gather_gate)) +) -> AsyncDataTransferContext: + monkeypatch.setattr(async_data, "torch_dev", _FakeTorchDev(gather_gate)) _install_fake_gather(monkeypatch) - ctx = DataTransferContext(max_inflight_stores=max_inflight) + ctx = AsyncDataTransferContext(max_inflight_stores=max_inflight) ctx._non_gpu_context = _FakeStoreContext(commit_impl=commit_impl) - # Async tests exercise the async store path directly; enable it explicitly - # so the capability probe (which needs real pinned memory) is bypassed. - ctx._async_capable = True - ctx._create_async_resources() return ctx @@ -232,7 +234,7 @@ def _commit(_chunks: list[torch.Tensor]) -> bool: raise RuntimeError("commit failed") log_exception = MagicMock() - monkeypatch.setattr(worker_transfer.logger, "exception", log_exception) + monkeypatch.setattr(async_data.logger, "exception", log_exception) ctx = _new_context(monkeypatch, gather_gate=gather_gate, commit_impl=_commit) future = ctx.submit_store( "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 @@ -243,6 +245,19 @@ def _commit(_chunks: list[torch.Tensor]) -> bool: ctx.close() +def test_flush_inflight_gathers_no_inflight_is_noop( + monkeypatch: pytest.MonkeyPatch, +) -> None: + gather_gate = threading.Event() + gather_gate.set() + ctx = _new_context( + monkeypatch, gather_gate=gather_gate, commit_impl=lambda _c: True + ) + # No in-flight events: flush is a cheap no-op and must not raise. + ctx.flush_inflight_gathers() + ctx.close() + + class _RecordingTorchDev: """torch_dev stub that records whether async primitives are touched.""" @@ -266,7 +281,7 @@ def Event(self, interprocess: bool = False) -> _FakeEvent: return _FakeEvent(threading.Event()) -def test_submit_store_sync_path_returns_resolved_future( +def test_sync_data_context_returns_resolved_future( monkeypatch: pytest.MonkeyPatch, ) -> None: fake = _RecordingTorchDev() @@ -274,7 +289,6 @@ def test_submit_store_sync_path_returns_resolved_future( _install_fake_gather(monkeypatch) ctx = DataTransferContext() ctx._non_gpu_context = _FakeStoreContext(commit_impl=lambda _c: True) - assert ctx._async_capable is False future = ctx.submit_store( "r1", @@ -292,68 +306,48 @@ def test_submit_store_sync_path_returns_resolved_future( assert fake.stream_calls == 0 assert fake.event_calls == 0 assert fake.synchronize_calls >= 1 + # flush_inflight_gathers is the inherited base no-op; neither it nor close + # must raise on the synchronous path. + ctx.flush_inflight_gathers() ctx.close() -def test_submit_store_dispatches_on_capability( - monkeypatch: pytest.MonkeyPatch, -) -> None: +def test_sync_data_context_has_no_async_resources() -> None: + # The synchronous DataTransferContext must not create any async resources + # and must not expose async-only attributes. ctx = DataTransferContext() - ctx._non_gpu_context = MagicMock() - async_mock = MagicMock(return_value="async") - sync_mock = MagicMock(return_value="sync") - monkeypatch.setattr(ctx, "_submit_store_async", async_mock) - monkeypatch.setattr(ctx, "_submit_store_sync", sync_mock) - - ctx._async_capable = True - assert ctx.submit_store("r", None, 1, {}, [[0]], None, 1) == "async" - async_mock.assert_called_once() - sync_mock.assert_not_called() - - async_mock.reset_mock() - ctx._async_capable = False - assert ctx.submit_store("r", None, 1, {}, [[0]], None, 1) == "sync" - sync_mock.assert_called_once() - async_mock.assert_not_called() + assert not hasattr(ctx, "_copy_stream") + assert not hasattr(ctx, "_commit_executor") + assert not hasattr(ctx, "_inflight_semaphore") + # close() on an unregistered sync context must not raise. + ctx.close() -def test_init_async_capability_non_capable_skips_resources( +def test_build_data_context_dispatches_on_capability( monkeypatch: pytest.MonkeyPatch, ) -> None: - # A torch_dev without Stream/Event is not async-capable. - monkeypatch.setattr(worker_transfer, "torch_dev", object()) - ctx = DataTransferContext() - assert ctx._copy_stream is None - assert ctx._commit_executor is None - - ctx._init_async_capability() - - assert ctx._async_capable is False - assert ctx._copy_stream is None - assert ctx._commit_executor is None - assert ctx._inflight_semaphore is None - # close() must not raise even though async resources were never created. - ctx.close() + kv_caches = {"k": torch.zeros(1)} + + monkeypatch.setattr(worker_transfer, "_supports_async_primitives", lambda _kv: True) + # Avoid touching real stream/event primitives when instantiating the async + # context returned by the capable branch. + monkeypatch.setattr(async_data, "torch_dev", _FakeTorchDev(threading.Event())) + capable = worker_transfer._build_data_context(kv_caches) + assert isinstance(capable, AsyncDataTransferContext) + capable.close() + + monkeypatch.setattr( + worker_transfer, "_supports_async_primitives", lambda _kv: False + ) + fallback = worker_transfer._build_data_context(kv_caches) + assert isinstance(fallback, DataTransferContext) + assert not isinstance(fallback, AsyncDataTransferContext) + fallback.close() -def test_init_async_capability_capable_creates_resources( +def test_supports_async_primitives_false_without_stream( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(worker_transfer, "torch_dev", _FakeTorchDev(threading.Event())) - # Make the pinned-memory probe succeed regardless of host CUDA support. - real_empty = torch.empty - - def _fake_empty(*args: object, **kwargs: object) -> torch.Tensor: - kwargs.pop("pin_memory", None) - return real_empty(*args, **kwargs) - - monkeypatch.setattr(worker_transfer.torch, "empty", _fake_empty) - - ctx = DataTransferContext() - ctx._init_async_capability() - - assert ctx._async_capable is True - assert ctx._copy_stream is not None - assert ctx._commit_executor is not None - assert ctx._inflight_semaphore is not None - ctx.close() + # A torch_dev without Stream/Event is not async-capable. + monkeypatch.setattr(worker_transfer, "torch_dev", object()) + assert worker_transfer._supports_async_primitives({"k": torch.zeros(1)}) is False From 2ba9c49b5d961c16f0eddf0811c316c89fd40458 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 6 Jun 2026 02:07:19 +0000 Subject: [PATCH 06/31] Improve AsyncDataTransferContext docstrings per review --- .../v1/multiprocess/transfer_context/async_data.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py index b1fb3fd8c0..8be65dc5f4 100644 --- a/lmcache/v1/multiprocess/transfer_context/async_data.py +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -33,6 +33,12 @@ class AsyncDataTransferContext(DataTransferContext): """Fully async non-GPU data transfer context (store-only async). + "Store-only async" means ``submit_store`` returns an *unresolved* future + that resolves only after the deferred gather (GPU->CPU copy) and commit + (CPU->server) both complete off the forward thread, while + ``submit_retrieve`` stays synchronous and returns an already-resolved + future exactly as on the base context. + Inherits :class:`DataTransferContext` and reuses its ``register()`` (layout / SHM registration, no stream dependency) and ``submit_retrieve()`` (this path does not change retrieve). Only the store is made async. @@ -265,6 +271,13 @@ def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: return completion def flush_inflight_gathers(self) -> None: + """Synchronize all in-flight gather (GPU->CPU) events. + + Called at preemption/eviction time (and during ``close``) so that vLLM + cannot overwrite paged KV blocks before a deferred gather has finished + reading them. Only gather completion is awaited; commit futures are not + affected, since commits read from LMCache-owned staging buffers. + """ with self._inflight_lock: gather_events = list(self._inflight_gather_events) for event in gather_events: From cd86a4e0c19bb3d61dbf2a8ab578a433248edcca Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Sat, 6 Jun 2026 11:21:19 +0800 Subject: [PATCH 07/31] test: align preemption-flush tests with real vLLM schema Drop the dead-field cases (resumed_from_preemption / evicted_req_ids), which do not exist on vLLM main's CachedRequestData / SchedulerOutput, and keep only the real signals (resumed_req_ids, preempted_req_ids) plus the conservative unknown-schema fallback. --- .../test_lmcache_mp_connector_preemption.py | 31 ++++--------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/tests/v1/test_lmcache_mp_connector_preemption.py b/tests/v1/test_lmcache_mp_connector_preemption.py index 1f0ea93611..5860ed254d 100644 --- a/tests/v1/test_lmcache_mp_connector_preemption.py +++ b/tests/v1/test_lmcache_mp_connector_preemption.py @@ -46,12 +46,8 @@ def test_build_connector_meta_keeps_need_flush_false_without_signal() -> None: connector._report_block_allocation_deltas = lambda scheduler_output: None scheduler_output = SimpleNamespace( - scheduled_cached_reqs=SimpleNamespace( - resumed_req_ids=[], - resumed_from_preemption=[], - ), + scheduled_cached_reqs=SimpleNamespace(resumed_req_ids=[]), preempted_req_ids=[], - evicted_req_ids=[], ) metadata = connector.build_connector_meta(scheduler_output) assert isinstance(metadata, LMCacheMPConnectorMetadata) @@ -61,27 +57,15 @@ def test_build_connector_meta_keeps_need_flush_false_without_signal() -> None: @pytest.mark.parametrize( "scheduler_output", [ + # Resumed-from-preemption signal (CachedRequestData.resumed_req_ids). SimpleNamespace( - scheduled_cached_reqs=SimpleNamespace( - resumed_req_ids=[], - resumed_from_preemption=[False, True], - ) + scheduled_cached_reqs=SimpleNamespace(resumed_req_ids=["req-1"]), ), + # Preempted-this-step signal (SchedulerOutput.preempted_req_ids). SimpleNamespace( - scheduled_cached_reqs=SimpleNamespace( - resumed_req_ids=[], - resumed_from_preemption=[], - ), + scheduled_cached_reqs=SimpleNamespace(resumed_req_ids=[]), preempted_req_ids=["req-1"], ), - SimpleNamespace( - scheduled_cached_reqs=SimpleNamespace( - resumed_req_ids=[], - resumed_from_preemption=[], - ), - preempted_req_ids=[], - evicted_req_ids=["req-2"], - ), ], ) def test_scheduler_step_needs_flush_for_all_supported_signals( @@ -117,9 +101,6 @@ def test_scheduler_step_needs_flush_conservative_on_unknown_schema() -> None: def test_scheduler_step_needs_flush_false_for_recognized_no_preemption() -> None: connector = _new_connector_without_init() scheduler_output = SimpleNamespace( - scheduled_cached_reqs=SimpleNamespace( - resumed_req_ids=[], - resumed_from_preemption=[], - ) + scheduled_cached_reqs=SimpleNamespace(resumed_req_ids=[]) ) assert connector._scheduler_step_needs_flush(scheduler_output) is False From 711f5851d7e3b1561faaf708a93a8c9f866168e2 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Sat, 6 Jun 2026 13:08:59 +0800 Subject: [PATCH 08/31] refactor: simplify preemption-flush detection to real vLLM fields _scheduler_step_needs_flush previously probed two fields that do not exist on vLLM main's schema: CachedRequestData.resumed_from_preemption (replaced by resumed_req_ids) and SchedulerOutput.evicted_req_ids (never existed). Those getattr checks were dead code and the comment was inaccurate. Verified against vLLM main (vllm/v1/core/sched/output.py): - CachedRequestData.resumed_req_ids: set[str] -> real resume signal - SchedulerOutput.preempted_req_ids: set[str] | None -> real preempt signal (populated unconditionally in scheduler.py) Keep only those two real signals plus the conservative unknown-schema fallback (flush when scheduled_cached_reqs lacks resumed_req_ids). This matches the test cleanup in the previous commit; behavior on real vLLM is unchanged. --- .../integration/vllm/lmcache_mp_connector.py | 44 ++++++++----------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/lmcache/integration/vllm/lmcache_mp_connector.py b/lmcache/integration/vllm/lmcache_mp_connector.py index 1b2ef14fea..988fb313cd 100644 --- a/lmcache/integration/vllm/lmcache_mp_connector.py +++ b/lmcache/integration/vllm/lmcache_mp_connector.py @@ -999,35 +999,29 @@ def _scheduler_step_needs_flush(self, scheduler_output: SchedulerOutput) -> bool Under-syncing here risks KV-block corruption (a paged block may be overwritten before a deferred async gather reads it), while over-syncing only costs performance, so we prefer a spurious flush over a missed one. + + Signal fields are verified against vLLM main: + - ``CachedRequestData.resumed_req_ids``: requests resumed from + preemption this step. Their blocks are replaced (not appended), so an + in-flight gather against the old blocks must be flushed first. + - ``SchedulerOutput.preempted_req_ids``: requests preempted this step. + Their blocks are freed and may be reused, so the same applies. """ - # In vLLM v1, resumed preemptions are surfaced on the cached-request - # struct via ``resumed_req_ids`` / ``resumed_from_preemption``. cached_reqs = getattr(scheduler_output, "scheduled_cached_reqs", None) - resumed_req_ids = getattr(cached_reqs, "resumed_req_ids", None) - if resumed_req_ids: - return True - resumed_flags = getattr(cached_reqs, "resumed_from_preemption", None) - if resumed_flags and any(resumed_flags): - return True - # Resilient optional checks: these fields may not exist on every vLLM - # ``SchedulerOutput`` version; getattr keeps the probe harmless when so. - preempted_req_ids = getattr(scheduler_output, "preempted_req_ids", None) - if preempted_req_ids: + + # Primary signal: requests resumed from preemption this step. + if getattr(cached_reqs, "resumed_req_ids", None): return True - evicted_req_ids = getattr(scheduler_output, "evicted_req_ids", None) - if evicted_req_ids: + + # Primary signal: requests preempted this step. + if getattr(scheduler_output, "preempted_req_ids", None): return True - # Conservative fallback: a recognized cached-request schema is expected - # to expose at least one of ``resumed_req_ids`` / - # ``resumed_from_preemption`` (older vLLM exposes the former, newer the - # latter; they are version-dependent alternatives, not required to - # co-exist). If cached requests are present but expose neither, the - # schema is unrecognized and we cannot prove the step is - # preemption-free, so flush rather than risk corruption. - if cached_reqs is not None and not ( - hasattr(cached_reqs, "resumed_req_ids") - or hasattr(cached_reqs, "resumed_from_preemption") - ): + + # Conservative fallback: if cached requests are present but the schema + # exposes no recognized ``resumed_req_ids`` field (e.g. a much older or + # forked vLLM), we cannot prove the step is preemption-free, so flush + # rather than risk corruption. + if cached_reqs is not None and not hasattr(cached_reqs, "resumed_req_ids"): logger.warning( "Unrecognized scheduled_cached_reqs schema (%s); conservatively " "flushing in-flight async gathers to avoid KV block corruption.", From 0dee0ac1e4437449bda93cfbe147e999fc5b85c0 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Mon, 15 Jun 2026 12:40:53 +0800 Subject: [PATCH 09/31] fix: wrap shm_view.copy_ with inference_mode(False) to avoid InferenceMode error in commit thread PyTorch's InferenceMode propagates to child threads. The commit thread inherits InferenceMode from the vLLM EngineCore main thread, causing `shm_view.copy_(staged)` to raise: "Inplace update to inference tensor outside InferenceMode is not allowed" Fix by explicitly exiting InferenceMode for the inplace copy operation. --- .../v1/multiprocess/transfer_context/async_data.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py index 8be65dc5f4..dc72a2e9a2 100644 --- a/lmcache/v1/multiprocess/transfer_context/async_data.py +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -213,10 +213,13 @@ def _commit_after_gather() -> None: f"{len(staged_chunks)} vs {len(shm_out_buffers)} " f"(request_id={_request_id}, instance_id={instance_id})" ) - for staged, shm_view in zip( - staged_chunks, shm_out_buffers, strict=True - ): - shm_view.copy_(staged) + # Exit InferenceMode inherited from the vLLM main + # thread — inplace copy_ is disallowed under it. + with torch.inference_mode(False): + for staged, shm_view in zip( + staged_chunks, shm_out_buffers, strict=True + ): + shm_view.copy_(staged) ok = non_gpu_context.commit_store( key, instance_id, shm_out_buffers ) From 471b84b5247e214117fd991f1d86f84d061ea7b0 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Mon, 15 Jun 2026 12:42:52 +0800 Subject: [PATCH 10/31] fix: gather directly into SHM view when available, eliminating redundant staging copy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When SHM out_buffers are available from prepare_store(), gather directly into them on the copy stream — matching the synchronous DataTransferContext behavior. This removes: 1. The redundant pinned staging buffer allocation for SHM path 2. The staged→shm_view copy_ in the commit thread 3. The InferenceMode error caused by that copy_ Only the pickle path (no SHM) still uses pinned staging buffers. --- .../transfer_context/async_data.py | 79 +++++++++---------- 1 file changed, 38 insertions(+), 41 deletions(-) diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py index dc72a2e9a2..f0c776ce17 100644 --- a/lmcache/v1/multiprocess/transfer_context/async_data.py +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -44,18 +44,15 @@ class AsyncDataTransferContext(DataTransferContext): path does not change retrieve). Only the store is made async. Store is two-phase: - 1) gather: enqueue GPU->CPU copies on a dedicated copy stream into - LMCache-owned pinned staging buffers (ordered behind the per-step event). + 1) gather: enqueue GPU->CPU copies on a dedicated copy stream. When SHM + buffers are available, gather writes directly into SHM views (matching + the synchronous path). Otherwise, gather targets pinned staging buffers. 2) commit: wait for gather completion in a background thread, then perform - commit_store() (pickle or SHM commit) and resolve the returned future. + commit_store() and resolve the returned future. This class is only instantiated by the factory when the device is async-capable, so the constructor creates async resources unconditionally; there is no ``self._async_capable`` flag. - - SHM note: SHM slots are generally pageable, so device->SHM DtoH copies may - implicitly synchronize. To keep gather async, we always gather into pinned - bounce buffers first, then copy to SHM slots on the commit thread. """ def __init__( @@ -150,8 +147,10 @@ def submit_store( semaphore.acquire() staged_chunks: list[torch.Tensor] = [] - shm_out_buffers: list[torch.Tensor] | None = None gather_done: Any | None = None + # Whether we gathered directly into SHM views (True) or into + # pinned staging buffers that need to be released later (False). + used_shm_direct = False try: with self._inflight_lock: if self._is_closing: @@ -173,16 +172,27 @@ def submit_store( if chunk_indices is not None else len(full_block_ids) // blocks_in_chunk ) - if not non_gpu_context.layout_desc.shapes: - raise RuntimeError("non-GPU layout_desc.shapes is empty") - if not non_gpu_context.layout_desc.dtypes: - raise RuntimeError("non-GPU layout_desc.dtypes is empty") - staged_chunks = self._alloc_pinned_staging( - non_gpu_context.layout_desc.shapes[0], - non_gpu_context.layout_desc.dtypes[0], - num_chunks, - ) - shm_out_buffers = out_buffers + + # Determine gather target: + # - SHM path (out_buffers available): gather directly into SHM views + # - Pickle path (no out_buffers): gather into pinned staging buffers + if out_buffers is not None: + # SHM path: gather directly into SHM views, no staging needed. + gather_target = out_buffers + used_shm_direct = True + else: + # Pickle path: allocate pinned staging buffers. + if not non_gpu_context.layout_desc.shapes: + raise RuntimeError("non-GPU layout_desc.shapes is empty") + if not non_gpu_context.layout_desc.dtypes: + raise RuntimeError("non-GPU layout_desc.dtypes is empty") + staged_chunks = self._alloc_pinned_staging( + non_gpu_context.layout_desc.shapes[0], + non_gpu_context.layout_desc.dtypes[0], + num_chunks, + ) + gather_target = staged_chunks + with torch_dev.stream(self._copy_stream): _event.wait(stream=self._copy_stream) gather_paged_kv_to_cpu( @@ -191,7 +201,7 @@ def submit_store( blocks_in_chunk, layout_hints=self._layout_hints, gpu_kv_format=self._gpu_kv_format, - out=staged_chunks, + out=gather_target, chunk_indices=chunk_indices, ) gather_done = torch_dev.Event() @@ -201,32 +211,18 @@ def submit_store( if gather_done is not None: self._inflight_gather_events.add(gather_done) + # Capture variables for the closure + _used_shm_direct = used_shm_direct + _gather_target = gather_target + def _commit_after_gather() -> None: ok = False try: if gather_done is not None: gather_done.synchronize() - if shm_out_buffers is not None: - if len(staged_chunks) != len(shm_out_buffers): - raise RuntimeError( - "SHM staging chunk count mismatch: " - f"{len(staged_chunks)} vs {len(shm_out_buffers)} " - f"(request_id={_request_id}, instance_id={instance_id})" - ) - # Exit InferenceMode inherited from the vLLM main - # thread — inplace copy_ is disallowed under it. - with torch.inference_mode(False): - for staged, shm_view in zip( - staged_chunks, shm_out_buffers, strict=True - ): - shm_view.copy_(staged) - ok = non_gpu_context.commit_store( - key, instance_id, shm_out_buffers - ) - else: - ok = non_gpu_context.commit_store( - key, instance_id, staged_chunks - ) + ok = non_gpu_context.commit_store( + key, instance_id, _gather_target + ) if not ok: logger.error( "Async non-GPU commit_store failed for request_id=%s", @@ -239,7 +235,8 @@ def _commit_after_gather() -> None: ) ok = False finally: - self._release_staging(staged_chunks) + if not _used_shm_direct: + self._release_staging(staged_chunks) with self._inflight_lock: if gather_done is not None: self._inflight_gather_events.discard(gather_done) From 8be4642885bb3b955a2275e8d24c20e9c21df7e8 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Tue, 16 Jun 2026 05:31:04 +0000 Subject: [PATCH 11/31] add logs Signed-off-by: Tony Lin --- .../multiprocess/transfer_context/async_data.py | 17 ++++++++++++++++- .../transfer_context/worker_transfer.py | 8 +++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py index f0c776ce17..826de356bf 100644 --- a/lmcache/v1/multiprocess/transfer_context/async_data.py +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -193,8 +193,13 @@ def submit_store( ) gather_target = staged_chunks + import time + with torch_dev.stream(self._copy_stream): + t1 = time.perf_counter() _event.wait(stream=self._copy_stream) + t2 = time.perf_counter() + gather_paged_kv_to_cpu( kv_caches, full_block_ids, @@ -204,9 +209,19 @@ def submit_store( out=gather_target, chunk_indices=chunk_indices, ) + t3 = time.perf_counter() + gather_done = torch_dev.Event() gather_done.record(self._copy_stream) - + t4 = time.perf_counter() + # Print intervals in milliseconds (ms) + logger.info( + "[Store Profiler] wait: %.3f ms | gather_to_cpu: %.3f ms | record_event: %.3f ms | total: %.3f ms", + (t2 - t1) * 1000, + (t3 - t2) * 1000, + (t4 - t3) * 1000, + (t4 - t1) * 1000 + ) with self._inflight_lock: if gather_done is not None: self._inflight_gather_events.add(gather_done) diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index 59fcec73c1..8a60a01805 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -543,8 +543,8 @@ def create_transfer_context( if resolved_mode is MPTransferMode.DATA: return _build_data_context(kv_caches) # AUTO: preserve the historical device-type-based dispatch. - if device_type == "cuda": - return HandleTransferContext() + # if device_type == "cuda": + # return HandleTransferContext() return _build_data_context(kv_caches) @@ -594,6 +594,8 @@ def _build_data_context(kv_caches: dict[str, torch.Tensor]) -> "TransferContext" from lmcache.v1.multiprocess.transfer_context.async_data import ( AsyncDataTransferContext, ) - + logger.info(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> AsyncDataTransferContext ") return AsyncDataTransferContext() + + logger.info(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> SyncDataTransferContext ") return DataTransferContext() From 1dd3bd7ce17b14c333981249c55a3b1fd881c546 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 16 Jun 2026 05:32:05 +0000 Subject: [PATCH 12/31] Fix SHM worker host registration --- .../v1/multiprocess/transfer_context/shm.py | 78 ++++++++++- .../test_non_cuda_data_transfer.py | 125 ++++++++++++++++++ 2 files changed, 200 insertions(+), 3 deletions(-) diff --git a/lmcache/v1/multiprocess/transfer_context/shm.py b/lmcache/v1/multiprocess/transfer_context/shm.py index 0178d27faf..0a549f1cdc 100644 --- a/lmcache/v1/multiprocess/transfer_context/shm.py +++ b/lmcache/v1/multiprocess/transfer_context/shm.py @@ -2,6 +2,7 @@ """Shared-memory NonGpuContext implementation for multiprocess mode.""" # Standard +import ctypes from dataclasses import dataclass from multiprocessing import shared_memory from multiprocessing.resource_tracker import unregister @@ -11,6 +12,8 @@ import torch # First Party +from lmcache import torch_dev +from lmcache.logging import init_logger from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey from lmcache.v1.multiprocess.mq import MessageQueueClient from lmcache.v1.multiprocess.protocol import RequestType, get_response_class @@ -19,6 +22,8 @@ NonGpuContextMetadata, ) +logger = init_logger(__name__) + @dataclass(frozen=True) class ShmSlotDescriptor: @@ -92,6 +97,9 @@ def __init__( self._pool_size = pool_size self._shm: shared_memory.SharedMemory | None = None self._shm_buffer: memoryview | None = None + self._pinned = False + self._pinned_ptr = 0 + self._pinned_size = 0 try: self._shm = shared_memory.SharedMemory( name=shm_name.lstrip("/"), create=False @@ -101,6 +109,7 @@ def __init__( # unlink the segment when this worker exits. unregister(f"/{self._shm.name}", "shared_memory") self._shm_buffer = self._shm.buf + self._register_shm_buffer() except Exception: self._shm = None self._shm_buffer = None @@ -212,7 +221,70 @@ def close(self) -> None: if self._shm is None: return try: - self._shm.close() + self._unregister_shm_buffer() finally: - self._shm = None - self._shm_buffer = None + try: + self._shm.close() + finally: + self._shm = None + self._shm_buffer = None + + def _register_shm_buffer(self) -> None: + if self._shm_buffer is None or not torch_dev.is_available(): + return + if not hasattr(torch_dev, "cudart"): + logger.warning( + "Skipping SHM host registration for shm_name=%s: " + "backend does not support cudart(); D2H copies will be synchronous", + self._shm_name, + ) + return + try: + ptr = ctypes.addressof(ctypes.c_char.from_buffer(self._shm_buffer)) + err = torch_dev.cudart().cudaHostRegister(ptr, self._pool_size, 0) + except Exception as exc: + logger.warning( + "Failed to register SHM buffer for shm_name=%s: %s; " + "D2H copies will be synchronous", + self._shm_name, + exc, + ) + return + if err != 0: + logger.warning( + "cudaHostRegister failed for shm_name=%s (ptr=%d, size=%d, err=%s); " + "D2H copies will be synchronous", + self._shm_name, + ptr, + self._pool_size, + err, + ) + return + self._pinned = True + self._pinned_ptr = ptr + self._pinned_size = self._pool_size + + def _unregister_shm_buffer(self) -> None: + if not self._pinned or self._pinned_ptr == 0: + return + try: + err = torch_dev.cudart().cudaHostUnregister(self._pinned_ptr) + if err != 0: + logger.warning( + "cudaHostUnregister failed for shm_name=%s (ptr=%d, size=%d, " + "err=%s)", + self._shm_name, + self._pinned_ptr, + self._pinned_size, + err, + ) + except Exception as exc: + logger.warning( + "Failed to unregister SHM buffer for shm_name=%s: %s", + self._shm_name, + exc, + ) + finally: + self._pinned = False + self._pinned_ptr = 0 + self._pinned_size = 0 diff --git a/tests/v1/multiprocess/test_non_cuda_data_transfer.py b/tests/v1/multiprocess/test_non_cuda_data_transfer.py index c60290b917..83a1b66a97 100644 --- a/tests/v1/multiprocess/test_non_cuda_data_transfer.py +++ b/tests/v1/multiprocess/test_non_cuda_data_transfer.py @@ -1162,3 +1162,128 @@ def test_non_gpu_context_shm_close_is_idempotent() -> None: finally: if os.path.exists(shm_path): os.unlink(shm_path) + + +def test_non_gpu_context_shm_registers_and_unregisters_host_memory( + monkeypatch: Any, +) -> None: + shm_name = f"lmcache_test_pin_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + + class FakeCudaRt: + def __init__(self) -> None: + self.register_calls: list[tuple[int, int, int]] = [] + self.unregister_calls: list[int] = [] + + def cudaHostRegister(self, ptr: int, size: int, flags: int) -> int: + self.register_calls.append((ptr, size, flags)) + return 0 + + def cudaHostUnregister(self, ptr: int) -> int: + self.unregister_calls.append(ptr) + return 0 + + class FakeTorchDev: + def __init__(self, cudart: FakeCudaRt) -> None: + self._cudart = cudart + + def is_available(self) -> bool: + return True + + def cudart(self) -> FakeCudaRt: + return self._cudart + + # First Party + import lmcache.v1.multiprocess.transfer_context.shm as shm_module + + fake_cudart = FakeCudaRt() + monkeypatch.setattr(shm_module, "torch_dev", FakeTorchDev(fake_cudart)) + + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + try: + assert len(fake_cudart.register_calls) == 1 + ptr, size, flags = fake_cudart.register_calls[0] + assert ptr > 0 + assert size == 4096 + assert flags == 0 + finally: + context.close() + if os.path.exists(shm_path): + os.unlink(shm_path) + + assert fake_cudart.unregister_calls == [ptr] + + +def test_non_gpu_context_shm_register_failure_warns_and_skips_unregister( + monkeypatch: Any, +) -> None: + shm_name = f"lmcache_test_pin_fail_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + + class FakeCudaRt: + def __init__(self) -> None: + self.register_calls: list[tuple[int, int, int]] = [] + self.unregister_calls: list[int] = [] + + def cudaHostRegister(self, ptr: int, size: int, flags: int) -> int: + self.register_calls.append((ptr, size, flags)) + return 1 + + def cudaHostUnregister(self, ptr: int) -> int: + self.unregister_calls.append(ptr) + return 0 + + class FakeTorchDev: + def __init__(self, cudart: FakeCudaRt) -> None: + self._cudart = cudart + + def is_available(self) -> bool: + return True + + def cudart(self) -> FakeCudaRt: + return self._cudart + + # First Party + import lmcache.v1.multiprocess.transfer_context.shm as shm_module + + fake_cudart = FakeCudaRt() + monkeypatch.setattr(shm_module, "torch_dev", FakeTorchDev(fake_cudart)) + + with patch.object(shm_module.logger, "warning") as warning_mock: + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + try: + assert len(fake_cudart.register_calls) == 1 + finally: + context.close() + if os.path.exists(shm_path): + os.unlink(shm_path) + + assert fake_cudart.unregister_calls == [] + warning_mock.assert_called_once() + assert "cudaHostRegister failed" in warning_mock.call_args[0][0] From b0374ad7336d0950a7fcefa8fa6505525fa9f18c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 16 Jun 2026 05:36:00 +0000 Subject: [PATCH 13/31] Polish SHM pinning validation logs --- lmcache/v1/multiprocess/transfer_context/shm.py | 4 ++-- tests/v1/multiprocess/test_non_cuda_data_transfer.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lmcache/v1/multiprocess/transfer_context/shm.py b/lmcache/v1/multiprocess/transfer_context/shm.py index 0a549f1cdc..7f833ba656 100644 --- a/lmcache/v1/multiprocess/transfer_context/shm.py +++ b/lmcache/v1/multiprocess/transfer_context/shm.py @@ -244,7 +244,7 @@ def _register_shm_buffer(self) -> None: err = torch_dev.cudart().cudaHostRegister(ptr, self._pool_size, 0) except Exception as exc: logger.warning( - "Failed to register SHM buffer for shm_name=%s: %s; " + "Failed to register SHM buffer for shm_name=%s: %r; " "D2H copies will be synchronous", self._shm_name, exc, @@ -280,7 +280,7 @@ def _unregister_shm_buffer(self) -> None: ) except Exception as exc: logger.warning( - "Failed to unregister SHM buffer for shm_name=%s: %s", + "Failed to unregister SHM buffer for shm_name=%s: %r", self._shm_name, exc, ) diff --git a/tests/v1/multiprocess/test_non_cuda_data_transfer.py b/tests/v1/multiprocess/test_non_cuda_data_transfer.py index 83a1b66a97..1aaa38b6f7 100644 --- a/tests/v1/multiprocess/test_non_cuda_data_transfer.py +++ b/tests/v1/multiprocess/test_non_cuda_data_transfer.py @@ -1286,4 +1286,10 @@ def cudart(self) -> FakeCudaRt: assert fake_cudart.unregister_calls == [] warning_mock.assert_called_once() - assert "cudaHostRegister failed" in warning_mock.call_args[0][0] + message, logged_shm_name, _logged_ptr, logged_size, logged_err = ( + warning_mock.call_args[0] + ) + assert "cudaHostRegister failed" in message + assert logged_shm_name == shm_name + assert logged_size == 4096 + assert logged_err == 1 From 4aa4c12ddc7328984cc128870648b278ad559c99 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Tue, 16 Jun 2026 06:17:31 +0000 Subject: [PATCH 14/31] add logs Signed-off-by: Tony Lin --- .../transfer_context/async_data.py | 20 +++++++++++-------- .../v1/multiprocess/transfer_context/shm.py | 1 + .../transfer_context/worker_transfer.py | 7 +++++-- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py index 826de356bf..61f886a3d8 100644 --- a/lmcache/v1/multiprocess/transfer_context/async_data.py +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -193,12 +193,14 @@ def submit_store( ) gather_target = staged_chunks + # Standard import time + t00 = time.perf_counter() with torch_dev.stream(self._copy_stream): - t1 = time.perf_counter() _event.wait(stream=self._copy_stream) - t2 = time.perf_counter() + torch_dev.synchronize() + t1 = time.perf_counter() gather_paged_kv_to_cpu( kv_caches, @@ -209,6 +211,8 @@ def submit_store( out=gather_target, chunk_indices=chunk_indices, ) + t2 = time.perf_counter() + torch_dev.synchronize() t3 = time.perf_counter() gather_done = torch_dev.Event() @@ -216,12 +220,14 @@ def submit_store( t4 = time.perf_counter() # Print intervals in milliseconds (ms) logger.info( - "[Store Profiler] wait: %.3f ms | gather_to_cpu: %.3f ms | record_event: %.3f ms | total: %.3f ms", + "[Store Profiler] launch: %.3f ms | gpu_exec: %.3f ms | total: %.3f ms", (t2 - t1) * 1000, (t3 - t2) * 1000, - (t4 - t3) * 1000, - (t4 - t1) * 1000 + (t3 - t1) * 1000, ) + t11 = time.perf_counter() + logger.info("[Store Profiler] submit block time: %.3f ms", (t11 - t00) * 1000) + with self._inflight_lock: if gather_done is not None: self._inflight_gather_events.add(gather_done) @@ -235,9 +241,7 @@ def _commit_after_gather() -> None: try: if gather_done is not None: gather_done.synchronize() - ok = non_gpu_context.commit_store( - key, instance_id, _gather_target - ) + ok = non_gpu_context.commit_store(key, instance_id, _gather_target) if not ok: logger.error( "Async non-GPU commit_store failed for request_id=%s", diff --git a/lmcache/v1/multiprocess/transfer_context/shm.py b/lmcache/v1/multiprocess/transfer_context/shm.py index 7f833ba656..025b80e4ad 100644 --- a/lmcache/v1/multiprocess/transfer_context/shm.py +++ b/lmcache/v1/multiprocess/transfer_context/shm.py @@ -110,6 +110,7 @@ def __init__( unregister(f"/{self._shm.name}", "shared_memory") self._shm_buffer = self._shm.buf self._register_shm_buffer() + logger.info("SHM pinned=%s for shm_name=%s", self._pinned, self._shm_name) except Exception: self._shm = None self._shm_buffer = None diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index 8a60a01805..cbf715feaa 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -590,11 +590,14 @@ def _build_data_context(kv_caches: dict[str, torch.Tensor]) -> "TransferContext" cycle and to keep the synchronous path free of stream/event dependencies. """ if _supports_async_primitives(kv_caches): - # Local + # First Party from lmcache.v1.multiprocess.transfer_context.async_data import ( AsyncDataTransferContext, ) - logger.info(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> AsyncDataTransferContext ") + + logger.info( + " >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> AsyncDataTransferContext " + ) return AsyncDataTransferContext() logger.info(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> SyncDataTransferContext ") From b68992bcfc7ee94eb934d2f715ec4f03fa852628 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 16 Jun 2026 06:32:43 +0000 Subject: [PATCH 15/31] perf: move gather phase to background thread in AsyncDataTransferContext Previously, submit_store performed the gather kernel launch (including _event.wait() and gather_paged_kv_to_cpu()) directly on the forward thread. When the copy stream has a pending event-wait (for the forward pass to finish), CUDA runtime throttles the CPU as kernels queue up on a stream with unresolved dependencies, blocking the forward thread for ~38ms on every store. This commit moves the entire gather phase into the background _commit_after_gather thread via the commit_executor. The forward thread now only does lightweight preparation (prepare_store, buffer allocation) and immediately submits the work and returns. Background thread now: 1. Acquires copy stream context 2. Inserts event-level wait for forward completion 3. Launches gather_paged_kv_to_cpu() 4. Records gather_done event on copy stream 5. Adds gather_done to _inflight_gather_events (under lock) 6. Synchronizes gather_done (waits for GPU gather to finish) 7. Calls commit_store() and resolves the future Also removes profiling remnants: import time, t00/t1/t2/t3/t4/t11 timing variables, Store Profiler logger.info calls, and the two torch_dev.synchronize() calls that were added for profiling only. --- .../transfer_context/async_data.py | 90 ++++++++----------- 1 file changed, 38 insertions(+), 52 deletions(-) diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py index 61f886a3d8..6fc5183658 100644 --- a/lmcache/v1/multiprocess/transfer_context/async_data.py +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -43,12 +43,18 @@ class AsyncDataTransferContext(DataTransferContext): / SHM registration, no stream dependency) and ``submit_retrieve()`` (this path does not change retrieve). Only the store is made async. - Store is two-phase: - 1) gather: enqueue GPU->CPU copies on a dedicated copy stream. When SHM - buffers are available, gather writes directly into SHM views (matching - the synchronous path). Otherwise, gather targets pinned staging buffers. - 2) commit: wait for gather completion in a background thread, then perform - commit_store() and resolve the returned future. + Store is two-phase, both executed entirely in a background thread: + 1) gather: wait for the forward event on the copy stream, then enqueue + GPU->CPU copies. When SHM buffers are available, gather writes directly + into SHM views (matching the synchronous path). Otherwise, gather + targets pinned staging buffers. + 2) commit: wait for gather completion (via a recorded CUDA event), then + perform commit_store() and resolve the returned future. + + ``submit_store`` only performs lightweight preparation (prepare_store, + buffer allocation) on the forward thread and immediately submits all + GPU/copy work to the background ``commit_executor``, so the forward thread + is never blocked by gather kernel launch latency. This class is only instantiated by the factory when the device is async-capable, so the constructor creates async resources unconditionally; @@ -130,10 +136,12 @@ def submit_store( _event: IPCEvent, blocks_in_chunk: int, ) -> MessagingFuture: - """Two-phase async store (gather on copy stream, deferred commit). + """Two-phase async store (gather and commit both in background thread). - Returns an unresolved future that resolves only after both gather - completion and the commit ACK. + Performs lightweight preparation (prepare_store, buffer allocation) on + the forward thread and immediately submits the gather + commit work to + the background ``commit_executor``. Returns an unresolved future that + resolves only after both gather completion and the commit ACK. """ if self._non_gpu_context is None: raise RuntimeError( @@ -147,7 +155,6 @@ def submit_store( semaphore.acquire() staged_chunks: list[torch.Tensor] = [] - gather_done: Any | None = None # Whether we gathered directly into SHM views (True) or into # pinned staging buffers that need to be released later (False). used_shm_direct = False @@ -193,52 +200,34 @@ def submit_store( ) gather_target = staged_chunks - # Standard - import time - t00 = time.perf_counter() - - with torch_dev.stream(self._copy_stream): - _event.wait(stream=self._copy_stream) - torch_dev.synchronize() - t1 = time.perf_counter() - - gather_paged_kv_to_cpu( - kv_caches, - full_block_ids, - blocks_in_chunk, - layout_hints=self._layout_hints, - gpu_kv_format=self._gpu_kv_format, - out=gather_target, - chunk_indices=chunk_indices, - ) - t2 = time.perf_counter() - torch_dev.synchronize() - t3 = time.perf_counter() - - gather_done = torch_dev.Event() - gather_done.record(self._copy_stream) - t4 = time.perf_counter() - # Print intervals in milliseconds (ms) - logger.info( - "[Store Profiler] launch: %.3f ms | gpu_exec: %.3f ms | total: %.3f ms", - (t2 - t1) * 1000, - (t3 - t2) * 1000, - (t3 - t1) * 1000, - ) - t11 = time.perf_counter() - logger.info("[Store Profiler] submit block time: %.3f ms", (t11 - t00) * 1000) - - with self._inflight_lock: - if gather_done is not None: - self._inflight_gather_events.add(gather_done) - # Capture variables for the closure _used_shm_direct = used_shm_direct _gather_target = gather_target def _commit_after_gather() -> None: + gather_done: Any | None = None ok = False try: + with torch_dev.stream(self._copy_stream): + _event.wait(stream=self._copy_stream) + + gather_paged_kv_to_cpu( + kv_caches, + full_block_ids, + blocks_in_chunk, + layout_hints=self._layout_hints, + gpu_kv_format=self._gpu_kv_format, + out=_gather_target, + chunk_indices=chunk_indices, + ) + + gather_done = torch_dev.Event() + gather_done.record(self._copy_stream) + + with self._inflight_lock: + if gather_done is not None: + self._inflight_gather_events.add(gather_done) + if gather_done is not None: gather_done.synchronize() ok = non_gpu_context.commit_store(key, instance_id, _gather_target) @@ -272,9 +261,6 @@ def _commit_after_gather() -> None: logger.exception("Failed to submit async non-GPU store") if staged_chunks: self._release_staging(staged_chunks) - if gather_done is not None: - with self._inflight_lock: - self._inflight_gather_events.discard(gather_done) completion.set_result(False) semaphore.release() return completion From 9a3574a752a90e828e0913f301fa88c33ff13f71 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Tue, 16 Jun 2026 08:05:01 +0000 Subject: [PATCH 16/31] remove semaphone Signed-off-by: Tony Lin --- .../transfer_context/async_data.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py index 6fc5183658..23df512b52 100644 --- a/lmcache/v1/multiprocess/transfer_context/async_data.py +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -83,7 +83,6 @@ def __init__( max_workers=self._commit_workers, thread_name_prefix="lmcache_non_gpu_commit", ) - self._inflight_semaphore = threading.BoundedSemaphore(self._max_inflight_stores) self._inflight_lock = threading.Lock() self._inflight_gather_events: set[Any] = set() self._inflight_commits: set[ConcurrentFuture[None]] = set() @@ -143,6 +142,8 @@ def submit_store( the background ``commit_executor``. Returns an unresolved future that resolves only after both gather completion and the commit ACK. """ + import time + _t_entry = time.perf_counter() if self._non_gpu_context is None: raise RuntimeError( "Data transfer context is not registered. " @@ -150,10 +151,8 @@ def submit_store( ) completion: MessagingFuture[bool] = MessagingFuture() non_gpu_context = self._non_gpu_context - semaphore = self._inflight_semaphore commit_executor = self._commit_executor - semaphore.acquire() staged_chunks: list[torch.Tensor] = [] # Whether we gathered directly into SHM views (True) or into # pinned staging buffers that need to be released later (False). @@ -162,7 +161,6 @@ def submit_store( with self._inflight_lock: if self._is_closing: completion.set_result(False) - semaphore.release() return completion result = non_gpu_context.prepare_store(key, instance_id) @@ -170,7 +168,6 @@ def submit_store( if chunk_indices is not None and len(chunk_indices) == 0: # All chunks are already in cache: no gather, no commit. completion.set_result(True) - semaphore.release() return completion full_block_ids = _single_group_block_ids(block_ids) @@ -208,7 +205,7 @@ def _commit_after_gather() -> None: gather_done: Any | None = None ok = False try: - with torch_dev.stream(self._copy_stream): + with torch.inference_mode(), torch_dev.stream(self._copy_stream): _event.wait(stream=self._copy_stream) gather_paged_kv_to_cpu( @@ -249,20 +246,18 @@ def _commit_after_gather() -> None: if gather_done is not None: self._inflight_gather_events.discard(gather_done) completion.set_result(ok) - semaphore.release() # Submitting the commit task is the ownership-transfer point: once it # succeeds, the commit task is solely responsible for releasing the - # semaphore, releasing staging buffers, and resolving the future. The - # except below therefore only handles failures that occur *before* - # this submit, so it can never double-release or double-resolve. + # staging buffers, and resolving the future. The except below therefore + # only handles failures that occur *before* this submit, so it can never + # double-release or double-resolve. commit_future = commit_executor.submit(_commit_after_gather) except Exception: logger.exception("Failed to submit async non-GPU store") if staged_chunks: self._release_staging(staged_chunks) completion.set_result(False) - semaphore.release() return completion with self._inflight_lock: @@ -273,6 +268,8 @@ def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: self._inflight_commits.discard(done_future) commit_future.add_done_callback(_drop_commit_future) + logger.info("[submit_store] forward thread returned at %.3f ms since entry", + (time.perf_counter() - _t_entry) * 1000) return completion def flush_inflight_gathers(self) -> None: From 98eff82726a135cac2e280abfef750f4111411dc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 16 Jun 2026 08:05:12 +0000 Subject: [PATCH 17/31] Add comprehensive profiling instrumentation to async_data.py --- .../transfer_context/async_data.py | 77 ++++++++++++++++++- 1 file changed, 74 insertions(+), 3 deletions(-) diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py index 23df512b52..853e4366db 100644 --- a/lmcache/v1/multiprocess/transfer_context/async_data.py +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any import threading +import time # Third Party import torch @@ -142,7 +143,6 @@ def submit_store( the background ``commit_executor``. Returns an unresolved future that resolves only after both gather completion and the commit ACK. """ - import time _t_entry = time.perf_counter() if self._non_gpu_context is None: raise RuntimeError( @@ -158,12 +158,16 @@ def submit_store( # pinned staging buffers that need to be released later (False). used_shm_direct = False try: + _t0 = time.perf_counter() with self._inflight_lock: if self._is_closing: completion.set_result(False) return completion + _t_lock = time.perf_counter() result = non_gpu_context.prepare_store(key, instance_id) + _t_prepare = time.perf_counter() + out_buffers, chunk_indices = result if result is not None else (None, None) if chunk_indices is not None and len(chunk_indices) == 0: # All chunks are already in cache: no gather, no commit. @@ -171,6 +175,8 @@ def submit_store( return completion full_block_ids = _single_group_block_ids(block_ids) + _t_block_ids = time.perf_counter() + num_chunks = ( len(chunk_indices) if chunk_indices is not None @@ -196,17 +202,24 @@ def submit_store( num_chunks, ) gather_target = staged_chunks + _t_alloc = time.perf_counter() # Capture variables for the closure _used_shm_direct = used_shm_direct _gather_target = gather_target + _t_submit_start = time.perf_counter() def _commit_after_gather() -> None: + _tb_entry = time.perf_counter() gather_done: Any | None = None ok = False try: + _tb0 = time.perf_counter() with torch.inference_mode(), torch_dev.stream(self._copy_stream): + _tb_stream_enter = time.perf_counter() + _event.wait(stream=self._copy_stream) + _tb_event_wait = time.perf_counter() gather_paged_kv_to_cpu( kv_caches, @@ -217,22 +230,49 @@ def _commit_after_gather() -> None: out=_gather_target, chunk_indices=chunk_indices, ) + _tb_gather = time.perf_counter() gather_done = torch_dev.Event() gather_done.record(self._copy_stream) + _tb_record = time.perf_counter() + + _tb_stream_exit = time.perf_counter() with self._inflight_lock: if gather_done is not None: self._inflight_gather_events.add(gather_done) + _tb_lock = time.perf_counter() if gather_done is not None: gather_done.synchronize() + _tb_sync = time.perf_counter() + ok = non_gpu_context.commit_store(key, instance_id, _gather_target) + _tb_commit = time.perf_counter() + if not ok: logger.error( "Async non-GPU commit_store failed for request_id=%s", _request_id, ) + + logger.info( + "[BG %s] thread_start=%.3f stream_enter=%.3f " + "event_wait=%.3f gather_launch=%.3f record=%.3f " + "stream_exit=%.3f lock=%.3f sync=%.3f commit=%.3f " + "total=%.3f ms", + _request_id, + (_tb0 - _tb_entry) * 1000, + (_tb_stream_enter - _tb0) * 1000, + (_tb_event_wait - _tb_stream_enter) * 1000, + (_tb_gather - _tb_event_wait) * 1000, + (_tb_record - _tb_gather) * 1000, + (_tb_stream_exit - _tb_record) * 1000, + (_tb_lock - _tb_stream_exit) * 1000, + (_tb_sync - _tb_lock) * 1000, + (_tb_commit - _tb_sync) * 1000, + (_tb_commit - _tb_entry) * 1000, + ) except Exception: logger.exception( "Async non-GPU store failed for request_id=%s", @@ -253,6 +293,7 @@ def _commit_after_gather() -> None: # only handles failures that occur *before* this submit, so it can never # double-release or double-resolve. commit_future = commit_executor.submit(_commit_after_gather) + _t_submit_end = time.perf_counter() except Exception: logger.exception("Failed to submit async non-GPU store") if staged_chunks: @@ -268,8 +309,22 @@ def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: self._inflight_commits.discard(done_future) commit_future.add_done_callback(_drop_commit_future) - logger.info("[submit_store] forward thread returned at %.3f ms since entry", - (time.perf_counter() - _t_entry) * 1000) + + _t_exit = time.perf_counter() + logger.info( + "[FWD %s] lock=%.3f prepare=%.3f block_ids=%.3f " + "alloc=%.3f submit=%.3f bookkeep=%.3f total=%.3f ms " + "(num_chunks=%d, shm=%s)", + _request_id, + (_t_lock - _t0) * 1000, + (_t_prepare - _t_lock) * 1000, + (_t_block_ids - _t_prepare) * 1000, + (_t_alloc - _t_block_ids) * 1000, + (_t_submit_end - _t_submit_start) * 1000, + (_t_exit - _t_submit_end) * 1000, + num_chunks, + _used_shm_direct, + ) return completion def flush_inflight_gathers(self) -> None: @@ -280,13 +335,21 @@ def flush_inflight_gathers(self) -> None: reading them. Only gather completion is awaited; commit futures are not affected, since commits read from LMCache-owned staging buffers. """ + _t0 = time.perf_counter() with self._inflight_lock: gather_events = list(self._inflight_gather_events) for event in gather_events: event.synchronize() + _t1 = time.perf_counter() + logger.info( + "[flush_inflight_gathers] synced %d events in %.3f ms", + len(gather_events), + (_t1 - _t0) * 1000, + ) def close(self) -> None: # Drain in-flight gather/commit work before closing the base context. + _t0 = time.perf_counter() with self._inflight_lock: self._is_closing = True gather_events = list(self._inflight_gather_events) @@ -295,5 +358,13 @@ def close(self) -> None: event.synchronize() except Exception: logger.exception("Failed while draining gather events") + _t1 = time.perf_counter() self._commit_executor.shutdown(wait=True, cancel_futures=False) + _t2 = time.perf_counter() + logger.info( + "[close] gather_drain=%.3f executor_shutdown=%.3f total=%.3f ms", + (_t1 - _t0) * 1000, + (_t2 - _t1) * 1000, + (_t2 - _t0) * 1000, + ) super().close() From 0fe5e8a628de37e46ad904594f2966c3b20d7813 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 16 Jun 2026 08:05:12 +0000 Subject: [PATCH 18/31] Add comprehensive profiling instrumentation to async_data.py From 20190bc5cd75b52f499a53f26adcd6d8fb02dd10 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 17 Jun 2026 03:14:00 +0000 Subject: [PATCH 19/31] fix log Signed-off-by: Tony Lin --- lmcache/v1/multiprocess/transfer_context/async_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py index 853e4366db..b828f04632 100644 --- a/lmcache/v1/multiprocess/transfer_context/async_data.py +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -322,6 +322,7 @@ def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: (_t_alloc - _t_block_ids) * 1000, (_t_submit_end - _t_submit_start) * 1000, (_t_exit - _t_submit_end) * 1000, + (_t_exit - _t_entry) * 1000, num_chunks, _used_shm_direct, ) From fe220cea7d8e59c66c0ef424ddadd12b10c12835 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 16 Jun 2026 08:09:12 +0000 Subject: [PATCH 20/31] Fix missing total argument and use outer-scope used_shm_direct in FWD log --- lmcache/v1/multiprocess/transfer_context/async_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py index b828f04632..0f35811da5 100644 --- a/lmcache/v1/multiprocess/transfer_context/async_data.py +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -324,7 +324,7 @@ def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: (_t_exit - _t_submit_end) * 1000, (_t_exit - _t_entry) * 1000, num_chunks, - _used_shm_direct, + used_shm_direct, ) return completion From 08e03b95455ec1969574a2a1f320d4baea7924fb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 03:11:59 +0000 Subject: [PATCH 21/31] Add profiling instrumentation to CUDA IPC transfer path - worker_transfer.py: Add import time + timing to HandleTransferContext.submit_store() with [FWD-IPC] log covering ipc_handle, send_request, to_cuda_future, and total ms - gpu_transfer.py: Add granular timing to GPUTransferModule.store() with [GPU-STORE] summary log and per-chunk [GPU-STORE-CHUNK] logs covering kernel launch and memcpy_d2h --- .../v1/multiprocess/modules/gpu_transfer.py | 38 +++++++++++++++++++ .../transfer_context/worker_transfer.py | 22 +++++++++-- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/lmcache/v1/multiprocess/modules/gpu_transfer.py b/lmcache/v1/multiprocess/modules/gpu_transfer.py index 2ec9c25148..0706385f1e 100644 --- a/lmcache/v1/multiprocess/modules/gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/gpu_transfer.py @@ -338,12 +338,14 @@ def store( """ st = time.perf_counter() obj_keys = self._ctx.resolve_obj_keys(key) + _t_resolve = time.perf_counter() entry = self._gpu_contexts.get(instance_id) if entry is None: raise ValueError(f"No GPU context registered for instance ID {instance_id}") gpu_context = entry.gpu_context model_name = entry.model_name + _num_groups = gpu_context.kv_layer_groups_manager.num_groups # NOTE: different engine groups may have different block sizes, so # ``blocks_per_chunk[i]`` is the number of blocks in one chunk for @@ -363,6 +365,7 @@ def store( block_ids_per_group_gpu = gpu_context.copy_view_block_ids_to_gpu( gpu_block_ids ) + _t_copy_ids = time.perf_counter() # Fail closed: every LMCache group must have block IDs covering all # chunks. A short list (e.g. a caller/protocol bug) would otherwise @@ -397,6 +400,7 @@ def store( gpu_context.device, event_ipc_handle ) vllm_event.wait(stream=gpu_context.stream) + _t_event_wait = time.perf_counter() # CPU-synchronous sentinel: a GPU store is about to be enqueued. # Must be published via publish() (not publish_on_stream) so the @@ -421,6 +425,7 @@ def store( }, ), ) + _t_publish = time.perf_counter() reserved_dict: dict[ObjectKey, MemoryObj] = {} store_succeeded = False @@ -429,6 +434,7 @@ def store( reserved_dict = self._ctx.storage_manager.reserve_write( obj_keys, layout_desc, "new" ) + _t_reserve = time.perf_counter() # NOTE: Store is not batched because some obj_keys may be # skipped (not in reserved_dict), making block_ids @@ -441,6 +447,7 @@ def store( else: continue + _t_chunk_start = time.perf_counter() # Copy from GPU paged buffer to tmp buffer, then to CPU — per # group. Each group uses its own block-id list (HMA). for group_idx in range(num_groups): @@ -467,16 +474,28 @@ def store( gpu_context.gpu_kv_format_, 0, ) + _t_kernel_end = time.perf_counter() # Store is not batched, so we always use chunk_idx=0 (single slot) lmcache_memcpy_async_d2h( gpu_context.get_tmp_gpu_buffer_flat(chunk_idx=0), memory_obj ) + _t_memcpy_end = time.perf_counter() + logger.info( + "[GPU-STORE-CHUNK] req=%s chunk_idx=%d kernel=%.3f memcpy_d2h=%.3f ms", + key.request_id, + idx, + (_t_kernel_end - _t_chunk_start) * 1000, + (_t_memcpy_end - _t_kernel_end) * 1000, + ) store_succeeded = True + _t_loop_end = time.perf_counter() except Exception: logger.exception("Cannot store keys due to exception") return event.ipc_handle(), False finally: + _t_record_start = time.perf_counter() event.record() + _t_record_end = time.perf_counter() # Fail closed: commit the reserved objects only when every chunk # copied successfully; otherwise the whole store is skipped. stored_count = len(reserved_dict) if store_succeeded else 0 @@ -486,6 +505,7 @@ def store( "finish_write", list(reserved_dict.keys()), ) + _t_callback_end = time.perf_counter() # All reserved MemoryObjs share one layout_desc, so per-object # size is identical — avoid summing N identical values. total_bytes = ( @@ -509,6 +529,24 @@ def store( ) ed = time.perf_counter() + logger.info( + "[GPU-STORE] req=%s resolve_keys=%.3f copy_block_ids=%.3f " + "event_ipc_wait=%.3f event_publish=%.3f reserve_write=%.3f " + "kernel_loop=%.3f event_record=%.3f submit_cb=%.3f total=%.3f ms " + "(num_chunks=%d, num_groups=%d)", + key.request_id, + (_t_resolve - st) * 1000, + (_t_copy_ids - _t_resolve) * 1000, + (_t_event_wait - _t_copy_ids) * 1000, + (_t_publish - _t_event_wait) * 1000, + (_t_reserve - _t_publish) * 1000, + (_t_loop_end - _t_reserve) * 1000, + (_t_record_end - _t_record_start) * 1000, + (_t_callback_end - _t_record_end) * 1000, + (ed - st) * 1000, + len(obj_keys), + _num_groups, + ) if length := len(reserved_dict): logger.info( "Stored %d tokens in %.3f seconds", diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index cbf715feaa..4b475bef76 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Callable, Protocol import os +import time # Third Party import torch @@ -281,11 +282,26 @@ def submit_store( "Handle transfer context is not registered. " "Call register() before submit_store()." ) - return self._send_request( + _t0 = time.perf_counter() + ipc_handle = event.ipc_handle() + _t_ipc = time.perf_counter() + mq_future = self._send_request( self._mq_client, RequestType.STORE, - [key, instance_id, block_ids, event.ipc_handle()], - ).to_cuda_future() + [key, instance_id, block_ids, ipc_handle], + ) + _t_send = time.perf_counter() + cuda_future = mq_future.to_cuda_future() + _t_cuda = time.perf_counter() + logger.info( + "[FWD-IPC] req=%s ipc_handle=%.3f send_request=%.3f to_cuda_future=%.3f total=%.3f ms", + _request_id, + (_t_ipc - _t0) * 1000, + (_t_send - _t_ipc) * 1000, + (_t_cuda - _t_send) * 1000, + (_t_cuda - _t0) * 1000, + ) + return cuda_future def submit_retrieve( self, From 01c2d7ef4065afba303daa12d57ccaa849d05f20 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 03:13:21 +0000 Subject: [PATCH 22/31] Fix timing variable scoping: initialize before try block to avoid NameError risk --- lmcache/v1/multiprocess/modules/gpu_transfer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lmcache/v1/multiprocess/modules/gpu_transfer.py b/lmcache/v1/multiprocess/modules/gpu_transfer.py index 0706385f1e..4ff7359956 100644 --- a/lmcache/v1/multiprocess/modules/gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/gpu_transfer.py @@ -429,6 +429,11 @@ def store( reserved_dict: dict[ObjectKey, MemoryObj] = {} store_succeeded = False + _t_reserve = _t_publish + _t_loop_end = _t_publish + _t_record_start = _t_publish + _t_record_end = _t_publish + _t_callback_end = _t_publish try: layout_desc = get_layout_desc(gpu_context, self._ctx.chunk_size) reserved_dict = self._ctx.storage_manager.reserve_write( From 357aab42ba4372cc799d1d4a51b89a825a02747a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 04:37:09 +0000 Subject: [PATCH 23/31] Add E2E timing from submit_store_request to get_finished --- .../vllm/vllm_multi_process_adapter.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 3f855c49e5..c4c2ea818e 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -7,6 +7,7 @@ import enum import os import threading +import time # Third Party import torch @@ -936,6 +937,10 @@ def __init__( # Prevents re-reporting the same ID after drain clears tracking sets. self._returned_finished: set[str] = set() + # Timestamps recorded when submit_store_request is called, used to + # measure E2E wall-clock time until the future is resolved. + self._store_submit_times: dict[str, float] = {} + self.model_name = model_name self.parallel_strategy = parallel_strategy @@ -1185,6 +1190,7 @@ def submit_store_request( ) self.store_futures[request_id] = future self.store_events[request_id] = event + self._store_submit_times[request_id] = time.perf_counter() @_lmcache_nvtx_annotate def submit_retrieve_request( @@ -1345,6 +1351,7 @@ def get_finished( self.retrieve_futures.clear() self.store_events.clear() self.retrieve_events.clear() + self._store_submit_times.clear() ret_stores = self._process_finished_stores( finished_stores, finished_req_ids_from_engine @@ -1363,6 +1370,15 @@ def get_finished( if not s_future.query(): continue + _t_done = time.perf_counter() + _t_submit = self._store_submit_times.pop(request_id, None) + if _t_submit is not None: + logger.info( + "[E2E-STORE] req=%s e2e=%.3f ms", + str(request_id), + (_t_done - _t_submit) * 1000, + ) + s_result = s_future.result() finished_stores.add(request_id) From 11529a7feab8a1659eaecf7633627ac021371418 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Jun 2026 04:37:41 +0000 Subject: [PATCH 24/31] Remove redundant str() in E2E-STORE log call --- lmcache/integration/vllm/vllm_multi_process_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index c4c2ea818e..ff035dd392 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -1375,7 +1375,7 @@ def get_finished( if _t_submit is not None: logger.info( "[E2E-STORE] req=%s e2e=%.3f ms", - str(request_id), + request_id, (_t_done - _t_submit) * 1000, ) From 959b005d8c2274d7e30000b0800d15ef3fca0be7 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Sun, 7 Jun 2026 10:06:02 +0800 Subject: [PATCH 25/31] feat(ops): add multi_layer_block_kv_transfer Python fallback as unified MP transfer primitive (#3508) Signed-off-by: Tony Lin --- lmcache/python_ops_fallback.py | 847 +++++++++++++++++++++++++ tests/v1/test_c_ops_fallback_parity.py | 4 +- tests/v1/test_python_ops_fallback.py | 847 +++++++++++++++++++++++++ 3 files changed, 1695 insertions(+), 3 deletions(-) diff --git a/lmcache/python_ops_fallback.py b/lmcache/python_ops_fallback.py index 01bc7c2398..f1b8e15593 100644 --- a/lmcache/python_ops_fallback.py +++ b/lmcache/python_ops_fallback.py @@ -748,6 +748,853 @@ def multi_layer_kv_transfer_unilateral( key_value[kv_idx, layer_id, valid_mask_kv, :] = gathered.to(kv_device) +def _is_cross_layer_format(gpu_kv_format: GPUKVFormat) -> bool: + """Return True when a KV format uses a single cross-layer tensor.""" + return int(gpu_kv_format) in ( + int(GPUKVFormat.NB_NL_TWO_BS_NH_HS), + int(GPUKVFormat.NB_NL_TWO_NH_BS_HS), + ) + + +def _is_sglang_mha_format(gpu_kv_format: GPUKVFormat) -> bool: + """Return True when a KV format uses SGLang MHA layout (2*NL tensors).""" + return int(gpu_kv_format) in ( + int(GPUKVFormat.TWO_X_NL_X_NBBS_NH_HS), + int(GPUKVFormat.TWO_X_NL_X_NB_BS_NH_HS), + ) + + +def _is_hnd_format(gpu_kv_format: GPUKVFormat) -> bool: + """Return True when a per-layer KV format stores heads before block tokens (HND).""" + return int(gpu_kv_format) in ( + int(GPUKVFormat.NL_X_TWO_NB_NH_BS_HS), + int(GPUKVFormat.NL_X_NB_TWO_NH_BS_HS), + ) + + +def _is_mla_format(gpu_kv_format: GPUKVFormat) -> bool: + """Return True when a KV format uses MLA paged layout.""" + return int(gpu_kv_format) in ( + int(GPUKVFormat.NL_X_NB_BS_HS), + int(GPUKVFormat.NL_X_NBBS_ONE_HS), + ) + + +_ELEMENT_SIZE_TO_DTYPE: dict[int, torch.dtype] = { + # Maps the byte width of a KV-cache element to a representative torch dtype. + # Only widths that commonly appear in KV caches are listed; 1-byte entries + # are treated as uint8 (raw bytes), 2-byte as float16, 4-byte as float32. + # Note: bfloat16 also has element_size == 2 but cannot be distinguished here; + # callers that need exact dtype should supply it explicitly. + 1: torch.uint8, + 2: torch.float16, + 4: torch.float32, +} + + +def _is_ptr_tensor(x: object) -> bool: + """Return True when *x* is a 1-D pointer tensor (int64 or uint64).""" + return ( + isinstance(x, torch.Tensor) + and x.dtype in (torch.int64, torch.uint64) + and x.ndim == 1 + ) + + +def _per_layer_paged_shape( + gpu_kv_format: GPUKVFormat, + nb: int, + bs: int, + nh: int, + hs: int, +) -> tuple[int, ...]: + """Return the logical shape of a single per-layer paged buffer tensor. + + Args: + gpu_kv_format: The format enum that describes how K/V tokens are laid out. + nb: Number of blocks in the paged buffer (``shape_desc.nb``). + bs: Tokens per block / block size (``shape_desc.bs``). + nh: Number of attention heads (``shape_desc.nh``). + hs: Per-head hidden size (``shape_desc.hs``). + + Returns: + A tuple representing the shape needed to reconstruct one layer's tensor + from a raw pointer via :func:`_tensor_from_ptr`. + """ + fmt = int(gpu_kv_format) + if fmt == int(GPUKVFormat.NL_X_NBBS_ONE_HS): + return (nb * bs, 1, hs) + if fmt == int(GPUKVFormat.NL_X_NB_BS_HS): + return (nb, bs, hs) + if fmt == int(GPUKVFormat.NL_X_TWO_NB_NH_BS_HS): + return (2, nb, nh, bs, hs) + if fmt == int(GPUKVFormat.NL_X_NB_TWO_NH_BS_HS): + return (nb, 2, nh, bs, hs) + if fmt == int(GPUKVFormat.NL_X_TWO_NB_BS_NH_HS): + return (2, nb, bs, nh, hs) + # Covers NL_X_NB_TWO_BS_NH_HS and any future NHD variants. + return (nb, 2, bs, nh, hs) + + +def _infer_kv_dtype( + paged_buffer_ptrs_tensor: object, + lmcache_objects_ptrs: object, + shape_desc: "PageBufferShapeDesc", +) -> torch.dtype: + """Infer the KV element dtype from whichever inputs carry it. + + Inference order (first match wins): + 1. ``shape_desc.dtype`` — authoritative when set (requires the + ``set_shape_desc_dtype`` helper from PR #3514; correctly distinguishes + float16 vs bfloat16 which share ``element_size == 2``). + 2. ``paged_buffer_ptrs_tensor`` — if it is a non-pointer tensor or a list + of tensors (including nested SGLang MHA lists), the dtype of the first + tensor is used. + 3. ``lmcache_objects_ptrs`` — if it is a list of tensors, the dtype of the + first chunk tensor is used. + 4. ``shape_desc.element_size`` — looked up in :data:`_ELEMENT_SIZE_TO_DTYPE` + (ambiguous for 2-byte types; kept only as last-resort fallback). + 5. ``torch.bfloat16`` — silent default when no other source is available. + """ + # Prefer shape_desc.dtype — it is exact and avoids the element_size ambiguity. + if shape_desc is not None: + sd_dtype = getattr(shape_desc, "dtype", None) + if sd_dtype is not None: + return sd_dtype + if isinstance(paged_buffer_ptrs_tensor, list) and paged_buffer_ptrs_tensor: + first = paged_buffer_ptrs_tensor[0] + if isinstance(first, list) and first and isinstance(first[0], torch.Tensor): + return first[0].dtype + if isinstance(first, torch.Tensor): + return first.dtype + if isinstance(paged_buffer_ptrs_tensor, torch.Tensor) and not _is_ptr_tensor( + paged_buffer_ptrs_tensor + ): + return paged_buffer_ptrs_tensor.dtype + if isinstance(lmcache_objects_ptrs, list) and lmcache_objects_ptrs: + if isinstance(lmcache_objects_ptrs[0], torch.Tensor): + return lmcache_objects_ptrs[0].dtype + if shape_desc is not None and shape_desc.element_size > 0: + dtype = _ELEMENT_SIZE_TO_DTYPE.get(shape_desc.element_size) + if dtype is None: + raise ValueError( + f"Unsupported element_size {shape_desc.element_size!r} in " + "shape_desc; cannot infer KV dtype. " + f"Supported sizes: {sorted(_ELEMENT_SIZE_TO_DTYPE)}" + ) + return dtype + return torch.bfloat16 + + +def _normalize_paged_layers( + paged_buffer_ptrs_tensor: "torch.Tensor | list", + gpu_kv_format: GPUKVFormat, + shape_desc: "PageBufferShapeDesc | None" = None, + device: "torch.device | str | None" = None, + dtype: "torch.dtype | None" = None, +) -> "torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]]": + """Normalize paged buffer input based on GPU KV format. + + Accepts either tensor-form inputs (list / Tensor) or a 1-D pointer tensor + (int64 / uint64). When a pointer tensor is provided *shape_desc*, *device*, + and *dtype* must be supplied so the tensors can be reconstructed via + :func:`_tensor_from_ptr`. + + Returns: + - Single ``torch.Tensor`` for cross-layer formats. + - ``list[list[torch.Tensor]]`` (2 x NL) for SGLang MHA formats. + - ``list[torch.Tensor]`` (per-layer) for all other formats. + """ + if _is_cross_layer_format(gpu_kv_format): + if isinstance(paged_buffer_ptrs_tensor, torch.Tensor): + if _is_ptr_tensor(paged_buffer_ptrs_tensor): + # 1-D pointer tensor with a single entry → reconstruct full tensor. + if shape_desc is None or device is None or dtype is None: + raise ValueError( + "_normalize_paged_layers: shape_desc, device, and dtype are " + "required when paged_buffer_ptrs_tensor is a pointer tensor" + ) + nb = int(shape_desc.nb) + nl = int(shape_desc.nl) + bs = int(shape_desc.bs) + nh = int(shape_desc.nh) + hs = int(shape_desc.hs) + if int(gpu_kv_format) == int(GPUKVFormat.NB_NL_TWO_NH_BS_HS): + shape: tuple[int, ...] = (nb, nl, 2, nh, bs, hs) + else: + shape = (nb, nl, 2, bs, nh, hs) + ptr = int(paged_buffer_ptrs_tensor[0].item()) + return _tensor_from_ptr(ptr, shape, dtype, device) + return paged_buffer_ptrs_tensor + raise TypeError( + "Cross-layer formats require a single torch.Tensor input; " + "got: " + type(paged_buffer_ptrs_tensor).__name__ + ) + if _is_sglang_mha_format(gpu_kv_format): + if _is_ptr_tensor(paged_buffer_ptrs_tensor): + # 1-D pointer tensor [K_L0,...,K_LN-1, V_L0,...,V_LN-1] → nested list. + if shape_desc is None or device is None or dtype is None: + raise ValueError( + "_normalize_paged_layers: shape_desc, device, and dtype are " + "required when paged_buffer_ptrs_tensor is a pointer tensor" + ) + nb = int(shape_desc.nb) + nl = int(shape_desc.nl) + bs = int(shape_desc.bs) + nh = int(shape_desc.nh) + hs = int(shape_desc.hs) + is_flat = int(gpu_kv_format) == int(GPUKVFormat.TWO_X_NL_X_NBBS_NH_HS) + per_layer_shape: tuple[int, ...] = ( + (nb * bs, nh, hs) if is_flat else (nb, bs, nh, hs) + ) + ptrs = [int(p.item()) for p in paged_buffer_ptrs_tensor] + k_tensors = [ + _tensor_from_ptr(ptrs[i], per_layer_shape, dtype, device) + for i in range(nl) + ] + v_tensors = [ + _tensor_from_ptr(ptrs[nl + i], per_layer_shape, dtype, device) + for i in range(nl) + ] + return [k_tensors, v_tensors] + if isinstance(paged_buffer_ptrs_tensor, list): + # Already nested [[K tensors], [V tensors]] + if ( + len(paged_buffer_ptrs_tensor) == 2 + and isinstance(paged_buffer_ptrs_tensor[0], list) + and all( + isinstance(t, torch.Tensor) + for group in paged_buffer_ptrs_tensor + for t in group + ) + ): + return paged_buffer_ptrs_tensor + # Flat list [K_L0, ..., K_LN-1, V_L0, ..., V_LN-1] + if all(isinstance(t, torch.Tensor) for t in paged_buffer_ptrs_tensor): + if len(paged_buffer_ptrs_tensor) % 2 != 0: + raise ValueError( + "Flat SGLang MHA list must have even length (2*NL)" + ) + half = len(paged_buffer_ptrs_tensor) // 2 + return [ + paged_buffer_ptrs_tensor[:half], + paged_buffer_ptrs_tensor[half:], + ] + raise TypeError( + "SGLang MHA formats require a list[list[torch.Tensor]], a flat " + "list[torch.Tensor] (2*NL entries), or a 1-D pointer tensor; " + "got: " + type(paged_buffer_ptrs_tensor).__name__ + ) + # Per-layer formats + if _is_ptr_tensor(paged_buffer_ptrs_tensor): + # 1-D pointer tensor [ptr_L0, ..., ptr_LN-1] → list of per-layer tensors. + if shape_desc is None or device is None or dtype is None: + raise ValueError( + "_normalize_paged_layers: shape_desc, device, and dtype are " + "required when paged_buffer_ptrs_tensor is a pointer tensor" + ) + nb = int(shape_desc.nb) + bs = int(shape_desc.bs) + nh = int(shape_desc.nh) + hs = int(shape_desc.hs) + per_shape = _per_layer_paged_shape(gpu_kv_format, nb, bs, nh, hs) + return [ + _tensor_from_ptr(int(p.item()), per_shape, dtype, device) + for p in paged_buffer_ptrs_tensor + ] + if isinstance(paged_buffer_ptrs_tensor, list): + if not all(isinstance(t, torch.Tensor) for t in paged_buffer_ptrs_tensor): + raise TypeError( + "paged_buffer_ptrs_tensor list must contain torch.Tensor entries" + ) + return paged_buffer_ptrs_tensor + raise TypeError( + "paged_buffer_ptrs_tensor must be a list[torch.Tensor] or 1-D pointer tensor; " + "got: " + type(paged_buffer_ptrs_tensor).__name__ + ) + + +def _normalize_lmcache_objects( + lmcache_objects_ptrs: "list[int] | list[torch.Tensor]", + shape_desc: "PageBufferShapeDesc | None" = None, + lmcache_chunk_size: "int | None" = None, + gpu_kv_format: "GPUKVFormat | None" = None, + dtype: "torch.dtype | None" = None, +) -> list[torch.Tensor]: + """Normalize LMCache object inputs to chunk tensors. + + Accepts either a list of chunk tensors or a ``list[int]`` of raw CPU pointers. + When a pointer list is provided *shape_desc*, *lmcache_chunk_size*, + *gpu_kv_format*, and *dtype* must be supplied so the tensors can be + reconstructed via :func:`_tensor_from_ptr` on the CPU. + """ + if not isinstance(lmcache_objects_ptrs, list): + raise TypeError( + "lmcache_objects_ptrs must be a list[torch.Tensor] or list[int]; " + "got: " + type(lmcache_objects_ptrs).__name__ + ) + if not lmcache_objects_ptrs: + return [] + if isinstance(lmcache_objects_ptrs[0], torch.Tensor): + return lmcache_objects_ptrs # type: ignore[return-value] + if isinstance(lmcache_objects_ptrs[0], int): + # Pointer mode: reconstruct chunk tensors (always on CPU). + if ( + shape_desc is None + or lmcache_chunk_size is None + or gpu_kv_format is None + or dtype is None + ): + raise ValueError( + "_normalize_lmcache_objects: shape_desc, lmcache_chunk_size, " + "gpu_kv_format, and dtype are required when lmcache_objects_ptrs " + "contains raw int pointers" + ) + nl = int(shape_desc.nl) + nh = int(shape_desc.nh) + hs = int(shape_desc.hs) + chunk_tokens = lmcache_chunk_size + if _is_mla_format(gpu_kv_format): + chunk_shape: tuple[int, ...] = (nl, chunk_tokens, hs) + else: + chunk_shape = (2, nl, chunk_tokens, nh * hs) + return [ + _tensor_from_ptr(ptr, chunk_shape, dtype, "cpu") + for ptr in lmcache_objects_ptrs + ] + raise TypeError( + "lmcache_objects_ptrs must be a list[torch.Tensor] or list[int]; " + "got list containing: " + type(lmcache_objects_ptrs[0]).__name__ + ) + + +def _to_block_id_list(block_ids: torch.Tensor | list[int]) -> list[int]: + """Convert block IDs from tensor/list form into a Python ``list[int]``.""" + if isinstance(block_ids, torch.Tensor): + return [int(x) for x in block_ids.to(dtype=torch.int64).cpu().tolist()] + if isinstance(block_ids, list): + return [int(x) for x in block_ids] + raise TypeError("block_ids must be a torch.Tensor or list[int]") + + +def multi_layer_block_kv_transfer( + paged_buffer_ptrs_tensor: "torch.Tensor | list", + lmcache_objects_ptrs: list[int] | list[torch.Tensor], + block_ids: torch.Tensor | list[int], + device: torch.device | str, + direction: TransferDirection, + shape_desc: PageBufferShapeDesc, + lmcache_chunk_size: int, + gpu_kv_format: GPUKVFormat, + skip_prefix_n_blocks: int, +) -> None: + """Python fallback implementation of block-based multi-layer KV transfer. + + Signature intentionally mirrors the C++ binding so callers can invoke + ``lmcache.c_ops.multi_layer_block_kv_transfer`` uniformly on native and + fallback backends. + + Args: + paged_buffer_ptrs_tensor: Paged buffer pointers or tensors. + lmcache_objects_ptrs: LMCache object pointers or chunk tensors. + block_ids: Ordered engine block IDs for the transfer. + device: Target device for the transfer. + direction: Transfer direction (H2D or D2H). + shape_desc: Shape descriptor of the page buffer. + lmcache_chunk_size: Chunk size of LMCache objects. + gpu_kv_format: GPU KV cache format. + skip_prefix_n_blocks: Number of leading blocks to skip. + + Returns: + None + + Raises: + ValueError: If chunk size is invalid, or transfer direction is unsupported. + TypeError: If input types do not match expected types. + """ + if lmcache_chunk_size <= 0: + raise ValueError("lmcache_chunk_size must be positive") + if int(shape_desc.bs) <= 0 or lmcache_chunk_size % int(shape_desc.bs) != 0: + raise ValueError( + "lmcache_chunk_size must be a positive multiple of shape_desc.bs" + ) + if skip_prefix_n_blocks < 0: + raise ValueError("skip_prefix_n_blocks must be >= 0") + + is_d2h = int(direction) == int(TransferDirection.D2H) + is_h2d = int(direction) == int(TransferDirection.H2D) + if not (is_d2h or is_h2d): + raise ValueError(f"Unsupported transfer direction: {direction!r}") + + kv_dtype = _infer_kv_dtype( + paged_buffer_ptrs_tensor, lmcache_objects_ptrs, shape_desc + ) + normalized = _normalize_paged_layers( + paged_buffer_ptrs_tensor, + gpu_kv_format, + shape_desc=shape_desc, + device=device, + dtype=kv_dtype, + ) + object_tensors = _normalize_lmcache_objects( + lmcache_objects_ptrs, + shape_desc=shape_desc, + lmcache_chunk_size=lmcache_chunk_size, + gpu_kv_format=gpu_kv_format, + dtype=kv_dtype, + ) + block_id_list = _to_block_id_list(block_ids) + blocks_per_object = lmcache_chunk_size // int(shape_desc.bs) + block_size = int(shape_desc.bs) + + if _is_cross_layer_format(gpu_kv_format): + _transfer_cross_layer( + normalized, + object_tensors, + block_id_list, + blocks_per_object, + block_size, + gpu_kv_format, + is_d2h, + skip_prefix_n_blocks, + ) + elif _is_sglang_mha_format(gpu_kv_format): + _transfer_sglang_mha( + normalized, + object_tensors, + block_id_list, + blocks_per_object, + block_size, + gpu_kv_format, + is_d2h, + skip_prefix_n_blocks, + ) + elif _is_mla_format(gpu_kv_format): + _transfer_per_layer_mla( + normalized, + object_tensors, + block_id_list, + blocks_per_object, + block_size, + gpu_kv_format, + is_d2h, + skip_prefix_n_blocks, + ) + elif _is_hnd_format(gpu_kv_format): + _transfer_per_layer_hnd( + normalized, + object_tensors, + block_id_list, + blocks_per_object, + block_size, + gpu_kv_format, + is_d2h, + skip_prefix_n_blocks, + ) + else: + _transfer_per_layer_nhd( + normalized, + object_tensors, + block_id_list, + blocks_per_object, + block_size, + gpu_kv_format, + is_d2h, + skip_prefix_n_blocks, + ) + + +def _valid_block_range( + object_idx: int, + block_id_list: list[int], + blocks_per_object: int, + block_size: int, + skip_prefix_n_blocks: int, +) -> tuple[list[int], int] | None: + """Return valid engine block IDs and their LMCache object token offset. + + Args: + object_idx: Index of the LMCache object/chunk being processed. + block_id_list: Full ordered engine block ids for the transfer. + blocks_per_object: Number of blocks represented by one LMCache object. + block_size: Number of tokens per block. + skip_prefix_n_blocks: Number of leading flat block positions to skip. + + Returns: + ``None`` if this object has no valid blocks after skip handling. + Otherwise, a tuple of valid engine block ids and the token offset + within this LMCache object where those blocks start. + """ + object_flat_start = object_idx * blocks_per_object + valid_flat_start = max(object_flat_start, skip_prefix_n_blocks) + valid_flat_end = min(object_flat_start + blocks_per_object, len(block_id_list)) + if valid_flat_start >= valid_flat_end: + return None + offset_in_object = (valid_flat_start - object_flat_start) * block_size + return block_id_list[valid_flat_start:valid_flat_end], offset_in_object + + +def _transfer_cross_layer( + paged_tensor: torch.Tensor, + object_tensors: list[torch.Tensor], + block_id_list: list[int], + blocks_per_object: int, + block_size: int, + gpu_kv_format: GPUKVFormat, + is_d2h: bool, + skip_prefix_n_blocks: int, +) -> None: + """Handle cross-layer formats: single tensor [NB, NL, 2, ...].""" + # NHD: [NB, NL, 2, BS, NH, HS] HND: [NB, NL, 2, NH, BS, HS] + is_hnd = int(gpu_kv_format) == int(GPUKVFormat.NB_NL_TWO_NH_BS_HS) + num_layers = paged_tensor.shape[1] + + if is_hnd: + # [NB, NL, 2, NH, BS, HS] + nh = paged_tensor.shape[3] + hs = paged_tensor.shape[5] + else: + # [NB, NL, 2, BS, NH, HS] + nh = paged_tensor.shape[4] + hs = paged_tensor.shape[5] + + # H2D: pre-transfer objects to paged device + if not is_d2h and object_tensors: + objs_on_device = [obj.to(paged_tensor.device) for obj in object_tensors] + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range( + object_idx, + block_id_list, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + engine_block_ids, offset_in_object = valid + n_valid = len(engine_block_ids) + token_end = offset_in_object + n_valid * block_size + eff_idx = torch.tensor( + engine_block_ids, dtype=torch.long, device=paged_tensor.device + ) + + if is_d2h: + selected = paged_tensor.index_select(0, eff_idx) + + for layer_idx in range(num_layers): + for kv in range(2): + if is_d2h: + slice_t = selected[:, layer_idx, kv] + if is_hnd: + # N=n_valid, BS=block_size: + # [N, NH, BS, HS] -> [N, BS, NH, HS] -> [N*BS, NH*HS] + flat = slice_t.permute(0, 2, 1, 3).reshape( + n_valid * block_size, nh * hs + ) + else: + # [N, BS, NH, HS] → [N*BS, NH*HS] + flat = slice_t.reshape(n_valid * block_size, nh * hs) + obj[kv, layer_idx, offset_in_object:token_end].copy_( + flat, non_blocking=True + ) + else: + obj_device = objs_on_device[object_idx] + src = obj_device[kv, layer_idx, offset_in_object:token_end] + if is_hnd: + # N=n_valid, BS=block_size: + # [N*BS, NH*HS] -> [N, BS, NH, HS] -> [N, NH, BS, HS] + src_blocks = src.reshape(n_valid, block_size, nh, hs).permute( + 0, 2, 1, 3 + ) + else: + # N=n_valid, BS=block_size: + # [N*BS, NH*HS] -> [N, BS, NH, HS] + src_blocks = src.reshape(n_valid, block_size, nh, hs) + paged_tensor[:, layer_idx, kv].index_copy_(0, eff_idx, src_blocks) + + +def _transfer_sglang_mha( + paged_tensors: list[list[torch.Tensor]], + object_tensors: list[torch.Tensor], + block_id_list: list[int], + blocks_per_object: int, + block_size: int, + gpu_kv_format: GPUKVFormat, + is_d2h: bool, + skip_prefix_n_blocks: int, +) -> None: + """Handle SGLang MHA formats: 2*NL tensors (list[list[Tensor]]).""" + # TWO_X_NL_X_NBBS_NH_HS: each tensor [NB*BS, NH, HS] + # TWO_X_NL_X_NB_BS_NH_HS: each tensor [NB, BS, NH, HS] + is_flat = int(gpu_kv_format) == int(GPUKVFormat.TWO_X_NL_X_NBBS_NH_HS) + num_layers = len(paged_tensors[0]) + + # Determine target device from first tensor + target_device = paged_tensors[0][0].device + + # H2D: pre-transfer objects + if not is_d2h and object_tensors: + objs_on_device = [obj.to(target_device) for obj in object_tensors] + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range( + object_idx, + block_id_list, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + engine_block_ids, offset_in_object = valid + n_valid = len(engine_block_ids) + token_end = offset_in_object + n_valid * block_size + eff_idx = torch.tensor(engine_block_ids, dtype=torch.long, device=target_device) + if is_flat: + # Flat token positions for all valid blocks: + # block_id * block_size + token offset. Reused across layer/KV pairs. + token_indices = ( + eff_idx[:, None] * block_size + + torch.arange(block_size, dtype=torch.long, device=target_device) + ).reshape(-1) + + for layer_idx in range(num_layers): + for kv in range(2): + layer_t = paged_tensors[kv][layer_idx] + nh = layer_t.shape[-2] + hs = layer_t.shape[-1] + + if is_d2h: + if is_flat: + # [NB*BS, NH, HS] + gathered = layer_t.index_select(0, token_indices) + else: + # [NB, BS, NH, HS] + gathered = layer_t.index_select(0, eff_idx).reshape( + n_valid * block_size, nh, hs + ) + flat = gathered.reshape(n_valid * block_size, nh * hs) + obj[kv, layer_idx, offset_in_object:token_end].copy_( + flat, non_blocking=True + ) + else: + obj_device = objs_on_device[object_idx] + src = obj_device[kv, layer_idx, offset_in_object:token_end] + src_shaped = src.reshape(n_valid * block_size, nh, hs) + if is_flat: + # scatter into [NB*BS, NH, HS] + layer_t.index_copy_(0, token_indices, src_shaped) + else: + # N=n_valid, BS=block_size: + # [N*BS, NH, HS] -> [N, BS, NH, HS] + src_blocks = src_shaped.reshape(n_valid, block_size, nh, hs) + layer_t.index_copy_(0, eff_idx, src_blocks) + + +def _transfer_per_layer_mla( + layer_tensors: list[torch.Tensor], + object_tensors: list[torch.Tensor], + block_id_list: list[int], + blocks_per_object: int, + block_size: int, + gpu_kv_format: GPUKVFormat, + is_d2h: bool, + skip_prefix_n_blocks: int, +) -> None: + """Handle MLA per-layer formats: [NB, BS, HS].""" + if not is_d2h and layer_tensors and object_tensors: + target_device = layer_tensors[0].device + objs_on_device = [obj.to(target_device) for obj in object_tensors] + + for layer_idx, layer in enumerate(layer_tensors): + is_flat = int(gpu_kv_format) == int(GPUKVFormat.NL_X_NBBS_ONE_HS) + if is_flat: + token_offsets = torch.arange( + block_size, dtype=torch.long, device=layer.device + ) + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range( + object_idx, + block_id_list, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + engine_block_ids, offset_in_object = valid + n_valid = len(engine_block_ids) + token_end = offset_in_object + n_valid * block_size + eff_idx = torch.tensor( + engine_block_ids, dtype=torch.long, device=layer.device + ) + if is_flat: + token_indices = ( + eff_idx[:, None] * block_size + token_offsets[None, :] + ).reshape(-1) + + if is_d2h: + if is_flat: + layer_blocks = layer.index_select(0, token_indices) + else: + layer_blocks = layer.index_select(0, eff_idx) + flat = layer_blocks.reshape(n_valid * block_size, layer.shape[-1]) + obj[layer_idx, offset_in_object:token_end].copy_( + flat, non_blocking=True + ) + else: + obj_device = objs_on_device[object_idx] + src = obj_device[layer_idx, offset_in_object:token_end] + hidden_size = layer.shape[-1] + if is_flat: + src_tokens = src.reshape(n_valid * block_size, 1, hidden_size) + layer.index_copy_(0, token_indices, src_tokens) + else: + src_blocks = src.reshape(n_valid, block_size, hidden_size) + layer.index_copy_(0, eff_idx, src_blocks) + + +def _transfer_per_layer_hnd( + layer_tensors: list[torch.Tensor], + object_tensors: list[torch.Tensor], + block_id_list: list[int], + blocks_per_object: int, + block_size: int, + gpu_kv_format: GPUKVFormat, + is_d2h: bool, + skip_prefix_n_blocks: int, +) -> None: + """Handle per-layer HND formats: heads before block tokens.""" + if not is_d2h and layer_tensors and object_tensors: + target_device = layer_tensors[0].device + objs_on_device = [obj.to(target_device) for obj in object_tensors] + + for layer_idx, layer in enumerate(layer_tensors): + # Determine K/V split based on specific format + if int(gpu_kv_format) == int(GPUKVFormat.NL_X_TWO_NB_NH_BS_HS): + k_t, v_t = layer[0], layer[1] + else: + k_t, v_t = layer[:, 0], layer[:, 1] + _nb, nh, _bs, hs = k_t.shape + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range( + object_idx, + block_id_list, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + engine_block_ids, offset_in_object = valid + n_valid = len(engine_block_ids) + token_end = offset_in_object + n_valid * block_size + eff_idx = torch.tensor( + engine_block_ids, dtype=torch.long, device=layer.device + ) + + if is_d2h: + k_blocks = ( + k_t.index_select(0, eff_idx) + .permute(0, 2, 1, 3) + .reshape(n_valid * block_size, nh * hs) + ) + v_blocks = ( + v_t.index_select(0, eff_idx) + .permute(0, 2, 1, 3) + .reshape(n_valid * block_size, nh * hs) + ) + obj[0, layer_idx, offset_in_object:token_end].copy_( + k_blocks, non_blocking=True + ) + obj[1, layer_idx, offset_in_object:token_end].copy_( + v_blocks, non_blocking=True + ) + else: + obj_device = objs_on_device[object_idx] + k_src = obj_device[0, layer_idx, offset_in_object:token_end] + v_src = obj_device[1, layer_idx, offset_in_object:token_end] + k_blocks = k_src.reshape(n_valid, block_size, nh, hs).permute( + 0, 2, 1, 3 + ) + v_blocks = v_src.reshape(n_valid, block_size, nh, hs).permute( + 0, 2, 1, 3 + ) + k_t.index_copy_(0, eff_idx, k_blocks) + v_t.index_copy_(0, eff_idx, v_blocks) + + +def _transfer_per_layer_nhd( + layer_tensors: list[torch.Tensor], + object_tensors: list[torch.Tensor], + block_id_list: list[int], + blocks_per_object: int, + block_size: int, + gpu_kv_format: GPUKVFormat, + is_d2h: bool, + skip_prefix_n_blocks: int, +) -> None: + """Handle per-layer NHD formats: block tokens before heads.""" + if not is_d2h and layer_tensors and object_tensors: + target_device = layer_tensors[0].device + objs_on_device = [obj.to(target_device) for obj in object_tensors] + + for layer_idx, layer in enumerate(layer_tensors): + # Determine K/V split based on specific format + if int(gpu_kv_format) == int(GPUKVFormat.NL_X_TWO_NB_BS_NH_HS): + k_t, v_t = layer[0], layer[1] + else: + k_t, v_t = layer[:, 0], layer[:, 1] + _nb, _bs, nh, hs = k_t.shape + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range( + object_idx, + block_id_list, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + engine_block_ids, offset_in_object = valid + n_valid = len(engine_block_ids) + token_end = offset_in_object + n_valid * block_size + eff_idx = torch.tensor( + engine_block_ids, dtype=torch.long, device=layer.device + ) + + if is_d2h: + k_blocks = k_t.index_select(0, eff_idx).reshape( + n_valid * block_size, nh * hs + ) + v_blocks = v_t.index_select(0, eff_idx).reshape( + n_valid * block_size, nh * hs + ) + obj[0, layer_idx, offset_in_object:token_end].copy_( + k_blocks, non_blocking=True + ) + obj[1, layer_idx, offset_in_object:token_end].copy_( + v_blocks, non_blocking=True + ) + else: + obj_device = objs_on_device[object_idx] + k_src = obj_device[0, layer_idx, offset_in_object:token_end] + v_src = obj_device[1, layer_idx, offset_in_object:token_end] + k_t.index_copy_( + 0, + eff_idx, + k_src.reshape(n_valid, block_size, nh, hs), + ) + v_t.index_copy_( + 0, + eff_idx, + v_src.reshape(n_valid, block_size, nh, hs), + ) + + def single_layer_kv_transfer( lmc_key_value_cache: torch.Tensor, vllm_key_value_cache: torch.Tensor, diff --git a/tests/v1/test_c_ops_fallback_parity.py b/tests/v1/test_c_ops_fallback_parity.py index fb97b67576..5d15bbd918 100644 --- a/tests/v1/test_c_ops_fallback_parity.py +++ b/tests/v1/test_c_ops_fallback_parity.py @@ -212,9 +212,7 @@ def _has_real_names(params): # ── Discover the intersection automatically ── # Functions intentionally excluded from parity checks. -_EXCLUDED_FUNCS: set[str] = { - "multi_layer_block_kv_transfer", -} +_EXCLUDED_FUNCS: set[str] = set() _fallback_callables = _public_callables(fallback) _c_ops_callables = _public_callables(c_ops) if HAS_C_OPS else {} diff --git a/tests/v1/test_python_ops_fallback.py b/tests/v1/test_python_ops_fallback.py index f0f6e648e4..9bb6c5efa9 100644 --- a/tests/v1/test_python_ops_fallback.py +++ b/tests/v1/test_python_ops_fallback.py @@ -1796,6 +1796,852 @@ class _FakeStream: return {"dispatcher_integration": torch.tensor([1], dtype=torch.int32)} +def scenario_multi_layer_block_kv_transfer( + ops: Any, device: str +) -> dict[str, torch.Tensor]: + """Test multi_layer_block_kv_transfer across all GPU KV formats. + + Exercises the block-based transfer path for: + - NHD per-layer format (D2H and H2D round-trip) + - FlashInfer NHD per-layer format (interleaved K/V) + - HND per-layer format (with permute) + - FlashInfer HND per-layer format (interleaved K/V) + - MLA per-layer format (no K/V split) + - SGLang MLA flat per-layer format + - Cross-layer NHD (single tensor [NB, NL, 2, BS, NH, HS]) + - Cross-layer HND (single tensor [NB, NL, 2, NH, BS, HS]) + - SGLang MHA flat (2*NL tensors [NB*BS, NH, HS]) + - SGLang MHA block (2*NL tensors [NB, BS, NH, HS]) + - non-sequential block_ids and list[int] block_ids input + - skip_prefix_n_blocks > 0 + - num_blocks > blocks_per_chunk (multi-chunk) + """ + results = {} + + # C++ bindings (cuda_c_ops, xpu_sycl_ops) expect uint64 pointer tensors for + # paged_buffer_ptrs_tensor and list[int] for lmcache_objects_ptrs. + # The Python fallback also supports both modes on cpu/cuda (pointer inputs + # are reconstructed internally via _tensor_from_ptr). + use_tensor_list = device not in ("cpu", "cuda") + + def _alloc_chunks(shape: tuple[int, ...], count: int) -> list[torch.Tensor]: + chunks = [torch.zeros(shape, dtype=dtype) for _ in range(count)] + if device in ("cuda"): + chunks = [chunk.pin_memory() for chunk in chunks] + return chunks + + # --- NHD per-layer --- + torch.manual_seed(123) + num_layers, num_blocks, block_size = 2, 8, 4 + num_heads, head_size = 2, 8 + blocks_per_chunk = 4 + chunk_tokens = blocks_per_chunk * block_size + hidden_dim = num_heads * head_size + dtype = torch.float32 + + paged_layers = [ + torch.randn(2, num_blocks, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + shape_desc = ops.PageBufferShapeDesc() + shape_desc.nl = num_layers + shape_desc.nb = num_blocks + shape_desc.bs = block_size + shape_desc.nh = num_heads + shape_desc.hs = head_size + shape_desc.element_size = dtype.itemsize + shape_desc.kv_size = 2 + gpu_kv_format = ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS + num_chunks = num_blocks // blocks_per_chunk + d2h_chunks = _alloc_chunks((2, num_layers, chunk_tokens, hidden_dim), num_chunks) + block_ids = list(range(num_blocks)) + + ops.multi_layer_block_kv_transfer( + paged_layers + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers], + dtype=torch.uint64, + device=device, + ), + d2h_chunks if use_tensor_list else [c.data_ptr() for c in d2h_chunks], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format, + 0, + ) + paged_h2d = [torch.zeros_like(layer) for layer in paged_layers] + ops.multi_layer_block_kv_transfer( + paged_h2d + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d], dtype=torch.uint64, device=device + ), + d2h_chunks if use_tensor_list else [c.data_ptr() for c in d2h_chunks], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format, + 0, + ) + for i in range(num_layers): + orig = paged_layers[i].cpu() + recon = paged_h2d[i].cpu() + results[f"nhd_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"NHD Layer {i} round-trip mismatch" + ) + + # --- FlashInfer NHD per-layer --- + torch.manual_seed(234) + paged_layers_fi_nhd = [ + torch.randn(num_blocks, 2, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + gpu_kv_format_fi_nhd = ops.GPUKVFormat.NL_X_NB_TWO_BS_NH_HS + d2h_chunks_fi_nhd = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_fi_nhd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_fi_nhd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_fi_nhd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_fi_nhd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_fi_nhd, + 0, + ) + paged_h2d_fi_nhd = [torch.zeros_like(layer) for layer in paged_layers_fi_nhd] + ops.multi_layer_block_kv_transfer( + paged_h2d_fi_nhd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_fi_nhd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_fi_nhd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_fi_nhd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_fi_nhd, + 0, + ) + for i in range(num_layers): + orig = paged_layers_fi_nhd[i].cpu() + recon = paged_h2d_fi_nhd[i].cpu() + results[f"flashinfer_nhd_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"FlashInfer NHD Layer {i} round-trip mismatch" + ) + + # --- HND per-layer --- + torch.manual_seed(456) + paged_layers_hnd = [ + torch.randn(2, num_blocks, num_heads, block_size, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + gpu_kv_format_hnd = ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS + d2h_chunks_hnd = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_hnd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_hnd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_hnd if use_tensor_list else [c.data_ptr() for c in d2h_chunks_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_hnd, + 0, + ) + paged_h2d_hnd = [torch.zeros_like(layer) for layer in paged_layers_hnd] + ops.multi_layer_block_kv_transfer( + paged_h2d_hnd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_hnd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_hnd if use_tensor_list else [c.data_ptr() for c in d2h_chunks_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_hnd, + 0, + ) + for i in range(num_layers): + orig = paged_layers_hnd[i].cpu() + recon = paged_h2d_hnd[i].cpu() + results[f"hnd_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"HND Layer {i} round-trip mismatch" + ) + + # --- FlashInfer HND per-layer --- + torch.manual_seed(567) + paged_layers_fi_hnd = [ + torch.randn(num_blocks, 2, num_heads, block_size, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + gpu_kv_format_fi_hnd = ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS + d2h_chunks_fi_hnd = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_fi_hnd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_fi_hnd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_fi_hnd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_fi_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_fi_hnd, + 0, + ) + paged_h2d_fi_hnd = [torch.zeros_like(layer) for layer in paged_layers_fi_hnd] + ops.multi_layer_block_kv_transfer( + paged_h2d_fi_hnd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_fi_hnd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_fi_hnd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_fi_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_fi_hnd, + 0, + ) + for i in range(num_layers): + orig = paged_layers_fi_hnd[i].cpu() + recon = paged_h2d_fi_hnd[i].cpu() + results[f"flashinfer_hnd_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"FlashInfer HND Layer {i} round-trip mismatch" + ) + + # --- MLA per-layer --- + torch.manual_seed(789) + mla_hidden = 16 + paged_layers_mla = [ + torch.randn(num_blocks, block_size, mla_hidden, dtype=dtype).to(device) + for _ in range(num_layers) + ] + shape_desc_mla = ops.PageBufferShapeDesc() + shape_desc_mla.nl = num_layers + shape_desc_mla.nb = num_blocks + shape_desc_mla.bs = block_size + shape_desc_mla.nh = 1 + shape_desc_mla.hs = mla_hidden + shape_desc_mla.element_size = dtype.itemsize + shape_desc_mla.kv_size = 1 + gpu_kv_format_mla = ops.GPUKVFormat.NL_X_NB_BS_HS + d2h_chunks_mla = _alloc_chunks((num_layers, chunk_tokens, mla_hidden), num_chunks) + ops.multi_layer_block_kv_transfer( + paged_layers_mla + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_mla], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_mla if use_tensor_list else [c.data_ptr() for c in d2h_chunks_mla], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc_mla, + chunk_tokens, + gpu_kv_format_mla, + 0, + ) + paged_h2d_mla = [torch.zeros_like(layer) for layer in paged_layers_mla] + ops.multi_layer_block_kv_transfer( + paged_h2d_mla + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_mla], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_mla if use_tensor_list else [c.data_ptr() for c in d2h_chunks_mla], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc_mla, + chunk_tokens, + gpu_kv_format_mla, + 0, + ) + for i in range(num_layers): + orig = paged_layers_mla[i].cpu() + recon = paged_h2d_mla[i].cpu() + results[f"mla_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"MLA Layer {i} round-trip mismatch" + ) + + # --- SGLang MLA flat per-layer --- + torch.manual_seed(890) + paged_layers_sglang_mla = [ + torch.randn(num_blocks * block_size, 1, mla_hidden, dtype=dtype).to(device) + for _ in range(num_layers) + ] + gpu_kv_format_sglang_mla = ops.GPUKVFormat.NL_X_NBBS_ONE_HS + d2h_chunks_sglang_mla = _alloc_chunks( + (num_layers, chunk_tokens, mla_hidden), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_sglang_mla + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_sglang_mla], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_mla + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_mla], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc_mla, + chunk_tokens, + gpu_kv_format_sglang_mla, + 0, + ) + paged_h2d_sglang_mla = [ + torch.zeros_like(layer) for layer in paged_layers_sglang_mla + ] + ops.multi_layer_block_kv_transfer( + paged_h2d_sglang_mla + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_sglang_mla], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_mla + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_mla], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc_mla, + chunk_tokens, + gpu_kv_format_sglang_mla, + 0, + ) + for i in range(num_layers): + orig = paged_layers_sglang_mla[i].cpu() + recon = paged_h2d_sglang_mla[i].cpu() + results[f"sglang_mla_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"SGLang MLA Layer {i} round-trip mismatch" + ) + + # --- Cross-layer NHD --- + torch.manual_seed(101) + paged_cross_nhd = torch.randn( + num_blocks, + num_layers, + 2, + block_size, + num_heads, + head_size, + dtype=dtype, + ).to(device) + gpu_kv_format_cross_nhd = ops.GPUKVFormat.NB_NL_TWO_BS_NH_HS + d2h_chunks_cross_nhd = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_cross_nhd + if use_tensor_list + else torch.tensor( + [paged_cross_nhd.data_ptr()], dtype=torch.uint64, device=device + ), + d2h_chunks_cross_nhd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_cross_nhd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_cross_nhd, + 0, + ) + paged_h2d_cross_nhd = torch.zeros_like(paged_cross_nhd) + ops.multi_layer_block_kv_transfer( + paged_h2d_cross_nhd + if use_tensor_list + else torch.tensor( + [paged_h2d_cross_nhd.data_ptr()], dtype=torch.uint64, device=device + ), + d2h_chunks_cross_nhd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_cross_nhd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_cross_nhd, + 0, + ) + orig = paged_cross_nhd.cpu() + recon = paged_h2d_cross_nhd.cpu() + results["cross_nhd"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), "Cross-layer NHD mismatch" + + # --- Cross-layer HND --- + torch.manual_seed(202) + paged_cross_hnd = torch.randn( + num_blocks, + num_layers, + 2, + num_heads, + block_size, + head_size, + dtype=dtype, + ).to(device) + gpu_kv_format_cross_hnd = ops.GPUKVFormat.NB_NL_TWO_NH_BS_HS + d2h_chunks_cross_hnd = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_cross_hnd + if use_tensor_list + else torch.tensor( + [paged_cross_hnd.data_ptr()], dtype=torch.uint64, device=device + ), + d2h_chunks_cross_hnd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_cross_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_cross_hnd, + 0, + ) + paged_h2d_cross_hnd = torch.zeros_like(paged_cross_hnd) + ops.multi_layer_block_kv_transfer( + paged_h2d_cross_hnd + if use_tensor_list + else torch.tensor( + [paged_h2d_cross_hnd.data_ptr()], dtype=torch.uint64, device=device + ), + d2h_chunks_cross_hnd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_cross_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_cross_hnd, + 0, + ) + orig = paged_cross_hnd.cpu() + recon = paged_h2d_cross_hnd.cpu() + results["cross_hnd"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), "Cross-layer HND mismatch" + + # --- SGLang MHA flat (NBBS) --- + torch.manual_seed(303) + pbs = num_blocks * block_size + paged_sglang_nbbs = [ + [ + torch.randn(pbs, num_heads, head_size, dtype=dtype).to(device) + for _ in range(num_layers) + ], + [ + torch.randn(pbs, num_heads, head_size, dtype=dtype).to(device) + for _ in range(num_layers) + ], + ] + gpu_kv_format_sglang_nbbs = ops.GPUKVFormat.TWO_X_NL_X_NBBS_NH_HS + d2h_chunks_sglang_nbbs = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_sglang_nbbs + if use_tensor_list + else torch.tensor( + [t.data_ptr() for t in paged_sglang_nbbs[0]] + + [t.data_ptr() for t in paged_sglang_nbbs[1]], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_nbbs + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_nbbs], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_sglang_nbbs, + 0, + ) + paged_h2d_sglang_nbbs = [ + [torch.zeros_like(t) for t in group] for group in paged_sglang_nbbs + ] + ops.multi_layer_block_kv_transfer( + paged_h2d_sglang_nbbs + if use_tensor_list + else torch.tensor( + [t.data_ptr() for t in paged_h2d_sglang_nbbs[0]] + + [t.data_ptr() for t in paged_h2d_sglang_nbbs[1]], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_nbbs + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_nbbs], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_sglang_nbbs, + 0, + ) + for kv in range(2): + for i in range(num_layers): + orig = paged_sglang_nbbs[kv][i].cpu() + recon = paged_h2d_sglang_nbbs[kv][i].cpu() + results[f"sglang_nbbs_kv{kv}_l{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"SGLang NBBS kv={kv} layer={i} mismatch" + ) + + # --- SGLang MHA block (NB_BS) --- + torch.manual_seed(404) + paged_sglang_nb = [ + [ + torch.randn(num_blocks, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ], + [ + torch.randn(num_blocks, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ], + ] + gpu_kv_format_sglang_nb = ops.GPUKVFormat.TWO_X_NL_X_NB_BS_NH_HS + d2h_chunks_sglang_nb = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_sglang_nb + if use_tensor_list + else torch.tensor( + [t.data_ptr() for t in paged_sglang_nb[0]] + + [t.data_ptr() for t in paged_sglang_nb[1]], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_nb + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_nb], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_sglang_nb, + 0, + ) + paged_h2d_sglang_nb = [ + [torch.zeros_like(t) for t in group] for group in paged_sglang_nb + ] + ops.multi_layer_block_kv_transfer( + paged_h2d_sglang_nb + if use_tensor_list + else torch.tensor( + [t.data_ptr() for t in paged_h2d_sglang_nb[0]] + + [t.data_ptr() for t in paged_h2d_sglang_nb[1]], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_nb + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_nb], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_sglang_nb, + 0, + ) + for kv in range(2): + for i in range(num_layers): + orig = paged_sglang_nb[kv][i].cpu() + recon = paged_h2d_sglang_nb[kv][i].cpu() + results[f"sglang_nb_kv{kv}_l{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"SGLang NB kv={kv} layer={i} mismatch" + ) + + # --- skip_prefix_n_blocks > 0 --- + torch.manual_seed(505) + skip_n = 2 + paged_layers_skip = [ + torch.randn(2, num_blocks, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + gpu_kv_format_nhd = ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS + # With skip=2, effective blocks start at index 2. + # Object 0 occupies flat indices [0, blocks_per_chunk), skipping first 2. + # Object 1 occupies flat indices [blocks_per_chunk, 2*blocks_per_chunk). + d2h_chunks_skip = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_skip + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_skip], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_skip if use_tensor_list else [c.data_ptr() for c in d2h_chunks_skip], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_nhd, + skip_n, + ) + paged_h2d_skip = [torch.zeros_like(layer) for layer in paged_layers_skip] + ops.multi_layer_block_kv_transfer( + paged_h2d_skip + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_skip], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_skip if use_tensor_list else [c.data_ptr() for c in d2h_chunks_skip], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_nhd, + skip_n, + ) + # Blocks [skip_n:num_blocks] should round-trip at their original positions + for layer_idx in range(num_layers): + for kv in range(2): + orig_blocks = paged_layers_skip[layer_idx][kv, skip_n:num_blocks] + recon_blocks = paged_h2d_skip[layer_idx][kv, skip_n:num_blocks] + key = f"skip_l{layer_idx}_kv{kv}" + results[key] = torch.stack([orig_blocks.cpu(), recon_blocks.cpu()]) + assert torch.allclose(orig_blocks, recon_blocks, atol=1e-6), ( + f"Skip mismatch at layer={layer_idx} kv={kv}" + ) + # Skipped blocks [0:skip_n] should remain zero + for layer_idx in range(num_layers): + for kv in range(2): + skipped = paged_h2d_skip[layer_idx][kv, :skip_n] + assert torch.all(skipped == 0), ( + f"Skipped blocks not zero at layer={layer_idx} kv={kv}" + ) + assert torch.all(d2h_chunks_skip[0][:, :, : skip_n * block_size] == 0), ( + "Skipped D2H chunk region should remain zero" + ) + + # --- Non-sequential block_ids --- + torch.manual_seed(515) + paged_layers_permuted = [ + torch.randn(2, num_blocks, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + generator = torch.Generator(device="cpu").manual_seed(616) + permuted_block_ids = torch.randperm(num_blocks, generator=generator).tolist() + d2h_chunks_permuted = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_permuted + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_permuted], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_permuted + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_permuted], + torch.tensor(permuted_block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_nhd, + 0, + ) + paged_h2d_permuted = [torch.zeros_like(layer) for layer in paged_layers_permuted] + ops.multi_layer_block_kv_transfer( + paged_h2d_permuted + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_permuted], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_permuted + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_permuted], + torch.tensor(permuted_block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_nhd, + 0, + ) + for i in range(num_layers): + orig = paged_layers_permuted[i].cpu() + recon = paged_h2d_permuted[i].cpu() + results[f"permuted_nhd_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"Permuted NHD Layer {i} round-trip mismatch" + ) + + # --- Multi-chunk (num_blocks > blocks_per_chunk) --- + torch.manual_seed(606) + num_blocks_mc = 12 + paged_layers_mc = [ + torch.randn(2, num_blocks_mc, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + shape_desc_mc = ops.PageBufferShapeDesc() + shape_desc_mc.nl = num_layers + shape_desc_mc.nb = num_blocks_mc + shape_desc_mc.bs = block_size + shape_desc_mc.nh = num_heads + shape_desc_mc.hs = head_size + shape_desc_mc.element_size = dtype.itemsize + shape_desc_mc.kv_size = 2 + num_chunks_mc = num_blocks_mc // blocks_per_chunk # 3 chunks + d2h_chunks_mc = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks_mc + ) + block_ids_mc = list(range(num_blocks_mc)) + ops.multi_layer_block_kv_transfer( + paged_layers_mc + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_mc], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_mc if use_tensor_list else [c.data_ptr() for c in d2h_chunks_mc], + torch.tensor(block_ids_mc, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc_mc, + chunk_tokens, + gpu_kv_format_nhd, + 0, + ) + paged_h2d_mc = [torch.zeros_like(layer) for layer in paged_layers_mc] + ops.multi_layer_block_kv_transfer( + paged_h2d_mc + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_mc], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_mc if use_tensor_list else [c.data_ptr() for c in d2h_chunks_mc], + torch.tensor(block_ids_mc, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc_mc, + chunk_tokens, + gpu_kv_format_nhd, + 0, + ) + for i in range(num_layers): + orig = paged_layers_mc[i].cpu() + recon = paged_h2d_mc[i].cpu() + results[f"multi_chunk_l{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"Multi-chunk Layer {i} round-trip mismatch" + ) + + return results + + def scenario_record_drain_event(ops: Any, device: str) -> dict[str, torch.Tensor]: """Test record_event_on_stream / drain_recorded_events contracts. @@ -1860,6 +2706,7 @@ def scenario_record_drain_event(ops: Any, device: str) -> dict[str, torch.Tensor "transfer_direction_enum": scenario_transfer_direction_enum, "multi_layer_kv_transfer": scenario_multi_layer_kv_transfer, "multi_layer_kv_transfer_unilateral": scenario_multi_layer_kv_transfer_unilateral, + "multi_layer_block_kv_transfer": scenario_multi_layer_block_kv_transfer, "single_layer_kv_transfer": scenario_single_layer_kv_transfer, "single_layer_kv_transfer_sgl": scenario_single_layer_kv_transfer_sgl, "load_and_reshape_flash": scenario_load_and_reshape_flash, From a12f41a25473e5ec7f7e7514038f4b0070124c64 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Fri, 12 Jun 2026 09:46:22 +0800 Subject: [PATCH 26/31] Optimize the Python fallback path for block transfer operations with notable speedup (#3591) * Perf: optimize Python fallback block transfer for 3x speedup - Optimize fallback block-id and D2H staging overhead - Restructure per-layer transfer loops to iterate over objects first then layers Signed-off-by: Tony Lin * apply gemini's suggestion Signed-off-by: Tony Lin * optimize flash_infer block transfer paths in python fallback Signed-off-by: Tony Lin --------- Signed-off-by: Tony Lin --- lmcache/python_ops_fallback.py | 622 ++++++++++++++++++++------------- 1 file changed, 385 insertions(+), 237 deletions(-) diff --git a/lmcache/python_ops_fallback.py b/lmcache/python_ops_fallback.py index f1b8e15593..69038c61b9 100644 --- a/lmcache/python_ops_fallback.py +++ b/lmcache/python_ops_fallback.py @@ -257,8 +257,8 @@ class TransferDirection(IntEnum): D2H = 1 -class GPUKVFormat(IntEnum): - """Enumeration of different GPU KV cache memory layouts.""" +class EngineKVFormat(IntEnum): + """Enumeration of different engine KV cache memory layouts.""" # used by: vLLM CROSS_LAYER mode NB_NL_TWO_BS_NH_HS = 0 @@ -290,6 +290,16 @@ class GPUKVFormat(IntEnum): # used by: SGLang MHA via the MP daemon path TWO_X_NL_X_NB_BS_NH_HS = 9 + # used by: vLLM non-MLA blocks-first attention with K/V fused into the + # trailing dim. Per-layer physical shape + # [num_blocks, num_heads, block_size, 2, head_size] -- the K/V "2" axis is + # second-to-last, recovered by splitting the fused [..., 2 * head_size]. + NL_X_NB_NH_BS_TWO_HS = 10 + + +# Backward-compat alias +GPUKVFormat = EngineKVFormat + class PageBufferShapeDesc: """Python stand-in for the C++ ``PageBufferShapeDesc`` struct. @@ -539,7 +549,7 @@ def multi_layer_kv_transfer( paged_memory_device: torch.device, page_buffer_size: int, direction: TransferDirection, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, block_size: int = 0, head_size: int = 0, skip_prefix_n_tokens: int = 0, @@ -555,9 +565,9 @@ def multi_layer_kv_transfer( # TODO: Implement head_size support for HND layouts (NL_X_TWO_NB_NH_BS_HS, # NL_X_NB_TWO_NH_BS_HS) as next step. - if int(gpu_kv_format) in ( - int(GPUKVFormat.NL_X_TWO_NB_NH_BS_HS), - int(GPUKVFormat.NL_X_NB_TWO_NH_BS_HS), + if int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS), + int(EngineKVFormat.NL_X_NB_TWO_NH_BS_HS), ): raise NotImplementedError( "HND layouts (NL_X_TWO_NB_NH_BS_HS, NL_X_NB_TWO_NH_BS_HS) " @@ -585,11 +595,11 @@ def multi_layer_kv_transfer( valid_slots = slots_kv[valid_mask_kv].to(paged_memory_device) # 2. Determine architecture variant and tensor dimensions. - is_mla = int(gpu_kv_format) in ( - int(GPUKVFormat.NL_X_NB_BS_HS), - int(GPUKVFormat.NL_X_NBBS_ONE_HS), + is_mla = int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_NB_BS_HS), + int(EngineKVFormat.NL_X_NBBS_ONE_HS), ) - is_flash_infer = int(gpu_kv_format) == int(GPUKVFormat.NL_X_NB_TWO_BS_NH_HS) + is_flash_infer = int(engine_kv_format) == int(EngineKVFormat.NL_X_NB_TWO_BS_NH_HS) num_layers = key_value.size(1) hidden_size = key_value.size(3) @@ -673,7 +683,7 @@ def multi_layer_kv_transfer_unilateral( paged_memory_device: torch.device, page_buffer_size: int, direction: TransferDirection, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, ): """ Python fallback for multi_layer_kv_transfer_unilateral @@ -696,9 +706,9 @@ def multi_layer_kv_transfer_unilateral( H2D = LMCache -> PagedBuffer D2H = PagedBuffer -> LMCache """ - is_mla = int(gpu_kv_format) in ( - int(GPUKVFormat.NL_X_NB_BS_HS), - int(GPUKVFormat.NL_X_NBBS_ONE_HS), + is_mla = int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_NB_BS_HS), + int(EngineKVFormat.NL_X_NBBS_ONE_HS), ) # MLA case collapses back to multi_layer_kv_transfer @@ -711,7 +721,7 @@ def multi_layer_kv_transfer_unilateral( paged_memory_device, page_buffer_size, direction, - gpu_kv_format, + engine_kv_format, 0, # block_size unused for MLA formats ) # ── Non-MLA path: unilateral (separate K/V buffers per layer) ── @@ -748,35 +758,36 @@ def multi_layer_kv_transfer_unilateral( key_value[kv_idx, layer_id, valid_mask_kv, :] = gathered.to(kv_device) -def _is_cross_layer_format(gpu_kv_format: GPUKVFormat) -> bool: +def _is_cross_layer_format(engine_kv_format: EngineKVFormat) -> bool: """Return True when a KV format uses a single cross-layer tensor.""" - return int(gpu_kv_format) in ( - int(GPUKVFormat.NB_NL_TWO_BS_NH_HS), - int(GPUKVFormat.NB_NL_TWO_NH_BS_HS), + return int(engine_kv_format) in ( + int(EngineKVFormat.NB_NL_TWO_BS_NH_HS), + int(EngineKVFormat.NB_NL_TWO_NH_BS_HS), ) -def _is_sglang_mha_format(gpu_kv_format: GPUKVFormat) -> bool: +def _is_sglang_mha_format(engine_kv_format: EngineKVFormat) -> bool: """Return True when a KV format uses SGLang MHA layout (2*NL tensors).""" - return int(gpu_kv_format) in ( - int(GPUKVFormat.TWO_X_NL_X_NBBS_NH_HS), - int(GPUKVFormat.TWO_X_NL_X_NB_BS_NH_HS), + return int(engine_kv_format) in ( + int(EngineKVFormat.TWO_X_NL_X_NBBS_NH_HS), + int(EngineKVFormat.TWO_X_NL_X_NB_BS_NH_HS), ) -def _is_hnd_format(gpu_kv_format: GPUKVFormat) -> bool: +def _is_hnd_format(engine_kv_format: EngineKVFormat) -> bool: """Return True when a per-layer KV format stores heads before block tokens (HND).""" - return int(gpu_kv_format) in ( - int(GPUKVFormat.NL_X_TWO_NB_NH_BS_HS), - int(GPUKVFormat.NL_X_NB_TWO_NH_BS_HS), + return int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS), + int(EngineKVFormat.NL_X_NB_TWO_NH_BS_HS), + int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS), ) -def _is_mla_format(gpu_kv_format: GPUKVFormat) -> bool: +def _is_mla_format(engine_kv_format: EngineKVFormat) -> bool: """Return True when a KV format uses MLA paged layout.""" - return int(gpu_kv_format) in ( - int(GPUKVFormat.NL_X_NB_BS_HS), - int(GPUKVFormat.NL_X_NBBS_ONE_HS), + return int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_NB_BS_HS), + int(EngineKVFormat.NL_X_NBBS_ONE_HS), ) @@ -802,7 +813,7 @@ def _is_ptr_tensor(x: object) -> bool: def _per_layer_paged_shape( - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, nb: int, bs: int, nh: int, @@ -811,7 +822,7 @@ def _per_layer_paged_shape( """Return the logical shape of a single per-layer paged buffer tensor. Args: - gpu_kv_format: The format enum that describes how K/V tokens are laid out. + engine_kv_format: The format enum that describes how K/V tokens are laid out. nb: Number of blocks in the paged buffer (``shape_desc.nb``). bs: Tokens per block / block size (``shape_desc.bs``). nh: Number of attention heads (``shape_desc.nh``). @@ -821,17 +832,23 @@ def _per_layer_paged_shape( A tuple representing the shape needed to reconstruct one layer's tensor from a raw pointer via :func:`_tensor_from_ptr`. """ - fmt = int(gpu_kv_format) - if fmt == int(GPUKVFormat.NL_X_NBBS_ONE_HS): + fmt = int(engine_kv_format) + if fmt == int(EngineKVFormat.NL_X_NBBS_ONE_HS): return (nb * bs, 1, hs) - if fmt == int(GPUKVFormat.NL_X_NB_BS_HS): + if fmt == int(EngineKVFormat.NL_X_NB_BS_HS): return (nb, bs, hs) - if fmt == int(GPUKVFormat.NL_X_TWO_NB_NH_BS_HS): + if fmt == int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS): return (2, nb, nh, bs, hs) - if fmt == int(GPUKVFormat.NL_X_NB_TWO_NH_BS_HS): + if fmt == int(EngineKVFormat.NL_X_NB_TWO_NH_BS_HS): return (nb, 2, nh, bs, hs) - if fmt == int(GPUKVFormat.NL_X_TWO_NB_BS_NH_HS): + if fmt == int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS): + # vLLM CPU blocks-first fused KV: K and V interleaved at the + # second-to-last dim so each layer is [NB, NH, BS, 2, HS]. + return (nb, nh, bs, 2, hs) + if fmt == int(EngineKVFormat.NL_X_TWO_NB_BS_NH_HS): return (2, nb, bs, nh, hs) + if fmt == int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS): + return (nb, nh, bs, 2, hs) # Covers NL_X_NB_TWO_BS_NH_HS and any future NHD variants. return (nb, 2, bs, nh, hs) @@ -888,7 +905,7 @@ def _infer_kv_dtype( def _normalize_paged_layers( paged_buffer_ptrs_tensor: "torch.Tensor | list", - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, shape_desc: "PageBufferShapeDesc | None" = None, device: "torch.device | str | None" = None, dtype: "torch.dtype | None" = None, @@ -905,7 +922,7 @@ def _normalize_paged_layers( - ``list[list[torch.Tensor]]`` (2 x NL) for SGLang MHA formats. - ``list[torch.Tensor]`` (per-layer) for all other formats. """ - if _is_cross_layer_format(gpu_kv_format): + if _is_cross_layer_format(engine_kv_format): if isinstance(paged_buffer_ptrs_tensor, torch.Tensor): if _is_ptr_tensor(paged_buffer_ptrs_tensor): # 1-D pointer tensor with a single entry → reconstruct full tensor. @@ -919,7 +936,7 @@ def _normalize_paged_layers( bs = int(shape_desc.bs) nh = int(shape_desc.nh) hs = int(shape_desc.hs) - if int(gpu_kv_format) == int(GPUKVFormat.NB_NL_TWO_NH_BS_HS): + if int(engine_kv_format) == int(EngineKVFormat.NB_NL_TWO_NH_BS_HS): shape: tuple[int, ...] = (nb, nl, 2, nh, bs, hs) else: shape = (nb, nl, 2, bs, nh, hs) @@ -930,7 +947,7 @@ def _normalize_paged_layers( "Cross-layer formats require a single torch.Tensor input; " "got: " + type(paged_buffer_ptrs_tensor).__name__ ) - if _is_sglang_mha_format(gpu_kv_format): + if _is_sglang_mha_format(engine_kv_format): if _is_ptr_tensor(paged_buffer_ptrs_tensor): # 1-D pointer tensor [K_L0,...,K_LN-1, V_L0,...,V_LN-1] → nested list. if shape_desc is None or device is None or dtype is None: @@ -943,7 +960,7 @@ def _normalize_paged_layers( bs = int(shape_desc.bs) nh = int(shape_desc.nh) hs = int(shape_desc.hs) - is_flat = int(gpu_kv_format) == int(GPUKVFormat.TWO_X_NL_X_NBBS_NH_HS) + is_flat = int(engine_kv_format) == int(EngineKVFormat.TWO_X_NL_X_NBBS_NH_HS) per_layer_shape: tuple[int, ...] = ( (nb * bs, nh, hs) if is_flat else (nb, bs, nh, hs) ) @@ -997,7 +1014,7 @@ def _normalize_paged_layers( bs = int(shape_desc.bs) nh = int(shape_desc.nh) hs = int(shape_desc.hs) - per_shape = _per_layer_paged_shape(gpu_kv_format, nb, bs, nh, hs) + per_shape = _per_layer_paged_shape(engine_kv_format, nb, bs, nh, hs) return [ _tensor_from_ptr(int(p.item()), per_shape, dtype, device) for p in paged_buffer_ptrs_tensor @@ -1018,14 +1035,14 @@ def _normalize_lmcache_objects( lmcache_objects_ptrs: "list[int] | list[torch.Tensor]", shape_desc: "PageBufferShapeDesc | None" = None, lmcache_chunk_size: "int | None" = None, - gpu_kv_format: "GPUKVFormat | None" = None, + engine_kv_format: "EngineKVFormat | None" = None, dtype: "torch.dtype | None" = None, ) -> list[torch.Tensor]: """Normalize LMCache object inputs to chunk tensors. Accepts either a list of chunk tensors or a ``list[int]`` of raw CPU pointers. When a pointer list is provided *shape_desc*, *lmcache_chunk_size*, - *gpu_kv_format*, and *dtype* must be supplied so the tensors can be + *engine_kv_format*, and *dtype* must be supplied so the tensors can be reconstructed via :func:`_tensor_from_ptr` on the CPU. """ if not isinstance(lmcache_objects_ptrs, list): @@ -1042,19 +1059,19 @@ def _normalize_lmcache_objects( if ( shape_desc is None or lmcache_chunk_size is None - or gpu_kv_format is None + or engine_kv_format is None or dtype is None ): raise ValueError( "_normalize_lmcache_objects: shape_desc, lmcache_chunk_size, " - "gpu_kv_format, and dtype are required when lmcache_objects_ptrs " + "engine_kv_format, and dtype are required when lmcache_objects_ptrs " "contains raw int pointers" ) nl = int(shape_desc.nl) nh = int(shape_desc.nh) hs = int(shape_desc.hs) chunk_tokens = lmcache_chunk_size - if _is_mla_format(gpu_kv_format): + if _is_mla_format(engine_kv_format): chunk_shape: tuple[int, ...] = (nl, chunk_tokens, hs) else: chunk_shape = (2, nl, chunk_tokens, nh * hs) @@ -1085,7 +1102,7 @@ def multi_layer_block_kv_transfer( direction: TransferDirection, shape_desc: PageBufferShapeDesc, lmcache_chunk_size: int, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, skip_prefix_n_blocks: int, ) -> None: """Python fallback implementation of block-based multi-layer KV transfer. @@ -1102,7 +1119,7 @@ def multi_layer_block_kv_transfer( direction: Transfer direction (H2D or D2H). shape_desc: Shape descriptor of the page buffer. lmcache_chunk_size: Chunk size of LMCache objects. - gpu_kv_format: GPU KV cache format. + engine_kv_format: GPU KV cache format. skip_prefix_n_blocks: Number of leading blocks to skip. Returns: @@ -1131,7 +1148,7 @@ def multi_layer_block_kv_transfer( ) normalized = _normalize_paged_layers( paged_buffer_ptrs_tensor, - gpu_kv_format, + engine_kv_format, shape_desc=shape_desc, device=device, dtype=kv_dtype, @@ -1140,54 +1157,62 @@ def multi_layer_block_kv_transfer( lmcache_objects_ptrs, shape_desc=shape_desc, lmcache_chunk_size=lmcache_chunk_size, - gpu_kv_format=gpu_kv_format, + engine_kv_format=engine_kv_format, dtype=kv_dtype, ) - block_id_list = _to_block_id_list(block_ids) + n_block_ids = ( + int(block_ids.numel()) + if isinstance(block_ids, torch.Tensor) + else len(block_ids) + ) blocks_per_object = lmcache_chunk_size // int(shape_desc.bs) block_size = int(shape_desc.bs) - if _is_cross_layer_format(gpu_kv_format): + if _is_cross_layer_format(engine_kv_format): _transfer_cross_layer( normalized, object_tensors, - block_id_list, + block_ids, + n_block_ids, blocks_per_object, block_size, - gpu_kv_format, + engine_kv_format, is_d2h, skip_prefix_n_blocks, ) - elif _is_sglang_mha_format(gpu_kv_format): + elif _is_sglang_mha_format(engine_kv_format): _transfer_sglang_mha( normalized, object_tensors, - block_id_list, + block_ids, + n_block_ids, blocks_per_object, block_size, - gpu_kv_format, + engine_kv_format, is_d2h, skip_prefix_n_blocks, ) - elif _is_mla_format(gpu_kv_format): + elif _is_mla_format(engine_kv_format): _transfer_per_layer_mla( normalized, object_tensors, - block_id_list, + block_ids, + n_block_ids, blocks_per_object, block_size, - gpu_kv_format, + engine_kv_format, is_d2h, skip_prefix_n_blocks, ) - elif _is_hnd_format(gpu_kv_format): + elif _is_hnd_format(engine_kv_format): _transfer_per_layer_hnd( normalized, object_tensors, - block_id_list, + block_ids, + n_block_ids, blocks_per_object, block_size, - gpu_kv_format, + engine_kv_format, is_d2h, skip_prefix_n_blocks, ) @@ -1195,10 +1220,11 @@ def multi_layer_block_kv_transfer( _transfer_per_layer_nhd( normalized, object_tensors, - block_id_list, + block_ids, + n_block_ids, blocks_per_object, block_size, - gpu_kv_format, + engine_kv_format, is_d2h, skip_prefix_n_blocks, ) @@ -1234,19 +1260,37 @@ def _valid_block_range( return block_id_list[valid_flat_start:valid_flat_end], offset_in_object +def _valid_block_range_indices( + object_idx: int, + n_block_ids: int, + blocks_per_object: int, + block_size: int, + skip_prefix_n_blocks: int, +) -> tuple[int, int, int] | None: + """Return valid [start, end) range over flat block IDs and object token offset.""" + object_flat_start = object_idx * blocks_per_object + valid_flat_start = max(object_flat_start, skip_prefix_n_blocks) + valid_flat_end = min(object_flat_start + blocks_per_object, n_block_ids) + if valid_flat_start >= valid_flat_end: + return None + offset_in_object = (valid_flat_start - object_flat_start) * block_size + return valid_flat_start, valid_flat_end, offset_in_object + + def _transfer_cross_layer( paged_tensor: torch.Tensor, object_tensors: list[torch.Tensor], - block_id_list: list[int], + block_ids: torch.Tensor | list[int], + n_block_ids: int, blocks_per_object: int, block_size: int, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, is_d2h: bool, skip_prefix_n_blocks: int, ) -> None: """Handle cross-layer formats: single tensor [NB, NL, 2, ...].""" # NHD: [NB, NL, 2, BS, NH, HS] HND: [NB, NL, 2, NH, BS, HS] - is_hnd = int(gpu_kv_format) == int(GPUKVFormat.NB_NL_TWO_NH_BS_HS) + is_hnd = int(engine_kv_format) == int(EngineKVFormat.NB_NL_TWO_NH_BS_HS) num_layers = paged_tensor.shape[1] if is_hnd: @@ -1261,23 +1305,24 @@ def _transfer_cross_layer( # H2D: pre-transfer objects to paged device if not is_d2h and object_tensors: objs_on_device = [obj.to(paged_tensor.device) for obj in object_tensors] + block_ids_dev = torch.as_tensor( + block_ids, dtype=torch.long, device=paged_tensor.device + ) for object_idx, obj in enumerate(object_tensors): - valid = _valid_block_range( + valid = _valid_block_range_indices( object_idx, - block_id_list, + n_block_ids, blocks_per_object, block_size, skip_prefix_n_blocks, ) if valid is None: continue - engine_block_ids, offset_in_object = valid - n_valid = len(engine_block_ids) + idx_start, idx_end, offset_in_object = valid + n_valid = idx_end - idx_start token_end = offset_in_object + n_valid * block_size - eff_idx = torch.tensor( - engine_block_ids, dtype=torch.long, device=paged_tensor.device - ) + eff_idx = block_ids_dev[idx_start:idx_end] if is_d2h: selected = paged_tensor.index_select(0, eff_idx) @@ -1317,17 +1362,18 @@ def _transfer_cross_layer( def _transfer_sglang_mha( paged_tensors: list[list[torch.Tensor]], object_tensors: list[torch.Tensor], - block_id_list: list[int], + block_ids: torch.Tensor | list[int], + n_block_ids: int, blocks_per_object: int, block_size: int, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, is_d2h: bool, skip_prefix_n_blocks: int, ) -> None: """Handle SGLang MHA formats: 2*NL tensors (list[list[Tensor]]).""" # TWO_X_NL_X_NBBS_NH_HS: each tensor [NB*BS, NH, HS] # TWO_X_NL_X_NB_BS_NH_HS: each tensor [NB, BS, NH, HS] - is_flat = int(gpu_kv_format) == int(GPUKVFormat.TWO_X_NL_X_NBBS_NH_HS) + is_flat = int(engine_kv_format) == int(EngineKVFormat.TWO_X_NL_X_NBBS_NH_HS) num_layers = len(paged_tensors[0]) # Determine target device from first tensor @@ -1336,21 +1382,22 @@ def _transfer_sglang_mha( # H2D: pre-transfer objects if not is_d2h and object_tensors: objs_on_device = [obj.to(target_device) for obj in object_tensors] + block_ids_dev = torch.as_tensor(block_ids, dtype=torch.long, device=target_device) for object_idx, obj in enumerate(object_tensors): - valid = _valid_block_range( + valid = _valid_block_range_indices( object_idx, - block_id_list, + n_block_ids, blocks_per_object, block_size, skip_prefix_n_blocks, ) if valid is None: continue - engine_block_ids, offset_in_object = valid - n_valid = len(engine_block_ids) + idx_start, idx_end, offset_in_object = valid + n_valid = idx_end - idx_start token_end = offset_in_object + n_valid * block_size - eff_idx = torch.tensor(engine_block_ids, dtype=torch.long, device=target_device) + eff_idx = block_ids_dev[idx_start:idx_end] if is_flat: # Flat token positions for all valid blocks: # block_id * block_size + token offset. Reused across layer/KV pairs. @@ -1395,57 +1442,68 @@ def _transfer_sglang_mha( def _transfer_per_layer_mla( layer_tensors: list[torch.Tensor], object_tensors: list[torch.Tensor], - block_id_list: list[int], + block_ids: torch.Tensor | list[int], + n_block_ids: int, blocks_per_object: int, block_size: int, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, is_d2h: bool, skip_prefix_n_blocks: int, ) -> None: """Handle MLA per-layer formats: [NB, BS, HS].""" - if not is_d2h and layer_tensors and object_tensors: - target_device = layer_tensors[0].device - objs_on_device = [obj.to(target_device) for obj in object_tensors] + if not layer_tensors or not object_tensors: + return + + is_flat = int(engine_kv_format) == int(EngineKVFormat.NL_X_NBBS_ONE_HS) + target_device = layer_tensors[0].device + if is_flat: + token_offsets = torch.arange(block_size, dtype=torch.long, device=target_device) + block_ids_dev = torch.as_tensor(block_ids, dtype=torch.long, device=target_device) - for layer_idx, layer in enumerate(layer_tensors): - is_flat = int(gpu_kv_format) == int(GPUKVFormat.NL_X_NBBS_ONE_HS) + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range_indices( + object_idx, + n_block_ids, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + idx_start, idx_end, offset_in_object = valid + n_valid = idx_end - idx_start + token_end = offset_in_object + n_valid * block_size + eff_idx = block_ids_dev[idx_start:idx_end] if is_flat: - token_offsets = torch.arange( - block_size, dtype=torch.long, device=layer.device - ) - for object_idx, obj in enumerate(object_tensors): - valid = _valid_block_range( - object_idx, - block_id_list, - blocks_per_object, - block_size, - skip_prefix_n_blocks, - ) - if valid is None: - continue - engine_block_ids, offset_in_object = valid - n_valid = len(engine_block_ids) - token_end = offset_in_object + n_valid * block_size - eff_idx = torch.tensor( - engine_block_ids, dtype=torch.long, device=layer.device - ) - if is_flat: - token_indices = ( - eff_idx[:, None] * block_size + token_offsets[None, :] - ).reshape(-1) + token_indices = ( + eff_idx[:, None] * block_size + token_offsets[None, :] + ).reshape(-1) - if is_d2h: + if is_d2h: + hidden_size = layer_tensors[0].shape[-1] + chunk_gpu = torch.empty( + len(layer_tensors), + n_valid * block_size, + hidden_size, + dtype=layer_tensors[0].dtype, + device=target_device, + ) + for layer_idx, layer in enumerate(layer_tensors): if is_flat: - layer_blocks = layer.index_select(0, token_indices) + dst = chunk_gpu[layer_idx].view( + n_valid * block_size, 1, hidden_size + ) + torch.index_select(layer, 0, token_indices, out=dst) else: - layer_blocks = layer.index_select(0, eff_idx) - flat = layer_blocks.reshape(n_valid * block_size, layer.shape[-1]) - obj[layer_idx, offset_in_object:token_end].copy_( - flat, non_blocking=True - ) - else: - obj_device = objs_on_device[object_idx] - src = obj_device[layer_idx, offset_in_object:token_end] + dst = chunk_gpu[layer_idx].view(n_valid, block_size, hidden_size) + torch.index_select(layer, 0, eff_idx, out=dst) + obj[:, offset_in_object:token_end].copy_(chunk_gpu, non_blocking=True) + else: + chunk_gpu = obj[:, offset_in_object:token_end].to( + target_device, non_blocking=True + ) + for layer_idx, layer in enumerate(layer_tensors): + src = chunk_gpu[layer_idx] hidden_size = layer.shape[-1] if is_flat: src_tokens = src.reshape(n_valid * block_size, 1, hidden_size) @@ -1458,141 +1516,231 @@ def _transfer_per_layer_mla( def _transfer_per_layer_hnd( layer_tensors: list[torch.Tensor], object_tensors: list[torch.Tensor], - block_id_list: list[int], + block_ids: torch.Tensor | list[int], + n_block_ids: int, blocks_per_object: int, block_size: int, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, is_d2h: bool, skip_prefix_n_blocks: int, ) -> None: """Handle per-layer HND formats: heads before block tokens.""" - if not is_d2h and layer_tensors and object_tensors: - target_device = layer_tensors[0].device - objs_on_device = [obj.to(target_device) for obj in object_tensors] + if not layer_tensors or not object_tensors: + return - for layer_idx, layer in enumerate(layer_tensors): - # Determine K/V split based on specific format - if int(gpu_kv_format) == int(GPUKVFormat.NL_X_TWO_NB_NH_BS_HS): - k_t, v_t = layer[0], layer[1] - else: - k_t, v_t = layer[:, 0], layer[:, 1] - _nb, nh, _bs, hs = k_t.shape - - for object_idx, obj in enumerate(object_tensors): - valid = _valid_block_range( - object_idx, - block_id_list, - blocks_per_object, + target_device = layer_tensors[0].device + block_ids_dev = torch.as_tensor(block_ids, dtype=torch.long, device=target_device) + + first_layer = layer_tensors[0] + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS): + first_k = first_layer[0] + elif int(engine_kv_format) == int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS): + first_k = first_layer[:, :, :, 0] + else: + first_k = first_layer[:, 0] + _nb0, nh0, _bs0, hs0 = first_k.shape + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range_indices( + object_idx, + n_block_ids, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + idx_start, idx_end, offset_in_object = valid + n_valid = idx_end - idx_start + token_end = offset_in_object + n_valid * block_size + eff_idx = block_ids_dev[idx_start:idx_end] + + if is_d2h: + chunk_gpu = torch.empty( + 2, + len(layer_tensors), + n_valid * block_size, + nh0 * hs0, + dtype=first_k.dtype, + device=target_device, + ) + scratch = torch.empty( + n_valid, + nh0, block_size, - skip_prefix_n_blocks, + hs0, + dtype=first_k.dtype, + device=target_device, ) - if valid is None: - continue - engine_block_ids, offset_in_object = valid - n_valid = len(engine_block_ids) - token_end = offset_in_object + n_valid * block_size - eff_idx = torch.tensor( - engine_block_ids, dtype=torch.long, device=layer.device + for layer_idx, layer in enumerate(layer_tensors): + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS): + k_t, v_t = layer[0], layer[1] + torch.index_select(k_t, 0, eff_idx, out=scratch) + chunk_gpu[0, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + scratch.permute(0, 2, 1, 3) + ) + torch.index_select(v_t, 0, eff_idx, out=scratch) + chunk_gpu[1, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + scratch.permute(0, 2, 1, 3) + ) + elif int(engine_kv_format) == int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS): + k_t, v_t = layer[:, :, :, 0], layer[:, :, :, 1] + torch.index_select(k_t, 0, eff_idx, out=scratch) + chunk_gpu[0, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + scratch.permute(0, 2, 1, 3) + ) + torch.index_select(v_t, 0, eff_idx, out=scratch) + chunk_gpu[1, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + scratch.permute(0, 2, 1, 3) + ) + else: + # FlashInfer HND stores KV as [NB, 2, NH, BS, HS]. + # Gather on dim=0 first so reads stay contiguous in memory; + # index_select on layer[:, 0]/layer[:, 1] non-contiguous views + # triggers slower element-wise gather reads. + selected = layer.index_select(0, eff_idx) + chunk_gpu[0, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + selected[:, 0].permute(0, 2, 1, 3) + ) + chunk_gpu[1, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + selected[:, 1].permute(0, 2, 1, 3) + ) + obj[:, :, offset_in_object:token_end].copy_(chunk_gpu, non_blocking=True) + else: + chunk_gpu = obj[:, :, offset_in_object:token_end].to( + target_device, non_blocking=True ) - - if is_d2h: + for layer_idx, layer in enumerate(layer_tensors): + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS): + k_t, v_t = layer[0], layer[1] + elif int(engine_kv_format) == int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS): + k_t, v_t = layer[:, :, :, 0], layer[:, :, :, 1] + else: + k_t, v_t = layer[:, 0], layer[:, 1] + _nb, nh, _bs, hs = k_t.shape k_blocks = ( - k_t.index_select(0, eff_idx) + chunk_gpu[0, layer_idx] + .reshape(n_valid, block_size, nh, hs) .permute(0, 2, 1, 3) - .reshape(n_valid * block_size, nh * hs) ) v_blocks = ( - v_t.index_select(0, eff_idx) + chunk_gpu[1, layer_idx] + .reshape(n_valid, block_size, nh, hs) .permute(0, 2, 1, 3) - .reshape(n_valid * block_size, nh * hs) - ) - obj[0, layer_idx, offset_in_object:token_end].copy_( - k_blocks, non_blocking=True ) - obj[1, layer_idx, offset_in_object:token_end].copy_( - v_blocks, non_blocking=True - ) - else: - obj_device = objs_on_device[object_idx] - k_src = obj_device[0, layer_idx, offset_in_object:token_end] - v_src = obj_device[1, layer_idx, offset_in_object:token_end] - k_blocks = k_src.reshape(n_valid, block_size, nh, hs).permute( - 0, 2, 1, 3 - ) - v_blocks = v_src.reshape(n_valid, block_size, nh, hs).permute( - 0, 2, 1, 3 - ) - k_t.index_copy_(0, eff_idx, k_blocks) - v_t.index_copy_(0, eff_idx, v_blocks) + if int(engine_kv_format) == int(EngineKVFormat.NL_X_NB_TWO_NH_BS_HS): + layer.index_copy_( + 0, eff_idx, torch.stack([k_blocks, v_blocks], dim=1) + ) + else: + k_t.index_copy_(0, eff_idx, k_blocks) + v_t.index_copy_(0, eff_idx, v_blocks) def _transfer_per_layer_nhd( layer_tensors: list[torch.Tensor], object_tensors: list[torch.Tensor], - block_id_list: list[int], + block_ids: torch.Tensor | list[int], + n_block_ids: int, blocks_per_object: int, block_size: int, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, is_d2h: bool, skip_prefix_n_blocks: int, ) -> None: """Handle per-layer NHD formats: block tokens before heads.""" - if not is_d2h and layer_tensors and object_tensors: - target_device = layer_tensors[0].device - objs_on_device = [obj.to(target_device) for obj in object_tensors] + if not layer_tensors or not object_tensors: + return - for layer_idx, layer in enumerate(layer_tensors): - # Determine K/V split based on specific format - if int(gpu_kv_format) == int(GPUKVFormat.NL_X_TWO_NB_BS_NH_HS): - k_t, v_t = layer[0], layer[1] - else: - k_t, v_t = layer[:, 0], layer[:, 1] - _nb, _bs, nh, hs = k_t.shape - - for object_idx, obj in enumerate(object_tensors): - valid = _valid_block_range( - object_idx, - block_id_list, - blocks_per_object, - block_size, - skip_prefix_n_blocks, + target_device = layer_tensors[0].device + block_ids_dev = torch.as_tensor(block_ids, dtype=torch.long, device=target_device) + + first_layer = layer_tensors[0] + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_BS_NH_HS): + first_k = first_layer[0] + else: + first_k = first_layer[:, 0] + _nb0, _bs0, nh0, hs0 = first_k.shape + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range_indices( + object_idx, + n_block_ids, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + idx_start, idx_end, offset_in_object = valid + n_valid = idx_end - idx_start + token_end = offset_in_object + n_valid * block_size + eff_idx = block_ids_dev[idx_start:idx_end] + + if is_d2h: + chunk_gpu = torch.empty( + 2, + len(layer_tensors), + n_valid * block_size, + nh0 * hs0, + dtype=first_k.dtype, + device=target_device, ) - if valid is None: - continue - engine_block_ids, offset_in_object = valid - n_valid = len(engine_block_ids) - token_end = offset_in_object + n_valid * block_size - eff_idx = torch.tensor( - engine_block_ids, dtype=torch.long, device=layer.device + for layer_idx, layer in enumerate(layer_tensors): + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_BS_NH_HS): + k_t, v_t = layer[0], layer[1] + torch.index_select( + k_t, + 0, + eff_idx, + out=chunk_gpu[0, layer_idx].view(n_valid, block_size, nh0, hs0), + ) + torch.index_select( + v_t, + 0, + eff_idx, + out=chunk_gpu[1, layer_idx].view(n_valid, block_size, nh0, hs0), + ) + else: + # FlashInfer NHD stores KV as [NB, 2, BS, NH, HS]. + # Gather on dim=0 first to avoid index_select from + # non-contiguous layer[:, 0]/layer[:, 1] views, which + # trigger slower element-wise gather reads. + selected = layer.index_select(0, eff_idx) + chunk_gpu[0, layer_idx].copy_( + selected[:, 0].reshape(n_valid * block_size, nh0 * hs0) + ) + chunk_gpu[1, layer_idx].copy_( + selected[:, 1].reshape(n_valid * block_size, nh0 * hs0) + ) + obj[:, :, offset_in_object:token_end].copy_(chunk_gpu, non_blocking=True) + else: + chunk_gpu = obj[:, :, offset_in_object:token_end].to( + target_device, non_blocking=True ) - - if is_d2h: - k_blocks = k_t.index_select(0, eff_idx).reshape( - n_valid * block_size, nh * hs - ) - v_blocks = v_t.index_select(0, eff_idx).reshape( - n_valid * block_size, nh * hs - ) - obj[0, layer_idx, offset_in_object:token_end].copy_( - k_blocks, non_blocking=True - ) - obj[1, layer_idx, offset_in_object:token_end].copy_( - v_blocks, non_blocking=True - ) - else: - obj_device = objs_on_device[object_idx] - k_src = obj_device[0, layer_idx, offset_in_object:token_end] - v_src = obj_device[1, layer_idx, offset_in_object:token_end] - k_t.index_copy_( - 0, - eff_idx, - k_src.reshape(n_valid, block_size, nh, hs), - ) - v_t.index_copy_( - 0, - eff_idx, - v_src.reshape(n_valid, block_size, nh, hs), - ) + for layer_idx, layer in enumerate(layer_tensors): + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_BS_NH_HS): + k_t, v_t = layer[0], layer[1] + k_t.index_copy_( + 0, + eff_idx, + chunk_gpu[0, layer_idx].reshape(n_valid, block_size, nh0, hs0), + ) + v_t.index_copy_( + 0, + eff_idx, + chunk_gpu[1, layer_idx].reshape(n_valid, block_size, nh0, hs0), + ) + else: + k_blocks = chunk_gpu[0, layer_idx].reshape( + n_valid, block_size, nh0, hs0 + ) + v_blocks = chunk_gpu[1, layer_idx].reshape( + n_valid, block_size, nh0, hs0 + ) + layer.index_copy_( + 0, eff_idx, torch.stack([k_blocks, v_blocks], dim=1) + ) def single_layer_kv_transfer( @@ -1600,7 +1748,7 @@ def single_layer_kv_transfer( vllm_key_value_cache: torch.Tensor, slot_mapping: torch.Tensor, direction: TransferDirection, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, token_major: bool = False, ): """ @@ -1638,9 +1786,9 @@ def single_layer_kv_transfer( valid_token_indices = torch.nonzero(valid_mask_kv, as_tuple=True)[0] valid_slots = slots_kv[valid_mask_kv].to(paged_memory_device) - is_mla = int(gpu_kv_format) in ( - int(GPUKVFormat.NL_X_NB_BS_HS), - int(GPUKVFormat.NL_X_NBBS_ONE_HS), + is_mla = int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_NB_BS_HS), + int(EngineKVFormat.NL_X_NBBS_ONE_HS), ) if is_mla: @@ -1665,7 +1813,7 @@ def single_layer_kv_transfer( else: # ── Non-MLA format ── # Determine vLLM layout and block_size - is_two_major = int(gpu_kv_format) == int(GPUKVFormat.NL_X_TWO_NB_BS_NH_HS) + is_two_major = int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_BS_NH_HS) # flash attn: # [2, num_blocks, block_size, num_heads, head_size] # -> dim2 = block_size From 49e060d5cde0921af3c1f9ca05052ecbafb92442 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 17 Jun 2026 05:32:22 +0000 Subject: [PATCH 27/31] add log Signed-off-by: Tony Lin --- lmcache/v1/multiprocess/modules/gpu_transfer.py | 10 ++++++++++ .../multiprocess/transfer_context/worker_transfer.py | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/lmcache/v1/multiprocess/modules/gpu_transfer.py b/lmcache/v1/multiprocess/modules/gpu_transfer.py index 4ff7359956..2e3f45e56d 100644 --- a/lmcache/v1/multiprocess/modules/gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/gpu_transfer.py @@ -501,6 +501,16 @@ def store( _t_record_start = time.perf_counter() event.record() _t_record_end = time.perf_counter() + + _t_sync_start = time.perf_counter() + event.synchronize() # 等 GPU 上所有 kernel + D2H 真正完成 + _t_sync_end = time.perf_counter() + logger.info( + "[GPU-STORE-SYNC] req=%s gpu_sync=%.3f total_with_sync=%.3f ms", + str(key.request_id), + (_t_sync_end - _t_sync_start) * 1000, + (_t_sync_end - st) * 1000, + ) # Fail closed: commit the reserved objects only when every chunk # copied successfully; otherwise the whole store is skipped. stored_count = len(reserved_dict) if store_succeeded else 0 diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index 4b475bef76..4d3e3bb84d 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -559,8 +559,8 @@ def create_transfer_context( if resolved_mode is MPTransferMode.DATA: return _build_data_context(kv_caches) # AUTO: preserve the historical device-type-based dispatch. - # if device_type == "cuda": - # return HandleTransferContext() + if device_type == "cuda": + return HandleTransferContext() return _build_data_context(kv_caches) From 3e6deea8abf415a6f489fececb39f5a70ddf0b15 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 17 Jun 2026 06:04:18 +0000 Subject: [PATCH 28/31] add use_c_ops --- .../v1/multiprocess/modules/gpu_transfer.py | 54 +++++++++++++------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/lmcache/v1/multiprocess/modules/gpu_transfer.py b/lmcache/v1/multiprocess/modules/gpu_transfer.py index 2e3f45e56d..3b9be50e3e 100644 --- a/lmcache/v1/multiprocess/modules/gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/gpu_transfer.py @@ -434,6 +434,7 @@ def store( _t_record_start = _t_publish _t_record_end = _t_publish _t_callback_end = _t_publish + use_c_ops = False try: layout_desc = get_layout_desc(gpu_context, self._ctx.chunk_size) reserved_dict = self._ctx.storage_manager.reserve_write( @@ -460,30 +461,49 @@ def store( chunk_block_ids_gpu = block_ids_per_group_gpu[group_idx][ idx * bpc : (idx + 1) * bpc ] - tmp_buffer = gpu_context.get_tmp_chunk_gpu_buffer(group_idx) - group_kv_pointers = gpu_context.get_group_kv_pointers(group_idx) + if use_c_ops: + tmp_buffer = gpu_context.get_tmp_chunk_gpu_buffer(group_idx) + group_kv_pointers = gpu_context.get_group_kv_pointers(group_idx) # Kernel contract: ``group_lmcache_chunk_size`` here is the # number of *physical* slots per chunk for this group # (= logical chunk_size // compress_ratio). group_lmcache_chunk_size = gpu_context.get_physical_chunk_size( group_idx ) - lmc_ops.multi_layer_block_kv_transfer( - group_kv_pointers, - [tmp_buffer.data_ptr()], - chunk_block_ids_gpu, - gpu_context.device, - lmc_ops.TransferDirection.D2H, - gpu_context.get_shape_desc(group_idx), - group_lmcache_chunk_size, - gpu_context.gpu_kv_format_, - 0, - ) + if use_c_ops: + lmc_ops.multi_layer_block_kv_transfer( + group_kv_pointers, + [tmp_buffer.data_ptr()], + chunk_block_ids_gpu, + gpu_context.device, + lmc_ops.TransferDirection.D2H, + gpu_context.get_shape_desc(group_idx), + group_lmcache_chunk_size, + gpu_context.gpu_kv_format_, + 0, + ) + else: + group = gpu_context.kv_layer_groups_manager.kv_layer_groups[group_idx] + kv_tensors = [gpu_context.kv_tensors[i] for i in group.layer_indices] + + lmc_ops.multi_layer_block_kv_transfer( + kv_tensors, + [memory_obj.tensor], + chunk_block_ids_gpu, + gpu_context.device, + lmc_ops.TransferDirection.D2H, + gpu_context.get_shape_desc(group_idx), + group_lmcache_chunk_size, + gpu_context.gpu_kv_format_, + 0, + ) + _t_kernel_end = time.perf_counter() - # Store is not batched, so we always use chunk_idx=0 (single slot) - lmcache_memcpy_async_d2h( - gpu_context.get_tmp_gpu_buffer_flat(chunk_idx=0), memory_obj - ) + if use_c_ops: + # Store is not batched, so we always use chunk_idx=0 (single slot) + lmcache_memcpy_async_d2h( + gpu_context.get_tmp_gpu_buffer_flat(chunk_idx=0), memory_obj + ) _t_memcpy_end = time.perf_counter() logger.info( "[GPU-STORE-CHUNK] req=%s chunk_idx=%d kernel=%.3f memcpy_d2h=%.3f ms", From 1a15c7987b3a0cec8309c06d6702b67cdc39442f Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Thu, 18 Jun 2026 02:43:01 +0000 Subject: [PATCH 29/31] force data transfer --- lmcache/v1/multiprocess/modules/gpu_transfer.py | 5 +++-- lmcache/v1/multiprocess/transfer_context/worker_transfer.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lmcache/v1/multiprocess/modules/gpu_transfer.py b/lmcache/v1/multiprocess/modules/gpu_transfer.py index 3b9be50e3e..c9f8152edf 100644 --- a/lmcache/v1/multiprocess/modules/gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/gpu_transfer.py @@ -434,7 +434,7 @@ def store( _t_record_start = _t_publish _t_record_end = _t_publish _t_callback_end = _t_publish - use_c_ops = False + use_c_ops = True try: layout_desc = get_layout_desc(gpu_context, self._ctx.chunk_size) reserved_dict = self._ctx.storage_manager.reserve_write( @@ -523,7 +523,8 @@ def store( _t_record_end = time.perf_counter() _t_sync_start = time.perf_counter() - event.synchronize() # 等 GPU 上所有 kernel + D2H 真正完成 + # hlin99: debug mode + # event.synchronize() # 等 GPU 上所有 kernel + D2H 真正完成 _t_sync_end = time.perf_counter() logger.info( "[GPU-STORE-SYNC] req=%s gpu_sync=%.3f total_with_sync=%.3f ms", diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index 4d3e3bb84d..4b475bef76 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -559,8 +559,8 @@ def create_transfer_context( if resolved_mode is MPTransferMode.DATA: return _build_data_context(kv_caches) # AUTO: preserve the historical device-type-based dispatch. - if device_type == "cuda": - return HandleTransferContext() + # if device_type == "cuda": + # return HandleTransferContext() return _build_data_context(kv_caches) From 5ecb60ecd5bd8877788dfe46caa8db7c709a7ea9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Jun 2026 02:45:11 +0000 Subject: [PATCH 30/31] feat: add detailed timing instrumentation for non-GPU store path Add per-step timing logs to the server-side non-GPU store path so that performance can be profiled comparably to the GPU IPC path. non_gpu_transfer.py: - [SRV-PREPARE-STORE]: times context/strategy lookup (resolve_keys) and strategy.prepare_store() call, logs strategy name (shm/pickle) - [SRV-COMMIT-STORE]: times strategy.commit_store() call and total time since prepare, logs strategy name and token count - Imports ShmTransferStrategy for isinstance strategy detection server_transfer.py: - Adds `import time` - [PICKLE-COMMIT]: per-step breakdown of deserialize / reserve_write / copy_loop / finish_write with total (PickleTransferStrategy.commit_store) - [SHM-PREPARE]: per-step breakdown of resolve_keys / reserve_write / slot-descriptor loop with total (ShmTransferStrategy.prepare_store) - [SHM-COMMIT]: finish_write and total timing for the SHM fast path (ShmTransferStrategy.commit_store, non-fallback branch only) All timing uses time.perf_counter(); all logs use %s/%.3f format strings. --- .../multiprocess/modules/non_gpu_transfer.py | 28 +++++++++++++ .../multiprocess/modules/server_transfer.py | 42 ++++++++++++++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/lmcache/v1/multiprocess/modules/non_gpu_transfer.py b/lmcache/v1/multiprocess/modules/non_gpu_transfer.py index 0f15026176..47098dfd0d 100644 --- a/lmcache/v1/multiprocess/modules/non_gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/non_gpu_transfer.py @@ -35,6 +35,7 @@ # Local from .server_transfer import ( + ShmTransferStrategy, TransferStrategy, create_transfer_strategy, ) @@ -296,6 +297,7 @@ def prepare_store( Returns: PrepareStoreResponse with empty slots for pickle mode. """ + t_start = time.perf_counter() entry = self._non_gpu_contexts.get(instance_id) if entry is None: raise ValueError( @@ -306,14 +308,26 @@ def prepare_store( raise ValueError( f"transfer strategy not registered for instance ID {instance_id}" ) + t_resolve = time.perf_counter() response = strategy.prepare_store( key=key, instance_id=instance_id, context=entry.metadata, resolve_obj_keys=self._ctx.resolve_obj_keys, ) + t_prepare = time.perf_counter() session = self._ctx.session_manager.get_or_create(key.request_id) session.extras["store_start_time"] = time.perf_counter() + strategy_name = "shm" if isinstance(strategy, ShmTransferStrategy) else "pickle" + logger.info( + "[SRV-PREPARE-STORE] req=%s resolve_keys=%.3f prepare=%.3f" + " total=%.3f ms (strategy=%s)", + key.request_id, + (t_resolve - t_start) * 1000, + (t_prepare - t_resolve) * 1000, + (t_prepare - t_start) * 1000, + strategy_name, + ) return response @_lmcache_nvtx_annotate @@ -349,6 +363,7 @@ def commit_store( ) session = self._ctx.session_manager.get_or_create(key.request_id) st = session.extras.pop("store_start_time", None) + t_commit_start = time.perf_counter() result = strategy.commit_store( key=key, instance_id=instance_id, @@ -356,13 +371,26 @@ def commit_store( context=entry.metadata, resolve_obj_keys=self._ctx.resolve_obj_keys, ) + t_commit_end = time.perf_counter() if st is not None and result: num_tokens = len(self._ctx.resolve_obj_keys(key)) * self._ctx.chunk_size + strategy_name = ( + "shm" if isinstance(strategy, ShmTransferStrategy) else "pickle" + ) logger.info( "Stored %d tokens in %.3f seconds", num_tokens, time.perf_counter() - st, ) + logger.info( + "[SRV-COMMIT-STORE] req=%s commit=%.3f total_since_prepare=%.3f ms" + " (strategy=%s, num_tokens=%d)", + key.request_id, + (t_commit_end - t_commit_start) * 1000, + (t_commit_end - st) * 1000, + strategy_name, + num_tokens, + ) return result @_lmcache_nvtx_annotate diff --git a/lmcache/v1/multiprocess/modules/server_transfer.py b/lmcache/v1/multiprocess/modules/server_transfer.py index fc967f79db..c2eaa85d50 100644 --- a/lmcache/v1/multiprocess/modules/server_transfer.py +++ b/lmcache/v1/multiprocess/modules/server_transfer.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any import abc import pickle +import time # Third Party import torch @@ -208,11 +209,14 @@ def commit_store( Returns: ``True`` when every reserved object is written successfully. """ + t_start = time.perf_counter() obj_keys = resolve_obj_keys(key) chunks: list[torch.Tensor] = pickle.loads(cpu_data) + t_deserialize = time.perf_counter() reserved_dict = self._storage_manager.reserve_write( obj_keys, context.layout_desc, "new" ) + t_reserve_write = time.perf_counter() written_keys: list[ObjectKey] = [] try: for idx, obj_key in enumerate(obj_keys): @@ -229,9 +233,21 @@ def commit_store( memory_obj.tensor.copy_(chunk_cpu) written_keys.append(obj_key) finally: + t_copy_loop = time.perf_counter() if written_keys: self._storage_manager.finish_write(written_keys) - + t_finish_write = time.perf_counter() + logger.info( + "[PICKLE-COMMIT] req=%s deserialize=%.3f reserve_write=%.3f" + " copy_loop=%.3f finish_write=%.3f total=%.3f ms (num_chunks=%d)", + key.request_id, + (t_deserialize - t_start) * 1000, + (t_reserve_write - t_deserialize) * 1000, + (t_copy_loop - t_reserve_write) * 1000, + (t_finish_write - t_copy_loop) * 1000, + (t_finish_write - t_start) * 1000, + len(chunks), + ) return len(written_keys) == len(reserved_dict) def prepare_retrieve( @@ -321,10 +337,13 @@ def prepare_store( Returns: Context with ``slots`` and ``chunk_indices``. """ + t_start = time.perf_counter() obj_keys = resolve_obj_keys(key) + t_resolve = time.perf_counter() reserved = self._storage_manager.reserve_write( obj_keys, context.layout_desc, "new" ) + t_reserve_write = time.perf_counter() slots: list[dict[str, Any]] = [] chunk_indices: list[int] = [] reserved_keys: list[ObjectKey] = [] @@ -350,6 +369,17 @@ def prepare_store( ] if unused_keys: self._storage_manager.finish_write(unused_keys) + t_slots = time.perf_counter() + logger.info( + "[SHM-PREPARE] req=%s resolve_keys=%.3f reserve_write=%.3f" + " slots=%.3f total=%.3f ms (num_slots=%d)", + key.request_id, + (t_resolve - t_start) * 1000, + (t_reserve_write - t_resolve) * 1000, + (t_slots - t_reserve_write) * 1000, + (t_slots - t_start) * 1000, + len(reserved_keys), + ) if not reserved_keys: return PrepareStoreResponse(context={"slots": [], "chunk_indices": []}) transfer_key = self._transfer_key_factory(key, instance_id) @@ -380,13 +410,23 @@ def commit_store( context=context, resolve_obj_keys=resolve_obj_keys, ) + t_start = time.perf_counter() transfer_key = self._transfer_key_factory(key, instance_id) with self._pending_lock: reserved_keys = self._pending_writes.pop(transfer_key, None) if reserved_keys is None: return False + t_before_fw = time.perf_counter() if reserved_keys: self._storage_manager.finish_write(reserved_keys) + t_finish_write = time.perf_counter() + logger.info( + "[SHM-COMMIT] req=%s finish_write=%.3f total=%.3f ms (num_keys=%d)", + key.request_id, + (t_finish_write - t_before_fw) * 1000, + (t_finish_write - t_start) * 1000, + len(reserved_keys), + ) return True def prepare_retrieve( From d2cf6c6ab3eaf55d89a01fce8caf1bb3e8be5199 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Jun 2026 02:46:43 +0000 Subject: [PATCH 31/31] refactor: use strategy_name property instead of isinstance checks Add abstract strategy_name property to TransferStrategy base class, overridden as "pickle" in PickleTransferStrategy and "shm" in ShmTransferStrategy. Update non_gpu_transfer.py to call strategy.strategy_name directly, removing the isinstance coupling and the now-unnecessary ShmTransferStrategy import. --- .../multiprocess/modules/non_gpu_transfer.py | 9 ++------- .../multiprocess/modules/server_transfer.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/lmcache/v1/multiprocess/modules/non_gpu_transfer.py b/lmcache/v1/multiprocess/modules/non_gpu_transfer.py index 47098dfd0d..607201d4f6 100644 --- a/lmcache/v1/multiprocess/modules/non_gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/non_gpu_transfer.py @@ -35,7 +35,6 @@ # Local from .server_transfer import ( - ShmTransferStrategy, TransferStrategy, create_transfer_strategy, ) @@ -318,7 +317,6 @@ def prepare_store( t_prepare = time.perf_counter() session = self._ctx.session_manager.get_or_create(key.request_id) session.extras["store_start_time"] = time.perf_counter() - strategy_name = "shm" if isinstance(strategy, ShmTransferStrategy) else "pickle" logger.info( "[SRV-PREPARE-STORE] req=%s resolve_keys=%.3f prepare=%.3f" " total=%.3f ms (strategy=%s)", @@ -326,7 +324,7 @@ def prepare_store( (t_resolve - t_start) * 1000, (t_prepare - t_resolve) * 1000, (t_prepare - t_start) * 1000, - strategy_name, + strategy.strategy_name, ) return response @@ -374,9 +372,6 @@ def commit_store( t_commit_end = time.perf_counter() if st is not None and result: num_tokens = len(self._ctx.resolve_obj_keys(key)) * self._ctx.chunk_size - strategy_name = ( - "shm" if isinstance(strategy, ShmTransferStrategy) else "pickle" - ) logger.info( "Stored %d tokens in %.3f seconds", num_tokens, @@ -388,7 +383,7 @@ def commit_store( key.request_id, (t_commit_end - t_commit_start) * 1000, (t_commit_end - st) * 1000, - strategy_name, + strategy.strategy_name, num_tokens, ) return result diff --git a/lmcache/v1/multiprocess/modules/server_transfer.py b/lmcache/v1/multiprocess/modules/server_transfer.py index c2eaa85d50..e3e178b6fa 100644 --- a/lmcache/v1/multiprocess/modules/server_transfer.py +++ b/lmcache/v1/multiprocess/modules/server_transfer.py @@ -86,6 +86,15 @@ class TransferStrategy(abc.ABC): shared-memory-based transfers behind a common interface. """ + @property + @abc.abstractmethod + def strategy_name(self) -> str: + """Return a short human-readable name identifying this strategy. + + Returns: + A lowercase string label such as ``"pickle"`` or ``"shm"``. + """ + @abc.abstractmethod def prepare_store( self, @@ -183,6 +192,11 @@ def __init__( """ self._storage_manager = storage_manager + @property + def strategy_name(self) -> str: + """Return ``"pickle"`` as the strategy identifier.""" + return "pickle" + def prepare_store( self, key: IPCCacheEngineKey, @@ -325,6 +339,11 @@ def __init__( self._transfer_key_factory = transfer_key_factory self._fallback_strategy = fallback_strategy + @property + def strategy_name(self) -> str: + """Return ``"shm"`` as the strategy identifier.""" + return "shm" + def prepare_store( self, key: IPCCacheEngineKey,