From 7ff6660ec4a3297502332cc4437172d457576fee Mon Sep 17 00:00:00 2001 From: chenshengxin Date: Thu, 9 Apr 2026 20:32:35 +0800 Subject: [PATCH] Add: SPMD paged attention example with dual-vector softmax Implements a complete paged attention kernel using SPMD parallelism under the tensormap_and_ringbuffer runtime. The pipeline consists of QK matmul, softmax prepare, PV matmul, and online update stages with dual AIV lanes processing 8-row sub-tiles each for the softmax and accumulation phases. --- .../spmd_paged_attention/golden.py | 79 ++++++ .../kernels/aic/aic_hub.cpp | 29 ++ .../kernels/aic/aic_pv_matmul.cpp | 152 +++++++++++ .../kernels/aic/aic_qk_matmul.cpp | 158 +++++++++++ .../kernels/aiv/aiv_hub.cpp | 27 ++ .../kernels/aiv/aiv_online_update.cpp | 258 ++++++++++++++++++ .../kernels/aiv/aiv_softmax_prepare.cpp | 192 +++++++++++++ .../kernels/kernel_config.py | 85 ++++++ .../spmd_paged_attention_orch.cpp | 239 ++++++++++++++++ 9 files changed, 1219 insertions(+) create mode 100644 examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/golden.py create mode 100644 examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_hub.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_pv_matmul.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_qk_matmul.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_hub.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_online_update.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/kernel_config.py create mode 100644 examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/orchestration/spmd_paged_attention_orch.cpp diff --git a/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/golden.py b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/golden.py new file mode 100644 index 000000000..3847437a6 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/golden.py @@ -0,0 +1,79 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""SPMD Paged Attention Golden - tensormap_and_ringbuffer example (small scale, bfloat16). + +Uses SPMD parallelism: each block handles one (batch, q_tile) position. +Kernels use get_block_idx() to determine their work slice. +""" + +from paged_attention_golden import ( + compute_golden, # noqa: F401 + run_golden_test, +) +from paged_attention_golden import generate_inputs as _generate_inputs + +__outputs__ = ["out"] + +RTOL = 1e-2 +ATOL = 1e-2 + +ALL_CASES = { + "Case1": { + "batch": 1, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 16, + "block_size": 16, + "context_len": 33, + "max_model_len": 256, + "dtype": "bfloat16", + }, + "Case2": { + "batch": 1, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 16, + "block_size": 16, + "context_len": 128, + "max_model_len": 256, + "dtype": "bfloat16", + }, + "CaseVarSeq2": { + "batch": 2, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 16, + "block_size": 16, + "context_len": 33, + "context_lens_list": [33, 17], + "max_model_len": 256, + "dtype": "bfloat16", + }, + "CaseVarSeq4": { + "batch": 4, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 16, + "block_size": 16, + "context_len": 64, + "context_lens_list": [33, 64, 48, 15], + "max_model_len": 256, + "dtype": "bfloat16", + }, +} + +DEFAULT_CASE = "Case1" + + +def generate_inputs(params: dict) -> list: + return _generate_inputs(params) + + +if __name__ == "__main__": + run_golden_test(ALL_CASES, DEFAULT_CASE, generate_inputs, label="SPMD Paged Attention") diff --git a/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_hub.cpp b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_hub.cpp new file mode 100644 index 000000000..eb602e0bd --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_hub.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +// AIC Hub Kernel - No-op stub used as the AIC slot of MIX (AIC+AIV0+AIV1) tasks +// when the real work happens only on the two AIVs (softmax, online update). +// Pairing an idle AIC with two active AIVs forces the scheduler to allocate a +// full cluster, which is what enables the two AIV lanes to run in parallel. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {} diff --git a/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_pv_matmul.cpp b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_pv_matmul.cpp new file mode 100644 index 000000000..f2e2652e2 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_pv_matmul.cpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +// SPMD PV Matmul: pij(M, K) @ vj(K, N) -> oi_new(M, N) +// +// SPMD block_idx encodes (batch_idx, q_tile_idx). +// Each block computes one 16x16 matmul using paged V cache. +// +// Args: +// args[0] = pij Tensor* (spmd_blocks*Q_TILE, block_size) data_type +// args[1] = value_cache Tensor* (kv_total_rows, head_dim) bf16 +// args[2] = block_table Tensor* (batch, max_blocks_per_req) int32 +// args[3] = context_lens Tensor* (batch,) int32 +// args[4] = oi_new Tensor* (spmd_blocks*Q_TILE, head_dim) float32 [output] +// args[5] = bn scalar: current KV block index +// args[6] = num_heads scalar +// args[7] = head_dim scalar +// args[8] = block_size scalar +// args[9] = max_num_blocks_per_req scalar +// args[10] = q_loop scalar + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include "intrinsic.h" + +static constexpr int M = 16; +static constexpr int K = 16; +static constexpr int N = 16; + +template +static __aicore__ void pv_matmul_spmd(__gm__ bfloat16_t *pij_addr, __gm__ bfloat16_t *vj_addr, __gm__ float *oi_addr) { + using GlobalA = GlobalTensor, Stride>; + using GlobalB = GlobalTensor, Stride>; + using GlobalOut = GlobalTensor, Stride>; + + GlobalA pijGlobal(pij_addr); + GlobalB vjGlobal(vj_addr); + GlobalOut oiGlobal(oi_addr); + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + TLOAD(aMatTile, pijGlobal); + TLOAD(bMatTile, vjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(oiGlobal, cTile); + + set_flag(PIPE_FIX, PIPE_S, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *pij_t = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *value_cache_t = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *block_table_t = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *context_lens_t = reinterpret_cast<__gm__ Tensor *>(args[3]); + __gm__ Tensor *oi_new_t = reinterpret_cast<__gm__ Tensor *>(args[4]); + + int64_t bn = static_cast(args[5]); + int64_t num_heads = static_cast(args[6]); + int64_t head_dim = static_cast(args[7]); + int64_t block_size = static_cast(args[8]); + int64_t max_blocks_per_req = static_cast(args[9]); + int64_t q_loop = static_cast(args[10]); + + int32_t block_idx = get_block_idx(args); + int64_t batch_idx = block_idx / q_loop; + + // Check if this batch has data at this KV block + __gm__ int32_t *ctx_ptr = + reinterpret_cast<__gm__ int32_t *>(context_lens_t->buffer.addr) + context_lens_t->start_offset; + int64_t cur_seq = static_cast(ctx_ptr[batch_idx]); + int64_t bn_this_batch = (cur_seq + block_size - 1) / block_size; + + // Output pointer for this block's oi_new slice + __gm__ float *oi_addr = + reinterpret_cast<__gm__ float *>(oi_new_t->buffer.addr) + oi_new_t->start_offset + block_idx * M * head_dim; + + if (bn >= bn_this_batch) { + for (int i = 0; i < M * static_cast(head_dim); i++) { + oi_addr[i] = 0.0f; + } + return; + } + + // Look up physical block index + __gm__ int32_t *bt_ptr = + reinterpret_cast<__gm__ int32_t *>(block_table_t->buffer.addr) + block_table_t->start_offset; + int64_t phys_block = static_cast(bt_ptr[batch_idx * max_blocks_per_req + bn]); + + // pij offset: block_idx * Q_TILE * block_size + int64_t pij_offset = block_idx * M * block_size; + __gm__ bfloat16_t *pij_addr = + reinterpret_cast<__gm__ bfloat16_t *>(pij_t->buffer.addr) + pij_t->start_offset + pij_offset; + + // Value offset: phys_block * block_size * head_dim + int64_t v_offset = phys_block * block_size * head_dim; + __gm__ bfloat16_t *vj_addr = + reinterpret_cast<__gm__ bfloat16_t *>(value_cache_t->buffer.addr) + value_cache_t->start_offset + v_offset; + + pv_matmul_spmd(pij_addr, vj_addr, oi_addr); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_qk_matmul.cpp b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_qk_matmul.cpp new file mode 100644 index 000000000..bdd0844bf --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aic/aic_qk_matmul.cpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +// SPMD QK Matmul: qi(M, K) @ kj.T(K, N) -> sij(M, N) +// +// SPMD block_idx encodes (batch_idx, q_tile_idx). +// Each block computes one 16x16 matmul using paged KV. +// +// Args: +// args[0] = query Tensor* (batch*num_heads, head_dim) bf16 +// args[1] = key_cache Tensor* (kv_total_rows, head_dim) bf16 +// args[2] = block_table Tensor* (batch, max_blocks_per_req) int32 +// args[3] = context_lens Tensor* (batch,) int32 +// args[4] = sij Tensor* (spmd_blocks*Q_TILE, block_size) float32 [output] +// args[5] = bn scalar: current KV block index +// args[6] = num_heads scalar +// args[7] = head_dim scalar +// args[8] = block_size scalar +// args[9] = max_num_blocks_per_req scalar +// args[10] = q_loop scalar + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include "intrinsic.h" + +static constexpr int M = 16; +static constexpr int K = 16; +static constexpr int N = 16; + +template +static __aicore__ void +qk_matmul_spmd(__gm__ bfloat16_t *qi_addr, __gm__ bfloat16_t *kj_addr, __gm__ float *sij_addr) { + using GlobalA = GlobalTensor, Stride>; + using GlobalB = + GlobalTensor, Stride, Layout::DN>; + using GlobalOut = GlobalTensor, Stride>; + + GlobalA qiGlobal(qi_addr); + GlobalB kjGlobal(kj_addr); + GlobalOut sijGlobal(sij_addr); + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + TLOAD(aMatTile, qiGlobal); + TLOAD(bMatTile, kjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(sijGlobal, cTile); + + set_flag(PIPE_FIX, PIPE_S, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *query_t = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *key_cache_t = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *block_table_t = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *context_lens_t = reinterpret_cast<__gm__ Tensor *>(args[3]); + __gm__ Tensor *sij_t = reinterpret_cast<__gm__ Tensor *>(args[4]); + + int64_t bn = static_cast(args[5]); + int64_t num_heads = static_cast(args[6]); + int64_t head_dim = static_cast(args[7]); + int64_t block_size = static_cast(args[8]); + int64_t max_blocks_per_req = static_cast(args[9]); + int64_t q_loop = static_cast(args[10]); + + int32_t block_idx = get_block_idx(args); + + // Decode (batch_idx, q_tile_idx) from block_idx + int64_t batch_idx = block_idx / q_loop; + int64_t q_tile_idx = block_idx % q_loop; + + // Check if this batch has data at this KV block + __gm__ int32_t *ctx_ptr = + reinterpret_cast<__gm__ int32_t *>(context_lens_t->buffer.addr) + context_lens_t->start_offset; + int64_t cur_seq = static_cast(ctx_ptr[batch_idx]); + int64_t bn_this_batch = (cur_seq + block_size - 1) / block_size; + + // Output pointer for this block's sij slice + __gm__ float *sij_addr = + reinterpret_cast<__gm__ float *>(sij_t->buffer.addr) + sij_t->start_offset + block_idx * M * block_size; + + if (bn >= bn_this_batch) { + // No valid KV data for this batch at this bn — zero out sij + for (int i = 0; i < M * static_cast(block_size); i++) { + sij_addr[i] = 0.0f; + } + return; + } + + // Look up physical block index from block_table + __gm__ int32_t *bt_ptr = + reinterpret_cast<__gm__ int32_t *>(block_table_t->buffer.addr) + block_table_t->start_offset; + int64_t phys_block = static_cast(bt_ptr[batch_idx * max_blocks_per_req + bn]); + + // Query offset: (batch_idx * num_heads + q_tile_idx * Q_TILE, 0) + int64_t q_offset = (batch_idx * num_heads + q_tile_idx * M) * head_dim; + __gm__ bfloat16_t *qi_addr = + reinterpret_cast<__gm__ bfloat16_t *>(query_t->buffer.addr) + query_t->start_offset + q_offset; + + // Key offset: (phys_block * block_size, 0) + int64_t k_offset = phys_block * block_size * head_dim; + __gm__ bfloat16_t *kj_addr = + reinterpret_cast<__gm__ bfloat16_t *>(key_cache_t->buffer.addr) + key_cache_t->start_offset + k_offset; + + qk_matmul_spmd(qi_addr, kj_addr, sij_addr); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_hub.cpp b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_hub.cpp new file mode 100644 index 000000000..a42f2790e --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_hub.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +// AIV Hub Kernel - No-op stub for accumulator tensor allocation. +// The runtime allocates output tensors specified in the Arg; the kernel itself does nothing. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {} diff --git a/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_online_update.cpp b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_online_update.cpp new file mode 100644 index 000000000..c22023fc9 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_online_update.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +// SPMD Online Softmax Update + Normalize Kernel (AIV) with dual-vector +// subvector split. +// +// SPMD block_idx encodes (batch_idx, q_tile_idx). +// The two AIV lanes in a cluster split the Q_TILE=16 rows 8/8 via +// get_sub_block_id(): AIV0 updates rows [0, 8), AIV1 updates rows [8, 16). +// The online softmax update is row-independent, so the two lanes never touch +// the same row of mi/li/oi accumulators or the output buffer. +// +// Scalar layout strategy (same as MPMD version): +// M scalar floats stored contiguously in GM can be loaded as either: +// - ND (kScalarRows, kScalarCols) RowMajor for element-wise ops +// - DN (kAlignedRows, 1) ColMajor for row-broadcast ops (TROWEXPANDMUL/DIV) +// Conversion between layouts uses GM round-trip: ND TSTORE -> DN TLOAD. +// +// Args: +// args[0] = mij Tensor* (spmd_blocks*Q_TILE,) float32 +// args[1] = lij Tensor* (spmd_blocks*Q_TILE,) float32 +// args[2] = oi_new Tensor* (spmd_blocks*Q_TILE, head_dim) float32 +// args[3] = mi_acc Tensor* (spmd_blocks*Q_TILE,) float32 [inout] +// args[4] = li_acc Tensor* (spmd_blocks*Q_TILE,) float32 [inout] +// args[5] = oi_acc Tensor* (spmd_blocks*Q_TILE, head_dim) float32 [inout] +// args[6] = out Tensor* (batch*num_heads, head_dim) float32 [inout] +// args[7] = is_first scalar +// args[8] = is_last scalar +// args[9] = num_heads scalar +// args[10] = head_dim scalar +// args[11] = q_loop scalar + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include "intrinsic.h" + +static constexpr int QT = 16; // Full Q tile rows (shared between both AIVs) +static constexpr int SUB_QT = 8; // Rows per AIV lane (QT / 2) +static constexpr int HD = 16; // Head dimension + +template +static __aicore__ void online_update_spmd( + __gm__ float *mij_ptr, __gm__ float *lij_ptr, __gm__ float *oi_new_ptr, __gm__ float *mi_ptr, + __gm__ float *li_ptr, __gm__ float *oi_ptr, __gm__ float *dst_ptr, uint64_t is_first, uint64_t is_last +) { + constexpr int kScalarCols = 32 / sizeof(float); + constexpr int kScalarRows = TM / kScalarCols; + constexpr int kAlignedRows = ((TM * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, TN, 1>>; + using GlobalScalarND = + GlobalTensor, Stride<1, 1, 1, kScalarCols, 1>>; + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + + GlobalDataMxN oiNewGlobal(oi_new_ptr); + GlobalDataMxN oiGlobal(oi_ptr); + GlobalDataMxN dstGlobal(dst_ptr); + + GlobalScalarND mijGlobalND(mij_ptr); + GlobalScalarND lijGlobalND(lij_ptr); + GlobalScalarND miGlobalND(mi_ptr); + GlobalScalarND liGlobalND(li_ptr); + + GlobalScalarDN mijGlobalDN(mij_ptr); + GlobalScalarDN lijGlobalDN(lij_ptr); + GlobalScalarDN liGlobalDN(li_ptr); + + using TileDataMxN = Tile; + using TileScalarND = + Tile; + using TileScalarDN = Tile; + + constexpr int kDataBytes = TM * TN * sizeof(float); + constexpr int kScalarNDBytes = kScalarRows * kScalarCols * sizeof(float); + constexpr int kScalarDNBytes = kAlignedRows * sizeof(float); + + TileDataMxN oiNewTile; + TileDataMxN oiTile; + TileScalarND mijND, lijND, miND, liND; + TileScalarND miNewND, alphaND, betaND, tmpND; + TileScalarDN alphaDN, betaDN, liDN; + + TASSIGN(oiNewTile, 0); + TASSIGN(oiTile, kDataBytes); + TASSIGN(mijND, 2 * kDataBytes); + TASSIGN(lijND, 2 * kDataBytes + kScalarNDBytes); + TASSIGN(miND, 2 * kDataBytes + 2 * kScalarNDBytes); + TASSIGN(liND, 2 * kDataBytes + 3 * kScalarNDBytes); + TASSIGN(miNewND, 2 * kDataBytes + 4 * kScalarNDBytes); + TASSIGN(alphaND, 2 * kDataBytes + 5 * kScalarNDBytes); + TASSIGN(betaND, 2 * kDataBytes + 6 * kScalarNDBytes); + TASSIGN(tmpND, 2 * kDataBytes + 7 * kScalarNDBytes); + TASSIGN(alphaDN, 2 * kDataBytes + 8 * kScalarNDBytes); + TASSIGN(betaDN, 2 * kDataBytes + 8 * kScalarNDBytes + kScalarDNBytes); + TASSIGN(liDN, 2 * kDataBytes + 8 * kScalarNDBytes + 2 * kScalarDNBytes); + + if (is_first) { + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(mijND, mijGlobalND); + TLOAD(lijND, lijGlobalND); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, mijND); + TSTORE(liGlobalND, lijND); + TSTORE(oiGlobal, oiNewTile); + + if (is_last) { + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(liDN, liGlobalDN); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TROWEXPANDDIV(oiNewTile, oiNewTile, liDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiNewTile); + } + } else { + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(oiTile, oiGlobal); + TLOAD(mijND, mijGlobalND); + TLOAD(lijND, lijGlobalND); + TLOAD(miND, miGlobalND); + TLOAD(liND, liGlobalND); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TMAX(miNewND, miND, mijND); + pipe_barrier(PIPE_V); + TSUB(alphaND, miND, miNewND); + pipe_barrier(PIPE_V); + TEXP(alphaND, alphaND); + pipe_barrier(PIPE_V); + TSUB(betaND, mijND, miNewND); + pipe_barrier(PIPE_V); + TEXP(betaND, betaND); + pipe_barrier(PIPE_V); + TMUL(liND, alphaND, liND); + pipe_barrier(PIPE_V); + TMUL(tmpND, betaND, lijND); + pipe_barrier(PIPE_V); + TADD(liND, liND, tmpND); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); + TSTORE(liGlobalND, liND); + TSTORE(mijGlobalND, alphaND); + TSTORE(lijGlobalND, betaND); + + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(alphaDN, mijGlobalDN); + TLOAD(betaDN, lijGlobalDN); + if (is_last) { + TLOAD(liDN, liGlobalDN); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + + TROWEXPANDMUL(oiTile, oiTile, alphaDN); + TROWEXPANDMUL(oiNewTile, oiNewTile, betaDN); + pipe_barrier(PIPE_V); + TADD(oiTile, oiTile, oiNewTile); + + if (is_last) { + pipe_barrier(PIPE_V); + TROWEXPANDDIV(oiTile, oiTile, liDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiTile); + } else { + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(oiGlobal, oiTile); + } + } + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + // Safety check: if called with null tensor args (misrouted hub invocation), return. + if (args[0] == 0 || args[1] == 0 || args[2] == 0) { + return; + } + + __gm__ Tensor *mij_t = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *lij_t = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *oi_new_t = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *mi_acc_t = reinterpret_cast<__gm__ Tensor *>(args[3]); + __gm__ Tensor *li_acc_t = reinterpret_cast<__gm__ Tensor *>(args[4]); + __gm__ Tensor *oi_acc_t = reinterpret_cast<__gm__ Tensor *>(args[5]); + __gm__ Tensor *out_t = reinterpret_cast<__gm__ Tensor *>(args[6]); + uint64_t is_first = static_cast(args[7]); + uint64_t is_last = static_cast(args[8]); + int64_t num_heads = static_cast(args[9]); + int64_t head_dim = static_cast(args[10]); + int64_t q_loop = static_cast(args[11]); + + int32_t block_idx = get_block_idx(args); + int32_t sub_block_id = get_sub_block_id(args); // 0 = AIV0 (rows 0..7), 1 = AIV1 (rows 8..15) + int64_t batch_idx = block_idx / q_loop; + int64_t q_tile_idx = block_idx % q_loop; + + // Scalar layout: full QT=16 rows pack to kAlignedRowsFull=16 floats per block_idx; + // each AIV lane owns kAlignedRowsSub=8 contiguous floats inside that slab. + constexpr int kAlignedRowsFull = ((QT * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + constexpr int kAlignedRowsSub = ((SUB_QT * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + int64_t row_offset = sub_block_id * SUB_QT; + + // Accumulator offsets (each AIV lane owns its own 8-row sub-slice within the block_idx slab) + int64_t scalar_offset = block_idx * kAlignedRowsFull + sub_block_id * kAlignedRowsSub; + int64_t data_offset = (block_idx * QT + row_offset) * head_dim; + + __gm__ float *mij_ptr = + reinterpret_cast<__gm__ float *>(mij_t->buffer.addr) + mij_t->start_offset + scalar_offset; + __gm__ float *lij_ptr = + reinterpret_cast<__gm__ float *>(lij_t->buffer.addr) + lij_t->start_offset + scalar_offset; + __gm__ float *oi_new_ptr = + reinterpret_cast<__gm__ float *>(oi_new_t->buffer.addr) + oi_new_t->start_offset + data_offset; + __gm__ float *mi_ptr = + reinterpret_cast<__gm__ float *>(mi_acc_t->buffer.addr) + mi_acc_t->start_offset + scalar_offset; + __gm__ float *li_ptr = + reinterpret_cast<__gm__ float *>(li_acc_t->buffer.addr) + li_acc_t->start_offset + scalar_offset; + __gm__ float *oi_ptr = + reinterpret_cast<__gm__ float *>(oi_acc_t->buffer.addr) + oi_acc_t->start_offset + data_offset; + + // Output offset: (batch_idx * num_heads + q_tile_idx * QT + row_offset, 0) + int64_t out_offset = (batch_idx * num_heads + q_tile_idx * QT + row_offset) * head_dim; + __gm__ float *dst_ptr = reinterpret_cast<__gm__ float *>(out_t->buffer.addr) + out_t->start_offset + out_offset; + + online_update_spmd(mij_ptr, lij_ptr, oi_new_ptr, mi_ptr, li_ptr, oi_ptr, dst_ptr, is_first, is_last); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp new file mode 100644 index 000000000..85c0187da --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +// SPMD Softmax Preparation Kernel (AIV) with partial block masking and +// dual-vector subvector split. +// +// SPMD block_idx encodes (batch_idx, q_tile_idx). +// The two AIV lanes in a cluster split the Q_TILE=16 rows 8/8 via +// get_sub_block_id(): AIV0 handles rows [0, 8), AIV1 handles rows [8, 16). +// +// Computes (per sub-slice of SUB_M=8 rows): +// sij_masked = pad(sij, valid_len, -inf) +// sij_scale = sij_masked * scale +// mij = row_max(sij_scale) -> (SUB_M, 1) +// pij = exp(sij_scale - mij) -> (SUB_M, N) +// lij = row_sum(pij) -> (SUB_M, 1) +// +// Args: +// args[0] = sij Tensor* (spmd_blocks*Q_TILE, block_size) float32 [input] +// args[1] = context_lens Tensor* (batch,) int32 +// args[2] = pij Tensor* (spmd_blocks*Q_TILE, block_size) bf16 [output] +// args[3] = mij Tensor* (spmd_blocks*Q_TILE,) float32 [output] +// args[4] = lij Tensor* (spmd_blocks*Q_TILE,) float32 [output] +// args[5] = scale_value scalar (as float bits in uint64) +// args[6] = bn scalar: current KV block index +// args[7] = block_size scalar +// args[8] = q_loop scalar + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include "intrinsic.h" + +static constexpr int M = 16; // Full Q tile rows (shared between both AIVs) +static constexpr int SUB_M = 8; // Rows per AIV lane (M / 2) +static constexpr int N = 16; // block_size + +template +static __aicore__ void softmax_prepare_spmd( + __gm__ float *sij_addr, float scale_value, uint64_t valid_len, __gm__ bfloat16_t *pij_addr, + __gm__ float *mij_addr, __gm__ float *lij_addr +) { + constexpr int kAlignedRows = ((TM * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, TN, 1>>; + using GlobalDataMxN_bf16 = GlobalTensor, Stride<1, 1, 1, TN, 1>>; + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + + GlobalDataMxN sijGlobal(sij_addr); + GlobalDataMxN_bf16 pijGlobal(pij_addr); + GlobalScalarDN mijGlobal(mij_addr); + GlobalScalarDN lijGlobal(lij_addr); + + using TileSijDyn = Tile; + using TileSijPad = + Tile; + + using TileVecMxN = Tile; + using TileVecMxN_bf16 = Tile; + using TileScalarDN = Tile; + + TileVecMxN sijTile; + TileSijDyn sijDynTile(static_cast(valid_len)); + TileSijPad sijPadTile; + TileVecMxN pijTile; + TileVecMxN tmpTile; + TileScalarDN maxTile; + TileScalarDN sumTile; + TileVecMxN_bf16 pijBf16Tile; + + TASSIGN(sijTile, 0x0); + TASSIGN(sijDynTile, 0x0); + TASSIGN(sijPadTile, 0x0); + TASSIGN(pijTile, TM * TN * sizeof(float)); + TASSIGN(tmpTile, 2 * TM * TN * sizeof(float)); + TASSIGN(maxTile, 3 * TM * TN * sizeof(float)); + TASSIGN(sumTile, 3 * TM * TN * sizeof(float) + kAlignedRows * sizeof(float)); + TASSIGN(pijBf16Tile, 3 * TM * TN * sizeof(float) + 2 * kAlignedRows * sizeof(float)); + + TLOAD(sijTile, sijGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TFILLPAD_INPLACE(sijPadTile, sijDynTile); + pipe_barrier(PIPE_V); + + TMULS(sijTile, sijTile, scale_value); + pipe_barrier(PIPE_V); + TROWMAX(maxTile, sijTile, tmpTile); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(pijTile, sijTile, maxTile); + pipe_barrier(PIPE_V); + TEXP(pijTile, pijTile); + TCVT(pijBf16Tile, pijTile, RoundMode::CAST_ROUND); + TCVT(pijTile, pijBf16Tile, RoundMode::CAST_ROUND); + TROWSUM(sumTile, pijTile, tmpTile); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(mijGlobal, maxTile); + TSTORE(lijGlobal, sumTile); + TSTORE(pijGlobal, pijBf16Tile); + + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *sij_t = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *context_lens_t = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *pij_t = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *mij_t = reinterpret_cast<__gm__ Tensor *>(args[3]); + __gm__ Tensor *lij_t = reinterpret_cast<__gm__ Tensor *>(args[4]); + float scale_value = from_u64(static_cast(args[5])); + int64_t bn = static_cast(args[6]); + int64_t block_size = static_cast(args[7]); + int64_t q_loop = static_cast(args[8]); + + int32_t block_idx = get_block_idx(args); + int32_t sub_block_id = get_sub_block_id(args); // 0 = AIV0 (rows 0..7), 1 = AIV1 (rows 8..15) + int64_t batch_idx = block_idx / q_loop; + + // Compute valid_len for this block: how many columns of sij are valid + __gm__ int32_t *ctx_ptr = + reinterpret_cast<__gm__ int32_t *>(context_lens_t->buffer.addr) + context_lens_t->start_offset; + int64_t cur_seq = static_cast(ctx_ptr[batch_idx]); + int64_t remaining = cur_seq - bn * block_size; + uint64_t valid_len; + if (remaining <= 0) { + valid_len = 0; + } else if (remaining >= block_size) { + valid_len = static_cast(block_size); + } else { + valid_len = static_cast(remaining); + } + + // Row offset for this AIV lane within the block_idx's Q_TILE slice + int64_t row_offset = sub_block_id * SUB_M; + + // Pointers into this block's SUB_M-row sub-slice of the flat tensors + int64_t data_row_offset = block_idx * M + row_offset; + __gm__ float *sij_addr = + reinterpret_cast<__gm__ float *>(sij_t->buffer.addr) + sij_t->start_offset + data_row_offset * block_size; + __gm__ bfloat16_t *pij_addr = + reinterpret_cast<__gm__ bfloat16_t *>(pij_t->buffer.addr) + pij_t->start_offset + data_row_offset * block_size; + + // Scalar layout: full M=16 rows pack to kAlignedRowsFull=16 floats per block_idx; + // each AIV lane owns kAlignedRowsSub=8 contiguous floats inside that slab. + constexpr int kAlignedRowsFull = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + constexpr int kAlignedRowsSub = ((SUB_M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + int64_t scalar_offset = block_idx * kAlignedRowsFull + sub_block_id * kAlignedRowsSub; + __gm__ float *mij_addr = + reinterpret_cast<__gm__ float *>(mij_t->buffer.addr) + mij_t->start_offset + scalar_offset; + __gm__ float *lij_addr = + reinterpret_cast<__gm__ float *>(lij_t->buffer.addr) + lij_t->start_offset + scalar_offset; + + if (valid_len == 0) { + // No valid KV data — emit neutral values so online_update is a no-op: + // mij = -1e30 (very negative so beta = exp(mij - mi_new) ≈ 0) + // lij = 0 (no contribution to normalizer) + // pij = 0 (no attention weight) + for (int i = 0; i < kAlignedRowsSub; i++) { + mij_addr[i] = -1e30f; + lij_addr[i] = 0.0f; + } + for (int i = 0; i < SUB_M * static_cast(block_size); i++) { + pij_addr[i] = static_cast(0.0f); + } + return; + } + + softmax_prepare_spmd(sij_addr, scale_value, valid_len, pij_addr, mij_addr, lij_addr); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/kernel_config.py b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/kernel_config.py new file mode 100644 index 000000000..f24c94b1f --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/kernel_config.py @@ -0,0 +1,85 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +""" +SPMD Paged Attention Kernel and Orchestration Configuration + +Uses SPMD (block_num) parallelism across batch*q_loop positions. +Each block handles one (batch_idx, q_tile_idx) using get_block_idx(). + +Softmax and online-update run as MIX tasks (AIC idle + AIV0 + AIV1), with the +two AIVs splitting the 16 query rows 8/8 via get_sub_block_id(). + +AIC Kernels (Matrix Multiplication): + - aic_qk_matmul: Q @ K^T (SPMD across batch*q_loop) + - aic_pv_matmul: P @ V (SPMD across batch*q_loop) + - aic_hub: no-op, occupies the AIC slot of softmax/update MIX tasks + +AIV Kernels (Vector Operations): + - aiv_softmax_prepare: scale, rowmax, exp, rowsum on 8-row sub-tile + - aiv_online_update: online softmax accumulation + normalization on 8-row sub-tile + - aiv_hub: no-op, used to allocate persistent accumulators +""" + +from pathlib import Path + +from task_interface import ArgDirection as D # pyright: ignore[reportAttributeAccessIssue] + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "spmd_paged_attention_orch.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +KERNELS = [ + # AIC kernels (matrix multiplication using Cube unit) + { + "func_id": 0, + "name": "SPMD_QK", + "source": str(_KERNELS_ROOT / "aic" / "aic_qk_matmul.cpp"), + "core_type": "aic", + }, + { + "func_id": 1, + "name": "SPMD_PV", + "source": str(_KERNELS_ROOT / "aic" / "aic_pv_matmul.cpp"), + "core_type": "aic", + }, + { + "func_id": 2, + "name": "AIC_HUB", + "source": str(_KERNELS_ROOT / "aic" / "aic_hub.cpp"), + "core_type": "aic", + }, + # AIV kernels (vector operations) + { + "func_id": 3, + "name": "SPMD_SF", + "source": str(_KERNELS_ROOT / "aiv" / "aiv_softmax_prepare.cpp"), + "core_type": "aiv", + }, + { + "func_id": 4, + "name": "SPMD_UP", + "source": str(_KERNELS_ROOT / "aiv" / "aiv_online_update.cpp"), + "core_type": "aiv", + }, + { + "func_id": 5, + "name": "AIV_HUB", + "source": str(_KERNELS_ROOT / "aiv" / "aiv_hub.cpp"), + "core_type": "aiv", + }, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "block_dim": 24, +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/orchestration/spmd_paged_attention_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/orchestration/spmd_paged_attention_orch.cpp new file mode 100644 index 000000000..22140426d --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/orchestration/spmd_paged_attention_orch.cpp @@ -0,0 +1,239 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * SPMD Paged Attention Orchestration (dual-vector subvector partitioning) + * + * Uses SPMD parallelism: block_num = batch * q_loop, where each logical + * block handles one (batch_idx, q_tile_idx) position. Kernels use + * get_block_idx() to compute their data offsets. + * + * QK and PV matmuls are AIC-only SPMD tasks. Softmax and online-update are + * submitted as MIX tasks (AIC hub + AIV0 + AIV1) so the two AIV lanes within + * a cluster each process one half of the 16 query rows, using + * get_sub_block_id() to pick their 8-row slice. This mirrors the AscendC + * reference (paged_attention_antiquantkv.h) subvector partitioning strategy. + * + * Memory Layout: + * Query: (batch, num_heads, head_dim) - bfloat16 + * Key/Value: (total_blocks, block_size, kv_head_num, head_dim) - bfloat16 + * Block Table: (batch, max_num_blocks_per_req) - int32 + * Context Lens: (batch,) - int32 + * Output: (batch, num_heads, head_dim) - float32 + * + * Scratch layout (runtime-allocated, indexed by block_idx * Q_TILE): + * sij: (spmd_blocks * Q_TILE, block_size) float32 + * pij: (spmd_blocks * Q_TILE, block_size) data_type + * oi_new: (spmd_blocks * Q_TILE, head_dim) float32 + * mij/lij: (spmd_blocks * Q_TILE,) float32 + * oi_acc/mi_acc/li_acc: persistent accumulators across bn loop + */ + +#include +#include + +#include + +#include "pto_orchestration_api.h" + +#define FUNC_QK_MATMUL 0 +#define FUNC_PV_MATMUL 1 +#define FUNC_AIC_HUB 2 +#define FUNC_SOFTMAX_PREPARE 3 +#define FUNC_ONLINE_UPDATE 4 +#define FUNC_AIV_HUB 5 + +static constexpr uint64_t Q_TILE = 16; + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +aicpu_orchestration_config(const ChipStorageTaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{ + .expected_arg_count = 7, + }; +} + +__attribute__((visibility("default"))) void aicpu_orchestration_entry(const ChipStorageTaskArgs &orch_args) { + // query: shape=[batch, num_heads, head_dim] + uint64_t batch = orch_args.tensor(0).shapes[0]; + uint64_t num_heads = orch_args.tensor(0).shapes[1]; + uint64_t head_dim = orch_args.tensor(0).shapes[2]; + DataType data_type = orch_args.tensor(0).dtype; + + // key_cache: shape=[total_blocks, block_size, kv_head_num, head_dim] + uint64_t block_size = orch_args.tensor(1).shapes[1]; + + // block_table: shape=[batch, max_num_blocks_per_req] + uint64_t max_num_blocks_per_req = orch_args.tensor(3).shapes[1]; + + // scale from scalar arg + uint64_t scale_value = orch_args.scalar(0); + + uint64_t q_loop = (num_heads + Q_TILE - 1) / Q_TILE; + int16_t spmd_block_num = static_cast(batch * q_loop); + + LOG_INFO( + "SPMD PA: batch=%" PRIu64 " heads=%" PRIu64 " hd=%" PRIu64 " bs=%" PRIu64 " q_loop=%" PRIu64 " blocks=%d", + batch, num_heads, head_dim, block_size, q_loop, spmd_block_num + ); + + // Wrap host-provided tensors + void *query_ptr = orch_args.tensor(0).data_as(); + void *kc_ptr = orch_args.tensor(1).data_as(); + void *vc_ptr = orch_args.tensor(2).data_as(); + void *out_ptr = orch_args.tensor(5).data_as(); + + uint64_t total_kv_blocks = orch_args.tensor(1).shapes[0]; + uint64_t kv_total_rows = total_kv_blocks * block_size; + + uint32_t query_shapes[2] = {static_cast(batch * num_heads), static_cast(head_dim)}; + uint32_t kv_shapes[2] = {static_cast(kv_total_rows), static_cast(head_dim)}; + uint32_t out_shapes[2] = {static_cast(batch * num_heads), static_cast(head_dim)}; + + Tensor query = make_tensor_external(query_ptr, query_shapes, 2, data_type); + Tensor key_cache = make_tensor_external(kc_ptr, kv_shapes, 2, data_type); + Tensor value_cache = make_tensor_external(vc_ptr, kv_shapes, 2, data_type); + Tensor out = make_tensor_external(out_ptr, out_shapes, 2, DataType::FLOAT32); + + uint32_t bt_shapes[2] = {static_cast(batch), static_cast(max_num_blocks_per_req)}; + Tensor block_table = + make_tensor_external(orch_args.tensor(3).data_as(), bt_shapes, 2, DataType::INT32, false); + uint32_t cl_shapes[1] = {static_cast(batch)}; + Tensor context_lens = + make_tensor_external(orch_args.tensor(4).data_as(), cl_shapes, 1, DataType::INT32, false); + + // Find max context_len for KV block loop bound + uint64_t max_ctx = 0; + for (uint64_t b = 0; b < batch; b++) { + uint32_t idx[1] = {static_cast(b)}; + uint64_t ctx = static_cast(get_tensor_data(context_lens, 1, idx)); + if (ctx > max_ctx) max_ctx = ctx; + } + uint64_t max_bn = (max_ctx + block_size - 1) / block_size; + + // Scratch tensor create infos (sized for all SPMD blocks) + uint32_t n_rows = static_cast(spmd_block_num) * static_cast(Q_TILE); + uint32_t sij_shapes[2] = {n_rows, static_cast(block_size)}; + uint32_t pij_shapes[2] = {n_rows, static_cast(block_size)}; + uint32_t oi_new_shapes[2] = {n_rows, static_cast(head_dim)}; + uint32_t scalar_shapes[1] = {n_rows}; + + TensorCreateInfo sij_ci(sij_shapes, 2, DataType::FLOAT32); + TensorCreateInfo pij_ci(pij_shapes, 2, data_type); + TensorCreateInfo oi_new_ci(oi_new_shapes, 2, DataType::FLOAT32); + TensorCreateInfo mij_ci(scalar_shapes, 1, DataType::FLOAT32); + TensorCreateInfo lij_ci(scalar_shapes, 1, DataType::FLOAT32); + TensorCreateInfo acc_oi_ci(oi_new_shapes, 2, DataType::FLOAT32); + TensorCreateInfo acc_mi_ci(scalar_shapes, 1, DataType::FLOAT32); + TensorCreateInfo acc_li_ci(scalar_shapes, 1, DataType::FLOAT32); + + PTO2_SCOPE() { + // Allocate persistent accumulators via no-op AIV hub + Arg hub_args; + hub_args.add_output(acc_oi_ci); + hub_args.add_output(acc_mi_ci); + hub_args.add_output(acc_li_ci); + TaskOutputTensors hub_outs = pto2_rt_submit_aiv_task(FUNC_AIV_HUB, hub_args); + const Tensor &oi_acc = hub_outs.get_ref(0); + const Tensor &mi_acc = hub_outs.get_ref(1); + const Tensor &li_acc = hub_outs.get_ref(2); + + for (uint64_t bn = 0; bn < max_bn; bn++) { + uint64_t is_first = (bn == 0) ? 1 : 0; + uint64_t is_last = (bn == max_bn - 1) ? 1 : 0; + + // -- QK Matmul (AIC, SPMD) -- + Arg qk_args; + qk_args.add_input(query); + qk_args.add_input(key_cache); + qk_args.add_input(block_table); + qk_args.add_input(context_lens); + qk_args.add_output(sij_ci); + qk_args.add_scalar(static_cast(bn)); + qk_args.add_scalar(static_cast(num_heads)); + qk_args.add_scalar(static_cast(head_dim)); + qk_args.add_scalar(static_cast(block_size)); + qk_args.add_scalar(static_cast(max_num_blocks_per_req)); + qk_args.add_scalar(static_cast(q_loop)); + qk_args.launch_spec.set_block_num(spmd_block_num); + TaskOutputTensors qk_outs = pto2_rt_submit_aic_task(FUNC_QK_MATMUL, qk_args); + const Tensor &sij = qk_outs.get_ref(0); + + // -- Softmax Prepare (MIX: AIC hub + AIV0 + AIV1, SPMD) -- + // AIV0 processes rows 0..7, AIV1 processes rows 8..15 of the Q_TILE + // slice, discriminated via get_sub_block_id() inside the kernel. + Arg sf_args; + sf_args.add_input(sij); + sf_args.add_input(context_lens); + sf_args.add_output(pij_ci); + sf_args.add_output(mij_ci); + sf_args.add_output(lij_ci); + sf_args.add_scalar(scale_value); + sf_args.add_scalar(static_cast(bn)); + sf_args.add_scalar(static_cast(block_size)); + sf_args.add_scalar(static_cast(q_loop)); + sf_args.launch_spec.set_block_num(spmd_block_num); + MixedKernels sf_mk; + sf_mk.aic_kernel_id = FUNC_AIC_HUB; + sf_mk.aiv0_kernel_id = FUNC_SOFTMAX_PREPARE; + sf_mk.aiv1_kernel_id = FUNC_SOFTMAX_PREPARE; + TaskOutputTensors sf_outs = pto2_rt_submit_task(sf_mk, sf_args); + const Tensor &pij = sf_outs.get_ref(0); + const Tensor &mij = sf_outs.get_ref(1); + const Tensor &lij = sf_outs.get_ref(2); + + // -- PV Matmul (AIC, SPMD) -- + Arg pv_args; + pv_args.add_input(pij); + pv_args.add_input(value_cache); + pv_args.add_input(block_table); + pv_args.add_input(context_lens); + pv_args.add_output(oi_new_ci); + pv_args.add_scalar(static_cast(bn)); + pv_args.add_scalar(static_cast(num_heads)); + pv_args.add_scalar(static_cast(head_dim)); + pv_args.add_scalar(static_cast(block_size)); + pv_args.add_scalar(static_cast(max_num_blocks_per_req)); + pv_args.add_scalar(static_cast(q_loop)); + pv_args.launch_spec.set_block_num(spmd_block_num); + TaskOutputTensors pv_outs = pto2_rt_submit_aic_task(FUNC_PV_MATMUL, pv_args); + const Tensor &oi_new = pv_outs.get_ref(0); + + // -- Online Update (MIX: AIC hub + AIV0 + AIV1, SPMD) -- + // Row-independent online softmax update: AIV0 updates rows 0..7 of + // the Q_TILE accumulator slice, AIV1 updates rows 8..15. + Arg up_args; + up_args.add_input(mij); + up_args.add_input(lij); + up_args.add_input(oi_new); + up_args.add_inout(mi_acc); + up_args.add_inout(li_acc); + up_args.add_inout(oi_acc); + up_args.add_inout(out); + up_args.add_scalar(is_first); + up_args.add_scalar(is_last); + up_args.add_scalar(static_cast(num_heads)); + up_args.add_scalar(static_cast(head_dim)); + up_args.add_scalar(static_cast(q_loop)); + up_args.launch_spec.set_block_num(spmd_block_num); + MixedKernels up_mk; + up_mk.aic_kernel_id = FUNC_AIC_HUB; + up_mk.aiv0_kernel_id = FUNC_ONLINE_UPDATE; + up_mk.aiv1_kernel_id = FUNC_ONLINE_UPDATE; + pto2_rt_submit_task(up_mk, up_args); + } + } + + LOG_INFO("SPMD PA: %" PRIu64 " KV iters x 4 tasks, blocks=%d", max_bn, static_cast(spmd_block_num)); +} + +} // extern "C"