From 5ab8510f075b123bda7350c6bd69dd0b407561a1 Mon Sep 17 00:00:00 2001 From: OceanLi <122793010+ohdearquant@users.noreply.github.com> Date: Sun, 31 May 2026 09:24:50 -0400 Subject: [PATCH 1/2] fix(inference): wide f16 GEMV kernel restores 160 tok/s decode throughput MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The lm_head dispatch for Q8 models was using gemv_decode_m1 (NR=1, one threadgroup per vocab row) which created 151,936 threadgroups of 256 threads each — 4× more shader invocations than the prior Q8 path. This was introduced in commit 4dab27e1 which correctly switched to f16 weights for PPL quality but used an inefficient kernel for the large-N matmul. Add gemv_decode_wide_f16: an NR=4 f16 GEMV kernel (same structure as the existing gemv_q8_decode_wide) that processes 4 output rows per threadgroup, reducing lm_head dispatch from 151,936 to 37,984 threadgroups. Same f16 weights, same f32 accumulation — zero quality regression. bench_decode_ab (Qwen3.5-0.8B Q8, slope method): Before: T1=291ms, ~133 tok/s After: T1=127ms, ~160 tok/s Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/inference/src/forward/metal_qwen35.rs | 122 +++++++++++++------ 1 file changed, 87 insertions(+), 35 deletions(-) diff --git a/crates/inference/src/forward/metal_qwen35.rs b/crates/inference/src/forward/metal_qwen35.rs index ae9aae982..56bb49e3a 100644 --- a/crates/inference/src/forward/metal_qwen35.rs +++ b/crates/inference/src/forward/metal_qwen35.rs @@ -297,6 +297,65 @@ kernel void gemv_q8_decode_wide( } } +// ===== FP16-weight GEMV Decode (wide): NR=4 rows per threadgroup ===== +// Same scheduling as gemv_q8_decode_wide but reads half-precision weights +// directly (no Q8 block structure). Dispatched for large-N matmuls (lm_head) +// where the NR=1 gemv_decode_m1 kernel creates excessive threadgroup count. +// Dispatch: threadgroups=(ceil(N/4), 1, 1), threads=(32, 4, 1) +kernel void gemv_decode_wide_f16( + device const float* x [[buffer(0)]], + device const half* W [[buffer(1)]], + device float* y [[buffer(2)]], + constant uint& N [[buffer(3)]], + constant uint& K [[buffer(4)]], + uint3 tgpig [[threadgroup_position_in_grid]], + uint tiisg [[thread_index_in_simdgroup]], + uint sgitg [[simdgroup_index_in_threadgroup]]) +{ + const uint NR = 4; + const uint NSG = 4; + const uint nb = K / 32; + const uint first_row = tgpig.x * NR; + const uint ix = tiisg / 4; + const uint il = tiisg % 4; + + float sumf[NR] = {0.0f}; + const uint ib_start = sgitg * 8 + ix; + const uint ib_stride = NSG * 8; + device const float* xb = x + ib_start * 32 + il * 8; + + for (uint ib = ib_start; ib < nb; ib += ib_stride) { + float xl[8]; + for (uint i = 0; i < 8; i++) xl[i] = xb[i]; + xb += ib_stride * 32; + + for (uint row = 0; row < NR; row++) { + uint r = first_row + row; + if (r >= N) continue; + device const half* wrow = W + r * K + ib * 32 + il * 8; + float dot = 0.0f; + for (uint i = 0; i < 8; i++) dot += xl[i] * float(wrow[i]); + sumf[row] += dot; + } + } + + for (uint row = 0; row < NR; row++) sumf[row] = simd_sum(sumf[row]); + + threadgroup float shared[NR][4]; + if (tiisg == 0) { + for (uint row = 0; row < NR; row++) shared[row][sgitg] = sumf[row]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (sgitg == 0 && tiisg == 0) { + for (uint row = 0; row < NR; row++) { + uint r = first_row + row; + if (r < N) { + y[r] = shared[row][0] + shared[row][1] + shared[row][2] + shared[row][3]; + } + } + } +} + // ===== Qwen3.5 RMS Norm: x = x * (1 + gamma) / sqrt(mean(x^2) + eps) ===== // Shifted norm: (1 + gamma) instead of plain gamma. // One threadgroup per row. @@ -2812,6 +2871,7 @@ kernel void moe_shared_gate_add( /// Compiled MSL pipeline state objects. struct MetalQwen35Pipelines { gemv_decode: ComputePipelineState, + gemv_decode_wide: ComputePipelineState, gemv_q8: ComputePipelineState, gemv_q8_wide: ComputePipelineState, rms_norm: ComputePipelineState, @@ -3632,6 +3692,7 @@ kernel void moe_shared_gate_add( let pipelines = MetalQwen35Pipelines { gemv_decode: make_pipeline("gemv_decode_m1")?, + gemv_decode_wide: make_pipeline("gemv_decode_wide_f16")?, gemv_q8: make_pipeline("gemv_q8_decode")?, gemv_q8_wide: make_pipeline("gemv_q8_decode_wide")?, gemv_q4: make_pipeline("gemv_q4_decode")?, @@ -7013,45 +7074,33 @@ kernel void moe_shared_gate_add( // matmul kernel is faster + handles QuaRot rotations correctly. let use_fp16_lm_head = matches!(self.engine.quant_format, QuantFormat::Q8_0); if use_fp16_lm_head { - let gemm_params = GemmParams { - m: 1, - n: cfg.vocab_size as u32, - k: hidden as u32, - lda: hidden as u32, - ldb: hidden as u32, - ldc: cfg.vocab_size as u32, - }; + let vocab_n = cfg.vocab_size as u32; + let hidden_k = hidden as u32; if let Some(ref pb) = ppl_buf { for p in 0..n { let hidden_off = (p * hidden) as u64 * 4; let logits_off = (p * cfg.vocab_size) as u64 * 4; - enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode); + enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode_wide); enc.set_buffer(0, Some(&self.session.activations.hidden), hidden_off); enc.set_buffer(1, Some(&self.engine.embed_tokens), 0); enc.set_buffer(2, Some(pb), logits_off); - enc.set_bytes( - 3, - std::mem::size_of::() as u64, - &gemm_params as *const GemmParams as *const _, - ); + enc.set_bytes(3, 4, &vocab_n as *const u32 as *const _); + enc.set_bytes(4, 4, &hidden_k as *const u32 as *const _); enc.dispatch_thread_groups( - MTLSize::new(cfg.vocab_size as u64, 1, 1), - MTLSize::new(256, 1, 1), + MTLSize::new(vocab_n.div_ceil(4) as u64, 1, 1), + MTLSize::new(32, 4, 1), ); } } else { - enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode); + enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode_wide); enc.set_buffer(0, Some(&self.session.activations.hidden), last_off); enc.set_buffer(1, Some(&self.engine.embed_tokens), 0); enc.set_buffer(2, Some(&self.session.activations.logits), 0); - enc.set_bytes( - 3, - std::mem::size_of::() as u64, - &gemm_params as *const GemmParams as *const _, - ); + enc.set_bytes(3, 4, &vocab_n as *const u32 as *const _); + enc.set_bytes(4, 4, &hidden_k as *const u32 as *const _); enc.dispatch_thread_groups( - MTLSize::new(cfg.vocab_size as u64, 1, 1), - MTLSize::new(256, 1, 1), + MTLSize::new(vocab_n.div_ceil(4) as u64, 1, 1), + MTLSize::new(32, 4, 1), ); } } else if let Some(ref pb) = ppl_buf { @@ -9670,19 +9719,21 @@ kernel void moe_shared_gate_add( 1, cfg.rms_norm_eps, ); - // lm_head: for Q8 format use FP16 (matches MLX), for Q4 use Q4 path. - // Q8 per-row symmetric quantization of embeddings causes ~0.79 PPL - // gap from per-row scale not capturing outlier embeddings cleanly. + // lm_head: for Q8 format use FP16 wide kernel (NR=4, fewer TGs than + // gemv_decode_m1 which dispatches one TG per vocab row). For Q4 use Q4 path. match self.engine.quant_format { QuantFormat::Q8_0 => { - self.dispatch_matmul_half( - enc, - &self.session.activations.hidden, - &self.engine.embed_tokens, - &self.session.activations.logits, - 1, - cfg.vocab_size as u32, - hidden as u32, + let vocab_n = cfg.vocab_size as u32; + let hidden_k = hidden as u32; + enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode_wide); + enc.set_buffer(0, Some(&self.session.activations.hidden), 0); + enc.set_buffer(1, Some(&self.engine.embed_tokens), 0); + enc.set_buffer(2, Some(&self.session.activations.logits), 0); + enc.set_bytes(3, 4, &vocab_n as *const u32 as *const _); + enc.set_bytes(4, 4, &hidden_k as *const u32 as *const _); + enc.dispatch_thread_groups( + MTLSize::new(vocab_n.div_ceil(4) as u64, 1, 1), + MTLSize::new(32, 4, 1), ); } QuantFormat::Q4_0 => { @@ -11439,6 +11490,7 @@ kernel void moe_shared_gate_add( let pipelines = MetalQwen35Pipelines { gemv_decode: make_pipeline("gemv_decode_m1")?, + gemv_decode_wide: make_pipeline("gemv_decode_wide_f16")?, gemv_q8: make_pipeline("gemv_q8_decode")?, gemv_q8_wide: make_pipeline("gemv_q8_decode_wide")?, gemv_q4: make_pipeline("gemv_q4_decode")?, From 51154a2708f7336d5946058ad98ab6ea7f90cbf0 Mon Sep 17 00:00:00 2001 From: OceanLi <122793010+ohdearquant@users.noreply.github.com> Date: Sun, 31 May 2026 09:43:37 -0400 Subject: [PATCH 2/2] feat(inference): add ppl_metal binary + pyproject.toml for dev deps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ppl_metal: GPU-accelerated perplexity evaluation via Metal. Uses the same forward path as decode (including the wide f16 lm_head kernel). Configurable via PPL_TOKENS and CORPUS env vars. pyproject.toml: tracks common Python dev dependencies (pyarrow, datasets, mlx, numpy, matplotlib) so scripts/ and one-shot comparisons work without ad-hoc installs. Verified: Lattice PPL=20.60 vs MLX PPL=20.67 on wikitext-2 (2048 tokens, window=512, stride=256). Parity confirmed — no quality regression. Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/inference/Cargo.toml | 5 +++ crates/inference/src/bin/ppl_metal.rs | 44 +++++++++++++++++++++++++++ pyproject.toml | 20 ++++++++++++ 3 files changed, 69 insertions(+) create mode 100644 crates/inference/src/bin/ppl_metal.rs create mode 100644 pyproject.toml diff --git a/crates/inference/Cargo.toml b/crates/inference/Cargo.toml index 9a1ee8576..f1f9a205b 100644 --- a/crates/inference/Cargo.toml +++ b/crates/inference/Cargo.toml @@ -66,6 +66,11 @@ name = "bench_decode_ab" path = "src/bin/bench_decode_ab.rs" required-features = ["f16", "metal-gpu"] +[[bin]] +name = "ppl_metal" +path = "src/bin/ppl_metal.rs" +required-features = ["f16", "metal-gpu"] + [[bin]] name = "bench_logit_dump" path = "src/bin/bench_logit_dump.rs" diff --git a/crates/inference/src/bin/ppl_metal.rs b/crates/inference/src/bin/ppl_metal.rs new file mode 100644 index 000000000..cb4f470be --- /dev/null +++ b/crates/inference/src/bin/ppl_metal.rs @@ -0,0 +1,44 @@ +use lattice_inference::forward::metal_qwen35::MetalQwen35State; +use lattice_inference::model::qwen35::{PerplexityConfig, Qwen35Model}; +use lattice_inference::tokenizer::{BpeTokenizer, Tokenizer}; + +fn main() { + let home = std::env::var("HOME").unwrap(); + let model_dir = std::env::var("LATTICE_MODEL_DIR") + .unwrap_or_else(|_| format!("{home}/.lattice/models/qwen3.5-0.8b")); + let dir = std::path::Path::new(&model_dir); + let n_tokens: usize = std::env::var("PPL_TOKENS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(2048); + + eprintln!("[ppl_metal] loading {model_dir}"); + let model = Qwen35Model::from_safetensors(dir).expect("load model"); + let cfg = model.config().clone(); + let mut metal = MetalQwen35State::new(model.weights(), &cfg, 4096).expect("init metal"); + let tokenizer = BpeTokenizer::from_tokenizer_json(&dir.join("tokenizer.json")).unwrap(); + + let corpus_path = + std::env::var("CORPUS").unwrap_or_else(|_| "/tmp/wikitext2_test.txt".to_string()); + let corpus = std::fs::read_to_string(&corpus_path).expect("read corpus"); + let input = tokenizer.tokenize(&corpus); + let all_tokens: Vec = input.input_ids[..input.real_length].to_vec(); + let tokens = &all_tokens[..all_tokens.len().min(n_tokens)]; + eprintln!( + "[ppl_metal] scoring {} tokens (Metal GPU, Q8 + f16 lm_head)", + tokens.len() + ); + + let t = std::time::Instant::now(); + let ppl_cfg = PerplexityConfig { + window: 512, + stride: 256, + }; + let report = metal.compute_perplexity(tokens, &ppl_cfg).expect("ppl"); + let elapsed = t.elapsed(); + + println!("PPL: {:.4}", report.ppl); + println!("NLL: {:.6}", report.mean_nll); + println!("Tokens: {}", report.num_tokens_scored); + println!("Time: {:.1}s", elapsed.as_secs_f64()); +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..8701bdf45 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "lattice-dev" +version = "0.1.0" +description = "Dev scripts for lattice inference engine" +requires-python = ">=3.11" + +dependencies = [ + "pyarrow>=17.0", + "datasets>=3.0", + "numpy>=1.26", + "mlx>=0.22", + "mlx-lm>=0.21", + "transformers>=4.46", + "matplotlib>=3.9", +] + +[tool.uv] +dev-dependencies = [ + "pytest>=8.0", +]