Skip to content

Conversation

@GregoryComer
Copy link
Member

@GregoryComer GregoryComer commented Jan 21, 2026

Summary:
When running in fp16, our layer norm implementation diverges from ATen, especially for larger variances. This is because ATen internally accumulates mean and variance in fp32 for fp16/bf16, whereas we accumulate in the input/output dtype.

To resolve this, I've matched the ATen behavior by using acc_t (similar to OpMathType from ATen) in optimized. In portable, I've just removed the cast back to CTYPE, as the accumulation logic was already in float.

There was also a straight up bug where the element count was being cast to CTYPE in group norm, which tends to overflow or lose precision.

I've added an additional test case with high variance. In combination with the existing coverage, this should show that the behavior is unchanged for existing cases and no longer overflows or diverges for larger inputs.

Thanks Jason Zhu for spotting this!

Differential Revision: D91096147

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 21, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16727

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (6 Unrelated Failures)

As of commit 9386079 with merge base 8e7d761 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-codesync
Copy link
Contributor

meta-codesync bot commented Jan 21, 2026

@GregoryComer has exported this pull request. If you are a Meta employee, you can view the originating Diff in D91096147.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 21, 2026
@GregoryComer GregoryComer added the release notes: ops & kernels Changes to the opset and any new / changed kernel implementations label Jan 21, 2026
GregoryComer added a commit to GregoryComer/executorch that referenced this pull request Jan 21, 2026
Summary:

When running in fp16, our layer norm implementation diverges from ATen, especially for larger variances. This is because ATen internally accumulates mean and variance in fp32 for fp16/bf16, whereas we accumulate in the input/output dtype.

To resolve this, I've matched the ATen behavior by using acc_t (similar to OpMathType from ATen) in optimized. In portable, I've just removed the cast back to CTYPE, as the accumulation logic was already in float.

I've added an additional test case with high variance. In combination with the existing coverage, this should show that the behavior is unchanged for existing cases and no longer overflows or diverges for larger inputs.

Differential Revision: D91096147
Summary:

When running in fp16, our layer norm implementation diverges from ATen, especially for larger variances. This is because ATen internally accumulates mean and variance in fp32 for fp16/bf16, whereas we accumulate in the input/output dtype.

To resolve this, I've matched the ATen behavior by using acc_t (similar to OpMathType from ATen) in optimized. In portable, I've just removed the cast back to CTYPE, as the accumulation logic was already in float.

There was also a straight up bug where the element count was being cast to CTYPE in group norm, which tends to overflow or lose precision.

I've added an additional test case with high variance. In combination with the existing coverage, this should show that the behavior is unchanged for existing cases and no longer overflows or diverges for larger inputs.

Thanks Jason Zhu for spotting this!

Differential Revision: D91096147
struct ComputeDTypeTraits<int8_t> {
using type = int32_t;
};
// For 16 bit float types, ops should perform internal math in float32.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This matches the ATen convention from OpMathType.h.

// compute E[X] and Var[x] = E[x^2] - E[x]^2
CTYPE sum = reduce_add(x, static_cast<CTYPE>(inner_size));
CTYPE sq_sum = vec_powerf(x, static_cast<CTYPE>(inner_size));
float sum = reduce_add(x, inner_size);
Copy link
Member Author

@GregoryComer GregoryComer Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: We should ideally use float64 if CTYPE is double, but we're already forcing float32 in the various utility functions - reduce_add and vec_powerf return floats and use floats internally. I'm just removing the implicit cast to CTYPE on the return value.

I'm inclined to leave this for a follow-up on an as-needed basis.

@GregoryComer GregoryComer changed the title Use fp32 accumulation in portable/optimized layer norm Use fp32 accumulation in portable/optimized norms Jan 21, 2026
@GregoryComer GregoryComer marked this pull request as draft January 21, 2026 22:42
@GregoryComer
Copy link
Member Author

GregoryComer commented Jan 21, 2026

Note that unit tests are failing on trunk. Checking the job logs on this PR shows it's an unrelated failure.

@GregoryComer GregoryComer marked this pull request as ready for review January 21, 2026 23:38
@github-project-automation github-project-automation bot moved this to To triage in ExecuTorch Core Jan 22, 2026
@GregoryComer GregoryComer moved this from To triage to In progress in ExecuTorch Core Jan 22, 2026
@GregoryComer GregoryComer self-assigned this Jan 22, 2026
@GregoryComer GregoryComer merged commit 86b4bea into pytorch:main Jan 22, 2026
140 of 147 checks passed
@github-project-automation github-project-automation bot moved this from In progress to Done in ExecuTorch Core Jan 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported release notes: ops & kernels Changes to the opset and any new / changed kernel implementations

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants