Refactor: Qwen3 decode with 3-scope architecture and TILELET rename#99
Refactor: Qwen3 decode with 3-scope architecture and TILELET rename#99zhangqi-chen merged 1 commit intohw-native-sys:mainfrom
Conversation
bumble0918
commented
Apr 10, 2026
- qwen3_32b_decode.py: Refactored into 3 scopes for better incore
- Scope 1: Input RMSNorm + Q/K/V projection
- Scope 2: Attention (K RoPE + cache, QK matmul, softmax, SV matmul)
- Scope 3: Output projection, residual, RMSNorm, MLP
- Updated HIDDEN size from 5120 to 8192 (64 heads × 128 dim)
- Renamed qwen3_32b_decode_tilelet.py to qwen3_32b_decode_mixed.py for clearer TILELET-aware version naming
- Adjusted tiling constants for each scope
- qwen3_32b_decode.py: Refactored into 3 scopes for better incore * Scope 1: Input RMSNorm + Q/K/V projection * Scope 2: Attention (K RoPE + cache, QK matmul, softmax, SV matmul) * Scope 3: Output projection, residual, RMSNorm, MLP - Updated HIDDEN size from 5120 to 8192 (64 heads × 128 dim) - Renamed qwen3_32b_decode_tilelet.py to qwen3_32b_decode_mixed.py for clearer TILELET-aware version naming - Adjusted tiling constants for each scope
📝 WalkthroughWalkthroughRestructured Qwen3 decode from single-layer into three explicit phases (Scope 1/2/3): Scope 1 performs input normalization and Q/K/V projections in FP32; Scope 2 implements grouped-query attention with KV cache updates and staged softmax; Scope 3 handles output projection, residual operations, and MLP. Added PyTorch golden reference implementation and updated function signatures with new parameters. Changes
Sequence Diagram(s)sequenceDiagram
participant Input as Input Tensors
participant S1 as Scope 1<br/>RMSNorm + Projections
participant S2 as Scope 2<br/>Attention + KV Cache
participant S3 as Scope 3<br/>Output + MLP + Residual
participant Output as Output Tensors
Input->>S1: raw activations
S1->>S1: apply input RMSNorm
S1->>S1: project to Q/K/V (FP32)
S1->>S2: Q, K, V intermediates
S2->>S2: update KV caches
S2->>S2: apply RoPE to Q/K
S2->>S2: grouped-query attention<br/>(staged softmax + accumulation)
S2->>S3: attention output
S3->>S3: output projection
S3->>S3: add residual
S3->>S3: post-attention RMSNorm
S3->>S3: SwiGLU MLP
S3->>S3: final residual assembly
S3->>Output: decoded layer output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the Qwen3-32B single-layer decode forward implementation into three distinct scopes, optimizing tiling strategies and adding a PyTorch reference function for validation. Feedback focuses on improving code robustness and performance: specifically, reshaping the results of pl.row_max and pl.row_sum to maintain consistent 2D tensor dimensions, and replacing inefficient loops of small pl.assemble calls with pl.full for tensor initialization to reduce DMA overhead.
| ) | ||
| scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) | ||
| scores = pl.mul(scores_padded, attn_scale) | ||
| cur_mi = pl.row_max(scores) |
There was a problem hiding this comment.
| exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) | ||
| exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) | ||
| exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32) | ||
| cur_li = pl.row_sum(exp_scores_fp32) |
There was a problem hiding this comment.
| for idx in pl.range(batch * total_q_groups): | ||
| all_q_padded = pl.assemble( | ||
| all_q_padded, | ||
| pl.cast(pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16), | ||
| [idx * Q_HEAD_PAD + Q_HEAD_BATCH, 0], | ||
| ) |
There was a problem hiding this comment.
| for sb0 in pl.range(0, ctx_blocks, SB_BATCH): | ||
| with pl.incore(): | ||
| for si in pl.range(SB_BATCH): | ||
| sb = sb0 + si | ||
| if sb < ctx_blocks: | ||
| all_raw_scores = pl.assemble( | ||
| all_raw_scores, | ||
| pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), | ||
| [sb * Q_HEAD_PAD, 0], | ||
| ) | ||
| all_exp_padded = pl.assemble( | ||
| all_exp_padded, | ||
| pl.cast(pl.full([Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32, value=0.0), target_type=pl.BF16), | ||
| [sb * Q_HEAD_PAD, 0], | ||
| ) | ||
| all_oi_tmp = pl.assemble( | ||
| all_oi_tmp, | ||
| pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0), | ||
| [sb * Q_HEAD_PAD, 0], | ||
| ) | ||
| mi_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0) | ||
| all_cur_mi = pl.assemble( | ||
| all_cur_mi, | ||
| pl.reshape(mi_init_flat, [Q_HEAD_BATCH, 1]), | ||
| [sb * Q_HEAD_BATCH, 0], | ||
| ) | ||
| li_init_flat = pl.full([1, Q_HEAD_BATCH], dtype=pl.FP32, value=0.0) | ||
| all_cur_li = pl.assemble( | ||
| all_cur_li, | ||
| pl.reshape(li_init_flat, [Q_HEAD_BATCH, 1]), | ||
| [sb * Q_HEAD_BATCH, 0], | ||
| ) |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
examples/models/qwen3/qwen3_32b_decode.py (3)
23-23: Consider using ASCIIxinstead of Unicode×in documentation.The docstring uses the Unicode multiplication sign (
×) which could cause issues in some editors or environments. This is flagged by static analysis (RUF002/RUF003).Suggested fix
- 1. Output projection: attn_out × wo + 1. Output projection: attn_out x woAlso applies to line 382.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/qwen3/qwen3_32b_decode.py` at line 23, Replace the Unicode multiplication sign in the docstrings with an ASCII 'x' so editors/static checkers don't flag it; locate the string "Output projection: attn_out × wo" (and the similar occurrence noted around line 382) in examples/models/qwen3/qwen3_32b_decode.py and change "×" to "x" (i.e., "attn_out x wo") while preserving surrounding text and formatting.
253-289: Clarify the "Workaround" comment purpose.The comment on line 252 says "# Workaround" without explaining what it's working around. Consider adding a brief explanation for maintainability.
Suggested clarification
- # Workaround + # Workaround: pre-allocate intermediate tensors for all context blocks + # to avoid dynamic tensor creation within the attention loop all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/qwen3/qwen3_32b_decode.py` around lines 253 - 289, The lone "# Workaround" comment is too vague—update it to explain what bug or hardware/SDK limitation the loop over ctx_blocks and the in-core assembly of all_raw_scores, all_exp_padded, all_oi_tmp, all_cur_mi, and all_cur_li is addressing (e.g., avoiding OOM, alignment/padding requirements, or a compiler bug when assembling large tensors), and mention why mi_init_flat/li_init_flat are reshaped into Q_HEAD_BATCH×1; place the clarified comment immediately above the for sb0 in pl.range(...) loop and include the specific constraint or behavior being worked around plus any links or issue IDs if available.
72-79: Integer division assumes exact divisibility without validation.The tiling calculations silently truncate if dimensions aren't evenly divisible by chunk sizes. With default constants this works, but custom configurations could produce incorrect results.
Consider adding assertions for non-default configurations:
Proposed validation
hidden = hidden_size kv_hidden = num_kv_heads * head_dim inter = intermediate_size + assert hidden % K_CHUNK == 0, f"hidden ({hidden}) must be divisible by K_CHUNK ({K_CHUNK})" + assert hidden % Q_OUT_CHUNK == 0, f"hidden ({hidden}) must be divisible by Q_OUT_CHUNK ({Q_OUT_CHUNK})" + assert kv_hidden % KV_OUT_CHUNK == 0, f"kv_hidden ({kv_hidden}) must be divisible by KV_OUT_CHUNK ({KV_OUT_CHUNK})" + assert inter % MLP_OUT_CHUNK == 0, f"inter ({inter}) must be divisible by MLP_OUT_CHUNK ({MLP_OUT_CHUNK})" hidden_blocks = hidden // K_CHUNK q_out_blocks = hidden // Q_OUT_CHUNK kv_out_blocks = kv_hidden // KV_OUT_CHUNK mlp_out_blocks = inter // MLP_OUT_CHUNK🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/models/qwen3/qwen3_32b_decode.py` around lines 72 - 79, The tiling variables hidden_blocks, q_out_blocks, kv_out_blocks, mlp_out_blocks, cache_rows, half_dim, q_per_kv, and q_groups perform integer division without validating divisibility; add explicit assertions (or raise ValueError) to verify that hidden % K_CHUNK == 0, hidden % Q_OUT_CHUNK == 0, kv_hidden % KV_OUT_CHUNK == 0, inter % MLP_OUT_CHUNK == 0, head_dim % 2 == 0, num_heads % num_kv_heads == 0, and q_per_kv % Q_HEAD_BATCH == 0 (using the same variable names from the diff) so misconfigured shapes fail fast with a clear error message describing which dimension/chunk pair is invalid.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode.py`:
- Around line 493-496: The code defines a local variable hidden = num_heads *
head_dim but other tensor specs use hidden_size, causing inconsistency and a
potentially unused shadow variable; change the code so the output tensor shape
and any uses of hidden reference hidden_size (or set hidden = hidden_size if you
prefer a local alias) and remove any other unused definition of hidden or
replace its usage with hidden_size to ensure all tensor shapes (including the
output tensor created near the current output-shape code) consistently use
hidden_size rather than recomputing num_heads * head_dim.
---
Nitpick comments:
In `@examples/models/qwen3/qwen3_32b_decode.py`:
- Line 23: Replace the Unicode multiplication sign in the docstrings with an
ASCII 'x' so editors/static checkers don't flag it; locate the string "Output
projection: attn_out × wo" (and the similar occurrence noted around line 382) in
examples/models/qwen3/qwen3_32b_decode.py and change "×" to "x" (i.e., "attn_out
x wo") while preserving surrounding text and formatting.
- Around line 253-289: The lone "# Workaround" comment is too vague—update it to
explain what bug or hardware/SDK limitation the loop over ctx_blocks and the
in-core assembly of all_raw_scores, all_exp_padded, all_oi_tmp, all_cur_mi, and
all_cur_li is addressing (e.g., avoiding OOM, alignment/padding requirements, or
a compiler bug when assembling large tensors), and mention why
mi_init_flat/li_init_flat are reshaped into Q_HEAD_BATCH×1; place the clarified
comment immediately above the for sb0 in pl.range(...) loop and include the
specific constraint or behavior being worked around plus any links or issue IDs
if available.
- Around line 72-79: The tiling variables hidden_blocks, q_out_blocks,
kv_out_blocks, mlp_out_blocks, cache_rows, half_dim, q_per_kv, and q_groups
perform integer division without validating divisibility; add explicit
assertions (or raise ValueError) to verify that hidden % K_CHUNK == 0, hidden %
Q_OUT_CHUNK == 0, kv_hidden % KV_OUT_CHUNK == 0, inter % MLP_OUT_CHUNK == 0,
head_dim % 2 == 0, num_heads % num_kv_heads == 0, and q_per_kv % Q_HEAD_BATCH ==
0 (using the same variable names from the diff) so misconfigured shapes fail
fast with a clear error message describing which dimension/chunk pair is
invalid.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 5ae171cb-263c-401e-8755-c6ea6501c21c
📒 Files selected for processing (2)
examples/models/qwen3/qwen3_32b_decode.pyexamples/models/qwen3/qwen3_32b_decode_mixed.py
| hidden = num_heads * head_dim | ||
| kv_hidden = num_kv_heads * head_dim | ||
| cache_rows = batch * num_kv_heads * max_seq_len | ||
| inter = intermediate_size | ||
| cache_rows = batch * num_kv_heads * max_seq |
There was a problem hiding this comment.
Unused variable hidden and inconsistent usage with hidden_size.
Line 493 defines hidden = num_heads * head_dim which shadows the intent but is mostly unused. Line 575 uses hidden for the output tensor shape, while all other tensor specs use hidden_size. This inconsistency could cause bugs if hidden_size ever differs from num_heads * head_dim.
Suggested fix for consistency
- hidden = num_heads * head_dim
kv_hidden = num_kv_heads * head_dim
inter = intermediate_size
cache_rows = batch * num_kv_heads * max_seqAnd at line 575:
- TensorSpec("out", [batch, hidden], torch.bfloat16, is_output=True),
+ TensorSpec("out", [batch, hidden_size], torch.bfloat16, is_output=True),📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| hidden = num_heads * head_dim | |
| kv_hidden = num_kv_heads * head_dim | |
| cache_rows = batch * num_kv_heads * max_seq_len | |
| inter = intermediate_size | |
| cache_rows = batch * num_kv_heads * max_seq | |
| kv_hidden = num_kv_heads * head_dim | |
| inter = intermediate_size | |
| cache_rows = batch * num_kv_heads * max_seq |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode.py` around lines 493 - 496, The code
defines a local variable hidden = num_heads * head_dim but other tensor specs
use hidden_size, causing inconsistency and a potentially unused shadow variable;
change the code so the output tensor shape and any uses of hidden reference
hidden_size (or set hidden = hidden_size if you prefer a local alias) and remove
any other unused definition of hidden or replace its usage with hidden_size to
ensure all tensor shapes (including the output tensor created near the current
output-shape code) consistently use hidden_size rather than recomputing
num_heads * head_dim.
f8454ae to
3ee28c7
Compare