Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions challenges/medium/88_prefix_cached_attention/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
<p>
Implement prefix-cached attention, the attention pattern used during <em>chunked prefill</em> in
LLM inference systems such as vLLM and TensorRT-LLM. Given query tensors for a chunk of
<code>new_len</code> tokens and packed key/value tensors containing both a cached prefix of
<code>cache_len</code> tokens and the new tokens themselves, compute scaled dot-product attention
where each new query token attends to all cached tokens (full access) and causally to the new
tokens (lower-triangular access). All tensors use <code>float32</code>.
</p>

<svg width="680" height="280" viewBox="0 0 680 280" xmlns="http://www.w3.org/2000/svg" style="display:block; margin:20px auto;">
<rect width="680" height="280" fill="#222" rx="10"/>

<!-- Title -->
<text x="340" y="26" fill="#ccc" font-family="monospace" font-size="13" text-anchor="middle">Attention mask (cache_len=4, new_len=4, total_len=8)</text>

<!-- Column labels -->
<text x="200" y="55" fill="#60a5fa" font-family="monospace" font-size="11" text-anchor="middle">K cache (j=0..3)</text>
<text x="390" y="55" fill="#4ade80" font-family="monospace" font-size="11" text-anchor="middle">K new (j=4..7)</text>

<!-- Row labels -->
<text x="60" y="100" fill="#ccc" font-family="monospace" font-size="11" text-anchor="end">Q[0]</text>
<text x="60" y="130" fill="#ccc" font-family="monospace" font-size="11" text-anchor="end">Q[1]</text>
<text x="60" y="160" fill="#ccc" font-family="monospace" font-size="11" text-anchor="end">Q[2]</text>
<text x="60" y="190" fill="#ccc" font-family="monospace" font-size="11" text-anchor="end">Q[3]</text>

<!-- Cache block: all 4 queries attend to all 4 cache keys (full rectangle) -->
<!-- Row 0 -->
<rect x="70" y="82" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="106" y="82" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="142" y="82" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="178" y="82" width="32" height="32" fill="#1d4ed8" rx="2"/>
<!-- Row 1 -->
<rect x="70" y="114" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="106" y="114" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="142" y="114" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="178" y="114" width="32" height="32" fill="#1d4ed8" rx="2"/>
<!-- Row 2 -->
<rect x="70" y="146" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="106" y="146" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="142" y="146" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="178" y="146" width="32" height="32" fill="#1d4ed8" rx="2"/>
<!-- Row 3 -->
<rect x="70" y="178" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="106" y="178" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="142" y="178" width="32" height="32" fill="#1d4ed8" rx="2"/>
<rect x="178" y="178" width="32" height="32" fill="#1d4ed8" rx="2"/>

<!-- New-token block: lower-triangular (causal) -->
<!-- Row 0: Q[0] attends only to K_new[0] (j=4) -->
<rect x="214" y="82" width="32" height="32" fill="#15803d" rx="2"/>
<rect x="250" y="82" width="32" height="32" fill="#333" rx="2"/>
<rect x="286" y="82" width="32" height="32" fill="#333" rx="2"/>
<rect x="322" y="82" width="32" height="32" fill="#333" rx="2"/>
<!-- Row 1: Q[1] attends to K_new[0,1] -->
<rect x="214" y="114" width="32" height="32" fill="#15803d" rx="2"/>
<rect x="250" y="114" width="32" height="32" fill="#15803d" rx="2"/>
<rect x="286" y="114" width="32" height="32" fill="#333" rx="2"/>
<rect x="322" y="114" width="32" height="32" fill="#333" rx="2"/>
<!-- Row 2: Q[2] attends to K_new[0,1,2] -->
<rect x="214" y="146" width="32" height="32" fill="#15803d" rx="2"/>
<rect x="250" y="146" width="32" height="32" fill="#15803d" rx="2"/>
<rect x="286" y="146" width="32" height="32" fill="#15803d" rx="2"/>
<rect x="322" y="146" width="32" height="32" fill="#333" rx="2"/>
<!-- Row 3: Q[3] attends to all K_new -->
<rect x="214" y="178" width="32" height="32" fill="#15803d" rx="2"/>
<rect x="250" y="178" width="32" height="32" fill="#15803d" rx="2"/>
<rect x="286" y="178" width="32" height="32" fill="#15803d" rx="2"/>
<rect x="322" y="178" width="32" height="32" fill="#15803d" rx="2"/>

<!-- Divider between cache and new blocks -->
<line x1="212" y1="78" x2="212" y2="216" stroke="#aaa" stroke-width="1.5" stroke-dasharray="4,3"/>

<!-- Legend -->
<rect x="430" y="100" width="16" height="16" fill="#1d4ed8" rx="2"/>
<text x="452" y="113" fill="#ccc" font-family="monospace" font-size="12">attend (cache)</text>
<rect x="430" y="124" width="16" height="16" fill="#15803d" rx="2"/>
<text x="452" y="137" fill="#ccc" font-family="monospace" font-size="12">attend (causal)</text>
<rect x="430" y="148" width="16" height="16" fill="#333" rx="2"/>
<text x="452" y="161" fill="#ccc" font-family="monospace" font-size="12">masked out</text>

