Skip to content

feat(gemma4): support unified any-to-any arch (Gemma4Unified) — encoder-free vision + audio early-fusion (#120)#126

Merged
Pushkinist merged 3 commits into
mainfrom
fix/120-gemma4-unified-arch
Jun 17, 2026
Merged

feat(gemma4): support unified any-to-any arch (Gemma4Unified) — encoder-free vision + audio early-fusion (#120)#126
Pushkinist merged 3 commits into
mainfrom
fix/120-gemma4-unified-arch

Conversation

@Pushkinist

Copy link
Copy Markdown
Owner

Closes #120.

Problem

rMLX supported Gemma4 image input only for the gemma4 /
Gemma4ForConditionalGeneration family (separate SigLIP vision_tower, e4b/26b/
31b). It did NOT support the gemma4_unified / Gemma4UnifiedForConditionalGeneration
any-to-any arch used by the dense Gemma 4 12B — the unified model loaded
text-only and image requests returned 503 no vision tower (the loader looked
for vision_tower.*, which the unified snapshot does not have).

Why it's a different architecture

The unified arch is encoder-free: vision and audio are early-fusion soft
tokens projected straight into the shared 48-layer LM — there is no vision/audio
transformer tower. The 12B snapshot ships vision_embedder.* +
embed_vision.embedding_projection and embed_audio.embedding_projection (no
vision_tower.*, no audio_tower.*). The text forward pass was already correct
(mxfp8 12B generates coherently); this PR adds only the multimodal front-end +
arch dispatch.

Changes

  • Arch dispatch: Gemma4(_) if is_unified_arch(...) (checks
    architectures[0]) routes to the new encoder-free path; the standard family
    keeps the SigLIP vision_tower path unchanged.
  • Vision (gemma4/vision/unified.rs): faithful port of HF
    Gemma4UnifiedVisionEmbedder + image processor — host patchify + 3×3 merge
    (pooling_kernel_size, interior layout [ky,ry,kx,rx,ch]), patch_ln1
    (true LayerNorm) → quantized patch_dense (6912→3840 +bias) → patch_ln2
    factorized 2D positional embedding → pos_normembed_vision projection →
    soft tokens scattered at IMAGE_TOKEN_ID.
  • Audio (gemma4/audio/unified.rs): faithful port of HF
    Gemma4UnifiedAudioFeatureExtractor — parameter-free front-end (resample to
    16kHz mono → zero-pad tail to a multiple of audio_samples_per_token=640
    reshape [num_tokens, 640] raw samples; no mel/window/normalization) →
    quantized embed_audio projection (640→3840) → soft tokens scattered at the
    <|audio|> run. num_soft_tokens = ceil(num_samples/640).
  • Shared embedder/quant-linear/LayerNorm helpers reused across vision+audio;
    combined image+audio in one request returns a clear error (no silent drop);
    a model without a unified embedder returns a clear error, never a silent drop.

Proof (real model, gemma-4-12B-it-mxfp8, single-MLX)

Vision — before: 503 no vision tower. After: 4-quadrant image →
"Top-left: Red, Top-right: Green, Bottom-left: Blue, Bottom-right: Yellow"
(exact); counting/ordering exact; scene grounded; no-image control correctly
says no image attached
. (Fine-grained OCR is coarser than the SigLIP e4b path —
an architectural property of the 35M encoder-free projection, documented, not a
port defect.)

Audio — before: disabled (no audio_tower.*). After: say "The launch is scheduled for Tuesday at noon." (55 soft tokens) → response grounded in
"Tuesday at noon"; no-audio control does NOT hallucinate it; a second
clip ("green … three cats") → different grounded answer.

No regression: unified text-only coherent; standard e4b SigLIP vision +
Conformer audio (#122) paths untouched and reproven. Combined image+audio → clear
503.

make lint (-D warnings) + make test green; 12 model-free unit tests + 2
model-gated integration tests (resolve the snapshot via RMLX_O_MODELS_ROOT,
skip-gracefully). Reviewed by the rust-reviewer agent (vision + audio passes);
findings fixed. No new deps, no new runtime/test env vars.

🤖 Generated with Claude Code

Add the gemma4_unified / Gemma4UnifiedForConditionalGeneration multimodal
front-end so the dense Gemma 4 12B accepts image input. The unified arch has
no SigLIP vision_tower: vision is early-fusion via soft tokens projected
straight into the shared 48-layer LM.

New encoder-free vision embedder (crates/rmlx-models/src/gemma4/vision/
unified.rs), a faithful port of HF transformers gemma4_unified
Gemma4UnifiedVisionEmbedder + Gemma4UnifiedImageProcessor:
  - host patchify (16px teacher patches) + patches_merge (3x3 -> 48x48 model
    patch, interior [ky, ry, kx, rx, ch] = contiguous sub-image)
  - patch_ln1 (LayerNorm) -> patch_dense (quantized Linear 6912->3840, +bias)
    -> patch_ln2 -> factorized 2D pos embedding (pos_embedding[x,0]+[y,1]) ->
    pos_norm -> embed_vision (reused RMSNormNoScale -> embedding_projection)
  - 280 soft tokens scattered at image-token positions; text decoder reused
    unchanged.

Dispatch: Architecture::Gemma4 + is_unified_arch routes to the encoder-free
embedder; the standard gemma4 family (e4b/26b/31b) keeps the SigLIP tower.
patch_ln1/ln2/pos_norm are true LayerNorm (weight+bias), not RMSNorm.

Real-model verified on mlx-community__gemma-4-12B-it-mxfp8: distinct correct
colors, exact 4-quadrant + left/right/top/bottom spatial, object counting and
left-to-right ordering. No-image control correctly reports no image. Text-only
and the e4b SigLIP path are not regressed. Fine-grained OCR is weaker than the
e4b tower (architectural property of the 35M encoder-free projection).

Audio (unified embed_audio early-fusion) is a follow-up; the existing
audio_tower loader does not match embed_audio.* so audio stays disabled on 12B.

Refs #120
…review fixes

Part A (unified vision review):
- output_proj_dims is now load-bearing: assert the parsed config dim against
  the loaded embed_vision.embedding_projection input dim at load time (new
  MultimodalEmbedder::projection_input_dim), reject mismatch.
- LayerNorm returns Error::Model on rank-0 input instead of the silent
  *shape.last().unwrap_or(&1) fallback.

Part B (unified audio mandate):
- New gemma4/audio/unified.rs: encoder-free audio front-end (faithful port of
  HF Gemma4UnifiedAudioFeatureExtractor) — raw 16 kHz waveform chunked into
  fixed 640-sample frames (audio_samples_per_token == audio_embed_dim ==
  output_proj_dims == 640), no mel/Conformer, projected via the shared
  embed_audio (RMSNorm -> Linear 640->3840) into the LM hidden space and
  scattered at the <|audio|> positions (build_unified_audio_inputs_embeds),
  mirroring the unified vision path. Same output_proj_dims load-time assert.
- engine: AudioBundle::Gemma4Unified variant; load_gemma4_audio_bundle routes
  is_unified_arch before the Conformer loader; build_audio_prompt forks to the
  encoder-free path. Combined image+audio guard and no-embedder clear error
  intact.
- Tests: model-free unit coverage (config parse, ceil-div frame count, tail
  padding, frame/soft-token alignment) + model-gated integration
  (RMLX_TEST_MODEL_GEMMA4_12B / RMLX_O_MODELS_ROOT, skip-graceful).
- Docs: MODELS.md (unified now vision AND audio; remove 'audio not yet'),
  SERVER.md (audio surface forks Conformer/unified), TESTING.md (12B env var).
…to-discovery

Drop the single-purpose RMLX_TEST_MODEL_GEMMA4_12B env var in favour of
the existing RMLX_O_MODELS_ROOT + slug pattern, matching how Whisper tests
were updated previously. Update the skip diagnostic and remove the row from
docs/TESTING.md.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support Gemma4 'unified' any-to-any arch (Gemma4UnifiedForConditionalGeneration): vision/audio early-fusion soft tokens

1 participant