diff --git a/crates/inference/Cargo.toml b/crates/inference/Cargo.toml index 9a1ee8576..f1f9a205b 100644 --- a/crates/inference/Cargo.toml +++ b/crates/inference/Cargo.toml @@ -66,6 +66,11 @@ name = "bench_decode_ab" path = "src/bin/bench_decode_ab.rs" required-features = ["f16", "metal-gpu"] +[[bin]] +name = "ppl_metal" +path = "src/bin/ppl_metal.rs" +required-features = ["f16", "metal-gpu"] + [[bin]] name = "bench_logit_dump" path = "src/bin/bench_logit_dump.rs" diff --git a/crates/inference/src/bin/ppl_metal.rs b/crates/inference/src/bin/ppl_metal.rs new file mode 100644 index 000000000..cb4f470be --- /dev/null +++ b/crates/inference/src/bin/ppl_metal.rs @@ -0,0 +1,44 @@ +use lattice_inference::forward::metal_qwen35::MetalQwen35State; +use lattice_inference::model::qwen35::{PerplexityConfig, Qwen35Model}; +use lattice_inference::tokenizer::{BpeTokenizer, Tokenizer}; + +fn main() { + let home = std::env::var("HOME").unwrap(); + let model_dir = std::env::var("LATTICE_MODEL_DIR") + .unwrap_or_else(|_| format!("{home}/.lattice/models/qwen3.5-0.8b")); + let dir = std::path::Path::new(&model_dir); + let n_tokens: usize = std::env::var("PPL_TOKENS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(2048); + + eprintln!("[ppl_metal] loading {model_dir}"); + let model = Qwen35Model::from_safetensors(dir).expect("load model"); + let cfg = model.config().clone(); + let mut metal = MetalQwen35State::new(model.weights(), &cfg, 4096).expect("init metal"); + let tokenizer = BpeTokenizer::from_tokenizer_json(&dir.join("tokenizer.json")).unwrap(); + + let corpus_path = + std::env::var("CORPUS").unwrap_or_else(|_| "/tmp/wikitext2_test.txt".to_string()); + let corpus = std::fs::read_to_string(&corpus_path).expect("read corpus"); + let input = tokenizer.tokenize(&corpus); + let all_tokens: Vec = input.input_ids[..input.real_length].to_vec(); + let tokens = &all_tokens[..all_tokens.len().min(n_tokens)]; + eprintln!( + "[ppl_metal] scoring {} tokens (Metal GPU, Q8 + f16 lm_head)", + tokens.len() + ); + + let t = std::time::Instant::now(); + let ppl_cfg = PerplexityConfig { + window: 512, + stride: 256, + }; + let report = metal.compute_perplexity(tokens, &ppl_cfg).expect("ppl"); + let elapsed = t.elapsed(); + + println!("PPL: {:.4}", report.ppl); + println!("NLL: {:.6}", report.mean_nll); + println!("Tokens: {}", report.num_tokens_scored); + println!("Time: {:.1}s", elapsed.as_secs_f64()); +} diff --git a/crates/inference/src/forward/metal_qwen35.rs b/crates/inference/src/forward/metal_qwen35.rs index ae9aae982..56bb49e3a 100644 --- a/crates/inference/src/forward/metal_qwen35.rs +++ b/crates/inference/src/forward/metal_qwen35.rs @@ -297,6 +297,65 @@ kernel void gemv_q8_decode_wide( } } +// ===== FP16-weight GEMV Decode (wide): NR=4 rows per threadgroup ===== +// Same scheduling as gemv_q8_decode_wide but reads half-precision weights +// directly (no Q8 block structure). Dispatched for large-N matmuls (lm_head) +// where the NR=1 gemv_decode_m1 kernel creates excessive threadgroup count. +// Dispatch: threadgroups=(ceil(N/4), 1, 1), threads=(32, 4, 1) +kernel void gemv_decode_wide_f16( + device const float* x [[buffer(0)]], + device const half* W [[buffer(1)]], + device float* y [[buffer(2)]], + constant uint& N [[buffer(3)]], + constant uint& K [[buffer(4)]], + uint3 tgpig [[threadgroup_position_in_grid]], + uint tiisg [[thread_index_in_simdgroup]], + uint sgitg [[simdgroup_index_in_threadgroup]]) +{ + const uint NR = 4; + const uint NSG = 4; + const uint nb = K / 32; + const uint first_row = tgpig.x * NR; + const uint ix = tiisg / 4; + const uint il = tiisg % 4; + + float sumf[NR] = {0.0f}; + const uint ib_start = sgitg * 8 + ix; + const uint ib_stride = NSG * 8; + device const float* xb = x + ib_start * 32 + il * 8; + + for (uint ib = ib_start; ib < nb; ib += ib_stride) { + float xl[8]; + for (uint i = 0; i < 8; i++) xl[i] = xb[i]; + xb += ib_stride * 32; + + for (uint row = 0; row < NR; row++) { + uint r = first_row + row; + if (r >= N) continue; + device const half* wrow = W + r * K + ib * 32 + il * 8; + float dot = 0.0f; + for (uint i = 0; i < 8; i++) dot += xl[i] * float(wrow[i]); + sumf[row] += dot; + } + } + + for (uint row = 0; row < NR; row++) sumf[row] = simd_sum(sumf[row]); + + threadgroup float shared[NR][4]; + if (tiisg == 0) { + for (uint row = 0; row < NR; row++) shared[row][sgitg] = sumf[row]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (sgitg == 0 && tiisg == 0) { + for (uint row = 0; row < NR; row++) { + uint r = first_row + row; + if (r < N) { + y[r] = shared[row][0] + shared[row][1] + shared[row][2] + shared[row][3]; + } + } + } +} + // ===== Qwen3.5 RMS Norm: x = x * (1 + gamma) / sqrt(mean(x^2) + eps) ===== // Shifted norm: (1 + gamma) instead of plain gamma. // One threadgroup per row. @@ -2812,6 +2871,7 @@ kernel void moe_shared_gate_add( /// Compiled MSL pipeline state objects. struct MetalQwen35Pipelines { gemv_decode: ComputePipelineState, + gemv_decode_wide: ComputePipelineState, gemv_q8: ComputePipelineState, gemv_q8_wide: ComputePipelineState, rms_norm: ComputePipelineState, @@ -3632,6 +3692,7 @@ kernel void moe_shared_gate_add( let pipelines = MetalQwen35Pipelines { gemv_decode: make_pipeline("gemv_decode_m1")?, + gemv_decode_wide: make_pipeline("gemv_decode_wide_f16")?, gemv_q8: make_pipeline("gemv_q8_decode")?, gemv_q8_wide: make_pipeline("gemv_q8_decode_wide")?, gemv_q4: make_pipeline("gemv_q4_decode")?, @@ -7013,45 +7074,33 @@ kernel void moe_shared_gate_add( // matmul kernel is faster + handles QuaRot rotations correctly. let use_fp16_lm_head = matches!(self.engine.quant_format, QuantFormat::Q8_0); if use_fp16_lm_head { - let gemm_params = GemmParams { - m: 1, - n: cfg.vocab_size as u32, - k: hidden as u32, - lda: hidden as u32, - ldb: hidden as u32, - ldc: cfg.vocab_size as u32, - }; + let vocab_n = cfg.vocab_size as u32; + let hidden_k = hidden as u32; if let Some(ref pb) = ppl_buf { for p in 0..n { let hidden_off = (p * hidden) as u64 * 4; let logits_off = (p * cfg.vocab_size) as u64 * 4; - enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode); + enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode_wide); enc.set_buffer(0, Some(&self.session.activations.hidden), hidden_off); enc.set_buffer(1, Some(&self.engine.embed_tokens), 0); enc.set_buffer(2, Some(pb), logits_off); - enc.set_bytes( - 3, - std::mem::size_of::() as u64, - &gemm_params as *const GemmParams as *const _, - ); + enc.set_bytes(3, 4, &vocab_n as *const u32 as *const _); + enc.set_bytes(4, 4, &hidden_k as *const u32 as *const _); enc.dispatch_thread_groups( - MTLSize::new(cfg.vocab_size as u64, 1, 1), - MTLSize::new(256, 1, 1), + MTLSize::new(vocab_n.div_ceil(4) as u64, 1, 1), + MTLSize::new(32, 4, 1), ); } } else { - enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode); + enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode_wide); enc.set_buffer(0, Some(&self.session.activations.hidden), last_off); enc.set_buffer(1, Some(&self.engine.embed_tokens), 0); enc.set_buffer(2, Some(&self.session.activations.logits), 0); - enc.set_bytes( - 3, - std::mem::size_of::() as u64, - &gemm_params as *const GemmParams as *const _, - ); + enc.set_bytes(3, 4, &vocab_n as *const u32 as *const _); + enc.set_bytes(4, 4, &hidden_k as *const u32 as *const _); enc.dispatch_thread_groups( - MTLSize::new(cfg.vocab_size as u64, 1, 1), - MTLSize::new(256, 1, 1), + MTLSize::new(vocab_n.div_ceil(4) as u64, 1, 1), + MTLSize::new(32, 4, 1), ); } } else if let Some(ref pb) = ppl_buf { @@ -9670,19 +9719,21 @@ kernel void moe_shared_gate_add( 1, cfg.rms_norm_eps, ); - // lm_head: for Q8 format use FP16 (matches MLX), for Q4 use Q4 path. - // Q8 per-row symmetric quantization of embeddings causes ~0.79 PPL - // gap from per-row scale not capturing outlier embeddings cleanly. + // lm_head: for Q8 format use FP16 wide kernel (NR=4, fewer TGs than + // gemv_decode_m1 which dispatches one TG per vocab row). For Q4 use Q4 path. match self.engine.quant_format { QuantFormat::Q8_0 => { - self.dispatch_matmul_half( - enc, - &self.session.activations.hidden, - &self.engine.embed_tokens, - &self.session.activations.logits, - 1, - cfg.vocab_size as u32, - hidden as u32, + let vocab_n = cfg.vocab_size as u32; + let hidden_k = hidden as u32; + enc.set_compute_pipeline_state(&self.engine.pipelines.gemv_decode_wide); + enc.set_buffer(0, Some(&self.session.activations.hidden), 0); + enc.set_buffer(1, Some(&self.engine.embed_tokens), 0); + enc.set_buffer(2, Some(&self.session.activations.logits), 0); + enc.set_bytes(3, 4, &vocab_n as *const u32 as *const _); + enc.set_bytes(4, 4, &hidden_k as *const u32 as *const _); + enc.dispatch_thread_groups( + MTLSize::new(vocab_n.div_ceil(4) as u64, 1, 1), + MTLSize::new(32, 4, 1), ); } QuantFormat::Q4_0 => { @@ -11439,6 +11490,7 @@ kernel void moe_shared_gate_add( let pipelines = MetalQwen35Pipelines { gemv_decode: make_pipeline("gemv_decode_m1")?, + gemv_decode_wide: make_pipeline("gemv_decode_wide_f16")?, gemv_q8: make_pipeline("gemv_q8_decode")?, gemv_q8_wide: make_pipeline("gemv_q8_decode_wide")?, gemv_q4: make_pipeline("gemv_q4_decode")?, diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..8701bdf45 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "lattice-dev" +version = "0.1.0" +description = "Dev scripts for lattice inference engine" +requires-python = ">=3.11" + +dependencies = [ + "pyarrow>=17.0", + "datasets>=3.0", + "numpy>=1.26", + "mlx>=0.22", + "mlx-lm>=0.21", + "transformers>=4.46", + "matplotlib>=3.9", +] + +[tool.uv] +dev-dependencies = [ + "pytest>=8.0", +]