<text x="430" y="195" fill="#4ade80" font-family="monospace" font-size="11">mask: j &lt;= cache_len + i</text>
<text x="430" y="213" fill="#4ade80" font-family="monospace" font-size="11">scale = 1 / sqrt(head_dim)</text>
<text x="430" y="231" fill="#4ade80" font-family="monospace" font-size="11">output = softmax(scores) @ V</text>
</svg>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement the function <code>solve(Q, K, V, output, num_heads, cache_len, new_len, head_dim)</code>.</li>
<li>Do not change the function signature or use external libraries beyond the standard GPU frameworks.</li>
<li>Write the result into the provided <code>output</code> buffer.</li>
<li>Use scaled dot-product attention with scale factor <code>1 / sqrt(head_dim)</code>.</li>
<li>
Apply the causal mask: query token <code>i</code> (at absolute sequence position
<code>cache_len + i</code>) attends to key token <code>j</code> if and only if
<code>j &le; cache_len + i</code>. Masked positions receive <code>-inf</code> before
softmax.
</li>
<li>
<code>K</code> and <code>V</code> are packed buffers of shape
<code>(num_heads, cache_len + new_len, head_dim)</code>; the first <code>cache_len</code>
positions along the sequence dimension are the cached prefix.
</li>
</ul>

<h2>Example</h2>
<p>
With <code>num_heads</code> = 2, <code>cache_len</code> = 2, <code>new_len</code> = 2,
<code>head_dim</code> = 4 (total_len = 4):
</p>
<p>
<strong>Input:</strong><br>
\(Q_0\) (2&times;4):
\[
\begin{bmatrix}
1 & 0 & 0 & 1 \\
0 & 1 & 1 & 0
\end{bmatrix}
\]
\(Q_1\) (2&times;4):
\[
\begin{bmatrix}
0 & 1 & 0 & 1 \\
1 & 0 & 1 & 0
\end{bmatrix}
\]
\(K_0\) (4&times;4, cache rows first):
\[
\begin{bmatrix}
1 & 0 & 1 & 0 \\
0 & 1 & 0 & 1 \\
1 & 1 & 0 & 0 \\
0 & 0 & 1 & 1
\end{bmatrix}
\]
\(K_1\) (4&times;4):
\[
\begin{bmatrix}
0 & 1 & 0 & -1 \\
-1 & 0 & 1 & 0 \\
1 & 0 & -1 & 0 \\
0 & 1 & 0 & 1
\end{bmatrix}
\]
\(V_0\) (4&times;4):
\[
\begin{bmatrix}
1 & 2 & 3 & 4 \\
5 & 6 & 7 & 8 \\
9 & 10 & 11 & 12 \\
13 & 14 & 15 & 16
\end{bmatrix}
\]
\(V_1\) (4&times;4):
\[
\begin{bmatrix}
-1 & -2 & -3 & -4 \\
2 & 3 & 4 & 5 \\
6 & 7 & 8 & 9 \\
-2 & -3 & -4 & -5
\end{bmatrix}
\]
\(\text{cache\_len} = 2\), \(\text{new\_len} = 2\).<br>
Query token 0 (absolute position 2) attends to \(K[\,:\,,\,0{:}3,\,:\,]\); token 1 attends to all four keys.
</p>
<p>
<strong>Output</strong> (values rounded to 2 decimal places):<br>
\(\text{output}_0\) (2&times;4):
\[
\begin{bmatrix}
5.00 & 6.00 & 7.00 & 8.00 \\
7.00 & 8.00 & 9.00 & 10.00
\end{bmatrix}
\]
\(\text{output}_1\) (2&times;4):
\[
\begin{bmatrix}
2.33 & 2.67 & 3.00 & 3.33 \\
1.25 & 1.25 & 1.25 & 1.25
\end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>num_heads</code> &le; 64</li>
<li>0 &le; <code>cache_len</code> &le; 4,096</li>
<li>1 &le; <code>new_len</code> &le; 1,024</li>
<li>8 &le; <code>head_dim</code> &le; 256; <code>head_dim</code> is a multiple of 8</li>
<li>All tensor values are <code>float32</code></li>
<li>
Performance is measured with <code>num_heads</code> = 32, <code>cache_len</code> = 1,024,
<code>new_len</code> = 512, <code>head_dim</code> = 128
</li>
</ul>
197 changes: 197 additions & 0 deletions challenges/medium/88_prefix_cached_attention/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import ctypes
import math
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="Prefix-Cached Attention",
atol=1e-04,
rtol=1e-04,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
num_heads: int,
cache_len: int,
new_len: int,
head_dim: int,
):
total_len = cache_len + new_len
assert Q.shape == (num_heads, new_len, head_dim)
assert K.shape == (num_heads, total_len, head_dim)
assert V.shape == (num_heads, total_len, head_dim)
assert output.shape == (num_heads, new_len, head_dim)
assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
assert Q.device.type == "cuda"
assert K.device.type == "cuda"
assert V.device.type == "cuda"
assert output.device.type == "cuda"

