@@ -146,49 +146,101 @@ File mapping:
146146================================================================================
147147
1481488.1 Constants (scaled down for testing)
149- - BATCH: 64 → 1.
150- - MAX_SEQ: 4096 → 4.
151- - HIDDEN: 5120 → 80.
152- - INTERMEDIATE: 25600 → 400.
153- - K_CHUNK: 64 → 4.
154- - Q_OUT_CHUNK: 128 → 8.
155- - MLP_OUT_CHUNK: 256 → 16.
156-
157- 8.2 Scratch buffers
158- - New: muon_scratch [MLP_OUT_CHUNK, MLP_OUT_CHUNK] BF16.
159- - New: proxy_scratch [TOK_TILE, MLP_OUT_CHUNK] BF16.
160- - New: btrans_scratch [TOK_TILE, K_CHUNK] FP32.
161- (Used for staging data for matmul transpose patterns.)
162-
163- 8.3 Forward pass
164- - Input RMSNorm: reshape-after-slice pattern for 3D tensors.
165- - Attention scores: k_c staged via btrans_scratch (assemble/slice) before
166- matmul with b_trans=True.
167- - Attention context: explicit ctx_acc zeroed + add(ctx_acc, matmul(...))
168- instead of direct fused matmul.
149+ - BATCH: 64 → 1 (64//64).
150+ - MAX_SEQ: 4096 → 4 (4096//1024).
151+ - HIDDEN: 5120 → 80 (5120//64).
152+ - INTERMEDIATE: 25600 → 400 (25600//64).
153+ - K_CHUNK: 64 → 4 (64//16).
154+ - Q_OUT_CHUNK: 128 → 8 (128//16).
155+ - MLP_OUT_CHUNK: 256 → 16 (256//16).
156+
157+ 8.2 Top-level tensor allocation
158+ - loss_acc [TOK_TILE, 1] FP32 created outside all incore scopes (was
159+ previously inside the single auto_incore).
160+ - muon_buf [MLP_OUT_CHUNK, MLP_OUT_CHUNK] FP32 created at top level as
161+ staging buffer for Newton-Schulz iterations (written by assemble,
162+ read back by slice in separate incore scopes to force memory
163+ round-trip).
164+ - Old scratch buffers (muon_scratch, proxy_scratch, btrans_scratch) removed.
165+
166+ 8.3 Incore scope structure (major restructuring)
167+ Original: one monolithic pl.auto_incore() around the entire function body.
168+ New: split into many separate incore scopes to stay within Vec buffer
169+ limit (253952 bytes):
170+ (a) Gradient zeroing: two pl.incore() blocks — one for small grad
171+ tensors (wq/wk/wv/wo), one for large grads (w_gate/w_up/w_down)
172+ with chunked MLP_OUT_BLOCKS loop.
173+ (b) Forward + backward: per-token pl.auto_incore() inside the batch/
174+ position loop body.
175+ (c) Weight gradient stages: per-block pl.auto_incore() (see 8.7).
176+ (d) Loss extraction: separate pl.incore() at the end.
177+
178+ 8.4 Forward pass
179+ - 3D→2D slicing: pl.reshape(pl.slice(tensor, [1, TOK, CHUNK], ...),
180+ [TOK, CHUNK]) pattern used for hidden_states and target_states.
181+ - Attention scores: k_c loaded directly from k_proj_tile and used with
182+ b_trans=True in matmul (no staging buffer needed).
183+ - Attention context: explicit ctx_acc zeroed via pl.mul(ctx_acc, 0.0)
184+ then accumulated via pl.add(ctx_acc, matmul(...)).
169185 - O projection residual: reshape-after-slice for hidden_states.
170186
171- 8.4 Loss accumulation
187+ 8.5 Loss accumulation
172188 - Old: per-token loop with tensor.read and [1,1] scalar tensors.
173189 - New: vector add with [TOK_TILE, 1] accumulator (loss_prev + sq_row).
190+ - loss_out changed from [1] to [TOK_TILE, 1] to match loss_acc directly,
191+ avoiding layout mismatch (nd vs dn) during tile.store.
174192
175- 8.5 Backward pass
193+ 8.6 Backward pass
176194 - MLP backward: d_mlp cast to BF16 before matmul with w_down chunk.
177195 - Gate/up gradients: d_gate_bf16 / d_up_bf16 intermediate BF16 casts,
178196 sequential add pattern instead of fused add(add(matmul, matmul)).
179- - Attention backward: v_c renamed to v_bwd.
180-
181- 8.6 Weight gradients + Muon optimizer (major rewrite)
182- - proxy_scratch used to stage proxy tensors for a_trans matmul patterns.
183- - Stage 1 (w_down): tiled gram matrix computation via muon_scratch
184- (MLP_OUT_CHUNK//TOK_TILE iterations, slice/transpose/matmul per tile).
185- - Stage 2 (wo/wq/wk/wv): proxy_ctx and proxy_n staged via proxy_scratch;
186- tiled gram with K_CHUNK//TOK_TILE iterations.
187- - Stage 3 (w_gate/w_up): different NS formulation — builds ns_acc via
188- matmul(muon_bf, transpose(tile)) then matmul(tmp_bf, tile) instead of
189- computing gram then muon @ gram.
190-
191- 8.7 Backend
197+ - Attention backward: q_c/k_c renamed to q_bwd/k_bwd to avoid type
198+ reassignment with forward-pass variables of different shapes.
199+
200+ 8.7 Weight gradients + Muon optimizer (major rewrite)
201+ Three stages, each processing one weight block per iteration with
202+ the outer loop OUTSIDE the incore scope:
203+
204+ Stage 1 (grad_w_down): for each Q_OUT block:
205+ - One auto_incore: compute proxy gradient (a_trans=True matmul of
206+ proxy_mlp × proxy_go), momentum update, assemble into muon_buf.
207+ - Newton-Schulz loop (MUON_NS_STEPS iterations, each in its own
208+ auto_incore): slice muon iterate from muon_buf in memory, compute
209+ Gram matrix G'=X@X^T via b_trans=True, update X←1.5X−0.5·G'@X,
210+ assemble result back to muon_buf.
211+ - One auto_incore: extract final iterate, apply learning rate, assemble
212+ into grad_w_down.
213+
214+ Stage 2 (grad_wo/wq/wk/wv): same pattern per Q_OUT block, with each
215+ weight (wo, wq, wk, wv) having its own gradient+momentum+NS sequence.
216+
217+ Stage 3 (grad_w_gate/w_up): same pattern per MLP_OUT block.
218+
219+ Key design decisions for the Muon Newton-Schulz implementation:
220+ (a) Gram matrix reformulated: uses b_trans=True (G'=X@X^T) instead of
221+ a_trans=True (G=X^T@X). Update becomes G'@X instead of X@G.
222+ Mathematically equivalent by associativity: X@(X^T@X) = (X@X^T)@X.
223+ (b) The b_trans=True formulation generates tile.load(transpose=True)
224+ from memory to Mem.Mat, instead of tile.transpose in Mem.Vec which
225+ triggers a codegen bug (matmul K-dimension mismatch) during PTO
226+ code generation for non-square operands.
227+ (c) NS loop placed outside auto_incore so each step's matmul operands
228+ come from pl.slice of muon_buf (a tensor in memory), not from
229+ computed Vec tiles. This forces the converter to use transposed
230+ memory loads rather than Vec transposes.
231+ (d) muon_buf serves as the staging tensor — written via pl.assemble at
232+ end of each NS step, read back via pl.slice at start of next step.
233+ The loop boundary prevents the optimizer from eliminating the
234+ memory round-trip.
235+
236+ 8.8 Variable naming
237+ - Unique variable names per stage to avoid PyPTO type reassignment errors:
238+ mu_s1/gram_s1/ns_update_s1 (stage 1), mu_wo/gram_wo/ns_upd_wo (stage 2
239+ wo), mu_wq/gram_wq/ns_upd_wq (stage 2 wq), etc.
240+ - proxy_tgt_q/proxy_tgt_k/proxy_tgt_v, proxy_n_k/proxy_n_v,
241+ proxy_post_g/proxy_post_u to avoid reusing names across different shapes.
242+
243+ 8.9 Backend
192244 - BackendType.CCE → BackendType.Ascend950.
193245 - save_kernels / save_kernels_dir added to RunConfig.
194246
@@ -206,8 +258,9 @@ prefill_tilelet → _new | RoPE: concat → create_tensor+assemble
206258decode → _new | Same as qwen3-32b (grouped Q, staged incore, etc.)
207259decode_scope2 → _new | Accumulator init: full → create_tensor+mul
208260decode_tilelet → _new | Accumulator init: full → create_tensor+mul
209- training_fwd_bwd → _new | Scaled constants, scratch buffers, reshape-after-
210- | slice, tiled Muon NS, vector loss, Ascend950
261+ training_fwd_bwd → _new | Scaled constants, multi-scope incore, reshape-after-
262+ | slice, Muon NS with muon_buf staging & b_trans Gram,
263+ | vector loss, unique variable names, Ascend950
211264---------------------------+-----------------------------------------------------
212265All files | CCE→Ascend950, save_kernels, early return on
213266 | code_runner error, pl.full removed
0 commit comments