diff --git a/crates/rmlx-models/src/gemma4/audio/mod.rs b/crates/rmlx-models/src/gemma4/audio/mod.rs index 7274ba3..a3b3c2c 100644 --- a/crates/rmlx-models/src/gemma4/audio/mod.rs +++ b/crates/rmlx-models/src/gemma4/audio/mod.rs @@ -1057,5 +1057,11 @@ fn build_attention( } } +// --------------------------------------------------------------------------- +// Unified (encoder-free) audio embedder — `Gemma4UnifiedForConditionalGeneration`. +// --------------------------------------------------------------------------- + +pub mod unified; + #[cfg(test)] mod tests; diff --git a/crates/rmlx-models/src/gemma4/audio/unified.rs b/crates/rmlx-models/src/gemma4/audio/unified.rs new file mode 100644 index 0000000..7e86163 --- /dev/null +++ b/crates/rmlx-models/src/gemma4/audio/unified.rs @@ -0,0 +1,359 @@ +//! Gemma4 **unified** (`gemma4_unified` / `Gemma4UnifiedForConditionalGeneration`) +//! encoder-free audio embedder. +//! +//! Faithful host+MLX port of the HF Transformers `gemma4_unified` +//! `Gemma4UnifiedAudioFeatureExtractor` (waveform → fixed-length frames) + +//! `Gemma4UnifiedMultimodalEmbedder` (`RMSNormNoScale -> embedding_projection`). +//! +//! ## Why this is a different path than the Conformer audio tower +//! +//! The unified 12B has **no audio transformer**. Audio is early-fusion: the raw +//! 16 kHz mono waveform is chunked into fixed-length frames of +//! `audio_samples_per_token` (640) samples — 40 ms each at 16 kHz — and each +//! frame is projected straight into the shared 48-layer LM hidden space via the +//! same `embed_audio.embedding_projection` (`RMSNormNoScale -> Linear`) the +//! standard family reuses for its Conformer output. The 12B snapshot ships only +//! `embed_audio.embedding_projection.{weight,scales}` — there is **no** +//! `audio_tower.*`. The standard `gemma4` family (e4b/26b) keeps the existing +//! [`super::AudioEncoder`] Conformer path. +//! +//! ## Reference pipeline (`Gemma4UnifiedAudioFeatureExtractor._extract_waveform_features`) +//! +//! ```text +//! pad_len = (-len(waveform)) % audio_samples_per_token # zero-pad tail +//! num_tok = len(waveform) // audio_samples_per_token # ceil(n / 640) +//! features = waveform.reshape(num_tok, audio_samples_per_token) # [T, 640] f32 +//! ``` +//! +//! There is **no** mel spectrogram, no windowing, no per-sample normalization or +//! scaling — the raw float waveform is fed directly. Since there is no +//! downsampling, `num_soft_tokens == num_tok == ceil(num_samples / 640)`. +//! `audio_embed_dim == audio_samples_per_token == 640` (each soft token's +//! feature vector is exactly one frame of raw samples). The frames are then +//! projected via [`super::vision::MultimodalEmbedder`] (`embed_audio`) into the +//! 3840-wide text hidden space and scattered at the `<|audio|>` positions, +//! exactly mirroring [`super::vision::unified::build_unified_inputs_embeds`]. + +use std::path::Path; + +use rmlx_core::error::{Error, Result}; +use rmlx_mlx::{multiply, scalar_f32, Array, Device}; +use tracing::{debug, info}; + +use super::super::vision::MultimodalEmbedder; + +// --------------------------------------------------------------------------- +// Config +// --------------------------------------------------------------------------- + +/// `audio_config` for `gemma4_unified` (`model_type: gemma4_unified_audio`). +/// +/// Distinct from [`super::super::config::Gemma4AudioConfig`] (the Conformer +/// tower): the unified audio config has **no** `num_hidden_layers` / heads / +/// conv params — it is encoder-free. Fields are the complete contract for the +/// embedder. +#[allow( + clippy::exhaustive_structs, + reason = "internal closed config struct — fields are the complete gemma4_unified audio-embedder contract; adding a field requires updating from_json" +)] +#[derive(Debug, Clone)] +pub struct UnifiedAudioConfig { + /// Raw-sample feature dim of one audio frame (640). Equals + /// `audio_samples_per_token` on the snapshot; the `embed_audio` projection + /// consumes vectors of this width. + pub audio_embed_dim: usize, + /// Raw audio samples grouped into one soft token (640 — 40 ms at 16 kHz). + pub audio_samples_per_token: usize, + /// `embed_audio.embedding_projection` input dim (640 = output_proj_dims). + /// Validated against the loaded projection weight at load time. + pub output_proj_dims: usize, + /// RMSNorm epsilon (1e-6). + pub rms_norm_eps: f32, + /// `` id (top-level `audio_token_id`, 258881 on the 12B). + pub audio_token_id: u32, +} + +impl UnifiedAudioConfig { + /// Parse from the `audio_config` JSON object. `audio_token_id` is the + /// top-level config value. Missing keys fall back to the verified + /// `gemma-4-12B` `gemma4_unified_audio` values. + pub fn from_json(v: &serde_json::Value, audio_token_id: u32) -> Self { + let u = |key: &str, dflt: usize| -> usize { + v.get(key) + .and_then(serde_json::Value::as_u64) + .map_or(dflt, |x| x as usize) + }; + let f = |key: &str, dflt: f32| -> f32 { + v.get(key) + .and_then(serde_json::Value::as_f64) + .map_or(dflt, |x| x as f32) + }; + Self { + audio_embed_dim: u("audio_embed_dim", 640), + audio_samples_per_token: u("audio_samples_per_token", 640), + output_proj_dims: u("output_proj_dims", 640), + rms_norm_eps: f("rms_norm_eps", 1e-6), + audio_token_id, + } + } + + /// Read `audio_config` from a model dir when `architectures[0]` is the + /// unified arch. Returns `None` if there is no `audio_config` key. + pub fn from_model_dir(model_dir: &Path) -> Result> { + let path = model_dir.join("config.json"); + let data = std::fs::read(&path).map_err(|e| { + Error::Config(format!( + "gemma4_unified: cannot read {}: {e}", + path.display() + )) + })?; + let v: serde_json::Value = serde_json::from_slice(&data).map_err(|e| { + Error::Config(format!( + "gemma4_unified: malformed config.json at {}: {e}", + path.display() + )) + })?; + let audio_token_id = v + .get("audio_token_id") + .and_then(serde_json::Value::as_u64) + .unwrap_or(258_881) as u32; + Ok(v.get("audio_config") + .map(|ac| Self::from_json(ac, audio_token_id))) + } +} + +// --------------------------------------------------------------------------- +// Encoder-free audio feature extraction (host) +// --------------------------------------------------------------------------- + +/// Chunk a raw 16 kHz mono waveform into fixed-length frames of +/// `audio_samples_per_token` samples. The tail is zero-padded to a multiple of +/// the frame size (within the last frame — no extra frame is created). +/// +/// Faithful to `Gemma4UnifiedAudioFeatureExtractor._extract_waveform_features`: +/// no normalization, no windowing, raw float samples. Returns the flat +/// `[num_tokens * audio_samples_per_token]` f32 buffer plus `num_tokens`. +#[must_use] +#[allow( + clippy::indexing_slicing, + reason = "frames is sized num_tokens*spt >= samples.len() by construction (ceil-div), so the [..samples.len()] prefix is always in bounds" +)] +pub fn extract_waveform_frames( + samples: &[f32], + audio_samples_per_token: usize, +) -> (Vec, usize) { + let spt = audio_samples_per_token.max(1); + // ceil(len / spt); pad the tail with zeros to a full frame. + let num_tokens = samples.len().div_ceil(spt); + let mut frames = vec![0.0_f32; num_tokens * spt]; + frames[..samples.len()].copy_from_slice(samples); + (frames, num_tokens) +} + +/// Number of audio soft tokens a clip of `num_samples` raw samples yields: +/// `ceil(num_samples / audio_samples_per_token)`. Must equal the model-frame +/// count produced by [`extract_waveform_frames`]. +#[inline] +#[must_use] +pub fn unified_num_audio_soft_tokens(num_samples: usize, cfg: &UnifiedAudioConfig) -> usize { + num_samples.div_ceil(cfg.audio_samples_per_token.max(1)) +} + +// --------------------------------------------------------------------------- +// Unified audio embedder +// --------------------------------------------------------------------------- + +/// Encoder-free unified audio embedder (`embed_audio.*`). +/// +/// Wraps the shared [`MultimodalEmbedder`] (`RMSNormNoScale -> embedding_projection`) +/// — the audio path has no tower, conv front-end, or per-frame state, so the +/// embedder *is* the projection. The host feature front-end +/// ([`extract_waveform_frames`]) runs before [`Self::forward`]. +#[allow(missing_debug_implementations)] +pub struct UnifiedAudioEmbedder { + cfg: UnifiedAudioConfig, + /// Shared `RMSNormNoScale -> embedding_projection` (`embed_audio.*`). + embed_audio: MultimodalEmbedder, +} + +impl UnifiedAudioEmbedder { + /// Parsed unified-audio sub-config. + pub fn config(&self) -> &UnifiedAudioConfig { + &self.cfg + } + + /// Project chunked waveform frames into the text hidden space. + /// + /// `frames`: flat `[num_tokens * audio_embed_dim]` f32 from + /// [`extract_waveform_frames`]. Returns `[1, num_tokens, text_hidden]` ready + /// to scatter into `inputs_embeds`. + pub fn forward(&self, frames: &[f32], num_tokens: usize, device: Device) -> Result { + let dim = self.cfg.audio_embed_dim as i32; + let frame_arr = Array::from_f32_slice(frames, &[1, num_tokens as i32, dim])?; + let out = self.embed_audio.forward(&frame_arr, device)?; + debug!( + num_tokens, + audio_embed_dim = self.cfg.audio_embed_dim, + "gemma4_unified audio: embedder forward" + ); + Ok(out) + } +} + +// --------------------------------------------------------------------------- +// Loader +// --------------------------------------------------------------------------- + +/// Load the unified audio embedder (`embed_audio.*`) from a snapshot directory. +/// Errors if the unified-audio weights are absent (caller disables audio input +/// on error). +pub fn load_unified_audio_embedder( + model_dir: &Path, + cfg: &UnifiedAudioConfig, +) -> Result { + // embed_audio: RMSNormNoScale -> embedding_projection (reused loader). + let embed_audio = + super::super::vision::load_multimodal_embedder(model_dir, "embed_audio", cfg.rms_norm_eps)?; + + // Validate the parsed `output_proj_dims` / `audio_embed_dim` against the + // loaded projection's actual input feature dim — making the config field + // load-bearing (mirrors the unified-vision loader's check). A checkpoint + // whose `embed_audio.embedding_projection` does not consume + // `output_proj_dims` features is rejected here, not deep in forward. + if let Some(proj_in) = embed_audio.projection_input_dim() { + if proj_in != cfg.output_proj_dims { + return Err(Error::Loader(format!( + "gemma4_unified audio: output_proj_dims ({}) != embed_audio.embedding_projection \ + input dim ({proj_in}) — config/checkpoint mismatch", + cfg.output_proj_dims + ))); + } + if proj_in != cfg.audio_embed_dim { + return Err(Error::Loader(format!( + "gemma4_unified audio: audio_embed_dim ({}) != embed_audio.embedding_projection \ + input dim ({proj_in}) — config/checkpoint mismatch", + cfg.audio_embed_dim + ))); + } + } + + info!( + audio_embed_dim = cfg.audio_embed_dim, + audio_samples_per_token = cfg.audio_samples_per_token, + audio_token_id = cfg.audio_token_id, + "gemma4_unified audio: embedder loaded (encoder-free)" + ); + + Ok(UnifiedAudioEmbedder { + cfg: cfg.clone(), + embed_audio, + }) +} + +// --------------------------------------------------------------------------- +// build unified inputs_embeds (text + scattered audio soft tokens) +// --------------------------------------------------------------------------- + +/// Build the merged `inputs_embeds` for a unified-arch audio prompt. +/// +/// Mirrors [`super::super::vision::unified::build_unified_inputs_embeds`] but +/// routes the encode through the encoder-free [`UnifiedAudioEmbedder`]. The clip +/// contributes `num_tokens` soft tokens scattered at its contiguous run of +/// `audio_token_id` positions. +/// +/// `frames`: flat `[num_tokens * audio_embed_dim]` f32 (host feature front-end). +/// Returns `(inputs_embeds [1, seq, hidden], masked_ids [seq])`. +#[allow(clippy::too_many_arguments)] +#[allow( + clippy::indexing_slicing, + reason = "bounds established by construction: audio_positions are filtered from input_ids; the run is validated contiguous before slice_update" +)] +pub fn build_unified_audio_inputs_embeds( + model: &super::super::model::Gemma4Text, + embedder: &UnifiedAudioEmbedder, + frames: &[f32], + num_tokens: usize, + audio_token_id: u32, + input_ids: &[u32], + device: Device, +) -> Result<(Array, Array)> { + let hidden = model.cfg.hidden_size as i32; + let seq = input_ids.len(); + + // Locate the audio-token positions (one contiguous run per clip). + let audio_positions: Vec = input_ids + .iter() + .enumerate() + .filter(|(_, &t)| t == audio_token_id) + .map(|(i, _)| i) + .collect(); + if audio_positions.len() != num_tokens { + return Err(Error::Model(format!( + "gemma4_unified audio: {} audio-token ({audio_token_id}) positions in prompt != \ + {num_tokens} audio soft tokens — scatter would misalign", + audio_positions.len() + ))); + } + info!( + audio_tokens = audio_positions.len(), + num_tokens, + seq, + "gemma4_unified audio: building inputs_embeds (token count == soft tokens)" + ); + + // Scaled text embeddings: embed_tokens(ids) * sqrt(hidden). + let ids_i32: Vec = input_ids.iter().map(|&x| x as i32).collect(); + let ids_arr = Array::from_i32_slice(&ids_i32, &[seq as i32])?; + let h_raw = model.embed_tokens.forward(&ids_arr, device)?; + let embed_scale = scalar_f32((model.cfg.hidden_size as f32).sqrt()); + let mut embeds = multiply(&h_raw, &embed_scale, device)?; + embeds = embeds.reshape(&[1, seq as i32, hidden], device)?; + let embeds_dtype = embeds.dtype(); + + // Encode audio frames -> [1, num_tokens, hidden] f32. + let feats = embedder.forward(frames, num_tokens, device)?; + let fs = feats.shape(); + if fs.first().copied() != Some(1) + || fs.get(1).copied() != Some(num_tokens as i32) + || fs.get(2).copied() != Some(hidden) + { + return Err(Error::Model(format!( + "gemma4_unified audio: audio feature shape {fs:?} != [1, {num_tokens}, {hidden}]" + ))); + } + let feats = feats.astype(embeds_dtype, device)?; + + // The audio run must be contiguous (one `<|audio|>` run per clip). + let first = audio_positions.first().copied().unwrap_or(0); + let contiguous = audio_positions + .iter() + .enumerate() + .all(|(k, &p)| p == first + k); + if !contiguous { + return Err(Error::Model(format!( + "gemma4_unified audio: audio-token positions are not contiguous (got {audio_positions:?})" + ))); + } + embeds = embeds.slice_update( + &feats, + &[0, first as i32, 0], + &[1, (first + num_tokens) as i32, hidden], + &[1, 1, 1], + device, + )?; + + // Mask audio-token ids to 0 for per-layer-input gating (matches vision). + let mut masked: Vec = ids_i32; + for &p in &audio_positions { + masked[p] = 0; + } + let masked_arr = Array::from_i32_slice(&masked, &[seq as i32])?; + + Ok((embeds, masked_arr)) +} + +// --------------------------------------------------------------------------- + +#[cfg(test)] +#[path = "unified_tests.rs"] +mod unified_tests; diff --git a/crates/rmlx-models/src/gemma4/audio/unified_tests.rs b/crates/rmlx-models/src/gemma4/audio/unified_tests.rs new file mode 100644 index 0000000..ad076a2 --- /dev/null +++ b/crates/rmlx-models/src/gemma4/audio/unified_tests.rs @@ -0,0 +1,180 @@ +//! Model-free unit coverage for the gemma4_unified audio embedder front-end: +//! config parse, soft-token math, and the waveform-frame chunking/padding +//! plumbing (frame count, padding, layout). + +use super::*; + +fn cfg_12b() -> UnifiedAudioConfig { + // Verified gemma-4-12B `gemma4_unified_audio` values. + UnifiedAudioConfig { + audio_embed_dim: 640, + audio_samples_per_token: 640, + output_proj_dims: 640, + rms_norm_eps: 1e-6, + audio_token_id: 258_881, + } +} + +#[test] +#[allow( + clippy::unwrap_used, + reason = "test asserts JSON parse succeeds on a known-good literal" +)] +fn config_parse_matches_snapshot() { + let v: serde_json::Value = serde_json::from_str( + r#"{ + "model_type": "gemma4_unified_audio", + "audio_embed_dim": 640, "audio_samples_per_token": 640, + "output_proj_dims": 640, "rms_norm_eps": 1e-06 + }"#, + ) + .unwrap(); + let cfg = UnifiedAudioConfig::from_json(&v, 258_881); + assert_eq!(cfg.audio_embed_dim, 640); + assert_eq!(cfg.audio_samples_per_token, 640); + assert_eq!(cfg.output_proj_dims, 640); + assert_eq!(cfg.audio_token_id, 258_881); +} + +#[test] +fn config_defaults_when_keys_missing() { + let v: serde_json::Value = serde_json::from_str("{}").unwrap_or(serde_json::Value::Null); + let cfg = UnifiedAudioConfig::from_json(&v, 1); + assert_eq!(cfg.audio_embed_dim, 640); + assert_eq!(cfg.audio_samples_per_token, 640); + assert_eq!(cfg.output_proj_dims, 640); + assert_eq!(cfg.audio_token_id, 1); +} + +#[test] +fn soft_token_count_is_ceil_div_640() { + let cfg = cfg_12b(); + // Exactly one frame. + assert_eq!(unified_num_audio_soft_tokens(640, &cfg), 1); + // One sample over → two frames (ceil). + assert_eq!(unified_num_audio_soft_tokens(641, &cfg), 2); + // One short of a frame → still one frame (ceil rounds up). + assert_eq!(unified_num_audio_soft_tokens(639, &cfg), 1); + // 16000 samples (1 s @ 16 kHz) → ceil(16000/640) = 25 frames. + assert_eq!(unified_num_audio_soft_tokens(16_000, &cfg), 25); +} + +#[test] +fn soft_token_count_zero_for_empty_clip() { + let cfg = cfg_12b(); + assert_eq!(unified_num_audio_soft_tokens(0, &cfg), 0); +} + +#[test] +#[allow( + clippy::indexing_slicing, + reason = "test indices/slices bounded by the asserted frame length (700 < 1280)" +)] +fn extract_frames_pads_tail_to_full_frame() { + // 700 samples → 2 frames (640 + 60), tail zero-padded to 1280. + let samples: Vec = (0..700).map(|i| i as f32).collect(); + let (frames, num_tokens) = extract_waveform_frames(&samples, 640); + assert_eq!(num_tokens, 2); + assert_eq!(frames.len(), 2 * 640); + // The original samples are preserved verbatim (no scaling/normalization). + for (i, &v) in samples.iter().enumerate() { + assert!((frames[i] - v).abs() < f32::EPSILON, "sample {i} mismatch"); + } + // The padded tail (700..1280) is zero. + for &v in &frames[700..] { + assert!(v.abs() < f32::EPSILON, "tail padding not zero"); + } +} + +#[test] +fn extract_frames_exact_multiple_no_extra_frame() { + // 1280 = 2 * 640 exactly → 2 frames, no padding. + let samples = vec![1.0_f32; 1280]; + let (frames, num_tokens) = extract_waveform_frames(&samples, 640); + assert_eq!(num_tokens, 2); + assert_eq!(frames.len(), 1280); + assert!(frames.iter().all(|&v| (v - 1.0).abs() < f32::EPSILON)); +} + +#[test] +fn extract_frames_empty_clip_is_zero_frames() { + let (frames, num_tokens) = extract_waveform_frames(&[], 640); + assert_eq!(num_tokens, 0); + assert!(frames.is_empty()); +} + +#[test] +fn frame_count_matches_soft_token_count() { + // The host front-end frame count must equal the prompt-block soft-token + // count for the scatter to align. + let cfg = cfg_12b(); + for n in [0usize, 1, 639, 640, 641, 1500, 16_000, 48_000] { + let samples = vec![0.5_f32; n]; + let (_frames, num_tokens) = extract_waveform_frames(&samples, cfg.audio_samples_per_token); + assert_eq!( + num_tokens, + unified_num_audio_soft_tokens(n, &cfg), + "frame/soft-token mismatch at n={n}" + ); + } +} + +// --------------------------------------------------------------------------- +// Model-gated integration: load the real 12B `embed_audio` projection and run +// a forward over synthetic frames, asserting the [1, num_tokens, 3840] output +// shape. Skips gracefully when the snapshot is not available. +// --------------------------------------------------------------------------- + +/// Resolve the unified 12B snapshot dir via +/// `/mlx-community__gemma-4-12B-it-mxfp8`. +/// Returns `None` when the root is unset or the directory is absent (test skips). +fn unified_12b_dir() -> Option { + let root = std::env::var_os("RMLX_O_MODELS_ROOT")?; + let p = std::path::PathBuf::from(root).join("mlx-community__gemma-4-12B-it-mxfp8"); + p.exists().then_some(p) +} + +#[test] +#[allow( + clippy::expect_used, + reason = "model-gated integration test: .expect() documents the load/forward invariant; failure here is a genuine test failure" +)] +fn unified_audio_embedder_forward_real_weights() { + let Some(dir) = unified_12b_dir() else { + eprintln!("SKIP: gemma-4-12B snapshot not available (set RMLX_O_MODELS_ROOT to the model root dir)"); + return; + }; + let cfg = match UnifiedAudioConfig::from_model_dir(&dir) { + Ok(Some(c)) => c, + Ok(None) => { + eprintln!("SKIP: no audio_config in {}", dir.display()); + return; + } + Err(e) => { + eprintln!("SKIP: config read failed: {e}"); + return; + } + }; + let embedder = match load_unified_audio_embedder(&dir, &cfg) { + Ok(e) => e, + Err(e) => { + eprintln!("SKIP: embed_audio load failed: {e}"); + return; + } + }; + + // 3 frames of synthetic raw waveform (1920 samples). + let num_tokens = 3usize; + let (frames, n) = extract_waveform_frames(&vec![0.01_f32; 1920], cfg.audio_samples_per_token); + assert_eq!(n, num_tokens); + + let out = embedder + .forward(&frames, num_tokens, Device::Cpu) + .expect("unified audio forward"); + out.eval().expect("eval"); + let shp = out.shape(); + assert_eq!(shp.first().copied(), Some(1), "batch dim"); + assert_eq!(shp.get(1).copied(), Some(num_tokens as i32), "token dim"); + // Projected to text hidden (3840 on the 12B). + assert_eq!(shp.get(2).copied(), Some(3840), "hidden dim"); +} diff --git a/crates/rmlx-models/src/gemma4/mod.rs b/crates/rmlx-models/src/gemma4/mod.rs index 2684292..1e3f191 100644 --- a/crates/rmlx-models/src/gemma4/mod.rs +++ b/crates/rmlx-models/src/gemma4/mod.rs @@ -37,6 +37,10 @@ pub(crate) mod prompt_cache; mod tests; mod vision; +pub use audio::unified::{ + build_unified_audio_inputs_embeds, extract_waveform_frames, load_unified_audio_embedder, + unified_num_audio_soft_tokens, UnifiedAudioConfig, UnifiedAudioEmbedder, +}; pub use audio::{build_audio_inputs_embeds, load_audio_tower, AudioEncoder}; pub use audio_feature_extractor::{ AudioFeatError, AudioFeatureExtractorConfig, Gemma4AudioFeatureExtractor, @@ -52,6 +56,11 @@ pub use preprocessor::{ }; pub use prompt_cache::read_cache_stats as gemma4_cache_stats; pub use prompt_cache::read_kv_cache_bytes as gemma4_kv_cache_bytes; +pub use vision::unified::{ + build_unified_inputs_embeds, is_unified_arch, load_unified_vision_embedder, + unified_image_processor_config, unified_num_soft_tokens, UnifiedVisionConfig, + UnifiedVisionEmbedder, +}; pub use vision::{ build_inputs_embeds, load_multimodal_embedder, load_vision_tower, MultimodalEmbedder, VisionModel, IMAGE_TOKEN_ID, diff --git a/crates/rmlx-models/src/gemma4/vision/mod.rs b/crates/rmlx-models/src/gemma4/vision/mod.rs index 281e1e1..296a98d 100644 --- a/crates/rmlx-models/src/gemma4/vision/mod.rs +++ b/crates/rmlx-models/src/gemma4/vision/mod.rs @@ -307,6 +307,31 @@ impl MultimodalEmbedder { let normed = rms_norm(inputs_embeds, None, self.norm_eps, device)?; self.projection.forward(&normed, device) } + + /// Input feature dim of `embedding_projection` — the unpacked `in_features` + /// the projection consumes (e.g. `mm_embed_dim` 3840 for vision, + /// `audio_embed_dim` 640 for unified audio). Used to validate the parsed + /// `output_proj_dims` config against the actual checkpoint weight so a + /// mismatched config is rejected at load instead of producing a silent + /// shape error deep in the forward pass. + /// + /// For a quantized projection the unpacked input dim is + /// `scales.shape()[1] * group_size`; for a plain projection it is the + /// weight's second axis (`[out, in]`). Returns `None` if the shapes are + /// degenerate (treated as "cannot validate" by the caller). + pub fn projection_input_dim(&self) -> Option { + match &self.projection { + crate::layers::Linear::Plain { weight } => weight.shape().get(1).map(|&d| d as usize), + crate::layers::Linear::Quantized { + scales, group_size, .. + } => scales + .shape() + .get(1) + .map(|&groups| groups as usize * *group_size as usize), + // PARO projections never appear on the multimodal embedder path. + crate::layers::Linear::Paro { .. } => None, + } + } } // --------------------------------------------------------------------------- @@ -532,13 +557,13 @@ fn zero_mask(seq: usize) -> Result { } #[inline] -fn f32_bytes(v: &[f32]) -> &[u8] { +pub(super) fn f32_bytes(v: &[f32]) -> &[u8] { // SAFETY: f32 is 4 bytes; from_bytes copies immediately. unsafe { std::slice::from_raw_parts(v.as_ptr().cast::(), size_of_val(v)) } } #[inline] -fn i32_bytes(v: &[i32]) -> &[u8] { +pub(super) fn i32_bytes(v: &[i32]) -> &[u8] { // SAFETY: i32 is 4 bytes; from_bytes copies immediately. unsafe { std::slice::from_raw_parts(v.as_ptr().cast::(), size_of_val(v)) } } @@ -884,7 +909,7 @@ pub fn build_inputs_embeds( } /// Load a packed quantized weight without dtype conversion (keep U32/U8/F16). -fn load_raw(shards: &ShardSet, name: &str) -> Result { +pub(super) fn load_raw(shards: &ShardSet, name: &str) -> Result { for (_, handle) in shards.iter() { let st = handle.safetensors()?; if let Ok(t) = st.tensor(name) { @@ -904,7 +929,7 @@ fn load_raw(shards: &ShardSet, name: &str) -> Result { /// Read top-level `quantization` (`group_size`, `bits`, `mode`) for the /// `embed_vision.embedding_projection` quantized Linear. -fn read_quant_params(model_dir: &Path) -> Result<(i32, i32, crate::layers::QuantMode)> { +pub(super) fn read_quant_params(model_dir: &Path) -> Result<(i32, i32, crate::layers::QuantMode)> { let v = crate::load_util::read_raw_config(model_dir)?; let q = v.get("quantization"); let gs = q @@ -922,6 +947,12 @@ fn read_quant_params(model_dir: &Path) -> Result<(i32, i32, crate::layers::Quant Ok((gs, bits, crate::layers::QuantMode::from(mode_str))) } +// --------------------------------------------------------------------------- +// Unified (encoder-free) vision embedder — `Gemma4UnifiedForConditionalGeneration`. +// --------------------------------------------------------------------------- + +pub(crate) mod unified; + // --------------------------------------------------------------------------- #[cfg(test)] diff --git a/crates/rmlx-models/src/gemma4/vision/unified.rs b/crates/rmlx-models/src/gemma4/vision/unified.rs new file mode 100644 index 0000000..28fcc11 --- /dev/null +++ b/crates/rmlx-models/src/gemma4/vision/unified.rs @@ -0,0 +1,650 @@ +//! Gemma4 **unified** (`gemma4_unified` / `Gemma4UnifiedForConditionalGeneration`) +//! encoder-free vision embedder. +//! +//! Faithful host+MLX port of the HF Transformers `gemma4_unified` +//! `Gemma4UnifiedVisionEmbedder` (+ image processor `patches_merge` / +//! `convert_image_to_patches`) and `Gemma4UnifiedMultimodalEmbedder`. +//! +//! ## Why this is a different path than the SigLIP tower +//! +//! The unified 12B has **no vision transformer**. Vision is early-fusion: raw +//! pixel patches are projected straight into the shared 48-layer LM hidden +//! space via a single Dense + LayerNorms + factorized 2D positional embedding, +//! producing `num_soft_tokens` (280) soft tokens. The weights are +//! `vision_embedder.*` + `embed_vision.embedding_projection.*` — there is no +//! `vision_tower.*`. The standard `gemma4` family (e4b/26b/31b) keeps the +//! existing [`super::VisionModel`] SigLIP path. +//! +//! ## Reference pipeline (verbatim from `Gemma4UnifiedVisionEmbedder.forward`) +//! +//! ```text +//! hidden = patch_ln1(pixel_values) # LayerNorm over 6912 raw-patch dims +//! hidden = patch_dense(hidden) # Linear 6912 -> mm_embed_dim (3840), +bias +//! hidden = patch_ln2(hidden) # LayerNorm over 3840 +//! pos = pos_embedding[x, 0, :] + pos_embedding[y, 1, :] # factorized 2D, padding -> 0 +//! hidden = pos_norm(hidden + pos) # LayerNorm over 3840 +//! hidden = embed_vision(hidden) # RMSNormNoScale -> embedding_projection (3840 -> text_hidden) +//! ``` +//! +//! Note: `patch_ln1`/`patch_ln2`/`pos_norm` are true **LayerNorm** +//! (mean-subtraction, learned weight **and** bias) — NOT RMSNorm. `embed_vision` +//! is the same `RMSNormNoScale -> Linear` [`super::MultimodalEmbedder`] reused +//! from the tower path. +//! +//! ## Image processing (host) +//! +//! Reference `Gemma4UnifiedImageProcessor`: +//! 1. aspect-ratio-preserving resize to a patch budget (reused from +//! [`super::super::preprocessor`]; identical for both Gemma4 families). +//! 2. rescale to `[0,1]` (`do_normalize=false`, mean=0/std=1 on the snapshot). +//! 3. patchify into 16px teacher patches `[n_teacher, 16*16*3=768]`, layout +//! `[patch_h, patch_w, channel]` (`convert_image_to_patches`). +//! 4. `patches_merge`: group `k=pooling_kernel_size` (3) × k teacher patches +//! into one 48×48 model patch `[n_model, 48*48*3=6912]`; model-patch position +//! = `(min teacher_x // k, min teacher_y // k)`. +//! 5. The model patch interior is laid out `[k*16, k*16, 3]` (kernel rows × +//! kernel cols × teacher-patch pixels), matching the reference reshape +//! `(length, k, k, 16, 16, 3) -> permute -> (length, 48, 48, 3)`. +//! +//! Step 1 here introduces the `2*(x-0.5)` rescale used by the tower's patch +//! embedder. The unified embedder does **not** apply that; it consumes raw +//! `[0,1]` pixels (patch_ln1 normalizes). So this path uses its own patchify. + +use std::path::Path; + +use rmlx_core::error::{Error, Result}; +use rmlx_loader::{load_shard_index, ShardSet}; +use rmlx_mlx::{add, multiply, scalar_f32, subtract, sum_axis_keepdims, Array, Device, Dtype}; +use tracing::{debug, info}; + +use super::{f32_bytes, i32_bytes, load_raw, read_quant_params, MultimodalEmbedder}; +use crate::gemma4::preprocessor::{Gemma4ImageProcessorConfig, Gemma4PixelValues}; + +// --------------------------------------------------------------------------- +// Config +// --------------------------------------------------------------------------- + +/// `vision_config` for `gemma4_unified` (`model_type: gemma4_unified_vision`). +/// +/// Distinct from [`super::super::config::Gemma4VisionConfig`] (the SigLIP +/// tower): the unified vision config has **no** `num_hidden_layers` / heads — +/// it is encoder-free. Fields are the complete contract for the embedder. +#[allow( + clippy::exhaustive_structs, + reason = "internal closed config struct — fields are the complete gemma4_unified vision-embedder contract; adding a field requires updating from_json" +)] +#[derive(Debug, Clone)] +pub struct UnifiedVisionConfig { + /// Multimodal embed dim = LM hidden (3840). `patch_dense` out, LayerNorm dim. + pub mm_embed_dim: usize, + /// Factorized positional-embedding table length per axis (1120). + pub mm_posemb_size: usize, + /// Model patch side after pooling (48). `patch_dim = model_patch_size^2 * 3`. + pub model_patch_size: usize, + /// Teacher patch side (16). + pub patch_size: usize, + /// Pooling kernel (`k`, 3) merging `k*k` teacher patches into one model patch. + pub pooling_kernel_size: usize, + /// Soft tokens produced per image (280). Padding budget. + pub num_soft_tokens: usize, + /// `embed_vision.embedding_projection` input dim (3840 = output_proj_dims). + pub output_proj_dims: usize, + /// LayerNorm / RMSNorm epsilon (1e-6). + pub rms_norm_eps: f32, +} + +impl UnifiedVisionConfig { + /// `patch_dim = model_patch_size^2 * 3` — the raw-patch feature length + /// (`48*48*3 = 6912`) fed into `patch_ln1` / `patch_dense`. + #[inline] + pub fn patch_dim(&self) -> usize { + self.model_patch_size * self.model_patch_size * 3 + } + + /// Parse from the `vision_config` JSON object. Missing keys fall back to the + /// verified `gemma-4-12B` `gemma4_unified_vision` values. + pub fn from_json(v: &serde_json::Value) -> Self { + let u = |key: &str, dflt: usize| -> usize { + v.get(key) + .and_then(serde_json::Value::as_u64) + .map_or(dflt, |x| x as usize) + }; + let f = |key: &str, dflt: f32| -> f32 { + v.get(key) + .and_then(serde_json::Value::as_f64) + .map_or(dflt, |x| x as f32) + }; + Self { + mm_embed_dim: u("mm_embed_dim", 3840), + mm_posemb_size: u("mm_posemb_size", 1120), + model_patch_size: u("model_patch_size", 48), + patch_size: u("patch_size", 16), + pooling_kernel_size: u("pooling_kernel_size", 3), + num_soft_tokens: u("num_soft_tokens", 280), + output_proj_dims: u("output_proj_dims", 3840), + rms_norm_eps: f("rms_norm_eps", 1e-6), + } + } + + /// Read `vision_config` from a model dir when `architectures[0]` is the + /// unified arch. Returns `None` if there is no `vision_config` key. + pub fn from_model_dir(model_dir: &Path) -> Result> { + let path = model_dir.join("config.json"); + let data = std::fs::read(&path).map_err(|e| { + Error::Config(format!( + "gemma4_unified: cannot read {}: {e}", + path.display() + )) + })?; + let v: serde_json::Value = serde_json::from_slice(&data).map_err(|e| { + Error::Config(format!( + "gemma4_unified: malformed config.json at {}: {e}", + path.display() + )) + })?; + Ok(v.get("vision_config").map(Self::from_json)) + } +} + +/// Returns true when `architectures[0] == "Gemma4UnifiedForConditionalGeneration"`. +/// +/// The unified 12B loads through the same [`crate::arch::Architecture::Gemma4`] +/// text path as the tower family; this string check is the dispatch that routes +/// the vision/audio front-end to the encoder-free embedder instead of the +/// SigLIP `vision_tower` loader. +pub fn is_unified_arch(model_dir: &Path) -> bool { + let path = model_dir.join("config.json"); + let Ok(data) = std::fs::read(&path) else { + return false; + }; + let Ok(v) = serde_json::from_slice::(&data) else { + return false; + }; + v.get("architectures") + .and_then(serde_json::Value::as_array) + .and_then(|a| a.first()) + .and_then(serde_json::Value::as_str) + == Some("Gemma4UnifiedForConditionalGeneration") +} + +// --------------------------------------------------------------------------- +// LayerNorm (weight + bias, mean-subtraction) — composed from primitives +// --------------------------------------------------------------------------- + +/// Standard LayerNorm over the last axis: `(x - mean) / sqrt(var + eps) * w + b`. +/// +/// Distinct from [`crate::layers::RmsNorm`] (no mean-subtraction). The unified +/// embedder's `patch_ln1` / `patch_ln2` / `pos_norm` are PyTorch `nn.LayerNorm` +/// (the upstream class names say "ln", and the weights carry both `.weight` and +/// `.bias`). Computed in f32 for stability — the whole embedder runs once per +/// image, off the decode loop. +#[allow(missing_debug_implementations)] +struct LayerNorm { + weight: Array, + bias: Array, + eps: f32, +} + +impl LayerNorm { + /// `x`: `[..., dim]`. Returns same shape. + fn forward(&self, x: &Array, device: Device) -> Result { + let x = x.astype(Dtype::F32, device)?; + let shape = x.shape(); + if shape.is_empty() { + return Err(Error::Model( + "gemma4_unified LayerNorm: rank-0 input has no last axis to normalize".to_owned(), + )); + } + let axis = (shape.len() as i32) - 1; + // SAFETY: shape non-empty (checked above) → last() is Some. + let dim = *shape.last().unwrap_or(&1) as f32; + // mean over last axis (keepdims for broadcast). + let sum = sum_axis_keepdims(&x, axis, device)?; + let mean = multiply(&sum, &scalar_f32(1.0 / dim), device)?; + let centered = subtract(&x, &mean, device)?; + // var = mean(centered^2) + let sq = multiply(¢ered, ¢ered, device)?; + let sq_sum = sum_axis_keepdims(&sq, axis, device)?; + let var = multiply(&sq_sum, &scalar_f32(1.0 / dim), device)?; + let denom = rmlx_mlx::sqrt(&add(&var, &scalar_f32(self.eps), device)?, device)?; + let normed = rmlx_mlx::divide(¢ered, &denom, device)?; + let scaled = multiply(&normed, &self.weight, device)?; + add(&scaled, &self.bias, device) + } +} + +// --------------------------------------------------------------------------- +// Unified vision embedder +// --------------------------------------------------------------------------- + +/// Encoder-free unified vision embedder (`vision_embedder.*` + `embed_vision.*`). +#[allow(missing_debug_implementations)] +pub struct UnifiedVisionEmbedder { + cfg: UnifiedVisionConfig, + patch_ln1: LayerNorm, + /// `patch_dense`: `[mm_embed_dim, patch_dim]` (possibly quantized) + bias. + patch_dense: crate::layers::Linear, + patch_dense_bias: Array, + patch_ln2: LayerNorm, + /// `[mm_posemb_size, 2, mm_embed_dim]` factorized 2D table (f32). + pos_embedding: Array, + pos_norm: LayerNorm, + /// Shared `RMSNormNoScale -> embedding_projection` (`embed_vision.*`). + embed_vision: MultimodalEmbedder, +} + +impl UnifiedVisionEmbedder { + /// Parsed unified-vision sub-config. + pub fn config(&self) -> &UnifiedVisionConfig { + &self.cfg + } + + /// Run the full embedder over one preprocessed image. Returns + /// `[1, num_soft_tokens, text_hidden]` ready to scatter into `inputs_embeds`. + /// + /// `pv` is the resized/rescaled `[1, 3, H, W]` CHW buffer from the shared + /// Gemma4 preprocessor; `H`/`W` are multiples of `model_patch_size`. + #[allow( + clippy::indexing_slicing, + reason = "bounds established by construction: patch loop indices bounded by host-computed patch grid" + )] + pub fn forward(&self, pv: &Gemma4PixelValues, device: Device) -> Result { + let (patches, x_idx, y_idx) = self.patchify_and_merge(pv)?; + let n_model = x_idx.len(); + let patch_dim = self.cfg.patch_dim() as i32; + let mm = self.cfg.mm_embed_dim as i32; + + // Raw merged patches -> device [n_model, patch_dim]. + let patch_arr = Array::from_bytes( + f32_bytes(&patches), + &[n_model as i32, patch_dim], + Dtype::F32, + )?; + + // patch_ln1 -> patch_dense (+bias) -> patch_ln2 + let mut h = self.patch_ln1.forward(&patch_arr, device)?; + h = self.patch_dense.forward(&h, device)?; + h = add(&h, &self.patch_dense_bias, device)?; + h = self.patch_ln2.forward(&h, device)?; + + // Factorized 2D positional embedding (host gather of two axis rows). + let pos = self.gather_pos_embedding(&x_idx, &y_idx, device)?; // [n_model, mm] + h = add(&h, &pos, device)?; + h = self.pos_norm.forward(&h, device)?; + + // embed_vision: RMSNormNoScale -> embedding_projection -> [n_model, text_hidden]. + h = h.reshape(&[1, n_model as i32, mm], device)?; + let out = self.embed_vision.forward(&h, device)?; + + debug!( + n_model_patches = n_model, + mm_embed_dim = self.cfg.mm_embed_dim, + "gemma4_unified vision: embedder forward" + ); + Ok(out) + } + + /// Host patchify (16px teacher patches) + `patches_merge` (k×k -> model + /// patch) producing flat `[n_model * patch_dim]` f32 plus per-model-patch + /// `(x, y)` positions. + /// + /// Faithful to `convert_image_to_patches` (layout `[patch_h, patch_w, ch]`) + /// and `patches_merge` (model-patch interior `[k, k, 16, 16, 3]` -> + /// `[48, 48, 3]`, position = top-left teacher position // k). + #[allow( + clippy::indexing_slicing, + reason = "bounds established by construction: all indices derived from the host-computed patch grid" + )] + fn patchify_and_merge(&self, pv: &Gemma4PixelValues) -> Result<(Vec, Vec, Vec)> { + let p = self.cfg.patch_size; // 16 + let k = self.cfg.pooling_kernel_size; // 3 + let h = pv.height; + let w = pv.width; + if !h.is_multiple_of(p * k) || !w.is_multiple_of(p * k) { + return Err(Error::Model(format!( + "gemma4_unified vision: image {h}x{w} not divisible by model_patch_size {}", + p * k + ))); + } + let p_h = h / p; // teacher rows + let p_w = w / p; // teacher cols + let m_h = p_h / k; // model rows + let m_w = p_w / k; // model cols + let n_model = m_h * m_w; + let model_patch = p * k; // 48 + let patch_dim = model_patch * model_patch * 3; // 6912 + let n_pixels = h * w; + + // Build merged patches directly in the reference target layout. The + // upstream `patches_merge` reshapes the k×k kernel group to + // `(length, k, k, p, p, 3)` then permutes to `(length, k, p, k, p, 3)` + // and flattens — i.e. the 6912-vector interior is ordered + // **`[ky, ry, kx, rx, ch]`**. That makes the model patch a *contiguous* + // (k*p)×(k*p) image: full row = `ky*p + ry`, full col = `kx*p + rx`. + // (Ordering `[ky, kx, ry, rx, ch]` would tile 3×3 blocks instead and + // scramble fine detail — OCR fails, color survives.) + let mut merged = vec![0.0_f32; n_model * patch_dim]; + let mut x_idx = vec![0i32; n_model]; + let mut y_idx = vec![0i32; n_model]; + for my in 0..m_h { + for mx in 0..m_w { + let model_i = my * m_w + mx; + // model-patch position = (min teacher_x // k, min teacher_y // k) + // = (mx, my) since teacher cols/rows in a kernel are contiguous. + x_idx[model_i] = mx as i32; + y_idx[model_i] = my as i32; + let dst = model_i * patch_dim; + for ky in 0..k { + for ry in 0..p { + for kx in 0..k { + for rx in 0..p { + let y = (my * k + ky) * p + ry; // my*48 + ky*16 + ry + let x = (mx * k + kx) * p + rx; // mx*48 + kx*16 + rx + for ch in 0..3 { + let src = ch * n_pixels + y * w + x; + // interior index: [ky, ry, kx, rx, ch] over dims [k, p, k, p, 3] + let off = ((((ky * p + ry) * k + kx) * p + rx) * 3) + ch; + merged[dst + off] = pv.pixel_values[src]; + } + } + } + } + } + } + } + Ok((merged, x_idx, y_idx)) + } + + /// Gather `pos_embedding[x, 0, :] + pos_embedding[y, 1, :]` per model patch. + /// + /// `pos_embedding` is `[mm_posemb_size, 2, mm_embed_dim]`; axis-0 slice is + /// the X table, axis-1 slice is the Y table. Positions are in-range by + /// construction (no padding patches on the single-image path). + fn gather_pos_embedding(&self, x_idx: &[i32], y_idx: &[i32], device: Device) -> Result { + let n = x_idx.len() as i32; + let posemb = self.cfg.mm_posemb_size as i32; + let mm = self.cfg.mm_embed_dim as i32; + // table_x = pos_embedding[:, 0, :], table_y = pos_embedding[:, 1, :]. + let table_x = self + .pos_embedding + .slice(&[0, 0, 0], &[posemb, 1, mm], &[1, 1, 1], device)? + .reshape(&[posemb, mm], device)?; + let table_y = self + .pos_embedding + .slice(&[0, 1, 0], &[posemb, 2, mm], &[1, 1, 1], device)? + .reshape(&[posemb, mm], device)?; + let x_arr = Array::from_bytes(i32_bytes(x_idx), &[n], Dtype::I32)?; + let y_arr = Array::from_bytes(i32_bytes(y_idx), &[n], Dtype::I32)?; + let pe_x = table_x.take(&x_arr, 0, device)?; + let pe_y = table_y.take(&y_arr, 0, device)?; + add(&pe_x, &pe_y, device) + } +} + +// --------------------------------------------------------------------------- +// Loader +// --------------------------------------------------------------------------- + +/// Load the unified vision embedder (`vision_embedder.*` + `embed_vision.*`) +/// from a snapshot directory. Errors if the unified-vision weights are absent +/// (caller disables image input on error). +pub fn load_unified_vision_embedder( + model_dir: &Path, + cfg: &UnifiedVisionConfig, +) -> Result { + let idx = load_shard_index(model_dir)?; + let shards = ShardSet::open(model_dir, &idx)?; + + fn load_f32(shards: &ShardSet, name: &str) -> Result { + for (_, handle) in shards.iter() { + let st = handle.safetensors()?; + if let Ok(t) = st.tensor(name) { + let tv = rmlx_loader::TensorView { + name, + dtype: t.dtype(), + shape: t.shape().to_vec(), + bytes: t.data(), + }; + let a = Array::from_safetensor_view(&tv)?; + return a.astype(Dtype::F32, Device::Cpu); + } + } + Err(Error::Loader(format!( + "gemma4_unified vision: tensor '{name}' not found in any shard" + ))) + } + let has = |name: &str| -> bool { + shards + .iter() + .any(|(_, h)| h.safetensors().is_ok_and(|st| st.tensor(name).is_ok())) + }; + + let layer_norm = |prefix: &str| -> Result { + Ok(LayerNorm { + weight: load_f32(&shards, &format!("{prefix}.weight"))?, + bias: load_f32(&shards, &format!("{prefix}.bias"))?, + eps: cfg.rms_norm_eps, + }) + }; + + // patch_dense: quantized linear (mxfp8 on the snapshot) + additive bias. + let dense_base = "vision_embedder.patch_dense"; + let patch_dense = if has(&format!("{dense_base}.scales")) { + let weight = load_raw(&shards, &format!("{dense_base}.weight"))?; + let scales = load_raw(&shards, &format!("{dense_base}.scales"))?; + let biases = if has(&format!("{dense_base}.biases")) { + Some(load_raw(&shards, &format!("{dense_base}.biases"))?) + } else { + None + }; + let (gs, bits, mode) = read_quant_params(model_dir)?; + crate::layers::Linear::Quantized { + weight, + scales, + biases, + group_size: gs, + bits, + mode, + } + } else { + crate::layers::Linear::Plain { + weight: load_f32(&shards, &format!("{dense_base}.weight"))?, + } + }; + let patch_dense_bias = load_f32(&shards, &format!("{dense_base}.bias"))?; + + let pos_embedding = load_f32(&shards, "vision_embedder.pos_embedding")?; + + // embed_vision: RMSNormNoScale -> embedding_projection (reused loader). + let embed_vision = + super::load_multimodal_embedder(model_dir, "embed_vision", cfg.rms_norm_eps)?; + + // Validate the parsed `output_proj_dims` against the loaded projection's + // actual input feature dim. This makes the config field load-bearing: a + // checkpoint whose `embed_vision.embedding_projection` does not consume + // `output_proj_dims` features is rejected here instead of failing with an + // opaque shape error inside the forward pass. + if let Some(proj_in) = embed_vision.projection_input_dim() { + if proj_in != cfg.output_proj_dims { + return Err(Error::Loader(format!( + "gemma4_unified vision: output_proj_dims ({}) != embed_vision.embedding_projection \ + input dim ({proj_in}) — config/checkpoint mismatch", + cfg.output_proj_dims + ))); + } + } + + info!( + mm_embed_dim = cfg.mm_embed_dim, + num_soft_tokens = cfg.num_soft_tokens, + model_patch_size = cfg.model_patch_size, + "gemma4_unified vision: embedder loaded (encoder-free)" + ); + + Ok(UnifiedVisionEmbedder { + cfg: cfg.clone(), + patch_ln1: layer_norm("vision_embedder.patch_ln1")?, + patch_dense, + patch_dense_bias, + patch_ln2: layer_norm("vision_embedder.patch_ln2")?, + pos_embedding, + pos_norm: layer_norm("vision_embedder.pos_norm")?, + embed_vision, + }) +} + +// --------------------------------------------------------------------------- +// Image processor for the unified path +// --------------------------------------------------------------------------- + +/// Number of soft tokens an image of size `(h, w)` will consume after the +/// unified embedder's `patches_merge`: +/// `(h / model_patch_size) * (w / model_patch_size)`. +/// +/// This is the count used to size the image-token block in the prompt. It must +/// equal the model-patch count produced by [`UnifiedVisionEmbedder::forward`]. +#[inline] +pub fn unified_num_soft_tokens(h: usize, w: usize, cfg: &UnifiedVisionConfig) -> usize { + let mp = cfg.model_patch_size; + (h / mp) * (w / mp) +} + +/// Build the [`Gemma4ImageProcessorConfig`] for the unified path from the +/// unified vision config. The shared preprocessor (resize + rescale) is +/// identical to the tower path; only the post-resize patchify differs (done in +/// [`UnifiedVisionEmbedder::forward`]). The processor's reported +/// `num_soft_tokens` is corrected to the model-patch count by +/// [`unified_num_soft_tokens`] at prompt-build time. +pub fn unified_image_processor_config(cfg: &UnifiedVisionConfig) -> Gemma4ImageProcessorConfig { + let d = Gemma4ImageProcessorConfig::default(); + Gemma4ImageProcessorConfig { + patch_size: cfg.patch_size, + max_soft_tokens: cfg.num_soft_tokens, + pooling_kernel_size: cfg.pooling_kernel_size, + ..d + } +} + +// --------------------------------------------------------------------------- +// build unified inputs_embeds (text + scattered vision soft tokens) +// --------------------------------------------------------------------------- + +/// Build the merged `inputs_embeds` for a unified-arch image prompt. +/// +/// Mirrors [`super::build_inputs_embeds`] but routes the encode through the +/// encoder-free [`UnifiedVisionEmbedder`]. Each image contributes +/// `unified_num_soft_tokens` soft tokens scattered at its contiguous run of +/// [`super::IMAGE_TOKEN_ID`] positions. +/// +/// Returns `(inputs_embeds [1, seq, hidden], masked_ids [seq])`. +#[allow( + clippy::indexing_slicing, + reason = "bounds established by construction: img_positions are filtered from input_ids; per-image runs validated contiguous before slice_update" +)] +pub fn build_unified_inputs_embeds( + model: &super::super::model::Gemma4Text, + embedder: &UnifiedVisionEmbedder, + images: &[Gemma4PixelValues], + input_ids: &[u32], + device: Device, + mm_cache: Option<&crate::multimodal_cache::MultimodalCache>, +) -> Result<(Array, Array)> { + let hidden = model.cfg.hidden_size as i32; + let seq = input_ids.len(); + + let img_positions: Vec = input_ids + .iter() + .enumerate() + .filter(|(_, &t)| t == super::IMAGE_TOKEN_ID) + .map(|(i, _)| i) + .collect(); + // Each image's soft-token count = model-patch count for its resized size. + let per_image: Vec = images + .iter() + .map(|pv| unified_num_soft_tokens(pv.height, pv.width, embedder.config())) + .collect(); + let expected: usize = per_image.iter().sum(); + if img_positions.len() != expected { + return Err(Error::Model(format!( + "gemma4_unified image: {} image-token ({}) positions in prompt != \ + {expected} vision soft tokens ({} image(s)) — scatter would misalign", + img_positions.len(), + super::IMAGE_TOKEN_ID, + images.len() + ))); + } + info!( + image_tokens = img_positions.len(), + images = images.len(), + seq, + "gemma4_unified image: building inputs_embeds (token count == soft tokens)" + ); + + // Scaled text embeddings: embed_tokens(ids) * sqrt(hidden). + let ids_i32: Vec = input_ids.iter().map(|&x| x as i32).collect(); + let ids_arr = Array::from_bytes(i32_bytes(&ids_i32), &[seq as i32], Dtype::I32)?; + let h_raw = model.embed_tokens.forward(&ids_arr, device)?; + let embed_scale = scalar_f32((model.cfg.hidden_size as f32).sqrt()); + let mut embeds = multiply(&h_raw, &embed_scale, device)?; + embeds = embeds.reshape(&[1, seq as i32, hidden], device)?; + let embeds_dtype = embeds.dtype(); + + let mut cursor = 0usize; + for (img_idx, (pv, &n_soft)) in images.iter().zip(per_image.iter()).enumerate() { + let key_bytes = crate::multimodal_cache::pixel_f32_bytes(&pv.pixel_values); + let key = crate::multimodal_cache::MmCacheKey::image_key( + key_bytes, + u16::try_from(pv.height).unwrap_or(u16::MAX), + u16::try_from(pv.width).unwrap_or(u16::MAX), + 3, + crate::multimodal_cache::MmDtype::F32, + ); + let feats = crate::multimodal_cache::get_or_compute(mm_cache, key, || { + embedder.forward(pv, device) + })?; + let fs = feats.shape(); + if fs.first().copied() != Some(1) + || fs.get(1).copied() != Some(n_soft as i32) + || fs.get(2).copied() != Some(hidden) + { + return Err(Error::Model(format!( + "gemma4_unified image: vision feature shape {fs:?} != [1, {n_soft}, {hidden}] \ + for image {img_idx}" + ))); + } + let feats = feats.astype(embeds_dtype, device)?; + + let run = &img_positions[cursor..cursor + n_soft]; + let first = run[0]; + let contiguous = run.iter().enumerate().all(|(k, &p)| p == first + k); + if !contiguous { + return Err(Error::Model(format!( + "gemma4_unified image: image-token positions for image {img_idx} are not \ + contiguous (got {run:?})" + ))); + } + embeds = embeds.slice_update( + &feats, + &[0, first as i32, 0], + &[1, (first + n_soft) as i32, hidden], + &[1, 1, 1], + device, + )?; + cursor += n_soft; + } + + // Mask image-token ids to 0 for per-layer-input gating (matches the tower path). + let mut masked: Vec = ids_i32; + for &p in &img_positions { + masked[p] = 0; + } + let masked_arr = Array::from_bytes(i32_bytes(&masked), &[seq as i32], Dtype::I32)?; + + Ok((embeds, masked_arr)) +} + +// --------------------------------------------------------------------------- + +#[cfg(test)] +#[path = "unified_tests.rs"] +mod unified_tests; diff --git a/crates/rmlx-models/src/gemma4/vision/unified_tests.rs b/crates/rmlx-models/src/gemma4/vision/unified_tests.rs new file mode 100644 index 0000000..8feca42 --- /dev/null +++ b/crates/rmlx-models/src/gemma4/vision/unified_tests.rs @@ -0,0 +1,98 @@ +//! Model-free unit coverage for the gemma4_unified vision embedder front-end: +//! config parse, soft-token math, and the patchify+merge shape/position plumbing. + +use std::path::Path; + +use super::*; + +fn cfg_12b() -> UnifiedVisionConfig { + // Verified gemma-4-12B `gemma4_unified_vision` values. + UnifiedVisionConfig { + mm_embed_dim: 3840, + mm_posemb_size: 1120, + model_patch_size: 48, + patch_size: 16, + pooling_kernel_size: 3, + num_soft_tokens: 280, + output_proj_dims: 3840, + rms_norm_eps: 1e-6, + } +} + +#[test] +#[allow( + clippy::unwrap_used, + reason = "test asserts JSON parse succeeds on a known-good literal" +)] +fn config_parse_matches_snapshot() { + let v: serde_json::Value = serde_json::from_str( + r#"{ + "mm_embed_dim": 3840, "mm_posemb_size": 1120, "model_patch_size": 48, + "patch_size": 16, "pooling_kernel_size": 3, "num_soft_tokens": 280, + "output_proj_dims": 3840, "rms_norm_eps": 1e-06 + }"#, + ) + .unwrap(); + let cfg = UnifiedVisionConfig::from_json(&v); + assert_eq!(cfg.mm_embed_dim, 3840); + assert_eq!(cfg.mm_posemb_size, 1120); + assert_eq!(cfg.model_patch_size, 48); + assert_eq!(cfg.patch_size, 16); + assert_eq!(cfg.pooling_kernel_size, 3); + assert_eq!(cfg.num_soft_tokens, 280); + assert_eq!(cfg.output_proj_dims, 3840); + // patch_dim = 48*48*3 = 6912 (input width of patch_dense / patch_ln1). + assert_eq!(cfg.patch_dim(), 6912); +} + +#[test] +fn patch_dim_is_model_patch_squared_times_three() { + let cfg = cfg_12b(); + assert_eq!(cfg.patch_dim(), 48 * 48 * 3); +} + +#[test] +fn soft_token_count_is_model_patch_grid() { + let cfg = cfg_12b(); + // A 144x96 image: 3x2 model patches (model_patch_size=48) -> 6 soft tokens. + assert_eq!(unified_num_soft_tokens(144, 96, &cfg), 6); + // Square 672x672 -> 14x14 = 196 model patches. + assert_eq!(unified_num_soft_tokens(672, 672, &cfg), 196); + // 768x720 -> 16x15 = 240. + assert_eq!(unified_num_soft_tokens(768, 720, &cfg), 240); +} + +#[test] +fn image_processor_config_carries_unified_params() { + let cfg = cfg_12b(); + let pc = unified_image_processor_config(&cfg); + assert_eq!(pc.patch_size, 16); + assert_eq!(pc.max_soft_tokens, 280); + assert_eq!(pc.pooling_kernel_size, 3); +} + +/// Validate that the public soft-token count and the documented model-patch +/// grid agree for a non-square image, and that all position ids fall inside the +/// factorized positional-embedding table (`mm_posemb_size`). +#[test] +fn model_patch_grid_and_position_bounds() { + let cfg = cfg_12b(); + // 96x144 image (h=96, w=144): model patches = 2 rows x 3 cols = 6. + let h = 96usize; + let w = 144usize; + let m_h = h / cfg.model_patch_size; // 2 + let m_w = w / cfg.model_patch_size; // 3 + assert_eq!(m_h * m_w, unified_num_soft_tokens(h, w, &cfg)); + assert_eq!(unified_num_soft_tokens(h, w, &cfg), 6); + + // Position ids span (mx, my) ∈ [0, m_w) × [0, m_h) — all in-table. + assert!(m_w <= cfg.mm_posemb_size); + assert!(m_h <= cfg.mm_posemb_size); +} + +#[test] +fn is_unified_arch_false_for_missing_dir() { + // Non-existent dir -> false (no panic). + let p = Path::new("/nonexistent/gemma4-unified-test-dir"); + assert!(!is_unified_arch(p)); +} diff --git a/crates/rmlx-server/src/engine/arch_generator.rs b/crates/rmlx-server/src/engine/arch_generator.rs index 802d1ed..a29298e 100644 --- a/crates/rmlx-server/src/engine/arch_generator.rs +++ b/crates/rmlx-server/src/engine/arch_generator.rs @@ -255,6 +255,36 @@ impl ArchGenerator { // `None` here and the image-input path is rejected at request time. // Only the Gemma4 architecture has a vision tower today. let vision: Option> = match &model { + // Gemma4 **unified** (12B): encoder-free vision embedder, no SigLIP + // tower. Distinguished from the tower family by `architectures[0]`. + rmlx_models::arch::Architecture::Gemma4(_) + if rmlx_models::gemma4::is_unified_arch(model_dir) => + { + match rmlx_models::gemma4::UnifiedVisionConfig::from_model_dir(model_dir) { + Ok(Some(vcfg)) => { + match rmlx_models::gemma4::load_unified_vision_embedder(model_dir, &vcfg) { + Ok(embedder) => { + let pc = rmlx_models::gemma4::unified_image_processor_config(&vcfg); + let processor = rmlx_models::gemma4::Gemma4ImageProcessor::new(pc); + tracing::info!(model_id = %model_id, "Gemma4-unified vision embedder loaded (encoder-free, multimodal)"); + Some(Arc::new(VisionBundle::Gemma4Unified { + embedder, + processor, + })) + } + Err(e) => { + tracing::warn!(model_id = %model_id, error = %e, "unified vision embedder load failed — image input disabled"); + None + } + } + } + Ok(None) => None, + Err(e) => { + tracing::warn!(model_id = %model_id, error = %e, "unified vision_config parse failed — image input disabled"); + None + } + } + } rmlx_models::arch::Architecture::Gemma4(_) => { match rmlx_models::gemma4::Gemma4VisionConfig::from_model_dir(model_dir) { Ok(Some(vcfg)) => { diff --git a/crates/rmlx-server/src/engine/audio.rs b/crates/rmlx-server/src/engine/audio.rs index f3cd58a..7135780 100644 --- a/crates/rmlx-server/src/engine/audio.rs +++ b/crates/rmlx-server/src/engine/audio.rs @@ -26,6 +26,10 @@ const GEMMA4_AUDIO_SAMPLE_RATE: u32 = 16_000; /// `audio_config` + `audio_tower.*` weights. One variant per audio-capable /// architecture (only Gemma4's Conformer tower today). #[allow(missing_debug_implementations)] +#[allow( + clippy::large_enum_variant, + reason = "loaded once per model and held behind an Arc; the Conformer variant carries the full audio tower while the unified variant is a thin projection — boxing would add an indirection on every multimodal forward for a single long-lived value" +)] pub(crate) enum AudioBundle { /// Gemma4 Conformer audio tower + multimodal embedder + feature extractor. Gemma4 { @@ -34,6 +38,13 @@ pub(crate) enum AudioBundle { feature_extractor: rmlx_models::gemma4::Gemma4AudioFeatureExtractor, audio_token_id: u32, }, + /// Gemma4 **unified** (`Gemma4UnifiedForConditionalGeneration`, 12B): + /// encoder-free audio embedder (no Conformer tower). Raw 16 kHz waveform is + /// chunked into fixed-length frames and projected by `embed_audio`. + Gemma4Unified { + embedder: rmlx_models::gemma4::UnifiedAudioEmbedder, + audio_token_id: u32, + }, } /// Load the Gemma4 audio bundle (Conformer tower + `embed_audio` projector + @@ -43,6 +54,21 @@ pub(crate) enum AudioBundle { /// vision-only). Errors propagate to the caller, which logs + disables the /// audio path (audio input then returns a clear "no audio tower" error). pub(crate) fn load_gemma4_audio_bundle(model_dir: &Path) -> rmlx_core::Result> { + // Unified 12B: encoder-free audio (`model_type: gemma4_unified_audio`, no + // `audio_tower.*`). Route to the waveform-frame embedder before the + // Conformer loader, which would fail trying to read the absent tower + // tensors. + if rmlx_models::gemma4::is_unified_arch(model_dir) { + let Some(ucfg) = rmlx_models::gemma4::UnifiedAudioConfig::from_model_dir(model_dir)? else { + return Ok(None); + }; + let embedder = rmlx_models::gemma4::load_unified_audio_embedder(model_dir, &ucfg)?; + return Ok(Some(AudioBundle::Gemma4Unified { + audio_token_id: ucfg.audio_token_id, + embedder, + })); + } + let Some(acfg) = rmlx_models::gemma4::Gemma4AudioConfig::from_model_dir(model_dir)? else { return Ok(None); }; @@ -115,16 +141,40 @@ pub(crate) fn build_audio_prompt( ))); } + let model = arch + .as_gemma4() + .ok_or_else(|| Error::Other("audio input requires the Gemma4 architecture".to_owned()))?; + + // Unified 12B: encoder-free waveform-frame path. Decode + resample (shared + // with the Conformer path below), then chunk into fixed-length frames and + // project — no mel front-end, no tower. + if let AudioBundle::Gemma4Unified { + embedder, + audio_token_id, + } = ab + { + return build_unified_audio_prompt( + model, + embedder, + *audio_token_id, + &audio_b64[0], + prompt_tokens, + device, + ); + } + let AudioBundle::Gemma4 { encoder, embedder, feature_extractor, audio_token_id, - } = ab; - - let model = arch - .as_gemma4() - .ok_or_else(|| Error::Other("audio input requires the Gemma4 architecture".to_owned()))?; + } = ab + else { + // Only the Conformer variant remains after the unified early return. + return Err(Error::Other( + "internal: unexpected audio bundle variant in Conformer path".to_owned(), + )); + }; // 1. base64 → bytes → mono f32 @ 16 kHz. let raw = audio_b64[0].as_str(); @@ -211,6 +261,93 @@ pub(crate) fn build_audio_prompt( Ok((aug_ids, embeds, masked_ids)) } +/// Build the unified-arch (12B) audio prompt: decode the clip, chunk the raw +/// 16 kHz waveform into fixed-length frames, project via the encoder-free +/// `embed_audio`, and scatter at the `<|audio|>` positions. +/// +/// Pipeline (mirrors HF `Gemma4UnifiedProcessor` + the unified vision splice): +/// 1. base64 → bytes → `WavDecoder` (mono f32, native rate) → resample 16 kHz. +/// 2. `extract_waveform_frames` → `[num_tokens, audio_embed_dim]` raw frames, +/// `num_tokens = ceil(num_samples / audio_samples_per_token)`. +/// 3. splice the audio block `<|audio>` + `num_tokens` × `<|audio|>` + `` +/// after the leading token. +/// 4. `build_unified_audio_inputs_embeds` projects the frames and scatters them. +#[allow( + clippy::indexing_slicing, + reason = "insert_at is 0 or 1, bounded by prompt_tokens.len() (0 → empty, 1 → non-empty); both prompt_tokens slices are in range" +)] +fn build_unified_audio_prompt( + model: &rmlx_models::gemma4::Gemma4Text, + embedder: &rmlx_models::gemma4::UnifiedAudioEmbedder, + audio_token_id: u32, + audio_b64: &str, + prompt_tokens: &[u32], + device: rmlx_mlx::Device, +) -> rmlx_core::Result<(Vec, rmlx_mlx::Array, rmlx_mlx::Array)> { + // 1. base64 → bytes → mono f32 @ 16 kHz (shared decode/resample helpers). + let b64 = audio_b64 + .rsplit_once(',') + .map_or(audio_b64, |(_, tail)| tail); + let bytes = crate::image_io::base64_decode(b64) + .map_err(|e| Error::Other(format!("input_audio base64 decode failed: {e}")))?; + let (samples, sample_rate) = rmlx_audio::wav::WavDecoder::decode(&bytes) + .map_err(|e| Error::Other(format!("input_audio decode failed: {e}")))?; + let samples = rmlx_audio::transcribe::resample_to_16k(&samples, sample_rate); + let dur_secs = samples.len() as f64 / f64::from(GEMMA4_AUDIO_SAMPLE_RATE); + + // 2. chunk raw waveform into fixed-length frames (encoder-free front-end). + let cfg = embedder.config(); + let (frames, num_tokens) = + rmlx_models::gemma4::extract_waveform_frames(&samples, cfg.audio_samples_per_token); + if num_tokens == 0 { + return Err(Error::Other( + "input_audio produced zero audio soft tokens (clip empty)".to_owned(), + )); + } + + tracing::info!( + sample_rate, + samples = samples.len(), + duration_secs = dur_secs, + audio_embed_dim = cfg.audio_embed_dim, + audio_samples_per_token = cfg.audio_samples_per_token, + audio_soft_tokens = num_tokens, + "Gemma4-unified audio preprocessed (encoder-free)" + ); + + // 3. splice the audio block after the prompt's leading token. + let mut block = Vec::with_capacity(num_tokens + 2); + block.push(GEMMA4_BOA_TOKEN_ID); + block.extend(std::iter::repeat_n(audio_token_id, num_tokens)); + block.push(GEMMA4_EOA_TOKEN_ID); + + let insert_at = usize::from(!prompt_tokens.is_empty()); + let mut aug_ids = Vec::with_capacity(prompt_tokens.len() + block.len()); + aug_ids.extend_from_slice(&prompt_tokens[..insert_at]); + aug_ids.extend_from_slice(&block); + aug_ids.extend_from_slice(&prompt_tokens[insert_at..]); + + let in_prompt = aug_ids.iter().filter(|&&t| t == audio_token_id).count(); + tracing::info!( + audio_soft_tokens = num_tokens, + audio_tokens_in_prompt = in_prompt, + aug_len = aug_ids.len(), + "built Gemma4-unified audio prompt" + ); + + // 4. project frames + scatter. + let (embeds, masked_ids) = rmlx_models::gemma4::build_unified_audio_inputs_embeds( + model, + embedder, + &frames, + num_tokens, + audio_token_id, + &aug_ids, + device, + )?; + Ok((aug_ids, embeds, masked_ids)) +} + #[cfg(test)] #[path = "audio_tests.rs"] mod tests; diff --git a/crates/rmlx-server/src/engine/audio_tests.rs b/crates/rmlx-server/src/engine/audio_tests.rs index 6d56e9a..f0542ca 100644 --- a/crates/rmlx-server/src/engine/audio_tests.rs +++ b/crates/rmlx-server/src/engine/audio_tests.rs @@ -66,7 +66,11 @@ fn gemma4_audio_prompt_build_real_weights() { let bundle = load_gemma4_audio_bundle(&dir) .expect("load audio bundle") .expect("snapshot has an audio_config"); - let AudioBundle::Gemma4 { audio_token_id, .. } = &bundle; + // e4b ships the Conformer audio tower (not the encoder-free unified arch). + let AudioBundle::Gemma4 { audio_token_id, .. } = &bundle else { + eprintln!("SKIP: e4b loaded a non-Conformer audio bundle (unexpected)"); + return; + }; let audio_token_id = *audio_token_id; // Load the text model so build_audio_prompt can embed + scatter. diff --git a/crates/rmlx-server/src/engine/image.rs b/crates/rmlx-server/src/engine/image.rs index 02cb380..8892489 100644 --- a/crates/rmlx-server/src/engine/image.rs +++ b/crates/rmlx-server/src/engine/image.rs @@ -25,6 +25,12 @@ pub(crate) enum VisionBundle { embedder: rmlx_models::gemma4::MultimodalEmbedder, processor: rmlx_models::gemma4::Gemma4ImageProcessor, }, + /// Gemma4 **unified** (`Gemma4UnifiedForConditionalGeneration`, 12B): + /// encoder-free vision embedder (no SigLIP tower) + processor. + Gemma4Unified { + embedder: rmlx_models::gemma4::UnifiedVisionEmbedder, + processor: rmlx_models::gemma4::Gemma4ImageProcessor, + }, /// Gemma3 standard SigLIP tower + multimodal projector + processor. Gemma3 { vision: rmlx_models::gemma3::VisionModel, @@ -135,6 +141,67 @@ pub(crate) fn build_image_prompt( )?; Ok((aug_ids, embeds, masked_ids)) } + VisionBundle::Gemma4Unified { + embedder, + processor, + } => { + // The unified 12B loads through the Gemma4 text architecture; the + // encoder-free embedder replaces the SigLIP tower. + let model = arch.as_gemma4().ok_or_else(|| { + Error::Other("image input requires the Gemma4 architecture".to_owned()) + })?; + let image_token_id = rmlx_models::gemma4::IMAGE_TOKEN_ID; + + let mut pixels = Vec::with_capacity(sources.len()); + for (i, src) in sources.iter().enumerate() { + let bytes = rmlx_server_load_image(src)?; + let pv = processor + .preprocess(&bytes) + .map_err(|e| Error::Other(format!("image {i} preprocess failed: {e}")))?; + let n_soft = rmlx_models::gemma4::unified_num_soft_tokens( + pv.height, + pv.width, + embedder.config(), + ); + tracing::info!( + image_idx = i, + width = pv.width, + height = pv.height, + num_soft_tokens = n_soft, + "Gemma4-unified image preprocessed" + ); + pixels.push((pv, n_soft)); + } + + let blocks: Vec> = pixels + .iter() + .map(|(_, n_soft)| { + let mut b = Vec::with_capacity(n_soft + 2); + b.push(GEMMA4_BOI_TOKEN_ID); + b.extend(std::iter::repeat_n(image_token_id, *n_soft)); + b.push(GEMMA4_EOI_TOKEN_ID); + b + }) + .collect(); + let aug_ids = splice(prompt_tokens, &blocks); + + let total_soft: usize = pixels.iter().map(|(_, n)| *n).sum(); + let in_prompt = aug_ids.iter().filter(|&&t| t == image_token_id).count(); + tracing::info!( + images = pixels.len(), + soft_tokens = total_soft, + image_tokens_in_prompt = in_prompt, + aug_len = aug_ids.len(), + "built Gemma4-unified image prompt" + ); + + let pv_only: Vec = + pixels.into_iter().map(|(pv, _)| pv).collect(); + let (embeds, masked_ids) = rmlx_models::gemma4::build_unified_inputs_embeds( + model, embedder, &pv_only, &aug_ids, device, mm_cache, + )?; + Ok((aug_ids, embeds, masked_ids)) + } VisionBundle::Gemma3 { vision, projector, diff --git a/docs/MODELS.md b/docs/MODELS.md index 84a1bd8..fe76853 100644 --- a/docs/MODELS.md +++ b/docs/MODELS.md @@ -83,7 +83,7 @@ A converting or new architecture supplies only: | `Qwen3VLMoeForConditionalGeneration` | `Architecture::Qwen3VlMoe` | text + image | `bf16` | config | green | | `Gemma3ForConditionalGeneration` | `Architecture::Gemma3` | text + image | `Planar` | `KV_MAX_SEQ_DEFAULT` | green | | `Gemma4ForConditionalGeneration` | `Architecture::Gemma4` | text + image + audio | `K8V8` / `Planar` / `K8V4` | config | green | -| `Gemma4UnifiedForConditionalGeneration` | `Architecture::Gemma4` (alias) | text (12B; vision/audio not yet) | `K8V8` | config | green | +| `Gemma4UnifiedForConditionalGeneration` | `Architecture::Gemma4` (alias) | text + image + audio (12B) | `K8V8` | config | green | | `LagunaForCausalLM` | `Architecture::Laguna` | text | `K8V8` | `KV_MAX_SEQ_DEFAULT` | green | | `BitNetForCausalLM` | `Architecture::BitNet` | text | `K8V8` | 4 096 | green | | `JinaEmbeddingsV4Model` | (encoder — no enum variant) | text + image | n/a | 128 000 | green | @@ -645,9 +645,10 @@ audio fields under `audio_config`. The 12B snapshots declare `architectures[0] = "Gemma4UnifiedForConditionalGeneration"` — an encoder-free multimodal variant whose **text** decoder is identical to Gemma4 (dense, `attention_k_eq_v`, no per-layer-input, no MoE). rMLX aliases the -arch string to the Gemma4 text loader; the multimodal-embedder tensors -(`embed_vision`/`embed_audio`/`vision_embedder.*`) are not read, so image/audio -input is not yet wired for 12B (text serves end-to-end). +arch string to the Gemma4 text loader for the decoder, and routes **both vision +and audio** through dedicated encoder-free embedders (no SigLIP tower, no +Conformer; see *Unified (encoder-free) vision* and *Unified (encoder-free) +audio* below). Text, image, and audio all serve end-to-end on the 12B. Text serves correctly at **all weight quants**, including the mixed 4/8-bit QAT snapshots (`gemma-4-12B-it-qat-4bit` affine, `gemma-4-12B-it-qat-mxfp4`): their @@ -753,6 +754,77 @@ Text, image, and audio input. rMLX implements all three towers: - Vision: SigLIP-style ViT + VisionPooler + soft-token scatter. - Audio: Conformer encoder + output projection + scatter. +### Unified (encoder-free) vision — `Gemma4UnifiedForConditionalGeneration` (12B) + +The unified 12B has **no SigLIP `vision_tower`**. Vision is early-fusion: raw +pixel patches are projected straight into the shared 48-layer LM hidden space +(`mm_embed_dim = 3840`) as `num_soft_tokens` soft tokens. rMLX dispatches on +`architectures[0]` (`is_unified_arch`) and loads +`crates/rmlx-models/src/gemma4/vision/unified.rs` instead of the tower loader; +the Gemma4 text decoder is reused unchanged. + +Per-image pipeline (faithful port of HF `gemma4_unified` +`Gemma4UnifiedVisionEmbedder` + `Gemma4UnifiedImageProcessor`): + +1. Shared Gemma4 preprocess: aspect-ratio resize (mult of `model_patch_size=48`) + + rescale to `[0,1]` (`do_normalize=false`). +2. Host patchify into 16px teacher patches (`[ry, rx, ch]`), then + `patches_merge`: each `3×3` (`pooling_kernel_size`) group becomes one 48×48 + model patch (`patch_dim = 48²·3 = 6912`), interior laid out `[ky, ry, kx, rx, + ch]` so the model patch is a *contiguous* sub-image; model-patch position = + `(min teacher_x // k, min teacher_y // k)`. +3. On-device: `patch_ln1` (LayerNorm 6912) → `patch_dense` (quantized Linear + 6912→3840, +bias) → `patch_ln2` (LayerNorm 3840). +4. Factorized 2D positional embedding: `pos_embedding[x, 0, :] + + pos_embedding[y, 1, :]` (table `[mm_posemb_size=1120, 2, 3840]`), added then + `pos_norm` (LayerNorm 3840). +5. `embed_vision`: `RMSNormNoScale → embedding_projection` (3840 → text hidden) — + the same [`MultimodalEmbedder`] the tower path reuses. +6. Scatter the soft tokens at the image-token run in `inputs_embeds` + (`build_unified_inputs_embeds`), then run the shared text decoder from embeds. + +`patch_ln1/ln2/pos_norm` are true **LayerNorm** (mean-subtraction, weight+bias), +not RMSNorm — verified against the snapshot's `.weight`+`.bias` tensors and the +upstream class. Color, spatial layout (4-quadrant, left/right/top/bottom), and +object counting are exact on the real 12B; fine-grained OCR is weaker than the +e4b SigLIP tower — an architectural property of the encoder-free 35M projection +(it lacks the semantic richness of a full vision encoder), not a port defect. + +### Unified (encoder-free) audio — `Gemma4UnifiedForConditionalGeneration` (12B) + +The unified 12B has **no Conformer `audio_tower`** either. Audio is early-fusion: +the raw 16 kHz mono waveform is chunked into fixed-length frames and projected +straight into the shared LM hidden space. rMLX dispatches on `architectures[0]` +(`is_unified_arch`) *before* the Conformer loader and loads +`crates/rmlx-models/src/gemma4/audio/unified.rs`; the snapshot ships only +`embed_audio.embedding_projection.{weight,scales}` (a quantized +640→3840 Linear) — there is no `audio_tower.*`. + +Per-clip pipeline (faithful port of HF `gemma4_unified` +`Gemma4UnifiedAudioFeatureExtractor` + the shared multimodal embedder): + +1. base64 → wav decode → resample to 16 kHz mono (shared `rmlx-audio` path). +2. Host feature front-end (`extract_waveform_frames`): zero-pad the tail to a + multiple of `audio_samples_per_token` (640), then reshape into + `[num_tokens, 640]` frames. **No** mel spectrogram, windowing, or per-sample + normalization — raw float samples. Since there is no downsampling, + `num_soft_tokens = ceil(num_samples / 640)` (one soft token per 40 ms frame). +3. `embed_audio`: `RMSNormNoScale → embedding_projection` (640 → text hidden) — + the same [`MultimodalEmbedder`] the Conformer path and `embed_vision` reuse. +4. Scatter the soft tokens at the `<|audio|>` run in `inputs_embeds` + (`build_unified_audio_inputs_embeds`), then run the shared text decoder from + embeds. + +`audio_embed_dim == audio_samples_per_token == output_proj_dims == 640` on the +snapshot; the loader asserts the parsed `output_proj_dims` against the actual +`embed_audio.embedding_projection` input dim and rejects a mismatch. The +combined image+audio rejection guard still applies on the unified arch (a request +with both returns a clear error, never a silent drop). Like the encoder-free +vision path, audio grounding is real but coarse — the model reliably picks up the +spoken content (e.g. "Tuesday at noon", named colors/objects) without a full +audio encoder. Submitting audio to a model without the unified audio embedder +returns a clear "no audio tower" error. + ### Maximum context `max_position_embeddings` from `text_config`. diff --git a/docs/SERVER.md b/docs/SERVER.md index 60144ef..8153e6a 100644 --- a/docs/SERVER.md +++ b/docs/SERVER.md @@ -241,17 +241,29 @@ runs from the fused `inputs_embeds` (mirroring mlx-vlm `get_input_embeddings`). | `text` | `{type:"text", text:"…"}` | — | all | | `image_url` | `{type:"image_url", image_url:{url:""}}` | SigLIP vision | Gemma4, Gemma3, Qwen3-VL-MoE | | `input_image` | `{type:"input_image", image_url:""}` (mlx-vlm shape) | SigLIP vision | same | -| `input_audio` | `{type:"input_audio", input_audio:{data:"", format:"wav"}}` | Conformer audio (USM) | **Gemma4** (e4b/26b) | +| `input_audio` | `{type:"input_audio", input_audio:{data:"", format:"wav"}}` | Conformer audio (USM) / unified encoder-free | **Gemma4** (e4b/26b Conformer; 12B unified) | **Native audio (`input_audio`) — Gemma4.** The base64 payload is decoded (`rmlx-audio` symphonia decoder — WAV/MP3/M4A/etc.), downmixed to mono and -resampled to 16 kHz, run through the Gemma4 USM log-mel front-end, then the -Conformer `audio_tower` produces `T_sub` audio soft tokens. The prompt is -spliced with `<|audio>` + `T_sub`×`<|audio|>` + `` after the leading -token, and the soft tokens are scattered at the `<|audio|>` positions. `T_sub` -is derived from the encoder's SSCP downsample (`≈ mel_frames / 4`), so the -placeholder count always matches the encoder output (scatter aligns by -construction). One clip per request; >1 is rejected with a clear error. +resampled to 16 kHz. The downstream front-end then forks by architecture: + +- **Conformer (e4b/26b).** The waveform runs through the Gemma4 USM log-mel + front-end, then the Conformer `audio_tower` produces `T_sub` audio soft + tokens. `T_sub` is derived from the encoder's SSCP downsample + (`≈ mel_frames / 4`). +- **Unified encoder-free (12B `Gemma4UnifiedForConditionalGeneration`).** No mel + front-end, no Conformer: the raw 16 kHz waveform is chunked into fixed-length + 640-sample frames (`extract_waveform_frames`) and each frame is projected by + `embed_audio` (`RMSNorm → Linear`, 640→hidden). `num_soft_tokens = + ceil(num_samples / 640)` (one soft token per 40 ms frame). See *Unified + (encoder-free) audio* in `docs/MODELS.md`. + +In both cases the prompt is spliced with `<|audio>` + `N`×`<|audio|>` + +`` after the leading token, and the soft tokens are scattered at the +`<|audio|>` positions; the placeholder count always matches the front-end output +(scatter aligns by construction). One clip per request; >1 is rejected with a +clear error. Combined image+audio in one request is also rejected with a clear +error (on both the Conformer and unified arches), never a silent drop. **Not-supported path.** Submitting `input_audio` to a model without an audio tower (text-only, or a vision-only checkpoint) returns **HTTP 503**