onnx: fix LayerNorm output dtype mismatch with F16 inputs#2276
Open
dqj1998 wants to merge 1 commit into
Open
Conversation
The LayerNorm op's `wire` expansion casts `normalized` back to
fact.datum_type *before* applying scale/bias, then multiplies that
result with `cast_scale` (which is still in self.datum_type, F32).
With F16 inputs this becomes F16 × F32, whose output is downgraded to
F32 by `mul()`. The inference rule then asserts
`outputs[0].datum_type == inputs[0].datum_type` (F16) against the
actual F32 output, failing `into_typed()` with:
Output mismatch after rewiring expansion for output #0:
expected 1,256,384,F16 got 1,256,384,F32
Fix: defer the cast back to fact.datum_type until after all scale/bias
operations. Now the expansion stays entirely in self.datum_type (F32)
through normalized × scale + bias, and casts only the final result.
Behavior is unchanged for F32 inputs (the final cast is a no-op when
fact.datum_type == self.datum_type).
Reproduced with sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
exported via `optimum.exporters.onnx.main_export(..., dtype="fp16")`
and loaded with `into_optimized().into_runnable()`.
kali
approved these changes
May 26, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The LayerNorm op's
wireexpansion castsnormalizedback to fact.datum_type before applying scale/bias, then multiplies that result withcast_scale(which is still in self.datum_type, F32).With F16 inputs this becomes F16 × F32, whose output is downgraded to F32 by
mul(). The inference rule then assertsoutputs[0].datum_type == inputs[0].datum_type(F16) against the actual F32 output, failinginto_typed()with:Fix: defer the cast back to fact.datum_type until after all scale/bias operations. Now the expansion stays entirely in self.datum_type (F32) through normalized × scale + bias, and casts only the final result.
Behavior is unchanged for F32 inputs (the final cast is a no-op when fact.datum_type == self.datum_type).
Reproduced with sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 exported via
optimum.exporters.onnx.main_export(..., dtype="fp16")and loaded withinto_optimized().into_runnable().