From 7525bb1e00d7368f38523c8897a97886a84808e1 Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Thu, 20 Mar 2025 15:51:45 +0000 Subject: [PATCH] Add v2 test to paged_attention_decode --- python/perf-kernels/paged_attention_decode.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/perf-kernels/paged_attention_decode.py b/python/perf-kernels/paged_attention_decode.py index 63bf78e70b4a..d5973fde7eb5 100644 --- a/python/perf-kernels/paged_attention_decode.py +++ b/python/perf-kernels/paged_attention_decode.py @@ -382,8 +382,8 @@ def paged_attn_decode_v2(output: torch.Tensor, #[num_seqs, num_kv_heads*query_g _paged_attn_decode_v2_wo_dot_reduce_kernel[grid](output, exp_sums, max_logits, tmp_output, seq_lens, output.stride(0), output.stride(1), exp_sums.stride(0), exp_sums.stride(1), tmp_output.stride(0), tmp_output.stride(1), - tmp_output.stride(2), compute_type=compute_type, - HEAD_SZ=head_sz, HEAD_SZ_POW2=head_sz_pow2, + tmp_output.stride(2), HEAD_SZ=head_sz, + HEAD_SZ_POW2=head_sz_pow2, SEQ_PARTITION_SZ=_SEQ_PARTITION_SIZE, MAX_NUM_SEQ_PARTITIONS=int(max_num_partitions), MAX_NUM_SEQ_PARTITIONS_POW2=int(max_num_partitions_pow2)) @@ -412,8 +412,8 @@ def paged_attn_decode_v2(output: torch.Tensor, #[num_seqs, num_kv_heads*query_g _paged_attn_decode_v2_w_dot_reduce_kernel[grid]( output, exp_sums, max_logits, tmp_output, seq_lens, output.stride(0), output.stride(1), exp_sums.stride(0), exp_sums.stride(1), exp_sums.stride(2), tmp_output.stride(0), tmp_output.stride(1), tmp_output.stride(2), - tmp_output.stride(3), compute_type=compute_type, HEAD_SZ=head_sz, HEAD_SZ_POW2=head_sz_pow2, - QUERY_GRP_SZ=query_grp_sz, QUERY_GRP_SZ_POW2=query_grp_sz_pow2, SEQ_PARTITION_SZ=_SEQ_PARTITION_SIZE, + tmp_output.stride(3), HEAD_SZ=head_sz, HEAD_SZ_POW2=head_sz_pow2, QUERY_GRP_SZ=query_grp_sz, + QUERY_GRP_SZ_POW2=query_grp_sz_pow2, SEQ_PARTITION_SZ=_SEQ_PARTITION_SIZE, MAX_NUM_SEQ_PARTITIONS=int(max_num_partitions), MAX_NUM_SEQ_PARTITIONS_POW2=int(triton.next_power_of_2(max_num_partitions))) @@ -855,14 +855,18 @@ def paged_attention_decode_ref(output, #[num_seqs, num_q_heads, head_sz] (8, 1, 1, 1, 1, 10), (16, 1, 1, 1, 1, 10), (64, 1, 1, 1, 1, 10), + (64, 1, 1, 1, 1, 8192), + (64, 1, 1, 1, 1, 10000), #H_Q and H_KV > 1 (1, 4, 4, 1, 1, 1), (1, 4, 4, 1, 1, 10), + (1, 4, 4, 1, 1, 10000), #Head_dim > 1 (1, 1, 1, 8, 1, 1), (1, 1, 1, 8, 1, 10), + (1, 1, 1, 8, 1, 10000), #H_Q and H_KV > 1 and Head_dim > 1 (1, 4, 4, 8, 1, 1), @@ -870,6 +874,7 @@ def paged_attention_decode_ref(output, #[num_seqs, num_q_heads, head_sz] (4, 4, 4, 8, 1, 10), (16, 4, 4, 8, 1, 10), (32, 4, 4, 8, 1, 10), + (32, 4, 4, 8, 1, 10000), #H_Q and H_KV > 1 and Head_dim > 1 and KV_BLK_SZ > 1 (1, 1, 1, 1, 1, 1), @@ -891,12 +896,12 @@ def paged_attention_decode_ref(output, #[num_seqs, num_q_heads, head_sz] #GQA Basic (1, 2, 1, 16, 16, 1), (1, 2, 1, 16, 16, 10), - - #GQA Basic + (1, 2, 1, 16, 16, 10000), (1, 4, 2, 16, 16, 1), (1, 4, 2, 16, 16, 10), (1, 4, 2, 128, 16, 1), (1, 4, 2, 128, 16, 10), + (1, 4, 2, 128, 16, 16384), (1, 6, 2, 128, 16, 1), (1, 6, 2, 128, 16, 10), (1, 6, 2, 128, 16, 16), @@ -906,9 +911,11 @@ def paged_attention_decode_ref(output, #[num_seqs, num_q_heads, head_sz] (1, 6, 2, 128, 16, 56), (1, 6, 2, 128, 16, 64), (1, 6, 2, 128, 16, 128), + (1, 6, 2, 128, 16, 8192), (1, 8, 2, 128, 16, 128), (4, 8, 2, 128, 16, 128), (16, 8, 2, 128, 16, 128), + (16, 8, 2, 128, 16, 10000), (32, 8, 2, 128, 16, 128), (64, 8, 2, 128, 16, 200), ])