scale = 1.0 / math.sqrt(head_dim)

# scores: (num_heads, new_len, total_len)
scores = torch.bmm(Q, K.transpose(1, 2)) * scale

# Causal mask: query token i (at absolute position cache_len+i) attends to
# key token j iff j <= cache_len + i.
# This gives full access to the KV cache and causal access within new tokens.
i_idx = torch.arange(new_len, device=Q.device).unsqueeze(1) # (new_len, 1)
j_idx = torch.arange(total_len, device=Q.device).unsqueeze(0) # (1, total_len)
mask = j_idx <= cache_len + i_idx # (new_len, total_len)

scores = scores.masked_fill(~mask.unsqueeze(0), float("-inf"))
attn_weights = torch.softmax(scores, dim=-1)
output.copy_(torch.bmm(attn_weights, V))

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
"K": (ctypes.POINTER(ctypes.c_float), "in"),
"V": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"num_heads": (ctypes.c_int, "in"),
"cache_len": (ctypes.c_int, "in"),
"new_len": (ctypes.c_int, "in"),
"head_dim": (ctypes.c_int, "in"),
}

def _make_test_case(self, num_heads, cache_len, new_len, head_dim, zero_inputs=False):
total_len = cache_len + new_len
dtype = torch.float32
device = "cuda"
if zero_inputs:
Q = torch.zeros(num_heads, new_len, head_dim, device=device, dtype=dtype)
K = torch.zeros(num_heads, total_len, head_dim, device=device, dtype=dtype)
V = torch.zeros(num_heads, total_len, head_dim, device=device, dtype=dtype)
else:
Q = torch.randn(num_heads, new_len, head_dim, device=device, dtype=dtype)
K = torch.randn(num_heads, total_len, head_dim, device=device, dtype=dtype)
V = torch.randn(num_heads, total_len, head_dim, device=device, dtype=dtype)
output = torch.zeros(num_heads, new_len, head_dim, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"num_heads": num_heads,
"cache_len": cache_len,
"new_len": new_len,
"head_dim": head_dim,
}

def generate_example_test(self) -> Dict[str, Any]:
num_heads = 2
cache_len = 2
new_len = 2
head_dim = 4
device = "cuda"
dtype = torch.float32

Q = torch.tensor(
[
[[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]],
[[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]],
],
device=device,
dtype=dtype,
)
K = torch.tensor(
[
[
[1.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0],
],
[
[0.0, 1.0, 0.0, -1.0],
[-1.0, 0.0, 1.0, 0.0],
[1.0, 0.0, -1.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
],
],
device=device,
dtype=dtype,
)
V = torch.tensor(
[
[
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
],
[
[-1.0, -2.0, -3.0, -4.0],
[2.0, 3.0, 4.0, 5.0],
[6.0, 7.0, 8.0, 9.0],
[-2.0, -3.0, -4.0, -5.0],
],
],
device=device,
dtype=dtype,
)
output = torch.zeros(num_heads, new_len, head_dim, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"num_heads": num_heads,
"cache_len": cache_len,
"new_len": new_len,
"head_dim": head_dim,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
torch.manual_seed(42)
tests = []

# Edge case: single decode step against a single cached token
tests.append(self._make_test_case(1, 1, 1, 4))

# Edge case: zero inputs
tests.append(self._make_test_case(2, 2, 2, 4, zero_inputs=True))

# cache_len=0: pure causal self-attention over new tokens
tests.append(self._make_test_case(2, 0, 4, 8))

# Single decode step (new_len=1) — typical autoregressive generation
tests.append(self._make_test_case(4, 16, 1, 32))

# Power-of-2 sizes
tests.append(self._make_test_case(4, 32, 16, 32))

# Larger power-of-2
tests.append(self._make_test_case(8, 64, 32, 64))

# Non-power-of-2 sizes
tests.append(self._make_test_case(4, 30, 15, 32))

# Non-power-of-2 with more heads
tests.append(self._make_test_case(6, 100, 50, 32))

# Long cache, short new chunk
tests.append(self._make_test_case(8, 255, 3, 64))

# Realistic dimensions (LLaMA-style), short chunk
tests.append(self._make_test_case(16, 128, 64, 64))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
# LLaMA-3 8B style: 32 heads, head_dim=128
# cache_len=1024 (prior context), new_len=512 (chunk being prefilled)
return self._make_test_case(32, 1024, 512, 128)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// Q, K, V, output are device pointers
extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads,
int cache_len, int new_len, int head_dim) {}
Loading
Loading