diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_pv_matmul.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_pv_matmul.cpp new file mode 100644 index 00000000..9fc8b539 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_pv_matmul.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +// SplitK PV Matmul Kernel: Accumulated P @ V across n_blocks +// +// Processes n_blocks blocks using SplitK accumulation pattern: +// Block 0: TMATMUL(C, A, B) — initialize accumulator +// Block i: TMATMUL_ACC(C, C, A, B) — accumulate into same C +// +// Per-block pij addresses: contiguous slices of pij_buf (n_blocks * M * K) +// Per-block vj addresses: value_cache base + block_indices lookup +// Single output: oi_new (M, N) fp32 = sum of P_i @ V_i across all blocks +// +// Optimizations: +// - Double-buffered L1 tiles (ping/pong for A and B via MTE2) +// - Double-buffered L0 tiles (ping/pong for L0A and L0B via MTE1) +// - TLOAD(next) overlaps with TMATMUL(current) via MTE2/M-pipe parallelism +// - Canonical 3-stage pipeline: TLOAD(MTE2) → TMOV(MTE1) → TMATMUL(M) +// - Reverse-dependency events ensure buffer safety across iterations +// +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) @ (128, 128) -> (16, 128) +// Case2: (64, 64) @ ( 64, 128) -> (64, 128) +// +// pij is bfloat16 (from softmax_prepare TCVT). +// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout. + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void pv_matmul_n_impl( + __gm__ Tensor *pij_buf, __gm__ Tensor *value_cache, __gm__ Tensor *block_table_t, __gm__ Tensor *oi_new, + uint64_t n_blocks, uint64_t bt_offset +) { + // Decode 4D semantic: batch/q_len are constexpr 1. + static constexpr int BATCH = 1; + static constexpr int Q_LEN = 1; + + __gm__ bfloat16_t *pij_base = reinterpret_cast<__gm__ bfloat16_t *>(pij_buf->buffer.addr) + pij_buf->start_offset; + __gm__ bfloat16_t *val_base = reinterpret_cast<__gm__ bfloat16_t *>(value_cache->buffer.addr); + __gm__ float *oi_base = reinterpret_cast<__gm__ float *>(oi_new->buffer.addr) + oi_new->start_offset; + __gm__ int32_t *bt = reinterpret_cast<__gm__ int32_t *>(block_table_t->buffer.addr); + + using GlobalA = GlobalTensor, Stride<1, M * K, M * K, K, 1>>; + using GlobalB = GlobalTensor, Stride>; + using GlobalOut = GlobalTensor, Stride<1, M * N, M * N, N, 1>>; + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + // L1 memory layout: double-buffered A and B tiles (tightly packed) + constexpr int kATileBytes = M * K * static_cast(sizeof(bfloat16_t)); + constexpr int kBTileBytes = K * N * static_cast(sizeof(bfloat16_t)); + + TileMatA aMatTile[2]; + TileMatB bMatTile[2]; + TASSIGN(aMatTile[0], 0x0); + TASSIGN(aMatTile[1], kATileBytes); + TASSIGN(bMatTile[0], 2 * kATileBytes); + TASSIGN(bMatTile[1], 2 * kATileBytes + kBTileBytes); + + // L0 memory layout: double-buffered L0A and L0B, single accumulator L0C + LeftTile aTile[2]; + RightTile bTile[2]; + AccTile cTile; + TASSIGN(aTile[0], 0x0); + TASSIGN(aTile[1], kATileBytes); + TASSIGN(bTile[0], 0x0); + TASSIGN(bTile[1], kBTileBytes); + TASSIGN(cTile, 0x0); + + GlobalOut oiGlobal(oi_base); + + // Seed reverse-dependency flags: all ping/pong buffers initially free + // PIPE_MTE1 → PIPE_MTE2: L1 buffer [0/1] safe for TLOAD to overwrite + // PIPE_M → PIPE_MTE1: L0 buffer [0/1] safe for TMOV to overwrite + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + + for (uint64_t i = 0; i < n_blocks; i++) { + int cur = static_cast(i % 2); + GlobalA pijGlobal(pij_base + i * M * K); + GlobalB vjGlobal(val_base + bt[bt_offset + i] * K * N); + + // Stage 1: TLOAD (MTE2: GM → L1[cur]) + // Wait for MTE1 to release L1[cur] (reverse dep from previous iteration) + wait_flag(PIPE_MTE1, PIPE_MTE2, (event_t)cur); + TLOAD(aMatTile[cur], pijGlobal); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // forward: A in L1 ready + TLOAD(bMatTile[cur], vjGlobal); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // forward: B in L1 ready + + // Stage 2: TMOV (MTE1: L1[cur] → L0[cur]) + // Wait for M-pipe to release L0[cur] (reverse dep from previous iteration) + wait_flag(PIPE_M, PIPE_MTE1, (event_t)cur); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // forward: wait A loaded + TMOV(aTile[cur], aMatTile[cur]); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // forward: wait B loaded + TMOV(bTile[cur], bMatTile[cur]); + set_flag(PIPE_MTE1, PIPE_MTE2, (event_t)cur); // reverse: release L1[cur] + + // Stage 3: TMATMUL (M-pipe: L0A[cur] × L0B[cur] → L0C) + set_flag(PIPE_MTE1, PIPE_M, (event_t)cur); // forward: L0[cur] ready + wait_flag(PIPE_MTE1, PIPE_M, (event_t)cur); + if (i == 0) { + TMATMUL(cTile, aTile[cur], bTile[cur]); + } else { + TMATMUL_ACC(cTile, cTile, aTile[cur], bTile[cur]); + } + set_flag(PIPE_M, PIPE_MTE1, (event_t)cur); // reverse: release L0[cur] + } + + // Drain outstanding reverse-dependency flags + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(oiGlobal, cTile); + + set_flag(PIPE_FIX, PIPE_S, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *pij_buf = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *value_cache = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *block_table_t = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *oi_new = reinterpret_cast<__gm__ Tensor *>(args[3]); + uint64_t n_blocks = static_cast(args[4]); + uint64_t bt_offset = static_cast(args[5]); + + // pij_buf is 4D (1, 1, q_tile, n_blocks*block_size) to match qk's 4D output. + uint64_t q_tile_size = static_cast(pij_buf->shapes[2]); + + if (q_tile_size == 16) { + pv_matmul_n_impl<16, 128, 128>(pij_buf, value_cache, block_table_t, oi_new, n_blocks, bt_offset); + } else { + pv_matmul_n_impl<64, 64, 128>(pij_buf, value_cache, block_table_t, oi_new, n_blocks, bt_offset); + } +} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_qk_matmul.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_qk_matmul.cpp new file mode 100644 index 00000000..2577fd4e --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_qk_matmul.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +// Multi-block QK Matmul Kernel: qi(M, K) @ kj.T(K, N) -> sij(M, N) for each block +// +// Processes n_blocks blocks in a single kernel invocation. +// Per-block kj addresses computed from key_cache base + block_indices lookup. +// qi is shared across all blocks (same query head against different key blocks). +// +// Output layout: n_blocks contiguous (M, N) tiles stacked vertically. +// Block i occupies sij[i*M : (i+1)*M, 0:N]. +// +// Optimizations: +// - qi TLOAD hoisted before the loop (constant across all iterations) +// +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) @ (128, 128).T -> (16, 128) +// Case2: (64, 128) @ (128, 64).T -> (64, 64) +// +// Template: M=q_tile, K=head_dim, N=block_size + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void qk_matmul_n_impl( + __gm__ Tensor *qi, __gm__ Tensor *key_cache, __gm__ Tensor *block_table_t, __gm__ Tensor *sij_buf, + uint64_t n_blocks, uint64_t bt_offset +) { + // Decode 4D query view: batch/q_len are constexpr 1. + static constexpr int BATCH = 1; + static constexpr int Q_LEN = 1; + + __gm__ bfloat16_t *qi_base = reinterpret_cast<__gm__ bfloat16_t *>(qi->buffer.addr) + qi->start_offset; + __gm__ bfloat16_t *key_base = reinterpret_cast<__gm__ bfloat16_t *>(key_cache->buffer.addr); + __gm__ float *sij_base = reinterpret_cast<__gm__ float *>(sij_buf->buffer.addr) + sij_buf->start_offset; + __gm__ int32_t *bt = reinterpret_cast<__gm__ int32_t *>(block_table_t->buffer.addr); + + using GlobalA = GlobalTensor, Stride<1, M * K, M * K, K, 1>>; + using GlobalB = GlobalTensor, Stride, Layout::DN>; + using GlobalOut = GlobalTensor, Stride<1, M * N, M * N, N, 1>>; + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + // Hoist qi TLOAD before the loop (qi is constant across all blocks) + GlobalA qiGlobal(qi_base); + TLOAD(aMatTile, qiGlobal); + + for (uint64_t i = 0; i < n_blocks; i++) { + GlobalB kjGlobal(key_base + bt[bt_offset + i] * N * K); + GlobalOut sijGlobal(sij_base + i * M * N); + + // Load only B each iteration (qi already in L1 from hoist) + TLOAD(bMatTile, kjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + // TMOV qi from L1→L0A (re-copy since TMATMUL consumed L0A) and kj from L1→L0B + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(sijGlobal, cTile); + + if (i + 1 < n_blocks) { + pipe_barrier(PIPE_ALL); + } + } + set_flag(PIPE_FIX, PIPE_S, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *qi = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *key_cache = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *block_table_t = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *sij_buf = reinterpret_cast<__gm__ Tensor *>(args[3]); + uint64_t n_blocks = static_cast(args[4]); + uint64_t bt_offset = static_cast(args[5]); + + // qi is a 4D view (batch, q_len, num_heads_tile, head_dim); decode fixes batch=q_len=1. + uint64_t q_tile_size = static_cast(qi->shapes[2]); + + if (q_tile_size == 16) { + qk_matmul_n_impl<16, 128, 128>(qi, key_cache, block_table_t, sij_buf, n_blocks, bt_offset); + } else { + qk_matmul_n_impl<64, 128, 64>(qi, key_cache, block_table_t, sij_buf, n_blocks, bt_offset); + } +} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_online_update.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_online_update.cpp new file mode 100644 index 00000000..6f238ba9 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_online_update.cpp @@ -0,0 +1,261 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +// Online Softmax Update + Normalize Kernel (AIV) +// +// Operates on full tiles where M=q_tile_size, N=head_dim (128): +// Case1: oi/oi_new are (16, 128), mij/lij/mi/li are 16-element vectors +// Case2: oi/oi_new are (64, 128), mij/lij/mi/li are 64-element vectors +// +// Scalar layout strategy using TRESHAPE (zero-copy UB reshape): +// Scalars loaded as DN ColMajor (M, 1) for TROWEXPANDMUL/TROWEXPANDDIV. +// For element-wise ops (TMAX, TSUB, TEXP, etc.), TRESHAPE to RowMajor (1, M). +// After arithmetic, TRESHAPE back to ColMajor (M, 1) for row-broadcast ops. +// This eliminates the GM round-trip (TSTORE ND → TLOAD DN) used in the original. + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void online_update_impl( + __gm__ Tensor *mij, __gm__ Tensor *lij, __gm__ Tensor *oi_new, __gm__ Tensor *mi, __gm__ Tensor *li, + __gm__ Tensor *oi, uint64_t is_first, uint64_t is_last, __gm__ Tensor *dst +) { + __gm__ float *mij_ptr = reinterpret_cast<__gm__ float *>(mij->buffer.addr); + __gm__ float *lij_ptr = reinterpret_cast<__gm__ float *>(lij->buffer.addr); + __gm__ float *oi_new_ptr = reinterpret_cast<__gm__ float *>(oi_new->buffer.addr); + __gm__ float *mi_ptr = reinterpret_cast<__gm__ float *>(mi->buffer.addr); + __gm__ float *li_ptr = reinterpret_cast<__gm__ float *>(li->buffer.addr); + __gm__ float *oi_ptr = reinterpret_cast<__gm__ float *>(oi->buffer.addr); + __gm__ float *dst_ptr = reinterpret_cast<__gm__ float *>(dst->buffer.addr); + + // Aligned rows for ColMajor DN tiles (32-byte alignment) + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + // --- GlobalTensor types --- + + // Decode 4D semantic: batch/q_len are constexpr 1. + static constexpr int BATCH = 1; + static constexpr int Q_LEN = 1; + + // 4D data views (1, 1, q_tile, head_dim) — oi, oi_new, dst. + using GlobalData4D = GlobalTensor, Stride<1, M * N, M * N, N, 1>>; + + // Scalar DN: M contiguous floats as (kAlignedRows, 1) ColMajor for TROWEXPAND ops and loading + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + + // Scalar ND: for storing mi_new and li_new back to GM + constexpr int kScalarCols = 32 / sizeof(float); + constexpr int kScalarRows = M / kScalarCols; + using GlobalScalarND = + GlobalTensor, Stride<1, 1, 1, kScalarCols, 1>>; + + // --- GlobalTensor instances --- + + GlobalData4D oiNewGlobal(oi_new_ptr + oi_new->start_offset); + GlobalData4D oiGlobal(oi_ptr + oi->start_offset); + GlobalData4D dstGlobal(dst_ptr + dst->start_offset); + + // DN globals for loading scalars as ColMajor + GlobalScalarDN mijGlobalDN(mij_ptr + mij->start_offset); + GlobalScalarDN lijGlobalDN(lij_ptr + lij->start_offset); + GlobalScalarDN miGlobalDN(mi_ptr + mi->start_offset); + GlobalScalarDN liGlobalDN(li_ptr + li->start_offset); + + // ND globals for storing scalar results + GlobalScalarND miGlobalND(mi_ptr + mi->start_offset); + GlobalScalarND liGlobalND(li_ptr + li->start_offset); + + // --- Tile types --- + + using TileDataMxN = Tile; + using TileScalarDN = Tile; + + // RowMajor (1, M) tiles for element-wise arithmetic via TRESHAPE + using TileScalarRow = Tile; + + // ND tile for storing back to GM + using TileScalarND = + Tile; + + // --- UB memory layout --- + + constexpr int kDataBytes = M * N * sizeof(float); + constexpr int kScalarDNBytes = kAlignedRows * sizeof(float); + + // Data tiles + TileDataMxN oiNewTile; + TileDataMxN oiTile; + + // Scalar DN tiles loaded from GM (ColMajor) + TileScalarDN mijDN, lijDN, miDN, liDN; + + // Temporary DN tiles for results + TileScalarDN miNewDN, alphaDN, betaDN, liNewDN, tmpDN; + + TASSIGN(oiNewTile, 0); + TASSIGN(oiTile, kDataBytes); + TASSIGN(mijDN, 2 * kDataBytes); + TASSIGN(lijDN, 2 * kDataBytes + kScalarDNBytes); + TASSIGN(miDN, 2 * kDataBytes + 2 * kScalarDNBytes); + TASSIGN(liDN, 2 * kDataBytes + 3 * kScalarDNBytes); + TASSIGN(miNewDN, 2 * kDataBytes + 4 * kScalarDNBytes); + TASSIGN(alphaDN, 2 * kDataBytes + 5 * kScalarDNBytes); + TASSIGN(betaDN, 2 * kDataBytes + 6 * kScalarDNBytes); + TASSIGN(liNewDN, 2 * kDataBytes + 7 * kScalarDNBytes); + TASSIGN(tmpDN, 2 * kDataBytes + 8 * kScalarDNBytes); + + if (is_first) { + // --- First block: copy inputs to accumulators --- + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(mijDN, mijGlobalDN); + TLOAD(lijDN, lijGlobalDN); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Store mi = mij, li = lij, oi = oi_new + // Alias ND tiles to same UB as DN tiles for ND-format store + TileScalarND mijND, lijND; + TASSIGN(mijND, 2 * kDataBytes); // alias same UB as mijDN + TASSIGN(lijND, 2 * kDataBytes + kScalarDNBytes); // alias same UB as lijDN + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, mijND); // mi = mij + TSTORE(liGlobalND, lijND); // li = lij + TSTORE(oiGlobal, oiNewTile); // oi = oi_new + + if (is_last) { + // Single block: normalize dst = oi_new / lij + // lijDN already in ColMajor DN format, use directly for TROWEXPANDDIV + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TROWEXPANDDIV(oiNewTile, oiNewTile, lijDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiNewTile); + } + } else { + // --- Subsequent blocks: accumulate --- + + // Load all inputs as DN (ColMajor) + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(oiTile, oiGlobal); + TLOAD(mijDN, mijGlobalDN); + TLOAD(lijDN, lijGlobalDN); + TLOAD(miDN, miGlobalDN); + TLOAD(liDN, liGlobalDN); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TRESHAPE: ColMajor(M,1) → RowMajor(1,M) for element-wise arithmetic + TileScalarRow miRow, mijRow, liRow, lijRow; + TRESHAPE(miRow, miDN); + TRESHAPE(mijRow, mijDN); + TRESHAPE(liRow, liDN); + TRESHAPE(lijRow, lijDN); + + // Scalar arithmetic in RowMajor (1, M) layout + TileScalarRow miNewRow, alphaRow, betaRow, liNewRow, tmpRow; + TASSIGN(miNewRow, 2 * kDataBytes + 4 * kScalarDNBytes); + TASSIGN(alphaRow, 2 * kDataBytes + 5 * kScalarDNBytes); + TASSIGN(betaRow, 2 * kDataBytes + 6 * kScalarDNBytes); + TASSIGN(liNewRow, 2 * kDataBytes + 7 * kScalarDNBytes); + TASSIGN(tmpRow, 2 * kDataBytes + 8 * kScalarDNBytes); + + TMAX(miNewRow, miRow, mijRow); // mi_new = max(mi, mij) + pipe_barrier(PIPE_V); + TSUB(alphaRow, miRow, miNewRow); // alpha_exp = mi - mi_new + pipe_barrier(PIPE_V); + TEXP(alphaRow, alphaRow); // alpha = exp(mi - mi_new) + pipe_barrier(PIPE_V); + TSUB(betaRow, mijRow, miNewRow); // beta_exp = mij - mi_new + pipe_barrier(PIPE_V); + TEXP(betaRow, betaRow); // beta = exp(mij - mi_new) + pipe_barrier(PIPE_V); + TMUL(tmpRow, alphaRow, liRow); // alpha * li + pipe_barrier(PIPE_V); + TMUL(liNewRow, betaRow, lijRow); // beta * lij + pipe_barrier(PIPE_V); + TADD(liNewRow, tmpRow, liNewRow); // li_new = alpha*li + beta*lij + + // TRESHAPE back: RowMajor(1,M) → ColMajor(M,1) for TROWEXPANDMUL + pipe_barrier(PIPE_V); + TRESHAPE(alphaDN, alphaRow); + TRESHAPE(betaDN, betaRow); + + // Scale data tiles using row-broadcast multiply + TROWEXPANDMUL(oiTile, oiTile, alphaDN); // oi *= alpha + pipe_barrier(PIPE_V); + TROWEXPANDMUL(oiNewTile, oiNewTile, betaDN); // oi_new *= beta + pipe_barrier(PIPE_V); + TADD(oiTile, oiTile, oiNewTile); // oi = alpha*oi + beta*oi_new + + // Store mi_new and li_new to GM (ND format) + // Alias ND tiles to the same UB locations as miNewRow and liNewRow + TileScalarND miNewND, liNewND; + TASSIGN(miNewND, 2 * kDataBytes + 4 * kScalarDNBytes); + TASSIGN(liNewND, 2 * kDataBytes + 7 * kScalarDNBytes); + + if (is_last) { + // Normalize and output: dst = oi / li_new + TRESHAPE(liNewDN, liNewRow); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(oiTile, oiTile, liNewDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); // persist mi_new + TSTORE(liGlobalND, liNewND); // persist li_new + TSTORE(dstGlobal, oiTile); + } else { + // Store updated accumulators + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); // persist mi_new + TSTORE(liGlobalND, liNewND); // persist li_new + TSTORE(oiGlobal, oiTile); + } + } + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *mij = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *lij = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *oi_new = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *mi = reinterpret_cast<__gm__ Tensor *>(args[3]); + __gm__ Tensor *li = reinterpret_cast<__gm__ Tensor *>(args[4]); + __gm__ Tensor *oi = reinterpret_cast<__gm__ Tensor *>(args[5]); + __gm__ Tensor *dst = reinterpret_cast<__gm__ Tensor *>(args[6]); + uint64_t is_first = static_cast(args[7]); + uint64_t is_last = static_cast(args[8]); + // mij is 3D (1, 1, q_tile) to match softmax's 3D scalar output. + uint64_t q_tile_size = static_cast(mij->shapes[2]); + + if (q_tile_size == 16) { + online_update_impl<16, 128>(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); + } else { + online_update_impl<64, 128>(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); + } +} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_softmax_prepare.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_softmax_prepare.cpp new file mode 100644 index 00000000..0203df19 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_softmax_prepare.cpp @@ -0,0 +1,276 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +// Two-Pass Softmax Kernel (AIV) for n_blocks tiles +// +// Input: sij_buf (n_blocks * M, N) fp32 — QK results stacked vertically +// Output: pij_buf (n_blocks * M, N) bf16 — attention weights per block +// mij (M,) fp32 — global row max across all blocks +// lij (M,) fp32 — total row sum across all blocks +// +// Pass 1: Iterate over n_blocks tiles, apply scale, mask last block, +// find global m = max over all blocks of rowmax(S_i * scale) +// Uses TRESHAPE for DN↔Row conversion to keep globalMax in UB +// (eliminates 63 × 4 GM round-trip operations). +// Pass 2: Iterate again, compute P_i = exp(S_i * scale - m) -> bf16, +// accumulate l = sum over all blocks of rowsum(P_i) +// Uses double-buffered sij tiles to overlap TLOAD with computation. +// +// Two-pass ensures all P_i tiles share the same scale (global max), +// enabling direct TMATMUL_ACC accumulation in the PV kernel. +// +// Supports two tile configurations via runtime dispatch: +// Case1: M=16, N=128 (q_tile=16, block_size=128) +// Case2: M=64, N=64 (q_tile=64, block_size=64) + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void softmax_prepare_n_impl( + __gm__ Tensor *sij_buf, __gm__ Tensor *pij_buf, __gm__ Tensor *mij, __gm__ Tensor *lij, float scale_value, + uint64_t n_blocks, uint64_t valid_len_last +) { + __gm__ float *sij_base = reinterpret_cast<__gm__ float *>(sij_buf->buffer.addr) + sij_buf->start_offset; + __gm__ bfloat16_t *pij_base = reinterpret_cast<__gm__ bfloat16_t *>(pij_buf->buffer.addr) + pij_buf->start_offset; + __gm__ float *mij_addr = reinterpret_cast<__gm__ float *>(mij->buffer.addr) + mij->start_offset; + __gm__ float *lij_addr = reinterpret_cast<__gm__ float *>(lij->buffer.addr) + lij->start_offset; + + // Decode 4D semantic: batch/q_len are constexpr 1. + static constexpr int BATCH = 1; + static constexpr int Q_LEN = 1; + + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + constexpr int kScalarCols = 32 / sizeof(float); + constexpr int kScalarRows = M / kScalarCols; + + // --- GlobalTensor types --- + // 4D data views (1, 1, q_tile, n_blocks*block_size) for sij/pij. + using GlobalDataMxN = GlobalTensor, Stride<1, M * N, M * N, N, 1>>; + using GlobalDataMxN_bf16 = GlobalTensor, Stride<1, M * N, M * N, N, 1>>; + // DN/ND scalar globals stay 2D: scalar vectors only need per-element layout. + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + using GlobalScalarND = + GlobalTensor, Stride<1, 1, 1, kScalarCols, 1>>; + + // --- Tile types --- + using TileSijDyn = Tile; + using TileSijPad = Tile; + using TileVecMxN = Tile; + using TileVecMxN_bf16 = Tile; + using TileScalarDN = Tile; + using TileScalarND = + Tile; + // RowMajor (1, M) tile for element-wise arithmetic via TRESHAPE + using TileScalarRow = Tile; + + // --- UB memory layout (double-buffered sij) --- + constexpr int kDataBytes = M * N * sizeof(float); + constexpr int kScalarDNBytes = kAlignedRows * sizeof(float); + + // Double-buffered sij tiles + TileVecMxN sijTile_A; + TileSijPad sijPadTile_A; + TileVecMxN sijTile_B; + TileSijPad sijPadTile_B; + TileVecMxN pijTile; + TileVecMxN tmpTile; + TileVecMxN sumAccTile; + TileScalarDN localMaxDN; + TileScalarDN globalMaxDN; + TileScalarDN sumDN; + TileVecMxN_bf16 pijBf16Tile; + + // TRESHAPE aliases (same UB address as their DN counterparts) + TileScalarRow localMaxRow; + TileScalarRow globalMaxRow; + + // ND alias for storing globalMax to GM + TileScalarND globalMaxND; + + TASSIGN(sijTile_A, 0x0); + TASSIGN(sijPadTile_A, 0x0); + TASSIGN(sijTile_B, kDataBytes); + TASSIGN(sijPadTile_B, kDataBytes); + TASSIGN(pijTile, 2 * kDataBytes); + TASSIGN(tmpTile, 3 * kDataBytes); + TASSIGN(sumAccTile, 4 * kDataBytes); + int scalarBase = 5 * kDataBytes; + TASSIGN(localMaxDN, scalarBase); + TASSIGN(localMaxRow, scalarBase); // alias: same UB as localMaxDN + TASSIGN(globalMaxDN, scalarBase + kScalarDNBytes); + TASSIGN(globalMaxRow, scalarBase + kScalarDNBytes); // alias: same UB as globalMaxDN + TASSIGN(globalMaxND, scalarBase + kScalarDNBytes); // alias: same UB as globalMaxDN + TASSIGN(sumDN, scalarBase + 2 * kScalarDNBytes); + TASSIGN(pijBf16Tile, scalarBase + 3 * kScalarDNBytes); + + // GM aliases (mij/lij output buffers) + GlobalScalarND mijGlobalND(mij_addr); + GlobalScalarDN lijGlobalDN(lij_addr); + + // ======== Pass 1: Find global row max via TRESHAPE (no GM round-trip) ======== + for (uint64_t i = 0; i < n_blocks; i++) { + GlobalDataMxN sijGlobal(sij_base + i * M * N); + TLOAD(sijTile_A, sijGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + if (i == n_blocks - 1 && valid_len_last < static_cast(N)) { + TileSijDyn sijDynTile(static_cast(valid_len_last)); + TASSIGN(sijDynTile, 0x0); + TFILLPAD_INPLACE(sijPadTile_A, sijDynTile); + pipe_barrier(PIPE_V); + } + + TMULS(sijTile_A, sijTile_A, scale_value); + pipe_barrier(PIPE_V); + TROWMAX(localMaxDN, sijTile_A, tmpTile); + pipe_barrier(PIPE_V); + + // TRESHAPE: ColMajor(M,1) → RowMajor(1,M) for element-wise TMAX + TRESHAPE(localMaxRow, localMaxDN); + if (i == 0) { + TMAX(globalMaxRow, localMaxRow, localMaxRow); + } else { + TMAX(globalMaxRow, globalMaxRow, localMaxRow); + } + pipe_barrier(PIPE_V); + } + + // TRESHAPE back: RowMajor(1,M) → ColMajor(M,1) for Pass 2's TROWEXPANDSUB + TRESHAPE(globalMaxDN, globalMaxRow); + + // Store final global max to mij for online_update to consume + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(mijGlobalND, globalMaxND); + + // ======== Pass 2: Compute softmax with double-buffered sij ======== + // globalMaxDN is already in UB from TRESHAPE — no reload needed. + // Sync MTE3→MTE2 to ensure the mij TSTORE completed before first sij TLOAD. + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + + // Pre-load first sij tile into buffer A + GlobalDataMxN sijGlobal_0(sij_base); + TLOAD(sijTile_A, sijGlobal_0); + + for (uint64_t i = 0; i < n_blocks; i++) { + GlobalDataMxN_bf16 pijGlobal(pij_base + i * M * N); + + // Wait for current tile's TLOAD to complete + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TFILLPAD on current buffer if last block with partial valid length + if (i == n_blocks - 1 && valid_len_last < static_cast(N)) { + TileSijDyn curSijDyn(static_cast(valid_len_last)); + if (i % 2 == 0) { + TASSIGN(curSijDyn, 0x0); + TFILLPAD_INPLACE(sijPadTile_A, curSijDyn); + } else { + TASSIGN(curSijDyn, static_cast(kDataBytes)); + TFILLPAD_INPLACE(sijPadTile_B, curSijDyn); + } + pipe_barrier(PIPE_V); + } + + // Compute on current buffer (select A or B based on iteration parity) + if (i % 2 == 0) { + TMULS(sijTile_A, sijTile_A, scale_value); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(pijTile, sijTile_A, globalMaxDN); + } else { + TMULS(sijTile_B, sijTile_B, scale_value); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(pijTile, sijTile_B, globalMaxDN); + } + pipe_barrier(PIPE_V); + TEXP(pijTile, pijTile); + pipe_barrier(PIPE_V); + TCVT(pijBf16Tile, pijTile, RoundMode::CAST_ROUND); + pipe_barrier(PIPE_V); + TCVT(pijTile, pijBf16Tile, RoundMode::CAST_ROUND); + + pipe_barrier(PIPE_V); + if (i == 0) { + TMULS(sumAccTile, pijTile, 1.0f); + } else { + TADD(sumAccTile, sumAccTile, pijTile); + } + + // Store pij (must complete before next iteration's TCVT overwrites pijBf16Tile) + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(pijGlobal, pijBf16Tile); + + // Prefetch next sij into alternate buffer (after TSTORE to avoid UB race) + if (i + 1 < n_blocks) { + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + GlobalDataMxN sijGlobal_next(sij_base + (i + 1) * M * N); + if (i % 2 == 0) { + TLOAD(sijTile_B, sijGlobal_next); + } else { + TLOAD(sijTile_A, sijGlobal_next); + } + } + } + + // Compute final row sum from accumulated pij values + pipe_barrier(PIPE_V); + TROWSUM(sumDN, sumAccTile, tmpTile); + + // Store lij (total sum). mij already stored after Pass 1. + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(lijGlobalDN, sumDN); + + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *sij_buf = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *pij_buf = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *mij = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *lij = reinterpret_cast<__gm__ Tensor *>(args[3]); + union { + uint64_t u; + float f; + } scale_conv; + scale_conv.u = static_cast(args[4]); + float scale_value = scale_conv.f; + uint64_t n_blocks = static_cast(args[5]); + uint64_t valid_len_last = static_cast(args[6]); + + // sij_buf is 4D (1, 1, q_tile, n_blocks*block_size) to match qk's 4D output semantic. + uint64_t q_tile_size = static_cast(sij_buf->shapes[2]); + + if (q_tile_size == 16) { + softmax_prepare_n_impl<16, 128>(sij_buf, pij_buf, mij, lij, scale_value, n_blocks, valid_len_last); + } else { + softmax_prepare_n_impl<64, 64>(sij_buf, pij_buf, mij, lij, scale_value, n_blocks, valid_len_last); + } +} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp new file mode 100644 index 00000000..45f7ba75 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +/** + * Paged Attention Orchestration - 4D input shapes, N_UNROLL=64, 4 Tasks Per Group + * + * Batches up to N_UNROLL blocks per group. Each group submits exactly 4 tasks: + * 1. QK matmul: qi @ K^T for n_blocks → sij_buf (1, 1, q_tile, n_blocks * block_size) + * 2. Softmax: two-pass over sij_buf → pij_buf, mi, li + * 3. PV matmul: SplitK accumulated P @ V → oi_new (1, 1, q_tile, head_dim) + * 4. Update: online softmax accumulation with group-level mi, li, oi_new + * + * Memory Layout (4D throughout): + * Query: (batch, seq_len=1, num_heads, head_dim) bf16 + * Key: (total_blocks, block_size, kv_head_num, head_dim) bf16 + * Value: (total_blocks, block_size, kv_head_num, head_dim) bf16 + * Out: (batch, seq_len=1, num_heads, head_dim) fp32 + */ + +#include +#include + +#include "pto_orchestration_api.h" + +#define N_UNROLL 64 + +#define FUNC_QK_MATMUL 0 +#define FUNC_SOFTMAX_PREPARE 1 +#define FUNC_PV_MATMUL 2 +#define FUNC_ONLINE_UPDATE 3 + +extern "C" { +/** + * Orchestration config — the executor reads these values to set up + * shared memory and runtime before calling aicpu_orchestration_entry. + */ +__attribute__((visibility("default"))) PTO2OrchestrationConfig +aicpu_orchestration_config(const ChipStorageTaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{ + .expected_arg_count = 7, + }; +} + +__attribute__((visibility("default"))) void aicpu_orchestration_entry(const ChipStorageTaskArgs &orch_args) { + // Read dimensions from tensor metadata + // query: shape=[batch, seq_len, num_heads, head_dim] + uint64_t batch = orch_args.tensor(0).shapes[0]; + uint64_t num_heads = orch_args.tensor(0).shapes[2]; + uint64_t head_dim = orch_args.tensor(0).shapes[3]; + DataType data_type = orch_args.tensor(0).dtype; + + // key_cache: shape=[total_blocks, block_size, kv_head_num, head_dim] + uint64_t block_size = orch_args.tensor(1).shapes[1]; + + // block_table: shape=[batch, max_num_blocks_per_req] + uint64_t block_num = orch_args.tensor(3).shapes[1]; + + // scale from scalar arg + uint64_t scale_value = orch_args.scalar(0); + uint64_t q_tile = std::min(num_heads, 128UL); + uint64_t q_loop = (num_heads + q_tile - 1) / q_tile; + + // External 4D tensors inherit shape/dtype from TaskArg (golden provides 4D). + Tensor query = from_tensor_arg(orch_args.tensor(0)); + Tensor key_cache = from_tensor_arg(orch_args.tensor(1)); + Tensor value_cache = from_tensor_arg(orch_args.tensor(2)); + Tensor block_table = from_tensor_arg(orch_args.tensor(3)); + Tensor out = from_tensor_arg(orch_args.tensor(5)); + + int *host_context_lens = orch_args.tensor(4).data_as(); + + // Loop-invariant shape descriptors: 4D data tiles (1, 1, q_tile, head_dim), + // 3D scalar vectors (1, 1, q_tile). + uint32_t tile4d_shapes[4] = {1, 1, (uint32_t)q_tile, (uint32_t)head_dim}; + uint32_t scalar_shapes[3] = {1, 1, (uint32_t)q_tile}; + TensorCreateInfo tile4d_ci(tile4d_shapes, 4, DataType::FLOAT32); + TensorCreateInfo scalar_ci(scalar_shapes, 3, DataType::FLOAT32); + + // Prefetch first block host_context_lens data into cache + __builtin_prefetch(&host_context_lens[0], 0, 3); + + for (uint64_t b_idx = 0; b_idx < batch; b_idx++) { + uint64_t cur_seq = host_context_lens[b_idx]; + uint64_t bn_this_batch = (cur_seq + block_size - 1) / block_size; + + // Prefetch next block host_context_lens data while processing current batch + if (b_idx + 1 < batch) { + __builtin_prefetch(&host_context_lens[b_idx + 1], 0, 3); + } + for (uint64_t q_idx = 0; q_idx < q_loop; q_idx++) { + PTO2_SCOPE() { + // 4D views into query/out, matching (1, 1, q_tile, head_dim). + uint32_t view_shapes[4] = {1, 1, (uint32_t)q_tile, (uint32_t)head_dim}; + uint32_t view_offsets[4] = {(uint32_t)b_idx, 0, (uint32_t)(q_idx * q_tile), 0}; + Tensor qi = query.view(view_shapes, view_offsets); + Tensor out_view = out.view(view_shapes, view_offsets, true); + + // Per-group accumulators: oi (4D data), mi_update/li_update (3D scalars). + TaskOutputTensors alloc_outs = alloc_tensors(tile4d_ci, scalar_ci, scalar_ci); + const Tensor &oi = alloc_outs.get_ref(0); + const Tensor &li_update = alloc_outs.get_ref(1); + const Tensor &mi_update = alloc_outs.get_ref(2); + + // Reusable Arg objects — reset() before each use avoids + // repeated stack-frame construction in the inner loop. + Arg params_qk, params_sf, params_pv, params_up; + + for (uint64_t bn = 0; bn < bn_this_batch; bn += N_UNROLL) { + uint64_t n_blocks = std::min((uint64_t)N_UNROLL, bn_this_batch - bn); + + // Valid length for last block in this group + uint64_t last_block_seq_start = (bn + n_blocks - 1) * block_size; + uint64_t valid_len_last = std::min(block_size, cur_seq - last_block_seq_start); + + // === Task 1: Batched QK matmul — produces 4D sij_buf === + uint32_t sij_buf_shapes[4] = {1, 1, (uint32_t)q_tile, (uint32_t)(n_blocks * block_size)}; + TensorCreateInfo sij_buf_ci(sij_buf_shapes, 4, DataType::FLOAT32); + + params_qk.reset(); + params_qk.add_input(qi); + params_qk.add_input(key_cache); + params_qk.add_input(block_table); + params_qk.add_output(sij_buf_ci); + params_qk.add_scalar(n_blocks); + params_qk.add_scalar(b_idx * block_num + bn); + TaskOutputTensors qk_outs = pto2_rt_submit_aic_task(FUNC_QK_MATMUL, params_qk); + const Tensor &sij_buf = qk_outs.get_ref(0); + + // === Task 2: Two-pass softmax — produces 4D pij_buf, 3D mi, li === + uint32_t pij_buf_shapes[4] = {1, 1, (uint32_t)q_tile, (uint32_t)(n_blocks * block_size)}; + TensorCreateInfo pij_buf_ci(pij_buf_shapes, 4, data_type); + + params_sf.reset(); + params_sf.add_input(sij_buf); + params_sf.add_output(pij_buf_ci); + params_sf.add_output(scalar_ci); + params_sf.add_output(scalar_ci); + params_sf.add_scalar(scale_value); + params_sf.add_scalar(n_blocks); + params_sf.add_scalar(valid_len_last); + TaskOutputTensors sf_outs = pto2_rt_submit_aiv_task(FUNC_SOFTMAX_PREPARE, params_sf); + const Tensor &pij_buf = sf_outs.get_ref(0); + const Tensor &mi = sf_outs.get_ref(1); + const Tensor &li = sf_outs.get_ref(2); + + // === Task 3: SplitK PV matmul — produces 4D oi_new === + params_pv.reset(); + params_pv.add_input(pij_buf); + params_pv.add_input(value_cache); + params_pv.add_input(block_table); + params_pv.add_output(tile4d_ci); + params_pv.add_scalar(n_blocks); + params_pv.add_scalar(b_idx * block_num + bn); + TaskOutputTensors pv_outs = pto2_rt_submit_aic_task(FUNC_PV_MATMUL, params_pv); + const Tensor &oi_new = pv_outs.get_ref(0); + + // === Task 4: Online update (per-group) === + uint64_t is_first = (bn == 0) ? 1 : 0; + uint64_t is_last = (bn + n_blocks >= bn_this_batch) ? 1 : 0; + + params_up.reset(); + params_up.add_input(mi); + params_up.add_input(li); + params_up.add_input(oi_new); + params_up.add_inout(mi_update); + params_up.add_inout(li_update); + params_up.add_inout(oi); + params_up.add_inout(out_view); + params_up.add_scalar(is_first); + params_up.add_scalar(is_last); + pto2_rt_submit_aiv_task(FUNC_ONLINE_UPDATE, params_up); + } + } + } + } +} + +} // extern "C" diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/test_paged_attention_unroll_4dims.py b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/test_paged_attention_unroll_4dims.py new file mode 100644 index 00000000..4ade6341 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/test_paged_attention_unroll_4dims.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Paged attention unroll with 4D input shapes (batch, seq_len, num_heads, head_dim). + +Query and output tensors use 4D format instead of the standard 3D. +6 kernels: QK/PV matmul (AIC), softmax_prepare/online_update (AIV). +Orchestration with N_UNROLL=64, 4 tasks per group, online softmax accumulation. +""" + +import torch +from paged_attention_golden import compute_golden as _pa_compute_golden +from paged_attention_golden import generate_inputs as _pa_generate_inputs +from simpler.task_interface import ArgDirection as D + +from simpler_setup import Scalar, SceneTestCase, TaskArgsBuilder, Tensor, scene_test + + +@scene_test(level=2, runtime="tensormap_and_ringbuffer") +class TestPagedAttentionUnroll4dims(SceneTestCase): + """Paged attention unroll with 4D query/out shapes.""" + + RTOL = 1e-3 + ATOL = 1e-3 + + CALLABLE = { + "orchestration": { + "source": "kernels/orchestration/paged_attention_orch.cpp", + "function_name": "aicpu_orchestration_entry", + "signature": [D.IN, D.IN, D.IN, D.IN, D.IN, D.OUT], + }, + "incores": [ + { + "func_id": 0, + "source": "kernels/aic/aic_qk_matmul.cpp", + "core_type": "aic", + "signature": [D.IN, D.IN, D.IN, D.OUT], + }, + { + "func_id": 1, + "source": "kernels/aiv/aiv_softmax_prepare.cpp", + "core_type": "aiv", + "signature": [D.IN, D.OUT, D.OUT, D.OUT], + }, + { + "func_id": 2, + "source": "kernels/aic/aic_pv_matmul.cpp", + "core_type": "aic", + "signature": [D.IN, D.IN, D.IN, D.OUT], + }, + { + "func_id": 3, + "source": "kernels/aiv/aiv_online_update.cpp", + "core_type": "aiv", + "signature": [D.IN, D.IN, D.IN, D.INOUT, D.INOUT, D.INOUT, D.INOUT], + }, + ], + } + + CASES = [ + { + "name": "Case1", + "platforms": ["a2a3"], + "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "params": { + "batch": 256, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 128, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + }, + { + "name": "Case2", + "platforms": ["a2a3"], + "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "params": { + "batch": 64, + "num_heads": 64, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 64, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + "manual": True, + }, + { + "name": "Case3", + "platforms": ["a2a3"], + "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "params": { + "batch": 64, + "num_heads": 64, + "kv_head_num": 1, + "head_dim": 256, + "block_size": 64, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + "manual": True, + }, + ] + + def generate_args(self, params): + inputs = _pa_generate_inputs(params) + batch = params["batch"] + num_heads = params["num_heads"] + head_dim = params["head_dim"] + specs = [] + for name, val in inputs: + if isinstance(val, torch.Tensor): + if name in ("query", "out"): + val = val.reshape(batch, 1, num_heads, head_dim) + specs.append(Tensor(name, val)) + else: + specs.append(Scalar(name, val)) + return TaskArgsBuilder(*specs) + + def compute_golden(self, args, params): + batch = params["batch"] + num_heads = params["num_heads"] + head_dim = params["head_dim"] + tensors = {s.name: s.value for s in args.specs if isinstance(s, Tensor)} + # Reshape 4D out to 3D for shared golden, then restore + out_4d = tensors["out"] + tensors["out"] = out_4d.reshape(batch, num_heads, head_dim) + _pa_compute_golden(tensors, params) + tensors["out"] = out_4d + + +if __name__ == "__main__": + SceneTestCase.run_module(__name__)