diff --git a/README.md b/README.md index f919a833..a930ffd0 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,174 @@ higgs serve --model mlx-community/Qwen3.6-35B-A3B-4bit Send a request to the local endpoint: ```bash +<<<<<<< HEAD +======= +higgs init # create ~/.config/higgs/config.toml +higgs serve # start with config +higgs start # start as background daemon +higgs attach # attach TUI dashboard to running daemon +higgs stop # stop daemon +``` + +### Profiles + +Named profiles let you maintain multiple configurations and run multiple instances simultaneously: + +```bash +higgs init --profile dev # create config.dev.toml +higgs init --profile prod # create config.prod.toml +higgs serve --profile dev # foreground with dev config +higgs start --profile dev # daemon with dev config (separate PID/log) +higgs start --profile prod # daemon with prod config (different port) +higgs attach --profile dev # attach TUI to dev instance +higgs stop --profile dev # stop only the dev instance +higgs doctor --profile prod # validate prod config +``` + +Each profile gets isolated runtime files (`higgs..pid`, `higgs..log`, `metrics..jsonl`). Profiles must use different ports (configured in each profile's config file). `--profile` and `--config` are mutually exclusive. + +## Features + +### Local inference +- **OpenAI + Anthropic APIs** -- chat completions, text completions, embeddings, messages +- **Structured output** -- `json_schema` response format (100% schema compliance) +- **Reasoning models** -- `` tag extraction to `reasoning_content` +- **Continuous batching** -- 755 tok/s aggregate at 8 concurrent requests +- **Radix tree prefix cache** -- shared prefix reuse across requests +- **Vision** -- multimodal image+text (LLaVA-Qwen2) +- **11 architectures** -- LLaMA, Mistral, Qwen2/3, Qwen3-MoE, Qwen3-Next, Gemma 2, Phi-3, Starcoder2, DeepSeek-V2, LLaVA-Qwen2 + +### Gateway +- **Remote providers** -- proxy requests to OpenAI, Anthropic, Ollama, or any OpenAI-compatible API +- **Format translation** -- send OpenAI requests to Anthropic providers (and vice versa) with automatic conversion of request/response formats, including streaming +- **Pattern routing** -- regex-based model name matching to route requests to the right provider +- **Model rewriting** -- map model aliases to upstream model names +- **Auto-router** -- classify requests using a local LLM to pick the best provider +- **Metrics dashboard** -- TUI with live request rates, latency, token throughput, and error tracking +- **Daemon mode** -- `higgs start`/`stop`/`attach` for background operation +- **Config management** -- `higgs config get/set`, `higgs doctor` for validation + +## Configuration + +### Simple mode (CLI flags) + +| CLI Flag | Env Variable | Default | Description | +|---|---|---|---| +| `--model` | `HIGGS_MODELS` | *(required)* | Model path or HF ID (repeatable) | +| `--host` | `HIGGS_HOST` | `0.0.0.0` | Bind address | +| `--port` | `HIGGS_PORT` | `8000` | Bind port | +| `--max-tokens` | `HIGGS_MAX_TOKENS` | `32768` | Max generation tokens | +| `--api-key` | `HIGGS_API_KEY` | *(none)* | Bearer token for auth | +| `--rate-limit` | `HIGGS_RATE_LIMIT` | `0` | Requests/min per client | +| `--timeout` | `HIGGS_TIMEOUT` | `300` | Request timeout (seconds) | +| `--batch` | -- | `false` | Enable continuous batching | + +### Gateway mode (config file) + +Run `higgs init` to create `~/.config/higgs/config.toml`: + +```toml +[server] +host = "0.0.0.0" +port = 8000 +# max_tokens = 32768 +# timeout = 300.0 +# api_key = "sk-..." + +# --- Local models --- +[[models]] +path = "mlx-community/Llama-3.2-1B-Instruct-4bit" +# name = "llama" # optional friendly name (used as engine key and for auto_router lookup) +# batch = false +# draft_model = "mlx-community/Llama-3.2-1B-Instruct-4bit" # speculative decoding +# num_draft = 8 # draft tokens per speculative cycle (default: 8) + +# --- Remote providers --- +[provider.anthropic] +url = "https://api.anthropic.com" +format = "anthropic" + +[provider.openai] +url = "https://api.openai.com" +format = "openai" + +[provider.ollama] +url = "http://localhost:11434" +strip_auth = true + +# --- Routes --- +# First regex match wins. Requests matching a local model name are served locally. + +[[routes]] +pattern = "claude-.*" +provider = "anthropic" + +[[routes]] +pattern = "gpt-.*" +provider = "openai" + +# Model rewriting: requests for "my-alias" are sent to the provider as "actual-model-name" +# [[routes]] +# pattern = "my-alias" +# provider = "openai" +# model = "gpt-4o" + +# --- Default route --- +[default] +provider = "higgs" # "higgs" = local models only; set to a provider name to proxy unmatched requests + +# --- Auto router (optional) --- +# Classify requests with a local LLM to pick the best provider automatically. +# The model field can reference a model by name or path. +# [auto_router] +# enabled = true +# model = "llama" # matches [[models]] name or path +# timeout_ms = 2000 + +# --- Metrics & dashboard --- +[retention] +enabled = true +minutes = 60 + +[logging.metrics] +enabled = true +# path = "~/.config/higgs/logs/metrics.jsonl" +# max_size_mb = 50 +# max_files = 5 +``` + +#### Provider options + +| Field | Type | Default | Description | +|---|---|---|---| +| `url` | string | *(required)* | Base URL of the upstream API | +| `format` | `"openai"` or `"anthropic"` | `"openai"` | API format the provider speaks | +| `api_key` | string | *(none)* | API key to inject into proxied requests | +| `strip_auth` | bool | `false` | Remove the client's Authorization header before proxying | +| `stub_count_tokens` | bool | `false` | Return a stub for `/v1/messages/count_tokens` | + +#### Route options + +| Field | Type | Description | +|---|---|---| +| `pattern` | regex | Match against the `model` field in requests | +| `provider` | string | Provider name to forward to | +| `model` | string | Rewrite the model field before forwarding | +| `name` | string | Human label (used by auto-router) | +| `description` | string | Route description (used by auto-router for classification) | + +## API + +**OpenAI**: `/v1/chat/completions`, `/v1/completions`, `/v1/embeddings`, `/v1/models` +**Anthropic**: `/v1/messages`, `/v1/messages/count_tokens` +**Metrics**: `/metrics` (JSON) +**Health**: `/health` + +Format translation works transparently: send an OpenAI-format request to higgs and it will translate to Anthropic format if the matched route points to an Anthropic provider (and vice versa), including streaming responses. + +```bash +# Local model +>>>>>>> feef8e47 (feat(doctor): validate draft_model path and batch incompatibility) curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ diff --git a/crates/higgs-engine/src/lib.rs b/crates/higgs-engine/src/lib.rs index 705a82b4..16150eb0 100644 --- a/crates/higgs-engine/src/lib.rs +++ b/crates/higgs-engine/src/lib.rs @@ -13,6 +13,7 @@ pub mod reasoning_parser; pub mod scheduler; pub mod simple; pub mod spec_prefill; +pub mod speculative; pub mod tool_parser; pub use tokenizers; diff --git a/crates/higgs-engine/src/speculative.rs b/crates/higgs-engine/src/speculative.rs new file mode 100644 index 00000000..ded50a9a --- /dev/null +++ b/crates/higgs-engine/src/speculative.rs @@ -0,0 +1,371 @@ +use crate::error::EngineError; + +/// Run one speculative decode cycle. +/// +/// 1. Draft `num_draft` tokens with the draft model +/// 2. Build verify batch: `[last_token_id, draft_0, ..., draft_{K-1}]` +/// 3. Call `verify_fn` with the batch to get `K+1` target-sampled token IDs +/// 4. Accept the longest matching prefix +/// 5. Advance or rollback the draft model accordingly +/// +/// Returns the accepted token IDs (1..=K+1). +pub fn speculative_step( + draft: &mut dyn DraftModel, + last_token_id: u32, + num_draft: usize, + verify_fn: F, +) -> Result, EngineError> +where + F: FnOnce(&[u32]) -> Result, EngineError>, +{ + let draft_ids = draft.draft(last_token_id, num_draft)?; + let k = draft_ids.len(); + + let mut verify_batch = Vec::with_capacity(k + 1); + verify_batch.push(last_token_id); + verify_batch.extend_from_slice(&draft_ids); + + let target_ids = verify_fn(&verify_batch)?; + let accepted = accept_prefix(&draft_ids, &target_ids)?; + + let matched = if accepted.len() > k { + k + } else { + accepted.len().saturating_sub(1) + }; + + if matched > 0 { + draft.advance(matched)?; + } else { + draft.rollback()?; + } + + Ok(accepted) +} + +/// Result of one speculative decode cycle. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StepResult { + pub tokens: Vec, + pub hit_eos: bool, +} + +/// Run a full speculative decode loop until EOS or `max_tokens`. +pub fn speculative_loop( + draft: &mut dyn DraftModel, + last_token_id: u32, + num_draft: usize, + max_tokens: usize, + eos_ids: &[u32], + mut verify_fn: F, +) -> Result, EngineError> +where + F: FnMut(&[u32]) -> Result, EngineError>, +{ + let mut generated = Vec::new(); + let mut current_token = last_token_id; + + while generated.len() < max_tokens { + let remaining = max_tokens - generated.len(); + let k = num_draft.min(remaining); + if k == 0 { + break; + } + + let accepted = speculative_step(draft, current_token, k, |batch| verify_fn(batch))?; + + for &token in &accepted { + if generated.len() >= max_tokens { + break; + } + generated.push(token); + if eos_ids.contains(&token) { + return Ok(generated); + } + } + + if let Some(&last) = generated.last() { + current_token = last; + } + } + + Ok(generated) +} + +/// Compute the accepted prefix from a speculative decode cycle. +/// +/// Given `draft_ids` (K tokens from the draft model) and `target_ids` (K+1 +/// samples from the target model's verify logits), return the longest prefix +/// where draft and target agree, followed by the target's first divergent +/// token. +/// +/// Invariants: +/// - `target_ids.len() == draft_ids.len() + 1` +/// - Returns 1..=K+1 tokens (always at least one token from the target) +pub fn accept_prefix(draft_ids: &[u32], target_ids: &[u32]) -> Result, EngineError> { + let k = draft_ids.len(); + if target_ids.len() != k + 1 { + return Err(EngineError::Generation(format!( + "accept_prefix: target_ids.len() ({}) must be draft_ids.len() ({k}) + 1", + target_ids.len(), + ))); + } + + // Walk both slices in lock-step: types enforce bounds, no indexing needed. + // The k+1th target token (the bonus when every draft matched) is appended + // after the loop using `.last()` on target_ids. + let mut accepted = Vec::with_capacity(k + 1); + for (&target_token, &draft_token) in target_ids.iter().zip(draft_ids.iter()) { + accepted.push(target_token); + if target_token != draft_token { + return Ok(accepted); + } + } + // All k draft tokens matched — append the verify model's k+1th sample. + // Safe because we validated `target_ids.len() == k + 1` above; the + // .last() returns Some unless the slice is empty (impossible for k+1≥1). + if let Some(&bonus_token) = target_ids.last() { + accepted.push(bonus_token); + } + Ok(accepted) +} + +/// Trait for a draft model that produces candidate tokens for speculative +/// decoding. Implementations may run on any device (GPU, ANE, CPU). +pub trait DraftModel: Send { + /// Prefill the draft model with the given prompt tokens, resetting any + /// prior cache state. Must be called once before the first `draft()` call + /// in a new generation request. + fn prefill(&mut self, prompt_tokens: &[u32]) -> Result<(), EngineError>; + + /// Generate up to `num_draft` greedy tokens starting from `last_token_id`. + fn draft(&mut self, last_token_id: u32, num_draft: usize) -> Result, EngineError>; + + /// Advance internal state by `n` accepted tokens. + /// Called after verify confirms the first `n` draft tokens. + fn advance(&mut self, n: usize) -> Result<(), EngineError>; + + /// Roll back to the state before the last `draft()` call. + /// Called when the target rejects draft tokens and we need to resync. + fn rollback(&mut self) -> Result<(), EngineError>; +} + +#[cfg(test)] +#[allow( + clippy::panic, + clippy::unwrap_used, + clippy::indexing_slicing, + clippy::unreachable +)] +mod tests { + use super::*; + + #[test] + fn accept_prefix_all_match_returns_k_plus_one() { + let accepted = accept_prefix(&[5, 3, 7], &[5, 3, 7, 42]).unwrap(); + assert_eq!(accepted, vec![5, 3, 7, 42]); + } + + #[test] + fn accept_prefix_first_mismatch_returns_one() { + let accepted = accept_prefix(&[5, 3, 7], &[9, 1, 2, 0]).unwrap(); + assert_eq!(accepted, vec![9]); + } + + #[test] + fn accept_prefix_mid_mismatch() { + let accepted = accept_prefix(&[5, 3, 7], &[5, 3, 9, 0]).unwrap(); + assert_eq!(accepted, vec![5, 3, 9]); + } + + #[test] + fn accept_prefix_single_draft_match() { + let accepted = accept_prefix(&[10], &[10, 99]).unwrap(); + assert_eq!(accepted, vec![10, 99]); + } + + #[test] + fn accept_prefix_single_draft_mismatch() { + let accepted = accept_prefix(&[10], &[20, 99]).unwrap(); + assert_eq!(accepted, vec![20]); + } + + #[test] + fn accept_prefix_empty_draft() { + let accepted = accept_prefix(&[], &[42]).unwrap(); + assert_eq!(accepted, vec![42]); + } + + #[test] + fn accept_prefix_wrong_length_errors() { + let err = accept_prefix(&[1, 2], &[1, 2]).unwrap_err(); + assert!(err.to_string().contains("must be")); + } + + struct MockDraft { + sequence: Vec, + cursor: usize, + draft_count: usize, + } + + impl MockDraft { + fn new(sequence: Vec) -> Self { + Self { + sequence, + cursor: 0, + draft_count: 0, + } + } + } + + impl DraftModel for MockDraft { + fn prefill(&mut self, _prompt_tokens: &[u32]) -> Result<(), EngineError> { + Ok(()) + } + + fn draft( + &mut self, + _last_token_id: u32, + num_draft: usize, + ) -> Result, EngineError> { + let mut tokens = Vec::with_capacity(num_draft); + for i in 0..num_draft { + let idx = (self.cursor + i) % self.sequence.len(); + tokens.push(self.sequence[idx]); + } + self.draft_count = num_draft; + Ok(tokens) + } + + fn advance(&mut self, n: usize) -> Result<(), EngineError> { + self.cursor = (self.cursor + n) % self.sequence.len(); + self.draft_count = 0; + Ok(()) + } + + fn rollback(&mut self) -> Result<(), EngineError> { + self.draft_count = 0; + Ok(()) + } + } + + #[test] + fn mock_draft_produces_tokens() { + let mut draft = MockDraft::new(vec![10, 20, 30]); + let tokens = draft.draft(0, 3).unwrap(); + assert_eq!(tokens, vec![10, 20, 30]); + } + + #[test] + fn mock_draft_advance_shifts_cursor() { + let mut draft = MockDraft::new(vec![10, 20, 30]); + let _ = draft.draft(0, 2).unwrap(); + draft.advance(2).unwrap(); + let tokens = draft.draft(0, 2).unwrap(); + assert_eq!(tokens, vec![30, 10]); + } + + #[test] + fn mock_draft_rollback_preserves_cursor() { + let mut draft = MockDraft::new(vec![10, 20, 30]); + let _ = draft.draft(0, 2).unwrap(); + draft.rollback().unwrap(); + let tokens = draft.draft(0, 2).unwrap(); + assert_eq!(tokens, vec![10, 20]); + } + + #[test] + fn step_all_accepted_returns_k_plus_one() { + let mut draft = MockDraft::new(vec![10, 20, 30]); + let accepted = speculative_step(&mut draft, 0, 3, |batch| { + assert_eq!(batch, &[0, 10, 20, 30]); + Ok(vec![10, 20, 30, 99]) + }) + .unwrap(); + assert_eq!(accepted, vec![10, 20, 30, 99]); + assert_eq!(draft.cursor, 0); + } + + #[test] + fn step_partial_accept_advances_draft() { + let mut draft = MockDraft::new(vec![10, 20, 30]); + let accepted = speculative_step(&mut draft, 0, 3, |_| Ok(vec![10, 20, 99, 55])).unwrap(); + assert_eq!(accepted, vec![10, 20, 99]); + assert_eq!(draft.cursor, 2); + } + + #[test] + fn step_no_match_rollback() { + let mut draft = MockDraft::new(vec![10, 20, 30]); + let accepted = speculative_step(&mut draft, 0, 3, |_| Ok(vec![77, 0, 0, 0])).unwrap(); + assert_eq!(accepted, vec![77]); + assert_eq!(draft.cursor, 0); + } + + #[test] + fn step_single_draft_match() { + let mut draft = MockDraft::new(vec![10]); + let accepted = speculative_step(&mut draft, 5, 1, |batch| { + assert_eq!(batch, &[5, 10]); + Ok(vec![10, 42]) + }) + .unwrap(); + assert_eq!(accepted, vec![10, 42]); + } + + #[test] + fn step_verify_error_propagates() { + let mut draft = MockDraft::new(vec![10, 20]); + let err = speculative_step(&mut draft, 0, 2, |_| { + Err(EngineError::Generation("GPU OOM".into())) + }) + .unwrap_err(); + assert!(err.to_string().contains("GPU OOM")); + } + + #[test] + fn loop_generates_until_max_tokens() { + let mut draft = MockDraft::new(vec![1, 2, 3]); + let tokens = speculative_loop(&mut draft, 0, 3, 10, &[999], |batch| { + let mut target = batch[1..].to_vec(); + target.push(50); + Ok(target) + }) + .unwrap(); + assert_eq!(tokens.len(), 10); + } + + #[test] + fn loop_stops_on_eos() { + let mut draft = MockDraft::new(vec![1, 2, 0]); + let tokens = speculative_loop(&mut draft, 99, 3, 100, &[0], |batch| { + let mut target = batch[1..].to_vec(); + target.push(50); + Ok(target) + }) + .unwrap(); + assert!(tokens.contains(&0)); + assert!(tokens.len() < 100); + } + + #[test] + fn loop_with_partial_accepts_still_progresses() { + let mut draft = MockDraft::new(vec![1, 2, 3]); + let tokens = speculative_loop(&mut draft, 0, 3, 6, &[999], |batch| { + let k = batch.len() - 1; + let mut target = vec![77]; + target.resize(k + 1, 0); + Ok(target) + }) + .unwrap(); + assert_eq!(tokens.len(), 6); + assert!(tokens.iter().all(|&t| t == 77)); + } + + #[test] + fn loop_empty_max_tokens() { + let mut draft = MockDraft::new(vec![1]); + let tokens = speculative_loop(&mut draft, 0, 3, 0, &[], |_| unreachable!()).unwrap(); + assert!(tokens.is_empty()); + } +} diff --git a/crates/higgs-models/src/lib.rs b/crates/higgs-models/src/lib.rs index 6aec6383..407f9fbf 100644 --- a/crates/higgs-models/src/lib.rs +++ b/crates/higgs-models/src/lib.rs @@ -91,6 +91,29 @@ pub enum AnyCache { Hybrid(Vec>), } +impl AnyCache { + /// Trim every layer cache by `count` tokens, discarding the most recent + /// entries. Used after speculative-decode verify to roll back rejected + /// draft tokens. Hybrid SSM (recurrent) layers are intentionally left + /// untouched — their state cannot be trimmed by offset alone. + pub fn trim_by(&mut self, count: usize) { + match self { + Self::KV(layers) => { + for layer in layers.iter_mut().flatten() { + layer.trim_by(count); + } + } + Self::Hybrid(layers) => { + for layer in layers.iter_mut().flatten() { + if let LayerCache::KV(kv) = layer { + kv.trim_by(count); + } + } + } + } + } +} + /// Unified model wrapper dispatching to the correct architecture. pub enum AnyModel { /// Standard transformer architectures: Llama, Mistral, Qwen2/2.5, Qwen3. @@ -1189,6 +1212,7 @@ fn remap_quantized_key(key: &str) -> Option { #[allow(clippy::panic, clippy::unwrap_used, clippy::indexing_slicing)] mod tests { use super::*; + use crate::cache::KeyValueCache; fn params(temp: f32, top_p: f32) -> SamplingParams { SamplingParams { @@ -1695,4 +1719,58 @@ mod tests { assert!((vals[1] - 2.0).abs() < 1e-5); assert!((vals[2] - 4.5).abs() < 1e-5); } + + // --- AnyCache::trim_by tests --- + + #[test] + fn any_cache_trim_by_kv_dispatches_to_each_layer() { + // Two KV layers, both at offset 0; trim_by saturates to 0. + // Verifies the dispatcher iterates None and Some(_) layers without panic. + let mut cache = AnyCache::KV(vec![ + Some(cache::SteppingKeyValueCache::new()), + None, + Some(cache::SteppingKeyValueCache::new()), + ]); + cache.trim_by(5); + if let AnyCache::KV(layers) = &cache { + assert_eq!(layers.len(), 3); + for layer in layers.iter().flatten() { + assert_eq!(layer.offset(), 0); + } + } else { + panic!("expected KV variant"); + } + } + + #[test] + fn any_cache_trim_by_hybrid_skips_arrays_layers() { + // Hybrid mixes LayerCache::KV (trimmable) and LayerCache::Arrays (recurrent, + // intentionally untouched). Verifies the dispatcher reaches into KV layers + // and leaves Arrays alone. + let mut arrays = qwen3_next::ArraysCache::new(); + arrays.offset = 7; + let mut cache = AnyCache::Hybrid(vec![ + Some(LayerCache::KV(cache::SteppingKeyValueCache::new())), + Some(LayerCache::Arrays(arrays)), + None, + ]); + cache.trim_by(3); + if let AnyCache::Hybrid(layers) = &cache { + assert_eq!(layers.len(), 3); + // KV layer trimmed (saturated at 0 since starting offset was 0) + if let Some(LayerCache::KV(kv)) = layers.first().and_then(|l| l.as_ref()) { + assert_eq!(kv.offset(), 0); + } else { + panic!("expected first layer to be KV variant"); + } + // Arrays layer offset unchanged (recurrent state, can't trim by offset) + if let Some(LayerCache::Arrays(a)) = layers.get(1).and_then(|l| l.as_ref()) { + assert_eq!(a.offset, 7, "Arrays layer offset must NOT be trimmed"); + } else { + panic!("expected second layer to be Arrays variant"); + } + } else { + panic!("expected Hybrid variant"); + } + } } diff --git a/crates/higgs-models/src/speculation_policy.rs b/crates/higgs-models/src/speculation_policy.rs new file mode 100644 index 00000000..7b204e22 --- /dev/null +++ b/crates/higgs-models/src/speculation_policy.rs @@ -0,0 +1,564 @@ +/// Projection for one speculative depth. +/// +/// `depth` is the number of draft positions verified per cycle. The expected +/// token count includes the target correction/bonus token, so depth 0 would be +/// 1.0 emitted token per ordinary decode step. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct SpeculationProjection { + pub depth: usize, + pub expected_emitted_tokens: f64, + pub cycle_cost_multiplier: f64, + pub speedup: f64, +} + +/// Static description of one GDN value head. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct GateHead { + pub layer: usize, + pub head: usize, + pub a_log: f64, + pub dt_bias: f64, +} + +/// Static score derived from the nominal GDN decay gate. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct GateHeadScore { + pub layer: usize, + pub head: usize, + pub retention: f64, + pub time_constant_steps: f64, +} + +/// Per-head precision decision for head-selective ternary experiments. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HeadPrecision { + PreserveBf16, + QuantizeTernary, +} + +/// Controls how many long-horizon heads stay in higher precision. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct ProtectionConfig { + pub protected_fraction: f64, + pub min_protected: usize, + pub max_protected: Option, + pub max_refresh_interval: usize, +} + +/// Concrete offline plan for one GDN value head. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct HeadExecutionPlan { + pub layer: usize, + pub head: usize, + pub retention: f64, + pub time_constant_steps: f64, + pub precision: HeadPrecision, + pub refresh_every_steps: usize, +} + +/// Expected emitted tokens for conditional per-position acceptance rates. +/// +/// For K draft positions, speculative decoding always emits at least one +/// target token. If position acceptances are conditional on all previous draft +/// positions being accepted, the expectation is: +/// `1 + a0 + a0*a1 + ... + a0*...*aK`. +pub fn expected_emitted_tokens_conditional(position_acceptance: &[f64]) -> f64 { + let mut prefix_probability = 1.0; + let mut emitted = 1.0; + + for &acceptance in position_acceptance { + prefix_probability *= acceptance.clamp(0.0, 1.0); + emitted += prefix_probability; + } + + emitted +} + +/// Project speedup for depths 1..=N using conditional acceptance rates. +/// +/// `cycle_cost_multipliers[i]` is the relative cost of verifying depth `i + 1`. +/// It should include any MTP-head or drafter overhead. A multiplier of 1.10 +/// means one speculative cycle costs 10% more than a baseline decode cycle. +pub fn project_conditional_depths( + position_acceptance: &[f64], + cycle_cost_multipliers: &[f64], +) -> Result, &'static str> { + if cycle_cost_multipliers.len() > position_acceptance.len() { + return Err("cycle_cost_multipliers cannot be longer than position_acceptance"); + } + + let mut projections = Vec::with_capacity(cycle_cost_multipliers.len()); + for (idx, &cycle_cost_multiplier) in cycle_cost_multipliers.iter().enumerate() { + if !cycle_cost_multiplier.is_finite() || cycle_cost_multiplier <= 0.0 { + return Err("cycle cost multipliers must be finite and positive"); + } + + let depth = idx + 1; + let expected_emitted_tokens = + expected_emitted_tokens_conditional_prefix(position_acceptance, depth); + projections.push(SpeculationProjection { + depth, + expected_emitted_tokens, + cycle_cost_multiplier, + speedup: expected_emitted_tokens / cycle_cost_multiplier, + }); + } + + Ok(projections) +} + +/// Pick the highest-speedup projection, preferring shallower depths on ties. +pub fn best_projected_depth( + projections: &[SpeculationProjection], +) -> Option { + projections.iter().copied().max_by(|a, b| { + a.speedup + .total_cmp(&b.speedup) + .then_with(|| b.depth.cmp(&a.depth)) + }) +} + +/// Project end-to-end tok/s from a baseline decode rate and depth projection. +pub fn projected_tps( + baseline_tps: f64, + projection: SpeculationProjection, +) -> Result { + if !baseline_tps.is_finite() || baseline_tps <= 0.0 { + return Err("baseline_tps must be finite and positive"); + } + if !projection.speedup.is_finite() || projection.speedup <= 0.0 { + return Err("projection speedup must be finite and positive"); + } + + Ok(baseline_tps * projection.speedup) +} + +/// Whether a projection clears a target tok/s threshold. +pub fn meets_tps_target( + baseline_tps: f64, + projection: SpeculationProjection, + target_tps: f64, +) -> Result { + if !target_tps.is_finite() || target_tps <= 0.0 { + return Err("target_tps must be finite and positive"); + } + + Ok(projected_tps(baseline_tps, projection)? >= target_tps) +} + +/// Nominal one-step GDN retention from static gate parameters. +/// +/// This evaluates the recurrent decay at `a = 0`: +/// `g = exp(-exp(A_log) * softplus(dt_bias))`. +pub fn nominal_gdn_retention(a_log: f64, dt_bias: f64) -> f64 { + let rate = a_log.exp() * softplus(dt_bias); + (-rate).exp().clamp(0.0, 1.0) +} + +/// Time constant, in decode steps, implied by a one-step retention value. +pub fn refresh_interval_steps(retention: f64, max_interval: f64) -> f64 { + if !max_interval.is_finite() || max_interval <= 0.0 { + return 0.0; + } + if !retention.is_finite() || retention <= 0.0 { + return 0.0; + } + if retention >= 1.0 { + return max_interval; + } + + let interval = -1.0 / retention.ln(); + if interval.is_finite() { + interval.min(max_interval) + } else { + max_interval + } +} + +/// Integer refresh cadence derived from the retention time constant. +/// +/// This floors the continuous time constant so fast-forgetting heads stay on a +/// one-step cadence. Long-horizon heads are capped by `max_interval`. +pub fn refresh_every_steps(retention: f64, max_interval: usize) -> usize { + if max_interval == 0 { + return 0; + } + + let interval = refresh_interval_steps(retention, max_interval as f64); + if interval < 1.0 { + 1 + } else { + interval.floor() as usize + } +} + +/// Score heads by static GDN retention. +pub fn score_gdn_heads(heads: &[GateHead]) -> Vec { + heads + .iter() + .map(|head| { + let retention = nominal_gdn_retention(head.a_log, head.dt_bias); + GateHeadScore { + layer: head.layer, + head: head.head, + retention, + time_constant_steps: refresh_interval_steps(retention, f64::INFINITY), + } + }) + .collect() +} + +/// Return the longest-horizon heads first for selective protection. +pub fn select_long_horizon_heads(heads: &[GateHead], protect_count: usize) -> Vec { + let mut scores = score_gdn_heads(heads); + scores.sort_by(|a, b| { + b.retention + .total_cmp(&a.retention) + .then_with(|| a.layer.cmp(&b.layer)) + .then_with(|| a.head.cmp(&b.head)) + }); + scores.truncate(protect_count.min(scores.len())); + scores +} + +/// Build a BF16-protection and temporal-refresh plan for static GDN heads. +pub fn plan_head_selective_ternary( + heads: &[GateHead], + config: ProtectionConfig, +) -> Result, &'static str> { + if !config.protected_fraction.is_finite() + || config.protected_fraction < 0.0 + || config.protected_fraction > 1.0 + { + return Err("protected_fraction must be finite and within 0..=1"); + } + if config.max_refresh_interval == 0 { + return Err("max_refresh_interval must be at least 1"); + } + + let raw_count = (heads.len() as f64 * config.protected_fraction).ceil() as usize; + let mut protect_count = raw_count.max(config.min_protected).min(heads.len()); + if let Some(max_protected) = config.max_protected { + protect_count = protect_count.min(max_protected); + } + + let protected = select_long_horizon_heads(heads, protect_count); + let protected_keys: std::collections::BTreeSet<(usize, usize)> = protected + .iter() + .map(|head| (head.layer, head.head)) + .collect(); + + let mut scores = score_gdn_heads(heads); + scores.sort_by(|a, b| a.layer.cmp(&b.layer).then_with(|| a.head.cmp(&b.head))); + + Ok(scores + .into_iter() + .map(|score| { + let precision = if protected_keys.contains(&(score.layer, score.head)) { + HeadPrecision::PreserveBf16 + } else { + HeadPrecision::QuantizeTernary + }; + HeadExecutionPlan { + layer: score.layer, + head: score.head, + retention: score.retention, + time_constant_steps: score.time_constant_steps, + precision, + refresh_every_steps: refresh_every_steps( + score.retention, + config.max_refresh_interval, + ), + } + }) + .collect()) +} + +/// Extract static GDN gate heads from flattened model parameter names. +/// +/// Accepts names ending in `model.layers.N.linear_attn.A_log` and +/// `model.layers.N.linear_attn.dt_bias`, including longer prefixes such as +/// `language_model.model.layers.N...`. Non-GDN tensors are ignored. +pub fn extract_gate_heads_from_tensors<'a, I>(tensors: I) -> Result, String> +where + I: IntoIterator, +{ + let mut by_layer: std::collections::BTreeMap> = + std::collections::BTreeMap::new(); + + for (name, values) in tensors { + let Some((layer, kind)) = parse_static_gate_name(name) else { + continue; + }; + let entry = by_layer.entry(layer).or_default(); + match kind { + StaticGateKind::ALog => entry.a_log = Some(values), + StaticGateKind::DtBias => entry.dt_bias = Some(values), + } + } + + let mut heads = Vec::new(); + for (layer, pair) in by_layer { + let a_log = pair + .a_log + .ok_or_else(|| format!("layer {layer}: missing A_log"))?; + let dt_bias = pair + .dt_bias + .ok_or_else(|| format!("layer {layer}: missing dt_bias"))?; + + if a_log.len() != dt_bias.len() { + return Err(format!( + "layer {layer}: mismatched A_log ({}) and dt_bias ({}) lengths", + a_log.len(), + dt_bias.len() + )); + } + + for (head, (&a_log, &dt_bias)) in a_log.iter().zip(dt_bias.iter()).enumerate() { + if !a_log.is_finite() || !dt_bias.is_finite() { + return Err(format!("layer {layer} head {head}: non-finite gate value")); + } + heads.push(GateHead { + layer, + head, + a_log, + dt_bias, + }); + } + } + + Ok(heads) +} + +#[derive(Default)] +struct StaticGatePair<'a> { + a_log: Option<&'a [f64]>, + dt_bias: Option<&'a [f64]>, +} + +#[derive(Clone, Copy)] +enum StaticGateKind { + ALog, + DtBias, +} + +fn parse_static_gate_name(name: &str) -> Option<(usize, StaticGateKind)> { + let (_, after_layers) = name.split_once(".layers.")?; + let (layer, after_layer) = after_layers.split_once('.')?; + let layer = layer.parse::().ok()?; + let suffix = after_layer.strip_prefix("linear_attn.")?; + match suffix { + "A_log" => Some((layer, StaticGateKind::ALog)), + "dt_bias" => Some((layer, StaticGateKind::DtBias)), + _ => None, + } +} + +fn expected_emitted_tokens_conditional_prefix(position_acceptance: &[f64], depth: usize) -> f64 { + let mut prefix_probability = 1.0; + let mut emitted = 1.0; + + for &acceptance in position_acceptance.iter().take(depth) { + prefix_probability *= acceptance.clamp(0.0, 1.0); + emitted += prefix_probability; + } + + emitted +} + +fn softplus(x: f64) -> f64 { + if x > 40.0 { + x + } else if x < -40.0 { + x.exp() + } else { + (1.0 + x.exp()).ln() + } +} + +#[cfg(test)] +#[allow(clippy::indexing_slicing, clippy::panic, clippy::unwrap_used)] +mod tests { + use super::*; + + fn assert_close(actual: f64, expected: f64, tolerance: f64) { + assert!( + (actual - expected).abs() <= tolerance, + "actual={actual} expected={expected} tolerance={tolerance}" + ); + } + + #[test] + fn conditional_acceptance_matches_geometric_mtp_model() { + let emitted = expected_emitted_tokens_conditional(&[0.6, 0.6, 0.6]); + assert_close(emitted, 2.176, 1e-12); + } + + #[test] + fn decayed_position_acceptance_can_select_mtp2_over_mtp3() { + let projections = + project_conditional_depths(&[0.847, 0.587, 0.493], &[1.03, 1.10, 1.35]).unwrap(); + + assert_close(projections[1].expected_emitted_tokens, 2.344_189, 1e-6); + let best = best_projected_depth(&projections).unwrap(); + assert_eq!(best.depth, 2); + } + + #[test] + fn static_gate_scores_rank_long_horizon_heads_first() { + let heads = [ + GateHead { + layer: 2, + head: 0, + a_log: 0.0, + dt_bias: 0.0, + }, + GateHead { + layer: 0, + head: 8, + a_log: -10.0, + dt_bias: -10.0, + }, + GateHead { + layer: 1, + head: 3, + a_log: 2.0, + dt_bias: 1.0, + }, + ]; + + let protected = select_long_horizon_heads(&heads, 2); + assert_eq!(protected[0].layer, 0); + assert_eq!(protected[0].head, 8); + assert!(protected[0].retention > 0.999_999_99); + assert_eq!(protected[1].layer, 2); + assert_eq!(protected[1].head, 0); + } + + #[test] + fn refresh_interval_tracks_retention_time_constant() { + assert_close( + refresh_interval_steps(0.5, 512.0), + 1.442_695_040_888_963_4, + 1e-12, + ); + assert_close( + refresh_interval_steps(0.99, 512.0), + 99.499_162_473_422_07, + 1e-10, + ); + assert_close(refresh_interval_steps(0.999_999, 128.0), 128.0, 1e-12); + } + + #[test] + fn head_selective_ternary_plan_protects_configured_fraction() { + let heads = [ + GateHead { + layer: 0, + head: 0, + a_log: -10.0, + dt_bias: -10.0, + }, + GateHead { + layer: 0, + head: 1, + a_log: 0.0, + dt_bias: 0.0, + }, + GateHead { + layer: 0, + head: 2, + a_log: 2.0, + dt_bias: 1.0, + }, + GateHead { + layer: 0, + head: 3, + a_log: 3.0, + dt_bias: 2.0, + }, + ]; + + let plan = plan_head_selective_ternary( + &heads, + ProtectionConfig { + protected_fraction: 0.25, + min_protected: 1, + max_protected: None, + max_refresh_interval: 128, + }, + ) + .unwrap(); + + assert_eq!(plan.len(), 4); + assert_eq!(plan[0].precision, HeadPrecision::PreserveBf16); + assert_eq!(plan[0].refresh_every_steps, 128); + assert!( + plan[1..] + .iter() + .all(|entry| entry.precision == HeadPrecision::QuantizeTernary) + ); + } + + #[test] + fn temporal_schedule_is_conservative_for_fast_forgetting_heads() { + assert_eq!(refresh_every_steps(0.5, 512), 1); + assert_eq!(refresh_every_steps(0.99, 512), 99); + assert_eq!(refresh_every_steps(0.999_999, 128), 128); + } + + #[test] + fn extracts_gate_heads_from_flattened_qwen_parameter_names() { + let a0 = [-10.0, 0.0]; + let dt0 = [-10.0, 0.0]; + let a1 = [2.0]; + let dt1 = [1.0]; + let ignored = [123.0]; + let heads = extract_gate_heads_from_tensors([ + ("model.layers.0.linear_attn.A_log", a0.as_slice()), + ("model.layers.0.linear_attn.dt_bias", dt0.as_slice()), + ( + "language_model.model.layers.1.linear_attn.A_log", + a1.as_slice(), + ), + ( + "language_model.model.layers.1.linear_attn.dt_bias", + dt1.as_slice(), + ), + ("model.layers.3.self_attn.q_proj.weight", ignored.as_slice()), + ]) + .unwrap(); + + assert_eq!(heads.len(), 3); + assert_eq!(heads[0].layer, 0); + assert_eq!(heads[0].head, 0); + assert_close(heads[0].a_log, -10.0, 0.0); + assert_close(heads[0].dt_bias, -10.0, 0.0); + assert_eq!(heads[2].layer, 1); + assert_eq!(heads[2].head, 0); + } + + #[test] + fn extractor_rejects_mismatched_static_gate_shapes() { + let a0 = [0.0, 1.0]; + let dt0 = [0.0]; + let err = extract_gate_heads_from_tensors([ + ("model.layers.0.linear_attn.A_log", a0.as_slice()), + ("model.layers.0.linear_attn.dt_bias", dt0.as_slice()), + ]) + .unwrap_err(); + + assert!(err.contains("mismatched")); + } + + #[test] + fn throughput_projection_confirms_mtp2_clears_20_tps_target() { + let projections = + project_conditional_depths(&[0.847, 0.587, 0.493], &[1.03, 1.10, 1.35]).unwrap(); + let best = best_projected_depth(&projections).unwrap(); + + assert_eq!(best.depth, 2); + assert_close(projected_tps(12.0, best).unwrap(), 25.572_970_909, 1e-9); + assert!(meets_tps_target(12.0, best, 20.0).unwrap()); + } +} diff --git a/crates/higgs/src/config.rs b/crates/higgs/src/config.rs index b9967aa4..58fbf9d1 100644 --- a/crates/higgs/src/config.rs +++ b/crates/higgs/src/config.rs @@ -201,6 +201,14 @@ pub struct ServeArgs { #[arg(long)] pub batch: bool, + /// Path to a draft model for speculative decoding (MLX or ANE). + #[arg(long)] + pub draft_model: Option, + + /// Number of draft tokens per speculative cycle (default: 8). + #[arg(long, default_value = "8")] + pub num_draft: usize, + /// KV cache mode for simple mode models. #[arg(long, value_name = "MODE", value_parser = ["off", "turboquant"])] pub kv_cache: Option, @@ -389,6 +397,12 @@ pub struct ModelConfig { /// Enable the separate batch engine for this model. #[serde(default)] pub batch: bool, + /// Path to a draft model for speculative decoding. + #[serde(default)] + pub draft_model: Option, + /// Number of draft tokens per speculative cycle. + #[serde(default = "default_num_draft")] + pub num_draft: usize, /// KV-cache storage mode. #[serde(default)] pub kv_cache: KvCacheMode, @@ -416,6 +430,10 @@ const fn default_norm_correction() -> bool { true } +const fn default_num_draft() -> usize { + 8 +} + const fn default_kv_bits() -> u8 { 3 } @@ -655,6 +673,8 @@ pub fn build_simple_config(args: &ServeArgs) -> Result { name: None, mlx_profile: None, batch: args.batch, + draft_model: args.draft_model.clone(), + num_draft: args.num_draft, kv_cache, kv_bits: args.kv_bits.unwrap_or(default_kv_bits()), kv_key_bits: args.kv_key_bits, @@ -742,6 +762,8 @@ pub fn load_config_file(path: &Path, args: Option<&ServeArgs>) -> Result DoctorResult { check_config_valid(&mut result); check_server_section(config, &mut result); check_models(config, &mut result); + check_draft_models(config, &mut result); check_duplicate_models(config, &mut result); check_providers(config, &mut result).await; check_route_consistency(config, &mut result); @@ -219,6 +220,30 @@ fn check_models(config: &HiggsConfig, result: &mut DoctorResult) { } } +fn check_draft_models(config: &HiggsConfig, result: &mut DoctorResult) { + for model in &config.models { + let Some(ref draft_path) = model.draft_model else { + continue; + }; + let label = model_label(model); + match model_resolver::resolve(draft_path) { + Ok(_) => pass(&format!("draft model for {label} resolvable"), result), + Err(err) => fail( + &format!("draft model \"{draft_path}\" for {label} not found: {err}"), + result, + ), + } + if model.batch { + warn( + &format!( + "{label} has draft_model but batch=true; speculative decoding is only supported with SimpleEngine" + ), + result, + ); + } + } +} + fn check_duplicate_models(config: &HiggsConfig, result: &mut DoctorResult) { let mut seen_paths = HashSet::new(); let mut seen_names = HashSet::new(); @@ -444,6 +469,26 @@ mod tests { } } + /// Build a `ModelConfig` with sensible test defaults. Tests override only + /// the fields they care about via struct-update syntax. + fn test_model_config(path: &str) -> ModelConfig { + ModelConfig { + path: path.to_owned(), + name: None, + mlx_profile: None, + batch: false, + draft_model: None, + num_draft: 8, + kv_cache: higgs_models::turboquant::KvCacheMode::Off, + kv_bits: 3, + kv_seed: 0, + kv_key_bits: None, + kv_value_bits: None, + kv_norm_correction: true, + kv_adaptive_dense_layers: 0, + } + } + // -- Helper function counter tests -- #[test] @@ -484,6 +529,8 @@ mod tests { name: None, mlx_profile: None, batch: false, + draft_model: None, + num_draft: 8, kv_cache: higgs_models::turboquant::KvCacheMode::Off, kv_bits: 3, kv_seed: 0, @@ -497,6 +544,8 @@ mod tests { name: None, mlx_profile: None, batch: false, + draft_model: None, + num_draft: 8, kv_cache: higgs_models::turboquant::KvCacheMode::Off, kv_bits: 3, kv_seed: 0, @@ -523,6 +572,8 @@ mod tests { name: None, mlx_profile: None, batch: false, + draft_model: None, + num_draft: 8, kv_cache: higgs_models::turboquant::KvCacheMode::Off, kv_bits: 3, kv_seed: 0, @@ -536,6 +587,8 @@ mod tests { name: None, mlx_profile: None, batch: false, + draft_model: None, + num_draft: 8, kv_cache: higgs_models::turboquant::KvCacheMode::Off, kv_bits: 3, kv_seed: 0, @@ -671,6 +724,52 @@ mod tests { assert_eq!(result.failures, 0); } + // -- Draft model validation -- + + #[test] + fn test_draft_model_not_found_fails() { + let config = HiggsConfig { + models: vec![ModelConfig { + draft_model: Some("org/nonexistent-draft".to_owned()), + ..test_model_config("org/target-model") + }], + ..HiggsConfig::default() + }; + let mut result = empty_result(); + check_draft_models(&config, &mut result); + assert_eq!(result.failures, 1); + } + + #[test] + fn test_draft_model_with_batch_warns() { + let config = HiggsConfig { + models: vec![ModelConfig { + batch: true, + draft_model: Some("org/some-draft".to_owned()), + ..test_model_config("org/target-model") + }], + ..HiggsConfig::default() + }; + let mut result = empty_result(); + check_draft_models(&config, &mut result); + // Fails for unresolvable path + warns for batch incompatibility + assert!(result.failures >= 1); + assert_eq!(result.warnings, 1); + } + + #[test] + fn test_no_draft_model_skips() { + let config = HiggsConfig { + models: vec![test_model_config("org/model")], + ..HiggsConfig::default() + }; + let mut result = empty_result(); + check_draft_models(&config, &mut result); + assert_eq!(result.passes, 0); + assert_eq!(result.warnings, 0); + assert_eq!(result.failures, 0); + } + // -- Default provider -- #[test] @@ -946,6 +1045,8 @@ mod tests { name: None, mlx_profile: None, batch: false, + draft_model: None, + num_draft: 8, kv_cache: higgs_models::turboquant::KvCacheMode::Off, kv_bits: 3, kv_seed: 0, @@ -976,6 +1077,8 @@ mod tests { name: None, mlx_profile: None, batch: false, + draft_model: None, + num_draft: 8, kv_cache: higgs_models::turboquant::KvCacheMode::Off, kv_bits: 3, kv_seed: 0, @@ -1008,6 +1111,8 @@ mod tests { name: None, mlx_profile: None, batch: false, + draft_model: None, + num_draft: 8, kv_cache: higgs_models::turboquant::KvCacheMode::Off, kv_bits: 3, kv_seed: 0, diff --git a/crates/higgs/src/tui/mod.rs b/crates/higgs/src/tui/mod.rs index 25423366..94187907 100644 --- a/crates/higgs/src/tui/mod.rs +++ b/crates/higgs/src/tui/mod.rs @@ -646,6 +646,8 @@ mod tests { name: None, mlx_profile: None, batch: false, + draft_model: None, + num_draft: 8, kv_cache: higgs_models::turboquant::KvCacheMode::Off, kv_bits: 3, kv_seed: 0,