Skip to content

Commit 4aba28e

Browse files
author
Dave Christopher Hong
committed
change qwen3_32b_training_forward_and_backward.py to support pypto
1 parent 5361894 commit 4aba28e

2 files changed

Lines changed: 487 additions & 407 deletions

File tree

examples-lib/qwen3/SUMMARY_QWEN3CHANGES.txt

Lines changed: 90 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -146,49 +146,101 @@ File mapping:
146146
================================================================================
147147

148148
8.1 Constants (scaled down for testing)
149-
- BATCH: 64 → 1.
150-
- MAX_SEQ: 4096 → 4.
151-
- HIDDEN: 5120 → 80.
152-
- INTERMEDIATE: 25600 → 400.
153-
- K_CHUNK: 64 → 4.
154-
- Q_OUT_CHUNK: 128 → 8.
155-
- MLP_OUT_CHUNK: 256 → 16.
156-
157-
8.2 Scratch buffers
158-
- New: muon_scratch [MLP_OUT_CHUNK, MLP_OUT_CHUNK] BF16.
159-
- New: proxy_scratch [TOK_TILE, MLP_OUT_CHUNK] BF16.
160-
- New: btrans_scratch [TOK_TILE, K_CHUNK] FP32.
161-
(Used for staging data for matmul transpose patterns.)
162-
163-
8.3 Forward pass
164-
- Input RMSNorm: reshape-after-slice pattern for 3D tensors.
165-
- Attention scores: k_c staged via btrans_scratch (assemble/slice) before
166-
matmul with b_trans=True.
167-
- Attention context: explicit ctx_acc zeroed + add(ctx_acc, matmul(...))
168-
instead of direct fused matmul.
149+
- BATCH: 64 → 1 (64//64).
150+
- MAX_SEQ: 4096 → 4 (4096//1024).
151+
- HIDDEN: 5120 → 80 (5120//64).
152+
- INTERMEDIATE: 25600 → 400 (25600//64).
153+
- K_CHUNK: 64 → 4 (64//16).
154+
- Q_OUT_CHUNK: 128 → 8 (128//16).
155+
- MLP_OUT_CHUNK: 256 → 16 (256//16).
156+
157+
8.2 Top-level tensor allocation
158+
- loss_acc [TOK_TILE, 1] FP32 created outside all incore scopes (was
159+
previously inside the single auto_incore).
160+
- muon_buf [MLP_OUT_CHUNK, MLP_OUT_CHUNK] FP32 created at top level as
161+
staging buffer for Newton-Schulz iterations (written by assemble,
162+
read back by slice in separate incore scopes to force memory
163+
round-trip).
164+
- Old scratch buffers (muon_scratch, proxy_scratch, btrans_scratch) removed.
165+
166+
8.3 Incore scope structure (major restructuring)
167+
Original: one monolithic pl.auto_incore() around the entire function body.
168+
New: split into many separate incore scopes to stay within Vec buffer
169+
limit (253952 bytes):
170+
(a) Gradient zeroing: two pl.incore() blocks — one for small grad
171+
tensors (wq/wk/wv/wo), one for large grads (w_gate/w_up/w_down)
172+
with chunked MLP_OUT_BLOCKS loop.
173+
(b) Forward + backward: per-token pl.auto_incore() inside the batch/
174+
position loop body.
175+
(c) Weight gradient stages: per-block pl.auto_incore() (see 8.7).
176+
(d) Loss extraction: separate pl.incore() at the end.
177+
178+
8.4 Forward pass
179+
- 3D→2D slicing: pl.reshape(pl.slice(tensor, [1, TOK, CHUNK], ...),
180+
[TOK, CHUNK]) pattern used for hidden_states and target_states.
181+
- Attention scores: k_c loaded directly from k_proj_tile and used with
182+
b_trans=True in matmul (no staging buffer needed).
183+
- Attention context: explicit ctx_acc zeroed via pl.mul(ctx_acc, 0.0)
184+
then accumulated via pl.add(ctx_acc, matmul(...)).
169185
- O projection residual: reshape-after-slice for hidden_states.
170186

171-
8.4 Loss accumulation
187+
8.5 Loss accumulation
172188
- Old: per-token loop with tensor.read and [1,1] scalar tensors.
173189
- New: vector add with [TOK_TILE, 1] accumulator (loss_prev + sq_row).
190+
- loss_out changed from [1] to [TOK_TILE, 1] to match loss_acc directly,
191+
avoiding layout mismatch (nd vs dn) during tile.store.
174192

175-
8.5 Backward pass
193+
8.6 Backward pass
176194
- MLP backward: d_mlp cast to BF16 before matmul with w_down chunk.
177195
- Gate/up gradients: d_gate_bf16 / d_up_bf16 intermediate BF16 casts,
178196
sequential add pattern instead of fused add(add(matmul, matmul)).
179-
- Attention backward: v_c renamed to v_bwd.
180-
181-
8.6 Weight gradients + Muon optimizer (major rewrite)
182-
- proxy_scratch used to stage proxy tensors for a_trans matmul patterns.
183-
- Stage 1 (w_down): tiled gram matrix computation via muon_scratch
184-
(MLP_OUT_CHUNK//TOK_TILE iterations, slice/transpose/matmul per tile).
185-
- Stage 2 (wo/wq/wk/wv): proxy_ctx and proxy_n staged via proxy_scratch;
186-
tiled gram with K_CHUNK//TOK_TILE iterations.
187-
- Stage 3 (w_gate/w_up): different NS formulation — builds ns_acc via
188-
matmul(muon_bf, transpose(tile)) then matmul(tmp_bf, tile) instead of
189-
computing gram then muon @ gram.
190-
191-
8.7 Backend
197+
- Attention backward: q_c/k_c renamed to q_bwd/k_bwd to avoid type
198+
reassignment with forward-pass variables of different shapes.
199+
200+
8.7 Weight gradients + Muon optimizer (major rewrite)
201+
Three stages, each processing one weight block per iteration with
202+
the outer loop OUTSIDE the incore scope:
203+
204+
Stage 1 (grad_w_down): for each Q_OUT block:
205+
- One auto_incore: compute proxy gradient (a_trans=True matmul of
206+
proxy_mlp × proxy_go), momentum update, assemble into muon_buf.
207+
- Newton-Schulz loop (MUON_NS_STEPS iterations, each in its own
208+
auto_incore): slice muon iterate from muon_buf in memory, compute
209+
Gram matrix G'=X@X^T via b_trans=True, update X←1.5X−0.5·G'@X,
210+
assemble result back to muon_buf.
211+
- One auto_incore: extract final iterate, apply learning rate, assemble
212+
into grad_w_down.
213+
214+
Stage 2 (grad_wo/wq/wk/wv): same pattern per Q_OUT block, with each
215+
weight (wo, wq, wk, wv) having its own gradient+momentum+NS sequence.
216+
217+
Stage 3 (grad_w_gate/w_up): same pattern per MLP_OUT block.
218+
219+
Key design decisions for the Muon Newton-Schulz implementation:
220+
(a) Gram matrix reformulated: uses b_trans=True (G'=X@X^T) instead of
221+
a_trans=True (G=X^T@X). Update becomes G'@X instead of X@G.
222+
Mathematically equivalent by associativity: X@(X^T@X) = (X@X^T)@X.
223+
(b) The b_trans=True formulation generates tile.load(transpose=True)
224+
from memory to Mem.Mat, instead of tile.transpose in Mem.Vec which
225+
triggers a codegen bug (matmul K-dimension mismatch) during PTO
226+
code generation for non-square operands.
227+
(c) NS loop placed outside auto_incore so each step's matmul operands
228+
come from pl.slice of muon_buf (a tensor in memory), not from
229+
computed Vec tiles. This forces the converter to use transposed
230+
memory loads rather than Vec transposes.
231+
(d) muon_buf serves as the staging tensor — written via pl.assemble at
232+
end of each NS step, read back via pl.slice at start of next step.
233+
The loop boundary prevents the optimizer from eliminating the
234+
memory round-trip.
235+
236+
8.8 Variable naming
237+
- Unique variable names per stage to avoid PyPTO type reassignment errors:
238+
mu_s1/gram_s1/ns_update_s1 (stage 1), mu_wo/gram_wo/ns_upd_wo (stage 2
239+
wo), mu_wq/gram_wq/ns_upd_wq (stage 2 wq), etc.
240+
- proxy_tgt_q/proxy_tgt_k/proxy_tgt_v, proxy_n_k/proxy_n_v,
241+
proxy_post_g/proxy_post_u to avoid reusing names across different shapes.
242+
243+
8.9 Backend
192244
- BackendType.CCE → BackendType.Ascend950.
193245
- save_kernels / save_kernels_dir added to RunConfig.
194246

@@ -206,8 +258,9 @@ prefill_tilelet → _new | RoPE: concat → create_tensor+assemble
206258
decode → _new | Same as qwen3-32b (grouped Q, staged incore, etc.)
207259
decode_scope2 → _new | Accumulator init: full → create_tensor+mul
208260
decode_tilelet → _new | Accumulator init: full → create_tensor+mul
209-
training_fwd_bwd → _new | Scaled constants, scratch buffers, reshape-after-
210-
| slice, tiled Muon NS, vector loss, Ascend950
261+
training_fwd_bwd → _new | Scaled constants, multi-scope incore, reshape-after-
262+
| slice, Muon NS with muon_buf staging & b_trans Gram,
263+
| vector loss, unique variable names, Ascend950
211264
---------------------------+-----------------------------------------------------
212265
All files | CCE→Ascend950, save_kernels, early return on
213266
| code_runner error, pl.full removed

0 commit comments

Comments
 (0)