Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/kernels/attention/test_cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
from vllm.v1.attention.backends.cpu_attn import CPUAttentionBackend, _get_attn_isa

if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
Expand Down Expand Up @@ -59,6 +59,10 @@ def get_attn_isa(
)


def test_cpu_backend_requires_hnd_kv_cache_layout():
assert CPUAttentionBackend.get_required_kv_cache_layout() == "HND"


# rand number generation takes too much time, cache rand tensors
@functools.lru_cache(maxsize=128, typed=False)
def tensor_cache(
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
KVCacheLayoutType,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
Expand Down Expand Up @@ -94,6 +95,10 @@ def get_kv_cache_shape(
) -> tuple[int, ...]:
return 2, num_blocks, num_kv_heads, block_size, head_size

@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return "HND"

@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
Expand Down