diff --git a/tests/kernels/attention/test_cpu_attn.py b/tests/kernels/attention/test_cpu_attn.py index 6af1bfe1e7ac..c4bd10215486 100644 --- a/tests/kernels/attention/test_cpu_attn.py +++ b/tests/kernels/attention/test_cpu_attn.py @@ -9,7 +9,7 @@ from vllm.platforms import CpuArchEnum, current_platform from vllm.utils.torch_utils import set_random_seed -from vllm.v1.attention.backends.cpu_attn import _get_attn_isa +from vllm.v1.attention.backends.cpu_attn import CPUAttentionBackend, _get_attn_isa if not current_platform.is_cpu(): pytest.skip("skipping CPU-only tests", allow_module_level=True) @@ -59,6 +59,10 @@ def get_attn_isa( ) +def test_cpu_backend_requires_hnd_kv_cache_layout(): + assert CPUAttentionBackend.get_required_kv_cache_layout() == "HND" + + # rand number generation takes too much time, cache rand tensors @functools.lru_cache(maxsize=128, typed=False) def tensor_cache( diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index e28e20045148..801a16d319e1 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -25,6 +25,7 @@ MultipleOf, ) from vllm.v1.attention.backends.utils import ( + KVCacheLayoutType, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec @@ -94,6 +95,10 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: return 2, num_blocks, num_kv_heads, block_size, head_size + @classmethod + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": + return "HND" + @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False