Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
*/
// PV Matmul Kernel: pij(M, K) @ vj(K, N) -> oi_new(M, N)
//
// Fixed tile size: (16, 16) @ (16, 16) -> (16, 16)
// Supports two tile configurations via runtime dispatch:
// Case1: (16, 128) @ (128, 128) -> (16, 128)
// Case2: (64, 64) @ ( 64, 128) -> (64, 128)
//
// pij is bfloat16 (converted from fp32 in softmax_prepare via TCVT).
// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout.
Expand Down Expand Up @@ -67,15 +69,17 @@ static __aicore__ void pv_matmul_impl(__gm__ Tensor *pij, __gm__ Tensor *vj, __g
TASSIGN(bTile, 0x0);
TASSIGN(cTile, 0x0);

// Load pij and vj to L1
// Load pij and vj to L1 with separate events for pipeline overlap
TLOAD(aMatTile, pijGlobal);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // A load done
TLOAD(bMatTile, vjGlobal);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // B load done

set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
// Move A to L0A as soon as A load completes (B may still be loading)
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);

// Move to L0A/L0B
TMOV(aTile, aMatTile);
// Move B to L0B after B load completes
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
TMOV(bTile, bMatTile);

set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
Expand All @@ -97,6 +101,13 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {
__gm__ Tensor *pij = reinterpret_cast<__gm__ Tensor *>(args[0]);
__gm__ Tensor *vj = reinterpret_cast<__gm__ Tensor *>(args[1]);
__gm__ Tensor *oi_new = reinterpret_cast<__gm__ Tensor *>(args[2]);

pv_matmul_impl<16, 16, 16>(pij, vj, oi_new);
uint64_t q_tile_size = static_cast<uint64_t>(pij->shapes[0]);

if (q_tile_size == 16 && pij->shapes[1] <= 16) {
pv_matmul_impl<16, 16, 16>(pij, vj, oi_new);
} else if (q_tile_size == 16) {
pv_matmul_impl<16, 128, 128>(pij, vj, oi_new);
} else {
pv_matmul_impl<64, 64, 128>(pij, vj, oi_new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
*/
// QK Matmul Kernel: qi(M, K) @ kj.T(K, N) -> sij(M, N)
//
// Fixed tile size: (16, 16) @ (16, 16).T -> (16, 16)
// Supports two tile configurations via runtime dispatch:
// Case1: (16, 128) @ (128, 128).T -> (16, 128)
// Case2: (64, 128) @ (128, 64).T -> (64, 64)
//
// kj is stored as (N, K) = (block_size, head_dim) in row-major memory.
// This is equivalent to (K, N) in column-major (DN) layout.
Expand Down Expand Up @@ -68,15 +70,17 @@ static __aicore__ void qk_matmul_impl(__gm__ Tensor *qi, __gm__ Tensor *kj, __gm
TASSIGN(bTile, 0x0);
TASSIGN(cTile, 0x0);

// Load A and B to L1
// Load A and B to L1 with separate events for pipeline overlap
TLOAD(aMatTile, qiGlobal);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // A load done
TLOAD(bMatTile, kjGlobal);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // B load done

set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
// Move A to L0A as soon as A load completes (B may still be loading)
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);

// Move from L1 to L0A/L0B
TMOV(aTile, aMatTile);
// Move B to L0B after B load completes
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
TMOV(bTile, bMatTile);

set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
Expand All @@ -98,6 +102,13 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {
__gm__ Tensor *qi = reinterpret_cast<__gm__ Tensor *>(args[0]);
__gm__ Tensor *kj = reinterpret_cast<__gm__ Tensor *>(args[1]);
__gm__ Tensor *sij = reinterpret_cast<__gm__ Tensor *>(args[2]);

qk_matmul_impl<16, 16, 16>(qi, kj, sij);
uint64_t q_tile_size = static_cast<uint64_t>(qi->shapes[0]);

if (q_tile_size == 16 && qi->shapes[1] <= 16) {
qk_matmul_impl<16, 16, 16>(qi, kj, sij);
} else if (q_tile_size == 16) {
qk_matmul_impl<16, 128, 128>(qi, kj, sij);
} else {
qk_matmul_impl<64, 128, 64>(qi, kj, sij);
}
}
Loading
Loading