diff --git a/challenges/medium/86_paged_attention/challenge.html b/challenges/medium/86_paged_attention/challenge.html new file mode 100644 index 00000000..9eb99eb7 --- /dev/null +++ b/challenges/medium/86_paged_attention/challenge.html @@ -0,0 +1,201 @@ +

+ Implement decode-phase attention over a paged KV cache. In LLM serving systems (e.g., vLLM), + the key and value tensors for each sequence are stored in fixed-size memory blocks (pages) that + may be scattered non-contiguously across a shared GPU memory pool. A block_table maps each + sequence's logical block indices to physical block indices in the cache pool. Given a single query vector + per sequence (one new token being generated), compute the attention output by gathering the relevant + K/V blocks via the block table and computing scaled dot-product attention over the full context. +

+ + + + + + + + + + + + + + + + block_table + + + blk 0 + blk 1 + blk 2 + + + seq 0 + + 3 + + 7 + + + + + seq 1 + + 1 + + 5 + + 9 + + values = physical block indices in pool ↓ + + + + + K_cache / V_cache pool (GPU memory) + + + + + blk 0 + + + + blk 1 + seq1.0 + + + + blk 2 + + + + blk 3 + seq0.0 + + + + blk 4 + + + + blk 5 + seq1.1 + + + + blk 6 + + + + blk 7 + seq0.1 + + + + blk 8 + + + + blk 9 + seq1.2 + + + + + + Decode Attention (per sequence s, per head h) + + 1. + Gather K, V: token t is at pool[ block_table[s, t/B] ], offset t%B + + 2. + scores[t] = Q[s,h] · K[s,h,t] / √head_dim for t = 0 .. context_lens[s]-1 + + 3. + output[s,h] = ∑_t softmax(scores)[t] · V[s,h,t] + + +

Implementation Requirements

+

+ Implement the function solve(Q, K_cache, V_cache, block_table, context_lens, output, batch_size, num_heads, head_dim, block_size, max_blocks_per_seq) + that computes paged decode-phase attention: +

+ +

+ For each sequence \(s\) and each attention head \(h\), compute: +

+
    +
  1. + Gather the \(\text{context_lens}[s]\) key and value vectors from the paged cache using + \(\text{block_table}[s]\). Token at logical position \(t\) lives in physical block + \(\text{block_table}[s,\;\lfloor t / B \rfloor]\) at offset \(t \bmod B\) within that block, + where \(B = \text{block_size}\). +
  2. +
  3. + Compute scaled dot-product attention: + \[\text{scores}[t] = \frac{Q[s, h] \cdot K[s, h, t]}{\sqrt{\text{head_dim}}}\] +
  4. +
  5. + Apply softmax over all \(\text{context_lens}[s]\) positions to get attention weights. +
  6. +
  7. + Compute: \(\displaystyle \text{output}[s, h] = \sum_{t} \text{softmax}(\text{scores})[t] \cdot V[s, h, t]\) +
  8. +
+

+ Do not use external libraries beyond the framework you select. Keep the function signature unchanged. + Write results directly into output. +

+ +

Example

+

+ Input: batch_size = 1, num_heads = 1, head_dim = 4, + block_size = 2, context_lens = [2], block_table = [[0]] +

+

+ \(Q[0, 0] = \begin{bmatrix} 1.0 & 1.0 & 0.0 & 0.0 \end{bmatrix}\) +

+

+ Keys gathered from block 0 (2 tokens): + \[ + K_0 = \begin{bmatrix} 1.0 & 0.0 & 0.0 & 0.0 \end{bmatrix}, \quad + K_1 = \begin{bmatrix} 0.0 & 1.0 & 0.0 & 0.0 \end{bmatrix} + \] + Values gathered from block 0: + \[ + V_0 = \begin{bmatrix} 2.0 & 0.0 & 0.0 & 0.0 \end{bmatrix}, \quad + V_1 = \begin{bmatrix} 0.0 & 4.0 & 0.0 & 0.0 \end{bmatrix} + \] +

+

+ Scores (before softmax): + \[ + s_0 = \frac{Q \cdot K_0}{\sqrt{4}} = \frac{1}{2} = 0.5, \quad + s_1 = \frac{Q \cdot K_1}{\sqrt{4}} = \frac{1}{2} = 0.5 + \] + Attention weights: \(\text{softmax}([0.5, 0.5]) = [0.5, 0.5]\) + \[ + \text{output}[0, 0] = 0.5 \cdot V_0 + 0.5 \cdot V_1 = + \begin{bmatrix} 1.0 & 2.0 & 0.0 & 0.0 \end{bmatrix} + \] +

