Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions python/perf-kernels/paged_attention_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ def paged_attn_decode_v2(output: torch.Tensor, #[num_seqs, num_kv_heads*query_g
_paged_attn_decode_v2_wo_dot_reduce_kernel[grid](output, exp_sums, max_logits, tmp_output, seq_lens,
output.stride(0), output.stride(1), exp_sums.stride(0),
exp_sums.stride(1), tmp_output.stride(0), tmp_output.stride(1),
tmp_output.stride(2), compute_type=compute_type,
HEAD_SZ=head_sz, HEAD_SZ_POW2=head_sz_pow2,
tmp_output.stride(2), HEAD_SZ=head_sz,
HEAD_SZ_POW2=head_sz_pow2,
SEQ_PARTITION_SZ=_SEQ_PARTITION_SIZE,
MAX_NUM_SEQ_PARTITIONS=int(max_num_partitions),
MAX_NUM_SEQ_PARTITIONS_POW2=int(max_num_partitions_pow2))
Expand Down Expand Up @@ -412,8 +412,8 @@ def paged_attn_decode_v2(output: torch.Tensor, #[num_seqs, num_kv_heads*query_g
_paged_attn_decode_v2_w_dot_reduce_kernel[grid](
output, exp_sums, max_logits, tmp_output, seq_lens, output.stride(0), output.stride(1), exp_sums.stride(0),
exp_sums.stride(1), exp_sums.stride(2), tmp_output.stride(0), tmp_output.stride(1), tmp_output.stride(2),
tmp_output.stride(3), compute_type=compute_type, HEAD_SZ=head_sz, HEAD_SZ_POW2=head_sz_pow2,
QUERY_GRP_SZ=query_grp_sz, QUERY_GRP_SZ_POW2=query_grp_sz_pow2, SEQ_PARTITION_SZ=_SEQ_PARTITION_SIZE,
tmp_output.stride(3), HEAD_SZ=head_sz, HEAD_SZ_POW2=head_sz_pow2, QUERY_GRP_SZ=query_grp_sz,
QUERY_GRP_SZ_POW2=query_grp_sz_pow2, SEQ_PARTITION_SZ=_SEQ_PARTITION_SIZE,
MAX_NUM_SEQ_PARTITIONS=int(max_num_partitions),
MAX_NUM_SEQ_PARTITIONS_POW2=int(triton.next_power_of_2(max_num_partitions)))

Expand Down Expand Up @@ -855,21 +855,26 @@ def paged_attention_decode_ref(output, #[num_seqs, num_q_heads, head_sz]
(8, 1, 1, 1, 1, 10),
(16, 1, 1, 1, 1, 10),
(64, 1, 1, 1, 1, 10),
(64, 1, 1, 1, 1, 8192),
(64, 1, 1, 1, 1, 10000),

#H_Q and H_KV > 1
(1, 4, 4, 1, 1, 1),
(1, 4, 4, 1, 1, 10),
(1, 4, 4, 1, 1, 10000),

#Head_dim > 1
(1, 1, 1, 8, 1, 1),
(1, 1, 1, 8, 1, 10),
(1, 1, 1, 8, 1, 10000),

#H_Q and H_KV > 1 and Head_dim > 1
(1, 4, 4, 8, 1, 1),
(1, 4, 4, 8, 1, 10),
(4, 4, 4, 8, 1, 10),
(16, 4, 4, 8, 1, 10),
(32, 4, 4, 8, 1, 10),
(32, 4, 4, 8, 1, 10000),

#H_Q and H_KV > 1 and Head_dim > 1 and KV_BLK_SZ > 1
(1, 1, 1, 1, 1, 1),
Expand All @@ -891,12 +896,12 @@ def paged_attention_decode_ref(output, #[num_seqs, num_q_heads, head_sz]
#GQA Basic
(1, 2, 1, 16, 16, 1),
(1, 2, 1, 16, 16, 10),

#GQA Basic
(1, 2, 1, 16, 16, 10000),
(1, 4, 2, 16, 16, 1),
(1, 4, 2, 16, 16, 10),
(1, 4, 2, 128, 16, 1),
(1, 4, 2, 128, 16, 10),
(1, 4, 2, 128, 16, 16384),
(1, 6, 2, 128, 16, 1),
(1, 6, 2, 128, 16, 10),
(1, 6, 2, 128, 16, 16),
Expand All @@ -906,9 +911,11 @@ def paged_attention_decode_ref(output, #[num_seqs, num_q_heads, head_sz]
(1, 6, 2, 128, 16, 56),
(1, 6, 2, 128, 16, 64),
(1, 6, 2, 128, 16, 128),
(1, 6, 2, 128, 16, 8192),
(1, 8, 2, 128, 16, 128),
(4, 8, 2, 128, 16, 128),
(16, 8, 2, 128, 16, 128),
(16, 8, 2, 128, 16, 10000),
(32, 8, 2, 128, 16, 128),
(64, 8, 2, 128, 16, 200),
])
Expand Down