-
Notifications
You must be signed in to change notification settings - Fork 21
Refactor: multi-batch q_padded and simplify online softmax in scope2/scope12 #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |
| """Qwen3-32B decode Scope 2 — RoPE + KV cache update + grouped-query attention. | ||
| 1. K RoPE + cache write, V cache write, Q RoPE + pad | ||
| 2. QK matmul | ||
| 3. Softmax | ||
| 3. Softmax | ||
| 4. SV matmul | ||
| 5. Online-softmax accumulation + final normalisation | ||
|
|
||
|
|
@@ -64,6 +64,16 @@ def qwen3_scope2( | |
| v_cache: pl.Tensor[[cache_rows, head_dim], pl.BF16], | ||
| attn_out: pl.Out[pl.Tensor[[batch, hidden], pl.BF16]], | ||
| ) -> pl.Tensor[[batch, hidden], pl.BF16]: | ||
| # Padding q | ||
| all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16) | ||
| with pl.incore(): | ||
| for idx in pl.range(batch * total_q_groups): | ||
| all_q_padded = pl.assemble( | ||
| all_q_padded, | ||
| pl.cast(pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16), | ||
| [idx * Q_HEAD_PAD + Q_HEAD_BATCH, 0], | ||
| ) | ||
|
Comment on lines
+68
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Populate all Q groups or enforce
Also applies to: 143-144, 151-151 🤖 Prompt for AI Agents |
||
|
|
||
| for b in pl.range(batch): | ||
| ctx_len = pl.tensor.read(seq_lens, [b]) | ||
| pos = ctx_len - 1 | ||
|
|
@@ -75,16 +85,6 @@ def qwen3_scope2( | |
| sin_lo = pl.slice(sin_row, [1, half_dim], [0, 0]) | ||
| sin_hi = pl.slice(sin_row, [1, half_dim], [0, half_dim]) | ||
|
|
||
| # Workaround | ||
| all_q_padded = pl.create_tensor([total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16) | ||
| with pl.incore(): | ||
| for gi in pl.range(total_q_groups): | ||
| all_q_padded = pl.assemble( | ||
| all_q_padded, | ||
| pl.cast(pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16), | ||
| [gi * Q_HEAD_PAD, 0], | ||
| ) | ||
|
|
||
| # Stage 1: K RoPE + cache update + V cache + Q RoPE + pad. | ||
| with pl.auto_incore(): | ||
| for ki in pl.parallel(0, num_kv_heads, chunk=8): | ||
|
|
@@ -140,56 +140,22 @@ def qwen3_scope2( | |
| ), | ||
| target_type=pl.BF16, | ||
| ) | ||
| all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [ki * Q_HEAD_PAD + qi, 0]) | ||
| all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [ki * Q_HEAD_PAD + qi, half_dim]) | ||
| all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, 0]) | ||
| all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, half_dim]) | ||
|
|
||
| attn_row = pl.create_tensor([1, hidden], dtype=pl.BF16) | ||
| for gi in pl.range(total_q_groups): | ||
| kvh = gi // q_groups | ||
| qg = gi - kvh * q_groups | ||
| q_base = kvh * q_per_kv + qg * Q_HEAD_BATCH | ||
| q_padded = pl.slice(all_q_padded, [Q_HEAD_PAD, head_dim], [gi * Q_HEAD_PAD, 0]) | ||
| q_padded = pl.slice(all_q_padded, [Q_HEAD_PAD, head_dim], [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD, 0]) | ||
|
|
||
| # Workaround | ||
| # Stage 2: QK matmul for all active sb blocks. | ||
| all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32) | ||
| all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16) | ||
| all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) | ||
| all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) | ||
| all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) | ||
| for sb0 in pl.range(0, ctx_blocks, SB_BATCH): | ||
| with pl.incore(): | ||
| for si in pl.range(SB_BATCH): | ||
| sb = sb0 + si | ||
| if sb < ctx_blocks: | ||
| all_raw_scores = pl.assemble( | ||
| all_raw_scores, | ||
| pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), | ||
| [sb * Q_HEAD_PAD, 0], | ||
| ) | ||
| all_exp_padded = pl.assemble( | ||
| all_exp_padded, | ||
| pl.cast(pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), target_type=pl.BF16), | ||
| [sb * Q_HEAD_PAD, 0], | ||
| ) | ||
| all_oi_tmp = pl.assemble( | ||
| all_oi_tmp, | ||
| pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0), | ||
| [sb * Q_HEAD_PAD, 0], | ||
| ) | ||
| mi_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0) | ||
| all_cur_mi = pl.assemble( | ||
| all_cur_mi, | ||
| pl.reshape(mi_init_flat, [Q_HEAD_BATCH, 1]), | ||
| [sb * Q_HEAD_BATCH, 0], | ||
| ) | ||
| li_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0) | ||
| all_cur_li = pl.assemble( | ||
| all_cur_li, | ||
| pl.reshape(li_init_flat, [Q_HEAD_BATCH, 1]), | ||
| [sb * Q_HEAD_BATCH, 0], | ||
| ) | ||
|
|
||
| # Stage 2: QK matmul for all active sb blocks. | ||
| for sb0 in pl.range(0, ctx_blocks, SB_BATCH): | ||
| with pl.incore(): | ||
| for si in pl.range(SB_BATCH): | ||
|
|
@@ -251,38 +217,22 @@ def qwen3_scope2( | |
| oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) | ||
| all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) | ||
|
|
||
| # Workaround | ||
| with pl.incore(): | ||
| oi = pl.full([Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0) | ||
| li_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0) | ||
| li = pl.reshape(li_flat, [Q_HEAD_BATCH, 1]) | ||
| mi_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0) | ||
| mi = pl.reshape(mi_flat, [Q_HEAD_BATCH, 1]) | ||
|
|
||
| # Stage 5: online softmax accumulation for active sb blocks. | ||
| for sb0 in pl.range(0, ctx_blocks, SB_BATCH): | ||
| with pl.incore(): | ||
| for si in pl.range(SB_BATCH): | ||
| sb = sb0 + si | ||
| if sb < ctx_blocks: | ||
| oi_tmp_valid = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [sb * Q_HEAD_PAD, 0]) | ||
| cur_mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0]) | ||
| cur_li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0]) | ||
| if sb == 0: | ||
| oi = oi_tmp_valid | ||
| li = cur_li | ||
| mi = cur_mi | ||
| else: | ||
| mi_new = pl.maximum(mi, cur_mi) | ||
| alpha = pl.exp(pl.sub(mi, mi_new)) | ||
| beta = pl.exp(pl.sub(cur_mi, mi_new)) | ||
| li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) | ||
| oi = pl.add(pl.row_expand_mul(oi, alpha), | ||
| pl.row_expand_mul(oi_tmp_valid, beta)) | ||
| mi = mi_new | ||
|
|
||
| # Stage 6: normalise and write back one Q-head batch. | ||
| # Stage 5: online softmax accumulation and normalisation. | ||
| with pl.incore(): | ||
| oi = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [0, 0]) | ||
| mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [0, 0]) | ||
| li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [0, 0]) | ||
| for sb in pl.range(1, ctx_blocks): | ||
| oi_tmp_valid = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [sb * Q_HEAD_PAD, 0]) | ||
| cur_mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0]) | ||
| cur_li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0]) | ||
| mi_new = pl.maximum(mi, cur_mi) | ||
| alpha = pl.exp(pl.sub(mi, mi_new)) | ||
| beta = pl.exp(pl.sub(cur_mi, mi_new)) | ||
| li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) | ||
| oi = pl.add(pl.row_expand_mul(oi, alpha), | ||
| pl.row_expand_mul(oi_tmp_valid, beta)) | ||
| mi = mi_new | ||
| ctx = pl.row_expand_div(oi, li) | ||
| ctx_flat = pl.reshape(ctx, [1, Q_HEAD_BATCH * head_dim]) | ||
| ctx_flat_bf16 = pl.cast(ctx_flat, target_type=pl.BF16) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hoisted Q buffer still only fills one group per KV head.
This buffer is allocated and consumed as
total_q_groups, but the producer still writes rows as ifq_groups == 1. With any non-default configuration whereq_per_kv > Q_HEAD_BATCH, the later groups will read uninitializedall_q_paddedrows. Please either materialize every(ki, qg)group here or add an explicit guard that only the single-group case is supported.Also applies to: 224-225, 232-232
🤖 Prompt for AI Agents