Skip to content

fix(gemma4): bidirectional vision attention for unified soft tokens — fixes color corruption (#127)#128

Merged
Pushkinist merged 2 commits into
mainfrom
fix/127-unified-vision-color
Jun 17, 2026
Merged

fix(gemma4): bidirectional vision attention for unified soft tokens — fixes color corruption (#127)#128
Pushkinist merged 2 commits into
mainfrom
fix/127-unified-vision-color

Conversation

@Pushkinist

Copy link
Copy Markdown
Owner

Closes #127.

Problem

The encoder-free gemma4_unified (12B) vision path (added in #126/#120)
corrupted color: on solid-color PNGs the unified 12B misnamed green→black,
yellow→purple, white→black, while the standard gemma4 SigLIP path (e4b) named
all 5 correctly on the same images. (The original #120 proof used a synthetic
4-quadrant image and a "describe it" prompt, which was too soft to catch this.)

Root cause — the issue's preprocessing diagnosis was falsified

The filed diagnosis (channel order / mean-std / RGBA stride) is wrong. The
patchify + embedder were verified numerically bit-exact to the HF
Gemma4UnifiedVisionEmbedder reference (NumPy port, all digits). The real defect:
the text decoder ran the image soft tokens causally, but Gemma4 unified sets
use_bidirectional_attention: "vision" — every soft token in an image block must
attend to every other soft token of the same block. The SigLIP path (e4b/26b/31b)
hides this because its ViT pre-integrates the image before projection; the
encoder-free path projects raw patches with no pre-integrated context, so causal
masking scrambles them (green/yellow → black/purple, spatial layout hallucinated).

Fix (general)

  • build_vision_bidi_overlay (gemma4/model.rs): builds a [1,1,seq,seq]
    additive overlay (0.0 = allowed) opening intra-image-block cells, keyed off the
    <start_of_image>/<end_of_image> markers, merged via element-wise maximum
    (combine_bidi_overlay) with each layer's causal/SWA prefill mask. Threaded as
    Option<&Array> through DecoderLayer/Attention; None (no-op, byte-identical)
    on text / audio / decode / speculative paths, and gated on has_image so a
    text prefill never even syncs ids to host.
  • LayerNorm eps corrected to the PyTorch nn.LayerNorm default 1e-5.
  • BOI/EOI marker ids promoted to a single pub const (was duplicated across crates).

Proof (real 12B-mxfp8, temp=0, identical PNGs)

Unified solid-color battery, BEFORE → AFTER: green Purple→Green, yellow
purple→Yellow, red Red→Red, blue blue→Blue = chromatic 4/4 correct. The
white-card image (white bg, red border, dark text) now reads "light gray
background, thin red border, black center" (structure correct) vs BEFORE
"corrupted pinks/mauves/grays". e4b SigLIP control unchanged (5/5). Unified text
("Paris") + unified audio paths un-regressed.

Documented limitation (not an rMLX bug)

A 100%-uniform solid white fill still reads as black, and this is inherent
to the architecture
: patch_ln1 is a 6912-dim LayerNorm applied to the raw
patch as the first op, so any achromatic uniform patch has zero variance and
collapses to the bias — white ≡ gray ≡ black by construction (confirmed by the
weight shapes and the rust-reviewer). The real Gemma4 12B has the same property;
the SigLIP path escapes it only via its ViT. Real images (with inter-patch
variance) are unaffected. Documented in docs/MODELS.md.

Tests + guards

Model-free: patchify channel-preservation + interior-layout; bidi overlay opens
intra-block / does not cross blocks / absent without markers; SWA+overlay
override. make lint (-D warnings) + make test green. Reviewed by the
rust-reviewer agent (no CRITICAL/HIGH; all findings fixed in the second commit).
No new deps, no new env vars.

🤖 Generated with Claude Code

The encoder-free gemma4_unified (12B) vision path misnamed colors and
hallucinated layout: green/yellow read as black/purple, white-card content
read as 'corrupted pinks/mauves/grays'. Root cause was NOT preprocessing or
the embedder (verified bit-exact against the HF reference at every stage —
patchify channel order, LayerNorms, mxfp8 patch_dense, pos-emb, RMSNorm
projection all match to printed precision). It was the text decoder reading
the image soft tokens *causally*.

Gemma 4 conditions each image's soft tokens with bidirectional attention
(every soft token of an image attends to every other soft token of that
image; text stays causal). The SigLIP tower path (e4b/26b/31b) integrates the
image in its ViT so causal LM attention suffices, but the encoder-free path
projects raw patches with no pre-integrated context and must be read
bidirectionally.

Fix: build a per-prefill bidirectional overlay keyed off the
<start_of_image>/<end_of_image> markers and merge it (element-wise maximum)
with each layer's causal/SWA prefill mask. Threaded as Option<&Array> through
the Gemma4 decoder layer / attention / mask builder; None on text, audio,
decode, and speculative paths (no-op). Also corrects the unified embedder's
LayerNorm eps to the PyTorch nn.LayerNorm default 1e-5 (was rms_norm_eps 1e-6).

Result on the real 12B (temp=0): solid-color naming 1/5 -> 4/5 correct
(red/green/blue/yellow), white-card border + center shape now read correctly.
Pure achromatic input (white/gray/black) remains indistinguishable — an
inherent property of the encoder-free projection (patch_ln1 normalizes away
the absolute level, so white=gray=black map to one embedding; faithful to HF),
documented in docs/MODELS.md. e4b SigLIP control unchanged (5/5 + white).

Tests: model-free guards for the bidi overlay allow/block pattern and the
patchify channel-order/value preservation (the test that would have caught
the originally-suspected channel defect).

Refs #127
…fallback; shared BOI/EOI consts

Review fixes on top of the unified-vision colour fix:

- perf: thread `has_image: bool` into `forward_h` so pure-text prefill no
  longer forces a device→host id sync via `build_vision_bidi_overlay`. The
  text path (`forward_arr`) passes `false`; the image path
  (`forward_arr_embeds`) passes `true`. Decode (seq==1 / offset>0) untouched.
- traceability: `combine_bidi_overlay` now emits a `tracing::warn!` with both
  shapes before falling back to the plain causal mask, so the (currently
  unreachable) image-block→causal regression is loud instead of silent.
- dedup: promote the BOI/EOI marker token ids to a single `pub const` in
  `rmlx-models::gemma4::vision`; the model and the server engine both import
  them (cast at the i32/u32 boundary) instead of redeclaring.
- test: pin "image-bidi overrides the sliding window inside the block" with a
  model-free SWA + overlay mask case.
- reword a doc comment to drop a ticket reference (hard rule).
@Pushkinist Pushkinist merged commit 0e6700e into main Jun 17, 2026
2 checks passed
@Pushkinist Pushkinist deleted the fix/127-unified-vision-color branch June 17, 2026 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gemma4-unified vision embedder corrupts colors (white→black, yellow→purple, green→black); e4b SigLIP path correct

1 participant