From 2da15192b3a0d16f180b206025786c03fc4d6d64 Mon Sep 17 00:00:00 2001 From: maobaolong Date: Sun, 7 Jun 2026 16:42:58 +0800 Subject: [PATCH 01/57] feat: add POSIX SHM infra for CPU KV-cache IPC (#3563) * feat: add POSIX SHM infra for CPU KV-cache IPC - lmcache/v1/multiprocess/posix_shm.py: thin POSIX-SHM facade (shm_create_readwrite / shm_map_readwrite / shm_munmap / shm_unlink / shm_open_pool_as_mmap) routing through CPython's _posixshmem to avoid macOS EACCES and shutdown BufferError issues - lmcache/v1/platform/cpu/shm.py: CpuShmTensorWrapper + migrate_to_shm_and_wrap for zero-copy CPU KV-cache IPC mirroring CUDA-IPC semantics - lmcache/v1/platform/cpu/__init__.py: self-register cpu factory with platform registry - tests/v1/multiprocess/test_posix_shm.py: unit tests for posix_shm - tests/v1/platform/test_cpu_shm.py: unit tests for CpuShmTensorWrapper Signed-off-by: baoloongmao * address comment Signed-off-by: baoloongmao * address comment Signed-off-by: baoloongmao * assert zero storage_offset before SHM migration Signed-off-by: baoloongmao * add warning logs to swallowed exceptions in posix_shm Signed-off-by: baoloongmao --------- Signed-off-by: baoloongmao --- .../vllm/vllm_multi_process_adapter.py | 6 +- lmcache/v1/multiprocess/posix_shm.py | 273 ++++++++++++++++++ lmcache/v1/platform/cpu/__init__.py | 34 ++- lmcache/v1/platform/cpu/shm.py | 267 +++++++++++++++++ tests/v1/multiprocess/test_posix_shm.py | 78 +++++ tests/v1/platform/__init__.py | 1 + tests/v1/platform/test_cpu_shm.py | 234 +++++++++++++++ 7 files changed, 887 insertions(+), 6 deletions(-) create mode 100644 lmcache/v1/multiprocess/posix_shm.py create mode 100644 lmcache/v1/platform/cpu/shm.py create mode 100644 tests/v1/multiprocess/test_posix_shm.py create mode 100644 tests/v1/platform/__init__.py create mode 100644 tests/v1/platform/test_cpu_shm.py diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 86578b22db..80dfde3717 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -149,7 +149,7 @@ def wrap_kv_caches(kv_caches: dict[str, torch.Tensor]) -> KVCache: wrappers: KVCache = [] try: for tensor in kv_caches.values(): - wrappers.append(_wrap_one_kv_cache(tensor)) + wrappers.append(wrap_one_kv_cache(tensor)) except BaseException: _release_partial_kv_wrappers(wrappers) raise @@ -165,7 +165,7 @@ def _release_partial_kv_wrappers(wrappers: list[Any]) -> None: are silently skipped. """ # First Party - from lmcache.v1.platform.cpu.shm import shm_unlink + from lmcache.v1.multiprocess.posix_shm import shm_unlink for w in wrappers: name = getattr(w, "shm_name", None) @@ -177,7 +177,7 @@ def _release_partial_kv_wrappers(wrappers: list[Any]) -> None: logger.debug("shm_unlink failed during rollback", exc_info=True) -def _wrap_one_kv_cache(tensor: torch.Tensor) -> Any: +def wrap_one_kv_cache(tensor: torch.Tensor) -> Any: """Dispatch by ``tensor.device.type`` via the platform registry. Concrete factories self-register at import time (CUDA in diff --git a/lmcache/v1/multiprocess/posix_shm.py b/lmcache/v1/multiprocess/posix_shm.py new file mode 100644 index 0000000000..9445021666 --- /dev/null +++ b/lmcache/v1/multiprocess/posix_shm.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shared-memory primitives shared by SHM-based transports. + +Thin POSIX-SHM facade exposing the legacy +``shm_create_readwrite`` / ``shm_map_readwrite`` / ``shm_munmap`` / +``shm_unlink`` / ``shm_open_pool_as_mmap`` quartet, so the CPU +KV-cache wrapper, the MP non-GPU SHM transport, and the existing +tests keep working unchanged. + +We deliberately route through CPython's bundled ``_posixshmem`` C +extension (used internally by :mod:`multiprocessing.shared_memory`) +rather than the higher-level :class:`SharedMemory` wrapper. The +wrapper keeps an internal ``memoryview`` over its own ``mmap``; +when callers also export a buffer (via +``ctypes.c_uint8.from_buffer(shm.buf)`` / ``torch.frombuffer(...)``), +:meth:`SharedMemory.close` invoked from ``__del__`` at interpreter +shutdown raises ``BufferError: cannot close exported pointers +exist``. Owning the ``mmap`` ourselves and pairing every alloc with +an explicit ``shm_munmap`` keeps shutdown silent on macOS and Linux +alike. + +The previous hand-rolled libc/librt implementation tripped over +macOS' shm_open MAC label propagation when certain native +extensions (torch + a few others) were already loaded in the +parent process, producing spurious ``errno=13 / EACCES`` failures +on the child side. Routing through ``_posixshmem.shm_open`` -- the +same underlying entry point CPython's stdlib uses -- fixes that +and is identical on Linux. +""" + +# Future +from __future__ import annotations + +# Standard +import atexit +import ctypes +import logging +import mmap as _mmap +import os +import threading + +# Third Party +import _posixshmem # type: ignore[import-not-found] + +logger = logging.getLogger(__name__) + + +def _strip_leading_slash(name: str) -> str: + """Normalise a name to the bare form (no leading ``/``). + + Callers historically embed the POSIX leading slash; we keep the + on-wire name slash-prefixed but feed the bare form to + ``_posixshmem.shm_open``-derived helpers that prepend it again. + """ + return name[1:] if name.startswith("/") else name + + +def _slashed(name: str) -> str: + """Inverse of :func:`_strip_leading_slash` for shm_open calls.""" + return name if name.startswith("/") else "/" + name + + +# Per-process registry mapping the public ``int`` address back to the +# ``mmap`` object that owns the mapping, so a later ``shm_munmap`` can +# call ``mmap.close()`` exactly once and avoid leaking pages. Owners +# (creators) also remember the name so ``shm_unlink`` can find it +# without a re-open round-trip. +_REGISTRY_LOCK = threading.Lock() +_ADDR_TO_MMAP: dict[int, _mmap.mmap] = {} +_OWNED_NAMES: set[str] = set() + + +def _open_and_mmap(name: str, nbytes: int, *, create: bool) -> tuple[_mmap.mmap, int]: + """Open (or create) a POSIX SHM segment and ``mmap`` it. + + Returns a ``(mmap_obj, base_addr)`` pair. The fd is always closed + before returning so we don't leak descriptors; the kernel keeps + the mapping alive as long as ``mmap_obj`` stays alive. + """ + flags = os.O_RDWR | (os.O_CREAT | os.O_EXCL if create else 0) + fd = _posixshmem.shm_open(_slashed(name), flags, mode=0o600) + mm: _mmap.mmap | None = None + try: + if create: + os.ftruncate(fd, nbytes) + mm = _mmap.mmap(fd, nbytes, access=_mmap.ACCESS_WRITE) + addr = _addr_of_mmap(mm) + except BaseException: + if mm is not None: + mm.close() + if create: + try: + _posixshmem.shm_unlink(_slashed(name)) + except OSError: + logger.warning( + "shm_unlink failed during cleanup of %s", + name, + exc_info=True, + ) + raise + finally: + os.close(fd) + return mm, addr + + +def _addr_of_mmap(mm: _mmap.mmap) -> int: + """Return the base address of an ``mmap`` without leaking a buffer view. + + A single-byte ctypes view is created just long enough to read the + base address, then dropped before this function returns; once it is + out of scope the mmap has no exported pointers, so a later + ``mm.close()`` can complete cleanly. A 1-byte view is sufficient + -- ``ctypes.addressof`` returns the start of the buffer regardless + of its declared length. + """ + view = (ctypes.c_uint8 * 1).from_buffer(mm) + addr = ctypes.addressof(view) + del view + return addr + + +def shm_create_readwrite(name: str, nbytes: int) -> int: + """Create a new shared-memory segment and return its mmap address. + + Mirrors the previous ``shm_open(O_CREAT|O_EXCL) + ftruncate + + mmap`` sequence: collisions raise ``OSError`` (``FileExistsError`` + is a subclass), and a failure mid-way fully tears down what was + allocated. + + Args: + name: The name of the shared-memory segment. + nbytes: The size of the segment in bytes. + + Returns: + The virtual address of the mapped segment. + + Raises: + OSError: If the segment already exists or creation fails. + """ + sm_name = _strip_leading_slash(name) + mm, addr = _open_and_mmap(sm_name, nbytes, create=True) + with _REGISTRY_LOCK: + _ADDR_TO_MMAP[addr] = mm + _OWNED_NAMES.add(sm_name) + return addr + + +def shm_map_readwrite(name: str, nbytes: int) -> int: + """Open an existing shared-memory segment and return its address. + + ``nbytes`` must match the segment's actual size; ``mmap`` will + raise on a mismatch. + + Args: + name: The name of the shared-memory segment. + nbytes: The size of the segment in bytes. + + Returns: + The virtual address of the mapped segment. + + Raises: + OSError: If the segment cannot be opened or mapped. + """ + sm_name = _strip_leading_slash(name) + mm, addr = _open_and_mmap(sm_name, nbytes, create=False) + with _REGISTRY_LOCK: + _ADDR_TO_MMAP[addr] = mm + return addr + + +def shm_munmap(addr: int, nbytes: int = 0) -> None: + """Best-effort release of a previously mapped segment by address. + + The underlying mmap is closed exactly once; subsequent calls with + the same address are no-ops. + + Args: + addr: The virtual address of the mapped segment. + nbytes: Unused; kept for API compatibility so callers that + already pass the size do not need to be updated. + """ + if not addr: + return + with _REGISTRY_LOCK: + mm = _ADDR_TO_MMAP.pop(addr, None) + if mm is None: + return + try: + mm.close() + except (BufferError, ValueError) as exc: + # ``BufferError`` means callers still hold an exported view + # (e.g. a torch tensor backed by this mmap); they will release + # the mapping themselves on GC. ``ValueError`` means already + # closed -- treat both as best-effort no-ops. + logger.warning( + "shm_munmap: mmap.close() skipped for addr=%#x: %s", + addr, + exc, + ) + + +def shm_unlink(name: str) -> None: + """Best-effort segment removal. + + Idempotent: a missing segment is treated as a successful + no-op so callers can blindly call this on shutdown. + + Args: + name: The name of the shared-memory segment to unlink. + """ + sm_name = _strip_leading_slash(name) + with _REGISTRY_LOCK: + _OWNED_NAMES.discard(sm_name) + try: + _posixshmem.shm_unlink(_slashed(sm_name)) + except FileNotFoundError: + logger.debug("shm_unlink: segment %s already removed", sm_name) + except OSError: + # Mirrors the historical "best effort" contract -- e.g. + # double-unlink on shutdown should never raise. + logger.warning( + "shm_unlink: failed to unlink %s", + sm_name, + exc_info=True, + ) + + +def _atexit_cleanup() -> None: + """Unlink and munmap any SHM segments still owned by this process.""" + with _REGISTRY_LOCK: + names = list(_OWNED_NAMES) + mmaps = list(_ADDR_TO_MMAP.values()) + _ADDR_TO_MMAP.clear() + _OWNED_NAMES.clear() + for mm in mmaps: + try: + mm.close() + except (BufferError, OSError) as exc: + logger.warning("atexit: mmap.close() failed: %s", exc) + for n in names: + try: + _posixshmem.shm_unlink(_slashed(n)) + except OSError as exc: + logger.warning("atexit: shm_unlink(%s) failed: %s", n, exc) + + +atexit.register(_atexit_cleanup) + + +def shm_open_pool_as_mmap(name: str, nbytes: int) -> _mmap.mmap: + """Open an existing segment as an independent ``mmap.mmap`` object. + + Convenience helper for non-GPU SHM transports that consume the + segment via ``torch.frombuffer(mmap_obj, ...)`` rather than a raw + address. The returned mmap is independent of any registry entry, + so the caller takes ownership and is responsible for closing it. + + Args: + name: The name of the shared-memory segment. + nbytes: The size of the segment in bytes. + + Returns: + An independent ``mmap.mmap`` object backed by the segment. + + Raises: + OSError: If the segment cannot be opened or mapped. + """ + sm_name = _strip_leading_slash(name) + fd = _posixshmem.shm_open(_slashed(sm_name), os.O_RDWR, mode=0o600) + try: + return _mmap.mmap(fd, nbytes, access=_mmap.ACCESS_WRITE) + finally: + os.close(fd) diff --git a/lmcache/v1/platform/cpu/__init__.py b/lmcache/v1/platform/cpu/__init__.py index ea5769408b..bb2edd378a 100644 --- a/lmcache/v1/platform/cpu/__init__.py +++ b/lmcache/v1/platform/cpu/__init__.py @@ -1,7 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 """CPU-specific platform primitives. -This package will register a CPU KV-cache wrapper factory with -:mod:`lmcache.v1.platform._registry` once the POSIX-SHM backend -is available. +Importing this package self-registers the POSIX-SHM KV-cache wrapper +factory with :mod:`lmcache.v1.platform._registry`, so the dispatch +in :mod:`lmcache.integration.vllm.vllm_multi_process_adapter` can +pick the right wrapper based on ``tensor.device.type`` without any +if/elif chain. """ + +# Standard +from typing import Any + +# Third Party +import torch + +# First Party +from lmcache.v1.platform._registry import register_kv_wrapper + + +def _kv_wrapper_factory(tensor: torch.Tensor) -> Any: + """Indirect-dispatch wrapper. + + Defers loading :mod:`lmcache.v1.platform.cpu.shm` (which pulls in + ``multiprocess.custom_types``) until first use, so importing this + package during ``lmcache/__init__.py``'s bootstrap does not race + other imports that touch ``torch_dev``. + """ + # First Party + from lmcache.v1.platform.cpu.shm import migrate_to_shm_and_wrap + + return migrate_to_shm_and_wrap(tensor) + + +register_kv_wrapper("cpu", _kv_wrapper_factory) diff --git a/lmcache/v1/platform/cpu/shm.py b/lmcache/v1/platform/cpu/shm.py new file mode 100644 index 0000000000..7bbb702d5b --- /dev/null +++ b/lmcache/v1/platform/cpu/shm.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: Apache-2.0 +"""CPU-only KV-cache IPC wrapper backed by POSIX shared memory. + +Mirrors the GPU-mode CUDA-IPC zero-copy semantics for hosts without an +accelerator: client and LMCache mp server map the **same** physical +pages so transfers are pointer-shuffles rather than memcpys. + +Self-registers a ``"cpu"`` factory with +:mod:`lmcache.v1.platform._registry` at import time, so the +multiprocess adapter can dispatch by ``tensor.device.type`` without +any if/elif chain. +""" + +# Future +from __future__ import annotations + +# Standard +import ctypes +import itertools +import os +import threading +import weakref + +# Third Party +import torch + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.multiprocess.custom_types import CudaIPCWrapper +from lmcache.v1.multiprocess.posix_shm import ( + shm_create_readwrite, + shm_map_readwrite, + shm_munmap, + shm_unlink, +) + +logger = init_logger(__name__) + +# Re-export POSIX-SHM primitives so existing callers keep working. +# The canonical home is :mod:`lmcache.v1.multiprocess.posix_shm`; new +# code (e.g. the MP non-GPU SHM transport) should import from there. +__all__ = [ + "CpuShmTensorWrapper", + "inject_stale_cache_entry_for_test", + "migrate_to_shm_and_wrap", + "shm_create_readwrite", + "shm_map_readwrite", + "shm_munmap", + "shm_unlink", +] + +# --------------------------------------------------------------------------- +# Wrapper class # +# --------------------------------------------------------------------------- + + +class CpuShmTensorWrapper(CudaIPCWrapper): + """IPC wrapper for CPU tensors backed by POSIX shared memory. + + Used by the ``lmcache bench kvcache --mode cpu`` path and the + vLLM CPU integration so that the client and the LMCache mp server + map the **same** physical pages for the KV cache, mirroring the + GPU-mode CUDA-IPC zero-copy semantics. + + Subclassing :class:`CudaIPCWrapper` is load-bearing for the same + reason :class:`RawCudaIPCWrapper` does it: msgspec does not + support unions of custom ext-encoded types, so all wire-level + KV-cache wrappers must share the single ext code (1) registered + for ``CudaIPCWrapper``. Pickle preserves the subclass identity + so ``to_tensor`` dispatches correctly on both sides. + """ + + # POSIX shared-memory name (``/lmcache_...``) -- leading ``/`` is + # required by ``shm_open(3)`` on both Linux and macOS. + SHM_NAME_PREFIX = "/lmcache_kv_" + + def __init__(self, tensor: torch.Tensor, shm_name: str) -> None: + if tensor.device.type != "cpu": + raise ValueError( + "CpuShmTensorWrapper requires a CPU tensor, got %s" % tensor.device + ) + if not tensor.is_contiguous(): + raise ValueError("CpuShmTensorWrapper requires a contiguous tensor") + + self.shm_name = shm_name + # ``numel * element_size`` is the correct logical byte size; the + # underlying storage may be larger when the tensor is a view. + self.nbytes = tensor.numel() * tensor.element_size() + + # CudaIPCWrapper interface fields. ``handle`` / ``device_uuid`` + # are unused on the CPU path but kept to satisfy the parent + # contract used by equality checks. + self.handle = None + self.dtype = tensor.dtype + self.shape = tuple(tensor.shape) + self.stride = tuple(tensor.stride()) + self.storage_offset = int(tensor.storage_offset()) + self.device_uuid = "cpu" + + def to_tensor(self) -> torch.Tensor: + """Reconstruct the tensor by mapping the same SHM segment. + + The returned tensor owns the mmap: a ``weakref.finalize`` hook + runs ``munmap`` once the tensor (and any views derived from it) + is garbage-collected, so the per-process virtual address space + does not leak across repeated ``to_tensor`` calls. + + We rebuild the view through ``as_strided`` so the original + memory layout (stride / storage_offset / memory_format) is + replayed faithfully on the receiving side; reshape would + silently re-coalesce strides and lose, e.g., channels_last. + """ + # Empty tensors carry no SHM segment (mmap with length 0 is + # undefined / EINVAL on POSIX); rebuild the empty view in-process. + if self.nbytes == 0: + return torch.empty(self.shape, dtype=self.dtype) + addr = shm_map_readwrite(self.shm_name, self.nbytes) + # ``torch.frombuffer`` requires a writable buffer; build one + # via ctypes so the resulting torch tensor shares storage + # with the SHM mapping (zero copy across processes). + buf_type = ctypes.c_uint8 * self.nbytes + buf = buf_type.from_address(addr) + flat = torch.frombuffer(buf, dtype=torch.uint8) + typed = flat.view(self.dtype) + out = torch.as_strided(typed, self.shape, self.stride, self.storage_offset) + # Keep ``flat`` alive for the lifetime of ``out`` so its mmap + # is not released while still in use, then munmap on cleanup. + out._lmcache_shm_buf = flat # type: ignore[attr-defined] + weakref.finalize(out, shm_munmap, addr, self.nbytes) + return out + + +# --------------------------------------------------------------------------- +# Migrate-and-wrap factory (used by the multiprocess adapter) # +# --------------------------------------------------------------------------- + +# Per-process registry of SHM segments we have created, so the same +# tensor object is only migrated to SHM once even if the factory is +# called multiple times. +# +# Keyed by ``id(tensor)`` for cheap O(1) lookup, but each entry also +# holds a ``weakref.ref`` to the original tensor and we *verify the +# referent is still that exact object* before reusing the cached SHM +# name. CPython recycles object IDs, so a fresh tensor allocated at +# the same address as a previously migrated (now garbage-collected) +# one would otherwise inherit a stale name -- and because +# :func:`shm_create_readwrite` uses ``O_EXCL``, the next migration +# would crash with ``EEXIST`` ("File exists"). The weakref-validated +# lookup below makes that race impossible: a stale entry can only +# point at a dead referent, which we treat as a miss. +_CPU_SHM_NAMES: dict[int, tuple["weakref.ReferenceType[torch.Tensor]", str]] = {} +_CPU_SHM_LOCK = threading.Lock() +_CPU_SHM_COUNTER = itertools.count() + + +def _cleanup_shm_segment(tid: int, shm_name: str, addr: int, nbytes: int) -> None: + """Release the mmap, unlink, and forget the cached SHM name.""" + with _CPU_SHM_LOCK: + # Only drop the entry if it still points at *this* segment; + # a future tensor reusing ``tid`` may already have replaced it. + cached = _CPU_SHM_NAMES.get(tid) + if cached is not None and cached[1] == shm_name: + _CPU_SHM_NAMES.pop(tid, None) + shm_munmap(addr, nbytes) + shm_unlink(shm_name) + + +def migrate_to_shm_and_wrap(tensor: torch.Tensor) -> CpuShmTensorWrapper: + """Re-point ``tensor``'s storage at a POSIX SHM segment, then wrap. + + Used as the registered ``"cpu"`` KV-wrapper factory: the LMCache mp + server can mmap the same physical pages on the receiving side. + Idempotent per tensor identity (validated via a stored weakref so + Python's id-recycling cannot produce a stale-name hit). The SHM + segment is released (``munmap`` + ``shm_unlink``) automatically + when the migrated tensor is garbage-collected. + """ + # First Party + from lmcache.v1.gpu_connector.utils import attempt_permute_to_contiguous_view + + # Validate and normalise the tensor *before* touching the registry + # or mutating storage, so a bad input never leaves things half-done. + tensor = attempt_permute_to_contiguous_view(tensor) + if tensor.device.type != "cpu": + raise ValueError( + "migrate_to_shm_and_wrap requires a CPU tensor, got %s" % tensor.device + ) + if not tensor.is_contiguous(): + raise ValueError("migrate_to_shm_and_wrap requires a contiguous tensor") + + tid = id(tensor) + + # Fast path: check the registry under the lock, return early if the + # tensor has already been migrated. + with _CPU_SHM_LOCK: + cached = _CPU_SHM_NAMES.get(tid) + if cached is not None: + ref, cached_name = cached + if ref() is tensor: + return CpuShmTensorWrapper(tensor, cached_name) + # Stale entry from a GC'd tensor whose id has been + # reused; drop it and fall through to allocate fresh. + _CPU_SHM_NAMES.pop(tid, None) + + nbytes = tensor.numel() * tensor.element_size() + assert tensor.storage_offset() == 0, ( + "migrate_to_shm_and_wrap: SHM segment is sized to " + "numel*elem_size; a nonzero storage_offset would cause " + "OOB access. Got offset=%d" % tensor.storage_offset() + ) + if nbytes == 0: + # No SHM segment for empty tensors: ``mmap`` with length 0 + # is undefined / EINVAL on POSIX. ``to_tensor`` rebuilds an + # empty view directly when ``shm_name`` is empty. + return CpuShmTensorWrapper(tensor, "") + + shm_name = "%s%d_%d" % ( + CpuShmTensorWrapper.SHM_NAME_PREFIX, + os.getpid(), + next(_CPU_SHM_COUNTER), + ) + # Perform the heavy work (syscall + tensor mutation) outside the lock + # to keep the critical section small. + addr = shm_create_readwrite(shm_name, nbytes) + try: + buf_type = ctypes.c_uint8 * nbytes + buf = buf_type.from_address(addr) + shm_storage = torch.frombuffer(buf, dtype=torch.uint8).untyped_storage() + tensor.set_( + shm_storage, + tensor.storage_offset(), + tensor.shape, + tensor.stride(), + ) + except Exception: + # Make sure the SHM resources don't leak if migration fails + # part-way (e.g. ``set_`` rejects an unusual stride). + shm_munmap(addr, nbytes) + shm_unlink(shm_name) + raise + + with _CPU_SHM_LOCK: + _CPU_SHM_NAMES[tid] = (weakref.ref(tensor), shm_name) + weakref.finalize(tensor, _cleanup_shm_segment, tid, shm_name, addr, nbytes) + logger.info( + "Migrated CPU KV cache tensor (nbytes=%d) to SHM %s", + nbytes, + shm_name, + ) + return CpuShmTensorWrapper(tensor, shm_name) + + +def inject_stale_cache_entry_for_test( + tensor: torch.Tensor, + dead_ref: "weakref.ReferenceType[torch.Tensor]", + stale_shm_name: str, +) -> None: + """Test-only hook: pre-seed the registry with a stale entry. + + Lets unit tests reproduce the CPython id-reuse race -- where a + fresh tensor lands on the same id as a previously migrated and + garbage-collected one -- without the per-test global-state + surgery that would otherwise have to reach into the module's + private dict / lock. + """ + with _CPU_SHM_LOCK: + _CPU_SHM_NAMES[id(tensor)] = (dead_ref, stale_shm_name) diff --git a/tests/v1/multiprocess/test_posix_shm.py b/tests/v1/multiprocess/test_posix_shm.py new file mode 100644 index 0000000000..a04497df50 --- /dev/null +++ b/tests/v1/multiprocess/test_posix_shm.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for ``lmcache.v1.multiprocess.posix_shm``. + +Validates the POSIX-SHM primitives and the ``mmap``-based pool helper +shared by SHM-based transports. +""" + +# Standard +import os + +# Third Party +import pytest + +# First Party +from lmcache.v1.multiprocess.posix_shm import ( + shm_create_readwrite, + shm_map_readwrite, + shm_munmap, + shm_open_pool_as_mmap, + shm_unlink, +) + + +def _unique_name(tag: str) -> str: + # macOS shm_open caps names at 31 bytes incl. leading '/'. + return "/lmc_pshm_%s_%d" % (tag, os.getpid()) + + +def test_create_map_munmap_unlink_roundtrip(): + name = _unique_name("rt") + addr = shm_create_readwrite(name, 4096) + try: + assert addr not in (0, None) + # Map again from a fresh address: same segment, different vaddr. + addr2 = shm_map_readwrite(name, 4096) + try: + assert addr2 not in (0, None) + finally: + shm_munmap(addr2, 4096) + finally: + shm_munmap(addr, 4096) + shm_unlink(name) + + +def test_create_excl_collision(): + name = _unique_name("excl") + addr = shm_create_readwrite(name, 4096) + try: + with pytest.raises(OSError): + shm_create_readwrite(name, 4096) + finally: + shm_munmap(addr, 4096) + shm_unlink(name) + + +def test_open_pool_as_mmap_zero_copy_view(): + name = _unique_name("pool") + nbytes = 4096 + addr = shm_create_readwrite(name, nbytes) + try: + mm = shm_open_pool_as_mmap(name, nbytes) + try: + mm[0:4] = b"\x01\x02\x03\x04" + mm2 = shm_open_pool_as_mmap(name, nbytes) + try: + assert bytes(mm2[0:4]) == b"\x01\x02\x03\x04" + finally: + mm2.close() + finally: + mm.close() + finally: + shm_munmap(addr, nbytes) + shm_unlink(name) + + +def test_munmap_no_op_on_zero_addr(): + # Should not crash; best-effort no-op. + shm_munmap(0, 4096) diff --git a/tests/v1/platform/__init__.py b/tests/v1/platform/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/tests/v1/platform/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/v1/platform/test_cpu_shm.py b/tests/v1/platform/test_cpu_shm.py new file mode 100644 index 0000000000..65a52cedae --- /dev/null +++ b/tests/v1/platform/test_cpu_shm.py @@ -0,0 +1,234 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for ``lmcache.v1.platform.cpu.shm``. + +Validates that the POSIX-SHM-backed wrapper can round-trip a CPU +tensor in-process: the constructed wrapper's ``to_tensor()`` view +sees writes made through the original tensor. +""" + +# Standard +import os + +# Third Party +import pytest +import torch + +# First Party +from lmcache.v1.multiprocess.posix_shm import shm_unlink +from lmcache.v1.platform.cpu.shm import ( + CpuShmTensorWrapper, + migrate_to_shm_and_wrap, + shm_create_readwrite, +) + + +def test_shm_create_unlink_roundtrip(): + """``shm_create_readwrite`` succeeds and ``shm_unlink`` cleans up.""" + name = "/lmcache_test_%d" % os.getpid() + addr = shm_create_readwrite(name, 4096) + try: + assert addr not in (0, None) + finally: + shm_unlink(name) + + +def test_migrate_to_shm_and_wrap_zero_copy_view(): + """After migrate, writes via the original tensor are visible via wrapper.""" + src = torch.zeros((2, 4, 4), dtype=torch.float32) + wrapper = migrate_to_shm_and_wrap(src) + try: + assert isinstance(wrapper, CpuShmTensorWrapper) + assert wrapper.shape == (2, 4, 4) + assert wrapper.dtype == torch.float32 + # Mutate via the migrated source tensor; its storage is now the + # SHM segment, so the wrapper's reconstructed view must see it. + src.add_(7.0) + view = wrapper.to_tensor() + assert torch.equal(view, src) + finally: + shm_unlink(wrapper.shm_name) + + +def test_migrate_handles_empty_tensor(): + """Empty tensors must not call ``mmap`` (length 0 is EINVAL). + + Regression for the case where ``nbytes == 0``: the wrapper carries + an empty ``shm_name`` and ``to_tensor`` rebuilds the empty view in + process without touching POSIX shared memory. + """ + src = torch.empty((0, 4), dtype=torch.float32) + wrapper = migrate_to_shm_and_wrap(src) + assert isinstance(wrapper, CpuShmTensorWrapper) + assert wrapper.shm_name == "" + assert wrapper.nbytes == 0 + view = wrapper.to_tensor() + assert view.shape == (0, 4) + assert view.dtype == torch.float32 + + +def test_migrate_is_idempotent_on_same_tensor(): + """Re-wrapping the same tensor reuses the existing SHM segment.""" + src = torch.zeros((3, 5), dtype=torch.float32) + w1 = migrate_to_shm_and_wrap(src) + try: + w2 = migrate_to_shm_and_wrap(src) + assert w1.shm_name == w2.shm_name + finally: + shm_unlink(w1.shm_name) + + +def test_rejects_non_cpu_tensor(): + """Construction rejects tensors that are not on CPU.""" + if not torch.backends.mps.is_available(): + pytest.skip("MPS not available; cannot synthesize a non-cpu tensor") + src = torch.zeros((2, 2), device="mps") + with pytest.raises(ValueError, match="CPU tensor"): + CpuShmTensorWrapper(src, "/lmcache_test_should_not_exist") + + +def test_migrate_finalizer_unlinks_on_gc(): + """Once the migrated tensor is GC-ed, its SHM segment is unlinked.""" + # Standard + import gc + + # First Party + from lmcache.v1.platform.cpu.shm import shm_map_readwrite + + src = torch.zeros((2, 2), dtype=torch.float32) + w = migrate_to_shm_and_wrap(src) + name = w.shm_name + nbytes = w.nbytes + # Drop both references; the weakref.finalize hook should unlink. + del src, w + gc.collect() + with pytest.raises(OSError): + shm_map_readwrite(name, nbytes) + + +def test_shm_create_cleans_up_on_existing_name(): + """If ``shm_open(O_EXCL)`` fails the helper must not leave the fd open. + + We exercise the failure path by creating a segment, then asking + ``shm_create_readwrite`` to recreate the same name -- it must + raise without leaking the file descriptor it briefly held. + """ + name = "/lmcache_test_excl_%d" % os.getpid() + addr = shm_create_readwrite(name, 4096) + try: + with pytest.raises(OSError): + shm_create_readwrite(name, 4096) + finally: + shm_unlink(name) + # And after unlink, the name is reusable again. + addr2 = shm_create_readwrite(name, 4096) + assert addr2 not in (0, None) + shm_unlink(name) + _ = addr # silence unused-variable hint + + +def test_to_tensor_view_carries_munmap_finalizer(): + """``to_tensor`` returns a tensor that releases its mmap on GC.""" + # Standard + import gc + import weakref + + src = torch.zeros((2, 2), dtype=torch.float32) + w = migrate_to_shm_and_wrap(src) + try: + view = w.to_tensor() + # The view must keep ``flat`` alive so its mmap stays valid. + assert hasattr(view, "_lmcache_shm_buf") + ref = weakref.ref(view) + del view + gc.collect() + assert ref() is None + finally: + del src + gc.collect() + shm_unlink(w.shm_name) + + +def test_to_tensor_replays_stride_and_storage_offset(): + """``to_tensor`` rebuilds the view via stride+offset (not reshape).""" + src = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4).contiguous() + w = migrate_to_shm_and_wrap(src) + try: + view = w.to_tensor() + assert tuple(view.stride()) == w.stride + assert int(view.storage_offset()) == w.storage_offset + assert torch.equal(view, src) + finally: + del src, view + shm_unlink(w.shm_name) + + +def test_wrap_kv_caches_unlinks_partial_batch_on_failure(monkeypatch): + """If wrapping the N-th tensor raises, earlier SHM names are unlinked. + + Drives :func:`wrap_kv_caches` with two CPU tensors and forces the + second factory call to raise; the first iteration's SHM segment + must be ``shm_unlink``-ed so the named segment does not outlive + the failed batch. + """ + # First Party + from lmcache.integration.vllm import vllm_multi_process_adapter as adapter + from lmcache.v1.platform.cpu.shm import shm_map_readwrite + + real_wrap = adapter.wrap_one_kv_cache + state = {"n": 0, "first_name": None} + + def flaky_wrap(tensor): + state["n"] += 1 + if state["n"] == 2: + raise RuntimeError("simulated migration failure") + w = real_wrap(tensor) + state["first_name"] = w.shm_name + return w + + monkeypatch.setattr(adapter, "wrap_one_kv_cache", flaky_wrap) + + t1 = torch.zeros((2, 2), dtype=torch.float32) + t2 = torch.zeros((2, 2), dtype=torch.float32) + with pytest.raises(RuntimeError, match="simulated migration failure"): + adapter.wrap_kv_caches({"a": t1, "b": t2}) + + # The first iteration's SHM segment must no longer be openable. + nbytes = t1.numel() * t1.element_size() + with pytest.raises(OSError): + shm_map_readwrite(state["first_name"], nbytes) + + +def test_migrate_ignores_stale_entry_from_id_reuse(): + """A cached entry whose weakref is dead must not be reused. + + Simulates CPython recycling an object id by injecting a stale + ``(dead_ref, old_name)`` tuple keyed by the live tensor's id, + then calling :func:`migrate_to_shm_and_wrap`. The factory must + treat the dead entry as a miss and allocate a fresh SHM segment + -- if it blindly reused the cached name, ``shm_create_readwrite`` + would crash with ``EEXIST`` (and even worse, the fresh tensor + would be silently bound to the wrong SHM name). + """ + # Standard + import gc + import weakref as _wr + + # First Party + from lmcache.v1.platform.cpu.shm import inject_stale_cache_entry_for_test + + # Build a tensor we will let die so we have a guaranteed-dead ref. + ghost = torch.zeros((1,), dtype=torch.float32) + dead_ref = _wr.ref(ghost) + del ghost + gc.collect() + assert dead_ref() is None + + live = torch.zeros((2, 2), dtype=torch.float32) + stale_name = "/lmcache_test_stale_%d" % os.getpid() + inject_stale_cache_entry_for_test(live, dead_ref, stale_name) + + w = migrate_to_shm_and_wrap(live) + try: + assert w.shm_name != stale_name + finally: + shm_unlink(w.shm_name) From 954abb4aea8397589544e5f72f21116b98d712f0 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Mon, 8 Jun 2026 15:36:34 +0800 Subject: [PATCH 02/57] [Refactor]: Normalize flat/nested block_ids in flat_block_ids and connector __str__ (#3577) Normalize flat/nested block_ids in flat_block_ids and connector __str__ Older vLLM connectors emit a flat list[int] for the single non-hybrid group, while newer ones use nested list[list[int]]. Make flat_block_ids and the three LMCacheMPConnectorMetadata.__str__ paths tolerate both, matching the normalization already done in expand_block_ids_to_views(). Signed-off-by: Tony Lin --- lmcache/integration/vllm/lmcache_mp_connector.py | 2 +- .../integration/vllm/lmcache_mp_connector_0180.py | 2 +- .../integration/vllm/lmcache_mp_connector_0201.py | 2 +- .../integration/vllm/vllm_multi_process_adapter.py | 14 +++++++++++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/lmcache/integration/vllm/lmcache_mp_connector.py b/lmcache/integration/vllm/lmcache_mp_connector.py index 27fa2de489..012e96fd2c 100644 --- a/lmcache/integration/vllm/lmcache_mp_connector.py +++ b/lmcache/integration/vllm/lmcache_mp_connector.py @@ -510,7 +510,7 @@ def __str__(self): request_strs.append( f"RequestMetadata(request_id={req_meta.request_id}, " f"direction={req_meta.direction}, " - f"num_blocks={len(req_meta.op.block_ids[0])}, " + f"num_blocks={len(req_meta.op.flat_block_ids)}, " f"block_ids={req_meta.op.block_ids})" ) return "[" + "\n".join(request_strs) + "]" diff --git a/lmcache/integration/vllm/lmcache_mp_connector_0180.py b/lmcache/integration/vllm/lmcache_mp_connector_0180.py index 61fe117a87..7fa46db945 100644 --- a/lmcache/integration/vllm/lmcache_mp_connector_0180.py +++ b/lmcache/integration/vllm/lmcache_mp_connector_0180.py @@ -433,7 +433,7 @@ def __str__(self): request_strs.append( f"RequestMetadata(request_id={req_meta.request_id}, " f"direction={req_meta.direction}, " - f"num_blocks={len(req_meta.op)}, " + f"num_blocks={len(req_meta.op.flat_block_ids)}, " f"block_ids={req_meta.op.block_ids})" ) return "[" + "\n".join(request_strs) + "]" diff --git a/lmcache/integration/vllm/lmcache_mp_connector_0201.py b/lmcache/integration/vllm/lmcache_mp_connector_0201.py index 1adc873587..6db28f5412 100644 --- a/lmcache/integration/vllm/lmcache_mp_connector_0201.py +++ b/lmcache/integration/vllm/lmcache_mp_connector_0201.py @@ -454,7 +454,7 @@ def __str__(self): request_strs.append( f"RequestMetadata(request_id={req_meta.request_id}, " f"direction={req_meta.direction}, " - f"num_blocks={len(req_meta.op)}, " + f"num_blocks={len(req_meta.op.flat_block_ids)}, " f"block_ids={req_meta.op.block_ids})" ) return "[" + "\n".join(request_strs) + "]" diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 80dfde3717..c2fcb083a9 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -475,7 +475,19 @@ class LoadStoreOp: @property def flat_block_ids(self) -> list[int]: - """Return all block IDs flattened for group-blind error paths.""" + """Return all block IDs flattened for group-blind error paths. + + Handles both the normal ``list[list[int]]`` format and the + IPC-flattened ``list[int]`` format that vLLM v0.19.0 produces when + ``SchedulerOutput`` serializes single-element nested lists across + process boundaries (e.g. ``[[20, 21]]`` → ``[20, 21]``). + Returns an empty list when ``block_ids`` is empty. + """ + if not self.block_ids: + return [] + # Defend against IPC serialization flattening [[20, 21, …]] → [20, 21, …] + if isinstance(self.block_ids[0], int): + return list(self.block_ids) return [ block_id for group_block_ids in self.block_ids From 3a45d0f2fadaa590d4dc450d60d6d93be44691c2 Mon Sep 17 00:00:00 2001 From: feixiangpeng <155504520+feixiangpeng@users.noreply.github.com> Date: Mon, 8 Jun 2026 13:28:45 -0500 Subject: [PATCH 03/57] Added HFbucket MP (#3263) Signed-off-by: feixiangpeng <155504520+feixiangpeng@users.noreply.github.com> --- .../kv_cache/storage_backends/hfbucket.rst | 84 ++ docs/source/mp/l2_storage.rst | 59 ++ .../l2_adapters/hfbucket_l2_adapter.py | 898 ++++++++++++++++++ .../distributed/test_hfbucket_l2_adapter.py | 570 +++++++++++ 4 files changed, 1611 insertions(+) create mode 100644 lmcache/v1/distributed/l2_adapters/hfbucket_l2_adapter.py create mode 100644 tests/v1/distributed/test_hfbucket_l2_adapter.py diff --git a/docs/source/kv_cache/storage_backends/hfbucket.rst b/docs/source/kv_cache/storage_backends/hfbucket.rst index cd05342b51..2f5dc46da4 100644 --- a/docs/source/kv_cache/storage_backends/hfbucket.rst +++ b/docs/source/kv_cache/storage_backends/hfbucket.rst @@ -91,6 +91,90 @@ either ``hfbucket`` or an instance-qualified name such as ``hfbucket.prod``. existence and size metadata. +MP Mode Configuration +--------------------- + +In multi-process (MP) mode, Hugging Face Buckets are configured as an L2 +adapter through a JSON spec passed to the LMCache server. This is separate from +the non-MP ``remote_storage_plugins`` configuration above. Each +``--l2-adapter`` argument takes a JSON object whose ``"type": "hfbucket"`` +field selects the HFBucket adapter. + +.. code-block:: json + + { + "type": "hfbucket", + "bucket_handle": "hf://buckets/my-org/lmcache-kv/prod", + "token_env": "HF_TOKEN", + "create_bucket_if_missing": false, + "download_tmp_dir": "/tmp/lmcache-hfbucket-mp", + "metadata_cache_ttl_secs": 30, + "num_workers": 4, + "max_capacity_gb": 500, + "eviction": { + "eviction_policy": "LRU", + "trigger_watermark": 0.85, + "eviction_ratio": 0.2 + } + } + +HFBucket L2 Adapter Fields +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* **type** (required): must be ``"hfbucket"``. +* **bucket_handle** (required): Hugging Face Bucket handle in + ``hf://buckets//[/]`` format. +* **token_env**: environment variable used to resolve the Hugging Face access + token (default ``"HF_TOKEN"``). +* **token**: optional direct token fallback. ``token_env`` takes precedence + when the environment variable is set. Prefer ``token_env`` for production + deployments so secrets do not live in adapter JSON. +* **create_bucket_if_missing**: lazily create the bucket on the first store + operation (default ``false``). This only helps when the bucket is missing and + the token has permission to create it; it does not fix invalid credentials, + invalid handles, or network failures. +* **download_tmp_dir**: root directory for temporary load downloads (default + ``/tmp/lmcache-hfbucket-mp``). The MP adapter downloads bucket files into + per-task temporary files and then copies their bytes into the destination + ``MemoryObj`` buffers supplied by the MP controller. +* **metadata_cache_ttl_secs**: TTL for cached exact path-size metadata (default + ``30``). Set this lower when another process may modify the same bucket + prefix outside LMCache and fresher metadata is more important than reducing + Hugging Face metadata calls. +* **num_workers**: number of worker threads used for blocking Hugging Face Hub + bucket API calls (default ``4``). The HFBucket Python APIs are synchronous, + so MP mode runs upload, lookup, load, and delete work on a bounded thread + pool behind the adapter's eventfd-based completion interface. +* **max_capacity_gb**: capacity used by ``get_usage()`` for watermark-based L2 + eviction. Set to ``0`` (default) to disable aggregate capacity tracking; + ``get_usage()`` then reports the adapter as not providing an eviction signal. +* **eviction**: optional sub-dict enabling the L2 eviction controller for this + adapter. When present, keys that are currently being loaded are protected by + the lookup-and-lock path and skipped by ``delete()`` until they are unlocked. + +Differences vs Non-MP HFBucket +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* Hugging Face bucket operations are synchronous but the adapter makes submission + non-blocking by running the blocking calls on worker threads. +* MP loads do not allocate and return new memory. The MP controller provides + destination ``MemoryObj`` buffers, and the adapter copies downloaded bytes + into those buffers. +* Keys are identified by ``ObjectKey`` (``model_name`` + ``kv_rank`` + + ``chunk_hash`` + optional ``cache_salt``) rather than ``CacheEngineKey``. + The serialized MP object name is + ``@@[@]`` and is then + encoded for the bucket path. This naming is not compatible with the non-MP + HFBucket connector's ``CacheEngineKey`` object names, so a bucket prefix + populated by non-MP LMCache cannot be read directly by MP LMCache and vice + versa. +* Full object writes are batch based. Hugging Face batch writes are not + transactional, so a failed store task may still leave some objects in the + bucket. The MP adapter reconciles backend metadata after such failures so + any objects that actually landed are counted for usage and later deletion + (submitted store task is still reported as failed). + + Notes ----- diff --git a/docs/source/mp/l2_storage.rst b/docs/source/mp/l2_storage.rst index d5f0f08d2e..0f5138c251 100644 --- a/docs/source/mp/l2_storage.rst +++ b/docs/source/mp/l2_storage.rst @@ -543,6 +543,59 @@ S3-compatible endpoint (MinIO, Ceph RGW, etc.). # Local MinIO over plain HTTP --l2-adapter '{"type": "s3", "s3_endpoint": "minio.local:9000", "s3_region": "us-east-1", "disable_tls": true, "aws_access_key_id": "minio", "aws_secret_access_key": "minio123"}' +``hfbucket`` -- Hugging Face Buckets +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +An L2 adapter that stores KV cache objects in a `Hugging Face Bucket +`_ using the +``huggingface_hub`` bucket APIs. Blocking Hub calls run on a bounded thread +pool driven by an asyncio loop on a daemon thread, so the L2 controller thread +is never blocked on network I/O. + +Object names are derived from the MP ``ObjectKey`` as +``@@[@]`` and then encoded with +the standard HFBucket object-name encoding plus the optional bucket prefix. +Because Hugging Face batch writes are not transactional, a store task that +partially fails reconciles backend metadata so that any objects that actually +landed are still counted for usage accounting and later deletion. + +This is a persistent remote backend best suited to warm and cold KV cache +tiers; prefer a lower-latency local adapter for the hottest cache tier. + +**Required fields:** + +- ``bucket_handle``: Bucket location in the form + ``hf://buckets//[/]``. + +**Optional fields:** + +- ``token_env`` (string, default ``"HF_TOKEN"``): Environment variable used to + resolve the Hugging Face access token. +- ``token`` (string): Direct token fallback used when ``token_env`` is unset. +- ``create_bucket_if_missing`` (bool, default ``false``): Create the bucket + lazily on the first store instead of requiring it to exist. +- ``download_tmp_dir`` (string): Root directory for temporary load downloads. +- ``metadata_cache_ttl_secs`` (float, default ``30.0``): TTL for the + path-size metadata cache that backs lookups and usage accounting. +- ``num_workers`` (int, default ``4``): Number of worker threads for blocking + Hugging Face Hub API calls. +- ``max_capacity_gb`` (float, default ``0.0``): Aggregate capacity used by + ``get_usage()``. A value of ``0`` disables aggregate eviction. +- ``eviction`` (dict): Optional eviction policy, see ``L2AdapterConfigBase``. + +**Configuration examples:** + +.. code-block:: bash + + # Minimal: use an existing bucket with a token from $HF_TOKEN + --l2-adapter '{"type": "hfbucket", "bucket_handle": "hf://buckets/my-org/lmcache-kv/prod"}' + + # Create the bucket on first store and bound the worker pool + --l2-adapter '{"type": "hfbucket", "bucket_handle": "hf://buckets/my-org/lmcache-kv/prod", "create_bucket_if_missing": true, "num_workers": 8}' + + # Enable aggregate eviction with a capacity cap + --l2-adapter '{"type": "hfbucket", "bucket_handle": "hf://buckets/my-org/lmcache-kv/prod", "max_capacity_gb": 50, "eviction": {"eviction_policy": "LRU", "trigger_watermark": 0.9, "eviction_ratio": 0.1}}' + ``mock`` -- Mock adapter for testing ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -777,6 +830,12 @@ drops by ``eviction_ratio``. when ``max_capacity_gb`` is ``0`` (disabled); set a non-zero ``max_capacity_gb`` to enable the watermark-triggered eviction controller. + * - ``hfbucket`` + - ``delete`` removes objects from the bucket and frees aggregate + byte accounting. ``get_usage`` reports ``usage_fraction == -1.0`` + when ``max_capacity_gb`` is ``0`` (disabled); set a non-zero + ``max_capacity_gb`` to enable the watermark-triggered eviction + controller. Locked keys (in-flight loads) are skipped. * - ``dax`` - Full support. ``delete`` removes unlocked keys from the in-memory index immediately and recycles fixed slots once active read borrows diff --git a/lmcache/v1/distributed/l2_adapters/hfbucket_l2_adapter.py b/lmcache/v1/distributed/l2_adapters/hfbucket_l2_adapter.py new file mode 100644 index 0000000000..b1cf97f576 --- /dev/null +++ b/lmcache/v1/distributed/l2_adapters/hfbucket_l2_adapter.py @@ -0,0 +1,898 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Hugging Face Buckets L2 adapter for LMCache MP mode. +""" + +# Future +from __future__ import annotations + +# Standard +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Optional +import asyncio +import os +import shutil +import tempfile +import threading +import time + +if TYPE_CHECKING: + from lmcache.v1.distributed.internal_api import L1MemoryDesc + +# First Party +from lmcache.logging import init_logger +from lmcache.native_storage_ops import Bitmap +from lmcache.v1.distributed.api import ObjectKey +from lmcache.v1.distributed.internal_api import L2StoreResult +from lmcache.v1.distributed.l2_adapters.base import ( + L2AdapterInterface, + L2TaskId, +) +from lmcache.v1.distributed.l2_adapters.config import ( + L2AdapterConfigBase, + register_l2_adapter_type, +) +from lmcache.v1.distributed.l2_adapters.factory import ( + register_l2_adapter_factory, +) +from lmcache.v1.memory_management import MemoryObj +from lmcache.v1.platform import create_event_notifier +from lmcache.v1.storage_backend.connector.hfbucket_connector import ( + HFBucketClient, + HFBucketClientInterface, + HFBucketLocation, + encode_hfbucket_object_name, + parse_hfbucket_handle, +) + +logger = init_logger(__name__) + +# Use a separate temp root from non-MP HFBucket to avoid collisions. +_DEFAULT_DOWNLOAD_TMP_DIR = Path(tempfile.gettempdir()) / "lmcache-hfbucket-mp" +_METADATA_CACHE_PRUNE_INTERVAL = 128 + + +@dataclass(frozen=True) +class _CachedObjectMetadata: + """Cached object size entry with expiration metadata.""" + + size_bytes: int + expires_at: float + + +class _PartialStoreFailure(RuntimeError): + """Raised when a failed HFBucket batch store still wrote some objects.""" + + def __init__( + self, + message: str, + stored_keys: list[ObjectKey], + stored_sizes: list[int], + ) -> None: + super().__init__(message) + self.stored_keys = stored_keys + self.stored_sizes = stored_sizes + + +def _object_key_to_string(key: ObjectKey) -> str: + """Serialize an MP ``ObjectKey`` to the shared L2 object-name format. + + Unsalted keys use ``@@``. Salted + keys append ``@`` so tenants/users with identical token chunks + do not collide in the backing bucket. + """ + base = f"{key.model_name}@{key.kv_rank:08x}@{key.chunk_hash.hex()}" + if key.cache_salt: + return f"{base}@{key.cache_salt}" + return base + + +def _object_key_to_bucket_path(key: ObjectKey, location: HFBucketLocation) -> str: + """Return the HFBucket object path for an MP object key.""" + encoded = encode_hfbucket_object_name(_object_key_to_string(key)) + if location.object_prefix: + return f"{location.object_prefix}/{encoded}" + return encoded + + +def _resolve_hf_token(token_env: str, token: str | None) -> str | None: + """Resolve Hugging Face token from env-first adapter config.""" + env_token = os.environ.get(token_env, "") if token_env else "" + if env_token: + return env_token + return token + + +def _get_path_info_path(path_info: object) -> str: + """Read a Hugging Face path-info object's path field defensively.""" + path = getattr(path_info, "path", "") + return path if isinstance(path, str) else "" + + +def _get_path_info_type(path_info: object) -> str: + """Read a Hugging Face path-info object's type field defensively.""" + obj_type = getattr(path_info, "type", "") + return obj_type if isinstance(obj_type, str) else "" + + +def _get_path_info_size(path_info: object) -> int: + """Read a Hugging Face path-info object's size field defensively.""" + size = getattr(path_info, "size", 0) + return size if isinstance(size, int) else 0 + + +def _is_not_found_error(exc: Exception) -> bool: + """Return whether an exception represents a missing bucket/object.""" + response = getattr(exc, "response", None) + status_code = getattr(response, "status_code", None) + if isinstance(status_code, int): + return status_code == 404 + + direct_status_code = getattr(exc, "status_code", None) + if isinstance(direct_status_code, int): + return direct_status_code == 404 + + return "404" in str(exc) + + +class HFBucketL2AdapterConfig(L2AdapterConfigBase): + """Configuration for the HFBucket MP L2 adapter. + + Fields: + - ``bucket_handle``: ``hf://buckets//[/]``. + - ``token_env``: environment variable used to resolve the HF token. + - ``token``: optional direct token fallback. + - ``create_bucket_if_missing``: create the bucket lazily on first store. + - ``download_tmp_dir``: root directory for temporary load downloads. + - ``metadata_cache_ttl_secs``: TTL for path-size metadata cache. + - ``num_workers``: worker threads for blocking Hugging Face API calls. + - ``max_capacity_gb``: capacity used by inherited L2 usage accounting. + """ + + def __init__( + self, + bucket_handle: str, + token_env: str = "HF_TOKEN", + token: Optional[str] = None, + create_bucket_if_missing: bool = False, + download_tmp_dir: str = str(_DEFAULT_DOWNLOAD_TMP_DIR), + metadata_cache_ttl_secs: float = 30.0, + num_workers: int = 4, + max_capacity_gb: float = 0.0, + ) -> None: + self.bucket_handle = bucket_handle + self.bucket_location = parse_hfbucket_handle(bucket_handle) + self.token_env = token_env + self.token = token + self.create_bucket_if_missing = create_bucket_if_missing + self.download_tmp_dir = Path(download_tmp_dir) + self.metadata_cache_ttl_secs = metadata_cache_ttl_secs + self.num_workers = num_workers + self.max_capacity_gb = max_capacity_gb + + @classmethod + def from_dict(cls, d: dict) -> "HFBucketL2AdapterConfig": + """Parse a config object from ``--l2-adapter`` JSON.""" + bucket_handle = d.get("bucket_handle") + if not isinstance(bucket_handle, str) or not bucket_handle: + raise ValueError("bucket_handle must be a non-empty string") + + token_env = d.get("token_env", "HF_TOKEN") + if not isinstance(token_env, str): + raise ValueError("token_env must be a string") + + token = d.get("token") + if token is not None and not isinstance(token, str): + raise ValueError("token must be a string") + + download_tmp_dir = d.get("download_tmp_dir", str(_DEFAULT_DOWNLOAD_TMP_DIR)) + if not isinstance(download_tmp_dir, str) or not download_tmp_dir: + raise ValueError("download_tmp_dir must be a non-empty string") + + metadata_cache_ttl_secs = d.get("metadata_cache_ttl_secs", 30.0) + if ( + not isinstance(metadata_cache_ttl_secs, (int, float)) + or isinstance(metadata_cache_ttl_secs, bool) + or metadata_cache_ttl_secs < 0 + ): + raise ValueError("metadata_cache_ttl_secs must be a non-negative number") + + num_workers = d.get("num_workers", 4) + if not isinstance(num_workers, int) or isinstance(num_workers, bool): + raise ValueError("num_workers must be a positive integer") + if num_workers <= 0: + raise ValueError("num_workers must be a positive integer") + + max_capacity_gb = d.get("max_capacity_gb", 0.0) + if ( + not isinstance(max_capacity_gb, (int, float)) + or isinstance(max_capacity_gb, bool) + or max_capacity_gb < 0 + ): + raise ValueError("max_capacity_gb must be a non-negative number") + + create_bucket_if_missing = d.get("create_bucket_if_missing", False) + if not isinstance(create_bucket_if_missing, bool): + raise ValueError("create_bucket_if_missing must be a boolean") + + cfg = cls( + bucket_handle=bucket_handle, + token_env=token_env, + token=token, + create_bucket_if_missing=create_bucket_if_missing, + download_tmp_dir=download_tmp_dir, + metadata_cache_ttl_secs=float(metadata_cache_ttl_secs), + num_workers=num_workers, + max_capacity_gb=float(max_capacity_gb), + ) + cfg.eviction_config = cls._parse_eviction_config(d) + return cfg + + @classmethod + def help(cls) -> str: + """Return CLI help text for this adapter type.""" + return ( + "HFBucket L2 adapter config fields:\n" + "- bucket_handle (str, required): " + "hf://buckets//[/]\n" + "- token_env (str): env var for HF token (default HF_TOKEN)\n" + "- token (str): direct token fallback\n" + "- create_bucket_if_missing (bool): create bucket on first store\n" + "- download_tmp_dir (str): temporary download root\n" + "- metadata_cache_ttl_secs (float): metadata cache TTL\n" + "- num_workers (int): blocking HF API worker threads\n" + "- max_capacity_gb (float): capacity for get_usage (0 = disabled)\n" + "- eviction (dict): optional, see L2AdapterConfigBase" + ) + + +class HFBucketL2Adapter(L2AdapterInterface): + """Hugging Face Buckets backed MP L2 adapter.""" + + def __init__( + self, + config: HFBucketL2AdapterConfig, + bucket_client: HFBucketClientInterface | None = None, + ) -> None: + super().__init__(max_capacity_bytes=int(config.max_capacity_gb * (1024**3))) + self._config = config + self._bucket_location = config.bucket_location + self._bucket_id = config.bucket_location.bucket_id + self._object_prefix = config.bucket_location.object_prefix + self._create_bucket_if_missing = config.create_bucket_if_missing + self._metadata_cache_ttl_secs = config.metadata_cache_ttl_secs + + if bucket_client is None: + token = _resolve_hf_token(config.token_env, config.token) + self._bucket_client: HFBucketClientInterface = HFBucketClient(token=token) + else: + self._bucket_client = bucket_client + + self._store_efd = create_event_notifier() + self._lookup_efd = create_event_notifier() + self._load_efd = create_event_notifier() + + self._next_task_id: L2TaskId = 0 + self._completed_store_tasks: dict[L2TaskId, L2StoreResult] = {} + self._completed_lookup_tasks: dict[L2TaskId, Bitmap] = {} + self._completed_load_tasks: dict[L2TaskId, Bitmap] = {} + + self._locked_keys: dict[ObjectKey, int] = defaultdict(int) + self._key_sizes: dict[ObjectKey, int] = {} + self._metadata_cache: dict[str, _CachedObjectMetadata] = {} + self._metadata_cache_updates = 0 + + self._bucket_create_checked = False + self._bucket_create_lock = threading.Lock() + + self._lock = threading.Lock() + self._closed = False + + self._download_tmp_root = config.download_tmp_dir.expanduser() + self._download_tmp_root.mkdir(parents=True, exist_ok=True) + self._download_session_dir = Path( + tempfile.mkdtemp( + prefix="hfbucket-mp-", + dir=self._download_tmp_root, + ) + ) + + self._executor = ThreadPoolExecutor( + max_workers=config.num_workers, + thread_name_prefix="hfbucket-l2", + ) + self._loop = asyncio.new_event_loop() + self._loop_thread = threading.Thread( + target=self._run_event_loop, + daemon=True, + name="hfbucket-l2-adapter-loop", + ) + self._loop_thread.start() + + logger.info( + "Initialized HFBucketL2Adapter (bucket_id=%s prefix=%r " + "workers=%d max_capacity_gb=%.2f)", + self._bucket_id, + self._object_prefix, + config.num_workers, + config.max_capacity_gb, + ) + + def get_store_event_fd(self) -> int: + return self._store_efd.fileno() + + def get_lookup_and_lock_event_fd(self) -> int: + return self._lookup_efd.fileno() + + def get_load_event_fd(self) -> int: + return self._load_efd.fileno() + + def submit_store_task( + self, + keys: list[ObjectKey], + objects: list[MemoryObj], + ) -> L2TaskId: + with self._lock: + task_id = self._get_next_task_id_locked() + if self._closed: + self._completed_store_tasks[task_id] = L2StoreResult(False, 0) + closed = True + else: + closed = False + + if closed: + self._store_efd.notify() + return task_id + + asyncio.run_coroutine_threadsafe( + self._execute_store(list(keys), list(objects), task_id), + self._loop, + ) + return task_id + + def pop_completed_store_tasks(self) -> dict[L2TaskId, L2StoreResult]: + with self._lock: + completed = self._completed_store_tasks + self._completed_store_tasks = {} + return completed + + def submit_lookup_and_lock_task(self, keys: list[ObjectKey]) -> L2TaskId: + with self._lock: + task_id = self._get_next_task_id_locked() + if self._closed: + self._completed_lookup_tasks[task_id] = Bitmap(len(keys)) + closed = True + else: + closed = False + + if closed: + self._lookup_efd.notify() + return task_id + + asyncio.run_coroutine_threadsafe( + self._execute_lookup(list(keys), task_id), + self._loop, + ) + return task_id + + def query_lookup_and_lock_result(self, task_id: L2TaskId) -> Optional[Bitmap]: + with self._lock: + return self._completed_lookup_tasks.pop(task_id, None) + + def submit_unlock(self, keys: list[ObjectKey]) -> None: + with self._lock: + for key in keys: + if key not in self._locked_keys: + continue + if self._locked_keys[key] <= 1: + del self._locked_keys[key] + else: + self._locked_keys[key] -= 1 + + def submit_load_task( + self, + keys: list[ObjectKey], + objects: list[MemoryObj], + ) -> L2TaskId: + with self._lock: + task_id = self._get_next_task_id_locked() + if self._closed: + self._completed_load_tasks[task_id] = Bitmap(len(keys)) + closed = True + else: + closed = False + + if closed: + self._load_efd.notify() + return task_id + + asyncio.run_coroutine_threadsafe( + self._execute_load(list(keys), list(objects), task_id), + self._loop, + ) + return task_id + + def query_load_result(self, task_id: L2TaskId) -> Optional[Bitmap]: + with self._lock: + return self._completed_load_tasks.pop(task_id, None) + + def delete(self, keys: list[ObjectKey]) -> None: + if not keys: + return + + with self._lock: + if self._closed: + return + deletable = [key for key in keys if self._locked_keys.get(key, 0) == 0] + + if not deletable: + return + + future = asyncio.run_coroutine_threadsafe( + self._execute_delete(deletable), + self._loop, + ) + try: + deleted_keys, deleted_sizes = future.result(timeout=30.0) + except Exception as exc: + logger.warning("HFBucketL2Adapter delete failed: %s", exc) + return + + if deleted_keys: + self._notify_keys_deleted(deleted_keys, deleted_sizes) + + def report_status(self) -> dict: + usage = self.get_usage() + with self._lock: + object_count = len(self._key_sizes) + locked_key_count = len(self._locked_keys) + closed = self._closed + return { + "is_healthy": self._loop_thread.is_alive() and not closed, + "type": "HFBucketL2Adapter", + "bucket_id": self._bucket_id, + "object_prefix": self._object_prefix, + "stored_object_count": object_count, + "locked_key_count": locked_key_count, + "current_size_bytes": usage.total_bytes_used, + "max_capacity_bytes": usage.total_capacity_bytes, + } + + def close(self) -> None: + if self._closed: + return + self._closed = True + + async def _stop_tasks() -> None: + tasks = [ + task + for task in asyncio.all_tasks(self._loop) + if task is not asyncio.current_task() + ] + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + if self._loop.is_running(): + try: + asyncio.run_coroutine_threadsafe(_stop_tasks(), self._loop).result( + timeout=5 + ) + except Exception: + pass + self._loop.call_soon_threadsafe(self._loop.stop) + + self._loop_thread.join(timeout=5) + try: + self._loop.close() + except Exception: + pass + + self._executor.shutdown(wait=True, cancel_futures=True) + + self._store_efd.close() + self._lookup_efd.close() + self._load_efd.close() + + with self._lock: + self._metadata_cache.clear() + self._key_sizes.clear() + self._locked_keys.clear() + + shutil.rmtree(self._download_session_dir, ignore_errors=True) + logger.info("HFBucketL2Adapter closed") + + def _run_event_loop(self) -> None: + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + def _get_next_task_id_locked(self) -> L2TaskId: + task_id = self._next_task_id + self._next_task_id += 1 + return task_id + + async def _execute_store( + self, + keys: list[ObjectKey], + objects: list[MemoryObj], + task_id: L2TaskId, + ) -> None: + try: + stored_keys, stored_sizes = await self._loop.run_in_executor( + self._executor, + self._store_batch_sync, + keys, + objects, + ) + success = True + except _PartialStoreFailure as exc: + logger.exception("HFBucketL2Adapter store task partially failed") + stored_keys = exc.stored_keys + stored_sizes = exc.stored_sizes + success = False + except Exception: + logger.exception("HFBucketL2Adapter store task failed") + stored_keys = [] + stored_sizes = [] + success = False + + bytes_transferred = sum(stored_sizes) + with self._lock: + self._completed_store_tasks[task_id] = L2StoreResult( + success, + bytes_transferred, + ) + + if stored_keys: + self._notify_keys_stored(stored_keys, stored_sizes) + self._store_efd.notify() + + async def _execute_lookup( + self, + keys: list[ObjectKey], + task_id: L2TaskId, + ) -> None: + bitmap = Bitmap(len(keys)) + try: + sizes = await self._loop.run_in_executor( + self._executor, + self._resolve_object_sizes_sync, + keys, + ) + except Exception: + logger.exception("HFBucketL2Adapter lookup task failed") + sizes = [0] * len(keys) + + accessed: list[ObjectKey] = [] + with self._lock: + for i, (key, size) in enumerate(zip(keys, sizes, strict=True)): + if size <= 0: + continue + bitmap.set(i) + self._locked_keys[key] += 1 + accessed.append(key) + self._completed_lookup_tasks[task_id] = bitmap + + self._lookup_efd.notify() + if accessed: + self._notify_keys_accessed(accessed) + + async def _execute_load( + self, + keys: list[ObjectKey], + objects: list[MemoryObj], + task_id: L2TaskId, + ) -> None: + try: + bitmap = await self._loop.run_in_executor( + self._executor, + self._load_batch_sync, + keys, + objects, + ) + except Exception: + logger.exception("HFBucketL2Adapter load task failed") + bitmap = Bitmap(len(keys)) + + with self._lock: + self._completed_load_tasks[task_id] = bitmap + self._load_efd.notify() + + async def _execute_delete( + self, + keys: list[ObjectKey], + ) -> tuple[list[ObjectKey], list[int]]: + return await self._loop.run_in_executor( + self._executor, + self._delete_batch_sync, + keys, + ) + + def _store_batch_sync( + self, + keys: list[ObjectKey], + objects: list[MemoryObj], + ) -> tuple[list[ObjectKey], list[int]]: + self._ensure_bucket_for_writes() + + additions: list[tuple[bytes, str]] = [] + indexed: list[tuple[ObjectKey, str, int]] = [] + for key, obj in zip(keys, objects, strict=True): + object_path = _object_key_to_bucket_path(key, self._bucket_location) + data = memoryview(obj.byte_array).cast("B").tobytes() + additions.append((data, object_path)) + indexed.append((key, object_path, len(data))) + + if not additions: + return [], [] + + try: + self._bucket_client.upload_files(self._bucket_id, additions) + except Exception as exc: + # Hugging Face batch writes are not transactional: a request can + # write part of the batch and then fail. Fetch fresh backend + # metadata, update accounting for objects that really landed, and + # still report the submitted store task as failed. + reconciled_keys, reconciled_sizes = self._reconcile_failed_store(indexed) + raise _PartialStoreFailure( + "HFBucket batch upload failed after partial reconciliation", + reconciled_keys, + reconciled_sizes, + ) from exc + + stored_keys: list[ObjectKey] = [] + stored_sizes: list[int] = [] + with self._lock: + for key, object_path, size in indexed: + was_new = key not in self._key_sizes + self._key_sizes[key] = size + self._set_cached_object_size_locked(object_path, size) + if was_new: + stored_keys.append(key) + stored_sizes.append(size) + + return stored_keys, stored_sizes + + def _resolve_object_sizes_sync(self, keys: list[ObjectKey]) -> list[int]: + object_paths = [ + _object_key_to_bucket_path(key, self._bucket_location) for key in keys + ] + + cached: dict[str, int] = {} + unresolved_paths: list[str] = [] + with self._lock: + for object_path in object_paths: + cached_size = self._get_cached_object_size_locked(object_path) + if cached_size is None: + unresolved_paths.append(object_path) + else: + cached[object_path] = cached_size + + if unresolved_paths: + fetched = self._fetch_object_sizes_sync(unresolved_paths) + with self._lock: + for object_path, size in fetched.items(): + self._set_cached_object_size_locked(object_path, size) + cached.update(fetched) + + return [cached.get(object_path, 0) for object_path in object_paths] + + def _fetch_object_sizes_sync(self, object_paths: list[str]) -> dict[str, int]: + if not object_paths: + return {} + + try: + path_infos = self._bucket_client.get_paths_info( + self._bucket_id, + object_paths, + ) + except Exception as exc: + if _is_not_found_error(exc): + return {object_path: 0 for object_path in object_paths} + raise + + size_by_path: dict[str, int] = {} + for path_info in path_infos: + if _get_path_info_type(path_info) != "file": + continue + path = _get_path_info_path(path_info) + if path: + size_by_path[path] = _get_path_info_size(path_info) + + return { + object_path: size_by_path.get(object_path, 0) + for object_path in object_paths + } + + def _load_batch_sync( + self, + keys: list[ObjectKey], + objects: list[MemoryObj], + ) -> Bitmap: + bitmap = Bitmap(len(keys)) + object_paths = [ + _object_key_to_bucket_path(key, self._bucket_location) for key in keys + ] + + batch_dir = Path( + tempfile.mkdtemp(prefix="load-", dir=self._download_session_dir) + ) + local_paths: list[tuple[int, Path]] = [] + files: list[tuple[str, str]] = [] + for index, object_path in enumerate(object_paths): + local_path = batch_dir / f"{index}.bin" + local_paths.append((index, local_path)) + files.append((object_path, str(local_path))) + + try: + try: + self._bucket_client.download_files(self._bucket_id, files) + except Exception as exc: + if not _is_not_found_error(exc): + logger.warning("Batch download from hfbucket raised: %s", exc) + + for index, local_path in local_paths: + if not local_path.exists(): + continue + + dst = memoryview(objects[index].byte_array).cast("B") + file_size = local_path.stat().st_size + if file_size != len(dst): + logger.error( + "Downloaded object %s has %d bytes, expected %d bytes; " + "rejecting load", + object_paths[index], + file_size, + len(dst), + ) + with self._lock: + self._set_cached_object_size_locked( + object_paths[index], + file_size, + ) + continue + + with local_path.open("rb") as f: + bytes_read = f.readinto(dst) + if bytes_read != len(dst): + logger.error( + "Downloaded object %s read %d bytes, expected %d bytes; " + "rejecting load", + object_paths[index], + bytes_read, + len(dst), + ) + with self._lock: + self._set_cached_object_size_locked( + object_paths[index], + bytes_read, + ) + continue + + bitmap.set(index) + with self._lock: + self._set_cached_object_size_locked( + object_paths[index], + file_size, + ) + + return bitmap + finally: + shutil.rmtree(batch_dir, ignore_errors=True) + + def _delete_batch_sync( + self, + keys: list[ObjectKey], + ) -> tuple[list[ObjectKey], list[int]]: + object_paths = [ + _object_key_to_bucket_path(key, self._bucket_location) for key in keys + ] + + try: + self._bucket_client.delete_files(self._bucket_id, object_paths) + except Exception as exc: + if not _is_not_found_error(exc): + raise + + deleted_keys: list[ObjectKey] = [] + deleted_sizes: list[int] = [] + with self._lock: + for key, object_path in zip(keys, object_paths, strict=True): + size = self._key_sizes.pop(key, None) + self._set_cached_object_size_locked(object_path, 0) + deleted_keys.append(key) + deleted_sizes.append(size if size is not None else 0) + + return deleted_keys, deleted_sizes + + def _ensure_bucket_for_writes(self) -> None: + if not self._create_bucket_if_missing or self._bucket_create_checked: + return + + with self._bucket_create_lock: + if self._bucket_create_checked: + return + self._bucket_client.create_bucket(self._bucket_id) + self._bucket_create_checked = True + + def _refresh_cached_sizes(self, keys: list[ObjectKey]) -> None: + try: + self._resolve_object_sizes_sync(keys) + except Exception: + logger.debug("Failed to refresh hfbucket object sizes", exc_info=True) + + def _reconcile_failed_store( + self, + indexed: list[tuple[ObjectKey, str, int]], + ) -> tuple[list[ObjectKey], list[int]]: + object_paths = [object_path for _, object_path, _ in indexed] + try: + sizes_by_path = self._fetch_object_sizes_sync(object_paths) + except Exception: + logger.debug("Failed to reconcile partial hfbucket store", exc_info=True) + return [], [] + + stored_keys: list[ObjectKey] = [] + stored_sizes: list[int] = [] + with self._lock: + for key, object_path, _expected_size in indexed: + size = sizes_by_path.get(object_path, 0) + self._set_cached_object_size_locked(object_path, size) + if size <= 0: + continue + + # Only notify net-new keys. Existing keys already contributed + # to byte accounting, and cache objects should be fixed size. + was_new = key not in self._key_sizes + self._key_sizes[key] = size + if was_new: + stored_keys.append(key) + stored_sizes.append(size) + + return stored_keys, stored_sizes + + def _get_cached_object_size_locked(self, object_path: str) -> int | None: + entry = self._metadata_cache.get(object_path) + if entry is None: + return None + if entry.expires_at <= time.monotonic(): + self._metadata_cache.pop(object_path, None) + return None + return entry.size_bytes + + def _set_cached_object_size_locked(self, object_path: str, size: int) -> None: + expires_at = time.monotonic() + self._metadata_cache_ttl_secs + self._metadata_cache[object_path] = _CachedObjectMetadata( + size_bytes=size, + expires_at=expires_at, + ) + self._metadata_cache_updates += 1 + if self._metadata_cache_updates % _METADATA_CACHE_PRUNE_INTERVAL == 0: + self._prune_expired_cache_entries_locked(time.monotonic()) + + def _prune_expired_cache_entries_locked(self, now: float) -> None: + expired = [ + object_path + for object_path, entry in self._metadata_cache.items() + if entry.expires_at <= now + ] + for object_path in expired: + self._metadata_cache.pop(object_path, None) + + +register_l2_adapter_type("hfbucket", HFBucketL2AdapterConfig) + + +def _create_hfbucket_l2_adapter( + config: L2AdapterConfigBase, + l1_memory_desc: "Optional[L1MemoryDesc]" = None, +) -> L2AdapterInterface: + """Create an HFBucket L2 adapter from registry config.""" + return HFBucketL2Adapter(config) # type: ignore[arg-type] + + +register_l2_adapter_factory("hfbucket", _create_hfbucket_l2_adapter) diff --git a/tests/v1/distributed/test_hfbucket_l2_adapter.py b/tests/v1/distributed/test_hfbucket_l2_adapter.py new file mode 100644 index 0000000000..2e69c6bb74 --- /dev/null +++ b/tests/v1/distributed/test_hfbucket_l2_adapter.py @@ -0,0 +1,570 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the HFBucket MP L2 adapter.""" + +# Standard +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path +import select +import threading +import time + +# Third Party +import pytest +import torch + +# First Party +from lmcache.v1.distributed.api import ObjectKey +from lmcache.v1.distributed.internal_api import L2AdapterListener +from lmcache.v1.distributed.l2_adapters import hfbucket_l2_adapter as hfmod +from lmcache.v1.distributed.l2_adapters.hfbucket_l2_adapter import ( + HFBucketL2Adapter, + HFBucketL2AdapterConfig, + _object_key_to_bucket_path, + _object_key_to_string, +) +from lmcache.v1.memory_management import ( + MemoryFormat, + MemoryObj, + MemoryObjMetadata, + TensorMemoryObj, +) +from lmcache.v1.platform import consume_fd +from lmcache.v1.storage_backend.connector.hfbucket_connector import ( + parse_hfbucket_handle, +) + +_TEST_BUCKET_HANDLE = "hf://buckets/test-org/test-bucket/prod" +_TEST_BUCKET_LOCATION = parse_hfbucket_handle(_TEST_BUCKET_HANDLE) + + +@dataclass(frozen=True) +class _FakePathInfo: + path: str + type: str + size: int + + +class _FakeBucketClient: + """In-memory HFBucket client used by adapter unit tests.""" + + def __init__(self) -> None: + self.storage: dict[str, bytes] = {} + self.created_buckets: list[str] = [] + self.deleted_paths: list[str] = [] + self.fail_upload_after: int | None = None + self._lock = threading.Lock() + + def create_bucket(self, bucket_id: str) -> None: + with self._lock: + self.created_buckets.append(bucket_id) + + def bucket_info(self, bucket_id: str) -> object: + return {"bucket_id": bucket_id} + + def get_paths_info( + self, + bucket_id: str, + paths: Sequence[str], + ) -> list[object]: + del bucket_id + with self._lock: + return [ + _FakePathInfo(path=path, type="file", size=len(self.storage[path])) + for path in paths + if path in self.storage + ] + + def list_tree(self, bucket_id: str, prefix: str) -> list[object]: + del bucket_id + with self._lock: + return [ + _FakePathInfo(path=path, type="file", size=len(data)) + for path, data in self.storage.items() + if not prefix or path.startswith(prefix) + ] + + def upload_files( + self, + bucket_id: str, + add: Sequence[tuple[bytes, str]], + ) -> None: + del bucket_id + with self._lock: + for index, (data, path) in enumerate(add, start=1): + self.storage[path] = bytes(data) + if ( + self.fail_upload_after is not None + and index >= self.fail_upload_after + ): + raise RuntimeError("injected partial upload failure") + + def download_files( + self, + bucket_id: str, + files: Sequence[tuple[str, str]], + ) -> None: + del bucket_id + with self._lock: + items = [ + (remote, local, self.storage.get(remote)) for remote, local in files + ] + + for _remote, local, data in items: + if data is None: + continue + Path(local).write_bytes(data) + + def delete_files( + self, + bucket_id: str, + delete: Sequence[str], + ) -> None: + del bucket_id + with self._lock: + for path in delete: + self.deleted_paths.append(path) + self.storage.pop(path, None) + + def contains(self, path: str) -> bool: + with self._lock: + return path in self.storage + + +def create_object_key(chunk_id: int, model_name: str = "test/model") -> ObjectKey: + return ObjectKey( + chunk_hash=ObjectKey.IntHash2Bytes(chunk_id), + model_name=model_name, + kv_rank=0, + ) + + +def create_memory_obj(size: int = 16, fill_value: float = 1.0) -> MemoryObj: + raw_data = torch.empty(size, dtype=torch.float32) + raw_data.fill_(fill_value) + metadata = MemoryObjMetadata( + shape=torch.Size([size]), + dtype=torch.float32, + address=0, + phy_size=size * 4, + fmt=MemoryFormat.KV_2LTD, + ref_count=1, + ) + return TensorMemoryObj(raw_data, metadata, parent_allocator=None) + + +def bucket_path_for_key(key: ObjectKey) -> str: + return _object_key_to_bucket_path(key, _TEST_BUCKET_LOCATION) + + +def wait_for_event_fd(event_fd: int, timeout: float = 5.0) -> bool: + poll = select.poll() + poll.register(event_fd, select.POLLIN) + events = poll.poll(timeout * 1000) + if events: + try: + consume_fd(event_fd) + except BlockingIOError: + pass + return True + return False + + +@pytest.fixture +def fake_client() -> _FakeBucketClient: + return _FakeBucketClient() + + +@pytest.fixture +def adapter(tmp_path: Path, fake_client: _FakeBucketClient): + cfg = HFBucketL2AdapterConfig( + bucket_handle=_TEST_BUCKET_HANDLE, + download_tmp_dir=str(tmp_path), + metadata_cache_ttl_secs=30, + num_workers=2, + max_capacity_gb=0.001, + ) + adapter = HFBucketL2Adapter(cfg, bucket_client=fake_client) + yield adapter + adapter.close() + + +class _RecordingListener(L2AdapterListener): + def __init__(self) -> None: + self.stored: list[list[ObjectKey]] = [] + self.accessed: list[list[ObjectKey]] = [] + self.deleted: list[list[ObjectKey]] = [] + + def on_l2_keys_stored(self, keys): + self.stored.append(list(keys)) + + def on_l2_keys_accessed(self, keys): + self.accessed.append(list(keys)) + + def on_l2_keys_deleted(self, keys): + self.deleted.append(list(keys)) + + +class TestObjectKeySerialization: + def test_format(self) -> None: + key = ObjectKey( + chunk_hash=b"\x00\x01\x02\x03", + model_name="llama", + kv_rank=255, + ) + assert _object_key_to_string(key) == "llama@000000ff@00010203" + + def test_cache_salt_appended(self) -> None: + base_key = ObjectKey( + chunk_hash=b"\x00\x01\x02\x03", + model_name="llama", + kv_rank=255, + ) + salted = ObjectKey( + chunk_hash=b"\x00\x01\x02\x03", + model_name="llama", + kv_rank=255, + cache_salt="user-42", + ) + assert _object_key_to_string(base_key) == "llama@000000ff@00010203" + assert _object_key_to_string(salted) == "llama@000000ff@00010203@user-42" + assert _object_key_to_string(base_key) != _object_key_to_string(salted) + + def test_bucket_path_uses_prefix_and_encoding(self) -> None: + cfg = HFBucketL2AdapterConfig(bucket_handle=_TEST_BUCKET_HANDLE) + key = create_object_key(1) + path = _object_key_to_bucket_path(key, cfg.bucket_location) + assert path.startswith("prod/") + assert "/" not in path.removeprefix("prod/") + + +class TestEventFdInterface: + def test_three_distinct_fds(self, adapter: HFBucketL2Adapter) -> None: + a = adapter.get_store_event_fd() + b = adapter.get_lookup_and_lock_event_fd() + c = adapter.get_load_event_fd() + assert a >= 0 and b >= 0 and c >= 0 + assert len({a, b, c}) == 3 + + +class TestStoreLookupLoad: + def test_roundtrip_single_key(self, adapter: HFBucketL2Adapter) -> None: + key = create_object_key(1) + obj = create_memory_obj(fill_value=3.14) + + tid = adapter.submit_store_task([key], [obj]) + assert wait_for_event_fd(adapter.get_store_event_fd()) + assert adapter.pop_completed_store_tasks()[tid].is_successful() + + tid = adapter.submit_lookup_and_lock_task([key]) + assert wait_for_event_fd(adapter.get_lookup_and_lock_event_fd()) + bm = adapter.query_lookup_and_lock_result(tid) + assert bm is not None and bm.test(0) is True + + dst = create_memory_obj(fill_value=0.0) + tid = adapter.submit_load_task([key], [dst]) + assert wait_for_event_fd(adapter.get_load_event_fd()) + bm = adapter.query_load_result(tid) + assert bm is not None and bm.test(0) is True + assert torch.allclose(dst.tensor, torch.full((16,), 3.14)) + + def test_partial_hits(self, adapter: HFBucketL2Adapter) -> None: + stored = [create_object_key(0), create_object_key(2)] + objs = [create_memory_obj(fill_value=float(i)) for i in range(2)] + adapter.submit_store_task(stored, objs) + wait_for_event_fd(adapter.get_store_event_fd()) + adapter.pop_completed_store_tasks() + + keys = [create_object_key(i) for i in range(4)] + tid = adapter.submit_lookup_and_lock_task(keys) + wait_for_event_fd(adapter.get_lookup_and_lock_event_fd()) + bm = adapter.query_lookup_and_lock_result(tid) + assert bm is not None + assert bm.test(0) is True + assert bm.test(1) is False + assert bm.test(2) is True + assert bm.test(3) is False + + def test_load_miss_returns_zero_bit(self, adapter: HFBucketL2Adapter) -> None: + key = create_object_key(99) + dst = create_memory_obj() + tid = adapter.submit_load_task([key], [dst]) + wait_for_event_fd(adapter.get_load_event_fd()) + bm = adapter.query_load_result(tid) + assert bm is not None and bm.test(0) is False + + def test_load_size_mismatch_returns_zero_bit( + self, + adapter: HFBucketL2Adapter, + fake_client: _FakeBucketClient, + ) -> None: + key = create_object_key(7) + object_path = bucket_path_for_key(key) + fake_client.storage[object_path] = b"too-small" + + dst = create_memory_obj() + tid = adapter.submit_load_task([key], [dst]) + wait_for_event_fd(adapter.get_load_event_fd()) + bm = adapter.query_load_result(tid) + assert bm is not None and bm.test(0) is False + + def test_query_lookup_returns_none_after_pop( + self, + adapter: HFBucketL2Adapter, + ) -> None: + key = create_object_key(1) + tid = adapter.submit_lookup_and_lock_task([key]) + wait_for_event_fd(adapter.get_lookup_and_lock_event_fd()) + assert adapter.query_lookup_and_lock_result(tid) is not None + assert adapter.query_lookup_and_lock_result(tid) is None + + def test_partial_store_failure_accounts_written_keys( + self, + adapter: HFBucketL2Adapter, + fake_client: _FakeBucketClient, + ) -> None: + fake_client.fail_upload_after = 1 + keys = [create_object_key(0), create_object_key(1)] + objs = [create_memory_obj(), create_memory_obj()] + + tid = adapter.submit_store_task(keys, objs) + assert wait_for_event_fd(adapter.get_store_event_fd()) + assert not adapter.pop_completed_store_tasks()[tid].is_successful() + + assert fake_client.contains(bucket_path_for_key(keys[0])) + assert not fake_client.contains(bucket_path_for_key(keys[1])) + assert adapter.get_usage().total_bytes_used == 64 + + +class TestEviction: + def _store(self, adapter: HFBucketL2Adapter, key: ObjectKey) -> None: + adapter.submit_store_task([key], [create_memory_obj()]) + wait_for_event_fd(adapter.get_store_event_fd()) + adapter.pop_completed_store_tasks() + + def _lookup(self, adapter: HFBucketL2Adapter, key: ObjectKey): + tid = adapter.submit_lookup_and_lock_task([key]) + wait_for_event_fd(adapter.get_lookup_and_lock_event_fd()) + return adapter.query_lookup_and_lock_result(tid) + + def test_delete_removes_key( + self, + adapter: HFBucketL2Adapter, + fake_client: _FakeBucketClient, + ) -> None: + key = create_object_key(1) + self._store(adapter, key) + object_path = bucket_path_for_key(key) + assert fake_client.contains(object_path) + + adapter.delete([key]) + assert not fake_client.contains(object_path) + + def test_lock_blocks_delete( + self, + adapter: HFBucketL2Adapter, + fake_client: _FakeBucketClient, + ) -> None: + key = create_object_key(1) + self._store(adapter, key) + bm = self._lookup(adapter, key) + assert bm is not None and bm.test(0) is True + + deletes_before = len(fake_client.deleted_paths) + adapter.delete([key]) + assert len(fake_client.deleted_paths) == deletes_before + + adapter.submit_unlock([key]) + adapter.delete([key]) + object_path = bucket_path_for_key(key) + assert not fake_client.contains(object_path) + + def test_refcount_unlock( + self, + adapter: HFBucketL2Adapter, + fake_client: _FakeBucketClient, + ) -> None: + key = create_object_key(1) + self._store(adapter, key) + self._lookup(adapter, key) + self._lookup(adapter, key) + + adapter.submit_unlock([key]) + adapter.delete([key]) + object_path = bucket_path_for_key(key) + assert fake_client.contains(object_path) + + adapter.submit_unlock([key]) + adapter.delete([key]) + assert not fake_client.contains(object_path) + + def test_delete_on_unknown_key(self, adapter: HFBucketL2Adapter) -> None: + adapter.delete([create_object_key(42)]) + + +class TestGetUsage: + def test_disabled_returns_minus_one(self, tmp_path: Path) -> None: + cfg = HFBucketL2AdapterConfig( + bucket_handle="hf://buckets/test-org/test-bucket", + download_tmp_dir=str(tmp_path), + max_capacity_gb=0.0, + ) + adapter = HFBucketL2Adapter(cfg, bucket_client=_FakeBucketClient()) + try: + usage = adapter.get_usage() + # 0/0 is defined as -1.0 to indicate disabled + assert usage.usage_fraction == -1.0 + assert usage.total_bytes_used == 0 + assert usage.total_capacity_bytes == 0 + finally: + adapter.close() + + def test_usage_grows_on_store_and_shrinks_on_delete( + self, + adapter: HFBucketL2Adapter, + ) -> None: + keys = [create_object_key(i) for i in range(4)] + objs = [create_memory_obj() for _ in range(4)] + + adapter.submit_store_task(keys, objs) + wait_for_event_fd(adapter.get_store_event_fd()) + adapter.pop_completed_store_tasks() + + total = 4 * 64 + capacity = int(0.001 * 1024**3) + usage = adapter.get_usage() + assert usage.total_bytes_used == total + assert usage.total_capacity_bytes == capacity + assert usage.usage_fraction == pytest.approx(total / capacity) + + adapter.delete(keys) + usage = adapter.get_usage() + assert usage.total_bytes_used == 0 + assert usage.usage_fraction == 0.0 + + +class TestListener: + def test_stored_accessed_and_deleted_fire( + self, + adapter: HFBucketL2Adapter, + ) -> None: + listener = _RecordingListener() + adapter.register_listener(listener) + + key = create_object_key(1) + adapter.submit_store_task([key], [create_memory_obj()]) + wait_for_event_fd(adapter.get_store_event_fd()) + adapter.pop_completed_store_tasks() + time.sleep(0.05) + assert any(key in batch for batch in listener.stored) + + tid = adapter.submit_lookup_and_lock_task([key]) + wait_for_event_fd(adapter.get_lookup_and_lock_event_fd()) + adapter.query_lookup_and_lock_result(tid) + time.sleep(0.05) + assert any(key in batch for batch in listener.accessed) + accessed_count = len(listener.accessed) + + dst = create_memory_obj(fill_value=0.0) + tid = adapter.submit_load_task([key], [dst]) + wait_for_event_fd(adapter.get_load_event_fd()) + adapter.query_load_result(tid) + time.sleep(0.05) + assert len(listener.accessed) == accessed_count + + adapter.submit_unlock([key]) + adapter.delete([key]) + assert any(key in batch for batch in listener.deleted) + + +class TestConfig: + def test_from_dict_requires_bucket_handle(self) -> None: + with pytest.raises(ValueError): + HFBucketL2AdapterConfig.from_dict({"type": "hfbucket"}) + + def test_from_dict_parses_all_fields(self) -> None: + cfg = HFBucketL2AdapterConfig.from_dict( + { + "type": "hfbucket", + "bucket_handle": _TEST_BUCKET_HANDLE, + "token_env": "HF_TEST_TOKEN", + "token": "direct-token", + "create_bucket_if_missing": True, + "download_tmp_dir": "/tmp/hf", + "metadata_cache_ttl_secs": 12.5, + "num_workers": 8, + "max_capacity_gb": 2.5, + } + ) + assert cfg.bucket_handle == _TEST_BUCKET_HANDLE + assert cfg.bucket_location.bucket_id == "test-org/test-bucket" + assert cfg.bucket_location.object_prefix == "prod" + assert cfg.token_env == "HF_TEST_TOKEN" + assert cfg.token == "direct-token" + assert cfg.create_bucket_if_missing is True + assert cfg.download_tmp_dir == Path("/tmp/hf") + assert cfg.metadata_cache_ttl_secs == 12.5 + assert cfg.num_workers == 8 + assert cfg.max_capacity_gb == 2.5 + + # strict boolean parsing + def test_from_dict_rejects_string_boolean(self) -> None: + with pytest.raises(ValueError, match="create_bucket_if_missing"): + HFBucketL2AdapterConfig.from_dict( + { + "type": "hfbucket", + "bucket_handle": _TEST_BUCKET_HANDLE, + "create_bucket_if_missing": "false", + } + ) + + def test_help_nonempty(self) -> None: + assert isinstance(HFBucketL2AdapterConfig.help(), str) + assert "bucket_handle" in HFBucketL2AdapterConfig.help() + + +class TestFactoryRegistration: + def test_create_l2_adapter_registers_hfbucket( + self, + monkeypatch, + tmp_path: Path, + ) -> None: + # First Party + from lmcache.v1.distributed.l2_adapters import create_l2_adapter + + monkeypatch.setattr( + hfmod, + "HFBucketClient", + lambda token=None: _FakeBucketClient(), + ) + cfg = HFBucketL2AdapterConfig.from_dict( + { + "type": "hfbucket", + "bucket_handle": _TEST_BUCKET_HANDLE, + "download_tmp_dir": str(tmp_path), + "num_workers": 1, + } + ) + adapter = create_l2_adapter(cfg) + try: + assert isinstance(adapter, HFBucketL2Adapter) + finally: + adapter.close() + + +class TestCleanup: + def test_close_cleans_temp_dir( + self, + tmp_path: Path, + fake_client: _FakeBucketClient, + ) -> None: + cfg = HFBucketL2AdapterConfig( + bucket_handle=_TEST_BUCKET_HANDLE, + download_tmp_dir=str(tmp_path), + ) + adapter = HFBucketL2Adapter(cfg, bucket_client=fake_client) + assert list(tmp_path.iterdir()) + + adapter.close() + + assert list(tmp_path.iterdir()) == [] From 07f68b27322ecaa884dd165db5ba99290ceb537b Mon Sep 17 00:00:00 2001 From: deng451e <57919305+deng451e@users.noreply.github.com> Date: Mon, 8 Jun 2026 12:55:21 -0700 Subject: [PATCH 04/57] [CB] Token-level matching + per-token slot scatter for non-block-aligned KV reuse (#3582) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [blend-v3] Token-level matching + per-token slot scatter for CB reuse Match fingerprints at token stride (probe_stride=1) and scatter reused KV with the per-token slot kernel (multi_layer_kv_transfer) instead of matching/scattering at vLLM block granularity. This lets CacheBlend reuse non-block-aligned matches, the common case for real workloads where the shared body starts at an arbitrary token offset (a partial vLLM block) rather than a chunk/block boundary. - register_rope: probe_stride = 1 (find matches at any token offset) - cb_unified_lookup: accept non-prefix matches at any cur_st (drop the chunk-alignment filter) - cb_retrieve_pre_computed: per-token slot scatter of the full matched range. Partial vLLM blocks are written per slot, so matched and recomputed tokens sharing a block don't conflict. Removes the block-aligned drop checks and the now-dead whole-block scatter path. Validated on prefix-suffix-tuner (non-block-aligned by construction): ~99% suffix hit, 3.91x TTFT vs full recompute, output matches the full-recompute baseline. The slot kernel is bandwidth-bound and matches the whole-block kernel's throughput (~700 GB/s), so no scatter overhead. Signed-off-by: deng451e <838677410@qq.com> * [blend-v3] Vectorize V3 matcher probe; drop obsolete probe stride Token-level matching (probe_stride=1) had turned match_sub_sequence into an O(tokens) pure-Python probe loop — ~5.7 ms at 32K context, ~7x the old block-stride cost. Replace it with a vectorized direct-address probe (numpy gather over all positions) plus a verify loop over only the surviving hits; the table is sparse (TABLE_SIZE = 2^20 >> registered chunks) so the hit set is tiny. This restores the base class's vectorization that the V3 override had dropped, keeping full-hash collision rejection. Probe stride is now obsolete (we always scan every position), so the _probe_stride field, ctor arg, and register_rope assignment are removed. Matcher microbench (CPU, per lookup): 32K ctx 5.66 -> 0.83 ms (~7x), 20K 3.43 -> 0.52, 8K 1.39 -> 0.23 — back to the pre-token-scatter block-stride baseline with full token-level matching. All 20 test_optimized_lookup_v3 tests pass. Signed-off-by: deng451e <838677410@qq.com> * update Signed-off-by: deng451e <838677410@qq.com> * update stale docstring Signed-off-by: deng451e <838677410@qq.com> --------- Signed-off-by: deng451e <838677410@qq.com> --- lmcache/v1/multiprocess/custom_types.py | 13 +- lmcache/v1/multiprocess/modules/blend_v3.py | 263 +++++++++--------- lmcache/v1/multiprocess/protocols/blend_v3.py | 13 +- 3 files changed, 141 insertions(+), 148 deletions(-) diff --git a/lmcache/v1/multiprocess/custom_types.py b/lmcache/v1/multiprocess/custom_types.py index af991f05fb..a177c705f2 100644 --- a/lmcache/v1/multiprocess/custom_types.py +++ b/lmcache/v1/multiprocess/custom_types.py @@ -376,16 +376,19 @@ class CBMatchResult: @dataclass class CBUnifiedLookupResult: - """Result of ``CB_UNIFIED_LOOKUP``: prefix lookup + non-prefix fingerprint - match, reconciled in one RPC. + """Resolved payload of ``CB_UNIFIED_LOOKUP``: prefix lookup + non-prefix + fingerprint match, reconciled in one RPC. The RPC returns ``None`` (not this) + while either leg's KV is still loading into L1; this type is sent only once + both are resident. Attributes: prefix_coverage_tokens: Contiguous prefix-cache coverage (L1+L2) in tokens — what the standard LOOKUP would report. - non_prefix_segments: Block-aligned matches outside the prefix coverage + non_prefix_segments: Fingerprint matches outside the prefix coverage (cur_st order), each carrying ``(old_st, old_ed, cur_st, cur_ed, - hash)``. Already sparse-prefetched, so the retrieve set equals the - prefetched set. + hash)``. Token-aligned (any offset, not block-aligned): the per-token + slot scatter handles them. Already resident in L1, so the retrieve + set equals the prefetched set. """ prefix_coverage_tokens: int diff --git a/lmcache/v1/multiprocess/modules/blend_v3.py b/lmcache/v1/multiprocess/modules/blend_v3.py index 6905460026..20db4908a9 100644 --- a/lmcache/v1/multiprocess/modules/blend_v3.py +++ b/lmcache/v1/multiprocess/modules/blend_v3.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -"""Blend V3: paged-aware CacheBlend as an :class:`EngineModule`. +"""Blend V3: paged-aware CacheBlend as an EngineModule. -Plugs into the unified :class:`MPCacheEngine`; standard ``REGISTER_KV_CACHE`` -+ ``CB_REGISTER_ROPE_V3`` for setup; STORE wrapper registers fingerprints; +Plugs into the unified MPCacheEngine; standard REGISTER_KV_CACHE + +CB_REGISTER_ROPE_V3 for setup; STORE wrapper registers fingerprints; retrieve scatters into the request's paged blocks. """ @@ -39,7 +39,6 @@ from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext from lmcache.v1.multiprocess.engine_module import HandlerSpec, ThreadPoolType from lmcache.v1.multiprocess.gpu_context import GPUCacheContext -from lmcache.v1.multiprocess.modules.blend import BlendTokenRangeMatcher from lmcache.v1.multiprocess.modules.gpu_transfer import GPUTransferModule from lmcache.v1.multiprocess.modules.lookup import LookupModule from lmcache.v1.multiprocess.protocol import RequestType @@ -65,7 +64,7 @@ class _CBRopeState: @dataclass class _CBUnifiedJob: - """Per-request poll state for non-blocking ``cb_unified_lookup``. + """Per-request poll state for non-blocking cb_unified_lookup. Stashed across polls because the underlying status/found polls are consume-once. @@ -82,17 +81,35 @@ class _CBUnifiedJob: found_uidx: set[int] | None = None # stashed when the sparse poll completes -class BlendTokenRangeMatcherV3(BlendTokenRangeMatcher): - """V3 matcher: full-hash collision rejection + block-aligned probe stride. +class BlendTokenRangeMatcherV3: + """V3 matcher: token-level probe (any offset) + full-hash collision + rejection. Self-contained (does not inherit a base matcher).""" - Probes every ``probe_stride`` positions; lossless because retrieve drops - non-block-aligned ``cur_st`` anyway. - """ + _TABLE_BITS: int = 20 # 2^20 ~ 1 M entries + _TABLE_SIZE: int = 1 << _TABLE_BITS + _BASE: np.uint64 = np.uint64(0x9E3779B97F4A7C15) # Fibonacci-hashing const + + def __init__(self, chunk_size: int = 256): + """Initialize the V3 matcher. - def __init__(self, chunk_size: int = 256, probe_stride: int = 16): - super().__init__(chunk_size) + Args: + chunk_size (int): Tokens per non-overlapping fingerprint chunk. + """ + self.chunk_size = chunk_size + # poly_chunk_hash -> compact_chunk_id; -1 = empty + self._table_id = np.full(self._TABLE_SIZE, -1, dtype=np.int64) + self._mask = np.uint64(self._TABLE_SIZE - 1) + # compact_chunk_id -> caller token_hash (full bytes); None once evicted + self._chunk_token_hash: list[bytes | None] = [] + # token_hash -> start position in its registered sequence + self._token_hash_to_start: dict[bytes, int] = {} + # compact_chunk_id -> table slot (reverse lookup for eviction) + self._compact_id_to_slot = np.full(self._TABLE_SIZE, -1, dtype=np.int64) + # token_hash -> compact_chunk_id (for eviction lookup) + self._token_hash_to_compact_id: dict[bytes, int] = {} + self._lock = threading.Lock() + # V3 addition: compact_chunk_id -> full poly hash, for collision reject. self._chunk_poly_hash: list[int] = [] - self._probe_stride: int = probe_stride def on_new_token_hashes( self, @@ -101,9 +118,23 @@ def on_new_token_hashes( start_chunk_idx: int = 0, position_offset: int = 0, ) -> None: - """Index non-overlapping chunks; ``start_chunk_idx=1`` skips pos-0 - (handled by the standard prefix lookup); ``position_offset`` is - added to recorded positions for tail-slices.""" + """Index a stored sequence's non-overlapping chunks into the matcher. + + Records each new chunk's poly hash + start position so a later + match_sub_sequence can find it. Thread-safe (holds the matcher lock). + + Args: + token_ids (list[int]): The stored sequence's token IDs. + token_hashes (list[bytes]): Per-chunk content hashes (one per + chunk), used as the dedup/eviction key. + start_chunk_idx (int): First chunk to index; 1 skips chunk 0 (the + standard prefix lookup owns it). + position_offset (int): Added to each recorded start position (for + indexing a tail-slice of a larger sequence). + + Returns: + None. + """ arr = np.array(token_ids, dtype=np.uint64) chunk_hashes = chunk_hash_windows_numba(arr, self.chunk_size, self._BASE) n = int(chunk_hashes.shape[0]) @@ -160,74 +191,66 @@ def match_sub_sequence( self, token_ids: list[int], ) -> list[CBMatchResult]: - """Probe rolling-hash array every ``probe_stride`` positions; skips - bucket-only collisions and evicted entries. One result per unique - match; ``cur_st`` is the first block-aligned hit.""" + """Find every registered chunk reused anywhere in a query sequence. + + Vectorized direct-address probe over all token positions, then a small + verify loop over the surviving hits (a full poly-hash check rejects + bucket collisions; evicted/unknown chunks are skipped). Thread-safe. + + Args: + token_ids (list[int]): The query sequence's token IDs. + + Returns: + list[CBMatchResult]: One result per unique reused chunk (cur_st + = its first query position, old_st = its stored position). + Empty if the query is shorter than one chunk or nothing matched. + """ if len(token_ids) < self.chunk_size: return [] arr = np.array(token_ids, dtype=np.uint64) rolling = rolling_hash_windows_numba(arr, self.chunk_size, self._BASE) - n_positions = int(rolling.shape[0]) with self._lock: if not self._chunk_token_hash: - logger.info( - "[match_probe] empty fingerprint table; n_tok=%d", len(token_ids) - ) return [] - mask = int(self._mask) - stride = self._probe_stride + # Vectorized direct-address probe over all positions. The table is + # sparse (TABLE_SIZE >> registered chunks), so only true matches and + # a few bucket collisions reach the Python verify loop below. + cids_at_pos = self._table_id[rolling & self._mask] + hit_positions = np.nonzero(cids_at_pos >= 0)[0] + seen_cids: set[int] = set() results: list[CBMatchResult] = [] - n_probes = 0 - n_table_hit = 0 - n_collision = 0 - n_evicted = 0 - n_no_old_st = 0 - for q_pos in range(0, n_positions, stride): - n_probes += 1 - r = int(rolling[q_pos]) - cid = int(self._table_id[r & mask]) - if cid < 0 or cid in seen_cids: - continue - n_table_hit += 1 - if r != self._chunk_poly_hash[cid]: - n_collision += 1 + for pos in hit_positions: + pos = int(pos) + cid = int(cids_at_pos[pos]) + if cid in seen_cids: continue + if int(rolling[pos]) != self._chunk_poly_hash[cid]: + continue # bucket-only collision th = self._chunk_token_hash[cid] if th is None: - n_evicted += 1 - continue + continue # evicted old_st = self._token_hash_to_start.get(th) if old_st is None: - n_no_old_st += 1 continue seen_cids.add(cid) results.append( CBMatchResult( old_st=old_st, old_ed=old_st + self.chunk_size, - cur_st=q_pos, - cur_ed=q_pos + self.chunk_size, + cur_st=pos, + cur_ed=pos + self.chunk_size, hash=th, ) ) logger.info( - "[match_probe] n_tok=%d stride=%d n_probes=%d " - "table_hit=%d collisions=%d evicted=%d no_old_st=%d " - "→ matches=%d (sample old_st=%s cur_st=%s)", + "[match_probe] n_tok=%d table_hits=%d matches=%d", len(token_ids), - stride, - n_probes, - n_table_hit, - n_collision, - n_evicted, - n_no_old_st, + len(hit_positions), len(results), - [r.old_st for r in results[:3]], - [r.cur_st for r in results[:3]], ) return results @@ -386,8 +409,8 @@ def cb_register_rope( head_size: int, is_neox_style: bool, ) -> None: - """Bolt rope state onto an already-registered ``cache_contexts`` entry; - idempotent. ``REGISTER_KV_CACHE`` must precede this.""" + """Bolt rope state onto an already-registered cache_contexts entry; + idempotent. REGISTER_KV_CACHE must precede this.""" cache_contexts = self._gpu_transfer.cache_contexts if instance_id not in cache_contexts: raise ValueError( @@ -429,31 +452,14 @@ def cb_register_rope( self._cb_gpu_contexts[instance_id] = gpu_context self._cb_gpu_context_meta[instance_id] = (entry.model_name, entry.world_size) - # Probe stride = ie block size; must divide chunk_size. - ie_logical_block_size = ( - gpu_context.kv_layer_groups_manager.inference_engine_logical_block_size - ) - if self._ctx.chunk_size % ie_logical_block_size == 0: - self._token_range_matcher._probe_stride = ie_logical_block_size - else: - logger.warning( - "CB matcher probe stride unchanged (%d): chunk_size %d is not " - "a multiple of inference_engine_logical_block_size %d.", - self._token_range_matcher._probe_stride, - self._ctx.chunk_size, - ie_logical_block_size, - ) - logger.info( "Registered CB rope state for instance %d " - "(cos_sin_cache shape=%s dtype=%s, head_size=%d, is_neox=%s, " - "matcher_probe_stride=%d)", + "(cos_sin_cache shape=%s dtype=%s, head_size=%d, is_neox=%s)", instance_id, tuple(cos_sin_cache.shape), cos_sin_cache.dtype, head_size, is_neox_style, - self._token_range_matcher._probe_stride, ) def cb_unregister_rope(self, instance_id: int) -> None: @@ -494,7 +500,7 @@ def _drain_fingerprints_sync(self) -> None: def _match_fingerprints(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: """Drain pending registrations, fingerprint-match sub-sequences, then leftmost-greedy dedup over overlapping ranges. Returns matches sorted - by ``cur_st`` (empty if none).""" + by cur_st (empty if none).""" self._drain_fingerprints_sync() matches = self._token_range_matcher.match_sub_sequence(list(key.token_ids)) if not matches: @@ -527,8 +533,8 @@ def _sparse_prefetch_submit( layout_desc: "MemoryLayoutDesc", matches: list[CBMatchResult], ) -> "tuple[PrefetchHandle, dict[bytes, list], list[int]]": - """Coalesce all ``matches`` into one sparse prefetch and submit it - (non-blocking). The caller polls ``query_prefetch_status(handle)`` then + """Coalesce all matches into one sparse prefetch and submit it + (non-blocking). The caller polls query_prefetch_status(handle) then calls :meth:`_sparse_classify` with the found set.""" world_size = key.world_size per_hash_obj_keys: dict[bytes, list] = {} @@ -625,7 +631,7 @@ def cb_unified_lookup( """Non-blocking single-RPC CB lookup (submit-once, poll-on-recall). First call submits the prefix lookup + fingerprint match; later calls - poll both legs, returning ``None`` until the prefix and the sparse + poll both legs, returning None until the prefix and the sparse complement are both resident in L1 (so a worker thread never blocks on the L2->L1 loads). The prefix job's L1 read locks persist for the retrieve. @@ -667,11 +673,9 @@ def cb_unified_lookup( # enter the sparse prefetch, so they cannot leak a read lock. if not job.sparse_started: prefix_tokens = job.prefix_chunks * chunk_size - job.non_prefix = [ - r - for r in job.matches - if r.cur_st >= prefix_tokens and r.cur_st % chunk_size == 0 - ] + # Any offset is fine: the per-token slot scatter writes + # non-block-aligned matches. + job.non_prefix = [r for r in job.matches if r.cur_st >= prefix_tokens] if job.non_prefix: layout_desc = self._resolve_cb_layout_desc( key.model_name, key.world_size @@ -786,7 +790,7 @@ def store( return result def _drain_fingerprint_queue(self) -> None: - """Best-effort background drainer for ``_fingerprint_queue``.""" + """Best-effort background drainer for _fingerprint_queue.""" while not self._fingerprint_stop.is_set(): try: job = self._fingerprint_queue.get(timeout=0.1) @@ -816,7 +820,7 @@ def _apply_cb_rope_batched( slots_to_rope: list[tuple[int, int, int]], ) -> None: """Re-RoPE tmp-pool slots in-place (K-only, per group); list of - ``(slot_idx, old_st, cur_st)``.""" + (slot_idx, old_st, cur_st).""" if not slots_to_rope: return num_groups = gpu_context.kv_layer_groups_manager.num_groups @@ -891,7 +895,14 @@ def cb_retrieve_pre_computed( with self._lookup_obj_keys_lock: cached = self._lookup_obj_keys_cache.pop(key.request_id, None) if cached is not None and all(r.hash in cached for r in cb_match_result): - all_obj_keys = [k for r in cb_match_result for k in cached[r.hash]] + # The lookup cached all-ranks obj keys (world_size per hash). This + # retrieve is per-worker, so select THIS rank's key -> M objects, not + # M*world_size (else the zip below silently truncates and mispairs + # ranks at TP>1). Mirrors the non-cached path's per-worker resolve. + if key.worker_id is not None and key.world_size > 1: + all_obj_keys = [cached[r.hash][key.worker_id] for r in cb_match_result] + else: + all_obj_keys = [k for r in cb_match_result for k in cached[r.hash]] else: all_obj_keys = ipc_key_to_object_keys( key, [r.hash for r in cb_match_result] @@ -937,7 +948,6 @@ def cb_retrieve_pre_computed( f"chunk_size {chunk_size} must be a multiple of " f"inference_engine_logical_block_size {ie_logical_block_size}" ) - blocks_per_chunk = chunk_size // ie_logical_block_size num_groups = gpu_context.kv_layer_groups_manager.num_groups with ( @@ -982,30 +992,18 @@ def cb_retrieve_pre_computed( if memory_objs is None: return event_ipc_handle, False - # Drop malformed matches up front. + # Per-token scatter handles any cur_st; just bound the + # matched range to the allocated slots. pairs: list[tuple[CBMatchResult, Any]] = [] - for r, memory_obj in zip( - cb_match_result, memory_objs, strict=False - ): - if r.cur_st % ie_logical_block_size != 0: + num_slots = int(all_block_ids_gpu.numel()) * ie_logical_block_size + for r, memory_obj in zip(cb_match_result, memory_objs, strict=True): + if r.cur_ed > num_slots: logger.warning( - "Dropping CB match cur_st=%d: not aligned to " - "ie_logical_block_size=%d.", + "Dropping CB match cur_st=%d cur_ed=%d: exceeds " + "%d slots. Request %s.", r.cur_st, - ie_logical_block_size, - ) - continue - cbs = r.cur_st // ie_logical_block_size - if cbs + blocks_per_chunk > int(all_block_ids_gpu.numel()): - logger.warning( - "Dropping CB match cur_st=%d old_st=%d: needs " - "blocks [%d:%d) but gpu_block_ids has %d. " - "Request %s.", - r.cur_st, - r.old_st, - cbs, - cbs + blocks_per_chunk, - int(all_block_ids_gpu.numel()), + r.cur_ed, + num_slots, key.request_id, ) continue @@ -1025,7 +1023,6 @@ def cb_retrieve_pre_computed( for batch_start in range(0, len(run), max_batch): batch = run[batch_start : batch_start + max_batch] batch_len = len(batch) - first_cur_st = batch[0][0].cur_st # (a) H2D fill into per-chunk tmp slots. for slot_idx, (_, memory_obj) in enumerate(batch): @@ -1044,14 +1041,24 @@ def cb_retrieve_pre_computed( gpu_context, rope_state, batch_len, slots_to_rope ) - # (c) One batched scatter per group. - chunk_block_start = first_cur_st // ie_logical_block_size - chunk_block_end = ( - chunk_block_start + batch_len * blocks_per_chunk + # (c) Per-token slot scatter: partial vLLM blocks + # shared with recomputed tokens stay disjoint. + bs = ie_logical_block_size + pos = torch.cat( + [ + torch.arange( + r.cur_st, + r.cur_ed, + device=gpu_context.device, + dtype=torch.long, + ) + for (r, _) in batch + ] ) - chunk_block_ids_gpu = all_block_ids_gpu[ - chunk_block_start:chunk_block_end - ] + slot_mapping = all_block_ids_gpu[pos // bs] * bs + ( + pos % bs + ) + page_buffer_size = gpu_context.num_blocks * bs for group_idx in range(num_groups): tmp_buffers = ( gpu_context.get_tmp_chunk_gpu_buffer_batched( @@ -1059,23 +1066,17 @@ def cb_retrieve_pre_computed( group_idx=group_idx, ) ) - group_kv_pointers = gpu_context.get_group_kv_pointers( - group_idx - ) - group_lmcache_chunk_size = ( - gpu_context.get_physical_chunk_size(group_idx) - ) - - lmc_ops.multi_layer_block_kv_transfer( - group_kv_pointers, - [tb.data_ptr() for tb in tmp_buffers], - chunk_block_ids_gpu, + key_value = torch.cat(tmp_buffers, dim=2) + lmc_ops.multi_layer_kv_transfer( + key_value, + gpu_context.get_group_kv_pointers(group_idx), + slot_mapping, gpu_context.device, + page_buffer_size, lmc_ops.TransferDirection.H2D, - gpu_context.get_shape_desc(group_idx), - group_lmcache_chunk_size, gpu_context.gpu_kv_format_, - 0, # skip_blocks_in_chunk + block_size=bs, + head_size=rope_state.head_size, ) except Exception: logger.exception("Error during retrieving prefetched results") diff --git a/lmcache/v1/multiprocess/protocols/blend_v3.py b/lmcache/v1/multiprocess/protocols/blend_v3.py index 69a5114bf1..337c619768 100644 --- a/lmcache/v1/multiprocess/protocols/blend_v3.py +++ b/lmcache/v1/multiprocess/protocols/blend_v3.py @@ -1,16 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""Blend V3 protocol — paged-aware CB pipeline. - -RPCs: -* ``CB_REGISTER_ROPE_V3`` / ``CB_UNREGISTER_ROPE_V3`` — share / release the rope - cos/sin cache onto a context already registered via ``REGISTER_KV_CACHE``. -* ``CB_RETRIEVE_PRE_COMPUTED_V3`` — scatter all matched chunks (prefix- and - non-prefix-hit) into paged KV by per-token block ID; re-RoPE only the shifted - (``old_st != cur_st``) subset. -* ``CB_UNIFIED_LOOKUP`` — the sole live lookup path: one RPC runs prefix + - non-prefix match, reconcile, one sparse-coalesced prefetch, and per-TP-rank - classify. ``(IPCCacheEngineKey, tp_size)`` → ``CBUnifiedLookupResult``. -""" +"""Blend V3 protocol definitions.""" # First Party from lmcache.v1.multiprocess.custom_types import ( From 20cf3cdb99f26f86029a7524414cea647016c761 Mon Sep 17 00:00:00 2001 From: Yihua Cheng Date: Mon, 8 Jun 2026 15:37:19 -0700 Subject: [PATCH 05/57] [Core][MP] refactor the LMCache layer group for better compat with hybrid models (#3557) Signed-off-by: ApostaC --- lmcache/utils.py | 41 +- lmcache/v1/kv_layer_groups.py | 268 ++++-- lmcache/v1/multiprocess/gpu_context.py | 602 ++++++++----- lmcache/v1/multiprocess/modules/blend_v3.py | 30 +- .../v1/multiprocess/modules/gpu_transfer.py | 94 +- tests/v1/distributed/serde/test_serde_e2e.py | 14 +- .../test_blend_v3_load_store_opts.py | 31 +- tests/v1/multiprocess/test_gpu_context.py | 818 ++++++++++-------- .../test_gpu_transfer_layout_registry.py | 1 + tests/v1/test_kv_layer_groups_manager.py | 132 ++- 10 files changed, 1306 insertions(+), 725 deletions(-) diff --git a/lmcache/utils.py b/lmcache/utils.py index 2e7193ebf2..3145795678 100644 --- a/lmcache/utils.py +++ b/lmcache/utils.py @@ -5,13 +5,15 @@ # Standard from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, TypeVar, Union import asyncio +import functools import hashlib import inspect import re import threading import traceback +import warnings try: # Third Party @@ -700,6 +702,43 @@ def wrapper(*args, **kwargs): return wrapper +##### Deprecation ##### +F = TypeVar("F", bound=Callable[..., Any]) + + +def lmcache_deprecate(reason: str) -> Callable[[F], F]: + """Mark a function or method as deprecated. + + Calling the wrapped callable emits a ``DeprecationWarning`` and logs a + warning the first time it is invoked, including the supplied reason. + + Args: + reason: Human-readable explanation of why the callable is deprecated + and, ideally, what to use instead. + + Returns: + A decorator that wraps the target callable while preserving its + signature and metadata. + """ + + def decorator(func: F) -> F: + warned = False + + @functools.wraps(func) + def wrapper(*args, **kwargs): + nonlocal warned + if not warned: + message = f"{func.__qualname__} is deprecated: {reason}" + warnings.warn(message, DeprecationWarning, stacklevel=2) + logger.warning(message) + warned = True + return func(*args, **kwargs) + + return wrapper # type: ignore[return-value] + + return decorator + + #### Thread/asyncio-related utilities #### def handle_thread_exception(args): """Handle an uncaught exception reported by ``threading``. diff --git a/lmcache/v1/kv_layer_groups.py b/lmcache/v1/kv_layer_groups.py index bef741085d..2127bcb573 100644 --- a/lmcache/v1/kv_layer_groups.py +++ b/lmcache/v1/kv_layer_groups.py @@ -6,7 +6,7 @@ from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple # Third Party import torch @@ -14,6 +14,7 @@ # First Party from lmcache.logging import init_logger from lmcache.python_ops_fallback import set_shape_desc_dtype +from lmcache.utils import lmcache_deprecate import lmcache.c_ops as lmc_ops if TYPE_CHECKING: @@ -48,7 +49,16 @@ # block address space). Block IDs are only meaningful within one such group, so # layers from different groups must not share one LMCache group (and thus one # transfer-kernel launch) even if their tensor shape and dtype match. -LayerGroupIdentity = tuple[int, int, int, int, int, torch.dtype] +class KernelGroupIdentity(NamedTuple): + kv_size: int + num_heads: int + head_size: int + block_size: int + engine_group_idx: int + dtype: torch.dtype + + +LayerGroupIdentity = KernelGroupIdentity # Alias for compatibility # Sentinel ``per_layer_engine_group_idx`` value: a KV tensor tagged with it is @@ -66,7 +76,7 @@ def group_layers_by_identity( """Partition layer indices by :data:`LayerGroupIdentity`. This helper is shared by vLLM-side LMCache group inflation and server-side - ``KVLayerGroupInfo`` construction so both sides agree on group order. + ``KernelGroupInfo`` construction so both sides agree on group order. Args: kv_caches: Registered KV cache structure inspected for per-layer shape @@ -113,12 +123,21 @@ def group_layers_by_identity( hs = get_head_size(kv_caches, gpu_kv_format, idx) dt = get_dtype(kv_caches, gpu_kv_format, idx) bs = get_block_size(kv_caches, gpu_kv_format, idx) - groups_dict[(kv_size, nh, hs, bs, engine_group_idx, dt)].append(idx) + + identity = LayerGroupIdentity( + kv_size=kv_size, + num_heads=nh, + head_size=hs, + block_size=bs, + engine_group_idx=engine_group_idx, + dtype=dt, + ) + groups_dict[identity].append(idx) return sorted(groups_dict.items(), key=lambda kv: kv[1][0]) @dataclass -class KVLayerGroupInfo: +class KernelGroupInfo: """A single transfer-kernel dispatch unit: a set of KV layers that can ride one kernel launch with one ``PageBufferShapeDesc``. @@ -180,7 +199,7 @@ def __repr__(self) -> str: indices_repr = f"{self.layer_indices[0]}-{self.layer_indices[-1]}" sd = self.shape_desc return ( - f"KVLayerGroupInfo(layers={len(self.layer_indices)}, " + f"KernelGroupInfo(layers={len(self.layer_indices)}, " f"indices={indices_repr}, " f"shape_desc=(kv={sd.kv_size}, nl={sd.nl}, nb={sd.nb}, " f"bs={sd.bs}, nh={sd.nh}, hs={sd.hs}, " @@ -203,18 +222,41 @@ def hidden_dim_size(self) -> int: return self.shape_desc.nh * self.shape_desc.hs +KVLayerGroupInfo = KernelGroupInfo # Alias for compatibility + + +@dataclass +class ObjectGroupInfo: + """Metadata for an 'object group'. + + An object group contains one or more kernel groups whose + KV caches will be stored in the same memory object. + + This will be useful for dealing with sliding window or mamba + KV caches that needs a different prefix matching logic from + the full attention KV caches. + """ + + kernel_group_indices: list[int] + """Indices of the kernel groups belonging to this object group, in the + order they should be laid out in memory.""" + + # NOTE: will add fields to indicate the "kv cache type" of this + # object group in the follow-up PRs + + class KVLayerGroupsManager: """Partition a model's KV layers into transfer-kernel dispatch units. At construction time, every layer in ``kv_caches`` is bucketed by its :data:`LayerGroupIdentity` (``(kv_size, num_heads, head_size, block_size, engine_group_idx, dtype)``). Each bucket becomes one - :class:`KVLayerGroupInfo` holding the layer indices, a shared + :class:`KernelGroupInfo` holding the layer indices, a shared :class:`PageBufferShapeDesc`, and the group's torch dtype. Downstream consumers (``VLLMPagedMemGPUConnectorV3``, ``GPUCacheContext``, the multiprocess server) iterate - ``self.kv_layer_groups`` and issue one transfer-kernel launch per + ``self._kernel_groups`` and issue one transfer-kernel launch per group. The manager itself is a pure metadata object — it does not own any GPU buffers or perform any transfers. @@ -239,7 +281,7 @@ def __init__( ``(kv_size, num_heads, head_size, dtype)`` via the format-aware accessors in ``utils.py``. Layers with identical identities are bucketed together; each bucket becomes one - :class:`KVLayerGroupInfo`. + :class:`KernelGroupInfo`. Groups are emitted in the order of their first-appearing layer, so group indices are deterministic across runs. @@ -292,7 +334,8 @@ def __init__( if layout_hints else None ) - self.kv_layer_groups: list[KVLayerGroupInfo] = [] + self._kernel_groups: list[KernelGroupInfo] = [] + self._object_groups: list[ObjectGroupInfo] = [] num_layers = get_num_layers(kv_caches, gpu_kv_format) if num_layers == 0: @@ -334,6 +377,10 @@ def __init__( group_logical_block_size = ( max(global_logical, bs) if global_logical is not None else None ) + + # TODO (ApostaC): the code here is not very good. + # Conceptually, KV Layer Group should not be aware of lmcache logical + # chunk size at all. compress_ratio, physical_chunk_size = self._derive_compression_metadata( group_idx=group_idx, bs=bs, @@ -341,8 +388,8 @@ def __init__( lmcache_logical_chunk_size=lmcache_logical_chunk_size, ) - self.kv_layer_groups.append( - KVLayerGroupInfo( + self._kernel_groups.append( + KernelGroupInfo( layer_indices=indices, shape_desc=shape_desc, dtype=dt, @@ -354,10 +401,139 @@ def __init__( self.inference_engine_logical_block_size_ = ( self.inference_engine_logical_block_size_ - or self.kv_layer_groups[0].shape_desc.bs + or self._kernel_groups[0].shape_desc.bs + ) + + logger.info( + "KV layer groups: ---\n%s\n---", + "\n".join(repr(g) for g in self._kernel_groups), + ) + + # Detect the object groups + self._object_groups = self._detect_object_groups(group_views) + + @property + def kernel_groups(self) -> list[KernelGroupInfo]: + """List of :class:`KernelGroupInfo`, one per kernel group.""" + return self._kernel_groups + + @property + @lmcache_deprecate("`kv_layer_groups` is an outdated alias for `kernel_groups`") + def kv_layer_groups(self) -> list[KernelGroupInfo]: + """List of :class:`KernelGroupInfo`, one per kernel group.""" + return self._kernel_groups + + @property + def num_kernel_groups(self) -> int: + """Number of :class:`KernelGroupInfo` entries. + + Zero if ``kv_caches`` had no layers at construction time. + """ + return len(self._kernel_groups) + + @property + def object_groups(self) -> list[ObjectGroupInfo]: + """List of :class:`ObjectGroupInfo`, one per object group.""" + return self._object_groups + + @property + def num_object_groups(self) -> int: + """Number of :class:`ObjectGroupInfo` entries.""" + return len(self._object_groups) + + @property + @lmcache_deprecate("`num_groups` is an outdated alias for `num_kernel_groups`") + def num_groups(self) -> int: + """Number of :class:`KernelGroupInfo` entries. + + Zero if ``kv_caches`` had no layers at construction time. + """ + return len(self._kernel_groups) + + @property + def inference_engine_logical_block_size(self) -> int: + """Inference-engine-side logical block size. + + Taken from ``layout_hints`` at construction time, or falls back + to the first group's physical ``bs`` when no hint is provided + (non-vLLM engines, or vLLM without mixed-compression KV groups), + in which case every group is treated as non-compressed. + """ + return ( + self.inference_engine_logical_block_size_ + or self._kernel_groups[0].shape_desc.bs ) - logger.info("KV layer groups: %s", self.kv_layer_groups) + def get_shape_desc(self, kernel_group_idx: int) -> "lmc_ops.PageBufferShapeDesc": + """Return the :class:`PageBufferShapeDesc` for *kernel_group_idx*. + + Equivalent to ``self._kernel_groups[kernel_group_idx].shape_desc``. + + Args: + kernel_group_idx: 0-based kernel group index. + + Raises: + IndexError: If *kernel_group_idx* is out of range. + """ + return self._kernel_groups[kernel_group_idx].shape_desc + + def get_physical_chunk_size(self, kernel_group_idx: int) -> int: + """Return the per-chunk *physical* slot count for *kernel_group_idx*. + + Equivalent to + ``self._kernel_groups[kernel_group_idx].physical_chunk_size``. + For non-compressed groups this equals + ``lmcache_logical_chunk_size``; for compressed groups it equals + ``lmcache_logical_chunk_size // compress_ratio`` and is what the + block-level transfer kernel must be told (the logical chunk size + in *vLLM tokens* is not what the kernel addresses). + + Args: + kernel_group_idx: 0-based kernel group index. + + Raises: + IndexError: If *kernel_group_idx* is out of range. + """ + return self._kernel_groups[kernel_group_idx].physical_chunk_size + + def calculate_num_blocks(self, kernel_group_idx: int, num_tokens: int) -> int: + """Calculate the number of blocks for a given number of tokens in a + specified kernel group. + + Args: + kernel_group_idx: 0-based index of the kernel group. + num_tokens: The total number of tokens to be processed for the group. + + Returns: + The number of blocks. + + Raises: + IndexError: If *kernel_group_idx* is out of range. + """ + group = self._kernel_groups[kernel_group_idx] + num_physical_slots = num_tokens // group.compress_ratio + return num_physical_slots // group.shape_desc.bs + + ### Helper methods + def _detect_object_groups( + self, group_views: "Sequence[LMCacheGroupView]" + ) -> list[ObjectGroupInfo]: + """Detect object groups based on the provided group views. + + Args: + group_views: LMCache-owned engine KV cache group metadata. + + Returns: + A list of ObjectGroupInfo instances representing the detected object groups. + """ + # TODO: add the real object group detection logic based on + # the attention type metadata in the group views once it's + # available. + # Now, we are using a single object group, which means + # all kernel groups' KV caches will be stored in the same memory object. + return [ + ObjectGroupInfo(kernel_group_indices=list(range(len(self._kernel_groups)))) + ] @staticmethod def _derive_compression_metadata( @@ -406,60 +582,6 @@ def _derive_compression_metadata( ) return compress_ratio, physical_chunk_size - @property - def num_groups(self) -> int: - """Number of :class:`KVLayerGroupInfo` entries. - - Zero if ``kv_caches`` had no layers at construction time. - """ - return len(self.kv_layer_groups) - - @property - def inference_engine_logical_block_size(self): - """Inference-engine-side logical block size. - - Taken from ``layout_hints`` at construction time, or falls back - to the first group's physical ``bs`` when no hint is provided - (non-vLLM engines, or vLLM without mixed-compression KV groups), - in which case every group is treated as non-compressed. - """ - return ( - self.inference_engine_logical_block_size_ - or self.kv_layer_groups[0].shape_desc.bs - ) - - def get_shape_desc(self, group_idx: int) -> "lmc_ops.PageBufferShapeDesc": - """Return the :class:`PageBufferShapeDesc` for *group_idx*. - - Equivalent to ``self.kv_layer_groups[group_idx].shape_desc``. - - Args: - group_idx: 0-based group index. - - Raises: - IndexError: If *group_idx* is out of range. - """ - return self.kv_layer_groups[group_idx].shape_desc - - def get_physical_chunk_size(self, group_idx: int) -> int: - """Return the per-chunk *physical* slot count for *group_idx*. - - Equivalent to - ``self.kv_layer_groups[group_idx].physical_chunk_size``. - For non-compressed groups this equals - ``lmcache_logical_chunk_size``; for compressed groups it equals - ``lmcache_logical_chunk_size // compress_ratio`` and is what the - block-level transfer kernel must be told (the logical chunk size - in *vLLM tokens* is not what the kernel addresses). - - Args: - group_idx: 0-based group index. - - Raises: - IndexError: If *group_idx* is out of range. - """ - return self.kv_layer_groups[group_idx].physical_chunk_size - # ------------------------------------------------------------------ # # CLI shape-spec parser # @@ -468,7 +590,7 @@ def get_physical_chunk_size(self, group_idx: int) -> int: def parse_kvcache_shape_spec( spec_str: str, -) -> list[KVLayerGroupInfo]: +) -> list[KernelGroupInfo]: """Parse a ``--kvcache-shape-spec`` string into layer groups. **Grammar** (EBNF-ish):: @@ -513,7 +635,7 @@ def parse_kvcache_shape_spec( (handy for CLI echo-back / debug logging). Returns: - A list of :class:`KVLayerGroupInfo`, one per group. + A list of :class:`KernelGroupInfo`, one per group. Raises: ValueError: Malformed spec, unknown dtype, or a shape with a @@ -522,7 +644,7 @@ def parse_kvcache_shape_spec( if not spec_str: raise ValueError("KV shape specification cannot be empty") - groups: list[KVLayerGroupInfo] = [] + groups: list[KernelGroupInfo] = [] layer_offset = 0 for group_spec in spec_str.split(";"): @@ -574,7 +696,7 @@ def parse_kvcache_shape_spec( indices = list(range(layer_offset, layer_offset + layer_count)) groups.append( - KVLayerGroupInfo( + KernelGroupInfo( layer_indices=indices, shape_desc=shape_desc, dtype=dtype, @@ -588,7 +710,7 @@ def parse_kvcache_shape_spec( return groups -def format_kvcache_shape_spec(groups: list[KVLayerGroupInfo]) -> str: +def format_kvcache_shape_spec(groups: list[KernelGroupInfo]) -> str: """Format layer groups back into a ``--kvcache-shape-spec`` string. This is the inverse of :func:`parse_kvcache_shape_spec`; the diff --git a/lmcache/v1/multiprocess/gpu_context.py b/lmcache/v1/multiprocess/gpu_context.py index bc885467ad..182d55af15 100644 --- a/lmcache/v1/multiprocess/gpu_context.py +++ b/lmcache/v1/multiprocess/gpu_context.py @@ -22,7 +22,7 @@ # First Party from lmcache import torch_dev from lmcache.logging import init_logger -from lmcache.utils import EngineType +from lmcache.utils import EngineType, lmcache_deprecate from lmcache.v1.gpu_connector.utils import ( LayoutHints, get_attention_backend, @@ -63,6 +63,275 @@ def list_to_gpu_tensor(lis: list[int], device: torch.device) -> torch.Tensor: ) +class _TempGPUBuffer: + """ + Manages the temporary GPU buffer for GPUCacheContext + + The logical layout of the temp GPU buffer is (batch size, + object group, kernel group). + + Here is an example of batch size = 4, with 2 object groups, + and 2 kernel groups per object group: + [ + batch 0: + - object group 0: kernel group 0 | kernel group 1 | ... + - object group 1: kernel group 2 | kernel group 3 | ... + + batch 1: + - object group 0: kernel group 0 | kernel group 1 | ... + - object group 1: kernel group 2 | kernel group 3 | ... + + batch 2: + - object group 0: kernel group 0 | kernel group 1 | ... + - object group 1: kernel group 2 | kernel group 3 | ... + + batch 3: + - object group 0: kernel group 0 | kernel group 1 | ... + - object group 1: kernel group 2 | kernel group 3 | ... + ] + + During the multi-layer copy kernel launch, we will do it at kernel + group level, which means we will have: + ``` + gpu_buffers = [ + get_temp_kernel_group_buffer(batch_idx, kernel_group_idx) + for batch_idx in range(batch_size) + ] + ``` + + During the lmcache_memcpy_async launch, we will do it at the object group + level, which will be: + ``` + for i in range(batch_size): + gpu_buffer = get_temp_object_group_buffer(batch_idx, object_group_idx) + lmcache_memcpy_async(...) + ``` + """ + + def __init__( + self, + kv_layer_groups_manager: KVLayerGroupsManager, + lmcache_logical_chunk_size: int, + device: torch.device, + max_batch_size: int = 4, + ) -> None: + self._kv_groups_manager = kv_layer_groups_manager + self._lmcache_chunk_size = lmcache_logical_chunk_size + self._max_batch_size = max_batch_size + + self._temp_buffer = torch.empty( + self._get_size_for_single_batch() * max_batch_size, + dtype=torch.uint8, + device=device, + ) + + # Offset map: (batch_idx, object_group_idx, kernel_group_idx) -> + # (byte offset in the temp buffer, size of the buffer in bytes) + self._offset_map: dict[tuple[int, int, int], tuple[int, int]] = {} + + # (batch_idx, kernel_group_idx) -> (byte offset for the kernel group, + # size of the buffer in bytes). + self._offset_map_kernel_group_only: dict[tuple[int, int], tuple[int, int]] = {} + + # (batch_idx, object_group_idx) -> (byte offset for the object group, + # size of the buffer in bytes) + self._offset_map_object_group_only: dict[tuple[int, int], tuple[int, int]] = {} + + offset = 0 + for batch_idx in range(max_batch_size): + for object_group_idx in range(self._kv_groups_manager.num_object_groups): + object_group_size = 0 + object_group_start_offset = offset + + for kernel_group_idx in self._kv_groups_manager.object_groups[ + object_group_idx + ].kernel_group_indices: + key = (batch_idx, object_group_idx, kernel_group_idx) + key2 = (batch_idx, kernel_group_idx) + + size = self._get_size_for_kernel_group(kernel_group_idx) + self._offset_map[key] = (offset, size) + self._offset_map_kernel_group_only[key2] = (offset, size) + + offset += size + object_group_size += size + + key3 = (batch_idx, object_group_idx) + self._offset_map_object_group_only[key3] = ( + object_group_start_offset, + object_group_size, + ) + + # Shape/dtype cache for kernel groups + self._shape_cache_kernel_group: dict[int, tuple[torch.Size, torch.dtype]] = {} + for kernel_group_idx in range(self._kv_groups_manager.num_kernel_groups): + shape = self._get_shape_for_kernel_group( + self._lmcache_chunk_size, kernel_group_idx + ) + group = self._kv_groups_manager.kernel_groups[kernel_group_idx] + dtype = group.dtype + self._shape_cache_kernel_group[kernel_group_idx] = (shape, dtype) + + # Public APIs + @property + def max_batch_size(self) -> int: + """Maximum number of chunks (batch slots) the buffer holds.""" + return self._max_batch_size + + def get_temp_kernel_group_buffer( + self, batch_idx: int, kernel_group_idx: int + ) -> torch.Tensor: + """ + Returns the temp GPU buffer for the given batch index and kernel group index. + The returned buffer is with the correct shape and dtype for the kernel group. + + Args: + batch_idx: Index of the batch (0 <= batch_idx < max_batch_size) + kernel_group_idx: Index of the kernel group. + + Returns: + The temp GPU buffer for the given batch index and kernel group index. + + Raises: + ValueError: If the batch_idx or kernel_group_idx is out of range. + """ + key = (batch_idx, kernel_group_idx) + if key not in self._offset_map_kernel_group_only: + raise ValueError( + f"Invalid batch_idx {batch_idx} or kernel_group_idx {kernel_group_idx}" + ) + + offset, size = self._offset_map_kernel_group_only[key] + shape, dtype = self._shape_cache_kernel_group[kernel_group_idx] + return self._temp_buffer[offset : offset + size].view(dtype).view(shape) + + def get_temp_object_group_buffer( + self, batch_idx: int, object_group_idx: int + ) -> torch.Tensor: + """ + Returns the temp GPU buffer for the given batch index and object group index + The returned buffer is a flat uint8 raw tensor. + + Args: + batch_idx: Index of the batch (0 <= batch_idx < max_batch_size) + object_group_idx: Index of the object group. + + Returns: + The temp GPU buffer for the given batch index and object group index. + """ + key = (batch_idx, object_group_idx) + if key not in self._offset_map_object_group_only: + raise ValueError( + f"Invalid batch_idx {batch_idx} or object_group_idx {object_group_idx}" + ) + + offset, size = self._offset_map_object_group_only[key] + return self._temp_buffer[offset : offset + size] + + def get_kernel_group_shape_dtype( + self, + num_tokens: int, + kernel_group_idx: int, + ) -> tuple[torch.Size, torch.dtype]: + """ + Returns the shape and dtype for the given kernel group index and + number of tokens. + + Will be exported by GPUCacheContext and used to construct the + MemoryLayoutDesc + + Args: + num_tokens: Number of tokens + kernel_group_idx: Index of the kernel group. + + Returns: + The shape and dtype for the given kernel group index and + number of tokens. + """ + _, dtype = self._shape_cache_kernel_group[kernel_group_idx] + shape = self._get_shape_for_kernel_group(num_tokens, kernel_group_idx) + + return shape, dtype + + def get_cache_size_per_token(self) -> int: + """ + Returns the cache size per token (in bytes), summed across all kernel groups. + """ + return self._get_size_for_single_batch() // self._lmcache_chunk_size + + # Helper functions + def _get_shape_for_kernel_group( + self, + num_tokens: int, + kernel_group_idx: int, + ) -> torch.Size: + """ + Returns the shape of the temp GPU buffer for the given kernel group index + + Args: + num_tokens: Number of tokens + kernel_group_idx: Index of the kernel group. + + Returns: + The shape of the temp GPU buffer for the given kernel group index. + """ + group = self._kv_groups_manager.kernel_groups[kernel_group_idx] + compress_ratio = group.compress_ratio + sd = group.shape_desc + + if num_tokens % compress_ratio != 0: + raise ValueError( + f"logical_num_tokens ({num_tokens}) is not a multiple of " + f"compress_ratio ({compress_ratio}) for group {kernel_group_idx}" + ) + num_slots = num_tokens // compress_ratio + return torch.Size( + (sd.kv_size, group.num_layers, num_slots, group.hidden_dim_size) + ) + + def _get_size_for_kernel_group(self, kernel_group_idx: int) -> int: + """ + Returns the size in bytes of the temp GPU buffer for the given kernel group + index + + **Assumes the size is lmcache_chunk_size + + Will only be called during initialization + """ + shape = self._get_shape_for_kernel_group( + self._lmcache_chunk_size, kernel_group_idx + ) + kernel_group = self._kv_groups_manager.kernel_groups[kernel_group_idx] + dtype = kernel_group.dtype + return shape.numel() * dtype.itemsize + + def _get_size_for_object_group(self, object_group_idx: int) -> int: + """ + Returns the size in bytes of the temp GPU buffer for the given object group + + **Assumes the size is lmcache_chunk_size + + Will only be called during initialization + """ + object_group = self._kv_groups_manager.object_groups[object_group_idx] + return sum( + self._get_size_for_kernel_group(kernel_group_idx) + for kernel_group_idx in object_group.kernel_group_indices + ) + + def _get_size_for_single_batch(self) -> int: + """ + Returns the size in bytes of the temp GPU buffer for a single batch + (i.e., a single chunk) + + **Assumes the size is lmcache_chunk_size + """ + return sum( + self._get_size_for_object_group(object_group_idx) + for object_group_idx in range(self._kv_groups_manager.num_object_groups) + ) + + class GPUCacheContext: """ Manages the shape and pointers to vLLM GPU KV cache tensors. @@ -107,44 +376,17 @@ def __init__( # Pre-allocated GPU buffer for block IDs (up to 1M elements). # The caller copies block_ids into this buffer before launching the # block-level kernel. Single-thread assumption: no lock needed. - _MAX_BLOCK_IDS = 1_000_000 + _MAX_BLOCK_IDS = 1 << 20 self.block_ids_buffer_ = torch.empty( _MAX_BLOCK_IDS, dtype=torch.long, device=self.device_ ) # Temporary GPU buffer for transfers — a single flat uint8 buffer - # laid out in chunk-major order so that each chunk's data matches - # the layout of a MemoryObj.raw_data (all groups concatenated): - # - # [ chunk_0: group_0_bytes | group_1_bytes | ... ] - # [ chunk_1: group_0_bytes | group_1_bytes | ... ] - # ... - # - # This lets callers copy an entire chunk to/from a MemoryObj with a - # single memcpy, without needing to know the per-group layout. - # max_batch_size is the max number of chunks processed concurrently. - self.max_batch_size = 4 - # Byte size of one chunk entry (= one chunk across all groups). - # tmp_chunk_group_offsets_[g] is the byte offset of group g within - # a single chunk; tmp_chunk_group_offsets_[num_groups] == - # tmp_chunk_bytes_. - self.tmp_chunk_group_offsets_: list[int] = [0] - for group_idx, group in enumerate( - self.kv_layer_groups_manager_.kv_layer_groups - ): - # ``get_kv_buffer_shape`` takes *logical* tokens; for - # compressed groups it folds ``compress_ratio`` logical - # tokens into one physical slot internally. - shape = self.get_kv_buffer_shape(lmcache_logical_chunk_size, group_idx) - byte_size = shape.numel() * group.dtype.itemsize - self.tmp_chunk_group_offsets_.append( - self.tmp_chunk_group_offsets_[-1] + byte_size - ) - self.tmp_chunk_bytes_ = self.tmp_chunk_group_offsets_[-1] - self.tmp_gpu_buffer_ = torch.empty( - self.tmp_chunk_bytes_ * self.max_batch_size, - dtype=torch.uint8, + self._temp_buffer = _TempGPUBuffer( + kv_layer_groups_manager=self.kv_layer_groups_manager_, + lmcache_logical_chunk_size=lmcache_logical_chunk_size, device=self.device_, + max_batch_size=4, ) # GPU streams @@ -156,14 +398,6 @@ def __init__( self.cuda_stream_.cuda_stream, self.device_.index ) - _, high_priority = torch_dev.Stream.priority_range() - self.high_priority_cuda_stream_ = torch_dev.Stream( - device=self.device_, priority=high_priority - ) - self.high_priority_cupy_stream_ = cupy.cuda.ExternalStream( - self.high_priority_cuda_stream_.cuda_stream, self.device_.index - ) - # Extra initialization self.cupy_stream_.launch_host_func( lambda logger: logger.info( @@ -195,38 +429,6 @@ def stream(self) -> Any: def cupy_stream(self) -> "cupy.cuda.Stream": return self.cupy_stream_ - @property - def high_priority_stream(self) -> Any: - return self.high_priority_cuda_stream_ - - @property - def high_priority_cupy_stream(self) -> "cupy.cuda.Stream": - return self.high_priority_cupy_stream_ - - @property - def group_physical_block_sizes(self) -> list[int]: - """Per-group physical slot count (``shape_desc.bs``) in group - order. For non-compressed groups this equals - ``inference_engine_logical_block_size``; for compressed groups - it equals - ``inference_engine_logical_block_size // compress_ratio``. - """ - return [ - group.shape_desc.bs - for group in self.kv_layer_groups_manager_.kv_layer_groups - ] - - @property - def group_compress_ratios(self) -> list[int]: - """Per-group compression ratio - (= ``inference_engine_logical_block_size // shape_desc.bs``) - in group order. ``1`` for non-compressed groups. - """ - return [ - group.compress_ratio - for group in self.kv_layer_groups_manager_.kv_layer_groups - ] - @property def num_layers(self) -> int: """ @@ -256,6 +458,11 @@ def hidden_dim_sizes(self) -> list[int]: for group in self.kv_layer_groups_manager_.kv_layer_groups ] + @property + def kv_layer_groups_manager(self) -> KVLayerGroupsManager: + """Returns the KV layer groups manager.""" + return self.kv_layer_groups_manager_ + def get_shape_desc(self, group_idx: int) -> "lmc_ops.PageBufferShapeDesc": """Returns the PageBufferShapeDesc for the given KV layer group.""" return self.kv_layer_groups_manager_.get_shape_desc(group_idx) @@ -269,116 +476,70 @@ def get_physical_chunk_size(self, group_idx: int) -> int: """ return self.kv_layer_groups_manager_.get_physical_chunk_size(group_idx) - def blocks_for_tokens(self, num_logical_tokens: int, group_idx: int) -> int: - """Number of group ``group_idx`` blocks that span ``num_logical_tokens``. + def get_kernel_group_kv_pointers(self, kernel_group_idx: int) -> torch.Tensor: + """Returns the pre-computed GPU tensor of KV cache pointers for the + given kernel group index. + """ + return self.group_kv_pointers_[kernel_group_idx] - Each group counts blocks in its own ``block_size`` (``shape_desc.bs``), - which can differ across groups. For compressed groups, ``compress_ratio`` - logical tokens share one physical slot, so it is divided out first. + def get_temp_kernel_group_buffer( + self, batch_idx: int, kernel_group_idx: int + ) -> torch.Tensor: + """Returns the temporary GPU buffer for the given batch index and kernel + group index, with the correct shape and dtype for the kernel group. Args: - num_logical_tokens: Number of logical (engine-side) tokens. - group_idx: Index of the KV layer group. + batch_idx: Index of the batch (0 <= batch_idx < max_batch_size) + kernel_group_idx: Index of the kernel group. Returns: - The number of this group's blocks spanning those tokens. + The temp GPU buffer for the given batch index and kernel group index. """ - group = self.kv_layer_groups_manager_.kv_layer_groups[group_idx] - physical_slots = num_logical_tokens // group.compress_ratio - return physical_slots // group.shape_desc.bs - - @property - def kv_layer_groups_manager(self) -> KVLayerGroupsManager: - """Returns the KV layer groups manager.""" - return self.kv_layer_groups_manager_ - - @property - def gpu_kv_format_name(self) -> str: - """Returns the GPU KV format enum name (e.g. ``'NL_X_TWO_NB_BS_NH_HS'``).""" - return self.gpu_kv_format_.name - - @property - def gpu_kv_shape(self) -> str: - """Returns a human-readable shape description of the GPU KV cache layout.""" - return get_gpu_kv_shape_description(self.gpu_kv_format_) - - @property - def attention_backend(self) -> str: - """Returns the attention backend name.""" - return get_attention_backend(self.gpu_kv_format_) + return self._temp_buffer.get_temp_kernel_group_buffer( + batch_idx, kernel_group_idx + ) @property - def concrete_gpu_kv_shape(self) -> str: - """Returns the GPU KV shape with actual numeric values substituted.""" - return get_concrete_gpu_kv_shape(self.kv_caches_, self.gpu_kv_format_) + def max_batch_size(self) -> int: + """Maximum number of chunks processed concurrently in one batch.""" + return self._temp_buffer.max_batch_size - def get_group_kv_pointers(self, group_idx: int) -> torch.Tensor: - """Returns the pre-computed GPU tensor of KV cache pointers for the - given group.""" - return self.group_kv_pointers_[group_idx] - - def get_tmp_gpu_buffer_flat(self, chunk_idx: int) -> torch.Tensor: - """Returns the flat uint8 view of the temporary GPU buffer for the - given chunk index, covering all KV layer groups. - - The returned tensor will fit a memory full object corresponding - ``self.chunk_size`` tokens, so it can be copied to/from a MemoryObj - with a single memcpy. + def get_temp_object_group_buffer( + self, batch_idx: int, object_group_idx: int + ) -> torch.Tensor: + """Returns the temporary GPU buffer for the given batch index and object + group index, as a flat uint8 tensor. Args: - chunk_idx: Chunk index (0 <= chunk_idx < max_batch_size). - """ - if chunk_idx >= self.max_batch_size: - raise ValueError( - f"chunk_idx {chunk_idx} exceeds max_batch_size {self.max_batch_size}" - ) - start = chunk_idx * self.tmp_chunk_bytes_ - return self.tmp_gpu_buffer_[start : start + self.tmp_chunk_bytes_] - - def get_tmp_chunk_gpu_buffer(self, group_idx: int = 0) -> torch.Tensor: - """ - Returns a view of the temporary GPU buffer for the given group, - sized for a single chunk. The chunk holds - ``lmcache_logical_chunk_size`` logical tokens which, for a - compressed group, correspond to ``group.physical_chunk_size`` - physical slots. + batch_idx: Index of the batch (0 <= batch_idx < max_batch_size) + object_group_idx: Index of the object group. - Args: - group_idx: Index of the KV layer group (default 0). + Returns: + The temp GPU buffer for the given batch index and object group index. """ - group = self.kv_layer_groups_manager_.kv_layer_groups[group_idx] - shape = self.get_kv_buffer_shape(self.lmcache_logical_chunk_size, group_idx) - start = self.tmp_chunk_group_offsets_[group_idx] - end = self.tmp_chunk_group_offsets_[group_idx + 1] - return self.tmp_gpu_buffer_[start:end].view(group.dtype).view(shape) + return self._temp_buffer.get_temp_object_group_buffer( + batch_idx, object_group_idx + ) - def get_tmp_chunk_gpu_buffer_batched( - self, batch_size: int, group_idx: int = 0 - ) -> list[torch.Tensor]: - """ - Returns a list of ``batch_size`` non-overlapping views into the - pre-allocated temporary GPU buffer for the given group, each - sized for ``lmcache_logical_chunk_size`` tokens. + def get_kernel_group_shape_dtype( + self, + num_tokens: int, + kernel_group_idx: int, + ) -> tuple[torch.Size, torch.dtype]: + """Returns the shape and dtype for the given kernel group index and number + of tokens. + Will be exported by GPUCacheContext and used to construct the MemoryLayoutDesc Args: - batch_size: Number of concurrent requests (must be <= max_batch_size). - group_idx: Index of the KV layer group (default 0). + num_tokens: Number of tokens + kernel_group_idx: Index of the kernel group. + + Returns: + The shape and dtype for the given kernel group index and number of tokens. """ - if batch_size > self.max_batch_size: - raise ValueError( - f"batch_size {batch_size} exceeds max_batch_size {self.max_batch_size}" - ) - group = self.kv_layer_groups_manager_.kv_layer_groups[group_idx] - shape = self.get_kv_buffer_shape(self.lmcache_logical_chunk_size, group_idx) - g_start = self.tmp_chunk_group_offsets_[group_idx] - g_end = self.tmp_chunk_group_offsets_[group_idx + 1] - chunk = self.tmp_chunk_bytes_ - return [ - self.tmp_gpu_buffer_[i * chunk + g_start : i * chunk + g_end] - .view(group.dtype) - .view(shape) - for i in range(batch_size) - ] + return self._temp_buffer.get_kernel_group_shape_dtype( + num_tokens, kernel_group_idx + ) def copy_view_block_ids_to_gpu( self, block_ids_per_group: list[list[int]] @@ -410,6 +571,7 @@ def copy_view_block_ids_to_gpu( for i in range(len(block_ids_per_group)) ] + @lmcache_deprecate("will be refactored") def get_kv_buffer_shape( self, logical_num_tokens: int, group_idx: int = 0 ) -> torch.Size: @@ -429,6 +591,7 @@ def get_kv_buffer_shape( of the group's ``compress_ratio``. group_idx: Index of the KV layer group (default 0). """ + # TODO: remove this! group = self.kv_layer_groups_manager_.kv_layer_groups[group_idx] compress_ratio = group.compress_ratio if logical_num_tokens % compress_ratio != 0: @@ -442,6 +605,24 @@ def get_kv_buffer_shape( (sd.kv_size, group.num_layers, num_slots, group.hidden_dim_size) ) + def calculate_num_blocks(self, num_tokens: int, kernel_group_idx: int) -> int: + """Calculate the number of blocks for a given number of tokens in a + specified kernel group. + + Args: + kernel_group_idx: 0-based index of the kernel group. + num_tokens: The total number of tokens to be processed for the group. + + Returns: + The number of blocks. + + Raises: + IndexError: If *kernel_group_idx* is out of range. + """ + return self.kv_layer_groups_manager.calculate_num_blocks( + kernel_group_idx, num_tokens + ) + def cache_size_per_token(self) -> int: """ Returns the cache size per *logical* token (in bytes), summed @@ -453,20 +634,65 @@ def cache_size_per_token(self) -> int: endpoint and the ``lmcache describe`` CLI); sub-byte truncation from integer division is acceptable. """ - total = 0 - for group_idx, group in enumerate( - self.kv_layer_groups_manager_.kv_layer_groups - ): - # ``get_kv_buffer_shape`` now takes *logical* tokens, so - # query ``compress_ratio`` logical tokens (= 1 physical - # slot) and then divide the resulting bytes back by - # ``compress_ratio`` to recover the per-logical-token - # contribution. Equivalent to the old - # ``physical_slot_bytes // compress_ratio`` formulation. - numels = self.get_kv_buffer_shape(group.compress_ratio, group_idx).numel() - slot_bytes = numels * group.dtype.itemsize - total += slot_bytes // group.compress_ratio - return total + return self._temp_buffer.get_cache_size_per_token() + + def report_status(self) -> dict: + """Return this context's KV cache layout metadata for ``/status``. + + Builds the ``kv_cache_layout`` sub-dict surfaced by the ``/status`` + HTTP endpoint (see ``GPUTransferModule.report_status``) and consumed by + the ``lmcache`` CLI (``lmcache describe kvcache`` and + ``lmcache bench engine``). It describes only the KV cache geometry; the + owning module wraps it with ``model_name``/``world_size``, which this + context does not track. + + Returns: + A dict with one entry per documented layout field: + + - ``num_layers`` (int) + - ``inference_engine_logical_block_size`` (int) + - ``group_physical_block_sizes`` (list[int]): per-group + ``shape_desc.bs`` + - ``group_compress_ratios`` (list[int]): per-group compress ratio + - ``hidden_dim_sizes`` (str): stringified per-group hidden-dim list + - ``dtype`` (str): stringified torch dtype + - ``is_mla`` (bool) + - ``num_blocks`` (int) + - ``gpu_kv_format`` (str): GPU KV format enum name + - ``gpu_kv_shape`` (str): symbolic shape description + - ``gpu_kv_concrete_shape`` (str): shape with numeric values + - ``attention_backend`` (str) + - ``cache_size_per_token`` (int): bytes per logical token + """ + # TODO(compat): the key names and value *formatting* below are a + # contract with the `/status` endpoint and the `lmcache` CLI + # (`lmcache/cli/commands/describe.py`, `bench/engine_bench/config.py`). + # Renaming a key breaks `lmcache describe kvcache`; dropping + # `cache_size_per_token` breaks `lmcache bench engine`. `hidden_dim_sizes` + # and `dtype` are stringified only for back-compat with those consumers + # and should become a real list / structured value once the CLI is + # updated to parse them. + manager = self.kv_layer_groups_manager + kernel_groups = manager.kernel_groups + return { + "num_layers": self.num_layers, + "inference_engine_logical_block_size": ( + manager.inference_engine_logical_block_size + ), + "group_physical_block_sizes": [g.shape_desc.bs for g in kernel_groups], + "group_compress_ratios": [g.compress_ratio for g in kernel_groups], + "hidden_dim_sizes": str([g.hidden_dim_size for g in kernel_groups]), + "dtype": str(self.dtype), + "is_mla": self.is_mla, + "num_blocks": self.num_blocks, + "gpu_kv_format": self.gpu_kv_format_.name, + "gpu_kv_shape": get_gpu_kv_shape_description(self.gpu_kv_format_), + "gpu_kv_concrete_shape": get_concrete_gpu_kv_shape( + self.kv_caches_, self.gpu_kv_format_ + ), + "attention_backend": get_attention_backend(self.gpu_kv_format_), + "cache_size_per_token": self.cache_size_per_token(), + } class PlainGPUCacheContext: @@ -506,14 +732,6 @@ def __init__(self, kv_caches: KVCache, lmcache_chunk_size: int = 256): self._cuda_stream.cuda_stream, self._device.index ) - _, high_priority = torch_dev.Stream.priority_range() - self._high_priority_cuda_stream = torch_dev.Stream( - device=self._device, priority=high_priority - ) - self._high_priority_cupy_stream = cupy.cuda.ExternalStream( - self._high_priority_cuda_stream.cuda_stream, self._device.index - ) - # Extra initialization self._cupy_stream.launch_host_func( lambda logger: logger.info( @@ -557,14 +775,6 @@ def stream(self) -> Any: def cupy_stream(self) -> "cupy.cuda.Stream": return self._cupy_stream - @property - def high_priority_stream(self) -> Any: - return self._high_priority_cuda_stream - - @property - def high_priority_cupy_stream(self) -> "cupy.cuda.Stream": - return self._high_priority_cupy_stream - @property def num_layers(self) -> int: return self._num_layers diff --git a/lmcache/v1/multiprocess/modules/blend_v3.py b/lmcache/v1/multiprocess/modules/blend_v3.py index 20db4908a9..7f0dba6889 100644 --- a/lmcache/v1/multiprocess/modules/blend_v3.py +++ b/lmcache/v1/multiprocess/modules/blend_v3.py @@ -823,17 +823,18 @@ def _apply_cb_rope_batched( (slot_idx, old_st, cur_st).""" if not slots_to_rope: return - num_groups = gpu_context.kv_layer_groups_manager.num_groups + num_groups = gpu_context.kv_layer_groups_manager.num_kernel_groups for group_idx in range(num_groups): - group = gpu_context.kv_layer_groups_manager.kv_layer_groups[group_idx] + group = gpu_context.kv_layer_groups_manager.kernel_groups[group_idx] if group.compress_ratio != 1: raise RuntimeError( f"CB v3: group {group_idx} has compress_ratio=" f"{group.compress_ratio}; compressed layouts unsupported." ) - all_slots = gpu_context.get_tmp_chunk_gpu_buffer_batched( - batch_size=batch_len, group_idx=group_idx - ) + all_slots = [ + gpu_context.get_temp_kernel_group_buffer(slot_idx, group_idx) + for slot_idx in range(batch_len) + ] if all_slots[0].shape[0] != 2: raise RuntimeError( f"CB v3: group {group_idx} has kv_size={all_slots[0].shape[0]}; " @@ -948,7 +949,7 @@ def cb_retrieve_pre_computed( f"chunk_size {chunk_size} must be a multiple of " f"inference_engine_logical_block_size {ie_logical_block_size}" ) - num_groups = gpu_context.kv_layer_groups_manager.num_groups + num_groups = gpu_context.kv_layer_groups_manager.num_kernel_groups with ( torch_dev.device(gpu_context.device), @@ -1026,8 +1027,9 @@ def cb_retrieve_pre_computed( # (a) H2D fill into per-chunk tmp slots. for slot_idx, (_, memory_obj) in enumerate(batch): - flat_slot = gpu_context.get_tmp_gpu_buffer_flat( - chunk_idx=slot_idx + # Single object group => object_group_idx=0. + flat_slot = gpu_context.get_temp_object_group_buffer( + slot_idx, 0 ) lmcache_memcpy_async_h2d(memory_obj, flat_slot) @@ -1060,16 +1062,16 @@ def cb_retrieve_pre_computed( ) page_buffer_size = gpu_context.num_blocks * bs for group_idx in range(num_groups): - tmp_buffers = ( - gpu_context.get_tmp_chunk_gpu_buffer_batched( - batch_size=batch_len, - group_idx=group_idx, + tmp_buffers = [ + gpu_context.get_temp_kernel_group_buffer( + slot_idx, group_idx ) - ) + for slot_idx in range(batch_len) + ] key_value = torch.cat(tmp_buffers, dim=2) lmc_ops.multi_layer_kv_transfer( key_value, - gpu_context.get_group_kv_pointers(group_idx), + gpu_context.get_kernel_group_kv_pointers(group_idx), slot_mapping, gpu_context.device, page_buffer_size, diff --git a/lmcache/v1/multiprocess/modules/gpu_transfer.py b/lmcache/v1/multiprocess/modules/gpu_transfer.py index 25c0bd0ab6..8b012af0c3 100644 --- a/lmcache/v1/multiprocess/modules/gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/gpu_transfer.py @@ -49,29 +49,35 @@ def get_layout_desc( - cache_context: GPUCacheContext, num_tokens: int + gpu_context: GPUCacheContext, + num_tokens: int, + object_group_id: int = 0, ) -> MemoryLayoutDesc: - """Get the memory layout description for a given GPU context and number of tokens. + """Get the memory layout description for a specific object group. - Supports multiple KV layer groups with different shapes and dtypes. + The returned layout describes the single memory object that backs + ``object_group_id``: one (shape, dtype) entry per kernel group in that + object group, in the kernel groups' declared layout order. Kernel groups + may have different shapes and dtypes. Args: cache_context: The GPU cache context containing the KV cache information. num_tokens: The number of tokens to determine the layout for. + object_group_id: Index of the object group whose layout to build. + Defaults to 0; under the current single-object-group assumption this + covers every kernel group. Returns: - MemoryLayoutDesc: The memory layout description containing shapes and dtypes. + MemoryLayoutDesc: The memory layout description containing shapes and + dtypes, one entry per kernel group in the object group. """ - num_groups = cache_context.kv_layer_groups_manager.num_groups - shapes = [ - cache_context.get_kv_buffer_shape(num_tokens, group_idx) - for group_idx in range(num_groups) + object_group = gpu_context.kv_layer_groups_manager.object_groups[object_group_id] + shapes_and_dtypes = [ + gpu_context.get_kernel_group_shape_dtype(num_tokens, kernel_group_idx) + for kernel_group_idx in object_group.kernel_group_indices ] - dtypes = [ - cache_context.kv_layer_groups_manager.kv_layer_groups[group_idx].dtype - for group_idx in range(num_groups) - ] - return MemoryLayoutDesc(shapes=shapes, dtypes=dtypes) + shapes, dtypes = zip(*shapes_and_dtypes, strict=False) + return MemoryLayoutDesc(shapes=list(shapes), dtypes=list(dtypes)) def batched_iteration(lst: list, batch_size: int) -> Generator[tuple, None, None]: @@ -198,23 +204,7 @@ def report_status(self) -> dict: cache_context_meta[str(instance_id)] = { "model_name": entry.model_name, "world_size": entry.world_size, - "kv_cache_layout": { - "num_layers": ctx.num_layers, - "inference_engine_logical_block_size": ( - ctx.kv_layer_groups_manager.inference_engine_logical_block_size - ), - "group_physical_block_sizes": ctx.group_physical_block_sizes, - "group_compress_ratios": ctx.group_compress_ratios, - "hidden_dim_sizes": str(ctx.hidden_dim_sizes), - "dtype": str(ctx.dtype), - "is_mla": ctx.is_mla, - "num_blocks": ctx.num_blocks, - "gpu_kv_format": ctx.gpu_kv_format_name, - "gpu_kv_shape": ctx.gpu_kv_shape, - "gpu_kv_concrete_shape": ctx.concrete_gpu_kv_shape, - "attention_backend": ctx.attention_backend, - "cache_size_per_token": ctx.cache_size_per_token(), - }, + "kv_cache_layout": ctx.report_status(), } return { @@ -279,7 +269,9 @@ def register_kv_cache( world_size=world_size, ) - layout_desc = get_layout_desc(cache_context, self._ctx.chunk_size) + layout_desc = get_layout_desc( + cache_context, self._ctx.chunk_size, object_group_id=0 + ) self._ctx.layout_desc_registry.register(model_name, world_size, layout_desc) logger.info( @@ -351,11 +343,14 @@ def store( cache_context = entry.cache_context model_name = entry.model_name + # TODO(refactor): only single-object-group transfers are wired up so far. + assert cache_context.kv_layer_groups_manager.num_object_groups == 1 + # NOTE: different engine groups may have different block sizes, so # ``blocks_per_chunk[i]`` is the number of blocks in one chunk for # group ``i``. blocks_per_chunk = [ - cache_context.blocks_for_tokens(self._ctx.chunk_size, group_idx) + cache_context.calculate_num_blocks(self._ctx.chunk_size, group_idx) for group_idx in range(cache_context.kv_layer_groups_manager.num_groups) ] @@ -431,7 +426,9 @@ def store( reserved_dict: dict[ObjectKey, MemoryObj] = {} store_succeeded = False try: - layout_desc = get_layout_desc(cache_context, self._ctx.chunk_size) + layout_desc = get_layout_desc( + cache_context, self._ctx.chunk_size, object_group_id=0 + ) reserved_dict = self._ctx.storage_manager.reserve_write( obj_keys, layout_desc, "new" ) @@ -454,8 +451,11 @@ def store( chunk_block_ids_gpu = block_ids_per_group_gpu[group_idx][ idx * bpc : (idx + 1) * bpc ] - tmp_buffer = cache_context.get_tmp_chunk_gpu_buffer(group_idx) - group_kv_pointers = cache_context.get_group_kv_pointers( + # Store is not batched, so we always use batch_idx=0. + tmp_buffer = cache_context.get_temp_kernel_group_buffer( + 0, group_idx + ) + group_kv_pointers = cache_context.get_kernel_group_kv_pointers( group_idx ) # Kernel contract: ``group_lmcache_chunk_size`` here is the @@ -475,9 +475,10 @@ def store( cache_context.gpu_kv_format_, 0, ) - # Store is not batched, so we always use chunk_idx=0 (single slot) + # Store is not batched, so we always use batch_idx=0 (single + # slot). Single object group => object_group_idx=0. lmcache_memcpy_async_d2h( - cache_context.get_tmp_gpu_buffer_flat(chunk_idx=0), memory_obj + cache_context.get_temp_object_group_buffer(0, 0), memory_obj ) store_succeeded = True except Exception: @@ -565,6 +566,9 @@ def retrieve( cache_context = entry.cache_context model_name = entry.model_name + # TODO(refactor): only single-object-group transfers are wired up so far. + assert cache_context.kv_layer_groups_manager.num_object_groups == 1 + # CPU-synchronous sentinel: a GPU retrieve is about to be enqueued. # Must be published via publish() (not publish_on_stream) so the # drain thread sees it before MP_REQUEST_END can race MP_RETRIEVE_END. @@ -634,12 +638,13 @@ def _retrieve_loop(keys: list[ObjectKey], memory_objs: list[MemoryObj]) -> None: # Copy from CPU to GPU tmp buffers, then scatter to paged KV — per group # H2D copy: each memory_obj maps to its own batch slot for chunk_idx, memory_obj in enumerate(memory_obj_batch): + # Single object group => object_group_idx=0. lmcache_memcpy_async_h2d( memory_obj, - cache_context.get_tmp_gpu_buffer_flat(chunk_idx=chunk_idx), + cache_context.get_temp_object_group_buffer(chunk_idx, 0), ) for group_idx, group in enumerate(groups): - bpc = cache_context.blocks_for_tokens( + bpc = cache_context.calculate_num_blocks( self._ctx.chunk_size, group_idx ) chunk_block_ids_gpu = block_ids_per_group_gpu[group_idx][ @@ -656,13 +661,16 @@ def _retrieve_loop(keys: list[ObjectKey], memory_objs: list[MemoryObj]) -> None: f"expected={batch_len * bpc} " f"got={chunk_block_ids_gpu.shape[0]}" ) - group_skip_blocks = cache_context.blocks_for_tokens( + group_skip_blocks = cache_context.calculate_num_blocks( skip_tokens_in_chunk, group_idx ) - tmp_buffers = cache_context.get_tmp_chunk_gpu_buffer_batched( - batch_len, group_idx + tmp_buffers = [ + cache_context.get_temp_kernel_group_buffer(i, group_idx) + for i in range(batch_len) + ] + group_kv_pointers = cache_context.get_kernel_group_kv_pointers( + group_idx ) - group_kv_pointers = cache_context.get_group_kv_pointers(group_idx) group_lmcache_chunk_size = cache_context.get_physical_chunk_size( group_idx ) diff --git a/tests/v1/distributed/serde/test_serde_e2e.py b/tests/v1/distributed/serde/test_serde_e2e.py index b88a18fb57..1b192df05e 100644 --- a/tests/v1/distributed/serde/test_serde_e2e.py +++ b/tests/v1/distributed/serde/test_serde_e2e.py @@ -197,7 +197,7 @@ def test_store_and_prefetch_with_serde(self) -> None: write_and_wait_for_l2(sm, keys, layout) # Brief sleep so StoreController releases read locks after L2 store - time.sleep(0.1) + time.sleep(1) sm.clear() assert get_l1_object_count(sm) == 0 @@ -222,7 +222,7 @@ def test_no_memory_leak_after_full_cycle(self) -> None: keys = [make_object_key(i) for i in range(3)] write_and_wait_for_l2(sm, keys, layout) - time.sleep(0.1) + time.sleep(1) sm.clear() # Prefetch @@ -263,7 +263,7 @@ def test_store_and_prefetch_without_serde(self) -> None: keys = [make_object_key(i) for i in range(5)] write_and_wait_for_l2(sm, keys, layout) - time.sleep(0.1) + time.sleep(1) sm.clear() handle = sm.submit_prefetch_task(keys, layout) @@ -285,7 +285,7 @@ def test_no_memory_leak_without_serde(self) -> None: keys = [make_object_key(i) for i in range(3)] write_and_wait_for_l2(sm, keys, layout) - time.sleep(0.1) + time.sleep(1) sm.clear() handle = sm.submit_prefetch_task(keys, layout) @@ -318,7 +318,7 @@ def test_partial_prefix_with_serde(self) -> None: # Write only keys 0, 1, 3, 4 (skip 2) keys_to_write = [make_object_key(i) for i in [0, 1, 3, 4]] write_and_wait_for_l2(sm, keys_to_write, layout) - time.sleep(0.1) + time.sleep(1) sm.clear() # Request all 5 keys — prefix should be 2 (gap at index 2) @@ -354,7 +354,7 @@ def test_repeated_cycles_no_leak(self) -> None: for cycle in range(5): keys = [make_object_key(cycle * 10 + i) for i in range(3)] write_and_wait_for_l2(sm, keys, layout) - time.sleep(0.1) + time.sleep(1) sm.clear() handle = sm.submit_prefetch_task(keys, layout) @@ -441,7 +441,7 @@ def _run_roundtrip( keys = [make_object_key(i) for i in range(num_keys)] write_and_wait_for_l2(sm, keys, layout) - time.sleep(0.1) + time.sleep(1) sm.clear() assert get_l1_object_count(sm) == 0 diff --git a/tests/v1/multiprocess/test_blend_v3_load_store_opts.py b/tests/v1/multiprocess/test_blend_v3_load_store_opts.py index bcf7a73820..fd8047ed6a 100644 --- a/tests/v1/multiprocess/test_blend_v3_load_store_opts.py +++ b/tests/v1/multiprocess/test_blend_v3_load_store_opts.py @@ -230,25 +230,20 @@ def _build_fake_gpu_context(batch_size: int, num_groups: int): """Returns a MagicMock matching the minimal GPUCacheContext surface used by _apply_cb_rope_batched.""" gpu_context = MagicMock() - gpu_context.kv_layer_groups_manager.num_groups = num_groups + gpu_context.kv_layer_groups_manager.num_kernel_groups = num_groups # All groups: compress_ratio=1, kv_size=2. groups = [SimpleNamespace(compress_ratio=1) for _ in range(num_groups)] - gpu_context.kv_layer_groups_manager.kv_layer_groups = groups + gpu_context.kv_layer_groups_manager.kernel_groups = groups - # all_slots = [tmp_for_slot_0, ..., tmp_for_slot_{batch-1}] - # Each tmp shape: (2 kv, num_layers, slots_per_chunk, hidden_dim). + # Each per-(slot, group) buffer has shape + # (2 kv, num_layers, slots_per_chunk, hidden_dim). num_layers, slots_per_chunk, hidden_dim = 2, 4, 64 head_size = 32 - def _get_tmp_chunk_gpu_buffer_batched(batch_size, group_idx): - return [ - _FakeTensor((2, num_layers, slots_per_chunk, hidden_dim)) - for _ in range(batch_size) - ] + def _get_temp_kernel_group_buffer(batch_idx, kernel_group_idx): + return _FakeTensor((2, num_layers, slots_per_chunk, hidden_dim)) - gpu_context.get_tmp_chunk_gpu_buffer_batched.side_effect = ( - _get_tmp_chunk_gpu_buffer_batched - ) + gpu_context.get_temp_kernel_group_buffer.side_effect = _get_temp_kernel_group_buffer return gpu_context, head_size @@ -292,8 +287,10 @@ def repeat(self, n): eng._apply_cb_rope_batched(gpu_context, rope_state, 4, slots_to_rope) - # Per-group setup (get_tmp_chunk_gpu_buffer_batched) called once per group. - assert gpu_context.get_tmp_chunk_gpu_buffer_batched.call_count == 2 + # all_slots is built once per group (G=2), each fetching the full batch + # of slot buffers => batch_len(4) × G(2) = 8 buffer fetches, independent + # of how many slots are actually re-RoPE'd. + assert gpu_context.get_temp_kernel_group_buffer.call_count == 8 # Kernel called N=2 slots × G=2 groups = 4 times. assert ops.rotary_embedding_k_fused.call_count == 4 @@ -315,7 +312,7 @@ def test_batched_rope_noop_on_empty_slots(): with patch.object(v3_mod, "lmc_ops") as ops: eng._apply_cb_rope_batched(gpu_context, rope_state, 2, []) - assert gpu_context.get_tmp_chunk_gpu_buffer_batched.call_count == 0 + assert gpu_context.get_temp_kernel_group_buffer.call_count == 0 assert ops.rotary_embedding_k_fused.call_count == 0 @@ -325,8 +322,8 @@ def test_batched_rope_raises_on_compressed_layout(): from lmcache.v1.multiprocess.modules import blend_v3 as v3_mod gpu_context = MagicMock() - gpu_context.kv_layer_groups_manager.num_groups = 1 - gpu_context.kv_layer_groups_manager.kv_layer_groups = [ + gpu_context.kv_layer_groups_manager.num_kernel_groups = 1 + gpu_context.kv_layer_groups_manager.kernel_groups = [ SimpleNamespace(compress_ratio=2) ] rope_state = SimpleNamespace( diff --git a/tests/v1/multiprocess/test_gpu_context.py b/tests/v1/multiprocess/test_gpu_context.py index e7be624cc8..535cdf4ef3 100644 --- a/tests/v1/multiprocess/test_gpu_context.py +++ b/tests/v1/multiprocess/test_gpu_context.py @@ -1,14 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for GPUCacheContext.get_tmp_chunk_gpu_buffer, -get_tmp_chunk_gpu_buffer_batched and get_tmp_gpu_buffer_flat — verifying -contiguity, shape, non-overlapping guarantees, and multi-group layout. - -These tests construct a minimal GPUCacheContext-like object that has -just the fields the buffer methods need, avoiding the full KVCache / -CudaIPCWrapper construction. +"""Unit tests for the temp-GPU-buffer machinery in +``lmcache.v1.multiprocess.gpu_context``. + +Two layers are exercised: + +* ``_TempGPUBuffer`` -- the standalone buffer manager. It is built directly + from a real :class:`KVLayerGroupsManager` (its constructor is fully public), + so the layout invariants (per-kernel-group shape/dtype, per-object-group flat + views, non-overlap, write isolation, byte sizing) are tested in isolation. + +* ``GPUCacheContext`` -- the higher-level context that owns a ``_TempGPUBuffer`` + and exposes the per-kernel-group / per-object-group buffer accessors plus + ``get_kernel_group_kv_pointers``, ``calculate_num_blocks``, + ``kv_layer_groups_manager`` and ``report_status``. It is built through its + real public constructor using a lightweight ``to_tensor`` test double in place + of ``CudaIPCWrapper`` (same-process CUDA IPC cannot reimport its own handle). """ +# Standard +from collections.abc import Sequence + # Third Party import pytest import torch @@ -18,402 +30,486 @@ ) # First Party +from lmcache.v1.gpu_connector.utils import LayoutHints # noqa: E402 from lmcache.v1.kv_layer_groups import KVLayerGroupsManager # noqa: E402 -from lmcache.v1.multiprocess.gpu_context import GPUCacheContext # noqa: E402 +from lmcache.v1.multiprocess.gpu_context import ( # noqa: E402 + GPUCacheContext, + _TempGPUBuffer, +) import lmcache.c_ops as lmc_ops # noqa: E402 +_DEVICE = torch.device("cuda") + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _make_context( - num_layers: int = 4, - num_heads: int = 8, - head_size: int = 128, - is_mla: bool = False, - chunk_size: int = 256, - dtype: torch.dtype = torch.bfloat16, -) -> GPUCacheContext: - """Build a GPUCacheContext with a single KV layer group by directly - setting internal fields, bypassing the KVCache/IPC wrapper construction.""" - ctx = object.__new__(GPUCacheContext) - ctx.is_mla_ = is_mla - ctx.num_layers_ = num_layers - ctx.max_batch_size = 4 - - # Build a real KVLayerGroupsManager from synthetic tensors shaped to - # match the grouping signature the tests care about. - if is_mla: - kv_caches = [ - torch.empty(1, 1, head_size, dtype=dtype) for _ in range(num_layers) - ] - fmt = lmc_ops.GPUKVFormat.NL_X_NB_BS_HS - else: - kv_caches = [ - torch.empty(2, 1, 1, num_heads, head_size, dtype=dtype) - for _ in range(num_layers) - ] - fmt = lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS - manager = KVLayerGroupsManager( - kv_caches, fmt, num_blocks=1, lmcache_logical_chunk_size=chunk_size - ) - ctx.kv_layer_groups_manager_ = manager - - # Build flat tmp_gpu_buffer_ with prefix-sum offsets (new layout) - ctx.tmp_chunk_group_offsets_ = [0] - for gidx, grp in enumerate(manager.kv_layer_groups): - shape = ctx.get_kv_buffer_shape(chunk_size, gidx) - byte_size = shape.numel() * grp.dtype.itemsize - ctx.tmp_chunk_group_offsets_.append( - ctx.tmp_chunk_group_offsets_[-1] + byte_size - ) - ctx.tmp_chunk_bytes_ = ctx.tmp_chunk_group_offsets_[-1] - ctx.lmcache_logical_chunk_size = chunk_size - ctx.tmp_gpu_buffer_ = torch.empty( - ctx.tmp_chunk_bytes_ * ctx.max_batch_size, - dtype=torch.uint8, - device="cuda", +class _GroupSpec: + """Description of one homogeneous block of KV layers used to build the + synthetic ``[2, NB, BS, NH, HS]`` (non-MLA) tensors fed to the manager.""" + + def __init__( + self, + num_layers: int, + num_heads: int = 8, + head_size: int = 64, + block_size: int = 16, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + self.num_layers = num_layers + self.num_heads = num_heads + self.head_size = head_size + self.block_size = block_size + self.dtype = dtype + + +def _make_kv_tensors( + specs: Sequence[_GroupSpec], + num_blocks: int = 4, +) -> list[torch.Tensor]: + """Build non-MLA per-layer KV tensors shaped ``[2, NB, BS, NH, HS]``.""" + tensors: list[torch.Tensor] = [] + for spec in specs: + for _ in range(spec.num_layers): + tensors.append( + torch.empty( + 2, + num_blocks, + spec.block_size, + spec.num_heads, + spec.head_size, + dtype=spec.dtype, + device=_DEVICE, + ) + ) + return tensors + + +def _build_manager( + tensors: list[torch.Tensor], + num_blocks: int = 4, + gpu_kv_format: "lmc_ops.GPUKVFormat" = lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS, + layout_hints: LayoutHints | None = None, +) -> KVLayerGroupsManager: + """Build a real :class:`KVLayerGroupsManager` from synthetic tensors.""" + return KVLayerGroupsManager( + tensors, + gpu_kv_format=gpu_kv_format, + num_blocks=num_blocks, + layout_hints=layout_hints, ) - return ctx -def _make_context_multi_group( - groups: list[dict], +def _make_temp_buffer( + specs: Sequence[_GroupSpec], chunk_size: int = 256, - is_mla: bool = False, -) -> GPUCacheContext: - """Build a GPUCacheContext with multiple KV layer groups. - - Args: - groups: List of dicts, each with keys: - - num_layers (int) - - num_heads (int) - - head_size (int) - - dtype (torch.dtype, optional, default bfloat16) - chunk_size: Tokens per chunk. - is_mla: Whether to use MLA (kv_dim=1) layout. - """ - assert not is_mla, "multi-group helper only exercises the non-MLA path" - ctx = object.__new__(GPUCacheContext) - ctx.is_mla_ = is_mla - ctx.max_batch_size = 4 - - kv_caches: list[torch.Tensor] = [] - for g in groups: - nl = g["num_layers"] - nh = g["num_heads"] - hs = g["head_size"] - dt = g.get("dtype", torch.bfloat16) - kv_caches.extend(torch.empty(2, 1, 1, nh, hs, dtype=dt) for _ in range(nl)) - - ctx.num_layers_ = len(kv_caches) - manager = KVLayerGroupsManager( - kv_caches, - lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS, - num_blocks=1, + max_batch_size: int = 4, + num_blocks: int = 4, + layout_hints: LayoutHints | None = None, +) -> _TempGPUBuffer: + """Build a ``_TempGPUBuffer`` backed by a real manager.""" + tensors = _make_kv_tensors(specs, num_blocks=num_blocks) + manager = _build_manager(tensors, num_blocks=num_blocks, layout_hints=layout_hints) + return _TempGPUBuffer( + kv_layer_groups_manager=manager, lmcache_logical_chunk_size=chunk_size, + device=_DEVICE, + max_batch_size=max_batch_size, ) - ctx.kv_layer_groups_manager_ = manager - - # Build flat tmp_gpu_buffer_ with prefix-sum offsets - ctx.tmp_chunk_group_offsets_ = [0] - for gidx, grp in enumerate(manager.kv_layer_groups): - shape = ctx.get_kv_buffer_shape(chunk_size, gidx) - byte_size = shape.numel() * grp.dtype.itemsize - ctx.tmp_chunk_group_offsets_.append( - ctx.tmp_chunk_group_offsets_[-1] + byte_size + + +def _expected_kernel_group_shape( + manager: KVLayerGroupsManager, num_tokens: int, kernel_group_idx: int +) -> torch.Size: + """Compute the expected kernel-group buffer shape from the manager's + public metadata (kv_size, num_layers, slots, hidden_dim).""" + group = manager.kernel_groups[kernel_group_idx] + num_slots = num_tokens // group.compress_ratio + return torch.Size( + ( + group.shape_desc.kv_size, + group.num_layers, + num_slots, + group.hidden_dim_size, ) - ctx.tmp_chunk_bytes_ = ctx.tmp_chunk_group_offsets_[-1] - ctx.lmcache_logical_chunk_size = chunk_size - ctx.tmp_gpu_buffer_ = torch.empty( - ctx.tmp_chunk_bytes_ * ctx.max_batch_size, - dtype=torch.uint8, - device="cuda", ) - return ctx -# --------------------------------------------------------------------------- -# get_tmp_chunk_gpu_buffer tests -# --------------------------------------------------------------------------- +def _expected_kernel_group_bytes( + manager: KVLayerGroupsManager, chunk_size: int, kernel_group_idx: int +) -> int: + """Byte size of one kernel group's per-chunk buffer.""" + group = manager.kernel_groups[kernel_group_idx] + shape = _expected_kernel_group_shape(manager, chunk_size, kernel_group_idx) + return shape.numel() * group.dtype.itemsize -class TestGetTmpChunkGpuBuffer: - def test_contiguity(self) -> None: - ctx = _make_context(chunk_size=256) - buf = ctx.get_tmp_chunk_gpu_buffer() - assert buf.is_contiguous(), "Buffer not contiguous" +def _byte_region(buf: torch.Tensor) -> tuple[int, int]: + """Return ``(start_ptr, end_ptr)`` covering a tensor's bytes.""" + start = buf.data_ptr() + return start, start + buf.nelement() * buf.element_size() - def test_shape(self) -> None: - ctx = _make_context(chunk_size=256) - buf = ctx.get_tmp_chunk_gpu_buffer() - expected = ctx.get_kv_buffer_shape(256) - assert buf.shape == expected - def test_shape_mla(self) -> None: - ctx = _make_context(is_mla=True, num_heads=1, head_size=576, chunk_size=256) - buf = ctx.get_tmp_chunk_gpu_buffer() - expected = ctx.get_kv_buffer_shape(256) - assert buf.shape == expected - assert buf.shape[0] == 1 # kv_dim=1 for MLA +def _assert_disjoint(regions: list[tuple[int, int, str]]) -> None: + """Assert that no two ``(start, end, label)`` byte ranges overlap.""" + for i in range(len(regions)): + for j in range(i + 1, len(regions)): + s_i, e_i, label_i = regions[i] + s_j, e_j, label_j = regions[j] + assert e_i <= s_j or e_j <= s_i, f"Overlap between {label_i} and {label_j}" - def test_repeated_calls_same_ptr(self) -> None: - """Two calls should return the same base pointer (same pre-allocated slot).""" - ctx = _make_context(chunk_size=256) - buf1 = ctx.get_tmp_chunk_gpu_buffer() - buf2 = ctx.get_tmp_chunk_gpu_buffer() - assert buf1.data_ptr() == buf2.data_ptr() - - def test_write_read_roundtrip(self) -> None: - """Write a pattern, read it back to verify the view is correct.""" - ctx = _make_context(num_layers=2, num_heads=2, head_size=16, chunk_size=32) - buf = ctx.get_tmp_chunk_gpu_buffer() - buf.fill_(42.0) - assert buf.to(torch.float32).sum().item() == pytest.approx( - 42.0 * buf.numel(), rel=1e-3 - ) +class _FakeIPCWrapper: + """Test-only stand-in for ``CudaIPCWrapper``. -# --------------------------------------------------------------------------- -# get_tmp_chunk_gpu_buffer_batched tests -# --------------------------------------------------------------------------- + ``GPUCacheContext`` only needs ``to_tensor()`` from each entry of its + ``kv_caches`` argument. Same-process CUDA IPC cannot reopen its own handle, + so this test double simply hands back a locally allocated CUDA tensor, + letting the real ``GPUCacheContext`` constructor run end to end. + """ + def __init__(self, tensor: torch.Tensor) -> None: + self._tensor = tensor -class TestGetTmpChunkGpuBufferBatched: - @pytest.mark.parametrize("batch_size", [1, 2, 3, 4]) - def test_contiguity(self, batch_size: int) -> None: - ctx = _make_context(chunk_size=256) - buffers = ctx.get_tmp_chunk_gpu_buffer_batched(batch_size) - assert len(buffers) == batch_size - for i, buf in enumerate(buffers): - assert buf.is_contiguous(), f"Buffer {i} not contiguous" - - @pytest.mark.parametrize("batch_size", [1, 2, 3, 4]) - def test_shapes(self, batch_size: int) -> None: - ctx = _make_context(chunk_size=256) - buffers = ctx.get_tmp_chunk_gpu_buffer_batched(batch_size) - expected_shape = ctx.get_kv_buffer_shape(256) - for buf in buffers: - assert buf.shape == expected_shape - - @pytest.mark.parametrize("batch_size", [2, 3, 4]) - def test_non_overlapping(self, batch_size: int) -> None: - """Buffers in a batch must not overlap in memory.""" - ctx = _make_context(chunk_size=256) - buffers = ctx.get_tmp_chunk_gpu_buffer_batched(batch_size) - for i in range(len(buffers)): - for j in range(i + 1, len(buffers)): - start_i = buffers[i].data_ptr() - end_i = start_i + buffers[i].nelement() * buffers[i].element_size() - start_j = buffers[j].data_ptr() - end_j = start_j + buffers[j].nelement() * buffers[j].element_size() - assert end_i <= start_j or end_j <= start_i, ( - f"Buffers {i} and {j} overlap" - ) + def to_tensor(self) -> torch.Tensor: + """Return the wrapped local CUDA tensor (test-only).""" + return self._tensor - def test_write_isolation(self) -> None: - """Writing to one buffer must not affect another.""" - ctx = _make_context(num_layers=2, num_heads=2, head_size=16, chunk_size=32) - buffers = ctx.get_tmp_chunk_gpu_buffer_batched(4) - - # Write distinct values to each buffer - for i, buf in enumerate(buffers): - buf.fill_(float(i + 1)) - - # Verify each buffer has its own value - for i, buf in enumerate(buffers): - expected = float(i + 1) - assert buf.to(torch.float32).min().item() == pytest.approx( - expected, rel=1e-3 - ) - assert buf.to(torch.float32).max().item() == pytest.approx( - expected, rel=1e-3 - ) - def test_batch_exceeds_max_raises(self) -> None: - ctx = _make_context(chunk_size=256) - with pytest.raises(ValueError, match="exceeds max"): - ctx.get_tmp_chunk_gpu_buffer_batched(5) - - @pytest.mark.parametrize("batch_size", [1, 2, 3, 4]) - def test_mla(self, batch_size: int) -> None: - ctx = _make_context(is_mla=True, num_heads=1, head_size=576, chunk_size=256) - buffers = ctx.get_tmp_chunk_gpu_buffer_batched(batch_size) - for buf in buffers: - assert buf.is_contiguous() - assert buf.shape[0] == 1 # kv_dim=1 for MLA - - def test_consistent_with_single(self) -> None: - """get_tmp_chunk_gpu_buffer_batched(1)[0] should have the same data_ptr - and shape as get_tmp_chunk_gpu_buffer().""" - ctx = _make_context(chunk_size=256) - single = ctx.get_tmp_chunk_gpu_buffer() - batched = ctx.get_tmp_chunk_gpu_buffer_batched(1) - assert len(batched) == 1 - assert batched[0].data_ptr() == single.data_ptr() - assert batched[0].shape == single.shape +def _make_context( + specs: Sequence[_GroupSpec], + chunk_size: int = 256, + num_blocks: int = 4, + layout_hints: LayoutHints | None = None, +) -> GPUCacheContext: + """Build a real ``GPUCacheContext`` via its public constructor.""" + tensors = _make_kv_tensors(specs, num_blocks=num_blocks) + kv_caches = [_FakeIPCWrapper(t) for t in tensors] + return GPUCacheContext( + kv_caches, # type: ignore + lmcache_logical_chunk_size=chunk_size, + layout_hints=layout_hints, + ) + + +# Common group layouts reused across tests. +_SINGLE_GROUP = [_GroupSpec(num_layers=4, num_heads=8, head_size=64)] +_MULTI_GROUP = [ + _GroupSpec(num_layers=4, num_heads=8, head_size=64, dtype=torch.bfloat16), + _GroupSpec(num_layers=2, num_heads=16, head_size=64, dtype=torch.float16), +] # --------------------------------------------------------------------------- -# Multi-group tests +# _TempGPUBuffer tests # --------------------------------------------------------------------------- -class TestMultiGroup: - """Tests for multi-group flat buffer layout.""" - - GROUPS_SAME_DTYPE = [ - {"num_layers": 4, "num_heads": 8, "head_size": 128, "dtype": torch.bfloat16}, - {"num_layers": 4, "num_heads": 8, "head_size": 128, "dtype": torch.bfloat16}, - ] - GROUPS_DIFF_DTYPE = [ - {"num_layers": 4, "num_heads": 8, "head_size": 128, "dtype": torch.bfloat16}, - {"num_layers": 2, "num_heads": 4, "head_size": 64, "dtype": torch.float16}, - ] - - def test_prefix_sum_length(self) -> None: - """tmp_chunk_group_offsets_ should have num_groups+1 entries.""" - ctx = _make_context_multi_group(self.GROUPS_SAME_DTYPE) - num_groups = len(ctx.kv_layer_groups_manager_.kv_layer_groups) - assert len(ctx.tmp_chunk_group_offsets_) == num_groups + 1 - - def test_prefix_sum_monotone(self) -> None: - """Offsets must be strictly increasing.""" - ctx = _make_context_multi_group(self.GROUPS_DIFF_DTYPE) - offsets = ctx.tmp_chunk_group_offsets_ - for i in range(1, len(offsets)): - assert offsets[i] > offsets[i - 1], ( - f"Offset not increasing at index {i}: {offsets}" - ) +class TestTempGPUBufferConstruction: + def test_max_batch_size_property(self) -> None: + buf = _make_temp_buffer(_SINGLE_GROUP, max_batch_size=3) + assert buf.max_batch_size == 3 - def test_flat_buffer_total_size(self) -> None: - """tmp_gpu_buffer_ byte count == tmp_chunk_bytes_ * max_batch_size.""" - ctx = _make_context_multi_group(self.GROUPS_SAME_DTYPE) - assert ctx.tmp_gpu_buffer_.numel() == ctx.tmp_chunk_bytes_ * ctx.max_batch_size - - def test_groups_non_overlapping_in_chunk(self) -> None: - """Within a single chunk, different groups must occupy disjoint byte ranges.""" - ctx = _make_context_multi_group(self.GROUPS_DIFF_DTYPE) - offsets = ctx.tmp_chunk_group_offsets_ - num_groups = len(ctx.kv_layer_groups_manager_.kv_layer_groups) - for i in range(num_groups): - for j in range(i + 1, num_groups): - # [offsets[i], offsets[i+1]) vs [offsets[j], offsets[j+1]) - assert offsets[i + 1] <= offsets[j] or offsets[j + 1] <= offsets[i], ( - f"Groups {i} and {j} overlap in chunk layout" - ) - def test_get_tmp_chunk_gpu_buffer_shape_per_group(self) -> None: - """get_tmp_chunk_gpu_buffer returns the correct shape for each group.""" - ctx = _make_context_multi_group(self.GROUPS_DIFF_DTYPE, chunk_size=256) - num_groups = len(ctx.kv_layer_groups_manager_.kv_layer_groups) - for gidx in range(num_groups): - buf = ctx.get_tmp_chunk_gpu_buffer(group_idx=gidx) - expected = ctx.get_kv_buffer_shape(256, gidx) - assert buf.shape == expected, ( - f"Group {gidx}: expected {expected}, got {buf.shape}" - ) +class TestTempGPUBufferKernelGroupBuffer: + def test_shape_and_dtype(self) -> None: + tensors = _make_kv_tensors(_MULTI_GROUP) + manager = _build_manager(tensors) + buf = _TempGPUBuffer(manager, 256, _DEVICE) + for kg in range(manager.num_kernel_groups): + tensor = buf.get_temp_kernel_group_buffer(0, kg) + assert tensor.shape == _expected_kernel_group_shape(manager, 256, kg) + assert tensor.dtype == manager.kernel_groups[kg].dtype - def test_get_tmp_chunk_gpu_buffer_dtype_per_group(self) -> None: - """get_tmp_chunk_gpu_buffer returns the correct dtype for each group.""" - ctx = _make_context_multi_group(self.GROUPS_DIFF_DTYPE, chunk_size=256) - groups = ctx.kv_layer_groups_manager_.kv_layer_groups - for gidx, grp in enumerate(groups): - buf = ctx.get_tmp_chunk_gpu_buffer(group_idx=gidx) - assert buf.dtype == grp.dtype, ( - f"Group {gidx}: expected dtype {grp.dtype}, got {buf.dtype}" + def test_contiguous(self) -> None: + buf = _make_temp_buffer(_SINGLE_GROUP) + assert buf.get_temp_kernel_group_buffer(0, 0).is_contiguous() + + def test_repeated_calls_same_ptr(self) -> None: + buf = _make_temp_buffer(_SINGLE_GROUP) + first = buf.get_temp_kernel_group_buffer(1, 0) + second = buf.get_temp_kernel_group_buffer(1, 0) + assert first.data_ptr() == second.data_ptr() + + def test_invalid_batch_idx_raises(self) -> None: + buf = _make_temp_buffer(_SINGLE_GROUP, max_batch_size=4) + with pytest.raises(ValueError, match="Invalid batch_idx"): + buf.get_temp_kernel_group_buffer(4, 0) + + def test_invalid_kernel_group_idx_raises(self) -> None: + buf = _make_temp_buffer(_SINGLE_GROUP) + with pytest.raises(ValueError, match="kernel_group_idx"): + buf.get_temp_kernel_group_buffer(0, 99) + + def test_buffers_non_overlapping(self) -> None: + """Every (batch, kernel_group) buffer occupies disjoint memory.""" + tensors = _make_kv_tensors(_MULTI_GROUP) + manager = _build_manager(tensors) + max_batch_size = 4 + buf = _TempGPUBuffer(manager, 256, _DEVICE, max_batch_size=max_batch_size) + regions: list[tuple[int, int, str]] = [] + for batch in range(max_batch_size): + for kg in range(manager.num_kernel_groups): + tensor = buf.get_temp_kernel_group_buffer(batch, kg) + start, end = _byte_region(tensor) + regions.append((start, end, f"batch={batch},kg={kg}")) + _assert_disjoint(regions) + + def test_write_isolation(self) -> None: + """Writing to one batch slot must not corrupt another.""" + buf = _make_temp_buffer( + [_GroupSpec(num_layers=2, num_heads=2, head_size=16)], + chunk_size=32, + max_batch_size=4, + ) + for batch in range(4): + buf.get_temp_kernel_group_buffer(batch, 0).fill_(float(batch + 1)) + for batch in range(4): + tensor = buf.get_temp_kernel_group_buffer(batch, 0).to(torch.float32) + assert tensor.min().item() == pytest.approx(batch + 1, rel=1e-3) + assert tensor.max().item() == pytest.approx(batch + 1, rel=1e-3) + + +class TestTempGPUBufferObjectGroupBuffer: + def test_flat_uint8(self) -> None: + buf = _make_temp_buffer(_MULTI_GROUP) + tensor = buf.get_temp_object_group_buffer(0, 0) + assert tensor.dtype == torch.uint8 + assert tensor.dim() == 1 + assert tensor.is_contiguous() + + def test_size_covers_all_kernel_groups(self) -> None: + """The single object group's flat buffer spans every kernel group's + bytes for one chunk.""" + tensors = _make_kv_tensors(_MULTI_GROUP) + manager = _build_manager(tensors) + chunk_size = 256 + buf = _TempGPUBuffer(manager, chunk_size, _DEVICE) + obj_group = manager.object_groups[0] + expected_bytes = sum( + _expected_kernel_group_bytes(manager, chunk_size, kg) + for kg in obj_group.kernel_group_indices + ) + assert buf.get_temp_object_group_buffer(0, 0).numel() == expected_bytes + + def test_starts_at_first_kernel_group(self) -> None: + """The object-group flat view aliases the same memory as its first + kernel group's buffer.""" + tensors = _make_kv_tensors(_MULTI_GROUP) + manager = _build_manager(tensors) + buf = _TempGPUBuffer(manager, 256, _DEVICE) + first_kg = manager.object_groups[0].kernel_group_indices[0] + obj_buf = buf.get_temp_object_group_buffer(0, 0) + kg_buf = buf.get_temp_kernel_group_buffer(0, first_kg) + assert obj_buf.data_ptr() == kg_buf.data_ptr() + + def test_invalid_indices_raise(self) -> None: + buf = _make_temp_buffer(_SINGLE_GROUP, max_batch_size=4) + with pytest.raises(ValueError, match="object_group_idx"): + buf.get_temp_object_group_buffer(0, 99) + with pytest.raises(ValueError, match="batch_idx"): + buf.get_temp_object_group_buffer(4, 0) + + def test_contains_kernel_group_data(self) -> None: + """Bytes written through kernel-group views are visible through the + object-group flat view at matching offsets.""" + tensors = _make_kv_tensors(_MULTI_GROUP) + manager = _build_manager(tensors) + chunk_size = 64 + buf = _TempGPUBuffer(manager, chunk_size, _DEVICE) + obj_group = manager.object_groups[0] + + for offset_kg, kg in enumerate(obj_group.kernel_group_indices): + buf.get_temp_kernel_group_buffer(0, kg).view(torch.uint8).fill_( + offset_kg + 1 ) - def test_groups_data_ptr_matches_offsets(self) -> None: - """data_ptr of each group's buffer should equal base + group offset.""" - ctx = _make_context_multi_group(self.GROUPS_DIFF_DTYPE, chunk_size=256) - base_ptr = ctx.tmp_gpu_buffer_.data_ptr() - num_groups = len(ctx.kv_layer_groups_manager_.kv_layer_groups) - for gidx in range(num_groups): - buf = ctx.get_tmp_chunk_gpu_buffer(group_idx=gidx) - expected_ptr = base_ptr + ctx.tmp_chunk_group_offsets_[gidx] - assert buf.data_ptr() == expected_ptr, ( - f"Group {gidx}: expected ptr offset " - f"{ctx.tmp_chunk_group_offsets_[gidx]}, " - f"got {buf.data_ptr() - base_ptr}" + flat = buf.get_temp_object_group_buffer(0, 0) + cursor = 0 + for offset_kg, kg in enumerate(obj_group.kernel_group_indices): + size = _expected_kernel_group_bytes(manager, chunk_size, kg) + region = flat[cursor : cursor + size] + assert region.min().item() == offset_kg + 1 + assert region.max().item() == offset_kg + 1 + cursor += size + + def test_object_groups_non_overlapping(self) -> None: + """Object-group buffers across batch slots occupy disjoint memory.""" + tensors = _make_kv_tensors(_MULTI_GROUP) + manager = _build_manager(tensors) + max_batch_size = 4 + buf = _TempGPUBuffer(manager, 256, _DEVICE, max_batch_size=max_batch_size) + regions: list[tuple[int, int, str]] = [] + for batch in range(max_batch_size): + for og in range(manager.num_object_groups): + start, end = _byte_region(buf.get_temp_object_group_buffer(batch, og)) + regions.append((start, end, f"batch={batch},og={og}")) + _assert_disjoint(regions) + + +class TestTempGPUBufferShapeDtype: + def test_shape_scales_with_num_tokens(self) -> None: + tensors = _make_kv_tensors(_SINGLE_GROUP) + manager = _build_manager(tensors) + buf = _TempGPUBuffer(manager, 256, _DEVICE) + for num_tokens in (16, 128, 256): + shape, dtype = buf.get_kernel_group_shape_dtype(num_tokens, 0) + assert shape == _expected_kernel_group_shape(manager, num_tokens, 0) + assert dtype == manager.kernel_groups[0].dtype + + def test_shape_compressed_group(self) -> None: + """For a compressed group, the token dim is divided by compress_ratio.""" + tensors = _make_kv_tensors([_GroupSpec(num_layers=2, block_size=8)]) + manager = _build_manager( + tensors, layout_hints={"inference_engine_logical_block_size": 16} + ) + assert manager.kernel_groups[0].compress_ratio == 2 + buf = _TempGPUBuffer(manager, 256, _DEVICE) + shape, _ = buf.get_kernel_group_shape_dtype(256, 0) + assert shape[2] == 256 // 2 + + def test_not_divisible_by_compress_ratio_raises(self) -> None: + tensors = _make_kv_tensors([_GroupSpec(num_layers=2, block_size=8)]) + manager = _build_manager( + tensors, layout_hints={"inference_engine_logical_block_size": 16} + ) + buf = _TempGPUBuffer(manager, 256, _DEVICE) + with pytest.raises(ValueError, match="not a multiple of"): + buf.get_kernel_group_shape_dtype(255, 0) + + +class TestTempGPUBufferCacheSize: + def test_cache_size_per_token(self) -> None: + tensors = _make_kv_tensors(_MULTI_GROUP) + manager = _build_manager(tensors) + chunk_size = 256 + buf = _TempGPUBuffer(manager, chunk_size, _DEVICE) + expected = ( + sum( + _expected_kernel_group_bytes(manager, chunk_size, kg) + for kg in range(manager.num_kernel_groups) ) + // chunk_size + ) + assert buf.get_cache_size_per_token() == expected + + def test_cache_size_per_token_compressed(self) -> None: + """Compression halves per-physical-slot bytes, so the per-logical-token + size of a 2x-compressed group is half its uncompressed counterpart.""" + uncompressed = _make_temp_buffer([_GroupSpec(num_layers=2, block_size=16)]) + compressed = _make_temp_buffer( + [_GroupSpec(num_layers=2, block_size=8)], + layout_hints={"inference_engine_logical_block_size": 16}, + ) + assert ( + compressed.get_cache_size_per_token() * 2 + == uncompressed.get_cache_size_per_token() + ) - def test_write_isolation_across_groups(self) -> None: - """Writing to one group's buffer must not corrupt another group.""" - ctx = _make_context_multi_group(self.GROUPS_SAME_DTYPE, chunk_size=64) - num_groups = len(ctx.kv_layer_groups_manager_.kv_layer_groups) - buffers = [ctx.get_tmp_chunk_gpu_buffer(group_idx=g) for g in range(num_groups)] - - for i, buf in enumerate(buffers): - buf.fill_(float(i + 1)) - - for i, buf in enumerate(buffers): - expected = float(i + 1) - assert buf.to(torch.float32).min().item() == pytest.approx( - expected, rel=1e-3 - ), f"Group {i} was corrupted" - assert buf.to(torch.float32).max().item() == pytest.approx( - expected, rel=1e-3 - ), f"Group {i} was corrupted" - - @pytest.mark.parametrize("batch_size", [1, 2, 4]) - def test_batched_non_overlapping_across_groups_and_chunks( - self, batch_size: int - ) -> None: - """All (group, chunk_idx) combinations must occupy disjoint memory.""" - ctx = _make_context_multi_group(self.GROUPS_DIFF_DTYPE, chunk_size=256) - num_groups = len(ctx.kv_layer_groups_manager_.kv_layer_groups) - # Collect (data_ptr, end_ptr) for every (group, chunk) combination - regions: list[tuple[int, int, str]] = [] - for gidx in range(num_groups): - bufs = ctx.get_tmp_chunk_gpu_buffer_batched(batch_size, group_idx=gidx) - for cidx, buf in enumerate(bufs): - start = buf.data_ptr() - end = start + buf.nelement() * buf.element_size() - regions.append((start, end, f"group={gidx},chunk={cidx}")) - - for i in range(len(regions)): - for j in range(i + 1, len(regions)): - s_i, e_i, label_i = regions[i] - s_j, e_j, label_j = regions[j] - assert e_i <= s_j or e_j <= s_i, ( - f"Overlap between {label_i} and {label_j}" - ) +# --------------------------------------------------------------------------- +# GPUCacheContext tests +# --------------------------------------------------------------------------- - def test_flat_buffer_covers_all_groups(self) -> None: - """get_tmp_gpu_buffer_flat covers the full chunk (all groups).""" - ctx = _make_context_multi_group(self.GROUPS_DIFF_DTYPE, chunk_size=256) - flat = ctx.get_tmp_gpu_buffer_flat(chunk_idx=0) - assert flat.numel() == ctx.tmp_chunk_bytes_ - assert flat.dtype == torch.uint8 - - def test_flat_buffer_chunk_idx_raises(self) -> None: - """chunk_idx >= max_batch_size should raise ValueError.""" - ctx = _make_context_multi_group(self.GROUPS_SAME_DTYPE) - with pytest.raises(ValueError, match="exceeds max_batch_size"): - ctx.get_tmp_gpu_buffer_flat(chunk_idx=ctx.max_batch_size) - - def test_flat_buffer_contains_group_data(self) -> None: - """Data written via get_tmp_chunk_gpu_buffer should be visible in flat view.""" - ctx = _make_context_multi_group(self.GROUPS_SAME_DTYPE, chunk_size=64) - num_groups = len(ctx.kv_layer_groups_manager_.kv_layer_groups) - - # Fill each group with a distinct byte value - for gidx in range(num_groups): - buf = ctx.get_tmp_chunk_gpu_buffer(group_idx=gidx) - # Use view(torch.uint8) to fill raw bytes - buf.view(torch.uint8).fill_(gidx + 1) - - flat = ctx.get_tmp_gpu_buffer_flat(chunk_idx=0) - for gidx in range(num_groups): - g_start = ctx.tmp_chunk_group_offsets_[gidx] - g_end = ctx.tmp_chunk_group_offsets_[gidx + 1] - region = flat[g_start:g_end] - assert region.min().item() == gidx + 1, ( - f"Group {gidx} flat region has wrong min value" - ) - assert region.max().item() == gidx + 1, ( - f"Group {gidx} flat region has wrong max value" + +class TestGPUCacheContextBuffers: + def test_max_batch_size(self) -> None: + ctx = _make_context(_SINGLE_GROUP) + assert ctx.max_batch_size == 4 + + def test_kv_layer_groups_manager(self) -> None: + ctx = _make_context(_MULTI_GROUP) + manager = ctx.kv_layer_groups_manager + assert isinstance(manager, KVLayerGroupsManager) + assert manager.num_kernel_groups == 2 + + def test_get_temp_kernel_group_buffer(self) -> None: + ctx = _make_context(_MULTI_GROUP) + manager = ctx.kv_layer_groups_manager + for kg in range(manager.num_kernel_groups): + tensor = ctx.get_temp_kernel_group_buffer(0, kg) + assert tensor.shape == _expected_kernel_group_shape(manager, 256, kg) + assert tensor.dtype == manager.kernel_groups[kg].dtype + + def test_get_temp_object_group_buffer(self) -> None: + ctx = _make_context(_MULTI_GROUP) + tensor = ctx.get_temp_object_group_buffer(0, 0) + assert tensor.dtype == torch.uint8 + assert tensor.dim() == 1 + + def test_get_kernel_group_shape_dtype(self) -> None: + ctx = _make_context(_SINGLE_GROUP) + manager = ctx.kv_layer_groups_manager + shape, dtype = ctx.get_kernel_group_shape_dtype(128, 0) + assert shape == _expected_kernel_group_shape(manager, 128, 0) + assert dtype == manager.kernel_groups[0].dtype + + +class TestGPUCacheContextPointers: + def test_get_kernel_group_kv_pointers(self) -> None: + ctx = _make_context(_MULTI_GROUP) + manager = ctx.kv_layer_groups_manager + for kg in range(manager.num_kernel_groups): + pointers = ctx.get_kernel_group_kv_pointers(kg) + assert pointers.dtype == torch.long + # One pointer per layer in the group. + assert pointers.numel() == manager.kernel_groups[kg].num_layers + + +class TestGPUCacheContextBlocks: + def test_calculate_num_blocks_uncompressed(self) -> None: + # block_size=16, compress_ratio=1 -> 256 tokens span 16 blocks. + ctx = _make_context([_GroupSpec(num_layers=2, block_size=16)]) + assert ctx.calculate_num_blocks(256, 0) == 16 + + def test_calculate_num_blocks_matches_manager(self) -> None: + ctx = _make_context(_MULTI_GROUP) + manager = ctx.kv_layer_groups_manager + for kg in range(manager.num_kernel_groups): + assert ctx.calculate_num_blocks(256, kg) == manager.calculate_num_blocks( + kg, 256 ) + + +class TestGPUCacheContextReportStatus: + def test_report_status_fields(self) -> None: + ctx = _make_context(_SINGLE_GROUP) + status = ctx.report_status() + + expected_keys = { + "num_layers", + "inference_engine_logical_block_size", + "group_physical_block_sizes", + "group_compress_ratios", + "hidden_dim_sizes", + "dtype", + "is_mla", + "num_blocks", + "gpu_kv_format", + "gpu_kv_shape", + "gpu_kv_concrete_shape", + "attention_backend", + "cache_size_per_token", + } + assert set(status.keys()) == expected_keys + + assert status["num_layers"] == 4 + assert status["is_mla"] is False + assert status["group_compress_ratios"] == [1] + assert status["gpu_kv_format"] == "NL_X_TWO_NB_BS_NH_HS" + assert status["dtype"] == str(ctx.dtype) + assert status["cache_size_per_token"] == ctx.cache_size_per_token() + + def test_report_status_multi_group(self) -> None: + ctx = _make_context(_MULTI_GROUP) + manager = ctx.kv_layer_groups_manager + status = ctx.report_status() + assert status["num_layers"] == 6 + assert len(status["group_physical_block_sizes"]) == manager.num_kernel_groups + assert len(status["group_compress_ratios"]) == manager.num_kernel_groups + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/v1/multiprocess/test_gpu_transfer_layout_registry.py b/tests/v1/multiprocess/test_gpu_transfer_layout_registry.py index f2853f6260..8ab2470bef 100644 --- a/tests/v1/multiprocess/test_gpu_transfer_layout_registry.py +++ b/tests/v1/multiprocess/test_gpu_transfer_layout_registry.py @@ -81,6 +81,7 @@ def fake_create_cache_context( def fake_layout_desc( gpu_context: _FakeGPUContext, num_tokens: int, + object_group_id: int = 0, ) -> MemoryLayoutDesc: """Return the shared layout descriptor used by both registrations.""" return layout_desc diff --git a/tests/v1/test_kv_layer_groups_manager.py b/tests/v1/test_kv_layer_groups_manager.py index 8dd259a562..e8cc99bf0d 100644 --- a/tests/v1/test_kv_layer_groups_manager.py +++ b/tests/v1/test_kv_layer_groups_manager.py @@ -9,8 +9,13 @@ # First Party from lmcache.v1.gpu_connector.utils import LayoutHints from lmcache.v1.kv_layer_groups import ( + EXCLUDED_ENGINE_GROUP, + KernelGroupIdentity, + KernelGroupInfo, KVLayerGroupInfo, KVLayerGroupsManager, + LayerGroupIdentity, + ObjectGroupInfo, format_kvcache_shape_spec, parse_kvcache_shape_spec, ) @@ -52,14 +57,14 @@ class TestKVLayerGroupsManager: def test_build_empty(self): manager = _build_manager([], num_blocks=32) - assert manager.kv_layer_groups == [] + assert manager.kernel_groups == [] def test_build_single_layer(self): tensors = [torch.randn(2, 32, 256, 8, 64, dtype=torch.float16)] manager = _build_manager(tensors, num_blocks=32) - assert len(manager.kv_layer_groups) == 1 - group = manager.kv_layer_groups[0] + assert len(manager.kernel_groups) == 1 + group = manager.kernel_groups[0] assert isinstance(group, KVLayerGroupInfo) assert group.layer_indices == [0] assert group.shape_desc.kv_size == 2 @@ -76,8 +81,8 @@ def test_build_multiple_layers_same_shape(self): ] manager = _build_manager(tensors, num_blocks=32) - assert len(manager.kv_layer_groups) == 1 - group = manager.kv_layer_groups[0] + assert len(manager.kernel_groups) == 1 + group = manager.kernel_groups[0] assert group.layer_indices == [0, 1, 2] assert group.shape_desc.nl == 3 assert group.shape_desc.nh == 8 @@ -96,9 +101,9 @@ def test_build_splits_same_shape_by_engine_group_idx(self): ], ) - assert len(manager.kv_layer_groups) == 2 + assert len(manager.kernel_groups) == 2 groups_by_engine_group_idx = { - group.engine_group_idx: group for group in manager.kv_layer_groups + group.engine_group_idx: group for group in manager.kernel_groups } assert groups_by_engine_group_idx[0].layer_indices == [0, 2] assert groups_by_engine_group_idx[1].layer_indices == [1, 3] @@ -121,8 +126,8 @@ def test_build_different_shapes(self): torch.randn(2, 32, 256, 8, 64, dtype=torch.float16), ] manager = _build_manager(tensors, num_blocks=32) - assert len(manager.kv_layer_groups) == 2 - group1, group2 = manager.kv_layer_groups + assert len(manager.kernel_groups) == 2 + group1, group2 = manager.kernel_groups assert group1.layer_indices == [0, 2] assert group1.shape_desc.nh == 8 assert group2.layer_indices == [1] @@ -135,8 +140,8 @@ def test_build_different_dtypes(self): torch.randn(2, 32, 256, 8, 64, dtype=torch.float16), ] manager = _build_manager(tensors, num_blocks=32) - assert len(manager.kv_layer_groups) == 2 - group1, group2 = manager.kv_layer_groups + assert len(manager.kernel_groups) == 2 + group1, group2 = manager.kernel_groups assert group1.layer_indices == [0, 2] assert group1.dtype == torch.float16 assert group2.layer_indices == [1] @@ -151,9 +156,9 @@ def test_build_mixed_differences(self): torch.randn(2, 32, 256, 16, 64, dtype=torch.float32), # nh=16, f32 ] manager = _build_manager(tensors, num_blocks=32) - assert len(manager.kv_layer_groups) == 4 + assert len(manager.kernel_groups) == 4 - groups_by_key = {(g.shape_desc.nh, g.dtype): g for g in manager.kv_layer_groups} + groups_by_key = {(g.shape_desc.nh, g.dtype): g for g in manager.kernel_groups} assert groups_by_key[(8, torch.float16)].layer_indices == [0, 3] assert groups_by_key[(8, torch.float32)].layer_indices == [1] assert groups_by_key[(16, torch.float16)].layer_indices == [2] @@ -306,5 +311,106 @@ def test_not_divisible_raises(self): self._derive(bs=6, logical=16) +class TestKernelGroupIdentity: + """The grouping key is a named tuple; ``LayerGroupIdentity`` is its alias.""" + + def test_fields_and_alias(self): + ident = KernelGroupIdentity( + kv_size=2, + num_heads=8, + head_size=64, + block_size=16, + engine_group_idx=0, + dtype=torch.float16, + ) + assert ident.kv_size == 2 + assert ident.num_heads == 8 + assert ident.head_size == 64 + assert ident.block_size == 16 + assert ident.engine_group_idx == 0 + assert ident.dtype == torch.float16 + assert LayerGroupIdentity is KernelGroupIdentity + + def test_hashable_as_dict_key(self): + ident = KernelGroupIdentity(2, 8, 64, 16, 0, torch.float16) + assert {ident: "x"}[ident] == "x" + + def test_excluded_engine_group_sentinel(self): + assert EXCLUDED_ENGINE_GROUP == -1 + + +class TestKernelAndObjectGroups: + """Kernel-group accessors, deprecated aliases, and the (currently single) + object-group layout.""" + + def test_kernel_groups_match_deprecated_alias(self): + tensors = [ + torch.randn(2, 32, 256, 8, 64, dtype=torch.float16) for _ in range(3) + ] + manager = _build_manager(tensors, num_blocks=32) + # The deprecated alias must still return the live list, not a bound + # method (regression guard for the @property/@deprecate ordering). + assert isinstance(manager.kv_layer_groups, list) + assert manager.kernel_groups is manager.kv_layer_groups + assert manager.num_kernel_groups == manager.num_groups + assert manager.num_kernel_groups == len(manager.kernel_groups) + assert all(isinstance(g, KernelGroupInfo) for g in manager.kernel_groups) + + def test_single_object_group_covers_all_kernel_groups(self): + # Two distinct kernel groups (different num_heads) still share one + # object group under the current single-object-group assumption. + tensors = [ + torch.randn(2, 32, 256, 8, 64, dtype=torch.float16), + torch.randn(2, 32, 256, 16, 64, dtype=torch.float16), + ] + manager = _build_manager(tensors, num_blocks=32) + assert manager.num_kernel_groups == 2 + assert manager.num_object_groups == 1 + obj = manager.object_groups[0] + assert isinstance(obj, ObjectGroupInfo) + assert obj.kernel_group_indices == list(range(manager.num_kernel_groups)) + + def test_empty_manager_has_no_groups(self): + # Empty registration returns early in __init__; both group lists must + # still be initialized (regression guard for missing _object_groups). + manager = _build_manager([], num_blocks=32) + assert manager.kernel_groups == [] + assert manager.num_kernel_groups == 0 + assert manager.object_groups == [] + assert manager.num_object_groups == 0 + + def test_excluded_layer_left_out_of_all_groups(self): + # Layer 2 is referenced by no group view, so it is excluded entirely. + tensors = [ + torch.randn(2, 32, 256, 8, 64, dtype=torch.float16) for _ in range(3) + ] + manager = _build_manager( + tensors, + num_blocks=32, + group_views=[LMCacheGroupView(0, (0, 1))], + ) + grouped = sorted( + idx for group in manager.kernel_groups for idx in group.layer_indices + ) + assert grouped == [0, 1] + + def test_calculate_num_blocks_uncompressed(self): + # bs=16, compress_ratio=1 -> 256 tokens span 16 blocks. + tensors = [torch.randn(2, 32, 16, 8, 64, dtype=torch.float16) for _ in range(2)] + manager = _build_manager(tensors, num_blocks=32) + assert manager.calculate_num_blocks(0, 256) == 16 + + def test_calculate_num_blocks_compressed(self): + # bs=8, ie_logical_block_size=16 -> compress_ratio=2; + # 256 logical tokens -> 128 physical slots -> 128 // 8 = 16 blocks. + tensors = [torch.randn(2, 32, 8, 8, 64, dtype=torch.float16) for _ in range(2)] + manager = _build_manager( + tensors, + num_blocks=32, + layout_hints={"inference_engine_logical_block_size": 16}, + ) + assert manager.calculate_num_blocks(0, 256) == 16 + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 874f81b83fc88dd4a6500b4059a3dfc5e5988758 Mon Sep 17 00:00:00 2001 From: Samuel Shen Date: Mon, 8 Jun 2026 17:05:00 -0700 Subject: [PATCH 06/57] [GPUKVFormat]: support vLLM CPU 2-fused KV layout (#3567) Signed-off-by: Samuel Shen --- .../scripts/run-cpu-e2e-validation.sh | 5 +- csrc/mem_kernels.cuh | 10 ++ csrc/pybind.cpp | 1 + csrc/sycl/pybind_sycl.cpp | 1 + lmcache/python_ops_fallback.py | 6 ++ lmcache/v1/gpu_connector/utils.py | 58 +++++++++- .../v1/multiprocess/transfer_context/base.py | 10 ++ .../test_blocks_first_fused_kv_format.py | 100 ++++++++++++++++++ tests/v1/utils.py | 4 + 9 files changed, 191 insertions(+), 4 deletions(-) create mode 100644 tests/v1/gpu_connector/test_blocks_first_fused_kv_format.py diff --git a/.buildkite/k3_tests/multiprocess/scripts/run-cpu-e2e-validation.sh b/.buildkite/k3_tests/multiprocess/scripts/run-cpu-e2e-validation.sh index d67f06f17b..6c3448199c 100755 --- a/.buildkite/k3_tests/multiprocess/scripts/run-cpu-e2e-validation.sh +++ b/.buildkite/k3_tests/multiprocess/scripts/run-cpu-e2e-validation.sh @@ -230,7 +230,10 @@ uv pip install -r requirements/common.txt echo "✅ Installed requirements/common.txt" echo "Installing vLLM CPU build" -uv pip install vllm --extra-index-url https://wheels.vllm.ai/71df063c494c111ab60f6a33c54aafe7b9ae1d02/cpu --index-strategy first-index --torch-backend cpu +# Un-pinned from 71df063c (LMCache #3538) now that LMCache handles the +# blocks-first fused KV layout. Running against nightly means a passing CPU +# e2e proves the new GPUKVFormat path works. +uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly/cpu --index-strategy first-index --torch-backend cpu echo "✅ vLLM CPU install completed" echo "Installing LMCache in editable mode with NO_GPU_EXT=1" diff --git a/csrc/mem_kernels.cuh b/csrc/mem_kernels.cuh index ac2e22adaa..8c00fad5f0 100644 --- a/csrc/mem_kernels.cuh +++ b/csrc/mem_kernels.cuh @@ -99,6 +99,16 @@ enum class GPUKVFormat : int { - SGLang MHA via the MP daemon path physical shape per layer: [num_blocks, block_size, num_heads, head_size] */ + + NL_X_NB_NH_BS_TWO_HS = 10, + /* + used by: + - vLLM non-MLA blocks-first attention with K/V fused into the trailing dim + physical shape per layer: [num_blocks, num_heads, block_size, 2, head_size] + (recovered by splitting the fused trailing [block_size, 2 * head_size]). + Currently only reached via the host gather/scatter path, not the CUDA + transfer kernels. + */ }; void multi_layer_kv_transfer( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index e2e6eae68b..adc7ed23ab 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -32,6 +32,7 @@ PYBIND11_MODULE(c_ops, m) { .value("NL_X_NB_TWO_NH_BS_HS", GPUKVFormat::NL_X_NB_TWO_NH_BS_HS) .value("NB_NL_TWO_NH_BS_HS", GPUKVFormat::NB_NL_TWO_NH_BS_HS) .value("TWO_X_NL_X_NB_BS_NH_HS", GPUKVFormat::TWO_X_NL_X_NB_BS_NH_HS) + .value("NL_X_NB_NH_BS_TWO_HS", GPUKVFormat::NL_X_NB_NH_BS_TWO_HS) .export_values(); m.def("multi_layer_kv_transfer", &multi_layer_kv_transfer, py::arg("key_value"), py::arg("key_value_ptrs"), diff --git a/csrc/sycl/pybind_sycl.cpp b/csrc/sycl/pybind_sycl.cpp index 6ed9d58f3a..430e08dc2d 100644 --- a/csrc/sycl/pybind_sycl.cpp +++ b/csrc/sycl/pybind_sycl.cpp @@ -27,6 +27,7 @@ PYBIND11_MODULE(xpu_ops, m) { .value("NL_X_NB_TWO_NH_BS_HS", GPUKVFormat::NL_X_NB_TWO_NH_BS_HS) .value("NB_NL_TWO_NH_BS_HS", GPUKVFormat::NB_NL_TWO_NH_BS_HS) .value("TWO_X_NL_X_NB_BS_NH_HS", GPUKVFormat::TWO_X_NL_X_NB_BS_NH_HS) + .value("NL_X_NB_NH_BS_TWO_HS", GPUKVFormat::NL_X_NB_NH_BS_TWO_HS) .export_values(); m.def("multi_layer_kv_transfer", &multi_layer_kv_transfer, py::arg("key_value"), py::arg("key_value_ptrs"), diff --git a/lmcache/python_ops_fallback.py b/lmcache/python_ops_fallback.py index f1b8e15593..a95bf24493 100644 --- a/lmcache/python_ops_fallback.py +++ b/lmcache/python_ops_fallback.py @@ -290,6 +290,12 @@ class GPUKVFormat(IntEnum): # used by: SGLang MHA via the MP daemon path TWO_X_NL_X_NB_BS_NH_HS = 9 + # used by: vLLM non-MLA blocks-first attention with K/V fused into the + # trailing dim. Per-layer physical shape + # [num_blocks, num_heads, block_size, 2, head_size] -- the K/V "2" axis is + # second-to-last, recovered by splitting the fused [..., 2 * head_size]. + NL_X_NB_NH_BS_TWO_HS = 10 + class PageBufferShapeDesc: """Python stand-in for the C++ ``PageBufferShapeDesc`` struct. diff --git a/lmcache/v1/gpu_connector/utils.py b/lmcache/v1/gpu_connector/utils.py index beff9f0f56..8fdf21387d 100644 --- a/lmcache/v1/gpu_connector/utils.py +++ b/lmcache/v1/gpu_connector/utils.py @@ -315,6 +315,7 @@ def get_gpu_kv_shape_description(gpu_kv_format: "lmc_ops.GPUKVFormat") -> str: lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS: "NL x [2, NB, NH, BS, HS]", lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS: "NL x [NB, 2, NH, BS, HS]", lmc_ops.GPUKVFormat.NB_NL_TWO_NH_BS_HS: "[NB, NL, 2, NH, BS, HS]", + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS: "NL x [NB, NH, BS, 2, HS]", } return _SHAPE_DESCRIPTIONS.get(gpu_kv_format, f"Unknown ({gpu_kv_format})") @@ -340,6 +341,9 @@ def get_attention_backend(gpu_kv_format: "lmc_ops.GPUKVFormat") -> str: "vLLM non-MLA flash infer (HND layout)" ), lmc_ops.GPUKVFormat.NB_NL_TWO_NH_BS_HS: "TRT-LLM cross-layer (HND layout)", + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS: ( + "vLLM non-MLA blocks-first, fused K/V" + ), } return _ATTENTION_BACKENDS.get(gpu_kv_format, f"Unknown ({gpu_kv_format})") @@ -414,6 +418,12 @@ def get_concrete_gpu_kv_shape( bs = get_block_size(kv_caches, fmt) return f"[{nb}, {nl}, 2, {nh}, {bs}, {hs}]" + if fmt == F.NL_X_NB_NH_BS_TWO_HS: + nb = get_num_blocks(kv_caches, fmt) + nh = get_num_heads(kv_caches, fmt) + bs = get_block_size(kv_caches, fmt) + return f"{nl} x [{nb}, {nh}, {bs}, 2, {hs}]" + return f"Unknown ({gpu_kv_format})" @@ -615,6 +625,22 @@ def normalize_kv_and_discover_format( detected_format = lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS else: detected_format = lmc_ops.GPUKVFormat.NL_X_NB_TWO_BS_NH_HS + elif tensor_dim == 4: + # vLLM non-MLA blocks-first attention: K/V fused into the + # trailing dim -> [NB, NH, BS, 2*head_size]. + # Split the fused axis so downstream sees the canonical 5D + # [NB, NH, BS, 2, HS]. + last_dim = probe.shape[3] + if last_dim % 2 != 0: + raise ValueError( + "blocks-first fused KV cache trailing dim " + f"{last_dim} is not 2 * head_size" + ) + kv_caches = [ + layer.reshape(*layer.shape[:3], 2, last_dim // 2) + for layer in kv_caches + ] + detected_format = lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS elif tensor_dim == 3: # vllm MLA detected_format = lmc_ops.GPUKVFormat.NL_X_NB_BS_HS @@ -659,6 +685,7 @@ def get_num_layers( lmc_ops.GPUKVFormat.NL_X_NB_BS_HS, lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS, ): return len(kv_caches) elif gpu_kv_format in ( @@ -692,8 +719,9 @@ def get_num_blocks( elif gpu_kv_format in ( lmc_ops.GPUKVFormat.NL_X_NB_TWO_BS_NH_HS, lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS, ): - # [num_blocks, 2, ...] — shape[0] is num_blocks + # [num_blocks, ...] — shape[0] is num_blocks return kv_caches[0].shape[0] elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_BS_HS: return kv_caches[0].shape[0] @@ -731,8 +759,10 @@ def get_block_size( elif gpu_kv_format in ( lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS, lmc_ops.GPUKVFormat.NL_X_NB_TWO_BS_NH_HS, + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS, ): - # NHD: [..., BS, NH, HS] — block_size at shape[2] + # block_size at shape[2]: NHD [..., BS, NH, HS] and the CPU fused + # layout [NB, NH, BS, 2, HS] both carry block_size at shape[2]. return kv_caches[layer_idx].shape[2] elif gpu_kv_format in ( lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, @@ -780,6 +810,10 @@ def get_page_buffer_size( # list[num_layers] of [num_blocks, 2, num_heads, block_size, head_size] # num_blocks=shape[0], block_size=shape[3] return kv_caches[0].shape[0] * kv_caches[0].shape[3] + elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS: + # list[num_layers] of [num_blocks, num_heads, block_size, 2, head_size] + # num_blocks=shape[0], block_size=shape[2] + return kv_caches[0].shape[0] * kv_caches[0].shape[2] elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_BS_HS: # list[num_layers] of [num_blocks, block_size, head_size] return kv_caches[0].shape[0] * kv_caches[0].shape[1] @@ -821,6 +855,9 @@ def get_num_heads( ): # HND: [..., NH, BS, HS] — num_heads at shape[2] return kv_caches[layer_idx].shape[2] + elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS: + # CPU fused: [NB, NH, BS, 2, HS] — num_heads at shape[1] + return kv_caches[layer_idx].shape[1] elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_BS_HS: # MLA: heads are absorbed into hidden dim, so num_heads = 1 return 1 @@ -861,6 +898,9 @@ def get_hidden_dim_size( ): # HND: [..., NH, BS, HS] — hidden_dim = NH * HS = shape[2] * shape[4] return kv_caches[layer_idx].shape[2] * kv_caches[layer_idx].shape[4] + elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS: + # CPU fused: [NB, NH, BS, 2, HS] — hidden_dim = NH * HS = shape[1] * shape[4] + return kv_caches[layer_idx].shape[1] * kv_caches[layer_idx].shape[4] elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_BS_HS: return kv_caches[layer_idx].shape[2] elif gpu_kv_format == lmc_ops.GPUKVFormat.TWO_X_NL_X_NBBS_NH_HS: @@ -895,8 +935,9 @@ def get_head_size( lmc_ops.GPUKVFormat.NL_X_NB_TWO_BS_NH_HS, lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS, ): - # Both NHD [..., NH, HS] and HND [..., BS, HS] have head_size last + # All these per-layer non-MLA layouts carry head_size as the last dim return kv_caches[layer_idx].shape[4] elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_BS_HS: return kv_caches[layer_idx].shape[2] @@ -943,6 +984,10 @@ def get_tokens_per_layer( # k_cache = kv_caches[0][:, 0] → (NB, NH, BS, HS); tokens = NB * BS k_cache_shape = kv_caches[0][:, 0].shape return k_cache_shape[0] * k_cache_shape[2] + elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS: + # list[num_layers] of [num_blocks, num_heads, block_size, 2, head_size] + # tokens = NB * BS = shape[0] * shape[2] + return kv_caches[0].shape[0] * kv_caches[0].shape[2] elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_BS_HS: # list[num_layers] of [num_blocks, block_size, head_size] return kv_caches[0].shape[0] * kv_caches[0].shape[1] @@ -995,6 +1040,10 @@ def get_elements_per_layer( # [num_blocks, 2, ...] — k_cache is kv_caches[0][:, 0] k_cache_shape = kv_caches[0][:, 0].shape return k_cache_shape.numel() * 2 + elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS: + # [NB, NH, BS, 2, HS] — K/V at dim 3; k_cache is kv_caches[0][:, :, :, 0] + k_cache_shape = kv_caches[0][:, :, :, 0].shape + return k_cache_shape.numel() * 2 elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_BS_HS: # list[num_layers] of [num_blocks, block_size, head_size] (MLA) return kv_caches[0].numel() @@ -1022,6 +1071,7 @@ def assert_is_vllm_flash_attn_or_flash_infer(gpu_kv_format: "lmc_ops.GPUKVFormat lmc_ops.GPUKVFormat.NL_X_NB_TWO_BS_NH_HS, lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS, ) @@ -1033,6 +1083,7 @@ def is_hnd(gpu_kv_format: "lmc_ops.GPUKVFormat") -> bool: lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, lmc_ops.GPUKVFormat.NB_NL_TWO_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS, ) @@ -1092,6 +1143,7 @@ def get_dtype( lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, lmc_ops.GPUKVFormat.NL_X_NBBS_ONE_HS, + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS, ): return kv_caches[layer_idx].dtype elif gpu_kv_format in ( diff --git a/lmcache/v1/multiprocess/transfer_context/base.py b/lmcache/v1/multiprocess/transfer_context/base.py index 8857ed13d8..ef43dd6121 100644 --- a/lmcache/v1/multiprocess/transfer_context/base.py +++ b/lmcache/v1/multiprocess/transfer_context/base.py @@ -282,6 +282,7 @@ def gather_paged_kv_to_cpu( is_hnd = gpu_kv_format in ( lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS, ) block_size = get_block_size(normalized, gpu_kv_format) @@ -326,6 +327,10 @@ def gather_paged_kv_to_cpu( if gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS: k_t = layer[0] v_t = layer[1] + elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS: + # [NB, NH, BS, 2, HS] — K/V fused at dim 3 + k_t = layer[:, :, :, 0] + v_t = layer[:, :, :, 1] else: k_t = layer[:, 0] v_t = layer[:, 1] @@ -419,6 +424,7 @@ def scatter_cpu_to_paged_kv( is_hnd = gpu_kv_format in ( lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS, ) # After normalization the structure is always a list of per-layer @@ -462,6 +468,10 @@ def scatter_cpu_to_paged_kv( if gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS: k_t = layer[0] v_t = layer[1] + elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS: + # [NB, NH, BS, 2, HS] — K/V fused at dim 3 + k_t = layer[:, :, :, 0] + v_t = layer[:, :, :, 1] else: k_t = layer[:, 0] v_t = layer[:, 1] diff --git a/tests/v1/gpu_connector/test_blocks_first_fused_kv_format.py b/tests/v1/gpu_connector/test_blocks_first_fused_kv_format.py new file mode 100644 index 0000000000..72ce1ecd3e --- /dev/null +++ b/tests/v1/gpu_connector/test_blocks_first_fused_kv_format.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Blocks-first, fused-K/V KV layout (GPUKVFormat.NL_X_NB_NH_BS_TWO_HS). + +A non-MLA blocks-first attention backend registers its KV cache as the 4D +``[NB, NH, BS, 2 * HS]`` with K/V fused into the trailing dim (as opposed to +the 5D K/V-major ``[2, NB, NH, BS, HS]``). Discovery splits the fused axis into +the canonical 5D ``[NB, NH, BS, 2, HS]`` and classifies it as +``NL_X_NB_NH_BS_TWO_HS``. + +These tests pin discovery, the format-aware accessors, and the multiprocess +gather/scatter round-trip for that layout. +""" + +# Third Party +import pytest +import torch + +# First Party +from lmcache.utils import EngineType +from lmcache.v1.gpu_connector import utils as U +from lmcache.v1.multiprocess.transfer_context.base import ( + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, +) +import lmcache.c_ops as lmc_ops + +NB, NH, BS, HS, NL = 16, 4, 128, 64, 3 +HINTS = {"kv_layout": "HND"} + + +def _raw_blocks_first_caches() -> list[torch.Tensor]: + """Per-layer blocks-first tensors as registered: [NB, NH, BS, 2 * HS].""" + torch.manual_seed(0) + return [torch.randn(NB, NH, BS, 2 * HS) for _ in range(NL)] + + +def test_discovery_splits_fused_axis(): + fmt, norm = U.normalize_kv_and_discover_format( + _raw_blocks_first_caches(), EngineType.VLLM, HINTS + ) + assert fmt == lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS + # 4D [NB, NH, BS, 2*HS] -> canonical 5D [NB, NH, BS, 2, HS] + assert tuple(norm[0].shape) == (NB, NH, BS, 2, HS) + + +def test_discovery_rejects_odd_trailing_dim(): + bad = [torch.randn(NB, NH, BS, 2 * HS + 1) for _ in range(NL)] + with pytest.raises(ValueError): + U.normalize_kv_and_discover_format(bad, EngineType.VLLM, HINTS) + + +def test_accessors(): + fmt, norm = U.normalize_kv_and_discover_format( + _raw_blocks_first_caches(), EngineType.VLLM, HINTS + ) + assert U.get_num_layers(norm, fmt) == NL + assert U.get_num_blocks(norm, fmt) == NB + assert U.get_block_size(norm, fmt) == BS + assert U.get_num_heads(norm, fmt) == NH + assert U.get_head_size(norm, fmt) == HS + assert U.get_hidden_dim_size(norm, fmt) == NH * HS + assert U.get_page_buffer_size(norm, fmt) == NB * BS + assert U.get_tokens_per_layer(norm, fmt) == NB * BS + assert U.get_elements_per_layer(norm, fmt) == NB * NH * BS * HS * 2 + # get_dtype is on the register_kv_caches -> group_layers_by_identity path, + # so it must recognize this format too. + assert U.get_dtype(norm, fmt) == _raw_blocks_first_caches()[0].dtype + assert U.is_hnd(fmt) is True + assert not U.is_mla(fmt) + + +def test_mp_gather_scatter_roundtrip(): + blocks_per_chunk = 2 + block_ids = [0, 3, 5, 6] # 2 chunks + raw = _raw_blocks_first_caches() + src = {f"layer_{i}": t for i, t in enumerate(raw)} + ref = {k: v.clone() for k, v in src.items()} + idx = torch.tensor(block_ids) + + chunks = gather_paged_kv_to_cpu( + src, block_ids, blocks_per_chunk, layout_hints=HINTS + ) + # [K/V, NL, chunk_tokens, NH*HS] + assert tuple(chunks[0].shape) == (2, NL, blocks_per_chunk * BS, NH * HS) + + # Wipe the gathered blocks, scatter back, and confirm exact recovery. + dst = {k: v.clone() for k, v in src.items()} + for k in dst: + dst[k][idx] = 0.0 + scatter_cpu_to_paged_kv( + dst, block_ids, chunks, blocks_per_chunk, layout_hints=HINTS + ) + + for k in dst: + assert torch.equal(dst[k][idx], ref[k][idx]) + + # Untouched blocks must be left alone. + untouched = torch.tensor([b for b in range(NB) if b not in block_ids]) + for k in dst: + assert torch.equal(dst[k][untouched], ref[k][untouched]) diff --git a/tests/v1/utils.py b/tests/v1/utils.py index 94826b20f8..a4c98d601f 100644 --- a/tests/v1/utils.py +++ b/tests/v1/utils.py @@ -46,6 +46,7 @@ class MockGPUKVFormat: NL_X_NB_BS_HS = 2 NL_X_TWO_NB_NH_BS_HS = 3 NL_X_NB_TWO_NH_BS_HS = 4 + NL_X_NB_NH_BS_TWO_HS = 5 class MockCOps: GPUKVFormat = MockGPUKVFormat @@ -311,6 +312,9 @@ def generate_kv_cache_paged_list_tensors( shape = [2, num_blocks, num_heads, block_size, head_size] elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS: shape = [num_blocks, 2, num_heads, block_size, head_size] + elif gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_NB_NH_BS_TWO_HS: + # blocks-first, K/V fused into the trailing dim + shape = [num_blocks, num_heads, block_size, 2, head_size] else: raise ValueError(f"Unsupported gpu_kv_format: {gpu_kv_format}") From 936bb94744e36ff58c09ac8af06a57ecd7fea83c Mon Sep 17 00:00:00 2001 From: Yihua Cheng Date: Mon, 8 Jun 2026 18:07:08 -0700 Subject: [PATCH 07/57] [Refactor] Rename LMCacheGroupView to EngineGroupInfo (#3598) Signed-off-by: ApostaC --- .../vllm/hybrid-kv-cache-groups.md | 58 +++++++++---------- .../commands/bench/server_bench/command.py | 6 +- .../commands/bench/server_bench/helpers.py | 14 ++--- .../sglang/multi_process_adapter.py | 4 +- lmcache/integration/vllm/kv_cache_groups.py | 14 ++--- .../integration/vllm/lmcache_mp_connector.py | 8 ++- .../vllm/vllm_multi_process_adapter.py | 16 ++--- lmcache/v1/kv_layer_groups.py | 22 +++---- lmcache/v1/multiprocess/gpu_context.py | 6 +- lmcache/v1/multiprocess/group_view.py | 30 +++++----- .../v1/multiprocess/modules/gpu_transfer.py | 8 +-- lmcache/v1/multiprocess/protocols/engine.py | 6 +- .../transfer_context/worker_transfer.py | 14 ++--- lmcache/v1/platform/cache_context.py | 8 +-- .../test_gpu_transfer_layout_registry.py | 2 +- .../multiprocess/test_mq_handler_helpers.py | 10 ++-- tests/v1/test_kv_cache_groups.py | 46 +++++++-------- tests/v1/test_kv_layer_groups_manager.py | 20 +++---- tests/v1/test_vllm_kv_cache_groups.py | 12 ++-- tests/v1/test_vllm_mp_adapter.py | 10 ++-- 20 files changed, 159 insertions(+), 155 deletions(-) diff --git a/docs/design/integration/vllm/hybrid-kv-cache-groups.md b/docs/design/integration/vllm/hybrid-kv-cache-groups.md index 919b66afcc..95f2f05ad8 100644 --- a/docs/design/integration/vllm/hybrid-kv-cache-groups.md +++ b/docs/design/integration/vllm/hybrid-kv-cache-groups.md @@ -8,16 +8,16 @@ connector. It separates three concepts: - **Engine KV cache group** — a group defined by the serving engine (vLLM's `KVCacheConfig.kv_cache_groups`). Each is one distinct paged-block address space; block IDs are only meaningful within one group. -- **`LMCacheGroupView`** — LMCache's engine-neutral, `msgspec`-encoded view of - one such group (`group_view.py`). A `list[LMCacheGroupView]` is the +- **`EngineGroupInfo`** — LMCache's engine-neutral, `msgspec`-encoded record of + one such group (`group_view.py`). A `list[EngineGroupInfo]` is the registration contract. - **`KVLayerGroupInfo`** — the server's runtime transfer-kernel dispatch unit, - built from the views + the real tensors (`kv_layer_groups.py`). + built from the engine group infos + the real tensors (`kv_layer_groups.py`). vLLM groups layers by cache behavior; LMCache must transfer by physical layout (kv_size, num_heads, head_size, block_size, dtype) *and* keep distinct engine -block-id spaces separate. So at registration we build group views, and -store/retrieve address those views directly. +block-id spaces separate. So at registration we build engine group infos, and +store/retrieve address those infos directly. ## Goals / Non-Goals @@ -36,13 +36,13 @@ store/retrieve address those views directly. ## Types -- **`LMCacheGroupView`** (`msgspec.Struct`): `engine_group_id` (which engine - block group its layers live in; dense from 0) + `layer_indices`. Several views +- **`EngineGroupInfo`** (`msgspec.Struct`): `engine_group_id` (which engine + block group its layers live in; dense from 0) + `layer_indices`. Several infos may share an `engine_group_id` when one engine group is split by physical transfer identity. The list order is the protocol-visible group order; an empty list means a single non-hybrid group. -- Helpers in `group_view.py` operate on `Sequence[LMCacheGroupView]`: - `num_engine_groups`, `num_group_views`, `expand_block_ids_to_views`, +- Helpers in `group_view.py` operate on `Sequence[EngineGroupInfo]`: + `num_engine_groups`, `num_engine_group_infos`, `expand_engine_block_ids`, `get_engine_group_indices`. - **`KVLayerGroupInfo`** (runtime, server-only): layer indices, `PageBufferShapeDesc`, dtype, compress ratio, physical chunk size, @@ -52,17 +52,17 @@ store/retrieve address those views directly. ```text vLLM KVCacheConfig + registered kv_caches - | integration.vllm.kv_cache_groups.create_group_views_from_vllm + | integration.vllm.kv_cache_groups.create_engine_group_infos_from_vllm v -list[LMCacheGroupView] --REGISTER_KV_CACHE (msgspec)--> server msgspec-decode +list[EngineGroupInfo] --REGISTER_KV_CACHE (msgspec)--> server msgspec-decode | KVLayerGroupsManager validates against real tensors v -KVLayerGroupInfo list --STORE/RETRIEVE block_ids per view--> transfer kernels +KVLayerGroupInfo list --STORE/RETRIEVE block_ids per info--> transfer kernels ``` ## Registration -`create_group_views_from_vllm` (the only place that reads vLLM `KVCacheConfig`): +`create_engine_group_infos_from_vllm` (the only place that reads vLLM `KVCacheConfig`): 1. Inspect registered tensors for physical layout/dtype. 2. Map each registered layer to its engine group index; layers absent from @@ -71,17 +71,17 @@ KVLayerGroupInfo list --STORE/RETRIEVE block_ids per view--> transfer kernels 3. `group_layers_by_identity` splits layers by transfer identity `(kv_size, num_heads, head_size, block_size, engine_group_idx, dtype)` — the `engine_group_idx` term keeps identically-shaped layers from different engine - groups in separate views. -4. Emit one `LMCacheGroupView` per identity; send the list in the + groups in separate infos. +4. Emit one `EngineGroupInfo` per identity; send the list in the `REGISTER_KV_CACHE` payload (the message queue encodes it). ## Store and retrieve vLLM reports block IDs per engine group. The worker adapter re-indexes them to -group-view order with `expand_block_ids_to_views(group_views, block_ids)` (each -view reuses its source engine group's block IDs), so `STORE`/`RETRIEVE` receive -`list[list[int]]` indexed by view order. The server loop is then trivial: for -view `i`, use `gpu_block_ids[i]`. +engine-group-info order with `expand_engine_block_ids(engine_group_infos, block_ids)` (each +info reuses its source engine group's block IDs), so `STORE`/`RETRIEVE` receive +`list[list[int]]` indexed by info order. The server loop is then trivial: for +info `i`, use `gpu_block_ids[i]`. ### Per-group block sizes @@ -108,7 +108,7 @@ layers in `kv_cache_groups`; a sharing layer is absent from every group's `layer_names`. Such a layer's KV physically lives in its target owner's blocks, so storing/retrieving the owner already covers it. Registration therefore tags unlisted layers with `EXCLUDED_ENGINE_GROUP` and `group_layers_by_identity` -skips them — they never form their own view. (Placing them in a group would +skips them — they never form their own info. (Placing them in a group would duplicate work and, when their block size differs from the group they default into, corrupt the per-group block-id counts.) @@ -124,19 +124,19 @@ vLLM exposes two engine groups — group 0: layers [0,2,4], group 1: [1,3]. If layers 0–3 share a shape but layer 4 differs, registration produces: ```text -view 0: engine group 0, layers [0, 2] -view 1: engine group 1, layers [1, 3] -view 2: engine group 0, layers [4] +info 0: engine group 0, layers [0, 2] +info 1: engine group 1, layers [1, 3] +info 2: engine group 0, layers [4] ``` Block IDs `{group 0: [10,11], group 1: [20,21]}` are sent as -`[[10,11], [20,21], [10,11]]` (views 0 and 2 share group 0's IDs). +`[[10,11], [20,21], [10,11]]` (infos 0 and 2 share group 0's IDs). ## Invariants -- The `list[LMCacheGroupView]` order is the protocol-visible group order; callers - send one block-id list per view. -- vLLM-specific access stays in `lmcache.integration.vllm`; views carry neutral +- The `list[EngineGroupInfo]` order is the protocol-visible group order; callers + send one block-id list per info. +- vLLM-specific access stays in `lmcache.integration.vllm`; infos carry neutral metadata only. - The server reproduces grouping with the same `group_layers_by_identity`; real tensors remain the source of truth for shape/dtype/stride. @@ -151,9 +151,9 @@ but LMCache cannot store/retrieve those layers. | Area | File | |---|---| -| Group view (IPC type) + helpers | `lmcache/v1/multiprocess/group_view.py` | +| Engine group info (IPC type) + helpers | `lmcache/v1/multiprocess/group_view.py` | | Shared grouping primitive | `lmcache/v1/kv_layer_groups.py` | -| vLLM → `list[LMCacheGroupView]` | `lmcache/integration/vllm/kv_cache_groups.py` | +| vLLM → `list[EngineGroupInfo]` | `lmcache/integration/vllm/kv_cache_groups.py` | | Register / store / retrieve | `lmcache/integration/vllm/{lmcache_mp_connector,vllm_multi_process_adapter}.py` | | Server GPU context / transfer | `lmcache/v1/multiprocess/{gpu_context,modules/gpu_transfer}.py` | | ZMQ protocol | `lmcache/v1/multiprocess/protocols/engine.py` | diff --git a/lmcache/cli/commands/bench/server_bench/command.py b/lmcache/cli/commands/bench/server_bench/command.py index 8ade408a1d..269f01a22e 100644 --- a/lmcache/cli/commands/bench/server_bench/command.py +++ b/lmcache/cli/commands/bench/server_bench/command.py @@ -294,7 +294,7 @@ def run_server_bench( # noqa: ARG001 (command kept for symmetry with siblings) layer_groups = parse_kvcache_shape_spec(args.kvcache_shape_spec) # One block-id list is sent per LMCache KV group; each shape-spec # group becomes its own group server-side. - num_group_views = len(layer_groups) or 1 + num_engine_group_infos = len(layer_groups) or 1 # Echo the resolved spec so operators can verify that their # input was interpreted as intended. The echoed string is a # valid ``--kvcache-shape-spec`` itself. @@ -473,7 +473,7 @@ def run_server_bench( # noqa: ARG001 (command kept for symmetry with siblings) http_base=http_base, block_size=block_size, total_blocks=num_blocks, - num_group_views=num_group_views, + num_engine_group_infos=num_engine_group_infos, use_gpu=use_gpu, use_handle=use_handle, client_tensors=client_tensors, @@ -492,7 +492,7 @@ def run_server_bench( # noqa: ARG001 (command kept for symmetry with siblings) http_base=http_base, block_size=block_size, total_blocks=num_blocks, - num_group_views=num_group_views, + num_engine_group_infos=num_engine_group_infos, use_gpu=use_gpu, use_handle=use_handle, client_tensors=client_tensors, diff --git a/lmcache/cli/commands/bench/server_bench/helpers.py b/lmcache/cli/commands/bench/server_bench/helpers.py index dc7b99046b..1ed0be6d00 100644 --- a/lmcache/cli/commands/bench/server_bench/helpers.py +++ b/lmcache/cli/commands/bench/server_bench/helpers.py @@ -554,7 +554,7 @@ def _send_store( key: IPCCacheEngineKey, block_offset: int = 0, block_size: int = 16, - num_group_views: int = 1, + num_engine_group_infos: int = 1, use_gpu: bool = True, use_handle: bool | None = None, client_tensors: list["torch.Tensor"] | None = None, @@ -582,7 +582,7 @@ def _send_store( payloads = [ key, _INSTANCE_ID, - [block_ids] * num_group_views, + [block_ids] * num_engine_group_infos, _make_event_handle(use_gpu), ] result = _call(client, RequestType.STORE, payloads) @@ -627,7 +627,7 @@ def _send_retrieve( hit_chunks: int, block_offset: int = 0, block_size: int = 16, - num_group_views: int = 1, + num_engine_group_infos: int = 1, use_gpu: bool = True, use_handle: bool | None = None, client_tensors: list["torch.Tensor"] | None = None, @@ -654,7 +654,7 @@ def _send_retrieve( payloads = [ key, _INSTANCE_ID, - [block_ids] * num_group_views, + [block_ids] * num_engine_group_infos, _make_event_handle(use_gpu), 0, # skip_first_n_tokens ] @@ -784,7 +784,7 @@ def _process_request( http_base: str = "", block_size: int = 16, total_blocks: int = 1024, - num_group_views: int = 1, + num_engine_group_infos: int = 1, use_gpu: bool = True, use_handle: bool | None = None, client_tensors: list["torch.Tensor"] | None = None, @@ -901,7 +901,7 @@ def _process_request( hit_chunks, block_offset=block_offset, block_size=block_size, - num_group_views=num_group_views, + num_engine_group_infos=num_engine_group_infos, use_gpu=use_gpu, use_handle=use_handle, client_tensors=client_tensors, @@ -938,7 +938,7 @@ def _process_request( store_key, block_offset=store_block_off, block_size=block_size, - num_group_views=num_group_views, + num_engine_group_infos=num_engine_group_infos, use_gpu=use_gpu, use_handle=use_handle, client_tensors=client_tensors, diff --git a/lmcache/integration/sglang/multi_process_adapter.py b/lmcache/integration/sglang/multi_process_adapter.py index 16de0e2955..3bf940abde 100644 --- a/lmcache/integration/sglang/multi_process_adapter.py +++ b/lmcache/integration/sglang/multi_process_adapter.py @@ -135,12 +135,12 @@ def __init__( # Upstream's REGISTER_KV_CACHE protocol takes flat positional args: # (instance_id, kv_cache, model_name, world_size, engine_type, - # layout_hints, group_views). SGLang's natural KV layout is depth-2 + # layout_hints, engine_group_infos). SGLang's natural KV layout is depth-2 # ([K_layers, V_layers]); we flatten it on the wire to fit # ``KVCache = list[CudaIPCWrapper]``. The daemon recognizes the # SGLang-MHA flat-of-2NL pattern from ``EngineType.SGLANG`` plus the # ``tokens_per_block`` hint and un-flattens + reshapes per layer. - # SGLang is non-hybrid (a single KV cache group), so group_views is the + # SGLang is non-hybrid (a single KV cache group), so engine_group_infos is the # empty list -- which the server treats as one group spanning all layers # (matching the vLLM non-hybrid and TensorRT-LLM register paths). send_lmcache_request( diff --git a/lmcache/integration/vllm/kv_cache_groups.py b/lmcache/integration/vllm/kv_cache_groups.py index 5167d74dd8..fdc9459410 100644 --- a/lmcache/integration/vllm/kv_cache_groups.py +++ b/lmcache/integration/vllm/kv_cache_groups.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""Build LMCache group views from vLLM KV cache group metadata.""" +"""Build LMCache engine group infos from vLLM KV cache group metadata.""" # Future from __future__ import annotations @@ -13,15 +13,15 @@ from lmcache.v1.gpu_connector.utils import LayoutHints # First Party -from lmcache.v1.multiprocess.group_view import LMCacheGroupView +from lmcache.v1.multiprocess.group_view import EngineGroupInfo -def create_group_views_from_vllm( +def create_engine_group_infos_from_vllm( kv_cache_config: Any, kv_caches: Mapping[str, Any], layout_hints: "LayoutHints | None" = None, -) -> list[LMCacheGroupView]: - """Build the LMCache group views from vLLM metadata and registered tensors. +) -> list[EngineGroupInfo]: + """Build the LMCache engine group infos from vLLM metadata and registered tensors. This is the single entry point for the vLLM -> LMCache conversion. It reads the vLLM-specific fields (``KVCacheConfig.kv_cache_groups`` and @@ -41,7 +41,7 @@ def create_group_views_from_vllm( detection (e.g. ``NHD``/``HND`` and compression metadata). Returns: - The list of ``LMCacheGroupView`` in protocol order, i.e. the LMCache group + The list of ``EngineGroupInfo`` in protocol order, i.e. the LMCache group order used by store/retrieve block IDs. """ # First Party @@ -95,7 +95,7 @@ def create_group_views_from_vllm( # the shared, engine-neutral primitive the server reuses to reproduce the # same grouping from the registered tensors. return [ - LMCacheGroupView( + EngineGroupInfo( engine_group_id=identity[4], layer_indices=tuple(indices), ) diff --git a/lmcache/integration/vllm/lmcache_mp_connector.py b/lmcache/integration/vllm/lmcache_mp_connector.py index 012e96fd2c..3ca1c9086d 100644 --- a/lmcache/integration/vllm/lmcache_mp_connector.py +++ b/lmcache/integration/vllm/lmcache_mp_connector.py @@ -37,7 +37,7 @@ class SupportsHMA: # type: ignore[no-redef] # First Party from lmcache import torch_dev from lmcache.integration.vllm.kv_cache_groups import ( - create_group_views_from_vllm, + create_engine_group_infos_from_vllm, ) from lmcache.integration.vllm.utils import mla_enabled, vllm_layout_hints from lmcache.utils import init_logger as lmcache_init_logger @@ -619,12 +619,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ logger.info("Registering kv caches!") kv_cache_config = getattr(self, "_kv_cache_config", None) - group_views = create_group_views_from_vllm( + engine_group_infos = create_engine_group_infos_from_vllm( kv_cache_config, kv_caches, layout_hints=vllm_layout_hints(), ) - self.worker_adapter.register_kv_caches(kv_caches, group_views=group_views) + self.worker_adapter.register_kv_caches( + kv_caches, engine_group_infos=engine_group_infos + ) return def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index c2fcb083a9..3fcc0054e6 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -22,8 +22,8 @@ KVCache, ) from lmcache.v1.multiprocess.group_view import ( - LMCacheGroupView, - expand_block_ids_to_views, + EngineGroupInfo, + expand_engine_block_ids, ) from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture from lmcache.v1.multiprocess.protocol import RequestType, get_response_class @@ -920,7 +920,7 @@ def __init__( # Registered kv caches from vLLM self.kv_caches: dict[str, torch.Tensor] = {} - self.group_views: list[LMCacheGroupView] = [] + self.engine_group_infos: list[EngineGroupInfo] = [] # Transport context for transfer operations. self.transfer_ctx: TransferContext | None = None @@ -1034,7 +1034,7 @@ def is_first_rank_of_pp_group(self) -> bool: def register_kv_caches( self, kv_caches: dict[str, torch.Tensor], - group_views: Sequence[LMCacheGroupView] = (), + engine_group_infos: Sequence[EngineGroupInfo] = (), ) -> None: """ Register the kv caches with LMCache server. @@ -1042,7 +1042,7 @@ def register_kv_caches( Args: kv_caches: A dict of kv caches to register. The keys are the layer names and the values are the corresponding tensors. - group_views: LMCache-owned engine KV cache group metadata. + engine_group_infos: LMCache-owned engine KV cache group metadata. Raises: ConnectionError: if the server does not respond within @@ -1050,11 +1050,11 @@ def register_kv_caches( """ logger.info("Registering kv caches") self.kv_caches = kv_caches - self.group_views = list(group_views) + self.engine_group_infos = list(engine_group_infos) self._send_register_kv_caches_request(kv_caches) def _block_ids_per_group(self, op: LoadStoreOp) -> list[list[int]]: - return expand_block_ids_to_views(self.group_views, op.block_ids) + return expand_engine_block_ids(self.engine_group_infos, op.block_ids) def _send_register_kv_caches_request( self, kv_caches: dict[str, torch.Tensor] @@ -1090,7 +1090,7 @@ def _send_register_kv_caches_request( self._mq_timeout, send_request=send_lmcache_request, layout_hints=layout_hints, - group_views=self.group_views, + engine_group_infos=self.engine_group_infos, ) except TimeoutError: raise ConnectionError( diff --git a/lmcache/v1/kv_layer_groups.py b/lmcache/v1/kv_layer_groups.py index 2127bcb573..78ea6ba055 100644 --- a/lmcache/v1/kv_layer_groups.py +++ b/lmcache/v1/kv_layer_groups.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: # First Party from lmcache.v1.gpu_connector.utils import DiscoverableKVCache, LayoutHints - from lmcache.v1.multiprocess.group_view import LMCacheGroupView + from lmcache.v1.multiprocess.group_view import EngineGroupInfo logger = init_logger(__name__) @@ -63,7 +63,7 @@ class KernelGroupIdentity(NamedTuple): # Sentinel ``per_layer_engine_group_idx`` value: a KV tensor tagged with it is # excluded from every LMCache group (used for cross-layer KV-sharing layers; see -# ``create_group_views_from_vllm``). +# ``create_engine_group_infos_from_vllm``). EXCLUDED_ENGINE_GROUP = -1 @@ -271,7 +271,7 @@ def __init__( gpu_kv_format: "lmc_ops.GPUKVFormat", num_blocks: int, layout_hints: "LayoutHints | None" = None, - group_views: "Sequence[LMCacheGroupView]" = (), + engine_group_infos: "Sequence[EngineGroupInfo]" = (), lmcache_logical_chunk_size: int = 256, ) -> None: """Partition layers into groups keyed by @@ -301,7 +301,7 @@ def __init__( group's ``compress_ratio`` and ``physical_chunk_size``. ``None`` means every group is treated as non-compressed (``compress_ratio == 1``). - group_views: LMCache-owned engine KV cache group + engine_group_infos: LMCache-owned engine KV cache group metadata. When present, it is used to keep layers from different engine block-ID spaces in separate LMCache transfer groups. @@ -342,7 +342,9 @@ def __init__( logger.debug("No KV caches available, skipping KV layer groups building") return - per_layer_engine_group_idx = get_engine_group_indices(group_views, num_layers) + per_layer_engine_group_idx = get_engine_group_indices( + engine_group_infos, num_layers + ) groups_by_identity = group_layers_by_identity( kv_caches, gpu_kv_format, num_layers, per_layer_engine_group_idx @@ -410,7 +412,7 @@ def __init__( ) # Detect the object groups - self._object_groups = self._detect_object_groups(group_views) + self._object_groups = self._detect_object_groups(engine_group_infos) @property def kernel_groups(self) -> list[KernelGroupInfo]: @@ -516,18 +518,18 @@ def calculate_num_blocks(self, kernel_group_idx: int, num_tokens: int) -> int: ### Helper methods def _detect_object_groups( - self, group_views: "Sequence[LMCacheGroupView]" + self, engine_group_infos: "Sequence[EngineGroupInfo]" ) -> list[ObjectGroupInfo]: - """Detect object groups based on the provided group views. + """Detect object groups based on the provided engine group infos. Args: - group_views: LMCache-owned engine KV cache group metadata. + engine_group_infos: LMCache-owned engine KV cache group metadata. Returns: A list of ObjectGroupInfo instances representing the detected object groups. """ # TODO: add the real object group detection logic based on - # the attention type metadata in the group views once it's + # the attention type metadata in the engine group infos once it's # available. # Now, we are using a single object group, which means # all kernel groups' KV caches will be stored in the same memory object. diff --git a/lmcache/v1/multiprocess/gpu_context.py b/lmcache/v1/multiprocess/gpu_context.py index 182d55af15..b3eef93296 100644 --- a/lmcache/v1/multiprocess/gpu_context.py +++ b/lmcache/v1/multiprocess/gpu_context.py @@ -38,7 +38,7 @@ ) from lmcache.v1.kv_layer_groups import KVLayerGroupsManager from lmcache.v1.multiprocess.custom_types import KVCache -from lmcache.v1.multiprocess.group_view import LMCacheGroupView +from lmcache.v1.multiprocess.group_view import EngineGroupInfo # Backend selection (c_ops when CUDA is available, otherwise a pure-Python # fallback) is handled once in ``lmcache/__init__.py`` via ``_get_backend``, @@ -342,7 +342,7 @@ def __init__( kv_caches: KVCache, lmcache_logical_chunk_size: int = 256, layout_hints: LayoutHints | None = None, - group_views: Sequence[LMCacheGroupView] = (), + engine_group_infos: Sequence[EngineGroupInfo] = (), engine_type: EngineType = EngineType.VLLM, ): unwrapped = unwrap_kv_cache_tensors(kv_caches) @@ -362,7 +362,7 @@ def __init__( gpu_kv_format=self.gpu_kv_format_, num_blocks=self.num_blocks_, layout_hints=layout_hints, - group_views=group_views, + engine_group_infos=engine_group_infos, lmcache_logical_chunk_size=lmcache_logical_chunk_size, ) diff --git a/lmcache/v1/multiprocess/group_view.py b/lmcache/v1/multiprocess/group_view.py index 7155791926..5c95cef30d 100644 --- a/lmcache/v1/multiprocess/group_view.py +++ b/lmcache/v1/multiprocess/group_view.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""LMCache's engine-neutral view of a serving engine's KV cache groups. +"""LMCache's engine-neutral description of a serving engine's KV cache groups. An *engine group* is one distinct paged-block address space exposed by the serving engine (e.g. one of vLLM's hybrid KV cache groups): block IDs are only @@ -7,8 +7,8 @@ merged into one LMCache KV group. Engine group ids are assumed dense and consecutive starting from 0. -LMCache's neutral KV cache spec is simply a ``list[LMCacheGroupView]`` (passed as -a ``Sequence[LMCacheGroupView]`` where only order matters). The group order is +LMCache's neutral KV cache spec is simply a ``list[EngineGroupInfo]`` (passed as +a ``Sequence[EngineGroupInfo]`` where only order matters). The group order is the protocol-visible LMCache group order used by store/retrieve block IDs. An empty list means a single non-hybrid group (the default for engines that do not report KV cache group metadata). Engine-specific conversion belongs in the @@ -23,26 +23,26 @@ import msgspec -class LMCacheGroupView(msgspec.Struct, frozen=True): +class EngineGroupInfo(msgspec.Struct, frozen=True): """One LMCache KV group: layers of one engine group that share a copy kernel. Carries the layer indices and which engine group they belong to. Several - ``LMCacheGroupView`` instances may share the same ``engine_group_id`` when + ``EngineGroupInfo`` instances may share the same ``engine_group_id`` when one engine group is split by physical transfer identity (e.g. differing - hidden dims). A ``list[LMCacheGroupView]`` is carried verbatim in the + hidden dims). A ``list[EngineGroupInfo]`` is carried verbatim in the ``REGISTER_KV_CACHE`` IPC payload; the message queue handles encoding/decoding. """ engine_group_id: int - """Engine group this view's layers live in (one distinct paged-block address + """Engine group these layers live in (one distinct paged-block address space). Selects which request block-id list applies. Dense from 0.""" layer_indices: tuple[int, ...] = () """Registered KV tensor indices assigned to this group.""" -def num_engine_groups(groups: Sequence[LMCacheGroupView]) -> int: +def num_engine_groups(groups: Sequence[EngineGroupInfo]) -> int: """Return the number of engine groups (block-id lists per transfer request). Engine group ids are assumed dense and consecutive from 0. @@ -59,7 +59,7 @@ def num_engine_groups(groups: Sequence[LMCacheGroupView]) -> int: return max(group.engine_group_id for group in groups) + 1 -def num_group_views(groups: Sequence[LMCacheGroupView]) -> int: +def num_engine_group_infos(groups: Sequence[EngineGroupInfo]) -> int: """Return the number of LMCache KV groups visible to transfer requests. Args: @@ -75,7 +75,7 @@ def num_group_views(groups: Sequence[LMCacheGroupView]) -> int: def _engine_group_id_per_view( - groups: Sequence[LMCacheGroupView], + groups: Sequence[EngineGroupInfo], ) -> tuple[int, ...]: """Return, per LMCache group, the engine group it draws block IDs from. @@ -84,7 +84,7 @@ def _engine_group_id_per_view( Returns: A tuple whose length equals the number of LMCache groups (i.e. - :func:`num_group_views`); element ``i`` is the engine group id + :func:`num_engine_group_infos`); element ``i`` is the engine group id that LMCache group ``i`` reads block IDs from. ``(0,)`` for an empty ``groups`` (single non-hybrid group). """ @@ -93,11 +93,11 @@ def _engine_group_id_per_view( return tuple(group.engine_group_id for group in groups) -def expand_block_ids_to_views( - groups: Sequence[LMCacheGroupView], +def expand_engine_block_ids( + groups: Sequence[EngineGroupInfo], engine_side_block_ids: Sequence[Sequence[int]] | Sequence[int], ) -> list[list[int]]: - """Re-index engine-side block IDs to one list per LMCache group. + """Expand the engine-side block id list to the list per LMCache kernel group. The serving engine reports block IDs per engine group. LMCache transfer requests are indexed by LMCache KV group, so each LMCache group reuses the @@ -174,7 +174,7 @@ def slice_block_ids_per_group( def get_engine_group_indices( - groups: Sequence[LMCacheGroupView], + groups: Sequence[EngineGroupInfo], num_registered_layers: int, ) -> list[int] | None: """Return the engine group index for each registered KV tensor. diff --git a/lmcache/v1/multiprocess/modules/gpu_transfer.py b/lmcache/v1/multiprocess/modules/gpu_transfer.py index 8b012af0c3..eaf4d1ff83 100644 --- a/lmcache/v1/multiprocess/modules/gpu_transfer.py +++ b/lmcache/v1/multiprocess/modules/gpu_transfer.py @@ -36,7 +36,7 @@ ThreadPoolType, ) from lmcache.v1.multiprocess.gpu_context import GPUCacheContext -from lmcache.v1.multiprocess.group_view import LMCacheGroupView +from lmcache.v1.multiprocess.group_view import EngineGroupInfo from lmcache.v1.multiprocess.native_completion import ( DeviceHostFuncDispatcher, submit_callback_to_stream, @@ -231,7 +231,7 @@ def register_kv_cache( world_size: int, engine_type: EngineType, layout_hints: LayoutHints, - group_views: list[LMCacheGroupView], + engine_group_infos: list[EngineGroupInfo], ) -> None: """Register the KV cache tensors for a given GPU instance ID. @@ -245,7 +245,7 @@ def register_kv_cache( Forwarded to GPUCacheContext for format detection. layout_hints: See LayoutHints. Forwarded to GPUCacheContext for GPU KV format detection. - group_views: Engine-neutral KV cache group metadata + engine_group_infos: Engine-neutral KV cache group metadata (already msgspec-decoded by the message queue). """ if instance_id in self._cache_contexts: @@ -260,7 +260,7 @@ def register_kv_cache( kv_caches, self._ctx.chunk_size, layout_hints=layout_hints or None, - group_views=group_views, + engine_group_infos=engine_group_infos, engine_type=engine_type, ) self._cache_contexts[instance_id] = ContextEntry( diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index bc61f7f2b1..b2aa603c66 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -24,7 +24,7 @@ KVCache, RegisterNonGpuContextPayload, ) -from lmcache.v1.multiprocess.group_view import LMCacheGroupView +from lmcache.v1.multiprocess.group_view import EngineGroupInfo from lmcache.v1.multiprocess.protocols.base import HandlerType, ProtocolDefinition @@ -96,7 +96,7 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: # - engine_type: EngineType - Which serving engine produced the # caches (vLLM, SGLang, ...). Drives format detection. # - layout_hints: LayoutHints - See custom_types.LayoutHints. - # - group_views: list[LMCacheGroupView] - Engine-neutral KV cache + # - engine_group_infos: list[EngineGroupInfo] - Engine-neutral KV cache # group metadata (msgspec-encoded by the message queue). # Returns: None "REGISTER_KV_CACHE": ProtocolDefinition( @@ -107,7 +107,7 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: int, EngineType, LayoutHints, - list[LMCacheGroupView], + list[EngineGroupInfo], ], response_class=None, handler_type=HandlerType.SYNC, diff --git a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py index 41c72ca7d2..7ab3dc3cdc 100644 --- a/lmcache/v1/multiprocess/transfer_context/worker_transfer.py +++ b/lmcache/v1/multiprocess/transfer_context/worker_transfer.py @@ -18,7 +18,7 @@ from lmcache.v1.gpu_connector.utils import LayoutHints, is_mla from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload from lmcache.v1.multiprocess.futures import MessagingFuture -from lmcache.v1.multiprocess.group_view import LMCacheGroupView +from lmcache.v1.multiprocess.group_view import EngineGroupInfo from lmcache.v1.multiprocess.mq import MessageQueueClient from lmcache.v1.multiprocess.protocol import RequestType from lmcache.v1.multiprocess.protocols.engine import RegisterNonGpuContextResponse @@ -126,7 +126,7 @@ def register( mq_timeout: float, send_request: SendRequest, layout_hints: LayoutHints | None = None, - group_views: Sequence[LMCacheGroupView] = (), + engine_group_infos: Sequence[EngineGroupInfo] = (), ) -> None: """Register KV caches with the server and wait for ACK. @@ -140,7 +140,7 @@ def register( mq_timeout: Timeout in seconds for synchronous request wait. send_request: Request sender callable used to issue MQ requests. layout_hints: Optional inference-engine-provided layout hints. - group_views: LMCache-owned engine KV cache group metadata. + engine_group_infos: LMCache-owned engine KV cache group metadata. Raises: TimeoutError: If server registration does not complete before @@ -232,7 +232,7 @@ def register( mq_timeout: float, send_request: SendRequest, layout_hints: LayoutHints | None = None, - group_views: Sequence[LMCacheGroupView] = (), + engine_group_infos: Sequence[EngineGroupInfo] = (), ) -> None: # First Party from lmcache.integration.vllm.vllm_multi_process_adapter import wrap_kv_caches @@ -249,7 +249,7 @@ def register( world_size, EngineType.VLLM, layout_hints, - list(group_views), + list(engine_group_infos), ], ) future.result(timeout=mq_timeout) @@ -321,11 +321,11 @@ def register( mq_timeout: float, send_request: SendRequest, layout_hints: LayoutHints | None = None, - group_views: Sequence[LMCacheGroupView] = (), + engine_group_infos: Sequence[EngineGroupInfo] = (), ) -> None: """Register KV caches with the non-GPU context server. - ``group_views`` is accepted to satisfy the base interface but + ``engine_group_infos`` is accepted to satisfy the base interface but is currently a no-op: the non-GPU transfer path does not support hybrid KV cache groups and rejects multi-group transfers at store / retrieve time (see ``_single_group_block_ids``). diff --git a/lmcache/v1/platform/cache_context.py b/lmcache/v1/platform/cache_context.py index 3c723cb222..b3f983e74f 100644 --- a/lmcache/v1/platform/cache_context.py +++ b/lmcache/v1/platform/cache_context.py @@ -25,14 +25,14 @@ if TYPE_CHECKING: # First Party - from lmcache.v1.multiprocess.group_view import LMCacheGroupView + from lmcache.v1.multiprocess.group_view import EngineGroupInfo def create_cache_context( kv_caches: KVCache, lmcache_logical_chunk_size: int = 256, layout_hints: LayoutHints | None = None, - group_views: "Sequence[LMCacheGroupView]" = (), + engine_group_infos: "Sequence[EngineGroupInfo]" = (), engine_type: EngineType = EngineType.VLLM, ) -> Any: """Create the appropriate cache context. @@ -50,7 +50,7 @@ def create_cache_context( lmcache_logical_chunk_size: Number of tokens per LMCache chunk. layout_hints: Optional hints for GPU KV format detection. Forwarded verbatim to the concrete context constructor. - group_views: Engine-neutral KV cache group metadata. + engine_group_infos: Engine-neutral KV cache group metadata. engine_type: Which serving engine produced the caches. Returns: @@ -70,6 +70,6 @@ def create_cache_context( kv_caches, lmcache_logical_chunk_size, layout_hints, - group_views, + engine_group_infos, engine_type, ) diff --git a/tests/v1/multiprocess/test_gpu_transfer_layout_registry.py b/tests/v1/multiprocess/test_gpu_transfer_layout_registry.py index 8ab2470bef..09abf56228 100644 --- a/tests/v1/multiprocess/test_gpu_transfer_layout_registry.py +++ b/tests/v1/multiprocess/test_gpu_transfer_layout_registry.py @@ -72,7 +72,7 @@ def fake_create_cache_context( kv_caches: object, lmcache_logical_chunk_size: int, layout_hints: object = None, - group_views: object = (), + engine_group_infos: object = (), engine_type: object = None, ) -> _FakeGPUContext: """Return a fake cache context without touching CUDA or wrappers.""" diff --git a/tests/v1/multiprocess/test_mq_handler_helpers.py b/tests/v1/multiprocess/test_mq_handler_helpers.py index 43610b07bc..60b1b72345 100644 --- a/tests/v1/multiprocess/test_mq_handler_helpers.py +++ b/tests/v1/multiprocess/test_mq_handler_helpers.py @@ -13,7 +13,7 @@ BlockAllocationRecord, KVCache, ) -from lmcache.v1.multiprocess.group_view import LMCacheGroupView +from lmcache.v1.multiprocess.group_view import EngineGroupInfo from lmcache.v1.multiprocess.protocol import KeyType # ============================================================================== @@ -41,7 +41,7 @@ def register_kv_cache_handler( world_size: int, engine_type: EngineType, layout_hints: LayoutHints, - group_views: list[LMCacheGroupView], + engine_group_infos: list[EngineGroupInfo], ) -> None: """ Dummy handler for REGISTER_KV_CACHE requests. @@ -56,7 +56,7 @@ def register_kv_cache_handler( ``layout_hints["inference_engine_logical_block_size"]`` carries the logical tokens-per-engine-block (previously a standalone argument). - group_views: Engine-neutral KV cache group metadata, + engine_group_infos: Engine-neutral KV cache group metadata, msgspec-decoded from the request payload. Returns: @@ -86,8 +86,8 @@ def register_kv_cache_handler( "Expected layout_hints['inference_engine_logical_block_size'] to be int, got " f"{type(ie_logical_block_size)}" ) - assert isinstance(group_views, list), ( - f"Expected group_views to be a list, got {type(group_views)}" + assert isinstance(engine_group_infos, list), ( + f"Expected engine_group_infos to be a list, got {type(engine_group_infos)}" ) # No return value (returns None implicitly) diff --git a/tests/v1/test_kv_cache_groups.py b/tests/v1/test_kv_cache_groups.py index cb04ff45f1..b554f1f699 100644 --- a/tests/v1/test_kv_cache_groups.py +++ b/tests/v1/test_kv_cache_groups.py @@ -4,61 +4,61 @@ # First Party from lmcache.v1.multiprocess.group_view import ( - LMCacheGroupView, - expand_block_ids_to_views, + EngineGroupInfo, + expand_engine_block_ids, get_engine_group_indices, + num_engine_group_infos, num_engine_groups, - num_group_views, slice_block_ids_per_group, ) -def test_group_views_default_to_one_engine_group(): +def test_engine_group_infos_default_to_one_engine_group(): assert num_engine_groups([]) == 1 - assert num_group_views([]) == 1 + assert num_engine_group_infos([]) == 1 assert get_engine_group_indices([], 1) is None -def test_group_views_build_per_layer_engine_group_indices(): +def test_engine_group_infos_build_per_layer_engine_group_indices(): groups = [ - LMCacheGroupView(0, (0, 2)), - LMCacheGroupView(1, (1, 3)), + EngineGroupInfo(0, (0, 2)), + EngineGroupInfo(1, (1, 3)), ] assert num_engine_groups(groups) == 2 - assert num_group_views(groups) == 2 + assert num_engine_group_infos(groups) == 2 assert get_engine_group_indices(groups, 4) == [0, 1, 0, 1] -def test_group_views_expand_block_ids_to_views(): +def test_engine_group_infos_expand_engine_block_ids(): groups = [ - LMCacheGroupView(0, (0, 2)), - LMCacheGroupView(0, (4,)), - LMCacheGroupView(1, (1, 3)), + EngineGroupInfo(0, (0, 2)), + EngineGroupInfo(0, (4,)), + EngineGroupInfo(1, (1, 3)), ] - assert expand_block_ids_to_views(groups, [[10, 11], [20, 21]]) == [ + assert expand_engine_block_ids(groups, [[10, 11], [20, 21]]) == [ [10, 11], [10, 11], [20, 21], ] -def test_group_views_msgspec_round_trip(): +def test_engine_group_infos_msgspec_round_trip(): """The groups encode/decode losslessly via msgspec (the IPC path).""" groups = [ - LMCacheGroupView(0, (0, 2)), - LMCacheGroupView(1, (1, 3)), + EngineGroupInfo(0, (0, 2)), + EngineGroupInfo(1, (1, 3)), ] decoded = msgspec.msgpack.decode( - msgspec.msgpack.encode(groups), type=list[LMCacheGroupView] + msgspec.msgpack.encode(groups), type=list[EngineGroupInfo] ) assert decoded == groups -def test_group_views_exclude_uncovered_layers(): +def test_engine_group_infos_exclude_uncovered_layers(): """Layers not referenced by any group are tagged EXCLUDED_ENGINE_GROUP. Cross-layer KV-sharing layers (e.g. google/gemma-4-E4B-it) alias a target @@ -69,16 +69,16 @@ def test_group_views_exclude_uncovered_layers(): from lmcache.v1.kv_layer_groups import EXCLUDED_ENGINE_GROUP groups = [ - LMCacheGroupView(0, (0,)), - LMCacheGroupView(1, (1,)), + EngineGroupInfo(0, (0,)), + EngineGroupInfo(1, (1,)), ] # Layer 2 is not covered by any group -> excluded, not an error. assert get_engine_group_indices(groups, 3) == [0, 1, EXCLUDED_ENGINE_GROUP] -def test_group_views_reject_out_of_range_layer(): - groups = [LMCacheGroupView(0, (0, 5))] +def test_engine_group_infos_reject_out_of_range_layer(): + groups = [EngineGroupInfo(0, (0, 5))] try: get_engine_group_indices(groups, 3) diff --git a/tests/v1/test_kv_layer_groups_manager.py b/tests/v1/test_kv_layer_groups_manager.py index e8cc99bf0d..58aae45fcc 100644 --- a/tests/v1/test_kv_layer_groups_manager.py +++ b/tests/v1/test_kv_layer_groups_manager.py @@ -19,7 +19,7 @@ format_kvcache_shape_spec, parse_kvcache_shape_spec, ) -from lmcache.v1.multiprocess.group_view import LMCacheGroupView +from lmcache.v1.multiprocess.group_view import EngineGroupInfo pytestmark = pytest.mark.skipif( not torch.cuda.is_available(), reason="PageBufferShapeDesc requires CUDA build" @@ -31,7 +31,7 @@ def _build_manager( *, num_blocks: int, layout_hints: LayoutHints | None = None, - group_views: Sequence[LMCacheGroupView] = (), + engine_group_infos: Sequence[EngineGroupInfo] = (), ) -> KVLayerGroupsManager: """Build a manager using the per-layer NHD format. @@ -48,7 +48,7 @@ def _build_manager( gpu_kv_format=lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS, num_blocks=num_blocks, layout_hints=layout_hints, - group_views=group_views, + engine_group_infos=engine_group_infos, ) @@ -95,9 +95,9 @@ def test_build_splits_same_shape_by_engine_group_idx(self): manager = _build_manager( tensors, num_blocks=32, - group_views=[ - LMCacheGroupView(0, (0, 2)), - LMCacheGroupView(1, (1, 3)), + engine_group_infos=[ + EngineGroupInfo(0, (0, 2)), + EngineGroupInfo(1, (1, 3)), ], ) @@ -108,7 +108,7 @@ def test_build_splits_same_shape_by_engine_group_idx(self): assert groups_by_engine_group_idx[0].layer_indices == [0, 2] assert groups_by_engine_group_idx[1].layer_indices == [1, 3] - def test_build_rejects_bad_group_views(self): + def test_build_rejects_bad_engine_group_infos(self): tensors = [ torch.randn(2, 32, 256, 8, 64, dtype=torch.float16) for _ in range(2) ] @@ -116,7 +116,7 @@ def test_build_rejects_bad_group_views(self): _build_manager( tensors, num_blocks=32, - group_views=[LMCacheGroupView(0, (2,))], + engine_group_infos=[EngineGroupInfo(0, (2,))], ) def test_build_different_shapes(self): @@ -380,14 +380,14 @@ def test_empty_manager_has_no_groups(self): assert manager.num_object_groups == 0 def test_excluded_layer_left_out_of_all_groups(self): - # Layer 2 is referenced by no group view, so it is excluded entirely. + # Layer 2 is referenced by no engine group info, so it is excluded entirely. tensors = [ torch.randn(2, 32, 256, 8, 64, dtype=torch.float16) for _ in range(3) ] manager = _build_manager( tensors, num_blocks=32, - group_views=[LMCacheGroupView(0, (0, 1))], + engine_group_infos=[EngineGroupInfo(0, (0, 1))], ) grouped = sorted( idx for group in manager.kernel_groups for idx in group.layer_indices diff --git a/tests/v1/test_vllm_kv_cache_groups.py b/tests/v1/test_vllm_kv_cache_groups.py index aec934ada9..d6e990c40a 100644 --- a/tests/v1/test_vllm_kv_cache_groups.py +++ b/tests/v1/test_vllm_kv_cache_groups.py @@ -7,10 +7,10 @@ # First Party from lmcache.integration.vllm.kv_cache_groups import ( - create_group_views_from_vllm, + create_engine_group_infos_from_vllm, ) from lmcache.v1.multiprocess.group_view import ( - expand_block_ids_to_views, + expand_engine_block_ids, get_engine_group_indices, num_engine_groups, ) @@ -32,7 +32,7 @@ def _same_shape_caches(names: list[str]) -> dict[str, torch.Tensor]: def test_conversion_defaults_to_single_group_without_config(): """No vLLM KV cache groups -> all layers fall into a single engine group.""" - spec = create_group_views_from_vllm( + spec = create_engine_group_infos_from_vllm( None, _same_shape_caches(["layer.0", "layer.1"]) ) @@ -43,7 +43,7 @@ def test_conversion_defaults_to_single_group_without_config(): def test_conversion_preserves_engine_group_layers(): """Two engine groups with identical tensor shape stay separate by group.""" - spec = create_group_views_from_vllm( + spec = create_engine_group_infos_from_vllm( MockKVCacheConfig( kv_cache_groups=[ MockKVCacheGroup(["layer.0", "layer.2"]), @@ -62,7 +62,7 @@ def test_conversion_splits_by_lmcache_layer_identity(): caches = _same_shape_caches(["layer.0", "layer.1", "layer.2", "layer.3"]) # layer.4 has a different head count -> distinct transfer identity. caches["layer.4"] = torch.randn(2, 32, 16, 16, 64, dtype=torch.float16) - spec = create_group_views_from_vllm( + spec = create_engine_group_infos_from_vllm( MockKVCacheConfig( kv_cache_groups=[ MockKVCacheGroup(["layer.0", "layer.2", "layer.4"]), @@ -74,7 +74,7 @@ def test_conversion_splits_by_lmcache_layer_identity(): assert [group.engine_group_id for group in spec] == [0, 1, 0] assert [group.layer_indices for group in spec] == [(0, 2), (1, 3), (4,)] - assert expand_block_ids_to_views(spec, [[10], [20]]) == [ + assert expand_engine_block_ids(spec, [[10], [20]]) == [ [10], [20], [10], diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index 635e7ea09d..54f5e36e23 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -25,7 +25,7 @@ LoadStoreOp, ParallelStrategy, ) -from lmcache.v1.multiprocess.group_view import LMCacheGroupView +from lmcache.v1.multiprocess.group_view import EngineGroupInfo from lmcache.v1.multiprocess.protocol import RequestType @@ -174,10 +174,10 @@ def test_submit_store_request_expands_block_ids_to_views(fake_adapter, monkeypat fake_tensor = MagicMock() fake_tensor.device.type = "cuda" adapter.kv_caches = {"layer.0": fake_tensor} - adapter.group_views = [ - LMCacheGroupView(0, (0, 2)), - LMCacheGroupView(0, (4,)), - LMCacheGroupView(1, (1, 3)), + adapter.engine_group_infos = [ + EngineGroupInfo(0, (0, 2)), + EngineGroupInfo(0, (4,)), + EngineGroupInfo(1, (1, 3)), ] transfer_ctx = MagicMock() fake_future = MagicMock() From bf1a215ec9072c11c5fd390fb6d5c8ab19fb17cb Mon Sep 17 00:00:00 2001 From: Yihua Cheng Date: Mon, 8 Jun 2026 19:10:36 -0700 Subject: [PATCH 08/57] [Refactor] Change the report_status to be per-kernel-group in LMCache (#3599) Signed-off-by: ApostaC --- docs/design/cli/commands/describe.md | 103 +++++++++++++----- docs/source/cli/describe.rst | 49 ++++++--- docs/source/mp/http_api.rst | 28 +++-- lmcache/cli/commands/describe.py | 69 ++++++++---- lmcache/v1/gpu_connector/utils.py | 70 ++++++++++++ lmcache/v1/multiprocess/gpu_context.py | 97 ++++++++++------- tests/cli/test_describe.py | 43 ++++++-- tests/v1/gpu_connector/test_concrete_shape.py | 68 ++++++++++++ tests/v1/multiprocess/test_gpu_context.py | 72 ++++++++---- 9 files changed, 451 insertions(+), 148 deletions(-) create mode 100644 tests/v1/gpu_connector/test_concrete_shape.py diff --git a/docs/design/cli/commands/describe.md b/docs/design/cli/commands/describe.md index 3aa2f729c7..0613a80620 100644 --- a/docs/design/cli/commands/describe.md +++ b/docs/design/cli/commands/describe.md @@ -33,15 +33,21 @@ Uptime: 2h 14m 32s ------ Model: meta-llama/Llama-3.1-70B-Instruct --- World size: 4 GPU IDs: 0, 1, 2, 3 -Attention backend: vLLM non-MLA flash attention -GPU KV shape: NL x [2, NB, BS, NH, HS] -GPU KV tensor shape: 80 x [2, 2048, 128, 8, 128] Num layers: 80 -Block size: 128 -Hidden dim size: 1024 +Num blocks: 2048 +Cache size per token (bytes): 327680 +--- Kernel group 0 (meta-llama/Llama-3.1-70B-Instruct) --- +Kernel group index: 0 +Engine group index: 0 +Object group index: 0 +Num layers: 80 +Physical block size: 128 +Compress ratio: 1 Dtype: torch.float16 MLA: False -Num blocks: 2048 +Attention backend: vLLM non-MLA flash attention +GPU KV shape: NL x [2, NB, BS, NH, HS] +GPU KV tensor shape: 80 x [2, 2048, 128, 8, 128] ----------- L2: NixlStoreL2Adapter ------------ Type: NixlStoreL2Adapter Health: OK @@ -67,15 +73,25 @@ programmatic access: "model": "meta-llama/Llama-3.1-70B-Instruct", "world_size": 4, "gpu_ids": "0, 1, 2, 3", - "attention_backend": "vLLM non-MLA flash attention", - "gpu_kv_shape": "NL x [2, NB, BS, NH, HS]", - "gpu_kv_concrete_shape": "80 x [2, 2048, 128, 8, 128]", "num_layers": 80, - "block_size": 128, - "hidden_dim_size": 1024, + "num_blocks": 2048, + "cache_size_per_token": 327680 + } + ], + "kernel_groups": [ + { + "model": "meta-llama/Llama-3.1-70B-Instruct", + "kernel_group_idx": 0, + "engine_group_idx": 0, + "object_group_idx": 0, + "num_layers": 80, + "physical_block_size": 128, + "compress_ratio": 1, "dtype": "torch.float16", "is_mla": false, - "num_blocks": 2048 + "attention_backend": "vLLM non-MLA flash attention", + "gpu_kv_shape": "NL x [2, NB, BS, NH, HS]", + "gpu_kv_concrete_shape": "80 x [2, 2048, 128, 8, 128]" } ], "l2_adapters": [ @@ -92,8 +108,21 @@ programmatic access: ``` Per-model sections are generated for each unique `(model_name, world_size)` pair -registered with the engine. The section includes: - +registered with the engine. The model section carries the context-wide fields — +`num_layers`, `num_blocks`, and `cache_size_per_token` — and is followed by one +**kernel group** section per kernel group, since a hybrid model's groups can +differ in geometry. + +Each kernel group section includes: + +- **Kernel / engine / object group index** — the group's identity: + `kernel_group_idx` enumerates the manager's kernel groups, `engine_group_idx` + is the paged-block address space (0 for non-hybrid), and `object_group_idx` is + the owning object group. +- **Num layers** and **Physical block size** — the group's layer count and + `shape_desc.bs`. +- **Compress ratio** — logical tokens per physical slot (1 for non-compressed). +- **Dtype** and **MLA** — the group's torch dtype and MLA flag. - **Attention backend** — which attention implementation is active (e.g., `vLLM non-MLA flash attention`, `vLLM MLA`, `SGLang MHA`), derived from the `GPUKVFormat` enum. @@ -101,9 +130,8 @@ registered with the engine. The section includes: `GPUKVFormat` enum (NB=num_blocks, NL=num_layers, BS=block_size, NH=num_heads, HS=head_size, PBS=page_buffer_size). E.g., `NL x [2, NB, BS, NH, HS]`. - **GPU KV tensor shape** — the same layout with actual numeric values substituted - (e.g., `80 x [2, 2048, 128, 8, 128]`). -- **Layout details** — num_layers, block_size, hidden_dim_size, dtype, MLA flag, - num_blocks. + from the group's `shape_desc` (e.g., `80 x [2, 2048, 128, 8, 128]`), so it is + group-accurate. L2 adapter sections are generated for each adapter in `storage_manager.l2_adapters`. Fields shown depend on the adapter type: @@ -258,22 +286,41 @@ Mirror the same `start_time`, `zmq_endpoint`, and `http_endpoint` additions if **Files:** `lmcache/v1/gpu_connector/utils.py`, `lmcache/v1/multiprocess/gpu_context.py`, `lmcache/v1/multiprocess/server.py` -Three new helper functions in `utils.py` (derived from `legible_print_gpu_kv_format()`): -- `get_gpu_kv_shape_description(gpu_kv_format)` — symbolic shape (e.g., `List[num_layers] of [2, num_blocks, ...]`) +Helper functions in `utils.py` (derived from `legible_print_gpu_kv_format()`): +- `get_gpu_kv_shape_description(gpu_kv_format)` — symbolic shape (e.g., `NL x [2, NB, BS, NH, HS]`) - `get_attention_backend(gpu_kv_format)` — backend name (e.g., `vLLM non-MLA flash attention`) -- `get_concrete_gpu_kv_shape(kv_caches, gpu_kv_format)` — shape with actual values (e.g., `List[80] of [2, 2048, 128, 8, 128]`) +- `get_concrete_gpu_kv_shape(kv_caches, gpu_kv_format)` — whole-context shape with actual values +- `get_concrete_gpu_kv_shape_from_shape_desc(shape_desc, gpu_kv_format)` — **group-accurate** shape with actual values, read from a single kernel group's `PageBufferShapeDesc` (used by `report_status`) -`GPUCacheContext` exposes these as properties: `gpu_kv_format_name`, `gpu_kv_shape`, `concrete_gpu_kv_shape`, `attention_backend`. - -`report_status()` includes them in the per-GPU `kv_cache_layout` dict: +`report_status()` is organised **per kernel group**: a small set of context-wide +fields at the top level, plus a `kernel_groups` list where each entry is +self-describing. The format-derived fields (`gpu_kv_format`, `gpu_kv_shape`, +`attention_backend`, `is_mla`) and the group-accurate `gpu_kv_concrete_shape` +live inside each group: ```python "kv_cache_layout": { - ..., - "gpu_kv_format": "NL_X_TWO_NB_BS_NH_HS", - "gpu_kv_shape": "NL x [2, NB, BS, NH, HS]", - "gpu_kv_concrete_shape": "80 x [2, 2048, 128, 8, 128]", - "attention_backend": "vLLM non-MLA flash attention", + "num_layers": 80, + "inference_engine_logical_block_size": 128, + "num_blocks": 2048, + "cache_size_per_token": 327680, + "kernel_groups": [ + { + "kernel_group_idx": 0, + "engine_group_idx": 0, + "object_group_idx": 0, + "num_layers": 80, + "layer_indices": [0, 1, ...], + "physical_block_size": 128, + "compress_ratio": 1, + "dtype": "torch.float16", + "gpu_kv_concrete_shape": "80 x [2, 2048, 128, 8, 128]", + "is_mla": false, + "gpu_kv_format": "NL_X_TWO_NB_BS_NH_HS", + "gpu_kv_shape": "NL x [2, NB, BS, NH, HS]", + "attention_backend": "vLLM non-MLA flash attention", + }, + ], } ``` diff --git a/docs/source/cli/describe.rst b/docs/source/cli/describe.rst index 15b7e46e90..1b07751210 100644 --- a/docs/source/cli/describe.rst +++ b/docs/source/cli/describe.rst @@ -25,15 +25,21 @@ L2 adapters. Model: meta-llama/Llama-3.1-70B-Instruct World size: 4 GPU IDs: 0, 1, 2, 3 - Attention backend: vLLM non-MLA flash attention - GPU KV shape: NL x [2, NB, BS, NH, HS] - GPU KV tensor shape: 80 x [2, 2048, 128, 8, 128] Num layers: 80 - Block size: 128 - Hidden dim sizes: 1024 + Num blocks: 2048 + Cache size per token (bytes): 327680 + --- Kernel group 0 (meta-llama/Llama-3.1-70B-Instruct) --- + Kernel group index: 0 + Engine group index: 0 + Object group index: 0 + Num layers: 80 + Physical block size: 128 + Compress ratio: 1 Dtype: torch.float16 MLA: False - Num blocks: 2048 + Attention backend: vLLM non-MLA flash attention + GPU KV shape: NL x [2, NB, BS, NH, HS] + GPU KV tensor shape: 80 x [2, 2048, 128, 8, 128] ------------- L2: NixlStoreL2Adapter ------------- Type: NixlStoreL2Adapter Health: OK @@ -46,8 +52,9 @@ The output shows: - **Overview** — health status, engine type, chunk size. - **L1 storage** — capacity, usage, eviction policy, cached object count. -- **Registered models** — per-model KV cache layout including the GPU KV - tensor shape (symbolic and concrete), attention backend, and layer details. +- **Registered models** — per-model KV cache layout: a context-wide summary + followed by one kernel group section per kernel group, each with the GPU KV + tensor shape (symbolic and concrete), attention backend, and group geometry. - **L2 adapters** — type, health, backend, stored objects, and utilization. Options @@ -74,8 +81,8 @@ Options JSON Output ----------- -Use ``--format json`` for machine-readable output. Models and L2 adapters -are collected into lists for easy programmatic access: +Use ``--format json`` for machine-readable output. Models, kernel groups, and +L2 adapters are collected into lists for easy programmatic access: .. code-block:: bash @@ -100,15 +107,25 @@ are collected into lists for easy programmatic access: "model": "meta-llama/Llama-3.1-70B-Instruct", "world_size": 4, "gpu_ids": "0, 1, 2, 3", - "attention_backend": "vLLM non-MLA flash attention", - "gpu_kv_shape": "NL x [2, NB, BS, NH, HS]", - "gpu_kv_concrete_shape": "80 x [2, 2048, 128, 8, 128]", "num_layers": 80, - "block_size": 128, - "hidden_dim_sizes": [1024], + "num_blocks": 2048, + "cache_size_per_token": 327680 + } + ], + "kernel_groups": [ + { + "model": "meta-llama/Llama-3.1-70B-Instruct", + "kernel_group_idx": 0, + "engine_group_idx": 0, + "object_group_idx": 0, + "num_layers": 80, + "physical_block_size": 128, + "compress_ratio": 1, "dtype": "torch.float16", "is_mla": false, - "num_blocks": 2048 + "attention_backend": "vLLM non-MLA flash attention", + "gpu_kv_shape": "NL x [2, NB, BS, NH, HS]", + "gpu_kv_concrete_shape": "80 x [2, 2048, 128, 8, 128]" } ], "l2_adapters": [ diff --git a/docs/source/mp/http_api.rst b/docs/source/mp/http_api.rst index 1ac69b9a17..06086b76de 100644 --- a/docs/source/mp/http_api.rst +++ b/docs/source/mp/http_api.rst @@ -237,16 +237,26 @@ prefetch jobs. Intended for operators and debugging, not for monitoring "world_size": 1, "kv_cache_layout": { "num_layers": 32, - "block_size": 16, - "hidden_dim_sizes": "...", - "dtype": "torch.bfloat16", - "is_mla": false, + "inference_engine_logical_block_size": 16, "num_blocks": 12345, - "gpu_kv_format": "...", - "gpu_kv_shape": "...", - "gpu_kv_concrete_shape": "...", - "attention_backend": "...", - "cache_size_per_token": 131072 + "cache_size_per_token": 131072, + "kernel_groups": [ + { + "kernel_group_idx": 0, + "engine_group_idx": 0, + "object_group_idx": 0, + "num_layers": 32, + "layer_indices": [0, 1, "..."], + "physical_block_size": 16, + "compress_ratio": 1, + "dtype": "torch.bfloat16", + "gpu_kv_concrete_shape": "...", + "is_mla": false, + "gpu_kv_format": "...", + "gpu_kv_shape": "...", + "attention_backend": "..." + } + ] } } }, diff --git a/lmcache/cli/commands/describe.py b/lmcache/cli/commands/describe.py index ad98d4d8e7..d23350821b 100644 --- a/lmcache/cli/commands/describe.py +++ b/lmcache/cli/commands/describe.py @@ -171,11 +171,13 @@ def add_l1_storage(self) -> None: ) def add_models(self) -> None: - """Per-model KV cache layout sections.""" + """Per-model KV cache layout sections. + + Each model gets one section with context-wide fields, followed by + one ``kernel_groups`` list entry per kernel group carrying that + group's identity and geometry. + """ gpu_meta = self.data.get("cache_context_meta", {}) - if not gpu_meta: - # CB-only deployments populate cb_gpu_context_meta instead. - gpu_meta = self.data.get("cb_gpu_context_meta", {}) if not gpu_meta: return @@ -202,33 +204,54 @@ def add_models(self) -> None: layout = info.get("layout") if not layout: continue - sec.add( - "attention_backend", - "Attention backend", - layout.get("attention_backend"), - ) - sec.add("gpu_kv_shape", "GPU KV shape", layout.get("gpu_kv_shape")) - sec.add( - "gpu_kv_concrete_shape", - "GPU KV tensor shape", - layout.get("gpu_kv_concrete_shape"), - ) - # CB-only contexts ship a singular ``hidden_dim_size``; wrap to - # match the plural list-shape used by the regular path. - if "hidden_dim_sizes" not in layout and "hidden_dim_size" in layout: - layout = dict(layout, hidden_dim_sizes=[layout["hidden_dim_size"]]) for _key, _label in ( ("num_layers", "Num layers"), - ("block_size", "Block size"), - ("hidden_dim_sizes", "Hidden dim sizes"), - ("dtype", "Dtype"), - ("is_mla", "MLA"), ("num_blocks", "Num blocks"), ("cache_size_per_token", "Cache size per token (bytes)"), ): if _key in layout: sec.add(_key, _label, layout[_key]) + self._add_kernel_groups(idx, model_name, layout.get("kernel_groups", [])) + + def _add_kernel_groups( + self, model_idx: int, model_name: str, kernel_groups: list + ) -> None: + """Emit one ``kernel_groups`` list section per kernel group. + + Args: + model_idx: Index of the owning model section (keeps section keys + unique across models). + model_name: Human-readable model name, shown in each group header. + kernel_groups: The model layout's ``kernel_groups`` list (each a + dict produced by ``GPUCacheContext.report_status``). + """ + for group in kernel_groups: + kg_idx = group.get("kernel_group_idx") + section_key = f"model_{model_idx}_kg_{kg_idx}" + self.metrics.add_list_section( + "kernel_groups", + section_key, + f"Kernel group {kg_idx} ({model_name})", + ) + sec = self.metrics[section_key] + sec.add("model", "Model", model_name) + for _key, _label in ( + ("kernel_group_idx", "Kernel group index"), + ("engine_group_idx", "Engine group index"), + ("object_group_idx", "Object group index"), + ("num_layers", "Num layers"), + ("physical_block_size", "Physical block size"), + ("compress_ratio", "Compress ratio"), + ("dtype", "Dtype"), + ("is_mla", "MLA"), + ("attention_backend", "Attention backend"), + ("gpu_kv_shape", "GPU KV shape"), + ("gpu_kv_concrete_shape", "GPU KV tensor shape"), + ): + if _key in group: + sec.add(_key, _label, group[_key]) + def add_l2_adapters(self) -> None: """L2 adapter sections.""" l2_adapters = safe_get(self.data, "storage_manager", "l2_adapters") or [] diff --git a/lmcache/v1/gpu_connector/utils.py b/lmcache/v1/gpu_connector/utils.py index 8fdf21387d..8e39bbf672 100644 --- a/lmcache/v1/gpu_connector/utils.py +++ b/lmcache/v1/gpu_connector/utils.py @@ -427,6 +427,76 @@ def get_concrete_gpu_kv_shape( return f"Unknown ({gpu_kv_format})" +def get_concrete_gpu_kv_shape_from_shape_desc( + shape_desc: "lmc_ops.PageBufferShapeDesc", + gpu_kv_format: "lmc_ops.GPUKVFormat", +) -> str: + """Return the concrete shape for a single kernel group's ``shape_desc``. + + Like :func:`get_concrete_gpu_kv_shape`, but the numeric values are + read from a per-group :class:`PageBufferShapeDesc` instead of from + the whole ``kv_caches`` structure. This makes the result + *group-accurate*: ``shape_desc.nl`` is the number of layers in the + group (not the model total), so for hybrid models each kernel group + reports its own shape. + + For example, instead of ``NL x [2, NB, BS, NH, HS]`` this returns + ``80 x [2, 2048, 128, 8, 128]``. + + Args: + shape_desc: The kernel group's shape descriptor. Numeric values + are pulled from its ``nl``/``nb``/``bs``/``nh``/``hs`` fields; + the page-buffer-size (``PBS``) formats use ``nb * bs``. + gpu_kv_format: The GPU KV format that determines the symbolic + shape template. + + Returns: + The shape string with numeric values substituted, or + ``"Unknown ()"`` for an unrecognised format. + """ + nl = shape_desc.nl + nb = shape_desc.nb + bs = shape_desc.bs + nh = shape_desc.nh + hs = shape_desc.hs + pbs = nb * bs + + fmt = gpu_kv_format + F = lmc_ops.GPUKVFormat + + if fmt == F.NB_NL_TWO_BS_NH_HS: + return f"[{nb}, {nl}, 2, {bs}, {nh}, {hs}]" + + if fmt == F.NL_X_TWO_NB_BS_NH_HS: + return f"{nl} x [2, {nb}, {bs}, {nh}, {hs}]" + + if fmt == F.NL_X_NB_TWO_BS_NH_HS: + return f"{nl} x [{nb}, 2, {bs}, {nh}, {hs}]" + + if fmt == F.NL_X_NB_BS_HS: + return f"{nl} x [{nb}, {bs}, {hs}]" + + if fmt == F.TWO_X_NL_X_NBBS_NH_HS: + return f"2 x {nl} x [{pbs}, {nh}, {hs}]" + + if fmt == F.TWO_X_NL_X_NB_BS_NH_HS: + return f"2 x {nl} x [{nb}, {bs}, {nh}, {hs}]" + + if fmt == F.NL_X_NBBS_ONE_HS: + return f"{nl} x [{pbs}, 1, {hs}]" + + if fmt == F.NL_X_TWO_NB_NH_BS_HS: + return f"{nl} x [2, {nb}, {nh}, {bs}, {hs}]" + + if fmt == F.NL_X_NB_TWO_NH_BS_HS: + return f"{nl} x [{nb}, 2, {nh}, {bs}, {hs}]" + + if fmt == F.NB_NL_TWO_NH_BS_HS: + return f"[{nb}, {nl}, 2, {nh}, {bs}, {hs}]" + + return f"Unknown ({gpu_kv_format})" + + def legible_print_gpu_kv_format(gpu_kv_format: "lmc_ops.GPUKVFormat"): """ Print the GPU KV Format in a legible way diff --git a/lmcache/v1/multiprocess/gpu_context.py b/lmcache/v1/multiprocess/gpu_context.py index b3eef93296..ac20961721 100644 --- a/lmcache/v1/multiprocess/gpu_context.py +++ b/lmcache/v1/multiprocess/gpu_context.py @@ -26,7 +26,7 @@ from lmcache.v1.gpu_connector.utils import ( LayoutHints, get_attention_backend, - get_concrete_gpu_kv_shape, + get_concrete_gpu_kv_shape_from_shape_desc, get_device, get_dtype, get_gpu_kv_shape_description, @@ -639,59 +639,74 @@ def cache_size_per_token(self) -> int: def report_status(self) -> dict: """Return this context's KV cache layout metadata for ``/status``. - Builds the ``kv_cache_layout`` sub-dict surfaced by the ``/status`` - HTTP endpoint (see ``GPUTransferModule.report_status``) and consumed by - the ``lmcache`` CLI (``lmcache describe kvcache`` and - ``lmcache bench engine``). It describes only the KV cache geometry; the - owning module wraps it with ``model_name``/``world_size``, which this - context does not track. - Returns: - A dict with one entry per documented layout field: + A dict with these top-level fields: - - ``num_layers`` (int) + - ``num_layers`` (int): total layers in the model. - ``inference_engine_logical_block_size`` (int) - - ``group_physical_block_sizes`` (list[int]): per-group - ``shape_desc.bs`` - - ``group_compress_ratios`` (list[int]): per-group compress ratio - - ``hidden_dim_sizes`` (str): stringified per-group hidden-dim list - - ``dtype`` (str): stringified torch dtype - - ``is_mla`` (bool) - ``num_blocks`` (int) - - ``gpu_kv_format`` (str): GPU KV format enum name - - ``gpu_kv_shape`` (str): symbolic shape description - - ``gpu_kv_concrete_shape`` (str): shape with numeric values - - ``attention_backend`` (str) - - ``cache_size_per_token`` (int): bytes per logical token - """ - # TODO(compat): the key names and value *formatting* below are a - # contract with the `/status` endpoint and the `lmcache` CLI - # (`lmcache/cli/commands/describe.py`, `bench/engine_bench/config.py`). - # Renaming a key breaks `lmcache describe kvcache`; dropping - # `cache_size_per_token` breaks `lmcache bench engine`. `hidden_dim_sizes` - # and `dtype` are stringified only for back-compat with those consumers - # and should become a real list / structured value once the CLI is - # updated to parse them. + - ``cache_size_per_token`` (int): bytes per logical token, + summed across groups. + - ``kernel_groups`` (list[dict]): one entry per kernel group, + each with: + + - ``kernel_group_idx`` (int): index into ``manager.kernel_groups``. + - ``engine_group_idx`` (int): paged-block address space. + - ``object_group_idx`` (int): owning object group. + - ``num_layers`` (int): layers in this group. + - ``layer_indices`` (list[int]): the group's layer indices. + - ``physical_block_size`` (int): ``shape_desc.bs``. + - ``compress_ratio`` (int) + - ``dtype`` (str): stringified torch dtype. + - ``gpu_kv_concrete_shape`` (str): group-accurate numeric shape. + - ``is_mla`` (bool) + - ``gpu_kv_format`` (str): GPU KV format enum name. + - ``gpu_kv_shape`` (str): symbolic shape description. + - ``attention_backend`` (str) + """ manager = self.kv_layer_groups_manager kernel_groups = manager.kernel_groups + + # Reverse-map each kernel group to its owning object group. + kernel_group_to_object_group: dict[int, int] = { + kg_idx: og_idx + for og_idx, og in enumerate(manager.object_groups) + for kg_idx in og.kernel_group_indices + } + + gpu_kv_format = self.gpu_kv_format_ + group_reports: list[dict] = [] + for kernel_group_idx, group in enumerate(kernel_groups): + group_reports.append( + { + "kernel_group_idx": kernel_group_idx, + "engine_group_idx": group.engine_group_idx, + "object_group_idx": kernel_group_to_object_group.get( + kernel_group_idx, 0 + ), + "num_layers": group.num_layers, + "layer_indices": list(group.layer_indices), + "physical_block_size": group.shape_desc.bs, + "compress_ratio": group.compress_ratio, + "dtype": str(group.dtype), + "gpu_kv_concrete_shape": get_concrete_gpu_kv_shape_from_shape_desc( + group.shape_desc, gpu_kv_format + ), + "is_mla": is_mla(gpu_kv_format), + "gpu_kv_format": gpu_kv_format.name, + "gpu_kv_shape": get_gpu_kv_shape_description(gpu_kv_format), + "attention_backend": get_attention_backend(gpu_kv_format), + } + ) + return { "num_layers": self.num_layers, "inference_engine_logical_block_size": ( manager.inference_engine_logical_block_size ), - "group_physical_block_sizes": [g.shape_desc.bs for g in kernel_groups], - "group_compress_ratios": [g.compress_ratio for g in kernel_groups], - "hidden_dim_sizes": str([g.hidden_dim_size for g in kernel_groups]), - "dtype": str(self.dtype), - "is_mla": self.is_mla, "num_blocks": self.num_blocks, - "gpu_kv_format": self.gpu_kv_format_.name, - "gpu_kv_shape": get_gpu_kv_shape_description(self.gpu_kv_format_), - "gpu_kv_concrete_shape": get_concrete_gpu_kv_shape( - self.kv_caches_, self.gpu_kv_format_ - ), - "attention_backend": get_attention_backend(self.gpu_kv_format_), "cache_size_per_token": self.cache_size_per_token(), + "kernel_groups": group_reports, } diff --git a/tests/cli/test_describe.py b/tests/cli/test_describe.py index e493c86cd2..c99aea8458 100644 --- a/tests/cli/test_describe.py +++ b/tests/cli/test_describe.py @@ -35,12 +35,26 @@ "world_size": 1, "kv_cache_layout": { "num_layers": 32, - "block_size": 16, - "hidden_dim_sizes": 128, - "dtype": "torch.float16", - "is_mla": False, + "inference_engine_logical_block_size": 16, "num_blocks": 2048, "cache_size_per_token": 163840, + "kernel_groups": [ + { + "kernel_group_idx": 0, + "engine_group_idx": 0, + "object_group_idx": 0, + "num_layers": 32, + "layer_indices": list(range(32)), + "physical_block_size": 16, + "compress_ratio": 1, + "dtype": "torch.float16", + "gpu_kv_concrete_shape": "32 x [2, 2048, 16, 8, 128]", + "is_mla": False, + "gpu_kv_format": "NL_X_TWO_NB_BS_NH_HS", + "gpu_kv_shape": "NL x [2, NB, BS, NH, HS]", + "attention_backend": "vLLM non-MLA flash attention", + }, + ], }, }, }, @@ -180,11 +194,24 @@ class FakeArgs: assert model["world_size"] == 1 assert model["gpu_ids"] == "0" assert model["num_layers"] == 32 - assert model["block_size"] == 16 - assert model["hidden_dim_sizes"] == 128 - assert model["dtype"] == "torch.float16" - assert model["is_mla"] is False assert model["num_blocks"] == 2048 + assert model["cache_size_per_token"] == 163840 + + # Per-kernel-group section (list) + assert "kernel_groups" in m + kg = m["kernel_groups"][0] + assert kg["model"] == "llama" + assert kg["kernel_group_idx"] == 0 + assert kg["engine_group_idx"] == 0 + assert kg["object_group_idx"] == 0 + assert kg["num_layers"] == 32 + assert kg["physical_block_size"] == 16 + assert kg["compress_ratio"] == 1 + assert kg["dtype"] == "torch.float16" + assert kg["is_mla"] is False + assert kg["attention_backend"] == "vLLM non-MLA flash attention" + assert kg["gpu_kv_shape"] == "NL x [2, NB, BS, NH, HS]" + assert kg["gpu_kv_concrete_shape"] == "32 x [2, 2048, 16, 8, 128]" def test_unhealthy(self): """Verify health shows UNHEALTHY when is_healthy is False.""" diff --git a/tests/v1/gpu_connector/test_concrete_shape.py b/tests/v1/gpu_connector/test_concrete_shape.py new file mode 100644 index 0000000000..2abc5fc7d8 --- /dev/null +++ b/tests/v1/gpu_connector/test_concrete_shape.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for :func:`get_concrete_gpu_kv_shape_from_shape_desc`. + +These run without a CUDA build: ``lmcache.c_ops`` resolves to the +pure-Python fallback, which provides both ``PageBufferShapeDesc`` and +``GPUKVFormat``. +""" + +# First Party +from lmcache.v1.gpu_connector.utils import ( + get_concrete_gpu_kv_shape_from_shape_desc, +) +import lmcache.c_ops as lmc_ops + + +def _make_shape_desc( + *, kv_size: int, nl: int, nb: int, bs: int, nh: int, hs: int +) -> "lmc_ops.PageBufferShapeDesc": + """Build a ``PageBufferShapeDesc`` with the given geometry.""" + sd = lmc_ops.PageBufferShapeDesc() + sd.kv_size = kv_size + sd.nl = nl + sd.nb = nb + sd.bs = bs + sd.nh = nh + sd.hs = hs + sd.element_size = 2 + sd.block_stride_elems = 0 + return sd + + +def test_concrete_shape_vllm_flash_attn(): + sd = _make_shape_desc(kv_size=2, nl=32, nb=2048, bs=16, nh=8, hs=128) + out = get_concrete_gpu_kv_shape_from_shape_desc( + sd, lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS + ) + assert out == "32 x [2, 2048, 16, 8, 128]" + + +def test_concrete_shape_vllm_mla(): + sd = _make_shape_desc(kv_size=1, nl=61, nb=1024, bs=64, nh=1, hs=512) + out = get_concrete_gpu_kv_shape_from_shape_desc( + sd, lmc_ops.GPUKVFormat.NL_X_NB_BS_HS + ) + assert out == "61 x [1024, 64, 512]" + + +def test_concrete_shape_uses_pbs_for_folded_formats(): + # NL_X_NBBS_ONE_HS folds num_blocks * block_size into one PBS dim. + sd = _make_shape_desc(kv_size=1, nl=2, nb=32, bs=16, nh=1, hs=128) + out = get_concrete_gpu_kv_shape_from_shape_desc( + sd, lmc_ops.GPUKVFormat.NL_X_NBBS_ONE_HS + ) + assert out == "2 x [512, 1, 128]" # 512 == 32 * 16 + + +def test_concrete_shape_is_group_accurate(): + # Two groups with different layer counts produce different shapes for + # the same format — the whole-context helper could not do this. + fmt = lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS + g0 = _make_shape_desc(kv_size=2, nl=4, nb=128, bs=16, nh=8, hs=64) + g1 = _make_shape_desc(kv_size=2, nl=2, nb=128, bs=16, nh=16, hs=64) + assert get_concrete_gpu_kv_shape_from_shape_desc(g0, fmt) == ( + "4 x [2, 128, 16, 8, 64]" + ) + assert get_concrete_gpu_kv_shape_from_shape_desc(g1, fmt) == ( + "2 x [2, 128, 16, 16, 64]" + ) diff --git a/tests/v1/multiprocess/test_gpu_context.py b/tests/v1/multiprocess/test_gpu_context.py index 535cdf4ef3..75790bed51 100644 --- a/tests/v1/multiprocess/test_gpu_context.py +++ b/tests/v1/multiprocess/test_gpu_context.py @@ -474,41 +474,67 @@ def test_calculate_num_blocks_matches_manager(self) -> None: class TestGPUCacheContextReportStatus: + _TOP_LEVEL_KEYS = { + "num_layers", + "inference_engine_logical_block_size", + "num_blocks", + "cache_size_per_token", + "kernel_groups", + } + _GROUP_KEYS = { + "kernel_group_idx", + "engine_group_idx", + "object_group_idx", + "num_layers", + "layer_indices", + "physical_block_size", + "compress_ratio", + "dtype", + "gpu_kv_concrete_shape", + "is_mla", + "gpu_kv_format", + "gpu_kv_shape", + "attention_backend", + } + def test_report_status_fields(self) -> None: ctx = _make_context(_SINGLE_GROUP) status = ctx.report_status() - expected_keys = { - "num_layers", - "inference_engine_logical_block_size", - "group_physical_block_sizes", - "group_compress_ratios", - "hidden_dim_sizes", - "dtype", - "is_mla", - "num_blocks", - "gpu_kv_format", - "gpu_kv_shape", - "gpu_kv_concrete_shape", - "attention_backend", - "cache_size_per_token", - } - assert set(status.keys()) == expected_keys - + assert set(status.keys()) == self._TOP_LEVEL_KEYS assert status["num_layers"] == 4 - assert status["is_mla"] is False - assert status["group_compress_ratios"] == [1] - assert status["gpu_kv_format"] == "NL_X_TWO_NB_BS_NH_HS" - assert status["dtype"] == str(ctx.dtype) assert status["cache_size_per_token"] == ctx.cache_size_per_token() + assert len(status["kernel_groups"]) == 1 + group = status["kernel_groups"][0] + assert set(group.keys()) == self._GROUP_KEYS + assert group["kernel_group_idx"] == 0 + assert group["num_layers"] == 4 + assert group["layer_indices"] == [0, 1, 2, 3] + assert group["is_mla"] is False + assert group["compress_ratio"] == 1 + assert group["gpu_kv_format"] == "NL_X_TWO_NB_BS_NH_HS" + assert group["dtype"] == str(ctx.dtype) + def test_report_status_multi_group(self) -> None: ctx = _make_context(_MULTI_GROUP) manager = ctx.kv_layer_groups_manager status = ctx.report_status() assert status["num_layers"] == 6 - assert len(status["group_physical_block_sizes"]) == manager.num_kernel_groups - assert len(status["group_compress_ratios"]) == manager.num_kernel_groups + assert len(status["kernel_groups"]) == manager.num_kernel_groups + + # Group reports enumerate in order and stay self-consistent with the + # manager's kernel groups. + for kg_idx, (group, kernel_group) in enumerate( + zip(status["kernel_groups"], manager.kernel_groups, strict=False) + ): + assert set(group.keys()) == self._GROUP_KEYS + assert group["kernel_group_idx"] == kg_idx + assert group["engine_group_idx"] == kernel_group.engine_group_idx + assert group["num_layers"] == kernel_group.num_layers + assert group["physical_block_size"] == kernel_group.shape_desc.bs + assert group["compress_ratio"] == kernel_group.compress_ratio + assert 0 <= group["object_group_idx"] < manager.num_object_groups if __name__ == "__main__": From cb193c741be965178f8228ae6ce85d72ce7c8c44 Mon Sep 17 00:00:00 2001 From: sonimwang <17816198144@163.com> Date: Tue, 9 Jun 2026 10:43:34 +0800 Subject: [PATCH 09/57] fix(zh_CN): correct machine translation errors in documentation (#3592) Signed-off-by: sonimwang <17816198144@163.com> --- docs/source/locale/zh_CN/LC_MESSAGES/cli/index.po | 4 ++-- .../LC_MESSAGES/developer_guide/contributing.po | 8 ++++---- docs/source/locale/zh_CN/LC_MESSAGES/index.po | 2 +- .../zh_CN/LC_MESSAGES/kv_cache_management/index.po | 4 ++-- docs/source/locale/zh_CN/LC_MESSAGES/mp/index.po | 4 ++-- .../locale/zh_CN/LC_MESSAGES/recipes/index.po | 14 +++++++------- .../locale/zh_CN/LC_MESSAGES/recipes/minimax_m2.po | 2 +- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/cli/index.po b/docs/source/locale/zh_CN/LC_MESSAGES/cli/index.po index b189725888..1bb99fa3d3 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/cli/index.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/cli/index.po @@ -77,7 +77,7 @@ msgstr "``lmcache-cli``" #: ../../source/cli/index.rst:34 msgid "``pip install lmcache-cli``" -msgstr "``pip install kvcache``" +msgstr "``pip install lmcache-cli``" #: ../../source/cli/index.rst:35 msgid "" @@ -147,7 +147,7 @@ msgstr "对推理引擎(``engine``)、LMCache MP 服务器(``server``) #: ../../source/cli/index.rst:64 msgid ":doc:`kvcache`" -msgstr ":kvcache:" +msgstr ":doc:`kvcache`" #: ../../source/cli/index.rst:65 msgid "Manage KV cache state (e.g. clear L1 cache) on a running server." diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/contributing.po b/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/contributing.po index 8dff10e725..366ba73c1c 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/contributing.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/contributing.po @@ -28,7 +28,7 @@ msgid "" "Thank you for your interest in contributing to LMCache! We welcome and " "accept all kinds of contributions, no matter how small or large. There " "are several ways you can contribute to the project:" -msgstr "感谢您对为 LMCache 贡献的兴趣!我们欢迎并接受各种形式的贡献,无论大小。您可以通过几种方式为项目做出贡献:" +msgstr "感谢您有兴趣为 LMCache 做出贡献!我们欢迎并接受各种形式的贡献,无论大小。您可以通过以下几种方式为项目做出贡献:" #: ../../source/developer_guide/contributing.rst:6 msgid "Identify and report any issues or bugs" @@ -51,7 +51,7 @@ msgid "" "A comprehensive list of good first issues can be found in the issue " "`[Onboarding][Q4] Welcoming contributors with good first issues! " "`_." -msgstr "可以在问题 `[Onboarding][Q4] 欢迎贡献者的好第一问题! `_ 中找到一个全面的好第一问题列表。" +msgstr "可以在 Issue `[Onboarding][Q4] Welcoming contributors with good first issues! `_ 中找到完整的适合新手的 Issue 列表。" #: ../../source/developer_guide/contributing.rst:13 msgid "" @@ -103,7 +103,7 @@ msgid "" "and there is always a need for more test coverage. If you see something " "that you think should be fixed, take ownership! Here is how you get " "started." -msgstr "对开源项目的帮助总是受欢迎的,总有一些可以改进的地方。例如,文档(就像您现在正在阅读的文本)总是可以改进,代码总是可以更清晰,变量或函数总是可以重命名或添加注释,并且总是需要更多的测试覆盖率。如果您看到认为应该修复的内容,请主动承担责任!以下是您如何开始的指南。" +msgstr "对开源项目的帮助总是受欢迎的,总有一些可以改进的地方。例如,文档(就像您现在正在阅读的文本)总是可以改进,代码总是可以更清晰,变量或函数总是可以重命名或添加注释,并且总是需要更多的测试覆盖率。如果您看到认为应该修复的内容,请主动承担责任!以下是入门指南。" #: ../../source/developer_guide/contributing.rst:33 msgid "How Can I Contribute?" @@ -487,7 +487,7 @@ msgstr "在 http://localhost:8000 本地服务文档页面: :code:`python -m h #: ../../source/developer_guide/contributing.rst:201 msgid "Thank You" -msgstr "谢谢你" +msgstr "感谢" #: ../../source/developer_guide/contributing.rst:203 msgid "" diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/index.po b/docs/source/locale/zh_CN/LC_MESSAGES/index.po index d896b3313d..d83dc9acf8 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/index.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/index.po @@ -29,7 +29,7 @@ msgstr "入门指南" #: ../../source/index.rst:86 msgid "Recipes" -msgstr "食谱" +msgstr "使用指南" #: ../../source/index.rst:94 msgid "KV Cache offloading and sharing" diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/kv_cache_management/index.po b/docs/source/locale/zh_CN/LC_MESSAGES/kv_cache_management/index.po index 63d3d35f23..afd0ff958b 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/kv_cache_management/index.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/kv_cache_management/index.po @@ -120,7 +120,7 @@ msgstr ":ref:`压缩 `: 压缩 KV Cache。" #: ../../source/kv_cache_management/index.rst:41 msgid ":ref:`Health `: Check the health status of cache workers." -msgstr "`:ref:`Health `: 检查缓存工作线程的健康状态。`" +msgstr ":ref:`Health `: 检查缓存工作线程的健康状态。" #: ../../source/kv_cache_management/index.rst:42 msgid ":ref:`Lookup `: Lookup the KV cache for a given list of tokens." @@ -138,7 +138,7 @@ msgstr ":ref:`Pin `: 持久化 KV Cache 以防止其被逐出。" msgid "" ":ref:`CheckFinish `: Check whether a (non-blocking) control" " event has finished or not." -msgstr "`:ref:`CheckFinish `: 检查一个(非阻塞)控制事件是否已经完成。`" +msgstr ":ref:`CheckFinish `: 检查一个(非阻塞)控制事件是否已经完成。" #: ../../source/kv_cache_management/index.rst:46 msgid ":ref:`QueryWorkerInfo `: Query the worker info." diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/mp/index.po b/docs/source/locale/zh_CN/LC_MESSAGES/mp/index.po index 31459db2e0..63182bf8d1 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/mp/index.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/mp/index.po @@ -37,7 +37,7 @@ msgstr "LMCache 多进程 (MP) 模式将 LMCache 作为一个 **独立服务** #: ../../source/mp/index.rst:10 msgid "Key Benefits" -msgstr "关键好处" +msgstr "主要优势" #: ../../source/mp/index.rst:12 msgid "" @@ -126,7 +126,7 @@ msgstr "``python3 -m lmcache.v1.multiprocess.server``" msgid "" "(Legacy) ZMQ-only server using MPCacheEngine (no HTTP endpoints). Prefer " "``lmcache server``." -msgstr "(遗留) 仅使用 MPCacheEngine 的 ZMQ 服务器(没有 HTTP 端点)。请使用 ``lmcache server``。" +msgstr "(遗留)仅使用 MPCacheEngine 的 ZMQ 服务器(没有 HTTP 端点)。请使用 ``lmcache server``。" #: ../../source/mp/index.rst:51 msgid "``python3 -m lmcache.v1.multiprocess.blend_server_v2``" diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/recipes/index.po b/docs/source/locale/zh_CN/LC_MESSAGES/recipes/index.po index 87474a844c..ced9f6bde5 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/recipes/index.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/recipes/index.po @@ -21,7 +21,7 @@ msgstr "" #: ../../source/recipes/index.rst:4 msgid "Recipes" -msgstr "食谱" +msgstr "使用指南" #: ../../source/recipes/index.rst:6 msgid "" @@ -34,15 +34,15 @@ msgstr "本节列出了经过 LMCache 端到端验证的模型架构,每个架 msgid "" "Engine-side documentation (how to serve the model itself) lives with the " "serving engine. Recipe pages link out rather than duplicate." -msgstr "引擎端文档(如何服务模型本身)与服务引擎一起存在。食谱页面链接而不是重复。" +msgstr "引擎端文档(如何服务模型本身)随服务引擎一起维护。使用指南页面提供外部链接,不重复已有内容。" #: ../../source/recipes/index.rst:14 msgid "Recipe page contents" -msgstr "食谱页面内容" +msgstr "使用指南页面内容" #: ../../source/recipes/index.rst:16 msgid "Each recipe page is intentionally minimal:" -msgstr "每个食谱页面都故意保持简约:" +msgstr "每个使用指南页面都故意保持简约:" #: ../../source/recipes/index.rst:18 msgid "**Validated models** -- exact HF repo IDs that have been tested." @@ -81,7 +81,7 @@ msgid "" msgstr "" "有关通用 LMCache + 引擎连接(端口、远程主机、进程内模式、发送第一个请求),请参阅 " ":doc:`../getting_started/quickstart` 和 " -":doc:`../mp/quickstart`。食谱假设这些页面是先决条件。" +":doc:`../mp/quickstart`。使用指南假设这些页面是先决条件。" #: ../../source/recipes/index.rst:33 msgid "Supported architectures" @@ -109,7 +109,7 @@ msgstr "TRT-LLM" #: ../../source/recipes/index.rst:44 msgid "Recipe" -msgstr "食谱" +msgstr "使用指南" #: ../../source/recipes/index.rst:46 msgid "``MiniMaxM2ForCausalLM``" @@ -231,7 +231,7 @@ msgstr "图例:``✓`` 已验证,``—`` 未验证。" #: ../../source/recipes/index.rst:105 msgid "Contributing a recipe" -msgstr "贡献一个食谱" +msgstr "贡献一个使用指南" #: ../../source/recipes/index.rst:107 msgid "To add a new architecture:" diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/recipes/minimax_m2.po b/docs/source/locale/zh_CN/LC_MESSAGES/recipes/minimax_m2.po index 5d7ad47c9d..b90f18d9fd 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/recipes/minimax_m2.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/recipes/minimax_m2.po @@ -98,7 +98,7 @@ msgid "" "`_, " "`MiniMax M2.5/M2.1/M2 usage guide " "`_." -msgstr "**引擎文档:** `MiniMax-M2 SGLang 食谱 `_,`MiniMax M2.5/M2.1/M2 使用指南 `_。" +msgstr "**引擎文档:** `MiniMax-M2 SGLang 实战指南 `_,`MiniMax M2.5/M2.1/M2 使用指南 `_。" #: ../../source/recipes/minimax_m2.rst:93 msgid "**Status:** Not validated with LMCache." From ae328a66b0d9bc2c5e5f09c1a686e37766a03199 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 9 Jun 2026 11:53:54 +0800 Subject: [PATCH 10/57] [CI] Improve CI stability: gemma-4 test & serde test (#3556) Signed-off-by: KuntaiDu --- .buildkite/k3_tests/multiprocess/pipeline.yml | 58 +++++------- .../multiprocess/scripts/run-hma-lm-eval.sh | 92 ++++++++----------- .../multiprocess/scripts/run-single-test.sh | 9 +- docs/source/recipes/gemma4.rst | 17 ++-- docs/source/recipes/index.rst | 7 ++ tests/v1/distributed/serde/test_serde_e2e.py | 45 ++++++--- 6 files changed, 110 insertions(+), 118 deletions(-) diff --git a/.buildkite/k3_tests/multiprocess/pipeline.yml b/.buildkite/k3_tests/multiprocess/pipeline.yml index b9fa979b86..fa47e37108 100644 --- a/.buildkite/k3_tests/multiprocess/pipeline.yml +++ b/.buildkite/k3_tests/multiprocess/pipeline.yml @@ -26,40 +26,6 @@ steps: - { name: hf-cache, hostPath: { path: /data/huggingface, type: DirectoryOrCreate } } artifact_paths: ["*.log"] - # HMA (hybrid memory allocator) correctness check on google/gemma-4-31B-it. - # It interleaves sliding + full attention whose full layers use a larger - # head_dim (512 vs 256), so vLLM gives its KV cache groups different block - # sizes -- exercising LMCache's per-group block-size handling. Runs gsm8k, - # resets vLLM's local prefix cache (LMCache preserved), reruns, and asserts - # the scores match (run1 == run2 == no-LMCache baseline). Needs 2 GPUs - # (LMCache+vLLM + baseline). It is public (no HF_TOKEN), forces TRITON_ATTN - # (so ATTENTION_BACKEND=auto and a non-zero SCORE_TOLERANCE, since - # TRITON_ATTN is not bit-exact under batch invariance), and its ~63GB of - # weights need a higher GPU_MEMORY_UTILIZATION than the 0.5 default. - - label: ":compression: hma_lm_eval_gemma4" - command: .buildkite/k3_tests/multiprocess/run.sh hma_lm_eval_gemma4 - timeout_in_minutes: 60 - env: - MODEL: "google/gemma-4-31B-it" - SCORE_TOLERANCE: "0.05" - ATTENTION_BACKEND: "auto" - GPU_MEMORY_UTILIZATION: "0.85" - # Skip CUDA-graph capture so the large model doesn't time out at launch - # (safe here: this test uses a tolerance, not the bit-exact check). - ENFORCE_EAGER: "1" - # 31B weights are large; allow longer for download + load before the - # readiness probe gives up (other models keep the 300s default). - MAX_WAIT_SECONDS: "400" - # LIMIT = number of gsm8k samples. 31B's large per-token KV makes the - # full 200-sample working set overflow the CPU pool and thrash, so run - # 2 misses LMCache; cap the samples and enlarge the pool to keep run 2 - # cache-served. CPU_BUFFER_SIZE (GB) is bounded by node RAM. - LIMIT: "100" - CPU_BUFFER_SIZE: "200" - agents: { queue: "k8s" } - plugins: [{ kubernetes: { podSpec: *pod-2gpu } }] - artifact_paths: ["*.log"] - - label: ":compression: long_doc_qa" command: .buildkite/k3_tests/multiprocess/run.sh long_doc_qa timeout_in_minutes: 30 @@ -99,6 +65,30 @@ steps: volumes: *vols artifact_paths: ["*.log"] + # HMA correctness check on google/gemma-4-31B-it (a hybrid model whose KV + # cache groups get different block sizes). Runs gsm8k, resets vLLM's prefix + # cache (LMCache preserved), reruns served by LMCache, and asserts the two + # runs' scores match. Single GPU, no baseline. + - label: ":compression: hma_lm_eval_gemma4" + command: .buildkite/k3_tests/multiprocess/run.sh hma_lm_eval_gemma4 + timeout_in_minutes: 60 + env: + MODEL: "google/gemma-4-31B-it" + # Require an exact score match between the two runs. + SCORE_TOLERANCE: "0" + ATTENTION_BACKEND: "auto" + GPU_MEMORY_UTILIZATION: "0.85" + # 31B load + CUDA-graph capture is slow; raise the readiness timeout + # above the 300s default. + MAX_WAIT_SECONDS: "600" + # Cap samples and enlarge the CPU pool so the retrieve run stays + # cache-served (31B's per-token KV is large). + LIMIT: "100" + CPU_BUFFER_SIZE: "200" + agents: { queue: "k8s" } + plugins: [{ kubernetes: { podSpec: *pod-1gpu } }] + artifact_paths: ["*.log"] + - label: ":compression: fault_tolerance" command: .buildkite/k3_tests/multiprocess/run.sh fault_tolerance timeout_in_minutes: 30 diff --git a/.buildkite/k3_tests/multiprocess/scripts/run-hma-lm-eval.sh b/.buildkite/k3_tests/multiprocess/scripts/run-hma-lm-eval.sh index f650a0a7db..5d2f33e9c5 100755 --- a/.buildkite/k3_tests/multiprocess/scripts/run-hma-lm-eval.sh +++ b/.buildkite/k3_tests/multiprocess/scripts/run-hma-lm-eval.sh @@ -11,17 +11,16 @@ # Qwen3.5/Qwen3-Next, whose state caches LMCache cannot yet transfer). # - Public, so no HF_TOKEN is required. # -# Flow: -# 1. Run lm_eval (gsm8k) against vLLM+LMCache -> populates LMCache (STORE). +# Flow (single GPU, no baseline server): +# 1. vLLM run: lm_eval (gsm8k) against vLLM+LMCache, populating LMCache. # 2. Reset vLLM's *local* prefix cache (APC) only, leaving LMCache intact, via # the dev-mode endpoint POST /reset_prefix_cache (reset_external defaults to # false, so the LMCache-managed cache is preserved). -# 3. Re-run lm_eval -> vLLM APC misses, so the -# prefix KV is served by LMCache (RETRIEVE), exercising the HMA retrieve path. -# 4. Assert the three gsm8k scores agree within SCORE_TOLERANCE (run 1 store == -# run 2 retrieve == no-LMCache baseline); a broken retrieve corrupts the KV -# and the score diverges. -# 5. Assert LMCache actually served retrieves during run 2 (non-vacuous). +# 3. LMCache retrieve run: re-run lm_eval; vLLM's APC misses, so the prefix KV +# is served by LMCache. +# 4. Assert the two runs' gsm8k scores match -- a broken LMCache would skew the +# retrieved KV and make them diverge. +# 5. Assert LMCache actually served retrieves in the retrieve run (non-vacuous). # # The reset endpoint requires VLLM_SERVER_DEV_MODE=1 (set by launch-processes.sh). set -e @@ -34,40 +33,35 @@ source "${REPO_ROOT}/.buildkite/k3_tests/common_scripts/helpers.sh" # Configuration VLLM_PORT="${VLLM_PORT:-8000}" -VLLM_BASELINE_PORT="${VLLM_BASELINE_PORT:-9000}" MODEL="${MODEL:-google/gemma-4-31B-it}" NUM_CONCURRENT="${NUM_CONCURRENT:-50}" # 31B has a large per-token KV footprint; cap the sample count so the working -# set fits the CPU pool (a too-large set thrashes and run 2 misses LMCache). +# set fits the CPU pool (a too-large set thrashes and the retrieve run misses). LIMIT="${LIMIT:-100}" -# Max allowed absolute difference in the gsm8k exact_match score across runs. -# gemma-4 forces the Triton backend, which is not bit-exact under vLLM's -# batch-invariant mode, so a correct retrieve can differ from a fresh compute by -# a small margin; the default allows a small tolerance instead of an exact match. -SCORE_TOLERANCE="${SCORE_TOLERANCE:-0.05}" -# Seconds to wait after run 1 so async LMCache stores drain before run 2. +# Max abs difference allowed between the two runs' gsm8k scores; 0 requires an +# exact match. +SCORE_TOLERANCE="${SCORE_TOLERANCE:-0}" +# Seconds to let async LMCache stores drain before the retrieve run. STORE_DRAIN_SECONDS="${STORE_DRAIN_SECONDS:-20}" BUILD_ID="${BUILD_ID:-local_$$}" RESULTS_DIR="${RESULTS_DIR:-/tmp/lmcache_ci_results_${BUILD_ID}}" -# LMCache MP server log, scanned to confirm run 2 was served by LMCache retrieves. +# LMCache MP server log, scanned to confirm the retrieve run hit LMCache. LMCACHE_LOG="${LMCACHE_LOG:-/tmp/build_${BUILD_ID}_lmcache.log}" HMA_DIR="$RESULTS_DIR/hma_lm_eval" -RUN1_DIR="$HMA_DIR/run1_store" -RUN2_DIR="$HMA_DIR/run2_retrieve" -BASELINE_DIR="$HMA_DIR/baseline" +VLLM_RUN_DIR="$HMA_DIR/vllm_run" +RETRIEVE_RUN_DIR="$HMA_DIR/retrieve_run" echo "=== HMA lm_eval correctness test ===" echo "Model: $MODEL" echo "vLLM (LMCache) port: $VLLM_PORT" -echo "vLLM baseline port: $VLLM_BASELINE_PORT" echo "Concurrent requests: $NUM_CONCURRENT" echo "Limit: $LIMIT" echo "Score tolerance: $SCORE_TOLERANCE" echo "Results dir: $HMA_DIR" echo "" -mkdir -p "$RUN1_DIR" "$RUN2_DIR" "$BASELINE_DIR" +mkdir -p "$VLLM_RUN_DIR" "$RETRIEVE_RUN_DIR" # Run one lm_eval gsm8k pass against a vLLM OpenAI-compatible server. # @@ -147,8 +141,8 @@ count_retrieves() { grep -c "Retrieved" "$LMCACHE_LOG" 2>/dev/null || true } -# ── 1. Cold run: compute + STORE into LMCache ─────────────── -run_lm_eval "$VLLM_PORT" "$RUN1_DIR" "run1 LMCache STORE" +# ── 1. vLLM run: compute from scratch, populating LMCache ─── +run_lm_eval "$VLLM_PORT" "$VLLM_RUN_DIR" "vLLM run" # Let async stores drain to the LMCache server before invalidating the APC. echo "Waiting ${STORE_DRAIN_SECONDS}s for LMCache stores to drain..." @@ -159,28 +153,25 @@ retrieves_before=$(count_retrieves) # ── 2. Invalidate vLLM's local prefix cache (keep LMCache) ── reset_vllm_prefix_cache "$VLLM_PORT" -# ── 3. Warm run: vLLM APC misses -> LMCache RETRIEVE ──────── -run_lm_eval "$VLLM_PORT" "$RUN2_DIR" "run2 LMCache RETRIEVE" +# ── 3. Retrieve run: vLLM APC misses -> LMCache serves the KV ─ +run_lm_eval "$VLLM_PORT" "$RETRIEVE_RUN_DIR" "LMCache retrieve run" retrieves_after=$(count_retrieves) -# ── 4. Baseline run: no LMCache, ground truth ────────────── -run_lm_eval "$VLLM_BASELINE_PORT" "$BASELINE_DIR" "baseline no LMCache" - -# ── 5. Compare scores and verify LMCache was actually used ── +# ── 4. Compare scores and verify LMCache was actually used ── echo "============================================" echo "=== Verifying HMA store/retrieve correctness ===" echo "============================================" -echo "LMCache retrieves logged: before run2=${retrieves_before}, after run2=${retrieves_after}" +echo "LMCache retrieves logged: before=${retrieves_before}, after=${retrieves_after}" -python3 - "$RUN1_DIR" "$RUN2_DIR" "$BASELINE_DIR" \ +python3 - "$VLLM_RUN_DIR" "$RETRIEVE_RUN_DIR" \ "$SCORE_TOLERANCE" "$retrieves_before" "$retrieves_after" <<'PYEOF' import glob import json import os import sys -run1_dir, run2_dir, baseline_dir, tol_s, before_s, after_s = sys.argv[1:7] +vllm_run_dir, retrieve_run_dir, tol_s, before_s, after_s = sys.argv[1:6] tol = float(tol_s) retrieves_before = int(before_s) retrieves_after = int(after_s) @@ -221,34 +212,25 @@ def gsm8k_exact_match(results_dir: str) -> float: raise SystemExit(f"No exact_match metric in {latest}: {sorted(metrics)}") -s1 = gsm8k_exact_match(run1_dir) -s2 = gsm8k_exact_match(run2_dir) -sb = gsm8k_exact_match(baseline_dir) +s_vllm = gsm8k_exact_match(vllm_run_dir) +s_retrieve = gsm8k_exact_match(retrieve_run_dir) -print(f" run1 (LMCache STORE) gsm8k exact_match = {s1:.4f}") -print(f" run2 (LMCache RETRIEVE) gsm8k exact_match = {s2:.4f}") -print(f" baseline (no LMCache) gsm8k exact_match = {sb:.4f}") +print(f" vLLM run gsm8k exact_match = {s_vllm:.4f}") +print(f" LMCache retrieve run gsm8k exact_match = {s_retrieve:.4f}") print(f" tolerance = {tol}") failures = [] -# run1 (store) vs run2 (retrieve): same server, the core store/retrieve check. -if abs(s1 - s2) > tol: - failures.append( - f"LMCache store-vs-retrieve score drift: |{s1:.4f} - {s2:.4f}| = " - f"{abs(s1 - s2):.4f} > {tol}" - ) -# run2 (retrieve) vs baseline (no LMCache): retrieve must match ground truth. -if abs(s2 - sb) > tol: +# The two runs must match -- a broken LMCache would skew the retrieved KV. +if abs(s_vllm - s_retrieve) > tol: failures.append( - f"Retrieve-vs-baseline score drift: |{s2:.4f} - {sb:.4f}| = " - f"{abs(s2 - sb):.4f} > {tol}" + f"score drift between runs: |{s_vllm:.4f} - {s_retrieve:.4f}| = " + f"{abs(s_vllm - s_retrieve):.4f} > {tol}" ) -# Non-vacuous: run 2 must have been served by LMCache retrieves, not recompute. +# Non-vacuous: the retrieve run must have been served by LMCache, not recompute. if retrieves_after <= retrieves_before: failures.append( - "LMCache served no retrieves during run 2 " - f"(before={retrieves_before}, after={retrieves_after}); " - "the retrieve path was not exercised" + "LMCache served no retrieves during the retrieve run " + f"(before={retrieves_before}, after={retrieves_after})" ) if failures: @@ -258,8 +240,8 @@ if failures: sys.exit(1) print( - f"\nPASS: store, retrieve, and baseline gsm8k scores match (tol={tol}); " - f"LMCache served {retrieves_after - retrieves_before} retrieves during run 2." + f"\nPASS: vLLM and LMCache-retrieve gsm8k scores match (tol={tol}); " + f"LMCache served {retrieves_after - retrieves_before} retrieves." ) PYEOF diff --git a/.buildkite/k3_tests/multiprocess/scripts/run-single-test.sh b/.buildkite/k3_tests/multiprocess/scripts/run-single-test.sh index 71dd68f762..4df0e2ad95 100755 --- a/.buildkite/k3_tests/multiprocess/scripts/run-single-test.sh +++ b/.buildkite/k3_tests/multiprocess/scripts/run-single-test.sh @@ -28,10 +28,9 @@ if [ "$TEST_NAME" = "hma_lm_eval_gemma4" ]; then # gemma-4-31B-it is public (no gating, so no HF token check) and has # heterogeneous head dims (head_dim 256 / global_head_dim 512), so vLLM # gives its KV cache groups different block sizes -- this is what exercises - # LMCache's per-group block-size handling. It forces TRITON_ATTN, which is - # not bit-exact under batch invariance, so the pipeline sets a small - # SCORE_TOLERANCE and ATTENTION_BACKEND=auto; its ~63GB of weights also need - # a higher GPU_MEMORY_UTILIZATION than the default (all set in pipeline.yml). + # LMCache's per-group block-size handling. It forces TRITON_ATTN, so the + # pipeline sets ATTENTION_BACKEND=auto; its ~63GB of weights also need a + # higher GPU_MEMORY_UTILIZATION than the default (all set in pipeline.yml). export MODEL="${MODEL:-google/gemma-4-31B-it}" else export MODEL="${MODEL:-Qwen/Qwen3-14B}" @@ -63,7 +62,7 @@ SELF_CONTAINED_TESTS=" deadlock " # Tests that compare against a baseline vLLM (no LMCache) on a second GPU. # Only these need the baseline server (and thus a 2-GPU pod); everything # else runs on GPU 0 alone, so launch-processes.sh skips the baseline. -BASELINE_TESTS=" vllm_bench long_doc_qa long_doc_qa_l2 hma_lm_eval_gemma4 " +BASELINE_TESTS=" vllm_bench long_doc_qa long_doc_qa_l2 " if [[ "$BASELINE_TESTS" == *" $TEST_NAME "* ]]; then export LAUNCH_BASELINE=true else diff --git a/docs/source/recipes/gemma4.rst b/docs/source/recipes/gemma4.rst index e4e1acc607..a2ed89894e 100644 --- a/docs/source/recipes/gemma4.rst +++ b/docs/source/recipes/gemma4.rst @@ -1,12 +1,13 @@ .. _recipe_gemma4: -Gemma4ForConditionalGeneration -=============================== +Gemma 4 +======= Validated models ---------------- - `google/gemma-4-31B-it `_ +- `google/gemma-4-12B-it `_ - `google/gemma-4-E4B-it `_ .. tab-set:: @@ -17,7 +18,8 @@ Validated models **Engine documentation:** `Gemma 4 in vLLM supported models `_ - (architecture ``Gemma4ForConditionalGeneration``). + (architectures ``Gemma4ForConditionalGeneration`` for 31B/E4B and + ``Gemma4UnifiedForConditionalGeneration`` for 12B). **Status:** Validated with LMCache. @@ -40,11 +42,12 @@ Validated models | - The smaller ``google/gemma-4-E4B-it`` runs on a single GPU: + The smaller ``google/gemma-4-12B-it`` and ``google/gemma-4-E4B-it`` run on + a single GPU: .. code-block:: bash - vllm serve google/gemma-4-E4B-it \ + vllm serve google/gemma-4-12B-it \ --kv-transfer-config \ '{"kv_connector":"LMCacheMPConnector", "kv_role":"kv_both"}' @@ -95,7 +98,3 @@ Caveats - **Cross-layer KV sharing.** ``google/gemma-4-E4B-it`` reuses some layers' KV caches across layers. LMCache stores the cache-owning layers only; the sharing layers' KV lives in the same blocks and is restored automatically. -- **Determinism.** Gemma 4 runs on the Triton attention backend, which is not - bit-exact under vLLM's batch-invariant mode, so a retrieved result may differ - from a freshly computed one by a small numerical margin rather than being - byte-identical. diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index d02b991808..8dd0fdd370 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -58,6 +58,13 @@ Supported architectures - — - :doc:`gemma4` + * - ``Gemma4UnifiedForConditionalGeneration`` + - ``google/gemma-4-12B-it`` + - ✓ + - — + - — + - :doc:`gemma4` + * - ``Gemma3ForConditionalGeneration`` - ``google/gemma-3-4b-it`` - ✓ diff --git a/tests/v1/distributed/serde/test_serde_e2e.py b/tests/v1/distributed/serde/test_serde_e2e.py index 1b192df05e..2aa42f7bbb 100644 --- a/tests/v1/distributed/serde/test_serde_e2e.py +++ b/tests/v1/distributed/serde/test_serde_e2e.py @@ -175,6 +175,29 @@ def get_l1_object_count(sm: StorageManager) -> int: return sm.report_status()["l1_manager"]["total_object_count"] +def clear_and_wait_drained(sm: StorageManager, timeout: float = 10.0) -> None: + """Clear L1 and poll until every object is evicted. + + After an L2 store the StoreController holds read locks on the stored objects + for a short window, and ``StorageManager.clear`` keeps locked objects intact. + A single clear right after the store therefore races the lock release and can + leave objects behind. Retry clear() until the locks drop and L1 drains rather + than relying on a fixed sleep. + + Raises: + AssertionError: If L1 still holds objects after ``timeout`` seconds. + """ + + def drained() -> bool: + sm.clear() + return get_l1_object_count(sm) == 0 + + if not wait_for_condition(drained, timeout=timeout): + raise AssertionError( + f"L1 did not drain after clear: {get_l1_object_count(sm)} objects remain" + ) + + # ============================================================================= # Tests: Full round-trip through serde # ============================================================================= @@ -196,9 +219,7 @@ def test_store_and_prefetch_with_serde(self) -> None: write_and_wait_for_l2(sm, keys, layout) - # Brief sleep so StoreController releases read locks after L2 store - time.sleep(1) - sm.clear() + clear_and_wait_drained(sm) assert get_l1_object_count(sm) == 0 # Prefetch from L2 @@ -222,8 +243,7 @@ def test_no_memory_leak_after_full_cycle(self) -> None: keys = [make_object_key(i) for i in range(3)] write_and_wait_for_l2(sm, keys, layout) - time.sleep(1) - sm.clear() + clear_and_wait_drained(sm) # Prefetch handle = sm.submit_prefetch_task(keys, layout) @@ -263,8 +283,7 @@ def test_store_and_prefetch_without_serde(self) -> None: keys = [make_object_key(i) for i in range(5)] write_and_wait_for_l2(sm, keys, layout) - time.sleep(1) - sm.clear() + clear_and_wait_drained(sm) handle = sm.submit_prefetch_task(keys, layout) hits = wait_for_prefetch_status(sm, handle) @@ -285,8 +304,7 @@ def test_no_memory_leak_without_serde(self) -> None: keys = [make_object_key(i) for i in range(3)] write_and_wait_for_l2(sm, keys, layout) - time.sleep(1) - sm.clear() + clear_and_wait_drained(sm) handle = sm.submit_prefetch_task(keys, layout) hits = wait_for_prefetch_status(sm, handle) @@ -318,8 +336,7 @@ def test_partial_prefix_with_serde(self) -> None: # Write only keys 0, 1, 3, 4 (skip 2) keys_to_write = [make_object_key(i) for i in [0, 1, 3, 4]] write_and_wait_for_l2(sm, keys_to_write, layout) - time.sleep(1) - sm.clear() + clear_and_wait_drained(sm) # Request all 5 keys — prefix should be 2 (gap at index 2) all_keys = [make_object_key(i) for i in range(5)] @@ -354,8 +371,7 @@ def test_repeated_cycles_no_leak(self) -> None: for cycle in range(5): keys = [make_object_key(cycle * 10 + i) for i in range(3)] write_and_wait_for_l2(sm, keys, layout) - time.sleep(1) - sm.clear() + clear_and_wait_drained(sm) handle = sm.submit_prefetch_task(keys, layout) hits = wait_for_prefetch_status(sm, handle) @@ -441,8 +457,7 @@ def _run_roundtrip( keys = [make_object_key(i) for i in range(num_keys)] write_and_wait_for_l2(sm, keys, layout) - time.sleep(1) - sm.clear() + clear_and_wait_drained(sm) assert get_l1_object_count(sm) == 0 handle = sm.submit_prefetch_task(keys, layout) From 996f03bbb68c16d7993eb02d45d1d5d1c8e06249 Mon Sep 17 00:00:00 2001 From: ruicheng <95903923+KimmoZAG@users.noreply.github.com> Date: Tue, 9 Jun 2026 12:15:52 +0800 Subject: [PATCH 11/57] examples(kv_cache_calculator): add Hunyuan & DeepSeek models, fix head_dim/CLA calculation, add i18n UI (#2834) * examples(kv_cache_calculator): add Hunyuan & DeepSeek models, UI i18n, prefer local modelconfig Signed-off-by: KimmoZAG <995496585@qq.com> * fix(kv_cache_calculator): use prefix match for DeepSeek V3 variants; consolidate head_dim logic Signed-off-by: KimmoZAG <995496585@qq.com> --------- Signed-off-by: KimmoZAG <995496585@qq.com> --- .../kv_cache_calculator/generate_config.py | 20 +- .../kv_cache_calculator.html | 538 +++++++++++++----- examples/kv_cache_calculator/modelconfig.json | 46 +- .../benchmarks/test_xpu_kernels_microbench.py | 2 +- 4 files changed, 451 insertions(+), 155 deletions(-) diff --git a/examples/kv_cache_calculator/generate_config.py b/examples/kv_cache_calculator/generate_config.py index c7724170aa..dac665298a 100644 --- a/examples/kv_cache_calculator/generate_config.py +++ b/examples/kv_cache_calculator/generate_config.py @@ -34,17 +34,33 @@ def main(): "num_key_value_heads": getattr(config, "num_key_value_heads", None), } - if args.model == "deepseek-ai/DeepSeek-V3": + # DeepSeek MLA models (V3, V3.1, V3.2, … and R1) store + # KV in latent space + if ( + args.model.lower().startswith("deepseek-ai/deepseek-v3") + or args.model == "deepseek-ai/DeepSeek-R1" + ): config_data["kv_lora_rank"] = getattr(config, "kv_lora_rank", None) config_data["qk_rope_head_dim"] = getattr(config, "qk_rope_head_dim", None) - # Check for Qwen3 models (fuzzy matching) or GLM4 models + # Models whose head_dim is explicit in config and may + # differ from hidden_size / num_heads: + # Qwen3, GLM4, and Hunyuan dense variants. if ( "qwen/qwen3-" in args.model.lower() or "zai-org/glm-4." in args.model.lower() + or ( + args.model.lower().startswith("tencent/hunyuan-") + and args.model.lower() != "tencent/hunyuan-large" + ) ): config_data["head_dim"] = getattr(config, "head_dim", None) + # Hunyuan-Large uses CLA (Cross-Layer Attention): + # KV layers = num_hidden_layers / cla_share_factor + if args.model.lower() == "tencent/hunyuan-large": + config_data["cla_share_factor"] = getattr(config, "cla_share_factor", None) + # Convert to JSON and print string = json.dumps(config_data, indent=4) diff --git a/examples/kv_cache_calculator/kv_cache_calculator.html b/examples/kv_cache_calculator/kv_cache_calculator.html index 0dfe1e6116..92ed0c7053 100644 --- a/examples/kv_cache_calculator/kv_cache_calculator.html +++ b/examples/kv_cache_calculator/kv_cache_calculator.html @@ -5,141 +5,353 @@ KV Cache Size Calculator -
-

KV Cache Size Calculator

- - - - - - - +
+ + +
+
+

KV Cache Size Calculator

+ +
+ +
+ +
+
+ +
+ +
+ +
+
+ +
+ + +
+ + +
-
- + +
-
- Developed by Zhuohan Gu @ LMCache team -
+
Developed by Zhuohan Gu @ LMCache team
+