diff --git a/csrc/sycl/mp_mem_kernels_sycl.cpp b/csrc/sycl/mp_mem_kernels_sycl.cpp new file mode 100644 index 0000000000..eaaa2f77dc --- /dev/null +++ b/csrc/sycl/mp_mem_kernels_sycl.cpp @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: Apache-2.0 + +// +// SYCL implementation of multi_layer_block_kv_transfer for Intel XPU +// (PVC / Arc / Battlemage). +// +// Design notes: +// +// 1. Independence from CUDA: this file must compile with icpx alone. +// It includes only SYCL and PyTorch/ATen headers. +// +// 2. Vectorization: uses int64_t (8 bytes) as the maximum copy +// granularity, with fallback to int32_t and int16_t based on +// head_bytes alignment. No uint4 or other CUDA-specific types. +// +// 3. Kernel topology: nd_range<3> with group dimensions +// (kv_size, total_blocks, nl) and a 1-D local range (wg_size). +// Work-group size is rounded up to sub-group size 16 and capped at 256. +// Each work-item strides across all bs * scalars_per_token elements in +// the block using a stride loop. +// +// 4. Supported GPUKVFormat values (NHD + MLA only): +// NB_NL_TWO_BS_NH_HS — cross-layer single tensor +// NL_X_TWO_NB_BS_NH_HS — vLLM flash attention +// NL_X_NB_TWO_BS_NH_HS — vLLM flash infer +// NL_X_NB_BS_HS — vLLM MLA +// NL_X_NBBS_ONE_HS — SGLang MLA +// All other formats throw std::runtime_error. + +// sycl/accessor.hpp references the deprecated 'host_buffer' internally +// even when user code only uses USM pointers; suppress the noise. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include +#pragma GCC diagnostic pop + +#include +#include +#include +#include + +#include "mp_mem_kernels_sycl.h" + +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Tuning constants (mirrored from mem_kernels_sycl.cpp) +// --------------------------------------------------------------------------- +namespace { + +constexpr int INTEL_SUB_GROUP_SIZE = 16; +constexpr int MAX_WG_SIZE = 256; + +inline int round_up_to_sg(int n) { + return ((n + INTEL_SUB_GROUP_SIZE - 1) / INTEL_SUB_GROUP_SIZE) * + INTEL_SUB_GROUP_SIZE; +} + +// --------------------------------------------------------------------------- +// Kernel submit helper +// --------------------------------------------------------------------------- +// +// Template parameters: +// scalar_t — working scalar type (int64_t / int32_t / int16_t) +// LMCACHE_TO_ENGINE — true = H2D (LMCache → engine), false = D2H +// format — compile-time GPUKVFormat (eliminates branches) +// +// All five supported NHD / MLA formats have the same intra-block memory +// layout: [BS, NH, HS] (or [BS, HS] for MLA with NH==1). This means the +// flat index ``i`` within a block is identical for both the engine and the +// LMCache side, so a single stride loop suffices. +// +// paged_buffer_ptrs — scalar_t** pointing to XPU memory that holds the +// per-layer (or single cross-layer) tensor data pointers. Accessed via +// USM device dereference inside the kernel. +// +// obj0..obj3 — individual scalar_t* values captured by value so the SYCL +// lambda receives device-accessible copies without dangling pointer risk. +// +template +void submit_block_kv_transfer_kernel( + sycl::queue& queue, + scalar_t** paged_buffer_ptrs, + const int64_t* block_ids_ptr, + scalar_t* obj0, + scalar_t* obj1, + scalar_t* obj2, + scalar_t* obj3, + int total_blocks, + int num_blocks_per_object, + int skip_prefix_n_blocks, + PageBufferShapeDesc shape_desc, + int lmcache_chunk_size, + int wg_size) { + const int kv_size = shape_desc.kv_size; + const int nl = shape_desc.nl; + const int bs = shape_desc.bs; + const int nb = shape_desc.nb; + + // Pre-compute scalar counts outside the kernel to avoid repeated + // template method calls from within device code. + const size_t spt = shape_desc.scalars_per_token(); + const size_t spb = shape_desc.scalars_per_block(); + const size_t total_elements = static_cast(bs) * spt; + + // Grid: (kv_size, total_blocks, nl) groups; each group has wg_size threads. + sycl::range<3> global_range( + static_cast(kv_size), + static_cast(total_blocks), + static_cast(nl) * static_cast(wg_size)); + sycl::range<3> local_range(1, 1, static_cast(wg_size)); + + queue.parallel_for( + sycl::nd_range<3>(global_range, local_range), + [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(16)]] { + const int k_or_v = static_cast(item.get_group(0)); + const int flat_block_idx = static_cast(item.get_group(1)); + const int layer_idx = static_cast(item.get_group(2)); + const int tid = static_cast(item.get_local_id(2)); + const int num_threads = static_cast(item.get_local_range(2)); + + // Skip leading prefix blocks. + if (flat_block_idx < skip_prefix_n_blocks) return; + + const int obj_idx = flat_block_idx / num_blocks_per_object; + const int block_idx_in_object = flat_block_idx % num_blocks_per_object; + const int engine_block_idx = + static_cast(block_ids_ptr[flat_block_idx]); + + // ---------------------------------------------------------------- + // Engine global offset: start of this block within the engine + // tensor, in scalar_t units. Formula depends on format. + // ---------------------------------------------------------------- + size_t engine_global; + if constexpr (format == GPUKVFormat::NB_NL_TWO_BS_NH_HS) { + // Single tensor [NB, NL, 2, BS, NH, HS] + engine_global = + static_cast(k_or_v) * spb + + static_cast(layer_idx) * static_cast(kv_size) * + spb + + static_cast(engine_block_idx) * + static_cast(kv_size) * spb * + static_cast(nl); + } else if constexpr (format == GPUKVFormat::NL_X_TWO_NB_BS_NH_HS) { + // Per-layer tensor [2, NB, BS, NH, HS] + engine_global = static_cast(engine_block_idx) * spb + + static_cast(k_or_v) * + static_cast(nb) * spb; + } else if constexpr (format == GPUKVFormat::NL_X_NB_TWO_BS_NH_HS) { + // Per-layer tensor [NB, 2, BS, NH, HS] + engine_global = + static_cast(engine_block_idx) * + static_cast(kv_size) * spb + + static_cast(k_or_v) * spb; + } else if constexpr (format == GPUKVFormat::NL_X_NB_BS_HS || + format == GPUKVFormat::NL_X_NBBS_ONE_HS) { + // Per-layer MLA tensor ([NB, BS, HS] or [NBBS, 1, HS]) + engine_global = static_cast(engine_block_idx) * spb; + } + + // ---------------------------------------------------------------- + // LMCache global offset: 2LTD layout [2, L, chunk_size, NH*HS]. + // token_offset_in_object = block_idx_in_object * bs (tokens). + // ---------------------------------------------------------------- + const size_t token_offset_in_object = + static_cast(block_idx_in_object) * + static_cast(bs); + const size_t lmcache_global = + static_cast(k_or_v) * + static_cast(nl) * + static_cast(lmcache_chunk_size) * spt + + static_cast(layer_idx) * + static_cast(lmcache_chunk_size) * spt + + token_offset_in_object * spt; + + // ---------------------------------------------------------------- + // Select engine layer pointer. + // NB_NL_TWO_BS_NH_HS stores all layers in a single tensor (idx 0). + // All other formats store one tensor per layer. + // ---------------------------------------------------------------- + scalar_t* engine_layer_ptr; + if constexpr (format == GPUKVFormat::NB_NL_TWO_BS_NH_HS) { + engine_layer_ptr = paged_buffer_ptrs[0]; + } else { + engine_layer_ptr = paged_buffer_ptrs[layer_idx]; + } + + // ---------------------------------------------------------------- + // Select LMCache object pointer (up to 4 objects). + // Individual variables are captured by value to avoid dangling + // pointer risk from capturing a local array by pointer. + // ---------------------------------------------------------------- + scalar_t* lmcache_ptr; + if (obj_idx == 0) { + lmcache_ptr = obj0; + } else if (obj_idx == 1) { + lmcache_ptr = obj1; + } else if (obj_idx == 2) { + lmcache_ptr = obj2; + } else { + lmcache_ptr = obj3; + } + + // ---------------------------------------------------------------- + // Stride loop: copy total_elements = bs * spt scalars. + // + // Because all supported formats use NHD intra-block layout + // ([BS, NH, HS]), the flat index i maps directly to the same + // offset in both the engine and the LMCache buffer. + // ---------------------------------------------------------------- + for (size_t i = static_cast(tid); i < total_elements; + i += static_cast(num_threads)) { + if constexpr (LMCACHE_TO_ENGINE) { + engine_layer_ptr[engine_global + i] = + lmcache_ptr[lmcache_global + i]; + } else { + lmcache_ptr[lmcache_global + i] = + engine_layer_ptr[engine_global + i]; + } + } + }); +} + +// --------------------------------------------------------------------------- +// Dispatch macros +// --------------------------------------------------------------------------- + +#define LAUNCH_BLOCK_KERNEL(DIRECTION, FORMAT) \ + submit_block_kv_transfer_kernel( \ + queue, paged_buffer_ptrs, block_ids_ptr, \ + obj0, obj1, obj2, obj3, \ + total_blocks, num_blocks_per_object, skip_prefix_n_blocks, \ + shape_desc, lmcache_chunk_size, wg_size); + +#define DISPATCH_FORMAT(DIRECTION) \ + switch (gpu_kv_format) { \ + case GPUKVFormat::NB_NL_TWO_BS_NH_HS: \ + LAUNCH_BLOCK_KERNEL(DIRECTION, GPUKVFormat::NB_NL_TWO_BS_NH_HS); \ + break; \ + case GPUKVFormat::NL_X_TWO_NB_BS_NH_HS: \ + LAUNCH_BLOCK_KERNEL(DIRECTION, GPUKVFormat::NL_X_TWO_NB_BS_NH_HS); \ + break; \ + case GPUKVFormat::NL_X_NB_TWO_BS_NH_HS: \ + LAUNCH_BLOCK_KERNEL(DIRECTION, GPUKVFormat::NL_X_NB_TWO_BS_NH_HS); \ + break; \ + case GPUKVFormat::NL_X_NB_BS_HS: \ + LAUNCH_BLOCK_KERNEL(DIRECTION, GPUKVFormat::NL_X_NB_BS_HS); \ + break; \ + case GPUKVFormat::NL_X_NBBS_ONE_HS: \ + LAUNCH_BLOCK_KERNEL(DIRECTION, GPUKVFormat::NL_X_NBBS_ONE_HS); \ + break; \ + default: \ + throw std::runtime_error( \ + "multi_layer_block_kv_transfer (SYCL): unsupported GPUKVFormat " + \ + std::to_string(static_cast(gpu_kv_format))); \ + } + +// --------------------------------------------------------------------------- +// Templated implementation +// --------------------------------------------------------------------------- +template +void multi_layer_block_kv_transfer_templated( + const torch::Tensor& paged_buffer_ptrs_tensor, + std::vector lmcache_objects_ptrs, + const torch::Tensor& block_ids, + const torch::Device& device, + TransferDirection direction, + PageBufferShapeDesc shape_desc, + int lmcache_chunk_size, + GPUKVFormat gpu_kv_format, + int skip_prefix_n_blocks) { + // --- Validation --- + const int num_objects = static_cast(lmcache_objects_ptrs.size()); + TORCH_CHECK(num_objects >= 1 && num_objects <= 4, + "Expected 1–4 LMCache objects, got ", num_objects); + + const int total_blocks = static_cast(block_ids.size(0)); + TORCH_CHECK(total_blocks % num_objects == 0, + "block_ids length (", total_blocks, + ") must be divisible by num_objects (", num_objects, ")"); + const int num_blocks_per_object = total_blocks / num_objects; + + TORCH_CHECK( + num_blocks_per_object * shape_desc.bs == lmcache_chunk_size, + "blocks_per_object * block_size (", + num_blocks_per_object * shape_desc.bs, + ") must equal lmcache_chunk_size (", lmcache_chunk_size, ")"); + + // --- Build typed object pointers (captured by value in kernel) --- + scalar_t* obj0 = (num_objects > 0) + ? reinterpret_cast(lmcache_objects_ptrs[0]) + : nullptr; + scalar_t* obj1 = (num_objects > 1) + ? reinterpret_cast(lmcache_objects_ptrs[1]) + : nullptr; + scalar_t* obj2 = (num_objects > 2) + ? reinterpret_cast(lmcache_objects_ptrs[2]) + : nullptr; + scalar_t* obj3 = (num_objects > 3) + ? reinterpret_cast(lmcache_objects_ptrs[3]) + : nullptr; + + // --- Paged-buffer pointer array (XPU USM device memory) --- + scalar_t** paged_buffer_ptrs = + reinterpret_cast(paged_buffer_ptrs_tensor.data_ptr()); + + // --- Block IDs (XPU int64 tensor) --- + const int64_t* block_ids_ptr = block_ids.data_ptr(); + + // --- Work-group size: round up scalars_per_token to sub-group boundary --- + const size_t spt = shape_desc.scalars_per_token(); + const int wg_size = round_up_to_sg( + std::min(static_cast(spt), MAX_WG_SIZE)); + + // --- Acquire the XPU stream for the target device --- + const c10::OptionalDeviceGuard device_guard(device); + sycl::queue& queue = + c10::xpu::getCurrentXPUStream(device.index()).queue(); + + if (direction == TransferDirection::H2D) { + DISPATCH_FORMAT(true); + } else { + DISPATCH_FORMAT(false); + } +} + +#undef DISPATCH_FORMAT +#undef LAUNCH_BLOCK_KERNEL + +} // namespace + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- +void multi_layer_block_kv_transfer( + const torch::Tensor& paged_buffer_ptrs_tensor, + std::vector lmcache_objects_ptrs, + const torch::Tensor& block_ids, + const torch::Device& device, + TransferDirection direction, + PageBufferShapeDesc shape_desc, + int lmcache_chunk_size, + GPUKVFormat gpu_kv_format, + int skip_prefix_n_blocks) { + const int head_bytes = shape_desc.hs * shape_desc.element_size; + TORCH_CHECK(head_bytes % sizeof(int16_t) == 0, + "head_size * element_size (", head_bytes, + ") must be divisible by 2 for vectorized access"); + +#define LAUNCH_TEMPLATED(type) \ + do { \ + multi_layer_block_kv_transfer_templated( \ + paged_buffer_ptrs_tensor, lmcache_objects_ptrs, block_ids, \ + device, direction, shape_desc, lmcache_chunk_size, \ + gpu_kv_format, skip_prefix_n_blocks); \ + } while (0) + + if (head_bytes % sizeof(int64_t) == 0) { + LAUNCH_TEMPLATED(int64_t); // 8 bytes per transfer + } else if (head_bytes % sizeof(int32_t) == 0) { + LAUNCH_TEMPLATED(int32_t); // 4 bytes per transfer + } else { + LAUNCH_TEMPLATED(int16_t); // 2 bytes per transfer (minimum) + } + +#undef LAUNCH_TEMPLATED +} diff --git a/csrc/sycl/mp_mem_kernels_sycl.h b/csrc/sycl/mp_mem_kernels_sycl.h new file mode 100644 index 0000000000..2ee9204654 --- /dev/null +++ b/csrc/sycl/mp_mem_kernels_sycl.h @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +// This header is intentionally independent of CUDA. It must compile without +// any CUDA headers present. The SYCL build uses only the headers below. + +#include +#include +#include +#include + +#include "mem_kernels_sycl.h" // TransferDirection, GPUKVFormat + +/** + * Shape descriptor for a vLLM paged KV buffer. + * + * Mirrors the CUDA ``PageBufferShapeDesc`` in ``csrc/mp_mem_kernels.cuh`` + * but uses plain ``inline`` instead of ``__host__ __device__`` so it + * compiles without any CUDA headers. + */ +struct PageBufferShapeDesc { + int kv_size; // 1 or 2 + int nl; // num layers + int nb; // num blocks + int bs; // block size (tokens per block) + int nh; // num heads + int hs; // head size + int element_size; // bytes per scalar (1 or 2) + // Physical per-block stride in source-dtype element units, used by + // formats whose dim-0 is the block axis to step over padding bytes. + // 0 means "unset — fall back to the format-specific tight stride". + // + // CONTRACT: pass ``tensor.stride(0)`` verbatim; do NOT pre-multiply + // by any inner dimension. + // + // Honoured today only by NL_X_NB_BS_HS (MLA). All other formats + // ignore this field. + int block_stride_elems; + + /** + * Number of ScalarType elements per attention head. + * + * @tparam ScalarType The working scalar type (e.g. int64_t, int32_t, + * int16_t). + * @return hs * element_size / sizeof(ScalarType) + */ + template + inline size_t scalars_per_head() const { + return static_cast(hs) * element_size / sizeof(ScalarType); + } + + /** + * Number of ScalarType elements per token (all heads). + * + * @tparam ScalarType The working scalar type. + * @return nh * hs * element_size / sizeof(ScalarType) + */ + template + inline size_t scalars_per_token() const { + return static_cast(nh) * hs * element_size / sizeof(ScalarType); + } + + /** + * Physical per-block stride in ScalarType element units. + * + * Returns the tight ``bs * nh * hs`` stride by default, or the + * physical ``block_stride_elems`` stride when dim-0 carries padding + * (today only NL_X_NB_BS_HS / MLA). + * + * @tparam ScalarType The working scalar type. + * @return padded-or-tight stride in ScalarType units + */ + template + inline size_t scalars_per_block() const { + const size_t elems = block_stride_elems > 0 + ? static_cast(block_stride_elems) + : static_cast(bs) * nh * hs; + return elems * element_size / sizeof(ScalarType); + } +}; + +/** + * Holds up to 4 typed pointers to LMCache memory objects. + * + * @tparam ScalarType The working scalar type. + */ +template +struct MemoryObj4 { + ScalarType* objects[4]; + int num_objects; // 0–4 +}; + +/** + * Block-level multi-layer KV transfer between vLLM paged buffers and + * LMCache contiguous memory objects (SYCL / XPU implementation). + * + * @param paged_buffer_ptrs_tensor XPU int64 tensor of data pointers into + * vLLM paged buffers (one per tensor). + * @param lmcache_objects_ptrs Raw pointer values for up to 4 LMCache + * memory objects (XPU USM device pointers). + * @param block_ids XPU int64 tensor of vLLM block indices, + * one entry per block across all objects. + * @param device XPU device of the vLLM tensors. + * @param direction H2D (LMCache → vLLM) or D2H (vLLM → + * LMCache). + * @param shape_desc Shape descriptor for the paged buffer. + * @param lmcache_chunk_size Tokens per LMCache memory object. + * @param gpu_kv_format GPUKVFormat identifier. Only the 5 + * NHD / MLA formats supported by the SYCL + * backend are accepted; others throw + * std::runtime_error. + * @param skip_prefix_n_blocks Number of leading blocks (by flat index) + * to leave untouched. + */ +void multi_layer_block_kv_transfer( + const torch::Tensor& paged_buffer_ptrs_tensor, + std::vector lmcache_objects_ptrs, + const torch::Tensor& block_ids, + const torch::Device& device, + TransferDirection direction, + PageBufferShapeDesc shape_desc, + int lmcache_chunk_size, + GPUKVFormat gpu_kv_format, + int skip_prefix_n_blocks); diff --git a/csrc/sycl/pybind_sycl.cpp b/csrc/sycl/pybind_sycl.cpp index 430e08dc2d..c077db6d93 100644 --- a/csrc/sycl/pybind_sycl.cpp +++ b/csrc/sycl/pybind_sycl.cpp @@ -5,8 +5,10 @@ // Exposed as `lmcache.xpu_ops`. // #include +#include #include #include "mem_kernels_sycl.h" +#include "mp_mem_kernels_sycl.h" #include "cachegen_kernels_sycl.h" namespace py = pybind11; @@ -45,6 +47,23 @@ PYBIND11_MODULE(xpu_ops, m) { m.def("reshape_and_cache_back_flash", &reshape_and_cache_back_flash); m.def("lmcache_memcpy_async", &lmcache_memcpy_async, py::call_guard()); + m.def("multi_layer_block_kv_transfer", &multi_layer_block_kv_transfer, + py::arg("paged_buffer_ptrs_tensor"), py::arg("lmcache_objects_ptrs"), + py::arg("block_ids"), py::arg("device"), py::arg("direction"), + py::arg("shape_desc"), py::arg("lmcache_chunk_size"), + py::arg("gpu_kv_format"), py::arg("skip_prefix_n_blocks"), + py::call_guard()); + py::class_(m, "PageBufferShapeDesc") + .def(py::init<>()) + .def_readwrite("kv_size", &PageBufferShapeDesc::kv_size) + .def_readwrite("nl", &PageBufferShapeDesc::nl) + .def_readwrite("nb", &PageBufferShapeDesc::nb) + .def_readwrite("bs", &PageBufferShapeDesc::bs) + .def_readwrite("nh", &PageBufferShapeDesc::nh) + .def_readwrite("hs", &PageBufferShapeDesc::hs) + .def_readwrite("element_size", &PageBufferShapeDesc::element_size) + .def_readwrite("block_stride_elems", + &PageBufferShapeDesc::block_stride_elems); // CacheGen / RoPE kernels (Intel XPU). Names match the // lmcache.python_ops_fallback module so lmcache._get_backend() can diff --git a/setup.py b/setup.py index 9c59f2d9d9..3c330c021e 100644 --- a/setup.py +++ b/setup.py @@ -327,6 +327,7 @@ def sycl_extension() -> tuple[list, dict]: sycl_sources = [ "csrc/sycl/pybind_sycl.cpp", "csrc/sycl/mem_kernels_sycl.cpp", + "csrc/sycl/mp_mem_kernels_sycl.cpp", "csrc/sycl/cal_cdf_sycl.cpp", "csrc/sycl/pos_kernels_sycl.cpp", "csrc/sycl/ac_enc_sycl.cpp", diff --git a/tests/v1/test_mp_mem_kernels_sycl.py b/tests/v1/test_mp_mem_kernels_sycl.py new file mode 100644 index 0000000000..ce1aaa028a --- /dev/null +++ b/tests/v1/test_mp_mem_kernels_sycl.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import random + +# Third Party +import pytest +import torch + +pytest.importorskip( + "lmcache.xpu_ops", + reason="Requires SYCL extension lmcache.xpu_ops", +) + +# First Party +import lmcache.xpu_ops as xpu_ops + +# Skip all tests if an XPU device is unavailable. +pytestmark = pytest.mark.skipif( + not torch.xpu.is_available() or torch.xpu.device_count() == 0, + reason="No Intel XPU present", +) + +# --------------------------------------------------------------------------- +# Supported formats (5 NHD + MLA formats covered by the SYCL backend) +# --------------------------------------------------------------------------- +FMT_CROSS_LAYER = xpu_ops.GPUKVFormat.NB_NL_TWO_BS_NH_HS +FMT_NORMAL = xpu_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS +FMT_FLASH_INFER = xpu_ops.GPUKVFormat.NL_X_NB_TWO_BS_NH_HS +FMT_MLA = xpu_ops.GPUKVFormat.NL_X_NB_BS_HS +FMT_SGLANG_MLA = xpu_ops.GPUKVFormat.NL_X_NBBS_ONE_HS + +# (gpu_kv_format, nl, nh, hs, is_mla) +FORMAT_PARAMS = [ + (FMT_CROSS_LAYER, 4, 8, 128, False), + (FMT_NORMAL, 4, 8, 128, False), + (FMT_FLASH_INFER, 4, 8, 128, False), + (FMT_MLA, 4, 1, 576, True), + (FMT_SGLANG_MLA, 4, 1, 576, True), +] + + +# --------------------------------------------------------------------------- +# Tensor factories +# --------------------------------------------------------------------------- + + +def _create_random_tensor(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) + + +def _create_zero_tensor(shape, dtype, device): + return torch.zeros(shape, dtype=dtype, device=device) + + +def create_vllm_tensors(gpu_kv_format, nl, nb, bs, nh, hs, dtype, device): + """Create random vLLM paged-buffer tensors for the given format.""" + nbbs = nb * bs + if gpu_kv_format == FMT_NORMAL: + return [ + _create_random_tensor([2, nb, bs, nh, hs], dtype, device) for _ in range(nl) + ] + elif gpu_kv_format == FMT_CROSS_LAYER: + return [_create_random_tensor([nb, nl, 2, bs, nh, hs], dtype, device)] + elif gpu_kv_format == FMT_FLASH_INFER: + return [ + _create_random_tensor([nb, 2, bs, nh, hs], dtype, device) for _ in range(nl) + ] + elif gpu_kv_format == FMT_MLA: + return [_create_random_tensor([nb, bs, hs], dtype, device) for _ in range(nl)] + elif gpu_kv_format == FMT_SGLANG_MLA: + return [_create_random_tensor([nbbs, 1, hs], dtype, device) for _ in range(nl)] + raise ValueError(f"Unknown format: {gpu_kv_format}") + + +def create_zero_vllm_tensors(gpu_kv_format, nl, nb, bs, nh, hs, dtype, device): + """Create zero-filled vLLM paged-buffer tensors for the given format.""" + nbbs = nb * bs + if gpu_kv_format == FMT_NORMAL: + return [ + _create_zero_tensor([2, nb, bs, nh, hs], dtype, device) for _ in range(nl) + ] + elif gpu_kv_format == FMT_CROSS_LAYER: + return [_create_zero_tensor([nb, nl, 2, bs, nh, hs], dtype, device)] + elif gpu_kv_format == FMT_FLASH_INFER: + return [ + _create_zero_tensor([nb, 2, bs, nh, hs], dtype, device) for _ in range(nl) + ] + elif gpu_kv_format == FMT_MLA: + return [_create_zero_tensor([nb, bs, hs], dtype, device) for _ in range(nl)] + elif gpu_kv_format == FMT_SGLANG_MLA: + return [_create_zero_tensor([nbbs, 1, hs], dtype, device) for _ in range(nl)] + raise ValueError(f"Unknown format: {gpu_kv_format}") + + +def create_memory_objects( + kv_dim, nl, tokens_per_object, hidden_dim, num_objects, dtype, device +): + """Create zero-filled LMCache memory objects [2, L, T, NH*HS].""" + shape = [kv_dim, nl, tokens_per_object, hidden_dim] + return [_create_zero_tensor(shape, dtype, device) for _ in range(num_objects)] + + +def get_block_data(vllm_tensors, gpu_kv_format, nl, bs, block_idx): + """Extract all-layer data for *block_idx* as a list of layer tensors.""" + results = [] + for layer_idx in range(nl): + if gpu_kv_format == FMT_NORMAL: + results.append(vllm_tensors[layer_idx][:, block_idx, :, :, :].clone()) + elif gpu_kv_format == FMT_CROSS_LAYER: + results.append(vllm_tensors[0][block_idx, layer_idx, :, :, :, :].clone()) + elif gpu_kv_format == FMT_FLASH_INFER: + results.append(vllm_tensors[layer_idx][block_idx, :, :, :, :].clone()) + elif gpu_kv_format == FMT_MLA: + results.append(vllm_tensors[layer_idx][block_idx, :, :].clone()) + elif gpu_kv_format == FMT_SGLANG_MLA: + ts, ed = block_idx * bs, (block_idx + 1) * bs + results.append(vllm_tensors[layer_idx][ts:ed, 0, :].clone()) + return results + + +# --------------------------------------------------------------------------- +# Kernel call helper +# --------------------------------------------------------------------------- + + +def call_block_kernel( + vllm_tensors, + mem_objects, + block_ids, + gpu_kv_format, + direction, + nl, + nb, + bs, + nh, + hs, + is_mla, + tokens_per_object, + skip_prefix_n_blocks=0, +): + """Call xpu_ops.multi_layer_block_kv_transfer with the given arguments.""" + device = vllm_tensors[0].device + + shape_desc = xpu_ops.PageBufferShapeDesc() + shape_desc.kv_size = 1 if is_mla else 2 + shape_desc.nl = nl + shape_desc.nb = nb + shape_desc.bs = bs + shape_desc.nh = nh + shape_desc.hs = hs + shape_desc.element_size = vllm_tensors[0].element_size() + shape_desc.block_stride_elems = 0 + + ptrs = [t.data_ptr() for t in vllm_tensors] + paged_buffer_ptrs_tensor = torch.tensor(ptrs, dtype=torch.int64, device=device) + lmcache_objects_ptrs = [m.data_ptr() for m in mem_objects] + + block_ids_gpu = torch.tensor(block_ids, dtype=torch.int64, device=device) + xpu_ops.multi_layer_block_kv_transfer( + paged_buffer_ptrs_tensor, + lmcache_objects_ptrs, + block_ids_gpu, + device, + direction, + shape_desc, + tokens_per_object, + gpu_kv_format, + skip_prefix_n_blocks, + ) + + +# --------------------------------------------------------------------------- +# Test configuration +# --------------------------------------------------------------------------- +NB = 200 # Must be >= 2 * TOTAL_BLOCKS so disjoint D2H and H2D block ID sets fit +BS = 16 +NUM_MEMORY_OBJECTS = 4 +TOKENS_PER_OBJECT = 256 +BLOCKS_PER_OBJECT = TOKENS_PER_OBJECT // BS # 16 +TOTAL_BLOCKS = NUM_MEMORY_OBJECTS * BLOCKS_PER_OBJECT # 64 + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "gpu_kv_format,nl,nh,hs,is_mla", + FORMAT_PARAMS, + ids=["cross_layer", "normal", "flash_infer", "mla", "sglang_mla"], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +def test_block_transfer_roundtrip(gpu_kv_format, nl, nh, hs, is_mla, dtype): + """ + D2H → H2D roundtrip: data written via D2H must be recoverable via H2D. + + Uses disjoint source and target block IDs so the result is unambiguous. + """ + device = torch.device("xpu") + kv_dim = 1 if is_mla else 2 + hidden_dim = nh * hs + + source_vllm = create_vllm_tensors(gpu_kv_format, nl, NB, BS, nh, hs, dtype, device) + target_vllm = create_zero_vllm_tensors( + gpu_kv_format, nl, NB, BS, nh, hs, dtype, device + ) + mem_objects = create_memory_objects( + kv_dim, nl, TOKENS_PER_OBJECT, hidden_dim, NUM_MEMORY_OBJECTS, dtype, device + ) + + rng_d2h = random.Random(42) + block_ids_d2h = rng_d2h.sample(range(NB), TOTAL_BLOCKS) + excluded = set(block_ids_d2h) + available = [i for i in range(NB) if i not in excluded] + rng_h2d = random.Random(123) + block_ids_h2d = rng_h2d.sample(available, TOTAL_BLOCKS) + + # D2H: source vLLM → LMCache memory objects + call_block_kernel( + source_vllm, + mem_objects, + block_ids_d2h, + gpu_kv_format, + xpu_ops.TransferDirection.D2H, + nl, + NB, + BS, + nh, + hs, + is_mla, + TOKENS_PER_OBJECT, + ) + torch.xpu.synchronize() + + # H2D: LMCache memory objects → target vLLM + call_block_kernel( + target_vllm, + mem_objects, + block_ids_h2d, + gpu_kv_format, + xpu_ops.TransferDirection.H2D, + nl, + NB, + BS, + nh, + hs, + is_mla, + TOKENS_PER_OBJECT, + ) + torch.xpu.synchronize() + + # Verify: target[h2d_block_i] == source[d2h_block_i] + for i in range(TOTAL_BLOCKS): + src_data = get_block_data(source_vllm, gpu_kv_format, nl, BS, block_ids_d2h[i]) + tgt_data = get_block_data(target_vllm, gpu_kv_format, nl, BS, block_ids_h2d[i]) + for layer_idx in range(nl): + assert torch.equal(src_data[layer_idx], tgt_data[layer_idx]), ( + f"Mismatch at block {i}, layer {layer_idx}" + ) + + +@pytest.mark.parametrize( + "gpu_kv_format,nl,nh,hs,is_mla", + FORMAT_PARAMS, + ids=["cross_layer", "normal", "flash_infer", "mla", "sglang_mla"], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +def test_block_transfer_skip_prefix(gpu_kv_format, nl, nh, hs, is_mla, dtype): + """Verify skip_prefix_n_blocks skips the first N blocks globally.""" + device = torch.device("xpu") + kv_dim = 1 if is_mla else 2 + hidden_dim = nh * hs + skip = 4 + + source_vllm = create_vllm_tensors(gpu_kv_format, nl, NB, BS, nh, hs, dtype, device) + target_vllm = create_zero_vllm_tensors( + gpu_kv_format, nl, NB, BS, nh, hs, dtype, device + ) + mem_objects = create_memory_objects( + kv_dim, nl, TOKENS_PER_OBJECT, hidden_dim, NUM_MEMORY_OBJECTS, dtype, device + ) + + rng_d2h = random.Random(42) + block_ids_d2h = rng_d2h.sample(range(NB), TOTAL_BLOCKS) + excluded = set(block_ids_d2h) + available = [i for i in range(NB) if i not in excluded] + rng_h2d = random.Random(123) + block_ids_h2d = rng_h2d.sample(available, TOTAL_BLOCKS) + + # D2H with prefix skip + call_block_kernel( + source_vllm, + mem_objects, + block_ids_d2h, + gpu_kv_format, + xpu_ops.TransferDirection.D2H, + nl, + NB, + BS, + nh, + hs, + is_mla, + TOKENS_PER_OBJECT, + skip_prefix_n_blocks=skip, + ) + torch.xpu.synchronize() + + # H2D with prefix skip + call_block_kernel( + target_vllm, + mem_objects, + block_ids_h2d, + gpu_kv_format, + xpu_ops.TransferDirection.H2D, + nl, + NB, + BS, + nh, + hs, + is_mla, + TOKENS_PER_OBJECT, + skip_prefix_n_blocks=skip, + ) + torch.xpu.synchronize() + + # Non-skipped blocks [skip, TOTAL_BLOCKS) must match. + for i in range(skip, TOTAL_BLOCKS): + src_data = get_block_data(source_vllm, gpu_kv_format, nl, BS, block_ids_d2h[i]) + tgt_data = get_block_data(target_vllm, gpu_kv_format, nl, BS, block_ids_h2d[i]) + for layer_idx in range(nl): + assert torch.equal(src_data[layer_idx], tgt_data[layer_idx]), ( + f"Mismatch at block {i}, layer {layer_idx}" + ) + + # Skipped blocks in target must remain zero. + for i in range(skip): + tgt_data = get_block_data(target_vllm, gpu_kv_format, nl, BS, block_ids_h2d[i]) + for layer_idx in range(nl): + block = tgt_data[layer_idx].to(torch.float32) + assert block.abs().sum().item() == 0, ( + f"Skipped block {i}, layer {layer_idx} is not zero" + )