From 938607958717afba320082705a09b3c77c36403c Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Wed, 21 Jan 2026 14:29:52 -0800 Subject: [PATCH] Use fp32 accumulation in portable/optimized norms (#16727) Summary: When running in fp16, our layer norm implementation diverges from ATen, especially for larger variances. This is because ATen internally accumulates mean and variance in fp32 for fp16/bf16, whereas we accumulate in the input/output dtype. To resolve this, I've matched the ATen behavior by using acc_t (similar to OpMathType from ATen) in optimized. In portable, I've just removed the cast back to CTYPE, as the accumulation logic was already in float. There was also a straight up bug where the element count was being cast to CTYPE in group norm, which tends to overflow or lose precision. I've added an additional test case with high variance. In combination with the existing coverage, this should show that the behavior is unchanged for existing cases and no longer overflows or diverges for larger inputs. Thanks Jason Zhu for spotting this! Differential Revision: D91096147 --- .../optimized/cpu/op_native_layer_norm.cpp | 12 +++---- kernels/optimized/test/targets.bzl | 1 + kernels/optimized/utils/math_utils.h | 11 ++++++ kernels/portable/cpu/op_native_group_norm.cpp | 4 +-- kernels/portable/cpu/op_native_layer_norm.cpp | 10 +++--- kernels/portable/cpu/vec_ops.h | 2 +- kernels/test/op_native_group_norm_test.cpp | 36 +++++++++++++++++++ kernels/test/op_native_layer_norm_test.cpp | 21 +++++++++++ 8 files changed, 83 insertions(+), 14 deletions(-) diff --git a/kernels/optimized/cpu/op_native_layer_norm.cpp b/kernels/optimized/cpu/op_native_layer_norm.cpp index 8d5410cb581..5a6db59768c 100644 --- a/kernels/optimized/cpu/op_native_layer_norm.cpp +++ b/kernels/optimized/cpu/op_native_layer_norm.cpp @@ -76,18 +76,18 @@ void layer_norm( const CTYPE* src_ptr = input_data + i * N; CTYPE* dst_ptr = out_data + i * N; - CTYPE mean_val; - CTYPE rstd_val; + acc_t mean_val; + acc_t rstd_val; std::tie(mean_val, rstd_val) = RowwiseMoments(src_ptr, N); rstd_val = CTYPE(1) / std::sqrt(rstd_val + eps); - const CTYPE scale = rstd_val; - const CTYPE offset = -rstd_val * mean_val; + const acc_t scale = rstd_val; + const acc_t offset = -rstd_val * mean_val; if (gamma_null || beta_null) { for (size_t j = 0; j < N; ++j) { - const CTYPE gamma_v = gamma_null ? CTYPE(1) : gamma_data[j]; - const CTYPE beta_v = beta_null ? CTYPE(0) : beta_data[j]; + const acc_t gamma_v = gamma_null ? CTYPE(1) : gamma_data[j]; + const acc_t beta_v = beta_null ? CTYPE(0) : beta_data[j]; dst_ptr[j] = (src_ptr[j] * scale + offset) * gamma_v + beta_v; } } else { diff --git a/kernels/optimized/test/targets.bzl b/kernels/optimized/test/targets.bzl index 438e7e57215..3f240661027 100644 --- a/kernels/optimized/test/targets.bzl +++ b/kernels/optimized/test/targets.bzl @@ -25,6 +25,7 @@ def _lib_test_bin(name, extra_deps = [], in_cpu = False): deps = [ "//executorch/test/utils:utils", "//executorch/kernels/optimized{}:{}".format(cpu_path, lib_root), + "//executorch/runtime/core/portable_type:scalar_type", ] + extra_deps, preprocessor_flags = get_vec_preprocessor_flags() + get_vec_cxx_preprocessor_flags(), ) diff --git a/kernels/optimized/utils/math_utils.h b/kernels/optimized/utils/math_utils.h index 0b671f3e5f1..f94d94a3ab9 100644 --- a/kernels/optimized/utils/math_utils.h +++ b/kernels/optimized/utils/math_utils.h @@ -11,6 +11,8 @@ #include #include +#include +#include namespace executorch { namespace utils { @@ -37,6 +39,15 @@ template <> struct ComputeDTypeTraits { using type = int32_t; }; +// For 16 bit float types, ops should perform internal math in float32. +template <> +struct ComputeDTypeTraits { + using type = float; +}; +template <> +struct ComputeDTypeTraits { + using type = float; +}; template using compute_dtype = typename ComputeDTypeTraits::type; diff --git a/kernels/portable/cpu/op_native_group_norm.cpp b/kernels/portable/cpu/op_native_group_norm.cpp index 9e300dc7829..7a2b40cba1a 100644 --- a/kernels/portable/cpu/op_native_group_norm.cpp +++ b/kernels/portable/cpu/op_native_group_norm.cpp @@ -77,8 +77,8 @@ void group_norm( const CTYPE* x = input_data + i * inner_size; // compute E[X] and Var[x] = E[x^2] - E[x]^2 - CTYPE sum = reduce_add(x, static_cast(inner_size)); - CTYPE sq_sum = vec_powerf(x, static_cast(inner_size)); + float sum = reduce_add(x, inner_size); + float sq_sum = vec_powerf(x, inner_size); double mean_value = static_cast(sum) / static_cast(inner_size); double variance = diff --git a/kernels/portable/cpu/op_native_layer_norm.cpp b/kernels/portable/cpu/op_native_layer_norm.cpp index 12a03a184f6..e97d17acfdf 100644 --- a/kernels/portable/cpu/op_native_layer_norm.cpp +++ b/kernels/portable/cpu/op_native_layer_norm.cpp @@ -73,11 +73,11 @@ void layer_norm( CTYPE* y = out_data + i * normalized; // compute E[X] and Var[x] = E[x^2] - E[x]^2 - CTYPE sum = reduce_add(x, ct_normalized); - CTYPE sq_sum = vec_powerf(x, ct_normalized); - CTYPE mean_value = sum / ct_normalized; - CTYPE variance = sq_sum / ct_normalized - mean_value * mean_value; - CTYPE std = std::sqrt(variance + eps); + float sum = reduce_add(x, ct_normalized); + float sq_sum = vec_powerf(x, ct_normalized); + float mean_value = sum / ct_normalized; + float variance = sq_sum / ct_normalized - mean_value * mean_value; + float std = std::sqrt(variance + eps); // Calculate the elements of output for (const auto j : c10::irange(normalized)) { diff --git a/kernels/portable/cpu/vec_ops.h b/kernels/portable/cpu/vec_ops.h index 87dd05ac7d4..bafc1b24879 100644 --- a/kernels/portable/cpu/vec_ops.h +++ b/kernels/portable/cpu/vec_ops.h @@ -179,7 +179,7 @@ template inline float vec_powerf(const T* x, size_t size) { float sum = 0; for (const auto i : c10::irange(size)) { - sum += x[i] * x[i]; + sum += static_cast(x[i]) * x[i]; } return sum; } diff --git a/kernels/test/op_native_group_norm_test.cpp b/kernels/test/op_native_group_norm_test.cpp index 591df6e186b..2280adf64ee 100644 --- a/kernels/test/op_native_group_norm_test.cpp +++ b/kernels/test/op_native_group_norm_test.cpp @@ -243,6 +243,42 @@ class OpNativeGroupNormTest : public OperatorTest { 0.38038814, 0.75809801}, // expected_rstd_data }, + { + {1, 4, 3}, // sizes + {0.0, + 1000.0, + 2000.0, + 3000.0, + 4000.0, + 5000.0, + 6000.0, + 7000.0, + 8000.0, + 9000.0, + 10000.0, + 11000.0}, // input_data + {1.0, 1.0, 1.0, 1.0}, // weight_data + {0.0, 0.0, 0.0, 0.0}, // bias_data + 1, // N + 4, // C + 3, // HxW + 2, // group + 1e-5, // eps + {-1.46385, + -0.87831, + -0.29277, + 0.29277, + 0.87831, + 1.46385, + -1.46385, + -0.87831, + -0.29277, + 0.29277, + 0.87831, + 1.46385}, // expected_data + {2500.0, 8500.0}, // expected_mean_data + {0.00058554, 0.00058554}, // expected_rstd_data + }, }; run_test_cases(test_cases); diff --git a/kernels/test/op_native_layer_norm_test.cpp b/kernels/test/op_native_layer_norm_test.cpp index d8cc2d3b2e4..930214d238c 100644 --- a/kernels/test/op_native_layer_norm_test.cpp +++ b/kernels/test/op_native_layer_norm_test.cpp @@ -101,6 +101,12 @@ class OpNativeLayerNormTest : public OperatorTest { expected, 1e-2, executorch::runtime::testing::internal::kDefaultBFloat16Atol); + } else if constexpr (DTYPE == ScalarType::Half) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out0, + expected, + 1e-3, + executorch::runtime::testing::internal::kDefaultHalfAtol); } else { EXPECT_TENSOR_CLOSE(out0, expected); } @@ -235,6 +241,21 @@ class OpNativeLayerNormTest : public OperatorTest { 1.38873, -0.46291}, // expected_data }, + { + std::string(__func__) + ": Large variance", + {1, 2, 3}, // sizes + {0.0, 1000.0, 2000.0, 3000.0, 4000.0, 5000.0}, // input_data + {1, 2, 3}, // normalized shape + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, // weights + {0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, // bias + 1.0e-5, // eps + {-1.46385, + -0.87831, + -0.29277, + 0.29277, + 0.87831, + 1.46385}, // expected_data + }, }; run_test_cases(test_cases);