-
Notifications
You must be signed in to change notification settings - Fork 807
Use fp32 accumulation in portable/optimized norms #16727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16727
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (6 Unrelated Failures)As of commit 9386079 with merge base 8e7d761 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@GregoryComer has exported this pull request. If you are a Meta employee, you can view the originating Diff in D91096147. |
1a89352 to
dfd8f9a
Compare
Summary: When running in fp16, our layer norm implementation diverges from ATen, especially for larger variances. This is because ATen internally accumulates mean and variance in fp32 for fp16/bf16, whereas we accumulate in the input/output dtype. To resolve this, I've matched the ATen behavior by using acc_t (similar to OpMathType from ATen) in optimized. In portable, I've just removed the cast back to CTYPE, as the accumulation logic was already in float. I've added an additional test case with high variance. In combination with the existing coverage, this should show that the behavior is unchanged for existing cases and no longer overflows or diverges for larger inputs. Differential Revision: D91096147
Summary: When running in fp16, our layer norm implementation diverges from ATen, especially for larger variances. This is because ATen internally accumulates mean and variance in fp32 for fp16/bf16, whereas we accumulate in the input/output dtype. To resolve this, I've matched the ATen behavior by using acc_t (similar to OpMathType from ATen) in optimized. In portable, I've just removed the cast back to CTYPE, as the accumulation logic was already in float. There was also a straight up bug where the element count was being cast to CTYPE in group norm, which tends to overflow or lose precision. I've added an additional test case with high variance. In combination with the existing coverage, this should show that the behavior is unchanged for existing cases and no longer overflows or diverges for larger inputs. Thanks Jason Zhu for spotting this! Differential Revision: D91096147
dfd8f9a to
9386079
Compare
| struct ComputeDTypeTraits<int8_t> { | ||
| using type = int32_t; | ||
| }; | ||
| // For 16 bit float types, ops should perform internal math in float32. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This matches the ATen convention from OpMathType.h.
| // compute E[X] and Var[x] = E[x^2] - E[x]^2 | ||
| CTYPE sum = reduce_add(x, static_cast<CTYPE>(inner_size)); | ||
| CTYPE sq_sum = vec_powerf(x, static_cast<CTYPE>(inner_size)); | ||
| float sum = reduce_add(x, inner_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: We should ideally use float64 if CTYPE is double, but we're already forcing float32 in the various utility functions - reduce_add and vec_powerf return floats and use floats internally. I'm just removing the implicit cast to CTYPE on the return value.
I'm inclined to leave this for a follow-up on an as-needed basis.
|
Note that unit tests are failing on trunk. Checking the job logs on this PR shows it's an unrelated failure. |
Summary:
When running in fp16, our layer norm implementation diverges from ATen, especially for larger variances. This is because ATen internally accumulates mean and variance in fp32 for fp16/bf16, whereas we accumulate in the input/output dtype.
To resolve this, I've matched the ATen behavior by using acc_t (similar to OpMathType from ATen) in optimized. In portable, I've just removed the cast back to CTYPE, as the accumulation logic was already in float.
There was also a straight up bug where the element count was being cast to CTYPE in group norm, which tends to overflow or lose precision.
I've added an additional test case with high variance. In combination with the existing coverage, this should show that the behavior is unchanged for existing cases and no longer overflows or diverges for larger inputs.
Thanks Jason Zhu for spotting this!
Differential Revision: D91096147