feat(gemma4): support unified any-to-any arch (Gemma4Unified) — encoder-free vision + audio early-fusion (#120)#126
Merged
Conversation
Add the gemma4_unified / Gemma4UnifiedForConditionalGeneration multimodal
front-end so the dense Gemma 4 12B accepts image input. The unified arch has
no SigLIP vision_tower: vision is early-fusion via soft tokens projected
straight into the shared 48-layer LM.
New encoder-free vision embedder (crates/rmlx-models/src/gemma4/vision/
unified.rs), a faithful port of HF transformers gemma4_unified
Gemma4UnifiedVisionEmbedder + Gemma4UnifiedImageProcessor:
- host patchify (16px teacher patches) + patches_merge (3x3 -> 48x48 model
patch, interior [ky, ry, kx, rx, ch] = contiguous sub-image)
- patch_ln1 (LayerNorm) -> patch_dense (quantized Linear 6912->3840, +bias)
-> patch_ln2 -> factorized 2D pos embedding (pos_embedding[x,0]+[y,1]) ->
pos_norm -> embed_vision (reused RMSNormNoScale -> embedding_projection)
- 280 soft tokens scattered at image-token positions; text decoder reused
unchanged.
Dispatch: Architecture::Gemma4 + is_unified_arch routes to the encoder-free
embedder; the standard gemma4 family (e4b/26b/31b) keeps the SigLIP tower.
patch_ln1/ln2/pos_norm are true LayerNorm (weight+bias), not RMSNorm.
Real-model verified on mlx-community__gemma-4-12B-it-mxfp8: distinct correct
colors, exact 4-quadrant + left/right/top/bottom spatial, object counting and
left-to-right ordering. No-image control correctly reports no image. Text-only
and the e4b SigLIP path are not regressed. Fine-grained OCR is weaker than the
e4b tower (architectural property of the 35M encoder-free projection).
Audio (unified embed_audio early-fusion) is a follow-up; the existing
audio_tower loader does not match embed_audio.* so audio stays disabled on 12B.
Refs #120
…review fixes Part A (unified vision review): - output_proj_dims is now load-bearing: assert the parsed config dim against the loaded embed_vision.embedding_projection input dim at load time (new MultimodalEmbedder::projection_input_dim), reject mismatch. - LayerNorm returns Error::Model on rank-0 input instead of the silent *shape.last().unwrap_or(&1) fallback. Part B (unified audio mandate): - New gemma4/audio/unified.rs: encoder-free audio front-end (faithful port of HF Gemma4UnifiedAudioFeatureExtractor) — raw 16 kHz waveform chunked into fixed 640-sample frames (audio_samples_per_token == audio_embed_dim == output_proj_dims == 640), no mel/Conformer, projected via the shared embed_audio (RMSNorm -> Linear 640->3840) into the LM hidden space and scattered at the <|audio|> positions (build_unified_audio_inputs_embeds), mirroring the unified vision path. Same output_proj_dims load-time assert. - engine: AudioBundle::Gemma4Unified variant; load_gemma4_audio_bundle routes is_unified_arch before the Conformer loader; build_audio_prompt forks to the encoder-free path. Combined image+audio guard and no-embedder clear error intact. - Tests: model-free unit coverage (config parse, ceil-div frame count, tail padding, frame/soft-token alignment) + model-gated integration (RMLX_TEST_MODEL_GEMMA4_12B / RMLX_O_MODELS_ROOT, skip-graceful). - Docs: MODELS.md (unified now vision AND audio; remove 'audio not yet'), SERVER.md (audio surface forks Conformer/unified), TESTING.md (12B env var).
…to-discovery Drop the single-purpose RMLX_TEST_MODEL_GEMMA4_12B env var in favour of the existing RMLX_O_MODELS_ROOT + slug pattern, matching how Whisper tests were updated previously. Update the skip diagnostic and remove the row from docs/TESTING.md.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #120.
Problem
rMLX supported Gemma4 image input only for the
gemma4/Gemma4ForConditionalGenerationfamily (separate SigLIPvision_tower, e4b/26b/31b). It did NOT support the
gemma4_unified/Gemma4UnifiedForConditionalGenerationany-to-any arch used by the dense Gemma 4 12B — the unified model loaded
text-only and image requests returned
503 no vision tower(the loader lookedfor
vision_tower.*, which the unified snapshot does not have).Why it's a different architecture
The unified arch is encoder-free: vision and audio are early-fusion soft
tokens projected straight into the shared 48-layer LM — there is no vision/audio
transformer tower. The 12B snapshot ships
vision_embedder.*+embed_vision.embedding_projectionandembed_audio.embedding_projection(novision_tower.*, noaudio_tower.*). The text forward pass was already correct(mxfp8 12B generates coherently); this PR adds only the multimodal front-end +
arch dispatch.
Changes
Gemma4(_) if is_unified_arch(...)(checksarchitectures[0]) routes to the new encoder-free path; the standard familykeeps the SigLIP
vision_towerpath unchanged.gemma4/vision/unified.rs): faithful port of HFGemma4UnifiedVisionEmbedder+ image processor — host patchify + 3×3 merge(
pooling_kernel_size, interior layout[ky,ry,kx,rx,ch]),patch_ln1(true LayerNorm) → quantized
patch_dense(6912→3840 +bias) →patch_ln2→factorized 2D positional embedding →
pos_norm→embed_visionprojection →soft tokens scattered at
IMAGE_TOKEN_ID.gemma4/audio/unified.rs): faithful port of HFGemma4UnifiedAudioFeatureExtractor— parameter-free front-end (resample to16kHz mono → zero-pad tail to a multiple of
audio_samples_per_token=640→reshape
[num_tokens, 640]raw samples; no mel/window/normalization) →quantized
embed_audioprojection (640→3840) → soft tokens scattered at the<|audio|>run.num_soft_tokens = ceil(num_samples/640).combined image+audio in one request returns a clear error (no silent drop);
a model without a unified embedder returns a clear error, never a silent drop.
Proof (real model,
gemma-4-12B-it-mxfp8, single-MLX)Vision — before:
503 no vision tower. After: 4-quadrant image →"Top-left: Red, Top-right: Green, Bottom-left: Blue, Bottom-right: Yellow"
(exact); counting/ordering exact; scene grounded; no-image control correctly
says no image attached. (Fine-grained OCR is coarser than the SigLIP e4b path —
an architectural property of the 35M encoder-free projection, documented, not a
port defect.)
Audio — before: disabled (no
audio_tower.*). After:say "The launch is scheduled for Tuesday at noon."(55 soft tokens) → response grounded in"Tuesday at noon"; no-audio control does NOT hallucinate it; a second
clip ("green … three cats") → different grounded answer.
No regression: unified text-only coherent; standard e4b SigLIP vision +
Conformer audio (#122) paths untouched and reproven. Combined image+audio → clear
503.
make lint(-D warnings) +make testgreen; 12 model-free unit tests + 2model-gated integration tests (resolve the snapshot via
RMLX_O_MODELS_ROOT,skip-gracefully). Reviewed by the rust-reviewer agent (vision + audio passes);
findings fixed. No new deps, no new runtime/test env vars.
🤖 Generated with Claude Code