Problem
Each of the 24 transformer layers has 2 normalization points, each currently requiring 2 Metal dispatches: one for RMSNorm, one for residual addition. That's 24 × 2 × 2 = 96 dispatches just for norm+residual.
Proposal
Fuse RMSNorm + residual add into a single kernel per normalization point:
- Read input once
- Compute variance → normalize → scale → add residual
- Write output once
This cuts 48 dispatches/token (from 304 to ~256), a 16% dispatch reduction.
Evidence
metal_qwen35.rs line ~8663 (GDN) and ~9043 (GQA): norm and residual dispatched separately
- Metal best practice: minimize dispatch count to reduce command-buffer overhead
- At 304 dispatches/token with ~0.02ms overhead/dispatch, 48 fewer = ~1ms/token potential
Acceptance
Priority
P1
Problem
Each of the 24 transformer layers has 2 normalization points, each currently requiring 2 Metal dispatches: one for RMSNorm, one for residual addition. That's 24 × 2 × 2 = 96 dispatches just for norm+residual.
Proposal
Fuse RMSNorm + residual add into a single kernel per normalization point:
This cuts 48 dispatches/token (from 304 to ~256), a 16% dispatch reduction.
Evidence
metal_qwen35.rsline ~8663 (GDN) and ~9043 (GQA): norm and residual dispatched separatelyAcceptance
output = RMSNorm(x) * weight + residualPriority
P1