diff --git a/lmcache/integration/vllm/lmcache_mp_connector.py b/lmcache/integration/vllm/lmcache_mp_connector.py index 27fa2de489..988fb313cd 100644 --- a/lmcache/integration/vllm/lmcache_mp_connector.py +++ b/lmcache/integration/vllm/lmcache_mp_connector.py @@ -496,6 +496,11 @@ class LMCacheMPConnectorMetadata(KVConnectorMetadata): def __init__(self): super().__init__() self.requests: list[LMCacheMPRequestMetadata] = [] + # True only on scheduler steps where preemption/eviction may overwrite + # blocks referenced by in-flight async stores. Worker-side + # handle_preemptions() uses this hint to flush deferred gather + # (device->CPU copy) work before vLLM can overwrite source KV blocks. + self.need_flush: bool = False def add_request_metadata(self, request_metadata: LMCacheMPRequestMetadata): self.requests.append(request_metadata) @@ -513,7 +518,7 @@ def __str__(self): f"num_blocks={len(req_meta.op.block_ids[0])}, " f"block_ids={req_meta.op.block_ids})" ) - return "[" + "\n".join(request_strs) + "]" + return f"need_flush={self.need_flush}; [" + "\n".join(request_strs) + "]" def __repr__(self): return self.__str__() @@ -741,6 +746,17 @@ def wait_for_save(self): request_ids, ops, event, cache_salts=cache_salts ) + def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata) -> None: + """Flush async non-GPU gathers only when scheduler metadata requests it.""" + worker_adapter = getattr(self, "worker_adapter", None) + if self.role != KVConnectorRole.WORKER or worker_adapter is None: + return + need_flush = ( + isinstance(kv_connector_metadata, LMCacheMPConnectorMetadata) + and kv_connector_metadata.need_flush + ) + worker_adapter.handle_preemptions(need_flush) + def get_finished( self, finished_req_ids: set[str] ) -> tuple[set[str] | None, set[str] | None]: @@ -963,6 +979,7 @@ def build_connector_meta( scheduler_output (SchedulerOutput): the scheduler output object. """ metadata = LMCacheMPConnectorMetadata() + metadata.need_flush = self._scheduler_step_needs_flush(scheduler_output) self._process_retrieve_requests(metadata) self._process_new_requests(scheduler_output, metadata) @@ -976,6 +993,43 @@ def build_connector_meta( return metadata + def _scheduler_step_needs_flush(self, scheduler_output: SchedulerOutput) -> bool: + """Return whether this scheduler step can overwrite preempted blocks. + + Under-syncing here risks KV-block corruption (a paged block may be + overwritten before a deferred async gather reads it), while over-syncing + only costs performance, so we prefer a spurious flush over a missed one. + + Signal fields are verified against vLLM main: + - ``CachedRequestData.resumed_req_ids``: requests resumed from + preemption this step. Their blocks are replaced (not appended), so an + in-flight gather against the old blocks must be flushed first. + - ``SchedulerOutput.preempted_req_ids``: requests preempted this step. + Their blocks are freed and may be reused, so the same applies. + """ + cached_reqs = getattr(scheduler_output, "scheduled_cached_reqs", None) + + # Primary signal: requests resumed from preemption this step. + if getattr(cached_reqs, "resumed_req_ids", None): + return True + + # Primary signal: requests preempted this step. + if getattr(scheduler_output, "preempted_req_ids", None): + return True + + # Conservative fallback: if cached requests are present but the schema + # exposes no recognized ``resumed_req_ids`` field (e.g. a much older or + # forked vLLM), we cannot prove the step is preemption-free, so flush + # rather than risk corruption. + if cached_reqs is not None and not hasattr(cached_reqs, "resumed_req_ids"): + logger.warning( + "Unrecognized scheduled_cached_reqs schema (%s); conservatively " + "flushing in-flight async gathers to avoid KV block corruption.", + type(cached_reqs).__name__, + ) + return True + return False + def update_connector_output(self, connector_output: KVConnectorOutput): """ Update KVConnector state from worker-side connectors output. diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 86578b22db..ff035dd392 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -7,6 +7,7 @@ import enum import os import threading +import time # Third Party import torch @@ -936,6 +937,10 @@ def __init__( # Prevents re-reporting the same ID after drain clears tracking sets. self._returned_finished: set[str] = set() + # Timestamps recorded when submit_store_request is called, used to + # measure E2E wall-clock time until the future is resolved. + self._store_submit_times: dict[str, float] = {} + self.model_name = model_name self.parallel_strategy = parallel_strategy @@ -1185,6 +1190,7 @@ def submit_store_request( ) self.store_futures[request_id] = future self.store_events[request_id] = event + self._store_submit_times[request_id] = time.perf_counter() @_lmcache_nvtx_annotate def submit_retrieve_request( @@ -1345,6 +1351,7 @@ def get_finished( self.retrieve_futures.clear() self.store_events.clear() self.retrieve_events.clear() + self._store_submit_times.clear() ret_stores = self._process_finished_stores( finished_stores, finished_req_ids_from_engine @@ -1363,6 +1370,15 @@ def get_finished( if not s_future.query(): continue + _t_done = time.perf_counter() + _t_submit = self._store_submit_times.pop(request_id, None) + if _t_submit is not None: + logger.info( + "[E2E-STORE] req=%s e2e=%.3f ms", + request_id, + (_t_done - _t_submit) * 1000, + ) + s_result = s_future.result() finished_stores.add(request_id) @@ -1430,6 +1446,18 @@ def get_block_ids_with_load_errors(self) -> set[int]: self.error_block_ids.clear() return errors + def handle_preemptions(self, need_flush: bool) -> None: + """Handle worker-side preemption hints from connector metadata. + + When ``need_flush`` is true, synchronize deferred non-GPU gather work + before the next forward pass can overwrite paged KV blocks. + """ + if not need_flush: + return + if not self.is_healthy or self.transfer_ctx is None: + return + self.transfer_ctx.flush_inflight_gathers() + def shutdown(self): """ Shutdown the LMCache MP worker adapter diff --git a/lmcache/python_ops_fallback.py b/lmcache/python_ops_fallback.py index 01bc7c2398..69038c61b9 100644 --- a/lmcache/python_ops_fallback.py +++ b/lmcache/python_ops_fallback.py @@ -257,8 +257,8 @@ class TransferDirection(IntEnum): D2H = 1 -class GPUKVFormat(IntEnum): - """Enumeration of different GPU KV cache memory layouts.""" +class EngineKVFormat(IntEnum): + """Enumeration of different engine KV cache memory layouts.""" # used by: vLLM CROSS_LAYER mode NB_NL_TWO_BS_NH_HS = 0 @@ -290,6 +290,16 @@ class GPUKVFormat(IntEnum): # used by: SGLang MHA via the MP daemon path TWO_X_NL_X_NB_BS_NH_HS = 9 + # used by: vLLM non-MLA blocks-first attention with K/V fused into the + # trailing dim. Per-layer physical shape + # [num_blocks, num_heads, block_size, 2, head_size] -- the K/V "2" axis is + # second-to-last, recovered by splitting the fused [..., 2 * head_size]. + NL_X_NB_NH_BS_TWO_HS = 10 + + +# Backward-compat alias +GPUKVFormat = EngineKVFormat + class PageBufferShapeDesc: """Python stand-in for the C++ ``PageBufferShapeDesc`` struct. @@ -539,7 +549,7 @@ def multi_layer_kv_transfer( paged_memory_device: torch.device, page_buffer_size: int, direction: TransferDirection, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, block_size: int = 0, head_size: int = 0, skip_prefix_n_tokens: int = 0, @@ -555,9 +565,9 @@ def multi_layer_kv_transfer( # TODO: Implement head_size support for HND layouts (NL_X_TWO_NB_NH_BS_HS, # NL_X_NB_TWO_NH_BS_HS) as next step. - if int(gpu_kv_format) in ( - int(GPUKVFormat.NL_X_TWO_NB_NH_BS_HS), - int(GPUKVFormat.NL_X_NB_TWO_NH_BS_HS), + if int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS), + int(EngineKVFormat.NL_X_NB_TWO_NH_BS_HS), ): raise NotImplementedError( "HND layouts (NL_X_TWO_NB_NH_BS_HS, NL_X_NB_TWO_NH_BS_HS) " @@ -585,11 +595,11 @@ def multi_layer_kv_transfer( valid_slots = slots_kv[valid_mask_kv].to(paged_memory_device) # 2. Determine architecture variant and tensor dimensions. - is_mla = int(gpu_kv_format) in ( - int(GPUKVFormat.NL_X_NB_BS_HS), - int(GPUKVFormat.NL_X_NBBS_ONE_HS), + is_mla = int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_NB_BS_HS), + int(EngineKVFormat.NL_X_NBBS_ONE_HS), ) - is_flash_infer = int(gpu_kv_format) == int(GPUKVFormat.NL_X_NB_TWO_BS_NH_HS) + is_flash_infer = int(engine_kv_format) == int(EngineKVFormat.NL_X_NB_TWO_BS_NH_HS) num_layers = key_value.size(1) hidden_size = key_value.size(3) @@ -673,7 +683,7 @@ def multi_layer_kv_transfer_unilateral( paged_memory_device: torch.device, page_buffer_size: int, direction: TransferDirection, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, ): """ Python fallback for multi_layer_kv_transfer_unilateral @@ -696,9 +706,9 @@ def multi_layer_kv_transfer_unilateral( H2D = LMCache -> PagedBuffer D2H = PagedBuffer -> LMCache """ - is_mla = int(gpu_kv_format) in ( - int(GPUKVFormat.NL_X_NB_BS_HS), - int(GPUKVFormat.NL_X_NBBS_ONE_HS), + is_mla = int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_NB_BS_HS), + int(EngineKVFormat.NL_X_NBBS_ONE_HS), ) # MLA case collapses back to multi_layer_kv_transfer @@ -711,7 +721,7 @@ def multi_layer_kv_transfer_unilateral( paged_memory_device, page_buffer_size, direction, - gpu_kv_format, + engine_kv_format, 0, # block_size unused for MLA formats ) # ── Non-MLA path: unilateral (separate K/V buffers per layer) ── @@ -748,12 +758,997 @@ def multi_layer_kv_transfer_unilateral( key_value[kv_idx, layer_id, valid_mask_kv, :] = gathered.to(kv_device) +def _is_cross_layer_format(engine_kv_format: EngineKVFormat) -> bool: + """Return True when a KV format uses a single cross-layer tensor.""" + return int(engine_kv_format) in ( + int(EngineKVFormat.NB_NL_TWO_BS_NH_HS), + int(EngineKVFormat.NB_NL_TWO_NH_BS_HS), + ) + + +def _is_sglang_mha_format(engine_kv_format: EngineKVFormat) -> bool: + """Return True when a KV format uses SGLang MHA layout (2*NL tensors).""" + return int(engine_kv_format) in ( + int(EngineKVFormat.TWO_X_NL_X_NBBS_NH_HS), + int(EngineKVFormat.TWO_X_NL_X_NB_BS_NH_HS), + ) + + +def _is_hnd_format(engine_kv_format: EngineKVFormat) -> bool: + """Return True when a per-layer KV format stores heads before block tokens (HND).""" + return int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS), + int(EngineKVFormat.NL_X_NB_TWO_NH_BS_HS), + int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS), + ) + + +def _is_mla_format(engine_kv_format: EngineKVFormat) -> bool: + """Return True when a KV format uses MLA paged layout.""" + return int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_NB_BS_HS), + int(EngineKVFormat.NL_X_NBBS_ONE_HS), + ) + + +_ELEMENT_SIZE_TO_DTYPE: dict[int, torch.dtype] = { + # Maps the byte width of a KV-cache element to a representative torch dtype. + # Only widths that commonly appear in KV caches are listed; 1-byte entries + # are treated as uint8 (raw bytes), 2-byte as float16, 4-byte as float32. + # Note: bfloat16 also has element_size == 2 but cannot be distinguished here; + # callers that need exact dtype should supply it explicitly. + 1: torch.uint8, + 2: torch.float16, + 4: torch.float32, +} + + +def _is_ptr_tensor(x: object) -> bool: + """Return True when *x* is a 1-D pointer tensor (int64 or uint64).""" + return ( + isinstance(x, torch.Tensor) + and x.dtype in (torch.int64, torch.uint64) + and x.ndim == 1 + ) + + +def _per_layer_paged_shape( + engine_kv_format: EngineKVFormat, + nb: int, + bs: int, + nh: int, + hs: int, +) -> tuple[int, ...]: + """Return the logical shape of a single per-layer paged buffer tensor. + + Args: + engine_kv_format: The format enum that describes how K/V tokens are laid out. + nb: Number of blocks in the paged buffer (``shape_desc.nb``). + bs: Tokens per block / block size (``shape_desc.bs``). + nh: Number of attention heads (``shape_desc.nh``). + hs: Per-head hidden size (``shape_desc.hs``). + + Returns: + A tuple representing the shape needed to reconstruct one layer's tensor + from a raw pointer via :func:`_tensor_from_ptr`. + """ + fmt = int(engine_kv_format) + if fmt == int(EngineKVFormat.NL_X_NBBS_ONE_HS): + return (nb * bs, 1, hs) + if fmt == int(EngineKVFormat.NL_X_NB_BS_HS): + return (nb, bs, hs) + if fmt == int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS): + return (2, nb, nh, bs, hs) + if fmt == int(EngineKVFormat.NL_X_NB_TWO_NH_BS_HS): + return (nb, 2, nh, bs, hs) + if fmt == int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS): + # vLLM CPU blocks-first fused KV: K and V interleaved at the + # second-to-last dim so each layer is [NB, NH, BS, 2, HS]. + return (nb, nh, bs, 2, hs) + if fmt == int(EngineKVFormat.NL_X_TWO_NB_BS_NH_HS): + return (2, nb, bs, nh, hs) + if fmt == int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS): + return (nb, nh, bs, 2, hs) + # Covers NL_X_NB_TWO_BS_NH_HS and any future NHD variants. + return (nb, 2, bs, nh, hs) + + +def _infer_kv_dtype( + paged_buffer_ptrs_tensor: object, + lmcache_objects_ptrs: object, + shape_desc: "PageBufferShapeDesc", +) -> torch.dtype: + """Infer the KV element dtype from whichever inputs carry it. + + Inference order (first match wins): + 1. ``shape_desc.dtype`` — authoritative when set (requires the + ``set_shape_desc_dtype`` helper from PR #3514; correctly distinguishes + float16 vs bfloat16 which share ``element_size == 2``). + 2. ``paged_buffer_ptrs_tensor`` — if it is a non-pointer tensor or a list + of tensors (including nested SGLang MHA lists), the dtype of the first + tensor is used. + 3. ``lmcache_objects_ptrs`` — if it is a list of tensors, the dtype of the + first chunk tensor is used. + 4. ``shape_desc.element_size`` — looked up in :data:`_ELEMENT_SIZE_TO_DTYPE` + (ambiguous for 2-byte types; kept only as last-resort fallback). + 5. ``torch.bfloat16`` — silent default when no other source is available. + """ + # Prefer shape_desc.dtype — it is exact and avoids the element_size ambiguity. + if shape_desc is not None: + sd_dtype = getattr(shape_desc, "dtype", None) + if sd_dtype is not None: + return sd_dtype + if isinstance(paged_buffer_ptrs_tensor, list) and paged_buffer_ptrs_tensor: + first = paged_buffer_ptrs_tensor[0] + if isinstance(first, list) and first and isinstance(first[0], torch.Tensor): + return first[0].dtype + if isinstance(first, torch.Tensor): + return first.dtype + if isinstance(paged_buffer_ptrs_tensor, torch.Tensor) and not _is_ptr_tensor( + paged_buffer_ptrs_tensor + ): + return paged_buffer_ptrs_tensor.dtype + if isinstance(lmcache_objects_ptrs, list) and lmcache_objects_ptrs: + if isinstance(lmcache_objects_ptrs[0], torch.Tensor): + return lmcache_objects_ptrs[0].dtype + if shape_desc is not None and shape_desc.element_size > 0: + dtype = _ELEMENT_SIZE_TO_DTYPE.get(shape_desc.element_size) + if dtype is None: + raise ValueError( + f"Unsupported element_size {shape_desc.element_size!r} in " + "shape_desc; cannot infer KV dtype. " + f"Supported sizes: {sorted(_ELEMENT_SIZE_TO_DTYPE)}" + ) + return dtype + return torch.bfloat16 + + +def _normalize_paged_layers( + paged_buffer_ptrs_tensor: "torch.Tensor | list", + engine_kv_format: EngineKVFormat, + shape_desc: "PageBufferShapeDesc | None" = None, + device: "torch.device | str | None" = None, + dtype: "torch.dtype | None" = None, +) -> "torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]]": + """Normalize paged buffer input based on GPU KV format. + + Accepts either tensor-form inputs (list / Tensor) or a 1-D pointer tensor + (int64 / uint64). When a pointer tensor is provided *shape_desc*, *device*, + and *dtype* must be supplied so the tensors can be reconstructed via + :func:`_tensor_from_ptr`. + + Returns: + - Single ``torch.Tensor`` for cross-layer formats. + - ``list[list[torch.Tensor]]`` (2 x NL) for SGLang MHA formats. + - ``list[torch.Tensor]`` (per-layer) for all other formats. + """ + if _is_cross_layer_format(engine_kv_format): + if isinstance(paged_buffer_ptrs_tensor, torch.Tensor): + if _is_ptr_tensor(paged_buffer_ptrs_tensor): + # 1-D pointer tensor with a single entry → reconstruct full tensor. + if shape_desc is None or device is None or dtype is None: + raise ValueError( + "_normalize_paged_layers: shape_desc, device, and dtype are " + "required when paged_buffer_ptrs_tensor is a pointer tensor" + ) + nb = int(shape_desc.nb) + nl = int(shape_desc.nl) + bs = int(shape_desc.bs) + nh = int(shape_desc.nh) + hs = int(shape_desc.hs) + if int(engine_kv_format) == int(EngineKVFormat.NB_NL_TWO_NH_BS_HS): + shape: tuple[int, ...] = (nb, nl, 2, nh, bs, hs) + else: + shape = (nb, nl, 2, bs, nh, hs) + ptr = int(paged_buffer_ptrs_tensor[0].item()) + return _tensor_from_ptr(ptr, shape, dtype, device) + return paged_buffer_ptrs_tensor + raise TypeError( + "Cross-layer formats require a single torch.Tensor input; " + "got: " + type(paged_buffer_ptrs_tensor).__name__ + ) + if _is_sglang_mha_format(engine_kv_format): + if _is_ptr_tensor(paged_buffer_ptrs_tensor): + # 1-D pointer tensor [K_L0,...,K_LN-1, V_L0,...,V_LN-1] → nested list. + if shape_desc is None or device is None or dtype is None: + raise ValueError( + "_normalize_paged_layers: shape_desc, device, and dtype are " + "required when paged_buffer_ptrs_tensor is a pointer tensor" + ) + nb = int(shape_desc.nb) + nl = int(shape_desc.nl) + bs = int(shape_desc.bs) + nh = int(shape_desc.nh) + hs = int(shape_desc.hs) + is_flat = int(engine_kv_format) == int(EngineKVFormat.TWO_X_NL_X_NBBS_NH_HS) + per_layer_shape: tuple[int, ...] = ( + (nb * bs, nh, hs) if is_flat else (nb, bs, nh, hs) + ) + ptrs = [int(p.item()) for p in paged_buffer_ptrs_tensor] + k_tensors = [ + _tensor_from_ptr(ptrs[i], per_layer_shape, dtype, device) + for i in range(nl) + ] + v_tensors = [ + _tensor_from_ptr(ptrs[nl + i], per_layer_shape, dtype, device) + for i in range(nl) + ] + return [k_tensors, v_tensors] + if isinstance(paged_buffer_ptrs_tensor, list): + # Already nested [[K tensors], [V tensors]] + if ( + len(paged_buffer_ptrs_tensor) == 2 + and isinstance(paged_buffer_ptrs_tensor[0], list) + and all( + isinstance(t, torch.Tensor) + for group in paged_buffer_ptrs_tensor + for t in group + ) + ): + return paged_buffer_ptrs_tensor + # Flat list [K_L0, ..., K_LN-1, V_L0, ..., V_LN-1] + if all(isinstance(t, torch.Tensor) for t in paged_buffer_ptrs_tensor): + if len(paged_buffer_ptrs_tensor) % 2 != 0: + raise ValueError( + "Flat SGLang MHA list must have even length (2*NL)" + ) + half = len(paged_buffer_ptrs_tensor) // 2 + return [ + paged_buffer_ptrs_tensor[:half], + paged_buffer_ptrs_tensor[half:], + ] + raise TypeError( + "SGLang MHA formats require a list[list[torch.Tensor]], a flat " + "list[torch.Tensor] (2*NL entries), or a 1-D pointer tensor; " + "got: " + type(paged_buffer_ptrs_tensor).__name__ + ) + # Per-layer formats + if _is_ptr_tensor(paged_buffer_ptrs_tensor): + # 1-D pointer tensor [ptr_L0, ..., ptr_LN-1] → list of per-layer tensors. + if shape_desc is None or device is None or dtype is None: + raise ValueError( + "_normalize_paged_layers: shape_desc, device, and dtype are " + "required when paged_buffer_ptrs_tensor is a pointer tensor" + ) + nb = int(shape_desc.nb) + bs = int(shape_desc.bs) + nh = int(shape_desc.nh) + hs = int(shape_desc.hs) + per_shape = _per_layer_paged_shape(engine_kv_format, nb, bs, nh, hs) + return [ + _tensor_from_ptr(int(p.item()), per_shape, dtype, device) + for p in paged_buffer_ptrs_tensor + ] + if isinstance(paged_buffer_ptrs_tensor, list): + if not all(isinstance(t, torch.Tensor) for t in paged_buffer_ptrs_tensor): + raise TypeError( + "paged_buffer_ptrs_tensor list must contain torch.Tensor entries" + ) + return paged_buffer_ptrs_tensor + raise TypeError( + "paged_buffer_ptrs_tensor must be a list[torch.Tensor] or 1-D pointer tensor; " + "got: " + type(paged_buffer_ptrs_tensor).__name__ + ) + + +def _normalize_lmcache_objects( + lmcache_objects_ptrs: "list[int] | list[torch.Tensor]", + shape_desc: "PageBufferShapeDesc | None" = None, + lmcache_chunk_size: "int | None" = None, + engine_kv_format: "EngineKVFormat | None" = None, + dtype: "torch.dtype | None" = None, +) -> list[torch.Tensor]: + """Normalize LMCache object inputs to chunk tensors. + + Accepts either a list of chunk tensors or a ``list[int]`` of raw CPU pointers. + When a pointer list is provided *shape_desc*, *lmcache_chunk_size*, + *engine_kv_format*, and *dtype* must be supplied so the tensors can be + reconstructed via :func:`_tensor_from_ptr` on the CPU. + """ + if not isinstance(lmcache_objects_ptrs, list): + raise TypeError( + "lmcache_objects_ptrs must be a list[torch.Tensor] or list[int]; " + "got: " + type(lmcache_objects_ptrs).__name__ + ) + if not lmcache_objects_ptrs: + return [] + if isinstance(lmcache_objects_ptrs[0], torch.Tensor): + return lmcache_objects_ptrs # type: ignore[return-value] + if isinstance(lmcache_objects_ptrs[0], int): + # Pointer mode: reconstruct chunk tensors (always on CPU). + if ( + shape_desc is None + or lmcache_chunk_size is None + or engine_kv_format is None + or dtype is None + ): + raise ValueError( + "_normalize_lmcache_objects: shape_desc, lmcache_chunk_size, " + "engine_kv_format, and dtype are required when lmcache_objects_ptrs " + "contains raw int pointers" + ) + nl = int(shape_desc.nl) + nh = int(shape_desc.nh) + hs = int(shape_desc.hs) + chunk_tokens = lmcache_chunk_size + if _is_mla_format(engine_kv_format): + chunk_shape: tuple[int, ...] = (nl, chunk_tokens, hs) + else: + chunk_shape = (2, nl, chunk_tokens, nh * hs) + return [ + _tensor_from_ptr(ptr, chunk_shape, dtype, "cpu") + for ptr in lmcache_objects_ptrs + ] + raise TypeError( + "lmcache_objects_ptrs must be a list[torch.Tensor] or list[int]; " + "got list containing: " + type(lmcache_objects_ptrs[0]).__name__ + ) + + +def _to_block_id_list(block_ids: torch.Tensor | list[int]) -> list[int]: + """Convert block IDs from tensor/list form into a Python ``list[int]``.""" + if isinstance(block_ids, torch.Tensor): + return [int(x) for x in block_ids.to(dtype=torch.int64).cpu().tolist()] + if isinstance(block_ids, list): + return [int(x) for x in block_ids] + raise TypeError("block_ids must be a torch.Tensor or list[int]") + + +def multi_layer_block_kv_transfer( + paged_buffer_ptrs_tensor: "torch.Tensor | list", + lmcache_objects_ptrs: list[int] | list[torch.Tensor], + block_ids: torch.Tensor | list[int], + device: torch.device | str, + direction: TransferDirection, + shape_desc: PageBufferShapeDesc, + lmcache_chunk_size: int, + engine_kv_format: EngineKVFormat, + skip_prefix_n_blocks: int, +) -> None: + """Python fallback implementation of block-based multi-layer KV transfer. + + Signature intentionally mirrors the C++ binding so callers can invoke + ``lmcache.c_ops.multi_layer_block_kv_transfer`` uniformly on native and + fallback backends. + + Args: + paged_buffer_ptrs_tensor: Paged buffer pointers or tensors. + lmcache_objects_ptrs: LMCache object pointers or chunk tensors. + block_ids: Ordered engine block IDs for the transfer. + device: Target device for the transfer. + direction: Transfer direction (H2D or D2H). + shape_desc: Shape descriptor of the page buffer. + lmcache_chunk_size: Chunk size of LMCache objects. + engine_kv_format: GPU KV cache format. + skip_prefix_n_blocks: Number of leading blocks to skip. + + Returns: + None + + Raises: + ValueError: If chunk size is invalid, or transfer direction is unsupported. + TypeError: If input types do not match expected types. + """ + if lmcache_chunk_size <= 0: + raise ValueError("lmcache_chunk_size must be positive") + if int(shape_desc.bs) <= 0 or lmcache_chunk_size % int(shape_desc.bs) != 0: + raise ValueError( + "lmcache_chunk_size must be a positive multiple of shape_desc.bs" + ) + if skip_prefix_n_blocks < 0: + raise ValueError("skip_prefix_n_blocks must be >= 0") + + is_d2h = int(direction) == int(TransferDirection.D2H) + is_h2d = int(direction) == int(TransferDirection.H2D) + if not (is_d2h or is_h2d): + raise ValueError(f"Unsupported transfer direction: {direction!r}") + + kv_dtype = _infer_kv_dtype( + paged_buffer_ptrs_tensor, lmcache_objects_ptrs, shape_desc + ) + normalized = _normalize_paged_layers( + paged_buffer_ptrs_tensor, + engine_kv_format, + shape_desc=shape_desc, + device=device, + dtype=kv_dtype, + ) + object_tensors = _normalize_lmcache_objects( + lmcache_objects_ptrs, + shape_desc=shape_desc, + lmcache_chunk_size=lmcache_chunk_size, + engine_kv_format=engine_kv_format, + dtype=kv_dtype, + ) + n_block_ids = ( + int(block_ids.numel()) + if isinstance(block_ids, torch.Tensor) + else len(block_ids) + ) + blocks_per_object = lmcache_chunk_size // int(shape_desc.bs) + block_size = int(shape_desc.bs) + + if _is_cross_layer_format(engine_kv_format): + _transfer_cross_layer( + normalized, + object_tensors, + block_ids, + n_block_ids, + blocks_per_object, + block_size, + engine_kv_format, + is_d2h, + skip_prefix_n_blocks, + ) + elif _is_sglang_mha_format(engine_kv_format): + _transfer_sglang_mha( + normalized, + object_tensors, + block_ids, + n_block_ids, + blocks_per_object, + block_size, + engine_kv_format, + is_d2h, + skip_prefix_n_blocks, + ) + elif _is_mla_format(engine_kv_format): + _transfer_per_layer_mla( + normalized, + object_tensors, + block_ids, + n_block_ids, + blocks_per_object, + block_size, + engine_kv_format, + is_d2h, + skip_prefix_n_blocks, + ) + elif _is_hnd_format(engine_kv_format): + _transfer_per_layer_hnd( + normalized, + object_tensors, + block_ids, + n_block_ids, + blocks_per_object, + block_size, + engine_kv_format, + is_d2h, + skip_prefix_n_blocks, + ) + else: + _transfer_per_layer_nhd( + normalized, + object_tensors, + block_ids, + n_block_ids, + blocks_per_object, + block_size, + engine_kv_format, + is_d2h, + skip_prefix_n_blocks, + ) + + +def _valid_block_range( + object_idx: int, + block_id_list: list[int], + blocks_per_object: int, + block_size: int, + skip_prefix_n_blocks: int, +) -> tuple[list[int], int] | None: + """Return valid engine block IDs and their LMCache object token offset. + + Args: + object_idx: Index of the LMCache object/chunk being processed. + block_id_list: Full ordered engine block ids for the transfer. + blocks_per_object: Number of blocks represented by one LMCache object. + block_size: Number of tokens per block. + skip_prefix_n_blocks: Number of leading flat block positions to skip. + + Returns: + ``None`` if this object has no valid blocks after skip handling. + Otherwise, a tuple of valid engine block ids and the token offset + within this LMCache object where those blocks start. + """ + object_flat_start = object_idx * blocks_per_object + valid_flat_start = max(object_flat_start, skip_prefix_n_blocks) + valid_flat_end = min(object_flat_start + blocks_per_object, len(block_id_list)) + if valid_flat_start >= valid_flat_end: + return None + offset_in_object = (valid_flat_start - object_flat_start) * block_size + return block_id_list[valid_flat_start:valid_flat_end], offset_in_object + + +def _valid_block_range_indices( + object_idx: int, + n_block_ids: int, + blocks_per_object: int, + block_size: int, + skip_prefix_n_blocks: int, +) -> tuple[int, int, int] | None: + """Return valid [start, end) range over flat block IDs and object token offset.""" + object_flat_start = object_idx * blocks_per_object + valid_flat_start = max(object_flat_start, skip_prefix_n_blocks) + valid_flat_end = min(object_flat_start + blocks_per_object, n_block_ids) + if valid_flat_start >= valid_flat_end: + return None + offset_in_object = (valid_flat_start - object_flat_start) * block_size + return valid_flat_start, valid_flat_end, offset_in_object + + +def _transfer_cross_layer( + paged_tensor: torch.Tensor, + object_tensors: list[torch.Tensor], + block_ids: torch.Tensor | list[int], + n_block_ids: int, + blocks_per_object: int, + block_size: int, + engine_kv_format: EngineKVFormat, + is_d2h: bool, + skip_prefix_n_blocks: int, +) -> None: + """Handle cross-layer formats: single tensor [NB, NL, 2, ...].""" + # NHD: [NB, NL, 2, BS, NH, HS] HND: [NB, NL, 2, NH, BS, HS] + is_hnd = int(engine_kv_format) == int(EngineKVFormat.NB_NL_TWO_NH_BS_HS) + num_layers = paged_tensor.shape[1] + + if is_hnd: + # [NB, NL, 2, NH, BS, HS] + nh = paged_tensor.shape[3] + hs = paged_tensor.shape[5] + else: + # [NB, NL, 2, BS, NH, HS] + nh = paged_tensor.shape[4] + hs = paged_tensor.shape[5] + + # H2D: pre-transfer objects to paged device + if not is_d2h and object_tensors: + objs_on_device = [obj.to(paged_tensor.device) for obj in object_tensors] + block_ids_dev = torch.as_tensor( + block_ids, dtype=torch.long, device=paged_tensor.device + ) + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range_indices( + object_idx, + n_block_ids, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + idx_start, idx_end, offset_in_object = valid + n_valid = idx_end - idx_start + token_end = offset_in_object + n_valid * block_size + eff_idx = block_ids_dev[idx_start:idx_end] + + if is_d2h: + selected = paged_tensor.index_select(0, eff_idx) + + for layer_idx in range(num_layers): + for kv in range(2): + if is_d2h: + slice_t = selected[:, layer_idx, kv] + if is_hnd: + # N=n_valid, BS=block_size: + # [N, NH, BS, HS] -> [N, BS, NH, HS] -> [N*BS, NH*HS] + flat = slice_t.permute(0, 2, 1, 3).reshape( + n_valid * block_size, nh * hs + ) + else: + # [N, BS, NH, HS] → [N*BS, NH*HS] + flat = slice_t.reshape(n_valid * block_size, nh * hs) + obj[kv, layer_idx, offset_in_object:token_end].copy_( + flat, non_blocking=True + ) + else: + obj_device = objs_on_device[object_idx] + src = obj_device[kv, layer_idx, offset_in_object:token_end] + if is_hnd: + # N=n_valid, BS=block_size: + # [N*BS, NH*HS] -> [N, BS, NH, HS] -> [N, NH, BS, HS] + src_blocks = src.reshape(n_valid, block_size, nh, hs).permute( + 0, 2, 1, 3 + ) + else: + # N=n_valid, BS=block_size: + # [N*BS, NH*HS] -> [N, BS, NH, HS] + src_blocks = src.reshape(n_valid, block_size, nh, hs) + paged_tensor[:, layer_idx, kv].index_copy_(0, eff_idx, src_blocks) + + +def _transfer_sglang_mha( + paged_tensors: list[list[torch.Tensor]], + object_tensors: list[torch.Tensor], + block_ids: torch.Tensor | list[int], + n_block_ids: int, + blocks_per_object: int, + block_size: int, + engine_kv_format: EngineKVFormat, + is_d2h: bool, + skip_prefix_n_blocks: int, +) -> None: + """Handle SGLang MHA formats: 2*NL tensors (list[list[Tensor]]).""" + # TWO_X_NL_X_NBBS_NH_HS: each tensor [NB*BS, NH, HS] + # TWO_X_NL_X_NB_BS_NH_HS: each tensor [NB, BS, NH, HS] + is_flat = int(engine_kv_format) == int(EngineKVFormat.TWO_X_NL_X_NBBS_NH_HS) + num_layers = len(paged_tensors[0]) + + # Determine target device from first tensor + target_device = paged_tensors[0][0].device + + # H2D: pre-transfer objects + if not is_d2h and object_tensors: + objs_on_device = [obj.to(target_device) for obj in object_tensors] + block_ids_dev = torch.as_tensor(block_ids, dtype=torch.long, device=target_device) + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range_indices( + object_idx, + n_block_ids, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + idx_start, idx_end, offset_in_object = valid + n_valid = idx_end - idx_start + token_end = offset_in_object + n_valid * block_size + eff_idx = block_ids_dev[idx_start:idx_end] + if is_flat: + # Flat token positions for all valid blocks: + # block_id * block_size + token offset. Reused across layer/KV pairs. + token_indices = ( + eff_idx[:, None] * block_size + + torch.arange(block_size, dtype=torch.long, device=target_device) + ).reshape(-1) + + for layer_idx in range(num_layers): + for kv in range(2): + layer_t = paged_tensors[kv][layer_idx] + nh = layer_t.shape[-2] + hs = layer_t.shape[-1] + + if is_d2h: + if is_flat: + # [NB*BS, NH, HS] + gathered = layer_t.index_select(0, token_indices) + else: + # [NB, BS, NH, HS] + gathered = layer_t.index_select(0, eff_idx).reshape( + n_valid * block_size, nh, hs + ) + flat = gathered.reshape(n_valid * block_size, nh * hs) + obj[kv, layer_idx, offset_in_object:token_end].copy_( + flat, non_blocking=True + ) + else: + obj_device = objs_on_device[object_idx] + src = obj_device[kv, layer_idx, offset_in_object:token_end] + src_shaped = src.reshape(n_valid * block_size, nh, hs) + if is_flat: + # scatter into [NB*BS, NH, HS] + layer_t.index_copy_(0, token_indices, src_shaped) + else: + # N=n_valid, BS=block_size: + # [N*BS, NH, HS] -> [N, BS, NH, HS] + src_blocks = src_shaped.reshape(n_valid, block_size, nh, hs) + layer_t.index_copy_(0, eff_idx, src_blocks) + + +def _transfer_per_layer_mla( + layer_tensors: list[torch.Tensor], + object_tensors: list[torch.Tensor], + block_ids: torch.Tensor | list[int], + n_block_ids: int, + blocks_per_object: int, + block_size: int, + engine_kv_format: EngineKVFormat, + is_d2h: bool, + skip_prefix_n_blocks: int, +) -> None: + """Handle MLA per-layer formats: [NB, BS, HS].""" + if not layer_tensors or not object_tensors: + return + + is_flat = int(engine_kv_format) == int(EngineKVFormat.NL_X_NBBS_ONE_HS) + target_device = layer_tensors[0].device + if is_flat: + token_offsets = torch.arange(block_size, dtype=torch.long, device=target_device) + block_ids_dev = torch.as_tensor(block_ids, dtype=torch.long, device=target_device) + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range_indices( + object_idx, + n_block_ids, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + idx_start, idx_end, offset_in_object = valid + n_valid = idx_end - idx_start + token_end = offset_in_object + n_valid * block_size + eff_idx = block_ids_dev[idx_start:idx_end] + if is_flat: + token_indices = ( + eff_idx[:, None] * block_size + token_offsets[None, :] + ).reshape(-1) + + if is_d2h: + hidden_size = layer_tensors[0].shape[-1] + chunk_gpu = torch.empty( + len(layer_tensors), + n_valid * block_size, + hidden_size, + dtype=layer_tensors[0].dtype, + device=target_device, + ) + for layer_idx, layer in enumerate(layer_tensors): + if is_flat: + dst = chunk_gpu[layer_idx].view( + n_valid * block_size, 1, hidden_size + ) + torch.index_select(layer, 0, token_indices, out=dst) + else: + dst = chunk_gpu[layer_idx].view(n_valid, block_size, hidden_size) + torch.index_select(layer, 0, eff_idx, out=dst) + obj[:, offset_in_object:token_end].copy_(chunk_gpu, non_blocking=True) + else: + chunk_gpu = obj[:, offset_in_object:token_end].to( + target_device, non_blocking=True + ) + for layer_idx, layer in enumerate(layer_tensors): + src = chunk_gpu[layer_idx] + hidden_size = layer.shape[-1] + if is_flat: + src_tokens = src.reshape(n_valid * block_size, 1, hidden_size) + layer.index_copy_(0, token_indices, src_tokens) + else: + src_blocks = src.reshape(n_valid, block_size, hidden_size) + layer.index_copy_(0, eff_idx, src_blocks) + + +def _transfer_per_layer_hnd( + layer_tensors: list[torch.Tensor], + object_tensors: list[torch.Tensor], + block_ids: torch.Tensor | list[int], + n_block_ids: int, + blocks_per_object: int, + block_size: int, + engine_kv_format: EngineKVFormat, + is_d2h: bool, + skip_prefix_n_blocks: int, +) -> None: + """Handle per-layer HND formats: heads before block tokens.""" + if not layer_tensors or not object_tensors: + return + + target_device = layer_tensors[0].device + block_ids_dev = torch.as_tensor(block_ids, dtype=torch.long, device=target_device) + + first_layer = layer_tensors[0] + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS): + first_k = first_layer[0] + elif int(engine_kv_format) == int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS): + first_k = first_layer[:, :, :, 0] + else: + first_k = first_layer[:, 0] + _nb0, nh0, _bs0, hs0 = first_k.shape + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range_indices( + object_idx, + n_block_ids, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + idx_start, idx_end, offset_in_object = valid + n_valid = idx_end - idx_start + token_end = offset_in_object + n_valid * block_size + eff_idx = block_ids_dev[idx_start:idx_end] + + if is_d2h: + chunk_gpu = torch.empty( + 2, + len(layer_tensors), + n_valid * block_size, + nh0 * hs0, + dtype=first_k.dtype, + device=target_device, + ) + scratch = torch.empty( + n_valid, + nh0, + block_size, + hs0, + dtype=first_k.dtype, + device=target_device, + ) + for layer_idx, layer in enumerate(layer_tensors): + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS): + k_t, v_t = layer[0], layer[1] + torch.index_select(k_t, 0, eff_idx, out=scratch) + chunk_gpu[0, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + scratch.permute(0, 2, 1, 3) + ) + torch.index_select(v_t, 0, eff_idx, out=scratch) + chunk_gpu[1, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + scratch.permute(0, 2, 1, 3) + ) + elif int(engine_kv_format) == int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS): + k_t, v_t = layer[:, :, :, 0], layer[:, :, :, 1] + torch.index_select(k_t, 0, eff_idx, out=scratch) + chunk_gpu[0, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + scratch.permute(0, 2, 1, 3) + ) + torch.index_select(v_t, 0, eff_idx, out=scratch) + chunk_gpu[1, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + scratch.permute(0, 2, 1, 3) + ) + else: + # FlashInfer HND stores KV as [NB, 2, NH, BS, HS]. + # Gather on dim=0 first so reads stay contiguous in memory; + # index_select on layer[:, 0]/layer[:, 1] non-contiguous views + # triggers slower element-wise gather reads. + selected = layer.index_select(0, eff_idx) + chunk_gpu[0, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + selected[:, 0].permute(0, 2, 1, 3) + ) + chunk_gpu[1, layer_idx].view(n_valid, block_size, nh0, hs0).copy_( + selected[:, 1].permute(0, 2, 1, 3) + ) + obj[:, :, offset_in_object:token_end].copy_(chunk_gpu, non_blocking=True) + else: + chunk_gpu = obj[:, :, offset_in_object:token_end].to( + target_device, non_blocking=True + ) + for layer_idx, layer in enumerate(layer_tensors): + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_NH_BS_HS): + k_t, v_t = layer[0], layer[1] + elif int(engine_kv_format) == int(EngineKVFormat.NL_X_NB_NH_BS_TWO_HS): + k_t, v_t = layer[:, :, :, 0], layer[:, :, :, 1] + else: + k_t, v_t = layer[:, 0], layer[:, 1] + _nb, nh, _bs, hs = k_t.shape + k_blocks = ( + chunk_gpu[0, layer_idx] + .reshape(n_valid, block_size, nh, hs) + .permute(0, 2, 1, 3) + ) + v_blocks = ( + chunk_gpu[1, layer_idx] + .reshape(n_valid, block_size, nh, hs) + .permute(0, 2, 1, 3) + ) + if int(engine_kv_format) == int(EngineKVFormat.NL_X_NB_TWO_NH_BS_HS): + layer.index_copy_( + 0, eff_idx, torch.stack([k_blocks, v_blocks], dim=1) + ) + else: + k_t.index_copy_(0, eff_idx, k_blocks) + v_t.index_copy_(0, eff_idx, v_blocks) + + +def _transfer_per_layer_nhd( + layer_tensors: list[torch.Tensor], + object_tensors: list[torch.Tensor], + block_ids: torch.Tensor | list[int], + n_block_ids: int, + blocks_per_object: int, + block_size: int, + engine_kv_format: EngineKVFormat, + is_d2h: bool, + skip_prefix_n_blocks: int, +) -> None: + """Handle per-layer NHD formats: block tokens before heads.""" + if not layer_tensors or not object_tensors: + return + + target_device = layer_tensors[0].device + block_ids_dev = torch.as_tensor(block_ids, dtype=torch.long, device=target_device) + + first_layer = layer_tensors[0] + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_BS_NH_HS): + first_k = first_layer[0] + else: + first_k = first_layer[:, 0] + _nb0, _bs0, nh0, hs0 = first_k.shape + + for object_idx, obj in enumerate(object_tensors): + valid = _valid_block_range_indices( + object_idx, + n_block_ids, + blocks_per_object, + block_size, + skip_prefix_n_blocks, + ) + if valid is None: + continue + idx_start, idx_end, offset_in_object = valid + n_valid = idx_end - idx_start + token_end = offset_in_object + n_valid * block_size + eff_idx = block_ids_dev[idx_start:idx_end] + + if is_d2h: + chunk_gpu = torch.empty( + 2, + len(layer_tensors), + n_valid * block_size, + nh0 * hs0, + dtype=first_k.dtype, + device=target_device, + ) + for layer_idx, layer in enumerate(layer_tensors): + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_BS_NH_HS): + k_t, v_t = layer[0], layer[1] + torch.index_select( + k_t, + 0, + eff_idx, + out=chunk_gpu[0, layer_idx].view(n_valid, block_size, nh0, hs0), + ) + torch.index_select( + v_t, + 0, + eff_idx, + out=chunk_gpu[1, layer_idx].view(n_valid, block_size, nh0, hs0), + ) + else: + # FlashInfer NHD stores KV as [NB, 2, BS, NH, HS]. + # Gather on dim=0 first to avoid index_select from + # non-contiguous layer[:, 0]/layer[:, 1] views, which + # trigger slower element-wise gather reads. + selected = layer.index_select(0, eff_idx) + chunk_gpu[0, layer_idx].copy_( + selected[:, 0].reshape(n_valid * block_size, nh0 * hs0) + ) + chunk_gpu[1, layer_idx].copy_( + selected[:, 1].reshape(n_valid * block_size, nh0 * hs0) + ) + obj[:, :, offset_in_object:token_end].copy_(chunk_gpu, non_blocking=True) + else: + chunk_gpu = obj[:, :, offset_in_object:token_end].to( + target_device, non_blocking=True + ) + for layer_idx, layer in enumerate(layer_tensors): + if int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_BS_NH_HS): + k_t, v_t = layer[0], layer[1] + k_t.index_copy_( + 0, + eff_idx, + chunk_gpu[0, layer_idx].reshape(n_valid, block_size, nh0, hs0), + ) + v_t.index_copy_( + 0, + eff_idx, + chunk_gpu[1, layer_idx].reshape(n_valid, block_size, nh0, hs0), + ) + else: + k_blocks = chunk_gpu[0, layer_idx].reshape( + n_valid, block_size, nh0, hs0 + ) + v_blocks = chunk_gpu[1, layer_idx].reshape( + n_valid, block_size, nh0, hs0 + ) + layer.index_copy_( + 0, eff_idx, torch.stack([k_blocks, v_blocks], dim=1) + ) + + def single_layer_kv_transfer( lmc_key_value_cache: torch.Tensor, vllm_key_value_cache: torch.Tensor, slot_mapping: torch.Tensor, direction: TransferDirection, - gpu_kv_format: GPUKVFormat, + engine_kv_format: EngineKVFormat, token_major: bool = False, ): """ @@ -791,9 +1786,9 @@ def single_layer_kv_transfer( valid_token_indices = torch.nonzero(valid_mask_kv, as_tuple=True)[0] valid_slots = slots_kv[valid_mask_kv].to(paged_memory_device) - is_mla = int(gpu_kv_format) in ( - int(GPUKVFormat.NL_X_NB_BS_HS), - int(GPUKVFormat.NL_X_NBBS_ONE_HS), + is_mla = int(engine_kv_format) in ( + int(EngineKVFormat.NL_X_NB_BS_HS), + int(EngineKVFormat.NL_X_NBBS_ONE_HS), ) if is_mla: @@ -818,7 +1813,7 @@ def single_layer_kv_transfer( else: # ── Non-MLA format ── # Determine vLLM layout and block_size - is_two_major = int(gpu_kv_format) == int(GPUKVFormat.NL_X_TWO_NB_BS_NH_HS) + is_two_major = int(engine_kv_format) == int(EngineKVFormat.NL_X_TWO_NB_BS_NH_HS) # flash attn: # [2, num_blocks, block_size, num_heads, head_size] # -> dim2 = block_size diff --git a/lmcache/v1/multiprocess/modules/gpu_transfer.py b/lmcache/v1/multiprocess/modules/gpu_transfer.py index 2ec9c25148..c9f8152edf 100644 --- a/lmcache/v1/multiprocess/modules/gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/gpu_transfer.py @@ -338,12 +338,14 @@ def store( """ st = time.perf_counter() obj_keys = self._ctx.resolve_obj_keys(key) + _t_resolve = time.perf_counter() entry = self._gpu_contexts.get(instance_id) if entry is None: raise ValueError(f"No GPU context registered for instance ID {instance_id}") gpu_context = entry.gpu_context model_name = entry.model_name + _num_groups = gpu_context.kv_layer_groups_manager.num_groups # NOTE: different engine groups may have different block sizes, so # ``blocks_per_chunk[i]`` is the number of blocks in one chunk for @@ -363,6 +365,7 @@ def store( block_ids_per_group_gpu = gpu_context.copy_view_block_ids_to_gpu( gpu_block_ids ) + _t_copy_ids = time.perf_counter() # Fail closed: every LMCache group must have block IDs covering all # chunks. A short list (e.g. a caller/protocol bug) would otherwise @@ -397,6 +400,7 @@ def store( gpu_context.device, event_ipc_handle ) vllm_event.wait(stream=gpu_context.stream) + _t_event_wait = time.perf_counter() # CPU-synchronous sentinel: a GPU store is about to be enqueued. # Must be published via publish() (not publish_on_stream) so the @@ -421,14 +425,22 @@ def store( }, ), ) + _t_publish = time.perf_counter() reserved_dict: dict[ObjectKey, MemoryObj] = {} store_succeeded = False + _t_reserve = _t_publish + _t_loop_end = _t_publish + _t_record_start = _t_publish + _t_record_end = _t_publish + _t_callback_end = _t_publish + use_c_ops = True try: layout_desc = get_layout_desc(gpu_context, self._ctx.chunk_size) reserved_dict = self._ctx.storage_manager.reserve_write( obj_keys, layout_desc, "new" ) + _t_reserve = time.perf_counter() # NOTE: Store is not batched because some obj_keys may be # skipped (not in reserved_dict), making block_ids @@ -441,6 +453,7 @@ def store( else: continue + _t_chunk_start = time.perf_counter() # Copy from GPU paged buffer to tmp buffer, then to CPU — per # group. Each group uses its own block-id list (HMA). for group_idx in range(num_groups): @@ -448,35 +461,77 @@ def store( chunk_block_ids_gpu = block_ids_per_group_gpu[group_idx][ idx * bpc : (idx + 1) * bpc ] - tmp_buffer = gpu_context.get_tmp_chunk_gpu_buffer(group_idx) - group_kv_pointers = gpu_context.get_group_kv_pointers(group_idx) + if use_c_ops: + tmp_buffer = gpu_context.get_tmp_chunk_gpu_buffer(group_idx) + group_kv_pointers = gpu_context.get_group_kv_pointers(group_idx) # Kernel contract: ``group_lmcache_chunk_size`` here is the # number of *physical* slots per chunk for this group # (= logical chunk_size // compress_ratio). group_lmcache_chunk_size = gpu_context.get_physical_chunk_size( group_idx ) - lmc_ops.multi_layer_block_kv_transfer( - group_kv_pointers, - [tmp_buffer.data_ptr()], - chunk_block_ids_gpu, - gpu_context.device, - lmc_ops.TransferDirection.D2H, - gpu_context.get_shape_desc(group_idx), - group_lmcache_chunk_size, - gpu_context.gpu_kv_format_, - 0, + if use_c_ops: + lmc_ops.multi_layer_block_kv_transfer( + group_kv_pointers, + [tmp_buffer.data_ptr()], + chunk_block_ids_gpu, + gpu_context.device, + lmc_ops.TransferDirection.D2H, + gpu_context.get_shape_desc(group_idx), + group_lmcache_chunk_size, + gpu_context.gpu_kv_format_, + 0, + ) + else: + group = gpu_context.kv_layer_groups_manager.kv_layer_groups[group_idx] + kv_tensors = [gpu_context.kv_tensors[i] for i in group.layer_indices] + + lmc_ops.multi_layer_block_kv_transfer( + kv_tensors, + [memory_obj.tensor], + chunk_block_ids_gpu, + gpu_context.device, + lmc_ops.TransferDirection.D2H, + gpu_context.get_shape_desc(group_idx), + group_lmcache_chunk_size, + gpu_context.gpu_kv_format_, + 0, + ) + + _t_kernel_end = time.perf_counter() + if use_c_ops: + # Store is not batched, so we always use chunk_idx=0 (single slot) + lmcache_memcpy_async_d2h( + gpu_context.get_tmp_gpu_buffer_flat(chunk_idx=0), memory_obj ) - # Store is not batched, so we always use chunk_idx=0 (single slot) - lmcache_memcpy_async_d2h( - gpu_context.get_tmp_gpu_buffer_flat(chunk_idx=0), memory_obj + _t_memcpy_end = time.perf_counter() + logger.info( + "[GPU-STORE-CHUNK] req=%s chunk_idx=%d kernel=%.3f memcpy_d2h=%.3f ms", + key.request_id, + idx, + (_t_kernel_end - _t_chunk_start) * 1000, + (_t_memcpy_end - _t_kernel_end) * 1000, ) store_succeeded = True + _t_loop_end = time.perf_counter() except Exception: logger.exception("Cannot store keys due to exception") return event.ipc_handle(), False finally: + _t_record_start = time.perf_counter() event.record() + _t_record_end = time.perf_counter() + + _t_sync_start = time.perf_counter() + # hlin99: debug mode + # event.synchronize() # 等 GPU 上所有 kernel + D2H 真正完成 + _t_sync_end = time.perf_counter() + logger.info( + "[GPU-STORE-SYNC] req=%s gpu_sync=%.3f total_with_sync=%.3f ms", + str(key.request_id), + (_t_sync_end - _t_sync_start) * 1000, + (_t_sync_end - st) * 1000, + ) # Fail closed: commit the reserved objects only when every chunk # copied successfully; otherwise the whole store is skipped. stored_count = len(reserved_dict) if store_succeeded else 0 @@ -486,6 +541,7 @@ def store( "finish_write", list(reserved_dict.keys()), ) + _t_callback_end = time.perf_counter() # All reserved MemoryObjs share one layout_desc, so per-object # size is identical — avoid summing N identical values. total_bytes = ( @@ -509,6 +565,24 @@ def store( ) ed = time.perf_counter() + logger.info( + "[GPU-STORE] req=%s resolve_keys=%.3f copy_block_ids=%.3f " + "event_ipc_wait=%.3f event_publish=%.3f reserve_write=%.3f " + "kernel_loop=%.3f event_record=%.3f submit_cb=%.3f total=%.3f ms " + "(num_chunks=%d, num_groups=%d)", + key.request_id, + (_t_resolve - st) * 1000, + (_t_copy_ids - _t_resolve) * 1000, + (_t_event_wait - _t_copy_ids) * 1000, + (_t_publish - _t_event_wait) * 1000, + (_t_reserve - _t_publish) * 1000, + (_t_loop_end - _t_reserve) * 1000, + (_t_record_end - _t_record_start) * 1000, + (_t_callback_end - _t_record_end) * 1000, + (ed - st) * 1000, + len(obj_keys), + _num_groups, + ) if length := len(reserved_dict): logger.info( "Stored %d tokens in %.3f seconds", diff --git a/lmcache/v1/multiprocess/modules/non_gpu_transfer.py b/lmcache/v1/multiprocess/modules/non_gpu_transfer.py index 0f15026176..607201d4f6 100644 --- a/lmcache/v1/multiprocess/modules/non_gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/non_gpu_transfer.py @@ -296,6 +296,7 @@ def prepare_store( Returns: PrepareStoreResponse with empty slots for pickle mode. """ + t_start = time.perf_counter() entry = self._non_gpu_contexts.get(instance_id) if entry is None: raise ValueError( @@ -306,14 +307,25 @@ def prepare_store( raise ValueError( f"transfer strategy not registered for instance ID {instance_id}" ) + t_resolve = time.perf_counter() response = strategy.prepare_store( key=key, instance_id=instance_id, context=entry.metadata, resolve_obj_keys=self._ctx.resolve_obj_keys, ) + t_prepare = time.perf_counter() session = self._ctx.session_manager.get_or_create(key.request_id) session.extras["store_start_time"] = time.perf_counter() + logger.info( + "[SRV-PREPARE-STORE] req=%s resolve_keys=%.3f prepare=%.3f" + " total=%.3f ms (strategy=%s)", + key.request_id, + (t_resolve - t_start) * 1000, + (t_prepare - t_resolve) * 1000, + (t_prepare - t_start) * 1000, + strategy.strategy_name, + ) return response @_lmcache_nvtx_annotate @@ -349,6 +361,7 @@ def commit_store( ) session = self._ctx.session_manager.get_or_create(key.request_id) st = session.extras.pop("store_start_time", None) + t_commit_start = time.perf_counter() result = strategy.commit_store( key=key, instance_id=instance_id, @@ -356,6 +369,7 @@ def commit_store( context=entry.metadata, resolve_obj_keys=self._ctx.resolve_obj_keys, ) + t_commit_end = time.perf_counter() if st is not None and result: num_tokens = len(self._ctx.resolve_obj_keys(key)) * self._ctx.chunk_size logger.info( @@ -363,6 +377,15 @@ def commit_store( num_tokens, time.perf_counter() - st, ) + logger.info( + "[SRV-COMMIT-STORE] req=%s commit=%.3f total_since_prepare=%.3f ms" + " (strategy=%s, num_tokens=%d)", + key.request_id, + (t_commit_end - t_commit_start) * 1000, + (t_commit_end - st) * 1000, + strategy.strategy_name, + num_tokens, + ) return result @_lmcache_nvtx_annotate diff --git a/lmcache/v1/multiprocess/modules/server_transfer.py b/lmcache/v1/multiprocess/modules/server_transfer.py index fc967f79db..e3e178b6fa 100644 --- a/lmcache/v1/multiprocess/modules/server_transfer.py +++ b/lmcache/v1/multiprocess/modules/server_transfer.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any import abc import pickle +import time # Third Party import torch @@ -85,6 +86,15 @@ class TransferStrategy(abc.ABC): shared-memory-based transfers behind a common interface. """ + @property + @abc.abstractmethod + def strategy_name(self) -> str: + """Return a short human-readable name identifying this strategy. + + Returns: + A lowercase string label such as ``"pickle"`` or ``"shm"``. + """ + @abc.abstractmethod def prepare_store( self, @@ -182,6 +192,11 @@ def __init__( """ self._storage_manager = storage_manager + @property + def strategy_name(self) -> str: + """Return ``"pickle"`` as the strategy identifier.""" + return "pickle" + def prepare_store( self, key: IPCCacheEngineKey, @@ -208,11 +223,14 @@ def commit_store( Returns: ``True`` when every reserved object is written successfully. """ + t_start = time.perf_counter() obj_keys = resolve_obj_keys(key) chunks: list[torch.Tensor] = pickle.loads(cpu_data) + t_deserialize = time.perf_counter() reserved_dict = self._storage_manager.reserve_write( obj_keys, context.layout_desc, "new" ) + t_reserve_write = time.perf_counter() written_keys: list[ObjectKey] = [] try: for idx, obj_key in enumerate(obj_keys): @@ -229,9 +247,21 @@ def commit_store( memory_obj.tensor.copy_(chunk_cpu) written_keys.append(obj_key) finally: + t_copy_loop = time.perf_counter() if written_keys: self._storage_manager.finish_write(written_keys) - + t_finish_write = time.perf_counter() + logger.info( + "[PICKLE-COMMIT] req=%s deserialize=%.3f reserve_write=%.3f" + " copy_loop=%.3f finish_write=%.3f total=%.3f ms (num_chunks=%d)", + key.request_id, + (t_deserialize - t_start) * 1000, + (t_reserve_write - t_deserialize) * 1000, + (t_copy_loop - t_reserve_write) * 1000, + (t_finish_write - t_copy_loop) * 1000, + (t_finish_write - t_start) * 1000, + len(chunks), + ) return len(written_keys) == len(reserved_dict) def prepare_retrieve( @@ -309,6 +339,11 @@ def __init__( self._transfer_key_factory = transfer_key_factory self._fallback_strategy = fallback_strategy + @property + def strategy_name(self) -> str: + """Return ``"shm"`` as the strategy identifier.""" + return "shm" + def prepare_store( self, key: IPCCacheEngineKey, @@ -321,10 +356,13 @@ def prepare_store( Returns: Context with ``slots`` and ``chunk_indices``. """ + t_start = time.perf_counter() obj_keys = resolve_obj_keys(key) + t_resolve = time.perf_counter() reserved = self._storage_manager.reserve_write( obj_keys, context.layout_desc, "new" ) + t_reserve_write = time.perf_counter() slots: list[dict[str, Any]] = [] chunk_indices: list[int] = [] reserved_keys: list[ObjectKey] = [] @@ -350,6 +388,17 @@ def prepare_store( ] if unused_keys: self._storage_manager.finish_write(unused_keys) + t_slots = time.perf_counter() + logger.info( + "[SHM-PREPARE] req=%s resolve_keys=%.3f reserve_write=%.3f" + " slots=%.3f total=%.3f ms (num_slots=%d)", + key.request_id, + (t_resolve - t_start) * 1000, + (t_reserve_write - t_resolve) * 1000, + (t_slots - t_reserve_write) * 1000, + (t_slots - t_start) * 1000, + len(reserved_keys), + ) if not reserved_keys: return PrepareStoreResponse(context={"slots": [], "chunk_indices": []}) transfer_key = self._transfer_key_factory(key, instance_id) @@ -380,13 +429,23 @@ def commit_store( context=context, resolve_obj_keys=resolve_obj_keys, ) + t_start = time.perf_counter() transfer_key = self._transfer_key_factory(key, instance_id) with self._pending_lock: reserved_keys = self._pending_writes.pop(transfer_key, None) if reserved_keys is None: return False + t_before_fw = time.perf_counter() if reserved_keys: self._storage_manager.finish_write(reserved_keys) + t_finish_write = time.perf_counter() + logger.info( + "[SHM-COMMIT] req=%s finish_write=%.3f total=%.3f ms (num_keys=%d)", + key.request_id, + (t_finish_write - t_before_fw) * 1000, + (t_finish_write - t_start) * 1000, + len(reserved_keys), + ) return True def prepare_retrieve( diff --git a/lmcache/v1/multiprocess/transfer_context/__init__.py b/lmcache/v1/multiprocess/transfer_context/__init__.py index 6551664df5..833ac32748 100644 --- a/lmcache/v1/multiprocess/transfer_context/__init__.py +++ b/lmcache/v1/multiprocess/transfer_context/__init__.py @@ -7,6 +7,7 @@ """ # Local +from .async_data import AsyncDataTransferContext from .base import ( NonGpuContext, NonGpuContextMetadata, @@ -26,6 +27,7 @@ ) __all__ = [ + "AsyncDataTransferContext", "DataTransferContext", "HandleTransferContext", "MPTransferMode", diff --git a/lmcache/v1/multiprocess/transfer_context/async_data.py b/lmcache/v1/multiprocess/transfer_context/async_data.py new file mode 100644 index 0000000000..0f35811da5 --- /dev/null +++ b/lmcache/v1/multiprocess/transfer_context/async_data.py @@ -0,0 +1,371 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Async non-GPU data transfer context for multiprocess worker adapters.""" + +# Standard +from concurrent.futures import Future as ConcurrentFuture +from concurrent.futures import ThreadPoolExecutor +from typing import Any +import threading +import time + +# Third Party +import torch + +# First Party +from lmcache import torch_dev +from lmcache.utils import init_logger +from lmcache.v1.multiprocess.futures import MessagingFuture +from lmcache.v1.multiprocess.transfer_context.base import gather_paged_kv_to_cpu +from lmcache.v1.multiprocess.transfer_context.worker_transfer import ( + DataTransferContext, + IPCEvent, + _single_group_block_ids, +) + +logger = init_logger(__name__) + +DEFAULT_MAX_ASYNC_NON_GPU_STORES = 8 +# Number of background threads used to run commit (CPU->server) work for the +# async non-GPU store path. >1 so that a slow gather for one store does not +# block the commit of another store whose gather already finished. +DEFAULT_NON_GPU_COMMIT_WORKERS = 4 + + +class AsyncDataTransferContext(DataTransferContext): + """Fully async non-GPU data transfer context (store-only async). + + "Store-only async" means ``submit_store`` returns an *unresolved* future + that resolves only after the deferred gather (GPU->CPU copy) and commit + (CPU->server) both complete off the forward thread, while + ``submit_retrieve`` stays synchronous and returns an already-resolved + future exactly as on the base context. + + Inherits :class:`DataTransferContext` and reuses its ``register()`` (layout + / SHM registration, no stream dependency) and ``submit_retrieve()`` (this + path does not change retrieve). Only the store is made async. + + Store is two-phase, both executed entirely in a background thread: + 1) gather: wait for the forward event on the copy stream, then enqueue + GPU->CPU copies. When SHM buffers are available, gather writes directly + into SHM views (matching the synchronous path). Otherwise, gather + targets pinned staging buffers. + 2) commit: wait for gather completion (via a recorded CUDA event), then + perform commit_store() and resolve the returned future. + + ``submit_store`` only performs lightweight preparation (prepare_store, + buffer allocation) on the forward thread and immediately submits all + GPU/copy work to the background ``commit_executor``, so the forward thread + is never blocked by gather kernel launch latency. + + This class is only instantiated by the factory when the device is + async-capable, so the constructor creates async resources unconditionally; + there is no ``self._async_capable`` flag. + """ + + def __init__( + self, + max_inflight_stores: int = DEFAULT_MAX_ASYNC_NON_GPU_STORES, + commit_workers: int = DEFAULT_NON_GPU_COMMIT_WORKERS, + ) -> None: + """Initialize the async context and create its async resources. + + Args: + max_inflight_stores: Max number of concurrently in-flight async + stores before ``submit_store`` blocks (backpressure). + commit_workers: Number of background threads used to run commit + (CPU->server) work. >1 so a slow gather for one store does not + block the commit of another whose gather is already done. + """ + super().__init__() + self._max_inflight_stores = max(1, int(max_inflight_stores)) + self._commit_workers = max(1, int(commit_workers)) + self._copy_stream: Any = torch_dev.Stream() + self._commit_executor: ThreadPoolExecutor = ThreadPoolExecutor( + max_workers=self._commit_workers, + thread_name_prefix="lmcache_non_gpu_commit", + ) + self._inflight_lock = threading.Lock() + self._inflight_gather_events: set[Any] = set() + self._inflight_commits: set[ConcurrentFuture[None]] = set() + self._staging_pool: dict[ + tuple[tuple[int, ...], torch.dtype], list[torch.Tensor] + ] = {} + self._is_closing = False + + def _alloc_pinned_staging( + self, shape: torch.Size, dtype: torch.dtype, count: int + ) -> list[torch.Tensor]: + key = (tuple(shape), dtype) + with self._inflight_lock: + pooled = self._staging_pool.setdefault(key, []) + staged = [pooled.pop() for _ in range(min(len(pooled), count))] + if len(staged) == count: + return staged + + missing = count - len(staged) + for _ in range(missing): + try: + staged.append( + torch.empty(shape, dtype=dtype, device="cpu", pin_memory=True) + ) + except RuntimeError: + # Graceful fallback for CPU-only / pin-memory-disabled setups. + logger.warning( + "Falling back to non-pinned CPU staging buffer " + "(shape=%s, dtype=%s)", + tuple(shape), + dtype, + ) + staged.append(torch.empty(shape, dtype=dtype, device="cpu")) + return staged + + def _release_staging(self, chunks: list[torch.Tensor]) -> None: + if not chunks: + return + key = (tuple(chunks[0].shape), chunks[0].dtype) + with self._inflight_lock: + self._staging_pool.setdefault(key, []).extend(chunks) + + def submit_store( + self, + _request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[list[int]], + _event: IPCEvent, + blocks_in_chunk: int, + ) -> MessagingFuture: + """Two-phase async store (gather and commit both in background thread). + + Performs lightweight preparation (prepare_store, buffer allocation) on + the forward thread and immediately submits the gather + commit work to + the background ``commit_executor``. Returns an unresolved future that + resolves only after both gather completion and the commit ACK. + """ + _t_entry = time.perf_counter() + if self._non_gpu_context is None: + raise RuntimeError( + "Data transfer context is not registered. " + "Call register() before submit_store()." + ) + completion: MessagingFuture[bool] = MessagingFuture() + non_gpu_context = self._non_gpu_context + commit_executor = self._commit_executor + + staged_chunks: list[torch.Tensor] = [] + # Whether we gathered directly into SHM views (True) or into + # pinned staging buffers that need to be released later (False). + used_shm_direct = False + try: + _t0 = time.perf_counter() + with self._inflight_lock: + if self._is_closing: + completion.set_result(False) + return completion + _t_lock = time.perf_counter() + + result = non_gpu_context.prepare_store(key, instance_id) + _t_prepare = time.perf_counter() + + out_buffers, chunk_indices = result if result is not None else (None, None) + if chunk_indices is not None and len(chunk_indices) == 0: + # All chunks are already in cache: no gather, no commit. + completion.set_result(True) + return completion + + full_block_ids = _single_group_block_ids(block_ids) + _t_block_ids = time.perf_counter() + + num_chunks = ( + len(chunk_indices) + if chunk_indices is not None + else len(full_block_ids) // blocks_in_chunk + ) + + # Determine gather target: + # - SHM path (out_buffers available): gather directly into SHM views + # - Pickle path (no out_buffers): gather into pinned staging buffers + if out_buffers is not None: + # SHM path: gather directly into SHM views, no staging needed. + gather_target = out_buffers + used_shm_direct = True + else: + # Pickle path: allocate pinned staging buffers. + if not non_gpu_context.layout_desc.shapes: + raise RuntimeError("non-GPU layout_desc.shapes is empty") + if not non_gpu_context.layout_desc.dtypes: + raise RuntimeError("non-GPU layout_desc.dtypes is empty") + staged_chunks = self._alloc_pinned_staging( + non_gpu_context.layout_desc.shapes[0], + non_gpu_context.layout_desc.dtypes[0], + num_chunks, + ) + gather_target = staged_chunks + _t_alloc = time.perf_counter() + + # Capture variables for the closure + _used_shm_direct = used_shm_direct + _gather_target = gather_target + _t_submit_start = time.perf_counter() + + def _commit_after_gather() -> None: + _tb_entry = time.perf_counter() + gather_done: Any | None = None + ok = False + try: + _tb0 = time.perf_counter() + with torch.inference_mode(), torch_dev.stream(self._copy_stream): + _tb_stream_enter = time.perf_counter() + + _event.wait(stream=self._copy_stream) + _tb_event_wait = time.perf_counter() + + gather_paged_kv_to_cpu( + kv_caches, + full_block_ids, + blocks_in_chunk, + layout_hints=self._layout_hints, + gpu_kv_format=self._gpu_kv_format, + out=_gather_target, + chunk_indices=chunk_indices, + ) + _tb_gather = time.perf_counter() + + gather_done = torch_dev.Event() + gather_done.record(self._copy_stream) + _tb_record = time.perf_counter() + + _tb_stream_exit = time.perf_counter() + + with self._inflight_lock: + if gather_done is not None: + self._inflight_gather_events.add(gather_done) + _tb_lock = time.perf_counter() + + if gather_done is not None: + gather_done.synchronize() + _tb_sync = time.perf_counter() + + ok = non_gpu_context.commit_store(key, instance_id, _gather_target) + _tb_commit = time.perf_counter() + + if not ok: + logger.error( + "Async non-GPU commit_store failed for request_id=%s", + _request_id, + ) + + logger.info( + "[BG %s] thread_start=%.3f stream_enter=%.3f " + "event_wait=%.3f gather_launch=%.3f record=%.3f " + "stream_exit=%.3f lock=%.3f sync=%.3f commit=%.3f " + "total=%.3f ms", + _request_id, + (_tb0 - _tb_entry) * 1000, + (_tb_stream_enter - _tb0) * 1000, + (_tb_event_wait - _tb_stream_enter) * 1000, + (_tb_gather - _tb_event_wait) * 1000, + (_tb_record - _tb_gather) * 1000, + (_tb_stream_exit - _tb_record) * 1000, + (_tb_lock - _tb_stream_exit) * 1000, + (_tb_sync - _tb_lock) * 1000, + (_tb_commit - _tb_sync) * 1000, + (_tb_commit - _tb_entry) * 1000, + ) + except Exception: + logger.exception( + "Async non-GPU store failed for request_id=%s", + _request_id, + ) + ok = False + finally: + if not _used_shm_direct: + self._release_staging(staged_chunks) + with self._inflight_lock: + if gather_done is not None: + self._inflight_gather_events.discard(gather_done) + completion.set_result(ok) + + # Submitting the commit task is the ownership-transfer point: once it + # succeeds, the commit task is solely responsible for releasing the + # staging buffers, and resolving the future. The except below therefore + # only handles failures that occur *before* this submit, so it can never + # double-release or double-resolve. + commit_future = commit_executor.submit(_commit_after_gather) + _t_submit_end = time.perf_counter() + except Exception: + logger.exception("Failed to submit async non-GPU store") + if staged_chunks: + self._release_staging(staged_chunks) + completion.set_result(False) + return completion + + with self._inflight_lock: + self._inflight_commits.add(commit_future) + + def _drop_commit_future(done_future: ConcurrentFuture[None]) -> None: + with self._inflight_lock: + self._inflight_commits.discard(done_future) + + commit_future.add_done_callback(_drop_commit_future) + + _t_exit = time.perf_counter() + logger.info( + "[FWD %s] lock=%.3f prepare=%.3f block_ids=%.3f " + "alloc=%.3f submit=%.3f bookkeep=%.3f total=%.3f ms " + "(num_chunks=%d, shm=%s)", + _request_id, + (_t_lock - _t0) * 1000, + (_t_prepare - _t_lock) * 1000, + (_t_block_ids - _t_prepare) * 1000, + (_t_alloc - _t_block_ids) * 1000, + (_t_submit_end - _t_submit_start) * 1000, + (_t_exit - _t_submit_end) * 1000, + (_t_exit - _t_entry) * 1000, + num_chunks, + used_shm_direct, + ) + return completion + + def flush_inflight_gathers(self) -> None: + """Synchronize all in-flight gather (GPU->CPU) events. + + Called at preemption/eviction time (and during ``close``) so that vLLM + cannot overwrite paged KV blocks before a deferred gather has finished + reading them. Only gather completion is awaited; commit futures are not + affected, since commits read from LMCache-owned staging buffers. + """ + _t0 = time.perf_counter() + with self._inflight_lock: + gather_events = list(self._inflight_gather_events) + for event in gather_events: + event.synchronize() + _t1 = time.perf_counter() + logger.info( + "[flush_inflight_gathers] synced %d events in %.3f ms", + len(gather_events), + (_t1 - _t0) * 1000, + ) + + def close(self) -> None: + # Drain in-flight gather/commit work before closing the base context. + _t0 = time.perf_counter() + with self._inflight_lock: + self._is_closing = True + gather_events = list(self._inflight_gather_events) + for event in gather_events: + try: + event.synchronize() + except Exception: + logger.exception("Failed while draining gather events") + _t1 = time.perf_counter() + self._commit_executor.shutdown(wait=True, cancel_futures=False) + _t2 = time.perf_counter() + logger.info( + "[close] gather_drain=%.3f executor_shutdown=%.3f total=%.3f ms", + (_t1 - _t0) * 1000, + (_t2 - _t1) * 1000, + (_t2 - _t0) * 1000, + ) + super().close() diff --git a/lmcache/v1/multiprocess/transfer_context/shm.py b/lmcache/v1/multiprocess/transfer_context/shm.py index 0178d27faf..025b80e4ad 100644 --- a/lmcache/v1/multiprocess/transfer_context/shm.py +++ b/lmcache/v1/multiprocess/transfer_context/shm.py @@ -2,6 +2,7 @@ """Shared-memory NonGpuContext implementation for multiprocess mode.""" # Standard +import ctypes from dataclasses import dataclass from multiprocessing import shared_memory from multiprocessing.resource_tracker import unregister @@ -11,6 +12,8 @@ import torch # First Party +from lmcache import torch_dev +from lmcache.logging import init_logger from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey from lmcache.v1.multiprocess.mq import MessageQueueClient from lmcache.v1.multiprocess.protocol import RequestType, get_response_class @@ -19,6 +22,8 @@ NonGpuContextMetadata, ) +logger = init_logger(__name__) + @dataclass(frozen=True) class ShmSlotDescriptor: @@ -92,6 +97,9 @@ def __init__( self._pool_size = pool_size self._shm: shared_memory.SharedMemory | None = None self._shm_buffer: memoryview | None = None + self._pinned = False + self._pinned_ptr = 0 + self._pinned_size = 0 try: self._shm = shared_memory.SharedMemory( name=shm_name.lstrip("/"), create=False @@ -101,6 +109,8 @@ def __init__( # unlink the segment when this worker exits. unregister(f"/{self._shm.name}", "shared_memory") self._shm_buffer = self._shm.buf + self._register_shm_buffer() + logger.info("SHM pinned=%s for shm_name=%s", self._pinned, self._shm_name) except Exception: self._shm = None self._shm_buffer = None @@ -212,7 +222,70 @@ def close(self) -> None: if self._shm is None: return try: - self._shm.close() + self._unregister_shm_buffer() finally: - self._shm = None - self._shm_buffer = None + try: + self._shm.close() + finally: + self._shm = None + self._shm_buffer = None + + def _register_shm_buffer(self) -> None: + if self._shm_buffer is None or not torch_dev.is_available(): + return + if not hasattr(torch_dev, "cudart"): + logger.warning( + "Skipping SHM host registration for shm_name=%s: " + "backend does not support cudart(); D2H copies will be synchronous", + self._shm_name, + ) + return + try: + ptr = ctypes.addressof(ctypes.c_char.from_buffer(self._shm_buffer)) + err = torch_dev.cudart().cudaHostRegister(ptr, self._pool_size, 0) + except Exception as exc: + logger.warning( + "Failed to register SHM buffer for shm_name=%s: %r; " + "D2H copies will be synchronous", + self._shm_name, + exc, + ) + return + if err != 0: + logger.warning( + "cudaHostRegister failed for shm_name=%s (ptr=%d, size=%d, err=%s); " + "D2H copies will be synchronous", + self._shm_name, + ptr, + self._pool_size, + err, + ) + return + self._pinned = True + self._pinned_ptr = ptr + self._pinned_size = self._pool_size + + def _unregister_shm_buffer(self) -> None: + if not self._pinned or self._pinned_ptr == 0: + return + try: + err = torch_dev.cudart().cudaHostUnregister(self._pinned_ptr) + if err != 0: + logger.warning( + "cudaHostUnregister failed for shm_name=%s (ptr=%d, size=%d, " + "err=%s)", + self._shm_name, + self._pinned_ptr, + self._pinned_size, + err, + ) + except Exception as exc: + logger.warning( + "Failed to unregister SHM buffer for shm_name=%s: %r", + self._shm_name, + exc, + ) + finally: + self._pinned = False + self._pinned_ptr = 0 + self._pinned_size = 0 diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index 41c72ca7d2..4b475bef76 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Callable, Protocol import os +import time # Third Party import torch @@ -94,6 +95,9 @@ class IPCEvent(Protocol): def ipc_handle(self) -> object: """Return an IPC handle consumable by the multiprocess server.""" + def wait(self, stream: object | None = None) -> None: + """Make ``stream`` wait for this event (async ordering primitive).""" + SendRequest = Callable[[MessageQueueClient, RequestType, list[object]], MessagingFuture] @@ -213,6 +217,15 @@ def submit_retrieve( def close(self) -> None: """Release resources held by this context.""" + def flush_inflight_gathers(self) -> None: + """Synchronize any in-flight gather operations. + + The default implementation is a no-op. Non-GPU async save contexts can + override this to make preemption handling block until deferred reads of + vLLM paged KV data are complete. + """ + return None + class HandleTransferContext(TransferContext): """Handle-based IPC + MQ future transport context.""" @@ -269,11 +282,26 @@ def submit_store( "Handle transfer context is not registered. " "Call register() before submit_store()." ) - return self._send_request( + _t0 = time.perf_counter() + ipc_handle = event.ipc_handle() + _t_ipc = time.perf_counter() + mq_future = self._send_request( self._mq_client, RequestType.STORE, - [key, instance_id, block_ids, event.ipc_handle()], - ).to_cuda_future() + [key, instance_id, block_ids, ipc_handle], + ) + _t_send = time.perf_counter() + cuda_future = mq_future.to_cuda_future() + _t_cuda = time.perf_counter() + logger.info( + "[FWD-IPC] req=%s ipc_handle=%.3f send_request=%.3f to_cuda_future=%.3f total=%.3f ms", + _request_id, + (_t_ipc - _t0) * 1000, + (_t_send - _t_ipc) * 1000, + (_t_cuda - _t_send) * 1000, + (_t_cuda - _t0) * 1000, + ) + return cuda_future def submit_retrieve( self, @@ -529,8 +557,64 @@ def create_transfer_context( if resolved_mode is MPTransferMode.HANDLE: return _build_handle_context(device_type) if resolved_mode is MPTransferMode.DATA: - return DataTransferContext() + return _build_data_context(kv_caches) # AUTO: preserve the historical device-type-based dispatch. - if device_type == "cuda": - return HandleTransferContext() + # if device_type == "cuda": + # return HandleTransferContext() + return _build_data_context(kv_caches) + + +def _supports_async_primitives(kv_caches: dict[str, torch.Tensor]) -> bool: + """Probe whether the worker device supports the async store primitives. + + The async non-GPU store path needs a stream, an event exposing + ``record``/``synchronize``/``wait``, and pinned (page-locked) host memory. + When any of these is unavailable (e.g. a CPU-only backend), the factory + falls back to the synchronous :class:`DataTransferContext`. This dispatch is + internal and capability-based; there is no user-facing async/sync flag. + + Args: + kv_caches: Worker KV cache tensors keyed by layer name. Currently unused + (capability is a property of ``torch_dev``), accepted to keep the + probe signature forward-compatible with device-specific checks. + + Returns: + True if all required async primitives are available, else False. + """ + if not hasattr(torch_dev, "Stream") or not hasattr(torch_dev, "Event"): + return False + try: + torch_dev.Stream() + event = torch_dev.Event() + except Exception: + return False + for attr in ("record", "synchronize", "wait"): + if not callable(getattr(event, attr, None)): + return False + try: + torch.empty(1, dtype=torch.uint8, device="cpu", pin_memory=True) + except (RuntimeError, TypeError): + return False + return True + + +def _build_data_context(kv_caches: dict[str, torch.Tensor]) -> "TransferContext": + """Build the non-GPU data context, async when device-capable else sync. + + Routes the ``DATA`` and AUTO non-cuda branches through a single capability + check. ``AsyncDataTransferContext`` is imported lazily to avoid an import + cycle and to keep the synchronous path free of stream/event dependencies. + """ + if _supports_async_primitives(kv_caches): + # First Party + from lmcache.v1.multiprocess.transfer_context.async_data import ( + AsyncDataTransferContext, + ) + + logger.info( + " >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> AsyncDataTransferContext " + ) + return AsyncDataTransferContext() + + logger.info(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> SyncDataTransferContext ") return DataTransferContext() diff --git a/tests/v1/multiprocess/test_async_data_transfer_context.py b/tests/v1/multiprocess/test_async_data_transfer_context.py new file mode 100644 index 0000000000..a2b14af5d1 --- /dev/null +++ b/tests/v1/multiprocess/test_async_data_transfer_context.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from contextlib import nullcontext +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Callable +from unittest.mock import MagicMock +import threading + +# Third Party +import pytest +import torch + +# First Party +from lmcache.v1.multiprocess.transfer_context import async_data, worker_transfer +from lmcache.v1.multiprocess.transfer_context.async_data import ( + AsyncDataTransferContext, +) +from lmcache.v1.multiprocess.transfer_context.worker_transfer import DataTransferContext + + +@dataclass +class _FakeStoreContext: + """Minimal non-GPU context for async store tests.""" + + commit_impl: Callable[[list[torch.Tensor]], bool] + prepare_result: tuple[list[torch.Tensor], list[int]] | None = None + + def __post_init__(self) -> None: + self.layout_desc = SimpleNamespace( + shapes=[torch.Size([2, 1, 1, 1])], dtypes=[torch.float32] + ) + + def prepare_store( + self, _key: object, _instance_id: int + ) -> tuple[list[torch.Tensor], list[int]] | None: + return self.prepare_result + + def commit_store( + self, _key: object, _instance_id: int, chunks: list[torch.Tensor] + ) -> bool: + return bool(self.commit_impl(chunks)) + + def close(self) -> None: + return None + + +class _FakeEvent: + def __init__(self, gate: threading.Event): + self._gate = gate + + def record(self, stream: object | None = None) -> None: + return None + + def wait(self, stream: object | None = None) -> None: + return None + + def synchronize(self) -> None: + self._gate.wait(timeout=2) + + def query(self) -> bool: + return self._gate.is_set() + + +class _FakeTorchDev: + def __init__(self, gather_gate: threading.Event): + self._stream = object() + self._gather_gate = gather_gate + + def Stream(self) -> object: + return object() + + def stream(self, stream: object) -> object: + return nullcontext(stream) + + def current_stream(self) -> object: + return self._stream + + def Event(self, interprocess: bool = False) -> _FakeEvent: + return _FakeEvent(self._gather_gate) + + +def _install_fake_gather(monkeypatch: pytest.MonkeyPatch) -> None: + def _gather( + _kv_caches: dict[str, torch.Tensor], + _block_ids: list[int], + _blocks_in_chunk: int, + **kwargs: object, + ) -> list[torch.Tensor]: + out = kwargs.get("out") + if out is None: + # Sync fallback path passes out=None: gather allocates its own + # buffers and returns them, so mirror that contract here. + return [torch.ones(1)] + assert isinstance(out, list) + for tensor in out: + tensor.fill_(1.0) + return out + + # The async path resolves ``gather_paged_kv_to_cpu`` from async_data, the + # sync path from worker_transfer; patch both so either is exercised. + monkeypatch.setattr(async_data, "gather_paged_kv_to_cpu", _gather) + monkeypatch.setattr(worker_transfer, "gather_paged_kv_to_cpu", _gather) + + +def _new_context( + monkeypatch: pytest.MonkeyPatch, + *, + gather_gate: threading.Event, + commit_impl: Callable[[list[torch.Tensor]], bool], + max_inflight: int = 8, +) -> AsyncDataTransferContext: + monkeypatch.setattr(async_data, "torch_dev", _FakeTorchDev(gather_gate)) + _install_fake_gather(monkeypatch) + ctx = AsyncDataTransferContext(max_inflight_stores=max_inflight) + ctx._non_gpu_context = _FakeStoreContext(commit_impl=commit_impl) + return ctx + + +def test_submit_store_returns_pending_future_until_gather_and_commit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + gather_gate = threading.Event() + ctx = _new_context( + monkeypatch, gather_gate=gather_gate, commit_impl=lambda _c: True + ) + future = ctx.submit_store( + "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 + ) + assert not future.query() + gather_gate.set() + assert future.result(timeout=1) is True + ctx.close() + + +def test_submit_store_commit_waits_for_gather_done( + monkeypatch: pytest.MonkeyPatch, +) -> None: + gather_gate = threading.Event() + commit_called = threading.Event() + + def _commit(_chunks: list[torch.Tensor]) -> bool: + commit_called.set() + return True + + ctx = _new_context(monkeypatch, gather_gate=gather_gate, commit_impl=_commit) + future = ctx.submit_store( + "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 + ) + assert not commit_called.wait(timeout=0.05) + gather_gate.set() + assert future.result(timeout=1) is True + assert commit_called.is_set() + ctx.close() + + +def test_submit_store_backpressure_blocks_when_inflight_cap_hit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + gather_gate = threading.Event() + commit_gate = threading.Event() + gather_gate.set() + + def _commit(_chunks: list[torch.Tensor]) -> bool: + commit_gate.wait(timeout=2) + return True + + ctx = _new_context( + monkeypatch, gather_gate=gather_gate, commit_impl=_commit, max_inflight=1 + ) + first = ctx.submit_store( + "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 + ) + done = threading.Event() + + def _submit_second() -> None: + try: + ctx.submit_store( + "r2", + object(), + 1, + {"k": torch.zeros(1)}, + [[0]], + _FakeEvent(gather_gate), + 1, + ) + finally: + done.set() + + t = threading.Thread(target=_submit_second, daemon=True) + t.start() + assert not done.wait(timeout=0.1) + commit_gate.set() + assert first.result(timeout=1) is True + t.join(timeout=1) + assert done.is_set() + ctx.close() + + +def test_close_drains_inflight_async_store(monkeypatch: pytest.MonkeyPatch) -> None: + gather_gate = threading.Event() + commit_gate = threading.Event() + gather_gate.set() + + def _commit(_chunks: list[torch.Tensor]) -> bool: + commit_gate.wait(timeout=2) + return True + + ctx = _new_context(monkeypatch, gather_gate=gather_gate, commit_impl=_commit) + future = ctx.submit_store( + "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 + ) + closed = threading.Event() + + def _close() -> None: + ctx.close() + closed.set() + + t = threading.Thread(target=_close, daemon=True) + t.start() + assert not closed.wait(timeout=0.05) + commit_gate.set() + t.join(timeout=1) + assert closed.is_set() + assert future.result(timeout=1) is True + + +def test_commit_failure_sets_false_and_logs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + gather_gate = threading.Event() + + def _commit(_chunks: list[torch.Tensor]) -> bool: + raise RuntimeError("commit failed") + + log_exception = MagicMock() + monkeypatch.setattr(async_data.logger, "exception", log_exception) + ctx = _new_context(monkeypatch, gather_gate=gather_gate, commit_impl=_commit) + future = ctx.submit_store( + "r1", object(), 1, {"k": torch.zeros(1)}, [[0]], _FakeEvent(gather_gate), 1 + ) + gather_gate.set() + assert future.result(timeout=1) is False + log_exception.assert_called_once() + ctx.close() + + +def test_flush_inflight_gathers_no_inflight_is_noop( + monkeypatch: pytest.MonkeyPatch, +) -> None: + gather_gate = threading.Event() + gather_gate.set() + ctx = _new_context( + monkeypatch, gather_gate=gather_gate, commit_impl=lambda _c: True + ) + # No in-flight events: flush is a cheap no-op and must not raise. + ctx.flush_inflight_gathers() + ctx.close() + + +class _RecordingTorchDev: + """torch_dev stub that records whether async primitives are touched.""" + + def __init__(self) -> None: + self.synchronize_calls = 0 + self.stream_calls = 0 + self.event_calls = 0 + + def synchronize(self) -> None: + self.synchronize_calls += 1 + + def Stream(self) -> object: + self.stream_calls += 1 + return object() + + def stream(self, stream: object) -> object: + return nullcontext(stream) + + def Event(self, interprocess: bool = False) -> _FakeEvent: + self.event_calls += 1 + return _FakeEvent(threading.Event()) + + +def test_sync_data_context_returns_resolved_future( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake = _RecordingTorchDev() + monkeypatch.setattr(worker_transfer, "torch_dev", fake) + _install_fake_gather(monkeypatch) + ctx = DataTransferContext() + ctx._non_gpu_context = _FakeStoreContext(commit_impl=lambda _c: True) + + future = ctx.submit_store( + "r1", + object(), + 1, + {"k": torch.zeros(1)}, + [[0]], + _FakeEvent(threading.Event()), + 1, + ) + + # Sync path resolves inline and never touches copy-stream / event primitives. + assert future.query() + assert future.result(timeout=1) is True + assert fake.stream_calls == 0 + assert fake.event_calls == 0 + assert fake.synchronize_calls >= 1 + # flush_inflight_gathers is the inherited base no-op; neither it nor close + # must raise on the synchronous path. + ctx.flush_inflight_gathers() + ctx.close() + + +def test_sync_data_context_has_no_async_resources() -> None: + # The synchronous DataTransferContext must not create any async resources + # and must not expose async-only attributes. + ctx = DataTransferContext() + assert not hasattr(ctx, "_copy_stream") + assert not hasattr(ctx, "_commit_executor") + assert not hasattr(ctx, "_inflight_semaphore") + # close() on an unregistered sync context must not raise. + ctx.close() + + +def test_build_data_context_dispatches_on_capability( + monkeypatch: pytest.MonkeyPatch, +) -> None: + kv_caches = {"k": torch.zeros(1)} + + monkeypatch.setattr(worker_transfer, "_supports_async_primitives", lambda _kv: True) + # Avoid touching real stream/event primitives when instantiating the async + # context returned by the capable branch. + monkeypatch.setattr(async_data, "torch_dev", _FakeTorchDev(threading.Event())) + capable = worker_transfer._build_data_context(kv_caches) + assert isinstance(capable, AsyncDataTransferContext) + capable.close() + + monkeypatch.setattr( + worker_transfer, "_supports_async_primitives", lambda _kv: False + ) + fallback = worker_transfer._build_data_context(kv_caches) + assert isinstance(fallback, DataTransferContext) + assert not isinstance(fallback, AsyncDataTransferContext) + fallback.close() + + +def test_supports_async_primitives_false_without_stream( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # A torch_dev without Stream/Event is not async-capable. + monkeypatch.setattr(worker_transfer, "torch_dev", object()) + assert worker_transfer._supports_async_primitives({"k": torch.zeros(1)}) is False diff --git a/tests/v1/multiprocess/test_non_cuda_data_transfer.py b/tests/v1/multiprocess/test_non_cuda_data_transfer.py index c60290b917..1aaa38b6f7 100644 --- a/tests/v1/multiprocess/test_non_cuda_data_transfer.py +++ b/tests/v1/multiprocess/test_non_cuda_data_transfer.py @@ -1162,3 +1162,134 @@ def test_non_gpu_context_shm_close_is_idempotent() -> None: finally: if os.path.exists(shm_path): os.unlink(shm_path) + + +def test_non_gpu_context_shm_registers_and_unregisters_host_memory( + monkeypatch: Any, +) -> None: + shm_name = f"lmcache_test_pin_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + + class FakeCudaRt: + def __init__(self) -> None: + self.register_calls: list[tuple[int, int, int]] = [] + self.unregister_calls: list[int] = [] + + def cudaHostRegister(self, ptr: int, size: int, flags: int) -> int: + self.register_calls.append((ptr, size, flags)) + return 0 + + def cudaHostUnregister(self, ptr: int) -> int: + self.unregister_calls.append(ptr) + return 0 + + class FakeTorchDev: + def __init__(self, cudart: FakeCudaRt) -> None: + self._cudart = cudart + + def is_available(self) -> bool: + return True + + def cudart(self) -> FakeCudaRt: + return self._cudart + + # First Party + import lmcache.v1.multiprocess.transfer_context.shm as shm_module + + fake_cudart = FakeCudaRt() + monkeypatch.setattr(shm_module, "torch_dev", FakeTorchDev(fake_cudart)) + + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + try: + assert len(fake_cudart.register_calls) == 1 + ptr, size, flags = fake_cudart.register_calls[0] + assert ptr > 0 + assert size == 4096 + assert flags == 0 + finally: + context.close() + if os.path.exists(shm_path): + os.unlink(shm_path) + + assert fake_cudart.unregister_calls == [ptr] + + +def test_non_gpu_context_shm_register_failure_warns_and_skips_unregister( + monkeypatch: Any, +) -> None: + shm_name = f"lmcache_test_pin_fail_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + + class FakeCudaRt: + def __init__(self) -> None: + self.register_calls: list[tuple[int, int, int]] = [] + self.unregister_calls: list[int] = [] + + def cudaHostRegister(self, ptr: int, size: int, flags: int) -> int: + self.register_calls.append((ptr, size, flags)) + return 1 + + def cudaHostUnregister(self, ptr: int) -> int: + self.unregister_calls.append(ptr) + return 0 + + class FakeTorchDev: + def __init__(self, cudart: FakeCudaRt) -> None: + self._cudart = cudart + + def is_available(self) -> bool: + return True + + def cudart(self) -> FakeCudaRt: + return self._cudart + + # First Party + import lmcache.v1.multiprocess.transfer_context.shm as shm_module + + fake_cudart = FakeCudaRt() + monkeypatch.setattr(shm_module, "torch_dev", FakeTorchDev(fake_cudart)) + + with patch.object(shm_module.logger, "warning") as warning_mock: + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + try: + assert len(fake_cudart.register_calls) == 1 + finally: + context.close() + if os.path.exists(shm_path): + os.unlink(shm_path) + + assert fake_cudart.unregister_calls == [] + warning_mock.assert_called_once() + message, logged_shm_name, _logged_ptr, logged_size, logged_err = ( + warning_mock.call_args[0] + ) + assert "cudaHostRegister failed" in message + assert logged_shm_name == shm_name + assert logged_size == 4096 + assert logged_err == 1 diff --git a/tests/v1/test_c_ops_fallback_parity.py b/tests/v1/test_c_ops_fallback_parity.py index fb97b67576..5d15bbd918 100644 --- a/tests/v1/test_c_ops_fallback_parity.py +++ b/tests/v1/test_c_ops_fallback_parity.py @@ -212,9 +212,7 @@ def _has_real_names(params): # ── Discover the intersection automatically ── # Functions intentionally excluded from parity checks. -_EXCLUDED_FUNCS: set[str] = { - "multi_layer_block_kv_transfer", -} +_EXCLUDED_FUNCS: set[str] = set() _fallback_callables = _public_callables(fallback) _c_ops_callables = _public_callables(c_ops) if HAS_C_OPS else {} diff --git a/tests/v1/test_lmcache_mp_connector_preemption.py b/tests/v1/test_lmcache_mp_connector_preemption.py new file mode 100644 index 0000000000..5860ed254d --- /dev/null +++ b/tests/v1/test_lmcache_mp_connector_preemption.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from types import SimpleNamespace +from unittest.mock import MagicMock + +# Third Party +import pytest + +pytest.importorskip("vllm") + +# Third Party +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole + +# First Party +from lmcache.integration.vllm.lmcache_mp_connector import ( + LMCacheMPConnector, + LMCacheMPConnectorMetadata, +) + + +def _new_connector_without_init() -> LMCacheMPConnector: + connector = LMCacheMPConnector.__new__(LMCacheMPConnector) + return connector + + +def test_build_connector_meta_sets_need_flush_from_preemption_signal() -> None: + connector = _new_connector_without_init() + connector._process_retrieve_requests = lambda metadata: None + connector._process_new_requests = lambda scheduler_output, metadata: None + connector._process_cached_requests = lambda scheduler_output, metadata: None + connector._report_block_allocation_deltas = lambda scheduler_output: None + + scheduler_output = SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace(resumed_req_ids=["req-1"]) + ) + metadata = connector.build_connector_meta(scheduler_output) + assert isinstance(metadata, LMCacheMPConnectorMetadata) + assert metadata.need_flush is True + + +def test_build_connector_meta_keeps_need_flush_false_without_signal() -> None: + connector = _new_connector_without_init() + connector._process_retrieve_requests = lambda metadata: None + connector._process_new_requests = lambda scheduler_output, metadata: None + connector._process_cached_requests = lambda scheduler_output, metadata: None + connector._report_block_allocation_deltas = lambda scheduler_output: None + + scheduler_output = SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace(resumed_req_ids=[]), + preempted_req_ids=[], + ) + metadata = connector.build_connector_meta(scheduler_output) + assert isinstance(metadata, LMCacheMPConnectorMetadata) + assert metadata.need_flush is False + + +@pytest.mark.parametrize( + "scheduler_output", + [ + # Resumed-from-preemption signal (CachedRequestData.resumed_req_ids). + SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace(resumed_req_ids=["req-1"]), + ), + # Preempted-this-step signal (SchedulerOutput.preempted_req_ids). + SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace(resumed_req_ids=[]), + preempted_req_ids=["req-1"], + ), + ], +) +def test_scheduler_step_needs_flush_for_all_supported_signals( + scheduler_output: SimpleNamespace, +) -> None: + connector = _new_connector_without_init() + assert connector._scheduler_step_needs_flush(scheduler_output) is True + + +def test_handle_preemptions_forwards_flush_hint_to_worker_adapter() -> None: + connector = _new_connector_without_init() + connector._role = KVConnectorRole.WORKER + connector.worker_adapter = MagicMock() + metadata = LMCacheMPConnectorMetadata() + + connector.handle_preemptions(metadata) + connector.worker_adapter.handle_preemptions.assert_called_once_with(False) + + connector.worker_adapter.reset_mock() + metadata.need_flush = True + connector.handle_preemptions(metadata) + connector.worker_adapter.handle_preemptions.assert_called_once_with(True) + + +def test_scheduler_step_needs_flush_conservative_on_unknown_schema() -> None: + connector = _new_connector_without_init() + scheduler_output = SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace(unexpected_field=["req-1"]) + ) + assert connector._scheduler_step_needs_flush(scheduler_output) is True + + +def test_scheduler_step_needs_flush_false_for_recognized_no_preemption() -> None: + connector = _new_connector_without_init() + scheduler_output = SimpleNamespace( + scheduled_cached_reqs=SimpleNamespace(resumed_req_ids=[]) + ) + assert connector._scheduler_step_needs_flush(scheduler_output) is False diff --git a/tests/v1/test_python_ops_fallback.py b/tests/v1/test_python_ops_fallback.py index f0f6e648e4..9bb6c5efa9 100644 --- a/tests/v1/test_python_ops_fallback.py +++ b/tests/v1/test_python_ops_fallback.py @@ -1796,6 +1796,852 @@ class _FakeStream: return {"dispatcher_integration": torch.tensor([1], dtype=torch.int32)} +def scenario_multi_layer_block_kv_transfer( + ops: Any, device: str +) -> dict[str, torch.Tensor]: + """Test multi_layer_block_kv_transfer across all GPU KV formats. + + Exercises the block-based transfer path for: + - NHD per-layer format (D2H and H2D round-trip) + - FlashInfer NHD per-layer format (interleaved K/V) + - HND per-layer format (with permute) + - FlashInfer HND per-layer format (interleaved K/V) + - MLA per-layer format (no K/V split) + - SGLang MLA flat per-layer format + - Cross-layer NHD (single tensor [NB, NL, 2, BS, NH, HS]) + - Cross-layer HND (single tensor [NB, NL, 2, NH, BS, HS]) + - SGLang MHA flat (2*NL tensors [NB*BS, NH, HS]) + - SGLang MHA block (2*NL tensors [NB, BS, NH, HS]) + - non-sequential block_ids and list[int] block_ids input + - skip_prefix_n_blocks > 0 + - num_blocks > blocks_per_chunk (multi-chunk) + """ + results = {} + + # C++ bindings (cuda_c_ops, xpu_sycl_ops) expect uint64 pointer tensors for + # paged_buffer_ptrs_tensor and list[int] for lmcache_objects_ptrs. + # The Python fallback also supports both modes on cpu/cuda (pointer inputs + # are reconstructed internally via _tensor_from_ptr). + use_tensor_list = device not in ("cpu", "cuda") + + def _alloc_chunks(shape: tuple[int, ...], count: int) -> list[torch.Tensor]: + chunks = [torch.zeros(shape, dtype=dtype) for _ in range(count)] + if device in ("cuda"): + chunks = [chunk.pin_memory() for chunk in chunks] + return chunks + + # --- NHD per-layer --- + torch.manual_seed(123) + num_layers, num_blocks, block_size = 2, 8, 4 + num_heads, head_size = 2, 8 + blocks_per_chunk = 4 + chunk_tokens = blocks_per_chunk * block_size + hidden_dim = num_heads * head_size + dtype = torch.float32 + + paged_layers = [ + torch.randn(2, num_blocks, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + shape_desc = ops.PageBufferShapeDesc() + shape_desc.nl = num_layers + shape_desc.nb = num_blocks + shape_desc.bs = block_size + shape_desc.nh = num_heads + shape_desc.hs = head_size + shape_desc.element_size = dtype.itemsize + shape_desc.kv_size = 2 + gpu_kv_format = ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS + num_chunks = num_blocks // blocks_per_chunk + d2h_chunks = _alloc_chunks((2, num_layers, chunk_tokens, hidden_dim), num_chunks) + block_ids = list(range(num_blocks)) + + ops.multi_layer_block_kv_transfer( + paged_layers + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers], + dtype=torch.uint64, + device=device, + ), + d2h_chunks if use_tensor_list else [c.data_ptr() for c in d2h_chunks], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format, + 0, + ) + paged_h2d = [torch.zeros_like(layer) for layer in paged_layers] + ops.multi_layer_block_kv_transfer( + paged_h2d + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d], dtype=torch.uint64, device=device + ), + d2h_chunks if use_tensor_list else [c.data_ptr() for c in d2h_chunks], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format, + 0, + ) + for i in range(num_layers): + orig = paged_layers[i].cpu() + recon = paged_h2d[i].cpu() + results[f"nhd_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"NHD Layer {i} round-trip mismatch" + ) + + # --- FlashInfer NHD per-layer --- + torch.manual_seed(234) + paged_layers_fi_nhd = [ + torch.randn(num_blocks, 2, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + gpu_kv_format_fi_nhd = ops.GPUKVFormat.NL_X_NB_TWO_BS_NH_HS + d2h_chunks_fi_nhd = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_fi_nhd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_fi_nhd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_fi_nhd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_fi_nhd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_fi_nhd, + 0, + ) + paged_h2d_fi_nhd = [torch.zeros_like(layer) for layer in paged_layers_fi_nhd] + ops.multi_layer_block_kv_transfer( + paged_h2d_fi_nhd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_fi_nhd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_fi_nhd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_fi_nhd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_fi_nhd, + 0, + ) + for i in range(num_layers): + orig = paged_layers_fi_nhd[i].cpu() + recon = paged_h2d_fi_nhd[i].cpu() + results[f"flashinfer_nhd_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"FlashInfer NHD Layer {i} round-trip mismatch" + ) + + # --- HND per-layer --- + torch.manual_seed(456) + paged_layers_hnd = [ + torch.randn(2, num_blocks, num_heads, block_size, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + gpu_kv_format_hnd = ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS + d2h_chunks_hnd = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_hnd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_hnd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_hnd if use_tensor_list else [c.data_ptr() for c in d2h_chunks_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_hnd, + 0, + ) + paged_h2d_hnd = [torch.zeros_like(layer) for layer in paged_layers_hnd] + ops.multi_layer_block_kv_transfer( + paged_h2d_hnd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_hnd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_hnd if use_tensor_list else [c.data_ptr() for c in d2h_chunks_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_hnd, + 0, + ) + for i in range(num_layers): + orig = paged_layers_hnd[i].cpu() + recon = paged_h2d_hnd[i].cpu() + results[f"hnd_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"HND Layer {i} round-trip mismatch" + ) + + # --- FlashInfer HND per-layer --- + torch.manual_seed(567) + paged_layers_fi_hnd = [ + torch.randn(num_blocks, 2, num_heads, block_size, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + gpu_kv_format_fi_hnd = ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS + d2h_chunks_fi_hnd = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_fi_hnd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_fi_hnd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_fi_hnd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_fi_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_fi_hnd, + 0, + ) + paged_h2d_fi_hnd = [torch.zeros_like(layer) for layer in paged_layers_fi_hnd] + ops.multi_layer_block_kv_transfer( + paged_h2d_fi_hnd + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_fi_hnd], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_fi_hnd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_fi_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_fi_hnd, + 0, + ) + for i in range(num_layers): + orig = paged_layers_fi_hnd[i].cpu() + recon = paged_h2d_fi_hnd[i].cpu() + results[f"flashinfer_hnd_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"FlashInfer HND Layer {i} round-trip mismatch" + ) + + # --- MLA per-layer --- + torch.manual_seed(789) + mla_hidden = 16 + paged_layers_mla = [ + torch.randn(num_blocks, block_size, mla_hidden, dtype=dtype).to(device) + for _ in range(num_layers) + ] + shape_desc_mla = ops.PageBufferShapeDesc() + shape_desc_mla.nl = num_layers + shape_desc_mla.nb = num_blocks + shape_desc_mla.bs = block_size + shape_desc_mla.nh = 1 + shape_desc_mla.hs = mla_hidden + shape_desc_mla.element_size = dtype.itemsize + shape_desc_mla.kv_size = 1 + gpu_kv_format_mla = ops.GPUKVFormat.NL_X_NB_BS_HS + d2h_chunks_mla = _alloc_chunks((num_layers, chunk_tokens, mla_hidden), num_chunks) + ops.multi_layer_block_kv_transfer( + paged_layers_mla + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_mla], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_mla if use_tensor_list else [c.data_ptr() for c in d2h_chunks_mla], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc_mla, + chunk_tokens, + gpu_kv_format_mla, + 0, + ) + paged_h2d_mla = [torch.zeros_like(layer) for layer in paged_layers_mla] + ops.multi_layer_block_kv_transfer( + paged_h2d_mla + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_mla], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_mla if use_tensor_list else [c.data_ptr() for c in d2h_chunks_mla], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc_mla, + chunk_tokens, + gpu_kv_format_mla, + 0, + ) + for i in range(num_layers): + orig = paged_layers_mla[i].cpu() + recon = paged_h2d_mla[i].cpu() + results[f"mla_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"MLA Layer {i} round-trip mismatch" + ) + + # --- SGLang MLA flat per-layer --- + torch.manual_seed(890) + paged_layers_sglang_mla = [ + torch.randn(num_blocks * block_size, 1, mla_hidden, dtype=dtype).to(device) + for _ in range(num_layers) + ] + gpu_kv_format_sglang_mla = ops.GPUKVFormat.NL_X_NBBS_ONE_HS + d2h_chunks_sglang_mla = _alloc_chunks( + (num_layers, chunk_tokens, mla_hidden), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_sglang_mla + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_sglang_mla], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_mla + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_mla], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc_mla, + chunk_tokens, + gpu_kv_format_sglang_mla, + 0, + ) + paged_h2d_sglang_mla = [ + torch.zeros_like(layer) for layer in paged_layers_sglang_mla + ] + ops.multi_layer_block_kv_transfer( + paged_h2d_sglang_mla + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_sglang_mla], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_mla + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_mla], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc_mla, + chunk_tokens, + gpu_kv_format_sglang_mla, + 0, + ) + for i in range(num_layers): + orig = paged_layers_sglang_mla[i].cpu() + recon = paged_h2d_sglang_mla[i].cpu() + results[f"sglang_mla_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"SGLang MLA Layer {i} round-trip mismatch" + ) + + # --- Cross-layer NHD --- + torch.manual_seed(101) + paged_cross_nhd = torch.randn( + num_blocks, + num_layers, + 2, + block_size, + num_heads, + head_size, + dtype=dtype, + ).to(device) + gpu_kv_format_cross_nhd = ops.GPUKVFormat.NB_NL_TWO_BS_NH_HS + d2h_chunks_cross_nhd = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_cross_nhd + if use_tensor_list + else torch.tensor( + [paged_cross_nhd.data_ptr()], dtype=torch.uint64, device=device + ), + d2h_chunks_cross_nhd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_cross_nhd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_cross_nhd, + 0, + ) + paged_h2d_cross_nhd = torch.zeros_like(paged_cross_nhd) + ops.multi_layer_block_kv_transfer( + paged_h2d_cross_nhd + if use_tensor_list + else torch.tensor( + [paged_h2d_cross_nhd.data_ptr()], dtype=torch.uint64, device=device + ), + d2h_chunks_cross_nhd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_cross_nhd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_cross_nhd, + 0, + ) + orig = paged_cross_nhd.cpu() + recon = paged_h2d_cross_nhd.cpu() + results["cross_nhd"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), "Cross-layer NHD mismatch" + + # --- Cross-layer HND --- + torch.manual_seed(202) + paged_cross_hnd = torch.randn( + num_blocks, + num_layers, + 2, + num_heads, + block_size, + head_size, + dtype=dtype, + ).to(device) + gpu_kv_format_cross_hnd = ops.GPUKVFormat.NB_NL_TWO_NH_BS_HS + d2h_chunks_cross_hnd = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_cross_hnd + if use_tensor_list + else torch.tensor( + [paged_cross_hnd.data_ptr()], dtype=torch.uint64, device=device + ), + d2h_chunks_cross_hnd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_cross_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_cross_hnd, + 0, + ) + paged_h2d_cross_hnd = torch.zeros_like(paged_cross_hnd) + ops.multi_layer_block_kv_transfer( + paged_h2d_cross_hnd + if use_tensor_list + else torch.tensor( + [paged_h2d_cross_hnd.data_ptr()], dtype=torch.uint64, device=device + ), + d2h_chunks_cross_hnd + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_cross_hnd], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_cross_hnd, + 0, + ) + orig = paged_cross_hnd.cpu() + recon = paged_h2d_cross_hnd.cpu() + results["cross_hnd"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), "Cross-layer HND mismatch" + + # --- SGLang MHA flat (NBBS) --- + torch.manual_seed(303) + pbs = num_blocks * block_size + paged_sglang_nbbs = [ + [ + torch.randn(pbs, num_heads, head_size, dtype=dtype).to(device) + for _ in range(num_layers) + ], + [ + torch.randn(pbs, num_heads, head_size, dtype=dtype).to(device) + for _ in range(num_layers) + ], + ] + gpu_kv_format_sglang_nbbs = ops.GPUKVFormat.TWO_X_NL_X_NBBS_NH_HS + d2h_chunks_sglang_nbbs = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_sglang_nbbs + if use_tensor_list + else torch.tensor( + [t.data_ptr() for t in paged_sglang_nbbs[0]] + + [t.data_ptr() for t in paged_sglang_nbbs[1]], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_nbbs + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_nbbs], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_sglang_nbbs, + 0, + ) + paged_h2d_sglang_nbbs = [ + [torch.zeros_like(t) for t in group] for group in paged_sglang_nbbs + ] + ops.multi_layer_block_kv_transfer( + paged_h2d_sglang_nbbs + if use_tensor_list + else torch.tensor( + [t.data_ptr() for t in paged_h2d_sglang_nbbs[0]] + + [t.data_ptr() for t in paged_h2d_sglang_nbbs[1]], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_nbbs + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_nbbs], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_sglang_nbbs, + 0, + ) + for kv in range(2): + for i in range(num_layers): + orig = paged_sglang_nbbs[kv][i].cpu() + recon = paged_h2d_sglang_nbbs[kv][i].cpu() + results[f"sglang_nbbs_kv{kv}_l{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"SGLang NBBS kv={kv} layer={i} mismatch" + ) + + # --- SGLang MHA block (NB_BS) --- + torch.manual_seed(404) + paged_sglang_nb = [ + [ + torch.randn(num_blocks, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ], + [ + torch.randn(num_blocks, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ], + ] + gpu_kv_format_sglang_nb = ops.GPUKVFormat.TWO_X_NL_X_NB_BS_NH_HS + d2h_chunks_sglang_nb = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_sglang_nb + if use_tensor_list + else torch.tensor( + [t.data_ptr() for t in paged_sglang_nb[0]] + + [t.data_ptr() for t in paged_sglang_nb[1]], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_nb + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_nb], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_sglang_nb, + 0, + ) + paged_h2d_sglang_nb = [ + [torch.zeros_like(t) for t in group] for group in paged_sglang_nb + ] + ops.multi_layer_block_kv_transfer( + paged_h2d_sglang_nb + if use_tensor_list + else torch.tensor( + [t.data_ptr() for t in paged_h2d_sglang_nb[0]] + + [t.data_ptr() for t in paged_h2d_sglang_nb[1]], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_sglang_nb + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_sglang_nb], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_sglang_nb, + 0, + ) + for kv in range(2): + for i in range(num_layers): + orig = paged_sglang_nb[kv][i].cpu() + recon = paged_h2d_sglang_nb[kv][i].cpu() + results[f"sglang_nb_kv{kv}_l{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"SGLang NB kv={kv} layer={i} mismatch" + ) + + # --- skip_prefix_n_blocks > 0 --- + torch.manual_seed(505) + skip_n = 2 + paged_layers_skip = [ + torch.randn(2, num_blocks, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + gpu_kv_format_nhd = ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS + # With skip=2, effective blocks start at index 2. + # Object 0 occupies flat indices [0, blocks_per_chunk), skipping first 2. + # Object 1 occupies flat indices [blocks_per_chunk, 2*blocks_per_chunk). + d2h_chunks_skip = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_skip + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_skip], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_skip if use_tensor_list else [c.data_ptr() for c in d2h_chunks_skip], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_nhd, + skip_n, + ) + paged_h2d_skip = [torch.zeros_like(layer) for layer in paged_layers_skip] + ops.multi_layer_block_kv_transfer( + paged_h2d_skip + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_skip], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_skip if use_tensor_list else [c.data_ptr() for c in d2h_chunks_skip], + torch.tensor(block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_nhd, + skip_n, + ) + # Blocks [skip_n:num_blocks] should round-trip at their original positions + for layer_idx in range(num_layers): + for kv in range(2): + orig_blocks = paged_layers_skip[layer_idx][kv, skip_n:num_blocks] + recon_blocks = paged_h2d_skip[layer_idx][kv, skip_n:num_blocks] + key = f"skip_l{layer_idx}_kv{kv}" + results[key] = torch.stack([orig_blocks.cpu(), recon_blocks.cpu()]) + assert torch.allclose(orig_blocks, recon_blocks, atol=1e-6), ( + f"Skip mismatch at layer={layer_idx} kv={kv}" + ) + # Skipped blocks [0:skip_n] should remain zero + for layer_idx in range(num_layers): + for kv in range(2): + skipped = paged_h2d_skip[layer_idx][kv, :skip_n] + assert torch.all(skipped == 0), ( + f"Skipped blocks not zero at layer={layer_idx} kv={kv}" + ) + assert torch.all(d2h_chunks_skip[0][:, :, : skip_n * block_size] == 0), ( + "Skipped D2H chunk region should remain zero" + ) + + # --- Non-sequential block_ids --- + torch.manual_seed(515) + paged_layers_permuted = [ + torch.randn(2, num_blocks, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + generator = torch.Generator(device="cpu").manual_seed(616) + permuted_block_ids = torch.randperm(num_blocks, generator=generator).tolist() + d2h_chunks_permuted = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks + ) + ops.multi_layer_block_kv_transfer( + paged_layers_permuted + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_permuted], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_permuted + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_permuted], + torch.tensor(permuted_block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc, + chunk_tokens, + gpu_kv_format_nhd, + 0, + ) + paged_h2d_permuted = [torch.zeros_like(layer) for layer in paged_layers_permuted] + ops.multi_layer_block_kv_transfer( + paged_h2d_permuted + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_permuted], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_permuted + if use_tensor_list + else [c.data_ptr() for c in d2h_chunks_permuted], + torch.tensor(permuted_block_ids, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc, + chunk_tokens, + gpu_kv_format_nhd, + 0, + ) + for i in range(num_layers): + orig = paged_layers_permuted[i].cpu() + recon = paged_h2d_permuted[i].cpu() + results[f"permuted_nhd_layer{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"Permuted NHD Layer {i} round-trip mismatch" + ) + + # --- Multi-chunk (num_blocks > blocks_per_chunk) --- + torch.manual_seed(606) + num_blocks_mc = 12 + paged_layers_mc = [ + torch.randn(2, num_blocks_mc, block_size, num_heads, head_size, dtype=dtype).to( + device + ) + for _ in range(num_layers) + ] + shape_desc_mc = ops.PageBufferShapeDesc() + shape_desc_mc.nl = num_layers + shape_desc_mc.nb = num_blocks_mc + shape_desc_mc.bs = block_size + shape_desc_mc.nh = num_heads + shape_desc_mc.hs = head_size + shape_desc_mc.element_size = dtype.itemsize + shape_desc_mc.kv_size = 2 + num_chunks_mc = num_blocks_mc // blocks_per_chunk # 3 chunks + d2h_chunks_mc = _alloc_chunks( + (2, num_layers, chunk_tokens, hidden_dim), num_chunks_mc + ) + block_ids_mc = list(range(num_blocks_mc)) + ops.multi_layer_block_kv_transfer( + paged_layers_mc + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_layers_mc], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_mc if use_tensor_list else [c.data_ptr() for c in d2h_chunks_mc], + torch.tensor(block_ids_mc, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.D2H, + shape_desc_mc, + chunk_tokens, + gpu_kv_format_nhd, + 0, + ) + paged_h2d_mc = [torch.zeros_like(layer) for layer in paged_layers_mc] + ops.multi_layer_block_kv_transfer( + paged_h2d_mc + if use_tensor_list + else torch.tensor( + [layer.data_ptr() for layer in paged_h2d_mc], + dtype=torch.uint64, + device=device, + ), + d2h_chunks_mc if use_tensor_list else [c.data_ptr() for c in d2h_chunks_mc], + torch.tensor(block_ids_mc, dtype=torch.int64, device=device), + torch.device(device), + ops.TransferDirection.H2D, + shape_desc_mc, + chunk_tokens, + gpu_kv_format_nhd, + 0, + ) + for i in range(num_layers): + orig = paged_layers_mc[i].cpu() + recon = paged_h2d_mc[i].cpu() + results[f"multi_chunk_l{i}"] = torch.stack([orig, recon]) + assert torch.allclose(orig, recon, atol=1e-6), ( + f"Multi-chunk Layer {i} round-trip mismatch" + ) + + return results + + def scenario_record_drain_event(ops: Any, device: str) -> dict[str, torch.Tensor]: """Test record_event_on_stream / drain_recorded_events contracts. @@ -1860,6 +2706,7 @@ def scenario_record_drain_event(ops: Any, device: str) -> dict[str, torch.Tensor "transfer_direction_enum": scenario_transfer_direction_enum, "multi_layer_kv_transfer": scenario_multi_layer_kv_transfer, "multi_layer_kv_transfer_unilateral": scenario_multi_layer_kv_transfer_unilateral, + "multi_layer_block_kv_transfer": scenario_multi_layer_block_kv_transfer, "single_layer_kv_transfer": scenario_single_layer_kv_transfer, "single_layer_kv_transfer_sgl": scenario_single_layer_kv_transfer_sgl, "load_and_reshape_flash": scenario_load_and_reshape_flash, diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index 635e7ea09d..a9f08da303 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -296,3 +296,15 @@ def test_retrieve_keeps_event_until_future_finishes(fake_adapter): transfer_ctx.reset_mock() gc.collect() assert event_ref() is None + + +def test_handle_preemptions_flushes_only_when_signaled(fake_adapter): + adapter, _send_mock, _future = fake_adapter + transfer_ctx = MagicMock() + adapter.transfer_ctx = transfer_ctx + + adapter.handle_preemptions(False) + transfer_ctx.flush_inflight_gathers.assert_not_called() + + adapter.handle_preemptions(True) + transfer_ctx.flush_inflight_gathers.assert_called_once_with()