+ +

Constraints

+ diff --git a/challenges/medium/86_paged_attention/challenge.py b/challenges/medium/86_paged_attention/challenge.py new file mode 100644 index 00000000..824f869b --- /dev/null +++ b/challenges/medium/86_paged_attention/challenge.py @@ -0,0 +1,231 @@ +import ctypes +import math +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="Paged KV-Cache Attention", + atol=1e-04, + rtol=1e-04, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + Q: torch.Tensor, + K_cache: torch.Tensor, + V_cache: torch.Tensor, + block_table: torch.Tensor, + context_lens: torch.Tensor, + output: torch.Tensor, + batch_size: int, + num_heads: int, + head_dim: int, + block_size: int, + max_blocks_per_seq: int, + ): + assert Q.shape == (batch_size, num_heads, head_dim) + assert K_cache.shape[1] == block_size + assert K_cache.shape[2] == num_heads + assert K_cache.shape[3] == head_dim + assert V_cache.shape == K_cache.shape + assert block_table.shape == (batch_size, max_blocks_per_seq) + assert context_lens.shape == (batch_size,) + assert output.shape == (batch_size, num_heads, head_dim) + assert Q.dtype == K_cache.dtype == V_cache.dtype == output.dtype == torch.float32 + assert block_table.dtype == context_lens.dtype == torch.int32 + assert Q.device.type == "cuda" + assert K_cache.device.type == "cuda" + assert V_cache.device.type == "cuda" + assert block_table.device.type == "cuda" + assert context_lens.device.type == "cuda" + assert output.device.type == "cuda" + + scale = 1.0 / math.sqrt(head_dim) + + for s in range(batch_size): + ctx_len = context_lens[s].item() + n_blocks = (ctx_len + block_size - 1) // block_size + + # Gather the physical blocks assigned to this sequence + phys_blocks = block_table[s, :n_blocks].long() # (n_blocks,) + + # Gather K and V: (n_blocks, block_size, num_heads, head_dim) + K_blocks = K_cache[phys_blocks] + V_blocks = V_cache[phys_blocks] + + # Flatten to (n_blocks * block_size, num_heads, head_dim) and trim + K_seq = K_blocks.reshape(-1, num_heads, head_dim)[ + :ctx_len + ] # (ctx_len, num_heads, head_dim) + V_seq = V_blocks.reshape(-1, num_heads, head_dim)[:ctx_len] + + # Transpose to (num_heads, ctx_len, head_dim) + K_seq = K_seq.transpose(0, 1).contiguous() + V_seq = V_seq.transpose(0, 1).contiguous() + + # Q[s]: (num_heads, head_dim) -> (num_heads, 1, head_dim) + q = Q[s].unsqueeze(1) + + # Scaled dot-product: (num_heads, 1, ctx_len) + scores = torch.bmm(q, K_seq.transpose(1, 2)) * scale + attn_weights = torch.softmax(scores, dim=-1) + + # Weighted sum: (num_heads, 1, head_dim) -> (num_heads, head_dim) + out = torch.bmm(attn_weights, V_seq).squeeze(1) + output[s].copy_(out) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K_cache": (ctypes.POINTER(ctypes.c_float), "in"), + "V_cache": (ctypes.POINTER(ctypes.c_float), "in"), + "block_table": (ctypes.POINTER(ctypes.c_int), "in"), + "context_lens": (ctypes.POINTER(ctypes.c_int), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "batch_size": (ctypes.c_int, "in"), + "num_heads": (ctypes.c_int, "in"), + "head_dim": (ctypes.c_int, "in"), + "block_size": (ctypes.c_int, "in"), + "max_blocks_per_seq": (ctypes.c_int, "in"), + } + + def _make_test_case( + self, batch_size, num_heads, head_dim, block_size, context_lens, zero_q=False + ): + if isinstance(context_lens, int): + context_lens = [context_lens] * batch_size + + max_ctx = max(context_lens) + max_blocks_per_seq = (max_ctx + block_size - 1) // block_size + + # Allocate exactly the blocks needed, assigned sequentially + total_blocks = sum((cl + block_size - 1) // block_size for cl in context_lens) + + device = "cuda" + dtype = torch.float32 + + if zero_q: + Q = torch.zeros(batch_size, num_heads, head_dim, device=device, dtype=dtype) + else: + Q = torch.randn(batch_size, num_heads, head_dim, device=device, dtype=dtype) + + K_cache = torch.randn( + total_blocks, block_size, num_heads, head_dim, device=device, dtype=dtype + ) + V_cache = torch.randn( + total_blocks, block_size, num_heads, head_dim, device=device, dtype=dtype + ) + + block_table = torch.zeros(batch_size, max_blocks_per_seq, device=device, dtype=torch.int32) + ctx_lens_tensor = torch.tensor(context_lens, device=device, dtype=torch.int32) + + # Assign physical blocks sequentially per sequence + block_idx = 0 + for s in range(batch_size): + n_blocks = (context_lens[s] + block_size - 1) // block_size + for b in range(n_blocks): + block_table[s, b] = block_idx + block_idx += 1 + + output = torch.zeros(batch_size, num_heads, head_dim, device=device, dtype=dtype) + + return { + "Q": Q, + "K_cache": K_cache, + "V_cache": V_cache, + "block_table": block_table, + "context_lens": ctx_lens_tensor, + "output": output, + "batch_size": batch_size, + "num_heads": num_heads, + "head_dim": head_dim, + "block_size": block_size, + "max_blocks_per_seq": max_blocks_per_seq, + } + + def generate_example_test(self) -> Dict[str, Any]: + device = "cuda" + dtype = torch.float32 + + # batch=1, heads=1, head_dim=4, block_size=2, ctx_len=2 + # Q · K / sqrt(4): [1,1,0,0]·[1,0,0,0]/2 = 0.5, [1,1,0,0]·[0,1,0,0]/2 = 0.5 + # attn = softmax([0.5, 0.5]) = [0.5, 0.5] + # output = 0.5*[2,0,0,0] + 0.5*[0,4,0,0] = [1, 2, 0, 0] + Q = torch.tensor([[[1.0, 1.0, 0.0, 0.0]]], device=device, dtype=dtype) # (1, 1, 4) + K_cache = torch.tensor( + [[[[1.0, 0.0, 0.0, 0.0]], [[0.0, 1.0, 0.0, 0.0]]]], + device=device, + dtype=dtype, + ) # (1 block, block_size=2, 1 head, head_dim=4) + V_cache = torch.tensor( + [[[[2.0, 0.0, 0.0, 0.0]], [[0.0, 4.0, 0.0, 0.0]]]], + device=device, + dtype=dtype, + ) + block_table = torch.tensor( + [[0]], device=device, dtype=torch.int32 + ) # seq 0 -> physical block 0 + context_lens = torch.tensor([2], device=device, dtype=torch.int32) + output = torch.zeros(1, 1, 4, device=device, dtype=dtype) + + return { + "Q": Q, + "K_cache": K_cache, + "V_cache": V_cache, + "block_table": block_table, + "context_lens": context_lens, + "output": output, + "batch_size": 1, + "num_heads": 1, + "head_dim": 4, + "block_size": 2, + "max_blocks_per_seq": 1, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + + # Edge case: single KV token + tests.append(self._make_test_case(1, 1, 4, 2, 1)) + + # Edge case: ctx_len equals block_size exactly + tests.append(self._make_test_case(1, 2, 8, 4, 4)) + + # Zero query: softmax is uniform, output is mean of V + tests.append(self._make_test_case(2, 2, 8, 4, 8, zero_q=True)) + + # Variable context lengths within a batch + tests.append(self._make_test_case(4, 4, 32, 16, [16, 32, 48, 64])) + + # Power-of-2 context lengths + tests.append(self._make_test_case(4, 4, 32, 16, 32)) + + # Power-of-2, larger + tests.append(self._make_test_case(4, 8, 64, 16, 128)) + + # Non-power-of-2 context length + tests.append(self._make_test_case(2, 4, 32, 16, 30)) + + # Non-power-of-2, straddles multiple blocks + tests.append(self._make_test_case(4, 4, 64, 16, 100)) + + # Mixed variable lengths with non-power-of-2 + tests.append(self._make_test_case(4, 8, 64, 16, [50, 100, 150, 200])) + + # Realistic: LLaMA-3 8B style (8 Q heads), shorter context + tests.append(self._make_test_case(4, 8, 128, 16, 256)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # Realistic LLM decode: batch=8, 32 heads, head_dim=128, block_size=16, ctx_len=2048 + return self._make_test_case(8, 32, 128, 16, 2048) diff --git a/challenges/medium/86_paged_attention/starter/starter.cu b/challenges/medium/86_paged_attention/starter/starter.cu new file mode 100644 index 00000000..b72d1b9c --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.cu @@ -0,0 +1,7 @@ +#include + +// Q, K_cache, V_cache, block_table, context_lens, output are device pointers +extern "C" void solve(const float* Q, const float* K_cache, const float* V_cache, + const int* block_table, const int* context_lens, float* output, + int batch_size, int num_heads, int head_dim, int block_size, + int max_blocks_per_seq) {} diff --git a/challenges/medium/86_paged_attention/starter/starter.cute.py b/challenges/medium/86_paged_attention/starter/starter.cute.py new file mode 100644 index 00000000..d703ed65 --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.cute.py @@ -0,0 +1,20 @@ +import cutlass +import cutlass.cute as cute + + +# Q, K_cache, V_cache, block_table, context_lens, output are tensors on the GPU +@cute.jit +def solve( + Q: cute.Tensor, + K_cache: cute.Tensor, + V_cache: cute.Tensor, + block_table: cute.Tensor, + context_lens: cute.Tensor, + output: cute.Tensor, + batch_size: cute.Int32, + num_heads: cute.Int32, + head_dim: cute.Int32, + block_size: cute.Int32, + max_blocks_per_seq: cute.Int32, +): + pass diff --git a/challenges/medium/86_paged_attention/starter/starter.jax.py b/challenges/medium/86_paged_attention/starter/starter.jax.py new file mode 100644 index 00000000..cd82ce9b --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.jax.py @@ -0,0 +1,20 @@ +import jax +import jax.numpy as jnp + + +# Q, K_cache, V_cache, block_table, context_lens are tensors on GPU +@jax.jit +def solve( + Q: jax.Array, + K_cache: jax.Array, + V_cache: jax.Array, + block_table: jax.Array, + context_lens: jax.Array, + batch_size: int, + num_heads: int, + head_dim: int, + block_size: int, + max_blocks_per_seq: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/86_paged_attention/starter/starter.mojo b/challenges/medium/86_paged_attention/starter/starter.mojo new file mode 100644 index 00000000..ce8b7e21 --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.mojo @@ -0,0 +1,21 @@ +from gpu.host import DeviceContext +from gpu.id import block_dim, block_idx, thread_idx +from memory import UnsafePointer +from math import ceildiv + +# Q, K_cache, V_cache, block_table, context_lens, output are device pointers +@export +def solve( + Q: UnsafePointer[Float32], + K_cache: UnsafePointer[Float32], + V_cache: UnsafePointer[Float32], + block_table: UnsafePointer[Int32], + context_lens: UnsafePointer[Int32], + output: UnsafePointer[Float32], + batch_size: Int32, + num_heads: Int32, + head_dim: Int32, + block_size: Int32, + max_blocks_per_seq: Int32, +): + pass diff --git a/challenges/medium/86_paged_attention/starter/starter.pytorch.py b/challenges/medium/86_paged_attention/starter/starter.pytorch.py new file mode 100644 index 00000000..aeb42ce3 --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.pytorch.py @@ -0,0 +1,18 @@ +import torch + + +# Q, K_cache, V_cache, block_table, context_lens, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K_cache: torch.Tensor, + V_cache: torch.Tensor, + block_table: torch.Tensor, + context_lens: torch.Tensor, + output: torch.Tensor, + batch_size: int, + num_heads: int, + head_dim: int, + block_size: int, + max_blocks_per_seq: int, +): + pass diff --git a/challenges/medium/86_paged_attention/starter/starter.triton.py b/challenges/medium/86_paged_attention/starter/starter.triton.py new file mode 100644 index 00000000..7c392628 --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.triton.py @@ -0,0 +1,20 @@ +import torch +import triton +import triton.language as tl + + +# Q, K_cache, V_cache, block_table, context_lens, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K_cache: torch.Tensor, + V_cache: torch.Tensor, + block_table: torch.Tensor, + context_lens: torch.Tensor, + output: torch.Tensor, + batch_size: int, + num_heads: int, + head_dim: int, + block_size: int, + max_blocks_per_seq: int, +): + pass