Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 28 additions & 59 deletions examples/models/qwen3/qwen3_32b_decode_scope12.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
3. Softmax
4. SV matmul
5. Online-softmax accumulation + final normalisation

Intermediate q_proj/k_proj/v_proj are FP32 GM tensors between the two scopes.
"""
from __future__ import annotations
Expand Down Expand Up @@ -92,20 +91,6 @@ def qwen3_scope12(
k_proj = pl.create_tensor([batch, kv_hidden], dtype=pl.FP32)
v_proj = pl.create_tensor([batch, kv_hidden], dtype=pl.FP32)

# Initialize intermediate tensors to zero so assemble generates inout.
for ob in pl.range(q_out_blocks):
q0 = ob * Q_OUT_CHUNK
with pl.incore():
zero_q = pl.full([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32, value=0.0)
q_proj = pl.assemble(q_proj, zero_q, [0, q0])
for ob in pl.range(kv_out_blocks):
kv0 = ob * KV_OUT_CHUNK
with pl.incore():
zero_k = pl.full([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32, value=0.0)
zero_v = pl.full([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32, value=0.0)
k_proj = pl.assemble(k_proj, zero_k, [0, kv0])
v_proj = pl.assemble(v_proj, zero_v, [0, kv0])

# ── Scope 1: input RMSNorm + Q/K/V projection ──
for b0 in pl.range(0, batch, BATCH_TILE):
normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16)
Expand Down Expand Up @@ -176,6 +161,16 @@ def qwen3_scope12(
v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i)
v_proj = pl.assemble(v_proj, v_acc, [b0, kv0])

# Padding q
all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16)
with pl.incore():
for idx in pl.range(batch * total_q_groups):
all_q_padded = pl.assemble(
all_q_padded,
pl.cast(pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
[idx * Q_HEAD_PAD + Q_HEAD_BATCH, 0],
)
Comment on lines +164 to +172
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The hoisted Q buffer still only fills one group per KV head.

This buffer is allocated and consumed as total_q_groups, but the producer still writes rows as if q_groups == 1. With any non-default configuration where q_per_kv > Q_HEAD_BATCH, the later groups will read uninitialized all_q_padded rows. Please either materialize every (ki, qg) group here or add an explicit guard that only the single-group case is supported.

Also applies to: 224-225, 232-232

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_scope12.py` around lines 164 - 172,
The hoisted Q buffer all_q_padded is allocated for batch * total_q_groups groups
but the producer only writes the first group per KV head; either fully
materialize every (ki, qg) group or explicitly reject multi-group configs. Fix
by iterating/writing for every group: update the producer loop that uses
pl.range(batch * total_q_groups) / pl.assemble to cover all q groups (use
total_q_groups and q_per_kv or q_groups in the loop/index arithmetic so each
group's rows at idx * Q_HEAD_PAD + Q_HEAD_BATCH + group_offset are filled), or
if you prefer a simpler change add a guard in the function (raise/exit) when
total_q_groups != 1 (check total_q_groups, q_per_kv vs Q_HEAD_BATCH) and
document that only single-group is supported; ensure references to all_q_padded,
Q_HEAD_PAD, Q_HEAD_BATCH, total_q_groups, and the pl.assemble write are updated
accordingly.


# ── Scope 2: RoPE + KV cache update + grouped-query attention ──
for b in pl.range(batch):
ctx_len = pl.tensor.read(seq_lens, [b])
Expand All @@ -188,16 +183,6 @@ def qwen3_scope12(
sin_lo = pl.slice(sin_row, [1, half_dim], [0, 0])
sin_hi = pl.slice(sin_row, [1, half_dim], [0, half_dim])

# Workaround
all_q_padded = pl.create_tensor([total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16)
with pl.incore():
for gi in pl.range(total_q_groups):
all_q_padded = pl.assemble(
all_q_padded,
pl.cast(pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
[gi * Q_HEAD_PAD, 0],
)

# Stage 1: K RoPE + cache update + V cache + Q RoPE + pad.
with pl.auto_incore():
for ki in pl.parallel(0, num_kv_heads, chunk=8):
Expand Down Expand Up @@ -236,15 +221,15 @@ def qwen3_scope12(
pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)),
target_type=pl.BF16,
)
all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [ki * Q_HEAD_PAD + qi, 0])
all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [ki * Q_HEAD_PAD + qi, half_dim])
all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, 0])
all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, half_dim])

attn_row = pl.create_tensor([1, hidden], dtype=pl.BF16)
for gi in pl.range(total_q_groups):
kvh = gi // q_groups
qg = gi - kvh * q_groups
q_base = kvh * q_per_kv + qg * Q_HEAD_BATCH
q_padded = pl.slice(all_q_padded, [Q_HEAD_PAD, head_dim], [gi * Q_HEAD_PAD, 0])
q_padded = pl.slice(all_q_padded, [Q_HEAD_PAD, head_dim], [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD, 0])

# Workaround
all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32)
Expand Down Expand Up @@ -347,38 +332,22 @@ def qwen3_scope12(
oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32)
all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0])

# Workaround
with pl.incore():
oi = pl.full([Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0)
li_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0)
li = pl.reshape(li_flat, [Q_HEAD_BATCH, 1])
mi_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0)
mi = pl.reshape(mi_flat, [Q_HEAD_BATCH, 1])

# Stage 5: online softmax accumulation for active sb blocks.
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.incore():
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
oi_tmp_valid = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [sb * Q_HEAD_PAD, 0])
cur_mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0])
cur_li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0])
if sb == 0:
oi = oi_tmp_valid
li = cur_li
mi = cur_mi
else:
mi_new = pl.maximum(mi, cur_mi)
alpha = pl.exp(pl.sub(mi, mi_new))
beta = pl.exp(pl.sub(cur_mi, mi_new))
li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li))
oi = pl.add(pl.row_expand_mul(oi, alpha),
pl.row_expand_mul(oi_tmp_valid, beta))
mi = mi_new

# Stage 6: normalise and write back one Q-head batch.
# Stage 5: online softmax accumulation and normalisation.
with pl.incore():
oi = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [0, 0])
mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [0, 0])
li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [0, 0])
for sb in pl.range(1, ctx_blocks):
oi_tmp_valid = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [sb * Q_HEAD_PAD, 0])
cur_mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0])
cur_li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0])
mi_new = pl.maximum(mi, cur_mi)
alpha = pl.exp(pl.sub(mi, mi_new))
beta = pl.exp(pl.sub(cur_mi, mi_new))
li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li))
oi = pl.add(pl.row_expand_mul(oi, alpha),
pl.row_expand_mul(oi_tmp_valid, beta))
mi = mi_new
ctx = pl.row_expand_div(oi, li)
ctx_flat = pl.reshape(ctx, [1, Q_HEAD_BATCH * head_dim])
ctx_flat_bf16 = pl.cast(ctx_flat, target_type=pl.BF16)
Expand Down
110 changes: 30 additions & 80 deletions examples/models/qwen3/qwen3_32b_decode_scope2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""Qwen3-32B decode Scope 2 — RoPE + KV cache update + grouped-query attention.
1. K RoPE + cache write, V cache write, Q RoPE + pad
2. QK matmul
3. Softmax
3. Softmax
4. SV matmul
5. Online-softmax accumulation + final normalisation

Expand Down Expand Up @@ -64,6 +64,16 @@ def qwen3_scope2(
v_cache: pl.Tensor[[cache_rows, head_dim], pl.BF16],
attn_out: pl.Out[pl.Tensor[[batch, hidden], pl.BF16]],
) -> pl.Tensor[[batch, hidden], pl.BF16]:
# Padding q
all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16)
with pl.incore():
for idx in pl.range(batch * total_q_groups):
all_q_padded = pl.assemble(
all_q_padded,
pl.cast(pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
[idx * Q_HEAD_PAD + Q_HEAD_BATCH, 0],
)
Comment on lines +68 to +75
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Populate all Q groups or enforce q_groups == 1.

all_q_padded is now sized and sliced by total_q_groups, but these writes still only populate the first group for each KV head (... + ki * Q_HEAD_PAD + qi). If num_heads // num_kv_heads > Q_HEAD_BATCH, later gi slices for qg > 0 will read untouched rows and produce wrong attention. Either add a qg loop here and write to (ki * q_groups + qg) * Q_HEAD_PAD + qi, or reject unsupported configs explicitly.

Also applies to: 143-144, 151-151

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_scope2.py` around lines 68 - 75, The
loop that assembles all_q_padded only writes the first Q_HEAD_BATCH rows per KV
group, so when q_groups (num_heads // num_kv_heads) > 1 later gi slices will
read uninitialized rows; update the assembly to either (A) add an inner loop
over qg and write into index (ki * q_groups + qg) * Q_HEAD_PAD + qi for every q
group (ensure use of variables all_q_padded, Q_HEAD_PAD, Q_HEAD_BATCH,
total_q_groups, ki, qi, qg), or (B) explicitly assert/reject configs where
q_groups != 1 (e.g., check num_heads // num_kv_heads == 1) and raise an error;
apply the same fix pattern to the other occurrences mentioned around the later
assembly uses (the blocks at the other referenced locations).


for b in pl.range(batch):
ctx_len = pl.tensor.read(seq_lens, [b])
pos = ctx_len - 1
Expand All @@ -75,16 +85,6 @@ def qwen3_scope2(
sin_lo = pl.slice(sin_row, [1, half_dim], [0, 0])
sin_hi = pl.slice(sin_row, [1, half_dim], [0, half_dim])

# Workaround
all_q_padded = pl.create_tensor([total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16)
with pl.incore():
for gi in pl.range(total_q_groups):
all_q_padded = pl.assemble(
all_q_padded,
pl.cast(pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
[gi * Q_HEAD_PAD, 0],
)

# Stage 1: K RoPE + cache update + V cache + Q RoPE + pad.
with pl.auto_incore():
for ki in pl.parallel(0, num_kv_heads, chunk=8):
Expand Down Expand Up @@ -140,56 +140,22 @@ def qwen3_scope2(
),
target_type=pl.BF16,
)
all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [ki * Q_HEAD_PAD + qi, 0])
all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [ki * Q_HEAD_PAD + qi, half_dim])
all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, 0])
all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, half_dim])

attn_row = pl.create_tensor([1, hidden], dtype=pl.BF16)
for gi in pl.range(total_q_groups):
kvh = gi // q_groups
qg = gi - kvh * q_groups
q_base = kvh * q_per_kv + qg * Q_HEAD_BATCH
q_padded = pl.slice(all_q_padded, [Q_HEAD_PAD, head_dim], [gi * Q_HEAD_PAD, 0])
q_padded = pl.slice(all_q_padded, [Q_HEAD_PAD, head_dim], [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD, 0])

# Workaround
# Stage 2: QK matmul for all active sb blocks.
all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32)
all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16)
all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32)
all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.incore():
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
all_raw_scores = pl.assemble(
all_raw_scores,
pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0),
[sb * Q_HEAD_PAD, 0],
)
all_exp_padded = pl.assemble(
all_exp_padded,
pl.cast(pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
[sb * Q_HEAD_PAD, 0],
)
all_oi_tmp = pl.assemble(
all_oi_tmp,
pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0),
[sb * Q_HEAD_PAD, 0],
)
mi_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0)
all_cur_mi = pl.assemble(
all_cur_mi,
pl.reshape(mi_init_flat, [Q_HEAD_BATCH, 1]),
[sb * Q_HEAD_BATCH, 0],
)
li_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0)
all_cur_li = pl.assemble(
all_cur_li,
pl.reshape(li_init_flat, [Q_HEAD_BATCH, 1]),
[sb * Q_HEAD_BATCH, 0],
)

# Stage 2: QK matmul for all active sb blocks.
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.incore():
for si in pl.range(SB_BATCH):
Expand Down Expand Up @@ -251,38 +217,22 @@ def qwen3_scope2(
oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32)
all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0])

# Workaround
with pl.incore():
oi = pl.full([Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0)
li_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0)
li = pl.reshape(li_flat, [Q_HEAD_BATCH, 1])
mi_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0)
mi = pl.reshape(mi_flat, [Q_HEAD_BATCH, 1])

# Stage 5: online softmax accumulation for active sb blocks.
for sb0 in pl.range(0, ctx_blocks, SB_BATCH):
with pl.incore():
for si in pl.range(SB_BATCH):
sb = sb0 + si
if sb < ctx_blocks:
oi_tmp_valid = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [sb * Q_HEAD_PAD, 0])
cur_mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0])
cur_li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0])
if sb == 0:
oi = oi_tmp_valid
li = cur_li
mi = cur_mi
else:
mi_new = pl.maximum(mi, cur_mi)
alpha = pl.exp(pl.sub(mi, mi_new))
beta = pl.exp(pl.sub(cur_mi, mi_new))
li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li))
oi = pl.add(pl.row_expand_mul(oi, alpha),
pl.row_expand_mul(oi_tmp_valid, beta))
mi = mi_new

# Stage 6: normalise and write back one Q-head batch.
# Stage 5: online softmax accumulation and normalisation.
with pl.incore():
oi = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [0, 0])
mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [0, 0])
li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [0, 0])
for sb in pl.range(1, ctx_blocks):
oi_tmp_valid = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [sb * Q_HEAD_PAD, 0])
cur_mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0])
cur_li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0])
mi_new = pl.maximum(mi, cur_mi)
alpha = pl.exp(pl.sub(mi, mi_new))
beta = pl.exp(pl.sub(cur_mi, mi_new))
li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li))
oi = pl.add(pl.row_expand_mul(oi, alpha),
pl.row_expand_mul(oi_tmp_valid, beta))
mi = mi_new
ctx = pl.row_expand_div(oi, li)
ctx_flat = pl.reshape(ctx, [1, Q_HEAD_BATCH * head_dim])
ctx_flat_bf16 = pl.cast(ctx_flat, target_type=pl.BF16)
Expand Down
Loading