diff --git a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aic/aic_pv_matmul.cpp b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aic/aic_pv_matmul.cpp index d7e928668..80cd6c90e 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aic/aic_pv_matmul.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aic/aic_pv_matmul.cpp @@ -10,7 +10,9 @@ */ // PV Matmul Kernel: pij(M, K) @ vj(K, N) -> oi_new(M, N) // -// Fixed tile size: (16, 16) @ (16, 16) -> (16, 16) +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) @ (128, 128) -> (16, 128) +// Case2: (64, 64) @ ( 64, 128) -> (64, 128) // // pij is bfloat16 (converted from fp32 in softmax_prepare via TCVT). // vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout. @@ -67,15 +69,17 @@ static __aicore__ void pv_matmul_impl(__gm__ Tensor *pij, __gm__ Tensor *vj, __g TASSIGN(bTile, 0x0); TASSIGN(cTile, 0x0); - // Load pij and vj to L1 + // Load pij and vj to L1 with separate events for pipeline overlap TLOAD(aMatTile, pijGlobal); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // A load done TLOAD(bMatTile, vjGlobal); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // B load done - set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + // Move A to L0A as soon as A load completes (B may still be loading) wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); - - // Move to L0A/L0B TMOV(aTile, aMatTile); + // Move B to L0B after B load completes + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); TMOV(bTile, bMatTile); set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); @@ -97,6 +101,13 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { __gm__ Tensor *pij = reinterpret_cast<__gm__ Tensor *>(args[0]); __gm__ Tensor *vj = reinterpret_cast<__gm__ Tensor *>(args[1]); __gm__ Tensor *oi_new = reinterpret_cast<__gm__ Tensor *>(args[2]); - - pv_matmul_impl<16, 16, 16>(pij, vj, oi_new); + uint64_t q_tile_size = static_cast(pij->shapes[0]); + + if (q_tile_size == 16 && pij->shapes[1] <= 16) { + pv_matmul_impl<16, 16, 16>(pij, vj, oi_new); + } else if (q_tile_size == 16) { + pv_matmul_impl<16, 128, 128>(pij, vj, oi_new); + } else { + pv_matmul_impl<64, 64, 128>(pij, vj, oi_new); + } } diff --git a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aic/aic_qk_matmul.cpp b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aic/aic_qk_matmul.cpp index 2db78900a..bf7ea5c6c 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aic/aic_qk_matmul.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aic/aic_qk_matmul.cpp @@ -10,7 +10,9 @@ */ // QK Matmul Kernel: qi(M, K) @ kj.T(K, N) -> sij(M, N) // -// Fixed tile size: (16, 16) @ (16, 16).T -> (16, 16) +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) @ (128, 128).T -> (16, 128) +// Case2: (64, 128) @ (128, 64).T -> (64, 64) // // kj is stored as (N, K) = (block_size, head_dim) in row-major memory. // This is equivalent to (K, N) in column-major (DN) layout. @@ -68,15 +70,17 @@ static __aicore__ void qk_matmul_impl(__gm__ Tensor *qi, __gm__ Tensor *kj, __gm TASSIGN(bTile, 0x0); TASSIGN(cTile, 0x0); - // Load A and B to L1 + // Load A and B to L1 with separate events for pipeline overlap TLOAD(aMatTile, qiGlobal); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // A load done TLOAD(bMatTile, kjGlobal); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // B load done - set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + // Move A to L0A as soon as A load completes (B may still be loading) wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); - - // Move from L1 to L0A/L0B TMOV(aTile, aMatTile); + // Move B to L0B after B load completes + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); TMOV(bTile, bMatTile); set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); @@ -98,6 +102,13 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { __gm__ Tensor *qi = reinterpret_cast<__gm__ Tensor *>(args[0]); __gm__ Tensor *kj = reinterpret_cast<__gm__ Tensor *>(args[1]); __gm__ Tensor *sij = reinterpret_cast<__gm__ Tensor *>(args[2]); - - qk_matmul_impl<16, 16, 16>(qi, kj, sij); + uint64_t q_tile_size = static_cast(qi->shapes[0]); + + if (q_tile_size == 16 && qi->shapes[1] <= 16) { + qk_matmul_impl<16, 16, 16>(qi, kj, sij); + } else if (q_tile_size == 16) { + qk_matmul_impl<16, 128, 128>(qi, kj, sij); + } else { + qk_matmul_impl<64, 128, 64>(qi, kj, sij); + } } diff --git a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aiv/aiv_online_update.cpp b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aiv/aiv_online_update.cpp index 1828ad674..b7d8b0a77 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aiv/aiv_online_update.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aiv/aiv_online_update.cpp @@ -10,13 +10,15 @@ */ // Online Softmax Update + Normalize Kernel (AIV) // -// Fixed tile size: oi/oi_new are (16, 16), mij/lij/mi/li are 16-element vectors +// Operates on full tiles where M=q_tile_size, N=head_dim (128): +// Case1: oi/oi_new are (16, 128), mij/lij/mi/li are 16-element vectors +// Case2: oi/oi_new are (64, 128), mij/lij/mi/li are 64-element vectors // -// Scalar layout strategy: -// M scalar floats stored contiguously in GM can be loaded as either: -// - ND (kScalarRows, kScalarCols) RowMajor for element-wise ops (TMAX, TSUB, TEXP, TMUL, TADD) -// - DN (kAlignedRows, 1) ColMajor for row-broadcast ops (TROWEXPANDMUL, TROWEXPANDDIV) -// Conversion between layouts uses GM round-trip: ND TSTORE -> DN TLOAD. +// Scalar layout strategy using TRESHAPE (zero-copy UB reshape): +// Scalars loaded as DN ColMajor (M, 1) for TROWEXPANDMUL/TROWEXPANDDIV. +// For element-wise ops (TMAX, TSUB, TEXP, etc.), TRESHAPE to RowMajor (1, M). +// After arithmetic, TRESHAPE back to ColMajor (M, 1) for row-broadcast ops. +// This eliminates the GM round-trip (TSTORE ND → TLOAD DN) used in the original. #include #include @@ -46,11 +48,6 @@ static __aicore__ void online_update_impl( __gm__ float *oi_ptr = reinterpret_cast<__gm__ float *>(oi->buffer.addr); __gm__ float *dst_ptr = reinterpret_cast<__gm__ float *>(dst->buffer.addr); - // Scalar tile dimensions for RowMajor layout: - // kScalarCols = 32 bytes / 4 bytes per float = 8 floats per row (one 32-byte block) - // kScalarRows = M / 8 (M=16 -> 2 rows) - constexpr int kScalarCols = 32 / sizeof(float); - constexpr int kScalarRows = M / kScalarCols; // Aligned rows for ColMajor DN tiles (32-byte alignment) constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); @@ -59,77 +56,84 @@ static __aicore__ void online_update_impl( // Data (M, N) RowMajor using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; - // Scalar ND: M contiguous floats as (kScalarRows, kScalarCols) RowMajor + // Scalar DN: M contiguous floats as (kAlignedRows, 1) ColMajor for TROWEXPAND ops and loading + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + + // Scalar ND: for storing mi_new and li_new back to GM + constexpr int kScalarCols = 32 / sizeof(float); + constexpr int kScalarRows = M / kScalarCols; using GlobalScalarND = GlobalTensor, Stride<1, 1, 1, kScalarCols, 1>>; - // Scalar DN: same M contiguous floats as (kAlignedRows, 1) ColMajor - using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; - // --- GlobalTensor instances --- GlobalDataMxN oiNewGlobal(oi_new_ptr + oi_new->start_offset); GlobalDataMxN oiGlobal(oi_ptr + oi->start_offset); GlobalDataMxN dstGlobal(dst_ptr + dst->start_offset); - // ND globals for scalar element-wise operations - GlobalScalarND mijGlobalND(mij_ptr + mij->start_offset); - GlobalScalarND lijGlobalND(lij_ptr + lij->start_offset); - GlobalScalarND miGlobalND(mi_ptr + mi->start_offset); - GlobalScalarND liGlobalND(li_ptr + li->start_offset); - - // DN globals aliased to same GM for ColMajor reload (used after ND TSTORE) + // DN globals for loading scalars as ColMajor GlobalScalarDN mijGlobalDN(mij_ptr + mij->start_offset); GlobalScalarDN lijGlobalDN(lij_ptr + lij->start_offset); + GlobalScalarDN miGlobalDN(mi_ptr + mi->start_offset); GlobalScalarDN liGlobalDN(li_ptr + li->start_offset); + // ND globals for storing scalar results + GlobalScalarND miGlobalND(mi_ptr + mi->start_offset); + GlobalScalarND liGlobalND(li_ptr + li->start_offset); + // --- Tile types --- using TileDataMxN = Tile; + using TileScalarDN = Tile; + + // RowMajor (1, M) tiles for element-wise arithmetic via TRESHAPE + using TileScalarRow = Tile; + + // ND tile for storing back to GM using TileScalarND = Tile; - using TileScalarDN = Tile; // --- UB memory layout --- constexpr int kDataBytes = M * N * sizeof(float); - constexpr int kScalarNDBytes = kScalarRows * kScalarCols * sizeof(float); constexpr int kScalarDNBytes = kAlignedRows * sizeof(float); // Data tiles TileDataMxN oiNewTile; TileDataMxN oiTile; - // Scalar ND tiles for element-wise arithmetic - TileScalarND mijND, lijND, miND, liND; - TileScalarND miNewND, alphaND, betaND, tmpND; + // Scalar DN tiles loaded from GM (ColMajor) + TileScalarDN mijDN, lijDN, miDN, liDN; - // Scalar DN tiles for TROWEXPAND operations - TileScalarDN alphaDN, betaDN, liDN; + // Temporary DN tiles for results + TileScalarDN miNewDN, alphaDN, betaDN, liNewDN, tmpDN; TASSIGN(oiNewTile, 0); TASSIGN(oiTile, kDataBytes); - TASSIGN(mijND, 2 * kDataBytes); - TASSIGN(lijND, 2 * kDataBytes + kScalarNDBytes); - TASSIGN(miND, 2 * kDataBytes + 2 * kScalarNDBytes); - TASSIGN(liND, 2 * kDataBytes + 3 * kScalarNDBytes); - TASSIGN(miNewND, 2 * kDataBytes + 4 * kScalarNDBytes); - TASSIGN(alphaND, 2 * kDataBytes + 5 * kScalarNDBytes); - TASSIGN(betaND, 2 * kDataBytes + 6 * kScalarNDBytes); - TASSIGN(tmpND, 2 * kDataBytes + 7 * kScalarNDBytes); - TASSIGN(alphaDN, 2 * kDataBytes + 8 * kScalarNDBytes); - TASSIGN(betaDN, 2 * kDataBytes + 8 * kScalarNDBytes + kScalarDNBytes); - TASSIGN(liDN, 2 * kDataBytes + 8 * kScalarNDBytes + 2 * kScalarDNBytes); + TASSIGN(mijDN, 2 * kDataBytes); + TASSIGN(lijDN, 2 * kDataBytes + kScalarDNBytes); + TASSIGN(miDN, 2 * kDataBytes + 2 * kScalarDNBytes); + TASSIGN(liDN, 2 * kDataBytes + 3 * kScalarDNBytes); + TASSIGN(miNewDN, 2 * kDataBytes + 4 * kScalarDNBytes); + TASSIGN(alphaDN, 2 * kDataBytes + 5 * kScalarDNBytes); + TASSIGN(betaDN, 2 * kDataBytes + 6 * kScalarDNBytes); + TASSIGN(liNewDN, 2 * kDataBytes + 7 * kScalarDNBytes); + TASSIGN(tmpDN, 2 * kDataBytes + 8 * kScalarDNBytes); if (is_first) { // --- First block: copy inputs to accumulators --- TLOAD(oiNewTile, oiNewGlobal); - TLOAD(mijND, mijGlobalND); - TLOAD(lijND, lijGlobalND); + TLOAD(mijDN, mijGlobalDN); + TLOAD(lijDN, lijGlobalDN); set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - // Passthrough to MTE3 (no V compute needed) + // Store mi = mij, li = lij, oi = oi_new + // Alias ND tiles to the same UB as DN tiles for storing as ND format + TileScalarND mijND, lijND; + TASSIGN(mijND, 2 * kDataBytes); // alias same UB as mijDN + TASSIGN(lijND, 2 * kDataBytes + kScalarDNBytes); // alias same UB as lijDN + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); TSTORE(miGlobalND, mijND); // mi = mij @@ -138,13 +142,10 @@ static __aicore__ void online_update_impl( if (is_last) { // Single block: normalize dst = oi_new / lij - // lij stored to li buffer in ND format; reload as DN for TROWEXPANDDIV - set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); - TLOAD(liDN, liGlobalDN); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); - TROWEXPANDDIV(oiNewTile, oiNewTile, liDN); + // lijDN already in ColMajor DN format, use directly for TROWEXPANDDIV + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TROWEXPANDDIV(oiNewTile, oiNewTile, lijDN); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); TSTORE(dstGlobal, oiNewTile); @@ -152,73 +153,79 @@ static __aicore__ void online_update_impl( } else { // --- Subsequent blocks: accumulate --- - // Phase 1: Load all inputs + // Load all inputs TLOAD(oiNewTile, oiNewGlobal); TLOAD(oiTile, oiGlobal); - TLOAD(mijND, mijGlobalND); - TLOAD(lijND, lijGlobalND); - TLOAD(miND, miGlobalND); - TLOAD(liND, liGlobalND); + TLOAD(mijDN, mijGlobalDN); + TLOAD(lijDN, lijGlobalDN); + TLOAD(miDN, miGlobalDN); + TLOAD(liDN, liGlobalDN); set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - // Phase 2: Scalar arithmetic in RowMajor (kScalarRows, kScalarCols) - // pipe_barrier(PIPE_V) required between each dependent vector operation - // to resolve RAW hazards on shared UB tiles. - TMAX(miNewND, miND, mijND); // mi_new = max(mi, mij) + // TRESHAPE: ColMajor(M,1) → RowMajor(1,M) for element-wise arithmetic + TileScalarRow miRow, mijRow, liRow, lijRow; + TRESHAPE(miRow, miDN); + TRESHAPE(mijRow, mijDN); + TRESHAPE(liRow, liDN); + TRESHAPE(lijRow, lijDN); + + // Scalar arithmetic in RowMajor (1, M) layout + TileScalarRow miNewRow, alphaRow, betaRow, liNewRow, tmpRow; + TASSIGN(miNewRow, 2 * kDataBytes + 4 * kScalarDNBytes); + TASSIGN(alphaRow, 2 * kDataBytes + 5 * kScalarDNBytes); + TASSIGN(betaRow, 2 * kDataBytes + 6 * kScalarDNBytes); + TASSIGN(liNewRow, 2 * kDataBytes + 7 * kScalarDNBytes); + TASSIGN(tmpRow, 2 * kDataBytes + 8 * kScalarDNBytes); + + TMAX(miNewRow, miRow, mijRow); // mi_new = max(mi, mij) pipe_barrier(PIPE_V); - TSUB(alphaND, miND, miNewND); // alpha = mi - mi_new + TSUB(alphaRow, miRow, miNewRow); // alpha_exp = mi - mi_new pipe_barrier(PIPE_V); - TEXP(alphaND, alphaND); // alpha = exp(mi - mi_new) + TEXP(alphaRow, alphaRow); // alpha = exp(mi - mi_new) pipe_barrier(PIPE_V); - TSUB(betaND, mijND, miNewND); // beta = mij - mi_new + TSUB(betaRow, mijRow, miNewRow); // beta_exp = mij - mi_new pipe_barrier(PIPE_V); - TEXP(betaND, betaND); // beta = exp(mij - mi_new) + TEXP(betaRow, betaRow); // beta = exp(mij - mi_new) pipe_barrier(PIPE_V); - TMUL(liND, alphaND, liND); // li = alpha * li + TMUL(tmpRow, alphaRow, liRow); // alpha * li pipe_barrier(PIPE_V); - TMUL(tmpND, betaND, lijND); // tmp = beta * lij + TMUL(liNewRow, betaRow, lijRow); // beta * lij pipe_barrier(PIPE_V); - TADD(liND, liND, tmpND); // li = alpha * li + beta * lij (= li_new) + TADD(liNewRow, tmpRow, liNewRow); // li_new = alpha*li + beta*lij - // Phase 3: Store scalar results to GM (ND format) - // mi_new -> mi accumulator, li_new -> li accumulator - // alpha -> mij buffer (reuse), beta -> lij buffer (reuse) - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - TSTORE(miGlobalND, miNewND); // persist mi_new - TSTORE(liGlobalND, liND); // persist li_new - TSTORE(mijGlobalND, alphaND); // temp: alpha to mij buffer - TSTORE(lijGlobalND, betaND); // temp: beta to lij buffer - - // Phase 4: Reload alpha, beta (and li if last) as ColMajor DN - set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); - TLOAD(alphaDN, mijGlobalDN); // alpha from mij buffer as DN - TLOAD(betaDN, lijGlobalDN); // beta from lij buffer as DN - if (is_last) { - TLOAD(liDN, liGlobalDN); // li_new from li buffer as DN - } - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + // TRESHAPE back: RowMajor(1,M) → ColMajor(M,1) for TROWEXPANDMUL + TRESHAPE(alphaDN, alphaRow); + TRESHAPE(betaDN, betaRow); - // Phase 5: Scale data tiles using row-broadcast multiply + // Scale data tiles using row-broadcast multiply TROWEXPANDMUL(oiTile, oiTile, alphaDN); // oi *= alpha TROWEXPANDMUL(oiNewTile, oiNewTile, betaDN); // oi_new *= beta pipe_barrier(PIPE_V); TADD(oiTile, oiTile, oiNewTile); // oi = alpha*oi + beta*oi_new + // Store mi_new and li_new to GM (ND format) + // Alias ND tiles to the same UB locations as miNewRow and liNewRow + TileScalarND miNewND, liNewND; + TASSIGN(miNewND, 2 * kDataBytes + 4 * kScalarDNBytes); + TASSIGN(liNewND, 2 * kDataBytes + 7 * kScalarDNBytes); + if (is_last) { - // Phase 6: Normalize and output + // Normalize and output: dst = oi / li_new + TRESHAPE(liNewDN, liNewRow); pipe_barrier(PIPE_V); - TROWEXPANDDIV(oiTile, oiTile, liDN); // dst = oi / li_new - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); - wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TROWEXPANDDIV(oiTile, oiTile, liNewDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); // persist mi_new + TSTORE(liGlobalND, liNewND); // persist li_new TSTORE(dstGlobal, oiTile); } else { - // Phase 6: Store updated accumulators - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); - wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + // Store updated accumulators + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); // persist mi_new + TSTORE(liGlobalND, liNewND); // persist li_new TSTORE(oiGlobal, oiTile); } } @@ -236,6 +243,13 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { __gm__ Tensor *dst = reinterpret_cast<__gm__ Tensor *>(args[6]); uint64_t is_first = static_cast(args[7]); uint64_t is_last = static_cast(args[8]); + uint64_t q_tile_size = static_cast(mij->shapes[0]); - online_update_impl<16, 16>(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); + if (q_tile_size == 16 && oi_new->shapes[1] <= 16) { + online_update_impl<16, 16>(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); + } else if (q_tile_size == 16) { + online_update_impl<16, 128>(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); + } else { + online_update_impl<64, 128>(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); + } } diff --git a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aiv/aiv_softmax_prepare.cpp b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aiv/aiv_softmax_prepare.cpp index fc39bc2ef..75ee87855 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aiv/aiv_softmax_prepare.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/aiv/aiv_softmax_prepare.cpp @@ -8,17 +8,18 @@ * See LICENSE in the root of the software repository for the full text of the License. * ----------------------------------------------------------------------------------------------------------- */ - // Softmax Preparation Kernel (AIV) with partial block masking // -// Fixed tile size: sij is (16, 16) +// Operates on (M, N) tile where M=q_tile_size, N=block_size: +// Case1: sij is (16, 128) +// Case2: sij is (64, 64) // // For partial blocks (valid_len < N), positions [valid_len, N) in sij are -// filled with -inf before softmax, ensuring exp(-inf)=0 so that invalid -// key positions contribute zero attention weight. +// filled with -inf via TFILLPAD_INPLACE before softmax, ensuring exp(-inf)=0 +// so that invalid key positions contribute zero attention weight. // // Computes: -// sij_masked = pad(sij, valid_len, -inf) +// sij_masked = TFILLPAD(sij, valid_len, pad=-inf) // sij_scale = sij_masked * scale // mij = row_max(sij_scale) -> (M, 1) // pij = exp(sij_scale - mij) -> (M, N) @@ -27,9 +28,8 @@ #include #include -#include "tensor.h" // NOLINT(build/include_subdir) +#include "tensor.h" -// NOLINTNEXTLINE(build/namespaces) using namespace pto; #ifndef __gm__ @@ -37,7 +37,7 @@ using namespace pto; #endif #ifndef __aicore__ -#define __aicore__ [aicore] // NOLINT(whitespace/braces) +#define __aicore__ [aicore] #endif template @@ -79,6 +79,7 @@ static __aicore__ void softmax_prepare_impl( TileScalarDN sumTile; TileVecMxN_bf16 pijBf16Tile; + // All sij tiles share UB address 0x0 (in-place masking) TASSIGN(sijTile, 0x0); TASSIGN(sijDynTile, 0x0); TASSIGN(sijPadTile, 0x0); @@ -89,11 +90,13 @@ static __aicore__ void softmax_prepare_impl( TASSIGN(pijBf16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float)); // Load full sij (M, N) tile from GM - all N columns including garbage for partial blocks + // printf("sij addr incore %x\n", sij->buffer.addr); TLOAD(sijTile, sijGlobal); set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - // manually fill invalid columns with -inf as a workaround. + // Mask columns [valid_len, N) with -inf. sijDynTile provides the valid boundary, + // sijPadTile provides PadValue::Min as the fill value. No-op when valid_len == N. TFILLPAD_INPLACE(sijPadTile, sijDynTile); pipe_barrier(PIPE_V); @@ -104,16 +107,26 @@ static __aicore__ void softmax_prepare_impl( TROWEXPANDSUB(pijTile, sijTile, maxTile); pipe_barrier(PIPE_V); TEXP(pijTile, pijTile); - // Truncate pij to bf16 first, then compute lij from truncated values (matches golden) + // Truncate pij to bf16 first + pipe_barrier(PIPE_V); TCVT(pijBf16Tile, pijTile, RoundMode::CAST_ROUND); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); // pij bf16 ready, can store early + + // Continue computing: bf16 → f32 and rowsum while pij store proceeds in parallel + pipe_barrier(PIPE_V); TCVT(pijTile, pijBf16Tile, RoundMode::CAST_ROUND); + pipe_barrier(PIPE_V); TROWSUM(sumTile, pijTile, tmpTile); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); // sum ready - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + // Store pij (overlaps with TCVT + TROWSUM above) wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(pijGlobal, pijBf16Tile); + + // Store max and sum TSTORE(mijGlobal, maxTile); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); TSTORE(lijGlobal, sumTile); - TSTORE(pijGlobal, pijBf16Tile); set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); @@ -124,7 +137,19 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { __gm__ Tensor *pij = reinterpret_cast<__gm__ Tensor *>(args[1]); __gm__ Tensor *mij = reinterpret_cast<__gm__ Tensor *>(args[2]); __gm__ Tensor *lij = reinterpret_cast<__gm__ Tensor *>(args[3]); - float scale_value = from_u64(static_cast(args[4])); - - softmax_prepare_impl<16, 16>(sij, scale_value, pij, mij, lij); + union { + uint64_t u; + float f; + } scale_conv; + scale_conv.u = static_cast(args[4]); + float scale_value = scale_conv.f; + uint64_t q_tile_size = static_cast(sij->shapes[0]); + + if (q_tile_size == 16 && pij->shapes[1] <= 16) { + softmax_prepare_impl<16, 16>(sij, scale_value, pij, mij, lij); + } else if (q_tile_size == 16) { + softmax_prepare_impl<16, 128>(sij, scale_value, pij, mij, lij); + } else { + softmax_prepare_impl<64, 64>(sij, scale_value, pij, mij, lij); + } } diff --git a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/orchestration/paged_attention_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/orchestration/paged_attention_orch.cpp index b4600a341..5d8acb234 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/orchestration/paged_attention_orch.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/kernels/orchestration/paged_attention_orch.cpp @@ -15,18 +15,15 @@ * Each block processes a single 16x16 matmul operation. * * Memory Layout: - * Query: (batch, 16, 16) - one 16x16 tile per batch bf16 - * Key: (total_blocks, 16, 16) - stored as K^T for direct matmul bf16 - * Value: (total_blocks, 16, 16) - direct format bf16 - * - * This file compiles as a standalone .so with zero runtime link dependencies. - * All runtime calls go through the PTO2RuntimeOps function-pointer table. + * Query: (batch, 16, 16) - one 16x16 tile per batch + * Key: (total_blocks, 16, 16) - stored as K^T for direct matmul + * Value: (total_blocks, 16, 16) - direct format */ -#include -#include - +#include #include +#include +#include #include "pto_orchestration_api.h" @@ -34,6 +31,39 @@ #define FUNC_SOFTMAX_PREPARE 1 #define FUNC_PV_MATMUL 2 #define FUNC_ONLINE_UPDATE 3 +constexpr uint64_t PLATFORM_PROF_SYS_CNT_FREQ = 50000000; // 50 MHz + +inline double cycles_to_us(uint64_t cycles) { + return (static_cast(cycles) / PLATFORM_PROF_SYS_CNT_FREQ) * 1000000.0; +} + +inline uint64_t get_sys_cnt_aicpu() { +#if defined(__aarch64__) + uint64_t ticks; + asm volatile("mrs %0, cntvct_el0" : "=r"(ticks)); + return ticks; +#elif defined(__x86_64__) + return 0; +#else + return 0; +#endif +} + +#ifdef ENABLE_PROFILING +#define CYCLE_COUNT_START() uint64_t _t0 = get_sys_cnt_aicpu(), _t1 +#define CYCLE_COUNT_LAP(acc) \ + do { \ + _t1 = get_sys_cnt_aicpu(); \ + acc += (_t1 - _t0); \ + _t0 = _t1; \ + } while (0) +#define PROF_INC(counter, n) (counter) += (n) +#else +#define CYCLE_COUNT_START() (void)0 +#define CYCLE_COUNT_LAP(acc) (void)0 +#define PROF_INC(counter, n) (void)0 +#endif + extern "C" { __attribute__((visibility("default"))) PTO2OrchestrationConfig @@ -45,26 +75,38 @@ aicpu_orchestration_config(const ChipStorageTaskArgs &orch_args) { } __attribute__((visibility("default"))) void aicpu_orchestration_entry(const ChipStorageTaskArgs &orch_args) { +#ifdef ENABLE_PROFILING + uint64_t prof_param_extract = 0; + uint64_t prof_ext_tensor = 0; + uint64_t prof_scope = 0; + uint64_t prof_make_tensor = 0; + uint64_t prof_tensor_view = 0; + uint64_t prof_param_setup = 0; + uint64_t prof_submit_task = 0; + int prof_submit_count = 0; + int prof_make_count = 0; + int prof_view_count = 0; +#endif + + CYCLE_COUNT_START(); + // Read dimensions from tensor metadata - // query: shape=[batch, num_heads, head_dim] uint64_t batch = orch_args.tensor(0).shapes[0]; uint64_t num_heads = orch_args.tensor(0).shapes[1]; uint64_t head_dim = orch_args.tensor(0).shapes[2]; DataType data_type = orch_args.tensor(0).dtype; - // key_cache: shape=[total_blocks, block_size, kv_head_num, head_dim] uint64_t block_size = orch_args.tensor(1).shapes[1]; - - // block_table: shape=[batch, max_num_blocks_per_req] uint64_t block_num = orch_args.tensor(3).shapes[1]; - // scale from scalar arg uint64_t scale_value = orch_args.scalar(0); uint64_t q_head_num = num_heads; - uint64_t q_tile = 16; + uint64_t q_tile = std::min(num_heads, static_cast(128)); uint64_t q_loop = (q_head_num + q_tile - 1) / q_tile; - uint64_t elem_size = get_element_size(data_type); + CYCLE_COUNT_LAP(prof_param_extract); + + LOG_ALWAYS(">>>>>> batch = %" PRIu64, batch); // Reshape tensors for kernel consumption (2D flattened) void *query_ptr = orch_args.tensor(0).data_as(); @@ -72,22 +114,21 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(const Chip void *vc_ptr = orch_args.tensor(2).data_as(); void *out_ptr = orch_args.tensor(5).data_as(); - // Compute kv_total_rows from key_cache tensor metadata uint64_t total_blocks_count = orch_args.tensor(1).shapes[0]; - uint64_t kv_total_rows = total_blocks_count * block_size; uint32_t query_shapes[2] = {static_cast(batch * num_heads), static_cast(head_dim)}; - uint32_t key_cache_shapes[2] = {static_cast(kv_total_rows), static_cast(head_dim)}; - uint32_t value_cache_shapes[2] = {static_cast(kv_total_rows), static_cast(head_dim)}; + uint32_t key_cache_shapes[2] = { + static_cast(total_blocks_count * block_size), static_cast(head_dim) + }; + uint32_t value_cache_shapes[2] = { + static_cast(total_blocks_count * block_size), static_cast(head_dim) + }; uint32_t out_shapes[2] = {static_cast(batch * num_heads), static_cast(head_dim)}; Tensor query = make_tensor_external(query_ptr, query_shapes, 2, data_type); Tensor key_cache = make_tensor_external(kc_ptr, key_cache_shapes, 2, data_type); Tensor value_cache = make_tensor_external(vc_ptr, value_cache_shapes, 2, data_type); Tensor out = make_tensor_external(out_ptr, out_shapes, 2, DataType::FLOAT32); - LOG_DEBUG("query=%s", query.dump().c_str()); - LOG_DEBUG("key_cache=%s", key_cache.dump().c_str()); - LOG_DEBUG("value_cache=%s", value_cache.dump().c_str()); - LOG_DEBUG("out=%s", out.dump().c_str()); + CYCLE_COUNT_LAP(prof_ext_tensor); uint32_t bt_shapes[2] = {static_cast(batch), static_cast(block_num)}; Tensor block_table = @@ -103,7 +144,10 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(const Chip TensorCreateInfo tile2d_ci(tile2d_shapes, 2, DataType::FLOAT32); TensorCreateInfo scalar_ci(scalar_shapes, 1, DataType::FLOAT32); TensorCreateInfo sij_ci(sij_shapes, 2, DataType::FLOAT32); - TensorCreateInfo pij_bf16_ci(sij_shapes, 2, data_type); + TensorCreateInfo pij_f16_ci(sij_shapes, 2, data_type); + + PROF_INC(prof_make_count, 4); + CYCLE_COUNT_LAP(prof_make_tensor); for (uint64_t b_idx = 0; b_idx < batch; b_idx++) { uint32_t cl_idx[1] = {static_cast(b_idx)}; @@ -111,58 +155,82 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(const Chip uint64_t bn_this_batch = (cur_seq + block_size - 1) / block_size; for (uint64_t q_idx = 0; q_idx < q_loop; q_idx++) { PTO2_SCOPE() { - uint32_t cur_offset = static_cast(b_idx * q_head_num + q_idx * q_tile); + CYCLE_COUNT_LAP(prof_scope); + uint64_t cur_offset = b_idx * q_head_num + q_idx * q_tile; - uint32_t qi_offsets[2] = {cur_offset, 0}; + uint32_t qi_offsets[2] = {static_cast(cur_offset), 0}; Tensor qi = query.view(tile2d_shapes, qi_offsets); - uint32_t out_view_offsets[2] = {cur_offset, 0}; + uint32_t out_view_offsets[2] = {static_cast(cur_offset), 0}; Tensor out_view = out.view(tile2d_shapes, out_view_offsets); + PROF_INC(prof_view_count, 2); + CYCLE_COUNT_LAP(prof_tensor_view); + CYCLE_COUNT_LAP(prof_param_setup); TaskOutputTensors alloc_outs = alloc_tensors(tile2d_ci, scalar_ci, scalar_ci); const Tensor &oi = alloc_outs.get_ref(0); const Tensor &li_update = alloc_outs.get_ref(1); const Tensor &mi_update = alloc_outs.get_ref(2); + PROF_INC(prof_submit_count, 1); + CYCLE_COUNT_LAP(prof_submit_task); for (uint64_t bn = 0; bn < bn_this_batch; bn++) { + PTO2_SCOPE_GUARD(); + uint32_t bt_idx[2] = {static_cast(b_idx), static_cast(bn)}; uint64_t cur_block_idx = static_cast(get_tensor_data(block_table, 2, bt_idx)); - uint64_t valid_len = - block_size < (cur_seq - bn * block_size) ? block_size : (cur_seq - bn * block_size); + uint64_t valid_len = std::min(block_size, cur_seq - bn * block_size); + CYCLE_COUNT_LAP(prof_param_extract); + uint32_t kv_shapes[2] = {static_cast(block_size), static_cast(head_dim)}; uint32_t kv_offsets[2] = {static_cast(cur_block_idx * block_size), 0}; Tensor kj = key_cache.view(kv_shapes, kv_offsets); Tensor vj = value_cache.view(kv_shapes, kv_offsets); + PROF_INC(prof_view_count, 2); + CYCLE_COUNT_LAP(prof_tensor_view); Arg params_qk; params_qk.add_input(qi); params_qk.add_input(kj); params_qk.add_output(sij_ci); + CYCLE_COUNT_LAP(prof_param_setup); TaskOutputTensors qk_outs = pto2_rt_submit_aic_task(FUNC_QK_MATMUL, params_qk); const Tensor &sij = qk_outs.get_ref(0); + PROF_INC(prof_submit_count, 1); + CYCLE_COUNT_LAP(prof_submit_task); uint32_t sij_valid_shapes[2] = {static_cast(q_tile), static_cast(valid_len)}; uint32_t sij_valid_offsets[2] = {0, 0}; Tensor sij_valid = sij.view(sij_valid_shapes, sij_valid_offsets); + PROF_INC(prof_view_count, 1); + CYCLE_COUNT_LAP(prof_tensor_view); + Arg params_sf; params_sf.add_input(sij_valid); - params_sf.add_output(pij_bf16_ci); + params_sf.add_output(pij_f16_ci); params_sf.add_output(scalar_ci); params_sf.add_output(scalar_ci); params_sf.add_scalar(scale_value); + CYCLE_COUNT_LAP(prof_param_setup); TaskOutputTensors sf_outs = pto2_rt_submit_aiv_task(FUNC_SOFTMAX_PREPARE, params_sf); - const Tensor &pij_bf16 = sf_outs.get_ref(0); + const Tensor &pij_f16 = sf_outs.get_ref(0); const Tensor &mi = sf_outs.get_ref(1); const Tensor &li = sf_outs.get_ref(2); + PROF_INC(prof_submit_count, 1); + CYCLE_COUNT_LAP(prof_submit_task); Arg params_pv; - params_pv.add_input(pij_bf16); + params_pv.add_input(pij_f16); params_pv.add_input(vj); params_pv.add_output(tile2d_ci); + CYCLE_COUNT_LAP(prof_param_setup); TaskOutputTensors pv_outs = pto2_rt_submit_aic_task(FUNC_PV_MATMUL, params_pv); const Tensor &oi_tmp = pv_outs.get_ref(0); + PROF_INC(prof_submit_count, 1); + CYCLE_COUNT_LAP(prof_submit_task); uint64_t is_first = (bn == 0) ? 1 : 0; uint64_t is_last = (bn == bn_this_batch - 1) ? 1 : 0; + CYCLE_COUNT_LAP(prof_param_extract); Arg params_up; params_up.add_input(mi); @@ -174,13 +242,52 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(const Chip params_up.add_inout(out_view); params_up.add_scalar(is_first); params_up.add_scalar(is_last); + CYCLE_COUNT_LAP(prof_param_setup); pto2_rt_submit_aiv_task(FUNC_ONLINE_UPDATE, params_up); + PROF_INC(prof_submit_count, 1); + CYCLE_COUNT_LAP(prof_submit_task); } } + CYCLE_COUNT_LAP(prof_scope); } } - LOG_INFO("tasks submitted for batch=%" PRIu64 ", num_heads=%" PRIu64, batch, num_heads); +#ifdef ENABLE_PROFILING + uint64_t total = prof_param_extract + prof_ext_tensor + prof_make_tensor + prof_tensor_view + prof_param_setup + + prof_submit_task + prof_scope; + LOG_ALWAYS( + "=== PagedAttn Orch Profiling: %d submits, %d makes, %d views, total=%.3fus ===", prof_submit_count, + prof_make_count, prof_view_count, cycles_to_us(total) + ); + if (total > 0) { + LOG_ALWAYS( + " param_extract : %7.3fus (%5.1f%%)", cycles_to_us(prof_param_extract), + prof_param_extract * 100.0 / total + ); + LOG_ALWAYS( + " ext_tensor(x4) : %7.3fus (%5.1f%%)", cycles_to_us(prof_ext_tensor), prof_ext_tensor * 100.0 / total + ); + LOG_ALWAYS( + " create_info(x%d) : %7.3fus (%5.1f%%) avg=%.3fus", prof_make_count, cycles_to_us(prof_make_tensor), + prof_make_tensor * 100.0 / total, + prof_make_count > 0 ? cycles_to_us(prof_make_tensor) / prof_make_count : 0.0 + ); + LOG_ALWAYS( + " tensor_view(x%d) : %7.3fus (%5.1f%%) avg=%.3fus", prof_view_count, cycles_to_us(prof_tensor_view), + prof_tensor_view * 100.0 / total, + prof_view_count > 0 ? cycles_to_us(prof_tensor_view) / prof_view_count : 0.0 + ); + LOG_ALWAYS( + " param_setup : %7.3fus (%5.1f%%)", cycles_to_us(prof_param_setup), prof_param_setup * 100.0 / total + ); + LOG_ALWAYS(" scope : %7.3fus (%5.1f%%)", cycles_to_us(prof_scope), prof_scope * 100.0 / total); + LOG_ALWAYS( + " submit_task(x%d) : %7.3fus (%5.1f%%) avg=%.3fus", prof_submit_count, cycles_to_us(prof_submit_task), + prof_submit_task * 100.0 / total, + prof_submit_count > 0 ? cycles_to_us(prof_submit_task) / prof_submit_count : 0.0 + ); + } +#endif } } // extern "C" diff --git a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/test_paged_attention.py b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/test_paged_attention.py index 228d5b6e2..ee58ece6a 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/paged_attention/test_paged_attention.py +++ b/examples/a2a3/tensormap_and_ringbuffer/paged_attention/test_paged_attention.py @@ -7,7 +7,7 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. # ----------------------------------------------------------------------------------------------------------- -"""Paged attention: small-scale (sim-compatible, bfloat16) tests.""" +"""Paged attention: online softmax with AIC/AIV subgraph splitting (bfloat16).""" import torch from simpler.task_interface import ArgDirection as D @@ -19,8 +19,8 @@ @scene_test(level=2, runtime="tensormap_and_ringbuffer") class TestPagedAttention(SceneTestCase): - RTOL = 1e-2 - ATOL = 1e-2 + RTOL = 1e-3 + ATOL = 1e-3 CALLABLE = { "orchestration": { @@ -59,8 +59,55 @@ class TestPagedAttention(SceneTestCase): CASES = [ { "name": "Case1", - "platforms": ["a2a3sim", "a2a3"], + "platforms": ["a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "params": { + "batch": 256, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 128, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + }, + { + "name": "Case2", + "platforms": ["a2a3"], + "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "manual": True, + "params": { + "batch": 64, + "num_heads": 64, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 64, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + }, + { + "name": "Case3", + "platforms": ["a2a3"], + "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "manual": True, + "params": { + "batch": 64, + "num_heads": 64, + "kv_head_num": 1, + "head_dim": 256, + "block_size": 64, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + }, + { + "name": "CaseSmall1", + "platforms": ["a2a3sim", "a2a3"], + "config": {"aicpu_thread_num": 4, "block_dim": 9}, "params": { "batch": 1, "num_heads": 16, @@ -73,9 +120,10 @@ class TestPagedAttention(SceneTestCase): }, }, { - "name": "Case2", + "name": "CaseSmall2", "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "manual": True, "params": { "batch": 1, "num_heads": 16, @@ -91,6 +139,7 @@ class TestPagedAttention(SceneTestCase): "name": "CaseVarSeq2", "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "manual": True, "params": { "batch": 2, "num_heads": 16, @@ -107,6 +156,7 @@ class TestPagedAttention(SceneTestCase): "name": "CaseVarSeq4", "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "manual": True, "params": { "batch": 4, "num_heads": 16, diff --git a/examples/a2a3/tensormap_and_ringbuffer/paged_attention_ringbuffer/test_paged_attention_ringbuffer.py b/examples/a2a3/tensormap_and_ringbuffer/paged_attention_ringbuffer/test_paged_attention_ringbuffer.py index 1888102d3..c2b2ac802 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/paged_attention_ringbuffer/test_paged_attention_ringbuffer.py +++ b/examples/a2a3/tensormap_and_ringbuffer/paged_attention_ringbuffer/test_paged_attention_ringbuffer.py @@ -20,7 +20,7 @@ from simpler_setup.goldens.paged_attention import compute_golden as _pa_compute_golden # noqa: PLC0415 from simpler_setup.goldens.paged_attention import generate_inputs as _pa_generate_inputs # noqa: PLC0415 -PA_KERNELS = "../batch_paged_attention/kernels" +PA_KERNELS = "../../../../tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels" @scene_test(level=2, runtime="tensormap_and_ringbuffer") diff --git a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp similarity index 76% rename from examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp rename to tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp index c9befec1a..531bc6073 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp +++ b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp @@ -14,12 +14,11 @@ // Processes batch_count batches in a single kernel invocation. // Per-batch addresses are computed from global tensor bases + block_table lookup. // -// Supports three tile configurations via runtime dispatch: -// Small: (16, 16) @ ( 16, 16) -> (16, 16) [fp16] -// Case1: (16, 128) @ (128, 128) -> (16, 128) [bf16] -// Case2: (64, 64) @ ( 64, 128) -> (64, 128) [bf16] +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) @ (128, 128) -> (16, 128) +// Case2: (64, 64) @ ( 64, 128) -> (64, 128) // -// Template: T=data_type, M=q_tile, K=block_size, N=head_dim +// Template: M=q_tile, K=block_size, N=head_dim #include #include @@ -37,25 +36,25 @@ using namespace pto; #define __aicore__ [aicore] // NOLINT(whitespace/braces) #endif -template +template static __aicore__ void pv_matmul_batch_impl( __gm__ Tensor *pij_batch, __gm__ Tensor *value_cache, __gm__ Tensor *block_table_t, __gm__ Tensor *oi_new_batch, uint64_t batch_count, uint64_t block_idx, uint64_t block_num, uint64_t batch_start ) { - __gm__ T *pij_base = reinterpret_cast<__gm__ T *>(pij_batch->buffer.addr); - __gm__ T *val_base = reinterpret_cast<__gm__ T *>(value_cache->buffer.addr); + __gm__ bfloat16_t *pij_base = reinterpret_cast<__gm__ bfloat16_t *>(pij_batch->buffer.addr); + __gm__ bfloat16_t *val_base = reinterpret_cast<__gm__ bfloat16_t *>(value_cache->buffer.addr); __gm__ float *oi_base = reinterpret_cast<__gm__ float *>(oi_new_batch->buffer.addr); __gm__ int32_t *bt = reinterpret_cast<__gm__ int32_t *>(block_table_t->buffer.addr); - using GlobalA = GlobalTensor, Stride>; - using GlobalB = GlobalTensor, Stride>; + using GlobalA = GlobalTensor, Stride>; + using GlobalB = GlobalTensor, Stride>; using GlobalOut = GlobalTensor, Stride>; - using TileMatA = Tile; - using TileMatB = Tile; + using TileMatA = Tile; + using TileMatB = Tile; - using LeftTile = TileLeft; - using RightTile = TileRight; + using LeftTile = TileLeft; + using RightTile = TileRight; using AccTile = TileAcc; TileMatA aMatTile; @@ -71,9 +70,9 @@ static __aicore__ void pv_matmul_batch_impl( TASSIGN(cTile, 0x0); for (uint64_t b = 0; b < batch_count; b++) { - __gm__ T *pij_addr = pij_base + b * M * K; + __gm__ bfloat16_t *pij_addr = pij_base + b * M * K; int32_t phys_block = bt[(batch_start + b) * block_num + block_idx]; - __gm__ T *vj_addr = val_base + static_cast(phys_block) * K * N; + __gm__ bfloat16_t *vj_addr = val_base + static_cast(phys_block) * K * N; __gm__ float *oi_addr = oi_base + b * M * N; GlobalA pijGlobal(pij_addr); @@ -121,16 +120,16 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { uint64_t q_tile_size = static_cast(pij_batch->shapes[0] / batch_count); uint64_t block_size = static_cast(pij_batch->shapes[1]); - if (q_tile_size == 16 && block_size == 16) { - pv_matmul_batch_impl( + if (q_tile_size == 16 && block_size <= 16) { + pv_matmul_batch_impl<16, 16, 16>( pij_batch, value_cache, block_table_t, oi_new_batch, batch_count, block_idx, block_num, batch_start ); } else if (q_tile_size == 16) { - pv_matmul_batch_impl( + pv_matmul_batch_impl<16, 128, 128>( pij_batch, value_cache, block_table_t, oi_new_batch, batch_count, block_idx, block_num, batch_start ); } else { - pv_matmul_batch_impl( + pv_matmul_batch_impl<64, 64, 128>( pij_batch, value_cache, block_table_t, oi_new_batch, batch_count, block_idx, block_num, batch_start ); } diff --git a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp similarity index 76% rename from examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp rename to tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp index 0da74ecae..f9a25b4fb 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp +++ b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp @@ -14,12 +14,11 @@ // Processes batch_count batches in a single kernel invocation. // Per-batch addresses are computed from global tensor bases + block_table lookup. // -// Supports three tile configurations via runtime dispatch: -// Small: (16, 16) @ ( 16, 16).T -> (16, 16) [fp16] -// Case1: (16, 128) @ (128, 128).T -> (16, 128) [bf16] -// Case2: (64, 128) @ (128, 64).T -> (64, 64) [bf16] +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) @ (128, 128).T -> (16, 128) +// Case2: (64, 128) @ (128, 64).T -> (64, 64) // -// Template: T=data_type, M=q_tile, K=head_dim, N=block_size +// Template: M=q_tile, K=head_dim, N=block_size #include #include @@ -37,26 +36,26 @@ using namespace pto; #define __aicore__ [aicore] // NOLINT(whitespace/braces) #endif -template +template static __aicore__ void qk_matmul_batch_impl( __gm__ Tensor *query, __gm__ Tensor *key_cache, __gm__ Tensor *block_table_t, __gm__ Tensor *sij_batch, uint64_t batch_count, uint64_t block_idx, uint64_t q_offset, uint64_t block_num, uint64_t num_heads, uint64_t batch_start ) { - __gm__ T *query_base = reinterpret_cast<__gm__ T *>(query->buffer.addr); - __gm__ T *key_base = reinterpret_cast<__gm__ T *>(key_cache->buffer.addr); + __gm__ bfloat16_t *query_base = reinterpret_cast<__gm__ bfloat16_t *>(query->buffer.addr); + __gm__ bfloat16_t *key_base = reinterpret_cast<__gm__ bfloat16_t *>(key_cache->buffer.addr); __gm__ float *sij_base = reinterpret_cast<__gm__ float *>(sij_batch->buffer.addr); __gm__ int32_t *bt = reinterpret_cast<__gm__ int32_t *>(block_table_t->buffer.addr); - using GlobalA = GlobalTensor, Stride>; - using GlobalB = GlobalTensor, Stride, Layout::DN>; + using GlobalA = GlobalTensor, Stride>; + using GlobalB = GlobalTensor, Stride, Layout::DN>; using GlobalOut = GlobalTensor, Stride>; - using TileMatA = Tile; - using TileMatB = Tile; + using TileMatA = Tile; + using TileMatB = Tile; - using LeftTile = TileLeft; - using RightTile = TileRight; + using LeftTile = TileLeft; + using RightTile = TileRight; using AccTile = TileAcc; TileMatA aMatTile; @@ -72,9 +71,9 @@ static __aicore__ void qk_matmul_batch_impl( TASSIGN(cTile, 0x0); for (uint64_t b = 0; b < batch_count; b++) { - __gm__ T *qi_addr = query_base + ((batch_start + b) * num_heads + q_offset) * K; + __gm__ bfloat16_t *qi_addr = query_base + ((batch_start + b) * num_heads + q_offset) * K; int32_t phys_block = bt[(batch_start + b) * block_num + block_idx]; - __gm__ T *kj_addr = key_base + static_cast(phys_block) * N * K; + __gm__ bfloat16_t *kj_addr = key_base + static_cast(phys_block) * N * K; __gm__ float *sij_addr = sij_base + b * M * N; GlobalA qiGlobal(qi_addr); @@ -125,18 +124,18 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { uint64_t q_tile_size = static_cast(sij_batch->shapes[0] / batch_count); uint64_t block_size = static_cast(sij_batch->shapes[1]); - if (q_tile_size == 16 && block_size == 16) { - qk_matmul_batch_impl( + if (q_tile_size == 16 && block_size <= 16) { + qk_matmul_batch_impl<16, 16, 16>( query, key_cache, block_table_t, sij_batch, batch_count, block_idx, q_offset, block_num, num_heads, batch_start ); } else if (q_tile_size == 16) { - qk_matmul_batch_impl( + qk_matmul_batch_impl<16, 128, 128>( query, key_cache, block_table_t, sij_batch, batch_count, block_idx, q_offset, block_num, num_heads, batch_start ); } else { - qk_matmul_batch_impl( + qk_matmul_batch_impl<64, 128, 64>( query, key_cache, block_table_t, sij_batch, batch_count, block_idx, q_offset, block_num, num_heads, batch_start ); diff --git a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_online_update.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_online_update.cpp similarity index 98% rename from examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_online_update.cpp rename to tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_online_update.cpp index d180adc76..296db9e57 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_online_update.cpp +++ b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_online_update.cpp @@ -14,8 +14,7 @@ // For each batch b, updates accumulators mi/li/oi with new block's mij/lij/oi_new. // On is_last, normalizes and writes to the output tensor at the correct batch offset. // -// Supports three tile configurations via runtime dispatch: -// Small: (16, 16) -- q_tile=16, head_dim=16 +// Supports two tile configurations via runtime dispatch: // Case1: (16, 128) -- q_tile=16, head_dim=128 // Case2: (64, 128) -- q_tile=64, head_dim=128 // @@ -211,7 +210,7 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { uint64_t q_tile_size = static_cast(mij_batch->shapes[0] / batch_count); uint64_t head_dim = static_cast(oi_new_batch->shapes[1]); - if (q_tile_size == 16 && head_dim == 16) { + if (q_tile_size == 16 && head_dim <= 16) { online_update_batch_impl<16, 16>( mij_batch, lij_batch, oi_new_batch, mi_batch, li_batch, oi_batch, out, is_first, is_last, batch_count, q_offset, num_heads, batch_start diff --git a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp similarity index 84% rename from examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp rename to tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp index dc14d4602..64d6796ac 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp +++ b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp @@ -20,10 +20,9 @@ // pij[b] = exp(sij_scale - mij[b]) (truncated to bf16 then back) // lij[b] = row_sum(pij[b]) // -// Supports three tile configurations via runtime dispatch: -// Small: (16, 16) -- q_tile=16, block_size=16 [truncate to fp16] -// Case1: (16, 128) -- q_tile=16, block_size=128 [truncate to bf16] -// Case2: (64, 64) -- q_tile=64, block_size=64 [truncate to bf16] +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) -- q_tile=16, block_size=128 +// Case2: (64, 64) -- q_tile=64, block_size=64 #include #include @@ -41,13 +40,13 @@ using namespace pto; #define __aicore__ [aicore] // NOLINT(whitespace/braces) #endif -template +template static __aicore__ void softmax_prepare_batch_impl( __gm__ Tensor *sij_batch, __gm__ Tensor *context_lens_t, __gm__ Tensor *pij_batch, __gm__ Tensor *mij_batch, __gm__ Tensor *lij_batch, float scale_value, uint64_t batch_count, uint64_t block_idx, uint64_t batch_start ) { __gm__ float *sij_base = reinterpret_cast<__gm__ float *>(sij_batch->buffer.addr); - __gm__ T *pij_base = reinterpret_cast<__gm__ T *>(pij_batch->buffer.addr); + __gm__ bfloat16_t *pij_base = reinterpret_cast<__gm__ bfloat16_t *>(pij_batch->buffer.addr); __gm__ float *mij_base = reinterpret_cast<__gm__ float *>(mij_batch->buffer.addr); __gm__ float *lij_base = reinterpret_cast<__gm__ float *>(lij_batch->buffer.addr); __gm__ int32_t *ctx_lens = reinterpret_cast<__gm__ int32_t *>(context_lens_t->buffer.addr); @@ -55,14 +54,14 @@ static __aicore__ void softmax_prepare_batch_impl( constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; - using GlobalDataMxN_T = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalDataMxN_bf16 = GlobalTensor, Stride<1, 1, 1, N, 1>>; using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; using TileSijDyn = Tile; using TileSijPad = Tile; using TileVecMxN = Tile; - using TileVecMxN_T = Tile; + using TileVecMxN_bf16 = Tile; using TileScalarDN = Tile; TileVecMxN sijTile; @@ -71,7 +70,7 @@ static __aicore__ void softmax_prepare_batch_impl( TileVecMxN tmpTile; TileScalarDN maxTile; TileScalarDN sumTile; - TileVecMxN_T pijTruncTile; + TileVecMxN_bf16 pijBf16Tile; TASSIGN(sijTile, 0x0); TASSIGN(sijPadTile, 0x0); @@ -79,7 +78,7 @@ static __aicore__ void softmax_prepare_batch_impl( TASSIGN(tmpTile, 2 * M * N * sizeof(float)); TASSIGN(maxTile, 3 * M * N * sizeof(float)); TASSIGN(sumTile, 3 * M * N * sizeof(float) + kAlignedRows * sizeof(float)); - TASSIGN(pijTruncTile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float)); + TASSIGN(pijBf16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float)); for (uint64_t b = 0; b < batch_count; b++) { int32_t cur_seq = ctx_lens[batch_start + b]; @@ -91,12 +90,12 @@ static __aicore__ void softmax_prepare_batch_impl( } __gm__ float *sij_addr = sij_base + b * M * N; - __gm__ T *pij_addr = pij_base + b * M * N; + __gm__ bfloat16_t *pij_addr = pij_base + b * M * N; __gm__ float *mij_addr = mij_base + b * M; __gm__ float *lij_addr = lij_base + b * M; GlobalDataMxN sijGlobal(sij_addr); - GlobalDataMxN_T pijGlobal(pij_addr); + GlobalDataMxN_bf16 pijGlobal(pij_addr); GlobalScalarDN mijGlobal(mij_addr); GlobalScalarDN lijGlobal(lij_addr); @@ -109,14 +108,14 @@ static __aicore__ void softmax_prepare_batch_impl( sumTile.SetValue(i, 0.0f); } for (int i = 0; i < M * N; i++) { - pijTruncTile.SetValue(i, static_cast(0.0f)); + pijBf16Tile.SetValue(i, static_cast(0.0f)); } set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); TSTORE(mijGlobal, maxTile); TSTORE(lijGlobal, sumTile); - TSTORE(pijGlobal, pijTruncTile); + TSTORE(pijGlobal, pijBf16Tile); if (b + 1 < batch_count) { pipe_barrier(PIPE_ALL); @@ -142,16 +141,16 @@ static __aicore__ void softmax_prepare_batch_impl( TEXP(pijTile, pijTile); pipe_barrier(PIPE_V); // Truncate pij to bf16 first, then compute lij from truncated values (matches golden) - TCVT(pijTruncTile, pijTile, RoundMode::CAST_ROUND); + TCVT(pijBf16Tile, pijTile, RoundMode::CAST_ROUND); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); pipe_barrier(PIPE_V); - TCVT(pijTile, pijTruncTile, RoundMode::CAST_ROUND); + TCVT(pijTile, pijBf16Tile, RoundMode::CAST_ROUND); pipe_barrier(PIPE_V); TROWSUM(sumTile, pijTile, tmpTile); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - TSTORE(pijGlobal, pijTruncTile); + TSTORE(pijGlobal, pijBf16Tile); TSTORE(mijGlobal, maxTile); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); TSTORE(lijGlobal, sumTile); @@ -182,18 +181,18 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { uint64_t batch_start = static_cast(args[8]); uint64_t q_tile_size = static_cast(sij_batch->shapes[0] / batch_count); - uint64_t block_size = static_cast(sij_batch->shapes[1]); + uint64_t block_size = static_cast(pij_batch->shapes[1]); - if (q_tile_size == 16 && block_size == 16) { - softmax_prepare_batch_impl( + if (q_tile_size == 16 && block_size <= 16) { + softmax_prepare_batch_impl<16, 16>( sij_batch, context_lens_t, pij_batch, mij_batch, lij_batch, scale_value, batch_count, block_idx, batch_start ); } else if (q_tile_size == 16) { - softmax_prepare_batch_impl( + softmax_prepare_batch_impl<16, 128>( sij_batch, context_lens_t, pij_batch, mij_batch, lij_batch, scale_value, batch_count, block_idx, batch_start ); } else { - softmax_prepare_batch_impl( + softmax_prepare_batch_impl<64, 64>( sij_batch, context_lens_t, pij_batch, mij_batch, lij_batch, scale_value, batch_count, block_idx, batch_start ); } diff --git a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp similarity index 99% rename from examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp rename to tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp index 299848ba6..68b794f6b 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp +++ b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp @@ -71,8 +71,7 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(const Chip uint64_t scale_value = orch_args.scalar(0); - constexpr uint64_t Q_TILE_LIMIT = 128; - uint64_t q_tile = std::min(num_heads, Q_TILE_LIMIT); + uint64_t q_tile = std::min(num_heads, static_cast(128)); uint64_t q_loop = (num_heads + q_tile - 1) / q_tile; uint64_t elem_size = get_element_size(data_type); diff --git a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/test_batch_paged_attention.py b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/test_batch_paged_attention.py similarity index 81% rename from examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/test_batch_paged_attention.py rename to tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/test_batch_paged_attention.py index 47f33671a..cc1ed20e9 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/test_batch_paged_attention.py +++ b/tests/st/a2a3/tensormap_and_ringbuffer/batch_paged_attention/test_batch_paged_attention.py @@ -7,11 +7,7 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. # ----------------------------------------------------------------------------------------------------------- -"""Batch paged attention: small-scale (sim) and production-scale (hardware) tests. - -Kernels use runtime dispatch to handle both small-scale (fp16, 16x16x16 templates) -and production-scale (bf16, 128+ dimension templates) configurations. -""" +"""Batch paged attention: batched online softmax with AIC/AIV subgraph splitting (bfloat16).""" import torch from simpler.task_interface import ArgDirection as D @@ -21,8 +17,10 @@ from simpler_setup.goldens.paged_attention import generate_inputs as _pa_generate_inputs -class _BatchPagedAttentionBase(SceneTestCase): - """Shared CALLABLE, generate_args, and compute_golden for batch paged attention.""" +@scene_test(level=2, runtime="tensormap_and_ringbuffer") +class TestBatchPagedAttention(SceneTestCase): + RTOL = 1e-3 + ATOL = 1e-3 CALLABLE = { "orchestration": { @@ -58,36 +56,58 @@ class _BatchPagedAttentionBase(SceneTestCase): ], } - def generate_args(self, params): - result = _pa_generate_inputs(params) - specs = [] - for name, value in result: - if isinstance(value, torch.Tensor): - specs.append(Tensor(name, value)) - else: - specs.append(Scalar(name, value)) - return TaskArgsBuilder(*specs) - - def compute_golden(self, args, params): - tensors = {s.name: s.value for s in args.specs if isinstance(s, Tensor)} - _pa_compute_golden(tensors, params) - for s in args.specs: - if isinstance(s, Tensor) and s.name in tensors: - getattr(args, s.name)[:] = tensors[s.name] - - -@scene_test(level=2, runtime="tensormap_and_ringbuffer") -class TestBatchPagedAttention(_BatchPagedAttentionBase): - """Batch paged attention — small-scale cases (sim-compatible, float16).""" - - RTOL = 1e-2 - ATOL = 1e-2 - CASES = [ { "name": "Case1", - "platforms": ["a2a3sim", "a2a3"], + "platforms": ["a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "params": { + "batch": 256, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 128, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + }, + { + "name": "Case2", + "platforms": ["a2a3"], + "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "manual": True, + "params": { + "batch": 64, + "num_heads": 64, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 64, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + }, + { + "name": "Case3", + "platforms": ["a2a3"], + "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "manual": True, + "params": { + "batch": 64, + "num_heads": 64, + "kv_head_num": 1, + "head_dim": 256, + "block_size": 64, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + }, + { + "name": "CaseSmall1", + "platforms": ["a2a3sim", "a2a3"], + "config": {"aicpu_thread_num": 4, "block_dim": 9}, "params": { "batch": 1, "num_heads": 16, @@ -96,13 +116,14 @@ class TestBatchPagedAttention(_BatchPagedAttentionBase): "block_size": 16, "context_len": 33, "max_model_len": 256, - "dtype": "float16", + "dtype": "bfloat16", }, }, { - "name": "Case2", + "name": "CaseSmall2", "platforms": ["a2a3sim", "a2a3"], - "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "config": {"aicpu_thread_num": 4, "block_dim": 9}, + "manual": True, "params": { "batch": 1, "num_heads": 16, @@ -111,13 +132,14 @@ class TestBatchPagedAttention(_BatchPagedAttentionBase): "block_size": 16, "context_len": 31, "max_model_len": 256, - "dtype": "float16", + "dtype": "bfloat16", }, }, { - "name": "Case3", + "name": "CaseSmall3", "platforms": ["a2a3sim", "a2a3"], - "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "config": {"aicpu_thread_num": 4, "block_dim": 9}, + "manual": True, "params": { "batch": 1, "num_heads": 16, @@ -126,13 +148,14 @@ class TestBatchPagedAttention(_BatchPagedAttentionBase): "block_size": 16, "context_len": 128, "max_model_len": 256, - "dtype": "float16", + "dtype": "bfloat16", }, }, { "name": "CaseVarSeq2", "platforms": ["a2a3sim", "a2a3"], - "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "config": {"aicpu_thread_num": 4, "block_dim": 9}, + "manual": True, "params": { "batch": 2, "num_heads": 16, @@ -142,13 +165,14 @@ class TestBatchPagedAttention(_BatchPagedAttentionBase): "context_len": 33, "context_lens_list": [33, 17], "max_model_len": 256, - "dtype": "float16", + "dtype": "bfloat16", }, }, { "name": "CaseVarSeq4", "platforms": ["a2a3sim", "a2a3"], - "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "config": {"aicpu_thread_num": 4, "block_dim": 9}, + "manual": True, "params": { "batch": 4, "num_heads": 16, @@ -158,70 +182,27 @@ class TestBatchPagedAttention(_BatchPagedAttentionBase): "context_len": 128, "context_lens_list": [33, 64, 128, 15], "max_model_len": 256, - "dtype": "float16", + "dtype": "bfloat16", }, }, ] + def generate_args(self, params): + result = _pa_generate_inputs(params) + specs = [] + for name, value in result: + if isinstance(value, torch.Tensor): + specs.append(Tensor(name, value)) + else: + specs.append(Scalar(name, value)) + return TaskArgsBuilder(*specs) -@scene_test(level=2, runtime="tensormap_and_ringbuffer") -class TestBatchPagedAttentionLarge(_BatchPagedAttentionBase): - """Batch paged attention — production-scale cases (hardware-only, bfloat16).""" - - RTOL = 1e-3 - ATOL = 1e-3 - RUNTIME_ENV = {"PTO2_RING_HEAP": "1073741824"} - - CASES = [ - { - "name": "Case1", - "platforms": ["a2a3"], - "config": {"aicpu_thread_num": 4, "block_dim": 24}, - "manual": True, - "params": { - "batch": 256, - "num_heads": 16, - "kv_head_num": 1, - "head_dim": 128, - "block_size": 128, - "context_len": 8192, - "max_model_len": 32768, - "dtype": "bfloat16", - }, - }, - { - "name": "Case2", - "platforms": ["a2a3"], - "config": {"aicpu_thread_num": 4, "block_dim": 24}, - "manual": True, - "params": { - "batch": 64, - "num_heads": 64, - "kv_head_num": 1, - "head_dim": 128, - "block_size": 64, - "context_len": 8192, - "max_model_len": 32768, - "dtype": "bfloat16", - }, - }, - { - "name": "Case3", - "platforms": ["a2a3"], - "config": {"aicpu_thread_num": 4, "block_dim": 24}, - "manual": True, - "params": { - "batch": 64, - "num_heads": 64, - "kv_head_num": 1, - "head_dim": 256, - "block_size": 64, - "context_len": 8192, - "max_model_len": 32768, - "dtype": "bfloat16", - }, - }, - ] + def compute_golden(self, args, params): + tensors = {s.name: s.value for s in args.specs if isinstance(s, Tensor)} + _pa_compute_golden(tensors, params) + for s in args.specs: + if isinstance(s, Tensor) and s.name in tensors: + getattr(args, s.name)[:] = tensors[s.name] if __name__ == "__main__": diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py b/tests/st/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py index e59a360f9..514e2189a 100644 --- a/tests/st/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py +++ b/tests/st/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py @@ -51,24 +51,28 @@ class TestBenchmarkBgemm(SceneTestCase): }, { "name": "Case1", + "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 24}, "params": {"matmul_add_task_num": 64, "incore_data_size": 128, "incore_loop": 4, "grid_k": 2}, }, { "name": "Case2", + "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 24}, "params": {"matmul_add_task_num": 256, "incore_data_size": 128, "incore_loop": 4, "grid_k": 2}, }, { "name": "Case3", + "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 24}, "params": {"matmul_add_task_num": 64, "incore_data_size": 128, "incore_loop": 16, "grid_k": 2}, }, { "name": "Case4", + "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 24}, "params": {"matmul_add_task_num": 64, "incore_data_size": 128, "incore_loop": 4, "grid_k": 4}, diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/mixed_example/test_mixed_example.py b/tests/st/a2a3/tensormap_and_ringbuffer/mixed_example/test_mixed_example.py index 132960cd1..daf598969 100644 --- a/tests/st/a2a3/tensormap_and_ringbuffer/mixed_example/test_mixed_example.py +++ b/tests/st/a2a3/tensormap_and_ringbuffer/mixed_example/test_mixed_example.py @@ -91,6 +91,7 @@ class TestMixedExample(SceneTestCase): }, { "name": "case2", + "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 3}, "params": {"num_iters": 1}, diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/multi_round_paged_attention/test_multi_round_paged_attention.py b/tests/st/a2a3/tensormap_and_ringbuffer/multi_round_paged_attention/test_multi_round_paged_attention.py index 22377d0cc..b9520e5af 100644 --- a/tests/st/a2a3/tensormap_and_ringbuffer/multi_round_paged_attention/test_multi_round_paged_attention.py +++ b/tests/st/a2a3/tensormap_and_ringbuffer/multi_round_paged_attention/test_multi_round_paged_attention.py @@ -76,7 +76,56 @@ class TestMultiRoundPagedAttention(SceneTestCase): "max_model_len": 256, "dtype": "bfloat16", }, + }, + { + "name": "Case2", + "platforms": ["a2a3sim", "a2a3"], + "manual": True, + "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "params": { + "batch": 1, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 16, + "block_size": 16, + "context_len": 128, + "max_model_len": 256, + "dtype": "bfloat16", + }, + }, + { + "name": "CaseVarSeq2", + "platforms": ["a2a3sim", "a2a3"], + "manual": True, + "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "params": { + "batch": 2, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 16, + "block_size": 16, + "context_len": 33, + "context_lens_list": [33, 17], + "max_model_len": 256, + "dtype": "bfloat16", + }, + }, + { + "name": "CaseVarSeq4", + "platforms": ["a2a3sim", "a2a3"], "manual": True, + "config": {"aicpu_thread_num": 4, "block_dim": 24}, + "params": { + "batch": 4, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 16, + "block_size": 16, + "context_len": 128, + "context_lens_list": [33, 64, 128, 15], + "max_model_len": 256, + "dtype": "bfloat16", + }, }, ] diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/orchestration/paged_attention_orch.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/orchestration/paged_attention_orch.cpp index f5d88861d..954ee478f 100644 --- a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/orchestration/paged_attention_orch.cpp +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/orchestration/paged_attention_orch.cpp @@ -105,7 +105,7 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(const Chip // scale from scalar arg uint64_t scale_value = orch_args.scalar(0); uint64_t q_head_num = num_heads; - uint64_t q_tile = std::min(num_heads, 128UL); + uint64_t q_tile = std::min(num_heads, static_cast(128)); uint64_t q_loop = (q_head_num + q_tile - 1) / q_tile; CYCLE_COUNT_LAP(prof_param_extract); diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/test_paged_attention_unroll.py b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/test_paged_attention_unroll.py index 1a4efbbce..847882d0a 100644 --- a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/test_paged_attention_unroll.py +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/test_paged_attention_unroll.py @@ -61,7 +61,6 @@ class TestPagedAttentionUnroll(SceneTestCase): "name": "Case1", "platforms": ["a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 24}, - "manual": True, "params": { "batch": 256, "num_heads": 16, diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp index 45f7ba75c..c606d8d14 100644 --- a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp @@ -66,7 +66,7 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(const Chip // scale from scalar arg uint64_t scale_value = orch_args.scalar(0); - uint64_t q_tile = std::min(num_heads, 128UL); + uint64_t q_tile = std::min(num_heads, static_cast(128)); uint64_t q_loop = (num_heads + q_tile - 1) / q_tile; // External 4D tensors inherit shape/dtype from TaskArg (golden provides 4D). diff --git a/tools/benchmark_rounds.sh b/tools/benchmark_rounds.sh index 6d684ea19..5e11f4327 100755 --- a/tools/benchmark_rounds.sh +++ b/tools/benchmark_rounds.sh @@ -30,9 +30,9 @@ RUN_EXAMPLE="$PROJECT_ROOT/examples/scripts/run_example.py" # --- tensormap_and_ringbuffer --- declare -A TMR_EXAMPLE_CASES=( [alternating_matmul_add]="Case1" - [benchmark_bgemm]="" + [benchmark_bgemm]="Case0" [paged_attention_unroll]="Case1,Case2" - [batch_paged_attention]="" + [batch_paged_attention]="Case1" ) TMR_EXAMPLE_ORDER=( alternating_matmul_add