diff --git a/kernels/optimized/cpu/op_native_layer_norm.cpp b/kernels/optimized/cpu/op_native_layer_norm.cpp index 8d5410cb581..5a6db59768c 100644 --- a/kernels/optimized/cpu/op_native_layer_norm.cpp +++ b/kernels/optimized/cpu/op_native_layer_norm.cpp @@ -76,18 +76,18 @@ void layer_norm( const CTYPE* src_ptr = input_data + i * N; CTYPE* dst_ptr = out_data + i * N; - CTYPE mean_val; - CTYPE rstd_val; + acc_t mean_val; + acc_t rstd_val; std::tie(mean_val, rstd_val) = RowwiseMoments(src_ptr, N); rstd_val = CTYPE(1) / std::sqrt(rstd_val + eps); - const CTYPE scale = rstd_val; - const CTYPE offset = -rstd_val * mean_val; + const acc_t scale = rstd_val; + const acc_t offset = -rstd_val * mean_val; if (gamma_null || beta_null) { for (size_t j = 0; j < N; ++j) { - const CTYPE gamma_v = gamma_null ? CTYPE(1) : gamma_data[j]; - const CTYPE beta_v = beta_null ? CTYPE(0) : beta_data[j]; + const acc_t gamma_v = gamma_null ? CTYPE(1) : gamma_data[j]; + const acc_t beta_v = beta_null ? CTYPE(0) : beta_data[j]; dst_ptr[j] = (src_ptr[j] * scale + offset) * gamma_v + beta_v; } } else { diff --git a/kernels/optimized/test/targets.bzl b/kernels/optimized/test/targets.bzl index 438e7e57215..3f240661027 100644 --- a/kernels/optimized/test/targets.bzl +++ b/kernels/optimized/test/targets.bzl @@ -25,6 +25,7 @@ def _lib_test_bin(name, extra_deps = [], in_cpu = False): deps = [ "//executorch/test/utils:utils", "//executorch/kernels/optimized{}:{}".format(cpu_path, lib_root), + "//executorch/runtime/core/portable_type:scalar_type", ] + extra_deps, preprocessor_flags = get_vec_preprocessor_flags() + get_vec_cxx_preprocessor_flags(), ) diff --git a/kernels/optimized/utils/math_utils.h b/kernels/optimized/utils/math_utils.h index 0b671f3e5f1..f94d94a3ab9 100644 --- a/kernels/optimized/utils/math_utils.h +++ b/kernels/optimized/utils/math_utils.h @@ -11,6 +11,8 @@ #include #include +#include +#include namespace executorch { namespace utils { @@ -37,6 +39,15 @@ template <> struct ComputeDTypeTraits { using type = int32_t; }; +// For 16 bit float types, ops should perform internal math in float32. +template <> +struct ComputeDTypeTraits { + using type = float; +}; +template <> +struct ComputeDTypeTraits { + using type = float; +}; template using compute_dtype = typename ComputeDTypeTraits::type; diff --git a/kernels/portable/cpu/op_native_group_norm.cpp b/kernels/portable/cpu/op_native_group_norm.cpp index 9e300dc7829..7a2b40cba1a 100644 --- a/kernels/portable/cpu/op_native_group_norm.cpp +++ b/kernels/portable/cpu/op_native_group_norm.cpp @@ -77,8 +77,8 @@ void group_norm( const CTYPE* x = input_data + i * inner_size; // compute E[X] and Var[x] = E[x^2] - E[x]^2 - CTYPE sum = reduce_add(x, static_cast(inner_size)); - CTYPE sq_sum = vec_powerf(x, static_cast(inner_size)); + float sum = reduce_add(x, inner_size); + float sq_sum = vec_powerf(x, inner_size); double mean_value = static_cast(sum) / static_cast(inner_size); double variance = diff --git a/kernels/portable/cpu/op_native_layer_norm.cpp b/kernels/portable/cpu/op_native_layer_norm.cpp index 12a03a184f6..e97d17acfdf 100644 --- a/kernels/portable/cpu/op_native_layer_norm.cpp +++ b/kernels/portable/cpu/op_native_layer_norm.cpp @@ -73,11 +73,11 @@ void layer_norm( CTYPE* y = out_data + i * normalized; // compute E[X] and Var[x] = E[x^2] - E[x]^2 - CTYPE sum = reduce_add(x, ct_normalized); - CTYPE sq_sum = vec_powerf(x, ct_normalized); - CTYPE mean_value = sum / ct_normalized; - CTYPE variance = sq_sum / ct_normalized - mean_value * mean_value; - CTYPE std = std::sqrt(variance + eps); + float sum = reduce_add(x, ct_normalized); + float sq_sum = vec_powerf(x, ct_normalized); + float mean_value = sum / ct_normalized; + float variance = sq_sum / ct_normalized - mean_value * mean_value; + float std = std::sqrt(variance + eps); // Calculate the elements of output for (const auto j : c10::irange(normalized)) { diff --git a/kernels/portable/cpu/vec_ops.h b/kernels/portable/cpu/vec_ops.h index 87dd05ac7d4..bafc1b24879 100644 --- a/kernels/portable/cpu/vec_ops.h +++ b/kernels/portable/cpu/vec_ops.h @@ -179,7 +179,7 @@ template inline float vec_powerf(const T* x, size_t size) { float sum = 0; for (const auto i : c10::irange(size)) { - sum += x[i] * x[i]; + sum += static_cast(x[i]) * x[i]; } return sum; } diff --git a/kernels/test/op_native_group_norm_test.cpp b/kernels/test/op_native_group_norm_test.cpp index 591df6e186b..2280adf64ee 100644 --- a/kernels/test/op_native_group_norm_test.cpp +++ b/kernels/test/op_native_group_norm_test.cpp @@ -243,6 +243,42 @@ class OpNativeGroupNormTest : public OperatorTest { 0.38038814, 0.75809801}, // expected_rstd_data }, + { + {1, 4, 3}, // sizes + {0.0, + 1000.0, + 2000.0, + 3000.0, + 4000.0, + 5000.0, + 6000.0, + 7000.0, + 8000.0, + 9000.0, + 10000.0, + 11000.0}, // input_data + {1.0, 1.0, 1.0, 1.0}, // weight_data + {0.0, 0.0, 0.0, 0.0}, // bias_data + 1, // N + 4, // C + 3, // HxW + 2, // group + 1e-5, // eps + {-1.46385, + -0.87831, + -0.29277, + 0.29277, + 0.87831, + 1.46385, + -1.46385, + -0.87831, + -0.29277, + 0.29277, + 0.87831, + 1.46385}, // expected_data + {2500.0, 8500.0}, // expected_mean_data + {0.00058554, 0.00058554}, // expected_rstd_data + }, }; run_test_cases(test_cases); diff --git a/kernels/test/op_native_layer_norm_test.cpp b/kernels/test/op_native_layer_norm_test.cpp index d8cc2d3b2e4..930214d238c 100644 --- a/kernels/test/op_native_layer_norm_test.cpp +++ b/kernels/test/op_native_layer_norm_test.cpp @@ -101,6 +101,12 @@ class OpNativeLayerNormTest : public OperatorTest { expected, 1e-2, executorch::runtime::testing::internal::kDefaultBFloat16Atol); + } else if constexpr (DTYPE == ScalarType::Half) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out0, + expected, + 1e-3, + executorch::runtime::testing::internal::kDefaultHalfAtol); } else { EXPECT_TENSOR_CLOSE(out0, expected); } @@ -235,6 +241,21 @@ class OpNativeLayerNormTest : public OperatorTest { 1.38873, -0.46291}, // expected_data }, + { + std::string(__func__) + ": Large variance", + {1, 2, 3}, // sizes + {0.0, 1000.0, 2000.0, 3000.0, 4000.0, 5000.0}, // input_data + {1, 2, 3}, // normalized shape + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, // weights + {0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, // bias + 1.0e-5, // eps + {-1.46385, + -0.87831, + -0.29277, + 0.29277, + 0.87831, + 1.46385}, // expected_data + }, }; run_test_cases(test_cases);