Skip to content

onnx: fix LayerNorm output dtype mismatch with F16 inputs#2276

Open
dqj1998 wants to merge 1 commit into
sonos:mainfrom
dqj1998:fix-layer-norm-f16-cast
Open

onnx: fix LayerNorm output dtype mismatch with F16 inputs#2276
dqj1998 wants to merge 1 commit into
sonos:mainfrom
dqj1998:fix-layer-norm-f16-cast

Conversation

@dqj1998
Copy link
Copy Markdown

@dqj1998 dqj1998 commented May 24, 2026

The LayerNorm op's wire expansion casts normalized back to fact.datum_type before applying scale/bias, then multiplies that result with cast_scale (which is still in self.datum_type, F32).

With F16 inputs this becomes F16 × F32, whose output is downgraded to F32 by mul(). The inference rule then asserts
outputs[0].datum_type == inputs[0].datum_type (F16) against the actual F32 output, failing into_typed() with:

Output mismatch after rewiring expansion for output #0:
expected 1,256,384,F16 got 1,256,384,F32

Fix: defer the cast back to fact.datum_type until after all scale/bias operations. Now the expansion stays entirely in self.datum_type (F32) through normalized × scale + bias, and casts only the final result.

Behavior is unchanged for F32 inputs (the final cast is a no-op when fact.datum_type == self.datum_type).

Reproduced with sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 exported via optimum.exporters.onnx.main_export(..., dtype="fp16") and loaded with into_optimized().into_runnable().

The LayerNorm op's `wire` expansion casts `normalized` back to
fact.datum_type *before* applying scale/bias, then multiplies that
result with `cast_scale` (which is still in self.datum_type, F32).

With F16 inputs this becomes F16 × F32, whose output is downgraded to
F32 by `mul()`. The inference rule then asserts
`outputs[0].datum_type == inputs[0].datum_type` (F16) against the
actual F32 output, failing `into_typed()` with:

    Output mismatch after rewiring expansion for output #0:
    expected 1,256,384,F16 got 1,256,384,F32

Fix: defer the cast back to fact.datum_type until after all scale/bias
operations. Now the expansion stays entirely in self.datum_type (F32)
through normalized × scale + bias, and casts only the final result.

Behavior is unchanged for F32 inputs (the final cast is a no-op when
fact.datum_type == self.datum_type).

Reproduced with sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
exported via `optimum.exporters.onnx.main_export(..., dtype="fp16")`
and loaded with `into_optimized().into_runnable()`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants