From e7e917415bb47bdac03110fb86c0cbce857a0a10 Mon Sep 17 00:00:00 2001 From: OceanLi <122793010+ohdearquant@users.noreply.github.com> Date: Thu, 4 Jun 2026 10:08:03 -0400 Subject: [PATCH] perf(inference): chunked batched prefill for long prompts on Metal Replaces the per-token prefill loop with a chunked batched Metal path for long prompts, removing per-token GPU dispatch overhead during prefill. Decode unchanged; prefill argmax parity preserved. Co-Authored-By: Claude Opus 4.8 --- crates/inference/src/bin/bench_decode_ab.rs | 30 +- crates/inference/src/forward/metal_qwen35.rs | 308 ++++++++++++++++++- 2 files changed, 317 insertions(+), 21 deletions(-) diff --git a/crates/inference/src/bin/bench_decode_ab.rs b/crates/inference/src/bin/bench_decode_ab.rs index bd2a472dd..cab4f0a57 100644 --- a/crates/inference/src/bin/bench_decode_ab.rs +++ b/crates/inference/src/bin/bench_decode_ab.rs @@ -36,7 +36,7 @@ fn run() -> Result<(), Box> { use lattice_inference::forward::metal_qwen35::{ChatMessage, MetalQwen35State}; use lattice_inference::model::qwen35::Qwen35Model; use lattice_inference::model::qwen35_config::{GenerateConfig, Qwen35Config}; - use lattice_inference::tokenizer::BpeTokenizer; + use lattice_inference::tokenizer::{BpeTokenizer, Tokenizer}; let home = std::env::var("HOME")?; let model_dir_str = std::env::var("LATTICE_MODEL_DIR") @@ -116,9 +116,31 @@ fn run() -> Result<(), Box> { }; // Same continuation prompt as the original bench, single user turn. - let prompt = "The quick brown fox jumps over the lazy dog. \ - Once upon a time in a land far away, there lived a"; - let history = vec![ChatMessage::user(prompt)]; + // Optionally pad to a target context length (BENCH_PROMPT_TOKENS) by + // repeating filler text, so we can measure decode at real context depth + // (agentic workload: long context + short response). + let base = "The quick brown fox jumps over the lazy dog. \ + Once upon a time in a land far away, there lived a wise old owl \ + who knew many secrets. Every morning the sun rose over the \ + mountains and cast long shadows across the quiet valley. "; + let prompt: String = match std::env::var("BENCH_PROMPT_TOKENS") + .ok() + .and_then(|s| s.parse::().ok()) + { + Some(target) if target > 0 => { + let mut p = String::new(); + while tokenizer.tokenize(&p).real_length < target { + p.push_str(base); + } + p + } + _ => base.to_string(), + }; + eprintln!( + "[bench] prompt_tokens={}", + tokenizer.tokenize(&prompt).real_length + ); + let history = vec![ChatMessage::user(&prompt)]; // One warmup (not recorded). metal.reset_state(); diff --git a/crates/inference/src/forward/metal_qwen35.rs b/crates/inference/src/forward/metal_qwen35.rs index 1e5955da0..c92e40bb0 100644 --- a/crates/inference/src/forward/metal_qwen35.rs +++ b/crates/inference/src/forward/metal_qwen35.rs @@ -6348,7 +6348,9 @@ kernel void moe_shared_gate_add( /// Uses batch GEMM (M=prompt_len) for projections instead of per-token GEMV, /// giving ~10-20x speedup on prefill for typical prompt lengths. /// - /// Falls back to sequential forward_step for prompts longer than max_prefill. + /// Prompts longer than max_prefill without active LoRA are processed in + /// max_prefill-sized batched chunks; LoRA-active prompts remain on the + /// per-token forward_step fallback. pub fn forward_prefill(&mut self, token_ids: &[u32]) -> Vec { self.forward_prefill_impl(token_ids, false) } @@ -6375,18 +6377,14 @@ kernel void moe_shared_gate_add( }; } if n == 1 { - let logits = self.forward_step(token_ids[0], 0); - return logits; + return self.forward_step(token_ids[0], 0); } - // Sequential fallback (LoRA or oversize) — only supports last-token mode. - // For all_positions mode under these conditions, caller must reduce window - // or quantize a Q4 dir without LoRA. - if self.lora.is_some() || n > self.session.max_prefill { + if self.lora.is_some() { + // Batched helper does not apply LoRA adapters; stay on sequential path. if all_positions { panic!( - "forward_prefill_all_logits: n={n} exceeds max_prefill ({}) or LoRA active — \ - reduce window or run without LoRA", - self.session.max_prefill + "forward_prefill_all_logits: LoRA active — \ + batch/all-position prefill does not apply LoRA" ); } let mut last_logits = Vec::new(); @@ -6395,13 +6393,63 @@ kernel void moe_shared_gate_add( } return last_logits; } - assert!( n <= self.session.kv_cache.max_cache_len, "prefill length {} exceeds max_cache_len {}", n, self.session.kv_cache.max_cache_len ); + let max_prefill = self.session.max_prefill; + if n <= max_prefill { + return self.forward_prefill_batched_chunk(token_ids, 0, all_positions); + } + // Chunked batched prefill: each chunk is one command buffer (preserving the + // n≤512 fast path within each chunk). GDN recurrent state threads across + // boundaries automatically — session GPU buffers are mutated in place and + // never reset between chunks within a single request. + if all_positions { + let mut all_logits = Vec::with_capacity(n * vocab); + let mut start_pos = 0usize; + for chunk in token_ids.chunks(max_prefill) { + all_logits.extend(self.forward_prefill_batched_chunk(chunk, start_pos, true)); + start_pos += chunk.len(); + } + all_logits + } else { + let mut last_logits = Vec::new(); + let mut start_pos = 0usize; + for chunk in token_ids.chunks(max_prefill) { + last_logits = self.forward_prefill_batched_chunk(chunk, start_pos, false); + start_pos += chunk.len(); + } + last_logits + } + } + + /// Batched single-command-buffer prefill for a contiguous token slice starting + /// at absolute position `start_pos`. + /// + /// Writes full-attention K/V rows `start_pos..start_pos+n` and advances GDN + /// recurrent state in the session GPU buffers. Called by `forward_prefill_impl` + /// once per chunk; for `n ≤ max_prefill` it IS the entire prefill (one command + /// buffer, all 24 layers, final logits). + fn forward_prefill_batched_chunk( + &mut self, + token_ids: &[u32], + start_pos: usize, + all_positions: bool, + ) -> Vec { + let n = token_ids.len(); + debug_assert!( + n <= self.session.max_prefill, + "forward_prefill_batched_chunk: n={n} exceeds max_prefill={}", + self.session.max_prefill + ); + assert!( + start_pos + n <= self.session.kv_cache.max_cache_len, + "prefill chunk start_pos={start_pos} + n={n} exceeds max_cache_len {}", + self.session.kv_cache.max_cache_len + ); let cfg = self.engine.config.clone(); let m = n as u32; let hidden = cfg.hidden_size; @@ -6793,6 +6841,8 @@ kernel void moe_shared_gate_add( // Per-token: scatter, norm, RoPE, cache store, attention, gate for t in 0..n { + // abs_pos is the token's position in the full sequence across all chunks. + let abs_pos = start_pos + t; let q_off = (t * 2 * q_dim) as u64 * 4; let qs_off = (t * q_dim) as u64 * 4; let gz_off = (t * q_dim) as u64 * 4; @@ -6852,7 +6902,7 @@ kernel void moe_shared_gate_add( // SAFETY: RoPE table buffers cover max positions and // per-token Q/K offsets are within activation buffers. unsafe { - let pos = t as u32; + let pos = abs_pos as u32; enc.set_compute_pipeline_state(&self.engine.pipelines.partial_rope); enc.set_buffer(0, Some(&self.session.activations.q_separated), qs_off); enc.set_buffer(1, Some(&self.engine.rope_cos), 0); @@ -6878,9 +6928,10 @@ kernel void moe_shared_gate_add( ); } - // Store K, V to cache + // Store K, V to cache; use absolute row so cross-chunk tokens + // write to the correct KV cache rows rather than overwriting chunk 0. { - let dst_offset = (t * kv_dim) as u32; + let dst_offset = (abs_pos * kv_dim) as u32; enc.set_compute_pipeline_state(&self.engine.pipelines.copy_offset); enc.set_buffer(0, Some(&self.session.activations.k), k_off); enc.set_buffer(1, Some(&self.session.kv_cache.k_bufs[full_idx]), 0); @@ -6901,9 +6952,10 @@ kernel void moe_shared_gate_add( ); } - // Causal attention: query Q[t] against cache[0..t+1] + // Causal attention: Q[t] attends against cache[0..abs_pos+1], + // covering all prior chunks when start_pos > 0. { - let cache_len = (t + 1) as u32; + let cache_len = (abs_pos + 1) as u32; enc.set_compute_pipeline_state(&self.engine.pipelines.decode_attention); enc.set_buffer(0, Some(&self.session.activations.q_separated), qs_off); enc.set_buffer(1, Some(&self.session.kv_cache.k_bufs[full_idx]), 0); @@ -7172,7 +7224,7 @@ kernel void moe_shared_gate_add( self.session.last_pre_final_hidden = unsafe { read_buffer(&self.session.activations.pre_final_hidden, hidden) }; } - self.session.kv_cache.seq_len = n; + self.session.set_position(start_pos + n); if let Some(pb) = ppl_buf { // SAFETY: GPU completed, ppl_buf is StorageModeShared and sized n*vocab. @@ -15483,6 +15535,228 @@ kernel void decode_attention_reference( } } } + + // ----------------------------------------------------------------------- + // Chunked batched prefill parity tests + // ----------------------------------------------------------------------- + + /// Build a token sequence of 650-700 tokens from varied prose paragraphs. + /// Uses the provided tokenizer; grows text by appending paragraphs until + /// the encoded length is ≥ 650, then truncates to ≤ 680. + fn long_real_text_tokens(tokenizer: &BpeTokenizer) -> Vec { + let paragraphs = [ + "During a late engineering review, the team walked through the inference trace \ + one layer at a time. They checked where state changed, which buffers were reused, \ + and how a long request should preserve every earlier token.", + "The prompt continued with ordinary prose about debugging, benchmarks, release notes, \ + and careful handoffs. It used full sentences, punctuation, and varied vocabulary \ + so tokenization looked like real input rather than a repeated numeric pattern.", + "A second reviewer asked for evidence at the boundary between chunks. The answer \ + described rotary positions, cache rows, recurrent memory, and causal attention \ + in concrete terms before any optimization was accepted.", + "Verification across chunk boundaries requires that each token's absolute position \ + index matches the RoPE table row, the KV cache write offset, and the attention \ + causal mask — all three must use the same absolute coordinate, not a chunk-local one.", + ]; + let mut text = String::new(); + let mut ids = Vec::new(); + for i in 0..256usize { + if !text.is_empty() { + text.push_str("\n\n"); + } + text.push_str(paragraphs[i % paragraphs.len()]); + let input = tokenizer.tokenize(&text); + ids = input.input_ids[..input.real_length].to_vec(); + if ids.len() >= 650 { + break; + } + } + if ids.len() > 680 { + ids.truncate(680); + } + ids + } + + fn argmax_f32(xs: &[f32]) -> usize { + let mut best = 0usize; + let mut best_val = f32::NEG_INFINITY; + for (i, &x) in xs.iter().enumerate() { + if x > best_val { + best_val = x; + best = i; + } + } + best + } + + #[cfg(all(target_os = "macos", feature = "metal-gpu"))] + #[test] + fn metal_qwen35_chunked_prefill_long_prompt_matches_step_loop() { + // Parity gate: chunked forward_prefill must agree with token-by-token + // forward_step on a prompt longer than max_prefill (≈512). + // Tests that RoPE offsets, KV cache rows, attention cache_len, and + // GDN recurrent state all thread correctly across chunk boundaries. + let model_dir = + std::path::PathBuf::from(std::env::var("HOME").unwrap_or_else(|_| ".".into())) + .join(".lattice/models/qwen3.5-0.8b"); + if !model_dir.join("model.safetensors").exists() + || !model_dir.join("config.json").exists() + || !model_dir.join("tokenizer.json").exists() + { + eprintln!( + "skipping chunked prefill parity test: model files missing at {}", + model_dir.display() + ); + return; + } + + let model = crate::model::qwen35::Qwen35Model::from_safetensors(&model_dir) + .expect("load qwen3.5-0.8b"); + assert!( + model.config().num_active_linear_attention_layers() > 0, + "prompt must exercise at least one GDN layer" + ); + + let Some(device) = Device::system_default() else { + eprintln!("skipping chunked prefill parity test: no Metal device"); + return; + }; + let _ = device; + + let mut state = MetalQwen35State::new(model.weights(), model.config(), 1024) + .expect("construct Metal qwen3.5-0.8b state"); + + let tokens = long_real_text_tokens(model.tokenizer()); + assert!( + tokens.len() > state.session.max_prefill, + "prompt ({} tokens) must exceed max_prefill ({})", + tokens.len(), + state.session.max_prefill + ); + assert!( + (600..=700).contains(&tokens.len()), + "prompt length {} out of expected 600-700 range", + tokens.len() + ); + + // Path A: token-by-token reference. + state.reset_state(); + let mut step_logits = Vec::new(); + for (pos, &tok) in tokens.iter().enumerate() { + step_logits = state.forward_step(tok, pos); + } + + // Path B: chunked batched prefill. + state.reset_state(); + let prefill_logits = state.forward_prefill(&tokens); + + assert_eq!(step_logits.len(), model.config().vocab_size); + assert_eq!(prefill_logits.len(), model.config().vocab_size); + + let max_abs_diff = step_logits + .iter() + .zip(prefill_logits.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0_f32, f32::max); + + eprintln!( + "chunked prefill parity: max_abs_diff={max_abs_diff:.6}, \ + argmax_step={}, argmax_prefill={}, tokens={}", + argmax_f32(&step_logits), + argmax_f32(&prefill_logits), + tokens.len() + ); + + assert!( + max_abs_diff < 1e-2, + "chunked prefill logits diverged from step loop: max_abs_diff={max_abs_diff}" + ); + assert_eq!( + argmax_f32(&step_logits), + argmax_f32(&prefill_logits), + "argmax mismatch between step loop and chunked prefill" + ); + // Both paths must advance the session cursor to the full prompt length. + assert_eq!(state.session.kv_cache.seq_len, tokens.len()); + assert_eq!(state.session.position, tokens.len()); + } + + #[cfg(all(target_os = "macos", feature = "metal-gpu"))] + #[test] + fn metal_qwen35_prefill_all_logits_long_prompt_no_longer_panics() { + // Regression gate: forward_prefill_all_logits must NOT panic when + // n > max_prefill and LoRA is inactive. Old code panicked; new code + // chunks the request and concatenates per-chunk all-position logits. + let model_dir = + std::path::PathBuf::from(std::env::var("HOME").unwrap_or_else(|_| ".".into())) + .join(".lattice/models/qwen3.5-0.8b"); + if !model_dir.join("model.safetensors").exists() + || !model_dir.join("config.json").exists() + || !model_dir.join("tokenizer.json").exists() + { + eprintln!( + "skipping all-logits long-prompt test: model files missing at {}", + model_dir.display() + ); + return; + } + + let Some(device) = Device::system_default() else { + eprintln!("skipping all-logits long-prompt test: no Metal device"); + return; + }; + let _ = device; + + let model = crate::model::qwen35::Qwen35Model::from_safetensors(&model_dir) + .expect("load qwen3.5-0.8b"); + let mut state = MetalQwen35State::new(model.weights(), model.config(), 1024) + .expect("construct Metal qwen3.5-0.8b state"); + + let tokens = long_real_text_tokens(model.tokenizer()); + assert!(tokens.len() > state.session.max_prefill); + + // Must not panic; returns n * vocab_size f32 values. + state.reset_state(); + let flat = state.forward_prefill_all_logits(&tokens); + assert_eq!( + flat.len(), + tokens.len() * model.config().vocab_size, + "all-logits output length must be n * vocab_size" + ); + + // Final row of all-logits must agree with token-by-token final logits. + let vocab = model.config().vocab_size; + let final_row = &flat[(tokens.len() - 1) * vocab..tokens.len() * vocab]; + + state.reset_state(); + let mut step_logits = Vec::new(); + for (pos, &tok) in tokens.iter().enumerate() { + step_logits = state.forward_step(tok, pos); + } + + let max_abs_diff = step_logits + .iter() + .zip(final_row.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0_f32, f32::max); + + eprintln!( + "all-logits final-row parity: max_abs_diff={max_abs_diff:.6}, \ + argmax_step={}, argmax_all_logits_final={}", + argmax_f32(&step_logits), + argmax_f32(final_row) + ); + + assert!( + max_abs_diff < 1e-2, + "all-logits final row diverged from step loop: max_abs_diff={max_abs_diff}" + ); + assert_eq!( + argmax_f32(&step_logits), + argmax_f32(final_row), + "argmax mismatch between step loop and all-logits final row" + ); + } } impl crate::speculative::MtpTargetVerifier for MetalQwen35State {