Skip to content

Commit d8ccf87

Browse files
Add: paged attention unroll scene test with 4D input shapes
- New paged_attention_unroll_4dims test under tensormap_and_ringbuffer - Query and output tensors use 4D format (batch, seq_len, num_heads, head_dim) - 6 kernels: QK/PV matmul (AIC), softmax_prepare/online_update (AIV), hub stubs - Orchestration with N_UNROLL=64, 4 tasks per group, online softmax accumulation - Golden wraps shared paged_attention_golden with 4D reshape adapter - Three test cases: varying batch/heads/head_dim at production scale (bfloat16)
1 parent f265671 commit d8ccf87

9 files changed

Lines changed: 1266 additions & 0 deletions

File tree

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Paged Attention Unroll Golden - tensormap_and_ringbuffer test (production scale, bfloat16).
2+
3+
Input shapes use 4D format: (batch, seq_len, num_heads, head_dim) for query and out.
4+
"""
5+
6+
from paged_attention_golden import (
7+
generate_inputs as _generate_inputs,
8+
compute_golden as _compute_golden,
9+
run_golden_test,
10+
)
11+
12+
__outputs__ = ["out"]
13+
14+
RTOL = 1e-3
15+
ATOL = 1e-3
16+
17+
ALL_CASES = {
18+
"Case1": {
19+
"batch": 256,
20+
"num_heads": 16,
21+
"kv_head_num": 1,
22+
"head_dim": 128,
23+
"block_size": 128,
24+
"context_len": 8192,
25+
"max_model_len": 32768,
26+
"dtype": "bfloat16",
27+
},
28+
"Case2": {
29+
"batch": 64,
30+
"num_heads": 64,
31+
"kv_head_num": 1,
32+
"head_dim": 128,
33+
"block_size": 64,
34+
"context_len": 8192,
35+
"max_model_len": 32768,
36+
"dtype": "bfloat16",
37+
},
38+
"Case3": {
39+
"batch": 64,
40+
"num_heads": 64,
41+
"kv_head_num": 1,
42+
"head_dim": 256,
43+
"block_size": 64,
44+
"context_len": 8192,
45+
"max_model_len": 32768,
46+
"dtype": "bfloat16",
47+
},
48+
}
49+
50+
DEFAULT_CASE = "Case1"
51+
52+
53+
def generate_inputs(params: dict) -> list:
54+
result = _generate_inputs(params)
55+
batch = params["batch"]
56+
num_heads = params["num_heads"]
57+
head_dim = params["head_dim"]
58+
reshaped = []
59+
for name, val in result:
60+
if name in ("query", "out"):
61+
val = val.reshape(batch, 1, num_heads, head_dim)
62+
reshaped.append((name, val))
63+
return reshaped
64+
65+
66+
def compute_golden(tensors: dict, params: dict) -> None:
67+
batch = params["batch"]
68+
num_heads = params["num_heads"]
69+
head_dim = params["head_dim"]
70+
out_4d = tensors["out"]
71+
tensors["out"] = out_4d.reshape(batch, num_heads, head_dim)
72+
_compute_golden(tensors, params)
73+
tensors["out"] = out_4d
74+
75+
76+
if __name__ == "__main__":
77+
run_golden_test(ALL_CASES, DEFAULT_CASE, generate_inputs, label="Paged Attention Unroll")
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include <cstdint>
2+
#include <pto/pto-inst.hpp>
3+
4+
using namespace pto;
5+
6+
#ifndef __gm__
7+
#define __gm__
8+
#endif
9+
10+
#ifndef __aicore__
11+
#define __aicore__ [aicore]
12+
#endif
13+
14+
constexpr int M = 16;
15+
constexpr int K = 16;
16+
constexpr int N = 16;
17+
18+
extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {}
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// SplitK PV Matmul Kernel: Accumulated P @ V across n_blocks
2+
//
3+
// Processes n_blocks blocks using SplitK accumulation pattern:
4+
// Block 0: TMATMUL(C, A, B) — initialize accumulator
5+
// Block i: TMATMUL_ACC(C, C, A, B) — accumulate into same C
6+
//
7+
// Per-block pij addresses: contiguous slices of pij_buf (n_blocks * M * K)
8+
// Per-block vj addresses: value_cache base + block_indices lookup
9+
// Single output: oi_new (M, N) fp32 = sum of P_i @ V_i across all blocks
10+
//
11+
// Optimizations:
12+
// - Double-buffered L1 tiles (ping/pong for A and B via MTE2)
13+
// - Double-buffered L0 tiles (ping/pong for L0A and L0B via MTE1)
14+
// - TLOAD(next) overlaps with TMATMUL(current) via MTE2/M-pipe parallelism
15+
// - Canonical 3-stage pipeline: TLOAD(MTE2) → TMOV(MTE1) → TMATMUL(M)
16+
// - Reverse-dependency events ensure buffer safety across iterations
17+
//
18+
// Supports two tile configurations via runtime dispatch:
19+
// Case1: (16, 128) @ (128, 128) -> (16, 128)
20+
// Case2: (64, 64) @ ( 64, 128) -> (64, 128)
21+
//
22+
// pij is bfloat16 (from softmax_prepare TCVT).
23+
// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout.
24+
25+
#include <cstdint>
26+
#include <pto/pto-inst.hpp>
27+
28+
#include "tensor.h"
29+
30+
using namespace pto;
31+
32+
#ifndef __gm__
33+
#define __gm__
34+
#endif
35+
36+
#ifndef __aicore__
37+
#define __aicore__ [aicore]
38+
#endif
39+
40+
template <int M, int K, int N>
41+
static __aicore__ void pv_matmul_n_impl(
42+
__gm__ bfloat16_t* pij_base,
43+
__gm__ bfloat16_t* val_base,
44+
__gm__ float* oi_base,
45+
uint64_t n_blocks,
46+
__gm__ int32_t* block_table) {
47+
48+
using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
49+
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
50+
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;
51+
52+
using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
53+
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;
54+
55+
using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
56+
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
57+
using AccTile = TileAcc<float, M, N, M, N>;
58+
59+
// L1 memory layout: double-buffered A and B tiles (tightly packed)
60+
constexpr int kATileBytes = M * K * static_cast<int>(sizeof(bfloat16_t));
61+
constexpr int kBTileBytes = K * N * static_cast<int>(sizeof(bfloat16_t));
62+
63+
TileMatA aMatTile[2];
64+
TileMatB bMatTile[2];
65+
TASSIGN(aMatTile[0], 0x0);
66+
TASSIGN(aMatTile[1], kATileBytes);
67+
TASSIGN(bMatTile[0], 2 * kATileBytes);
68+
TASSIGN(bMatTile[1], 2 * kATileBytes + kBTileBytes);
69+
70+
// L0 memory layout: double-buffered L0A and L0B, single accumulator L0C
71+
LeftTile aTile[2];
72+
RightTile bTile[2];
73+
AccTile cTile;
74+
TASSIGN(aTile[0], 0x0);
75+
TASSIGN(aTile[1], kATileBytes);
76+
TASSIGN(bTile[0], 0x0);
77+
TASSIGN(bTile[1], kBTileBytes);
78+
TASSIGN(cTile, 0x0);
79+
80+
GlobalOut oiGlobal(oi_base);
81+
82+
// Seed reverse-dependency flags: all ping/pong buffers initially free
83+
// PIPE_MTE1 → PIPE_MTE2: L1 buffer [0/1] safe for TLOAD to overwrite
84+
// PIPE_M → PIPE_MTE1: L0 buffer [0/1] safe for TMOV to overwrite
85+
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
86+
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
87+
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
88+
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
89+
90+
for (uint64_t i = 0; i < n_blocks; i++) {
91+
int cur = static_cast<int>(i % 2);
92+
GlobalA pijGlobal(pij_base + i * M * K);
93+
GlobalB vjGlobal(val_base + block_table[i] * K * N);
94+
95+
// Stage 1: TLOAD (MTE2: GM → L1[cur])
96+
// Wait for MTE1 to release L1[cur] (reverse dep from previous iteration)
97+
wait_flag(PIPE_MTE1, PIPE_MTE2, (event_t)cur);
98+
TLOAD(aMatTile[cur], pijGlobal);
99+
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // forward: A in L1 ready
100+
TLOAD(bMatTile[cur], vjGlobal);
101+
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // forward: B in L1 ready
102+
103+
// Stage 2: TMOV (MTE1: L1[cur] → L0[cur])
104+
// Wait for M-pipe to release L0[cur] (reverse dep from previous iteration)
105+
wait_flag(PIPE_M, PIPE_MTE1, (event_t)cur);
106+
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // forward: wait A loaded
107+
TMOV(aTile[cur], aMatTile[cur]);
108+
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // forward: wait B loaded
109+
TMOV(bTile[cur], bMatTile[cur]);
110+
set_flag(PIPE_MTE1, PIPE_MTE2, (event_t)cur); // reverse: release L1[cur]
111+
112+
// Stage 3: TMATMUL (M-pipe: L0A[cur] × L0B[cur] → L0C)
113+
set_flag(PIPE_MTE1, PIPE_M, (event_t)cur); // forward: L0[cur] ready
114+
wait_flag(PIPE_MTE1, PIPE_M, (event_t)cur);
115+
if (i == 0) {
116+
TMATMUL(cTile, aTile[cur], bTile[cur]);
117+
} else {
118+
TMATMUL_ACC(cTile, cTile, aTile[cur], bTile[cur]);
119+
}
120+
set_flag(PIPE_M, PIPE_MTE1, (event_t)cur); // reverse: release L0[cur]
121+
}
122+
123+
// Drain outstanding reverse-dependency flags
124+
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
125+
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
126+
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
127+
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
128+
129+
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
130+
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
131+
TSTORE(oiGlobal, cTile);
132+
133+
set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
134+
wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
135+
}
136+
137+
extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {
138+
__gm__ TensorData* pij_buf = reinterpret_cast<__gm__ TensorData*>(args[0]);
139+
__gm__ TensorData* value_cache = reinterpret_cast<__gm__ TensorData*>(args[1]);
140+
__gm__ TensorData* oi_new = reinterpret_cast<__gm__ TensorData*>(args[2]);
141+
uint64_t n_blocks = static_cast<uint64_t>(args[3]);
142+
__gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]);
143+
144+
__gm__ bfloat16_t* pij_base = reinterpret_cast<__gm__ bfloat16_t*>(pij_buf->buffer.addr) + pij_buf->start_offset;
145+
__gm__ bfloat16_t* val_base = reinterpret_cast<__gm__ bfloat16_t*>(value_cache->buffer.addr);
146+
__gm__ float* oi_base = reinterpret_cast<__gm__ float*>(oi_new->buffer.addr) + oi_new->start_offset;
147+
148+
uint64_t q_tile_size = static_cast<uint64_t>(pij_buf->shapes[0]);
149+
150+
if (q_tile_size == 16) {
151+
pv_matmul_n_impl<16, 128, 128>(pij_base, val_base, oi_base, n_blocks, block_table);
152+
} else {
153+
pv_matmul_n_impl<64, 64, 128>(pij_base, val_base, oi_base, n_blocks, block_table);
154+
}
155+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Multi-block QK Matmul Kernel: qi(M, K) @ kj.T(K, N) -> sij(M, N) for each block
2+
//
3+
// Processes n_blocks blocks in a single kernel invocation.
4+
// Per-block kj addresses computed from key_cache base + block_indices lookup.
5+
// qi is shared across all blocks (same query head against different key blocks).
6+
//
7+
// Output layout: n_blocks contiguous (M, N) tiles stacked vertically.
8+
// Block i occupies sij[i*M : (i+1)*M, 0:N].
9+
//
10+
// Optimizations:
11+
// - qi TLOAD hoisted before the loop (constant across all iterations)
12+
//
13+
// Supports two tile configurations via runtime dispatch:
14+
// Case1: (16, 128) @ (128, 128).T -> (16, 128)
15+
// Case2: (64, 128) @ (128, 64).T -> (64, 64)
16+
//
17+
// Template: M=q_tile, K=head_dim, N=block_size
18+
19+
#include <cstdint>
20+
#include <pto/pto-inst.hpp>
21+
22+
#include "tensor.h"
23+
24+
using namespace pto;
25+
26+
#ifndef __gm__
27+
#define __gm__
28+
#endif
29+
30+
#ifndef __aicore__
31+
#define __aicore__ [aicore]
32+
#endif
33+
34+
template <int M, int K, int N>
35+
static __aicore__ void qk_matmul_n_impl(
36+
__gm__ bfloat16_t* qi_base,
37+
__gm__ bfloat16_t* key_base,
38+
__gm__ float* sij_base,
39+
uint64_t n_blocks,
40+
__gm__ int32_t* block_table) {
41+
42+
using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
43+
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
44+
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;
45+
46+
using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
47+
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;
48+
49+
using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
50+
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
51+
using AccTile = TileAcc<float, M, N, M, N>;
52+
53+
TileMatA aMatTile;
54+
TileMatB bMatTile;
55+
TASSIGN(aMatTile, 0x0);
56+
TASSIGN(bMatTile, 0x20000);
57+
58+
LeftTile aTile;
59+
RightTile bTile;
60+
AccTile cTile;
61+
TASSIGN(aTile, 0x0);
62+
TASSIGN(bTile, 0x0);
63+
TASSIGN(cTile, 0x0);
64+
65+
// Hoist qi TLOAD before the loop (qi is constant across all blocks)
66+
GlobalA qiGlobal(qi_base);
67+
TLOAD(aMatTile, qiGlobal);
68+
69+
for (uint64_t i = 0; i < n_blocks; i++) {
70+
GlobalB kjGlobal(key_base + block_table[i] * N * K);
71+
GlobalOut sijGlobal(sij_base + i * M * N);
72+
73+
// Load only B each iteration (qi already in L1 from hoist)
74+
TLOAD(bMatTile, kjGlobal);
75+
76+
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
77+
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
78+
79+
// TMOV qi from L1→L0A (re-copy since TMATMUL consumed L0A) and kj from L1→L0B
80+
TMOV(aTile, aMatTile);
81+
TMOV(bTile, bMatTile);
82+
83+
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
84+
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
85+
86+
TMATMUL(cTile, aTile, bTile);
87+
88+
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
89+
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
90+
91+
TSTORE(sijGlobal, cTile);
92+
93+
if (i + 1 < n_blocks) {
94+
pipe_barrier(PIPE_ALL);
95+
}
96+
}
97+
set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
98+
wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
99+
}
100+
101+
extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {
102+
__gm__ TensorData* qi = reinterpret_cast<__gm__ TensorData*>(args[0]);
103+
__gm__ TensorData* key_cache = reinterpret_cast<__gm__ TensorData*>(args[1]);
104+
__gm__ TensorData* sij_buf = reinterpret_cast<__gm__ TensorData*>(args[2]);
105+
uint64_t n_blocks = static_cast<uint64_t>(args[3]);
106+
__gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]);
107+
108+
__gm__ bfloat16_t* qi_base = reinterpret_cast<__gm__ bfloat16_t*>(qi->buffer.addr) + qi->start_offset;
109+
__gm__ bfloat16_t* key_base = reinterpret_cast<__gm__ bfloat16_t*>(key_cache->buffer.addr);
110+
__gm__ float* sij_base = reinterpret_cast<__gm__ float*>(sij_buf->buffer.addr) + sij_buf->start_offset;
111+
112+
// qi is a 4D view: (1, 1, q_tile, head_dim)
113+
uint64_t q_tile_size = static_cast<uint64_t>(qi->shapes[2]);
114+
115+
if (q_tile_size == 16) {
116+
qk_matmul_n_impl<16, 128, 128>(qi_base, key_base, sij_base, n_blocks, block_table);
117+
} else {
118+
qk_matmul_n_impl<64, 128, 64>(qi_base, key_base, sij_base, n_blocks, block_table);
119+
}
120+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include <cstdint>
2+
#include <pto/pto-inst.hpp>
3+
4+
using namespace pto;
5+
6+
#ifndef __gm__
7+
#define __gm__
8+
#endif
9+
10+
#ifndef __aicore__
11+
#define __aicore__ [aicore]
12+
#endif
13+
14+
constexpr int M = 16;
15+
constexpr int K = 16;
16+
constexpr int N = 16;
17+
18+
extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {}

0 commit comments

Comments
 (0)