fix(gemma4): bidirectional vision attention for unified soft tokens — fixes color corruption (#127)#128
Merged
Merged
Conversation
The encoder-free gemma4_unified (12B) vision path misnamed colors and hallucinated layout: green/yellow read as black/purple, white-card content read as 'corrupted pinks/mauves/grays'. Root cause was NOT preprocessing or the embedder (verified bit-exact against the HF reference at every stage — patchify channel order, LayerNorms, mxfp8 patch_dense, pos-emb, RMSNorm projection all match to printed precision). It was the text decoder reading the image soft tokens *causally*. Gemma 4 conditions each image's soft tokens with bidirectional attention (every soft token of an image attends to every other soft token of that image; text stays causal). The SigLIP tower path (e4b/26b/31b) integrates the image in its ViT so causal LM attention suffices, but the encoder-free path projects raw patches with no pre-integrated context and must be read bidirectionally. Fix: build a per-prefill bidirectional overlay keyed off the <start_of_image>/<end_of_image> markers and merge it (element-wise maximum) with each layer's causal/SWA prefill mask. Threaded as Option<&Array> through the Gemma4 decoder layer / attention / mask builder; None on text, audio, decode, and speculative paths (no-op). Also corrects the unified embedder's LayerNorm eps to the PyTorch nn.LayerNorm default 1e-5 (was rms_norm_eps 1e-6). Result on the real 12B (temp=0): solid-color naming 1/5 -> 4/5 correct (red/green/blue/yellow), white-card border + center shape now read correctly. Pure achromatic input (white/gray/black) remains indistinguishable — an inherent property of the encoder-free projection (patch_ln1 normalizes away the absolute level, so white=gray=black map to one embedding; faithful to HF), documented in docs/MODELS.md. e4b SigLIP control unchanged (5/5 + white). Tests: model-free guards for the bidi overlay allow/block pattern and the patchify channel-order/value preservation (the test that would have caught the originally-suspected channel defect). Refs #127
…fallback; shared BOI/EOI consts Review fixes on top of the unified-vision colour fix: - perf: thread `has_image: bool` into `forward_h` so pure-text prefill no longer forces a device→host id sync via `build_vision_bidi_overlay`. The text path (`forward_arr`) passes `false`; the image path (`forward_arr_embeds`) passes `true`. Decode (seq==1 / offset>0) untouched. - traceability: `combine_bidi_overlay` now emits a `tracing::warn!` with both shapes before falling back to the plain causal mask, so the (currently unreachable) image-block→causal regression is loud instead of silent. - dedup: promote the BOI/EOI marker token ids to a single `pub const` in `rmlx-models::gemma4::vision`; the model and the server engine both import them (cast at the i32/u32 boundary) instead of redeclaring. - test: pin "image-bidi overrides the sliding window inside the block" with a model-free SWA + overlay mask case. - reword a doc comment to drop a ticket reference (hard rule).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #127.
Problem
The encoder-free
gemma4_unified(12B) vision path (added in #126/#120)corrupted color: on solid-color PNGs the unified 12B misnamed green→black,
yellow→purple, white→black, while the standard
gemma4SigLIP path (e4b) namedall 5 correctly on the same images. (The original #120 proof used a synthetic
4-quadrant image and a "describe it" prompt, which was too soft to catch this.)
Root cause — the issue's preprocessing diagnosis was falsified
The filed diagnosis (channel order / mean-std / RGBA stride) is wrong. The
patchify + embedder were verified numerically bit-exact to the HF
Gemma4UnifiedVisionEmbedderreference (NumPy port, all digits). The real defect:the text decoder ran the image soft tokens causally, but Gemma4 unified sets
use_bidirectional_attention: "vision"— every soft token in an image block mustattend to every other soft token of the same block. The SigLIP path (e4b/26b/31b)
hides this because its ViT pre-integrates the image before projection; the
encoder-free path projects raw patches with no pre-integrated context, so causal
masking scrambles them (green/yellow → black/purple, spatial layout hallucinated).
Fix (general)
build_vision_bidi_overlay(gemma4/model.rs): builds a[1,1,seq,seq]additive overlay (0.0 = allowed) opening intra-image-block cells, keyed off the
<start_of_image>/<end_of_image>markers, merged via element-wisemaximum(
combine_bidi_overlay) with each layer's causal/SWA prefill mask. Threaded asOption<&Array>throughDecoderLayer/Attention;None(no-op, byte-identical)on text / audio / decode / speculative paths, and gated on
has_imageso atext prefill never even syncs ids to host.
nn.LayerNormdefault1e-5.pub const(was duplicated across crates).Proof (real 12B-mxfp8, temp=0, identical PNGs)
Unified solid-color battery, BEFORE → AFTER: green Purple→Green, yellow
purple→Yellow, red Red→Red, blue blue→Blue = chromatic 4/4 correct. The
white-card image (white bg, red border, dark text) now reads "light gray
background, thin red border, black center" (structure correct) vs BEFORE
"corrupted pinks/mauves/grays". e4b SigLIP control unchanged (5/5). Unified text
("Paris") + unified audio paths un-regressed.
Documented limitation (not an rMLX bug)
A 100%-uniform solid white fill still reads as black, and this is inherent
to the architecture:
patch_ln1is a 6912-dim LayerNorm applied to the rawpatch as the first op, so any achromatic uniform patch has zero variance and
collapses to the bias — white ≡ gray ≡ black by construction (confirmed by the
weight shapes and the rust-reviewer). The real Gemma4 12B has the same property;
the SigLIP path escapes it only via its ViT. Real images (with inter-patch
variance) are unaffected. Documented in
docs/MODELS.md.Tests + guards
Model-free: patchify channel-preservation + interior-layout; bidi overlay opens
intra-block / does not cross blocks / absent without markers; SWA+overlay
override.
make lint(-D warnings) +make testgreen. Reviewed by therust-reviewer agent (no CRITICAL/HIGH; all findings fixed in the second commit).
No new deps, no new env vars.
🤖 Generated with Claude Code